Spaces:
Sleeping
Sleeping
Consoli Sergio
commited on
Commit
·
66b8c66
1
Parent(s):
eff93c6
corrected bug on double repetition on history
Browse files- app-demo-myMultiNER.py +103 -43
- nerBio.py +263 -118
- retrieverRAG_SF.py +114 -0
- virtuosoQueryRest.py +4 -0
app-demo-myMultiNER.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
-
#
|
| 4 |
-
|
| 5 |
-
#
|
| 6 |
-
#
|
| 7 |
-
#
|
| 8 |
|
| 9 |
from transformers import file_utils
|
| 10 |
print(file_utils.default_cache_path)
|
|
@@ -76,8 +76,8 @@ examples = [
|
|
| 76 |
|
| 77 |
|
| 78 |
|
| 79 |
-
|
| 80 |
-
models_List = ["Babelscape/wikineural-multilingual-ner", "urchade/gliner_large-v2.1", "NCBO/BioPortal" ] # "urchade/gliner_large-v2.1", "knowledgator/gliner-multitask-large-v0.5"
|
| 81 |
#models_List = ["NCBO/BioPortal" ]
|
| 82 |
|
| 83 |
#categories_List = ["MED","LOC","PER","ORG","DATE","MISC"]
|
|
@@ -189,7 +189,12 @@ def nerBio(text, ModelsSelection, CategoriesSelection, ScoreFilt, EntityLinking,
|
|
| 189 |
state = {
|
| 190 |
"text": "",
|
| 191 |
"df_annotated_dict": dict(),
|
| 192 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
}
|
| 194 |
return {"text": text, "entities": []}, html_output, state, [], ""
|
| 195 |
|
|
@@ -224,7 +229,7 @@ def nerBio(text, ModelsSelection, CategoriesSelection, ScoreFilt, EntityLinking,
|
|
| 224 |
help="List of ontologies to which restrict the entity linking task.")
|
| 225 |
#consose 20250502:
|
| 226 |
if Counter(KGchoices) == Counter(POSSIBLE_KGchoices_List):
|
| 227 |
-
parser.add_argument("--USE_CACHE", type=str, default="
|
| 228 |
help="whether to use cache for the NER and NEL tasks or not")
|
| 229 |
else:
|
| 230 |
#print("Lists do not have the same elements")
|
|
@@ -237,6 +242,8 @@ def nerBio(text, ModelsSelection, CategoriesSelection, ScoreFilt, EntityLinking,
|
|
| 237 |
help="whether to extract a readable context from the extracted triples for the concept")
|
| 238 |
parser.add_argument("--computeEntityGlobalContext", type=str, default="False",
|
| 239 |
help="whether to extract a readable context from the extracted triples of all the entities extracted from the endpoint for the concept")
|
|
|
|
|
|
|
| 240 |
parser.add_argument("--UseRetrieverForContextCreation", type=str, default="True",
|
| 241 |
help="whether to use a retriever for the creation of the context of the entities from the triples coming from the KGs")
|
| 242 |
|
|
@@ -257,7 +264,39 @@ def nerBio(text, ModelsSelection, CategoriesSelection, ScoreFilt, EntityLinking,
|
|
| 257 |
if state:
|
| 258 |
previous_text = state.get("text", "")
|
| 259 |
previous_df_annotated_dict = state.get("df_annotated_dict", {})
|
|
|
|
| 260 |
previous_kg_choices = state.get("KGchoices", [])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
|
| 262 |
#print("Are all models in any row of the 'model' column, case-insensitively?", all_models_in_any_row)
|
| 263 |
#if (not history_dict) or (history_dict[args.source_column][0] != text) or (all_models_in_any_row == False):
|
|
@@ -319,7 +358,12 @@ def nerBio(text, ModelsSelection, CategoriesSelection, ScoreFilt, EntityLinking,
|
|
| 319 |
state = {
|
| 320 |
"text": text,
|
| 321 |
"df_annotated_dict": df_annotated.to_dict(),
|
| 322 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
}
|
| 324 |
|
| 325 |
else:
|
|
@@ -341,7 +385,12 @@ def nerBio(text, ModelsSelection, CategoriesSelection, ScoreFilt, EntityLinking,
|
|
| 341 |
state = {
|
| 342 |
"text": text,
|
| 343 |
"df_annotated_dict": df_annotated.to_dict(),
|
| 344 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
}
|
| 346 |
|
| 347 |
|
|
@@ -353,6 +402,7 @@ def nerBio(text, ModelsSelection, CategoriesSelection, ScoreFilt, EntityLinking,
|
|
| 353 |
df_annotated = df_annotated[df_annotated['model'].str.lower().isin([model.lower() for model in ModelsSelection])]
|
| 354 |
if df_annotated.empty and quoted_text==False:
|
| 355 |
html_output = f"<div class='gr-textbox' style='white-space: pre-wrap; overflow-wrap: break-word; padding: 10px; border: 1px solid #ddd; border-radius: 5px; font-family: monospace; font-size: 12px; line-height: 24px;'>{text}</div>"
|
|
|
|
| 356 |
return {"text": text, "entities": []}, html_output, state, [], ""
|
| 357 |
|
| 358 |
df_annotated_combined = pd.DataFrame()
|
|
@@ -360,6 +410,7 @@ def nerBio(text, ModelsSelection, CategoriesSelection, ScoreFilt, EntityLinking,
|
|
| 360 |
df_annotated_combined = entitiesFusion(df_annotated,args)
|
| 361 |
if df_annotated_combined.empty and quoted_text==False:
|
| 362 |
html_output = f"<div class='gr-textbox' style='white-space: pre-wrap; overflow-wrap: break-word; padding: 10px; border: 1px solid #ddd; border-radius: 5px; font-family: monospace; font-size: 12px; line-height: 24px;'>{text}</div>"
|
|
|
|
| 363 |
return {"text": text, "entities": []}, html_output, state, [], ""
|
| 364 |
else:
|
| 365 |
if (not df_annotated.empty):
|
|
@@ -530,6 +581,7 @@ def nerBio(text, ModelsSelection, CategoriesSelection, ScoreFilt, EntityLinking,
|
|
| 530 |
df_annotated_combined = df_annotated_combined[filter_mask]
|
| 531 |
if df_annotated_combined.empty:
|
| 532 |
html_output = f"<div class='gr-textbox' style='white-space: pre-wrap; overflow-wrap: break-word; padding: 10px; border: 1px solid #ddd; border-radius: 5px; font-family: monospace; font-size: 12px; line-height: 24px;'>{text}</div>"
|
|
|
|
| 533 |
return {"text": text, "entities": []}, html_output, state, [], ""
|
| 534 |
|
| 535 |
###
|
|
@@ -540,6 +592,7 @@ def nerBio(text, ModelsSelection, CategoriesSelection, ScoreFilt, EntityLinking,
|
|
| 540 |
df_annotated_combined = df_annotated_combined[df_annotated_combined['IsCrossInside'] != 1]
|
| 541 |
if df_annotated_combined.empty:
|
| 542 |
html_output = f"<div class='gr-textbox' style='white-space: pre-wrap; overflow-wrap: break-word; padding: 10px; border: 1px solid #ddd; border-radius: 5px; font-family: monospace; font-size: 12px; line-height: 24px;'>{text}</div>"
|
|
|
|
| 543 |
return {"text": text, "entities": []}, html_output, state, [], ""
|
| 544 |
|
| 545 |
dict_annotated_combined_NER = df_annotated_combined[["end", "entity_group", "score", "start", "word"]].to_dict(orient="records")
|
|
@@ -550,15 +603,15 @@ def nerBio(text, ModelsSelection, CategoriesSelection, ScoreFilt, EntityLinking,
|
|
| 550 |
|
| 551 |
# # Create a new column for the entities with links
|
| 552 |
df_annotated_combined['entity_with_link'] = df_annotated_combined.apply(
|
| 553 |
-
#lambda row: (
|
| 554 |
-
#
|
| 555 |
-
#
|
| 556 |
-
#
|
| 557 |
-
#),
|
| 558 |
lambda row: (
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
),
|
| 563 |
axis=1
|
| 564 |
)
|
|
@@ -641,17 +694,20 @@ def nerBio(text, ModelsSelection, CategoriesSelection, ScoreFilt, EntityLinking,
|
|
| 641 |
words_for_dropdown = []
|
| 642 |
|
| 643 |
html_output = f"<div class='gr-textbox' style='white-space: pre-wrap; overflow-wrap: break-word; padding: 10px; border: 1px solid #ddd; border-radius: 5px; font-family: monospace; font-size: 12px; line-height: 24px;'>{text_with_links}</div>"
|
|
|
|
| 644 |
|
| 645 |
#return {"text": text, "entities": dict_annotated_combined_NER}, html_output, state
|
| 646 |
return {"text": text, "entities": dict_annotated_combined_NER}, html_output, state, gr.update(choices=words_for_dropdown), ""
|
| 647 |
|
| 648 |
else:
|
| 649 |
html_output = f"<div class='gr-textbox' style='white-space: pre-wrap; overflow-wrap: break-word; padding: 10px; border: 1px solid #ddd; border-radius: 5px; font-family: monospace; font-size: 12px; line-height: 24px;'>{text}</div>"
|
|
|
|
| 650 |
return {"text": text, "entities": dict_annotated_combined_NER}, html_output, state, [], ""
|
| 651 |
|
| 652 |
else:
|
| 653 |
|
| 654 |
html_output = f"<div class='gr-textbox' style='white-space: pre-wrap; overflow-wrap: break-word; padding: 10px; border: 1px solid #ddd; border-radius: 5px; font-family: monospace; font-size: 12px; line-height: 24px;'>{text}</div>"
|
|
|
|
| 655 |
return {"text": text, "entities": []}, html_output, state, [], ""
|
| 656 |
|
| 657 |
|
|
@@ -663,28 +719,32 @@ def update_urls(selected_word, state):
|
|
| 663 |
# Convert the state dictionary back into a DataFrame
|
| 664 |
df = pd.DataFrame(state["df_annotated_combined_dict"])
|
| 665 |
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 688 |
|
| 689 |
else:
|
| 690 |
return ""
|
|
@@ -768,7 +828,7 @@ with gr.Blocks(title="BioAnnotator") as demo:
|
|
| 768 |
text_input = gr.Textbox(label="Input text", placeholder="Enter text here...")
|
| 769 |
models_selection = gr.CheckboxGroup(models_List, label="ModelsSelection", value=models_List)
|
| 770 |
categories_selection = gr.CheckboxGroup(categories_List, label="CategoriesSelection", value=categories_List)
|
| 771 |
-
score_slider = gr.Slider(minimum=0, maximum=1.0, step=0.
|
| 772 |
nel_checkbox = gr.Checkbox(label="Enable Named-Entity Linking (NEL)", value=False)
|
| 773 |
kgchoices_selection = gr.Dropdown(POSSIBLE_KGchoices_List, multiselect=True, label="KGchoices Selection", value=POSSIBLE_KGchoices_List)
|
| 774 |
state = gr.State(value={})
|
|
@@ -824,4 +884,4 @@ with gr.Blocks(title="BioAnnotator") as demo:
|
|
| 824 |
|
| 825 |
|
| 826 |
demo.launch()
|
| 827 |
-
#demo.launch(share=True) # Share your demo with just 1 extra parameter
|
|
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
+
#os.environ["CUDA_VISIBLE_DEVICES"] = "1,6" # to use the GPUs 3,4 only
|
| 4 |
+
|
| 5 |
+
#os.environ["HF_HUB_CACHE"] = "/eos/jeodpp/home/users/consose/cache/huggingface/hub"
|
| 6 |
+
#os.environ["HUGGINGFACE_HUB_CACHE"] = "/eos/jeodpp/home/users/consose/cache/huggingface/hub"
|
| 7 |
+
#os.environ["HF_HOME"] = "/eos/jeodpp/home/users/consose/cache/huggingface/hub"
|
| 8 |
|
| 9 |
from transformers import file_utils
|
| 10 |
print(file_utils.default_cache_path)
|
|
|
|
| 76 |
|
| 77 |
|
| 78 |
|
| 79 |
+
models_List = ["FacebookAI/xlm-roberta-large-finetuned-conll03-english", "Babelscape/wikineural-multilingual-ner", "blaze999/Medical-NER", "urchade/gliner_large-v2.1", "urchade/gliner_large_bio-v0.1", "NCBO/BioPortal" ] # "urchade/gliner_large-v2.1", "knowledgator/gliner-multitask-large-v0.5"
|
| 80 |
+
#models_List = ["Babelscape/wikineural-multilingual-ner", "urchade/gliner_large-v2.1", "NCBO/BioPortal" ] # "urchade/gliner_large-v2.1", "knowledgator/gliner-multitask-large-v0.5"
|
| 81 |
#models_List = ["NCBO/BioPortal" ]
|
| 82 |
|
| 83 |
#categories_List = ["MED","LOC","PER","ORG","DATE","MISC"]
|
|
|
|
| 189 |
state = {
|
| 190 |
"text": "",
|
| 191 |
"df_annotated_dict": dict(),
|
| 192 |
+
"df_annotated_combined_dict": dict(),
|
| 193 |
+
"KGchoices": KGchoices,
|
| 194 |
+
"ModelsSelection": ModelsSelection,
|
| 195 |
+
"ScoreFilt": ScoreFilt,
|
| 196 |
+
"EntityLinking": EntityLinking,
|
| 197 |
+
"html_output": html_output
|
| 198 |
}
|
| 199 |
return {"text": text, "entities": []}, html_output, state, [], ""
|
| 200 |
|
|
|
|
| 229 |
help="List of ontologies to which restrict the entity linking task.")
|
| 230 |
#consose 20250502:
|
| 231 |
if Counter(KGchoices) == Counter(POSSIBLE_KGchoices_List):
|
| 232 |
+
parser.add_argument("--USE_CACHE", type=str, default="True",
|
| 233 |
help="whether to use cache for the NER and NEL tasks or not")
|
| 234 |
else:
|
| 235 |
#print("Lists do not have the same elements")
|
|
|
|
| 242 |
help="whether to extract a readable context from the extracted triples for the concept")
|
| 243 |
parser.add_argument("--computeEntityGlobalContext", type=str, default="False",
|
| 244 |
help="whether to extract a readable context from the extracted triples of all the entities extracted from the endpoint for the concept")
|
| 245 |
+
parser.add_argument("--maxTriplesGlobalContext", type=int, default=20000,
|
| 246 |
+
help="maximum number of triples to consider for global context computation") # if 0 or None it is not considered
|
| 247 |
parser.add_argument("--UseRetrieverForContextCreation", type=str, default="True",
|
| 248 |
help="whether to use a retriever for the creation of the context of the entities from the triples coming from the KGs")
|
| 249 |
|
|
|
|
| 264 |
if state:
|
| 265 |
previous_text = state.get("text", "")
|
| 266 |
previous_df_annotated_dict = state.get("df_annotated_dict", {})
|
| 267 |
+
previous_df_annotated_combined_dict = state.get("df_annotated_combined_dict", {})
|
| 268 |
previous_kg_choices = state.get("KGchoices", [])
|
| 269 |
+
previous_ModelsSelection = state.get("ModelsSelection", [])
|
| 270 |
+
previous_ScoreFilt_from_state = float(state.get("ScoreFilt", ScoreFilt)) # Ensure ScoreFilt is a float
|
| 271 |
+
previous_EntityLinking_from_state = bool(state.get("EntityLinking", EntityLinking)) # Ensure EntityLinking is a boolean
|
| 272 |
+
previous_html_output = state.get("html_output", "")
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
if previous_html_output and (previous_df_annotated_dict) and (previous_df_annotated_combined_dict) and (previous_text == text) and (sorted(previous_kg_choices) == sorted(KGchoices)) and (sorted(previous_ModelsSelection) == sorted(ModelsSelection)) and (previous_ScoreFilt_from_state == ScoreFilt) and (previous_EntityLinking_from_state == EntityLinking):
|
| 276 |
+
ddf_annot_prev = pd.DataFrame(previous_df_annotated_combined_dict)
|
| 277 |
+
if 'ALLURIScontext' in ddf_annot_prev.columns:
|
| 278 |
+
# words_for_dropdown = df_annotated_combined[
|
| 279 |
+
# df_annotated_combined['ALLURIScontext'].apply(lambda x: x is not None and x != [])][
|
| 280 |
+
# 'word'].unique().tolist()
|
| 281 |
+
words_for_dropdown = ddf_annot_prev[ddf_annot_prev['ALLURIScontext'].apply(
|
| 282 |
+
lambda x: x is not None and x != [] and (isinstance(x, list) and len(x) > 0) and (
|
| 283 |
+
isinstance(x, list) and (not (len(x) == 1 and not str(x[0]).strip()))))][
|
| 284 |
+
'word'].unique().tolist()
|
| 285 |
+
words_for_dropdown = list({entry.lower(): entry for entry in words_for_dropdown}.values())
|
| 286 |
+
words_for_dropdown.insert(0, "")
|
| 287 |
+
else:
|
| 288 |
+
words_for_dropdown = []
|
| 289 |
+
|
| 290 |
+
dict_annotated_combined_NER = ddf_annot_prev[
|
| 291 |
+
["end", "entity_group", "score", "start", "word"]].to_dict(orient="records")
|
| 292 |
+
|
| 293 |
+
# return {"text": text, "entities": dict_annotated_combined_NER}, html_output, state
|
| 294 |
+
return {"text": text, "entities": dict_annotated_combined_NER}, previous_html_output, state, gr.update(
|
| 295 |
+
choices=words_for_dropdown), ""
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
|
| 300 |
|
| 301 |
#print("Are all models in any row of the 'model' column, case-insensitively?", all_models_in_any_row)
|
| 302 |
#if (not history_dict) or (history_dict[args.source_column][0] != text) or (all_models_in_any_row == False):
|
|
|
|
| 358 |
state = {
|
| 359 |
"text": text,
|
| 360 |
"df_annotated_dict": df_annotated.to_dict(),
|
| 361 |
+
"df_annotated_combined_dict": dict(),
|
| 362 |
+
"KGchoices": KGchoices,
|
| 363 |
+
"ModelsSelection": ModelsSelection,
|
| 364 |
+
"ScoreFilt": ScoreFilt,
|
| 365 |
+
"EntityLinking": EntityLinking,
|
| 366 |
+
"html_output": ""
|
| 367 |
}
|
| 368 |
|
| 369 |
else:
|
|
|
|
| 385 |
state = {
|
| 386 |
"text": text,
|
| 387 |
"df_annotated_dict": df_annotated.to_dict(),
|
| 388 |
+
"df_annotated_combined_dict": dict(),
|
| 389 |
+
"KGchoices": KGchoices,
|
| 390 |
+
"ModelsSelection": ModelsSelection,
|
| 391 |
+
"ScoreFilt": ScoreFilt,
|
| 392 |
+
"EntityLinking": EntityLinking,
|
| 393 |
+
"html_output": ""
|
| 394 |
}
|
| 395 |
|
| 396 |
|
|
|
|
| 402 |
df_annotated = df_annotated[df_annotated['model'].str.lower().isin([model.lower() for model in ModelsSelection])]
|
| 403 |
if df_annotated.empty and quoted_text==False:
|
| 404 |
html_output = f"<div class='gr-textbox' style='white-space: pre-wrap; overflow-wrap: break-word; padding: 10px; border: 1px solid #ddd; border-radius: 5px; font-family: monospace; font-size: 12px; line-height: 24px;'>{text}</div>"
|
| 405 |
+
state["html_output"] = html_output
|
| 406 |
return {"text": text, "entities": []}, html_output, state, [], ""
|
| 407 |
|
| 408 |
df_annotated_combined = pd.DataFrame()
|
|
|
|
| 410 |
df_annotated_combined = entitiesFusion(df_annotated,args)
|
| 411 |
if df_annotated_combined.empty and quoted_text==False:
|
| 412 |
html_output = f"<div class='gr-textbox' style='white-space: pre-wrap; overflow-wrap: break-word; padding: 10px; border: 1px solid #ddd; border-radius: 5px; font-family: monospace; font-size: 12px; line-height: 24px;'>{text}</div>"
|
| 413 |
+
state["html_output"] = html_output
|
| 414 |
return {"text": text, "entities": []}, html_output, state, [], ""
|
| 415 |
else:
|
| 416 |
if (not df_annotated.empty):
|
|
|
|
| 581 |
df_annotated_combined = df_annotated_combined[filter_mask]
|
| 582 |
if df_annotated_combined.empty:
|
| 583 |
html_output = f"<div class='gr-textbox' style='white-space: pre-wrap; overflow-wrap: break-word; padding: 10px; border: 1px solid #ddd; border-radius: 5px; font-family: monospace; font-size: 12px; line-height: 24px;'>{text}</div>"
|
| 584 |
+
state["html_output"] = html_output
|
| 585 |
return {"text": text, "entities": []}, html_output, state, [], ""
|
| 586 |
|
| 587 |
###
|
|
|
|
| 592 |
df_annotated_combined = df_annotated_combined[df_annotated_combined['IsCrossInside'] != 1]
|
| 593 |
if df_annotated_combined.empty:
|
| 594 |
html_output = f"<div class='gr-textbox' style='white-space: pre-wrap; overflow-wrap: break-word; padding: 10px; border: 1px solid #ddd; border-radius: 5px; font-family: monospace; font-size: 12px; line-height: 24px;'>{text}</div>"
|
| 595 |
+
state["html_output"] = html_output
|
| 596 |
return {"text": text, "entities": []}, html_output, state, [], ""
|
| 597 |
|
| 598 |
dict_annotated_combined_NER = df_annotated_combined[["end", "entity_group", "score", "start", "word"]].to_dict(orient="records")
|
|
|
|
| 603 |
|
| 604 |
# # Create a new column for the entities with links
|
| 605 |
df_annotated_combined['entity_with_link'] = df_annotated_combined.apply(
|
| 606 |
+
# lambda row: (
|
| 607 |
+
# f"<a href='https://expl-rels-dev-vast.apps.ocpt.jrc.ec.europa.eu/?concept={row['namedEntity']}' target='_blank'>{row['word']}</a>"
|
| 608 |
+
# if row['namedEntity'] not in [None, '', 'NaN', 'nan'] and pd.notnull(row['namedEntity']) else row[
|
| 609 |
+
# 'word']
|
| 610 |
+
# ),
|
| 611 |
lambda row: (
|
| 612 |
+
f"<a href='https://api-vast.jrc.service.ec.europa.eu/describe//?url={row['namedEntity']}' target='_blank'>{row['word']}</a>"
|
| 613 |
+
if row['namedEntity'] not in [None, '', 'NaN', 'nan'] and pd.notnull(row['namedEntity']) else row[
|
| 614 |
+
'word']
|
| 615 |
),
|
| 616 |
axis=1
|
| 617 |
)
|
|
|
|
| 694 |
words_for_dropdown = []
|
| 695 |
|
| 696 |
html_output = f"<div class='gr-textbox' style='white-space: pre-wrap; overflow-wrap: break-word; padding: 10px; border: 1px solid #ddd; border-radius: 5px; font-family: monospace; font-size: 12px; line-height: 24px;'>{text_with_links}</div>"
|
| 697 |
+
state["html_output"] = html_output
|
| 698 |
|
| 699 |
#return {"text": text, "entities": dict_annotated_combined_NER}, html_output, state
|
| 700 |
return {"text": text, "entities": dict_annotated_combined_NER}, html_output, state, gr.update(choices=words_for_dropdown), ""
|
| 701 |
|
| 702 |
else:
|
| 703 |
html_output = f"<div class='gr-textbox' style='white-space: pre-wrap; overflow-wrap: break-word; padding: 10px; border: 1px solid #ddd; border-radius: 5px; font-family: monospace; font-size: 12px; line-height: 24px;'>{text}</div>"
|
| 704 |
+
state["html_output"] = html_output
|
| 705 |
return {"text": text, "entities": dict_annotated_combined_NER}, html_output, state, [], ""
|
| 706 |
|
| 707 |
else:
|
| 708 |
|
| 709 |
html_output = f"<div class='gr-textbox' style='white-space: pre-wrap; overflow-wrap: break-word; padding: 10px; border: 1px solid #ddd; border-radius: 5px; font-family: monospace; font-size: 12px; line-height: 24px;'>{text}</div>"
|
| 710 |
+
state["html_output"] = html_output
|
| 711 |
return {"text": text, "entities": []}, html_output, state, [], ""
|
| 712 |
|
| 713 |
|
|
|
|
| 719 |
# Convert the state dictionary back into a DataFrame
|
| 720 |
df = pd.DataFrame(state["df_annotated_combined_dict"])
|
| 721 |
|
| 722 |
+
if 'ALLURIScontext' in df.columns:
|
| 723 |
+
# # Filter the DataFrame to get rows where 'ALLURIScontextFromNCBO' is not empty or None
|
| 724 |
+
# valid_entries = df[df['ALLURIScontext'].apply(lambda x: x is not None and x != [])]
|
| 725 |
+
# # Filter the DataFrame to get rows where 'ALLURIScontext' is not None, not an empty list, and not an empty string
|
| 726 |
+
valid_entries = df[df['ALLURIScontext'].apply(lambda x: x is not None and x != [] and (isinstance(x, list) and len(x) > 0) and (isinstance(x, list) and (not (len(x) == 1 and not str(x[0]).strip())) ))]
|
| 727 |
+
|
| 728 |
+
# Check if the selected word is in the filtered DataFrame
|
| 729 |
+
if selected_word in valid_entries['word'].values:
|
| 730 |
+
urls = valid_entries.loc[valid_entries['word'] == selected_word, 'ALLURIScontext'].values[0]
|
| 731 |
+
if 'namedEntity' in df.columns:
|
| 732 |
+
firsturlinlist = df.loc[df['word'] == selected_word, 'namedEntity']
|
| 733 |
+
firsturlinlist = firsturlinlist.iloc[0] if not firsturlinlist.empty else None
|
| 734 |
+
if firsturlinlist and firsturlinlist in urls:
|
| 735 |
+
# Remove the URL from its current position
|
| 736 |
+
urls.remove(firsturlinlist)
|
| 737 |
+
# Insert the URL at the first position
|
| 738 |
+
urls.insert(0, firsturlinlist)
|
| 739 |
+
|
| 740 |
+
# Convert list of URLs to HTML string with clickable links
|
| 741 |
+
#html_links = "<br>".join([f'<a href="https://expl-rels-dev-vast.apps.ocpt.jrc.ec.europa.eu/?concept={url}" target="_blank">{url}</a>' for url in urls])
|
| 742 |
+
html_links = "<br>".join([f'<a href="https://api-vast.jrc.service.ec.europa.eu/describe//?url={url}" target="_blank">{url}</a>' for url in urls])
|
| 743 |
+
return html_links
|
| 744 |
+
return ""
|
| 745 |
+
else:
|
| 746 |
+
return""
|
| 747 |
+
|
| 748 |
|
| 749 |
else:
|
| 750 |
return ""
|
|
|
|
| 828 |
text_input = gr.Textbox(label="Input text", placeholder="Enter text here...")
|
| 829 |
models_selection = gr.CheckboxGroup(models_List, label="ModelsSelection", value=models_List)
|
| 830 |
categories_selection = gr.CheckboxGroup(categories_List, label="CategoriesSelection", value=categories_List)
|
| 831 |
+
score_slider = gr.Slider(minimum=0, maximum=1.0, step=0.05, label="Score", value=0.75)
|
| 832 |
nel_checkbox = gr.Checkbox(label="Enable Named-Entity Linking (NEL)", value=False)
|
| 833 |
kgchoices_selection = gr.Dropdown(POSSIBLE_KGchoices_List, multiselect=True, label="KGchoices Selection", value=POSSIBLE_KGchoices_List)
|
| 834 |
state = gr.State(value={})
|
|
|
|
| 884 |
|
| 885 |
|
| 886 |
demo.launch()
|
| 887 |
+
#demo.launch(share=True) # Share your demo with just 1 extra parameter
|
nerBio.py
CHANGED
|
@@ -65,7 +65,8 @@ import json
|
|
| 65 |
import random
|
| 66 |
import numpy as np
|
| 67 |
|
| 68 |
-
from retrieverRAG_testing import RAG_retrieval_Base, RAG_retrieval_Z_scores, RAG_retrieval_Percentile, RAG_retrieval_TopK
|
|
|
|
| 69 |
|
| 70 |
from joblib import Memory
|
| 71 |
|
|
@@ -957,135 +958,265 @@ def getLinearTextualContextFromTriples(word,labelTriplesLIST, text_splitter, arg
|
|
| 957 |
word = word.lower()
|
| 958 |
word = word.capitalize()
|
| 959 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 960 |
|
| 961 |
-
if (strtobool(args.UseRetrieverForContextCreation)==True):
|
| 962 |
labelTriples = ""
|
| 963 |
-
passages = []
|
| 964 |
-
nn = 200
|
| 965 |
|
| 966 |
-
|
|
|
|
| 967 |
passages = []
|
| 968 |
for i, triple in enumerate(labelTriplesLIST, start=1):
|
| 969 |
# for triple in labelTriplesLIST:
|
| 970 |
TriplesString = (" ".join(str(element).capitalize() for element in triple))
|
| 971 |
passages.append(TriplesString)
|
| 972 |
|
| 973 |
-
|
| 974 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 975 |
|
| 976 |
-
|
| 977 |
-
#labelTriplesLIST_RAGGED = df_retrieved.to_records(index=False).tolist()
|
| 978 |
-
labelTriplesLIST_RAGGED = df_retrieved['Passage'].apply(lambda x: (x,)).tolist()
|
| 979 |
-
labelTriplesAPP = ". ".join(
|
| 980 |
-
" ".join(str(element).capitalize() for element in triple) for triple in labelTriplesLIST_RAGGED)
|
| 981 |
|
| 982 |
-
|
| 983 |
-
labelTriples = labelTriplesAPP
|
| 984 |
-
else:
|
| 985 |
-
labelTriples = labelTriples + ". " + labelTriplesAPP
|
| 986 |
|
| 987 |
-
|
|
|
|
|
|
|
| 988 |
|
| 989 |
-
|
| 990 |
-
|
| 991 |
-
|
| 992 |
-
Oinnerlistiterative=[]
|
| 993 |
-
for i, triple in enumerate(OverallListRAGtriples, start=1):
|
| 994 |
-
# for triple in labelTriplesLIST:
|
| 995 |
-
TriplesString = (" ".join(str(element).capitalize() for element in triple))
|
| 996 |
-
passages.append(TriplesString)
|
| 997 |
-
# Check if the current index is a multiple of nn
|
| 998 |
-
if i % nn == 0:
|
| 999 |
-
# print("elaborate RAG triples")
|
| 1000 |
-
|
| 1001 |
-
# df_retrieved_Base = RAG_retrieval_Base(questionText, passages, min_threshold=0.7, max_num_passages=20)
|
| 1002 |
-
# df_retrievedZscore = RAG_retrieval_Z_scores(questionText, passages, z_threshold=1.0, max_num_passages=20, min_threshold=0.7)
|
| 1003 |
-
# df_retrievedPercentile = RAG_retrieval_Percentile(questionText, passages, percentile=90, max_num_passages=20, min_threshold=0.7)
|
| 1004 |
-
df_retrievedtopk = RAG_retrieval_TopK(questionText, passages, top_fraction=0.1, max_num_passages=20,
|
| 1005 |
-
min_threshold=0.7)
|
| 1006 |
-
|
| 1007 |
-
passages = []
|
| 1008 |
-
|
| 1009 |
-
df_retrieved = df_retrievedtopk.copy()
|
| 1010 |
-
if not df_retrieved.empty:
|
| 1011 |
-
#labelTriplesLIST_RAGGED = df_retrieved.to_records(index=False).tolist()
|
| 1012 |
-
labelTriplesLIST_RAGGED = df_retrieved['Passage'].apply(lambda x: (x,)).tolist()
|
| 1013 |
-
if not Oinnerlistiterative:
|
| 1014 |
-
Oinnerlistiterative=labelTriplesLIST_RAGGED
|
| 1015 |
-
else:
|
| 1016 |
-
Oinnerlistiterative.extend(labelTriplesLIST_RAGGED)
|
| 1017 |
-
|
| 1018 |
-
if passages:
|
| 1019 |
-
df_retrievedtopk = RAG_retrieval_TopK(questionText, passages, top_fraction=0.1, max_num_passages=20,
|
| 1020 |
-
min_threshold=0.7)
|
| 1021 |
-
|
| 1022 |
-
df_retrieved = df_retrievedtopk.copy()
|
| 1023 |
-
if not df_retrieved.empty:
|
| 1024 |
-
#labelTriplesLIST_RAGGED = df_retrieved.to_records(index=False).tolist()
|
| 1025 |
-
labelTriplesLIST_RAGGED = df_retrieved['Passage'].apply(lambda x: (x,)).tolist()
|
| 1026 |
-
if not Oinnerlistiterative:
|
| 1027 |
-
Oinnerlistiterative = labelTriplesLIST_RAGGED
|
| 1028 |
-
else:
|
| 1029 |
-
Oinnerlistiterative.extend(labelTriplesLIST_RAGGED)
|
| 1030 |
-
|
| 1031 |
-
OverallListRAGtriples = Oinnerlistiterative.copy()
|
| 1032 |
-
|
| 1033 |
-
if OverallListRAGtriples:
|
| 1034 |
-
labelTriplesAPP = ". ".join(" ".join(str(element).capitalize() for element in triple) for triple in OverallListRAGtriples)
|
| 1035 |
-
|
| 1036 |
-
if not labelTriples:
|
| 1037 |
-
labelTriples = labelTriplesAPP
|
| 1038 |
else:
|
| 1039 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1040 |
|
| 1041 |
-
|
|
|
|
|
|
|
| 1042 |
|
| 1043 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1044 |
|
| 1045 |
-
|
| 1046 |
-
|
| 1047 |
-
|
| 1048 |
-
|
| 1049 |
-
# #for triple in labelTriplesLIST:
|
| 1050 |
-
# TriplesString = (" ".join(str(element).capitalize() for element in triple))
|
| 1051 |
-
# passages.append(TriplesString)
|
| 1052 |
-
# # Check if the current index is a multiple of nn
|
| 1053 |
-
# if i % nn == 0:
|
| 1054 |
-
# #print("elaborate RAG triples")
|
| 1055 |
-
#
|
| 1056 |
-
# #df_retrieved_Base = RAG_retrieval_Base(questionText, passages, min_threshold=0.7, max_num_passages=20)
|
| 1057 |
-
# #df_retrievedZscore = RAG_retrieval_Z_scores(questionText, passages, z_threshold=1.0, max_num_passages=20, min_threshold=0.7)
|
| 1058 |
-
# #df_retrievedPercentile = RAG_retrieval_Percentile(questionText, passages, percentile=90, max_num_passages=20, min_threshold=0.7)
|
| 1059 |
-
# df_retrievedtopk = RAG_retrieval_TopK(questionText, passages, top_fraction=0.1, max_num_passages=20, min_threshold=0.7)
|
| 1060 |
-
#
|
| 1061 |
-
# passages = []
|
| 1062 |
-
#
|
| 1063 |
-
# df_retrieved = df_retrievedtopk.copy()
|
| 1064 |
-
# if not df_retrieved.empty:
|
| 1065 |
-
# #labelTriplesLIST_RAGGED = df_retrieved.to_records(index=False).tolist()
|
| 1066 |
-
# labelTriplesLIST_RAGGED = df_retrieved['Passage'].apply(lambda x: (x,)).tolist()
|
| 1067 |
-
# labelTriplesAPP = ". ".join(" ".join(str(element).capitalize() for element in triple) for triple in labelTriplesLIST_RAGGED)
|
| 1068 |
-
#
|
| 1069 |
-
# if not labelTriples:
|
| 1070 |
-
# labelTriples =labelTriplesAPP
|
| 1071 |
-
# else:
|
| 1072 |
-
# labelTriples = labelTriples + ". " + labelTriplesAPP
|
| 1073 |
-
#
|
| 1074 |
-
# if passages:
|
| 1075 |
-
# df_retrievedtopk = RAG_retrieval_TopK(questionText, passages, top_fraction=0.1, max_num_passages=20, min_threshold=0.7)
|
| 1076 |
-
#
|
| 1077 |
-
# df_retrieved = df_retrievedtopk.copy()
|
| 1078 |
-
# if not df_retrieved.empty:
|
| 1079 |
-
# #labelTriplesLIST_RAGGED = df_retrieved.to_records(index=False).tolist()
|
| 1080 |
-
# labelTriplesLIST_RAGGED = df_retrieved['Passage'].apply(lambda x: (x,)).tolist()
|
| 1081 |
-
# labelTriplesAPP = ". ".join(" ".join(str(element).capitalize() for element in triple) for triple in labelTriplesLIST_RAGGED)
|
| 1082 |
-
# if not labelTriples:
|
| 1083 |
-
# labelTriples = labelTriplesAPP
|
| 1084 |
-
# else:
|
| 1085 |
-
# labelTriples = labelTriples + ". " + labelTriplesAPP
|
| 1086 |
-
#
|
| 1087 |
-
# if labelTriples:
|
| 1088 |
-
# labelTriples.strip().replace("..",".").strip()
|
| 1089 |
|
| 1090 |
|
| 1091 |
else: # NO RAG on triples
|
|
@@ -1571,7 +1702,7 @@ def virtuoso_api_call(word, text_splitter, args, key_virtuoso, cache_map_virtuos
|
|
| 1571 |
|
| 1572 |
if entityBioeUrl:
|
| 1573 |
|
| 1574 |
-
if strtobool(args.computeEntityContext):
|
| 1575 |
|
| 1576 |
if strtobool(args.debug):
|
| 1577 |
print("START computeEntityContext")
|
|
@@ -1706,6 +1837,8 @@ def virtuoso_api_call(word, text_splitter, args, key_virtuoso, cache_map_virtuos
|
|
| 1706 |
|
| 1707 |
|
| 1708 |
if not globalContext:
|
|
|
|
|
|
|
| 1709 |
if unique_listGlobalTriples:
|
| 1710 |
globalContext, load_map_query_input_output = getLinearTextualContextFromTriples(word, unique_listGlobalTriples,
|
| 1711 |
text_splitter, args,
|
|
@@ -1750,6 +1883,7 @@ def virtuoso_api_call(word, text_splitter, args, key_virtuoso, cache_map_virtuos
|
|
| 1750 |
return None, None, None, None, None, None, cache_map_virtuoso, load_map_query_input_output
|
| 1751 |
|
| 1752 |
|
|
|
|
| 1753 |
if not ALLURIScontext:
|
| 1754 |
# Print the error message to stderr
|
| 1755 |
print("THIS CASE SHOULD NEVER HAPPEN NOW!!!! Check what's happening...exiting now...")
|
|
@@ -1879,9 +2013,16 @@ def virtuoso_api_call(word, text_splitter, args, key_virtuoso, cache_map_virtuos
|
|
| 1879 |
if unique_listLabelTriples:
|
| 1880 |
unique_listGlobalTriples.extend(unique_listLabelTriples)
|
| 1881 |
|
| 1882 |
-
|
| 1883 |
-
#
|
| 1884 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1885 |
|
| 1886 |
|
| 1887 |
if unique_listGlobalTriples:
|
|
@@ -1892,7 +2033,8 @@ def virtuoso_api_call(word, text_splitter, args, key_virtuoso, cache_map_virtuos
|
|
| 1892 |
if cache_map_virtuoso is not None:
|
| 1893 |
if not word in cache_map_virtuoso:
|
| 1894 |
cache_map_virtuoso[word] = {}
|
| 1895 |
-
|
|
|
|
| 1896 |
|
| 1897 |
globalContext, load_map_query_input_output = getLinearTextualContextFromTriples(word,
|
| 1898 |
unique_listGlobalTriples,
|
|
@@ -1903,7 +2045,8 @@ def virtuoso_api_call(word, text_splitter, args, key_virtuoso, cache_map_virtuos
|
|
| 1903 |
if cache_map_virtuoso is not None:
|
| 1904 |
if not word in cache_map_virtuoso:
|
| 1905 |
cache_map_virtuoso[word] = {}
|
| 1906 |
-
|
|
|
|
| 1907 |
|
| 1908 |
if unique_listLabelTriples:
|
| 1909 |
sssingleTriples = " ,., ".join(
|
|
@@ -2291,6 +2434,8 @@ if __name__ == '__main__':
|
|
| 2291 |
|
| 2292 |
parser.add_argument("--computeEntityContext", type=str, default="False", help="whether to extract a readable context from the extracted triples for the concept")
|
| 2293 |
parser.add_argument("--computeEntityGlobalContext", type=str, default="False", help="whether to extract a readable context from the extracted triples of all the entities extracted from the endpoint for the concept")
|
|
|
|
|
|
|
| 2294 |
parser.add_argument("--UseRetrieverForContextCreation", type=str, default="True",
|
| 2295 |
help="whether to use a retriever for the creation of the context of the entities from the triples coming from the KGs")
|
| 2296 |
|
|
|
|
| 65 |
import random
|
| 66 |
import numpy as np
|
| 67 |
|
| 68 |
+
#from retrieverRAG_testing import RAG_retrieval_Base, RAG_retrieval_Z_scores, RAG_retrieval_Percentile, RAG_retrieval_TopK, retrievePassageSimilarities
|
| 69 |
+
from retrieverRAG_SF import RAG_retrieval_Base
|
| 70 |
|
| 71 |
from joblib import Memory
|
| 72 |
|
|
|
|
| 958 |
word = word.lower()
|
| 959 |
word = word.capitalize()
|
| 960 |
|
| 961 |
+
labelTriples=""
|
| 962 |
+
|
| 963 |
+
if labelTriplesLIST and getattr(args, 'maxTriplesContextComputation', None): # it means it exists
|
| 964 |
+
if args.maxTriplesContextComputation > 0:
|
| 965 |
+
if len(labelTriplesLIST) > args.maxTriplesContextComputation:
|
| 966 |
+
labelTriplesLIST = labelTriplesLIST[:args.maxTriplesContextComputation]
|
| 967 |
+
|
| 968 |
+
if (strtobool(args.UseRetrieverForContextCreation) == True):
|
| 969 |
+
|
| 970 |
+
# if strtobool(args.debug):
|
| 971 |
+
# print("Start reranking - num passages : ", len(labelTriplesLIST), "\n")
|
| 972 |
+
# startRerank = time.time()
|
| 973 |
+
#
|
| 974 |
+
# labelTriples = ""
|
| 975 |
+
# passages = []
|
| 976 |
+
# nn = 200
|
| 977 |
+
#
|
| 978 |
+
# OverallListRAGtriples = []
|
| 979 |
+
# labelTriplesLIST_RAGGED = []
|
| 980 |
+
#
|
| 981 |
+
# if len(labelTriplesLIST) <= nn:
|
| 982 |
+
# passages = []
|
| 983 |
+
# for i, triple in enumerate(labelTriplesLIST, start=1):
|
| 984 |
+
# # for triple in labelTriplesLIST:
|
| 985 |
+
# TriplesString = (" ".join(str(element).capitalize() for element in triple))
|
| 986 |
+
# passages.append(TriplesString)
|
| 987 |
+
#
|
| 988 |
+
# df_retrieved = RAG_retrieval_TopK(questionText, passages, top_fraction=0.1, max_num_passages=20,
|
| 989 |
+
# min_threshold=0.7)
|
| 990 |
+
#
|
| 991 |
+
# if not df_retrieved.empty:
|
| 992 |
+
# # labelTriplesLIST_RAGGED = df_retrieved.to_records(index=False).tolist()
|
| 993 |
+
# labelTriplesLIST_RAGGED = df_retrieved['Passage'].apply(lambda x: (x,)).tolist()
|
| 994 |
+
# labelTriplesAPP = ". ".join(
|
| 995 |
+
# " ".join(str(element).capitalize() for element in triple) for triple in labelTriplesLIST_RAGGED)
|
| 996 |
+
#
|
| 997 |
+
# if not labelTriples:
|
| 998 |
+
# labelTriples = labelTriplesAPP
|
| 999 |
+
# else:
|
| 1000 |
+
# labelTriples = labelTriples + ". " + labelTriplesAPP
|
| 1001 |
+
#
|
| 1002 |
+
# else:
|
| 1003 |
+
#
|
| 1004 |
+
# OverallListRAGtriples = labelTriplesLIST.copy()
|
| 1005 |
+
#
|
| 1006 |
+
# while len(OverallListRAGtriples) > nn:
|
| 1007 |
+
# Oinnerlistiterative = []
|
| 1008 |
+
# for i, triple in enumerate(OverallListRAGtriples, start=1):
|
| 1009 |
+
# # for triple in labelTriplesLIST:
|
| 1010 |
+
# TriplesString = (" ".join(str(element).capitalize() for element in triple))
|
| 1011 |
+
# passages.append(TriplesString)
|
| 1012 |
+
# # Check if the current index is a multiple of nn
|
| 1013 |
+
# if i % nn == 0:
|
| 1014 |
+
# # print("elaborate RAG triples")
|
| 1015 |
+
#
|
| 1016 |
+
# # df_retrieved_Base = RAG_retrieval_Base(questionText, passages, min_threshold=0.7, max_num_passages=20)
|
| 1017 |
+
# # df_retrievedZscore = RAG_retrieval_Z_scores(questionText, passages, z_threshold=1.0, max_num_passages=20, min_threshold=0.7)
|
| 1018 |
+
# # df_retrievedPercentile = RAG_retrieval_Percentile(questionText, passages, percentile=90, max_num_passages=20, min_threshold=0.7)
|
| 1019 |
+
# df_retrievedtopk = RAG_retrieval_TopK(questionText, passages, top_fraction=0.1,
|
| 1020 |
+
# max_num_passages=20,
|
| 1021 |
+
# min_threshold=0.7)
|
| 1022 |
+
#
|
| 1023 |
+
# passages = []
|
| 1024 |
+
#
|
| 1025 |
+
# df_retrieved = df_retrievedtopk.copy()
|
| 1026 |
+
# if not df_retrieved.empty:
|
| 1027 |
+
# # labelTriplesLIST_RAGGED = df_retrieved.to_records(index=False).tolist()
|
| 1028 |
+
# labelTriplesLIST_RAGGED = df_retrieved['Passage'].apply(lambda x: (x,)).tolist()
|
| 1029 |
+
# if not Oinnerlistiterative:
|
| 1030 |
+
# Oinnerlistiterative = labelTriplesLIST_RAGGED
|
| 1031 |
+
# else:
|
| 1032 |
+
# Oinnerlistiterative.extend(labelTriplesLIST_RAGGED)
|
| 1033 |
+
#
|
| 1034 |
+
# if passages:
|
| 1035 |
+
# df_retrievedtopk = RAG_retrieval_TopK(questionText, passages, top_fraction=0.1, max_num_passages=20,
|
| 1036 |
+
# min_threshold=0.7)
|
| 1037 |
+
#
|
| 1038 |
+
# df_retrieved = df_retrievedtopk.copy()
|
| 1039 |
+
# if not df_retrieved.empty:
|
| 1040 |
+
# # labelTriplesLIST_RAGGED = df_retrieved.to_records(index=False).tolist()
|
| 1041 |
+
# labelTriplesLIST_RAGGED = df_retrieved['Passage'].apply(lambda x: (x,)).tolist()
|
| 1042 |
+
# if not Oinnerlistiterative:
|
| 1043 |
+
# Oinnerlistiterative = labelTriplesLIST_RAGGED
|
| 1044 |
+
# else:
|
| 1045 |
+
# Oinnerlistiterative.extend(labelTriplesLIST_RAGGED)
|
| 1046 |
+
#
|
| 1047 |
+
# OverallListRAGtriples = Oinnerlistiterative.copy()
|
| 1048 |
+
#
|
| 1049 |
+
# if OverallListRAGtriples:
|
| 1050 |
+
# labelTriplesAPP = ". ".join(
|
| 1051 |
+
# " ".join(str(element).capitalize() for element in triple) for triple in OverallListRAGtriples)
|
| 1052 |
+
#
|
| 1053 |
+
# if not labelTriples:
|
| 1054 |
+
# labelTriples = labelTriplesAPP
|
| 1055 |
+
# else:
|
| 1056 |
+
# labelTriples = labelTriples + ". " + labelTriplesAPP
|
| 1057 |
+
#
|
| 1058 |
+
# labelTriples = labelTriples.strip().replace("..", ".").strip()
|
| 1059 |
+
#
|
| 1060 |
+
# # labelTriples = ""
|
| 1061 |
+
# # passages = []
|
| 1062 |
+
# # nn=200
|
| 1063 |
+
# # for i, triple in enumerate(labelTriplesLIST, start=1):
|
| 1064 |
+
# # #for triple in labelTriplesLIST:
|
| 1065 |
+
# # TriplesString = (" ".join(str(element).capitalize() for element in triple))
|
| 1066 |
+
# # passages.append(TriplesString)
|
| 1067 |
+
# # # Check if the current index is a multiple of nn
|
| 1068 |
+
# # if i % nn == 0:
|
| 1069 |
+
# # #print("elaborate RAG triples")
|
| 1070 |
+
# #
|
| 1071 |
+
# # #df_retrieved_Base = RAG_retrieval_Base(questionText, passages, min_threshold=0.7, max_num_passages=20)
|
| 1072 |
+
# # #df_retrievedZscore = RAG_retrieval_Z_scores(questionText, passages, z_threshold=1.0, max_num_passages=20, min_threshold=0.7)
|
| 1073 |
+
# # #df_retrievedPercentile = RAG_retrieval_Percentile(questionText, passages, percentile=90, max_num_passages=20, min_threshold=0.7)
|
| 1074 |
+
# # df_retrievedtopk = RAG_retrieval_TopK(questionText, passages, top_fraction=0.1, max_num_passages=20, min_threshold=0.7)
|
| 1075 |
+
# #
|
| 1076 |
+
# # passages = []
|
| 1077 |
+
# #
|
| 1078 |
+
# # df_retrieved = df_retrievedtopk.copy()
|
| 1079 |
+
# # if not df_retrieved.empty:
|
| 1080 |
+
# # #labelTriplesLIST_RAGGED = df_retrieved.to_records(index=False).tolist()
|
| 1081 |
+
# # labelTriplesLIST_RAGGED = df_retrieved['Passage'].apply(lambda x: (x,)).tolist()
|
| 1082 |
+
# # labelTriplesAPP = ". ".join(" ".join(str(element).capitalize() for element in triple) for triple in labelTriplesLIST_RAGGED)
|
| 1083 |
+
# #
|
| 1084 |
+
# # if not labelTriples:
|
| 1085 |
+
# # labelTriples =labelTriplesAPP
|
| 1086 |
+
# # else:
|
| 1087 |
+
# # labelTriples = labelTriples + ". " + labelTriplesAPP
|
| 1088 |
+
# #
|
| 1089 |
+
# # if passages:
|
| 1090 |
+
# # df_retrievedtopk = RAG_retrieval_TopK(questionText, passages, top_fraction=0.1, max_num_passages=20, min_threshold=0.7)
|
| 1091 |
+
# #
|
| 1092 |
+
# # df_retrieved = df_retrievedtopk.copy()
|
| 1093 |
+
# # if not df_retrieved.empty:
|
| 1094 |
+
# # #labelTriplesLIST_RAGGED = df_retrieved.to_records(index=False).tolist()
|
| 1095 |
+
# # labelTriplesLIST_RAGGED = df_retrieved['Passage'].apply(lambda x: (x,)).tolist()
|
| 1096 |
+
# # labelTriplesAPP = ". ".join(" ".join(str(element).capitalize() for element in triple) for triple in labelTriplesLIST_RAGGED)
|
| 1097 |
+
# # if not labelTriples:
|
| 1098 |
+
# # labelTriples = labelTriplesAPP
|
| 1099 |
+
# # else:
|
| 1100 |
+
# # labelTriples = labelTriples + ". " + labelTriplesAPP
|
| 1101 |
+
# #
|
| 1102 |
+
# # if labelTriples:
|
| 1103 |
+
# # labelTriples.strip().replace("..",".").strip()
|
| 1104 |
+
#
|
| 1105 |
+
# if strtobool(args.debug):
|
| 1106 |
+
# numfinal = 0
|
| 1107 |
+
# if OverallListRAGtriples:
|
| 1108 |
+
# numfinal = len(OverallListRAGtriples)
|
| 1109 |
+
# elif labelTriplesLIST_RAGGED:
|
| 1110 |
+
# numfinal = len(labelTriplesLIST_RAGGED)
|
| 1111 |
+
# print("End reranking - found final passages : ", numfinal, "\n")
|
| 1112 |
+
# #
|
| 1113 |
+
# endRerank = time.time()
|
| 1114 |
+
# hours, rem = divmod(endRerank - startRerank, 3600)
|
| 1115 |
+
# minutes, seconds = divmod(rem, 60)
|
| 1116 |
+
# print("Rerank Time... {:0>2}:{:0>2}:{:05.2f}\n".format(int(hours), int(minutes), seconds))
|
| 1117 |
+
# #
|
| 1118 |
+
|
| 1119 |
+
# if len(labelTriplesLIST) > 10000:
|
| 1120 |
+
# print("debug")
|
| 1121 |
+
|
| 1122 |
+
if strtobool(args.debug):
|
| 1123 |
+
print("Start reranking2 - num passages : ", len(labelTriplesLIST), "\n")
|
| 1124 |
+
startRerank2 = time.time()
|
| 1125 |
|
|
|
|
| 1126 |
labelTriples = ""
|
|
|
|
|
|
|
| 1127 |
|
| 1128 |
+
try:
|
| 1129 |
+
|
| 1130 |
passages = []
|
| 1131 |
for i, triple in enumerate(labelTriplesLIST, start=1):
|
| 1132 |
# for triple in labelTriplesLIST:
|
| 1133 |
TriplesString = (" ".join(str(element).capitalize() for element in triple))
|
| 1134 |
passages.append(TriplesString)
|
| 1135 |
|
| 1136 |
+
nback = 1
|
| 1137 |
+
if len(passages) <= 10:
|
| 1138 |
+
nback = len(passages)
|
| 1139 |
+
elif len(passages) <= 1000:
|
| 1140 |
+
nback = 10+int(0.1 * len(passages)) # 10% of the number of passages
|
| 1141 |
+
elif len(passages) <= 5000:
|
| 1142 |
+
nback = 200
|
| 1143 |
+
elif len(passages) <= 10000:
|
| 1144 |
+
nback = 300
|
| 1145 |
+
else:
|
| 1146 |
+
nback = 400
|
| 1147 |
|
| 1148 |
+
df_retrieved = RAG_retrieval_Base(questionText, passages, min_threshold=0, max_num_passages=nback)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1149 |
|
| 1150 |
+
if not df_retrieved.empty:
|
|
|
|
|
|
|
|
|
|
| 1151 |
|
| 1152 |
+
countRetr = 0
|
| 1153 |
+
min_threshold = 0.80
|
| 1154 |
+
countRetr = (df_retrieved['score'] > min_threshold).sum()
|
| 1155 |
|
| 1156 |
+
countRetrThreshold = int(nback / 2)
|
| 1157 |
+
if nback > 10:
|
| 1158 |
+
countRetrThreshold = 10
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1159 |
else:
|
| 1160 |
+
countRetrThreshold = int(nback/2)
|
| 1161 |
+
if countRetrThreshold <=0:
|
| 1162 |
+
countRetrThreshold = 1
|
| 1163 |
+
|
| 1164 |
+
while countRetr <= countRetrThreshold:
|
| 1165 |
+
min_threshold = min_threshold - 0.05
|
| 1166 |
+
countRetr = (df_retrieved['score'] >= min_threshold).sum()
|
| 1167 |
+
if min_threshold < 0.2:
|
| 1168 |
+
break
|
| 1169 |
+
|
| 1170 |
+
# countRetrThreshold = int(0.1 + nback)
|
| 1171 |
+
# if countRetrThreshold > 5:
|
| 1172 |
+
# countRetrThreshold = 5
|
| 1173 |
+
#
|
| 1174 |
+
# countRetr=0
|
| 1175 |
+
# min_threshold = 0.90
|
| 1176 |
+
# countRetr = (df_retrieved['score'] > min_threshold).sum()
|
| 1177 |
+
# while countRetr<=countRetrThreshold:
|
| 1178 |
+
# min_threshold = min_threshold - 0.05
|
| 1179 |
+
# if min_threshold<0.7:
|
| 1180 |
+
# countRetrThreshold=0
|
| 1181 |
+
# if min_threshold == 0:
|
| 1182 |
+
# min_threshold = 0.01
|
| 1183 |
+
# countRetr = (df_retrieved['score'] > min_threshold).sum()
|
| 1184 |
+
# if min_threshold <= 0.01:
|
| 1185 |
+
# break
|
| 1186 |
+
|
| 1187 |
+
if countRetr > 0:
|
| 1188 |
+
df_retrieved = df_retrieved[df_retrieved['score'] > min_threshold]
|
| 1189 |
+
|
| 1190 |
+
# labelTriplesLIST_RAGGED = df_retrieved.to_records(index=False).tolist()
|
| 1191 |
+
labelTriplesLIST_RAGGED = df_retrieved['Passage'].apply(lambda x: (x,)).tolist()
|
| 1192 |
+
labelTriplesAPP = ". ".join(
|
| 1193 |
+
" ".join(str(element).capitalize() for element in triple) for triple in labelTriplesLIST_RAGGED)
|
| 1194 |
+
|
| 1195 |
+
if not labelTriples:
|
| 1196 |
+
labelTriples = labelTriplesAPP
|
| 1197 |
+
else:
|
| 1198 |
+
labelTriples = labelTriples + ". " + labelTriplesAPP
|
| 1199 |
|
| 1200 |
+
else:
|
| 1201 |
+
labelTriplesLIST_RAGGED = []
|
| 1202 |
+
labelTriples = ""
|
| 1203 |
|
| 1204 |
|
| 1205 |
+
if strtobool(args.debug):
|
| 1206 |
+
numfinal = 0
|
| 1207 |
+
if labelTriplesLIST_RAGGED:
|
| 1208 |
+
numfinal = len(labelTriplesLIST_RAGGED)
|
| 1209 |
+
print("End reranking2 - found final passages : ", numfinal, "\n")
|
| 1210 |
+
endRerank2 = time.time()
|
| 1211 |
+
hours, rem = divmod(endRerank2 - startRerank2, 3600)
|
| 1212 |
+
minutes, seconds = divmod(rem, 60)
|
| 1213 |
+
print("Rerank2 Time... {:0>2}:{:0>2}:{:05.2f}\n".format(int(hours), int(minutes), seconds))
|
| 1214 |
+
#
|
| 1215 |
|
| 1216 |
+
except Exception as err:
|
| 1217 |
+
print("SOMETHING HAPPENED on PASSAGE RERANKING for Question :"+questionText+"\n")
|
| 1218 |
+
print(err)
|
| 1219 |
+
#status_code: 422, body: type='validation_error' url='https://www.mixedbread.ai/api-reference' message='Your request is invalid. Please check your input and try again.' details=[[{'type': 'too_long', 'loc': ['body', 'input', 'list[str]'], 'msg': 'List should have at most 1000 items after validation, not 4249',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1220 |
|
| 1221 |
|
| 1222 |
else: # NO RAG on triples
|
|
|
|
| 1702 |
|
| 1703 |
if entityBioeUrl:
|
| 1704 |
|
| 1705 |
+
if strtobool(args.computeEntityContext) and (strtobool(args.computeEntityGlobalContext)==False):
|
| 1706 |
|
| 1707 |
if strtobool(args.debug):
|
| 1708 |
print("START computeEntityContext")
|
|
|
|
| 1837 |
|
| 1838 |
|
| 1839 |
if not globalContext:
|
| 1840 |
+
|
| 1841 |
+
BreakenBeforeAll = False
|
| 1842 |
if unique_listGlobalTriples:
|
| 1843 |
globalContext, load_map_query_input_output = getLinearTextualContextFromTriples(word, unique_listGlobalTriples,
|
| 1844 |
text_splitter, args,
|
|
|
|
| 1883 |
return None, None, None, None, None, None, cache_map_virtuoso, load_map_query_input_output
|
| 1884 |
|
| 1885 |
|
| 1886 |
+
|
| 1887 |
if not ALLURIScontext:
|
| 1888 |
# Print the error message to stderr
|
| 1889 |
print("THIS CASE SHOULD NEVER HAPPEN NOW!!!! Check what's happening...exiting now...")
|
|
|
|
| 2013 |
if unique_listLabelTriples:
|
| 2014 |
unique_listGlobalTriples.extend(unique_listLabelTriples)
|
| 2015 |
|
| 2016 |
+
|
| 2017 |
+
# This is to speed up, so I break here the global, but in this case I will not store the triples for the other uris in the cache, which maybe useful in the future
|
| 2018 |
+
# #if token_counter(str(unique_listGlobalTriples),args.model_name) > args.tokens_max:
|
| 2019 |
+
|
| 2020 |
+
if getattr(args, 'maxTriplesContextComputation', None): #it means it exists
|
| 2021 |
+
if args.maxTriplesContextComputation > 0:
|
| 2022 |
+
if len(unique_listGlobalTriples) > args.maxTriplesContextComputation:
|
| 2023 |
+
unique_listGlobalTriples = unique_listGlobalTriples[:args.maxTriplesContextComputation]
|
| 2024 |
+
BreakenBeforeAll = True
|
| 2025 |
+
break # BREAK THE FOR LOOP IF THE GLOBAL CONTEXT IS ALREADY TOO BIG, BIGGER THAN tokens_max
|
| 2026 |
|
| 2027 |
|
| 2028 |
if unique_listGlobalTriples:
|
|
|
|
| 2033 |
if cache_map_virtuoso is not None:
|
| 2034 |
if not word in cache_map_virtuoso:
|
| 2035 |
cache_map_virtuoso[word] = {}
|
| 2036 |
+
if BreakenBeforeAll == False:
|
| 2037 |
+
cache_map_virtuoso[word][("GlobalTriples"+" "+contextWordVirtuoso).strip()] = unique_listGlobalTriples
|
| 2038 |
|
| 2039 |
globalContext, load_map_query_input_output = getLinearTextualContextFromTriples(word,
|
| 2040 |
unique_listGlobalTriples,
|
|
|
|
| 2045 |
if cache_map_virtuoso is not None:
|
| 2046 |
if not word in cache_map_virtuoso:
|
| 2047 |
cache_map_virtuoso[word] = {}
|
| 2048 |
+
if BreakenBeforeAll == False:
|
| 2049 |
+
cache_map_virtuoso[word][("GlobalContext"+" "+contextWordVirtuoso).strip()] = globalContext
|
| 2050 |
|
| 2051 |
if unique_listLabelTriples:
|
| 2052 |
sssingleTriples = " ,., ".join(
|
|
|
|
| 2434 |
|
| 2435 |
parser.add_argument("--computeEntityContext", type=str, default="False", help="whether to extract a readable context from the extracted triples for the concept")
|
| 2436 |
parser.add_argument("--computeEntityGlobalContext", type=str, default="False", help="whether to extract a readable context from the extracted triples of all the entities extracted from the endpoint for the concept")
|
| 2437 |
+
parser.add_argument("--maxTriplesContextComputation", type=int, default=20000,
|
| 2438 |
+
help="maximum number of triples to consider for global context computation") # if 0 or None it is not considered
|
| 2439 |
parser.add_argument("--UseRetrieverForContextCreation", type=str, default="True",
|
| 2440 |
help="whether to use a retriever for the creation of the context of the entities from the triples coming from the KGs")
|
| 2441 |
|
retrieverRAG_SF.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# https://www.mixedbread.ai/blog/mxbai-embed-large-v1
|
| 3 |
+
# https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import numpy as np
|
| 9 |
+
from typing import Dict
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from transformers import AutoModel, AutoTokenizer
|
| 13 |
+
from sentence_transformers.util import cos_sim
|
| 14 |
+
from accelerate import Accelerator # Import from accelerate
|
| 15 |
+
from scipy.stats import zscore
|
| 16 |
+
|
| 17 |
+
# Set up environment variables for Hugging Face caching
|
| 18 |
+
os.environ["HF_HUB_CACHE"] = "/eos/jeodpp/home/users/consose/cache/huggingface/hub"
|
| 19 |
+
os.environ["HUGGINGFACE_HUB_CACHE"] = "/eos/jeodpp/home/users/consose/cache/huggingface/hub"
|
| 20 |
+
os.environ["HF_HOME"] = "/eos/jeodpp/home/users/consose/cache/huggingface/hub"
|
| 21 |
+
|
| 22 |
+
# Initialize the Accelerator
|
| 23 |
+
accelerator = Accelerator()
|
| 24 |
+
|
| 25 |
+
# Use the device managed by Accelerator
|
| 26 |
+
device = accelerator.device
|
| 27 |
+
print("Using accelerator device =", device)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
from sentence_transformers import CrossEncoder
|
| 31 |
+
model_sf_mxbai = CrossEncoder("mixedbread-ai/mxbai-rerank-large-v1" ,device=device)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def RAG_retrieval_Base(queryText ,passages, min_threshold=0.0, max_num_passages=None):
|
| 38 |
+
|
| 39 |
+
# # Example query
|
| 40 |
+
# query = "What is the capital of France?"
|
| 41 |
+
#
|
| 42 |
+
# # Example passages
|
| 43 |
+
# ppppassages = [
|
| 44 |
+
# "This is the first passage.",
|
| 45 |
+
# "The capital of France is Paris.",
|
| 46 |
+
# "This is the third passage.",
|
| 47 |
+
# "Paris is a beautiful city.",
|
| 48 |
+
# "The Eiffel Tower is in Paris."
|
| 49 |
+
# ]
|
| 50 |
+
#
|
| 51 |
+
# # Rank the passages with respect to the query
|
| 52 |
+
# ranked_passages = model_sf_mxbai.rank(query, ppppassages)
|
| 53 |
+
|
| 54 |
+
try:
|
| 55 |
+
|
| 56 |
+
df_filtered = pd.DataFrame()
|
| 57 |
+
|
| 58 |
+
if max_num_passages:
|
| 59 |
+
result_rerank = model_sf_mxbai.rank(queryText, passages, return_documents=False, top_k=max_num_passages)
|
| 60 |
+
else:
|
| 61 |
+
nback =int(0.1 *len(passages)) # 10% of the number of passages
|
| 62 |
+
if nback<=0:
|
| 63 |
+
nback=1
|
| 64 |
+
result_rerank = model_sf_mxbai.rank(queryText, passages, return_documents=False, top_k=nback)
|
| 65 |
+
|
| 66 |
+
if result_rerank:
|
| 67 |
+
df = pd.DataFrame(result_rerank) # corpus_id, score
|
| 68 |
+
|
| 69 |
+
if min_threshold >0:
|
| 70 |
+
df_filtered = df[df['score'] >= min_threshold]
|
| 71 |
+
else:
|
| 72 |
+
df_filtered =df.copy()
|
| 73 |
+
|
| 74 |
+
selected_passages = [passages[i] for i in df_filtered['corpus_id']]
|
| 75 |
+
|
| 76 |
+
# Add the selected passages as a new column "Passage" to the DataFrame
|
| 77 |
+
df_filtered['Passage'] = selected_passages
|
| 78 |
+
|
| 79 |
+
df_filtered = df_filtered.drop_duplicates(subset='Passage', keep='first')
|
| 80 |
+
|
| 81 |
+
# df_filtered = df_filtered.sort_values(by='score', ascending=False)
|
| 82 |
+
|
| 83 |
+
# Return the filtered DataFrame
|
| 84 |
+
return df_filtered
|
| 85 |
+
|
| 86 |
+
except Exception as e:
|
| 87 |
+
# Log the exception message or handle it as needed
|
| 88 |
+
print(f"An error occurred: {e}")
|
| 89 |
+
return pd.DataFrame() # Return an empty DataFrame in case of error
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
if __name__ == '__main__':
|
| 96 |
+
|
| 97 |
+
queryText = 'A man is eating a piece of bread'
|
| 98 |
+
|
| 99 |
+
# Define the passages list
|
| 100 |
+
passages = [
|
| 101 |
+
"A man is eating food.",
|
| 102 |
+
"A man is eating pasta.",
|
| 103 |
+
"The girl is carrying a baby.",
|
| 104 |
+
"A man is riding a horse.",
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
df_retrieved = RAG_retrieval_Base(queryText, passages, min_threshold=0, max_num_passages=3)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
print(df_retrieved)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
print("end of computations")
|
| 114 |
+
|
virtuosoQueryRest.py
CHANGED
|
@@ -3,10 +3,14 @@ from requests.auth import HTTPDigestAuth, HTTPBasicAuth
|
|
| 3 |
import ssl
|
| 4 |
import json
|
| 5 |
|
|
|
|
| 6 |
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
|
|
|
|
| 10 |
def execute_query(endpoint, query, auth):
|
| 11 |
headers = {
|
| 12 |
'Content-Type': 'application/x-www-form-urlencoded',
|
|
|
|
| 3 |
import ssl
|
| 4 |
import json
|
| 5 |
|
| 6 |
+
from joblib import Memory
|
| 7 |
|
| 8 |
+
cachedir = 'cached'
|
| 9 |
+
mem = Memory(cachedir, verbose=False)
|
| 10 |
|
| 11 |
|
| 12 |
|
| 13 |
+
@mem.cache
|
| 14 |
def execute_query(endpoint, query, auth):
|
| 15 |
headers = {
|
| 16 |
'Content-Type': 'application/x-www-form-urlencoded',
|