MobileLLM-R1-950M-MLX / inference.py
robbiemu's picture
add mlx and mlx-lm support
e39ff3a
raw
history blame
9.72 kB
import argparse
import time
import mlx.core as mx
from transformers import AutoTokenizer
from model import load_model
from pathlib import Path
def generate_text(
prompt: str,
model_path: str,
max_tokens: int = 100,
temperature: float = 0.1,
top_p: float = 0.9,
system: str | None = None,
final_only: bool = False,
stop_at_boxed: bool = False,
extract_boxed: bool = False,
disable_chat_template: bool = False,
repetition_penalty: float = 1.0,
frequency_penalty: float = 0.0,
):
"""Generates text using the loaded MLX model with better sampling."""
print("Loading model and tokenizer...")
model = load_model(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Check if we have the chat template
chat_template_path = Path(model_path) / "chat_template.jinja"
use_chat_format = chat_template_path.exists() and not disable_chat_template
print(f"Chat template found: {use_chat_format}")
print("Starting generation...")
print(f"Prompt: {prompt}")
# Format the prompt if using chat template
if use_chat_format:
messages = []
if system is None and final_only:
system = (
"You are a helpful assistant. Do not reveal your reasoning. "
"Respond with only the final answer enclosed in \\boxed{...}."
)
if system is not None:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
formatted_prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
print(f"Formatted prompt: {formatted_prompt}")
else:
# No chat template: prepend BOS if available in tokenizer
bos = tokenizer.bos_token or ""
formatted_prompt = f"{bos}{prompt}"
# Tokenize the prompt
prompt_tokens = tokenizer.encode(formatted_prompt, add_special_tokens=False)
prompt_tokens = mx.array([prompt_tokens])
print(f"Prompt tokens shape: {prompt_tokens.shape}")
print(
f"First few token IDs: {prompt_tokens[0, : min(10, prompt_tokens.shape[1])].tolist()}"
)
# Generation loop with better sampling
start_time = time.time()
generated_tokens = []
freq_counts = {}
running_text = ""
seen_box_start = False
for i in range(max_tokens):
# Get logits from model
logits = model(prompt_tokens)
# Focus on next-token logits
next_token_logits = logits[0, -1, :]
# Apply repetition and frequency penalties before sampling/argmax
if repetition_penalty and repetition_penalty != 1.0 and generated_tokens:
# Apply a simple repetition penalty to previously generated tokens
# Using HF-like rule: if logit > 0 divide by penalty else multiply by penalty
logits_list = next_token_logits.tolist()
seen = set(generated_tokens)
for tid in seen:
val = logits_list[tid]
if val > 0:
logits_list[tid] = val / repetition_penalty
else:
logits_list[tid] = val * repetition_penalty
next_token_logits = mx.array(logits_list)
if frequency_penalty and frequency_penalty > 0 and generated_tokens:
# Subtract a multiple of token frequency from logits
counts = {}
for t in generated_tokens:
counts[t] = counts.get(t, 0) + 1
# Build a dense penalty vector once per step
vocab_size = next_token_logits.shape[-1]
pen = [0.0] * vocab_size
for tid, c in counts.items():
pen[tid] = frequency_penalty * float(c)
next_token_logits = next_token_logits - mx.array(pen)
# Apply temperature (temperature==0 -> greedy)
if temperature == 0:
# Greedy decode
next_token = int(mx.argmax(next_token_logits).item())
else:
# Sampling path: scale logits, apply top-p mask in logits space
scaled_logits = next_token_logits / temperature
if 0.0 < top_p < 1.0:
probs = mx.softmax(scaled_logits, axis=-1)
sorted_probs = mx.sort(probs)[::-1]
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
cutoff_index = mx.sum(cumulative_probs < top_p)
cutoff_prob = sorted_probs[cutoff_index.item()]
mask = probs >= cutoff_prob
scaled_logits = mx.where(mask, scaled_logits, float("-inf"))
# Sample from logits (MLX categorical expects logits)
next_token = mx.random.categorical(scaled_logits, num_samples=1).item()
# Safer stop condition: support multiple EOS ids
eos_ids = tokenizer.eos_token_id
if isinstance(eos_ids, (list, tuple)):
stop_ids = set(int(i) for i in eos_ids)
else:
stop_ids = {int(eos_ids)}
if next_token in stop_ids:
print(f"Stopping generation at EOS token: {next_token}")
break
generated_tokens.append(next_token)
# Update frequency counts
freq_counts[next_token] = freq_counts.get(next_token, 0) + 1
# Append the new token for the next iteration
prompt_tokens = mx.concatenate(
[prompt_tokens, mx.array([[next_token]])], axis=1
)
# Print token as we generate for debugging
if i < 10: # Only print first 10 tokens to avoid spam
token_text = tokenizer.decode([next_token])
print(f"Token {i}: {next_token} -> '{token_text}'")
# Optional boxed stopping condition
if stop_at_boxed:
token_text_full = tokenizer.decode([next_token], skip_special_tokens=False)
running_text += token_text_full
if not seen_box_start and "\\boxed{" in running_text:
seen_box_start = True
if seen_box_start and "}" in running_text:
print("Stopping generation at boxed answer.")
break
end_time = time.time()
# Decode and print the result
if generated_tokens:
response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
print("\n--- Response ---")
print(response)
else:
print("\n--- No tokens generated ---")
print("------------------")
generation_speed = (
len(generated_tokens) / (end_time - start_time) if generated_tokens else 0
)
print(f"Generated {len(generated_tokens)} tokens")
print(f"Generation speed: {generation_speed:.2f} tokens/sec")
# Also print the full generated sequence including special tokens for debugging
if generated_tokens:
full_response = tokenizer.decode(generated_tokens, skip_special_tokens=False)
print(f"\nFull response (with special tokens): '{full_response}'")
if extract_boxed and generated_tokens:
import re
m = None
# Get the last occurrence of \\boxed{...}
for m in re.finditer(r"\\\\boxed\{([^}]*)\}", full_response):
pass
if m:
print(f"\nExtracted boxed answer: {m.group(1).strip()}")
else:
print("\nNo \\boxed{...} segment found to extract.")
def main():
parser = argparse.ArgumentParser(description="Run inference with the MLX model.")
parser.add_argument(
"--model-path", type=str, default=".", help="Path to the model directory."
)
parser.add_argument(
"--prompt",
type=str,
default="What is the capital of France?",
help="The prompt to start generation from.",
)
parser.add_argument(
"--max-tokens",
type=int,
default=100,
help="The maximum number of tokens to generate.",
)
parser.add_argument(
"--temperature", type=float, default=0.1, help="Sampling temperature."
)
parser.add_argument(
"--top-p", type=float, default=0.9, help="Top-p (nucleus) sampling parameter."
)
parser.add_argument(
"--system", type=str, default=None, help="Optional system message for chat template."
)
parser.add_argument(
"--final-only",
action="store_true",
help="Instruct the model to output only the final answer inside \\boxed{...}.",
)
parser.add_argument(
"--stop-at-boxed",
action="store_true",
help="Stop generation once a closing '}' appears after \\boxed{.",
)
parser.add_argument(
"--extract-boxed",
action="store_true",
help="Extract and print the content inside the last \\boxed{...} in the response.",
)
parser.add_argument(
"--disable-chat-template",
action="store_true",
help="Ignore chat_template.jinja and feed the raw prompt (prepended with BOS).",
)
parser.add_argument(
"--repetition-penalty",
type=float,
default=1.0,
help="Penalty (>1.0) to discourage previously generated tokens.",
)
parser.add_argument(
"--frequency-penalty",
type=float,
default=0.0,
help="Subtract alpha * count(token) from logits before sampling.",
)
args = parser.parse_args()
generate_text(
args.prompt,
args.model_path,
args.max_tokens,
args.temperature,
args.top_p,
args.system,
args.final_only,
args.stop_at_boxed,
args.extract_boxed,
args.disable_chat_template,
args.repetition_penalty,
args.frequency_penalty,
)
if __name__ == "__main__":
main()