Commit
·
14ce2e6
1
Parent(s):
5b19a88
Upload modeling_ddpm.py
Browse files- modeling_ddpm.py +24 -26
modeling_ddpm.py
CHANGED
|
@@ -14,15 +14,13 @@
|
|
| 14 |
# limitations under the License.
|
| 15 |
|
| 16 |
|
| 17 |
-
from diffusers import DiffusionPipeline
|
| 18 |
-
import tqdm
|
| 19 |
import torch
|
| 20 |
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
class DDPM(DiffusionPipeline):
|
| 23 |
-
|
| 24 |
-
modeling_file = "modeling_ddpm.py"
|
| 25 |
|
|
|
|
| 26 |
def __init__(self, unet, noise_scheduler):
|
| 27 |
super().__init__()
|
| 28 |
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
|
|
@@ -32,30 +30,30 @@ class DDPM(DiffusionPipeline):
|
|
| 32 |
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 33 |
|
| 34 |
self.unet.to(torch_device)
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
| 45 |
with torch.no_grad():
|
| 46 |
-
|
| 47 |
|
| 48 |
-
#
|
| 49 |
-
|
| 50 |
-
pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual
|
| 51 |
-
pred_mean = torch.clamp(pred_mean, -1, 1)
|
| 52 |
-
prev_image = clip_coeff * pred_mean + image_coeff * image
|
| 53 |
|
| 54 |
-
#
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
#
|
| 58 |
-
|
| 59 |
-
image = sampled_prev_image
|
| 60 |
|
| 61 |
return image
|
|
|
|
| 14 |
# limitations under the License.
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
| 17 |
import torch
|
| 18 |
|
| 19 |
+
import tqdm
|
| 20 |
+
from diffusers import DiffusionPipeline
|
| 21 |
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
+
class DDPM(DiffusionPipeline):
|
| 24 |
def __init__(self, unet, noise_scheduler):
|
| 25 |
super().__init__()
|
| 26 |
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
|
|
|
|
| 30 |
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 31 |
|
| 32 |
self.unet.to(torch_device)
|
| 33 |
+
|
| 34 |
+
# Sample gaussian noise to begin loop
|
| 35 |
+
image = self.noise_scheduler.sample_noise(
|
| 36 |
+
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
|
| 37 |
+
device=torch_device,
|
| 38 |
+
generator=generator,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
num_prediction_steps = len(self.noise_scheduler)
|
| 42 |
+
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
|
| 43 |
+
# 1. predict noise residual
|
| 44 |
with torch.no_grad():
|
| 45 |
+
residual = self.unet(image, t)
|
| 46 |
|
| 47 |
+
# 2. predict previous mean of image x_t-1
|
| 48 |
+
pred_prev_image = self.noise_scheduler.compute_prev_image_step(residual, image, t)
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
+
# 3. optionally sample variance
|
| 51 |
+
variance = 0
|
| 52 |
+
if t > 0:
|
| 53 |
+
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
|
| 54 |
+
variance = self.noise_scheduler.get_variance(t).sqrt() * noise
|
| 55 |
|
| 56 |
+
# 4. set current image to prev_image: x_t -> x_t-1
|
| 57 |
+
image = pred_prev_image + variance
|
|
|
|
| 58 |
|
| 59 |
return image
|