import gradio as gr import torch from transformers import AutoProcessor, Idefics3ForConditionalGeneration import logging logger = logging.getLogger(__name__) class CustomModelChat: def __init__(self): self.model = None self.processor = None self.initialize_model() def initialize_model(self): try: model_id = "HuggingFaceTB/SmolVLM-Instruct" self.model = Idefics3ForConditionalGeneration.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16 ) self.processor = AutoProcessor.from_pretrained(model_id) # Load your custom adapter adapter_path = "smolVLM_Essay_Knowledge_Distillation/smolvlm-instruct-trl-sft-ChartQA" self.model.load_adapter(adapter_path) except Exception as e: logger.error(f"Error initializing model: {e}") raise def process_chat_history(self, history, system_message): # Convert chat history to the format expected by the model messages = [{"role": "system", "content": system_message}] for user_msg, assistant_msg in history: if user_msg: messages.append({"role": "user", "content": user_msg}) if assistant_msg: messages.append({"role": "assistant", "content": assistant_msg}) return messages def generate_response( self, message, history, system_message, max_tokens=512, temperature=0.7, top_p=0.95, image=None ): try: messages = self.process_chat_history(history, system_message) messages.append({"role": "user", "content": message}) # Prepare the chat template chat_input = self.processor.apply_chat_template( messages[1:], # Exclude system message add_generation_prompt=True ) # Handle image input if provided image_inputs = [] if image is not None: if image.mode != 'RGB': image = image.convert('RGB') image_inputs.append([image]) # Prepare model inputs model_inputs = self.processor( text=chat_input, images=image_inputs if image_inputs else None, return_tensors="pt", ).to(self.model.device) # Generate response generated_ids = self.model.generate( **model_inputs, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True ) # Trim and decode the response trimmed_ids = generated_ids[:, len(model_inputs.input_ids[0]):] response = self.processor.batch_decode( trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] yield response except Exception as e: logger.error(f"Error generating response: {e}") yield f"Error: {str(e)}" def create_chat_interface(): chat_model = CustomModelChat() demo = gr.ChatInterface( chat_model.generate_response, additional_inputs=[ gr.Textbox(value="You are a helpful assistant.", label="System message"), gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"), gr.Image(type="pil", label="Upload Image (optional)") ], title="Custom SmolVLM Chat", description="Chat interface using custom fine-tuned SmolVLM model" ) return demo if __name__ == "__main__": demo = create_chat_interface() demo.launch()