XL2     &                !Mڲ     TypeV2Obj ID                DDirJ`\OXzSK90٦EcAlgoEcMEcN EcBSize   EcIndexEcDistCSumAlgoPartNumsPartETags 0a6774d7d30cd349caedc45d4a2a0346PartSizes͘HPartASizes͘HPartIdx Size͘HMTime!MڲMetaSysx-minio-internal-inline-datatruex-rustfs-internal-inline-datatrueMetaUsrcontent-typetext/x-python-scriptetag 0a6774d7d30cd349caedc45d4a2a0346v ΢$nullŘh`X5A)]CkD' import argparse
from dataclasses import dataclass
from pathlib import Path
import json

import genesis as gs
import h5py
import numpy as np
import torch
from scipy.spatial.transform import Rotation


@dataclass
class TaskConfig:
    # Parallelism and physics parameters
    n_envs: int = 20
    dt: float = 5e-2
    substeps: int = 100
    record_hz: float = 20.0
    env_spacing: tuple[float, float] = (2.0, 2.0)

    # SPH parameters
    pressure_solver: str = "WCSPH"
    particle_size: float = 0.005
    liquid_stiffness: float = 50000.0

    # Data output settings
    base_dir: Path = Path("results/pour_water_from_beaker2beaker/dataset")
    episode_prefix: str = "episode_"
    video_fps: int = 20

    # Randomization ranges for scene setup
    source_xy_range: tuple[tuple[float, float], tuple[float, float]] = ((-0.05, 0.12), (-0.3, 0.0))
    target_xy_range: tuple[tuple[float, float], tuple[float, float]] = ((0.12, 0.19), (-0.3, 0.0))
    beaker_z: float = 0.829

    # Gripper force settings
    close_force: float = -15.0
    open_force: float = 2.0
    success_target_ratio: float = 0.5


FINGER_OPEN_M = 0.04


def _to_np(x):
    return x.detach().cpu().numpy() if isinstance(x, torch.Tensor) else np.asarray(x)


def pose_to_cartesian6(pos, quat):
    pos = np.asarray(pos, dtype=np.float32)
    quat = np.asarray(quat, dtype=np.float32)
    if pos.ndim == 1:
        pos = pos[None, :]
    if quat.ndim == 1:
        quat = quat[None, :]
    euler = Rotation.from_quat(
        np.concatenate([quat[:, 1:4], quat[:, :1]], axis=1)
    ).as_euler("xyz", degrees=False).astype(np.float32)
    return np.concatenate([pos, euler], axis=1).astype(np.float32)


def norm_gripper(gripper_m):
    return np.clip(np.asarray(gripper_m, dtype=np.float32) / FINGER_OPEN_M, 0.0, 1.0)


def binary_gripper(gripper):
    return (np.asarray(gripper, dtype=np.float32) > 0.5).astype(np.float32)


def grip_action_from_force(grip_force):
    if grip_force is None:
        return None
    return 0.0 if float(grip_force) < 0.0 else 1.0


def grip_open_m_from_force(grip_force):
    if grip_force is None:
        return None
    return 0.0 if float(grip_force) < 0.0 else FINGER_OPEN_M


def next_frame(values):
    values = np.asarray(values, dtype=np.float32)
    if values.shape[0] <= 1:
        return values.copy()
    return np.concatenate([values[1:], values[-1:]], axis=0).astype(np.float32)


def finite_difference(values, dt):
    values = np.asarray(values, dtype=np.float32)
    out = np.zeros_like(values, dtype=np.float32)
    if values.shape[0] > 1:
        out[:-1] = (values[1:] - values[:-1]) / float(dt)
    return out


def evaluate_pour_success(
    source_pose,
    target_pose,
    particle_pos,
    target_ratio=0.5,
    target_xy_radius=0.06,
    source_xy_radius=0.06,
    z_min_margin=-0.02,
    z_max_above=0.24,
):
    source_pose = np.asarray(source_pose, dtype=np.float32)
    target_pose = np.asarray(target_pose, dtype=np.float32)
    particle_pos = np.asarray(particle_pos, dtype=np.float32)
    if source_pose.shape[0] == 0 or target_pose.shape[0] == 0 or particle_pos.shape[0] == 0:
        return False, {
            "target_ratio": float(target_ratio),
            "final_target_ratio": 0.0,
            "target_particle_count": 0,
            "initial_particle_count": 0,
        }

    first_particles = particle_pos[0]
    final_particles = particle_pos[-1]
    valid_initial = np.isfinite(first_particles).all(axis=1) & (first_particles[:, 2] > 0.5)
    initial_count = int(valid_initial.sum())
    if initial_count <= 0:
        return False, {
            "target_ratio": float(target_ratio),
            "final_target_ratio": 0.0,
            "target_particle_count": 0,
            "initial_particle_count": 0,
        }

    target_center = target_pose[-1, :3]
    source_center = source_pose[-1, :3]
    valid_final = np.isfinite(final_particles).all(axis=1) & (final_particles[:, 2] > 0.5)
    target_xy = np.linalg.norm(final_particles[:, :2] - target_center[:2], axis=1) <= target_xy_radius
    target_z = (
        (final_particles[:, 2] >= target_center[2] + z_min_margin)
        & (final_particles[:, 2] <= target_center[2] + z_max_above)
    )
    source_xy = np.linalg.norm(final_particles[:, :2] - source_center[:2], axis=1) <= source_xy_radius
    source_z = (
        (final_particles[:, 2] >= source_center[2] + z_min_margin)
        & (final_particles[:, 2] <= source_center[2] + z_max_above)
    )
    target_count = int((valid_final & target_xy & target_z).sum())
    source_count = int((valid_final & source_xy & source_z).sum())
    final_ratio = float(target_count / max(initial_count, 1))
    source_ratio = float(source_count / max(initial_count, 1))
    info = {
        "target_ratio": float(target_ratio),
        "final_target_ratio": final_ratio,
        "final_source_ratio": source_ratio,
        "target_particle_count": target_count,
        "source_particle_count": source_count,
        "initial_particle_count": initial_count,
        "target_xy_radius_m": float(target_xy_radius),
        "source_xy_radius_m": float(source_xy_radius),
        "z_min_margin_m": float(z_min_margin),
        "z_max_above_m": float(z_max_above),
    }
    return bool(final_ratio >= target_ratio), info


class WeldAttachmentChecker:
    """
    Batched weld rule for keeping the source beaker fixed to the gripper after
    grasping. Attach when both fingers contact the object and it is no longer
    supported by the table. Once attached, keep the weld until the beaker touches
    the table again; this avoids contact-query flicker causing slip during pour.
    """
    def __init__(
        self,
        scene,
        robot_entity,
        ee_link,
        left_finger_link,
        right_finger_link,
        object_entity,
        support_entities=None,
        acquire_steps=2,
        release_steps=3,
        n_envs=1,
    ):
        self.rigid = scene.sim.rigid_solver
        self.robot = robot_entity
        self.ee_link = ee_link
        self.left_finger = left_finger_link
        self.right_finger = right_finger_link
        self.obj = object_entity
        self.support_entities = [] if support_entities is None else list(support_entities)
        self.acquire_steps = acquire_steps
        self.release_steps = release_steps
        self.n_envs = n_envs
        self.active = np.zeros(n_envs, dtype=bool)
        self._acquire_counter = np.zeros(n_envs, dtype=np.int32)
        self._release_counter = np.zeros(n_envs, dtype=np.int32)
        self.link_obj = np.array([self.obj.base_link.idx], dtype=gs.np_int)
        self.link_ee = np.array([self.ee_link.idx], dtype=gs.np_int)

    def _contact_per_env(self, entity):
        info = self.obj.get_contacts(with_entity=entity)
        n = self.n_envs
        if "valid_mask" in info:
            arr = _to_np(info["valid_mask"])
            if arr.ndim == 0:
                return np.full(n, bool(arr), dtype=bool)
            if arr.ndim == 1:
                if n == 1:
                    return np.array([bool(arr.any())], dtype=bool)
                return np.array([bool(arr[i]) if i < arr.shape[0] else False for i in range(n)], dtype=bool)
            return arr.any(axis=tuple(range(1, arr.ndim))).astype(bool)
        if "geom_a" in info:
            arr = _to_np(info["geom_a"])
            if arr.ndim <= 1:
                return np.full(n, arr.size > 0, dtype=bool)
            return (arr.shape[1] > 0) * np.ones(n, dtype=bool)
        return np.zeros(n, dtype=bool)

    def left_contact(self):
        return self._contact_per_env(self.left_finger)

    def right_contact(self):
        return self._contact_per_env(self.right_finger)

    def has_both_finger_contacts(self):
        return self.left_contact() & self.right_contact()

    def has_support_contact(self):
        if not self.support_entities:
            return np.zeros(self.n_envs, dtype=bool)
        out = np.zeros(self.n_envs, dtype=bool)
        for entity in self.support_entities:
            out = out | self._contact_per_env(entity)
        return out

    def grasp_success_now(self):
        return self.has_both_finger_contacts() & (~self.has_support_contact())

    def grasp_failure_now(self):
        # Release is explicit in the scripted sequence. Once the beaker is
        # welded, keep it fixed through table/contact-query flicker during
        # transport and pouring.
        return np.zeros(self.n_envs, dtype=bool)

    def _weld_already_present(self):
        try:
            info = self.rigid.get_weld_constraints(as_tensor=True, to_torch=True)
            link_a = info["link_a"]
            link_b = info["link_b"]
            la = int(self.link_obj[0])
            lb = int(self.link_ee[0])
            mask = (((link_a == la) | (link_b == la)) &
                    ((link_a == lb) | (link_b == lb)))
            return bool(mask.any().item()) if hasattr(mask, "any") else bool(np.asarray(mask).any())
        except Exception:
            return False

    def _attach_envs(self, env_idx):
        if env_idx.size == 0:
            return
        if self.active.any() or self._weld_already_present():
            self.active[env_idx] = True
            self._acquire_counter[env_idx] = 0
            self._release_counter[env_idx] = 0
            return
        try:
            self.rigid.add_weld_constraint(self.link_obj, self.link_ee)
        except (AssertionError, Exception) as e:
            print(f"[WELD] add_weld_constraint skipped ({type(e).__name__}: {e})")
        self.active[:] = True
        self._acquire_counter[:] = 0
        self._release_counter[:] = 0

    def force_attach_all(self):
        env_idx = np.arange(self.n_envs, dtype=np.int32)
        self._attach_envs(env_idx)
        return self.active.copy()

    def _detach_envs(self, env_idx):
        if env_idx.size == 0:
            return
        if not (self.active.any() or self._weld_already_present()):
            self.active[:] = False
            return
        try:
            self.rigid.delete_weld_constraint(self.link_obj, self.link_ee)
        except (AssertionError, Exception) as e:
            print(f"[WELD] delete_weld_constraint skipped ({type(e).__name__}: {e})")
        self.active[:] = False
        self._acquire_counter[:] = 0
        self._release_counter[:] = 0

    def force_detach_all(self):
        env_idx = np.arange(self.n_envs, dtype=np.int32)
        self._detach_envs(env_idx)
        return self.active.copy()

    def update(self):
        success = self.grasp_success_now()
        failure = self.grasp_failure_now()
        inactive = ~self.active

        self._acquire_counter[inactive & success] += 1
        self._acquire_counter[inactive & ~success] = 0
        self._release_counter[self.active & failure] += 1
        self._release_counter[self.active & ~failure] = 0

        attach_envs = np.where(inactive & (self._acquire_counter >= self.acquire_steps))[0].astype(np.int32)
        self._attach_envs(attach_envs)
        detach_envs = np.where(self.active & (self._release_counter >= self.release_steps))[0].astype(np.int32)
        self._detach_envs(detach_envs)

        return {
            "attached": self.active.copy(),
            "attach_now": attach_envs,
            "detach_now": detach_envs,
        }


class ParallelPourTask:
    """
    Parallel pouring task.

    Goals:
    1) Execute pouring motions in parallel.
    2) Record states/actions/particle_pos.
    3) Export one data.h5 per environment.
    4) Avoid cam.render to reduce rendering overhead.
    """

    def __init__(self, cfg: TaskConfig):
        self.cfg = cfg
        gs.init(backend=gs.gpu)

        self.scene = None
        self.table = None
        self.source_beaker = None
        self.target_beaker = None
        self.franka = None
        self.end_effector = None
        self.liquid = None
        self.attachment_checker = None

        self.motors_dof = np.arange(7)
        self.fingers_dof = np.arange(7, 9)

        self.active_envs = np.arange(cfg.n_envs, dtype=np.int32)
        self.success_envs: set[int] = set()

        self.source_init_pos = None
        self.target_init_pos = None
        self.initial_particle_count = 0
        self.state_seq: list[list[np.ndarray]] = [[] for _ in range(cfg.n_envs)]
        self.action_seq: list[list[np.ndarray]] = [[] for _ in range(cfg.n_envs)]
        self.cartesian_seq: list[list[np.ndarray]] = [[] for _ in range(cfg.n_envs)]
        self.source_pose_seq: list[list[np.ndarray]] = [[] for _ in range(cfg.n_envs)]
        self.target_pose_seq: list[list[np.ndarray]] = [[] for _ in range(cfg.n_envs)]
        self.particle_seq: list[list[np.ndarray]] = [[] for _ in range(cfg.n_envs)]

        self.sim_t = 0.0

    @staticmethod
    def euler_deg_to_quat_wxyz(roll: float, pitch: float, yaw: float) -> np.ndarray:
        quat_xyzw = Rotation.from_euler("xyz", [roll, pitch, yaw], degrees=True).as_quat()
        return np.array([quat_xyzw[3], quat_xyzw[0], quat_xyzw[1], quat_xyzw[2]], dtype=np.float32)

    def _random_positions(self, x_rng, y_rng, z_val: float):
        n_envs = self.cfg.n_envs
        x = x_rng[0] + (x_rng[1] - x_rng[0]) * torch.rand(n_envs, device=gs.device)
        y = y_rng[0] + (y_rng[1] - y_rng[0]) * torch.rand(n_envs, device=gs.device)
        z = torch.full((n_envs,), float(z_val), device=gs.device)
        return torch.stack([x, y, z], dim=-1)

    def build_scene(self):
        # SPH solver boundary
        lower_bound = (-0.5, -1.0, 0.0)
        upper_bound = (1.2, 0.5, 1.2)

        self.scene = gs.Scene(
            sim_options=gs.options.SimOptions(dt=self.cfg.dt, substeps=self.cfg.substeps),
            sph_options=gs.options.SPHOptions(
                pressure_solver=self.cfg.pressure_solver,
                lower_bound=lower_bound,
                upper_bound=upper_bound,
                particle_size=self.cfg.particle_size,
            ),
            show_viewer=False,
        )

        self.scene.add_entity(gs.morphs.Plane())
        self.table = self.scene.add_entity(
            gs.morphs.MJCF(
                file='asset/model/misc/simple_table.xml',
                pos=(0.0, 0.0, 0.0),
                decimate=False,
                euler=(0, 0, 90),
            )
        )

        self.source_beaker = self.scene.add_entity(
            gs.morphs.MJCF(
                file='asset/beaker/beaker.xml',
                pos=((0.0, 0.0, 0.829)),
                euler=(0, 0, 0),
            ),
        )
        self.target_beaker = self.scene.add_entity(
            gs.morphs.MJCF(
                file='asset/beaker/beaker.xml',
                pos=((0.0, 0.0, 0.829)),
                euler=(0, 0, 0),
            ),
        )

        self.franka = self.scene.add_entity(
            gs.morphs.MJCF(file="xml/franka_emika_panda/panda.xml", pos=(-0.25, -0.5, 0.824))
        )

        # Key: add liquid entity
        self.liquid = self.scene.add_entity(
            material=gs.materials.SPH.Liquid(mu=0.001, gamma=0.01, stiffness=self.cfg.liquid_stiffness, sampler="regular"),
            morph=gs.morphs.Box(pos=(0.55, -0.20, 0.45), size=(0.04, 0.04, 0.06)),
            surface=gs.surfaces.Default(color=(0.4, 0.8, 1.0, 1.0), vis_mode="particle"),
        )

        self.scene.build(n_envs=self.cfg.n_envs, env_spacing=self.cfg.env_spacing)

        # Randomize beaker positions
        envs_all = torch.arange(self.cfg.n_envs, device=gs.device, dtype=torch.long)
        source_pos = self._random_positions(*self.cfg.source_xy_range, self.cfg.beaker_z)
        target_pos = self._random_positions(*self.cfg.target_xy_range, self.cfg.beaker_z)
        self.source_beaker.set_pos(source_pos, envs_idx=envs_all)
        self.target_beaker.set_pos(target_pos, envs_idx=envs_all)

        # Align liquid particle cloud with randomized source beaker
        nominal_source = torch.tensor([0.55, -0.20, 0.38], device=gs.device, dtype=gs.tc_float)
        offset = source_pos - nominal_source[None, :]
        p0 = self.liquid.get_particles_pos()[:, : self.liquid.n_particles, :]
        self.liquid.set_particles_pos(p0 + offset[:, None, :])

        self.source_init_pos = source_pos.detach().cpu().numpy()
        self.target_init_pos = target_pos.detach().cpu().numpy()

        # Controller parameters
        self.franka.set_dofs_kp(np.array([4500, 4500, 3500, 3500, 2000, 2000, 2000, 100, 100]))
        self.franka.set_dofs_kv(np.array([450, 450, 350, 350, 200, 200, 200, 10, 10]))
        self.franka.set_dofs_force_range(
            np.array([-87, -87, -87, -87, -12, -12, -12, -100, -100]),
            np.array([87, 87, 87, 87, 12, 12, 12, 100, 100]),
        )
        valid_home_qpos = np.array([-0.0613, -1.2243, 2.1648, -2.8683, -0.6867, 2.6172, 0.4415, 0.04, 0.04])
        self.franka.set_dofs_position(valid_home_qpos)
        self.franka.control_dofs_position(valid_home_qpos)
        self.end_effector = self.franka.get_link("hand")
        left_finger_link = self.franka.get_link("left_finger")
        right_finger_link = self.franka.get_link("right_finger")
        self.attachment_checker = WeldAttachmentChecker(
            scene=self.scene,
            robot_entity=self.franka,
            ee_link=self.end_effector,
            left_finger_link=left_finger_link,
            right_finger_link=right_finger_link,
            object_entity=self.source_beaker,
            support_entities=[self.table],
            acquire_steps=2,
            release_steps=3,
            n_envs=self.cfg.n_envs,
        )

        self.scene.step()

        self.initial_particle_count = int(self.liquid.n_particles)
    def _record_step(self, envs_idx: np.ndarray, action_t, grip_action=None):
        # Record every simulation step to keep time series fully aligned
        qpos_cur = self.franka.get_qpos()
        gripper = qpos_cur[:, 7:9].mean(axis=1, keepdims=True)
        state_t = torch.cat([qpos_cur[:, :7], gripper], dim=1).detach().cpu().numpy()  # [B,8]
        cartesian_t = pose_to_cartesian6(_to_np(self.end_effector.get_pos()), _to_np(self.end_effector.get_quat()))

        src_pos = self.source_beaker.get_pos().detach().cpu().numpy()
        src_quat = self.source_beaker.get_quat().detach().cpu().numpy()
        tgt_pos = self.target_beaker.get_pos().detach().cpu().numpy()
        tgt_quat = self.target_beaker.get_quat().detach().cpu().numpy()
        src_pose = np.concatenate([src_pos, src_quat], axis=1)  # [B,7]
        tgt_pose = np.concatenate([tgt_pos, tgt_quat], axis=1)  # [B,7]

        particles = self.liquid.get_particles_pos()[:, : self.liquid.n_particles, :].detach().cpu().numpy()

        if isinstance(action_t, torch.Tensor):
            action_t = action_t.detach().cpu().numpy()

        # Normalize action to 8 dims: [arm7, gripper_open_command_0_or_1]
        if action_t.shape[1] == 7:
            if grip_action is None:
                grip_cmd = binary_gripper(norm_gripper(_to_np(gripper)[envs_idx]))
            else:
                grip_cmd = np.full((len(envs_idx), 1), float(grip_action), dtype=np.float32)
            action_t = np.concatenate([action_t, grip_cmd], axis=1)

        for local_i, env_idx in enumerate(envs_idx):
            env_i = int(env_idx)
            self.state_seq[env_i].append(state_t[env_i])
            self.action_seq[env_i].append(action_t[local_i])
            self.cartesian_seq[env_i].append(cartesian_t[env_i])
            self.source_pose_seq[env_i].append(src_pose[env_i])
            self.target_pose_seq[env_i].append(tgt_pose[env_i])
            self.particle_seq[env_i].append(particles[env_i])

    def _step_and_record(self, envs_idx: np.ndarray, action_t, grip_action=None):
        self.scene.step()
        self.sim_t += self.cfg.dt
        if self.attachment_checker is not None:
            self.attachment_checker.update()
        self._record_step(envs_idx, action_t, grip_action)

    def _apply_qpos_target(self, envs_idx: np.ndarray, q_cmd, grip_force=None):
        if len(envs_idx) == 0:
            return None
        if isinstance(q_cmd, torch.Tensor):
            q_target = q_cmd.detach().clone()
        else:
            q_target = torch.as_tensor(q_cmd, device=gs.device, dtype=gs.tc_float).clone()
        if q_target.ndim == 1:
            q_target = q_target[None, :]
        grip_open_m = grip_open_m_from_force(grip_force)
        if grip_open_m is not None:
            q_target[:, 7:9] = float(grip_open_m)
        self.franka.control_dofs_position(q_target, envs_idx=envs_idx)
        return grip_action_from_force(grip_force)

    def _hold(self, envs_idx: np.ndarray, n_steps: int, grip_force=None):
        for _ in range(n_steps):
            qpos_cur = self.franka.get_qpos()
            grip_action = self._apply_qpos_target(envs_idx, qpos_cur[envs_idx], grip_force)
            self._step_and_record(envs_idx, qpos_cur[envs_idx, :7], grip_action)

    def grisp(self, envs_idx: np.ndarray, q_ref, gripper_force=None, n_steps: int = 20):
        if len(envs_idx) == 0:
            return
        q_start = self.franka.get_qpos(envs_idx=envs_idx).detach().clone()
        q_target = q_start.clone()
        grip_open_m = grip_open_m_from_force(gripper_force)
        if grip_open_m is not None:
            q_target[:, 7:9] = float(grip_open_m)
        steps = max(1, int(n_steps))
        for step in range(1, steps + 1):
            alpha = step / steps
            q_cmd = (1.0 - alpha) * q_start + alpha * q_target
            self.franka.control_dofs_position(q_cmd, envs_idx=envs_idx)
            self._step_and_record(envs_idx, q_ref[:, :7], grip_action_from_force(gripper_force))

    def _move_ik(self, envs_idx: np.ndarray, pos: np.ndarray, quat: np.ndarray, n_interp: int, grip_force: float):
        q_goal = self.franka.inverse_kinematics(
            link=self.end_effector,
            pos=pos,
            quat=quat,
            envs_idx=envs_idx,
        )
        q_start = self.franka.get_qpos()[envs_idx]

        for t in range(1, n_interp + 1):
            alpha = t / n_interp
            q_cmd = (1.0 - alpha) * q_start + alpha * q_goal
            grip_action = self._apply_qpos_target(envs_idx, q_cmd, grip_force)
            self._step_and_record(envs_idx, q_cmd[:, :7], grip_action)

        return q_goal

    def _rotate_joint7(self, envs_idx: np.ndarray, angle_deg: float, hold_steps: int = 8, grip_force: float = -30.0):
        if len(envs_idx) == 0:
            return

        qpos = self.franka.get_qpos(envs_idx=envs_idx).clone()
        qpos[:, 6] += torch.deg2rad(torch.tensor(angle_deg, device=gs.device, dtype=gs.tc_float))
        grip_action = self._apply_qpos_target(envs_idx, qpos, grip_force)

        for _ in range(hold_steps):
            if len(self.active_envs) == 0:
                return
            qpos_all = self.franka.get_qpos()[self.active_envs]
            self._step_and_record(self.active_envs, qpos_all[:, :7], grip_action)

    def run(self):
        self.build_scene()

        # settle
        for _ in range(5):
            self.scene.step()

        # Step 1
        print("Step 1")
        step1_pos = self.source_beaker.get_pos().cpu().numpy().copy()
        step1_pos += np.array([-0.095, 0.0, 0.11], dtype=np.float32)
        step1_quat = np.repeat(self.euler_deg_to_quat_wxyz(0, 90, 0)[None, :], len(self.active_envs), axis=0)
        q = self._move_ik(self.active_envs, step1_pos, step1_quat, n_interp=20, grip_force=self.cfg.open_force)
        self._hold(self.active_envs, 5, grip_force=self.cfg.open_force)

        # Step 2
        print("Step 2")
        step2_pos = self.source_beaker.get_pos().cpu().numpy().copy()
        step2_pos += np.array([-0.095, 0.0, 0.05], dtype=np.float32)
        step2_quat = np.repeat(self.euler_deg_to_quat_wxyz(0, 90, 0)[None, :], len(self.active_envs), axis=0)
        q = self._move_ik(self.active_envs, step2_pos, step2_quat, n_interp=12, grip_force=self.cfg.open_force)
        self._hold(self.active_envs, 5, grip_force=self.cfg.open_force)

        self.grisp(self.active_envs, q, gripper_force=self.cfg.close_force, n_steps=30)
        self._hold(self.active_envs, 8, grip_force=self.cfg.close_force)
        if self.attachment_checker is not None:
            self.attachment_checker.force_attach_all()
            print("[WELD] attached source_beaker after grasp")

        # Step 3
        print("Step 3")
        step3_pos = step2_pos + np.array([0.0, 0.0, 0.03], dtype=np.float32)
        step3_quat = step2_quat
        _ = self._move_ik(self.active_envs, step3_pos, step3_quat, n_interp=18, grip_force=self.cfg.close_force)
        self._hold(self.active_envs, 8, grip_force=self.cfg.close_force)

        # Step 4
        print("Step 4")
        source_y = self.source_beaker.get_pos().cpu().numpy().copy()[self.active_envs, 1]
        target_y = self.target_beaker.get_pos().cpu().numpy().copy()[self.active_envs, 1]
        normal_envs = self.active_envs[source_y <= target_y]
        reverse_envs = self.active_envs[source_y > target_y]

        if len(normal_envs) > 0:
            step4_pos = self.target_beaker.get_pos().cpu().numpy().copy()[normal_envs] + np.array([-0.09, -0.045, 0.165], dtype=np.float32)
            step4_quat = np.repeat(self.euler_deg_to_quat_wxyz(0, 90, 0)[None, :], len(normal_envs), axis=0)
            q = self._move_ik(normal_envs, step4_pos, step4_quat, n_interp=30, grip_force=self.cfg.close_force)
            self._hold(normal_envs, 5, grip_force=self.cfg.close_force)

        if len(reverse_envs) > 0:
            step4_pos = self.target_beaker.get_pos().cpu().numpy().copy()[reverse_envs] + np.array([-0.09, 0.045, 0.165], dtype=np.float32)
            step4_quat = np.repeat(self.euler_deg_to_quat_wxyz(0, 90, 0)[None, :], len(reverse_envs), axis=0)
            q = self._move_ik(reverse_envs, step4_pos, step4_quat, n_interp=30, grip_force=self.cfg.close_force)
            self._hold(reverse_envs, 5, grip_force=self.cfg.close_force)

        # Step 5
        print("Step 5")
        self._rotate_joint7(normal_envs, -30, hold_steps=12, grip_force=self.cfg.close_force)
        self._rotate_joint7(normal_envs, -30, hold_steps=12, grip_force=self.cfg.close_force)
        self._rotate_joint7(normal_envs, -45, hold_steps=16, grip_force=self.cfg.close_force)

        self._rotate_joint7(reverse_envs, +30, hold_steps=12, grip_force=self.cfg.close_force)
        self._rotate_joint7(reverse_envs, +30, hold_steps=12, grip_force=self.cfg.close_force)
        self._rotate_joint7(reverse_envs, +45, hold_steps=16, grip_force=self.cfg.close_force)

        self._hold(self.active_envs, 3, grip_force=self.cfg.close_force)

        self._rotate_joint7(normal_envs, +105, hold_steps=18, grip_force=self.cfg.close_force)
        self._rotate_joint7(reverse_envs, -105, hold_steps=18, grip_force=self.cfg.close_force)

        self._hold(self.active_envs, 3, grip_force=self.cfg.close_force)

        # Step 6
        print("Step 6")
        step6_pos = step2_pos + np.array([0.0, 0.0, 0.13], dtype=np.float32)
        step6_quat = step2_quat
        q = self._move_ik(self.active_envs, step6_pos, step6_quat, n_interp=10, grip_force=self.cfg.close_force)
        self._hold(self.active_envs, 5, grip_force=self.cfg.close_force)
        q = self._move_ik(self.active_envs, step2_pos, step2_quat, n_interp=10, grip_force=self.cfg.close_force)
        self._hold(self.active_envs, 5, grip_force=self.cfg.close_force)

        if self.attachment_checker is not None:
            self.attachment_checker.force_detach_all()
            print("[WELD] detached source_beaker before release")
        self.grisp(self.active_envs, q, gripper_force=self.cfg.open_force, n_steps=10)
    

        # Step 7
        print("Step 7")
        step7_pos = step1_pos - np.array([0.1, 0.0, 0.0], dtype=np.float32)
        step7_quat = np.repeat(self.euler_deg_to_quat_wxyz(0, 90, 0)[None, :], len(self.active_envs), axis=0)
        q = self._move_ik(self.active_envs, step7_pos, step7_quat, n_interp=8, grip_force=self.cfg.open_force)
        self._hold(self.active_envs, 5, grip_force=self.cfg.open_force)

        self.success_envs = set(int(e) for e in self.active_envs.tolist())
        self.finalize_outputs()

    @staticmethod
    def _create_h5_episode(
        episode_dir: Path,
        states: np.ndarray,
        actions: np.ndarray,
        cartesian: np.ndarray,
        source_pose: np.ndarray,
        target_pose: np.ndarray,
        particle_pos: np.ndarray,
        source_init_pos: np.ndarray,
        target_init_pos: np.ndarray,
        dt: float,
        substeps: int,
        record_hz: float,
        success: bool,
        success_info: dict,
    ):
        episode_dir.mkdir(parents=True, exist_ok=True)
        h5_path = episode_dir / "data.h5"
        states_actions_h5_path = episode_dir / "states_actions.hdf5"
        init_json_path = episode_dir / "init.json"

        t = states.shape[0]
        actual_record_hz = 1.0 / dt
        joint_position = states[:, :7].astype(np.float32)
        gripper_position = states[:, 7:8].astype(np.float32)
        cartesian_position = cartesian.astype(np.float32)
        joint_position_cmd = next_frame(joint_position)
        cartesian_position_cmd = next_frame(cartesian_position)
        joint_velocity_cmd = (joint_position_cmd - joint_position).astype(np.float32)
        cartesian_velocity_cmd = (cartesian_position_cmd - cartesian_position).astype(np.float32)
        action_gripper = binary_gripper(actions[:, 7:8]).astype(np.float32)
        comp = {"compression": "gzip", "compression_opts": 4}
        dones = np.zeros((t,), dtype=np.int8)
        rewards = np.zeros((t,), dtype=np.float32)
        if t > 0:
            dones[-1] = 1
            rewards[-1] = 1.0 if success else 0.0
        language_instruction = "Pick up the source beaker and pour its liquid into the target beaker."

        def write_attrs(f):
            f.attrs["language_instruction"] = language_instruction
            f.attrs["dt"] = dt
            f.attrs["substeps"] = substeps
            f.attrs["record_hz"] = actual_record_hz
            f.attrs["task"] = "pour_water_from_beaker2beaker"
            f.attrs["success"] = int(success)
            f.attrs["total"] = int(t)
            f.attrs["rotation"] = "Euler XYZ radians"
            f.attrs["gripper_units"] = "obs gripper in meters; actions gripper binary 0=closed, 1=open."
            for k, v in success_info.items():
                f.attrs[f"success_{k}"] = v

        def write_obs_actions(parent):
            obs = parent.create_group("obs")
            obs.create_dataset("cartesian_position", data=cartesian_position, compression="gzip", compression_opts=4)
            obs.create_dataset("joint_position", data=joint_position, compression="gzip", compression_opts=4)
            obs.create_dataset("gripper_position", data=gripper_position, compression="gzip", compression_opts=4)

            actions_grp = parent.create_group("actions")
            actions_grp.create_dataset("cartesian_position", data=cartesian_position_cmd, compression="gzip", compression_opts=4)
            actions_grp.create_dataset("joint_position", data=joint_position_cmd, compression="gzip", compression_opts=4)
            actions_grp.create_dataset("gripper_position", data=action_gripper, compression="gzip", compression_opts=4)
            actions_grp.create_dataset("cartesian_velocity", data=cartesian_velocity_cmd, compression="gzip", compression_opts=4)
            actions_grp.create_dataset("joint_velocity", data=joint_velocity_cmd, compression="gzip", compression_opts=4)
            parent.create_dataset("dones", data=dones)
            parent.create_dataset("rewards", data=rewards)

        with h5py.File(h5_path, "w") as f:
            write_attrs(f)
            data_grp = f.create_group("data")
            demo = data_grp.create_group("demo_0")
            demo.attrs["num_samples"] = int(t)
            demo.attrs["language_instruction"] = language_instruction
            demo.attrs["success"] = int(success)
            for k, v in success_info.items():
                demo.attrs[f"success_{k}"] = v
            write_obs_actions(demo)
            obs = demo["obs"]
            obs.create_dataset("source_pose", data=source_pose.astype(np.float32), **comp)
            obs.create_dataset("target_pose", data=target_pose.astype(np.float32), **comp)
            obs.create_dataset("particle_pos", data=particle_pos.astype(np.float32), **comp)

            g_scene = f.create_group("scene")
            g_scene.create_dataset("source_init_pos", data=source_init_pos.astype(np.float32))
            g_scene.create_dataset("target_init_pos", data=target_init_pos.astype(np.float32))

        with h5py.File(states_actions_h5_path, "w") as f:
            write_attrs(f)
            write_obs_actions(f)

        # JSON metadata for randomized initial conditions and object parameters.
        init_json = {
            "robot": {
                "base_pos": [-0.25, -0.5, 0.824],
                "home_qpos": [-0.0613, -1.2243, 2.1648, -2.8683, -0.6867, 2.6172, 0.4415, 0.04, 0.04],
            },
            "source_beaker": {
                "init_pos": np.asarray(source_init_pos, dtype=np.float32).tolist(),
                "init_quat": [1.0, 0.0, 0.0, 0.0],
            },
            "target_beaker": {
                "init_pos": np.asarray(target_init_pos, dtype=np.float32).tolist(),
                "init_quat": [1.0, 0.0, 0.0, 0.0],
            },
            "scene": {
                "n_envs": 1,
                "dt": dt,
                "substeps": substeps,
                "record_hz": record_hz,
            },
            "success_check": {
                "success": bool(success),
                **success_info,
            },
            "cameras": {
                "agentview": {
                "res": [256, 256],
                "pos": [2.5, 0, 2.0],
                "lookat": [0.0, 0.0, 1.0],
                "fov": 30
                },
                "wrist_camera": {
                "res": [256, 256],
                "pos": [0.0, 0.0, 0.0],
                "lookat": [1.0, 0.0, 0.0],
                "fov": 70,
                "attach_link": "hand",
                "offset_translation_xyz": [-0.08, 0.0, -0.035],
                "offset_rotation_euler_deg": {
                    "yaw_z": 90.0,
                    "pitch_y": 180.0,
                    "roll_x": -10.0
                },
                "offset_rotation_composition": "Rz @ Ry @ Rx"
                }
            }
        }
        with open(init_json_path, "w", encoding="utf-8") as f:
            json.dump(init_json, f, ensure_ascii=False, indent=2)

    def finalize_outputs(self):
        self.cfg.base_dir.mkdir(parents=True, exist_ok=True)
        success_dir = self.cfg.base_dir / "success"
        fail_dir = self.cfg.base_dir / "fail"
        success_dir.mkdir(parents=True, exist_ok=True)
        fail_dir.mkdir(parents=True, exist_ok=True)

        tasks_success = {}

        for env_idx in range(self.cfg.n_envs):
            ep_name = f"{self.cfg.episode_prefix}{env_idx:03d}"

            states = np.asarray(self.state_seq[env_idx], dtype=np.float32)
            actions = np.asarray(self.action_seq[env_idx], dtype=np.float32)
            cartesian = np.asarray(self.cartesian_seq[env_idx], dtype=np.float32)
            source_pose = np.asarray(self.source_pose_seq[env_idx], dtype=np.float32)
            target_pose = np.asarray(self.target_pose_seq[env_idx], dtype=np.float32)
            particles = np.asarray(self.particle_seq[env_idx], dtype=np.float32)

            # Align lengths to avoid boundary sampling mismatch
            t = min(len(states), len(actions), len(cartesian), len(source_pose), len(target_pose), len(particles))
            states = states[:t]
            actions = actions[:t]
            cartesian = cartesian[:t]
            source_pose = source_pose[:t]
            target_pose = target_pose[:t]
            particles = particles[:t]
            success, success_info = evaluate_pour_success(
                source_pose,
                target_pose,
                particles,
                target_ratio=self.cfg.success_target_ratio,
            )
            parent_dir = success_dir if success else fail_dir
            ep_dir = parent_dir / ep_name

            self._create_h5_episode(
                episode_dir=ep_dir,
                states=states,
                actions=actions,
                cartesian=cartesian,
                source_pose=source_pose,
                target_pose=target_pose,
                particle_pos=particles,
                source_init_pos=self.source_init_pos[env_idx],
                target_init_pos=self.target_init_pos[env_idx],
                dt=self.cfg.dt,
                substeps=self.cfg.substeps,
                record_hz=1.0 / self.cfg.dt,
                success=success,
                success_info=success_info,
            )
            if success:
                tasks_success[ep_name] = {
                    "env_idx": int(env_idx),
                    "episode_dir": str(ep_dir),
                    "n_frames": int(t),
                    **success_info,
                }

        with open(self.cfg.base_dir / "tasks_success.json", "w", encoding="utf-8") as f:
            json.dump(tasks_success, f, ensure_ascii=False, indent=2)

        summary = {
            "total": self.cfg.n_envs,
            "success": len(tasks_success),
            "fail": self.cfg.n_envs - len(tasks_success),
            "success_target_ratio": float(self.cfg.success_target_ratio),
        }
        with open(self.cfg.base_dir / "summary.json", "w", encoding="utf-8") as f:
            json.dump(summary, f, ensure_ascii=False, indent=2)

        print(
            f"Done. total={self.cfg.n_envs}, success={len(tasks_success)}, output={self.cfg.base_dir}"
        )


def parse_args():
    parser = argparse.ArgumentParser(description="Parallel pour task for Stage1 data collection")
    parser.add_argument("--n_envs", type=int, default=100)
    parser.add_argument("--dt", type=float, default=5e-2)
    parser.add_argument("--substeps", type=int, default=100)
    parser.add_argument("--output", type=str, default="results/pour_water_from_b2b/dataset_10")
    parser.add_argument("--close_force", type=float, default=-10.0)
    parser.add_argument("--open_force", type=float, default=2.0)
    parser.add_argument("--success_target_ratio", type=float, default=0.7)
    return parser.parse_args()


def main():
    args = parse_args()

    cfg = TaskConfig(
        n_envs=args.n_envs,
        dt=args.dt,
        substeps=args.substeps,
        record_hz=1.0 / args.dt,
        base_dir=Path(args.output),
        close_force=args.close_force,
        open_force=args.open_force,
        success_target_ratio=args.success_target_ratio,
    )

    task = ParallelPourTask(cfg)
    task.run()


if __name__ == "__main__":
    main()
