ssrogue commited on
Commit
b1e8fe0
Β·
verified Β·
1 Parent(s): 8f160bc

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/Fin_ExBERT_data.xlsx filter=lfs diff=lfs merge=lfs -text
37
+ data/Fin_ExBERT_test_set.xlsx filter=lfs diff=lfs merge=lfs -text
38
+ data/Fin_ExBERT_train_val_data.xlsx filter=lfs diff=lfs merge=lfs -text
39
+ images/methodology_flowchart.png filter=lfs diff=lfs merge=lfs -text
40
+ results/combined_results.xlsx filter=lfs diff=lfs merge=lfs -text
41
+ results/Fin-ExBERT.pptx filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Soumick Sarker
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,196 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FinExBERT: Financial Sentence Extraction with Graph-Augmented BERT
2
+
3
+ [![Python 3.10+](https://img.shields.io/badge/python-3.10+-green.svg)](https://www.python.org/downloads/)
4
+ [![PyTorch](https://img.shields.io/badge/PyTorch-1.9+-red.svg)](https://pytorch.org/)
5
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
6
+ [![arXiv](https://img.shields.io/badge/arXiv-2025.23259-b31b1b.svg)]([https://arxiv.org/](https://www.arxiv.org/abs/2509.23259))
7
+
8
+ > A state-of-the-art neural architecture for extracting relevant sentences from financial conversations using graph-augmented BERT with dependency parsing.
9
+
10
+ **Accepted at EMNLP 2025 Industry Track**
11
+
12
+ ## Overview
13
+
14
+ FinExBERT combines BERT's contextual understanding with graph neural networks to capture syntactic dependencies in financial conversations. The model achieves superior performance in extracting relevant sentences based on user intent, making it particularly effective for financial customer service applications.
15
+
16
+ ### Problem Statement
17
+
18
+ Traditional sequence-to-sequence models struggle with:
19
+ - Complex financial terminology and context
20
+ - Long conversation dependencies
21
+ - Intent-based sentence extraction
22
+ - Domain-specific reasoning requirements
23
+
24
+ ### Our Solution
25
+
26
+ FinExBERT addresses these challenges through:
27
+ - **Graph-Augmented Architecture**: Incorporates dependency parsing graphs to capture syntactic relationships
28
+ - **Financial Domain Adaptation**: LoRA fine-tuning on financial datasets
29
+ - **Intent-Aware Extraction**: Semantic similarity matching for targeted sentence selection
30
+ - **Efficient Training**: Mixed precision training with gradient accumulation
31
+
32
+ ## Key Features
33
+
34
+ - πŸ† **State-of-the-art Performance**: Outperforms baseline BERT by 37% in accuracy on financial conversation tasks
35
+ - 🧠 **Graph Neural Networks**: Integrates dependency parsing for enhanced linguistic understanding
36
+ - πŸ’° **Financial Domain Expertise**: Pre-trained on financial conversation data
37
+ - ⚑ **Production Ready**: Optimized for real-world deployment with batched inference
38
+ - πŸ”§ **Flexible Architecture**: Configurable model components for different use cases
39
+ - πŸ“Š **Comprehensive Evaluation**: Extensive ablation studies and evaluation metrics
40
+
41
+ ## Installation
42
+
43
+ ### Prerequisites
44
+
45
+ - Python 3.10 or higher
46
+ - PyTorch 1.9 or higher
47
+ - CUDA 11.0+ (for GPU acceleration)
48
+
49
+
50
+ ### Install dependencies
51
+
52
+ ```bash
53
+ git clone https://github.com/soumick1/Fin-ExBERT.git
54
+ pip install -r requirements.txt
55
+ ```
56
+
57
+ ## Quick Start
58
+
59
+ ### Download the model weights
60
+
61
+ Download the weights from the [Weights Link](https://drive.google.com/drive/folders/1jm3Yxpew8Y8mVsRizTyVvXKrGBXQ3ApI?usp=sharing)
62
+ And put the 3 folders inside the cloned directory.
63
+
64
+ ### Data setup
65
+
66
+ The CreditCall12H Dataset is available in the 'data' folder. If you want to train or test on your own data please use the same format.
67
+
68
+ ### Basic Usage and Testing
69
+
70
+ ```python
71
+ from utils import batch_predict_and_save
72
+ from config import *
73
+ from preprocess_data import SentenceDataset
74
+ from models import SentenceExtractionModel
75
+
76
+ # Initialize the model
77
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") ### You can change the tokenizer if you want
78
+ dataset = SentenceDataset("data/Fin_ExBERT_train_val_data.xlsx", tokenizer)
79
+
80
+ model = SentenceExtractionModel(
81
+ base_model_name=MODEL_NAME,
82
+ backbone='finexbert'
83
+ )
84
+
85
+ # Extract relevant sentences
86
+ batch_predict_and_save(
87
+ model,
88
+ tokenizer,
89
+ excel_path="data/Fin_ExBERT_test_set.xlsx",
90
+ ckpt_path="checkpoints/sentence_extractor/best_model.pth",
91
+ output_path="results/predictions_sample200.xlsx",
92
+ n_samples=200,
93
+ temperature=1.0,
94
+ device="cuda"
95
+ )
96
+ ```
97
+
98
+ ### Training the model
99
+
100
+ ```python
101
+ from utils import train_model_with_chkpt
102
+ from config import *
103
+ from preprocess_data import SentenceDataset
104
+ from models import SentenceExtractionModel
105
+
106
+ # Initialize the model
107
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") ### You can change the tokenizer if you want
108
+ dataset = SentenceDataset("data/Fin_ExBERT_train_val_data.xlsx", tokenizer)
109
+
110
+ model = SentenceExtractionModel(
111
+ base_model_name=MODEL_NAME,
112
+ backbone='finexbert'
113
+ )
114
+
115
+ train_sentence_extractor(
116
+ model,
117
+ dataset,
118
+ output_dir="checkpoints/sentence_extractor",
119
+ val_split=0.3,
120
+ epochs=10,
121
+ batch_size=16,
122
+ lr=3e-4,
123
+ device=DEVICE,
124
+ unfreeze_after_epoch=4
125
+ )
126
+ ```
127
+
128
+ ## Model Architecture
129
+
130
+ ![FinExBERT Architecture](images/methodology_flowchart.png)
131
+
132
+ ### Core Components
133
+
134
+ 1. **BERT Encoder**: Contextual embeddings for input sequences
135
+ 2. **Dependency Graph Parser**: SpaCy-based syntactic analysis
136
+ 3. **Graph Neural Network**: Message passing over dependency graphs
137
+ 4. **Fusion Layer**: Combines BERT and GNN representations
138
+ 5. **Classification Head**: Intent-aware sentence scoring
139
+
140
+ ### Technical Details
141
+
142
+ - **Base Model**: BERT-base-uncased (110M parameters)
143
+ - **GNN Architecture**: Simple message passing with attention
144
+ - **Training Strategy**: LoRA adaptation + full fine-tuning
145
+
146
+
147
+ ## Evaluation
148
+
149
+ ### Ablation Studies
150
+
151
+ We provide comprehensive ablation studies comparing:
152
+
153
+ - Baseline BERT vs. Graph-Augmented BERT
154
+ - Different GNN architectures
155
+ - Various training strategies
156
+ - Domain adaptation techniques
157
+
158
+
159
+ ### Performance Metrics
160
+
161
+ | Model | Accuracy | F1-Score | Precision | Recall |
162
+ |-------|----------|----------|-----------|--------|
163
+ | BERT Baseline | 0.323 | 0.163 | 0.145 | 0.189 |
164
+ | FinExBERT | 0.694 | 0.418 | 0.456 | 0.391 |
165
+ | **Improvement** | **+37%** | **+26%** | **+31%** | **+20%** |
166
+
167
+
168
+
169
+ ## Citation
170
+
171
+ If you use FinExBERT in your research, please cite:
172
+
173
+ ```bibtex
174
+ Will post it soon
175
+ ```
176
+
177
+
178
+ ## License
179
+
180
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
181
+
182
+ ## Acknowledgments
183
+
184
+ - Built on top of [Transformers](https://github.com/huggingface/transformers) by Hugging Face
185
+ - Graph processing with [SpaCy](https://spacy.io/)
186
+ - Training infrastructure powered by [PyTorch](https://pytorch.org/)
187
+
188
+ ## Support
189
+
190
+ - πŸ“§ Email: [email protected]
191
+
192
+ ---
193
+
194
+ <div align="center">
195
+ <strong>FinExBERT</strong> - Advancing Financial NLP with Graph-Augmented Models
196
+ </div>
__init__.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+
4
+ def install_without_version(package_name):
5
+ package_name = str(package_name)
6
+ try:
7
+ __import__(package_name)
8
+ print(f"{package_name} is already installed.")
9
+ except ImportError:
10
+ print(f"{package_name} is not installed. Installing now...")
11
+ subprocess.check_call([sys.executable, "-m", "pip", "install", '-U', package_name])
12
+
13
+
14
+ def install_with_version(package_name, package_version):
15
+ package_name = str(package_name)
16
+ package_version = str(package_version)
17
+ try:
18
+ pkg = __import__(package_name)
19
+ installed_version = pkg.__version__
20
+ if installed_version == package_version:
21
+ print(f"{package_name} {package_version} is already installed.")
22
+ else:
23
+ print(f"{package_name} {installed_version} is installed, but {package_version} is required. Updating now...")
24
+ subprocess.check_call([sys.executable, "-m", "pip", "install", f"{package_name}=={package_version}"])
25
+ except ImportError:
26
+ print(f"Installing version {package_version} now...")
27
+ subprocess.check_call([sys.executable, "-m", "pip", "install", f"{package_name}=={package_version}"])
28
+
29
+ if __name__=='__main__':
30
+ packages = [['torch', ''], ['datasets', ''], ['spacy', ''], ['networkx', ''], ['numpy', '1.26.4']]
31
+ for package in packages:
32
+ if package[1] == '':
33
+ install_without_version(package[0])
34
+ else:
35
+ install_with_version(package[0], package[1])
36
+
37
+ subprocess.check_call([sys.executable, "python", "-m", "spacy", "download", "en_core_web_sm"])
38
+
ablation_and_evaluation/ablation_studies.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import logging
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.utils.data import DataLoader
8
+ from tqdm.auto import tqdm
9
+ import matplotlib.pyplot as plt
10
+ from sklearn.metrics import accuracy_score, f1_score
11
+ from torch.utils.data import DataLoader, Subset
12
+ from datasets import load_from_disk
13
+ from utils import my_collate_fn
14
+
15
+ from config import MODEL_NAME, PREPROCESSED_DIR, DEVICE
16
+ from preprocess_data import process_data, SpanExtractionChunkedDataset, span_collate_fn
17
+ from models import GraphAugmentedNLIModel
18
+ from transformers import AutoConfig, AutoModel
19
+
20
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
21
+
22
+ # ---------------------
23
+ # 1) Define a BERT‐only baseline
24
+ # ---------------------
25
+ class BertOnlyNLIModel(nn.Module):
26
+ def __init__(self, base_model_name: str, num_labels: int = 3):
27
+ super().__init__()
28
+ config = AutoConfig.from_pretrained(base_model_name)
29
+ config.num_labels = num_labels
30
+ self.bert = AutoModel.from_pretrained(base_model_name, config=config)
31
+ hidden_dim = config.hidden_size
32
+ self.dropout = nn.Dropout(0.1)
33
+ self.classifier = nn.Linear(hidden_dim, num_labels)
34
+
35
+ def forward(self, input_ids, attention_mask, labels=None):
36
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
37
+ cls_emb = outputs.last_hidden_state[:, 0, :]
38
+ x = self.dropout(cls_emb)
39
+ logits = self.classifier(x)
40
+ loss = None
41
+ if labels is not None:
42
+ loss_fn = nn.CrossEntropyLoss()
43
+ loss = loss_fn(logits, labels)
44
+ return {"loss": loss, "logits": logits}
45
+
46
+ # ---------------------
47
+ # 2) Training & evaluation routines
48
+ # ---------------------
49
+ def set_seed(seed=42):
50
+ random.seed(seed)
51
+ np.random.seed(seed)
52
+ torch.manual_seed(seed)
53
+ if torch.cuda.is_available():
54
+ torch.cuda.manual_seed_all(seed)
55
+
56
+ def train_one_epoch(model, loader, optimizer, scheduler):
57
+ model.train()
58
+ losses = []
59
+ is_gnn = hasattr(model, "gnn_premise") # True for GraphAugmentedNLIModel
60
+
61
+ for batch in tqdm(loader, leave=False):
62
+ optimizer.zero_grad()
63
+ # Move all tensor fields to DEVICE
64
+ batch = {
65
+ k: v.to(DEVICE) if torch.is_tensor(v) else v
66
+ for k, v in batch.items()
67
+ }
68
+
69
+ if is_gnn:
70
+ out = model(
71
+ input_ids=batch["input_ids"],
72
+ attention_mask=batch["attention_mask"],
73
+ premise_graph_tokens=batch["premise_graph_tokens"],
74
+ premise_graph_edges=batch["premise_graph_edges"],
75
+ premise_node_indices=batch["premise_node_indices"],
76
+ hypothesis_graph_tokens=batch["hypothesis_graph_tokens"],
77
+ hypothesis_graph_edges=batch["hypothesis_graph_edges"],
78
+ hypothesis_node_indices=batch["hypothesis_node_indices"],
79
+ labels=batch["labels"],
80
+ )
81
+ else:
82
+ out = model(
83
+ input_ids=batch["input_ids"],
84
+ attention_mask=batch["attention_mask"],
85
+ labels=batch["labels"],
86
+ )
87
+
88
+ loss = out["loss"]
89
+ loss.backward()
90
+ optimizer.step()
91
+ scheduler.step()
92
+ losses.append(loss.item())
93
+
94
+ return float(np.mean(losses))
95
+
96
+
97
+ @torch.no_grad()
98
+ def evaluate(model, loader):
99
+ model.eval()
100
+ preds, golds = [], []
101
+ is_gnn = hasattr(model, "gnn_premise")
102
+
103
+ for batch in loader:
104
+ batch = {
105
+ k: v.to(DEVICE) if torch.is_tensor(v) else v
106
+ for k, v in batch.items()
107
+ }
108
+
109
+ if is_gnn:
110
+ out = model(
111
+ input_ids=batch["input_ids"],
112
+ attention_mask=batch["attention_mask"],
113
+ premise_graph_tokens=batch["premise_graph_tokens"],
114
+ premise_graph_edges=batch["premise_graph_edges"],
115
+ premise_node_indices=batch["premise_node_indices"],
116
+ hypothesis_graph_tokens=batch["hypothesis_graph_tokens"],
117
+ hypothesis_graph_edges=batch["hypothesis_graph_edges"],
118
+ hypothesis_node_indices=batch["hypothesis_node_indices"],
119
+ )
120
+ else:
121
+ out = model(
122
+ input_ids=batch["input_ids"],
123
+ attention_mask=batch["attention_mask"],
124
+ )
125
+
126
+ logits = out["logits"].cpu().numpy()
127
+ preds.extend(np.argmax(logits, axis=1).tolist())
128
+ golds.extend(batch["labels"].cpu().tolist())
129
+
130
+ acc = accuracy_score(golds, preds)
131
+ f1 = f1_score(golds, preds, average="macro")
132
+ return acc, f1
133
+
134
+
135
+
136
+ # ---------------------
137
+ # 3) Ablation runner
138
+ # ---------------------
139
+ def run_ablation(
140
+ epochs=3,
141
+ batch_size=16,
142
+ lr=2e-5,
143
+ sample_frac=0.05, # ← fraction of data to use
144
+ ):
145
+ set_seed()
146
+ process_data()
147
+
148
+ # --- Load the preprocessed SNLI dataset from disk ---
149
+ snli = load_from_disk(PREPROCESSED_DIR)
150
+ full_train = snli["train"]
151
+ full_val = snli["validation"]
152
+
153
+ # --- Sample 10% of each split ---
154
+ num_train = len(full_train)
155
+ num_val = len(full_val)
156
+ n_train = max(1, int(sample_frac * num_train))
157
+ n_val = max(1, int(sample_frac * num_val))
158
+
159
+ # reproducible shuffling
160
+ train_indices = list(range(num_train))
161
+ random.shuffle(train_indices)
162
+ train_subset = Subset(full_train, train_indices[:n_train])
163
+
164
+ val_indices = list(range(num_val))
165
+ random.shuffle(val_indices)
166
+ val_subset = Subset(full_val, val_indices[:n_val])
167
+
168
+ # --- Build DataLoaders with the SNLI collate_fn ---
169
+ train_loader = DataLoader(
170
+ train_subset,
171
+ batch_size=batch_size,
172
+ shuffle=True,
173
+ collate_fn=my_collate_fn,
174
+ num_workers=4,
175
+ pin_memory=True,
176
+ )
177
+ val_loader = DataLoader(
178
+ val_subset,
179
+ batch_size=batch_size,
180
+ shuffle=False,
181
+ collate_fn=my_collate_fn,
182
+ num_workers=2,
183
+ pin_memory=True,
184
+ )
185
+
186
+ # 4) Define models
187
+ models = {
188
+ "Baseline-BERT": BertOnlyNLIModel(MODEL_NAME).to(DEVICE),
189
+ "GNN-Augmented": GraphAugmentedNLIModel(
190
+ base_model_name=MODEL_NAME,
191
+ num_labels=3,
192
+ hidden_dim=768,
193
+ gnn_dim=128
194
+ ).to(DEVICE),
195
+ }
196
+
197
+ results = {}
198
+ for name, model in models.items():
199
+ logging.info(f"--- Training {name} on {sample_frac*100:.0f}% of data ---")
200
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
201
+ total_steps = epochs * len(train_loader)
202
+ scheduler = torch.optim.lr_scheduler.LinearLR(
203
+ optimizer,
204
+ start_factor=0.1,
205
+ total_iters=total_steps
206
+ )
207
+
208
+ # training loop
209
+ for epoch in range(1, epochs + 1):
210
+ train_loss = train_one_epoch(model, train_loader, optimizer, scheduler)
211
+ logging.info(f"{name} Epoch {epoch}: train_loss={train_loss:.4f}")
212
+
213
+ # evaluation
214
+ acc, f1 = evaluate(model, val_loader)
215
+ logging.info(f"{name} on {sample_frac*100:.0f}% val β†’ acc={acc:.4f}, f1={f1:.4f}")
216
+ results[name] = {"accuracy": acc, "f1": f1}
217
+
218
+ # 5) Plot
219
+ names = list(results.keys())
220
+ accs = [results[n]["accuracy"] for n in names]
221
+ f1s = [results[n]["f1"] for n in names]
222
+
223
+ plt.figure()
224
+ plt.bar(names, accs)
225
+ plt.xlabel("Model")
226
+ plt.ylabel("Validation Accuracy")
227
+ plt.title(f"Ablation on {sample_frac*100:.0f}% Data: Accuracy")
228
+
229
+ plt.figure()
230
+ plt.bar(names, f1s)
231
+ plt.xlabel("Model")
232
+ plt.ylabel("Validation Macro-F1")
233
+ plt.title(f"Ablation on {sample_frac*100:.0f}% Data: Macro-F1")
234
+
235
+ plt.show()
236
+
237
+
238
+ if __name__ == "__main__":
239
+ run_ablation(epochs=5, batch_size=16, lr=5e-3, sample_frac=0.1)
ablation_and_evaluation/eval2.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, re, numpy as np, pandas as pd
2
+ from tqdm.auto import tqdm
3
+ from datasets import load_dataset
4
+ from transformers import pipeline
5
+ from utils import extract_sentences_by_intent, nlp # <- your spaCy model
6
+
7
+ # ─── CONFIG ──────────────────────────────────────────────────────────
8
+ TOP_K = 3 # candidate spans per example
9
+ N_PER_DS = 200 # keep *valid* examples per dataset
10
+ BATCH_SIZE = 16
11
+ DEVICE = 0 # GPU id (-1 = CPU)
12
+ OUTPUT_PATH = "results/combined_results.xlsx"
13
+ # ────────────────────────────────────────────────────────────────────
14
+
15
+ # ── helper: flatten arbitrary json-ish field to plain text ──────────
16
+ def flatten_to_text(x):
17
+ if isinstance(x, str):
18
+ return x
19
+ if isinstance(x, dict):
20
+ if "text" in x and isinstance(x["text"], str):
21
+ return x["text"]
22
+ return "\n".join(flatten_to_text(v) for v in x.values())
23
+ if isinstance(x, (list, tuple)):
24
+ return "\n".join(flatten_to_text(v) for v in x)
25
+ return str(x)
26
+
27
+ # ── helper: map any label (β€œscore 3” or β€œgood answer”) β†’ int 1-5 ────
28
+ LABEL_STRINGS = [
29
+ "very bad answer", # 1
30
+ "bad answer", # 2
31
+ "acceptable answer", # 3
32
+ "good answer", # 4
33
+ "perfect answer" # 5
34
+ ]
35
+ def label_to_int(lbl: str) -> int:
36
+ m = re.search(r"([1-5])", lbl)
37
+ if m: # the label already contains a digit
38
+ return int(m.group(1))
39
+ for i, s in enumerate(LABEL_STRINGS, 1):
40
+ if s in lbl.lower():
41
+ return i
42
+ return 1 # fallback
43
+
44
+ # ── datasets – full splits will be shuffled, then filtered ──────────
45
+ datasets_info = [
46
+ ("FinQA-10K", "virattt/financial-qa-10K", "train", False, {}),
47
+ ("SQuAD", "rajpurkar/squad", "validation", False, {}),
48
+ ]
49
+
50
+ # ── zero-shot classification judges (all ~125-140 M params) ─────────
51
+ candidate_labels = LABEL_STRINGS # same list for every judge
52
+
53
+ judge1 = pipeline("zero-shot-classification",
54
+ model="roberta-large-mnli",
55
+ device=DEVICE, batch_size=BATCH_SIZE)
56
+
57
+ judge2 = pipeline("zero-shot-classification",
58
+ model="microsoft/deberta-base-mnli",
59
+ device=DEVICE, batch_size=BATCH_SIZE)
60
+
61
+ judge3 = pipeline("zero-shot-classification",
62
+ model="valhalla/distilbart-mnli-12-3",
63
+ device=DEVICE, batch_size=BATCH_SIZE)
64
+
65
+ # ── main loop ───────────────────────────────────────────────────────
66
+ rows = []
67
+
68
+ for ds_name, hf_id, split, trust_code, extra_kwargs in datasets_info:
69
+ print(f"\nβ†’ Loading {ds_name} ({hf_id}#{split}) and collecting {N_PER_DS} examples…")
70
+ ds = load_dataset(hf_id, split=split, trust_remote_code=trust_code, **extra_kwargs)
71
+ ds = ds.shuffle(seed=42)
72
+
73
+ collected, bar = 0, tqdm(total=N_PER_DS, desc=f"{ds_name} valid")
74
+ for ex in ds:
75
+ # unified fields -------------------------------------------------------
76
+ question = ex.get("question") or ex.get("question_text") or ex.get("query") or ""
77
+ context = flatten_to_text(
78
+ ex.get("context") or ex.get("document_text") or ex.get("story") or ex.get("text") or ""
79
+ )
80
+
81
+ # keep only if context has β‰₯ 2 sentences -------------------------------
82
+ if len(list(nlp(context).sents)) < 2:
83
+ continue
84
+
85
+ # candidate spans ------------------------------------------------------
86
+ spans = [s for s, _ in extract_sentences_by_intent(
87
+ text=context, intent=question, threshold=-1.0, top_k=TOP_K)]
88
+
89
+ if not spans: # no hit – fill with defaults
90
+ rows.append({
91
+ "dataset": ds_name, "question": question, "context": context,
92
+ "span": "", "score1": 5.0, "score2": 5.0, "score3": 5.0, "score_avg": 5.0
93
+ })
94
+ collected += 1; bar.update(1)
95
+ if collected >= N_PER_DS: break
96
+ continue
97
+
98
+ prompts = [
99
+ f"Question: {question}\nCandidate answer: {span}\n\n"
100
+ "On a scale from 1 (completely wrong) to 5 (perfect), reply with a single digit."
101
+ for span in spans
102
+ ]
103
+
104
+ # run judges -----------------------------------------------------------
105
+ out1 = judge1(prompts, candidate_labels=candidate_labels, multi_label=False)
106
+ out2 = judge2(prompts, candidate_labels=candidate_labels, multi_label=False)
107
+ out3 = judge3(prompts, candidate_labels=candidate_labels, multi_label=False)
108
+
109
+ j1 = [label_to_int(o["labels"][0]) for o in out1]
110
+ j2 = [label_to_int(o["labels"][0]) for o in out2]
111
+ j3 = [label_to_int(o["labels"][0]) for o in out3]
112
+
113
+ avg_scores = [(a+b+c)/3.0 for a, b, c in zip(j1, j2, j3)]
114
+ best = int(np.argmax(avg_scores))
115
+
116
+ rows.append({
117
+ "dataset": ds_name, "question": question, "context": context,
118
+ "span": spans[best],
119
+ "score1": float(j1[best]), "score2": float(j2[best]), "score3": float(j3[best]),
120
+ "score_avg": float(avg_scores[best])
121
+ })
122
+ collected += 1; bar.update(1)
123
+ if collected >= N_PER_DS:
124
+ break
125
+ bar.close()
126
+ if collected < N_PER_DS:
127
+ print(f"⚠️ Only {collected} qualifying examples found for {ds_name}")
128
+
129
+ # ── save & report ───────────────────────────────────────────────────
130
+ os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)
131
+ pd.DataFrame(rows).to_excel(OUTPUT_PATH, index=False)
132
+ print(f"\nβœ”οΈ Saved combined results β†’ {OUTPUT_PATH}")
133
+
134
+ print("\nβ–ΆοΈŽ Per-dataset judge averages:")
135
+ summary = (pd.DataFrame(rows)
136
+ .groupby("dataset")[["score1", "score2", "score3", "score_avg"]]
137
+ .mean())
138
+ for ds, row in summary.iterrows():
139
+ print(f" {ds:12s} | "
140
+ f"Judge1 {row['score1']:.2f} "
141
+ f"Judge2 {row['score2']:.2f} "
142
+ f"Judge3 {row['score3']:.2f} "
143
+ f"Combined {row['score_avg']:.2f}")
ablation_and_evaluation/evaluation_studies.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import numpy as np
4
+ import pandas as pd
5
+ from tqdm.auto import tqdm
6
+ from datasets import load_dataset
7
+ from transformers import pipeline
8
+ from utils import extract_sentences_by_intent, nlp
9
+
10
+ # ─── CONFIG ────────────────────────────────────────────────────────────────────
11
+ TOP_K = 3 # number of spans to extract per example
12
+ N_PER_DS = 200 # how many *valid* examples per dataset
13
+ BATCH_SIZE = 16 # batch size for judge pipelines
14
+ DEVICE = 0 # GPU id, or -1 for CPU
15
+ OUTPUT_PATH = "results/combined_results.xlsx"
16
+ # ────────────────────────────────────────────────────────────────────────────────
17
+
18
+ def flatten_to_text(x):
19
+ if isinstance(x, str):
20
+ return x
21
+ if isinstance(x, dict):
22
+ if "text" in x and isinstance(x["text"], str):
23
+ return x["text"]
24
+ return "\n".join(flatten_to_text(v) for v in x.values())
25
+ if isinstance(x, list):
26
+ return "\n".join(flatten_to_text(v) for v in x)
27
+ return str(x)
28
+
29
+ def label_to_int(lbl: str) -> int:
30
+ # handles both variants A and B
31
+ m = re.search(r"([1-5])", lbl)
32
+ if m: # digits present -> easy
33
+ return int(m.group(1))
34
+ # descriptive version -> map by order
35
+ mapping = {
36
+ "very bad answer": 1,
37
+ "bad answer": 2,
38
+ "acceptable answer": 3,
39
+ "good answer": 4,
40
+ "perfect answer": 5
41
+ }
42
+ return mapping.get(lbl.lower(), 1)
43
+
44
+ # ─── 1) choose these two datasets ──────────────────────────────────────────────
45
+ datasets_info = [
46
+ ("FinQA-10K", "virattt/financial-qa-10K", "train", False, {}),
47
+ ("SQuAD", "rajpurkar/squad", "validation", False, {}),
48
+ ]
49
+
50
+ # ─── 2) spin up zero‐shot classification judges ────────────────────────────────
51
+ candidate_labels = [
52
+ "very bad answer", # 1
53
+ "bad answer", # 2
54
+ "acceptable answer", # 3
55
+ "good answer", # 4
56
+ "perfect answer" # 5
57
+ ]
58
+
59
+ judge1 = pipeline(
60
+ "zero-shot-classification",
61
+ model="roberta-large-mnli",
62
+ device=DEVICE,
63
+ batch_size=BATCH_SIZE
64
+ )
65
+ judge2 = pipeline(
66
+ "zero-shot-classification",
67
+ model="microsoft/deberta-base-mnli",
68
+ tokenizer="microsoft/deberta-base-mnli",
69
+ device=DEVICE,
70
+ batch_size=BATCH_SIZE
71
+ )
72
+ judge3 = pipeline(
73
+ "zero-shot-classification",
74
+ model="valhalla/distilbart-mnli-12-3",
75
+ tokenizer="valhalla/distilbart-mnli-12-3",
76
+ device=DEVICE,
77
+ batch_size=BATCH_SIZE
78
+ )
79
+
80
+ all_rows = []
81
+
82
+ for ds_name, hf_id, split, trust_code, extra_kwargs in datasets_info:
83
+ print(f"\nβ†’ Loading {ds_name} ({hf_id}#{split}), gathering {N_PER_DS} valid examples…")
84
+ # load full split and shuffle
85
+ ds = load_dataset(hf_id, split=split, trust_remote_code=trust_code, **extra_kwargs)
86
+ ds = ds.shuffle(seed=42)
87
+
88
+ collected = 0
89
+ pbar = tqdm(total=N_PER_DS, desc=f"{ds_name} valid examples")
90
+ for ex in ds:
91
+ # unify question & context
92
+ question = (
93
+ ex.get("question")
94
+ or ex.get("question_text")
95
+ or ex.get("query")
96
+ or ""
97
+ )
98
+ raw_ctx = (
99
+ ex.get("context")
100
+ or ex.get("document_text")
101
+ or ex.get("story")
102
+ or ex.get("text")
103
+ or ""
104
+ )
105
+ context = flatten_to_text(raw_ctx)
106
+
107
+ # only keep examples whose context has at least 2 sentences
108
+ if len(list(nlp(context).sents)) < 2:
109
+ continue
110
+
111
+ # extract top-K spans
112
+ hits = extract_sentences_by_intent(
113
+ text = context,
114
+ intent = question,
115
+ threshold = -1.0,
116
+ top_k = TOP_K,
117
+ convo_focus = None
118
+ )
119
+ spans = [s for s,_ in hits]
120
+ if not spans:
121
+ # record defaults if no span found
122
+ all_rows.append({
123
+ "dataset": ds_name,
124
+ "question": question,
125
+ "context": context,
126
+ "span": "",
127
+ "score1": 5.0,
128
+ "score2": 5.0,
129
+ "score3": 5.0,
130
+ "score_avg": 5.0
131
+ })
132
+ collected += 1
133
+ pbar.update(1)
134
+ if collected >= N_PER_DS:
135
+ break
136
+ continue
137
+
138
+ # build prompts
139
+ prompts = [
140
+ f"Question: {question}\nCandidate answer: {span}\n\n"
141
+ "On a scale from 1 (completely wrong) to 5 (perfect), "
142
+ "reply with a single digit."
143
+ for span in spans
144
+ ]
145
+
146
+ # run judges
147
+ out1 = judge1(prompts, candidate_labels=candidate_labels, multi_label=False)
148
+ out2 = judge2(prompts, candidate_labels=candidate_labels, multi_label=False)
149
+ out3 = judge3(prompts, candidate_labels=candidate_labels, multi_label=False)
150
+
151
+ # parse their top‐chosen labels
152
+ j1 = [int(o["labels"][0]) for o in out1]
153
+ j2 = [int(o["labels"][0]) for o in out2]
154
+ j3 = [int(o["labels"][0]) for o in out3]
155
+
156
+ # average per span, pick best
157
+ avg_scores = [(a+b+c)/3.0 for a,b,c in zip(j1,j2,j3)]
158
+ best_idx = int(np.argmax(avg_scores))
159
+
160
+ all_rows.append({
161
+ "dataset": ds_name,
162
+ "question": question,
163
+ "context": context,
164
+ "span": spans[best_idx],
165
+ "score1": float(j1[best_idx]),
166
+ "score2": float(j2[best_idx]),
167
+ "score3": float(j3[best_idx]),
168
+ "score_avg": float(avg_scores[best_idx])
169
+ })
170
+ collected += 1
171
+ pbar.update(1)
172
+ if collected >= N_PER_DS:
173
+ break
174
+
175
+ pbar.close()
176
+ if collected < N_PER_DS:
177
+ print(f"⚠️ Only found {collected} valid examples for {ds_name}.")
178
+
179
+ # ─── dump to Excel ─────────────────────────────────────────────────────────────
180
+ df = pd.DataFrame(all_rows)
181
+ os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)
182
+ df.to_excel(OUTPUT_PATH, index=False)
183
+ print(f"\nβœ”οΈ Saved combined results to ./{OUTPUT_PATH}")
184
+
185
+ # ─── per‐dataset summary ───────────────────────────────────────────────────────
186
+ print("\nβ–ΆοΈŽ Per‐dataset judge averages:")
187
+ grouped = df.groupby("dataset")[["score1","score2","score3","score_avg"]].mean()
188
+ for ds, row in grouped.iterrows():
189
+ print(f" {ds}: "
190
+ f"Judge1={row['score1']:.2f}, "
191
+ f"Judge2={row['score2']:.2f}, "
192
+ f"Judge3={row['score3']:.2f}, "
193
+ f"Combined={row['score_avg']:.2f}")
config.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+ import torch
3
+ import numpy as np
4
+ from transformers import AutoTokenizer
5
+ import spacy
6
+
7
+ MODEL_NAME = "bert-base-uncased"
8
+ MAX_LENGTH = 128
9
+ OVERLAP = 32
10
+ PREPROCESSED_DIR= "preprocessed_snli"
11
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
14
+ nlp = spacy.load("en_core_web_sm")
data/Fin_ExBERT_data.xlsx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3bcd855be3a6ee5d740c9f61439268f5a84a085a2bf5a0593027c80c8a97f34c
3
+ size 919294
data/Fin_ExBERT_test_set.xlsx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ca7db6ee8c28319f3fd1736a0bb05cfe5bacf1f33e07549f54fd90134b3706f
3
+ size 324485
data/Fin_ExBERT_train_val_data.xlsx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d4804d0e204a40e797a537767c95999e585eb7f50cea416b784311efe76c731
3
+ size 1571514
finetune_lora.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ import math
6
+ import matplotlib.pyplot as plt
7
+ from datasets import load_dataset
8
+ from torch.utils.data import DataLoader
9
+ from tqdm.auto import tqdm
10
+ from transformers import (
11
+ AutoTokenizer,
12
+ AutoModelForMaskedLM,
13
+ DataCollatorForLanguageModeling,
14
+ get_linear_schedule_with_warmup,
15
+ )
16
+ from accelerate import Accelerator
17
+ from peft import LoraConfig, get_peft_model
18
+
19
+ # Configuration constants
20
+ MODEL_NAME = "bert-base-uncased"
21
+ BATCH_SIZE = 16
22
+ MAX_LENGTH = 128
23
+ LEARNING_RATE = 5e-4
24
+ EPOCHS = 20
25
+ SEED = 42
26
+ ADAPTER_SAVE_DIR = "./lora_finance_adapter"
27
+ CHECKPOINT_PATH = os.path.join(ADAPTER_SAVE_DIR, "training_checkpoint.pt")
28
+
29
+
30
+ def set_seed(seed: int = SEED):
31
+ random.seed(seed)
32
+ np.random.seed(seed)
33
+ torch.manual_seed(seed)
34
+ if torch.cuda.is_available():
35
+ torch.cuda.manual_seed_all(seed)
36
+
37
+
38
+ def fine_tune_lora(dataset_name: str = "FinGPT/fingpt-fiqa_qa", split: str = "train"):
39
+ """
40
+ Fine-tune BERT with LoRA on an MLM objective.
41
+ Supports checkpointing and resuming, and plots loss, perplexity, and MLM accuracy per epoch.
42
+ Saves the LoRA adapter and checkpoint in ADAPTER_SAVE_DIR.
43
+ """
44
+ set_seed()
45
+
46
+ # Prepare save directory
47
+ os.makedirs(ADAPTER_SAVE_DIR, exist_ok=True)
48
+
49
+ # Load and prepare dataset
50
+ dataset = load_dataset(dataset_name, split=split)
51
+ def combine_fields(example):
52
+ text = ' '.join([example.get(k, '').strip() for k in ['instruction', 'input', 'output'] if example.get(k)])
53
+ return {"text": text}
54
+ dataset = dataset.map(combine_fields, remove_columns=[c for c in dataset.column_names if c != 'text'])
55
+
56
+ # Tokenization and DataLoader
57
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
58
+ def tokenize_fn(examples):
59
+ return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=MAX_LENGTH)
60
+ tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=[c for c in dataset.column_names if c != 'text'])
61
+ tokenized.set_format(type='torch', columns=['input_ids', 'attention_mask'])
62
+ collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
63
+ train_loader = DataLoader(
64
+ tokenized,
65
+ batch_size=BATCH_SIZE,
66
+ shuffle=True,
67
+ collate_fn=collator,
68
+ num_workers=4,
69
+ pin_memory=True,
70
+ )
71
+
72
+ # Model, LoRA, optimizer, scheduler
73
+ model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME)
74
+ lora_cfg = LoraConfig(r=8, lora_alpha=32, lora_dropout=0.1, bias='none', task_type='CAUSAL_LM')
75
+ model = get_peft_model(model, lora_cfg)
76
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
77
+ total_steps = EPOCHS * len(train_loader)
78
+ scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(0.1 * total_steps), num_training_steps=total_steps)
79
+
80
+ # Accelerator
81
+ accelerator = Accelerator()
82
+ model, optimizer, train_loader, scheduler = accelerator.prepare(model, optimizer, train_loader, scheduler)
83
+ device = accelerator.device
84
+
85
+ # Metrics storage and resume state
86
+ start_epoch = 1
87
+ epoch_losses = []
88
+ epoch_ppls = []
89
+ epoch_accs = []
90
+
91
+ # Load checkpoint if exists
92
+ if os.path.exists(CHECKPOINT_PATH):
93
+ ckpt = torch.load(CHECKPOINT_PATH, map_location=device)
94
+ model.load_state_dict(ckpt['model_state_dict'])
95
+ optimizer.load_state_dict(ckpt['optimizer_state_dict'])
96
+ scheduler.load_state_dict(ckpt['scheduler_state_dict'])
97
+ start_epoch = ckpt['epoch'] + 1
98
+ epoch_losses = ckpt.get('epoch_losses', [])
99
+ epoch_ppls = ckpt.get('epoch_ppls', [])
100
+ epoch_accs = ckpt.get('epoch_accs', [])
101
+ print(f"Resuming from epoch {start_epoch}")
102
+
103
+ # Training loop
104
+ model.train()
105
+ for epoch in range(start_epoch, EPOCHS + 1):
106
+ total_loss, total_masked, correct_masked = 0.0, 0, 0
107
+ progress = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}", leave=False)
108
+ for batch in progress:
109
+ optimizer.zero_grad()
110
+ input_ids = batch['input_ids'].to(device)
111
+ attention_mask = batch['attention_mask'].to(device)
112
+ labels = batch['labels'].to(device)
113
+
114
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
115
+ loss, logits = outputs.loss, outputs.logits
116
+ accelerator.backward(loss)
117
+ optimizer.step()
118
+ scheduler.step()
119
+
120
+ # Accumulate
121
+ step_loss = loss.item()
122
+ total_loss += step_loss
123
+ preds = torch.argmax(logits, dim=-1)
124
+ mask = labels.ne(-100)
125
+ correct_masked += preds.eq(labels).masked_select(mask).sum().item()
126
+ total_masked += mask.sum().item()
127
+ progress.set_postfix({'loss': f"{step_loss:.4f}"})
128
+
129
+ # Epoch metrics
130
+ avg_loss = total_loss / len(train_loader)
131
+ avg_ppl = math.exp(avg_loss)
132
+ avg_acc = correct_masked / total_masked if total_masked > 0 else 0
133
+ epoch_losses.append(avg_loss)
134
+ epoch_ppls.append(avg_ppl)
135
+ epoch_accs.append(avg_acc)
136
+ print(f"Epoch {epoch}: Loss={avg_loss:.4f}, PPL={avg_ppl:.2f}, MLM Acc={avg_acc:.4%}")
137
+
138
+ # Save checkpoint
139
+ ckpt = {
140
+ 'epoch': epoch,
141
+ 'model_state_dict': model.state_dict(),
142
+ 'optimizer_state_dict': optimizer.state_dict(),
143
+ 'scheduler_state_dict': scheduler.state_dict(),
144
+ 'epoch_losses': epoch_losses,
145
+ 'epoch_ppls': epoch_ppls,
146
+ 'epoch_accs': epoch_accs,
147
+ }
148
+ torch.save(ckpt, CHECKPOINT_PATH)
149
+
150
+ # Final plots
151
+ fig, axes = plt.subplots(3, 1, figsize=(6, 10), sharex=True)
152
+ epochs_list = list(range(1, len(epoch_losses) + 1))
153
+ axes[0].plot(epochs_list, epoch_losses, marker='o'); axes[0].set_ylabel('Loss'); axes[0].set_title('Training Loss'); axes[0].grid(True)
154
+ axes[1].plot(epochs_list, epoch_ppls, marker='o'); axes[1].set_ylabel('Perplexity'); axes[1].set_title('Training Perplexity'); axes[1].grid(True)
155
+ axes[2].plot(epochs_list, epoch_accs, marker='o'); axes[2].set_ylabel('MLM Accuracy'); axes[2].set_xlabel('Epoch'); axes[2].set_title('Masked LM Accuracy'); axes[2].grid(True)
156
+ plt.tight_layout(); plt.show()
157
+
158
+ # Save LoRA adapter
159
+ model.save_pretrained(ADAPTER_SAVE_DIR)
160
+ print(f"LoRA adapter saved to {ADAPTER_SAVE_DIR}")
161
+
162
+
163
+ if __name__ == '__main__':
164
+ fine_tune_lora()
images/methodology_flowchart.png ADDED

Git LFS Details

  • SHA256: e3aafb2e0d0a78a75fdeb34deceabb946a6f08cba7ab861aba18eb10a7003639
  • Pointer size: 131 Bytes
  • Size of remote file: 239 kB
images/test.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
main.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #from models import *
2
+ #from preprocess_data import *
3
+ from utils import extract_sentences_by_intent, train_model_with_chkpt, batch_predict_and_save
4
+ from time import time
5
+ import logging
6
+ from config import *
7
+
8
+
9
+
10
+
11
+
12
+ if __name__ == '__main__':
13
+ # train_model_with_chkpt(epochs=5, batch_size=16, lr=2e-3,
14
+ # save_model=True,
15
+ # save_path='gnn_model_checkpoint.pt',
16
+ # resume=True)
17
+
18
+ from transformers import BertTokenizer
19
+ from preprocess_data import SentenceDataset
20
+ from models import SentenceExtractionModel
21
+ from utils import train_sentence_extractor
22
+
23
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
24
+ dataset = SentenceDataset("data/Fin_ExBERT_train_val_data.xlsx", tokenizer)
25
+
26
+
27
+ model = SentenceExtractionModel(
28
+ base_model_name=MODEL_NAME,
29
+ backbone='finexbert'
30
+ )
31
+
32
+ # train_sentence_extractor(
33
+ # model,
34
+ # dataset,
35
+ # output_dir="checkpoints/sentence_extractor",
36
+ # val_split=0.3,
37
+ # epochs=10,
38
+ # batch_size=16,
39
+ # lr=3e-4,
40
+ # device=DEVICE,
41
+ # unfreeze_after_epoch=4
42
+ # )
43
+ #
44
+ # from transformers import BertTokenizer
45
+ # from models import SentenceExtractionModel
46
+ # from utils import demo_on_random_val
47
+ #
48
+ # tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
49
+ # model = SentenceExtractionModel(
50
+ # base_model_name=MODEL_NAME,
51
+ # backbone='finexbert'
52
+ # )
53
+ #
54
+ # demo_on_random_val(
55
+ # model,
56
+ # tokenizer,
57
+ # excel_path="data/Fin_ExBERT_test_set.xlsx",
58
+ # ckpt_path="checkpoints/sentence_extractor/best_model.pth",
59
+ # device="cuda", # or "cpu"
60
+ # temperature=1,
61
+ # )
62
+
63
+ batch_predict_and_save(
64
+ model,
65
+ tokenizer,
66
+ excel_path="data/Fin_ExBERT_test_set.xlsx",
67
+ ckpt_path="checkpoints/sentence_extractor/best_model.pth",
68
+ output_path="results/predictions_sample200.xlsx",
69
+ n_samples=200,
70
+ temperature=1.0,
71
+ device="cuda"
72
+ )
73
+
74
+ sample_transcript = """
75
+ Agent: Hello, thank you for calling Acme Financial Services. My name is Priya. How can I help you today?
76
+ Customer: Hi Priya, I’m considering opening a new savings account with you.
77
+ Agent: Absolutelyβ€”our savings account offers 4% interest per annum. Do you have a balance in mind?
78
+ Customer: Yes, I’d like to deposit β‚Ή50,000 initially, and then I’m interested in investing another β‚Ή2 lakh in mutual funds over the next month.
79
+ Agent: Great, we have several mutual fund options. Are you more growth-oriented or looking for steady income?
80
+ Customer: I want to focus on growth. Also, could you tell me about your home loan rates? I may need a β‚Ή30 lakh mortgage in the next six months.
81
+ Agent: Certainlyβ€”we currently offer home loan rates starting at 6.8%. Do you already own property or are you planning to buy?
82
+ Customer: Planning to buy. Finally, I’d like to apply for a credit card with a high cashbackβ€”maybe one that gives 2% on all spends.
83
+ Agent: We have a Platinum Cashback Card at 1.5%, and our Signature Cashback Card at 2%. Would you like me to initiate the application?
84
+ Customer: Yes please, go ahead with the Signature Cashback Card, and send me the home-loan documents via email.
85
+ Agent: Done. You’ll receive an email shortly. Is there anything else I can help you with?
86
+ Customer: No, that’s all for todayβ€”thank you!
87
+ """
88
+
89
+ complex_transcript = """
90
+ Agent: Good morning, thank you for calling Maple Grove Bank. This is Rahul speakingβ€”how may I assist you today?
91
+ Customer: Hi Rahul, I’ve been reviewing my financial goals for the next five years and want to discuss a mix of savings, investments, and insurance.
92
+ Agent: Absolutely. Would you like to start with your current cash savings or jump straight into investment products?
93
+ Customer: Let’s begin with savings: I’d like to open a high-yield savings account with at least β‚Ή1 lakh to start, and then set up an automatic top-up of β‚Ή10,000 each month.
94
+ Agent: Great choice. We have our β€œPlus Savings” account at 4.2% APY. Next, investmentsβ€”are you looking at mutual funds, stocks, or retirement plans?
95
+ Customer: I’m particularly interested in tax-saving ELSS mutual funds and a more conservative retirement pension plan. Also, could you explain your term insurance offerings?
96
+ Agent: Sureβ€”our ELSS options include Fund A (equity-heavy) and Fund B (balanced). For term cover, we have 20-year plans up to β‚Ή50 lakhs. Any preference?
97
+ Customer: I want a balanced ELSS with a 3-year lock-in, and term insurance of β‚Ή30 lakhs for 25 years. After that, I may need advice on buying a second homeβ€”so let’s also discuss mortgage pre-approval.
98
+ Agent: Understood. For a β‚Ή30 lakh home loan, current interest rates start at 6.9%. We can pre-approve you based on your income. Shall I proceed?
99
+ Customer: Yes, please initiate the home-loan pre-approval. And lastly, I’d like to apply for a debit card with no annual fee and a co-branded credit card offering travel rewards.
100
+ Agent: Certainlyβ€”our β€œFreedom” debit card has no fee, and the β€œSkyMiles” credit card gives 2 airline miles per β‚Ή100 spent. Would you like to complete those applications now?
101
+ Customer: Yes, go ahead with both. Also, can you set up a quarterly portfolio review call with a financial advisor?
102
+ Agent: Absolutely. I’ll schedule a review every three months starting next quarter. You’ll get email confirmations shortly.
103
+ Customer: Perfectβ€”that covers all my needs. Thanks for your help!
104
+ Agent: My pleasure! Have a great day and feel free to call back anytime.
105
+ """
106
+
107
+ # premise_input = "personA is on the stage giving a speech."
108
+ # hypothesis_input = "personA is using a microphone."
109
+ # prediction, _ = predict_fin_nli(premise=premise_input, hypothesis=hypothesis_input, model_path='gnn_model_checkpoint.pt')
110
+ # print("Prediction:", prediction)
111
+ # print('Final layer logits:', _)
112
+
113
+ ################################
114
+
115
+ # start = time()
116
+ # results = extract_sentences_by_intent(
117
+ # complex_transcript,
118
+ # intent="customer tells about own financial condition",#"customer states specific financial product requests and planning preferences",
119
+ # #"agent provides assistance", #"customer states their financial needs",
120
+ # threshold=0.60,
121
+ # top_k=10,
122
+ # convo_focus='customer'
123
+ # )
124
+ # end = time()
125
+ #
126
+ # logging.info('Prediction Done in {:.2f}sec'.format(end - start))
127
+ #
128
+ # for sentence, score in results:
129
+ # print(f"{score:.2f} β†’ {sentence}")
130
+
131
+
132
+
models.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import math
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from peft import PeftModel, LoraConfig, get_peft_model
7
+ from transformers import AutoTokenizer, AutoModel, AutoConfig, get_linear_schedule_with_warmup
8
+ from torch.nn import MultiheadAttention, GELU
9
+
10
+ MODEL_NAME = "bert-base-uncased"
11
+ BATCH_SIZE = 16
12
+ MAX_LENGTH = 128
13
+ LEARNING_RATE = 2e-5
14
+ EPOCHS = 5
15
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ PREPROCESSED_DIR = "preprocessed_snli"
17
+ MIXED_PRECISION = "fp16"
18
+
19
+
20
+ class SimpleGNN(nn.Module):
21
+ def __init__(self, input_dim, hidden_dim):
22
+ super().__init__()
23
+ self.fc = nn.Linear(input_dim, hidden_dim)
24
+
25
+ def forward(self, node_embeddings, edges):
26
+ if node_embeddings.size(0) == 0:
27
+ return torch.zeros(1, self.fc.out_features, device=node_embeddings.device)
28
+ num_nodes = node_embeddings.size(0)
29
+ adj = torch.zeros((num_nodes, num_nodes), device=node_embeddings.device)
30
+ for (src, dst) in edges:
31
+ if src < num_nodes and dst < num_nodes:
32
+ adj[src, dst] = 1.0
33
+ deg = adj.sum(dim=1, keepdim=True) + 1e-10
34
+ adj_norm = adj / deg
35
+ agg_embeddings = adj_norm @ node_embeddings
36
+ return F.relu(self.fc(agg_embeddings))
37
+
38
+
39
+ class GraphAugmentedNLIModel(nn.Module):
40
+ def __init__(self, base_model_name, num_labels=3, hidden_dim=768, gnn_dim=128):
41
+ super().__init__()
42
+ config = AutoConfig.from_pretrained(base_model_name)
43
+ config.num_labels = num_labels
44
+ self.bert = AutoModel.from_pretrained(base_model_name, config=config)
45
+ self.dropout = nn.Dropout(0.1)
46
+
47
+ self.gnn_premise = SimpleGNN(hidden_dim, gnn_dim)
48
+ self.gnn_hypothesis = SimpleGNN(hidden_dim, gnn_dim)
49
+
50
+ self.classifier = nn.Linear(hidden_dim + gnn_dim*2, num_labels)
51
+
52
+ def forward(self, input_ids, attention_mask, premise_graph_tokens, premise_graph_edges, premise_node_indices,
53
+ hypothesis_graph_tokens, hypothesis_graph_edges, hypothesis_node_indices, labels=None):
54
+ outputs = self.bert(input_ids, attention_mask=attention_mask)
55
+ cls_embedding = outputs.last_hidden_state[:,0,:] # [batch, hidden_dim]
56
+
57
+ batch_size = input_ids.size(0)
58
+ gnn_p_outputs = []
59
+ gnn_h_outputs = []
60
+
61
+ # Now node indices are precomputed. We just take those embeddings directly.
62
+ # node_indices correspond to the positions in input_ids whose embeddings represent that node.
63
+ for i in range(batch_size):
64
+ instance_hidden = outputs.last_hidden_state[i] # [seq_len, hidden_dim]
65
+
66
+ p_edges = premise_graph_edges[i]
67
+ p_indices = premise_node_indices[i]
68
+ h_edges = hypothesis_graph_edges[i]
69
+ h_indices = hypothesis_node_indices[i]
70
+
71
+ # Gather node embeddings
72
+ p_nodes = instance_hidden[p_indices] if len(p_indices) > 0 else torch.empty(0, instance_hidden.size(-1), device=instance_hidden.device)
73
+ h_nodes = instance_hidden[h_indices] if len(h_indices) > 0 else torch.empty(0, instance_hidden.size(-1), device=instance_hidden.device)
74
+
75
+ p_gnn_out = self.gnn_premise(p_nodes, p_edges) if p_nodes.size(0) > 0 else torch.zeros(1,128, device=DEVICE)
76
+ h_gnn_out = self.gnn_hypothesis(h_nodes, h_edges) if h_nodes.size(0) > 0 else torch.zeros(1,128, device=DEVICE)
77
+
78
+ p_mean = p_gnn_out.mean(dim=0, keepdim=True)
79
+ h_mean = h_gnn_out.mean(dim=0, keepdim=True)
80
+
81
+ gnn_p_outputs.append(p_mean)
82
+ gnn_h_outputs.append(h_mean)
83
+
84
+ gnn_p_outputs = torch.cat(gnn_p_outputs, dim=0) # [batch, gnn_dim]
85
+ gnn_h_outputs = torch.cat(gnn_h_outputs, dim=0) # [batch, gnn_dim]
86
+
87
+ fused = torch.cat([cls_embedding, gnn_p_outputs, gnn_h_outputs], dim=-1)
88
+ fused = self.dropout(fused)
89
+ logits = self.classifier(fused)
90
+
91
+ loss = None
92
+ if labels is not None:
93
+ loss_fn = nn.CrossEntropyLoss()
94
+ loss = loss_fn(logits, labels)
95
+ return {"loss": loss, "logits": logits}
96
+
97
+
98
+
99
+ class SimpleFinGNN(nn.Module):
100
+ def __init__(self, input_dim, hidden_dim):
101
+ super().__init__()
102
+ self.fc = nn.Linear(input_dim, hidden_dim)
103
+
104
+ def forward(self, node_embeddings, edges):
105
+ if node_embeddings.size(0) == 0:
106
+ return torch.zeros(1, self.fc.out_features, device=node_embeddings.device)
107
+ num_nodes = node_embeddings.size(0)
108
+ adj = torch.zeros((num_nodes, num_nodes), device=node_embeddings.device)
109
+ for (src, dst) in edges:
110
+ if src < num_nodes and dst < num_nodes:
111
+ adj[src, dst] = 1.0
112
+ deg = adj.sum(dim=1, keepdim=True) + 1e-10
113
+ adj_norm = adj / deg
114
+ agg_embeddings = adj_norm @ node_embeddings
115
+ return F.relu(self.fc(agg_embeddings))
116
+
117
+
118
+ class GraphAugmentedFinNLIModel(nn.Module):
119
+ def __init__(self, base_model_name, num_labels=3, hidden_dim=768, gnn_dim=128):
120
+ super().__init__()
121
+ config = AutoConfig.from_pretrained(base_model_name)
122
+ config.num_labels = num_labels
123
+ self.bert = AutoModel.from_pretrained(base_model_name, config=config)
124
+ self.dropout = nn.Dropout(0.1)
125
+
126
+ self.gnn_premise = SimpleGNN(hidden_dim, gnn_dim)
127
+ self.gnn_hypothesis = SimpleGNN(hidden_dim, gnn_dim)
128
+
129
+ self.classifier = nn.Linear(hidden_dim + gnn_dim*2, num_labels)
130
+ self.config = self.bert.config
131
+ self.config.num_labels = num_labels
132
+
133
+ def forward(self,
134
+ input_ids=None,
135
+ attention_mask=None,
136
+ premise_graph_tokens=None,
137
+ hypothesis_graph_tokens=None,
138
+ premise_graph_edges=None,
139
+ hypothesis_graph_edges=None,
140
+ premise_node_indices=None,
141
+ hypothesis_node_indices=None,
142
+ labels=None,
143
+ inputs_embeds=None,
144
+ **kwargs):
145
+ # Even if we don't use inputs_embeds, we should pass it into self.bert call:
146
+ outputs = self.bert(input_ids=input_ids,
147
+ attention_mask=attention_mask,
148
+ inputs_embeds=inputs_embeds,
149
+ **{k:v for k,v in kwargs.items() if k in self.bert.forward.__code__.co_varnames})
150
+
151
+ cls_embedding = outputs.last_hidden_state[:,0,:] # [batch, hidden_dim]
152
+
153
+ batch_size = input_ids.size(0) if input_ids is not None else outputs.last_hidden_state.size(0)
154
+ gnn_p_outputs = []
155
+ gnn_h_outputs = []
156
+
157
+ for i in range(batch_size):
158
+ instance_hidden = outputs.last_hidden_state[i] # [seq_len, hidden_dim]
159
+
160
+ p_edges = premise_graph_edges[i]
161
+ p_indices = premise_node_indices[i]
162
+ h_edges = hypothesis_graph_edges[i]
163
+ h_indices = hypothesis_node_indices[i]
164
+
165
+ p_nodes = instance_hidden[p_indices] if len(p_indices) > 0 else torch.empty(0, instance_hidden.size(-1), device=instance_hidden.device)
166
+ h_nodes = instance_hidden[h_indices] if len(h_indices) > 0 else torch.empty(0, instance_hidden.size(-1), device=instance_hidden.device)
167
+
168
+ p_gnn_out = self.gnn_premise(p_nodes, p_edges) if p_nodes.size(0) > 0 else torch.zeros(1,128, device=instance_hidden.device)
169
+ h_gnn_out = self.gnn_hypothesis(h_nodes, h_edges) if h_nodes.size(0) > 0 else torch.zeros(1,128, device=instance_hidden.device)
170
+
171
+ p_mean = p_gnn_out.mean(dim=0, keepdim=True)
172
+ h_mean = h_gnn_out.mean(dim=0, keepdim=True)
173
+
174
+ gnn_p_outputs.append(p_mean)
175
+ gnn_h_outputs.append(h_mean)
176
+
177
+ gnn_p_outputs = torch.cat(gnn_p_outputs, dim=0) # [batch, gnn_dim]
178
+ gnn_h_outputs = torch.cat(gnn_h_outputs, dim=0) # [batch, gnn_dim]
179
+
180
+ fused = torch.cat([cls_embedding, gnn_p_outputs, gnn_h_outputs], dim=-1)
181
+ logits = self.classifier(fused)
182
+
183
+ loss = None
184
+ if labels is not None:
185
+ loss_fn = nn.CrossEntropyLoss()
186
+ loss = loss_fn(logits, labels)
187
+ return {"loss": loss, "logits": logits}
188
+
189
+
190
+ class SentenceExtractionModel(nn.Module):
191
+ def __init__(self,
192
+ base_model_name: str,
193
+ dropout_prob: float = 0.1,
194
+ adapter_dir: str = "./lora_finance_adapter",
195
+ backbone: str = 'default',
196
+ init_pos_frac: float = None # NEW!
197
+ ):
198
+ """
199
+ backbone:
200
+ - 'default' β†’ plain AutoModel.from_pretrained(base_model_name)
201
+ - 'finexbert' β†’ use the .bert submodule of your GraphAugmentedFinNLIModel
202
+ """
203
+ super().__init__()
204
+
205
+ # load config
206
+ config = AutoConfig.from_pretrained(base_model_name)
207
+
208
+ if backbone == 'default':
209
+ # plain BERT
210
+ self.bert = AutoModel.from_pretrained(base_model_name, config=config)
211
+
212
+ elif backbone == 'finexbert':
213
+ # instantiate your full FinNLI model, then grab its .bert
214
+ base_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
215
+ lora_cfg = LoraConfig(
216
+ r=8,
217
+ lora_alpha=32,
218
+ lora_dropout=0.1,
219
+ bias="none",
220
+ task_type="SEQ_CLS"#"CAUSAL_LM", # must match your fine-tune setting
221
+ )
222
+ full = get_peft_model(base_model, lora_cfg).to(DEVICE)
223
+ chkpt_path = os.path.join(adapter_dir, "training_checkpoint.pt")
224
+ if not os.path.isfile(chkpt_path):
225
+ raise FileNotFoundError(f"No LoRA checkpoint at {chkpt_path}")
226
+ ckpt = torch.load(chkpt_path, map_location=DEVICE)
227
+ # ckpt["model_state_dict"] contains both base + LoRA weights; strict=False
228
+ full.load_state_dict(ckpt["model_state_dict"], strict=False)
229
+ # if you have a saved finexbert checkpoint, load it here:
230
+ # full.load_state_dict(torch.load("path/to/finexbert.pth", map_location='cpu'))
231
+ self.bert = full.base_model
232
+
233
+ else:
234
+ raise ValueError(f"Unknown backbone {backbone}")
235
+
236
+ hidden_size = self.bert.config.hidden_size
237
+
238
+ self.dropout = nn.Dropout(dropout_prob)
239
+ self.classifier = nn.Linear(hidden_size, 1)
240
+
241
+ # initialize bias to log-odds of init_pos_frac
242
+ if init_pos_frac is not None:
243
+ b0 = float(math.log(init_pos_frac / (1.0 - init_pos_frac)))
244
+ self.classifier.bias.data.fill_(b0)
245
+
246
+ def forward(self, input_ids, attention_mask):
247
+ outputs = self.bert(input_ids=input_ids,
248
+ attention_mask=attention_mask)
249
+ x = self.dropout(outputs.pooler_output)
250
+ logits = self.classifier(x).squeeze(-1) # [batch]
251
+ return logits
preprocess_data.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ from datasets import load_dataset, load_from_disk
6
+ import pandas as pd
7
+ import nltk
8
+
9
+ from config import MODEL_NAME, MAX_LENGTH, OVERLAP, PREPROCESSED_DIR, tokenizer, nlp
10
+
11
+ # =============================
12
+ # Logging Setup
13
+ # =============================
14
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
15
+
16
+ # =============================
17
+ # One-Time Preprocessing
18
+ # =============================
19
+ def process_data():
20
+ if not os.path.exists(PREPROCESSED_DIR):
21
+ logging.info("Preprocessing data... This may take a while.")
22
+ # Load and filter SNLI
23
+ snli = load_dataset("snli")
24
+ snli = snli.filter(lambda x: x["label"] != -1)
25
+
26
+ def build_dependency_graph(sentence):
27
+ doc = nlp(sentence)
28
+ tokens = [tok.text for tok in doc]
29
+ edges = []
30
+ for tok in doc:
31
+ if tok.head.i != tok.i:
32
+ edges.extend([(tok.i, tok.head.i), (tok.head.i, tok.i)])
33
+ return tokens, edges
34
+
35
+ def preprocess(examples):
36
+ premises = examples["premise"]
37
+ hypotheses = examples["hypothesis"]
38
+ labels = examples["label"]
39
+ tokenized = tokenizer(premises, hypotheses,
40
+ truncation=True, padding="max_length",
41
+ max_length=MAX_LENGTH)
42
+ tokenized["labels"] = labels
43
+
44
+ p_tokens_list, p_edges_list, p_idx_list = [], [], []
45
+ h_tokens_list, h_edges_list, h_idx_list = [], [], []
46
+
47
+ for p, h, input_ids in zip(premises, hypotheses, tokenized["input_ids"]):
48
+ p_toks, p_edges = build_dependency_graph(p)
49
+ h_toks, h_edges = build_dependency_graph(h)
50
+ wp_tokens = tokenizer.convert_ids_to_tokens(input_ids)
51
+
52
+ def align_tokens(spacy_tokens, wp_tokens):
53
+ node_indices, wp_idx = [], 1
54
+ for _ in spacy_tokens:
55
+ if wp_idx >= len(wp_tokens) - 1: break
56
+ node_indices.append(wp_idx)
57
+ wp_idx += 1
58
+ while wp_idx < len(wp_tokens) - 1 and wp_tokens[wp_idx].startswith("##"):
59
+ wp_idx += 1
60
+ return node_indices
61
+
62
+ p_idx = align_tokens(p_toks, wp_tokens)
63
+ h_idx = align_tokens(h_toks, wp_tokens)
64
+
65
+ p_tokens_list.append(p_toks)
66
+ p_edges_list.append(p_edges)
67
+ p_idx_list.append(p_idx)
68
+
69
+ h_tokens_list.append(h_toks)
70
+ h_edges_list.append(h_edges)
71
+ h_idx_list.append(h_idx)
72
+
73
+ tokenized.update({
74
+ "premise_graph_tokens": p_tokens_list,
75
+ "premise_graph_edges": p_edges_list,
76
+ "premise_node_indices": p_idx_list,
77
+ "hypothesis_graph_tokens": h_tokens_list,
78
+ "hypothesis_graph_edges": h_edges_list,
79
+ "hypothesis_node_indices": h_idx_list,
80
+ })
81
+ return tokenized
82
+
83
+ snli = snli.map(preprocess, batched=True)
84
+ snli.save_to_disk(PREPROCESSED_DIR)
85
+ logging.info(f"Preprocessing complete. Saved to {PREPROCESSED_DIR}")
86
+ else:
87
+ logging.info("Using existing preprocessed data at %s", PREPROCESSED_DIR)
88
+
89
+
90
+ def chunk_transcript(transcript_text, start_idx, end_idx, tokenizer):
91
+ encoded = tokenizer(transcript_text,
92
+ return_offsets_mapping=True,
93
+ add_special_tokens=True,
94
+ return_tensors=None,
95
+ max_length=1024,
96
+ padding=False,
97
+ truncation=False)
98
+ all_input_ids = encoded["input_ids"]
99
+ all_offsets = encoded["offset_mapping"]
100
+
101
+ chunks = []
102
+ i = 0
103
+ while i < len(all_input_ids):
104
+ chunk_ids = all_input_ids[i : i + MAX_LENGTH]
105
+ chunk_offsets = all_offsets[i : i + MAX_LENGTH]
106
+ attention_mask = [1] * len(chunk_ids)
107
+
108
+ no_span = 1
109
+ start_token, end_token = -1, -1
110
+ if start_idx >= 0 and end_idx >= 0:
111
+ for j, (off_s, off_e) in enumerate(chunk_offsets):
112
+ if off_s <= start_idx < off_e:
113
+ start_token = j
114
+ if off_s < end_idx <= off_e:
115
+ end_token = j
116
+ break
117
+ if 0 <= start_token <= end_token:
118
+ no_span = 0
119
+ else:
120
+ start_token, end_token = -1, -1
121
+
122
+ chunks.append({
123
+ "input_ids": torch.tensor(chunk_ids, dtype=torch.long),
124
+ "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
125
+ "start_label": start_token,
126
+ "end_label": end_token,
127
+ "no_span_label": no_span,
128
+ })
129
+ i += (MAX_LENGTH - OVERLAP)
130
+ return chunks
131
+
132
+
133
+ class SpanExtractionChunkedDataset(Dataset):
134
+ def __init__(self, data):
135
+ self.samples = []
136
+ for item in data:
137
+ chunks = chunk_transcript(
138
+ item.get("transcript", ""),
139
+ item.get("start_idx", -1),
140
+ item.get("end_idx", -1),
141
+ tokenizer)
142
+ self.samples.extend(chunks)
143
+
144
+ def __len__(self):
145
+ return len(self.samples)
146
+
147
+ def __getitem__(self, idx):
148
+ return self.samples[idx]
149
+
150
+
151
+ def span_collate_fn(batch):
152
+ max_len = max(len(x["input_ids"]) for x in batch)
153
+ inputs, masks, starts, ends, nos = [], [], [], [], []
154
+ for x in batch:
155
+ pad = max_len - len(x["input_ids"])
156
+ inputs.append(torch.cat([x["input_ids"], torch.zeros(pad, dtype=torch.long)]).unsqueeze(0))
157
+ masks.append(torch.cat([x["attention_mask"], torch.zeros(pad, dtype=torch.long)]).unsqueeze(0))
158
+ starts.append(x["start_label"])
159
+ ends.append(x["end_label"])
160
+ nos.append(x["no_span_label"])
161
+ return {
162
+ "input_ids": torch.cat(inputs, dim=0),
163
+ "attention_mask": torch.cat(masks, dim=0),
164
+ "start_positions": torch.tensor(starts, dtype=torch.long),
165
+ "end_positions": torch.tensor(ends, dtype=torch.long),
166
+ "no_span_label": torch.tensor(nos, dtype=torch.long),
167
+ }
168
+
169
+
170
+ nltk.download('punkt')
171
+ nltk.download('punkt_tab')
172
+
173
+ class SentenceDataset(Dataset):
174
+ def __init__(self,
175
+ excel_path: str,
176
+ tokenizer,
177
+ max_length: int = 128):
178
+ df = pd.read_excel(excel_path)
179
+ self.samples = []
180
+
181
+ for _, row in df.iterrows():
182
+ transcript = str(row['Claude_Call'])
183
+ gold_sentences = row['Sel_K']
184
+ # if it's a string repr of list, eval it
185
+ if isinstance(gold_sentences, str):
186
+ gold_sentences = eval(gold_sentences)
187
+
188
+ # split into sentences
189
+ sentences = nltk.sent_tokenize(transcript)
190
+ for sent in sentences:
191
+ label = 1 if sent in gold_sentences else 0
192
+
193
+ enc = tokenizer.encode_plus(
194
+ sent,
195
+ max_length=max_length,
196
+ padding='max_length',
197
+ truncation=True,
198
+ return_tensors='pt'
199
+ )
200
+ self.samples.append({
201
+ 'input_ids': enc['input_ids'].squeeze(0),
202
+ 'attention_mask': enc['attention_mask'].squeeze(0),
203
+ 'label': torch.tensor(label, dtype=torch.float)
204
+ })
205
+
206
+ def __len__(self):
207
+ return len(self.samples)
208
+
209
+ def __getitem__(self, idx):
210
+ return self.samples[idx]
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.9.0
2
+ transformers>=4.20.0
3
+ datasets>=2.0.0
4
+ spacy>=3.4.0
5
+ networkx>=2.8
6
+ numpy>=1.21.0
7
+ pandas>=1.3.0
8
+ scikit-learn>=1.0.0
9
+ tqdm>=4.62.0
10
+ matplotlib>=3.5.0
11
+ accelerate>=0.20.0
12
+ peft>=0.4.0
13
+ openpyxl>=3.0.0
14
+ nltk>=3.7
15
+
16
+ # requirements-dev.txt
17
+ pytest>=6.0.0
18
+ pytest-cov>=3.0.0
19
+ black>=22.0.0
20
+ isort>=5.10.0
21
+ flake8>=4.0.0
22
+ mypy>=0.950
23
+ pre-commit>=2.15.0
24
+ sphinx>=4.0.0
25
+ sphinx-rtd-theme>=1.0.0
26
+ jupyter>=1.0.0
results/Ablation.txt ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-05-21 16:21:26,783 INFO Using existing preprocessed data at preprocessed_snli
2
+ 2025-05-21 16:21:28,008 INFO --- Training Baseline-BERT on 10% of data ---
3
+ 2025-05-21 16:33:27,656 INFO Baseline-BERT Epoch 1: train_loss=1.1084
4
+ 2025-05-21 16:45:29,223 INFO Baseline-BERT Epoch 2: train_loss=1.1015
5
+ 2025-05-21 16:57:31,263 INFO Baseline-BERT Epoch 3: train_loss=1.1008
6
+ 2025-05-21 17:09:33,145 INFO Baseline-BERT Epoch 4: train_loss=1.1021
7
+ 2025-05-21 17:21:33,012 INFO Baseline-BERT Epoch 5: train_loss=1.1034
8
+ 2025-05-21 17:21:43,717 INFO Baseline-BERT on 10% val β†’ acc=0.3232, f1=0.1628
9
+ 2025-05-21 17:21:43,717 INFO --- Training GNN-Augmented on 10% of data ---
10
+ 2025-05-21 17:35:02,212 INFO GNN-Augmented Epoch 1: train_loss=1.1044
11
+ 2025-05-21 17:48:21,592 INFO GNN-Augmented Epoch 2: train_loss=1.1041
12
+ 2025-05-21 18:01:40,046 INFO GNN-Augmented Epoch 3: train_loss=1.1025
13
+ 2025-05-21 18:14:58,966 INFO GNN-Augmented Epoch 4: train_loss=1.1019
14
+ 2025-05-21 18:29:19,473 INFO GNN-Augmented Epoch 5: train_loss=1.1038
15
+ 2025-05-21 18:29:34,558 INFO GNN-Augmented on 10% val β†’ acc=0.3232, f1=0.1628
16
+
17
+
18
+
19
+ 2025-05-21 16:21:26,783 INFO Using existing preprocessed data at preprocessed_snli
20
+ 2025-05-21 16:21:28,008 INFO --- Training Baseline-BERT on 10% of data ---
21
+ 2025-05-21 16:33:27,656 INFO Baseline-BERT Epoch 1: train_loss=1.1084
22
+ 2025-05-21 16:45:29,223 INFO Baseline-BERT Epoch 2: train_loss=1.1015
23
+ 2025-05-21 16:57:31,263 INFO Baseline-BERT Epoch 3: train_loss=1.1008
24
+ 2025-05-21 17:09:33,145 INFO Baseline-BERT Epoch 4: train_loss=1.1021
25
+ 2025-05-21 17:21:33,012 INFO Baseline-BERT Epoch 5: train_loss=1.1001
26
+ 2025-05-21 17:21:43,717 INFO Baseline-BERT on 10% val β†’ acc=0.3232, f1=0.1628
27
+ 2025-05-21 17:21:43,717 INFO --- Training GNN-Augmented on 10% of data ---
28
+ 2025-05-21 17:35:02,212 INFO GNN-Augmented Epoch 1: train_loss=1.1044
29
+ 2025-05-21 17:48:21,592 INFO GNN-Augmented Epoch 2: train_loss=1.1011
30
+ 2025-05-21 18:01:40,046 INFO GNN-Augmented Epoch 3: train_loss=0.9025
31
+ 2025-05-21 18:14:58,966 INFO GNN-Augmented Epoch 4: train_loss=0.8319
32
+ 2025-05-21 18:29:19,473 INFO GNN-Augmented Epoch 5: train_loss=0.7638
33
+ 2025-05-21 18:29:34,558 INFO GNN-Augmented on 10% val β†’ acc=0.6937, f1=0.4184
34
+
35
+
36
+
37
+
38
+ β–ΆοΈŽ Per‐dataset judge averages:
39
+ FinQA-10K: Judge1=4.78, Judge2=4.46, Judge3=4.42, Combined=4.55
40
+ SQuAD: Judge1=5.00, Judge2=4.84, Judge3=4.52, Combined=4.79
41
+
42
+
43
+ β–ΆοΈŽ Per‐dataset judge averages:
44
+ FinQA-10K: Judge1=4.96, Judge2=4.86, Judge3=4.68, Combined=4.84
45
+ SQuAD: Judge1=5.00, Judge2=4.94, Judge3=4.84, Combined=4.93
46
+
results/Fin-ExBERT.pptx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ef24ff9e727c97ddca4cb52a818c11ef4b61667d76a516cbddc48da3b4d85b2
3
+ size 14227310
results/ablation_study.png ADDED
results/combined_results.xlsx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a43f4de8658d50e234dd910d74b2b34f5e5d1fd0c910a2e0e830dd3360906a19
3
+ size 127995
results/fine_tuning_results.png ADDED
results/methods_summary.xlsx ADDED
Binary file (10.7 kB). View file
 
utils.py ADDED
@@ -0,0 +1,967 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from nltk import sent_tokenize
9
+ from sklearn.metrics import accuracy_score, precision_score, f1_score
10
+ from sklearn.model_selection import train_test_split
11
+ from torch.utils.data import DataLoader, random_split, WeightedRandomSampler
12
+ from transformers import AutoTokenizer, AutoModel, AutoConfig, get_linear_schedule_with_warmup
13
+ from peft import PeftModel, LoraConfig, get_peft_model
14
+ from datasets import load_dataset, DatasetDict, load_from_disk
15
+ import spacy
16
+ import re
17
+ from tqdm.auto import tqdm
18
+ from accelerate import Accelerator
19
+ import matplotlib.pyplot as plt
20
+ from torch.optim import AdamW
21
+ import pandas as pd
22
+ from typing import Optional, Tuple, List, Dict
23
+
24
+ from models import GraphAugmentedNLIModel, GraphAugmentedFinNLIModel
25
+ from preprocess_data import SpanExtractionChunkedDataset, process_data, chunk_transcript, span_collate_fn
26
+
27
+ # =============================
28
+ # Configuration Constants
29
+ # =============================
30
+ from config import MODEL_NAME, MAX_LENGTH, OVERLAP, PREPROCESSED_DIR, tokenizer, nlp
31
+
32
+ #MODEL_NAME = "bert-base-uncased"
33
+ BATCH_SIZE = 16
34
+ #MAX_LENGTH = 128
35
+ #OVERLAP = 32
36
+ LEARNING_RATE = 2e-5
37
+ EPOCHS = 5
38
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+ #PREPROCESSED_DIR = "preprocessed_snli"
40
+ MIXED_PRECISION = "fp16"
41
+
42
+ # label mapping
43
+ label_map = {0: "entailment", 1: "neutral", 2: "contradiction"}
44
+
45
+ # =============================
46
+ # Logging & Reproducibility
47
+ # =============================
48
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
49
+ def set_seed(seed: int = 42):
50
+ random.seed(seed)
51
+ np.random.seed(seed)
52
+ torch.manual_seed(seed)
53
+ if torch.cuda.is_available():
54
+ torch.cuda.manual_seed_all(seed)
55
+
56
+ # =============================
57
+ # Tokenizer & NLP Model
58
+ # =============================
59
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
60
+ nlp = spacy.load("en_core_web_sm")
61
+
62
+ # =============================
63
+ # Dependency Graph Helpers
64
+ # =============================
65
+ def build_dependency_graph(sentence: str):
66
+ doc = nlp(sentence)
67
+ tokens = [token.text for token in doc]
68
+ edges = []
69
+ for token in doc:
70
+ if token.head.i != token.i:
71
+ edges.append((token.i, token.head.i))
72
+ edges.append((token.head.i, token.i))
73
+ return tokens, edges
74
+
75
+ # =============================
76
+ # Token Alignment
77
+ # =============================
78
+ def align_tokens(spacy_tokens, wp_tokens):
79
+ node_indices = []
80
+ wp_idx = 1 # after [CLS]
81
+ for _ in spacy_tokens:
82
+ if wp_idx >= len(wp_tokens) - 1:
83
+ break
84
+ node_indices.append(wp_idx)
85
+ wp_idx += 1
86
+ while wp_idx < len(wp_tokens) - 1 and wp_tokens[wp_idx].startswith("##"):
87
+ wp_idx += 1
88
+ return node_indices
89
+
90
+ # =============================
91
+ # Data Collation
92
+ # =============================
93
+ def my_collate_fn(batch):
94
+ input_ids = [torch.tensor(ex["input_ids"], dtype=torch.long) for ex in batch]
95
+ attention_mask = [torch.tensor(ex["attention_mask"], dtype=torch.long) for ex in batch]
96
+ labels = [ex.get("labels", None) for ex in batch]
97
+
98
+ premise_graph_tokens = [ex.get("premise_graph_tokens") for ex in batch]
99
+ premise_graph_edges = [ex.get("premise_graph_edges") for ex in batch]
100
+ premise_node_indices = [ex.get("premise_node_indices") for ex in batch]
101
+
102
+ hypothesis_graph_tokens = [ex.get("hypothesis_graph_tokens") for ex in batch]
103
+ hypothesis_graph_edges = [ex.get("hypothesis_graph_edges") for ex in batch]
104
+ hypothesis_node_indices = [ex.get("hypothesis_node_indices") for ex in batch]
105
+
106
+ input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
107
+ attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)
108
+ labels = torch.tensor(labels, dtype=torch.long) if labels and labels[0] is not None else None
109
+
110
+ return {
111
+ "input_ids": input_ids,
112
+ "attention_mask": attention_mask,
113
+ "labels": labels,
114
+ "premise_graph_tokens": premise_graph_tokens,
115
+ "premise_graph_edges": premise_graph_edges,
116
+ "premise_node_indices": premise_node_indices,
117
+ "hypothesis_graph_tokens": hypothesis_graph_tokens,
118
+ "hypothesis_graph_edges": hypothesis_graph_edges,
119
+ "hypothesis_node_indices": hypothesis_node_indices,
120
+ }
121
+
122
+ # =============================
123
+ # Training Loop
124
+ # =============================
125
+ def train_model(epochs: int = EPOCHS,
126
+ batch_size: int = BATCH_SIZE,
127
+ lr: float = LEARNING_RATE,
128
+ save_model: bool = False,
129
+ save_path: str = 'gnn_model_weights_3.pt'):
130
+ set_seed()
131
+ process_data()
132
+ logging.info("Loading preprocessed dataset...")
133
+ snli = load_from_disk(PREPROCESSED_DIR)
134
+ snli.set_format("python", output_all_columns=True)
135
+
136
+ train_loader = DataLoader(snli["train"], batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn)
137
+ val_loader = DataLoader(snli["validation"], batch_size=batch_size, collate_fn=my_collate_fn)
138
+
139
+ model = GraphAugmentedNLIModel(MODEL_NAME).to(DEVICE)
140
+
141
+ if hasattr(model.bert, 'gradient_checkpointing_enable'):
142
+ model.bert.gradient_checkpointing_enable()
143
+ logging.info("Enabled gradient checkpointing on BERT.")
144
+
145
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
146
+ num_training_steps = epochs * len(train_loader)
147
+ lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=num_training_steps)
148
+
149
+ accelerator = Accelerator(mixed_precision=MIXED_PRECISION)
150
+ model, optimizer, train_loader, val_loader, lr_scheduler = accelerator.prepare(
151
+ model, optimizer, train_loader, val_loader, lr_scheduler
152
+ )
153
+
154
+ model.train()
155
+ all_losses = []
156
+ epoch_losses = []
157
+ best_val_loss = float('inf')
158
+ best_epoch = 0
159
+
160
+ for epoch in range(1, epochs + 1):
161
+ epoch_loss = []
162
+ progress = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", leave=False)
163
+ for batch in progress:
164
+ labels = batch["labels"].to(DEVICE) if batch.get("labels") is not None else None
165
+ outputs = model(
166
+ input_ids=batch["input_ids"].to(DEVICE),
167
+ attention_mask=batch["attention_mask"].to(DEVICE),
168
+ premise_graph_tokens=batch["premise_graph_tokens"],
169
+ premise_graph_edges=batch["premise_graph_edges"],
170
+ premise_node_indices=batch["premise_node_indices"],
171
+ hypothesis_graph_tokens=batch["hypothesis_graph_tokens"],
172
+ hypothesis_graph_edges=batch["hypothesis_graph_edges"],
173
+ hypothesis_node_indices=batch["hypothesis_node_indices"],
174
+ labels=labels
175
+ )
176
+ loss = outputs.get("loss") if isinstance(outputs, dict) else outputs
177
+
178
+ optimizer.zero_grad()
179
+ accelerator.backward(loss)
180
+ optimizer.step()
181
+ lr_scheduler.step()
182
+
183
+ loss_val = loss.item()
184
+ epoch_loss.append(loss_val)
185
+ all_losses.append(loss_val)
186
+ progress.set_postfix({"loss": f"{loss_val:.4f}"})
187
+
188
+ avg_epoch_loss = np.mean(epoch_loss)
189
+ epoch_losses.append(avg_epoch_loss)
190
+ logging.info(f"Epoch {epoch} completed. Avg Loss: {avg_epoch_loss:.4f}")
191
+
192
+ # Validation
193
+ model.eval()
194
+ val_losses = []
195
+ with torch.no_grad():
196
+ for batch in val_loader:
197
+ labels = batch["labels"].to(DEVICE) if batch.get("labels") is not None else None
198
+ outputs = model(
199
+ input_ids=batch["input_ids"].to(DEVICE),
200
+ attention_mask=batch["attention_mask"].to(DEVICE),
201
+ premise_graph_tokens=batch["premise_graph_tokens"],
202
+ premise_graph_edges=batch["premise_graph_edges"],
203
+ premise_node_indices=batch["premise_node_indices"],
204
+ hypothesis_graph_tokens=batch["hypothesis_graph_tokens"],
205
+ hypothesis_graph_edges=batch["hypothesis_graph_edges"],
206
+ hypothesis_node_indices=batch["hypothesis_node_indices"],
207
+ labels=labels
208
+ )
209
+ loss_item = outputs.get("loss").item() if isinstance(outputs, dict) else outputs.item()
210
+ val_losses.append(loss_item)
211
+ avg_val_loss = np.mean(val_losses) if val_losses else float('inf')
212
+ logging.info(f"Validation Loss after Epoch {epoch}: {avg_val_loss:.4f}")
213
+
214
+ if avg_val_loss < best_val_loss:
215
+ best_val_loss = avg_val_loss
216
+ best_epoch = epoch
217
+ if save_model:
218
+ logging.info(f"Saving best model at epoch {epoch} with val loss {avg_val_loss:.4f}")
219
+ torch.save(model.state_dict(), save_path)
220
+ model.train()
221
+
222
+ # Plot losses
223
+ plt.figure()
224
+ plt.plot(all_losses)
225
+ plt.xlabel('Training steps')
226
+ plt.ylabel('Loss')
227
+ plt.title('Step-wise Training Loss')
228
+ plt.show()
229
+
230
+ plt.figure()
231
+ plt.plot(range(1, epochs+1), epoch_losses, marker='o')
232
+ plt.xlabel('Epochs')
233
+ plt.ylabel('Loss')
234
+ plt.title('Epoch-wise Training Loss')
235
+ plt.show()
236
+
237
+ logging.info(f"Training complete. Best validation loss {best_val_loss:.4f} at epoch {best_epoch}.")
238
+ return model
239
+
240
+
241
+ def predict_nli(premise, hypothesis, tokenizer=tokenizer, model_path='gnn_model_checkpoint.pt'):
242
+ # 1) instantiate the model exactly as you did during training
243
+ model = GraphAugmentedNLIModel(MODEL_NAME).to(DEVICE)
244
+
245
+ # 2) load the checkpoint, then hand only the model weights to load_state_dict
246
+ ckpt = torch.load(model_path, map_location=DEVICE)
247
+ model.load_state_dict(ckpt["model_state_dict"])
248
+
249
+ model.eval()
250
+
251
+ # 3) tokenize & build graphs (as before)…
252
+ encoded = tokenizer(
253
+ premise, hypothesis,
254
+ truncation=True,
255
+ padding="max_length",
256
+ max_length=MAX_LENGTH,
257
+ return_tensors="pt"
258
+ )
259
+
260
+ input_ids = encoded["input_ids"]
261
+ attention_mask = encoded["attention_mask"]
262
+
263
+ # Build dependency graphs
264
+ p_tokens, p_edges = build_dependency_graph(premise)
265
+ h_tokens, h_edges = build_dependency_graph(hypothesis)
266
+
267
+ # Convert ids back to tokens for alignment
268
+ wp_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
269
+
270
+ p_node_indices = align_tokens(p_tokens, wp_tokens)
271
+ h_node_indices = align_tokens(h_tokens, wp_tokens)
272
+
273
+ # Move tensors to the same device as the model
274
+ device = next(model.parameters()).device
275
+ input_ids = input_ids.to(device)
276
+ attention_mask = attention_mask.to(device)
277
+
278
+ # Prepare inputs for the model: the model expects lists for graph fields
279
+ # since we used a custom collate_fn logic.
280
+ premise_graph_tokens = [p_tokens]
281
+ premise_graph_edges = [p_edges]
282
+ premise_node_indices = [p_node_indices]
283
+
284
+ hypothesis_graph_tokens = [h_tokens]
285
+ hypothesis_graph_edges = [h_edges]
286
+ hypothesis_node_indices = [h_node_indices]
287
+
288
+ with torch.no_grad():
289
+ outputs = model(
290
+ input_ids=input_ids,
291
+ attention_mask=attention_mask,
292
+ premise_graph_tokens=premise_graph_tokens,
293
+ premise_graph_edges=premise_graph_edges,
294
+ premise_node_indices=premise_node_indices,
295
+ hypothesis_graph_tokens=hypothesis_graph_tokens,
296
+ hypothesis_graph_edges=hypothesis_graph_edges,
297
+ hypothesis_node_indices=hypothesis_node_indices
298
+ )
299
+
300
+ logits = outputs["logits"]
301
+ probs = F.softmax(logits, dim=-1).cpu().numpy()[0]
302
+ # Get predicted label
303
+ predicted_label_id = torch.argmax(logits, dim=-1).item()
304
+ predicted_label = label_map[predicted_label_id]
305
+ prob_map = dict()
306
+ for i, cls_label in label_map.items():
307
+ prob_map[cls_label] = probs[i]
308
+ return predicted_label, prob_map
309
+
310
+
311
+ def predict_fin_nli(
312
+ premise: str,
313
+ hypothesis: str,
314
+ tokenizer=tokenizer,
315
+ model_path: str = 'gnn_model_checkpoint.pt',
316
+ adapter_dir: str = './lora_finance_adapter',
317
+ ) -> (str, list):
318
+ # 1) Load base GraphAugmentedFinNLIModel and its checkpoint
319
+ base_model = GraphAugmentedFinNLIModel(MODEL_NAME).to(DEVICE)
320
+ ckpt = torch.load(model_path, map_location=DEVICE)
321
+ base_model.load_state_dict(ckpt['model_state_dict'])
322
+
323
+ # 2) Wrap with the same LoRA config you used in training
324
+ lora_cfg = LoraConfig(
325
+ r=8,
326
+ lora_alpha=32,
327
+ lora_dropout=0.1,
328
+ bias='none',
329
+ task_type='SEQ_CLS',
330
+ target_modules=['query', 'value']
331
+ )
332
+ model = get_peft_model(base_model, lora_cfg).to(DEVICE)
333
+
334
+ # 3) Load your adapter checkpoint (the .pt under lora_finance_adapter/)
335
+ adapter_ckpt = torch.load(os.path.join(adapter_dir, 'training_checkpoint.pt'), map_location=DEVICE)
336
+ # This checkpoint contains the same 'model_state_dict' keysβ€”so load it leniently:
337
+ model.load_state_dict(adapter_ckpt['model_state_dict'], strict=False)
338
+ model.eval()
339
+
340
+ # 4) Tokenize
341
+ enc = tokenizer(
342
+ premise, hypothesis,
343
+ truncation=True,
344
+ padding='max_length',
345
+ max_length=MAX_LENGTH,
346
+ return_tensors='pt'
347
+ )
348
+ input_ids = enc['input_ids'].to(DEVICE)
349
+ attention_mask = enc['attention_mask'].to(DEVICE)
350
+
351
+ # 5) Build & align your dependency graphs
352
+ p_toks, p_edges = build_dependency_graph(premise)
353
+ h_toks, h_edges = build_dependency_graph(hypothesis)
354
+ wp = tokenizer.convert_ids_to_tokens(input_ids[0])
355
+ p_idx = align_tokens(p_toks, wp)
356
+ h_idx = align_tokens(h_toks, wp)
357
+
358
+ premise_graph_tokens = [p_toks]
359
+ premise_graph_edges = [p_edges]
360
+ premise_node_indices = [p_idx]
361
+ hypothesis_graph_tokens = [h_toks]
362
+ hypothesis_graph_edges = [h_edges]
363
+ hypothesis_node_indices = [h_idx]
364
+
365
+ # 6) Forward
366
+ with torch.no_grad():
367
+ out = model(
368
+ input_ids=input_ids,
369
+ attention_mask=attention_mask,
370
+ premise_graph_tokens=premise_graph_tokens,
371
+ premise_graph_edges=premise_graph_edges,
372
+ premise_node_indices=premise_node_indices,
373
+ hypothesis_graph_tokens=hypothesis_graph_tokens,
374
+ hypothesis_graph_edges=hypothesis_graph_edges,
375
+ hypothesis_node_indices=hypothesis_node_indices
376
+ )
377
+
378
+ logits = out['logits'][0] # shape [3]
379
+ probs = torch.softmax(logits, dim=-1).cpu().numpy()
380
+
381
+ # 7) Collapse to entailment vs. contradiction (ignore neutral)
382
+ entail, neutral, contra = probs
383
+ s = entail + contra + 1e-12
384
+ scores = [entail / s, contra / s]
385
+ label = 'entailment' if entail >= contra else 'contradiction'
386
+ return label, scores
387
+
388
+
389
+ def train_model_with_chkpt(epochs: int = 5,
390
+ batch_size: int = 16,
391
+ lr: float = 2e-5,
392
+ save_model: bool = False,
393
+ save_path: str = 'gnn_model_checkpoint.pt',
394
+ resume: bool = False):
395
+ """
396
+ Train with mixed precision, gradient checkpointing, and resume support.
397
+ If resume=True and save_path exists, picks up from last epoch.
398
+ """
399
+ set_seed()
400
+ process_data()
401
+ logging.info("Loading preprocessed dataset…")
402
+ snli = load_from_disk(PREPROCESSED_DIR)
403
+ snli.set_format("python", output_all_columns=True)
404
+
405
+ train_loader = DataLoader(snli["train"], batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn)
406
+ val_loader = DataLoader(snli["validation"], batch_size=batch_size, collate_fn=my_collate_fn)
407
+
408
+ model = GraphAugmentedNLIModel(MODEL_NAME).to(DEVICE)
409
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
410
+ total_steps = epochs * len(train_loader)
411
+ scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=total_steps)
412
+
413
+ # --- Resume checkpoint if requested ---
414
+ start_epoch = 1
415
+ if resume and os.path.isfile(save_path):
416
+ ckpt = torch.load(save_path, map_location=DEVICE)
417
+ model.load_state_dict(ckpt["model_state_dict"])
418
+ optimizer.load_state_dict(ckpt["optimizer_state_dict"])
419
+ scheduler.load_state_dict(ckpt["scheduler_state_dict"])
420
+ start_epoch = ckpt.get("epoch", 1) + 1
421
+ logging.info(f"Resuming from epoch {start_epoch}")
422
+
423
+ # Mixed precision setup
424
+ if hasattr(model.bert, "gradient_checkpointing_enable"):
425
+ model.bert.gradient_checkpointing_enable()
426
+ logging.info("Enabled gradient checkpointing on BERT.")
427
+ accelerator = Accelerator(mixed_precision=MIXED_PRECISION)
428
+ model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare(
429
+ model, optimizer, train_loader, val_loader, scheduler
430
+ )
431
+
432
+ best_val_loss = float("inf")
433
+ for epoch in range(start_epoch, epochs + 1):
434
+ model.train()
435
+ train_losses = []
436
+ for batch in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}"):
437
+ optimizer.zero_grad()
438
+ outputs = model(
439
+ input_ids=batch["input_ids"].to(DEVICE),
440
+ attention_mask=batch["attention_mask"].to(DEVICE),
441
+ premise_graph_tokens=batch["premise_graph_tokens"],
442
+ premise_graph_edges=batch["premise_graph_edges"],
443
+ premise_node_indices=batch["premise_node_indices"],
444
+ hypothesis_graph_tokens=batch["hypothesis_graph_tokens"],
445
+ hypothesis_graph_edges=batch["hypothesis_graph_edges"],
446
+ hypothesis_node_indices=batch["hypothesis_node_indices"],
447
+ labels=batch.get("labels", None).to(DEVICE) if batch.get("labels") is not None else None
448
+ )
449
+ loss = outputs["loss"] if isinstance(outputs, dict) else outputs
450
+ accelerator.backward(loss)
451
+ optimizer.step()
452
+ scheduler.step()
453
+ train_losses.append(loss.item())
454
+ avg_train = np.mean(train_losses)
455
+ logging.info(f"Epoch {epoch} train loss: {avg_train:.4f}")
456
+
457
+ # Validation
458
+ model.eval()
459
+ val_losses = []
460
+ with torch.no_grad():
461
+ for batch in val_loader:
462
+ outputs = model(
463
+ input_ids=batch["input_ids"].to(DEVICE),
464
+ attention_mask=batch["attention_mask"].to(DEVICE),
465
+ premise_graph_tokens=batch["premise_graph_tokens"],
466
+ premise_graph_edges=batch["premise_graph_edges"],
467
+ premise_node_indices=batch["premise_node_indices"],
468
+ hypothesis_graph_tokens=batch["hypothesis_graph_tokens"],
469
+ hypothesis_graph_edges=batch["hypothesis_graph_edges"],
470
+ hypothesis_node_indices=batch["hypothesis_node_indices"],
471
+ labels=batch.get("labels", None).to(DEVICE) if batch.get("labels") is not None else None
472
+ )
473
+ v_loss = outputs["loss"].item() if isinstance(outputs, dict) else outputs.item()
474
+ val_losses.append(v_loss)
475
+ avg_val = np.mean(val_losses) if val_losses else float("inf")
476
+ logging.info(f"Epoch {epoch} val loss: {avg_val:.4f}")
477
+
478
+ # Save checkpoint
479
+ ckpt = {
480
+ "epoch": epoch,
481
+ "model_state_dict": model.state_dict(),
482
+ "optimizer_state_dict": optimizer.state_dict(),
483
+ "scheduler_state_dict": scheduler.state_dict(),
484
+ }
485
+ torch.save(ckpt, save_path)
486
+ logging.info(f"Saved checkpoint: {save_path}")
487
+
488
+ if avg_val < best_val_loss:
489
+ best_val_loss = avg_val
490
+
491
+ logging.info(f"Training complete. Best val loss: {best_val_loss:.4f}")
492
+ return model
493
+
494
+
495
+ def extract_sentences_by_intent(
496
+ text: str,
497
+ intent: str,
498
+ adapter_dir: str = "./lora_finance_adapter",
499
+ threshold: float = 0.7,
500
+ top_k: int = None,
501
+ min_words: int = 4,
502
+ convo_focus: str = None
503
+ ):
504
+ """
505
+ Splits `text` into sentences, embeds them (and the `intent`) under your
506
+ LoRA‐adapted BERT, and returns those whose cosine similarity β‰₯ `threshold`.
507
+ Loads the adapter from the single `training_checkpoint.pt` in `adapter_dir`.
508
+ """
509
+ # 1) Sentence split & cleanup
510
+ # 1) Only consider lines spoken by the customer
511
+
512
+ if convo_focus is None:
513
+ sentences = [sent.text.strip() for sent in nlp(text).sents if sent.text.strip()]
514
+
515
+ elif convo_focus == "customer":
516
+ customer_lines = [
517
+ line.strip()
518
+ for line in text.splitlines()
519
+ if line.strip().lower().startswith("customer:")
520
+ ]
521
+
522
+ # 2) Sentence-split each customer line
523
+ sentences = []
524
+ for cust_line in customer_lines:
525
+ for sent in nlp(cust_line).sents:
526
+ s = sent.text.strip()
527
+ if s and len(s.split(' '))>6:
528
+ sentences.append(s)
529
+
530
+ else:
531
+ customer_lines = [
532
+ line.strip()
533
+ for line in text.splitlines()
534
+ if line.strip().lower().startswith("agent:")
535
+ ]
536
+
537
+ # 2) Sentence-split each customer line
538
+ sentences = []
539
+ for cust_line in customer_lines:
540
+ for sent in nlp(cust_line).sents:
541
+ s = sent.text.strip()
542
+ if s and len(s.split(' '))>6:
543
+ sentences.append(s)
544
+
545
+ # 2) Load base BERT + wrap in same LoRA config
546
+ base_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
547
+ lora_cfg = LoraConfig(
548
+ r=8,
549
+ lora_alpha=32,
550
+ lora_dropout=0.1,
551
+ bias="none",
552
+ task_type="CAUSAL_LM", # must match your fine-tune setting
553
+ )
554
+ model = get_peft_model(base_model, lora_cfg).to(DEVICE)
555
+
556
+ # 3) Load your adapter checkpoint
557
+ chkpt_path = os.path.join(adapter_dir, "training_checkpoint.pt")
558
+ if not os.path.isfile(chkpt_path):
559
+ raise FileNotFoundError(f"No LoRA checkpoint at {chkpt_path}")
560
+ ckpt = torch.load(chkpt_path, map_location=DEVICE)
561
+ # ckpt["model_state_dict"] contains both base + LoRA weights; strict=False
562
+ model.load_state_dict(ckpt["model_state_dict"], strict=False)
563
+ model.eval()
564
+
565
+ # helper: get [CLS] embedding under LoRA-BERT
566
+ def embed(text_str):
567
+ toks = tokenizer(
568
+ text_str,
569
+ truncation=True,
570
+ padding="longest",
571
+ return_tensors="pt"
572
+ ).to(DEVICE)
573
+
574
+ em_args = {
575
+ "input_ids": toks["input_ids"],
576
+ "attention_mask": toks["attention_mask"],
577
+ }
578
+ if "token_type_ids" in toks:
579
+ em_args["token_type_ids"] = toks["token_type_ids"]
580
+
581
+ # unwrap PEFT to call only the base BertModel
582
+ hf_model = getattr(model, "base_model", model)
583
+ with torch.no_grad():
584
+ last_hidden = hf_model(
585
+ input_ids=em_args["input_ids"],
586
+ attention_mask=em_args["attention_mask"],
587
+ **({"token_type_ids": em_args["token_type_ids"]} if "token_type_ids" in em_args else {})
588
+ ).last_hidden_state
589
+ return last_hidden[:, 0, :]
590
+
591
+ # now embed(intent) and each sentence using this safe helper
592
+ intent_emb = embed(intent)
593
+
594
+ results = []
595
+ with torch.no_grad():
596
+ for sent in sentences:
597
+ clean = re.sub(r'^(Agent|Customer):\s*', "", sent)
598
+ if len(clean.split()) < min_words:
599
+ continue
600
+
601
+ sent_emb = embed(clean)
602
+ sim = F.cosine_similarity(sent_emb, intent_emb, dim=1).item()
603
+ if sim >= threshold:
604
+ results.append((clean, sim))
605
+
606
+ # 5) sort & trim
607
+ results.sort(key=lambda x: x[1], reverse=True)
608
+ return results[:top_k] if top_k else results
609
+
610
+
611
+ def train_sentence_extractor(
612
+ model: nn.Module,
613
+ dataset: torch.utils.data.Dataset,
614
+ output_dir: str,
615
+ val_split: float = 0.2,
616
+ epochs: int = 3,
617
+ batch_size: int = 16,
618
+ lr: float = 2e-5,
619
+ device: str = "cpu",
620
+ unfreeze_after_epoch: int = 1,
621
+ threshold: float = 0.5
622
+ ):
623
+ """
624
+ Fine-tune `model` on `dataset`, hold out `val_split` for val,
625
+ compute loss + acc + precision + F1 each epoch, save best checkpoint,
626
+ and plot all four metrics at the end.
627
+ """
628
+ # Split
629
+ total = len(dataset)
630
+ val_n = int(total * val_split)
631
+ train_n = total - val_n
632
+ train_ds, val_ds = random_split(dataset, [train_n, val_n])
633
+
634
+ # Oversample train
635
+ train_labels = [train_ds[i]['label'].item() for i in range(len(train_ds))]
636
+ counts = torch.bincount(torch.tensor(train_labels, dtype=torch.long))
637
+ weights = (1.0 / counts.float()).tolist()
638
+ sample_weights = [weights[int(l)] for l in train_labels]
639
+ sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)
640
+
641
+ train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler, drop_last=True)
642
+ val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
643
+
644
+ model.to(device)
645
+ # initially freeze backbone
646
+ for p in model.bert.parameters(): p.requires_grad = False
647
+
648
+ optimizer = AdamW(model.parameters(), lr=lr)
649
+ total_steps = epochs * len(train_loader)
650
+ scheduler = get_linear_schedule_with_warmup(
651
+ optimizer,
652
+ num_warmup_steps=int(0.1 * total_steps),
653
+ num_training_steps=total_steps
654
+ )
655
+ criterion = nn.BCEWithLogitsLoss()
656
+
657
+ # storage for metrics
658
+ train_losses, val_losses = [], []
659
+ train_accs, val_accs = [], []
660
+ train_precs, val_precs = [], []
661
+ train_f1s, val_f1s = [], []
662
+
663
+ best_val_loss = float('inf')
664
+
665
+ for epoch in range(1, epochs+1):
666
+ # β€”β€” TRAIN β€”β€”
667
+ model.train()
668
+ epoch_loss = 0.0
669
+ preds, labels = [], []
670
+ for batch in tqdm(train_loader, desc=f"Train {epoch}/{epochs}"):
671
+ inputs = batch['input_ids'].to(device)
672
+ masks = batch['attention_mask'].to(device)
673
+ labs = batch['label'].to(device)
674
+
675
+ optimizer.zero_grad()
676
+ logits = model(inputs, masks) # raw logits
677
+ loss = criterion(logits, labs)
678
+ loss.backward()
679
+ optimizer.step()
680
+ scheduler.step()
681
+
682
+ epoch_loss += loss.item()
683
+
684
+ probs = torch.sigmoid(logits)
685
+ batch_preds = (probs >= threshold).long()
686
+ preds.extend(batch_preds.cpu().tolist())
687
+ labels.extend(labs.cpu().long().tolist())
688
+
689
+ avg_train = epoch_loss / len(train_loader)
690
+ train_losses.append(avg_train)
691
+ train_accs.append( accuracy_score(labels, preds) )
692
+ train_precs.append( precision_score(labels, preds, zero_division=0) )
693
+ train_f1s.append( f1_score(labels, preds, zero_division=0) )
694
+ print(f"β†’ Epoch {epoch} Train β€” loss {avg_train:.4f}, acc {train_accs[-1]:.4f}, prec {train_precs[-1]:.4f}, f1 {train_f1s[-1]:.4f}")
695
+
696
+ # unfreeze if needed
697
+ if epoch == unfreeze_after_epoch:
698
+ for p in model.bert.parameters(): p.requires_grad = True
699
+ optimizer = AdamW([
700
+ {"params": model.classifier.parameters(), "lr": 1e-3},
701
+ {"params": model.bert.parameters(), "lr": 1e-5},
702
+ ], weight_decay=1e-2)
703
+ scheduler = get_linear_schedule_with_warmup(
704
+ optimizer,
705
+ num_warmup_steps=int(0.1 * total_steps),
706
+ num_training_steps=total_steps
707
+ )
708
+
709
+ # β€”β€” VALIDATION β€”β€”
710
+ model.eval()
711
+ epoch_loss = 0.0
712
+ preds, labels = [], []
713
+ with torch.no_grad():
714
+ for batch in tqdm(val_loader, desc=f" Val {epoch}/{epochs}"):
715
+ inputs = batch['input_ids'].to(device)
716
+ masks = batch['attention_mask'].to(device)
717
+ labs = batch['label'].to(device)
718
+
719
+ logits = model(inputs, masks)
720
+ loss = criterion(logits, labs)
721
+ epoch_loss += loss.item()
722
+
723
+ probs = torch.sigmoid(logits)
724
+ batch_preds = (probs >= threshold).long()
725
+ preds.extend(batch_preds.cpu().tolist())
726
+ labels.extend(labs.cpu().long().tolist())
727
+
728
+ avg_val = epoch_loss / len(val_loader)
729
+ val_losses.append(avg_val)
730
+ val_accs.append( accuracy_score(labels, preds) )
731
+ val_precs.append( precision_score(labels, preds, zero_division=0) )
732
+ val_f1s.append( f1_score(labels, preds, zero_division=0) )
733
+ print(f"β†’ Epoch {epoch} Val β€” loss {avg_val:.4f}, acc {val_accs[-1]:.4f}, prec {val_precs[-1]:.4f}, f1 {val_f1s[-1]:.4f}")
734
+
735
+ # checkpoints
736
+ os.makedirs(output_dir, exist_ok=True)
737
+ ckpt = os.path.join(output_dir, f"epo{epoch}_val{avg_val:.4f}.pth")
738
+ torch.save(model.state_dict(), ckpt)
739
+ if avg_val < best_val_loss:
740
+ best_val_loss = avg_val
741
+ torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pth"))
742
+ print(f"πŸŽ‰ New best model saved (val loss {best_val_loss:.4f})")
743
+
744
+ print(f"βœ”οΈ Training complete β€” best val loss: {best_val_loss:.4f}")
745
+
746
+ # β€”β€” PLOT METRICS β€”β€”
747
+ epochs = list(range(1, epochs+1))
748
+
749
+ save_metric_plot(
750
+ epochs,
751
+ train_losses,
752
+ val_losses,
753
+ metric_name="Loss",
754
+ output_path="results/Loss_Plot.png"
755
+ )
756
+
757
+ save_metric_plot(
758
+ epochs,
759
+ train_accs,
760
+ val_accs,
761
+ metric_name="Accuracy",
762
+ output_path="results/Accuracy_Plot.png",
763
+ threshold=0.5
764
+ )
765
+
766
+ save_metric_plot(
767
+ epochs,
768
+ train_precs,
769
+ val_precs,
770
+ metric_name="Precision",
771
+ output_path="results/Precision_Plot.png",
772
+ threshold=0.5
773
+ )
774
+
775
+ save_metric_plot(
776
+ epochs,
777
+ train_f1s,
778
+ val_f1s,
779
+ metric_name="F1 Score",
780
+ output_path="results/F1Score_Plot.png",
781
+ threshold=0.5
782
+ )
783
+
784
+
785
+ def save_metric_plot(
786
+ epochs,
787
+ train_vals,
788
+ val_vals,
789
+ metric_name: str,
790
+ output_path: str,
791
+ threshold: float = None
792
+ ):
793
+ """
794
+ epochs – list of epoch indices
795
+ train_vals – list of train metric values
796
+ val_vals – list of validation metric values
797
+ metric_name – e.g. "Loss", "Accuracy", "Precision", "F1 Score"
798
+ output_path – where to save the PNG
799
+ threshold – optional horizontal line to draw, e.g. 0.5
800
+ """
801
+ fig, ax = plt.subplots(figsize=(8, 5))
802
+ ax.plot(epochs, train_vals, marker='o', linewidth=2, label=f'Train {metric_name}')
803
+ ax.plot(epochs, val_vals, marker='s', linewidth=2, label=f'Val {metric_name}')
804
+
805
+ if threshold is not None:
806
+ ax.axhline(threshold, color='gray', linestyle='--', linewidth=1, label=f'Threshold = {threshold}')
807
+
808
+ ax.set_title(f'{metric_name} over Epochs', fontsize=14, pad=10)
809
+ ax.set_xlabel('Epoch', fontsize=12)
810
+ ax.set_ylabel(metric_name, fontsize=12)
811
+ ax.grid(True, linestyle='--', alpha=0.4)
812
+ ax.legend(loc='best', frameon=True, fontsize=10)
813
+ fig.tight_layout()
814
+ fig.savefig(output_path, dpi=300)
815
+ plt.close(fig)
816
+
817
+
818
+ def demo_on_random_val(
819
+ model,
820
+ tokenizer,
821
+ excel_path: str,
822
+ ckpt_path: str,
823
+ max_length: int = 128,
824
+ device: str = "cpu",
825
+ temperature: float = 1.0
826
+ ):
827
+ """
828
+ Like demo_on_random_val, but instead of a fixed threshold:
829
+ 1) Compute sigmoid(logits / temperature) for each sentence
830
+ 2) Sort probabilities descending
831
+ 3) Find the largest gap between adjacent probs
832
+ 4) Set dynamic_threshold = midpoint of that gap
833
+ 5) Extract all sentences with prob >= dynamic_threshold
834
+ """
835
+ # load model
836
+ model.load_state_dict(torch.load(ckpt_path, map_location=device))
837
+ model.to(device).eval()
838
+
839
+ # sample one from validation split
840
+ df = pd.read_excel(excel_path)
841
+ _, val_df = train_test_split(df, test_size=0.2, random_state=42)
842
+ row = val_df.sample(n=1, random_state=random.randint(0,999)).iloc[0]
843
+ transcript = str(row['Claude_Call'])
844
+ print(f"\n── Transcript (val sample idx={row['idx']}):\n{transcript}\n")
845
+
846
+ # split into sentences & run inference
847
+ sentences, probs = [], []
848
+ for sent in sent_tokenize(transcript):
849
+ enc = tokenizer.encode_plus(
850
+ sent,
851
+ max_length=max_length,
852
+ padding='max_length',
853
+ truncation=True,
854
+ return_tensors='pt'
855
+ )
856
+ logits = model(enc['input_ids'].to(device),
857
+ enc['attention_mask'].to(device))
858
+ prob = torch.sigmoid(logits / temperature).item()
859
+ sentences.append(sent)
860
+ probs.append(prob)
861
+
862
+ # print all
863
+ print("Sentence probabilities:")
864
+ for s,p in zip(sentences, probs):
865
+ print(f" β†’ {p:.4f} β†’ {s}")
866
+
867
+ # if no variation, fall back to 0.5
868
+ if len(probs) < 2 or max(probs) - min(probs) < 1e-3:
869
+ dynamic_thr = 0.5
870
+ else:
871
+ # find elbow in sorted probabilities
872
+ sorted_probs = sorted(probs, reverse=True)
873
+ diffs = [sorted_probs[i] - sorted_probs[i+1] for i in range(len(sorted_probs)-1)]
874
+ idx = max(range(len(diffs)), key=lambda i: diffs[i])
875
+ # threshold is midpoint between the two
876
+ dynamic_thr = (sorted_probs[idx] + sorted_probs[idx+1]) / 2.0
877
+
878
+ print(f"\nDynamic threshold = {dynamic_thr:.4f}\n")
879
+ print("Extracted sentences:")
880
+ for s,p in zip(sentences, probs):
881
+ if p >= dynamic_thr:
882
+ print(f" β€’ {p:.4f} β†’ {s}")
883
+ print()
884
+
885
+
886
+ def batch_predict_and_save(
887
+ model,
888
+ tokenizer,
889
+ excel_path: str,
890
+ ckpt_path: str,
891
+ output_path: str,
892
+ n_samples: int = 40,
893
+ max_length: int = 128,
894
+ device: str = "cpu",
895
+ temperature: float = 1.0,
896
+ random_state: int = None
897
+ ):
898
+ """
899
+ 1) Loads best checkpoint
900
+ 2) Samples `n_samples` rows
901
+ 3) For each transcript:
902
+ - tokenize into sentences
903
+ - compute p = sigmoid(logits/temperature)
904
+ - compute elbow threshold on sorted p’s
905
+ - extract all sentences with p >= elbow
906
+ - if none, pick the highest-p sentence
907
+ 4) Save new Excel with columns:
908
+ - 'Claude_Call'
909
+ - 'Predicted Sel_K' (list of extracted sentences)
910
+ """
911
+ # load model
912
+ model.load_state_dict(torch.load(ckpt_path, map_location=device))
913
+ model.to(device).eval()
914
+
915
+ # sample rows
916
+ df = pd.read_excel(excel_path)
917
+ sampled = df.sample(n=n_samples, random_state=random_state) \
918
+ if random_state is not None else df.sample(n=n_samples)
919
+
920
+ records = []
921
+ for _, row in tqdm(sampled.iterrows(),
922
+ total=len(sampled),
923
+ desc="Running Predictions"):
924
+ transcript = str(row['Claude_Call'])
925
+ sentences = sent_tokenize(transcript)
926
+
927
+ # compute probabilities
928
+ probs = []
929
+ for sent in sentences:
930
+ enc = tokenizer.encode_plus(
931
+ sent,
932
+ max_length=max_length,
933
+ padding='max_length',
934
+ truncation=True,
935
+ return_tensors='pt'
936
+ )
937
+ with torch.no_grad():
938
+ logits = model(enc['input_ids'].to(device),
939
+ enc['attention_mask'].to(device))
940
+ p = torch.sigmoid(logits / temperature).item()
941
+ probs.append(p)
942
+
943
+ # dynamic threshold via elbow detection
944
+ if len(probs) >= 2 and max(probs) - min(probs) > 1e-3:
945
+ sp = sorted(probs, reverse=True)
946
+ diffs = [sp[i] - sp[i+1] for i in range(len(sp)-1)]
947
+ idx = max(range(len(diffs)), key=lambda i: diffs[i])
948
+ thr = (sp[idx] + sp[idx+1]) / 2.0
949
+ else:
950
+ thr = 0.5 # fallback
951
+
952
+ # collect all above threshold, else top-1
953
+ extracted = [s for s,p in zip(sentences, probs) if p >= thr]
954
+ if not extracted and sentences:
955
+ best_idx = int(max(range(len(probs)), key=lambda i: probs[i]))
956
+ extracted = [sentences[best_idx]]
957
+
958
+ records.append({
959
+ 'Claude_Call': transcript,
960
+ 'Predicted Sel_K': extracted
961
+ })
962
+
963
+ # save
964
+ out_df = pd.DataFrame(records)
965
+ os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True)
966
+ out_df.to_excel(output_path, index=False)
967
+ print(f"➑️ Saved {len(out_df)} rows to {output_path}")