weijielyu commited on
Commit
45fe0fa
·
1 Parent(s): 465f735
Files changed (2) hide show
  1. README.md +6 -6
  2. app.py +8 -541
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
- title: My Space
3
- emoji: 🚀
4
- colorFrom: indigo
5
- colorTo: pink
6
- short_description: One-line pitch shown on the tile
7
  sdk: gradio
8
  app_file: app.py
9
- thumbnail: https://huggingface.co/spaces/wlyu/FaceLift/blob/main/examples/teaser.png
10
  pinned: false
11
  ---
12
 
 
1
  ---
2
+ title: FaceLift
3
+ emoji: 🎭
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ short_description: Single Image 3D Face Reconstruction with Gaussian Splatting
7
  sdk: gradio
8
  app_file: app.py
9
+ thumbnail: examples/teaser.png
10
  pinned: false
11
  ---
12
 
app.py CHANGED
@@ -365,511 +365,7 @@ class FaceLiftPipeline:
365
  print(f"Error details:\n{error_details}")
366
  raise gr.Error(f"Generation failed: {str(e)}")
367
 
368
- # -----------------------------
369
- # Custom WebGL Gaussian Splatting Viewer
370
- # Based on antimatter15/splat viewer, modified for Gradio
371
- # -----------------------------
372
- GSPLAT_HEAD = """
373
- <style>
374
- #gs-canvas { width: 100%; height: 100%; display: block; cursor: grab; }
375
- #gs-canvas:active { cursor: grabbing; }
376
- </style>
377
- <script>
378
- // Custom WebGL Gaussian Splatting Viewer - Handles PLY files natively
379
- (function() {
380
- 'use strict';
381
- let canvas, gl, program, worker, texture, indexBuffer;
382
- let viewMatrix, projectionMatrix;
383
- let vertexCount = 0;
384
- let isInitialized = false;
385
- let u_view, u_projection, u_viewport, u_focal;
386
-
387
- const vertexShaderSource = \`#version 300 es
388
- precision highp float;
389
- precision highp int;
390
- uniform highp usampler2D u_texture;
391
- uniform mat4 projection, view;
392
- uniform vec2 focal, viewport;
393
- in vec2 position;
394
- in int index;
395
- out vec4 vColor;
396
- out vec2 vPosition;
397
-
398
- void main () {
399
- uvec4 cen = texelFetch(u_texture, ivec2((uint(index) & 0x3ffu) << 1, uint(index) >> 10), 0);
400
- vec4 cam = view * vec4(uintBitsToFloat(cen.xyz), 1);
401
- vec4 pos2d = projection * cam;
402
- float clip = 1.2 * pos2d.w;
403
- if (pos2d.z < -clip || pos2d.x < -clip || pos2d.x > clip || pos2d.y < -clip || pos2d.y > clip) {
404
- gl_Position = vec4(0.0, 0.0, 2.0, 1.0);
405
- return;
406
- }
407
- uvec4 cov = texelFetch(u_texture, ivec2(((uint(index) & 0x3ffu) << 1) | 1u, uint(index) >> 10), 0);
408
- vec2 u1 = unpackHalf2x16(cov.x), u2 = unpackHalf2x16(cov.y), u3 = unpackHalf2x16(cov.z);
409
- mat3 Vrk = mat3(u1.x, u1.y, u2.x, u1.y, u2.y, u3.x, u2.x, u3.x, u3.y);
410
- mat3 J = mat3(focal.x / cam.z, 0., -(focal.x * cam.x) / (cam.z * cam.z),
411
- 0., -focal.y / cam.z, (focal.y * cam.y) / (cam.z * cam.z), 0., 0., 0.);
412
- mat3 T = transpose(mat3(view)) * J;
413
- mat3 cov2d = transpose(T) * Vrk * T;
414
- float mid = (cov2d[0][0] + cov2d[1][1]) / 2.0;
415
- float radius = length(vec2((cov2d[0][0] - cov2d[1][1]) / 2.0, cov2d[0][1]));
416
- float lambda1 = mid + radius, lambda2 = mid - radius;
417
- if(lambda2 < 0.0) return;
418
- vec2 diagonalVector = normalize(vec2(cov2d[0][1], lambda1 - cov2d[0][0]));
419
- vec2 majorAxis = min(sqrt(2.0 * lambda1), 1024.0) * diagonalVector;
420
- vec2 minorAxis = min(sqrt(2.0 * lambda2), 1024.0) * vec2(diagonalVector.y, -diagonalVector.x);
421
- vColor = clamp(pos2d.z/pos2d.w+1.0, 0.0, 1.0) * vec4((cov.w) & 0xffu, (cov.w >> 8) & 0xffu, (cov.w >> 16) & 0xffu, (cov.w >> 24) & 0xffu) / 255.0;
422
- vPosition = position;
423
- vec2 vCenter = vec2(pos2d) / pos2d.w;
424
- gl_Position = vec4(vCenter + position.x * majorAxis / viewport + position.y * minorAxis / viewport, 0.0, 1.0);
425
- }\`;
426
-
427
- const fragmentShaderSource = \`#version 300 es
428
- precision highp float;
429
- in vec4 vColor;
430
- in vec2 vPosition;
431
- out vec4 fragColor;
432
- void main () {
433
- float A = -dot(vPosition, vPosition);
434
- if (A < -4.0) discard;
435
- float B = exp(A) * vColor.a;
436
- fragColor = vec4(B * vColor.rgb, B);
437
- }\`;
438
-
439
- function createWorker() {
440
- const workerCode = `
441
- let buffer, vertexCount = 0;
442
- const rowLength = 32;
443
- var _floatView = new Float32Array(1), _int32View = new Int32Array(_floatView.buffer);
444
-
445
- function floatToHalf(float) {
446
- _floatView[0] = float;
447
- var f = _int32View[0];
448
- var sign = (f >> 31) & 0x0001, exp = (f >> 23) & 0x00ff, frac = f & 0x007fffff, newExp;
449
- if (exp == 0) newExp = 0;
450
- else if (exp < 113) {
451
- newExp = 0; frac |= 0x00800000; frac = frac >> (113 - exp);
452
- if (frac & 0x01000000) { newExp = 1; frac = 0; }
453
- } else if (exp < 142) newExp = exp - 112;
454
- else { newExp = 31; frac = 0; }
455
- return (sign << 15) | (newExp << 10) | (frac >> 13);
456
- }
457
-
458
- function packHalf2x16(x, y) {
459
- return (floatToHalf(x) | (floatToHalf(y) << 16)) >>> 0;
460
- }
461
-
462
- function processPlyBuffer(inputBuffer) {
463
- const ubuf = new Uint8Array(inputBuffer);
464
- const header = new TextDecoder().decode(ubuf.slice(0, 10240));
465
- const header_end = "end_header\\n";
466
- const header_end_index = header.indexOf(header_end);
467
- if (header_end_index < 0) throw new Error("Unable to read PLY header");
468
-
469
- const match = /element vertex (\\d+)\\n/.exec(header);
470
- const vertexCount = parseInt(match[1]);
471
- console.log("Vertex Count:", vertexCount);
472
-
473
- let row_offset = 0, offsets = {}, types = {};
474
- const TYPE_MAP = {
475
- double: "getFloat64", int: "getInt32", uint: "getUint32",
476
- float: "getFloat32", short: "getInt16", ushort: "getUint16", uchar: "getUint8"
477
- };
478
-
479
- const lines = header.slice(0, header_end_index).split("\\n");
480
- for (let line of lines) {
481
- if (line.startsWith("property ")) {
482
- const parts = line.split(" ");
483
- const type = parts[1];
484
- const name = parts[2];
485
- const arrayType = TYPE_MAP[type] || "getInt8";
486
- types[name] = arrayType;
487
- offsets[name] = row_offset;
488
- row_offset += parseInt(arrayType.replace(/[^\\d]/g, "")) / 8;
489
- }
490
- }
491
-
492
- let dataView = new DataView(inputBuffer, header_end_index + header_end.length);
493
- let row = 0;
494
- const attrs = new Proxy({}, {
495
- get(target, prop) {
496
- if (!types[prop]) throw new Error(prop + " not found");
497
- return dataView[types[prop]](row * row_offset + offsets[prop], true);
498
- }
499
- });
500
-
501
- let sizeList = new Float32Array(vertexCount);
502
- let sizeIndex = new Uint32Array(vertexCount);
503
- for (row = 0; row < vertexCount; row++) {
504
- sizeIndex[row] = row;
505
- if (!types["scale_0"]) continue;
506
- const size = Math.exp(attrs.scale_0) * Math.exp(attrs.scale_1) * Math.exp(attrs.scale_2);
507
- const opacity = 1 / (1 + Math.exp(-attrs.opacity));
508
- sizeList[row] = size * opacity;
509
- }
510
- sizeIndex.sort((b, a) => sizeList[a] - sizeList[b]);
511
-
512
- const buffer = new ArrayBuffer(rowLength * vertexCount);
513
- for (let j = 0; j < vertexCount; j++) {
514
- row = sizeIndex[j];
515
- const position = new Float32Array(buffer, j * rowLength, 3);
516
- const scales = new Float32Array(buffer, j * rowLength + 12, 3);
517
- const rgba = new Uint8ClampedArray(buffer, j * rowLength + 24, 4);
518
- const rot = new Uint8ClampedArray(buffer, j * rowLength + 28, 4);
519
-
520
- if (types["scale_0"]) {
521
- const qlen = Math.sqrt(attrs.rot_0 ** 2 + attrs.rot_1 ** 2 + attrs.rot_2 ** 2 + attrs.rot_3 ** 2);
522
- rot[0] = (attrs.rot_0 / qlen) * 128 + 128;
523
- rot[1] = (attrs.rot_1 / qlen) * 128 + 128;
524
- rot[2] = (attrs.rot_2 / qlen) * 128 + 128;
525
- rot[3] = (attrs.rot_3 / qlen) * 128 + 128;
526
- scales[0] = Math.exp(attrs.scale_0);
527
- scales[1] = Math.exp(attrs.scale_1);
528
- scales[2] = Math.exp(attrs.scale_2);
529
- } else {
530
- scales[0] = scales[1] = scales[2] = 0.01;
531
- rot[0] = 255; rot[1] = rot[2] = rot[3] = 0;
532
- }
533
-
534
- position[0] = attrs.x;
535
- position[1] = attrs.y;
536
- position[2] = attrs.z;
537
-
538
- if (types["f_dc_0"]) {
539
- const SH_C0 = 0.28209479177387814;
540
- rgba[0] = (0.5 + SH_C0 * attrs.f_dc_0) * 255;
541
- rgba[1] = (0.5 + SH_C0 * attrs.f_dc_1) * 255;
542
- rgba[2] = (0.5 + SH_C0 * attrs.f_dc_2) * 255;
543
- } else {
544
- rgba[0] = attrs.red;
545
- rgba[1] = attrs.green;
546
- rgba[2] = attrs.blue;
547
- }
548
- rgba[3] = types["opacity"] ? (1 / (1 + Math.exp(-attrs.opacity))) * 255 : 255;
549
- }
550
- return buffer;
551
- }
552
-
553
- function generateTexture() {
554
- if (!buffer) return;
555
- const f_buffer = new Float32Array(buffer);
556
- const u_buffer = new Uint8Array(buffer);
557
- const texwidth = 2048;
558
- const texheight = Math.ceil((2 * vertexCount) / texwidth);
559
- const texdata = new Uint32Array(texwidth * texheight * 4);
560
- const texdata_c = new Uint8Array(texdata.buffer);
561
- const texdata_f = new Float32Array(texdata.buffer);
562
-
563
- for (let i = 0; i < vertexCount; i++) {
564
- texdata_f[8 * i + 0] = f_buffer[8 * i + 0];
565
- texdata_f[8 * i + 1] = f_buffer[8 * i + 1];
566
- texdata_f[8 * i + 2] = f_buffer[8 * i + 2];
567
- texdata_c[4 * (8 * i + 7) + 0] = u_buffer[32 * i + 24 + 0];
568
- texdata_c[4 * (8 * i + 7) + 1] = u_buffer[32 * i + 24 + 1];
569
- texdata_c[4 * (8 * i + 7) + 2] = u_buffer[32 * i + 24 + 2];
570
- texdata_c[4 * (8 * i + 7) + 3] = u_buffer[32 * i + 24 + 3];
571
-
572
- let scale = [f_buffer[8 * i + 3], f_buffer[8 * i + 4], f_buffer[8 * i + 5]];
573
- let rot = [
574
- (u_buffer[32 * i + 28] - 128) / 128,
575
- (u_buffer[32 * i + 29] - 128) / 128,
576
- (u_buffer[32 * i + 30] - 128) / 128,
577
- (u_buffer[32 * i + 31] - 128) / 128
578
- ];
579
-
580
- const M = [
581
- 1.0 - 2.0 * (rot[2] * rot[2] + rot[3] * rot[3]),
582
- 2.0 * (rot[1] * rot[2] + rot[0] * rot[3]),
583
- 2.0 * (rot[1] * rot[3] - rot[0] * rot[2]),
584
- 2.0 * (rot[1] * rot[2] - rot[0] * rot[3]),
585
- 1.0 - 2.0 * (rot[1] * rot[1] + rot[3] * rot[3]),
586
- 2.0 * (rot[2] * rot[3] + rot[0] * rot[1]),
587
- 2.0 * (rot[1] * rot[3] + rot[0] * rot[2]),
588
- 2.0 * (rot[2] * rot[3] - rot[0] * rot[1]),
589
- 1.0 - 2.0 * (rot[1] * rot[1] + rot[2] * rot[2])
590
- ].map((k, idx) => k * scale[Math.floor(idx / 3)]);
591
-
592
- const sigma = [
593
- M[0]*M[0] + M[3]*M[3] + M[6]*M[6],
594
- M[0]*M[1] + M[3]*M[4] + M[6]*M[7],
595
- M[0]*M[2] + M[3]*M[5] + M[6]*M[8],
596
- M[1]*M[1] + M[4]*M[4] + M[7]*M[7],
597
- M[1]*M[2] + M[4]*M[5] + M[7]*M[8],
598
- M[2]*M[2] + M[5]*M[5] + M[8]*M[8]
599
- ];
600
-
601
- texdata[8 * i + 4] = packHalf2x16(4 * sigma[0], 4 * sigma[1]);
602
- texdata[8 * i + 5] = packHalf2x16(4 * sigma[2], 4 * sigma[3]);
603
- texdata[8 * i + 6] = packHalf2x16(4 * sigma[4], 4 * sigma[5]);
604
- }
605
- self.postMessage({ texdata, texwidth, texheight, vertexCount }, [texdata.buffer]);
606
- }
607
-
608
- self.onmessage = (e) => {
609
- console.log('[Worker] Message received');
610
- if (e.data.ply) {
611
- try {
612
- console.log('[Worker] Processing PLY buffer...');
613
- buffer = processPlyBuffer(e.data.ply);
614
- vertexCount = Math.floor(buffer.byteLength / rowLength);
615
- console.log('[Worker] Vertex count:', vertexCount);
616
- generateTexture();
617
- console.log('[Worker] Texture sent');
618
- } catch (error) {
619
- console.error('[Worker] Error:', error);
620
- }
621
- }
622
- };
623
- `;
624
- return new Worker(URL.createObjectURL(new Blob([workerCode], { type: 'application/javascript' })));
625
- }
626
-
627
- const invert4 = (a) => {
628
- let b00=a[0]*a[5]-a[1]*a[4],b01=a[0]*a[6]-a[2]*a[4],b02=a[0]*a[7]-a[3]*a[4],b03=a[1]*a[6]-a[2]*a[5],b04=a[1]*a[7]-a[3]*a[5],b05=a[2]*a[7]-a[3]*a[6];
629
- let b06=a[8]*a[13]-a[9]*a[12],b07=a[8]*a[14]-a[10]*a[12],b08=a[8]*a[15]-a[11]*a[12],b09=a[9]*a[14]-a[10]*a[13],b10=a[9]*a[15]-a[11]*a[13],b11=a[10]*a[15]-a[11]*a[14];
630
- let det=b00*b11-b01*b10+b02*b09+b03*b08-b04*b07+b05*b06;if(!det)return null;
631
- return[(a[5]*b11-a[6]*b10+a[7]*b09)/det,(a[2]*b10-a[1]*b11-a[3]*b09)/det,(a[13]*b05-a[14]*b04+a[15]*b03)/det,(a[10]*b04-a[9]*b05-a[11]*b03)/det,
632
- (a[6]*b08-a[4]*b11-a[7]*b07)/det,(a[0]*b11-a[2]*b08+a[3]*b07)/det,(a[14]*b02-a[12]*b05-a[15]*b01)/det,(a[8]*b05-a[10]*b02+a[11]*b01)/det,
633
- (a[4]*b10-a[5]*b08+a[7]*b06)/det,(a[1]*b08-a[0]*b10-a[3]*b06)/det,(a[12]*b04-a[13]*b02+a[15]*b00)/det,(a[9]*b02-a[8]*b04-a[11]*b00)/det,
634
- (a[5]*b07-a[4]*b09-a[6]*b06)/det,(a[0]*b09-a[1]*b07+a[2]*b06)/det,(a[13]*b01-a[12]*b03-a[14]*b00)/det,(a[8]*b03-a[9]*b01+a[10]*b00)/det];
635
- };
636
- const rotate4 = (a, rad, x, y, z) => {
637
- let len = Math.hypot(x, y, z); x /= len; y /= len; z /= len;
638
- let s = Math.sin(rad), c = Math.cos(rad), t = 1 - c;
639
- let b00=x*x*t+c, b01=y*x*t+z*s, b02=z*x*t-y*s, b10=x*y*t-z*s, b11=y*y*t+c, b12=z*y*t+x*s, b20=x*z*t+y*s, b21=y*z*t-x*s, b22=z*z*t+c;
640
- return [a[0]*b00+a[4]*b01+a[8]*b02,a[1]*b00+a[5]*b01+a[9]*b02,a[2]*b00+a[6]*b01+a[10]*b02,a[3]*b00+a[7]*b01+a[11]*b02,
641
- a[0]*b10+a[4]*b11+a[8]*b12,a[1]*b10+a[5]*b11+a[9]*b12,a[2]*b10+a[6]*b11+a[10]*b12,a[3]*b10+a[7]*b11+a[11]*b12,
642
- a[0]*b20+a[4]*b21+a[8]*b22,a[1]*b20+a[5]*b21+a[9]*b22,a[2]*b20+a[6]*b21+a[10]*b22,a[3]*b20+a[7]*b21+a[11]*b22,...a.slice(12,16)];
643
- };
644
- const translate4 = (a, x, y, z) => [...a.slice(0,12), a[0]*x+a[4]*y+a[8]*z+a[12], a[1]*x+a[5]*y+a[9]*z+a[13], a[2]*x+a[6]*y+a[10]*z+a[14], a[3]*x+a[7]*y+a[11]*z+a[15]];
645
- const getProjectionMatrix = (fx, fy, width, height) => {
646
- const znear = 0.2, zfar = 200;
647
- return [(2*fx)/width, 0, 0, 0, 0, -(2*fy)/height, 0, 0, 0, 0, zfar/(zfar-znear), 1, 0, 0, -(zfar*znear)/(zfar-znear), 0];
648
- };
649
-
650
- function initViewer() {
651
- console.log('initViewer called, isInitialized:', isInitialized);
652
- if (isInitialized) {
653
- console.log('Already initialized, skipping');
654
- return;
655
- }
656
-
657
- const container = document.getElementById('splat-container');
658
- if (!container) {
659
- console.error('Container #splat-container not found, retrying in 100ms');
660
- setTimeout(initViewer, 100);
661
- return;
662
- }
663
-
664
- console.log('Container found:', container);
665
- canvas = document.createElement('canvas');
666
- canvas.id = 'gs-canvas';
667
- container.innerHTML = '';
668
- container.appendChild(canvas);
669
- console.log('Canvas created and added');
670
-
671
- gl = canvas.getContext('webgl2', { antialias: false });
672
- if (!gl) {
673
- console.error('WebGL2 not supported');
674
- return;
675
- }
676
- console.log('WebGL2 context created');
677
-
678
- const vShader = gl.createShader(gl.VERTEX_SHADER);
679
- gl.shaderSource(vShader, vertexShaderSource);
680
- gl.compileShader(vShader);
681
- const fShader = gl.createShader(gl.FRAGMENT_SHADER);
682
- gl.shaderSource(fShader, fragmentShaderSource);
683
- gl.compileShader(fShader);
684
- program = gl.createProgram();
685
- gl.attachShader(program, vShader);
686
- gl.attachShader(program, fShader);
687
- gl.linkProgram(program);
688
- gl.useProgram(program);
689
-
690
- gl.disable(gl.DEPTH_TEST);
691
- gl.enable(gl.BLEND);
692
- gl.blendFuncSeparate(gl.ONE_MINUS_DST_ALPHA, gl.ONE, gl.ONE_MINUS_DST_ALPHA, gl.ONE);
693
-
694
- u_projection = gl.getUniformLocation(program, 'projection');
695
- u_viewport = gl.getUniformLocation(program, 'viewport');
696
- u_focal = gl.getUniformLocation(program, 'focal');
697
- u_view = gl.getUniformLocation(program, 'view');
698
-
699
- const vBuffer = gl.createBuffer();
700
- gl.bindBuffer(gl.ARRAY_BUFFER, vBuffer);
701
- gl.bufferData(gl.ARRAY_BUFFER, new Float32Array([-2, -2, 2, -2, 2, 2, -2, 2]), gl.STATIC_DRAW);
702
- const a_position = gl.getAttribLocation(program, 'position');
703
- gl.enableVertexAttribArray(a_position);
704
- gl.vertexAttribPointer(a_position, 2, gl.FLOAT, false, 0, 0);
705
-
706
- texture = gl.createTexture();
707
- gl.bindTexture(gl.TEXTURE_2D, texture);
708
- gl.uniform1i(gl.getUniformLocation(program, 'u_texture'), 0);
709
-
710
- indexBuffer = gl.createBuffer();
711
- const a_index = gl.getAttribLocation(program, 'index');
712
- gl.enableVertexAttribArray(a_index);
713
- gl.bindBuffer(gl.ARRAY_BUFFER, indexBuffer);
714
- gl.vertexAttribIPointer(a_index, 1, gl.INT, false, 0, 0);
715
- gl.vertexAttribDivisor(a_index, 1);
716
-
717
- viewMatrix = [0.99, 0, 0.14, 0, 0, 1, 0, 0, -0.14, 0, 0.99, 0, 0, 0, 2.5, 1];
718
- const fx = 500, fy = 500;
719
-
720
- function resize() {
721
- const dpr = window.devicePixelRatio || 1;
722
- const rect = container.getBoundingClientRect();
723
- canvas.width = rect.width * dpr;
724
- canvas.height = rect.height * dpr;
725
- canvas.style.width = rect.width + 'px';
726
- canvas.style.height = rect.height + 'px';
727
- projectionMatrix = getProjectionMatrix(fx, fy, canvas.width, canvas.height);
728
- gl.viewport(0, 0, canvas.width, canvas.height);
729
- gl.uniformMatrix4fv(u_projection, false, projectionMatrix);
730
- gl.uniform2fv(u_viewport, [canvas.width, canvas.height]);
731
- gl.uniform2fv(u_focal, [fx, fy]);
732
- }
733
- window.addEventListener('resize', resize);
734
- resize();
735
-
736
- let startX, startY, down;
737
- canvas.addEventListener('mousedown', (e) => { e.preventDefault(); startX = e.clientX; startY = e.clientY; down = e.ctrlKey || e.metaKey ? 2 : 1; });
738
- canvas.addEventListener('mousemove', (e) => {
739
- e.preventDefault();
740
- if (down == 1) {
741
- let inv = invert4(viewMatrix), dx = (5 * (e.clientX - startX)) / window.innerWidth, dy = (5 * (e.clientY - startY)) / window.innerHeight;
742
- inv = translate4(inv, 0, 0, 2);
743
- inv = rotate4(inv, dx, 0, 1, 0);
744
- inv = rotate4(inv, -dy, 1, 0, 0);
745
- inv = translate4(inv, 0, 0, -2);
746
- viewMatrix = invert4(inv);
747
- startX = e.clientX; startY = e.clientY;
748
- } else if (down == 2) {
749
- let inv = invert4(viewMatrix);
750
- inv = translate4(inv, (-5 * (e.clientX - startX)) / window.innerWidth, 0, (5 * (e.clientY - startY)) / window.innerHeight);
751
- viewMatrix = invert4(inv);
752
- startX = e.clientX; startY = e.clientY;
753
- }
754
- });
755
- canvas.addEventListener('mouseup', () => { down = false; });
756
- canvas.addEventListener('wheel', (e) => {
757
- e.preventDefault();
758
- let inv = invert4(viewMatrix);
759
- inv = translate4(inv, 0, 0, -e.deltaY * 0.001);
760
- viewMatrix = invert4(inv);
761
- }, { passive: false });
762
-
763
- let renderCount = 0;
764
- function render() {
765
- if (vertexCount > 0) {
766
- if (renderCount === 0) console.log('First render with', vertexCount, 'vertices');
767
- gl.uniformMatrix4fv(u_view, false, viewMatrix);
768
- gl.clear(gl.COLOR_BUFFER_BIT);
769
- gl.drawArraysInstanced(gl.TRIANGLE_FAN, 0, 4, vertexCount);
770
- renderCount++;
771
- } else {
772
- if (renderCount % 60 === 0) console.log('Rendering black screen, waiting for data...');
773
- gl.clearColor(0, 0, 0, 1);
774
- gl.clear(gl.COLOR_BUFFER_BIT);
775
- renderCount++;
776
- }
777
- requestAnimationFrame(render);
778
- }
779
- console.log('Starting render loop');
780
- render();
781
-
782
- worker = createWorker();
783
- console.log('Worker created');
784
-
785
- worker.onmessage = (e) => {
786
- console.log('Worker message received:', e.data);
787
- if (e.data.texdata) {
788
- const { texdata, texwidth, texheight, vertexCount: vc } = e.data;
789
- console.log('Texture data:', {texwidth, texheight, vertexCount: vc, textureSize: texdata.byteLength});
790
- vertexCount = vc;
791
-
792
- gl.bindTexture(gl.TEXTURE_2D, texture);
793
- gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
794
- gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
795
- gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
796
- gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
797
- gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA32UI, texwidth, texheight, 0, gl.RGBA_INTEGER, gl.UNSIGNED_INT, texdata);
798
-
799
- const depthIndex = new Uint32Array(vc);
800
- for (let i = 0; i < vc; i++) depthIndex[i] = i;
801
- gl.bindBuffer(gl.ARRAY_BUFFER, indexBuffer);
802
- gl.bufferData(gl.ARRAY_BUFFER, depthIndex, gl.STATIC_DRAW);
803
-
804
- console.log('✓ Successfully loaded', vc, 'gaussians');
805
- }
806
- };
807
-
808
- worker.onerror = (error) => {
809
- console.error('Worker error:', error);
810
- };
811
-
812
- isInitialized = true;
813
- console.log('✓ Viewer initialized successfully');
814
- }
815
-
816
- async function loadPlyFile(url) {
817
- try {
818
- console.log('=== Starting PLY load ===');
819
- console.log('URL:', url);
820
- console.log('Initialized?', isInitialized);
821
- console.log('Worker?', !!worker);
822
-
823
- if (!isInitialized) {
824
- console.log('Initializing viewer...');
825
- initViewer();
826
- // Wait for initialization
827
- await new Promise(resolve => setTimeout(resolve, 500));
828
- }
829
-
830
- if (!worker) {
831
- console.error('Worker not initialized after init attempt');
832
- return;
833
- }
834
-
835
- console.log('Fetching:', url);
836
- const response = await fetch(url);
837
- console.log('Response status:', response.status);
838
- console.log('Response headers:', [...response.headers.entries()]);
839
-
840
- if (!response.ok) throw new Error('Failed to load: ' + response.status);
841
-
842
- const arrayBuffer = await response.arrayBuffer();
843
- console.log('ArrayBuffer size:', arrayBuffer.byteLength, 'bytes');
844
-
845
- // Check PLY magic bytes
846
- const magic = new Uint8Array(arrayBuffer.slice(0, 4));
847
- console.log('Magic bytes:', magic, 'Expected: [112, 108, 121, 10] for "ply\\n"');
848
-
849
- worker.postMessage({ ply: arrayBuffer }, [arrayBuffer]);
850
- console.log('PLY file sent to worker');
851
- } catch (error) {
852
- console.error('=== Error loading PLY ===');
853
- console.error('Error:', error);
854
- console.error('Stack:', error.stack);
855
- }
856
- }
857
-
858
- window.__load_splat__ = loadPlyFile;
859
-
860
- // Initialize on load
861
- if (document.readyState === 'loading') {
862
- document.addEventListener('DOMContentLoaded', () => {
863
- console.log('DOM loaded, initializing viewer in 100ms');
864
- setTimeout(initViewer, 100);
865
- });
866
- } else {
867
- console.log('DOM already loaded, initializing viewer in 100ms');
868
- setTimeout(initViewer, 100);
869
- }
870
- })();
871
- </script>
872
- """
873
 
874
  def main():
875
  """Run the FaceLift application with an embedded gsplat.js viewer and per-session files."""
@@ -881,26 +377,13 @@ def main():
881
  examples = [[str(f), True, 3.0, 4, 50] for f in sorted(pipeline.examples_dir.iterdir())
882
  if f.suffix.lower() in {'.png', '.jpg', '.jpeg'}]
883
 
884
- with gr.Blocks(head=GSPLAT_HEAD, title="FaceLift: Single Image 3D Face Reconstruction") as demo:
885
- session = gr.State()
886
-
887
- # Light GC + session init
888
- def _init_session():
889
- cleanup_old_sessions()
890
- return new_session_id()
891
 
892
- # After generation: copy ply into per-session folder and return viewer URL
893
- def _prep_viewer_url(ply_path: str, session_id: str):
894
- if not ply_path or not os.path.exists(ply_path):
895
- return "✗ Model file not found", ""
896
- url = copy_to_session_and_get_url(ply_path, session_id)
897
- return "✓ Model loaded! Use mouse to interact.", url
898
-
899
- # Wrapper to return only the outputs we want to display
900
  def _generate_and_filter_outputs(image_path, auto_crop, guidance_scale, random_seed, num_steps):
901
  input_path, multiview_path, output_path, turntable_path, ply_path = \
902
  pipeline.generate_3d_head(image_path, auto_crop, guidance_scale, random_seed, num_steps)
903
- return output_path, turntable_path, ply_path
904
 
905
  gr.Markdown("## FaceLift: Single Image 3D Face Reconstruction.")
906
 
@@ -928,36 +411,20 @@ def main():
928
  )
929
 
930
  with gr.Column(scale=1):
931
- gr.Markdown("### Interactive Gaussian Splat Viewer\n*Drag to rotate, Ctrl+drag to pan, scroll to zoom*")
932
- viewer = gr.HTML("<div id='splat-container' style='width:100%;height:600px;background:#000;border:1px solid #ccc;border-radius:8px;'></div>")
933
- viewer_status = gr.Textbox(label="Viewer Status", value="Viewer ready. Generate a model to view.", interactive=False)
934
- url_box = gr.Textbox(label="Model URL (for debugging)", interactive=False, visible=False)
935
- reload_btn = gr.Button("Reload Viewer", size="sm", visible=False)
936
 
937
  out_recon = gr.Image(label="3D Reconstruction")
938
  out_video = gr.PlayableVideo(label="Turntable Animation")
939
  out_ply = gr.File(label="3D Gaussians (.ply)")
940
 
941
- # Initialize per-browser session
942
- demo.load(fn=_init_session, inputs=None, outputs=session)
943
-
944
- # Chain: run → show outputs → prepare viewer URL → load viewer (JS)
945
  run_btn.click(
946
  fn=_generate_and_filter_outputs,
947
  inputs=[in_image, auto_crop, guidance, seed, steps],
948
- outputs=[out_recon, out_video, out_ply],
949
- ).then(
950
- fn=_prep_viewer_url,
951
- inputs=[out_ply, session],
952
- outputs=[viewer_status, url_box],
953
- ).then(
954
- fn=None, inputs=url_box, outputs=None,
955
- js="(url)=>{console.log('Calling __load_splat__ with:', url); return window.__load_splat__(url);}"
956
  )
957
 
958
- # Manual reload if needed
959
- reload_btn.click(fn=None, inputs=url_box, outputs=None, js="(url)=>window.__load_splat__(url)")
960
-
961
  demo.queue(max_size=10)
962
  demo.launch(share=True, server_name="0.0.0.0", server_port=7860, show_error=True)
963
 
 
365
  print(f"Error details:\n{error_details}")
366
  raise gr.Error(f"Generation failed: {str(e)}")
367
 
368
+ # No custom head needed - using Gradio's built-in Model3D component
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
 
370
  def main():
371
  """Run the FaceLift application with an embedded gsplat.js viewer and per-session files."""
 
377
  examples = [[str(f), True, 3.0, 4, 50] for f in sorted(pipeline.examples_dir.iterdir())
378
  if f.suffix.lower() in {'.png', '.jpg', '.jpeg'}]
379
 
380
+ with gr.Blocks(title="FaceLift: Single Image 3D Face Reconstruction") as demo:
 
 
 
 
 
 
381
 
382
+ # Wrapper to return all outputs including 3D model for viewer
 
 
 
 
 
 
 
383
  def _generate_and_filter_outputs(image_path, auto_crop, guidance_scale, random_seed, num_steps):
384
  input_path, multiview_path, output_path, turntable_path, ply_path = \
385
  pipeline.generate_3d_head(image_path, auto_crop, guidance_scale, random_seed, num_steps)
386
+ return ply_path, output_path, turntable_path, ply_path
387
 
388
  gr.Markdown("## FaceLift: Single Image 3D Face Reconstruction.")
389
 
 
411
  )
412
 
413
  with gr.Column(scale=1):
414
+ gr.Markdown("### 3D Model Viewer\n*Use mouse to rotate and zoom*")
415
+ viewer_3d = gr.Model3D(label="Interactive 3D Viewer", height=600)
 
 
 
416
 
417
  out_recon = gr.Image(label="3D Reconstruction")
418
  out_video = gr.PlayableVideo(label="Turntable Animation")
419
  out_ply = gr.File(label="3D Gaussians (.ply)")
420
 
421
+ # Run generation and display all outputs
 
 
 
422
  run_btn.click(
423
  fn=_generate_and_filter_outputs,
424
  inputs=[in_image, auto_crop, guidance, seed, steps],
425
+ outputs=[viewer_3d, out_recon, out_video, out_ply],
 
 
 
 
 
 
 
426
  )
427
 
 
 
 
428
  demo.queue(max_size=10)
429
  demo.launch(share=True, server_name="0.0.0.0", server_port=7860, show_error=True)
430