alperensn commited on
Commit
25fcb73
·
verified ·
1 Parent(s): cb96db0

Upload files

Browse files
Files changed (9) hide show
  1. app.py +121 -0
  2. build_knowledge_base.py +55 -0
  3. config.py +34 -0
  4. dataFetch.py +79 -0
  5. dataPrep.py +199 -0
  6. data_processing.py +138 -0
  7. rag_pipeline.py +139 -0
  8. requirements.txt +130 -0
  9. vector_store_manager.py +29 -0
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =================================================================================
2
+ # app.py: Main application file for the Streamlit web interface
3
+ # =================================================================================
4
+ import streamlit as st
5
+ from dotenv import load_dotenv
6
+
7
+ # Load environment variables from .env file
8
+ load_dotenv()
9
+
10
+ # Import the modules we've created
11
+ import config
12
+ import rag_pipeline # Now using the LlamaIndex pipeline
13
+
14
+ # --- Page Configuration ---
15
+ st.set_page_config(
16
+ page_title="PharmaBot",
17
+ page_icon="🤖",
18
+ layout="wide",
19
+ initial_sidebar_state="expanded",
20
+ )
21
+
22
+ # --- State Management ---
23
+ def initialize_state():
24
+ """Initializes session state variables."""
25
+ if "messages" not in st.session_state:
26
+ st.session_state.messages = [{"role": "assistant", "content": "Welcome to PharmaBot! How can I help you today?"}]
27
+ if "query_engine" not in st.session_state:
28
+ st.session_state.query_engine = None
29
+ if "initialized" not in st.session_state:
30
+ st.session_state.initialized = False
31
+
32
+ # --- UI Components ---
33
+ def setup_sidebar():
34
+ """Sets up the sidebar with app information."""
35
+ with st.sidebar:
36
+ st.header("About PharmaBot")
37
+ st.info(
38
+ "PharmaBot is an AI assistant designed to answer questions about "
39
+ "pharmaceuticals based on a knowledge base of RAG documents. "
40
+ "It uses a Retrieval-Augmented Generation (RAG) pipeline to provide accurate, "
41
+ "context-aware answers."
42
+ )
43
+ st.warning("**Disclaimer: I am an AI assistant, not a medical professional. This information is for educational purposes only. Please consult with a qualified healthcare provider for any health concerns or before making any medical decisions.**"
44
+ )
45
+ st.markdown("---")
46
+ st.header("Technical Details")
47
+ st.markdown(
48
+ f"""
49
+ - **LLM Model:** `{config.LLM_MODEL_ID}`
50
+ - **Embedding Model:** `{config.EMBEDDING_MODEL_NAME}`
51
+ - **Vector Type:** `LLama Index Vector Store`
52
+ - **Vector Store:** `{config.VECTOR_STORE_PATH}`
53
+ """
54
+ )
55
+
56
+ def display_chat_history():
57
+ """Displays the chat history."""
58
+ for message in st.session_state.messages:
59
+ with st.chat_message(message["role"]):
60
+ st.write(message["content"])
61
+
62
+ def handle_user_input(chat_engine):
63
+ """Handles user input and displays the response."""
64
+ if prompt := st.chat_input("Ask me anything about pharmaceuticals..."):
65
+ st.session_state.messages.append({"role": "user", "content": prompt})
66
+ with st.chat_message("user"):
67
+ st.write(prompt)
68
+
69
+ with st.chat_message("assistant"):
70
+ with st.spinner("Thinking..."):
71
+ response = chat_engine.chat(prompt)
72
+ st.write(str(response))
73
+
74
+ st.session_state.messages.append({"role": "assistant", "content": str(response)})
75
+
76
+ import time
77
+
78
+ # --- Main Application Logic ---
79
+ def main():
80
+ """Main function to run the Streamlit app."""
81
+ st.set_page_config(page_title="PharmaBot Assistant", page_icon="💊")
82
+ initialize_state()
83
+ st.title("💊 PharmaBot: Your AI Pharmaceutical Assistant")
84
+ setup_sidebar()
85
+
86
+ # Initialize the RAG pipeline if it hasn't been already
87
+ if not st.session_state.initialized:
88
+ with st.status("Initializing the RAG pipeline...", expanded=True) as status:
89
+ try:
90
+ status.write("Step 1/3: Initializing LLM and embedding models...")
91
+ rag_pipeline.initialize_llm_and_embed_model()
92
+
93
+ status.write("Step 2/3: Loading vector index from storage...")
94
+ index = rag_pipeline.load_vector_index()
95
+
96
+ status.write("Step 3/3: Building the conversational chat engine...")
97
+ st.session_state.query_engine = rag_pipeline.build_query_engine(index)
98
+
99
+ st.session_state.initialized = True
100
+ status.update(label="Initialization Complete!", state="complete", expanded=False)
101
+ time.sleep(1) # Brief pause to show completion
102
+
103
+ except FileNotFoundError as e:
104
+ status.update(label="Initialization Failed", state="error")
105
+ st.error(f"Error: {e}. Please make sure the vector store is built.")
106
+ st.warning("To build the vector store, run `python build_knowledge_base.py` from your terminal.")
107
+ return
108
+ except Exception as e:
109
+ status.update(label="Initialization Failed", state="error")
110
+ st.error(f"An unexpected error occurred during initialization: {e}")
111
+ return
112
+ st.rerun()
113
+
114
+ # Display chat and handle input if initialized
115
+ if st.session_state.initialized:
116
+ display_chat_history()
117
+ handle_user_input(st.session_state.query_engine)
118
+
119
+ if __name__ == "__main__":
120
+ main()
121
+
build_knowledge_base.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =================================================================================
2
+ # build_knowledge_base.py: One-time script to build and save the vector store
3
+ # =================================================================================
4
+ from llama_index.core import VectorStoreIndex, Document
5
+ from llama_index.core.node_parser import SentenceSplitter
6
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
7
+ import config
8
+ import data_processing
9
+ import os
10
+
11
+ def build_vector_store():
12
+ """
13
+ Builds and saves a LlamaIndex vector store from the processed documents.
14
+ """
15
+ # Load and process documents from all sources
16
+ all_docs = data_processing.load_and_process_all()
17
+
18
+ # If no documents were created, exit
19
+ if not all_docs:
20
+ print("No documents were created. Exiting.")
21
+ return
22
+
23
+ # The documents are already in the correct LlamaIndex format.
24
+ llama_documents = all_docs
25
+
26
+ # Initialize the embedding model
27
+ print(f"Loading embedding model: {config.EMBEDDING_MODEL_NAME}...")
28
+ embed_model = HuggingFaceEmbedding(model_name=config.EMBEDDING_MODEL_NAME)
29
+
30
+ # Create the LlamaIndex VectorStoreIndex
31
+ print("Creating the LlamaIndex vector store...")
32
+ index = VectorStoreIndex.from_documents(
33
+ llama_documents,
34
+ embed_model=embed_model,
35
+ transformations=[SentenceSplitter(chunk_size=1000, chunk_overlap=150)]
36
+ )
37
+
38
+ # Persist the index to disk
39
+ print(f"Saving the vector store to: {config.LLAMA_INDEX_STORE_PATH}")
40
+ index.storage_context.persist(persist_dir=config.LLAMA_INDEX_STORE_PATH)
41
+ print("Vector store built and saved successfully.")
42
+
43
+ def main():
44
+ """
45
+ Main function to build the knowledge base.
46
+ """
47
+ # Check if the vector store already exists
48
+ if os.path.exists(config.LLAMA_INDEX_STORE_PATH):
49
+ print("Vector store already exists. Skipping build process.")
50
+ else:
51
+ build_vector_store()
52
+
53
+ if __name__ == "__main__":
54
+ main()
55
+
config.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =================================================================================
2
+ # config.py: Project configuration settings
3
+ # =================================================================================
4
+ # This file contains constant parameters like model names, file paths, etc.
5
+ # Sensitive information like API keys will be read from the .env file.
6
+
7
+ # --- Model Settings ---
8
+ # The main language model to be used in the RAG chain
9
+ LLM_MODEL_ID = "gemini-2.0-flash-001"
10
+
11
+ # The embedding model for converting text to vectors
12
+
13
+ EMBEDDING_MODEL_NAME = "pritamdeka/S-BioBert-snli-multinli-stsb"
14
+
15
+ # --- File Paths ---
16
+ # Path to the raw data downloaded from the openFDA API
17
+ RAW_DATA_PATH = "./fda_data/drug_labels_all.json"
18
+ # Path to the cleaned/processed data
19
+ CLEANED_DATA_PATH = "./fda_data/fda_data_processed.jsonl"
20
+
21
+ # The name of the folder where the vector database will be saved
22
+ VECTOR_STORE_PATH = "llamaIndexVectorBase_fda"
23
+
24
+ # =================================================================================
25
+ # LlamaIndex Settings
26
+ # =================================================================================
27
+ LLAMA_INDEX_STORE_PATH = "./llamaIndexVectorBase_fda"
28
+
29
+ # =================================================================================
30
+ # Data Source Paths
31
+ # =================================================================================
32
+ ##HEALTHCARE_MAGIC_PATH = "../healthCareMagic/HealthCareMagic-100k.json"
33
+ ##MEDQUAD_PATH = "../medQuad/medDataset.json"
34
+
dataFetch.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ import os
4
+ import math
5
+
6
+ # Define the API endpoint
7
+ API_URL = "https://api.fda.gov/drug/label.json"
8
+
9
+ # Define the output directory and file for all data
10
+ OUTPUT_DIR = "fda_data"
11
+ OUTPUT_FILE = os.path.join(OUTPUT_DIR, "drug_labels_all.json")
12
+
13
+ # The API's maximum limit per request is 1000
14
+ CHUNK_SIZE = 1000
15
+ MAX_RECORDS = 25000
16
+
17
+ def fetch_all_fda_data():
18
+ """
19
+ Fetches drug label data from the openFDA API using pagination
20
+ and saves it to a single file.
21
+ """
22
+ print("Starting to fetch data from the openFDA endpoint...")
23
+
24
+ try:
25
+ # Step 1: Make an initial request to get the total number of records
26
+ print("Determining the total number of records...")
27
+ initial_response = requests.get(API_URL, params={"limit": 1})
28
+ initial_response.raise_for_status()
29
+ total_records = initial_response.json()['meta']['results']['total']
30
+
31
+ records_to_fetch = min(total_records, MAX_RECORDS)
32
+ print(f"Found a total of {total_records} records. Fetching up to {records_to_fetch} records.")
33
+
34
+ all_results = []
35
+
36
+ # Step 2: Loop through the data in chunks
37
+ num_chunks = math.ceil(records_to_fetch / CHUNK_SIZE)
38
+ for i in range(num_chunks):
39
+ skip = i * CHUNK_SIZE
40
+
41
+ # Ensure we don't request more than records_to_fetch
42
+ limit = min(CHUNK_SIZE, records_to_fetch - skip)
43
+ if limit <= 0:
44
+ break
45
+
46
+ params = {"limit": limit, "skip": skip}
47
+
48
+ print(f"Fetching chunk {i+1}/{num_chunks} (records {skip} to {skip + limit - 1})...")
49
+
50
+ response = requests.get(API_URL, params=params)
51
+ response.raise_for_status()
52
+
53
+ chunk_data = response.json()
54
+ if 'results' in chunk_data:
55
+ all_results.extend(chunk_data['results'])
56
+
57
+ print("\nAll data has been fetched successfully.")
58
+
59
+ # Step 3: Save all the data to a single file
60
+ if not os.path.exists(OUTPUT_DIR):
61
+ os.makedirs(OUTPUT_DIR)
62
+ print(f"Created directory: {OUTPUT_DIR}")
63
+
64
+ with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
65
+ json.dump({"results": all_results}, f, ensure_ascii=False, indent=4)
66
+
67
+ print(f"All {len(all_results)} records saved to: {OUTPUT_FILE}")
68
+
69
+ except requests.exceptions.HTTPError as http_err:
70
+ print(f"HTTP error occurred: {http_err}")
71
+ except requests.exceptions.RequestException as req_err:
72
+ print(f"An error occurred while fetching data: {req_err}")
73
+ except json.JSONDecodeError:
74
+ print("Failed to parse the response as JSON.")
75
+ except Exception as e:
76
+ print(f"An unexpected error occurred: {e}")
77
+
78
+ if __name__ == "__main__":
79
+ fetch_all_fda_data()
dataPrep.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from tqdm import tqdm
4
+ import os
5
+ import config
6
+
7
+ # --- Functions from dataOrganize.py ---
8
+
9
+ def clean_text(text: str) -> str:
10
+ """
11
+ Cleans the input text by removing common noise from FDA documents.
12
+ """
13
+ if not text:
14
+ return ""
15
+ text = re.sub(r'REVISED:\s*\d{1,2}/\d{4}', '', text)
16
+ text = re.sub(r'\s{2,}', ' ', text).strip()
17
+ text = re.sub(r'[\-=*]{3,}', '', text)
18
+ return text
19
+
20
+ def organize_drug_data(input_path):
21
+ """
22
+ Loads raw drug data, filters for high-quality entries, cleans the text,
23
+ and returns the organized data as a list.
24
+ """
25
+ print(f"Loading raw data from: {input_path}...")
26
+ try:
27
+ with open(input_path, 'r', encoding='utf-8') as f:
28
+ data = json.load(f)
29
+ except FileNotFoundError:
30
+ print(f"Error: The file '{input_path}' was not found.")
31
+ return []
32
+ except json.JSONDecodeError:
33
+ print(f"Error: Could not decode JSON from '{input_path}'.")
34
+ return []
35
+
36
+ entries = data.get('results', data) if isinstance(data, dict) else data
37
+
38
+ if not isinstance(entries, list):
39
+ print("Error: The JSON data is not in the expected list format.")
40
+ return []
41
+
42
+ organized_data = []
43
+ print("Filtering, cleaning, and organizing drug data...")
44
+
45
+ for entry in tqdm(entries, desc="Processing drug entries"):
46
+ if not isinstance(entry, dict):
47
+ continue
48
+
49
+ openfda = entry.get("openfda", {})
50
+ brand_name_list = openfda.get("brand_name")
51
+ generic_name_list = openfda.get("generic_name")
52
+
53
+ if not brand_name_list and not generic_name_list:
54
+ continue
55
+
56
+ if "indications_and_usage" not in entry:
57
+ continue
58
+
59
+ brand_name = brand_name_list[0] if brand_name_list else "Unknown Brand"
60
+ generic_name = generic_name_list[0] if generic_name_list else "Unknown Generic"
61
+
62
+ sections_to_extract = {
63
+ "indications_and_usage": "Indications and Usage", "adverse_reactions": "Adverse Reactions",
64
+ "drug_interactions": "Drug Interactions", "contraindications": "Contraindications",
65
+ "warnings": "Warnings", "boxed_warning": "Boxed Warning",
66
+ "mechanism_of_action": "Mechanism of Action", "pharmacokinetics": "Pharmacokinetics",
67
+ "dosage_and_administration": "Dosage and Administration", "how_supplied": "How Supplied",
68
+ "storage_and_handling": "Storage and Handling", "information_for_patients": "Information for Patients",
69
+ "pregnancy": "Pregnancy", "nursing_mothers": "Nursing Mothers",
70
+ "pediatric_use": "Pediatric Use", "geriatric_use": "Geriatric Use"
71
+ }
72
+
73
+ processed_sections = {}
74
+ for key, section_name in sections_to_extract.items():
75
+ text_list = entry.get(key)
76
+ if text_list and isinstance(text_list, list) and text_list[0]:
77
+ cleaned_text = clean_text(text_list[0])
78
+ if cleaned_text:
79
+ processed_sections[section_name] = cleaned_text
80
+
81
+ if processed_sections:
82
+ organized_entry = {
83
+ "brand_name": brand_name,
84
+ "generic_name": generic_name,
85
+ "sections": processed_sections
86
+ }
87
+ organized_data.append(organized_entry)
88
+
89
+ print(f"Found {len(organized_data)} high-quality drug entries.")
90
+ return organized_data
91
+
92
+ # --- Functions from deduplicate_drugs.py ---
93
+
94
+ def deduplicate_drugs(data):
95
+ """
96
+ Deduplicates a list of drugs based on brand_name and generic_name.
97
+ """
98
+ print(f"Deduplicating {len(data)} drugs...")
99
+ seen_drugs = set()
100
+ deduplicated_drugs = []
101
+
102
+ for drug in data:
103
+ brand_name = drug.get('brand_name')
104
+ generic_name = drug.get('generic_name')
105
+
106
+ if isinstance(brand_name, list):
107
+ brand_name = brand_name[0] if brand_name else None
108
+ if isinstance(generic_name, list):
109
+ generic_name = generic_name[0] if generic_name else None
110
+
111
+ brand_name_lower = brand_name.lower() if brand_name else None
112
+ generic_name_lower = generic_name.lower() if generic_name else None
113
+
114
+ drug_identifier = (brand_name_lower, generic_name_lower)
115
+
116
+ if drug_identifier not in seen_drugs:
117
+ seen_drugs.add(drug_identifier)
118
+ deduplicated_drugs.append(drug)
119
+
120
+ print(f"Deduplication complete. Found {len(deduplicated_drugs)} unique drugs.")
121
+ return deduplicated_drugs
122
+
123
+ # --- Functions from format_fda_data.py ---
124
+
125
+ def generate_section_id(section_title):
126
+ """Generates a simplified, lowercase, underscore-separated ID from a section title."""
127
+ s = re.sub(r'[/\-&]', ' ', section_title)
128
+ s = re.sub(r'[^a-zA-Z0-9\s]', '', s)
129
+ parts = s.lower().split()
130
+ if len(parts) >= 2:
131
+ return '_'.join(parts[:2])
132
+ elif len(parts) == 1:
133
+ return parts[0]
134
+ else:
135
+ return "section"
136
+
137
+ def transform_drug_data(drugs, output_file_path):
138
+ """
139
+ Transforms drug data to a JSON Lines format.
140
+ """
141
+ print(f"Transforming {len(drugs)} drugs to JSONL format...")
142
+ processed_records = []
143
+
144
+ for drug in drugs:
145
+ generic_name = drug.get('generic_name')
146
+ sections = drug.get('sections')
147
+
148
+ if not generic_name or not isinstance(sections, dict):
149
+ continue
150
+
151
+ if isinstance(generic_name, list):
152
+ generic_name = generic_name[0] if generic_name else None
153
+
154
+ if not generic_name:
155
+ continue
156
+
157
+ generic_name_upper = generic_name.upper()
158
+
159
+ for section_title, section_content in sections.items():
160
+ if not section_title or not section_content:
161
+ continue
162
+
163
+ section_id = generate_section_id(section_title)
164
+ doc_id = f"{generic_name_upper.replace(' ', '_')}_{section_id}"
165
+
166
+ record = {
167
+ "doc_id": doc_id,
168
+ "generic_name": generic_name_upper,
169
+ "section": section_title,
170
+ "content": section_content.strip()
171
+ }
172
+ processed_records.append(json.dumps(record))
173
+
174
+ os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
175
+ with open(output_file_path, 'w') as f_out:
176
+ f_out.write('\n'.join(processed_records))
177
+
178
+ print(f"Transformation complete. {len(processed_records)} records created.")
179
+ print(f"Transformed data saved to: {output_file_path}")
180
+
181
+
182
+ if __name__ == '__main__':
183
+ # Define file paths using config
184
+ raw_data_path = config.RAW_DATA_PATH
185
+ cleaned_data_path = config.CLEANED_DATA_PATH
186
+
187
+ # --- Run the full pipeline ---
188
+ print("--- Starting Data Preparation Pipeline ---")
189
+
190
+ # Step 1: Organize and clean the raw data in memory
191
+ organized_data = organize_drug_data(raw_data_path)
192
+
193
+ # Step 2: Deduplicate the cleaned data in memory
194
+ deduplicated_data = deduplicate_drugs(organized_data)
195
+
196
+ # Step 3: Transform the deduplicated data and write to the final file
197
+ transform_drug_data(deduplicated_data, cleaned_data_path)
198
+
199
+ print("--- Data Preparation Pipeline Finished ---")
data_processing.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =================================================================================
2
+ # data_processing.py: Process and prepare raw data
3
+ # =================================================================================
4
+ import json
5
+ import re
6
+ from llama_index.core import Document
7
+ from tqdm import tqdm
8
+ import config
9
+
10
+ def clean_text(text: str) -> str:
11
+ """
12
+ Cleans the input text by removing common noise from FDA documents.
13
+ """
14
+ text = re.sub(r'REVISED:\s*\d{1,2}/\d{4}', '', text)
15
+ text = re.sub(r'\s{2,}', ' ', text).strip()
16
+ text = re.sub(r'[\-=*]{3,}', '', text)
17
+ return text
18
+
19
+ def load_and_prepare_documents(json_path=config.RAW_DATA_PATH):
20
+ """
21
+ Loads drug data from a JSON file, filters for high-quality entries,
22
+ cleans the text, and returns a list of LangChain Document objects.
23
+ """
24
+ print(f"Loading data from: {json_path}...")
25
+ with open(json_path, 'r', encoding='utf-8') as f:
26
+ data = json.load(f)
27
+
28
+ all_docs = []
29
+ print("Filtering, cleaning, and converting data to 'Document' objects...")
30
+ for entry in tqdm(data, desc="Processing drug data"):
31
+ if not entry: continue
32
+
33
+ # --- NEW FILTERING LOGIC ---
34
+ # 1. Ensure the entry has a brand or generic name.
35
+ brand_name_list = entry.get("openfda", {}).get("brand_name")
36
+ generic_name_list = entry.get("openfda", {}).get("generic_name")
37
+
38
+ if not brand_name_list and not generic_name_list:
39
+ continue # Skip entries with no name
40
+
41
+ # 2. Ensure it's likely a real drug by checking for a crucial section.
42
+ if "indications_and_usage" not in entry:
43
+ continue # Skip entries that don't say what the drug is for
44
+
45
+ brand_name = brand_name_list[0] if brand_name_list else "Unknown Brand"
46
+ generic_name = generic_name_list[0] if generic_name_list else "Unknown Generic"
47
+ # ---------------------------
48
+
49
+ sections_to_process = {
50
+ "indications_and_usage": "Indications and Usage",
51
+ "adverse_reactions": "Adverse Reactions",
52
+ "drug_interactions": "Drug Interactions",
53
+ "contraindications": "Contraindications",
54
+ "warnings": "Warnings",
55
+ "boxed_warning": "Boxed Warning",
56
+ "mechanism_of_action": "Mechanism of Action",
57
+ "pharmacokinetics": "Pharmacokinetics",
58
+ "dosage_and_administration": "Dosage and Administration",
59
+ "how_supplied": "How Supplied",
60
+ "storage_and_handling": "Storage and Handling",
61
+ "information_for_patients": "Information for Patients",
62
+ "pregnancy": "Pregnancy",
63
+ "nursing_mothers": "Nursing Mothers",
64
+ "pediatric_use": "Pediatric Use",
65
+ "geriatric_use": "Geriatric Use"
66
+ }
67
+
68
+ for key, section_name in sections_to_process.items():
69
+ text_list = entry.get(key)
70
+ if text_list and isinstance(text_list, list) and text_list[0] and text_list[0].strip():
71
+ cleaned_text = clean_text(text_list[0])
72
+ if cleaned_text:
73
+ metadata = {"brand_name": brand_name, "generic_name": generic_name, "section": section_name}
74
+ doc = Document(page_content=cleaned_text, metadata=metadata)
75
+ all_docs.append(doc)
76
+
77
+ print(f"Created a total of {len(all_docs)} 'Document' objects after filtering.")
78
+ return all_docs
79
+
80
+ def load_and_process_all():
81
+ """
82
+ Loads and processes documents from all configured data sources.
83
+ """
84
+ all_docs = []
85
+
86
+ # Process FDA drug data
87
+ fda_docs = load_and_prepare_fda_documents()
88
+ all_docs.extend(fda_docs)
89
+
90
+ # Process HealthCareMagic data
91
+ # healthcare_docs = healthcare_data_processing.load_and_prepare_documents(config.HEALTHCARE_MAGIC_PATH)
92
+ # all_docs.extend(healthcare_docs)
93
+
94
+ '''# Process MedQuad data
95
+ medquad_docs = medquad_data_processing.load_and_prepare_documents(config.MEDQUAD_PATH)
96
+ all_docs.extend(medquad_docs)'''
97
+
98
+ print(f"Total documents loaded from all sources: {len(all_docs)}")
99
+ return all_docs
100
+
101
+ def load_and_prepare_fda_documents(json_path=config.CLEANED_DATA_PATH):
102
+ """
103
+ Loads cleaned drug data from a JSON Lines file and converts it into
104
+ a list of LlamaIndex Document objects for the RAG pipeline.
105
+ """
106
+ print(f"Loading cleaned drug data from: {json_path}...")
107
+ all_docs = []
108
+ try:
109
+ with open(json_path, 'r', encoding='utf-8') as f:
110
+ for line in tqdm(f, desc="Processing cleaned drug data"):
111
+ entry = json.loads(line)
112
+
113
+ content = entry.get("content")
114
+ if not content:
115
+ continue
116
+
117
+ metadata = {
118
+ "doc_id": entry.get("doc_id"),
119
+ "brand_name": entry.get("brand_name"),
120
+ "generic_name": entry.get("generic_name"),
121
+ "section": entry.get("section"),
122
+ "source": "FDA Drug Labels"
123
+ }
124
+
125
+ # The text for the document is just the content of the section
126
+ doc = Document(text=content, metadata=metadata)
127
+ all_docs.append(doc)
128
+
129
+ except FileNotFoundError:
130
+ print(f"Error: The file '{json_path}' was not found.")
131
+ return []
132
+ except json.JSONDecodeError as e:
133
+ print(f"Error: Could not decode JSON from a line in '{json_path}'. Details: {e}")
134
+ return []
135
+
136
+ print(f"Created {len(all_docs)} 'Document' objects from the cleaned FDA data.")
137
+ return all_docs
138
+
rag_pipeline.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =================================================================================
2
+ # rag_pipeline.py: Create the Gemini model and the RAG chain
3
+ # =================================================================================
4
+ from llama_index.core import VectorStoreIndex, StorageContext, load_index_from_storage
5
+ from llama_index.llms.gemini import Gemini
6
+ from llama_index.core.prompts.base import PromptTemplate
7
+ from llama_index.core.prompts import ChatPromptTemplate
8
+ from llama_index.core.llms import ChatMessage, MessageRole
9
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
10
+ from llama_index.core import Settings
11
+ from google.generativeai.types import HarmCategory, HarmBlockThreshold
12
+ import config
13
+ import os
14
+
15
+ def initialize_llm_and_embed_model():
16
+ """
17
+ Initializes and sets the global LLM and embedding model for LlamaIndex.
18
+ """
19
+ print(f"Initializing Gemini model: {config.LLM_MODEL_ID}...")
20
+
21
+ # Define safety settings to be less restrictive, especially for medical content
22
+ safety_settings = {
23
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
24
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
25
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
26
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
27
+ }
28
+
29
+ # System instruction for Gemini (if supported by your version)
30
+ system_instruction = (
31
+ "You are PharmaBot, an AI pharmaceutical information assistant. "
32
+ "You provide accurate information from FDA drug labels but never give medical advice or diagnose conditions. "
33
+ "You always respond in the user's language and maintain conversation context throughout the session."
34
+ )
35
+
36
+ llm = Gemini(
37
+ model_name=config.LLM_MODEL_ID,
38
+ temperature=0.3,
39
+ safety_settings=safety_settings,
40
+ generation_config={"candidate_count": 1},
41
+ system_instruction=system_instruction # Add system instruction
42
+ )
43
+
44
+ print(f"Loading embedding model: {config.EMBEDDING_MODEL_NAME}...")
45
+
46
+ # Get the token from environment variables
47
+ hf_token = os.getenv("HUGGING_FACE_TOKEN")
48
+ if not hf_token:
49
+ print("Warning: HUGGING_FACE_TOKEN environment variable not set.")
50
+
51
+ embed_model = HuggingFaceEmbedding(
52
+ model_name=config.EMBEDDING_MODEL_NAME,
53
+ token=hf_token
54
+ )
55
+
56
+ # Set the global models for LlamaIndex
57
+ Settings.llm = llm
58
+ Settings.embed_model = embed_model
59
+
60
+ def load_vector_index():
61
+ """
62
+ Loads the LlamaIndex vector index from storage.
63
+ """
64
+ if not os.path.exists(config.LLAMA_INDEX_STORE_PATH):
65
+ raise FileNotFoundError(f"LlamaIndex store not found at {config.LLAMA_INDEX_STORE_PATH}. Please run build_knowledge_base.py first.")
66
+
67
+ print("Loading LlamaIndex vector store...")
68
+ storage_context = StorageContext.from_defaults(persist_dir=config.LLAMA_INDEX_STORE_PATH)
69
+ index = load_index_from_storage(storage_context)
70
+ return index
71
+
72
+ from llama_index.core.memory import ChatMemoryBuffer
73
+
74
+ def build_query_engine(index):
75
+ """
76
+ Builds a query engine from the LlamaIndex vector index.
77
+ """
78
+
79
+ # Condensed, action-oriented prompt that guides behavior without being conversational
80
+ qa_template_str = (
81
+ "Context information from FDA drug labels:\n"
82
+ "---------------------\n"
83
+ "{context_str}\n"
84
+ "---------------------\n\n"
85
+ "Instructions:\n"
86
+ "1. LANGUAGE: Respond entirely in the same language as the query. Detect: English, Turkish, Spanish, French, German, Arabic, etc.\n"
87
+ "2. QUERY TYPE:\n"
88
+ " - Medical/Drug query (medications, symptoms, dosages, interactions) → Use context to provide structured response\n"
89
+ " - General conversation (greetings, small talk) → Respond conversationally, no context needed\n"
90
+ "3. CONTEXT CHECK:\n"
91
+ " - If context is empty/irrelevant → State you couldn't find information, ask for clarification\n"
92
+ " - If context is relevant → Proceed with response\n"
93
+ "4. RESPONSE FORMAT FOR DRUG QUERIES:\n"
94
+ " **Drug Name:** [from brand_name/generic_name]\n"
95
+ " **What It's Used For:** [summarize indications_and_usage]\n"
96
+ " **How to Take It:** [summarize dosage_and_administration]\n"
97
+ " **Important Warnings:** [list 4-5 critical points from warnings/adverse_reactions/contraindications]\n"
98
+ " **Drug Interactions:** [if available from drug_interactions]\n"
99
+ "5. RESPONSE FORMAT FOR DRUG INTERACTIONS:\n"
100
+ " **Drug Interaction: [Drug A] and [Drug B]**\n"
101
+ " **Interaction Found:** [describe]\n"
102
+ " **Clinical Significance:** [explain risks]\n"
103
+ " **Recommendation:** [FDA guidance]\n"
104
+ "6. RESPONSE FORMAT FOR SYMPTOM QUERIES (first ask):\n"
105
+ " Ask 5 clarifying questions: duration, severity, prior medications, current medications, allergies\n"
106
+ "7. RESPONSE FORMAT FOR SYMPTOM QUERIES (after details):\n"
107
+ " Present 2-3 FDA-approved medication options with: Type, Used For, Dosage, Key Warning\n"
108
+ "8. SAFETY:\n"
109
+ " - Only use info from context for medical responses\n"
110
+ " - If details missing from context, state explicitly\n"
111
+ " - ALWAYS end medical responses with:\n"
112
+ " ⚠️ Disclaimer: I am an AI assistant, not a medical professional. This information is from FDA labels and is for educational purposes only. Always consult your doctor or pharmacist before taking any medication.\n"
113
+ "9. MEMORY: Reference previous drugs/symptoms/allergies mentioned in conversation\n\n"
114
+ "Query: {query_str}\n\n"
115
+ "Answer (in same language as query):"
116
+ )
117
+
118
+ qa_template = PromptTemplate(qa_template_str)
119
+
120
+ print("Building query engine...")
121
+
122
+ memory = ChatMemoryBuffer.from_defaults(token_limit=3000)
123
+
124
+ # Use simple chat mode to avoid condense_question_prompt issues
125
+ # The chat mode will still maintain conversation history through memory
126
+ query_engine = index.as_chat_engine(
127
+ chat_mode="context", # Changed from "condense_question" to "context"
128
+ memory=memory,
129
+ system_prompt=(
130
+ "You are PharmaBot, an AI pharmaceutical information assistant. "
131
+ "Always respond in the user's language. Use FDA drug label data to answer medical queries. "
132
+ "Never diagnose or prescribe. Include disclaimers on medical responses."
133
+ ),
134
+ context_template=qa_template, # Use our custom template
135
+ similarity_top_k=5,
136
+ verbose=True
137
+ )
138
+
139
+ return query_engine
requirements.txt ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohappyeyeballs==2.6.1
2
+ aiohttp==3.13.1
3
+ aiosignal==1.4.0
4
+ aiosqlite==0.21.0
5
+ altair==5.5.0
6
+ annotated-types==0.7.0
7
+ anyio==4.11.0
8
+ attrs==25.4.0
9
+ banks==2.2.0
10
+ beautifulsoup4==4.14.2
11
+ blinker==1.9.0
12
+ cachetools==6.2.1
13
+ certifi==2025.10.5
14
+ charset-normalizer==3.4.4
15
+ click==8.3.0
16
+ colorama==0.4.6
17
+ dataclasses-json==0.6.7
18
+ defusedxml==0.7.1
19
+ Deprecated==1.2.18
20
+ dirtyjson==1.0.8
21
+ distro==1.9.0
22
+ filelock==3.20.0
23
+ filetype==1.2.0
24
+ frozenlist==1.8.0
25
+ fsspec==2025.9.0
26
+ gitdb==4.0.12
27
+ GitPython==3.1.45
28
+ google-ai-generativelanguage==0.6.15
29
+ google-api-core==2.26.0
30
+ google-api-python-client==2.185.0
31
+ google-auth==2.41.1
32
+ google-auth-httplib2==0.2.0
33
+ google-generativeai==0.8.5
34
+ googleapis-common-protos==1.70.0
35
+ greenlet==3.2.4
36
+ griffe==1.14.0
37
+ grpcio==1.75.1
38
+ grpcio-status==1.71.2
39
+ h11==0.16.0
40
+ hf-xet==1.1.10
41
+ httpcore==1.0.9
42
+ httplib2==0.31.0
43
+ httpx==0.28.1
44
+ huggingface-hub==0.35.3
45
+ idna==3.11
46
+ Jinja2==3.1.6
47
+ jiter==0.11.1
48
+ joblib==1.5.2
49
+ jsonschema==4.25.1
50
+ jsonschema-specifications==2025.9.1
51
+ llama-cloud==0.1.35
52
+ llama-cloud-services==0.6.54
53
+ llama-index==0.14.5
54
+ llama-index-cli==0.5.3
55
+ llama-index-core==0.14.5
56
+ llama-index-embeddings-huggingface==0.6.1
57
+ llama-index-embeddings-openai==0.5.1
58
+ llama-index-indices-managed-llama-cloud==0.9.4
59
+ llama-index-instrumentation==0.4.2
60
+ llama-index-llms-gemini==0.6.1
61
+ llama-index-llms-openai==0.6.5
62
+ llama-index-readers-file==0.5.4
63
+ llama-index-readers-llama-parse==0.5.1
64
+ llama-index-workflows==2.8.3
65
+ llama-parse==0.6.54
66
+ MarkupSafe==3.0.3
67
+ marshmallow==3.26.1
68
+ mpmath==1.3.0
69
+ multidict==6.7.0
70
+ mypy_extensions==1.1.0
71
+ narwhals==2.9.0
72
+ nest-asyncio==1.6.0
73
+ networkx==3.5
74
+ nltk==3.9.2
75
+ numpy==2.3.4
76
+ openai==1.109.1
77
+ packaging==25.0
78
+ pandas==2.2.3
79
+ pillow==10.4.0
80
+ platformdirs==4.5.0
81
+ propcache==0.4.1
82
+ proto-plus==1.26.1
83
+ protobuf==5.29.5
84
+ pyarrow==21.0.0
85
+ pyasn1==0.6.1
86
+ pyasn1_modules==0.4.2
87
+ pydantic==2.12.3
88
+ pydantic_core==2.41.4
89
+ pydeck==0.9.1
90
+ pyparsing==3.2.5
91
+ pypdf==6.1.2
92
+ python-dateutil==2.9.0.post0
93
+ python-dotenv==1.1.1
94
+ pytz==2025.2
95
+ PyYAML==6.0.3
96
+ referencing==0.37.0
97
+ regex==2025.9.18
98
+ requests==2.32.5
99
+ rpds-py==0.27.1
100
+ rsa==4.9.1
101
+ safetensors==0.6.2
102
+ scikit-learn==1.7.2
103
+ scipy==1.16.2
104
+ sentence-transformers==5.1.1
105
+ setuptools==80.9.0
106
+ six==1.17.0
107
+ smmap==5.0.2
108
+ sniffio==1.3.1
109
+ soupsieve==2.8
110
+ SQLAlchemy==2.0.44
111
+ streamlit==1.50.0
112
+ striprtf==0.0.26
113
+ sympy==1.14.0
114
+ tenacity==9.1.2
115
+ threadpoolctl==3.6.0
116
+ tiktoken==0.12.0
117
+ tokenizers==0.22.1
118
+ toml==0.10.2
119
+ torch==2.9.0
120
+ tornado==6.5.2
121
+ tqdm==4.67.1
122
+ transformers==4.57.1
123
+ typing-inspect==0.9.0
124
+ typing-inspection==0.4.2
125
+ typing_extensions==4.15.0
126
+ tzdata==2025.2
127
+ uritemplate==4.2.0
128
+ urllib3==2.5.0
129
+ wrapt==1.17.3
130
+ yarl==1.22.0
vector_store_manager.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =================================================================================
2
+ # vector_store_manager.py: Management of the FAISS vector database
3
+ # =================================================================================
4
+ from langchain_community.vectorstores import FAISS
5
+ from langchain_community.embeddings import HuggingFaceEmbeddings
6
+ import config
7
+
8
+ def get_embeddings_model(model_name=config.EMBEDDING_MODEL_NAME):
9
+ """Loads and returns the embedding model."""
10
+ print(f"Loading embedding model: {model_name}...")
11
+ return HuggingFaceEmbeddings(model_name=model_name)
12
+
13
+ def create_and_save_store(documents, embeddings, save_path=config.VECTOR_STORE_PATH):
14
+ """
15
+ Creates a FAISS vector database from the given documents and saves it to disk.
16
+ """
17
+ print("Creating and saving the FAISS vector store...")
18
+ vector_store = FAISS.from_documents(documents, embeddings)
19
+ vector_store.save_local(save_path)
20
+ print(f"✅ Vector store successfully saved to '{save_path}'.")
21
+
22
+ def load_store(embeddings, load_path=config.VECTOR_STORE_PATH):
23
+ """
24
+ Loads the FAISS vector database from a local path.
25
+ """
26
+ print(f"Loading vector store from: {load_path}...")
27
+ # The allow_dangerous_deserialization flag is required for loading FAISS indexes with LangChain.
28
+ return FAISS.load_local(load_path, embeddings, allow_dangerous_deserialization=True)
29
+