FaceLift / app.py
weijielyu's picture
Update demo
59ae2c2
raw
history blame
3.09 kB
# Copyright (C) 2025, FaceLift Research Group
# https://github.com/weijielyu/FaceLift
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact: [email protected]
"""
FaceLift: Single Image 3D Face Reconstruction
Generates 3D head models from single images using multi-view diffusion and GS-LRM.
"""
import json
from pathlib import Path
from datetime import datetime
import uuid
import time
import shutil
import gradio as gr
import numpy as np
import torch
import yaml
from easydict import EasyDict as edict
from einops import rearrange
from PIL import Image
from huggingface_hub import snapshot_download
import spaces
# Install diff-gaussian-rasterization at runtime (requires GPU)
import subprocess
import sys
import os
# -----------------------------
# Static paths (for viewer files)
# -----------------------------
OUTPUTS_DIR = Path.cwd() / "outputs"
SPLATS_ROOT = OUTPUTS_DIR / "splats"
SPLATS_ROOT.mkdir(parents=True, exist_ok=True)
# Serve ./outputs via Gradio's static router: /gradio_api/file=outputs/...
gr.set_static_paths(paths=[OUTPUTS_DIR])
# -----------------------------
# Per-session helpers
# -----------------------------
def new_session_id() -> str:
return uuid.uuid4().hex[:10]
def session_dir(session_id: str) -> Path:
p = SPLATS_ROOT / session_id
p.mkdir(parents=True, exist_ok=True)
return p
def cleanup_old_sessions(max_age_hours: int = 6):
cutoff = time.time() - max_age_hours * 3600
if not SPLATS_ROOT.exists():
return
for child in SPLATS_ROOT.iterdir():
try:
if child.is_dir() and child.stat().st_mtime < cutoff:
shutil.rmtree(child, ignore_errors=True)
except Exception:
pass
def copy_to_session_and_get_url(src_path: str, session_id: str) -> str:
"""
Copy a .splat or .ply into this user's session folder and return a cache-busted URL.
"""
src = Path(src_path)
ext = src.suffix.lower() if src.suffix else ".ply"
fn = f"{int(time.time())}_{uuid.uuid4().hex[:6]}{ext}"
dst = session_dir(session_id) / fn
shutil.copy2(src, dst)
# /gradio_api/file=outputs/...
return f"/gradio_api/file=outputs/splats/{session_id}/{fn}?v={uuid.uuid4().hex[:6]}"
# -----------------------------
# Ensure diff-gaussian-rasterization builds for current GPU
# -----------------------------
try:
import diff_gaussian_rasterization # noqa: F401
except ImportError:
print("Installing diff-gaussian-rasterization (compiling for detected CUDA arch)...")
env = os.environ.copy()
try:
import torch as _torch
if _torch.cuda.is_available():
maj, minr = _torch.cuda.get_device_capability()
arch = f"{maj}.{minr}" # e.g., "9.0" on H100/H200, "8.0" on A100
env["TORCH_CUDA_ARCH_LIST"] = f"{arch}+PTX"
else:
# Build stage may not see a GPU on HF Spaces: compile a cross-arch set
env["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;8.9;9.0+PTX"
except Excep