samiur-r commited on
Commit
5037564
·
verified ·
1 Parent(s): 335198e

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +53 -0
handler.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from peft import PeftModel
4
+
5
+ class SentimentAnalysisHandler:
6
+ def __init__(self):
7
+ """Load base model and fine-tuned adapter."""
8
+ self.base_model_id = "unsloth/llama-3-8b-bnb-4bit"
9
+ self.adapter_model_id = "samiur-r/BanglishSentiment-Llama3-8B"
10
+
11
+ # Load tokenizer
12
+ self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_id)
13
+
14
+ # Load base model
15
+ self.model = AutoModelForCausalLM.from_pretrained(
16
+ self.base_model_id,
17
+ device_map="auto",
18
+ torch_dtype=torch.bfloat16
19
+ )
20
+
21
+ # Attach LoRA adapter
22
+ self.model = PeftModel.from_pretrained(self.model, self.adapter_model_id)
23
+ self.model.eval()
24
+
25
+ def preprocess(self, input_text):
26
+ """Tokenize input text."""
27
+ inputs = self.tokenizer(input_text, return_tensors="pt").to("cuda")
28
+ return inputs
29
+
30
+ def inference(self, inputs):
31
+ """Perform model inference."""
32
+ with torch.no_grad():
33
+ output = self.model.generate(**inputs, max_new_tokens=256)
34
+ return output
35
+
36
+ def postprocess(self, output):
37
+ """Decode model output."""
38
+ sentiment = self.tokenizer.decode(output[0], skip_special_tokens=True)
39
+ return sentiment
40
+
41
+ def predict(self, input_text):
42
+ """Full prediction pipeline."""
43
+ inputs = self.preprocess(input_text)
44
+ output = self.inference(inputs)
45
+ return self.postprocess(output)
46
+
47
+ # Create handler instance
48
+ _model_handler = SentimentAnalysisHandler()
49
+
50
+ def handle(inputs, context):
51
+ """Entry point for model API inference."""
52
+ text = inputs.get("text", "")
53
+ return {"prediction": _model_handler.predict(text)}