ubden commited on
Commit
1049a4d
·
verified ·
1 Parent(s): 20c5de2

Upload 17 files

Browse files
Files changed (2) hide show
  1. handler.bak.py +540 -0
  2. handler.py +244 -539
handler.bak.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PULSE-7B Enhanced Handler
3
+ Ubden® Team - Edited by https://github.com/ck-cankurt
4
+ Support: Text, Image URLs, and Base64 encoded images
5
+ """
6
+
7
+ import torch
8
+ from typing import Dict, List, Any
9
+ import base64
10
+ from io import BytesIO
11
+ from PIL import Image
12
+ import requests
13
+ import time
14
+
15
+ # Import utilities if available
16
+ try:
17
+ from utils import (
18
+ performance_monitor,
19
+ validate_image_input,
20
+ sanitize_parameters,
21
+ get_system_info,
22
+ create_health_check,
23
+ deepseek_client
24
+ )
25
+ UTILS_AVAILABLE = True
26
+ except ImportError:
27
+ UTILS_AVAILABLE = False
28
+ deepseek_client = None
29
+ print("⚠️ Utils module not found - performance monitoring and DeepSeek integration disabled")
30
+
31
+ # Try to import LLaVA modules for proper conversation handling
32
+ try:
33
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
34
+ from llava.conversation import conv_templates, SeparatorStyle
35
+ from llava.mm_utils import tokenizer_image_token, process_images, KeywordsStoppingCriteria
36
+ LLAVA_AVAILABLE = True
37
+ print("✅ LLaVA modules imported successfully")
38
+ except ImportError:
39
+ LLAVA_AVAILABLE = False
40
+ print("⚠️ LLaVA modules not available - using basic text processing")
41
+
42
+
43
+ class EndpointHandler:
44
+ def __init__(self, path=""):
45
+ """
46
+ Hey there! Let's get this PULSE-7B model up and running.
47
+ We'll load it from the HuggingFace hub directly, so no worries about local files.
48
+
49
+ Args:
50
+ path: Model directory path (we actually ignore this and load from HF hub)
51
+ """
52
+ print("🚀 Starting up PULSE-7B handler...")
53
+ print("📝 Enhanced by Ubden® Team - github.com/ck-cankurt")
54
+ import sys
55
+ print(f"🔧 Python version: {sys.version}")
56
+ print(f"🔧 PyTorch version: {torch.__version__}")
57
+
58
+ # Check transformers version
59
+ try:
60
+ import transformers
61
+ print(f"🔧 Transformers version: {transformers.__version__}")
62
+
63
+ # PULSE LLaVA works with transformers==4.37.2
64
+ if transformers.__version__ == "4.37.2":
65
+ print("✅ Using PULSE LLaVA compatible version (4.37.2)")
66
+ elif "dev" in transformers.__version__ or "git" in str(transformers.__version__):
67
+ print("⚠️ Using development version - may conflict with PULSE LLaVA")
68
+ else:
69
+ print("⚠️ Using different version - PULSE LLaVA prefers 4.37.2")
70
+ except Exception as e:
71
+ print(f"❌ Error checking transformers version: {e}")
72
+
73
+ print(f"🔧 CUDA available: {torch.cuda.is_available()}")
74
+ if torch.cuda.is_available():
75
+ print(f"🔧 CUDA device: {torch.cuda.get_device_name(0)}")
76
+
77
+ # Let's see what hardware we're working with
78
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
+ print(f"🖥️ Running on: {self.device}")
80
+
81
+ try:
82
+ # First attempt - PULSE demo's exact approach
83
+ if LLAVA_AVAILABLE:
84
+ print("📦 Using PULSE demo's load_pretrained_model approach...")
85
+ from llava.model.builder import load_pretrained_model
86
+ from llava.mm_utils import get_model_name_from_path
87
+
88
+ model_path = "PULSE-ECG/PULSE-7B"
89
+ model_name = get_model_name_from_path(model_path)
90
+
91
+ self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
92
+ model_path=model_path,
93
+ model_base=None,
94
+ model_name=model_name,
95
+ load_8bit=False,
96
+ load_4bit=False
97
+ )
98
+
99
+ # Move model to device like demo
100
+ self.model = self.model.to(self.device)
101
+ self.use_pipeline = False
102
+ print("✅ Model loaded successfully with PULSE demo's approach!")
103
+ print(f"📸 Image processor: {type(self.image_processor).__name__}")
104
+
105
+ else:
106
+ raise ImportError("LLaVA modules not available")
107
+
108
+ except Exception as e:
109
+ print(f"⚠️ PULSE demo approach failed: {e}")
110
+ print("🔄 Falling back to pipeline...")
111
+
112
+ try:
113
+ # Fallback - using pipeline
114
+ from transformers import pipeline
115
+
116
+ print("📦 Fetching model from HuggingFace Hub...")
117
+ self.pipe = pipeline(
118
+ "text-generation",
119
+ model="PULSE-ECG/PULSE-7B",
120
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
121
+ device=0 if torch.cuda.is_available() else -1,
122
+ trust_remote_code=True,
123
+ model_kwargs={
124
+ "low_cpu_mem_usage": True,
125
+ "use_safetensors": True
126
+ }
127
+ )
128
+ self.use_pipeline = True
129
+ self.image_processor = None
130
+ print("✅ Model loaded successfully via pipeline!")
131
+
132
+ except Exception as e2:
133
+ print(f"😓 Pipeline also failed: {e2}")
134
+
135
+ try:
136
+ # Last resort - manual loading
137
+ from transformers import AutoTokenizer, LlamaForCausalLM
138
+
139
+ print("📖 Manual loading as last resort...")
140
+ self.tokenizer = AutoTokenizer.from_pretrained(
141
+ "PULSE-ECG/PULSE-7B",
142
+ trust_remote_code=True
143
+ )
144
+
145
+ self.model = LlamaForCausalLM.from_pretrained(
146
+ "PULSE-ECG/PULSE-7B",
147
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
148
+ device_map="auto",
149
+ low_cpu_mem_usage=True,
150
+ trust_remote_code=True
151
+ )
152
+
153
+ if self.tokenizer.pad_token is None:
154
+ self.tokenizer.pad_token = self.tokenizer.eos_token
155
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
156
+
157
+ self.model.eval()
158
+ self.use_pipeline = False
159
+ self.image_processor = None
160
+ print("✅ Model loaded manually!")
161
+
162
+ except Exception as e3:
163
+ print(f"😓 All approaches failed: {e3}")
164
+ self.pipe = None
165
+ self.model = None
166
+ self.tokenizer = None
167
+ self.image_processor = None
168
+ self.use_pipeline = None
169
+
170
+ # Final status report
171
+ print("\n🔍 Model Loading Status Report:")
172
+ print(f" - use_pipeline: {self.use_pipeline}")
173
+ print(f" - model: {'✅ Loaded' if hasattr(self, 'model') and self.model is not None else '❌ None'}")
174
+ print(f" - tokenizer: {'✅ Loaded' if hasattr(self, 'tokenizer') and self.tokenizer is not None else '❌ None'}")
175
+ print(f" - image_processor: {'✅ Loaded' if hasattr(self, 'image_processor') and self.image_processor is not None else '❌ None'}")
176
+ print(f" - pipe: {'✅ Loaded' if hasattr(self, 'pipe') and self.pipe is not None else '❌ None'}")
177
+
178
+ # Check if any model component loaded successfully
179
+ has_model = hasattr(self, 'model') and self.model is not None
180
+ has_tokenizer = hasattr(self, 'tokenizer') and self.tokenizer is not None
181
+ has_pipe = hasattr(self, 'pipe') and self.pipe is not None
182
+ has_image_processor = hasattr(self, 'image_processor') and self.image_processor is not None
183
+
184
+ if not (has_model or has_tokenizer or has_pipe):
185
+ print("💥 CRITICAL: No model components loaded successfully!")
186
+ else:
187
+ print("✅ At least one model component loaded successfully")
188
+ if has_image_processor:
189
+ print("🖼️ Vision capabilities available!")
190
+ else:
191
+ print("⚠️ No image processor - text-only mode")
192
+
193
+ def is_valid_image_format(self, filename_or_url):
194
+ """Validate image format like PULSE demo"""
195
+ # Demo's supported formats
196
+ image_extensions = ["jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "heic", "heif", "jfif", "svg", "eps", "raw"]
197
+
198
+ if filename_or_url.startswith(('http://', 'https://')):
199
+ # For URLs, check the extension or content-type
200
+ ext = filename_or_url.split('.')[-1].split('?')[0].lower()
201
+ return ext in image_extensions
202
+ else:
203
+ # For base64 or local files
204
+ return True # Base64 will be validated during decode
205
+
206
+ def process_image_input(self, image_input):
207
+ """
208
+ Handle both URL and base64 image inputs exactly like PULSE demo
209
+
210
+ Args:
211
+ image_input: Can be a URL string or base64 encoded image
212
+
213
+ Returns:
214
+ PIL Image object or None if something goes wrong
215
+ """
216
+ try:
217
+ # Check if it's a URL (starts with http/https)
218
+ if isinstance(image_input, str) and (image_input.startswith('http://') or image_input.startswith('https://')):
219
+ print(f"🌐 Fetching image from URL: {image_input[:50]}...")
220
+
221
+ # Validate format like demo
222
+ if not self.is_valid_image_format(image_input):
223
+ print("❌ Invalid image format in URL")
224
+ return None
225
+
226
+ # Demo's exact image loading approach
227
+ response = requests.get(image_input, timeout=15)
228
+ if response.status_code == 200:
229
+ image = Image.open(BytesIO(response.content)).convert("RGB")
230
+ print(f"✅ Image downloaded successfully! Size: {image.size}")
231
+ return image
232
+ else:
233
+ print(f"❌ Failed to load image: status {response.status_code}")
234
+ return None
235
+
236
+ # Must be base64 then
237
+ elif isinstance(image_input, str):
238
+ print("🔍 Decoding base64 image...")
239
+
240
+ # Remove the data URL prefix if it exists
241
+ base64_data = image_input
242
+ if "base64," in image_input:
243
+ base64_data = image_input.split("base64,")[1]
244
+
245
+ # Clean and validate base64 data
246
+ base64_data = base64_data.strip().replace('\n', '').replace('\r', '').replace(' ', '')
247
+
248
+ try:
249
+ image_data = base64.b64decode(base64_data)
250
+ image = Image.open(BytesIO(image_data)).convert('RGB')
251
+ print(f"✅ Base64 image decoded successfully! Size: {image.size}")
252
+ return image
253
+ except Exception as decode_error:
254
+ print(f"❌ Base64 decode error: {decode_error}")
255
+ return None
256
+
257
+ except Exception as e:
258
+ print(f"❌ Couldn't process the image: {e}")
259
+ return None
260
+
261
+ return None
262
+
263
+ def add_turkish_commentary(self, response: Dict[str, Any], enable_commentary: bool, timeout: int = 30) -> Dict[str, Any]:
264
+ """Add Turkish commentary to the response using DeepSeek API"""
265
+ if not enable_commentary:
266
+ return response
267
+
268
+ if not UTILS_AVAILABLE or not deepseek_client:
269
+ print("⚠️ DeepSeek client not available - skipping Turkish commentary")
270
+ response["commentary_status"] = "unavailable"
271
+ return response
272
+
273
+ if not deepseek_client.is_available():
274
+ print("⚠️ DeepSeek API key not configured - skipping Turkish commentary")
275
+ response["commentary_status"] = "api_key_missing"
276
+ return response
277
+
278
+ generated_text = response.get("generated_text", "")
279
+ if not generated_text:
280
+ print("⚠️ No generated text to comment on")
281
+ response["commentary_status"] = "no_text"
282
+ return response
283
+
284
+ print("🔄 DeepSeek ile Türkçe yorum ekleniyor...")
285
+ commentary_result = deepseek_client.get_turkish_commentary(generated_text, timeout)
286
+
287
+ if commentary_result["success"]:
288
+ response["comment_text"] = commentary_result["comment_text"]
289
+ response["commentary_model"] = commentary_result.get("model", "deepseek-chat")
290
+ response["commentary_tokens"] = commentary_result.get("tokens_used", 0)
291
+ response["commentary_status"] = "success"
292
+ print("✅ Türkçe yorum başarıyla eklendi")
293
+ else:
294
+ response["comment_text"] = ""
295
+ response["commentary_error"] = commentary_result["error"]
296
+ response["commentary_status"] = "failed"
297
+ print(f"❌ Türkçe yorum eklenemedi: {commentary_result['error']}")
298
+
299
+ return response
300
+
301
+ def health_check(self) -> Dict[str, Any]:
302
+ """Health check endpoint"""
303
+ if UTILS_AVAILABLE:
304
+ return create_health_check()
305
+ else:
306
+ return {
307
+ 'status': 'healthy',
308
+ 'model': 'PULSE-7B',
309
+ 'timestamp': time.time(),
310
+ 'handler_version': '2.0.0'
311
+ }
312
+
313
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
314
+ """
315
+ Main processing function - where the magic happens!
316
+
317
+ Args:
318
+ data: Input data with 'inputs' and optional 'parameters'
319
+
320
+ Returns:
321
+ List with the generated response
322
+ """
323
+ # Quick check - is our model ready?
324
+ if self.use_pipeline is None:
325
+ return [{
326
+ "generated_text": "Oops! Model couldn't load properly. Please check the deployment settings.",
327
+ "error": "Model initialization failed",
328
+ "handler": "Ubden® Team Enhanced Handler"
329
+ }]
330
+
331
+ try:
332
+ # Parse the inputs - flexible format support
333
+ inputs = data.get("inputs", "")
334
+ text = ""
335
+ image = None
336
+
337
+ if isinstance(inputs, dict):
338
+ # Dictionary input - check for text and image
339
+ # Support query field (new) plus original text/prompt fields
340
+ text = inputs.get("query", inputs.get("text", inputs.get("prompt", str(inputs))))
341
+
342
+ # Check for image in various formats
343
+ image_input = inputs.get("image", inputs.get("image_url", inputs.get("image_base64", None)))
344
+ if image_input:
345
+ image = self.process_image_input(image_input)
346
+ if image:
347
+ # Since we're in text-only mode, create smart ECG context
348
+ print(f"🖼️ Image loaded: {image.size[0]}x{image.size[1]} pixels - using text-only ECG analysis mode")
349
+
350
+ # Create ECG-specific prompt that mimics visual analysis
351
+ ecg_context = f"Analyzing an ECG image ({image.size[0]}x{image.size[1]} pixels). "
352
+
353
+ # Use demo's exact approach - no additional context, just the query
354
+ # Model is trained to understand ECG images from text queries
355
+ pass # Keep text exactly as received
356
+ else:
357
+ # Simple string input
358
+ text = str(inputs)
359
+
360
+ if not text:
361
+ return [{"generated_text": "Hey, I need some text to work with! Please provide an input."}]
362
+
363
+ # Get generation parameters - using PULSE-7B demo's exact settings
364
+ parameters = data.get("parameters", {})
365
+ max_new_tokens = min(parameters.get("max_new_tokens", 1024), 8192) # Demo uses 1024 default
366
+ temperature = parameters.get("temperature", 0.05) # Demo uses 0.05 for precise medical analysis
367
+ top_p = parameters.get("top_p", 1.0) # Demo uses 1.0 for full vocabulary access
368
+ do_sample = parameters.get("do_sample", True) # Demo uses sampling
369
+ repetition_penalty = parameters.get("repetition_penalty", 1.0) # Demo default
370
+
371
+ print(f"🎛️ Generation params: max_tokens={max_new_tokens}, temp={temperature}, top_p={top_p}, do_sample={do_sample}, rep_penalty={repetition_penalty}")
372
+
373
+ # Check if Turkish commentary is requested (NEW FEATURE)
374
+ enable_turkish_commentary = parameters.get("enable_turkish_commentary", False) # Default false
375
+
376
+ # Using pipeline? Let's go!
377
+ if self.use_pipeline:
378
+ print(f"🎛️ Pipeline generation: temp={temperature}, tokens={max_new_tokens}")
379
+ print(f"📝 Input text: '{text[:100]}...'")
380
+
381
+ result = self.pipe(
382
+ text,
383
+ max_new_tokens=max_new_tokens,
384
+ min_new_tokens=200, # Force very detailed analysis to match demo
385
+ temperature=temperature,
386
+ top_p=top_p,
387
+ do_sample=do_sample,
388
+ repetition_penalty=repetition_penalty,
389
+ return_full_text=False # Just the new stuff, not the input
390
+ )
391
+
392
+ # Pipeline returns a list, let's handle it
393
+ if isinstance(result, list) and len(result) > 0:
394
+ generated_text = result[0].get("generated_text", "").strip()
395
+
396
+ print(f"🔍 Pipeline debug:")
397
+ print(f" - Raw result: '{str(result[0])[:200]}...'")
398
+ print(f" - Generated text length: {len(generated_text)}")
399
+
400
+ # Clean up common issues
401
+ if generated_text.startswith(text):
402
+ generated_text = generated_text[len(text):].strip()
403
+ print("🔧 Removed input text from output")
404
+
405
+ # Remove common artifacts
406
+ generated_text = generated_text.replace("</s>", "").strip()
407
+
408
+ if not generated_text:
409
+ print("❌ Pipeline generated empty text!")
410
+ generated_text = "Empty response from pipeline. Please try different parameters."
411
+
412
+ print(f"✅ Final pipeline text: '{generated_text[:100]}...' (length: {len(generated_text)})")
413
+
414
+ # Create response
415
+ response = {"generated_text": generated_text}
416
+
417
+ # Add Turkish commentary if requested (NEW FEATURE)
418
+ if enable_turkish_commentary:
419
+ response = self.add_turkish_commentary(response, True)
420
+
421
+ return [response]
422
+ else:
423
+ generated_text = str(result).strip()
424
+
425
+ # Create response
426
+ response = {"generated_text": generated_text}
427
+
428
+ # Add Turkish commentary if requested (NEW FEATURE)
429
+ if enable_turkish_commentary:
430
+ response = self.add_turkish_commentary(response, True)
431
+
432
+ return [response]
433
+
434
+ # Manual generation mode - using PULSE demo's exact approach
435
+ else:
436
+ print(f"🔥 Manual generation with PULSE demo logic: temp={temperature}, tokens={max_new_tokens}")
437
+ print(f"📝 Input text: '{text[:100]}...'")
438
+
439
+ # Text-only generation with enhanced ECG context
440
+ print("🔤 Using enhanced text-only generation with ECG context")
441
+
442
+ # Tokenize the enhanced prompt
443
+ encoded = self.tokenizer(
444
+ text,
445
+ return_tensors="pt",
446
+ truncation=True,
447
+ max_length=4096 # Increased for longer prompts
448
+ )
449
+
450
+ input_ids = encoded["input_ids"].to(self.device)
451
+ attention_mask = encoded.get("attention_mask")
452
+ if attention_mask is not None:
453
+ attention_mask = attention_mask.to(self.device)
454
+
455
+ print(f"🔍 Enhanced generation debug:")
456
+ print(f" - Enhanced prompt length: {len(text)} chars")
457
+ print(f" - Input tokens: {input_ids.shape[-1]}")
458
+ print(f" - Prompt preview: '{text[:150]}...'")
459
+
460
+ # Generate with enhanced settings for medical analysis
461
+ with torch.no_grad():
462
+ outputs = self.model.generate(
463
+ input_ids,
464
+ attention_mask=attention_mask,
465
+ max_new_tokens=max_new_tokens,
466
+ min_new_tokens=200, # Force detailed response like demo
467
+ temperature=temperature,
468
+ top_p=top_p,
469
+ do_sample=do_sample,
470
+ repetition_penalty=repetition_penalty,
471
+ pad_token_id=self.tokenizer.pad_token_id,
472
+ eos_token_id=self.tokenizer.eos_token_id,
473
+ early_stopping=False
474
+ )
475
+
476
+ # Decode and clean response
477
+ generated_ids = outputs[0][input_ids.shape[-1]:]
478
+ generated_text = self.tokenizer.decode(
479
+ generated_ids,
480
+ skip_special_tokens=True,
481
+ clean_up_tokenization_spaces=True
482
+ ).strip()
483
+
484
+ # Aggressive cleanup of artifacts
485
+ generated_text = generated_text.replace("</s>", "").strip()
486
+
487
+ # Simple cleanup - just remove Answer prefix and parentheses
488
+ if generated_text.startswith("(Answer:") and ")" in generated_text:
489
+ # Just remove the parentheses and Answer: prefix
490
+ end_paren = generated_text.find(")")
491
+ answer_content = generated_text[8:end_paren].strip() # Remove "(Answer:"
492
+ # Keep the rest of the response if there is any
493
+ rest_of_response = generated_text[end_paren+1:].strip()
494
+
495
+ if rest_of_response:
496
+ generated_text = f"{answer_content}. {rest_of_response}"
497
+ else:
498
+ generated_text = answer_content
499
+
500
+ elif generated_text.startswith("Answer:"):
501
+ generated_text = generated_text[7:].strip()
502
+
503
+ # Remove only clear training artifacts
504
+ cleanup_patterns = [
505
+ "In this task",
506
+ "I'm asking the respondent",
507
+ "The respondent should"
508
+ ]
509
+
510
+ for pattern in cleanup_patterns:
511
+ if pattern in generated_text:
512
+ parts = generated_text.split(pattern)
513
+ generated_text = parts[0].strip()
514
+ break
515
+
516
+ # Only provide fallback if response is truly empty or malformed
517
+ if len(generated_text) < 10 or generated_text.startswith("7)"):
518
+ print("⚠️ Malformed response detected, providing fallback...")
519
+ generated_text = "This ECG shows cardiac electrical activity. For accurate interpretation, please consult with a qualified cardiologist who can analyze the specific waveforms, intervals, and morphology patterns."
520
+
521
+ print(f"✅ Enhanced text-only generation: '{generated_text[:100]}...' (length: {len(generated_text)})")
522
+
523
+ # Create response
524
+ response = {"generated_text": generated_text}
525
+
526
+ # Add Turkish commentary if requested (NEW FEATURE)
527
+ if enable_turkish_commentary:
528
+ response = self.add_turkish_commentary(response, True)
529
+
530
+ return [response]
531
+
532
+
533
+ except Exception as e:
534
+ error_msg = f"Something went wrong during generation: {str(e)}"
535
+ print(f"❌ {error_msg}")
536
+ return [{
537
+ "generated_text": "",
538
+ "error": error_msg,
539
+ "handler": "Ubden® Team Enhanced Handler"
540
+ }]
handler.py CHANGED
@@ -1,540 +1,245 @@
1
- """
2
- PULSE-7B Enhanced Handler
3
- Ubden® Team - Edited by https://github.com/ck-cankurt
4
- Support: Text, Image URLs, and Base64 encoded images
5
- """
6
-
7
- import torch
8
- from typing import Dict, List, Any
9
- import base64
10
- from io import BytesIO
11
- from PIL import Image
12
- import requests
13
- import time
14
-
15
- # Import utilities if available
16
- try:
17
- from utils import (
18
- performance_monitor,
19
- validate_image_input,
20
- sanitize_parameters,
21
- get_system_info,
22
- create_health_check,
23
- deepseek_client
24
- )
25
- UTILS_AVAILABLE = True
26
- except ImportError:
27
- UTILS_AVAILABLE = False
28
- deepseek_client = None
29
- print("⚠️ Utils module not found - performance monitoring and DeepSeek integration disabled")
30
-
31
- # Try to import LLaVA modules for proper conversation handling
32
- try:
33
- from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
34
- from llava.conversation import conv_templates, SeparatorStyle
35
- from llava.mm_utils import tokenizer_image_token, process_images, KeywordsStoppingCriteria
36
- LLAVA_AVAILABLE = True
37
- print("✅ LLaVA modules imported successfully")
38
- except ImportError:
39
- LLAVA_AVAILABLE = False
40
- print("⚠️ LLaVA modules not available - using basic text processing")
41
-
42
-
43
- class EndpointHandler:
44
- def __init__(self, path=""):
45
- """
46
- Hey there! Let's get this PULSE-7B model up and running.
47
- We'll load it from the HuggingFace hub directly, so no worries about local files.
48
-
49
- Args:
50
- path: Model directory path (we actually ignore this and load from HF hub)
51
- """
52
- print("🚀 Starting up PULSE-7B handler...")
53
- print("📝 Enhanced by Ubden® Team - github.com/ck-cankurt")
54
- import sys
55
- print(f"🔧 Python version: {sys.version}")
56
- print(f"🔧 PyTorch version: {torch.__version__}")
57
-
58
- # Check transformers version
59
- try:
60
- import transformers
61
- print(f"🔧 Transformers version: {transformers.__version__}")
62
-
63
- # PULSE LLaVA works with transformers==4.37.2
64
- if transformers.__version__ == "4.37.2":
65
- print(" Using PULSE LLaVA compatible version (4.37.2)")
66
- elif "dev" in transformers.__version__ or "git" in str(transformers.__version__):
67
- print("⚠️ Using development version - may conflict with PULSE LLaVA")
68
- else:
69
- print("⚠️ Using different version - PULSE LLaVA prefers 4.37.2")
70
- except Exception as e:
71
- print(f"❌ Error checking transformers version: {e}")
72
-
73
- print(f"🔧 CUDA available: {torch.cuda.is_available()}")
74
- if torch.cuda.is_available():
75
- print(f"🔧 CUDA device: {torch.cuda.get_device_name(0)}")
76
-
77
- # Let's see what hardware we're working with
78
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
- print(f"🖥️ Running on: {self.device}")
80
-
81
- try:
82
- # First attempt - PULSE demo's exact approach
83
- if LLAVA_AVAILABLE:
84
- print("📦 Using PULSE demo's load_pretrained_model approach...")
85
- from llava.model.builder import load_pretrained_model
86
- from llava.mm_utils import get_model_name_from_path
87
-
88
- model_path = "PULSE-ECG/PULSE-7B"
89
- model_name = get_model_name_from_path(model_path)
90
-
91
- self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
92
- model_path=model_path,
93
- model_base=None,
94
- model_name=model_name,
95
- load_8bit=False,
96
- load_4bit=False
97
- )
98
-
99
- # Move model to device like demo
100
- self.model = self.model.to(self.device)
101
- self.use_pipeline = False
102
- print("✅ Model loaded successfully with PULSE demo's approach!")
103
- print(f"📸 Image processor: {type(self.image_processor).__name__}")
104
-
105
- else:
106
- raise ImportError("LLaVA modules not available")
107
-
108
- except Exception as e:
109
- print(f"⚠️ PULSE demo approach failed: {e}")
110
- print("🔄 Falling back to pipeline...")
111
-
112
- try:
113
- # Fallback - using pipeline
114
- from transformers import pipeline
115
-
116
- print("📦 Fetching model from HuggingFace Hub...")
117
- self.pipe = pipeline(
118
- "text-generation",
119
- model="PULSE-ECG/PULSE-7B",
120
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
121
- device=0 if torch.cuda.is_available() else -1,
122
- trust_remote_code=True,
123
- model_kwargs={
124
- "low_cpu_mem_usage": True,
125
- "use_safetensors": True
126
- }
127
- )
128
- self.use_pipeline = True
129
- self.image_processor = None
130
- print("✅ Model loaded successfully via pipeline!")
131
-
132
- except Exception as e2:
133
- print(f"😓 Pipeline also failed: {e2}")
134
-
135
- try:
136
- # Last resort - manual loading
137
- from transformers import AutoTokenizer, LlamaForCausalLM
138
-
139
- print("📖 Manual loading as last resort...")
140
- self.tokenizer = AutoTokenizer.from_pretrained(
141
- "PULSE-ECG/PULSE-7B",
142
- trust_remote_code=True
143
- )
144
-
145
- self.model = LlamaForCausalLM.from_pretrained(
146
- "PULSE-ECG/PULSE-7B",
147
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
148
- device_map="auto",
149
- low_cpu_mem_usage=True,
150
- trust_remote_code=True
151
- )
152
-
153
- if self.tokenizer.pad_token is None:
154
- self.tokenizer.pad_token = self.tokenizer.eos_token
155
- self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
156
-
157
- self.model.eval()
158
- self.use_pipeline = False
159
- self.image_processor = None
160
- print(" Model loaded manually!")
161
-
162
- except Exception as e3:
163
- print(f"😓 All approaches failed: {e3}")
164
- self.pipe = None
165
- self.model = None
166
- self.tokenizer = None
167
- self.image_processor = None
168
- self.use_pipeline = None
169
-
170
- # Final status report
171
- print("\n🔍 Model Loading Status Report:")
172
- print(f" - use_pipeline: {self.use_pipeline}")
173
- print(f" - model: {'✅ Loaded' if hasattr(self, 'model') and self.model is not None else '❌ None'}")
174
- print(f" - tokenizer: {'✅ Loaded' if hasattr(self, 'tokenizer') and self.tokenizer is not None else '❌ None'}")
175
- print(f" - image_processor: {'✅ Loaded' if hasattr(self, 'image_processor') and self.image_processor is not None else '❌ None'}")
176
- print(f" - pipe: {'✅ Loaded' if hasattr(self, 'pipe') and self.pipe is not None else '❌ None'}")
177
-
178
- # Check if any model component loaded successfully
179
- has_model = hasattr(self, 'model') and self.model is not None
180
- has_tokenizer = hasattr(self, 'tokenizer') and self.tokenizer is not None
181
- has_pipe = hasattr(self, 'pipe') and self.pipe is not None
182
- has_image_processor = hasattr(self, 'image_processor') and self.image_processor is not None
183
-
184
- if not (has_model or has_tokenizer or has_pipe):
185
- print("💥 CRITICAL: No model components loaded successfully!")
186
- else:
187
- print("✅ At least one model component loaded successfully")
188
- if has_image_processor:
189
- print("🖼️ Vision capabilities available!")
190
- else:
191
- print("⚠️ No image processor - text-only mode")
192
-
193
- def is_valid_image_format(self, filename_or_url):
194
- """Validate image format like PULSE demo"""
195
- # Demo's supported formats
196
- image_extensions = ["jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "heic", "heif", "jfif", "svg", "eps", "raw"]
197
-
198
- if filename_or_url.startswith(('http://', 'https://')):
199
- # For URLs, check the extension or content-type
200
- ext = filename_or_url.split('.')[-1].split('?')[0].lower()
201
- return ext in image_extensions
202
- else:
203
- # For base64 or local files
204
- return True # Base64 will be validated during decode
205
-
206
- def process_image_input(self, image_input):
207
- """
208
- Handle both URL and base64 image inputs exactly like PULSE demo
209
-
210
- Args:
211
- image_input: Can be a URL string or base64 encoded image
212
-
213
- Returns:
214
- PIL Image object or None if something goes wrong
215
- """
216
- try:
217
- # Check if it's a URL (starts with http/https)
218
- if isinstance(image_input, str) and (image_input.startswith('http://') or image_input.startswith('https://')):
219
- print(f"🌐 Fetching image from URL: {image_input[:50]}...")
220
-
221
- # Validate format like demo
222
- if not self.is_valid_image_format(image_input):
223
- print("❌ Invalid image format in URL")
224
- return None
225
-
226
- # Demo's exact image loading approach
227
- response = requests.get(image_input, timeout=15)
228
- if response.status_code == 200:
229
- image = Image.open(BytesIO(response.content)).convert("RGB")
230
- print(f"✅ Image downloaded successfully! Size: {image.size}")
231
- return image
232
- else:
233
- print(f"❌ Failed to load image: status {response.status_code}")
234
- return None
235
-
236
- # Must be base64 then
237
- elif isinstance(image_input, str):
238
- print("🔍 Decoding base64 image...")
239
-
240
- # Remove the data URL prefix if it exists
241
- base64_data = image_input
242
- if "base64," in image_input:
243
- base64_data = image_input.split("base64,")[1]
244
-
245
- # Clean and validate base64 data
246
- base64_data = base64_data.strip().replace('\n', '').replace('\r', '').replace(' ', '')
247
-
248
- try:
249
- image_data = base64.b64decode(base64_data)
250
- image = Image.open(BytesIO(image_data)).convert('RGB')
251
- print(f"✅ Base64 image decoded successfully! Size: {image.size}")
252
- return image
253
- except Exception as decode_error:
254
- print(f"❌ Base64 decode error: {decode_error}")
255
- return None
256
-
257
- except Exception as e:
258
- print(f"❌ Couldn't process the image: {e}")
259
- return None
260
-
261
- return None
262
-
263
- def add_turkish_commentary(self, response: Dict[str, Any], enable_commentary: bool, timeout: int = 30) -> Dict[str, Any]:
264
- """Add Turkish commentary to the response using DeepSeek API"""
265
- if not enable_commentary:
266
- return response
267
-
268
- if not UTILS_AVAILABLE or not deepseek_client:
269
- print("⚠️ DeepSeek client not available - skipping Turkish commentary")
270
- response["commentary_status"] = "unavailable"
271
- return response
272
-
273
- if not deepseek_client.is_available():
274
- print("⚠️ DeepSeek API key not configured - skipping Turkish commentary")
275
- response["commentary_status"] = "api_key_missing"
276
- return response
277
-
278
- generated_text = response.get("generated_text", "")
279
- if not generated_text:
280
- print("⚠️ No generated text to comment on")
281
- response["commentary_status"] = "no_text"
282
- return response
283
-
284
- print("🔄 DeepSeek ile Türkçe yorum ekleniyor...")
285
- commentary_result = deepseek_client.get_turkish_commentary(generated_text, timeout)
286
-
287
- if commentary_result["success"]:
288
- response["comment_text"] = commentary_result["comment_text"]
289
- response["commentary_model"] = commentary_result.get("model", "deepseek-chat")
290
- response["commentary_tokens"] = commentary_result.get("tokens_used", 0)
291
- response["commentary_status"] = "success"
292
- print("✅ Türkçe yorum başarıyla eklendi")
293
- else:
294
- response["comment_text"] = ""
295
- response["commentary_error"] = commentary_result["error"]
296
- response["commentary_status"] = "failed"
297
- print(f"❌ Türkçe yorum eklenemedi: {commentary_result['error']}")
298
-
299
- return response
300
-
301
- def health_check(self) -> Dict[str, Any]:
302
- """Health check endpoint"""
303
- if UTILS_AVAILABLE:
304
- return create_health_check()
305
- else:
306
- return {
307
- 'status': 'healthy',
308
- 'model': 'PULSE-7B',
309
- 'timestamp': time.time(),
310
- 'handler_version': '2.0.0'
311
- }
312
-
313
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
314
- """
315
- Main processing function - where the magic happens!
316
-
317
- Args:
318
- data: Input data with 'inputs' and optional 'parameters'
319
-
320
- Returns:
321
- List with the generated response
322
- """
323
- # Quick check - is our model ready?
324
- if self.use_pipeline is None:
325
- return [{
326
- "generated_text": "Oops! Model couldn't load properly. Please check the deployment settings.",
327
- "error": "Model initialization failed",
328
- "handler": "Ubden® Team Enhanced Handler"
329
- }]
330
-
331
- try:
332
- # Parse the inputs - flexible format support
333
- inputs = data.get("inputs", "")
334
- text = ""
335
- image = None
336
-
337
- if isinstance(inputs, dict):
338
- # Dictionary input - check for text and image
339
- # Support query field (new) plus original text/prompt fields
340
- text = inputs.get("query", inputs.get("text", inputs.get("prompt", str(inputs))))
341
-
342
- # Check for image in various formats
343
- image_input = inputs.get("image", inputs.get("image_url", inputs.get("image_base64", None)))
344
- if image_input:
345
- image = self.process_image_input(image_input)
346
- if image:
347
- # Since we're in text-only mode, create smart ECG context
348
- print(f"🖼️ Image loaded: {image.size[0]}x{image.size[1]} pixels - using text-only ECG analysis mode")
349
-
350
- # Create ECG-specific prompt that mimics visual analysis
351
- ecg_context = f"Analyzing an ECG image ({image.size[0]}x{image.size[1]} pixels). "
352
-
353
- # Use demo's exact approach - no additional context, just the query
354
- # Model is trained to understand ECG images from text queries
355
- pass # Keep text exactly as received
356
- else:
357
- # Simple string input
358
- text = str(inputs)
359
-
360
- if not text:
361
- return [{"generated_text": "Hey, I need some text to work with! Please provide an input."}]
362
-
363
- # Get generation parameters - using PULSE-7B demo's exact settings
364
- parameters = data.get("parameters", {})
365
- max_new_tokens = min(parameters.get("max_new_tokens", 1024), 8192) # Demo uses 1024 default
366
- temperature = parameters.get("temperature", 0.05) # Demo uses 0.05 for precise medical analysis
367
- top_p = parameters.get("top_p", 1.0) # Demo uses 1.0 for full vocabulary access
368
- do_sample = parameters.get("do_sample", True) # Demo uses sampling
369
- repetition_penalty = parameters.get("repetition_penalty", 1.0) # Demo default
370
-
371
- print(f"🎛️ Generation params: max_tokens={max_new_tokens}, temp={temperature}, top_p={top_p}, do_sample={do_sample}, rep_penalty={repetition_penalty}")
372
-
373
- # Check if Turkish commentary is requested (NEW FEATURE)
374
- enable_turkish_commentary = parameters.get("enable_turkish_commentary", False) # Default false
375
-
376
- # Using pipeline? Let's go!
377
- if self.use_pipeline:
378
- print(f"🎛️ Pipeline generation: temp={temperature}, tokens={max_new_tokens}")
379
- print(f"📝 Input text: '{text[:100]}...'")
380
-
381
- result = self.pipe(
382
- text,
383
- max_new_tokens=max_new_tokens,
384
- min_new_tokens=200, # Force very detailed analysis to match demo
385
- temperature=temperature,
386
- top_p=top_p,
387
- do_sample=do_sample,
388
- repetition_penalty=repetition_penalty,
389
- return_full_text=False # Just the new stuff, not the input
390
- )
391
-
392
- # Pipeline returns a list, let's handle it
393
- if isinstance(result, list) and len(result) > 0:
394
- generated_text = result[0].get("generated_text", "").strip()
395
-
396
- print(f"🔍 Pipeline debug:")
397
- print(f" - Raw result: '{str(result[0])[:200]}...'")
398
- print(f" - Generated text length: {len(generated_text)}")
399
-
400
- # Clean up common issues
401
- if generated_text.startswith(text):
402
- generated_text = generated_text[len(text):].strip()
403
- print("🔧 Removed input text from output")
404
-
405
- # Remove common artifacts
406
- generated_text = generated_text.replace("</s>", "").strip()
407
-
408
- if not generated_text:
409
- print("❌ Pipeline generated empty text!")
410
- generated_text = "Empty response from pipeline. Please try different parameters."
411
-
412
- print(f"✅ Final pipeline text: '{generated_text[:100]}...' (length: {len(generated_text)})")
413
-
414
- # Create response
415
- response = {"generated_text": generated_text}
416
-
417
- # Add Turkish commentary if requested (NEW FEATURE)
418
- if enable_turkish_commentary:
419
- response = self.add_turkish_commentary(response, True)
420
-
421
- return [response]
422
- else:
423
- generated_text = str(result).strip()
424
-
425
- # Create response
426
- response = {"generated_text": generated_text}
427
-
428
- # Add Turkish commentary if requested (NEW FEATURE)
429
- if enable_turkish_commentary:
430
- response = self.add_turkish_commentary(response, True)
431
-
432
- return [response]
433
-
434
- # Manual generation mode - using PULSE demo's exact approach
435
- else:
436
- print(f"🔥 Manual generation with PULSE demo logic: temp={temperature}, tokens={max_new_tokens}")
437
- print(f"📝 Input text: '{text[:100]}...'")
438
-
439
- # Text-only generation with enhanced ECG context
440
- print("🔤 Using enhanced text-only generation with ECG context")
441
-
442
- # Tokenize the enhanced prompt
443
- encoded = self.tokenizer(
444
- text,
445
- return_tensors="pt",
446
- truncation=True,
447
- max_length=4096 # Increased for longer prompts
448
- )
449
-
450
- input_ids = encoded["input_ids"].to(self.device)
451
- attention_mask = encoded.get("attention_mask")
452
- if attention_mask is not None:
453
- attention_mask = attention_mask.to(self.device)
454
-
455
- print(f"🔍 Enhanced generation debug:")
456
- print(f" - Enhanced prompt length: {len(text)} chars")
457
- print(f" - Input tokens: {input_ids.shape[-1]}")
458
- print(f" - Prompt preview: '{text[:150]}...'")
459
-
460
- # Generate with enhanced settings for medical analysis
461
- with torch.no_grad():
462
- outputs = self.model.generate(
463
- input_ids,
464
- attention_mask=attention_mask,
465
- max_new_tokens=max_new_tokens,
466
- min_new_tokens=200, # Force detailed response like demo
467
- temperature=temperature,
468
- top_p=top_p,
469
- do_sample=do_sample,
470
- repetition_penalty=repetition_penalty,
471
- pad_token_id=self.tokenizer.pad_token_id,
472
- eos_token_id=self.tokenizer.eos_token_id,
473
- early_stopping=False
474
- )
475
-
476
- # Decode and clean response
477
- generated_ids = outputs[0][input_ids.shape[-1]:]
478
- generated_text = self.tokenizer.decode(
479
- generated_ids,
480
- skip_special_tokens=True,
481
- clean_up_tokenization_spaces=True
482
- ).strip()
483
-
484
- # Aggressive cleanup of artifacts
485
- generated_text = generated_text.replace("</s>", "").strip()
486
-
487
- # Simple cleanup - just remove Answer prefix and parentheses
488
- if generated_text.startswith("(Answer:") and ")" in generated_text:
489
- # Just remove the parentheses and Answer: prefix
490
- end_paren = generated_text.find(")")
491
- answer_content = generated_text[8:end_paren].strip() # Remove "(Answer:"
492
- # Keep the rest of the response if there is any
493
- rest_of_response = generated_text[end_paren+1:].strip()
494
-
495
- if rest_of_response:
496
- generated_text = f"{answer_content}. {rest_of_response}"
497
- else:
498
- generated_text = answer_content
499
-
500
- elif generated_text.startswith("Answer:"):
501
- generated_text = generated_text[7:].strip()
502
-
503
- # Remove only clear training artifacts
504
- cleanup_patterns = [
505
- "In this task",
506
- "I'm asking the respondent",
507
- "The respondent should"
508
- ]
509
-
510
- for pattern in cleanup_patterns:
511
- if pattern in generated_text:
512
- parts = generated_text.split(pattern)
513
- generated_text = parts[0].strip()
514
- break
515
-
516
- # Only provide fallback if response is truly empty or malformed
517
- if len(generated_text) < 10 or generated_text.startswith("7)"):
518
- print("⚠️ Malformed response detected, providing fallback...")
519
- generated_text = "This ECG shows cardiac electrical activity. For accurate interpretation, please consult with a qualified cardiologist who can analyze the specific waveforms, intervals, and morphology patterns."
520
-
521
- print(f"✅ Enhanced text-only generation: '{generated_text[:100]}...' (length: {len(generated_text)})")
522
-
523
- # Create response
524
- response = {"generated_text": generated_text}
525
-
526
- # Add Turkish commentary if requested (NEW FEATURE)
527
- if enable_turkish_commentary:
528
- response = self.add_turkish_commentary(response, True)
529
-
530
- return [response]
531
-
532
-
533
- except Exception as e:
534
- error_msg = f"Something went wrong during generation: {str(e)}"
535
- print(f"❌ {error_msg}")
536
- return [{
537
- "generated_text": "",
538
- "error": error_msg,
539
- "handler": "Ubden® Team Enhanced Handler"
540
  }]
 
1
+ """
2
+ PULSE-7B Enhanced Handler
3
+ Ubden® Team - Edited by https://github.com/ck-cankurt
4
+ Support: Text, Image URLs, and Base64 encoded images
5
+ """
6
+
7
+ import torch
8
+ from typing import Dict, List, Any
9
+ import base64
10
+ from io import BytesIO
11
+ from PIL import Image
12
+ import requests
13
+
14
+
15
+ class EndpointHandler:
16
+ def __init__(self, path=""):
17
+ """
18
+ Hey there! Let's get this PULSE-7B model up and running.
19
+ We'll load it from the HuggingFace hub directly, so no worries about local files.
20
+
21
+ Args:
22
+ path: Model directory path (we actually ignore this and load from HF hub)
23
+ """
24
+ print("🚀 Starting up PULSE-7B handler...")
25
+ print("📝 Enhanced by Ubden® Team - github.com/ck-cankurt")
26
+
27
+ # Let's see what hardware we're working with
28
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ print(f"🖥️ Running on: {self.device}")
30
+
31
+ try:
32
+ # First attempt - using pipeline (easiest and most stable way)
33
+ from transformers import pipeline
34
+
35
+ print("📦 Fetching model from HuggingFace Hub...")
36
+ self.pipe = pipeline(
37
+ "text-generation",
38
+ model="PULSE-ECG/PULSE-7B",
39
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
40
+ device=0 if torch.cuda.is_available() else -1,
41
+ trust_remote_code=True,
42
+ model_kwargs={
43
+ "low_cpu_mem_usage": True,
44
+ "use_safetensors": True
45
+ }
46
+ )
47
+ print("✅ Model loaded successfully via pipeline!")
48
+
49
+ except Exception as e:
50
+ print(f"⚠️ Pipeline didn't work out: {e}")
51
+ print("🔄 Let me try a different approach...")
52
+
53
+ try:
54
+ # Plan B - load model and tokenizer separately
55
+ from transformers import AutoTokenizer, LlamaForCausalLM
56
+
57
+ # Get the tokenizer ready
58
+ print("📖 Setting up tokenizer...")
59
+ self.tokenizer = AutoTokenizer.from_pretrained(
60
+ "PULSE-ECG/PULSE-7B",
61
+ trust_remote_code=True
62
+ )
63
+
64
+ # Load the model as Llama (it works, trust me!)
65
+ print("🧠 Loading the model as Llama...")
66
+ self.model = LlamaForCausalLM.from_pretrained(
67
+ "PULSE-ECG/PULSE-7B",
68
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
69
+ device_map="auto",
70
+ low_cpu_mem_usage=True,
71
+ trust_remote_code=True
72
+ )
73
+
74
+ # Quick fix for padding token if it's missing
75
+ if self.tokenizer.pad_token is None:
76
+ self.tokenizer.pad_token = self.tokenizer.eos_token
77
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
78
+
79
+ self.model.eval()
80
+ self.use_pipeline = False
81
+ print("✅ Model loaded successfully via direct loading!")
82
+
83
+ except Exception as e2:
84
+ print(f"😓 That didn't work either: {e2}")
85
+ # If all else fails, we'll handle it gracefully
86
+ self.pipe = None
87
+ self.model = None
88
+ self.tokenizer = None
89
+ self.use_pipeline = None
90
+ else:
91
+ self.use_pipeline = True
92
+
93
+ def process_image_input(self, image_input):
94
+ """
95
+ Handle both URL and base64 image inputs like a champ!
96
+
97
+ Args:
98
+ image_input: Can be a URL string or base64 encoded image
99
+
100
+ Returns:
101
+ PIL Image object or None if something goes wrong
102
+ """
103
+ try:
104
+ # Check if it's a URL (starts with http/https)
105
+ if isinstance(image_input, str) and (image_input.startswith('http://') or image_input.startswith('https://')):
106
+ print(f"🌐 Fetching image from URL: {image_input[:50]}...")
107
+ response = requests.get(image_input, timeout=10)
108
+ response.raise_for_status()
109
+ image = Image.open(BytesIO(response.content)).convert('RGB')
110
+ print(" Image downloaded successfully!")
111
+ return image
112
+
113
+ # Must be base64 then
114
+ elif isinstance(image_input, str):
115
+ print("🔍 Decoding base64 image...")
116
+ # Remove the data URL prefix if it exists
117
+ if "base64," in image_input:
118
+ image_input = image_input.split("base64,")[1]
119
+
120
+ image_data = base64.b64decode(image_input)
121
+ image = Image.open(BytesIO(image_data)).convert('RGB')
122
+ print("✅ Image decoded successfully!")
123
+ return image
124
+
125
+ except Exception as e:
126
+ print(f"❌ Couldn't process the image: {e}")
127
+ return None
128
+
129
+ return None
130
+
131
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
132
+ """
133
+ Main processing function - where the magic happens!
134
+
135
+ Args:
136
+ data: Input data with 'inputs' and optional 'parameters'
137
+
138
+ Returns:
139
+ List with the generated response
140
+ """
141
+ # Quick check - is our model ready?
142
+ if self.use_pipeline is None:
143
+ return [{
144
+ "generated_text": "Oops! Model couldn't load properly. Please check the deployment settings.",
145
+ "error": "Model initialization failed",
146
+ "handler": "Ubden® Team Enhanced Handler"
147
+ }]
148
+
149
+ try:
150
+ # Parse the inputs - flexible format support
151
+ inputs = data.get("inputs", "")
152
+ text = ""
153
+ image = None
154
+
155
+ if isinstance(inputs, dict):
156
+ # Dictionary input - check for text and image
157
+ text = inputs.get("text", inputs.get("prompt", str(inputs)))
158
+
159
+ # Check for image in various formats
160
+ image_input = inputs.get("image", inputs.get("image_url", inputs.get("image_base64", None)))
161
+ if image_input:
162
+ image = self.process_image_input(image_input)
163
+ if image:
164
+ # For now, we'll add a note about the image since we're text-only
165
+ text = f"[Image provided - {image.size[0]}x{image.size[1]} pixels] {text}"
166
+ else:
167
+ # Simple string input
168
+ text = str(inputs)
169
+
170
+ if not text:
171
+ return [{"generated_text": "Hey, I need some text to work with! Please provide an input."}]
172
+
173
+ # Get generation parameters with sensible defaults
174
+ parameters = data.get("parameters", {})
175
+ max_new_tokens = min(parameters.get("max_new_tokens", 256), 1024)
176
+ temperature = parameters.get("temperature", 0.7)
177
+ top_p = parameters.get("top_p", 0.95)
178
+ do_sample = parameters.get("do_sample", True)
179
+ repetition_penalty = parameters.get("repetition_penalty", 1.0)
180
+
181
+ # Using pipeline? Let's go!
182
+ if self.use_pipeline:
183
+ result = self.pipe(
184
+ text,
185
+ max_new_tokens=max_new_tokens,
186
+ temperature=temperature,
187
+ top_p=top_p,
188
+ do_sample=do_sample,
189
+ repetition_penalty=repetition_penalty,
190
+ return_full_text=False # Just the new stuff, not the input
191
+ )
192
+
193
+ # Pipeline returns a list, let's handle it
194
+ if isinstance(result, list) and len(result) > 0:
195
+ return [{"generated_text": result[0].get("generated_text", "")}]
196
+ else:
197
+ return [{"generated_text": str(result)}]
198
+
199
+ # Manual generation mode
200
+ else:
201
+ # Tokenize the input
202
+ encoded = self.tokenizer(
203
+ text,
204
+ return_tensors="pt",
205
+ truncation=True,
206
+ max_length=2048
207
+ )
208
+
209
+ input_ids = encoded["input_ids"].to(self.device)
210
+ attention_mask = encoded.get("attention_mask")
211
+ if attention_mask is not None:
212
+ attention_mask = attention_mask.to(self.device)
213
+
214
+ # Generate the response
215
+ with torch.no_grad():
216
+ outputs = self.model.generate(
217
+ input_ids,
218
+ attention_mask=attention_mask,
219
+ max_new_tokens=max_new_tokens,
220
+ temperature=temperature,
221
+ top_p=top_p,
222
+ do_sample=do_sample,
223
+ repetition_penalty=repetition_penalty,
224
+ pad_token_id=self.tokenizer.pad_token_id,
225
+ eos_token_id=self.tokenizer.eos_token_id
226
+ )
227
+
228
+ # Decode only the new tokens (not the input)
229
+ generated_ids = outputs[0][input_ids.shape[-1]:]
230
+ generated_text = self.tokenizer.decode(
231
+ generated_ids,
232
+ skip_special_tokens=True,
233
+ clean_up_tokenization_spaces=True
234
+ )
235
+
236
+ return [{"generated_text": generated_text}]
237
+
238
+ except Exception as e:
239
+ error_msg = f"Something went wrong during generation: {str(e)}"
240
+ print(f"❌ {error_msg}")
241
+ return [{
242
+ "generated_text": "",
243
+ "error": error_msg,
244
+ "handler": "Ubden® Team Enhanced Handler"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  }]