richlai commited on
Commit
a2ea235
·
1 Parent(s): 7d17f51

add Qdrant

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +103 -11
  3. requirements.txt +2 -1
.gitignore CHANGED
@@ -1 +1,2 @@
1
  __pycache__/
 
 
1
  __pycache__/
2
+ .env
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  from typing import List
3
  from chainlit.types import AskFileResponse
4
  from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader
@@ -11,12 +12,19 @@ from aimakerspace.openai_utils.embedding import EmbeddingModel
11
  from aimakerspace.vectordatabase import VectorDatabase
12
  from aimakerspace.openai_utils.chatmodel import ChatOpenAI
13
  import chainlit as cl
14
- import fitz
15
-
 
 
 
 
 
 
16
  system_template = """\
17
  Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer."""
18
  system_role_prompt = SystemRolePrompt(system_template)
19
 
 
20
  user_prompt_template = """\
21
  Context:
22
  {context}
@@ -26,6 +34,16 @@ Question:
26
  """
27
  user_role_prompt = UserRolePrompt(user_prompt_template)
28
 
 
 
 
 
 
 
 
 
 
 
29
  class RetrievalAugmentedQAPipeline:
30
  def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase) -> None:
31
  self.llm = llm
@@ -48,6 +66,34 @@ class RetrievalAugmentedQAPipeline:
48
 
49
  return {"response": generate_response(), "context": context_list}
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  text_splitter = CharacterTextSplitter()
52
 
53
 
@@ -79,10 +125,42 @@ def process_pdf_file(file: AskFileResponse):
79
  texts = text_splitter.split_texts(documents)
80
  return texts
81
 
82
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  @cl.on_chat_start
85
  async def on_chat_start():
 
86
  files = None
87
 
88
  # Wait for the user to upload a file
@@ -110,18 +188,32 @@ async def on_chat_start():
110
  texts = process_text_file(file)
111
 
112
  print(f"Processing {len(texts)} text chunks")
113
-
114
- # Create a dict vector store
115
- vector_db = VectorDatabase()
116
- vector_db = await vector_db.abuild_from_list(texts)
117
 
118
  chat_openai = ChatOpenAI()
119
 
 
 
 
 
 
 
 
 
120
  # Create a chain
121
- retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(
122
- vector_db_retriever=vector_db,
123
- llm=chat_openai
124
- )
 
 
 
 
 
 
 
 
 
 
125
 
126
  # Let the user know that the system is ready
127
  msg.content = f"Processing `{file.name}` done. You can now ask questions!"
 
1
  import os
2
+ import numpy as np
3
  from typing import List
4
  from chainlit.types import AskFileResponse
5
  from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader
 
12
  from aimakerspace.vectordatabase import VectorDatabase
13
  from aimakerspace.openai_utils.chatmodel import ChatOpenAI
14
  import chainlit as cl
15
+ from qdrant_client import QdrantClient
16
+ from qdrant_client.models import VectorParams, Distance
17
+ from chainlit.input_widget import Select
18
+ from qdrant_client.models import PointStruct
19
+ #Qdrant client
20
+ client = None
21
+
22
+ #System Chat Prompt
23
  system_template = """\
24
  Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer."""
25
  system_role_prompt = SystemRolePrompt(system_template)
26
 
27
+ #User Prompt for chat
28
  user_prompt_template = """\
29
  Context:
30
  {context}
 
34
  """
35
  user_role_prompt = UserRolePrompt(user_prompt_template)
36
 
37
+ #Categorization of VectorDatabase
38
+ system_template_db = """\
39
+ You are an expert in categorization. Given the last user response determine if he or she wants to use the Qdrant database. If yes return the output single word 'QDrant' without any other phrases. If no return the only the word 'AI Makerspace'.
40
+
41
+ """
42
+ system_role_prompt_db = SystemRolePrompt(system_template_db)
43
+
44
+ user_prompt_template_db = "User Input:\n{user_input}"
45
+ user_role_prompt_db = UserRolePrompt(user_prompt_template)
46
+
47
  class RetrievalAugmentedQAPipeline:
48
  def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase) -> None:
49
  self.llm = llm
 
66
 
67
  return {"response": generate_response(), "context": context_list}
68
 
69
+ class RetrievalAugmentedQAPipelineQdrant:
70
+ def __init__(self, llm: ChatOpenAI(), vector_db_retriever) -> None:
71
+ self.llm = llm
72
+ self.vector_db_retriever = vector_db_retriever
73
+ self.embedding_model = EmbeddingModel()
74
+
75
+ async def arun_pipeline(self, user_query: str):
76
+ query_vector = self.embedding_model.get_embedding(user_query)
77
+ context_list = self.vector_db_retriever.search(
78
+ collection_name="my_collection",
79
+ query_vector=query_vector,
80
+ limit=4
81
+ )
82
+
83
+ context_prompt = ""
84
+ for context in context_list:
85
+ context_prompt += context.payload['text'] + "\n"
86
+
87
+ formatted_system_prompt = system_role_prompt.create_message()
88
+
89
+ formatted_user_prompt = user_role_prompt.create_message(question=user_query, context=context_prompt)
90
+
91
+ async def generate_response():
92
+ async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]):
93
+ yield chunk
94
+
95
+ return {"response": generate_response(), "context": context_list}
96
+
97
  text_splitter = CharacterTextSplitter()
98
 
99
 
 
125
  texts = text_splitter.split_texts(documents)
126
  return texts
127
 
128
+ async def initialize_qdrant(text):
129
+ client = QdrantClient(":memory:")
130
+ if not client.collection_exists("my_collection"):
131
+ client.create_collection(
132
+ collection_name="my_collection",
133
+ vectors_config=VectorParams(size=1536, distance=Distance.COSINE),
134
+ )
135
+ embedding_model = EmbeddingModel()
136
+ embeddings = await embedding_model.async_get_embeddings(text)
137
+ i = 0
138
+ for text, embedding in zip(text, embeddings):
139
+ insert(text, np.array(embedding), i, client)
140
+ i+=1
141
+
142
+ return client
143
+
144
+ def insert(text, vector, idx, client):
145
+ point= PointStruct(
146
+ id=idx,
147
+ vector=vector.tolist(),
148
+ payload={"text": text}
149
+ )
150
+ client.upsert(
151
+ collection_name="my_collection",
152
+ points=[point]
153
+ )
154
+
155
+
156
+ def choose_db(llm, user_input):
157
+ formatted_system_prompt_db = system_role_prompt_db.create_message()
158
+ formatted_user_prompt_db = user_role_prompt_db.create_message(question=user_input)
159
+ return llm.run([formatted_system_prompt_db, formatted_user_prompt_db])
160
 
161
  @cl.on_chat_start
162
  async def on_chat_start():
163
+ global client
164
  files = None
165
 
166
  # Wait for the user to upload a file
 
188
  texts = process_text_file(file)
189
 
190
  print(f"Processing {len(texts)} text chunks")
 
 
 
 
191
 
192
  chat_openai = ChatOpenAI()
193
 
194
+ res = await cl.AskUserMessage(content="Do you want to use the QDrant vector database or AI Makerspace's?").send()
195
+
196
+ if res:
197
+ chosen_db = choose_db(chat_openai, res['content'])
198
+ await cl.Message(
199
+ content=f"You have chosen {chosen_db}. Please start asking questions!",
200
+ ).send()
201
+
202
  # Create a chain
203
+ retrieval_augmented_qa_pipeline = None
204
+ if chosen_db.lower() == 'qdrant':
205
+ client = await initialize_qdrant(texts)
206
+ retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipelineQdrant(
207
+ vector_db_retriever=client,
208
+ llm=chat_openai
209
+ )
210
+ else:
211
+ vector_db = VectorDatabase()
212
+ vector_db = await vector_db.abuild_from_list(texts)
213
+ retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(
214
+ vector_db_retriever=vector_db,
215
+ llm=chat_openai
216
+ )
217
 
218
  # Let the user know that the system is ready
219
  msg.content = f"Processing `{file.name}` done. You can now ask questions!"
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  numpy
2
  chainlit==0.7.700
3
  openai
4
- pymupdf==1.24.9
 
 
1
  numpy
2
  chainlit==0.7.700
3
  openai
4
+ pymupdf==1.24.9
5
+ qdrant-client==1.11.0