Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from PIL import Image,ImageFilter | |
| import cv2 | |
| import numpy as np | |
| import base64 | |
| import spaces | |
| from loadimg import load_img | |
| from io import BytesIO | |
| import numpy as np | |
| import insightface | |
| import onnxruntime as ort | |
| import huggingface_hub | |
| from SegCloth import segment_clothing | |
| from transparent_background import Remover | |
| import uuid | |
| from transformers import AutoModelForImageSegmentation | |
| import torch | |
| from torchvision import transforms | |
| # Load the model lazily | |
| model = None | |
| detector = None | |
| def load_model(): | |
| global model, detector | |
| path = huggingface_hub.hf_hub_download("public-data/insightface", "models/scrfd_person_2.5g.onnx") | |
| options = ort.SessionOptions() | |
| options.intra_op_num_threads = 8 | |
| options.inter_op_num_threads = 8 | |
| session = ort.InferenceSession( | |
| path, sess_options=options, providers=["CPUExecutionProvider", "CUDAExecutionProvider"] | |
| ) | |
| model = insightface.model_zoo.retinaface.RetinaFace(model_file=path, session=session) | |
| model.prepare(-1, nms_thresh=0.5, input_size=(640, 640)) | |
| detector = model | |
| # Load the segmentation model | |
| torch.set_float32_matmul_precision(["high", "highest"][0]) | |
| birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True) | |
| birefnet.to("cuda") | |
| transform_image = transforms.Compose([ | |
| transforms.Resize((1024, 1024)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ]) | |
| def refine_edges(image): | |
| """ | |
| Affine les contours de l'image en sortie en utilisant un filtre de détection de contours et du lissage. | |
| """ | |
| # Convertir l'image PIL en format numpy pour OpenCV | |
| img_np = np.array(image) | |
| # Convertir en niveaux de gris pour traiter les contours | |
| gray = cv2.cvtColor(img_np, cv2.COLOR_RGBA2GRAY) | |
| # Détection des bords avec Canny | |
| edges = cv2.Canny(gray, threshold1=50, threshold2=150) | |
| # Dilater les bords pour renforcer les contours | |
| kernel = np.ones((3, 3), np.uint8) | |
| edges_dilated = cv2.dilate(edges, kernel, iterations=1) | |
| # Lisser les bords (anti-aliasing) | |
| blurred = cv2.GaussianBlur(edges_dilated, (5, 5), 0) | |
| # Ajouter les bords comme masque alpha | |
| alpha = Image.fromarray(blurred).convert("L") | |
| image.putalpha(alpha) | |
| # Filtrage supplémentaire pour améliorer l'esthétique | |
| refined_image = image.filter(ImageFilter.SMOOTH_MORE) | |
| return refined_image | |
| def save_image(img): | |
| unique_name = str(uuid.uuid4()) + ".png" | |
| img.save(unique_name) | |
| return unique_name | |
| def rm_background(image): | |
| im = load_img(image, output_type="pil") | |
| im = im.convert("RGB") | |
| image_size = im.size | |
| origin = im.copy() | |
| image = load_img(im) | |
| input_images = transform_image(image).unsqueeze(0).to("cuda") | |
| # Prediction | |
| with torch.no_grad(): | |
| preds = birefnet(input_images)[-1].sigmoid().cpu() | |
| pred = preds[0].squeeze() | |
| pred_pil = transforms.ToPILImage()(pred) | |
| mask = pred_pil.resize(image_size) | |
| image.putalpha(mask) | |
| return image | |
| def detect_and_segment_persons(image, clothes): | |
| img = np.array(image) | |
| img = img[:, :, ::-1] # RGB -> BGR | |
| if detector is None: | |
| load_model() # Ensure the model is loaded | |
| bboxes, kpss = detector.detect(img) | |
| if bboxes.shape[0] == 0: | |
| return [rm_background(image)] | |
| height, width, _ = img.shape | |
| bboxes = np.round(bboxes[:, :4]).astype(int) | |
| bboxes[:, 0] = np.clip(bboxes[:, 0], 0, width) | |
| bboxes[:, 1] = np.clip(bboxes[:, 1], 0, height) | |
| bboxes[:, 2] = np.clip(bboxes[:, 2], 0, width) | |
| bboxes[:, 3] = np.clip(bboxes[:, 3], 0, height) | |
| all_segmented_images = [] | |
| for i in range(bboxes.shape[0]): | |
| bbox = bboxes[i] | |
| x1, y1, x2, y2 = bbox | |
| person_img = img[y1:y2, x1:x2] | |
| pil_img = Image.fromarray(person_img[:, :, ::-1]) | |
| img_rm_background = rm_background(pil_img) | |
| segmented_result = segment_clothing(img_rm_background, clothes) | |
| all_segmented_images.extend(segmented_result) | |
| return all_segmented_images | |
| def process_image(input_image): | |
| try: | |
| clothes = ["Upper-clothes", "Skirt", "Pants", "Dress"] | |
| results = detect_and_segment_persons(input_image, clothes) | |
| # results = [refine_edges(image) for image in results] | |
| return results | |
| except Exception as e: | |
| return f"Error occurred: {e}" | |
| # Gradio Interface | |
| def gradio_interface(image): | |
| results = process_image(image) | |
| if isinstance(results, list): | |
| return results | |
| else: | |
| return "Error: " + results | |
| # Create Gradio app | |
| interface = gr.Interface( | |
| fn=gradio_interface, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Gallery(label="Segmented Results"), | |
| title="Clothing Segmentation API" | |
| ) | |
| interface.launch(server_name="0.0.0.0", server_port=7860) | |