File size: 9,162 Bytes
8b4913f
 
 
 
5ba800a
8b4913f
5ba800a
8b4913f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ba800a
 
 
 
8b4913f
5ba800a
 
 
8b4913f
5ba800a
 
8b4913f
 
5ba800a
8b4913f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ba800a
 
8b4913f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ba800a
8b4913f
 
 
 
 
 
 
 
 
 
5ba800a
 
8b4913f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import transformers
import torch
import requests
import re
import gradio as gr
from threading import Thread

# --- Configuration --------------------------------------------------

# 1. DEFINE YOUR MODEL
model_id = "yrshi/AutoRefine-Qwen2.5-3B-Base"

# 2. !!! CRITICAL: UPDATE THIS URL !!!
# Your local 'http://127.0.0.1:8000/retrieve' will NOT work on Hugging Face.
# You must deploy your retrieval service and provide its public URL here.
RETRIEVER_URL = "http://127.0.0.1:8000/retrieve" # <-- UPDATE ME

# 3. MODEL & SEARCH CONSTANTS
curr_eos = [151645, 151643] # for Qwen2.5 series models
curr_search_template = '\n\n{output_text}<documents>{search_results}</documents>\n\n'
target_sequences = ["</search>", " </search>", "</search>\n", " </search>\n", "</search>\n\n", " </search>\n\n"]

# --- Global Model & Tokenizer Loading -------------------------------
# This happens once when the Space starts.
# Ensure your Space has a GPU assigned (e.g., T4, A10G).

print("Loading model and tokenizer...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_id, 
    torch_dtype=torch.bfloat16, 
    device_map="auto"
)
print("Model and tokenizer loaded successfully.")

# --- Custom Stopping Criteria Class ---------------------------------

class StopOnSequence(transformers.StoppingCriteria):
    def __init__(self, target_sequences, tokenizer):
        self.target_ids = [tokenizer.encode(target_sequence, add_special_tokens=False) for target_sequence in target_sequences]
        self.target_lengths = [len(target_id) for target_id in self.target_ids]
        self._tokenizer = tokenizer

    def __call__(self, input_ids, scores, **kwargs):
        targets = [torch.as_tensor(target_id, device=input_ids.device) for target_id in self.target_ids]
        if input_ids.shape[1] < min(self.target_lengths):
            return False
        for i, target in enumerate(targets):
            if torch.equal(input_ids[0, -self.target_lengths[i]:], target):
                return True
        return False

# Initialize stopping criteria globally
stopping_criteria = transformers.StoppingCriteriaList([StopOnSequence(target_sequences, tokenizer)])

# --- Helper Functions (Search & Parse) ------------------------------

def get_query(text):
    pattern = re.compile(r"<search>(.*?)</search>", re.DOTALL)
    matches = pattern.findall(text)
    return matches[-1] if matches else None

def search(query: str):
    """
    Calls your deployed retriever service.
    """
    payload = {"queries": [query], "topk": 3, "return_scores": True}
    
    if RETRIEVER_URL == "http://127.0.0.1:8000/retrieve":
        print("WARNING: Using default local retriever URL. This will likely fail.")
        print("Please update RETRIEVER_URL in app.py to your deployed service.")
    
    try:
        response = requests.post(RETRIEVER_URL, json=payload, timeout=10)
        response.raise_for_status() # Raise an error for bad responses
        results = response.json()['result']
        
        format_reference = ''
        for idx, doc_item in enumerate(results[0]):
            content = doc_item['document']['contents']
            title = content.split("\n")[0]
            text = "\n".join(content.split("\n")[1:])
            format_reference += f"Doc {idx+1}(Title: {title}) {text}\n"
        return format_reference
    
    except requests.exceptions.RequestException as e:
        print(f"Error calling retriever: {e}")
        return f"Error: Could not retrieve search results for query: {query}"
    except (KeyError, IndexError):
        print("Error parsing retriever response")
        return "Error: Malformed response from retriever."

# --- Main Gradio 'respond' Function ---------------------------------

def respond(
    message,
    history: list[dict[str, str]],
    system_message, # This is now our base prompt
    max_tokens,
    temperature,
    top_p,
    hf_token: gr.OAuthToken = None, # Not used here, but in template
):
    """
    This function implements your local multi-turn search logic as a
    streaming generator for the Gradio interface.
    """
    
    question = message.strip()
    
    # Use the system_message from the UI as the base prompt
    # Or, if empty, use your default.
    if not system_message:
        system_message = """You are a helpful assistant excel at answering questions with multi-turn search engine calling. \
    To answer questions, you must first reason through the available information using <think> and </think>. \
    If you identify missing knowledge, you may issue a search request using <search> query </search> at any time. The retrieval system will provide you with the three most relevant documents enclosed in <documents> and </documents>. \
    After each search, you need to summarize and refine the existing documents in <refine> and </refine>. \
    You may send multiple search requests if needed. \
    Once you have sufficient information, provide a concise final answer using <answer> and </answer>. For example, <answer> Donald Trump </answer>."""
        
    prompt = f"{system_message} Question: {question}\n"

    if tokenizer.chat_template:
        # Apply chat template if it exists
        # Note: Your logic builds the prompt manually, but this ensures
        # correct special tokens if the model needs them.
        chat_prompt = [{"role": "user", "content": prompt}]
        prompt = tokenizer.apply_chat_template(chat_prompt, add_generation_prompt=True, tokenize=False)
    
    # This string will accumulate the full agent trajectory
    full_response_trajectory = ""

    while True:
        input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
        attention_mask = torch.ones_like(input_ids)
        
        # Check for context overflow
        if input_ids.shape[1] > model.config.max_position_embeddings - max_tokens:
            print("Context limit reached.")
            full_response_trajectory += "\n\n[Error: Context limit reached. Aborting.]"
            yield full_response_trajectory
            break

        # Generate text with the stopping criteria
        outputs = model.generate(
            input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_tokens,
            stopping_criteria=stopping_criteria,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=True,
            temperature=temperature,
            top_p=top_p
        )

        # Decode the *newly* generated tokens
        generated_token_ids = outputs[0][input_ids.shape[1]:]
        output_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
        
        # Check if generation ended with an EOS token
        if outputs[0][-1].item() in curr_eos:
            full_response_trajectory += output_text
            yield full_response_trajectory # Yield the final text
            break # Exit the loop

        # --- Generation stopped at </search> ---
        
        # Get the full text (prompt + new generation) to parse the *last* query
        full_generation_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        query_text = get_query(full_generation_text)
        
        if query_text:
            search_results = search(query_text)
        else:
            search_results = 'Error: Stop token found but no <search> query was parsed.'

        # Construct the text to append to the prompt
        search_text = curr_search_template.format(
            output_text=output_text, 
            search_results=search_results
        )
        
        # Append to the prompt for the next loop
        prompt += search_text
        
        # Append to the trajectory string and yield to the UI
        full_response_trajectory += search_text
        yield full_response_trajectory


# --- Gradio UI (Example) -------------------------------------------
# This part is just to make the file runnable.
# You can customize your Gradio UI as needed.

with gr.Blocks() as demo:
    gr.Markdown("# Multi-Turn Search Agent")
    gr.Markdown(f"Running model: `{model_id}`")
    
    with gr.Accordion("Prompt & Parameters"):
        system_message = gr.Textbox(
            label="System Message", 
            value="""You are a helpful assistant... (full prompt from code)""",
            lines=10
        )
        max_tokens = gr.Slider(50, 2048, value=1024, label="Max New Tokens")
        temperature = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
        top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p")
        
    chatbot = gr.Chatbot(label="Agent Trajectory")
    msg = gr.Textbox(label="Your Question")
    
    def user_turn(user_message, history):
        return "", history + [[user_message, None]]

    msg.submit(
        user_turn, 
        [msg, chatbot], 
        [msg, chatbot], 
        queue=False
    ).then(
        respond,
        [msg, chatbot, system_message, max_tokens, temperature, top_p],
        chatbot
    )

if __name__ == "__main__":
    demo.queue().launch(debug=True)