File size: 6,152 Bytes
2c29544 a3ed2a1 2c29544 e4d37e2 2c29544 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
---
license: apache-2.0
language:
- en
metrics:
- accuracy
base_model:
- microsoft/resnet-50
- timm/vgg19.tv_in1k
- google/vit-base-patch16-224
- xai-org/grok-1
pipeline_tag: image-classification
tags:
- Ocular-Toxoplasmosis(FundusImages)
- Retinal-images(Diabetics,Cataract,Gulocoma,Healthy)
- PyTorch
- Transformers
- Image-Classification
- Image_feature_extraction
- Grad-CAM
- XAI-Visualization
---
# Model Card: ROYXAI [Vision Transformer + VGG19 + ResNet50 Ensemble with Grad-CAM]
## Model Description
This model is an ensemble of three deep learning architectures: **Vision Transformer (ViT), VGG19, and ResNet50**. The ensemble approach enhances classification performance on medical image datasets related to ocular diseases. The model also integrates **Grad-CAM** visualization to highlight regions of interest for better interpretability.
## Model Details
- **Model Name**: ROYXAI
- **Developed by**: Avishek Roy Sparsho
- **Framework**: PyTorch
- **Ensemble Method**: Bagging
- **Backbone Models**: Vision Transformer, VGG19, ResNet50
- **Target Task**: Medical Image Classification
- **Supported Classes**:
- OT
- Healthy
- SC_diabetes
- SC_cataract
- SC_glucoma
## Model Sources
- **Repository**: [ROYXAI on Hugging Face](https://huggingface.co/Aviroy/ROYXAI)
## Uses
### Direct Use
This model is designed for medical image classification to detect and Visualize ocular diseases and its secondary complications.
### Downstream Use
Can be fine-tuned on different medical datasets to improve performance for specific conditions.
### Out-of-Scope Use
Not suitable for non-medical image classification tasks or use as a standalone medical diagnostic tool.
## Bias, Risks, and Limitations
- This model is trained on a specific dataset and may not generalize well to other medical image datasets without fine-tuning.
- It is **not a substitute for professional medical diagnosis**.
- The Vision Transformer model is computationally expensive compared to CNNs.
## Training Details
## Dataset
- **Dataset Name**: Custom Ocular Disease and its Secondary complications Dataset
- **Dataset Source**: Private Dataset (Medical Images)
- **Dataset Structure**: Images stored in folders based on class labels
- **Preprocessing**:
- Resized images to 224x224 pixels
- Normalized using ImageNet mean and standard deviation
### Training Procedure
- **Optimizer**: Adam with weight decay
- **Learning Rate Scheduler**: Cosine Annealing LR
- **Loss Function**: Cross-Entropy Loss
- **Batch Size**: 32
- **Training Epochs**: 20
- **Hardware Used**: T4 GPU x2
## Model Performance
- **Accuracy**: 98% on the test dataset
- **Precision/Recall/F1-score**: Evaluated and optimized for medical diagnosis
- **Overfitting Prevention**: Implemented **data augmentation, dropout, weight regularization**
## Installation and Usage
### Clone the Repository
```bash
git clone https://huggingface.co/Aviroy/ROYXAI
cd ROYXAI
```
### Install Dependencies
```bash
pip install -r requirements.txt
```
### Training the Model
To train the model from scratch, run:
```bash
python train.py --epochs 50 --batch_size 32
```
### Load Pretrained Model
To directly use the trained model:
```python
import torch
from PIL import Image
import torchvision.transforms as transforms
from model import ensemble_model # Load the trained ensemble model
# Define image transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Load and preprocess an image
image_path = "path/to/image.jpg"
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0).to('cuda' if torch.cuda.is_available() else 'cpu')
# Perform inference
ensemble_model.eval()
with torch.no_grad():
output = ensemble_model(image)
predicted_class = torch.argmax(output, dim=1).item()
# Print classification result
print("Predicted Class:", predicted_class)
```
## Grad-CAM Visualization
### Visualizing Attention Maps for Interpretability
#### Vision Transformer (ViT)
```python
from visualization import visualize_gradcam_vit # Function for ViT Grad-CAM
# Generate Grad-CAM visualization
overlay = visualize_gradcam_vit(ensemble_model.models[0], image, target_class=predicted_class)
# Display the Grad-CAM output
import matplotlib.pyplot as plt
plt.imshow(overlay)
plt.axis('off')
plt.title("Grad-CAM for Vision Transformer")
plt.show()
```
#### ResNet50
```python
from visualization import visualize_gradcam # General Grad-CAM function
# Generate Grad-CAM visualization for ResNet50
overlay = visualize_gradcam(ensemble_model.models[2], image, target_class=predicted_class)
# Display the Grad-CAM output
import matplotlib.pyplot as plt
plt.imshow(overlay)
plt.axis('off')
plt.title("Grad-CAM for ResNet50")
plt.show()
```
#### VGG19
```python
from visualization import visualize_gradcam # General Grad-CAM function
# Generate Grad-CAM visualization for VGG19
overlay = visualize_gradcam(ensemble_model.models[1], image, target_class=predicted_class)
# Display the Grad-CAM output
import matplotlib.pyplot as plt
plt.imshow(overlay)
plt.axis('off')
plt.title("Grad-CAM for VGG19")
plt.show()
```
## Environmental Impact
- **Hardware Type**: T4 GPU x2
- **Hours used**: 50
- **Cloud Provider**: Google Cloud (GCP)
- **Compute Region**: US-Central1
- **Carbon Emitted**: Estimated using [Machine Learning Impact Calculator](https://mlco2.github.io/impact#compute)
## Citation
If you use this model in your research, please cite:
## Citation
If you use this model in your research, please cite:
```
@article{Sparsho2025,
author = {Avishek Roy Sparsho},
title = {ROYXAI Model For Proper Visualization of Classified Medical Image},
journal = {Medical AI Research},
year = {2025}
}
```
## Acknowledgments
Special thanks to the open-source community and Kaggle for providing medical datasets for deep learning research.
## Contact
For inquiries, please contact: Avishek Roy Sparsho
## License
This model is released under the **Apache 2.0 License**. Use it responsibly. |