unitree-g1-mujoco / sim /base_sim.py
nepyope's picture
Update sim/base_sim.py
638fd14 verified
import argparse
import pathlib
from pathlib import Path
import threading
from threading import Lock, Thread
from typing import Dict
import mujoco
import mujoco.viewer
import numpy as np
try:
import rclpy
HAS_RCLPY = True
except ImportError:
HAS_RCLPY = False
print("Warning: rclpy not found. Camera image publishing will be disabled.")
from unitree_sdk2py.core.channel import ChannelFactoryInitialize
import yaml
import os
from .image_publish_utils import ImagePublishProcess
from .metric_utils import check_contact, check_height
from .sim_utilts import get_subtree_body_names
from .unitree_sdk2py_bridge import ElasticBand, UnitreeSdk2Bridge
GR00T_WBC_ROOT = Path(__file__).resolve().parent.parent # Points to mujoco_sim_g1/
class DefaultEnv:
"""Base environment class that handles simulation environment setup and step"""
def __init__(
self,
config: Dict[str, any],
env_name: str = "default",
camera_configs: Dict[str, any] = None,
onscreen: bool = False,
offscreen: bool = False,
):
# Avoid mutable default argument gotcha
if camera_configs is None:
camera_configs = {}
# global_view is only set up for this specifc scene for now.
if config["ROBOT_SCENE"] == "gr00t_wbc/control/robot_model/model_data/g1/scene_29dof.xml":
camera_configs["global_view"] = {
"height": 400,
"width": 400,
}
self.config = config
self.env_name = env_name
self.num_body_dof = self.config["NUM_JOINTS"]
self.num_hand_dof = self.config["NUM_HAND_JOINTS"]
self.sim_dt = self.config["SIMULATE_DT"]
self.obs = None
self.torques = np.zeros(self.num_body_dof + self.num_hand_dof * 2)
self.torque_limit = np.array(self.config["motor_effort_limit_list"])
self.camera_configs = camera_configs
# Debug: print camera config
if len(camera_configs) > 0:
print(f"✓ DefaultEnv initialized with {len(camera_configs)} camera(s): {list(camera_configs.keys())}")
# Thread safety lock
self.reward_lock = Lock()
# Unitree bridge will be initialized by the simulator
self.unitree_bridge = None
# Store display mode
self.onscreen = onscreen
# Initialize scene (defined in subclasses)
self.init_scene()
self.last_reward = 0
# Setup offscreen rendering if needed
self.offscreen = offscreen
if self.offscreen:
self.init_renderers()
self.image_dt = self.config.get("IMAGE_DT", 0.033333)
# Image publishing subprocess (initialized separately)
self.image_publish_process = None
def init_scene(self):
"""Initialize the default robot scene"""
assets_root = Path(__file__).parent.parent
self.mj_model = mujoco.MjModel.from_xml_path(
str(assets_root / self.config["ROBOT_SCENE"])
)
self.mj_data = mujoco.MjData(self.mj_model)
self.mj_model.opt.timestep = self.sim_dt
self.torso_index = mujoco.mj_name2id(self.mj_model, mujoco.mjtObj.mjOBJ_BODY, "torso_link")
self.root_body = "pelvis"
# Enable the elastic band
if self.config["ENABLE_ELASTIC_BAND"]:
self.elastic_band = ElasticBand()
if "g1" in self.config["ROBOT_TYPE"]:
if self.config["enable_waist"]:
self.band_attached_link = self.mj_model.body("pelvis").id
else:
self.band_attached_link = self.mj_model.body("torso_link").id
elif "h1" in self.config["ROBOT_TYPE"]:
self.band_attached_link = self.mj_model.body("torso_link").id
else:
self.band_attached_link = self.mj_model.body("base_link").id
if self.onscreen:
self.viewer = mujoco.viewer.launch_passive(
self.mj_model,
self.mj_data,
key_callback=self.elastic_band.MujuocoKeyCallback,
show_left_ui=False,
show_right_ui=False,
)
else:
mujoco.mj_forward(self.mj_model, self.mj_data)
self.viewer = None
else:
if self.onscreen:
self.viewer = mujoco.viewer.launch_passive(
self.mj_model, self.mj_data, show_left_ui=False, show_right_ui=False
)
else:
mujoco.mj_forward(self.mj_model, self.mj_data)
self.viewer = None
if self.viewer:
# viewer camera
self.viewer.cam.azimuth = 120 # Horizontal rotation in degrees
self.viewer.cam.elevation = -30 # Vertical tilt in degrees
self.viewer.cam.distance = 2.0 # Distance from camera to target
self.viewer.cam.lookat = np.array([0, 0, 0.5]) # Point the camera is looking at
# Note that the actuator order is the same as the joint order in the mujoco model.
self.body_joint_index = []
self.left_hand_index = []
self.right_hand_index = []
for i in range(self.mj_model.njnt):
name = self.mj_model.joint(i).name
if any(
[
part_name in name
for part_name in ["hip", "knee", "ankle", "waist", "shoulder", "elbow", "wrist"]
]
):
self.body_joint_index.append(i)
elif "left_hand" in name:
self.left_hand_index.append(i)
elif "right_hand" in name:
self.right_hand_index.append(i)
assert len(self.body_joint_index) == self.config["NUM_JOINTS"], \
f"Expected {self.config['NUM_JOINTS']} body joints, got {len(self.body_joint_index)}"
# Hand joints are optional (some models don't have hands)
if self.config.get("NUM_HAND_JOINTS", 0) > 0:
expected_hands = self.config["NUM_HAND_JOINTS"]
if len(self.left_hand_index) != expected_hands or len(self.right_hand_index) != expected_hands:
print(f"Warning: Expected {expected_hands} hand joints, got left={len(self.left_hand_index)}, right={len(self.right_hand_index)}")
print("Continuing without hands...")
self.body_joint_index = np.array(self.body_joint_index)
self.left_hand_index = np.array(self.left_hand_index)
self.right_hand_index = np.array(self.right_hand_index)
def init_renderers(self):
# Initialize camera renderers
self.renderers = {}
for camera_name, camera_config in self.camera_configs.items():
renderer = mujoco.Renderer(
self.mj_model, height=camera_config["height"], width=camera_config["width"]
)
self.renderers[camera_name] = renderer
def start_image_publish_subprocess(self, start_method: str = "spawn", camera_port: int = 5555):
"""Start image publishing subprocess using ZMQ"""
# Use spawn method for better GIL isolation, or configured method
if len(self.camera_configs) == 0:
print(
"Warning: No camera configs provided, image publishing subprocess will not be started"
)
return
start_method = self.config.get("MP_START_METHOD", "spawn")
self.image_publish_process = ImagePublishProcess(
camera_configs=self.camera_configs,
image_dt=self.image_dt,
zmq_port=camera_port,
start_method=start_method,
verbose=self.config.get("verbose", False),
)
self.image_publish_process.start_process()
print(f"✓ Started image publishing subprocess on ZMQ port {camera_port}")
def compute_body_torques(self) -> np.ndarray:
"""Compute body torques based on the current robot state"""
body_torques = np.zeros(self.num_body_dof)
if self.unitree_bridge is not None and self.unitree_bridge.low_cmd:
for i in range(self.unitree_bridge.num_body_motor):
if self.unitree_bridge.use_sensor:
body_torques[i] = (
self.unitree_bridge.low_cmd.motor_cmd[i].tau
+ self.unitree_bridge.low_cmd.motor_cmd[i].kp
* (self.unitree_bridge.low_cmd.motor_cmd[i].q - self.mj_data.sensordata[i])
+ self.unitree_bridge.low_cmd.motor_cmd[i].kd
* (
self.unitree_bridge.low_cmd.motor_cmd[i].dq
- self.mj_data.sensordata[i + self.unitree_bridge.num_body_motor]
)
)
else:
body_torques[i] = (
self.unitree_bridge.low_cmd.motor_cmd[i].tau
+ self.unitree_bridge.low_cmd.motor_cmd[i].kp
* (
self.unitree_bridge.low_cmd.motor_cmd[i].q
- self.mj_data.qpos[self.body_joint_index[i] + 7 - 1]
)
+ self.unitree_bridge.low_cmd.motor_cmd[i].kd
* (
self.unitree_bridge.low_cmd.motor_cmd[i].dq
- self.mj_data.qvel[self.body_joint_index[i] + 6 - 1]
)
)
return body_torques
def compute_hand_torques(self) -> np.ndarray:
"""Compute hand torques based on the current robot state"""
left_hand_torques = np.zeros(self.num_hand_dof)
right_hand_torques = np.zeros(self.num_hand_dof)
if self.unitree_bridge is not None and self.unitree_bridge.low_cmd:
for i in range(self.unitree_bridge.num_hand_motor):
left_hand_torques[i] = (
self.unitree_bridge.left_hand_cmd.motor_cmd[i].tau
+ self.unitree_bridge.left_hand_cmd.motor_cmd[i].kp
* (
self.unitree_bridge.left_hand_cmd.motor_cmd[i].q
- self.mj_data.qpos[self.left_hand_index[i] + 7 - 1]
)
+ self.unitree_bridge.left_hand_cmd.motor_cmd[i].kd
* (
self.unitree_bridge.left_hand_cmd.motor_cmd[i].dq
- self.mj_data.qvel[self.left_hand_index[i] + 6 - 1]
)
)
right_hand_torques[i] = (
self.unitree_bridge.right_hand_cmd.motor_cmd[i].tau
+ self.unitree_bridge.right_hand_cmd.motor_cmd[i].kp
* (
self.unitree_bridge.right_hand_cmd.motor_cmd[i].q
- self.mj_data.qpos[self.right_hand_index[i] + 7 - 1]
)
+ self.unitree_bridge.right_hand_cmd.motor_cmd[i].kd
* (
self.unitree_bridge.right_hand_cmd.motor_cmd[i].dq
- self.mj_data.qvel[self.right_hand_index[i] + 6 - 1]
)
)
return np.concatenate((left_hand_torques, right_hand_torques))
def compute_body_qpos(self) -> np.ndarray:
"""Compute body joint positions based on the current command"""
body_qpos = np.zeros(self.num_body_dof)
if self.unitree_bridge is not None and self.unitree_bridge.low_cmd:
for i in range(self.unitree_bridge.num_body_motor):
body_qpos[i] = self.unitree_bridge.low_cmd.motor_cmd[i].q
return body_qpos
def compute_hand_qpos(self) -> np.ndarray:
"""Compute hand joint positions based on the current command"""
hand_qpos = np.zeros(self.num_hand_dof * 2)
if self.unitree_bridge is not None and self.unitree_bridge.low_cmd:
for i in range(self.unitree_bridge.num_hand_motor):
hand_qpos[i] = self.unitree_bridge.left_hand_cmd.motor_cmd[i].q
hand_qpos[i + self.num_hand_dof] = self.unitree_bridge.right_hand_cmd.motor_cmd[i].q
return hand_qpos
def prepare_obs(self) -> Dict[str, any]:
"""Prepare observation dictionary from the current robot state"""
obs = {}
obs["floating_base_pose"] = self.mj_data.qpos[:7]
obs["floating_base_vel"] = self.mj_data.qvel[:6]
obs["floating_base_acc"] = self.mj_data.qacc[:6]
obs["secondary_imu_quat"] = self.mj_data.xquat[self.torso_index]
obs["secondary_imu_vel"] = self.mj_data.cvel[self.torso_index]
obs["body_q"] = self.mj_data.qpos[self.body_joint_index + 7 - 1]
obs["body_dq"] = self.mj_data.qvel[self.body_joint_index + 6 - 1]
obs["body_ddq"] = self.mj_data.qacc[self.body_joint_index + 6 - 1]
obs["body_tau_est"] = self.mj_data.actuator_force[self.body_joint_index - 1]
if self.num_hand_dof > 0:
obs["left_hand_q"] = self.mj_data.qpos[self.left_hand_index + 7 - 1]
obs["left_hand_dq"] = self.mj_data.qvel[self.left_hand_index + 6 - 1]
obs["left_hand_ddq"] = self.mj_data.qacc[self.left_hand_index + 6 - 1]
obs["left_hand_tau_est"] = self.mj_data.actuator_force[self.left_hand_index - 1]
obs["right_hand_q"] = self.mj_data.qpos[self.right_hand_index + 7 - 1]
obs["right_hand_dq"] = self.mj_data.qvel[self.right_hand_index + 6 - 1]
obs["right_hand_ddq"] = self.mj_data.qacc[self.right_hand_index + 6 - 1]
obs["right_hand_tau_est"] = self.mj_data.actuator_force[self.right_hand_index - 1]
obs["time"] = self.mj_data.time
return obs
def sim_step(self):
self.obs = self.prepare_obs()
self.unitree_bridge.PublishLowState(self.obs)
if self.unitree_bridge.joystick:
self.unitree_bridge.PublishWirelessController()
if self.config["ENABLE_ELASTIC_BAND"]:
if self.elastic_band.enable:
# Get Cartesian pose and velocity of the band_attached_link
pose = np.concatenate(
[
self.mj_data.xpos[self.band_attached_link], # link position in world
self.mj_data.xquat[
self.band_attached_link
], # link quaternion in world [w,x,y,z]
np.zeros(6), # placeholder for velocity
]
)
# Get velocity in world frame
mujoco.mj_objectVelocity(
self.mj_model,
self.mj_data,
mujoco.mjtObj.mjOBJ_BODY,
self.band_attached_link,
pose[7:13],
0, # 0 for world frame
)
# Reorder velocity from [ang, lin] to [lin, ang]
pose[7:10], pose[10:13] = pose[10:13], pose[7:10].copy()
self.mj_data.xfrc_applied[self.band_attached_link] = self.elastic_band.Advance(pose)
else:
# explicitly resetting the force when the band is not enabled
self.mj_data.xfrc_applied[self.band_attached_link] = np.zeros(6)
body_torques = self.compute_body_torques()
hand_torques = self.compute_hand_torques()
self.torques[self.body_joint_index - 1] = body_torques
if self.num_hand_dof > 0:
self.torques[self.left_hand_index - 1] = hand_torques[: self.num_hand_dof]
self.torques[self.right_hand_index - 1] = hand_torques[self.num_hand_dof :]
self.torques = np.clip(self.torques, -self.torque_limit, self.torque_limit)
if self.config["FREE_BASE"]:
self.mj_data.ctrl = np.concatenate((np.zeros(6), self.torques))
else:
self.mj_data.ctrl = self.torques
mujoco.mj_step(self.mj_model, self.mj_data)
# self.check_self_collision()
def kinematics_step(self):
"""
Run kinematics only: compute the qpos of the robot and directly set the qpos.
For debugging purposes.
"""
if self.unitree_bridge is not None:
self.unitree_bridge.PublishLowState(self.prepare_obs())
if self.unitree_bridge.joystick:
self.unitree_bridge.PublishWirelessController()
if self.config["ENABLE_ELASTIC_BAND"]:
if self.elastic_band.enable:
# Get Cartesian pose and velocity of the band_attached_link
pose = np.concatenate(
[
self.mj_data.xpos[self.band_attached_link], # link position in world
self.mj_data.xquat[
self.band_attached_link
], # link quaternion in world [w,x,y,z]
np.zeros(6), # placeholder for velocity
]
)
# Get velocity in world frame
mujoco.mj_objectVelocity(
self.mj_model,
self.mj_data,
mujoco.mjtObj.mjOBJ_BODY,
self.band_attached_link,
pose[7:13],
0, # 0 for world frame
)
# Reorder velocity from [ang, lin] to [lin, ang]
pose[7:10], pose[10:13] = pose[10:13], pose[7:10].copy()
self.mj_data.xfrc_applied[self.band_attached_link] = self.elastic_band.Advance(pose)
else:
# explicitly resetting the force when the band is not enabled
self.mj_data.xfrc_applied[self.band_attached_link] = np.zeros(6)
body_qpos = self.compute_body_qpos() # (num_body_dof,)
hand_qpos = self.compute_hand_qpos() # (num_hand_dof * 2,)
self.mj_data.qpos[self.body_joint_index + 7 - 1] = body_qpos
self.mj_data.qpos[self.left_hand_index + 7 - 1] = hand_qpos[: self.num_hand_dof]
self.mj_data.qpos[self.right_hand_index + 7 - 1] = hand_qpos[self.num_hand_dof :]
mujoco.mj_kinematics(self.mj_model, self.mj_data)
mujoco.mj_comPos(self.mj_model, self.mj_data)
def apply_perturbation(self, key):
"""Apply perturbation to the robot"""
# Add velocity perturbations in body frame
perturbation_x_body = 0.0 # forward/backward in body frame
perturbation_y_body = 0.0 # left/right in body frame
if key == "up":
perturbation_x_body = 1.0 # forward
elif key == "down":
perturbation_x_body = -1.0 # backward
elif key == "left":
perturbation_y_body = 1.0 # left
elif key == "right":
perturbation_y_body = -1.0 # right
# Transform body frame velocity to world frame using MuJoCo's rotation
vel_body = np.array([perturbation_x_body, perturbation_y_body, 0.0])
vel_world = np.zeros(3)
base_quat = self.mj_data.qpos[3:7] # [w, x, y, z] quaternion
# Use MuJoCo's robust quaternion rotation (handles invalid quaternions automatically)
mujoco.mju_rotVecQuat(vel_world, vel_body, base_quat)
# Apply to base linear velocity in world frame
self.mj_data.qvel[0] += vel_world[0] # world X velocity
self.mj_data.qvel[1] += vel_world[1] # world Y velocity
# Update dynamics after velocity change
mujoco.mj_forward(self.mj_model, self.mj_data)
def update_viewer(self):
if self.viewer is not None:
self.viewer.sync()
def update_viewer_camera(self):
if self.viewer is not None:
if self.viewer.cam.type == mujoco.mjtCamera.mjCAMERA_TRACKING:
self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_FREE
else:
self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_TRACKING
def update_reward(self):
"""Calculate reward. Should be implemented by subclasses."""
with self.reward_lock:
self.last_reward = 0
def get_reward(self):
"""Thread-safe way to get the last calculated reward."""
with self.reward_lock:
return self.last_reward
def set_unitree_bridge(self, unitree_bridge):
"""Set the unitree bridge from the simulator"""
self.unitree_bridge = unitree_bridge
def get_privileged_obs(self):
"""Get privileged observation. Should be implemented by subclasses."""
return {}
def update_render_caches(self):
"""Update render cache and shared memory for subprocess."""
render_caches = {}
for camera_name, camera_config in self.camera_configs.items():
renderer = self.renderers[camera_name]
if "params" in camera_config:
renderer.update_scene(self.mj_data, camera=camera_config["params"])
else:
renderer.update_scene(self.mj_data, camera=camera_name)
render_caches[camera_name + "_image"] = renderer.render()
# Update shared memory if image publishing process is available
if self.image_publish_process is not None:
self.image_publish_process.update_shared_memory(render_caches)
return render_caches
def handle_keyboard_button(self, key):
if self.elastic_band is not None:
self.elastic_band.handle_keyboard_button(key)
if key == "backspace":
self.reset()
if key == "v":
self.update_viewer_camera()
if key in ["up", "down", "left", "right"]:
self.apply_perturbation(key)
def check_fall(self):
"""Check if the robot has fallen"""
self.fall = False
if self.mj_data.qpos[2] < 0.2:
self.fall = True
print(f"Warning: Robot has fallen, height: {self.mj_data.qpos[2]:.3f} m")
if self.fall:
self.reset()
def check_self_collision(self):
"""Check for self-collision of the robot"""
robot_bodies = get_subtree_body_names(self.mj_model, self.mj_model.body(self.root_body).id)
self_collision, contact_bodies = check_contact(
self.mj_model, self.mj_data, robot_bodies, robot_bodies, return_all_contact_bodies=True
)
if self_collision:
print(f"Warning: Self-collision detected: {contact_bodies}")
return self_collision
def reset(self):
mujoco.mj_resetData(self.mj_model, self.mj_data)
class CubeEnv(DefaultEnv):
"""Environment with a cube object for pick and place tasks"""
def __init__(
self,
config: Dict[str, any],
onscreen: bool = False,
offscreen: bool = False,
):
# Override the robot scene
config = config.copy() # Create a copy to avoid modifying the original
config["ROBOT_SCENE"] = "gr00t_wbc/control/robot_model/model_data/g1/pnp_cube_43dof.xml"
super().__init__(config, "cube", {}, onscreen, offscreen)
def update_reward(self):
"""Calculate reward based on gripper contact with cube and cube height"""
right_hand_body = [
"right_hand_thumb_2_link",
"right_hand_middle_1_link",
"right_hand_index_1_link",
]
gripper_cube_contact = check_contact(
self.mj_model, self.mj_data, right_hand_body, "cube_body"
)
cube_lifted = check_height(self.mj_model, self.mj_data, "cube", 0.85, 2.0)
with self.reward_lock:
self.last_reward = gripper_cube_contact & cube_lifted
class BoxEnv(DefaultEnv):
"""Environment with a box object for manipulation tasks"""
def __init__(
self,
config: Dict[str, any],
onscreen: bool = False,
offscreen: bool = False,
):
# Override the robot scene
config = config.copy() # Create a copy to avoid modifying the original
config["ROBOT_SCENE"] = "gr00t_wbc/control/robot_model/model_data/g1/lift_box_43dof.xml"
super().__init__(config, "box", {}, onscreen, offscreen)
def reward(self):
"""Calculate reward based on gripper contact with cube and cube height"""
left_hand_body = [
"left_hand_thumb_2_link",
"left_hand_middle_1_link",
"left_hand_index_1_link",
]
right_hand_body = [
"right_hand_thumb_2_link",
"right_hand_middle_1_link",
"right_hand_index_1_link",
]
gripper_box_contact = check_contact(self.mj_model, self.mj_data, left_hand_body, "box_body")
gripper_box_contact &= check_contact(
self.mj_model, self.mj_data, right_hand_body, "box_body"
)
box_lifted = check_height(self.mj_model, self.mj_data, "box", 0.92, 2.0)
print("gripper_box_contact: ", gripper_box_contact, "box_lifted: ", box_lifted)
with self.reward_lock:
self.last_reward = gripper_box_contact & box_lifted
return self.last_reward
class BottleEnv(DefaultEnv):
"""Environment with a cylinder object for manipulation tasks"""
def __init__(
self,
config: Dict[str, any],
onscreen: bool = False,
offscreen: bool = False,
):
# Override the robot scene
config = config.copy() # Create a copy to avoid modifying the original
config["ROBOT_SCENE"] = "gr00t_wbc/control/robot_model/model_data/g1/pnp_bottle_43dof.xml"
camera_configs = {
"egoview": {
"height": 400,
"width": 400,
},
}
super().__init__(
config, "cylinder", camera_configs, onscreen, offscreen
)
self.bottle_body = self.mj_model.body("bottle_body")
self.bottle_geom = self.mj_model.geom("bottle")
if self.viewer is not None:
self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_FIXED
self.viewer.cam.fixedcamid = self.mj_model.camera("egoview").id
def update_reward(self):
"""Calculate reward based on gripper contact with cylinder and cylinder height"""
pass
def get_privileged_obs(self):
obs_pos = self.mj_data.xpos[self.bottle_body.id]
obs_quat = self.mj_data.xquat[self.bottle_body.id]
return {"bottle_pos": obs_pos, "bottle_quat": obs_quat}
class BaseSimulator:
"""Base simulator class that handles initialization and running of simulations"""
def __init__(self, config: Dict[str, any], env_name: str = "default", **kwargs):
self.config = config
self.env_name = env_name
# Initialize ROS 2 node (optional, only if rclpy is available)
if HAS_RCLPY:
if not rclpy.ok():
rclpy.init()
self.node = rclpy.create_node("sim_mujoco")
self.thread = threading.Thread(target=rclpy.spin, args=(self.node,), daemon=True)
self.thread.start()
else:
self.thread = None
executor = rclpy.get_global_executor()
self.node = executor.get_nodes()[0] # will only take the first node
else:
self.node = None
self.thread = None
# Set update frequencies
self.sim_dt = self.config["SIMULATE_DT"]
self.reward_dt = self.config.get("REWARD_DT", 0.02)
self.image_dt = self.config.get("IMAGE_DT", 0.033333)
self.viewer_dt = self.config.get("VIEWER_DT", 0.02)
# Create the appropriate environment based on name
if env_name == "default":
self.sim_env = DefaultEnv(config, env_name, **kwargs)
elif env_name == "pnp_cube":
self.sim_env = CubeEnv(config, **kwargs)
elif env_name == "lift_box":
self.sim_env = BoxEnv(config, **kwargs)
elif env_name == "pnp_bottle":
self.sim_env = BottleEnv(config, **kwargs)
else:
raise ValueError(f"Invalid environment name: {env_name}")
# Initialize the DDS communication layer - should be safe to call multiple times
try:
if self.config.get("INTERFACE", None):
ChannelFactoryInitialize(self.config["DOMAIN_ID"], self.config["INTERFACE"])
else:
ChannelFactoryInitialize(self.config["DOMAIN_ID"])
except Exception as e:
# If it fails because it's already initialized, that's okay
print(f"Note: Channel factory initialization attempt: {e}")
# Initialize the unitree bridge and pass it to the environment
self.init_unitree_bridge()
self.sim_env.set_unitree_bridge(self.unitree_bridge)
# Initialize additional components
self.init_subscriber()
self.init_publisher()
self.sim_thread = None
def start_as_thread(self):
# Create simulation thread
self.sim_thread = Thread(target=self.start)
self.sim_thread.start()
def start_image_publish_subprocess(self, start_method: str = "spawn", camera_port: int = 5555):
"""Start the image publish subprocess"""
self.sim_env.start_image_publish_subprocess(start_method, camera_port)
def init_subscriber(self):
"""Initialize subscribers. Can be overridden by subclasses."""
pass
def init_publisher(self):
"""Initialize publishers. Can be overridden by subclasses."""
pass
def init_unitree_bridge(self):
"""Initialize the unitree SDK bridge"""
self.unitree_bridge = UnitreeSdk2Bridge(self.config)
if self.config["USE_JOYSTICK"]:
self.unitree_bridge.SetupJoystick(
device_id=self.config["JOYSTICK_DEVICE"], js_type=self.config["JOYSTICK_TYPE"]
)
def start(self):
"""Main simulation loop"""
import time
sim_cnt = 0
last_time = time.time()
print(f"Starting simulation loop. Viewer: {self.sim_env.viewer is not None}")
try:
while (
self.sim_env.viewer and self.sim_env.viewer.is_running()
) or self.sim_env.viewer is None:
# Run simulation step
self.sim_env.sim_step()
# Update viewer at viewer rate
if sim_cnt % int(self.viewer_dt / self.sim_dt) == 0:
self.sim_env.update_viewer()
# Calculate reward at reward rate
if sim_cnt % int(self.reward_dt / self.sim_dt) == 0:
self.sim_env.update_reward()
# Update render caches at image rate
if sim_cnt % int(self.image_dt / self.sim_dt) == 0:
self.sim_env.update_render_caches()
# Sleep to maintain correct rate (simple timing without ROS)
elapsed = time.time() - last_time
sleep_time = max(0, self.sim_dt - elapsed)
if sleep_time > 0:
time.sleep(sleep_time)
last_time = time.time()
sim_cnt += 1
print(f"Loop exited. Viewer running: {self.sim_env.viewer.is_running() if self.sim_env.viewer else 'No viewer'}")
except KeyboardInterrupt:
# User pressed Ctrl+C - exit cleanly
print("Keyboard interrupt received")
pass
except Exception as e:
print(f"Exception in simulation loop: {e}")
import traceback
traceback.print_exc()
self.close()
def __del__(self):
"""Clean up resources when simulator is deleted"""
self.close()
def reset(self):
"""Reset the simulation. Can be overridden by subclasses."""
self.sim_env.reset()
def close(self):
"""Close the simulation. Can be overridden by subclasses."""
try:
# Close viewer
if hasattr(self.sim_env, "viewer") and self.sim_env.viewer is not None:
self.sim_env.viewer.close()
# Shutdown ROS (if available)
if HAS_RCLPY and rclpy.ok():
rclpy.shutdown()
except Exception as e:
print(f"Warning during close: {e}")
def get_privileged_obs(self):
obs = self.sim_env.get_privileged_obs()
# TODO: add ros2 topic to get privileged obs
return obs
def handle_keyboard_button(self, key):
# Only handles keyboard buttons for default env.
if self.env_name == "default":
self.sim_env.handle_keyboard_button(key)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Robot")
parser.add_argument(
"--config",
type=str,
default="./gr00t_wbc/control/main/teleop/configs/g1_29dof_gear_wbc.yaml",
help="config file",
)
args = parser.parse_args()
with open(args.config, "r") as file:
config = yaml.load(file, Loader=yaml.FullLoader)
if config.get("INTERFACE", None):
ChannelFactoryInitialize(config["DOMAIN_ID"], config["INTERFACE"])
else:
ChannelFactoryInitialize(config["DOMAIN_ID"])
simulation = BaseSimulator(config)
simulation.start_as_thread()