|
|
import torch |
|
|
import os |
|
|
import math |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from peft import PeftModel, LoraConfig, get_peft_model |
|
|
from transformers import AutoTokenizer, AutoModel, AutoConfig, get_linear_schedule_with_warmup |
|
|
from torch.nn import MultiheadAttention, GELU |
|
|
|
|
|
MODEL_NAME = "bert-base-uncased" |
|
|
BATCH_SIZE = 16 |
|
|
MAX_LENGTH = 128 |
|
|
LEARNING_RATE = 2e-5 |
|
|
EPOCHS = 5 |
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
PREPROCESSED_DIR = "preprocessed_snli" |
|
|
MIXED_PRECISION = "fp16" |
|
|
|
|
|
|
|
|
class SimpleGNN(nn.Module): |
|
|
def __init__(self, input_dim, hidden_dim): |
|
|
super().__init__() |
|
|
self.fc = nn.Linear(input_dim, hidden_dim) |
|
|
|
|
|
def forward(self, node_embeddings, edges): |
|
|
if node_embeddings.size(0) == 0: |
|
|
return torch.zeros(1, self.fc.out_features, device=node_embeddings.device) |
|
|
num_nodes = node_embeddings.size(0) |
|
|
adj = torch.zeros((num_nodes, num_nodes), device=node_embeddings.device) |
|
|
for (src, dst) in edges: |
|
|
if src < num_nodes and dst < num_nodes: |
|
|
adj[src, dst] = 1.0 |
|
|
deg = adj.sum(dim=1, keepdim=True) + 1e-10 |
|
|
adj_norm = adj / deg |
|
|
agg_embeddings = adj_norm @ node_embeddings |
|
|
return F.relu(self.fc(agg_embeddings)) |
|
|
|
|
|
|
|
|
class GraphAugmentedNLIModel(nn.Module): |
|
|
def __init__(self, base_model_name, num_labels=3, hidden_dim=768, gnn_dim=128): |
|
|
super().__init__() |
|
|
config = AutoConfig.from_pretrained(base_model_name) |
|
|
config.num_labels = num_labels |
|
|
self.bert = AutoModel.from_pretrained(base_model_name, config=config) |
|
|
self.dropout = nn.Dropout(0.1) |
|
|
|
|
|
self.gnn_premise = SimpleGNN(hidden_dim, gnn_dim) |
|
|
self.gnn_hypothesis = SimpleGNN(hidden_dim, gnn_dim) |
|
|
|
|
|
self.classifier = nn.Linear(hidden_dim + gnn_dim*2, num_labels) |
|
|
|
|
|
def forward(self, input_ids, attention_mask, premise_graph_tokens, premise_graph_edges, premise_node_indices, |
|
|
hypothesis_graph_tokens, hypothesis_graph_edges, hypothesis_node_indices, labels=None): |
|
|
outputs = self.bert(input_ids, attention_mask=attention_mask) |
|
|
cls_embedding = outputs.last_hidden_state[:,0,:] |
|
|
|
|
|
batch_size = input_ids.size(0) |
|
|
gnn_p_outputs = [] |
|
|
gnn_h_outputs = [] |
|
|
|
|
|
|
|
|
|
|
|
for i in range(batch_size): |
|
|
instance_hidden = outputs.last_hidden_state[i] |
|
|
|
|
|
p_edges = premise_graph_edges[i] |
|
|
p_indices = premise_node_indices[i] |
|
|
h_edges = hypothesis_graph_edges[i] |
|
|
h_indices = hypothesis_node_indices[i] |
|
|
|
|
|
|
|
|
p_nodes = instance_hidden[p_indices] if len(p_indices) > 0 else torch.empty(0, instance_hidden.size(-1), device=instance_hidden.device) |
|
|
h_nodes = instance_hidden[h_indices] if len(h_indices) > 0 else torch.empty(0, instance_hidden.size(-1), device=instance_hidden.device) |
|
|
|
|
|
p_gnn_out = self.gnn_premise(p_nodes, p_edges) if p_nodes.size(0) > 0 else torch.zeros(1,128, device=DEVICE) |
|
|
h_gnn_out = self.gnn_hypothesis(h_nodes, h_edges) if h_nodes.size(0) > 0 else torch.zeros(1,128, device=DEVICE) |
|
|
|
|
|
p_mean = p_gnn_out.mean(dim=0, keepdim=True) |
|
|
h_mean = h_gnn_out.mean(dim=0, keepdim=True) |
|
|
|
|
|
gnn_p_outputs.append(p_mean) |
|
|
gnn_h_outputs.append(h_mean) |
|
|
|
|
|
gnn_p_outputs = torch.cat(gnn_p_outputs, dim=0) |
|
|
gnn_h_outputs = torch.cat(gnn_h_outputs, dim=0) |
|
|
|
|
|
fused = torch.cat([cls_embedding, gnn_p_outputs, gnn_h_outputs], dim=-1) |
|
|
fused = self.dropout(fused) |
|
|
logits = self.classifier(fused) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fn = nn.CrossEntropyLoss() |
|
|
loss = loss_fn(logits, labels) |
|
|
return {"loss": loss, "logits": logits} |
|
|
|
|
|
|
|
|
|
|
|
class SimpleFinGNN(nn.Module): |
|
|
def __init__(self, input_dim, hidden_dim): |
|
|
super().__init__() |
|
|
self.fc = nn.Linear(input_dim, hidden_dim) |
|
|
|
|
|
def forward(self, node_embeddings, edges): |
|
|
if node_embeddings.size(0) == 0: |
|
|
return torch.zeros(1, self.fc.out_features, device=node_embeddings.device) |
|
|
num_nodes = node_embeddings.size(0) |
|
|
adj = torch.zeros((num_nodes, num_nodes), device=node_embeddings.device) |
|
|
for (src, dst) in edges: |
|
|
if src < num_nodes and dst < num_nodes: |
|
|
adj[src, dst] = 1.0 |
|
|
deg = adj.sum(dim=1, keepdim=True) + 1e-10 |
|
|
adj_norm = adj / deg |
|
|
agg_embeddings = adj_norm @ node_embeddings |
|
|
return F.relu(self.fc(agg_embeddings)) |
|
|
|
|
|
|
|
|
class GraphAugmentedFinNLIModel(nn.Module): |
|
|
def __init__(self, base_model_name, num_labels=3, hidden_dim=768, gnn_dim=128): |
|
|
super().__init__() |
|
|
config = AutoConfig.from_pretrained(base_model_name) |
|
|
config.num_labels = num_labels |
|
|
self.bert = AutoModel.from_pretrained(base_model_name, config=config) |
|
|
self.dropout = nn.Dropout(0.1) |
|
|
|
|
|
self.gnn_premise = SimpleGNN(hidden_dim, gnn_dim) |
|
|
self.gnn_hypothesis = SimpleGNN(hidden_dim, gnn_dim) |
|
|
|
|
|
self.classifier = nn.Linear(hidden_dim + gnn_dim*2, num_labels) |
|
|
self.config = self.bert.config |
|
|
self.config.num_labels = num_labels |
|
|
|
|
|
def forward(self, |
|
|
input_ids=None, |
|
|
attention_mask=None, |
|
|
premise_graph_tokens=None, |
|
|
hypothesis_graph_tokens=None, |
|
|
premise_graph_edges=None, |
|
|
hypothesis_graph_edges=None, |
|
|
premise_node_indices=None, |
|
|
hypothesis_node_indices=None, |
|
|
labels=None, |
|
|
inputs_embeds=None, |
|
|
**kwargs): |
|
|
|
|
|
outputs = self.bert(input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
**{k:v for k,v in kwargs.items() if k in self.bert.forward.__code__.co_varnames}) |
|
|
|
|
|
cls_embedding = outputs.last_hidden_state[:,0,:] |
|
|
|
|
|
batch_size = input_ids.size(0) if input_ids is not None else outputs.last_hidden_state.size(0) |
|
|
gnn_p_outputs = [] |
|
|
gnn_h_outputs = [] |
|
|
|
|
|
for i in range(batch_size): |
|
|
instance_hidden = outputs.last_hidden_state[i] |
|
|
|
|
|
p_edges = premise_graph_edges[i] |
|
|
p_indices = premise_node_indices[i] |
|
|
h_edges = hypothesis_graph_edges[i] |
|
|
h_indices = hypothesis_node_indices[i] |
|
|
|
|
|
p_nodes = instance_hidden[p_indices] if len(p_indices) > 0 else torch.empty(0, instance_hidden.size(-1), device=instance_hidden.device) |
|
|
h_nodes = instance_hidden[h_indices] if len(h_indices) > 0 else torch.empty(0, instance_hidden.size(-1), device=instance_hidden.device) |
|
|
|
|
|
p_gnn_out = self.gnn_premise(p_nodes, p_edges) if p_nodes.size(0) > 0 else torch.zeros(1,128, device=instance_hidden.device) |
|
|
h_gnn_out = self.gnn_hypothesis(h_nodes, h_edges) if h_nodes.size(0) > 0 else torch.zeros(1,128, device=instance_hidden.device) |
|
|
|
|
|
p_mean = p_gnn_out.mean(dim=0, keepdim=True) |
|
|
h_mean = h_gnn_out.mean(dim=0, keepdim=True) |
|
|
|
|
|
gnn_p_outputs.append(p_mean) |
|
|
gnn_h_outputs.append(h_mean) |
|
|
|
|
|
gnn_p_outputs = torch.cat(gnn_p_outputs, dim=0) |
|
|
gnn_h_outputs = torch.cat(gnn_h_outputs, dim=0) |
|
|
|
|
|
fused = torch.cat([cls_embedding, gnn_p_outputs, gnn_h_outputs], dim=-1) |
|
|
logits = self.classifier(fused) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fn = nn.CrossEntropyLoss() |
|
|
loss = loss_fn(logits, labels) |
|
|
return {"loss": loss, "logits": logits} |
|
|
|
|
|
|
|
|
class SentenceExtractionModel(nn.Module): |
|
|
def __init__(self, |
|
|
base_model_name: str, |
|
|
dropout_prob: float = 0.1, |
|
|
adapter_dir: str = "./lora_finance_adapter", |
|
|
backbone: str = 'default', |
|
|
init_pos_frac: float = None |
|
|
): |
|
|
""" |
|
|
backbone: |
|
|
- 'default' → plain AutoModel.from_pretrained(base_model_name) |
|
|
- 'finexbert' → use the .bert submodule of your GraphAugmentedFinNLIModel |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
config = AutoConfig.from_pretrained(base_model_name) |
|
|
|
|
|
if backbone == 'default': |
|
|
|
|
|
self.bert = AutoModel.from_pretrained(base_model_name, config=config) |
|
|
|
|
|
elif backbone == 'finexbert': |
|
|
|
|
|
base_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE) |
|
|
lora_cfg = LoraConfig( |
|
|
r=8, |
|
|
lora_alpha=32, |
|
|
lora_dropout=0.1, |
|
|
bias="none", |
|
|
task_type="SEQ_CLS" |
|
|
) |
|
|
full = get_peft_model(base_model, lora_cfg).to(DEVICE) |
|
|
chkpt_path = os.path.join(adapter_dir, "training_checkpoint.pt") |
|
|
if not os.path.isfile(chkpt_path): |
|
|
raise FileNotFoundError(f"No LoRA checkpoint at {chkpt_path}") |
|
|
ckpt = torch.load(chkpt_path, map_location=DEVICE) |
|
|
|
|
|
full.load_state_dict(ckpt["model_state_dict"], strict=False) |
|
|
|
|
|
|
|
|
self.bert = full.base_model |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unknown backbone {backbone}") |
|
|
|
|
|
hidden_size = self.bert.config.hidden_size |
|
|
|
|
|
self.dropout = nn.Dropout(dropout_prob) |
|
|
self.classifier = nn.Linear(hidden_size, 1) |
|
|
|
|
|
|
|
|
if init_pos_frac is not None: |
|
|
b0 = float(math.log(init_pos_frac / (1.0 - init_pos_frac))) |
|
|
self.classifier.bias.data.fill_(b0) |
|
|
|
|
|
def forward(self, input_ids, attention_mask): |
|
|
outputs = self.bert(input_ids=input_ids, |
|
|
attention_mask=attention_mask) |
|
|
x = self.dropout(outputs.pooler_output) |
|
|
logits = self.classifier(x).squeeze(-1) |
|
|
return logits |