szymmon commited on
Commit
7a9fb99
·
1 Parent(s): 62733d8
Files changed (1) hide show
  1. app.py +38 -48
app.py CHANGED
@@ -5,7 +5,7 @@ import logging
5
 
6
  logger = logging.getLogger(__name__)
7
 
8
- class CustomModelChat:
9
  def __init__(self):
10
  self.model = None
11
  self.processor = None
@@ -21,58 +21,41 @@ class CustomModelChat:
21
  )
22
  self.processor = AutoProcessor.from_pretrained(model_id)
23
 
24
- # Load your custom adapter
25
  adapter_path = "smolvlm-instruct-trl-sft-ChartQA"
26
  self.model.load_adapter(adapter_path)
27
  except Exception as e:
28
  logger.error(f"Error initializing model: {e}")
29
  raise
30
 
31
- def process_chat_history(self, history, system_message):
32
- # Convert chat history to the format expected by the model
33
- messages = [{"role": "system", "content": system_message}]
34
-
35
- for user_msg, assistant_msg in history:
36
- if user_msg:
37
- messages.append({"role": "user", "content": user_msg})
38
- if assistant_msg:
39
- messages.append({"role": "assistant", "content": assistant_msg})
40
-
41
- return messages
42
-
43
  def generate_response(
44
  self,
45
- message,
46
- history,
47
- system_message,
48
  max_tokens=512,
49
  temperature=0.7,
50
- top_p=0.95,
51
- image=None
52
  ):
53
  try:
54
- messages = self.process_chat_history(history, system_message)
55
- messages.append({"role": "user", "content": message})
56
-
57
- # Prepare the chat template
58
  chat_input = self.processor.apply_chat_template(
59
- messages[1:2], # Exclude system message
60
  add_generation_prompt=True
61
  )
62
 
63
- # Handle image input if provided
64
  if image is not None:
65
  if image.mode != 'RGB':
66
  image = image.convert('RGB')
67
- # Create a list of None for previous turns and add the current image
68
- image_inputs = [None] * (len(messages) - 1) + [image]
69
  else:
70
  image_inputs = None
71
 
72
  # Prepare model inputs
73
  model_inputs = self.processor(
74
  text=chat_input,
75
- images=image_inputs if image_inputs else None,
76
  return_tensors="pt",
77
  ).to(self.model.device)
78
 
@@ -85,41 +68,48 @@ class CustomModelChat:
85
  do_sample=True
86
  )
87
 
88
- # Trim and decode the response
89
  trimmed_generated_ids = [
90
  out_ids[len(in_ids):] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)
91
  ]
92
-
93
  output_text = self.processor.batch_decode(
94
  trimmed_generated_ids,
95
  skip_special_tokens=True,
96
  clean_up_tokenization_spaces=False
97
- )
98
 
99
- yield output_text[0]
100
 
101
  except Exception as e:
102
  logger.error(f"Error generating response: {e}")
103
- yield f"Error: {str(e)}"
104
 
105
- def create_chat_interface():
106
- chat_model = CustomModelChat()
107
 
108
- demo = gr.ChatInterface(
109
- chat_model.generate_response,
110
- additional_inputs=[
111
- gr.Textbox(value="You are a helpful assistant.", label="System message"),
112
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
113
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
114
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
115
- gr.Image(type="pil", label="Upload Image (optional)")
116
- ],
117
- title="Custom SmolVLM Chat",
118
- description="Chat interface using custom fine-tuned SmolVLM model"
119
- )
 
 
 
 
 
 
 
 
120
 
121
  return demo
122
 
123
  if __name__ == "__main__":
124
- demo = create_chat_interface()
125
  demo.launch()
 
5
 
6
  logger = logging.getLogger(__name__)
7
 
8
+ class SimpleVLMInterface:
9
  def __init__(self):
10
  self.model = None
11
  self.processor = None
 
21
  )
22
  self.processor = AutoProcessor.from_pretrained(model_id)
23
 
24
+ # Load custom adapter
25
  adapter_path = "smolvlm-instruct-trl-sft-ChartQA"
26
  self.model.load_adapter(adapter_path)
27
  except Exception as e:
28
  logger.error(f"Error initializing model: {e}")
29
  raise
30
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def generate_response(
32
  self,
33
+ text_input,
34
+ image=None,
 
35
  max_tokens=512,
36
  temperature=0.7,
37
+ top_p=0.95
 
38
  ):
39
  try:
40
+ # Prepare the input text
41
+ messages = [{"role": "user", "content": text_input}]
 
 
42
  chat_input = self.processor.apply_chat_template(
43
+ messages,
44
  add_generation_prompt=True
45
  )
46
 
47
+ # Handle image input
48
  if image is not None:
49
  if image.mode != 'RGB':
50
  image = image.convert('RGB')
51
+ image_inputs = [image]
 
52
  else:
53
  image_inputs = None
54
 
55
  # Prepare model inputs
56
  model_inputs = self.processor(
57
  text=chat_input,
58
+ images=image_inputs,
59
  return_tensors="pt",
60
  ).to(self.model.device)
61
 
 
68
  do_sample=True
69
  )
70
 
71
+ # Process output
72
  trimmed_generated_ids = [
73
  out_ids[len(in_ids):] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)
74
  ]
 
75
  output_text = self.processor.batch_decode(
76
  trimmed_generated_ids,
77
  skip_special_tokens=True,
78
  clean_up_tokenization_spaces=False
79
+ )[0]
80
 
81
+ return output_text
82
 
83
  except Exception as e:
84
  logger.error(f"Error generating response: {e}")
85
+ return f"Error: {str(e)}"
86
 
87
+ def create_interface():
88
+ vlm = SimpleVLMInterface()
89
 
90
+ with gr.Blocks(title="Simple VLM Interface") as demo:
91
+ with gr.Row():
92
+ with gr.Column():
93
+ image_input = gr.Image(type="pil", label="Upload Image (optional)")
94
+ text_input = gr.Textbox(label="Enter your text", lines=2)
95
+
96
+ with gr.Row():
97
+ max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max tokens")
98
+ temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
99
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
100
+
101
+ submit_btn = gr.Button("Generate Response")
102
+
103
+ output_text = gr.Textbox(label="Response", lines=4)
104
+
105
+ submit_btn.click(
106
+ fn=vlm.generate_response,
107
+ inputs=[text_input, image_input, max_tokens, temperature, top_p],
108
+ outputs=output_text
109
+ )
110
 
111
  return demo
112
 
113
  if __name__ == "__main__":
114
+ demo = create_interface()
115
  demo.launch()