|
|
import os |
|
|
import logging |
|
|
import torch |
|
|
from torch.utils.data import Dataset |
|
|
from datasets import load_dataset, load_from_disk |
|
|
import pandas as pd |
|
|
import nltk |
|
|
|
|
|
from config import MODEL_NAME, MAX_LENGTH, OVERLAP, PREPROCESSED_DIR, tokenizer, nlp |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_data(): |
|
|
if not os.path.exists(PREPROCESSED_DIR): |
|
|
logging.info("Preprocessing data... This may take a while.") |
|
|
|
|
|
snli = load_dataset("snli") |
|
|
snli = snli.filter(lambda x: x["label"] != -1) |
|
|
|
|
|
def build_dependency_graph(sentence): |
|
|
doc = nlp(sentence) |
|
|
tokens = [tok.text for tok in doc] |
|
|
edges = [] |
|
|
for tok in doc: |
|
|
if tok.head.i != tok.i: |
|
|
edges.extend([(tok.i, tok.head.i), (tok.head.i, tok.i)]) |
|
|
return tokens, edges |
|
|
|
|
|
def preprocess(examples): |
|
|
premises = examples["premise"] |
|
|
hypotheses = examples["hypothesis"] |
|
|
labels = examples["label"] |
|
|
tokenized = tokenizer(premises, hypotheses, |
|
|
truncation=True, padding="max_length", |
|
|
max_length=MAX_LENGTH) |
|
|
tokenized["labels"] = labels |
|
|
|
|
|
p_tokens_list, p_edges_list, p_idx_list = [], [], [] |
|
|
h_tokens_list, h_edges_list, h_idx_list = [], [], [] |
|
|
|
|
|
for p, h, input_ids in zip(premises, hypotheses, tokenized["input_ids"]): |
|
|
p_toks, p_edges = build_dependency_graph(p) |
|
|
h_toks, h_edges = build_dependency_graph(h) |
|
|
wp_tokens = tokenizer.convert_ids_to_tokens(input_ids) |
|
|
|
|
|
def align_tokens(spacy_tokens, wp_tokens): |
|
|
node_indices, wp_idx = [], 1 |
|
|
for _ in spacy_tokens: |
|
|
if wp_idx >= len(wp_tokens) - 1: break |
|
|
node_indices.append(wp_idx) |
|
|
wp_idx += 1 |
|
|
while wp_idx < len(wp_tokens) - 1 and wp_tokens[wp_idx].startswith("##"): |
|
|
wp_idx += 1 |
|
|
return node_indices |
|
|
|
|
|
p_idx = align_tokens(p_toks, wp_tokens) |
|
|
h_idx = align_tokens(h_toks, wp_tokens) |
|
|
|
|
|
p_tokens_list.append(p_toks) |
|
|
p_edges_list.append(p_edges) |
|
|
p_idx_list.append(p_idx) |
|
|
|
|
|
h_tokens_list.append(h_toks) |
|
|
h_edges_list.append(h_edges) |
|
|
h_idx_list.append(h_idx) |
|
|
|
|
|
tokenized.update({ |
|
|
"premise_graph_tokens": p_tokens_list, |
|
|
"premise_graph_edges": p_edges_list, |
|
|
"premise_node_indices": p_idx_list, |
|
|
"hypothesis_graph_tokens": h_tokens_list, |
|
|
"hypothesis_graph_edges": h_edges_list, |
|
|
"hypothesis_node_indices": h_idx_list, |
|
|
}) |
|
|
return tokenized |
|
|
|
|
|
snli = snli.map(preprocess, batched=True) |
|
|
snli.save_to_disk(PREPROCESSED_DIR) |
|
|
logging.info(f"Preprocessing complete. Saved to {PREPROCESSED_DIR}") |
|
|
else: |
|
|
logging.info("Using existing preprocessed data at %s", PREPROCESSED_DIR) |
|
|
|
|
|
|
|
|
def chunk_transcript(transcript_text, start_idx, end_idx, tokenizer): |
|
|
encoded = tokenizer(transcript_text, |
|
|
return_offsets_mapping=True, |
|
|
add_special_tokens=True, |
|
|
return_tensors=None, |
|
|
max_length=1024, |
|
|
padding=False, |
|
|
truncation=False) |
|
|
all_input_ids = encoded["input_ids"] |
|
|
all_offsets = encoded["offset_mapping"] |
|
|
|
|
|
chunks = [] |
|
|
i = 0 |
|
|
while i < len(all_input_ids): |
|
|
chunk_ids = all_input_ids[i : i + MAX_LENGTH] |
|
|
chunk_offsets = all_offsets[i : i + MAX_LENGTH] |
|
|
attention_mask = [1] * len(chunk_ids) |
|
|
|
|
|
no_span = 1 |
|
|
start_token, end_token = -1, -1 |
|
|
if start_idx >= 0 and end_idx >= 0: |
|
|
for j, (off_s, off_e) in enumerate(chunk_offsets): |
|
|
if off_s <= start_idx < off_e: |
|
|
start_token = j |
|
|
if off_s < end_idx <= off_e: |
|
|
end_token = j |
|
|
break |
|
|
if 0 <= start_token <= end_token: |
|
|
no_span = 0 |
|
|
else: |
|
|
start_token, end_token = -1, -1 |
|
|
|
|
|
chunks.append({ |
|
|
"input_ids": torch.tensor(chunk_ids, dtype=torch.long), |
|
|
"attention_mask": torch.tensor(attention_mask, dtype=torch.long), |
|
|
"start_label": start_token, |
|
|
"end_label": end_token, |
|
|
"no_span_label": no_span, |
|
|
}) |
|
|
i += (MAX_LENGTH - OVERLAP) |
|
|
return chunks |
|
|
|
|
|
|
|
|
class SpanExtractionChunkedDataset(Dataset): |
|
|
def __init__(self, data): |
|
|
self.samples = [] |
|
|
for item in data: |
|
|
chunks = chunk_transcript( |
|
|
item.get("transcript", ""), |
|
|
item.get("start_idx", -1), |
|
|
item.get("end_idx", -1), |
|
|
tokenizer) |
|
|
self.samples.extend(chunks) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.samples) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
return self.samples[idx] |
|
|
|
|
|
|
|
|
def span_collate_fn(batch): |
|
|
max_len = max(len(x["input_ids"]) for x in batch) |
|
|
inputs, masks, starts, ends, nos = [], [], [], [], [] |
|
|
for x in batch: |
|
|
pad = max_len - len(x["input_ids"]) |
|
|
inputs.append(torch.cat([x["input_ids"], torch.zeros(pad, dtype=torch.long)]).unsqueeze(0)) |
|
|
masks.append(torch.cat([x["attention_mask"], torch.zeros(pad, dtype=torch.long)]).unsqueeze(0)) |
|
|
starts.append(x["start_label"]) |
|
|
ends.append(x["end_label"]) |
|
|
nos.append(x["no_span_label"]) |
|
|
return { |
|
|
"input_ids": torch.cat(inputs, dim=0), |
|
|
"attention_mask": torch.cat(masks, dim=0), |
|
|
"start_positions": torch.tensor(starts, dtype=torch.long), |
|
|
"end_positions": torch.tensor(ends, dtype=torch.long), |
|
|
"no_span_label": torch.tensor(nos, dtype=torch.long), |
|
|
} |
|
|
|
|
|
|
|
|
nltk.download('punkt') |
|
|
nltk.download('punkt_tab') |
|
|
|
|
|
class SentenceDataset(Dataset): |
|
|
def __init__(self, |
|
|
excel_path: str, |
|
|
tokenizer, |
|
|
max_length: int = 128): |
|
|
df = pd.read_excel(excel_path) |
|
|
self.samples = [] |
|
|
|
|
|
for _, row in df.iterrows(): |
|
|
transcript = str(row['Claude_Call']) |
|
|
gold_sentences = row['Sel_K'] |
|
|
|
|
|
if isinstance(gold_sentences, str): |
|
|
gold_sentences = eval(gold_sentences) |
|
|
|
|
|
|
|
|
sentences = nltk.sent_tokenize(transcript) |
|
|
for sent in sentences: |
|
|
label = 1 if sent in gold_sentences else 0 |
|
|
|
|
|
enc = tokenizer.encode_plus( |
|
|
sent, |
|
|
max_length=max_length, |
|
|
padding='max_length', |
|
|
truncation=True, |
|
|
return_tensors='pt' |
|
|
) |
|
|
self.samples.append({ |
|
|
'input_ids': enc['input_ids'].squeeze(0), |
|
|
'attention_mask': enc['attention_mask'].squeeze(0), |
|
|
'label': torch.tensor(label, dtype=torch.float) |
|
|
}) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.samples) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
return self.samples[idx] |