Hand Retargeting (Fancy)ΒΆ

Shadow Hand retargeting example, with costs to maintain contact with the object. Find and unzip the shadowhand URDF at assets/hand_retargeting/shadowhand_urdf.zip.

All examples can be run by first cloning the PyRoki repository, which includes the pyroki_snippets implementation details.

  1import time
  2from typing import Tuple, TypedDict
  3from pathlib import Path
  4import pickle
  5import trimesh
  6from scipy.spatial.transform import Rotation as R
  7
  8import jax
  9import jax.numpy as jnp
 10import jax_dataclasses as jdc
 11import jaxlie
 12import jaxls
 13import numpy as onp
 14import viser
 15from viser.extras import ViserUrdf
 16import yourdfpy
 17
 18import pyroki as pk
 19
 20from retarget_helpers._utils import (
 21    create_conn_tree,
 22    get_mapping_from_mano_to_shadow,
 23    MANO_TO_SHADOW_MAPPING,
 24)
 25
 26
 27class RetargetingWeights(TypedDict):
 28    local_alignment: float
 29    """Local alignment weight, by matching the relative joint/keypoint positions and angles."""
 30    global_alignment: float
 31    """Global alignment weight, by matching the keypoint positions to the robot."""
 32    contact: float
 33    """Contact weight, to maintain contact between the robot and the object."""
 34    contact_margin: float
 35    """Contact margin, to stop penalizing contact when the robot is already close to the object."""
 36    joint_smoothness: float
 37    """Joint smoothness weight."""
 38    root_smoothness: float
 39    """Root translation smoothness weight."""
 40
 41
 42def main():
 43    """Main function for hand retargeting."""
 44
 45    asset_dir = Path(__file__).parent / "retarget_helpers" / "hand"
 46
 47    robot_urdf_path = asset_dir / "shadowhand_urdf" / "shadow_hand_right.urdf"
 48
 49    def filename_handler(fname: str) -> str:
 50        base_path = robot_urdf_path.parent
 51        return yourdfpy.filename_handler_magic(fname, dir=base_path)
 52
 53    try:
 54        urdf = yourdfpy.URDF.load(robot_urdf_path, filename_handler=filename_handler)
 55    except FileNotFoundError:
 56        raise FileNotFoundError(
 57            "Please unzip the included URDF at `retarget_helpers/hand/shadowhand_urdf.zip`."
 58        )
 59
 60    robot = pk.Robot.from_urdf(urdf)
 61
 62    # Get the mapping from MANO to Shadow Hand joints.
 63    shadow_link_idx, mano_joint_idx = get_mapping_from_mano_to_shadow(robot)
 64
 65    # Create a mask for the MANO joints that are connected to the Shadow Hand.
 66    mano_mask = create_conn_tree(robot, shadow_link_idx)
 67
 68    # Load source motion data.
 69    dexycb_motion_path = asset_dir / "dexycb_motion.pkl"
 70    with open(dexycb_motion_path, "rb") as f:
 71        dexycb_motion_data = pickle.load(f, encoding="latin1")
 72
 73    # Load keypoints.
 74    keypoints = dexycb_motion_data["world_hand_joints"]
 75    assert not onp.isnan(keypoints).any()
 76    num_timesteps = keypoints.shape[0]
 77    num_mano_joints = len(MANO_TO_SHADOW_MAPPING)
 78
 79    # Load mano hand contact information -- these are lists of lists,
 80    # len(contact_points_per_frame) = num_timesteps,
 81    # len(contact_points_per_frame[i]) = number of contacts in frame i,
 82    contact_points_per_frame = dexycb_motion_data["contact_object_points"]
 83    contact_indices_per_frame = dexycb_motion_data["contact_joint_indices"]
 84
 85    # Now, we're going to pad this info + make a mask to indicate the padded regions.
 86    # We will also track the shadowhand joint indices, NOT the MANO joint indices.
 87    max_num_contacts = max(len(c) for c in contact_points_per_frame)
 88    padded_contact_points_per_frame = onp.zeros((num_timesteps, max_num_contacts, 3))
 89    padded_contact_indices_per_frame = onp.zeros(
 90        (num_timesteps, max_num_contacts), dtype=onp.int32
 91    )
 92    padded_contact_mask = onp.zeros((num_timesteps, max_num_contacts), dtype=onp.bool_)
 93    for i in range(num_timesteps):
 94        num_contacts = len(contact_points_per_frame[i])
 95        if num_contacts == 0:
 96            continue
 97        contact_shadowhand_indices = [
 98            robot.links.names.index(MANO_TO_SHADOW_MAPPING[j])
 99            for j in contact_indices_per_frame[i]
100        ]
101        padded_contact_points_per_frame[i, :num_contacts] = contact_points_per_frame[i]
102        padded_contact_indices_per_frame[i, :num_contacts] = contact_shadowhand_indices
103        padded_contact_mask[i, :num_contacts] = True
104
105    # Load the object.
106    object_mesh_vertices = dexycb_motion_data["object_mesh_vertices"]
107    object_mesh_faces = dexycb_motion_data["object_mesh_faces"]
108    object_pose_list = dexycb_motion_data["object_poses"]  # (N, 4, 4)
109    mesh = trimesh.Trimesh(object_mesh_vertices, object_mesh_faces)
110
111    server = viser.ViserServer()
112
113    # We will transform everything by the transform below, for aesthetics.
114    server.scene.add_frame(
115        "/scene_offset",
116        show_axes=False,
117        position=(-0.15415953, -0.73598871, 0.93434792),
118        wxyz=(-0.381870867, 0.92421569, 0.0, 2.0004992e-32),
119    )
120    base_frame = server.scene.add_frame("/scene_offset/base", show_axes=False)
121    urdf_vis = ViserUrdf(server, urdf, root_node_name="/scene_offset/base")
122    playing = server.gui.add_checkbox("playing", True)
123    timestep_slider = server.gui.add_slider("timestep", 0, num_timesteps - 1, 1, 0)
124    object_handle = server.scene.add_mesh_trimesh("/scene_offset/object", mesh)
125    server.scene.add_grid("/grid", 2.0, 2.0)
126
127    default_weights = RetargetingWeights(
128        local_alignment=10.0,
129        global_alignment=1.0,
130        contact=5.0,
131        contact_margin=0.01,
132        joint_smoothness=2.0,
133        root_smoothness=2.0,
134    )
135
136    weights = pk.viewer.WeightTuner(
137        server,
138        default_weights,  # type: ignore
139    )
140
141    Ts_world_root, joints = None, None
142
143    def generate_trajectory():
144        nonlocal Ts_world_root, joints
145        gen_button.disabled = True
146        Ts_world_root, joints = solve_retargeting(
147            robot=robot,
148            target_keypoints=keypoints,
149            shadow_hand_link_retarget_indices=shadow_link_idx,
150            mano_joint_retarget_indices=mano_joint_idx,
151            mano_mask=mano_mask,
152            contact_points_per_frame=jnp.array(padded_contact_points_per_frame),
153            contact_indices_per_frame=jnp.array(padded_contact_indices_per_frame),
154            contact_mask=jnp.array(padded_contact_mask),
155            weights=weights.get_weights(),  # type: ignore
156        )
157        gen_button.disabled = False
158
159    gen_button = server.gui.add_button("Retarget!")
160    gen_button.on_click(lambda _: generate_trajectory())
161
162    generate_trajectory()
163    assert Ts_world_root is not None and joints is not None
164
165    while True:
166        with server.atomic():
167            if playing.value:
168                timestep_slider.value = (timestep_slider.value + 1) % num_timesteps
169            tstep = timestep_slider.value
170            base_frame.wxyz = onp.array(Ts_world_root.wxyz_xyz[tstep][:4])
171            base_frame.position = onp.array(Ts_world_root.wxyz_xyz[tstep][4:])
172            urdf_vis.update_cfg(onp.array(joints[tstep]))
173
174            server.scene.add_point_cloud(
175                "/scene_offset/target_keypoints",
176                onp.array(keypoints[tstep]).reshape(-1, 3),
177                onp.array((0, 0, 255))[None]
178                .repeat(num_mano_joints, axis=0)
179                .reshape(-1, 3),
180                point_size=0.005,
181                point_shape="sparkle",
182            )
183            server.scene.add_point_cloud(
184                "/scene_offset/contact_points",
185                onp.array(contact_points_per_frame[tstep]).reshape(-1, 3),
186                onp.array((255, 0, 0))[None]
187                .repeat(len(contact_points_per_frame[tstep]), axis=0)
188                .reshape(-1, 3),
189                point_size=0.005,
190                point_shape="circle",
191            )
192            object_handle.position = object_pose_list[tstep][:3, 3]
193            object_handle.wxyz = R.from_matrix(object_pose_list[tstep][:3, :3]).as_quat(
194                scalar_first=True
195            )
196
197        time.sleep(0.05)
198
199
200@jdc.jit
201def solve_retargeting(
202    robot: pk.Robot,
203    target_keypoints: jnp.ndarray,
204    shadow_hand_link_retarget_indices: jnp.ndarray,
205    mano_joint_retarget_indices: jnp.ndarray,
206    mano_mask: jnp.ndarray,
207    contact_points_per_frame: jnp.ndarray,
208    contact_indices_per_frame: jnp.ndarray,
209    contact_mask: jnp.ndarray,
210    weights: RetargetingWeights,
211) -> Tuple[jaxlie.SE3, jnp.ndarray]:
212    """Solve the retargeting problem."""
213
214    n_retarget = len(mano_joint_retarget_indices)
215    timesteps = target_keypoints.shape[0]
216
217    # Variables.
218    class ManoJointsScaleVar(
219        jaxls.Var[jax.Array], default_factory=lambda: jnp.ones((n_retarget, n_retarget))
220    ): ...
221
222    class OffsetVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros((3,))): ...
223
224    var_joints = robot.joint_var_cls(jnp.arange(timesteps))
225    var_Ts_world_root = jaxls.SE3Var(jnp.arange(timesteps))
226    var_smpl_joints_scale = ManoJointsScaleVar(jnp.zeros(timesteps))
227    var_offset = OffsetVar(jnp.zeros(timesteps))
228
229    # Costs.
230    costs: list[jaxls.Cost] = []
231
232    @jaxls.Cost.create_factory
233    def retargeting_cost(
234        var_values: jaxls.VarValues,
235        var_Ts_world_root: jaxls.SE3Var,
236        var_robot_cfg: jaxls.Var[jnp.ndarray],
237        var_smpl_joints_scale: ManoJointsScaleVar,
238        keypoints: jnp.ndarray,
239    ) -> jax.Array:
240        """Retargeting factor, with a focus on:
241        - matching the relative joint/keypoint positions (vectors).
242        - and matching the relative angles between the vectors.
243        """
244        robot_cfg = var_values[var_robot_cfg]
245        T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
246        T_world_root = var_values[var_Ts_world_root]
247        T_world_link = T_world_root @ T_root_link
248
249        mano_pos = keypoints[jnp.array(mano_joint_retarget_indices)]
250        robot_pos = T_world_link.translation()[
251            jnp.array(shadow_hand_link_retarget_indices)
252        ]
253
254        # NxN grid of relative positions.
255        delta_mano = mano_pos[:, None] - mano_pos[None, :]
256        delta_robot = robot_pos[:, None] - robot_pos[None, :]
257
258        # Vector regularization.
259        position_scale = var_values[var_smpl_joints_scale][..., None]
260        residual_position_delta = (
261            (delta_mano - delta_robot * position_scale)
262            * (1 - jnp.eye(delta_mano.shape[0])[..., None])
263            * mano_mask[..., None]
264        )
265
266        # Vector angle regularization.
267        delta_mano_normalized = delta_mano / jnp.linalg.norm(
268            delta_mano + 1e-6, axis=-1, keepdims=True
269        )
270        delta_robot_normalized = delta_robot / jnp.linalg.norm(
271            delta_robot + 1e-6, axis=-1, keepdims=True
272        )
273        residual_angle_delta = 1 - (delta_mano_normalized * delta_robot_normalized).sum(
274            axis=-1
275        )
276        residual_angle_delta = (
277            residual_angle_delta
278            * (1 - jnp.eye(residual_angle_delta.shape[0]))
279            * mano_mask
280        )
281
282        residual = (
283            jnp.concatenate(
284                [
285                    residual_position_delta.flatten(),
286                    residual_angle_delta.flatten(),
287                ],
288                axis=0,
289            )
290            * weights["local_alignment"]
291        )
292        return residual
293
294    @jaxls.Cost.create_factory
295    def scale_regularization(
296        var_values: jaxls.VarValues,
297        var_smpl_joints_scale: ManoJointsScaleVar,
298    ) -> jax.Array:
299        """Regularize the scale of the retargeted joints."""
300        # Close to 1.
301        res_0 = (var_values[var_smpl_joints_scale] - 1.0).flatten() * 1.0
302        # Symmetric.
303        res_1 = (
304            var_values[var_smpl_joints_scale] - var_values[var_smpl_joints_scale].T
305        ).flatten() * 100.0
306        # Non-negative.
307        res_2 = jnp.clip(-var_values[var_smpl_joints_scale], min=0).flatten() * 100.0
308        return jnp.concatenate([res_0, res_1, res_2])
309
310    @jaxls.Cost.create_factory
311    def pc_alignment_cost(
312        var_values: jaxls.VarValues,
313        var_Ts_world_root: jaxls.SE3Var,
314        var_robot_cfg: jaxls.Var[jnp.ndarray],
315        keypoints: jnp.ndarray,
316    ) -> jax.Array:
317        """Soft cost to align the human keypoints to the robot, in the world frame."""
318        T_world_root = var_values[var_Ts_world_root]
319        robot_cfg = var_values[var_robot_cfg]
320        T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
321        T_world_link = T_world_root @ T_root_link
322        link_pos = T_world_link.translation()[shadow_hand_link_retarget_indices]
323        keypoint_pos = keypoints[mano_joint_retarget_indices]
324        return (link_pos - keypoint_pos).flatten() * weights["global_alignment"]
325
326    @jaxls.Cost.create_factory
327    def root_smoothness(
328        var_values: jaxls.VarValues,
329        var_Ts_world_root: jaxls.SE3Var,
330        var_Ts_world_root_prev: jaxls.SE3Var,
331    ) -> jax.Array:
332        """Smoothness cost for the robot root translation."""
333        return (
334            var_values[var_Ts_world_root].translation()
335            - var_values[var_Ts_world_root_prev].translation()
336        ).flatten() * weights["root_smoothness"]
337
338    @jaxls.Cost.create_factory
339    def contact_cost(
340        var_values: jaxls.VarValues,
341        var_T_world_root: jaxls.SE3Var,
342        var_robot_cfg: jaxls.Var[jnp.ndarray],
343        contact_points: jax.Array,  # (J, P, 3)
344        contact_indices: jax.Array,  # (J,) - Actual robot joint indices.
345        contact_points_mask: jax.Array,  # (J, P)
346    ) -> jax.Array:
347        """Cost for maintaining contact between specified robot joints and object points."""
348        robot_cfg = var_values[var_robot_cfg]
349        T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
350        T_world_root = var_values[var_T_world_root]
351        T_world_link = T_world_root @ T_root_link
352
353        contact_joint_positions_world = T_world_link.translation()[contact_indices]
354
355        # Contact points are already in world frame (as processed in dexycb).
356        # Calculate distances from each joint to its set of contact points
357        # Shape contact_points: (J, P, 3), contact_joint_positions_world: (J, 3)
358        # We want distance between joint J and points P for that joint.
359        # residual: (J, P, 3)
360        residual = contact_points - contact_joint_positions_world
361
362        # Penalize distance beyond a margin.
363        residual_penalty = jnp.maximum(
364            jnp.abs(residual) - weights["contact_margin"], 0.0
365        )  # (J, P, 3)
366
367        # Apply mask.
368        residual_penalty = (
369            residual_penalty * contact_points_mask[..., None]
370        )  # (J, P, 3)
371        residual = residual_penalty.flatten() * weights["contact"]
372
373        return residual
374
375    costs = [
376        # Costs that are relatively self-contained to the robot.
377        retargeting_cost(
378            var_Ts_world_root,
379            var_joints,
380            var_smpl_joints_scale,
381            target_keypoints,
382        ),
383        scale_regularization(var_smpl_joints_scale),
384        pk.costs.limit_cost(
385            jax.tree.map(lambda x: x[None], robot),
386            var_joints,
387            100.0,
388        ),
389        pk.costs.smoothness_cost(
390            robot.joint_var_cls(jnp.arange(1, timesteps)),
391            robot.joint_var_cls(jnp.arange(0, timesteps - 1)),
392            jnp.array([weights["joint_smoothness"]]),
393        ),
394        pk.costs.rest_cost(
395            var_joints,
396            var_joints.default_factory()[None],
397            jnp.array([0.2]),
398        ),
399        # Costs that are scene-centric.
400        pc_alignment_cost(
401            var_Ts_world_root,
402            var_joints,
403            target_keypoints,
404        ),
405        root_smoothness(
406            jaxls.SE3Var(jnp.arange(1, timesteps)),
407            jaxls.SE3Var(jnp.arange(0, timesteps - 1)),
408        ),
409        contact_cost(
410            var_T_world_root=var_Ts_world_root,
411            var_robot_cfg=var_joints,
412            contact_points=contact_points_per_frame,
413            contact_indices=contact_indices_per_frame,
414            contact_points_mask=contact_mask,
415        ),
416    ]
417
418    solution = (
419        jaxls.LeastSquaresProblem(
420            costs, [var_joints, var_Ts_world_root, var_smpl_joints_scale, var_offset]
421        )
422        .analyze()
423        .solve()
424    )
425    transform = solution[var_Ts_world_root]
426    offset = solution[var_offset]
427    transform = jaxlie.SE3.from_translation(offset) @ transform
428    return transform, solution[var_joints]
429
430
431if __name__ == "__main__":
432    main()