|
|
import gradio as gr |
|
|
import torch |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
import re |
|
|
from Model import OmniPathWithInterTaskAttention |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import transformers |
|
|
import os |
|
|
|
|
|
|
|
|
os.environ["GRADIO_LOCALE"] = "en" |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def load_models(): |
|
|
"""Preload necessary models""" |
|
|
|
|
|
ckpt_path = "best_model.pth" |
|
|
if not Path(ckpt_path).exists(): |
|
|
raise FileNotFoundError(f"Model file not found: {ckpt_path}") |
|
|
|
|
|
ckpt = torch.load(ckpt_path, map_location=device) |
|
|
label_mappings = ckpt.get('label_mappings', None) |
|
|
if not label_mappings: |
|
|
raise ValueError("The checkpoint is missing label_mappings") |
|
|
|
|
|
ck_cfg = ckpt.get('config', {}) |
|
|
feature_dim = 768 |
|
|
hidden_dim = int(ck_cfg.get('hidden_dim', 256)) |
|
|
dropout = float(ck_cfg.get('dropout', 0.3)) |
|
|
use_inter_task_attention = bool(ck_cfg.get('use_inter_task_attention', True)) |
|
|
inter_task_heads = int(ck_cfg.get('inter_task_heads', 4)) |
|
|
|
|
|
classification_model = OmniPathWithInterTaskAttention( |
|
|
label_mappings=label_mappings, |
|
|
feature_dim=feature_dim, |
|
|
hidden_dim=hidden_dim, |
|
|
dropout=dropout, |
|
|
use_inter_task_attention=use_inter_task_attention, |
|
|
inter_task_heads=inter_task_heads |
|
|
).to(device) |
|
|
classification_model.load_state_dict(ckpt['model_state_dict'], strict=False) |
|
|
classification_model.eval() |
|
|
|
|
|
|
|
|
llm_model_name = "Qwen/Qwen3-0.6B" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(llm_model_name) |
|
|
llm_model = AutoModelForCausalLM.from_pretrained( |
|
|
llm_model_name, |
|
|
dtype="auto", |
|
|
device_map="auto" |
|
|
) |
|
|
|
|
|
return classification_model, llm_model, tokenizer, label_mappings |
|
|
|
|
|
|
|
|
classification_model, llm_model, tokenizer, label_mappings = load_models() |
|
|
|
|
|
def analyze_npy_file(npy_file): |
|
|
"""Analyze NPY file and return prediction results""" |
|
|
if npy_file is None: |
|
|
return None, "Please upload an NPY file first" |
|
|
|
|
|
try: |
|
|
|
|
|
arr = np.load(npy_file.name, allow_pickle=False) |
|
|
if not isinstance(arr, np.ndarray) or arr.ndim != 2: |
|
|
return None, "Error: NPY file must be a two-dimensional feature matrix" |
|
|
|
|
|
features = torch.from_numpy(arr).float() |
|
|
|
|
|
|
|
|
p = Path(npy_file.name) |
|
|
m = re.search(r'(TCGA-[A-Z0-9]{2}-[A-Z0-9]{4})', p.name.upper()) |
|
|
short_id = m.group(1) if m else p.stem[:12] |
|
|
|
|
|
|
|
|
feat_batch = features.unsqueeze(0).to(device) |
|
|
outputs = classification_model(feat_batch) |
|
|
|
|
|
|
|
|
pred_names, pred_scores = {}, {} |
|
|
for task_name, logits in outputs.items(): |
|
|
probs = torch.softmax(logits[0], dim=-1) |
|
|
idx = int(torch.argmax(probs).item()) |
|
|
classes = label_mappings[task_name]['classes'] |
|
|
class_name = classes[idx] if 0 <= idx < len(classes) else str(idx) |
|
|
pred_names[task_name] = class_name |
|
|
pred_scores[task_name] = float(probs[idx].item()) |
|
|
|
|
|
|
|
|
results_text = f"Patient ID: {short_id}\n\nPrediction Results:\n" |
|
|
for task, name in pred_names.items(): |
|
|
results_text += f"- {task}: {name} (Confidence: {pred_scores.get(task, 0.0):.3f})\n" |
|
|
|
|
|
return {"pred_names": pred_names, "pred_scores": pred_scores, "patient_id": short_id}, results_text |
|
|
|
|
|
except Exception as e: |
|
|
return None, f"An error occurred during processing: {str(e)}" |
|
|
|
|
|
def generate_response(message, chat_history, analysis_results): |
|
|
"""Generate response based on user message and analysis results""" |
|
|
if analysis_results is None: |
|
|
return "Please upload an NPY file first to analyze the patient data.", chat_history |
|
|
|
|
|
pred_names = analysis_results["pred_names"] |
|
|
pred_scores = analysis_results["pred_scores"] |
|
|
patient_id = analysis_results["patient_id"] |
|
|
|
|
|
|
|
|
context = f"Patient {patient_id} analysis results:\n" |
|
|
for task, name in pred_names.items(): |
|
|
context += f"- {task}: {name} (confidence: {pred_scores.get(task, 0.0):.3f})\n" |
|
|
|
|
|
|
|
|
if "diagnosis" in message.lower() or "result" in message.lower(): |
|
|
prompt = f"{context}\nBased on the above analysis results, provide a detailed diagnosis summary and interpretation." |
|
|
elif "treatment" in message.lower() or "therapy" in message.lower(): |
|
|
prompt = f"{context}\nBased on the diagnosis, suggest appropriate treatment options and considerations." |
|
|
elif "prognosis" in message.lower() or "outlook" in message.lower(): |
|
|
prompt = f"{context}\nDiscuss the prognosis and potential outcomes for this patient." |
|
|
elif "stage" in message.lower(): |
|
|
prompt = f"{context}\nExplain the staging information and its clinical implications." |
|
|
elif "histology" in message.lower() or "type" in message.lower(): |
|
|
prompt = f"{context}\nDescribe the histological characteristics and their significance." |
|
|
else: |
|
|
prompt = f"{context}\nUser question: {message}\nPlease provide a helpful response based on the analysis results." |
|
|
|
|
|
try: |
|
|
|
|
|
messages = [{"role": "user", "content": prompt}] |
|
|
text = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True, |
|
|
enable_thinking=False |
|
|
) |
|
|
|
|
|
model_inputs = tokenizer([text], return_tensors="pt").to(llm_model.device) |
|
|
generated_ids = llm_model.generate( |
|
|
**model_inputs, |
|
|
max_new_tokens=32768, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
) |
|
|
|
|
|
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() |
|
|
try: |
|
|
index = len(output_ids) - output_ids[::-1].index(151668) |
|
|
except ValueError: |
|
|
index = 0 |
|
|
|
|
|
response = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n") |
|
|
|
|
|
|
|
|
chat_history.append((message, response)) |
|
|
return "", chat_history |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error generating response: {str(e)}" |
|
|
chat_history.append((message, error_msg)) |
|
|
return "", chat_history |
|
|
|
|
|
def upload_file(npy_file, chat_history, analysis_results): |
|
|
"""Handle file upload and initial analysis""" |
|
|
if npy_file is None: |
|
|
return chat_history, analysis_results, "Please select a file to upload" |
|
|
|
|
|
new_analysis_results, results_text = analyze_npy_file(npy_file) |
|
|
|
|
|
if new_analysis_results is None: |
|
|
return chat_history, analysis_results, results_text |
|
|
|
|
|
|
|
|
chat_history.append(("System", f"File uploaded and analyzed successfully!\n{results_text}")) |
|
|
chat_history.append(("System", "You can now ask questions about this patient's diagnosis, treatment options, prognosis, etc.")) |
|
|
|
|
|
return chat_history, new_analysis_results, "Analysis completed successfully!" |
|
|
|
|
|
def example_click(example): |
|
|
"""Handle example question click""" |
|
|
return example |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(""" |
|
|
# 🏥 Medical Pathology Diagnostic Chat Assistant |
|
|
|
|
|
Upload a pathology NPY file and chat with the AI assistant about the diagnosis, treatment options, prognosis, and more. |
|
|
""") |
|
|
|
|
|
|
|
|
analysis_results = gr.State(value=None) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Upload Patient Data") |
|
|
file_input = gr.File( |
|
|
label="Upload NPY Feature File", |
|
|
file_types=[".npy"], |
|
|
type="filepath" |
|
|
) |
|
|
upload_btn = gr.Button("Upload & Analyze", variant="primary") |
|
|
status_output = gr.Textbox( |
|
|
label="Status", |
|
|
lines=2, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
gr.Markdown("### Chat with Medical Assistant") |
|
|
chatbot = gr.Chatbot( |
|
|
label="Conversation", |
|
|
height=400 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
msg = gr.Textbox( |
|
|
label="Your Question", |
|
|
placeholder="Ask about diagnosis, treatment, prognosis...", |
|
|
lines=2, |
|
|
scale=4 |
|
|
) |
|
|
send_btn = gr.Button("Send", variant="primary", scale=1) |
|
|
|
|
|
with gr.Row(): |
|
|
clear_btn = gr.Button("Clear Chat") |
|
|
|
|
|
gr.Markdown("### Suggested Questions") |
|
|
examples = gr.Examples( |
|
|
examples=[ |
|
|
"What is the diagnosis?", |
|
|
"What treatment options are available?", |
|
|
"What is the prognosis?", |
|
|
"Explain the staging information", |
|
|
"Describe the histological findings" |
|
|
], |
|
|
inputs=msg, |
|
|
fn=example_click, |
|
|
outputs=msg, |
|
|
label="Click a question to use it" |
|
|
) |
|
|
|
|
|
|
|
|
upload_btn.click( |
|
|
upload_file, |
|
|
inputs=[file_input, chatbot, analysis_results], |
|
|
outputs=[chatbot, analysis_results, status_output] |
|
|
) |
|
|
|
|
|
send_btn.click( |
|
|
generate_response, |
|
|
inputs=[msg, chatbot, analysis_results], |
|
|
outputs=[msg, chatbot] |
|
|
) |
|
|
|
|
|
msg.submit( |
|
|
generate_response, |
|
|
inputs=[msg, chatbot, analysis_results], |
|
|
outputs=[msg, chatbot] |
|
|
) |
|
|
|
|
|
clear_btn.click( |
|
|
lambda: ([], None, "Chat cleared"), |
|
|
inputs=[], |
|
|
outputs=[chatbot, analysis_results, status_output] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(share=True) |