from PIL import Image import io import numpy as np import pydicom import torch from fastapi import HTTPException, UploadFile from skimage.transform import resize import re def clean_paragraph(text): # Remove leading dashes, numbers, and extra whitespace lines = text.strip().splitlines() cleaned_lines = [re.sub(r"^\s*[-•\d.]*\s*", "", line) for line in lines if line.strip()] return " ".join(cleaned_lines) def split_report_sections(report_text): #strip any * in any place in the report report_text = report_text.replace("*", "") # Use regex to extract findings and impression sections findings_match = re.search(r"(?i)findings:\s*(.*?)(?=(impressions?:))", report_text, re.DOTALL) impression_match = re.search(r"(?i)impressions?:\s*(.*)", report_text, re.DOTALL) findings_raw = findings_match.group(1).strip() if findings_match else "" impression = impression_match.group(1).strip() if impression_match else "" findings = clean_paragraph(findings_raw) impression = clean_paragraph(impression) return { "findings": findings, "impression": impression # Keep impression formatting as-is (or you can also clean it similarly) } def load_image(image): image = image.convert("RGB") image_array = np.asarray(image) / 255.0 # Normalize to [0,1] image_array = resize(image_array, (224, 224)) image_tensor = torch.tensor(image_array, dtype=torch.float32).permute(2, 0, 1) # CxHxW mean = torch.tensor([0.485, 0.456, 0.406]) std = torch.tensor([0.229, 0.224, 0.225]) image_tensor = (image_tensor - mean[:, None, None]) / std[:, None, None] return image_tensor.unsqueeze(0) # Add batch dimension async def convert_to_png(file: UploadFile) -> Image.Image: """Converts JPG, PNG, or DICOM to a PNG format""" image_data = await file.read() if file.content_type in ["image/jpeg", "image/png", "image/jpg"]: image = Image.open(io.BytesIO(image_data)) return image if file.content_type == "application/dicom" or file.filename.endswith(".dcm") or file.filename.endswith(".dicom"): dicom_data = pydicom.dcmread(io.BytesIO(image_data)) pixel_array = dicom_data.pixel_array if pixel_array.dtype != np.uint8: pixel_array = (pixel_array / pixel_array.max() * 255).astype(np.uint8) image = Image.fromarray(pixel_array).convert("RGB") return image raise HTTPException(status_code=400, detail="Unsupported media type") def dicom_to_png(ds): pixel_array = ds.pixel_array # Normalize if needed (handle 16-bit images) if pixel_array.dtype != np.uint8: # Scale to 0-255 pixel_array = pixel_array.astype(np.float32) pixel_array -= pixel_array.min() pixel_array /= pixel_array.max() pixel_array *= 255.0 pixel_array = pixel_array.astype(np.uint8) # Convert to PIL Image img = Image.fromarray(pixel_array).convert("RGB") return img