Trajectory OptimizationΒΆ

Basic Trajectory Optimization using PyRoKi.

Robot going over a wall, while avoiding world-collisions.

All examples can be run by first cloning the PyRoki repository, which includes the pyroki_snippets implementation details.

  1import time
  2from typing import Literal
  3
  4import numpy as np
  5import pyroki as pk
  6import trimesh
  7import tyro
  8import viser
  9from viser.extras import ViserUrdf
 10from robot_descriptions.loaders.yourdfpy import load_robot_description
 11
 12import pyroki_snippets as pks
 13
 14
 15def main(robot_name: Literal["ur5", "panda"] = "panda"):
 16    if robot_name == "ur5":
 17        urdf = load_robot_description("ur5_description")
 18        down_wxyz = np.array([0.707, 0, 0.707, 0])
 19        target_link_name = "ee_link"
 20
 21        # For UR5 it's important to initialize the robot in a safe configuration;
 22        # the zero-configuration puts the robot aligned with the wall obstacle.
 23        default_cfg = np.zeros(6)
 24        default_cfg[1] = -1.308
 25        robot = pk.Robot.from_urdf(urdf, default_joint_cfg=default_cfg)
 26
 27    elif robot_name == "panda":
 28        urdf = load_robot_description("panda_description")
 29        target_link_name = "panda_hand"
 30        down_wxyz = np.array([0, 0, 1, 0])  # for panda!
 31        robot = pk.Robot.from_urdf(urdf)
 32
 33    else:
 34        raise ValueError(f"Invalid robot: {robot_name}")
 35
 36    robot_coll = pk.collision.RobotCollision.from_urdf(urdf)
 37
 38    # Define the trajectory problem:
 39    # - number of timesteps, timestep size
 40    timesteps, dt = 25, 0.02
 41    # - the start and end poses.
 42    start_pos, end_pos = np.array([0.5, -0.3, 0.2]), np.array([0.5, 0.3, 0.2])
 43
 44    # Define the obstacles:
 45    # - Ground
 46    ground_coll = pk.collision.HalfSpace.from_point_and_normal(
 47        np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0])
 48    )
 49    # - Wall
 50    wall_height = 0.4
 51    wall_width = 0.1
 52    wall_length = 0.4
 53    wall_intervals = np.arange(start=0.3, stop=wall_length + 0.3, step=0.05)
 54    translation = np.concatenate(
 55        [
 56            wall_intervals.reshape(-1, 1),
 57            np.full((wall_intervals.shape[0], 1), 0.0),
 58            np.full((wall_intervals.shape[0], 1), wall_height / 2),
 59        ],
 60        axis=1,
 61    )
 62    wall_coll = pk.collision.Capsule.from_radius_height(
 63        position=translation,
 64        radius=np.full((translation.shape[0], 1), wall_width / 2),
 65        height=np.full((translation.shape[0], 1), wall_height),
 66    )
 67    world_coll = [ground_coll, wall_coll]
 68
 69    traj = pks.solve_trajopt(
 70        robot,
 71        robot_coll,
 72        world_coll,
 73        target_link_name,
 74        start_pos,
 75        down_wxyz,
 76        end_pos,
 77        down_wxyz,
 78        timesteps,
 79        dt,
 80    )
 81    traj = np.array(traj)
 82
 83    # Visualize!
 84    server = viser.ViserServer()
 85    urdf_vis = ViserUrdf(server, urdf)
 86    server.scene.add_grid("/grid", width=2, height=2, cell_size=0.1)
 87    server.scene.add_mesh_trimesh(
 88        "wall_box",
 89        trimesh.creation.box(
 90            extents=(wall_length, wall_width, wall_height),
 91            transform=trimesh.transformations.translation_matrix(
 92                np.array([0.5, 0.0, wall_height / 2])
 93            ),
 94        ),
 95    )
 96    for name, pos in zip(["start", "end"], [start_pos, end_pos]):
 97        server.scene.add_frame(
 98            f"/{name}",
 99            position=pos,
100            wxyz=down_wxyz,
101            axes_length=0.05,
102            axes_radius=0.01,
103        )
104
105    slider = server.gui.add_slider(
106        "Timestep", min=0, max=timesteps - 1, step=1, initial_value=0
107    )
108    playing = server.gui.add_checkbox("Playing", initial_value=True)
109
110    while True:
111        if playing.value:
112            slider.value = (slider.value + 1) % timesteps
113
114        urdf_vis.update_cfg(traj[slider.value])
115        time.sleep(1.0 / 10.0)
116
117
118if __name__ == "__main__":
119    tyro.cli(main)