Humanoid RetargetingΒΆ

Simpler motion retargeting to the G1 humanoid.

All examples can be run by first cloning the PyRoki repository, which includes the pyroki_snippets implementation details.

  1import time
  2from pathlib import Path
  3from typing import Tuple, TypedDict
  4
  5import jax
  6import jax.numpy as jnp
  7import jax_dataclasses as jdc
  8import jaxlie
  9import jaxls
 10import numpy as onp
 11import pyroki as pk
 12import viser
 13from pyroki.collision import colldist_from_sdf, collide
 14from robot_descriptions.loaders.yourdfpy import load_robot_description
 15from viser.extras import ViserUrdf
 16
 17from retarget_helpers._utils import (
 18    SMPL_JOINT_NAMES,
 19    create_conn_tree,
 20    get_humanoid_retarget_indices,
 21)
 22
 23
 24class RetargetingWeights(TypedDict):
 25    local_alignment: float
 26    """Local alignment weight, by matching the relative joint/keypoint positions and angles."""
 27    global_alignment: float
 28    """Global alignment weight, by matching the keypoint positions to the robot."""
 29
 30
 31def main():
 32    """Main function for humanoid retargeting."""
 33
 34    urdf = load_robot_description("g1_description")
 35    robot = pk.Robot.from_urdf(urdf)
 36    robot_coll = pk.collision.RobotCollision.from_urdf(urdf)
 37
 38    # Load source motion data:
 39    # - keypoints [N, 45, 3],
 40    # - left/right foot contact (boolean) 2 x [N],
 41    # - heightmap [H, W].
 42    asset_dir = Path(__file__).parent / "retarget_helpers" / "humanoid"
 43    smpl_keypoints = onp.load(asset_dir / "smpl_keypoints.npy")
 44    is_left_foot_contact = onp.load(asset_dir / "left_foot_contact.npy")
 45    is_right_foot_contact = onp.load(asset_dir / "right_foot_contact.npy")
 46    heightmap = onp.load(asset_dir / "heightmap.npy")
 47
 48    num_timesteps = smpl_keypoints.shape[0]
 49    assert smpl_keypoints.shape == (num_timesteps, 45, 3)
 50    assert is_left_foot_contact.shape == (num_timesteps,)
 51    assert is_right_foot_contact.shape == (num_timesteps,)
 52
 53    heightmap = pk.collision.Heightmap(
 54        pose=jaxlie.SE3.identity(),
 55        size=jnp.array([0.01, 0.01, 1.0]),
 56        height_data=heightmap,
 57    )
 58
 59    # Get the left and right foot keypoints, projected on the heightmap.
 60    left_foot_keypoint_idx = SMPL_JOINT_NAMES.index("left_foot")
 61    right_foot_keypoint_idx = SMPL_JOINT_NAMES.index("right_foot")
 62    left_foot_keypoints = smpl_keypoints[..., left_foot_keypoint_idx, :].reshape(-1, 3)
 63    right_foot_keypoints = smpl_keypoints[..., right_foot_keypoint_idx, :].reshape(
 64        -1, 3
 65    )
 66    left_foot_keypoints = heightmap.project_points(left_foot_keypoints)
 67    right_foot_keypoints = heightmap.project_points(right_foot_keypoints)
 68
 69    smpl_joint_retarget_indices, g1_joint_retarget_indices = (
 70        get_humanoid_retarget_indices()
 71    )
 72    smpl_mask = create_conn_tree(robot, g1_joint_retarget_indices)
 73
 74    server = viser.ViserServer()
 75    base_frame = server.scene.add_frame("/base", show_axes=False)
 76    urdf_vis = ViserUrdf(server, urdf, root_node_name="/base")
 77    playing = server.gui.add_checkbox("playing", True)
 78    timestep_slider = server.gui.add_slider("timestep", 0, num_timesteps - 1, 1, 0)
 79    server.scene.add_mesh_trimesh("/heightmap", heightmap.to_trimesh())
 80
 81    weights = pk.viewer.WeightTuner(
 82        server,
 83        RetargetingWeights(  # type: ignore
 84            local_alignment=2.0,
 85            global_alignment=1.0,
 86        ),
 87    )
 88
 89    Ts_world_root, joints = None, None
 90
 91    def generate_trajectory():
 92        nonlocal Ts_world_root, joints
 93        gen_button.disabled = True
 94        Ts_world_root, joints = solve_retargeting(
 95            robot=robot,
 96            target_keypoints=smpl_keypoints,
 97            smpl_joint_retarget_indices=smpl_joint_retarget_indices,
 98            g1_joint_retarget_indices=g1_joint_retarget_indices,
 99            smpl_mask=smpl_mask,
100            weights=weights.get_weights(),  # type: ignore
101        )
102        gen_button.disabled = False
103
104    gen_button = server.gui.add_button("Retarget!")
105    gen_button.on_click(lambda _: generate_trajectory())
106
107    generate_trajectory()
108    assert Ts_world_root is not None and joints is not None
109
110    while True:
111        with server.atomic():
112            if playing.value:
113                timestep_slider.value = (timestep_slider.value + 1) % num_timesteps
114            tstep = timestep_slider.value
115            base_frame.wxyz = onp.array(Ts_world_root.wxyz_xyz[tstep][:4])
116            base_frame.position = onp.array(Ts_world_root.wxyz_xyz[tstep][4:])
117            urdf_vis.update_cfg(onp.array(joints[tstep]))
118            server.scene.add_point_cloud(
119                "/target_keypoints",
120                onp.array(smpl_keypoints[tstep]),
121                onp.array((0, 0, 255))[None].repeat(45, axis=0),
122                point_size=0.01,
123            )
124
125        time.sleep(0.05)
126
127
128@jdc.jit
129def solve_retargeting(
130    robot: pk.Robot,
131    target_keypoints: jnp.ndarray,
132    smpl_joint_retarget_indices: jnp.ndarray,
133    g1_joint_retarget_indices: jnp.ndarray,
134    smpl_mask: jnp.ndarray,
135    weights: RetargetingWeights,
136) -> Tuple[jaxlie.SE3, jnp.ndarray]:
137    """Solve the retargeting problem."""
138
139    n_retarget = len(smpl_joint_retarget_indices)
140    timesteps = target_keypoints.shape[0]
141
142    # Robot properties.
143    # - Joints that should move less for natural humanoid motion.
144    joints_to_move_less = jnp.array(
145        [
146            robot.joints.actuated_names.index(name)
147            for name in ["left_hip_yaw_joint", "right_hip_yaw_joint", "torso_joint"]
148        ]
149    )
150
151    # Variables.
152    class SmplJointsScaleVarG1(
153        jaxls.Var[jax.Array], default_factory=lambda: jnp.ones((n_retarget, n_retarget))
154    ): ...
155
156    class OffsetVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros((3,))): ...
157
158    var_joints = robot.joint_var_cls(jnp.arange(timesteps))
159    var_Ts_world_root = jaxls.SE3Var(jnp.arange(timesteps))
160    var_smpl_joints_scale = SmplJointsScaleVarG1(jnp.zeros(timesteps))
161    var_offset = OffsetVar(jnp.zeros(timesteps))
162
163    # Costs.
164    costs: list[jaxls.Cost] = []
165
166    @jaxls.Cost.create_factory
167    def retargeting_cost(
168        var_values: jaxls.VarValues,
169        var_Ts_world_root: jaxls.SE3Var,
170        var_robot_cfg: jaxls.Var[jnp.ndarray],
171        var_smpl_joints_scale: SmplJointsScaleVarG1,
172        keypoints: jnp.ndarray,
173    ) -> jax.Array:
174        """Retargeting factor, with a focus on:
175        - matching the relative joint/keypoint positions (vectors).
176        - and matching the relative angles between the vectors.
177        """
178        robot_cfg = var_values[var_robot_cfg]
179        T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
180        T_world_root = var_values[var_Ts_world_root]
181        T_world_link = T_world_root @ T_root_link
182
183        smpl_pos = keypoints[jnp.array(smpl_joint_retarget_indices)]
184        robot_pos = T_world_link.translation()[jnp.array(g1_joint_retarget_indices)]
185
186        # NxN grid of relative positions.
187        delta_smpl = smpl_pos[:, None] - smpl_pos[None, :]
188        delta_robot = robot_pos[:, None] - robot_pos[None, :]
189
190        # Vector regularization.
191        position_scale = var_values[var_smpl_joints_scale][..., None]
192        residual_position_delta = (
193            (delta_smpl - delta_robot * position_scale)
194            * (1 - jnp.eye(delta_smpl.shape[0])[..., None])
195            * smpl_mask[..., None]
196        )
197
198        # Vector angle regularization.
199        delta_smpl_normalized = delta_smpl / jnp.linalg.norm(
200            delta_smpl + 1e-6, axis=-1, keepdims=True
201        )
202        delta_robot_normalized = delta_robot / jnp.linalg.norm(
203            delta_robot + 1e-6, axis=-1, keepdims=True
204        )
205        residual_angle_delta = 1 - (delta_smpl_normalized * delta_robot_normalized).sum(
206            axis=-1
207        )
208        residual_angle_delta = (
209            residual_angle_delta
210            * (1 - jnp.eye(residual_angle_delta.shape[0]))
211            * smpl_mask
212        )
213
214        residual = (
215            jnp.concatenate(
216                [residual_position_delta.flatten(), residual_angle_delta.flatten()]
217            )
218            * weights["local_alignment"]
219        )
220        return residual
221
222    @jaxls.Cost.create_factory
223    def pc_alignment_cost(
224        var_values: jaxls.VarValues,
225        var_Ts_world_root: jaxls.SE3Var,
226        var_robot_cfg: jaxls.Var[jnp.ndarray],
227        keypoints: jnp.ndarray,
228    ) -> jax.Array:
229        """Soft cost to align the human keypoints to the robot, in the world frame."""
230        T_world_root = var_values[var_Ts_world_root]
231        robot_cfg = var_values[var_robot_cfg]
232        T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
233        T_world_link = T_world_root @ T_root_link
234        link_pos = T_world_link.translation()[g1_joint_retarget_indices]
235        keypoint_pos = keypoints[smpl_joint_retarget_indices]
236        return (link_pos - keypoint_pos).flatten() * weights["global_alignment"]
237
238    costs = [
239        # Costs that are relatively self-contained to the robot.
240        retargeting_cost(
241            var_Ts_world_root,
242            var_joints,
243            var_smpl_joints_scale,
244            target_keypoints,
245        ),
246        pk.costs.limit_cost(
247            jax.tree.map(lambda x: x[None], robot),
248            var_joints,
249            100.0,
250        ),
251        pk.costs.smoothness_cost(
252            robot.joint_var_cls(jnp.arange(1, timesteps)),
253            robot.joint_var_cls(jnp.arange(0, timesteps - 1)),
254            jnp.array([0.2]),
255        ),
256        pk.costs.rest_cost(
257            var_joints,
258            var_joints.default_factory()[None],
259            jnp.full(var_joints.default_factory().shape, 0.2)
260            .at[joints_to_move_less]
261            .set(2.0)[None],
262        ),
263        # Costs that are scene-centric.
264        pc_alignment_cost(
265            var_Ts_world_root,
266            var_joints,
267            target_keypoints,
268        ),
269    ]
270
271    solution = (
272        jaxls.LeastSquaresProblem(
273            costs, [var_joints, var_Ts_world_root, var_smpl_joints_scale, var_offset]
274        )
275        .analyze()
276        .solve()
277    )
278    transform = solution[var_Ts_world_root]
279    offset = solution[var_offset]
280    transform = jaxlie.SE3.from_translation(offset) @ transform
281    return transform, solution[var_joints]
282
283
284if __name__ == "__main__":
285    main()