Spaces:
Running
Running
| import gradio as gr | |
| from huggingface_hub import InferenceClient # Keep for direct use if needed, though agent will use its own model | |
| import os | |
| import json | |
| import base64 | |
| from PIL import Image | |
| import io | |
| # Smolagents imports | |
| from smolagents import CodeAgent, Tool | |
| from smolagents.models import InferenceClientModel as SmolInferenceClientModel | |
| # We'll use PIL.Image directly for opening, AgentImage is for agent's internal typing if needed by a tool | |
| from smolagents.gradio_ui import pull_messages_from_step # For formatting agent steps | |
| from smolagents.memory import ActionStep, FinalAnswerStep, PlanningStep, MemoryStep # For type checking steps | |
| from smolagents.models import ChatMessageStreamDelta # For type checking stream deltas | |
| ACCESS_TOKEN = os.getenv("HF_TOKEN") | |
| print("Access token loaded.") | |
| # Function to encode image to base64 (remains useful if we ever need to pass base64 to a non-smolagent component) | |
| def encode_image(image_path_or_pil): | |
| if not image_path_or_pil: | |
| print("No image path or PIL Image provided") | |
| return None | |
| try: | |
| # print(f"Encoding image: {type(image_path_or_pil)}") # Debug | |
| if isinstance(image_path_or_pil, Image.Image): | |
| image = image_path_or_pil | |
| else: # Assuming it's a path | |
| image = Image.open(image_path_or_pil) | |
| if image.mode == 'RGBA': | |
| image = image.convert('RGB') | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="JPEG") # JPEG is generally smaller for transfer | |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| # print("Image encoded successfully") # Debug | |
| return img_str | |
| except Exception as e: | |
| print(f"Error encoding image: {e}") | |
| return None | |
| # This function will now set up and run the smolagent | |
| def respond( | |
| message_text, # Text from MultimodalTextbox | |
| image_file_paths, # List of file paths from MultimodalTextbox | |
| gradio_history: list[tuple[str, str]], # Gradio history (for context if needed, agent is stateless per call here) | |
| system_message_for_agent, # System prompt for the main LLM agent | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| frequency_penalty, | |
| seed, | |
| provider_for_agent_llm, | |
| api_key_for_agent_llm, | |
| model_id_for_agent_llm, | |
| model_search_term, # Unused directly by agent logic | |
| selected_model_for_agent_llm # Fallback model ID | |
| ): | |
| print(f"Respond function called. Message: '{message_text}', Images: {image_file_paths}") | |
| token_to_use = api_key_for_agent_llm if api_key_for_agent_llm.strip() != "" else ACCESS_TOKEN | |
| model_to_use = model_id_for_agent_llm.strip() if model_id_for_agent_llm.strip() != "" else selected_model_for_agent_llm | |
| # --- Initialize the LLM for the CodeAgent --- | |
| agent_llm_params = { | |
| "model_id": model_to_use, | |
| "token": token_to_use, | |
| # smolagents's InferenceClientModel uses max_tokens for max_new_tokens | |
| "max_tokens": max_tokens, | |
| "temperature": temperature if temperature > 0.01 else None, # Some models require temp > 0 | |
| "top_p": top_p if top_p < 1.0 else None, # Often 1.0 means no top_p | |
| "seed": seed if seed != -1 else None, | |
| } | |
| if provider_for_agent_llm and provider_for_agent_llm != "hf-inference": | |
| agent_llm_params["provider"] = provider_for_agent_llm | |
| # HFIC specific params, add if not default and supported | |
| if frequency_penalty != 0.0: | |
| agent_llm_params["frequency_penalty"] = frequency_penalty | |
| agent_llm = SmolInferenceClientModel(**agent_llm_params) | |
| print(f"Smolagents LLM for agent initialized: model='{model_to_use}', provider='{provider_for_agent_llm or 'default'}'") | |
| # --- Define Tools for the Agent --- | |
| agent_tools = [] | |
| try: | |
| image_gen_tool = Tool.from_space( | |
| space_id="black-forest-labs/FLUX.1-schnell", | |
| name="image_generator", | |
| description="Generates an image from a textual prompt. Input is a single string argument named 'prompt'. Output is an image file path.", | |
| token=token_to_use | |
| ) | |
| agent_tools.append(image_gen_tool) | |
| print("Image generation tool loaded: black-forest-labs/FLUX.1-schnell") | |
| except Exception as e: | |
| print(f"Error loading image generation tool: {e}") | |
| yield f"Error: Could not load image generation tool. {e}" | |
| return | |
| # --- Initialize the CodeAgent --- | |
| # If system_message_for_agent is empty, CodeAgent will use its default. | |
| # The default is usually good as it explains how to use tools. | |
| agent = CodeAgent( | |
| tools=agent_tools, | |
| model=agent_llm, | |
| system_prompt=system_message_for_agent if system_message_for_agent and system_message_for_agent.strip() else None, | |
| # add_base_tools=True, # Consider adding Python interpreter, etc. | |
| stream_outputs=True # Important for Gradio streaming | |
| ) | |
| print("Smolagents CodeAgent initialized.") | |
| # --- Prepare task and image inputs for the agent --- | |
| agent_task_text = message_text | |
| pil_images_for_agent = [] | |
| if image_file_paths: | |
| for file_path in image_file_paths: | |
| try: | |
| pil_images_for_agent.append(Image.open(file_path)) | |
| except Exception as e: | |
| print(f"Error opening image file {file_path} for agent: {e}") | |
| print(f"Agent task: '{agent_task_text}'") | |
| if pil_images_for_agent: | |
| print(f"Passing {len(pil_images_for_agent)} image(s) to agent.") | |
| # --- Run the agent and stream response --- | |
| # Agent is reset each turn. For conversational memory, agent instance | |
| # would need to be stored in session_state and agent.run(..., reset=False) used. | |
| current_agent_response_text = "" | |
| try: | |
| # The agent.run method returns a generator when stream=True | |
| for step_item in agent.run( | |
| task=agent_task_text, | |
| images=pil_images_for_agent, | |
| stream=True, | |
| reset=True # Explicitly reset for stateless operation per call | |
| ): | |
| if isinstance(step_item, ChatMessageStreamDelta): | |
| if step_item.content: | |
| current_agent_response_text += step_item.content | |
| yield current_agent_response_text # Yield accumulated text | |
| elif isinstance(step_item, (ActionStep, PlanningStep, FinalAnswerStep)): | |
| # A structured step. Format it for Gradio. | |
| # pull_messages_from_step yields gr.ChatMessage objects. | |
| for gradio_chat_msg in pull_messages_from_step(step_item, skip_model_outputs=agent.stream_outputs): | |
| # The 'bot' function will handle these gr.ChatMessage objects. | |
| yield gradio_chat_msg # Yield the gr.ChatMessage object directly | |
| current_agent_response_text = "" # Reset text buffer after a structured step | |
| # else: | |
| # print(f"Unhandled stream item type: {type(step_item)}") # Debug | |
| # If there's any remaining text not part of a gr.ChatMessage, yield it. | |
| # This usually shouldn't happen if stream_to_gradio logic is followed, | |
| # as text deltas should be part of the last gr.ChatMessage or yielded before it. | |
| # However, if the agent's final textual answer comes as pure deltas after all steps. | |
| if current_agent_response_text and not isinstance(step_item, FinalAnswerStep): | |
| # Check if the last yielded item already contains this text | |
| if not (isinstance(step_item, gr.ChatMessage) and step_item.content == current_agent_response_text): | |
| yield current_agent_response_text | |
| except Exception as e: | |
| error_message = f"Error during agent execution: {str(e)}" | |
| print(error_message) | |
| yield error_message # Yield the error message to be displayed in UI | |
| print("Agent run completed.") | |
| # Function to validate provider selection based on BYOK | |
| def validate_provider(api_key, provider): | |
| if not api_key.strip() and provider != "hf-inference": | |
| return gr.update(value="hf-inference") | |
| return gr.update(value=provider) | |
| # GRADIO UI | |
| with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo: | |
| chatbot = gr.Chatbot( | |
| height=600, | |
| show_copy_button=True, | |
| placeholder="Select a model and begin chatting. Now uses smolagents with tools!", | |
| layout="panel", | |
| bubble_full_width=False # For better display of images/files | |
| ) | |
| print("Chatbot interface created.") | |
| msg = gr.MultimodalTextbox( | |
| placeholder="Type a message or upload images...", | |
| show_label=False, | |
| container=False, | |
| scale=12, | |
| file_types=["image"], | |
| file_count="multiple", | |
| sources=["upload"] | |
| ) | |
| with gr.Accordion("Settings", open=False): | |
| system_message_box = gr.Textbox( | |
| value="You are a helpful AI assistant. You can generate images if asked. Be precise with your prompts for image generation.", | |
| placeholder="You are a helpful AI assistant.", | |
| label="System Prompt for Agent" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="Max New Tokens") | |
| temperature_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.01, label="Temperature") | |
| top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="Top-P") | |
| with gr.Column(): | |
| frequency_penalty_slider = gr.Slider(minimum=-2.0, maximum=2.0, value=0.0, step=0.1, label="Frequency Penalty") | |
| seed_slider = gr.Slider(minimum=-1, maximum=65535, value=-1, step=1, label="Seed (-1 for random)") | |
| providers_list = [ | |
| "hf-inference", "cerebras", "together", "sambanova", "novita", | |
| "cohere", "fireworks-ai", "hyperbolic", "nebius", | |
| ] | |
| provider_radio = gr.Radio(choices=providers_list, value="hf-inference", label="Inference Provider for Agent's LLM") | |
| byok_textbox = gr.Textbox(value="", label="BYOK (Your HF Token or Provider API Key)", info="Enter API key for the selected provider. Uses HF_TOKEN if empty.", placeholder="Enter your API token", type="password") | |
| custom_model_box = gr.Textbox(value="", label="Custom Model ID for Agent's LLM", info="(Optional) Provide a custom model ID. Overrides featured model.", placeholder="meta-llama/Llama-3.3-70B-Instruct") | |
| model_search_box = gr.Textbox(label="Filter Featured Models", placeholder="Search for a featured model...", lines=1) | |
| models_list = [ | |
| "meta-llama/Llama-3.3-70B-Instruct", "meta-llama/Llama-3.1-70B-Instruct", "meta-llama/Llama-3.0-70B-Instruct", | |
| "meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-3B-Instruct", "meta-llama/Llama-3.2-1B-Instruct", | |
| "meta-llama/Llama-3.1-8B-Instruct", "NousResearch/Hermes-3-Llama-3.1-8B", "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", | |
| "mistralai/Mistral-Nemo-Instruct-2407", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.3", | |
| "Qwen/Qwen3-235B-A22B", "Qwen/Qwen3-32B", "Qwen/Qwen2.5-72B-Instruct", "Qwen/Qwen2.5-3B-Instruct", | |
| "Qwen/Qwen2.5-Coder-32B-Instruct", "microsoft/Phi-3.5-mini-instruct", "microsoft/Phi-3-mini-128k-instruct", | |
| ] | |
| featured_model_radio = gr.Radio(label="Select a Featured Model for Agent's LLM", choices=models_list, value="meta-llama/Llama-3.3-70B-Instruct", interactive=True) | |
| gr.Markdown("[View all Text-to-Text models](https://huggingface.co/models?inference_provider=all&pipeline_tag=text-generation&sort=trending) | [View all multimodal models](https://huggingface.co/models?inference_provider=all&pipeline_tag=image-text-to-text&sort=trending)") | |
| # Chat history state (using gr.State to manage it properly) | |
| # The chatbot's value itself will be the history display. | |
| # We might need a separate gr.State if agent needs to be conversational across turns. | |
| # For now, agent is stateless per turn. | |
| # Function for the chat interface | |
| def user(user_multimodal_input_dict, history): | |
| print(f"User input: {user_multimodal_input_dict}") | |
| text_content = user_multimodal_input_dict.get("text", "") | |
| files = user_multimodal_input_dict.get("files", []) | |
| user_display_parts = [] | |
| if text_content and text_content.strip(): | |
| user_display_parts.append(text_content) | |
| for file_path_obj in files: # file_path_obj is a tempfile._TemporaryFileWrapper | |
| user_display_parts.append((file_path_obj.name, os.path.basename(file_path_obj.name))) | |
| if not user_display_parts: | |
| return history | |
| # Append the user's multimodal message to history for display | |
| # The actual data (dict) is passed to `bot` function separately. | |
| history.append([user_display_parts if len(user_display_parts) > 1 else user_display_parts[0], None]) | |
| return history | |
| def bot(history, system_msg, max_tokens, temperature, top_p, freq_penalty, seed, provider, api_key, custom_model, search_term, selected_model): | |
| if not history or not history[-1][0]: # If no user input | |
| yield history | |
| return | |
| # The user's input (text and list of file paths) is in history[-1][0] | |
| # If `user` function stores the dict: | |
| raw_user_input_dict = history[-1][0] if isinstance(history[-1][0], dict) else {"text": str(history[-1][0]), "files": []} | |
| # If `user` function stores formatted display parts: | |
| # We need to reconstruct or rely on msg input to bot. | |
| # For now, assuming msg.submit passes the raw dict. | |
| # Let's adjust the Gradio flow to pass `msg` directly to `bot` as well. | |
| # The `msg` variable in `msg.submit` holds the raw MultimodalTextbox output. | |
| # We need to pass this raw dict to `respond`. | |
| # The `history` is for display. | |
| # This part is tricky as `bot` gets `history` which is already formatted for display. | |
| # A common pattern is to pass `msg` (raw input) also to `bot`. | |
| # Let's assume `history[-1][0]` contains enough info or we adjust `user` fn. | |
| # For simplicity, let's assume `user` stores the raw dict if needed, | |
| # or `bot` can parse `history[-1][0]` if it's a string/list of tuples. | |
| # Let's assume `history[-1][0]` is the raw `user_multimodal_input_dict` | |
| # This means the `user` function must append it like: `history.append([user_multimodal_input_dict, None])` | |
| # And the chatbot will display `str(user_multimodal_input_dict)`. | |
| # This is what the current `user` function does. | |
| user_input_data = history[-1][0] # This should be the dict from MultimodalTextbox | |
| text_input_for_agent = user_input_data.get("text", "") | |
| # Files from MultimodalTextbox are temp file paths | |
| image_file_paths_for_agent = [f.name for f in user_input_data.get("files", []) if hasattr(f, 'name')] | |
| history[-1][1] = "" # Initialize assistant's part for streaming | |
| # Buffer for current text stream from agent | |
| # Handles both pure text deltas and text content from gr.ChatMessage | |
| current_text_for_turn = "" | |
| for item in respond( | |
| message_text=text_input_for_agent, | |
| image_file_paths=image_file_paths_for_agent, | |
| gradio_history=history[:-1], # Pass previous turns for context if agent uses it | |
| system_message_for_agent=system_msg, | |
| max_tokens=max_tokens, temperature=temperature, top_p=top_p, | |
| frequency_penalty=freq_penalty, seed=seed, | |
| provider_for_agent_llm=provider, api_key_for_agent_llm=api_key, | |
| model_id_for_agent_llm=custom_model, | |
| model_search_term=search_term, # unused | |
| selected_model_for_agent_llm=selected_model | |
| ): | |
| if isinstance(item, str): # LLM text delta from agent's thought or textual answer | |
| current_text_for_turn = item | |
| history[-1][1] = current_text_for_turn | |
| elif isinstance(item, gr.ChatMessage): | |
| # This is a structured step (thought, tool output, image, etc.) | |
| # We need to append this to the history as a new message or part of current message. | |
| # For simplicity, let's append its string content to the current turn's assistant message. | |
| # If it's an image/file, we'll represent it as a markdown link. | |
| if isinstance(item.content, str): | |
| current_text_for_turn = item.content # Replace if it's a full message | |
| elif isinstance(item.content, dict) and "path" in item.content: | |
| # This is typically an image or audio file | |
| file_path = item.content["path"] | |
| # We need to make this file accessible to Gradio if it's temporary from agent | |
| # For now, just put a placeholder. | |
| # If it's an output from a tool, the path might be relative to where smolagents saves it. | |
| # Gradio needs an absolute path or a URL. | |
| # A common pattern is to copy temp files to a static dir served by Gradio or use gr.File. | |
| # For now, let's assume Gradio can handle local paths if they are in a folder it knows. | |
| # We'll display it as a tuple for Gradio Chatbot. | |
| # This means history[-1][1] needs to become a list. | |
| # If current_text_for_turn is not empty, make history[-1][1] a list | |
| if current_text_for_turn and not isinstance(history[-1][1], list): | |
| history[-1][1] = [current_text_for_turn] | |
| elif not current_text_for_turn and not isinstance(history[-1][1], list): | |
| history[-1][1] = [] | |
| alt_text = item.metadata.get("title", os.path.basename(file_path)) if item.metadata else os.path.basename(file_path) | |
| # Add as new component to the list for current assistant message | |
| if isinstance(history[-1][1], list): | |
| history[-1][1].append((file_path, alt_text)) | |
| else: # Should have been made a list above | |
| history[-1][1] = [(file_path, alt_text)] | |
| current_text_for_turn = "" # Reset text buffer after a file | |
| # If it's not a delta, but a full message, replace the current text | |
| if not isinstance(history[-1][1], list): # if it hasn't become a list due to file | |
| history[-1][1] = current_text_for_turn | |
| yield history | |
| # Event handlers | |
| # `msg.submit`'s first argument is the function to call. | |
| # Its `inputs` are the Gradio components whose values are passed to the function. | |
| # Its `outputs` are the Gradio components that are updated by the function's return value. | |
| # The `user` function now appends the raw dict from MultimodalTextbox to history. | |
| # The `bot` function takes this history. | |
| # When msg is submitted: | |
| # 1. Call `user` to update history with user's input. Output is `chatbot`. | |
| # 2. Then call `bot` with the updated history. Output is `chatbot`. | |
| # 3. Then clear `msg` | |
| msg.submit( | |
| user, | |
| [msg, chatbot], | |
| [chatbot], # `user` returns the new history, updating the chatbot display | |
| queue=False | |
| ).then( | |
| bot, | |
| [chatbot, system_message_box, max_tokens_slider, temperature_slider, top_p_slider, | |
| frequency_penalty_slider, seed_slider, provider_radio, byok_textbox, custom_model_box, | |
| model_search_box, featured_model_radio], | |
| [chatbot] # `bot` yields history updates, streaming to chatbot | |
| ).then( | |
| lambda: {"text": "", "files": []}, # Clear MultimodalTextbox | |
| None, | |
| [msg] | |
| ) | |
| model_search_box.change(fn=filter_models, inputs=model_search_box, outputs=featured_model_radio) | |
| featured_model_radio.change(fn=set_custom_model_from_radio, inputs=featured_model_radio, outputs=custom_model_box) | |
| byok_textbox.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio) | |
| provider_radio.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio) | |
| print("Gradio interface initialized.") | |
| if __name__ == "__main__": | |
| print("Launching the demo application.") | |
| demo.launch(show_api=False) # show_api=False for cleaner launch, True for API docs |