import gradio as gr import torch import numpy as np from pathlib import Path import re from Model import OmniPathWithInterTaskAttention from transformers import AutoModelForCausalLM, AutoTokenizer import transformers import os # 强制设置 Gradio 为英文环境 os.environ["GRADIO_LOCALE"] = "en" # 设备设置 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # 预加载模型(避免重复加载) @torch.no_grad() def load_models(): """Preload necessary models""" # 1. Load classification model ckpt_path = "best_model.pth" if not Path(ckpt_path).exists(): raise FileNotFoundError(f"Model file not found: {ckpt_path}") ckpt = torch.load(ckpt_path, map_location=device) label_mappings = ckpt.get('label_mappings', None) if not label_mappings: raise ValueError("The checkpoint is missing label_mappings") ck_cfg = ckpt.get('config', {}) feature_dim = 768 # Adjust according to your actual feature dimension hidden_dim = int(ck_cfg.get('hidden_dim', 256)) dropout = float(ck_cfg.get('dropout', 0.3)) use_inter_task_attention = bool(ck_cfg.get('use_inter_task_attention', True)) inter_task_heads = int(ck_cfg.get('inter_task_heads', 4)) classification_model = OmniPathWithInterTaskAttention( label_mappings=label_mappings, feature_dim=feature_dim, hidden_dim=hidden_dim, dropout=dropout, use_inter_task_attention=use_inter_task_attention, inter_task_heads=inter_task_heads ).to(device) classification_model.load_state_dict(ckpt['model_state_dict'], strict=False) classification_model.eval() # 2. Load text generation model llm_model_name = "Qwen/Qwen3-0.6B" # llm_model_name = "Qwen/QwQ-32B" tokenizer = AutoTokenizer.from_pretrained(llm_model_name) llm_model = AutoModelForCausalLM.from_pretrained( llm_model_name, dtype="auto", device_map="auto" ) return classification_model, llm_model, tokenizer, label_mappings # 预加载模型 classification_model, llm_model, tokenizer, label_mappings = load_models() def analyze_npy_file(npy_file): """Analyze NPY file and return prediction results""" if npy_file is None: return None, "Please upload an NPY file first" try: # Read NPY file arr = np.load(npy_file.name, allow_pickle=False) if not isinstance(arr, np.ndarray) or arr.ndim != 2: return None, "Error: NPY file must be a two-dimensional feature matrix" features = torch.from_numpy(arr).float() # Extract short ID p = Path(npy_file.name) m = re.search(r'(TCGA-[A-Z0-9]{2}-[A-Z0-9]{4})', p.name.upper()) short_id = m.group(1) if m else p.stem[:12] # Inference feat_batch = features.unsqueeze(0).to(device) outputs = classification_model(feat_batch) # Decode results pred_names, pred_scores = {}, {} for task_name, logits in outputs.items(): probs = torch.softmax(logits[0], dim=-1) idx = int(torch.argmax(probs).item()) classes = label_mappings[task_name]['classes'] class_name = classes[idx] if 0 <= idx < len(classes) else str(idx) pred_names[task_name] = class_name pred_scores[task_name] = float(probs[idx].item()) # Format results results_text = f"Patient ID: {short_id}\n\nPrediction Results:\n" for task, name in pred_names.items(): results_text += f"- {task}: {name} (Confidence: {pred_scores.get(task, 0.0):.3f})\n" return {"pred_names": pred_names, "pred_scores": pred_scores, "patient_id": short_id}, results_text except Exception as e: return None, f"An error occurred during processing: {str(e)}" def generate_response(message, chat_history, analysis_results): """Generate response based on user message and analysis results""" if analysis_results is None: return "Please upload an NPY file first to analyze the patient data.", chat_history pred_names = analysis_results["pred_names"] pred_scores = analysis_results["pred_scores"] patient_id = analysis_results["patient_id"] # Build context from analysis results context = f"Patient {patient_id} analysis results:\n" for task, name in pred_names.items(): context += f"- {task}: {name} (confidence: {pred_scores.get(task, 0.0):.3f})\n" # Build prompt based on user message if "diagnosis" in message.lower() or "result" in message.lower(): prompt = f"{context}\nBased on the above analysis results, provide a detailed diagnosis summary and interpretation." elif "treatment" in message.lower() or "therapy" in message.lower(): prompt = f"{context}\nBased on the diagnosis, suggest appropriate treatment options and considerations." elif "prognosis" in message.lower() or "outlook" in message.lower(): prompt = f"{context}\nDiscuss the prognosis and potential outcomes for this patient." elif "stage" in message.lower(): prompt = f"{context}\nExplain the staging information and its clinical implications." elif "histology" in message.lower() or "type" in message.lower(): prompt = f"{context}\nDescribe the histological characteristics and their significance." else: prompt = f"{context}\nUser question: {message}\nPlease provide a helpful response based on the analysis results." try: # Generate response using LLM messages = [{"role": "user", "content": prompt}] text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=False ) model_inputs = tokenizer([text], return_tensors="pt").to(llm_model.device) generated_ids = llm_model.generate( **model_inputs, max_new_tokens=32768, do_sample=True, temperature=0.7, ) output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() try: index = len(output_ids) - output_ids[::-1].index(151668) except ValueError: index = 0 response = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n") # Add to chat history chat_history.append((message, response)) return "", chat_history except Exception as e: error_msg = f"Error generating response: {str(e)}" chat_history.append((message, error_msg)) return "", chat_history def upload_file(npy_file, chat_history, analysis_results): """Handle file upload and initial analysis""" if npy_file is None: return chat_history, analysis_results, "Please select a file to upload" new_analysis_results, results_text = analyze_npy_file(npy_file) if new_analysis_results is None: return chat_history, analysis_results, results_text # Add analysis results to chat chat_history.append(("System", f"File uploaded and analyzed successfully!\n{results_text}")) chat_history.append(("System", "You can now ask questions about this patient's diagnosis, treatment options, prognosis, etc.")) return chat_history, new_analysis_results, "Analysis completed successfully!" def example_click(example): """Handle example question click""" return example # Create conversational interface with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🏥 Medical Pathology Diagnostic Chat Assistant Upload a pathology NPY file and chat with the AI assistant about the diagnosis, treatment options, prognosis, and more. """) # Store analysis results in session state analysis_results = gr.State(value=None) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Upload Patient Data") file_input = gr.File( label="Upload NPY Feature File", file_types=[".npy"], type="filepath" ) upload_btn = gr.Button("Upload & Analyze", variant="primary") status_output = gr.Textbox( label="Status", lines=2, interactive=False ) with gr.Column(scale=2): gr.Markdown("### Chat with Medical Assistant") chatbot = gr.Chatbot( label="Conversation", height=400 ) with gr.Row(): msg = gr.Textbox( label="Your Question", placeholder="Ask about diagnosis, treatment, prognosis...", lines=2, scale=4 ) send_btn = gr.Button("Send", variant="primary", scale=1) with gr.Row(): clear_btn = gr.Button("Clear Chat") gr.Markdown("### Suggested Questions") examples = gr.Examples( examples=[ "What is the diagnosis?", "What treatment options are available?", "What is the prognosis?", "Explain the staging information", "Describe the histological findings" ], inputs=msg, # 将示例应用到消息输入框 fn=example_click, # 点击示例时的处理函数 outputs=msg, # 输出到消息输入框 label="Click a question to use it" ) # Event handlers upload_btn.click( upload_file, inputs=[file_input, chatbot, analysis_results], outputs=[chatbot, analysis_results, status_output] ) send_btn.click( generate_response, inputs=[msg, chatbot, analysis_results], outputs=[msg, chatbot] ) msg.submit( generate_response, inputs=[msg, chatbot, analysis_results], outputs=[msg, chatbot] ) clear_btn.click( lambda: ([], None, "Chat cleared"), inputs=[], outputs=[chatbot, analysis_results, status_output] ) if __name__ == "__main__": demo.launch(share=True)