import gradio as gr from datetime import datetime from huggingface_hub import hf_hub_download import torch import json from PIL import Image from PIL import ImageDraw, ImageFont import numpy as np from model import MIPHEIViT # Load model once repo_id = "Estabousi/MIPHEI-vit" model = MIPHEIViT.from_pretrained_hf(repo_id=repo_id) config_path = hf_hub_download(repo_id=repo_id, filename="config_hf.json") model.eval() mean = torch.Tensor([0.485, 0.456, 0.406]).to(torch.float32).reshape((-1, 1, 1)) std = torch.Tensor([0.229, 0.224, 0.225]).to(torch.float32).reshape((-1, 1, 1)) with open(config_path, "r") as f: config = json.load(f) channel_names = config["targ_channel_names"] channel_colors = { "Hoechst": (0, 0, 255), # Blue (DAPI, nuclear stain) "CD31": (0, 255, 255), # Cyan (endothelial) "CD45": (255, 255, 0), # Yellow (leukocyte common antigen) "CD68": (255, 165, 0), # Orange (macrophages) "CD4": (255, 0, 0), # Red (helper T cells) "FOXP3": (138, 43, 226), # Purple/Blue-Violet (regulatory T cells) "CD8a": (303, 100, 100), # Green (cytotoxic T cells) "CD45RO": (255, 105, 180), # Hot Pink (memory T cells) "CD20": (0, 191, 255), # Deep Sky Blue (B cells) "PD-L1": (255, 0, 255), # Magenta "CD3e": (95, 95, 94), # Crimson (T cells) "CD163": (184, 134, 11), # Dark Goldenrod (M2 macrophages) "E-cadherin": (242, 12, 43), # Spring Green (epithelial marker) "Ki67": (255, 20, 147), # Deep Pink (proliferation marker) "Pan-CK": (255, 0, 0), # Red (epithelial/carcinoma) "SMA": (0, 255, 0), # Green (smooth muscle, myofibroblasts) } # Contrast correction factors per channel (255 for Hoechst, 150 otherwise) default_contrast = 150.0 correction_map = {"Hoechst": 255.0, "CD8a": 100, "CD31": 100, "CD4": 100, "CD68": 100, "FOXP3": 100} max_contrast_correction_value = torch.tensor([ correction_map.get(name, default_contrast) / 255 for name in channel_names ]).reshape(len(channel_names), 1, 1) overlay_markers = ["Hoechst", "Pan-CK", "SMA", "CD45"] def preprocess(image): image = image.convert("RGB").resize((256, 256)) tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255 tensor = (tensor - mean) / std return tensor.unsqueeze(0) # [1, 3, H, W] def draw_legend_on_image(image, channel_names, channel_colors, indices, box_size=18, spacing=5, top_margin=5): """Draw a semi-transparent legend on the bottom-right of the image.""" overlay = image.convert("RGBA") # to allow alpha legend_layer = Image.new("RGBA", overlay.size, (255, 255, 255, 0)) draw = ImageDraw.Draw(legend_layer) font = ImageFont.load_default() legend_height = top_margin + box_size * len(indices) + spacing * (len(indices) - 1) legend_width = 60 # adjust as needed x_start = overlay.width - legend_width - 10 y_start = overlay.height - legend_height - 10 # Semi-transparent background draw.rectangle( [x_start - 5, y_start - 5, x_start + legend_width + 5, y_start + legend_height + 5], fill=(255, 255, 255, 180) # semi-transparent white ) for i, idx in enumerate(indices): name = channel_names[idx] color = channel_colors[name] y = y_start + i * (box_size + spacing) draw.rectangle([x_start, y, x_start + box_size, y + box_size], fill=color + (255,)) draw.text((x_start + box_size + 5, y), name, fill=(0, 0, 0, 255), font=font) # Composite legend onto overlay combined = Image.alpha_composite(overlay, legend_layer) return combined.convert("RGB") # back to RGB for display def merge_colored_images(color_imgs, top4_idx): # Convert images to float32 NumPy arrays accum = np.zeros_like(np.array(color_imgs[0]), dtype=np.float32) for idx in top4_idx: img = np.array(color_imgs[idx]).astype(np.float32) accum += img # additive blending accum = np.clip(accum, 0, 255).astype(np.uint8) return Image.fromarray(accum, mode='RGB') def apply_color_map(gray_img, rgb_color): """Map a grayscale image to RGB using a fixed pseudocolor.""" gray = np.asarray(gray_img).astype(np.float32) / 255.0 rgb = np.stack([gray * rgb_color[i] for i in range(3)], axis=-1).astype(np.uint8) return Image.fromarray(rgb, mode='RGB') def predict(image): print(f"[{datetime.now().isoformat()}] Inference run") input_tensor = preprocess(image) with torch.inference_mode(): output = model(input_tensor)[0] # [16, H, W] output = (output.clamp(-0.9, 0.9) + 0.9) / 1.8 output_vis = output / max_contrast_correction_value.to(output.device).clamp(min=1e-6) output_vis = output_vis.clamp(0, 1) * 255 output_vis = np.uint8(output_vis.cpu().numpy()) output = output.cpu().numpy() # Convert each mIF channel to grayscale PIL image channel_imgs = [] for i in range(output_vis.shape[0]): ch_name = channel_names[i] ch_gray = Image.fromarray(output_vis[i], mode='L') ch_colored = apply_color_map(ch_gray, channel_colors[ch_name]) channel_imgs.append(ch_colored) fixed_idx = [channel_names.index(name) for name in overlay_markers] overlay = merge_colored_images(channel_imgs, fixed_idx) overlay_with_legend = draw_legend_on_image(overlay, channel_names, channel_colors, fixed_idx) return [overlay_with_legend] + channel_imgs # Markdown header with open("HEADER.md", "r", encoding="utf-8") as f: HEADER_MD = f.read() # Build interface using Blocks with gr.Blocks() as demo: gr.Markdown(HEADER_MD) with gr.Row(): # LEFT: input + examples + button with gr.Column(scale=1): input_image = gr.Image(type="pil", label="Input H&E") run_btn = gr.Button("Run Prediction") gr.Examples( examples=[ ["examples/crc100k_val.jpg"], ["examples/orion_test_1.jpg"], ["examples/orion_test_2.jpg"], ["examples/orion_test_3.jpg"], ["examples/orion_test_4.jpg"], ["examples/orion_test_5.jpg"], ["examples/tcga.jpg"], ["examples/hemit.jpg"], ], inputs=[input_image], label="Example H&E tile (TCGA, ORION Test, CRC100K, HEMIT)" ) # RIGHT: outputs with gr.Column(scale=2): overlay_image = gr.Image(type="pil", label="mIF Overlay") channel_images = [ gr.Image(type="pil", label=f"mIF Channel {channel_names[i]}") for i in range(16) ] output_images = [overlay_image] + channel_images run_btn.click(fn=predict, inputs=input_image, outputs=output_images) if __name__ == "__main__": demo.launch(ssr_mode=False)