import gradio as gr import torch from transformers import AutoProcessor, Idefics3ForConditionalGeneration import logging logger = logging.getLogger(__name__) class SimpleVLMInterface: def __init__(self): self.model = None self.processor = None self.initialize_model() def initialize_model(self): try: model_id = "HuggingFaceTB/SmolVLM-Instruct" self.model = Idefics3ForConditionalGeneration.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16 ) self.processor = AutoProcessor.from_pretrained(model_id) # Load custom adapter adapter_path = "smolvlm-instruct-trl-sft-ChartQA" self.model.load_adapter(adapter_path) except Exception as e: logger.error(f"Error initializing model: {e}") raise def generate_response( self, text_input, image=None, max_tokens=512, temperature=0.7, top_p=0.95 ): try: # Prepare the multimodal message format message_content = [] # Add image content if provided if image is not None: if image.mode != 'RGB': image = image.convert('RGB') message_content.append({ 'type': 'image', 'image': image }) # Add text content message_content.append({ 'type': 'text', 'text': text_input }) # Create the complete message structure messages = { 'role': 'user', 'content': message_content } # Apply chat template chat_input = self.processor.apply_chat_template( [messages], # Wrap in list as it expects a sequence of messages add_generation_prompt=True ) # Prepare model inputs model_inputs = self.processor( text=chat_input, images=[msg['image'] for msg in message_content if msg['type'] == 'image'] if image is not None else None, return_tensors="pt", ).to(self.model.device) # Generate response generated_ids = self.model.generate( **model_inputs, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True ) # Process output trimmed_generated_ids = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids) ] output_text = self.processor.batch_decode( trimmed_generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] return output_text except Exception as e: logger.error(f"Error generating response: {e}") return f"Error: {str(e)}" def create_interface(): vlm = SimpleVLMInterface() with gr.Blocks(title="Simple VLM Interface") as demo: with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Image (optional)") text_input = gr.Textbox(label="Enter your text", lines=2) with gr.Row(): max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max tokens") temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature") top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p") submit_btn = gr.Button("Generate Response") output_text = gr.Textbox(label="Response", lines=4) submit_btn.click( fn=vlm.generate_response, inputs=[text_input, image_input, max_tokens, temperature, top_p], outputs=output_text ) return demo if __name__ == "__main__": demo = create_interface() demo.launch()