File size: 3,653 Bytes
3eeef5e
6ae16c3
755dc57
 
 
3eeef5e
 
755dc57
3eeef5e
 
74ae99d
3eeef5e
 
 
 
 
 
 
 
 
7a93c4f
3eeef5e
77b138b
06a854d
3eeef5e
 
 
 
7a93c4f
3eeef5e
77b138b
06a854d
3eeef5e
 
 
 
7a93c4f
3eeef5e
77b138b
06a854d
3eeef5e
 
 
 
 
 
4d07400
3eeef5e
 
 
 
c87ac9d
 
3eeef5e
 
 
 
 
 
 
4dad5ce
 
 
 
 
 
 
 
 
 
 
 
3eeef5e
74ae99d
 
 
 
 
 
3eeef5e
74ae99d
 
 
 
 
3eeef5e
 
 
 
74ae99d
 
 
 
 
 
 
 
3eeef5e
74ae99d
 
 
 
 
 
 
 
 
 
 
 
3eeef5e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from huggingface_hub import hf_hub_download, InferenceClient
from FlagEmbedding import FlagICLModel
import pandas as pd
import faiss
import numpy as np
import os
import json

# Define HF tokens
HF_TOKEN_read = os.environ.get("HF_TOKEN_read")
#HF_TOKEN_inference = os.environ.get("HF_TOKEN_inf")

# Dataset repo (private)
DATASET_REPO = "luciagomez/MrPhil_vector"

# -------------------------------------------------------------------
# 1. Download files from Hugging Face dataset
# -------------------------------------------------------------------
parquet_path = hf_hub_download(
    repo_id=DATASET_REPO,
    filename="bgem3/foundations.parquet",
    repo_type="dataset",
    token=HF_TOKEN_read,
    cache_dir="/tmp/huggingface"
)

faiss_path = hf_hub_download(
    repo_id=DATASET_REPO,
    filename="bgem3/faiss.index",
    repo_type="dataset",
    token=HF_TOKEN_read,
    cache_dir="/tmp/huggingface"
)

meta_path = hf_hub_download(
    repo_id=DATASET_REPO,
    filename="bgem3/meta.json",
    repo_type="dataset",
    token=HF_TOKEN_read,
    cache_dir="/tmp/huggingface"
)


# -------------------------------------------------------------------
# 2. Load data
# -------------------------------------------------------------------
df = pd.read_parquet(parquet_path,engine="pyarrow")
index = faiss.read_index(faiss_path)
with open(meta_path, "r") as f:
    meta = json.load(f)

dim = index.d
n = index.ntotal
print(f"Loaded FAISS index with {n} vectors of dimension {dim}")


# -------------------------------------------------------------------
# 3. Initialize BGE-ICL model for queries
# -------------------------------------------------------------------
examples = [
            {
                "instruct": "Retrieve foundations whose mission aligns with the given perspective.",
                "query": "Protect marine life while educating children about ocean conservation",
                "response": "Foundations working on marine conservation and youth education."
            },
            {
                "instruct": "Retrieve foundations whose mission aligns with the given perspective.",
                "query": "Promote renewable energy education and community awareness",
                "response": "Foundations focused on clean energy advocacy and public education."
            }
        ]


model = FlagICLModel(
    "BAAI/bge-en-icl",
    query_instruction_for_retrieval="Given a mission statement, retrieve foundations with aligned purposes.",
    examples_for_task=examples,
    use_fp16=False, # set True if GPU with enough memory
)

# -------------------------------
# Helper to encode queries
# -------------------------------
def encode_query(query: str) -> np.ndarray:
    return model.encode_queries([query])[0].astype("float32")  #  Encode a user query using BGE-EN-ICL.

# -------------------------------------------------------------------
# 4. Retrieval function
# -------------------------------------------------------------------
def find_similar_foundations(perspective: str, top_k: int = 5):
    """
    Given a user perspective, retrieve top-k foundations aligned with it.
    """
    # Encode perspective
    q_emb = encode_query(perspective).reshape(1, -1)  # FAISS expects 2D

    # Search FAISS index
    scores, idxs = index.search(q_emb, top_k)

    # Retrieve foundation info
    results = []
    for score, idx in zip(scores[0], idxs[0]):
        foundation_info = {
            "Title": foundations.iloc[idx]["Title"],
            "Purpose": foundations.iloc[idx]["Purpose"],
            "Score": float(score)
        }
        results.append(foundation_info)

    return results