|
|
import argparse |
|
|
import time |
|
|
import mlx.core as mx |
|
|
from transformers import AutoTokenizer |
|
|
from model import load_model |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
def generate_text( |
|
|
prompt: str, |
|
|
model_path: str, |
|
|
max_tokens: int = 100, |
|
|
temperature: float = 0.1, |
|
|
top_p: float = 0.9, |
|
|
system: str | None = None, |
|
|
final_only: bool = False, |
|
|
stop_at_boxed: bool = False, |
|
|
extract_boxed: bool = False, |
|
|
disable_chat_template: bool = False, |
|
|
repetition_penalty: float = 1.0, |
|
|
frequency_penalty: float = 0.0, |
|
|
): |
|
|
"""Generates text using the loaded MLX model with better sampling.""" |
|
|
print("Loading model and tokenizer...") |
|
|
model = load_model(model_path) |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
|
|
|
|
|
chat_template_path = Path(model_path) / "chat_template.jinja" |
|
|
use_chat_format = chat_template_path.exists() and not disable_chat_template |
|
|
|
|
|
print(f"Chat template found: {use_chat_format}") |
|
|
print("Starting generation...") |
|
|
print(f"Prompt: {prompt}") |
|
|
|
|
|
|
|
|
if use_chat_format: |
|
|
messages = [] |
|
|
if system is None and final_only: |
|
|
system = ( |
|
|
"You are a helpful assistant. Do not reveal your reasoning. " |
|
|
"Respond with only the final answer enclosed in \\boxed{...}." |
|
|
) |
|
|
if system is not None: |
|
|
messages.append({"role": "system", "content": system}) |
|
|
messages.append({"role": "user", "content": prompt}) |
|
|
formatted_prompt = tokenizer.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
print(f"Formatted prompt: {formatted_prompt}") |
|
|
else: |
|
|
|
|
|
bos = tokenizer.bos_token or "" |
|
|
formatted_prompt = f"{bos}{prompt}" |
|
|
|
|
|
|
|
|
prompt_tokens = tokenizer.encode(formatted_prompt, add_special_tokens=False) |
|
|
prompt_tokens = mx.array([prompt_tokens]) |
|
|
|
|
|
print(f"Prompt tokens shape: {prompt_tokens.shape}") |
|
|
print( |
|
|
f"First few token IDs: {prompt_tokens[0, : min(10, prompt_tokens.shape[1])].tolist()}" |
|
|
) |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
generated_tokens = [] |
|
|
freq_counts = {} |
|
|
|
|
|
running_text = "" |
|
|
seen_box_start = False |
|
|
for i in range(max_tokens): |
|
|
|
|
|
logits = model(prompt_tokens) |
|
|
|
|
|
|
|
|
next_token_logits = logits[0, -1, :] |
|
|
|
|
|
|
|
|
if repetition_penalty and repetition_penalty != 1.0 and generated_tokens: |
|
|
|
|
|
|
|
|
logits_list = next_token_logits.tolist() |
|
|
seen = set(generated_tokens) |
|
|
for tid in seen: |
|
|
val = logits_list[tid] |
|
|
if val > 0: |
|
|
logits_list[tid] = val / repetition_penalty |
|
|
else: |
|
|
logits_list[tid] = val * repetition_penalty |
|
|
next_token_logits = mx.array(logits_list) |
|
|
|
|
|
if frequency_penalty and frequency_penalty > 0 and generated_tokens: |
|
|
|
|
|
counts = {} |
|
|
for t in generated_tokens: |
|
|
counts[t] = counts.get(t, 0) + 1 |
|
|
|
|
|
vocab_size = next_token_logits.shape[-1] |
|
|
pen = [0.0] * vocab_size |
|
|
for tid, c in counts.items(): |
|
|
pen[tid] = frequency_penalty * float(c) |
|
|
next_token_logits = next_token_logits - mx.array(pen) |
|
|
|
|
|
|
|
|
if temperature == 0: |
|
|
|
|
|
next_token = int(mx.argmax(next_token_logits).item()) |
|
|
else: |
|
|
|
|
|
scaled_logits = next_token_logits / temperature |
|
|
|
|
|
if 0.0 < top_p < 1.0: |
|
|
probs = mx.softmax(scaled_logits, axis=-1) |
|
|
sorted_probs = mx.sort(probs)[::-1] |
|
|
cumulative_probs = mx.cumsum(sorted_probs, axis=-1) |
|
|
cutoff_index = mx.sum(cumulative_probs < top_p) |
|
|
cutoff_prob = sorted_probs[cutoff_index.item()] |
|
|
mask = probs >= cutoff_prob |
|
|
scaled_logits = mx.where(mask, scaled_logits, float("-inf")) |
|
|
|
|
|
|
|
|
next_token = mx.random.categorical(scaled_logits, num_samples=1).item() |
|
|
|
|
|
|
|
|
eos_ids = tokenizer.eos_token_id |
|
|
if isinstance(eos_ids, (list, tuple)): |
|
|
stop_ids = set(int(i) for i in eos_ids) |
|
|
else: |
|
|
stop_ids = {int(eos_ids)} |
|
|
if next_token in stop_ids: |
|
|
print(f"Stopping generation at EOS token: {next_token}") |
|
|
break |
|
|
|
|
|
generated_tokens.append(next_token) |
|
|
|
|
|
freq_counts[next_token] = freq_counts.get(next_token, 0) + 1 |
|
|
|
|
|
prompt_tokens = mx.concatenate( |
|
|
[prompt_tokens, mx.array([[next_token]])], axis=1 |
|
|
) |
|
|
|
|
|
|
|
|
if i < 10: |
|
|
token_text = tokenizer.decode([next_token]) |
|
|
print(f"Token {i}: {next_token} -> '{token_text}'") |
|
|
|
|
|
|
|
|
if stop_at_boxed: |
|
|
token_text_full = tokenizer.decode([next_token], skip_special_tokens=False) |
|
|
running_text += token_text_full |
|
|
if not seen_box_start and "\\boxed{" in running_text: |
|
|
seen_box_start = True |
|
|
if seen_box_start and "}" in running_text: |
|
|
print("Stopping generation at boxed answer.") |
|
|
break |
|
|
|
|
|
end_time = time.time() |
|
|
|
|
|
|
|
|
if generated_tokens: |
|
|
response = tokenizer.decode(generated_tokens, skip_special_tokens=True) |
|
|
print("\n--- Response ---") |
|
|
print(response) |
|
|
else: |
|
|
print("\n--- No tokens generated ---") |
|
|
|
|
|
print("------------------") |
|
|
|
|
|
generation_speed = ( |
|
|
len(generated_tokens) / (end_time - start_time) if generated_tokens else 0 |
|
|
) |
|
|
print(f"Generated {len(generated_tokens)} tokens") |
|
|
print(f"Generation speed: {generation_speed:.2f} tokens/sec") |
|
|
|
|
|
|
|
|
if generated_tokens: |
|
|
full_response = tokenizer.decode(generated_tokens, skip_special_tokens=False) |
|
|
print(f"\nFull response (with special tokens): '{full_response}'") |
|
|
|
|
|
if extract_boxed and generated_tokens: |
|
|
import re |
|
|
m = None |
|
|
|
|
|
for m in re.finditer(r"\\\\boxed\{([^}]*)\}", full_response): |
|
|
pass |
|
|
if m: |
|
|
print(f"\nExtracted boxed answer: {m.group(1).strip()}") |
|
|
else: |
|
|
print("\nNo \\boxed{...} segment found to extract.") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Run inference with the MLX model.") |
|
|
parser.add_argument( |
|
|
"--model-path", type=str, default=".", help="Path to the model directory." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--prompt", |
|
|
type=str, |
|
|
default="What is the capital of France?", |
|
|
help="The prompt to start generation from.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max-tokens", |
|
|
type=int, |
|
|
default=100, |
|
|
help="The maximum number of tokens to generate.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--temperature", type=float, default=0.1, help="Sampling temperature." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--top-p", type=float, default=0.9, help="Top-p (nucleus) sampling parameter." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--system", type=str, default=None, help="Optional system message for chat template." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--final-only", |
|
|
action="store_true", |
|
|
help="Instruct the model to output only the final answer inside \\boxed{...}.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--stop-at-boxed", |
|
|
action="store_true", |
|
|
help="Stop generation once a closing '}' appears after \\boxed{.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--extract-boxed", |
|
|
action="store_true", |
|
|
help="Extract and print the content inside the last \\boxed{...} in the response.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--disable-chat-template", |
|
|
action="store_true", |
|
|
help="Ignore chat_template.jinja and feed the raw prompt (prepended with BOS).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--repetition-penalty", |
|
|
type=float, |
|
|
default=1.0, |
|
|
help="Penalty (>1.0) to discourage previously generated tokens.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--frequency-penalty", |
|
|
type=float, |
|
|
default=0.0, |
|
|
help="Subtract alpha * count(token) from logits before sampling.", |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
generate_text( |
|
|
args.prompt, |
|
|
args.model_path, |
|
|
args.max_tokens, |
|
|
args.temperature, |
|
|
args.top_p, |
|
|
args.system, |
|
|
args.final_only, |
|
|
args.stop_at_boxed, |
|
|
args.extract_boxed, |
|
|
args.disable_chat_template, |
|
|
args.repetition_penalty, |
|
|
args.frequency_penalty, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|