Hand Retargeting (Fancy)ΒΆ
Shadow Hand retargeting example, with costs to maintain contact with the object.
Find and unzip the shadowhand URDF at assets/hand_retargeting/shadowhand_urdf.zip.
All examples can be run by first cloning the PyRoki repository, which includes the pyroki_snippets implementation details.
1import pickle
2import time
3from pathlib import Path
4from typing import Tuple, TypedDict
5
6import jax
7import jax.numpy as jnp
8import jax_dataclasses as jdc
9import jaxlie
10import jaxls
11import numpy as onp
12import pyroki as pk
13import trimesh
14import viser
15import yourdfpy
16from scipy.spatial.transform import Rotation as R
17from viser.extras import ViserUrdf
18
19from retarget_helpers._utils import (
20 MANO_TO_SHADOW_MAPPING,
21 create_conn_tree,
22 get_mapping_from_mano_to_shadow,
23)
24
25
26class RetargetingWeights(TypedDict):
27 local_alignment: float
28 """Local alignment weight, by matching the relative joint/keypoint positions and angles."""
29 global_alignment: float
30 """Global alignment weight, by matching the keypoint positions to the robot."""
31 contact: float
32 """Contact weight, to maintain contact between the robot and the object."""
33 contact_margin: float
34 """Contact margin, to stop penalizing contact when the robot is already close to the object."""
35 joint_smoothness: float
36 """Joint smoothness weight."""
37 root_smoothness: float
38 """Root translation smoothness weight."""
39
40
41def main():
42 """Main function for hand retargeting."""
43
44 asset_dir = Path(__file__).parent / "retarget_helpers" / "hand"
45
46 robot_urdf_path = asset_dir / "shadowhand_urdf" / "shadow_hand_right.urdf"
47
48 def filename_handler(fname: str) -> str:
49 base_path = robot_urdf_path.parent
50 return yourdfpy.filename_handler_magic(fname, dir=base_path)
51
52 try:
53 urdf = yourdfpy.URDF.load(robot_urdf_path, filename_handler=filename_handler)
54 except FileNotFoundError:
55 raise FileNotFoundError(
56 "Please unzip the included URDF at `retarget_helpers/hand/shadowhand_urdf.zip`."
57 )
58
59 robot = pk.Robot.from_urdf(urdf)
60
61 # Get the mapping from MANO to Shadow Hand joints.
62 shadow_link_idx, mano_joint_idx = get_mapping_from_mano_to_shadow(robot)
63
64 # Create a mask for the MANO joints that are connected to the Shadow Hand.
65 mano_mask = create_conn_tree(robot, shadow_link_idx)
66
67 # Load source motion data.
68 dexycb_motion_path = asset_dir / "dexycb_motion.pkl"
69 with open(dexycb_motion_path, "rb") as f:
70 dexycb_motion_data = pickle.load(f, encoding="latin1")
71
72 # Load keypoints.
73 keypoints = dexycb_motion_data["world_hand_joints"]
74 assert not onp.isnan(keypoints).any()
75 num_timesteps = keypoints.shape[0]
76 num_mano_joints = len(MANO_TO_SHADOW_MAPPING)
77
78 # Load mano hand contact information -- these are lists of lists,
79 # len(contact_points_per_frame) = num_timesteps,
80 # len(contact_points_per_frame[i]) = number of contacts in frame i,
81 contact_points_per_frame = dexycb_motion_data["contact_object_points"]
82 contact_indices_per_frame = dexycb_motion_data["contact_joint_indices"]
83
84 # Now, we're going to pad this info + make a mask to indicate the padded regions.
85 # We will also track the shadowhand joint indices, NOT the MANO joint indices.
86 max_num_contacts = max(len(c) for c in contact_points_per_frame)
87 padded_contact_points_per_frame = onp.zeros((num_timesteps, max_num_contacts, 3))
88 padded_contact_indices_per_frame = onp.zeros(
89 (num_timesteps, max_num_contacts), dtype=onp.int32
90 )
91 padded_contact_mask = onp.zeros((num_timesteps, max_num_contacts), dtype=onp.bool_)
92 for i in range(num_timesteps):
93 num_contacts = len(contact_points_per_frame[i])
94 if num_contacts == 0:
95 continue
96 contact_shadowhand_indices = [
97 robot.links.names.index(MANO_TO_SHADOW_MAPPING[j])
98 for j in contact_indices_per_frame[i]
99 ]
100 padded_contact_points_per_frame[i, :num_contacts] = contact_points_per_frame[i]
101 padded_contact_indices_per_frame[i, :num_contacts] = contact_shadowhand_indices
102 padded_contact_mask[i, :num_contacts] = True
103
104 # Load the object.
105 object_mesh_vertices = dexycb_motion_data["object_mesh_vertices"]
106 object_mesh_faces = dexycb_motion_data["object_mesh_faces"]
107 object_pose_list = dexycb_motion_data["object_poses"] # (N, 4, 4)
108 mesh = trimesh.Trimesh(object_mesh_vertices, object_mesh_faces)
109
110 server = viser.ViserServer()
111
112 # We will transform everything by the transform below, for aesthetics.
113 server.scene.add_frame(
114 "/scene_offset",
115 show_axes=False,
116 position=(-0.15415953, -0.73598871, 0.93434792),
117 wxyz=(-0.381870867, 0.92421569, 0.0, 2.0004992e-32),
118 )
119 base_frame = server.scene.add_frame("/scene_offset/base", show_axes=False)
120 urdf_vis = ViserUrdf(server, urdf, root_node_name="/scene_offset/base")
121 playing = server.gui.add_checkbox("playing", True)
122 timestep_slider = server.gui.add_slider("timestep", 0, num_timesteps - 1, 1, 0)
123 object_handle = server.scene.add_mesh_trimesh("/scene_offset/object", mesh)
124 server.scene.add_grid("/grid", 2.0, 2.0)
125
126 default_weights = RetargetingWeights(
127 local_alignment=10.0,
128 global_alignment=1.0,
129 contact=5.0,
130 contact_margin=0.01,
131 joint_smoothness=2.0,
132 root_smoothness=2.0,
133 )
134
135 weights = pk.viewer.WeightTuner(
136 server,
137 default_weights, # type: ignore
138 )
139
140 Ts_world_root, joints = None, None
141
142 def generate_trajectory():
143 nonlocal Ts_world_root, joints
144 gen_button.disabled = True
145 Ts_world_root, joints = solve_retargeting(
146 robot=robot,
147 target_keypoints=keypoints,
148 shadow_hand_link_retarget_indices=shadow_link_idx,
149 mano_joint_retarget_indices=mano_joint_idx,
150 mano_mask=mano_mask,
151 contact_points_per_frame=jnp.array(padded_contact_points_per_frame),
152 contact_indices_per_frame=jnp.array(padded_contact_indices_per_frame),
153 contact_mask=jnp.array(padded_contact_mask),
154 weights=weights.get_weights(), # type: ignore
155 )
156 gen_button.disabled = False
157
158 gen_button = server.gui.add_button("Retarget!")
159 gen_button.on_click(lambda _: generate_trajectory())
160
161 generate_trajectory()
162 assert Ts_world_root is not None and joints is not None
163
164 while True:
165 with server.atomic():
166 if playing.value:
167 timestep_slider.value = (timestep_slider.value + 1) % num_timesteps
168 tstep = timestep_slider.value
169 base_frame.wxyz = onp.array(Ts_world_root.wxyz_xyz[tstep][:4])
170 base_frame.position = onp.array(Ts_world_root.wxyz_xyz[tstep][4:])
171 urdf_vis.update_cfg(onp.array(joints[tstep]))
172
173 server.scene.add_point_cloud(
174 "/scene_offset/target_keypoints",
175 onp.array(keypoints[tstep]).reshape(-1, 3),
176 onp.array((0, 0, 255))[None]
177 .repeat(num_mano_joints, axis=0)
178 .reshape(-1, 3),
179 point_size=0.005,
180 point_shape="sparkle",
181 )
182 server.scene.add_point_cloud(
183 "/scene_offset/contact_points",
184 onp.array(contact_points_per_frame[tstep]).reshape(-1, 3),
185 onp.array((255, 0, 0))[None]
186 .repeat(len(contact_points_per_frame[tstep]), axis=0)
187 .reshape(-1, 3),
188 point_size=0.005,
189 point_shape="circle",
190 )
191 object_handle.position = object_pose_list[tstep][:3, 3]
192 object_handle.wxyz = R.from_matrix(object_pose_list[tstep][:3, :3]).as_quat(
193 scalar_first=True
194 )
195
196 time.sleep(0.05)
197
198
199@jdc.jit
200def solve_retargeting(
201 robot: pk.Robot,
202 target_keypoints: jnp.ndarray,
203 shadow_hand_link_retarget_indices: jnp.ndarray,
204 mano_joint_retarget_indices: jnp.ndarray,
205 mano_mask: jnp.ndarray,
206 contact_points_per_frame: jnp.ndarray,
207 contact_indices_per_frame: jnp.ndarray,
208 contact_mask: jnp.ndarray,
209 weights: RetargetingWeights,
210) -> Tuple[jaxlie.SE3, jnp.ndarray]:
211 """Solve the retargeting problem."""
212
213 n_retarget = len(mano_joint_retarget_indices)
214 timesteps = target_keypoints.shape[0]
215
216 # Variables.
217 class ManoJointsScaleVar(
218 jaxls.Var[jax.Array], default_factory=lambda: jnp.ones((n_retarget, n_retarget))
219 ): ...
220
221 class OffsetVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros((3,))): ...
222
223 var_joints = robot.joint_var_cls(jnp.arange(timesteps))
224 var_Ts_world_root = jaxls.SE3Var(jnp.arange(timesteps))
225 var_smpl_joints_scale = ManoJointsScaleVar(jnp.zeros(timesteps))
226 var_offset = OffsetVar(jnp.zeros(timesteps))
227
228 # Costs and constraints.
229 costs: list[jaxls.Cost] = []
230
231 @jaxls.Cost.factory
232 def retargeting_cost(
233 var_values: jaxls.VarValues,
234 var_Ts_world_root: jaxls.SE3Var,
235 var_robot_cfg: jaxls.Var[jnp.ndarray],
236 var_smpl_joints_scale: ManoJointsScaleVar,
237 keypoints: jnp.ndarray,
238 ) -> jax.Array:
239 """Retargeting factor, with a focus on:
240 - matching the relative joint/keypoint positions (vectors).
241 - and matching the relative angles between the vectors.
242 """
243 robot_cfg = var_values[var_robot_cfg]
244 T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
245 T_world_root = var_values[var_Ts_world_root]
246 T_world_link = T_world_root @ T_root_link
247
248 mano_pos = keypoints[jnp.array(mano_joint_retarget_indices)]
249 robot_pos = T_world_link.translation()[
250 jnp.array(shadow_hand_link_retarget_indices)
251 ]
252
253 # NxN grid of relative positions.
254 delta_mano = mano_pos[:, None] - mano_pos[None, :]
255 delta_robot = robot_pos[:, None] - robot_pos[None, :]
256
257 # Vector regularization.
258 position_scale = var_values[var_smpl_joints_scale][..., None]
259 residual_position_delta = (
260 (delta_mano - delta_robot * position_scale)
261 * (1 - jnp.eye(delta_mano.shape[0])[..., None])
262 * mano_mask[..., None]
263 )
264
265 # Vector angle regularization.
266 delta_mano_normalized = delta_mano / jnp.linalg.norm(
267 delta_mano + 1e-6, axis=-1, keepdims=True
268 )
269 delta_robot_normalized = delta_robot / jnp.linalg.norm(
270 delta_robot + 1e-6, axis=-1, keepdims=True
271 )
272 residual_angle_delta = 1 - (delta_mano_normalized * delta_robot_normalized).sum(
273 axis=-1
274 )
275 residual_angle_delta = (
276 residual_angle_delta
277 * (1 - jnp.eye(residual_angle_delta.shape[0]))
278 * mano_mask
279 )
280
281 residual = (
282 jnp.concatenate(
283 [
284 residual_position_delta.flatten(),
285 residual_angle_delta.flatten(),
286 ],
287 axis=0,
288 )
289 * weights["local_alignment"]
290 )
291 return residual
292
293 @jaxls.Cost.factory
294 def scale_regularization(
295 var_values: jaxls.VarValues,
296 var_smpl_joints_scale: ManoJointsScaleVar,
297 ) -> jax.Array:
298 """Regularize the scale of the retargeted joints."""
299 # Close to 1.
300 res_0 = (var_values[var_smpl_joints_scale] - 1.0).flatten() * 1.0
301 # Symmetric.
302 res_1 = (
303 var_values[var_smpl_joints_scale] - var_values[var_smpl_joints_scale].T
304 ).flatten() * 100.0
305 # Non-negative.
306 res_2 = jnp.clip(-var_values[var_smpl_joints_scale], min=0).flatten() * 100.0
307 return jnp.concatenate([res_0, res_1, res_2])
308
309 @jaxls.Cost.factory
310 def pc_alignment_cost(
311 var_values: jaxls.VarValues,
312 var_Ts_world_root: jaxls.SE3Var,
313 var_robot_cfg: jaxls.Var[jnp.ndarray],
314 keypoints: jnp.ndarray,
315 ) -> jax.Array:
316 """Soft cost to align the human keypoints to the robot, in the world frame."""
317 T_world_root = var_values[var_Ts_world_root]
318 robot_cfg = var_values[var_robot_cfg]
319 T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
320 T_world_link = T_world_root @ T_root_link
321 link_pos = T_world_link.translation()[shadow_hand_link_retarget_indices]
322 keypoint_pos = keypoints[mano_joint_retarget_indices]
323 return (link_pos - keypoint_pos).flatten() * weights["global_alignment"]
324
325 @jaxls.Cost.factory
326 def root_smoothness(
327 var_values: jaxls.VarValues,
328 var_Ts_world_root: jaxls.SE3Var,
329 var_Ts_world_root_prev: jaxls.SE3Var,
330 ) -> jax.Array:
331 """Smoothness cost for the robot root translation."""
332 return (
333 var_values[var_Ts_world_root].translation()
334 - var_values[var_Ts_world_root_prev].translation()
335 ).flatten() * weights["root_smoothness"]
336
337 @jaxls.Cost.factory
338 def contact_cost(
339 var_values: jaxls.VarValues,
340 var_T_world_root: jaxls.SE3Var,
341 var_robot_cfg: jaxls.Var[jnp.ndarray],
342 contact_points: jax.Array, # (J, P, 3)
343 contact_indices: jax.Array, # (J,) - Actual robot joint indices.
344 contact_points_mask: jax.Array, # (J, P)
345 ) -> jax.Array:
346 """Cost for maintaining contact between specified robot joints and object points."""
347 robot_cfg = var_values[var_robot_cfg]
348 T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
349 T_world_root = var_values[var_T_world_root]
350 T_world_link = T_world_root @ T_root_link
351
352 contact_joint_positions_world = T_world_link.translation()[contact_indices]
353
354 # Contact points are already in world frame (as processed in dexycb).
355 # Calculate distances from each joint to its set of contact points
356 # Shape contact_points: (J, P, 3), contact_joint_positions_world: (J, 3)
357 # We want distance between joint J and points P for that joint.
358 # residual: (J, P, 3)
359 residual = contact_points - contact_joint_positions_world
360
361 # Penalize distance beyond a margin.
362 residual_penalty = jnp.maximum(
363 jnp.abs(residual) - weights["contact_margin"], 0.0
364 ) # (J, P, 3)
365
366 # Apply mask.
367 residual_penalty = (
368 residual_penalty * contact_points_mask[..., None]
369 ) # (J, P, 3)
370 residual = residual_penalty.flatten() * weights["contact"]
371
372 return residual
373
374 costs = [
375 # Costs that are relatively self-contained to the robot.
376 retargeting_cost(
377 var_Ts_world_root,
378 var_joints,
379 var_smpl_joints_scale,
380 target_keypoints,
381 ),
382 scale_regularization(var_smpl_joints_scale),
383 pk.costs.smoothness_cost(
384 robot.joint_var_cls(jnp.arange(1, timesteps)),
385 robot.joint_var_cls(jnp.arange(0, timesteps - 1)),
386 jnp.array([weights["joint_smoothness"]]),
387 ),
388 pk.costs.rest_cost(
389 var_joints,
390 var_joints.default_factory()[None],
391 jnp.array([0.2]),
392 ),
393 # Costs that are scene-centric.
394 pc_alignment_cost(
395 var_Ts_world_root,
396 var_joints,
397 target_keypoints,
398 ),
399 root_smoothness(
400 jaxls.SE3Var(jnp.arange(1, timesteps)),
401 jaxls.SE3Var(jnp.arange(0, timesteps - 1)),
402 ),
403 contact_cost(
404 var_T_world_root=var_Ts_world_root,
405 var_robot_cfg=var_joints,
406 contact_points=contact_points_per_frame,
407 contact_indices=contact_indices_per_frame,
408 contact_points_mask=contact_mask,
409 ),
410 ]
411
412 costs.append(
413 pk.costs.limit_constraint(
414 jax.tree.map(lambda x: x[None], robot),
415 var_joints,
416 ),
417 )
418
419 solution = (
420 jaxls.LeastSquaresProblem(
421 costs=costs,
422 variables=[
423 var_joints,
424 var_Ts_world_root,
425 var_smpl_joints_scale,
426 var_offset,
427 ],
428 )
429 .analyze()
430 .solve()
431 )
432 transform = solution[var_Ts_world_root]
433 offset = solution[var_offset]
434 transform = jaxlie.SE3.from_translation(offset) @ transform
435 return transform, solution[var_joints]
436
437
438if __name__ == "__main__":
439 main()