File size: 4,248 Bytes
eba5417
e9b398a
 
 
eba5417
e9b398a
eba5417
7a9fb99
e9b398a
 
 
 
eba5417
e9b398a
 
 
 
 
 
 
 
 
7a9fb99
fd4c9b9
e9b398a
 
 
 
eba5417
e9b398a
 
7a9fb99
 
e9b398a
 
7a9fb99
e9b398a
 
f565140
 
 
 
730c387
f565140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9b398a
f565140
e9b398a
 
eba5417
e9b398a
 
 
f565140
e9b398a
 
 
 
 
 
 
 
 
 
 
eba5417
7a9fb99
9bf607c
 
 
 
 
e9b398a
 
7a9fb99
eba5417
7a9fb99
e9b398a
 
7a9fb99
eba5417
7a9fb99
 
 
 
 
 
 
 
 
 
 
 
f565140
 
7a9fb99
 
 
 
 
f565140
e9b398a
eba5417
 
7a9fb99
e9b398a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import torch
from transformers import AutoProcessor, Idefics3ForConditionalGeneration
import logging

logger = logging.getLogger(__name__)

class SimpleVLMInterface:
    def __init__(self):
        self.model = None
        self.processor = None
        self.initialize_model()

    def initialize_model(self):
        try:
            model_id = "HuggingFaceTB/SmolVLM-Instruct"
            self.model = Idefics3ForConditionalGeneration.from_pretrained(
                model_id,
                device_map="auto",
                torch_dtype=torch.bfloat16
            )
            self.processor = AutoProcessor.from_pretrained(model_id)
            # Load custom adapter
            adapter_path = "smolvlm-instruct-trl-sft-ChartQA"
            self.model.load_adapter(adapter_path)
        except Exception as e:
            logger.error(f"Error initializing model: {e}")
            raise

    def generate_response(
        self,
        text_input,
        image=None,
        max_tokens=512,
        temperature=0.7,
        top_p=0.95
    ):
        try:
            # Prepare the multimodal message format
            message_content = []
            
            # Add image content if provided
            if image is not None:
                if image.mode != 'RGB':
                    image = image.convert('RGB')
                message_content.append({
                    'type': 'image',
                    'image': image
                })

            # Add text content
            message_content.append({
                'type': 'text',
                'text': text_input
            })

            # Create the complete message structure
            messages = {
                'role': 'user',
                'content': message_content
            }

            # Apply chat template
            chat_input = self.processor.apply_chat_template(
                [messages],  # Wrap in list as it expects a sequence of messages
                add_generation_prompt=True
            )

            # Prepare model inputs
            model_inputs = self.processor(
                text=chat_input,
                images=[msg['image'] for msg in message_content if msg['type'] == 'image'] if image is not None else None,
                return_tensors="pt",
            ).to(self.model.device)

            # Generate response
            generated_ids = self.model.generate(
                **model_inputs,
                max_new_tokens=max_tokens,
                temperature=temperature,
                top_p=top_p,
                do_sample=True
            )

            # Process output
            trimmed_generated_ids = [
                out_ids[len(in_ids):] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)
            ]
            output_text = self.processor.batch_decode(
                trimmed_generated_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False
            )[0]

            return output_text
        except Exception as e:
            logger.error(f"Error generating response: {e}")
            return f"Error: {str(e)}"

def create_interface():
    vlm = SimpleVLMInterface()
    with gr.Blocks(title="Simple VLM Interface") as demo:
        with gr.Row():
            with gr.Column():
                image_input = gr.Image(type="pil", label="Upload Image (optional)")
                text_input = gr.Textbox(label="Enter your text", lines=2)
                with gr.Row():
                    max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max tokens")
                    temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
                    top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
                submit_btn = gr.Button("Generate Response")
                output_text = gr.Textbox(label="Response", lines=4)

        submit_btn.click(
            fn=vlm.generate_response,
            inputs=[text_input, image_input, max_tokens, temperature, top_p],
            outputs=output_text
        )

    return demo

if __name__ == "__main__":
    demo = create_interface()
    demo.launch()