AutoRefine / app.py
yrshi's picture
update some files
de48493
raw
history blame
9.77 kB
import transformers
import torch
import requests
import re
import gradio as gr
from threading import Thread
import subprocess
import time
import atexit
try:
server_process = subprocess.Popen(["bash", "retrieval_launch.sh"])
print(f"Server process started with PID: {server_process.pid}")
# Register a function to kill the server when app.py exits
def cleanup():
print("Shutting down retrieval server...")
server_process.terminate()
server_process.wait()
print("Server process terminated.")
atexit.register(cleanup)
except Exception as e:
print(f"Failed to start retrieval_launch.sh: {e}")
print("WARNING: The retrieval server may not be running.")
# --- Configuration --------------------------------------------------
# 1. DEFINE YOUR MODEL
model_id = "yrshi/AutoRefine-Qwen2.5-3B-Base"
# 2. !!! CRITICAL: UPDATE THIS URL !!!
# Your local 'http://127.0.0.1:8000/retrieve' will NOT work on Hugging Face.
# You must deploy your retrieval service and provide its public URL here.
RETRIEVER_URL = "http://127.0.0.1:8000/retrieve" # <-- UPDATE ME
# 3. MODEL & SEARCH CONSTANTS
curr_eos = [151645, 151643] # for Qwen2.5 series models
curr_search_template = '\n\n{output_text}<documents>{search_results}</documents>\n\n'
target_sequences = ["</search>", " </search>", "</search>\n", " </search>\n", "</search>\n\n", " </search>\n\n"]
# --- Global Model & Tokenizer Loading -------------------------------
# This happens once when the Space starts.
# Ensure your Space has a GPU assigned (e.g., T4, A10G).
print("Loading model and tokenizer...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto"
)
print("Model and tokenizer loaded successfully.")
# --- Custom Stopping Criteria Class ---------------------------------
class StopOnSequence(transformers.StoppingCriteria):
def __init__(self, target_sequences, tokenizer):
self.target_ids = [tokenizer.encode(target_sequence, add_special_tokens=False) for target_sequence in target_sequences]
self.target_lengths = [len(target_id) for target_id in self.target_ids]
self._tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs):
targets = [torch.as_tensor(target_id, device=input_ids.device) for target_id in self.target_ids]
if input_ids.shape[1] < min(self.target_lengths):
return False
for i, target in enumerate(targets):
if torch.equal(input_ids[0, -self.target_lengths[i]:], target):
return True
return False
# Initialize stopping criteria globally
stopping_criteria = transformers.StoppingCriteriaList([StopOnSequence(target_sequences, tokenizer)])
# --- Helper Functions (Search & Parse) ------------------------------
def get_query(text):
pattern = re.compile(r"<search>(.*?)</search>", re.DOTALL)
matches = pattern.findall(text)
return matches[-1] if matches else None
def search(query: str):
"""
Calls your deployed retriever service.
"""
payload = {"queries": [query], "topk": 3, "return_scores": True}
if RETRIEVER_URL == "http://127.0.0.1:8000/retrieve":
print("WARNING: Using default local retriever URL. This will likely fail.")
print("Please update RETRIEVER_URL in app.py to your deployed service.")
try:
response = requests.post(RETRIEVER_URL, json=payload, timeout=10)
response.raise_for_status() # Raise an error for bad responses
results = response.json()['result']
format_reference = ''
for idx, doc_item in enumerate(results[0]):
content = doc_item['document']['contents']
title = content.split("\n")[0]
text = "\n".join(content.split("\n")[1:])
format_reference += f"Doc {idx+1}(Title: {title}) {text}\n"
return format_reference
except requests.exceptions.RequestException as e:
print(f"Error calling retriever: {e}")
return f"Error: Could not retrieve search results for query: {query}"
except (KeyError, IndexError):
print("Error parsing retriever response")
return "Error: Malformed response from retriever."
# --- Main Gradio 'respond' Function ---------------------------------
def respond(
message,
history: list[dict[str, str]],
system_message, # This is now our base prompt
max_tokens,
temperature,
top_p,
hf_token: gr.OAuthToken = None, # Not used here, but in template
):
"""
This function implements your local multi-turn search logic as a
streaming generator for the Gradio interface.
"""
question = message.strip()
# Use the system_message from the UI as the base prompt
# Or, if empty, use your default.
if not system_message:
system_message = """You are a helpful assistant excel at answering questions with multi-turn search engine calling. \
To answer questions, you must first reason through the available information using <think> and </think>. \
If you identify missing knowledge, you may issue a search request using <search> query </search> at any time. The retrieval system will provide you with the three most relevant documents enclosed in <documents> and </documents>. \
After each search, you need to summarize and refine the existing documents in <refine> and </refine>. \
You may send multiple search requests if needed. \
Once you have sufficient information, provide a concise final answer using <answer> and </answer>. For example, <answer> Donald Trump </answer>."""
prompt = f"{system_message} Question: {question}\n"
if tokenizer.chat_template:
# Apply chat template if it exists
# Note: Your logic builds the prompt manually, but this ensures
# correct special tokens if the model needs them.
chat_prompt = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(chat_prompt, add_generation_prompt=True, tokenize=False)
# This string will accumulate the full agent trajectory
full_response_trajectory = ""
while True:
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
attention_mask = torch.ones_like(input_ids)
# Check for context overflow
if input_ids.shape[1] > model.config.max_position_embeddings - max_tokens:
print("Context limit reached.")
full_response_trajectory += "\n\n[Error: Context limit reached. Aborting.]"
yield full_response_trajectory
break
# Generate text with the stopping criteria
outputs = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=max_tokens,
stopping_criteria=stopping_criteria,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
temperature=temperature,
top_p=top_p
)
# Decode the *newly* generated tokens
generated_token_ids = outputs[0][input_ids.shape[1]:]
output_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
# Check if generation ended with an EOS token
if outputs[0][-1].item() in curr_eos:
full_response_trajectory += output_text
yield full_response_trajectory # Yield the final text
break # Exit the loop
# --- Generation stopped at </search> ---
# Get the full text (prompt + new generation) to parse the *last* query
full_generation_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
query_text = get_query(full_generation_text)
if query_text:
search_results = search(query_text)
else:
search_results = 'Error: Stop token found but no <search> query was parsed.'
# Construct the text to append to the prompt
search_text = curr_search_template.format(
output_text=output_text,
search_results=search_results
)
# Append to the prompt for the next loop
prompt += search_text
# Append to the trajectory string and yield to the UI
full_response_trajectory += search_text
yield full_response_trajectory
# --- Gradio UI (Example) -------------------------------------------
# This part is just to make the file runnable.
# You can customize your Gradio UI as needed.
with gr.Blocks() as demo:
gr.Markdown("# Multi-Turn Search Agent")
gr.Markdown(f"Running model: `{model_id}`")
with gr.Accordion("Prompt & Parameters"):
system_message = gr.Textbox(
label="System Message",
value="""You are a helpful assistant... (full prompt from code)""",
lines=10
)
max_tokens = gr.Slider(50, 2048, value=1024, label="Max New Tokens")
temperature = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p")
chatbot = gr.Chatbot(label="Agent Trajectory")
msg = gr.Textbox(label="Your Question")
def user_turn(user_message, history):
return "", history + [[user_message, None]]
msg.submit(
user_turn,
[msg, chatbot],
[msg, chatbot],
queue=False
).then(
respond,
[msg, chatbot, system_message, max_tokens, temperature, top_p],
chatbot
)
if __name__ == "__main__":
demo.queue().launch(debug=True)