Spaces:
Running
on
L4
Running
on
L4
| import base64 | |
| import logging | |
| import os | |
| import sys | |
| from contextlib import nullcontext | |
| import comfy.model_management | |
| import folder_paths | |
| import numpy as np | |
| import torch | |
| import trimesh | |
| from PIL import Image | |
| from trimesh.exchange import gltf | |
| sys.path.append(os.path.dirname(__file__)) | |
| from sf3d.system import SF3D | |
| from sf3d.utils import resize_foreground | |
| SF3D_CATEGORY = "StableFast3D" | |
| SF3D_MODEL_NAME = "stabilityai/stable-fast-3d" | |
| class StableFast3DLoader: | |
| CATEGORY = SF3D_CATEGORY | |
| FUNCTION = "load" | |
| RETURN_NAMES = ("sf3d_model",) | |
| RETURN_TYPES = ("SF3D_MODEL",) | |
| def INPUT_TYPES(cls): | |
| return {"required": {}} | |
| def load(self): | |
| device = comfy.model_management.get_torch_device() | |
| model = SF3D.from_pretrained( | |
| SF3D_MODEL_NAME, | |
| config_name="config.yaml", | |
| weight_name="model.safetensors", | |
| ) | |
| model.to(device) | |
| model.eval() | |
| return (model,) | |
| class StableFast3DPreview: | |
| CATEGORY = SF3D_CATEGORY | |
| FUNCTION = "preview" | |
| OUTPUT_NODE = True | |
| RETURN_TYPES = () | |
| def INPUT_TYPES(s): | |
| return {"required": {"mesh": ("MESH",)}} | |
| def preview(self, mesh): | |
| glbs = [] | |
| for m in mesh: | |
| scene = trimesh.Scene(m) | |
| glb_data = gltf.export_glb(scene, include_normals=True) | |
| glb_base64 = base64.b64encode(glb_data).decode("utf-8") | |
| glbs.append(glb_base64) | |
| return {"ui": {"glbs": glbs}} | |
| class StableFast3DSampler: | |
| CATEGORY = SF3D_CATEGORY | |
| FUNCTION = "predict" | |
| RETURN_NAMES = ("mesh",) | |
| RETURN_TYPES = ("MESH",) | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "model": ("SF3D_MODEL",), | |
| "image": ("IMAGE",), | |
| "foreground_ratio": ( | |
| "FLOAT", | |
| {"default": 0.85, "min": 0.0, "max": 1.0, "step": 0.01}, | |
| ), | |
| "texture_resolution": ( | |
| "INT", | |
| {"default": 1024, "min": 512, "max": 2048, "step": 256}, | |
| ), | |
| }, | |
| "optional": { | |
| "mask": ("MASK",), | |
| "remesh": (["none", "triangle", "quad"],), | |
| "vertex_count": ( | |
| "INT", | |
| {"default": -1, "min": -1, "max": 20000, "step": 1}, | |
| ), | |
| }, | |
| } | |
| def predict( | |
| s, | |
| model, | |
| image, | |
| mask, | |
| foreground_ratio, | |
| texture_resolution, | |
| remesh="none", | |
| vertex_count=-1, | |
| ): | |
| if image.shape[0] != 1: | |
| raise ValueError("Only one image can be processed at a time") | |
| pil_image = Image.fromarray( | |
| torch.clamp(torch.round(255.0 * image[0]), 0, 255) | |
| .type(torch.uint8) | |
| .cpu() | |
| .numpy() | |
| ) | |
| if mask is not None: | |
| print("Using Mask") | |
| mask_np = np.clip(255.0 * mask[0].detach().cpu().numpy(), 0, 255).astype( | |
| np.uint8 | |
| ) | |
| mask_pil = Image.fromarray(mask_np, mode="L") | |
| pil_image.putalpha(mask_pil) | |
| else: | |
| if image.shape[3] != 4: | |
| print("No mask or alpha channel detected, Converting to RGBA") | |
| pil_image = pil_image.convert("RGBA") | |
| pil_image = resize_foreground(pil_image, foreground_ratio) | |
| print(remesh) | |
| with torch.no_grad(): | |
| with torch.autocast( | |
| device_type="cuda", dtype=torch.bfloat16 | |
| ) if "cuda" in comfy.model_management.get_torch_device().type else nullcontext(): | |
| mesh, glob_dict = model.run_image( | |
| pil_image, | |
| bake_resolution=texture_resolution, | |
| remesh=remesh, | |
| vertex_count=vertex_count, | |
| ) | |
| if mesh.vertices.shape[0] == 0: | |
| raise ValueError("No subject detected in the image") | |
| return ([mesh],) | |
| class StableFast3DSave: | |
| CATEGORY = SF3D_CATEGORY | |
| FUNCTION = "save" | |
| OUTPUT_NODE = True | |
| RETURN_TYPES = () | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "mesh": ("MESH",), | |
| "filename_prefix": ("STRING", {"default": "SF3D"}), | |
| } | |
| } | |
| def __init__(self): | |
| self.type = "output" | |
| def save(self, mesh, filename_prefix): | |
| output_dir = folder_paths.get_output_directory() | |
| glbs = [] | |
| for idx, m in enumerate(mesh): | |
| scene = trimesh.Scene(m) | |
| glb_data = gltf.export_glb(scene, include_normals=True) | |
| logging.info(f"Generated GLB model with {len(glb_data)} bytes") | |
| full_output_folder, filename, counter, subfolder, filename_prefix = ( | |
| folder_paths.get_save_image_path(filename_prefix, output_dir) | |
| ) | |
| filename = filename.replace("%batch_num%", str(idx)) | |
| out_path = os.path.join(full_output_folder, f"{filename}_{counter:05}_.glb") | |
| with open(out_path, "wb") as f: | |
| f.write(glb_data) | |
| glbs.append(base64.b64encode(glb_data).decode("utf-8")) | |
| return {"ui": {"glbs": glbs}} | |
| NODE_DISPLAY_NAME_MAPPINGS = { | |
| "StableFast3DLoader": "Stable Fast 3D Loader", | |
| "StableFast3DPreview": "Stable Fast 3D Preview", | |
| "StableFast3DSampler": "Stable Fast 3D Sampler", | |
| "StableFast3DSave": "Stable Fast 3D Save", | |
| } | |
| NODE_CLASS_MAPPINGS = { | |
| "StableFast3DLoader": StableFast3DLoader, | |
| "StableFast3DPreview": StableFast3DPreview, | |
| "StableFast3DSampler": StableFast3DSampler, | |
| "StableFast3DSave": StableFast3DSave, | |
| } | |
| WEB_DIRECTORY = "./comfyui" | |
| __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"] | |