Spaces:
Sleeping
Sleeping
| # https://www.mixedbread.ai/blog/mxbai-embed-large-v1 | |
| # https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1 | |
| import os | |
| import time | |
| import pandas as pd | |
| import numpy as np | |
| from typing import Dict | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer | |
| from sentence_transformers.util import cos_sim | |
| from accelerate import Accelerator # Import from accelerate | |
| from scipy.stats import zscore | |
| # Set up environment variables for Hugging Face caching | |
| #os.environ["HF_HUB_CACHE"] = "/eos/jeodpp/home/users/consose/cache/huggingface/hub" | |
| #os.environ["HUGGINGFACE_HUB_CACHE"] = "/eos/jeodpp/home/users/consose/cache/huggingface/hub" | |
| #os.environ["HF_HOME"] = "/eos/jeodpp/home/users/consose/cache/huggingface/hub" | |
| # Initialize the Accelerator | |
| accelerator = Accelerator() | |
| # Use the device managed by Accelerator | |
| device = accelerator.device | |
| print("Using accelerator device =", device) | |
| from sentence_transformers import CrossEncoder | |
| model_sf_mxbai = CrossEncoder("mixedbread-ai/mxbai-rerank-large-v1" ,device=device) | |
| def RAG_retrieval_Base(queryText ,passages, min_threshold=0.0, max_num_passages=None): | |
| # # Example query | |
| # query = "What is the capital of France?" | |
| # | |
| # # Example passages | |
| # ppppassages = [ | |
| # "This is the first passage.", | |
| # "The capital of France is Paris.", | |
| # "This is the third passage.", | |
| # "Paris is a beautiful city.", | |
| # "The Eiffel Tower is in Paris." | |
| # ] | |
| # | |
| # # Rank the passages with respect to the query | |
| # ranked_passages = model_sf_mxbai.rank(query, ppppassages) | |
| try: | |
| df_filtered = pd.DataFrame() | |
| if max_num_passages: | |
| result_rerank = model_sf_mxbai.rank(queryText, passages, return_documents=False, top_k=max_num_passages) | |
| else: | |
| nback =int(0.1 *len(passages)) # 10% of the number of passages | |
| if nback<=0: | |
| nback=1 | |
| result_rerank = model_sf_mxbai.rank(queryText, passages, return_documents=False, top_k=nback) | |
| if result_rerank: | |
| df = pd.DataFrame(result_rerank) # corpus_id, score | |
| if min_threshold >0: | |
| df_filtered = df[df['score'] >= min_threshold] | |
| else: | |
| df_filtered =df.copy() | |
| selected_passages = [passages[i] for i in df_filtered['corpus_id']] | |
| # Add the selected passages as a new column "Passage" to the DataFrame | |
| df_filtered['Passage'] = selected_passages | |
| df_filtered = df_filtered.drop_duplicates(subset='Passage', keep='first') | |
| # df_filtered = df_filtered.sort_values(by='score', ascending=False) | |
| # Return the filtered DataFrame | |
| return df_filtered | |
| except Exception as e: | |
| # Log the exception message or handle it as needed | |
| print(f"An error occurred: {e}") | |
| return pd.DataFrame() # Return an empty DataFrame in case of error | |
| if __name__ == '__main__': | |
| queryText = 'A man is eating a piece of bread' | |
| # Define the passages list | |
| passages = [ | |
| "A man is eating food.", | |
| "A man is eating pasta.", | |
| "The girl is carrying a baby.", | |
| "A man is riding a horse.", | |
| ] | |
| df_retrieved = RAG_retrieval_Base(queryText, passages, min_threshold=0, max_num_passages=3) | |
| print(df_retrieved) | |
| print("end of computations") | |