ayushsinha's picture
Upload 2 files
eed2432 verified
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
import numpy as np
# Load the model and tokenizer
model_name = "AventIQ-AI/distilbert-disease-specialist-recommendation"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Define the labels
labels = ["Cardiology", "Neurology", "Orthopedics", "Dermatology"]
# Define the prediction function
def recommend_specialist(symptoms):
inputs = tokenizer(symptoms, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
outputs = model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)[0]
confidence, predicted_class = torch.max(probabilities, dim=0)
recommended_specialist = labels[predicted_class]
return f"Recommended Specialist: {recommended_specialist} (Confidence: {confidence.item():.2f})"
# Create the Gradio interface
iface = gr.Interface(
fn=recommend_specialist,
inputs=gr.Textbox(label="πŸ“ Describe Your Symptoms", placeholder="e.g., experiencing chest pain and shortness of breath...", lines=3),
outputs=gr.Textbox(label="πŸ” Specialist Recommendation", interactive=True),
title="πŸ₯ Medical Specialist Recommender",
description="Enter your symptoms to receive a recommendation for the appropriate medical specialist.",
theme="compact",
allow_flagging="never",
examples=[
["I have persistent headaches and dizziness."],
["My skin has developed a red, itchy rash."],
["I'm experiencing joint pain and stiffness in the mornings."]
],
)
if __name__ == "__main__":
iface.launch()