weijielyu commited on
Commit
59ae2c2
·
1 Parent(s): 4d09c83

Update demo

Browse files
Files changed (1) hide show
  1. app.py +54 -322
app.py CHANGED
@@ -14,6 +14,9 @@ Generates 3D head models from single images using multi-view diffusion and GS-LR
14
  import json
15
  from pathlib import Path
16
  from datetime import datetime
 
 
 
17
 
18
  import gradio as gr
19
  import numpy as np
@@ -29,337 +32,66 @@ import spaces
29
  import subprocess
30
  import sys
31
  import os
32
- import subprocess, sys, os
33
 
34
- # Ensure diff-gaussian-rasterization is compiled for the current GPU arch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  try:
36
  import diff_gaussian_rasterization # noqa: F401
37
  except ImportError:
38
  print("Installing diff-gaussian-rasterization (compiling for detected CUDA arch)...")
39
  env = os.environ.copy()
40
  try:
41
- import torch
42
- if torch.cuda.is_available():
43
- maj, minr = torch.cuda.get_device_capability()
44
  arch = f"{maj}.{minr}" # e.g., "9.0" on H100/H200, "8.0" on A100
45
  env["TORCH_CUDA_ARCH_LIST"] = f"{arch}+PTX"
46
  else:
47
  # Build stage may not see a GPU on HF Spaces: compile a cross-arch set
48
  env["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;8.9;9.0+PTX"
49
- except Exception:
50
- env["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;8.9;9.0+PTX"
51
-
52
- # (Optional) side-step allocator+NVML quirks in restrictive containers
53
- env.setdefault("PYTORCH_NO_CUDA_MEMORY_CACHING", "1")
54
-
55
- subprocess.check_call(
56
- [sys.executable, "-m", "pip", "install",
57
- "git+https://github.com/graphdeco-inria/diff-gaussian-rasterization"],
58
- env=env,
59
- )
60
- import diff_gaussian_rasterization # noqa: F401
61
-
62
-
63
- from gslrm.model.gaussians_renderer import render_turntable, imageseq2video
64
- from mvdiffusion.pipelines.pipeline_mvdiffusion_unclip import StableUnCLIPImg2ImgPipeline
65
- from utils_folder.face_utils import preprocess_image, preprocess_image_without_cropping
66
-
67
- # HuggingFace repository configuration
68
- HF_REPO_ID = "wlyu/OpenFaceLift"
69
-
70
- def download_weights_from_hf() -> Path:
71
- """Download model weights from HuggingFace if not already present.
72
-
73
- Returns:
74
- Path to the downloaded repository
75
- """
76
- workspace_dir = Path(__file__).parent
77
-
78
- # Check if weights already exist locally
79
- mvdiffusion_path = workspace_dir / "checkpoints/mvdiffusion/pipeckpts"
80
- gslrm_path = workspace_dir / "checkpoints/gslrm/ckpt_0000000000021125.pt"
81
-
82
- if mvdiffusion_path.exists() and gslrm_path.exists():
83
- print("Using local model weights")
84
- return workspace_dir
85
-
86
- print(f"Downloading model weights from HuggingFace: {HF_REPO_ID}")
87
- print("This may take a few minutes on first run...")
88
-
89
- # Download to local directory
90
- snapshot_download(
91
- repo_id=HF_REPO_ID,
92
- local_dir=str(workspace_dir / "checkpoints"),
93
- local_dir_use_symlinks=False,
94
- )
95
-
96
- print("Model weights downloaded successfully!")
97
- return workspace_dir
98
-
99
- class FaceLiftPipeline:
100
- """Pipeline for FaceLift 3D head generation from single images."""
101
-
102
- def __init__(self):
103
- # Download weights from HuggingFace if needed
104
- workspace_dir = download_weights_from_hf()
105
-
106
- # Setup paths
107
- self.output_dir = workspace_dir / "outputs"
108
- self.examples_dir = workspace_dir / "examples"
109
- self.output_dir.mkdir(exist_ok=True)
110
-
111
- # Parameters
112
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
113
- self.image_size = 512
114
- self.camera_indices = [2, 1, 0, 5, 4, 3]
115
-
116
- # Load models (keep on CPU for ZeroGPU compatibility)
117
- print("Loading models...")
118
- try:
119
- self.mvdiffusion_pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
120
- str(workspace_dir / "checkpoints/mvdiffusion/pipeckpts"),
121
- torch_dtype=torch.float16,
122
- )
123
- # Don't move to device or enable xformers here - will be done in GPU-decorated function
124
- self._models_on_gpu = False
125
-
126
- with open(workspace_dir / "configs/gslrm.yaml", "r") as f:
127
- config = edict(yaml.safe_load(f))
128
-
129
- module_name, class_name = config.model.class_name.rsplit(".", 1)
130
- module = __import__(module_name, fromlist=[class_name])
131
- ModelClass = getattr(module, class_name)
132
-
133
- self.gs_lrm_model = ModelClass(config)
134
- checkpoint = torch.load(
135
- workspace_dir / "checkpoints/gslrm/ckpt_0000000000021125.pt",
136
- map_location="cpu"
137
- )
138
- # Filter out loss_calculator weights (training-only, not needed for inference)
139
- state_dict = {k: v for k, v in checkpoint["model"].items()
140
- if not k.startswith("loss_calculator.")}
141
- self.gs_lrm_model.load_state_dict(state_dict)
142
- # Keep on CPU initially - will move to GPU in decorated function
143
-
144
- self.color_prompt_embedding = torch.load(
145
- workspace_dir / "mvdiffusion/fixed_prompt_embeds_6view/clr_embeds.pt",
146
- map_location="cpu"
147
- )
148
-
149
- with open(workspace_dir / "utils_folder/opencv_cameras.json", 'r') as f:
150
- self.cameras_data = json.load(f)["frames"]
151
-
152
- print("Models loaded successfully!")
153
- except Exception as e:
154
- print(f"Error loading models: {e}")
155
- import traceback
156
- traceback.print_exc()
157
- raise
158
-
159
- def _move_models_to_gpu(self):
160
- """Move models to GPU and enable optimizations. Called within @spaces.GPU context."""
161
- if not self._models_on_gpu and torch.cuda.is_available():
162
- print("Moving models to GPU...")
163
- self.device = torch.device("cuda:0")
164
- self.mvdiffusion_pipeline.to(self.device)
165
- self.mvdiffusion_pipeline.unet.enable_xformers_memory_efficient_attention()
166
- self.gs_lrm_model.to(self.device)
167
- self.gs_lrm_model.eval() # Set to eval mode
168
- self.color_prompt_embedding = self.color_prompt_embedding.to(self.device)
169
- self._models_on_gpu = True
170
- torch.cuda.empty_cache() # Clear cache after moving models
171
- print("Models on GPU, xformers enabled!")
172
-
173
- @spaces.GPU(duration=120)
174
- def generate_3d_head(self, image_path, auto_crop=True, guidance_scale=3.0,
175
- random_seed=4, num_steps=50):
176
- """Generate 3D head from single image."""
177
- try:
178
- # Move models to GPU now that we're in the GPU context
179
- self._move_models_to_gpu()
180
- # Setup output directory
181
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
182
- output_dir = self.output_dir / timestamp
183
- output_dir.mkdir(exist_ok=True)
184
-
185
- # Preprocess input
186
- original_img = np.array(Image.open(image_path))
187
- input_image = preprocess_image(original_img) if auto_crop else \
188
- preprocess_image_without_cropping(original_img)
189
-
190
- if input_image.size != (self.image_size, self.image_size):
191
- input_image = input_image.resize((self.image_size, self.image_size))
192
-
193
- input_path = output_dir / "input.png"
194
- input_image.save(input_path)
195
-
196
- # Generate multi-view images
197
- generator = torch.Generator(device=self.mvdiffusion_pipeline.unet.device)
198
- generator.manual_seed(random_seed)
199
-
200
- result = self.mvdiffusion_pipeline(
201
- input_image, None,
202
- prompt_embeds=self.color_prompt_embedding,
203
- height=self.image_size,
204
- width=self.image_size,
205
- guidance_scale=guidance_scale,
206
- num_images_per_prompt=1,
207
- num_inference_steps=num_steps,
208
- generator=generator,
209
- eta=1.0,
210
- )
211
-
212
- selected_views = result.images[:6]
213
-
214
- # Save multi-view composite
215
- multiview_image = Image.new("RGB", (self.image_size * 6, self.image_size))
216
- for i, view in enumerate(selected_views):
217
- multiview_image.paste(view, (self.image_size * i, 0))
218
-
219
- multiview_path = output_dir / "multiview.png"
220
- multiview_image.save(multiview_path)
221
-
222
- # Move diffusion model to CPU to free GPU memory for GS-LRM
223
- print("Moving diffusion model to CPU to free memory...")
224
- self.mvdiffusion_pipeline.to("cpu")
225
-
226
- # Delete intermediate variables to free memory
227
- del result, generator
228
- torch.cuda.empty_cache()
229
- torch.cuda.synchronize()
230
-
231
- # Prepare 3D reconstruction input
232
- view_arrays = [np.array(view) for view in selected_views]
233
- lrm_input = torch.from_numpy(np.stack(view_arrays, axis=0)).float()
234
- lrm_input = lrm_input[None].to(self.device) / 255.0
235
- lrm_input = rearrange(lrm_input, "b v h w c -> b v c h w")
236
-
237
- # Prepare camera parameters
238
- selected_cameras = [self.cameras_data[i] for i in self.camera_indices]
239
- fxfycxcy_list = [[c["fx"], c["fy"], c["cx"], c["cy"]] for c in selected_cameras]
240
- c2w_list = [np.linalg.inv(np.array(c["w2c"])) for c in selected_cameras]
241
-
242
- fxfycxcy = torch.from_numpy(np.stack(fxfycxcy_list, axis=0).astype(np.float32))
243
- c2w = torch.from_numpy(np.stack(c2w_list, axis=0).astype(np.float32))
244
- fxfycxcy = fxfycxcy[None].to(self.device)
245
- c2w = c2w[None].to(self.device)
246
-
247
- batch_indices = torch.stack([
248
- torch.zeros(lrm_input.size(1)).long(),
249
- torch.arange(lrm_input.size(1)).long(),
250
- ], dim=-1)[None].to(self.device)
251
-
252
- batch = edict({
253
- "image": lrm_input,
254
- "c2w": c2w,
255
- "fxfycxcy": fxfycxcy,
256
- "index": batch_indices,
257
- })
258
-
259
- # Ensure GS-LRM model is on GPU
260
- if next(self.gs_lrm_model.parameters()).device.type == "cpu":
261
- print("Moving GS-LRM model to GPU...")
262
- self.gs_lrm_model.to(self.device)
263
- torch.cuda.empty_cache()
264
-
265
- # Final memory cleanup before reconstruction
266
- torch.cuda.empty_cache()
267
-
268
- # Run 3D reconstruction
269
- with torch.no_grad(), torch.autocast(enabled=True, device_type="cuda", dtype=torch.float16):
270
- result = self.gs_lrm_model.forward(batch, create_visual=False, split_data=True)
271
-
272
- comp_image = result.render[0].unsqueeze(0).detach()
273
- gaussians = result.gaussians[0]
274
-
275
- # Clear CUDA cache after reconstruction
276
- torch.cuda.empty_cache()
277
-
278
- # Save filtered gaussians
279
- filtered_gaussians = gaussians.apply_all_filters(
280
- cam_origins=None,
281
- opacity_thres=0.04,
282
- scaling_thres=0.2,
283
- floater_thres=0.75,
284
- crop_bbx=[-0.91, 0.91, -0.91, 0.91, -1.0, 1.0],
285
- nearfar_percent=(0.0001, 1.0),
286
- )
287
-
288
- ply_path = output_dir / "gaussians.ply"
289
- filtered_gaussians.save_ply(str(ply_path))
290
-
291
- # Save output image
292
- comp_image = rearrange(comp_image, "x v c h w -> (x h) (v w) c")
293
- comp_image = (comp_image.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
294
- output_path = output_dir / "output.png"
295
- Image.fromarray(comp_image).save(output_path)
296
-
297
- # Generate turntable video
298
- turntable_resolution = 512
299
- num_turntable_views = 180
300
- turntable_frames = render_turntable(gaussians, rendering_resolution=turntable_resolution,
301
- num_views=num_turntable_views)
302
- turntable_frames = rearrange(turntable_frames, "h (v w) c -> v h w c", v=num_turntable_views)
303
- turntable_frames = np.ascontiguousarray(turntable_frames)
304
-
305
- turntable_path = output_dir / "turntable.mp4"
306
- imageseq2video(turntable_frames, str(turntable_path), fps=30)
307
-
308
- # Final CUDA cache clear
309
- torch.cuda.empty_cache()
310
-
311
- return str(input_path), str(multiview_path), str(output_path), \
312
- str(turntable_path), str(ply_path)
313
-
314
- except Exception as e:
315
- import traceback
316
- error_details = traceback.format_exc()
317
- print(f"Error details:\n{error_details}")
318
- raise gr.Error(f"Generation failed: {str(e)}")
319
-
320
-
321
- def main():
322
- """Run the FaceLift application."""
323
- pipeline = FaceLiftPipeline()
324
-
325
- # Load examples - provide all 5 input values (image, auto_crop, guidance_scale, random_seed, num_steps)
326
- examples = []
327
- if pipeline.examples_dir.exists():
328
- examples = [[str(f), True, 3.0, 4, 50] for f in sorted(pipeline.examples_dir.iterdir())
329
- if f.suffix.lower() in {'.png', '.jpg', '.jpeg'}]
330
-
331
- # Create interface
332
- demo = gr.Interface(
333
- fn=pipeline.generate_3d_head,
334
- title="FaceLift: Single Image 3D Face Reconstruction",
335
- description="""
336
- Transform a single portrait image into a complete 3D head model.
337
-
338
- **Tips:**
339
- - Use high-quality portrait images with clear facial features
340
- - If face detection fails, try disabling auto-cropping and manually crop to square
341
- """,
342
- inputs=[
343
- gr.Image(type="filepath", label="Input Portrait Image"),
344
- gr.Checkbox(value=True, label="Auto Cropping"),
345
- gr.Slider(1.0, 10.0, 3.0, step=0.1, label="Guidance Scale"),
346
- gr.Number(value=4, label="Random Seed"),
347
- gr.Slider(10, 100, 50, step=5, label="Generation Steps"),
348
- ],
349
- outputs=[
350
- gr.Image(label="Processed Input"),
351
- gr.Image(label="Multi-view Generation"),
352
- gr.Image(label="3D Reconstruction"),
353
- gr.PlayableVideo(label="Turntable Animation"),
354
- gr.File(label="3D Model (.ply)"),
355
- ],
356
- examples=examples,
357
- allow_flagging="never",
358
- )
359
-
360
- demo.queue(max_size=10)
361
- demo.launch(share=True, server_name="0.0.0.0", server_port=7860, show_error=True)
362
-
363
-
364
- if __name__ == "__main__":
365
- main()
 
14
  import json
15
  from pathlib import Path
16
  from datetime import datetime
17
+ import uuid
18
+ import time
19
+ import shutil
20
 
21
  import gradio as gr
22
  import numpy as np
 
32
  import subprocess
33
  import sys
34
  import os
 
35
 
36
+ # -----------------------------
37
+ # Static paths (for viewer files)
38
+ # -----------------------------
39
+ OUTPUTS_DIR = Path.cwd() / "outputs"
40
+ SPLATS_ROOT = OUTPUTS_DIR / "splats"
41
+ SPLATS_ROOT.mkdir(parents=True, exist_ok=True)
42
+
43
+ # Serve ./outputs via Gradio's static router: /gradio_api/file=outputs/...
44
+ gr.set_static_paths(paths=[OUTPUTS_DIR])
45
+
46
+ # -----------------------------
47
+ # Per-session helpers
48
+ # -----------------------------
49
+ def new_session_id() -> str:
50
+ return uuid.uuid4().hex[:10]
51
+
52
+ def session_dir(session_id: str) -> Path:
53
+ p = SPLATS_ROOT / session_id
54
+ p.mkdir(parents=True, exist_ok=True)
55
+ return p
56
+
57
+ def cleanup_old_sessions(max_age_hours: int = 6):
58
+ cutoff = time.time() - max_age_hours * 3600
59
+ if not SPLATS_ROOT.exists():
60
+ return
61
+ for child in SPLATS_ROOT.iterdir():
62
+ try:
63
+ if child.is_dir() and child.stat().st_mtime < cutoff:
64
+ shutil.rmtree(child, ignore_errors=True)
65
+ except Exception:
66
+ pass
67
+
68
+ def copy_to_session_and_get_url(src_path: str, session_id: str) -> str:
69
+ """
70
+ Copy a .splat or .ply into this user's session folder and return a cache-busted URL.
71
+ """
72
+ src = Path(src_path)
73
+ ext = src.suffix.lower() if src.suffix else ".ply"
74
+ fn = f"{int(time.time())}_{uuid.uuid4().hex[:6]}{ext}"
75
+ dst = session_dir(session_id) / fn
76
+ shutil.copy2(src, dst)
77
+ # /gradio_api/file=outputs/...
78
+ return f"/gradio_api/file=outputs/splats/{session_id}/{fn}?v={uuid.uuid4().hex[:6]}"
79
+
80
+ # -----------------------------
81
+ # Ensure diff-gaussian-rasterization builds for current GPU
82
+ # -----------------------------
83
  try:
84
  import diff_gaussian_rasterization # noqa: F401
85
  except ImportError:
86
  print("Installing diff-gaussian-rasterization (compiling for detected CUDA arch)...")
87
  env = os.environ.copy()
88
  try:
89
+ import torch as _torch
90
+ if _torch.cuda.is_available():
91
+ maj, minr = _torch.cuda.get_device_capability()
92
  arch = f"{maj}.{minr}" # e.g., "9.0" on H100/H200, "8.0" on A100
93
  env["TORCH_CUDA_ARCH_LIST"] = f"{arch}+PTX"
94
  else:
95
  # Build stage may not see a GPU on HF Spaces: compile a cross-arch set
96
  env["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;8.9;9.0+PTX"
97
+ except Excep