Optional texts in processor.
Browse files- processor.py +10 -9
processor.py
CHANGED
|
@@ -38,7 +38,7 @@ class CondViTProcessor(ImageProcessingMixin):
|
|
| 38 |
img = F.normalize(img, self.image_mean, self.image_std)
|
| 39 |
return img
|
| 40 |
|
| 41 |
-
def __call__(self, images, texts):
|
| 42 |
"""
|
| 43 |
Parameters
|
| 44 |
----------
|
|
@@ -55,14 +55,15 @@ class CondViTProcessor(ImageProcessingMixin):
|
|
| 55 |
texts : Union[str, List[str]]
|
| 56 |
"""
|
| 57 |
# Single Image
|
|
|
|
| 58 |
if isinstance(images, Image.Image):
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
| 61 |
)
|
| 62 |
|
| 63 |
-
|
| 64 |
-
data=
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
}
|
| 68 |
-
)
|
|
|
|
| 38 |
img = F.normalize(img, self.image_mean, self.image_std)
|
| 39 |
return img
|
| 40 |
|
| 41 |
+
def __call__(self, images, texts=None):
|
| 42 |
"""
|
| 43 |
Parameters
|
| 44 |
----------
|
|
|
|
| 55 |
texts : Union[str, List[str]]
|
| 56 |
"""
|
| 57 |
# Single Image
|
| 58 |
+
data = {}
|
| 59 |
if isinstance(images, Image.Image):
|
| 60 |
+
data["pixel_values"] = self.process_img(images)
|
| 61 |
+
else:
|
| 62 |
+
data["pixel_values"] = torch.stack(
|
| 63 |
+
[self.process_img(img) for img in images]
|
| 64 |
)
|
| 65 |
|
| 66 |
+
if texts is not None:
|
| 67 |
+
data["texts"] = texts
|
| 68 |
+
|
| 69 |
+
return BatchFeature(data=data)
|
|
|
|
|
|