.. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. Hand Retargeting ========================================== Simpler shadow hand retargeting example. Find and unzip the shadowhand URDF at ``assets/hand_retargeting/shadowhand_urdf.zip``. All examples can be run by first cloning the PyRoki repository, which includes the ``pyroki_snippets`` implementation details. .. code-block:: python :linenos: import pickle import time from pathlib import Path from typing import Tuple, TypedDict import jax import jax.numpy as jnp import jax_dataclasses as jdc import jaxlie import jaxls import numpy as onp import pyroki as pk import trimesh import viser import yourdfpy from scipy.spatial.transform import Rotation as R from viser.extras import ViserUrdf from retarget_helpers._utils import ( MANO_TO_SHADOW_MAPPING, create_conn_tree, get_mapping_from_mano_to_shadow, ) class RetargetingWeights(TypedDict): local_alignment: float """Local alignment weight, by matching the relative joint/keypoint positions and angles.""" global_alignment: float """Global alignment weight, by matching the keypoint positions to the robot.""" joint_smoothness: float """Joint smoothness weight.""" root_smoothness: float """Root translation smoothness weight.""" def main(): """Main function for hand retargeting.""" asset_dir = Path(__file__).parent / "retarget_helpers" / "hand" robot_urdf_path = asset_dir / "shadowhand_urdf" / "shadow_hand_right.urdf" def filename_handler(fname: str) -> str: base_path = robot_urdf_path.parent return yourdfpy.filename_handler_magic(fname, dir=base_path) try: urdf = yourdfpy.URDF.load(robot_urdf_path, filename_handler=filename_handler) except FileNotFoundError: raise FileNotFoundError( "Please unzip the included URDF at `retarget_helpers/hand/shadowhand_urdf.zip`." ) robot = pk.Robot.from_urdf(urdf) # Get the mapping from MANO to Shadow Hand joints. shadow_link_idx, mano_joint_idx = get_mapping_from_mano_to_shadow(robot) # Create a mask for the MANO joints that are connected to the Shadow Hand. mano_mask = create_conn_tree(robot, shadow_link_idx) # Load source motion data. dexycb_motion_path = asset_dir / "dexycb_motion.pkl" with open(dexycb_motion_path, "rb") as f: dexycb_motion_data = pickle.load(f, encoding="latin1") # Load keypoints. keypoints = dexycb_motion_data["world_hand_joints"] assert not onp.isnan(keypoints).any() num_timesteps = keypoints.shape[0] num_mano_joints = len(MANO_TO_SHADOW_MAPPING) # Load mano hand contact information -- these are lists of lists, # len(contact_points_per_frame) = num_timesteps, # len(contact_points_per_frame[i]) = number of contacts in frame i, contact_points_per_frame = dexycb_motion_data["contact_object_points"] contact_indices_per_frame = dexycb_motion_data["contact_joint_indices"] # Now, we're going to pad this info + make a mask to indicate the padded regions. # We will also track the shadowhand joint indices, NOT the MANO joint indices. max_num_contacts = max(len(c) for c in contact_points_per_frame) padded_contact_points_per_frame = onp.zeros((num_timesteps, max_num_contacts, 3)) padded_contact_indices_per_frame = onp.zeros( (num_timesteps, max_num_contacts), dtype=onp.int32 ) padded_contact_mask = onp.zeros((num_timesteps, max_num_contacts), dtype=onp.bool_) for i in range(num_timesteps): num_contacts = len(contact_points_per_frame[i]) if num_contacts == 0: continue contact_shadowhand_indices = [ robot.links.names.index(MANO_TO_SHADOW_MAPPING[j]) for j in contact_indices_per_frame[i] ] padded_contact_points_per_frame[i, :num_contacts] = contact_points_per_frame[i] padded_contact_indices_per_frame[i, :num_contacts] = contact_shadowhand_indices padded_contact_mask[i, :num_contacts] = True # Load the object. object_mesh_vertices = dexycb_motion_data["object_mesh_vertices"] object_mesh_faces = dexycb_motion_data["object_mesh_faces"] object_pose_list = dexycb_motion_data["object_poses"] # (N, 4, 4) mesh = trimesh.Trimesh(object_mesh_vertices, object_mesh_faces) server = viser.ViserServer() # We will transform everything by the transform below, for aesthetics. server.scene.add_frame( "/scene_offset", show_axes=False, position=(-0.15415953, -0.73598871, 0.93434792), wxyz=(-0.381870867, 0.92421569, 0.0, 2.0004992e-32), ) hand_mesh = server.scene.add_mesh_simple( "/scene_offset/hand_mesh", vertices=dexycb_motion_data["world_hand_vertices"][0, :, :], faces=dexycb_motion_data["hand_mesh_faces"], opacity=0.5, ) base_frame = server.scene.add_frame("/scene_offset/base", show_axes=False) urdf_vis = ViserUrdf(server, urdf, root_node_name="/scene_offset/base") playing = server.gui.add_checkbox("playing", True) timestep_slider = server.gui.add_slider("timestep", 0, num_timesteps - 1, 1, 0) object_handle = server.scene.add_mesh_trimesh("/scene_offset/object", mesh) server.scene.add_grid("/grid", 2.0, 2.0) default_weights = RetargetingWeights( local_alignment=10.0, global_alignment=1.0, joint_smoothness=2.0, root_smoothness=2.0, ) weights = pk.viewer.WeightTuner( server, default_weights, # type: ignore ) Ts_world_root, joints = None, None def generate_trajectory(): nonlocal Ts_world_root, joints gen_button.disabled = True Ts_world_root, joints = solve_retargeting( robot=robot, target_keypoints=keypoints, shadow_hand_link_retarget_indices=shadow_link_idx, mano_joint_retarget_indices=mano_joint_idx, mano_mask=mano_mask, weights=weights.get_weights(), # type: ignore ) gen_button.disabled = False gen_button = server.gui.add_button("Retarget!") gen_button.on_click(lambda _: generate_trajectory()) generate_trajectory() assert Ts_world_root is not None and joints is not None while True: with server.atomic(): if playing.value: timestep_slider.value = (timestep_slider.value + 1) % num_timesteps tstep = timestep_slider.value base_frame.wxyz = onp.array(Ts_world_root.wxyz_xyz[tstep][:4]) base_frame.position = onp.array(Ts_world_root.wxyz_xyz[tstep][4:]) urdf_vis.update_cfg(onp.array(joints[tstep])) server.scene.add_point_cloud( "/scene_offset/target_keypoints", onp.array(keypoints[tstep]).reshape(-1, 3), onp.array((0, 0, 255))[None] .repeat(num_mano_joints, axis=0) .reshape(-1, 3), point_size=0.005, point_shape="sparkle", ) server.scene.add_point_cloud( "/scene_offset/contact_points", onp.array(contact_points_per_frame[tstep]).reshape(-1, 3), onp.array((255, 0, 0))[None] .repeat(len(contact_points_per_frame[tstep]), axis=0) .reshape(-1, 3), point_size=0.005, point_shape="circle", ) hand_mesh.vertices = dexycb_motion_data["world_hand_vertices"][tstep, :, :] object_handle.position = object_pose_list[tstep][:3, 3] object_handle.wxyz = R.from_matrix(object_pose_list[tstep][:3, :3]).as_quat( scalar_first=True ) time.sleep(0.05) @jdc.jit def solve_retargeting( robot: pk.Robot, target_keypoints: jnp.ndarray, shadow_hand_link_retarget_indices: jnp.ndarray, mano_joint_retarget_indices: jnp.ndarray, mano_mask: jnp.ndarray, weights: RetargetingWeights, ) -> Tuple[jaxlie.SE3, jnp.ndarray]: """Solve the retargeting problem.""" n_retarget = len(mano_joint_retarget_indices) timesteps = target_keypoints.shape[0] # Variables. class ManoJointsScaleVar( jaxls.Var[jax.Array], default_factory=lambda: jnp.ones((n_retarget, n_retarget)) ): ... class OffsetVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros((3,))): ... var_joints = robot.joint_var_cls(jnp.arange(timesteps)) var_Ts_world_root = jaxls.SE3Var(jnp.arange(timesteps)) var_smpl_joints_scale = ManoJointsScaleVar(jnp.zeros(timesteps)) var_offset = OffsetVar(jnp.zeros(timesteps)) # Costs. costs: list[jaxls.Cost] = [] @jaxls.Cost.create_factory def retargeting_cost( var_values: jaxls.VarValues, var_Ts_world_root: jaxls.SE3Var, var_robot_cfg: jaxls.Var[jnp.ndarray], var_smpl_joints_scale: ManoJointsScaleVar, keypoints: jnp.ndarray, ) -> jax.Array: """Retargeting factor, with a focus on: - matching the relative joint/keypoint positions (vectors). - and matching the relative angles between the vectors. """ robot_cfg = var_values[var_robot_cfg] T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg)) T_world_root = var_values[var_Ts_world_root] T_world_link = T_world_root @ T_root_link mano_pos = keypoints[jnp.array(mano_joint_retarget_indices)] robot_pos = T_world_link.translation()[ jnp.array(shadow_hand_link_retarget_indices) ] # NxN grid of relative positions. delta_mano = mano_pos[:, None] - mano_pos[None, :] delta_robot = robot_pos[:, None] - robot_pos[None, :] # Vector regularization. position_scale = var_values[var_smpl_joints_scale][..., None] residual_position_delta = ( (delta_mano - delta_robot * position_scale) * (1 - jnp.eye(delta_mano.shape[0])[..., None]) * mano_mask[..., None] ) # Vector angle regularization. delta_mano_normalized = delta_mano / jnp.linalg.norm( delta_mano + 1e-6, axis=-1, keepdims=True ) delta_robot_normalized = delta_robot / jnp.linalg.norm( delta_robot + 1e-6, axis=-1, keepdims=True ) residual_angle_delta = 1 - (delta_mano_normalized * delta_robot_normalized).sum( axis=-1 ) residual_angle_delta = ( residual_angle_delta * (1 - jnp.eye(residual_angle_delta.shape[0])) * mano_mask ) residual = ( jnp.concatenate( [ residual_position_delta.flatten(), residual_angle_delta.flatten(), ], axis=0, ) * weights["local_alignment"] ) return residual @jaxls.Cost.create_factory def pc_alignment_cost( var_values: jaxls.VarValues, var_Ts_world_root: jaxls.SE3Var, var_robot_cfg: jaxls.Var[jnp.ndarray], keypoints: jnp.ndarray, ) -> jax.Array: """Soft cost to align the human keypoints to the robot, in the world frame.""" T_world_root = var_values[var_Ts_world_root] robot_cfg = var_values[var_robot_cfg] T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg)) T_world_link = T_world_root @ T_root_link link_pos = T_world_link.translation()[shadow_hand_link_retarget_indices] keypoint_pos = keypoints[mano_joint_retarget_indices] return (link_pos - keypoint_pos).flatten() * weights["global_alignment"] @jaxls.Cost.create_factory def root_smoothness( var_values: jaxls.VarValues, var_Ts_world_root: jaxls.SE3Var, var_Ts_world_root_prev: jaxls.SE3Var, ) -> jax.Array: """Smoothness cost for the robot root translation.""" return ( var_values[var_Ts_world_root].translation() - var_values[var_Ts_world_root_prev].translation() ).flatten() * weights["root_smoothness"] costs = [ retargeting_cost( var_Ts_world_root, var_joints, var_smpl_joints_scale, target_keypoints, ), pk.costs.limit_cost( jax.tree.map(lambda x: x[None], robot), var_joints, 100.0, ), pk.costs.smoothness_cost( robot.joint_var_cls(jnp.arange(1, timesteps)), robot.joint_var_cls(jnp.arange(0, timesteps - 1)), jnp.array([weights["joint_smoothness"]]), ), pc_alignment_cost( var_Ts_world_root, var_joints, target_keypoints, ), root_smoothness( jaxls.SE3Var(jnp.arange(1, timesteps)), jaxls.SE3Var(jnp.arange(0, timesteps - 1)), ), ] solution = ( jaxls.LeastSquaresProblem( costs, [var_joints, var_Ts_world_root, var_smpl_joints_scale, var_offset] ) .analyze() .solve() ) transform = solution[var_Ts_world_root] offset = solution[var_offset] transform = jaxlie.SE3.from_translation(offset) @ transform return transform, solution[var_joints] if __name__ == "__main__": main()