Spaces:
Running
Running
| import torch | |
| from torch import nn | |
| from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation | |
| import numpy as np | |
| from PIL import Image | |
| # Global variables for lazy loading | |
| image_processor = None | |
| model = None | |
| # convenience expression for automatically determining device | |
| device = ( | |
| "cuda" | |
| # Device for NVIDIA or AMD GPUs | |
| if torch.cuda.is_available() | |
| else "mps" | |
| # Device for Apple Silicon (Metal Performance Shaders) | |
| if torch.backends.mps.is_available() | |
| else "cpu" | |
| ) | |
| image_processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing") | |
| model = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing") | |
| model.to(device) | |
| def get_face_mask(image): | |
| # Load models if not already loaded | |
| try: | |
| # run inference on image | |
| inputs = image_processor(images=image, return_tensors="pt").to(device) | |
| outputs = model(**inputs) | |
| logits = outputs.logits # shape (batch_size, num_labels, ~height/4, ~width/4) | |
| # resize output to match input image dimensions | |
| upsampled_logits = nn.functional.interpolate(logits, | |
| size=image.size[::-1], # H x W | |
| mode='bilinear', | |
| align_corners=False) | |
| # get label masks | |
| labels = upsampled_logits.argmax(dim=1)[0] | |
| # move to CPU to visualize in matplotlib | |
| labels_viz = labels.cpu().numpy() | |
| # Debug: print label statistics | |
| print(f"Labels min: {labels_viz.min()}, max: {labels_viz.max()}, unique: {np.unique(labels_viz)}") | |
| #Map to something more colorful. Use a color map to map the labels to a color. | |
| #Create a color map for colors 0 through 18 | |
| color_map = plt.get_cmap('tab20') | |
| #Map the labels to a color - normalize labels to 0-1 range for the colormap | |
| # Face parsing models typically have 19 classes (0-18), so normalize by 18 | |
| normalized_labels = labels_viz.astype(np.float32) / 18.0 | |
| colors = color_map(normalized_labels) | |
| #Convert to PIL Image - take only RGB channels (drop alpha) | |
| colors_rgb = colors[:, :, :3] # Remove alpha channel | |
| colors_pil = Image.fromarray((colors_rgb * 255).astype(np.uint8)) | |
| return colors_pil | |
| except Exception as e: | |
| print(f"Error in face parsing: {e}") | |
| return f"Error: {str(e)}" | |