XL2     &                !PE     TypeV2Obj ID                DDir/&MBŐ0LfEcAlgoEcMEcN EcBSize   EcIndexEcDistCSumAlgoPartNumsPartETags 04a0bd190d84b31c81e37e957dc1aa3fPartSizesCPartASizesCPartIdx SizeCMTime!PEͧMetaSysx-minio-internal-inline-datatruex-rustfs-internal-inline-datatrueMetaUsretag 04a0bd190d84b31c81e37e957dc1aa3fcontent-typetext/x-python-scriptv D,QnullC>P,-Xcȴc̃{jazC6import argparse
import json
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path

import genesis as gs
import h5py
import imageio.v2 as imageio
import numpy as np

_GS_INITIALIZED = False
HOME_QPOS = np.array([0.0, -0.785, 0.0, -2.356, 0.0, 1.571, 0.785, 0.04, 0.04], dtype=np.float32)


def ensure_gs_init(backend: str):
    global _GS_INITIALIZED
    if not _GS_INITIALIZED:
        gs.init(backend=gs.gpu if backend == "gpu" else gs.cpu)
        _GS_INITIALIZED = True


def load_episode(h5_path: Path):
    left_state, right_state = load_arm_states(h5_path)
    with h5py.File(h5_path, "r") as f:
        obs = f["data/demo_0/obs"] if "data/demo_0/obs" in f else f
        pipette_pose = np.asarray(obs["pipette_pose"], dtype=np.float32)  # [T,7]
        tube_pose = np.asarray(obs["tube_pose"], dtype=np.float32)        # [T,7]
        particle_pos = np.asarray(obs["liquid_particles"], dtype=np.float32)  # [T,P,3]
        dt = float(f.attrs.get("dt", 5e-2))

    init_path = h5_path.with_name("init.json")
    env_offset = np.zeros(3, dtype=np.float32)
    init_pipette_pos = None
    init_tube_pos = None
    if init_path.exists():
        with open(init_path, "r", encoding="utf-8") as f:
            meta = json.load(f)
        env_offset = np.asarray(meta.get("scene", {}).get("envs_offset_xyz_m", [0.0, 0.0, 0.0]), dtype=np.float32)
        if "pipette" in meta and "init_pos" in meta["pipette"]:
            init_pipette_pos = np.asarray(meta["pipette"]["init_pos"], dtype=np.float32)
        if "tube" in meta and "init_pos" in meta["tube"]:
            init_tube_pos = np.asarray(meta["tube"]["init_pos"], dtype=np.float32)

    t = min(left_state.shape[0], right_state.shape[0],
            pipette_pose.shape[0], tube_pose.shape[0], particle_pos.shape[0])
    return {
        "left_state": left_state[:t],
        "right_state": right_state[:t],
        "pipette_pose": pipette_pose[:t],
        "tube_pose": tube_pose[:t],
        "particle_pos": particle_pos[:t],
        "env_offset": env_offset,
        "init_pipette_pos": init_pipette_pos,
        "init_tube_pos": init_tube_pos,
        "dt": dt,
    }


def align_recorded_positions_to_init(data):
    init_tube_pos = data.get("init_tube_pos")
    init_pipette_pos = data.get("init_pipette_pos")
    if init_tube_pos is None and init_pipette_pos is None:
        return data

    corrections = []
    if init_tube_pos is not None:
        corrections.append(np.asarray(init_tube_pos, dtype=np.float32) - data["tube_pose"][0, :3])
    if init_pipette_pos is not None:
        corrections.append(np.asarray(init_pipette_pos, dtype=np.float32) - data["pipette_pose"][0, :3])
    correction = np.median(np.stack(corrections, axis=0), axis=0).astype(np.float32)
    # The legacy bug is an env xy offset issue. Do not correct z here:
    # pipette's recorded base-link z can differ from init_pos because the MJCF
    # has an internal body offset, while tube/liquid z is already correct.
    correction[2] = 0.0
    if np.linalg.norm(correction) < 1e-4:
        return data

    data = dict(data)
    data["pipette_pose"] = data["pipette_pose"].copy()
    data["tube_pose"] = data["tube_pose"].copy()
    data["particle_pos"] = data["particle_pos"].copy()
    data["pipette_pose"][:, :3] += correction
    data["tube_pose"][:, :3] += correction
    data["particle_pos"] += correction[None, None, :]
    print(f"[offset] aligned recorded poses to init.json by {correction.tolist()}")
    return data


def _gripper_to_normalized(gripper_position, gripper_units=""):
    gripper_position = np.asarray(gripper_position, dtype=np.float32)
    units = str(gripper_units).lower()
    if "meter" in units or "meters" in units:
        return np.clip(gripper_position / 0.04, 0.0, 1.0).astype(np.float32)
    if gripper_position.size and np.nanmax(gripper_position) <= 0.08:
        return np.clip(gripper_position / 0.04, 0.0, 1.0).astype(np.float32)
    return np.clip(gripper_position, 0.0, 1.0).astype(np.float32)


def _states_from_joint_gripper(joint_position, gripper_position, gripper_units=""):
    joint_position = np.asarray(joint_position, dtype=np.float32)
    gripper_position = _gripper_to_normalized(gripper_position, gripper_units)
    left_state = np.concatenate([joint_position[:, :7], gripper_position[:, :1]], axis=1)
    right_state = np.concatenate([joint_position[:, 7:14], gripper_position[:, 1:2]], axis=1)
    return left_state.astype(np.float32), right_state.astype(np.float32)


def load_arm_states(h5_path: Path):
    states_actions_path = h5_path.with_name("states_actions.hdf5")
    if states_actions_path.exists():
        with h5py.File(states_actions_path, "r") as f:
            if "obs/joint_position" in f and "obs/gripper_position" in f:
                return _states_from_joint_gripper(
                    f["obs/joint_position"],
                    f["obs/gripper_position"],
                    f.attrs.get("gripper_units", ""),
                )
            if "observation/joint_position" in f and "observation/gripper_position" in f:
                return _states_from_joint_gripper(
                    f["observation/joint_position"],
                    f["observation/gripper_position"],
                    f.attrs.get("gripper_units", ""),
                )

    with h5py.File(h5_path, "r") as f:
        if "data/demo_0/obs/joint_position" in f and "data/demo_0/obs/gripper_position" in f:
            return _states_from_joint_gripper(
                f["data/demo_0/obs/joint_position"],
                f["data/demo_0/obs/gripper_position"],
                f.attrs.get("gripper_units", ""),
            )
        if "obs/joint_position" in f and "obs/gripper_position" in f:
            return _states_from_joint_gripper(
                f["obs/joint_position"],
                f["obs/gripper_position"],
                f.attrs.get("gripper_units", ""),
            )
        if "observation/joint_position" in f and "observation/gripper_position" in f:
            return _states_from_joint_gripper(
                f["observation/joint_position"],
                f["observation/gripper_position"],
                f.attrs.get("gripper_units", ""),
            )
        return (
            np.asarray(f["left_state"], dtype=np.float32),
            np.asarray(f["right_state"], dtype=np.float32),
        )


def _extract_rgb(frame):
    if isinstance(frame, tuple):
        frame = frame[0]
    elif isinstance(frame, dict):
        frame = frame.get("rgb", next(iter(frame.values())))
    frame = np.asarray(frame)
    if frame.dtype != np.uint8:
        frame = np.clip(frame, 0, 255).astype(np.uint8)
    return frame


def set_entity_pose(entity, pose7):
    pos = tuple(np.asarray(pose7[:3], dtype=np.float32).tolist())
    quat = tuple(np.asarray(pose7[3:7], dtype=np.float32).tolist())
    try:
        entity.set_pose(pos=pos, quat=quat)
    except Exception:
        entity.set_pos(pos)
        try:
            entity.set_quat(quat)
        except Exception:
            pass


def state8_to_qpos9(s8):
    """left_state/right_state is [7 joints, gripper_normalized in 0..1].
    Expand normalized gripper to two finger positions in meters (0..0.04)."""
    finger_m = float(np.clip(s8[7], 0.0, 1.0)) * 0.04
    return np.array([s8[0], s8[1], s8[2], s8[3], s8[4], s8[5], s8[6],
                     finger_m, finger_m], dtype=np.float32)


def make_wrist_offset():
    wrist_offset = np.eye(4, dtype=np.float32)
    wrist_offset[:3, 3] = np.array([-0.08, 0.0, -0.035], dtype=np.float32)

    yaw = np.deg2rad(90.0)
    pitch = np.deg2rad(180.0)
    roll = np.deg2rad(-10.0)

    rz = np.array([[np.cos(yaw), -np.sin(yaw), 0], [np.sin(yaw), np.cos(yaw), 0], [0, 0, 1]], dtype=np.float32)
    ry = np.array([[np.cos(pitch), 0, np.sin(pitch)], [0, 1, 0], [-np.sin(pitch), 0, np.cos(pitch)]], dtype=np.float32)
    rx = np.array([[1, 0, 0], [0, np.cos(roll), -np.sin(roll)], [0, np.sin(roll), np.cos(roll)]], dtype=np.float32)
    # Flip the camera image vertically by rotating 180 degrees about its local optical (Z) axis.
    flip_180_z = np.array([[-1, 0, 0], [0, -1, 0], [0, 0, 1]], dtype=np.float32)
    wrist_offset[:3, :3] = rz @ ry @ rx @ flip_180_z
    return wrist_offset


def build_scene(data, res_w, res_h):
    pip0 = data["pipette_pose"][0].copy()
    tube0 = data["tube_pose"][0].copy()

    # Use the same nominal liquid box as pipette2_parallel.py so SPH samples
    # the exact same number of particles as recorded.
    LIQUID_NOMINAL_POS = (-0.028, 0.018, 1.1)
    LIQUID_NOMINAL_SIZE = (0.016, 0.016, 0.35)

    scene = gs.Scene(
        sim_options=gs.options.SimOptions(dt=data["dt"], substeps=100),
        sph_options=gs.options.SPHOptions(
            lower_bound=(-0.5, -0.5, 0.0),
            upper_bound=(0.8, 0.5, 2.0),
            particle_size=0.005,
        ),
        renderer=gs.renderers.Rasterizer(),
        show_viewer=False,
    )

    scene.add_entity(gs.morphs.Plane(collision=False))

    # Table
    scene.add_entity(
        gs.morphs.MJCF(file='asset/model/misc/simple_table.xml',
                       pos=(0.0, 0.0, 0.0), decimate=False, euler=(0, 0, 90), collision=True)
    )

    # Tube rack
    scene.add_entity(
        gs.morphs.MJCF(file='asset/model/object/centrifuge_10slot.gen.xml',
                       pos=(0.1, 0.0, 0.854), decimate=False, euler=(0, 0, 0), collision=True)
    )

    # Tube
    tube = scene.add_entity(
        gs.morphs.MJCF(file='asset/model/object/centrifuge_50ml_collision.gen.xml',
                       pos=tuple(tube0[:3].tolist()), decimate=False, euler=(0, 0, 0), collision=False)
    )

    # Pipettex
    pipette = scene.add_entity(
        gs.morphs.MJCF(file='asset/model/object/pipette_add_stiffness.gen.xml',
                       pos=tuple(pip0[:3].tolist()), decimate=False, euler=(0, 10, 180), collision=False)
    )

    # Pipette rack
    scene.add_entity(
        gs.morphs.MJCF(file='asset/model/object/pipette_rack.gen.xml',
                       pos=(-0.27, 0.0, 0.824), decimate=False, euler=(0, 0, 90), collision=True)
    )

    # Liquid (recon mode for surface rendering)
    liquid = scene.add_entity(
        material=gs.materials.SPH.Liquid(
            mu=0.001, gamma=0.01, stiffness=50000.0, sampler="regular",
        ),
        morph=gs.morphs.Box(
            pos=LIQUID_NOMINAL_POS,
            size=LIQUID_NOMINAL_SIZE,
        ),
        surface=gs.surfaces.Default(
            color=(0.4, 0.8, 1.0, 1.0),
            vis_mode="recon",
        ),
    )

    # Left arm
    franka_left = scene.add_entity(
        gs.morphs.MJCF(file='xml/franka_emika_panda/panda.xml',
                       pos=(0.0, -0.5, 0.824), euler=(0, 0, 90), collision=False)
    )

    # Right arm
    franka_right = scene.add_entity(
        gs.morphs.MJCF(file='xml/franka_emika_panda/panda.xml',
                       pos=(0.0, 0.5, 0.824), euler=(0, 0, -90), collision=False)
    )

    # Cameras
    cam_agent = scene.add_camera(
        res=(res_w, res_h),
        pos=(2.5, 0.0, 2.0),
        lookat=(0.0, 0.0, 1.0),
        fov=30,
        GUI=False,
    )
    cam_wrist = scene.add_camera(
        res=(res_w, res_h),
        pos=(0.0, 0.0, 0.0),
        lookat=(1.0, 0.0, 0.0),
        fov=70,
        GUI=False,
    )
    cam_wrist_2 = scene.add_camera(
        res=(res_w, res_h),
        pos=(0.0, 0.0, 0.0),
        lookat=(1.0, 0.0, 0.0),
        fov=70,
        GUI=False,
    )

    scene.build()

    # Initialize arm poses
    for robot in (franka_left, franka_right):
        robot.set_dofs_position(HOME_QPOS)
        robot.control_dofs_position(HOME_QPOS)

    wrist_offset = make_wrist_offset()
    cam_wrist.attach(franka_left.get_link("hand"), wrist_offset)
    cam_wrist_2.attach(franka_right.get_link("hand"), wrist_offset.copy())

    return scene, franka_left, franka_right, pipette, tube, liquid, cam_agent, cam_wrist, cam_wrist_2


def render_episode(h5_path: Path, output_dir: Path, fps: int, res_w: int, res_h: int, backend: str):
    ensure_gs_init(backend)

    data = load_episode(h5_path)
    data = align_recorded_positions_to_init(data)
    scene, franka_left, franka_right, pipette, tube, liquid, cam_agent, cam_wrist, cam_wrist_2 = build_scene(data, res_w, res_h)

    output_dir.mkdir(parents=True, exist_ok=True)
    writer_agent = imageio.get_writer(str(output_dir / "agentview.mp4"), fps=fps)
    writer_wrist = imageio.get_writer(str(output_dir / "wrist.mp4"), fps=fps)
    writer_wrist_2 = imageio.get_writer(str(output_dir / "wrist_2.mp4"), fps=fps)

    T = data["left_state"].shape[0]

    try:
        for t in range(T):
            # Set arm joint positions, and also drive PD target to the same
            # value so the controller does not pull the joints back to home
            # during the (substeps=100) physics integration that scene.step()
            # runs on every frame.
            left_q9 = state8_to_qpos9(data["left_state"][t])
            right_q9 = state8_to_qpos9(data["right_state"][t])
            franka_left.set_qpos(left_q9)
            franka_right.set_qpos(right_q9)
            franka_left.control_dofs_position(left_q9)
            franka_right.control_dofs_position(right_q9)

            # Set object poses (overwrite each frame to suppress drift)
            pip_pose = data["pipette_pose"][t].copy()
            tube_pose = data["tube_pose"][t].copy()
            set_entity_pose(pipette, pip_pose)
            set_entity_pose(tube, tube_pose)

            # Stored object and particle positions are already env-local.
            particle_pos = np.asarray(data["particle_pos"][t], dtype=np.float32)
            liquid.set_particles_pos(particle_pos[None, ...])

            scene.step()
            writer_agent.append_data(_extract_rgb(cam_agent.render()))
            writer_wrist.append_data(_extract_rgb(cam_wrist.render()))
            writer_wrist_2.append_data(_extract_rgb(cam_wrist_2.render()))

            if t % 500 == 0:
                print(f"  frame {t}/{T}")
    finally:
        writer_agent.close()
        writer_wrist.close()
        writer_wrist_2.close()
        scene.destroy()


def find_episodes(dataset_root: Path):
    h5_files = []
    for d in sorted(dataset_root.glob("episode_*")):
        if d.is_dir() and (d / "data.h5").exists():
            h5_files.append((d / "data.h5", d))
    # Also check for a bare data.h5 directly in the root
    bare = dataset_root / "data.h5"
    if bare.exists() and not any(p == bare for p, _ in h5_files):
        h5_files.append((bare, dataset_root))
    return h5_files


def _worker_render(args_tuple):
    h5_path, output_dir, fps, res_w, res_h, backend = args_tuple
    render_episode(Path(h5_path), Path(output_dir), fps=fps, res_w=res_w, res_h=res_h, backend=backend)
    return str(output_dir)


def main():
    parser = argparse.ArgumentParser(description="Replay renderer for pipette task h5 data (liquid surface recon)")
    parser.add_argument("--h5", type=str, default="",
                        help="Path to a single data.h5 file. Mutually exclusive with --dataset_root.")
    parser.add_argument("--dataset_root", type=str, default="",
                        help="Folder containing episode_xxx subfolders with data.h5")
    parser.add_argument("--output", type=str, default="",
                        help="Output directory for rendered videos (default: same dir as h5)")
    parser.add_argument("--fps", type=int, default=20)
    parser.add_argument("--res_w", type=int, default=256)
    parser.add_argument("--res_h", type=int, default=256)
    parser.add_argument("--backend", type=str, choices=["gpu", "cpu"], default="gpu")
    parser.add_argument("--workers", type=int, default=1)
    args = parser.parse_args()

    jobs = []  # list of (h5_path, output_dir)

    if args.h5:
        h5_path = Path(args.h5)
        out = Path(args.output) if args.output else h5_path.parent
        jobs.append((h5_path, out))
    elif args.dataset_root:
        episodes = find_episodes(Path(args.dataset_root))
        if not episodes:
            raise FileNotFoundError(f"No data.h5 found under: {args.dataset_root}")
        for h5_path, ep_dir in episodes:
            out = Path(args.output) / ep_dir.name if args.output else ep_dir
            jobs.append((h5_path, out))
    else:
        parser.error("Provide either --h5 or --dataset_root")

    workers = max(1, int(args.workers))

    if workers == 1 or len(jobs) == 1:
        for h5_path, out_dir in jobs:
            print(f"Rendering {h5_path} -> {out_dir}")
            render_episode(h5_path, out_dir, fps=args.fps, res_w=args.res_w, res_h=args.res_h, backend=args.backend)
            print(f"Done: {out_dir}")
        return

    with ProcessPoolExecutor(max_workers=workers) as ex:
        futures = [
            ex.submit(_worker_render, (str(h5), str(out), args.fps, args.res_w, args.res_h, args.backend))
            for h5, out in jobs
        ]
        for fut in as_completed(futures):
            try:
                print(f"Done: {fut.result()}")
            except Exception as e:
                print(f"Failed: {e}")


if __name__ == "__main__":
    main()
