Humanoid Retargeting (Fancy)ΒΆ
Retarget motion to G1 humanoid, with scene contacts (keep feet close to contact points, while avoiding world-collisions).
All examples can be run by first cloning the PyRoki repository, which includes the pyroki_snippets
implementation details.
1import time
2from typing import Tuple, TypedDict
3from pathlib import Path
4
5import jax
6import jax.numpy as jnp
7import jax_dataclasses as jdc
8import jaxlie
9import jaxls
10import numpy as onp
11import pyroki as pk
12import viser
13from viser.extras import ViserUrdf
14from pyroki.collision import colldist_from_sdf, collide
15from robot_descriptions.loaders.yourdfpy import load_robot_description
16
17from retarget_helpers._utils import (
18 SMPL_JOINT_NAMES,
19 create_conn_tree,
20 get_humanoid_retarget_indices,
21)
22
23
24class RetargetingWeights(TypedDict):
25 local_alignment: float
26 """Local alignment weight, by matching the relative joint/keypoint positions and angles."""
27 global_alignment: float
28 """Global alignment weight, by matching the keypoint positions to the robot."""
29 floor_contact: float
30 """Floor contact weight, to place the robot's foot on the floor."""
31 root_smoothness: float
32 """Root smoothness weight, to penalize the robot's root from jittering too much."""
33 foot_skating: float
34 """Foot skating weight, to penalize the robot's foot from moving when it is in contact with the floor."""
35 world_collision: float
36 """World collision weight, to penalize the robot from colliding with the world."""
37
38
39def main():
40 """Main function for humanoid retargeting."""
41
42 urdf = load_robot_description("g1_description")
43 robot = pk.Robot.from_urdf(urdf)
44 robot_coll = pk.collision.RobotCollision.from_urdf(urdf)
45
46 # Load source motion data:
47 # - keypoints [N, 45, 3],
48 # - left/right foot contact (boolean) 2 x [N],
49 # - heightmap [H, W].
50 asset_dir = Path(__file__).parent / "retarget_helpers" / "humanoid"
51 smpl_keypoints = onp.load(asset_dir / "smpl_keypoints.npy")
52 is_left_foot_contact = onp.load(asset_dir / "left_foot_contact.npy")
53 is_right_foot_contact = onp.load(asset_dir / "right_foot_contact.npy")
54 heightmap = onp.load(asset_dir / "heightmap.npy")
55
56 num_timesteps = smpl_keypoints.shape[0]
57 assert smpl_keypoints.shape == (num_timesteps, 45, 3)
58 assert is_left_foot_contact.shape == (num_timesteps,)
59 assert is_right_foot_contact.shape == (num_timesteps,)
60
61 heightmap = pk.collision.Heightmap(
62 pose=jaxlie.SE3.identity(),
63 size=jnp.array([0.01, 0.01, 1.0]),
64 height_data=heightmap,
65 )
66
67 # Get the left and right foot keypoints, projected on the heightmap.
68 left_foot_keypoint_idx = SMPL_JOINT_NAMES.index("left_foot")
69 right_foot_keypoint_idx = SMPL_JOINT_NAMES.index("right_foot")
70 left_foot_keypoints = smpl_keypoints[..., left_foot_keypoint_idx, :].reshape(-1, 3)
71 right_foot_keypoints = smpl_keypoints[..., right_foot_keypoint_idx, :].reshape(
72 -1, 3
73 )
74 left_foot_keypoints = heightmap.project_points(left_foot_keypoints)
75 right_foot_keypoints = heightmap.project_points(right_foot_keypoints)
76
77 smpl_joint_retarget_indices, g1_joint_retarget_indices = (
78 get_humanoid_retarget_indices()
79 )
80 smpl_mask = create_conn_tree(robot, g1_joint_retarget_indices)
81
82 server = viser.ViserServer()
83 base_frame = server.scene.add_frame("/base", show_axes=False)
84 urdf_vis = ViserUrdf(server, urdf, root_node_name="/base")
85 playing = server.gui.add_checkbox("playing", True)
86 timestep_slider = server.gui.add_slider("timestep", 0, num_timesteps - 1, 1, 0)
87 server.scene.add_mesh_trimesh("/heightmap", heightmap.to_trimesh())
88
89 weights = pk.viewer.WeightTuner(
90 server,
91 RetargetingWeights(
92 local_alignment=2.0,
93 global_alignment=1.0,
94 floor_contact=1.0,
95 root_smoothness=1.0,
96 foot_skating=1.0,
97 world_collision=1.0,
98 ), # type: ignore
99 )
100
101 Ts_world_root, joints = None, None
102
103 def generate_trajectory():
104 nonlocal Ts_world_root, joints
105 gen_button.disabled = True
106 Ts_world_root, joints = solve_retargeting(
107 robot=robot,
108 robot_coll=robot_coll,
109 target_keypoints=smpl_keypoints,
110 is_left_foot_contact=is_left_foot_contact,
111 is_right_foot_contact=is_right_foot_contact,
112 left_foot_keypoints=left_foot_keypoints,
113 right_foot_keypoints=right_foot_keypoints,
114 smpl_joint_retarget_indices=smpl_joint_retarget_indices,
115 g1_joint_retarget_indices=g1_joint_retarget_indices,
116 smpl_mask=smpl_mask,
117 heightmap=heightmap,
118 weights=weights.get_weights(), # type: ignore
119 )
120 gen_button.disabled = False
121
122 gen_button = server.gui.add_button("Retarget!")
123 gen_button.on_click(lambda _: generate_trajectory())
124
125 generate_trajectory()
126 assert Ts_world_root is not None and joints is not None
127
128 while True:
129 with server.atomic():
130 if playing.value:
131 timestep_slider.value = (timestep_slider.value + 1) % num_timesteps
132 tstep = timestep_slider.value
133 base_frame.wxyz = onp.array(Ts_world_root.wxyz_xyz[tstep][:4])
134 base_frame.position = onp.array(Ts_world_root.wxyz_xyz[tstep][4:])
135 urdf_vis.update_cfg(onp.array(joints[tstep]))
136 server.scene.add_point_cloud(
137 "/target_keypoints",
138 onp.array(smpl_keypoints[tstep]),
139 onp.array((0, 0, 255))[None].repeat(45, axis=0),
140 point_size=0.01,
141 )
142
143 time.sleep(0.05)
144
145
146@jdc.jit
147def solve_retargeting(
148 robot: pk.Robot,
149 robot_coll: pk.collision.RobotCollision,
150 target_keypoints: jnp.ndarray,
151 is_left_foot_contact: jnp.ndarray,
152 is_right_foot_contact: jnp.ndarray,
153 left_foot_keypoints: jnp.ndarray,
154 right_foot_keypoints: jnp.ndarray,
155 smpl_joint_retarget_indices: jnp.ndarray,
156 g1_joint_retarget_indices: jnp.ndarray,
157 smpl_mask: jnp.ndarray,
158 heightmap: pk.collision.Heightmap,
159 weights: RetargetingWeights,
160) -> Tuple[jaxlie.SE3, jnp.ndarray]:
161 """Solve the retargeting problem."""
162
163 n_retarget = len(smpl_joint_retarget_indices)
164 timesteps = target_keypoints.shape[0]
165
166 # Robot properties.
167 # - Joints that should move less for natural humanoid motion.
168 joints_to_move_less = jnp.array(
169 [
170 robot.joints.actuated_names.index(name)
171 for name in ["left_hip_yaw_joint", "right_hip_yaw_joint", "torso_joint"]
172 ]
173 )
174 # - Foot indices.
175 left_foot_idx = robot.links.names.index("left_ankle_roll_link")
176 right_foot_idx = robot.links.names.index("right_ankle_roll_link")
177
178 # Variables.
179 class SmplJointsScaleVarG1(
180 jaxls.Var[jax.Array], default_factory=lambda: jnp.ones((n_retarget, n_retarget))
181 ): ...
182
183 class OffsetVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros((3,))): ...
184
185 var_joints = robot.joint_var_cls(jnp.arange(timesteps))
186 var_Ts_world_root = jaxls.SE3Var(jnp.arange(timesteps))
187 var_smpl_joints_scale = SmplJointsScaleVarG1(jnp.zeros(timesteps))
188 var_offset = OffsetVar(jnp.zeros(timesteps))
189
190 # Costs.
191 costs: list[jaxls.Cost] = []
192
193 @jaxls.Cost.create_factory
194 def retargeting_cost(
195 var_values: jaxls.VarValues,
196 var_Ts_world_root: jaxls.SE3Var,
197 var_robot_cfg: jaxls.Var[jnp.ndarray],
198 var_smpl_joints_scale: SmplJointsScaleVarG1,
199 keypoints: jnp.ndarray,
200 ) -> jax.Array:
201 """Retargeting factor, with a focus on:
202 - matching the relative joint/keypoint positions (vectors).
203 - and matching the relative angles between the vectors.
204 """
205 robot_cfg = var_values[var_robot_cfg]
206 T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
207 T_world_root = var_values[var_Ts_world_root]
208 T_world_link = T_world_root @ T_root_link
209
210 smpl_pos = keypoints[jnp.array(smpl_joint_retarget_indices)]
211 robot_pos = T_world_link.translation()[jnp.array(g1_joint_retarget_indices)]
212
213 # NxN grid of relative positions.
214 delta_smpl = smpl_pos[:, None] - smpl_pos[None, :]
215 delta_robot = robot_pos[:, None] - robot_pos[None, :]
216
217 # Vector regularization.
218 position_scale = var_values[var_smpl_joints_scale][..., None]
219 residual_position_delta = (
220 (delta_smpl - delta_robot * position_scale)
221 * (1 - jnp.eye(delta_smpl.shape[0])[..., None])
222 * smpl_mask[..., None]
223 )
224
225 # Vector angle regularization.
226 delta_smpl_normalized = delta_smpl / jnp.linalg.norm(
227 delta_smpl + 1e-6, axis=-1, keepdims=True
228 )
229 delta_robot_normalized = delta_robot / jnp.linalg.norm(
230 delta_robot + 1e-6, axis=-1, keepdims=True
231 )
232 residual_angle_delta = 1 - (delta_smpl_normalized * delta_robot_normalized).sum(
233 axis=-1
234 )
235 residual_angle_delta = (
236 residual_angle_delta
237 * (1 - jnp.eye(residual_angle_delta.shape[0]))
238 * smpl_mask
239 )
240
241 residual = (
242 jnp.concatenate(
243 [residual_position_delta.flatten(), residual_angle_delta.flatten()]
244 )
245 * weights["local_alignment"]
246 )
247 return residual
248
249 @jaxls.Cost.create_factory
250 def scale_regularization(
251 var_values: jaxls.VarValues,
252 var_smpl_joints_scale: SmplJointsScaleVarG1,
253 ) -> jax.Array:
254 """Regularize the scale of the retargeted joints."""
255 # Close to 1.
256 res_0 = (var_values[var_smpl_joints_scale] - 1.0).flatten() * 1.0
257 # Symmetric.
258 res_1 = (
259 var_values[var_smpl_joints_scale] - var_values[var_smpl_joints_scale].T
260 ).flatten() * 100.0
261 # Non-negative.
262 res_2 = jnp.clip(-var_values[var_smpl_joints_scale], min=0).flatten() * 100.0
263 return jnp.concatenate([res_0, res_1, res_2])
264
265 @jaxls.Cost.create_factory
266 def pc_alignment_cost(
267 var_values: jaxls.VarValues,
268 var_Ts_world_root: jaxls.SE3Var,
269 var_robot_cfg: jaxls.Var[jnp.ndarray],
270 keypoints: jnp.ndarray,
271 ) -> jax.Array:
272 """Soft cost to align the human keypoints to the robot, in the world frame."""
273 T_world_root = var_values[var_Ts_world_root]
274 robot_cfg = var_values[var_robot_cfg]
275 T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
276 T_world_link = T_world_root @ T_root_link
277 link_pos = T_world_link.translation()[g1_joint_retarget_indices]
278 keypoint_pos = keypoints[smpl_joint_retarget_indices]
279 return (link_pos - keypoint_pos).flatten() * weights["global_alignment"]
280
281 @jaxls.Cost.create_factory
282 def floor_contact_cost(
283 var_values: jaxls.VarValues,
284 var_Ts_world_root: jaxls.SE3Var,
285 var_robot_cfg: jaxls.Var[jnp.ndarray],
286 var_offset: OffsetVar,
287 is_left_foot_contact: jnp.ndarray,
288 is_right_foot_contact: jnp.ndarray,
289 left_foot_keypoints: jnp.ndarray,
290 right_foot_keypoints: jnp.ndarray,
291 ) -> jax.Array:
292 """Cost to place the robot on the floor:
293 - match foot keypoint positions, and
294 - penalize the foot from tilting too much.
295 """
296 T_world_root = var_values[var_Ts_world_root]
297 T_root_link = jaxlie.SE3(
298 robot.forward_kinematics(cfg=var_values[var_robot_cfg])
299 )
300
301 offset = var_values[var_offset]
302 left_foot_pos = (T_world_root @ T_root_link).translation()[
303 left_foot_idx
304 ] + offset
305 right_foot_pos = (T_world_root @ T_root_link).translation()[
306 right_foot_idx
307 ] + offset
308 left_foot_contact_cost = (
309 is_left_foot_contact * (left_foot_pos - left_foot_keypoints) ** 2
310 )
311 right_foot_contact_cost = (
312 is_right_foot_contact * (right_foot_pos - right_foot_keypoints) ** 2
313 )
314
315 # Also penalize the foot from tilting too much -- keep z axis up!
316 left_foot_ori = (
317 (T_world_root @ T_root_link).rotation().as_matrix()[left_foot_idx]
318 )
319 right_foot_ori = (
320 (T_world_root @ T_root_link).rotation().as_matrix()[right_foot_idx]
321 )
322 left_foot_contact_residual_rot = jnp.where(
323 is_left_foot_contact,
324 left_foot_ori[2, 2] - 1,
325 0.0,
326 )
327 right_foot_contact_residual_rot = jnp.where(
328 is_right_foot_contact,
329 right_foot_ori[2, 2] - 1,
330 0.0,
331 )
332
333 return (
334 jnp.concatenate(
335 [
336 left_foot_contact_cost.flatten(),
337 right_foot_contact_cost.flatten(),
338 left_foot_contact_residual_rot.flatten(),
339 right_foot_contact_residual_rot.flatten(),
340 ]
341 )
342 * weights["floor_contact"]
343 )
344
345 @jaxls.Cost.create_factory
346 def root_smoothness(
347 var_values: jaxls.VarValues,
348 var_Ts_world_root: jaxls.SE3Var,
349 var_Ts_world_root_prev: jaxls.SE3Var,
350 ) -> jax.Array:
351 """Smoothness cost for the robot root pose."""
352 return (
353 var_values[var_Ts_world_root].inverse() @ var_values[var_Ts_world_root_prev]
354 ).log().flatten() * weights["root_smoothness"]
355
356 @jaxls.Cost.create_factory
357 def skating_cost(
358 var_values: jaxls.VarValues,
359 var_Ts_world_root: jaxls.SE3Var,
360 var_robot_cfg: jaxls.Var[jnp.ndarray],
361 var_offset: OffsetVar,
362 var_Ts_world_root_prev: jaxls.SE3Var,
363 var_robot_cfg_prev: jaxls.Var[jnp.ndarray],
364 var_offset_prev: OffsetVar,
365 is_left_foot_contact: jnp.ndarray,
366 is_right_foot_contact: jnp.ndarray,
367 ) -> jax.Array:
368 """Cost to penalize the robot for skating."""
369 T_world_root = var_values[var_Ts_world_root]
370 robot_cfg = var_values[var_robot_cfg]
371 T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
372 offset = var_values[var_offset]
373 T_link = T_world_root @ T_root_link
374 left_foot_pos = T_link.translation()[left_foot_idx] + offset
375 right_foot_pos = T_link.translation()[right_foot_idx] + offset
376
377 T_world_root_prev = var_values[var_Ts_world_root_prev]
378 robot_cfg_prev = var_values[var_robot_cfg_prev]
379 T_root_link_prev = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg_prev))
380 offset_prev = var_values[var_offset_prev]
381 T_link_prev = T_world_root_prev @ T_root_link_prev
382 left_foot_pos_prev = T_link_prev.translation()[left_foot_idx] + offset_prev
383 right_foot_pos_prev = T_link_prev.translation()[right_foot_idx] + offset_prev
384
385 skating_cost_left = is_left_foot_contact * (left_foot_pos - left_foot_pos_prev)
386 skating_cost_right = is_right_foot_contact * (
387 right_foot_pos - right_foot_pos_prev
388 )
389
390 return (
391 jnp.stack([skating_cost_left, skating_cost_right]) * weights["foot_skating"]
392 )
393
394 @jaxls.Cost.create_factory
395 def world_collision_cost(
396 var_values: jaxls.VarValues,
397 var_Ts_world_root: jaxls.SE3Var,
398 var_robot_cfg: jaxls.Var[jnp.ndarray],
399 var_offset: OffsetVar,
400 ) -> jax.Array:
401 """
402 World collision; we intentionally use a low weight --
403 high enough to lift the robot up from the ground, but
404 low enough to not interfere with the retargeting.
405 """
406 Ts_world_root = var_values[var_Ts_world_root]
407 T_offset = jaxlie.SE3.from_translation(var_values[var_offset])
408 transform = T_offset @ Ts_world_root
409
410 robot_cfg = var_values[var_robot_cfg]
411 coll = robot_coll.at_config(robot, robot_cfg)
412 coll = coll.transform(transform)
413
414 dist = collide(coll, heightmap)
415 act = colldist_from_sdf(dist, activation_dist=0.005)
416 return act.flatten() * weights["world_collision"]
417
418 costs = [
419 # Costs that are relatively self-contained to the robot.
420 retargeting_cost(
421 var_Ts_world_root,
422 var_joints,
423 var_smpl_joints_scale,
424 target_keypoints,
425 ),
426 scale_regularization(var_smpl_joints_scale),
427 pk.costs.limit_cost(
428 jax.tree.map(lambda x: x[None], robot),
429 var_joints,
430 100.0,
431 ),
432 pk.costs.smoothness_cost(
433 robot.joint_var_cls(jnp.arange(1, timesteps)),
434 robot.joint_var_cls(jnp.arange(0, timesteps - 1)),
435 jnp.array([0.2]),
436 ),
437 pk.costs.rest_cost(
438 var_joints,
439 var_joints.default_factory()[None],
440 jnp.full(var_joints.default_factory().shape, 0.2)
441 .at[joints_to_move_less]
442 .set(2.0)[None],
443 ),
444 pk.costs.self_collision_cost(
445 jax.tree.map(lambda x: x[None], robot),
446 jax.tree.map(lambda x: x[None], robot_coll),
447 var_joints,
448 margin=0.05,
449 weight=2.0,
450 ),
451 # Costs that are scene-centric.
452 pc_alignment_cost(
453 var_Ts_world_root,
454 var_joints,
455 target_keypoints,
456 ),
457 floor_contact_cost(
458 var_Ts_world_root,
459 var_joints,
460 var_offset,
461 is_left_foot_contact,
462 is_right_foot_contact,
463 left_foot_keypoints,
464 right_foot_keypoints,
465 ),
466 root_smoothness(
467 jaxls.SE3Var(jnp.arange(1, timesteps)),
468 jaxls.SE3Var(jnp.arange(0, timesteps - 1)),
469 ),
470 skating_cost(
471 jaxls.SE3Var(jnp.arange(1, timesteps)),
472 robot.joint_var_cls(jnp.arange(1, timesteps)),
473 OffsetVar(jnp.arange(1, timesteps)),
474 jaxls.SE3Var(jnp.arange(0, timesteps - 1)),
475 robot.joint_var_cls(jnp.arange(0, timesteps - 1)),
476 OffsetVar(jnp.arange(0, timesteps - 1)),
477 is_left_foot_contact[:-1],
478 is_right_foot_contact[:-1],
479 ),
480 world_collision_cost(
481 var_Ts_world_root,
482 var_joints,
483 var_offset,
484 ),
485 ]
486
487 solution = (
488 jaxls.LeastSquaresProblem(
489 costs, [var_joints, var_Ts_world_root, var_smpl_joints_scale, var_offset]
490 )
491 .analyze()
492 .solve()
493 )
494 transform = solution[var_Ts_world_root]
495 offset = solution[var_offset]
496 transform = jaxlie.SE3.from_translation(offset) @ transform
497 return transform, solution[var_joints]
498
499
500if __name__ == "__main__":
501 main()