Hand RetargetingΒΆ
Simpler shadow hand retargeting example.
Find and unzip the shadowhand URDF at assets/hand_retargeting/shadowhand_urdf.zip
.
All examples can be run by first cloning the PyRoki repository, which includes the pyroki_snippets
implementation details.
1import pickle
2import time
3from pathlib import Path
4from typing import Tuple, TypedDict
5
6import jax
7import jax.numpy as jnp
8import jax_dataclasses as jdc
9import jaxlie
10import jaxls
11import numpy as onp
12import pyroki as pk
13import trimesh
14import viser
15import yourdfpy
16from scipy.spatial.transform import Rotation as R
17from viser.extras import ViserUrdf
18
19from retarget_helpers._utils import (
20 MANO_TO_SHADOW_MAPPING,
21 create_conn_tree,
22 get_mapping_from_mano_to_shadow,
23)
24
25
26class RetargetingWeights(TypedDict):
27 local_alignment: float
28 """Local alignment weight, by matching the relative joint/keypoint positions and angles."""
29 global_alignment: float
30 """Global alignment weight, by matching the keypoint positions to the robot."""
31 joint_smoothness: float
32 """Joint smoothness weight."""
33 root_smoothness: float
34 """Root translation smoothness weight."""
35
36
37def main():
38 """Main function for hand retargeting."""
39
40 asset_dir = Path(__file__).parent / "retarget_helpers" / "hand"
41
42 robot_urdf_path = asset_dir / "shadowhand_urdf" / "shadow_hand_right.urdf"
43
44 def filename_handler(fname: str) -> str:
45 base_path = robot_urdf_path.parent
46 return yourdfpy.filename_handler_magic(fname, dir=base_path)
47
48 try:
49 urdf = yourdfpy.URDF.load(robot_urdf_path, filename_handler=filename_handler)
50 except FileNotFoundError:
51 raise FileNotFoundError(
52 "Please unzip the included URDF at `retarget_helpers/hand/shadowhand_urdf.zip`."
53 )
54
55 robot = pk.Robot.from_urdf(urdf)
56
57 # Get the mapping from MANO to Shadow Hand joints.
58 shadow_link_idx, mano_joint_idx = get_mapping_from_mano_to_shadow(robot)
59
60 # Create a mask for the MANO joints that are connected to the Shadow Hand.
61 mano_mask = create_conn_tree(robot, shadow_link_idx)
62
63 # Load source motion data.
64 dexycb_motion_path = asset_dir / "dexycb_motion.pkl"
65 with open(dexycb_motion_path, "rb") as f:
66 dexycb_motion_data = pickle.load(f, encoding="latin1")
67
68 # Load keypoints.
69 keypoints = dexycb_motion_data["world_hand_joints"]
70 assert not onp.isnan(keypoints).any()
71 num_timesteps = keypoints.shape[0]
72 num_mano_joints = len(MANO_TO_SHADOW_MAPPING)
73
74 # Load mano hand contact information -- these are lists of lists,
75 # len(contact_points_per_frame) = num_timesteps,
76 # len(contact_points_per_frame[i]) = number of contacts in frame i,
77 contact_points_per_frame = dexycb_motion_data["contact_object_points"]
78 contact_indices_per_frame = dexycb_motion_data["contact_joint_indices"]
79
80 # Now, we're going to pad this info + make a mask to indicate the padded regions.
81 # We will also track the shadowhand joint indices, NOT the MANO joint indices.
82 max_num_contacts = max(len(c) for c in contact_points_per_frame)
83 padded_contact_points_per_frame = onp.zeros((num_timesteps, max_num_contacts, 3))
84 padded_contact_indices_per_frame = onp.zeros(
85 (num_timesteps, max_num_contacts), dtype=onp.int32
86 )
87 padded_contact_mask = onp.zeros((num_timesteps, max_num_contacts), dtype=onp.bool_)
88 for i in range(num_timesteps):
89 num_contacts = len(contact_points_per_frame[i])
90 if num_contacts == 0:
91 continue
92 contact_shadowhand_indices = [
93 robot.links.names.index(MANO_TO_SHADOW_MAPPING[j])
94 for j in contact_indices_per_frame[i]
95 ]
96 padded_contact_points_per_frame[i, :num_contacts] = contact_points_per_frame[i]
97 padded_contact_indices_per_frame[i, :num_contacts] = contact_shadowhand_indices
98 padded_contact_mask[i, :num_contacts] = True
99
100 # Load the object.
101 object_mesh_vertices = dexycb_motion_data["object_mesh_vertices"]
102 object_mesh_faces = dexycb_motion_data["object_mesh_faces"]
103 object_pose_list = dexycb_motion_data["object_poses"] # (N, 4, 4)
104 mesh = trimesh.Trimesh(object_mesh_vertices, object_mesh_faces)
105
106 server = viser.ViserServer()
107
108 # We will transform everything by the transform below, for aesthetics.
109 server.scene.add_frame(
110 "/scene_offset",
111 show_axes=False,
112 position=(-0.15415953, -0.73598871, 0.93434792),
113 wxyz=(-0.381870867, 0.92421569, 0.0, 2.0004992e-32),
114 )
115 hand_mesh = server.scene.add_mesh_simple(
116 "/scene_offset/hand_mesh",
117 vertices=dexycb_motion_data["world_hand_vertices"][0, :, :],
118 faces=dexycb_motion_data["hand_mesh_faces"],
119 opacity=0.5,
120 )
121 base_frame = server.scene.add_frame("/scene_offset/base", show_axes=False)
122 urdf_vis = ViserUrdf(server, urdf, root_node_name="/scene_offset/base")
123 playing = server.gui.add_checkbox("playing", True)
124 timestep_slider = server.gui.add_slider("timestep", 0, num_timesteps - 1, 1, 0)
125 object_handle = server.scene.add_mesh_trimesh("/scene_offset/object", mesh)
126 server.scene.add_grid("/grid", 2.0, 2.0)
127
128 default_weights = RetargetingWeights(
129 local_alignment=10.0,
130 global_alignment=1.0,
131 joint_smoothness=2.0,
132 root_smoothness=2.0,
133 )
134
135 weights = pk.viewer.WeightTuner(
136 server,
137 default_weights, # type: ignore
138 )
139
140 Ts_world_root, joints = None, None
141
142 def generate_trajectory():
143 nonlocal Ts_world_root, joints
144 gen_button.disabled = True
145 Ts_world_root, joints = solve_retargeting(
146 robot=robot,
147 target_keypoints=keypoints,
148 shadow_hand_link_retarget_indices=shadow_link_idx,
149 mano_joint_retarget_indices=mano_joint_idx,
150 mano_mask=mano_mask,
151 weights=weights.get_weights(), # type: ignore
152 )
153 gen_button.disabled = False
154
155 gen_button = server.gui.add_button("Retarget!")
156 gen_button.on_click(lambda _: generate_trajectory())
157
158 generate_trajectory()
159 assert Ts_world_root is not None and joints is not None
160
161 while True:
162 with server.atomic():
163 if playing.value:
164 timestep_slider.value = (timestep_slider.value + 1) % num_timesteps
165 tstep = timestep_slider.value
166 base_frame.wxyz = onp.array(Ts_world_root.wxyz_xyz[tstep][:4])
167 base_frame.position = onp.array(Ts_world_root.wxyz_xyz[tstep][4:])
168 urdf_vis.update_cfg(onp.array(joints[tstep]))
169
170 server.scene.add_point_cloud(
171 "/scene_offset/target_keypoints",
172 onp.array(keypoints[tstep]).reshape(-1, 3),
173 onp.array((0, 0, 255))[None]
174 .repeat(num_mano_joints, axis=0)
175 .reshape(-1, 3),
176 point_size=0.005,
177 point_shape="sparkle",
178 )
179 server.scene.add_point_cloud(
180 "/scene_offset/contact_points",
181 onp.array(contact_points_per_frame[tstep]).reshape(-1, 3),
182 onp.array((255, 0, 0))[None]
183 .repeat(len(contact_points_per_frame[tstep]), axis=0)
184 .reshape(-1, 3),
185 point_size=0.005,
186 point_shape="circle",
187 )
188 hand_mesh.vertices = dexycb_motion_data["world_hand_vertices"][tstep, :, :]
189 object_handle.position = object_pose_list[tstep][:3, 3]
190 object_handle.wxyz = R.from_matrix(object_pose_list[tstep][:3, :3]).as_quat(
191 scalar_first=True
192 )
193
194 time.sleep(0.05)
195
196
197@jdc.jit
198def solve_retargeting(
199 robot: pk.Robot,
200 target_keypoints: jnp.ndarray,
201 shadow_hand_link_retarget_indices: jnp.ndarray,
202 mano_joint_retarget_indices: jnp.ndarray,
203 mano_mask: jnp.ndarray,
204 weights: RetargetingWeights,
205) -> Tuple[jaxlie.SE3, jnp.ndarray]:
206 """Solve the retargeting problem."""
207
208 n_retarget = len(mano_joint_retarget_indices)
209 timesteps = target_keypoints.shape[0]
210
211 # Variables.
212 class ManoJointsScaleVar(
213 jaxls.Var[jax.Array], default_factory=lambda: jnp.ones((n_retarget, n_retarget))
214 ): ...
215
216 class OffsetVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros((3,))): ...
217
218 var_joints = robot.joint_var_cls(jnp.arange(timesteps))
219 var_Ts_world_root = jaxls.SE3Var(jnp.arange(timesteps))
220 var_smpl_joints_scale = ManoJointsScaleVar(jnp.zeros(timesteps))
221 var_offset = OffsetVar(jnp.zeros(timesteps))
222
223 # Costs.
224 costs: list[jaxls.Cost] = []
225
226 @jaxls.Cost.create_factory
227 def retargeting_cost(
228 var_values: jaxls.VarValues,
229 var_Ts_world_root: jaxls.SE3Var,
230 var_robot_cfg: jaxls.Var[jnp.ndarray],
231 var_smpl_joints_scale: ManoJointsScaleVar,
232 keypoints: jnp.ndarray,
233 ) -> jax.Array:
234 """Retargeting factor, with a focus on:
235 - matching the relative joint/keypoint positions (vectors).
236 - and matching the relative angles between the vectors.
237 """
238 robot_cfg = var_values[var_robot_cfg]
239 T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
240 T_world_root = var_values[var_Ts_world_root]
241 T_world_link = T_world_root @ T_root_link
242
243 mano_pos = keypoints[jnp.array(mano_joint_retarget_indices)]
244 robot_pos = T_world_link.translation()[
245 jnp.array(shadow_hand_link_retarget_indices)
246 ]
247
248 # NxN grid of relative positions.
249 delta_mano = mano_pos[:, None] - mano_pos[None, :]
250 delta_robot = robot_pos[:, None] - robot_pos[None, :]
251
252 # Vector regularization.
253 position_scale = var_values[var_smpl_joints_scale][..., None]
254 residual_position_delta = (
255 (delta_mano - delta_robot * position_scale)
256 * (1 - jnp.eye(delta_mano.shape[0])[..., None])
257 * mano_mask[..., None]
258 )
259
260 # Vector angle regularization.
261 delta_mano_normalized = delta_mano / jnp.linalg.norm(
262 delta_mano + 1e-6, axis=-1, keepdims=True
263 )
264 delta_robot_normalized = delta_robot / jnp.linalg.norm(
265 delta_robot + 1e-6, axis=-1, keepdims=True
266 )
267 residual_angle_delta = 1 - (delta_mano_normalized * delta_robot_normalized).sum(
268 axis=-1
269 )
270 residual_angle_delta = (
271 residual_angle_delta
272 * (1 - jnp.eye(residual_angle_delta.shape[0]))
273 * mano_mask
274 )
275
276 residual = (
277 jnp.concatenate(
278 [
279 residual_position_delta.flatten(),
280 residual_angle_delta.flatten(),
281 ],
282 axis=0,
283 )
284 * weights["local_alignment"]
285 )
286 return residual
287
288 @jaxls.Cost.create_factory
289 def pc_alignment_cost(
290 var_values: jaxls.VarValues,
291 var_Ts_world_root: jaxls.SE3Var,
292 var_robot_cfg: jaxls.Var[jnp.ndarray],
293 keypoints: jnp.ndarray,
294 ) -> jax.Array:
295 """Soft cost to align the human keypoints to the robot, in the world frame."""
296 T_world_root = var_values[var_Ts_world_root]
297 robot_cfg = var_values[var_robot_cfg]
298 T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
299 T_world_link = T_world_root @ T_root_link
300 link_pos = T_world_link.translation()[shadow_hand_link_retarget_indices]
301 keypoint_pos = keypoints[mano_joint_retarget_indices]
302 return (link_pos - keypoint_pos).flatten() * weights["global_alignment"]
303
304 @jaxls.Cost.create_factory
305 def root_smoothness(
306 var_values: jaxls.VarValues,
307 var_Ts_world_root: jaxls.SE3Var,
308 var_Ts_world_root_prev: jaxls.SE3Var,
309 ) -> jax.Array:
310 """Smoothness cost for the robot root translation."""
311 return (
312 var_values[var_Ts_world_root].translation()
313 - var_values[var_Ts_world_root_prev].translation()
314 ).flatten() * weights["root_smoothness"]
315
316 costs = [
317 retargeting_cost(
318 var_Ts_world_root,
319 var_joints,
320 var_smpl_joints_scale,
321 target_keypoints,
322 ),
323 pk.costs.limit_cost(
324 jax.tree.map(lambda x: x[None], robot),
325 var_joints,
326 100.0,
327 ),
328 pk.costs.smoothness_cost(
329 robot.joint_var_cls(jnp.arange(1, timesteps)),
330 robot.joint_var_cls(jnp.arange(0, timesteps - 1)),
331 jnp.array([weights["joint_smoothness"]]),
332 ),
333 pc_alignment_cost(
334 var_Ts_world_root,
335 var_joints,
336 target_keypoints,
337 ),
338 root_smoothness(
339 jaxls.SE3Var(jnp.arange(1, timesteps)),
340 jaxls.SE3Var(jnp.arange(0, timesteps - 1)),
341 ),
342 ]
343
344 solution = (
345 jaxls.LeastSquaresProblem(
346 costs, [var_joints, var_Ts_world_root, var_smpl_joints_scale, var_offset]
347 )
348 .analyze()
349 .solve()
350 )
351 transform = solution[var_Ts_world_root]
352 offset = solution[var_offset]
353 transform = jaxlie.SE3.from_translation(offset) @ transform
354 return transform, solution[var_joints]
355
356
357if __name__ == "__main__":
358 main()