import torch from safetensors.torch import load_file from huggingface_hub import hf_hub_download import gradio as gr from PIL import Image, ImageOps import numpy as np from kornia.color import rgb_to_lab, lab_to_rgb REPO_ID = "ayushshah/imagecolorization" WEIGHTS_FILE = "model.safetensors" ARCHITECTURE_FILE = "model.py" # Download architecture file hf_hub_download( repo_id=REPO_ID, filename=ARCHITECTURE_FILE, local_dir=".", local_dir_use_symlinks=False ) # Downloading the weights weights_path = hf_hub_download( repo_id=REPO_ID, filename=WEIGHTS_FILE ) # Initialize the model from model import UNet model = UNet() state_dict = load_file(weights_path) model.load_state_dict(state_dict) model.eval() # Center crop and resize to 224x224 def prepare_input(image): if image is None: raise gr.Error("Please upload an image.") pil_image = Image.fromarray(image) side = min(pil_image.size) square = ImageOps.fit( pil_image, (side, side), centering=(0.5, 0.5), ) resized = square.resize((224, 224), Image.Resampling.BICUBIC) return np.array(resized) # Colorize the image def colorize(image): image = image / 255.0 img_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float() lab_tensor = rgb_to_lab(img_tensor) L = lab_tensor[:, 0:1, :, :] L_normalized = (L / 100.0) with torch.no_grad(): ab_pred = model(L_normalized) ab_pred = (ab_pred+1)*255.0/2-128.0 combined_lab = torch.cat([L, ab_pred], dim=1) colorized_rgb = lab_to_rgb(combined_lab) return colorized_rgb.squeeze().permute(1, 2, 0).numpy() def clear_images(): return None, None # Gradio interface with gr.Blocks(title="Image Colorization") as demo: gr.HTML("