.. Comment: this file is automatically generated by `update_example_docs.py`. It should not be modified manually. Hand Retargeting (Fancy) ========================================== Shadow Hand retargeting example, with costs to maintain contact with the object. Find and unzip the shadowhand URDF at ``assets/hand_retargeting/shadowhand_urdf.zip``. All examples can be run by first cloning the PyRoki repository, which includes the ``pyroki_snippets`` implementation details. .. code-block:: python :linenos: import time from typing import Tuple, TypedDict from pathlib import Path import pickle import trimesh from scipy.spatial.transform import Rotation as R import jax import jax.numpy as jnp import jax_dataclasses as jdc import jaxlie import jaxls import numpy as onp import viser from viser.extras import ViserUrdf import yourdfpy import pyroki as pk from retarget_helpers._utils import ( create_conn_tree, get_mapping_from_mano_to_shadow, MANO_TO_SHADOW_MAPPING, ) class RetargetingWeights(TypedDict): local_alignment: float """Local alignment weight, by matching the relative joint/keypoint positions and angles.""" global_alignment: float """Global alignment weight, by matching the keypoint positions to the robot.""" contact: float """Contact weight, to maintain contact between the robot and the object.""" contact_margin: float """Contact margin, to stop penalizing contact when the robot is already close to the object.""" joint_smoothness: float """Joint smoothness weight.""" root_smoothness: float """Root translation smoothness weight.""" def main(): """Main function for hand retargeting.""" asset_dir = Path(__file__).parent / "retarget_helpers" / "hand" robot_urdf_path = asset_dir / "shadowhand_urdf" / "shadow_hand_right.urdf" def filename_handler(fname: str) -> str: base_path = robot_urdf_path.parent return yourdfpy.filename_handler_magic(fname, dir=base_path) try: urdf = yourdfpy.URDF.load(robot_urdf_path, filename_handler=filename_handler) except FileNotFoundError: raise FileNotFoundError( "Please unzip the included URDF at `retarget_helpers/hand/shadowhand_urdf.zip`." ) robot = pk.Robot.from_urdf(urdf) # Get the mapping from MANO to Shadow Hand joints. shadow_link_idx, mano_joint_idx = get_mapping_from_mano_to_shadow(robot) # Create a mask for the MANO joints that are connected to the Shadow Hand. mano_mask = create_conn_tree(robot, shadow_link_idx) # Load source motion data. dexycb_motion_path = asset_dir / "dexycb_motion.pkl" with open(dexycb_motion_path, "rb") as f: dexycb_motion_data = pickle.load(f, encoding="latin1") # Load keypoints. keypoints = dexycb_motion_data["world_hand_joints"] assert not onp.isnan(keypoints).any() num_timesteps = keypoints.shape[0] num_mano_joints = len(MANO_TO_SHADOW_MAPPING) # Load mano hand contact information -- these are lists of lists, # len(contact_points_per_frame) = num_timesteps, # len(contact_points_per_frame[i]) = number of contacts in frame i, contact_points_per_frame = dexycb_motion_data["contact_object_points"] contact_indices_per_frame = dexycb_motion_data["contact_joint_indices"] # Now, we're going to pad this info + make a mask to indicate the padded regions. # We will also track the shadowhand joint indices, NOT the MANO joint indices. max_num_contacts = max(len(c) for c in contact_points_per_frame) padded_contact_points_per_frame = onp.zeros((num_timesteps, max_num_contacts, 3)) padded_contact_indices_per_frame = onp.zeros( (num_timesteps, max_num_contacts), dtype=onp.int32 ) padded_contact_mask = onp.zeros((num_timesteps, max_num_contacts), dtype=onp.bool_) for i in range(num_timesteps): num_contacts = len(contact_points_per_frame[i]) if num_contacts == 0: continue contact_shadowhand_indices = [ robot.links.names.index(MANO_TO_SHADOW_MAPPING[j]) for j in contact_indices_per_frame[i] ] padded_contact_points_per_frame[i, :num_contacts] = contact_points_per_frame[i] padded_contact_indices_per_frame[i, :num_contacts] = contact_shadowhand_indices padded_contact_mask[i, :num_contacts] = True # Load the object. object_mesh_vertices = dexycb_motion_data["object_mesh_vertices"] object_mesh_faces = dexycb_motion_data["object_mesh_faces"] object_pose_list = dexycb_motion_data["object_poses"] # (N, 4, 4) mesh = trimesh.Trimesh(object_mesh_vertices, object_mesh_faces) server = viser.ViserServer() # We will transform everything by the transform below, for aesthetics. server.scene.add_frame( "/scene_offset", show_axes=False, position=(-0.15415953, -0.73598871, 0.93434792), wxyz=(-0.381870867, 0.92421569, 0.0, 2.0004992e-32), ) base_frame = server.scene.add_frame("/scene_offset/base", show_axes=False) urdf_vis = ViserUrdf(server, urdf, root_node_name="/scene_offset/base") playing = server.gui.add_checkbox("playing", True) timestep_slider = server.gui.add_slider("timestep", 0, num_timesteps - 1, 1, 0) object_handle = server.scene.add_mesh_trimesh("/scene_offset/object", mesh) server.scene.add_grid("/grid", 2.0, 2.0) default_weights = RetargetingWeights( local_alignment=10.0, global_alignment=1.0, contact=5.0, contact_margin=0.01, joint_smoothness=2.0, root_smoothness=2.0, ) weights = pk.viewer.WeightTuner( server, default_weights, # type: ignore ) Ts_world_root, joints = None, None def generate_trajectory(): nonlocal Ts_world_root, joints gen_button.disabled = True Ts_world_root, joints = solve_retargeting( robot=robot, target_keypoints=keypoints, shadow_hand_link_retarget_indices=shadow_link_idx, mano_joint_retarget_indices=mano_joint_idx, mano_mask=mano_mask, contact_points_per_frame=jnp.array(padded_contact_points_per_frame), contact_indices_per_frame=jnp.array(padded_contact_indices_per_frame), contact_mask=jnp.array(padded_contact_mask), weights=weights.get_weights(), # type: ignore ) gen_button.disabled = False gen_button = server.gui.add_button("Retarget!") gen_button.on_click(lambda _: generate_trajectory()) generate_trajectory() assert Ts_world_root is not None and joints is not None while True: with server.atomic(): if playing.value: timestep_slider.value = (timestep_slider.value + 1) % num_timesteps tstep = timestep_slider.value base_frame.wxyz = onp.array(Ts_world_root.wxyz_xyz[tstep][:4]) base_frame.position = onp.array(Ts_world_root.wxyz_xyz[tstep][4:]) urdf_vis.update_cfg(onp.array(joints[tstep])) server.scene.add_point_cloud( "/scene_offset/target_keypoints", onp.array(keypoints[tstep]).reshape(-1, 3), onp.array((0, 0, 255))[None] .repeat(num_mano_joints, axis=0) .reshape(-1, 3), point_size=0.005, point_shape="sparkle", ) server.scene.add_point_cloud( "/scene_offset/contact_points", onp.array(contact_points_per_frame[tstep]).reshape(-1, 3), onp.array((255, 0, 0))[None] .repeat(len(contact_points_per_frame[tstep]), axis=0) .reshape(-1, 3), point_size=0.005, point_shape="circle", ) object_handle.position = object_pose_list[tstep][:3, 3] object_handle.wxyz = R.from_matrix(object_pose_list[tstep][:3, :3]).as_quat( scalar_first=True ) time.sleep(0.05) @jdc.jit def solve_retargeting( robot: pk.Robot, target_keypoints: jnp.ndarray, shadow_hand_link_retarget_indices: jnp.ndarray, mano_joint_retarget_indices: jnp.ndarray, mano_mask: jnp.ndarray, contact_points_per_frame: jnp.ndarray, contact_indices_per_frame: jnp.ndarray, contact_mask: jnp.ndarray, weights: RetargetingWeights, ) -> Tuple[jaxlie.SE3, jnp.ndarray]: """Solve the retargeting problem.""" n_retarget = len(mano_joint_retarget_indices) timesteps = target_keypoints.shape[0] # Variables. class ManoJointsScaleVar( jaxls.Var[jax.Array], default_factory=lambda: jnp.ones((n_retarget, n_retarget)) ): ... class OffsetVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros((3,))): ... var_joints = robot.joint_var_cls(jnp.arange(timesteps)) var_Ts_world_root = jaxls.SE3Var(jnp.arange(timesteps)) var_smpl_joints_scale = ManoJointsScaleVar(jnp.zeros(timesteps)) var_offset = OffsetVar(jnp.zeros(timesteps)) # Costs. costs: list[jaxls.Cost] = [] @jaxls.Cost.create_factory def retargeting_cost( var_values: jaxls.VarValues, var_Ts_world_root: jaxls.SE3Var, var_robot_cfg: jaxls.Var[jnp.ndarray], var_smpl_joints_scale: ManoJointsScaleVar, keypoints: jnp.ndarray, ) -> jax.Array: """Retargeting factor, with a focus on: - matching the relative joint/keypoint positions (vectors). - and matching the relative angles between the vectors. """ robot_cfg = var_values[var_robot_cfg] T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg)) T_world_root = var_values[var_Ts_world_root] T_world_link = T_world_root @ T_root_link mano_pos = keypoints[jnp.array(mano_joint_retarget_indices)] robot_pos = T_world_link.translation()[ jnp.array(shadow_hand_link_retarget_indices) ] # NxN grid of relative positions. delta_mano = mano_pos[:, None] - mano_pos[None, :] delta_robot = robot_pos[:, None] - robot_pos[None, :] # Vector regularization. position_scale = var_values[var_smpl_joints_scale][..., None] residual_position_delta = ( (delta_mano - delta_robot * position_scale) * (1 - jnp.eye(delta_mano.shape[0])[..., None]) * mano_mask[..., None] ) # Vector angle regularization. delta_mano_normalized = delta_mano / jnp.linalg.norm( delta_mano + 1e-6, axis=-1, keepdims=True ) delta_robot_normalized = delta_robot / jnp.linalg.norm( delta_robot + 1e-6, axis=-1, keepdims=True ) residual_angle_delta = 1 - (delta_mano_normalized * delta_robot_normalized).sum( axis=-1 ) residual_angle_delta = ( residual_angle_delta * (1 - jnp.eye(residual_angle_delta.shape[0])) * mano_mask ) residual = ( jnp.concatenate( [ residual_position_delta.flatten(), residual_angle_delta.flatten(), ], axis=0, ) * weights["local_alignment"] ) return residual @jaxls.Cost.create_factory def scale_regularization( var_values: jaxls.VarValues, var_smpl_joints_scale: ManoJointsScaleVar, ) -> jax.Array: """Regularize the scale of the retargeted joints.""" # Close to 1. res_0 = (var_values[var_smpl_joints_scale] - 1.0).flatten() * 1.0 # Symmetric. res_1 = ( var_values[var_smpl_joints_scale] - var_values[var_smpl_joints_scale].T ).flatten() * 100.0 # Non-negative. res_2 = jnp.clip(-var_values[var_smpl_joints_scale], min=0).flatten() * 100.0 return jnp.concatenate([res_0, res_1, res_2]) @jaxls.Cost.create_factory def pc_alignment_cost( var_values: jaxls.VarValues, var_Ts_world_root: jaxls.SE3Var, var_robot_cfg: jaxls.Var[jnp.ndarray], keypoints: jnp.ndarray, ) -> jax.Array: """Soft cost to align the human keypoints to the robot, in the world frame.""" T_world_root = var_values[var_Ts_world_root] robot_cfg = var_values[var_robot_cfg] T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg)) T_world_link = T_world_root @ T_root_link link_pos = T_world_link.translation()[shadow_hand_link_retarget_indices] keypoint_pos = keypoints[mano_joint_retarget_indices] return (link_pos - keypoint_pos).flatten() * weights["global_alignment"] @jaxls.Cost.create_factory def root_smoothness( var_values: jaxls.VarValues, var_Ts_world_root: jaxls.SE3Var, var_Ts_world_root_prev: jaxls.SE3Var, ) -> jax.Array: """Smoothness cost for the robot root translation.""" return ( var_values[var_Ts_world_root].translation() - var_values[var_Ts_world_root_prev].translation() ).flatten() * weights["root_smoothness"] @jaxls.Cost.create_factory def contact_cost( var_values: jaxls.VarValues, var_T_world_root: jaxls.SE3Var, var_robot_cfg: jaxls.Var[jnp.ndarray], contact_points: jax.Array, # (J, P, 3) contact_indices: jax.Array, # (J,) - Actual robot joint indices. contact_points_mask: jax.Array, # (J, P) ) -> jax.Array: """Cost for maintaining contact between specified robot joints and object points.""" robot_cfg = var_values[var_robot_cfg] T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg)) T_world_root = var_values[var_T_world_root] T_world_link = T_world_root @ T_root_link contact_joint_positions_world = T_world_link.translation()[contact_indices] # Contact points are already in world frame (as processed in dexycb). # Calculate distances from each joint to its set of contact points # Shape contact_points: (J, P, 3), contact_joint_positions_world: (J, 3) # We want distance between joint J and points P for that joint. # residual: (J, P, 3) residual = contact_points - contact_joint_positions_world # Penalize distance beyond a margin. residual_penalty = jnp.maximum( jnp.abs(residual) - weights["contact_margin"], 0.0 ) # (J, P, 3) # Apply mask. residual_penalty = ( residual_penalty * contact_points_mask[..., None] ) # (J, P, 3) residual = residual_penalty.flatten() * weights["contact"] return residual costs = [ # Costs that are relatively self-contained to the robot. retargeting_cost( var_Ts_world_root, var_joints, var_smpl_joints_scale, target_keypoints, ), scale_regularization(var_smpl_joints_scale), pk.costs.limit_cost( jax.tree.map(lambda x: x[None], robot), var_joints, 100.0, ), pk.costs.smoothness_cost( robot.joint_var_cls(jnp.arange(1, timesteps)), robot.joint_var_cls(jnp.arange(0, timesteps - 1)), jnp.array([weights["joint_smoothness"]]), ), pk.costs.rest_cost( var_joints, var_joints.default_factory()[None], jnp.array([0.2]), ), # Costs that are scene-centric. pc_alignment_cost( var_Ts_world_root, var_joints, target_keypoints, ), root_smoothness( jaxls.SE3Var(jnp.arange(1, timesteps)), jaxls.SE3Var(jnp.arange(0, timesteps - 1)), ), contact_cost( var_T_world_root=var_Ts_world_root, var_robot_cfg=var_joints, contact_points=contact_points_per_frame, contact_indices=contact_indices_per_frame, contact_points_mask=contact_mask, ), ] solution = ( jaxls.LeastSquaresProblem( costs, [var_joints, var_Ts_world_root, var_smpl_joints_scale, var_offset] ) .analyze() .solve() ) transform = solution[var_Ts_world_root] offset = solution[var_offset] transform = jaxlie.SE3.from_translation(offset) @ transform return transform, solution[var_joints] if __name__ == "__main__": main()