mmrag-hf / src /preprocess /embedding.py
doggdad's picture
Upload 18 files
21cbf97 verified
from utils import encode_image
from utils import bt_embeddings
from tqdm import tqdm
from typing import List
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel
class BridgeTowerEmbeddings(BaseModel,Embeddings):
""" BridgeTower embedding model """
def embed_image_text_pairs(self, texts: List[str], images: List[str], batch_size=2) -> List[List[float]]:
"""Embed a list of image-text pairs using BridgeTower.
Parameters:
-----------
texts: str
The list of texts to embed.
images: List
The list of path-to-images to embed
batch_size: int
The batch size to process, default to 2
Returns:
--------
List of embeddings, one for each image-text pairs.
"""
# the length of texts must be equal to the length of images
assert len(texts)==len(images), "the len of captions should be equal to the len of images"
print(f"Embedding {len(texts)} image-text pairs...")
embeddings = []
for path_to_img, text in tqdm(zip(images, texts), total=len(images), desc="Processing pairs"):
embedding = bt_embeddings(text, encode_image(path_to_img))
embeddings.append(embedding)
return embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of documents using BridgeTower.
Parameters:
-----------
texts: str
The list of texts to embed.
Returns:
--------
List of embeddings, one for each text.
"""
embeddings = []
for text in texts:
embedding = bt_embeddings(text, "")
embeddings.append(embedding)
return embeddings
def embed_query(self, text: str) -> List[float]:
"""Embed a query using BridgeTower.
Parameters:
-----------
texts: str
The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]