AutoRefine / app.py
yrshi's picture
first commit
8b4913f
raw
history blame
9.16 kB
import transformers
import torch
import requests
import re
import gradio as gr
from threading import Thread
# --- Configuration --------------------------------------------------
# 1. DEFINE YOUR MODEL
model_id = "yrshi/AutoRefine-Qwen2.5-3B-Base"
# 2. !!! CRITICAL: UPDATE THIS URL !!!
# Your local 'http://127.0.0.1:8000/retrieve' will NOT work on Hugging Face.
# You must deploy your retrieval service and provide its public URL here.
RETRIEVER_URL = "http://127.0.0.1:8000/retrieve" # <-- UPDATE ME
# 3. MODEL & SEARCH CONSTANTS
curr_eos = [151645, 151643] # for Qwen2.5 series models
curr_search_template = '\n\n{output_text}<documents>{search_results}</documents>\n\n'
target_sequences = ["</search>", " </search>", "</search>\n", " </search>\n", "</search>\n\n", " </search>\n\n"]
# --- Global Model & Tokenizer Loading -------------------------------
# This happens once when the Space starts.
# Ensure your Space has a GPU assigned (e.g., T4, A10G).
print("Loading model and tokenizer...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto"
)
print("Model and tokenizer loaded successfully.")
# --- Custom Stopping Criteria Class ---------------------------------
class StopOnSequence(transformers.StoppingCriteria):
def __init__(self, target_sequences, tokenizer):
self.target_ids = [tokenizer.encode(target_sequence, add_special_tokens=False) for target_sequence in target_sequences]
self.target_lengths = [len(target_id) for target_id in self.target_ids]
self._tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs):
targets = [torch.as_tensor(target_id, device=input_ids.device) for target_id in self.target_ids]
if input_ids.shape[1] < min(self.target_lengths):
return False
for i, target in enumerate(targets):
if torch.equal(input_ids[0, -self.target_lengths[i]:], target):
return True
return False
# Initialize stopping criteria globally
stopping_criteria = transformers.StoppingCriteriaList([StopOnSequence(target_sequences, tokenizer)])
# --- Helper Functions (Search & Parse) ------------------------------
def get_query(text):
pattern = re.compile(r"<search>(.*?)</search>", re.DOTALL)
matches = pattern.findall(text)
return matches[-1] if matches else None
def search(query: str):
"""
Calls your deployed retriever service.
"""
payload = {"queries": [query], "topk": 3, "return_scores": True}
if RETRIEVER_URL == "http://127.0.0.1:8000/retrieve":
print("WARNING: Using default local retriever URL. This will likely fail.")
print("Please update RETRIEVER_URL in app.py to your deployed service.")
try:
response = requests.post(RETRIEVER_URL, json=payload, timeout=10)
response.raise_for_status() # Raise an error for bad responses
results = response.json()['result']
format_reference = ''
for idx, doc_item in enumerate(results[0]):
content = doc_item['document']['contents']
title = content.split("\n")[0]
text = "\n".join(content.split("\n")[1:])
format_reference += f"Doc {idx+1}(Title: {title}) {text}\n"
return format_reference
except requests.exceptions.RequestException as e:
print(f"Error calling retriever: {e}")
return f"Error: Could not retrieve search results for query: {query}"
except (KeyError, IndexError):
print("Error parsing retriever response")
return "Error: Malformed response from retriever."
# --- Main Gradio 'respond' Function ---------------------------------
def respond(
message,
history: list[dict[str, str]],
system_message, # This is now our base prompt
max_tokens,
temperature,
top_p,
hf_token: gr.OAuthToken = None, # Not used here, but in template
):
"""
This function implements your local multi-turn search logic as a
streaming generator for the Gradio interface.
"""
question = message.strip()
# Use the system_message from the UI as the base prompt
# Or, if empty, use your default.
if not system_message:
system_message = """You are a helpful assistant excel at answering questions with multi-turn search engine calling. \
To answer questions, you must first reason through the available information using <think> and </think>. \
If you identify missing knowledge, you may issue a search request using <search> query </search> at any time. The retrieval system will provide you with the three most relevant documents enclosed in <documents> and </documents>. \
After each search, you need to summarize and refine the existing documents in <refine> and </refine>. \
You may send multiple search requests if needed. \
Once you have sufficient information, provide a concise final answer using <answer> and </answer>. For example, <answer> Donald Trump </answer>."""
prompt = f"{system_message} Question: {question}\n"
if tokenizer.chat_template:
# Apply chat template if it exists
# Note: Your logic builds the prompt manually, but this ensures
# correct special tokens if the model needs them.
chat_prompt = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(chat_prompt, add_generation_prompt=True, tokenize=False)
# This string will accumulate the full agent trajectory
full_response_trajectory = ""
while True:
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
attention_mask = torch.ones_like(input_ids)
# Check for context overflow
if input_ids.shape[1] > model.config.max_position_embeddings - max_tokens:
print("Context limit reached.")
full_response_trajectory += "\n\n[Error: Context limit reached. Aborting.]"
yield full_response_trajectory
break
# Generate text with the stopping criteria
outputs = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=max_tokens,
stopping_criteria=stopping_criteria,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
temperature=temperature,
top_p=top_p
)
# Decode the *newly* generated tokens
generated_token_ids = outputs[0][input_ids.shape[1]:]
output_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
# Check if generation ended with an EOS token
if outputs[0][-1].item() in curr_eos:
full_response_trajectory += output_text
yield full_response_trajectory # Yield the final text
break # Exit the loop
# --- Generation stopped at </search> ---
# Get the full text (prompt + new generation) to parse the *last* query
full_generation_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
query_text = get_query(full_generation_text)
if query_text:
search_results = search(query_text)
else:
search_results = 'Error: Stop token found but no <search> query was parsed.'
# Construct the text to append to the prompt
search_text = curr_search_template.format(
output_text=output_text,
search_results=search_results
)
# Append to the prompt for the next loop
prompt += search_text
# Append to the trajectory string and yield to the UI
full_response_trajectory += search_text
yield full_response_trajectory
# --- Gradio UI (Example) -------------------------------------------
# This part is just to make the file runnable.
# You can customize your Gradio UI as needed.
with gr.Blocks() as demo:
gr.Markdown("# Multi-Turn Search Agent")
gr.Markdown(f"Running model: `{model_id}`")
with gr.Accordion("Prompt & Parameters"):
system_message = gr.Textbox(
label="System Message",
value="""You are a helpful assistant... (full prompt from code)""",
lines=10
)
max_tokens = gr.Slider(50, 2048, value=1024, label="Max New Tokens")
temperature = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p")
chatbot = gr.Chatbot(label="Agent Trajectory")
msg = gr.Textbox(label="Your Question")
def user_turn(user_message, history):
return "", history + [[user_message, None]]
msg.submit(
user_turn,
[msg, chatbot],
[msg, chatbot],
queue=False
).then(
respond,
[msg, chatbot, system_message, max_tokens, temperature, top_p],
chatbot
)
if __name__ == "__main__":
demo.queue().launch(debug=True)