File size: 9,723 Bytes
e39ff3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
import argparse
import time
import mlx.core as mx
from transformers import AutoTokenizer
from model import load_model
from pathlib import Path


def generate_text(
    prompt: str,
    model_path: str,
    max_tokens: int = 100,
    temperature: float = 0.1,
    top_p: float = 0.9,
    system: str | None = None,
    final_only: bool = False,
    stop_at_boxed: bool = False,
    extract_boxed: bool = False,
    disable_chat_template: bool = False,
    repetition_penalty: float = 1.0,
    frequency_penalty: float = 0.0,
):
    """Generates text using the loaded MLX model with better sampling."""
    print("Loading model and tokenizer...")
    model = load_model(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    # Check if we have the chat template
    chat_template_path = Path(model_path) / "chat_template.jinja"
    use_chat_format = chat_template_path.exists() and not disable_chat_template

    print(f"Chat template found: {use_chat_format}")
    print("Starting generation...")
    print(f"Prompt: {prompt}")

    # Format the prompt if using chat template
    if use_chat_format:
        messages = []
        if system is None and final_only:
            system = (
                "You are a helpful assistant. Do not reveal your reasoning. "
                "Respond with only the final answer enclosed in \\boxed{...}."
            )
        if system is not None:
            messages.append({"role": "system", "content": system})
        messages.append({"role": "user", "content": prompt})
        formatted_prompt = tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        print(f"Formatted prompt: {formatted_prompt}")
    else:
        # No chat template: prepend BOS if available in tokenizer
        bos = tokenizer.bos_token or ""
        formatted_prompt = f"{bos}{prompt}"

    # Tokenize the prompt
    prompt_tokens = tokenizer.encode(formatted_prompt, add_special_tokens=False)
    prompt_tokens = mx.array([prompt_tokens])

    print(f"Prompt tokens shape: {prompt_tokens.shape}")
    print(
        f"First few token IDs: {prompt_tokens[0, : min(10, prompt_tokens.shape[1])].tolist()}"
    )

    # Generation loop with better sampling
    start_time = time.time()
    generated_tokens = []
    freq_counts = {}

    running_text = ""
    seen_box_start = False
    for i in range(max_tokens):
        # Get logits from model
        logits = model(prompt_tokens)

        # Focus on next-token logits
        next_token_logits = logits[0, -1, :]

        # Apply repetition and frequency penalties before sampling/argmax
        if repetition_penalty and repetition_penalty != 1.0 and generated_tokens:
            # Apply a simple repetition penalty to previously generated tokens
            # Using HF-like rule: if logit > 0 divide by penalty else multiply by penalty
            logits_list = next_token_logits.tolist()
            seen = set(generated_tokens)
            for tid in seen:
                val = logits_list[tid]
                if val > 0:
                    logits_list[tid] = val / repetition_penalty
                else:
                    logits_list[tid] = val * repetition_penalty
            next_token_logits = mx.array(logits_list)

        if frequency_penalty and frequency_penalty > 0 and generated_tokens:
            # Subtract a multiple of token frequency from logits
            counts = {}
            for t in generated_tokens:
                counts[t] = counts.get(t, 0) + 1
            # Build a dense penalty vector once per step
            vocab_size = next_token_logits.shape[-1]
            pen = [0.0] * vocab_size
            for tid, c in counts.items():
                pen[tid] = frequency_penalty * float(c)
            next_token_logits = next_token_logits - mx.array(pen)

        # Apply temperature (temperature==0 -> greedy)
        if temperature == 0:
            # Greedy decode
            next_token = int(mx.argmax(next_token_logits).item())
        else:
            # Sampling path: scale logits, apply top-p mask in logits space
            scaled_logits = next_token_logits / temperature

            if 0.0 < top_p < 1.0:
                probs = mx.softmax(scaled_logits, axis=-1)
                sorted_probs = mx.sort(probs)[::-1]
                cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
                cutoff_index = mx.sum(cumulative_probs < top_p)
                cutoff_prob = sorted_probs[cutoff_index.item()]
                mask = probs >= cutoff_prob
                scaled_logits = mx.where(mask, scaled_logits, float("-inf"))

            # Sample from logits (MLX categorical expects logits)
            next_token = mx.random.categorical(scaled_logits, num_samples=1).item()

        # Safer stop condition: support multiple EOS ids
        eos_ids = tokenizer.eos_token_id
        if isinstance(eos_ids, (list, tuple)):
            stop_ids = set(int(i) for i in eos_ids)
        else:
            stop_ids = {int(eos_ids)}
        if next_token in stop_ids:
            print(f"Stopping generation at EOS token: {next_token}")
            break

        generated_tokens.append(next_token)
        # Update frequency counts
        freq_counts[next_token] = freq_counts.get(next_token, 0) + 1
        # Append the new token for the next iteration
        prompt_tokens = mx.concatenate(
            [prompt_tokens, mx.array([[next_token]])], axis=1
        )

        # Print token as we generate for debugging
        if i < 10:  # Only print first 10 tokens to avoid spam
            token_text = tokenizer.decode([next_token])
            print(f"Token {i}: {next_token} -> '{token_text}'")

        # Optional boxed stopping condition
        if stop_at_boxed:
            token_text_full = tokenizer.decode([next_token], skip_special_tokens=False)
            running_text += token_text_full
            if not seen_box_start and "\\boxed{" in running_text:
                seen_box_start = True
            if seen_box_start and "}" in running_text:
                print("Stopping generation at boxed answer.")
                break

    end_time = time.time()

    # Decode and print the result
    if generated_tokens:
        response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
        print("\n--- Response ---")
        print(response)
    else:
        print("\n--- No tokens generated ---")

    print("------------------")

    generation_speed = (
        len(generated_tokens) / (end_time - start_time) if generated_tokens else 0
    )
    print(f"Generated {len(generated_tokens)} tokens")
    print(f"Generation speed: {generation_speed:.2f} tokens/sec")

    # Also print the full generated sequence including special tokens for debugging
    if generated_tokens:
        full_response = tokenizer.decode(generated_tokens, skip_special_tokens=False)
        print(f"\nFull response (with special tokens): '{full_response}'")

    if extract_boxed and generated_tokens:
        import re
        m = None
        # Get the last occurrence of \\boxed{...}
        for m in re.finditer(r"\\\\boxed\{([^}]*)\}", full_response):
            pass
        if m:
            print(f"\nExtracted boxed answer: {m.group(1).strip()}")
        else:
            print("\nNo \\boxed{...} segment found to extract.")


def main():
    parser = argparse.ArgumentParser(description="Run inference with the MLX model.")
    parser.add_argument(
        "--model-path", type=str, default=".", help="Path to the model directory."
    )
    parser.add_argument(
        "--prompt",
        type=str,
        default="What is the capital of France?",
        help="The prompt to start generation from.",
    )
    parser.add_argument(
        "--max-tokens",
        type=int,
        default=100,
        help="The maximum number of tokens to generate.",
    )
    parser.add_argument(
        "--temperature", type=float, default=0.1, help="Sampling temperature."
    )
    parser.add_argument(
        "--top-p", type=float, default=0.9, help="Top-p (nucleus) sampling parameter."
    )
    parser.add_argument(
        "--system", type=str, default=None, help="Optional system message for chat template."
    )
    parser.add_argument(
        "--final-only",
        action="store_true",
        help="Instruct the model to output only the final answer inside \\boxed{...}.",
    )
    parser.add_argument(
        "--stop-at-boxed",
        action="store_true",
        help="Stop generation once a closing '}' appears after \\boxed{.",
    )
    parser.add_argument(
        "--extract-boxed",
        action="store_true",
        help="Extract and print the content inside the last \\boxed{...} in the response.",
    )
    parser.add_argument(
        "--disable-chat-template",
        action="store_true",
        help="Ignore chat_template.jinja and feed the raw prompt (prepended with BOS).",
    )
    parser.add_argument(
        "--repetition-penalty",
        type=float,
        default=1.0,
        help="Penalty (>1.0) to discourage previously generated tokens.",
    )
    parser.add_argument(
        "--frequency-penalty",
        type=float,
        default=0.0,
        help="Subtract alpha * count(token) from logits before sampling.",
    )
    args = parser.parse_args()

    generate_text(
        args.prompt,
        args.model_path,
        args.max_tokens,
        args.temperature,
        args.top_p,
        args.system,
        args.final_only,
        args.stop_at_boxed,
        args.extract_boxed,
        args.disable_chat_template,
        args.repetition_penalty,
        args.frequency_penalty,
    )


if __name__ == "__main__":
    main()