|
|
from typing import Any, Iterable, List, Optional |
|
|
from langchain_core.embeddings import Embeddings |
|
|
import uuid |
|
|
from langchain_community.vectorstores.lancedb import LanceDB |
|
|
|
|
|
class MultimodalLanceDB(LanceDB): |
|
|
"""`LanceDB` vector store to process multimodal data |
|
|
|
|
|
Parameters: |
|
|
----------- |
|
|
connection: Any |
|
|
LanceDB connection to use. If not provided, a new connection will be created. |
|
|
embedding: Embeddings |
|
|
Embedding to use for the vectorstore. |
|
|
vector_key: str |
|
|
Key to use for the vector in the database. Defaults to ``vector``. |
|
|
id_key: str |
|
|
Key to use for the id in the database. Defaults to ``id``. |
|
|
text_key: str |
|
|
Key to use for the text in the database. Defaults to ``text``. |
|
|
image_path_key: str |
|
|
Key to use for the path to image in the database. Defaults to ``image_path``. |
|
|
table_name: str |
|
|
Name of the table to use. Defaults to ``vectorstore``. |
|
|
api_key: str |
|
|
API key to use for LanceDB cloud database. |
|
|
region: str |
|
|
Region to use for LanceDB cloud database. |
|
|
mode: str |
|
|
Mode to use for adding data to the table. Defaults to ``overwrite``. |
|
|
|
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
connection: Optional[Any] = None, |
|
|
embedding: Optional[Embeddings] = None, |
|
|
uri: Optional[str] = "/tmp/lancedb", |
|
|
vector_key: Optional[str] = "vector", |
|
|
id_key: Optional[str] = "id", |
|
|
text_key: Optional[str] = "text", |
|
|
image_path_key: Optional[str] = "image_path", |
|
|
table_name: Optional[str] = "vectorstore", |
|
|
api_key: Optional[str] = None, |
|
|
region: Optional[str] = None, |
|
|
mode: Optional[str] = "append", |
|
|
): |
|
|
super(MultimodalLanceDB, self).__init__(connection, embedding, uri, vector_key, id_key, text_key, table_name, api_key, region, mode) |
|
|
self._image_path_key = image_path_key |
|
|
|
|
|
def add_text_image_pairs( |
|
|
self, |
|
|
texts: Iterable[str], |
|
|
image_paths: Iterable[str], |
|
|
metadatas: Optional[List[dict]] = None, |
|
|
ids: Optional[List[str]] = None, |
|
|
**kwargs: Any, |
|
|
) -> List[str]: |
|
|
"""Turn text-image pairs into embedding and add it to the database |
|
|
|
|
|
Parameters: |
|
|
---------- |
|
|
texts: Iterable[str] |
|
|
Iterable of strings to combine with corresponding images to add to the vectorstore. |
|
|
images: Iterable[str] |
|
|
Iterable of path-to-images as strings to combine with corresponding texts to add to the vectorstore. |
|
|
metadatas: List[str] |
|
|
Optional list of metadatas associated with the texts. |
|
|
ids: List[str] |
|
|
Optional list of ids to associate with the texts. |
|
|
|
|
|
Returns: |
|
|
-------- |
|
|
List of ids of the added text-image pairs. |
|
|
""" |
|
|
|
|
|
assert len(texts)==len(image_paths), "the len of transcripts should be equal to the len of images" |
|
|
|
|
|
print(f'The length of texts is {len(texts)}') |
|
|
|
|
|
|
|
|
docs = [] |
|
|
ids = ids or [str(uuid.uuid4()) for _ in texts] |
|
|
embeddings = self._embedding.embed_image_text_pairs(texts=list(texts), images=list(image_paths)) |
|
|
for idx, text in enumerate(texts): |
|
|
embedding = embeddings[idx] |
|
|
metadata = metadatas[idx] if metadatas else {"id": ids[idx]} |
|
|
docs.append( |
|
|
{ |
|
|
self._vector_key: embedding, |
|
|
self._id_key: ids[idx], |
|
|
self._text_key: text, |
|
|
self._image_path_key : image_paths[idx], |
|
|
"metadata": metadata, |
|
|
} |
|
|
) |
|
|
print(f'Adding {len(docs)} text-image pairs to the vectorstore...') |
|
|
|
|
|
if 'mode' in kwargs: |
|
|
mode = kwargs['mode'] |
|
|
else: |
|
|
mode = self.mode |
|
|
if self._table_name in self._connection.table_names(): |
|
|
tbl = self._connection.open_table(self._table_name) |
|
|
if self.api_key is None: |
|
|
tbl.add(docs) |
|
|
else: |
|
|
tbl.add(docs) |
|
|
else: |
|
|
self._connection.create_table(self._table_name, data=docs) |
|
|
return ids |
|
|
|
|
|
@classmethod |
|
|
def from_text_image_pairs( |
|
|
cls, |
|
|
texts: List[str], |
|
|
image_paths: List[str], |
|
|
embedding: Embeddings, |
|
|
metadatas: Optional[List[dict]] = None, |
|
|
connection: Any = None, |
|
|
vector_key: Optional[str] = "vector", |
|
|
id_key: Optional[str] = "id", |
|
|
text_key: Optional[str] = "text", |
|
|
image_path_key: Optional[str] = "image_path", |
|
|
table_name: Optional[str] = "vectorstore", |
|
|
**kwargs: Any, |
|
|
): |
|
|
|
|
|
instance = MultimodalLanceDB( |
|
|
connection=connection, |
|
|
embedding=embedding, |
|
|
vector_key=vector_key, |
|
|
id_key=id_key, |
|
|
text_key=text_key, |
|
|
image_path_key=image_path_key, |
|
|
table_name=table_name, |
|
|
) |
|
|
instance.add_text_image_pairs(texts, image_paths, metadatas=metadatas, **kwargs) |
|
|
|
|
|
return instance |