import transformers import torch import requests import re import gradio as gr from threading import Thread import subprocess import time import atexit try: server_process = subprocess.Popen(["bash", "retrieval_launch.sh"]) print(f"Server process started with PID: {server_process.pid}") # Register a function to kill the server when app.py exits def cleanup(): print("Shutting down retrieval server...") server_process.terminate() server_process.wait() print("Server process terminated.") atexit.register(cleanup) except Exception as e: print(f"Failed to start retrieval_launch.sh: {e}") print("WARNING: The retrieval server may not be running.") # --- Configuration -------------------------------------------------- # 1. DEFINE YOUR MODEL model_id = "yrshi/AutoRefine-Qwen2.5-3B-Base" # 2. !!! CRITICAL: UPDATE THIS URL !!! # Your local 'http://127.0.0.1:8000/retrieve' will NOT work on Hugging Face. # You must deploy your retrieval service and provide its public URL here. RETRIEVER_URL = "http://127.0.0.1:8000/retrieve" # <-- UPDATE ME # 3. MODEL & SEARCH CONSTANTS curr_eos = [151645, 151643] # for Qwen2.5 series models curr_search_template = '\n\n{output_text}{search_results}\n\n' target_sequences = ["", " ", "\n", " \n", "\n\n", " \n\n"] # --- Global Model & Tokenizer Loading ------------------------------- # This happens once when the Space starts. # Ensure your Space has a GPU assigned (e.g., T4, A10G). print("Loading model and tokenizer...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) model = transformers.AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto" ) print("Model and tokenizer loaded successfully.") # --- Custom Stopping Criteria Class --------------------------------- class StopOnSequence(transformers.StoppingCriteria): def __init__(self, target_sequences, tokenizer): self.target_ids = [tokenizer.encode(target_sequence, add_special_tokens=False) for target_sequence in target_sequences] self.target_lengths = [len(target_id) for target_id in self.target_ids] self._tokenizer = tokenizer def __call__(self, input_ids, scores, **kwargs): targets = [torch.as_tensor(target_id, device=input_ids.device) for target_id in self.target_ids] if input_ids.shape[1] < min(self.target_lengths): return False for i, target in enumerate(targets): if torch.equal(input_ids[0, -self.target_lengths[i]:], target): return True return False # Initialize stopping criteria globally stopping_criteria = transformers.StoppingCriteriaList([StopOnSequence(target_sequences, tokenizer)]) # --- Helper Functions (Search & Parse) ------------------------------ def get_query(text): pattern = re.compile(r"(.*?)", re.DOTALL) matches = pattern.findall(text) return matches[-1] if matches else None def search(query: str): """ Calls your deployed retriever service. """ payload = {"queries": [query], "topk": 3, "return_scores": True} if RETRIEVER_URL == "http://127.0.0.1:8000/retrieve": print("WARNING: Using default local retriever URL. This will likely fail.") print("Please update RETRIEVER_URL in app.py to your deployed service.") try: response = requests.post(RETRIEVER_URL, json=payload, timeout=10) response.raise_for_status() # Raise an error for bad responses results = response.json()['result'] format_reference = '' for idx, doc_item in enumerate(results[0]): content = doc_item['document']['contents'] title = content.split("\n")[0] text = "\n".join(content.split("\n")[1:]) format_reference += f"Doc {idx+1}(Title: {title}) {text}\n" return format_reference except requests.exceptions.RequestException as e: print(f"Error calling retriever: {e}") return f"Error: Could not retrieve search results for query: {query}" except (KeyError, IndexError): print("Error parsing retriever response") return "Error: Malformed response from retriever." # --- Main Gradio 'respond' Function --------------------------------- def respond( message, history: list[dict[str, str]], system_message, # This is now our base prompt max_tokens, temperature, top_p, hf_token: gr.OAuthToken = None, # Not used here, but in template ): """ This function implements your local multi-turn search logic as a streaming generator for the Gradio interface. """ question = message.strip() # Use the system_message from the UI as the base prompt # Or, if empty, use your default. if not system_message: system_message = """You are a helpful assistant excel at answering questions with multi-turn search engine calling. \ To answer questions, you must first reason through the available information using and . \ If you identify missing knowledge, you may issue a search request using query at any time. The retrieval system will provide you with the three most relevant documents enclosed in and . \ After each search, you need to summarize and refine the existing documents in and . \ You may send multiple search requests if needed. \ Once you have sufficient information, provide a concise final answer using and . For example, Donald Trump .""" prompt = f"{system_message} Question: {question}\n" if tokenizer.chat_template: # Apply chat template if it exists # Note: Your logic builds the prompt manually, but this ensures # correct special tokens if the model needs them. chat_prompt = [{"role": "user", "content": prompt}] prompt = tokenizer.apply_chat_template(chat_prompt, add_generation_prompt=True, tokenize=False) # This string will accumulate the full agent trajectory full_response_trajectory = "" while True: input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) attention_mask = torch.ones_like(input_ids) # Check for context overflow if input_ids.shape[1] > model.config.max_position_embeddings - max_tokens: print("Context limit reached.") full_response_trajectory += "\n\n[Error: Context limit reached. Aborting.]" yield full_response_trajectory break # Generate text with the stopping criteria outputs = model.generate( input_ids, attention_mask=attention_mask, max_new_tokens=max_tokens, stopping_criteria=stopping_criteria, pad_token_id=tokenizer.eos_token_id, do_sample=True, temperature=temperature, top_p=top_p ) # Decode the *newly* generated tokens generated_token_ids = outputs[0][input_ids.shape[1]:] output_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True) # Check if generation ended with an EOS token if outputs[0][-1].item() in curr_eos: full_response_trajectory += output_text yield full_response_trajectory # Yield the final text break # Exit the loop # --- Generation stopped at --- # Get the full text (prompt + new generation) to parse the *last* query full_generation_text = tokenizer.decode(outputs[0], skip_special_tokens=True) query_text = get_query(full_generation_text) if query_text: search_results = search(query_text) else: search_results = 'Error: Stop token found but no query was parsed.' # Construct the text to append to the prompt search_text = curr_search_template.format( output_text=output_text, search_results=search_results ) # Append to the prompt for the next loop prompt += search_text # Append to the trajectory string and yield to the UI full_response_trajectory += search_text yield full_response_trajectory # --- Gradio UI (Example) ------------------------------------------- # This part is just to make the file runnable. # You can customize your Gradio UI as needed. with gr.Blocks() as demo: gr.Markdown("# Multi-Turn Search Agent") gr.Markdown(f"Running model: `{model_id}`") with gr.Accordion("Prompt & Parameters"): system_message = gr.Textbox( label="System Message", value="""You are a helpful assistant... (full prompt from code)""", lines=10 ) max_tokens = gr.Slider(50, 2048, value=1024, label="Max New Tokens") temperature = gr.Slider(0.1, 1.0, value=0.7, label="Temperature") top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p") chatbot = gr.Chatbot(label="Agent Trajectory") msg = gr.Textbox(label="Your Question") def user_turn(user_message, history): return "", history + [[user_message, None]] msg.submit( user_turn, [msg, chatbot], [msg, chatbot], queue=False ).then( respond, [msg, chatbot, system_message, max_tokens, temperature, top_p], chatbot ) if __name__ == "__main__": demo.queue().launch(debug=True)