weijielyu commited on
Commit
f9b802b
·
1 Parent(s): 6272050
Files changed (2) hide show
  1. app.py +5 -3
  2. app_old.py +398 -0
app.py CHANGED
@@ -351,7 +351,7 @@ def main():
351
  input_path, multiview_path, output_path, turntable_path, ply_path = \
352
  pipeline.generate_3d_head(image_path, auto_crop, guidance_scale, random_seed, num_steps)
353
 
354
- return output_path, turntable_path, ply_path
355
 
356
  gr.Markdown("## FaceLift: Single Image 3D Face Reconstruction.")
357
 
@@ -382,13 +382,15 @@ def main():
382
  with gr.Column(scale=1):
383
  out_recon = gr.Image(label="3D Reconstruction Views")
384
  out_video = gr.PlayableVideo(label="Turntable Animation (360° View)", height=600)
385
- out_ply = gr.File(label="Download 3D Model (.ply)")
 
 
386
 
387
  # Run generation and display all outputs
388
  run_btn.click(
389
  fn=_generate_and_filter_outputs,
390
  inputs=[in_image, auto_crop, guidance, seed, steps],
391
- outputs=[out_recon, out_video, out_ply],
392
  )
393
 
394
  demo.queue(max_size=10)
 
351
  input_path, multiview_path, output_path, turntable_path, ply_path = \
352
  pipeline.generate_3d_head(image_path, auto_crop, guidance_scale, random_seed, num_steps)
353
 
354
+ return output_path, turntable_path, str(ply_path), ply_path
355
 
356
  gr.Markdown("## FaceLift: Single Image 3D Face Reconstruction.")
357
 
 
382
  with gr.Column(scale=1):
383
  out_recon = gr.Image(label="3D Reconstruction Views")
384
  out_video = gr.PlayableVideo(label="Turntable Animation (360° View)", height=600)
385
+ # Interactive 3D viewer for the generated Gaussian PLY (uses three.js under the hood)
386
+ out_viewer = gr.Model3D(label="Interactive 3D Viewer (.ply)", height=600)
387
+ out_ply = gr.File(label="Download 3D Model (.ply)")")
388
 
389
  # Run generation and display all outputs
390
  run_btn.click(
391
  fn=_generate_and_filter_outputs,
392
  inputs=[in_image, auto_crop, guidance, seed, steps],
393
+ outputs=[out_recon, out_video, out_viewer, out_ply],
394
  )
395
 
396
  demo.queue(max_size=10)
app_old.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025, FaceLift Research Group
2
+ # https://github.com/weijielyu/FaceLift
3
+ #
4
+ # This software is free for non-commercial, research and evaluation use
5
+ # under the terms of the LICENSE.md file.
6
+ #
7
+ # For inquiries contact: [email protected]
8
+
9
+ """
10
+ FaceLift: Single Image 3D Face Reconstruction
11
+ Generates 3D head models from single images using multi-view diffusion and GS-LRM.
12
+ """
13
+
14
+ # Disable HF fast transfer if hf_transfer is not installed
15
+ # This MUST be done before importing huggingface_hub
16
+ import os
17
+ if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER") == "1":
18
+ try:
19
+ import hf_transfer
20
+ except ImportError:
21
+ print("⚠️ hf_transfer not available, disabling fast download")
22
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
23
+
24
+ import json
25
+ from pathlib import Path
26
+ from datetime import datetime
27
+ import uuid
28
+ import time
29
+ import shutil
30
+
31
+ import gradio as gr
32
+ import numpy as np
33
+ import torch
34
+ import yaml
35
+ from easydict import EasyDict as edict
36
+ from einops import rearrange
37
+ from PIL import Image
38
+ from huggingface_hub import snapshot_download
39
+ import spaces
40
+
41
+ # Install diff-gaussian-rasterization at runtime (requires GPU)
42
+ import subprocess
43
+ import sys
44
+
45
+ # Outputs directory for generated files
46
+ OUTPUTS_DIR = Path.cwd() / "outputs"
47
+ OUTPUTS_DIR.mkdir(exist_ok=True)
48
+
49
+ # -----------------------------
50
+ # Ensure diff-gaussian-rasterization builds for current GPU
51
+ # -----------------------------
52
+ try:
53
+ import diff_gaussian_rasterization # noqa: F401
54
+ except ImportError:
55
+ print("Installing diff-gaussian-rasterization (compiling for detected CUDA arch)...")
56
+ env = os.environ.copy()
57
+ try:
58
+ import torch as _torch
59
+ if _torch.cuda.is_available():
60
+ maj, minr = _torch.cuda.get_device_capability()
61
+ arch = f"{maj}.{minr}" # e.g., "9.0" on H100/H200, "8.0" on A100
62
+ env["TORCH_CUDA_ARCH_LIST"] = f"{arch}+PTX"
63
+ else:
64
+ # Build stage may not see a GPU on HF Spaces: compile a cross-arch set
65
+ env["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;8.9;9.0+PTX"
66
+ except Exception:
67
+ env["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;8.9;9.0+PTX"
68
+
69
+ # (Optional) side-step allocator+NVML quirks in restrictive containers
70
+ env.setdefault("PYTORCH_NO_CUDA_MEMORY_CACHING", "1")
71
+
72
+ subprocess.check_call(
73
+ [sys.executable, "-m", "pip", "install",
74
+ "git+https://github.com/graphdeco-inria/diff-gaussian-rasterization"],
75
+ env=env,
76
+ )
77
+ import diff_gaussian_rasterization # noqa: F401
78
+
79
+
80
+ from gslrm.model.gaussians_renderer import render_turntable, imageseq2video
81
+ from mvdiffusion.pipelines.pipeline_mvdiffusion_unclip import StableUnCLIPImg2ImgPipeline
82
+ from utils_folder.face_utils import preprocess_image, preprocess_image_without_cropping
83
+
84
+ # HuggingFace repository configuration
85
+ HF_REPO_ID = "wlyu/OpenFaceLift"
86
+
87
+ def download_weights_from_hf() -> Path:
88
+ """Download model weights from HuggingFace if not already present.
89
+
90
+ Returns:
91
+ Path to the downloaded repository
92
+ """
93
+ workspace_dir = Path(__file__).parent
94
+
95
+ # Check if weights already exist locally
96
+ mvdiffusion_path = workspace_dir / "checkpoints/mvdiffusion/pipeckpts"
97
+ gslrm_path = workspace_dir / "checkpoints/gslrm/ckpt_0000000000021125.pt"
98
+
99
+ if mvdiffusion_path.exists() and gslrm_path.exists():
100
+ print("Using local model weights")
101
+ return workspace_dir
102
+
103
+ print(f"Downloading model weights from HuggingFace: {HF_REPO_ID}")
104
+ print("This may take a few minutes on first run...")
105
+
106
+ # Download to local directory
107
+ snapshot_download(
108
+ repo_id=HF_REPO_ID,
109
+ local_dir=str(workspace_dir / "checkpoints"),
110
+ local_dir_use_symlinks=False,
111
+ )
112
+
113
+ print("Model weights downloaded successfully!")
114
+ return workspace_dir
115
+
116
+ class FaceLiftPipeline:
117
+ """Pipeline for FaceLift 3D head generation from single images."""
118
+
119
+ def __init__(self):
120
+ # Download weights from HuggingFace if needed
121
+ workspace_dir = download_weights_from_hf()
122
+
123
+ # Setup paths
124
+ self.output_dir = workspace_dir / "outputs"
125
+ self.examples_dir = workspace_dir / "examples"
126
+ self.output_dir.mkdir(exist_ok=True)
127
+
128
+ # Parameters
129
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
130
+ self.image_size = 512
131
+ self.camera_indices = [2, 1, 0, 5, 4, 3]
132
+
133
+ # Load models (keep on CPU for ZeroGPU compatibility)
134
+ print("Loading models...")
135
+ try:
136
+ self.mvdiffusion_pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
137
+ str(workspace_dir / "checkpoints/mvdiffusion/pipeckpts"),
138
+ torch_dtype=torch.float16,
139
+ )
140
+ # Don't move to device or enable xformers here - will be done in GPU-decorated function
141
+ self._models_on_gpu = False
142
+
143
+ with open(workspace_dir / "configs/gslrm.yaml", "r") as f:
144
+ config = edict(yaml.safe_load(f))
145
+
146
+ module_name, class_name = config.model.class_name.rsplit(".", 1)
147
+ module = __import__(module_name, fromlist=[class_name])
148
+ ModelClass = getattr(module, class_name)
149
+
150
+ self.gs_lrm_model = ModelClass(config)
151
+ checkpoint = torch.load(
152
+ workspace_dir / "checkpoints/gslrm/ckpt_0000000000021125.pt",
153
+ map_location="cpu"
154
+ )
155
+ # Filter out loss_calculator weights (training-only, not needed for inference)
156
+ state_dict = {k: v for k, v in checkpoint["model"].items()
157
+ if not k.startswith("loss_calculator.")}
158
+ self.gs_lrm_model.load_state_dict(state_dict)
159
+ # Keep on CPU initially - will move to GPU in decorated function
160
+
161
+ self.color_prompt_embedding = torch.load(
162
+ workspace_dir / "mvdiffusion/fixed_prompt_embeds_6view/clr_embeds.pt",
163
+ map_location="cpu"
164
+ )
165
+
166
+ with open(workspace_dir / "utils_folder/opencv_cameras.json", 'r') as f:
167
+ self.cameras_data = json.load(f)["frames"]
168
+
169
+ print("Models loaded successfully!")
170
+ except Exception as e:
171
+ print(f"Error loading models: {e}")
172
+ import traceback
173
+ traceback.print_exc()
174
+ raise
175
+
176
+ def _move_models_to_gpu(self):
177
+ """Move models to GPU and enable optimizations. Called within @spaces.GPU context."""
178
+ if not self._models_on_gpu and torch.cuda.is_available():
179
+ print("Moving models to GPU...")
180
+ self.device = torch.device("cuda:0")
181
+ self.mvdiffusion_pipeline.to(self.device)
182
+ self.mvdiffusion_pipeline.unet.enable_xformers_memory_efficient_attention()
183
+ self.gs_lrm_model.to(self.device)
184
+ self.gs_lrm_model.eval() # Set to eval mode
185
+ self.color_prompt_embedding = self.color_prompt_embedding.to(self.device)
186
+ self._models_on_gpu = True
187
+ torch.cuda.empty_cache() # Clear cache after moving models
188
+ print("Models on GPU, xformers enabled!")
189
+
190
+ @spaces.GPU(duration=120)
191
+ def generate_3d_head(self, image_path, auto_crop=True, guidance_scale=3.0,
192
+ random_seed=4, num_steps=50):
193
+ """Generate 3D head from single image."""
194
+ try:
195
+ # Move models to GPU now that we're in the GPU context
196
+ self._move_models_to_gpu()
197
+ # Setup output directory
198
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
199
+ output_dir = self.output_dir / timestamp
200
+ output_dir.mkdir(exist_ok=True)
201
+
202
+ # Preprocess input
203
+ original_img = np.array(Image.open(image_path))
204
+ input_image = preprocess_image(original_img) if auto_crop else \
205
+ preprocess_image_without_cropping(original_img)
206
+
207
+ if input_image.size != (self.image_size, self.image_size):
208
+ input_image = input_image.resize((self.image_size, self.image_size))
209
+
210
+ input_path = output_dir / "input.png"
211
+ input_image.save(input_path)
212
+
213
+ # Generate multi-view images
214
+ generator = torch.Generator(device=self.mvdiffusion_pipeline.unet.device)
215
+ generator.manual_seed(random_seed)
216
+
217
+ result = self.mvdiffusion_pipeline(
218
+ input_image, None,
219
+ prompt_embeds=self.color_prompt_embedding,
220
+ height=self.image_size,
221
+ width=self.image_size,
222
+ guidance_scale=guidance_scale,
223
+ num_images_per_prompt=1,
224
+ num_inference_steps=num_steps,
225
+ generator=generator,
226
+ eta=1.0,
227
+ )
228
+
229
+ selected_views = result.images[:6]
230
+
231
+ # Save multi-view composite
232
+ multiview_image = Image.new("RGB", (self.image_size * 6, self.image_size))
233
+ for i, view in enumerate(selected_views):
234
+ multiview_image.paste(view, (self.image_size * i, 0))
235
+
236
+ multiview_path = output_dir / "multiview.png"
237
+ multiview_image.save(multiview_path)
238
+
239
+ # Move diffusion model to CPU to free GPU memory for GS-LRM
240
+ print("Moving diffusion model to CPU to free memory...")
241
+ self.mvdiffusion_pipeline.to("cpu")
242
+
243
+ # Delete intermediate variables to free memory
244
+ del result, generator
245
+ torch.cuda.empty_cache()
246
+ torch.cuda.synchronize()
247
+
248
+ # Prepare 3D reconstruction input
249
+ view_arrays = [np.array(view) for view in selected_views]
250
+ lrm_input = torch.from_numpy(np.stack(view_arrays, axis=0)).float()
251
+ lrm_input = lrm_input[None].to(self.device) / 255.0
252
+ lrm_input = rearrange(lrm_input, "b v h w c -> b v c h w")
253
+
254
+ # Prepare camera parameters
255
+ selected_cameras = [self.cameras_data[i] for i in self.camera_indices]
256
+ fxfycxcy_list = [[c["fx"], c["fy"], c["cx"], c["cy"]] for c in selected_cameras]
257
+ c2w_list = [np.linalg.inv(np.array(c["w2c"])) for c in selected_cameras]
258
+
259
+ fxfycxcy = torch.from_numpy(np.stack(fxfycxcy_list, axis=0).astype(np.float32))
260
+ c2w = torch.from_numpy(np.stack(c2w_list, axis=0).astype(np.float32))
261
+ fxfycxcy = fxfycxcy[None].to(self.device)
262
+ c2w = c2w[None].to(self.device)
263
+
264
+ batch_indices = torch.stack([
265
+ torch.zeros(lrm_input.size(1)).long(),
266
+ torch.arange(lrm_input.size(1)).long(),
267
+ ], dim=-1)[None].to(self.device)
268
+
269
+ batch = edict({
270
+ "image": lrm_input,
271
+ "c2w": c2w,
272
+ "fxfycxcy": fxfycxcy,
273
+ "index": batch_indices,
274
+ })
275
+
276
+ # Ensure GS-LRM model is on GPU
277
+ if next(self.gs_lrm_model.parameters()).device.type == "cpu":
278
+ print("Moving GS-LRM model to GPU...")
279
+ self.gs_lrm_model.to(self.device)
280
+ torch.cuda.empty_cache()
281
+
282
+ # Final memory cleanup before reconstruction
283
+ torch.cuda.empty_cache()
284
+
285
+ # Run 3D reconstruction
286
+ with torch.no_grad(), torch.autocast(enabled=True, device_type="cuda", dtype=torch.float16):
287
+ result = self.gs_lrm_model.forward(batch, create_visual=False, split_data=True)
288
+
289
+ comp_image = result.render[0].unsqueeze(0).detach()
290
+ gaussians = result.gaussians[0]
291
+
292
+ # Clear CUDA cache after reconstruction
293
+ torch.cuda.empty_cache()
294
+
295
+ # Save filtered gaussians
296
+ filtered_gaussians = gaussians.apply_all_filters(
297
+ cam_origins=None,
298
+ opacity_thres=0.04,
299
+ scaling_thres=0.2,
300
+ floater_thres=0.75,
301
+ crop_bbx=[-0.91, 0.91, -0.91, 0.91, -1.0, 1.0],
302
+ nearfar_percent=(0.0001, 1.0),
303
+ )
304
+
305
+ ply_path = output_dir / "gaussians.ply"
306
+ filtered_gaussians.save_ply(str(ply_path))
307
+
308
+ # Save output image
309
+ comp_image = rearrange(comp_image, "x v c h w -> (x h) (v w) c")
310
+ comp_image = (comp_image.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
311
+ output_path = output_dir / "output.png"
312
+ Image.fromarray(comp_image).save(output_path)
313
+
314
+ # Generate turntable video
315
+ turntable_resolution = 512
316
+ num_turntable_views = 180
317
+ turntable_frames = render_turntable(gaussians, rendering_resolution=turntable_resolution,
318
+ num_views=num_turntable_views)
319
+ turntable_frames = rearrange(turntable_frames, "h (v w) c -> v h w c", v=num_turntable_views)
320
+ turntable_frames = np.ascontiguousarray(turntable_frames)
321
+
322
+ turntable_path = output_dir / "turntable.mp4"
323
+ imageseq2video(turntable_frames, str(turntable_path), fps=30)
324
+
325
+ # Final CUDA cache clear
326
+ torch.cuda.empty_cache()
327
+
328
+ return str(input_path), str(multiview_path), str(output_path), \
329
+ str(turntable_path), str(ply_path)
330
+
331
+ except Exception as e:
332
+ import traceback
333
+ error_details = traceback.format_exc()
334
+ print(f"Error details:\n{error_details}")
335
+ raise gr.Error(f"Generation failed: {str(e)}")
336
+
337
+ def main():
338
+ """Run the FaceLift application."""
339
+ pipeline = FaceLiftPipeline()
340
+
341
+ # Prepare examples (same as before)
342
+ examples = []
343
+ if pipeline.examples_dir.exists():
344
+ examples = [[str(f), True, 3.0, 4, 50] for f in sorted(pipeline.examples_dir.iterdir())
345
+ if f.suffix.lower() in {'.png', '.jpg', '.jpeg'}]
346
+
347
+ with gr.Blocks(title="FaceLift: Single Image 3D Face Reconstruction") as demo:
348
+
349
+ # Wrapper to return outputs for display
350
+ def _generate_and_filter_outputs(image_path, auto_crop, guidance_scale, random_seed, num_steps):
351
+ input_path, multiview_path, output_path, turntable_path, ply_path = \
352
+ pipeline.generate_3d_head(image_path, auto_crop, guidance_scale, random_seed, num_steps)
353
+
354
+ return output_path, turntable_path, ply_path
355
+
356
+ gr.Markdown("## FaceLift: Single Image 3D Face Reconstruction.")
357
+
358
+ gr.Markdown("""
359
+ ### 💡 Tips for Best Results
360
+ - Works best with near-frontal portrait images
361
+ - The provided checkpoints were not trained with accessories (glasses, hats, etc.). Portraits containing accessories may produce suboptimal results.
362
+ - If face detection fails, try disabling auto-cropping and manually crop to square
363
+ """)
364
+
365
+ with gr.Row():
366
+ with gr.Column(scale=1):
367
+ in_image = gr.Image(type="filepath", label="Input Portrait Image")
368
+ auto_crop = gr.Checkbox(value=True, label="Auto Cropping")
369
+ guidance = gr.Slider(1.0, 10.0, 3.0, step=0.1, label="Guidance Scale")
370
+ seed = gr.Number(value=4, label="Random Seed")
371
+ steps = gr.Slider(10, 100, 50, step=5, label="Generation Steps")
372
+ run_btn = gr.Button("Generate 3D Head", variant="primary")
373
+
374
+ # Examples (match input signature)
375
+ if examples:
376
+ gr.Examples(
377
+ examples=examples,
378
+ inputs=[in_image, auto_crop, guidance, seed, steps],
379
+ examples_per_page=10,
380
+ )
381
+
382
+ with gr.Column(scale=1):
383
+ out_recon = gr.Image(label="3D Reconstruction Views")
384
+ out_video = gr.PlayableVideo(label="Turntable Animation (360° View)", height=600)
385
+ out_ply = gr.File(label="Download 3D Model (.ply)")
386
+
387
+ # Run generation and display all outputs
388
+ run_btn.click(
389
+ fn=_generate_and_filter_outputs,
390
+ inputs=[in_image, auto_crop, guidance, seed, steps],
391
+ outputs=[out_recon, out_video, out_ply],
392
+ )
393
+
394
+ demo.queue(max_size=10)
395
+ demo.launch(share=True, server_name="0.0.0.0", server_port=7860, show_error=True)
396
+
397
+ if __name__ == "__main__":
398
+ main()