Humanoid RetargetingΒΆ

Simpler motion retargeting to the G1 humanoid.

All examples can be run by first cloning the PyRoki repository, which includes the pyroki_snippets implementation details.

  1import time
  2from pathlib import Path
  3from typing import Tuple, TypedDict
  4
  5import jax
  6import jax.numpy as jnp
  7import jax_dataclasses as jdc
  8import jaxlie
  9import jaxls
 10import numpy as onp
 11import pyroki as pk
 12import viser
 13from robot_descriptions.loaders.yourdfpy import load_robot_description
 14from viser.extras import ViserUrdf
 15
 16from retarget_helpers._utils import (
 17    SMPL_JOINT_NAMES,
 18    create_conn_tree,
 19    get_humanoid_retarget_indices,
 20)
 21
 22
 23class RetargetingWeights(TypedDict):
 24    local_alignment: float
 25    """Local alignment weight, by matching the relative joint/keypoint positions and angles."""
 26    global_alignment: float
 27    """Global alignment weight, by matching the keypoint positions to the robot."""
 28
 29
 30def main():
 31    """Main function for humanoid retargeting."""
 32
 33    urdf = load_robot_description("g1_description")
 34    robot = pk.Robot.from_urdf(urdf)
 35
 36    # Load source motion data:
 37    # - keypoints [N, 45, 3],
 38    # - left/right foot contact (boolean) 2 x [N],
 39    # - heightmap [H, W].
 40    asset_dir = Path(__file__).parent / "retarget_helpers" / "humanoid"
 41    smpl_keypoints = onp.load(asset_dir / "smpl_keypoints.npy")
 42    is_left_foot_contact = onp.load(asset_dir / "left_foot_contact.npy")
 43    is_right_foot_contact = onp.load(asset_dir / "right_foot_contact.npy")
 44    heightmap = onp.load(asset_dir / "heightmap.npy")
 45
 46    num_timesteps = smpl_keypoints.shape[0]
 47    assert smpl_keypoints.shape == (num_timesteps, 45, 3)
 48    assert is_left_foot_contact.shape == (num_timesteps,)
 49    assert is_right_foot_contact.shape == (num_timesteps,)
 50
 51    heightmap = pk.collision.Heightmap(
 52        pose=jaxlie.SE3.identity(),
 53        size=jnp.array([0.01, 0.01, 1.0]),
 54        height_data=heightmap,
 55    )
 56
 57    # Get the left and right foot keypoints, projected on the heightmap.
 58    left_foot_keypoint_idx = SMPL_JOINT_NAMES.index("left_foot")
 59    right_foot_keypoint_idx = SMPL_JOINT_NAMES.index("right_foot")
 60    left_foot_keypoints = smpl_keypoints[..., left_foot_keypoint_idx, :].reshape(-1, 3)
 61    right_foot_keypoints = smpl_keypoints[..., right_foot_keypoint_idx, :].reshape(
 62        -1, 3
 63    )
 64    left_foot_keypoints = heightmap.project_points(left_foot_keypoints)
 65    right_foot_keypoints = heightmap.project_points(right_foot_keypoints)
 66
 67    smpl_joint_retarget_indices, g1_joint_retarget_indices = (
 68        get_humanoid_retarget_indices()
 69    )
 70    smpl_mask = create_conn_tree(robot, g1_joint_retarget_indices)
 71
 72    server = viser.ViserServer()
 73    base_frame = server.scene.add_frame("/base", show_axes=False)
 74    urdf_vis = ViserUrdf(server, urdf, root_node_name="/base")
 75    playing = server.gui.add_checkbox("playing", True)
 76    timestep_slider = server.gui.add_slider("timestep", 0, num_timesteps - 1, 1, 0)
 77    server.scene.add_mesh_trimesh("/heightmap", heightmap.to_trimesh())
 78
 79    weights = pk.viewer.WeightTuner(
 80        server,
 81        RetargetingWeights(  # type: ignore
 82            local_alignment=2.0,
 83            global_alignment=1.0,
 84        ),
 85    )
 86
 87    Ts_world_root, joints = None, None
 88
 89    def generate_trajectory():
 90        nonlocal Ts_world_root, joints
 91        gen_button.disabled = True
 92        Ts_world_root, joints = solve_retargeting(
 93            robot=robot,
 94            target_keypoints=smpl_keypoints,
 95            smpl_joint_retarget_indices=smpl_joint_retarget_indices,
 96            g1_joint_retarget_indices=g1_joint_retarget_indices,
 97            smpl_mask=smpl_mask,
 98            weights=weights.get_weights(),  # type: ignore
 99        )
100        gen_button.disabled = False
101
102    gen_button = server.gui.add_button("Retarget!")
103    gen_button.on_click(lambda _: generate_trajectory())
104
105    generate_trajectory()
106    assert Ts_world_root is not None and joints is not None
107
108    while True:
109        with server.atomic():
110            if playing.value:
111                timestep_slider.value = (timestep_slider.value + 1) % num_timesteps
112            tstep = timestep_slider.value
113            base_frame.wxyz = onp.array(Ts_world_root.wxyz_xyz[tstep][:4])
114            base_frame.position = onp.array(Ts_world_root.wxyz_xyz[tstep][4:])
115            urdf_vis.update_cfg(onp.array(joints[tstep]))
116            server.scene.add_point_cloud(
117                "/target_keypoints",
118                onp.array(smpl_keypoints[tstep]),
119                onp.array((0, 0, 255))[None].repeat(45, axis=0),
120                point_size=0.01,
121            )
122
123        time.sleep(0.05)
124
125
126@jdc.jit
127def solve_retargeting(
128    robot: pk.Robot,
129    target_keypoints: jnp.ndarray,
130    smpl_joint_retarget_indices: jnp.ndarray,
131    g1_joint_retarget_indices: jnp.ndarray,
132    smpl_mask: jnp.ndarray,
133    weights: RetargetingWeights,
134) -> Tuple[jaxlie.SE3, jnp.ndarray]:
135    """Solve the retargeting problem."""
136
137    n_retarget = len(smpl_joint_retarget_indices)
138    timesteps = target_keypoints.shape[0]
139
140    # Robot properties.
141    # - Joints that should move less for natural humanoid motion.
142    joints_to_move_less = jnp.array(
143        [
144            robot.joints.actuated_names.index(name)
145            for name in ["left_hip_yaw_joint", "right_hip_yaw_joint", "waist_yaw_joint"]
146        ]
147    )
148
149    # Variables.
150    class SmplJointsScaleVarG1(
151        jaxls.Var[jax.Array], default_factory=lambda: jnp.ones((n_retarget, n_retarget))
152    ): ...
153
154    class OffsetVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros((3,))): ...
155
156    var_joints = robot.joint_var_cls(jnp.arange(timesteps))
157    var_Ts_world_root = jaxls.SE3Var(jnp.arange(timesteps))
158    var_smpl_joints_scale = SmplJointsScaleVarG1(jnp.zeros(timesteps))
159    var_offset = OffsetVar(jnp.zeros(timesteps))
160
161    # Costs and constraints.
162    costs: list[jaxls.Cost] = []
163
164    @jaxls.Cost.factory
165    def retargeting_cost(
166        var_values: jaxls.VarValues,
167        var_Ts_world_root: jaxls.SE3Var,
168        var_robot_cfg: jaxls.Var[jnp.ndarray],
169        var_smpl_joints_scale: SmplJointsScaleVarG1,
170        keypoints: jnp.ndarray,
171    ) -> jax.Array:
172        """Retargeting factor, with a focus on:
173        - matching the relative joint/keypoint positions (vectors).
174        - and matching the relative angles between the vectors.
175        """
176        robot_cfg = var_values[var_robot_cfg]
177        T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
178        T_world_root = var_values[var_Ts_world_root]
179        T_world_link = T_world_root @ T_root_link
180
181        smpl_pos = keypoints[jnp.array(smpl_joint_retarget_indices)]
182        robot_pos = T_world_link.translation()[jnp.array(g1_joint_retarget_indices)]
183
184        # NxN grid of relative positions.
185        delta_smpl = smpl_pos[:, None] - smpl_pos[None, :]
186        delta_robot = robot_pos[:, None] - robot_pos[None, :]
187
188        # Vector regularization.
189        position_scale = var_values[var_smpl_joints_scale][..., None]
190        residual_position_delta = (
191            (delta_smpl - delta_robot * position_scale)
192            * (1 - jnp.eye(delta_smpl.shape[0])[..., None])
193            * smpl_mask[..., None]
194        )
195
196        # Vector angle regularization.
197        delta_smpl_normalized = delta_smpl / jnp.linalg.norm(
198            delta_smpl + 1e-6, axis=-1, keepdims=True
199        )
200        delta_robot_normalized = delta_robot / jnp.linalg.norm(
201            delta_robot + 1e-6, axis=-1, keepdims=True
202        )
203        residual_angle_delta = 1 - (delta_smpl_normalized * delta_robot_normalized).sum(
204            axis=-1
205        )
206        residual_angle_delta = (
207            residual_angle_delta
208            * (1 - jnp.eye(residual_angle_delta.shape[0]))
209            * smpl_mask
210        )
211
212        residual = (
213            jnp.concatenate(
214                [residual_position_delta.flatten(), residual_angle_delta.flatten()]
215            )
216            * weights["local_alignment"]
217        )
218        return residual
219
220    @jaxls.Cost.factory
221    def pc_alignment_cost(
222        var_values: jaxls.VarValues,
223        var_Ts_world_root: jaxls.SE3Var,
224        var_robot_cfg: jaxls.Var[jnp.ndarray],
225        keypoints: jnp.ndarray,
226    ) -> jax.Array:
227        """Soft cost to align the human keypoints to the robot, in the world frame."""
228        T_world_root = var_values[var_Ts_world_root]
229        robot_cfg = var_values[var_robot_cfg]
230        T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
231        T_world_link = T_world_root @ T_root_link
232        link_pos = T_world_link.translation()[g1_joint_retarget_indices]
233        keypoint_pos = keypoints[smpl_joint_retarget_indices]
234        return (link_pos - keypoint_pos).flatten() * weights["global_alignment"]
235
236    costs = [
237        # Costs that are relatively self-contained to the robot.
238        retargeting_cost(
239            var_Ts_world_root,
240            var_joints,
241            var_smpl_joints_scale,
242            target_keypoints,
243        ),
244        pk.costs.smoothness_cost(
245            robot.joint_var_cls(jnp.arange(1, timesteps)),
246            robot.joint_var_cls(jnp.arange(0, timesteps - 1)),
247            jnp.array([0.2]),
248        ),
249        pk.costs.rest_cost(
250            var_joints,
251            var_joints.default_factory()[None],
252            jnp.full(var_joints.default_factory().shape, 0.2)
253            .at[joints_to_move_less]
254            .set(2.0)[None],
255        ),
256        # Costs that are scene-centric.
257        pc_alignment_cost(
258            var_Ts_world_root,
259            var_joints,
260            target_keypoints,
261        ),
262    ]
263
264    costs.append(
265        pk.costs.limit_constraint(
266            jax.tree.map(lambda x: x[None], robot),
267            var_joints,
268        ),
269    )
270
271    solution = (
272        jaxls.LeastSquaresProblem(
273            costs=costs,
274            variables=[
275                var_joints,
276                var_Ts_world_root,
277                var_smpl_joints_scale,
278                var_offset,
279            ],
280        )
281        .analyze()
282        .solve()
283    )
284    transform = solution[var_Ts_world_root]
285    offset = solution[var_offset]
286    transform = jaxlie.SE3.from_translation(offset) @ transform
287    return transform, solution[var_joints]
288
289
290if __name__ == "__main__":
291    main()