Trouter-Library commited on
Commit
23a2e44
·
verified ·
1 Parent(s): 676bc11

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +322 -0
inference.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helion-V1.5-XL Inference Script
3
+ Supports multiple inference modes and optimization techniques
4
+ """
5
+
6
+ import torch
7
+ from transformers import (
8
+ AutoTokenizer,
9
+ AutoModelForCausalLM,
10
+ BitsAndBytesConfig,
11
+ GenerationConfig
12
+ )
13
+ from typing import Optional, Dict, Any, List
14
+ import argparse
15
+ import json
16
+ import time
17
+
18
+
19
+ class HelionInference:
20
+ """Inference wrapper for Helion-V1.5-XL"""
21
+
22
+ def __init__(
23
+ self,
24
+ model_name: str = "DeepXR/Helion-V1.5-XL",
25
+ load_in_4bit: bool = False,
26
+ load_in_8bit: bool = False,
27
+ device_map: str = "auto",
28
+ torch_dtype: str = "bfloat16"
29
+ ):
30
+ """
31
+ Initialize the model and tokenizer
32
+
33
+ Args:
34
+ model_name: HuggingFace model identifier
35
+ load_in_4bit: Enable 4-bit quantization
36
+ load_in_8bit: Enable 8-bit quantization
37
+ device_map: Device mapping strategy
38
+ torch_dtype: PyTorch dtype for model weights
39
+ """
40
+ self.model_name = model_name
41
+ print(f"Loading model: {model_name}")
42
+
43
+ # Setup dtype
44
+ dtype_map = {
45
+ "bfloat16": torch.bfloat16,
46
+ "float16": torch.float16,
47
+ "float32": torch.float32
48
+ }
49
+ torch_dtype = dtype_map.get(torch_dtype, torch.bfloat16)
50
+
51
+ # Setup quantization config
52
+ quantization_config = None
53
+ if load_in_4bit:
54
+ quantization_config = BitsAndBytesConfig(
55
+ load_in_4bit=True,
56
+ bnb_4bit_compute_dtype=torch_dtype,
57
+ bnb_4bit_use_double_quant=True,
58
+ bnb_4bit_quant_type="nf4"
59
+ )
60
+ elif load_in_8bit:
61
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
62
+
63
+ # Load tokenizer
64
+ self.tokenizer = AutoTokenizer.from_pretrained(
65
+ model_name,
66
+ trust_remote_code=True
67
+ )
68
+
69
+ # Load model
70
+ model_kwargs = {
71
+ "device_map": device_map,
72
+ "trust_remote_code": True,
73
+ }
74
+
75
+ if quantization_config:
76
+ model_kwargs["quantization_config"] = quantization_config
77
+ else:
78
+ model_kwargs["torch_dtype"] = torch_dtype
79
+
80
+ self.model = AutoModelForCausalLM.from_pretrained(
81
+ model_name,
82
+ **model_kwargs
83
+ )
84
+
85
+ self.model.eval()
86
+ print("Model loaded successfully!")
87
+
88
+ def generate(
89
+ self,
90
+ prompt: str,
91
+ max_new_tokens: int = 512,
92
+ temperature: float = 0.7,
93
+ top_p: float = 0.9,
94
+ top_k: int = 50,
95
+ repetition_penalty: float = 1.1,
96
+ do_sample: bool = True,
97
+ num_return_sequences: int = 1,
98
+ **kwargs
99
+ ) -> List[str]:
100
+ """
101
+ Generate text from a prompt
102
+
103
+ Args:
104
+ prompt: Input text prompt
105
+ max_new_tokens: Maximum number of tokens to generate
106
+ temperature: Sampling temperature (0.0 to 2.0)
107
+ top_p: Nucleus sampling threshold
108
+ top_k: Top-k sampling threshold
109
+ repetition_penalty: Penalty for repetition
110
+ do_sample: Whether to use sampling
111
+ num_return_sequences: Number of sequences to generate
112
+
113
+ Returns:
114
+ List of generated text strings
115
+ """
116
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
117
+
118
+ generation_config = GenerationConfig(
119
+ max_new_tokens=max_new_tokens,
120
+ temperature=temperature,
121
+ top_p=top_p,
122
+ top_k=top_k,
123
+ repetition_penalty=repetition_penalty,
124
+ do_sample=do_sample,
125
+ num_return_sequences=num_return_sequences,
126
+ pad_token_id=self.tokenizer.pad_token_id,
127
+ eos_token_id=self.tokenizer.eos_token_id,
128
+ **kwargs
129
+ )
130
+
131
+ start_time = time.time()
132
+
133
+ with torch.no_grad():
134
+ outputs = self.model.generate(
135
+ **inputs,
136
+ generation_config=generation_config
137
+ )
138
+
139
+ generation_time = time.time() - start_time
140
+
141
+ # Decode outputs
142
+ responses = []
143
+ for output in outputs:
144
+ response = self.tokenizer.decode(output, skip_special_tokens=True)
145
+ # Remove the prompt from response
146
+ response = response[len(prompt):].strip()
147
+ responses.append(response)
148
+
149
+ # Calculate tokens per second
150
+ total_tokens = sum(len(output) for output in outputs)
151
+ tokens_per_sec = total_tokens / generation_time
152
+
153
+ print(f"\nGeneration Stats:")
154
+ print(f" Time: {generation_time:.2f}s")
155
+ print(f" Tokens/sec: {tokens_per_sec:.2f}")
156
+
157
+ return responses
158
+
159
+ def chat(
160
+ self,
161
+ messages: List[Dict[str, str]],
162
+ max_new_tokens: int = 512,
163
+ temperature: float = 0.7,
164
+ **kwargs
165
+ ) -> str:
166
+ """
167
+ Generate response in chat format
168
+
169
+ Args:
170
+ messages: List of message dicts with 'role' and 'content'
171
+ max_new_tokens: Maximum tokens to generate
172
+ temperature: Sampling temperature
173
+
174
+ Returns:
175
+ Generated response string
176
+ """
177
+ # Apply chat template
178
+ prompt = self.tokenizer.apply_chat_template(
179
+ messages,
180
+ tokenize=False,
181
+ add_generation_prompt=True
182
+ )
183
+
184
+ responses = self.generate(
185
+ prompt,
186
+ max_new_tokens=max_new_tokens,
187
+ temperature=temperature,
188
+ **kwargs
189
+ )
190
+
191
+ return responses[0]
192
+
193
+ def batch_generate(
194
+ self,
195
+ prompts: List[str],
196
+ max_new_tokens: int = 512,
197
+ **kwargs
198
+ ) -> List[str]:
199
+ """
200
+ Generate responses for multiple prompts in batch
201
+
202
+ Args:
203
+ prompts: List of input prompts
204
+ max_new_tokens: Maximum tokens per generation
205
+
206
+ Returns:
207
+ List of generated responses
208
+ """
209
+ inputs = self.tokenizer(
210
+ prompts,
211
+ return_tensors="pt",
212
+ padding=True,
213
+ truncation=True
214
+ ).to(self.model.device)
215
+
216
+ with torch.no_grad():
217
+ outputs = self.model.generate(
218
+ **inputs,
219
+ max_new_tokens=max_new_tokens,
220
+ **kwargs
221
+ )
222
+
223
+ responses = []
224
+ for i, output in enumerate(outputs):
225
+ response = self.tokenizer.decode(output, skip_special_tokens=True)
226
+ # Remove prompt
227
+ response = response[len(prompts[i]):].strip()
228
+ responses.append(response)
229
+
230
+ return responses
231
+
232
+
233
+ def main():
234
+ parser = argparse.ArgumentParser(description="Helion-V1.5-XL Inference")
235
+ parser.add_argument(
236
+ "--model",
237
+ type=str,
238
+ default="DeepXR/Helion-V1.5-XL",
239
+ help="Model name or path"
240
+ )
241
+ parser.add_argument(
242
+ "--prompt",
243
+ type=str,
244
+ required=True,
245
+ help="Input prompt"
246
+ )
247
+ parser.add_argument(
248
+ "--max-tokens",
249
+ type=int,
250
+ default=512,
251
+ help="Maximum tokens to generate"
252
+ )
253
+ parser.add_argument(
254
+ "--temperature",
255
+ type=float,
256
+ default=0.7,
257
+ help="Sampling temperature"
258
+ )
259
+ parser.add_argument(
260
+ "--top-p",
261
+ type=float,
262
+ default=0.9,
263
+ help="Nucleus sampling threshold"
264
+ )
265
+ parser.add_argument(
266
+ "--load-in-4bit",
267
+ action="store_true",
268
+ help="Load model in 4-bit quantization"
269
+ )
270
+ parser.add_argument(
271
+ "--load-in-8bit",
272
+ action="store_true",
273
+ help="Load model in 8-bit quantization"
274
+ )
275
+ parser.add_argument(
276
+ "--chat-mode",
277
+ action="store_true",
278
+ help="Use chat format"
279
+ )
280
+
281
+ args = parser.parse_args()
282
+
283
+ # Initialize model
284
+ inference = HelionInference(
285
+ model_name=args.model,
286
+ load_in_4bit=args.load_in_4bit,
287
+ load_in_8bit=args.load_in_8bit
288
+ )
289
+
290
+ # Generate response
291
+ if args.chat_mode:
292
+ messages = [
293
+ {"role": "user", "content": args.prompt}
294
+ ]
295
+ response = inference.chat(
296
+ messages,
297
+ max_new_tokens=args.max_tokens,
298
+ temperature=args.temperature,
299
+ top_p=args.top_p
300
+ )
301
+ else:
302
+ responses = inference.generate(
303
+ args.prompt,
304
+ max_new_tokens=args.max_tokens,
305
+ temperature=args.temperature,
306
+ top_p=args.top_p
307
+ )
308
+ response = responses[0]
309
+
310
+ print("\n" + "="*80)
311
+ print("PROMPT:")
312
+ print("="*80)
313
+ print(args.prompt)
314
+ print("\n" + "="*80)
315
+ print("RESPONSE:")
316
+ print("="*80)
317
+ print(response)
318
+ print("="*80)
319
+
320
+
321
+ if __name__ == "__main__":
322
+ main()