Hand Retargeting (Fancy)ΒΆ

Shadow Hand retargeting example, with costs to maintain contact with the object. Find and unzip the shadowhand URDF at assets/hand_retargeting/shadowhand_urdf.zip.

All examples can be run by first cloning the PyRoki repository, which includes the pyroki_snippets implementation details.

  1import pickle
  2import time
  3from pathlib import Path
  4from typing import Tuple, TypedDict
  5
  6import jax
  7import jax.numpy as jnp
  8import jax_dataclasses as jdc
  9import jaxlie
 10import jaxls
 11import numpy as onp
 12import pyroki as pk
 13import trimesh
 14import viser
 15import yourdfpy
 16from scipy.spatial.transform import Rotation as R
 17from viser.extras import ViserUrdf
 18
 19from retarget_helpers._utils import (
 20    MANO_TO_SHADOW_MAPPING,
 21    create_conn_tree,
 22    get_mapping_from_mano_to_shadow,
 23)
 24
 25
 26class RetargetingWeights(TypedDict):
 27    local_alignment: float
 28    """Local alignment weight, by matching the relative joint/keypoint positions and angles."""
 29    global_alignment: float
 30    """Global alignment weight, by matching the keypoint positions to the robot."""
 31    contact: float
 32    """Contact weight, to maintain contact between the robot and the object."""
 33    contact_margin: float
 34    """Contact margin, to stop penalizing contact when the robot is already close to the object."""
 35    joint_smoothness: float
 36    """Joint smoothness weight."""
 37    root_smoothness: float
 38    """Root translation smoothness weight."""
 39
 40
 41def main():
 42    """Main function for hand retargeting."""
 43
 44    asset_dir = Path(__file__).parent / "retarget_helpers" / "hand"
 45
 46    robot_urdf_path = asset_dir / "shadowhand_urdf" / "shadow_hand_right.urdf"
 47
 48    def filename_handler(fname: str) -> str:
 49        base_path = robot_urdf_path.parent
 50        return yourdfpy.filename_handler_magic(fname, dir=base_path)
 51
 52    try:
 53        urdf = yourdfpy.URDF.load(robot_urdf_path, filename_handler=filename_handler)
 54    except FileNotFoundError:
 55        raise FileNotFoundError(
 56            "Please unzip the included URDF at `retarget_helpers/hand/shadowhand_urdf.zip`."
 57        )
 58
 59    robot = pk.Robot.from_urdf(urdf)
 60
 61    # Get the mapping from MANO to Shadow Hand joints.
 62    shadow_link_idx, mano_joint_idx = get_mapping_from_mano_to_shadow(robot)
 63
 64    # Create a mask for the MANO joints that are connected to the Shadow Hand.
 65    mano_mask = create_conn_tree(robot, shadow_link_idx)
 66
 67    # Load source motion data.
 68    dexycb_motion_path = asset_dir / "dexycb_motion.pkl"
 69    with open(dexycb_motion_path, "rb") as f:
 70        dexycb_motion_data = pickle.load(f, encoding="latin1")
 71
 72    # Load keypoints.
 73    keypoints = dexycb_motion_data["world_hand_joints"]
 74    assert not onp.isnan(keypoints).any()
 75    num_timesteps = keypoints.shape[0]
 76    num_mano_joints = len(MANO_TO_SHADOW_MAPPING)
 77
 78    # Load mano hand contact information -- these are lists of lists,
 79    # len(contact_points_per_frame) = num_timesteps,
 80    # len(contact_points_per_frame[i]) = number of contacts in frame i,
 81    contact_points_per_frame = dexycb_motion_data["contact_object_points"]
 82    contact_indices_per_frame = dexycb_motion_data["contact_joint_indices"]
 83
 84    # Now, we're going to pad this info + make a mask to indicate the padded regions.
 85    # We will also track the shadowhand joint indices, NOT the MANO joint indices.
 86    max_num_contacts = max(len(c) for c in contact_points_per_frame)
 87    padded_contact_points_per_frame = onp.zeros((num_timesteps, max_num_contacts, 3))
 88    padded_contact_indices_per_frame = onp.zeros(
 89        (num_timesteps, max_num_contacts), dtype=onp.int32
 90    )
 91    padded_contact_mask = onp.zeros((num_timesteps, max_num_contacts), dtype=onp.bool_)
 92    for i in range(num_timesteps):
 93        num_contacts = len(contact_points_per_frame[i])
 94        if num_contacts == 0:
 95            continue
 96        contact_shadowhand_indices = [
 97            robot.links.names.index(MANO_TO_SHADOW_MAPPING[j])
 98            for j in contact_indices_per_frame[i]
 99        ]
100        padded_contact_points_per_frame[i, :num_contacts] = contact_points_per_frame[i]
101        padded_contact_indices_per_frame[i, :num_contacts] = contact_shadowhand_indices
102        padded_contact_mask[i, :num_contacts] = True
103
104    # Load the object.
105    object_mesh_vertices = dexycb_motion_data["object_mesh_vertices"]
106    object_mesh_faces = dexycb_motion_data["object_mesh_faces"]
107    object_pose_list = dexycb_motion_data["object_poses"]  # (N, 4, 4)
108    mesh = trimesh.Trimesh(object_mesh_vertices, object_mesh_faces)
109
110    server = viser.ViserServer()
111
112    # We will transform everything by the transform below, for aesthetics.
113    server.scene.add_frame(
114        "/scene_offset",
115        show_axes=False,
116        position=(-0.15415953, -0.73598871, 0.93434792),
117        wxyz=(-0.381870867, 0.92421569, 0.0, 2.0004992e-32),
118    )
119    base_frame = server.scene.add_frame("/scene_offset/base", show_axes=False)
120    urdf_vis = ViserUrdf(server, urdf, root_node_name="/scene_offset/base")
121    playing = server.gui.add_checkbox("playing", True)
122    timestep_slider = server.gui.add_slider("timestep", 0, num_timesteps - 1, 1, 0)
123    object_handle = server.scene.add_mesh_trimesh("/scene_offset/object", mesh)
124    server.scene.add_grid("/grid", 2.0, 2.0)
125
126    default_weights = RetargetingWeights(
127        local_alignment=10.0,
128        global_alignment=1.0,
129        contact=5.0,
130        contact_margin=0.01,
131        joint_smoothness=2.0,
132        root_smoothness=2.0,
133    )
134
135    weights = pk.viewer.WeightTuner(
136        server,
137        default_weights,  # type: ignore
138    )
139
140    Ts_world_root, joints = None, None
141
142    def generate_trajectory():
143        nonlocal Ts_world_root, joints
144        gen_button.disabled = True
145        Ts_world_root, joints = solve_retargeting(
146            robot=robot,
147            target_keypoints=keypoints,
148            shadow_hand_link_retarget_indices=shadow_link_idx,
149            mano_joint_retarget_indices=mano_joint_idx,
150            mano_mask=mano_mask,
151            contact_points_per_frame=jnp.array(padded_contact_points_per_frame),
152            contact_indices_per_frame=jnp.array(padded_contact_indices_per_frame),
153            contact_mask=jnp.array(padded_contact_mask),
154            weights=weights.get_weights(),  # type: ignore
155        )
156        gen_button.disabled = False
157
158    gen_button = server.gui.add_button("Retarget!")
159    gen_button.on_click(lambda _: generate_trajectory())
160
161    generate_trajectory()
162    assert Ts_world_root is not None and joints is not None
163
164    while True:
165        with server.atomic():
166            if playing.value:
167                timestep_slider.value = (timestep_slider.value + 1) % num_timesteps
168            tstep = timestep_slider.value
169            base_frame.wxyz = onp.array(Ts_world_root.wxyz_xyz[tstep][:4])
170            base_frame.position = onp.array(Ts_world_root.wxyz_xyz[tstep][4:])
171            urdf_vis.update_cfg(onp.array(joints[tstep]))
172
173            server.scene.add_point_cloud(
174                "/scene_offset/target_keypoints",
175                onp.array(keypoints[tstep]).reshape(-1, 3),
176                onp.array((0, 0, 255))[None]
177                .repeat(num_mano_joints, axis=0)
178                .reshape(-1, 3),
179                point_size=0.005,
180                point_shape="sparkle",
181            )
182            server.scene.add_point_cloud(
183                "/scene_offset/contact_points",
184                onp.array(contact_points_per_frame[tstep]).reshape(-1, 3),
185                onp.array((255, 0, 0))[None]
186                .repeat(len(contact_points_per_frame[tstep]), axis=0)
187                .reshape(-1, 3),
188                point_size=0.005,
189                point_shape="circle",
190            )
191            object_handle.position = object_pose_list[tstep][:3, 3]
192            object_handle.wxyz = R.from_matrix(object_pose_list[tstep][:3, :3]).as_quat(
193                scalar_first=True
194            )
195
196        time.sleep(0.05)
197
198
199@jdc.jit
200def solve_retargeting(
201    robot: pk.Robot,
202    target_keypoints: jnp.ndarray,
203    shadow_hand_link_retarget_indices: jnp.ndarray,
204    mano_joint_retarget_indices: jnp.ndarray,
205    mano_mask: jnp.ndarray,
206    contact_points_per_frame: jnp.ndarray,
207    contact_indices_per_frame: jnp.ndarray,
208    contact_mask: jnp.ndarray,
209    weights: RetargetingWeights,
210) -> Tuple[jaxlie.SE3, jnp.ndarray]:
211    """Solve the retargeting problem."""
212
213    n_retarget = len(mano_joint_retarget_indices)
214    timesteps = target_keypoints.shape[0]
215
216    # Variables.
217    class ManoJointsScaleVar(
218        jaxls.Var[jax.Array], default_factory=lambda: jnp.ones((n_retarget, n_retarget))
219    ): ...
220
221    class OffsetVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros((3,))): ...
222
223    var_joints = robot.joint_var_cls(jnp.arange(timesteps))
224    var_Ts_world_root = jaxls.SE3Var(jnp.arange(timesteps))
225    var_smpl_joints_scale = ManoJointsScaleVar(jnp.zeros(timesteps))
226    var_offset = OffsetVar(jnp.zeros(timesteps))
227
228    # Costs and constraints.
229    costs: list[jaxls.Cost] = []
230
231    @jaxls.Cost.factory
232    def retargeting_cost(
233        var_values: jaxls.VarValues,
234        var_Ts_world_root: jaxls.SE3Var,
235        var_robot_cfg: jaxls.Var[jnp.ndarray],
236        var_smpl_joints_scale: ManoJointsScaleVar,
237        keypoints: jnp.ndarray,
238    ) -> jax.Array:
239        """Retargeting factor, with a focus on:
240        - matching the relative joint/keypoint positions (vectors).
241        - and matching the relative angles between the vectors.
242        """
243        robot_cfg = var_values[var_robot_cfg]
244        T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
245        T_world_root = var_values[var_Ts_world_root]
246        T_world_link = T_world_root @ T_root_link
247
248        mano_pos = keypoints[jnp.array(mano_joint_retarget_indices)]
249        robot_pos = T_world_link.translation()[
250            jnp.array(shadow_hand_link_retarget_indices)
251        ]
252
253        # NxN grid of relative positions.
254        delta_mano = mano_pos[:, None] - mano_pos[None, :]
255        delta_robot = robot_pos[:, None] - robot_pos[None, :]
256
257        # Vector regularization.
258        position_scale = var_values[var_smpl_joints_scale][..., None]
259        residual_position_delta = (
260            (delta_mano - delta_robot * position_scale)
261            * (1 - jnp.eye(delta_mano.shape[0])[..., None])
262            * mano_mask[..., None]
263        )
264
265        # Vector angle regularization.
266        delta_mano_normalized = delta_mano / jnp.linalg.norm(
267            delta_mano + 1e-6, axis=-1, keepdims=True
268        )
269        delta_robot_normalized = delta_robot / jnp.linalg.norm(
270            delta_robot + 1e-6, axis=-1, keepdims=True
271        )
272        residual_angle_delta = 1 - (delta_mano_normalized * delta_robot_normalized).sum(
273            axis=-1
274        )
275        residual_angle_delta = (
276            residual_angle_delta
277            * (1 - jnp.eye(residual_angle_delta.shape[0]))
278            * mano_mask
279        )
280
281        residual = (
282            jnp.concatenate(
283                [
284                    residual_position_delta.flatten(),
285                    residual_angle_delta.flatten(),
286                ],
287                axis=0,
288            )
289            * weights["local_alignment"]
290        )
291        return residual
292
293    @jaxls.Cost.factory
294    def scale_regularization(
295        var_values: jaxls.VarValues,
296        var_smpl_joints_scale: ManoJointsScaleVar,
297    ) -> jax.Array:
298        """Regularize the scale of the retargeted joints."""
299        # Close to 1.
300        res_0 = (var_values[var_smpl_joints_scale] - 1.0).flatten() * 1.0
301        # Symmetric.
302        res_1 = (
303            var_values[var_smpl_joints_scale] - var_values[var_smpl_joints_scale].T
304        ).flatten() * 100.0
305        # Non-negative.
306        res_2 = jnp.clip(-var_values[var_smpl_joints_scale], min=0).flatten() * 100.0
307        return jnp.concatenate([res_0, res_1, res_2])
308
309    @jaxls.Cost.factory
310    def pc_alignment_cost(
311        var_values: jaxls.VarValues,
312        var_Ts_world_root: jaxls.SE3Var,
313        var_robot_cfg: jaxls.Var[jnp.ndarray],
314        keypoints: jnp.ndarray,
315    ) -> jax.Array:
316        """Soft cost to align the human keypoints to the robot, in the world frame."""
317        T_world_root = var_values[var_Ts_world_root]
318        robot_cfg = var_values[var_robot_cfg]
319        T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
320        T_world_link = T_world_root @ T_root_link
321        link_pos = T_world_link.translation()[shadow_hand_link_retarget_indices]
322        keypoint_pos = keypoints[mano_joint_retarget_indices]
323        return (link_pos - keypoint_pos).flatten() * weights["global_alignment"]
324
325    @jaxls.Cost.factory
326    def root_smoothness(
327        var_values: jaxls.VarValues,
328        var_Ts_world_root: jaxls.SE3Var,
329        var_Ts_world_root_prev: jaxls.SE3Var,
330    ) -> jax.Array:
331        """Smoothness cost for the robot root translation."""
332        return (
333            var_values[var_Ts_world_root].translation()
334            - var_values[var_Ts_world_root_prev].translation()
335        ).flatten() * weights["root_smoothness"]
336
337    @jaxls.Cost.factory
338    def contact_cost(
339        var_values: jaxls.VarValues,
340        var_T_world_root: jaxls.SE3Var,
341        var_robot_cfg: jaxls.Var[jnp.ndarray],
342        contact_points: jax.Array,  # (J, P, 3)
343        contact_indices: jax.Array,  # (J,) - Actual robot joint indices.
344        contact_points_mask: jax.Array,  # (J, P)
345    ) -> jax.Array:
346        """Cost for maintaining contact between specified robot joints and object points."""
347        robot_cfg = var_values[var_robot_cfg]
348        T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
349        T_world_root = var_values[var_T_world_root]
350        T_world_link = T_world_root @ T_root_link
351
352        contact_joint_positions_world = T_world_link.translation()[contact_indices]
353
354        # Contact points are already in world frame (as processed in dexycb).
355        # Calculate distances from each joint to its set of contact points
356        # Shape contact_points: (J, P, 3), contact_joint_positions_world: (J, 3)
357        # We want distance between joint J and points P for that joint.
358        # residual: (J, P, 3)
359        residual = contact_points - contact_joint_positions_world
360
361        # Penalize distance beyond a margin.
362        residual_penalty = jnp.maximum(
363            jnp.abs(residual) - weights["contact_margin"], 0.0
364        )  # (J, P, 3)
365
366        # Apply mask.
367        residual_penalty = (
368            residual_penalty * contact_points_mask[..., None]
369        )  # (J, P, 3)
370        residual = residual_penalty.flatten() * weights["contact"]
371
372        return residual
373
374    costs = [
375        # Costs that are relatively self-contained to the robot.
376        retargeting_cost(
377            var_Ts_world_root,
378            var_joints,
379            var_smpl_joints_scale,
380            target_keypoints,
381        ),
382        scale_regularization(var_smpl_joints_scale),
383        pk.costs.smoothness_cost(
384            robot.joint_var_cls(jnp.arange(1, timesteps)),
385            robot.joint_var_cls(jnp.arange(0, timesteps - 1)),
386            jnp.array([weights["joint_smoothness"]]),
387        ),
388        pk.costs.rest_cost(
389            var_joints,
390            var_joints.default_factory()[None],
391            jnp.array([0.2]),
392        ),
393        # Costs that are scene-centric.
394        pc_alignment_cost(
395            var_Ts_world_root,
396            var_joints,
397            target_keypoints,
398        ),
399        root_smoothness(
400            jaxls.SE3Var(jnp.arange(1, timesteps)),
401            jaxls.SE3Var(jnp.arange(0, timesteps - 1)),
402        ),
403        contact_cost(
404            var_T_world_root=var_Ts_world_root,
405            var_robot_cfg=var_joints,
406            contact_points=contact_points_per_frame,
407            contact_indices=contact_indices_per_frame,
408            contact_points_mask=contact_mask,
409        ),
410    ]
411
412    costs.append(
413        pk.costs.limit_constraint(
414            jax.tree.map(lambda x: x[None], robot),
415            var_joints,
416        ),
417    )
418
419    solution = (
420        jaxls.LeastSquaresProblem(
421            costs=costs,
422            variables=[
423                var_joints,
424                var_Ts_world_root,
425                var_smpl_joints_scale,
426                var_offset,
427            ],
428        )
429        .analyze()
430        .solve()
431    )
432    transform = solution[var_Ts_world_root]
433    offset = solution[var_offset]
434    transform = jaxlie.SE3.from_translation(offset) @ transform
435    return transform, solution[var_joints]
436
437
438if __name__ == "__main__":
439    main()