Spaces:
Running
Running
| import gradio as gr | |
| from datetime import datetime | |
| from huggingface_hub import hf_hub_download | |
| import torch | |
| import json | |
| from PIL import Image | |
| from PIL import ImageDraw, ImageFont | |
| import numpy as np | |
| from model import MIPHEIViT | |
| # Load model once | |
| repo_id = "Estabousi/MIPHEI-vit" | |
| model = MIPHEIViT.from_pretrained_hf(repo_id=repo_id) | |
| config_path = hf_hub_download(repo_id=repo_id, filename="config_hf.json") | |
| model.eval() | |
| mean = torch.Tensor([0.485, 0.456, 0.406]).to(torch.float32).reshape((-1, 1, 1)) | |
| std = torch.Tensor([0.229, 0.224, 0.225]).to(torch.float32).reshape((-1, 1, 1)) | |
| with open(config_path, "r") as f: | |
| config = json.load(f) | |
| channel_names = config["targ_channel_names"] | |
| channel_colors = { | |
| "Hoechst": (0, 0, 255), # Blue (DAPI, nuclear stain) | |
| "CD31": (0, 255, 255), # Cyan (endothelial) | |
| "CD45": (255, 255, 0), # Yellow (leukocyte common antigen) | |
| "CD68": (255, 165, 0), # Orange (macrophages) | |
| "CD4": (255, 0, 0), # Red (helper T cells) | |
| "FOXP3": (138, 43, 226), # Purple/Blue-Violet (regulatory T cells) | |
| "CD8a": (303, 100, 100), # Green (cytotoxic T cells) | |
| "CD45RO": (255, 105, 180), # Hot Pink (memory T cells) | |
| "CD20": (0, 191, 255), # Deep Sky Blue (B cells) | |
| "PD-L1": (255, 0, 255), # Magenta | |
| "CD3e": (95, 95, 94), # Crimson (T cells) | |
| "CD163": (184, 134, 11), # Dark Goldenrod (M2 macrophages) | |
| "E-cadherin": (242, 12, 43), # Spring Green (epithelial marker) | |
| "Ki67": (255, 20, 147), # Deep Pink (proliferation marker) | |
| "Pan-CK": (255, 0, 0), # Red (epithelial/carcinoma) | |
| "SMA": (0, 255, 0), # Green (smooth muscle, myofibroblasts) | |
| } | |
| # Contrast correction factors per channel (255 for Hoechst, 150 otherwise) | |
| default_contrast = 150.0 | |
| correction_map = {"Hoechst": 255.0, "CD8a": 100, "CD31": 100, "CD4": 100, "CD68": 100, "FOXP3": 100} | |
| max_contrast_correction_value = torch.tensor([ | |
| correction_map.get(name, default_contrast) / 255 for name in channel_names | |
| ]).reshape(len(channel_names), 1, 1) | |
| overlay_markers = ["Hoechst", "Pan-CK", "SMA", "CD45"] | |
| def preprocess(image): | |
| image = image.convert("RGB").resize((256, 256)) | |
| tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255 | |
| tensor = (tensor - mean) / std | |
| return tensor.unsqueeze(0) # [1, 3, H, W] | |
| def draw_legend_on_image(image, channel_names, channel_colors, indices, box_size=18, spacing=5, top_margin=5): | |
| """Draw a semi-transparent legend on the bottom-right of the image.""" | |
| overlay = image.convert("RGBA") # to allow alpha | |
| legend_layer = Image.new("RGBA", overlay.size, (255, 255, 255, 0)) | |
| draw = ImageDraw.Draw(legend_layer) | |
| font = ImageFont.load_default() | |
| legend_height = top_margin + box_size * len(indices) + spacing * (len(indices) - 1) | |
| legend_width = 60 # adjust as needed | |
| x_start = overlay.width - legend_width - 10 | |
| y_start = overlay.height - legend_height - 10 | |
| # Semi-transparent background | |
| draw.rectangle( | |
| [x_start - 5, y_start - 5, x_start + legend_width + 5, y_start + legend_height + 5], | |
| fill=(255, 255, 255, 180) # semi-transparent white | |
| ) | |
| for i, idx in enumerate(indices): | |
| name = channel_names[idx] | |
| color = channel_colors[name] | |
| y = y_start + i * (box_size + spacing) | |
| draw.rectangle([x_start, y, x_start + box_size, y + box_size], fill=color + (255,)) | |
| draw.text((x_start + box_size + 5, y), name, fill=(0, 0, 0, 255), font=font) | |
| # Composite legend onto overlay | |
| combined = Image.alpha_composite(overlay, legend_layer) | |
| return combined.convert("RGB") # back to RGB for display | |
| def merge_colored_images(color_imgs, top4_idx): | |
| # Convert images to float32 NumPy arrays | |
| accum = np.zeros_like(np.array(color_imgs[0]), dtype=np.float32) | |
| for idx in top4_idx: | |
| img = np.array(color_imgs[idx]).astype(np.float32) | |
| accum += img # additive blending | |
| accum = np.clip(accum, 0, 255).astype(np.uint8) | |
| return Image.fromarray(accum, mode='RGB') | |
| def apply_color_map(gray_img, rgb_color): | |
| """Map a grayscale image to RGB using a fixed pseudocolor.""" | |
| gray = np.asarray(gray_img).astype(np.float32) / 255.0 | |
| rgb = np.stack([gray * rgb_color[i] for i in range(3)], axis=-1).astype(np.uint8) | |
| return Image.fromarray(rgb, mode='RGB') | |
| def predict(image): | |
| print(f"[{datetime.now().isoformat()}] Inference run") | |
| input_tensor = preprocess(image) | |
| with torch.inference_mode(): | |
| output = model(input_tensor)[0] # [16, H, W] | |
| output = (output.clamp(-0.9, 0.9) + 0.9) / 1.8 | |
| output_vis = output / max_contrast_correction_value.to(output.device).clamp(min=1e-6) | |
| output_vis = output_vis.clamp(0, 1) * 255 | |
| output_vis = np.uint8(output_vis.cpu().numpy()) | |
| output = output.cpu().numpy() | |
| # Convert each mIF channel to grayscale PIL image | |
| channel_imgs = [] | |
| for i in range(output_vis.shape[0]): | |
| ch_name = channel_names[i] | |
| ch_gray = Image.fromarray(output_vis[i], mode='L') | |
| ch_colored = apply_color_map(ch_gray, channel_colors[ch_name]) | |
| channel_imgs.append(ch_colored) | |
| fixed_idx = [channel_names.index(name) for name in overlay_markers] | |
| overlay = merge_colored_images(channel_imgs, fixed_idx) | |
| overlay_with_legend = draw_legend_on_image(overlay, channel_names, channel_colors, fixed_idx) | |
| return [overlay_with_legend] + channel_imgs | |
| # Markdown header | |
| with open("HEADER.md", "r", encoding="utf-8") as f: | |
| HEADER_MD = f.read() | |
| # Build interface using Blocks | |
| with gr.Blocks() as demo: | |
| gr.Markdown(HEADER_MD) | |
| with gr.Row(): | |
| # LEFT: input + examples + button | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(type="pil", label="Input H&E") | |
| run_btn = gr.Button("Run Prediction") | |
| gr.Examples( | |
| examples=[ | |
| ["examples/crc100k_val.jpg"], | |
| ["examples/orion_test_1.jpg"], | |
| ["examples/orion_test_2.jpg"], | |
| ["examples/orion_test_3.jpg"], | |
| ["examples/orion_test_4.jpg"], | |
| ["examples/orion_test_5.jpg"], | |
| ["examples/tcga.jpg"], | |
| ["examples/hemit.jpg"], | |
| ], | |
| inputs=[input_image], | |
| label="Example H&E tile (TCGA, ORION Test, CRC100K, HEMIT)" | |
| ) | |
| # RIGHT: outputs | |
| with gr.Column(scale=2): | |
| overlay_image = gr.Image(type="pil", label="mIF Overlay") | |
| channel_images = [ | |
| gr.Image(type="pil", label=f"mIF Channel {channel_names[i]}") | |
| for i in range(16) | |
| ] | |
| output_images = [overlay_image] + channel_images | |
| run_btn.click(fn=predict, inputs=input_image, outputs=output_images) | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=False) | |