Spaces:
Runtime error
Runtime error
fullvol auto-padding added
Browse files
app.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
import json
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
import nibabel as nib
|
| 5 |
import torch
|
|
|
|
| 6 |
import scipy.io
|
| 7 |
from io import BytesIO
|
| 8 |
from transformers import AutoModel
|
|
@@ -15,11 +17,32 @@ from skimage.filters import threshold_otsu
|
|
| 15 |
def infer_full_vol(tensor, model):
|
| 16 |
tensor = torch.movedim(tensor, -1, -3)
|
| 17 |
tensor = tensor / tensor.max()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
with torch.no_grad():
|
| 19 |
output = model(tensor)
|
| 20 |
if type(output) is tuple or type(output) is list:
|
| 21 |
output = output[0]
|
| 22 |
output = torch.sigmoid(output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
output = torch.movedim(output, -3, -1).type(tensor.type())
|
| 24 |
return output.squeeze().detach().cpu().numpy()
|
| 25 |
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import json
|
| 3 |
+
import math
|
| 4 |
import numpy as np
|
| 5 |
import nibabel as nib
|
| 6 |
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
import scipy.io
|
| 9 |
from io import BytesIO
|
| 10 |
from transformers import AutoModel
|
|
|
|
| 17 |
def infer_full_vol(tensor, model):
|
| 18 |
tensor = torch.movedim(tensor, -1, -3)
|
| 19 |
tensor = tensor / tensor.max()
|
| 20 |
+
|
| 21 |
+
sizes = tensor.shape[-3:]
|
| 22 |
+
new_sizes = [math.ceil(s / 16) * 16 for s in sizes]
|
| 23 |
+
total_pads = [new_size - s for s, new_size in zip(sizes, new_sizes)]
|
| 24 |
+
pad_before = [pad // 2 for pad in total_pads]
|
| 25 |
+
pad_after = [pad - pad_before[i] for i, pad in enumerate(total_pads)]
|
| 26 |
+
padding = []
|
| 27 |
+
for i in reversed(range(len(pad_before))):
|
| 28 |
+
padding.extend([pad_before[i], pad_after[i]])
|
| 29 |
+
tensor = F.pad(tensor, padding)
|
| 30 |
+
|
| 31 |
with torch.no_grad():
|
| 32 |
output = model(tensor)
|
| 33 |
if type(output) is tuple or type(output) is list:
|
| 34 |
output = output[0]
|
| 35 |
output = torch.sigmoid(output)
|
| 36 |
+
|
| 37 |
+
slices = [slice(None)] * output.dim()
|
| 38 |
+
for i in range(len(pad_before)):
|
| 39 |
+
dim = -3 + i
|
| 40 |
+
start = pad_before[i]
|
| 41 |
+
size = sizes[i]
|
| 42 |
+
end = start + size
|
| 43 |
+
slices[dim] = slice(start, end)
|
| 44 |
+
output = output[tuple(slices)]
|
| 45 |
+
|
| 46 |
output = torch.movedim(output, -3, -1).type(tensor.type())
|
| 47 |
return output.squeeze().detach().cpu().numpy()
|
| 48 |
|