Fin-ExBERT / preprocess_data.py
ssrogue's picture
Upload folder using huggingface_hub
b1e8fe0 verified
import os
import logging
import torch
from torch.utils.data import Dataset
from datasets import load_dataset, load_from_disk
import pandas as pd
import nltk
from config import MODEL_NAME, MAX_LENGTH, OVERLAP, PREPROCESSED_DIR, tokenizer, nlp
# =============================
# Logging Setup
# =============================
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
# =============================
# One-Time Preprocessing
# =============================
def process_data():
if not os.path.exists(PREPROCESSED_DIR):
logging.info("Preprocessing data... This may take a while.")
# Load and filter SNLI
snli = load_dataset("snli")
snli = snli.filter(lambda x: x["label"] != -1)
def build_dependency_graph(sentence):
doc = nlp(sentence)
tokens = [tok.text for tok in doc]
edges = []
for tok in doc:
if tok.head.i != tok.i:
edges.extend([(tok.i, tok.head.i), (tok.head.i, tok.i)])
return tokens, edges
def preprocess(examples):
premises = examples["premise"]
hypotheses = examples["hypothesis"]
labels = examples["label"]
tokenized = tokenizer(premises, hypotheses,
truncation=True, padding="max_length",
max_length=MAX_LENGTH)
tokenized["labels"] = labels
p_tokens_list, p_edges_list, p_idx_list = [], [], []
h_tokens_list, h_edges_list, h_idx_list = [], [], []
for p, h, input_ids in zip(premises, hypotheses, tokenized["input_ids"]):
p_toks, p_edges = build_dependency_graph(p)
h_toks, h_edges = build_dependency_graph(h)
wp_tokens = tokenizer.convert_ids_to_tokens(input_ids)
def align_tokens(spacy_tokens, wp_tokens):
node_indices, wp_idx = [], 1
for _ in spacy_tokens:
if wp_idx >= len(wp_tokens) - 1: break
node_indices.append(wp_idx)
wp_idx += 1
while wp_idx < len(wp_tokens) - 1 and wp_tokens[wp_idx].startswith("##"):
wp_idx += 1
return node_indices
p_idx = align_tokens(p_toks, wp_tokens)
h_idx = align_tokens(h_toks, wp_tokens)
p_tokens_list.append(p_toks)
p_edges_list.append(p_edges)
p_idx_list.append(p_idx)
h_tokens_list.append(h_toks)
h_edges_list.append(h_edges)
h_idx_list.append(h_idx)
tokenized.update({
"premise_graph_tokens": p_tokens_list,
"premise_graph_edges": p_edges_list,
"premise_node_indices": p_idx_list,
"hypothesis_graph_tokens": h_tokens_list,
"hypothesis_graph_edges": h_edges_list,
"hypothesis_node_indices": h_idx_list,
})
return tokenized
snli = snli.map(preprocess, batched=True)
snli.save_to_disk(PREPROCESSED_DIR)
logging.info(f"Preprocessing complete. Saved to {PREPROCESSED_DIR}")
else:
logging.info("Using existing preprocessed data at %s", PREPROCESSED_DIR)
def chunk_transcript(transcript_text, start_idx, end_idx, tokenizer):
encoded = tokenizer(transcript_text,
return_offsets_mapping=True,
add_special_tokens=True,
return_tensors=None,
max_length=1024,
padding=False,
truncation=False)
all_input_ids = encoded["input_ids"]
all_offsets = encoded["offset_mapping"]
chunks = []
i = 0
while i < len(all_input_ids):
chunk_ids = all_input_ids[i : i + MAX_LENGTH]
chunk_offsets = all_offsets[i : i + MAX_LENGTH]
attention_mask = [1] * len(chunk_ids)
no_span = 1
start_token, end_token = -1, -1
if start_idx >= 0 and end_idx >= 0:
for j, (off_s, off_e) in enumerate(chunk_offsets):
if off_s <= start_idx < off_e:
start_token = j
if off_s < end_idx <= off_e:
end_token = j
break
if 0 <= start_token <= end_token:
no_span = 0
else:
start_token, end_token = -1, -1
chunks.append({
"input_ids": torch.tensor(chunk_ids, dtype=torch.long),
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
"start_label": start_token,
"end_label": end_token,
"no_span_label": no_span,
})
i += (MAX_LENGTH - OVERLAP)
return chunks
class SpanExtractionChunkedDataset(Dataset):
def __init__(self, data):
self.samples = []
for item in data:
chunks = chunk_transcript(
item.get("transcript", ""),
item.get("start_idx", -1),
item.get("end_idx", -1),
tokenizer)
self.samples.extend(chunks)
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.samples[idx]
def span_collate_fn(batch):
max_len = max(len(x["input_ids"]) for x in batch)
inputs, masks, starts, ends, nos = [], [], [], [], []
for x in batch:
pad = max_len - len(x["input_ids"])
inputs.append(torch.cat([x["input_ids"], torch.zeros(pad, dtype=torch.long)]).unsqueeze(0))
masks.append(torch.cat([x["attention_mask"], torch.zeros(pad, dtype=torch.long)]).unsqueeze(0))
starts.append(x["start_label"])
ends.append(x["end_label"])
nos.append(x["no_span_label"])
return {
"input_ids": torch.cat(inputs, dim=0),
"attention_mask": torch.cat(masks, dim=0),
"start_positions": torch.tensor(starts, dtype=torch.long),
"end_positions": torch.tensor(ends, dtype=torch.long),
"no_span_label": torch.tensor(nos, dtype=torch.long),
}
nltk.download('punkt')
nltk.download('punkt_tab')
class SentenceDataset(Dataset):
def __init__(self,
excel_path: str,
tokenizer,
max_length: int = 128):
df = pd.read_excel(excel_path)
self.samples = []
for _, row in df.iterrows():
transcript = str(row['Claude_Call'])
gold_sentences = row['Sel_K']
# if it's a string repr of list, eval it
if isinstance(gold_sentences, str):
gold_sentences = eval(gold_sentences)
# split into sentences
sentences = nltk.sent_tokenize(transcript)
for sent in sentences:
label = 1 if sent in gold_sentences else 0
enc = tokenizer.encode_plus(
sent,
max_length=max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
self.samples.append({
'input_ids': enc['input_ids'].squeeze(0),
'attention_mask': enc['attention_mask'].squeeze(0),
'label': torch.tensor(label, dtype=torch.float)
})
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.samples[idx]