MisPhil_v3 / retriever.py
luciagomez's picture
Update retriever.py
7a93c4f verified
from huggingface_hub import hf_hub_download, InferenceClient
from FlagEmbedding import FlagICLModel
import pandas as pd
import faiss
import numpy as np
import os
import json
# Define HF tokens
HF_TOKEN_read = os.environ.get("HF_TOKEN_read")
#HF_TOKEN_inference = os.environ.get("HF_TOKEN_inf")
# Dataset repo (private)
DATASET_REPO = "luciagomez/MrPhil_vector"
# -------------------------------------------------------------------
# 1. Download files from Hugging Face dataset
# -------------------------------------------------------------------
parquet_path = hf_hub_download(
repo_id=DATASET_REPO,
filename="bgem3/foundations.parquet",
repo_type="dataset",
token=HF_TOKEN_read,
cache_dir="/tmp/huggingface"
)
faiss_path = hf_hub_download(
repo_id=DATASET_REPO,
filename="bgem3/faiss.index",
repo_type="dataset",
token=HF_TOKEN_read,
cache_dir="/tmp/huggingface"
)
meta_path = hf_hub_download(
repo_id=DATASET_REPO,
filename="bgem3/meta.json",
repo_type="dataset",
token=HF_TOKEN_read,
cache_dir="/tmp/huggingface"
)
# -------------------------------------------------------------------
# 2. Load data
# -------------------------------------------------------------------
df = pd.read_parquet(parquet_path,engine="pyarrow")
index = faiss.read_index(faiss_path)
with open(meta_path, "r") as f:
meta = json.load(f)
dim = index.d
n = index.ntotal
print(f"Loaded FAISS index with {n} vectors of dimension {dim}")
# -------------------------------------------------------------------
# 3. Initialize BGE-ICL model for queries
# -------------------------------------------------------------------
examples = [
{
"instruct": "Retrieve foundations whose mission aligns with the given perspective.",
"query": "Protect marine life while educating children about ocean conservation",
"response": "Foundations working on marine conservation and youth education."
},
{
"instruct": "Retrieve foundations whose mission aligns with the given perspective.",
"query": "Promote renewable energy education and community awareness",
"response": "Foundations focused on clean energy advocacy and public education."
}
]
model = FlagICLModel(
"BAAI/bge-en-icl",
query_instruction_for_retrieval="Given a mission statement, retrieve foundations with aligned purposes.",
examples_for_task=examples,
use_fp16=False, # set True if GPU with enough memory
)
# -------------------------------
# Helper to encode queries
# -------------------------------
def encode_query(query: str) -> np.ndarray:
return model.encode_queries([query])[0].astype("float32") # Encode a user query using BGE-EN-ICL.
# -------------------------------------------------------------------
# 4. Retrieval function
# -------------------------------------------------------------------
def find_similar_foundations(perspective: str, top_k: int = 5):
"""
Given a user perspective, retrieve top-k foundations aligned with it.
"""
# Encode perspective
q_emb = encode_query(perspective).reshape(1, -1) # FAISS expects 2D
# Search FAISS index
scores, idxs = index.search(q_emb, top_k)
# Retrieve foundation info
results = []
for score, idx in zip(scores[0], idxs[0]):
foundation_info = {
"Title": foundations.iloc[idx]["Title"],
"Purpose": foundations.iloc[idx]["Purpose"],
"Score": float(score)
}
results.append(foundation_info)
return results