Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # Modified by Bowen Cheng from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py | |
| import argparse | |
| import glob | |
| import multiprocessing as mp | |
| import os | |
| import sys | |
| sys.path.insert(1, os.getcwd()) | |
| import tempfile | |
| import time | |
| import warnings | |
| import cv2 | |
| import numpy as np | |
| import tqdm | |
| import torch | |
| from detectron2.config import get_cfg | |
| from detectron2.data.detection_utils import read_image | |
| from detectron2.projects.deeplab import add_deeplab_config | |
| from detectron2.utils.logger import setup_logger | |
| from mask2former import add_maskformer2_config | |
| from predictor import VisualizationDemo | |
| from annotator.util import annotator_ckpts_path | |
| model_url = "https://huggingface.co/datasets/qqlu1992/Adobe_EntitySeg/resolve/main/CropFormer_model/Entity_Segmentation/CropFormer_hornet_3x.pth" | |
| def make_colors(): | |
| from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES | |
| colors = [] | |
| for cate in COCO_CATEGORIES: | |
| colors.append(cate["color"]) | |
| return colors | |
| class EntitysegDetector: | |
| def __init__(self, confidence_threshold=0.5): | |
| cfg = get_cfg() | |
| add_deeplab_config(cfg) | |
| add_maskformer2_config(cfg) | |
| workdir = os.getcwd() | |
| config_file = f"{workdir}/annotator/entityseg/configs/cropformer_hornet_3x.yaml" | |
| model_path = f'{annotator_ckpts_path}/CropFormer_hornet_3x_03823a.pth' | |
| # Authentication required | |
| # if not os.path.exists(model_path): | |
| # from basicsr.utils.download_util import load_file_from_url | |
| # load_file_from_url(model_url, model_dir=annotator_ckpts_path) | |
| cfg.merge_from_file(config_file) | |
| opts = ['MODEL.WEIGHTS', model_path] | |
| cfg.merge_from_list(opts) | |
| cfg.freeze() | |
| self.confidence_threshold = confidence_threshold | |
| self.colors = make_colors() | |
| self.demo = VisualizationDemo(cfg) | |
| def __call__(self, image): | |
| predictions = self.demo.run_on_image(image) | |
| ##### color_mask | |
| pred_masks = predictions["instances"].pred_masks | |
| pred_scores = predictions["instances"].scores | |
| # select by confidence threshold | |
| selected_indexes = (pred_scores >= self.confidence_threshold) | |
| selected_scores = pred_scores[selected_indexes] | |
| selected_masks = pred_masks[selected_indexes] | |
| _, m_H, m_W = selected_masks.shape | |
| mask_id = np.zeros((m_H, m_W), dtype=np.uint8) | |
| # rank | |
| selected_scores, ranks = torch.sort(selected_scores) | |
| ranks = ranks + 1 | |
| for index in ranks: | |
| mask_id[(selected_masks[index-1]==1).cpu().numpy()] = int(index) | |
| unique_mask_id = np.unique(mask_id) | |
| color_mask = np.zeros(image.shape, dtype=np.uint8) | |
| for count in unique_mask_id: | |
| if count == 0: | |
| continue | |
| color_mask[mask_id==count] = self.colors[count % len(self.colors)] | |
| return color_mask | |