szymmon's picture
my model
ae7d21a
raw
history blame
4.07 kB
import gradio as gr
import torch
from transformers import AutoProcessor, Idefics3ForConditionalGeneration
import logging
logger = logging.getLogger(__name__)
class CustomModelChat:
def __init__(self):
self.model = None
self.processor = None
self.initialize_model()
def initialize_model(self):
try:
model_id = "HuggingFaceTB/SmolVLM-Instruct"
self.model = Idefics3ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16
)
self.processor = AutoProcessor.from_pretrained(model_id)
# Load your custom adapter
adapter_path = "smolVLM_Essay_Knowledge_Distillation/smolvlm-instruct-trl-sft-ChartQA"
self.model.load_adapter(adapter_path)
except Exception as e:
logger.error(f"Error initializing model: {e}")
raise
def process_chat_history(self, history, system_message):
# Convert chat history to the format expected by the model
messages = [{"role": "system", "content": system_message}]
for user_msg, assistant_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
return messages
def generate_response(
self,
message,
history,
system_message,
max_tokens=512,
temperature=0.7,
top_p=0.95,
image=None
):
try:
messages = self.process_chat_history(history, system_message)
messages.append({"role": "user", "content": message})
# Prepare the chat template
chat_input = self.processor.apply_chat_template(
messages[1:], # Exclude system message
add_generation_prompt=True
)
# Handle image input if provided
image_inputs = []
if image is not None:
if image.mode != 'RGB':
image = image.convert('RGB')
image_inputs.append([image])
# Prepare model inputs
model_inputs = self.processor(
text=chat_input,
images=image_inputs if image_inputs else None,
return_tensors="pt",
).to(self.model.device)
# Generate response
generated_ids = self.model.generate(
**model_inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True
)
# Trim and decode the response
trimmed_ids = generated_ids[:, len(model_inputs.input_ids[0]):]
response = self.processor.batch_decode(
trimmed_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
yield response
except Exception as e:
logger.error(f"Error generating response: {e}")
yield f"Error: {str(e)}"
def create_chat_interface():
chat_model = CustomModelChat()
demo = gr.ChatInterface(
chat_model.generate_response,
additional_inputs=[
gr.Textbox(value="You are a helpful assistant.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
gr.Image(type="pil", label="Upload Image (optional)")
],
title="Custom SmolVLM Chat",
description="Chat interface using custom fine-tuned SmolVLM model"
)
return demo
if __name__ == "__main__":
demo = create_chat_interface()
demo.launch()