Spaces:
Running
Running
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from transformers import AutoModel | |
| def calculate_ctr(mask: np.ndarray) -> float: | |
| # mask.ndim = 2, (height, width) | |
| lungs = np.zeros_like(mask) | |
| lungs[mask == 1] = 1 | |
| lungs[mask == 2] = 1 | |
| heart = (mask == 3).astype("int") | |
| y, x = np.stack(np.where(lungs == 1)) | |
| lung_min = x.min() | |
| lung_max = x.max() | |
| y, x = np.stack(np.where(heart == 1)) | |
| heart_min = x.min() | |
| heart_max = x.max() | |
| lung_range = lung_max - lung_min | |
| heart_range = heart_max - heart_min | |
| return heart_range / lung_range | |
| def make_overlay( | |
| img: np.ndarray, mask: np.ndarray, alpha: float = 0.7 | |
| ) -> np.ndarray[np.uint8]: | |
| overlay = alpha * img + (1 - alpha) * mask | |
| return overlay.astype(np.uint8) | |
| def predict(Radiograph): | |
| rg = cv2.cvtColor(Radiograph, cv2.COLOR_GRAY2RGB) | |
| x = cxr_info_model.preprocess(Radiograph) | |
| x = torch.from_numpy(x).float().to(device) | |
| x = rearrange(x, "h w -> 1 1 h w") | |
| with torch.inference_mode(): | |
| info_out = cxr_info_model(x) | |
| info_mask = info_out["mask"] | |
| h, w = rg.shape[:2] | |
| info_mask = F.interpolate(info_mask, size=(h, w), mode="bilinear") | |
| info_mask = info_mask.argmax(1)[0] | |
| info_mask_3ch = F.one_hot(info_mask, num_classes=4)[..., 1:] | |
| info_mask_3ch = (info_mask_3ch.cpu().numpy() * 255).astype(np.uint8) | |
| info_overlay = make_overlay(rg, info_mask_3ch[..., ::-1]) | |
| view = info_out["view"].argmax(1).item() | |
| info_string = "" | |
| if view in {0, 1}: | |
| info_string += "This is a frontal chest radiograph " | |
| if view == 0: | |
| info_string += "(AP projection)." | |
| elif view == 1: | |
| info_string += "(PA projection)." | |
| elif view == 2: | |
| info_string += "This is a lateral chest radiograph." | |
| age = info_out["age"].item() | |
| info_string += f"\nThe patient's predicted age is {round(age)} years." | |
| sex = info_out["female"].item() | |
| if sex < 0.5: | |
| sex = "male" | |
| else: | |
| sex = "female" | |
| info_string += f"\nThe patient's predicted sex is {sex}." | |
| if view in {0, 1}: | |
| ctr = calculate_ctr(info_mask.cpu().numpy()) | |
| info_string += f"\nThe estimated cardiothoracic ratio (CTR) is {ctr:0.2f}." | |
| if view == 0: | |
| info_string += ( | |
| "\nNote that the cardiac silhuoette is magnified in the AP projection." | |
| ) | |
| if view == 2: | |
| info_string += ( | |
| "\nNOTE: The below outputs are NOT VALID for lateral radiographs." | |
| ) | |
| x = pna_model.preprocess(Radiograph) | |
| x = torch.from_numpy(x).float().to(device) | |
| x = rearrange(x, "h w -> 1 1 h w") | |
| with torch.inference_mode(): | |
| pna_out = pna_model(x) | |
| pna_mask = pna_out["mask"] | |
| h, w = rg.shape[:2] | |
| pna_mask = F.interpolate(pna_mask, size=(h, w), mode="bilinear") | |
| pna_mask = (pna_mask.cpu().numpy()[0, 0] * 255).astype(np.uint8) | |
| pna_mask = cv2.applyColorMap(pna_mask, cv2.COLORMAP_JET) | |
| pna_overlay = make_overlay(rg, pna_mask[..., ::-1]) | |
| x = ptx_model.preprocess(Radiograph) | |
| x = torch.from_numpy(x).float().to(device) | |
| x = rearrange(x, "h w -> 1 1 h w") | |
| with torch.inference_mode(): | |
| ptx_out = ptx_model(x) | |
| ptx_mask = ptx_out["mask"] | |
| h, w = rg.shape[:2] | |
| ptx_mask = F.interpolate(ptx_mask, size=(h, w), mode="bilinear") | |
| ptx_mask = (ptx_mask.cpu().numpy()[0, 0] * 255).astype(np.uint8) | |
| ptx_mask = cv2.applyColorMap(ptx_mask, cv2.COLORMAP_JET) | |
| ptx_overlay = make_overlay(rg, ptx_mask[..., ::-1]) | |
| preds = {"Pneumonia": pna_out["cls"].item(), "Pneumothorax": ptx_out["cls"].item()} | |
| return [info_string, preds, info_overlay, pna_overlay, ptx_overlay] | |
| image = gr.Image(image_mode="L") | |
| info_textbox = gr.Textbox(show_label=False) | |
| labels = gr.Label(show_label=False, show_heading=False) | |
| heatmap1 = gr.Image(image_mode="RGB", label="Heart & Lungs") | |
| heatmap2 = gr.Image(image_mode="RGB", label="Pneumonia") | |
| heatmap3 = gr.Image(image_mode="RGB", label="Pneumothorax") | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # Deep Learning for Chest Radiographs | |
| This demo uses 3 models for chest radiographs: | |
| 1) Heart and lungs segmentation, with age, view, and sex prediction <https://huggingface.co/ianpan/chest-x-ray-basic> | |
| 2) Pneumonia classification and segmentation <https://huggingface.co/ianpan/pneumonia-cxr> | |
| 3) Pneumothorax classification and segmentation <https://huggingface.co/ianpan/pneumothorax-cxr> | |
| Note that the pneumonia and pneumothorax heatmaps produced by this model are based on pixel-level segmentation maps. | |
| Thus, they are expected to be more accurate than non-explicit localization methods such as GradCAM. | |
| The example radiograph is my own, from when I had pneumonia. | |
| This model is for demonstration purposes only and has NOT been approved by any regulatory agency for clinical use. The user assumes | |
| any and all responsibility regarding their own use of this model and its outputs. Do NOT upload any images containing protected | |
| health information, as this demonstration is not compliant with patient privacy laws. | |
| Created by: Ian Pan, <https://ianpan.me> | |
| Last updated: December 27, 2024 | |
| """ | |
| ) | |
| gr.Interface( | |
| fn=predict, | |
| inputs=image, | |
| outputs=[info_textbox, labels, heatmap1, heatmap2, heatmap3], | |
| examples=["examples/cxr.png"], | |
| cache_examples=True, | |
| ) | |
| if __name__ == "__main__": | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device `{device}` ...") | |
| cxr_info_model = ( | |
| AutoModel.from_pretrained("ianpan/chest-x-ray-basic", trust_remote_code=True) | |
| .eval() | |
| .to(device) | |
| ) | |
| pna_model = ( | |
| AutoModel.from_pretrained("ianpan/pneumonia-cxr", trust_remote_code=True) | |
| .eval() | |
| .to(device) | |
| ) | |
| ptx_model = ( | |
| AutoModel.from_pretrained("ianpan/pneumothorax-cxr", trust_remote_code=True) | |
| .eval() | |
| .to(device) | |
| ) | |
| demo.launch(share=True) | |