File size: 7,911 Bytes
ef6a683
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
#!/usr/bin/env python3
"""
Live camera viewer for MuJoCo simulator using matplotlib
Works without X11/GTK - suitable for SSH sessions with X forwarding
"""
import argparse
import sys
import time
from pathlib import Path

# Add sim module to path
sys.path.insert(0, str(Path(__file__).parent))

import cv2
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from sim.sensor_utils import SensorClient, ImageUtils


class CameraViewer:
    def __init__(self, host, port):
        self.client = SensorClient()
        self.client.start_client(server_ip=host, port=port)
        
        self.fig = None
        self.axes = {}
        self.images = {}
        self.text_objs = {}
        
        self.frame_count = 0
        self.last_time = time.time()
        self.fps = 0
        
    def init_plot(self):
        """Initialize matplotlib figure and axes"""
        # Wait for first frame to know how many cameras we have
        print("Waiting for first frame to detect cameras...")
        data = self.client.receive_message()
        
        # Parse camera names - handle nested 'images' dict
        camera_names = []
        if "images" in data and isinstance(data["images"], dict):
            # Nested structure: data["images"]["camera_name"]
            camera_names = list(data["images"].keys())
        else:
            # Flat structure: data["camera_name"] directly
            camera_names = [k for k in data.keys() if k not in ["timestamps", "images"]]
        
        num_cameras = len(camera_names)
        
        if num_cameras == 0:
            print("No cameras found in stream!")
            return False
        
        print(f"Found {num_cameras} camera(s): {', '.join(camera_names)}")
        
        # Create subplots
        if num_cameras == 1:
            self.fig, ax = plt.subplots(1, 1, figsize=(10, 8))
            axes_list = [ax]
        elif num_cameras == 2:
            self.fig, axes_list = plt.subplots(1, 2, figsize=(16, 6))
        else:
            rows = (num_cameras + 1) // 2
            self.fig, axes_list = plt.subplots(rows, 2, figsize=(16, 6 * rows))
            axes_list = axes_list.flatten()
        
        # Initialize each camera subplot
        for i, cam_name in enumerate(camera_names):
            ax = axes_list[i]
            ax.set_title(f"{cam_name}", fontsize=12, fontweight='bold')
            ax.axis('off')
            
            # Get image data from nested or flat structure
            if "images" in data and cam_name in data["images"]:
                img_data = data["images"][cam_name]
            elif cam_name in data:
                img_data = data[cam_name]
            else:
                img_data = cam_name  # Use the actual data if it's the value
            
            # Decode first image
            if isinstance(img_data, str):
                img = ImageUtils.decode_image(img_data)
            elif isinstance(img_data, np.ndarray):
                img = img_data
            else:
                print(f"Warning: Unknown image format for {cam_name}: {type(img_data)}")
                continue
            
            # Check if image is valid
            if img is None or not isinstance(img, np.ndarray):
                print(f"Warning: Invalid image data for {cam_name}")
                continue
            
            # Convert BGR to RGB for matplotlib
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            # Display image
            im = ax.imshow(img_rgb)
            self.images[cam_name] = im
            self.axes[cam_name] = ax
            
            # Add FPS text
            text = ax.text(0.02, 0.98, 'FPS: 0.0', 
                          transform=ax.transAxes,
                          fontsize=10,
                          verticalalignment='top',
                          bbox=dict(boxstyle='round', facecolor='black', alpha=0.7),
                          color='lime',
                          fontweight='bold')
            self.text_objs[cam_name] = text
        
        # Hide unused subplots
        if num_cameras < len(axes_list):
            for i in range(num_cameras, len(axes_list)):
                axes_list[i].axis('off')
        
        self.fig.tight_layout()
        return True
    
    def update_frame(self, frame_num):
        """Update function for animation"""
        try:
            # Receive new frame
            data = self.client.receive_message()
            
            # Calculate FPS
            self.frame_count += 1
            current_time = time.time()
            if current_time - self.last_time >= 1.0:
                self.fps = self.frame_count / (current_time - self.last_time)
                self.frame_count = 0
                self.last_time = current_time
            
            # Update each camera
            for cam_name in self.images.keys():
                # Get image data from nested or flat structure
                if "images" in data and cam_name in data["images"]:
                    img_data = data["images"][cam_name]
                elif cam_name in data:
                    img_data = data[cam_name]
                else:
                    continue
                
                # Decode image
                if isinstance(img_data, str):
                    img = ImageUtils.decode_image(img_data)
                elif isinstance(img_data, np.ndarray):
                    img = img_data
                else:
                    continue
                
                # Check if image is valid
                if img is None or not isinstance(img, np.ndarray):
                    continue
                
                # Convert BGR to RGB for matplotlib
                img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                
                # Update image
                self.images[cam_name].set_data(img_rgb)
                
                # Update FPS text
                self.text_objs[cam_name].set_text(f'FPS: {self.fps:.1f}')
            
        except Exception as e:
            print(f"Error updating frame: {e}")
        
        return list(self.images.values()) + list(self.text_objs.values())
    
    def start(self, interval=33):
        """Start the live viewer"""
        if not self.init_plot():
            return
        
        print(f"\n{'='*60}")
        print("📹 Live camera viewer started!")
        print("Close the window or press Ctrl+C to exit")
        print(f"{'='*60}\n")
        
        # Create animation
        anim = FuncAnimation(
            self.fig,
            self.update_frame,
            interval=interval,  # ms between frames
            blit=True,
            cache_frame_data=False
        )
        
        try:
            plt.show()
        except KeyboardInterrupt:
            print("\nStopping viewer...")
        finally:
            self.client.stop_client()
            plt.close('all')


def main():
    parser = argparse.ArgumentParser(description="Live camera viewer for MuJoCo simulator")
    parser.add_argument("--host", type=str, default="localhost",
                       help="Simulator host address (default: localhost)")
    parser.add_argument("--port", type=int, default=5555,
                       help="ZMQ port (default: 5555)")
    parser.add_argument("--interval", type=int, default=33,
                       help="Update interval in ms (default: 33 = ~30fps)")
    args = parser.parse_args()
    
    print("="*60)
    print("📷 MuJoCo Live Camera Viewer (matplotlib)")
    print("="*60)
    print(f"🌐 Connecting to: tcp://{args.host}:{args.port}")
    print(f"⏱️  Update interval: {args.interval}ms (~{1000/args.interval:.0f} fps)")
    print("="*60)
    
    viewer = CameraViewer(host=args.host, port=args.port)
    viewer.start(interval=args.interval)


if __name__ == "__main__":
    main()