Sami Ali commited on
Commit
b098829
Β·
1 Parent(s): cacd7d3

feat: improve model and prompt

Browse files
.gitignore CHANGED
@@ -7,6 +7,8 @@ __pycache__/
7
  data
8
  demo
9
 
 
 
10
  # C extensions
11
  *.so
12
 
 
7
  data
8
  demo
9
 
10
+ test.ipynb
11
+
12
  # C extensions
13
  *.so
14
 
src/data_processor.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  from src.constant import BASE_DIR
3
  from langchain.schema import Document
@@ -8,36 +9,31 @@ DATA_DIR = os.path.join(BASE_DIR, "data", "pmc")
8
 
9
  class DataProcessor:
10
  """
11
- Handles loading, cleaning, and chunking of text files
12
  from the PubMed Central (PMC) dataset.
13
  """
14
 
15
- def __init__(self, data_path: str = DATA_DIR, limit=100):
16
  self.data_path = data_path
17
- self.limit = limit
18
 
19
  def _load_files(self) -> list[dict]:
20
  """
21
  Load raw text files from the dataset directory.
22
  Returns a list of dictionaries with file name and raw content.
23
  """
24
- count = 0
25
  data_list = []
26
  for file_name in os.listdir(self.data_path):
27
  if not file_name.endswith(".txt"):
28
  continue
29
  file_path = os.path.join(self.data_path, file_name)
30
- with open(file_path, "r", encoding="utf-8") as file_ref:
31
  data_list.append(
32
  {
33
  "file_name": file_name,
34
  "page_content": file_ref.read()
35
  }
36
  )
37
- if count >= self.limit:
38
- break
39
- count += 1
40
-
41
  return data_list
42
 
43
  @staticmethod
@@ -48,10 +44,103 @@ class DataProcessor:
48
  if not isinstance(text, str):
49
  return text
50
  try:
51
- return text.encode("utf-8").decode("unicode-escape")
52
  except Exception:
53
  return text
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def _preprocess(self, data: list[dict]) -> list[dict]:
56
  """
57
  Apply preprocessing steps (e.g., unicode decoding) to raw data.
@@ -59,17 +148,18 @@ class DataProcessor:
59
  cleaned_data = []
60
  for record in data:
61
  decoded_text = self._decode_unicode(record["page_content"])
 
62
  cleaned_data.append(
63
  {
64
  "file_name": record["file_name"],
65
- "page_content": decoded_text
66
  }
67
  )
68
  return cleaned_data
69
 
70
  def load_documents(self) -> list[Document]:
71
  """
72
- Load and preprocess text files, converting them into
73
  LangChain Document objects.
74
  """
75
  raw_data = self._load_files()
@@ -106,4 +196,4 @@ class DataProcessor:
106
  """
107
  documents = self.load_documents()
108
  chunks = self.chunk_documents(documents)
109
- return chunks, documents
 
1
+ import re
2
  import os
3
  from src.constant import BASE_DIR
4
  from langchain.schema import Document
 
9
 
10
  class DataProcessor:
11
  """
12
+ Handles loading, cleaning, and chunking of text files
13
  from the PubMed Central (PMC) dataset.
14
  """
15
 
16
+ def __init__(self, data_path: str = DATA_DIR):
17
  self.data_path = data_path
 
18
 
19
  def _load_files(self) -> list[dict]:
20
  """
21
  Load raw text files from the dataset directory.
22
  Returns a list of dictionaries with file name and raw content.
23
  """
 
24
  data_list = []
25
  for file_name in os.listdir(self.data_path):
26
  if not file_name.endswith(".txt"):
27
  continue
28
  file_path = os.path.join(self.data_path, file_name)
29
+ with open(file_path, "r", encoding="utf-8", errors="replace") as file_ref:
30
  data_list.append(
31
  {
32
  "file_name": file_name,
33
  "page_content": file_ref.read()
34
  }
35
  )
36
+
 
 
 
37
  return data_list
38
 
39
  @staticmethod
 
44
  if not isinstance(text, str):
45
  return text
46
  try:
47
+ return text.encode("utf-8", "ignore").decode("utf-8", "ignore")
48
  except Exception:
49
  return text
50
 
51
+ def _extract_body(self, text: str) -> str:
52
+
53
+ if not text:
54
+ return ""
55
+
56
+ text = text.replace("\r\n", "\n").replace("\r", "\n")
57
+ text = re.sub(r'-\n', '', text)
58
+ text = re.sub(r'\n{3,}', '\n\n', text)
59
+
60
+ start_patterns = [
61
+ r"====\s*Body", r"^Body\s*$", r"^BODY\s*$",
62
+ r"^Abstract\s*$", r"^ABSTRACT\s*$", r"^Introduction\s*$", r"^INTRODUCTION\s*$"
63
+ ]
64
+ end_patterns = [
65
+ r"====\s*Back", r"^Back\s*$", r"^BACK\s*$",
66
+ r"^References\s*$", r"^REFERENCES\s*$", r"^Bibliography\s*$",
67
+ r"^Acknowledg", r"^Acknowledgments\s*$", r"^ACKNOWLEDGMENTS\s*$"
68
+ ]
69
+
70
+ start_idx = None
71
+ for pat in start_patterns:
72
+ m = re.search(pat, text, flags=re.IGNORECASE | re.MULTILINE)
73
+ if m:
74
+ start_idx = m.end()
75
+ break
76
+
77
+ if start_idx is not None:
78
+ # find end after start_idx
79
+ end_idx = None
80
+ for pat in end_patterns:
81
+ m = re.search(pat, text[start_idx:], flags=re.IGNORECASE | re.MULTILINE)
82
+ if m:
83
+ end_idx = start_idx + m.start()
84
+ break
85
+ body = text[start_idx:end_idx] if end_idx else text[start_idx:]
86
+
87
+ else:
88
+ paragraphs = re.split(r'\n{2,}', text)
89
+ paragraphs = [p.strip() for p in paragraphs if p.strip()]
90
+
91
+ def is_metadata_para(p: str) -> bool:
92
+ # DOI / arXiv / ISSN / PMCID / PMID
93
+ if re.search(r'\b10\.\d{4,9}/\S+\b', p):
94
+ return True
95
+ if re.search(r'\bPMCID\b|\bPMID\b', p, re.I):
96
+ return True
97
+ if re.search(r'ISSN[:\s]', p, re.I):
98
+ return True
99
+ # common metadata keywords
100
+ if re.search(r'Correspondence:|Affiliat|Author|ORCID|E-mail:|Contact:', p, re.I):
101
+ return True
102
+ if re.search(r'Β©|license|all rights reserved|Published|Received|Accepted', p, re.I):
103
+ return True
104
+
105
+ words = p.split()
106
+ if len(p) < 200 and len(words) <= 12 and sum(1 for w in words if w.isupper())/max(1,len(words)) > 0.6:
107
+ return True
108
+ return False
109
+
110
+ wc = [len(p.split()) for p in paragraphs]
111
+ good = [ (wc_i >= 40 and not is_metadata_para(p)) for p, wc_i in zip(paragraphs, wc) ]
112
+
113
+ best_start = best_len = 0
114
+ cur_start = cur_len = 0
115
+ for i, g in enumerate(good):
116
+ if g:
117
+ if cur_len == 0:
118
+ cur_start = i
119
+ cur_len += 1
120
+ if cur_len > best_len:
121
+ best_len = cur_len
122
+ best_start = cur_start
123
+ else:
124
+ cur_len = 0
125
+ if best_len > 0:
126
+ body = "\n\n".join(paragraphs[best_start: best_start + best_len])
127
+ else:
128
+ # final fallback: pick the top N paragraphs by length (they likely contain body content)
129
+ top_idxs = sorted(range(len(paragraphs)), key=lambda i: wc[i], reverse=True)[:5]
130
+ top_idxs.sort()
131
+ body = "\n\n".join(paragraphs[i] for i in top_idxs)
132
+
133
+ body = re.sub(r'\n{2,}References[\s\S]*$', '', body, flags=re.IGNORECASE)
134
+ body = re.sub(r'\n{2,}Bibliography[\s\S]*$', '', body, flags=re.IGNORECASE)
135
+ body = re.sub(r'\n{2,}Acknowledg[\s\S]*$', '', body, flags=re.IGNORECASE)
136
+
137
+ # 4) Clean junk: remove URLs/emails, collapse whitespace
138
+ body = re.sub(r'https?://\S+', ' ', body)
139
+ body = re.sub(r'\S+@\S+', ' ', body)
140
+ body = re.sub(r'\s+', ' ', body).strip()
141
+
142
+ return body
143
+
144
  def _preprocess(self, data: list[dict]) -> list[dict]:
145
  """
146
  Apply preprocessing steps (e.g., unicode decoding) to raw data.
 
148
  cleaned_data = []
149
  for record in data:
150
  decoded_text = self._decode_unicode(record["page_content"])
151
+ main_body = self._extract_body(decoded_text)
152
  cleaned_data.append(
153
  {
154
  "file_name": record["file_name"],
155
+ "page_content": main_body
156
  }
157
  )
158
  return cleaned_data
159
 
160
  def load_documents(self) -> list[Document]:
161
  """
162
+ Load and preprocess text files, converting them into
163
  LangChain Document objects.
164
  """
165
  raw_data = self._load_files()
 
196
  """
197
  documents = self.load_documents()
198
  chunks = self.chunk_documents(documents)
199
+ return chunks, documents
src/download_data.py CHANGED
@@ -17,6 +17,9 @@ def download_pmc_docs(
17
  target_dir=TARGET_DIR,
18
  limit=1000
19
  ):
 
 
 
20
  os.makedirs(target_dir, exist_ok=True)
21
 
22
  s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED))
 
17
  target_dir=TARGET_DIR,
18
  limit=1000
19
  ):
20
+ if (len(os.listdir(target_dir)) > 0):
21
+ return
22
+
23
  os.makedirs(target_dir, exist_ok=True)
24
 
25
  s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED))
src/embedding.py CHANGED
@@ -1,17 +1,20 @@
1
- from typing import List
2
  import numpy as np
3
- from langchain_huggingface import HuggingFaceEmbeddings
 
 
4
  from tqdm import tqdm
5
 
6
  class EmbeddingManager:
7
  def __init__(self, model_name: str = "pritamdeka/S-BioBERT-snli-multinli-stsb"):
8
  self.model_name = model_name
9
  self.model = None
 
10
  self.load_model()
11
 
12
  def load_model(self):
13
  print("Loading embedding model:", self.model_name)
14
- self.model = HuggingFaceEmbeddings(model_name=self.model_name)
 
15
  print("Model loaded.")
16
 
17
  def get_model(self):
@@ -24,10 +27,9 @@ class EmbeddingManager:
24
  embeddings = []
25
  for i in tqdm(range(0, len(texts), batch_size), desc="Embedding texts"):
26
  batch = texts[i:i + batch_size]
27
- emb = self.model.embed_documents(batch)
28
  embeddings.extend(emb)
 
29
 
30
- return np.array(embeddings)
31
-
32
- def embed_one(self, text: str) -> np.ndarray:
33
- return self.model.embed_query(text)
 
 
1
  import numpy as np
2
+ import torch
3
+ from typing import List
4
+ from sentence_transformers import SentenceTransformer
5
  from tqdm import tqdm
6
 
7
  class EmbeddingManager:
8
  def __init__(self, model_name: str = "pritamdeka/S-BioBERT-snli-multinli-stsb"):
9
  self.model_name = model_name
10
  self.model = None
11
+ self.device = 'cuda' if torch.cuda.is_available else 'cpu'
12
  self.load_model()
13
 
14
  def load_model(self):
15
  print("Loading embedding model:", self.model_name)
16
+ print('Using device', self.device)
17
+ self.model = SentenceTransformer(model_name=self.model_name, device=self.device)
18
  print("Model loaded.")
19
 
20
  def get_model(self):
 
27
  embeddings = []
28
  for i in tqdm(range(0, len(texts), batch_size), desc="Embedding texts"):
29
  batch = texts[i:i + batch_size]
30
+ emb = self.model.encode(batch, batch_size=batch_size, show_progress_bar=False, convert_to_numpy=True, normalize_embeddings=True)
31
  embeddings.extend(emb)
32
+ return np.vstack(embeddings)
33
 
34
+ def embed_query(self, text: str) -> np.ndarray:
35
+ return self.model.encode(text, convert_to_numpy=True, normalize_embeddings=True).flatten()
 
 
src/prompt.py CHANGED
@@ -1,22 +1,16 @@
1
  from langchain_core.prompts import PromptTemplate
2
 
3
- BIOMED_PROMPT = """
4
- You are MedRAG, an assistant specialized in biomedical research.
5
- Your job is to answer the question using ONLY the provided context.
6
 
7
- If the answer cannot be found in the context, say clearly:
8
- "I could not find an exact answer in the provided research papers."
9
-
10
- Always cite the source PMC IDs in your answer.
11
 
12
  Question:
13
  {question}
14
 
15
- Context:
16
- {context}
17
-
18
- Answer:
19
- """
20
 
21
  prompt = PromptTemplate(
22
  template=BIOMED_PROMPT,
 
1
  from langchain_core.prompts import PromptTemplate
2
 
3
+ BIOMED_PROMPT = """You are a scholarly assistant analyzing historical medical texts.
4
+ Use the retrieved documents to answer the user's question as completely as possible.
5
+ If the context implies but does not explicitly state a detail, you may infer it cautiously.
6
 
7
+ Context:
8
+ {context}
 
 
9
 
10
  Question:
11
  {question}
12
 
13
+ Answer in a clear, factual summary style."""
 
 
 
 
14
 
15
  prompt = PromptTemplate(
16
  template=BIOMED_PROMPT,
streamlit_app.py CHANGED
@@ -9,7 +9,7 @@ import streamlit as st
9
 
10
  @st.cache_resource(show_spinner="πŸ”„ Building pipeline...")
11
  def load_pipeline():
12
- limit = 1000
13
  download_pmc_docs(limit=limit)
14
  dp = DataProcessor()
15
  chunks, document = dp.build()
 
9
 
10
  @st.cache_resource(show_spinner="πŸ”„ Building pipeline...")
11
  def load_pipeline():
12
+ limit = 2000
13
  download_pmc_docs(limit=limit)
14
  dp = DataProcessor()
15
  chunks, document = dp.build()