Hand RetargetingΒΆ

Simpler shadow hand retargeting example. Find and unzip the shadowhand URDF at assets/hand_retargeting/shadowhand_urdf.zip.

All examples can be run by first cloning the PyRoki repository, which includes the pyroki_snippets implementation details.

  1import pickle
  2import time
  3from pathlib import Path
  4from typing import Tuple, TypedDict
  5
  6import jax
  7import jax.numpy as jnp
  8import jax_dataclasses as jdc
  9import jaxlie
 10import jaxls
 11import numpy as onp
 12import pyroki as pk
 13import trimesh
 14import viser
 15import yourdfpy
 16from scipy.spatial.transform import Rotation as R
 17from viser.extras import ViserUrdf
 18
 19from retarget_helpers._utils import (
 20    MANO_TO_SHADOW_MAPPING,
 21    create_conn_tree,
 22    get_mapping_from_mano_to_shadow,
 23)
 24
 25
 26class RetargetingWeights(TypedDict):
 27    local_alignment: float
 28    """Local alignment weight, by matching the relative joint/keypoint positions and angles."""
 29    global_alignment: float
 30    """Global alignment weight, by matching the keypoint positions to the robot."""
 31    joint_smoothness: float
 32    """Joint smoothness weight."""
 33    root_smoothness: float
 34    """Root translation smoothness weight."""
 35
 36
 37def main():
 38    """Main function for hand retargeting."""
 39
 40    asset_dir = Path(__file__).parent / "retarget_helpers" / "hand"
 41
 42    robot_urdf_path = asset_dir / "shadowhand_urdf" / "shadow_hand_right.urdf"
 43
 44    def filename_handler(fname: str) -> str:
 45        base_path = robot_urdf_path.parent
 46        return yourdfpy.filename_handler_magic(fname, dir=base_path)
 47
 48    try:
 49        urdf = yourdfpy.URDF.load(robot_urdf_path, filename_handler=filename_handler)
 50    except FileNotFoundError:
 51        raise FileNotFoundError(
 52            "Please unzip the included URDF at `retarget_helpers/hand/shadowhand_urdf.zip`."
 53        )
 54
 55    robot = pk.Robot.from_urdf(urdf)
 56
 57    # Get the mapping from MANO to Shadow Hand joints.
 58    shadow_link_idx, mano_joint_idx = get_mapping_from_mano_to_shadow(robot)
 59
 60    # Create a mask for the MANO joints that are connected to the Shadow Hand.
 61    mano_mask = create_conn_tree(robot, shadow_link_idx)
 62
 63    # Load source motion data.
 64    dexycb_motion_path = asset_dir / "dexycb_motion.pkl"
 65    with open(dexycb_motion_path, "rb") as f:
 66        dexycb_motion_data = pickle.load(f, encoding="latin1")
 67
 68    # Load keypoints.
 69    keypoints = dexycb_motion_data["world_hand_joints"]
 70    assert not onp.isnan(keypoints).any()
 71    num_timesteps = keypoints.shape[0]
 72    num_mano_joints = len(MANO_TO_SHADOW_MAPPING)
 73
 74    # Load mano hand contact information -- these are lists of lists,
 75    # len(contact_points_per_frame) = num_timesteps,
 76    # len(contact_points_per_frame[i]) = number of contacts in frame i,
 77    contact_points_per_frame = dexycb_motion_data["contact_object_points"]
 78    contact_indices_per_frame = dexycb_motion_data["contact_joint_indices"]
 79
 80    # Now, we're going to pad this info + make a mask to indicate the padded regions.
 81    # We will also track the shadowhand joint indices, NOT the MANO joint indices.
 82    max_num_contacts = max(len(c) for c in contact_points_per_frame)
 83    padded_contact_points_per_frame = onp.zeros((num_timesteps, max_num_contacts, 3))
 84    padded_contact_indices_per_frame = onp.zeros(
 85        (num_timesteps, max_num_contacts), dtype=onp.int32
 86    )
 87    padded_contact_mask = onp.zeros((num_timesteps, max_num_contacts), dtype=onp.bool_)
 88    for i in range(num_timesteps):
 89        num_contacts = len(contact_points_per_frame[i])
 90        if num_contacts == 0:
 91            continue
 92        contact_shadowhand_indices = [
 93            robot.links.names.index(MANO_TO_SHADOW_MAPPING[j])
 94            for j in contact_indices_per_frame[i]
 95        ]
 96        padded_contact_points_per_frame[i, :num_contacts] = contact_points_per_frame[i]
 97        padded_contact_indices_per_frame[i, :num_contacts] = contact_shadowhand_indices
 98        padded_contact_mask[i, :num_contacts] = True
 99
100    # Load the object.
101    object_mesh_vertices = dexycb_motion_data["object_mesh_vertices"]
102    object_mesh_faces = dexycb_motion_data["object_mesh_faces"]
103    object_pose_list = dexycb_motion_data["object_poses"]  # (N, 4, 4)
104    mesh = trimesh.Trimesh(object_mesh_vertices, object_mesh_faces)
105
106    server = viser.ViserServer()
107
108    # We will transform everything by the transform below, for aesthetics.
109    server.scene.add_frame(
110        "/scene_offset",
111        show_axes=False,
112        position=(-0.15415953, -0.73598871, 0.93434792),
113        wxyz=(-0.381870867, 0.92421569, 0.0, 2.0004992e-32),
114    )
115    hand_mesh = server.scene.add_mesh_simple(
116        "/scene_offset/hand_mesh",
117        vertices=dexycb_motion_data["world_hand_vertices"][0, :, :],
118        faces=dexycb_motion_data["hand_mesh_faces"],
119        opacity=0.5,
120    )
121    base_frame = server.scene.add_frame("/scene_offset/base", show_axes=False)
122    urdf_vis = ViserUrdf(server, urdf, root_node_name="/scene_offset/base")
123    playing = server.gui.add_checkbox("playing", True)
124    timestep_slider = server.gui.add_slider("timestep", 0, num_timesteps - 1, 1, 0)
125    object_handle = server.scene.add_mesh_trimesh("/scene_offset/object", mesh)
126    server.scene.add_grid("/grid", 2.0, 2.0)
127
128    default_weights = RetargetingWeights(
129        local_alignment=10.0,
130        global_alignment=1.0,
131        joint_smoothness=2.0,
132        root_smoothness=2.0,
133    )
134
135    weights = pk.viewer.WeightTuner(
136        server,
137        default_weights,  # type: ignore
138    )
139
140    Ts_world_root, joints = None, None
141
142    def generate_trajectory():
143        nonlocal Ts_world_root, joints
144        gen_button.disabled = True
145        Ts_world_root, joints = solve_retargeting(
146            robot=robot,
147            target_keypoints=keypoints,
148            shadow_hand_link_retarget_indices=shadow_link_idx,
149            mano_joint_retarget_indices=mano_joint_idx,
150            mano_mask=mano_mask,
151            weights=weights.get_weights(),  # type: ignore
152        )
153        gen_button.disabled = False
154
155    gen_button = server.gui.add_button("Retarget!")
156    gen_button.on_click(lambda _: generate_trajectory())
157
158    generate_trajectory()
159    assert Ts_world_root is not None and joints is not None
160
161    while True:
162        with server.atomic():
163            if playing.value:
164                timestep_slider.value = (timestep_slider.value + 1) % num_timesteps
165            tstep = timestep_slider.value
166            base_frame.wxyz = onp.array(Ts_world_root.wxyz_xyz[tstep][:4])
167            base_frame.position = onp.array(Ts_world_root.wxyz_xyz[tstep][4:])
168            urdf_vis.update_cfg(onp.array(joints[tstep]))
169
170            server.scene.add_point_cloud(
171                "/scene_offset/target_keypoints",
172                onp.array(keypoints[tstep]).reshape(-1, 3),
173                onp.array((0, 0, 255))[None]
174                .repeat(num_mano_joints, axis=0)
175                .reshape(-1, 3),
176                point_size=0.005,
177                point_shape="sparkle",
178            )
179            server.scene.add_point_cloud(
180                "/scene_offset/contact_points",
181                onp.array(contact_points_per_frame[tstep]).reshape(-1, 3),
182                onp.array((255, 0, 0))[None]
183                .repeat(len(contact_points_per_frame[tstep]), axis=0)
184                .reshape(-1, 3),
185                point_size=0.005,
186                point_shape="circle",
187            )
188            hand_mesh.vertices = dexycb_motion_data["world_hand_vertices"][tstep, :, :]
189            object_handle.position = object_pose_list[tstep][:3, 3]
190            object_handle.wxyz = R.from_matrix(object_pose_list[tstep][:3, :3]).as_quat(
191                scalar_first=True
192            )
193
194        time.sleep(0.05)
195
196
197@jdc.jit
198def solve_retargeting(
199    robot: pk.Robot,
200    target_keypoints: jnp.ndarray,
201    shadow_hand_link_retarget_indices: jnp.ndarray,
202    mano_joint_retarget_indices: jnp.ndarray,
203    mano_mask: jnp.ndarray,
204    weights: RetargetingWeights,
205) -> Tuple[jaxlie.SE3, jnp.ndarray]:
206    """Solve the retargeting problem."""
207
208    n_retarget = len(mano_joint_retarget_indices)
209    timesteps = target_keypoints.shape[0]
210
211    # Variables.
212    class ManoJointsScaleVar(
213        jaxls.Var[jax.Array], default_factory=lambda: jnp.ones((n_retarget, n_retarget))
214    ): ...
215
216    class OffsetVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros((3,))): ...
217
218    var_joints = robot.joint_var_cls(jnp.arange(timesteps))
219    var_Ts_world_root = jaxls.SE3Var(jnp.arange(timesteps))
220    var_smpl_joints_scale = ManoJointsScaleVar(jnp.zeros(timesteps))
221    var_offset = OffsetVar(jnp.zeros(timesteps))
222
223    # Costs.
224    costs: list[jaxls.Cost] = []
225
226    @jaxls.Cost.create_factory
227    def retargeting_cost(
228        var_values: jaxls.VarValues,
229        var_Ts_world_root: jaxls.SE3Var,
230        var_robot_cfg: jaxls.Var[jnp.ndarray],
231        var_smpl_joints_scale: ManoJointsScaleVar,
232        keypoints: jnp.ndarray,
233    ) -> jax.Array:
234        """Retargeting factor, with a focus on:
235        - matching the relative joint/keypoint positions (vectors).
236        - and matching the relative angles between the vectors.
237        """
238        robot_cfg = var_values[var_robot_cfg]
239        T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
240        T_world_root = var_values[var_Ts_world_root]
241        T_world_link = T_world_root @ T_root_link
242
243        mano_pos = keypoints[jnp.array(mano_joint_retarget_indices)]
244        robot_pos = T_world_link.translation()[
245            jnp.array(shadow_hand_link_retarget_indices)
246        ]
247
248        # NxN grid of relative positions.
249        delta_mano = mano_pos[:, None] - mano_pos[None, :]
250        delta_robot = robot_pos[:, None] - robot_pos[None, :]
251
252        # Vector regularization.
253        position_scale = var_values[var_smpl_joints_scale][..., None]
254        residual_position_delta = (
255            (delta_mano - delta_robot * position_scale)
256            * (1 - jnp.eye(delta_mano.shape[0])[..., None])
257            * mano_mask[..., None]
258        )
259
260        # Vector angle regularization.
261        delta_mano_normalized = delta_mano / jnp.linalg.norm(
262            delta_mano + 1e-6, axis=-1, keepdims=True
263        )
264        delta_robot_normalized = delta_robot / jnp.linalg.norm(
265            delta_robot + 1e-6, axis=-1, keepdims=True
266        )
267        residual_angle_delta = 1 - (delta_mano_normalized * delta_robot_normalized).sum(
268            axis=-1
269        )
270        residual_angle_delta = (
271            residual_angle_delta
272            * (1 - jnp.eye(residual_angle_delta.shape[0]))
273            * mano_mask
274        )
275
276        residual = (
277            jnp.concatenate(
278                [
279                    residual_position_delta.flatten(),
280                    residual_angle_delta.flatten(),
281                ],
282                axis=0,
283            )
284            * weights["local_alignment"]
285        )
286        return residual
287
288    @jaxls.Cost.create_factory
289    def pc_alignment_cost(
290        var_values: jaxls.VarValues,
291        var_Ts_world_root: jaxls.SE3Var,
292        var_robot_cfg: jaxls.Var[jnp.ndarray],
293        keypoints: jnp.ndarray,
294    ) -> jax.Array:
295        """Soft cost to align the human keypoints to the robot, in the world frame."""
296        T_world_root = var_values[var_Ts_world_root]
297        robot_cfg = var_values[var_robot_cfg]
298        T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
299        T_world_link = T_world_root @ T_root_link
300        link_pos = T_world_link.translation()[shadow_hand_link_retarget_indices]
301        keypoint_pos = keypoints[mano_joint_retarget_indices]
302        return (link_pos - keypoint_pos).flatten() * weights["global_alignment"]
303
304    @jaxls.Cost.create_factory
305    def root_smoothness(
306        var_values: jaxls.VarValues,
307        var_Ts_world_root: jaxls.SE3Var,
308        var_Ts_world_root_prev: jaxls.SE3Var,
309    ) -> jax.Array:
310        """Smoothness cost for the robot root translation."""
311        return (
312            var_values[var_Ts_world_root].translation()
313            - var_values[var_Ts_world_root_prev].translation()
314        ).flatten() * weights["root_smoothness"]
315
316    costs = [
317        retargeting_cost(
318            var_Ts_world_root,
319            var_joints,
320            var_smpl_joints_scale,
321            target_keypoints,
322        ),
323        pk.costs.limit_cost(
324            jax.tree.map(lambda x: x[None], robot),
325            var_joints,
326            100.0,
327        ),
328        pk.costs.smoothness_cost(
329            robot.joint_var_cls(jnp.arange(1, timesteps)),
330            robot.joint_var_cls(jnp.arange(0, timesteps - 1)),
331            jnp.array([weights["joint_smoothness"]]),
332        ),
333        pc_alignment_cost(
334            var_Ts_world_root,
335            var_joints,
336            target_keypoints,
337        ),
338        root_smoothness(
339            jaxls.SE3Var(jnp.arange(1, timesteps)),
340            jaxls.SE3Var(jnp.arange(0, timesteps - 1)),
341        ),
342    ]
343
344    solution = (
345        jaxls.LeastSquaresProblem(
346            costs, [var_joints, var_Ts_world_root, var_smpl_joints_scale, var_offset]
347        )
348        .analyze()
349        .solve()
350    )
351    transform = solution[var_Ts_world_root]
352    offset = solution[var_offset]
353    transform = jaxlie.SE3.from_translation(offset) @ transform
354    return transform, solution[var_joints]
355
356
357if __name__ == "__main__":
358    main()