ubden commited on
Commit
3a15aeb
·
verified ·
1 Parent(s): 1049a4d

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +539 -244
handler.py CHANGED
@@ -1,245 +1,540 @@
1
- """
2
- PULSE-7B Enhanced Handler
3
- Ubden® Team - Edited by https://github.com/ck-cankurt
4
- Support: Text, Image URLs, and Base64 encoded images
5
- """
6
-
7
- import torch
8
- from typing import Dict, List, Any
9
- import base64
10
- from io import BytesIO
11
- from PIL import Image
12
- import requests
13
-
14
-
15
- class EndpointHandler:
16
- def __init__(self, path=""):
17
- """
18
- Hey there! Let's get this PULSE-7B model up and running.
19
- We'll load it from the HuggingFace hub directly, so no worries about local files.
20
-
21
- Args:
22
- path: Model directory path (we actually ignore this and load from HF hub)
23
- """
24
- print("🚀 Starting up PULSE-7B handler...")
25
- print("📝 Enhanced by Ubden® Team - github.com/ck-cankurt")
26
-
27
- # Let's see what hardware we're working with
28
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
- print(f"🖥️ Running on: {self.device}")
30
-
31
- try:
32
- # First attempt - using pipeline (easiest and most stable way)
33
- from transformers import pipeline
34
-
35
- print("📦 Fetching model from HuggingFace Hub...")
36
- self.pipe = pipeline(
37
- "text-generation",
38
- model="PULSE-ECG/PULSE-7B",
39
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
40
- device=0 if torch.cuda.is_available() else -1,
41
- trust_remote_code=True,
42
- model_kwargs={
43
- "low_cpu_mem_usage": True,
44
- "use_safetensors": True
45
- }
46
- )
47
- print("✅ Model loaded successfully via pipeline!")
48
-
49
- except Exception as e:
50
- print(f"⚠️ Pipeline didn't work out: {e}")
51
- print("🔄 Let me try a different approach...")
52
-
53
- try:
54
- # Plan B - load model and tokenizer separately
55
- from transformers import AutoTokenizer, LlamaForCausalLM
56
-
57
- # Get the tokenizer ready
58
- print("📖 Setting up tokenizer...")
59
- self.tokenizer = AutoTokenizer.from_pretrained(
60
- "PULSE-ECG/PULSE-7B",
61
- trust_remote_code=True
62
- )
63
-
64
- # Load the model as Llama (it works, trust me!)
65
- print("🧠 Loading the model as Llama...")
66
- self.model = LlamaForCausalLM.from_pretrained(
67
- "PULSE-ECG/PULSE-7B",
68
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
69
- device_map="auto",
70
- low_cpu_mem_usage=True,
71
- trust_remote_code=True
72
- )
73
-
74
- # Quick fix for padding token if it's missing
75
- if self.tokenizer.pad_token is None:
76
- self.tokenizer.pad_token = self.tokenizer.eos_token
77
- self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
78
-
79
- self.model.eval()
80
- self.use_pipeline = False
81
- print("✅ Model loaded successfully via direct loading!")
82
-
83
- except Exception as e2:
84
- print(f"😓 That didn't work either: {e2}")
85
- # If all else fails, we'll handle it gracefully
86
- self.pipe = None
87
- self.model = None
88
- self.tokenizer = None
89
- self.use_pipeline = None
90
- else:
91
- self.use_pipeline = True
92
-
93
- def process_image_input(self, image_input):
94
- """
95
- Handle both URL and base64 image inputs like a champ!
96
-
97
- Args:
98
- image_input: Can be a URL string or base64 encoded image
99
-
100
- Returns:
101
- PIL Image object or None if something goes wrong
102
- """
103
- try:
104
- # Check if it's a URL (starts with http/https)
105
- if isinstance(image_input, str) and (image_input.startswith('http://') or image_input.startswith('https://')):
106
- print(f"🌐 Fetching image from URL: {image_input[:50]}...")
107
- response = requests.get(image_input, timeout=10)
108
- response.raise_for_status()
109
- image = Image.open(BytesIO(response.content)).convert('RGB')
110
- print(" Image downloaded successfully!")
111
- return image
112
-
113
- # Must be base64 then
114
- elif isinstance(image_input, str):
115
- print("🔍 Decoding base64 image...")
116
- # Remove the data URL prefix if it exists
117
- if "base64," in image_input:
118
- image_input = image_input.split("base64,")[1]
119
-
120
- image_data = base64.b64decode(image_input)
121
- image = Image.open(BytesIO(image_data)).convert('RGB')
122
- print("✅ Image decoded successfully!")
123
- return image
124
-
125
- except Exception as e:
126
- print(f"❌ Couldn't process the image: {e}")
127
- return None
128
-
129
- return None
130
-
131
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
132
- """
133
- Main processing function - where the magic happens!
134
-
135
- Args:
136
- data: Input data with 'inputs' and optional 'parameters'
137
-
138
- Returns:
139
- List with the generated response
140
- """
141
- # Quick check - is our model ready?
142
- if self.use_pipeline is None:
143
- return [{
144
- "generated_text": "Oops! Model couldn't load properly. Please check the deployment settings.",
145
- "error": "Model initialization failed",
146
- "handler": "Ubden® Team Enhanced Handler"
147
- }]
148
-
149
- try:
150
- # Parse the inputs - flexible format support
151
- inputs = data.get("inputs", "")
152
- text = ""
153
- image = None
154
-
155
- if isinstance(inputs, dict):
156
- # Dictionary input - check for text and image
157
- text = inputs.get("text", inputs.get("prompt", str(inputs)))
158
-
159
- # Check for image in various formats
160
- image_input = inputs.get("image", inputs.get("image_url", inputs.get("image_base64", None)))
161
- if image_input:
162
- image = self.process_image_input(image_input)
163
- if image:
164
- # For now, we'll add a note about the image since we're text-only
165
- text = f"[Image provided - {image.size[0]}x{image.size[1]} pixels] {text}"
166
- else:
167
- # Simple string input
168
- text = str(inputs)
169
-
170
- if not text:
171
- return [{"generated_text": "Hey, I need some text to work with! Please provide an input."}]
172
-
173
- # Get generation parameters with sensible defaults
174
- parameters = data.get("parameters", {})
175
- max_new_tokens = min(parameters.get("max_new_tokens", 256), 1024)
176
- temperature = parameters.get("temperature", 0.7)
177
- top_p = parameters.get("top_p", 0.95)
178
- do_sample = parameters.get("do_sample", True)
179
- repetition_penalty = parameters.get("repetition_penalty", 1.0)
180
-
181
- # Using pipeline? Let's go!
182
- if self.use_pipeline:
183
- result = self.pipe(
184
- text,
185
- max_new_tokens=max_new_tokens,
186
- temperature=temperature,
187
- top_p=top_p,
188
- do_sample=do_sample,
189
- repetition_penalty=repetition_penalty,
190
- return_full_text=False # Just the new stuff, not the input
191
- )
192
-
193
- # Pipeline returns a list, let's handle it
194
- if isinstance(result, list) and len(result) > 0:
195
- return [{"generated_text": result[0].get("generated_text", "")}]
196
- else:
197
- return [{"generated_text": str(result)}]
198
-
199
- # Manual generation mode
200
- else:
201
- # Tokenize the input
202
- encoded = self.tokenizer(
203
- text,
204
- return_tensors="pt",
205
- truncation=True,
206
- max_length=2048
207
- )
208
-
209
- input_ids = encoded["input_ids"].to(self.device)
210
- attention_mask = encoded.get("attention_mask")
211
- if attention_mask is not None:
212
- attention_mask = attention_mask.to(self.device)
213
-
214
- # Generate the response
215
- with torch.no_grad():
216
- outputs = self.model.generate(
217
- input_ids,
218
- attention_mask=attention_mask,
219
- max_new_tokens=max_new_tokens,
220
- temperature=temperature,
221
- top_p=top_p,
222
- do_sample=do_sample,
223
- repetition_penalty=repetition_penalty,
224
- pad_token_id=self.tokenizer.pad_token_id,
225
- eos_token_id=self.tokenizer.eos_token_id
226
- )
227
-
228
- # Decode only the new tokens (not the input)
229
- generated_ids = outputs[0][input_ids.shape[-1]:]
230
- generated_text = self.tokenizer.decode(
231
- generated_ids,
232
- skip_special_tokens=True,
233
- clean_up_tokenization_spaces=True
234
- )
235
-
236
- return [{"generated_text": generated_text}]
237
-
238
- except Exception as e:
239
- error_msg = f"Something went wrong during generation: {str(e)}"
240
- print(f"❌ {error_msg}")
241
- return [{
242
- "generated_text": "",
243
- "error": error_msg,
244
- "handler": "Ubden® Team Enhanced Handler"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  }]
 
1
+ """
2
+ PULSE-7B Enhanced Handler
3
+ Ubden® Team - Edited by https://github.com/ck-cankurt
4
+ Support: Text, Image URLs, and Base64 encoded images
5
+ """
6
+
7
+ import torch
8
+ from typing import Dict, List, Any
9
+ import base64
10
+ from io import BytesIO
11
+ from PIL import Image
12
+ import requests
13
+ import time
14
+
15
+ # Import utilities if available
16
+ try:
17
+ from utils import (
18
+ performance_monitor,
19
+ validate_image_input,
20
+ sanitize_parameters,
21
+ get_system_info,
22
+ create_health_check,
23
+ deepseek_client
24
+ )
25
+ UTILS_AVAILABLE = True
26
+ except ImportError:
27
+ UTILS_AVAILABLE = False
28
+ deepseek_client = None
29
+ print("⚠️ Utils module not found - performance monitoring and DeepSeek integration disabled")
30
+
31
+ # Try to import LLaVA modules for proper conversation handling
32
+ try:
33
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
34
+ from llava.conversation import conv_templates, SeparatorStyle
35
+ from llava.mm_utils import tokenizer_image_token, process_images, KeywordsStoppingCriteria
36
+ LLAVA_AVAILABLE = True
37
+ print("✅ LLaVA modules imported successfully")
38
+ except ImportError:
39
+ LLAVA_AVAILABLE = False
40
+ print("⚠️ LLaVA modules not available - using basic text processing")
41
+
42
+
43
+ class EndpointHandler:
44
+ def __init__(self, path=""):
45
+ """
46
+ Hey there! Let's get this PULSE-7B model up and running.
47
+ We'll load it from the HuggingFace hub directly, so no worries about local files.
48
+
49
+ Args:
50
+ path: Model directory path (we actually ignore this and load from HF hub)
51
+ """
52
+ print("🚀 Starting up PULSE-7B handler...")
53
+ print("📝 Enhanced by Ubden® Team - github.com/ck-cankurt")
54
+ import sys
55
+ print(f"🔧 Python version: {sys.version}")
56
+ print(f"🔧 PyTorch version: {torch.__version__}")
57
+
58
+ # Check transformers version
59
+ try:
60
+ import transformers
61
+ print(f"🔧 Transformers version: {transformers.__version__}")
62
+
63
+ # PULSE LLaVA works with transformers==4.37.2
64
+ if transformers.__version__ == "4.37.2":
65
+ print(" Using PULSE LLaVA compatible version (4.37.2)")
66
+ elif "dev" in transformers.__version__ or "git" in str(transformers.__version__):
67
+ print("⚠️ Using development version - may conflict with PULSE LLaVA")
68
+ else:
69
+ print("⚠️ Using different version - PULSE LLaVA prefers 4.37.2")
70
+ except Exception as e:
71
+ print(f"❌ Error checking transformers version: {e}")
72
+
73
+ print(f"🔧 CUDA available: {torch.cuda.is_available()}")
74
+ if torch.cuda.is_available():
75
+ print(f"🔧 CUDA device: {torch.cuda.get_device_name(0)}")
76
+
77
+ # Let's see what hardware we're working with
78
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
+ print(f"🖥️ Running on: {self.device}")
80
+
81
+ try:
82
+ # First attempt - PULSE demo's exact approach
83
+ if LLAVA_AVAILABLE:
84
+ print("📦 Using PULSE demo's load_pretrained_model approach...")
85
+ from llava.model.builder import load_pretrained_model
86
+ from llava.mm_utils import get_model_name_from_path
87
+
88
+ model_path = "PULSE-ECG/PULSE-7B"
89
+ model_name = get_model_name_from_path(model_path)
90
+
91
+ self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
92
+ model_path=model_path,
93
+ model_base=None,
94
+ model_name=model_name,
95
+ load_8bit=False,
96
+ load_4bit=False
97
+ )
98
+
99
+ # Move model to device like demo
100
+ self.model = self.model.to(self.device)
101
+ self.use_pipeline = False
102
+ print("✅ Model loaded successfully with PULSE demo's approach!")
103
+ print(f"📸 Image processor: {type(self.image_processor).__name__}")
104
+
105
+ else:
106
+ raise ImportError("LLaVA modules not available")
107
+
108
+ except Exception as e:
109
+ print(f"⚠️ PULSE demo approach failed: {e}")
110
+ print("🔄 Falling back to pipeline...")
111
+
112
+ try:
113
+ # Fallback - using pipeline
114
+ from transformers import pipeline
115
+
116
+ print("📦 Fetching model from HuggingFace Hub...")
117
+ self.pipe = pipeline(
118
+ "text-generation",
119
+ model="PULSE-ECG/PULSE-7B",
120
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
121
+ device=0 if torch.cuda.is_available() else -1,
122
+ trust_remote_code=True,
123
+ model_kwargs={
124
+ "low_cpu_mem_usage": True,
125
+ "use_safetensors": True
126
+ }
127
+ )
128
+ self.use_pipeline = True
129
+ self.image_processor = None
130
+ print("✅ Model loaded successfully via pipeline!")
131
+
132
+ except Exception as e2:
133
+ print(f"😓 Pipeline also failed: {e2}")
134
+
135
+ try:
136
+ # Last resort - manual loading
137
+ from transformers import AutoTokenizer, LlamaForCausalLM
138
+
139
+ print("📖 Manual loading as last resort...")
140
+ self.tokenizer = AutoTokenizer.from_pretrained(
141
+ "PULSE-ECG/PULSE-7B",
142
+ trust_remote_code=True
143
+ )
144
+
145
+ self.model = LlamaForCausalLM.from_pretrained(
146
+ "PULSE-ECG/PULSE-7B",
147
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
148
+ device_map="auto",
149
+ low_cpu_mem_usage=True,
150
+ trust_remote_code=True
151
+ )
152
+
153
+ if self.tokenizer.pad_token is None:
154
+ self.tokenizer.pad_token = self.tokenizer.eos_token
155
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
156
+
157
+ self.model.eval()
158
+ self.use_pipeline = False
159
+ self.image_processor = None
160
+ print(" Model loaded manually!")
161
+
162
+ except Exception as e3:
163
+ print(f"😓 All approaches failed: {e3}")
164
+ self.pipe = None
165
+ self.model = None
166
+ self.tokenizer = None
167
+ self.image_processor = None
168
+ self.use_pipeline = None
169
+
170
+ # Final status report
171
+ print("\n🔍 Model Loading Status Report:")
172
+ print(f" - use_pipeline: {self.use_pipeline}")
173
+ print(f" - model: {'✅ Loaded' if hasattr(self, 'model') and self.model is not None else '❌ None'}")
174
+ print(f" - tokenizer: {'✅ Loaded' if hasattr(self, 'tokenizer') and self.tokenizer is not None else '❌ None'}")
175
+ print(f" - image_processor: {'✅ Loaded' if hasattr(self, 'image_processor') and self.image_processor is not None else '❌ None'}")
176
+ print(f" - pipe: {'✅ Loaded' if hasattr(self, 'pipe') and self.pipe is not None else '❌ None'}")
177
+
178
+ # Check if any model component loaded successfully
179
+ has_model = hasattr(self, 'model') and self.model is not None
180
+ has_tokenizer = hasattr(self, 'tokenizer') and self.tokenizer is not None
181
+ has_pipe = hasattr(self, 'pipe') and self.pipe is not None
182
+ has_image_processor = hasattr(self, 'image_processor') and self.image_processor is not None
183
+
184
+ if not (has_model or has_tokenizer or has_pipe):
185
+ print("💥 CRITICAL: No model components loaded successfully!")
186
+ else:
187
+ print("✅ At least one model component loaded successfully")
188
+ if has_image_processor:
189
+ print("🖼️ Vision capabilities available!")
190
+ else:
191
+ print("⚠️ No image processor - text-only mode")
192
+
193
+ def is_valid_image_format(self, filename_or_url):
194
+ """Validate image format like PULSE demo"""
195
+ # Demo's supported formats
196
+ image_extensions = ["jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "heic", "heif", "jfif", "svg", "eps", "raw"]
197
+
198
+ if filename_or_url.startswith(('http://', 'https://')):
199
+ # For URLs, check the extension or content-type
200
+ ext = filename_or_url.split('.')[-1].split('?')[0].lower()
201
+ return ext in image_extensions
202
+ else:
203
+ # For base64 or local files
204
+ return True # Base64 will be validated during decode
205
+
206
+ def process_image_input(self, image_input):
207
+ """
208
+ Handle both URL and base64 image inputs exactly like PULSE demo
209
+
210
+ Args:
211
+ image_input: Can be a URL string or base64 encoded image
212
+
213
+ Returns:
214
+ PIL Image object or None if something goes wrong
215
+ """
216
+ try:
217
+ # Check if it's a URL (starts with http/https)
218
+ if isinstance(image_input, str) and (image_input.startswith('http://') or image_input.startswith('https://')):
219
+ print(f"🌐 Fetching image from URL: {image_input[:50]}...")
220
+
221
+ # Validate format like demo
222
+ if not self.is_valid_image_format(image_input):
223
+ print("❌ Invalid image format in URL")
224
+ return None
225
+
226
+ # Demo's exact image loading approach
227
+ response = requests.get(image_input, timeout=15)
228
+ if response.status_code == 200:
229
+ image = Image.open(BytesIO(response.content)).convert("RGB")
230
+ print(f"✅ Image downloaded successfully! Size: {image.size}")
231
+ return image
232
+ else:
233
+ print(f"❌ Failed to load image: status {response.status_code}")
234
+ return None
235
+
236
+ # Must be base64 then
237
+ elif isinstance(image_input, str):
238
+ print("🔍 Decoding base64 image...")
239
+
240
+ # Remove the data URL prefix if it exists
241
+ base64_data = image_input
242
+ if "base64," in image_input:
243
+ base64_data = image_input.split("base64,")[1]
244
+
245
+ # Clean and validate base64 data
246
+ base64_data = base64_data.strip().replace('\n', '').replace('\r', '').replace(' ', '')
247
+
248
+ try:
249
+ image_data = base64.b64decode(base64_data)
250
+ image = Image.open(BytesIO(image_data)).convert('RGB')
251
+ print(f"✅ Base64 image decoded successfully! Size: {image.size}")
252
+ return image
253
+ except Exception as decode_error:
254
+ print(f"❌ Base64 decode error: {decode_error}")
255
+ return None
256
+
257
+ except Exception as e:
258
+ print(f"❌ Couldn't process the image: {e}")
259
+ return None
260
+
261
+ return None
262
+
263
+ def add_turkish_commentary(self, response: Dict[str, Any], enable_commentary: bool, timeout: int = 30) -> Dict[str, Any]:
264
+ """Add Turkish commentary to the response using DeepSeek API"""
265
+ if not enable_commentary:
266
+ return response
267
+
268
+ if not UTILS_AVAILABLE or not deepseek_client:
269
+ print("⚠️ DeepSeek client not available - skipping Turkish commentary")
270
+ response["commentary_status"] = "unavailable"
271
+ return response
272
+
273
+ if not deepseek_client.is_available():
274
+ print("⚠️ DeepSeek API key not configured - skipping Turkish commentary")
275
+ response["commentary_status"] = "api_key_missing"
276
+ return response
277
+
278
+ generated_text = response.get("generated_text", "")
279
+ if not generated_text:
280
+ print("⚠️ No generated text to comment on")
281
+ response["commentary_status"] = "no_text"
282
+ return response
283
+
284
+ print("🔄 DeepSeek ile Türkçe yorum ekleniyor...")
285
+ commentary_result = deepseek_client.get_turkish_commentary(generated_text, timeout)
286
+
287
+ if commentary_result["success"]:
288
+ response["comment_text"] = commentary_result["comment_text"]
289
+ response["commentary_model"] = commentary_result.get("model", "deepseek-chat")
290
+ response["commentary_tokens"] = commentary_result.get("tokens_used", 0)
291
+ response["commentary_status"] = "success"
292
+ print("✅ Türkçe yorum başarıyla eklendi")
293
+ else:
294
+ response["comment_text"] = ""
295
+ response["commentary_error"] = commentary_result["error"]
296
+ response["commentary_status"] = "failed"
297
+ print(f"❌ Türkçe yorum eklenemedi: {commentary_result['error']}")
298
+
299
+ return response
300
+
301
+ def health_check(self) -> Dict[str, Any]:
302
+ """Health check endpoint"""
303
+ if UTILS_AVAILABLE:
304
+ return create_health_check()
305
+ else:
306
+ return {
307
+ 'status': 'healthy',
308
+ 'model': 'PULSE-7B',
309
+ 'timestamp': time.time(),
310
+ 'handler_version': '2.0.0'
311
+ }
312
+
313
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
314
+ """
315
+ Main processing function - where the magic happens!
316
+
317
+ Args:
318
+ data: Input data with 'inputs' and optional 'parameters'
319
+
320
+ Returns:
321
+ List with the generated response
322
+ """
323
+ # Quick check - is our model ready?
324
+ if self.use_pipeline is None:
325
+ return [{
326
+ "generated_text": "Oops! Model couldn't load properly. Please check the deployment settings.",
327
+ "error": "Model initialization failed",
328
+ "handler": "Ubden® Team Enhanced Handler"
329
+ }]
330
+
331
+ try:
332
+ # Parse the inputs - flexible format support
333
+ inputs = data.get("inputs", "")
334
+ text = ""
335
+ image = None
336
+
337
+ if isinstance(inputs, dict):
338
+ # Dictionary input - check for text and image
339
+ # Support query field (new) plus original text/prompt fields
340
+ text = inputs.get("query", inputs.get("text", inputs.get("prompt", str(inputs))))
341
+
342
+ # Check for image in various formats
343
+ image_input = inputs.get("image", inputs.get("image_url", inputs.get("image_base64", None)))
344
+ if image_input:
345
+ image = self.process_image_input(image_input)
346
+ if image:
347
+ # Since we're in text-only mode, create smart ECG context
348
+ print(f"🖼️ Image loaded: {image.size[0]}x{image.size[1]} pixels - using text-only ECG analysis mode")
349
+
350
+ # Create ECG-specific prompt that mimics visual analysis
351
+ ecg_context = f"Analyzing an ECG image ({image.size[0]}x{image.size[1]} pixels). "
352
+
353
+ # Use demo's exact approach - no additional context, just the query
354
+ # Model is trained to understand ECG images from text queries
355
+ pass # Keep text exactly as received
356
+ else:
357
+ # Simple string input
358
+ text = str(inputs)
359
+
360
+ if not text:
361
+ return [{"generated_text": "Hey, I need some text to work with! Please provide an input."}]
362
+
363
+ # Get generation parameters - using PULSE-7B demo's exact settings
364
+ parameters = data.get("parameters", {})
365
+ max_new_tokens = min(parameters.get("max_new_tokens", 1024), 8192) # Demo uses 1024 default
366
+ temperature = parameters.get("temperature", 0.05) # Demo uses 0.05 for precise medical analysis
367
+ top_p = parameters.get("top_p", 1.0) # Demo uses 1.0 for full vocabulary access
368
+ do_sample = parameters.get("do_sample", True) # Demo uses sampling
369
+ repetition_penalty = parameters.get("repetition_penalty", 1.0) # Demo default
370
+
371
+ print(f"🎛️ Generation params: max_tokens={max_new_tokens}, temp={temperature}, top_p={top_p}, do_sample={do_sample}, rep_penalty={repetition_penalty}")
372
+
373
+ # Check if Turkish commentary is requested (NEW FEATURE)
374
+ enable_turkish_commentary = parameters.get("enable_turkish_commentary", False) # Default false
375
+
376
+ # Using pipeline? Let's go!
377
+ if self.use_pipeline:
378
+ print(f"🎛️ Pipeline generation: temp={temperature}, tokens={max_new_tokens}")
379
+ print(f"📝 Input text: '{text[:100]}...'")
380
+
381
+ result = self.pipe(
382
+ text,
383
+ max_new_tokens=max_new_tokens,
384
+ min_new_tokens=200, # Force very detailed analysis to match demo
385
+ temperature=temperature,
386
+ top_p=top_p,
387
+ do_sample=do_sample,
388
+ repetition_penalty=repetition_penalty,
389
+ return_full_text=False # Just the new stuff, not the input
390
+ )
391
+
392
+ # Pipeline returns a list, let's handle it
393
+ if isinstance(result, list) and len(result) > 0:
394
+ generated_text = result[0].get("generated_text", "").strip()
395
+
396
+ print(f"🔍 Pipeline debug:")
397
+ print(f" - Raw result: '{str(result[0])[:200]}...'")
398
+ print(f" - Generated text length: {len(generated_text)}")
399
+
400
+ # Clean up common issues
401
+ if generated_text.startswith(text):
402
+ generated_text = generated_text[len(text):].strip()
403
+ print("🔧 Removed input text from output")
404
+
405
+ # Remove common artifacts
406
+ generated_text = generated_text.replace("</s>", "").strip()
407
+
408
+ if not generated_text:
409
+ print("❌ Pipeline generated empty text!")
410
+ generated_text = "Empty response from pipeline. Please try different parameters."
411
+
412
+ print(f"✅ Final pipeline text: '{generated_text[:100]}...' (length: {len(generated_text)})")
413
+
414
+ # Create response
415
+ response = {"generated_text": generated_text}
416
+
417
+ # Add Turkish commentary if requested (NEW FEATURE)
418
+ if enable_turkish_commentary:
419
+ response = self.add_turkish_commentary(response, True)
420
+
421
+ return [response]
422
+ else:
423
+ generated_text = str(result).strip()
424
+
425
+ # Create response
426
+ response = {"generated_text": generated_text}
427
+
428
+ # Add Turkish commentary if requested (NEW FEATURE)
429
+ if enable_turkish_commentary:
430
+ response = self.add_turkish_commentary(response, True)
431
+
432
+ return [response]
433
+
434
+ # Manual generation mode - using PULSE demo's exact approach
435
+ else:
436
+ print(f"🔥 Manual generation with PULSE demo logic: temp={temperature}, tokens={max_new_tokens}")
437
+ print(f"📝 Input text: '{text[:100]}...'")
438
+
439
+ # Text-only generation with enhanced ECG context
440
+ print("🔤 Using enhanced text-only generation with ECG context")
441
+
442
+ # Tokenize the enhanced prompt
443
+ encoded = self.tokenizer(
444
+ text,
445
+ return_tensors="pt",
446
+ truncation=True,
447
+ max_length=4096 # Increased for longer prompts
448
+ )
449
+
450
+ input_ids = encoded["input_ids"].to(self.device)
451
+ attention_mask = encoded.get("attention_mask")
452
+ if attention_mask is not None:
453
+ attention_mask = attention_mask.to(self.device)
454
+
455
+ print(f"🔍 Enhanced generation debug:")
456
+ print(f" - Enhanced prompt length: {len(text)} chars")
457
+ print(f" - Input tokens: {input_ids.shape[-1]}")
458
+ print(f" - Prompt preview: '{text[:150]}...'")
459
+
460
+ # Generate with enhanced settings for medical analysis
461
+ with torch.no_grad():
462
+ outputs = self.model.generate(
463
+ input_ids,
464
+ attention_mask=attention_mask,
465
+ max_new_tokens=max_new_tokens,
466
+ min_new_tokens=200, # Force detailed response like demo
467
+ temperature=temperature,
468
+ top_p=top_p,
469
+ do_sample=do_sample,
470
+ repetition_penalty=repetition_penalty,
471
+ pad_token_id=self.tokenizer.pad_token_id,
472
+ eos_token_id=self.tokenizer.eos_token_id,
473
+ early_stopping=False
474
+ )
475
+
476
+ # Decode and clean response
477
+ generated_ids = outputs[0][input_ids.shape[-1]:]
478
+ generated_text = self.tokenizer.decode(
479
+ generated_ids,
480
+ skip_special_tokens=True,
481
+ clean_up_tokenization_spaces=True
482
+ ).strip()
483
+
484
+ # Aggressive cleanup of artifacts
485
+ generated_text = generated_text.replace("</s>", "").strip()
486
+
487
+ # Simple cleanup - just remove Answer prefix and parentheses
488
+ if generated_text.startswith("(Answer:") and ")" in generated_text:
489
+ # Just remove the parentheses and Answer: prefix
490
+ end_paren = generated_text.find(")")
491
+ answer_content = generated_text[8:end_paren].strip() # Remove "(Answer:"
492
+ # Keep the rest of the response if there is any
493
+ rest_of_response = generated_text[end_paren+1:].strip()
494
+
495
+ if rest_of_response:
496
+ generated_text = f"{answer_content}. {rest_of_response}"
497
+ else:
498
+ generated_text = answer_content
499
+
500
+ elif generated_text.startswith("Answer:"):
501
+ generated_text = generated_text[7:].strip()
502
+
503
+ # Remove only clear training artifacts
504
+ cleanup_patterns = [
505
+ "In this task",
506
+ "I'm asking the respondent",
507
+ "The respondent should"
508
+ ]
509
+
510
+ for pattern in cleanup_patterns:
511
+ if pattern in generated_text:
512
+ parts = generated_text.split(pattern)
513
+ generated_text = parts[0].strip()
514
+ break
515
+
516
+ # Only provide fallback if response is truly empty or malformed
517
+ if len(generated_text) < 10 or generated_text.startswith("7)"):
518
+ print("⚠️ Malformed response detected, providing fallback...")
519
+ generated_text = "This ECG shows cardiac electrical activity. For accurate interpretation, please consult with a qualified cardiologist who can analyze the specific waveforms, intervals, and morphology patterns."
520
+
521
+ print(f"✅ Enhanced text-only generation: '{generated_text[:100]}...' (length: {len(generated_text)})")
522
+
523
+ # Create response
524
+ response = {"generated_text": generated_text}
525
+
526
+ # Add Turkish commentary if requested (NEW FEATURE)
527
+ if enable_turkish_commentary:
528
+ response = self.add_turkish_commentary(response, True)
529
+
530
+ return [response]
531
+
532
+
533
+ except Exception as e:
534
+ error_msg = f"Something went wrong during generation: {str(e)}"
535
+ print(f"❌ {error_msg}")
536
+ return [{
537
+ "generated_text": "",
538
+ "error": error_msg,
539
+ "handler": "Ubden® Team Enhanced Handler"
540
  }]