Spaces:
Runtime error
Runtime error
Merge pull request #29 from asutermo/user/asutermo/improvements-to-tryon-cog
Browse filesCog: Fix Up a couple issues with having hybrid tryon/tryoff. Working on Replicate now :-)
- predict.py +9 -5
predict.py
CHANGED
|
@@ -19,8 +19,8 @@ class Predictor(BasePredictor):
|
|
| 19 |
hf_token: Secret = Input(description="Hugging Face API token. Create a write token at https://huggingface.co/settings/token. You also need to approve the Flux Dev terms."),
|
| 20 |
image: Path = Input(description="Image file path", default="https://github.com/nftblackmagic/catvton-flux/raw/main/example/person/1.jpg"),
|
| 21 |
mask: Path = Input(description="Mask file path", default="https://github.com/nftblackmagic/catvton-flux/blob/main/example/person/1_mask.png?raw=true"),
|
| 22 |
-
try_on: bool = Input(
|
| 23 |
-
garment: Path = Input(description="Garment file path
|
| 24 |
num_steps: int = Input(50, description="Number of steps to run the model for"),
|
| 25 |
guidance_scale: float = Input(30, description="Guidance scale for the model"),
|
| 26 |
seed: int = Input(0, description="Seed for the model"),
|
|
@@ -30,9 +30,9 @@ class Predictor(BasePredictor):
|
|
| 30 |
size = (width, height)
|
| 31 |
i = load_image(str(image)).convert("RGB").resize(size)
|
| 32 |
m = load_image(str(mask)).convert("RGB").resize(size)
|
| 33 |
-
g = load_image(str(garment)).convert("RGB").resize(size)
|
| 34 |
|
| 35 |
if try_on:
|
|
|
|
| 36 |
self.transformer = self.try_on_transformer
|
| 37 |
else:
|
| 38 |
self.transformer = self.try_off_transformer
|
|
@@ -57,7 +57,11 @@ class Predictor(BasePredictor):
|
|
| 57 |
# Transform images using the new preprocessing
|
| 58 |
image_tensor = transform(i)
|
| 59 |
mask_tensor = mask_transform(m)[:1] # Take only first channel
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
# Create concatenated images
|
| 63 |
inpaint_image = torch.cat([garment_tensor, image_tensor], dim=2) # Concatenate along width
|
|
@@ -66,7 +70,7 @@ class Predictor(BasePredictor):
|
|
| 66 |
if try_on:
|
| 67 |
extended_mask = torch.cat([garment_mask, mask_tensor], dim=2)
|
| 68 |
else:
|
| 69 |
-
extended_mask = torch.cat([1 - garment_mask,
|
| 70 |
|
| 71 |
prompt = f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " \
|
| 72 |
f"[IMAGE1] Detailed product shot of a clothing" \
|
|
|
|
| 19 |
hf_token: Secret = Input(description="Hugging Face API token. Create a write token at https://huggingface.co/settings/token. You also need to approve the Flux Dev terms."),
|
| 20 |
image: Path = Input(description="Image file path", default="https://github.com/nftblackmagic/catvton-flux/raw/main/example/person/1.jpg"),
|
| 21 |
mask: Path = Input(description="Mask file path", default="https://github.com/nftblackmagic/catvton-flux/blob/main/example/person/1_mask.png?raw=true"),
|
| 22 |
+
try_on: bool = Input(False, description="Try on or try off"),
|
| 23 |
+
garment: Path = Input(description="Garment file path like https://github.com/nftblackmagic/catvton-flux/raw/main/example/garment/00035_00.jpg", default=None),
|
| 24 |
num_steps: int = Input(50, description="Number of steps to run the model for"),
|
| 25 |
guidance_scale: float = Input(30, description="Guidance scale for the model"),
|
| 26 |
seed: int = Input(0, description="Seed for the model"),
|
|
|
|
| 30 |
size = (width, height)
|
| 31 |
i = load_image(str(image)).convert("RGB").resize(size)
|
| 32 |
m = load_image(str(mask)).convert("RGB").resize(size)
|
|
|
|
| 33 |
|
| 34 |
if try_on:
|
| 35 |
+
g = load_image(str(garment)).convert("RGB").resize(size)
|
| 36 |
self.transformer = self.try_on_transformer
|
| 37 |
else:
|
| 38 |
self.transformer = self.try_off_transformer
|
|
|
|
| 57 |
# Transform images using the new preprocessing
|
| 58 |
image_tensor = transform(i)
|
| 59 |
mask_tensor = mask_transform(m)[:1] # Take only first channel
|
| 60 |
+
if try_on:
|
| 61 |
+
garment_tensor = transform(g)
|
| 62 |
+
else:
|
| 63 |
+
garment_tensor = torch.zeros_like(image_tensor)
|
| 64 |
+
image_tensor = image_tensor * mask_tensor
|
| 65 |
|
| 66 |
# Create concatenated images
|
| 67 |
inpaint_image = torch.cat([garment_tensor, image_tensor], dim=2) # Concatenate along width
|
|
|
|
| 70 |
if try_on:
|
| 71 |
extended_mask = torch.cat([garment_mask, mask_tensor], dim=2)
|
| 72 |
else:
|
| 73 |
+
extended_mask = torch.cat([1 - garment_mask, garment_mask], dim=2)
|
| 74 |
|
| 75 |
prompt = f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " \
|
| 76 |
f"[IMAGE1] Detailed product shot of a clothing" \
|