Humanoid RetargetingΒΆ
Simpler motion retargeting to the G1 humanoid.
All examples can be run by first cloning the PyRoki repository, which includes the pyroki_snippets
implementation details.
1import time
2from pathlib import Path
3from typing import Tuple, TypedDict
4
5import jax
6import jax.numpy as jnp
7import jax_dataclasses as jdc
8import jaxlie
9import jaxls
10import numpy as onp
11import pyroki as pk
12import viser
13from pyroki.collision import colldist_from_sdf, collide
14from robot_descriptions.loaders.yourdfpy import load_robot_description
15from viser.extras import ViserUrdf
16
17from retarget_helpers._utils import (
18 SMPL_JOINT_NAMES,
19 create_conn_tree,
20 get_humanoid_retarget_indices,
21)
22
23
24class RetargetingWeights(TypedDict):
25 local_alignment: float
26 """Local alignment weight, by matching the relative joint/keypoint positions and angles."""
27 global_alignment: float
28 """Global alignment weight, by matching the keypoint positions to the robot."""
29
30
31def main():
32 """Main function for humanoid retargeting."""
33
34 urdf = load_robot_description("g1_description")
35 robot = pk.Robot.from_urdf(urdf)
36 robot_coll = pk.collision.RobotCollision.from_urdf(urdf)
37
38 # Load source motion data:
39 # - keypoints [N, 45, 3],
40 # - left/right foot contact (boolean) 2 x [N],
41 # - heightmap [H, W].
42 asset_dir = Path(__file__).parent / "retarget_helpers" / "humanoid"
43 smpl_keypoints = onp.load(asset_dir / "smpl_keypoints.npy")
44 is_left_foot_contact = onp.load(asset_dir / "left_foot_contact.npy")
45 is_right_foot_contact = onp.load(asset_dir / "right_foot_contact.npy")
46 heightmap = onp.load(asset_dir / "heightmap.npy")
47
48 num_timesteps = smpl_keypoints.shape[0]
49 assert smpl_keypoints.shape == (num_timesteps, 45, 3)
50 assert is_left_foot_contact.shape == (num_timesteps,)
51 assert is_right_foot_contact.shape == (num_timesteps,)
52
53 heightmap = pk.collision.Heightmap(
54 pose=jaxlie.SE3.identity(),
55 size=jnp.array([0.01, 0.01, 1.0]),
56 height_data=heightmap,
57 )
58
59 # Get the left and right foot keypoints, projected on the heightmap.
60 left_foot_keypoint_idx = SMPL_JOINT_NAMES.index("left_foot")
61 right_foot_keypoint_idx = SMPL_JOINT_NAMES.index("right_foot")
62 left_foot_keypoints = smpl_keypoints[..., left_foot_keypoint_idx, :].reshape(-1, 3)
63 right_foot_keypoints = smpl_keypoints[..., right_foot_keypoint_idx, :].reshape(
64 -1, 3
65 )
66 left_foot_keypoints = heightmap.project_points(left_foot_keypoints)
67 right_foot_keypoints = heightmap.project_points(right_foot_keypoints)
68
69 smpl_joint_retarget_indices, g1_joint_retarget_indices = (
70 get_humanoid_retarget_indices()
71 )
72 smpl_mask = create_conn_tree(robot, g1_joint_retarget_indices)
73
74 server = viser.ViserServer()
75 base_frame = server.scene.add_frame("/base", show_axes=False)
76 urdf_vis = ViserUrdf(server, urdf, root_node_name="/base")
77 playing = server.gui.add_checkbox("playing", True)
78 timestep_slider = server.gui.add_slider("timestep", 0, num_timesteps - 1, 1, 0)
79 server.scene.add_mesh_trimesh("/heightmap", heightmap.to_trimesh())
80
81 weights = pk.viewer.WeightTuner(
82 server,
83 RetargetingWeights( # type: ignore
84 local_alignment=2.0,
85 global_alignment=1.0,
86 ),
87 )
88
89 Ts_world_root, joints = None, None
90
91 def generate_trajectory():
92 nonlocal Ts_world_root, joints
93 gen_button.disabled = True
94 Ts_world_root, joints = solve_retargeting(
95 robot=robot,
96 target_keypoints=smpl_keypoints,
97 smpl_joint_retarget_indices=smpl_joint_retarget_indices,
98 g1_joint_retarget_indices=g1_joint_retarget_indices,
99 smpl_mask=smpl_mask,
100 weights=weights.get_weights(), # type: ignore
101 )
102 gen_button.disabled = False
103
104 gen_button = server.gui.add_button("Retarget!")
105 gen_button.on_click(lambda _: generate_trajectory())
106
107 generate_trajectory()
108 assert Ts_world_root is not None and joints is not None
109
110 while True:
111 with server.atomic():
112 if playing.value:
113 timestep_slider.value = (timestep_slider.value + 1) % num_timesteps
114 tstep = timestep_slider.value
115 base_frame.wxyz = onp.array(Ts_world_root.wxyz_xyz[tstep][:4])
116 base_frame.position = onp.array(Ts_world_root.wxyz_xyz[tstep][4:])
117 urdf_vis.update_cfg(onp.array(joints[tstep]))
118 server.scene.add_point_cloud(
119 "/target_keypoints",
120 onp.array(smpl_keypoints[tstep]),
121 onp.array((0, 0, 255))[None].repeat(45, axis=0),
122 point_size=0.01,
123 )
124
125 time.sleep(0.05)
126
127
128@jdc.jit
129def solve_retargeting(
130 robot: pk.Robot,
131 target_keypoints: jnp.ndarray,
132 smpl_joint_retarget_indices: jnp.ndarray,
133 g1_joint_retarget_indices: jnp.ndarray,
134 smpl_mask: jnp.ndarray,
135 weights: RetargetingWeights,
136) -> Tuple[jaxlie.SE3, jnp.ndarray]:
137 """Solve the retargeting problem."""
138
139 n_retarget = len(smpl_joint_retarget_indices)
140 timesteps = target_keypoints.shape[0]
141
142 # Robot properties.
143 # - Joints that should move less for natural humanoid motion.
144 joints_to_move_less = jnp.array(
145 [
146 robot.joints.actuated_names.index(name)
147 for name in ["left_hip_yaw_joint", "right_hip_yaw_joint", "torso_joint"]
148 ]
149 )
150
151 # Variables.
152 class SmplJointsScaleVarG1(
153 jaxls.Var[jax.Array], default_factory=lambda: jnp.ones((n_retarget, n_retarget))
154 ): ...
155
156 class OffsetVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros((3,))): ...
157
158 var_joints = robot.joint_var_cls(jnp.arange(timesteps))
159 var_Ts_world_root = jaxls.SE3Var(jnp.arange(timesteps))
160 var_smpl_joints_scale = SmplJointsScaleVarG1(jnp.zeros(timesteps))
161 var_offset = OffsetVar(jnp.zeros(timesteps))
162
163 # Costs.
164 costs: list[jaxls.Cost] = []
165
166 @jaxls.Cost.create_factory
167 def retargeting_cost(
168 var_values: jaxls.VarValues,
169 var_Ts_world_root: jaxls.SE3Var,
170 var_robot_cfg: jaxls.Var[jnp.ndarray],
171 var_smpl_joints_scale: SmplJointsScaleVarG1,
172 keypoints: jnp.ndarray,
173 ) -> jax.Array:
174 """Retargeting factor, with a focus on:
175 - matching the relative joint/keypoint positions (vectors).
176 - and matching the relative angles between the vectors.
177 """
178 robot_cfg = var_values[var_robot_cfg]
179 T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
180 T_world_root = var_values[var_Ts_world_root]
181 T_world_link = T_world_root @ T_root_link
182
183 smpl_pos = keypoints[jnp.array(smpl_joint_retarget_indices)]
184 robot_pos = T_world_link.translation()[jnp.array(g1_joint_retarget_indices)]
185
186 # NxN grid of relative positions.
187 delta_smpl = smpl_pos[:, None] - smpl_pos[None, :]
188 delta_robot = robot_pos[:, None] - robot_pos[None, :]
189
190 # Vector regularization.
191 position_scale = var_values[var_smpl_joints_scale][..., None]
192 residual_position_delta = (
193 (delta_smpl - delta_robot * position_scale)
194 * (1 - jnp.eye(delta_smpl.shape[0])[..., None])
195 * smpl_mask[..., None]
196 )
197
198 # Vector angle regularization.
199 delta_smpl_normalized = delta_smpl / jnp.linalg.norm(
200 delta_smpl + 1e-6, axis=-1, keepdims=True
201 )
202 delta_robot_normalized = delta_robot / jnp.linalg.norm(
203 delta_robot + 1e-6, axis=-1, keepdims=True
204 )
205 residual_angle_delta = 1 - (delta_smpl_normalized * delta_robot_normalized).sum(
206 axis=-1
207 )
208 residual_angle_delta = (
209 residual_angle_delta
210 * (1 - jnp.eye(residual_angle_delta.shape[0]))
211 * smpl_mask
212 )
213
214 residual = (
215 jnp.concatenate(
216 [residual_position_delta.flatten(), residual_angle_delta.flatten()]
217 )
218 * weights["local_alignment"]
219 )
220 return residual
221
222 @jaxls.Cost.create_factory
223 def pc_alignment_cost(
224 var_values: jaxls.VarValues,
225 var_Ts_world_root: jaxls.SE3Var,
226 var_robot_cfg: jaxls.Var[jnp.ndarray],
227 keypoints: jnp.ndarray,
228 ) -> jax.Array:
229 """Soft cost to align the human keypoints to the robot, in the world frame."""
230 T_world_root = var_values[var_Ts_world_root]
231 robot_cfg = var_values[var_robot_cfg]
232 T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
233 T_world_link = T_world_root @ T_root_link
234 link_pos = T_world_link.translation()[g1_joint_retarget_indices]
235 keypoint_pos = keypoints[smpl_joint_retarget_indices]
236 return (link_pos - keypoint_pos).flatten() * weights["global_alignment"]
237
238 costs = [
239 # Costs that are relatively self-contained to the robot.
240 retargeting_cost(
241 var_Ts_world_root,
242 var_joints,
243 var_smpl_joints_scale,
244 target_keypoints,
245 ),
246 pk.costs.limit_cost(
247 jax.tree.map(lambda x: x[None], robot),
248 var_joints,
249 100.0,
250 ),
251 pk.costs.smoothness_cost(
252 robot.joint_var_cls(jnp.arange(1, timesteps)),
253 robot.joint_var_cls(jnp.arange(0, timesteps - 1)),
254 jnp.array([0.2]),
255 ),
256 pk.costs.rest_cost(
257 var_joints,
258 var_joints.default_factory()[None],
259 jnp.full(var_joints.default_factory().shape, 0.2)
260 .at[joints_to_move_less]
261 .set(2.0)[None],
262 ),
263 # Costs that are scene-centric.
264 pc_alignment_cost(
265 var_Ts_world_root,
266 var_joints,
267 target_keypoints,
268 ),
269 ]
270
271 solution = (
272 jaxls.LeastSquaresProblem(
273 costs, [var_joints, var_Ts_world_root, var_smpl_joints_scale, var_offset]
274 )
275 .analyze()
276 .solve()
277 )
278 transform = solution[var_Ts_world_root]
279 offset = solution[var_offset]
280 transform = jaxlie.SE3.from_translation(offset) @ transform
281 return transform, solution[var_joints]
282
283
284if __name__ == "__main__":
285 main()