XL2     &                !M]     TypeV2Obj ID                DDirf#|'I3]ѦEcAlgoEcMEcN EcBSize   EcIndexEcDistCSumAlgoPartNumsPartETags 58651c1537ce98ae543f97be9c52cac8PartSizes1PartASizes1PartIdx Size1MTime!M]MetaSysx-minio-internal-inline-datatruex-rustfs-internal-inline-datatrueMetaUsrcontent-typetext/x-python-scriptetag 58651c1537ce98ae543f97be9c52cac8v null1H_em	tK90qyimport argparse
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path

import genesis as gs
import h5py
import imageio.v2 as imageio
import numpy as np


_GS_INITIALIZED = False
FINGER_OPEN_M = 0.04


def ensure_gs_init(backend: str):
    global _GS_INITIALIZED
    if not _GS_INITIALIZED:
        gs.init(backend=gs.gpu if backend == "gpu" else gs.cpu)
        _GS_INITIALIZED = True


def load_episode(h5_path: Path):
    states = load_arm_states(h5_path)
    with h5py.File(h5_path, "r") as f:
        if "data/demo_0/obs/source_pose" in f:
            obs = f["data/demo_0/obs"]
            source_pose = np.asarray(obs["source_pose"], dtype=np.float32)
            target_pose = np.asarray(obs["target_pose"], dtype=np.float32)
            particle_pos = np.asarray(obs["particle_pos"], dtype=np.float32)
        elif "source_pose" in f:
            source_pose = np.asarray(f["source_pose"], dtype=np.float32)
            target_pose = np.asarray(f["target_pose"], dtype=np.float32)
            particle_pos = np.asarray(f["particle_pos"], dtype=np.float32)
        elif "liquid/particle_pos" in f:
            source_pose = np.asarray(f["source_pose"], dtype=np.float32)
            target_pose = np.asarray(f["target_pose"], dtype=np.float32)
            particle_pos = np.asarray(f["liquid/particle_pos"], dtype=np.float32)
        else:
            raise KeyError(f"No source/target/particle observations found in {h5_path}")

        dt = float(f.attrs.get("dt", 5e-2))

    t = min(states.shape[0], source_pose.shape[0], target_pose.shape[0], particle_pos.shape[0])
    return states[:t], source_pose[:t], target_pose[:t], particle_pos[:t], dt


def _gripper_to_normalized(gripper_position, gripper_units=""):
    gripper_position = np.asarray(gripper_position, dtype=np.float32)
    units = str(gripper_units).lower()
    if "meter" in units or "meters" in units:
        return np.clip(gripper_position / FINGER_OPEN_M, 0.0, 1.0).astype(np.float32)
    if gripper_position.size and np.nanmax(gripper_position) <= 0.08:
        return np.clip(gripper_position / FINGER_OPEN_M, 0.0, 1.0).astype(np.float32)
    return np.clip(gripper_position, 0.0, 1.0).astype(np.float32)


def _states_from_joint_gripper(joint_position, gripper_position, gripper_units=""):
    joint_position = np.asarray(joint_position, dtype=np.float32)
    gripper_position = _gripper_to_normalized(gripper_position, gripper_units)
    return np.concatenate([joint_position[:, :7], gripper_position[:, :1]], axis=1).astype(np.float32)


def load_arm_states(h5_path: Path):
    states_actions_path = h5_path.with_name("states_actions.hdf5")
    if states_actions_path.exists():
        with h5py.File(states_actions_path, "r") as f:
            if "obs/joint_position" in f and "obs/gripper_position" in f:
                return _states_from_joint_gripper(
                    f["obs/joint_position"],
                    f["obs/gripper_position"],
                    f.attrs.get("gripper_units", ""),
                )
            if "observation/joint_position" in f and "observation/gripper_position" in f:
                return _states_from_joint_gripper(
                    f["observation/joint_position"],
                    f["observation/gripper_position"],
                    f.attrs.get("gripper_units", ""),
                )

    with h5py.File(h5_path, "r") as f:
        if "data/demo_0/obs/joint_position" in f and "data/demo_0/obs/gripper_position" in f:
            return _states_from_joint_gripper(
                f["data/demo_0/obs/joint_position"],
                f["data/demo_0/obs/gripper_position"],
                f.attrs.get("gripper_units", ""),
            )
        if "obs/joint_position" in f and "obs/gripper_position" in f:
            return _states_from_joint_gripper(
                f["obs/joint_position"],
                f["obs/gripper_position"],
                f.attrs.get("gripper_units", ""),
            )
        if "observation/joint_position" in f and "observation/gripper_position" in f:
            return _states_from_joint_gripper(
                f["observation/joint_position"],
                f["observation/gripper_position"],
                f.attrs.get("gripper_units", ""),
            )
        if "states" in f:
            states = np.asarray(f["states"], dtype=np.float32)
        else:
            states = np.asarray(f["state"], dtype=np.float32)
        if states.shape[1] == 7:
            raise ValueError(
                "Cannot replay from top-level state alone because it is cartesian state. "
                "Expected obs/joint_position or data/demo_0/obs/joint_position with gripper_position."
            )
        return states[:, :8].astype(np.float32)


def _extract_rgb(frame):
    if isinstance(frame, tuple):
        frame = frame[0]
    elif isinstance(frame, dict):
        frame = frame.get("rgb", next(iter(frame.values())))
    frame = np.asarray(frame)
    if frame.dtype != np.uint8:
        frame = np.clip(frame, 0, 255).astype(np.uint8)
    return frame


def build_scene(first_source_pose, first_target_pose, dt, res_w, res_h):
    scene = gs.Scene(
        sim_options=gs.options.SimOptions(dt=dt, substeps=100),
        sph_options=gs.options.SPHOptions(
            pressure_solver="WCSPH",
            lower_bound=(-0.5, -1.0, 0.0),
            upper_bound=(1.2, 0.5, 1.2),
            particle_size=0.005,
        ),
        renderer=gs.renderers.Rasterizer(),
        show_viewer=False,
    )

    scene.add_entity(gs.morphs.Plane(collision=False))
    scene.add_entity(
        gs.morphs.MJCF(
            file='asset/model/misc/simple_table.xml',
            pos=(0.0, 0.0, 0.0),
            decimate=False,
            euler=(0, 0, 90),
        )
    )

    source_beaker = scene.add_entity(
        gs.morphs.MJCF(file="asset/beaker/beaker.xml", pos=tuple(first_source_pose[:3].tolist()), collision=False),
    )
    target_beaker = scene.add_entity(
        gs.morphs.MJCF(file="asset/beaker/beaker.xml", pos=tuple(first_target_pose[:3].tolist()), collision=False),
    )

    franka = scene.add_entity(gs.morphs.MJCF(file="xml/franka_emika_panda/panda.xml", pos=(-0.25, -0.5, 0.824), collision=False))

    liquid = scene.add_entity(
        material=gs.materials.SPH.Liquid(mu=0.001, gamma=0.01, stiffness=50000.0, sampler="regular"),
        morph=gs.morphs.Box(pos=(0.55, -0.20, 0.45), size=(0.04, 0.04, 0.06)),
        surface=gs.surfaces.Default(color=(0.4, 0.8, 1.0, 1.0), vis_mode="recon"),
    )

    cam_agent = scene.add_camera(
        res=(res_w, res_h),
        pos=(2.5, 0, 2.0),
        lookat=(0.0, 0.0, 1.0),
        fov=30,
        GUI=False,
    )
    cam_wrist = scene.add_camera(
        res=(res_w, res_h),
        pos=(0.0, 0.0, 0.0),
        lookat=(1.0, 0.0, 0.0),
        fov=70,
        GUI=False,
    )

    scene.build()

    hand_link = franka.get_link("hand")
    wrist_offset = np.eye(4, dtype=np.float32)
    wrist_offset[:3, 3] = np.array([-0.08, 0.0, -0.035], dtype=np.float32)

    # Match wrist camera orientation from pourTask_liquid_better.py
    yaw = np.deg2rad(90.0)
    pitch = np.deg2rad(180.0)
    roll = np.deg2rad(-10.0)

    rz = np.array(
        [
            [np.cos(yaw), -np.sin(yaw), 0.0],
            [np.sin(yaw), np.cos(yaw), 0.0],
            [0.0, 0.0, 1.0],
        ],
        dtype=np.float32,
    )
    ry = np.array(
        [
            [np.cos(pitch), 0.0, np.sin(pitch)],
            [0.0, 1.0, 0.0],
            [-np.sin(pitch), 0.0, np.cos(pitch)],
        ],
        dtype=np.float32,
    )
    rx = np.array(
        [
            [1.0, 0.0, 0.0],
            [0.0, np.cos(roll), -np.sin(roll)],
            [0.0, np.sin(roll), np.cos(roll)],
        ],
        dtype=np.float32,
    )

    wrist_offset[:3, :3] = rz @ ry @ rx
    cam_wrist.attach(hand_link, wrist_offset)

    return scene, franka, source_beaker, target_beaker, liquid, cam_agent, cam_wrist


def set_entity_pose(entity, pose7):
    pos = tuple(np.asarray(pose7[:3], dtype=np.float32).tolist())
    quat = tuple(np.asarray(pose7[3:7], dtype=np.float32).tolist())
    try:
        entity.set_pose(pos=pos, quat=quat)
    except Exception:
        entity.set_pos(pos)
        try:
            entity.set_quat(quat)
        except Exception:
            pass


def render_episode(episode_dir: Path, fps: int, res_w: int, res_h: int, backend: str):
    ensure_gs_init(backend)

    h5_path = episode_dir / "data.h5"
    if not h5_path.exists():
        raise FileNotFoundError(h5_path)

    states, source_pose, target_pose, particle_pos, dt = load_episode(h5_path)

    scene, franka, source_beaker, target_beaker, liquid, cam_agent, cam_wrist = build_scene(
        source_pose[0], target_pose[0], dt, res_w, res_h
    )

    render_dir = episode_dir
    render_dir.mkdir(parents=True, exist_ok=True)
    writer_agent = imageio.get_writer(str(render_dir / "agentview.mp4"), fps=fps)
    writer_wrist = imageio.get_writer(str(render_dir / "wrist.mp4"), fps=fps)

    try:
        for t in range(states.shape[0]):
            q = states[t]
            finger_m = float(np.clip(q[7], 0.0, 1.0)) * FINGER_OPEN_M
            q9 = np.array([q[0], q[1], q[2], q[3], q[4], q[5], q[6], finger_m, finger_m], dtype=np.float32)
            particles_t = np.asarray(particle_pos[t], dtype=np.float32)[None, ...]

            franka.set_qpos(q9)
            franka.control_dofs_position(q9)
            set_entity_pose(source_beaker, source_pose[t])
            set_entity_pose(target_beaker, target_pose[t])
            liquid.set_particles_pos(particles_t)
            scene.step()

            # Replay is data-driven. The beakers are freejoint MJCFs and render
            # with collision disabled, so a physics step can move them slightly
            # under gravity before the frame is drawn. Re-apply the recorded
            # state after stepping so the video matches the h5 exactly.
            franka.set_qpos(q9)
            franka.control_dofs_position(q9)
            set_entity_pose(source_beaker, source_pose[t])
            set_entity_pose(target_beaker, target_pose[t])
            liquid.set_particles_pos(particles_t)

            writer_agent.append_data(_extract_rgb(cam_agent.render()))
            writer_wrist.append_data(_extract_rgb(cam_wrist.render()))

    finally:
        writer_agent.close()
        writer_wrist.close()
        scene.destroy()


def _worker_render(args_tuple):
    episode_dir, fps, res_w, res_h, backend = args_tuple
    render_episode(Path(episode_dir), fps=fps, res_w=res_w, res_h=res_h, backend=backend)
    return str(Path(episode_dir))


def find_episodes(dataset_root: Path):
    episodes = []
    for d in sorted(dataset_root.glob("**/episode_*")):
        if d.is_dir() and (d / "data.h5").exists():
            episodes.append(d)
    return episodes


def main():
    parser = argparse.ArgumentParser(description="Stage2 data-driven replay renderer (supports folder + parallel)")
    parser.add_argument("--dataset_root", type=str, required=True, help="Folder containing episode_xxx subfolders")
    parser.add_argument("--episode", type=str, default="", help="Render one episode only, e.g. episode_000")
    parser.add_argument("--fps", type=int, default=20)
    parser.add_argument("--res_w", type=int, default=256)
    parser.add_argument("--res_h", type=int, default=256)
    parser.add_argument("--backend", type=str, choices=["gpu", "cpu"], default="gpu")
    parser.add_argument("--workers", type=int, default=1, help="Number of parallel processes")
    args = parser.parse_args()

    dataset_root = Path(args.dataset_root)
    if args.episode:
        episodes = [dataset_root / args.episode]
    else:
        episodes = find_episodes(dataset_root)

    if not episodes:
        raise FileNotFoundError(f"No episode folders with data.h5 under: {dataset_root}")

    workers = max(1, int(args.workers))

    if workers == 1 or len(episodes) == 1:
        for ep in episodes:
            render_episode(ep, fps=args.fps, res_w=args.res_w, res_h=args.res_h, backend=args.backend)
            print(f"Render done: {ep}")
        return

    futures = []
    with ProcessPoolExecutor(max_workers=workers) as ex:
        for ep in episodes:
            futures.append(ex.submit(_worker_render, (str(ep), args.fps, args.res_w, args.res_h, args.backend)))

        ok = 0
        for fut in as_completed(futures):
            try:
                out = fut.result()
                ok += 1
                print(f"Render done: {out}")
            except Exception as e:
                print(f"Render failed: {e}")

    print(f"All done. success={ok}/{len(episodes)}")


if __name__ == "__main__":
    main()
