XL2     &                !     TypeV2Obj ID                DDircL®QEcAlgoEcMEcN EcBSize   EcIndexEcDistCSumAlgoPartNumsPartETags 6523f8d89d022c7339b3b0a95955af83PartSizesBPartASizesBPartIdx SizeBMTime!اMetaSysx-minio-internal-inline-datatruex-rustfs-internal-inline-datatrueMetaUsretag 6523f8d89d022c7339b3b0a95955af83content-typetext/x-python-scriptv ξ1$.nullB^Y",HNV*	[Vimport argparse
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path

import genesis as gs
import h5py
import imageio.v2 as imageio
import numpy as np

_GS_INITIALIZED = False
HOME_QPOS = np.array([0.0, -0.785, 0.0, -2.356, 0.0, 1.571, 0.785, 0.04, 0.04], dtype=np.float32)


def ensure_gs_init(backend: str):
    global _GS_INITIALIZED
    if not _GS_INITIALIZED:
        gs.init(backend=gs.gpu if backend == "gpu" else gs.cpu)
        _GS_INITIALIZED = True


def load_episode(h5_path: Path):
    left_state, right_state = load_arm_states(h5_path)
    with h5py.File(h5_path, "r") as f:
        demo = f["data/demo_0"] if "data/demo_0" in f else f
        obs = demo["obs"] if "obs" in demo else demo
        pipette_pose = np.asarray(obs["pipette_pose"], dtype=np.float32)  # [T,7]
        pipette_tip_pos = np.asarray(obs["pipette_tip_pos"], dtype=np.float32)  # [T,3]
        pipette_qpos = np.asarray(obs["pipette_qpos"], dtype=np.float32) if "pipette_qpos" in obs else None
        dt = float(f.attrs.get("dt", 5e-2))
        init = demo["init"] if "init" in demo else None
        init_data = {
            "env_offset": np.asarray(init["env_offset_xyz_m"], dtype=np.float32) if init and "env_offset_xyz_m" in init else np.zeros(3, dtype=np.float32),
            "pipette_init_pos": np.asarray(init["pipette_init_pos"], dtype=np.float32) if init and "pipette_init_pos" in init else pipette_pose[0, :3].copy(),
            "pipette_init_quat": np.asarray(init["pipette_init_quat"], dtype=np.float32) if init and "pipette_init_quat" in init else pipette_pose[0, 3:7].copy(),
            "pipette_tip_init_pos": np.asarray(init["pipette_tip_init_pos"], dtype=np.float32) if init and "pipette_tip_init_pos" in init else pipette_tip_pos[0].copy(),
            "cell_dish_init_pos": np.asarray(init["cell_dish_init_pos"], dtype=np.float32) if init and "cell_dish_init_pos" in init else np.array([0.1, 0.0, 0.854], dtype=np.float32),
            "cell_dish_init_euler_deg": np.asarray(init["cell_dish_init_euler_deg"], dtype=np.float32) if init and "cell_dish_init_euler_deg" in init else np.array([0.0, 0.0, 90.0], dtype=np.float32),
            "pipette_rack_init_pos": np.asarray(init["pipette_rack_init_pos"], dtype=np.float32) if init and "pipette_rack_init_pos" in init else np.array([-0.27, 0.0, 0.824], dtype=np.float32),
            "pipette_rack_init_euler_deg": np.asarray(init["pipette_rack_init_euler_deg"], dtype=np.float32) if init and "pipette_rack_init_euler_deg" in init else np.array([0.0, 0.0, 90.0], dtype=np.float32),
            "pipette_file": init.attrs.get("pipette_file", "asset/model/object/pipette_free_tip.gen.xml") if init else "asset/model/object/pipette_free_tip.gen.xml",
            "cell_dish_file": init.attrs.get("cell_dish_file", "asset/model/object/cell_dish_100.gen.xml") if init else "asset/model/object/cell_dish_100.gen.xml",
            "pipette_rack_file": init.attrs.get("pipette_rack_file", "asset/model/object/pipette_rack.gen.xml") if init else "asset/model/object/pipette_rack.gen.xml",
        }

    lengths = [left_state.shape[0], right_state.shape[0], pipette_pose.shape[0], pipette_tip_pos.shape[0]]
    if pipette_qpos is not None:
        lengths.append(pipette_qpos.shape[0])
    t = min(lengths)
    return {
        "left_state": left_state[:t],
        "right_state": right_state[:t],
        "pipette_pose": pipette_pose[:t],
        "pipette_tip_pos": pipette_tip_pos[:t],
        "pipette_qpos": pipette_qpos[:t] if pipette_qpos is not None else None,
        **init_data,
        "dt": dt,
    }


def _gripper_to_normalized(gripper_position, gripper_units=""):
    gripper_position = np.asarray(gripper_position, dtype=np.float32)
    units = str(gripper_units).lower()
    if "meter" in units or "meters" in units:
        return np.clip(gripper_position / 0.04, 0.0, 1.0).astype(np.float32)
    if gripper_position.size and np.nanmax(gripper_position) <= 0.08:
        return np.clip(gripper_position / 0.04, 0.0, 1.0).astype(np.float32)
    return np.clip(gripper_position, 0.0, 1.0).astype(np.float32)


def _states_from_joint_gripper(joint_position, gripper_position, gripper_units=""):
    joint_position = np.asarray(joint_position, dtype=np.float32)
    gripper_position = _gripper_to_normalized(gripper_position, gripper_units)
    left_state = np.concatenate([joint_position[:, :7], gripper_position[:, :1]], axis=1)
    right_state = np.concatenate([joint_position[:, 7:14], gripper_position[:, 1:2]], axis=1)
    return left_state.astype(np.float32), right_state.astype(np.float32)


def load_arm_states(h5_path: Path):
    states_actions_path = h5_path.with_name("states_actions.hdf5")
    if states_actions_path.exists():
        with h5py.File(states_actions_path, "r") as f:
            if "obs/joint_position" in f and "obs/gripper_position" in f:
                return _states_from_joint_gripper(
                    f["obs/joint_position"],
                    f["obs/gripper_position"],
                    f.attrs.get("gripper_units", ""),
                )
            if "observation/joint_position" in f and "observation/gripper_position" in f:
                return _states_from_joint_gripper(
                    f["observation/joint_position"],
                    f["observation/gripper_position"],
                    f.attrs.get("gripper_units", ""),
                )

    with h5py.File(h5_path, "r") as f:
        if "data/demo_0/obs/joint_position" in f and "data/demo_0/obs/gripper_position" in f:
            return _states_from_joint_gripper(
                f["data/demo_0/obs/joint_position"],
                f["data/demo_0/obs/gripper_position"],
                f.attrs.get("gripper_units", ""),
            )
        if "obs/joint_position" in f and "obs/gripper_position" in f:
            return _states_from_joint_gripper(
                f["obs/joint_position"],
                f["obs/gripper_position"],
                f.attrs.get("gripper_units", ""),
            )
        if "observation/joint_position" in f and "observation/gripper_position" in f:
            return _states_from_joint_gripper(
                f["observation/joint_position"],
                f["observation/gripper_position"],
                f.attrs.get("gripper_units", ""),
            )
        return (
            np.asarray(f["left_state"], dtype=np.float32),
            np.asarray(f["right_state"], dtype=np.float32),
        )


def _extract_rgb(frame):
    if isinstance(frame, tuple):
        frame = frame[0]
    elif isinstance(frame, dict):
        frame = frame.get("rgb", next(iter(frame.values())))
    frame = np.asarray(frame)
    if frame.dtype != np.uint8:
        frame = np.clip(frame, 0, 255).astype(np.uint8)
    return frame


def set_pipette_free_tip_pose(entity, pipette_pose7, tip_pos3):
    qpos = np.asarray(entity.get_qpos().detach().cpu().numpy() if hasattr(entity.get_qpos(), "detach") else entity.get_qpos(), dtype=np.float32)
    if qpos.ndim != 1:
        qpos = qpos[0].copy()
    else:
        qpos = qpos.copy()

    root_joint = entity.get_joint("pipette_root_joint")
    tip_joint = entity.get_joint("pipette_tip_root_joint")
    root_qidx = root_joint.qs_idx_local
    tip_qidx = tip_joint.qs_idx_local

    qpos[root_qidx[:3]] = np.asarray(pipette_pose7[:3], dtype=np.float32)
    qpos[root_qidx[3:7]] = np.asarray(pipette_pose7[3:7], dtype=np.float32)
    qpos[tip_qidx[:3]] = np.asarray(tip_pos3, dtype=np.float32)
    # Keep the recorded run's initial tip orientation; only tip position is stored.
    entity.set_qpos(qpos, zero_velocity=True)


def set_pipette_replay_state(entity, data, t):
    if data.get("pipette_qpos") is not None:
        entity.set_qpos(np.asarray(data["pipette_qpos"][t], dtype=np.float32), zero_velocity=True)
        return
    set_pipette_free_tip_pose(entity, data["pipette_pose"][t], data["pipette_tip_pos"][t])


def state8_to_qpos9(s8):
    """left_state/right_state is [7 joints, gripper_normalized in 0..1].
    Expand normalized gripper to two finger positions in meters (0..0.04)."""
    finger_m = float(np.clip(s8[7], 0.0, 1.0)) * 0.04
    return np.array([s8[0], s8[1], s8[2], s8[3], s8[4], s8[5], s8[6],
                     finger_m, finger_m], dtype=np.float32)


def set_replay_frame(franka_left, franka_right, pipette, data, t):
    left_q9 = state8_to_qpos9(data["left_state"][t])
    right_q9 = state8_to_qpos9(data["right_state"][t])
    franka_left.set_qpos(left_q9)
    franka_right.set_qpos(right_q9)
    franka_left.control_dofs_position(left_q9)
    franka_right.control_dofs_position(right_q9)
    set_pipette_replay_state(pipette, data, t)


def make_wrist_offset():
    wrist_offset = np.eye(4, dtype=np.float32)
    wrist_offset[:3, 3] = np.array([-0.08, 0.0, -0.035], dtype=np.float32)

    yaw = np.deg2rad(90.0)
    pitch = np.deg2rad(180.0)
    roll = np.deg2rad(-10.0)

    rz = np.array([[np.cos(yaw), -np.sin(yaw), 0], [np.sin(yaw), np.cos(yaw), 0], [0, 0, 1]], dtype=np.float32)
    ry = np.array([[np.cos(pitch), 0, np.sin(pitch)], [0, 1, 0], [-np.sin(pitch), 0, np.cos(pitch)]], dtype=np.float32)
    rx = np.array([[1, 0, 0], [0, np.cos(roll), -np.sin(roll)], [0, np.sin(roll), np.cos(roll)]], dtype=np.float32)
    # Flip the camera image vertically by rotating 180 degrees about its local optical (Z) axis.
    flip_180_z = np.array([[-1, 0, 0], [0, -1, 0], [0, 0, 1]], dtype=np.float32)
    wrist_offset[:3, :3] = rz @ ry @ rx @ flip_180_z
    return wrist_offset


def build_scene(data, res_w, res_h):
    pip0 = data["pipette_pose"][0].copy()

    scene = gs.Scene(
        sim_options=gs.options.SimOptions(dt=data["dt"], substeps=25),
        renderer=gs.renderers.Rasterizer(),
        show_viewer=False,
    )

    scene.add_entity(gs.morphs.Plane(collision=False))

    # Table
    scene.add_entity(
        gs.morphs.MJCF(file='asset/model/misc/simple_table.xml',
                       pos=(0.0, 0.0, 0.0), decimate=False, euler=(0, 0, 90), collision=True)
    )

    # Cell dish
    scene.add_entity(
        gs.morphs.MJCF(file=str(data["cell_dish_file"]),
                       pos=tuple(data["cell_dish_init_pos"].tolist()),
                       decimate=False,
                       euler=tuple(data["cell_dish_init_euler_deg"].tolist()),
                       collision=True)
    )

    # Pipette with detachable/free tip
    pipette = scene.add_entity(
        gs.morphs.MJCF(file=str(data["pipette_file"]),
                       pos=tuple(pip0[:3].tolist()), decimate=False, euler=(0, 10, 180), collision=False)
    )

    # Pipette rack
    scene.add_entity(
        gs.morphs.MJCF(file=str(data["pipette_rack_file"]),
                       pos=tuple(data["pipette_rack_init_pos"].tolist()),
                       decimate=False,
                       euler=tuple(data["pipette_rack_init_euler_deg"].tolist()),
                       collision=True)
    )

    # Left arm
    franka_left = scene.add_entity(
        gs.morphs.MJCF(file='xml/franka_emika_panda/panda.xml',
                       pos=(0.0, -0.5, 0.824), euler=(0, 0, 90), collision=False)
    )

    # Right arm
    franka_right = scene.add_entity(
        gs.morphs.MJCF(file='xml/franka_emika_panda/panda.xml',
                       pos=(0.0, 0.5, 0.824), euler=(0, 0, -90), collision=False)
    )

    # Cameras
    cam_agent = scene.add_camera(
        res=(res_w, res_h),
        pos=(2.5, 0.0, 2.0),
        lookat=(0.0, 0.0, 1.0),
        fov=30,
        GUI=False,
    )
    cam_wrist = scene.add_camera(
        res=(res_w, res_h),
        pos=(0.0, 0.0, 0.0),
        lookat=(1.0, 0.0, 0.0),
        fov=70,
        GUI=False,
    )
    cam_wrist_2 = scene.add_camera(
        res=(res_w, res_h),
        pos=(0.0, 0.0, 0.0),
        lookat=(1.0, 0.0, 0.0),
        fov=70,
        GUI=False,
    )

    scene.build()

    # Initialize arm poses
    for robot in (franka_left, franka_right):
        robot.set_dofs_position(HOME_QPOS)
        robot.control_dofs_position(HOME_QPOS)

    wrist_offset = make_wrist_offset()
    cam_wrist.attach(franka_left.get_link("hand"), wrist_offset)
    cam_wrist_2.attach(franka_right.get_link("hand"), wrist_offset.copy())

    set_pipette_replay_state(pipette, data, 0)

    return scene, franka_left, franka_right, pipette, cam_agent, cam_wrist, cam_wrist_2


def render_episode(h5_path: Path, output_dir: Path, fps: int, res_w: int, res_h: int, backend: str):
    ensure_gs_init(backend)

    data = load_episode(h5_path)
    if data.get("pipette_qpos") is None:
        print(
            f"[warn] {h5_path} has no obs/pipette_qpos; "
            "pipette ejector joints and tip orientation will be approximate. "
            "Regenerate data.h5 with the updated collector for accurate replay."
        )
    scene, franka_left, franka_right, pipette, cam_agent, cam_wrist, cam_wrist_2 = build_scene(data, res_w, res_h)

    output_dir.mkdir(parents=True, exist_ok=True)
    writer_agent = imageio.get_writer(str(output_dir / "agentview.mp4"), fps=fps)
    writer_wrist = imageio.get_writer(str(output_dir / "wrist.mp4"), fps=fps)
    writer_wrist_2 = imageio.get_writer(str(output_dir / "wrist_2.mp4"), fps=fps)

    T = data["left_state"].shape[0]

    try:
        for t in range(T):
            # Step once to refresh kinematics/camera attachments, then write the
            # recorded frame again so replayed free bodies do not drift before render.
            set_replay_frame(franka_left, franka_right, pipette, data, t)
            scene.step()
            set_replay_frame(franka_left, franka_right, pipette, data, t)

            writer_agent.append_data(_extract_rgb(cam_agent.render()))
            writer_wrist.append_data(_extract_rgb(cam_wrist.render()))
            writer_wrist_2.append_data(_extract_rgb(cam_wrist_2.render()))

            if t % 500 == 0:
                print(f"  frame {t}/{T}")
    finally:
        writer_agent.close()
        writer_wrist.close()
        writer_wrist_2.close()
        scene.destroy()


def find_episodes(dataset_root: Path):
    h5_files = []
    for d in sorted(dataset_root.glob("episode_*")):
        if d.is_dir() and (d / "data.h5").exists():
            h5_files.append((d / "data.h5", d))
    # Also check for a bare data.h5 directly in the root
    bare = dataset_root / "data.h5"
    if bare.exists() and not any(p == bare for p, _ in h5_files):
        h5_files.append((bare, dataset_root))
    return h5_files


def _worker_render(args_tuple):
    h5_path, output_dir, fps, res_w, res_h, backend = args_tuple
    render_episode(Path(h5_path), Path(output_dir), fps=fps, res_w=res_w, res_h=res_h, backend=backend)
    return str(output_dir)


def main():
    parser = argparse.ArgumentParser(description="Replay renderer for discard-pipette-tip h5 data")
    parser.add_argument("--h5", type=str, default="",
                        help="Path to a single data.h5 file. Mutually exclusive with --dataset_root.")
    parser.add_argument("--dataset_root", type=str, default="",
                        help="Folder containing episode_xxx subfolders with data.h5")
    parser.add_argument("--output", type=str, default="",
                        help="Output directory for rendered videos (default: same dir as h5)")
    parser.add_argument("--fps", type=int, default=20)
    parser.add_argument("--res_w", type=int, default=256)
    parser.add_argument("--res_h", type=int, default=256)
    parser.add_argument("--backend", type=str, choices=["gpu", "cpu"], default="gpu")
    parser.add_argument("--workers", type=int, default=1)
    args = parser.parse_args()

    jobs = []  # list of (h5_path, output_dir)

    if args.h5:
        h5_path = Path(args.h5)
        out = Path(args.output) if args.output else h5_path.parent
        jobs.append((h5_path, out))
    elif args.dataset_root:
        episodes = find_episodes(Path(args.dataset_root))
        if not episodes:
            raise FileNotFoundError(f"No data.h5 found under: {args.dataset_root}")
        for h5_path, ep_dir in episodes:
            out = Path(args.output) / ep_dir.name if args.output else ep_dir
            jobs.append((h5_path, out))
    else:
        parser.error("Provide either --h5 or --dataset_root")

    workers = max(1, int(args.workers))

    if workers == 1 or len(jobs) == 1:
        for h5_path, out_dir in jobs:
            print(f"Rendering {h5_path} -> {out_dir}")
            render_episode(h5_path, out_dir, fps=args.fps, res_w=args.res_w, res_h=args.res_h, backend=args.backend)
            print(f"Done: {out_dir}")
        return

    with ProcessPoolExecutor(max_workers=workers) as ex:
        futures = [
            ex.submit(_worker_render, (str(h5), str(out), args.fps, args.res_w, args.res_h, args.backend))
            for h5, out in jobs
        ]
        for fut in as_completed(futures):
            try:
                print(f"Done: {fut.result()}")
            except Exception as e:
                print(f"Failed: {e}")


if __name__ == "__main__":
    main()
