XL2   Æ  ÁÄ&—Ä                Ó«!îÊó•™Ä     Å“ƒ¤Type¥V2ObjÞ ¢IDÄ                ¤DDirÄ§öÁ¥žIí‘ÒWïkì¤¼¦EcAlgo£EcM£EcN §EcBSizeÎ   §EcIndex¦EcDist‘¨CSumAlgo¨PartNums‘©PartETags‘Ù 76cab95dd07cf394303058d2df2fbd34©PartSizes‘Íö;ªPartASizes‘Íö;§PartIdx‘Ä ¤SizeÍö;¥MTimeÏ«!îÊó•™§MetaSys‚¼x-minio-internal-inline-dataÄtrue½x-rustfs-internal-inline-dataÄtrue§MetaUsr‚¤etagÙ 76cab95dd07cf394303058d2df2fbd34¬content-type´text/x-python-script¡v Î«_6¤nullÅö[á`À …úðODì”ž~¬®±Édˆ7=2Æb“x‡íimport argparse
import json
from dataclasses import dataclass
from pathlib import Path

import genesis as gs
import h5py
import numpy as np
import torch
from scipy.spatial.transform import Rotation as R, Slerp

# ---------------------------------------------------------------------------
# Scene object wrappers (entities built once, poses are [n_envs, ...] tensors)
# ---------------------------------------------------------------------------

def _to_np(x):
    return x.detach().cpu().numpy() if isinstance(x, torch.Tensor) else np.asarray(x)


class Table:
    def __init__(self, scene, pos=(0.0, 0.0, 0.0), euler=(0, 0, 90)):
        self.entity = scene.add_entity(
            gs.morphs.MJCF(file='asset/model/misc/simple_table.xml',
                           pos=pos, decimate=False, euler=euler)
        )


class Cell_Dish:
    def __init__(self, scene, pos=(0.1, 0.0, 0.854), euler=(0, 0, 90)):
        self.entity = scene.add_entity(
            gs.morphs.MJCF(file='asset/model/object/cell_dish_100.gen.xml',
                           pos=pos, decimate=False, euler=euler)
        )

    def get_pose(self):
        return _to_np(self.entity.get_pos()), _to_np(self.entity.get_quat())


class Pipette_Rack:
    def __init__(self, scene, pos=(-0.27, 0.0, 0.824), euler=(0, 0, 90)):
        self.entity = scene.add_entity(
            gs.morphs.MJCF(file='asset/model/object/pipette_rack.gen.xml',
                           pos=pos, decimate=False, euler=euler)
        )


class Pipette:
    def __init__(self, scene, pos=(-0.22, -0.0, 1.08), euler=(0, 10, 180)):
        self.entity = scene.add_entity(
            gs.morphs.MJCF(file='asset/model/object/pipette_free_tip.gen.xml',
                           pos=pos, decimate=False, euler=euler,
                           batch_fixed_verts=True)
        )
        self.entity.set_friction(2.0)

    def get_pose(self):
        return _to_np(self.entity.get_pos()), _to_np(self.entity.get_quat())

    def get_button_pos(self):
        button_link = self.entity.get_link('pipette_button')
        button_pos = button_link.get_pos().cpu().numpy()
        button_quat = button_link.get_quat().cpu().numpy()
        return button_pos, button_quat

    def get_ejector_pos(self):
        ejector_link = self.entity.get_link('pipette_ejector')
        ejector_pos = ejector_link.get_pos().cpu().numpy()
        ejector_quat = ejector_link.get_quat().cpu().numpy()
        return ejector_pos, ejector_quat

    def get_tip_pos(self):
        tip_link = self.entity.get_link('tip/pipette_tip')
        tip_pos = tip_link.get_pos().cpu().numpy()
        tip_quat = tip_link.get_quat().cpu().numpy()
        return tip_pos, tip_quat

    def release_tip(self, scene, envs_idx=None):
        rigid = scene.sim.rigid_solver
        pipette_link = self.entity.get_link('pipette')
        tip_link = self.entity.get_link('tip/pipette_tip')
        rigid.delete_weld_constraint(pipette_link.idx, tip_link.idx, envs_idx=envs_idx)
        rigid.delete_weld_constraint(tip_link.idx, pipette_link.idx, envs_idx=envs_idx)

    def attach_tip(self, scene, envs_idx=None):
        rigid = scene.sim.rigid_solver
        pipette_link = self.entity.get_link('pipette')
        tip_link = self.entity.get_link('tip/pipette_tip')
        rigid.add_weld_constraint(pipette_link.idx, tip_link.idx, envs_idx=envs_idx)

    def set_root_pos_keep_tip(self, pos, envs_idx=None):
        qpos = _to_np(self.entity.get_qpos()).astype(np.float32)
        if qpos.ndim == 1:
            qpos = qpos[None, :]

        target_pos = np.asarray(pos, dtype=np.float32)
        if target_pos.ndim == 1:
            target_pos = np.tile(target_pos[None, :], (qpos.shape[0], 1))

        root_joint = self.entity.get_joint('pipette_root_joint')
        tip_joint = self.entity.get_joint('pipette_tip_root_joint')
        root_qidx = root_joint.qs_idx_local
        tip_qidx = tip_joint.qs_idx_local

        delta = target_pos - qpos[:, root_qidx[:3]]
        qpos[:, root_qidx[:3]] = target_pos
        qpos[:, tip_qidx[:3]] += delta

        self.entity.set_qpos(
            torch.as_tensor(qpos, device=gs.device, dtype=gs.tc_float),
            envs_idx=envs_idx,
            zero_velocity=True,
        )

    def get_grasp_pose(self):
        pos, quat = self.get_pose()
        n = pos.shape[0]
        r_pipette = R.from_quat(np.concatenate([quat[:, 1:4], quat[:, :1]], axis=1))
        mats = r_pipette.as_matrix()
        v_long = mats[:, :, 2]
        cap_pos = pos + v_long * 0.095

        z_grip = np.tile(np.array([0.0, 0.0, -1.0]), (n, 1))
        y_grip = np.cross(z_grip, v_long)
        cap_pos = cap_pos + np.tile([0.140, 0.0, 0.03], (n, 1))

        y_norm = np.linalg.norm(y_grip, axis=1, keepdims=True)
        y_norm = np.where(y_norm < 1e-6, 1.0, y_norm)
        y_grip = y_grip / y_norm
        x_grip = np.cross(y_grip, z_grip)
        x_grip /= np.linalg.norm(x_grip, axis=1, keepdims=True)

        rot_mats = np.stack([x_grip, y_grip, z_grip], axis=-1)
        r_grip = R.from_matrix(rot_mats)

        eul = r_pipette.as_euler('xyz', degrees=True)            # [n,3]
        delta_y = eul[:, 1] + 90.0
        r_adjust = R.from_euler('y', delta_y, degrees=True)
        r_grip = r_grip * r_adjust

        qxyzw = r_grip.as_quat()
        target_quat = np.concatenate([qxyzw[:, 3:4], qxyzw[:, 0:3]], axis=1)
        return cap_pos.astype(np.float32), target_quat.astype(np.float32)

# ---------------------------------------------------------------------------
# Quaternion helpers (wxyz)
# ---------------------------------------------------------------------------

def wxyz_to_R(q):
    """q: [n,4] wxyz -> scipy Rotation batched."""
    q = np.asarray(q)
    return R.from_quat(np.concatenate([q[:, 1:4], q[:, :1]], axis=1))

def R_to_wxyz(r):
    qxyzw = r.as_quat()
    if qxyzw.ndim == 1:
        return np.array([qxyzw[3], qxyzw[0], qxyzw[1], qxyzw[2]], dtype=np.float32)
    return np.concatenate([qxyzw[:, 3:4], qxyzw[:, 0:3]], axis=1).astype(np.float32)

def pose_to_cartesian6(pos, quat, envs_offset=None):
    """pos[n,3], quat[n,4] wxyz -> [x,y,z,rx,ry,rz] with Euler XYZ in radians."""
    pos = np.asarray(pos, dtype=np.float32)
    quat = np.asarray(quat, dtype=np.float32)
    if pos.ndim == 1:
        pos = pos[None, :]
    if quat.ndim == 1:
        quat = quat[None, :]
    euler = wxyz_to_R(quat).as_euler("xyz", degrees=False).astype(np.float32)
    return np.concatenate([pos, euler], axis=1).astype(np.float32)

def next_frame(values):
    values = np.asarray(values, dtype=np.float32)
    if values.shape[0] <= 1:
        return values.copy()
    return np.concatenate([values[1:], values[-1:]], axis=0).astype(np.float32)

def finite_difference(values, dt):
    values = np.asarray(values, dtype=np.float32)
    out = np.zeros_like(values, dtype=np.float32)
    if values.shape[0] > 1:
        out[:-1] = (values[1:] - values[:-1]) / float(dt)
    return out

def binary_gripper(gripper):
    return (np.asarray(gripper, dtype=np.float32) > 0.5).astype(np.float32)

def max_true_run(mask):
    mask = np.asarray(mask, dtype=bool).reshape(-1)
    best = 0
    cur = 0
    for v in mask:
        if v:
            cur += 1
            best = max(best, cur)
        else:
            cur = 0
    return int(best)

def broadcast_quat(q, n):
    q = np.asarray(q, dtype=np.float32)
    if q.ndim == 1:
        return np.tile(q[None, :], (n, 1))
    return q

def broadcast_pos(p, n):
    p = np.asarray(p, dtype=np.float32)
    if p.ndim == 1:
        return np.tile(p[None, :], (n, 1))
    return p

def slerp_batch(q0_wxyz, q1_wxyz, alpha):
    """Per-env slerp. q0,q1: [n,4] wxyz, alpha scalar -> [n,4] wxyz."""
    r0 = wxyz_to_R(q0_wxyz)
    r1 = wxyz_to_R(q1_wxyz)
    n = q0_wxyz.shape[0]
    out = np.zeros_like(q0_wxyz, dtype=np.float32)
    for i in range(n):
        s = Slerp([0, 1], R.concatenate([r0[i], r1[i]]))
        out[i] = R_to_wxyz(s(alpha))
    return out

# ---------------------------------------------------------------------------
# Weld attachment checker (batched-aware, mirrors pipette3.WeldAttachmentChecker)
# ---------------------------------------------------------------------------

class WeldAttachmentChecker:
    """
    Batched weld rule across n_envs.

    Per-env state:
      attach when:  left finger contact AND right finger contact AND no support contact
      detach when:  bilateral finger contact gone OR support contact present

    acquire_steps / release_steps smooth one-frame contact noise.
    """
    def __init__(
        self,
        scene,
        robot_entity,
        ee_link,
        left_finger_link,
        right_finger_link,
        object_entity,
        support_entities=None,
        acquire_steps=2,
        release_steps=2,
        n_envs=1,
    ):
        self.scene = scene
        self.rigid = scene.sim.rigid_solver

        self.robot = robot_entity
        self.ee_link = ee_link
        self.left_finger = left_finger_link
        self.right_finger = right_finger_link
        self.obj = object_entity

        self.support_entities = [] if support_entities is None else list(support_entities)

        self.acquire_steps = acquire_steps
        self.release_steps = release_steps

        self.n_envs = n_envs
        self.active = np.zeros(n_envs, dtype=bool)
        self._acquire_counter = np.zeros(n_envs, dtype=np.int32)
        self._release_counter = np.zeros(n_envs, dtype=np.int32)

        self.link_obj = np.array([self.obj.base_link.idx], dtype=gs.np_int)
        self.link_ee = np.array([self.ee_link.idx], dtype=gs.np_int)

    # ---- contact queries (per-env) ------------------------------------------
    def _contact_per_env(self, entity):
        """Return [n_envs] bool array indicating contact between obj and entity."""
        info = self.obj.get_contacts(with_entity=entity)
        n = self.n_envs
        # Try valid_mask (batched, shape [n_envs, max_contacts])
        if "valid_mask" in info:
            vm = info["valid_mask"]
            arr = vm.detach().cpu().numpy() if isinstance(vm, torch.Tensor) else np.asarray(vm)
            if arr.ndim == 0:
                return np.array([bool(arr)] * n, dtype=bool)
            if arr.ndim == 1:
                # Single env case
                return np.array([bool(arr.any())] * n, dtype=bool) if n == 1 else \
                       np.array([bool(arr[i]) if i < arr.shape[0] else False for i in range(n)], dtype=bool)
            return arr.any(axis=tuple(range(1, arr.ndim))).astype(bool)
        if "geom_a" in info:
            ga = info["geom_a"]
            arr = ga.detach().cpu().numpy() if isinstance(ga, torch.Tensor) else np.asarray(ga)
            if arr.ndim <= 1:
                return np.array([int(arr.size) > 0] * n, dtype=bool)
            return (arr.shape[1] > 0) * np.ones(n, dtype=bool)
        return np.zeros(n, dtype=bool)

    def left_contact(self):
        return self._contact_per_env(self.left_finger)

    def right_contact(self):
        return self._contact_per_env(self.right_finger)

    def has_both_finger_contacts(self):
        return self.left_contact() & self.right_contact()

    def has_support_contact(self):
        if not self.support_entities:
            return np.zeros(self.n_envs, dtype=bool)
        out = np.zeros(self.n_envs, dtype=bool)
        for ent in self.support_entities:
            out = out | self._contact_per_env(ent)
        return out

    def grasp_success_now(self):
        return self.has_both_finger_contacts() & (~self.has_support_contact())

    def grasp_failure_now(self):
        return (~self.has_both_finger_contacts()) | self.has_support_contact()

    def finger_opening(self):
        q = _to_np(self.robot.get_dofs_position())
        return q[:, -2:].mean(axis=1)

    # ---- weld add/remove ----------------------------------------------------
    # NOTE: Genesis' add_weld_constraint asserts that the (link1, link2) pair
    # is not already present in ANY env (global check, not per-env). To work
    # around this we attach the weld to ALL envs once on first acquisition,
    # and detach from ALL envs once any env loses grasp persistently.
    def _weld_already_present(self):
        """Check if a weld between link_obj and link_ee already exists anywhere."""
        try:
            info = self.rigid.get_weld_constraints(as_tensor=True, to_torch=True)
            link_a = info["link_a"]
            link_b = info["link_b"]
            la = int(self.link_obj[0]); lb = int(self.link_ee[0])
            mask = (((link_a == la) | (link_b == la)) &
                    ((link_a == lb) | (link_b == lb)))
            return bool(mask.any().item()) if hasattr(mask, "any") else bool(np.asarray(mask).any())
        except Exception:
            return False

    def _attach_envs(self, env_idx):
        if env_idx.size == 0:
            return
        if self.active.any() or self._weld_already_present():
            # Already welded globally; just mark these envs as active.
            self.active[env_idx] = True
            self._acquire_counter[env_idx] = 0
            self._release_counter[env_idx] = 0
            return
        try:
            self.rigid.add_weld_constraint(self.link_obj, self.link_ee)
        except (AssertionError, Exception) as e:
            # If the weld already exists in the solver (e.g., due to a stale
            # constraint that survived a prior detach), just mark as attached.
            print(f"[WELD] add_weld_constraint skipped ({type(e).__name__}: {e})")
            self.active[:] = True
            self._acquire_counter[:] = 0
            self._release_counter[:] = 0
            return
        self.active[:] = True
        self._acquire_counter[:] = 0
        self._release_counter[:] = 0

    def _detach_envs(self, env_idx):
        if env_idx.size == 0:
            return
        if not (self.active.any() or self._weld_already_present()):
            self.active[:] = False
            return
        try:
            self.rigid.delete_weld_constraint(self.link_obj, self.link_ee)
        except (AssertionError, Exception) as e:
            print(f"[WELD] delete_weld_constraint skipped ({type(e).__name__}: {e})")
        self.active[:] = False
        self._acquire_counter[:] = 0
        self._release_counter[:] = 0

    def update(self):
        success = self.grasp_success_now()
        failure = self.grasp_failure_now()

        # inactive envs accumulate acquire counter on success
        inactive = ~self.active
        self._acquire_counter = np.where(inactive & success, self._acquire_counter + 1, 0) * inactive.astype(np.int32) \
                                + self._acquire_counter * self.active.astype(np.int32)
        # Simpler explicit update
        self._acquire_counter[inactive & success] += 1
        self._acquire_counter[inactive & ~success] = 0

        self._release_counter[self.active & failure] += 1
        self._release_counter[self.active & ~failure] = 0

        attach_envs = np.where(inactive & (self._acquire_counter >= self.acquire_steps))[0].astype(np.int32)
        self._attach_envs(attach_envs)

        detach_envs = np.where(self.active & (self._release_counter >= self.release_steps))[0].astype(np.int32)
        self._detach_envs(detach_envs)

        return {
            "attached": self.active.copy(),
            "attach_now": attach_envs,
            "detach_now": detach_envs,
        }


# ---------------------------------------------------------------------------
# Parallel task
# ---------------------------------------------------------------------------

@dataclass
class TaskConfig:
    n_envs: int = 4
    dt: float = 5e-2
    substeps: int = 100
    env_spacing: tuple = (2.0, 2.0)
    base_dir: Path = Path("results/discard_pipette_tip/dataset")
    episode_prefix: str = "episode_"
    show_viewer: bool = False
    record_video: bool = False

    # Randomization ranges
    pipette_y_range: tuple = (-0.1, 0.1)


class ParallelPipetteTask:
    def __init__(self, cfg: TaskConfig):
        self.cfg = cfg
        gs.init(backend=gs.gpu)

        self.scene = None
        self.cam = None
        self.table = None
        self.cell_dish = None
        self.pipette = None
        self.pipette_rack = None
        self.franka_left = None
        self.franka_right = None
        self.left_ee = None
        self.right_ee = None
        self.pipette_tip_link = None
        self.attachment_checker = None

        self.motors_dof = np.arange(7)
        self.fingers_dof = np.arange(7, 9)
        self.active_envs = np.arange(cfg.n_envs, dtype=np.int32)

        # per-env recording buffers
        n = cfg.n_envs
        self.buf = {
            "left_state":   [[] for _ in range(n)],
            "left_action":  [[] for _ in range(n)],
            "left_cartesian": [[] for _ in range(n)],
            "right_state":  [[] for _ in range(n)],
            "right_action": [[] for _ in range(n)],
            "right_cartesian": [[] for _ in range(n)],
            "pipette_pose": [[] for _ in range(n)],  # [pos3, quat4]
            "pipette_tip_pos": [[] for _ in range(n)],
            "pipette_qpos": [[] for _ in range(n)],
        }
        self.sim_t = 0.0

    def print_tip_welds(self, tag):
        rigid = self.scene.sim.rigid_solver
        info = rigid.get_weld_constraints(as_tensor=True, to_torch=True)
        pipette = self.pipette.entity.get_link('pipette').idx
        tip = self.pipette.entity.get_link('tip/pipette_tip').idx
        print(tag, "pipette", pipette, "tip", tip)
        print("link_a:", info["link_a"])
        print("link_b:", info["link_b"])

    def release_tip_preserving_grasp(self):
        rigid = self.scene.sim.rigid_solver
        pipette_link = self.pipette.entity.get_link('pipette')
        tip_link = self.pipette.entity.get_link('tip/pipette_tip')
        ee_link = self.left_ee

        checker = self.attachment_checker
        self.attachment_checker = None

        # Work around Genesis' old dynamic-weld deletion bug by making the
        # pipette-tip weld the last dynamic weld before deleting it.
        rigid.delete_weld_constraint(pipette_link.idx, ee_link.idx)
        rigid.delete_weld_constraint(ee_link.idx, pipette_link.idx)
        rigid.delete_weld_constraint(pipette_link.idx, tip_link.idx)
        rigid.delete_weld_constraint(tip_link.idx, pipette_link.idx)
        rigid.add_weld_constraint(pipette_link.idx, ee_link.idx)

        self.attachment_checker = checker
        if self.attachment_checker is not None:
            self.attachment_checker.active[:] = True
            self.attachment_checker._acquire_counter[:] = 0
            self.attachment_checker._release_counter[:] = 0
        
    # ----- scene construction ------------------------------------------------
    def build_scene(self):
        n = self.cfg.n_envs

        # Per-env randomization via torch.rand (each run is genuinely random;
        # torch's default RNG is seeded from OS entropy unless the user sets it).
        def _u(low, high, size):
            return (torch.rand(size) * (high - low) + low).cpu().numpy().astype(np.float32)

        # Pipette: random y in [-0.1, 0.1]
        self.pipette_y_rand = _u(self.cfg.pipette_y_range[0],
                                 self.cfg.pipette_y_range[1], (n,))

        print(f"[rand] n_envs={n}")
        print(f"  pipette_y:      {self.pipette_y_rand.tolist()}")

        self.scene = gs.Scene(
            sim_options=gs.options.SimOptions(dt=self.cfg.dt, substeps=self.cfg.substeps),
            viewer_options=gs.options.ViewerOptions(
                camera_pos=(2.1, 0.0, 1.5),
                camera_lookat=(0.0, 0.0, 0.8),
                camera_fov=30,
                max_FPS=60,
            ),
            show_viewer=self.cfg.show_viewer,
        )

        if self.cfg.record_video:
            self.cam = self.scene.add_camera(
                res=(640, 480), pos=(2.5, 0.0, 2.0),
                lookat=(0.0, 0.0, 1.0), fov=30, GUI=False,
            )
            self.cam_vice = self.scene.add_camera(
                res=(640, 480), pos=(0.1, -0.2, 1.1),
                lookat=(0.1, 0.0, 1.1), fov=30, GUI=False,
            )

        self.scene.add_entity(gs.morphs.Plane())
        self.table = Table(self.scene)
        self.cell_dish = Cell_Dish(self.scene)

        # Nominal (pre-randomization) positions used to compute per-env offsets.
        self._cell_dish_nominal = np.array([0.1, 0.0, 0.854], dtype=np.float32)
        self._pipette_nominal = np.array([-0.22, 0.0, 1.08], dtype=np.float32)

        self.pipette = Pipette(self.scene,
                               pos=tuple(self._pipette_nominal.tolist()),
                               euler=(0, 10, 180))
        self.pipette_rack = Pipette_Rack(self.scene)

        self.franka_left = self.scene.add_entity(
            gs.morphs.MJCF(file='xml/franka_emika_panda/panda.xml',
                           pos=(0.0, -0.5, 0.824), euler=(0, 0, 90))
        )
        self.franka_right = self.scene.add_entity(
            gs.morphs.MJCF(file='xml/franka_emika_panda/panda.xml',
                           pos=(0.0, 0.5, 0.824), euler=(0, 0, -90))
        )

        self.scene.build(n_envs=n, env_spacing=self.cfg.env_spacing)

        # ---- Apply per-env randomization ----
        envs_all = torch.arange(n, device=gs.device, dtype=torch.long)

        # Pipette: only y is randomized (x/z/yaw fixed at nominal)
        pipette_pos = np.tile(self._pipette_nominal, (n, 1))
        pipette_pos[:, 1] = self.pipette_y_rand
        self.pipette.set_root_pos_keep_tip(pipette_pos, envs_idx=envs_all)
        self.pipette.attach_tip(self.scene)

        # Arm controller setup
        for arm in (self.franka_left, self.franka_right):
            arm.set_dofs_kp(np.array([4500, 4500, 3500, 3500, 2000, 2000, 2000, 100, 100]))
            arm.set_dofs_kv(np.array([450, 450, 350, 350, 200, 200, 200, 10, 10]))
            arm.set_dofs_force_range(
                np.array([-87, -87, -87, -87, -12, -12, -12, -100, -100]),
                np.array([ 87,  87,  87,  87,  12,  12,  12,  100,  100]),
            )
            home_qpos = np.array([0.0, -0.785, 0.0, -2.356, 0.0, 1.571, 0.785, 0.04, 0.04])
            arm.set_dofs_position(home_qpos)
            arm.control_dofs_position(home_qpos)

        self.left_ee = self.franka_left.get_link('hand')
        self.right_ee = self.franka_right.get_link('hand')
        self.pipette_tip_link = self.pipette.entity.get_link('tip/pipette_tip')

        # Cache per-env world offsets so we can convert recorded poses back to
        # each env's local frame (matches the single-env replay used by render_pipette.py).
        try:
            envs_offset = _to_np(self.scene.envs_offset)  # [n,3]
        except Exception:
            try:
                envs_offset = _to_np(self.scene.get_envs_offset())
            except Exception:
                # Fallback: derive from arm base position vs nominal (0, -0.5, 0.824).
                left_base = _to_np(self.franka_left.get_pos())
                if left_base.ndim == 1:
                    left_base = np.tile(left_base[None, :], (self.cfg.n_envs, 1))
                envs_offset = left_base - np.array([0.0, -0.5, 0.824], dtype=np.float32)
        self.envs_offset = envs_offset.astype(np.float32)

        left_finger_link = self.franka_left.get_link('left_finger')
        right_finger_link = self.franka_left.get_link('right_finger')
        self.attachment_checker = WeldAttachmentChecker(
            scene=self.scene,
            robot_entity=self.franka_left,
            ee_link=self.left_ee,
            left_finger_link=left_finger_link,
            right_finger_link=right_finger_link,
            object_entity=self.pipette.entity,
            support_entities=[self.pipette_rack.entity],
            acquire_steps=2,
            release_steps=3,
            n_envs=self.cfg.n_envs,
        )

        # Cache scene init configuration (constants + per-env initial positions
        # of every entity and camera) so it can be saved verbatim into init.json.
        cam_agent_cfg = {
            "res": [256, 256], "pos": [2.5, 0.0, 2.0],
            "lookat": [0.0, 0.0, 1.0], "fov": 30,
        }
        cam_wrist_cfg = {
            "res": [256, 256], "pos": [0.0, 0.0, 0.0],
            "lookat": [1.0, 0.0, 0.0], "fov": 70,
            "attach_link": "hand",
            "offset_translation_xyz": [-0.08, 0.0, -0.035],
            "offset_rotation_euler_deg": {
                "yaw_z": 90.0,
                "pitch_y": 180.0,
                "roll_x": -10.0
            },
            "offset_rotation_composition": "Rz @ Ry @ Rx"
        }
        self.scene_init = {
            "n_envs": int(n),
            "env_spacing_xy_m": list(self.cfg.env_spacing),
            "envs_offset_xyz_m": self.envs_offset.tolist(),
            "fixed_entities": {
                "plane": {"pos": [0.0, 0.0, 0.0]},
                "table": {"file": "asset/model/misc/simple_table.xml",
                          "pos": [0.0, 0.0, 0.0], "euler_deg": [0, 0, 90]},
                "cell_dish": {"file": "asset/model/object/cell_dish_100.gen.xml",
                              "pos": self._cell_dish_nominal.tolist(), "euler_deg": [0, 0, 90]},
                "pipette_rack": {"file": "asset/model/object/pipette_rack.gen.xml",
                                 "pos": [-0.27, 0.0, 0.824], "euler_deg": [0, 0, 90]},
                "franka_left":  {"file": "xml/franka_emika_panda/panda.xml",
                                 "pos": [0.0, -0.5, 0.824], "euler_deg": [0, 0, 90]},
                "franka_right": {"file": "xml/franka_emika_panda/panda.xml",
                                 "pos": [0.0, 0.5, 0.824], "euler_deg": [0, 0, -90]},
            },
            "robot_home_qpos_9": [0.0, -0.785, 0.0, -2.356, 0.0, 1.571, 0.785, 0.04, 0.04],
            "robot_finger_open_m": float(self._FINGER_OPEN_M),
            "cameras": {
                "agentview": cam_agent_cfg,
                "wrist_camera": {
                    **cam_wrist_cfg,
                    "attach_link": "franka_left/hand",
                    "output": "wrist.mp4",
                },
                "wrist_camera_2": {
                    **cam_wrist_cfg,
                    "attach_link": "franka_right/hand",
                    "output": "wrist_2.mp4",
                },
            },
            "per_env": [
                {
                    "env_idx": int(i),
                    "envs_offset_xyz_m": self.envs_offset[i].tolist(),
                    "pipette": {
                        "file": "asset/model/object/pipette_free_tip.gen.xml",
                        "pos_local": [float(self._pipette_nominal[0]),
                                       float(self.pipette_y_rand[i]),
                                       float(self._pipette_nominal[2])],
                        "euler_deg": [0, 10, 180],
                    },
                }
                for i in range(n)
            ],
        }

        if self.cam is not None:
            self.cam.start_recording()
            self.cam_vice.start_recording()

    # ----- stepping and recording -------------------------------------------
    # Franka panda finger range per side: 0 (closed) .. 0.04 (open). The
    # mean of the two fingers is therefore in [0, 0.04]. We normalize to
    # [0, 1] (DROID-style) for state/action gripper component.
    _FINGER_OPEN_M = 0.04

    def _norm_grip(self, m):
        return float(np.clip(m / self._FINGER_OPEN_M, 0.0, 1.0))

    def _norm_grip_arr(self, m_arr):
        return np.clip(np.asarray(m_arr, dtype=np.float32) / self._FINGER_OPEN_M, 0.0, 1.0)

    def _record_step(self, left_action_np=None, right_action_np=None,
                     left_grip_cmd=None, right_grip_cmd=None):
        """
        left_action_np / right_action_np: [n,7] commanded joint targets for next frame
        left_grip_cmd  / right_grip_cmd : [n] commanded gripper opening (meters), next frame
        """
        n = self.cfg.n_envs
        l_q = _to_np(self.franka_left.get_qpos())   # [n, 9]
        r_q = _to_np(self.franka_right.get_qpos())
        # Current gripper opening (meters, mean of two fingers) -> normalized [0,1]
        l_gr_state = self._norm_grip_arr(l_q[:, 7:9].mean(axis=1, keepdims=True))
        r_gr_state = self._norm_grip_arr(r_q[:, 7:9].mean(axis=1, keepdims=True))
        l_state = np.concatenate([l_q[:, :7], l_gr_state], axis=1)   # [n,8]
        r_state = np.concatenate([r_q[:, :7], r_gr_state], axis=1)
        l_cart = pose_to_cartesian6(_to_np(self.left_ee.get_pos()), _to_np(self.left_ee.get_quat()), self.envs_offset)
        r_cart = pose_to_cartesian6(_to_np(self.right_ee.get_pos()), _to_np(self.right_ee.get_quat()), self.envs_offset)

        # Action: [7 commanded joints, normalized commanded gripper opening for NEXT frame]
        if left_action_np is None:
            la_joints = l_q[:, :7]
        else:
            la_joints = np.asarray(left_action_np, dtype=np.float32)
        if right_action_np is None:
            ra_joints = r_q[:, :7]
        else:
            ra_joints = np.asarray(right_action_np, dtype=np.float32)

        if left_grip_cmd is None:
            la_grip = l_gr_state
        else:
            lg_arr = np.asarray(left_grip_cmd, dtype=np.float32).reshape(-1)
            if lg_arr.size == 1:
                lg_arr = np.full((n,), float(lg_arr[0]), dtype=np.float32)
            la_grip = self._norm_grip_arr(lg_arr).reshape(-1, 1)
        if right_grip_cmd is None:
            ra_grip = r_gr_state
        else:
            rg_arr = np.asarray(right_grip_cmd, dtype=np.float32).reshape(-1)
            if rg_arr.size == 1:
                rg_arr = np.full((n,), float(rg_arr[0]), dtype=np.float32)
            ra_grip = self._norm_grip_arr(rg_arr).reshape(-1, 1)

        left_action_full = np.concatenate([la_joints, la_grip], axis=1)
        right_action_full = np.concatenate([ra_joints, ra_grip], axis=1)

        pip_pos, pip_quat = self.pipette.get_pose()
        pip_qpos = _to_np(self.pipette.entity.get_qpos())
        pipette_tip_link = self.pipette_tip_link
        tip_pos = _to_np(pipette_tip_link.get_pos())
        # Genesis returns per-env local coordinates here. Store them directly
        # for single-env replay; subtracting envs_offset would shift non-zero
        # envs away from the camera.
        pip_pose = np.concatenate([pip_pos, pip_quat], axis=1)

        for i in range(n):
            self.buf["left_state"][i].append(l_state[i])
            self.buf["right_state"][i].append(r_state[i])
            self.buf["left_action"][i].append(left_action_full[i])
            self.buf["right_action"][i].append(right_action_full[i])
            self.buf["left_cartesian"][i].append(l_cart[i])
            self.buf["right_cartesian"][i].append(r_cart[i])
            self.buf["pipette_pose"][i].append(pip_pose[i])
            self.buf["pipette_tip_pos"][i].append(tip_pos[i])
            self.buf["pipette_qpos"][i].append(pip_qpos[i])

    def _step(self, left_action=None, right_action=None,
              left_grip_cmd=None, right_grip_cmd=None):
        self.scene.step()
        self.sim_t += self.cfg.dt
        if self.attachment_checker is not None:
            self.attachment_checker.update()
        if self.cam is not None:
            self.cam.render()
            self.cam_vice.render()
        self._record_step(left_action, right_action, left_grip_cmd, right_grip_cmd)

    # ----- parallel motion primitives ---------------------------------------
    def _ik(self, arm, ee_link, pos, quat):
        """pos [n,3], quat [n,4] wxyz -> qpos [n, dof]"""
        pos_t = torch.as_tensor(pos, device=gs.device, dtype=gs.tc_float)
        quat_t = torch.as_tensor(quat, device=gs.device, dtype=gs.tc_float)
        return arm.inverse_kinematics(link=ee_link, pos=pos_t, quat=quat_t)

    def _apply_arm(self, arm, q_cmd, grip_force, grip_open_m=None):
        """
        grip_force: <0 means close, >0 means open.
        grip_open_m: optional, the *intended* gripper opening in meters this
                    command corresponds to (used purely for action recording).
        """
        if grip_open_m is None:
            grip_open_m = self._grip_open_for_force(grip_force)
        q_target = np.asarray(q_cmd, dtype=np.float32).copy()
        if q_target.ndim == 1:
            q_target = q_target[None, :]
        if q_target.shape[1] >= 9 and grip_open_m is not None:
            q_target[:, 7:9] = float(grip_open_m)
            arm.control_dofs_position(q_target)
        else:
            arm.control_dofs_position(q_target[:, :7], self.motors_dof)
        return grip_open_m

    def _grip_open_for_force(self, grip_force):
        """Map the convention used in this script (-10 -> closed, +ve -> open) to meters."""
        if grip_force is None:
            return None
        return 0.0 if grip_force < 0 else self._FINGER_OPEN_M

    def ik_move_to_pose(self, arm, ee_link, target_pos, target_quat,
                       n_interp=120, gripper_force=2.0, n_steps=80, record_arm='left'):
        n = self.cfg.n_envs
        target_pos = broadcast_pos(target_pos, n)
        target_quat = broadcast_quat(target_quat, n)

        q_goal = _to_np(self._ik(arm, ee_link, target_pos, target_quat))
        q_start = _to_np(arm.get_qpos())
        grip_open_m = self._grip_open_for_force(gripper_force)

        for t in range(1, n_interp + 1):
            alpha = t / n_interp
            q_cmd = (1.0 - alpha) * q_start + alpha * q_goal
            self._apply_arm(arm, q_cmd, gripper_force)
            la = q_cmd[:, :7] if record_arm == 'left' else None
            ra = q_cmd[:, :7] if record_arm == 'right' else None
            lg = grip_open_m if record_arm == 'left' else None
            rg = grip_open_m if record_arm == 'right' else None
            self._step(la, ra, lg, rg)

        for _ in range(n_steps):
            self._step()
        return q_goal

    def grisp(self, arm, gripper_force=-10.0, n_steps=100, record_arm='left'):
        grip_open_m = self._grip_open_for_force(gripper_force)
        q_start = _to_np(arm.get_qpos()).astype(np.float32)
        q_target = q_start.copy()
        q_target[:, 7:9] = grip_open_m
        steps = max(1, int(n_steps))
        for step in range(1, steps + 1):
            alpha = step / steps
            q_cmd = (1.0 - alpha) * q_start + alpha * q_target
            arm.control_dofs_position(q_cmd)
            la = q_cmd[:, :7] if record_arm == 'left' else None
            ra = q_cmd[:, :7] if record_arm == 'right' else None
            lg = grip_open_m if record_arm == 'left' else None
            rg = grip_open_m if record_arm == 'right' else None
            self._step(la, ra, lg, rg)

    def _obj_pose_batched(self, obj):
        """Returns (pos[n,3], quat[n,4] wxyz, R batched)."""
        p = _to_np(obj.get_pos())
        q = _to_np(obj.get_quat())
        if p.ndim == 1:
            p = np.tile(p[None, :], (self.cfg.n_envs, 1))
        if q.ndim == 1:
            q = np.tile(q[None, :], (self.cfg.n_envs, 1))
        r = wxyz_to_R(q)
        return p, q, r

    def move_to_pose_with_refinement(self, arm, ee_link,
                                     source_object=None,
                                     target_object=None,
                                     target_pos=None,
                                     target_quat=None,
                                     coarse_interp=12,
                                     coarse_force=2.0,
                                     coarse_steps=5,
                                     record_arm='left'):
        """Batched analog of pipette3.move_to_pose_with_refinement (coarse only)."""
        n = self.cfg.n_envs
        ee_pos0 = _to_np(ee_link.get_pos())
        ee_quat0 = _to_np(ee_link.get_quat())
        ee_r0 = wxyz_to_R(ee_quat0)

        if source_object is not None:
            src_pos0, _, src_r0 = self._obj_pose_batched(source_object)
            source_to_ee_pos = np.zeros((n, 3), dtype=np.float32)
            for i in range(n):
                source_to_ee_pos[i] = src_r0[i].inv().apply(ee_pos0[i] - src_pos0[i])
            source_to_ee_r = [src_r0[i].inv() * ee_r0[i] for i in range(n)]

            if target_object is not None:
                tgt_src_pos0, _, tgt_src_r0 = self._obj_pose_batched(target_object)
            else:
                if target_pos is None:
                    raise ValueError("source_object æ¨¡å¼ä¸‹ï¼Œå¿…é¡»æä¾› target_object æˆ– target_pos")
                tgt_src_pos0 = broadcast_pos(target_pos, n)
                if target_quat is None:
                    tgt_src_r0 = src_r0
                else:
                    tgt_src_r0 = wxyz_to_R(broadcast_quat(target_quat, n))

            coarse_pos = np.zeros((n, 3), dtype=np.float32)
            coarse_quat = np.zeros((n, 4), dtype=np.float32)
            for i in range(n):
                coarse_pos[i] = tgt_src_pos0[i] + tgt_src_r0[i].apply(source_to_ee_pos[i])
                coarse_r = tgt_src_r0[i] * source_to_ee_r[i]
                qx, qy, qz, qw = coarse_r.as_quat()
                coarse_quat[i] = [qw, qx, qy, qz]
        else:
            if target_pos is None or target_quat is None:
                raise ValueError("ç›´æŽ¥ ee æ¨¡å¼ä¸‹ï¼Œå¿…é¡»æä¾› target_pos å’Œ target_quat")
            coarse_pos = broadcast_pos(target_pos, n)
            coarse_quat = broadcast_quat(target_quat, n)

        self.ik_move_to_pose(arm, ee_link, coarse_pos, coarse_quat,
                             n_interp=coarse_interp, gripper_force=coarse_force,
                             n_steps=coarse_steps, record_arm=record_arm)

    def move_two_arms_with_refinement(self, left_cfg, right_cfg):
        """Batched analog of pipette3.move_two_arms_with_refinement (coarse, synchronous)."""
        n = self.cfg.n_envs

        l_ee_pos0 = _to_np(self.left_ee.get_pos())
        l_ee_quat0 = _to_np(self.left_ee.get_quat())
        l_ee_r0 = wxyz_to_R(l_ee_quat0)
        r_ee_pos0 = _to_np(self.right_ee.get_pos())
        r_ee_quat0 = _to_np(self.right_ee.get_quat())
        r_ee_r0 = wxyz_to_R(r_ee_quat0)

        def _compute_goal(ee_pos0, ee_r0_batch, cfg):
            source_object = cfg.get("source_object", None)
            target_object = cfg.get("target_object", None)
            target_pos = cfg.get("target_pos", None)
            target_quat = cfg.get("target_quat", None)

            if source_object is not None:
                src_pos0, _, src_r0 = self._obj_pose_batched(source_object)
                if target_object is not None:
                    tgt_src_pos0, _, tgt_src_r0 = self._obj_pose_batched(target_object)
                else:
                    if target_pos is None:
                        raise ValueError("source_object æ¨¡å¼ä¸‹ï¼Œå¿…é¡»æä¾› target_object æˆ– target_pos")
                    tgt_src_pos0 = broadcast_pos(target_pos, n)
                    if target_quat is None:
                        tgt_src_r0 = src_r0
                    else:
                        tgt_src_r0 = wxyz_to_R(broadcast_quat(target_quat, n))

                goal_pos = np.zeros((n, 3), dtype=np.float32)
                goal_quat = np.zeros((n, 4), dtype=np.float32)
                for i in range(n):
                    goal_pos[i] = ee_pos0[i] + (tgt_src_pos0[i] - src_pos0[i])
                    gr = (tgt_src_r0[i] * src_r0[i].inv()) * ee_r0_batch[i]
                    qx, qy, qz, qw = gr.as_quat()
                    goal_quat[i] = [qw, qx, qy, qz]
                return goal_pos, goal_quat
            else:
                if target_pos is None or target_quat is None:
                    raise ValueError("ç›´æŽ¥ ee æ¨¡å¼ä¸‹ï¼Œå¿…é¡»æä¾› target_pos å’Œ target_quat")
                return broadcast_pos(target_pos, n), broadcast_quat(target_quat, n)

        l_goal_pos, l_goal_quat = _compute_goal(l_ee_pos0, l_ee_r0, left_cfg)
        r_goal_pos, r_goal_quat = _compute_goal(r_ee_pos0, r_ee_r0, right_cfg)

        l_interp = int(left_cfg.get("coarse_interp", 12))
        r_interp = int(right_cfg.get("coarse_interp", 12))
        l_force = float(left_cfg.get("coarse_force", 2.0))
        r_force = float(right_cfg.get("coarse_force", 2.0))
        l_grip_open = self._grip_open_for_force(l_force)
        r_grip_open = self._grip_open_for_force(r_force)
        coarse_iters = max(l_interp, r_interp)

        for t in range(1, coarse_iters + 1):
            if l_interp > 0:
                la = min(t / l_interp, 1.0)
                lp = (1 - la) * l_ee_pos0 + la * l_goal_pos
                lq = slerp_batch(l_ee_quat0, l_goal_quat, la)
                lq_goal = _to_np(self._ik(self.franka_left, self.left_ee, lp, lq))
                self._apply_arm(self.franka_left, lq_goal, l_force)
            else:
                lq_goal = None
            if r_interp > 0:
                ra = min(t / r_interp, 1.0)
                rp = (1 - ra) * r_ee_pos0 + ra * r_goal_pos
                rq = slerp_batch(r_ee_quat0, r_goal_quat, ra)
                rq_goal = _to_np(self._ik(self.franka_right, self.right_ee, rp, rq))
                self._apply_arm(self.franka_right, rq_goal, r_force)
            else:
                rq_goal = None
            self._step(
                lq_goal[:, :7] if lq_goal is not None else None,
                rq_goal[:, :7] if rq_goal is not None else None,
                l_grip_open, r_grip_open,
            )

        coarse_settle = int(max(left_cfg.get("coarse_steps", 0),
                                 right_cfg.get("coarse_steps", 0)))
        for _ in range(coarse_settle):
            self._step()

    def move_two_arms_coarse(self, left_cfg, right_cfg, n_interp=120,
                             gripper_force_l=-10.0, gripper_force_r=-10.0,
                             settle_steps=60):
        """Synchronous batched linear interp (pos lerp + slerp) for both arms."""
        n = self.cfg.n_envs

        l_pos0 = _to_np(self.left_ee.get_pos())
        l_q0 = _to_np(self.left_ee.get_quat())
        r_pos0 = _to_np(self.right_ee.get_pos())
        r_q0 = _to_np(self.right_ee.get_quat())

        l_goal_pos = broadcast_pos(left_cfg['target_pos'], n)
        l_goal_q = broadcast_quat(left_cfg['target_quat'], n)
        r_goal_pos = broadcast_pos(right_cfg['target_pos'], n)
        r_goal_q = broadcast_quat(right_cfg['target_quat'], n)
        l_grip_open = self._grip_open_for_force(gripper_force_l)
        r_grip_open = self._grip_open_for_force(gripper_force_r)

        for t in range(1, n_interp + 1):
            alpha = t / n_interp
            lp = (1 - alpha) * l_pos0 + alpha * l_goal_pos
            rp = (1 - alpha) * r_pos0 + alpha * r_goal_pos
            lq = slerp_batch(l_q0, l_goal_q, alpha)
            rq = slerp_batch(r_q0, r_goal_q, alpha)

            lq_goal = _to_np(self._ik(self.franka_left, self.left_ee, lp, lq))
            rq_goal = _to_np(self._ik(self.franka_right, self.right_ee, rp, rq))

            self._apply_arm(self.franka_left, lq_goal, gripper_force_l)
            self._apply_arm(self.franka_right, rq_goal, gripper_force_r)
            self._step(lq_goal[:, :7], rq_goal[:, :7], l_grip_open, r_grip_open)

        for _ in range(settle_steps):
            self._step()

    def move_two_arms_vertical_z(self, left_dz=0.0, right_dz=0.0,
                                 n_steps=100, n_interp=100,
                                 left_gripper_force=-5.0, right_gripper_force=-5.0):
        n = self.cfg.n_envs
        l_start = _to_np(self.left_ee.get_pos())
        r_start = _to_np(self.right_ee.get_pos())
        l_quat = _to_np(self.left_ee.get_quat())
        r_quat = _to_np(self.right_ee.get_quat())
        # Broadcast per-env dz if scalar
        if np.isscalar(left_dz):
            left_dz = np.full(n, float(left_dz), dtype=np.float32)
        if np.isscalar(right_dz):
            right_dz = np.full(n, float(right_dz), dtype=np.float32)
        l_tgt = l_start.copy(); l_tgt[:, 2] += left_dz
        r_tgt = r_start.copy(); r_tgt[:, 2] += right_dz

        rot_mask_all = [True, True, True]
        for step in range(1, n_steps + 1):
            alpha = step / n_steps
            lp = (1 - alpha) * l_start + alpha * l_tgt
            rp = (1 - alpha) * r_start + alpha * r_tgt
            lq = _to_np(self.franka_left.inverse_kinematics(
                link=self.left_ee,
                pos=torch.as_tensor(lp, device=gs.device, dtype=gs.tc_float),
                quat=torch.as_tensor(l_quat, device=gs.device, dtype=gs.tc_float),
                rot_mask=rot_mask_all,
            ))
            rq = _to_np(self.franka_right.inverse_kinematics(
                link=self.right_ee,
                pos=torch.as_tensor(rp, device=gs.device, dtype=gs.tc_float),
                quat=torch.as_tensor(r_quat, device=gs.device, dtype=gs.tc_float),
                rot_mask=rot_mask_all,
            ))
            self._apply_arm(self.franka_left, lq, left_gripper_force)
            self._apply_arm(self.franka_right, rq, right_gripper_force)
            self._step(lq[:, :7], rq[:, :7],
                       self._grip_open_for_force(left_gripper_force),
                       self._grip_open_for_force(right_gripper_force))

        for _ in range(n_interp):
            self._step()

    # ----- sequence ---------------------------------------------------------
    def run(self):
        self.build_scene()

        # settle
        for _ in range(15):
            self.scene.step()

        right_origin_pos = _to_np(self.right_ee.get_pos()).copy()
        right_origin_quat = _to_np(self.right_ee.get_quat()).copy()

        # Step 1: Move to the pipette
        s1_pos, s1_quat = self.pipette.get_grasp_pose()
        self.ik_move_to_pose(self.franka_left, self.left_ee, s1_pos, s1_quat,
                             n_interp=25, gripper_force=2.0, n_steps=5,
                             record_arm='left')

        # Step 2: Close in and grip
        s2_pos = s1_pos + np.array([-0.037, 0.0, 0.0], dtype=np.float32)
        s2_quat = s1_quat
        self.ik_move_to_pose(self.franka_left, self.left_ee, s2_pos, s2_quat,
                             n_interp=16, gripper_force=2.0, n_steps=0,
                             record_arm='left')

        self.grisp(self.franka_left, gripper_force=-10.0, n_steps=10)

        # Step 3: Lift the pipette
        s3_pos = s2_pos + np.array([0.0, 0.0, 0.1], dtype=np.float32)
        s3_quat = np.tile(np.array([0.7071, 0, -0.7071, 0], dtype=np.float32),
                          (self.cfg.n_envs, 1))
        self.ik_move_to_pose(self.franka_left, self.left_ee, s3_pos, s3_quat,
                             n_interp=15, gripper_force=-10.0, n_steps=0,
                             record_arm='left')

        pipette_tip_link = self.pipette_tip_link

        # Step 4: Move pipette tip above the cell dish.
        cell_dish_pos, _ = self.cell_dish.get_pose()
        pip_pos, _ = self.pipette.get_pose()
        target_pos = np.stack(
            [cell_dish_pos[:, 0], cell_dish_pos[:, 1], cell_dish_pos[:, 2] + 0.1], axis=1
        ).astype(np.float32)
        target_quat = np.tile(np.array([0.7071, 0.0, 0.0, -0.7071], dtype=np.float32),
                              (self.cfg.n_envs, 1))
        self.move_to_pose_with_refinement(
            arm=self.franka_left, ee_link=self.left_ee,
            source_object=pipette_tip_link, target_object=None,
            target_pos=target_pos, target_quat=target_quat,
            coarse_interp=16, coarse_force=-10.0, coarse_steps=2,
            record_arm='left',
        )

        for _ in range(10):
            self._step()

        # Step 5: Move the right arm above the ejector, press it to the bottom,
        # then release the detachable tip weld so the tip can fall freely.
        pipette_ejector = self.pipette.entity.get_link('pipette_ejector')
        ejector_pos = _to_np(pipette_ejector.get_pos())
        cell_dish_pos_now, _ = self.cell_dish.get_pose()
        pip_pos_now, _ = self.pipette.get_pose()
        left_target_pos = np.stack(
            [cell_dish_pos_now[:, 0], cell_dish_pos_now[:, 1], pip_pos_now[:, 2] + 0.01], axis=1
        ).astype(np.float32)
        left_target_quat = np.tile(np.array([0.7071, 0.0, 0.0, -0.7071], dtype=np.float32),
                                   (self.cfg.n_envs, 1))

        self.move_two_arms_with_refinement(
            left_cfg={
                "source_object": pipette_tip_link,
                "target_pos": left_target_pos,
                "target_quat": left_target_quat,
                "coarse_interp": 0,
                "coarse_force": -10.0,
                "coarse_steps": 3,
            },
            right_cfg={
                "source_object": None,
                "target_pos": ejector_pos + np.array([0.00, -0.035, 0.30], dtype=np.float32),
                "target_quat": right_origin_quat,
                "coarse_interp": 12,
                "coarse_force": -10.0,
                "coarse_steps": 3,
            },
        )

        press_pos = _to_np(self.right_ee.get_pos()) - np.array([0.0, 0.0, 0.021], dtype=np.float32)
        cell_dish_pos_now, _ = self.cell_dish.get_pose()
        pip_pos_now, _ = self.pipette.get_pose()
        left_target_pos = np.stack(
            [cell_dish_pos_now[:, 0], cell_dish_pos_now[:, 1], pip_pos_now[:, 2] + 0.01], axis=1
        ).astype(np.float32)
        self.move_two_arms_with_refinement(
            left_cfg={
                "source_object": pipette_tip_link,
                "target_pos": left_target_pos,
                "target_quat": left_target_quat,
                "coarse_interp": 0,
                "coarse_force": -10.0,
                "coarse_steps": 5,
            },
            right_cfg={
                "source_object": None,
                "target_pos": press_pos,
                "target_quat": right_origin_quat,
                "coarse_interp": 12,
                "coarse_force": -10.0,
                "coarse_steps": 5,
            },
        )

        self.release_tip_preserving_grasp()

        for _ in range(10):
            self._step()

        # Step 6: Right arm lift (slight)
        self.move_two_arms_vertical_z(
            left_dz=0.0, right_dz=0.04,
            n_steps=10, n_interp=4,
            left_gripper_force=-10.0, right_gripper_force=-10.0,
        )

        for _ in range(10):
            self._step()

        if self.cam is not None:
            self.cam.stop_recording(save_to_filename='pick_up_pipette_parallel.mp4', fps=20)

        if self.cam_vice is not None:
            self.cam_vice.stop_recording(save_to_filename='pick_up_pipette_parallel_vice.mp4', fps=20)

        self.finalize_outputs()

    # ----- output -----------------------------------------------------------
    def finalize_outputs(self):
        self.cfg.base_dir.mkdir(parents=True, exist_ok=True)
        for env_i in range(self.cfg.n_envs):
            ep_dir = self.cfg.base_dir / f"{self.cfg.episode_prefix}{env_i:03d}"
            ep_dir.mkdir(parents=True, exist_ok=True)

            def stack(key):
                return np.asarray(self.buf[key][env_i], dtype=np.float32)

            ls = stack("left_state"); la = stack("left_action")
            rs = stack("right_state"); ra = stack("right_action")
            lc = stack("left_cartesian"); rc = stack("right_cartesian")
            pp = stack("pipette_pose"); tip = stack("pipette_tip_pos")
            pip_qpos = stack("pipette_qpos")
            T = min(len(ls), len(la), len(lc), len(rs), len(ra), len(rc), len(pp), len(tip), len(pip_qpos))

            ls = ls[:T]; la = la[:T]; rs = rs[:T]; ra = ra[:T]
            lc = lc[:T]; rc = rc[:T]
            pp = pp[:T]; tip = tip[:T]; pip_qpos = pip_qpos[:T]

            # Action gripper at frame t = commanded opening intended for frame t+1.
            # Shift action gripper by one step so action[t][7] reflects the NEXT
            # frame's gripper command (last frame keeps its own command).
            la[:-1, 7] = la[1:, 7]
            ra[:-1, 7] = ra[1:, 7]

            left_joint_position = ls[:, :7].astype(np.float32)
            right_joint_position = rs[:, :7].astype(np.float32)
            left_gripper_position = ls[:, 7:8].astype(np.float32)
            right_gripper_position = rs[:, 7:8].astype(np.float32)
            left_cartesian_position = lc.astype(np.float32)
            right_cartesian_position = rc.astype(np.float32)

            left_joint_position_cmd = next_frame(left_joint_position)
            right_joint_position_cmd = next_frame(right_joint_position)
            left_cartesian_position_cmd = next_frame(left_cartesian_position)
            right_cartesian_position_cmd = next_frame(right_cartesian_position)
            left_joint_velocity_cmd = (left_joint_position_cmd - left_joint_position).astype(np.float32)
            right_joint_velocity_cmd = (right_joint_position_cmd - right_joint_position).astype(np.float32)
            left_cartesian_velocity_cmd = (left_cartesian_position_cmd - left_cartesian_position).astype(np.float32)
            right_cartesian_velocity_cmd = (right_cartesian_position_cmd - right_cartesian_position).astype(np.float32)
            left_gripper_action = binary_gripper(la[:, 7:8])
            right_gripper_action = binary_gripper(ra[:, 7:8])

            obs_cartesian_position = np.concatenate([left_cartesian_position, right_cartesian_position], axis=1).astype(np.float32)
            obs_joint_position = np.concatenate([left_joint_position, right_joint_position], axis=1).astype(np.float32)
            obs_gripper_position = np.concatenate([
                left_gripper_position, right_gripper_position
            ], axis=1).astype(np.float32) * self._FINGER_OPEN_M

            act_cartesian_position = np.concatenate([left_cartesian_position_cmd, right_cartesian_position_cmd], axis=1).astype(np.float32)
            act_joint_position = np.concatenate([left_joint_position_cmd, right_joint_position_cmd], axis=1).astype(np.float32)
            act_gripper_position = np.concatenate([left_gripper_action, right_gripper_action], axis=1).astype(np.float32)
            act_cartesian_velocity = np.concatenate([left_cartesian_velocity_cmd, right_cartesian_velocity_cmd], axis=1).astype(np.float32)
            act_joint_velocity = np.concatenate([left_joint_velocity_cmd, right_joint_velocity_cmd], axis=1).astype(np.float32)

            dones = np.zeros((T,), dtype=np.int8)
            rewards = np.zeros((T,), dtype=np.float32)
            success = bool(T > 0)
            success_info = {"tip_release_attempted": True}
            if T > 0:
                dones[-1] = 1
                rewards[-1] = 1.0 if success else 0.0
            language_instruction = (
                "Pick up the pipette with the left arm, move it above the cell dish, "
                "then press the pipette ejector with the right arm to discard the tip."
            )

            # data.h5 (full payload, aligned with the dual-arm collection format)
            comp = {"compression": "gzip", "compression_opts": 4}
            per_env = self.scene_init["per_env"][env_i]
            env_offset = np.asarray(per_env["envs_offset_xyz_m"], dtype=np.float32)
            pipette_init_pos = np.asarray(per_env["pipette"]["pos_local"], dtype=np.float32)
            pipette_init_quat = pp[0, 3:7].astype(np.float32) if T > 0 else np.array([1, 0, 0, 0], dtype=np.float32)
            pipette_tip_init_pos = tip[0].astype(np.float32) if T > 0 else np.zeros((3,), dtype=np.float32)
            cell_dish_cfg = self.scene_init["fixed_entities"]["cell_dish"]
            pipette_rack_cfg = self.scene_init["fixed_entities"]["pipette_rack"]
            with h5py.File(ep_dir / "data.h5", "w") as f:
                f.attrs["language_instruction"] = language_instruction
                f.attrs["dt"] = self.cfg.dt
                f.attrs["substeps"] = self.cfg.substeps
                f.attrs["task"] = "discard_pipette_tip"
                f.attrs["success"] = int(success)
                f.attrs["total"] = int(T)
                f.attrs["rotation"] = "Euler XYZ radians"
                f.attrs["gripper_units"] = "observation gripper in meters; action gripper is binary 0=closed, 1=open."

                data_grp = f.create_group("data")
                demo = data_grp.create_group("demo_0")
                demo.attrs["num_samples"] = int(T)
                demo.attrs["language_instruction"] = language_instruction
                demo.attrs["success"] = int(success)
                for k, v in success_info.items():
                    demo.attrs[f"success_{k}"] = v
                demo.attrs["pipette_y_randomized_m"] = float(self.pipette_y_rand[env_i])

                init = demo.create_group("init")
                init.create_dataset("env_offset_xyz_m", data=env_offset.astype(np.float32))
                init.create_dataset("pipette_init_pos", data=pipette_init_pos.astype(np.float32))
                init.create_dataset("pipette_init_quat", data=pipette_init_quat)
                init.create_dataset("pipette_tip_init_pos", data=pipette_tip_init_pos)
                init.create_dataset("cell_dish_init_pos", data=np.asarray(cell_dish_cfg["pos"], dtype=np.float32))
                init.create_dataset("cell_dish_init_euler_deg", data=np.asarray(cell_dish_cfg["euler_deg"], dtype=np.float32))
                init.create_dataset("pipette_rack_init_pos", data=np.asarray(pipette_rack_cfg["pos"], dtype=np.float32))
                init.create_dataset("pipette_rack_init_euler_deg", data=np.asarray(pipette_rack_cfg["euler_deg"], dtype=np.float32))
                init.attrs["pipette_file"] = per_env["pipette"]["file"]
                init.attrs["cell_dish_file"] = cell_dish_cfg["file"]
                init.attrs["pipette_rack_file"] = pipette_rack_cfg["file"]

                obs = demo.create_group("obs")
                obs.create_dataset("cartesian_position", data=obs_cartesian_position, **comp)
                obs.create_dataset("joint_position", data=obs_joint_position, **comp)
                obs.create_dataset("gripper_position", data=obs_gripper_position, **comp)
                obs.create_dataset("pipette_pose", data=pp, **comp)
                obs.create_dataset("pipette_tip_pos", data=tip, **comp)
                obs.create_dataset("pipette_qpos", data=pip_qpos, **comp)

                actions = demo.create_group("actions")
                actions.create_dataset("cartesian_position", data=act_cartesian_position, **comp)
                actions.create_dataset("joint_position", data=act_joint_position, **comp)
                actions.create_dataset("gripper_position", data=act_gripper_position, **comp)
                actions.create_dataset("cartesian_velocity", data=act_cartesian_velocity, **comp)
                actions.create_dataset("joint_velocity", data=act_joint_velocity, **comp)

                demo.create_dataset("dones", data=dones)
                demo.create_dataset("rewards", data=rewards)

            with h5py.File(ep_dir / "states_actions.hdf5", "w") as f:
                f.attrs["language_instruction"] = language_instruction
                f.attrs["dt"] = self.cfg.dt
                f.attrs["task"] = "discard_pipette_tip"
                f.attrs["success"] = int(success)
                f.attrs["total"] = int(T)
                f.attrs["rotation"] = "Euler XYZ radians"
                f.attrs["gripper_units"] = "obs gripper in meters; actions gripper binary 0=closed, 1=open."
                for k, v in success_info.items():
                    f.attrs[f"success_{k}"] = v

                obs = f.create_group("obs")
                obs.create_dataset("cartesian_position", data=obs_cartesian_position, **comp)
                obs.create_dataset("joint_position", data=obs_joint_position, **comp)
                obs.create_dataset("gripper_position", data=obs_gripper_position, **comp)

                actions = f.create_group("actions")
                actions.create_dataset("cartesian_position", data=act_cartesian_position, **comp)
                actions.create_dataset("joint_position", data=act_joint_position, **comp)
                actions.create_dataset("gripper_position", data=act_gripper_position, **comp)
                actions.create_dataset("cartesian_velocity", data=act_cartesian_velocity, **comp)
                actions.create_dataset("joint_velocity", data=act_joint_velocity, **comp)

                f.create_dataset("dones", data=dones)
                f.create_dataset("rewards", data=rewards)

            # init.json (scene + per-env init + camera + randomization)
            meta = {
                "robot": {
                    "left_base_pos": self.scene_init["fixed_entities"]["franka_left"]["pos"],
                    "right_base_pos": self.scene_init["fixed_entities"]["franka_right"]["pos"],
                    "home_qpos": self.scene_init["robot_home_qpos_9"],
                },
                "pipette": {
                    "init_pos": pipette_init_pos.tolist(),
                    "init_quat": pp[0, 3:7].tolist(),
                },
                "cell_dish": self.scene_init["fixed_entities"]["cell_dish"],
                "success_check": {
                    "success": bool(success),
                    **success_info,
                },
                "scene": {
                    "n_envs": int(self.cfg.n_envs),
                    "env_idx": int(env_i),
                    "envs_offset_xyz_m": env_offset.tolist(),
                    "dt": self.cfg.dt,
                    "substeps": self.cfg.substeps,
                    "record_hz": 1.0 / self.cfg.dt,
                },
                "cameras": self.scene_init["cameras"],
            }
            with open(ep_dir / "init.json", "w", encoding="utf-8") as f:
                json.dump(meta, f, ensure_ascii=False, indent=2)

        print(f"[done] wrote {self.cfg.n_envs} episodes to {self.cfg.base_dir}")


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--n_envs", type=int, default=5)
    p.add_argument("--dt", type=float, default=5e-2)
    p.add_argument("--substeps", type=int, default=25)
    p.add_argument("--output", type=str, default="results/discard_pipette_tip/dataset")
    p.add_argument("--viewer", action="store_true")
    p.add_argument("--record_video", action="store_true")
    return p.parse_args()


def main():
    args = parse_args()
    cfg = TaskConfig(
        n_envs=args.n_envs,
        dt=args.dt,
        substeps=args.substeps,
        base_dir=Path(args.output),
        show_viewer=args.viewer,
        record_video=args.record_video,
    )
    ParallelPipetteTask(cfg).run()


if __name__ == "__main__":
    main()
