--- license: apache-2.0 tags: - medical - computer-vision - image-classification - multi-label-classification - radiology - pytorch - convnext - cbam - timm datasets: - stanfordmlgroup/chexpert metrics: - auc pipeline_tag: image-classification library_name: timm # THIS LINK ADDS THE DEMO WIDGET TO YOUR MODEL PAGE demonstrated_in: calender/GRADCAM-Convnext-Chexpert-Attention --- # ConvNeXt-CheXpert: CBAM-Augmented Thoracic Classifier [![GitHub](https://img.shields.io/badge/GitHub-Source_Code-black?logo=github)](https://github.com/jikaan/convnext-chexpert-attention) [![Dataset](https://img.shields.io/badge/Dataset-CheXpert-green)](https://stanfordmlgroup.github.io/competitions/chexpert/) [![Open In Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Open%20Demo-blue)](https://huggingface.co/spaces/calender/GRADCAM-Convnext-Chexpert-Attention) ## Model Description This repository contains the weights for a **ConvNeXt-Base** architecture fine-tuned for multi-label classification of chest radiographs. The model is augmented with **Convolutional Block Attention Modules (CBAM)** to enhance feature localization and interpretability. **Try the live demo here:** [👉 GRADCAM-Convnext-Chexpert-Attention Space](https://huggingface.co/spaces/calender/GRADCAM-Convnext-Chexpert-Attention) * **Architecture:** ConvNeXt-Base + CBAM * **Validation AUC:** 0.81 (Iteration 6) * **Input Resolution:** 384x384 ## Detectable Classes The model outputs probabilities for the following 14 classes: 1. No Finding 2. Enlarged Cardiomediastinum 3. Cardiomegaly 4. Lung Opacity 5. Lung Lesion 6. Edema 7. Consolidation 8. Pneumonia 9. Atelectasis 10. Pneumothorax 11. Pleural Effusion 12. Pleural Other 13. Fracture 14. Support Devices ## Interpretability (Grad-CAM) The model's attention mechanism allows for precise localization of pathologies.
Atelectasis Analysis

Figure 1: Multi-label detection visualized via Grad-CAM.

## Usage ### Prerequisites ```bash pip install torch torchvision timm pillow ``` Inference Code This script loads the model and performs inference using the specific normalization statistics from training. Python ``` import torch from PIL import Image from torchvision import transforms import timm # 1. Configuration # Ensure you have downloaded model.pth from this repository MODEL_PATH = "model.pth" DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') pathologies = [ "No Finding", "Enlarged Cardiomediastinum", "Cardiomegaly", "Lung Opacity", "Lung Lesion", "Edema", "Consolidation", "Pneumonia", "Atelectasis", "Pneumothorax", "Pleural Effusion", "Pleural Other", "Fracture", "Support Devices" ] # 2. Load Model model = timm.create_model('convnext_base', pretrained=False, num_classes=14) model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) model.to(DEVICE) model.eval() # 3. Preprocessing # Note: These specific mean/std values are critical for accuracy transform = transforms.Compose([ transforms.Grayscale(num_output_channels=3), transforms.Resize((384, 384)), transforms.ToTensor(), transforms.Normalize( mean=[0.5029414296150208] * 3, std=[0.2892409563064575] * 3 ) ]) # 4. Run Prediction image = Image.open('chest_xray.jpg') input_tensor = transform(image).unsqueeze(0).to(DEVICE) with torch.no_grad(): logits = model(input_tensor) probs = torch.sigmoid(logits)[0] print("Predictions:") for name, score in zip(pathologies, probs): print(f"{name}: {score.item():.3f}") ``` Citation ``` @misc{convnext_chexpert_2025, author = {Calendar, S.}, title = {ConvNeXt-CheXpert: CBAM-Augmented Thoracic Classifier}, year = {2025}, publisher = {Hugging Face}, url = {[https://huggingface.co/calender/GRADCAM-Convnext-Chexpert-Attention](https://huggingface.co/calender/GRADCAM-Convnext-Chexpert-Attention)} }```