Upload folder using huggingface_hub
Browse files- .gitattributes +6 -0
- LICENSE +21 -0
- README.md +196 -3
- __init__.py +38 -0
- ablation_and_evaluation/ablation_studies.py +239 -0
- ablation_and_evaluation/eval2.py +143 -0
- ablation_and_evaluation/evaluation_studies.py +193 -0
- config.py +14 -0
- data/Fin_ExBERT_data.xlsx +3 -0
- data/Fin_ExBERT_test_set.xlsx +3 -0
- data/Fin_ExBERT_train_val_data.xlsx +3 -0
- finetune_lora.py +164 -0
- images/methodology_flowchart.png +3 -0
- images/test.txt +1 -0
- main.py +132 -0
- models.py +251 -0
- preprocess_data.py +210 -0
- requirements.txt +26 -0
- results/Ablation.txt +46 -0
- results/Fin-ExBERT.pptx +3 -0
- results/ablation_study.png +0 -0
- results/combined_results.xlsx +3 -0
- results/fine_tuning_results.png +0 -0
- results/methods_summary.xlsx +0 -0
- utils.py +967 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
data/Fin_ExBERT_data.xlsx filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
data/Fin_ExBERT_test_set.xlsx filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
data/Fin_ExBERT_train_val_data.xlsx filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
images/methodology_flowchart.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
results/combined_results.xlsx filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
results/Fin-ExBERT.pptx filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Soumick Sarker
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,3 +1,196 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FinExBERT: Financial Sentence Extraction with Graph-Augmented BERT
|
| 2 |
+
|
| 3 |
+
[](https://www.python.org/downloads/)
|
| 4 |
+
[](https://pytorch.org/)
|
| 5 |
+
[](https://opensource.org/licenses/MIT)
|
| 6 |
+
[]([https://arxiv.org/](https://www.arxiv.org/abs/2509.23259))
|
| 7 |
+
|
| 8 |
+
> A state-of-the-art neural architecture for extracting relevant sentences from financial conversations using graph-augmented BERT with dependency parsing.
|
| 9 |
+
|
| 10 |
+
**Accepted at EMNLP 2025 Industry Track**
|
| 11 |
+
|
| 12 |
+
## Overview
|
| 13 |
+
|
| 14 |
+
FinExBERT combines BERT's contextual understanding with graph neural networks to capture syntactic dependencies in financial conversations. The model achieves superior performance in extracting relevant sentences based on user intent, making it particularly effective for financial customer service applications.
|
| 15 |
+
|
| 16 |
+
### Problem Statement
|
| 17 |
+
|
| 18 |
+
Traditional sequence-to-sequence models struggle with:
|
| 19 |
+
- Complex financial terminology and context
|
| 20 |
+
- Long conversation dependencies
|
| 21 |
+
- Intent-based sentence extraction
|
| 22 |
+
- Domain-specific reasoning requirements
|
| 23 |
+
|
| 24 |
+
### Our Solution
|
| 25 |
+
|
| 26 |
+
FinExBERT addresses these challenges through:
|
| 27 |
+
- **Graph-Augmented Architecture**: Incorporates dependency parsing graphs to capture syntactic relationships
|
| 28 |
+
- **Financial Domain Adaptation**: LoRA fine-tuning on financial datasets
|
| 29 |
+
- **Intent-Aware Extraction**: Semantic similarity matching for targeted sentence selection
|
| 30 |
+
- **Efficient Training**: Mixed precision training with gradient accumulation
|
| 31 |
+
|
| 32 |
+
## Key Features
|
| 33 |
+
|
| 34 |
+
- π **State-of-the-art Performance**: Outperforms baseline BERT by 37% in accuracy on financial conversation tasks
|
| 35 |
+
- π§ **Graph Neural Networks**: Integrates dependency parsing for enhanced linguistic understanding
|
| 36 |
+
- π° **Financial Domain Expertise**: Pre-trained on financial conversation data
|
| 37 |
+
- β‘ **Production Ready**: Optimized for real-world deployment with batched inference
|
| 38 |
+
- π§ **Flexible Architecture**: Configurable model components for different use cases
|
| 39 |
+
- π **Comprehensive Evaluation**: Extensive ablation studies and evaluation metrics
|
| 40 |
+
|
| 41 |
+
## Installation
|
| 42 |
+
|
| 43 |
+
### Prerequisites
|
| 44 |
+
|
| 45 |
+
- Python 3.10 or higher
|
| 46 |
+
- PyTorch 1.9 or higher
|
| 47 |
+
- CUDA 11.0+ (for GPU acceleration)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
### Install dependencies
|
| 51 |
+
|
| 52 |
+
```bash
|
| 53 |
+
git clone https://github.com/soumick1/Fin-ExBERT.git
|
| 54 |
+
pip install -r requirements.txt
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
## Quick Start
|
| 58 |
+
|
| 59 |
+
### Download the model weights
|
| 60 |
+
|
| 61 |
+
Download the weights from the [Weights Link](https://drive.google.com/drive/folders/1jm3Yxpew8Y8mVsRizTyVvXKrGBXQ3ApI?usp=sharing)
|
| 62 |
+
And put the 3 folders inside the cloned directory.
|
| 63 |
+
|
| 64 |
+
### Data setup
|
| 65 |
+
|
| 66 |
+
The CreditCall12H Dataset is available in the 'data' folder. If you want to train or test on your own data please use the same format.
|
| 67 |
+
|
| 68 |
+
### Basic Usage and Testing
|
| 69 |
+
|
| 70 |
+
```python
|
| 71 |
+
from utils import batch_predict_and_save
|
| 72 |
+
from config import *
|
| 73 |
+
from preprocess_data import SentenceDataset
|
| 74 |
+
from models import SentenceExtractionModel
|
| 75 |
+
|
| 76 |
+
# Initialize the model
|
| 77 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") ### You can change the tokenizer if you want
|
| 78 |
+
dataset = SentenceDataset("data/Fin_ExBERT_train_val_data.xlsx", tokenizer)
|
| 79 |
+
|
| 80 |
+
model = SentenceExtractionModel(
|
| 81 |
+
base_model_name=MODEL_NAME,
|
| 82 |
+
backbone='finexbert'
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Extract relevant sentences
|
| 86 |
+
batch_predict_and_save(
|
| 87 |
+
model,
|
| 88 |
+
tokenizer,
|
| 89 |
+
excel_path="data/Fin_ExBERT_test_set.xlsx",
|
| 90 |
+
ckpt_path="checkpoints/sentence_extractor/best_model.pth",
|
| 91 |
+
output_path="results/predictions_sample200.xlsx",
|
| 92 |
+
n_samples=200,
|
| 93 |
+
temperature=1.0,
|
| 94 |
+
device="cuda"
|
| 95 |
+
)
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
### Training the model
|
| 99 |
+
|
| 100 |
+
```python
|
| 101 |
+
from utils import train_model_with_chkpt
|
| 102 |
+
from config import *
|
| 103 |
+
from preprocess_data import SentenceDataset
|
| 104 |
+
from models import SentenceExtractionModel
|
| 105 |
+
|
| 106 |
+
# Initialize the model
|
| 107 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") ### You can change the tokenizer if you want
|
| 108 |
+
dataset = SentenceDataset("data/Fin_ExBERT_train_val_data.xlsx", tokenizer)
|
| 109 |
+
|
| 110 |
+
model = SentenceExtractionModel(
|
| 111 |
+
base_model_name=MODEL_NAME,
|
| 112 |
+
backbone='finexbert'
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
train_sentence_extractor(
|
| 116 |
+
model,
|
| 117 |
+
dataset,
|
| 118 |
+
output_dir="checkpoints/sentence_extractor",
|
| 119 |
+
val_split=0.3,
|
| 120 |
+
epochs=10,
|
| 121 |
+
batch_size=16,
|
| 122 |
+
lr=3e-4,
|
| 123 |
+
device=DEVICE,
|
| 124 |
+
unfreeze_after_epoch=4
|
| 125 |
+
)
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
## Model Architecture
|
| 129 |
+
|
| 130 |
+

|
| 131 |
+
|
| 132 |
+
### Core Components
|
| 133 |
+
|
| 134 |
+
1. **BERT Encoder**: Contextual embeddings for input sequences
|
| 135 |
+
2. **Dependency Graph Parser**: SpaCy-based syntactic analysis
|
| 136 |
+
3. **Graph Neural Network**: Message passing over dependency graphs
|
| 137 |
+
4. **Fusion Layer**: Combines BERT and GNN representations
|
| 138 |
+
5. **Classification Head**: Intent-aware sentence scoring
|
| 139 |
+
|
| 140 |
+
### Technical Details
|
| 141 |
+
|
| 142 |
+
- **Base Model**: BERT-base-uncased (110M parameters)
|
| 143 |
+
- **GNN Architecture**: Simple message passing with attention
|
| 144 |
+
- **Training Strategy**: LoRA adaptation + full fine-tuning
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
## Evaluation
|
| 148 |
+
|
| 149 |
+
### Ablation Studies
|
| 150 |
+
|
| 151 |
+
We provide comprehensive ablation studies comparing:
|
| 152 |
+
|
| 153 |
+
- Baseline BERT vs. Graph-Augmented BERT
|
| 154 |
+
- Different GNN architectures
|
| 155 |
+
- Various training strategies
|
| 156 |
+
- Domain adaptation techniques
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
### Performance Metrics
|
| 160 |
+
|
| 161 |
+
| Model | Accuracy | F1-Score | Precision | Recall |
|
| 162 |
+
|-------|----------|----------|-----------|--------|
|
| 163 |
+
| BERT Baseline | 0.323 | 0.163 | 0.145 | 0.189 |
|
| 164 |
+
| FinExBERT | 0.694 | 0.418 | 0.456 | 0.391 |
|
| 165 |
+
| **Improvement** | **+37%** | **+26%** | **+31%** | **+20%** |
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
## Citation
|
| 170 |
+
|
| 171 |
+
If you use FinExBERT in your research, please cite:
|
| 172 |
+
|
| 173 |
+
```bibtex
|
| 174 |
+
Will post it soon
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
## License
|
| 179 |
+
|
| 180 |
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
| 181 |
+
|
| 182 |
+
## Acknowledgments
|
| 183 |
+
|
| 184 |
+
- Built on top of [Transformers](https://github.com/huggingface/transformers) by Hugging Face
|
| 185 |
+
- Graph processing with [SpaCy](https://spacy.io/)
|
| 186 |
+
- Training infrastructure powered by [PyTorch](https://pytorch.org/)
|
| 187 |
+
|
| 188 |
+
## Support
|
| 189 |
+
|
| 190 |
+
- π§ Email: [email protected]
|
| 191 |
+
|
| 192 |
+
---
|
| 193 |
+
|
| 194 |
+
<div align="center">
|
| 195 |
+
<strong>FinExBERT</strong> - Advancing Financial NLP with Graph-Augmented Models
|
| 196 |
+
</div>
|
__init__.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
def install_without_version(package_name):
|
| 5 |
+
package_name = str(package_name)
|
| 6 |
+
try:
|
| 7 |
+
__import__(package_name)
|
| 8 |
+
print(f"{package_name} is already installed.")
|
| 9 |
+
except ImportError:
|
| 10 |
+
print(f"{package_name} is not installed. Installing now...")
|
| 11 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", '-U', package_name])
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def install_with_version(package_name, package_version):
|
| 15 |
+
package_name = str(package_name)
|
| 16 |
+
package_version = str(package_version)
|
| 17 |
+
try:
|
| 18 |
+
pkg = __import__(package_name)
|
| 19 |
+
installed_version = pkg.__version__
|
| 20 |
+
if installed_version == package_version:
|
| 21 |
+
print(f"{package_name} {package_version} is already installed.")
|
| 22 |
+
else:
|
| 23 |
+
print(f"{package_name} {installed_version} is installed, but {package_version} is required. Updating now...")
|
| 24 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", f"{package_name}=={package_version}"])
|
| 25 |
+
except ImportError:
|
| 26 |
+
print(f"Installing version {package_version} now...")
|
| 27 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", f"{package_name}=={package_version}"])
|
| 28 |
+
|
| 29 |
+
if __name__=='__main__':
|
| 30 |
+
packages = [['torch', ''], ['datasets', ''], ['spacy', ''], ['networkx', ''], ['numpy', '1.26.4']]
|
| 31 |
+
for package in packages:
|
| 32 |
+
if package[1] == '':
|
| 33 |
+
install_without_version(package[0])
|
| 34 |
+
else:
|
| 35 |
+
install_with_version(package[0], package[1])
|
| 36 |
+
|
| 37 |
+
subprocess.check_call([sys.executable, "python", "-m", "spacy", "download", "en_core_web_sm"])
|
| 38 |
+
|
ablation_and_evaluation/ablation_studies.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import logging
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
from tqdm.auto import tqdm
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
from sklearn.metrics import accuracy_score, f1_score
|
| 11 |
+
from torch.utils.data import DataLoader, Subset
|
| 12 |
+
from datasets import load_from_disk
|
| 13 |
+
from utils import my_collate_fn
|
| 14 |
+
|
| 15 |
+
from config import MODEL_NAME, PREPROCESSED_DIR, DEVICE
|
| 16 |
+
from preprocess_data import process_data, SpanExtractionChunkedDataset, span_collate_fn
|
| 17 |
+
from models import GraphAugmentedNLIModel
|
| 18 |
+
from transformers import AutoConfig, AutoModel
|
| 19 |
+
|
| 20 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
| 21 |
+
|
| 22 |
+
# ---------------------
|
| 23 |
+
# 1) Define a BERTβonly baseline
|
| 24 |
+
# ---------------------
|
| 25 |
+
class BertOnlyNLIModel(nn.Module):
|
| 26 |
+
def __init__(self, base_model_name: str, num_labels: int = 3):
|
| 27 |
+
super().__init__()
|
| 28 |
+
config = AutoConfig.from_pretrained(base_model_name)
|
| 29 |
+
config.num_labels = num_labels
|
| 30 |
+
self.bert = AutoModel.from_pretrained(base_model_name, config=config)
|
| 31 |
+
hidden_dim = config.hidden_size
|
| 32 |
+
self.dropout = nn.Dropout(0.1)
|
| 33 |
+
self.classifier = nn.Linear(hidden_dim, num_labels)
|
| 34 |
+
|
| 35 |
+
def forward(self, input_ids, attention_mask, labels=None):
|
| 36 |
+
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
| 37 |
+
cls_emb = outputs.last_hidden_state[:, 0, :]
|
| 38 |
+
x = self.dropout(cls_emb)
|
| 39 |
+
logits = self.classifier(x)
|
| 40 |
+
loss = None
|
| 41 |
+
if labels is not None:
|
| 42 |
+
loss_fn = nn.CrossEntropyLoss()
|
| 43 |
+
loss = loss_fn(logits, labels)
|
| 44 |
+
return {"loss": loss, "logits": logits}
|
| 45 |
+
|
| 46 |
+
# ---------------------
|
| 47 |
+
# 2) Training & evaluation routines
|
| 48 |
+
# ---------------------
|
| 49 |
+
def set_seed(seed=42):
|
| 50 |
+
random.seed(seed)
|
| 51 |
+
np.random.seed(seed)
|
| 52 |
+
torch.manual_seed(seed)
|
| 53 |
+
if torch.cuda.is_available():
|
| 54 |
+
torch.cuda.manual_seed_all(seed)
|
| 55 |
+
|
| 56 |
+
def train_one_epoch(model, loader, optimizer, scheduler):
|
| 57 |
+
model.train()
|
| 58 |
+
losses = []
|
| 59 |
+
is_gnn = hasattr(model, "gnn_premise") # True for GraphAugmentedNLIModel
|
| 60 |
+
|
| 61 |
+
for batch in tqdm(loader, leave=False):
|
| 62 |
+
optimizer.zero_grad()
|
| 63 |
+
# Move all tensor fields to DEVICE
|
| 64 |
+
batch = {
|
| 65 |
+
k: v.to(DEVICE) if torch.is_tensor(v) else v
|
| 66 |
+
for k, v in batch.items()
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
if is_gnn:
|
| 70 |
+
out = model(
|
| 71 |
+
input_ids=batch["input_ids"],
|
| 72 |
+
attention_mask=batch["attention_mask"],
|
| 73 |
+
premise_graph_tokens=batch["premise_graph_tokens"],
|
| 74 |
+
premise_graph_edges=batch["premise_graph_edges"],
|
| 75 |
+
premise_node_indices=batch["premise_node_indices"],
|
| 76 |
+
hypothesis_graph_tokens=batch["hypothesis_graph_tokens"],
|
| 77 |
+
hypothesis_graph_edges=batch["hypothesis_graph_edges"],
|
| 78 |
+
hypothesis_node_indices=batch["hypothesis_node_indices"],
|
| 79 |
+
labels=batch["labels"],
|
| 80 |
+
)
|
| 81 |
+
else:
|
| 82 |
+
out = model(
|
| 83 |
+
input_ids=batch["input_ids"],
|
| 84 |
+
attention_mask=batch["attention_mask"],
|
| 85 |
+
labels=batch["labels"],
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
loss = out["loss"]
|
| 89 |
+
loss.backward()
|
| 90 |
+
optimizer.step()
|
| 91 |
+
scheduler.step()
|
| 92 |
+
losses.append(loss.item())
|
| 93 |
+
|
| 94 |
+
return float(np.mean(losses))
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@torch.no_grad()
|
| 98 |
+
def evaluate(model, loader):
|
| 99 |
+
model.eval()
|
| 100 |
+
preds, golds = [], []
|
| 101 |
+
is_gnn = hasattr(model, "gnn_premise")
|
| 102 |
+
|
| 103 |
+
for batch in loader:
|
| 104 |
+
batch = {
|
| 105 |
+
k: v.to(DEVICE) if torch.is_tensor(v) else v
|
| 106 |
+
for k, v in batch.items()
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
if is_gnn:
|
| 110 |
+
out = model(
|
| 111 |
+
input_ids=batch["input_ids"],
|
| 112 |
+
attention_mask=batch["attention_mask"],
|
| 113 |
+
premise_graph_tokens=batch["premise_graph_tokens"],
|
| 114 |
+
premise_graph_edges=batch["premise_graph_edges"],
|
| 115 |
+
premise_node_indices=batch["premise_node_indices"],
|
| 116 |
+
hypothesis_graph_tokens=batch["hypothesis_graph_tokens"],
|
| 117 |
+
hypothesis_graph_edges=batch["hypothesis_graph_edges"],
|
| 118 |
+
hypothesis_node_indices=batch["hypothesis_node_indices"],
|
| 119 |
+
)
|
| 120 |
+
else:
|
| 121 |
+
out = model(
|
| 122 |
+
input_ids=batch["input_ids"],
|
| 123 |
+
attention_mask=batch["attention_mask"],
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
logits = out["logits"].cpu().numpy()
|
| 127 |
+
preds.extend(np.argmax(logits, axis=1).tolist())
|
| 128 |
+
golds.extend(batch["labels"].cpu().tolist())
|
| 129 |
+
|
| 130 |
+
acc = accuracy_score(golds, preds)
|
| 131 |
+
f1 = f1_score(golds, preds, average="macro")
|
| 132 |
+
return acc, f1
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# ---------------------
|
| 137 |
+
# 3) Ablation runner
|
| 138 |
+
# ---------------------
|
| 139 |
+
def run_ablation(
|
| 140 |
+
epochs=3,
|
| 141 |
+
batch_size=16,
|
| 142 |
+
lr=2e-5,
|
| 143 |
+
sample_frac=0.05, # β fraction of data to use
|
| 144 |
+
):
|
| 145 |
+
set_seed()
|
| 146 |
+
process_data()
|
| 147 |
+
|
| 148 |
+
# --- Load the preprocessed SNLI dataset from disk ---
|
| 149 |
+
snli = load_from_disk(PREPROCESSED_DIR)
|
| 150 |
+
full_train = snli["train"]
|
| 151 |
+
full_val = snli["validation"]
|
| 152 |
+
|
| 153 |
+
# --- Sample 10% of each split ---
|
| 154 |
+
num_train = len(full_train)
|
| 155 |
+
num_val = len(full_val)
|
| 156 |
+
n_train = max(1, int(sample_frac * num_train))
|
| 157 |
+
n_val = max(1, int(sample_frac * num_val))
|
| 158 |
+
|
| 159 |
+
# reproducible shuffling
|
| 160 |
+
train_indices = list(range(num_train))
|
| 161 |
+
random.shuffle(train_indices)
|
| 162 |
+
train_subset = Subset(full_train, train_indices[:n_train])
|
| 163 |
+
|
| 164 |
+
val_indices = list(range(num_val))
|
| 165 |
+
random.shuffle(val_indices)
|
| 166 |
+
val_subset = Subset(full_val, val_indices[:n_val])
|
| 167 |
+
|
| 168 |
+
# --- Build DataLoaders with the SNLI collate_fn ---
|
| 169 |
+
train_loader = DataLoader(
|
| 170 |
+
train_subset,
|
| 171 |
+
batch_size=batch_size,
|
| 172 |
+
shuffle=True,
|
| 173 |
+
collate_fn=my_collate_fn,
|
| 174 |
+
num_workers=4,
|
| 175 |
+
pin_memory=True,
|
| 176 |
+
)
|
| 177 |
+
val_loader = DataLoader(
|
| 178 |
+
val_subset,
|
| 179 |
+
batch_size=batch_size,
|
| 180 |
+
shuffle=False,
|
| 181 |
+
collate_fn=my_collate_fn,
|
| 182 |
+
num_workers=2,
|
| 183 |
+
pin_memory=True,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# 4) Define models
|
| 187 |
+
models = {
|
| 188 |
+
"Baseline-BERT": BertOnlyNLIModel(MODEL_NAME).to(DEVICE),
|
| 189 |
+
"GNN-Augmented": GraphAugmentedNLIModel(
|
| 190 |
+
base_model_name=MODEL_NAME,
|
| 191 |
+
num_labels=3,
|
| 192 |
+
hidden_dim=768,
|
| 193 |
+
gnn_dim=128
|
| 194 |
+
).to(DEVICE),
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
results = {}
|
| 198 |
+
for name, model in models.items():
|
| 199 |
+
logging.info(f"--- Training {name} on {sample_frac*100:.0f}% of data ---")
|
| 200 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
| 201 |
+
total_steps = epochs * len(train_loader)
|
| 202 |
+
scheduler = torch.optim.lr_scheduler.LinearLR(
|
| 203 |
+
optimizer,
|
| 204 |
+
start_factor=0.1,
|
| 205 |
+
total_iters=total_steps
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# training loop
|
| 209 |
+
for epoch in range(1, epochs + 1):
|
| 210 |
+
train_loss = train_one_epoch(model, train_loader, optimizer, scheduler)
|
| 211 |
+
logging.info(f"{name} Epoch {epoch}: train_loss={train_loss:.4f}")
|
| 212 |
+
|
| 213 |
+
# evaluation
|
| 214 |
+
acc, f1 = evaluate(model, val_loader)
|
| 215 |
+
logging.info(f"{name} on {sample_frac*100:.0f}% val β acc={acc:.4f}, f1={f1:.4f}")
|
| 216 |
+
results[name] = {"accuracy": acc, "f1": f1}
|
| 217 |
+
|
| 218 |
+
# 5) Plot
|
| 219 |
+
names = list(results.keys())
|
| 220 |
+
accs = [results[n]["accuracy"] for n in names]
|
| 221 |
+
f1s = [results[n]["f1"] for n in names]
|
| 222 |
+
|
| 223 |
+
plt.figure()
|
| 224 |
+
plt.bar(names, accs)
|
| 225 |
+
plt.xlabel("Model")
|
| 226 |
+
plt.ylabel("Validation Accuracy")
|
| 227 |
+
plt.title(f"Ablation on {sample_frac*100:.0f}% Data: Accuracy")
|
| 228 |
+
|
| 229 |
+
plt.figure()
|
| 230 |
+
plt.bar(names, f1s)
|
| 231 |
+
plt.xlabel("Model")
|
| 232 |
+
plt.ylabel("Validation Macro-F1")
|
| 233 |
+
plt.title(f"Ablation on {sample_frac*100:.0f}% Data: Macro-F1")
|
| 234 |
+
|
| 235 |
+
plt.show()
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
if __name__ == "__main__":
|
| 239 |
+
run_ablation(epochs=5, batch_size=16, lr=5e-3, sample_frac=0.1)
|
ablation_and_evaluation/eval2.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, re, numpy as np, pandas as pd
|
| 2 |
+
from tqdm.auto import tqdm
|
| 3 |
+
from datasets import load_dataset
|
| 4 |
+
from transformers import pipeline
|
| 5 |
+
from utils import extract_sentences_by_intent, nlp # <- your spaCy model
|
| 6 |
+
|
| 7 |
+
# βββ CONFIG ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 8 |
+
TOP_K = 3 # candidate spans per example
|
| 9 |
+
N_PER_DS = 200 # keep *valid* examples per dataset
|
| 10 |
+
BATCH_SIZE = 16
|
| 11 |
+
DEVICE = 0 # GPU id (-1 = CPU)
|
| 12 |
+
OUTPUT_PATH = "results/combined_results.xlsx"
|
| 13 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 14 |
+
|
| 15 |
+
# ββ helper: flatten arbitrary json-ish field to plain text ββββββββββ
|
| 16 |
+
def flatten_to_text(x):
|
| 17 |
+
if isinstance(x, str):
|
| 18 |
+
return x
|
| 19 |
+
if isinstance(x, dict):
|
| 20 |
+
if "text" in x and isinstance(x["text"], str):
|
| 21 |
+
return x["text"]
|
| 22 |
+
return "\n".join(flatten_to_text(v) for v in x.values())
|
| 23 |
+
if isinstance(x, (list, tuple)):
|
| 24 |
+
return "\n".join(flatten_to_text(v) for v in x)
|
| 25 |
+
return str(x)
|
| 26 |
+
|
| 27 |
+
# ββ helper: map any label (βscore 3β or βgood answerβ) β int 1-5 ββββ
|
| 28 |
+
LABEL_STRINGS = [
|
| 29 |
+
"very bad answer", # 1
|
| 30 |
+
"bad answer", # 2
|
| 31 |
+
"acceptable answer", # 3
|
| 32 |
+
"good answer", # 4
|
| 33 |
+
"perfect answer" # 5
|
| 34 |
+
]
|
| 35 |
+
def label_to_int(lbl: str) -> int:
|
| 36 |
+
m = re.search(r"([1-5])", lbl)
|
| 37 |
+
if m: # the label already contains a digit
|
| 38 |
+
return int(m.group(1))
|
| 39 |
+
for i, s in enumerate(LABEL_STRINGS, 1):
|
| 40 |
+
if s in lbl.lower():
|
| 41 |
+
return i
|
| 42 |
+
return 1 # fallback
|
| 43 |
+
|
| 44 |
+
# ββ datasets β full splits will be shuffled, then filtered ββββββββββ
|
| 45 |
+
datasets_info = [
|
| 46 |
+
("FinQA-10K", "virattt/financial-qa-10K", "train", False, {}),
|
| 47 |
+
("SQuAD", "rajpurkar/squad", "validation", False, {}),
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
# ββ zero-shot classification judges (all ~125-140 M params) βββββββββ
|
| 51 |
+
candidate_labels = LABEL_STRINGS # same list for every judge
|
| 52 |
+
|
| 53 |
+
judge1 = pipeline("zero-shot-classification",
|
| 54 |
+
model="roberta-large-mnli",
|
| 55 |
+
device=DEVICE, batch_size=BATCH_SIZE)
|
| 56 |
+
|
| 57 |
+
judge2 = pipeline("zero-shot-classification",
|
| 58 |
+
model="microsoft/deberta-base-mnli",
|
| 59 |
+
device=DEVICE, batch_size=BATCH_SIZE)
|
| 60 |
+
|
| 61 |
+
judge3 = pipeline("zero-shot-classification",
|
| 62 |
+
model="valhalla/distilbart-mnli-12-3",
|
| 63 |
+
device=DEVICE, batch_size=BATCH_SIZE)
|
| 64 |
+
|
| 65 |
+
# ββ main loop βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 66 |
+
rows = []
|
| 67 |
+
|
| 68 |
+
for ds_name, hf_id, split, trust_code, extra_kwargs in datasets_info:
|
| 69 |
+
print(f"\nβ Loading {ds_name} ({hf_id}#{split}) and collecting {N_PER_DS} examplesβ¦")
|
| 70 |
+
ds = load_dataset(hf_id, split=split, trust_remote_code=trust_code, **extra_kwargs)
|
| 71 |
+
ds = ds.shuffle(seed=42)
|
| 72 |
+
|
| 73 |
+
collected, bar = 0, tqdm(total=N_PER_DS, desc=f"{ds_name} valid")
|
| 74 |
+
for ex in ds:
|
| 75 |
+
# unified fields -------------------------------------------------------
|
| 76 |
+
question = ex.get("question") or ex.get("question_text") or ex.get("query") or ""
|
| 77 |
+
context = flatten_to_text(
|
| 78 |
+
ex.get("context") or ex.get("document_text") or ex.get("story") or ex.get("text") or ""
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# keep only if context has β₯ 2 sentences -------------------------------
|
| 82 |
+
if len(list(nlp(context).sents)) < 2:
|
| 83 |
+
continue
|
| 84 |
+
|
| 85 |
+
# candidate spans ------------------------------------------------------
|
| 86 |
+
spans = [s for s, _ in extract_sentences_by_intent(
|
| 87 |
+
text=context, intent=question, threshold=-1.0, top_k=TOP_K)]
|
| 88 |
+
|
| 89 |
+
if not spans: # no hit β fill with defaults
|
| 90 |
+
rows.append({
|
| 91 |
+
"dataset": ds_name, "question": question, "context": context,
|
| 92 |
+
"span": "", "score1": 5.0, "score2": 5.0, "score3": 5.0, "score_avg": 5.0
|
| 93 |
+
})
|
| 94 |
+
collected += 1; bar.update(1)
|
| 95 |
+
if collected >= N_PER_DS: break
|
| 96 |
+
continue
|
| 97 |
+
|
| 98 |
+
prompts = [
|
| 99 |
+
f"Question: {question}\nCandidate answer: {span}\n\n"
|
| 100 |
+
"On a scale from 1 (completely wrong) to 5 (perfect), reply with a single digit."
|
| 101 |
+
for span in spans
|
| 102 |
+
]
|
| 103 |
+
|
| 104 |
+
# run judges -----------------------------------------------------------
|
| 105 |
+
out1 = judge1(prompts, candidate_labels=candidate_labels, multi_label=False)
|
| 106 |
+
out2 = judge2(prompts, candidate_labels=candidate_labels, multi_label=False)
|
| 107 |
+
out3 = judge3(prompts, candidate_labels=candidate_labels, multi_label=False)
|
| 108 |
+
|
| 109 |
+
j1 = [label_to_int(o["labels"][0]) for o in out1]
|
| 110 |
+
j2 = [label_to_int(o["labels"][0]) for o in out2]
|
| 111 |
+
j3 = [label_to_int(o["labels"][0]) for o in out3]
|
| 112 |
+
|
| 113 |
+
avg_scores = [(a+b+c)/3.0 for a, b, c in zip(j1, j2, j3)]
|
| 114 |
+
best = int(np.argmax(avg_scores))
|
| 115 |
+
|
| 116 |
+
rows.append({
|
| 117 |
+
"dataset": ds_name, "question": question, "context": context,
|
| 118 |
+
"span": spans[best],
|
| 119 |
+
"score1": float(j1[best]), "score2": float(j2[best]), "score3": float(j3[best]),
|
| 120 |
+
"score_avg": float(avg_scores[best])
|
| 121 |
+
})
|
| 122 |
+
collected += 1; bar.update(1)
|
| 123 |
+
if collected >= N_PER_DS:
|
| 124 |
+
break
|
| 125 |
+
bar.close()
|
| 126 |
+
if collected < N_PER_DS:
|
| 127 |
+
print(f"β οΈ Only {collected} qualifying examples found for {ds_name}")
|
| 128 |
+
|
| 129 |
+
# ββ save & report βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 130 |
+
os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)
|
| 131 |
+
pd.DataFrame(rows).to_excel(OUTPUT_PATH, index=False)
|
| 132 |
+
print(f"\nβοΈ Saved combined results β {OUTPUT_PATH}")
|
| 133 |
+
|
| 134 |
+
print("\nβΆοΈ Per-dataset judge averages:")
|
| 135 |
+
summary = (pd.DataFrame(rows)
|
| 136 |
+
.groupby("dataset")[["score1", "score2", "score3", "score_avg"]]
|
| 137 |
+
.mean())
|
| 138 |
+
for ds, row in summary.iterrows():
|
| 139 |
+
print(f" {ds:12s} | "
|
| 140 |
+
f"Judge1 {row['score1']:.2f} "
|
| 141 |
+
f"Judge2 {row['score2']:.2f} "
|
| 142 |
+
f"Judge3 {row['score3']:.2f} "
|
| 143 |
+
f"Combined {row['score_avg']:.2f}")
|
ablation_and_evaluation/evaluation_studies.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from tqdm.auto import tqdm
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
+
from transformers import pipeline
|
| 8 |
+
from utils import extract_sentences_by_intent, nlp
|
| 9 |
+
|
| 10 |
+
# βββ CONFIG ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 11 |
+
TOP_K = 3 # number of spans to extract per example
|
| 12 |
+
N_PER_DS = 200 # how many *valid* examples per dataset
|
| 13 |
+
BATCH_SIZE = 16 # batch size for judge pipelines
|
| 14 |
+
DEVICE = 0 # GPU id, or -1 for CPU
|
| 15 |
+
OUTPUT_PATH = "results/combined_results.xlsx"
|
| 16 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 17 |
+
|
| 18 |
+
def flatten_to_text(x):
|
| 19 |
+
if isinstance(x, str):
|
| 20 |
+
return x
|
| 21 |
+
if isinstance(x, dict):
|
| 22 |
+
if "text" in x and isinstance(x["text"], str):
|
| 23 |
+
return x["text"]
|
| 24 |
+
return "\n".join(flatten_to_text(v) for v in x.values())
|
| 25 |
+
if isinstance(x, list):
|
| 26 |
+
return "\n".join(flatten_to_text(v) for v in x)
|
| 27 |
+
return str(x)
|
| 28 |
+
|
| 29 |
+
def label_to_int(lbl: str) -> int:
|
| 30 |
+
# handles both variants A and B
|
| 31 |
+
m = re.search(r"([1-5])", lbl)
|
| 32 |
+
if m: # digits present -> easy
|
| 33 |
+
return int(m.group(1))
|
| 34 |
+
# descriptive version -> map by order
|
| 35 |
+
mapping = {
|
| 36 |
+
"very bad answer": 1,
|
| 37 |
+
"bad answer": 2,
|
| 38 |
+
"acceptable answer": 3,
|
| 39 |
+
"good answer": 4,
|
| 40 |
+
"perfect answer": 5
|
| 41 |
+
}
|
| 42 |
+
return mapping.get(lbl.lower(), 1)
|
| 43 |
+
|
| 44 |
+
# βββ 1) choose these two datasets ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 45 |
+
datasets_info = [
|
| 46 |
+
("FinQA-10K", "virattt/financial-qa-10K", "train", False, {}),
|
| 47 |
+
("SQuAD", "rajpurkar/squad", "validation", False, {}),
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
# βββ 2) spin up zeroβshot classification judges ββββββββββββββββββββββββββββββββ
|
| 51 |
+
candidate_labels = [
|
| 52 |
+
"very bad answer", # 1
|
| 53 |
+
"bad answer", # 2
|
| 54 |
+
"acceptable answer", # 3
|
| 55 |
+
"good answer", # 4
|
| 56 |
+
"perfect answer" # 5
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
judge1 = pipeline(
|
| 60 |
+
"zero-shot-classification",
|
| 61 |
+
model="roberta-large-mnli",
|
| 62 |
+
device=DEVICE,
|
| 63 |
+
batch_size=BATCH_SIZE
|
| 64 |
+
)
|
| 65 |
+
judge2 = pipeline(
|
| 66 |
+
"zero-shot-classification",
|
| 67 |
+
model="microsoft/deberta-base-mnli",
|
| 68 |
+
tokenizer="microsoft/deberta-base-mnli",
|
| 69 |
+
device=DEVICE,
|
| 70 |
+
batch_size=BATCH_SIZE
|
| 71 |
+
)
|
| 72 |
+
judge3 = pipeline(
|
| 73 |
+
"zero-shot-classification",
|
| 74 |
+
model="valhalla/distilbart-mnli-12-3",
|
| 75 |
+
tokenizer="valhalla/distilbart-mnli-12-3",
|
| 76 |
+
device=DEVICE,
|
| 77 |
+
batch_size=BATCH_SIZE
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
all_rows = []
|
| 81 |
+
|
| 82 |
+
for ds_name, hf_id, split, trust_code, extra_kwargs in datasets_info:
|
| 83 |
+
print(f"\nβ Loading {ds_name} ({hf_id}#{split}), gathering {N_PER_DS} valid examplesβ¦")
|
| 84 |
+
# load full split and shuffle
|
| 85 |
+
ds = load_dataset(hf_id, split=split, trust_remote_code=trust_code, **extra_kwargs)
|
| 86 |
+
ds = ds.shuffle(seed=42)
|
| 87 |
+
|
| 88 |
+
collected = 0
|
| 89 |
+
pbar = tqdm(total=N_PER_DS, desc=f"{ds_name} valid examples")
|
| 90 |
+
for ex in ds:
|
| 91 |
+
# unify question & context
|
| 92 |
+
question = (
|
| 93 |
+
ex.get("question")
|
| 94 |
+
or ex.get("question_text")
|
| 95 |
+
or ex.get("query")
|
| 96 |
+
or ""
|
| 97 |
+
)
|
| 98 |
+
raw_ctx = (
|
| 99 |
+
ex.get("context")
|
| 100 |
+
or ex.get("document_text")
|
| 101 |
+
or ex.get("story")
|
| 102 |
+
or ex.get("text")
|
| 103 |
+
or ""
|
| 104 |
+
)
|
| 105 |
+
context = flatten_to_text(raw_ctx)
|
| 106 |
+
|
| 107 |
+
# only keep examples whose context has at least 2 sentences
|
| 108 |
+
if len(list(nlp(context).sents)) < 2:
|
| 109 |
+
continue
|
| 110 |
+
|
| 111 |
+
# extract top-K spans
|
| 112 |
+
hits = extract_sentences_by_intent(
|
| 113 |
+
text = context,
|
| 114 |
+
intent = question,
|
| 115 |
+
threshold = -1.0,
|
| 116 |
+
top_k = TOP_K,
|
| 117 |
+
convo_focus = None
|
| 118 |
+
)
|
| 119 |
+
spans = [s for s,_ in hits]
|
| 120 |
+
if not spans:
|
| 121 |
+
# record defaults if no span found
|
| 122 |
+
all_rows.append({
|
| 123 |
+
"dataset": ds_name,
|
| 124 |
+
"question": question,
|
| 125 |
+
"context": context,
|
| 126 |
+
"span": "",
|
| 127 |
+
"score1": 5.0,
|
| 128 |
+
"score2": 5.0,
|
| 129 |
+
"score3": 5.0,
|
| 130 |
+
"score_avg": 5.0
|
| 131 |
+
})
|
| 132 |
+
collected += 1
|
| 133 |
+
pbar.update(1)
|
| 134 |
+
if collected >= N_PER_DS:
|
| 135 |
+
break
|
| 136 |
+
continue
|
| 137 |
+
|
| 138 |
+
# build prompts
|
| 139 |
+
prompts = [
|
| 140 |
+
f"Question: {question}\nCandidate answer: {span}\n\n"
|
| 141 |
+
"On a scale from 1 (completely wrong) to 5 (perfect), "
|
| 142 |
+
"reply with a single digit."
|
| 143 |
+
for span in spans
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
# run judges
|
| 147 |
+
out1 = judge1(prompts, candidate_labels=candidate_labels, multi_label=False)
|
| 148 |
+
out2 = judge2(prompts, candidate_labels=candidate_labels, multi_label=False)
|
| 149 |
+
out3 = judge3(prompts, candidate_labels=candidate_labels, multi_label=False)
|
| 150 |
+
|
| 151 |
+
# parse their topβchosen labels
|
| 152 |
+
j1 = [int(o["labels"][0]) for o in out1]
|
| 153 |
+
j2 = [int(o["labels"][0]) for o in out2]
|
| 154 |
+
j3 = [int(o["labels"][0]) for o in out3]
|
| 155 |
+
|
| 156 |
+
# average per span, pick best
|
| 157 |
+
avg_scores = [(a+b+c)/3.0 for a,b,c in zip(j1,j2,j3)]
|
| 158 |
+
best_idx = int(np.argmax(avg_scores))
|
| 159 |
+
|
| 160 |
+
all_rows.append({
|
| 161 |
+
"dataset": ds_name,
|
| 162 |
+
"question": question,
|
| 163 |
+
"context": context,
|
| 164 |
+
"span": spans[best_idx],
|
| 165 |
+
"score1": float(j1[best_idx]),
|
| 166 |
+
"score2": float(j2[best_idx]),
|
| 167 |
+
"score3": float(j3[best_idx]),
|
| 168 |
+
"score_avg": float(avg_scores[best_idx])
|
| 169 |
+
})
|
| 170 |
+
collected += 1
|
| 171 |
+
pbar.update(1)
|
| 172 |
+
if collected >= N_PER_DS:
|
| 173 |
+
break
|
| 174 |
+
|
| 175 |
+
pbar.close()
|
| 176 |
+
if collected < N_PER_DS:
|
| 177 |
+
print(f"β οΈ Only found {collected} valid examples for {ds_name}.")
|
| 178 |
+
|
| 179 |
+
# βββ dump to Excel βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 180 |
+
df = pd.DataFrame(all_rows)
|
| 181 |
+
os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)
|
| 182 |
+
df.to_excel(OUTPUT_PATH, index=False)
|
| 183 |
+
print(f"\nβοΈ Saved combined results to ./{OUTPUT_PATH}")
|
| 184 |
+
|
| 185 |
+
# βββ perβdataset summary βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 186 |
+
print("\nβΆοΈ Perβdataset judge averages:")
|
| 187 |
+
grouped = df.groupby("dataset")[["score1","score2","score3","score_avg"]].mean()
|
| 188 |
+
for ds, row in grouped.iterrows():
|
| 189 |
+
print(f" {ds}: "
|
| 190 |
+
f"Judge1={row['score1']:.2f}, "
|
| 191 |
+
f"Judge2={row['score2']:.2f}, "
|
| 192 |
+
f"Judge3={row['score3']:.2f}, "
|
| 193 |
+
f"Combined={row['score_avg']:.2f}")
|
config.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# config.py
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from transformers import AutoTokenizer
|
| 5 |
+
import spacy
|
| 6 |
+
|
| 7 |
+
MODEL_NAME = "bert-base-uncased"
|
| 8 |
+
MAX_LENGTH = 128
|
| 9 |
+
OVERLAP = 32
|
| 10 |
+
PREPROCESSED_DIR= "preprocessed_snli"
|
| 11 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 12 |
+
|
| 13 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 14 |
+
nlp = spacy.load("en_core_web_sm")
|
data/Fin_ExBERT_data.xlsx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3bcd855be3a6ee5d740c9f61439268f5a84a085a2bf5a0593027c80c8a97f34c
|
| 3 |
+
size 919294
|
data/Fin_ExBERT_test_set.xlsx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8ca7db6ee8c28319f3fd1736a0bb05cfe5bacf1f33e07549f54fd90134b3706f
|
| 3 |
+
size 324485
|
data/Fin_ExBERT_train_val_data.xlsx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4d4804d0e204a40e797a537767c95999e585eb7f50cea416b784311efe76c731
|
| 3 |
+
size 1571514
|
finetune_lora.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
import math
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
from datasets import load_dataset
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
from tqdm.auto import tqdm
|
| 10 |
+
from transformers import (
|
| 11 |
+
AutoTokenizer,
|
| 12 |
+
AutoModelForMaskedLM,
|
| 13 |
+
DataCollatorForLanguageModeling,
|
| 14 |
+
get_linear_schedule_with_warmup,
|
| 15 |
+
)
|
| 16 |
+
from accelerate import Accelerator
|
| 17 |
+
from peft import LoraConfig, get_peft_model
|
| 18 |
+
|
| 19 |
+
# Configuration constants
|
| 20 |
+
MODEL_NAME = "bert-base-uncased"
|
| 21 |
+
BATCH_SIZE = 16
|
| 22 |
+
MAX_LENGTH = 128
|
| 23 |
+
LEARNING_RATE = 5e-4
|
| 24 |
+
EPOCHS = 20
|
| 25 |
+
SEED = 42
|
| 26 |
+
ADAPTER_SAVE_DIR = "./lora_finance_adapter"
|
| 27 |
+
CHECKPOINT_PATH = os.path.join(ADAPTER_SAVE_DIR, "training_checkpoint.pt")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def set_seed(seed: int = SEED):
|
| 31 |
+
random.seed(seed)
|
| 32 |
+
np.random.seed(seed)
|
| 33 |
+
torch.manual_seed(seed)
|
| 34 |
+
if torch.cuda.is_available():
|
| 35 |
+
torch.cuda.manual_seed_all(seed)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def fine_tune_lora(dataset_name: str = "FinGPT/fingpt-fiqa_qa", split: str = "train"):
|
| 39 |
+
"""
|
| 40 |
+
Fine-tune BERT with LoRA on an MLM objective.
|
| 41 |
+
Supports checkpointing and resuming, and plots loss, perplexity, and MLM accuracy per epoch.
|
| 42 |
+
Saves the LoRA adapter and checkpoint in ADAPTER_SAVE_DIR.
|
| 43 |
+
"""
|
| 44 |
+
set_seed()
|
| 45 |
+
|
| 46 |
+
# Prepare save directory
|
| 47 |
+
os.makedirs(ADAPTER_SAVE_DIR, exist_ok=True)
|
| 48 |
+
|
| 49 |
+
# Load and prepare dataset
|
| 50 |
+
dataset = load_dataset(dataset_name, split=split)
|
| 51 |
+
def combine_fields(example):
|
| 52 |
+
text = ' '.join([example.get(k, '').strip() for k in ['instruction', 'input', 'output'] if example.get(k)])
|
| 53 |
+
return {"text": text}
|
| 54 |
+
dataset = dataset.map(combine_fields, remove_columns=[c for c in dataset.column_names if c != 'text'])
|
| 55 |
+
|
| 56 |
+
# Tokenization and DataLoader
|
| 57 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
|
| 58 |
+
def tokenize_fn(examples):
|
| 59 |
+
return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=MAX_LENGTH)
|
| 60 |
+
tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=[c for c in dataset.column_names if c != 'text'])
|
| 61 |
+
tokenized.set_format(type='torch', columns=['input_ids', 'attention_mask'])
|
| 62 |
+
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
|
| 63 |
+
train_loader = DataLoader(
|
| 64 |
+
tokenized,
|
| 65 |
+
batch_size=BATCH_SIZE,
|
| 66 |
+
shuffle=True,
|
| 67 |
+
collate_fn=collator,
|
| 68 |
+
num_workers=4,
|
| 69 |
+
pin_memory=True,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Model, LoRA, optimizer, scheduler
|
| 73 |
+
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME)
|
| 74 |
+
lora_cfg = LoraConfig(r=8, lora_alpha=32, lora_dropout=0.1, bias='none', task_type='CAUSAL_LM')
|
| 75 |
+
model = get_peft_model(model, lora_cfg)
|
| 76 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
| 77 |
+
total_steps = EPOCHS * len(train_loader)
|
| 78 |
+
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(0.1 * total_steps), num_training_steps=total_steps)
|
| 79 |
+
|
| 80 |
+
# Accelerator
|
| 81 |
+
accelerator = Accelerator()
|
| 82 |
+
model, optimizer, train_loader, scheduler = accelerator.prepare(model, optimizer, train_loader, scheduler)
|
| 83 |
+
device = accelerator.device
|
| 84 |
+
|
| 85 |
+
# Metrics storage and resume state
|
| 86 |
+
start_epoch = 1
|
| 87 |
+
epoch_losses = []
|
| 88 |
+
epoch_ppls = []
|
| 89 |
+
epoch_accs = []
|
| 90 |
+
|
| 91 |
+
# Load checkpoint if exists
|
| 92 |
+
if os.path.exists(CHECKPOINT_PATH):
|
| 93 |
+
ckpt = torch.load(CHECKPOINT_PATH, map_location=device)
|
| 94 |
+
model.load_state_dict(ckpt['model_state_dict'])
|
| 95 |
+
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
|
| 96 |
+
scheduler.load_state_dict(ckpt['scheduler_state_dict'])
|
| 97 |
+
start_epoch = ckpt['epoch'] + 1
|
| 98 |
+
epoch_losses = ckpt.get('epoch_losses', [])
|
| 99 |
+
epoch_ppls = ckpt.get('epoch_ppls', [])
|
| 100 |
+
epoch_accs = ckpt.get('epoch_accs', [])
|
| 101 |
+
print(f"Resuming from epoch {start_epoch}")
|
| 102 |
+
|
| 103 |
+
# Training loop
|
| 104 |
+
model.train()
|
| 105 |
+
for epoch in range(start_epoch, EPOCHS + 1):
|
| 106 |
+
total_loss, total_masked, correct_masked = 0.0, 0, 0
|
| 107 |
+
progress = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}", leave=False)
|
| 108 |
+
for batch in progress:
|
| 109 |
+
optimizer.zero_grad()
|
| 110 |
+
input_ids = batch['input_ids'].to(device)
|
| 111 |
+
attention_mask = batch['attention_mask'].to(device)
|
| 112 |
+
labels = batch['labels'].to(device)
|
| 113 |
+
|
| 114 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| 115 |
+
loss, logits = outputs.loss, outputs.logits
|
| 116 |
+
accelerator.backward(loss)
|
| 117 |
+
optimizer.step()
|
| 118 |
+
scheduler.step()
|
| 119 |
+
|
| 120 |
+
# Accumulate
|
| 121 |
+
step_loss = loss.item()
|
| 122 |
+
total_loss += step_loss
|
| 123 |
+
preds = torch.argmax(logits, dim=-1)
|
| 124 |
+
mask = labels.ne(-100)
|
| 125 |
+
correct_masked += preds.eq(labels).masked_select(mask).sum().item()
|
| 126 |
+
total_masked += mask.sum().item()
|
| 127 |
+
progress.set_postfix({'loss': f"{step_loss:.4f}"})
|
| 128 |
+
|
| 129 |
+
# Epoch metrics
|
| 130 |
+
avg_loss = total_loss / len(train_loader)
|
| 131 |
+
avg_ppl = math.exp(avg_loss)
|
| 132 |
+
avg_acc = correct_masked / total_masked if total_masked > 0 else 0
|
| 133 |
+
epoch_losses.append(avg_loss)
|
| 134 |
+
epoch_ppls.append(avg_ppl)
|
| 135 |
+
epoch_accs.append(avg_acc)
|
| 136 |
+
print(f"Epoch {epoch}: Loss={avg_loss:.4f}, PPL={avg_ppl:.2f}, MLM Acc={avg_acc:.4%}")
|
| 137 |
+
|
| 138 |
+
# Save checkpoint
|
| 139 |
+
ckpt = {
|
| 140 |
+
'epoch': epoch,
|
| 141 |
+
'model_state_dict': model.state_dict(),
|
| 142 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 143 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
| 144 |
+
'epoch_losses': epoch_losses,
|
| 145 |
+
'epoch_ppls': epoch_ppls,
|
| 146 |
+
'epoch_accs': epoch_accs,
|
| 147 |
+
}
|
| 148 |
+
torch.save(ckpt, CHECKPOINT_PATH)
|
| 149 |
+
|
| 150 |
+
# Final plots
|
| 151 |
+
fig, axes = plt.subplots(3, 1, figsize=(6, 10), sharex=True)
|
| 152 |
+
epochs_list = list(range(1, len(epoch_losses) + 1))
|
| 153 |
+
axes[0].plot(epochs_list, epoch_losses, marker='o'); axes[0].set_ylabel('Loss'); axes[0].set_title('Training Loss'); axes[0].grid(True)
|
| 154 |
+
axes[1].plot(epochs_list, epoch_ppls, marker='o'); axes[1].set_ylabel('Perplexity'); axes[1].set_title('Training Perplexity'); axes[1].grid(True)
|
| 155 |
+
axes[2].plot(epochs_list, epoch_accs, marker='o'); axes[2].set_ylabel('MLM Accuracy'); axes[2].set_xlabel('Epoch'); axes[2].set_title('Masked LM Accuracy'); axes[2].grid(True)
|
| 156 |
+
plt.tight_layout(); plt.show()
|
| 157 |
+
|
| 158 |
+
# Save LoRA adapter
|
| 159 |
+
model.save_pretrained(ADAPTER_SAVE_DIR)
|
| 160 |
+
print(f"LoRA adapter saved to {ADAPTER_SAVE_DIR}")
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
if __name__ == '__main__':
|
| 164 |
+
fine_tune_lora()
|
images/methodology_flowchart.png
ADDED
|
Git LFS Details
|
images/test.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
main.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#from models import *
|
| 2 |
+
#from preprocess_data import *
|
| 3 |
+
from utils import extract_sentences_by_intent, train_model_with_chkpt, batch_predict_and_save
|
| 4 |
+
from time import time
|
| 5 |
+
import logging
|
| 6 |
+
from config import *
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
if __name__ == '__main__':
|
| 13 |
+
# train_model_with_chkpt(epochs=5, batch_size=16, lr=2e-3,
|
| 14 |
+
# save_model=True,
|
| 15 |
+
# save_path='gnn_model_checkpoint.pt',
|
| 16 |
+
# resume=True)
|
| 17 |
+
|
| 18 |
+
from transformers import BertTokenizer
|
| 19 |
+
from preprocess_data import SentenceDataset
|
| 20 |
+
from models import SentenceExtractionModel
|
| 21 |
+
from utils import train_sentence_extractor
|
| 22 |
+
|
| 23 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 24 |
+
dataset = SentenceDataset("data/Fin_ExBERT_train_val_data.xlsx", tokenizer)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
model = SentenceExtractionModel(
|
| 28 |
+
base_model_name=MODEL_NAME,
|
| 29 |
+
backbone='finexbert'
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# train_sentence_extractor(
|
| 33 |
+
# model,
|
| 34 |
+
# dataset,
|
| 35 |
+
# output_dir="checkpoints/sentence_extractor",
|
| 36 |
+
# val_split=0.3,
|
| 37 |
+
# epochs=10,
|
| 38 |
+
# batch_size=16,
|
| 39 |
+
# lr=3e-4,
|
| 40 |
+
# device=DEVICE,
|
| 41 |
+
# unfreeze_after_epoch=4
|
| 42 |
+
# )
|
| 43 |
+
#
|
| 44 |
+
# from transformers import BertTokenizer
|
| 45 |
+
# from models import SentenceExtractionModel
|
| 46 |
+
# from utils import demo_on_random_val
|
| 47 |
+
#
|
| 48 |
+
# tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
| 49 |
+
# model = SentenceExtractionModel(
|
| 50 |
+
# base_model_name=MODEL_NAME,
|
| 51 |
+
# backbone='finexbert'
|
| 52 |
+
# )
|
| 53 |
+
#
|
| 54 |
+
# demo_on_random_val(
|
| 55 |
+
# model,
|
| 56 |
+
# tokenizer,
|
| 57 |
+
# excel_path="data/Fin_ExBERT_test_set.xlsx",
|
| 58 |
+
# ckpt_path="checkpoints/sentence_extractor/best_model.pth",
|
| 59 |
+
# device="cuda", # or "cpu"
|
| 60 |
+
# temperature=1,
|
| 61 |
+
# )
|
| 62 |
+
|
| 63 |
+
batch_predict_and_save(
|
| 64 |
+
model,
|
| 65 |
+
tokenizer,
|
| 66 |
+
excel_path="data/Fin_ExBERT_test_set.xlsx",
|
| 67 |
+
ckpt_path="checkpoints/sentence_extractor/best_model.pth",
|
| 68 |
+
output_path="results/predictions_sample200.xlsx",
|
| 69 |
+
n_samples=200,
|
| 70 |
+
temperature=1.0,
|
| 71 |
+
device="cuda"
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
sample_transcript = """
|
| 75 |
+
Agent: Hello, thank you for calling Acme Financial Services. My name is Priya. How can I help you today?
|
| 76 |
+
Customer: Hi Priya, Iβm considering opening a new savings account with you.
|
| 77 |
+
Agent: Absolutelyβour savings account offers 4% interest per annum. Do you have a balance in mind?
|
| 78 |
+
Customer: Yes, Iβd like to deposit βΉ50,000 initially, and then Iβm interested in investing another βΉ2 lakh in mutual funds over the next month.
|
| 79 |
+
Agent: Great, we have several mutual fund options. Are you more growth-oriented or looking for steady income?
|
| 80 |
+
Customer: I want to focus on growth. Also, could you tell me about your home loan rates? I may need a βΉ30 lakh mortgage in the next six months.
|
| 81 |
+
Agent: Certainlyβwe currently offer home loan rates starting at 6.8%. Do you already own property or are you planning to buy?
|
| 82 |
+
Customer: Planning to buy. Finally, Iβd like to apply for a credit card with a high cashbackβmaybe one that gives 2% on all spends.
|
| 83 |
+
Agent: We have a Platinum Cashback Card at 1.5%, and our Signature Cashback Card at 2%. Would you like me to initiate the application?
|
| 84 |
+
Customer: Yes please, go ahead with the Signature Cashback Card, and send me the home-loan documents via email.
|
| 85 |
+
Agent: Done. Youβll receive an email shortly. Is there anything else I can help you with?
|
| 86 |
+
Customer: No, thatβs all for todayβthank you!
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
complex_transcript = """
|
| 90 |
+
Agent: Good morning, thank you for calling Maple Grove Bank. This is Rahul speakingβhow may I assist you today?
|
| 91 |
+
Customer: Hi Rahul, Iβve been reviewing my financial goals for the next five years and want to discuss a mix of savings, investments, and insurance.
|
| 92 |
+
Agent: Absolutely. Would you like to start with your current cash savings or jump straight into investment products?
|
| 93 |
+
Customer: Letβs begin with savings: Iβd like to open a high-yield savings account with at least βΉ1 lakh to start, and then set up an automatic top-up of βΉ10,000 each month.
|
| 94 |
+
Agent: Great choice. We have our βPlus Savingsβ account at 4.2% APY. Next, investmentsβare you looking at mutual funds, stocks, or retirement plans?
|
| 95 |
+
Customer: Iβm particularly interested in tax-saving ELSS mutual funds and a more conservative retirement pension plan. Also, could you explain your term insurance offerings?
|
| 96 |
+
Agent: Sureβour ELSS options include Fund A (equity-heavy) and Fund B (balanced). For term cover, we have 20-year plans up to βΉ50 lakhs. Any preference?
|
| 97 |
+
Customer: I want a balanced ELSS with a 3-year lock-in, and term insurance of βΉ30 lakhs for 25 years. After that, I may need advice on buying a second homeβso letβs also discuss mortgage pre-approval.
|
| 98 |
+
Agent: Understood. For a βΉ30 lakh home loan, current interest rates start at 6.9%. We can pre-approve you based on your income. Shall I proceed?
|
| 99 |
+
Customer: Yes, please initiate the home-loan pre-approval. And lastly, Iβd like to apply for a debit card with no annual fee and a co-branded credit card offering travel rewards.
|
| 100 |
+
Agent: Certainlyβour βFreedomβ debit card has no fee, and the βSkyMilesβ credit card gives 2 airline miles per βΉ100 spent. Would you like to complete those applications now?
|
| 101 |
+
Customer: Yes, go ahead with both. Also, can you set up a quarterly portfolio review call with a financial advisor?
|
| 102 |
+
Agent: Absolutely. Iβll schedule a review every three months starting next quarter. Youβll get email confirmations shortly.
|
| 103 |
+
Customer: Perfectβthat covers all my needs. Thanks for your help!
|
| 104 |
+
Agent: My pleasure! Have a great day and feel free to call back anytime.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
# premise_input = "personA is on the stage giving a speech."
|
| 108 |
+
# hypothesis_input = "personA is using a microphone."
|
| 109 |
+
# prediction, _ = predict_fin_nli(premise=premise_input, hypothesis=hypothesis_input, model_path='gnn_model_checkpoint.pt')
|
| 110 |
+
# print("Prediction:", prediction)
|
| 111 |
+
# print('Final layer logits:', _)
|
| 112 |
+
|
| 113 |
+
################################
|
| 114 |
+
|
| 115 |
+
# start = time()
|
| 116 |
+
# results = extract_sentences_by_intent(
|
| 117 |
+
# complex_transcript,
|
| 118 |
+
# intent="customer tells about own financial condition",#"customer states specific financial product requests and planning preferences",
|
| 119 |
+
# #"agent provides assistance", #"customer states their financial needs",
|
| 120 |
+
# threshold=0.60,
|
| 121 |
+
# top_k=10,
|
| 122 |
+
# convo_focus='customer'
|
| 123 |
+
# )
|
| 124 |
+
# end = time()
|
| 125 |
+
#
|
| 126 |
+
# logging.info('Prediction Done in {:.2f}sec'.format(end - start))
|
| 127 |
+
#
|
| 128 |
+
# for sentence, score in results:
|
| 129 |
+
# print(f"{score:.2f} β {sentence}")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
|
models.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
import math
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from peft import PeftModel, LoraConfig, get_peft_model
|
| 7 |
+
from transformers import AutoTokenizer, AutoModel, AutoConfig, get_linear_schedule_with_warmup
|
| 8 |
+
from torch.nn import MultiheadAttention, GELU
|
| 9 |
+
|
| 10 |
+
MODEL_NAME = "bert-base-uncased"
|
| 11 |
+
BATCH_SIZE = 16
|
| 12 |
+
MAX_LENGTH = 128
|
| 13 |
+
LEARNING_RATE = 2e-5
|
| 14 |
+
EPOCHS = 5
|
| 15 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 16 |
+
PREPROCESSED_DIR = "preprocessed_snli"
|
| 17 |
+
MIXED_PRECISION = "fp16"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SimpleGNN(nn.Module):
|
| 21 |
+
def __init__(self, input_dim, hidden_dim):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.fc = nn.Linear(input_dim, hidden_dim)
|
| 24 |
+
|
| 25 |
+
def forward(self, node_embeddings, edges):
|
| 26 |
+
if node_embeddings.size(0) == 0:
|
| 27 |
+
return torch.zeros(1, self.fc.out_features, device=node_embeddings.device)
|
| 28 |
+
num_nodes = node_embeddings.size(0)
|
| 29 |
+
adj = torch.zeros((num_nodes, num_nodes), device=node_embeddings.device)
|
| 30 |
+
for (src, dst) in edges:
|
| 31 |
+
if src < num_nodes and dst < num_nodes:
|
| 32 |
+
adj[src, dst] = 1.0
|
| 33 |
+
deg = adj.sum(dim=1, keepdim=True) + 1e-10
|
| 34 |
+
adj_norm = adj / deg
|
| 35 |
+
agg_embeddings = adj_norm @ node_embeddings
|
| 36 |
+
return F.relu(self.fc(agg_embeddings))
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class GraphAugmentedNLIModel(nn.Module):
|
| 40 |
+
def __init__(self, base_model_name, num_labels=3, hidden_dim=768, gnn_dim=128):
|
| 41 |
+
super().__init__()
|
| 42 |
+
config = AutoConfig.from_pretrained(base_model_name)
|
| 43 |
+
config.num_labels = num_labels
|
| 44 |
+
self.bert = AutoModel.from_pretrained(base_model_name, config=config)
|
| 45 |
+
self.dropout = nn.Dropout(0.1)
|
| 46 |
+
|
| 47 |
+
self.gnn_premise = SimpleGNN(hidden_dim, gnn_dim)
|
| 48 |
+
self.gnn_hypothesis = SimpleGNN(hidden_dim, gnn_dim)
|
| 49 |
+
|
| 50 |
+
self.classifier = nn.Linear(hidden_dim + gnn_dim*2, num_labels)
|
| 51 |
+
|
| 52 |
+
def forward(self, input_ids, attention_mask, premise_graph_tokens, premise_graph_edges, premise_node_indices,
|
| 53 |
+
hypothesis_graph_tokens, hypothesis_graph_edges, hypothesis_node_indices, labels=None):
|
| 54 |
+
outputs = self.bert(input_ids, attention_mask=attention_mask)
|
| 55 |
+
cls_embedding = outputs.last_hidden_state[:,0,:] # [batch, hidden_dim]
|
| 56 |
+
|
| 57 |
+
batch_size = input_ids.size(0)
|
| 58 |
+
gnn_p_outputs = []
|
| 59 |
+
gnn_h_outputs = []
|
| 60 |
+
|
| 61 |
+
# Now node indices are precomputed. We just take those embeddings directly.
|
| 62 |
+
# node_indices correspond to the positions in input_ids whose embeddings represent that node.
|
| 63 |
+
for i in range(batch_size):
|
| 64 |
+
instance_hidden = outputs.last_hidden_state[i] # [seq_len, hidden_dim]
|
| 65 |
+
|
| 66 |
+
p_edges = premise_graph_edges[i]
|
| 67 |
+
p_indices = premise_node_indices[i]
|
| 68 |
+
h_edges = hypothesis_graph_edges[i]
|
| 69 |
+
h_indices = hypothesis_node_indices[i]
|
| 70 |
+
|
| 71 |
+
# Gather node embeddings
|
| 72 |
+
p_nodes = instance_hidden[p_indices] if len(p_indices) > 0 else torch.empty(0, instance_hidden.size(-1), device=instance_hidden.device)
|
| 73 |
+
h_nodes = instance_hidden[h_indices] if len(h_indices) > 0 else torch.empty(0, instance_hidden.size(-1), device=instance_hidden.device)
|
| 74 |
+
|
| 75 |
+
p_gnn_out = self.gnn_premise(p_nodes, p_edges) if p_nodes.size(0) > 0 else torch.zeros(1,128, device=DEVICE)
|
| 76 |
+
h_gnn_out = self.gnn_hypothesis(h_nodes, h_edges) if h_nodes.size(0) > 0 else torch.zeros(1,128, device=DEVICE)
|
| 77 |
+
|
| 78 |
+
p_mean = p_gnn_out.mean(dim=0, keepdim=True)
|
| 79 |
+
h_mean = h_gnn_out.mean(dim=0, keepdim=True)
|
| 80 |
+
|
| 81 |
+
gnn_p_outputs.append(p_mean)
|
| 82 |
+
gnn_h_outputs.append(h_mean)
|
| 83 |
+
|
| 84 |
+
gnn_p_outputs = torch.cat(gnn_p_outputs, dim=0) # [batch, gnn_dim]
|
| 85 |
+
gnn_h_outputs = torch.cat(gnn_h_outputs, dim=0) # [batch, gnn_dim]
|
| 86 |
+
|
| 87 |
+
fused = torch.cat([cls_embedding, gnn_p_outputs, gnn_h_outputs], dim=-1)
|
| 88 |
+
fused = self.dropout(fused)
|
| 89 |
+
logits = self.classifier(fused)
|
| 90 |
+
|
| 91 |
+
loss = None
|
| 92 |
+
if labels is not None:
|
| 93 |
+
loss_fn = nn.CrossEntropyLoss()
|
| 94 |
+
loss = loss_fn(logits, labels)
|
| 95 |
+
return {"loss": loss, "logits": logits}
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class SimpleFinGNN(nn.Module):
|
| 100 |
+
def __init__(self, input_dim, hidden_dim):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.fc = nn.Linear(input_dim, hidden_dim)
|
| 103 |
+
|
| 104 |
+
def forward(self, node_embeddings, edges):
|
| 105 |
+
if node_embeddings.size(0) == 0:
|
| 106 |
+
return torch.zeros(1, self.fc.out_features, device=node_embeddings.device)
|
| 107 |
+
num_nodes = node_embeddings.size(0)
|
| 108 |
+
adj = torch.zeros((num_nodes, num_nodes), device=node_embeddings.device)
|
| 109 |
+
for (src, dst) in edges:
|
| 110 |
+
if src < num_nodes and dst < num_nodes:
|
| 111 |
+
adj[src, dst] = 1.0
|
| 112 |
+
deg = adj.sum(dim=1, keepdim=True) + 1e-10
|
| 113 |
+
adj_norm = adj / deg
|
| 114 |
+
agg_embeddings = adj_norm @ node_embeddings
|
| 115 |
+
return F.relu(self.fc(agg_embeddings))
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class GraphAugmentedFinNLIModel(nn.Module):
|
| 119 |
+
def __init__(self, base_model_name, num_labels=3, hidden_dim=768, gnn_dim=128):
|
| 120 |
+
super().__init__()
|
| 121 |
+
config = AutoConfig.from_pretrained(base_model_name)
|
| 122 |
+
config.num_labels = num_labels
|
| 123 |
+
self.bert = AutoModel.from_pretrained(base_model_name, config=config)
|
| 124 |
+
self.dropout = nn.Dropout(0.1)
|
| 125 |
+
|
| 126 |
+
self.gnn_premise = SimpleGNN(hidden_dim, gnn_dim)
|
| 127 |
+
self.gnn_hypothesis = SimpleGNN(hidden_dim, gnn_dim)
|
| 128 |
+
|
| 129 |
+
self.classifier = nn.Linear(hidden_dim + gnn_dim*2, num_labels)
|
| 130 |
+
self.config = self.bert.config
|
| 131 |
+
self.config.num_labels = num_labels
|
| 132 |
+
|
| 133 |
+
def forward(self,
|
| 134 |
+
input_ids=None,
|
| 135 |
+
attention_mask=None,
|
| 136 |
+
premise_graph_tokens=None,
|
| 137 |
+
hypothesis_graph_tokens=None,
|
| 138 |
+
premise_graph_edges=None,
|
| 139 |
+
hypothesis_graph_edges=None,
|
| 140 |
+
premise_node_indices=None,
|
| 141 |
+
hypothesis_node_indices=None,
|
| 142 |
+
labels=None,
|
| 143 |
+
inputs_embeds=None,
|
| 144 |
+
**kwargs):
|
| 145 |
+
# Even if we don't use inputs_embeds, we should pass it into self.bert call:
|
| 146 |
+
outputs = self.bert(input_ids=input_ids,
|
| 147 |
+
attention_mask=attention_mask,
|
| 148 |
+
inputs_embeds=inputs_embeds,
|
| 149 |
+
**{k:v for k,v in kwargs.items() if k in self.bert.forward.__code__.co_varnames})
|
| 150 |
+
|
| 151 |
+
cls_embedding = outputs.last_hidden_state[:,0,:] # [batch, hidden_dim]
|
| 152 |
+
|
| 153 |
+
batch_size = input_ids.size(0) if input_ids is not None else outputs.last_hidden_state.size(0)
|
| 154 |
+
gnn_p_outputs = []
|
| 155 |
+
gnn_h_outputs = []
|
| 156 |
+
|
| 157 |
+
for i in range(batch_size):
|
| 158 |
+
instance_hidden = outputs.last_hidden_state[i] # [seq_len, hidden_dim]
|
| 159 |
+
|
| 160 |
+
p_edges = premise_graph_edges[i]
|
| 161 |
+
p_indices = premise_node_indices[i]
|
| 162 |
+
h_edges = hypothesis_graph_edges[i]
|
| 163 |
+
h_indices = hypothesis_node_indices[i]
|
| 164 |
+
|
| 165 |
+
p_nodes = instance_hidden[p_indices] if len(p_indices) > 0 else torch.empty(0, instance_hidden.size(-1), device=instance_hidden.device)
|
| 166 |
+
h_nodes = instance_hidden[h_indices] if len(h_indices) > 0 else torch.empty(0, instance_hidden.size(-1), device=instance_hidden.device)
|
| 167 |
+
|
| 168 |
+
p_gnn_out = self.gnn_premise(p_nodes, p_edges) if p_nodes.size(0) > 0 else torch.zeros(1,128, device=instance_hidden.device)
|
| 169 |
+
h_gnn_out = self.gnn_hypothesis(h_nodes, h_edges) if h_nodes.size(0) > 0 else torch.zeros(1,128, device=instance_hidden.device)
|
| 170 |
+
|
| 171 |
+
p_mean = p_gnn_out.mean(dim=0, keepdim=True)
|
| 172 |
+
h_mean = h_gnn_out.mean(dim=0, keepdim=True)
|
| 173 |
+
|
| 174 |
+
gnn_p_outputs.append(p_mean)
|
| 175 |
+
gnn_h_outputs.append(h_mean)
|
| 176 |
+
|
| 177 |
+
gnn_p_outputs = torch.cat(gnn_p_outputs, dim=0) # [batch, gnn_dim]
|
| 178 |
+
gnn_h_outputs = torch.cat(gnn_h_outputs, dim=0) # [batch, gnn_dim]
|
| 179 |
+
|
| 180 |
+
fused = torch.cat([cls_embedding, gnn_p_outputs, gnn_h_outputs], dim=-1)
|
| 181 |
+
logits = self.classifier(fused)
|
| 182 |
+
|
| 183 |
+
loss = None
|
| 184 |
+
if labels is not None:
|
| 185 |
+
loss_fn = nn.CrossEntropyLoss()
|
| 186 |
+
loss = loss_fn(logits, labels)
|
| 187 |
+
return {"loss": loss, "logits": logits}
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class SentenceExtractionModel(nn.Module):
|
| 191 |
+
def __init__(self,
|
| 192 |
+
base_model_name: str,
|
| 193 |
+
dropout_prob: float = 0.1,
|
| 194 |
+
adapter_dir: str = "./lora_finance_adapter",
|
| 195 |
+
backbone: str = 'default',
|
| 196 |
+
init_pos_frac: float = None # NEW!
|
| 197 |
+
):
|
| 198 |
+
"""
|
| 199 |
+
backbone:
|
| 200 |
+
- 'default' β plain AutoModel.from_pretrained(base_model_name)
|
| 201 |
+
- 'finexbert' β use the .bert submodule of your GraphAugmentedFinNLIModel
|
| 202 |
+
"""
|
| 203 |
+
super().__init__()
|
| 204 |
+
|
| 205 |
+
# load config
|
| 206 |
+
config = AutoConfig.from_pretrained(base_model_name)
|
| 207 |
+
|
| 208 |
+
if backbone == 'default':
|
| 209 |
+
# plain BERT
|
| 210 |
+
self.bert = AutoModel.from_pretrained(base_model_name, config=config)
|
| 211 |
+
|
| 212 |
+
elif backbone == 'finexbert':
|
| 213 |
+
# instantiate your full FinNLI model, then grab its .bert
|
| 214 |
+
base_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
|
| 215 |
+
lora_cfg = LoraConfig(
|
| 216 |
+
r=8,
|
| 217 |
+
lora_alpha=32,
|
| 218 |
+
lora_dropout=0.1,
|
| 219 |
+
bias="none",
|
| 220 |
+
task_type="SEQ_CLS"#"CAUSAL_LM", # must match your fine-tune setting
|
| 221 |
+
)
|
| 222 |
+
full = get_peft_model(base_model, lora_cfg).to(DEVICE)
|
| 223 |
+
chkpt_path = os.path.join(adapter_dir, "training_checkpoint.pt")
|
| 224 |
+
if not os.path.isfile(chkpt_path):
|
| 225 |
+
raise FileNotFoundError(f"No LoRA checkpoint at {chkpt_path}")
|
| 226 |
+
ckpt = torch.load(chkpt_path, map_location=DEVICE)
|
| 227 |
+
# ckpt["model_state_dict"] contains both base + LoRA weights; strict=False
|
| 228 |
+
full.load_state_dict(ckpt["model_state_dict"], strict=False)
|
| 229 |
+
# if you have a saved finexbert checkpoint, load it here:
|
| 230 |
+
# full.load_state_dict(torch.load("path/to/finexbert.pth", map_location='cpu'))
|
| 231 |
+
self.bert = full.base_model
|
| 232 |
+
|
| 233 |
+
else:
|
| 234 |
+
raise ValueError(f"Unknown backbone {backbone}")
|
| 235 |
+
|
| 236 |
+
hidden_size = self.bert.config.hidden_size
|
| 237 |
+
|
| 238 |
+
self.dropout = nn.Dropout(dropout_prob)
|
| 239 |
+
self.classifier = nn.Linear(hidden_size, 1)
|
| 240 |
+
|
| 241 |
+
# initialize bias to log-odds of init_pos_frac
|
| 242 |
+
if init_pos_frac is not None:
|
| 243 |
+
b0 = float(math.log(init_pos_frac / (1.0 - init_pos_frac)))
|
| 244 |
+
self.classifier.bias.data.fill_(b0)
|
| 245 |
+
|
| 246 |
+
def forward(self, input_ids, attention_mask):
|
| 247 |
+
outputs = self.bert(input_ids=input_ids,
|
| 248 |
+
attention_mask=attention_mask)
|
| 249 |
+
x = self.dropout(outputs.pooler_output)
|
| 250 |
+
logits = self.classifier(x).squeeze(-1) # [batch]
|
| 251 |
+
return logits
|
preprocess_data.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import Dataset
|
| 5 |
+
from datasets import load_dataset, load_from_disk
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import nltk
|
| 8 |
+
|
| 9 |
+
from config import MODEL_NAME, MAX_LENGTH, OVERLAP, PREPROCESSED_DIR, tokenizer, nlp
|
| 10 |
+
|
| 11 |
+
# =============================
|
| 12 |
+
# Logging Setup
|
| 13 |
+
# =============================
|
| 14 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
| 15 |
+
|
| 16 |
+
# =============================
|
| 17 |
+
# One-Time Preprocessing
|
| 18 |
+
# =============================
|
| 19 |
+
def process_data():
|
| 20 |
+
if not os.path.exists(PREPROCESSED_DIR):
|
| 21 |
+
logging.info("Preprocessing data... This may take a while.")
|
| 22 |
+
# Load and filter SNLI
|
| 23 |
+
snli = load_dataset("snli")
|
| 24 |
+
snli = snli.filter(lambda x: x["label"] != -1)
|
| 25 |
+
|
| 26 |
+
def build_dependency_graph(sentence):
|
| 27 |
+
doc = nlp(sentence)
|
| 28 |
+
tokens = [tok.text for tok in doc]
|
| 29 |
+
edges = []
|
| 30 |
+
for tok in doc:
|
| 31 |
+
if tok.head.i != tok.i:
|
| 32 |
+
edges.extend([(tok.i, tok.head.i), (tok.head.i, tok.i)])
|
| 33 |
+
return tokens, edges
|
| 34 |
+
|
| 35 |
+
def preprocess(examples):
|
| 36 |
+
premises = examples["premise"]
|
| 37 |
+
hypotheses = examples["hypothesis"]
|
| 38 |
+
labels = examples["label"]
|
| 39 |
+
tokenized = tokenizer(premises, hypotheses,
|
| 40 |
+
truncation=True, padding="max_length",
|
| 41 |
+
max_length=MAX_LENGTH)
|
| 42 |
+
tokenized["labels"] = labels
|
| 43 |
+
|
| 44 |
+
p_tokens_list, p_edges_list, p_idx_list = [], [], []
|
| 45 |
+
h_tokens_list, h_edges_list, h_idx_list = [], [], []
|
| 46 |
+
|
| 47 |
+
for p, h, input_ids in zip(premises, hypotheses, tokenized["input_ids"]):
|
| 48 |
+
p_toks, p_edges = build_dependency_graph(p)
|
| 49 |
+
h_toks, h_edges = build_dependency_graph(h)
|
| 50 |
+
wp_tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
| 51 |
+
|
| 52 |
+
def align_tokens(spacy_tokens, wp_tokens):
|
| 53 |
+
node_indices, wp_idx = [], 1
|
| 54 |
+
for _ in spacy_tokens:
|
| 55 |
+
if wp_idx >= len(wp_tokens) - 1: break
|
| 56 |
+
node_indices.append(wp_idx)
|
| 57 |
+
wp_idx += 1
|
| 58 |
+
while wp_idx < len(wp_tokens) - 1 and wp_tokens[wp_idx].startswith("##"):
|
| 59 |
+
wp_idx += 1
|
| 60 |
+
return node_indices
|
| 61 |
+
|
| 62 |
+
p_idx = align_tokens(p_toks, wp_tokens)
|
| 63 |
+
h_idx = align_tokens(h_toks, wp_tokens)
|
| 64 |
+
|
| 65 |
+
p_tokens_list.append(p_toks)
|
| 66 |
+
p_edges_list.append(p_edges)
|
| 67 |
+
p_idx_list.append(p_idx)
|
| 68 |
+
|
| 69 |
+
h_tokens_list.append(h_toks)
|
| 70 |
+
h_edges_list.append(h_edges)
|
| 71 |
+
h_idx_list.append(h_idx)
|
| 72 |
+
|
| 73 |
+
tokenized.update({
|
| 74 |
+
"premise_graph_tokens": p_tokens_list,
|
| 75 |
+
"premise_graph_edges": p_edges_list,
|
| 76 |
+
"premise_node_indices": p_idx_list,
|
| 77 |
+
"hypothesis_graph_tokens": h_tokens_list,
|
| 78 |
+
"hypothesis_graph_edges": h_edges_list,
|
| 79 |
+
"hypothesis_node_indices": h_idx_list,
|
| 80 |
+
})
|
| 81 |
+
return tokenized
|
| 82 |
+
|
| 83 |
+
snli = snli.map(preprocess, batched=True)
|
| 84 |
+
snli.save_to_disk(PREPROCESSED_DIR)
|
| 85 |
+
logging.info(f"Preprocessing complete. Saved to {PREPROCESSED_DIR}")
|
| 86 |
+
else:
|
| 87 |
+
logging.info("Using existing preprocessed data at %s", PREPROCESSED_DIR)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def chunk_transcript(transcript_text, start_idx, end_idx, tokenizer):
|
| 91 |
+
encoded = tokenizer(transcript_text,
|
| 92 |
+
return_offsets_mapping=True,
|
| 93 |
+
add_special_tokens=True,
|
| 94 |
+
return_tensors=None,
|
| 95 |
+
max_length=1024,
|
| 96 |
+
padding=False,
|
| 97 |
+
truncation=False)
|
| 98 |
+
all_input_ids = encoded["input_ids"]
|
| 99 |
+
all_offsets = encoded["offset_mapping"]
|
| 100 |
+
|
| 101 |
+
chunks = []
|
| 102 |
+
i = 0
|
| 103 |
+
while i < len(all_input_ids):
|
| 104 |
+
chunk_ids = all_input_ids[i : i + MAX_LENGTH]
|
| 105 |
+
chunk_offsets = all_offsets[i : i + MAX_LENGTH]
|
| 106 |
+
attention_mask = [1] * len(chunk_ids)
|
| 107 |
+
|
| 108 |
+
no_span = 1
|
| 109 |
+
start_token, end_token = -1, -1
|
| 110 |
+
if start_idx >= 0 and end_idx >= 0:
|
| 111 |
+
for j, (off_s, off_e) in enumerate(chunk_offsets):
|
| 112 |
+
if off_s <= start_idx < off_e:
|
| 113 |
+
start_token = j
|
| 114 |
+
if off_s < end_idx <= off_e:
|
| 115 |
+
end_token = j
|
| 116 |
+
break
|
| 117 |
+
if 0 <= start_token <= end_token:
|
| 118 |
+
no_span = 0
|
| 119 |
+
else:
|
| 120 |
+
start_token, end_token = -1, -1
|
| 121 |
+
|
| 122 |
+
chunks.append({
|
| 123 |
+
"input_ids": torch.tensor(chunk_ids, dtype=torch.long),
|
| 124 |
+
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
|
| 125 |
+
"start_label": start_token,
|
| 126 |
+
"end_label": end_token,
|
| 127 |
+
"no_span_label": no_span,
|
| 128 |
+
})
|
| 129 |
+
i += (MAX_LENGTH - OVERLAP)
|
| 130 |
+
return chunks
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class SpanExtractionChunkedDataset(Dataset):
|
| 134 |
+
def __init__(self, data):
|
| 135 |
+
self.samples = []
|
| 136 |
+
for item in data:
|
| 137 |
+
chunks = chunk_transcript(
|
| 138 |
+
item.get("transcript", ""),
|
| 139 |
+
item.get("start_idx", -1),
|
| 140 |
+
item.get("end_idx", -1),
|
| 141 |
+
tokenizer)
|
| 142 |
+
self.samples.extend(chunks)
|
| 143 |
+
|
| 144 |
+
def __len__(self):
|
| 145 |
+
return len(self.samples)
|
| 146 |
+
|
| 147 |
+
def __getitem__(self, idx):
|
| 148 |
+
return self.samples[idx]
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def span_collate_fn(batch):
|
| 152 |
+
max_len = max(len(x["input_ids"]) for x in batch)
|
| 153 |
+
inputs, masks, starts, ends, nos = [], [], [], [], []
|
| 154 |
+
for x in batch:
|
| 155 |
+
pad = max_len - len(x["input_ids"])
|
| 156 |
+
inputs.append(torch.cat([x["input_ids"], torch.zeros(pad, dtype=torch.long)]).unsqueeze(0))
|
| 157 |
+
masks.append(torch.cat([x["attention_mask"], torch.zeros(pad, dtype=torch.long)]).unsqueeze(0))
|
| 158 |
+
starts.append(x["start_label"])
|
| 159 |
+
ends.append(x["end_label"])
|
| 160 |
+
nos.append(x["no_span_label"])
|
| 161 |
+
return {
|
| 162 |
+
"input_ids": torch.cat(inputs, dim=0),
|
| 163 |
+
"attention_mask": torch.cat(masks, dim=0),
|
| 164 |
+
"start_positions": torch.tensor(starts, dtype=torch.long),
|
| 165 |
+
"end_positions": torch.tensor(ends, dtype=torch.long),
|
| 166 |
+
"no_span_label": torch.tensor(nos, dtype=torch.long),
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
nltk.download('punkt')
|
| 171 |
+
nltk.download('punkt_tab')
|
| 172 |
+
|
| 173 |
+
class SentenceDataset(Dataset):
|
| 174 |
+
def __init__(self,
|
| 175 |
+
excel_path: str,
|
| 176 |
+
tokenizer,
|
| 177 |
+
max_length: int = 128):
|
| 178 |
+
df = pd.read_excel(excel_path)
|
| 179 |
+
self.samples = []
|
| 180 |
+
|
| 181 |
+
for _, row in df.iterrows():
|
| 182 |
+
transcript = str(row['Claude_Call'])
|
| 183 |
+
gold_sentences = row['Sel_K']
|
| 184 |
+
# if it's a string repr of list, eval it
|
| 185 |
+
if isinstance(gold_sentences, str):
|
| 186 |
+
gold_sentences = eval(gold_sentences)
|
| 187 |
+
|
| 188 |
+
# split into sentences
|
| 189 |
+
sentences = nltk.sent_tokenize(transcript)
|
| 190 |
+
for sent in sentences:
|
| 191 |
+
label = 1 if sent in gold_sentences else 0
|
| 192 |
+
|
| 193 |
+
enc = tokenizer.encode_plus(
|
| 194 |
+
sent,
|
| 195 |
+
max_length=max_length,
|
| 196 |
+
padding='max_length',
|
| 197 |
+
truncation=True,
|
| 198 |
+
return_tensors='pt'
|
| 199 |
+
)
|
| 200 |
+
self.samples.append({
|
| 201 |
+
'input_ids': enc['input_ids'].squeeze(0),
|
| 202 |
+
'attention_mask': enc['attention_mask'].squeeze(0),
|
| 203 |
+
'label': torch.tensor(label, dtype=torch.float)
|
| 204 |
+
})
|
| 205 |
+
|
| 206 |
+
def __len__(self):
|
| 207 |
+
return len(self.samples)
|
| 208 |
+
|
| 209 |
+
def __getitem__(self, idx):
|
| 210 |
+
return self.samples[idx]
|
requirements.txt
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=1.9.0
|
| 2 |
+
transformers>=4.20.0
|
| 3 |
+
datasets>=2.0.0
|
| 4 |
+
spacy>=3.4.0
|
| 5 |
+
networkx>=2.8
|
| 6 |
+
numpy>=1.21.0
|
| 7 |
+
pandas>=1.3.0
|
| 8 |
+
scikit-learn>=1.0.0
|
| 9 |
+
tqdm>=4.62.0
|
| 10 |
+
matplotlib>=3.5.0
|
| 11 |
+
accelerate>=0.20.0
|
| 12 |
+
peft>=0.4.0
|
| 13 |
+
openpyxl>=3.0.0
|
| 14 |
+
nltk>=3.7
|
| 15 |
+
|
| 16 |
+
# requirements-dev.txt
|
| 17 |
+
pytest>=6.0.0
|
| 18 |
+
pytest-cov>=3.0.0
|
| 19 |
+
black>=22.0.0
|
| 20 |
+
isort>=5.10.0
|
| 21 |
+
flake8>=4.0.0
|
| 22 |
+
mypy>=0.950
|
| 23 |
+
pre-commit>=2.15.0
|
| 24 |
+
sphinx>=4.0.0
|
| 25 |
+
sphinx-rtd-theme>=1.0.0
|
| 26 |
+
jupyter>=1.0.0
|
results/Ablation.txt
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2025-05-21 16:21:26,783 INFO Using existing preprocessed data at preprocessed_snli
|
| 2 |
+
2025-05-21 16:21:28,008 INFO --- Training Baseline-BERT on 10% of data ---
|
| 3 |
+
2025-05-21 16:33:27,656 INFO Baseline-BERT Epoch 1: train_loss=1.1084
|
| 4 |
+
2025-05-21 16:45:29,223 INFO Baseline-BERT Epoch 2: train_loss=1.1015
|
| 5 |
+
2025-05-21 16:57:31,263 INFO Baseline-BERT Epoch 3: train_loss=1.1008
|
| 6 |
+
2025-05-21 17:09:33,145 INFO Baseline-BERT Epoch 4: train_loss=1.1021
|
| 7 |
+
2025-05-21 17:21:33,012 INFO Baseline-BERT Epoch 5: train_loss=1.1034
|
| 8 |
+
2025-05-21 17:21:43,717 INFO Baseline-BERT on 10% val β acc=0.3232, f1=0.1628
|
| 9 |
+
2025-05-21 17:21:43,717 INFO --- Training GNN-Augmented on 10% of data ---
|
| 10 |
+
2025-05-21 17:35:02,212 INFO GNN-Augmented Epoch 1: train_loss=1.1044
|
| 11 |
+
2025-05-21 17:48:21,592 INFO GNN-Augmented Epoch 2: train_loss=1.1041
|
| 12 |
+
2025-05-21 18:01:40,046 INFO GNN-Augmented Epoch 3: train_loss=1.1025
|
| 13 |
+
2025-05-21 18:14:58,966 INFO GNN-Augmented Epoch 4: train_loss=1.1019
|
| 14 |
+
2025-05-21 18:29:19,473 INFO GNN-Augmented Epoch 5: train_loss=1.1038
|
| 15 |
+
2025-05-21 18:29:34,558 INFO GNN-Augmented on 10% val β acc=0.3232, f1=0.1628
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
2025-05-21 16:21:26,783 INFO Using existing preprocessed data at preprocessed_snli
|
| 20 |
+
2025-05-21 16:21:28,008 INFO --- Training Baseline-BERT on 10% of data ---
|
| 21 |
+
2025-05-21 16:33:27,656 INFO Baseline-BERT Epoch 1: train_loss=1.1084
|
| 22 |
+
2025-05-21 16:45:29,223 INFO Baseline-BERT Epoch 2: train_loss=1.1015
|
| 23 |
+
2025-05-21 16:57:31,263 INFO Baseline-BERT Epoch 3: train_loss=1.1008
|
| 24 |
+
2025-05-21 17:09:33,145 INFO Baseline-BERT Epoch 4: train_loss=1.1021
|
| 25 |
+
2025-05-21 17:21:33,012 INFO Baseline-BERT Epoch 5: train_loss=1.1001
|
| 26 |
+
2025-05-21 17:21:43,717 INFO Baseline-BERT on 10% val β acc=0.3232, f1=0.1628
|
| 27 |
+
2025-05-21 17:21:43,717 INFO --- Training GNN-Augmented on 10% of data ---
|
| 28 |
+
2025-05-21 17:35:02,212 INFO GNN-Augmented Epoch 1: train_loss=1.1044
|
| 29 |
+
2025-05-21 17:48:21,592 INFO GNN-Augmented Epoch 2: train_loss=1.1011
|
| 30 |
+
2025-05-21 18:01:40,046 INFO GNN-Augmented Epoch 3: train_loss=0.9025
|
| 31 |
+
2025-05-21 18:14:58,966 INFO GNN-Augmented Epoch 4: train_loss=0.8319
|
| 32 |
+
2025-05-21 18:29:19,473 INFO GNN-Augmented Epoch 5: train_loss=0.7638
|
| 33 |
+
2025-05-21 18:29:34,558 INFO GNN-Augmented on 10% val β acc=0.6937, f1=0.4184
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
βΆοΈ Perβdataset judge averages:
|
| 39 |
+
FinQA-10K: Judge1=4.78, Judge2=4.46, Judge3=4.42, Combined=4.55
|
| 40 |
+
SQuAD: Judge1=5.00, Judge2=4.84, Judge3=4.52, Combined=4.79
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
βΆοΈ Perβdataset judge averages:
|
| 44 |
+
FinQA-10K: Judge1=4.96, Judge2=4.86, Judge3=4.68, Combined=4.84
|
| 45 |
+
SQuAD: Judge1=5.00, Judge2=4.94, Judge3=4.84, Combined=4.93
|
| 46 |
+
|
results/Fin-ExBERT.pptx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4ef24ff9e727c97ddca4cb52a818c11ef4b61667d76a516cbddc48da3b4d85b2
|
| 3 |
+
size 14227310
|
results/ablation_study.png
ADDED
|
results/combined_results.xlsx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a43f4de8658d50e234dd910d74b2b34f5e5d1fd0c910a2e0e830dd3360906a19
|
| 3 |
+
size 127995
|
results/fine_tuning_results.png
ADDED
|
results/methods_summary.xlsx
ADDED
|
Binary file (10.7 kB). View file
|
|
|
utils.py
ADDED
|
@@ -0,0 +1,967 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from nltk import sent_tokenize
|
| 9 |
+
from sklearn.metrics import accuracy_score, precision_score, f1_score
|
| 10 |
+
from sklearn.model_selection import train_test_split
|
| 11 |
+
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler
|
| 12 |
+
from transformers import AutoTokenizer, AutoModel, AutoConfig, get_linear_schedule_with_warmup
|
| 13 |
+
from peft import PeftModel, LoraConfig, get_peft_model
|
| 14 |
+
from datasets import load_dataset, DatasetDict, load_from_disk
|
| 15 |
+
import spacy
|
| 16 |
+
import re
|
| 17 |
+
from tqdm.auto import tqdm
|
| 18 |
+
from accelerate import Accelerator
|
| 19 |
+
import matplotlib.pyplot as plt
|
| 20 |
+
from torch.optim import AdamW
|
| 21 |
+
import pandas as pd
|
| 22 |
+
from typing import Optional, Tuple, List, Dict
|
| 23 |
+
|
| 24 |
+
from models import GraphAugmentedNLIModel, GraphAugmentedFinNLIModel
|
| 25 |
+
from preprocess_data import SpanExtractionChunkedDataset, process_data, chunk_transcript, span_collate_fn
|
| 26 |
+
|
| 27 |
+
# =============================
|
| 28 |
+
# Configuration Constants
|
| 29 |
+
# =============================
|
| 30 |
+
from config import MODEL_NAME, MAX_LENGTH, OVERLAP, PREPROCESSED_DIR, tokenizer, nlp
|
| 31 |
+
|
| 32 |
+
#MODEL_NAME = "bert-base-uncased"
|
| 33 |
+
BATCH_SIZE = 16
|
| 34 |
+
#MAX_LENGTH = 128
|
| 35 |
+
#OVERLAP = 32
|
| 36 |
+
LEARNING_RATE = 2e-5
|
| 37 |
+
EPOCHS = 5
|
| 38 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 39 |
+
#PREPROCESSED_DIR = "preprocessed_snli"
|
| 40 |
+
MIXED_PRECISION = "fp16"
|
| 41 |
+
|
| 42 |
+
# label mapping
|
| 43 |
+
label_map = {0: "entailment", 1: "neutral", 2: "contradiction"}
|
| 44 |
+
|
| 45 |
+
# =============================
|
| 46 |
+
# Logging & Reproducibility
|
| 47 |
+
# =============================
|
| 48 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
| 49 |
+
def set_seed(seed: int = 42):
|
| 50 |
+
random.seed(seed)
|
| 51 |
+
np.random.seed(seed)
|
| 52 |
+
torch.manual_seed(seed)
|
| 53 |
+
if torch.cuda.is_available():
|
| 54 |
+
torch.cuda.manual_seed_all(seed)
|
| 55 |
+
|
| 56 |
+
# =============================
|
| 57 |
+
# Tokenizer & NLP Model
|
| 58 |
+
# =============================
|
| 59 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 60 |
+
nlp = spacy.load("en_core_web_sm")
|
| 61 |
+
|
| 62 |
+
# =============================
|
| 63 |
+
# Dependency Graph Helpers
|
| 64 |
+
# =============================
|
| 65 |
+
def build_dependency_graph(sentence: str):
|
| 66 |
+
doc = nlp(sentence)
|
| 67 |
+
tokens = [token.text for token in doc]
|
| 68 |
+
edges = []
|
| 69 |
+
for token in doc:
|
| 70 |
+
if token.head.i != token.i:
|
| 71 |
+
edges.append((token.i, token.head.i))
|
| 72 |
+
edges.append((token.head.i, token.i))
|
| 73 |
+
return tokens, edges
|
| 74 |
+
|
| 75 |
+
# =============================
|
| 76 |
+
# Token Alignment
|
| 77 |
+
# =============================
|
| 78 |
+
def align_tokens(spacy_tokens, wp_tokens):
|
| 79 |
+
node_indices = []
|
| 80 |
+
wp_idx = 1 # after [CLS]
|
| 81 |
+
for _ in spacy_tokens:
|
| 82 |
+
if wp_idx >= len(wp_tokens) - 1:
|
| 83 |
+
break
|
| 84 |
+
node_indices.append(wp_idx)
|
| 85 |
+
wp_idx += 1
|
| 86 |
+
while wp_idx < len(wp_tokens) - 1 and wp_tokens[wp_idx].startswith("##"):
|
| 87 |
+
wp_idx += 1
|
| 88 |
+
return node_indices
|
| 89 |
+
|
| 90 |
+
# =============================
|
| 91 |
+
# Data Collation
|
| 92 |
+
# =============================
|
| 93 |
+
def my_collate_fn(batch):
|
| 94 |
+
input_ids = [torch.tensor(ex["input_ids"], dtype=torch.long) for ex in batch]
|
| 95 |
+
attention_mask = [torch.tensor(ex["attention_mask"], dtype=torch.long) for ex in batch]
|
| 96 |
+
labels = [ex.get("labels", None) for ex in batch]
|
| 97 |
+
|
| 98 |
+
premise_graph_tokens = [ex.get("premise_graph_tokens") for ex in batch]
|
| 99 |
+
premise_graph_edges = [ex.get("premise_graph_edges") for ex in batch]
|
| 100 |
+
premise_node_indices = [ex.get("premise_node_indices") for ex in batch]
|
| 101 |
+
|
| 102 |
+
hypothesis_graph_tokens = [ex.get("hypothesis_graph_tokens") for ex in batch]
|
| 103 |
+
hypothesis_graph_edges = [ex.get("hypothesis_graph_edges") for ex in batch]
|
| 104 |
+
hypothesis_node_indices = [ex.get("hypothesis_node_indices") for ex in batch]
|
| 105 |
+
|
| 106 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
|
| 107 |
+
attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)
|
| 108 |
+
labels = torch.tensor(labels, dtype=torch.long) if labels and labels[0] is not None else None
|
| 109 |
+
|
| 110 |
+
return {
|
| 111 |
+
"input_ids": input_ids,
|
| 112 |
+
"attention_mask": attention_mask,
|
| 113 |
+
"labels": labels,
|
| 114 |
+
"premise_graph_tokens": premise_graph_tokens,
|
| 115 |
+
"premise_graph_edges": premise_graph_edges,
|
| 116 |
+
"premise_node_indices": premise_node_indices,
|
| 117 |
+
"hypothesis_graph_tokens": hypothesis_graph_tokens,
|
| 118 |
+
"hypothesis_graph_edges": hypothesis_graph_edges,
|
| 119 |
+
"hypothesis_node_indices": hypothesis_node_indices,
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
# =============================
|
| 123 |
+
# Training Loop
|
| 124 |
+
# =============================
|
| 125 |
+
def train_model(epochs: int = EPOCHS,
|
| 126 |
+
batch_size: int = BATCH_SIZE,
|
| 127 |
+
lr: float = LEARNING_RATE,
|
| 128 |
+
save_model: bool = False,
|
| 129 |
+
save_path: str = 'gnn_model_weights_3.pt'):
|
| 130 |
+
set_seed()
|
| 131 |
+
process_data()
|
| 132 |
+
logging.info("Loading preprocessed dataset...")
|
| 133 |
+
snli = load_from_disk(PREPROCESSED_DIR)
|
| 134 |
+
snli.set_format("python", output_all_columns=True)
|
| 135 |
+
|
| 136 |
+
train_loader = DataLoader(snli["train"], batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn)
|
| 137 |
+
val_loader = DataLoader(snli["validation"], batch_size=batch_size, collate_fn=my_collate_fn)
|
| 138 |
+
|
| 139 |
+
model = GraphAugmentedNLIModel(MODEL_NAME).to(DEVICE)
|
| 140 |
+
|
| 141 |
+
if hasattr(model.bert, 'gradient_checkpointing_enable'):
|
| 142 |
+
model.bert.gradient_checkpointing_enable()
|
| 143 |
+
logging.info("Enabled gradient checkpointing on BERT.")
|
| 144 |
+
|
| 145 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
| 146 |
+
num_training_steps = epochs * len(train_loader)
|
| 147 |
+
lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=num_training_steps)
|
| 148 |
+
|
| 149 |
+
accelerator = Accelerator(mixed_precision=MIXED_PRECISION)
|
| 150 |
+
model, optimizer, train_loader, val_loader, lr_scheduler = accelerator.prepare(
|
| 151 |
+
model, optimizer, train_loader, val_loader, lr_scheduler
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
model.train()
|
| 155 |
+
all_losses = []
|
| 156 |
+
epoch_losses = []
|
| 157 |
+
best_val_loss = float('inf')
|
| 158 |
+
best_epoch = 0
|
| 159 |
+
|
| 160 |
+
for epoch in range(1, epochs + 1):
|
| 161 |
+
epoch_loss = []
|
| 162 |
+
progress = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", leave=False)
|
| 163 |
+
for batch in progress:
|
| 164 |
+
labels = batch["labels"].to(DEVICE) if batch.get("labels") is not None else None
|
| 165 |
+
outputs = model(
|
| 166 |
+
input_ids=batch["input_ids"].to(DEVICE),
|
| 167 |
+
attention_mask=batch["attention_mask"].to(DEVICE),
|
| 168 |
+
premise_graph_tokens=batch["premise_graph_tokens"],
|
| 169 |
+
premise_graph_edges=batch["premise_graph_edges"],
|
| 170 |
+
premise_node_indices=batch["premise_node_indices"],
|
| 171 |
+
hypothesis_graph_tokens=batch["hypothesis_graph_tokens"],
|
| 172 |
+
hypothesis_graph_edges=batch["hypothesis_graph_edges"],
|
| 173 |
+
hypothesis_node_indices=batch["hypothesis_node_indices"],
|
| 174 |
+
labels=labels
|
| 175 |
+
)
|
| 176 |
+
loss = outputs.get("loss") if isinstance(outputs, dict) else outputs
|
| 177 |
+
|
| 178 |
+
optimizer.zero_grad()
|
| 179 |
+
accelerator.backward(loss)
|
| 180 |
+
optimizer.step()
|
| 181 |
+
lr_scheduler.step()
|
| 182 |
+
|
| 183 |
+
loss_val = loss.item()
|
| 184 |
+
epoch_loss.append(loss_val)
|
| 185 |
+
all_losses.append(loss_val)
|
| 186 |
+
progress.set_postfix({"loss": f"{loss_val:.4f}"})
|
| 187 |
+
|
| 188 |
+
avg_epoch_loss = np.mean(epoch_loss)
|
| 189 |
+
epoch_losses.append(avg_epoch_loss)
|
| 190 |
+
logging.info(f"Epoch {epoch} completed. Avg Loss: {avg_epoch_loss:.4f}")
|
| 191 |
+
|
| 192 |
+
# Validation
|
| 193 |
+
model.eval()
|
| 194 |
+
val_losses = []
|
| 195 |
+
with torch.no_grad():
|
| 196 |
+
for batch in val_loader:
|
| 197 |
+
labels = batch["labels"].to(DEVICE) if batch.get("labels") is not None else None
|
| 198 |
+
outputs = model(
|
| 199 |
+
input_ids=batch["input_ids"].to(DEVICE),
|
| 200 |
+
attention_mask=batch["attention_mask"].to(DEVICE),
|
| 201 |
+
premise_graph_tokens=batch["premise_graph_tokens"],
|
| 202 |
+
premise_graph_edges=batch["premise_graph_edges"],
|
| 203 |
+
premise_node_indices=batch["premise_node_indices"],
|
| 204 |
+
hypothesis_graph_tokens=batch["hypothesis_graph_tokens"],
|
| 205 |
+
hypothesis_graph_edges=batch["hypothesis_graph_edges"],
|
| 206 |
+
hypothesis_node_indices=batch["hypothesis_node_indices"],
|
| 207 |
+
labels=labels
|
| 208 |
+
)
|
| 209 |
+
loss_item = outputs.get("loss").item() if isinstance(outputs, dict) else outputs.item()
|
| 210 |
+
val_losses.append(loss_item)
|
| 211 |
+
avg_val_loss = np.mean(val_losses) if val_losses else float('inf')
|
| 212 |
+
logging.info(f"Validation Loss after Epoch {epoch}: {avg_val_loss:.4f}")
|
| 213 |
+
|
| 214 |
+
if avg_val_loss < best_val_loss:
|
| 215 |
+
best_val_loss = avg_val_loss
|
| 216 |
+
best_epoch = epoch
|
| 217 |
+
if save_model:
|
| 218 |
+
logging.info(f"Saving best model at epoch {epoch} with val loss {avg_val_loss:.4f}")
|
| 219 |
+
torch.save(model.state_dict(), save_path)
|
| 220 |
+
model.train()
|
| 221 |
+
|
| 222 |
+
# Plot losses
|
| 223 |
+
plt.figure()
|
| 224 |
+
plt.plot(all_losses)
|
| 225 |
+
plt.xlabel('Training steps')
|
| 226 |
+
plt.ylabel('Loss')
|
| 227 |
+
plt.title('Step-wise Training Loss')
|
| 228 |
+
plt.show()
|
| 229 |
+
|
| 230 |
+
plt.figure()
|
| 231 |
+
plt.plot(range(1, epochs+1), epoch_losses, marker='o')
|
| 232 |
+
plt.xlabel('Epochs')
|
| 233 |
+
plt.ylabel('Loss')
|
| 234 |
+
plt.title('Epoch-wise Training Loss')
|
| 235 |
+
plt.show()
|
| 236 |
+
|
| 237 |
+
logging.info(f"Training complete. Best validation loss {best_val_loss:.4f} at epoch {best_epoch}.")
|
| 238 |
+
return model
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def predict_nli(premise, hypothesis, tokenizer=tokenizer, model_path='gnn_model_checkpoint.pt'):
|
| 242 |
+
# 1) instantiate the model exactly as you did during training
|
| 243 |
+
model = GraphAugmentedNLIModel(MODEL_NAME).to(DEVICE)
|
| 244 |
+
|
| 245 |
+
# 2) load the checkpoint, then hand only the model weights to load_state_dict
|
| 246 |
+
ckpt = torch.load(model_path, map_location=DEVICE)
|
| 247 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 248 |
+
|
| 249 |
+
model.eval()
|
| 250 |
+
|
| 251 |
+
# 3) tokenize & build graphs (as before)β¦
|
| 252 |
+
encoded = tokenizer(
|
| 253 |
+
premise, hypothesis,
|
| 254 |
+
truncation=True,
|
| 255 |
+
padding="max_length",
|
| 256 |
+
max_length=MAX_LENGTH,
|
| 257 |
+
return_tensors="pt"
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
input_ids = encoded["input_ids"]
|
| 261 |
+
attention_mask = encoded["attention_mask"]
|
| 262 |
+
|
| 263 |
+
# Build dependency graphs
|
| 264 |
+
p_tokens, p_edges = build_dependency_graph(premise)
|
| 265 |
+
h_tokens, h_edges = build_dependency_graph(hypothesis)
|
| 266 |
+
|
| 267 |
+
# Convert ids back to tokens for alignment
|
| 268 |
+
wp_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
| 269 |
+
|
| 270 |
+
p_node_indices = align_tokens(p_tokens, wp_tokens)
|
| 271 |
+
h_node_indices = align_tokens(h_tokens, wp_tokens)
|
| 272 |
+
|
| 273 |
+
# Move tensors to the same device as the model
|
| 274 |
+
device = next(model.parameters()).device
|
| 275 |
+
input_ids = input_ids.to(device)
|
| 276 |
+
attention_mask = attention_mask.to(device)
|
| 277 |
+
|
| 278 |
+
# Prepare inputs for the model: the model expects lists for graph fields
|
| 279 |
+
# since we used a custom collate_fn logic.
|
| 280 |
+
premise_graph_tokens = [p_tokens]
|
| 281 |
+
premise_graph_edges = [p_edges]
|
| 282 |
+
premise_node_indices = [p_node_indices]
|
| 283 |
+
|
| 284 |
+
hypothesis_graph_tokens = [h_tokens]
|
| 285 |
+
hypothesis_graph_edges = [h_edges]
|
| 286 |
+
hypothesis_node_indices = [h_node_indices]
|
| 287 |
+
|
| 288 |
+
with torch.no_grad():
|
| 289 |
+
outputs = model(
|
| 290 |
+
input_ids=input_ids,
|
| 291 |
+
attention_mask=attention_mask,
|
| 292 |
+
premise_graph_tokens=premise_graph_tokens,
|
| 293 |
+
premise_graph_edges=premise_graph_edges,
|
| 294 |
+
premise_node_indices=premise_node_indices,
|
| 295 |
+
hypothesis_graph_tokens=hypothesis_graph_tokens,
|
| 296 |
+
hypothesis_graph_edges=hypothesis_graph_edges,
|
| 297 |
+
hypothesis_node_indices=hypothesis_node_indices
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
logits = outputs["logits"]
|
| 301 |
+
probs = F.softmax(logits, dim=-1).cpu().numpy()[0]
|
| 302 |
+
# Get predicted label
|
| 303 |
+
predicted_label_id = torch.argmax(logits, dim=-1).item()
|
| 304 |
+
predicted_label = label_map[predicted_label_id]
|
| 305 |
+
prob_map = dict()
|
| 306 |
+
for i, cls_label in label_map.items():
|
| 307 |
+
prob_map[cls_label] = probs[i]
|
| 308 |
+
return predicted_label, prob_map
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def predict_fin_nli(
|
| 312 |
+
premise: str,
|
| 313 |
+
hypothesis: str,
|
| 314 |
+
tokenizer=tokenizer,
|
| 315 |
+
model_path: str = 'gnn_model_checkpoint.pt',
|
| 316 |
+
adapter_dir: str = './lora_finance_adapter',
|
| 317 |
+
) -> (str, list):
|
| 318 |
+
# 1) Load base GraphAugmentedFinNLIModel and its checkpoint
|
| 319 |
+
base_model = GraphAugmentedFinNLIModel(MODEL_NAME).to(DEVICE)
|
| 320 |
+
ckpt = torch.load(model_path, map_location=DEVICE)
|
| 321 |
+
base_model.load_state_dict(ckpt['model_state_dict'])
|
| 322 |
+
|
| 323 |
+
# 2) Wrap with the same LoRA config you used in training
|
| 324 |
+
lora_cfg = LoraConfig(
|
| 325 |
+
r=8,
|
| 326 |
+
lora_alpha=32,
|
| 327 |
+
lora_dropout=0.1,
|
| 328 |
+
bias='none',
|
| 329 |
+
task_type='SEQ_CLS',
|
| 330 |
+
target_modules=['query', 'value']
|
| 331 |
+
)
|
| 332 |
+
model = get_peft_model(base_model, lora_cfg).to(DEVICE)
|
| 333 |
+
|
| 334 |
+
# 3) Load your adapter checkpoint (the .pt under lora_finance_adapter/)
|
| 335 |
+
adapter_ckpt = torch.load(os.path.join(adapter_dir, 'training_checkpoint.pt'), map_location=DEVICE)
|
| 336 |
+
# This checkpoint contains the same 'model_state_dict' keysβso load it leniently:
|
| 337 |
+
model.load_state_dict(adapter_ckpt['model_state_dict'], strict=False)
|
| 338 |
+
model.eval()
|
| 339 |
+
|
| 340 |
+
# 4) Tokenize
|
| 341 |
+
enc = tokenizer(
|
| 342 |
+
premise, hypothesis,
|
| 343 |
+
truncation=True,
|
| 344 |
+
padding='max_length',
|
| 345 |
+
max_length=MAX_LENGTH,
|
| 346 |
+
return_tensors='pt'
|
| 347 |
+
)
|
| 348 |
+
input_ids = enc['input_ids'].to(DEVICE)
|
| 349 |
+
attention_mask = enc['attention_mask'].to(DEVICE)
|
| 350 |
+
|
| 351 |
+
# 5) Build & align your dependency graphs
|
| 352 |
+
p_toks, p_edges = build_dependency_graph(premise)
|
| 353 |
+
h_toks, h_edges = build_dependency_graph(hypothesis)
|
| 354 |
+
wp = tokenizer.convert_ids_to_tokens(input_ids[0])
|
| 355 |
+
p_idx = align_tokens(p_toks, wp)
|
| 356 |
+
h_idx = align_tokens(h_toks, wp)
|
| 357 |
+
|
| 358 |
+
premise_graph_tokens = [p_toks]
|
| 359 |
+
premise_graph_edges = [p_edges]
|
| 360 |
+
premise_node_indices = [p_idx]
|
| 361 |
+
hypothesis_graph_tokens = [h_toks]
|
| 362 |
+
hypothesis_graph_edges = [h_edges]
|
| 363 |
+
hypothesis_node_indices = [h_idx]
|
| 364 |
+
|
| 365 |
+
# 6) Forward
|
| 366 |
+
with torch.no_grad():
|
| 367 |
+
out = model(
|
| 368 |
+
input_ids=input_ids,
|
| 369 |
+
attention_mask=attention_mask,
|
| 370 |
+
premise_graph_tokens=premise_graph_tokens,
|
| 371 |
+
premise_graph_edges=premise_graph_edges,
|
| 372 |
+
premise_node_indices=premise_node_indices,
|
| 373 |
+
hypothesis_graph_tokens=hypothesis_graph_tokens,
|
| 374 |
+
hypothesis_graph_edges=hypothesis_graph_edges,
|
| 375 |
+
hypothesis_node_indices=hypothesis_node_indices
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
logits = out['logits'][0] # shape [3]
|
| 379 |
+
probs = torch.softmax(logits, dim=-1).cpu().numpy()
|
| 380 |
+
|
| 381 |
+
# 7) Collapse to entailment vs. contradiction (ignore neutral)
|
| 382 |
+
entail, neutral, contra = probs
|
| 383 |
+
s = entail + contra + 1e-12
|
| 384 |
+
scores = [entail / s, contra / s]
|
| 385 |
+
label = 'entailment' if entail >= contra else 'contradiction'
|
| 386 |
+
return label, scores
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def train_model_with_chkpt(epochs: int = 5,
|
| 390 |
+
batch_size: int = 16,
|
| 391 |
+
lr: float = 2e-5,
|
| 392 |
+
save_model: bool = False,
|
| 393 |
+
save_path: str = 'gnn_model_checkpoint.pt',
|
| 394 |
+
resume: bool = False):
|
| 395 |
+
"""
|
| 396 |
+
Train with mixed precision, gradient checkpointing, and resume support.
|
| 397 |
+
If resume=True and save_path exists, picks up from last epoch.
|
| 398 |
+
"""
|
| 399 |
+
set_seed()
|
| 400 |
+
process_data()
|
| 401 |
+
logging.info("Loading preprocessed datasetβ¦")
|
| 402 |
+
snli = load_from_disk(PREPROCESSED_DIR)
|
| 403 |
+
snli.set_format("python", output_all_columns=True)
|
| 404 |
+
|
| 405 |
+
train_loader = DataLoader(snli["train"], batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn)
|
| 406 |
+
val_loader = DataLoader(snli["validation"], batch_size=batch_size, collate_fn=my_collate_fn)
|
| 407 |
+
|
| 408 |
+
model = GraphAugmentedNLIModel(MODEL_NAME).to(DEVICE)
|
| 409 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
| 410 |
+
total_steps = epochs * len(train_loader)
|
| 411 |
+
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=total_steps)
|
| 412 |
+
|
| 413 |
+
# --- Resume checkpoint if requested ---
|
| 414 |
+
start_epoch = 1
|
| 415 |
+
if resume and os.path.isfile(save_path):
|
| 416 |
+
ckpt = torch.load(save_path, map_location=DEVICE)
|
| 417 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 418 |
+
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
| 419 |
+
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
|
| 420 |
+
start_epoch = ckpt.get("epoch", 1) + 1
|
| 421 |
+
logging.info(f"Resuming from epoch {start_epoch}")
|
| 422 |
+
|
| 423 |
+
# Mixed precision setup
|
| 424 |
+
if hasattr(model.bert, "gradient_checkpointing_enable"):
|
| 425 |
+
model.bert.gradient_checkpointing_enable()
|
| 426 |
+
logging.info("Enabled gradient checkpointing on BERT.")
|
| 427 |
+
accelerator = Accelerator(mixed_precision=MIXED_PRECISION)
|
| 428 |
+
model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare(
|
| 429 |
+
model, optimizer, train_loader, val_loader, scheduler
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
best_val_loss = float("inf")
|
| 433 |
+
for epoch in range(start_epoch, epochs + 1):
|
| 434 |
+
model.train()
|
| 435 |
+
train_losses = []
|
| 436 |
+
for batch in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}"):
|
| 437 |
+
optimizer.zero_grad()
|
| 438 |
+
outputs = model(
|
| 439 |
+
input_ids=batch["input_ids"].to(DEVICE),
|
| 440 |
+
attention_mask=batch["attention_mask"].to(DEVICE),
|
| 441 |
+
premise_graph_tokens=batch["premise_graph_tokens"],
|
| 442 |
+
premise_graph_edges=batch["premise_graph_edges"],
|
| 443 |
+
premise_node_indices=batch["premise_node_indices"],
|
| 444 |
+
hypothesis_graph_tokens=batch["hypothesis_graph_tokens"],
|
| 445 |
+
hypothesis_graph_edges=batch["hypothesis_graph_edges"],
|
| 446 |
+
hypothesis_node_indices=batch["hypothesis_node_indices"],
|
| 447 |
+
labels=batch.get("labels", None).to(DEVICE) if batch.get("labels") is not None else None
|
| 448 |
+
)
|
| 449 |
+
loss = outputs["loss"] if isinstance(outputs, dict) else outputs
|
| 450 |
+
accelerator.backward(loss)
|
| 451 |
+
optimizer.step()
|
| 452 |
+
scheduler.step()
|
| 453 |
+
train_losses.append(loss.item())
|
| 454 |
+
avg_train = np.mean(train_losses)
|
| 455 |
+
logging.info(f"Epoch {epoch} train loss: {avg_train:.4f}")
|
| 456 |
+
|
| 457 |
+
# Validation
|
| 458 |
+
model.eval()
|
| 459 |
+
val_losses = []
|
| 460 |
+
with torch.no_grad():
|
| 461 |
+
for batch in val_loader:
|
| 462 |
+
outputs = model(
|
| 463 |
+
input_ids=batch["input_ids"].to(DEVICE),
|
| 464 |
+
attention_mask=batch["attention_mask"].to(DEVICE),
|
| 465 |
+
premise_graph_tokens=batch["premise_graph_tokens"],
|
| 466 |
+
premise_graph_edges=batch["premise_graph_edges"],
|
| 467 |
+
premise_node_indices=batch["premise_node_indices"],
|
| 468 |
+
hypothesis_graph_tokens=batch["hypothesis_graph_tokens"],
|
| 469 |
+
hypothesis_graph_edges=batch["hypothesis_graph_edges"],
|
| 470 |
+
hypothesis_node_indices=batch["hypothesis_node_indices"],
|
| 471 |
+
labels=batch.get("labels", None).to(DEVICE) if batch.get("labels") is not None else None
|
| 472 |
+
)
|
| 473 |
+
v_loss = outputs["loss"].item() if isinstance(outputs, dict) else outputs.item()
|
| 474 |
+
val_losses.append(v_loss)
|
| 475 |
+
avg_val = np.mean(val_losses) if val_losses else float("inf")
|
| 476 |
+
logging.info(f"Epoch {epoch} val loss: {avg_val:.4f}")
|
| 477 |
+
|
| 478 |
+
# Save checkpoint
|
| 479 |
+
ckpt = {
|
| 480 |
+
"epoch": epoch,
|
| 481 |
+
"model_state_dict": model.state_dict(),
|
| 482 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 483 |
+
"scheduler_state_dict": scheduler.state_dict(),
|
| 484 |
+
}
|
| 485 |
+
torch.save(ckpt, save_path)
|
| 486 |
+
logging.info(f"Saved checkpoint: {save_path}")
|
| 487 |
+
|
| 488 |
+
if avg_val < best_val_loss:
|
| 489 |
+
best_val_loss = avg_val
|
| 490 |
+
|
| 491 |
+
logging.info(f"Training complete. Best val loss: {best_val_loss:.4f}")
|
| 492 |
+
return model
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
def extract_sentences_by_intent(
|
| 496 |
+
text: str,
|
| 497 |
+
intent: str,
|
| 498 |
+
adapter_dir: str = "./lora_finance_adapter",
|
| 499 |
+
threshold: float = 0.7,
|
| 500 |
+
top_k: int = None,
|
| 501 |
+
min_words: int = 4,
|
| 502 |
+
convo_focus: str = None
|
| 503 |
+
):
|
| 504 |
+
"""
|
| 505 |
+
Splits `text` into sentences, embeds them (and the `intent`) under your
|
| 506 |
+
LoRAβadapted BERT, and returns those whose cosine similarity β₯ `threshold`.
|
| 507 |
+
Loads the adapter from the single `training_checkpoint.pt` in `adapter_dir`.
|
| 508 |
+
"""
|
| 509 |
+
# 1) Sentence split & cleanup
|
| 510 |
+
# 1) Only consider lines spoken by the customer
|
| 511 |
+
|
| 512 |
+
if convo_focus is None:
|
| 513 |
+
sentences = [sent.text.strip() for sent in nlp(text).sents if sent.text.strip()]
|
| 514 |
+
|
| 515 |
+
elif convo_focus == "customer":
|
| 516 |
+
customer_lines = [
|
| 517 |
+
line.strip()
|
| 518 |
+
for line in text.splitlines()
|
| 519 |
+
if line.strip().lower().startswith("customer:")
|
| 520 |
+
]
|
| 521 |
+
|
| 522 |
+
# 2) Sentence-split each customer line
|
| 523 |
+
sentences = []
|
| 524 |
+
for cust_line in customer_lines:
|
| 525 |
+
for sent in nlp(cust_line).sents:
|
| 526 |
+
s = sent.text.strip()
|
| 527 |
+
if s and len(s.split(' '))>6:
|
| 528 |
+
sentences.append(s)
|
| 529 |
+
|
| 530 |
+
else:
|
| 531 |
+
customer_lines = [
|
| 532 |
+
line.strip()
|
| 533 |
+
for line in text.splitlines()
|
| 534 |
+
if line.strip().lower().startswith("agent:")
|
| 535 |
+
]
|
| 536 |
+
|
| 537 |
+
# 2) Sentence-split each customer line
|
| 538 |
+
sentences = []
|
| 539 |
+
for cust_line in customer_lines:
|
| 540 |
+
for sent in nlp(cust_line).sents:
|
| 541 |
+
s = sent.text.strip()
|
| 542 |
+
if s and len(s.split(' '))>6:
|
| 543 |
+
sentences.append(s)
|
| 544 |
+
|
| 545 |
+
# 2) Load base BERT + wrap in same LoRA config
|
| 546 |
+
base_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
|
| 547 |
+
lora_cfg = LoraConfig(
|
| 548 |
+
r=8,
|
| 549 |
+
lora_alpha=32,
|
| 550 |
+
lora_dropout=0.1,
|
| 551 |
+
bias="none",
|
| 552 |
+
task_type="CAUSAL_LM", # must match your fine-tune setting
|
| 553 |
+
)
|
| 554 |
+
model = get_peft_model(base_model, lora_cfg).to(DEVICE)
|
| 555 |
+
|
| 556 |
+
# 3) Load your adapter checkpoint
|
| 557 |
+
chkpt_path = os.path.join(adapter_dir, "training_checkpoint.pt")
|
| 558 |
+
if not os.path.isfile(chkpt_path):
|
| 559 |
+
raise FileNotFoundError(f"No LoRA checkpoint at {chkpt_path}")
|
| 560 |
+
ckpt = torch.load(chkpt_path, map_location=DEVICE)
|
| 561 |
+
# ckpt["model_state_dict"] contains both base + LoRA weights; strict=False
|
| 562 |
+
model.load_state_dict(ckpt["model_state_dict"], strict=False)
|
| 563 |
+
model.eval()
|
| 564 |
+
|
| 565 |
+
# helper: get [CLS] embedding under LoRA-BERT
|
| 566 |
+
def embed(text_str):
|
| 567 |
+
toks = tokenizer(
|
| 568 |
+
text_str,
|
| 569 |
+
truncation=True,
|
| 570 |
+
padding="longest",
|
| 571 |
+
return_tensors="pt"
|
| 572 |
+
).to(DEVICE)
|
| 573 |
+
|
| 574 |
+
em_args = {
|
| 575 |
+
"input_ids": toks["input_ids"],
|
| 576 |
+
"attention_mask": toks["attention_mask"],
|
| 577 |
+
}
|
| 578 |
+
if "token_type_ids" in toks:
|
| 579 |
+
em_args["token_type_ids"] = toks["token_type_ids"]
|
| 580 |
+
|
| 581 |
+
# unwrap PEFT to call only the base BertModel
|
| 582 |
+
hf_model = getattr(model, "base_model", model)
|
| 583 |
+
with torch.no_grad():
|
| 584 |
+
last_hidden = hf_model(
|
| 585 |
+
input_ids=em_args["input_ids"],
|
| 586 |
+
attention_mask=em_args["attention_mask"],
|
| 587 |
+
**({"token_type_ids": em_args["token_type_ids"]} if "token_type_ids" in em_args else {})
|
| 588 |
+
).last_hidden_state
|
| 589 |
+
return last_hidden[:, 0, :]
|
| 590 |
+
|
| 591 |
+
# now embed(intent) and each sentence using this safe helper
|
| 592 |
+
intent_emb = embed(intent)
|
| 593 |
+
|
| 594 |
+
results = []
|
| 595 |
+
with torch.no_grad():
|
| 596 |
+
for sent in sentences:
|
| 597 |
+
clean = re.sub(r'^(Agent|Customer):\s*', "", sent)
|
| 598 |
+
if len(clean.split()) < min_words:
|
| 599 |
+
continue
|
| 600 |
+
|
| 601 |
+
sent_emb = embed(clean)
|
| 602 |
+
sim = F.cosine_similarity(sent_emb, intent_emb, dim=1).item()
|
| 603 |
+
if sim >= threshold:
|
| 604 |
+
results.append((clean, sim))
|
| 605 |
+
|
| 606 |
+
# 5) sort & trim
|
| 607 |
+
results.sort(key=lambda x: x[1], reverse=True)
|
| 608 |
+
return results[:top_k] if top_k else results
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
def train_sentence_extractor(
|
| 612 |
+
model: nn.Module,
|
| 613 |
+
dataset: torch.utils.data.Dataset,
|
| 614 |
+
output_dir: str,
|
| 615 |
+
val_split: float = 0.2,
|
| 616 |
+
epochs: int = 3,
|
| 617 |
+
batch_size: int = 16,
|
| 618 |
+
lr: float = 2e-5,
|
| 619 |
+
device: str = "cpu",
|
| 620 |
+
unfreeze_after_epoch: int = 1,
|
| 621 |
+
threshold: float = 0.5
|
| 622 |
+
):
|
| 623 |
+
"""
|
| 624 |
+
Fine-tune `model` on `dataset`, hold out `val_split` for val,
|
| 625 |
+
compute loss + acc + precision + F1 each epoch, save best checkpoint,
|
| 626 |
+
and plot all four metrics at the end.
|
| 627 |
+
"""
|
| 628 |
+
# Split
|
| 629 |
+
total = len(dataset)
|
| 630 |
+
val_n = int(total * val_split)
|
| 631 |
+
train_n = total - val_n
|
| 632 |
+
train_ds, val_ds = random_split(dataset, [train_n, val_n])
|
| 633 |
+
|
| 634 |
+
# Oversample train
|
| 635 |
+
train_labels = [train_ds[i]['label'].item() for i in range(len(train_ds))]
|
| 636 |
+
counts = torch.bincount(torch.tensor(train_labels, dtype=torch.long))
|
| 637 |
+
weights = (1.0 / counts.float()).tolist()
|
| 638 |
+
sample_weights = [weights[int(l)] for l in train_labels]
|
| 639 |
+
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)
|
| 640 |
+
|
| 641 |
+
train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler, drop_last=True)
|
| 642 |
+
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
|
| 643 |
+
|
| 644 |
+
model.to(device)
|
| 645 |
+
# initially freeze backbone
|
| 646 |
+
for p in model.bert.parameters(): p.requires_grad = False
|
| 647 |
+
|
| 648 |
+
optimizer = AdamW(model.parameters(), lr=lr)
|
| 649 |
+
total_steps = epochs * len(train_loader)
|
| 650 |
+
scheduler = get_linear_schedule_with_warmup(
|
| 651 |
+
optimizer,
|
| 652 |
+
num_warmup_steps=int(0.1 * total_steps),
|
| 653 |
+
num_training_steps=total_steps
|
| 654 |
+
)
|
| 655 |
+
criterion = nn.BCEWithLogitsLoss()
|
| 656 |
+
|
| 657 |
+
# storage for metrics
|
| 658 |
+
train_losses, val_losses = [], []
|
| 659 |
+
train_accs, val_accs = [], []
|
| 660 |
+
train_precs, val_precs = [], []
|
| 661 |
+
train_f1s, val_f1s = [], []
|
| 662 |
+
|
| 663 |
+
best_val_loss = float('inf')
|
| 664 |
+
|
| 665 |
+
for epoch in range(1, epochs+1):
|
| 666 |
+
# ββ TRAIN ββ
|
| 667 |
+
model.train()
|
| 668 |
+
epoch_loss = 0.0
|
| 669 |
+
preds, labels = [], []
|
| 670 |
+
for batch in tqdm(train_loader, desc=f"Train {epoch}/{epochs}"):
|
| 671 |
+
inputs = batch['input_ids'].to(device)
|
| 672 |
+
masks = batch['attention_mask'].to(device)
|
| 673 |
+
labs = batch['label'].to(device)
|
| 674 |
+
|
| 675 |
+
optimizer.zero_grad()
|
| 676 |
+
logits = model(inputs, masks) # raw logits
|
| 677 |
+
loss = criterion(logits, labs)
|
| 678 |
+
loss.backward()
|
| 679 |
+
optimizer.step()
|
| 680 |
+
scheduler.step()
|
| 681 |
+
|
| 682 |
+
epoch_loss += loss.item()
|
| 683 |
+
|
| 684 |
+
probs = torch.sigmoid(logits)
|
| 685 |
+
batch_preds = (probs >= threshold).long()
|
| 686 |
+
preds.extend(batch_preds.cpu().tolist())
|
| 687 |
+
labels.extend(labs.cpu().long().tolist())
|
| 688 |
+
|
| 689 |
+
avg_train = epoch_loss / len(train_loader)
|
| 690 |
+
train_losses.append(avg_train)
|
| 691 |
+
train_accs.append( accuracy_score(labels, preds) )
|
| 692 |
+
train_precs.append( precision_score(labels, preds, zero_division=0) )
|
| 693 |
+
train_f1s.append( f1_score(labels, preds, zero_division=0) )
|
| 694 |
+
print(f"β Epoch {epoch} Train β loss {avg_train:.4f}, acc {train_accs[-1]:.4f}, prec {train_precs[-1]:.4f}, f1 {train_f1s[-1]:.4f}")
|
| 695 |
+
|
| 696 |
+
# unfreeze if needed
|
| 697 |
+
if epoch == unfreeze_after_epoch:
|
| 698 |
+
for p in model.bert.parameters(): p.requires_grad = True
|
| 699 |
+
optimizer = AdamW([
|
| 700 |
+
{"params": model.classifier.parameters(), "lr": 1e-3},
|
| 701 |
+
{"params": model.bert.parameters(), "lr": 1e-5},
|
| 702 |
+
], weight_decay=1e-2)
|
| 703 |
+
scheduler = get_linear_schedule_with_warmup(
|
| 704 |
+
optimizer,
|
| 705 |
+
num_warmup_steps=int(0.1 * total_steps),
|
| 706 |
+
num_training_steps=total_steps
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
# ββ VALIDATION ββ
|
| 710 |
+
model.eval()
|
| 711 |
+
epoch_loss = 0.0
|
| 712 |
+
preds, labels = [], []
|
| 713 |
+
with torch.no_grad():
|
| 714 |
+
for batch in tqdm(val_loader, desc=f" Val {epoch}/{epochs}"):
|
| 715 |
+
inputs = batch['input_ids'].to(device)
|
| 716 |
+
masks = batch['attention_mask'].to(device)
|
| 717 |
+
labs = batch['label'].to(device)
|
| 718 |
+
|
| 719 |
+
logits = model(inputs, masks)
|
| 720 |
+
loss = criterion(logits, labs)
|
| 721 |
+
epoch_loss += loss.item()
|
| 722 |
+
|
| 723 |
+
probs = torch.sigmoid(logits)
|
| 724 |
+
batch_preds = (probs >= threshold).long()
|
| 725 |
+
preds.extend(batch_preds.cpu().tolist())
|
| 726 |
+
labels.extend(labs.cpu().long().tolist())
|
| 727 |
+
|
| 728 |
+
avg_val = epoch_loss / len(val_loader)
|
| 729 |
+
val_losses.append(avg_val)
|
| 730 |
+
val_accs.append( accuracy_score(labels, preds) )
|
| 731 |
+
val_precs.append( precision_score(labels, preds, zero_division=0) )
|
| 732 |
+
val_f1s.append( f1_score(labels, preds, zero_division=0) )
|
| 733 |
+
print(f"β Epoch {epoch} Val β loss {avg_val:.4f}, acc {val_accs[-1]:.4f}, prec {val_precs[-1]:.4f}, f1 {val_f1s[-1]:.4f}")
|
| 734 |
+
|
| 735 |
+
# checkpoints
|
| 736 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 737 |
+
ckpt = os.path.join(output_dir, f"epo{epoch}_val{avg_val:.4f}.pth")
|
| 738 |
+
torch.save(model.state_dict(), ckpt)
|
| 739 |
+
if avg_val < best_val_loss:
|
| 740 |
+
best_val_loss = avg_val
|
| 741 |
+
torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pth"))
|
| 742 |
+
print(f"π New best model saved (val loss {best_val_loss:.4f})")
|
| 743 |
+
|
| 744 |
+
print(f"βοΈ Training complete β best val loss: {best_val_loss:.4f}")
|
| 745 |
+
|
| 746 |
+
# ββ PLOT METRICS ββ
|
| 747 |
+
epochs = list(range(1, epochs+1))
|
| 748 |
+
|
| 749 |
+
save_metric_plot(
|
| 750 |
+
epochs,
|
| 751 |
+
train_losses,
|
| 752 |
+
val_losses,
|
| 753 |
+
metric_name="Loss",
|
| 754 |
+
output_path="results/Loss_Plot.png"
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
save_metric_plot(
|
| 758 |
+
epochs,
|
| 759 |
+
train_accs,
|
| 760 |
+
val_accs,
|
| 761 |
+
metric_name="Accuracy",
|
| 762 |
+
output_path="results/Accuracy_Plot.png",
|
| 763 |
+
threshold=0.5
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
save_metric_plot(
|
| 767 |
+
epochs,
|
| 768 |
+
train_precs,
|
| 769 |
+
val_precs,
|
| 770 |
+
metric_name="Precision",
|
| 771 |
+
output_path="results/Precision_Plot.png",
|
| 772 |
+
threshold=0.5
|
| 773 |
+
)
|
| 774 |
+
|
| 775 |
+
save_metric_plot(
|
| 776 |
+
epochs,
|
| 777 |
+
train_f1s,
|
| 778 |
+
val_f1s,
|
| 779 |
+
metric_name="F1 Score",
|
| 780 |
+
output_path="results/F1Score_Plot.png",
|
| 781 |
+
threshold=0.5
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
|
| 785 |
+
def save_metric_plot(
|
| 786 |
+
epochs,
|
| 787 |
+
train_vals,
|
| 788 |
+
val_vals,
|
| 789 |
+
metric_name: str,
|
| 790 |
+
output_path: str,
|
| 791 |
+
threshold: float = None
|
| 792 |
+
):
|
| 793 |
+
"""
|
| 794 |
+
epochs β list of epoch indices
|
| 795 |
+
train_vals β list of train metric values
|
| 796 |
+
val_vals β list of validation metric values
|
| 797 |
+
metric_name β e.g. "Loss", "Accuracy", "Precision", "F1 Score"
|
| 798 |
+
output_path β where to save the PNG
|
| 799 |
+
threshold β optional horizontal line to draw, e.g. 0.5
|
| 800 |
+
"""
|
| 801 |
+
fig, ax = plt.subplots(figsize=(8, 5))
|
| 802 |
+
ax.plot(epochs, train_vals, marker='o', linewidth=2, label=f'Train {metric_name}')
|
| 803 |
+
ax.plot(epochs, val_vals, marker='s', linewidth=2, label=f'Val {metric_name}')
|
| 804 |
+
|
| 805 |
+
if threshold is not None:
|
| 806 |
+
ax.axhline(threshold, color='gray', linestyle='--', linewidth=1, label=f'Threshold = {threshold}')
|
| 807 |
+
|
| 808 |
+
ax.set_title(f'{metric_name} over Epochs', fontsize=14, pad=10)
|
| 809 |
+
ax.set_xlabel('Epoch', fontsize=12)
|
| 810 |
+
ax.set_ylabel(metric_name, fontsize=12)
|
| 811 |
+
ax.grid(True, linestyle='--', alpha=0.4)
|
| 812 |
+
ax.legend(loc='best', frameon=True, fontsize=10)
|
| 813 |
+
fig.tight_layout()
|
| 814 |
+
fig.savefig(output_path, dpi=300)
|
| 815 |
+
plt.close(fig)
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
def demo_on_random_val(
|
| 819 |
+
model,
|
| 820 |
+
tokenizer,
|
| 821 |
+
excel_path: str,
|
| 822 |
+
ckpt_path: str,
|
| 823 |
+
max_length: int = 128,
|
| 824 |
+
device: str = "cpu",
|
| 825 |
+
temperature: float = 1.0
|
| 826 |
+
):
|
| 827 |
+
"""
|
| 828 |
+
Like demo_on_random_val, but instead of a fixed threshold:
|
| 829 |
+
1) Compute sigmoid(logits / temperature) for each sentence
|
| 830 |
+
2) Sort probabilities descending
|
| 831 |
+
3) Find the largest gap between adjacent probs
|
| 832 |
+
4) Set dynamic_threshold = midpoint of that gap
|
| 833 |
+
5) Extract all sentences with prob >= dynamic_threshold
|
| 834 |
+
"""
|
| 835 |
+
# load model
|
| 836 |
+
model.load_state_dict(torch.load(ckpt_path, map_location=device))
|
| 837 |
+
model.to(device).eval()
|
| 838 |
+
|
| 839 |
+
# sample one from validation split
|
| 840 |
+
df = pd.read_excel(excel_path)
|
| 841 |
+
_, val_df = train_test_split(df, test_size=0.2, random_state=42)
|
| 842 |
+
row = val_df.sample(n=1, random_state=random.randint(0,999)).iloc[0]
|
| 843 |
+
transcript = str(row['Claude_Call'])
|
| 844 |
+
print(f"\nββ Transcript (val sample idx={row['idx']}):\n{transcript}\n")
|
| 845 |
+
|
| 846 |
+
# split into sentences & run inference
|
| 847 |
+
sentences, probs = [], []
|
| 848 |
+
for sent in sent_tokenize(transcript):
|
| 849 |
+
enc = tokenizer.encode_plus(
|
| 850 |
+
sent,
|
| 851 |
+
max_length=max_length,
|
| 852 |
+
padding='max_length',
|
| 853 |
+
truncation=True,
|
| 854 |
+
return_tensors='pt'
|
| 855 |
+
)
|
| 856 |
+
logits = model(enc['input_ids'].to(device),
|
| 857 |
+
enc['attention_mask'].to(device))
|
| 858 |
+
prob = torch.sigmoid(logits / temperature).item()
|
| 859 |
+
sentences.append(sent)
|
| 860 |
+
probs.append(prob)
|
| 861 |
+
|
| 862 |
+
# print all
|
| 863 |
+
print("Sentence probabilities:")
|
| 864 |
+
for s,p in zip(sentences, probs):
|
| 865 |
+
print(f" β {p:.4f} β {s}")
|
| 866 |
+
|
| 867 |
+
# if no variation, fall back to 0.5
|
| 868 |
+
if len(probs) < 2 or max(probs) - min(probs) < 1e-3:
|
| 869 |
+
dynamic_thr = 0.5
|
| 870 |
+
else:
|
| 871 |
+
# find elbow in sorted probabilities
|
| 872 |
+
sorted_probs = sorted(probs, reverse=True)
|
| 873 |
+
diffs = [sorted_probs[i] - sorted_probs[i+1] for i in range(len(sorted_probs)-1)]
|
| 874 |
+
idx = max(range(len(diffs)), key=lambda i: diffs[i])
|
| 875 |
+
# threshold is midpoint between the two
|
| 876 |
+
dynamic_thr = (sorted_probs[idx] + sorted_probs[idx+1]) / 2.0
|
| 877 |
+
|
| 878 |
+
print(f"\nDynamic threshold = {dynamic_thr:.4f}\n")
|
| 879 |
+
print("Extracted sentences:")
|
| 880 |
+
for s,p in zip(sentences, probs):
|
| 881 |
+
if p >= dynamic_thr:
|
| 882 |
+
print(f" β’ {p:.4f} β {s}")
|
| 883 |
+
print()
|
| 884 |
+
|
| 885 |
+
|
| 886 |
+
def batch_predict_and_save(
|
| 887 |
+
model,
|
| 888 |
+
tokenizer,
|
| 889 |
+
excel_path: str,
|
| 890 |
+
ckpt_path: str,
|
| 891 |
+
output_path: str,
|
| 892 |
+
n_samples: int = 40,
|
| 893 |
+
max_length: int = 128,
|
| 894 |
+
device: str = "cpu",
|
| 895 |
+
temperature: float = 1.0,
|
| 896 |
+
random_state: int = None
|
| 897 |
+
):
|
| 898 |
+
"""
|
| 899 |
+
1) Loads best checkpoint
|
| 900 |
+
2) Samples `n_samples` rows
|
| 901 |
+
3) For each transcript:
|
| 902 |
+
- tokenize into sentences
|
| 903 |
+
- compute p = sigmoid(logits/temperature)
|
| 904 |
+
- compute elbow threshold on sorted pβs
|
| 905 |
+
- extract all sentences with p >= elbow
|
| 906 |
+
- if none, pick the highest-p sentence
|
| 907 |
+
4) Save new Excel with columns:
|
| 908 |
+
- 'Claude_Call'
|
| 909 |
+
- 'Predicted Sel_K' (list of extracted sentences)
|
| 910 |
+
"""
|
| 911 |
+
# load model
|
| 912 |
+
model.load_state_dict(torch.load(ckpt_path, map_location=device))
|
| 913 |
+
model.to(device).eval()
|
| 914 |
+
|
| 915 |
+
# sample rows
|
| 916 |
+
df = pd.read_excel(excel_path)
|
| 917 |
+
sampled = df.sample(n=n_samples, random_state=random_state) \
|
| 918 |
+
if random_state is not None else df.sample(n=n_samples)
|
| 919 |
+
|
| 920 |
+
records = []
|
| 921 |
+
for _, row in tqdm(sampled.iterrows(),
|
| 922 |
+
total=len(sampled),
|
| 923 |
+
desc="Running Predictions"):
|
| 924 |
+
transcript = str(row['Claude_Call'])
|
| 925 |
+
sentences = sent_tokenize(transcript)
|
| 926 |
+
|
| 927 |
+
# compute probabilities
|
| 928 |
+
probs = []
|
| 929 |
+
for sent in sentences:
|
| 930 |
+
enc = tokenizer.encode_plus(
|
| 931 |
+
sent,
|
| 932 |
+
max_length=max_length,
|
| 933 |
+
padding='max_length',
|
| 934 |
+
truncation=True,
|
| 935 |
+
return_tensors='pt'
|
| 936 |
+
)
|
| 937 |
+
with torch.no_grad():
|
| 938 |
+
logits = model(enc['input_ids'].to(device),
|
| 939 |
+
enc['attention_mask'].to(device))
|
| 940 |
+
p = torch.sigmoid(logits / temperature).item()
|
| 941 |
+
probs.append(p)
|
| 942 |
+
|
| 943 |
+
# dynamic threshold via elbow detection
|
| 944 |
+
if len(probs) >= 2 and max(probs) - min(probs) > 1e-3:
|
| 945 |
+
sp = sorted(probs, reverse=True)
|
| 946 |
+
diffs = [sp[i] - sp[i+1] for i in range(len(sp)-1)]
|
| 947 |
+
idx = max(range(len(diffs)), key=lambda i: diffs[i])
|
| 948 |
+
thr = (sp[idx] + sp[idx+1]) / 2.0
|
| 949 |
+
else:
|
| 950 |
+
thr = 0.5 # fallback
|
| 951 |
+
|
| 952 |
+
# collect all above threshold, else top-1
|
| 953 |
+
extracted = [s for s,p in zip(sentences, probs) if p >= thr]
|
| 954 |
+
if not extracted and sentences:
|
| 955 |
+
best_idx = int(max(range(len(probs)), key=lambda i: probs[i]))
|
| 956 |
+
extracted = [sentences[best_idx]]
|
| 957 |
+
|
| 958 |
+
records.append({
|
| 959 |
+
'Claude_Call': transcript,
|
| 960 |
+
'Predicted Sel_K': extracted
|
| 961 |
+
})
|
| 962 |
+
|
| 963 |
+
# save
|
| 964 |
+
out_df = pd.DataFrame(records)
|
| 965 |
+
os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True)
|
| 966 |
+
out_df.to_excel(output_path, index=False)
|
| 967 |
+
print(f"β‘οΈ Saved {len(out_df)} rows to {output_path}")
|