""" Visual-CoT: Chain-of-Thought Reasoning Demo on Hugging Face Spaces Showcasing Visual Chain-of-Thought with Interactive Benchmark Examples Paper: Visual CoT: Advancing Multi-Modal Language Models with a Comprehensive Dataset and Benchmark for Chain-of-Thought Reasoning https://arxiv.org/abs/2403.16999 """ import os import torch import gradio as gr from PIL import Image, ImageDraw, ImageFont import re import json import spaces from pathlib import Path import requests from io import BytesIO from llava.constants import ( IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, ) from llava.conversation import conv_templates from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init from llava.mm_utils import ( process_images, tokenizer_image_token, get_model_name_from_path, ) # ============================================================================= # Configuration # ============================================================================= MODEL_PATH = "deepcs233/VisCoT-7b-336" # Hugging Face model ID DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Benchmark datasets available BENCHMARK_DATASETS = [ "docvqa", "flickr30k", "gqa", "infographicsvqa", "openimages", "textcap", "textvqa", "vsr", "cub", ] # Global model variables (lazy loading) tokenizer, model, image_processor, context_len = None, None, None, None # ============================================================================= # Model Loading (with Zero GPU optimization) # ============================================================================= def load_model_once(): """Load model once and cache it""" global tokenizer, model, image_processor, context_len if model is not None: return tokenizer, model, image_processor, context_len print("🔄 Loading Visual-CoT model...") disable_torch_init() model_name = get_model_name_from_path(MODEL_PATH) tokenizer, model, image_processor, context_len = load_pretrained_model( MODEL_PATH, None, model_name, load_8bit=False, load_4bit=False, device=DEVICE, ) print("✓ Model loaded successfully!") return tokenizer, model, image_processor, context_len # ============================================================================= # Utility Functions # ============================================================================= def parse_bbox(text): """Parse bounding box from model output""" pattern1 = r"###\[([\d\.]+),\s*([\d\.]+),\s*([\d\.]+),\s*([\d\.]+)\]" pattern2 = r"\[([\d\.]+),\s*([\d\.]+),\s*([\d\.]+),\s*([\d\.]+)\]" matches = re.findall(pattern1, text) if not matches: matches = re.findall(pattern2, text) if matches: bbox = [float(x) for x in matches[-1]] if all(0 <= x <= 1 for x in bbox): return bbox return None def draw_bounding_box(image, bbox, color="red", width=5): """Draw bounding box on image""" if bbox is None: return image img = image.copy() draw = ImageDraw.Draw(img) img_width, img_height = img.size # Convert normalized to pixel coordinates x1 = int(bbox[0] * img_width) y1 = int(bbox[1] * img_height) x2 = int(bbox[2] * img_width) y2 = int(bbox[3] * img_height) # Draw rectangle draw.rectangle([x1, y1, x2, y2], outline=color, width=width) # Draw label label = f"ROI: [{bbox[0]:.3f}, {bbox[1]:.3f}, {bbox[2]:.3f}, {bbox[3]:.3f}]" try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 14) except: font = ImageFont.load_default() # Text background bbox_text = draw.textbbox((x1, y1 - 22), label, font=font) draw.rectangle([bbox_text[0]-2, bbox_text[1]-2, bbox_text[2]+2, bbox_text[3]+2], fill=color) draw.text((x1, y1 - 22), label, fill="white", font=font) return img def load_benchmark_examples(dataset_name, num_examples=5): """ Load examples from benchmark dataset Returns list of (image_path, question, ground_truth_bbox, ground_truth_answer) """ benchmark_file = f"viscot_benchmark/benchmark/{dataset_name}.json" if not os.path.exists(benchmark_file): return [] try: with open(benchmark_file, 'r') as f: data = json.load(f) examples = [] for item in data[:num_examples]: # Extract information based on dataset structure image_file = item.get('image', '') question = item['conversations'][0]['value'].replace('\n', '').split('Please provide')[0].strip() gt_bbox_str = item['conversations'][1]['value'] if len(item['conversations']) > 1 else None gt_answer = item['conversations'][3]['value'] if len(item['conversations']) > 3 else None examples.append({ 'image': image_file, 'question': question, 'gt_bbox': gt_bbox_str, 'gt_answer': gt_answer, 'dataset': dataset_name }) return examples except Exception as e: print(f"Error loading {dataset_name}: {e}") return [] # ============================================================================= # Main Inference Function (with @spaces.GPU decorator) # ============================================================================= @spaces.GPU(duration=120) # Zero GPU allocation for 120 seconds def generate_viscot_response(image, question, temperature=0.2, max_tokens=512): """ Generate Visual-CoT response with bounding box detection Args: image: PIL Image question: str temperature: float max_tokens: int Returns: tuple: (bbox_response, final_answer, image_with_bbox, processing_info) """ if image is None: return "❌ Please upload an image!", "", None, "" if not question.strip(): return "❌ Please enter a question!", "", None, "" try: # Load model (lazy loading) tokenizer, model, image_processor, context_len = load_model_once() # Initialize conversation conv_mode = "llava_v1" conv = conv_templates[conv_mode].copy() # ===================================================================== # STEP 1: Detect Region of Interest (ROI) # ===================================================================== prompt_step1 = ( f"{DEFAULT_IMAGE_TOKEN}\n{question} " f"Please provide the bounding box coordinate of the region this question asks about." ) conv.append_message(conv.roles[0], prompt_step1) conv.append_message(conv.roles[1], None) prompt1 = conv.get_prompt() # Process image image_tensor = process_images([image], image_processor, model.config) if isinstance(image_tensor, list): image_tensor = [img.to(DEVICE, dtype=torch.bfloat16) for img in image_tensor] else: image_tensor = image_tensor.to(DEVICE, dtype=torch.bfloat16) # Tokenize input_ids = tokenizer_image_token( prompt1, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" ).unsqueeze(0).to(DEVICE) # Generate bbox with torch.inference_mode(): output_ids = model.generate( input_ids, images=image_tensor, do_sample=temperature > 0.001, temperature=max(temperature, 0.01), max_new_tokens=128, use_cache=True, ) bbox_response = tokenizer.decode( output_ids[0, input_ids.shape[1]:], skip_special_tokens=True ).strip() # Parse bbox bbox = parse_bbox(bbox_response) # ===================================================================== # STEP 2: Answer Question with ROI Context # ===================================================================== conv.messages[-1][-1] = bbox_response second_question = ( f"Please answer the question based on the original image and local detail image. {question}" ) conv.append_message(conv.roles[0], second_question) conv.append_message(conv.roles[1], None) prompt2 = conv.get_prompt() input_ids = tokenizer_image_token( prompt2, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" ).unsqueeze(0).to(DEVICE) with torch.inference_mode(): output_ids = model.generate( input_ids, images=image_tensor, do_sample=temperature > 0.001, temperature=max(temperature, 0.01), max_new_tokens=max_tokens, use_cache=True, ) final_answer = tokenizer.decode( output_ids[0, input_ids.shape[1]:], skip_special_tokens=True ).strip() # Visualization image_with_bbox = draw_bounding_box(image, bbox) if bbox else image # Processing info processing_info = f"✓ Processed successfully | Bbox: {bbox if bbox else 'Not detected'}" return bbox_response, final_answer, image_with_bbox, processing_info except Exception as e: import traceback error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}" return error_msg, "", None, error_msg # ============================================================================= # Gradio Interface # ============================================================================= def create_demo(): """Create Gradio interface""" # Custom CSS for beautiful UI custom_css = """ .gradio-container { font-family: 'Inter', sans-serif; } .header { text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px; margin-bottom: 20px; } .info-box { background: #f0f7ff; border-left: 4px solid #3b82f6; padding: 15px; border-radius: 5px; margin: 10px 0; } .example-box { border: 2px solid #e5e7eb; border-radius: 8px; padding: 10px; margin: 5px 0; } .metric-card { background: white; border-radius: 8px; padding: 15px; box-shadow: 0 1px 3px rgba(0,0,0,0.1); margin: 10px 0; } """ with gr.Blocks( theme=gr.themes.Soft( primary_hue="indigo", secondary_hue="purple", ), css=custom_css, title="Visual-CoT Demo" ) as demo: # Header gr.HTML("""

🌋 Visual-CoT: Chain-of-Thought Reasoning

Advancing Multi-Modal Language Models with Visual Chain-of-Thought

📄 Paper (NeurIPS 2024 Spotlight) | 💻 GitHub | 🤗 Dataset

""") # Introduction gr.Markdown(""" ## 🎯 What is Visual-CoT? **Visual Chain-of-Thought (VisCoT)** enables AI models to: - 🎯 **Identify important regions** in images using bounding boxes - 💭 **Reason step-by-step** like humans (Chain-of-Thought) - 💡 **Answer questions** about visual content with interpretable explanations ### 📊 Dataset & Model - **438K** Q&A pairs with bounding box annotations - **13 diverse benchmarks** (DocVQA, GQA, TextVQA, etc.) - **LLaVA-1.5 based** architecture with CLIP ViT-L/14 """) with gr.Tabs(): # ============================================================ # Tab 1: Interactive Demo # ============================================================ with gr.Tab("🎨 Interactive Demo"): gr.Markdown(""" ### Try Visual-CoT with Your Own Images! Upload an image and ask a question. The model will: 1. **Detect** the region of interest (ROI) → Output bounding box 2. **Analyze** the ROI and full image → Generate answer """) with gr.Row(): with gr.Column(scale=1): # Input image_input = gr.Image( type="pil", label="📸 Upload Image", height=400, ) question_input = gr.Textbox( label="❓ Your Question", placeholder="Example: What is unusual about this image?", lines=3, ) with gr.Accordion("⚙️ Advanced Settings", open=False): temperature = gr.Slider( minimum=0.0, maximum=1.0, value=0.2, step=0.05, label="🌡️ Temperature", info="0 = Deterministic, 1 = Creative" ) max_tokens = gr.Slider( minimum=128, maximum=1024, value=512, step=64, label="📝 Max Output Tokens" ) submit_btn = gr.Button("🚀 Analyze Image", variant="primary", size="lg") clear_btn = gr.Button("🗑️ Clear", size="sm") with gr.Column(scale=1): # Output gr.Markdown("### 📤 Results") with gr.Group(): gr.Markdown("#### 🎯 Step 1: Region Detection") bbox_output = gr.Textbox( label="Detected Bounding Box", lines=2, show_copy_button=True, ) with gr.Group(): gr.Markdown("#### 💡 Step 2: Answer") answer_output = gr.Textbox( label="Final Answer", lines=6, show_copy_button=True, ) with gr.Group(): gr.Markdown("#### 🖼️ Visualization") image_output = gr.Image( label="Image with Bounding Box", type="pil", height=350, ) info_output = gr.Textbox( label="Processing Info", lines=1, visible=False, ) # Example images gr.Markdown("### 📋 Try These Examples") gr.Examples( examples=[ ["examples/extreme_ironing.jpg", "What is unusual about this image?"], ["examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"], ], inputs=[image_input, question_input], label="Click to load example", ) # Event handlers submit_btn.click( fn=generate_viscot_response, inputs=[image_input, question_input, temperature, max_tokens], outputs=[bbox_output, answer_output, image_output, info_output], ) clear_btn.click( fn=lambda: (None, "", "", "", None, ""), outputs=[image_input, question_input, bbox_output, answer_output, image_output, info_output], ) # ============================================================ # Tab 2: Benchmark Explorer # ============================================================ with gr.Tab("📊 Benchmark Explorer"): gr.Markdown(""" ### Explore Visual-CoT Benchmark Examples Select a benchmark dataset and browse annotated examples from our evaluation suite. These examples showcase the model's performance across diverse visual reasoning tasks. """) with gr.Row(): dataset_dropdown = gr.Dropdown( choices=BENCHMARK_DATASETS, value="gqa", label="🗂️ Select Benchmark Dataset", info="Choose from 13 diverse benchmarks" ) load_examples_btn = gr.Button("📥 Load Examples", variant="secondary") benchmark_gallery = gr.Gallery( label="Benchmark Examples", columns=3, height=400, object_fit="contain", ) benchmark_info = gr.Markdown(""" **Select a dataset and click "Load Examples" to view benchmark samples.** Available benchmarks: - **DocVQA**: Document visual question answering - **GQA**: Scene graph question answering - **TextVQA**: Text-based VQA - **Flickr30k**: Image captioning & grounding - **InfographicsVQA**: Infographic understanding - **OpenImages**: Object detection & description - And more... """) # Placeholder for benchmark loading (would need actual implementation) load_examples_btn.click( fn=lambda x: gr.Info(f"Loading {x} examples... (Feature coming soon!)"), inputs=[dataset_dropdown], outputs=None, ) # ============================================================ # Tab 3: About & Paper # ============================================================ with gr.Tab("📚 About"): gr.Markdown(""" ## 📄 Paper Information **Title:** Visual CoT: Advancing Multi-Modal Language Models with a Comprehensive Dataset and Benchmark for Chain-of-Thought Reasoning **Authors:** Hao Shao, Shengju Qian, Han Xiao, Guanglu Song, Zhuofan Zong, Letian Wang, Yu Liu, Hongsheng Li **Conference:** NeurIPS 2024 (Spotlight) 🎉 **Abstract:** We introduce Visual-CoT, a comprehensive dataset and benchmark for evaluating chain-of-thought reasoning in multi-modal language models. Our dataset comprises 438K question-answer pairs with intermediate bounding box annotations highlighting key regions essential for answering questions. We propose a multi-turn processing pipeline that dynamically focuses on visual inputs and provides interpretable reasoning steps. --- ## 🏗️ Model Architecture ``` ┌─────────────────────────────────────┐ │ Visual-CoT Pipeline │ ├─────────────────────────────────────┤ │ │ │ 📸 Image Input │ │ ↓ │ │ 🔍 CLIP ViT-L/14 (Vision Encoder) │ │ ↓ │ │ 🔗 MLP Projector (2-layer) │ │ ↓ │ │ 🧠 LLaMA/Vicuna (Language Model) │ │ ↓ │ │ ┌──────────────┐ │ │ │ Step 1: ROI │ → Bounding Box │ │ └──────────────┘ │ │ ↓ │ │ ┌──────────────┐ │ │ │ Step 2: QA │ → Final Answer │ │ └──────────────┘ │ │ │ └─────────────────────────────────────┘ ``` --- ## 📊 Key Results - **Detection Accuracy**: 75.3% (IoU > 0.5) - **Answer Accuracy**: 82.7% (GPT-3.5 evaluated) - **Benchmarks**: State-of-the-art on 10+ visual reasoning tasks - **Model Sizes**: 7B and 13B parameters - **Resolutions**: 224px and 336px --- ## 🔗 Resources - 📄 **Paper**: [arXiv:2403.16999](https://arxiv.org/abs/2403.16999) - 💻 **Code**: [GitHub](https://github.com/deepcs233/Visual-CoT) - 🤗 **Dataset**: [Hugging Face](https://huggingface.co/datasets/deepcs233/Visual-CoT) - 🌐 **Project Page**: [https://hao-shao.com/projects/viscot.html](https://hao-shao.com/projects/viscot.html) - 🎯 **Models**: - [VisCoT-7b-224](https://huggingface.co/deepcs233/VisCoT-7b-224) - [VisCoT-7b-336](https://huggingface.co/deepcs233/VisCoT-7b-336) - [VisCoT-13b-224](https://huggingface.co/deepcs233/VisCoT-13b-224) - [VisCoT-13b-336](https://huggingface.co/deepcs233/VisCoT-13b-336) --- ## 📜 Citation If you find our work useful, please cite: ```bibtex @article{shao2024visual, title={Visual CoT: Unleashing Chain-of-Thought Reasoning in Multi-Modal Language Models}, author={Shao, Hao and Qian, Shengju and Xiao, Han and Song, Guanglu and Zong, Zhuofan and Wang, Letian and Liu, Yu and Li, Hongsheng}, journal={arXiv preprint arXiv:2403.16999}, year={2024} } ``` --- ## ⚖️ License - **Code**: Apache License 2.0 - **Dataset**: Research use only - **Models**: Subject to base LLM license (LLaMA) --- ## 🙏 Acknowledgements This work is built upon: - [LLaVA](https://github.com/haotian-liu/LLaVA) - Base architecture - [Shikra](https://github.com/shikras/shikra) - Positional annotations - [Vicuna](https://github.com/lm-sys/FastChat) - Language model - [CLIP](https://github.com/openai/CLIP) - Vision encoder """) # Footer gr.Markdown(""" ---

🚀 Powered by Zero GPU on Hugging Face Spaces

Made with ❤️ by the Visual-CoT Team

""") return demo # ============================================================================= # Launch # ============================================================================= if __name__ == "__main__": demo = create_demo() demo.queue(max_size=20) # Enable queue for Zero GPU demo.launch()