Spaces:
Sleeping
Sleeping
File size: 9,162 Bytes
8b4913f 5ba800a 8b4913f 5ba800a 8b4913f 5ba800a 8b4913f 5ba800a 8b4913f 5ba800a 8b4913f 5ba800a 8b4913f 5ba800a 8b4913f 5ba800a 8b4913f 5ba800a 8b4913f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
import transformers
import torch
import requests
import re
import gradio as gr
from threading import Thread
# --- Configuration --------------------------------------------------
# 1. DEFINE YOUR MODEL
model_id = "yrshi/AutoRefine-Qwen2.5-3B-Base"
# 2. !!! CRITICAL: UPDATE THIS URL !!!
# Your local 'http://127.0.0.1:8000/retrieve' will NOT work on Hugging Face.
# You must deploy your retrieval service and provide its public URL here.
RETRIEVER_URL = "http://127.0.0.1:8000/retrieve" # <-- UPDATE ME
# 3. MODEL & SEARCH CONSTANTS
curr_eos = [151645, 151643] # for Qwen2.5 series models
curr_search_template = '\n\n{output_text}<documents>{search_results}</documents>\n\n'
target_sequences = ["</search>", " </search>", "</search>\n", " </search>\n", "</search>\n\n", " </search>\n\n"]
# --- Global Model & Tokenizer Loading -------------------------------
# This happens once when the Space starts.
# Ensure your Space has a GPU assigned (e.g., T4, A10G).
print("Loading model and tokenizer...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto"
)
print("Model and tokenizer loaded successfully.")
# --- Custom Stopping Criteria Class ---------------------------------
class StopOnSequence(transformers.StoppingCriteria):
def __init__(self, target_sequences, tokenizer):
self.target_ids = [tokenizer.encode(target_sequence, add_special_tokens=False) for target_sequence in target_sequences]
self.target_lengths = [len(target_id) for target_id in self.target_ids]
self._tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs):
targets = [torch.as_tensor(target_id, device=input_ids.device) for target_id in self.target_ids]
if input_ids.shape[1] < min(self.target_lengths):
return False
for i, target in enumerate(targets):
if torch.equal(input_ids[0, -self.target_lengths[i]:], target):
return True
return False
# Initialize stopping criteria globally
stopping_criteria = transformers.StoppingCriteriaList([StopOnSequence(target_sequences, tokenizer)])
# --- Helper Functions (Search & Parse) ------------------------------
def get_query(text):
pattern = re.compile(r"<search>(.*?)</search>", re.DOTALL)
matches = pattern.findall(text)
return matches[-1] if matches else None
def search(query: str):
"""
Calls your deployed retriever service.
"""
payload = {"queries": [query], "topk": 3, "return_scores": True}
if RETRIEVER_URL == "http://127.0.0.1:8000/retrieve":
print("WARNING: Using default local retriever URL. This will likely fail.")
print("Please update RETRIEVER_URL in app.py to your deployed service.")
try:
response = requests.post(RETRIEVER_URL, json=payload, timeout=10)
response.raise_for_status() # Raise an error for bad responses
results = response.json()['result']
format_reference = ''
for idx, doc_item in enumerate(results[0]):
content = doc_item['document']['contents']
title = content.split("\n")[0]
text = "\n".join(content.split("\n")[1:])
format_reference += f"Doc {idx+1}(Title: {title}) {text}\n"
return format_reference
except requests.exceptions.RequestException as e:
print(f"Error calling retriever: {e}")
return f"Error: Could not retrieve search results for query: {query}"
except (KeyError, IndexError):
print("Error parsing retriever response")
return "Error: Malformed response from retriever."
# --- Main Gradio 'respond' Function ---------------------------------
def respond(
message,
history: list[dict[str, str]],
system_message, # This is now our base prompt
max_tokens,
temperature,
top_p,
hf_token: gr.OAuthToken = None, # Not used here, but in template
):
"""
This function implements your local multi-turn search logic as a
streaming generator for the Gradio interface.
"""
question = message.strip()
# Use the system_message from the UI as the base prompt
# Or, if empty, use your default.
if not system_message:
system_message = """You are a helpful assistant excel at answering questions with multi-turn search engine calling. \
To answer questions, you must first reason through the available information using <think> and </think>. \
If you identify missing knowledge, you may issue a search request using <search> query </search> at any time. The retrieval system will provide you with the three most relevant documents enclosed in <documents> and </documents>. \
After each search, you need to summarize and refine the existing documents in <refine> and </refine>. \
You may send multiple search requests if needed. \
Once you have sufficient information, provide a concise final answer using <answer> and </answer>. For example, <answer> Donald Trump </answer>."""
prompt = f"{system_message} Question: {question}\n"
if tokenizer.chat_template:
# Apply chat template if it exists
# Note: Your logic builds the prompt manually, but this ensures
# correct special tokens if the model needs them.
chat_prompt = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(chat_prompt, add_generation_prompt=True, tokenize=False)
# This string will accumulate the full agent trajectory
full_response_trajectory = ""
while True:
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
attention_mask = torch.ones_like(input_ids)
# Check for context overflow
if input_ids.shape[1] > model.config.max_position_embeddings - max_tokens:
print("Context limit reached.")
full_response_trajectory += "\n\n[Error: Context limit reached. Aborting.]"
yield full_response_trajectory
break
# Generate text with the stopping criteria
outputs = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=max_tokens,
stopping_criteria=stopping_criteria,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
temperature=temperature,
top_p=top_p
)
# Decode the *newly* generated tokens
generated_token_ids = outputs[0][input_ids.shape[1]:]
output_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
# Check if generation ended with an EOS token
if outputs[0][-1].item() in curr_eos:
full_response_trajectory += output_text
yield full_response_trajectory # Yield the final text
break # Exit the loop
# --- Generation stopped at </search> ---
# Get the full text (prompt + new generation) to parse the *last* query
full_generation_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
query_text = get_query(full_generation_text)
if query_text:
search_results = search(query_text)
else:
search_results = 'Error: Stop token found but no <search> query was parsed.'
# Construct the text to append to the prompt
search_text = curr_search_template.format(
output_text=output_text,
search_results=search_results
)
# Append to the prompt for the next loop
prompt += search_text
# Append to the trajectory string and yield to the UI
full_response_trajectory += search_text
yield full_response_trajectory
# --- Gradio UI (Example) -------------------------------------------
# This part is just to make the file runnable.
# You can customize your Gradio UI as needed.
with gr.Blocks() as demo:
gr.Markdown("# Multi-Turn Search Agent")
gr.Markdown(f"Running model: `{model_id}`")
with gr.Accordion("Prompt & Parameters"):
system_message = gr.Textbox(
label="System Message",
value="""You are a helpful assistant... (full prompt from code)""",
lines=10
)
max_tokens = gr.Slider(50, 2048, value=1024, label="Max New Tokens")
temperature = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p")
chatbot = gr.Chatbot(label="Agent Trajectory")
msg = gr.Textbox(label="Your Question")
def user_turn(user_message, history):
return "", history + [[user_message, None]]
msg.submit(
user_turn,
[msg, chatbot],
[msg, chatbot],
queue=False
).then(
respond,
[msg, chatbot, system_message, max_tokens, temperature, top_p],
chatbot
)
if __name__ == "__main__":
demo.queue().launch(debug=True) |