|
|
import os |
|
|
import logging |
|
|
import random |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from nltk import sent_tokenize |
|
|
from sklearn.metrics import accuracy_score, precision_score, f1_score |
|
|
from sklearn.model_selection import train_test_split |
|
|
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler |
|
|
from transformers import AutoTokenizer, AutoModel, AutoConfig, get_linear_schedule_with_warmup |
|
|
from peft import PeftModel, LoraConfig, get_peft_model |
|
|
from datasets import load_dataset, DatasetDict, load_from_disk |
|
|
import spacy |
|
|
import re |
|
|
from tqdm.auto import tqdm |
|
|
from accelerate import Accelerator |
|
|
import matplotlib.pyplot as plt |
|
|
from torch.optim import AdamW |
|
|
import pandas as pd |
|
|
from typing import Optional, Tuple, List, Dict |
|
|
|
|
|
from models import GraphAugmentedNLIModel, GraphAugmentedFinNLIModel |
|
|
from preprocess_data import SpanExtractionChunkedDataset, process_data, chunk_transcript, span_collate_fn |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from config import MODEL_NAME, MAX_LENGTH, OVERLAP, PREPROCESSED_DIR, tokenizer, nlp |
|
|
|
|
|
|
|
|
BATCH_SIZE = 16 |
|
|
|
|
|
|
|
|
LEARNING_RATE = 2e-5 |
|
|
EPOCHS = 5 |
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
MIXED_PRECISION = "fp16" |
|
|
|
|
|
|
|
|
label_map = {0: "entailment", 1: "neutral", 2: "contradiction"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
|
|
def set_seed(seed: int = 42): |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
nlp = spacy.load("en_core_web_sm") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_dependency_graph(sentence: str): |
|
|
doc = nlp(sentence) |
|
|
tokens = [token.text for token in doc] |
|
|
edges = [] |
|
|
for token in doc: |
|
|
if token.head.i != token.i: |
|
|
edges.append((token.i, token.head.i)) |
|
|
edges.append((token.head.i, token.i)) |
|
|
return tokens, edges |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def align_tokens(spacy_tokens, wp_tokens): |
|
|
node_indices = [] |
|
|
wp_idx = 1 |
|
|
for _ in spacy_tokens: |
|
|
if wp_idx >= len(wp_tokens) - 1: |
|
|
break |
|
|
node_indices.append(wp_idx) |
|
|
wp_idx += 1 |
|
|
while wp_idx < len(wp_tokens) - 1 and wp_tokens[wp_idx].startswith("##"): |
|
|
wp_idx += 1 |
|
|
return node_indices |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def my_collate_fn(batch): |
|
|
input_ids = [torch.tensor(ex["input_ids"], dtype=torch.long) for ex in batch] |
|
|
attention_mask = [torch.tensor(ex["attention_mask"], dtype=torch.long) for ex in batch] |
|
|
labels = [ex.get("labels", None) for ex in batch] |
|
|
|
|
|
premise_graph_tokens = [ex.get("premise_graph_tokens") for ex in batch] |
|
|
premise_graph_edges = [ex.get("premise_graph_edges") for ex in batch] |
|
|
premise_node_indices = [ex.get("premise_node_indices") for ex in batch] |
|
|
|
|
|
hypothesis_graph_tokens = [ex.get("hypothesis_graph_tokens") for ex in batch] |
|
|
hypothesis_graph_edges = [ex.get("hypothesis_graph_edges") for ex in batch] |
|
|
hypothesis_node_indices = [ex.get("hypothesis_node_indices") for ex in batch] |
|
|
|
|
|
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) |
|
|
attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0) |
|
|
labels = torch.tensor(labels, dtype=torch.long) if labels and labels[0] is not None else None |
|
|
|
|
|
return { |
|
|
"input_ids": input_ids, |
|
|
"attention_mask": attention_mask, |
|
|
"labels": labels, |
|
|
"premise_graph_tokens": premise_graph_tokens, |
|
|
"premise_graph_edges": premise_graph_edges, |
|
|
"premise_node_indices": premise_node_indices, |
|
|
"hypothesis_graph_tokens": hypothesis_graph_tokens, |
|
|
"hypothesis_graph_edges": hypothesis_graph_edges, |
|
|
"hypothesis_node_indices": hypothesis_node_indices, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_model(epochs: int = EPOCHS, |
|
|
batch_size: int = BATCH_SIZE, |
|
|
lr: float = LEARNING_RATE, |
|
|
save_model: bool = False, |
|
|
save_path: str = 'gnn_model_weights_3.pt'): |
|
|
set_seed() |
|
|
process_data() |
|
|
logging.info("Loading preprocessed dataset...") |
|
|
snli = load_from_disk(PREPROCESSED_DIR) |
|
|
snli.set_format("python", output_all_columns=True) |
|
|
|
|
|
train_loader = DataLoader(snli["train"], batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn) |
|
|
val_loader = DataLoader(snli["validation"], batch_size=batch_size, collate_fn=my_collate_fn) |
|
|
|
|
|
model = GraphAugmentedNLIModel(MODEL_NAME).to(DEVICE) |
|
|
|
|
|
if hasattr(model.bert, 'gradient_checkpointing_enable'): |
|
|
model.bert.gradient_checkpointing_enable() |
|
|
logging.info("Enabled gradient checkpointing on BERT.") |
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=lr) |
|
|
num_training_steps = epochs * len(train_loader) |
|
|
lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=num_training_steps) |
|
|
|
|
|
accelerator = Accelerator(mixed_precision=MIXED_PRECISION) |
|
|
model, optimizer, train_loader, val_loader, lr_scheduler = accelerator.prepare( |
|
|
model, optimizer, train_loader, val_loader, lr_scheduler |
|
|
) |
|
|
|
|
|
model.train() |
|
|
all_losses = [] |
|
|
epoch_losses = [] |
|
|
best_val_loss = float('inf') |
|
|
best_epoch = 0 |
|
|
|
|
|
for epoch in range(1, epochs + 1): |
|
|
epoch_loss = [] |
|
|
progress = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", leave=False) |
|
|
for batch in progress: |
|
|
labels = batch["labels"].to(DEVICE) if batch.get("labels") is not None else None |
|
|
outputs = model( |
|
|
input_ids=batch["input_ids"].to(DEVICE), |
|
|
attention_mask=batch["attention_mask"].to(DEVICE), |
|
|
premise_graph_tokens=batch["premise_graph_tokens"], |
|
|
premise_graph_edges=batch["premise_graph_edges"], |
|
|
premise_node_indices=batch["premise_node_indices"], |
|
|
hypothesis_graph_tokens=batch["hypothesis_graph_tokens"], |
|
|
hypothesis_graph_edges=batch["hypothesis_graph_edges"], |
|
|
hypothesis_node_indices=batch["hypothesis_node_indices"], |
|
|
labels=labels |
|
|
) |
|
|
loss = outputs.get("loss") if isinstance(outputs, dict) else outputs |
|
|
|
|
|
optimizer.zero_grad() |
|
|
accelerator.backward(loss) |
|
|
optimizer.step() |
|
|
lr_scheduler.step() |
|
|
|
|
|
loss_val = loss.item() |
|
|
epoch_loss.append(loss_val) |
|
|
all_losses.append(loss_val) |
|
|
progress.set_postfix({"loss": f"{loss_val:.4f}"}) |
|
|
|
|
|
avg_epoch_loss = np.mean(epoch_loss) |
|
|
epoch_losses.append(avg_epoch_loss) |
|
|
logging.info(f"Epoch {epoch} completed. Avg Loss: {avg_epoch_loss:.4f}") |
|
|
|
|
|
|
|
|
model.eval() |
|
|
val_losses = [] |
|
|
with torch.no_grad(): |
|
|
for batch in val_loader: |
|
|
labels = batch["labels"].to(DEVICE) if batch.get("labels") is not None else None |
|
|
outputs = model( |
|
|
input_ids=batch["input_ids"].to(DEVICE), |
|
|
attention_mask=batch["attention_mask"].to(DEVICE), |
|
|
premise_graph_tokens=batch["premise_graph_tokens"], |
|
|
premise_graph_edges=batch["premise_graph_edges"], |
|
|
premise_node_indices=batch["premise_node_indices"], |
|
|
hypothesis_graph_tokens=batch["hypothesis_graph_tokens"], |
|
|
hypothesis_graph_edges=batch["hypothesis_graph_edges"], |
|
|
hypothesis_node_indices=batch["hypothesis_node_indices"], |
|
|
labels=labels |
|
|
) |
|
|
loss_item = outputs.get("loss").item() if isinstance(outputs, dict) else outputs.item() |
|
|
val_losses.append(loss_item) |
|
|
avg_val_loss = np.mean(val_losses) if val_losses else float('inf') |
|
|
logging.info(f"Validation Loss after Epoch {epoch}: {avg_val_loss:.4f}") |
|
|
|
|
|
if avg_val_loss < best_val_loss: |
|
|
best_val_loss = avg_val_loss |
|
|
best_epoch = epoch |
|
|
if save_model: |
|
|
logging.info(f"Saving best model at epoch {epoch} with val loss {avg_val_loss:.4f}") |
|
|
torch.save(model.state_dict(), save_path) |
|
|
model.train() |
|
|
|
|
|
|
|
|
plt.figure() |
|
|
plt.plot(all_losses) |
|
|
plt.xlabel('Training steps') |
|
|
plt.ylabel('Loss') |
|
|
plt.title('Step-wise Training Loss') |
|
|
plt.show() |
|
|
|
|
|
plt.figure() |
|
|
plt.plot(range(1, epochs+1), epoch_losses, marker='o') |
|
|
plt.xlabel('Epochs') |
|
|
plt.ylabel('Loss') |
|
|
plt.title('Epoch-wise Training Loss') |
|
|
plt.show() |
|
|
|
|
|
logging.info(f"Training complete. Best validation loss {best_val_loss:.4f} at epoch {best_epoch}.") |
|
|
return model |
|
|
|
|
|
|
|
|
def predict_nli(premise, hypothesis, tokenizer=tokenizer, model_path='gnn_model_checkpoint.pt'): |
|
|
|
|
|
model = GraphAugmentedNLIModel(MODEL_NAME).to(DEVICE) |
|
|
|
|
|
|
|
|
ckpt = torch.load(model_path, map_location=DEVICE) |
|
|
model.load_state_dict(ckpt["model_state_dict"]) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
encoded = tokenizer( |
|
|
premise, hypothesis, |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
max_length=MAX_LENGTH, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
input_ids = encoded["input_ids"] |
|
|
attention_mask = encoded["attention_mask"] |
|
|
|
|
|
|
|
|
p_tokens, p_edges = build_dependency_graph(premise) |
|
|
h_tokens, h_edges = build_dependency_graph(hypothesis) |
|
|
|
|
|
|
|
|
wp_tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) |
|
|
|
|
|
p_node_indices = align_tokens(p_tokens, wp_tokens) |
|
|
h_node_indices = align_tokens(h_tokens, wp_tokens) |
|
|
|
|
|
|
|
|
device = next(model.parameters()).device |
|
|
input_ids = input_ids.to(device) |
|
|
attention_mask = attention_mask.to(device) |
|
|
|
|
|
|
|
|
|
|
|
premise_graph_tokens = [p_tokens] |
|
|
premise_graph_edges = [p_edges] |
|
|
premise_node_indices = [p_node_indices] |
|
|
|
|
|
hypothesis_graph_tokens = [h_tokens] |
|
|
hypothesis_graph_edges = [h_edges] |
|
|
hypothesis_node_indices = [h_node_indices] |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
premise_graph_tokens=premise_graph_tokens, |
|
|
premise_graph_edges=premise_graph_edges, |
|
|
premise_node_indices=premise_node_indices, |
|
|
hypothesis_graph_tokens=hypothesis_graph_tokens, |
|
|
hypothesis_graph_edges=hypothesis_graph_edges, |
|
|
hypothesis_node_indices=hypothesis_node_indices |
|
|
) |
|
|
|
|
|
logits = outputs["logits"] |
|
|
probs = F.softmax(logits, dim=-1).cpu().numpy()[0] |
|
|
|
|
|
predicted_label_id = torch.argmax(logits, dim=-1).item() |
|
|
predicted_label = label_map[predicted_label_id] |
|
|
prob_map = dict() |
|
|
for i, cls_label in label_map.items(): |
|
|
prob_map[cls_label] = probs[i] |
|
|
return predicted_label, prob_map |
|
|
|
|
|
|
|
|
def predict_fin_nli( |
|
|
premise: str, |
|
|
hypothesis: str, |
|
|
tokenizer=tokenizer, |
|
|
model_path: str = 'gnn_model_checkpoint.pt', |
|
|
adapter_dir: str = './lora_finance_adapter', |
|
|
) -> (str, list): |
|
|
|
|
|
base_model = GraphAugmentedFinNLIModel(MODEL_NAME).to(DEVICE) |
|
|
ckpt = torch.load(model_path, map_location=DEVICE) |
|
|
base_model.load_state_dict(ckpt['model_state_dict']) |
|
|
|
|
|
|
|
|
lora_cfg = LoraConfig( |
|
|
r=8, |
|
|
lora_alpha=32, |
|
|
lora_dropout=0.1, |
|
|
bias='none', |
|
|
task_type='SEQ_CLS', |
|
|
target_modules=['query', 'value'] |
|
|
) |
|
|
model = get_peft_model(base_model, lora_cfg).to(DEVICE) |
|
|
|
|
|
|
|
|
adapter_ckpt = torch.load(os.path.join(adapter_dir, 'training_checkpoint.pt'), map_location=DEVICE) |
|
|
|
|
|
model.load_state_dict(adapter_ckpt['model_state_dict'], strict=False) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
enc = tokenizer( |
|
|
premise, hypothesis, |
|
|
truncation=True, |
|
|
padding='max_length', |
|
|
max_length=MAX_LENGTH, |
|
|
return_tensors='pt' |
|
|
) |
|
|
input_ids = enc['input_ids'].to(DEVICE) |
|
|
attention_mask = enc['attention_mask'].to(DEVICE) |
|
|
|
|
|
|
|
|
p_toks, p_edges = build_dependency_graph(premise) |
|
|
h_toks, h_edges = build_dependency_graph(hypothesis) |
|
|
wp = tokenizer.convert_ids_to_tokens(input_ids[0]) |
|
|
p_idx = align_tokens(p_toks, wp) |
|
|
h_idx = align_tokens(h_toks, wp) |
|
|
|
|
|
premise_graph_tokens = [p_toks] |
|
|
premise_graph_edges = [p_edges] |
|
|
premise_node_indices = [p_idx] |
|
|
hypothesis_graph_tokens = [h_toks] |
|
|
hypothesis_graph_edges = [h_edges] |
|
|
hypothesis_node_indices = [h_idx] |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
out = model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
premise_graph_tokens=premise_graph_tokens, |
|
|
premise_graph_edges=premise_graph_edges, |
|
|
premise_node_indices=premise_node_indices, |
|
|
hypothesis_graph_tokens=hypothesis_graph_tokens, |
|
|
hypothesis_graph_edges=hypothesis_graph_edges, |
|
|
hypothesis_node_indices=hypothesis_node_indices |
|
|
) |
|
|
|
|
|
logits = out['logits'][0] |
|
|
probs = torch.softmax(logits, dim=-1).cpu().numpy() |
|
|
|
|
|
|
|
|
entail, neutral, contra = probs |
|
|
s = entail + contra + 1e-12 |
|
|
scores = [entail / s, contra / s] |
|
|
label = 'entailment' if entail >= contra else 'contradiction' |
|
|
return label, scores |
|
|
|
|
|
|
|
|
def train_model_with_chkpt(epochs: int = 5, |
|
|
batch_size: int = 16, |
|
|
lr: float = 2e-5, |
|
|
save_model: bool = False, |
|
|
save_path: str = 'gnn_model_checkpoint.pt', |
|
|
resume: bool = False): |
|
|
""" |
|
|
Train with mixed precision, gradient checkpointing, and resume support. |
|
|
If resume=True and save_path exists, picks up from last epoch. |
|
|
""" |
|
|
set_seed() |
|
|
process_data() |
|
|
logging.info("Loading preprocessed dataset…") |
|
|
snli = load_from_disk(PREPROCESSED_DIR) |
|
|
snli.set_format("python", output_all_columns=True) |
|
|
|
|
|
train_loader = DataLoader(snli["train"], batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn) |
|
|
val_loader = DataLoader(snli["validation"], batch_size=batch_size, collate_fn=my_collate_fn) |
|
|
|
|
|
model = GraphAugmentedNLIModel(MODEL_NAME).to(DEVICE) |
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=lr) |
|
|
total_steps = epochs * len(train_loader) |
|
|
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=total_steps) |
|
|
|
|
|
|
|
|
start_epoch = 1 |
|
|
if resume and os.path.isfile(save_path): |
|
|
ckpt = torch.load(save_path, map_location=DEVICE) |
|
|
model.load_state_dict(ckpt["model_state_dict"]) |
|
|
optimizer.load_state_dict(ckpt["optimizer_state_dict"]) |
|
|
scheduler.load_state_dict(ckpt["scheduler_state_dict"]) |
|
|
start_epoch = ckpt.get("epoch", 1) + 1 |
|
|
logging.info(f"Resuming from epoch {start_epoch}") |
|
|
|
|
|
|
|
|
if hasattr(model.bert, "gradient_checkpointing_enable"): |
|
|
model.bert.gradient_checkpointing_enable() |
|
|
logging.info("Enabled gradient checkpointing on BERT.") |
|
|
accelerator = Accelerator(mixed_precision=MIXED_PRECISION) |
|
|
model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare( |
|
|
model, optimizer, train_loader, val_loader, scheduler |
|
|
) |
|
|
|
|
|
best_val_loss = float("inf") |
|
|
for epoch in range(start_epoch, epochs + 1): |
|
|
model.train() |
|
|
train_losses = [] |
|
|
for batch in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}"): |
|
|
optimizer.zero_grad() |
|
|
outputs = model( |
|
|
input_ids=batch["input_ids"].to(DEVICE), |
|
|
attention_mask=batch["attention_mask"].to(DEVICE), |
|
|
premise_graph_tokens=batch["premise_graph_tokens"], |
|
|
premise_graph_edges=batch["premise_graph_edges"], |
|
|
premise_node_indices=batch["premise_node_indices"], |
|
|
hypothesis_graph_tokens=batch["hypothesis_graph_tokens"], |
|
|
hypothesis_graph_edges=batch["hypothesis_graph_edges"], |
|
|
hypothesis_node_indices=batch["hypothesis_node_indices"], |
|
|
labels=batch.get("labels", None).to(DEVICE) if batch.get("labels") is not None else None |
|
|
) |
|
|
loss = outputs["loss"] if isinstance(outputs, dict) else outputs |
|
|
accelerator.backward(loss) |
|
|
optimizer.step() |
|
|
scheduler.step() |
|
|
train_losses.append(loss.item()) |
|
|
avg_train = np.mean(train_losses) |
|
|
logging.info(f"Epoch {epoch} train loss: {avg_train:.4f}") |
|
|
|
|
|
|
|
|
model.eval() |
|
|
val_losses = [] |
|
|
with torch.no_grad(): |
|
|
for batch in val_loader: |
|
|
outputs = model( |
|
|
input_ids=batch["input_ids"].to(DEVICE), |
|
|
attention_mask=batch["attention_mask"].to(DEVICE), |
|
|
premise_graph_tokens=batch["premise_graph_tokens"], |
|
|
premise_graph_edges=batch["premise_graph_edges"], |
|
|
premise_node_indices=batch["premise_node_indices"], |
|
|
hypothesis_graph_tokens=batch["hypothesis_graph_tokens"], |
|
|
hypothesis_graph_edges=batch["hypothesis_graph_edges"], |
|
|
hypothesis_node_indices=batch["hypothesis_node_indices"], |
|
|
labels=batch.get("labels", None).to(DEVICE) if batch.get("labels") is not None else None |
|
|
) |
|
|
v_loss = outputs["loss"].item() if isinstance(outputs, dict) else outputs.item() |
|
|
val_losses.append(v_loss) |
|
|
avg_val = np.mean(val_losses) if val_losses else float("inf") |
|
|
logging.info(f"Epoch {epoch} val loss: {avg_val:.4f}") |
|
|
|
|
|
|
|
|
ckpt = { |
|
|
"epoch": epoch, |
|
|
"model_state_dict": model.state_dict(), |
|
|
"optimizer_state_dict": optimizer.state_dict(), |
|
|
"scheduler_state_dict": scheduler.state_dict(), |
|
|
} |
|
|
torch.save(ckpt, save_path) |
|
|
logging.info(f"Saved checkpoint: {save_path}") |
|
|
|
|
|
if avg_val < best_val_loss: |
|
|
best_val_loss = avg_val |
|
|
|
|
|
logging.info(f"Training complete. Best val loss: {best_val_loss:.4f}") |
|
|
return model |
|
|
|
|
|
|
|
|
def extract_sentences_by_intent( |
|
|
text: str, |
|
|
intent: str, |
|
|
adapter_dir: str = "./lora_finance_adapter", |
|
|
threshold: float = 0.7, |
|
|
top_k: int = None, |
|
|
min_words: int = 4, |
|
|
convo_focus: str = None |
|
|
): |
|
|
""" |
|
|
Splits `text` into sentences, embeds them (and the `intent`) under your |
|
|
LoRA‐adapted BERT, and returns those whose cosine similarity ≥ `threshold`. |
|
|
Loads the adapter from the single `training_checkpoint.pt` in `adapter_dir`. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
if convo_focus is None: |
|
|
sentences = [sent.text.strip() for sent in nlp(text).sents if sent.text.strip()] |
|
|
|
|
|
elif convo_focus == "customer": |
|
|
customer_lines = [ |
|
|
line.strip() |
|
|
for line in text.splitlines() |
|
|
if line.strip().lower().startswith("customer:") |
|
|
] |
|
|
|
|
|
|
|
|
sentences = [] |
|
|
for cust_line in customer_lines: |
|
|
for sent in nlp(cust_line).sents: |
|
|
s = sent.text.strip() |
|
|
if s and len(s.split(' '))>6: |
|
|
sentences.append(s) |
|
|
|
|
|
else: |
|
|
customer_lines = [ |
|
|
line.strip() |
|
|
for line in text.splitlines() |
|
|
if line.strip().lower().startswith("agent:") |
|
|
] |
|
|
|
|
|
|
|
|
sentences = [] |
|
|
for cust_line in customer_lines: |
|
|
for sent in nlp(cust_line).sents: |
|
|
s = sent.text.strip() |
|
|
if s and len(s.split(' '))>6: |
|
|
sentences.append(s) |
|
|
|
|
|
|
|
|
base_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE) |
|
|
lora_cfg = LoraConfig( |
|
|
r=8, |
|
|
lora_alpha=32, |
|
|
lora_dropout=0.1, |
|
|
bias="none", |
|
|
task_type="CAUSAL_LM", |
|
|
) |
|
|
model = get_peft_model(base_model, lora_cfg).to(DEVICE) |
|
|
|
|
|
|
|
|
chkpt_path = os.path.join(adapter_dir, "training_checkpoint.pt") |
|
|
if not os.path.isfile(chkpt_path): |
|
|
raise FileNotFoundError(f"No LoRA checkpoint at {chkpt_path}") |
|
|
ckpt = torch.load(chkpt_path, map_location=DEVICE) |
|
|
|
|
|
model.load_state_dict(ckpt["model_state_dict"], strict=False) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
def embed(text_str): |
|
|
toks = tokenizer( |
|
|
text_str, |
|
|
truncation=True, |
|
|
padding="longest", |
|
|
return_tensors="pt" |
|
|
).to(DEVICE) |
|
|
|
|
|
em_args = { |
|
|
"input_ids": toks["input_ids"], |
|
|
"attention_mask": toks["attention_mask"], |
|
|
} |
|
|
if "token_type_ids" in toks: |
|
|
em_args["token_type_ids"] = toks["token_type_ids"] |
|
|
|
|
|
|
|
|
hf_model = getattr(model, "base_model", model) |
|
|
with torch.no_grad(): |
|
|
last_hidden = hf_model( |
|
|
input_ids=em_args["input_ids"], |
|
|
attention_mask=em_args["attention_mask"], |
|
|
**({"token_type_ids": em_args["token_type_ids"]} if "token_type_ids" in em_args else {}) |
|
|
).last_hidden_state |
|
|
return last_hidden[:, 0, :] |
|
|
|
|
|
|
|
|
intent_emb = embed(intent) |
|
|
|
|
|
results = [] |
|
|
with torch.no_grad(): |
|
|
for sent in sentences: |
|
|
clean = re.sub(r'^(Agent|Customer):\s*', "", sent) |
|
|
if len(clean.split()) < min_words: |
|
|
continue |
|
|
|
|
|
sent_emb = embed(clean) |
|
|
sim = F.cosine_similarity(sent_emb, intent_emb, dim=1).item() |
|
|
if sim >= threshold: |
|
|
results.append((clean, sim)) |
|
|
|
|
|
|
|
|
results.sort(key=lambda x: x[1], reverse=True) |
|
|
return results[:top_k] if top_k else results |
|
|
|
|
|
|
|
|
def train_sentence_extractor( |
|
|
model: nn.Module, |
|
|
dataset: torch.utils.data.Dataset, |
|
|
output_dir: str, |
|
|
val_split: float = 0.2, |
|
|
epochs: int = 3, |
|
|
batch_size: int = 16, |
|
|
lr: float = 2e-5, |
|
|
device: str = "cpu", |
|
|
unfreeze_after_epoch: int = 1, |
|
|
threshold: float = 0.5 |
|
|
): |
|
|
""" |
|
|
Fine-tune `model` on `dataset`, hold out `val_split` for val, |
|
|
compute loss + acc + precision + F1 each epoch, save best checkpoint, |
|
|
and plot all four metrics at the end. |
|
|
""" |
|
|
|
|
|
total = len(dataset) |
|
|
val_n = int(total * val_split) |
|
|
train_n = total - val_n |
|
|
train_ds, val_ds = random_split(dataset, [train_n, val_n]) |
|
|
|
|
|
|
|
|
train_labels = [train_ds[i]['label'].item() for i in range(len(train_ds))] |
|
|
counts = torch.bincount(torch.tensor(train_labels, dtype=torch.long)) |
|
|
weights = (1.0 / counts.float()).tolist() |
|
|
sample_weights = [weights[int(l)] for l in train_labels] |
|
|
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True) |
|
|
|
|
|
train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler, drop_last=True) |
|
|
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False) |
|
|
|
|
|
model.to(device) |
|
|
|
|
|
for p in model.bert.parameters(): p.requires_grad = False |
|
|
|
|
|
optimizer = AdamW(model.parameters(), lr=lr) |
|
|
total_steps = epochs * len(train_loader) |
|
|
scheduler = get_linear_schedule_with_warmup( |
|
|
optimizer, |
|
|
num_warmup_steps=int(0.1 * total_steps), |
|
|
num_training_steps=total_steps |
|
|
) |
|
|
criterion = nn.BCEWithLogitsLoss() |
|
|
|
|
|
|
|
|
train_losses, val_losses = [], [] |
|
|
train_accs, val_accs = [], [] |
|
|
train_precs, val_precs = [], [] |
|
|
train_f1s, val_f1s = [], [] |
|
|
|
|
|
best_val_loss = float('inf') |
|
|
|
|
|
for epoch in range(1, epochs+1): |
|
|
|
|
|
model.train() |
|
|
epoch_loss = 0.0 |
|
|
preds, labels = [], [] |
|
|
for batch in tqdm(train_loader, desc=f"Train {epoch}/{epochs}"): |
|
|
inputs = batch['input_ids'].to(device) |
|
|
masks = batch['attention_mask'].to(device) |
|
|
labs = batch['label'].to(device) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
logits = model(inputs, masks) |
|
|
loss = criterion(logits, labs) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
scheduler.step() |
|
|
|
|
|
epoch_loss += loss.item() |
|
|
|
|
|
probs = torch.sigmoid(logits) |
|
|
batch_preds = (probs >= threshold).long() |
|
|
preds.extend(batch_preds.cpu().tolist()) |
|
|
labels.extend(labs.cpu().long().tolist()) |
|
|
|
|
|
avg_train = epoch_loss / len(train_loader) |
|
|
train_losses.append(avg_train) |
|
|
train_accs.append( accuracy_score(labels, preds) ) |
|
|
train_precs.append( precision_score(labels, preds, zero_division=0) ) |
|
|
train_f1s.append( f1_score(labels, preds, zero_division=0) ) |
|
|
print(f"→ Epoch {epoch} Train — loss {avg_train:.4f}, acc {train_accs[-1]:.4f}, prec {train_precs[-1]:.4f}, f1 {train_f1s[-1]:.4f}") |
|
|
|
|
|
|
|
|
if epoch == unfreeze_after_epoch: |
|
|
for p in model.bert.parameters(): p.requires_grad = True |
|
|
optimizer = AdamW([ |
|
|
{"params": model.classifier.parameters(), "lr": 1e-3}, |
|
|
{"params": model.bert.parameters(), "lr": 1e-5}, |
|
|
], weight_decay=1e-2) |
|
|
scheduler = get_linear_schedule_with_warmup( |
|
|
optimizer, |
|
|
num_warmup_steps=int(0.1 * total_steps), |
|
|
num_training_steps=total_steps |
|
|
) |
|
|
|
|
|
|
|
|
model.eval() |
|
|
epoch_loss = 0.0 |
|
|
preds, labels = [], [] |
|
|
with torch.no_grad(): |
|
|
for batch in tqdm(val_loader, desc=f" Val {epoch}/{epochs}"): |
|
|
inputs = batch['input_ids'].to(device) |
|
|
masks = batch['attention_mask'].to(device) |
|
|
labs = batch['label'].to(device) |
|
|
|
|
|
logits = model(inputs, masks) |
|
|
loss = criterion(logits, labs) |
|
|
epoch_loss += loss.item() |
|
|
|
|
|
probs = torch.sigmoid(logits) |
|
|
batch_preds = (probs >= threshold).long() |
|
|
preds.extend(batch_preds.cpu().tolist()) |
|
|
labels.extend(labs.cpu().long().tolist()) |
|
|
|
|
|
avg_val = epoch_loss / len(val_loader) |
|
|
val_losses.append(avg_val) |
|
|
val_accs.append( accuracy_score(labels, preds) ) |
|
|
val_precs.append( precision_score(labels, preds, zero_division=0) ) |
|
|
val_f1s.append( f1_score(labels, preds, zero_division=0) ) |
|
|
print(f"→ Epoch {epoch} Val — loss {avg_val:.4f}, acc {val_accs[-1]:.4f}, prec {val_precs[-1]:.4f}, f1 {val_f1s[-1]:.4f}") |
|
|
|
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
ckpt = os.path.join(output_dir, f"epo{epoch}_val{avg_val:.4f}.pth") |
|
|
torch.save(model.state_dict(), ckpt) |
|
|
if avg_val < best_val_loss: |
|
|
best_val_loss = avg_val |
|
|
torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pth")) |
|
|
print(f"🎉 New best model saved (val loss {best_val_loss:.4f})") |
|
|
|
|
|
print(f"✔️ Training complete — best val loss: {best_val_loss:.4f}") |
|
|
|
|
|
|
|
|
epochs = list(range(1, epochs+1)) |
|
|
|
|
|
save_metric_plot( |
|
|
epochs, |
|
|
train_losses, |
|
|
val_losses, |
|
|
metric_name="Loss", |
|
|
output_path="results/Loss_Plot.png" |
|
|
) |
|
|
|
|
|
save_metric_plot( |
|
|
epochs, |
|
|
train_accs, |
|
|
val_accs, |
|
|
metric_name="Accuracy", |
|
|
output_path="results/Accuracy_Plot.png", |
|
|
threshold=0.5 |
|
|
) |
|
|
|
|
|
save_metric_plot( |
|
|
epochs, |
|
|
train_precs, |
|
|
val_precs, |
|
|
metric_name="Precision", |
|
|
output_path="results/Precision_Plot.png", |
|
|
threshold=0.5 |
|
|
) |
|
|
|
|
|
save_metric_plot( |
|
|
epochs, |
|
|
train_f1s, |
|
|
val_f1s, |
|
|
metric_name="F1 Score", |
|
|
output_path="results/F1Score_Plot.png", |
|
|
threshold=0.5 |
|
|
) |
|
|
|
|
|
|
|
|
def save_metric_plot( |
|
|
epochs, |
|
|
train_vals, |
|
|
val_vals, |
|
|
metric_name: str, |
|
|
output_path: str, |
|
|
threshold: float = None |
|
|
): |
|
|
""" |
|
|
epochs – list of epoch indices |
|
|
train_vals – list of train metric values |
|
|
val_vals – list of validation metric values |
|
|
metric_name – e.g. "Loss", "Accuracy", "Precision", "F1 Score" |
|
|
output_path – where to save the PNG |
|
|
threshold – optional horizontal line to draw, e.g. 0.5 |
|
|
""" |
|
|
fig, ax = plt.subplots(figsize=(8, 5)) |
|
|
ax.plot(epochs, train_vals, marker='o', linewidth=2, label=f'Train {metric_name}') |
|
|
ax.plot(epochs, val_vals, marker='s', linewidth=2, label=f'Val {metric_name}') |
|
|
|
|
|
if threshold is not None: |
|
|
ax.axhline(threshold, color='gray', linestyle='--', linewidth=1, label=f'Threshold = {threshold}') |
|
|
|
|
|
ax.set_title(f'{metric_name} over Epochs', fontsize=14, pad=10) |
|
|
ax.set_xlabel('Epoch', fontsize=12) |
|
|
ax.set_ylabel(metric_name, fontsize=12) |
|
|
ax.grid(True, linestyle='--', alpha=0.4) |
|
|
ax.legend(loc='best', frameon=True, fontsize=10) |
|
|
fig.tight_layout() |
|
|
fig.savefig(output_path, dpi=300) |
|
|
plt.close(fig) |
|
|
|
|
|
|
|
|
def demo_on_random_val( |
|
|
model, |
|
|
tokenizer, |
|
|
excel_path: str, |
|
|
ckpt_path: str, |
|
|
max_length: int = 128, |
|
|
device: str = "cpu", |
|
|
temperature: float = 1.0 |
|
|
): |
|
|
""" |
|
|
Like demo_on_random_val, but instead of a fixed threshold: |
|
|
1) Compute sigmoid(logits / temperature) for each sentence |
|
|
2) Sort probabilities descending |
|
|
3) Find the largest gap between adjacent probs |
|
|
4) Set dynamic_threshold = midpoint of that gap |
|
|
5) Extract all sentences with prob >= dynamic_threshold |
|
|
""" |
|
|
|
|
|
model.load_state_dict(torch.load(ckpt_path, map_location=device)) |
|
|
model.to(device).eval() |
|
|
|
|
|
|
|
|
df = pd.read_excel(excel_path) |
|
|
_, val_df = train_test_split(df, test_size=0.2, random_state=42) |
|
|
row = val_df.sample(n=1, random_state=random.randint(0,999)).iloc[0] |
|
|
transcript = str(row['Claude_Call']) |
|
|
print(f"\n── Transcript (val sample idx={row['idx']}):\n{transcript}\n") |
|
|
|
|
|
|
|
|
sentences, probs = [], [] |
|
|
for sent in sent_tokenize(transcript): |
|
|
enc = tokenizer.encode_plus( |
|
|
sent, |
|
|
max_length=max_length, |
|
|
padding='max_length', |
|
|
truncation=True, |
|
|
return_tensors='pt' |
|
|
) |
|
|
logits = model(enc['input_ids'].to(device), |
|
|
enc['attention_mask'].to(device)) |
|
|
prob = torch.sigmoid(logits / temperature).item() |
|
|
sentences.append(sent) |
|
|
probs.append(prob) |
|
|
|
|
|
|
|
|
print("Sentence probabilities:") |
|
|
for s,p in zip(sentences, probs): |
|
|
print(f" → {p:.4f} → {s}") |
|
|
|
|
|
|
|
|
if len(probs) < 2 or max(probs) - min(probs) < 1e-3: |
|
|
dynamic_thr = 0.5 |
|
|
else: |
|
|
|
|
|
sorted_probs = sorted(probs, reverse=True) |
|
|
diffs = [sorted_probs[i] - sorted_probs[i+1] for i in range(len(sorted_probs)-1)] |
|
|
idx = max(range(len(diffs)), key=lambda i: diffs[i]) |
|
|
|
|
|
dynamic_thr = (sorted_probs[idx] + sorted_probs[idx+1]) / 2.0 |
|
|
|
|
|
print(f"\nDynamic threshold = {dynamic_thr:.4f}\n") |
|
|
print("Extracted sentences:") |
|
|
for s,p in zip(sentences, probs): |
|
|
if p >= dynamic_thr: |
|
|
print(f" • {p:.4f} → {s}") |
|
|
print() |
|
|
|
|
|
|
|
|
def batch_predict_and_save( |
|
|
model, |
|
|
tokenizer, |
|
|
excel_path: str, |
|
|
ckpt_path: str, |
|
|
output_path: str, |
|
|
n_samples: int = 40, |
|
|
max_length: int = 128, |
|
|
device: str = "cpu", |
|
|
temperature: float = 1.0, |
|
|
random_state: int = None |
|
|
): |
|
|
""" |
|
|
1) Loads best checkpoint |
|
|
2) Samples `n_samples` rows |
|
|
3) For each transcript: |
|
|
- tokenize into sentences |
|
|
- compute p = sigmoid(logits/temperature) |
|
|
- compute elbow threshold on sorted p’s |
|
|
- extract all sentences with p >= elbow |
|
|
- if none, pick the highest-p sentence |
|
|
4) Save new Excel with columns: |
|
|
- 'Claude_Call' |
|
|
- 'Predicted Sel_K' (list of extracted sentences) |
|
|
""" |
|
|
|
|
|
model.load_state_dict(torch.load(ckpt_path, map_location=device)) |
|
|
model.to(device).eval() |
|
|
|
|
|
|
|
|
df = pd.read_excel(excel_path) |
|
|
sampled = df.sample(n=n_samples, random_state=random_state) \ |
|
|
if random_state is not None else df.sample(n=n_samples) |
|
|
|
|
|
records = [] |
|
|
for _, row in tqdm(sampled.iterrows(), |
|
|
total=len(sampled), |
|
|
desc="Running Predictions"): |
|
|
transcript = str(row['Claude_Call']) |
|
|
sentences = sent_tokenize(transcript) |
|
|
|
|
|
|
|
|
probs = [] |
|
|
for sent in sentences: |
|
|
enc = tokenizer.encode_plus( |
|
|
sent, |
|
|
max_length=max_length, |
|
|
padding='max_length', |
|
|
truncation=True, |
|
|
return_tensors='pt' |
|
|
) |
|
|
with torch.no_grad(): |
|
|
logits = model(enc['input_ids'].to(device), |
|
|
enc['attention_mask'].to(device)) |
|
|
p = torch.sigmoid(logits / temperature).item() |
|
|
probs.append(p) |
|
|
|
|
|
|
|
|
if len(probs) >= 2 and max(probs) - min(probs) > 1e-3: |
|
|
sp = sorted(probs, reverse=True) |
|
|
diffs = [sp[i] - sp[i+1] for i in range(len(sp)-1)] |
|
|
idx = max(range(len(diffs)), key=lambda i: diffs[i]) |
|
|
thr = (sp[idx] + sp[idx+1]) / 2.0 |
|
|
else: |
|
|
thr = 0.5 |
|
|
|
|
|
|
|
|
extracted = [s for s,p in zip(sentences, probs) if p >= thr] |
|
|
if not extracted and sentences: |
|
|
best_idx = int(max(range(len(probs)), key=lambda i: probs[i])) |
|
|
extracted = [sentences[best_idx]] |
|
|
|
|
|
records.append({ |
|
|
'Claude_Call': transcript, |
|
|
'Predicted Sel_K': extracted |
|
|
}) |
|
|
|
|
|
|
|
|
out_df = pd.DataFrame(records) |
|
|
os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True) |
|
|
out_df.to_excel(output_path, index=False) |
|
|
print(f"➡️ Saved {len(out_df)} rows to {output_path}") |