Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (C) 2025, FaceLift Research Group | |
| # https://github.com/weijielyu/FaceLift | |
| # | |
| # This software is free for non-commercial, research and evaluation use | |
| # under the terms of the LICENSE.md file. | |
| # | |
| # For inquiries contact: [email protected] | |
| """ | |
| FaceLift: Single Image 3D Face Reconstruction | |
| Generates 3D head models from single images using multi-view diffusion and GS-LRM. | |
| """ | |
| import json | |
| from pathlib import Path | |
| from datetime import datetime | |
| import uuid | |
| import time | |
| import shutil | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import yaml | |
| from easydict import EasyDict as edict | |
| from einops import rearrange | |
| from PIL import Image | |
| from huggingface_hub import snapshot_download | |
| import spaces | |
| # Install diff-gaussian-rasterization at runtime (requires GPU) | |
| import subprocess | |
| import sys | |
| import os | |
| # Outputs directory for generated files | |
| OUTPUTS_DIR = Path.cwd() / "outputs" | |
| OUTPUTS_DIR.mkdir(exist_ok=True) | |
| # Copy viewer.js to outputs so it can be served as a static file | |
| VIEWER_JS_SRC = Path(__file__).parent / "viewer.js" | |
| if VIEWER_JS_SRC.exists(): | |
| shutil.copy2(VIEWER_JS_SRC, OUTPUTS_DIR / "viewer.js") | |
| print(f"โ Copied viewer.js to {OUTPUTS_DIR / 'viewer.js'}") | |
| # ----------------------------- | |
| # Ensure diff-gaussian-rasterization builds for current GPU | |
| # ----------------------------- | |
| try: | |
| import diff_gaussian_rasterization # noqa: F401 | |
| except ImportError: | |
| print("Installing diff-gaussian-rasterization (compiling for detected CUDA arch)...") | |
| env = os.environ.copy() | |
| try: | |
| import torch as _torch | |
| if _torch.cuda.is_available(): | |
| maj, minr = _torch.cuda.get_device_capability() | |
| arch = f"{maj}.{minr}" # e.g., "9.0" on H100/H200, "8.0" on A100 | |
| env["TORCH_CUDA_ARCH_LIST"] = f"{arch}+PTX" | |
| else: | |
| # Build stage may not see a GPU on HF Spaces: compile a cross-arch set | |
| env["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;8.9;9.0+PTX" | |
| except Exception: | |
| env["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;8.9;9.0+PTX" | |
| # (Optional) side-step allocator+NVML quirks in restrictive containers | |
| env.setdefault("PYTORCH_NO_CUDA_MEMORY_CACHING", "1") | |
| subprocess.check_call( | |
| [sys.executable, "-m", "pip", "install", | |
| "git+https://github.com/graphdeco-inria/diff-gaussian-rasterization"], | |
| env=env, | |
| ) | |
| import diff_gaussian_rasterization # noqa: F401 | |
| from gslrm.model.gaussians_renderer import render_turntable, imageseq2video | |
| from mvdiffusion.pipelines.pipeline_mvdiffusion_unclip import StableUnCLIPImg2ImgPipeline | |
| from utils_folder.face_utils import preprocess_image, preprocess_image_without_cropping | |
| # HuggingFace repository configuration | |
| HF_REPO_ID = "wlyu/OpenFaceLift" | |
| def download_weights_from_hf() -> Path: | |
| """Download model weights from HuggingFace if not already present. | |
| Returns: | |
| Path to the downloaded repository | |
| """ | |
| workspace_dir = Path(__file__).parent | |
| # Check if weights already exist locally | |
| mvdiffusion_path = workspace_dir / "checkpoints/mvdiffusion/pipeckpts" | |
| gslrm_path = workspace_dir / "checkpoints/gslrm/ckpt_0000000000021125.pt" | |
| if mvdiffusion_path.exists() and gslrm_path.exists(): | |
| print("Using local model weights") | |
| return workspace_dir | |
| print(f"Downloading model weights from HuggingFace: {HF_REPO_ID}") | |
| print("This may take a few minutes on first run...") | |
| # Disable fast transfer if hf_transfer is not installed | |
| if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER") == "1": | |
| try: | |
| import hf_transfer | |
| except ImportError: | |
| print("hf_transfer not available, disabling fast download") | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" | |
| # Download to local directory | |
| snapshot_download( | |
| repo_id=HF_REPO_ID, | |
| local_dir=str(workspace_dir / "checkpoints"), | |
| local_dir_use_symlinks=False, | |
| ) | |
| print("Model weights downloaded successfully!") | |
| return workspace_dir | |
| class FaceLiftPipeline: | |
| """Pipeline for FaceLift 3D head generation from single images.""" | |
| def __init__(self): | |
| # Download weights from HuggingFace if needed | |
| workspace_dir = download_weights_from_hf() | |
| # Setup paths | |
| self.output_dir = workspace_dir / "outputs" | |
| self.examples_dir = workspace_dir / "examples" | |
| self.output_dir.mkdir(exist_ok=True) | |
| # Parameters | |
| self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| self.image_size = 512 | |
| self.camera_indices = [2, 1, 0, 5, 4, 3] | |
| # Load models (keep on CPU for ZeroGPU compatibility) | |
| print("Loading models...") | |
| try: | |
| self.mvdiffusion_pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained( | |
| str(workspace_dir / "checkpoints/mvdiffusion/pipeckpts"), | |
| torch_dtype=torch.float16, | |
| ) | |
| # Don't move to device or enable xformers here - will be done in GPU-decorated function | |
| self._models_on_gpu = False | |
| with open(workspace_dir / "configs/gslrm.yaml", "r") as f: | |
| config = edict(yaml.safe_load(f)) | |
| module_name, class_name = config.model.class_name.rsplit(".", 1) | |
| module = __import__(module_name, fromlist=[class_name]) | |
| ModelClass = getattr(module, class_name) | |
| self.gs_lrm_model = ModelClass(config) | |
| checkpoint = torch.load( | |
| workspace_dir / "checkpoints/gslrm/ckpt_0000000000021125.pt", | |
| map_location="cpu" | |
| ) | |
| # Filter out loss_calculator weights (training-only, not needed for inference) | |
| state_dict = {k: v for k, v in checkpoint["model"].items() | |
| if not k.startswith("loss_calculator.")} | |
| self.gs_lrm_model.load_state_dict(state_dict) | |
| # Keep on CPU initially - will move to GPU in decorated function | |
| self.color_prompt_embedding = torch.load( | |
| workspace_dir / "mvdiffusion/fixed_prompt_embeds_6view/clr_embeds.pt", | |
| map_location="cpu" | |
| ) | |
| with open(workspace_dir / "utils_folder/opencv_cameras.json", 'r') as f: | |
| self.cameras_data = json.load(f)["frames"] | |
| print("Models loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading models: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise | |
| def _move_models_to_gpu(self): | |
| """Move models to GPU and enable optimizations. Called within @spaces.GPU context.""" | |
| if not self._models_on_gpu and torch.cuda.is_available(): | |
| print("Moving models to GPU...") | |
| self.device = torch.device("cuda:0") | |
| self.mvdiffusion_pipeline.to(self.device) | |
| self.mvdiffusion_pipeline.unet.enable_xformers_memory_efficient_attention() | |
| self.gs_lrm_model.to(self.device) | |
| self.gs_lrm_model.eval() # Set to eval mode | |
| self.color_prompt_embedding = self.color_prompt_embedding.to(self.device) | |
| self._models_on_gpu = True | |
| torch.cuda.empty_cache() # Clear cache after moving models | |
| print("Models on GPU, xformers enabled!") | |
| def generate_3d_head(self, image_path, auto_crop=True, guidance_scale=3.0, | |
| random_seed=4, num_steps=50): | |
| """Generate 3D head from single image.""" | |
| try: | |
| # Move models to GPU now that we're in the GPU context | |
| self._move_models_to_gpu() | |
| # Setup output directory | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| output_dir = self.output_dir / timestamp | |
| output_dir.mkdir(exist_ok=True) | |
| # Preprocess input | |
| original_img = np.array(Image.open(image_path)) | |
| input_image = preprocess_image(original_img) if auto_crop else \ | |
| preprocess_image_without_cropping(original_img) | |
| if input_image.size != (self.image_size, self.image_size): | |
| input_image = input_image.resize((self.image_size, self.image_size)) | |
| input_path = output_dir / "input.png" | |
| input_image.save(input_path) | |
| # Generate multi-view images | |
| generator = torch.Generator(device=self.mvdiffusion_pipeline.unet.device) | |
| generator.manual_seed(random_seed) | |
| result = self.mvdiffusion_pipeline( | |
| input_image, None, | |
| prompt_embeds=self.color_prompt_embedding, | |
| height=self.image_size, | |
| width=self.image_size, | |
| guidance_scale=guidance_scale, | |
| num_images_per_prompt=1, | |
| num_inference_steps=num_steps, | |
| generator=generator, | |
| eta=1.0, | |
| ) | |
| selected_views = result.images[:6] | |
| # Save multi-view composite | |
| multiview_image = Image.new("RGB", (self.image_size * 6, self.image_size)) | |
| for i, view in enumerate(selected_views): | |
| multiview_image.paste(view, (self.image_size * i, 0)) | |
| multiview_path = output_dir / "multiview.png" | |
| multiview_image.save(multiview_path) | |
| # Move diffusion model to CPU to free GPU memory for GS-LRM | |
| print("Moving diffusion model to CPU to free memory...") | |
| self.mvdiffusion_pipeline.to("cpu") | |
| # Delete intermediate variables to free memory | |
| del result, generator | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| # Prepare 3D reconstruction input | |
| view_arrays = [np.array(view) for view in selected_views] | |
| lrm_input = torch.from_numpy(np.stack(view_arrays, axis=0)).float() | |
| lrm_input = lrm_input[None].to(self.device) / 255.0 | |
| lrm_input = rearrange(lrm_input, "b v h w c -> b v c h w") | |
| # Prepare camera parameters | |
| selected_cameras = [self.cameras_data[i] for i in self.camera_indices] | |
| fxfycxcy_list = [[c["fx"], c["fy"], c["cx"], c["cy"]] for c in selected_cameras] | |
| c2w_list = [np.linalg.inv(np.array(c["w2c"])) for c in selected_cameras] | |
| fxfycxcy = torch.from_numpy(np.stack(fxfycxcy_list, axis=0).astype(np.float32)) | |
| c2w = torch.from_numpy(np.stack(c2w_list, axis=0).astype(np.float32)) | |
| fxfycxcy = fxfycxcy[None].to(self.device) | |
| c2w = c2w[None].to(self.device) | |
| batch_indices = torch.stack([ | |
| torch.zeros(lrm_input.size(1)).long(), | |
| torch.arange(lrm_input.size(1)).long(), | |
| ], dim=-1)[None].to(self.device) | |
| batch = edict({ | |
| "image": lrm_input, | |
| "c2w": c2w, | |
| "fxfycxcy": fxfycxcy, | |
| "index": batch_indices, | |
| }) | |
| # Ensure GS-LRM model is on GPU | |
| if next(self.gs_lrm_model.parameters()).device.type == "cpu": | |
| print("Moving GS-LRM model to GPU...") | |
| self.gs_lrm_model.to(self.device) | |
| torch.cuda.empty_cache() | |
| # Final memory cleanup before reconstruction | |
| torch.cuda.empty_cache() | |
| # Run 3D reconstruction | |
| with torch.no_grad(), torch.autocast(enabled=True, device_type="cuda", dtype=torch.float16): | |
| result = self.gs_lrm_model.forward(batch, create_visual=False, split_data=True) | |
| comp_image = result.render[0].unsqueeze(0).detach() | |
| gaussians = result.gaussians[0] | |
| # Clear CUDA cache after reconstruction | |
| torch.cuda.empty_cache() | |
| # Save filtered gaussians | |
| filtered_gaussians = gaussians.apply_all_filters( | |
| cam_origins=None, | |
| opacity_thres=0.04, | |
| scaling_thres=0.2, | |
| floater_thres=0.75, | |
| crop_bbx=[-0.91, 0.91, -0.91, 0.91, -1.0, 1.0], | |
| nearfar_percent=(0.0001, 1.0), | |
| ) | |
| ply_path = output_dir / "gaussians.ply" | |
| filtered_gaussians.save_ply(str(ply_path)) | |
| # Save output image | |
| comp_image = rearrange(comp_image, "x v c h w -> (x h) (v w) c") | |
| comp_image = (comp_image.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8) | |
| output_path = output_dir / "output.png" | |
| Image.fromarray(comp_image).save(output_path) | |
| # Generate turntable video | |
| turntable_resolution = 512 | |
| num_turntable_views = 180 | |
| turntable_frames = render_turntable(gaussians, rendering_resolution=turntable_resolution, | |
| num_views=num_turntable_views) | |
| turntable_frames = rearrange(turntable_frames, "h (v w) c -> v h w c", v=num_turntable_views) | |
| turntable_frames = np.ascontiguousarray(turntable_frames) | |
| turntable_path = output_dir / "turntable.mp4" | |
| imageseq2video(turntable_frames, str(turntable_path), fps=30) | |
| # Final CUDA cache clear | |
| torch.cuda.empty_cache() | |
| return str(input_path), str(multiview_path), str(output_path), \ | |
| str(turntable_path), str(ply_path) | |
| except Exception as e: | |
| import traceback | |
| error_details = traceback.format_exc() | |
| print(f"Error details:\n{error_details}") | |
| raise gr.Error(f"Generation failed: {str(e)}") | |
| # Create viewer HTML file that can be served by Gradio | |
| def create_splat_viewer_html(ply_url: str, viewer_js_url: str, viewer_id: str) -> str: | |
| """Create a standalone HTML file with embedded viewer for the PLY file.""" | |
| # Read viewer.js content | |
| viewer_js_path = Path(__file__).parent / "viewer.js" | |
| viewer_js_content = viewer_js_path.read_text() if viewer_js_path.exists() else "console.error('viewer.js not found');" | |
| # Create HTML file in outputs directory | |
| output_dir = Path(ply_url.replace("/file=", "").rsplit("/", 1)[0]) | |
| viewer_html_path = output_dir / f"viewer_{viewer_id}.html" | |
| html_content = f"""<!DOCTYPE html> | |
| <html> | |
| <head> | |
| <meta charset="utf-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <title>3D Gaussian Splat Viewer</title> | |
| <style> | |
| body {{ margin: 0; overflow: hidden; background: #000; }} | |
| #canvas {{ width: 100vw; height: 100vh; display: block; }} | |
| #spinner {{ | |
| position: absolute; top: 50%; left: 50%; | |
| transform: translate(-50%, -50%); | |
| color: white; font-family: Arial; z-index: 10; | |
| text-align: center; background: rgba(0,0,0,0.8); | |
| padding: 20px; border-radius: 8px; | |
| }} | |
| #progress {{ background: #4CAF50; height: 4px; width: 0%; transition: width 0.3s; }} | |
| #message {{ | |
| position: absolute; top: 50%; left: 50%; | |
| transform: translate(-50%, -50%); | |
| color: #ff4444; font-family: Arial; font-size: 14px; | |
| background: rgba(0,0,0,0.9); padding: 20px; | |
| border-radius: 8px; display: none; z-index: 11; | |
| }} | |
| #fps, #camid {{ | |
| position: absolute; right: 10px; | |
| color: white; font-family: monospace; font-size: 11px; | |
| background: rgba(0,0,0,0.7); padding: 6px 10px; | |
| border-radius: 4px; display: none; | |
| }} | |
| #fps {{ top: 10px; }} | |
| #camid {{ top: 40px; }} | |
| #controls-info {{ | |
| position: absolute; bottom: 10px; left: 10px; | |
| color: white; font-family: Arial; font-size: 11px; | |
| background: rgba(0,0,0,0.7); padding: 8px 12px; | |
| border-radius: 4px; | |
| }} | |
| </style> | |
| </head> | |
| <body> | |
| <canvas id="canvas"></canvas> | |
| <div id="spinner"> | |
| <div style="font-size:14px; margin-bottom:10px;">Loading 3D Viewer...</div> | |
| <div style="background:#333; height:4px; width:200px; border-radius:2px; overflow:hidden;"> | |
| <div id="progress"></div> | |
| </div> | |
| </div> | |
| <div id="message"></div> | |
| <div id="fps"></div> | |
| <div id="camid"></div> | |
| <div id="controls-info"> | |
| <strong>Controls:</strong> Drag: Rotate | Scroll: Zoom | Right-drag: Pan | |
| </div> | |
| <script> | |
| {viewer_js_content} | |
| </script> | |
| <script> | |
| // Auto-load PLY after viewer initializes | |
| const plyUrl = "{ply_url}"; | |
| console.log("=== Splat Viewer Init ==="); | |
| console.log("PLY URL:", plyUrl); | |
| let attempts = 0; | |
| const checkAndLoad = setInterval(function() {{ | |
| attempts++; | |
| if (window.worker) {{ | |
| console.log("โ Worker ready after", attempts * 100, "ms"); | |
| clearInterval(checkAndLoad); | |
| fetch(plyUrl) | |
| .then(r => {{ if (!r.ok) throw new Error("HTTP " + r.status); return r.arrayBuffer(); }}) | |
| .then(buffer => {{ | |
| console.log("โ PLY loaded:", buffer.byteLength, "bytes"); | |
| const file = new File([buffer], "model.ply"); | |
| const reader = new FileReader(); | |
| reader.onload = () => {{ | |
| window.worker.postMessage({{ ply: reader.result }}); | |
| console.log("โ Sent to worker"); | |
| }}; | |
| reader.readAsArrayBuffer(file); | |
| }}) | |
| .catch(err => {{ | |
| console.error("โ Error:", err); | |
| document.getElementById("spinner").style.display = "none"; | |
| const msg = document.getElementById("message"); | |
| msg.textContent = "Error: " + err.message; | |
| msg.style.display = "block"; | |
| }}); | |
| }} else if (attempts >= 50) {{ | |
| console.error("โ Worker timeout"); | |
| clearInterval(checkAndLoad); | |
| document.getElementById("spinner").style.display = "none"; | |
| const msg = document.getElementById("message"); | |
| msg.textContent = "Viewer failed to initialize."; | |
| msg.style.display = "block"; | |
| }} | |
| }}, 100); | |
| </script> | |
| </body> | |
| </html>""" | |
| # Write HTML file | |
| viewer_html_path.write_text(html_content) | |
| # Return iframe that loads this HTML file | |
| viewer_html_url = f"/file={viewer_html_path}" | |
| return f""" | |
| <iframe id="viewer-frame-{viewer_id}" src="{viewer_html_url}" style="width:100%; height:600px; border:1px solid #333; border-radius:8px;" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> | |
| <p style="font-size:11px; color:#666; margin-top:5px;"> | |
| <a href="{viewer_html_url}" target="_blank" style="color:#4CAF50;">Open in new tab</a> for better performance | |
| </p> | |
| """ | |
| def main(): | |
| """Run the FaceLift application.""" | |
| pipeline = FaceLiftPipeline() | |
| # Prepare examples (same as before) | |
| examples = [] | |
| if pipeline.examples_dir.exists(): | |
| examples = [[str(f), True, 3.0, 4, 50] for f in sorted(pipeline.examples_dir.iterdir()) | |
| if f.suffix.lower() in {'.png', '.jpg', '.jpeg'}] | |
| with gr.Blocks(title="FaceLift: Single Image 3D Face Reconstruction") as demo: | |
| # Wrapper to return outputs for display | |
| def _generate_and_filter_outputs(image_path, auto_crop, guidance_scale, random_seed, num_steps): | |
| input_path, multiview_path, output_path, turntable_path, ply_path = \ | |
| pipeline.generate_3d_head(image_path, auto_crop, guidance_scale, random_seed, num_steps) | |
| # Create Gradio-accessible URL for the PLY file | |
| ply_url = f"/file={ply_path}" | |
| # viewer.js is in the outputs directory | |
| viewer_js_url = f"/file={OUTPUTS_DIR}/viewer.js" | |
| # Generate unique viewer ID | |
| viewer_id = str(uuid.uuid4())[:8] | |
| viewer_html = create_splat_viewer_html(ply_url, viewer_js_url, viewer_id) | |
| # Debug info showing the paths | |
| debug_info = f"PLY Path: {ply_path}\nPLY URL: {ply_url}\nViewer JS URL: {viewer_js_url}\nFile exists: {Path(ply_path).exists()}\nViewer ID: {viewer_id}" | |
| return viewer_html, output_path, turntable_path, ply_path, debug_info | |
| gr.Markdown("## FaceLift: Single Image 3D Face Reconstruction.") | |
| gr.Markdown(""" | |
| ### ๐ก Tips for Best Results | |
| - Works best with near-frontal portrait images | |
| - The provided checkpoints were not trained with accessories (glasses, hats, etc.). Portraits containing accessories may produce suboptimal results. | |
| - If face detection fails, try disabling auto-cropping and manually crop to square | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| in_image = gr.Image(type="filepath", label="Input Portrait Image") | |
| auto_crop = gr.Checkbox(value=True, label="Auto Cropping") | |
| guidance = gr.Slider(1.0, 10.0, 3.0, step=0.1, label="Guidance Scale") | |
| seed = gr.Number(value=4, label="Random Seed") | |
| steps = gr.Slider(10, 100, 50, step=5, label="Generation Steps") | |
| run_btn = gr.Button("Generate 3D Head", variant="primary") | |
| # Examples (match input signature) | |
| if examples: | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[in_image, auto_crop, guidance, seed, steps], | |
| examples_per_page=10, | |
| ) | |
| with gr.Column(scale=1): | |
| out_viewer = gr.HTML(label="๐ฎ Interactive 3D Viewer") | |
| out_debug = gr.Textbox(label="๐ Debug Info", lines=3, visible=True) | |
| out_recon = gr.Image(label="3D Reconstruction Views") | |
| out_video = gr.PlayableVideo(label="Turntable Animation (360ยฐ View)", height=400) | |
| out_ply = gr.File(label="Download 3D Model (.ply)") | |
| gr.Markdown(""" | |
| **๐ก Controls:** Drag to rotate | Scroll to zoom | Right-drag to pan | |
| *Interactive viewer powered by [antimatter15/splat](https://github.com/antimatter15/splat)* | |
| """) | |
| # Run generation and display all outputs | |
| run_btn.click( | |
| fn=_generate_and_filter_outputs, | |
| inputs=[in_image, auto_crop, guidance, seed, steps], | |
| outputs=[out_viewer, out_recon, out_video, out_ply, out_debug], | |
| ) | |
| demo.queue(max_size=10) | |
| demo.launch(share=True, server_name="0.0.0.0", server_port=7860, show_error=True) | |
| if __name__ == "__main__": | |
| main() | |