Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import spaces | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| import torch | |
| from threading import Thread | |
| import time | |
| phi4_model_path = "Intelligent-Internet/II-Medical-8B" | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| phi4_model = AutoModelForCausalLM.from_pretrained(phi4_model_path, device_map="auto", torch_dtype="auto") | |
| phi4_tokenizer = AutoTokenizer.from_pretrained(phi4_model_path) | |
| # This is our streaming generator function that yields partial results | |
| def generate_streaming_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history): | |
| if not user_message.strip(): | |
| yield history, history | |
| return | |
| model = phi4_model | |
| tokenizer = phi4_tokenizer | |
| start_tag = "<|im_start|>" | |
| sep_tag = "<|im_sep|>" | |
| end_tag = "<|im_end|>" | |
| system_message = """You are a medical assistant AI designed to help diagnose symptoms, explain possible conditions, and recommend next steps. You must be cautious, thorough, and explain medical reasoning step-by-step. Structure your answer in two sections: | |
| <think> In this section, reason through the symptoms by considering patient history, differential diagnoses, relevant physiological mechanisms, and possible investigations. Explain your thought process step-by-step. </think> | |
| In the Solution section, summarize your working diagnosis, differential options, and suggest what to do next (e.g., tests, referral, lifestyle changes). Always clarify that this is not a replacement for a licensed medical professional. | |
| Use LaTeX for any formulas or values (e.g., $\\text{BMI} = \\frac{\\text{weight (kg)}}{\\text{height (m)}^2}$). | |
| Now, analyze the following case:""" | |
| # Build conversation history in the format the model expects | |
| prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}" | |
| # Convert chat history format from the Gradio Chatbot format to prompt format | |
| for user_msg, bot_msg in history: | |
| if user_msg: | |
| prompt += f"{start_tag}user{sep_tag}{user_msg}{end_tag}" | |
| if bot_msg: | |
| prompt += f"{start_tag}assistant{sep_tag}{bot_msg}{end_tag}" | |
| # Add the current user message | |
| prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}" | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) | |
| generation_kwargs = { | |
| "input_ids": inputs["input_ids"], | |
| "attention_mask": inputs["attention_mask"], | |
| "max_new_tokens": int(max_tokens), | |
| "do_sample": True, | |
| "temperature": float(temperature), | |
| "top_k": int(top_k), | |
| "top_p": float(top_p), | |
| "repetition_penalty": float(repetition_penalty), | |
| "streamer": streamer, | |
| } | |
| # Start generation in a separate thread | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # Create a new history with the current user message | |
| new_history = history.copy() + [[user_message, ""]] | |
| # Collect the generated response | |
| assistant_response = "" | |
| for new_token in streamer: | |
| cleaned_token = new_token.replace("<|im_start|>", "").replace("<|im_sep|>", "").replace("<|im_end|>", "") | |
| assistant_response += cleaned_token | |
| # Update the last message in history with the current response | |
| new_history[-1][1] = assistant_response.strip() | |
| yield new_history, new_history | |
| # Add a small sleep to control the streaming rate | |
| time.sleep(0.01) | |
| # Return the final state after streaming is completed | |
| yield new_history, new_history | |
| # This is our non-streaming wrapper function for buttons that don't support streaming | |
| def process_input(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history): | |
| generator = generate_streaming_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history) | |
| # Get the final result by exhausting the generator | |
| result = None | |
| for result in generator: | |
| pass | |
| return result | |
| example_messages = { | |
| "Headache case": "A 35-year-old female presents with a throbbing headache, nausea, and sensitivity to light. It started on one side of her head and worsens with activity. No prior trauma.", | |
| "Chest pain": "A 58-year-old male presents with chest tightness radiating to his left arm, shortness of breath, and sweating. Symptoms began while climbing stairs.", | |
| "Abdominal pain": "A 24-year-old complains of right lower quadrant abdominal pain, nausea, and mild fever. The pain started around the belly button and migrated.", | |
| "BMI calculation": "A patient weighs 85 kg and is 1.75 meters tall. Calculate the BMI and interpret whether it's underweight, normal, overweight, or obese." | |
| } | |
| css = """ | |
| .markdown-body .katex { | |
| font-size: 1.2em; | |
| } | |
| .markdown-body .katex-display { | |
| margin: 1em 0; | |
| overflow-x: auto; | |
| overflow-y: hidden; | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: | |
| gr.Markdown("# Medical Diagnostic Assistant\nThis AI assistant helps analyze symptoms and provide preliminary diagnostic reasoning using LaTeX-rendered medical formulas where needed.") | |
| gr.HTML(""" | |
| <script> | |
| if (typeof window.MathJax === 'undefined') { | |
| const script = document.createElement('script'); | |
| script.src = 'https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/MathJax.js?config=TeX-MML-AM_CHTML'; | |
| script.async = true; | |
| document.head.appendChild(script); | |
| window.MathJax = { | |
| tex2jax: { | |
| inlineMath: [['$', '$']], | |
| displayMath: [['$$', '$$']], | |
| processEscapes: true | |
| }, | |
| showProcessingMessages: false, | |
| messageStyle: 'none' | |
| }; | |
| } | |
| function rerender() { | |
| if (window.MathJax && window.MathJax.Hub) { | |
| window.MathJax.Hub.Queue(['Typeset', window.MathJax.Hub]); | |
| } | |
| } | |
| setInterval(rerender, 1000); | |
| </script> | |
| """) | |
| chatbot = gr.Chatbot(label="Chat", render_markdown=True, show_copy_button=True) | |
| history = gr.State([]) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Settings") | |
| max_tokens_slider = gr.Slider(64, 32768, step=1024, value=4096, label="Max Tokens") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| temperature_slider = gr.Slider(0.1, 2.0, value=0.8, label="Temperature") | |
| top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k") | |
| top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p") | |
| repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.0, label="Repetition Penalty") | |
| with gr.Column(scale=4): | |
| with gr.Row(): | |
| user_input = gr.Textbox(label="Describe symptoms or ask a medical question", placeholder="Type your message here...", scale=3) | |
| submit_button = gr.Button("Send", variant="primary", scale=1) | |
| clear_button = gr.Button("Clear", scale=1) | |
| gr.Markdown("**Try these examples:**") | |
| with gr.Row(): | |
| example1 = gr.Button("Headache case") | |
| example2 = gr.Button("Chest pain") | |
| example3 = gr.Button("Abdominal pain") | |
| example4 = gr.Button("BMI calculation") | |
| # Set up the streaming interface | |
| def on_submit(message, history, max_tokens, temperature, top_k, top_p, repetition_penalty): | |
| # Return the modified history that includes the new user message | |
| modified_history = history + [[message, ""]] | |
| return "", modified_history, modified_history | |
| def on_stream(history, max_tokens, temperature, top_k, top_p, repetition_penalty): | |
| if not history: | |
| return history | |
| # Get the last user message from history | |
| user_message = history[-1][0] | |
| # Start a fresh history without the last entry | |
| prev_history = history[:-1] | |
| # Generate streaming responses | |
| for new_history, _ in generate_streaming_response( | |
| user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, prev_history | |
| ): | |
| yield new_history | |
| # Connect the submission event | |
| submit_button.click( | |
| fn=on_submit, | |
| inputs=[user_input, history, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider], | |
| outputs=[user_input, chatbot, history] | |
| ).then( | |
| fn=on_stream, | |
| inputs=[history, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider], | |
| outputs=chatbot | |
| ) | |
| # Handle examples | |
| def set_example(example_text): | |
| return gr.update(value=example_text) | |
| clear_button.click(fn=lambda: ([], []), inputs=None, outputs=[chatbot, history]) | |
| example1.click(fn=lambda: set_example(example_messages["Headache case"]), inputs=None, outputs=user_input) | |
| example2.click(fn=lambda: set_example(example_messages["Chest pain"]), inputs=None, outputs=user_input) | |
| example3.click(fn=lambda: set_example(example_messages["Abdominal pain"]), inputs=None, outputs=user_input) | |
| example4.click(fn=lambda: set_example(example_messages["BMI calculation"]), inputs=None, outputs=user_input) | |
| demo.launch(ssr_mode=False) |