XL2   Æ  ÇÄ&—Ä                Ó«!P¶ÖÄ     Å™ƒ¤Type¥V2ObjÞ ¢IDÄ                ¤DDirÄueg—bpO”›3m®ê;¦EcAlgo£EcM£EcN §EcBSizeÎ   §EcIndex¦EcDist‘¨CSumAlgo¨PartNums‘©PartETags‘Ù a194ac4bb4ba2e6e03cdc7466a0ca1aa©PartSizes‘Î ªPartASizes‘Î §PartIdx‘Ä ¤SizeÎ ¥MTimeÏ«!P¶Ö§MetaSys‚¼x-minio-internal-inline-dataÄtrue½x-rustfs-internal-inline-dataÄtrue§MetaUsr‚¤etagÙ a194ac4bb4ba2e6e03cdc7466a0ca1aa¬content-type´text/x-python-script¡v Îß2Ã¤nullÆ =b7ŸÒ]ÕIf~ï‰ÜfŸç.=^çSlÃË¾A=µúæêimport argparse
import json
from dataclasses import dataclass
from pathlib import Path

import genesis as gs
import h5py
import numpy as np
import torch
from scipy.spatial.transform import Rotation as R, Slerp

# Default fallback (overridden per-env in build_scene)
Down_height = 0.15

# Tube slot local coordinates (relative to rack). Only xy are used; z forced to 0.854 globally.
TUBE_SLOT_LOCAL_XY = np.array([
    (0.028, 0.018), (0.064, 0.018), (0.100, 0.018),
    (0.136, 0.018), (0.172, 0.018),
    (0.028, -0.018), (0.064, -0.018), (0.100, -0.018),
    (0.136, -0.018), (0.172, -0.018),
], dtype=np.float32)

# Linear mapping: size_z in [0.15, 0.35] -> Down_height in [0.15, 0.10]
def down_height_from_size_z(size_z: np.ndarray) -> np.ndarray:
    return 0.1875 - 0.25 * np.asarray(size_z, dtype=np.float32)


# ---------------------------------------------------------------------------
# Scene object wrappers (entities built once, poses are [n_envs, ...] tensors)
# ---------------------------------------------------------------------------

def _to_np(x):
    return x.detach().cpu().numpy() if isinstance(x, torch.Tensor) else np.asarray(x)


class Table:
    def __init__(self, scene, pos=(0.0, 0.0, 0.0), euler=(0, 0, 90)):
        self.entity = scene.add_entity(
            gs.morphs.MJCF(file='asset/model/misc/simple_table.xml',
                           pos=pos, decimate=False, euler=euler)
        )


class Tube_Rack:
    def __init__(self, scene, pos=(0.1, 0.0, 0.854), euler=(0, 0, 0)):
        self.entity = scene.add_entity(
            gs.morphs.MJCF(file='asset/model/object/centrifuge_10slot.gen.xml',
                           pos=pos, decimate=False, euler=euler)
        )
        self.local_50ml_holes = [
            (-0.072, 0.018, -0.024), (-0.036, 0.018, -0.024), (0.000, 0.018, -0.024),
            (0.036, 0.018, -0.024), (0.072, 0.018, -0.024),
            (-0.072, -0.018, -0.024), (-0.036, -0.018, -0.024), (0.000, -0.018, -0.024),
            (0.036, -0.018, -0.024), (0.072, -0.018, -0.024),
        ]


class Tube:
    def __init__(self, scene, pos=(0.136, -0.018, 0.83033), euler=(0, 0, 0)):
        self.entity = scene.add_entity(
            gs.morphs.MJCF(file='asset/model/object/centrifuge_50ml_collision.gen.xml',
                           pos=pos, decimate=False, euler=euler,
                           batch_fixed_verts=True)
        )
        self.entity.set_friction(2.0)

    def get_pose(self):
        """Returns (pos [n_envs,3], quat [n_envs,4]) numpy."""
        return _to_np(self.entity.get_pos()), _to_np(self.entity.get_quat())

    def get_grasp_pose(self):
        """Batched top-down grasp pose for every env. Returns (pos[n,3], quat[n,4] wxyz)."""
        pos, quat = self.get_pose()                              # [n,3], [n,4] wxyz
        n = pos.shape[0]
        r_tube = R.from_quat(np.concatenate([quat[:, 1:4], quat[:, :1]], axis=1))  # xyzw
        mats = r_tube.as_matrix()                                # [n,3,3]
        v_long = mats[:, :, 2]                                   # [n,3]
        cap_pos = pos + v_long * 0.095

        z_grip = np.tile(np.array([0.0, 0.0, -1.0]), (n, 1))
        y_grip = np.cross(z_grip, v_long)
        norms = np.linalg.norm(y_grip, axis=1, keepdims=True)
        bad = norms[:, 0] < 1e-3
        y_grip_safe = np.where(norms < 1e-3, np.tile([0.0, 1.0, 0.0], (n, 1)), y_grip / np.where(norms < 1e-3, 1.0, norms))
        cap_pos = cap_pos + np.where(bad[:, None],
                                     np.tile([0.0, 0.0, 0.125], (n, 1)),
                                     np.tile([0.0, 0.0, 0.115], (n, 1)))

        x_grip = np.cross(y_grip_safe, z_grip)
        x_grip /= np.linalg.norm(x_grip, axis=1, keepdims=True)

        rot_mats = np.stack([x_grip, y_grip_safe, z_grip], axis=-1)  # [n,3,3]
        qxyzw = R.from_matrix(rot_mats).as_quat()
        target_quat = np.concatenate([qxyzw[:, 3:4], qxyzw[:, 0:3]], axis=1)  # wxyz
        return cap_pos.astype(np.float32), target_quat.astype(np.float32)


class Pipette_Rack:
    def __init__(self, scene, pos=(-0.27, 0.0, 0.824), euler=(0, 0, 90)):
        self.entity = scene.add_entity(
            gs.morphs.MJCF(file='asset/model/object/pipette_rack.gen.xml',
                           pos=pos, decimate=False, euler=euler)
        )


class Pipette:
    def __init__(self, scene, pos=(-0.22, -0.0, 1.08), euler=(0, 10, 180)):
        self.entity = scene.add_entity(
            gs.morphs.MJCF(file='asset/model/object/pipette_add_stiffness.gen.xml',
                           pos=pos, decimate=False, euler=euler,
                           batch_fixed_verts=True)
        )
        self.entity.set_friction(2.0)

    def get_pose(self):
        return _to_np(self.entity.get_pos()), _to_np(self.entity.get_quat())

    def get_button_pos(self):
        button_link = self.entity.get_link('pipette_button')
        button_pos = button_link.get_pos().cpu().numpy()
        button_quat = button_link.get_quat().cpu().numpy()
        return button_pos, button_quat

    def get_grasp_pose(self):
        pos, quat = self.get_pose()
        n = pos.shape[0]
        r_pipette = R.from_quat(np.concatenate([quat[:, 1:4], quat[:, :1]], axis=1))
        mats = r_pipette.as_matrix()
        v_long = mats[:, :, 2]
        cap_pos = pos + v_long * 0.095

        z_grip = np.tile(np.array([0.0, 0.0, -1.0]), (n, 1))
        y_grip = np.cross(z_grip, v_long)
        cap_pos = cap_pos + np.tile([0.140, 0.0, 0.03], (n, 1))

        y_norm = np.linalg.norm(y_grip, axis=1, keepdims=True)
        y_norm = np.where(y_norm < 1e-6, 1.0, y_norm)
        y_grip = y_grip / y_norm
        x_grip = np.cross(y_grip, z_grip)
        x_grip /= np.linalg.norm(x_grip, axis=1, keepdims=True)

        rot_mats = np.stack([x_grip, y_grip, z_grip], axis=-1)
        r_grip = R.from_matrix(rot_mats)

        eul = r_pipette.as_euler('xyz', degrees=True)            # [n,3]
        delta_y = eul[:, 1] + 90.0
        r_adjust = R.from_euler('y', delta_y, degrees=True)
        r_grip = r_grip * r_adjust

        qxyzw = r_grip.as_quat()
        target_quat = np.concatenate([qxyzw[:, 3:4], qxyzw[:, 0:3]], axis=1)
        return cap_pos.astype(np.float32), target_quat.astype(np.float32)


class Liquid:
    def __init__(self, scene, pos=(-0.036, 0.018, 1.1), size=(0.016, 0.016, 0.35)):
        self.entity = scene.add_entity(
            material=gs.materials.SPH.Liquid(
                mu=0.001, gamma=0.01, stiffness=50000.0, sampler="regular",
            ),
            morph=gs.morphs.Box(pos=pos, size=size),
            surface=gs.surfaces.Default(color=(0.4, 0.8, 1.0, 1.0), vis_mode='particle'),
        )


# ---------------------------------------------------------------------------
# Quaternion helpers (wxyz)
# ---------------------------------------------------------------------------

def wxyz_to_R(q):
    """q: [n,4] wxyz -> scipy Rotation batched."""
    q = np.asarray(q)
    return R.from_quat(np.concatenate([q[:, 1:4], q[:, :1]], axis=1))

def R_to_wxyz(r):
    qxyzw = r.as_quat()
    if qxyzw.ndim == 1:
        return np.array([qxyzw[3], qxyzw[0], qxyzw[1], qxyzw[2]], dtype=np.float32)
    return np.concatenate([qxyzw[:, 3:4], qxyzw[:, 0:3]], axis=1).astype(np.float32)

def pose_to_cartesian6(pos, quat, envs_offset=None):
    """pos[n,3], quat[n,4] wxyz -> [x,y,z,rx,ry,rz] with Euler XYZ in radians."""
    pos = np.asarray(pos, dtype=np.float32)
    quat = np.asarray(quat, dtype=np.float32)
    if pos.ndim == 1:
        pos = pos[None, :]
    if quat.ndim == 1:
        quat = quat[None, :]
    euler = wxyz_to_R(quat).as_euler("xyz", degrees=False).astype(np.float32)
    return np.concatenate([pos, euler], axis=1).astype(np.float32)

def next_frame(values):
    values = np.asarray(values, dtype=np.float32)
    if values.shape[0] <= 1:
        return values.copy()
    return np.concatenate([values[1:], values[-1:]], axis=0).astype(np.float32)

def finite_difference(values, dt):
    values = np.asarray(values, dtype=np.float32)
    out = np.zeros_like(values, dtype=np.float32)
    if values.shape[0] > 1:
        out[:-1] = (values[1:] - values[:-1]) / float(dt)
    return out

def binary_gripper(gripper):
    return (np.asarray(gripper, dtype=np.float32) > 0.5).astype(np.float32)

def max_true_run(mask):
    mask = np.asarray(mask, dtype=bool).reshape(-1)
    best = 0
    cur = 0
    for v in mask:
        if v:
            cur += 1
            best = max(best, cur)
        else:
            cur = 0
    return int(best)

def evaluate_pipette_success(
    pipette_tip_pos,
    tube_pose,
    liquid_particles,
    tube_xy_radius=0.035,
    liquid_xy_radius=0.055,
    surface_margin=0.003,
    bottom_margin=0.02,
    min_consecutive_frames=3,
):
    """Success if the pipette tip stays inside the tube and below liquid surface."""
    pipette_tip_pos = np.asarray(pipette_tip_pos, dtype=np.float32)
    tube_pose = np.asarray(tube_pose, dtype=np.float32)
    liquid_particles = np.asarray(liquid_particles, dtype=np.float32)
    T = min(pipette_tip_pos.shape[0], tube_pose.shape[0], liquid_particles.shape[0])
    if T <= 0:
        return False, {
            "max_consecutive_frames": 0,
            "first_success_frame": -1,
            "tube_xy_radius_m": float(tube_xy_radius),
            "surface_margin_m": float(surface_margin),
        }

    tip = pipette_tip_pos[:T]
    tube_xy = tube_pose[:T, :2]
    particles = liquid_particles[:T]

    surface_z = np.full((T,), np.nan, dtype=np.float32)
    bottom_z = np.full((T,), np.nan, dtype=np.float32)
    for t in range(T):
        p = particles[t]
        valid = np.isfinite(p).all(axis=1) & (p[:, 2] > 0.5)
        near_liquid_col = np.linalg.norm(p[:, :2] - tube_xy[t], axis=1) <= liquid_xy_radius
        mask = valid & near_liquid_col
        if not np.any(mask):
            mask = valid
        if np.any(mask):
            z = p[mask, 2]
            surface_z[t] = np.percentile(z, 95)
            bottom_z[t] = np.percentile(z, 5)

    tip_xy_dist = np.linalg.norm(tip[:, :2] - tube_xy, axis=1)
    finite_surface = np.isfinite(surface_z) & np.isfinite(bottom_z) & np.isfinite(tip).all(axis=1)
    inside_tube_xy = tip_xy_dist <= tube_xy_radius
    below_surface = tip[:, 2] <= (surface_z - surface_margin)
    above_bottom = tip[:, 2] >= (bottom_z - bottom_margin)
    success_mask = finite_surface & inside_tube_xy & below_surface & above_bottom

    max_run = max_true_run(success_mask)
    first_success = int(np.argmax(success_mask)) if np.any(success_mask) else -1
    best_surface_depth = float(np.nanmax(surface_z - tip[:, 2])) if np.any(finite_surface) else float("nan")
    min_tip_xy_dist = float(np.nanmin(tip_xy_dist)) if tip_xy_dist.size else float("nan")
    info = {
        "max_consecutive_frames": int(max_run),
        "first_success_frame": first_success,
        "min_consecutive_frames": int(min_consecutive_frames),
        "tube_xy_radius_m": float(tube_xy_radius),
        "liquid_xy_radius_m": float(liquid_xy_radius),
        "surface_margin_m": float(surface_margin),
        "bottom_margin_m": float(bottom_margin),
        "best_surface_depth_m": best_surface_depth,
        "min_tip_xy_dist_m": min_tip_xy_dist,
    }
    return bool(max_run >= min_consecutive_frames), info

def broadcast_quat(q, n):
    q = np.asarray(q, dtype=np.float32)
    if q.ndim == 1:
        return np.tile(q[None, :], (n, 1))
    return q

def broadcast_pos(p, n):
    p = np.asarray(p, dtype=np.float32)
    if p.ndim == 1:
        return np.tile(p[None, :], (n, 1))
    return p

def slerp_batch(q0_wxyz, q1_wxyz, alpha):
    """Per-env slerp. q0,q1: [n,4] wxyz, alpha scalar -> [n,4] wxyz."""
    r0 = wxyz_to_R(q0_wxyz)
    r1 = wxyz_to_R(q1_wxyz)
    n = q0_wxyz.shape[0]
    out = np.zeros_like(q0_wxyz, dtype=np.float32)
    for i in range(n):
        s = Slerp([0, 1], R.concatenate([r0[i], r1[i]]))
        out[i] = R_to_wxyz(s(alpha))
    return out


# ---------------------------------------------------------------------------
# Weld attachment checker (batched-aware, mirrors pipette3.WeldAttachmentChecker)
# ---------------------------------------------------------------------------

class WeldAttachmentChecker:
    """
    Batched weld rule across n_envs.

    Per-env state:
      attach when:  left finger contact AND right finger contact AND no support contact
      detach when:  bilateral finger contact gone OR support contact present

    acquire_steps / release_steps smooth one-frame contact noise.
    """
    def __init__(
        self,
        scene,
        robot_entity,
        ee_link,
        left_finger_link,
        right_finger_link,
        object_entity,
        support_entities=None,
        acquire_steps=2,
        release_steps=2,
        n_envs=1,
    ):
        self.scene = scene
        self.rigid = scene.sim.rigid_solver

        self.robot = robot_entity
        self.ee_link = ee_link
        self.left_finger = left_finger_link
        self.right_finger = right_finger_link
        self.obj = object_entity

        self.support_entities = [] if support_entities is None else list(support_entities)

        self.acquire_steps = acquire_steps
        self.release_steps = release_steps

        self.n_envs = n_envs
        self.active = np.zeros(n_envs, dtype=bool)
        self._acquire_counter = np.zeros(n_envs, dtype=np.int32)
        self._release_counter = np.zeros(n_envs, dtype=np.int32)

        self.link_obj = np.array([self.obj.base_link.idx], dtype=gs.np_int)
        self.link_ee = np.array([self.ee_link.idx], dtype=gs.np_int)

    # ---- contact queries (per-env) ------------------------------------------
    def _contact_per_env(self, entity):
        """Return [n_envs] bool array indicating contact between obj and entity."""
        info = self.obj.get_contacts(with_entity=entity)
        n = self.n_envs
        # Try valid_mask (batched, shape [n_envs, max_contacts])
        if "valid_mask" in info:
            vm = info["valid_mask"]
            arr = vm.detach().cpu().numpy() if isinstance(vm, torch.Tensor) else np.asarray(vm)
            if arr.ndim == 0:
                return np.array([bool(arr)] * n, dtype=bool)
            if arr.ndim == 1:
                # Single env case
                return np.array([bool(arr.any())] * n, dtype=bool) if n == 1 else \
                       np.array([bool(arr[i]) if i < arr.shape[0] else False for i in range(n)], dtype=bool)
            return arr.any(axis=tuple(range(1, arr.ndim))).astype(bool)
        if "geom_a" in info:
            ga = info["geom_a"]
            arr = ga.detach().cpu().numpy() if isinstance(ga, torch.Tensor) else np.asarray(ga)
            if arr.ndim <= 1:
                return np.array([int(arr.size) > 0] * n, dtype=bool)
            return (arr.shape[1] > 0) * np.ones(n, dtype=bool)
        return np.zeros(n, dtype=bool)

    def left_contact(self):
        return self._contact_per_env(self.left_finger)

    def right_contact(self):
        return self._contact_per_env(self.right_finger)

    def has_both_finger_contacts(self):
        return self.left_contact() & self.right_contact()

    def has_support_contact(self):
        if not self.support_entities:
            return np.zeros(self.n_envs, dtype=bool)
        out = np.zeros(self.n_envs, dtype=bool)
        for ent in self.support_entities:
            out = out | self._contact_per_env(ent)
        return out

    def grasp_success_now(self):
        return self.has_both_finger_contacts() & (~self.has_support_contact())

    def grasp_failure_now(self):
        return (~self.has_both_finger_contacts()) | self.has_support_contact()

    def finger_opening(self):
        q = _to_np(self.robot.get_dofs_position())
        return q[:, -2:].mean(axis=1)

    # ---- weld add/remove ----------------------------------------------------
    # NOTE: Genesis' add_weld_constraint asserts that the (link1, link2) pair
    # is not already present in ANY env (global check, not per-env). To work
    # around this we attach the weld to ALL envs once on first acquisition,
    # and detach from ALL envs once any env loses grasp persistently.
    def _weld_already_present(self):
        """Check if a weld between link_obj and link_ee already exists anywhere."""
        try:
            info = self.rigid.get_weld_constraints(as_tensor=True, to_torch=True)
            link_a = info["link_a"]
            link_b = info["link_b"]
            la = int(self.link_obj[0]); lb = int(self.link_ee[0])
            mask = (((link_a == la) | (link_b == la)) &
                    ((link_a == lb) | (link_b == lb)))
            return bool(mask.any().item()) if hasattr(mask, "any") else bool(np.asarray(mask).any())
        except Exception:
            return False

    def _attach_envs(self, env_idx):
        if env_idx.size == 0:
            return
        if self.active.any() or self._weld_already_present():
            # Already welded globally; just mark these envs as active.
            self.active[env_idx] = True
            self._acquire_counter[env_idx] = 0
            self._release_counter[env_idx] = 0
            return
        try:
            self.rigid.add_weld_constraint(self.link_obj, self.link_ee)
        except (AssertionError, Exception) as e:
            # If the weld already exists in the solver (e.g., due to a stale
            # constraint that survived a prior detach), just mark as attached.
            print(f"[WELD] add_weld_constraint skipped ({type(e).__name__}: {e})")
            self.active[:] = True
            self._acquire_counter[:] = 0
            self._release_counter[:] = 0
            return
        self.active[:] = True
        self._acquire_counter[:] = 0
        self._release_counter[:] = 0

    def _detach_envs(self, env_idx):
        if env_idx.size == 0:
            return
        if not (self.active.any() or self._weld_already_present()):
            self.active[:] = False
            return
        try:
            self.rigid.delete_weld_constraint(self.link_obj, self.link_ee)
        except (AssertionError, Exception) as e:
            print(f"[WELD] delete_weld_constraint skipped ({type(e).__name__}: {e})")
        self.active[:] = False
        self._acquire_counter[:] = 0
        self._release_counter[:] = 0

    def update(self):
        success = self.grasp_success_now()
        failure = self.grasp_failure_now()

        # inactive envs accumulate acquire counter on success
        inactive = ~self.active
        self._acquire_counter = np.where(inactive & success, self._acquire_counter + 1, 0) * inactive.astype(np.int32) \
                                + self._acquire_counter * self.active.astype(np.int32)
        # Simpler explicit update
        self._acquire_counter[inactive & success] += 1
        self._acquire_counter[inactive & ~success] = 0

        self._release_counter[self.active & failure] += 1
        self._release_counter[self.active & ~failure] = 0

        attach_envs = np.where(inactive & (self._acquire_counter >= self.acquire_steps))[0].astype(np.int32)
        self._attach_envs(attach_envs)

        detach_envs = np.where(self.active & (self._release_counter >= self.release_steps))[0].astype(np.int32)
        self._detach_envs(detach_envs)

        return {
            "attached": self.active.copy(),
            "attach_now": attach_envs,
            "detach_now": detach_envs,
        }


# ---------------------------------------------------------------------------
# Parallel task
# ---------------------------------------------------------------------------

@dataclass
class TaskConfig:
    n_envs: int = 4
    dt: float = 5e-2
    substeps: int = 100
    env_spacing: tuple = (2.0, 2.0)
    base_dir: Path = Path("results/pipette/dataset_parallel")
    episode_prefix: str = "episode_"
    show_viewer: bool = False
    record_video: bool = False

    # Randomization ranges
    pipette_y_range: tuple = (-0.1, 0.1)
    liquid_size_z_range: tuple = (0.15, 0.35)


class ParallelPipetteTask:
    def __init__(self, cfg: TaskConfig):
        self.cfg = cfg
        gs.init(backend=gs.gpu)

        self.scene = None
        self.cam = None
        self.table = None
        self.tube_rack = None
        self.tube = None
        self.pipette = None
        self.pipette_rack = None
        self.liquid = None
        self.franka_left = None
        self.franka_right = None
        self.left_ee = None
        self.right_ee = None
        self.pipette_tip_link = None
        self.attachment_checker = None

        self.motors_dof = np.arange(7)
        self.fingers_dof = np.arange(7, 9)
        self.active_envs = np.arange(cfg.n_envs, dtype=np.int32)

        # per-env recording buffers
        n = cfg.n_envs
        self.buf = {
            "left_state":   [[] for _ in range(n)],
            "left_action":  [[] for _ in range(n)],
            "left_cartesian": [[] for _ in range(n)],
            "right_state":  [[] for _ in range(n)],
            "right_action": [[] for _ in range(n)],
            "right_cartesian": [[] for _ in range(n)],
            "pipette_pose": [[] for _ in range(n)],  # [pos3, quat4]
            "pipette_tip_pos": [[] for _ in range(n)],
            "tube_pose":    [[] for _ in range(n)],
            "liquid_particles": [[] for _ in range(n)],
        }
        self.sim_t = 0.0

    # ----- scene construction ------------------------------------------------
    def build_scene(self):
        n = self.cfg.n_envs

        # Per-env randomization via torch.rand (each run is genuinely random;
        # torch's default RNG is seeded from OS entropy unless the user sets it).
        def _u(low, high, size):
            return (torch.rand(size) * (high - low) + low).cpu().numpy().astype(np.float32)

        # Pipette: random y in [-0.1, 0.1]
        self.pipette_y_rand = _u(self.cfg.pipette_y_range[0],
                                 self.cfg.pipette_y_range[1], (n,))

        # Tube: discrete slot from the 10 rack holes
        slot_idx = torch.randint(low=0, high=TUBE_SLOT_LOCAL_XY.shape[0], size=(n,)).cpu().numpy()
        self.tube_xy_rand = TUBE_SLOT_LOCAL_XY[slot_idx].copy()      # [n,2]
        self.tube_slot_idx = slot_idx.astype(np.int32)

        # Liquid column height
        self.liquid_size_z_rand = _u(self.cfg.liquid_size_z_range[0],
                                     self.cfg.liquid_size_z_range[1], (n,))
        self.down_height_rand = down_height_from_size_z(self.liquid_size_z_rand)

        print(f"[rand] n_envs={n}")
        print(f"  pipette_y:      {self.pipette_y_rand.tolist()}")
        print(f"  tube_slot_idx:  {self.tube_slot_idx.tolist()}")
        print(f"  tube_xy:\n{self.tube_xy_rand}")
        print(f"  liquid_size_z:  {self.liquid_size_z_rand.tolist()}")

        # Liquid is built with MAX size_z; per-env z-size variation is approximated
        # by shifting particle heights after build.
        liquid_size_z_max = float(self.cfg.liquid_size_z_range[1])

        self.scene = gs.Scene(
            sim_options=gs.options.SimOptions(dt=self.cfg.dt, substeps=self.cfg.substeps),
            sph_options=gs.options.SPHOptions(
                lower_bound=(-0.5, -0.5, 0.0),
                upper_bound=(0.8, 0.5, 2.0),
                particle_size=0.005,
            ),
            viewer_options=gs.options.ViewerOptions(
                camera_pos=(2.1, 0.0, 1.5),
                camera_lookat=(0.0, 0.0, 0.8),
                camera_fov=30,
                max_FPS=60,
            ),
            show_viewer=self.cfg.show_viewer,
        )

        if self.cfg.record_video:
            self.cam = self.scene.add_camera(
                res=(640, 480), pos=(2.5, 0.0, 2.0),
                lookat=(0.0, 0.0, 1.0), fov=30, GUI=False,
            )

        self.scene.add_entity(gs.morphs.Plane())
        self.table = Table(self.scene)
        self.tube_rack = Tube_Rack(self.scene)

        # Nominal (pre-randomization) positions used to compute per-env offsets.
        self._tube_nominal_xy = np.array([0.136, -0.018], dtype=np.float32)
        self._tube_nominal_z = 0.83033
        self._pipette_nominal = np.array([-0.22, 0.0, 1.08], dtype=np.float32)
        self._liquid_nominal_pos = np.array([-0.172, 0.018, 1.1], dtype=np.float32)

        self.tube = Tube(self.scene,
                         pos=tuple(self._tube_nominal_xy.tolist()) + (self._tube_nominal_z,),
                         euler=(0, 0, 0))
        self.pipette = Pipette(self.scene,
                               pos=tuple(self._pipette_nominal.tolist()),
                               euler=(0, 10, 180))
        self.pipette_rack = Pipette_Rack(self.scene)

        self.liquid = self.scene.add_entity(
            material=gs.materials.SPH.Liquid(
                mu=0.001, gamma=0.01, stiffness=50000.0, sampler="regular",
            ),
            morph=gs.morphs.Box(
                pos=tuple(self._liquid_nominal_pos.tolist()),
                size=(0.016, 0.016, liquid_size_z_max),
            ),
            surface=gs.surfaces.Default(color=(0.4, 0.8, 1.0, 1.0), vis_mode='particle'),
        )

        self.franka_left = self.scene.add_entity(
            gs.morphs.MJCF(file='xml/franka_emika_panda/panda.xml',
                           pos=(0.0, -0.5, 0.824), euler=(0, 0, 90))
        )
        self.franka_right = self.scene.add_entity(
            gs.morphs.MJCF(file='xml/franka_emika_panda/panda.xml',
                           pos=(0.0, 0.5, 0.824), euler=(0, 0, -90))
        )

        self.scene.build(n_envs=n, env_spacing=self.cfg.env_spacing)

        # ---- Apply per-env randomization ----
        envs_all = torch.arange(n, device=gs.device, dtype=torch.long)

        # Pipette: only y is randomized (x/z/yaw fixed at nominal)
        pipette_pos = np.tile(self._pipette_nominal, (n, 1))
        pipette_pos[:, 1] = self.pipette_y_rand
        self.pipette.entity.set_pos(
            torch.as_tensor(pipette_pos, device=gs.device, dtype=gs.tc_float),
            envs_idx=envs_all,
        )

        # Tube: random slot xy, z fixed to 0.854
        tube_pos = np.zeros((n, 3), dtype=np.float32)
        tube_pos[:, 0:2] = self.tube_xy_rand
        tube_pos[:, 2] = 0.83033
        self.tube.entity.set_pos(
            torch.as_tensor(tube_pos, device=gs.device, dtype=gs.tc_float),
            envs_idx=envs_all,
        )

        # Liquid: xy follows tube xy, z=1.1; z-size variation via particle displacement
        # Shift all particles per-env: dx,dy = tube_xy - nominal_xy, dz = 1.1 - nominal_z.
        dxy = self.tube_xy_rand - self._liquid_nominal_pos[None, 0:2]
        dz = 1.1 - self._liquid_nominal_pos[2]
        offset = np.concatenate([dxy, np.full((n, 1), dz, dtype=np.float32)], axis=1)  # [n,3]

        p0 = self.liquid.get_particles_pos()[:, : self.liquid.n_particles, :]           # [n,P,3]
        p0_np = _to_np(p0)
        # Per-env: keep only particles whose local z (relative to column bottom) fits size_z_env;
        # displace the rest far below the floor so they don't interact with the scene.
        col_bottom_z = self._liquid_nominal_pos[2] - liquid_size_z_max / 2.0
        local_z = p0_np[..., 2] - col_bottom_z                                         # [n,P]
        keep_mask = local_z <= self.liquid_size_z_rand[:, None]                         # [n,P]

        new_p = p0_np + offset[:, None, :]
        # Hide unused particles by sending them far below ground within their env bounds
        hide_z = -5.0
        new_p[..., 2] = np.where(keep_mask, new_p[..., 2], hide_z)

        self.liquid.set_particles_pos(
            torch.as_tensor(new_p, device=gs.device, dtype=gs.tc_float)
        )

        # Arm controller setup
        for arm in (self.franka_left, self.franka_right):
            arm.set_dofs_kp(np.array([4500, 4500, 3500, 3500, 2000, 2000, 2000, 100, 100]))
            arm.set_dofs_kv(np.array([450, 450, 350, 350, 200, 200, 200, 10, 10]))
            arm.set_dofs_force_range(
                np.array([-87, -87, -87, -87, -12, -12, -12, -100, -100]),
                np.array([ 87,  87,  87,  87,  12,  12,  12,  100,  100]),
            )
            home_qpos = np.array([0.0, -0.785, 0.0, -2.356, 0.0, 1.571, 0.785, 0.04, 0.04])
            arm.set_dofs_position(home_qpos)
            arm.control_dofs_position(home_qpos)

        self.left_ee = self.franka_left.get_link('hand')
        self.right_ee = self.franka_right.get_link('hand')
        self.pipette_tip_link = self.pipette.entity.get_link('tip/pipette_tip')

        # Cache per-env world offsets so we can convert recorded poses back to
        # each env's local frame (matches the single-env replay used by render_pipette.py).
        try:
            envs_offset = _to_np(self.scene.envs_offset)  # [n,3]
        except Exception:
            try:
                envs_offset = _to_np(self.scene.get_envs_offset())
            except Exception:
                # Fallback: derive from arm base position vs nominal (0, -0.5, 0.824).
                left_base = _to_np(self.franka_left.get_pos())
                if left_base.ndim == 1:
                    left_base = np.tile(left_base[None, :], (self.cfg.n_envs, 1))
                envs_offset = left_base - np.array([0.0, -0.5, 0.824], dtype=np.float32)
        self.envs_offset = envs_offset.astype(np.float32)

        left_finger_link = self.franka_left.get_link('left_finger')
        right_finger_link = self.franka_left.get_link('right_finger')
        self.attachment_checker = WeldAttachmentChecker(
            scene=self.scene,
            robot_entity=self.franka_left,
            ee_link=self.left_ee,
            left_finger_link=left_finger_link,
            right_finger_link=right_finger_link,
            object_entity=self.pipette.entity,
            support_entities=[self.pipette_rack.entity],
            acquire_steps=2,
            release_steps=3,
            n_envs=self.cfg.n_envs,
        )

        # Cache scene init configuration (constants + per-env initial positions
        # of every entity and camera) so it can be saved verbatim into init.json.
        cam_agent_cfg = {
            "res": [256, 256], "pos": [2.5, 0.0, 2.0],
            "lookat": [0.0, 0.0, 1.0], "fov": 30,
        }
        cam_wrist_cfg = {
            "res": [256, 256], "pos": [0.0, 0.0, 0.0],
            "lookat": [1.0, 0.0, 0.0], "fov": 70,
            "attach_link": "hand",
            "offset_translation_xyz": [-0.08, 0.0, -0.035],
            "offset_rotation_euler_deg": {
                "yaw_z": 90.0,
                "pitch_y": 180.0,
                "roll_x": -10.0
            },
            "offset_rotation_composition": "Rz @ Ry @ Rx"
        }
        self.scene_init = {
            "n_envs": int(n),
            "env_spacing_xy_m": list(self.cfg.env_spacing),
            "envs_offset_xyz_m": self.envs_offset.tolist(),
            "fixed_entities": {
                "plane": {"pos": [0.0, 0.0, 0.0]},
                "table": {"file": "asset/model/misc/simple_table.xml",
                          "pos": [0.0, 0.0, 0.0], "euler_deg": [0, 0, 90]},
                "tube_rack": {"file": "asset/model/object/centrifuge_10slot.gen.xml",
                              "pos": [0.1, 0.0, 0.854], "euler_deg": [0, 0, 0]},
                "pipette_rack": {"file": "asset/model/object/pipette_rack.gen.xml",
                                 "pos": [-0.27, 0.0, 0.824], "euler_deg": [0, 0, 90]},
                "franka_left":  {"file": "xml/franka_emika_panda/panda.xml",
                                 "pos": [0.0, -0.5, 0.824], "euler_deg": [0, 0, 90]},
                "franka_right": {"file": "xml/franka_emika_panda/panda.xml",
                                 "pos": [0.0, 0.5, 0.824], "euler_deg": [0, 0, -90]},
            },
            "robot_home_qpos_9": [0.0, -0.785, 0.0, -2.356, 0.0, 1.571, 0.785, 0.04, 0.04],
            "robot_finger_open_m": float(self._FINGER_OPEN_M),
            "cameras": {
                "agentview": cam_agent_cfg,
                "wrist_camera": {
                    **cam_wrist_cfg,
                    "attach_link": "franka_left/hand",
                    "output": "wrist.mp4",
                },
                "wrist_camera_2": {
                    **cam_wrist_cfg,
                    "attach_link": "franka_right/hand",
                    "output": "wrist_2.mp4",
                },
            },
            "per_env": [
                {
                    "env_idx": int(i),
                    "envs_offset_xyz_m": self.envs_offset[i].tolist(),
                    "pipette": {
                        "file": "asset/model/object/pipette_add_stiffness.gen.xml",
                        "pos_local": [float(self._pipette_nominal[0]),
                                       float(self.pipette_y_rand[i]),
                                       float(self._pipette_nominal[2])],
                        "euler_deg": [0, 10, 180],
                    },
                    "tube": {
                        "file": "asset/model/object/centrifuge_50ml_collision.gen.xml",
                        "pos_local": [float(self.tube_xy_rand[i, 0]),
                                       float(self.tube_xy_rand[i, 1]),
                                       float(self._tube_nominal_z)],
                        "euler_deg": [0, 0, 0],
                        "slot_idx": int(self.tube_slot_idx[i]),
                    },
                    "liquid": {
                        "morph": "box",
                        "pos_local": [float(self.tube_xy_rand[i, 0]),
                                       float(self.tube_xy_rand[i, 1]),
                                       1.1],
                        "size": [0.016, 0.016, float(self.liquid_size_z_rand[i])],
                        "size_z_max_built": liquid_size_z_max,
                    },
                    "down_height_m": float(self.down_height_rand[i]),
                }
                for i in range(n)
            ],
        }

        if self.cam is not None:
            self.cam.start_recording()

    # ----- stepping and recording -------------------------------------------
    # Franka panda finger range per side: 0 (closed) .. 0.04 (open). The
    # mean of the two fingers is therefore in [0, 0.04]. We normalize to
    # [0, 1] (DROID-style) for state/action gripper component.
    _FINGER_OPEN_M = 0.04

    def _norm_grip(self, m):
        return float(np.clip(m / self._FINGER_OPEN_M, 0.0, 1.0))

    def _norm_grip_arr(self, m_arr):
        return np.clip(np.asarray(m_arr, dtype=np.float32) / self._FINGER_OPEN_M, 0.0, 1.0)

    def _record_step(self, left_action_np=None, right_action_np=None,
                     left_grip_cmd=None, right_grip_cmd=None):
        """
        left_action_np / right_action_np: [n,7] commanded joint targets for next frame
        left_grip_cmd  / right_grip_cmd : [n] commanded gripper opening (meters), next frame
        """
        n = self.cfg.n_envs
        l_q = _to_np(self.franka_left.get_qpos())   # [n, 9]
        r_q = _to_np(self.franka_right.get_qpos())
        # Current gripper opening (meters, mean of two fingers) -> normalized [0,1]
        l_gr_state = self._norm_grip_arr(l_q[:, 7:9].mean(axis=1, keepdims=True))
        r_gr_state = self._norm_grip_arr(r_q[:, 7:9].mean(axis=1, keepdims=True))
        l_state = np.concatenate([l_q[:, :7], l_gr_state], axis=1)   # [n,8]
        r_state = np.concatenate([r_q[:, :7], r_gr_state], axis=1)
        l_cart = pose_to_cartesian6(_to_np(self.left_ee.get_pos()), _to_np(self.left_ee.get_quat()), self.envs_offset)
        r_cart = pose_to_cartesian6(_to_np(self.right_ee.get_pos()), _to_np(self.right_ee.get_quat()), self.envs_offset)

        # Action: [7 commanded joints, normalized commanded gripper opening for NEXT frame]
        if left_action_np is None:
            la_joints = l_q[:, :7]
        else:
            la_joints = np.asarray(left_action_np, dtype=np.float32)
        if right_action_np is None:
            ra_joints = r_q[:, :7]
        else:
            ra_joints = np.asarray(right_action_np, dtype=np.float32)

        if left_grip_cmd is None:
            la_grip = l_gr_state
        else:
            lg_arr = np.asarray(left_grip_cmd, dtype=np.float32).reshape(-1)
            if lg_arr.size == 1:
                lg_arr = np.full((n,), float(lg_arr[0]), dtype=np.float32)
            la_grip = self._norm_grip_arr(lg_arr).reshape(-1, 1)
        if right_grip_cmd is None:
            ra_grip = r_gr_state
        else:
            rg_arr = np.asarray(right_grip_cmd, dtype=np.float32).reshape(-1)
            if rg_arr.size == 1:
                rg_arr = np.full((n,), float(rg_arr[0]), dtype=np.float32)
            ra_grip = self._norm_grip_arr(rg_arr).reshape(-1, 1)

        left_action_full = np.concatenate([la_joints, la_grip], axis=1)
        right_action_full = np.concatenate([ra_joints, ra_grip], axis=1)

        pip_pos, pip_quat = self.pipette.get_pose()
        tube_pos, tube_quat = self.tube.get_pose()
        pipette_tip_link = self.pipette_tip_link
        tip_pos = _to_np(pipette_tip_link.get_pos())
        # Genesis returns per-env local coordinates here. Store them directly
        # for single-env replay; subtracting envs_offset would shift non-zero
        # envs away from the camera.
        pip_pose = np.concatenate([pip_pos, pip_quat], axis=1)
        tube_pose = np.concatenate([tube_pos, tube_quat], axis=1)

        particles = _to_np(self.liquid.get_particles_pos())[:, :self.liquid.n_particles, :]

        for i in range(n):
            self.buf["left_state"][i].append(l_state[i])
            self.buf["right_state"][i].append(r_state[i])
            self.buf["left_action"][i].append(left_action_full[i])
            self.buf["right_action"][i].append(right_action_full[i])
            self.buf["left_cartesian"][i].append(l_cart[i])
            self.buf["right_cartesian"][i].append(r_cart[i])
            self.buf["pipette_pose"][i].append(pip_pose[i])
            self.buf["pipette_tip_pos"][i].append(tip_pos[i])
            self.buf["tube_pose"][i].append(tube_pose[i])
            self.buf["liquid_particles"][i].append(particles[i])

    def _step(self, left_action=None, right_action=None,
              left_grip_cmd=None, right_grip_cmd=None):
        self.scene.step()
        self.sim_t += self.cfg.dt
        if self.attachment_checker is not None:
            self.attachment_checker.update()
        if self.cam is not None:
            self.cam.render()
        self._record_step(left_action, right_action, left_grip_cmd, right_grip_cmd)

    # ----- parallel motion primitives ---------------------------------------
    def _ik(self, arm, ee_link, pos, quat):
        """pos [n,3], quat [n,4] wxyz -> qpos [n, dof]"""
        pos_t = torch.as_tensor(pos, device=gs.device, dtype=gs.tc_float)
        quat_t = torch.as_tensor(quat, device=gs.device, dtype=gs.tc_float)
        return arm.inverse_kinematics(link=ee_link, pos=pos_t, quat=quat_t)

    def _apply_arm(self, arm, q_cmd, grip_force, grip_open_m=None):
        """
        grip_force: <0 means close, >0 means open.
        grip_open_m: optional, the *intended* gripper opening in meters this
                    command corresponds to (used purely for action recording).
        """
        if grip_open_m is None:
            grip_open_m = self._grip_open_for_force(grip_force)
        q_target = np.asarray(q_cmd, dtype=np.float32).copy()
        if q_target.ndim == 1:
            q_target = q_target[None, :]
        if q_target.shape[1] >= 9 and grip_open_m is not None:
            q_target[:, 7:9] = float(grip_open_m)
            arm.control_dofs_position(q_target)
        else:
            arm.control_dofs_position(q_target[:, :7], self.motors_dof)
        return grip_open_m

    def _grip_open_for_force(self, grip_force):
        """Map the convention used in this script (-10 -> closed, +ve -> open) to meters."""
        if grip_force is None:
            return None
        return 0.0 if grip_force < 0 else self._FINGER_OPEN_M

    def ik_move_to_pose(self, arm, ee_link, target_pos, target_quat,
                       n_interp=120, gripper_force=2.0, n_steps=80, record_arm='left'):
        n = self.cfg.n_envs
        target_pos = broadcast_pos(target_pos, n)
        target_quat = broadcast_quat(target_quat, n)

        q_goal = _to_np(self._ik(arm, ee_link, target_pos, target_quat))
        q_start = _to_np(arm.get_qpos())
        grip_open_m = self._grip_open_for_force(gripper_force)

        for t in range(1, n_interp + 1):
            alpha = t / n_interp
            q_cmd = (1.0 - alpha) * q_start + alpha * q_goal
            self._apply_arm(arm, q_cmd, gripper_force)
            la = q_cmd[:, :7] if record_arm == 'left' else None
            ra = q_cmd[:, :7] if record_arm == 'right' else None
            lg = grip_open_m if record_arm == 'left' else None
            rg = grip_open_m if record_arm == 'right' else None
            self._step(la, ra, lg, rg)

        for _ in range(n_steps):
            self._step()
        return q_goal

    def grisp(self, arm, gripper_force=-10.0, n_steps=100, record_arm='left'):
        grip_open_m = self._grip_open_for_force(gripper_force)
        q_start = _to_np(arm.get_qpos()).astype(np.float32)
        q_target = q_start.copy()
        q_target[:, 7:9] = grip_open_m
        steps = max(1, int(n_steps))
        for step in range(1, steps + 1):
            alpha = step / steps
            q_cmd = (1.0 - alpha) * q_start + alpha * q_target
            arm.control_dofs_position(q_cmd)
            la = q_cmd[:, :7] if record_arm == 'left' else None
            ra = q_cmd[:, :7] if record_arm == 'right' else None
            lg = grip_open_m if record_arm == 'left' else None
            rg = grip_open_m if record_arm == 'right' else None
            self._step(la, ra, lg, rg)

    def _obj_pose_batched(self, obj):
        """Returns (pos[n,3], quat[n,4] wxyz, R batched)."""
        p = _to_np(obj.get_pos())
        q = _to_np(obj.get_quat())
        if p.ndim == 1:
            p = np.tile(p[None, :], (self.cfg.n_envs, 1))
        if q.ndim == 1:
            q = np.tile(q[None, :], (self.cfg.n_envs, 1))
        r = wxyz_to_R(q)
        return p, q, r

    def move_to_pose_with_refinement(self, arm, ee_link,
                                     source_object=None,
                                     target_object=None,
                                     target_pos=None,
                                     target_quat=None,
                                     coarse_interp=12,
                                     coarse_force=2.0,
                                     coarse_steps=5,
                                     record_arm='left'):
        """Batched analog of pipette3.move_to_pose_with_refinement (coarse only)."""
        n = self.cfg.n_envs
        ee_pos0 = _to_np(ee_link.get_pos())
        ee_quat0 = _to_np(ee_link.get_quat())
        ee_r0 = wxyz_to_R(ee_quat0)

        if source_object is not None:
            src_pos0, _, src_r0 = self._obj_pose_batched(source_object)
            source_to_ee_pos = np.zeros((n, 3), dtype=np.float32)
            for i in range(n):
                source_to_ee_pos[i] = src_r0[i].inv().apply(ee_pos0[i] - src_pos0[i])
            source_to_ee_r = [src_r0[i].inv() * ee_r0[i] for i in range(n)]

            if target_object is not None:
                tgt_src_pos0, _, tgt_src_r0 = self._obj_pose_batched(target_object)
            else:
                if target_pos is None:
                    raise ValueError("source_object æ¨¡å¼ä¸‹ï¼Œå¿…é¡»æä¾› target_object æˆ– target_pos")
                tgt_src_pos0 = broadcast_pos(target_pos, n)
                if target_quat is None:
                    tgt_src_r0 = src_r0
                else:
                    tgt_src_r0 = wxyz_to_R(broadcast_quat(target_quat, n))

            coarse_pos = np.zeros((n, 3), dtype=np.float32)
            coarse_quat = np.zeros((n, 4), dtype=np.float32)
            for i in range(n):
                coarse_pos[i] = tgt_src_pos0[i] + tgt_src_r0[i].apply(source_to_ee_pos[i])
                coarse_r = tgt_src_r0[i] * source_to_ee_r[i]
                qx, qy, qz, qw = coarse_r.as_quat()
                coarse_quat[i] = [qw, qx, qy, qz]
        else:
            if target_pos is None or target_quat is None:
                raise ValueError("ç›´æŽ¥ ee æ¨¡å¼ä¸‹ï¼Œå¿…é¡»æä¾› target_pos å’Œ target_quat")
            coarse_pos = broadcast_pos(target_pos, n)
            coarse_quat = broadcast_quat(target_quat, n)

        self.ik_move_to_pose(arm, ee_link, coarse_pos, coarse_quat,
                             n_interp=coarse_interp, gripper_force=coarse_force,
                             n_steps=coarse_steps, record_arm=record_arm)

    def move_two_arms_with_refinement(self, left_cfg, right_cfg):
        """Batched analog of pipette3.move_two_arms_with_refinement (coarse, synchronous)."""
        n = self.cfg.n_envs

        l_ee_pos0 = _to_np(self.left_ee.get_pos())
        l_ee_quat0 = _to_np(self.left_ee.get_quat())
        l_ee_r0 = wxyz_to_R(l_ee_quat0)
        r_ee_pos0 = _to_np(self.right_ee.get_pos())
        r_ee_quat0 = _to_np(self.right_ee.get_quat())
        r_ee_r0 = wxyz_to_R(r_ee_quat0)

        def _compute_goal(ee_pos0, ee_r0_batch, cfg):
            source_object = cfg.get("source_object", None)
            target_object = cfg.get("target_object", None)
            target_pos = cfg.get("target_pos", None)
            target_quat = cfg.get("target_quat", None)

            if source_object is not None:
                src_pos0, _, src_r0 = self._obj_pose_batched(source_object)
                if target_object is not None:
                    tgt_src_pos0, _, tgt_src_r0 = self._obj_pose_batched(target_object)
                else:
                    if target_pos is None:
                        raise ValueError("source_object æ¨¡å¼ä¸‹ï¼Œå¿…é¡»æä¾› target_object æˆ– target_pos")
                    tgt_src_pos0 = broadcast_pos(target_pos, n)
                    if target_quat is None:
                        tgt_src_r0 = src_r0
                    else:
                        tgt_src_r0 = wxyz_to_R(broadcast_quat(target_quat, n))

                goal_pos = np.zeros((n, 3), dtype=np.float32)
                goal_quat = np.zeros((n, 4), dtype=np.float32)
                for i in range(n):
                    goal_pos[i] = ee_pos0[i] + (tgt_src_pos0[i] - src_pos0[i])
                    gr = (tgt_src_r0[i] * src_r0[i].inv()) * ee_r0_batch[i]
                    qx, qy, qz, qw = gr.as_quat()
                    goal_quat[i] = [qw, qx, qy, qz]
                return goal_pos, goal_quat
            else:
                if target_pos is None or target_quat is None:
                    raise ValueError("ç›´æŽ¥ ee æ¨¡å¼ä¸‹ï¼Œå¿…é¡»æä¾› target_pos å’Œ target_quat")
                return broadcast_pos(target_pos, n), broadcast_quat(target_quat, n)

        l_goal_pos, l_goal_quat = _compute_goal(l_ee_pos0, l_ee_r0, left_cfg)
        r_goal_pos, r_goal_quat = _compute_goal(r_ee_pos0, r_ee_r0, right_cfg)

        l_interp = int(left_cfg.get("coarse_interp", 12))
        r_interp = int(right_cfg.get("coarse_interp", 12))
        l_force = float(left_cfg.get("coarse_force", 2.0))
        r_force = float(right_cfg.get("coarse_force", 2.0))
        l_grip_open = self._grip_open_for_force(l_force)
        r_grip_open = self._grip_open_for_force(r_force)
        coarse_iters = max(l_interp, r_interp)

        for t in range(1, coarse_iters + 1):
            if l_interp > 0:
                la = min(t / l_interp, 1.0)
                lp = (1 - la) * l_ee_pos0 + la * l_goal_pos
                lq = slerp_batch(l_ee_quat0, l_goal_quat, la)
                lq_goal = _to_np(self._ik(self.franka_left, self.left_ee, lp, lq))
                self._apply_arm(self.franka_left, lq_goal, l_force)
            else:
                lq_goal = None
            if r_interp > 0:
                ra = min(t / r_interp, 1.0)
                rp = (1 - ra) * r_ee_pos0 + ra * r_goal_pos
                rq = slerp_batch(r_ee_quat0, r_goal_quat, ra)
                rq_goal = _to_np(self._ik(self.franka_right, self.right_ee, rp, rq))
                self._apply_arm(self.franka_right, rq_goal, r_force)
            else:
                rq_goal = None
            self._step(
                lq_goal[:, :7] if lq_goal is not None else None,
                rq_goal[:, :7] if rq_goal is not None else None,
                l_grip_open, r_grip_open,
            )

        coarse_settle = int(max(left_cfg.get("coarse_steps", 0),
                                 right_cfg.get("coarse_steps", 0)))
        for _ in range(coarse_settle):
            self._step()

    def move_two_arms_coarse(self, left_cfg, right_cfg, n_interp=120,
                             gripper_force_l=-10.0, gripper_force_r=-10.0,
                             settle_steps=60):
        """Synchronous batched linear interp (pos lerp + slerp) for both arms."""
        n = self.cfg.n_envs

        l_pos0 = _to_np(self.left_ee.get_pos())
        l_q0 = _to_np(self.left_ee.get_quat())
        r_pos0 = _to_np(self.right_ee.get_pos())
        r_q0 = _to_np(self.right_ee.get_quat())

        l_goal_pos = broadcast_pos(left_cfg['target_pos'], n)
        l_goal_q = broadcast_quat(left_cfg['target_quat'], n)
        r_goal_pos = broadcast_pos(right_cfg['target_pos'], n)
        r_goal_q = broadcast_quat(right_cfg['target_quat'], n)
        l_grip_open = self._grip_open_for_force(gripper_force_l)
        r_grip_open = self._grip_open_for_force(gripper_force_r)

        for t in range(1, n_interp + 1):
            alpha = t / n_interp
            lp = (1 - alpha) * l_pos0 + alpha * l_goal_pos
            rp = (1 - alpha) * r_pos0 + alpha * r_goal_pos
            lq = slerp_batch(l_q0, l_goal_q, alpha)
            rq = slerp_batch(r_q0, r_goal_q, alpha)

            lq_goal = _to_np(self._ik(self.franka_left, self.left_ee, lp, lq))
            rq_goal = _to_np(self._ik(self.franka_right, self.right_ee, rp, rq))

            self._apply_arm(self.franka_left, lq_goal, gripper_force_l)
            self._apply_arm(self.franka_right, rq_goal, gripper_force_r)
            self._step(lq_goal[:, :7], rq_goal[:, :7], l_grip_open, r_grip_open)

        for _ in range(settle_steps):
            self._step()

    def move_two_arms_vertical_z(self, left_dz=0.0, right_dz=0.0,
                                 n_steps=100, n_interp=100,
                                 left_gripper_force=-5.0, right_gripper_force=-5.0):
        n = self.cfg.n_envs
        l_start = _to_np(self.left_ee.get_pos())
        r_start = _to_np(self.right_ee.get_pos())
        l_quat = _to_np(self.left_ee.get_quat())
        r_quat = _to_np(self.right_ee.get_quat())
        # Broadcast per-env dz if scalar
        if np.isscalar(left_dz):
            left_dz = np.full(n, float(left_dz), dtype=np.float32)
        if np.isscalar(right_dz):
            right_dz = np.full(n, float(right_dz), dtype=np.float32)
        l_tgt = l_start.copy(); l_tgt[:, 2] += left_dz
        r_tgt = r_start.copy(); r_tgt[:, 2] += right_dz

        rot_mask_all = [True, True, True]
        for step in range(1, n_steps + 1):
            alpha = step / n_steps
            lp = (1 - alpha) * l_start + alpha * l_tgt
            rp = (1 - alpha) * r_start + alpha * r_tgt
            lq = _to_np(self.franka_left.inverse_kinematics(
                link=self.left_ee,
                pos=torch.as_tensor(lp, device=gs.device, dtype=gs.tc_float),
                quat=torch.as_tensor(l_quat, device=gs.device, dtype=gs.tc_float),
                rot_mask=rot_mask_all,
            ))
            rq = _to_np(self.franka_right.inverse_kinematics(
                link=self.right_ee,
                pos=torch.as_tensor(rp, device=gs.device, dtype=gs.tc_float),
                quat=torch.as_tensor(r_quat, device=gs.device, dtype=gs.tc_float),
                rot_mask=rot_mask_all,
            ))
            self._apply_arm(self.franka_left, lq, left_gripper_force)
            self._apply_arm(self.franka_right, rq, right_gripper_force)
            self._step(lq[:, :7], rq[:, :7],
                       self._grip_open_for_force(left_gripper_force),
                       self._grip_open_for_force(right_gripper_force))

        for _ in range(n_interp):
            self._step()

    # ----- sequence ---------------------------------------------------------
    def run(self):
        self.build_scene()

        # settle
        for _ in range(15):
            self.scene.step()

        right_origin_pos = _to_np(self.right_ee.get_pos()).copy()
        right_origin_quat = _to_np(self.right_ee.get_quat()).copy()

        # Step 1: Move to the pipette
        s1_pos, s1_quat = self.pipette.get_grasp_pose()
        self.ik_move_to_pose(self.franka_left, self.left_ee, s1_pos, s1_quat,
                             n_interp=25, gripper_force=2.0, n_steps=5,
                             record_arm='left')

        # Step 2: Close in and grip
        s2_pos = s1_pos + np.array([-0.037, 0.0, 0.0], dtype=np.float32)
        s2_quat = s1_quat
        self.ik_move_to_pose(self.franka_left, self.left_ee, s2_pos, s2_quat,
                             n_interp=16, gripper_force=2.0, n_steps=0,
                             record_arm='left')

        self.grisp(self.franka_left, gripper_force=-10.0, n_steps=10)

        # Step 3: Lift the pipette
        s3_pos = s2_pos + np.array([0.0, 0.0, 0.1], dtype=np.float32)
        s3_quat = np.tile(np.array([0.7071, 0, -0.7071, 0], dtype=np.float32),
                          (self.cfg.n_envs, 1))
        self.ik_move_to_pose(self.franka_left, self.left_ee, s3_pos, s3_quat,
                             n_interp=15, gripper_force=-10.0, n_steps=0,
                             record_arm='left')

        pipette_tip_link = self.pipette_tip_link

        # Step 4: Move pipette over the tube using source-binding refinement
        tube_pos, _ = self.tube.get_pose()
        pip_pos, _ = self.pipette.get_pose()
        target_pos = np.stack(
            [tube_pos[:, 0], tube_pos[:, 1], pip_pos[:, 2] + 0.01], axis=1
        ).astype(np.float32)
        target_quat = np.tile(np.array([0.7071, 0.0, 0.0, -0.7071], dtype=np.float32),
                              (self.cfg.n_envs, 1))
        self.move_to_pose_with_refinement(
            arm=self.franka_left, ee_link=self.left_ee,
            source_object=pipette_tip_link, target_object=None,
            target_pos=target_pos, target_quat=target_quat,
            coarse_interp=16, coarse_force=-10.0, coarse_steps=2,
            record_arm='left',
        )

        # Step 5: Press the button (right arm approaches)
        pipette_button = self.pipette.entity.get_link('pipette_button')
        btn_pos = _to_np(pipette_button.get_pos())
        tube_pos_now, _ = self.tube.get_pose()
        pip_pos_now, _ = self.pipette.get_pose()
        left_target_pos = np.stack(
            [tube_pos_now[:, 0], tube_pos_now[:, 1], pip_pos_now[:, 2] + 0.01], axis=1
        ).astype(np.float32)
        left_target_quat = np.tile(np.array([0.7071, 0.0, 0.0, -0.7071], dtype=np.float32),
                                   (self.cfg.n_envs, 1))

        self.move_two_arms_with_refinement(
            left_cfg={
                "source_object": pipette_tip_link,
                "target_pos": left_target_pos,
                "target_quat": left_target_quat,
                "coarse_interp": 12,
                "coarse_force": -10.0,
                "coarse_steps": 3,
            },
            right_cfg={
                "source_object": None,
                "target_pos": btn_pos + np.array([0.00, -0.018, 0.325], dtype=np.float32),
                "target_quat": right_origin_quat,
                "coarse_interp": 12,
                "coarse_force": -10.0,
                "coarse_steps": 3,
            },
        )

        press_pos = _to_np(self.right_ee.get_pos()) - np.array([0, 0, 0.0095], dtype=np.float32)
        tube_pos_now, _ = self.tube.get_pose()
        pip_pos_now, _ = self.pipette.get_pose()
        left_target_pos = np.stack(
            [tube_pos_now[:, 0], tube_pos_now[:, 1], pip_pos_now[:, 2] + 0.01], axis=1
        ).astype(np.float32)
        self.move_two_arms_with_refinement(
            left_cfg={
                "source_object": pipette_tip_link,
                "target_pos": left_target_pos,
                "target_quat": left_target_quat,
                "coarse_interp": 12,
                "coarse_force": -10.0,
                "coarse_steps": 5,
            },
            right_cfg={
                "source_object": None,
                "target_pos": press_pos,
                "target_quat": right_origin_quat,
                "coarse_interp": 12,
                "coarse_force": -10.0,
                "coarse_steps": 5,
            },
        )

        # Step 6: Down both
        self.move_two_arms_vertical_z(
            left_dz=-0.125, right_dz=-0.125,
            n_steps=15, n_interp=3,
            left_gripper_force=-10.0, right_gripper_force=-10.0,
        )

        # Step 7: Right arm lift (slight)
        self.move_two_arms_vertical_z(
            left_dz=0.0, right_dz=0.0085,
            n_steps=10, n_interp=4,
            left_gripper_force=-10.0, right_gripper_force=-10.0,
        )

        # Step 8: Right arm back to origin
        self.ik_move_to_pose(
            self.franka_right, self.right_ee,
            right_origin_pos, right_origin_quat,
            n_interp=10, gripper_force=-10.0, n_steps=1,
            record_arm='right',
        )

        # Step 9: Left arm lift
        self.move_two_arms_vertical_z(
            left_dz=0.125, right_dz=0.0,
            n_steps=10, n_interp=2,
            left_gripper_force=-10.0, right_gripper_force=-10.0,
        )

        if self.cam is not None:
            self.cam.stop_recording(save_to_filename='pick_up_pipette_parallel.mp4', fps=20)

        self.finalize_outputs()

    # ----- output -----------------------------------------------------------
    def finalize_outputs(self):
        self.cfg.base_dir.mkdir(parents=True, exist_ok=True)
        for env_i in range(self.cfg.n_envs):
            ep_dir = self.cfg.base_dir / f"{self.cfg.episode_prefix}{env_i:03d}"
            ep_dir.mkdir(parents=True, exist_ok=True)

            def stack(key):
                return np.asarray(self.buf[key][env_i], dtype=np.float32)

            ls = stack("left_state"); la = stack("left_action")
            rs = stack("right_state"); ra = stack("right_action")
            lc = stack("left_cartesian"); rc = stack("right_cartesian")
            pp = stack("pipette_pose"); tip = stack("pipette_tip_pos"); tp = stack("tube_pose")
            part = stack("liquid_particles")
            T = min(len(ls), len(la), len(lc), len(rs), len(ra), len(rc), len(pp), len(tip), len(tp), len(part))

            ls = ls[:T]; la = la[:T]; rs = rs[:T]; ra = ra[:T]
            lc = lc[:T]; rc = rc[:T]
            pp = pp[:T]; tip = tip[:T]; tp = tp[:T]; part = part[:T]

            # Action gripper at frame t = commanded opening intended for frame t+1.
            # Shift action gripper by one step so action[t][7] reflects the NEXT
            # frame's gripper command (last frame keeps its own command).
            la[:-1, 7] = la[1:, 7]
            ra[:-1, 7] = ra[1:, 7]

            left_joint_position = ls[:, :7].astype(np.float32)
            right_joint_position = rs[:, :7].astype(np.float32)
            left_gripper_position = ls[:, 7:8].astype(np.float32)
            right_gripper_position = rs[:, 7:8].astype(np.float32)
            left_cartesian_position = lc.astype(np.float32)
            right_cartesian_position = rc.astype(np.float32)

            left_joint_position_cmd = next_frame(left_joint_position)
            right_joint_position_cmd = next_frame(right_joint_position)
            left_cartesian_position_cmd = next_frame(left_cartesian_position)
            right_cartesian_position_cmd = next_frame(right_cartesian_position)
            left_joint_velocity_cmd = (left_joint_position_cmd - left_joint_position).astype(np.float32)
            right_joint_velocity_cmd = (right_joint_position_cmd - right_joint_position).astype(np.float32)
            left_cartesian_velocity_cmd = (left_cartesian_position_cmd - left_cartesian_position).astype(np.float32)
            right_cartesian_velocity_cmd = (right_cartesian_position_cmd - right_cartesian_position).astype(np.float32)
            left_gripper_action = binary_gripper(la[:, 7:8])
            right_gripper_action = binary_gripper(ra[:, 7:8])

            obs_cartesian_position = np.concatenate([left_cartesian_position, right_cartesian_position], axis=1).astype(np.float32)
            obs_joint_position = np.concatenate([left_joint_position, right_joint_position], axis=1).astype(np.float32)
            obs_gripper_position = np.concatenate([
                left_gripper_position, right_gripper_position
            ], axis=1).astype(np.float32) * self._FINGER_OPEN_M

            act_cartesian_position = np.concatenate([left_cartesian_position_cmd, right_cartesian_position_cmd], axis=1).astype(np.float32)
            act_joint_position = np.concatenate([left_joint_position_cmd, right_joint_position_cmd], axis=1).astype(np.float32)
            act_gripper_position = np.concatenate([left_gripper_action, right_gripper_action], axis=1).astype(np.float32)
            act_cartesian_velocity = np.concatenate([left_cartesian_velocity_cmd, right_cartesian_velocity_cmd], axis=1).astype(np.float32)
            act_joint_velocity = np.concatenate([left_joint_velocity_cmd, right_joint_velocity_cmd], axis=1).astype(np.float32)

            dones = np.zeros((T,), dtype=np.int8)
            rewards = np.zeros((T,), dtype=np.float32)
            success, success_info = evaluate_pipette_success(tip, tp, part)
            if T > 0:
                dones[-1] = 1
                rewards[-1] = 1.0 if success else 0.0
            language_instruction = (
                "Pick up the pipette with the left arm, press the pipette button "
                "with the right arm, and dispense liquid into the centrifuge tube."
            )

            # data.h5 (full payload, aligned with insert_tube_into_rack_command.py)
            comp = {"compression": "gzip", "compression_opts": 4}
            with h5py.File(ep_dir / "data.h5", "w") as f:
                f.attrs["language_instruction"] = language_instruction
                f.attrs["dt"] = self.cfg.dt
                f.attrs["substeps"] = self.cfg.substeps
                f.attrs["task"] = "pipette"
                f.attrs["success"] = int(success)
                f.attrs["total"] = int(T)
                f.attrs["rotation"] = "Euler XYZ radians"
                f.attrs["gripper_units"] = "observation gripper in meters; action gripper is binary 0=closed, 1=open."

                data_grp = f.create_group("data")
                demo = data_grp.create_group("demo_0")
                demo.attrs["num_samples"] = int(T)
                demo.attrs["language_instruction"] = language_instruction
                demo.attrs["success"] = int(success)
                for k, v in success_info.items():
                    demo.attrs[f"success_{k}"] = v

                obs = demo.create_group("obs")
                obs.create_dataset("cartesian_position", data=obs_cartesian_position, **comp)
                obs.create_dataset("joint_position", data=obs_joint_position, **comp)
                obs.create_dataset("gripper_position", data=obs_gripper_position, **comp)
                obs.create_dataset("pipette_pose", data=pp, **comp)
                obs.create_dataset("pipette_tip_pos", data=tip, **comp)
                obs.create_dataset("tube_pose", data=tp, **comp)
                obs.create_dataset("liquid_particles", data=part, **comp)

                actions = demo.create_group("actions")
                actions.create_dataset("cartesian_position", data=act_cartesian_position, **comp)
                actions.create_dataset("joint_position", data=act_joint_position, **comp)
                actions.create_dataset("gripper_position", data=act_gripper_position, **comp)
                actions.create_dataset("cartesian_velocity", data=act_cartesian_velocity, **comp)
                actions.create_dataset("joint_velocity", data=act_joint_velocity, **comp)

                demo.create_dataset("dones", data=dones)
                demo.create_dataset("rewards", data=rewards)

            with h5py.File(ep_dir / "states_actions.hdf5", "w") as f:
                f.attrs["language_instruction"] = language_instruction
                f.attrs["dt"] = self.cfg.dt
                f.attrs["task"] = "pipette"
                f.attrs["success"] = int(success)
                f.attrs["total"] = int(T)
                f.attrs["rotation"] = "Euler XYZ radians"
                f.attrs["gripper_units"] = "obs gripper in meters; actions gripper binary 0=closed, 1=open."
                for k, v in success_info.items():
                    f.attrs[f"success_{k}"] = v

                obs = f.create_group("obs")
                obs.create_dataset("cartesian_position", data=obs_cartesian_position, **comp)
                obs.create_dataset("joint_position", data=obs_joint_position, **comp)
                obs.create_dataset("gripper_position", data=obs_gripper_position, **comp)

                actions = f.create_group("actions")
                actions.create_dataset("cartesian_position", data=act_cartesian_position, **comp)
                actions.create_dataset("joint_position", data=act_joint_position, **comp)
                actions.create_dataset("gripper_position", data=act_gripper_position, **comp)
                actions.create_dataset("cartesian_velocity", data=act_cartesian_velocity, **comp)
                actions.create_dataset("joint_velocity", data=act_joint_velocity, **comp)

                f.create_dataset("dones", data=dones)
                f.create_dataset("rewards", data=rewards)

            # init.json (scene + per-env init + camera + randomization)
            per_env = self.scene_init["per_env"][env_i]
            env_offset = np.asarray(per_env["envs_offset_xyz_m"], dtype=np.float32)
            pipette_init_pos = np.asarray(per_env["pipette"]["pos_local"], dtype=np.float32)
            tube_init_pos = np.asarray(per_env["tube"]["pos_local"], dtype=np.float32)
            meta = {
                "robot": {
                    "left_base_pos": self.scene_init["fixed_entities"]["franka_left"]["pos"],
                    "right_base_pos": self.scene_init["fixed_entities"]["franka_right"]["pos"],
                    "home_qpos": self.scene_init["robot_home_qpos_9"],
                },
                "pipette": {
                    "init_pos": pipette_init_pos.tolist(),
                    "init_quat": pp[0, 3:7].tolist(),
                },
                "tube": {
                    "init_pos": tube_init_pos.tolist(),
                    "init_quat": tp[0, 3:7].tolist(),
                },
                "liquid": {
                    "init_pos": per_env["liquid"]["pos_local"],
                    "size": per_env["liquid"]["size"],
                },
                "success_check": {
                    "success": bool(success),
                    **success_info,
                },
                "scene": {
                    "n_envs": int(self.cfg.n_envs),
                    "env_idx": int(env_i),
                    "envs_offset_xyz_m": env_offset.tolist(),
                    "dt": self.cfg.dt,
                    "substeps": self.cfg.substeps,
                    "record_hz": 1.0 / self.cfg.dt,
                },
                "cameras": self.scene_init["cameras"],
            }
            with open(ep_dir / "init.json", "w", encoding="utf-8") as f:
                json.dump(meta, f, ensure_ascii=False, indent=2)

        print(f"[done] wrote {self.cfg.n_envs} episodes to {self.cfg.base_dir}")


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--n_envs", type=int, default=50)
    p.add_argument("--dt", type=float, default=5e-2)
    p.add_argument("--substeps", type=int, default=100)
    p.add_argument("--output", type=str, default="results/pipette/dataset7")
    p.add_argument("--viewer", action="store_true")
    p.add_argument("--record_video", action="store_true")
    return p.parse_args()


def main():
    args = parse_args()
    cfg = TaskConfig(
        n_envs=args.n_envs,
        dt=args.dt,
        substeps=args.substeps,
        base_dir=Path(args.output),
        show_viewer=args.viewer,
        record_video=args.record_video,
    )
    ParallelPipetteTask(cfg).run()


if __name__ == "__main__":
    main()
