|
|
import argparse |
|
|
import pathlib |
|
|
from pathlib import Path |
|
|
import threading |
|
|
from threading import Lock, Thread |
|
|
from typing import Dict |
|
|
|
|
|
import mujoco |
|
|
import mujoco.viewer |
|
|
import numpy as np |
|
|
try: |
|
|
import rclpy |
|
|
HAS_RCLPY = True |
|
|
except ImportError: |
|
|
HAS_RCLPY = False |
|
|
print("Warning: rclpy not found. Camera image publishing will be disabled.") |
|
|
from unitree_sdk2py.core.channel import ChannelFactoryInitialize |
|
|
import yaml |
|
|
import os |
|
|
from .image_publish_utils import ImagePublishProcess |
|
|
from .metric_utils import check_contact, check_height |
|
|
from .sim_utilts import get_subtree_body_names |
|
|
from .unitree_sdk2py_bridge import ElasticBand, UnitreeSdk2Bridge |
|
|
|
|
|
GR00T_WBC_ROOT = Path(__file__).resolve().parent.parent |
|
|
|
|
|
|
|
|
class DefaultEnv: |
|
|
"""Base environment class that handles simulation environment setup and step""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: Dict[str, any], |
|
|
env_name: str = "default", |
|
|
camera_configs: Dict[str, any] = None, |
|
|
onscreen: bool = False, |
|
|
offscreen: bool = False, |
|
|
): |
|
|
|
|
|
if camera_configs is None: |
|
|
camera_configs = {} |
|
|
|
|
|
|
|
|
if config["ROBOT_SCENE"] == "gr00t_wbc/control/robot_model/model_data/g1/scene_29dof.xml": |
|
|
camera_configs["global_view"] = { |
|
|
"height": 400, |
|
|
"width": 400, |
|
|
} |
|
|
self.config = config |
|
|
self.env_name = env_name |
|
|
self.num_body_dof = self.config["NUM_JOINTS"] |
|
|
self.num_hand_dof = self.config["NUM_HAND_JOINTS"] |
|
|
self.sim_dt = self.config["SIMULATE_DT"] |
|
|
self.obs = None |
|
|
self.torques = np.zeros(self.num_body_dof + self.num_hand_dof * 2) |
|
|
self.torque_limit = np.array(self.config["motor_effort_limit_list"]) |
|
|
self.camera_configs = camera_configs |
|
|
|
|
|
|
|
|
if len(camera_configs) > 0: |
|
|
print(f"✓ DefaultEnv initialized with {len(camera_configs)} camera(s): {list(camera_configs.keys())}") |
|
|
|
|
|
|
|
|
self.reward_lock = Lock() |
|
|
|
|
|
|
|
|
self.unitree_bridge = None |
|
|
|
|
|
|
|
|
self.onscreen = onscreen |
|
|
|
|
|
|
|
|
self.init_scene() |
|
|
self.last_reward = 0 |
|
|
|
|
|
|
|
|
self.offscreen = offscreen |
|
|
if self.offscreen: |
|
|
self.init_renderers() |
|
|
self.image_dt = self.config.get("IMAGE_DT", 0.033333) |
|
|
|
|
|
|
|
|
self.image_publish_process = None |
|
|
|
|
|
def init_scene(self): |
|
|
"""Initialize the default robot scene""" |
|
|
assets_root = Path(__file__).parent.parent |
|
|
self.mj_model = mujoco.MjModel.from_xml_path( |
|
|
str(assets_root / self.config["ROBOT_SCENE"]) |
|
|
) |
|
|
self.mj_data = mujoco.MjData(self.mj_model) |
|
|
self.mj_model.opt.timestep = self.sim_dt |
|
|
self.torso_index = mujoco.mj_name2id(self.mj_model, mujoco.mjtObj.mjOBJ_BODY, "torso_link") |
|
|
self.root_body = "pelvis" |
|
|
|
|
|
if self.config["ENABLE_ELASTIC_BAND"]: |
|
|
self.elastic_band = ElasticBand() |
|
|
if "g1" in self.config["ROBOT_TYPE"]: |
|
|
if self.config["enable_waist"]: |
|
|
self.band_attached_link = self.mj_model.body("pelvis").id |
|
|
else: |
|
|
self.band_attached_link = self.mj_model.body("torso_link").id |
|
|
elif "h1" in self.config["ROBOT_TYPE"]: |
|
|
self.band_attached_link = self.mj_model.body("torso_link").id |
|
|
else: |
|
|
self.band_attached_link = self.mj_model.body("base_link").id |
|
|
|
|
|
if self.onscreen: |
|
|
self.viewer = mujoco.viewer.launch_passive( |
|
|
self.mj_model, |
|
|
self.mj_data, |
|
|
key_callback=self.elastic_band.MujuocoKeyCallback, |
|
|
show_left_ui=False, |
|
|
show_right_ui=False, |
|
|
) |
|
|
else: |
|
|
mujoco.mj_forward(self.mj_model, self.mj_data) |
|
|
self.viewer = None |
|
|
else: |
|
|
if self.onscreen: |
|
|
self.viewer = mujoco.viewer.launch_passive( |
|
|
self.mj_model, self.mj_data, show_left_ui=False, show_right_ui=False |
|
|
) |
|
|
else: |
|
|
mujoco.mj_forward(self.mj_model, self.mj_data) |
|
|
self.viewer = None |
|
|
|
|
|
if self.viewer: |
|
|
|
|
|
self.viewer.cam.azimuth = 120 |
|
|
self.viewer.cam.elevation = -30 |
|
|
self.viewer.cam.distance = 2.0 |
|
|
self.viewer.cam.lookat = np.array([0, 0, 0.5]) |
|
|
|
|
|
|
|
|
self.body_joint_index = [] |
|
|
self.left_hand_index = [] |
|
|
self.right_hand_index = [] |
|
|
for i in range(self.mj_model.njnt): |
|
|
name = self.mj_model.joint(i).name |
|
|
if any( |
|
|
[ |
|
|
part_name in name |
|
|
for part_name in ["hip", "knee", "ankle", "waist", "shoulder", "elbow", "wrist"] |
|
|
] |
|
|
): |
|
|
self.body_joint_index.append(i) |
|
|
elif "left_hand" in name: |
|
|
self.left_hand_index.append(i) |
|
|
elif "right_hand" in name: |
|
|
self.right_hand_index.append(i) |
|
|
|
|
|
assert len(self.body_joint_index) == self.config["NUM_JOINTS"], \ |
|
|
f"Expected {self.config['NUM_JOINTS']} body joints, got {len(self.body_joint_index)}" |
|
|
|
|
|
if self.config.get("NUM_HAND_JOINTS", 0) > 0: |
|
|
expected_hands = self.config["NUM_HAND_JOINTS"] |
|
|
if len(self.left_hand_index) != expected_hands or len(self.right_hand_index) != expected_hands: |
|
|
print(f"Warning: Expected {expected_hands} hand joints, got left={len(self.left_hand_index)}, right={len(self.right_hand_index)}") |
|
|
print("Continuing without hands...") |
|
|
|
|
|
self.body_joint_index = np.array(self.body_joint_index) |
|
|
self.left_hand_index = np.array(self.left_hand_index) |
|
|
self.right_hand_index = np.array(self.right_hand_index) |
|
|
|
|
|
def init_renderers(self): |
|
|
|
|
|
self.renderers = {} |
|
|
for camera_name, camera_config in self.camera_configs.items(): |
|
|
renderer = mujoco.Renderer( |
|
|
self.mj_model, height=camera_config["height"], width=camera_config["width"] |
|
|
) |
|
|
self.renderers[camera_name] = renderer |
|
|
|
|
|
def start_image_publish_subprocess(self, start_method: str = "spawn", camera_port: int = 5555): |
|
|
"""Start image publishing subprocess using ZMQ""" |
|
|
|
|
|
if len(self.camera_configs) == 0: |
|
|
print( |
|
|
"Warning: No camera configs provided, image publishing subprocess will not be started" |
|
|
) |
|
|
return |
|
|
start_method = self.config.get("MP_START_METHOD", "spawn") |
|
|
self.image_publish_process = ImagePublishProcess( |
|
|
camera_configs=self.camera_configs, |
|
|
image_dt=self.image_dt, |
|
|
zmq_port=camera_port, |
|
|
start_method=start_method, |
|
|
verbose=self.config.get("verbose", False), |
|
|
) |
|
|
self.image_publish_process.start_process() |
|
|
print(f"✓ Started image publishing subprocess on ZMQ port {camera_port}") |
|
|
|
|
|
def compute_body_torques(self) -> np.ndarray: |
|
|
"""Compute body torques based on the current robot state""" |
|
|
body_torques = np.zeros(self.num_body_dof) |
|
|
if self.unitree_bridge is not None and self.unitree_bridge.low_cmd: |
|
|
for i in range(self.unitree_bridge.num_body_motor): |
|
|
if self.unitree_bridge.use_sensor: |
|
|
body_torques[i] = ( |
|
|
self.unitree_bridge.low_cmd.motor_cmd[i].tau |
|
|
+ self.unitree_bridge.low_cmd.motor_cmd[i].kp |
|
|
* (self.unitree_bridge.low_cmd.motor_cmd[i].q - self.mj_data.sensordata[i]) |
|
|
+ self.unitree_bridge.low_cmd.motor_cmd[i].kd |
|
|
* ( |
|
|
self.unitree_bridge.low_cmd.motor_cmd[i].dq |
|
|
- self.mj_data.sensordata[i + self.unitree_bridge.num_body_motor] |
|
|
) |
|
|
) |
|
|
else: |
|
|
body_torques[i] = ( |
|
|
self.unitree_bridge.low_cmd.motor_cmd[i].tau |
|
|
+ self.unitree_bridge.low_cmd.motor_cmd[i].kp |
|
|
* ( |
|
|
self.unitree_bridge.low_cmd.motor_cmd[i].q |
|
|
- self.mj_data.qpos[self.body_joint_index[i] + 7 - 1] |
|
|
) |
|
|
+ self.unitree_bridge.low_cmd.motor_cmd[i].kd |
|
|
* ( |
|
|
self.unitree_bridge.low_cmd.motor_cmd[i].dq |
|
|
- self.mj_data.qvel[self.body_joint_index[i] + 6 - 1] |
|
|
) |
|
|
) |
|
|
return body_torques |
|
|
|
|
|
def compute_hand_torques(self) -> np.ndarray: |
|
|
"""Compute hand torques based on the current robot state""" |
|
|
left_hand_torques = np.zeros(self.num_hand_dof) |
|
|
right_hand_torques = np.zeros(self.num_hand_dof) |
|
|
if self.unitree_bridge is not None and self.unitree_bridge.low_cmd: |
|
|
for i in range(self.unitree_bridge.num_hand_motor): |
|
|
left_hand_torques[i] = ( |
|
|
self.unitree_bridge.left_hand_cmd.motor_cmd[i].tau |
|
|
+ self.unitree_bridge.left_hand_cmd.motor_cmd[i].kp |
|
|
* ( |
|
|
self.unitree_bridge.left_hand_cmd.motor_cmd[i].q |
|
|
- self.mj_data.qpos[self.left_hand_index[i] + 7 - 1] |
|
|
) |
|
|
+ self.unitree_bridge.left_hand_cmd.motor_cmd[i].kd |
|
|
* ( |
|
|
self.unitree_bridge.left_hand_cmd.motor_cmd[i].dq |
|
|
- self.mj_data.qvel[self.left_hand_index[i] + 6 - 1] |
|
|
) |
|
|
) |
|
|
right_hand_torques[i] = ( |
|
|
self.unitree_bridge.right_hand_cmd.motor_cmd[i].tau |
|
|
+ self.unitree_bridge.right_hand_cmd.motor_cmd[i].kp |
|
|
* ( |
|
|
self.unitree_bridge.right_hand_cmd.motor_cmd[i].q |
|
|
- self.mj_data.qpos[self.right_hand_index[i] + 7 - 1] |
|
|
) |
|
|
+ self.unitree_bridge.right_hand_cmd.motor_cmd[i].kd |
|
|
* ( |
|
|
self.unitree_bridge.right_hand_cmd.motor_cmd[i].dq |
|
|
- self.mj_data.qvel[self.right_hand_index[i] + 6 - 1] |
|
|
) |
|
|
) |
|
|
return np.concatenate((left_hand_torques, right_hand_torques)) |
|
|
|
|
|
def compute_body_qpos(self) -> np.ndarray: |
|
|
"""Compute body joint positions based on the current command""" |
|
|
body_qpos = np.zeros(self.num_body_dof) |
|
|
if self.unitree_bridge is not None and self.unitree_bridge.low_cmd: |
|
|
for i in range(self.unitree_bridge.num_body_motor): |
|
|
body_qpos[i] = self.unitree_bridge.low_cmd.motor_cmd[i].q |
|
|
return body_qpos |
|
|
|
|
|
def compute_hand_qpos(self) -> np.ndarray: |
|
|
"""Compute hand joint positions based on the current command""" |
|
|
hand_qpos = np.zeros(self.num_hand_dof * 2) |
|
|
if self.unitree_bridge is not None and self.unitree_bridge.low_cmd: |
|
|
for i in range(self.unitree_bridge.num_hand_motor): |
|
|
hand_qpos[i] = self.unitree_bridge.left_hand_cmd.motor_cmd[i].q |
|
|
hand_qpos[i + self.num_hand_dof] = self.unitree_bridge.right_hand_cmd.motor_cmd[i].q |
|
|
return hand_qpos |
|
|
|
|
|
def prepare_obs(self) -> Dict[str, any]: |
|
|
"""Prepare observation dictionary from the current robot state""" |
|
|
obs = {} |
|
|
obs["floating_base_pose"] = self.mj_data.qpos[:7] |
|
|
obs["floating_base_vel"] = self.mj_data.qvel[:6] |
|
|
obs["floating_base_acc"] = self.mj_data.qacc[:6] |
|
|
obs["secondary_imu_quat"] = self.mj_data.xquat[self.torso_index] |
|
|
obs["secondary_imu_vel"] = self.mj_data.cvel[self.torso_index] |
|
|
obs["body_q"] = self.mj_data.qpos[self.body_joint_index + 7 - 1] |
|
|
obs["body_dq"] = self.mj_data.qvel[self.body_joint_index + 6 - 1] |
|
|
obs["body_ddq"] = self.mj_data.qacc[self.body_joint_index + 6 - 1] |
|
|
obs["body_tau_est"] = self.mj_data.actuator_force[self.body_joint_index - 1] |
|
|
if self.num_hand_dof > 0: |
|
|
obs["left_hand_q"] = self.mj_data.qpos[self.left_hand_index + 7 - 1] |
|
|
obs["left_hand_dq"] = self.mj_data.qvel[self.left_hand_index + 6 - 1] |
|
|
obs["left_hand_ddq"] = self.mj_data.qacc[self.left_hand_index + 6 - 1] |
|
|
obs["left_hand_tau_est"] = self.mj_data.actuator_force[self.left_hand_index - 1] |
|
|
obs["right_hand_q"] = self.mj_data.qpos[self.right_hand_index + 7 - 1] |
|
|
obs["right_hand_dq"] = self.mj_data.qvel[self.right_hand_index + 6 - 1] |
|
|
obs["right_hand_ddq"] = self.mj_data.qacc[self.right_hand_index + 6 - 1] |
|
|
obs["right_hand_tau_est"] = self.mj_data.actuator_force[self.right_hand_index - 1] |
|
|
obs["time"] = self.mj_data.time |
|
|
return obs |
|
|
|
|
|
def sim_step(self): |
|
|
self.obs = self.prepare_obs() |
|
|
self.unitree_bridge.PublishLowState(self.obs) |
|
|
if self.unitree_bridge.joystick: |
|
|
self.unitree_bridge.PublishWirelessController() |
|
|
if self.config["ENABLE_ELASTIC_BAND"]: |
|
|
if self.elastic_band.enable: |
|
|
|
|
|
pose = np.concatenate( |
|
|
[ |
|
|
self.mj_data.xpos[self.band_attached_link], |
|
|
self.mj_data.xquat[ |
|
|
self.band_attached_link |
|
|
], |
|
|
np.zeros(6), |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
mujoco.mj_objectVelocity( |
|
|
self.mj_model, |
|
|
self.mj_data, |
|
|
mujoco.mjtObj.mjOBJ_BODY, |
|
|
self.band_attached_link, |
|
|
pose[7:13], |
|
|
0, |
|
|
) |
|
|
|
|
|
|
|
|
pose[7:10], pose[10:13] = pose[10:13], pose[7:10].copy() |
|
|
self.mj_data.xfrc_applied[self.band_attached_link] = self.elastic_band.Advance(pose) |
|
|
else: |
|
|
|
|
|
self.mj_data.xfrc_applied[self.band_attached_link] = np.zeros(6) |
|
|
body_torques = self.compute_body_torques() |
|
|
hand_torques = self.compute_hand_torques() |
|
|
self.torques[self.body_joint_index - 1] = body_torques |
|
|
if self.num_hand_dof > 0: |
|
|
self.torques[self.left_hand_index - 1] = hand_torques[: self.num_hand_dof] |
|
|
self.torques[self.right_hand_index - 1] = hand_torques[self.num_hand_dof :] |
|
|
|
|
|
self.torques = np.clip(self.torques, -self.torque_limit, self.torque_limit) |
|
|
|
|
|
if self.config["FREE_BASE"]: |
|
|
self.mj_data.ctrl = np.concatenate((np.zeros(6), self.torques)) |
|
|
else: |
|
|
self.mj_data.ctrl = self.torques |
|
|
mujoco.mj_step(self.mj_model, self.mj_data) |
|
|
|
|
|
|
|
|
def kinematics_step(self): |
|
|
""" |
|
|
Run kinematics only: compute the qpos of the robot and directly set the qpos. |
|
|
For debugging purposes. |
|
|
""" |
|
|
if self.unitree_bridge is not None: |
|
|
self.unitree_bridge.PublishLowState(self.prepare_obs()) |
|
|
if self.unitree_bridge.joystick: |
|
|
self.unitree_bridge.PublishWirelessController() |
|
|
|
|
|
if self.config["ENABLE_ELASTIC_BAND"]: |
|
|
if self.elastic_band.enable: |
|
|
|
|
|
pose = np.concatenate( |
|
|
[ |
|
|
self.mj_data.xpos[self.band_attached_link], |
|
|
self.mj_data.xquat[ |
|
|
self.band_attached_link |
|
|
], |
|
|
np.zeros(6), |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
mujoco.mj_objectVelocity( |
|
|
self.mj_model, |
|
|
self.mj_data, |
|
|
mujoco.mjtObj.mjOBJ_BODY, |
|
|
self.band_attached_link, |
|
|
pose[7:13], |
|
|
0, |
|
|
) |
|
|
|
|
|
|
|
|
pose[7:10], pose[10:13] = pose[10:13], pose[7:10].copy() |
|
|
|
|
|
self.mj_data.xfrc_applied[self.band_attached_link] = self.elastic_band.Advance(pose) |
|
|
else: |
|
|
|
|
|
self.mj_data.xfrc_applied[self.band_attached_link] = np.zeros(6) |
|
|
|
|
|
body_qpos = self.compute_body_qpos() |
|
|
hand_qpos = self.compute_hand_qpos() |
|
|
|
|
|
self.mj_data.qpos[self.body_joint_index + 7 - 1] = body_qpos |
|
|
self.mj_data.qpos[self.left_hand_index + 7 - 1] = hand_qpos[: self.num_hand_dof] |
|
|
self.mj_data.qpos[self.right_hand_index + 7 - 1] = hand_qpos[self.num_hand_dof :] |
|
|
|
|
|
mujoco.mj_kinematics(self.mj_model, self.mj_data) |
|
|
mujoco.mj_comPos(self.mj_model, self.mj_data) |
|
|
|
|
|
def apply_perturbation(self, key): |
|
|
"""Apply perturbation to the robot""" |
|
|
|
|
|
perturbation_x_body = 0.0 |
|
|
perturbation_y_body = 0.0 |
|
|
if key == "up": |
|
|
perturbation_x_body = 1.0 |
|
|
elif key == "down": |
|
|
perturbation_x_body = -1.0 |
|
|
elif key == "left": |
|
|
perturbation_y_body = 1.0 |
|
|
elif key == "right": |
|
|
perturbation_y_body = -1.0 |
|
|
|
|
|
|
|
|
vel_body = np.array([perturbation_x_body, perturbation_y_body, 0.0]) |
|
|
vel_world = np.zeros(3) |
|
|
base_quat = self.mj_data.qpos[3:7] |
|
|
|
|
|
|
|
|
mujoco.mju_rotVecQuat(vel_world, vel_body, base_quat) |
|
|
|
|
|
|
|
|
self.mj_data.qvel[0] += vel_world[0] |
|
|
self.mj_data.qvel[1] += vel_world[1] |
|
|
|
|
|
|
|
|
mujoco.mj_forward(self.mj_model, self.mj_data) |
|
|
|
|
|
def update_viewer(self): |
|
|
if self.viewer is not None: |
|
|
self.viewer.sync() |
|
|
|
|
|
def update_viewer_camera(self): |
|
|
if self.viewer is not None: |
|
|
if self.viewer.cam.type == mujoco.mjtCamera.mjCAMERA_TRACKING: |
|
|
self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_FREE |
|
|
else: |
|
|
self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_TRACKING |
|
|
|
|
|
def update_reward(self): |
|
|
"""Calculate reward. Should be implemented by subclasses.""" |
|
|
with self.reward_lock: |
|
|
self.last_reward = 0 |
|
|
|
|
|
def get_reward(self): |
|
|
"""Thread-safe way to get the last calculated reward.""" |
|
|
with self.reward_lock: |
|
|
return self.last_reward |
|
|
|
|
|
def set_unitree_bridge(self, unitree_bridge): |
|
|
"""Set the unitree bridge from the simulator""" |
|
|
self.unitree_bridge = unitree_bridge |
|
|
|
|
|
def get_privileged_obs(self): |
|
|
"""Get privileged observation. Should be implemented by subclasses.""" |
|
|
return {} |
|
|
|
|
|
def update_render_caches(self): |
|
|
"""Update render cache and shared memory for subprocess.""" |
|
|
render_caches = {} |
|
|
for camera_name, camera_config in self.camera_configs.items(): |
|
|
renderer = self.renderers[camera_name] |
|
|
if "params" in camera_config: |
|
|
renderer.update_scene(self.mj_data, camera=camera_config["params"]) |
|
|
else: |
|
|
renderer.update_scene(self.mj_data, camera=camera_name) |
|
|
render_caches[camera_name + "_image"] = renderer.render() |
|
|
|
|
|
|
|
|
if self.image_publish_process is not None: |
|
|
self.image_publish_process.update_shared_memory(render_caches) |
|
|
|
|
|
return render_caches |
|
|
|
|
|
def handle_keyboard_button(self, key): |
|
|
if self.elastic_band is not None: |
|
|
self.elastic_band.handle_keyboard_button(key) |
|
|
|
|
|
if key == "backspace": |
|
|
self.reset() |
|
|
if key == "v": |
|
|
self.update_viewer_camera() |
|
|
if key in ["up", "down", "left", "right"]: |
|
|
self.apply_perturbation(key) |
|
|
|
|
|
def check_fall(self): |
|
|
"""Check if the robot has fallen""" |
|
|
self.fall = False |
|
|
if self.mj_data.qpos[2] < 0.2: |
|
|
self.fall = True |
|
|
print(f"Warning: Robot has fallen, height: {self.mj_data.qpos[2]:.3f} m") |
|
|
|
|
|
if self.fall: |
|
|
self.reset() |
|
|
|
|
|
def check_self_collision(self): |
|
|
"""Check for self-collision of the robot""" |
|
|
robot_bodies = get_subtree_body_names(self.mj_model, self.mj_model.body(self.root_body).id) |
|
|
self_collision, contact_bodies = check_contact( |
|
|
self.mj_model, self.mj_data, robot_bodies, robot_bodies, return_all_contact_bodies=True |
|
|
) |
|
|
if self_collision: |
|
|
print(f"Warning: Self-collision detected: {contact_bodies}") |
|
|
return self_collision |
|
|
|
|
|
def reset(self): |
|
|
mujoco.mj_resetData(self.mj_model, self.mj_data) |
|
|
|
|
|
|
|
|
class CubeEnv(DefaultEnv): |
|
|
"""Environment with a cube object for pick and place tasks""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: Dict[str, any], |
|
|
onscreen: bool = False, |
|
|
offscreen: bool = False, |
|
|
): |
|
|
|
|
|
config = config.copy() |
|
|
config["ROBOT_SCENE"] = "gr00t_wbc/control/robot_model/model_data/g1/pnp_cube_43dof.xml" |
|
|
super().__init__(config, "cube", {}, onscreen, offscreen) |
|
|
|
|
|
def update_reward(self): |
|
|
"""Calculate reward based on gripper contact with cube and cube height""" |
|
|
right_hand_body = [ |
|
|
"right_hand_thumb_2_link", |
|
|
"right_hand_middle_1_link", |
|
|
"right_hand_index_1_link", |
|
|
] |
|
|
gripper_cube_contact = check_contact( |
|
|
self.mj_model, self.mj_data, right_hand_body, "cube_body" |
|
|
) |
|
|
cube_lifted = check_height(self.mj_model, self.mj_data, "cube", 0.85, 2.0) |
|
|
|
|
|
with self.reward_lock: |
|
|
self.last_reward = gripper_cube_contact & cube_lifted |
|
|
|
|
|
|
|
|
class BoxEnv(DefaultEnv): |
|
|
"""Environment with a box object for manipulation tasks""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: Dict[str, any], |
|
|
onscreen: bool = False, |
|
|
offscreen: bool = False, |
|
|
): |
|
|
|
|
|
config = config.copy() |
|
|
config["ROBOT_SCENE"] = "gr00t_wbc/control/robot_model/model_data/g1/lift_box_43dof.xml" |
|
|
super().__init__(config, "box", {}, onscreen, offscreen) |
|
|
|
|
|
def reward(self): |
|
|
"""Calculate reward based on gripper contact with cube and cube height""" |
|
|
left_hand_body = [ |
|
|
"left_hand_thumb_2_link", |
|
|
"left_hand_middle_1_link", |
|
|
"left_hand_index_1_link", |
|
|
] |
|
|
right_hand_body = [ |
|
|
"right_hand_thumb_2_link", |
|
|
"right_hand_middle_1_link", |
|
|
"right_hand_index_1_link", |
|
|
] |
|
|
gripper_box_contact = check_contact(self.mj_model, self.mj_data, left_hand_body, "box_body") |
|
|
gripper_box_contact &= check_contact( |
|
|
self.mj_model, self.mj_data, right_hand_body, "box_body" |
|
|
) |
|
|
box_lifted = check_height(self.mj_model, self.mj_data, "box", 0.92, 2.0) |
|
|
|
|
|
print("gripper_box_contact: ", gripper_box_contact, "box_lifted: ", box_lifted) |
|
|
|
|
|
with self.reward_lock: |
|
|
self.last_reward = gripper_box_contact & box_lifted |
|
|
return self.last_reward |
|
|
|
|
|
|
|
|
class BottleEnv(DefaultEnv): |
|
|
"""Environment with a cylinder object for manipulation tasks""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: Dict[str, any], |
|
|
onscreen: bool = False, |
|
|
offscreen: bool = False, |
|
|
): |
|
|
|
|
|
config = config.copy() |
|
|
config["ROBOT_SCENE"] = "gr00t_wbc/control/robot_model/model_data/g1/pnp_bottle_43dof.xml" |
|
|
camera_configs = { |
|
|
"egoview": { |
|
|
"height": 400, |
|
|
"width": 400, |
|
|
}, |
|
|
} |
|
|
super().__init__( |
|
|
config, "cylinder", camera_configs, onscreen, offscreen |
|
|
) |
|
|
|
|
|
self.bottle_body = self.mj_model.body("bottle_body") |
|
|
self.bottle_geom = self.mj_model.geom("bottle") |
|
|
|
|
|
if self.viewer is not None: |
|
|
self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_FIXED |
|
|
self.viewer.cam.fixedcamid = self.mj_model.camera("egoview").id |
|
|
|
|
|
def update_reward(self): |
|
|
"""Calculate reward based on gripper contact with cylinder and cylinder height""" |
|
|
pass |
|
|
|
|
|
def get_privileged_obs(self): |
|
|
obs_pos = self.mj_data.xpos[self.bottle_body.id] |
|
|
obs_quat = self.mj_data.xquat[self.bottle_body.id] |
|
|
return {"bottle_pos": obs_pos, "bottle_quat": obs_quat} |
|
|
|
|
|
|
|
|
class BaseSimulator: |
|
|
"""Base simulator class that handles initialization and running of simulations""" |
|
|
|
|
|
def __init__(self, config: Dict[str, any], env_name: str = "default", **kwargs): |
|
|
self.config = config |
|
|
self.env_name = env_name |
|
|
|
|
|
|
|
|
if HAS_RCLPY: |
|
|
if not rclpy.ok(): |
|
|
rclpy.init() |
|
|
self.node = rclpy.create_node("sim_mujoco") |
|
|
self.thread = threading.Thread(target=rclpy.spin, args=(self.node,), daemon=True) |
|
|
self.thread.start() |
|
|
else: |
|
|
self.thread = None |
|
|
executor = rclpy.get_global_executor() |
|
|
self.node = executor.get_nodes()[0] |
|
|
else: |
|
|
self.node = None |
|
|
self.thread = None |
|
|
|
|
|
|
|
|
self.sim_dt = self.config["SIMULATE_DT"] |
|
|
self.reward_dt = self.config.get("REWARD_DT", 0.02) |
|
|
self.image_dt = self.config.get("IMAGE_DT", 0.033333) |
|
|
self.viewer_dt = self.config.get("VIEWER_DT", 0.02) |
|
|
|
|
|
|
|
|
if env_name == "default": |
|
|
self.sim_env = DefaultEnv(config, env_name, **kwargs) |
|
|
elif env_name == "pnp_cube": |
|
|
self.sim_env = CubeEnv(config, **kwargs) |
|
|
elif env_name == "lift_box": |
|
|
self.sim_env = BoxEnv(config, **kwargs) |
|
|
elif env_name == "pnp_bottle": |
|
|
self.sim_env = BottleEnv(config, **kwargs) |
|
|
else: |
|
|
raise ValueError(f"Invalid environment name: {env_name}") |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
if self.config.get("INTERFACE", None): |
|
|
ChannelFactoryInitialize(self.config["DOMAIN_ID"], self.config["INTERFACE"]) |
|
|
else: |
|
|
ChannelFactoryInitialize(self.config["DOMAIN_ID"]) |
|
|
except Exception as e: |
|
|
|
|
|
print(f"Note: Channel factory initialization attempt: {e}") |
|
|
|
|
|
|
|
|
self.init_unitree_bridge() |
|
|
self.sim_env.set_unitree_bridge(self.unitree_bridge) |
|
|
|
|
|
|
|
|
self.init_subscriber() |
|
|
self.init_publisher() |
|
|
|
|
|
self.sim_thread = None |
|
|
|
|
|
def start_as_thread(self): |
|
|
|
|
|
self.sim_thread = Thread(target=self.start) |
|
|
self.sim_thread.start() |
|
|
|
|
|
def start_image_publish_subprocess(self, start_method: str = "spawn", camera_port: int = 5555): |
|
|
"""Start the image publish subprocess""" |
|
|
self.sim_env.start_image_publish_subprocess(start_method, camera_port) |
|
|
|
|
|
def init_subscriber(self): |
|
|
"""Initialize subscribers. Can be overridden by subclasses.""" |
|
|
pass |
|
|
|
|
|
def init_publisher(self): |
|
|
"""Initialize publishers. Can be overridden by subclasses.""" |
|
|
pass |
|
|
|
|
|
def init_unitree_bridge(self): |
|
|
"""Initialize the unitree SDK bridge""" |
|
|
self.unitree_bridge = UnitreeSdk2Bridge(self.config) |
|
|
if self.config["USE_JOYSTICK"]: |
|
|
self.unitree_bridge.SetupJoystick( |
|
|
device_id=self.config["JOYSTICK_DEVICE"], js_type=self.config["JOYSTICK_TYPE"] |
|
|
) |
|
|
|
|
|
def start(self): |
|
|
"""Main simulation loop""" |
|
|
import time |
|
|
sim_cnt = 0 |
|
|
last_time = time.time() |
|
|
|
|
|
print(f"Starting simulation loop. Viewer: {self.sim_env.viewer is not None}") |
|
|
|
|
|
try: |
|
|
while ( |
|
|
self.sim_env.viewer and self.sim_env.viewer.is_running() |
|
|
) or self.sim_env.viewer is None: |
|
|
|
|
|
self.sim_env.sim_step() |
|
|
|
|
|
|
|
|
if sim_cnt % int(self.viewer_dt / self.sim_dt) == 0: |
|
|
self.sim_env.update_viewer() |
|
|
|
|
|
|
|
|
if sim_cnt % int(self.reward_dt / self.sim_dt) == 0: |
|
|
self.sim_env.update_reward() |
|
|
|
|
|
|
|
|
if sim_cnt % int(self.image_dt / self.sim_dt) == 0: |
|
|
self.sim_env.update_render_caches() |
|
|
|
|
|
|
|
|
elapsed = time.time() - last_time |
|
|
sleep_time = max(0, self.sim_dt - elapsed) |
|
|
if sleep_time > 0: |
|
|
time.sleep(sleep_time) |
|
|
last_time = time.time() |
|
|
|
|
|
sim_cnt += 1 |
|
|
|
|
|
print(f"Loop exited. Viewer running: {self.sim_env.viewer.is_running() if self.sim_env.viewer else 'No viewer'}") |
|
|
except KeyboardInterrupt: |
|
|
|
|
|
print("Keyboard interrupt received") |
|
|
pass |
|
|
except Exception as e: |
|
|
print(f"Exception in simulation loop: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
self.close() |
|
|
|
|
|
def __del__(self): |
|
|
"""Clean up resources when simulator is deleted""" |
|
|
self.close() |
|
|
|
|
|
def reset(self): |
|
|
"""Reset the simulation. Can be overridden by subclasses.""" |
|
|
self.sim_env.reset() |
|
|
|
|
|
def close(self): |
|
|
"""Close the simulation. Can be overridden by subclasses.""" |
|
|
try: |
|
|
|
|
|
if hasattr(self.sim_env, "viewer") and self.sim_env.viewer is not None: |
|
|
self.sim_env.viewer.close() |
|
|
|
|
|
|
|
|
if HAS_RCLPY and rclpy.ok(): |
|
|
rclpy.shutdown() |
|
|
except Exception as e: |
|
|
print(f"Warning during close: {e}") |
|
|
|
|
|
def get_privileged_obs(self): |
|
|
obs = self.sim_env.get_privileged_obs() |
|
|
|
|
|
return obs |
|
|
|
|
|
def handle_keyboard_button(self, key): |
|
|
|
|
|
if self.env_name == "default": |
|
|
self.sim_env.handle_keyboard_button(key) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Robot") |
|
|
parser.add_argument( |
|
|
"--config", |
|
|
type=str, |
|
|
default="./gr00t_wbc/control/main/teleop/configs/g1_29dof_gear_wbc.yaml", |
|
|
help="config file", |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
with open(args.config, "r") as file: |
|
|
config = yaml.load(file, Loader=yaml.FullLoader) |
|
|
|
|
|
if config.get("INTERFACE", None): |
|
|
ChannelFactoryInitialize(config["DOMAIN_ID"], config["INTERFACE"]) |
|
|
else: |
|
|
ChannelFactoryInitialize(config["DOMAIN_ID"]) |
|
|
|
|
|
simulation = BaseSimulator(config) |
|
|
simulation.start_as_thread() |
|
|
|