luciagomez commited on
Commit
74ae99d
·
verified ·
1 Parent(s): 91b3ba9

Update retriever.py

Browse files
Files changed (1) hide show
  1. retriever.py +32 -22
retriever.py CHANGED
@@ -8,14 +8,11 @@ import json
8
 
9
  # Define HF tokens
10
  HF_TOKEN_read = os.environ.get("HF_TOKEN_read")
11
- HF_TOKEN_inference = os.environ.get("HF_TOKEN_inf")
12
 
13
  # Make sure cache is redirected to /tmp
14
  os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache"
15
 
16
- # Setup InferenceClient for embeddings
17
- client = InferenceClient(provider="nebius", api_key=HF_TOKEN_inference)
18
-
19
  # Dataset repo (private)
20
  DATASET_REPO = "luciagomez/MrPhil_vector"
21
 
@@ -78,29 +75,42 @@ examples = [
78
  ]
79
 
80
 
81
- #model = FlagICLModel(
82
- # "BAAI/bge-en-icl",
83
- # query_instruction_for_retrieval="Given a mission statement, retrieve foundations with aligned purposes.",
84
- # examples_for_task=examples,
85
- # use_fp16=False
86
- #)
87
 
88
- def encode_query(perspective: str):
89
- response = client.feature_extraction(
90
- perspective,model="BAAI/bge-en-icl",
91
- )
92
- return np.array(response)
93
 
94
  # -------------------------------------------------------------------
95
  # 4. Retrieval function
96
  # -------------------------------------------------------------------
97
- def find_similar_foundations(perspective, top_k=5):
98
- q_emb = encode_query(perspective).astype("float32")
99
- faiss.normalize_L2(q_emb)
 
 
 
 
 
100
  scores, idxs = index.search(q_emb, top_k)
101
- return [
102
- {"title": df.iloc[i]["Title"], "purpose": df.iloc[i]["Purpose"], "similarity": float(scores[0][j])}
103
- for j, i in enumerate(idxs[0])
104
- ]
 
 
 
 
 
 
 
 
105
 
106
 
 
8
 
9
  # Define HF tokens
10
  HF_TOKEN_read = os.environ.get("HF_TOKEN_read")
11
+ #HF_TOKEN_inference = os.environ.get("HF_TOKEN_inf")
12
 
13
  # Make sure cache is redirected to /tmp
14
  os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache"
15
 
 
 
 
16
  # Dataset repo (private)
17
  DATASET_REPO = "luciagomez/MrPhil_vector"
18
 
 
75
  ]
76
 
77
 
78
+ model = FlagICLModel(
79
+ "BAAI/bge-en-icl",
80
+ query_instruction_for_retrieval="Given a mission statement, retrieve foundations with aligned purposes.",
81
+ examples_for_task=examples,
82
+ use_fp16=False, # set True if GPU with enough memory
83
+ )
84
 
85
+ # -------------------------------
86
+ # Helper to encode queries
87
+ # -------------------------------
88
+ def encode_query(query: str) -> np.ndarray:
89
+ return model.encode_queries([query])[0].astype("float32") # Encode a user query using BGE-EN-ICL.
90
 
91
  # -------------------------------------------------------------------
92
  # 4. Retrieval function
93
  # -------------------------------------------------------------------
94
+ def find_similar_foundations(perspective: str, top_k: int = 5):
95
+ """
96
+ Given a user perspective, retrieve top-k foundations aligned with it.
97
+ """
98
+ # Encode perspective
99
+ q_emb = encode_query(perspective).reshape(1, -1) # FAISS expects 2D
100
+
101
+ # Search FAISS index
102
  scores, idxs = index.search(q_emb, top_k)
103
+
104
+ # Retrieve foundation info
105
+ results = []
106
+ for score, idx in zip(scores[0], idxs[0]):
107
+ foundation_info = {
108
+ "Title": foundations.iloc[idx]["Title"],
109
+ "Purpose": foundations.iloc[idx]["Purpose"],
110
+ "Score": float(score)
111
+ }
112
+ results.append(foundation_info)
113
+
114
+ return results
115
 
116