Humanoid Retargeting (Fancy)ΒΆ

Retarget motion to G1 humanoid, with scene contacts (keep feet close to contact points, while avoiding world-collisions).

All examples can be run by first cloning the PyRoki repository, which includes the pyroki_snippets implementation details.

  1import time
  2from typing import Tuple, TypedDict
  3from pathlib import Path
  4
  5import jax
  6import jax.numpy as jnp
  7import jax_dataclasses as jdc
  8import jaxlie
  9import jaxls
 10import numpy as onp
 11import pyroki as pk
 12import viser
 13from viser.extras import ViserUrdf
 14from pyroki.collision import colldist_from_sdf, collide
 15from robot_descriptions.loaders.yourdfpy import load_robot_description
 16
 17from retarget_helpers._utils import (
 18    SMPL_JOINT_NAMES,
 19    create_conn_tree,
 20    get_humanoid_retarget_indices,
 21)
 22
 23
 24class RetargetingWeights(TypedDict):
 25    local_alignment: float
 26    """Local alignment weight, by matching the relative joint/keypoint positions and angles."""
 27    global_alignment: float
 28    """Global alignment weight, by matching the keypoint positions to the robot."""
 29    floor_contact: float
 30    """Floor contact weight, to place the robot's foot on the floor."""
 31    root_smoothness: float
 32    """Root smoothness weight, to penalize the robot's root from jittering too much."""
 33    foot_skating: float
 34    """Foot skating weight, to penalize the robot's foot from moving when it is in contact with the floor."""
 35    world_collision: float
 36    """World collision weight, to penalize the robot from colliding with the world."""
 37
 38
 39def main():
 40    """Main function for humanoid retargeting."""
 41
 42    urdf = load_robot_description("g1_description")
 43    robot = pk.Robot.from_urdf(urdf)
 44    robot_coll = pk.collision.RobotCollision.from_urdf(urdf)
 45
 46    # Load source motion data:
 47    # - keypoints [N, 45, 3],
 48    # - left/right foot contact (boolean) 2 x [N],
 49    # - heightmap [H, W].
 50    asset_dir = Path(__file__).parent / "retarget_helpers" / "humanoid"
 51    smpl_keypoints = onp.load(asset_dir / "smpl_keypoints.npy")
 52    is_left_foot_contact = onp.load(asset_dir / "left_foot_contact.npy")
 53    is_right_foot_contact = onp.load(asset_dir / "right_foot_contact.npy")
 54    heightmap = onp.load(asset_dir / "heightmap.npy")
 55
 56    num_timesteps = smpl_keypoints.shape[0]
 57    assert smpl_keypoints.shape == (num_timesteps, 45, 3)
 58    assert is_left_foot_contact.shape == (num_timesteps,)
 59    assert is_right_foot_contact.shape == (num_timesteps,)
 60
 61    heightmap = pk.collision.Heightmap(
 62        pose=jaxlie.SE3.identity(),
 63        size=jnp.array([0.01, 0.01, 1.0]),
 64        height_data=heightmap,
 65    )
 66
 67    # Get the left and right foot keypoints, projected on the heightmap.
 68    left_foot_keypoint_idx = SMPL_JOINT_NAMES.index("left_foot")
 69    right_foot_keypoint_idx = SMPL_JOINT_NAMES.index("right_foot")
 70    left_foot_keypoints = smpl_keypoints[..., left_foot_keypoint_idx, :].reshape(-1, 3)
 71    right_foot_keypoints = smpl_keypoints[..., right_foot_keypoint_idx, :].reshape(
 72        -1, 3
 73    )
 74    left_foot_keypoints = heightmap.project_points(left_foot_keypoints)
 75    right_foot_keypoints = heightmap.project_points(right_foot_keypoints)
 76
 77    smpl_joint_retarget_indices, g1_joint_retarget_indices = (
 78        get_humanoid_retarget_indices()
 79    )
 80    smpl_mask = create_conn_tree(robot, g1_joint_retarget_indices)
 81
 82    server = viser.ViserServer()
 83    base_frame = server.scene.add_frame("/base", show_axes=False)
 84    urdf_vis = ViserUrdf(server, urdf, root_node_name="/base")
 85    playing = server.gui.add_checkbox("playing", True)
 86    timestep_slider = server.gui.add_slider("timestep", 0, num_timesteps - 1, 1, 0)
 87    server.scene.add_mesh_trimesh("/heightmap", heightmap.to_trimesh())
 88
 89    weights = pk.viewer.WeightTuner(
 90        server,
 91        RetargetingWeights(
 92            local_alignment=2.0,
 93            global_alignment=1.0,
 94            floor_contact=1.0,
 95            root_smoothness=1.0,
 96            foot_skating=1.0,
 97            world_collision=1.0,
 98        ),  # type: ignore
 99    )
100
101    Ts_world_root, joints = None, None
102
103    def generate_trajectory():
104        nonlocal Ts_world_root, joints
105        gen_button.disabled = True
106        Ts_world_root, joints = solve_retargeting(
107            robot=robot,
108            robot_coll=robot_coll,
109            target_keypoints=smpl_keypoints,
110            is_left_foot_contact=is_left_foot_contact,
111            is_right_foot_contact=is_right_foot_contact,
112            left_foot_keypoints=left_foot_keypoints,
113            right_foot_keypoints=right_foot_keypoints,
114            smpl_joint_retarget_indices=smpl_joint_retarget_indices,
115            g1_joint_retarget_indices=g1_joint_retarget_indices,
116            smpl_mask=smpl_mask,
117            heightmap=heightmap,
118            weights=weights.get_weights(),  # type: ignore
119        )
120        gen_button.disabled = False
121
122    gen_button = server.gui.add_button("Retarget!")
123    gen_button.on_click(lambda _: generate_trajectory())
124
125    generate_trajectory()
126    assert Ts_world_root is not None and joints is not None
127
128    while True:
129        with server.atomic():
130            if playing.value:
131                timestep_slider.value = (timestep_slider.value + 1) % num_timesteps
132            tstep = timestep_slider.value
133            base_frame.wxyz = onp.array(Ts_world_root.wxyz_xyz[tstep][:4])
134            base_frame.position = onp.array(Ts_world_root.wxyz_xyz[tstep][4:])
135            urdf_vis.update_cfg(onp.array(joints[tstep]))
136            server.scene.add_point_cloud(
137                "/target_keypoints",
138                onp.array(smpl_keypoints[tstep]),
139                onp.array((0, 0, 255))[None].repeat(45, axis=0),
140                point_size=0.01,
141            )
142
143        time.sleep(0.05)
144
145
146@jdc.jit
147def solve_retargeting(
148    robot: pk.Robot,
149    robot_coll: pk.collision.RobotCollision,
150    target_keypoints: jnp.ndarray,
151    is_left_foot_contact: jnp.ndarray,
152    is_right_foot_contact: jnp.ndarray,
153    left_foot_keypoints: jnp.ndarray,
154    right_foot_keypoints: jnp.ndarray,
155    smpl_joint_retarget_indices: jnp.ndarray,
156    g1_joint_retarget_indices: jnp.ndarray,
157    smpl_mask: jnp.ndarray,
158    heightmap: pk.collision.Heightmap,
159    weights: RetargetingWeights,
160) -> Tuple[jaxlie.SE3, jnp.ndarray]:
161    """Solve the retargeting problem."""
162
163    n_retarget = len(smpl_joint_retarget_indices)
164    timesteps = target_keypoints.shape[0]
165
166    # Robot properties.
167    # - Joints that should move less for natural humanoid motion.
168    joints_to_move_less = jnp.array(
169        [
170            robot.joints.actuated_names.index(name)
171            for name in ["left_hip_yaw_joint", "right_hip_yaw_joint", "torso_joint"]
172        ]
173    )
174    # - Foot indices.
175    left_foot_idx = robot.links.names.index("left_ankle_roll_link")
176    right_foot_idx = robot.links.names.index("right_ankle_roll_link")
177
178    # Variables.
179    class SmplJointsScaleVarG1(
180        jaxls.Var[jax.Array], default_factory=lambda: jnp.ones((n_retarget, n_retarget))
181    ): ...
182
183    class OffsetVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros((3,))): ...
184
185    var_joints = robot.joint_var_cls(jnp.arange(timesteps))
186    var_Ts_world_root = jaxls.SE3Var(jnp.arange(timesteps))
187    var_smpl_joints_scale = SmplJointsScaleVarG1(jnp.zeros(timesteps))
188    var_offset = OffsetVar(jnp.zeros(timesteps))
189
190    # Costs.
191    costs: list[jaxls.Cost] = []
192
193    @jaxls.Cost.create_factory
194    def retargeting_cost(
195        var_values: jaxls.VarValues,
196        var_Ts_world_root: jaxls.SE3Var,
197        var_robot_cfg: jaxls.Var[jnp.ndarray],
198        var_smpl_joints_scale: SmplJointsScaleVarG1,
199        keypoints: jnp.ndarray,
200    ) -> jax.Array:
201        """Retargeting factor, with a focus on:
202        - matching the relative joint/keypoint positions (vectors).
203        - and matching the relative angles between the vectors.
204        """
205        robot_cfg = var_values[var_robot_cfg]
206        T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
207        T_world_root = var_values[var_Ts_world_root]
208        T_world_link = T_world_root @ T_root_link
209
210        smpl_pos = keypoints[jnp.array(smpl_joint_retarget_indices)]
211        robot_pos = T_world_link.translation()[jnp.array(g1_joint_retarget_indices)]
212
213        # NxN grid of relative positions.
214        delta_smpl = smpl_pos[:, None] - smpl_pos[None, :]
215        delta_robot = robot_pos[:, None] - robot_pos[None, :]
216
217        # Vector regularization.
218        position_scale = var_values[var_smpl_joints_scale][..., None]
219        residual_position_delta = (
220            (delta_smpl - delta_robot * position_scale)
221            * (1 - jnp.eye(delta_smpl.shape[0])[..., None])
222            * smpl_mask[..., None]
223        )
224
225        # Vector angle regularization.
226        delta_smpl_normalized = delta_smpl / jnp.linalg.norm(
227            delta_smpl + 1e-6, axis=-1, keepdims=True
228        )
229        delta_robot_normalized = delta_robot / jnp.linalg.norm(
230            delta_robot + 1e-6, axis=-1, keepdims=True
231        )
232        residual_angle_delta = 1 - (delta_smpl_normalized * delta_robot_normalized).sum(
233            axis=-1
234        )
235        residual_angle_delta = (
236            residual_angle_delta
237            * (1 - jnp.eye(residual_angle_delta.shape[0]))
238            * smpl_mask
239        )
240
241        residual = (
242            jnp.concatenate(
243                [residual_position_delta.flatten(), residual_angle_delta.flatten()]
244            )
245            * weights["local_alignment"]
246        )
247        return residual
248
249    @jaxls.Cost.create_factory
250    def scale_regularization(
251        var_values: jaxls.VarValues,
252        var_smpl_joints_scale: SmplJointsScaleVarG1,
253    ) -> jax.Array:
254        """Regularize the scale of the retargeted joints."""
255        # Close to 1.
256        res_0 = (var_values[var_smpl_joints_scale] - 1.0).flatten() * 1.0
257        # Symmetric.
258        res_1 = (
259            var_values[var_smpl_joints_scale] - var_values[var_smpl_joints_scale].T
260        ).flatten() * 100.0
261        # Non-negative.
262        res_2 = jnp.clip(-var_values[var_smpl_joints_scale], min=0).flatten() * 100.0
263        return jnp.concatenate([res_0, res_1, res_2])
264
265    @jaxls.Cost.create_factory
266    def pc_alignment_cost(
267        var_values: jaxls.VarValues,
268        var_Ts_world_root: jaxls.SE3Var,
269        var_robot_cfg: jaxls.Var[jnp.ndarray],
270        keypoints: jnp.ndarray,
271    ) -> jax.Array:
272        """Soft cost to align the human keypoints to the robot, in the world frame."""
273        T_world_root = var_values[var_Ts_world_root]
274        robot_cfg = var_values[var_robot_cfg]
275        T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
276        T_world_link = T_world_root @ T_root_link
277        link_pos = T_world_link.translation()[g1_joint_retarget_indices]
278        keypoint_pos = keypoints[smpl_joint_retarget_indices]
279        return (link_pos - keypoint_pos).flatten() * weights["global_alignment"]
280
281    @jaxls.Cost.create_factory
282    def floor_contact_cost(
283        var_values: jaxls.VarValues,
284        var_Ts_world_root: jaxls.SE3Var,
285        var_robot_cfg: jaxls.Var[jnp.ndarray],
286        var_offset: OffsetVar,
287        is_left_foot_contact: jnp.ndarray,
288        is_right_foot_contact: jnp.ndarray,
289        left_foot_keypoints: jnp.ndarray,
290        right_foot_keypoints: jnp.ndarray,
291    ) -> jax.Array:
292        """Cost to place the robot on the floor:
293        - match foot keypoint positions, and
294        - penalize the foot from tilting too much.
295        """
296        T_world_root = var_values[var_Ts_world_root]
297        T_root_link = jaxlie.SE3(
298            robot.forward_kinematics(cfg=var_values[var_robot_cfg])
299        )
300
301        offset = var_values[var_offset]
302        left_foot_pos = (T_world_root @ T_root_link).translation()[
303            left_foot_idx
304        ] + offset
305        right_foot_pos = (T_world_root @ T_root_link).translation()[
306            right_foot_idx
307        ] + offset
308        left_foot_contact_cost = (
309            is_left_foot_contact * (left_foot_pos - left_foot_keypoints) ** 2
310        )
311        right_foot_contact_cost = (
312            is_right_foot_contact * (right_foot_pos - right_foot_keypoints) ** 2
313        )
314
315        # Also penalize the foot from tilting too much -- keep z axis up!
316        left_foot_ori = (
317            (T_world_root @ T_root_link).rotation().as_matrix()[left_foot_idx]
318        )
319        right_foot_ori = (
320            (T_world_root @ T_root_link).rotation().as_matrix()[right_foot_idx]
321        )
322        left_foot_contact_residual_rot = jnp.where(
323            is_left_foot_contact,
324            left_foot_ori[2, 2] - 1,
325            0.0,
326        )
327        right_foot_contact_residual_rot = jnp.where(
328            is_right_foot_contact,
329            right_foot_ori[2, 2] - 1,
330            0.0,
331        )
332
333        return (
334            jnp.concatenate(
335                [
336                    left_foot_contact_cost.flatten(),
337                    right_foot_contact_cost.flatten(),
338                    left_foot_contact_residual_rot.flatten(),
339                    right_foot_contact_residual_rot.flatten(),
340                ]
341            )
342            * weights["floor_contact"]
343        )
344
345    @jaxls.Cost.create_factory
346    def root_smoothness(
347        var_values: jaxls.VarValues,
348        var_Ts_world_root: jaxls.SE3Var,
349        var_Ts_world_root_prev: jaxls.SE3Var,
350    ) -> jax.Array:
351        """Smoothness cost for the robot root pose."""
352        return (
353            var_values[var_Ts_world_root].inverse() @ var_values[var_Ts_world_root_prev]
354        ).log().flatten() * weights["root_smoothness"]
355
356    @jaxls.Cost.create_factory
357    def skating_cost(
358        var_values: jaxls.VarValues,
359        var_Ts_world_root: jaxls.SE3Var,
360        var_robot_cfg: jaxls.Var[jnp.ndarray],
361        var_offset: OffsetVar,
362        var_Ts_world_root_prev: jaxls.SE3Var,
363        var_robot_cfg_prev: jaxls.Var[jnp.ndarray],
364        var_offset_prev: OffsetVar,
365        is_left_foot_contact: jnp.ndarray,
366        is_right_foot_contact: jnp.ndarray,
367    ) -> jax.Array:
368        """Cost to penalize the robot for skating."""
369        T_world_root = var_values[var_Ts_world_root]
370        robot_cfg = var_values[var_robot_cfg]
371        T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
372        offset = var_values[var_offset]
373        T_link = T_world_root @ T_root_link
374        left_foot_pos = T_link.translation()[left_foot_idx] + offset
375        right_foot_pos = T_link.translation()[right_foot_idx] + offset
376
377        T_world_root_prev = var_values[var_Ts_world_root_prev]
378        robot_cfg_prev = var_values[var_robot_cfg_prev]
379        T_root_link_prev = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg_prev))
380        offset_prev = var_values[var_offset_prev]
381        T_link_prev = T_world_root_prev @ T_root_link_prev
382        left_foot_pos_prev = T_link_prev.translation()[left_foot_idx] + offset_prev
383        right_foot_pos_prev = T_link_prev.translation()[right_foot_idx] + offset_prev
384
385        skating_cost_left = is_left_foot_contact * (left_foot_pos - left_foot_pos_prev)
386        skating_cost_right = is_right_foot_contact * (
387            right_foot_pos - right_foot_pos_prev
388        )
389
390        return (
391            jnp.stack([skating_cost_left, skating_cost_right]) * weights["foot_skating"]
392        )
393
394    @jaxls.Cost.create_factory
395    def world_collision_cost(
396        var_values: jaxls.VarValues,
397        var_Ts_world_root: jaxls.SE3Var,
398        var_robot_cfg: jaxls.Var[jnp.ndarray],
399        var_offset: OffsetVar,
400    ) -> jax.Array:
401        """
402        World collision; we intentionally use a low weight --
403        high enough to lift the robot up from the ground, but
404        low enough to not interfere with the retargeting.
405        """
406        Ts_world_root = var_values[var_Ts_world_root]
407        T_offset = jaxlie.SE3.from_translation(var_values[var_offset])
408        transform = T_offset @ Ts_world_root
409
410        robot_cfg = var_values[var_robot_cfg]
411        coll = robot_coll.at_config(robot, robot_cfg)
412        coll = coll.transform(transform)
413
414        dist = collide(coll, heightmap)
415        act = colldist_from_sdf(dist, activation_dist=0.005)
416        return act.flatten() * weights["world_collision"]
417
418    costs = [
419        # Costs that are relatively self-contained to the robot.
420        retargeting_cost(
421            var_Ts_world_root,
422            var_joints,
423            var_smpl_joints_scale,
424            target_keypoints,
425        ),
426        scale_regularization(var_smpl_joints_scale),
427        pk.costs.limit_cost(
428            jax.tree.map(lambda x: x[None], robot),
429            var_joints,
430            100.0,
431        ),
432        pk.costs.smoothness_cost(
433            robot.joint_var_cls(jnp.arange(1, timesteps)),
434            robot.joint_var_cls(jnp.arange(0, timesteps - 1)),
435            jnp.array([0.2]),
436        ),
437        pk.costs.rest_cost(
438            var_joints,
439            var_joints.default_factory()[None],
440            jnp.full(var_joints.default_factory().shape, 0.2)
441            .at[joints_to_move_less]
442            .set(2.0)[None],
443        ),
444        pk.costs.self_collision_cost(
445            jax.tree.map(lambda x: x[None], robot),
446            jax.tree.map(lambda x: x[None], robot_coll),
447            var_joints,
448            margin=0.05,
449            weight=2.0,
450        ),
451        # Costs that are scene-centric.
452        pc_alignment_cost(
453            var_Ts_world_root,
454            var_joints,
455            target_keypoints,
456        ),
457        floor_contact_cost(
458            var_Ts_world_root,
459            var_joints,
460            var_offset,
461            is_left_foot_contact,
462            is_right_foot_contact,
463            left_foot_keypoints,
464            right_foot_keypoints,
465        ),
466        root_smoothness(
467            jaxls.SE3Var(jnp.arange(1, timesteps)),
468            jaxls.SE3Var(jnp.arange(0, timesteps - 1)),
469        ),
470        skating_cost(
471            jaxls.SE3Var(jnp.arange(1, timesteps)),
472            robot.joint_var_cls(jnp.arange(1, timesteps)),
473            OffsetVar(jnp.arange(1, timesteps)),
474            jaxls.SE3Var(jnp.arange(0, timesteps - 1)),
475            robot.joint_var_cls(jnp.arange(0, timesteps - 1)),
476            OffsetVar(jnp.arange(0, timesteps - 1)),
477            is_left_foot_contact[:-1],
478            is_right_foot_contact[:-1],
479        ),
480        world_collision_cost(
481            var_Ts_world_root,
482            var_joints,
483            var_offset,
484        ),
485    ]
486
487    solution = (
488        jaxls.LeastSquaresProblem(
489            costs, [var_joints, var_Ts_world_root, var_smpl_joints_scale, var_offset]
490        )
491        .analyze()
492        .solve()
493    )
494    transform = solution[var_Ts_world_root]
495    offset = solution[var_offset]
496    transform = jaxlie.SE3.from_translation(offset) @ transform
497    return transform, solution[var_joints]
498
499
500if __name__ == "__main__":
501    main()