Humanoid RetargetingΒΆ
Simpler motion retargeting to the G1 humanoid.
All examples can be run by first cloning the PyRoki repository, which includes the pyroki_snippets implementation details.
1import time
2from pathlib import Path
3from typing import Tuple, TypedDict
4
5import jax
6import jax.numpy as jnp
7import jax_dataclasses as jdc
8import jaxlie
9import jaxls
10import numpy as onp
11import pyroki as pk
12import viser
13from robot_descriptions.loaders.yourdfpy import load_robot_description
14from viser.extras import ViserUrdf
15
16from retarget_helpers._utils import (
17 SMPL_JOINT_NAMES,
18 create_conn_tree,
19 get_humanoid_retarget_indices,
20)
21
22
23class RetargetingWeights(TypedDict):
24 local_alignment: float
25 """Local alignment weight, by matching the relative joint/keypoint positions and angles."""
26 global_alignment: float
27 """Global alignment weight, by matching the keypoint positions to the robot."""
28
29
30def main():
31 """Main function for humanoid retargeting."""
32
33 urdf = load_robot_description("g1_description")
34 robot = pk.Robot.from_urdf(urdf)
35
36 # Load source motion data:
37 # - keypoints [N, 45, 3],
38 # - left/right foot contact (boolean) 2 x [N],
39 # - heightmap [H, W].
40 asset_dir = Path(__file__).parent / "retarget_helpers" / "humanoid"
41 smpl_keypoints = onp.load(asset_dir / "smpl_keypoints.npy")
42 is_left_foot_contact = onp.load(asset_dir / "left_foot_contact.npy")
43 is_right_foot_contact = onp.load(asset_dir / "right_foot_contact.npy")
44 heightmap = onp.load(asset_dir / "heightmap.npy")
45
46 num_timesteps = smpl_keypoints.shape[0]
47 assert smpl_keypoints.shape == (num_timesteps, 45, 3)
48 assert is_left_foot_contact.shape == (num_timesteps,)
49 assert is_right_foot_contact.shape == (num_timesteps,)
50
51 heightmap = pk.collision.Heightmap(
52 pose=jaxlie.SE3.identity(),
53 size=jnp.array([0.01, 0.01, 1.0]),
54 height_data=heightmap,
55 )
56
57 # Get the left and right foot keypoints, projected on the heightmap.
58 left_foot_keypoint_idx = SMPL_JOINT_NAMES.index("left_foot")
59 right_foot_keypoint_idx = SMPL_JOINT_NAMES.index("right_foot")
60 left_foot_keypoints = smpl_keypoints[..., left_foot_keypoint_idx, :].reshape(-1, 3)
61 right_foot_keypoints = smpl_keypoints[..., right_foot_keypoint_idx, :].reshape(
62 -1, 3
63 )
64 left_foot_keypoints = heightmap.project_points(left_foot_keypoints)
65 right_foot_keypoints = heightmap.project_points(right_foot_keypoints)
66
67 smpl_joint_retarget_indices, g1_joint_retarget_indices = (
68 get_humanoid_retarget_indices()
69 )
70 smpl_mask = create_conn_tree(robot, g1_joint_retarget_indices)
71
72 server = viser.ViserServer()
73 base_frame = server.scene.add_frame("/base", show_axes=False)
74 urdf_vis = ViserUrdf(server, urdf, root_node_name="/base")
75 playing = server.gui.add_checkbox("playing", True)
76 timestep_slider = server.gui.add_slider("timestep", 0, num_timesteps - 1, 1, 0)
77 server.scene.add_mesh_trimesh("/heightmap", heightmap.to_trimesh())
78
79 weights = pk.viewer.WeightTuner(
80 server,
81 RetargetingWeights( # type: ignore
82 local_alignment=2.0,
83 global_alignment=1.0,
84 ),
85 )
86
87 Ts_world_root, joints = None, None
88
89 def generate_trajectory():
90 nonlocal Ts_world_root, joints
91 gen_button.disabled = True
92 Ts_world_root, joints = solve_retargeting(
93 robot=robot,
94 target_keypoints=smpl_keypoints,
95 smpl_joint_retarget_indices=smpl_joint_retarget_indices,
96 g1_joint_retarget_indices=g1_joint_retarget_indices,
97 smpl_mask=smpl_mask,
98 weights=weights.get_weights(), # type: ignore
99 )
100 gen_button.disabled = False
101
102 gen_button = server.gui.add_button("Retarget!")
103 gen_button.on_click(lambda _: generate_trajectory())
104
105 generate_trajectory()
106 assert Ts_world_root is not None and joints is not None
107
108 while True:
109 with server.atomic():
110 if playing.value:
111 timestep_slider.value = (timestep_slider.value + 1) % num_timesteps
112 tstep = timestep_slider.value
113 base_frame.wxyz = onp.array(Ts_world_root.wxyz_xyz[tstep][:4])
114 base_frame.position = onp.array(Ts_world_root.wxyz_xyz[tstep][4:])
115 urdf_vis.update_cfg(onp.array(joints[tstep]))
116 server.scene.add_point_cloud(
117 "/target_keypoints",
118 onp.array(smpl_keypoints[tstep]),
119 onp.array((0, 0, 255))[None].repeat(45, axis=0),
120 point_size=0.01,
121 )
122
123 time.sleep(0.05)
124
125
126@jdc.jit
127def solve_retargeting(
128 robot: pk.Robot,
129 target_keypoints: jnp.ndarray,
130 smpl_joint_retarget_indices: jnp.ndarray,
131 g1_joint_retarget_indices: jnp.ndarray,
132 smpl_mask: jnp.ndarray,
133 weights: RetargetingWeights,
134) -> Tuple[jaxlie.SE3, jnp.ndarray]:
135 """Solve the retargeting problem."""
136
137 n_retarget = len(smpl_joint_retarget_indices)
138 timesteps = target_keypoints.shape[0]
139
140 # Robot properties.
141 # - Joints that should move less for natural humanoid motion.
142 joints_to_move_less = jnp.array(
143 [
144 robot.joints.actuated_names.index(name)
145 for name in ["left_hip_yaw_joint", "right_hip_yaw_joint", "waist_yaw_joint"]
146 ]
147 )
148
149 # Variables.
150 class SmplJointsScaleVarG1(
151 jaxls.Var[jax.Array], default_factory=lambda: jnp.ones((n_retarget, n_retarget))
152 ): ...
153
154 class OffsetVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros((3,))): ...
155
156 var_joints = robot.joint_var_cls(jnp.arange(timesteps))
157 var_Ts_world_root = jaxls.SE3Var(jnp.arange(timesteps))
158 var_smpl_joints_scale = SmplJointsScaleVarG1(jnp.zeros(timesteps))
159 var_offset = OffsetVar(jnp.zeros(timesteps))
160
161 # Costs and constraints.
162 costs: list[jaxls.Cost] = []
163
164 @jaxls.Cost.factory
165 def retargeting_cost(
166 var_values: jaxls.VarValues,
167 var_Ts_world_root: jaxls.SE3Var,
168 var_robot_cfg: jaxls.Var[jnp.ndarray],
169 var_smpl_joints_scale: SmplJointsScaleVarG1,
170 keypoints: jnp.ndarray,
171 ) -> jax.Array:
172 """Retargeting factor, with a focus on:
173 - matching the relative joint/keypoint positions (vectors).
174 - and matching the relative angles between the vectors.
175 """
176 robot_cfg = var_values[var_robot_cfg]
177 T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
178 T_world_root = var_values[var_Ts_world_root]
179 T_world_link = T_world_root @ T_root_link
180
181 smpl_pos = keypoints[jnp.array(smpl_joint_retarget_indices)]
182 robot_pos = T_world_link.translation()[jnp.array(g1_joint_retarget_indices)]
183
184 # NxN grid of relative positions.
185 delta_smpl = smpl_pos[:, None] - smpl_pos[None, :]
186 delta_robot = robot_pos[:, None] - robot_pos[None, :]
187
188 # Vector regularization.
189 position_scale = var_values[var_smpl_joints_scale][..., None]
190 residual_position_delta = (
191 (delta_smpl - delta_robot * position_scale)
192 * (1 - jnp.eye(delta_smpl.shape[0])[..., None])
193 * smpl_mask[..., None]
194 )
195
196 # Vector angle regularization.
197 delta_smpl_normalized = delta_smpl / jnp.linalg.norm(
198 delta_smpl + 1e-6, axis=-1, keepdims=True
199 )
200 delta_robot_normalized = delta_robot / jnp.linalg.norm(
201 delta_robot + 1e-6, axis=-1, keepdims=True
202 )
203 residual_angle_delta = 1 - (delta_smpl_normalized * delta_robot_normalized).sum(
204 axis=-1
205 )
206 residual_angle_delta = (
207 residual_angle_delta
208 * (1 - jnp.eye(residual_angle_delta.shape[0]))
209 * smpl_mask
210 )
211
212 residual = (
213 jnp.concatenate(
214 [residual_position_delta.flatten(), residual_angle_delta.flatten()]
215 )
216 * weights["local_alignment"]
217 )
218 return residual
219
220 @jaxls.Cost.factory
221 def pc_alignment_cost(
222 var_values: jaxls.VarValues,
223 var_Ts_world_root: jaxls.SE3Var,
224 var_robot_cfg: jaxls.Var[jnp.ndarray],
225 keypoints: jnp.ndarray,
226 ) -> jax.Array:
227 """Soft cost to align the human keypoints to the robot, in the world frame."""
228 T_world_root = var_values[var_Ts_world_root]
229 robot_cfg = var_values[var_robot_cfg]
230 T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
231 T_world_link = T_world_root @ T_root_link
232 link_pos = T_world_link.translation()[g1_joint_retarget_indices]
233 keypoint_pos = keypoints[smpl_joint_retarget_indices]
234 return (link_pos - keypoint_pos).flatten() * weights["global_alignment"]
235
236 costs = [
237 # Costs that are relatively self-contained to the robot.
238 retargeting_cost(
239 var_Ts_world_root,
240 var_joints,
241 var_smpl_joints_scale,
242 target_keypoints,
243 ),
244 pk.costs.smoothness_cost(
245 robot.joint_var_cls(jnp.arange(1, timesteps)),
246 robot.joint_var_cls(jnp.arange(0, timesteps - 1)),
247 jnp.array([0.2]),
248 ),
249 pk.costs.rest_cost(
250 var_joints,
251 var_joints.default_factory()[None],
252 jnp.full(var_joints.default_factory().shape, 0.2)
253 .at[joints_to_move_less]
254 .set(2.0)[None],
255 ),
256 # Costs that are scene-centric.
257 pc_alignment_cost(
258 var_Ts_world_root,
259 var_joints,
260 target_keypoints,
261 ),
262 ]
263
264 costs.append(
265 pk.costs.limit_constraint(
266 jax.tree.map(lambda x: x[None], robot),
267 var_joints,
268 ),
269 )
270
271 solution = (
272 jaxls.LeastSquaresProblem(
273 costs=costs,
274 variables=[
275 var_joints,
276 var_Ts_world_root,
277 var_smpl_joints_scale,
278 var_offset,
279 ],
280 )
281 .analyze()
282 .solve()
283 )
284 transform = solution[var_Ts_world_root]
285 offset = solution[var_offset]
286 transform = jaxlie.SE3.from_translation(offset) @ transform
287 return transform, solution[var_joints]
288
289
290if __name__ == "__main__":
291 main()