Argument and dtype fix
#1
by
SpiridonSunRotator
- opened
- pipeline.py +2 -4
pipeline.py
CHANGED
|
@@ -244,11 +244,9 @@ class SwDPipeline(StableDiffusion3Pipeline):
|
|
| 244 |
sigma = sigmas[i]
|
| 245 |
sigma_next = sigmas[i + 1]
|
| 246 |
x0_pred = (latents - sigma * noise_pred)
|
| 247 |
-
|
| 248 |
x0_pred = torch.nn.functional.interpolate(x0_pred, size=scales[i + 1], mode='bicubic')
|
| 249 |
-
|
| 250 |
-
x0_pred = x0_pred
|
| 251 |
-
noise = torch.randn(x0_pred.shape, generator=generator).to('cuda').half()
|
| 252 |
latents = (1 - sigma_next) * x0_pred + sigma_next * noise
|
| 253 |
|
| 254 |
if latents.dtype != latents_dtype:
|
|
|
|
| 244 |
sigma = sigmas[i]
|
| 245 |
sigma_next = sigmas[i + 1]
|
| 246 |
x0_pred = (latents - sigma * noise_pred)
|
| 247 |
+
if scales and i + 1 < len(scales):
|
| 248 |
x0_pred = torch.nn.functional.interpolate(x0_pred, size=scales[i + 1], mode='bicubic')
|
| 249 |
+
noise = torch.randn(x0_pred.shape, generator=generator, device=device, dtype=x0_pred.dtype)
|
|
|
|
|
|
|
| 250 |
latents = (1 - sigma_next) * x0_pred + sigma_next * noise
|
| 251 |
|
| 252 |
if latents.dtype != latents_dtype:
|