unitree-g1-mujoco / view_cameras_live.py
nepyope's picture
Upload folder using huggingface_hub
ef6a683 verified
raw
history blame
7.91 kB
#!/usr/bin/env python3
"""
Live camera viewer for MuJoCo simulator using matplotlib
Works without X11/GTK - suitable for SSH sessions with X forwarding
"""
import argparse
import sys
import time
from pathlib import Path
# Add sim module to path
sys.path.insert(0, str(Path(__file__).parent))
import cv2
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from sim.sensor_utils import SensorClient, ImageUtils
class CameraViewer:
def __init__(self, host, port):
self.client = SensorClient()
self.client.start_client(server_ip=host, port=port)
self.fig = None
self.axes = {}
self.images = {}
self.text_objs = {}
self.frame_count = 0
self.last_time = time.time()
self.fps = 0
def init_plot(self):
"""Initialize matplotlib figure and axes"""
# Wait for first frame to know how many cameras we have
print("Waiting for first frame to detect cameras...")
data = self.client.receive_message()
# Parse camera names - handle nested 'images' dict
camera_names = []
if "images" in data and isinstance(data["images"], dict):
# Nested structure: data["images"]["camera_name"]
camera_names = list(data["images"].keys())
else:
# Flat structure: data["camera_name"] directly
camera_names = [k for k in data.keys() if k not in ["timestamps", "images"]]
num_cameras = len(camera_names)
if num_cameras == 0:
print("No cameras found in stream!")
return False
print(f"Found {num_cameras} camera(s): {', '.join(camera_names)}")
# Create subplots
if num_cameras == 1:
self.fig, ax = plt.subplots(1, 1, figsize=(10, 8))
axes_list = [ax]
elif num_cameras == 2:
self.fig, axes_list = plt.subplots(1, 2, figsize=(16, 6))
else:
rows = (num_cameras + 1) // 2
self.fig, axes_list = plt.subplots(rows, 2, figsize=(16, 6 * rows))
axes_list = axes_list.flatten()
# Initialize each camera subplot
for i, cam_name in enumerate(camera_names):
ax = axes_list[i]
ax.set_title(f"{cam_name}", fontsize=12, fontweight='bold')
ax.axis('off')
# Get image data from nested or flat structure
if "images" in data and cam_name in data["images"]:
img_data = data["images"][cam_name]
elif cam_name in data:
img_data = data[cam_name]
else:
img_data = cam_name # Use the actual data if it's the value
# Decode first image
if isinstance(img_data, str):
img = ImageUtils.decode_image(img_data)
elif isinstance(img_data, np.ndarray):
img = img_data
else:
print(f"Warning: Unknown image format for {cam_name}: {type(img_data)}")
continue
# Check if image is valid
if img is None or not isinstance(img, np.ndarray):
print(f"Warning: Invalid image data for {cam_name}")
continue
# Convert BGR to RGB for matplotlib
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Display image
im = ax.imshow(img_rgb)
self.images[cam_name] = im
self.axes[cam_name] = ax
# Add FPS text
text = ax.text(0.02, 0.98, 'FPS: 0.0',
transform=ax.transAxes,
fontsize=10,
verticalalignment='top',
bbox=dict(boxstyle='round', facecolor='black', alpha=0.7),
color='lime',
fontweight='bold')
self.text_objs[cam_name] = text
# Hide unused subplots
if num_cameras < len(axes_list):
for i in range(num_cameras, len(axes_list)):
axes_list[i].axis('off')
self.fig.tight_layout()
return True
def update_frame(self, frame_num):
"""Update function for animation"""
try:
# Receive new frame
data = self.client.receive_message()
# Calculate FPS
self.frame_count += 1
current_time = time.time()
if current_time - self.last_time >= 1.0:
self.fps = self.frame_count / (current_time - self.last_time)
self.frame_count = 0
self.last_time = current_time
# Update each camera
for cam_name in self.images.keys():
# Get image data from nested or flat structure
if "images" in data and cam_name in data["images"]:
img_data = data["images"][cam_name]
elif cam_name in data:
img_data = data[cam_name]
else:
continue
# Decode image
if isinstance(img_data, str):
img = ImageUtils.decode_image(img_data)
elif isinstance(img_data, np.ndarray):
img = img_data
else:
continue
# Check if image is valid
if img is None or not isinstance(img, np.ndarray):
continue
# Convert BGR to RGB for matplotlib
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Update image
self.images[cam_name].set_data(img_rgb)
# Update FPS text
self.text_objs[cam_name].set_text(f'FPS: {self.fps:.1f}')
except Exception as e:
print(f"Error updating frame: {e}")
return list(self.images.values()) + list(self.text_objs.values())
def start(self, interval=33):
"""Start the live viewer"""
if not self.init_plot():
return
print(f"\n{'='*60}")
print("📹 Live camera viewer started!")
print("Close the window or press Ctrl+C to exit")
print(f"{'='*60}\n")
# Create animation
anim = FuncAnimation(
self.fig,
self.update_frame,
interval=interval, # ms between frames
blit=True,
cache_frame_data=False
)
try:
plt.show()
except KeyboardInterrupt:
print("\nStopping viewer...")
finally:
self.client.stop_client()
plt.close('all')
def main():
parser = argparse.ArgumentParser(description="Live camera viewer for MuJoCo simulator")
parser.add_argument("--host", type=str, default="localhost",
help="Simulator host address (default: localhost)")
parser.add_argument("--port", type=int, default=5555,
help="ZMQ port (default: 5555)")
parser.add_argument("--interval", type=int, default=33,
help="Update interval in ms (default: 33 = ~30fps)")
args = parser.parse_args()
print("="*60)
print("📷 MuJoCo Live Camera Viewer (matplotlib)")
print("="*60)
print(f"🌐 Connecting to: tcp://{args.host}:{args.port}")
print(f"⏱️ Update interval: {args.interval}ms (~{1000/args.interval:.0f} fps)")
print("="*60)
viewer = CameraViewer(host=args.host, port=args.port)
viewer.start(interval=args.interval)
if __name__ == "__main__":
main()