ubden commited on
Commit
b13fdfd
·
verified ·
1 Parent(s): 3a15aeb

Upload 10 files

Browse files
Files changed (3) hide show
  1. generation_config.json +2 -43
  2. handler.py +86 -445
  3. requirements.txt +1 -5
generation_config.json CHANGED
@@ -1,49 +1,8 @@
1
  {
2
- "_from_model_config": true,
3
  "attn_implementation": "flash_attention_2",
4
  "bos_token_id": 1,
5
  "eos_token_id": 2,
6
- "pad_token_id": 0,
7
  "max_length": 4096,
8
- "max_new_tokens": 2048,
9
- "min_length": 1,
10
- "min_new_tokens": 1,
11
- "early_stopping": false,
12
- "max_time": null,
13
- "do_sample": true,
14
- "num_beams": 1,
15
- "num_beam_groups": 1,
16
- "penalty_alpha": null,
17
- "use_cache": true,
18
- "temperature": 0.2,
19
- "top_k": 50,
20
- "top_p": 0.9,
21
- "typical_p": 1.0,
22
- "epsilon_cutoff": 0.0,
23
- "eta_cutoff": 0.0,
24
- "diversity_penalty": 0.0,
25
- "repetition_penalty": 1.05,
26
- "encoder_repetition_penalty": 1.0,
27
- "length_penalty": 1.0,
28
- "no_repeat_ngram_size": 0,
29
- "bad_words_ids": null,
30
- "force_words_ids": null,
31
- "renormalize_logits": false,
32
- "constraints": null,
33
- "forced_bos_token_id": null,
34
- "forced_eos_token_id": null,
35
- "remove_invalid_values": false,
36
- "exponential_decay_length_penalty": null,
37
- "suppress_tokens": null,
38
- "begin_suppress_tokens": null,
39
- "forced_decoder_ids": null,
40
- "sequence_bias": null,
41
- "guidance_scale": null,
42
- "low_memory": null,
43
- "num_return_sequences": 1,
44
- "output_attentions": false,
45
- "output_hidden_states": false,
46
- "output_scores": false,
47
- "return_dict_in_generate": false,
48
- "transformers_version": "4.40.0"
49
  }
 
1
  {
 
2
  "attn_implementation": "flash_attention_2",
3
  "bos_token_id": 1,
4
  "eos_token_id": 2,
 
5
  "max_length": 4096,
6
+ "pad_token_id": 0,
7
+ "transformers_version": "4.37.2"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  }
handler.py CHANGED
@@ -1,450 +1,146 @@
1
- """
2
- PULSE-7B Enhanced Handler
3
- Ubden® Team - Edited by https://github.com/ck-cankurt
4
- Support: Text, Image URLs, and Base64 encoded images
5
- """
6
-
7
  import torch
8
  from typing import Dict, List, Any
9
- import base64
10
- from io import BytesIO
11
- from PIL import Image
12
- import requests
13
- import time
14
-
15
- # Import utilities if available
16
- try:
17
- from utils import (
18
- performance_monitor,
19
- validate_image_input,
20
- sanitize_parameters,
21
- get_system_info,
22
- create_health_check,
23
- deepseek_client
24
- )
25
- UTILS_AVAILABLE = True
26
- except ImportError:
27
- UTILS_AVAILABLE = False
28
- deepseek_client = None
29
- print("⚠️ Utils module not found - performance monitoring and DeepSeek integration disabled")
30
-
31
- # Try to import LLaVA modules for proper conversation handling
32
- try:
33
- from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
34
- from llava.conversation import conv_templates, SeparatorStyle
35
- from llava.mm_utils import tokenizer_image_token, process_images, KeywordsStoppingCriteria
36
- LLAVA_AVAILABLE = True
37
- print("✅ LLaVA modules imported successfully")
38
- except ImportError:
39
- LLAVA_AVAILABLE = False
40
- print("⚠️ LLaVA modules not available - using basic text processing")
41
 
42
 
43
  class EndpointHandler:
44
  def __init__(self, path=""):
45
  """
46
- Hey there! Let's get this PULSE-7B model up and running.
47
- We'll load it from the HuggingFace hub directly, so no worries about local files.
48
 
49
  Args:
50
- path: Model directory path (we actually ignore this and load from HF hub)
51
  """
52
- print("🚀 Starting up PULSE-7B handler...")
53
- print("📝 Enhanced by Ubden® Team - github.com/ck-cankurt")
54
- import sys
55
- print(f"🔧 Python version: {sys.version}")
56
- print(f"🔧 PyTorch version: {torch.__version__}")
57
-
58
- # Check transformers version
59
- try:
60
- import transformers
61
- print(f"🔧 Transformers version: {transformers.__version__}")
62
-
63
- # PULSE LLaVA works with transformers==4.37.2
64
- if transformers.__version__ == "4.37.2":
65
- print("✅ Using PULSE LLaVA compatible version (4.37.2)")
66
- elif "dev" in transformers.__version__ or "git" in str(transformers.__version__):
67
- print("⚠️ Using development version - may conflict with PULSE LLaVA")
68
- else:
69
- print("⚠️ Using different version - PULSE LLaVA prefers 4.37.2")
70
- except Exception as e:
71
- print(f"❌ Error checking transformers version: {e}")
72
 
73
- print(f"🔧 CUDA available: {torch.cuda.is_available()}")
74
- if torch.cuda.is_available():
75
- print(f"🔧 CUDA device: {torch.cuda.get_device_name(0)}")
76
-
77
- # Let's see what hardware we're working with
78
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
- print(f"🖥️ Running on: {self.device}")
80
 
81
  try:
82
- # First attempt - PULSE demo's exact approach
83
- if LLAVA_AVAILABLE:
84
- print("📦 Using PULSE demo's load_pretrained_model approach...")
85
- from llava.model.builder import load_pretrained_model
86
- from llava.mm_utils import get_model_name_from_path
87
-
88
- model_path = "PULSE-ECG/PULSE-7B"
89
- model_name = get_model_name_from_path(model_path)
90
-
91
- self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
92
- model_path=model_path,
93
- model_base=None,
94
- model_name=model_name,
95
- load_8bit=False,
96
- load_4bit=False
97
- )
98
-
99
- # Move model to device like demo
100
- self.model = self.model.to(self.device)
101
- self.use_pipeline = False
102
- print("✅ Model loaded successfully with PULSE demo's approach!")
103
- print(f"📸 Image processor: {type(self.image_processor).__name__}")
104
-
105
- else:
106
- raise ImportError("LLaVA modules not available")
107
 
108
  except Exception as e:
109
- print(f"⚠️ PULSE demo approach failed: {e}")
110
- print("🔄 Falling back to pipeline...")
111
 
112
  try:
113
- # Fallback - using pipeline
114
- from transformers import pipeline
 
 
 
 
 
 
 
115
 
116
- print("📦 Fetching model from HuggingFace Hub...")
117
- self.pipe = pipeline(
118
- "text-generation",
119
- model="PULSE-ECG/PULSE-7B",
120
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
121
- device=0 if torch.cuda.is_available() else -1,
122
- trust_remote_code=True,
123
- model_kwargs={
124
- "low_cpu_mem_usage": True,
125
- "use_safetensors": True
126
- }
127
  )
128
- self.use_pipeline = True
129
- self.image_processor = None
130
- print("✅ Model loaded successfully via pipeline!")
131
-
132
- except Exception as e2:
133
- print(f"😓 Pipeline also failed: {e2}")
134
-
135
- try:
136
- # Last resort - manual loading
137
- from transformers import AutoTokenizer, LlamaForCausalLM
138
-
139
- print("📖 Manual loading as last resort...")
140
- self.tokenizer = AutoTokenizer.from_pretrained(
141
- "PULSE-ECG/PULSE-7B",
142
- trust_remote_code=True
143
- )
144
-
145
- self.model = LlamaForCausalLM.from_pretrained(
146
- "PULSE-ECG/PULSE-7B",
147
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
148
- device_map="auto",
149
- low_cpu_mem_usage=True,
150
- trust_remote_code=True
151
- )
152
-
153
- if self.tokenizer.pad_token is None:
154
- self.tokenizer.pad_token = self.tokenizer.eos_token
155
- self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
156
-
157
- self.model.eval()
158
- self.use_pipeline = False
159
- self.image_processor = None
160
- print("✅ Model loaded manually!")
161
-
162
- except Exception as e3:
163
- print(f"😓 All approaches failed: {e3}")
164
- self.pipe = None
165
- self.model = None
166
- self.tokenizer = None
167
- self.image_processor = None
168
- self.use_pipeline = None
169
-
170
- # Final status report
171
- print("\n🔍 Model Loading Status Report:")
172
- print(f" - use_pipeline: {self.use_pipeline}")
173
- print(f" - model: {'✅ Loaded' if hasattr(self, 'model') and self.model is not None else '❌ None'}")
174
- print(f" - tokenizer: {'✅ Loaded' if hasattr(self, 'tokenizer') and self.tokenizer is not None else '❌ None'}")
175
- print(f" - image_processor: {'✅ Loaded' if hasattr(self, 'image_processor') and self.image_processor is not None else '❌ None'}")
176
- print(f" - pipe: {'✅ Loaded' if hasattr(self, 'pipe') and self.pipe is not None else '❌ None'}")
177
-
178
- # Check if any model component loaded successfully
179
- has_model = hasattr(self, 'model') and self.model is not None
180
- has_tokenizer = hasattr(self, 'tokenizer') and self.tokenizer is not None
181
- has_pipe = hasattr(self, 'pipe') and self.pipe is not None
182
- has_image_processor = hasattr(self, 'image_processor') and self.image_processor is not None
183
-
184
- if not (has_model or has_tokenizer or has_pipe):
185
- print("💥 CRITICAL: No model components loaded successfully!")
186
- else:
187
- print("✅ At least one model component loaded successfully")
188
- if has_image_processor:
189
- print("🖼️ Vision capabilities available!")
190
- else:
191
- print("⚠️ No image processor - text-only mode")
192
-
193
- def is_valid_image_format(self, filename_or_url):
194
- """Validate image format like PULSE demo"""
195
- # Demo's supported formats
196
- image_extensions = ["jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "heic", "heif", "jfif", "svg", "eps", "raw"]
197
-
198
- if filename_or_url.startswith(('http://', 'https://')):
199
- # For URLs, check the extension or content-type
200
- ext = filename_or_url.split('.')[-1].split('?')[0].lower()
201
- return ext in image_extensions
202
- else:
203
- # For base64 or local files
204
- return True # Base64 will be validated during decode
205
-
206
- def process_image_input(self, image_input):
207
- """
208
- Handle both URL and base64 image inputs exactly like PULSE demo
209
-
210
- Args:
211
- image_input: Can be a URL string or base64 encoded image
212
-
213
- Returns:
214
- PIL Image object or None if something goes wrong
215
- """
216
- try:
217
- # Check if it's a URL (starts with http/https)
218
- if isinstance(image_input, str) and (image_input.startswith('http://') or image_input.startswith('https://')):
219
- print(f"🌐 Fetching image from URL: {image_input[:50]}...")
220
 
221
- # Validate format like demo
222
- if not self.is_valid_image_format(image_input):
223
- print("❌ Invalid image format in URL")
224
- return None
225
 
226
- # Demo's exact image loading approach
227
- response = requests.get(image_input, timeout=15)
228
- if response.status_code == 200:
229
- image = Image.open(BytesIO(response.content)).convert("RGB")
230
- print(f"✅ Image downloaded successfully! Size: {image.size}")
231
- return image
232
- else:
233
- print(f"❌ Failed to load image: status {response.status_code}")
234
- return None
235
-
236
- # Must be base64 then
237
- elif isinstance(image_input, str):
238
- print("🔍 Decoding base64 image...")
239
-
240
- # Remove the data URL prefix if it exists
241
- base64_data = image_input
242
- if "base64," in image_input:
243
- base64_data = image_input.split("base64,")[1]
244
-
245
- # Clean and validate base64 data
246
- base64_data = base64_data.strip().replace('\n', '').replace('\r', '').replace(' ', '')
247
-
248
- try:
249
- image_data = base64.b64decode(base64_data)
250
- image = Image.open(BytesIO(image_data)).convert('RGB')
251
- print(f"✅ Base64 image decoded successfully! Size: {image.size}")
252
- return image
253
- except Exception as decode_error:
254
- print(f"❌ Base64 decode error: {decode_error}")
255
- return None
256
 
257
- except Exception as e:
258
- print(f" Couldn't process the image: {e}")
259
- return None
260
-
261
- return None
262
-
263
- def add_turkish_commentary(self, response: Dict[str, Any], enable_commentary: bool, timeout: int = 30) -> Dict[str, Any]:
264
- """Add Turkish commentary to the response using DeepSeek API"""
265
- if not enable_commentary:
266
- return response
267
-
268
- if not UTILS_AVAILABLE or not deepseek_client:
269
- print("⚠️ DeepSeek client not available - skipping Turkish commentary")
270
- response["commentary_status"] = "unavailable"
271
- return response
272
-
273
- if not deepseek_client.is_available():
274
- print("⚠️ DeepSeek API key not configured - skipping Turkish commentary")
275
- response["commentary_status"] = "api_key_missing"
276
- return response
277
-
278
- generated_text = response.get("generated_text", "")
279
- if not generated_text:
280
- print("⚠️ No generated text to comment on")
281
- response["commentary_status"] = "no_text"
282
- return response
283
-
284
- print("🔄 DeepSeek ile Türkçe yorum ekleniyor...")
285
- commentary_result = deepseek_client.get_turkish_commentary(generated_text, timeout)
286
-
287
- if commentary_result["success"]:
288
- response["comment_text"] = commentary_result["comment_text"]
289
- response["commentary_model"] = commentary_result.get("model", "deepseek-chat")
290
- response["commentary_tokens"] = commentary_result.get("tokens_used", 0)
291
- response["commentary_status"] = "success"
292
- print("✅ Türkçe yorum başarıyla eklendi")
293
- else:
294
- response["comment_text"] = ""
295
- response["commentary_error"] = commentary_result["error"]
296
- response["commentary_status"] = "failed"
297
- print(f"❌ Türkçe yorum eklenemedi: {commentary_result['error']}")
298
-
299
- return response
300
-
301
- def health_check(self) -> Dict[str, Any]:
302
- """Health check endpoint"""
303
- if UTILS_AVAILABLE:
304
- return create_health_check()
305
  else:
306
- return {
307
- 'status': 'healthy',
308
- 'model': 'PULSE-7B',
309
- 'timestamp': time.time(),
310
- 'handler_version': '2.0.0'
311
- }
312
 
313
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
314
  """
315
- Main processing function - where the magic happens!
316
 
317
  Args:
318
- data: Input data with 'inputs' and optional 'parameters'
319
 
320
  Returns:
321
- List with the generated response
322
  """
323
- # Quick check - is our model ready?
324
  if self.use_pipeline is None:
325
  return [{
326
- "generated_text": "Oops! Model couldn't load properly. Please check the deployment settings.",
327
- "error": "Model initialization failed",
328
- "handler": "Ubden® Team Enhanced Handler"
329
  }]
330
 
331
  try:
332
- # Parse the inputs - flexible format support
333
  inputs = data.get("inputs", "")
334
- text = ""
335
- image = None
336
-
337
  if isinstance(inputs, dict):
338
- # Dictionary input - check for text and image
339
- # Support query field (new) plus original text/prompt fields
340
- text = inputs.get("query", inputs.get("text", inputs.get("prompt", str(inputs))))
341
-
342
- # Check for image in various formats
343
- image_input = inputs.get("image", inputs.get("image_url", inputs.get("image_base64", None)))
344
- if image_input:
345
- image = self.process_image_input(image_input)
346
- if image:
347
- # Since we're in text-only mode, create smart ECG context
348
- print(f"🖼️ Image loaded: {image.size[0]}x{image.size[1]} pixels - using text-only ECG analysis mode")
349
-
350
- # Create ECG-specific prompt that mimics visual analysis
351
- ecg_context = f"Analyzing an ECG image ({image.size[0]}x{image.size[1]} pixels). "
352
-
353
- # Use demo's exact approach - no additional context, just the query
354
- # Model is trained to understand ECG images from text queries
355
- pass # Keep text exactly as received
356
  else:
357
- # Simple string input
358
  text = str(inputs)
359
 
360
  if not text:
361
- return [{"generated_text": "Hey, I need some text to work with! Please provide an input."}]
362
 
363
- # Get generation parameters - using PULSE-7B demo's exact settings
364
  parameters = data.get("parameters", {})
365
- max_new_tokens = min(parameters.get("max_new_tokens", 1024), 8192) # Demo uses 1024 default
366
- temperature = parameters.get("temperature", 0.05) # Demo uses 0.05 for precise medical analysis
367
- top_p = parameters.get("top_p", 1.0) # Demo uses 1.0 for full vocabulary access
368
- do_sample = parameters.get("do_sample", True) # Demo uses sampling
369
- repetition_penalty = parameters.get("repetition_penalty", 1.0) # Demo default
370
-
371
- print(f"🎛️ Generation params: max_tokens={max_new_tokens}, temp={temperature}, top_p={top_p}, do_sample={do_sample}, rep_penalty={repetition_penalty}")
372
-
373
- # Check if Turkish commentary is requested (NEW FEATURE)
374
- enable_turkish_commentary = parameters.get("enable_turkish_commentary", False) # Default false
375
 
376
- # Using pipeline? Let's go!
377
  if self.use_pipeline:
378
- print(f"🎛️ Pipeline generation: temp={temperature}, tokens={max_new_tokens}")
379
- print(f"📝 Input text: '{text[:100]}...'")
380
-
381
  result = self.pipe(
382
  text,
383
  max_new_tokens=max_new_tokens,
384
- min_new_tokens=200, # Force very detailed analysis to match demo
385
  temperature=temperature,
386
  top_p=top_p,
387
  do_sample=do_sample,
388
  repetition_penalty=repetition_penalty,
389
- return_full_text=False # Just the new stuff, not the input
390
  )
391
 
392
- # Pipeline returns a list, let's handle it
393
  if isinstance(result, list) and len(result) > 0:
394
- generated_text = result[0].get("generated_text", "").strip()
395
-
396
- print(f"🔍 Pipeline debug:")
397
- print(f" - Raw result: '{str(result[0])[:200]}...'")
398
- print(f" - Generated text length: {len(generated_text)}")
399
-
400
- # Clean up common issues
401
- if generated_text.startswith(text):
402
- generated_text = generated_text[len(text):].strip()
403
- print("🔧 Removed input text from output")
404
-
405
- # Remove common artifacts
406
- generated_text = generated_text.replace("</s>", "").strip()
407
-
408
- if not generated_text:
409
- print("❌ Pipeline generated empty text!")
410
- generated_text = "Empty response from pipeline. Please try different parameters."
411
-
412
- print(f"✅ Final pipeline text: '{generated_text[:100]}...' (length: {len(generated_text)})")
413
-
414
- # Create response
415
- response = {"generated_text": generated_text}
416
-
417
- # Add Turkish commentary if requested (NEW FEATURE)
418
- if enable_turkish_commentary:
419
- response = self.add_turkish_commentary(response, True)
420
-
421
- return [response]
422
  else:
423
- generated_text = str(result).strip()
424
-
425
- # Create response
426
- response = {"generated_text": generated_text}
427
-
428
- # Add Turkish commentary if requested (NEW FEATURE)
429
- if enable_turkish_commentary:
430
- response = self.add_turkish_commentary(response, True)
431
-
432
- return [response]
433
 
434
- # Manual generation mode - using PULSE demo's exact approach
435
  else:
436
- print(f"🔥 Manual generation with PULSE demo logic: temp={temperature}, tokens={max_new_tokens}")
437
- print(f"📝 Input text: '{text[:100]}...'")
438
-
439
- # Text-only generation with enhanced ECG context
440
- print("🔤 Using enhanced text-only generation with ECG context")
441
-
442
- # Tokenize the enhanced prompt
443
  encoded = self.tokenizer(
444
  text,
445
  return_tensors="pt",
446
  truncation=True,
447
- max_length=4096 # Increased for longer prompts
448
  )
449
 
450
  input_ids = encoded["input_ids"].to(self.device)
@@ -452,89 +148,34 @@ class EndpointHandler:
452
  if attention_mask is not None:
453
  attention_mask = attention_mask.to(self.device)
454
 
455
- print(f"🔍 Enhanced generation debug:")
456
- print(f" - Enhanced prompt length: {len(text)} chars")
457
- print(f" - Input tokens: {input_ids.shape[-1]}")
458
- print(f" - Prompt preview: '{text[:150]}...'")
459
-
460
- # Generate with enhanced settings for medical analysis
461
  with torch.no_grad():
462
  outputs = self.model.generate(
463
  input_ids,
464
  attention_mask=attention_mask,
465
  max_new_tokens=max_new_tokens,
466
- min_new_tokens=200, # Force detailed response like demo
467
  temperature=temperature,
468
  top_p=top_p,
469
  do_sample=do_sample,
470
  repetition_penalty=repetition_penalty,
471
  pad_token_id=self.tokenizer.pad_token_id,
472
- eos_token_id=self.tokenizer.eos_token_id,
473
- early_stopping=False
474
  )
475
 
476
- # Decode and clean response
477
  generated_ids = outputs[0][input_ids.shape[-1]:]
478
  generated_text = self.tokenizer.decode(
479
  generated_ids,
480
  skip_special_tokens=True,
481
  clean_up_tokenization_spaces=True
482
- ).strip()
483
-
484
- # Aggressive cleanup of artifacts
485
- generated_text = generated_text.replace("</s>", "").strip()
486
-
487
- # Simple cleanup - just remove Answer prefix and parentheses
488
- if generated_text.startswith("(Answer:") and ")" in generated_text:
489
- # Just remove the parentheses and Answer: prefix
490
- end_paren = generated_text.find(")")
491
- answer_content = generated_text[8:end_paren].strip() # Remove "(Answer:"
492
- # Keep the rest of the response if there is any
493
- rest_of_response = generated_text[end_paren+1:].strip()
494
-
495
- if rest_of_response:
496
- generated_text = f"{answer_content}. {rest_of_response}"
497
- else:
498
- generated_text = answer_content
499
-
500
- elif generated_text.startswith("Answer:"):
501
- generated_text = generated_text[7:].strip()
502
-
503
- # Remove only clear training artifacts
504
- cleanup_patterns = [
505
- "In this task",
506
- "I'm asking the respondent",
507
- "The respondent should"
508
- ]
509
-
510
- for pattern in cleanup_patterns:
511
- if pattern in generated_text:
512
- parts = generated_text.split(pattern)
513
- generated_text = parts[0].strip()
514
- break
515
-
516
- # Only provide fallback if response is truly empty or malformed
517
- if len(generated_text) < 10 or generated_text.startswith("7)"):
518
- print("⚠️ Malformed response detected, providing fallback...")
519
- generated_text = "This ECG shows cardiac electrical activity. For accurate interpretation, please consult with a qualified cardiologist who can analyze the specific waveforms, intervals, and morphology patterns."
520
-
521
- print(f"✅ Enhanced text-only generation: '{generated_text[:100]}...' (length: {len(generated_text)})")
522
-
523
- # Create response
524
- response = {"generated_text": generated_text}
525
-
526
- # Add Turkish commentary if requested (NEW FEATURE)
527
- if enable_turkish_commentary:
528
- response = self.add_turkish_commentary(response, True)
529
 
530
- return [response]
531
-
532
 
533
  except Exception as e:
534
- error_msg = f"Something went wrong during generation: {str(e)}"
535
- print(f"❌ {error_msg}")
536
  return [{
537
  "generated_text": "",
538
- "error": error_msg,
539
- "handler": "Ubden® Team Enhanced Handler"
540
  }]
 
 
 
 
 
 
 
1
  import torch
2
  from typing import Dict, List, Any
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
  """
8
+ Initialize the handler for PULSE-7B model.
9
+ Direct reference to the original model.
10
 
11
  Args:
12
+ path: Path to the model directory (not used, we load from HF hub)
13
  """
14
+ print("Initializing PULSE-7B handler...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ # Device ayarla
 
 
 
 
17
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ print(f"Using device: {self.device}")
19
 
20
  try:
21
+ # Pipeline kullan - en basit ve güvenilir yöntem
22
+ from transformers import pipeline
23
+
24
+ print("Loading model from HuggingFace Hub...")
25
+ self.pipe = pipeline(
26
+ "text-generation",
27
+ model="PULSE-ECG/PULSE-7B",
28
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
29
+ device=0 if torch.cuda.is_available() else -1,
30
+ trust_remote_code=True,
31
+ model_kwargs={
32
+ "low_cpu_mem_usage": True,
33
+ "use_safetensors": True
34
+ }
35
+ )
36
+ print("Model loaded successfully via pipeline!")
 
 
 
 
 
 
 
 
 
37
 
38
  except Exception as e:
39
+ print(f"Pipeline loading failed: {e}")
40
+ print("Trying alternative loading method...")
41
 
42
  try:
43
+ # Alternatif: Model ve tokenizer'ı ayrı yükle
44
+ from transformers import AutoTokenizer, LlamaForCausalLM
45
+
46
+ # Tokenizer'ı yükle
47
+ print("Loading tokenizer...")
48
+ self.tokenizer = AutoTokenizer.from_pretrained(
49
+ "PULSE-ECG/PULSE-7B",
50
+ trust_remote_code=True
51
+ )
52
 
53
+ # Model'i Llama olarak yükle
54
+ print("Loading model as Llama...")
55
+ self.model = LlamaForCausalLM.from_pretrained(
56
+ "PULSE-ECG/PULSE-7B",
57
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
58
+ device_map="auto",
59
+ low_cpu_mem_usage=True,
60
+ trust_remote_code=True
 
 
 
61
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ # Padding token ayarla
64
+ if self.tokenizer.pad_token is None:
65
+ self.tokenizer.pad_token = self.tokenizer.eos_token
66
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
67
 
68
+ self.model.eval()
69
+ self.use_pipeline = False
70
+ print("Model loaded successfully via direct loading!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ except Exception as e2:
73
+ print(f"Alternative loading also failed: {e2}")
74
+ # En son çare: Basit bir fallback mesajı
75
+ self.pipe = None
76
+ self.model = None
77
+ self.tokenizer = None
78
+ self.use_pipeline = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  else:
80
+ self.use_pipeline = True
 
 
 
 
 
81
 
82
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
83
  """
84
+ Process the inference request.
85
 
86
  Args:
87
+ data: Input data containing 'inputs' and optional 'parameters'
88
 
89
  Returns:
90
+ List containing the generated response
91
  """
92
+ # Model yüklenemediyse hata döndür
93
  if self.use_pipeline is None:
94
  return [{
95
+ "generated_text": "Model could not be loaded. Please check the deployment configuration.",
96
+ "error": "Model initialization failed"
 
97
  }]
98
 
99
  try:
100
+ # Input'ları al
101
  inputs = data.get("inputs", "")
 
 
 
102
  if isinstance(inputs, dict):
103
+ text = inputs.get("text", inputs.get("prompt", str(inputs)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  else:
 
105
  text = str(inputs)
106
 
107
  if not text:
108
+ return [{"generated_text": "Please provide an input text."}]
109
 
110
+ # Parametreleri al
111
  parameters = data.get("parameters", {})
112
+ max_new_tokens = min(parameters.get("max_new_tokens", 256), 1024)
113
+ temperature = parameters.get("temperature", 0.7)
114
+ top_p = parameters.get("top_p", 0.95)
115
+ do_sample = parameters.get("do_sample", True)
116
+ repetition_penalty = parameters.get("repetition_penalty", 1.0)
 
 
 
 
 
117
 
118
+ # Pipeline kullanıyorsak
119
  if self.use_pipeline:
 
 
 
120
  result = self.pipe(
121
  text,
122
  max_new_tokens=max_new_tokens,
 
123
  temperature=temperature,
124
  top_p=top_p,
125
  do_sample=do_sample,
126
  repetition_penalty=repetition_penalty,
127
+ return_full_text=False # Sadece yeni üretilen metni döndür
128
  )
129
 
130
+ # Pipeline list döndürür
131
  if isinstance(result, list) and len(result) > 0:
132
+ return [{"generated_text": result[0].get("generated_text", "")}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  else:
134
+ return [{"generated_text": str(result)}]
 
 
 
 
 
 
 
 
 
135
 
136
+ # Manuel generation kullanıyorsak
137
  else:
138
+ # Tokenize
 
 
 
 
 
 
139
  encoded = self.tokenizer(
140
  text,
141
  return_tensors="pt",
142
  truncation=True,
143
+ max_length=2048
144
  )
145
 
146
  input_ids = encoded["input_ids"].to(self.device)
 
148
  if attention_mask is not None:
149
  attention_mask = attention_mask.to(self.device)
150
 
151
+ # Generate
 
 
 
 
 
152
  with torch.no_grad():
153
  outputs = self.model.generate(
154
  input_ids,
155
  attention_mask=attention_mask,
156
  max_new_tokens=max_new_tokens,
 
157
  temperature=temperature,
158
  top_p=top_p,
159
  do_sample=do_sample,
160
  repetition_penalty=repetition_penalty,
161
  pad_token_id=self.tokenizer.pad_token_id,
162
+ eos_token_id=self.tokenizer.eos_token_id
 
163
  )
164
 
165
+ # Decode - sadece yeni tokenleri al
166
  generated_ids = outputs[0][input_ids.shape[-1]:]
167
  generated_text = self.tokenizer.decode(
168
  generated_ids,
169
  skip_special_tokens=True,
170
  clean_up_tokenization_spaces=True
171
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
+ return [{"generated_text": generated_text}]
 
174
 
175
  except Exception as e:
176
+ error_msg = f"Error during generation: {str(e)}"
177
+ print(error_msg)
178
  return [{
179
  "generated_text": "",
180
+ "error": error_msg
 
181
  }]
requirements.txt CHANGED
@@ -3,8 +3,4 @@ torch>=2.1.0
3
  accelerate>=0.25.0
4
  sentencepiece
5
  safetensors
6
- protobuf
7
- Pillow>=9.0.0
8
- requests>=2.28.0
9
- numpy>=1.21.0
10
- psutil>=5.8.0
 
3
  accelerate>=0.25.0
4
  sentencepiece
5
  safetensors
6
+ protobuf