Spaces:
Runtime error
Runtime error
Marina Pliusnina
commited on
Commit
·
c774338
1
Parent(s):
c8bd9ca
adding number of chunks and context
Browse files
app.py
CHANGED
|
@@ -37,13 +37,14 @@ def generate(prompt, model_parameters):
|
|
| 37 |
)
|
| 38 |
|
| 39 |
|
| 40 |
-
def submit_input(input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, num_beams, temperature):
|
| 41 |
if input_.strip() == "":
|
| 42 |
gr.Warning("Not possible to inference an empty input")
|
| 43 |
return None
|
| 44 |
|
| 45 |
|
| 46 |
model_parameters = {
|
|
|
|
| 47 |
"MAX_NEW_TOKENS": max_new_tokens,
|
| 48 |
"REPETITION_PENALTY": repetition_penalty,
|
| 49 |
"TOP_K": top_k,
|
|
@@ -109,6 +110,13 @@ def gradio_app():
|
|
| 109 |
|
| 110 |
with gr.Row(variant="panel"):
|
| 111 |
with gr.Accordion("Model parameters", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
max_new_tokens = Slider(
|
| 113 |
minimum=50,
|
| 114 |
maximum=1000,
|
|
@@ -154,7 +162,7 @@ def gradio_app():
|
|
| 154 |
label="Temperature"
|
| 155 |
)
|
| 156 |
|
| 157 |
-
parameters_compontents = [max_new_tokens, repetition_penalty, top_k, top_p, do_sample, num_beams, temperature]
|
| 158 |
|
| 159 |
with gr.Column(variant="panel"):
|
| 160 |
output = Textbox(
|
|
|
|
| 37 |
)
|
| 38 |
|
| 39 |
|
| 40 |
+
def submit_input(input_, num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, num_beams, temperature):
|
| 41 |
if input_.strip() == "":
|
| 42 |
gr.Warning("Not possible to inference an empty input")
|
| 43 |
return None
|
| 44 |
|
| 45 |
|
| 46 |
model_parameters = {
|
| 47 |
+
"NUM_CHUNKS": num_chunks,
|
| 48 |
"MAX_NEW_TOKENS": max_new_tokens,
|
| 49 |
"REPETITION_PENALTY": repetition_penalty,
|
| 50 |
"TOP_K": top_k,
|
|
|
|
| 110 |
|
| 111 |
with gr.Row(variant="panel"):
|
| 112 |
with gr.Accordion("Model parameters", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
|
| 113 |
+
num_chunks = Slider(
|
| 114 |
+
minimum=1,
|
| 115 |
+
maximum=6,
|
| 116 |
+
step=1,
|
| 117 |
+
value=4,
|
| 118 |
+
label="Number of chunks"
|
| 119 |
+
)
|
| 120 |
max_new_tokens = Slider(
|
| 121 |
minimum=50,
|
| 122 |
maximum=1000,
|
|
|
|
| 162 |
label="Temperature"
|
| 163 |
)
|
| 164 |
|
| 165 |
+
parameters_compontents = [num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, num_beams, temperature]
|
| 166 |
|
| 167 |
with gr.Column(variant="panel"):
|
| 168 |
output = Textbox(
|
rag.py
CHANGED
|
@@ -24,19 +24,11 @@ class RAG:
|
|
| 24 |
|
| 25 |
logging.info("RAG loaded!")
|
| 26 |
|
| 27 |
-
def get_context(self, instruction, number_of_contexts=
|
| 28 |
-
|
| 29 |
-
context = ""
|
| 30 |
-
|
| 31 |
|
| 32 |
documentos = self.vectore_store.similarity_search_with_score(instruction, k=number_of_contexts)
|
| 33 |
|
| 34 |
-
|
| 35 |
-
for doc in documentos:
|
| 36 |
-
|
| 37 |
-
context += doc[0].page_content
|
| 38 |
-
|
| 39 |
-
return context
|
| 40 |
|
| 41 |
def predict(self, instruction, context, model_parameters):
|
| 42 |
|
|
@@ -61,14 +53,30 @@ class RAG:
|
|
| 61 |
response = requests.post(self.model_name, headers=headers, json=payload)
|
| 62 |
|
| 63 |
return response.json()[0]["generated_text"].split("###")[-1][8:-1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
def get_response(self, prompt: str, model_parameters: dict) -> str:
|
| 66 |
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
-
response = self.predict(prompt,
|
| 70 |
|
| 71 |
if not response:
|
| 72 |
return self.NO_ANSWER_MESSAGE
|
| 73 |
|
| 74 |
-
return response
|
|
|
|
| 24 |
|
| 25 |
logging.info("RAG loaded!")
|
| 26 |
|
| 27 |
+
def get_context(self, instruction, number_of_contexts=4):
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
documentos = self.vectore_store.similarity_search_with_score(instruction, k=number_of_contexts)
|
| 30 |
|
| 31 |
+
return documentos
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
def predict(self, instruction, context, model_parameters):
|
| 34 |
|
|
|
|
| 53 |
response = requests.post(self.model_name, headers=headers, json=payload)
|
| 54 |
|
| 55 |
return response.json()[0]["generated_text"].split("###")[-1][8:-1]
|
| 56 |
+
|
| 57 |
+
def beautiful_context(self, docs):
|
| 58 |
+
|
| 59 |
+
text_context = ""
|
| 60 |
+
|
| 61 |
+
full_context = ""
|
| 62 |
+
|
| 63 |
+
for doc in docs:
|
| 64 |
+
text_context += doc[0].page_content
|
| 65 |
+
full_context += doc[0].page_content + "\n"
|
| 66 |
+
full_context += doc[0].metadata["Títol de la norma"] + "\n\n"
|
| 67 |
+
|
| 68 |
+
return text_context, full_context
|
| 69 |
|
| 70 |
def get_response(self, prompt: str, model_parameters: dict) -> str:
|
| 71 |
|
| 72 |
+
docs = self.get_context(prompt, model_parameters["NUM_CHUNKS"])
|
| 73 |
+
text_context, full_context = beautiful_context(docs)
|
| 74 |
+
|
| 75 |
+
del model_parameters["NUM_CHUNKS"]
|
| 76 |
|
| 77 |
+
response = self.predict(prompt, text_context, model_parameters)
|
| 78 |
|
| 79 |
if not response:
|
| 80 |
return self.NO_ANSWER_MESSAGE
|
| 81 |
|
| 82 |
+
return response, full_context
|