Spaces:
Sleeping
Sleeping
Upload 54 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- UI/__pycache__/gradio.cpython-310.pyc +0 -0
- UI/__pycache__/gradio.cpython-311.pyc +0 -0
- UI/gradio.py +575 -0
- app.py +20 -0
- environment.yml +245 -0
- metrics/detectability.py +323 -0
- metrics/distortion.py +370 -0
- renderers/__pycache__/highlighter.cpython-310.pyc +0 -0
- renderers/__pycache__/highlighter.cpython-311.pyc +0 -0
- renderers/__pycache__/plot_3d.cpython-310.pyc +0 -0
- renderers/__pycache__/plot_3d.cpython-311.pyc +0 -0
- renderers/__pycache__/tree.cpython-310.pyc +0 -0
- renderers/__pycache__/tree.cpython-311.pyc +0 -0
- renderers/highlighter.py +185 -0
- renderers/plot_3d.py +126 -0
- renderers/tree.py +490 -0
- utils/__init__.py +5 -0
- utils/__pycache__/__init__.cpython-310.pyc +0 -0
- utils/__pycache__/__init__.cpython-311.pyc +0 -0
- utils/__pycache__/config.cpython-310.pyc +0 -0
- utils/__pycache__/config.cpython-311.pyc +0 -0
- utils/__pycache__/entailment.cpython-310.pyc +0 -0
- utils/__pycache__/entailment.cpython-311.pyc +0 -0
- utils/__pycache__/masking_methods.cpython-310.pyc +0 -0
- utils/__pycache__/masking_methods.cpython-311.pyc +0 -0
- utils/__pycache__/ngram_index_manager.cpython-311.pyc +0 -0
- utils/__pycache__/non_melting_point.cpython-310.pyc +0 -0
- utils/__pycache__/non_melting_point.cpython-311.pyc +0 -0
- utils/__pycache__/paraphraser.cpython-310.pyc +0 -0
- utils/__pycache__/paraphraser.cpython-311.pyc +0 -0
- utils/__pycache__/sampling.cpython-310.pyc +0 -0
- utils/__pycache__/sampling.cpython-311.pyc +0 -0
- utils/__pycache__/watermark.cpython-310.pyc +0 -0
- utils/__pycache__/watermark.cpython-311.pyc +0 -0
- utils/config.py +18 -0
- utils/config.yaml +48 -0
- utils/entailment.py +107 -0
- utils/masking_methods.py +583 -0
- utils/non_melting_point.py +590 -0
- utils/old/masking/masking_methods.py +355 -0
- utils/old/masking/masking_methods_new_work.py +447 -0
- utils/old/masking/masking_methods_ok_working.py +257 -0
- utils/old/masking/masking_methods_v1_working.py +233 -0
- utils/old/masking_methods_final_copy.py +619 -0
- utils/old/non_melting_points_v1.py +244 -0
- utils/old/sampling/sampling.py +330 -0
- utils/old/sampling/sampling_methods.py +291 -0
- utils/old/sampling/sampling_methods_v1.py +146 -0
- utils/old/sampling/sampling_methods_v2.py +112 -0
- utils/old/sampling_final_copy.py +168 -0
UI/__pycache__/gradio.cpython-310.pyc
ADDED
|
Binary file (6.61 kB). View file
|
|
|
UI/__pycache__/gradio.cpython-311.pyc
ADDED
|
Binary file (29.8 kB). View file
|
|
|
UI/gradio.py
ADDED
|
@@ -0,0 +1,575 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from utils.watermark import Watermarker
|
| 3 |
+
from utils.config import load_config
|
| 4 |
+
from renderers.highlighter import highlight_common_words, highlight_common_words_dict, reparaphrased_sentences_html
|
| 5 |
+
from renderers.tree import generate_subplot1, generate_subplot2
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import time
|
| 8 |
+
from typing import Dict, List, Tuple, Any
|
| 9 |
+
import plotly.graph_objects as go
|
| 10 |
+
|
| 11 |
+
class WatermarkerInterface:
|
| 12 |
+
def __init__(self, config):
|
| 13 |
+
self.pipeline = Watermarker(config)
|
| 14 |
+
self.common_grams = {}
|
| 15 |
+
self.highlight_info = []
|
| 16 |
+
self.masked_sentences = []
|
| 17 |
+
|
| 18 |
+
# Add tracking dictionaries for indexing
|
| 19 |
+
self.masked_sentence_indices = {} # Maps original sentences to masked indices
|
| 20 |
+
self.sampled_sentence_indices = {} # Maps masked sentences to sampling indices
|
| 21 |
+
self.reparaphrased_indices = {} # Maps sampled sentences to reparaphrased indices
|
| 22 |
+
|
| 23 |
+
def handle_paraphrase(self, prompt: str) -> Tuple[str, str, str, str]:
|
| 24 |
+
"""Wrapper for paraphrasing that includes highlighting"""
|
| 25 |
+
start_time = time.time()
|
| 26 |
+
|
| 27 |
+
# Run paraphrasing
|
| 28 |
+
self.pipeline.Paraphrase(prompt)
|
| 29 |
+
|
| 30 |
+
# Step 1: Process the original sentence first
|
| 31 |
+
seen_ngrams = {} # Stores first occurrence index of each n-gram
|
| 32 |
+
original_indexed_ngrams = [] # Final indexed list for original
|
| 33 |
+
|
| 34 |
+
original_sentence = self.pipeline.user_prompt
|
| 35 |
+
original_ngrams = self.pipeline.common_grams.get(original_sentence, {})
|
| 36 |
+
|
| 37 |
+
# Step 1.1: Extract n-grams and their first occurrence index
|
| 38 |
+
ngram_occurrences = [
|
| 39 |
+
(min(indices, key=lambda x: x[0])[0], gram) # Get first index
|
| 40 |
+
for gram, indices in original_ngrams.items()
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
# Step 1.2: Sort n-grams based on their first occurrence
|
| 44 |
+
ngram_occurrences.sort()
|
| 45 |
+
|
| 46 |
+
# Step 1.3: Assign sequential indices
|
| 47 |
+
for idx, (position, gram) in enumerate(ngram_occurrences, start=1):
|
| 48 |
+
seen_ngrams[gram] = idx # Assign sequential index
|
| 49 |
+
original_indexed_ngrams.append((idx, gram))
|
| 50 |
+
|
| 51 |
+
print("Original Indexed N-grams:", original_indexed_ngrams)
|
| 52 |
+
|
| 53 |
+
#generate highlight_info
|
| 54 |
+
colors = ["red", "blue", "green", "orange"]
|
| 55 |
+
highlight_info = [
|
| 56 |
+
(ngram, colors[i % len(colors)])
|
| 57 |
+
for i, (index, ngram) in enumerate(original_indexed_ngrams)
|
| 58 |
+
]
|
| 59 |
+
common_grams = original_indexed_ngrams
|
| 60 |
+
self.highlight_info = highlight_info
|
| 61 |
+
self.common_grams = common_grams
|
| 62 |
+
|
| 63 |
+
# Step 2: Process paraphrased sentences and match indices
|
| 64 |
+
paraphrase_indexed_ngrams = {}
|
| 65 |
+
|
| 66 |
+
for sentence in self.pipeline.paraphrased_sentences:
|
| 67 |
+
sentence_ngrams = [] # Stores n-grams for this sentence
|
| 68 |
+
sentence_ngrams_dict = self.pipeline.common_grams.get(sentence, {})
|
| 69 |
+
|
| 70 |
+
for gram, indices in sentence_ngrams_dict.items():
|
| 71 |
+
first_occurrence = min(indices, key=lambda x: x[0])[0]
|
| 72 |
+
|
| 73 |
+
# Use the original's index if exists, otherwise assign a new one
|
| 74 |
+
if gram in seen_ngrams:
|
| 75 |
+
index = seen_ngrams[gram] # Use the same index as original
|
| 76 |
+
else:
|
| 77 |
+
index = len(seen_ngrams) + 1 # Assign new index
|
| 78 |
+
seen_ngrams[gram] = index # Store it
|
| 79 |
+
|
| 80 |
+
sentence_ngrams.append((index, gram))
|
| 81 |
+
|
| 82 |
+
sentence_ngrams.sort()
|
| 83 |
+
paraphrase_indexed_ngrams[sentence] = sentence_ngrams
|
| 84 |
+
|
| 85 |
+
print("Paraphrase Indexed N-grams:", paraphrase_indexed_ngrams)
|
| 86 |
+
|
| 87 |
+
# Step 3: Generate highlighted versions using the renderer
|
| 88 |
+
highlighted_prompt = highlight_common_words(
|
| 89 |
+
common_grams,
|
| 90 |
+
[self.pipeline.user_prompt],
|
| 91 |
+
"Original Prompt with Highlighted Common Sequences"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
highlighted_accepted = highlight_common_words_dict(
|
| 95 |
+
common_grams,
|
| 96 |
+
self.pipeline.selected_sentences,
|
| 97 |
+
"Accepted Paraphrased Sentences with Entailment Scores"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
highlighted_discarded = highlight_common_words_dict(
|
| 101 |
+
common_grams,
|
| 102 |
+
self.pipeline.discarded_sentences,
|
| 103 |
+
"Discarded Paraphrased Sentences with Entailment Scores"
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
execution_time = f"<div class='execution-time'>Step 1 completed in {time.time() - start_time:.2f} seconds</div>"
|
| 107 |
+
|
| 108 |
+
return highlighted_prompt, highlighted_accepted, highlighted_discarded, execution_time
|
| 109 |
+
|
| 110 |
+
def handle_masking(self):
|
| 111 |
+
start_time = time.time()
|
| 112 |
+
masking_results = self.pipeline.Masking()
|
| 113 |
+
trees = []
|
| 114 |
+
highlight_info = self.highlight_info
|
| 115 |
+
common_grams = self.common_grams
|
| 116 |
+
sentence_to_masked = {}
|
| 117 |
+
self.masked_sentence_indices = {}
|
| 118 |
+
|
| 119 |
+
for strategy, sentence_dict in masking_results.items():
|
| 120 |
+
for sent, data in sentence_dict.items():
|
| 121 |
+
if sent not in sentence_to_masked:
|
| 122 |
+
sentence_to_masked[sent] = []
|
| 123 |
+
masked_sentence = data.get("masked_sentence", "")
|
| 124 |
+
if masked_sentence:
|
| 125 |
+
sentence_to_masked[sent].append((masked_sentence, strategy))
|
| 126 |
+
|
| 127 |
+
plot_idx = 1
|
| 128 |
+
for original_sentence, masked_sentences_data in sentence_to_masked.items():
|
| 129 |
+
if not masked_sentences_data:
|
| 130 |
+
continue
|
| 131 |
+
masked_idx = 1
|
| 132 |
+
for masked_sentence, strategy in masked_sentences_data:
|
| 133 |
+
index = f"{plot_idx}{masked_idx}"
|
| 134 |
+
if original_sentence not in self.masked_sentence_indices:
|
| 135 |
+
self.masked_sentence_indices[original_sentence] = {}
|
| 136 |
+
key = f"{strategy}_{masked_sentence}"
|
| 137 |
+
self.masked_sentence_indices[original_sentence][key] = {
|
| 138 |
+
'index': index,
|
| 139 |
+
'strategy': strategy,
|
| 140 |
+
'masked_sentence': masked_sentence
|
| 141 |
+
}
|
| 142 |
+
masked_idx += 1
|
| 143 |
+
|
| 144 |
+
masked_sentences = [ms[0] for ms in masked_sentences_data]
|
| 145 |
+
indexed_masked_sentences = []
|
| 146 |
+
verified_strategies = []
|
| 147 |
+
for masked_sentence, strategy in masked_sentences_data:
|
| 148 |
+
key = f"{strategy}_{masked_sentence}"
|
| 149 |
+
entry = self.masked_sentence_indices[original_sentence][key]
|
| 150 |
+
idx = entry['index']
|
| 151 |
+
indexed_masked_sentences.append(f"[{idx}] {masked_sentence}")
|
| 152 |
+
verified_strategies.append(entry['strategy'])
|
| 153 |
+
|
| 154 |
+
try:
|
| 155 |
+
fig = generate_subplot1(
|
| 156 |
+
original_sentence,
|
| 157 |
+
indexed_masked_sentences,
|
| 158 |
+
verified_strategies,
|
| 159 |
+
highlight_info,
|
| 160 |
+
common_grams
|
| 161 |
+
)
|
| 162 |
+
trees.append(fig)
|
| 163 |
+
except Exception as e:
|
| 164 |
+
print(f"Error generating plot: {e}")
|
| 165 |
+
trees.append(go.Figure())
|
| 166 |
+
plot_idx += 1
|
| 167 |
+
|
| 168 |
+
while len(trees) < 10:
|
| 169 |
+
trees.append(go.Figure())
|
| 170 |
+
|
| 171 |
+
execution_time = f"<div class='execution-time'>Step 2 completed in {time.time() - start_time:.2f} seconds</div>"
|
| 172 |
+
return trees[:10] + [execution_time]
|
| 173 |
+
|
| 174 |
+
def handle_sampling(self) -> Tuple[List[go.Figure], str]:
|
| 175 |
+
start_time = time.time()
|
| 176 |
+
sampling_results = self.pipeline.Sampling()
|
| 177 |
+
trees = []
|
| 178 |
+
self.sampled_sentence_indices = {}
|
| 179 |
+
organized_results = {}
|
| 180 |
+
|
| 181 |
+
for sampling_strategy, masking_dict in sampling_results.items():
|
| 182 |
+
for masking_strategy, sentences in masking_dict.items():
|
| 183 |
+
for original_sentence, data in sentences.items():
|
| 184 |
+
if original_sentence not in organized_results:
|
| 185 |
+
organized_results[original_sentence] = {}
|
| 186 |
+
if masking_strategy not in organized_results[original_sentence]:
|
| 187 |
+
organized_results[original_sentence][masking_strategy] = {
|
| 188 |
+
"masked_sentence": data.get("masked_sentence", ""),
|
| 189 |
+
"sampled_sentences": {}
|
| 190 |
+
}
|
| 191 |
+
organized_results[original_sentence][masking_strategy]["sampled_sentences"][sampling_strategy] = data.get("sampled_sentence", "")
|
| 192 |
+
|
| 193 |
+
plot_idx = 1
|
| 194 |
+
for original_sentence, data in organized_results.items():
|
| 195 |
+
masked_sentences = []
|
| 196 |
+
all_sampled_sentences = []
|
| 197 |
+
indexed_sampled_sentences = []
|
| 198 |
+
masked_indices = self.masked_sentence_indices.get(original_sentence, {})
|
| 199 |
+
|
| 200 |
+
for masking_strategy, masking_data in list(data.items())[:3]:
|
| 201 |
+
masked_sentence = masking_data.get("masked_sentence", "")
|
| 202 |
+
if masked_sentence:
|
| 203 |
+
masked_sentences.append(masked_sentence)
|
| 204 |
+
masked_idx = None
|
| 205 |
+
for ms_key, ms_data in masked_indices.items():
|
| 206 |
+
if ms_key == f"{masking_strategy}_{masked_sentence}":
|
| 207 |
+
masked_idx = ms_data['index']
|
| 208 |
+
break
|
| 209 |
+
|
| 210 |
+
if not masked_idx:
|
| 211 |
+
print(f"Warning: No index found for masked sentence: {masked_sentence}")
|
| 212 |
+
continue
|
| 213 |
+
|
| 214 |
+
sample_count = 1
|
| 215 |
+
for sampling_strategy, sampled_sentence in masking_data.get("sampled_sentences", {}).items():
|
| 216 |
+
if sampled_sentence:
|
| 217 |
+
sample_idx = f"{masked_idx}.{sample_count}"
|
| 218 |
+
if masked_sentence not in self.sampled_sentence_indices:
|
| 219 |
+
self.sampled_sentence_indices[masked_sentence] = {}
|
| 220 |
+
self.sampled_sentence_indices[masked_sentence][sampled_sentence] = {
|
| 221 |
+
'index': sample_idx,
|
| 222 |
+
'strategy': sampling_strategy
|
| 223 |
+
}
|
| 224 |
+
indexed_sampled_sentences.append(f"[{sample_idx}] {sampled_sentence}")
|
| 225 |
+
all_sampled_sentences.append(sampled_sentence)
|
| 226 |
+
sample_count += 1
|
| 227 |
+
|
| 228 |
+
if masked_sentences:
|
| 229 |
+
indexed_masked_sentences = []
|
| 230 |
+
for ms in masked_sentences:
|
| 231 |
+
idx = ""
|
| 232 |
+
for ms_key, ms_data in masked_indices.items():
|
| 233 |
+
if ms_key.endswith(f"_{ms}"):
|
| 234 |
+
idx = ms_data['index']
|
| 235 |
+
break
|
| 236 |
+
indexed_masked_sentences.append(f"[{idx}] {ms}")
|
| 237 |
+
|
| 238 |
+
try:
|
| 239 |
+
fig = generate_subplot2(
|
| 240 |
+
indexed_masked_sentences,
|
| 241 |
+
indexed_sampled_sentences,
|
| 242 |
+
self.highlight_info,
|
| 243 |
+
self.common_grams
|
| 244 |
+
)
|
| 245 |
+
trees.append(fig)
|
| 246 |
+
except Exception as e:
|
| 247 |
+
print(f"Error generating subplot for {original_sentence}: {e}")
|
| 248 |
+
trees.append(go.Figure())
|
| 249 |
+
plot_idx += 1
|
| 250 |
+
|
| 251 |
+
print("Sampled sentence indices:", self.sampled_sentence_indices)
|
| 252 |
+
|
| 253 |
+
while len(trees) < 10:
|
| 254 |
+
trees.append(go.Figure())
|
| 255 |
+
|
| 256 |
+
execution_time = f"<div class='execution-time'>Step 3 completed in {time.time() - start_time:.2f} seconds</div>"
|
| 257 |
+
|
| 258 |
+
return trees[:10] + [execution_time]
|
| 259 |
+
|
| 260 |
+
def handle_reparaphrasing(self) -> Tuple[List[str], str]:
|
| 261 |
+
start_time = time.time()
|
| 262 |
+
results = self.pipeline.re_paraphrasing()
|
| 263 |
+
html_outputs = []
|
| 264 |
+
self.reparaphrased_indices = {}
|
| 265 |
+
tab_count = 1
|
| 266 |
+
|
| 267 |
+
for sampling_strategy, masking_dict in results.items():
|
| 268 |
+
for masking_strategy, sentences in masking_dict.items():
|
| 269 |
+
for original_sent, data in sentences.items():
|
| 270 |
+
sampled_sentence = data.get("sampled_sentence", "")
|
| 271 |
+
if not sampled_sentence or not data["re_paraphrased_sentences"]:
|
| 272 |
+
continue
|
| 273 |
+
|
| 274 |
+
sampled_index = None
|
| 275 |
+
for masked_sent, sampled_dict in self.sampled_sentence_indices.items():
|
| 276 |
+
if sampled_sentence in sampled_dict:
|
| 277 |
+
sampled_index = sampled_dict[sampled_sentence]['index']
|
| 278 |
+
break
|
| 279 |
+
|
| 280 |
+
if not sampled_index:
|
| 281 |
+
sampled_index = "unknown"
|
| 282 |
+
|
| 283 |
+
indexed_reparaphrased = []
|
| 284 |
+
for i, rp_sent in enumerate(data["re_paraphrased_sentences"], 1):
|
| 285 |
+
rp_idx = f"{tab_count}.({sampled_index}).{i}"
|
| 286 |
+
if sampled_sentence not in self.reparaphrased_indices:
|
| 287 |
+
self.reparaphrased_indices[sampled_sentence] = {}
|
| 288 |
+
self.reparaphrased_indices[sampled_sentence][rp_sent] = rp_idx
|
| 289 |
+
indexed_reparaphrased.append(f"[{rp_idx}] {rp_sent}")
|
| 290 |
+
|
| 291 |
+
print(f"Reparaphrasing {tab_count}.({sampled_index}): {' '.join(sampled_sentence.split()[:5])}...")
|
| 292 |
+
html = reparaphrased_sentences_html(indexed_reparaphrased)
|
| 293 |
+
html_outputs.append(html)
|
| 294 |
+
tab_count += 1
|
| 295 |
+
|
| 296 |
+
print("Reparaphrased indices:", self.reparaphrased_indices)
|
| 297 |
+
|
| 298 |
+
while len(html_outputs) < 120:
|
| 299 |
+
html_outputs.append("")
|
| 300 |
+
|
| 301 |
+
execution_time = f"<div class='execution-time'>Step 4 completed in {time.time() - start_time:.2f} seconds</div>"
|
| 302 |
+
|
| 303 |
+
return html_outputs[:120] + [execution_time]
|
| 304 |
+
|
| 305 |
+
def create_gradio_interface(config):
|
| 306 |
+
"""Creates the Gradio interface with the updated pipeline"""
|
| 307 |
+
interface = WatermarkerInterface(config)
|
| 308 |
+
|
| 309 |
+
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
| 310 |
+
#CSS to enable scrolling for reparaphrased sentences and sampling plots
|
| 311 |
+
demo.css = """
|
| 312 |
+
/* Set fixed height for the reparaphrased tabs container only */
|
| 313 |
+
.gradio-container .tabs[id="reparaphrased-tabs"],
|
| 314 |
+
.gradio-container .tabs[id="sampling-tabs"] {
|
| 315 |
+
overflow-x: hidden;
|
| 316 |
+
white-space: normal;
|
| 317 |
+
border-radius: 8px;
|
| 318 |
+
max-height: 600px; /* Set fixed height for the entire tabs component */
|
| 319 |
+
overflow-y: auto; /* Enable vertical scrolling inside the container */
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
/* Tab content styling for reparaphrased and sampling tabs */
|
| 323 |
+
.gradio-container .tabs[id="reparaphrased-tabs"] .tabitem,
|
| 324 |
+
.gradio-container .tabs[id="sampling-tabs"] .tabitem {
|
| 325 |
+
overflow-x: hidden;
|
| 326 |
+
white-space: normal;
|
| 327 |
+
display: block;
|
| 328 |
+
border-radius: 8px;
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
/* Make the tab navigation fixed at the top for scrollable tabs */
|
| 332 |
+
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav,
|
| 333 |
+
.gradio-container .tabs[id="sampling-tabs"] .tab-nav {
|
| 334 |
+
display: flex;
|
| 335 |
+
overflow-x: auto;
|
| 336 |
+
white-space: nowrap;
|
| 337 |
+
scrollbar-width: thin;
|
| 338 |
+
border-radius: 8px;
|
| 339 |
+
scrollbar-color: #888 #f1f1f1;
|
| 340 |
+
position: sticky;
|
| 341 |
+
top: 0;
|
| 342 |
+
background: white;
|
| 343 |
+
z-index: 100;
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
/* Dropdown menu for scrollable tabs styling */
|
| 347 |
+
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown,
|
| 348 |
+
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown {
|
| 349 |
+
position: relative;
|
| 350 |
+
display: inline-block;
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown-content,
|
| 354 |
+
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown-content {
|
| 355 |
+
display: none;
|
| 356 |
+
position: absolute;
|
| 357 |
+
background-color: #f9f9f9;
|
| 358 |
+
min-width: 160px;
|
| 359 |
+
box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2);
|
| 360 |
+
z-index: 1;
|
| 361 |
+
max-height: 300px;
|
| 362 |
+
overflow-y: auto;
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown:hover .tab-dropdown-content,
|
| 366 |
+
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown:hover .tab-dropdown-content {
|
| 367 |
+
display: block;
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
/* Scrollbar styling for scrollable tabs */
|
| 371 |
+
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav::-webkit-scrollbar,
|
| 372 |
+
.gradio-container .tabs[id="sampling-tabs"] .tab-nav::-webkit-scrollbar {
|
| 373 |
+
height: 8px;
|
| 374 |
+
border-radius: 8px;
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav::-webkit-scrollbar-track,
|
| 378 |
+
.gradio-container .tabs[id="sampling-tabs"] .tab-nav::-webkit-scrollbar-track {
|
| 379 |
+
background: #f1f1f1;
|
| 380 |
+
border-radius: 8px;
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav::-webkit-scrollbar-thumb,
|
| 384 |
+
.gradio-container .tabs[id="sampling-tabs"] .tab-nav::-webkit-scrollbar-thumb {
|
| 385 |
+
background: #888;
|
| 386 |
+
border-radius: 8px;
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
/* Tab button styling for scrollable tabs */
|
| 390 |
+
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-item,
|
| 391 |
+
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-item {
|
| 392 |
+
flex: 0 0 auto;
|
| 393 |
+
border-radius: 8px;
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
/* Plot container styling specifically for sampling tabs */
|
| 397 |
+
.gradio-container .tabs[id="sampling-tabs"] .plot-container {
|
| 398 |
+
min-height: 600px;
|
| 399 |
+
max-height: 1800px;
|
| 400 |
+
overflow-y: auto;
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
/* Ensure text wraps in HTML components */
|
| 404 |
+
.gradio-container .prose {
|
| 405 |
+
white-space: normal;
|
| 406 |
+
word-wrap: break-word;
|
| 407 |
+
overflow-wrap: break-word;
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
/* Dropdown button styling for scrollable tabs */
|
| 411 |
+
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown button,
|
| 412 |
+
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown button {
|
| 413 |
+
background-color: #f0f0f0;
|
| 414 |
+
border: 1px solid #ddd;
|
| 415 |
+
border-radius: 4px;
|
| 416 |
+
padding: 5px 10px;
|
| 417 |
+
cursor: pointer;
|
| 418 |
+
margin: 2px;
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown button:hover,
|
| 422 |
+
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown button:hover {
|
| 423 |
+
background-color: #e0e0e0;
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
/* Style dropdown content items for scrollable tabs */
|
| 427 |
+
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown-content div,
|
| 428 |
+
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown-content div {
|
| 429 |
+
padding: 8px 12px;
|
| 430 |
+
cursor: pointer;
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
.gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown-content div:hover,
|
| 434 |
+
.gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown-content div:hover {
|
| 435 |
+
background-color: #e0e0e0;
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
/* Custom styling for execution time display */
|
| 439 |
+
.execution-time {
|
| 440 |
+
text-align: right;
|
| 441 |
+
padding: 8px 16px;
|
| 442 |
+
font-family: inherit;
|
| 443 |
+
color: #555;
|
| 444 |
+
font-size: 0.9rem;
|
| 445 |
+
font-style: italic;
|
| 446 |
+
margin-left: auto;
|
| 447 |
+
width: 100%;
|
| 448 |
+
border-top: 1px solid #eee;
|
| 449 |
+
margin-top: 8px;
|
| 450 |
+
}
|
| 451 |
+
|
| 452 |
+
/* Layout for section headers with execution time */
|
| 453 |
+
.section-header {
|
| 454 |
+
display: flex;
|
| 455 |
+
justify-content: space-between;
|
| 456 |
+
align-items: center;
|
| 457 |
+
width: 100%;
|
| 458 |
+
margin-bottom: 12px;
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
.section-header h3 {
|
| 462 |
+
margin: 0;
|
| 463 |
+
}
|
| 464 |
+
"""
|
| 465 |
+
gr.Markdown("# **AIISC Watermarking Model**")
|
| 466 |
+
|
| 467 |
+
with gr.Column():
|
| 468 |
+
gr.Markdown("## Input Prompt")
|
| 469 |
+
user_input = gr.Textbox(
|
| 470 |
+
label="Enter Your Prompt",
|
| 471 |
+
placeholder="Type your text here..."
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
with gr.Row():
|
| 475 |
+
with gr.Column(scale=3):
|
| 476 |
+
gr.Markdown("## Step 1: Paraphrasing, LCS and Entailment Analysis")
|
| 477 |
+
with gr.Column(scale=1):
|
| 478 |
+
step1_time = gr.HTML()
|
| 479 |
+
|
| 480 |
+
paraphrase_button = gr.Button("Generate Paraphrases")
|
| 481 |
+
highlighted_user_prompt = gr.HTML(label="Highlighted User Prompt")
|
| 482 |
+
|
| 483 |
+
with gr.Tabs():
|
| 484 |
+
with gr.TabItem("Accepted Paraphrased Sentences"):
|
| 485 |
+
highlighted_accepted_sentences = gr.HTML()
|
| 486 |
+
with gr.TabItem("Discarded Paraphrased Sentences"):
|
| 487 |
+
highlighted_discarded_sentences = gr.HTML()
|
| 488 |
+
|
| 489 |
+
with gr.Row():
|
| 490 |
+
with gr.Column(scale=3):
|
| 491 |
+
gr.Markdown("## Step 2: Where to Mask?")
|
| 492 |
+
with gr.Column(scale=1):
|
| 493 |
+
step2_time = gr.HTML()
|
| 494 |
+
|
| 495 |
+
masking_button = gr.Button("Apply Masking")
|
| 496 |
+
gr.Markdown("### Masked Sentence Trees")
|
| 497 |
+
tree1_plots = []
|
| 498 |
+
with gr.Tabs() as tree1_tabs:
|
| 499 |
+
for i in range(10):
|
| 500 |
+
with gr.TabItem(f"Masked Sentence {i+1}"):
|
| 501 |
+
tree1 = gr.Plot()
|
| 502 |
+
tree1_plots.append(tree1)
|
| 503 |
+
|
| 504 |
+
with gr.Row():
|
| 505 |
+
with gr.Column(scale=3):
|
| 506 |
+
gr.Markdown("## Step 3: How to Mask?")
|
| 507 |
+
with gr.Column(scale=1):
|
| 508 |
+
step3_time = gr.HTML()
|
| 509 |
+
|
| 510 |
+
sampling_button = gr.Button("Sample Words")
|
| 511 |
+
gr.Markdown("### Sampled Sentence Trees")
|
| 512 |
+
|
| 513 |
+
tree2_plots = []
|
| 514 |
+
# Add elem_id to make this tab container scrollable
|
| 515 |
+
with gr.Tabs(elem_id="sampling-tabs") as tree2_tabs:
|
| 516 |
+
for i in range(10):
|
| 517 |
+
with gr.TabItem(f"Sampled Sentence {i+1}"):
|
| 518 |
+
# Add a custom class to the container to enable proper styling
|
| 519 |
+
with gr.Column(elem_classes=["plot-container"]):
|
| 520 |
+
tree2 = gr.Plot()
|
| 521 |
+
tree2_plots.append(tree2)
|
| 522 |
+
|
| 523 |
+
with gr.Row():
|
| 524 |
+
with gr.Column(scale=3):
|
| 525 |
+
gr.Markdown("## Step 4: Re-paraphrasing")
|
| 526 |
+
with gr.Column(scale=1):
|
| 527 |
+
step4_time = gr.HTML()
|
| 528 |
+
|
| 529 |
+
reparaphrase_button = gr.Button("Re-paraphrase")
|
| 530 |
+
gr.Markdown("### Reparaphrased Sentences")
|
| 531 |
+
reparaphrased_sentences_tabs = []
|
| 532 |
+
with gr.Tabs(elem_id="reparaphrased-tabs") as reparaphrased_tabs:
|
| 533 |
+
for i in range(120):
|
| 534 |
+
with gr.TabItem(f"Reparaphrased Batch {i+1}"):
|
| 535 |
+
reparaphrased_sent_html = gr.HTML()
|
| 536 |
+
reparaphrased_sentences_tabs.append(reparaphrased_sent_html)
|
| 537 |
+
|
| 538 |
+
# Connect the interface functions to the buttons
|
| 539 |
+
paraphrase_button.click(
|
| 540 |
+
interface.handle_paraphrase,
|
| 541 |
+
inputs=user_input,
|
| 542 |
+
outputs=[
|
| 543 |
+
highlighted_user_prompt,
|
| 544 |
+
highlighted_accepted_sentences,
|
| 545 |
+
highlighted_discarded_sentences,
|
| 546 |
+
step1_time
|
| 547 |
+
]
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
masking_button.click(
|
| 551 |
+
interface.handle_masking,
|
| 552 |
+
inputs=None,
|
| 553 |
+
outputs=tree1_plots + [step2_time]
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
sampling_button.click(
|
| 557 |
+
interface.handle_sampling,
|
| 558 |
+
inputs=None,
|
| 559 |
+
outputs=tree2_plots + [step3_time]
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
reparaphrase_button.click(
|
| 563 |
+
interface.handle_reparaphrasing,
|
| 564 |
+
inputs=None,
|
| 565 |
+
outputs=reparaphrased_sentences_tabs + [step4_time]
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
return demo
|
| 569 |
+
|
| 570 |
+
if __name__ == "__main__":
|
| 571 |
+
project_root = Path(__file__).parent.parent
|
| 572 |
+
config_path = project_root / "utils" / "config.yaml"
|
| 573 |
+
config = load_config(config_path)['PECCAVI_TEXT']
|
| 574 |
+
|
| 575 |
+
create_gradio_interface(config).launch()
|
app.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from UI.gradio import create_gradio_interface
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from utils.config import load_config
|
| 6 |
+
|
| 7 |
+
project_root = Path(__file__).resolve().parent
|
| 8 |
+
config_path = project_root / "utils" / "config.yaml"
|
| 9 |
+
config = load_config(config_path)['PECCAVI_TEXT']
|
| 10 |
+
|
| 11 |
+
def main():
|
| 12 |
+
"""
|
| 13 |
+
This function is the entry point for the PECCAVI Watermarking Model.
|
| 14 |
+
|
| 15 |
+
It creates the Gradio interface for the model and runs it.
|
| 16 |
+
"""
|
| 17 |
+
create_gradio_interface(config).launch()
|
| 18 |
+
|
| 19 |
+
if __name__ == "__main__":
|
| 20 |
+
main()
|
environment.yml
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: panda
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- _libgcc_mutex=0.1=conda_forge
|
| 7 |
+
- _openmp_mutex=4.5=2_gnu
|
| 8 |
+
- asttokens=2.4.1=pyhd8ed1ab_0
|
| 9 |
+
- bzip2=1.0.8=h5eee18b_6
|
| 10 |
+
- ca-certificates=2024.8.30=hbcca054_0
|
| 11 |
+
- comm=0.2.2=pyhd8ed1ab_0
|
| 12 |
+
- debugpy=1.8.6=py310hf71b8c6_0
|
| 13 |
+
- decorator=5.1.1=pyhd8ed1ab_0
|
| 14 |
+
- exceptiongroup=1.2.2=pyhd8ed1ab_0
|
| 15 |
+
- executing=2.1.0=pyhd8ed1ab_0
|
| 16 |
+
- ipykernel=6.29.5=pyh3099207_0
|
| 17 |
+
- ipython=8.27.0=pyh707e725_0
|
| 18 |
+
- jedi=0.19.1=pyhd8ed1ab_0
|
| 19 |
+
- jupyter_client=8.6.3=pyhd8ed1ab_0
|
| 20 |
+
- jupyter_core=5.7.2=pyh31011fe_1
|
| 21 |
+
- krb5=1.21.3=h143b758_0
|
| 22 |
+
- ld_impl_linux-64=2.40=h12ee557_0
|
| 23 |
+
- libedit=3.1.20230828=h5eee18b_0
|
| 24 |
+
- libffi=3.4.4=h6a678d5_1
|
| 25 |
+
- libgcc=14.1.0=h77fa898_1
|
| 26 |
+
- libgcc-ng=14.1.0=h69a702a_1
|
| 27 |
+
- libgomp=14.1.0=h77fa898_1
|
| 28 |
+
- libsodium=1.0.20=h4ab18f5_0
|
| 29 |
+
- libstdcxx=14.1.0=hc0a3c3a_1
|
| 30 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
| 31 |
+
- libuuid=1.41.5=h5eee18b_0
|
| 32 |
+
- matplotlib-inline=0.1.7=pyhd8ed1ab_0
|
| 33 |
+
- ncurses=6.4=h6a678d5_0
|
| 34 |
+
- nest-asyncio=1.6.0=pyhd8ed1ab_0
|
| 35 |
+
- openssl=3.3.2=hb9d3cd8_0
|
| 36 |
+
- packaging=24.1=pyhd8ed1ab_0
|
| 37 |
+
- parso=0.8.4=pyhd8ed1ab_0
|
| 38 |
+
- pexpect=4.9.0=pyhd8ed1ab_0
|
| 39 |
+
- pickleshare=0.7.5=py_1003
|
| 40 |
+
- pip=24.2=py310h06a4308_0
|
| 41 |
+
- platformdirs=4.3.6=pyhd8ed1ab_0
|
| 42 |
+
- prompt-toolkit=3.0.48=pyha770c72_0
|
| 43 |
+
- ptyprocess=0.7.0=pyhd3deb0d_0
|
| 44 |
+
- pure_eval=0.2.3=pyhd8ed1ab_0
|
| 45 |
+
- pygments=2.18.0=pyhd8ed1ab_0
|
| 46 |
+
- python=3.10.14=h955ad1f_1
|
| 47 |
+
- python_abi=3.10=2_cp310
|
| 48 |
+
- pyzmq=26.2.0=py310h71f11fc_2
|
| 49 |
+
- readline=8.2=h5eee18b_0
|
| 50 |
+
- setuptools=75.1.0=py310h06a4308_0
|
| 51 |
+
- sqlite=3.45.3=h5eee18b_0
|
| 52 |
+
- stack_data=0.6.2=pyhd8ed1ab_0
|
| 53 |
+
- tk=8.6.14=h39e8969_0
|
| 54 |
+
- tornado=6.4.1=py310ha75aee5_1
|
| 55 |
+
- traitlets=5.14.3=pyhd8ed1ab_0
|
| 56 |
+
- typing_extensions=4.12.2=pyha770c72_0
|
| 57 |
+
- wcwidth=0.2.13=pyhd8ed1ab_0
|
| 58 |
+
- wheel=0.44.0=py310h06a4308_0
|
| 59 |
+
- xz=5.4.6=h5eee18b_1
|
| 60 |
+
- zeromq=4.3.5=ha4adb4c_5
|
| 61 |
+
- zlib=1.2.13=h5eee18b_1
|
| 62 |
+
- pip:
|
| 63 |
+
- absl-py==2.1.0
|
| 64 |
+
- accelerate==0.33.0
|
| 65 |
+
- aiofiles==23.2.1
|
| 66 |
+
- aiohappyeyeballs==2.3.5
|
| 67 |
+
- aiohttp==3.10.3
|
| 68 |
+
- aiosignal==1.3.1
|
| 69 |
+
- altgraph==0.17.4
|
| 70 |
+
- annotated-types==0.7.0
|
| 71 |
+
- anyio==4.6.0
|
| 72 |
+
- astunparse==1.6.3
|
| 73 |
+
- async-timeout==4.0.3
|
| 74 |
+
- attrs==24.2.0
|
| 75 |
+
- av==12.0.0
|
| 76 |
+
- backports-tarfile==1.2.0
|
| 77 |
+
- beautifulsoup4==4.12.3
|
| 78 |
+
- build==1.2.2
|
| 79 |
+
- cachetools==5.5.0
|
| 80 |
+
- certifi==2024.7.4
|
| 81 |
+
- cffi==1.17.1
|
| 82 |
+
- charset-normalizer==3.3.2
|
| 83 |
+
- clean-fid==0.1.35
|
| 84 |
+
- click==8.1.7
|
| 85 |
+
- colorama==0.4.6
|
| 86 |
+
- contextlib2==21.6.0
|
| 87 |
+
- contourpy==1.2.1
|
| 88 |
+
- cryptography==43.0.1
|
| 89 |
+
- cycler==0.12.1
|
| 90 |
+
- datasets==2.21.0
|
| 91 |
+
- diffusers==0.27.2
|
| 92 |
+
- dill==0.3.8
|
| 93 |
+
- docker-pycreds==0.4.0
|
| 94 |
+
- docutils==0.21.2
|
| 95 |
+
- fastapi==0.115.0
|
| 96 |
+
- ffmpy==0.4.0
|
| 97 |
+
- filelock==3.15.4
|
| 98 |
+
- flatbuffers==24.3.25
|
| 99 |
+
- fonttools==4.53.1
|
| 100 |
+
- frozenlist==1.4.1
|
| 101 |
+
- fsspec==2024.6.1
|
| 102 |
+
- gast==0.4.0
|
| 103 |
+
- gdown==5.2.0
|
| 104 |
+
- gitdb==4.0.11
|
| 105 |
+
- gitpython==3.1.43
|
| 106 |
+
- google-auth==2.35.0
|
| 107 |
+
- google-auth-oauthlib==0.4.6
|
| 108 |
+
- google-pasta==0.2.0
|
| 109 |
+
- gradio==4.44.0
|
| 110 |
+
- gradio-client==1.3.0
|
| 111 |
+
- grpcio==1.65.4
|
| 112 |
+
- h11==0.14.0
|
| 113 |
+
- h5py==3.11.0
|
| 114 |
+
- httpcore==1.0.6
|
| 115 |
+
- httpx==0.27.2
|
| 116 |
+
- huggingface-hub==0.25.2
|
| 117 |
+
- idna==3.7
|
| 118 |
+
- imageio==2.35.0
|
| 119 |
+
- importlib-metadata==8.2.0
|
| 120 |
+
- importlib-resources==6.4.5
|
| 121 |
+
- jaraco-classes==3.4.0
|
| 122 |
+
- jaraco-context==6.0.1
|
| 123 |
+
- jaraco-functools==4.1.0
|
| 124 |
+
- jeepney==0.8.0
|
| 125 |
+
- jinja2==3.1.4
|
| 126 |
+
- joblib==1.4.2
|
| 127 |
+
- json-with-comments==1.2.7
|
| 128 |
+
- keras==3.5.0
|
| 129 |
+
- keras-preprocessing==1.1.2
|
| 130 |
+
- keyring==25.4.1
|
| 131 |
+
- kiwisolver==1.4.5
|
| 132 |
+
- kornia==0.7.4
|
| 133 |
+
- kornia-rs==0.1.7
|
| 134 |
+
- lazy-loader==0.4
|
| 135 |
+
- libclang==18.1.1
|
| 136 |
+
- markdown==3.6
|
| 137 |
+
- markdown-it-py==3.0.0
|
| 138 |
+
- markupsafe==2.1.5
|
| 139 |
+
- matplotlib==3.9.2
|
| 140 |
+
- mdurl==0.1.2
|
| 141 |
+
- ml-collections==0.1.1
|
| 142 |
+
- ml-dtypes==0.4.0
|
| 143 |
+
- more-itertools==10.5.0
|
| 144 |
+
- multidict==6.0.5
|
| 145 |
+
- multiprocess==0.70.16
|
| 146 |
+
- namex==0.0.8
|
| 147 |
+
- networkx==3.3
|
| 148 |
+
- nh3==0.2.18
|
| 149 |
+
- nltk==3.9.1
|
| 150 |
+
- numpy==1.26.4
|
| 151 |
+
- nvidia-cublas-cu11==11.10.3.66
|
| 152 |
+
- nvidia-cuda-nvrtc-cu11==11.7.99
|
| 153 |
+
- nvidia-cuda-runtime-cu11==11.7.99
|
| 154 |
+
- nvidia-cudnn-cu11==8.5.0.96
|
| 155 |
+
- oauthlib==3.2.2
|
| 156 |
+
- opencv-python==4.10.0.84
|
| 157 |
+
- opencv-python-headless==4.10.0.84
|
| 158 |
+
- opt-einsum==3.3.0
|
| 159 |
+
- optree==0.12.1
|
| 160 |
+
- orjson==3.10.7
|
| 161 |
+
- pandas==2.2.2
|
| 162 |
+
- pillow==10.4.0
|
| 163 |
+
- pkginfo==1.10.0
|
| 164 |
+
- plotly==5.24.1
|
| 165 |
+
- protobuf==4.25.5
|
| 166 |
+
- psutil==5.9.8
|
| 167 |
+
- pyarrow==17.0.0
|
| 168 |
+
- pyasn1==0.6.1
|
| 169 |
+
- pyasn1-modules==0.4.1
|
| 170 |
+
- pycparser==2.22
|
| 171 |
+
- pydantic==2.9.2
|
| 172 |
+
- pydantic-core==2.23.4
|
| 173 |
+
- pydub==0.25.1
|
| 174 |
+
- pyinstaller==6.10.0
|
| 175 |
+
- pyinstaller-hooks-contrib==2024.8
|
| 176 |
+
- pyparsing==3.1.2
|
| 177 |
+
- pyproject-hooks==1.1.0
|
| 178 |
+
- pysocks==1.7.1
|
| 179 |
+
- python-dateutil==2.9.0.post0
|
| 180 |
+
- python-multipart==0.0.12
|
| 181 |
+
- pytorch-msssim==1.0.0
|
| 182 |
+
- pytorchcv==0.0.73
|
| 183 |
+
- pytz==2023.3.post1
|
| 184 |
+
- pyyaml==6.0.2
|
| 185 |
+
- readme-renderer==44.0
|
| 186 |
+
- regex==2024.7.24
|
| 187 |
+
- requests==2.32.3
|
| 188 |
+
- requests-oauthlib==2.0.0
|
| 189 |
+
- requests-toolbelt==1.0.0
|
| 190 |
+
- rfc3986==2.0.0
|
| 191 |
+
- rich==13.7.1
|
| 192 |
+
- rsa==4.9
|
| 193 |
+
- ruff==0.6.9
|
| 194 |
+
- safetensors==0.4.4
|
| 195 |
+
- saliency==0.2.1
|
| 196 |
+
- scikit-image==0.24.0
|
| 197 |
+
- scikit-learn==1.6.0
|
| 198 |
+
- scipy==1.14.0
|
| 199 |
+
- secretstorage==3.3.3
|
| 200 |
+
- semantic-version==2.10.0
|
| 201 |
+
- sentence-transformers==3.3.1
|
| 202 |
+
- sentry-sdk==2.15.0
|
| 203 |
+
- setproctitle==1.3.3
|
| 204 |
+
- shapely==2.0.5
|
| 205 |
+
- shellingham==1.5.4
|
| 206 |
+
- six==1.12.0
|
| 207 |
+
- smmap==5.0.1
|
| 208 |
+
- sniffio==1.3.1
|
| 209 |
+
- soupsieve==2.6
|
| 210 |
+
- spaces==0.30.2
|
| 211 |
+
- starlette==0.38.6
|
| 212 |
+
- tenacity==9.0.0
|
| 213 |
+
- tensorboard==2.17.1
|
| 214 |
+
- tensorboard-data-server==0.7.2
|
| 215 |
+
- tensorboard-plugin-wit==1.8.1
|
| 216 |
+
- tensorflow==2.17.0
|
| 217 |
+
- tensorflow-estimator==2.10.0
|
| 218 |
+
- tensorflow-hub==0.16.1
|
| 219 |
+
- tensorflow-intel==0.0.1
|
| 220 |
+
- tensorflow-io-gcs-filesystem==0.31.0
|
| 221 |
+
- termcolor==1.1.0
|
| 222 |
+
- tf-keras==2.17.0
|
| 223 |
+
- threadpoolctl==3.5.0
|
| 224 |
+
- tifffile==2024.8.10
|
| 225 |
+
- timm==1.0.10
|
| 226 |
+
- tokenizers==0.19.1
|
| 227 |
+
- tomli==2.0.1
|
| 228 |
+
- tomlkit==0.12.0
|
| 229 |
+
- torch==1.13.1
|
| 230 |
+
- torchvision==0.14.1
|
| 231 |
+
- tqdm==4.66.5
|
| 232 |
+
- transformers==4.43.3
|
| 233 |
+
- twine==5.1.1
|
| 234 |
+
- typer==0.12.5
|
| 235 |
+
- tzdata==2024.1
|
| 236 |
+
- urllib3==2.2.2
|
| 237 |
+
- uvicorn==0.31.0
|
| 238 |
+
- wandb==0.18.3
|
| 239 |
+
- websockets==12.0
|
| 240 |
+
- werkzeug==3.0.4
|
| 241 |
+
- wrapt==1.11.2
|
| 242 |
+
- xxhash==3.4.1
|
| 243 |
+
- yarl==1.9.4
|
| 244 |
+
- zipp==3.20.0
|
| 245 |
+
prefix: /home/ashhar21137/miniconda3/envs/panda
|
metrics/detectability.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Import necessary libraries
|
| 2 |
+
import nltk
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 7 |
+
from transformers import BertModel, BertTokenizer
|
| 8 |
+
from scipy import stats
|
| 9 |
+
|
| 10 |
+
# Download NLTK data if not already present
|
| 11 |
+
nltk.download('punkt', quiet=True)
|
| 12 |
+
detectability_val = {}
|
| 13 |
+
|
| 14 |
+
class SentenceDetectabilityCalculator:
|
| 15 |
+
"""
|
| 16 |
+
A class to calculate and analyze detectability metrics between an original sentence and paraphrased sentences.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, original_sentence, paraphrased_sentences):
|
| 20 |
+
"""
|
| 21 |
+
Initialize the calculator with the original sentence and a list of paraphrased sentences.
|
| 22 |
+
"""
|
| 23 |
+
self.original_sentence = original_sentence
|
| 24 |
+
self.paraphrased_sentences = paraphrased_sentences
|
| 25 |
+
|
| 26 |
+
# Raw metric dictionaries
|
| 27 |
+
self.z_scores = {}
|
| 28 |
+
self.p_values = {}
|
| 29 |
+
self.metric_values = []
|
| 30 |
+
|
| 31 |
+
# Normalized metric dictionaries
|
| 32 |
+
self.normalized_z_scores = {}
|
| 33 |
+
self.normalized_p_values = {}
|
| 34 |
+
|
| 35 |
+
# Combined detectability dictionary
|
| 36 |
+
self.combined_detectabilities = {}
|
| 37 |
+
|
| 38 |
+
# Load pre-trained BERT for embeddings
|
| 39 |
+
self.bert_model = BertModel.from_pretrained('bert-base-uncased')
|
| 40 |
+
self.bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 41 |
+
|
| 42 |
+
def calculate_all_metrics(self):
|
| 43 |
+
"""
|
| 44 |
+
Calculate detectability metrics for each paraphrased sentence.
|
| 45 |
+
"""
|
| 46 |
+
original_embedding = self._get_sentence_embedding(self.original_sentence)
|
| 47 |
+
|
| 48 |
+
# First, compute the metric values (cosine similarities)
|
| 49 |
+
for idx, paraphrased_sentence in enumerate(self.paraphrased_sentences):
|
| 50 |
+
paraphrase_embedding = self._get_sentence_embedding(paraphrased_sentence)
|
| 51 |
+
cosine_sim = cosine_similarity([original_embedding], [paraphrase_embedding])[0][0]
|
| 52 |
+
self.metric_values.append(cosine_sim)
|
| 53 |
+
|
| 54 |
+
# Compute mean and standard deviation of the metric values
|
| 55 |
+
metric_mean = np.mean(self.metric_values)
|
| 56 |
+
metric_std = np.std(self.metric_values)
|
| 57 |
+
|
| 58 |
+
# Compute z-scores and p-values
|
| 59 |
+
for idx, (paraphrased_sentence, metric_value) in enumerate(zip(self.paraphrased_sentences, self.metric_values)):
|
| 60 |
+
key = f"Sentence_{idx+1}"
|
| 61 |
+
z_score = (metric_value - metric_mean) / metric_std if metric_std != 0 else 0.0
|
| 62 |
+
p_value = stats.norm.sf(abs(z_score)) * 2 # two-tailed p-value
|
| 63 |
+
self.z_scores[key] = z_score
|
| 64 |
+
self.p_values[key] = p_value
|
| 65 |
+
|
| 66 |
+
def normalize_metrics(self):
|
| 67 |
+
"""
|
| 68 |
+
Normalize z-scores and p-values to be between 0 and 1.
|
| 69 |
+
"""
|
| 70 |
+
self.normalized_z_scores = self._normalize_dict(self.z_scores)
|
| 71 |
+
self.normalized_p_values = self._normalize_dict(self.p_values)
|
| 72 |
+
|
| 73 |
+
def calculate_combined_detectability(self):
|
| 74 |
+
"""
|
| 75 |
+
Calculate the combined detectability using the root mean square of the normalized metrics.
|
| 76 |
+
"""
|
| 77 |
+
for key in self.normalized_z_scores.keys():
|
| 78 |
+
rms = np.sqrt(
|
| 79 |
+
(
|
| 80 |
+
self.normalized_z_scores[key] ** 2 +
|
| 81 |
+
self.normalized_p_values[key] ** 2
|
| 82 |
+
) / 2
|
| 83 |
+
)
|
| 84 |
+
self.combined_detectabilities[key] = rms
|
| 85 |
+
|
| 86 |
+
def plot_metrics(self):
|
| 87 |
+
"""
|
| 88 |
+
Plot each normalized metric and the combined detectability in separate graphs.
|
| 89 |
+
"""
|
| 90 |
+
keys = list(self.normalized_z_scores.keys())
|
| 91 |
+
indices = np.arange(len(keys))
|
| 92 |
+
|
| 93 |
+
# Prepare data for plotting
|
| 94 |
+
metrics = {
|
| 95 |
+
'Z-Score': [self.normalized_z_scores[key] for key in keys],
|
| 96 |
+
'P-Value': [self.normalized_p_values[key] for key in keys],
|
| 97 |
+
'Combined Detectability': [self.combined_detectabilities[key] for key in keys]
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
# Plot each metric separately
|
| 101 |
+
for metric_name, values in metrics.items():
|
| 102 |
+
plt.figure(figsize=(12, 6))
|
| 103 |
+
plt.plot(indices, values, marker='o', color=np.random.rand(3,))
|
| 104 |
+
plt.xlabel('Sentence Index')
|
| 105 |
+
plt.ylabel('Normalized Value (0-1)')
|
| 106 |
+
plt.title(f'Normalized {metric_name}')
|
| 107 |
+
plt.grid(True)
|
| 108 |
+
plt.tight_layout()
|
| 109 |
+
plt.show()
|
| 110 |
+
|
| 111 |
+
# Private methods
|
| 112 |
+
def _get_sentence_embedding(self, sentence):
|
| 113 |
+
"""
|
| 114 |
+
Get sentence embedding using BERT.
|
| 115 |
+
"""
|
| 116 |
+
tokens = self.bert_tokenizer(sentence, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
| 117 |
+
with torch.no_grad():
|
| 118 |
+
outputs = self.bert_model(**tokens)
|
| 119 |
+
return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
|
| 120 |
+
|
| 121 |
+
def _normalize_dict(self, metric_dict):
|
| 122 |
+
"""
|
| 123 |
+
Normalize the values in a dictionary to be between 0 and 1.
|
| 124 |
+
"""
|
| 125 |
+
values = np.array(list(metric_dict.values()))
|
| 126 |
+
min_val = values.min()
|
| 127 |
+
max_val = values.max()
|
| 128 |
+
# Avoid division by zero if all values are the same
|
| 129 |
+
if max_val - min_val == 0:
|
| 130 |
+
normalized_values = np.zeros_like(values)
|
| 131 |
+
else:
|
| 132 |
+
normalized_values = (values - min_val) / (max_val - min_val)
|
| 133 |
+
return dict(zip(metric_dict.keys(), normalized_values))
|
| 134 |
+
|
| 135 |
+
# Getter methods
|
| 136 |
+
def get_normalized_metrics(self):
|
| 137 |
+
"""
|
| 138 |
+
Get all normalized metrics as a dictionary.
|
| 139 |
+
"""
|
| 140 |
+
return {
|
| 141 |
+
'Z-Score': self.normalized_z_scores,
|
| 142 |
+
'P-Value': self.normalized_p_values
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
def get_combined_detectabilities(self):
|
| 146 |
+
"""
|
| 147 |
+
Get the dictionary of combined detectability values.
|
| 148 |
+
"""
|
| 149 |
+
return self.combined_detectabilities
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# Example usage
|
| 153 |
+
if __name__ == "__main__":
|
| 154 |
+
# Original sentence
|
| 155 |
+
original_sentence = "The quick brown fox jumps over the lazy dog"
|
| 156 |
+
|
| 157 |
+
# Paraphrased sentences
|
| 158 |
+
paraphrased_sentences = [
|
| 159 |
+
# Original 1: "A swift auburn fox leaps across a sleepy canine."
|
| 160 |
+
"The swift auburn fox leaps across a sleepy canine.",
|
| 161 |
+
"A quick auburn fox leaps across a sleepy canine.",
|
| 162 |
+
"A swift ginger fox leaps across a sleepy canine.",
|
| 163 |
+
"A swift auburn fox bounds across a sleepy canine.",
|
| 164 |
+
"A swift auburn fox leaps across a tired canine.",
|
| 165 |
+
"Three swift auburn foxes leap across a sleepy canine.",
|
| 166 |
+
"The vulpine specimen rapidly traverses over a dormant dog.",
|
| 167 |
+
"Like lightning, the russet hunter soars over the drowsy guardian.",
|
| 168 |
+
"Tha quick ginger fox jumps o'er the lazy hound, ye ken.",
|
| 169 |
+
"One rapid Vulpes vulpes traverses the path of a quiescent canine.",
|
| 170 |
+
"A swift auburn predator navigates across a lethargic pet.",
|
| 171 |
+
"Subject A (fox) demonstrates velocity over Subject B (dog).",
|
| 172 |
+
|
| 173 |
+
# Original 2: "The agile russet fox bounds over an idle hound."
|
| 174 |
+
"Some agile russet foxes bound over an idle hound.",
|
| 175 |
+
"The nimble russet fox bounds over an idle hound.",
|
| 176 |
+
"The agile brown fox bounds over an idle hound.",
|
| 177 |
+
"The agile russet fox jumps over an idle hound.",
|
| 178 |
+
"The agile russet fox bounds over a lazy hound.",
|
| 179 |
+
"Two agile russet foxes bound over an idle hound.",
|
| 180 |
+
"A dexterous vulpine surpasses a stationary canine.",
|
| 181 |
+
"Quick as thought, the copper warrior sails over the guardian.",
|
| 182 |
+
"Tha nimble reddish fox jumps o'er the doggo, don't ya know.",
|
| 183 |
+
"A dexterous V. vulpes exceeds the plane of an inactive canine.",
|
| 184 |
+
"An agile russet hunter maneuvers above a resting hound.",
|
| 185 |
+
"Test subject F-1 achieves displacement superior to subject D-1.",
|
| 186 |
+
|
| 187 |
+
# Original 3: "A nimble mahogany vulpine vaults above a drowsy dog."
|
| 188 |
+
"The nimble mahogany vulpine vaults above a drowsy dog.",
|
| 189 |
+
"A swift mahogany vulpine vaults above a drowsy dog.",
|
| 190 |
+
"A nimble reddish vulpine vaults above a drowsy dog.",
|
| 191 |
+
"A nimble mahogany fox vaults above a drowsy dog.",
|
| 192 |
+
"A nimble mahogany vulpine leaps above a drowsy dog.",
|
| 193 |
+
"Four nimble mahogany vulpines vault above a drowsy dog.",
|
| 194 |
+
"An agile specimen of reddish fur surpasses a somnolent canine.",
|
| 195 |
+
"Fleet as wind, the earth-toned hunter soars over the sleepy guard.",
|
| 196 |
+
"Tha quick brown beastie jumps o'er the tired pup, aye.",
|
| 197 |
+
"Single V. vulpes demonstrates vertical traverse over C. familiaris.",
|
| 198 |
+
"A nimble rust-colored predator crosses above a drowsy pet.",
|
| 199 |
+
"Observed: Subject Red executes vertical motion over Subject Gray.",
|
| 200 |
+
|
| 201 |
+
# Original 4: "The speedy copper-colored fox hops over the lethargic pup."
|
| 202 |
+
"A speedy copper-colored fox hops over the lethargic pup.",
|
| 203 |
+
"The quick copper-colored fox hops over the lethargic pup.",
|
| 204 |
+
"The speedy bronze fox hops over the lethargic pup.",
|
| 205 |
+
"The speedy copper-colored fox jumps over the lethargic pup.",
|
| 206 |
+
"The speedy copper-colored fox hops over the tired pup.",
|
| 207 |
+
"Multiple speedy copper-colored foxes hop over the lethargic pup.",
|
| 208 |
+
"A rapid vulpine of bronze hue traverses an inactive young canine.",
|
| 209 |
+
"Swift as a dart, the metallic hunter bounds over the lazy puppy.",
|
| 210 |
+
"Tha fast copper beastie leaps o'er the sleepy wee dog.",
|
| 211 |
+
"1 rapid V. vulpes crosses above 1 juvenile C. familiaris.",
|
| 212 |
+
"A fleet copper-toned predator moves past a sluggish young dog.",
|
| 213 |
+
"Field note: Adult fox subject exceeds puppy subject vertically.",
|
| 214 |
+
|
| 215 |
+
# Original 5: "A rapid tawny fox springs over a sluggish dog."
|
| 216 |
+
"The rapid tawny fox springs over a sluggish dog.",
|
| 217 |
+
"A quick tawny fox springs over a sluggish dog.",
|
| 218 |
+
"A rapid golden fox springs over a sluggish dog.",
|
| 219 |
+
"A rapid tawny fox jumps over a sluggish dog.",
|
| 220 |
+
"A rapid tawny fox springs over a lazy dog.",
|
| 221 |
+
"Six rapid tawny foxes spring over a sluggish dog.",
|
| 222 |
+
"An expeditious yellowish vulpine surpasses a torpid canine.",
|
| 223 |
+
"Fast as a bullet, the golden hunter vaults over the idle guard.",
|
| 224 |
+
"Tha swift yellowy fox jumps o'er the lazy mutt, aye.",
|
| 225 |
+
"One V. vulpes displays rapid transit over one inactive C. familiaris.",
|
| 226 |
+
"A speedy yellow-brown predator bypasses a motionless dog.",
|
| 227 |
+
"Log entry: Vulpine subject achieves swift vertical displacement.",
|
| 228 |
+
|
| 229 |
+
# Original 6: "The fleet-footed chestnut fox soars above an indolent canine."
|
| 230 |
+
"A fleet-footed chestnut fox soars above an indolent canine.",
|
| 231 |
+
"The swift chestnut fox soars above an indolent canine.",
|
| 232 |
+
"The fleet-footed brown fox soars above an indolent canine.",
|
| 233 |
+
"The fleet-footed chestnut fox leaps above an indolent canine.",
|
| 234 |
+
"The fleet-footed chestnut fox soars above a lazy canine.",
|
| 235 |
+
"Several fleet-footed chestnut foxes soar above an indolent canine.",
|
| 236 |
+
"A rapid brown vulpine specimen traverses a lethargic domestic dog.",
|
| 237 |
+
"Graceful as a bird, the nutbrown hunter flies over the lazy guard.",
|
| 238 |
+
"Tha quick brown beastie sails o'er the sleepy hound, ken.",
|
| 239 |
+
"Single agile V. vulpes achieves elevation above stationary canine.",
|
| 240 |
+
"A nimble brown predator glides over an unmoving domestic animal.",
|
| 241 |
+
"Research note: Brown subject displays superior vertical mobility.",
|
| 242 |
+
|
| 243 |
+
# Original 7: "A fast ginger fox hurdles past a slothful dog."
|
| 244 |
+
"The fast ginger fox hurdles past a slothful dog.",
|
| 245 |
+
"A quick ginger fox hurdles past a slothful dog.",
|
| 246 |
+
"A fast red fox hurdles past a slothful dog.",
|
| 247 |
+
"A fast ginger fox jumps past a slothful dog.",
|
| 248 |
+
"A fast ginger fox hurdles past a lazy dog.",
|
| 249 |
+
"Five fast ginger foxes hurdle past a slothful dog.",
|
| 250 |
+
"A rapid orange vulpine bypasses a lethargic canine.",
|
| 251 |
+
"Quick as lightning, the flame-colored hunter races past the lazy guard.",
|
| 252 |
+
"Tha swift ginger beastie leaps past the tired doggy, ye see.",
|
| 253 |
+
"1 rapid orange V. vulpes surpasses 1 inactive C. familiaris.",
|
| 254 |
+
"A speedy red-orange predator overtakes a motionless dog.",
|
| 255 |
+
"Data point: Orange subject demonstrates rapid transit past Gray subject.",
|
| 256 |
+
|
| 257 |
+
# Original 8: "The spry rusty-colored fox jumps across a dozing hound."
|
| 258 |
+
"A spry rusty-colored fox jumps across a dozing hound.",
|
| 259 |
+
"The agile rusty-colored fox jumps across a dozing hound.",
|
| 260 |
+
"The spry reddish fox jumps across a dozing hound.",
|
| 261 |
+
"The spry rusty-colored fox leaps across a dozing hound.",
|
| 262 |
+
"The spry rusty-colored fox jumps across a sleeping hound.",
|
| 263 |
+
"Multiple spry rusty-colored foxes jump across a dozing hound.",
|
| 264 |
+
"An agile rust-toned vulpine traverses a somnolent canine.",
|
| 265 |
+
"Nimble as thought, the copper hunter bounds over the resting guard.",
|
| 266 |
+
"Tha lively rust-colored beastie hops o'er the snoozin' hound.",
|
| 267 |
+
"Single dexterous V. vulpes crosses path of dormant C. familiaris.",
|
| 268 |
+
"A lithe rust-tinted predator moves past a slumbering dog.",
|
| 269 |
+
"Observation: Russet subject exhibits agility over dormant subject.",
|
| 270 |
+
|
| 271 |
+
# Original 9: "A quick tan fox leaps over an inactive dog."
|
| 272 |
+
"The quick tan fox leaps over an inactive dog.",
|
| 273 |
+
"A swift tan fox leaps over an inactive dog.",
|
| 274 |
+
"A quick beige fox leaps over an inactive dog.",
|
| 275 |
+
"A quick tan fox jumps over an inactive dog.",
|
| 276 |
+
"A quick tan fox leaps over a motionless dog.",
|
| 277 |
+
"Seven quick tan foxes leap over an inactive dog.",
|
| 278 |
+
"A rapid light-brown vulpine surpasses a stationary canine.",
|
| 279 |
+
"Fast as wind, the sand-colored hunter soars over the still guard.",
|
| 280 |
+
"Tha nimble tan beastie jumps o'er the quiet doggy, aye.",
|
| 281 |
+
"One agile fawn V. vulpes traverses one immobile C. familiaris.",
|
| 282 |
+
"A fleet tan-colored predator bypasses an unmoving dog.",
|
| 283 |
+
"Field report: Tan subject demonstrates movement over static subject.",
|
| 284 |
+
|
| 285 |
+
# Original 10: "The brisk auburn vulpine bounces over a listless canine."
|
| 286 |
+
"Some brisk auburn vulpines bounce over a listless canine.",
|
| 287 |
+
"The quick auburn vulpine bounces over a listless canine.",
|
| 288 |
+
"The brisk russet vulpine bounces over a listless canine.",
|
| 289 |
+
"The brisk auburn fox bounces over a listless canine.",
|
| 290 |
+
"The brisk auburn vulpine jumps over a listless canine.",
|
| 291 |
+
"Five brisk auburn vulpines bounce over a listless canine.",
|
| 292 |
+
"The expeditious specimen supersedes a quiescent Canis lupus.",
|
| 293 |
+
"Swift as wind, the russet hunter vaults over the idle guardian.",
|
| 294 |
+
"Tha quick ginger beastie hops o'er the lazy mutt, aye.",
|
| 295 |
+
"One V. vulpes achieves displacement over inactive C. familiaris.",
|
| 296 |
+
"A high-velocity auburn predator traverses an immobile animal.",
|
| 297 |
+
"Final observation: Red subject shows mobility over Gray subject."
|
| 298 |
+
]
|
| 299 |
+
|
| 300 |
+
# Initialize the calculator
|
| 301 |
+
calculator = SentenceDetectabilityCalculator(original_sentence, paraphrased_sentences)
|
| 302 |
+
|
| 303 |
+
# Calculate all metrics
|
| 304 |
+
calculator.calculate_all_metrics()
|
| 305 |
+
|
| 306 |
+
# Normalize the metrics
|
| 307 |
+
calculator.normalize_metrics()
|
| 308 |
+
|
| 309 |
+
# Calculate combined detectability
|
| 310 |
+
calculator.calculate_combined_detectability()
|
| 311 |
+
|
| 312 |
+
# Retrieve the normalized metrics and combined detectabilities
|
| 313 |
+
normalized_metrics = calculator.get_normalized_metrics()
|
| 314 |
+
combined_detectabilities = calculator.get_combined_detectabilities()
|
| 315 |
+
detectability_val = combined_detectabilities
|
| 316 |
+
|
| 317 |
+
# Display the results
|
| 318 |
+
print("\nCombined Detectabilities:")
|
| 319 |
+
for each in combined_detectabilities.items():
|
| 320 |
+
print(f"{each[1]}")
|
| 321 |
+
|
| 322 |
+
# Plot the metrics (optional)
|
| 323 |
+
#calculator.plot_metrics()
|
metrics/distortion.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
|
| 8 |
+
from bert_score import BERTScorer
|
| 9 |
+
from bert_score.utils import model2layers
|
| 10 |
+
from nltk.tokenize import word_tokenize
|
| 11 |
+
from Levenshtein import distance as levenshtein_distance
|
| 12 |
+
from sentence_transformers import SentenceTransformer
|
| 13 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 14 |
+
from scipy.spatial.distance import cdist
|
| 15 |
+
from scipy.optimize import linear_sum_assignment
|
| 16 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 17 |
+
|
| 18 |
+
from config.config import load_config
|
| 19 |
+
config_path = os.path.join(os.path.dirname(__file__), '..', 'config', 'config.yaml')
|
| 20 |
+
config = load_config(config_path)['PECCAVI_TEXT']['Metrics']
|
| 21 |
+
|
| 22 |
+
class SentenceDistortionCalculator:
|
| 23 |
+
"""
|
| 24 |
+
A class to calculate and analyze distortion metrics between an original sentence and modified sentences.
|
| 25 |
+
"""
|
| 26 |
+
def __init__(self, config, original_sentence, paraphrased_sentences):
|
| 27 |
+
"""
|
| 28 |
+
Initialize the calculator with the original sentence and a list of modified sentences.
|
| 29 |
+
"""
|
| 30 |
+
self.original_sentence = original_sentence
|
| 31 |
+
self.paraphrased_sentences = paraphrased_sentences
|
| 32 |
+
|
| 33 |
+
self.levenshtein_distances = {}
|
| 34 |
+
self.bert_scores = {}
|
| 35 |
+
self.mover_scores = {}
|
| 36 |
+
|
| 37 |
+
self.normalized_levenshtein = {}
|
| 38 |
+
self.normalized_bert_scores = {}
|
| 39 |
+
self.normalized_mover_scores = {}
|
| 40 |
+
self.combined_distortions = {}
|
| 41 |
+
|
| 42 |
+
self.tokenizer = GPT2TokenizerFast.from_pretrained(config['Distortion'])
|
| 43 |
+
self.model = GPT2LMHeadModel.from_pretrained(config['Distortion'])
|
| 44 |
+
self.model.eval()
|
| 45 |
+
|
| 46 |
+
def calculate_all_metrics(self):
|
| 47 |
+
"""
|
| 48 |
+
Calculate all distortion metrics for each modified sentence.
|
| 49 |
+
"""
|
| 50 |
+
for idx, modified_sentence in tqdm(enumerate(self.paraphrased_sentences), total=len(self.paraphrased_sentences), desc="Calculating Metrics"):
|
| 51 |
+
key = f"Sentence_{idx+1}"
|
| 52 |
+
self.levenshtein_distances[key] = self._calculate_levenshtein_distance(modified_sentence)
|
| 53 |
+
self.bert_scores[key] = self._calculate_bert_score(modified_sentence)
|
| 54 |
+
self.mover_scores[key] = self._calculate_mover_score(modified_sentence)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def normalize_metrics(self):
|
| 58 |
+
"""
|
| 59 |
+
Normalize all metrics to be between 0 and 1.
|
| 60 |
+
"""
|
| 61 |
+
for _ in tqdm(range(1), desc="Normalizing Metrics"): # Add tqdm here (wrap the normalization process)
|
| 62 |
+
self.normalized_levenshtein = self._normalize_dict(self.levenshtein_distances)
|
| 63 |
+
self.normalized_bert_scores = self._normalize_dict(self.bert_scores)
|
| 64 |
+
self.normalized_mover_scores = self._normalize_dict(self.mover_scores)
|
| 65 |
+
|
| 66 |
+
def calculate_combined_distortion(self):
|
| 67 |
+
"""
|
| 68 |
+
Calculate the combined distortion using the root mean square of the normalized metrics.
|
| 69 |
+
"""
|
| 70 |
+
for _ in tqdm(range(1), desc="Calculating Combined Distortion"): # Add tqdm here
|
| 71 |
+
for key in self.normalized_levenshtein.keys():
|
| 72 |
+
rms = np.sqrt(
|
| 73 |
+
(
|
| 74 |
+
self.normalized_levenshtein[key] ** 2 +
|
| 75 |
+
self.normalized_bert_scores[key] ** 2+
|
| 76 |
+
self.normalized_mover_scores[key] **2
|
| 77 |
+
) / 3
|
| 78 |
+
)
|
| 79 |
+
self.combined_distortions[key] = rms
|
| 80 |
+
|
| 81 |
+
def plot_metrics(self):
|
| 82 |
+
"""
|
| 83 |
+
Plot each normalized metric and the combined distortion in separate graphs.
|
| 84 |
+
"""
|
| 85 |
+
keys = list(self.normalized_levenshtein.keys())
|
| 86 |
+
indices = np.arange(len(keys))
|
| 87 |
+
|
| 88 |
+
# Prepare data for plotting
|
| 89 |
+
metrics = {
|
| 90 |
+
'Levenshtein Distance': [self.normalized_levenshtein[key] for key in keys],
|
| 91 |
+
'BERTScore': [self.normalized_bert_scores[key] for key in keys],
|
| 92 |
+
'MOVERscore':[self.normalized_mover_scores[key] for key in keys],
|
| 93 |
+
'Combined Distortion': [self.combined_distortions[key] for key in keys]
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
# Plot each metric separately
|
| 97 |
+
for metric_name, values in tqdm(metrics.items(), desc="Plotting Metrics"): # Add tqdm here
|
| 98 |
+
plt.figure(figsize=(12, 6))
|
| 99 |
+
plt.plot(indices, values, marker='o', color=np.random.rand(3,))
|
| 100 |
+
plt.xlabel('Sentence Index')
|
| 101 |
+
plt.ylabel('Normalized Value (0-1)')
|
| 102 |
+
plt.title(f'Normalized {metric_name}')
|
| 103 |
+
plt.grid(True)
|
| 104 |
+
plt.tight_layout()
|
| 105 |
+
plt.show()
|
| 106 |
+
|
| 107 |
+
def _calculate_levenshtein_distance(self, modified_sentence):
|
| 108 |
+
"""
|
| 109 |
+
Calculate the word-level Levenshtein distance between the original and modified sentence.
|
| 110 |
+
"""
|
| 111 |
+
words1 = word_tokenize(self.original_sentence)
|
| 112 |
+
words2 = word_tokenize(modified_sentence)
|
| 113 |
+
lev_distance = levenshtein_distance(words1, words2)
|
| 114 |
+
return (lev_distance / max(len(words1), len(words2)))
|
| 115 |
+
|
| 116 |
+
def _calculate_bert_score(self, modified_sentence):
|
| 117 |
+
"""
|
| 118 |
+
Compute the BERTScore similarity between the original and modified sentence.
|
| 119 |
+
Returns 1 - F1 score to represent dissimilarity.
|
| 120 |
+
"""
|
| 121 |
+
if not hasattr(self, 'original_sentence'):
|
| 122 |
+
raise ValueError("original_sentence is not set. Please set self.original_sentence before calling this function.")
|
| 123 |
+
if not isinstance(modified_sentence, str):
|
| 124 |
+
raise ValueError("modified_sentence must be a string.")
|
| 125 |
+
|
| 126 |
+
model_type = "microsoft/deberta-xlarge-mnli"
|
| 127 |
+
num_layers = model2layers[model_type]
|
| 128 |
+
|
| 129 |
+
if not hasattr(self, "cached_bertscorer"):
|
| 130 |
+
self.cached_bertscorer = BERTScorer(
|
| 131 |
+
model_type=model_type,
|
| 132 |
+
num_layers=num_layers,
|
| 133 |
+
batch_size=1, # Single sentence comparison
|
| 134 |
+
nthreads=4,
|
| 135 |
+
all_layers=False,
|
| 136 |
+
idf=False,
|
| 137 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
| 138 |
+
lang="en"
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Compute BERTScore
|
| 142 |
+
_, _, F1 = self.cached_bertscorer.score(
|
| 143 |
+
cands=[modified_sentence],
|
| 144 |
+
refs=[self.original_sentence],
|
| 145 |
+
verbose=False,
|
| 146 |
+
batch_size=1
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
return 1 - F1.item() # Return dissimilarity score
|
| 150 |
+
def _calculate_mover_score(self,modified_sentence,model_name='all-MiniLM-L6-v2'):
|
| 151 |
+
"""Compute MoverScore correctly using word-level embeddings."""
|
| 152 |
+
if not self.original_sentence:
|
| 153 |
+
raise ValueError("Original sentence not provided.")
|
| 154 |
+
|
| 155 |
+
# Tokenize sentences
|
| 156 |
+
original_tokens = self.original_sentence.split()
|
| 157 |
+
modified_tokens = modified_sentence.split()
|
| 158 |
+
model = SentenceTransformer(model_name)
|
| 159 |
+
|
| 160 |
+
# Compute word embeddings
|
| 161 |
+
original_embeddings = model.encode(original_tokens, convert_to_numpy=True)
|
| 162 |
+
modified_embeddings = model.encode(modified_tokens, convert_to_numpy=True)
|
| 163 |
+
|
| 164 |
+
# Compute cost matrix (cosine distance)
|
| 165 |
+
cost_matrix = cdist(original_embeddings, modified_embeddings, metric='cosine')
|
| 166 |
+
|
| 167 |
+
# Solve optimal transport problem (Hungarian Algorithm)
|
| 168 |
+
row_ind, col_ind = linear_sum_assignment(cost_matrix)
|
| 169 |
+
|
| 170 |
+
# Compute IDF weights
|
| 171 |
+
vectorizer = TfidfVectorizer()
|
| 172 |
+
vectorizer.fit([self.original_sentence, modified_sentence])
|
| 173 |
+
idf_values = dict(zip(vectorizer.get_feature_names_out(), vectorizer.idf_))
|
| 174 |
+
|
| 175 |
+
# Apply IDF weighting to aligned word pairs
|
| 176 |
+
idf_weights_original = np.array([idf_values.get(word.lower(), 1.0) for word in original_tokens])
|
| 177 |
+
idf_weights_modified = np.array([idf_values.get(word.lower(), 1.0) for word in modified_tokens])
|
| 178 |
+
combined_idf_weights = (idf_weights_original[row_ind] + idf_weights_modified[col_ind]) / 2
|
| 179 |
+
weighted_score = np.sum((1 - cost_matrix[row_ind, col_ind]) * combined_idf_weights) / np.sum(combined_idf_weights)
|
| 180 |
+
|
| 181 |
+
return 1-weighted_score # Higher score = more dissimilar
|
| 182 |
+
|
| 183 |
+
def _normalize_dict(self, metric_dict):
|
| 184 |
+
"""
|
| 185 |
+
Normalize the values in a dictionary to be between 0 and 1.
|
| 186 |
+
"""
|
| 187 |
+
values = np.array(list(metric_dict.values()))
|
| 188 |
+
min_val = values.min()
|
| 189 |
+
max_val = values.max()
|
| 190 |
+
if max_val - min_val == 0:
|
| 191 |
+
normalized_values = np.zeros_like(values)
|
| 192 |
+
else:
|
| 193 |
+
normalized_values = (values - min_val) / (max_val - min_val)
|
| 194 |
+
return dict(zip(metric_dict.keys(), normalized_values))
|
| 195 |
+
|
| 196 |
+
def get_normalized_metrics(self):
|
| 197 |
+
"""
|
| 198 |
+
Get all normalized metrics as a dictionary.
|
| 199 |
+
"""
|
| 200 |
+
return {
|
| 201 |
+
'Min Edit Distance': self.normalized_levenshtein,
|
| 202 |
+
'BERTScore': self.normalized_bert_scores,
|
| 203 |
+
'Mover Score': self.normalized_mover_scores
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
def get_combined_distortions(self):
|
| 207 |
+
"""
|
| 208 |
+
Get the dictionary of combined distortion values.
|
| 209 |
+
"""
|
| 210 |
+
return self.combined_distortions
|
| 211 |
+
|
| 212 |
+
# Example usage
|
| 213 |
+
if __name__ == "__main__":
|
| 214 |
+
|
| 215 |
+
config = load_config(config_path)['PECCAVI_TEXT']['Metrics']
|
| 216 |
+
|
| 217 |
+
# Original sentence
|
| 218 |
+
original_sentence = "The quick brown fox jumps over the lazy dog"
|
| 219 |
+
|
| 220 |
+
# Paraphrased sentences
|
| 221 |
+
paraphrased_sentences = [
|
| 222 |
+
# Original 1: "A swift auburn fox leaps across a sleepy canine."
|
| 223 |
+
"The swift auburn fox leaps across a sleepy canine.",
|
| 224 |
+
"A quick auburn fox leaps across a sleepy canine.",
|
| 225 |
+
"A swift ginger fox leaps across a sleepy canine.",
|
| 226 |
+
"A swift auburn fox bounds across a sleepy canine.",
|
| 227 |
+
"A swift auburn fox leaps across a tired canine.",
|
| 228 |
+
"Three swift auburn foxes leap across a sleepy canine.",
|
| 229 |
+
"The vulpine specimen rapidly traverses over a dormant dog.",
|
| 230 |
+
"Like lightning, the russet hunter soars over the drowsy guardian.",
|
| 231 |
+
"Tha quick ginger fox jumps o'er the lazy hound, ye ken.",
|
| 232 |
+
"One rapid Vulpes vulpes traverses the path of a quiescent canine.",
|
| 233 |
+
"A swift auburn predator navigates across a lethargic pet.",
|
| 234 |
+
"Subject A (fox) demonstrates velocity over Subject B (dog).",
|
| 235 |
+
|
| 236 |
+
# Original 2: "The agile russet fox bounds over an idle hound."
|
| 237 |
+
"Some agile russet foxes bound over an idle hound.",
|
| 238 |
+
"The nimble russet fox bounds over an idle hound.",
|
| 239 |
+
"The agile brown fox bounds over an idle hound.",
|
| 240 |
+
"The agile russet fox jumps over an idle hound.",
|
| 241 |
+
"The agile russet fox bounds over a lazy hound.",
|
| 242 |
+
"Two agile russet foxes bound over an idle hound.",
|
| 243 |
+
"A dexterous vulpine surpasses a stationary canine.",
|
| 244 |
+
"Quick as thought, the copper warrior sails over the guardian.",
|
| 245 |
+
"Tha nimble reddish fox jumps o'er the doggo, don't ya know.",
|
| 246 |
+
"A dexterous V. vulpes exceeds the plane of an inactive canine.",
|
| 247 |
+
"An agile russet hunter maneuvers above a resting hound.",
|
| 248 |
+
"Test subject F-1 achieves displacement superior to subject D-1.",
|
| 249 |
+
|
| 250 |
+
# Original 3: "A nimble mahogany vulpine vaults above a drowsy dog."
|
| 251 |
+
"The nimble mahogany vulpine vaults above a drowsy dog.",
|
| 252 |
+
"A swift mahogany vulpine vaults above a drowsy dog.",
|
| 253 |
+
"A nimble reddish vulpine vaults above a drowsy dog.",
|
| 254 |
+
"A nimble mahogany fox vaults above a drowsy dog.",
|
| 255 |
+
"A nimble mahogany vulpine leaps above a drowsy dog.",
|
| 256 |
+
"Four nimble mahogany vulpines vault above a drowsy dog.",
|
| 257 |
+
"An agile specimen of reddish fur surpasses a somnolent canine.",
|
| 258 |
+
"Fleet as wind, the earth-toned hunter soars over the sleepy guard.",
|
| 259 |
+
"Tha quick brown beastie jumps o'er the tired pup, aye.",
|
| 260 |
+
"Single V. vulpes demonstrates vertical traverse over C. familiaris.",
|
| 261 |
+
"A nimble rust-colored predator crosses above a drowsy pet.",
|
| 262 |
+
"Observed: Subject Red executes vertical motion over Subject Gray.",
|
| 263 |
+
|
| 264 |
+
# Original 4: "The speedy copper-colored fox hops over the lethargic pup."
|
| 265 |
+
"A speedy copper-colored fox hops over the lethargic pup.",
|
| 266 |
+
"The quick copper-colored fox hops over the lethargic pup.",
|
| 267 |
+
"The speedy bronze fox hops over the lethargic pup.",
|
| 268 |
+
"The speedy copper-colored fox jumps over the lethargic pup.",
|
| 269 |
+
"The speedy copper-colored fox hops over the tired pup.",
|
| 270 |
+
"Multiple speedy copper-colored foxes hop over the lethargic pup.",
|
| 271 |
+
"A rapid vulpine of bronze hue traverses an inactive young canine.",
|
| 272 |
+
"Swift as a dart, the metallic hunter bounds over the lazy puppy.",
|
| 273 |
+
"Tha fast copper beastie leaps o'er the sleepy wee dog.",
|
| 274 |
+
"1 rapid V. vulpes crosses above 1 juvenile C. familiaris.",
|
| 275 |
+
"A fleet copper-toned predator moves past a sluggish young dog.",
|
| 276 |
+
"Field note: Adult fox subject exceeds puppy subject vertically.",
|
| 277 |
+
|
| 278 |
+
# Original 5: "A rapid tawny fox springs over a sluggish dog."
|
| 279 |
+
"The rapid tawny fox springs over a sluggish dog.",
|
| 280 |
+
"A quick tawny fox springs over a sluggish dog.",
|
| 281 |
+
"A rapid golden fox springs over a sluggish dog.",
|
| 282 |
+
"A rapid tawny fox jumps over a sluggish dog.",
|
| 283 |
+
"A rapid tawny fox springs over a lazy dog.",
|
| 284 |
+
"Six rapid tawny foxes spring over a sluggish dog.",
|
| 285 |
+
"An expeditious yellowish vulpine surpasses a torpid canine.",
|
| 286 |
+
"Fast as a bullet, the golden hunter vaults over the idle guard.",
|
| 287 |
+
"Tha swift yellowy fox jumps o'er the lazy mutt, aye.",
|
| 288 |
+
"One V. vulpes displays rapid transit over one inactive C. familiaris.",
|
| 289 |
+
"A speedy yellow-brown predator bypasses a motionless dog.",
|
| 290 |
+
"Log entry: Vulpine subject achieves swift vertical displacement.",
|
| 291 |
+
|
| 292 |
+
# Original 6: "The fleet-footed chestnut fox soars above an indolent canine."
|
| 293 |
+
"A fleet-footed chestnut fox soars above an indolent canine.",
|
| 294 |
+
"The swift chestnut fox soars above an indolent canine.",
|
| 295 |
+
"The fleet-footed brown fox soars above an indolent canine.",
|
| 296 |
+
"The fleet-footed chestnut fox leaps above an indolent canine.",
|
| 297 |
+
"The fleet-footed chestnut fox soars above a lazy canine.",
|
| 298 |
+
"Several fleet-footed chestnut foxes soar above an indolent canine.",
|
| 299 |
+
"A rapid brown vulpine specimen traverses a lethargic domestic dog.",
|
| 300 |
+
"Graceful as a bird, the nutbrown hunter flies over the lazy guard.",
|
| 301 |
+
"Tha quick brown beastie sails o'er the sleepy hound, ken.",
|
| 302 |
+
"Single agile V. vulpes achieves elevation above stationary canine.",
|
| 303 |
+
"A nimble brown predator glides over an unmoving domestic animal.",
|
| 304 |
+
"Research note: Brown subject displays superior vertical mobility.",
|
| 305 |
+
|
| 306 |
+
# Original 7: "A fast ginger fox hurdles past a slothful dog."
|
| 307 |
+
"The fast ginger fox hurdles past a slothful dog.",
|
| 308 |
+
"A quick ginger fox hurdles past a slothful dog.",
|
| 309 |
+
"A fast red fox hurdles past a slothful dog.",
|
| 310 |
+
"A fast ginger fox jumps past a slothful dog.",
|
| 311 |
+
"A fast ginger fox hurdles past a lazy dog.",
|
| 312 |
+
"Five fast ginger foxes hurdle past a slothful dog.",
|
| 313 |
+
"A rapid orange vulpine bypasses a lethargic canine.",
|
| 314 |
+
"Quick as lightning, the flame-colored hunter races past the lazy guard.",
|
| 315 |
+
"Tha swift ginger beastie leaps past the tired doggy, ye see.",
|
| 316 |
+
"1 rapid orange V. vulpes surpasses 1 inactive C. familiaris.",
|
| 317 |
+
"A speedy red-orange predator overtakes a motionless dog.",
|
| 318 |
+
"Data point: Orange subject demonstrates rapid transit past Gray subject.",
|
| 319 |
+
|
| 320 |
+
# Original 8: "The spry rusty-colored fox jumps across a dozing hound."
|
| 321 |
+
"A spry rusty-colored fox jumps across a dozing hound.",
|
| 322 |
+
"The agile rusty-colored fox jumps across a dozing hound.",
|
| 323 |
+
"The spry reddish fox jumps across a dozing hound.",
|
| 324 |
+
"The spry rusty-colored fox leaps across a dozing hound.",
|
| 325 |
+
"The spry rusty-colored fox jumps across a sleeping hound.",
|
| 326 |
+
"Multiple spry rusty-colored foxes jump across a dozing hound.",
|
| 327 |
+
"An agile rust-toned vulpine traverses a somnolent canine.",
|
| 328 |
+
"Nimble as thought, the copper hunter bounds over the resting guard.",
|
| 329 |
+
"Tha lively rust-colored beastie hops o'er the snoozin' hound.",
|
| 330 |
+
"Single dexterous V. vulpes crosses path of dormant C. familiaris.",
|
| 331 |
+
"A lithe rust-tinted predator moves past a slumbering dog.",
|
| 332 |
+
"Observation: Russet subject exhibits agility over dormant subject.",
|
| 333 |
+
|
| 334 |
+
# Original 9: "A quick tan fox leaps over an inactive dog."
|
| 335 |
+
"The quick tan fox leaps over an inactive dog.",
|
| 336 |
+
"A swift tan fox leaps over an inactive dog.",
|
| 337 |
+
"A quick beige fox leaps over an inactive dog.",
|
| 338 |
+
"A quick tan fox jumps over an inactive dog.",
|
| 339 |
+
"A quick tan fox leaps over a motionless dog.",
|
| 340 |
+
"Seven quick tan foxes leap over an inactive dog.",
|
| 341 |
+
"A rapid light-brown vulpine surpasses a stationary canine.",
|
| 342 |
+
"Fast as wind, the sand-colored hunter soars over the still guard.",
|
| 343 |
+
"Tha nimble tan beastie jumps o'er the quiet doggy, aye.",
|
| 344 |
+
"One agile fawn V. vulpes traverses one immobile C. familiaris.",
|
| 345 |
+
"A fleet tan-colored predator bypasses an unmoving dog.",
|
| 346 |
+
"Field report: Tan subject demonstrates movement over static subject.",
|
| 347 |
+
|
| 348 |
+
# Original 10: "The brisk auburn vulpine bounces over a listless canine."
|
| 349 |
+
"Some brisk auburn vulpines bounce over a listless canine.",
|
| 350 |
+
"The quick auburn vulpine bounces over a listless canine.",
|
| 351 |
+
"The brisk russet vulpine bounces over a listless canine.",
|
| 352 |
+
"The brisk auburn fox bounces over a listless canine.",
|
| 353 |
+
"The brisk auburn vulpine jumps over a listless canine.",
|
| 354 |
+
"Five brisk auburn vulpines bounce over a listless canine.",
|
| 355 |
+
"The expeditious specimen supersedes a quiescent Canis lupus.",
|
| 356 |
+
"Swift as wind, the russet hunter vaults over the idle guardian.",
|
| 357 |
+
"Tha quick ginger beastie hops o'er the lazy mutt, aye.",
|
| 358 |
+
"One V. vulpes achieves displacement over inactive C. familiaris.",
|
| 359 |
+
"A high-velocity auburn predator traverses an immobile animal.",
|
| 360 |
+
"Final observation: Red subject shows mobility over Gray subject."
|
| 361 |
+
]
|
| 362 |
+
|
| 363 |
+
distortion_calculator = SentenceDistortionCalculator(config, original_sentence, paraphrased_sentences)
|
| 364 |
+
for _ in tqdm(range(1)):
|
| 365 |
+
distortion_calculator.calculate_all_metrics()
|
| 366 |
+
distortion_calculator.normalize_metrics()
|
| 367 |
+
distortion_calculator.calculate_combined_distortion()
|
| 368 |
+
distortion_calculator.plot_metrics()
|
| 369 |
+
print("Normalized Metrics:", distortion_calculator.get_normalized_metrics())
|
| 370 |
+
print("Combined Distortion:", distortion_calculator.get_combined_distortions())
|
renderers/__pycache__/highlighter.cpython-310.pyc
ADDED
|
Binary file (4.98 kB). View file
|
|
|
renderers/__pycache__/highlighter.cpython-311.pyc
ADDED
|
Binary file (8.07 kB). View file
|
|
|
renderers/__pycache__/plot_3d.cpython-310.pyc
ADDED
|
Binary file (4.34 kB). View file
|
|
|
renderers/__pycache__/plot_3d.cpython-311.pyc
ADDED
|
Binary file (6 kB). View file
|
|
|
renderers/__pycache__/tree.cpython-310.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
renderers/__pycache__/tree.cpython-311.pyc
ADDED
|
Binary file (21.1 kB). View file
|
|
|
renderers/highlighter.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
def create_matching_pattern(word):
|
| 4 |
+
"""Creates appropriate regex pattern based on word characteristics"""
|
| 5 |
+
escaped_word = re.escape(word)
|
| 6 |
+
|
| 7 |
+
# Check for special cases
|
| 8 |
+
if any(char in word for char in '&-/.\'()[]$€£¥+') or ' ' in word:
|
| 9 |
+
# Special handling for phrases with special characters or spaces
|
| 10 |
+
return rf'{escaped_word}'
|
| 11 |
+
elif word.endswith('%'):
|
| 12 |
+
# Special handling for percentage values
|
| 13 |
+
numeric_part = word[:-1]
|
| 14 |
+
return rf'\b{re.escape(numeric_part)}\s*%'
|
| 15 |
+
elif re.search(r'[0-9]', word) and re.search(r'[a-zA-Z]', word):
|
| 16 |
+
# Special handling for alphanumeric combinations
|
| 17 |
+
return rf'{escaped_word}'
|
| 18 |
+
else:
|
| 19 |
+
# Standard word boundary pattern for simple words
|
| 20 |
+
return rf'\b{escaped_word}\b'
|
| 21 |
+
|
| 22 |
+
def highlight_common_words(common_words, sentences, title):
|
| 23 |
+
"""
|
| 24 |
+
Highlight common words in sentences by adding color-coded background and unique IDs.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
common_words (list of tuples): List of tuples where each tuple contains a word's index and the word.
|
| 28 |
+
sentences (list of str): List of sentences to search through.
|
| 29 |
+
title (str): The title for the HTML output.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
str: HTML string with the highlighted sentences.
|
| 33 |
+
"""
|
| 34 |
+
color_map = {}
|
| 35 |
+
color_index = 0
|
| 36 |
+
highlighted_html = []
|
| 37 |
+
pastel_colors = ['#E199C6','#7FB3D5', '#E57373', '#B388EB', '#80D9AA', '#F0B66B',
|
| 38 |
+
"#73D9A0", "#9A89EB", "#E38AD4", "#7FAFDB", "#DDAF8C", "#DA7FC1",
|
| 39 |
+
"#65CFA5", "#B38FDE", "#E6C97A"]
|
| 40 |
+
|
| 41 |
+
# Process each sentence
|
| 42 |
+
for idx, sentence in enumerate(sentences, start=1):
|
| 43 |
+
sentence_with_idx = f"{idx}. {sentence}"
|
| 44 |
+
highlighted_sentence = sentence_with_idx
|
| 45 |
+
|
| 46 |
+
# Highlight common words in each sentence
|
| 47 |
+
for index, word in common_words:
|
| 48 |
+
if word not in color_map:
|
| 49 |
+
color_map[word] = pastel_colors[color_index % len(pastel_colors)]
|
| 50 |
+
color_index += 1
|
| 51 |
+
|
| 52 |
+
# Create appropriate pattern based on word characteristics
|
| 53 |
+
pattern = create_matching_pattern(word)
|
| 54 |
+
|
| 55 |
+
# Replace the word with highlighted version
|
| 56 |
+
highlighted_sentence = re.sub(
|
| 57 |
+
pattern,
|
| 58 |
+
lambda m, idx=index, color=color_map[word]: (
|
| 59 |
+
f'<span style="background-color: {color}; font-weight: bold;'
|
| 60 |
+
f' padding: 2px 4px; border-radius: 2px; position: relative;">'
|
| 61 |
+
f'<span style="background-color: black; color: white; border-radius: 50%;'
|
| 62 |
+
f' padding: 2px 5px; margin-right: 5px;">{idx}</span>'
|
| 63 |
+
f'{m.group(0)}'
|
| 64 |
+
f'</span>'
|
| 65 |
+
),
|
| 66 |
+
highlighted_sentence,
|
| 67 |
+
flags=re.IGNORECASE
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
highlighted_html.append(highlighted_sentence)
|
| 71 |
+
|
| 72 |
+
# Format the HTML output with the title
|
| 73 |
+
final_html = "<br><br>".join(highlighted_html)
|
| 74 |
+
return f'''
|
| 75 |
+
<div style="border: solid 1px #FFFFFF; padding: 16px; background-color: #000000; color: #FFFFFF; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); border-radius: 8px;">
|
| 76 |
+
<h3 style="margin-top: 0; font-size: 1em; color: #FFFFFF;">{title}</h3>
|
| 77 |
+
<div style="background-color: #333333; line-height: 1.6; padding: 15px; border-radius: 8px; color: #FFFFFF;">{final_html}</div>
|
| 78 |
+
</div>
|
| 79 |
+
'''
|
| 80 |
+
|
| 81 |
+
def highlight_common_words_dict(common_words, sentences, title):
|
| 82 |
+
"""
|
| 83 |
+
Highlight common words in sentences (from a dictionary) by adding color-coded background and unique IDs.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
common_words (list of tuples): List of tuples where each tuple contains a word's index and the word.
|
| 87 |
+
sentences (dict): A dictionary of sentences where the key is the sentence and the value is an entailment score.
|
| 88 |
+
title (str): The title for the HTML output.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
str: HTML string with the highlighted sentences and their entailment scores.
|
| 92 |
+
"""
|
| 93 |
+
color_map = {}
|
| 94 |
+
color_index = 0
|
| 95 |
+
highlighted_html = []
|
| 96 |
+
pastel_colors = ['#E199C6','#7FB3D5', '#E57373', '#B388EB', '#80D9AA', '#F0B66B',
|
| 97 |
+
"#73D9A0", "#9A89EB", "#E38AD4", "#7FAFDB", "#DDAF8C", "#DA7FC1",
|
| 98 |
+
"#65CFA5", "#B38FDE", "#E6C97A"]
|
| 99 |
+
|
| 100 |
+
# Process each sentence and its score
|
| 101 |
+
for idx, (sentence, score) in enumerate(sentences.items(), start=1):
|
| 102 |
+
sentence_with_idx = f"{idx}. {sentence}"
|
| 103 |
+
highlighted_sentence = sentence_with_idx
|
| 104 |
+
|
| 105 |
+
# Highlight common words in each sentence
|
| 106 |
+
for index, word in common_words:
|
| 107 |
+
if word not in color_map:
|
| 108 |
+
color_map[word] = pastel_colors[color_index % len(pastel_colors)]
|
| 109 |
+
color_index += 1
|
| 110 |
+
|
| 111 |
+
# Create appropriate pattern based on word characteristics
|
| 112 |
+
pattern = create_matching_pattern(word)
|
| 113 |
+
|
| 114 |
+
# Replace the word with highlighted version
|
| 115 |
+
highlighted_sentence = re.sub(
|
| 116 |
+
pattern,
|
| 117 |
+
lambda m, idx=index, color=color_map[word]: (
|
| 118 |
+
f'<span style="background-color: {color}; font-weight: bold;'
|
| 119 |
+
f' padding: 1px 2px; border-radius: 2px; position: relative;">'
|
| 120 |
+
f'<span style="background-color: black; color: white; border-radius: 50%;'
|
| 121 |
+
f' padding: 1px 3px; margin-right: 3px; font-size: 0.8em;">{idx}</span>'
|
| 122 |
+
f'{m.group(0)}'
|
| 123 |
+
f'</span>'
|
| 124 |
+
),
|
| 125 |
+
highlighted_sentence,
|
| 126 |
+
flags=re.IGNORECASE
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Add the entailment score
|
| 130 |
+
highlighted_html.append(
|
| 131 |
+
f'<div style="margin-bottom: 5px;">'
|
| 132 |
+
f'{highlighted_sentence}'
|
| 133 |
+
f'<div style="display: inline-block; margin-left: 5px; padding: 3px 5px; border-radius: 3px; '
|
| 134 |
+
f'background-color: #333333; color: white; font-size: 0.9em;">'
|
| 135 |
+
f'Entailment Score: {score}</div></div>'
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Format the HTML output with the title
|
| 139 |
+
final_html = "<br>".join(highlighted_html)
|
| 140 |
+
return f'''
|
| 141 |
+
<div style="background-color: #000000; color: #FFFFFF;border: solid 1px #FFFFFF; border-radius: 8px;">
|
| 142 |
+
<h3 style="margin-top: 0; font-size: 1em; color: #FFFFFF;">{title}</h3>
|
| 143 |
+
<div style="background-color: #333333; line-height: 1.6; padding: 15px; border-radius: 8px; color: #FFFFFF;">{final_html}</div>
|
| 144 |
+
</div>
|
| 145 |
+
'''
|
| 146 |
+
|
| 147 |
+
def reparaphrased_sentences_html(sentences):
|
| 148 |
+
"""
|
| 149 |
+
Create an HTML representation of sentences with numbering.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
sentences (list of str): List of sentences to format.
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
str: HTML string with numbered sentences.
|
| 156 |
+
"""
|
| 157 |
+
formatted_sentences = []
|
| 158 |
+
|
| 159 |
+
# Process each sentence
|
| 160 |
+
for idx, sentence in enumerate(sentences, start=1):
|
| 161 |
+
sentence_with_idx = f"{idx}. {sentence}"
|
| 162 |
+
formatted_sentences.append(sentence_with_idx)
|
| 163 |
+
|
| 164 |
+
# Format the HTML output
|
| 165 |
+
final_html = "<br><br>".join(formatted_sentences)
|
| 166 |
+
return f'''
|
| 167 |
+
<div style="border: solid 1px #FFFFFF; background-color: #000000; color: #FFFFFF;
|
| 168 |
+
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); border-radius: 8px;">
|
| 169 |
+
<div style="background-color: #333333; line-height: 1.6; padding: 15px; border-radius: 8px;">{final_html}</div>
|
| 170 |
+
</div>
|
| 171 |
+
'''
|
| 172 |
+
|
| 173 |
+
if __name__ == "__main__":
|
| 174 |
+
# Example usage
|
| 175 |
+
common_words = [(1, "highlight"), (2, "numbering"), (3, "S&P 500")]
|
| 176 |
+
sentences = ["This is a test to highlight words.", "Numbering is important for clarity.", "The S&P 500 index rose 2% today."]
|
| 177 |
+
|
| 178 |
+
# Test highlight_common_words
|
| 179 |
+
highlighted_html = highlight_common_words(common_words, sentences, "Test Highlighting")
|
| 180 |
+
print(highlighted_html)
|
| 181 |
+
|
| 182 |
+
# Test highlight_common_words_dict
|
| 183 |
+
sentences_with_scores = {"Highlight words in this text.": 0.95, "Number sentences for clarity.": 0.8, "The S&P 500 index is a market benchmark.": 0.88}
|
| 184 |
+
highlighted_html_dict = highlight_common_words_dict(common_words, sentences_with_scores, "Test Dict Highlighting")
|
| 185 |
+
print(highlighted_html_dict)
|
renderers/plot_3d.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file contains the code to plot a 3d tree
|
| 3 |
+
"""
|
| 4 |
+
import numpy as np
|
| 5 |
+
import plotly.graph_objects as go
|
| 6 |
+
from scipy.interpolate import griddata
|
| 7 |
+
|
| 8 |
+
def gen_three_D_plot(detectability_val, distortion_val, euclidean_val):
|
| 9 |
+
"""
|
| 10 |
+
Generates a 3D surface plot showing the relationship between detectability, distortion,
|
| 11 |
+
and Euclidean distance, with a focus on highlighting the "sweet spot" based on a composite score.
|
| 12 |
+
|
| 13 |
+
The function takes three sets of values: detectability, distortion, and Euclidean distance,
|
| 14 |
+
normalizes them to a [0, 1] range, and computes a composite score that combines these three metrics.
|
| 15 |
+
The "sweet spot" is the point where the composite score is maximized. This sweet spot is plotted
|
| 16 |
+
as a red marker on the 3D surface plot.
|
| 17 |
+
|
| 18 |
+
The function then uses a grid interpolation method (`griddata`) to generate a smooth surface
|
| 19 |
+
for the Euclidean distance over the detectability and distortion values. The result is a surface plot
|
| 20 |
+
where the contours represent different Euclidean distances.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
detectability_val (list or array): A list or array of detectability scores.
|
| 24 |
+
distortion_val (list or array): A list or array of distortion scores.
|
| 25 |
+
euclidean_val (list or array): A list or array of Euclidean distances.
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
plotly.graph_objects.Figure: A Plotly figure object representing the 3D surface plot,
|
| 29 |
+
with contour lines and a marker for the sweet spot.
|
| 30 |
+
|
| 31 |
+
Raises:
|
| 32 |
+
ValueError: If `griddata` fails to generate a valid interpolation, which could happen if the
|
| 33 |
+
input data does not allow for a proper interpolation.
|
| 34 |
+
|
| 35 |
+
Example:
|
| 36 |
+
# Example of usage:
|
| 37 |
+
detectability_vals = [0.1, 0.3, 0.5, 0.7, 0.9]
|
| 38 |
+
distortion_vals = [0.2, 0.4, 0.6, 0.8, 1.0]
|
| 39 |
+
euclidean_vals = [0.5, 0.3, 0.2, 0.4, 0.6]
|
| 40 |
+
|
| 41 |
+
fig = gen_three_D_plot(detectability_vals, distortion_vals, euclidean_vals)
|
| 42 |
+
fig.show() # Displays the plot in a web browser
|
| 43 |
+
|
| 44 |
+
Notes:
|
| 45 |
+
- The composite score is calculated as:
|
| 46 |
+
`composite_score = norm_detectability - (norm_distortion + norm_euclidean)`,
|
| 47 |
+
where the goal is to maximize detectability and minimize distortion and Euclidean distance.
|
| 48 |
+
- The `griddata` function uses linear interpolation to create a smooth surface for the plot.
|
| 49 |
+
- The function uses the "Plasma" colorscale for the surface plot, which provides a perceptually uniform color scheme.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
detectability = np.array(detectability_val)
|
| 53 |
+
distortion = np.array(distortion_val)
|
| 54 |
+
euclidean = np.array(euclidean_val)
|
| 55 |
+
|
| 56 |
+
# Normalize the values to range [0, 1]
|
| 57 |
+
norm_detectability = (detectability - min(detectability)) / (max(detectability) - min(detectability))
|
| 58 |
+
norm_distortion = (distortion - min(distortion)) / (max(distortion) - min(distortion))
|
| 59 |
+
norm_euclidean = (euclidean - min(euclidean)) / (max(euclidean) - min(euclidean))
|
| 60 |
+
|
| 61 |
+
# Composite score: maximize detectability, minimize distortion and Euclidean distance
|
| 62 |
+
composite_score = norm_detectability - (norm_distortion + norm_euclidean)
|
| 63 |
+
|
| 64 |
+
# Find the index of the maximum score (sweet spot)
|
| 65 |
+
sweet_spot_index = np.argmax(composite_score)
|
| 66 |
+
|
| 67 |
+
# Sweet spot values
|
| 68 |
+
sweet_spot_detectability = detectability[sweet_spot_index]
|
| 69 |
+
sweet_spot_distortion = distortion[sweet_spot_index]
|
| 70 |
+
sweet_spot_euclidean = euclidean[sweet_spot_index]
|
| 71 |
+
|
| 72 |
+
# Create a meshgrid from the data
|
| 73 |
+
x_grid, y_grid = np.meshgrid(np.linspace(min(detectability), max(detectability), 30),
|
| 74 |
+
np.linspace(min(distortion), max(distortion), 30))
|
| 75 |
+
|
| 76 |
+
# Interpolate z values (Euclidean distances) to fit the grid using 'nearest' method
|
| 77 |
+
z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='nearest')
|
| 78 |
+
|
| 79 |
+
if z_grid is None:
|
| 80 |
+
raise ValueError("griddata could not generate a valid interpolation. Check your input data.")
|
| 81 |
+
|
| 82 |
+
# Create the 3D contour plot with the Plasma color scale
|
| 83 |
+
fig = go.Figure(data=go.Surface(
|
| 84 |
+
z=z_grid,
|
| 85 |
+
x=x_grid,
|
| 86 |
+
y=y_grid,
|
| 87 |
+
contours={
|
| 88 |
+
"z": {"show": True, "start": min(euclidean), "end": max(euclidean), "size": 0.1, "usecolormap": True}
|
| 89 |
+
},
|
| 90 |
+
colorscale='Plasma'
|
| 91 |
+
))
|
| 92 |
+
|
| 93 |
+
# Add a marker for the sweet spot
|
| 94 |
+
fig.add_trace(go.Scatter3d(
|
| 95 |
+
x=[sweet_spot_detectability],
|
| 96 |
+
y=[sweet_spot_distortion],
|
| 97 |
+
z=[sweet_spot_euclidean],
|
| 98 |
+
mode='markers+text',
|
| 99 |
+
marker=dict(size=10, color='red', symbol='circle'),
|
| 100 |
+
text=["Sweet Spot"],
|
| 101 |
+
textposition="top center"
|
| 102 |
+
))
|
| 103 |
+
|
| 104 |
+
# Set axis labels
|
| 105 |
+
fig.update_layout(
|
| 106 |
+
scene=dict(
|
| 107 |
+
xaxis_title='Detectability Score',
|
| 108 |
+
yaxis_title='Distortion Score',
|
| 109 |
+
zaxis_title='Euclidean Distance'
|
| 110 |
+
),
|
| 111 |
+
margin=dict(l=0, r=0, b=0, t=0)
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
return fig
|
| 115 |
+
|
| 116 |
+
if __name__ == "__main__":
|
| 117 |
+
# Example input data
|
| 118 |
+
detectability_vals = [0.1, 0.3, 0.5, 0.7, 0.9]
|
| 119 |
+
distortion_vals = [0.2, 0.4, 0.6, 0.8, 1.0]
|
| 120 |
+
euclidean_vals = [0.5, 0.3, 0.2, 0.4, 0.6]
|
| 121 |
+
|
| 122 |
+
# Call the function with example data
|
| 123 |
+
fig = gen_three_D_plot(detectability_vals, distortion_vals, euclidean_vals)
|
| 124 |
+
|
| 125 |
+
# Show the plot
|
| 126 |
+
fig.show()
|
renderers/tree.py
ADDED
|
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import plotly.graph_objects as go
|
| 2 |
+
import textwrap
|
| 3 |
+
import re
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
|
| 6 |
+
def generate_subplot1(paraphrased_sentence, masked_sentences, strategies, highlight_info, common_grams):
|
| 7 |
+
"""
|
| 8 |
+
Generates a subplot visualizing paraphrased and masked sentences in a tree structure.
|
| 9 |
+
Highlights common words with specific colors and applies Longest Common Subsequence (LCS) numbering.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
paraphrased_sentence (str): The paraphrased sentence to be visualized.
|
| 13 |
+
masked_sentences (list of str): A list of masked sentences to be visualized.
|
| 14 |
+
strategies (list of str, optional): List of strategies used for each masked sentence.
|
| 15 |
+
highlight_info (list of tuples): A list of tuples where each tuple contains a word and its associated color for highlighting.
|
| 16 |
+
common_grams (list of tuples): A list of tuples containing an index and a common word or phrase for LCS numbering.
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
plotly.graph_objects.Figure: A Plotly figure representing the tree structure with highlighted words and labeled edges.
|
| 20 |
+
"""
|
| 21 |
+
# Combine nodes into one list with appropriate labels
|
| 22 |
+
if isinstance(masked_sentences, str):
|
| 23 |
+
masked_sentences = [masked_sentences]
|
| 24 |
+
nodes = [paraphrased_sentence] + masked_sentences
|
| 25 |
+
nodes[0] += ' L0' # Paraphrased sentence is level 0
|
| 26 |
+
if len(nodes) < 2:
|
| 27 |
+
print("[ERROR] Insufficient nodes for visualization")
|
| 28 |
+
return go.Figure()
|
| 29 |
+
|
| 30 |
+
for i in range(1, len(nodes)):
|
| 31 |
+
nodes[i] += ' L1' # masked sentences are level 1
|
| 32 |
+
|
| 33 |
+
def apply_lcs_numbering(sentence, common_grams):
|
| 34 |
+
"""
|
| 35 |
+
Applies LCS numbering to the sentence based on the common_grams.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
sentence (str): The sentence to which the LCS numbering should be applied.
|
| 39 |
+
common_grams (list of tuples): A list of common grams to be replaced with LCS numbers.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
str: The sentence with LCS numbering applied.
|
| 43 |
+
"""
|
| 44 |
+
for idx, lcs in common_grams:
|
| 45 |
+
sentence = re.sub(rf"\b{lcs}\b", f"({idx}){lcs}", sentence)
|
| 46 |
+
return sentence
|
| 47 |
+
|
| 48 |
+
# Apply LCS numbering
|
| 49 |
+
nodes = [apply_lcs_numbering(node, common_grams) for node in nodes]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def highlight_words(sentence, color_map):
|
| 53 |
+
"""
|
| 54 |
+
Highlights words in the sentence based on the color_map.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
sentence (str): The sentence where the words will be highlighted.
|
| 58 |
+
color_map (dict): A dictionary mapping words to their colors.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
str: The sentence with highlighted words.
|
| 62 |
+
"""
|
| 63 |
+
for word, color in color_map.items():
|
| 64 |
+
sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE)
|
| 65 |
+
return sentence
|
| 66 |
+
|
| 67 |
+
# Clean and wrap nodes, and highlight specified words globally
|
| 68 |
+
cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes]
|
| 69 |
+
global_color_map = dict(highlight_info)
|
| 70 |
+
highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes]
|
| 71 |
+
wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=55)) for node in highlighted_nodes]
|
| 72 |
+
|
| 73 |
+
def get_levels_and_edges(nodes, strategies=None):
|
| 74 |
+
"""
|
| 75 |
+
Determines tree levels and creates edges dynamically.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
nodes (list of str): The nodes representing the sentences.
|
| 79 |
+
strategies (list of str, optional): The strategies used for each edge.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
tuple: A tuple containing two dictionaries:
|
| 83 |
+
- levels: A dictionary mapping node indices to their levels.
|
| 84 |
+
- edges: A list of edges where each edge is represented by a tuple of node indices.
|
| 85 |
+
"""
|
| 86 |
+
levels = {}
|
| 87 |
+
edges = []
|
| 88 |
+
for i, node in enumerate(nodes):
|
| 89 |
+
level = int(node.split()[-1][1])
|
| 90 |
+
levels[i] = level
|
| 91 |
+
|
| 92 |
+
# Add edges from L0 to all L1 nodes
|
| 93 |
+
root_node = next((i for i, level in levels.items() if level == 0), 0)
|
| 94 |
+
for i, level in levels.items():
|
| 95 |
+
if level == 1:
|
| 96 |
+
edges.append((root_node, i))
|
| 97 |
+
|
| 98 |
+
return levels, edges
|
| 99 |
+
|
| 100 |
+
# Get levels and dynamic edges
|
| 101 |
+
levels, edges = get_levels_and_edges(nodes, strategies)
|
| 102 |
+
max_level = max(levels.values(), default=0)
|
| 103 |
+
|
| 104 |
+
# Calculate positions
|
| 105 |
+
positions = {}
|
| 106 |
+
level_heights = defaultdict(int)
|
| 107 |
+
for node, level in levels.items():
|
| 108 |
+
level_heights[level] += 1
|
| 109 |
+
|
| 110 |
+
y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()}
|
| 111 |
+
x_gap = 2
|
| 112 |
+
l1_y_gap = 10
|
| 113 |
+
|
| 114 |
+
for node, level in levels.items():
|
| 115 |
+
if level == 1:
|
| 116 |
+
positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap)
|
| 117 |
+
else:
|
| 118 |
+
positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap)
|
| 119 |
+
y_offsets[level] += 1
|
| 120 |
+
|
| 121 |
+
def color_highlighted_words(node, color_map):
|
| 122 |
+
"""
|
| 123 |
+
Colors the highlighted words in the node text.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
node (str): The node text to be highlighted.
|
| 127 |
+
color_map (dict): A dictionary mapping words to their colors.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
str: The node text with highlighted words.
|
| 131 |
+
"""
|
| 132 |
+
parts = re.split(r'(\{\{.*?\}\})', node)
|
| 133 |
+
colored_parts = []
|
| 134 |
+
for part in parts:
|
| 135 |
+
match = re.match(r'\{\{(.*?)\}\}', part)
|
| 136 |
+
if match:
|
| 137 |
+
word = match.group(1)
|
| 138 |
+
color = color_map.get(word, 'black')
|
| 139 |
+
colored_parts.append(f"<span style='color: {color};'>{word}</span>")
|
| 140 |
+
else:
|
| 141 |
+
colored_parts.append(part)
|
| 142 |
+
return ''.join(colored_parts)
|
| 143 |
+
|
| 144 |
+
# Define the text for each edge
|
| 145 |
+
default_edge_texts = [
|
| 146 |
+
"Highest Entropy Masking", "Pseudo-random Masking", "Random Masking",
|
| 147 |
+
"Greedy Sampling", "Temperature Sampling", "Exponential Minimum Sampling",
|
| 148 |
+
"Inverse Transform Sampling", "Greedy Sampling", "Temperature Sampling",
|
| 149 |
+
"Exponential Minimum Sampling", "Inverse Transform Sampling", "Greedy Sampling",
|
| 150 |
+
"Temperature Sampling", "Exponential Minimum Sampling", "Inverse Transform Sampling"
|
| 151 |
+
]
|
| 152 |
+
|
| 153 |
+
if len(nodes) < 2:
|
| 154 |
+
print("[ERROR] Insufficient nodes for visualization")
|
| 155 |
+
return go.Figure()
|
| 156 |
+
|
| 157 |
+
# Create figure
|
| 158 |
+
fig1 = go.Figure()
|
| 159 |
+
|
| 160 |
+
# Add nodes to the figure
|
| 161 |
+
for i, node in enumerate(wrapped_nodes):
|
| 162 |
+
colored_node = color_highlighted_words(node, global_color_map)
|
| 163 |
+
x, y = positions[i]
|
| 164 |
+
fig1.add_trace(go.Scatter(
|
| 165 |
+
x=[-x], # Reflect the x coordinate
|
| 166 |
+
y=[y],
|
| 167 |
+
mode='markers',
|
| 168 |
+
marker=dict(size=20, color='blue', line=dict(color='black', width=2)),
|
| 169 |
+
hoverinfo='none'
|
| 170 |
+
))
|
| 171 |
+
fig1.add_annotation(
|
| 172 |
+
x=-x, # Reflect the x coordinate
|
| 173 |
+
y=y,
|
| 174 |
+
text=colored_node,
|
| 175 |
+
showarrow=False,
|
| 176 |
+
xshift=15,
|
| 177 |
+
align="center",
|
| 178 |
+
font=dict(size=12),
|
| 179 |
+
bordercolor='black',
|
| 180 |
+
borderwidth=2,
|
| 181 |
+
borderpad=4,
|
| 182 |
+
bgcolor='white',
|
| 183 |
+
width=400,
|
| 184 |
+
height=100
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# Add edges and text above each edge
|
| 188 |
+
for i, edge in enumerate(edges):
|
| 189 |
+
x0, y0 = positions[edge[0]]
|
| 190 |
+
x1, y1 = positions[edge[1]]
|
| 191 |
+
|
| 192 |
+
# Use strategy if available, otherwise use default edge text
|
| 193 |
+
if strategies and i < len(strategies):
|
| 194 |
+
edge_text = strategies[i]
|
| 195 |
+
else:
|
| 196 |
+
edge_text = default_edge_texts[i % len(default_edge_texts)]
|
| 197 |
+
|
| 198 |
+
fig1.add_trace(go.Scatter(
|
| 199 |
+
x=[-x0, -x1], # Reflect the x coordinates
|
| 200 |
+
y=[y0, y1],
|
| 201 |
+
mode='lines',
|
| 202 |
+
line=dict(color='black', width=1)
|
| 203 |
+
))
|
| 204 |
+
|
| 205 |
+
# Calculate the midpoint of the edge
|
| 206 |
+
mid_x = (-x0 + -x1) / 2
|
| 207 |
+
mid_y = (y0 + y1) / 2
|
| 208 |
+
|
| 209 |
+
# Adjust y position to shift text upwards
|
| 210 |
+
text_y_position = mid_y + 0.8 # Increase this value to shift the text further upwards
|
| 211 |
+
|
| 212 |
+
# Add text annotation above the edge
|
| 213 |
+
fig1.add_annotation(
|
| 214 |
+
x=mid_x,
|
| 215 |
+
y=text_y_position,
|
| 216 |
+
text=edge_text, # Use the text specific to this edge
|
| 217 |
+
showarrow=False,
|
| 218 |
+
font=dict(size=12),
|
| 219 |
+
align="center"
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
fig1.update_layout(
|
| 223 |
+
showlegend=False,
|
| 224 |
+
margin=dict(t=50, b=50, l=50, r=50),
|
| 225 |
+
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
| 226 |
+
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
| 227 |
+
width=800 + max_level * 200, # Adjusted width to accommodate more levels
|
| 228 |
+
height=300 + len(nodes) * 100, # Adjusted height to accommodate more levels
|
| 229 |
+
plot_bgcolor='rgba(240,240,240,0.2)',
|
| 230 |
+
paper_bgcolor='white'
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
return fig1
|
| 234 |
+
|
| 235 |
+
def generate_subplot2(masked_sentences, sampled_sentences, highlight_info, common_grams):
|
| 236 |
+
"""
|
| 237 |
+
Generates a subplot visualizing multiple masked sentences and their sampled variants in a tree structure.
|
| 238 |
+
Each masked sentence will have multiple sampled sentences derived from it using different sampling techniques.
|
| 239 |
+
|
| 240 |
+
Args:
|
| 241 |
+
masked_sentences (list of str): A list of masked sentences to be visualized as root nodes.
|
| 242 |
+
sampled_sentences (list of str): A list of sampled sentences derived from masked sentences.
|
| 243 |
+
highlight_info (list of tuples): A list of tuples where each tuple contains a word and its associated color for highlighting.
|
| 244 |
+
common_grams (list of tuples): A list of tuples containing an index and a common word or phrase for LCS numbering.
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
plotly.graph_objects.Figure: A Plotly figure representing the tree structure with highlighted words and labeled edges.
|
| 248 |
+
"""
|
| 249 |
+
# Define sampling techniques
|
| 250 |
+
sampling_techniques = [
|
| 251 |
+
"Greedy Sampling",
|
| 252 |
+
"Temperature Sampling",
|
| 253 |
+
"Exponential Minimum Sampling",
|
| 254 |
+
"Inverse Transform Sampling"
|
| 255 |
+
]
|
| 256 |
+
|
| 257 |
+
# Calculate total number of nodes
|
| 258 |
+
num_masked = len(masked_sentences)
|
| 259 |
+
num_sampled_per_masked = len(sampling_techniques)
|
| 260 |
+
total_nodes = num_masked + (num_masked * num_sampled_per_masked)
|
| 261 |
+
|
| 262 |
+
# Combine all sentences into nodes list with appropriate labels
|
| 263 |
+
nodes = []
|
| 264 |
+
# Level 0: masked sentences (root nodes)
|
| 265 |
+
nodes.extend([s + ' L0' for s in masked_sentences])
|
| 266 |
+
|
| 267 |
+
# Level 1: sampled sentences (branch nodes)
|
| 268 |
+
# For each masked sentence, we should have samples from each technique
|
| 269 |
+
sampled_nodes = []
|
| 270 |
+
|
| 271 |
+
# Validate if we have the expected number of sampled sentences
|
| 272 |
+
expected_sampled_count = num_masked * num_sampled_per_masked
|
| 273 |
+
if len(sampled_sentences) < expected_sampled_count:
|
| 274 |
+
# If insufficient samples provided, pad with placeholder sentences
|
| 275 |
+
print(f"Warning: Expected {expected_sampled_count} sampled sentences, but got {len(sampled_sentences)}")
|
| 276 |
+
while len(sampled_sentences) < expected_sampled_count:
|
| 277 |
+
sampled_sentences.append(f"Placeholder sampled sentence {len(sampled_sentences) + 1}")
|
| 278 |
+
|
| 279 |
+
# Add all sampled sentences with level information
|
| 280 |
+
for s in sampled_sentences[:expected_sampled_count]:
|
| 281 |
+
sampled_nodes.append(s + ' L1')
|
| 282 |
+
|
| 283 |
+
nodes.extend(sampled_nodes)
|
| 284 |
+
|
| 285 |
+
def apply_lcs_numbering(sentence, common_grams):
|
| 286 |
+
"""
|
| 287 |
+
Applies LCS numbering to the sentence based on the common_grams.
|
| 288 |
+
"""
|
| 289 |
+
for idx, lcs in common_grams:
|
| 290 |
+
sentence = re.sub(rf"\b{lcs}\b", f"({idx}){lcs}", sentence)
|
| 291 |
+
return sentence
|
| 292 |
+
|
| 293 |
+
# Apply LCS numbering
|
| 294 |
+
nodes = [apply_lcs_numbering(node, common_grams) for node in nodes]
|
| 295 |
+
|
| 296 |
+
def highlight_words(sentence, color_map):
|
| 297 |
+
"""
|
| 298 |
+
Highlights words in the sentence based on the color_map.
|
| 299 |
+
"""
|
| 300 |
+
for word, color in color_map.items():
|
| 301 |
+
sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE)
|
| 302 |
+
return sentence
|
| 303 |
+
|
| 304 |
+
# Helper function to color highlighted words
|
| 305 |
+
def color_highlighted_words(node, color_map):
|
| 306 |
+
"""
|
| 307 |
+
Colors the highlighted words in the node text.
|
| 308 |
+
"""
|
| 309 |
+
parts = re.split(r'(\{\{.*?\}\})', node)
|
| 310 |
+
colored_parts = []
|
| 311 |
+
for part in parts:
|
| 312 |
+
match = re.match(r'\{\{(.*?)\}\}', part)
|
| 313 |
+
if match:
|
| 314 |
+
word = match.group(1)
|
| 315 |
+
color = color_map.get(word, 'black')
|
| 316 |
+
colored_parts.append(f"<span style='color: {color};'>{word}</span>")
|
| 317 |
+
else:
|
| 318 |
+
colored_parts.append(part)
|
| 319 |
+
return ''.join(colored_parts)
|
| 320 |
+
|
| 321 |
+
# Clean nodes, highlight words, and wrap text
|
| 322 |
+
cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes]
|
| 323 |
+
global_color_map = dict(highlight_info)
|
| 324 |
+
highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes]
|
| 325 |
+
wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=80)) for node in highlighted_nodes]
|
| 326 |
+
|
| 327 |
+
# Generate edges based on the tree structure
|
| 328 |
+
def get_levels_and_edges(nodes):
|
| 329 |
+
levels = {}
|
| 330 |
+
edges = []
|
| 331 |
+
|
| 332 |
+
# Extract level info from node labels
|
| 333 |
+
for i, node in enumerate(nodes):
|
| 334 |
+
level = int(node.split()[-1][1])
|
| 335 |
+
levels[i] = level
|
| 336 |
+
|
| 337 |
+
# Create edges from masked sentences to their sampled variants
|
| 338 |
+
for masked_idx in range(num_masked):
|
| 339 |
+
# For each masked sentence, create edges to its sampled variants
|
| 340 |
+
for technique_idx in range(num_sampled_per_masked):
|
| 341 |
+
sampled_idx = num_masked + (masked_idx * num_sampled_per_masked) + technique_idx
|
| 342 |
+
if sampled_idx < len(nodes):
|
| 343 |
+
edges.append((masked_idx, sampled_idx))
|
| 344 |
+
|
| 345 |
+
return levels, edges
|
| 346 |
+
|
| 347 |
+
levels, edges = get_levels_and_edges(nodes)
|
| 348 |
+
|
| 349 |
+
# Calculate positions with improved spacing
|
| 350 |
+
positions = {}
|
| 351 |
+
|
| 352 |
+
# Calculate horizontal spacing for the root nodes (masked sentences)
|
| 353 |
+
root_x_spacing = 0 # All root nodes at x=0
|
| 354 |
+
root_y_spacing = 8.0 # Vertical spacing between root nodes
|
| 355 |
+
|
| 356 |
+
# Calculate positions for sampled nodes
|
| 357 |
+
sampled_x = 3 # X position for all sampled nodes
|
| 358 |
+
|
| 359 |
+
# Calculate y positions for root nodes (masked sentences)
|
| 360 |
+
root_y_start = -(num_masked - 1) * root_y_spacing / 2
|
| 361 |
+
for i in range(num_masked):
|
| 362 |
+
positions[i] = (root_x_spacing, root_y_start + i * root_y_spacing)
|
| 363 |
+
|
| 364 |
+
# Calculate y positions for sampled nodes
|
| 365 |
+
for masked_idx in range(num_masked):
|
| 366 |
+
root_y = positions[masked_idx][1] # Y position of parent masked sentence
|
| 367 |
+
|
| 368 |
+
# Calculate y-spacing for children of this root
|
| 369 |
+
children_y_spacing = 1.5 # Vertical spacing between children of the same root
|
| 370 |
+
children_y_start = root_y - (num_sampled_per_masked - 1) * children_y_spacing / 2
|
| 371 |
+
|
| 372 |
+
# Position each child
|
| 373 |
+
for technique_idx in range(num_sampled_per_masked):
|
| 374 |
+
child_idx = num_masked + (masked_idx * num_sampled_per_masked) + technique_idx
|
| 375 |
+
child_y = children_y_start + technique_idx * children_y_spacing
|
| 376 |
+
positions[child_idx] = (sampled_x, child_y)
|
| 377 |
+
|
| 378 |
+
# Create figure
|
| 379 |
+
fig2 = go.Figure()
|
| 380 |
+
|
| 381 |
+
# Add nodes
|
| 382 |
+
for i, node in enumerate(wrapped_nodes):
|
| 383 |
+
x, y = positions[i]
|
| 384 |
+
|
| 385 |
+
# Define node color based on level
|
| 386 |
+
node_color = 'blue' if levels[i] == 0 else 'green'
|
| 387 |
+
|
| 388 |
+
# Add the node marker
|
| 389 |
+
fig2.add_trace(go.Scatter(
|
| 390 |
+
x=[x],
|
| 391 |
+
y=[y],
|
| 392 |
+
mode='markers',
|
| 393 |
+
marker=dict(size=20, color=node_color, line=dict(color='black', width=2)),
|
| 394 |
+
hoverinfo='none'
|
| 395 |
+
))
|
| 396 |
+
|
| 397 |
+
# Add node label with highlighting
|
| 398 |
+
colored_node = color_highlighted_words(node, global_color_map)
|
| 399 |
+
|
| 400 |
+
fig2.add_annotation(
|
| 401 |
+
x=x,
|
| 402 |
+
y=y,
|
| 403 |
+
text=colored_node,
|
| 404 |
+
showarrow=False,
|
| 405 |
+
xshift=15,
|
| 406 |
+
align="left",
|
| 407 |
+
font=dict(size=12),
|
| 408 |
+
bordercolor='black',
|
| 409 |
+
borderwidth=2,
|
| 410 |
+
borderpad=4,
|
| 411 |
+
bgcolor='white',
|
| 412 |
+
width=450,
|
| 413 |
+
height=100
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
# Add edges with labels
|
| 417 |
+
for i, (src, dst) in enumerate(edges):
|
| 418 |
+
x0, y0 = positions[src]
|
| 419 |
+
x1, y1 = positions[dst]
|
| 420 |
+
|
| 421 |
+
# Draw the edge
|
| 422 |
+
fig2.add_trace(go.Scatter(
|
| 423 |
+
x=[x0, x1],
|
| 424 |
+
y=[y0, y1],
|
| 425 |
+
mode='lines',
|
| 426 |
+
line=dict(color='black', width=1)
|
| 427 |
+
))
|
| 428 |
+
|
| 429 |
+
# Add sampling technique label
|
| 430 |
+
# Determine which sampling technique this is
|
| 431 |
+
parent_idx = src
|
| 432 |
+
technique_count = sum(1 for k, (s, _) in enumerate(edges) if s == parent_idx and k < i)
|
| 433 |
+
technique_label = sampling_techniques[technique_count % len(sampling_techniques)]
|
| 434 |
+
|
| 435 |
+
# Calculate midpoint for the label
|
| 436 |
+
mid_x = (x0 + x1) / 2
|
| 437 |
+
mid_y = (y0 + y1) / 2
|
| 438 |
+
|
| 439 |
+
# Add slight offset to avoid overlap
|
| 440 |
+
label_offset = 0.1
|
| 441 |
+
|
| 442 |
+
fig2.add_annotation(
|
| 443 |
+
x=mid_x,
|
| 444 |
+
y=mid_y + label_offset,
|
| 445 |
+
text=technique_label,
|
| 446 |
+
showarrow=False,
|
| 447 |
+
font=dict(size=8),
|
| 448 |
+
align="center"
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
# Update layout
|
| 452 |
+
fig2.update_layout(
|
| 453 |
+
showlegend=False,
|
| 454 |
+
margin=dict(t=20, b=20, l=20, r=20),
|
| 455 |
+
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
| 456 |
+
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
| 457 |
+
width=1200, # Adjusted width to accommodate more levels
|
| 458 |
+
height=2000, # Adjusted height to accommodate more levels
|
| 459 |
+
plot_bgcolor='rgba(240,240,240,0.2)',
|
| 460 |
+
paper_bgcolor='white'
|
| 461 |
+
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
return fig2
|
| 465 |
+
|
| 466 |
+
if __name__ == "__main__":
|
| 467 |
+
paraphrased_sentence = "The quick brown fox jumps over the lazy dog."
|
| 468 |
+
masked_sentences = [
|
| 469 |
+
"A fast brown fox leaps over the lazy dog.",
|
| 470 |
+
"A quick brown fox hops over a lazy dog."
|
| 471 |
+
]
|
| 472 |
+
highlight_info = [
|
| 473 |
+
("quick", "red"),
|
| 474 |
+
("brown", "green"),
|
| 475 |
+
("fox", "blue"),
|
| 476 |
+
("lazy", "purple")
|
| 477 |
+
]
|
| 478 |
+
common_grams = [
|
| 479 |
+
(1, "quick brown fox"),
|
| 480 |
+
(2, "lazy dog")
|
| 481 |
+
]
|
| 482 |
+
|
| 483 |
+
fig1 = generate_subplot1(paraphrased_sentence, masked_sentences, highlight_info, common_grams)
|
| 484 |
+
fig1.show()
|
| 485 |
+
|
| 486 |
+
sampled_sentence = ["A fast brown fox jumps over a lazy dog."]
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
fig2 = generate_subplot2(masked_sentences, sampled_sentence, highlight_info, common_grams)
|
| 490 |
+
fig2.show()
|
utils/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils.watermark import Watermarker
|
| 2 |
+
from utils.paraphraser import Paraphraser
|
| 3 |
+
from utils.entailment import EntailmentAnalyzer
|
| 4 |
+
from utils.sampling import SamplingProcessor
|
| 5 |
+
from utils.config import load_config
|
utils/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (404 Bytes). View file
|
|
|
utils/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (509 Bytes). View file
|
|
|
utils/__pycache__/config.cpython-310.pyc
ADDED
|
Binary file (594 Bytes). View file
|
|
|
utils/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (971 Bytes). View file
|
|
|
utils/__pycache__/entailment.cpython-310.pyc
ADDED
|
Binary file (3.69 kB). View file
|
|
|
utils/__pycache__/entailment.cpython-311.pyc
ADDED
|
Binary file (5.33 kB). View file
|
|
|
utils/__pycache__/masking_methods.cpython-310.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
utils/__pycache__/masking_methods.cpython-311.pyc
ADDED
|
Binary file (35.6 kB). View file
|
|
|
utils/__pycache__/ngram_index_manager.cpython-311.pyc
ADDED
|
Binary file (19.5 kB). View file
|
|
|
utils/__pycache__/non_melting_point.cpython-310.pyc
ADDED
|
Binary file (5.05 kB). View file
|
|
|
utils/__pycache__/non_melting_point.cpython-311.pyc
ADDED
|
Binary file (33.1 kB). View file
|
|
|
utils/__pycache__/paraphraser.cpython-310.pyc
ADDED
|
Binary file (2.85 kB). View file
|
|
|
utils/__pycache__/paraphraser.cpython-311.pyc
ADDED
|
Binary file (4.89 kB). View file
|
|
|
utils/__pycache__/sampling.cpython-310.pyc
ADDED
|
Binary file (5.06 kB). View file
|
|
|
utils/__pycache__/sampling.cpython-311.pyc
ADDED
|
Binary file (9.2 kB). View file
|
|
|
utils/__pycache__/watermark.cpython-310.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
utils/__pycache__/watermark.cpython-311.pyc
ADDED
|
Binary file (20.9 kB). View file
|
|
|
utils/config.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file loads config from config.yaml
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import yaml
|
| 6 |
+
|
| 7 |
+
def load_config(path):
|
| 8 |
+
"""
|
| 9 |
+
Function to load config from config.yaml
|
| 10 |
+
"""
|
| 11 |
+
try:
|
| 12 |
+
with open(path, "r") as file:
|
| 13 |
+
config = yaml.safe_load(file)
|
| 14 |
+
return config
|
| 15 |
+
except FileNotFoundError:
|
| 16 |
+
raise FileNotFoundError("Config file not found")
|
| 17 |
+
except Exception as e:
|
| 18 |
+
raise e
|
utils/config.yaml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This is the official config file.
|
| 2 |
+
PECCAVI_TEXT:
|
| 3 |
+
Entailment:
|
| 4 |
+
task: "text-classification"
|
| 5 |
+
model: "ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli"
|
| 6 |
+
|
| 7 |
+
Masking:
|
| 8 |
+
task: "fill-mask"
|
| 9 |
+
tokenizer: "bert-base-uncased"
|
| 10 |
+
model: "bert-base-uncased"
|
| 11 |
+
# tokenizer: "bert-large-cased-whole-word-masking"
|
| 12 |
+
# model: "bert-large-cased-whole-word-masking"
|
| 13 |
+
|
| 14 |
+
Vocabulary:
|
| 15 |
+
tokenizer: "bert-base-uncased"
|
| 16 |
+
model: "bert-base-uncased"
|
| 17 |
+
# permissible_ratio: 0.5
|
| 18 |
+
# tokenizer: "bert-large-cased-whole-word-masking"
|
| 19 |
+
# model: "bert-large-cased-whole-word-masking"
|
| 20 |
+
permissible_ratio: 1.0
|
| 21 |
+
|
| 22 |
+
Sampling:
|
| 23 |
+
tokenizer: "bert-base-uncased"
|
| 24 |
+
model: "bert-base-uncased"
|
| 25 |
+
# tokenizer: "bert-large-cased-whole-word-masking"
|
| 26 |
+
# model: "bert-large-cased-whole-word-masking"
|
| 27 |
+
|
| 28 |
+
Metrics:
|
| 29 |
+
EuclideanDistance: "sentence-transformers/all-MiniLM-L6-v2"
|
| 30 |
+
Distortion: "gpt2"
|
| 31 |
+
|
| 32 |
+
Detector:
|
| 33 |
+
tokenizer: "bert-base-uncased"
|
| 34 |
+
model: "bert-base-uncased"
|
| 35 |
+
# tokenizer: "bert-large-cased-whole-word-masking"
|
| 36 |
+
# model: "bert-large-cased-whole-word-masking"
|
| 37 |
+
|
| 38 |
+
Paraphrase:
|
| 39 |
+
tokenizer: "humarin/chatgpt_paraphraser_on_T5_base"
|
| 40 |
+
model: "humarin/chatgpt_paraphraser_on_T5_base"
|
| 41 |
+
num_beams: 10
|
| 42 |
+
num_beam_groups: 10
|
| 43 |
+
num_return_sequences: 10
|
| 44 |
+
repetition_penalty: 10.0
|
| 45 |
+
diversity_penalty: 3.0
|
| 46 |
+
no_repeat_ngram_size: 2
|
| 47 |
+
temperature: 0.7
|
| 48 |
+
max_length: 64
|
utils/entailment.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from transformers import pipeline
|
| 7 |
+
from typing import List
|
| 8 |
+
from utils.config import load_config
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class EntailmentAnalyzer:
|
| 12 |
+
# def __init__(self, config_path: str):
|
| 13 |
+
def __init__(self, config):
|
| 14 |
+
"""
|
| 15 |
+
Initialize the EntailmentAnalyzer with the config file path.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
config_path: The path to the configuration file.
|
| 19 |
+
"""
|
| 20 |
+
# self.config = load_config(config_path)['PECCAVI_TEXT']['Entailment']
|
| 21 |
+
self.config = config
|
| 22 |
+
self.entailment_pipeline = pipeline(task=self.config['task'], model=self.config['model'])
|
| 23 |
+
|
| 24 |
+
def check_entailment(self, premise: str, hypothesis: str) -> float:
|
| 25 |
+
"""
|
| 26 |
+
Check entailment between the premise and hypothesis.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
premise: The premise sentence.
|
| 30 |
+
hypothesis: The hypothesis sentence.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
float: The entailment score.
|
| 34 |
+
"""
|
| 35 |
+
results = self.entailment_pipeline(f"{premise} [SEP] {hypothesis}", top_k=None)
|
| 36 |
+
entailment_score = next(item['score'] for item in results if item['label'] == 'entailment')
|
| 37 |
+
return entailment_score
|
| 38 |
+
|
| 39 |
+
def analyze_entailment(self, original_sentence: str, paraphrased_sentences: List[str], threshold: float) -> tuple:
|
| 40 |
+
"""
|
| 41 |
+
Analyze entailment scores for paraphrased sentences. If no selected sentences are found,
|
| 42 |
+
lower the threshold and rerun the analysis.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
original_sentence: The original sentence.
|
| 46 |
+
paraphrased_sentences: List of paraphrased sentences.
|
| 47 |
+
threshold: Minimum score to select a sentence.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
tuple: A dictionary of all scores, selected sentences, and discarded sentences.
|
| 51 |
+
"""
|
| 52 |
+
all_sentences = {}
|
| 53 |
+
selected_sentences = {}
|
| 54 |
+
discarded_sentences = {}
|
| 55 |
+
|
| 56 |
+
# Loop to reduce threshold if no sentences are selected
|
| 57 |
+
while not selected_sentences:
|
| 58 |
+
for paraphrased_sentence in paraphrased_sentences:
|
| 59 |
+
entailment_score = self.check_entailment(original_sentence, paraphrased_sentence)
|
| 60 |
+
|
| 61 |
+
all_sentences[paraphrased_sentence] = entailment_score
|
| 62 |
+
if entailment_score >= threshold:
|
| 63 |
+
selected_sentences[paraphrased_sentence] = entailment_score
|
| 64 |
+
else:
|
| 65 |
+
discarded_sentences[paraphrased_sentence] = entailment_score
|
| 66 |
+
|
| 67 |
+
# If no sentences are selected, lower the threshold
|
| 68 |
+
if not selected_sentences:
|
| 69 |
+
print(f"No selected sentences found. Lowering the threshold by 0.1 (from {threshold} to {threshold - 0.1}).")
|
| 70 |
+
threshold -= 0.1
|
| 71 |
+
if threshold <= 0:
|
| 72 |
+
print("Threshold has reached 0. No sentences meet the criteria.")
|
| 73 |
+
break
|
| 74 |
+
|
| 75 |
+
return all_sentences, selected_sentences, discarded_sentences
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if __name__ == "__main__":
|
| 79 |
+
config_path = os.path.join(os.path.dirname(__file__), '..', 'config', 'config.yaml')
|
| 80 |
+
|
| 81 |
+
config_path = '/home/ashhar21137/text_wm/scratch/utils/config/config.yaml'
|
| 82 |
+
|
| 83 |
+
config = load_config(config_path)
|
| 84 |
+
|
| 85 |
+
entailment_analyzer = EntailmentAnalyzer(config['PECCAVI_TEXT']['Entailment'])
|
| 86 |
+
|
| 87 |
+
all_sentences, selected_sentences, discarded_sentences = entailment_analyzer.analyze_entailment(
|
| 88 |
+
"The weather is nice today",
|
| 89 |
+
[
|
| 90 |
+
"The climate is pleasant today",
|
| 91 |
+
"It's a good day weather-wise",
|
| 92 |
+
"Today, the weather is terrible",
|
| 93 |
+
"What a beautiful day it is",
|
| 94 |
+
"The sky is clear and the weather is perfect",
|
| 95 |
+
"It's pouring rain outside today",
|
| 96 |
+
"The weather isn't bad today",
|
| 97 |
+
"A lovely day for outdoor activities"
|
| 98 |
+
],
|
| 99 |
+
0.7
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
print("----------------------- All Sentences -----------------------")
|
| 103 |
+
print(all_sentences)
|
| 104 |
+
print("----------------------- Discarded Sentences -----------------------")
|
| 105 |
+
print(discarded_sentences)
|
| 106 |
+
print("----------------------- Selected Sentences -----------------------")
|
| 107 |
+
print(selected_sentences)
|
utils/masking_methods.py
ADDED
|
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
import logging
|
| 4 |
+
import string
|
| 5 |
+
from transformers import BertTokenizer, BertForMaskedLM
|
| 6 |
+
from nltk.corpus import stopwords
|
| 7 |
+
import nltk
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
# Set logging to WARNING for a cleaner terminal.
|
| 11 |
+
logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s")
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
# Ensure stopwords are downloaded
|
| 15 |
+
try:
|
| 16 |
+
nltk.data.find('corpora/stopwords')
|
| 17 |
+
except LookupError:
|
| 18 |
+
nltk.download('stopwords')
|
| 19 |
+
|
| 20 |
+
def clean_word(word):
|
| 21 |
+
"""More robust cleaning for consistent matching"""
|
| 22 |
+
# Remove possessive 's before other punctuation
|
| 23 |
+
if word.lower().endswith("'s"):
|
| 24 |
+
word = word[:-2]
|
| 25 |
+
return word.lower().strip().translate(str.maketrans('', '', string.punctuation))
|
| 26 |
+
|
| 27 |
+
class MaskingProcessor:
|
| 28 |
+
def __init__(self, tokenizer, model):
|
| 29 |
+
self.tokenizer = tokenizer
|
| 30 |
+
self.model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
| 31 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 32 |
+
self.stop_words = set(stopwords.words('english'))
|
| 33 |
+
tqdm.write(f"[MaskingProcessor] Initialized on device: {self.device}")
|
| 34 |
+
|
| 35 |
+
def is_stopword(self, word):
|
| 36 |
+
"""Check if a word is a stopword, handling punctuation and case"""
|
| 37 |
+
return clean_word(word) in self.stop_words
|
| 38 |
+
|
| 39 |
+
def verify_and_correct_ngram_positions(self, sentence, common_ngrams):
|
| 40 |
+
"""Verify ngram positions match actual words in sentence and correct if needed."""
|
| 41 |
+
words = sentence.split()
|
| 42 |
+
corrected_ngrams = {}
|
| 43 |
+
|
| 44 |
+
for ngram, positions in common_ngrams.items():
|
| 45 |
+
corrected_positions = []
|
| 46 |
+
ngram_words = ngram.split()
|
| 47 |
+
|
| 48 |
+
# Convert ngram words to clean format for matching
|
| 49 |
+
clean_ngram_words = [clean_word(word) for word in ngram_words]
|
| 50 |
+
|
| 51 |
+
# Scan the sentence to find actual occurrences of the ngram
|
| 52 |
+
for i in range(len(words) - len(ngram_words) + 1):
|
| 53 |
+
is_match = True
|
| 54 |
+
for j, ngram_word in enumerate(clean_ngram_words):
|
| 55 |
+
if clean_word(words[i + j]) != ngram_word:
|
| 56 |
+
is_match = False
|
| 57 |
+
break
|
| 58 |
+
|
| 59 |
+
if is_match:
|
| 60 |
+
# Found a matching position, add it
|
| 61 |
+
corrected_positions.append((i, i + len(ngram_words) - 1))
|
| 62 |
+
|
| 63 |
+
if corrected_positions:
|
| 64 |
+
corrected_ngrams[ngram] = corrected_positions
|
| 65 |
+
else:
|
| 66 |
+
# Log the issue and perform a more flexible search
|
| 67 |
+
print(f"Warning: Could not find exact match for '{ngram}' in the sentence.")
|
| 68 |
+
print(f"Attempting flexible matching...")
|
| 69 |
+
|
| 70 |
+
# Try a more flexible approach by looking for individual words
|
| 71 |
+
for i in range(len(words)):
|
| 72 |
+
if clean_word(words[i]) == clean_ngram_words[0]:
|
| 73 |
+
# We found the first word of the ngram
|
| 74 |
+
if len(ngram_words) == 1 or (
|
| 75 |
+
i + len(ngram_words) <= len(words) and
|
| 76 |
+
all(clean_word(words[i+j]).startswith(clean_ngram_words[j]) for j in range(len(ngram_words)))
|
| 77 |
+
):
|
| 78 |
+
corrected_positions.append((i, i + len(ngram_words) - 1))
|
| 79 |
+
|
| 80 |
+
if corrected_positions:
|
| 81 |
+
print(f"Found flexible matches for '{ngram}': {corrected_positions}")
|
| 82 |
+
corrected_ngrams[ngram] = corrected_positions
|
| 83 |
+
else:
|
| 84 |
+
# If still no match, keep original positions as fallback
|
| 85 |
+
print(f"No matches found for '{ngram}'. Keeping original positions.")
|
| 86 |
+
corrected_ngrams[ngram] = positions
|
| 87 |
+
|
| 88 |
+
# Log changes
|
| 89 |
+
if corrected_ngrams != common_ngrams:
|
| 90 |
+
print(f"Original ngram positions: {common_ngrams}")
|
| 91 |
+
print(f"Corrected ngram positions: {corrected_ngrams}")
|
| 92 |
+
|
| 93 |
+
return corrected_ngrams
|
| 94 |
+
def in_any_ngram(self, idx, ngram_positions):
|
| 95 |
+
"""Check if an original sentence index is part of any n-gram span"""
|
| 96 |
+
return any(start <= idx <= end for start, end in ngram_positions)
|
| 97 |
+
def create_fallback_mask(self, sentence, ngrams):
|
| 98 |
+
"""Create a fallback mask when normal strategies fail."""
|
| 99 |
+
try:
|
| 100 |
+
words = sentence.split()
|
| 101 |
+
if not words:
|
| 102 |
+
return None
|
| 103 |
+
|
| 104 |
+
# Find any non-stopword that isn't in an ngram
|
| 105 |
+
ngram_positions = []
|
| 106 |
+
for positions in ngrams.values():
|
| 107 |
+
for start, end in positions:
|
| 108 |
+
ngram_positions.append((start, end))
|
| 109 |
+
ngram_positions.sort()
|
| 110 |
+
|
| 111 |
+
# Find first eligible word
|
| 112 |
+
for idx, word in enumerate(words):
|
| 113 |
+
if not self.is_stopword(word) and not self.in_any_ngram(idx, ngram_positions):
|
| 114 |
+
masked_words = words.copy()
|
| 115 |
+
masked_words[idx] = self.tokenizer.mask_token
|
| 116 |
+
tqdm.write(f"[INFO] Fallback mask created at position {idx}: '{word}'")
|
| 117 |
+
return " ".join(masked_words), [idx]
|
| 118 |
+
|
| 119 |
+
# If no eligible word found, just mask the first non-stop word
|
| 120 |
+
for idx, word in enumerate(words):
|
| 121 |
+
if not self.is_stopword(word):
|
| 122 |
+
masked_words = words.copy()
|
| 123 |
+
masked_words[idx] = self.tokenizer.mask_token
|
| 124 |
+
tqdm.write(f"[INFO] Last resort fallback mask created at position {idx}: '{word}'")
|
| 125 |
+
return " ".join(masked_words), [idx]
|
| 126 |
+
|
| 127 |
+
# If still nothing, mask the first word
|
| 128 |
+
if words:
|
| 129 |
+
masked_words = words.copy()
|
| 130 |
+
masked_words[0] = self.tokenizer.mask_token
|
| 131 |
+
return " ".join(masked_words), [0]
|
| 132 |
+
|
| 133 |
+
return None
|
| 134 |
+
except Exception as e:
|
| 135 |
+
tqdm.write(f"[ERROR] Error creating fallback mask: {e}")
|
| 136 |
+
return None
|
| 137 |
+
|
| 138 |
+
def mask_sentence_random(self, sentence, common_ngrams):
|
| 139 |
+
"""Mask random non-stopwords that are not part of common ngrams with controlled positioning."""
|
| 140 |
+
common_ngrams = self.verify_and_correct_ngram_positions(sentence, common_ngrams)
|
| 141 |
+
tqdm.write(f"[MaskingProcessor] Masking (random) sentence: {sentence}")
|
| 142 |
+
original_words = sentence.split()
|
| 143 |
+
|
| 144 |
+
# Handle punctuation
|
| 145 |
+
has_punctuation = False
|
| 146 |
+
punctuation = ''
|
| 147 |
+
if original_words and original_words[-1][-1] in ['.', ',', '!', '?', ';', ':', '"', "'"]:
|
| 148 |
+
has_punctuation = True
|
| 149 |
+
punctuation = original_words[-1][-1]
|
| 150 |
+
original_words[-1] = original_words[-1][:-1]
|
| 151 |
+
if not original_words[-1]: # If the word was just punctuation
|
| 152 |
+
original_words.pop()
|
| 153 |
+
|
| 154 |
+
# Get flattened ngram positions
|
| 155 |
+
ngram_positions = []
|
| 156 |
+
for positions in common_ngrams.values():
|
| 157 |
+
for start, end in positions:
|
| 158 |
+
ngram_positions.append((start, end))
|
| 159 |
+
ngram_positions.sort()
|
| 160 |
+
|
| 161 |
+
# Find all candidate indices (non-stopwords not in ngrams)
|
| 162 |
+
candidate_indices = []
|
| 163 |
+
for idx, word in enumerate(original_words):
|
| 164 |
+
if not self.is_stopword(word) and not self.in_any_ngram(idx, ngram_positions):
|
| 165 |
+
candidate_indices.append(idx)
|
| 166 |
+
|
| 167 |
+
# Debug print candidate words
|
| 168 |
+
print("Candidate words for masking:")
|
| 169 |
+
for idx in candidate_indices:
|
| 170 |
+
print(f" Position {idx}: '{original_words[idx]}'")
|
| 171 |
+
|
| 172 |
+
selected_indices = []
|
| 173 |
+
if ngram_positions:
|
| 174 |
+
# Before first ngram
|
| 175 |
+
before_first = [idx for idx in candidate_indices if idx < ngram_positions[0][0]]
|
| 176 |
+
if before_first:
|
| 177 |
+
num_to_select = min(1, len(before_first)) # Select 1 word
|
| 178 |
+
if num_to_select > 0:
|
| 179 |
+
selected = random.sample(before_first, num_to_select)
|
| 180 |
+
selected_indices.extend(selected)
|
| 181 |
+
|
| 182 |
+
# Between ngrams
|
| 183 |
+
for i in range(len(ngram_positions) - 1):
|
| 184 |
+
between = [idx for idx in candidate_indices
|
| 185 |
+
if ngram_positions[i][1] < idx < ngram_positions[i+1][0]]
|
| 186 |
+
if between:
|
| 187 |
+
num_to_select = min(2, len(between)) # Select between 1-2 words
|
| 188 |
+
if num_to_select > 0:
|
| 189 |
+
selected = random.sample(between, num_to_select)
|
| 190 |
+
selected_indices.extend(selected)
|
| 191 |
+
|
| 192 |
+
# After last ngram
|
| 193 |
+
after_last = [idx for idx in candidate_indices if idx > ngram_positions[-1][1]]
|
| 194 |
+
if after_last:
|
| 195 |
+
num_to_select = min(1, len(after_last)) # Select 1 word
|
| 196 |
+
if num_to_select > 0:
|
| 197 |
+
selected = random.sample(after_last, num_to_select)
|
| 198 |
+
selected_indices.extend(selected)
|
| 199 |
+
else:
|
| 200 |
+
# If no ngrams, pick up to 6 random candidates
|
| 201 |
+
if candidate_indices:
|
| 202 |
+
selected_indices = random.sample(candidate_indices,
|
| 203 |
+
min(6, len(candidate_indices)))
|
| 204 |
+
|
| 205 |
+
masked_words = original_words.copy()
|
| 206 |
+
for idx in selected_indices:
|
| 207 |
+
masked_words[idx] = self.tokenizer.mask_token
|
| 208 |
+
|
| 209 |
+
if has_punctuation:
|
| 210 |
+
masked_words.append(punctuation)
|
| 211 |
+
|
| 212 |
+
# Debug prints
|
| 213 |
+
print("Original sentence:", sentence)
|
| 214 |
+
print("Common ngrams:", common_ngrams)
|
| 215 |
+
print("Common ngram positions:", ngram_positions)
|
| 216 |
+
print("Candidate indices for masking:", candidate_indices)
|
| 217 |
+
print("Selected for masking:", selected_indices)
|
| 218 |
+
print("Masked sentence:", " ".join(masked_words))
|
| 219 |
+
|
| 220 |
+
return " ".join(masked_words), selected_indices
|
| 221 |
+
|
| 222 |
+
def mask_sentence_pseudorandom(self, sentence, common_ngrams):
|
| 223 |
+
"""Mask specific non-stopwords based on their position relative to ngrams."""
|
| 224 |
+
common_ngrams = self.verify_and_correct_ngram_positions(sentence, common_ngrams)
|
| 225 |
+
tqdm.write(f"[MaskingProcessor] Masking (pseudorandom) sentence: {sentence}")
|
| 226 |
+
random.seed(3) # Fixed seed for pseudorandom behavior
|
| 227 |
+
original_words = sentence.split()
|
| 228 |
+
|
| 229 |
+
# Handle punctuation
|
| 230 |
+
has_punctuation = False
|
| 231 |
+
punctuation = ''
|
| 232 |
+
if original_words and original_words[-1][-1] in ['.', ',', '!', '?', ';', ':', '"', "'"]:
|
| 233 |
+
has_punctuation = True
|
| 234 |
+
punctuation = original_words[-1][-1]
|
| 235 |
+
original_words[-1] = original_words[-1][:-1]
|
| 236 |
+
if not original_words[-1]: # If the word was just punctuation
|
| 237 |
+
original_words.pop()
|
| 238 |
+
|
| 239 |
+
# Get flattened ngram positions
|
| 240 |
+
ngram_positions = []
|
| 241 |
+
for positions in common_ngrams.values():
|
| 242 |
+
for start, end in positions:
|
| 243 |
+
ngram_positions.append((start, end))
|
| 244 |
+
ngram_positions.sort()
|
| 245 |
+
|
| 246 |
+
# Find all candidate indices (non-stopwords not in ngrams)
|
| 247 |
+
candidate_indices = []
|
| 248 |
+
for idx, word in enumerate(original_words):
|
| 249 |
+
if not self.is_stopword(word) and not self.in_any_ngram(idx, ngram_positions):
|
| 250 |
+
candidate_indices.append(idx)
|
| 251 |
+
|
| 252 |
+
# Debug print candidate words
|
| 253 |
+
print("Candidate words for masking:")
|
| 254 |
+
for idx in candidate_indices:
|
| 255 |
+
print(f" Position {idx}: '{original_words[idx]}'")
|
| 256 |
+
|
| 257 |
+
# PSEUDORANDOM SPECIFIC LOGIC:
|
| 258 |
+
selected_indices = []
|
| 259 |
+
if ngram_positions:
|
| 260 |
+
# Before first ngram
|
| 261 |
+
before_first = [idx for idx in candidate_indices if idx < ngram_positions[0][0]]
|
| 262 |
+
if before_first:
|
| 263 |
+
num_to_select = min(1, len(before_first)) # Select 1 word
|
| 264 |
+
if num_to_select > 0:
|
| 265 |
+
selected = random.sample(before_first, num_to_select)
|
| 266 |
+
selected_indices.extend(selected)
|
| 267 |
+
|
| 268 |
+
# Between ngrams
|
| 269 |
+
for i in range(len(ngram_positions) - 1):
|
| 270 |
+
between = [idx for idx in candidate_indices
|
| 271 |
+
if ngram_positions[i][1] < idx < ngram_positions[i+1][0]]
|
| 272 |
+
if between:
|
| 273 |
+
num_to_select = min(2, len(between)) # Select between 1-2 words
|
| 274 |
+
if num_to_select > 0:
|
| 275 |
+
selected = random.sample(between, num_to_select)
|
| 276 |
+
selected_indices.extend(selected)
|
| 277 |
+
|
| 278 |
+
# After last ngram
|
| 279 |
+
after_last = [idx for idx in candidate_indices if idx > ngram_positions[-1][1]]
|
| 280 |
+
if after_last:
|
| 281 |
+
num_to_select = min(1, len(after_last)) # Select 1 word
|
| 282 |
+
if num_to_select > 0:
|
| 283 |
+
selected = random.sample(after_last, num_to_select)
|
| 284 |
+
selected_indices.extend(selected)
|
| 285 |
+
else:
|
| 286 |
+
# If no ngrams, pick up to 6 random candidates
|
| 287 |
+
if candidate_indices:
|
| 288 |
+
selected_indices = random.sample(candidate_indices,
|
| 289 |
+
min(6, len(candidate_indices)))
|
| 290 |
+
|
| 291 |
+
masked_words = original_words.copy()
|
| 292 |
+
for idx in selected_indices:
|
| 293 |
+
masked_words[idx] = self.tokenizer.mask_token
|
| 294 |
+
|
| 295 |
+
if has_punctuation:
|
| 296 |
+
masked_words.append(punctuation)
|
| 297 |
+
|
| 298 |
+
# Debug prints
|
| 299 |
+
print("Original sentence:", sentence)
|
| 300 |
+
print("Common ngrams:", common_ngrams)
|
| 301 |
+
print("Common ngram positions:", ngram_positions)
|
| 302 |
+
print("Candidate indices for masking:", candidate_indices)
|
| 303 |
+
print("Selected for masking:", selected_indices)
|
| 304 |
+
print("Masked sentence:", " ".join(masked_words))
|
| 305 |
+
|
| 306 |
+
return " ".join(masked_words), selected_indices
|
| 307 |
+
|
| 308 |
+
def mask_sentence_entropy(self, sentence, common_ngrams):
|
| 309 |
+
"""Mask words with highest entropy that are not part of common ngrams."""
|
| 310 |
+
common_ngrams = self.verify_and_correct_ngram_positions(sentence, common_ngrams)
|
| 311 |
+
tqdm.write(f"[MaskingProcessor] Masking (entropy) sentence: {sentence}")
|
| 312 |
+
original_words = sentence.split()
|
| 313 |
+
|
| 314 |
+
# Handle punctuation
|
| 315 |
+
has_punctuation = False
|
| 316 |
+
punctuation = ''
|
| 317 |
+
if original_words and original_words[-1][-1] in ['.', ',', '!', '?', ';', ':', '"', "'"]:
|
| 318 |
+
has_punctuation = True
|
| 319 |
+
punctuation = original_words[-1][-1]
|
| 320 |
+
original_words[-1] = original_words[-1][:-1]
|
| 321 |
+
if not original_words[-1]: # If the word was just punctuation
|
| 322 |
+
original_words.pop()
|
| 323 |
+
|
| 324 |
+
# Get flattened ngram positions
|
| 325 |
+
ngram_positions = []
|
| 326 |
+
for positions in common_ngrams.values():
|
| 327 |
+
for start, end in positions:
|
| 328 |
+
ngram_positions.append((start, end))
|
| 329 |
+
ngram_positions.sort()
|
| 330 |
+
|
| 331 |
+
# Find all candidate indices (non-stopwords not in ngrams)
|
| 332 |
+
candidate_indices = []
|
| 333 |
+
for idx, word in enumerate(original_words):
|
| 334 |
+
if not self.is_stopword(word) and not self.in_any_ngram(idx, ngram_positions):
|
| 335 |
+
candidate_indices.append(idx)
|
| 336 |
+
|
| 337 |
+
# Debug print candidate words
|
| 338 |
+
print("Candidate words for masking:")
|
| 339 |
+
for idx in candidate_indices:
|
| 340 |
+
print(f" Position {idx}: '{original_words[idx]}'")
|
| 341 |
+
|
| 342 |
+
# ENTROPY SPECIFIC LOGIC:
|
| 343 |
+
# Calculate entropy for each candidate word
|
| 344 |
+
selected_indices = []
|
| 345 |
+
if candidate_indices:
|
| 346 |
+
# Organize candidates by position relative to ngrams
|
| 347 |
+
if ngram_positions:
|
| 348 |
+
# Group candidates by position
|
| 349 |
+
before_first = []
|
| 350 |
+
between_ngrams = {}
|
| 351 |
+
after_last = []
|
| 352 |
+
|
| 353 |
+
for idx in candidate_indices:
|
| 354 |
+
if idx < ngram_positions[0][0]:
|
| 355 |
+
before_first.append(idx)
|
| 356 |
+
elif idx > ngram_positions[-1][1]:
|
| 357 |
+
after_last.append(idx)
|
| 358 |
+
else:
|
| 359 |
+
# Find which ngram gap this belongs to
|
| 360 |
+
for i in range(len(ngram_positions) - 1):
|
| 361 |
+
if ngram_positions[i][1] < idx < ngram_positions[i+1][0]:
|
| 362 |
+
if i not in between_ngrams:
|
| 363 |
+
between_ngrams[i] = []
|
| 364 |
+
between_ngrams[i].append(idx)
|
| 365 |
+
|
| 366 |
+
# Before first ngram: select 1 highest entropy words
|
| 367 |
+
if before_first:
|
| 368 |
+
entropies = [(idx, self.calculate_word_entropy(sentence, idx)) for idx in before_first]
|
| 369 |
+
entropies.sort(key=lambda x: x[1], reverse=True) # Sort by entropy (highest first)
|
| 370 |
+
num_to_select = min(1, len(entropies)) # Select 1 word
|
| 371 |
+
selected_indices.extend([idx for idx, _ in entropies[:num_to_select]])
|
| 372 |
+
|
| 373 |
+
# For each gap between ngrams: select 1-2 highest entropy words
|
| 374 |
+
for group, indices in between_ngrams.items():
|
| 375 |
+
if indices:
|
| 376 |
+
entropies = [(idx, self.calculate_word_entropy(sentence, idx)) for idx in indices]
|
| 377 |
+
entropies.sort(key=lambda x: x[1], reverse=True) # Sort by entropy (highest first)
|
| 378 |
+
num_to_select = min(2, len(entropies)) # Select between 1-2 words
|
| 379 |
+
selected_indices.extend([idx for idx, _ in entropies[:num_to_select]])
|
| 380 |
+
|
| 381 |
+
# After last ngram: select 1 highest entropy words
|
| 382 |
+
if after_last:
|
| 383 |
+
entropies = [(idx, self.calculate_word_entropy(sentence, idx)) for idx in after_last]
|
| 384 |
+
entropies.sort(key=lambda x: x[1], reverse=True) # Sort by entropy (highest first)
|
| 385 |
+
num_to_select = min(1, len(entropies)) # Select 1 word
|
| 386 |
+
selected_indices.extend([idx for idx, _ in entropies[:num_to_select]])
|
| 387 |
+
else:
|
| 388 |
+
# If no ngrams, calculate entropy for all candidates
|
| 389 |
+
entropies = [(idx, self.calculate_word_entropy(sentence, idx)) for idx in candidate_indices]
|
| 390 |
+
# Sort by entropy (highest first)
|
| 391 |
+
entropies.sort(key=lambda x: x[1], reverse=True)
|
| 392 |
+
# Take top 6 or all if fewer
|
| 393 |
+
selected_indices = [idx for idx, _ in entropies[:min(6, len(entropies))]]
|
| 394 |
+
|
| 395 |
+
masked_words = original_words.copy()
|
| 396 |
+
for idx in selected_indices:
|
| 397 |
+
masked_words[idx] = self.tokenizer.mask_token
|
| 398 |
+
|
| 399 |
+
if has_punctuation:
|
| 400 |
+
masked_words.append(punctuation)
|
| 401 |
+
|
| 402 |
+
# Debug prints
|
| 403 |
+
print("Original sentence:", sentence)
|
| 404 |
+
print("Common ngrams:", common_ngrams)
|
| 405 |
+
print("Common ngram positions:", ngram_positions)
|
| 406 |
+
print("Candidate indices for masking:", candidate_indices)
|
| 407 |
+
print("Selected for masking:", selected_indices)
|
| 408 |
+
print("Masked sentence:", " ".join(masked_words))
|
| 409 |
+
|
| 410 |
+
return " ".join(masked_words), selected_indices
|
| 411 |
+
|
| 412 |
+
def calculate_mask_logits(self, original_sentence, original_mask_indices):
|
| 413 |
+
"""Calculate logits for masked positions."""
|
| 414 |
+
logger.info(f"Calculating mask logits for sentence: {original_sentence}")
|
| 415 |
+
words = original_sentence.split()
|
| 416 |
+
mask_logits = {}
|
| 417 |
+
for idx in original_mask_indices:
|
| 418 |
+
masked_words = words.copy()
|
| 419 |
+
masked_words[idx] = self.tokenizer.mask_token
|
| 420 |
+
masked_sentence = " ".join(masked_words)
|
| 421 |
+
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"].to(self.device)
|
| 422 |
+
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
| 423 |
+
with torch.no_grad():
|
| 424 |
+
outputs = self.model(input_ids)
|
| 425 |
+
logits = outputs.logits
|
| 426 |
+
mask_logits_tensor = logits[0, mask_token_index, :]
|
| 427 |
+
top_mask_logits, top_mask_indices = torch.topk(mask_logits_tensor, 100, dim=-1)
|
| 428 |
+
top_tokens = []
|
| 429 |
+
top_logits = []
|
| 430 |
+
seen_words = set()
|
| 431 |
+
for token_id, logit in zip(top_mask_indices[0], top_mask_logits[0]):
|
| 432 |
+
token = self.tokenizer.convert_ids_to_tokens(token_id.item())
|
| 433 |
+
if token.startswith('##'):
|
| 434 |
+
continue
|
| 435 |
+
word = self.tokenizer.convert_tokens_to_string([token]).strip()
|
| 436 |
+
if word and word not in seen_words:
|
| 437 |
+
seen_words.add(word)
|
| 438 |
+
top_tokens.append(word)
|
| 439 |
+
top_logits.append(logit.item())
|
| 440 |
+
if len(top_tokens) == 50:
|
| 441 |
+
break
|
| 442 |
+
mask_logits[idx] = {
|
| 443 |
+
"tokens": top_tokens,
|
| 444 |
+
"logits": top_logits
|
| 445 |
+
}
|
| 446 |
+
logger.info("Completed calculating mask logits.")
|
| 447 |
+
return mask_logits
|
| 448 |
+
|
| 449 |
+
def calculate_word_entropy(self, sentence, word_position):
|
| 450 |
+
"""Calculate entropy for a word at a specific position."""
|
| 451 |
+
logger.info(f"Calculating word entropy for position {word_position} in sentence: {sentence}")
|
| 452 |
+
words = sentence.split()
|
| 453 |
+
masked_words = words.copy()
|
| 454 |
+
masked_words[word_position] = self.tokenizer.mask_token
|
| 455 |
+
masked_sentence = " ".join(masked_words)
|
| 456 |
+
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"].to(self.device)
|
| 457 |
+
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
| 458 |
+
with torch.no_grad():
|
| 459 |
+
outputs = self.model(input_ids)
|
| 460 |
+
logits = outputs.logits
|
| 461 |
+
probs = torch.nn.functional.softmax(logits[0, mask_token_index], dim=-1)
|
| 462 |
+
entropy = -torch.sum(probs * torch.log(probs + 1e-9))
|
| 463 |
+
logger.info(f"Computed entropy: {entropy.item()}")
|
| 464 |
+
return entropy.item()
|
| 465 |
+
|
| 466 |
+
def process_sentences(self, sentences_list, common_grams, method="random"):
|
| 467 |
+
"""Process multiple sentences with the specified masking method."""
|
| 468 |
+
tqdm.write(f"[MaskingProcessor] Processing sentences using method: {method}")
|
| 469 |
+
results = {}
|
| 470 |
+
for sentence in tqdm(sentences_list, desc="Masking Sentences"):
|
| 471 |
+
try:
|
| 472 |
+
ngrams = common_grams.get(sentence, {})
|
| 473 |
+
|
| 474 |
+
if method == "random":
|
| 475 |
+
masked_sentence, original_mask_indices = self.mask_sentence_random(sentence, ngrams)
|
| 476 |
+
elif method == "pseudorandom":
|
| 477 |
+
masked_sentence, original_mask_indices = self.mask_sentence_pseudorandom(sentence, ngrams)
|
| 478 |
+
else: # entropy
|
| 479 |
+
masked_sentence, original_mask_indices = self.mask_sentence_entropy(sentence, ngrams)
|
| 480 |
+
|
| 481 |
+
# Skip if no masks were applied
|
| 482 |
+
if not original_mask_indices:
|
| 483 |
+
tqdm.write(f"[WARNING] No mask indices found for sentence with method {method}: {sentence[:50]}...")
|
| 484 |
+
# Create a fallback masked sentence with at least one mask
|
| 485 |
+
fallback_result = self.create_fallback_mask(sentence, ngrams)
|
| 486 |
+
if fallback_result:
|
| 487 |
+
masked_sentence, original_mask_indices = fallback_result
|
| 488 |
+
tqdm.write(f"[INFO] Created fallback mask for sentence")
|
| 489 |
+
else:
|
| 490 |
+
tqdm.write(f"[WARNING] Could not create fallback mask, skipping sentence")
|
| 491 |
+
continue
|
| 492 |
+
|
| 493 |
+
logits = self.calculate_mask_logits(sentence, original_mask_indices)
|
| 494 |
+
results[sentence] = {
|
| 495 |
+
"masked_sentence": masked_sentence,
|
| 496 |
+
"mask_logits": logits
|
| 497 |
+
}
|
| 498 |
+
logger.info(f"Processed sentence: {sentence}")
|
| 499 |
+
except Exception as e:
|
| 500 |
+
tqdm.write(f"[ERROR] Failed to process sentence with method {method}: {e}")
|
| 501 |
+
tqdm.write(f"Sentence: {sentence[:100]}...")
|
| 502 |
+
import traceback
|
| 503 |
+
tqdm.write(traceback.format_exc())
|
| 504 |
+
tqdm.write("[MaskingProcessor] Completed processing sentences.")
|
| 505 |
+
return results
|
| 506 |
+
|
| 507 |
+
@staticmethod
|
| 508 |
+
def identify_common_ngrams(sentences, entities):
|
| 509 |
+
"""Enhanced to handle possessive forms better"""
|
| 510 |
+
common_grams = {}
|
| 511 |
+
|
| 512 |
+
# Pre-process entities to handle variations
|
| 513 |
+
processed_entities = []
|
| 514 |
+
for entity in entities:
|
| 515 |
+
processed_entities.append(entity)
|
| 516 |
+
# Add possessive form if not already there
|
| 517 |
+
if not entity.endswith("'s") and not entity.endswith("s"):
|
| 518 |
+
processed_entities.append(f"{entity}'s")
|
| 519 |
+
|
| 520 |
+
for sentence in sentences:
|
| 521 |
+
words = sentence.split()
|
| 522 |
+
common_grams[sentence] = {}
|
| 523 |
+
|
| 524 |
+
# Look for each entity in the sentence
|
| 525 |
+
for entity in processed_entities:
|
| 526 |
+
entity_words = entity.split()
|
| 527 |
+
entity_len = len(entity_words)
|
| 528 |
+
|
| 529 |
+
# Convert entity words for matching
|
| 530 |
+
clean_entity_words = [clean_word(word) for word in entity_words]
|
| 531 |
+
|
| 532 |
+
# Find all occurrences
|
| 533 |
+
for i in range(len(words) - entity_len + 1):
|
| 534 |
+
is_match = True
|
| 535 |
+
for j, entity_word in enumerate(clean_entity_words):
|
| 536 |
+
if clean_word(words[i + j]) != entity_word:
|
| 537 |
+
is_match = False
|
| 538 |
+
break
|
| 539 |
+
|
| 540 |
+
if is_match:
|
| 541 |
+
# Use canonical form from entity list for consistency
|
| 542 |
+
base_entity = entity
|
| 543 |
+
if entity.endswith("'s") and any(e == entity[:-2] for e in processed_entities):
|
| 544 |
+
base_entity = entity[:-2]
|
| 545 |
+
|
| 546 |
+
if base_entity not in common_grams[sentence]:
|
| 547 |
+
common_grams[sentence][base_entity] = []
|
| 548 |
+
common_grams[sentence][base_entity].append((i, i + entity_len - 1))
|
| 549 |
+
|
| 550 |
+
return common_grams
|
| 551 |
+
if __name__ == "__main__":
|
| 552 |
+
#example test
|
| 553 |
+
# test_sentence = "Kevin De Bruyne scored for Manchester City as they won the 2019-20 Premier League title."
|
| 554 |
+
# entities to preserve
|
| 555 |
+
# entities = ["Kevin De Bruyne", "Manchester City", "Premier League"]
|
| 556 |
+
# Identify common n-grams
|
| 557 |
+
common_grams = MaskingProcessor.identify_common_ngrams([test_sentence], entities)
|
| 558 |
+
|
| 559 |
+
# Print detected n-grams
|
| 560 |
+
print(f"Detected common n-grams: {common_grams[test_sentence]}")
|
| 561 |
+
|
| 562 |
+
# Initialize the processor
|
| 563 |
+
processor = MaskingProcessor(
|
| 564 |
+
BertTokenizer.from_pretrained("bert-large-cased-whole-word-masking"),
|
| 565 |
+
BertForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
# Test all three masking methods
|
| 569 |
+
print("\nTesting Random Masking:")
|
| 570 |
+
masked_random, indices_random = processor.mask_sentence_random(test_sentence, common_grams[test_sentence])
|
| 571 |
+
|
| 572 |
+
print("\nTesting Pseudorandom Masking:")
|
| 573 |
+
masked_pseudorandom, indices_pseudorandom = processor.mask_sentence_pseudorandom(test_sentence, common_grams[test_sentence])
|
| 574 |
+
|
| 575 |
+
print("\nTesting Entropy Masking:")
|
| 576 |
+
masked_entropy, indices_entropy = processor.mask_sentence_entropy(test_sentence, common_grams[test_sentence])
|
| 577 |
+
|
| 578 |
+
# Print results
|
| 579 |
+
print("\nResults:")
|
| 580 |
+
print(f"Original: {test_sentence}")
|
| 581 |
+
print(f"Random Masked: {masked_random}")
|
| 582 |
+
print(f"Pseudorandom Masked: {masked_pseudorandom}")
|
| 583 |
+
print(f"Entropy Masked: {masked_entropy}")
|
utils/non_melting_point.py
ADDED
|
@@ -0,0 +1,590 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import nltk
|
| 2 |
+
import logging
|
| 3 |
+
import spacy
|
| 4 |
+
from nltk.corpus import stopwords
|
| 5 |
+
from nltk.util import ngrams
|
| 6 |
+
from collections import Counter
|
| 7 |
+
import re
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
# Logging setup
|
| 11 |
+
logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s")
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
class NgramProcessor:
|
| 15 |
+
def __init__(self, models=None):
|
| 16 |
+
try:
|
| 17 |
+
nltk.data.find('corpora/stopwords')
|
| 18 |
+
except LookupError:
|
| 19 |
+
nltk.download('stopwords')
|
| 20 |
+
self.stop_words = set(stopwords.words('english'))
|
| 21 |
+
|
| 22 |
+
# Default to standard model if none specified
|
| 23 |
+
if models is None:
|
| 24 |
+
models = ["en_core_web_trf"]
|
| 25 |
+
|
| 26 |
+
# Load specified model
|
| 27 |
+
self.models = {}
|
| 28 |
+
for model_name in models:
|
| 29 |
+
try:
|
| 30 |
+
self.models[model_name] = spacy.load(model_name)
|
| 31 |
+
tqdm.write(f"[NgramProcessor] Loaded model: {model_name}")
|
| 32 |
+
except IOError:
|
| 33 |
+
tqdm.write(f"[NgramProcessor] Error: Model '{model_name}' not found. Please install it with:")
|
| 34 |
+
tqdm.write(f"python -m spacy download {model_name}")
|
| 35 |
+
except Exception as e:
|
| 36 |
+
tqdm.write(f"[NgramProcessor] Error loading model '{model_name}': {str(e)}")
|
| 37 |
+
|
| 38 |
+
# Set primary NLP model for other processes
|
| 39 |
+
if "en_core_web_trf" in self.models:
|
| 40 |
+
self.nlp = self.models["en_core_web_trf"]
|
| 41 |
+
elif len(self.models) > 0:
|
| 42 |
+
# Use first available model as primary if preferred one isn't available
|
| 43 |
+
self.nlp = next(iter(self.models.values()))
|
| 44 |
+
else:
|
| 45 |
+
raise ValueError("No spaCy model was successfully loaded")
|
| 46 |
+
|
| 47 |
+
# Add custom entity patterns for numerical ranges to primary model
|
| 48 |
+
if "entity_ruler" not in self.nlp.pipe_names:
|
| 49 |
+
ruler = self.nlp.add_pipe("entity_ruler", before="ner")
|
| 50 |
+
patterns = [
|
| 51 |
+
{"label": "CARDINAL", "pattern": [{"TEXT": {"REGEX": "\\d+-\\d+"}}]}, # Pattern for ranges like "7-10"
|
| 52 |
+
{"label": "PERCENT", "pattern": [{"TEXT": {"REGEX": "\\d+%"}}]} # Pattern for percentages
|
| 53 |
+
]
|
| 54 |
+
ruler.add_patterns(patterns)
|
| 55 |
+
|
| 56 |
+
# Create special pattern for numerical ranges
|
| 57 |
+
self.number_range_pattern = re.compile(r'\b(\d+(?:-\d+)+)\b')
|
| 58 |
+
|
| 59 |
+
tqdm.write("[NgramProcessor] Initialized with stopwords, spaCy NLP model, and numerical range detection.")
|
| 60 |
+
|
| 61 |
+
def remove_stopwords(self, text):
|
| 62 |
+
words = re.findall(r'\w+', text.lower())
|
| 63 |
+
filtered_words = [word for word in words if word not in self.stop_words]
|
| 64 |
+
return ' '.join(filtered_words)
|
| 65 |
+
|
| 66 |
+
def extract_number_ranges(self, sentences):
|
| 67 |
+
"""Extract numerical ranges like '7-10' from sentences"""
|
| 68 |
+
tqdm.write("[NgramProcessor] Extracting numerical ranges...")
|
| 69 |
+
|
| 70 |
+
number_ranges = []
|
| 71 |
+
range_counts = Counter()
|
| 72 |
+
|
| 73 |
+
for sentence in sentences:
|
| 74 |
+
# Find all numerical ranges in the sentence
|
| 75 |
+
matches = self.number_range_pattern.findall(sentence)
|
| 76 |
+
for match in matches:
|
| 77 |
+
range_counts[match] += 1
|
| 78 |
+
|
| 79 |
+
# Add all ranges that appear in all sentences (threshold for ranges)
|
| 80 |
+
for range_text, count in range_counts.items():
|
| 81 |
+
if count >= 1:
|
| 82 |
+
number_ranges.append(range_text)
|
| 83 |
+
|
| 84 |
+
tqdm.write(f"[NgramProcessor] Found {len(number_ranges)} numerical ranges: {number_ranges}")
|
| 85 |
+
|
| 86 |
+
return number_ranges
|
| 87 |
+
def extract_standalone_numbers(self, sentences):
|
| 88 |
+
"""Extract standalone numerical values from sentences"""
|
| 89 |
+
tqdm.write("[NgramProcessor] Extracting standalone numbers...")
|
| 90 |
+
|
| 91 |
+
# Two patterns: one for percentages, one for regular numbers
|
| 92 |
+
percentage_pattern = re.compile(r'\b\d+%\b') # Only matches numbers with % sign
|
| 93 |
+
number_pattern = re.compile(r'\b\d+\b') # Only matches standalone numbers
|
| 94 |
+
|
| 95 |
+
percentage_counts = Counter()
|
| 96 |
+
number_counts = Counter()
|
| 97 |
+
percentage_values = set() # Store the numeric part of percentages for cross-reference
|
| 98 |
+
|
| 99 |
+
# First pass: Find all percentages
|
| 100 |
+
for sentence in sentences:
|
| 101 |
+
# Extract all percentages first
|
| 102 |
+
percentage_matches = percentage_pattern.findall(sentence)
|
| 103 |
+
for match in percentage_matches:
|
| 104 |
+
percentage_counts[match] += 1
|
| 105 |
+
# Store the numeric part for later comparison
|
| 106 |
+
numeric_part = match.rstrip('%')
|
| 107 |
+
percentage_values.add(numeric_part)
|
| 108 |
+
|
| 109 |
+
# Second pass: Find standalone numbers
|
| 110 |
+
for sentence in sentences:
|
| 111 |
+
# Only look for standalone numbers now
|
| 112 |
+
number_matches = number_pattern.findall(sentence)
|
| 113 |
+
for match in number_matches:
|
| 114 |
+
# Avoid double counting numbers that we already counted as percentages
|
| 115 |
+
if match not in percentage_values:
|
| 116 |
+
number_counts[match] += 1
|
| 117 |
+
|
| 118 |
+
# Process percentages first (they have priority)
|
| 119 |
+
threshold = max(1, int(len(sentences) * 1.0))
|
| 120 |
+
standalone_numbers = []
|
| 121 |
+
|
| 122 |
+
# Add percentages that meet the threshold
|
| 123 |
+
for num, count in percentage_counts.items():
|
| 124 |
+
if count >= threshold:
|
| 125 |
+
standalone_numbers.append(num) # Already has % sign
|
| 126 |
+
|
| 127 |
+
# Then add standalone numbers, converting to percentage format if needed
|
| 128 |
+
for num, count in number_counts.items():
|
| 129 |
+
if count >= threshold:
|
| 130 |
+
# If this number also appeared as part of a percentage, use the percentage format
|
| 131 |
+
if num in percentage_values:
|
| 132 |
+
standalone_numbers.append(f"{num}%")
|
| 133 |
+
else:
|
| 134 |
+
standalone_numbers.append(num)
|
| 135 |
+
|
| 136 |
+
tqdm.write(f"[NgramProcessor] Found {len(standalone_numbers)} standalone numbers: {standalone_numbers}")
|
| 137 |
+
return standalone_numbers
|
| 138 |
+
|
| 139 |
+
def extract_regex_subsequences(self, sentences):
|
| 140 |
+
"""Extract potential subsequences using regex patterns before applying NLP"""
|
| 141 |
+
tqdm.write("[NgramProcessor] Extracting regex subsequences...")
|
| 142 |
+
|
| 143 |
+
# Find potential multi-word subsequences (2-5 words) that occur across sentences
|
| 144 |
+
potential_subsequences = set()
|
| 145 |
+
|
| 146 |
+
# Process each sentence to find multi-word phrases
|
| 147 |
+
for sentence in sentences:
|
| 148 |
+
# First, clean the sentence by removing punctuation and converting to lowercase
|
| 149 |
+
clean_sentence = re.sub(r'[^\w\s&-./\'()[\]$€£¥+%]', ' ', sentence.lower())
|
| 150 |
+
|
| 151 |
+
# Extract sequences of 2-6 words
|
| 152 |
+
for i in range(2, 7): # Try sequences of length 2-6 words
|
| 153 |
+
pattern = r'\b(\w+(?:[-&\s./\'()[\]$€£¥+%]+\w+){' + str(i-1) + r'})\b'
|
| 154 |
+
matches = re.findall(pattern, clean_sentence)
|
| 155 |
+
potential_subsequences.update(matches)
|
| 156 |
+
|
| 157 |
+
# Filter out sequences that consist only of stopwords (but preserve numbers)
|
| 158 |
+
filtered_subsequences = []
|
| 159 |
+
for subseq in potential_subsequences:
|
| 160 |
+
words = re.split(r'[\s-]+', subseq) # Split on spaces or hyphens
|
| 161 |
+
|
| 162 |
+
# Function to check if a word is a number or percentage
|
| 163 |
+
def is_numeric(word):
|
| 164 |
+
return bool(re.match(r'^\d+(\.\d+)?%?$|^\d+-\d+$', word))
|
| 165 |
+
|
| 166 |
+
# Skip if ALL words are stopwords and none are numeric
|
| 167 |
+
if all((word in self.stop_words and not is_numeric(word)) for word in words):
|
| 168 |
+
tqdm.write(f"[NgramProcessor] Skipping all-stopword phrase: {subseq}")
|
| 169 |
+
continue
|
| 170 |
+
|
| 171 |
+
# Keep if sequence has significant words (not just stopwords)
|
| 172 |
+
# OR if it contains numbers/percentages
|
| 173 |
+
if len(words) > 1 and (
|
| 174 |
+
any(word not in self.stop_words and (len(word) > 2 or is_numeric(word)) for word in words)
|
| 175 |
+
):
|
| 176 |
+
# Additional check to reject if standalone "the" or other common stopwords
|
| 177 |
+
if not (len(words) == 1 and words[0] in self.stop_words and not is_numeric(words[0])):
|
| 178 |
+
filtered_subsequences.append(subseq)
|
| 179 |
+
|
| 180 |
+
# Count occurrences across all sentences
|
| 181 |
+
subseq_counts = Counter()
|
| 182 |
+
for subseq in filtered_subsequences:
|
| 183 |
+
for sentence in sentences:
|
| 184 |
+
if re.search(r'\b' + re.escape(subseq) + r'\b', sentence.lower()):
|
| 185 |
+
subseq_counts[subseq] += 1
|
| 186 |
+
|
| 187 |
+
# Keep only subsequences that appear in multiple sentences
|
| 188 |
+
threshold = max(2, int(len(sentences) * 1.0)) # threshold to catch all patterns
|
| 189 |
+
regex_candidates = [subseq for subseq, count in subseq_counts.items()
|
| 190 |
+
if count >= threshold]
|
| 191 |
+
|
| 192 |
+
tqdm.write(f"[NgramProcessor] Found {len(regex_candidates)} regex subsequences")
|
| 193 |
+
return regex_candidates
|
| 194 |
+
|
| 195 |
+
def filter_standalone_stopwords(self, ngrams_dict):
|
| 196 |
+
"""Remove standalone stopwords and very short terms from the ngrams dictionary"""
|
| 197 |
+
filtered_dict = {}
|
| 198 |
+
for sentence, ngrams in ngrams_dict.items():
|
| 199 |
+
filtered_dict[sentence] = {}
|
| 200 |
+
for ngram, indices in ngrams.items():
|
| 201 |
+
words = ngram.split()
|
| 202 |
+
# Skip single stopwords and very short terms UNLESS they are numbers
|
| 203 |
+
if (len(words) == 1 and (words[0] in self.stop_words or len(words[0]) < 3)):
|
| 204 |
+
# Exception for numbers
|
| 205 |
+
if len(words) == 1 and re.match(r'^\d+$', words[0]):
|
| 206 |
+
filtered_dict[sentence][ngram] = indices
|
| 207 |
+
continue
|
| 208 |
+
else:
|
| 209 |
+
continue
|
| 210 |
+
# Skip if ALL words are stopwords
|
| 211 |
+
if all(word in self.stop_words for word in words):
|
| 212 |
+
continue
|
| 213 |
+
filtered_dict[sentence][ngram] = indices
|
| 214 |
+
return filtered_dict
|
| 215 |
+
|
| 216 |
+
def extract_named_entities(self, sentences):
|
| 217 |
+
entity_counter = Counter()
|
| 218 |
+
|
| 219 |
+
# Process each sentence with each model
|
| 220 |
+
for model_name, nlp_model in self.models.items():
|
| 221 |
+
tqdm.write(f"[NgramProcessor] Extracting entities with model: {model_name}")
|
| 222 |
+
docs = list(nlp_model.pipe(sentences))
|
| 223 |
+
|
| 224 |
+
# Process each sentence
|
| 225 |
+
for doc in docs:
|
| 226 |
+
for ent in doc.ents:
|
| 227 |
+
# Include entity types relevant to this model
|
| 228 |
+
# This is a comprehensive list - some models may not use all these types
|
| 229 |
+
if ent.label_ in {
|
| 230 |
+
# People, organizations, locations
|
| 231 |
+
"PERSON", "ORG", "GPE", "LOC", "NORP",
|
| 232 |
+
|
| 233 |
+
# Facilities and products
|
| 234 |
+
"FAC", "PRODUCT", "WORK_OF_ART", "EVENT",
|
| 235 |
+
|
| 236 |
+
# Numeric entities
|
| 237 |
+
"DATE", "TIME", "MONEY", "QUANTITY", "PERCENT", "CARDINAL", "ORDINAL",
|
| 238 |
+
|
| 239 |
+
# Others
|
| 240 |
+
"LAW", "LANGUAGE",
|
| 241 |
+
|
| 242 |
+
# Scientific entities
|
| 243 |
+
"SCIENTIFIC", "SUBSTANCE", "CHEMICAL", "TECHNOLOGY",
|
| 244 |
+
|
| 245 |
+
# Medical entities
|
| 246 |
+
"DISEASE", "MEDICAL", "CLINICAL", "TREATMENT", "SYMPTOM", "DIAGNOSTIC",
|
| 247 |
+
"ANATOMICAL", "BIOLOGY", "GENE", "PROTEIN", "DRUG",
|
| 248 |
+
|
| 249 |
+
# Legal entities
|
| 250 |
+
"LEGAL", "COURT", "STATUTE", "PROVISION", "CASE_CITATION", "JUDGE",
|
| 251 |
+
"LEGAL_ROLE", "REGULATION", "CONTRACT"
|
| 252 |
+
}:
|
| 253 |
+
# Handle possessive forms by stripping 's
|
| 254 |
+
clean_entity = re.sub(r"'s\b", "", ent.text.lower()).strip()
|
| 255 |
+
# Add model name prefix to distinguish sources
|
| 256 |
+
entity_counter[clean_entity] += 1
|
| 257 |
+
|
| 258 |
+
threshold = max(1, len(sentences) * 1.0) # Adjusted threshold for entities
|
| 259 |
+
return [ent for ent, count in entity_counter.items() if count >= threshold]
|
| 260 |
+
|
| 261 |
+
def extract_domain_specific_entities(self, text):
|
| 262 |
+
"""Extract entities from all models and categorize by domain"""
|
| 263 |
+
domain_entities = {}
|
| 264 |
+
|
| 265 |
+
for model_name, nlp_model in self.models.items():
|
| 266 |
+
doc = nlp_model(text)
|
| 267 |
+
domain_entities[model_name] = [(ent.text, ent.label_) for ent in doc.ents]
|
| 268 |
+
|
| 269 |
+
return domain_entities
|
| 270 |
+
|
| 271 |
+
def is_substring_of_any(self, ngram, common_ngrams):
|
| 272 |
+
for other_ngram in common_ngrams:
|
| 273 |
+
if ngram != other_ngram and ngram in other_ngram:
|
| 274 |
+
return True
|
| 275 |
+
return False
|
| 276 |
+
|
| 277 |
+
def find_filtered_ngrams(self, sentences):
|
| 278 |
+
tqdm.write("[NgramProcessor] Processing...")
|
| 279 |
+
|
| 280 |
+
# Step 1: First extract numerical ranges or standalone numbers (special priority)
|
| 281 |
+
number_ranges = self.extract_number_ranges(sentences)
|
| 282 |
+
standalone_numbers = self.extract_standalone_numbers(sentences)
|
| 283 |
+
|
| 284 |
+
# Step 2: Use regex to find common subsequences
|
| 285 |
+
regex_subsequences = self.extract_regex_subsequences(sentences)
|
| 286 |
+
tqdm.write(f"[NgramProcessor] Regex Subsequences: {regex_subsequences}")
|
| 287 |
+
|
| 288 |
+
# Step 3: Then apply spaCy to detect named entities
|
| 289 |
+
named_entities = self.extract_named_entities(sentences)
|
| 290 |
+
# Make sure percentage values have proper format
|
| 291 |
+
for i, entity in enumerate(named_entities):
|
| 292 |
+
if re.match(r'\d+$', entity) and any(f"{entity}%" in sentence for sentence in sentences):
|
| 293 |
+
# Replace standalone digit with percentage if it appears as percentage in text
|
| 294 |
+
named_entities[i] = f"{entity}%"
|
| 295 |
+
|
| 296 |
+
tqdm.write(f"[NgramProcessor] Named Entities: {named_entities}")
|
| 297 |
+
|
| 298 |
+
# Step 4: Consolidate and filter all detected patterns
|
| 299 |
+
# Collect all patterns in one list
|
| 300 |
+
all_patterns = number_ranges + regex_subsequences + named_entities + standalone_numbers
|
| 301 |
+
|
| 302 |
+
# Sort by length (longer first) to prioritize more specific patterns
|
| 303 |
+
all_patterns.sort(key=len, reverse=True)
|
| 304 |
+
|
| 305 |
+
# Remove duplicates while preserving order
|
| 306 |
+
unique_patterns = []
|
| 307 |
+
seen = set()
|
| 308 |
+
for pattern in all_patterns:
|
| 309 |
+
if pattern not in seen:
|
| 310 |
+
# Check if this pattern is a substring of any already selected pattern
|
| 311 |
+
is_substring = False
|
| 312 |
+
for selected_pattern in unique_patterns:
|
| 313 |
+
if pattern in selected_pattern and pattern != selected_pattern:
|
| 314 |
+
is_substring = True
|
| 315 |
+
break
|
| 316 |
+
if not is_substring:
|
| 317 |
+
unique_patterns.append(pattern)
|
| 318 |
+
seen.add(pattern)
|
| 319 |
+
# Re-index sequentially
|
| 320 |
+
indexed_patterns = [(i+1, pattern) for i, pattern in enumerate(unique_patterns)]
|
| 321 |
+
self.indexed_patterns = indexed_patterns
|
| 322 |
+
non_melting_points = [pattern for _, pattern in indexed_patterns]
|
| 323 |
+
tqdm.write(f"[NgramProcessor] Filtered non_melting_points: {non_melting_points}")
|
| 324 |
+
tqdm.write(f"[NgramProcessor] Filtered non-melting points: {len(non_melting_points)}")
|
| 325 |
+
|
| 326 |
+
# Filter out patterns that are substrings of longer patterns or standalone numbers
|
| 327 |
+
standalone_numbers_set = set(standalone_numbers)
|
| 328 |
+
non_melting_points = []
|
| 329 |
+
for pattern in unique_patterns:
|
| 330 |
+
is_substring = False
|
| 331 |
+
for longer_pattern in non_melting_points:
|
| 332 |
+
# Check if pattern is contained within a longer pattern
|
| 333 |
+
if pattern in longer_pattern:
|
| 334 |
+
is_substring = True
|
| 335 |
+
break
|
| 336 |
+
if not is_substring or pattern in standalone_numbers_set:
|
| 337 |
+
non_melting_points.append(pattern)
|
| 338 |
+
|
| 339 |
+
# For remaining cases that might have been missed, apply NLTK n-gram extraction
|
| 340 |
+
# Only on cleaned sentences (less computationally expensive now)
|
| 341 |
+
clean_to_original = {}
|
| 342 |
+
sentences_cleaned = []
|
| 343 |
+
|
| 344 |
+
# Process sentences with spaCy to preserve entity information
|
| 345 |
+
docs = list(self.nlp.pipe(sentences))
|
| 346 |
+
|
| 347 |
+
for i, doc in enumerate(docs):
|
| 348 |
+
original_sentence = sentences[i]
|
| 349 |
+
entity_texts = {ent.text.lower() for ent in doc.ents if len(ent.text.split()) > 1}
|
| 350 |
+
|
| 351 |
+
# Tokenize while preserving entities and numerical ranges
|
| 352 |
+
tokens = []
|
| 353 |
+
j = 0
|
| 354 |
+
words = [token.text for token in doc]
|
| 355 |
+
|
| 356 |
+
while j < len(words):
|
| 357 |
+
# First check for numerical ranges
|
| 358 |
+
current_word = words[j].lower()
|
| 359 |
+
if self.number_range_pattern.match(current_word):
|
| 360 |
+
tokens.append(current_word)
|
| 361 |
+
j += 1
|
| 362 |
+
continue
|
| 363 |
+
|
| 364 |
+
# Then check for entities
|
| 365 |
+
matched_entity = None
|
| 366 |
+
for ent in sorted(entity_texts, key=len, reverse=True):
|
| 367 |
+
ent_words = ent.split()
|
| 368 |
+
if j + len(ent_words) <= len(words) and [w.lower() for w in words[j:j+len(ent_words)]] == ent_words:
|
| 369 |
+
matched_entity = " ".join(words[j:j+len(ent_words)])
|
| 370 |
+
tokens.append(matched_entity.lower()) # preserve full entity
|
| 371 |
+
j += len(ent_words)
|
| 372 |
+
break
|
| 373 |
+
|
| 374 |
+
if not matched_entity:
|
| 375 |
+
word = words[j].lower()
|
| 376 |
+
if word not in self.stop_words and re.match(r'\w+', word):
|
| 377 |
+
tokens.append(word)
|
| 378 |
+
j += 1
|
| 379 |
+
|
| 380 |
+
cleaned = " ".join(tokens)
|
| 381 |
+
sentences_cleaned.append(cleaned)
|
| 382 |
+
clean_to_original[cleaned] = original_sentence
|
| 383 |
+
|
| 384 |
+
# Step 5: Only run n-gram extraction on gaps not covered by regex and named entities
|
| 385 |
+
ngram_lengths = [4, 3, 2, 1] # Consider shorter n-grams now since we already have longer phrases
|
| 386 |
+
all_ngrams_by_length = {}
|
| 387 |
+
|
| 388 |
+
for n in ngram_lengths:
|
| 389 |
+
all_ngrams = []
|
| 390 |
+
for sentence in sentences_cleaned:
|
| 391 |
+
tokens = sentence.split()
|
| 392 |
+
if len(tokens) >= n:
|
| 393 |
+
sent_ngrams = list(ngrams(tokens, n))
|
| 394 |
+
all_ngrams.extend(sent_ngrams)
|
| 395 |
+
all_ngrams_by_length[n] = Counter(all_ngrams)
|
| 396 |
+
|
| 397 |
+
# Step 6: Add additional n-grams that are frequent but weren't caught by regex or named entities
|
| 398 |
+
threshold_factor = 1.0 # threshold since we're focusing on gaps
|
| 399 |
+
|
| 400 |
+
for n_size in sorted(ngram_lengths, reverse=True):
|
| 401 |
+
ngram_counts = all_ngrams_by_length[n_size]
|
| 402 |
+
threshold = max(2, int(len(sentences) * threshold_factor))
|
| 403 |
+
|
| 404 |
+
# Sort by count for efficiency
|
| 405 |
+
for ngram, count in ngram_counts.most_common():
|
| 406 |
+
if count >= threshold:
|
| 407 |
+
ngram_str = ' '.join(ngram)
|
| 408 |
+
# Skip if is a substring of existing n-grams or already in our collection
|
| 409 |
+
if ngram_str not in non_melting_points and not self.is_substring_of_any(ngram_str, non_melting_points):
|
| 410 |
+
non_melting_points.append(ngram_str)
|
| 411 |
+
|
| 412 |
+
# Create sorted version for efficient lookup
|
| 413 |
+
final_non_melting_points = non_melting_points.copy()
|
| 414 |
+
sorted_non_melting_points = sorted(final_non_melting_points, key=len, reverse=True)
|
| 415 |
+
final_indexed_patterns = [(i+1, pattern) for i, pattern in enumerate(sorted_non_melting_points)]
|
| 416 |
+
|
| 417 |
+
#Filter out n-grams that consist entirely of stop words
|
| 418 |
+
filtered_patterns = []
|
| 419 |
+
for idx, pattern in final_indexed_patterns:
|
| 420 |
+
words = pattern.lower().split()
|
| 421 |
+
|
| 422 |
+
# Check if the pattern is a number or contains a number
|
| 423 |
+
has_number = any(re.match(r'.*\d+.*', word) for word in words)
|
| 424 |
+
|
| 425 |
+
# If the pattern has a number OR has any non-stop word, keep it
|
| 426 |
+
if has_number or any(word not in self.stop_words for word in words):
|
| 427 |
+
filtered_patterns.append((idx, pattern))
|
| 428 |
+
else:
|
| 429 |
+
tqdm.write(f"[NgramProcessor] Removing n-gram with all stop words: {pattern}")
|
| 430 |
+
|
| 431 |
+
# Reassign filtered patterns with reindexed values
|
| 432 |
+
self.indexed_patterns = [(i+1, pattern) for i, (_, pattern) in enumerate(filtered_patterns)]
|
| 433 |
+
|
| 434 |
+
# Generate the results with more efficient regex matching
|
| 435 |
+
result = {}
|
| 436 |
+
for sentence in sentences:
|
| 437 |
+
sentence_result = {}
|
| 438 |
+
for _,ngram in self.indexed_patterns: # Use the filtered patterns
|
| 439 |
+
# Skip single word stopwords and short terms
|
| 440 |
+
words = ngram.split()
|
| 441 |
+
if len(words) == 1 and (words[0] in self.stop_words or len(words[0]) < 3):
|
| 442 |
+
continue
|
| 443 |
+
|
| 444 |
+
# Handle numerical ranges differently - need exact matching
|
| 445 |
+
if self.number_range_pattern.match(ngram):
|
| 446 |
+
pattern = re.compile(r'\b' + re.escape(ngram) + r'\b', re.IGNORECASE)
|
| 447 |
+
else:
|
| 448 |
+
# Compile the regex pattern once per n-gram - modified to handle special characters
|
| 449 |
+
pattern = re.compile(r'(?<!\w)' + re.escape(ngram) + r'(?!\w)', re.IGNORECASE)
|
| 450 |
+
|
| 451 |
+
matches = list(pattern.finditer(sentence))
|
| 452 |
+
|
| 453 |
+
if matches:
|
| 454 |
+
indices = []
|
| 455 |
+
for match in matches:
|
| 456 |
+
# Calculate word indices with improved handling for hyphenated terms
|
| 457 |
+
start_pos = match.start()
|
| 458 |
+
text_before = sentence[:start_pos]
|
| 459 |
+
|
| 460 |
+
# More accurate word counting that handles hyphenated terms
|
| 461 |
+
start_idx = len(re.findall(r'\s+', text_before)) + (0 if text_before.strip() == "" else 1)
|
| 462 |
+
|
| 463 |
+
# Count words in the matched n-gram (handling hyphens as single terms)
|
| 464 |
+
if self.number_range_pattern.match(ngram):
|
| 465 |
+
# Numerical ranges count as one term
|
| 466 |
+
ngram_word_count = 1
|
| 467 |
+
else:
|
| 468 |
+
ngram_word_count = len(re.findall(r'\S+', ngram))
|
| 469 |
+
|
| 470 |
+
end_idx = start_idx + ngram_word_count - 1
|
| 471 |
+
|
| 472 |
+
indices.append((start_idx, end_idx))
|
| 473 |
+
|
| 474 |
+
if indices: # Only add if we found valid indices
|
| 475 |
+
sentence_result[ngram] = indices
|
| 476 |
+
|
| 477 |
+
result[sentence] = sentence_result
|
| 478 |
+
|
| 479 |
+
# Apply the stopword filter before returning
|
| 480 |
+
result = self.filter_standalone_stopwords(result)
|
| 481 |
+
return result, dict(self.indexed_patterns)
|
| 482 |
+
|
| 483 |
+
def find_relative_order(self, sentence, common_ngrams):
|
| 484 |
+
# First, identify all possible matches without modifying the sentence
|
| 485 |
+
all_matches = []
|
| 486 |
+
|
| 487 |
+
for ngram in common_ngrams:
|
| 488 |
+
# Special handling for percentages
|
| 489 |
+
if any(char in ngram for char in '&-/.\'()[]$€£¥+%'):
|
| 490 |
+
pattern = re.compile(r'\b' + re.escape(ngram) + r'\b', re.IGNORECASE)
|
| 491 |
+
# Handle numerical ranges
|
| 492 |
+
elif self.number_range_pattern.match(ngram):
|
| 493 |
+
pattern = re.compile(r'\b' + re.escape(ngram) + r'\b', re.IGNORECASE)
|
| 494 |
+
else:
|
| 495 |
+
pattern = re.compile(r'(?<!\w)' + re.escape(ngram) + r"(?:'s)?(?!\w)", re.IGNORECASE)
|
| 496 |
+
|
| 497 |
+
for match in pattern.finditer(sentence):
|
| 498 |
+
start, end = match.span()
|
| 499 |
+
#store character position range, ngram text, and token count
|
| 500 |
+
all_matches.append((start, end, ngram, len(ngram.split())))
|
| 501 |
+
|
| 502 |
+
# Pre-process: identify all word spans in the original sentence
|
| 503 |
+
words = []
|
| 504 |
+
word_spans = []
|
| 505 |
+
for match in re.finditer(r'\S+', sentence):
|
| 506 |
+
words.append(match.group())
|
| 507 |
+
word_spans.append((match.start(), match.end()))
|
| 508 |
+
|
| 509 |
+
# Create a mapping from character positions to word indices
|
| 510 |
+
char_to_word_idx = {}
|
| 511 |
+
for i, (start, end) in enumerate(word_spans):
|
| 512 |
+
for pos in range(start, end + 1):
|
| 513 |
+
char_to_word_idx[pos] = i
|
| 514 |
+
|
| 515 |
+
# Sort by length in characters first, then by word count
|
| 516 |
+
all_matches.sort(key=lambda x: (-len(x[2]), -x[3], x[0]))
|
| 517 |
+
|
| 518 |
+
# Filter out ngrams that overlap with already claimed ranges
|
| 519 |
+
filtered_matches = []
|
| 520 |
+
claimed_ranges = []
|
| 521 |
+
|
| 522 |
+
for start, end, ngram, length in all_matches:
|
| 523 |
+
# Check if this match overlaps with any existing claimed range
|
| 524 |
+
is_overlapping = False
|
| 525 |
+
for c_start, c_end in claimed_ranges:
|
| 526 |
+
# Check for any overlap
|
| 527 |
+
if max(start, c_start) < min(end, c_end):
|
| 528 |
+
is_overlapping = True
|
| 529 |
+
break
|
| 530 |
+
|
| 531 |
+
if not is_overlapping:
|
| 532 |
+
# Add this ngram to our filtered list
|
| 533 |
+
filtered_matches.append((start, end, ngram, length))
|
| 534 |
+
# Claim its range
|
| 535 |
+
claimed_ranges.append((start, end))
|
| 536 |
+
|
| 537 |
+
# Sort filtered matches by position for final ordering
|
| 538 |
+
filtered_matches.sort(key=lambda x: x[0])
|
| 539 |
+
|
| 540 |
+
# Create word-level indices for the final matches
|
| 541 |
+
word_level_matches = []
|
| 542 |
+
for start, end, ngram, _ in filtered_matches:
|
| 543 |
+
# Find the word index for the start and end positions
|
| 544 |
+
try:
|
| 545 |
+
start_word_idx = char_to_word_idx.get(start, char_to_word_idx.get(start+1))
|
| 546 |
+
end_word_idx = char_to_word_idx.get(end-1, char_to_word_idx.get(end-2))
|
| 547 |
+
|
| 548 |
+
if start_word_idx is not None and end_word_idx is not None:
|
| 549 |
+
word_level_matches.append((start_word_idx, end_word_idx, ngram))
|
| 550 |
+
except (KeyError, IndexError):
|
| 551 |
+
# Skip this match if we can't determine word indices
|
| 552 |
+
continue
|
| 553 |
+
|
| 554 |
+
# Create the final order with 1-based indexing
|
| 555 |
+
ngram_to_index = {pattern: idx for idx, pattern in self.indexed_patterns}
|
| 556 |
+
relative_order = [(ngram_to_index.get(ngram, i+1), ngram) for i, (_, _, ngram) in enumerate(word_level_matches)]
|
| 557 |
+
|
| 558 |
+
return relative_order, sentence
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
# Example usage
|
| 562 |
+
if __name__ == "__main__":
|
| 563 |
+
# Test with NBA Play-In Tournament example
|
| 564 |
+
sentences = [
|
| 565 |
+
"The NBA Play-In Tournament tips off tonight as the No. 7-10 teams in each conference battle for a spot in the playoffs. Here's everything you need to know as the action unfolds.",
|
| 566 |
+
"Tonight the NBA Play-In Tournament begins with No. 7-10 teams from each conference competing for playoff spots. Here's your guide to following all the action.",
|
| 567 |
+
"The NBA Play-In Tournament kicks off this evening featuring the No. 7-10 teams across both conferences fighting for playoff positions. Here's what you should know about the upcoming games.",
|
| 568 |
+
"Starting tonight, the NBA Play-In Tournament will showcase the No. 7-10 teams from each conference as they compete for remaining playoff berths. Here's your complete guide to the action.",
|
| 569 |
+
"The NBA Play-In Tournament begins tonight with the No. 7-10 teams in both conferences battling for playoff spots. Here's everything you need to know about the upcoming games.",
|
| 570 |
+
"Tonight marks the start of the NBA Play-In Tournament where No. 7-10 teams in each conference compete for playoff positions. Here's your essential guide to following the action.",
|
| 571 |
+
"The NBA Play-In Tournament tips off tonight, featuring No. 7-10 teams from both conferences fighting for playoff berths. Here's what you need to know about the tournament.",
|
| 572 |
+
"Beginning tonight, the NBA Play-In Tournament will pit the No. 7-10 teams in each conference against each other for playoff spots. Here's everything you should know about the games.",
|
| 573 |
+
"The NBA Play-In Tournament starts tonight with No. 7-10 teams across both conferences competing for playoff positions. Here's your complete guide to all the action.",
|
| 574 |
+
"Tonight is the tip-off of the NBA Play-In Tournament where the No. 7-10 teams from each conference battle for remaining playoff spots. Here's what you need to know as the games unfold."
|
| 575 |
+
]
|
| 576 |
+
|
| 577 |
+
# Initialize with multiple models
|
| 578 |
+
processor = NgramProcessor(models=["en_core_web_trf"])
|
| 579 |
+
|
| 580 |
+
# Process with all models combined
|
| 581 |
+
common_ngrams,indexed_ngrams = processor.find_filtered_ngrams(sentences)
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
# Print results
|
| 585 |
+
print("Common n-grams with indices per sentence:")
|
| 586 |
+
for sentence in sentences:
|
| 587 |
+
order, updated_sentence = processor.find_relative_order(sentence, common_ngrams[sentence])
|
| 588 |
+
print(f"Sentence: {sentence}")
|
| 589 |
+
print(f"Order: {order}")
|
| 590 |
+
print()
|
utils/old/masking/masking_methods.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import BertTokenizer, BertForMaskedLM
|
| 4 |
+
from nltk.corpus import stopwords
|
| 5 |
+
import nltk
|
| 6 |
+
|
| 7 |
+
# Ensure stopwords are downloaded
|
| 8 |
+
try:
|
| 9 |
+
nltk.data.find('corpora/stopwords')
|
| 10 |
+
except LookupError:
|
| 11 |
+
nltk.download('stopwords')
|
| 12 |
+
|
| 13 |
+
class MaskingProcessor:
|
| 14 |
+
def __init__(self, ):
|
| 15 |
+
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
| 16 |
+
self.model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
| 17 |
+
self.stop_words = set(stopwords.words('english'))
|
| 18 |
+
|
| 19 |
+
def adjust_ngram_indices(self, words, common_ngrams, remove_stopwords):
|
| 20 |
+
"""
|
| 21 |
+
Adjust indices of common n-grams after removing stop words.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
words (list): List of words in the original sentence.
|
| 25 |
+
common_ngrams (dict): Common n-grams and their indices.
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
dict: Adjusted common n-grams and their indices.
|
| 29 |
+
"""
|
| 30 |
+
if not remove_stopwords:
|
| 31 |
+
return common_ngrams
|
| 32 |
+
|
| 33 |
+
non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
|
| 34 |
+
adjusted_ngrams = {}
|
| 35 |
+
|
| 36 |
+
for ngram, positions in common_ngrams.items():
|
| 37 |
+
adjusted_positions = []
|
| 38 |
+
for start, end in positions:
|
| 39 |
+
try:
|
| 40 |
+
new_start = non_stop_word_indices.index(start)
|
| 41 |
+
new_end = non_stop_word_indices.index(end)
|
| 42 |
+
adjusted_positions.append((new_start, new_end))
|
| 43 |
+
except ValueError:
|
| 44 |
+
continue # Skip if indices cannot be mapped
|
| 45 |
+
adjusted_ngrams[ngram] = adjusted_positions
|
| 46 |
+
|
| 47 |
+
return adjusted_ngrams
|
| 48 |
+
|
| 49 |
+
# def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords):
|
| 50 |
+
# """
|
| 51 |
+
# Mask one word before the first common n-gram, one between two n-grams,
|
| 52 |
+
# and one after the last common n-gram (random selection).
|
| 53 |
+
|
| 54 |
+
# Args:
|
| 55 |
+
# original_sentence (str): Original sentence
|
| 56 |
+
# common_ngrams (dict): Common n-grams and their indices
|
| 57 |
+
|
| 58 |
+
# Returns:
|
| 59 |
+
# str: Masked sentence with original stop words retained
|
| 60 |
+
# """
|
| 61 |
+
# words = original_sentence.split()
|
| 62 |
+
# if remove_stopwords:
|
| 63 |
+
# non_stop_words = [word for word in words if word.lower() not in self.stop_words]
|
| 64 |
+
# non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
|
| 65 |
+
# else:
|
| 66 |
+
# non_stop_words = words
|
| 67 |
+
# non_stop_word_indices = list(range(len(words)))
|
| 68 |
+
# # non_stop_words = [word for word in words if word.lower() not in self.stop_words] if remove_stopwords else words
|
| 69 |
+
# adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)
|
| 70 |
+
|
| 71 |
+
# mask_indices = []
|
| 72 |
+
# # Handle before the first common n-gram
|
| 73 |
+
# if adjusted_ngrams:
|
| 74 |
+
# first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
|
| 75 |
+
# if first_ngram_start > 0:
|
| 76 |
+
# mask_indices.append(random.randint(0, first_ngram_start - 1))
|
| 77 |
+
|
| 78 |
+
# # Handle between common n-grams
|
| 79 |
+
# ngram_positions = list(adjusted_ngrams.values())
|
| 80 |
+
# for i in range(len(ngram_positions) - 1):
|
| 81 |
+
# end_prev = ngram_positions[i][-1][1]
|
| 82 |
+
# start_next = ngram_positions[i + 1][0][0]
|
| 83 |
+
# if start_next > end_prev + 1:
|
| 84 |
+
# mask_indices.append(random.randint(end_prev + 1, start_next - 1))
|
| 85 |
+
|
| 86 |
+
# # Handle after the last common n-gram
|
| 87 |
+
# last_ngram_end = ngram_positions[-1][-1][1]
|
| 88 |
+
# if last_ngram_end < len(non_stop_words) - 1:
|
| 89 |
+
# mask_indices.append(random.randint(last_ngram_end + 1, len(non_stop_words) - 1))
|
| 90 |
+
|
| 91 |
+
# # Mask the chosen indices
|
| 92 |
+
# original_masked_sentence = words[:]
|
| 93 |
+
# # for idx in mask_indices:
|
| 94 |
+
# # if idx not in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
|
| 95 |
+
# # non_stop_words[idx] = self.tokenizer.mask_token
|
| 96 |
+
# # original_masked_sentence[idx] = self.tokenizer.mask_token
|
| 97 |
+
# for idx in mask_indices:
|
| 98 |
+
# if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
|
| 99 |
+
# continue # Skip if index belongs to common n-grams
|
| 100 |
+
# if remove_stopwords:
|
| 101 |
+
# original_idx = non_stop_word_indices[idx] # Map back to original indices
|
| 102 |
+
# original_masked_sentence[original_idx] = self.tokenizer.mask_token
|
| 103 |
+
# else:
|
| 104 |
+
# original_masked_sentence[idx] = self.tokenizer.mask_token
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# return " ".join(original_masked_sentence)
|
| 108 |
+
def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords):
|
| 109 |
+
"""
|
| 110 |
+
Mask one word before the first common n-gram, one between two n-grams,
|
| 111 |
+
and one after the last common n-gram (random selection).
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
original_sentence (str): Original sentence
|
| 115 |
+
common_ngrams (dict): Common n-grams and their indices
|
| 116 |
+
remove_stopwords (bool): Whether to remove stop words
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
str: Masked sentence with original stop words retained
|
| 120 |
+
"""
|
| 121 |
+
words = original_sentence.split()
|
| 122 |
+
if remove_stopwords:
|
| 123 |
+
non_stop_words = [word for word in words if word.lower() not in self.stop_words]
|
| 124 |
+
non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
|
| 125 |
+
else:
|
| 126 |
+
non_stop_words = words
|
| 127 |
+
non_stop_word_indices = list(range(len(words)))
|
| 128 |
+
|
| 129 |
+
adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)
|
| 130 |
+
|
| 131 |
+
# Collect all indices corresponding to common n-grams
|
| 132 |
+
common_ngram_indices = {
|
| 133 |
+
idx for ngram_positions in adjusted_ngrams.values()
|
| 134 |
+
for start, end in ngram_positions
|
| 135 |
+
for idx in range(start, end + 1)
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
mask_indices = []
|
| 139 |
+
# Handle before the first common n-gram
|
| 140 |
+
if adjusted_ngrams:
|
| 141 |
+
first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
|
| 142 |
+
if first_ngram_start > 0:
|
| 143 |
+
potential_indices = [i for i in range(first_ngram_start) if i not in common_ngram_indices]
|
| 144 |
+
if potential_indices:
|
| 145 |
+
mask_indices.append(random.choice(potential_indices))
|
| 146 |
+
|
| 147 |
+
# Handle between common n-grams
|
| 148 |
+
ngram_positions = list(adjusted_ngrams.values())
|
| 149 |
+
for i in range(len(ngram_positions) - 1):
|
| 150 |
+
end_prev = ngram_positions[i][-1][1]
|
| 151 |
+
start_next = ngram_positions[i + 1][0][0]
|
| 152 |
+
potential_indices = [i for i in range(end_prev + 1, start_next) if i not in common_ngram_indices]
|
| 153 |
+
if potential_indices:
|
| 154 |
+
mask_indices.append(random.choice(potential_indices))
|
| 155 |
+
|
| 156 |
+
# Handle after the last common n-gram
|
| 157 |
+
last_ngram_end = ngram_positions[-1][-1][1]
|
| 158 |
+
if last_ngram_end < len(non_stop_words) - 1:
|
| 159 |
+
potential_indices = [i for i in range(last_ngram_end + 1, len(non_stop_words)) if i not in common_ngram_indices]
|
| 160 |
+
if potential_indices:
|
| 161 |
+
mask_indices.append(random.choice(potential_indices))
|
| 162 |
+
|
| 163 |
+
# Mask the chosen indices
|
| 164 |
+
original_masked_sentence = words[:]
|
| 165 |
+
for idx in mask_indices:
|
| 166 |
+
if remove_stopwords:
|
| 167 |
+
original_idx = non_stop_word_indices[idx] # Map back to original indices
|
| 168 |
+
original_masked_sentence[original_idx] = self.tokenizer.mask_token
|
| 169 |
+
else:
|
| 170 |
+
original_masked_sentence[idx] = self.tokenizer.mask_token
|
| 171 |
+
|
| 172 |
+
return " ".join(original_masked_sentence)
|
| 173 |
+
|
| 174 |
+
def mask_sentence_entropy(self, original_sentence, common_ngrams, remove_stopwords):
|
| 175 |
+
"""
|
| 176 |
+
Mask one word before the first common n-gram, one between two n-grams,
|
| 177 |
+
and one after the last common n-gram (highest entropy selection).
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
original_sentence (str): Original sentence
|
| 181 |
+
common_ngrams (dict): Common n-grams and their indices
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
str: Masked sentence with original stop words retained
|
| 185 |
+
"""
|
| 186 |
+
words = original_sentence.split()
|
| 187 |
+
# non_stop_words = [word for word in words if word.lower() not in self.stop_words] if remove_stopwords else words
|
| 188 |
+
if remove_stopwords:
|
| 189 |
+
non_stop_words = [word for word in words if word.lower() not in self.stop_words]
|
| 190 |
+
non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
|
| 191 |
+
else:
|
| 192 |
+
non_stop_words = words
|
| 193 |
+
non_stop_word_indices = list(range(len(words)))
|
| 194 |
+
adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)
|
| 195 |
+
entropy_scores = {}
|
| 196 |
+
|
| 197 |
+
for idx, word in enumerate(non_stop_words):
|
| 198 |
+
if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
|
| 199 |
+
continue # Skip words in common n-grams
|
| 200 |
+
|
| 201 |
+
masked_sentence = non_stop_words[:idx] + [self.tokenizer.mask_token] + non_stop_words[idx + 1:]
|
| 202 |
+
masked_sentence = " ".join(masked_sentence)
|
| 203 |
+
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
|
| 204 |
+
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
| 205 |
+
|
| 206 |
+
with torch.no_grad():
|
| 207 |
+
outputs = self.model(input_ids)
|
| 208 |
+
logits = outputs.logits
|
| 209 |
+
|
| 210 |
+
filtered_logits = logits[0, mask_token_index, :]
|
| 211 |
+
probs = torch.softmax(filtered_logits, dim=-1)
|
| 212 |
+
entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item() # Add epsilon to prevent log(0)
|
| 213 |
+
entropy_scores[idx] = entropy
|
| 214 |
+
|
| 215 |
+
mask_indices = []
|
| 216 |
+
|
| 217 |
+
# Handle before the first common n-gram
|
| 218 |
+
if adjusted_ngrams:
|
| 219 |
+
first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
|
| 220 |
+
candidates = [i for i in range(first_ngram_start) if i in entropy_scores]
|
| 221 |
+
if candidates:
|
| 222 |
+
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
|
| 223 |
+
|
| 224 |
+
# Handle between common n-grams
|
| 225 |
+
ngram_positions = list(adjusted_ngrams.values())
|
| 226 |
+
for i in range(len(ngram_positions) - 1):
|
| 227 |
+
end_prev = ngram_positions[i][-1][1]
|
| 228 |
+
start_next = ngram_positions[i + 1][0][0]
|
| 229 |
+
candidates = [i for i in range(end_prev + 1, start_next) if i in entropy_scores]
|
| 230 |
+
if candidates:
|
| 231 |
+
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
|
| 232 |
+
|
| 233 |
+
# Handle after the last common n-gram
|
| 234 |
+
last_ngram_end = ngram_positions[-1][-1][1]
|
| 235 |
+
candidates = [i for i in range(last_ngram_end + 1, len(non_stop_words)) if i in entropy_scores]
|
| 236 |
+
if candidates:
|
| 237 |
+
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
|
| 238 |
+
|
| 239 |
+
# Mask the chosen indices
|
| 240 |
+
original_masked_sentence = words[:]
|
| 241 |
+
# for idx in mask_indices:
|
| 242 |
+
# non_stop_words[idx] = self.tokenizer.mask_token
|
| 243 |
+
# original_masked_sentence[idx] = self.tokenizer.mask_token
|
| 244 |
+
|
| 245 |
+
for idx in mask_indices:
|
| 246 |
+
if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
|
| 247 |
+
continue # Skip if index belongs to common n-grams
|
| 248 |
+
if remove_stopwords:
|
| 249 |
+
original_idx = non_stop_word_indices[idx] # Map back to original indices
|
| 250 |
+
original_masked_sentence[original_idx] = self.tokenizer.mask_token
|
| 251 |
+
else:
|
| 252 |
+
original_masked_sentence[idx] = self.tokenizer.mask_token
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
return " ".join(original_masked_sentence)
|
| 256 |
+
|
| 257 |
+
def calculate_mask_logits(self, masked_sentence):
|
| 258 |
+
"""
|
| 259 |
+
Calculate logits for masked tokens in the sentence using BERT.
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
masked_sentence (str): Sentence with [MASK] tokens
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
dict: Masked token indices and their logits
|
| 266 |
+
"""
|
| 267 |
+
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
|
| 268 |
+
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
| 269 |
+
|
| 270 |
+
with torch.no_grad():
|
| 271 |
+
outputs = self.model(input_ids)
|
| 272 |
+
logits = outputs.logits
|
| 273 |
+
|
| 274 |
+
mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index}
|
| 275 |
+
return mask_logits
|
| 276 |
+
|
| 277 |
+
def process_sentences(self, original_sentences, result_dict, method="random", remove_stopwords=False):
|
| 278 |
+
"""
|
| 279 |
+
Process a list of sentences and calculate logits for masked tokens using the specified method.
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
original_sentences (list): List of original sentences
|
| 283 |
+
result_dict (dict): Common n-grams and their indices for each sentence
|
| 284 |
+
method (str): Masking method ("random" or "entropy")
|
| 285 |
+
|
| 286 |
+
Returns:
|
| 287 |
+
dict: Masked sentences and their logits for each sentence
|
| 288 |
+
"""
|
| 289 |
+
results = {}
|
| 290 |
+
|
| 291 |
+
for sentence, ngrams in result_dict.items():
|
| 292 |
+
if method == "random":
|
| 293 |
+
masked_sentence = self.mask_sentence_random(sentence, ngrams, remove_stopwords)
|
| 294 |
+
elif method == "entropy":
|
| 295 |
+
masked_sentence = self.mask_sentence_entropy(sentence, ngrams, remove_stopwords)
|
| 296 |
+
else:
|
| 297 |
+
raise ValueError("Invalid method. Choose 'random' or 'entropy'.")
|
| 298 |
+
|
| 299 |
+
logits = self.calculate_mask_logits(masked_sentence)
|
| 300 |
+
results[sentence] = {
|
| 301 |
+
"masked_sentence": masked_sentence,
|
| 302 |
+
"mask_logits": logits
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
return results
|
| 306 |
+
|
| 307 |
+
# Example usage
|
| 308 |
+
if __name__ == "__main__":
|
| 309 |
+
# !!! Working both the cases regardless if the stopword is removed or not
|
| 310 |
+
sentences = [
|
| 311 |
+
"The quick brown fox jumps over the lazy dog.",
|
| 312 |
+
"A speedy brown fox jumps over a lazy dog.",
|
| 313 |
+
"A swift brown fox leaps over the lethargic dog."
|
| 314 |
+
]
|
| 315 |
+
result_dict ={
|
| 316 |
+
'The quick brown fox jumps over the lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
|
| 317 |
+
'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
|
| 318 |
+
'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
processor = MaskingProcessor()
|
| 323 |
+
results_random = processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=True)
|
| 324 |
+
# results_entropy = processor.process_sentences(sentences, result_dict, method="entropy", remove_stopwords=False)
|
| 325 |
+
|
| 326 |
+
for sentence, output in results_random.items():
|
| 327 |
+
print(f"Original Sentence (Random): {sentence}")
|
| 328 |
+
print(f"Masked Sentence (Random): {output['masked_sentence']}")
|
| 329 |
+
# # print(f"Mask Logits (Random): {output['mask_logits']}")
|
| 330 |
+
# print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
|
| 331 |
+
# print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
|
| 332 |
+
# print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
|
| 333 |
+
print('--------------------------------')
|
| 334 |
+
# for mask_idx, logits in output["mask_logits"].items():
|
| 335 |
+
# print(f"Logits for [MASK] at position {mask_idx}:")
|
| 336 |
+
# print(f' logits : {logits[:5]}') # List of logits for all vocabulary tokens
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
# result_dict = {
|
| 342 |
+
# "The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
|
| 343 |
+
# "A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
|
| 344 |
+
# "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}
|
| 345 |
+
# }
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
# print('--------------------------------')
|
| 349 |
+
# for sentence, output in results_entropy.items():
|
| 350 |
+
# print(f"Original Sentence (Entropy): {sentence}")
|
| 351 |
+
# print(f"Masked Sentence (Entropy): {output['masked_sentence']}")
|
| 352 |
+
# # print(f"Mask Logits (Entropy): {output['mask_logits']}")
|
| 353 |
+
# print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
|
| 354 |
+
# print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
|
| 355 |
+
# print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
|
utils/old/masking/masking_methods_new_work.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import BertTokenizer, BertForMaskedLM
|
| 4 |
+
from nltk.corpus import stopwords
|
| 5 |
+
import nltk
|
| 6 |
+
|
| 7 |
+
# Ensure stopwords are downloaded
|
| 8 |
+
try:
|
| 9 |
+
nltk.data.find('corpora/stopwords')
|
| 10 |
+
except LookupError:
|
| 11 |
+
nltk.download('stopwords')
|
| 12 |
+
|
| 13 |
+
class MaskingProcessor:
|
| 14 |
+
def __init__(self):
|
| 15 |
+
|
| 16 |
+
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
| 17 |
+
self.model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
| 18 |
+
self.stop_words = set(stopwords.words('english'))
|
| 19 |
+
|
| 20 |
+
def remove_stopwords(self, words):
|
| 21 |
+
"""
|
| 22 |
+
Remove stopwords from the given list of words.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
words (list): List of words.
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
list: List of non-stop words.
|
| 29 |
+
"""
|
| 30 |
+
return [word for word in words if word.lower() not in self.stop_words]
|
| 31 |
+
|
| 32 |
+
def adjust_ngram_indices(self, original_words, common_ngrams):
|
| 33 |
+
"""
|
| 34 |
+
Adjust indices of common n-grams after removing stopwords.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
original_words (list): Original list of words.
|
| 38 |
+
common_ngrams (dict): Common n-grams and their indices.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
dict: Adjusted common n-grams with updated indices.
|
| 42 |
+
"""
|
| 43 |
+
non_stop_words = self.remove_stopwords(original_words)
|
| 44 |
+
original_to_non_stop = []
|
| 45 |
+
non_stop_idx = 0
|
| 46 |
+
|
| 47 |
+
for original_idx, word in enumerate(original_words):
|
| 48 |
+
if word.lower() not in self.stop_words:
|
| 49 |
+
original_to_non_stop.append((original_idx, non_stop_idx))
|
| 50 |
+
non_stop_idx += 1
|
| 51 |
+
|
| 52 |
+
adjusted_ngrams = {}
|
| 53 |
+
for ngram, positions in common_ngrams.items():
|
| 54 |
+
adjusted_positions = []
|
| 55 |
+
for start, end in positions:
|
| 56 |
+
try:
|
| 57 |
+
new_start = next(non_stop for orig, non_stop in original_to_non_stop if orig == start)
|
| 58 |
+
new_end = next(non_stop for orig, non_stop in original_to_non_stop if orig == end)
|
| 59 |
+
adjusted_positions.append((new_start, new_end))
|
| 60 |
+
except StopIteration:
|
| 61 |
+
continue # Skip if indices cannot be mapped
|
| 62 |
+
adjusted_ngrams[ngram] = adjusted_positions
|
| 63 |
+
|
| 64 |
+
return adjusted_ngrams
|
| 65 |
+
|
| 66 |
+
def mask_sentence_random(self, sentence, common_ngrams):
|
| 67 |
+
"""
|
| 68 |
+
Mask words in the sentence based on the specified rules after removing stopwords.
|
| 69 |
+
"""
|
| 70 |
+
original_words = sentence.split()
|
| 71 |
+
print(f' ---- original_words : {original_words} ----- ')
|
| 72 |
+
non_stop_words = self.remove_stopwords(original_words)
|
| 73 |
+
print(f' ---- non_stop_words : {non_stop_words} ----- ')
|
| 74 |
+
adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
|
| 75 |
+
print(f' ---- common_ngrams : {common_ngrams} ----- ')
|
| 76 |
+
print(f' ---- adjusted_ngrams : {adjusted_ngrams} ----- ')
|
| 77 |
+
|
| 78 |
+
mask_indices = []
|
| 79 |
+
|
| 80 |
+
# Extract n-gram positions in non-stop words
|
| 81 |
+
ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
|
| 82 |
+
|
| 83 |
+
# Mask a word before the first common n-gram
|
| 84 |
+
if ngram_positions:
|
| 85 |
+
print(f' ---- ngram_positions : {ngram_positions} ----- ')
|
| 86 |
+
first_ngram_start = ngram_positions[0][0]
|
| 87 |
+
print(f' ---- first_ngram_start : {first_ngram_start} ----- ')
|
| 88 |
+
if first_ngram_start > 0:
|
| 89 |
+
mask_index_before_ngram = random.randint(0, first_ngram_start-1)
|
| 90 |
+
print(f' ---- mask_index_before_ngram : {mask_index_before_ngram} ----- ')
|
| 91 |
+
mask_indices.append(mask_index_before_ngram)
|
| 92 |
+
|
| 93 |
+
# Mask words between common n-grams
|
| 94 |
+
for i in range(len(ngram_positions) - 1):
|
| 95 |
+
end_prev = ngram_positions[i][1]
|
| 96 |
+
print(f' ---- end_prev : {end_prev} ----- ') # END INDICE FROM PREV LOOP FUNKNLKNLKNLKNLKNLKNLSKDNFLKSDHJFLSDJKFH:KLSDHF:LHKSDF:HJKLDFS:HJKLDFSHJK:
|
| 97 |
+
start_next = ngram_positions[i + 1][0]
|
| 98 |
+
print(f' ---- start_next : {start_next} ----- ')
|
| 99 |
+
if start_next > end_prev + 1:
|
| 100 |
+
mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
|
| 101 |
+
print(f' ---- mask_index_between_ngrams : {mask_index_between_ngrams} ----- ')
|
| 102 |
+
mask_indices.append(mask_index_between_ngrams)
|
| 103 |
+
|
| 104 |
+
# Mask a word after the last common n-gram
|
| 105 |
+
last_ngram_end = ngram_positions[-1][1]
|
| 106 |
+
if last_ngram_end < len(non_stop_words) - 1:
|
| 107 |
+
print(f' ---- last_ngram_end : {last_ngram_end} ----- ')
|
| 108 |
+
mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
|
| 109 |
+
print(f' ---- mask_index_after_ngram : {mask_index_after_ngram} ----- ')
|
| 110 |
+
mask_indices.append(mask_index_after_ngram)
|
| 111 |
+
|
| 112 |
+
# Create mapping from non-stop words to original indices
|
| 113 |
+
non_stop_to_original = {}
|
| 114 |
+
non_stop_idx = 0
|
| 115 |
+
for orig_idx, word in enumerate(original_words):
|
| 116 |
+
if word.lower() not in self.stop_words:
|
| 117 |
+
non_stop_to_original[non_stop_idx] = orig_idx
|
| 118 |
+
non_stop_idx += 1
|
| 119 |
+
|
| 120 |
+
# Map mask indices from non-stop word positions to original positions
|
| 121 |
+
print(f' ---- non_stop_to_original : {non_stop_to_original} ----- ')
|
| 122 |
+
original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
|
| 123 |
+
print(f' ---- original_mask_indices : {original_mask_indices} ----- ')
|
| 124 |
+
|
| 125 |
+
# Apply masks to the original sentence
|
| 126 |
+
masked_words = original_words.copy()
|
| 127 |
+
for idx in original_mask_indices:
|
| 128 |
+
masked_words[idx] = self.tokenizer.mask_token
|
| 129 |
+
|
| 130 |
+
return " ".join(masked_words)
|
| 131 |
+
|
| 132 |
+
def mask_sentence_pseudorandom(self, sentence, common_ngrams):
|
| 133 |
+
"""
|
| 134 |
+
Mask words in the sentence based on the specified rules after removing stopwords.
|
| 135 |
+
"""
|
| 136 |
+
random.seed(42)
|
| 137 |
+
original_words = sentence.split()
|
| 138 |
+
print(f' ---- original_words : {original_words} ----- ')
|
| 139 |
+
non_stop_words = self.remove_stopwords(original_words)
|
| 140 |
+
print(f' ---- non_stop_words : {non_stop_words} ----- ')
|
| 141 |
+
adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
|
| 142 |
+
print(f' ---- common_ngrams : {common_ngrams} ----- ')
|
| 143 |
+
print(f' ---- adjusted_ngrams : {adjusted_ngrams} ----- ')
|
| 144 |
+
|
| 145 |
+
mask_indices = []
|
| 146 |
+
|
| 147 |
+
# Extract n-gram positions in non-stop words
|
| 148 |
+
ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
|
| 149 |
+
|
| 150 |
+
# Mask a word before the first common n-gram
|
| 151 |
+
if ngram_positions:
|
| 152 |
+
print(f' ---- ngram_positions : {ngram_positions} ----- ')
|
| 153 |
+
first_ngram_start = ngram_positions[0][0]
|
| 154 |
+
print(f' ---- first_ngram_start : {first_ngram_start} ----- ')
|
| 155 |
+
if first_ngram_start > 0:
|
| 156 |
+
mask_index_before_ngram = random.randint(0, first_ngram_start-1)
|
| 157 |
+
print(f' ---- mask_index_before_ngram : {mask_index_before_ngram} ----- ')
|
| 158 |
+
mask_indices.append(mask_index_before_ngram)
|
| 159 |
+
|
| 160 |
+
# Mask words between common n-grams
|
| 161 |
+
for i in range(len(ngram_positions) - 1):
|
| 162 |
+
end_prev = ngram_positions[i][1]
|
| 163 |
+
print(f' ---- end_prev : {end_prev} ----- ')
|
| 164 |
+
start_next = ngram_positions[i + 1][0]
|
| 165 |
+
print(f' ---- start_next : {start_next} ----- ')
|
| 166 |
+
if start_next > end_prev + 1:
|
| 167 |
+
mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
|
| 168 |
+
print(f' ---- mask_index_between_ngrams : {mask_index_between_ngrams} ----- ')
|
| 169 |
+
mask_indices.append(mask_index_between_ngrams)
|
| 170 |
+
|
| 171 |
+
# Mask a word after the last common n-gram
|
| 172 |
+
last_ngram_end = ngram_positions[-1][1]
|
| 173 |
+
if last_ngram_end < len(non_stop_words) - 1:
|
| 174 |
+
print(f' ---- last_ngram_end : {last_ngram_end} ----- ')
|
| 175 |
+
mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
|
| 176 |
+
print(f' ---- mask_index_after_ngram : {mask_index_after_ngram} ----- ')
|
| 177 |
+
mask_indices.append(mask_index_after_ngram)
|
| 178 |
+
|
| 179 |
+
# Create mapping from non-stop words to original indices
|
| 180 |
+
non_stop_to_original = {}
|
| 181 |
+
non_stop_idx = 0
|
| 182 |
+
for orig_idx, word in enumerate(original_words):
|
| 183 |
+
if word.lower() not in self.stop_words:
|
| 184 |
+
non_stop_to_original[non_stop_idx] = orig_idx
|
| 185 |
+
non_stop_idx += 1
|
| 186 |
+
|
| 187 |
+
# Map mask indices from non-stop word positions to original positions
|
| 188 |
+
print(f' ---- non_stop_to_original : {non_stop_to_original} ----- ')
|
| 189 |
+
original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
|
| 190 |
+
print(f' ---- original_mask_indices : {original_mask_indices} ----- ')
|
| 191 |
+
|
| 192 |
+
# Apply masks to the original sentence
|
| 193 |
+
masked_words = original_words.copy()
|
| 194 |
+
for idx in original_mask_indices:
|
| 195 |
+
masked_words[idx] = self.tokenizer.mask_token
|
| 196 |
+
|
| 197 |
+
return " ".join(masked_words)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def calculate_word_entropy(self, sentence, word_position):
|
| 201 |
+
"""
|
| 202 |
+
Calculate entropy for a specific word position in the sentence.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
sentence (str): The input sentence
|
| 206 |
+
word_position (int): Position of the word to calculate entropy for
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
float: Entropy value for the word
|
| 210 |
+
"""
|
| 211 |
+
words = sentence.split()
|
| 212 |
+
masked_words = words.copy()
|
| 213 |
+
masked_words[word_position] = self.tokenizer.mask_token
|
| 214 |
+
masked_sentence = " ".join(masked_words)
|
| 215 |
+
|
| 216 |
+
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
|
| 217 |
+
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
| 218 |
+
|
| 219 |
+
with torch.no_grad():
|
| 220 |
+
outputs = self.model(input_ids)
|
| 221 |
+
logits = outputs.logits
|
| 222 |
+
|
| 223 |
+
# Get probabilities for the masked position
|
| 224 |
+
probs = torch.nn.functional.softmax(logits[0, mask_token_index], dim=-1)
|
| 225 |
+
# Calculate entropy: -sum(p * log(p))
|
| 226 |
+
entropy = -torch.sum(probs * torch.log(probs + 1e-9))
|
| 227 |
+
|
| 228 |
+
return entropy.item()
|
| 229 |
+
|
| 230 |
+
def mask_sentence_entropy(self, sentence, common_ngrams):
|
| 231 |
+
"""
|
| 232 |
+
Mask words in the sentence based on entropy, following n-gram positioning rules.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
sentence (str): Original sentence
|
| 236 |
+
common_ngrams (dict): Common n-grams and their indices
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
str: Masked sentence
|
| 240 |
+
"""
|
| 241 |
+
original_words = sentence.split()
|
| 242 |
+
non_stop_words = self.remove_stopwords(original_words)
|
| 243 |
+
adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
|
| 244 |
+
|
| 245 |
+
# Create mapping from non-stop words to original indices
|
| 246 |
+
non_stop_to_original = {}
|
| 247 |
+
original_to_non_stop = {}
|
| 248 |
+
non_stop_idx = 0
|
| 249 |
+
for orig_idx, word in enumerate(original_words):
|
| 250 |
+
if word.lower() not in self.stop_words:
|
| 251 |
+
non_stop_to_original[non_stop_idx] = orig_idx
|
| 252 |
+
original_to_non_stop[orig_idx] = non_stop_idx
|
| 253 |
+
non_stop_idx += 1
|
| 254 |
+
|
| 255 |
+
ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
|
| 256 |
+
mask_indices = []
|
| 257 |
+
|
| 258 |
+
if ngram_positions:
|
| 259 |
+
# Handle words before first n-gram
|
| 260 |
+
first_ngram_start = ngram_positions[0][0]
|
| 261 |
+
if first_ngram_start > 0:
|
| 262 |
+
# Calculate entropy for all candidate positions
|
| 263 |
+
candidate_positions = range(0, first_ngram_start)
|
| 264 |
+
entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
|
| 265 |
+
for pos in candidate_positions]
|
| 266 |
+
# Select position with highest entropy
|
| 267 |
+
mask_indices.append(max(entropies, key=lambda x: x[1])[0])
|
| 268 |
+
|
| 269 |
+
# Handle words between n-grams
|
| 270 |
+
for i in range(len(ngram_positions) - 1):
|
| 271 |
+
end_prev = ngram_positions[i][1]
|
| 272 |
+
start_next = ngram_positions[i + 1][0]
|
| 273 |
+
if start_next > end_prev + 1:
|
| 274 |
+
candidate_positions = range(end_prev + 1, start_next)
|
| 275 |
+
entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
|
| 276 |
+
for pos in candidate_positions]
|
| 277 |
+
mask_indices.append(max(entropies, key=lambda x: x[1])[0])
|
| 278 |
+
|
| 279 |
+
# Handle words after last n-gram
|
| 280 |
+
last_ngram_end = ngram_positions[-1][1]
|
| 281 |
+
if last_ngram_end < len(non_stop_words) - 1:
|
| 282 |
+
candidate_positions = range(last_ngram_end + 1, len(non_stop_words))
|
| 283 |
+
entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
|
| 284 |
+
for pos in candidate_positions]
|
| 285 |
+
mask_indices.append(max(entropies, key=lambda x: x[1])[0])
|
| 286 |
+
|
| 287 |
+
# Map mask indices to original sentence positions and apply masks
|
| 288 |
+
original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
|
| 289 |
+
masked_words = original_words.copy()
|
| 290 |
+
for idx in original_mask_indices:
|
| 291 |
+
masked_words[idx] = self.tokenizer.mask_token
|
| 292 |
+
|
| 293 |
+
return " ".join(masked_words)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def calculate_mask_logits(self, masked_sentence):
|
| 297 |
+
"""
|
| 298 |
+
Calculate logits for masked tokens in the sentence using BERT.
|
| 299 |
+
|
| 300 |
+
Args:
|
| 301 |
+
masked_sentence (str): Sentence with [MASK] tokens.
|
| 302 |
+
|
| 303 |
+
Returns:
|
| 304 |
+
dict: Masked token indices and their logits.
|
| 305 |
+
"""
|
| 306 |
+
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
|
| 307 |
+
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
| 308 |
+
|
| 309 |
+
with torch.no_grad():
|
| 310 |
+
outputs = self.model(input_ids)
|
| 311 |
+
logits = outputs.logits
|
| 312 |
+
|
| 313 |
+
mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index}
|
| 314 |
+
return mask_logits
|
| 315 |
+
|
| 316 |
+
def process_sentences(self, sentences, result_dict, method="random"):
|
| 317 |
+
"""
|
| 318 |
+
Process sentences and calculate logits for masked tokens.
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
sentences (list): List of sentences
|
| 322 |
+
result_dict (dict): Dictionary of common n-grams
|
| 323 |
+
method (str): Masking method ("random" or "entropy")
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
dict: Masked sentences and logits for each sentence
|
| 327 |
+
"""
|
| 328 |
+
results = {}
|
| 329 |
+
|
| 330 |
+
for sentence, ngrams in result_dict.items():
|
| 331 |
+
if method == "random":
|
| 332 |
+
masked_sentence = self.mask_sentence_random(sentence, ngrams)
|
| 333 |
+
elif method == "pseudorandom":
|
| 334 |
+
masked_sentence = self.mask_sentence_pseudorandom(sentence, ngrams)
|
| 335 |
+
else: # entropy
|
| 336 |
+
masked_sentence = self.mask_sentence_entropy(sentence, ngrams)
|
| 337 |
+
|
| 338 |
+
logits = self.calculate_mask_logits(masked_sentence)
|
| 339 |
+
results[sentence] = {
|
| 340 |
+
"masked_sentence": masked_sentence,
|
| 341 |
+
"mask_logits": logits
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
return results
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
if __name__ == "__main__":
|
| 349 |
+
# !!! Working both the cases regardless if the stopword is removed or not
|
| 350 |
+
sentences = [
|
| 351 |
+
"The quick brown fox jumps over the lazy dog everyday.",
|
| 352 |
+
# "A speedy brown fox jumps over a lazy dog.",
|
| 353 |
+
# "A swift brown fox leaps over the lethargic dog."
|
| 354 |
+
]
|
| 355 |
+
result_dict ={
|
| 356 |
+
'The quick brown fox jumps over the lazy dog everyday.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
|
| 357 |
+
# 'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
|
| 358 |
+
# 'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
processor = MaskingProcessor()
|
| 363 |
+
# results_random = processor.process_sentences(sentences, result_dict)
|
| 364 |
+
results_entropy = processor.process_sentences(sentences, result_dict, method="random")
|
| 365 |
+
|
| 366 |
+
# results_entropy = processor.process_sentences(sentences, result_dict, method="entropy", remove_stopwords=False)
|
| 367 |
+
|
| 368 |
+
for sentence, output in results_entropy.items():
|
| 369 |
+
print(f"Original Sentence (Random): {sentence}")
|
| 370 |
+
print(f"Masked Sentence (Random): {output['masked_sentence']}")
|
| 371 |
+
# print(f"Mask Logits (Random): {output['mask_logits']}")
|
| 372 |
+
print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
|
| 373 |
+
print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
|
| 374 |
+
print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
|
| 375 |
+
print('--------------------------------')
|
| 376 |
+
for mask_idx, logits in output["mask_logits"].items():
|
| 377 |
+
print(f"Logits for [MASK] at position {mask_idx}:")
|
| 378 |
+
print(f' logits : {logits[:5]}') # List of logits for all vocabulary tokens
|
| 379 |
+
print(f' len(logits) : {len(logits)}')
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
# -------------------------------------------------------------------------------------------
|
| 385 |
+
# def mask_sentence(self, sentence, common_ngrams):
|
| 386 |
+
# """
|
| 387 |
+
# Mask words in the sentence based on the specified rules after removing stopwords.
|
| 388 |
+
|
| 389 |
+
# Args:
|
| 390 |
+
# sentence (str): Original sentence.
|
| 391 |
+
# common_ngrams (dict): Common n-grams and their indices.
|
| 392 |
+
|
| 393 |
+
# Returns:
|
| 394 |
+
# str: Masked sentence.
|
| 395 |
+
# """
|
| 396 |
+
# original_words = sentence.split()
|
| 397 |
+
# print(f' ---- original_words : {original_words} ----- ')
|
| 398 |
+
# non_stop_words = self.remove_stopwords(original_words)
|
| 399 |
+
# print(f' ---- non_stop_words : {non_stop_words} ----- ')
|
| 400 |
+
# adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
|
| 401 |
+
# print(f' ---- common_ngrams : {common_ngrams} ----- ')
|
| 402 |
+
# print(f' ---- adjusted_ngrams : {adjusted_ngrams} ----- ')
|
| 403 |
+
|
| 404 |
+
# mask_indices = []
|
| 405 |
+
|
| 406 |
+
# # Extract n-gram positions in non-stop words
|
| 407 |
+
# ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
|
| 408 |
+
# print(f' ---- ngram_positions : {ngram_positions} ----- ')
|
| 409 |
+
# # Mask a word before the first common n-gram
|
| 410 |
+
# if ngram_positions:
|
| 411 |
+
# first_ngram_start = ngram_positions[0][0]
|
| 412 |
+
# print(f' ---- first_ngram_start : {first_ngram_start} ----- ')
|
| 413 |
+
# if first_ngram_start > 0:
|
| 414 |
+
# mask_index_before_ngram = random.randint(0, first_ngram_start-1)
|
| 415 |
+
# print(f' ---- mask_index_before_ngram : {mask_index_before_ngram} ----- ')
|
| 416 |
+
# mask_indices.append(mask_index_before_ngram)
|
| 417 |
+
|
| 418 |
+
# # Mask words between common n-grams
|
| 419 |
+
# for i in range(len(ngram_positions) - 1):
|
| 420 |
+
# end_prev = ngram_positions[i][1]
|
| 421 |
+
# print(f' ---- end_prev : {end_prev} ----- ')
|
| 422 |
+
# start_next = ngram_positions[i + 1][0]
|
| 423 |
+
# print(f' ---- start_next : {start_next} ----- ')
|
| 424 |
+
# if start_next > end_prev + 1:
|
| 425 |
+
# mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
|
| 426 |
+
# print(f' ---- mask_index_between_ngrams : {mask_index_between_ngrams} ----- ')
|
| 427 |
+
# mask_indices.append(mask_index_between_ngrams)
|
| 428 |
+
|
| 429 |
+
# # Mask a word after the last common n-gram
|
| 430 |
+
# last_ngram_end = ngram_positions[-1][1]
|
| 431 |
+
# print(f' ---- last_ngram_end : {last_ngram_end} ----- ')
|
| 432 |
+
# if last_ngram_end < len(non_stop_words) - 1:
|
| 433 |
+
# mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
|
| 434 |
+
# print(f' ---- mask_index_after_ngram : {mask_index_after_ngram} ----- ')
|
| 435 |
+
# mask_indices.append(mask_index_after_ngram)
|
| 436 |
+
|
| 437 |
+
# # Map mask indices back to original sentence
|
| 438 |
+
# adjusted_indices = [
|
| 439 |
+
# orig for orig, non_stop in enumerate(original_words)
|
| 440 |
+
# if non_stop in mask_indices
|
| 441 |
+
# ]
|
| 442 |
+
|
| 443 |
+
# # Apply masks to the original sentence
|
| 444 |
+
# for idx in adjusted_indices:
|
| 445 |
+
# original_words[idx] = self.tokenizer.mask_token
|
| 446 |
+
|
| 447 |
+
# return " ".join(original_words)
|
utils/old/masking/masking_methods_ok_working.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import BertTokenizer, BertForMaskedLM
|
| 4 |
+
from nltk.corpus import stopwords
|
| 5 |
+
import nltk
|
| 6 |
+
|
| 7 |
+
# Ensure stopwords are downloaded
|
| 8 |
+
try:
|
| 9 |
+
nltk.data.find('corpora/stopwords')
|
| 10 |
+
except LookupError:
|
| 11 |
+
nltk.download('stopwords')
|
| 12 |
+
|
| 13 |
+
class MaskingProcessor:
|
| 14 |
+
def __init__(self, ):
|
| 15 |
+
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
| 16 |
+
self.model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
| 17 |
+
self.stop_words = set(stopwords.words('english'))
|
| 18 |
+
|
| 19 |
+
def adjust_ngram_indices(self, words, common_ngrams, remove_stopwords):
|
| 20 |
+
"""
|
| 21 |
+
Adjust indices of common n-grams after removing stop words.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
words (list): List of words in the original sentence.
|
| 25 |
+
common_ngrams (dict): Common n-grams and their indices.
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
dict: Adjusted common n-grams and their indices.
|
| 29 |
+
"""
|
| 30 |
+
if not remove_stopwords:
|
| 31 |
+
return common_ngrams
|
| 32 |
+
|
| 33 |
+
non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
|
| 34 |
+
adjusted_ngrams = {}
|
| 35 |
+
|
| 36 |
+
for ngram, positions in common_ngrams.items():
|
| 37 |
+
adjusted_positions = []
|
| 38 |
+
for start, end in positions:
|
| 39 |
+
try:
|
| 40 |
+
new_start = non_stop_word_indices.index(start)
|
| 41 |
+
new_end = non_stop_word_indices.index(end)
|
| 42 |
+
adjusted_positions.append((new_start, new_end))
|
| 43 |
+
except ValueError:
|
| 44 |
+
continue # Skip if indices cannot be mapped
|
| 45 |
+
adjusted_ngrams[ngram] = adjusted_positions
|
| 46 |
+
|
| 47 |
+
return adjusted_ngrams
|
| 48 |
+
|
| 49 |
+
def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords):
|
| 50 |
+
"""
|
| 51 |
+
Mask one word before the first common n-gram, one between two n-grams,
|
| 52 |
+
and one after the last common n-gram (random selection).
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
original_sentence (str): Original sentence
|
| 56 |
+
common_ngrams (dict): Common n-grams and their indices
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
str: Masked sentence with original stop words retained
|
| 60 |
+
"""
|
| 61 |
+
words = original_sentence.split()
|
| 62 |
+
non_stop_words = [word for word in words if word.lower() not in self.stop_words] if remove_stopwords else words
|
| 63 |
+
adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)
|
| 64 |
+
|
| 65 |
+
mask_indices = []
|
| 66 |
+
# Handle before the first common n-gram
|
| 67 |
+
if adjusted_ngrams:
|
| 68 |
+
first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
|
| 69 |
+
if first_ngram_start > 0:
|
| 70 |
+
mask_indices.append(random.randint(0, first_ngram_start - 1))
|
| 71 |
+
|
| 72 |
+
# Handle between common n-grams
|
| 73 |
+
ngram_positions = list(adjusted_ngrams.values())
|
| 74 |
+
for i in range(len(ngram_positions) - 1):
|
| 75 |
+
end_prev = ngram_positions[i][-1][1]
|
| 76 |
+
start_next = ngram_positions[i + 1][0][0]
|
| 77 |
+
if start_next > end_prev + 1:
|
| 78 |
+
mask_indices.append(random.randint(end_prev + 1, start_next - 1))
|
| 79 |
+
|
| 80 |
+
# Handle after the last common n-gram
|
| 81 |
+
last_ngram_end = ngram_positions[-1][-1][1]
|
| 82 |
+
if last_ngram_end < len(non_stop_words) - 1:
|
| 83 |
+
mask_indices.append(random.randint(last_ngram_end + 1, len(non_stop_words) - 1))
|
| 84 |
+
|
| 85 |
+
# Mask the chosen indices
|
| 86 |
+
original_masked_sentence = words[:]
|
| 87 |
+
for idx in mask_indices:
|
| 88 |
+
if idx not in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
|
| 89 |
+
non_stop_words[idx] = self.tokenizer.mask_token
|
| 90 |
+
original_masked_sentence[idx] = self.tokenizer.mask_token
|
| 91 |
+
|
| 92 |
+
return " ".join(original_masked_sentence)
|
| 93 |
+
|
| 94 |
+
def mask_sentence_entropy(self, original_sentence, common_ngrams, remove_stopwords):
|
| 95 |
+
"""
|
| 96 |
+
Mask one word before the first common n-gram, one between two n-grams,
|
| 97 |
+
and one after the last common n-gram (highest entropy selection).
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
original_sentence (str): Original sentence
|
| 101 |
+
common_ngrams (dict): Common n-grams and their indices
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
str: Masked sentence with original stop words retained
|
| 105 |
+
"""
|
| 106 |
+
words = original_sentence.split()
|
| 107 |
+
non_stop_words = [word for word in words if word.lower() not in self.stop_words] if remove_stopwords else words
|
| 108 |
+
adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)
|
| 109 |
+
entropy_scores = {}
|
| 110 |
+
|
| 111 |
+
for idx, word in enumerate(non_stop_words):
|
| 112 |
+
if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
|
| 113 |
+
continue # Skip words in common n-grams
|
| 114 |
+
|
| 115 |
+
masked_sentence = non_stop_words[:idx] + [self.tokenizer.mask_token] + non_stop_words[idx + 1:]
|
| 116 |
+
masked_sentence = " ".join(masked_sentence)
|
| 117 |
+
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
|
| 118 |
+
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
| 119 |
+
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
outputs = self.model(input_ids)
|
| 122 |
+
logits = outputs.logits
|
| 123 |
+
|
| 124 |
+
filtered_logits = logits[0, mask_token_index, :]
|
| 125 |
+
probs = torch.softmax(filtered_logits, dim=-1)
|
| 126 |
+
entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item() # Add epsilon to prevent log(0)
|
| 127 |
+
entropy_scores[idx] = entropy
|
| 128 |
+
|
| 129 |
+
mask_indices = []
|
| 130 |
+
|
| 131 |
+
# Handle before the first common n-gram
|
| 132 |
+
if adjusted_ngrams:
|
| 133 |
+
first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
|
| 134 |
+
candidates = [i for i in range(first_ngram_start) if i in entropy_scores]
|
| 135 |
+
if candidates:
|
| 136 |
+
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
|
| 137 |
+
|
| 138 |
+
# Handle between common n-grams
|
| 139 |
+
ngram_positions = list(adjusted_ngrams.values())
|
| 140 |
+
for i in range(len(ngram_positions) - 1):
|
| 141 |
+
end_prev = ngram_positions[i][-1][1]
|
| 142 |
+
start_next = ngram_positions[i + 1][0][0]
|
| 143 |
+
candidates = [i for i in range(end_prev + 1, start_next) if i in entropy_scores]
|
| 144 |
+
if candidates:
|
| 145 |
+
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
|
| 146 |
+
|
| 147 |
+
# Handle after the last common n-gram
|
| 148 |
+
last_ngram_end = ngram_positions[-1][-1][1]
|
| 149 |
+
candidates = [i for i in range(last_ngram_end + 1, len(non_stop_words)) if i in entropy_scores]
|
| 150 |
+
if candidates:
|
| 151 |
+
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
|
| 152 |
+
|
| 153 |
+
# Mask the chosen indices
|
| 154 |
+
original_masked_sentence = words[:]
|
| 155 |
+
for idx in mask_indices:
|
| 156 |
+
non_stop_words[idx] = self.tokenizer.mask_token
|
| 157 |
+
original_masked_sentence[idx] = self.tokenizer.mask_token
|
| 158 |
+
|
| 159 |
+
return " ".join(original_masked_sentence)
|
| 160 |
+
|
| 161 |
+
def calculate_mask_logits(self, masked_sentence):
|
| 162 |
+
"""
|
| 163 |
+
Calculate logits for masked tokens in the sentence using BERT.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
masked_sentence (str): Sentence with [MASK] tokens
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
dict: Masked token indices and their logits
|
| 170 |
+
"""
|
| 171 |
+
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
|
| 172 |
+
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
| 173 |
+
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
outputs = self.model(input_ids)
|
| 176 |
+
logits = outputs.logits
|
| 177 |
+
|
| 178 |
+
mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index}
|
| 179 |
+
return mask_logits
|
| 180 |
+
|
| 181 |
+
def process_sentences(self, original_sentences, result_dict, method="random", remove_stopwords=False):
|
| 182 |
+
"""
|
| 183 |
+
Process a list of sentences and calculate logits for masked tokens using the specified method.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
original_sentences (list): List of original sentences
|
| 187 |
+
result_dict (dict): Common n-grams and their indices for each sentence
|
| 188 |
+
method (str): Masking method ("random" or "entropy")
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
dict: Masked sentences and their logits for each sentence
|
| 192 |
+
"""
|
| 193 |
+
results = {}
|
| 194 |
+
|
| 195 |
+
for sentence, ngrams in result_dict.items():
|
| 196 |
+
if method == "random":
|
| 197 |
+
masked_sentence = self.mask_sentence_random(sentence, ngrams, remove_stopwords)
|
| 198 |
+
elif method == "entropy":
|
| 199 |
+
masked_sentence = self.mask_sentence_entropy(sentence, ngrams, remove_stopwords)
|
| 200 |
+
else:
|
| 201 |
+
raise ValueError("Invalid method. Choose 'random' or 'entropy'.")
|
| 202 |
+
|
| 203 |
+
logits = self.calculate_mask_logits(masked_sentence)
|
| 204 |
+
results[sentence] = {
|
| 205 |
+
"masked_sentence": masked_sentence,
|
| 206 |
+
"mask_logits": logits
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
return results
|
| 210 |
+
|
| 211 |
+
# Example usage
|
| 212 |
+
if __name__ == "__main__":
|
| 213 |
+
# !!! Working both the cases regardless if the stopword is removed or not
|
| 214 |
+
sentences = [
|
| 215 |
+
"The quick brown fox jumps over the lazy dog.",
|
| 216 |
+
"A quick brown dog outpaces a lazy fox.",
|
| 217 |
+
"Quick brown animals leap over lazy obstacles."
|
| 218 |
+
]
|
| 219 |
+
|
| 220 |
+
result_dict = {
|
| 221 |
+
"The quick brown fox jumps over the lazy dog.": {"quick brown": [(1, 2)], "lazy": [(7, 7)]},
|
| 222 |
+
"A quick brown dog outpaces a lazy fox.": {"quick brown": [(1, 2)], "lazy": [(6, 6)]},
|
| 223 |
+
"Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(5, 5)]}
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
# result_dict = {
|
| 227 |
+
# "The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
|
| 228 |
+
# "A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
|
| 229 |
+
# "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}
|
| 230 |
+
# }
|
| 231 |
+
|
| 232 |
+
processor = MaskingProcessor()
|
| 233 |
+
results_random = processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
|
| 234 |
+
# results_entropy = processor.process_sentences(sentences, result_dict, method="entropy", remove_stopwords=False)
|
| 235 |
+
|
| 236 |
+
for sentence, output in results_random.items():
|
| 237 |
+
print(f"Original Sentence (Random): {sentence}")
|
| 238 |
+
print(f"Masked Sentence (Random): {output['masked_sentence']}")
|
| 239 |
+
# print(f"Mask Logits (Random): {output['mask_logits']}")
|
| 240 |
+
print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
|
| 241 |
+
print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
|
| 242 |
+
print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
|
| 243 |
+
print('--------------------------------')
|
| 244 |
+
for mask_idx, logits in output["mask_logits"].items():
|
| 245 |
+
print(f"Logits for [MASK] at position {mask_idx}:")
|
| 246 |
+
print(f' logits : {logits[:5]}') # List of logits for all vocabulary tokens
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
# print('--------------------------------')
|
| 251 |
+
# for sentence, output in results_entropy.items():
|
| 252 |
+
# print(f"Original Sentence (Entropy): {sentence}")
|
| 253 |
+
# print(f"Masked Sentence (Entropy): {output['masked_sentence']}")
|
| 254 |
+
# # print(f"Mask Logits (Entropy): {output['mask_logits']}")
|
| 255 |
+
# print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
|
| 256 |
+
# print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
|
| 257 |
+
# print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
|
utils/old/masking/masking_methods_v1_working.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import BertTokenizer, BertForMaskedLM
|
| 4 |
+
from nltk.corpus import stopwords
|
| 5 |
+
import nltk
|
| 6 |
+
|
| 7 |
+
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
| 8 |
+
# THIS IS WORKING WHEN THE COORDINATES ARE WITHOUT REMOVING STOPWORDS
|
| 9 |
+
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# Ensure stopwords are downloaded
|
| 13 |
+
try:
|
| 14 |
+
nltk.data.find('corpora/stopwords')
|
| 15 |
+
except LookupError:
|
| 16 |
+
nltk.download('stopwords')
|
| 17 |
+
|
| 18 |
+
class MaskingProcessor:
|
| 19 |
+
def __init__(self):
|
| 20 |
+
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
| 21 |
+
self.model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
| 22 |
+
self.stop_words = set(stopwords.words('english'))
|
| 23 |
+
|
| 24 |
+
def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords=False):
|
| 25 |
+
"""
|
| 26 |
+
Mask one word before the first common n-gram, one between two n-grams,
|
| 27 |
+
and one after the last common n-gram (random selection).
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
original_sentence (str): Original sentence
|
| 31 |
+
common_ngrams (dict): Common n-grams and their indices
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
str: Masked sentence
|
| 35 |
+
"""
|
| 36 |
+
if remove_stopwords:
|
| 37 |
+
words = original_sentence.split()
|
| 38 |
+
words = [word for word in words if word not in self.stop_words]
|
| 39 |
+
else:
|
| 40 |
+
words = original_sentence.split()
|
| 41 |
+
|
| 42 |
+
mask_indices = []
|
| 43 |
+
# Handle before the first common n-gram
|
| 44 |
+
if common_ngrams:
|
| 45 |
+
first_ngram_start = list(common_ngrams.values())[0][0][0]
|
| 46 |
+
if first_ngram_start > 0:
|
| 47 |
+
mask_indices.append(random.randint(0, first_ngram_start - 1))
|
| 48 |
+
|
| 49 |
+
# Handle between common n-grams
|
| 50 |
+
ngram_positions = list(common_ngrams.values())
|
| 51 |
+
for i in range(len(ngram_positions) - 1):
|
| 52 |
+
end_prev = ngram_positions[i][-1][1]
|
| 53 |
+
start_next = ngram_positions[i + 1][0][0]
|
| 54 |
+
if start_next > end_prev + 1:
|
| 55 |
+
mask_indices.append(random.randint(end_prev + 1, start_next - 1))
|
| 56 |
+
|
| 57 |
+
# Handle after the last common n-gram
|
| 58 |
+
last_ngram_end = ngram_positions[-1][-1][1]
|
| 59 |
+
if last_ngram_end < len(words) - 1:
|
| 60 |
+
mask_indices.append(random.randint(last_ngram_end + 1, len(words) - 1))
|
| 61 |
+
|
| 62 |
+
# Mask the chosen indices
|
| 63 |
+
for idx in mask_indices:
|
| 64 |
+
if idx not in [index for ngram_indices in common_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
|
| 65 |
+
words[idx] = self.tokenizer.mask_token
|
| 66 |
+
|
| 67 |
+
return " ".join(words)
|
| 68 |
+
|
| 69 |
+
def mask_sentence_entropy(self, original_sentence, common_ngrams, remove_stopwords=False):
|
| 70 |
+
"""
|
| 71 |
+
Mask one word before the first common n-gram, one between two n-grams,
|
| 72 |
+
and one after the last common n-gram (highest entropy selection).
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
original_sentence (str): Original sentence
|
| 76 |
+
common_ngrams (dict): Common n-grams and their indices
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
str: Masked sentence
|
| 80 |
+
"""
|
| 81 |
+
if remove_stopwords:
|
| 82 |
+
words = original_sentence.split()
|
| 83 |
+
words = [word for word in words if word not in self.stop_words]
|
| 84 |
+
else:
|
| 85 |
+
words = original_sentence.split()
|
| 86 |
+
entropy_scores = {}
|
| 87 |
+
|
| 88 |
+
for idx, word in enumerate(words):
|
| 89 |
+
if idx in [index for ngram_indices in common_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
|
| 90 |
+
continue # Skip words in common n-grams
|
| 91 |
+
|
| 92 |
+
masked_sentence = words[:idx] + [self.tokenizer.mask_token] + words[idx + 1:]
|
| 93 |
+
masked_sentence = " ".join(masked_sentence)
|
| 94 |
+
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
|
| 95 |
+
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
| 96 |
+
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
outputs = self.model(input_ids)
|
| 99 |
+
logits = outputs.logits
|
| 100 |
+
|
| 101 |
+
filtered_logits = logits[0, mask_token_index, :]
|
| 102 |
+
probs = torch.softmax(filtered_logits, dim=-1)
|
| 103 |
+
entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item() # Add epsilon to prevent log(0)
|
| 104 |
+
entropy_scores[idx] = entropy
|
| 105 |
+
|
| 106 |
+
mask_indices = []
|
| 107 |
+
|
| 108 |
+
# Handle before the first common n-gram
|
| 109 |
+
if common_ngrams:
|
| 110 |
+
first_ngram_start = list(common_ngrams.values())[0][0][0]
|
| 111 |
+
candidates = [i for i in range(first_ngram_start) if i in entropy_scores]
|
| 112 |
+
if candidates:
|
| 113 |
+
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
|
| 114 |
+
|
| 115 |
+
# Handle between common n-grams
|
| 116 |
+
ngram_positions = list(common_ngrams.values())
|
| 117 |
+
for i in range(len(ngram_positions) - 1):
|
| 118 |
+
end_prev = ngram_positions[i][-1][1]
|
| 119 |
+
start_next = ngram_positions[i + 1][0][0]
|
| 120 |
+
candidates = [i for i in range(end_prev + 1, start_next) if i in entropy_scores]
|
| 121 |
+
if candidates:
|
| 122 |
+
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
|
| 123 |
+
|
| 124 |
+
# Handle after the last common n-gram
|
| 125 |
+
last_ngram_end = ngram_positions[-1][-1][1]
|
| 126 |
+
candidates = [i for i in range(last_ngram_end + 1, len(words)) if i in entropy_scores]
|
| 127 |
+
if candidates:
|
| 128 |
+
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
|
| 129 |
+
|
| 130 |
+
# Mask the chosen indices
|
| 131 |
+
for idx in mask_indices:
|
| 132 |
+
words[idx] = self.tokenizer.mask_token
|
| 133 |
+
|
| 134 |
+
return " ".join(words)
|
| 135 |
+
|
| 136 |
+
def calculate_mask_logits(self, masked_sentence):
|
| 137 |
+
"""
|
| 138 |
+
Calculate logits for masked tokens in the sentence using BERT.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
masked_sentence (str): Sentence with [MASK] tokens
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
dict: Masked token indices and their logits
|
| 145 |
+
"""
|
| 146 |
+
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
|
| 147 |
+
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
| 148 |
+
|
| 149 |
+
with torch.no_grad():
|
| 150 |
+
outputs = self.model(input_ids)
|
| 151 |
+
logits = outputs.logits
|
| 152 |
+
|
| 153 |
+
mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index}
|
| 154 |
+
return mask_logits
|
| 155 |
+
|
| 156 |
+
def process_sentences(self, original_sentences, result_dict, remove_stopwords=False, method="random"):
|
| 157 |
+
"""
|
| 158 |
+
Process a list of sentences and calculate logits for masked tokens using the specified method.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
original_sentences (list): List of original sentences
|
| 162 |
+
result_dict (dict): Common n-grams and their indices for each sentence
|
| 163 |
+
method (str): Masking method ("random" or "entropy")
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
dict: Masked sentences and their logits for each sentence
|
| 167 |
+
"""
|
| 168 |
+
results = {}
|
| 169 |
+
|
| 170 |
+
for sentence, ngrams in result_dict.items():
|
| 171 |
+
if method == "random":
|
| 172 |
+
masked_sentence = self.mask_sentence_random(sentence, ngrams)
|
| 173 |
+
elif method == "entropy":
|
| 174 |
+
masked_sentence = self.mask_sentence_entropy(sentence, ngrams)
|
| 175 |
+
else:
|
| 176 |
+
raise ValueError("Invalid method. Choose 'random' or 'entropy'.")
|
| 177 |
+
|
| 178 |
+
logits = self.calculate_mask_logits(masked_sentence)
|
| 179 |
+
results[sentence] = {
|
| 180 |
+
"masked_sentence": masked_sentence,
|
| 181 |
+
"mask_logits": logits
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
return results
|
| 185 |
+
|
| 186 |
+
# Example usage
|
| 187 |
+
if __name__ == "__main__":
|
| 188 |
+
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
| 189 |
+
# THIS IS WORKING WHEN THE COORDINATES ARE WITHOUT REMOVING STOPWORDS
|
| 190 |
+
|
| 191 |
+
sentences = [
|
| 192 |
+
"The quick brown fox jumps over the lazy dog.",
|
| 193 |
+
"A quick brown dog outpaces a lazy fox.",
|
| 194 |
+
"Quick brown animals leap over lazy obstacles."
|
| 195 |
+
]
|
| 196 |
+
|
| 197 |
+
result_dict = {
|
| 198 |
+
"The quick brown fox jumps over the lazy dog.": {"quick brown": [(1, 2)], "lazy": [(7, 7)]},
|
| 199 |
+
"A quick brown dog outpaces a lazy fox.": {"quick brown": [(1, 2)], "lazy": [(6, 6)]},
|
| 200 |
+
"Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(5, 5)]}
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
# result_dict = {
|
| 204 |
+
# "The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
|
| 205 |
+
# "A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
|
| 206 |
+
# "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}
|
| 207 |
+
# }
|
| 208 |
+
|
| 209 |
+
processor = MaskingProcessor()
|
| 210 |
+
results_random = processor.process_sentences(sentences, result_dict, remove_stopwords=True, method="random")
|
| 211 |
+
results_entropy = processor.process_sentences(sentences, result_dict, remove_stopwords=True, method="entropy")
|
| 212 |
+
|
| 213 |
+
for sentence, output in results_random.items():
|
| 214 |
+
print(f"Original Sentence (Random): {sentence}")
|
| 215 |
+
print(f"Masked Sentence (Random): {output['masked_sentence']}")
|
| 216 |
+
# print(f"Mask Logits (Random): {output['mask_logits']}")
|
| 217 |
+
|
| 218 |
+
for sentence, output in results_entropy.items():
|
| 219 |
+
print(f"Original Sentence (Entropy): {sentence}")
|
| 220 |
+
print(f"Masked Sentence (Entropy): {output['masked_sentence']}")
|
| 221 |
+
# print(f"Mask Logits (Entropy): {output['mask_logits']}")
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
'''
|
| 227 |
+
result_dict = {
|
| 228 |
+
"The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
|
| 229 |
+
"A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
|
| 230 |
+
"Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
'''
|
utils/old/masking_methods_final_copy.py
ADDED
|
@@ -0,0 +1,619 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import BertTokenizer, BertForMaskedLM
|
| 4 |
+
from nltk.corpus import stopwords
|
| 5 |
+
import nltk
|
| 6 |
+
from transformers import RobertaTokenizer, RobertaForMaskedLM
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# Ensure stopwords are downloaded
|
| 10 |
+
try:
|
| 11 |
+
nltk.data.find('corpora/stopwords')
|
| 12 |
+
except LookupError:
|
| 13 |
+
nltk.download('stopwords')
|
| 14 |
+
|
| 15 |
+
class MaskingProcessor:
|
| 16 |
+
# def __init__(self, tokenizer, model):
|
| 17 |
+
def __init__(self):
|
| 18 |
+
# self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
| 19 |
+
# self.model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
| 20 |
+
|
| 21 |
+
# self.tokenizer = tokenizer
|
| 22 |
+
# self.model = model
|
| 23 |
+
|
| 24 |
+
self.tokenizer = BertTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
|
| 25 |
+
self.model = BertForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
|
| 26 |
+
|
| 27 |
+
# self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
|
| 28 |
+
# self.model = RobertaForMaskedLM.from_pretrained("roberta-base")
|
| 29 |
+
|
| 30 |
+
self.stop_words = set(stopwords.words('english'))
|
| 31 |
+
|
| 32 |
+
def remove_stopwords(self, words):
|
| 33 |
+
"""
|
| 34 |
+
Remove stopwords from the given list of words.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
words (list): List of words.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
list: List of non-stop words.
|
| 41 |
+
"""
|
| 42 |
+
return [word for word in words if word.lower() not in self.stop_words]
|
| 43 |
+
|
| 44 |
+
def adjust_ngram_indices(self, original_words, common_ngrams):
|
| 45 |
+
"""
|
| 46 |
+
Adjust indices of common n-grams after removing stopwords.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
original_words (list): Original list of words.
|
| 50 |
+
common_ngrams (dict): Common n-grams and their indices.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
dict: Adjusted common n-grams with updated indices.
|
| 54 |
+
"""
|
| 55 |
+
non_stop_words = self.remove_stopwords(original_words)
|
| 56 |
+
original_to_non_stop = []
|
| 57 |
+
non_stop_idx = 0
|
| 58 |
+
|
| 59 |
+
for original_idx, word in enumerate(original_words):
|
| 60 |
+
if word.lower() not in self.stop_words:
|
| 61 |
+
original_to_non_stop.append((original_idx, non_stop_idx))
|
| 62 |
+
non_stop_idx += 1
|
| 63 |
+
|
| 64 |
+
adjusted_ngrams = {}
|
| 65 |
+
for ngram, positions in common_ngrams.items():
|
| 66 |
+
adjusted_positions = []
|
| 67 |
+
for start, end in positions:
|
| 68 |
+
try:
|
| 69 |
+
new_start = next(non_stop for orig, non_stop in original_to_non_stop if orig == start)
|
| 70 |
+
new_end = next(non_stop for orig, non_stop in original_to_non_stop if orig == end)
|
| 71 |
+
adjusted_positions.append((new_start, new_end))
|
| 72 |
+
except StopIteration:
|
| 73 |
+
continue # Skip if indices cannot be mapped
|
| 74 |
+
adjusted_ngrams[ngram] = adjusted_positions
|
| 75 |
+
|
| 76 |
+
return adjusted_ngrams
|
| 77 |
+
|
| 78 |
+
def mask_sentence_random(self, sentence, common_ngrams):
|
| 79 |
+
"""
|
| 80 |
+
Mask words in the sentence based on the specified rules after removing stopwords.
|
| 81 |
+
"""
|
| 82 |
+
# Split sentence into words
|
| 83 |
+
original_words = sentence.split()
|
| 84 |
+
|
| 85 |
+
# Handle punctuation at the end
|
| 86 |
+
has_punctuation = False
|
| 87 |
+
punctuation = None
|
| 88 |
+
if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
|
| 89 |
+
has_punctuation = True
|
| 90 |
+
punctuation = original_words[-1][-1]
|
| 91 |
+
original_words = original_words[:-1]
|
| 92 |
+
|
| 93 |
+
print(f' ---- original_words : {original_words} ----- ')
|
| 94 |
+
|
| 95 |
+
# Process words without punctuation
|
| 96 |
+
non_stop_words = self.remove_stopwords(original_words)
|
| 97 |
+
adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
|
| 98 |
+
|
| 99 |
+
# Rest of the existing function code...
|
| 100 |
+
mask_indices = []
|
| 101 |
+
ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
|
| 102 |
+
|
| 103 |
+
if ngram_positions:
|
| 104 |
+
first_ngram_start = ngram_positions[0][0]
|
| 105 |
+
if first_ngram_start > 0:
|
| 106 |
+
mask_index_before_ngram = random.randint(0, first_ngram_start-1)
|
| 107 |
+
mask_indices.append(mask_index_before_ngram)
|
| 108 |
+
|
| 109 |
+
# Mask words between common n-grams
|
| 110 |
+
for i in range(len(ngram_positions) - 1):
|
| 111 |
+
end_prev = ngram_positions[i][1]
|
| 112 |
+
start_next = ngram_positions[i + 1][0]
|
| 113 |
+
if start_next > end_prev + 1:
|
| 114 |
+
mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
|
| 115 |
+
mask_indices.append(mask_index_between_ngrams)
|
| 116 |
+
|
| 117 |
+
# Mask a word after the last common n-gram
|
| 118 |
+
last_ngram_end = ngram_positions[-1][1]
|
| 119 |
+
if last_ngram_end < len(non_stop_words) - 1:
|
| 120 |
+
mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
|
| 121 |
+
mask_indices.append(mask_index_after_ngram)
|
| 122 |
+
|
| 123 |
+
# Create mapping from non-stop words to original indices
|
| 124 |
+
non_stop_to_original = {}
|
| 125 |
+
non_stop_idx = 0
|
| 126 |
+
for orig_idx, word in enumerate(original_words):
|
| 127 |
+
if word.lower() not in self.stop_words:
|
| 128 |
+
non_stop_to_original[non_stop_idx] = orig_idx
|
| 129 |
+
non_stop_idx += 1
|
| 130 |
+
|
| 131 |
+
# Map mask indices and apply masks
|
| 132 |
+
original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
|
| 133 |
+
masked_words = original_words.copy()
|
| 134 |
+
for idx in original_mask_indices:
|
| 135 |
+
masked_words[idx] = self.tokenizer.mask_token
|
| 136 |
+
# masked_words[idx] = '<mask>' # for roberta
|
| 137 |
+
|
| 138 |
+
# Add back punctuation if it existed
|
| 139 |
+
if has_punctuation:
|
| 140 |
+
masked_words.append(punctuation)
|
| 141 |
+
|
| 142 |
+
print(f' ***** masked_words at end : {masked_words} ***** ')
|
| 143 |
+
print(f' ***** original_mask_indices : {original_mask_indices} ***** ')
|
| 144 |
+
print(f' ***** TESTING : {" ".join(masked_words)} ***** ')
|
| 145 |
+
|
| 146 |
+
return " ".join(masked_words), original_mask_indices
|
| 147 |
+
|
| 148 |
+
def mask_sentence_pseudorandom(self, sentence, common_ngrams):
|
| 149 |
+
"""
|
| 150 |
+
Mask words in the sentence based on the specified rules after removing stopwords.
|
| 151 |
+
"""
|
| 152 |
+
# Split sentence into words
|
| 153 |
+
random.seed(3)
|
| 154 |
+
original_words = sentence.split()
|
| 155 |
+
|
| 156 |
+
# Handle punctuation at the end
|
| 157 |
+
has_punctuation = False
|
| 158 |
+
punctuation = None
|
| 159 |
+
if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
|
| 160 |
+
has_punctuation = True
|
| 161 |
+
punctuation = original_words[-1][-1]
|
| 162 |
+
original_words = original_words[:-1]
|
| 163 |
+
|
| 164 |
+
print(f' ---- original_words : {original_words} ----- ')
|
| 165 |
+
|
| 166 |
+
# Process words without punctuation
|
| 167 |
+
non_stop_words = self.remove_stopwords(original_words)
|
| 168 |
+
adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
|
| 169 |
+
|
| 170 |
+
# Rest of the existing function code...
|
| 171 |
+
mask_indices = []
|
| 172 |
+
ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
|
| 173 |
+
|
| 174 |
+
if ngram_positions:
|
| 175 |
+
first_ngram_start = ngram_positions[0][0]
|
| 176 |
+
if first_ngram_start > 0:
|
| 177 |
+
mask_index_before_ngram = random.randint(0, first_ngram_start-1)
|
| 178 |
+
mask_indices.append(mask_index_before_ngram)
|
| 179 |
+
|
| 180 |
+
# Mask words between common n-grams
|
| 181 |
+
for i in range(len(ngram_positions) - 1):
|
| 182 |
+
end_prev = ngram_positions[i][1]
|
| 183 |
+
start_next = ngram_positions[i + 1][0]
|
| 184 |
+
if start_next > end_prev + 1:
|
| 185 |
+
mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
|
| 186 |
+
mask_indices.append(mask_index_between_ngrams)
|
| 187 |
+
|
| 188 |
+
# Mask a word after the last common n-gram
|
| 189 |
+
last_ngram_end = ngram_positions[-1][1]
|
| 190 |
+
if last_ngram_end < len(non_stop_words) - 1:
|
| 191 |
+
mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
|
| 192 |
+
mask_indices.append(mask_index_after_ngram)
|
| 193 |
+
|
| 194 |
+
# Create mapping from non-stop words to original indices
|
| 195 |
+
non_stop_to_original = {}
|
| 196 |
+
non_stop_idx = 0
|
| 197 |
+
for orig_idx, word in enumerate(original_words):
|
| 198 |
+
if word.lower() not in self.stop_words:
|
| 199 |
+
non_stop_to_original[non_stop_idx] = orig_idx
|
| 200 |
+
non_stop_idx += 1
|
| 201 |
+
|
| 202 |
+
# Map mask indices and apply masks
|
| 203 |
+
original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
|
| 204 |
+
masked_words = original_words.copy()
|
| 205 |
+
for idx in original_mask_indices:
|
| 206 |
+
masked_words[idx] = self.tokenizer.mask_token
|
| 207 |
+
# masked_words[idx] = '<mask>' # for roberta
|
| 208 |
+
|
| 209 |
+
# Add back punctuation if it existed
|
| 210 |
+
if has_punctuation:
|
| 211 |
+
masked_words.append(punctuation)
|
| 212 |
+
|
| 213 |
+
print(f' ***** masked_words at end : {masked_words} ***** ')
|
| 214 |
+
print(f' ***** original_mask_indices : {original_mask_indices} ***** ')
|
| 215 |
+
print(f' ***** TESTING : {" ".join(masked_words)} ***** ')
|
| 216 |
+
|
| 217 |
+
return " ".join(masked_words), original_mask_indices
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def calculate_word_entropy(self, sentence, word_position):
|
| 221 |
+
"""
|
| 222 |
+
Calculate entropy for a specific word position in the sentence.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
sentence (str): The input sentence
|
| 226 |
+
word_position (int): Position of the word to calculate entropy for
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
float: Entropy value for the word
|
| 230 |
+
"""
|
| 231 |
+
words = sentence.split()
|
| 232 |
+
masked_words = words.copy()
|
| 233 |
+
masked_words[word_position] = self.tokenizer.mask_token
|
| 234 |
+
masked_sentence = " ".join(masked_words)
|
| 235 |
+
|
| 236 |
+
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
|
| 237 |
+
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
| 238 |
+
|
| 239 |
+
with torch.no_grad():
|
| 240 |
+
outputs = self.model(input_ids)
|
| 241 |
+
logits = outputs.logits
|
| 242 |
+
|
| 243 |
+
# Get probabilities for the masked position
|
| 244 |
+
probs = torch.nn.functional.softmax(logits[0, mask_token_index], dim=-1)
|
| 245 |
+
# Calculate entropy: -sum(p * log(p))
|
| 246 |
+
entropy = -torch.sum(probs * torch.log(probs + 1e-9))
|
| 247 |
+
|
| 248 |
+
return entropy.item()
|
| 249 |
+
|
| 250 |
+
def mask_sentence_entropy(self, sentence, common_ngrams):
|
| 251 |
+
"""
|
| 252 |
+
Mask words in the sentence based on entropy, following n-gram positioning rules.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
sentence (str): Original sentence
|
| 256 |
+
common_ngrams (dict): Common n-grams and their indices
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
str: Masked sentence
|
| 260 |
+
"""
|
| 261 |
+
# Split sentence into words
|
| 262 |
+
original_words = sentence.split()
|
| 263 |
+
|
| 264 |
+
# Handle punctuation at the end
|
| 265 |
+
has_punctuation = False
|
| 266 |
+
punctuation = None
|
| 267 |
+
if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
|
| 268 |
+
has_punctuation = True
|
| 269 |
+
punctuation = original_words[-1][-1]
|
| 270 |
+
original_words = original_words[:-1]
|
| 271 |
+
|
| 272 |
+
# Process words without punctuation
|
| 273 |
+
non_stop_words = self.remove_stopwords(original_words)
|
| 274 |
+
adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
|
| 275 |
+
|
| 276 |
+
# Create mapping from non-stop words to original indices
|
| 277 |
+
non_stop_to_original = {}
|
| 278 |
+
original_to_non_stop = {}
|
| 279 |
+
non_stop_idx = 0
|
| 280 |
+
for orig_idx, word in enumerate(original_words):
|
| 281 |
+
if word.lower() not in self.stop_words:
|
| 282 |
+
non_stop_to_original[non_stop_idx] = orig_idx
|
| 283 |
+
original_to_non_stop[orig_idx] = non_stop_idx
|
| 284 |
+
non_stop_idx += 1
|
| 285 |
+
|
| 286 |
+
ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
|
| 287 |
+
mask_indices = []
|
| 288 |
+
|
| 289 |
+
if ngram_positions:
|
| 290 |
+
# Handle words before first n-gram
|
| 291 |
+
first_ngram_start = ngram_positions[0][0]
|
| 292 |
+
if first_ngram_start > 0:
|
| 293 |
+
candidate_positions = range(0, first_ngram_start)
|
| 294 |
+
entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
|
| 295 |
+
for pos in candidate_positions]
|
| 296 |
+
mask_indices.append(max(entropies, key=lambda x: x[1])[0])
|
| 297 |
+
|
| 298 |
+
# Handle words between n-grams
|
| 299 |
+
for i in range(len(ngram_positions) - 1):
|
| 300 |
+
end_prev = ngram_positions[i][1]
|
| 301 |
+
start_next = ngram_positions[i + 1][0]
|
| 302 |
+
if start_next > end_prev + 1:
|
| 303 |
+
candidate_positions = range(end_prev + 1, start_next)
|
| 304 |
+
entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
|
| 305 |
+
for pos in candidate_positions]
|
| 306 |
+
mask_indices.append(max(entropies, key=lambda x: x[1])[0])
|
| 307 |
+
|
| 308 |
+
# Handle words after last n-gram
|
| 309 |
+
last_ngram_end = ngram_positions[-1][1]
|
| 310 |
+
if last_ngram_end < len(non_stop_words) - 1:
|
| 311 |
+
candidate_positions = range(last_ngram_end + 1, len(non_stop_words))
|
| 312 |
+
entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
|
| 313 |
+
for pos in candidate_positions]
|
| 314 |
+
mask_indices.append(max(entropies, key=lambda x: x[1])[0])
|
| 315 |
+
|
| 316 |
+
# Map mask indices to original sentence positions and apply masks
|
| 317 |
+
original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
|
| 318 |
+
masked_words = original_words.copy()
|
| 319 |
+
for idx in original_mask_indices:
|
| 320 |
+
masked_words[idx] = self.tokenizer.mask_token
|
| 321 |
+
|
| 322 |
+
# Add back punctuation if it existed
|
| 323 |
+
if has_punctuation:
|
| 324 |
+
masked_words.append(punctuation)
|
| 325 |
+
|
| 326 |
+
return " ".join(masked_words), original_mask_indices
|
| 327 |
+
|
| 328 |
+
def calculate_mask_logits(self, original_sentence, original_mask_indices):
|
| 329 |
+
"""
|
| 330 |
+
Calculate logits for masked tokens in the sentence using BERT.
|
| 331 |
+
|
| 332 |
+
Args:
|
| 333 |
+
original_sentence (str): Original sentence without masks
|
| 334 |
+
original_mask_indices (list): List of indices to mask
|
| 335 |
+
|
| 336 |
+
Returns:
|
| 337 |
+
dict: Masked token indices and their logits
|
| 338 |
+
"""
|
| 339 |
+
print('==========================================================================================================')
|
| 340 |
+
words = original_sentence.split()
|
| 341 |
+
print(f' ##### calculate_mask_logits >> words : {words} ##### ')
|
| 342 |
+
mask_logits = {}
|
| 343 |
+
|
| 344 |
+
for idx in original_mask_indices:
|
| 345 |
+
# Create a copy of words and mask the current position
|
| 346 |
+
print(f' ---- idx : {idx} ----- ')
|
| 347 |
+
masked_words = words.copy()
|
| 348 |
+
masked_words[idx] = '[MASK]'
|
| 349 |
+
# masked_words[idx] = '<mask>' # for roberta
|
| 350 |
+
masked_sentence = " ".join(masked_words)
|
| 351 |
+
print(f' ---- masked_sentence : {masked_sentence} ----- ')
|
| 352 |
+
|
| 353 |
+
# Calculate logits for the current mask
|
| 354 |
+
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
|
| 355 |
+
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
| 356 |
+
|
| 357 |
+
with torch.no_grad():
|
| 358 |
+
outputs = self.model(input_ids)
|
| 359 |
+
logits = outputs.logits
|
| 360 |
+
|
| 361 |
+
# Extract logits for the masked position
|
| 362 |
+
mask_logits_tensor = logits[0, mask_token_index, :]
|
| 363 |
+
|
| 364 |
+
# Get top logits and corresponding tokens
|
| 365 |
+
top_mask_logits, top_mask_indices = torch.topk(mask_logits_tensor, 100, dim=-1) # Get more candidates
|
| 366 |
+
|
| 367 |
+
# Convert token IDs to words and filter out subword tokens
|
| 368 |
+
top_tokens = []
|
| 369 |
+
top_logits = []
|
| 370 |
+
seen_words = set() # To keep track of unique words
|
| 371 |
+
|
| 372 |
+
for token_id, logit in zip(top_mask_indices[0], top_mask_logits[0]):
|
| 373 |
+
token = self.tokenizer.convert_ids_to_tokens(token_id.item())
|
| 374 |
+
|
| 375 |
+
# Skip if it's a subword token (starts with ##)
|
| 376 |
+
if token.startswith('##'):
|
| 377 |
+
continue
|
| 378 |
+
|
| 379 |
+
# Convert token to proper word
|
| 380 |
+
word = self.tokenizer.convert_tokens_to_string([token]).strip()
|
| 381 |
+
|
| 382 |
+
# Only add if it's a new word and not empty
|
| 383 |
+
if word and word not in seen_words:
|
| 384 |
+
seen_words.add(word)
|
| 385 |
+
top_tokens.append(word)
|
| 386 |
+
top_logits.append(logit.item())
|
| 387 |
+
|
| 388 |
+
# Break if we have 50 unique complete words
|
| 389 |
+
if len(top_tokens) == 50:
|
| 390 |
+
break
|
| 391 |
+
|
| 392 |
+
# print(f' ---- top_tokens : {top_tokens} ----- ')
|
| 393 |
+
|
| 394 |
+
# Store results
|
| 395 |
+
mask_logits[idx] = {
|
| 396 |
+
"tokens": top_tokens,
|
| 397 |
+
"logits": top_logits
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
return mask_logits
|
| 401 |
+
|
| 402 |
+
# def calculate_mask_logits(self, original_sentence, original_mask_indices):
|
| 403 |
+
# """
|
| 404 |
+
# Calculate logits for masked tokens in the sentence using BERT.
|
| 405 |
+
|
| 406 |
+
# Args:
|
| 407 |
+
# original_sentence (str): Original sentence without masks
|
| 408 |
+
# original_mask_indices (list): List of indices to mask
|
| 409 |
+
|
| 410 |
+
# Returns:
|
| 411 |
+
# dict: Masked token indices and their logits
|
| 412 |
+
# """
|
| 413 |
+
# words = original_sentence.split()
|
| 414 |
+
# print(f' ##### calculate_mask_logits >> words : {words} ##### ')
|
| 415 |
+
# mask_logits = {}
|
| 416 |
+
|
| 417 |
+
# for idx in original_mask_indices:
|
| 418 |
+
# # Create a copy of words and mask the current position
|
| 419 |
+
# print(f' ---- idx : {idx} ----- ')
|
| 420 |
+
# masked_words = words.copy()
|
| 421 |
+
# print(f' ---- words : {masked_words} ----- ')
|
| 422 |
+
# # masked_words[idx] = self.tokenizer.mask_token
|
| 423 |
+
# masked_words[idx] = '[MASK]'
|
| 424 |
+
# print(f' ---- masked_words : {masked_words} ----- ')
|
| 425 |
+
# masked_sentence = " ".join(masked_words)
|
| 426 |
+
# print(f' ---- masked_sentence : {masked_sentence} ----- ')
|
| 427 |
+
|
| 428 |
+
# # Calculate logits for the current mask
|
| 429 |
+
# input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
|
| 430 |
+
# mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
| 431 |
+
|
| 432 |
+
# with torch.no_grad():
|
| 433 |
+
# outputs = self.model(input_ids)
|
| 434 |
+
# logits = outputs.logits
|
| 435 |
+
|
| 436 |
+
# # Extract logits for the masked position
|
| 437 |
+
# mask_logits_tensor = logits[0, mask_token_index, :]
|
| 438 |
+
|
| 439 |
+
# # Get top 50 logits and corresponding tokens
|
| 440 |
+
# top_mask_logits, top_mask_indices = torch.topk(mask_logits_tensor, 50, dim=-1)
|
| 441 |
+
|
| 442 |
+
# # Convert token IDs to words
|
| 443 |
+
# top_tokens = [self.tokenizer.convert_ids_to_tokens(token_id.item()) for token_id in top_mask_indices[0]]
|
| 444 |
+
# print(f' ---- top_tokens : {top_tokens} ----- ')
|
| 445 |
+
|
| 446 |
+
# # Store results
|
| 447 |
+
# mask_logits[idx] = {
|
| 448 |
+
# "tokens": top_tokens,
|
| 449 |
+
# "logits": top_mask_logits.tolist()
|
| 450 |
+
# }
|
| 451 |
+
|
| 452 |
+
# return mask_logits
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
def process_sentences(self, sentences, result_dict, method="random"):
|
| 456 |
+
"""
|
| 457 |
+
Process sentences and calculate logits for masked tokens.
|
| 458 |
+
"""
|
| 459 |
+
results = {}
|
| 460 |
+
|
| 461 |
+
for sentence, ngrams in result_dict.items():
|
| 462 |
+
# Split punctuation from the last word before processing
|
| 463 |
+
words = sentence.split()
|
| 464 |
+
last_word = words[-1]
|
| 465 |
+
if any(last_word.endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
|
| 466 |
+
# Split the last word and punctuation
|
| 467 |
+
words[-1] = last_word[:-1]
|
| 468 |
+
punctuation = last_word[-1]
|
| 469 |
+
# Rejoin with space before punctuation to treat it as separate token
|
| 470 |
+
processed_sentence = " ".join(words) + " " + punctuation
|
| 471 |
+
else:
|
| 472 |
+
processed_sentence = sentence
|
| 473 |
+
|
| 474 |
+
if method == "random":
|
| 475 |
+
masked_sentence, original_mask_indices = self.mask_sentence_random(processed_sentence, ngrams)
|
| 476 |
+
elif method == "pseudorandom":
|
| 477 |
+
masked_sentence, original_mask_indices = self.mask_sentence_pseudorandom(processed_sentence, ngrams)
|
| 478 |
+
else: # entropy
|
| 479 |
+
masked_sentence, original_mask_indices = self.mask_sentence_entropy(processed_sentence, ngrams)
|
| 480 |
+
|
| 481 |
+
logits = self.calculate_mask_logits(processed_sentence, original_mask_indices)
|
| 482 |
+
results[sentence] = {
|
| 483 |
+
"masked_sentence": masked_sentence,
|
| 484 |
+
"mask_logits": logits
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
return results
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
if __name__ == "__main__":
|
| 492 |
+
# !!! Working both the cases regardless if the stopword is removed or not
|
| 493 |
+
sentences = [
|
| 494 |
+
"The quick brown fox jumps over small cat the lazy dog everyday again and again .",
|
| 495 |
+
# "A speedy brown fox jumps over a lazy dog.",
|
| 496 |
+
# "A swift brown fox leaps over the lethargic dog."
|
| 497 |
+
|
| 498 |
+
]
|
| 499 |
+
result_dict ={
|
| 500 |
+
'The quick brown fox jumps over small cat the lazy dog everyday again and again .': {'brown fox': [(2, 3)],'cat': [(7, 7)], 'dog': [(10, 10)]},
|
| 501 |
+
# 'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
|
| 502 |
+
# 'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}
|
| 503 |
+
}
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
processor = MaskingProcessor()
|
| 507 |
+
# results_random = processor.process_sentences(sentences, result_dict)
|
| 508 |
+
results_entropy = processor.process_sentences(sentences, result_dict, method="random")
|
| 509 |
+
|
| 510 |
+
'''
|
| 511 |
+
results structure :
|
| 512 |
+
results = {
|
| 513 |
+
"The quick brown fox jumps over the lazy dog everyday.":
|
| 514 |
+
{ # Original sentence as key
|
| 515 |
+
"masked_sentence": str, # The sentence with [MASK] tokens
|
| 516 |
+
"mask_logits":
|
| 517 |
+
{ # Dictionary of mask positions and their predictions
|
| 518 |
+
1:
|
| 519 |
+
{ # Position of mask in sentence
|
| 520 |
+
"tokens" (words) : list, # List of top 50 predicted tokens
|
| 521 |
+
"logits" (probabilities) : list # Corresponding logits for those tokens
|
| 522 |
+
},
|
| 523 |
+
7:
|
| 524 |
+
{
|
| 525 |
+
"tokens" (words) : list,
|
| 526 |
+
"logits" (probabilities) : list
|
| 527 |
+
},
|
| 528 |
+
10:
|
| 529 |
+
{
|
| 530 |
+
"tokens (words)": list,
|
| 531 |
+
"logits (probabilities)": list
|
| 532 |
+
}
|
| 533 |
+
}
|
| 534 |
+
}
|
| 535 |
+
}
|
| 536 |
+
|
| 537 |
+
'''
|
| 538 |
+
# results_entropy = processor.process_sentences(sentences, result_dict, method="entropy", remove_stopwords=False)
|
| 539 |
+
|
| 540 |
+
for sentence, output in results_entropy.items():
|
| 541 |
+
print(f"Original Sentence (Random): {sentence}")
|
| 542 |
+
print(f"Masked Sentence (Random): {output['masked_sentence']}")
|
| 543 |
+
# print(f"Mask Logits (Random): {output['mask_logits']}")
|
| 544 |
+
# print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
|
| 545 |
+
# print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
|
| 546 |
+
# print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
|
| 547 |
+
# print('--------------------------------')
|
| 548 |
+
# for mask_idx, logits in output["mask_logits"].items():
|
| 549 |
+
# print(f"Logits for [MASK] at position {mask_idx}:")
|
| 550 |
+
# print(f' logits : {logits[:5]}') # List of logits for all vocabulary tokens
|
| 551 |
+
# print(f' len(logits) : {len(logits)}')
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
# ------------------------------------------------------------------------------------------------
|
| 555 |
+
# def mask_sentence_random(self, sentence, common_ngrams):
|
| 556 |
+
# """
|
| 557 |
+
# Mask words in the sentence based on the specified rules after removing stopwords.
|
| 558 |
+
# """
|
| 559 |
+
# original_words = sentence.split()
|
| 560 |
+
# # print(f' ---- original_words : {original_words} ----- ')
|
| 561 |
+
# non_stop_words = self.remove_stopwords(original_words)
|
| 562 |
+
# # print(f' ---- non_stop_words : {non_stop_words} ----- ')
|
| 563 |
+
# adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
|
| 564 |
+
# # print(f' ---- common_ngrams : {common_ngrams} ----- ')
|
| 565 |
+
# # print(f' ---- adjusted_ngrams : {adjusted_ngrams} ----- ')
|
| 566 |
+
|
| 567 |
+
# mask_indices = []
|
| 568 |
+
|
| 569 |
+
# # Extract n-gram positions in non-stop words
|
| 570 |
+
# ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
|
| 571 |
+
|
| 572 |
+
# # Mask a word before the first common n-gram
|
| 573 |
+
# if ngram_positions:
|
| 574 |
+
# # print(f' ---- ngram_positions : {ngram_positions} ----- ')
|
| 575 |
+
# first_ngram_start = ngram_positions[0][0]
|
| 576 |
+
# # print(f' ---- first_ngram_start : {first_ngram_start} ----- ')
|
| 577 |
+
# if first_ngram_start > 0:
|
| 578 |
+
# mask_index_before_ngram = random.randint(0, first_ngram_start-1)
|
| 579 |
+
# # print(f' ---- mask_index_before_ngram : {mask_index_before_ngram} ----- ')
|
| 580 |
+
# mask_indices.append(mask_index_before_ngram)
|
| 581 |
+
|
| 582 |
+
# # Mask words between common n-grams
|
| 583 |
+
# for i in range(len(ngram_positions) - 1):
|
| 584 |
+
# end_prev = ngram_positions[i][1]
|
| 585 |
+
# # print(f' ---- end_prev : {end_prev} ----- ')
|
| 586 |
+
# start_next = ngram_positions[i + 1][0]
|
| 587 |
+
# # print(f' ---- start_next : {start_next} ----- ')
|
| 588 |
+
# if start_next > end_prev + 1:
|
| 589 |
+
# mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
|
| 590 |
+
# # print(f' ---- mask_index_between_ngrams : {mask_index_between_ngrams} ----- ')
|
| 591 |
+
# mask_indices.append(mask_index_between_ngrams)
|
| 592 |
+
|
| 593 |
+
# # Mask a word after the last common n-gram
|
| 594 |
+
# last_ngram_end = ngram_positions[-1][1]
|
| 595 |
+
# if last_ngram_end < len(non_stop_words) - 1:
|
| 596 |
+
# # print(f' ---- last_ngram_end : {last_ngram_end} ----- ')
|
| 597 |
+
# mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
|
| 598 |
+
# # print(f' ---- mask_index_after_ngram : {mask_index_after_ngram} ----- ')
|
| 599 |
+
# mask_indices.append(mask_index_after_ngram)
|
| 600 |
+
|
| 601 |
+
# # Create mapping from non-stop words to original indices
|
| 602 |
+
# non_stop_to_original = {}
|
| 603 |
+
# non_stop_idx = 0
|
| 604 |
+
# for orig_idx, word in enumerate(original_words):
|
| 605 |
+
# if word.lower() not in self.stop_words:
|
| 606 |
+
# non_stop_to_original[non_stop_idx] = orig_idx
|
| 607 |
+
# non_stop_idx += 1
|
| 608 |
+
|
| 609 |
+
# # Map mask indices from non-stop word positions to original positions
|
| 610 |
+
# # print(f' ---- non_stop_to_original : {non_stop_to_original} ----- ')
|
| 611 |
+
# original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
|
| 612 |
+
# # print(f' ---- original_mask_indices : {original_mask_indices} ----- ')
|
| 613 |
+
|
| 614 |
+
# # Apply masks to the original sentence
|
| 615 |
+
# masked_words = original_words.copy()
|
| 616 |
+
# for idx in original_mask_indices:
|
| 617 |
+
# masked_words[idx] = self.tokenizer.mask_token
|
| 618 |
+
|
| 619 |
+
# return " ".join(masked_words), original_mask_indices
|
utils/old/non_melting_points_v1.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import nltk
|
| 2 |
+
from nltk.corpus import stopwords
|
| 3 |
+
from nltk.util import ngrams
|
| 4 |
+
from collections import Counter
|
| 5 |
+
import re
|
| 6 |
+
|
| 7 |
+
class NgramProcessor:
|
| 8 |
+
def __init__(self):
|
| 9 |
+
try:
|
| 10 |
+
nltk.data.find('corpora/stopwords')
|
| 11 |
+
except LookupError:
|
| 12 |
+
nltk.download('stopwords')
|
| 13 |
+
|
| 14 |
+
self.stop_words = set(stopwords.words('english'))
|
| 15 |
+
|
| 16 |
+
def remove_stopwords(self, text):
|
| 17 |
+
"""
|
| 18 |
+
Remove stopwords using NLTK's stopword list
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
text (str): Input text
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
str: Cleaned text with stopwords removed
|
| 25 |
+
"""
|
| 26 |
+
words = re.findall(r'\w+', text.lower())
|
| 27 |
+
filtered_words = [word for word in words if word not in self.stop_words]
|
| 28 |
+
return ' '.join(filtered_words)
|
| 29 |
+
|
| 30 |
+
def is_exact_match(self, ngram, sentences):
|
| 31 |
+
"""
|
| 32 |
+
Check if the given n-gram has an exact match in all sentences
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
ngram (str): The n-gram to search for
|
| 36 |
+
sentences (list): List of sentences to search in
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
bool: True if n-gram has exact match in all sentences, False otherwise
|
| 40 |
+
"""
|
| 41 |
+
return all(ngram in sentence for sentence in sentences)
|
| 42 |
+
|
| 43 |
+
def is_substring_of_any(self, ngram, common_ngrams):
|
| 44 |
+
"""
|
| 45 |
+
Check if the given n-gram is an exact substring of any previously found common n-grams
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
ngram (str): The n-gram to check
|
| 49 |
+
common_ngrams (list): List of previously found common n-grams
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
bool: True if ngram is a substring of any common_ngrams, False otherwise
|
| 53 |
+
"""
|
| 54 |
+
return any(ngram in other_ngram for other_ngram in common_ngrams if ngram != other_ngram)
|
| 55 |
+
|
| 56 |
+
def find_filtered_ngrams(self, sentences):
|
| 57 |
+
"""
|
| 58 |
+
Find all n-grams that have exact matches across all sentences,
|
| 59 |
+
excluding those that are part of larger common n-grams
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
sentences (list): List of sentences to analyze
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
list: List of tuples where each tuple contains the n-gram and its indices in each sentence
|
| 66 |
+
"""
|
| 67 |
+
original_sentences = sentences[:]
|
| 68 |
+
sentences = [self.remove_stopwords(sentence) for sentence in sentences]
|
| 69 |
+
ngram_lengths = [4, 3, 2, 1] # Quadgram, trigram, bigram, unigram
|
| 70 |
+
common_ngrams = []
|
| 71 |
+
|
| 72 |
+
for n in ngram_lengths:
|
| 73 |
+
ngrams_list = [list(ngrams(sentence.split(), n)) for sentence in sentences]
|
| 74 |
+
ngrams_counter = Counter(ngrams_list[0])
|
| 75 |
+
|
| 76 |
+
for ngram in ngrams_counter:
|
| 77 |
+
ngram_str = ' '.join(ngram)
|
| 78 |
+
if self.is_exact_match(ngram_str, sentences) and not self.is_substring_of_any(ngram_str, [ng[0] for ng in common_ngrams]):
|
| 79 |
+
indices = []
|
| 80 |
+
for original_sentence in original_sentences:
|
| 81 |
+
words = original_sentence.split()
|
| 82 |
+
ngram_indices = [
|
| 83 |
+
(i, i + n - 1) for i in range(len(words) - n + 1)
|
| 84 |
+
if ' '.join(words[i:i + n]).lower() == ngram_str
|
| 85 |
+
]
|
| 86 |
+
indices.append(ngram_indices)
|
| 87 |
+
common_ngrams.append((ngram_str, indices))
|
| 88 |
+
|
| 89 |
+
return common_ngrams
|
| 90 |
+
|
| 91 |
+
def find_relative_order(self, sentence, common_ngrams):
|
| 92 |
+
"""
|
| 93 |
+
Find the relative order of the common n-grams in the sentence
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
sentence (str): Sentence in which to find the relative order
|
| 97 |
+
common_ngrams (list): List of common n-grams
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
list: List of tuples with the relative position and the n-gram
|
| 101 |
+
"""
|
| 102 |
+
relative_order = []
|
| 103 |
+
for ngram, _ in common_ngrams:
|
| 104 |
+
index = sentence.find(ngram)
|
| 105 |
+
if index != -1:
|
| 106 |
+
relative_order.append((index, ngram))
|
| 107 |
+
|
| 108 |
+
return sorted(relative_order)
|
| 109 |
+
|
| 110 |
+
# Example usage
|
| 111 |
+
if __name__ == "__main__":
|
| 112 |
+
sentences = [
|
| 113 |
+
"The quick brown fox jumps over the lazy dog.",
|
| 114 |
+
"A quick brown dog outpaces a lazy fox.",
|
| 115 |
+
"Quick brown animals leap over lazy obstacles."
|
| 116 |
+
]
|
| 117 |
+
|
| 118 |
+
processor = NgramProcessor()
|
| 119 |
+
common_ngrams = processor.find_filtered_ngrams(sentences)
|
| 120 |
+
print("Common n-grams and their indices:")
|
| 121 |
+
for ngram, indices in common_ngrams:
|
| 122 |
+
print(f"{ngram}: {indices}")
|
| 123 |
+
|
| 124 |
+
for sentence in sentences:
|
| 125 |
+
relative_order = processor.find_relative_order(sentence, common_ngrams)
|
| 126 |
+
print(f"Relative order in sentence '{sentence}':", relative_order)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# import nltk
|
| 131 |
+
# from nltk.corpus import stopwords
|
| 132 |
+
# from nltk.util import ngrams
|
| 133 |
+
# from collections import Counter
|
| 134 |
+
# import re
|
| 135 |
+
|
| 136 |
+
# class NgramProcessor:
|
| 137 |
+
# def __init__(self):
|
| 138 |
+
# try:
|
| 139 |
+
# nltk.data.find('corpora/stopwords')
|
| 140 |
+
# except LookupError:
|
| 141 |
+
# nltk.download('stopwords')
|
| 142 |
+
|
| 143 |
+
# self.stop_words = set(stopwords.words('english'))
|
| 144 |
+
|
| 145 |
+
# def remove_stopwords(self, text):
|
| 146 |
+
# """
|
| 147 |
+
# Remove stopwords using NLTK's stopword list
|
| 148 |
+
|
| 149 |
+
# Args:
|
| 150 |
+
# text (str): Input text
|
| 151 |
+
|
| 152 |
+
# Returns:
|
| 153 |
+
# str: Cleaned text with stopwords removed
|
| 154 |
+
# """
|
| 155 |
+
# words = re.findall(r'\w+', text.lower())
|
| 156 |
+
# filtered_words = [word for word in words if word not in self.stop_words]
|
| 157 |
+
# return ' '.join(filtered_words)
|
| 158 |
+
|
| 159 |
+
# def is_exact_match(self, ngram, sentences):
|
| 160 |
+
# """
|
| 161 |
+
# Check if the given n-gram has an exact match in all sentences
|
| 162 |
+
|
| 163 |
+
# Args:
|
| 164 |
+
# ngram (str): The n-gram to search for
|
| 165 |
+
# sentences (list): List of sentences to search in
|
| 166 |
+
|
| 167 |
+
# Returns:
|
| 168 |
+
# bool: True if n-gram has exact match in all sentences, False otherwise
|
| 169 |
+
# """
|
| 170 |
+
# return all(ngram in sentence for sentence in sentences)
|
| 171 |
+
|
| 172 |
+
# def is_substring_of_any(self, ngram, common_ngrams):
|
| 173 |
+
# """
|
| 174 |
+
# Check if the given n-gram is an exact substring of any previously found common n-grams
|
| 175 |
+
|
| 176 |
+
# Args:
|
| 177 |
+
# ngram (str): The n-gram to check
|
| 178 |
+
# common_ngrams (list): List of previously found common n-grams
|
| 179 |
+
|
| 180 |
+
# Returns:
|
| 181 |
+
# bool: True if ngram is a substring of any common_ngrams, False otherwise
|
| 182 |
+
# """
|
| 183 |
+
# return any(ngram in other_ngram for other_ngram in common_ngrams if ngram != other_ngram)
|
| 184 |
+
|
| 185 |
+
# def find_filtered_ngrams(self, sentences):
|
| 186 |
+
# """
|
| 187 |
+
# Find all n-grams that have exact matches across all sentences,
|
| 188 |
+
# excluding those that are part of larger common n-grams
|
| 189 |
+
|
| 190 |
+
# Args:
|
| 191 |
+
# sentences (list): List of sentences to analyze
|
| 192 |
+
|
| 193 |
+
# Returns:
|
| 194 |
+
# list: List of all common n-grams in order of their appearance in the first sentence
|
| 195 |
+
# """
|
| 196 |
+
# sentences = [self.remove_stopwords(sentence) for sentence in sentences]
|
| 197 |
+
# ngram_lengths = [4, 3, 2, 1] # Quadgram, trigram, bigram, unigram
|
| 198 |
+
# common_ngrams = []
|
| 199 |
+
|
| 200 |
+
# for n in ngram_lengths:
|
| 201 |
+
# ngrams_list = [list(ngrams(sentence.split(), n)) for sentence in sentences]
|
| 202 |
+
# ngrams_counter = Counter(ngrams_list[0])
|
| 203 |
+
|
| 204 |
+
# for ngram in ngrams_counter:
|
| 205 |
+
# ngram_str = ' '.join(ngram)
|
| 206 |
+
# if self.is_exact_match(ngram_str, sentences) and not self.is_substring_of_any(ngram_str, common_ngrams):
|
| 207 |
+
# common_ngrams.append(ngram_str)
|
| 208 |
+
|
| 209 |
+
# return common_ngrams
|
| 210 |
+
|
| 211 |
+
# def find_relative_order(self, sentence, common_ngrams):
|
| 212 |
+
# """
|
| 213 |
+
# Find the relative order of the common n-grams in the sentence
|
| 214 |
+
|
| 215 |
+
# Args:
|
| 216 |
+
# sentence (str): Sentence in which to find the relative order
|
| 217 |
+
# common_ngrams (list): List of common n-grams
|
| 218 |
+
|
| 219 |
+
# Returns:
|
| 220 |
+
# list: List of tuples with the relative position and the n-gram
|
| 221 |
+
# """
|
| 222 |
+
# relative_order = []
|
| 223 |
+
# for ngram in common_ngrams:
|
| 224 |
+
# index = sentence.find(ngram)
|
| 225 |
+
# if index != -1:
|
| 226 |
+
# relative_order.append((index, ngram))
|
| 227 |
+
|
| 228 |
+
# return sorted(relative_order)
|
| 229 |
+
|
| 230 |
+
# # Example usage
|
| 231 |
+
# if __name__ == "__main__":
|
| 232 |
+
# sentences = [
|
| 233 |
+
# "The quick brown fox jumps over the lazy dog.",
|
| 234 |
+
# "A quick brown dog outpaces a lazy fox.",
|
| 235 |
+
# "Quick brown animals leap over lazy obstacles."
|
| 236 |
+
# ]
|
| 237 |
+
|
| 238 |
+
# processor = NgramProcessor()
|
| 239 |
+
# common_ngrams = processor.find_filtered_ngrams(sentences)
|
| 240 |
+
# print("Common n-grams:", common_ngrams)
|
| 241 |
+
|
| 242 |
+
# for sentence in sentences:
|
| 243 |
+
# relative_order = processor.find_relative_order(sentence, common_ngrams)
|
| 244 |
+
# print(f"Relative order in sentence '{sentence}':", relative_order)
|
utils/old/sampling/sampling.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import random
|
| 3 |
+
from masking_methods import MaskingProcessor
|
| 4 |
+
import nltk
|
| 5 |
+
from nltk.corpus import words
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SamplingProcessor:
|
| 10 |
+
def __init__(self, tokenizer):
|
| 11 |
+
"""
|
| 12 |
+
Initialize the SamplingProcessor.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
tokenizer: BERT tokenizer instance
|
| 16 |
+
"""
|
| 17 |
+
self.tokenizer = tokenizer
|
| 18 |
+
self.subtoken_prefix = self._get_subtoken_prefix()
|
| 19 |
+
self.subtoken_ids = self._get_subtoken_ids()
|
| 20 |
+
try:
|
| 21 |
+
nltk.data.find('corpora/words')
|
| 22 |
+
except LookupError:
|
| 23 |
+
nltk.download('words')
|
| 24 |
+
self.english_words = set(words.words())
|
| 25 |
+
|
| 26 |
+
# def _get_subtoken_prefix(self):
|
| 27 |
+
# """
|
| 28 |
+
# Identify the subtoken prefix based on the tokenizer.
|
| 29 |
+
|
| 30 |
+
# Returns:
|
| 31 |
+
# str: The prefix used for subtokens (e.g., "##" for BERT).
|
| 32 |
+
# """
|
| 33 |
+
# # This method assumes that the tokenizer uses a consistent subtoken prefix.
|
| 34 |
+
# # Adjust accordingly if using different tokenizers.
|
| 35 |
+
# # For BERT's WordPiece tokenizer:
|
| 36 |
+
# if hasattr(self.tokenizer, "init_kwargs") and "wordpiece_prefix" in self.tokenizer.init_kwargs:
|
| 37 |
+
# return self.tokenizer.init_kwargs["wordpiece_prefix"]
|
| 38 |
+
# elif hasattr(self.tokenizer, "prefix_tokens"):
|
| 39 |
+
# return self.tokenizer.prefix_tokens
|
| 40 |
+
# else:
|
| 41 |
+
# # Default to BERT's subtoken prefix
|
| 42 |
+
# return "##"
|
| 43 |
+
|
| 44 |
+
def _get_subtoken_prefix(self):
|
| 45 |
+
"""
|
| 46 |
+
Identify the subtoken prefix based on the tokenizer.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
str: The prefix used for subtokens (e.g., "##" for BERT).
|
| 50 |
+
"""
|
| 51 |
+
# This method assumes that the tokenizer uses a consistent subtoken prefix.
|
| 52 |
+
# Adjust accordingly if using different tokenizers.
|
| 53 |
+
# For BERT's WordPiece tokenizer:
|
| 54 |
+
if hasattr(self.tokenizer, "init_kwargs") and "wordpiece_prefix" in self.tokenizer.init_kwargs:
|
| 55 |
+
return self.tokenizer.init_kwargs["wordpiece_prefix"]
|
| 56 |
+
elif hasattr(self.tokenizer, "prefix_tokens"):
|
| 57 |
+
return self.tokenizer.prefix_tokens
|
| 58 |
+
else:
|
| 59 |
+
# Default to BERT's subtoken prefix
|
| 60 |
+
return "##"
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# def _get_subtoken_ids(self):
|
| 64 |
+
# """
|
| 65 |
+
# Retrieve all token IDs that correspond to subtokens.
|
| 66 |
+
|
| 67 |
+
# Returns:
|
| 68 |
+
# set: A set of subtoken IDs.
|
| 69 |
+
# """
|
| 70 |
+
# vocab = self.tokenizer.get_vocab()
|
| 71 |
+
# subtoken_ids = set()
|
| 72 |
+
# for token, idx in vocab.items():
|
| 73 |
+
# if token.startswith(self.subtoken_prefix):
|
| 74 |
+
# subtoken_ids.add(idx)
|
| 75 |
+
# return subtoken_ids
|
| 76 |
+
|
| 77 |
+
def _get_subtoken_ids(self):
|
| 78 |
+
"""
|
| 79 |
+
Retrieve all token IDs that correspond to subtokens.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
list: A list of subtoken IDs.
|
| 83 |
+
"""
|
| 84 |
+
vocab = self.tokenizer.get_vocab()
|
| 85 |
+
subtoken_ids = []
|
| 86 |
+
for token, idx in vocab.items():
|
| 87 |
+
if token.startswith(self.subtoken_prefix):
|
| 88 |
+
subtoken_ids.append(idx)
|
| 89 |
+
return subtoken_ids # Changed from set to list
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0):
|
| 93 |
+
tokens = self.tokenizer.tokenize(masked_sentence)
|
| 94 |
+
|
| 95 |
+
for mask_pos in sorted(mask_logits_dict.keys()):
|
| 96 |
+
try:
|
| 97 |
+
# Get logits and squeeze extra dimension
|
| 98 |
+
mask_logits = torch.tensor(mask_logits_dict[mask_pos]).squeeze(0) # Remove the extra dimension
|
| 99 |
+
|
| 100 |
+
# Create a mask for valid tokens (no special tokens, no subwords)
|
| 101 |
+
valid_mask = torch.zeros_like(mask_logits, dtype=torch.bool)
|
| 102 |
+
for idx in range(len(mask_logits)):
|
| 103 |
+
token = self.tokenizer.convert_ids_to_tokens([idx])[0]
|
| 104 |
+
# Only allow regular words (no special tokens, no subwords)
|
| 105 |
+
if token.isalpha() and not token.startswith('[') and not token.startswith('##'):
|
| 106 |
+
valid_mask[idx] = True
|
| 107 |
+
|
| 108 |
+
# Get valid logits
|
| 109 |
+
valid_logits = mask_logits[valid_mask]
|
| 110 |
+
valid_indices = torch.where(valid_mask)[0]
|
| 111 |
+
|
| 112 |
+
if len(valid_logits) == 0:
|
| 113 |
+
print(f"Warning: No valid tokens found for position {mask_pos}")
|
| 114 |
+
continue
|
| 115 |
+
|
| 116 |
+
if sampling_technique == "inverse_transform":
|
| 117 |
+
probs = torch.softmax(valid_logits / temperature, dim=-1)
|
| 118 |
+
cumulative_probs = torch.cumsum(probs, dim=-1)
|
| 119 |
+
random_prob = random.random()
|
| 120 |
+
sampled_idx = torch.where(cumulative_probs >= random_prob)[0][0].item()
|
| 121 |
+
sampled_index = valid_indices[sampled_idx].item()
|
| 122 |
+
|
| 123 |
+
elif sampling_technique == "exponential_minimum":
|
| 124 |
+
probs = torch.softmax(valid_logits / temperature, dim=-1)
|
| 125 |
+
exp_probs = torch.exp(-torch.log(probs))
|
| 126 |
+
random_probs = torch.rand_like(exp_probs)
|
| 127 |
+
sampled_idx = torch.argmax(random_probs * exp_probs).item()
|
| 128 |
+
sampled_index = valid_indices[sampled_idx].item()
|
| 129 |
+
|
| 130 |
+
elif sampling_technique == "temperature":
|
| 131 |
+
valid_logits = torch.clamp(valid_logits, min=-1e8, max=1e8)
|
| 132 |
+
probs = torch.softmax(valid_logits / temperature, dim=-1)
|
| 133 |
+
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
|
| 134 |
+
raise ValueError("The computed probabilities contain NaN or inf values.")
|
| 135 |
+
probs = torch.max(probs, torch.tensor(1e-8))
|
| 136 |
+
probs = probs / torch.sum(probs)
|
| 137 |
+
sampled_idx = torch.multinomial(probs, 1)[0].item()
|
| 138 |
+
sampled_index = valid_indices[sampled_idx].item()
|
| 139 |
+
|
| 140 |
+
elif sampling_technique == 'greedy':
|
| 141 |
+
sampled_idx = torch.argmax(valid_logits).item()
|
| 142 |
+
sampled_index = valid_indices[sampled_idx].item()
|
| 143 |
+
|
| 144 |
+
# Replace mask with sampled token
|
| 145 |
+
sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0]
|
| 146 |
+
tokens[mask_pos] = sampled_token
|
| 147 |
+
|
| 148 |
+
except Exception as e:
|
| 149 |
+
print(f"Error sampling for position {mask_pos}: {str(e)}")
|
| 150 |
+
continue
|
| 151 |
+
|
| 152 |
+
return self.tokenizer.convert_tokens_to_string(tokens)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def process_masked_sentences(self, results_dict, sampling_technique="temperature", temperature=1.0):
|
| 157 |
+
"""
|
| 158 |
+
Process all masked sentences in the results dictionary.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
results_dict (dict): Dictionary containing masked sentences and their logits
|
| 162 |
+
sampling_technique (str): Sampling method to use
|
| 163 |
+
temperature (float): Temperature parameter for sampling
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
dict: Dictionary containing original, masked, and sampled sentences
|
| 167 |
+
"""
|
| 168 |
+
processed_results = {}
|
| 169 |
+
|
| 170 |
+
for original_sentence, data in results_dict.items():
|
| 171 |
+
masked_sentence = data["masked_sentence"]
|
| 172 |
+
mask_logits = data["mask_logits"]
|
| 173 |
+
|
| 174 |
+
sampled_sentence = self.sample_tokens(
|
| 175 |
+
mask_logits,
|
| 176 |
+
masked_sentence,
|
| 177 |
+
sampling_technique,
|
| 178 |
+
temperature
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
processed_results[original_sentence] = {
|
| 182 |
+
"masked_sentence": masked_sentence,
|
| 183 |
+
"sampled_sentence": sampled_sentence
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
return processed_results
|
| 187 |
+
|
| 188 |
+
if __name__ == "__main__":
|
| 189 |
+
sentences = [
|
| 190 |
+
"The quick brown fox jumps over the lazy dog everyday.",
|
| 191 |
+
]
|
| 192 |
+
result_dict = {
|
| 193 |
+
'The quick brown fox jumps over the lazy dog everyday.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
# First, mask the sentences
|
| 197 |
+
masking_processor = MaskingProcessor()
|
| 198 |
+
masking_results = masking_processor.process_sentences(sentences, result_dict)
|
| 199 |
+
|
| 200 |
+
# Then, sample replacements for the masks
|
| 201 |
+
sampling_processor = SamplingProcessor(masking_processor.tokenizer)
|
| 202 |
+
|
| 203 |
+
# Try different sampling techniques
|
| 204 |
+
sampling_techniques = ["temperature", "greedy", "inverse_transform", "exponential_minimum"]
|
| 205 |
+
|
| 206 |
+
for technique in sampling_techniques:
|
| 207 |
+
print(f"\nSampling using {technique}:")
|
| 208 |
+
sampled_results = sampling_processor.process_masked_sentences(
|
| 209 |
+
masking_results,
|
| 210 |
+
sampling_technique=technique,
|
| 211 |
+
temperature=1.0
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
for original_sentence, result in sampled_results.items():
|
| 215 |
+
print(f"Original: {original_sentence}")
|
| 216 |
+
print(f"Masked: {result['masked_sentence']}")
|
| 217 |
+
print(f"Sampled: {result['sampled_sentence']}")
|
| 218 |
+
print("---")
|
| 219 |
+
|
| 220 |
+
# --------------------------------------------------------------------------------------------------
|
| 221 |
+
# def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0, top_k=100):
|
| 222 |
+
# words = masked_sentence.split()
|
| 223 |
+
# mask_positions = sorted(mask_logits_dict.keys())
|
| 224 |
+
|
| 225 |
+
# for mask_pos in mask_positions:
|
| 226 |
+
# mask_logits = torch.tensor(mask_logits_dict[mask_pos])
|
| 227 |
+
|
| 228 |
+
# try:
|
| 229 |
+
# if sampling_technique == "inverse_transform":
|
| 230 |
+
# probs = torch.softmax(mask_logits / temperature, dim=-1)
|
| 231 |
+
# cumulative_probs = torch.cumsum(probs, dim=-1)
|
| 232 |
+
# random_prob = random.random()
|
| 233 |
+
# sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
|
| 234 |
+
|
| 235 |
+
# elif sampling_technique == "exponential_minimum":
|
| 236 |
+
# probs = torch.softmax(mask_logits / temperature, dim=-1)
|
| 237 |
+
# exp_probs = torch.exp(-torch.log(probs))
|
| 238 |
+
# random_probs = torch.rand_like(exp_probs)
|
| 239 |
+
# sampled_index = torch.argmax(random_probs * exp_probs).item()
|
| 240 |
+
|
| 241 |
+
# elif sampling_technique == "temperature":
|
| 242 |
+
# mask_logits = torch.clamp(mask_logits, min=-1e8, max=1e8)
|
| 243 |
+
# probs = torch.softmax(mask_logits / temperature, dim=-1)
|
| 244 |
+
# if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
|
| 245 |
+
# raise ValueError("The computed probabilities contain NaN or inf values.")
|
| 246 |
+
# probs = torch.max(probs, torch.tensor(1e-8))
|
| 247 |
+
# probs = probs / torch.sum(probs)
|
| 248 |
+
# sampled_index = torch.multinomial(probs, 1)[0].item()
|
| 249 |
+
|
| 250 |
+
# elif sampling_technique == 'greedy':
|
| 251 |
+
# sampled_index = torch.argmax(mask_logits).item()
|
| 252 |
+
|
| 253 |
+
# else:
|
| 254 |
+
# raise ValueError(f"Unknown sampling technique: {sampling_technique}")
|
| 255 |
+
|
| 256 |
+
# # Replace mask with sampled token
|
| 257 |
+
# sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0]
|
| 258 |
+
# words[mask_pos] = sampled_token
|
| 259 |
+
|
| 260 |
+
# except Exception as e:
|
| 261 |
+
# print(f"Error sampling for position {mask_pos}: {str(e)}")
|
| 262 |
+
# continue
|
| 263 |
+
|
| 264 |
+
# return " ".join(words)
|
| 265 |
+
|
| 266 |
+
## MORE WEIRD RESULTS
|
| 267 |
+
# def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0, top_k=100):
|
| 268 |
+
# words = masked_sentence.split()
|
| 269 |
+
# mask_positions = sorted(mask_logits_dict.keys())
|
| 270 |
+
|
| 271 |
+
# for mask_pos in mask_positions:
|
| 272 |
+
# mask_logits = torch.tensor(mask_logits_dict[mask_pos])
|
| 273 |
+
|
| 274 |
+
# try:
|
| 275 |
+
# # Create a mask for valid tokens (no special tokens, no subwords)
|
| 276 |
+
# valid_mask = torch.zeros_like(mask_logits, dtype=torch.bool)
|
| 277 |
+
# for idx in range(len(mask_logits)):
|
| 278 |
+
# token = self.tokenizer.convert_ids_to_tokens([idx])[0]
|
| 279 |
+
# # Only allow regular words (no special tokens, no subwords)
|
| 280 |
+
# if token.isalpha() and not token.startswith('[') and not token.startswith('##'):
|
| 281 |
+
# valid_mask[idx] = True
|
| 282 |
+
|
| 283 |
+
# # Get valid logits
|
| 284 |
+
# valid_logits = mask_logits[valid_mask]
|
| 285 |
+
# valid_indices = torch.where(valid_mask)[0]
|
| 286 |
+
|
| 287 |
+
# if len(valid_logits) == 0:
|
| 288 |
+
# print(f"Warning: No valid tokens found for position {mask_pos}")
|
| 289 |
+
# continue
|
| 290 |
+
|
| 291 |
+
# if sampling_technique == "inverse_transform":
|
| 292 |
+
# probs = torch.softmax(valid_logits / temperature, dim=-1)
|
| 293 |
+
# cumulative_probs = torch.cumsum(probs, dim=-1)
|
| 294 |
+
# random_prob = random.random()
|
| 295 |
+
# sampled_idx = torch.where(cumulative_probs >= random_prob)[0][0].item()
|
| 296 |
+
# sampled_index = valid_indices[sampled_idx].item()
|
| 297 |
+
|
| 298 |
+
# elif sampling_technique == "exponential_minimum":
|
| 299 |
+
# probs = torch.softmax(valid_logits / temperature, dim=-1)
|
| 300 |
+
# exp_probs = torch.exp(-torch.log(probs))
|
| 301 |
+
# random_probs = torch.rand_like(exp_probs)
|
| 302 |
+
# sampled_idx = torch.argmax(random_probs * exp_probs).item()
|
| 303 |
+
# sampled_index = valid_indices[sampled_idx].item()
|
| 304 |
+
|
| 305 |
+
# elif sampling_technique == "temperature":
|
| 306 |
+
# valid_logits = torch.clamp(valid_logits, min=-1e8, max=1e8)
|
| 307 |
+
# probs = torch.softmax(valid_logits / temperature, dim=-1)
|
| 308 |
+
# if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
|
| 309 |
+
# raise ValueError("The computed probabilities contain NaN or inf values.")
|
| 310 |
+
# probs = torch.max(probs, torch.tensor(1e-8))
|
| 311 |
+
# probs = probs / torch.sum(probs)
|
| 312 |
+
# sampled_idx = torch.multinomial(probs, 1)[0].item()
|
| 313 |
+
# sampled_index = valid_indices[sampled_idx].item()
|
| 314 |
+
|
| 315 |
+
# elif sampling_technique == 'greedy':
|
| 316 |
+
# sampled_idx = torch.argmax(valid_logits).item()
|
| 317 |
+
# sampled_index = valid_indices[sampled_idx].item()
|
| 318 |
+
|
| 319 |
+
# else:
|
| 320 |
+
# raise ValueError(f"Unknown sampling technique: {sampling_technique}")
|
| 321 |
+
|
| 322 |
+
# # Replace mask with sampled token
|
| 323 |
+
# sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0]
|
| 324 |
+
# words[mask_pos] = sampled_token
|
| 325 |
+
|
| 326 |
+
# except Exception as e:
|
| 327 |
+
# print(f"Error sampling for position {mask_pos}: {str(e)}")
|
| 328 |
+
# continue
|
| 329 |
+
|
| 330 |
+
# return " ".join(words)
|
utils/old/sampling/sampling_methods.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import BertTokenizer, BertForMaskedLM
|
| 2 |
+
import torch
|
| 3 |
+
import random
|
| 4 |
+
from masking_methods import MaskingProcessor
|
| 5 |
+
from transformers import pipeline
|
| 6 |
+
|
| 7 |
+
class SamplingProcessorWithModel:
|
| 8 |
+
def __init__(self, model_name='bert-base-uncased'):
|
| 9 |
+
self.tokenizer = BertTokenizer.from_pretrained(model_name)
|
| 10 |
+
self.model = BertForMaskedLM.from_pretrained(model_name)
|
| 11 |
+
self.model.eval() # Set the model to evaluation mode
|
| 12 |
+
|
| 13 |
+
def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0):
|
| 14 |
+
"""
|
| 15 |
+
Fills each mask in the masked sentence using the specified sampling technique.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
masked_sentence (str): Sentence with [MASK] tokens.
|
| 19 |
+
sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
|
| 20 |
+
temperature (float): Temperature parameter for sampling methods.
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
str: Sentence with the masks filled.
|
| 24 |
+
"""
|
| 25 |
+
input_ids = self.tokenizer.encode(masked_sentence, return_tensors="pt")
|
| 26 |
+
|
| 27 |
+
while self.tokenizer.mask_token_id in input_ids[0]:
|
| 28 |
+
# Find indices of all [MASK] tokens
|
| 29 |
+
mask_indices = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
|
| 30 |
+
|
| 31 |
+
# Process the first [MASK] token in the sequence
|
| 32 |
+
mask_index = mask_indices[0].item()
|
| 33 |
+
|
| 34 |
+
# Get logits from the model
|
| 35 |
+
with torch.no_grad():
|
| 36 |
+
outputs = self.model(input_ids)
|
| 37 |
+
logits = outputs.logits
|
| 38 |
+
|
| 39 |
+
# Extract logits for the [MASK] token
|
| 40 |
+
mask_logits = logits[0, mask_index]
|
| 41 |
+
|
| 42 |
+
if sampling_technique == "inverse_transform":
|
| 43 |
+
probs = torch.softmax(mask_logits / temperature, dim=-1)
|
| 44 |
+
cumulative_probs = torch.cumsum(probs, dim=-1)
|
| 45 |
+
random_prob = random.random()
|
| 46 |
+
sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
|
| 47 |
+
|
| 48 |
+
elif sampling_technique == "exponential_minimum":
|
| 49 |
+
probs = torch.softmax(mask_logits / temperature, dim=-1)
|
| 50 |
+
exp_probs = torch.exp(-torch.log(probs))
|
| 51 |
+
random_probs = torch.rand_like(exp_probs)
|
| 52 |
+
sampled_index = torch.argmax(random_probs * exp_probs).item()
|
| 53 |
+
|
| 54 |
+
elif sampling_technique == "temperature":
|
| 55 |
+
mask_logits = torch.clamp(mask_logits, min=-1e8, max=1e8)
|
| 56 |
+
probs = torch.softmax(mask_logits / temperature, dim=-1)
|
| 57 |
+
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
|
| 58 |
+
raise ValueError("The computed probabilities contain NaN or inf values.")
|
| 59 |
+
probs = torch.max(probs, torch.tensor(1e-8, device=mask_logits.device))
|
| 60 |
+
probs = probs / torch.sum(probs)
|
| 61 |
+
probs = probs.flatten()
|
| 62 |
+
if probs.size(0) > 1:
|
| 63 |
+
sampled_index = torch.multinomial(probs, 1).item()
|
| 64 |
+
else:
|
| 65 |
+
sampled_index = torch.argmax(probs).item()
|
| 66 |
+
|
| 67 |
+
elif sampling_technique == 'greedy':
|
| 68 |
+
sampled_index = torch.argmax(mask_logits).item()
|
| 69 |
+
|
| 70 |
+
else:
|
| 71 |
+
raise ValueError(f"Unknown sampling technique: {sampling_technique}")
|
| 72 |
+
|
| 73 |
+
# Replace the first [MASK] with the selected token
|
| 74 |
+
input_ids[0, mask_index] = sampled_index
|
| 75 |
+
|
| 76 |
+
return self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
| 77 |
+
|
| 78 |
+
def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0):
|
| 79 |
+
"""
|
| 80 |
+
Fills each mask in the masked sentence using the specified sampling technique.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
masked_sentence (str): Sentence with [MASK] tokens.
|
| 84 |
+
sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
|
| 85 |
+
temperature (float): Temperature parameter for sampling methods.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
str: Sentence with the masks filled.
|
| 89 |
+
"""
|
| 90 |
+
while '[MASK]' in masked_sentence:
|
| 91 |
+
# Get predictions for the first [MASK]
|
| 92 |
+
predictions = self.unmasker(masked_sentence)
|
| 93 |
+
|
| 94 |
+
# Ensure predictions is a list of dictionaries
|
| 95 |
+
if not isinstance(predictions, list) or not all(isinstance(pred, dict) for pred in predictions):
|
| 96 |
+
raise ValueError("Unexpected structure in predictions from the pipeline.")
|
| 97 |
+
|
| 98 |
+
# Extract logits (scores) from the predictions
|
| 99 |
+
logits = torch.tensor([pred['score'] for pred in predictions], dtype=torch.float32)
|
| 100 |
+
|
| 101 |
+
if sampling_technique == "inverse_transform":
|
| 102 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
| 103 |
+
cumulative_probs = torch.cumsum(probs, dim=-1)
|
| 104 |
+
random_prob = random.random()
|
| 105 |
+
sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
|
| 106 |
+
|
| 107 |
+
elif sampling_technique == "exponential_minimum":
|
| 108 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
| 109 |
+
exp_probs = torch.exp(-torch.log(probs))
|
| 110 |
+
random_probs = torch.rand_like(exp_probs)
|
| 111 |
+
sampled_index = torch.argmax(random_probs * exp_probs).item()
|
| 112 |
+
|
| 113 |
+
elif sampling_technique == "temperature":
|
| 114 |
+
logits = torch.clamp(logits, min=-1e8, max=1e8)
|
| 115 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
| 116 |
+
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
|
| 117 |
+
raise ValueError("The computed probabilities contain NaN or inf values.")
|
| 118 |
+
probs = torch.max(probs, torch.tensor(1e-8, device=logits.device))
|
| 119 |
+
probs = probs / torch.sum(probs)
|
| 120 |
+
probs = probs.flatten()
|
| 121 |
+
if probs.size(0) > 1:
|
| 122 |
+
sampled_index = torch.multinomial(probs, 1).item()
|
| 123 |
+
else:
|
| 124 |
+
sampled_index = torch.argmax(probs).item()
|
| 125 |
+
|
| 126 |
+
elif sampling_technique == 'greedy':
|
| 127 |
+
sampled_index = torch.argmax(logits).item()
|
| 128 |
+
|
| 129 |
+
else:
|
| 130 |
+
raise ValueError(f"Unknown sampling technique: {sampling_technique}")
|
| 131 |
+
|
| 132 |
+
# Replace the first [MASK] with the selected word
|
| 133 |
+
sampled_token = predictions[sampled_index]['token_str']
|
| 134 |
+
masked_sentence = masked_sentence.replace('[MASK]', sampled_token, 1)
|
| 135 |
+
|
| 136 |
+
return masked_sentence
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# Example usage
|
| 140 |
+
if __name__ == "__main__":
|
| 141 |
+
from transformers import BertTokenizer
|
| 142 |
+
|
| 143 |
+
# Define sentences and result_dict
|
| 144 |
+
sentences = [
|
| 145 |
+
"The quick brown fox jumps over the lazy dog.",
|
| 146 |
+
"A quick brown dog outpaces a lazy fox.",
|
| 147 |
+
"Quick brown dog leaps over lazy the fox."
|
| 148 |
+
]
|
| 149 |
+
result_dict = {
|
| 150 |
+
"The quick brown fox jumps over the lazy dog.": {'quick brown': [(0, 1)], 'fox': [(2, 2)], 'lazy': [(4, 4)], 'dog': [(5, 5)]},
|
| 151 |
+
"A quick brown dog outpaces a lazy fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]},
|
| 152 |
+
"Quick brown dog leaps over lazy the fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]}
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
masking_processor = MaskingProcessor()
|
| 156 |
+
masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
|
| 157 |
+
|
| 158 |
+
# Use SamplingProcessor
|
| 159 |
+
sampling_processor = SamplingProcessorWithModel()
|
| 160 |
+
|
| 161 |
+
# Iterate through masking results to apply sampling
|
| 162 |
+
for sentence, result in masking_results.items():
|
| 163 |
+
print(f"Original Sentence (Random): {sentence}")
|
| 164 |
+
print(f"Masked Sentence (Random): {result['masked_sentence']}")
|
| 165 |
+
masked_sentence = result["masked_sentence"]
|
| 166 |
+
|
| 167 |
+
# Apply different sampling techniques
|
| 168 |
+
for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]:
|
| 169 |
+
print(f"Sampling Technique: {technique}")
|
| 170 |
+
filled_sentence = sampling_processor.fill_masked_sentence(
|
| 171 |
+
masked_sentence=masked_sentence,
|
| 172 |
+
sampling_technique=technique,
|
| 173 |
+
temperature=1.0 # Adjust temperature as needed
|
| 174 |
+
)
|
| 175 |
+
print(f"Filled Sentence: {filled_sentence}\n")
|
| 176 |
+
print('--------------------------------')
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# from transformers import pipeline
|
| 181 |
+
# import torch
|
| 182 |
+
# import random
|
| 183 |
+
# from masking_methods import MaskingProcessor
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# class SamplingProcessorWithPipeline:
|
| 187 |
+
# def __init__(self, model_name='bert-base-uncased'):
|
| 188 |
+
# self.unmasker = pipeline('fill-mask', model=model_name)
|
| 189 |
+
# self.tokenizer = self.unmasker.tokenizer
|
| 190 |
+
|
| 191 |
+
# def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0):
|
| 192 |
+
# """
|
| 193 |
+
# Fills each mask in the masked sentence using the specified sampling technique.
|
| 194 |
+
|
| 195 |
+
# Args:
|
| 196 |
+
# masked_sentence (str): Sentence with [MASK] tokens.
|
| 197 |
+
# sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
|
| 198 |
+
# temperature (float): Temperature parameter for sampling methods.
|
| 199 |
+
|
| 200 |
+
# Returns:
|
| 201 |
+
# str: Sentence with the masks filled.
|
| 202 |
+
# """
|
| 203 |
+
# while '[MASK]' in masked_sentence:
|
| 204 |
+
# # Get predictions for the first [MASK]
|
| 205 |
+
# predictions = self.unmasker(masked_sentence)
|
| 206 |
+
# print(f' predictions : {predictions}')
|
| 207 |
+
# print(f' type of predictions : {type(predictions)}')
|
| 208 |
+
|
| 209 |
+
# # Ensure predictions is a list of dictionaries for the first [MASK]
|
| 210 |
+
# if not isinstance(predictions, list) or not all(isinstance(pred, dict) for pred in predictions):
|
| 211 |
+
# raise ValueError("Unexpected structure in predictions from the pipeline.")
|
| 212 |
+
|
| 213 |
+
# # Extract logits (scores) from the predictions
|
| 214 |
+
# logits = torch.tensor([pred['score'] for pred in predictions], dtype=torch.float32)
|
| 215 |
+
|
| 216 |
+
# if sampling_technique == "inverse_transform":
|
| 217 |
+
# probs = torch.softmax(logits / temperature, dim=-1)
|
| 218 |
+
# cumulative_probs = torch.cumsum(probs, dim=-1)
|
| 219 |
+
# random_prob = random.random()
|
| 220 |
+
# sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
|
| 221 |
+
|
| 222 |
+
# elif sampling_technique == "exponential_minimum":
|
| 223 |
+
# probs = torch.softmax(logits / temperature, dim=-1)
|
| 224 |
+
# exp_probs = torch.exp(-torch.log(probs))
|
| 225 |
+
# random_probs = torch.rand_like(exp_probs)
|
| 226 |
+
# sampled_index = torch.argmax(random_probs * exp_probs).item()
|
| 227 |
+
|
| 228 |
+
# elif sampling_technique == "temperature":
|
| 229 |
+
# logits = torch.clamp(logits, min=-1e8, max=1e8)
|
| 230 |
+
# probs = torch.softmax(logits / temperature, dim=-1)
|
| 231 |
+
# if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
|
| 232 |
+
# raise ValueError("The computed probabilities contain NaN or inf values.")
|
| 233 |
+
# probs = torch.max(probs, torch.tensor(1e-8, device=logits.device))
|
| 234 |
+
# probs = probs / torch.sum(probs)
|
| 235 |
+
# probs = probs.flatten()
|
| 236 |
+
# if probs.size(0) > 1:
|
| 237 |
+
# sampled_index = torch.multinomial(probs, 1).item()
|
| 238 |
+
# else:
|
| 239 |
+
# sampled_index = torch.argmax(probs).item()
|
| 240 |
+
|
| 241 |
+
# elif sampling_technique == 'greedy':
|
| 242 |
+
# sampled_index = torch.argmax(logits).item()
|
| 243 |
+
|
| 244 |
+
# else:
|
| 245 |
+
# raise ValueError(f"Unknown sampling technique: {sampling_technique}")
|
| 246 |
+
|
| 247 |
+
# # Replace the first [MASK] with the selected word
|
| 248 |
+
# sampled_token = predictions[sampled_index]['token_str']
|
| 249 |
+
# masked_sentence = masked_sentence.replace('[MASK]', sampled_token, 1)
|
| 250 |
+
|
| 251 |
+
# return masked_sentence
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# # Example usage
|
| 255 |
+
# if __name__ == "__main__":
|
| 256 |
+
# from transformers import BertTokenizer
|
| 257 |
+
|
| 258 |
+
# # Define sentences and result_dict
|
| 259 |
+
# sentences = [
|
| 260 |
+
# "The quick brown fox jumps over the lazy dog.",
|
| 261 |
+
# "A quick brown dog outpaces a lazy fox.",
|
| 262 |
+
# "Quick brown animals leap over lazy obstacles."
|
| 263 |
+
# ]
|
| 264 |
+
# result_dict = {
|
| 265 |
+
# "The quick brown fox jumps over the lazy dog.": {"quick brown": [(1, 2)], "lazy": [(7, 7)]},
|
| 266 |
+
# "A quick brown dog outpaces a lazy fox.": {"quick brown": [(1, 2)], "lazy": [(6, 6)]},
|
| 267 |
+
# "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(5, 5)]}
|
| 268 |
+
# }
|
| 269 |
+
|
| 270 |
+
# masking_processor = MaskingProcessor()
|
| 271 |
+
# masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
|
| 272 |
+
|
| 273 |
+
# # Use SamplingProcessor
|
| 274 |
+
# sampling_processor = SamplingProcessorWithPipeline()
|
| 275 |
+
|
| 276 |
+
# # Iterate through masking results to apply sampling
|
| 277 |
+
# for sentence, result in masking_results.items():
|
| 278 |
+
# print(f"Original Sentence (Random): {sentence}")
|
| 279 |
+
# print(f"Masked Sentence (Random): {result['masked_sentence']}")
|
| 280 |
+
# masked_sentence = result["masked_sentence"]
|
| 281 |
+
|
| 282 |
+
# # Apply different sampling techniques
|
| 283 |
+
# for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]:
|
| 284 |
+
# print(f"Sampling Technique: {technique}")
|
| 285 |
+
# filled_sentence = sampling_processor.fill_masked_sentence(
|
| 286 |
+
# masked_sentence=masked_sentence,
|
| 287 |
+
# sampling_technique=technique,
|
| 288 |
+
# temperature=1.0 # Adjust temperature as needed
|
| 289 |
+
# )
|
| 290 |
+
# print(f"Filled Sentence: {filled_sentence}\n")
|
| 291 |
+
# print('--------------------------------')
|
utils/old/sampling/sampling_methods_v1.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import random
|
| 3 |
+
from masking_methods import MaskingProcessor
|
| 4 |
+
|
| 5 |
+
class SamplingProcessor:
|
| 6 |
+
def __init__(self, tokenizer):
|
| 7 |
+
self.tokenizer = tokenizer
|
| 8 |
+
|
| 9 |
+
def fill_masked_sentence(self, original_sentence, mask_logits, sampling_technique, temperature=1.0):
|
| 10 |
+
"""
|
| 11 |
+
Fills each mask in the masked sentence using the specified sampling technique.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
original_sentence (str): The original masked sentence.
|
| 15 |
+
mask_logits (dict): Logits for each [MASK] token.
|
| 16 |
+
sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
|
| 17 |
+
temperature (float): Temperature parameter for sampling methods.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
str: Sentence with the masks filled.
|
| 21 |
+
"""
|
| 22 |
+
sentence_tokens = self.tokenizer.tokenize(original_sentence)
|
| 23 |
+
mask_token_indices = [i for i, token in enumerate(sentence_tokens) if token == self.tokenizer.mask_token]
|
| 24 |
+
|
| 25 |
+
if len(mask_token_indices) != len(mask_logits):
|
| 26 |
+
raise ValueError("Mismatch between number of [MASK] tokens and logits provided.")
|
| 27 |
+
|
| 28 |
+
for mask_idx, filtered_logits in zip(mask_token_indices, mask_logits.values()):
|
| 29 |
+
# Convert logits to a tensor
|
| 30 |
+
filtered_logits = torch.tensor(filtered_logits)
|
| 31 |
+
# filtered_logits, _ = torch.sort(filtered_logits, descending=True)
|
| 32 |
+
# print(f' type of filtered_logits : {type(filtered_logits)}')
|
| 33 |
+
# filtered_logits = filtered_logits[:5]
|
| 34 |
+
|
| 35 |
+
if sampling_technique == "inverse_transform":
|
| 36 |
+
probs = torch.softmax(filtered_logits / temperature, dim=-1)
|
| 37 |
+
cumulative_probs = torch.cumsum(probs, dim=-1)
|
| 38 |
+
random_prob = random.random()
|
| 39 |
+
sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
|
| 40 |
+
|
| 41 |
+
elif sampling_technique == "exponential_minimum":
|
| 42 |
+
probs = torch.softmax(filtered_logits / temperature, dim=-1)
|
| 43 |
+
exp_probs = torch.exp(-torch.log(probs))
|
| 44 |
+
random_probs = torch.rand_like(exp_probs)
|
| 45 |
+
sampled_index = torch.argmax(random_probs * exp_probs).item()
|
| 46 |
+
|
| 47 |
+
elif sampling_technique == "temperature":
|
| 48 |
+
filtered_logits = torch.clamp(filtered_logits, min=-1e8, max=1e8)
|
| 49 |
+
probs = torch.softmax(filtered_logits / temperature, dim=-1)
|
| 50 |
+
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
|
| 51 |
+
raise ValueError("The computed probabilities contain NaN or inf values.")
|
| 52 |
+
probs = torch.max(probs, torch.tensor(1e-8, device=filtered_logits.device))
|
| 53 |
+
probs = probs / torch.sum(probs)
|
| 54 |
+
probs = probs.flatten()
|
| 55 |
+
if probs.size(0) > 1:
|
| 56 |
+
sampled_index = torch.multinomial(probs, 1).item()
|
| 57 |
+
else:
|
| 58 |
+
sampled_index = torch.argmax(probs).item()
|
| 59 |
+
|
| 60 |
+
elif sampling_technique == 'greedy':
|
| 61 |
+
sampled_index = torch.argmax(filtered_logits).item()
|
| 62 |
+
|
| 63 |
+
else:
|
| 64 |
+
raise ValueError(f"Unknown sampling technique: {sampling_technique}")
|
| 65 |
+
|
| 66 |
+
sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0]
|
| 67 |
+
sentence_tokens[mask_idx] = sampled_token
|
| 68 |
+
|
| 69 |
+
return self.tokenizer.convert_tokens_to_string(sentence_tokens)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def process_samples(self, masked_sentences, mask_logits, sampling_technique, temperature=1.0):
|
| 74 |
+
"""
|
| 75 |
+
Process multiple masked sentences and fill their masks using the specified sampling technique.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
masked_sentences (list): List of masked sentences.
|
| 79 |
+
mask_logits (dict): Logits for each [MASK] token in each sentence.
|
| 80 |
+
sampling_technique (str): Sampling technique to use.
|
| 81 |
+
temperature (float): Temperature parameter for sampling methods.
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
list: List of sentences with masks filled.
|
| 85 |
+
"""
|
| 86 |
+
filled_sentences = []
|
| 87 |
+
for sentence, logits in zip(masked_sentences, mask_logits):
|
| 88 |
+
filled_sentence = self.fill_masked_sentence(sentence, logits, sampling_technique, temperature)
|
| 89 |
+
filled_sentences.append(filled_sentence)
|
| 90 |
+
return filled_sentences
|
| 91 |
+
|
| 92 |
+
# Example usage
|
| 93 |
+
if __name__ == "__main__":
|
| 94 |
+
from transformers import BertTokenizer
|
| 95 |
+
|
| 96 |
+
# tokenizer = BertTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
|
| 97 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
| 98 |
+
processor = SamplingProcessor(tokenizer)
|
| 99 |
+
|
| 100 |
+
sentences = [
|
| 101 |
+
"The quick brown fox jumps over the lazy dog.",
|
| 102 |
+
"A quick brown dog outpaces a lazy fox.",
|
| 103 |
+
"Quick brown dog leaps over lazy the fox."
|
| 104 |
+
]
|
| 105 |
+
result_dict = {
|
| 106 |
+
"The quick brown fox jumps over the lazy dog.": {'quick brown': [(0, 1)], 'fox': [(2, 2)], 'lazy': [(4, 4)], 'dog': [(5, 5)]},
|
| 107 |
+
"A quick brown dog outpaces a lazy fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]},
|
| 108 |
+
"Quick brown dog leaps over lazy the fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]}
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
masking_processor = MaskingProcessor()
|
| 113 |
+
masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
|
| 114 |
+
# masked_sentence = "The [MASK] brown fox jumps [MASK] the lazy dog."
|
| 115 |
+
# mask_logits = {
|
| 116 |
+
# 1: torch.randn(len(tokenizer)), # Example logits for first [MASK]
|
| 117 |
+
# 5: torch.randn(len(tokenizer)), # Example logits for second [MASK]
|
| 118 |
+
# }
|
| 119 |
+
|
| 120 |
+
# Iterate through masking results to apply sampling
|
| 121 |
+
for sentence, result in masking_results.items():
|
| 122 |
+
print(f"Original Sentence (Random): {sentence}")
|
| 123 |
+
print(f"Masked Sentence (Random): {result['masked_sentence']}")
|
| 124 |
+
# print(f"Mask Logits (Random): {output['mask_logits']}")
|
| 125 |
+
print(f' type(result["mask_logits"]) : {type(result["mask_logits"])}')
|
| 126 |
+
print(f' length of result["mask_logits"] : {len(result["mask_logits"])}')
|
| 127 |
+
print(f' result["mask_logits"].keys() : {result["mask_logits"].keys()}')
|
| 128 |
+
masked_sentence = result["masked_sentence"]
|
| 129 |
+
mask_logits = result["mask_logits"]
|
| 130 |
+
|
| 131 |
+
print(f"Original Masked Sentence: {masked_sentence}")
|
| 132 |
+
|
| 133 |
+
# Apply different sampling techniques
|
| 134 |
+
for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]:
|
| 135 |
+
print(f"Sampling Technique: {technique}")
|
| 136 |
+
|
| 137 |
+
# Fill the masks using the sampling processor
|
| 138 |
+
filled_sentence = processor.fill_masked_sentence(
|
| 139 |
+
original_sentence=masked_sentence,
|
| 140 |
+
mask_logits=mask_logits,
|
| 141 |
+
sampling_technique=technique,
|
| 142 |
+
temperature=1.0 # Adjust temperature as needed
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
print(f"Filled Sentence: {filled_sentence}\n")
|
| 146 |
+
print('--------------------------------')
|
utils/old/sampling/sampling_methods_v2.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import pipeline
|
| 2 |
+
import torch
|
| 3 |
+
import random
|
| 4 |
+
from masking_methods import MaskingProcessor
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class SamplingProcessorWithPipeline:
|
| 8 |
+
def __init__(self, model_name='bert-base-uncased'):
|
| 9 |
+
self.unmasker = pipeline('fill-mask', model=model_name)
|
| 10 |
+
self.tokenizer = self.unmasker.tokenizer
|
| 11 |
+
|
| 12 |
+
def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0):
|
| 13 |
+
"""
|
| 14 |
+
Fills each mask in the masked sentence using the specified sampling technique.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
masked_sentence (str): Sentence with [MASK] tokens.
|
| 18 |
+
sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
|
| 19 |
+
temperature (float): Temperature parameter for sampling methods.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
str: Sentence with the masks filled.
|
| 23 |
+
"""
|
| 24 |
+
while '[MASK]' in masked_sentence:
|
| 25 |
+
# Get predictions for the first [MASK]
|
| 26 |
+
predictions = self.unmasker(masked_sentence)
|
| 27 |
+
print(f' predictions : {predictions}')
|
| 28 |
+
print(f' type of predictions : {type(predictions)}')
|
| 29 |
+
|
| 30 |
+
# Ensure predictions is a list of dictionaries
|
| 31 |
+
if not isinstance(predictions, list) or not all(isinstance(pred, dict) for pred in predictions):
|
| 32 |
+
raise ValueError("Unexpected structure in predictions from the pipeline.")
|
| 33 |
+
|
| 34 |
+
# Extract logits (scores) from the predictions
|
| 35 |
+
logits = torch.tensor([pred['score'] for pred in predictions], dtype=torch.float32)
|
| 36 |
+
|
| 37 |
+
if sampling_technique == "inverse_transform":
|
| 38 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
| 39 |
+
cumulative_probs = torch.cumsum(probs, dim=-1)
|
| 40 |
+
random_prob = random.random()
|
| 41 |
+
sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
|
| 42 |
+
|
| 43 |
+
elif sampling_technique == "exponential_minimum":
|
| 44 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
| 45 |
+
exp_probs = torch.exp(-torch.log(probs))
|
| 46 |
+
random_probs = torch.rand_like(exp_probs)
|
| 47 |
+
sampled_index = torch.argmax(random_probs * exp_probs).item()
|
| 48 |
+
|
| 49 |
+
elif sampling_technique == "temperature":
|
| 50 |
+
logits = torch.clamp(logits, min=-1e8, max=1e8)
|
| 51 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
| 52 |
+
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
|
| 53 |
+
raise ValueError("The computed probabilities contain NaN or inf values.")
|
| 54 |
+
probs = torch.max(probs, torch.tensor(1e-8, device=logits.device))
|
| 55 |
+
probs = probs / torch.sum(probs)
|
| 56 |
+
probs = probs.flatten()
|
| 57 |
+
if probs.size(0) > 1:
|
| 58 |
+
sampled_index = torch.multinomial(probs, 1).item()
|
| 59 |
+
else:
|
| 60 |
+
sampled_index = torch.argmax(probs).item()
|
| 61 |
+
|
| 62 |
+
elif sampling_technique == 'greedy':
|
| 63 |
+
sampled_index = torch.argmax(logits).item()
|
| 64 |
+
|
| 65 |
+
else:
|
| 66 |
+
raise ValueError(f"Unknown sampling technique: {sampling_technique}")
|
| 67 |
+
|
| 68 |
+
# Replace the first [MASK] with the selected word
|
| 69 |
+
sampled_token = predictions[sampled_index]['token_str']
|
| 70 |
+
masked_sentence = masked_sentence.replace('[MASK]', sampled_token, 1)
|
| 71 |
+
|
| 72 |
+
return masked_sentence
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# Example usage
|
| 76 |
+
if __name__ == "__main__":
|
| 77 |
+
from transformers import BertTokenizer
|
| 78 |
+
|
| 79 |
+
# Define sentences and result_dict
|
| 80 |
+
sentences = [
|
| 81 |
+
"The quick brown fox jumps over the lazy dog.",
|
| 82 |
+
"A quick brown dog outpaces a lazy fox.",
|
| 83 |
+
"Quick brown dog leaps over lazy the fox."
|
| 84 |
+
]
|
| 85 |
+
result_dict = {
|
| 86 |
+
"The quick brown fox jumps over the lazy dog.": {'quick brown': [(0, 1)], 'fox': [(2, 2)], 'lazy': [(4, 4)], 'dog': [(5, 5)]},
|
| 87 |
+
"A quick brown dog outpaces a lazy fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]},
|
| 88 |
+
"Quick brown dog leaps over lazy the fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]}
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
masking_processor = MaskingProcessor()
|
| 92 |
+
masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
|
| 93 |
+
|
| 94 |
+
# Use SamplingProcessor
|
| 95 |
+
sampling_processor = SamplingProcessorWithPipeline()
|
| 96 |
+
|
| 97 |
+
# Iterate through masking results to apply sampling
|
| 98 |
+
for sentence, result in masking_results.items():
|
| 99 |
+
print(f"Original Sentence (Random): {sentence}")
|
| 100 |
+
print(f"Masked Sentence (Random): {result['masked_sentence']}")
|
| 101 |
+
masked_sentence = result["masked_sentence"]
|
| 102 |
+
|
| 103 |
+
# Apply different sampling techniques
|
| 104 |
+
for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]:
|
| 105 |
+
print(f"Sampling Technique: {technique}")
|
| 106 |
+
filled_sentence = sampling_processor.fill_masked_sentence(
|
| 107 |
+
masked_sentence=masked_sentence,
|
| 108 |
+
sampling_technique=technique,
|
| 109 |
+
temperature=1.0 # Adjust temperature as needed
|
| 110 |
+
)
|
| 111 |
+
print(f"Filled Sentence: {filled_sentence}\n")
|
| 112 |
+
print('--------------------------------')
|
utils/old/sampling_final_copy.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import random
|
| 3 |
+
from masking_methods import MaskingProcessor
|
| 4 |
+
|
| 5 |
+
class SamplingProcessor:
|
| 6 |
+
def __init__(self, tokenizer):
|
| 7 |
+
"""
|
| 8 |
+
Initialize the SamplingProcessor.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
tokenizer: BERT tokenizer instance
|
| 12 |
+
"""
|
| 13 |
+
self.tokenizer = tokenizer
|
| 14 |
+
|
| 15 |
+
def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0):
|
| 16 |
+
"""
|
| 17 |
+
Sample tokens for each mask in the sentence using the specified sampling technique.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
mask_logits_dict (dict): Dictionary of mask positions and their logits/tokens
|
| 21 |
+
masked_sentence (str): Sentence with [MASK] tokens
|
| 22 |
+
sampling_technique (str): Sampling method to use
|
| 23 |
+
temperature (float): Temperature parameter for sampling
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
str: Sentence with sampled tokens replacing masks
|
| 27 |
+
"""
|
| 28 |
+
words = masked_sentence.split()
|
| 29 |
+
|
| 30 |
+
# Convert positions and logits to sorted list to process masks in order
|
| 31 |
+
mask_positions = sorted(mask_logits_dict.keys())
|
| 32 |
+
|
| 33 |
+
for mask_pos in mask_positions:
|
| 34 |
+
mask_data = mask_logits_dict[mask_pos]
|
| 35 |
+
mask_logits = torch.tensor(mask_data['logits'])
|
| 36 |
+
candidate_tokens = mask_data['tokens']
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
if sampling_technique == "inverse_transform":
|
| 40 |
+
probs = torch.softmax(mask_logits / temperature, dim=-1)
|
| 41 |
+
cumulative_probs = torch.cumsum(probs, dim=-1)
|
| 42 |
+
random_prob = random.random()
|
| 43 |
+
sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
|
| 44 |
+
|
| 45 |
+
elif sampling_technique == "exponential_minimum":
|
| 46 |
+
probs = torch.softmax(mask_logits / temperature, dim=-1)
|
| 47 |
+
exp_probs = torch.exp(-torch.log(probs))
|
| 48 |
+
random_probs = torch.rand_like(exp_probs)
|
| 49 |
+
sampled_index = torch.argmax(random_probs * exp_probs).item()
|
| 50 |
+
|
| 51 |
+
elif sampling_technique == "temperature":
|
| 52 |
+
mask_logits = torch.clamp(mask_logits, min=-1e8, max=1e8)
|
| 53 |
+
probs = torch.softmax(mask_logits / temperature, dim=-1)
|
| 54 |
+
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
|
| 55 |
+
raise ValueError("The computed probabilities contain NaN or inf values.")
|
| 56 |
+
probs = torch.max(probs, torch.tensor(1e-8))
|
| 57 |
+
probs = probs / torch.sum(probs)
|
| 58 |
+
probs = probs.flatten()
|
| 59 |
+
if probs.size(0) > 1:
|
| 60 |
+
sampled_index = torch.multinomial(probs, 1).item()
|
| 61 |
+
else:
|
| 62 |
+
sampled_index = torch.argmax(probs).item()
|
| 63 |
+
|
| 64 |
+
elif sampling_technique == 'greedy':
|
| 65 |
+
sampled_index = torch.argmax(mask_logits).item()
|
| 66 |
+
|
| 67 |
+
else:
|
| 68 |
+
raise ValueError(f"Unknown sampling technique: {sampling_technique}")
|
| 69 |
+
|
| 70 |
+
# Use the sampled index to get the corresponding token
|
| 71 |
+
sampled_token = candidate_tokens[sampled_index]
|
| 72 |
+
# Remove ## if it's a subword token
|
| 73 |
+
sampled_token = sampled_token.replace('##', '')
|
| 74 |
+
words[mask_pos] = sampled_token
|
| 75 |
+
|
| 76 |
+
except Exception as e:
|
| 77 |
+
print(f"Error sampling for position {mask_pos}: {str(e)}")
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
return " ".join(words)
|
| 81 |
+
|
| 82 |
+
def process_masked_sentences(self, results_dict, sampling_technique="temperature", temperature=1.0):
|
| 83 |
+
"""
|
| 84 |
+
Process all masked sentences in the results dictionary.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
results_dict (dict): Dictionary containing masked sentences and their logits
|
| 88 |
+
sampling_technique (str): Sampling method to use
|
| 89 |
+
temperature (float): Temperature parameter for sampling
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
dict: Dictionary containing original, masked, and sampled sentences
|
| 93 |
+
"""
|
| 94 |
+
processed_results = {}
|
| 95 |
+
|
| 96 |
+
for original_sentence, data in results_dict.items():
|
| 97 |
+
masked_sentence = data["masked_sentence"]
|
| 98 |
+
mask_logits = data["mask_logits"]
|
| 99 |
+
|
| 100 |
+
sampled_sentence = self.sample_tokens(
|
| 101 |
+
mask_logits,
|
| 102 |
+
masked_sentence,
|
| 103 |
+
sampling_technique,
|
| 104 |
+
temperature
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
processed_results[original_sentence] = {
|
| 108 |
+
"masked_sentence": masked_sentence,
|
| 109 |
+
"sampled_sentence": sampled_sentence
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
return processed_results
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
if __name__ == "__main__":
|
| 116 |
+
sentences = [
|
| 117 |
+
"The quick brown fox jumps over the lazy dog everyday.",
|
| 118 |
+
"A speedy brown fox jumps over a lazy dog.",
|
| 119 |
+
"A swift brown fox leaps over the lethargic dog."
|
| 120 |
+
|
| 121 |
+
]
|
| 122 |
+
result_dict ={
|
| 123 |
+
'The quick brown fox jumps over the lazy dog everyday.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
|
| 124 |
+
'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
|
| 125 |
+
'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
# First, mask the sentences
|
| 129 |
+
masking_processor = MaskingProcessor()
|
| 130 |
+
masking_results = masking_processor.process_sentences(sentences, result_dict)
|
| 131 |
+
|
| 132 |
+
# Then, sample replacements for the masks
|
| 133 |
+
sampling_processor = SamplingProcessor(masking_processor.tokenizer)
|
| 134 |
+
|
| 135 |
+
# Try different sampling techniques
|
| 136 |
+
sampling_techniques = ["temperature", "greedy", "inverse_transform", "exponential_minimum"]
|
| 137 |
+
|
| 138 |
+
for technique in sampling_techniques:
|
| 139 |
+
print(f"\nSampling using {technique}:")
|
| 140 |
+
sampled_results = sampling_processor.process_masked_sentences(
|
| 141 |
+
masking_results,
|
| 142 |
+
sampling_technique=technique,
|
| 143 |
+
temperature=1.0
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
'''
|
| 147 |
+
{
|
| 148 |
+
"original_sentence_1":
|
| 149 |
+
{
|
| 150 |
+
"masked_sentence": "sentence with [MASK] tokens",
|
| 151 |
+
"sampling_method1": "sentence with sampled tokens",
|
| 152 |
+
},
|
| 153 |
+
"original_sentence_2":
|
| 154 |
+
{
|
| 155 |
+
"masked_sentence": "sentence with [MASK] tokens",
|
| 156 |
+
"sampling_method": "sentence with sampled tokens"
|
| 157 |
+
},
|
| 158 |
+
# ... and so on for each input sentence
|
| 159 |
+
},
|
| 160 |
+
|
| 161 |
+
'''
|
| 162 |
+
|
| 163 |
+
for original_sentence, result in sampled_results.items():
|
| 164 |
+
print(f"Original: {original_sentence}")
|
| 165 |
+
print(f"Masked: {result['masked_sentence']}")
|
| 166 |
+
print(f"Sampled: {result['sampled_sentence']}")
|
| 167 |
+
print("---")
|
| 168 |
+
|