File size: 6,045 Bytes
5e2217f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34883a3
5e2217f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import os
import torch
import gradio as gr
import spaces
import json
import random  # <<< NEW
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer

# >>>> CHANGE THIS <<<<
MODEL_ID = os.getenv("MODEL_ID", "theostos/babel-ssreflect-fp8")
TEST_JSON_PATH = os.getenv("TEST_JSON_PATH", "test.json")

INSTRUCTION_TEMPLATE = "You are given a proof term:\n\n{term}\n\nYour task is to derive a sequence of SSReflect tactics that corresponds to this term.\n\nWhen you work through the problem, write down your reasoning in detail inside <think> ... </think> tags. This reasoning should reflect your natural thought process as you explore the structure of the term and figure out what tactics to apply. You should consider different possible approaches, reflect on why some might or might not work, and gradually converge on a tactic choice.\n\nAfter each reasoning block, provide the next (group of) tactic(s) enclosed in:\n\n\\box{{\n  <tactic>\n}}\n\nSome dependencies that could be helpful:\n\n{dependencies}"

HF_TOKEN = os.getenv("HF_TOKEN")

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN, use_fast=True)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

_model = None
def load_model():
    global _model
    if _model is None:
        _model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            token=HF_TOKEN,
            device_map="auto",
            dtype="auto",
            trust_remote_code=True
        )
    return _model

def build_messages(term: str, deps: str):
    instr = INSTRUCTION_TEMPLATE.format(term=term, dependencies=deps)
    return [{"role": "user", "content": instr}]

def load_test_examples(path=TEST_JSON_PATH):
    """
    Expects a JSON list of dicts with keys:
      - 'rocq'
        - 'term'
        - 'notations'
        - 'constants'
    """
    try:
        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)
        if not isinstance(data, list):
            raise ValueError("Test set JSON must be a list of objects.")
        
        for entry in data:
            entry['dependencies'] = "\n".join(entry['notations']) + "\n".join(entry['constants'])
            entry['initial_proof'] = "\n".join(entry['steps'])
        print(f"[info] Loaded {len(data)} test examples from {path}")
        return data
    except Exception as e:
        print(f"[warn] Could not load test set {path}: {e}")
        return []

TEST_EXAMPLES = load_test_examples()

def _duration(term, deps, temperature, top_p, max_new_tokens):
    return int(min(300, max(60, (int(max_new_tokens) / 2.5) + 30)))

@spaces.GPU(duration=_duration)
def generate(term, deps, temperature, top_p, max_new_tokens):
    model = load_model()
    device = "cuda" if torch.cuda.is_available() else "cpu"

    messages = build_messages(term, deps)

    prompt_text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    inputs = tokenizer(prompt_text, return_tensors="pt").to(device)
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    gen_kwargs = dict(
        **inputs,
        max_new_tokens=int(max_new_tokens),
        temperature=float(temperature),
        top_p=float(top_p),
        do_sample=True,
        streamer=streamer,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )

    thread = Thread(target=model.generate, kwargs=gen_kwargs)
    thread.start()

    out = ""
    for token in streamer:  # stream tokens to UI
        out += token
        yield f"```rocq\n{out}\n```"


def _sample_test_example():
    if not TEST_EXAMPLES:
        return "", "", "No test examples loaded. Set TEST_JSON_PATH or add test.json at repo root."
    ex = random.choice(TEST_EXAMPLES)
    return ex['term'], ex['dependencies'], ex['initial_proof']

# NEW: hot-reload the test set
def _reload_test_set():
    global TEST_EXAMPLES
    TEST_EXAMPLES = load_test_examples()
    return gr.update(value=f"Reloaded {len(TEST_EXAMPLES)} test examples from {TEST_JSON_PATH}.")

with gr.Blocks(title="Proof translator (ZeroGPU, FP8)") as demo:
    gr.Markdown(
        "# Vanilla Rocq to SSReflect proof translator\n"
        "Write a proof term, "
        "then write dependencies appearing in the source proof.\n\n"
        "You can also use **🎲 Draw test example** to pull a sample from the test set."
    )

    with gr.Row():
        sample_btn = gr.Button("🎲 Draw test example", variant="secondary")
        reload_test_btn = gr.Button("Reload test set", variant="secondary")

    with gr.Row():
        term_box = gr.Code(
            label="Pretty-printed proof term",
            language=None,
            interactive=True,
            lines=18,
        )
        dep_box = gr.Code(
            label="Dependencies contain in the proof term",
            language=None,
            interactive=True,
            lines=18,
        )

    with gr.Row():
        temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature")
        top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
        max_new = gr.Slider(256, 8192, value=4096, step=128, label="max_new_tokens")

    # Output panels: model vs baseline/truth
    with gr.Row():
        out = gr.Markdown(label="Generated proof")
        baseline = gr.Code(label="Source proof", language=None)
    btn = gr.Button("Translate", variant="primary")

    test_notice = gr.Markdown("")
    sample_btn.click(_sample_test_example, inputs=None, outputs=[term_box, dep_box, baseline])
    reload_test_btn.click(_reload_test_set, inputs=None, outputs=test_notice)
    btn.click(
        generate,
        inputs=[term_box, dep_box, temperature, top_p, max_new],
        outputs=out,
        concurrency_limit=1,
    )

    demo.queue(max_size=20, default_concurrency_limit=1)

if __name__ == "__main__":
    demo.launch()