szymmon's picture
ui
f565140
import gradio as gr
import torch
from transformers import AutoProcessor, Idefics3ForConditionalGeneration
import logging
logger = logging.getLogger(__name__)
class SimpleVLMInterface:
def __init__(self):
self.model = None
self.processor = None
self.initialize_model()
def initialize_model(self):
try:
model_id = "HuggingFaceTB/SmolVLM-Instruct"
self.model = Idefics3ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16
)
self.processor = AutoProcessor.from_pretrained(model_id)
# Load custom adapter
adapter_path = "smolvlm-instruct-trl-sft-ChartQA"
self.model.load_adapter(adapter_path)
except Exception as e:
logger.error(f"Error initializing model: {e}")
raise
def generate_response(
self,
text_input,
image=None,
max_tokens=512,
temperature=0.7,
top_p=0.95
):
try:
# Prepare the multimodal message format
message_content = []
# Add image content if provided
if image is not None:
if image.mode != 'RGB':
image = image.convert('RGB')
message_content.append({
'type': 'image',
'image': image
})
# Add text content
message_content.append({
'type': 'text',
'text': text_input
})
# Create the complete message structure
messages = {
'role': 'user',
'content': message_content
}
# Apply chat template
chat_input = self.processor.apply_chat_template(
[messages], # Wrap in list as it expects a sequence of messages
add_generation_prompt=True
)
# Prepare model inputs
model_inputs = self.processor(
text=chat_input,
images=[msg['image'] for msg in message_content if msg['type'] == 'image'] if image is not None else None,
return_tensors="pt",
).to(self.model.device)
# Generate response
generated_ids = self.model.generate(
**model_inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True
)
# Process output
trimmed_generated_ids = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
trimmed_generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
return output_text
except Exception as e:
logger.error(f"Error generating response: {e}")
return f"Error: {str(e)}"
def create_interface():
vlm = SimpleVLMInterface()
with gr.Blocks(title="Simple VLM Interface") as demo:
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image (optional)")
text_input = gr.Textbox(label="Enter your text", lines=2)
with gr.Row():
max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max tokens")
temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
submit_btn = gr.Button("Generate Response")
output_text = gr.Textbox(label="Response", lines=4)
submit_btn.click(
fn=vlm.generate_response,
inputs=[text_input, image_input, max_tokens, temperature, top_p],
outputs=output_text
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch()