Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration | |
| import tempfile | |
| import os | |
| import cv2 | |
| import numpy as np | |
| # Load model and processor | |
| model_id = "llava-hf/LLaVA-NeXT-Video-7B-hf" | |
| # Initialize processor and model with error handling | |
| try: | |
| processor = LlavaNextVideoProcessor.from_pretrained(model_id) | |
| model = LlavaNextVideoForConditionalGeneration.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| low_cpu_mem_usage=True | |
| ) | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| processor = None | |
| model = None | |
| def process_video_file(video_path): | |
| """Convert video file to the format expected by the model""" | |
| try: | |
| # Read video using OpenCV | |
| cap = cv2.VideoCapture(video_path) | |
| frames = [] | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # Convert BGR to RGB | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(frame) | |
| cap.release() | |
| # Convert to numpy array and normalize | |
| video_frames = np.array(frames) | |
| return video_frames | |
| except Exception as e: | |
| print(f"Error processing video: {e}") | |
| return None | |
| def analyze_dance_video(video_file, pose_scores="0.85", music_info="Unknown"): | |
| """ | |
| Analyze dance video with pose scores and music information | |
| """ | |
| if model is None or processor is None: | |
| return "Error: Model not loaded properly. Please check the logs." | |
| if video_file is None: | |
| return "Please upload a video file." | |
| try: | |
| # Process the video file | |
| video_frames = process_video_file(video_file) | |
| if video_frames is None: | |
| return "Error: Could not process video file." | |
| # Prepare the prompt | |
| prompt = f"""USER: You are an expert dance instructor. Analyze this dance performance video. | |
| Additional Data: | |
| - MediaPipe Pose Scores: {pose_scores} | |
| - Music Information: {music_info} | |
| - When analyzing, combine what you see in the video with the pose scores and the music details to provide precise, realistic feedback. | |
| Please provide detailed feedback on: | |
| 1. Timing and synchronization with music | |
| 2. Pose accuracy and technique | |
| 3. Movement flow and transitions | |
| 4. Areas for improvement | |
| 5. Overall performance rating (1-10) | |
| 6. How well the dancer's moves are synchronized with the music's tempo and beat. | |
| 7. The accuracy and technique of the dancer's poses, considering the pose scores. | |
| 8. The fluidity and smoothness of transitions between moves. | |
| 9. Specific areas where the dancer can improve. | |
| Give constructive feedback in a friendly, encouraging tone. | |
| ASSISTANT:""" | |
| # Process video with the model | |
| inputs = processor( | |
| text=prompt, | |
| videos=[video_frames], # Note: videos expects a list | |
| return_tensors="pt" | |
| ) | |
| # Move inputs to the same device as model | |
| inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} | |
| # Generate response | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=500, | |
| do_sample=True, | |
| temperature=0.7, | |
| pad_token_id=processor.tokenizer.eos_token_id, | |
| eos_token_id=processor.tokenizer.eos_token_id | |
| ) | |
| # Decode response | |
| response = processor.decode(output[0], skip_special_tokens=True) | |
| # Extract just the generated part (after ASSISTANT:) | |
| if "ASSISTANT:" in response: | |
| response = response.split("ASSISTANT:")[-1].strip() | |
| return response | |
| except Exception as e: | |
| return f"Error analyzing video: {str(e)}" | |
| # Create Gradio interface | |
| with gr.Blocks(title="AI Dance Instructor") as demo: | |
| gr.Markdown("# ๐บ AI Dance Instructor") | |
| gr.Markdown("Upload your dance video along with pose scores for detailed feedback!") | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_input = gr.Video( | |
| label="Upload Dance Video" | |
| ) | |
| pose_scores = gr.Textbox( | |
| label="MediaPipe Pose Scores", | |
| placeholder="Enter pose scores data from MediaPipe (e.g., 0.85)...", | |
| value="0.85", | |
| lines=3 | |
| ) | |
| music_info = gr.Textbox( | |
| label="Music Information", | |
| placeholder="BPM, genre, rhythm details...", | |
| value="120 BPM, Pop music", | |
| lines=3 | |
| ) | |
| analyze_btn = gr.Button("Analyze Dance", variant="primary") | |
| with gr.Column(): | |
| feedback_output = gr.Textbox( | |
| label="Dance Feedback", | |
| lines=15, | |
| interactive=False | |
| ) | |
| # Set up the analysis function | |
| analyze_btn.click( | |
| fn=analyze_dance_video, | |
| inputs=[video_input, pose_scores, music_info], | |
| outputs=[feedback_output] | |
| ) | |
| # Add API endpoint info | |
| gr.Markdown("### API Usage") | |
| gr.Markdown(""" | |
| **For Next.js Integration, use the API endpoint:** | |
| ```javascript | |
| const formData = new FormData(); | |
| formData.append('data', JSON.stringify([videoFile, poseScores, musicInfo])); | |
| const response = await fetch('https://your-space-name.hf.space/api/predict', { | |
| method: 'POST', | |
| body: formData | |
| }); | |
| ``` | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() |