import os import torch import gradio as gr import spaces import json import random # <<< NEW from threading import Thread from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer # >>>> CHANGE THIS <<<< MODEL_ID = os.getenv("MODEL_ID", "theostos/babel-ssreflect-fp8") TEST_JSON_PATH = os.getenv("TEST_JSON_PATH", "test.json") INSTRUCTION_TEMPLATE = "You are given a proof term:\n\n{term}\n\nYour task is to derive a sequence of SSReflect tactics that corresponds to this term.\n\nWhen you work through the problem, write down your reasoning in detail inside ... tags. This reasoning should reflect your natural thought process as you explore the structure of the term and figure out what tactics to apply. You should consider different possible approaches, reflect on why some might or might not work, and gradually converge on a tactic choice.\n\nAfter each reasoning block, provide the next (group of) tactic(s) enclosed in:\n\n\\box{{\n \n}}\n\nSome dependencies that could be helpful:\n\n{dependencies}" HF_TOKEN = os.getenv("HF_TOKEN") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN, use_fast=True) if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: tokenizer.pad_token_id = tokenizer.eos_token_id _model = None def load_model(): global _model if _model is None: _model = AutoModelForCausalLM.from_pretrained( MODEL_ID, token=HF_TOKEN, device_map="auto", dtype="auto", trust_remote_code=True ) return _model def build_messages(term: str, deps: str): instr = INSTRUCTION_TEMPLATE.format(term=term, dependencies=deps) return [{"role": "user", "content": instr}] def load_test_examples(path=TEST_JSON_PATH): """ Expects a JSON list of dicts with keys: - 'rocq' - 'term' - 'notations' - 'constants' """ try: with open(path, "r", encoding="utf-8") as f: data = json.load(f) if not isinstance(data, list): raise ValueError("Test set JSON must be a list of objects.") for entry in data: entry['dependencies'] = "\n".join(entry['notations']) + "\n".join(entry['constants']) entry['initial_proof'] = "\n".join(entry['steps']) print(f"[info] Loaded {len(data)} test examples from {path}") return data except Exception as e: print(f"[warn] Could not load test set {path}: {e}") return [] TEST_EXAMPLES = load_test_examples() def _duration(term, deps, temperature, top_p, max_new_tokens): return int(min(300, max(60, (int(max_new_tokens) / 2.5) + 30))) @spaces.GPU(duration=_duration) def generate(term, deps, temperature, top_p, max_new_tokens): model = load_model() device = "cuda" if torch.cuda.is_available() else "cpu" messages = build_messages(term, deps) prompt_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer(prompt_text, return_tensors="pt").to(device) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) gen_kwargs = dict( **inputs, max_new_tokens=int(max_new_tokens), temperature=float(temperature), top_p=float(top_p), do_sample=True, streamer=streamer, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) thread = Thread(target=model.generate, kwargs=gen_kwargs) thread.start() out = "" for token in streamer: # stream tokens to UI out += token yield f"```rocq\n{out}\n```" def _sample_test_example(): if not TEST_EXAMPLES: return "", "", "No test examples loaded. Set TEST_JSON_PATH or add test.json at repo root." ex = random.choice(TEST_EXAMPLES) return ex['term'], ex['dependencies'], ex['initial_proof'] # NEW: hot-reload the test set def _reload_test_set(): global TEST_EXAMPLES TEST_EXAMPLES = load_test_examples() return gr.update(value=f"Reloaded {len(TEST_EXAMPLES)} test examples from {TEST_JSON_PATH}.") with gr.Blocks(title="Proof translator (ZeroGPU, FP8)") as demo: gr.Markdown( "# Vanilla Rocq to SSReflect proof translator\n" "Write a proof term, " "then write dependencies appearing in the source proof.\n\n" "You can also use **🎲 Draw test example** to pull a sample from the test set." ) with gr.Row(): sample_btn = gr.Button("🎲 Draw test example", variant="secondary") reload_test_btn = gr.Button("Reload test set", variant="secondary") with gr.Row(): term_box = gr.Code( label="Pretty-printed proof term", language=None, interactive=True, lines=18, ) dep_box = gr.Code( label="Dependencies contain in the proof term", language=None, interactive=True, lines=18, ) with gr.Row(): temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature") top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p") max_new = gr.Slider(256, 8192, value=4096, step=128, label="max_new_tokens") # Output panels: model vs baseline/truth with gr.Row(): out = gr.Markdown(label="Generated proof") baseline = gr.Code(label="Source proof", language=None) btn = gr.Button("Translate", variant="primary") test_notice = gr.Markdown("") sample_btn.click(_sample_test_example, inputs=None, outputs=[term_box, dep_box, baseline]) reload_test_btn.click(_reload_test_set, inputs=None, outputs=test_notice) btn.click( generate, inputs=[term_box, dep_box, temperature, top_p, max_new], outputs=out, concurrency_limit=1, ) demo.queue(max_size=20, default_concurrency_limit=1) if __name__ == "__main__": demo.launch()