Trajectory OptimizationΒΆ

Basic Trajectory Optimization using PyRoKi.

Robot going over a wall, while avoiding world-collisions.

All examples can be run by first cloning the PyRoki repository, which includes the pyroki_snippets implementation details.

  1import time
  2from typing import Literal
  3
  4import numpy as np
  5import pyroki as pk
  6import trimesh
  7import tyro
  8import viser
  9from viser.extras import ViserUrdf
 10from robot_descriptions.loaders.yourdfpy import load_robot_description
 11
 12import pyroki_snippets as pks
 13
 14
 15def main(robot_name: Literal["ur5", "panda"] = "panda"):
 16    if robot_name == "ur5":
 17        urdf = load_robot_description("ur5_description")
 18        down_wxyz = np.array([0.707, 0, 0.707, 0])
 19        target_link_name = "ee_link"
 20
 21        # For UR5 it's important to initialize the robot in a safe configuration;
 22        # the zero-configuration puts the robot aligned with the wall obstacle.
 23        default_cfg = np.zeros(6)
 24        default_cfg[1] = -1.308
 25        robot = pk.Robot.from_urdf(urdf, default_joint_cfg=default_cfg)
 26
 27    elif robot_name == "panda":
 28        urdf = load_robot_description("panda_description")
 29        target_link_name = "panda_hand"
 30        down_wxyz = np.array([0, 0, 1, 0])  # for panda!
 31        robot = pk.Robot.from_urdf(urdf)
 32
 33    else:
 34        raise ValueError(f"Invalid robot: {robot_name}")
 35
 36    robot_coll = pk.collision.RobotCollision.from_urdf(urdf)
 37
 38    # Define the trajectory problem:
 39    # - number of timesteps, timestep size
 40    timesteps, dt = 50, 0.02
 41    # - the start and end poses.
 42    start_pos, end_pos = np.array([0.5, -0.3, 0.2]), np.array([0.5, 0.3, 0.2])
 43
 44    # Define the obstacles:
 45    # - Ground
 46    # ground_coll = pk.collision.HalfSpace.from_point_and_normal(
 47    #     np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0])
 48    # )
 49    # - Wall (using Box collision geometry)
 50    wall_height = 0.4
 51    wall_width = 0.1
 52    wall_length = 0.4
 53    wall_coll = pk.collision.Box.from_extent(
 54        extent=np.array([wall_length, wall_width, wall_height]),
 55        position=np.array([0.5, 0.0, wall_height / 2]),
 56    )
 57    # world_coll = [ground_coll, wall_coll]
 58
 59    # TODO: Constraints don't work with ground collision at the moment,
 60    # because the robot is in collision already with it (which destabilizes things).
 61    # We will fix this when we can have better robot collision geometries / handling in the future.
 62    world_coll = [wall_coll]
 63
 64    traj = pks.solve_trajopt(
 65        robot,
 66        robot_coll,
 67        world_coll,
 68        target_link_name,
 69        start_pos,
 70        down_wxyz,
 71        end_pos,
 72        down_wxyz,
 73        timesteps,
 74        dt,
 75    )
 76    traj = np.array(traj)
 77
 78    # Visualize!
 79    server = viser.ViserServer()
 80    urdf_vis = ViserUrdf(server, urdf)
 81    server.scene.add_grid("/grid", width=2, height=2, cell_size=0.1)
 82    server.scene.add_mesh_trimesh(
 83        "wall_box",
 84        trimesh.creation.box(
 85            extents=(wall_length, wall_width, wall_height),
 86            transform=trimesh.transformations.translation_matrix(
 87                np.array([0.5, 0.0, wall_height / 2])
 88            ),
 89        ),
 90    )
 91    for name, pos in zip(["start", "end"], [start_pos, end_pos]):
 92        server.scene.add_frame(
 93            f"/{name}",
 94            position=pos,
 95            wxyz=down_wxyz,
 96            axes_length=0.05,
 97            axes_radius=0.01,
 98        )
 99
100    slider = server.gui.add_slider(
101        "Timestep", min=0, max=timesteps - 1, step=1, initial_value=0
102    )
103    playing = server.gui.add_checkbox("Playing", initial_value=True)
104
105    while True:
106        if playing.value:
107            slider.value = (slider.value + 1) % timesteps
108
109        urdf_vis.update_cfg(traj[slider.value])
110        time.sleep(1.0 / 10.0)
111
112
113if __name__ == "__main__":
114    tyro.cli(main)