TETSU0701's picture
Update app.py
4f52637 verified
import gradio as gr
import torch
import numpy as np
from pathlib import Path
import re
from Model import OmniPathWithInterTaskAttention
from transformers import AutoModelForCausalLM, AutoTokenizer
import transformers
import os
# 强制设置 Gradio 为英文环境
os.environ["GRADIO_LOCALE"] = "en"
# 设备设置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# 预加载模型(避免重复加载)
@torch.no_grad()
def load_models():
"""Preload necessary models"""
# 1. Load classification model
ckpt_path = "best_model.pth"
if not Path(ckpt_path).exists():
raise FileNotFoundError(f"Model file not found: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location=device)
label_mappings = ckpt.get('label_mappings', None)
if not label_mappings:
raise ValueError("The checkpoint is missing label_mappings")
ck_cfg = ckpt.get('config', {})
feature_dim = 768 # Adjust according to your actual feature dimension
hidden_dim = int(ck_cfg.get('hidden_dim', 256))
dropout = float(ck_cfg.get('dropout', 0.3))
use_inter_task_attention = bool(ck_cfg.get('use_inter_task_attention', True))
inter_task_heads = int(ck_cfg.get('inter_task_heads', 4))
classification_model = OmniPathWithInterTaskAttention(
label_mappings=label_mappings,
feature_dim=feature_dim,
hidden_dim=hidden_dim,
dropout=dropout,
use_inter_task_attention=use_inter_task_attention,
inter_task_heads=inter_task_heads
).to(device)
classification_model.load_state_dict(ckpt['model_state_dict'], strict=False)
classification_model.eval()
# 2. Load text generation model
llm_model_name = "Qwen/Qwen3-0.6B"
# llm_model_name = "Qwen/QwQ-32B"
tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
llm_model = AutoModelForCausalLM.from_pretrained(
llm_model_name,
dtype="auto",
device_map="auto"
)
return classification_model, llm_model, tokenizer, label_mappings
# 预加载模型
classification_model, llm_model, tokenizer, label_mappings = load_models()
def analyze_npy_file(npy_file):
"""Analyze NPY file and return prediction results"""
if npy_file is None:
return None, "Please upload an NPY file first"
try:
# Read NPY file
arr = np.load(npy_file.name, allow_pickle=False)
if not isinstance(arr, np.ndarray) or arr.ndim != 2:
return None, "Error: NPY file must be a two-dimensional feature matrix"
features = torch.from_numpy(arr).float()
# Extract short ID
p = Path(npy_file.name)
m = re.search(r'(TCGA-[A-Z0-9]{2}-[A-Z0-9]{4})', p.name.upper())
short_id = m.group(1) if m else p.stem[:12]
# Inference
feat_batch = features.unsqueeze(0).to(device)
outputs = classification_model(feat_batch)
# Decode results
pred_names, pred_scores = {}, {}
for task_name, logits in outputs.items():
probs = torch.softmax(logits[0], dim=-1)
idx = int(torch.argmax(probs).item())
classes = label_mappings[task_name]['classes']
class_name = classes[idx] if 0 <= idx < len(classes) else str(idx)
pred_names[task_name] = class_name
pred_scores[task_name] = float(probs[idx].item())
# Format results
results_text = f"Patient ID: {short_id}\n\nPrediction Results:\n"
for task, name in pred_names.items():
results_text += f"- {task}: {name} (Confidence: {pred_scores.get(task, 0.0):.3f})\n"
return {"pred_names": pred_names, "pred_scores": pred_scores, "patient_id": short_id}, results_text
except Exception as e:
return None, f"An error occurred during processing: {str(e)}"
def generate_response(message, chat_history, analysis_results):
"""Generate response based on user message and analysis results"""
if analysis_results is None:
return "Please upload an NPY file first to analyze the patient data.", chat_history
pred_names = analysis_results["pred_names"]
pred_scores = analysis_results["pred_scores"]
patient_id = analysis_results["patient_id"]
# Build context from analysis results
context = f"Patient {patient_id} analysis results:\n"
for task, name in pred_names.items():
context += f"- {task}: {name} (confidence: {pred_scores.get(task, 0.0):.3f})\n"
# Build prompt based on user message
if "diagnosis" in message.lower() or "result" in message.lower():
prompt = f"{context}\nBased on the above analysis results, provide a detailed diagnosis summary and interpretation."
elif "treatment" in message.lower() or "therapy" in message.lower():
prompt = f"{context}\nBased on the diagnosis, suggest appropriate treatment options and considerations."
elif "prognosis" in message.lower() or "outlook" in message.lower():
prompt = f"{context}\nDiscuss the prognosis and potential outcomes for this patient."
elif "stage" in message.lower():
prompt = f"{context}\nExplain the staging information and its clinical implications."
elif "histology" in message.lower() or "type" in message.lower():
prompt = f"{context}\nDescribe the histological characteristics and their significance."
else:
prompt = f"{context}\nUser question: {message}\nPlease provide a helpful response based on the analysis results."
try:
# Generate response using LLM
messages = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
model_inputs = tokenizer([text], return_tensors="pt").to(llm_model.device)
generated_ids = llm_model.generate(
**model_inputs,
max_new_tokens=32768,
do_sample=True,
temperature=0.7,
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
try:
index = len(output_ids) - output_ids[::-1].index(151668)
except ValueError:
index = 0
response = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
# Add to chat history
chat_history.append((message, response))
return "", chat_history
except Exception as e:
error_msg = f"Error generating response: {str(e)}"
chat_history.append((message, error_msg))
return "", chat_history
def upload_file(npy_file, chat_history, analysis_results):
"""Handle file upload and initial analysis"""
if npy_file is None:
return chat_history, analysis_results, "Please select a file to upload"
new_analysis_results, results_text = analyze_npy_file(npy_file)
if new_analysis_results is None:
return chat_history, analysis_results, results_text
# Add analysis results to chat
chat_history.append(("System", f"File uploaded and analyzed successfully!\n{results_text}"))
chat_history.append(("System", "You can now ask questions about this patient's diagnosis, treatment options, prognosis, etc."))
return chat_history, new_analysis_results, "Analysis completed successfully!"
def example_click(example):
"""Handle example question click"""
return example
# Create conversational interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🏥 Medical Pathology Diagnostic Chat Assistant
Upload a pathology NPY file and chat with the AI assistant about the diagnosis, treatment options, prognosis, and more.
""")
# Store analysis results in session state
analysis_results = gr.State(value=None)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Upload Patient Data")
file_input = gr.File(
label="Upload NPY Feature File",
file_types=[".npy"],
type="filepath"
)
upload_btn = gr.Button("Upload & Analyze", variant="primary")
status_output = gr.Textbox(
label="Status",
lines=2,
interactive=False
)
with gr.Column(scale=2):
gr.Markdown("### Chat with Medical Assistant")
chatbot = gr.Chatbot(
label="Conversation",
height=400
)
with gr.Row():
msg = gr.Textbox(
label="Your Question",
placeholder="Ask about diagnosis, treatment, prognosis...",
lines=2,
scale=4
)
send_btn = gr.Button("Send", variant="primary", scale=1)
with gr.Row():
clear_btn = gr.Button("Clear Chat")
gr.Markdown("### Suggested Questions")
examples = gr.Examples(
examples=[
"What is the diagnosis?",
"What treatment options are available?",
"What is the prognosis?",
"Explain the staging information",
"Describe the histological findings"
],
inputs=msg, # 将示例应用到消息输入框
fn=example_click, # 点击示例时的处理函数
outputs=msg, # 输出到消息输入框
label="Click a question to use it"
)
# Event handlers
upload_btn.click(
upload_file,
inputs=[file_input, chatbot, analysis_results],
outputs=[chatbot, analysis_results, status_output]
)
send_btn.click(
generate_response,
inputs=[msg, chatbot, analysis_results],
outputs=[msg, chatbot]
)
msg.submit(
generate_response,
inputs=[msg, chatbot, analysis_results],
outputs=[msg, chatbot]
)
clear_btn.click(
lambda: ([], None, "Chat cleared"),
inputs=[],
outputs=[chatbot, analysis_results, status_output]
)
if __name__ == "__main__":
demo.launch(share=True)