Spaces:
Sleeping
Sleeping
| import transformers | |
| import torch | |
| import requests | |
| import re | |
| import gradio as gr | |
| from threading import Thread | |
| # --- Configuration -------------------------------------------------- | |
| # 1. DEFINE YOUR MODEL | |
| model_id = "yrshi/AutoRefine-Qwen2.5-3B-Base" | |
| # 2. !!! CRITICAL: UPDATE THIS URL !!! | |
| # Your local 'http://127.0.0.1:8000/retrieve' will NOT work on Hugging Face. | |
| # You must deploy your retrieval service and provide its public URL here. | |
| RETRIEVER_URL = "http://127.0.0.1:8000/retrieve" # <-- UPDATE ME | |
| # 3. MODEL & SEARCH CONSTANTS | |
| curr_eos = [151645, 151643] # for Qwen2.5 series models | |
| curr_search_template = '\n\n{output_text}<documents>{search_results}</documents>\n\n' | |
| target_sequences = ["</search>", " </search>", "</search>\n", " </search>\n", "</search>\n\n", " </search>\n\n"] | |
| # --- Global Model & Tokenizer Loading ------------------------------- | |
| # This happens once when the Space starts. | |
| # Ensure your Space has a GPU assigned (e.g., T4, A10G). | |
| print("Loading model and tokenizer...") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) | |
| model = transformers.AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto" | |
| ) | |
| print("Model and tokenizer loaded successfully.") | |
| # --- Custom Stopping Criteria Class --------------------------------- | |
| class StopOnSequence(transformers.StoppingCriteria): | |
| def __init__(self, target_sequences, tokenizer): | |
| self.target_ids = [tokenizer.encode(target_sequence, add_special_tokens=False) for target_sequence in target_sequences] | |
| self.target_lengths = [len(target_id) for target_id in self.target_ids] | |
| self._tokenizer = tokenizer | |
| def __call__(self, input_ids, scores, **kwargs): | |
| targets = [torch.as_tensor(target_id, device=input_ids.device) for target_id in self.target_ids] | |
| if input_ids.shape[1] < min(self.target_lengths): | |
| return False | |
| for i, target in enumerate(targets): | |
| if torch.equal(input_ids[0, -self.target_lengths[i]:], target): | |
| return True | |
| return False | |
| # Initialize stopping criteria globally | |
| stopping_criteria = transformers.StoppingCriteriaList([StopOnSequence(target_sequences, tokenizer)]) | |
| # --- Helper Functions (Search & Parse) ------------------------------ | |
| def get_query(text): | |
| pattern = re.compile(r"<search>(.*?)</search>", re.DOTALL) | |
| matches = pattern.findall(text) | |
| return matches[-1] if matches else None | |
| def search(query: str): | |
| """ | |
| Calls your deployed retriever service. | |
| """ | |
| payload = {"queries": [query], "topk": 3, "return_scores": True} | |
| if RETRIEVER_URL == "http://127.0.0.1:8000/retrieve": | |
| print("WARNING: Using default local retriever URL. This will likely fail.") | |
| print("Please update RETRIEVER_URL in app.py to your deployed service.") | |
| try: | |
| response = requests.post(RETRIEVER_URL, json=payload, timeout=10) | |
| response.raise_for_status() # Raise an error for bad responses | |
| results = response.json()['result'] | |
| format_reference = '' | |
| for idx, doc_item in enumerate(results[0]): | |
| content = doc_item['document']['contents'] | |
| title = content.split("\n")[0] | |
| text = "\n".join(content.split("\n")[1:]) | |
| format_reference += f"Doc {idx+1}(Title: {title}) {text}\n" | |
| return format_reference | |
| except requests.exceptions.RequestException as e: | |
| print(f"Error calling retriever: {e}") | |
| return f"Error: Could not retrieve search results for query: {query}" | |
| except (KeyError, IndexError): | |
| print("Error parsing retriever response") | |
| return "Error: Malformed response from retriever." | |
| # --- Main Gradio 'respond' Function --------------------------------- | |
| def respond( | |
| message, | |
| history: list[dict[str, str]], | |
| system_message, # This is now our base prompt | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| hf_token: gr.OAuthToken = None, # Not used here, but in template | |
| ): | |
| """ | |
| This function implements your local multi-turn search logic as a | |
| streaming generator for the Gradio interface. | |
| """ | |
| question = message.strip() | |
| # Use the system_message from the UI as the base prompt | |
| # Or, if empty, use your default. | |
| if not system_message: | |
| system_message = """You are a helpful assistant excel at answering questions with multi-turn search engine calling. \ | |
| To answer questions, you must first reason through the available information using <think> and </think>. \ | |
| If you identify missing knowledge, you may issue a search request using <search> query </search> at any time. The retrieval system will provide you with the three most relevant documents enclosed in <documents> and </documents>. \ | |
| After each search, you need to summarize and refine the existing documents in <refine> and </refine>. \ | |
| You may send multiple search requests if needed. \ | |
| Once you have sufficient information, provide a concise final answer using <answer> and </answer>. For example, <answer> Donald Trump </answer>.""" | |
| prompt = f"{system_message} Question: {question}\n" | |
| if tokenizer.chat_template: | |
| # Apply chat template if it exists | |
| # Note: Your logic builds the prompt manually, but this ensures | |
| # correct special tokens if the model needs them. | |
| chat_prompt = [{"role": "user", "content": prompt}] | |
| prompt = tokenizer.apply_chat_template(chat_prompt, add_generation_prompt=True, tokenize=False) | |
| # This string will accumulate the full agent trajectory | |
| full_response_trajectory = "" | |
| while True: | |
| input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) | |
| attention_mask = torch.ones_like(input_ids) | |
| # Check for context overflow | |
| if input_ids.shape[1] > model.config.max_position_embeddings - max_tokens: | |
| print("Context limit reached.") | |
| full_response_trajectory += "\n\n[Error: Context limit reached. Aborting.]" | |
| yield full_response_trajectory | |
| break | |
| # Generate text with the stopping criteria | |
| outputs = model.generate( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=max_tokens, | |
| stopping_criteria=stopping_criteria, | |
| pad_token_id=tokenizer.eos_token_id, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p | |
| ) | |
| # Decode the *newly* generated tokens | |
| generated_token_ids = outputs[0][input_ids.shape[1]:] | |
| output_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True) | |
| # Check if generation ended with an EOS token | |
| if outputs[0][-1].item() in curr_eos: | |
| full_response_trajectory += output_text | |
| yield full_response_trajectory # Yield the final text | |
| break # Exit the loop | |
| # --- Generation stopped at </search> --- | |
| # Get the full text (prompt + new generation) to parse the *last* query | |
| full_generation_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| query_text = get_query(full_generation_text) | |
| if query_text: | |
| search_results = search(query_text) | |
| else: | |
| search_results = 'Error: Stop token found but no <search> query was parsed.' | |
| # Construct the text to append to the prompt | |
| search_text = curr_search_template.format( | |
| output_text=output_text, | |
| search_results=search_results | |
| ) | |
| # Append to the prompt for the next loop | |
| prompt += search_text | |
| # Append to the trajectory string and yield to the UI | |
| full_response_trajectory += search_text | |
| yield full_response_trajectory | |
| # --- Gradio UI (Example) ------------------------------------------- | |
| # This part is just to make the file runnable. | |
| # You can customize your Gradio UI as needed. | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Multi-Turn Search Agent") | |
| gr.Markdown(f"Running model: `{model_id}`") | |
| with gr.Accordion("Prompt & Parameters"): | |
| system_message = gr.Textbox( | |
| label="System Message", | |
| value="""You are a helpful assistant... (full prompt from code)""", | |
| lines=10 | |
| ) | |
| max_tokens = gr.Slider(50, 2048, value=1024, label="Max New Tokens") | |
| temperature = gr.Slider(0.1, 1.0, value=0.7, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p") | |
| chatbot = gr.Chatbot(label="Agent Trajectory") | |
| msg = gr.Textbox(label="Your Question") | |
| def user_turn(user_message, history): | |
| return "", history + [[user_message, None]] | |
| msg.submit( | |
| user_turn, | |
| [msg, chatbot], | |
| [msg, chatbot], | |
| queue=False | |
| ).then( | |
| respond, | |
| [msg, chatbot, system_message, max_tokens, temperature, top_p], | |
| chatbot | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch(debug=True) |