Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, UploadFile, File | |
| from fastapi.responses import JSONResponse | |
| import traceback | |
| import tempfile | |
| import torch | |
| # import mimetypes | |
| from PIL import Image | |
| import av | |
| import numpy as np | |
| import os | |
| from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration | |
| from my_lib.preproces_video import read_video_pyav | |
| app = FastAPI() | |
| # Load model and processor | |
| MODEL_ID = "llava-hf/LLaVA-NeXT-Video-7B-hf" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print("Loading model and processor...") | |
| processor = LlavaNextVideoProcessor.from_pretrained(MODEL_ID) | |
| # Optional: Pre-cache model on HF Spaces to avoid redownloading | |
| # from huggingface_hub import snapshot_download | |
| # snapshot_download(MODEL_ID) | |
| if device.type == "cuda": | |
| try: | |
| model = LlavaNextVideoForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True, | |
| load_in_4bit=True # Requires bitsandbytes and GPU | |
| ).to(device) | |
| print("Loaded model in 4-bit quantized mode.") | |
| except Exception as e: | |
| print("Failed to load in 4-bit mode:", e) | |
| print("Falling back to full precision FP16.") | |
| model = LlavaNextVideoForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True, | |
| ).to(device) | |
| else: | |
| model = LlavaNextVideoForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float32 | |
| ).to(device) | |
| print(f"Model and processor loaded on {device}.") | |
| async def root(): | |
| return {"message": "Welcome to the Summarization API. Use /summarize to summarize media files."} | |
| async def health(): | |
| return {"status": "ok", "device": device.type} | |
| async def summarize_media(file: UploadFile = File(...)): | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=file.filename) as tmp: | |
| tmp.write(await file.read()) | |
| tmp_path = tmp.name | |
| content_type = file.content_type | |
| is_video = content_type.startswith("video/") | |
| is_image = content_type.startswith("image/") | |
| if not (is_video or is_image): | |
| os.unlink(tmp_path) | |
| return JSONResponse(status_code=400, content={"error": f"Unsupported file type: {content_type}"}) | |
| if is_video: | |
| container = av.open(tmp_path) | |
| total_frames = container.streams.video[0].frames or sum(1 for _ in container.decode(video=0)) | |
| container = av.open(tmp_path) # reopen to reset position | |
| if total_frames == 0: | |
| raise ValueError("Could not extract frames: total frame count is zero.") | |
| num_frames = min(8, total_frames) | |
| indices = np.linspace(0, total_frames - 1, num_frames).astype(int) | |
| clip = read_video_pyav(container, indices) | |
| conversation = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": "Summarize this video and explain the key highlights."}, | |
| {"type": "video"}, | |
| ], | |
| }, | |
| ] | |
| prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) | |
| inputs = processor(text=prompt, videos=clip, return_tensors="pt").to(device) | |
| elif is_image: | |
| image = Image.open(tmp_path).convert("RGB") | |
| conversation = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": "Describe the image and summarize its content."}, | |
| {"type": "image"}, | |
| ], | |
| }, | |
| ] | |
| prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) | |
| inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) | |
| output_ids = model.generate(**inputs, max_new_tokens=512) | |
| response_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0] | |
| return JSONResponse(content={"summary": response_text}) | |
| except Exception as e: | |
| print("Unhandled error:", e) | |
| print(traceback.format_exc()) | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| finally: | |
| if 'tmp_path' in locals() and os.path.exists(tmp_path): | |
| os.unlink(tmp_path) | |