Hand Retargeting (Fancy)ΒΆ
Shadow Hand retargeting example, with costs to maintain contact with the object.
Find and unzip the shadowhand URDF at assets/hand_retargeting/shadowhand_urdf.zip
.
All examples can be run by first cloning the PyRoki repository, which includes the pyroki_snippets
implementation details.
1import time
2from typing import Tuple, TypedDict
3from pathlib import Path
4import pickle
5import trimesh
6from scipy.spatial.transform import Rotation as R
7
8import jax
9import jax.numpy as jnp
10import jax_dataclasses as jdc
11import jaxlie
12import jaxls
13import numpy as onp
14import viser
15from viser.extras import ViserUrdf
16import yourdfpy
17
18import pyroki as pk
19
20from retarget_helpers._utils import (
21 create_conn_tree,
22 get_mapping_from_mano_to_shadow,
23 MANO_TO_SHADOW_MAPPING,
24)
25
26
27class RetargetingWeights(TypedDict):
28 local_alignment: float
29 """Local alignment weight, by matching the relative joint/keypoint positions and angles."""
30 global_alignment: float
31 """Global alignment weight, by matching the keypoint positions to the robot."""
32 contact: float
33 """Contact weight, to maintain contact between the robot and the object."""
34 contact_margin: float
35 """Contact margin, to stop penalizing contact when the robot is already close to the object."""
36 joint_smoothness: float
37 """Joint smoothness weight."""
38 root_smoothness: float
39 """Root translation smoothness weight."""
40
41
42def main():
43 """Main function for hand retargeting."""
44
45 asset_dir = Path(__file__).parent / "retarget_helpers" / "hand"
46
47 robot_urdf_path = asset_dir / "shadowhand_urdf" / "shadow_hand_right.urdf"
48
49 def filename_handler(fname: str) -> str:
50 base_path = robot_urdf_path.parent
51 return yourdfpy.filename_handler_magic(fname, dir=base_path)
52
53 try:
54 urdf = yourdfpy.URDF.load(robot_urdf_path, filename_handler=filename_handler)
55 except FileNotFoundError:
56 raise FileNotFoundError(
57 "Please unzip the included URDF at `retarget_helpers/hand/shadowhand_urdf.zip`."
58 )
59
60 robot = pk.Robot.from_urdf(urdf)
61
62 # Get the mapping from MANO to Shadow Hand joints.
63 shadow_link_idx, mano_joint_idx = get_mapping_from_mano_to_shadow(robot)
64
65 # Create a mask for the MANO joints that are connected to the Shadow Hand.
66 mano_mask = create_conn_tree(robot, shadow_link_idx)
67
68 # Load source motion data.
69 dexycb_motion_path = asset_dir / "dexycb_motion.pkl"
70 with open(dexycb_motion_path, "rb") as f:
71 dexycb_motion_data = pickle.load(f, encoding="latin1")
72
73 # Load keypoints.
74 keypoints = dexycb_motion_data["world_hand_joints"]
75 assert not onp.isnan(keypoints).any()
76 num_timesteps = keypoints.shape[0]
77 num_mano_joints = len(MANO_TO_SHADOW_MAPPING)
78
79 # Load mano hand contact information -- these are lists of lists,
80 # len(contact_points_per_frame) = num_timesteps,
81 # len(contact_points_per_frame[i]) = number of contacts in frame i,
82 contact_points_per_frame = dexycb_motion_data["contact_object_points"]
83 contact_indices_per_frame = dexycb_motion_data["contact_joint_indices"]
84
85 # Now, we're going to pad this info + make a mask to indicate the padded regions.
86 # We will also track the shadowhand joint indices, NOT the MANO joint indices.
87 max_num_contacts = max(len(c) for c in contact_points_per_frame)
88 padded_contact_points_per_frame = onp.zeros((num_timesteps, max_num_contacts, 3))
89 padded_contact_indices_per_frame = onp.zeros(
90 (num_timesteps, max_num_contacts), dtype=onp.int32
91 )
92 padded_contact_mask = onp.zeros((num_timesteps, max_num_contacts), dtype=onp.bool_)
93 for i in range(num_timesteps):
94 num_contacts = len(contact_points_per_frame[i])
95 if num_contacts == 0:
96 continue
97 contact_shadowhand_indices = [
98 robot.links.names.index(MANO_TO_SHADOW_MAPPING[j])
99 for j in contact_indices_per_frame[i]
100 ]
101 padded_contact_points_per_frame[i, :num_contacts] = contact_points_per_frame[i]
102 padded_contact_indices_per_frame[i, :num_contacts] = contact_shadowhand_indices
103 padded_contact_mask[i, :num_contacts] = True
104
105 # Load the object.
106 object_mesh_vertices = dexycb_motion_data["object_mesh_vertices"]
107 object_mesh_faces = dexycb_motion_data["object_mesh_faces"]
108 object_pose_list = dexycb_motion_data["object_poses"] # (N, 4, 4)
109 mesh = trimesh.Trimesh(object_mesh_vertices, object_mesh_faces)
110
111 server = viser.ViserServer()
112
113 # We will transform everything by the transform below, for aesthetics.
114 server.scene.add_frame(
115 "/scene_offset",
116 show_axes=False,
117 position=(-0.15415953, -0.73598871, 0.93434792),
118 wxyz=(-0.381870867, 0.92421569, 0.0, 2.0004992e-32),
119 )
120 base_frame = server.scene.add_frame("/scene_offset/base", show_axes=False)
121 urdf_vis = ViserUrdf(server, urdf, root_node_name="/scene_offset/base")
122 playing = server.gui.add_checkbox("playing", True)
123 timestep_slider = server.gui.add_slider("timestep", 0, num_timesteps - 1, 1, 0)
124 object_handle = server.scene.add_mesh_trimesh("/scene_offset/object", mesh)
125 server.scene.add_grid("/grid", 2.0, 2.0)
126
127 default_weights = RetargetingWeights(
128 local_alignment=10.0,
129 global_alignment=1.0,
130 contact=5.0,
131 contact_margin=0.01,
132 joint_smoothness=2.0,
133 root_smoothness=2.0,
134 )
135
136 weights = pk.viewer.WeightTuner(
137 server,
138 default_weights, # type: ignore
139 )
140
141 Ts_world_root, joints = None, None
142
143 def generate_trajectory():
144 nonlocal Ts_world_root, joints
145 gen_button.disabled = True
146 Ts_world_root, joints = solve_retargeting(
147 robot=robot,
148 target_keypoints=keypoints,
149 shadow_hand_link_retarget_indices=shadow_link_idx,
150 mano_joint_retarget_indices=mano_joint_idx,
151 mano_mask=mano_mask,
152 contact_points_per_frame=jnp.array(padded_contact_points_per_frame),
153 contact_indices_per_frame=jnp.array(padded_contact_indices_per_frame),
154 contact_mask=jnp.array(padded_contact_mask),
155 weights=weights.get_weights(), # type: ignore
156 )
157 gen_button.disabled = False
158
159 gen_button = server.gui.add_button("Retarget!")
160 gen_button.on_click(lambda _: generate_trajectory())
161
162 generate_trajectory()
163 assert Ts_world_root is not None and joints is not None
164
165 while True:
166 with server.atomic():
167 if playing.value:
168 timestep_slider.value = (timestep_slider.value + 1) % num_timesteps
169 tstep = timestep_slider.value
170 base_frame.wxyz = onp.array(Ts_world_root.wxyz_xyz[tstep][:4])
171 base_frame.position = onp.array(Ts_world_root.wxyz_xyz[tstep][4:])
172 urdf_vis.update_cfg(onp.array(joints[tstep]))
173
174 server.scene.add_point_cloud(
175 "/scene_offset/target_keypoints",
176 onp.array(keypoints[tstep]).reshape(-1, 3),
177 onp.array((0, 0, 255))[None]
178 .repeat(num_mano_joints, axis=0)
179 .reshape(-1, 3),
180 point_size=0.005,
181 point_shape="sparkle",
182 )
183 server.scene.add_point_cloud(
184 "/scene_offset/contact_points",
185 onp.array(contact_points_per_frame[tstep]).reshape(-1, 3),
186 onp.array((255, 0, 0))[None]
187 .repeat(len(contact_points_per_frame[tstep]), axis=0)
188 .reshape(-1, 3),
189 point_size=0.005,
190 point_shape="circle",
191 )
192 object_handle.position = object_pose_list[tstep][:3, 3]
193 object_handle.wxyz = R.from_matrix(object_pose_list[tstep][:3, :3]).as_quat(
194 scalar_first=True
195 )
196
197 time.sleep(0.05)
198
199
200@jdc.jit
201def solve_retargeting(
202 robot: pk.Robot,
203 target_keypoints: jnp.ndarray,
204 shadow_hand_link_retarget_indices: jnp.ndarray,
205 mano_joint_retarget_indices: jnp.ndarray,
206 mano_mask: jnp.ndarray,
207 contact_points_per_frame: jnp.ndarray,
208 contact_indices_per_frame: jnp.ndarray,
209 contact_mask: jnp.ndarray,
210 weights: RetargetingWeights,
211) -> Tuple[jaxlie.SE3, jnp.ndarray]:
212 """Solve the retargeting problem."""
213
214 n_retarget = len(mano_joint_retarget_indices)
215 timesteps = target_keypoints.shape[0]
216
217 # Variables.
218 class ManoJointsScaleVar(
219 jaxls.Var[jax.Array], default_factory=lambda: jnp.ones((n_retarget, n_retarget))
220 ): ...
221
222 class OffsetVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros((3,))): ...
223
224 var_joints = robot.joint_var_cls(jnp.arange(timesteps))
225 var_Ts_world_root = jaxls.SE3Var(jnp.arange(timesteps))
226 var_smpl_joints_scale = ManoJointsScaleVar(jnp.zeros(timesteps))
227 var_offset = OffsetVar(jnp.zeros(timesteps))
228
229 # Costs.
230 costs: list[jaxls.Cost] = []
231
232 @jaxls.Cost.create_factory
233 def retargeting_cost(
234 var_values: jaxls.VarValues,
235 var_Ts_world_root: jaxls.SE3Var,
236 var_robot_cfg: jaxls.Var[jnp.ndarray],
237 var_smpl_joints_scale: ManoJointsScaleVar,
238 keypoints: jnp.ndarray,
239 ) -> jax.Array:
240 """Retargeting factor, with a focus on:
241 - matching the relative joint/keypoint positions (vectors).
242 - and matching the relative angles between the vectors.
243 """
244 robot_cfg = var_values[var_robot_cfg]
245 T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
246 T_world_root = var_values[var_Ts_world_root]
247 T_world_link = T_world_root @ T_root_link
248
249 mano_pos = keypoints[jnp.array(mano_joint_retarget_indices)]
250 robot_pos = T_world_link.translation()[
251 jnp.array(shadow_hand_link_retarget_indices)
252 ]
253
254 # NxN grid of relative positions.
255 delta_mano = mano_pos[:, None] - mano_pos[None, :]
256 delta_robot = robot_pos[:, None] - robot_pos[None, :]
257
258 # Vector regularization.
259 position_scale = var_values[var_smpl_joints_scale][..., None]
260 residual_position_delta = (
261 (delta_mano - delta_robot * position_scale)
262 * (1 - jnp.eye(delta_mano.shape[0])[..., None])
263 * mano_mask[..., None]
264 )
265
266 # Vector angle regularization.
267 delta_mano_normalized = delta_mano / jnp.linalg.norm(
268 delta_mano + 1e-6, axis=-1, keepdims=True
269 )
270 delta_robot_normalized = delta_robot / jnp.linalg.norm(
271 delta_robot + 1e-6, axis=-1, keepdims=True
272 )
273 residual_angle_delta = 1 - (delta_mano_normalized * delta_robot_normalized).sum(
274 axis=-1
275 )
276 residual_angle_delta = (
277 residual_angle_delta
278 * (1 - jnp.eye(residual_angle_delta.shape[0]))
279 * mano_mask
280 )
281
282 residual = (
283 jnp.concatenate(
284 [
285 residual_position_delta.flatten(),
286 residual_angle_delta.flatten(),
287 ],
288 axis=0,
289 )
290 * weights["local_alignment"]
291 )
292 return residual
293
294 @jaxls.Cost.create_factory
295 def scale_regularization(
296 var_values: jaxls.VarValues,
297 var_smpl_joints_scale: ManoJointsScaleVar,
298 ) -> jax.Array:
299 """Regularize the scale of the retargeted joints."""
300 # Close to 1.
301 res_0 = (var_values[var_smpl_joints_scale] - 1.0).flatten() * 1.0
302 # Symmetric.
303 res_1 = (
304 var_values[var_smpl_joints_scale] - var_values[var_smpl_joints_scale].T
305 ).flatten() * 100.0
306 # Non-negative.
307 res_2 = jnp.clip(-var_values[var_smpl_joints_scale], min=0).flatten() * 100.0
308 return jnp.concatenate([res_0, res_1, res_2])
309
310 @jaxls.Cost.create_factory
311 def pc_alignment_cost(
312 var_values: jaxls.VarValues,
313 var_Ts_world_root: jaxls.SE3Var,
314 var_robot_cfg: jaxls.Var[jnp.ndarray],
315 keypoints: jnp.ndarray,
316 ) -> jax.Array:
317 """Soft cost to align the human keypoints to the robot, in the world frame."""
318 T_world_root = var_values[var_Ts_world_root]
319 robot_cfg = var_values[var_robot_cfg]
320 T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
321 T_world_link = T_world_root @ T_root_link
322 link_pos = T_world_link.translation()[shadow_hand_link_retarget_indices]
323 keypoint_pos = keypoints[mano_joint_retarget_indices]
324 return (link_pos - keypoint_pos).flatten() * weights["global_alignment"]
325
326 @jaxls.Cost.create_factory
327 def root_smoothness(
328 var_values: jaxls.VarValues,
329 var_Ts_world_root: jaxls.SE3Var,
330 var_Ts_world_root_prev: jaxls.SE3Var,
331 ) -> jax.Array:
332 """Smoothness cost for the robot root translation."""
333 return (
334 var_values[var_Ts_world_root].translation()
335 - var_values[var_Ts_world_root_prev].translation()
336 ).flatten() * weights["root_smoothness"]
337
338 @jaxls.Cost.create_factory
339 def contact_cost(
340 var_values: jaxls.VarValues,
341 var_T_world_root: jaxls.SE3Var,
342 var_robot_cfg: jaxls.Var[jnp.ndarray],
343 contact_points: jax.Array, # (J, P, 3)
344 contact_indices: jax.Array, # (J,) - Actual robot joint indices.
345 contact_points_mask: jax.Array, # (J, P)
346 ) -> jax.Array:
347 """Cost for maintaining contact between specified robot joints and object points."""
348 robot_cfg = var_values[var_robot_cfg]
349 T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
350 T_world_root = var_values[var_T_world_root]
351 T_world_link = T_world_root @ T_root_link
352
353 contact_joint_positions_world = T_world_link.translation()[contact_indices]
354
355 # Contact points are already in world frame (as processed in dexycb).
356 # Calculate distances from each joint to its set of contact points
357 # Shape contact_points: (J, P, 3), contact_joint_positions_world: (J, 3)
358 # We want distance between joint J and points P for that joint.
359 # residual: (J, P, 3)
360 residual = contact_points - contact_joint_positions_world
361
362 # Penalize distance beyond a margin.
363 residual_penalty = jnp.maximum(
364 jnp.abs(residual) - weights["contact_margin"], 0.0
365 ) # (J, P, 3)
366
367 # Apply mask.
368 residual_penalty = (
369 residual_penalty * contact_points_mask[..., None]
370 ) # (J, P, 3)
371 residual = residual_penalty.flatten() * weights["contact"]
372
373 return residual
374
375 costs = [
376 # Costs that are relatively self-contained to the robot.
377 retargeting_cost(
378 var_Ts_world_root,
379 var_joints,
380 var_smpl_joints_scale,
381 target_keypoints,
382 ),
383 scale_regularization(var_smpl_joints_scale),
384 pk.costs.limit_cost(
385 jax.tree.map(lambda x: x[None], robot),
386 var_joints,
387 100.0,
388 ),
389 pk.costs.smoothness_cost(
390 robot.joint_var_cls(jnp.arange(1, timesteps)),
391 robot.joint_var_cls(jnp.arange(0, timesteps - 1)),
392 jnp.array([weights["joint_smoothness"]]),
393 ),
394 pk.costs.rest_cost(
395 var_joints,
396 var_joints.default_factory()[None],
397 jnp.array([0.2]),
398 ),
399 # Costs that are scene-centric.
400 pc_alignment_cost(
401 var_Ts_world_root,
402 var_joints,
403 target_keypoints,
404 ),
405 root_smoothness(
406 jaxls.SE3Var(jnp.arange(1, timesteps)),
407 jaxls.SE3Var(jnp.arange(0, timesteps - 1)),
408 ),
409 contact_cost(
410 var_T_world_root=var_Ts_world_root,
411 var_robot_cfg=var_joints,
412 contact_points=contact_points_per_frame,
413 contact_indices=contact_indices_per_frame,
414 contact_points_mask=contact_mask,
415 ),
416 ]
417
418 solution = (
419 jaxls.LeastSquaresProblem(
420 costs, [var_joints, var_Ts_world_root, var_smpl_joints_scale, var_offset]
421 )
422 .analyze()
423 .solve()
424 )
425 transform = solution[var_Ts_world_root]
426 offset = solution[var_offset]
427 transform = jaxlie.SE3.from_translation(offset) @ transform
428 return transform, solution[var_joints]
429
430
431if __name__ == "__main__":
432 main()