text-peccavi commited on
Commit
c8b9192
·
verified ·
1 Parent(s): 71a5e88

Upload 54 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. UI/__pycache__/gradio.cpython-310.pyc +0 -0
  2. UI/__pycache__/gradio.cpython-311.pyc +0 -0
  3. UI/gradio.py +575 -0
  4. app.py +20 -0
  5. environment.yml +245 -0
  6. metrics/detectability.py +323 -0
  7. metrics/distortion.py +370 -0
  8. renderers/__pycache__/highlighter.cpython-310.pyc +0 -0
  9. renderers/__pycache__/highlighter.cpython-311.pyc +0 -0
  10. renderers/__pycache__/plot_3d.cpython-310.pyc +0 -0
  11. renderers/__pycache__/plot_3d.cpython-311.pyc +0 -0
  12. renderers/__pycache__/tree.cpython-310.pyc +0 -0
  13. renderers/__pycache__/tree.cpython-311.pyc +0 -0
  14. renderers/highlighter.py +185 -0
  15. renderers/plot_3d.py +126 -0
  16. renderers/tree.py +490 -0
  17. utils/__init__.py +5 -0
  18. utils/__pycache__/__init__.cpython-310.pyc +0 -0
  19. utils/__pycache__/__init__.cpython-311.pyc +0 -0
  20. utils/__pycache__/config.cpython-310.pyc +0 -0
  21. utils/__pycache__/config.cpython-311.pyc +0 -0
  22. utils/__pycache__/entailment.cpython-310.pyc +0 -0
  23. utils/__pycache__/entailment.cpython-311.pyc +0 -0
  24. utils/__pycache__/masking_methods.cpython-310.pyc +0 -0
  25. utils/__pycache__/masking_methods.cpython-311.pyc +0 -0
  26. utils/__pycache__/ngram_index_manager.cpython-311.pyc +0 -0
  27. utils/__pycache__/non_melting_point.cpython-310.pyc +0 -0
  28. utils/__pycache__/non_melting_point.cpython-311.pyc +0 -0
  29. utils/__pycache__/paraphraser.cpython-310.pyc +0 -0
  30. utils/__pycache__/paraphraser.cpython-311.pyc +0 -0
  31. utils/__pycache__/sampling.cpython-310.pyc +0 -0
  32. utils/__pycache__/sampling.cpython-311.pyc +0 -0
  33. utils/__pycache__/watermark.cpython-310.pyc +0 -0
  34. utils/__pycache__/watermark.cpython-311.pyc +0 -0
  35. utils/config.py +18 -0
  36. utils/config.yaml +48 -0
  37. utils/entailment.py +107 -0
  38. utils/masking_methods.py +583 -0
  39. utils/non_melting_point.py +590 -0
  40. utils/old/masking/masking_methods.py +355 -0
  41. utils/old/masking/masking_methods_new_work.py +447 -0
  42. utils/old/masking/masking_methods_ok_working.py +257 -0
  43. utils/old/masking/masking_methods_v1_working.py +233 -0
  44. utils/old/masking_methods_final_copy.py +619 -0
  45. utils/old/non_melting_points_v1.py +244 -0
  46. utils/old/sampling/sampling.py +330 -0
  47. utils/old/sampling/sampling_methods.py +291 -0
  48. utils/old/sampling/sampling_methods_v1.py +146 -0
  49. utils/old/sampling/sampling_methods_v2.py +112 -0
  50. utils/old/sampling_final_copy.py +168 -0
UI/__pycache__/gradio.cpython-310.pyc ADDED
Binary file (6.61 kB). View file
 
UI/__pycache__/gradio.cpython-311.pyc ADDED
Binary file (29.8 kB). View file
 
UI/gradio.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from utils.watermark import Watermarker
3
+ from utils.config import load_config
4
+ from renderers.highlighter import highlight_common_words, highlight_common_words_dict, reparaphrased_sentences_html
5
+ from renderers.tree import generate_subplot1, generate_subplot2
6
+ from pathlib import Path
7
+ import time
8
+ from typing import Dict, List, Tuple, Any
9
+ import plotly.graph_objects as go
10
+
11
+ class WatermarkerInterface:
12
+ def __init__(self, config):
13
+ self.pipeline = Watermarker(config)
14
+ self.common_grams = {}
15
+ self.highlight_info = []
16
+ self.masked_sentences = []
17
+
18
+ # Add tracking dictionaries for indexing
19
+ self.masked_sentence_indices = {} # Maps original sentences to masked indices
20
+ self.sampled_sentence_indices = {} # Maps masked sentences to sampling indices
21
+ self.reparaphrased_indices = {} # Maps sampled sentences to reparaphrased indices
22
+
23
+ def handle_paraphrase(self, prompt: str) -> Tuple[str, str, str, str]:
24
+ """Wrapper for paraphrasing that includes highlighting"""
25
+ start_time = time.time()
26
+
27
+ # Run paraphrasing
28
+ self.pipeline.Paraphrase(prompt)
29
+
30
+ # Step 1: Process the original sentence first
31
+ seen_ngrams = {} # Stores first occurrence index of each n-gram
32
+ original_indexed_ngrams = [] # Final indexed list for original
33
+
34
+ original_sentence = self.pipeline.user_prompt
35
+ original_ngrams = self.pipeline.common_grams.get(original_sentence, {})
36
+
37
+ # Step 1.1: Extract n-grams and their first occurrence index
38
+ ngram_occurrences = [
39
+ (min(indices, key=lambda x: x[0])[0], gram) # Get first index
40
+ for gram, indices in original_ngrams.items()
41
+ ]
42
+
43
+ # Step 1.2: Sort n-grams based on their first occurrence
44
+ ngram_occurrences.sort()
45
+
46
+ # Step 1.3: Assign sequential indices
47
+ for idx, (position, gram) in enumerate(ngram_occurrences, start=1):
48
+ seen_ngrams[gram] = idx # Assign sequential index
49
+ original_indexed_ngrams.append((idx, gram))
50
+
51
+ print("Original Indexed N-grams:", original_indexed_ngrams)
52
+
53
+ #generate highlight_info
54
+ colors = ["red", "blue", "green", "orange"]
55
+ highlight_info = [
56
+ (ngram, colors[i % len(colors)])
57
+ for i, (index, ngram) in enumerate(original_indexed_ngrams)
58
+ ]
59
+ common_grams = original_indexed_ngrams
60
+ self.highlight_info = highlight_info
61
+ self.common_grams = common_grams
62
+
63
+ # Step 2: Process paraphrased sentences and match indices
64
+ paraphrase_indexed_ngrams = {}
65
+
66
+ for sentence in self.pipeline.paraphrased_sentences:
67
+ sentence_ngrams = [] # Stores n-grams for this sentence
68
+ sentence_ngrams_dict = self.pipeline.common_grams.get(sentence, {})
69
+
70
+ for gram, indices in sentence_ngrams_dict.items():
71
+ first_occurrence = min(indices, key=lambda x: x[0])[0]
72
+
73
+ # Use the original's index if exists, otherwise assign a new one
74
+ if gram in seen_ngrams:
75
+ index = seen_ngrams[gram] # Use the same index as original
76
+ else:
77
+ index = len(seen_ngrams) + 1 # Assign new index
78
+ seen_ngrams[gram] = index # Store it
79
+
80
+ sentence_ngrams.append((index, gram))
81
+
82
+ sentence_ngrams.sort()
83
+ paraphrase_indexed_ngrams[sentence] = sentence_ngrams
84
+
85
+ print("Paraphrase Indexed N-grams:", paraphrase_indexed_ngrams)
86
+
87
+ # Step 3: Generate highlighted versions using the renderer
88
+ highlighted_prompt = highlight_common_words(
89
+ common_grams,
90
+ [self.pipeline.user_prompt],
91
+ "Original Prompt with Highlighted Common Sequences"
92
+ )
93
+
94
+ highlighted_accepted = highlight_common_words_dict(
95
+ common_grams,
96
+ self.pipeline.selected_sentences,
97
+ "Accepted Paraphrased Sentences with Entailment Scores"
98
+ )
99
+
100
+ highlighted_discarded = highlight_common_words_dict(
101
+ common_grams,
102
+ self.pipeline.discarded_sentences,
103
+ "Discarded Paraphrased Sentences with Entailment Scores"
104
+ )
105
+
106
+ execution_time = f"<div class='execution-time'>Step 1 completed in {time.time() - start_time:.2f} seconds</div>"
107
+
108
+ return highlighted_prompt, highlighted_accepted, highlighted_discarded, execution_time
109
+
110
+ def handle_masking(self):
111
+ start_time = time.time()
112
+ masking_results = self.pipeline.Masking()
113
+ trees = []
114
+ highlight_info = self.highlight_info
115
+ common_grams = self.common_grams
116
+ sentence_to_masked = {}
117
+ self.masked_sentence_indices = {}
118
+
119
+ for strategy, sentence_dict in masking_results.items():
120
+ for sent, data in sentence_dict.items():
121
+ if sent not in sentence_to_masked:
122
+ sentence_to_masked[sent] = []
123
+ masked_sentence = data.get("masked_sentence", "")
124
+ if masked_sentence:
125
+ sentence_to_masked[sent].append((masked_sentence, strategy))
126
+
127
+ plot_idx = 1
128
+ for original_sentence, masked_sentences_data in sentence_to_masked.items():
129
+ if not masked_sentences_data:
130
+ continue
131
+ masked_idx = 1
132
+ for masked_sentence, strategy in masked_sentences_data:
133
+ index = f"{plot_idx}{masked_idx}"
134
+ if original_sentence not in self.masked_sentence_indices:
135
+ self.masked_sentence_indices[original_sentence] = {}
136
+ key = f"{strategy}_{masked_sentence}"
137
+ self.masked_sentence_indices[original_sentence][key] = {
138
+ 'index': index,
139
+ 'strategy': strategy,
140
+ 'masked_sentence': masked_sentence
141
+ }
142
+ masked_idx += 1
143
+
144
+ masked_sentences = [ms[0] for ms in masked_sentences_data]
145
+ indexed_masked_sentences = []
146
+ verified_strategies = []
147
+ for masked_sentence, strategy in masked_sentences_data:
148
+ key = f"{strategy}_{masked_sentence}"
149
+ entry = self.masked_sentence_indices[original_sentence][key]
150
+ idx = entry['index']
151
+ indexed_masked_sentences.append(f"[{idx}] {masked_sentence}")
152
+ verified_strategies.append(entry['strategy'])
153
+
154
+ try:
155
+ fig = generate_subplot1(
156
+ original_sentence,
157
+ indexed_masked_sentences,
158
+ verified_strategies,
159
+ highlight_info,
160
+ common_grams
161
+ )
162
+ trees.append(fig)
163
+ except Exception as e:
164
+ print(f"Error generating plot: {e}")
165
+ trees.append(go.Figure())
166
+ plot_idx += 1
167
+
168
+ while len(trees) < 10:
169
+ trees.append(go.Figure())
170
+
171
+ execution_time = f"<div class='execution-time'>Step 2 completed in {time.time() - start_time:.2f} seconds</div>"
172
+ return trees[:10] + [execution_time]
173
+
174
+ def handle_sampling(self) -> Tuple[List[go.Figure], str]:
175
+ start_time = time.time()
176
+ sampling_results = self.pipeline.Sampling()
177
+ trees = []
178
+ self.sampled_sentence_indices = {}
179
+ organized_results = {}
180
+
181
+ for sampling_strategy, masking_dict in sampling_results.items():
182
+ for masking_strategy, sentences in masking_dict.items():
183
+ for original_sentence, data in sentences.items():
184
+ if original_sentence not in organized_results:
185
+ organized_results[original_sentence] = {}
186
+ if masking_strategy not in organized_results[original_sentence]:
187
+ organized_results[original_sentence][masking_strategy] = {
188
+ "masked_sentence": data.get("masked_sentence", ""),
189
+ "sampled_sentences": {}
190
+ }
191
+ organized_results[original_sentence][masking_strategy]["sampled_sentences"][sampling_strategy] = data.get("sampled_sentence", "")
192
+
193
+ plot_idx = 1
194
+ for original_sentence, data in organized_results.items():
195
+ masked_sentences = []
196
+ all_sampled_sentences = []
197
+ indexed_sampled_sentences = []
198
+ masked_indices = self.masked_sentence_indices.get(original_sentence, {})
199
+
200
+ for masking_strategy, masking_data in list(data.items())[:3]:
201
+ masked_sentence = masking_data.get("masked_sentence", "")
202
+ if masked_sentence:
203
+ masked_sentences.append(masked_sentence)
204
+ masked_idx = None
205
+ for ms_key, ms_data in masked_indices.items():
206
+ if ms_key == f"{masking_strategy}_{masked_sentence}":
207
+ masked_idx = ms_data['index']
208
+ break
209
+
210
+ if not masked_idx:
211
+ print(f"Warning: No index found for masked sentence: {masked_sentence}")
212
+ continue
213
+
214
+ sample_count = 1
215
+ for sampling_strategy, sampled_sentence in masking_data.get("sampled_sentences", {}).items():
216
+ if sampled_sentence:
217
+ sample_idx = f"{masked_idx}.{sample_count}"
218
+ if masked_sentence not in self.sampled_sentence_indices:
219
+ self.sampled_sentence_indices[masked_sentence] = {}
220
+ self.sampled_sentence_indices[masked_sentence][sampled_sentence] = {
221
+ 'index': sample_idx,
222
+ 'strategy': sampling_strategy
223
+ }
224
+ indexed_sampled_sentences.append(f"[{sample_idx}] {sampled_sentence}")
225
+ all_sampled_sentences.append(sampled_sentence)
226
+ sample_count += 1
227
+
228
+ if masked_sentences:
229
+ indexed_masked_sentences = []
230
+ for ms in masked_sentences:
231
+ idx = ""
232
+ for ms_key, ms_data in masked_indices.items():
233
+ if ms_key.endswith(f"_{ms}"):
234
+ idx = ms_data['index']
235
+ break
236
+ indexed_masked_sentences.append(f"[{idx}] {ms}")
237
+
238
+ try:
239
+ fig = generate_subplot2(
240
+ indexed_masked_sentences,
241
+ indexed_sampled_sentences,
242
+ self.highlight_info,
243
+ self.common_grams
244
+ )
245
+ trees.append(fig)
246
+ except Exception as e:
247
+ print(f"Error generating subplot for {original_sentence}: {e}")
248
+ trees.append(go.Figure())
249
+ plot_idx += 1
250
+
251
+ print("Sampled sentence indices:", self.sampled_sentence_indices)
252
+
253
+ while len(trees) < 10:
254
+ trees.append(go.Figure())
255
+
256
+ execution_time = f"<div class='execution-time'>Step 3 completed in {time.time() - start_time:.2f} seconds</div>"
257
+
258
+ return trees[:10] + [execution_time]
259
+
260
+ def handle_reparaphrasing(self) -> Tuple[List[str], str]:
261
+ start_time = time.time()
262
+ results = self.pipeline.re_paraphrasing()
263
+ html_outputs = []
264
+ self.reparaphrased_indices = {}
265
+ tab_count = 1
266
+
267
+ for sampling_strategy, masking_dict in results.items():
268
+ for masking_strategy, sentences in masking_dict.items():
269
+ for original_sent, data in sentences.items():
270
+ sampled_sentence = data.get("sampled_sentence", "")
271
+ if not sampled_sentence or not data["re_paraphrased_sentences"]:
272
+ continue
273
+
274
+ sampled_index = None
275
+ for masked_sent, sampled_dict in self.sampled_sentence_indices.items():
276
+ if sampled_sentence in sampled_dict:
277
+ sampled_index = sampled_dict[sampled_sentence]['index']
278
+ break
279
+
280
+ if not sampled_index:
281
+ sampled_index = "unknown"
282
+
283
+ indexed_reparaphrased = []
284
+ for i, rp_sent in enumerate(data["re_paraphrased_sentences"], 1):
285
+ rp_idx = f"{tab_count}.({sampled_index}).{i}"
286
+ if sampled_sentence not in self.reparaphrased_indices:
287
+ self.reparaphrased_indices[sampled_sentence] = {}
288
+ self.reparaphrased_indices[sampled_sentence][rp_sent] = rp_idx
289
+ indexed_reparaphrased.append(f"[{rp_idx}] {rp_sent}")
290
+
291
+ print(f"Reparaphrasing {tab_count}.({sampled_index}): {' '.join(sampled_sentence.split()[:5])}...")
292
+ html = reparaphrased_sentences_html(indexed_reparaphrased)
293
+ html_outputs.append(html)
294
+ tab_count += 1
295
+
296
+ print("Reparaphrased indices:", self.reparaphrased_indices)
297
+
298
+ while len(html_outputs) < 120:
299
+ html_outputs.append("")
300
+
301
+ execution_time = f"<div class='execution-time'>Step 4 completed in {time.time() - start_time:.2f} seconds</div>"
302
+
303
+ return html_outputs[:120] + [execution_time]
304
+
305
+ def create_gradio_interface(config):
306
+ """Creates the Gradio interface with the updated pipeline"""
307
+ interface = WatermarkerInterface(config)
308
+
309
+ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
310
+ #CSS to enable scrolling for reparaphrased sentences and sampling plots
311
+ demo.css = """
312
+ /* Set fixed height for the reparaphrased tabs container only */
313
+ .gradio-container .tabs[id="reparaphrased-tabs"],
314
+ .gradio-container .tabs[id="sampling-tabs"] {
315
+ overflow-x: hidden;
316
+ white-space: normal;
317
+ border-radius: 8px;
318
+ max-height: 600px; /* Set fixed height for the entire tabs component */
319
+ overflow-y: auto; /* Enable vertical scrolling inside the container */
320
+ }
321
+
322
+ /* Tab content styling for reparaphrased and sampling tabs */
323
+ .gradio-container .tabs[id="reparaphrased-tabs"] .tabitem,
324
+ .gradio-container .tabs[id="sampling-tabs"] .tabitem {
325
+ overflow-x: hidden;
326
+ white-space: normal;
327
+ display: block;
328
+ border-radius: 8px;
329
+ }
330
+
331
+ /* Make the tab navigation fixed at the top for scrollable tabs */
332
+ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav,
333
+ .gradio-container .tabs[id="sampling-tabs"] .tab-nav {
334
+ display: flex;
335
+ overflow-x: auto;
336
+ white-space: nowrap;
337
+ scrollbar-width: thin;
338
+ border-radius: 8px;
339
+ scrollbar-color: #888 #f1f1f1;
340
+ position: sticky;
341
+ top: 0;
342
+ background: white;
343
+ z-index: 100;
344
+ }
345
+
346
+ /* Dropdown menu for scrollable tabs styling */
347
+ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown,
348
+ .gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown {
349
+ position: relative;
350
+ display: inline-block;
351
+ }
352
+
353
+ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown-content,
354
+ .gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown-content {
355
+ display: none;
356
+ position: absolute;
357
+ background-color: #f9f9f9;
358
+ min-width: 160px;
359
+ box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2);
360
+ z-index: 1;
361
+ max-height: 300px;
362
+ overflow-y: auto;
363
+ }
364
+
365
+ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown:hover .tab-dropdown-content,
366
+ .gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown:hover .tab-dropdown-content {
367
+ display: block;
368
+ }
369
+
370
+ /* Scrollbar styling for scrollable tabs */
371
+ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav::-webkit-scrollbar,
372
+ .gradio-container .tabs[id="sampling-tabs"] .tab-nav::-webkit-scrollbar {
373
+ height: 8px;
374
+ border-radius: 8px;
375
+ }
376
+
377
+ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav::-webkit-scrollbar-track,
378
+ .gradio-container .tabs[id="sampling-tabs"] .tab-nav::-webkit-scrollbar-track {
379
+ background: #f1f1f1;
380
+ border-radius: 8px;
381
+ }
382
+
383
+ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav::-webkit-scrollbar-thumb,
384
+ .gradio-container .tabs[id="sampling-tabs"] .tab-nav::-webkit-scrollbar-thumb {
385
+ background: #888;
386
+ border-radius: 8px;
387
+ }
388
+
389
+ /* Tab button styling for scrollable tabs */
390
+ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-item,
391
+ .gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-item {
392
+ flex: 0 0 auto;
393
+ border-radius: 8px;
394
+ }
395
+
396
+ /* Plot container styling specifically for sampling tabs */
397
+ .gradio-container .tabs[id="sampling-tabs"] .plot-container {
398
+ min-height: 600px;
399
+ max-height: 1800px;
400
+ overflow-y: auto;
401
+ }
402
+
403
+ /* Ensure text wraps in HTML components */
404
+ .gradio-container .prose {
405
+ white-space: normal;
406
+ word-wrap: break-word;
407
+ overflow-wrap: break-word;
408
+ }
409
+
410
+ /* Dropdown button styling for scrollable tabs */
411
+ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown button,
412
+ .gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown button {
413
+ background-color: #f0f0f0;
414
+ border: 1px solid #ddd;
415
+ border-radius: 4px;
416
+ padding: 5px 10px;
417
+ cursor: pointer;
418
+ margin: 2px;
419
+ }
420
+
421
+ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown button:hover,
422
+ .gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown button:hover {
423
+ background-color: #e0e0e0;
424
+ }
425
+
426
+ /* Style dropdown content items for scrollable tabs */
427
+ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown-content div,
428
+ .gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown-content div {
429
+ padding: 8px 12px;
430
+ cursor: pointer;
431
+ }
432
+
433
+ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown-content div:hover,
434
+ .gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown-content div:hover {
435
+ background-color: #e0e0e0;
436
+ }
437
+
438
+ /* Custom styling for execution time display */
439
+ .execution-time {
440
+ text-align: right;
441
+ padding: 8px 16px;
442
+ font-family: inherit;
443
+ color: #555;
444
+ font-size: 0.9rem;
445
+ font-style: italic;
446
+ margin-left: auto;
447
+ width: 100%;
448
+ border-top: 1px solid #eee;
449
+ margin-top: 8px;
450
+ }
451
+
452
+ /* Layout for section headers with execution time */
453
+ .section-header {
454
+ display: flex;
455
+ justify-content: space-between;
456
+ align-items: center;
457
+ width: 100%;
458
+ margin-bottom: 12px;
459
+ }
460
+
461
+ .section-header h3 {
462
+ margin: 0;
463
+ }
464
+ """
465
+ gr.Markdown("# **AIISC Watermarking Model**")
466
+
467
+ with gr.Column():
468
+ gr.Markdown("## Input Prompt")
469
+ user_input = gr.Textbox(
470
+ label="Enter Your Prompt",
471
+ placeholder="Type your text here..."
472
+ )
473
+
474
+ with gr.Row():
475
+ with gr.Column(scale=3):
476
+ gr.Markdown("## Step 1: Paraphrasing, LCS and Entailment Analysis")
477
+ with gr.Column(scale=1):
478
+ step1_time = gr.HTML()
479
+
480
+ paraphrase_button = gr.Button("Generate Paraphrases")
481
+ highlighted_user_prompt = gr.HTML(label="Highlighted User Prompt")
482
+
483
+ with gr.Tabs():
484
+ with gr.TabItem("Accepted Paraphrased Sentences"):
485
+ highlighted_accepted_sentences = gr.HTML()
486
+ with gr.TabItem("Discarded Paraphrased Sentences"):
487
+ highlighted_discarded_sentences = gr.HTML()
488
+
489
+ with gr.Row():
490
+ with gr.Column(scale=3):
491
+ gr.Markdown("## Step 2: Where to Mask?")
492
+ with gr.Column(scale=1):
493
+ step2_time = gr.HTML()
494
+
495
+ masking_button = gr.Button("Apply Masking")
496
+ gr.Markdown("### Masked Sentence Trees")
497
+ tree1_plots = []
498
+ with gr.Tabs() as tree1_tabs:
499
+ for i in range(10):
500
+ with gr.TabItem(f"Masked Sentence {i+1}"):
501
+ tree1 = gr.Plot()
502
+ tree1_plots.append(tree1)
503
+
504
+ with gr.Row():
505
+ with gr.Column(scale=3):
506
+ gr.Markdown("## Step 3: How to Mask?")
507
+ with gr.Column(scale=1):
508
+ step3_time = gr.HTML()
509
+
510
+ sampling_button = gr.Button("Sample Words")
511
+ gr.Markdown("### Sampled Sentence Trees")
512
+
513
+ tree2_plots = []
514
+ # Add elem_id to make this tab container scrollable
515
+ with gr.Tabs(elem_id="sampling-tabs") as tree2_tabs:
516
+ for i in range(10):
517
+ with gr.TabItem(f"Sampled Sentence {i+1}"):
518
+ # Add a custom class to the container to enable proper styling
519
+ with gr.Column(elem_classes=["plot-container"]):
520
+ tree2 = gr.Plot()
521
+ tree2_plots.append(tree2)
522
+
523
+ with gr.Row():
524
+ with gr.Column(scale=3):
525
+ gr.Markdown("## Step 4: Re-paraphrasing")
526
+ with gr.Column(scale=1):
527
+ step4_time = gr.HTML()
528
+
529
+ reparaphrase_button = gr.Button("Re-paraphrase")
530
+ gr.Markdown("### Reparaphrased Sentences")
531
+ reparaphrased_sentences_tabs = []
532
+ with gr.Tabs(elem_id="reparaphrased-tabs") as reparaphrased_tabs:
533
+ for i in range(120):
534
+ with gr.TabItem(f"Reparaphrased Batch {i+1}"):
535
+ reparaphrased_sent_html = gr.HTML()
536
+ reparaphrased_sentences_tabs.append(reparaphrased_sent_html)
537
+
538
+ # Connect the interface functions to the buttons
539
+ paraphrase_button.click(
540
+ interface.handle_paraphrase,
541
+ inputs=user_input,
542
+ outputs=[
543
+ highlighted_user_prompt,
544
+ highlighted_accepted_sentences,
545
+ highlighted_discarded_sentences,
546
+ step1_time
547
+ ]
548
+ )
549
+
550
+ masking_button.click(
551
+ interface.handle_masking,
552
+ inputs=None,
553
+ outputs=tree1_plots + [step2_time]
554
+ )
555
+
556
+ sampling_button.click(
557
+ interface.handle_sampling,
558
+ inputs=None,
559
+ outputs=tree2_plots + [step3_time]
560
+ )
561
+
562
+ reparaphrase_button.click(
563
+ interface.handle_reparaphrasing,
564
+ inputs=None,
565
+ outputs=reparaphrased_sentences_tabs + [step4_time]
566
+ )
567
+
568
+ return demo
569
+
570
+ if __name__ == "__main__":
571
+ project_root = Path(__file__).parent.parent
572
+ config_path = project_root / "utils" / "config.yaml"
573
+ config = load_config(config_path)['PECCAVI_TEXT']
574
+
575
+ create_gradio_interface(config).launch()
app.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ from UI.gradio import create_gradio_interface
4
+ from pathlib import Path
5
+ from utils.config import load_config
6
+
7
+ project_root = Path(__file__).resolve().parent
8
+ config_path = project_root / "utils" / "config.yaml"
9
+ config = load_config(config_path)['PECCAVI_TEXT']
10
+
11
+ def main():
12
+ """
13
+ This function is the entry point for the PECCAVI Watermarking Model.
14
+
15
+ It creates the Gradio interface for the model and runs it.
16
+ """
17
+ create_gradio_interface(config).launch()
18
+
19
+ if __name__ == "__main__":
20
+ main()
environment.yml ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: panda
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=conda_forge
7
+ - _openmp_mutex=4.5=2_gnu
8
+ - asttokens=2.4.1=pyhd8ed1ab_0
9
+ - bzip2=1.0.8=h5eee18b_6
10
+ - ca-certificates=2024.8.30=hbcca054_0
11
+ - comm=0.2.2=pyhd8ed1ab_0
12
+ - debugpy=1.8.6=py310hf71b8c6_0
13
+ - decorator=5.1.1=pyhd8ed1ab_0
14
+ - exceptiongroup=1.2.2=pyhd8ed1ab_0
15
+ - executing=2.1.0=pyhd8ed1ab_0
16
+ - ipykernel=6.29.5=pyh3099207_0
17
+ - ipython=8.27.0=pyh707e725_0
18
+ - jedi=0.19.1=pyhd8ed1ab_0
19
+ - jupyter_client=8.6.3=pyhd8ed1ab_0
20
+ - jupyter_core=5.7.2=pyh31011fe_1
21
+ - krb5=1.21.3=h143b758_0
22
+ - ld_impl_linux-64=2.40=h12ee557_0
23
+ - libedit=3.1.20230828=h5eee18b_0
24
+ - libffi=3.4.4=h6a678d5_1
25
+ - libgcc=14.1.0=h77fa898_1
26
+ - libgcc-ng=14.1.0=h69a702a_1
27
+ - libgomp=14.1.0=h77fa898_1
28
+ - libsodium=1.0.20=h4ab18f5_0
29
+ - libstdcxx=14.1.0=hc0a3c3a_1
30
+ - libstdcxx-ng=11.2.0=h1234567_1
31
+ - libuuid=1.41.5=h5eee18b_0
32
+ - matplotlib-inline=0.1.7=pyhd8ed1ab_0
33
+ - ncurses=6.4=h6a678d5_0
34
+ - nest-asyncio=1.6.0=pyhd8ed1ab_0
35
+ - openssl=3.3.2=hb9d3cd8_0
36
+ - packaging=24.1=pyhd8ed1ab_0
37
+ - parso=0.8.4=pyhd8ed1ab_0
38
+ - pexpect=4.9.0=pyhd8ed1ab_0
39
+ - pickleshare=0.7.5=py_1003
40
+ - pip=24.2=py310h06a4308_0
41
+ - platformdirs=4.3.6=pyhd8ed1ab_0
42
+ - prompt-toolkit=3.0.48=pyha770c72_0
43
+ - ptyprocess=0.7.0=pyhd3deb0d_0
44
+ - pure_eval=0.2.3=pyhd8ed1ab_0
45
+ - pygments=2.18.0=pyhd8ed1ab_0
46
+ - python=3.10.14=h955ad1f_1
47
+ - python_abi=3.10=2_cp310
48
+ - pyzmq=26.2.0=py310h71f11fc_2
49
+ - readline=8.2=h5eee18b_0
50
+ - setuptools=75.1.0=py310h06a4308_0
51
+ - sqlite=3.45.3=h5eee18b_0
52
+ - stack_data=0.6.2=pyhd8ed1ab_0
53
+ - tk=8.6.14=h39e8969_0
54
+ - tornado=6.4.1=py310ha75aee5_1
55
+ - traitlets=5.14.3=pyhd8ed1ab_0
56
+ - typing_extensions=4.12.2=pyha770c72_0
57
+ - wcwidth=0.2.13=pyhd8ed1ab_0
58
+ - wheel=0.44.0=py310h06a4308_0
59
+ - xz=5.4.6=h5eee18b_1
60
+ - zeromq=4.3.5=ha4adb4c_5
61
+ - zlib=1.2.13=h5eee18b_1
62
+ - pip:
63
+ - absl-py==2.1.0
64
+ - accelerate==0.33.0
65
+ - aiofiles==23.2.1
66
+ - aiohappyeyeballs==2.3.5
67
+ - aiohttp==3.10.3
68
+ - aiosignal==1.3.1
69
+ - altgraph==0.17.4
70
+ - annotated-types==0.7.0
71
+ - anyio==4.6.0
72
+ - astunparse==1.6.3
73
+ - async-timeout==4.0.3
74
+ - attrs==24.2.0
75
+ - av==12.0.0
76
+ - backports-tarfile==1.2.0
77
+ - beautifulsoup4==4.12.3
78
+ - build==1.2.2
79
+ - cachetools==5.5.0
80
+ - certifi==2024.7.4
81
+ - cffi==1.17.1
82
+ - charset-normalizer==3.3.2
83
+ - clean-fid==0.1.35
84
+ - click==8.1.7
85
+ - colorama==0.4.6
86
+ - contextlib2==21.6.0
87
+ - contourpy==1.2.1
88
+ - cryptography==43.0.1
89
+ - cycler==0.12.1
90
+ - datasets==2.21.0
91
+ - diffusers==0.27.2
92
+ - dill==0.3.8
93
+ - docker-pycreds==0.4.0
94
+ - docutils==0.21.2
95
+ - fastapi==0.115.0
96
+ - ffmpy==0.4.0
97
+ - filelock==3.15.4
98
+ - flatbuffers==24.3.25
99
+ - fonttools==4.53.1
100
+ - frozenlist==1.4.1
101
+ - fsspec==2024.6.1
102
+ - gast==0.4.0
103
+ - gdown==5.2.0
104
+ - gitdb==4.0.11
105
+ - gitpython==3.1.43
106
+ - google-auth==2.35.0
107
+ - google-auth-oauthlib==0.4.6
108
+ - google-pasta==0.2.0
109
+ - gradio==4.44.0
110
+ - gradio-client==1.3.0
111
+ - grpcio==1.65.4
112
+ - h11==0.14.0
113
+ - h5py==3.11.0
114
+ - httpcore==1.0.6
115
+ - httpx==0.27.2
116
+ - huggingface-hub==0.25.2
117
+ - idna==3.7
118
+ - imageio==2.35.0
119
+ - importlib-metadata==8.2.0
120
+ - importlib-resources==6.4.5
121
+ - jaraco-classes==3.4.0
122
+ - jaraco-context==6.0.1
123
+ - jaraco-functools==4.1.0
124
+ - jeepney==0.8.0
125
+ - jinja2==3.1.4
126
+ - joblib==1.4.2
127
+ - json-with-comments==1.2.7
128
+ - keras==3.5.0
129
+ - keras-preprocessing==1.1.2
130
+ - keyring==25.4.1
131
+ - kiwisolver==1.4.5
132
+ - kornia==0.7.4
133
+ - kornia-rs==0.1.7
134
+ - lazy-loader==0.4
135
+ - libclang==18.1.1
136
+ - markdown==3.6
137
+ - markdown-it-py==3.0.0
138
+ - markupsafe==2.1.5
139
+ - matplotlib==3.9.2
140
+ - mdurl==0.1.2
141
+ - ml-collections==0.1.1
142
+ - ml-dtypes==0.4.0
143
+ - more-itertools==10.5.0
144
+ - multidict==6.0.5
145
+ - multiprocess==0.70.16
146
+ - namex==0.0.8
147
+ - networkx==3.3
148
+ - nh3==0.2.18
149
+ - nltk==3.9.1
150
+ - numpy==1.26.4
151
+ - nvidia-cublas-cu11==11.10.3.66
152
+ - nvidia-cuda-nvrtc-cu11==11.7.99
153
+ - nvidia-cuda-runtime-cu11==11.7.99
154
+ - nvidia-cudnn-cu11==8.5.0.96
155
+ - oauthlib==3.2.2
156
+ - opencv-python==4.10.0.84
157
+ - opencv-python-headless==4.10.0.84
158
+ - opt-einsum==3.3.0
159
+ - optree==0.12.1
160
+ - orjson==3.10.7
161
+ - pandas==2.2.2
162
+ - pillow==10.4.0
163
+ - pkginfo==1.10.0
164
+ - plotly==5.24.1
165
+ - protobuf==4.25.5
166
+ - psutil==5.9.8
167
+ - pyarrow==17.0.0
168
+ - pyasn1==0.6.1
169
+ - pyasn1-modules==0.4.1
170
+ - pycparser==2.22
171
+ - pydantic==2.9.2
172
+ - pydantic-core==2.23.4
173
+ - pydub==0.25.1
174
+ - pyinstaller==6.10.0
175
+ - pyinstaller-hooks-contrib==2024.8
176
+ - pyparsing==3.1.2
177
+ - pyproject-hooks==1.1.0
178
+ - pysocks==1.7.1
179
+ - python-dateutil==2.9.0.post0
180
+ - python-multipart==0.0.12
181
+ - pytorch-msssim==1.0.0
182
+ - pytorchcv==0.0.73
183
+ - pytz==2023.3.post1
184
+ - pyyaml==6.0.2
185
+ - readme-renderer==44.0
186
+ - regex==2024.7.24
187
+ - requests==2.32.3
188
+ - requests-oauthlib==2.0.0
189
+ - requests-toolbelt==1.0.0
190
+ - rfc3986==2.0.0
191
+ - rich==13.7.1
192
+ - rsa==4.9
193
+ - ruff==0.6.9
194
+ - safetensors==0.4.4
195
+ - saliency==0.2.1
196
+ - scikit-image==0.24.0
197
+ - scikit-learn==1.6.0
198
+ - scipy==1.14.0
199
+ - secretstorage==3.3.3
200
+ - semantic-version==2.10.0
201
+ - sentence-transformers==3.3.1
202
+ - sentry-sdk==2.15.0
203
+ - setproctitle==1.3.3
204
+ - shapely==2.0.5
205
+ - shellingham==1.5.4
206
+ - six==1.12.0
207
+ - smmap==5.0.1
208
+ - sniffio==1.3.1
209
+ - soupsieve==2.6
210
+ - spaces==0.30.2
211
+ - starlette==0.38.6
212
+ - tenacity==9.0.0
213
+ - tensorboard==2.17.1
214
+ - tensorboard-data-server==0.7.2
215
+ - tensorboard-plugin-wit==1.8.1
216
+ - tensorflow==2.17.0
217
+ - tensorflow-estimator==2.10.0
218
+ - tensorflow-hub==0.16.1
219
+ - tensorflow-intel==0.0.1
220
+ - tensorflow-io-gcs-filesystem==0.31.0
221
+ - termcolor==1.1.0
222
+ - tf-keras==2.17.0
223
+ - threadpoolctl==3.5.0
224
+ - tifffile==2024.8.10
225
+ - timm==1.0.10
226
+ - tokenizers==0.19.1
227
+ - tomli==2.0.1
228
+ - tomlkit==0.12.0
229
+ - torch==1.13.1
230
+ - torchvision==0.14.1
231
+ - tqdm==4.66.5
232
+ - transformers==4.43.3
233
+ - twine==5.1.1
234
+ - typer==0.12.5
235
+ - tzdata==2024.1
236
+ - urllib3==2.2.2
237
+ - uvicorn==0.31.0
238
+ - wandb==0.18.3
239
+ - websockets==12.0
240
+ - werkzeug==3.0.4
241
+ - wrapt==1.11.2
242
+ - xxhash==3.4.1
243
+ - yarl==1.9.4
244
+ - zipp==3.20.0
245
+ prefix: /home/ashhar21137/miniconda3/envs/panda
metrics/detectability.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import necessary libraries
2
+ import nltk
3
+ import numpy as np
4
+ import torch
5
+ import matplotlib.pyplot as plt
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+ from transformers import BertModel, BertTokenizer
8
+ from scipy import stats
9
+
10
+ # Download NLTK data if not already present
11
+ nltk.download('punkt', quiet=True)
12
+ detectability_val = {}
13
+
14
+ class SentenceDetectabilityCalculator:
15
+ """
16
+ A class to calculate and analyze detectability metrics between an original sentence and paraphrased sentences.
17
+ """
18
+
19
+ def __init__(self, original_sentence, paraphrased_sentences):
20
+ """
21
+ Initialize the calculator with the original sentence and a list of paraphrased sentences.
22
+ """
23
+ self.original_sentence = original_sentence
24
+ self.paraphrased_sentences = paraphrased_sentences
25
+
26
+ # Raw metric dictionaries
27
+ self.z_scores = {}
28
+ self.p_values = {}
29
+ self.metric_values = []
30
+
31
+ # Normalized metric dictionaries
32
+ self.normalized_z_scores = {}
33
+ self.normalized_p_values = {}
34
+
35
+ # Combined detectability dictionary
36
+ self.combined_detectabilities = {}
37
+
38
+ # Load pre-trained BERT for embeddings
39
+ self.bert_model = BertModel.from_pretrained('bert-base-uncased')
40
+ self.bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
41
+
42
+ def calculate_all_metrics(self):
43
+ """
44
+ Calculate detectability metrics for each paraphrased sentence.
45
+ """
46
+ original_embedding = self._get_sentence_embedding(self.original_sentence)
47
+
48
+ # First, compute the metric values (cosine similarities)
49
+ for idx, paraphrased_sentence in enumerate(self.paraphrased_sentences):
50
+ paraphrase_embedding = self._get_sentence_embedding(paraphrased_sentence)
51
+ cosine_sim = cosine_similarity([original_embedding], [paraphrase_embedding])[0][0]
52
+ self.metric_values.append(cosine_sim)
53
+
54
+ # Compute mean and standard deviation of the metric values
55
+ metric_mean = np.mean(self.metric_values)
56
+ metric_std = np.std(self.metric_values)
57
+
58
+ # Compute z-scores and p-values
59
+ for idx, (paraphrased_sentence, metric_value) in enumerate(zip(self.paraphrased_sentences, self.metric_values)):
60
+ key = f"Sentence_{idx+1}"
61
+ z_score = (metric_value - metric_mean) / metric_std if metric_std != 0 else 0.0
62
+ p_value = stats.norm.sf(abs(z_score)) * 2 # two-tailed p-value
63
+ self.z_scores[key] = z_score
64
+ self.p_values[key] = p_value
65
+
66
+ def normalize_metrics(self):
67
+ """
68
+ Normalize z-scores and p-values to be between 0 and 1.
69
+ """
70
+ self.normalized_z_scores = self._normalize_dict(self.z_scores)
71
+ self.normalized_p_values = self._normalize_dict(self.p_values)
72
+
73
+ def calculate_combined_detectability(self):
74
+ """
75
+ Calculate the combined detectability using the root mean square of the normalized metrics.
76
+ """
77
+ for key in self.normalized_z_scores.keys():
78
+ rms = np.sqrt(
79
+ (
80
+ self.normalized_z_scores[key] ** 2 +
81
+ self.normalized_p_values[key] ** 2
82
+ ) / 2
83
+ )
84
+ self.combined_detectabilities[key] = rms
85
+
86
+ def plot_metrics(self):
87
+ """
88
+ Plot each normalized metric and the combined detectability in separate graphs.
89
+ """
90
+ keys = list(self.normalized_z_scores.keys())
91
+ indices = np.arange(len(keys))
92
+
93
+ # Prepare data for plotting
94
+ metrics = {
95
+ 'Z-Score': [self.normalized_z_scores[key] for key in keys],
96
+ 'P-Value': [self.normalized_p_values[key] for key in keys],
97
+ 'Combined Detectability': [self.combined_detectabilities[key] for key in keys]
98
+ }
99
+
100
+ # Plot each metric separately
101
+ for metric_name, values in metrics.items():
102
+ plt.figure(figsize=(12, 6))
103
+ plt.plot(indices, values, marker='o', color=np.random.rand(3,))
104
+ plt.xlabel('Sentence Index')
105
+ plt.ylabel('Normalized Value (0-1)')
106
+ plt.title(f'Normalized {metric_name}')
107
+ plt.grid(True)
108
+ plt.tight_layout()
109
+ plt.show()
110
+
111
+ # Private methods
112
+ def _get_sentence_embedding(self, sentence):
113
+ """
114
+ Get sentence embedding using BERT.
115
+ """
116
+ tokens = self.bert_tokenizer(sentence, return_tensors='pt', padding=True, truncation=True, max_length=512)
117
+ with torch.no_grad():
118
+ outputs = self.bert_model(**tokens)
119
+ return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
120
+
121
+ def _normalize_dict(self, metric_dict):
122
+ """
123
+ Normalize the values in a dictionary to be between 0 and 1.
124
+ """
125
+ values = np.array(list(metric_dict.values()))
126
+ min_val = values.min()
127
+ max_val = values.max()
128
+ # Avoid division by zero if all values are the same
129
+ if max_val - min_val == 0:
130
+ normalized_values = np.zeros_like(values)
131
+ else:
132
+ normalized_values = (values - min_val) / (max_val - min_val)
133
+ return dict(zip(metric_dict.keys(), normalized_values))
134
+
135
+ # Getter methods
136
+ def get_normalized_metrics(self):
137
+ """
138
+ Get all normalized metrics as a dictionary.
139
+ """
140
+ return {
141
+ 'Z-Score': self.normalized_z_scores,
142
+ 'P-Value': self.normalized_p_values
143
+ }
144
+
145
+ def get_combined_detectabilities(self):
146
+ """
147
+ Get the dictionary of combined detectability values.
148
+ """
149
+ return self.combined_detectabilities
150
+
151
+
152
+ # Example usage
153
+ if __name__ == "__main__":
154
+ # Original sentence
155
+ original_sentence = "The quick brown fox jumps over the lazy dog"
156
+
157
+ # Paraphrased sentences
158
+ paraphrased_sentences = [
159
+ # Original 1: "A swift auburn fox leaps across a sleepy canine."
160
+ "The swift auburn fox leaps across a sleepy canine.",
161
+ "A quick auburn fox leaps across a sleepy canine.",
162
+ "A swift ginger fox leaps across a sleepy canine.",
163
+ "A swift auburn fox bounds across a sleepy canine.",
164
+ "A swift auburn fox leaps across a tired canine.",
165
+ "Three swift auburn foxes leap across a sleepy canine.",
166
+ "The vulpine specimen rapidly traverses over a dormant dog.",
167
+ "Like lightning, the russet hunter soars over the drowsy guardian.",
168
+ "Tha quick ginger fox jumps o'er the lazy hound, ye ken.",
169
+ "One rapid Vulpes vulpes traverses the path of a quiescent canine.",
170
+ "A swift auburn predator navigates across a lethargic pet.",
171
+ "Subject A (fox) demonstrates velocity over Subject B (dog).",
172
+
173
+ # Original 2: "The agile russet fox bounds over an idle hound."
174
+ "Some agile russet foxes bound over an idle hound.",
175
+ "The nimble russet fox bounds over an idle hound.",
176
+ "The agile brown fox bounds over an idle hound.",
177
+ "The agile russet fox jumps over an idle hound.",
178
+ "The agile russet fox bounds over a lazy hound.",
179
+ "Two agile russet foxes bound over an idle hound.",
180
+ "A dexterous vulpine surpasses a stationary canine.",
181
+ "Quick as thought, the copper warrior sails over the guardian.",
182
+ "Tha nimble reddish fox jumps o'er the doggo, don't ya know.",
183
+ "A dexterous V. vulpes exceeds the plane of an inactive canine.",
184
+ "An agile russet hunter maneuvers above a resting hound.",
185
+ "Test subject F-1 achieves displacement superior to subject D-1.",
186
+
187
+ # Original 3: "A nimble mahogany vulpine vaults above a drowsy dog."
188
+ "The nimble mahogany vulpine vaults above a drowsy dog.",
189
+ "A swift mahogany vulpine vaults above a drowsy dog.",
190
+ "A nimble reddish vulpine vaults above a drowsy dog.",
191
+ "A nimble mahogany fox vaults above a drowsy dog.",
192
+ "A nimble mahogany vulpine leaps above a drowsy dog.",
193
+ "Four nimble mahogany vulpines vault above a drowsy dog.",
194
+ "An agile specimen of reddish fur surpasses a somnolent canine.",
195
+ "Fleet as wind, the earth-toned hunter soars over the sleepy guard.",
196
+ "Tha quick brown beastie jumps o'er the tired pup, aye.",
197
+ "Single V. vulpes demonstrates vertical traverse over C. familiaris.",
198
+ "A nimble rust-colored predator crosses above a drowsy pet.",
199
+ "Observed: Subject Red executes vertical motion over Subject Gray.",
200
+
201
+ # Original 4: "The speedy copper-colored fox hops over the lethargic pup."
202
+ "A speedy copper-colored fox hops over the lethargic pup.",
203
+ "The quick copper-colored fox hops over the lethargic pup.",
204
+ "The speedy bronze fox hops over the lethargic pup.",
205
+ "The speedy copper-colored fox jumps over the lethargic pup.",
206
+ "The speedy copper-colored fox hops over the tired pup.",
207
+ "Multiple speedy copper-colored foxes hop over the lethargic pup.",
208
+ "A rapid vulpine of bronze hue traverses an inactive young canine.",
209
+ "Swift as a dart, the metallic hunter bounds over the lazy puppy.",
210
+ "Tha fast copper beastie leaps o'er the sleepy wee dog.",
211
+ "1 rapid V. vulpes crosses above 1 juvenile C. familiaris.",
212
+ "A fleet copper-toned predator moves past a sluggish young dog.",
213
+ "Field note: Adult fox subject exceeds puppy subject vertically.",
214
+
215
+ # Original 5: "A rapid tawny fox springs over a sluggish dog."
216
+ "The rapid tawny fox springs over a sluggish dog.",
217
+ "A quick tawny fox springs over a sluggish dog.",
218
+ "A rapid golden fox springs over a sluggish dog.",
219
+ "A rapid tawny fox jumps over a sluggish dog.",
220
+ "A rapid tawny fox springs over a lazy dog.",
221
+ "Six rapid tawny foxes spring over a sluggish dog.",
222
+ "An expeditious yellowish vulpine surpasses a torpid canine.",
223
+ "Fast as a bullet, the golden hunter vaults over the idle guard.",
224
+ "Tha swift yellowy fox jumps o'er the lazy mutt, aye.",
225
+ "One V. vulpes displays rapid transit over one inactive C. familiaris.",
226
+ "A speedy yellow-brown predator bypasses a motionless dog.",
227
+ "Log entry: Vulpine subject achieves swift vertical displacement.",
228
+
229
+ # Original 6: "The fleet-footed chestnut fox soars above an indolent canine."
230
+ "A fleet-footed chestnut fox soars above an indolent canine.",
231
+ "The swift chestnut fox soars above an indolent canine.",
232
+ "The fleet-footed brown fox soars above an indolent canine.",
233
+ "The fleet-footed chestnut fox leaps above an indolent canine.",
234
+ "The fleet-footed chestnut fox soars above a lazy canine.",
235
+ "Several fleet-footed chestnut foxes soar above an indolent canine.",
236
+ "A rapid brown vulpine specimen traverses a lethargic domestic dog.",
237
+ "Graceful as a bird, the nutbrown hunter flies over the lazy guard.",
238
+ "Tha quick brown beastie sails o'er the sleepy hound, ken.",
239
+ "Single agile V. vulpes achieves elevation above stationary canine.",
240
+ "A nimble brown predator glides over an unmoving domestic animal.",
241
+ "Research note: Brown subject displays superior vertical mobility.",
242
+
243
+ # Original 7: "A fast ginger fox hurdles past a slothful dog."
244
+ "The fast ginger fox hurdles past a slothful dog.",
245
+ "A quick ginger fox hurdles past a slothful dog.",
246
+ "A fast red fox hurdles past a slothful dog.",
247
+ "A fast ginger fox jumps past a slothful dog.",
248
+ "A fast ginger fox hurdles past a lazy dog.",
249
+ "Five fast ginger foxes hurdle past a slothful dog.",
250
+ "A rapid orange vulpine bypasses a lethargic canine.",
251
+ "Quick as lightning, the flame-colored hunter races past the lazy guard.",
252
+ "Tha swift ginger beastie leaps past the tired doggy, ye see.",
253
+ "1 rapid orange V. vulpes surpasses 1 inactive C. familiaris.",
254
+ "A speedy red-orange predator overtakes a motionless dog.",
255
+ "Data point: Orange subject demonstrates rapid transit past Gray subject.",
256
+
257
+ # Original 8: "The spry rusty-colored fox jumps across a dozing hound."
258
+ "A spry rusty-colored fox jumps across a dozing hound.",
259
+ "The agile rusty-colored fox jumps across a dozing hound.",
260
+ "The spry reddish fox jumps across a dozing hound.",
261
+ "The spry rusty-colored fox leaps across a dozing hound.",
262
+ "The spry rusty-colored fox jumps across a sleeping hound.",
263
+ "Multiple spry rusty-colored foxes jump across a dozing hound.",
264
+ "An agile rust-toned vulpine traverses a somnolent canine.",
265
+ "Nimble as thought, the copper hunter bounds over the resting guard.",
266
+ "Tha lively rust-colored beastie hops o'er the snoozin' hound.",
267
+ "Single dexterous V. vulpes crosses path of dormant C. familiaris.",
268
+ "A lithe rust-tinted predator moves past a slumbering dog.",
269
+ "Observation: Russet subject exhibits agility over dormant subject.",
270
+
271
+ # Original 9: "A quick tan fox leaps over an inactive dog."
272
+ "The quick tan fox leaps over an inactive dog.",
273
+ "A swift tan fox leaps over an inactive dog.",
274
+ "A quick beige fox leaps over an inactive dog.",
275
+ "A quick tan fox jumps over an inactive dog.",
276
+ "A quick tan fox leaps over a motionless dog.",
277
+ "Seven quick tan foxes leap over an inactive dog.",
278
+ "A rapid light-brown vulpine surpasses a stationary canine.",
279
+ "Fast as wind, the sand-colored hunter soars over the still guard.",
280
+ "Tha nimble tan beastie jumps o'er the quiet doggy, aye.",
281
+ "One agile fawn V. vulpes traverses one immobile C. familiaris.",
282
+ "A fleet tan-colored predator bypasses an unmoving dog.",
283
+ "Field report: Tan subject demonstrates movement over static subject.",
284
+
285
+ # Original 10: "The brisk auburn vulpine bounces over a listless canine."
286
+ "Some brisk auburn vulpines bounce over a listless canine.",
287
+ "The quick auburn vulpine bounces over a listless canine.",
288
+ "The brisk russet vulpine bounces over a listless canine.",
289
+ "The brisk auburn fox bounces over a listless canine.",
290
+ "The brisk auburn vulpine jumps over a listless canine.",
291
+ "Five brisk auburn vulpines bounce over a listless canine.",
292
+ "The expeditious specimen supersedes a quiescent Canis lupus.",
293
+ "Swift as wind, the russet hunter vaults over the idle guardian.",
294
+ "Tha quick ginger beastie hops o'er the lazy mutt, aye.",
295
+ "One V. vulpes achieves displacement over inactive C. familiaris.",
296
+ "A high-velocity auburn predator traverses an immobile animal.",
297
+ "Final observation: Red subject shows mobility over Gray subject."
298
+ ]
299
+
300
+ # Initialize the calculator
301
+ calculator = SentenceDetectabilityCalculator(original_sentence, paraphrased_sentences)
302
+
303
+ # Calculate all metrics
304
+ calculator.calculate_all_metrics()
305
+
306
+ # Normalize the metrics
307
+ calculator.normalize_metrics()
308
+
309
+ # Calculate combined detectability
310
+ calculator.calculate_combined_detectability()
311
+
312
+ # Retrieve the normalized metrics and combined detectabilities
313
+ normalized_metrics = calculator.get_normalized_metrics()
314
+ combined_detectabilities = calculator.get_combined_detectabilities()
315
+ detectability_val = combined_detectabilities
316
+
317
+ # Display the results
318
+ print("\nCombined Detectabilities:")
319
+ for each in combined_detectabilities.items():
320
+ print(f"{each[1]}")
321
+
322
+ # Plot the metrics (optional)
323
+ #calculator.plot_metrics()
metrics/distortion.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from tqdm import tqdm
4
+ import numpy as np
5
+ import torch
6
+ import matplotlib.pyplot as plt
7
+ from transformers import GPT2LMHeadModel, GPT2TokenizerFast
8
+ from bert_score import BERTScorer
9
+ from bert_score.utils import model2layers
10
+ from nltk.tokenize import word_tokenize
11
+ from Levenshtein import distance as levenshtein_distance
12
+ from sentence_transformers import SentenceTransformer
13
+ from sklearn.feature_extraction.text import TfidfVectorizer
14
+ from scipy.spatial.distance import cdist
15
+ from scipy.optimize import linear_sum_assignment
16
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
17
+
18
+ from config.config import load_config
19
+ config_path = os.path.join(os.path.dirname(__file__), '..', 'config', 'config.yaml')
20
+ config = load_config(config_path)['PECCAVI_TEXT']['Metrics']
21
+
22
+ class SentenceDistortionCalculator:
23
+ """
24
+ A class to calculate and analyze distortion metrics between an original sentence and modified sentences.
25
+ """
26
+ def __init__(self, config, original_sentence, paraphrased_sentences):
27
+ """
28
+ Initialize the calculator with the original sentence and a list of modified sentences.
29
+ """
30
+ self.original_sentence = original_sentence
31
+ self.paraphrased_sentences = paraphrased_sentences
32
+
33
+ self.levenshtein_distances = {}
34
+ self.bert_scores = {}
35
+ self.mover_scores = {}
36
+
37
+ self.normalized_levenshtein = {}
38
+ self.normalized_bert_scores = {}
39
+ self.normalized_mover_scores = {}
40
+ self.combined_distortions = {}
41
+
42
+ self.tokenizer = GPT2TokenizerFast.from_pretrained(config['Distortion'])
43
+ self.model = GPT2LMHeadModel.from_pretrained(config['Distortion'])
44
+ self.model.eval()
45
+
46
+ def calculate_all_metrics(self):
47
+ """
48
+ Calculate all distortion metrics for each modified sentence.
49
+ """
50
+ for idx, modified_sentence in tqdm(enumerate(self.paraphrased_sentences), total=len(self.paraphrased_sentences), desc="Calculating Metrics"):
51
+ key = f"Sentence_{idx+1}"
52
+ self.levenshtein_distances[key] = self._calculate_levenshtein_distance(modified_sentence)
53
+ self.bert_scores[key] = self._calculate_bert_score(modified_sentence)
54
+ self.mover_scores[key] = self._calculate_mover_score(modified_sentence)
55
+
56
+
57
+ def normalize_metrics(self):
58
+ """
59
+ Normalize all metrics to be between 0 and 1.
60
+ """
61
+ for _ in tqdm(range(1), desc="Normalizing Metrics"): # Add tqdm here (wrap the normalization process)
62
+ self.normalized_levenshtein = self._normalize_dict(self.levenshtein_distances)
63
+ self.normalized_bert_scores = self._normalize_dict(self.bert_scores)
64
+ self.normalized_mover_scores = self._normalize_dict(self.mover_scores)
65
+
66
+ def calculate_combined_distortion(self):
67
+ """
68
+ Calculate the combined distortion using the root mean square of the normalized metrics.
69
+ """
70
+ for _ in tqdm(range(1), desc="Calculating Combined Distortion"): # Add tqdm here
71
+ for key in self.normalized_levenshtein.keys():
72
+ rms = np.sqrt(
73
+ (
74
+ self.normalized_levenshtein[key] ** 2 +
75
+ self.normalized_bert_scores[key] ** 2+
76
+ self.normalized_mover_scores[key] **2
77
+ ) / 3
78
+ )
79
+ self.combined_distortions[key] = rms
80
+
81
+ def plot_metrics(self):
82
+ """
83
+ Plot each normalized metric and the combined distortion in separate graphs.
84
+ """
85
+ keys = list(self.normalized_levenshtein.keys())
86
+ indices = np.arange(len(keys))
87
+
88
+ # Prepare data for plotting
89
+ metrics = {
90
+ 'Levenshtein Distance': [self.normalized_levenshtein[key] for key in keys],
91
+ 'BERTScore': [self.normalized_bert_scores[key] for key in keys],
92
+ 'MOVERscore':[self.normalized_mover_scores[key] for key in keys],
93
+ 'Combined Distortion': [self.combined_distortions[key] for key in keys]
94
+ }
95
+
96
+ # Plot each metric separately
97
+ for metric_name, values in tqdm(metrics.items(), desc="Plotting Metrics"): # Add tqdm here
98
+ plt.figure(figsize=(12, 6))
99
+ plt.plot(indices, values, marker='o', color=np.random.rand(3,))
100
+ plt.xlabel('Sentence Index')
101
+ plt.ylabel('Normalized Value (0-1)')
102
+ plt.title(f'Normalized {metric_name}')
103
+ plt.grid(True)
104
+ plt.tight_layout()
105
+ plt.show()
106
+
107
+ def _calculate_levenshtein_distance(self, modified_sentence):
108
+ """
109
+ Calculate the word-level Levenshtein distance between the original and modified sentence.
110
+ """
111
+ words1 = word_tokenize(self.original_sentence)
112
+ words2 = word_tokenize(modified_sentence)
113
+ lev_distance = levenshtein_distance(words1, words2)
114
+ return (lev_distance / max(len(words1), len(words2)))
115
+
116
+ def _calculate_bert_score(self, modified_sentence):
117
+ """
118
+ Compute the BERTScore similarity between the original and modified sentence.
119
+ Returns 1 - F1 score to represent dissimilarity.
120
+ """
121
+ if not hasattr(self, 'original_sentence'):
122
+ raise ValueError("original_sentence is not set. Please set self.original_sentence before calling this function.")
123
+ if not isinstance(modified_sentence, str):
124
+ raise ValueError("modified_sentence must be a string.")
125
+
126
+ model_type = "microsoft/deberta-xlarge-mnli"
127
+ num_layers = model2layers[model_type]
128
+
129
+ if not hasattr(self, "cached_bertscorer"):
130
+ self.cached_bertscorer = BERTScorer(
131
+ model_type=model_type,
132
+ num_layers=num_layers,
133
+ batch_size=1, # Single sentence comparison
134
+ nthreads=4,
135
+ all_layers=False,
136
+ idf=False,
137
+ device="cuda" if torch.cuda.is_available() else "cpu",
138
+ lang="en"
139
+ )
140
+
141
+ # Compute BERTScore
142
+ _, _, F1 = self.cached_bertscorer.score(
143
+ cands=[modified_sentence],
144
+ refs=[self.original_sentence],
145
+ verbose=False,
146
+ batch_size=1
147
+ )
148
+
149
+ return 1 - F1.item() # Return dissimilarity score
150
+ def _calculate_mover_score(self,modified_sentence,model_name='all-MiniLM-L6-v2'):
151
+ """Compute MoverScore correctly using word-level embeddings."""
152
+ if not self.original_sentence:
153
+ raise ValueError("Original sentence not provided.")
154
+
155
+ # Tokenize sentences
156
+ original_tokens = self.original_sentence.split()
157
+ modified_tokens = modified_sentence.split()
158
+ model = SentenceTransformer(model_name)
159
+
160
+ # Compute word embeddings
161
+ original_embeddings = model.encode(original_tokens, convert_to_numpy=True)
162
+ modified_embeddings = model.encode(modified_tokens, convert_to_numpy=True)
163
+
164
+ # Compute cost matrix (cosine distance)
165
+ cost_matrix = cdist(original_embeddings, modified_embeddings, metric='cosine')
166
+
167
+ # Solve optimal transport problem (Hungarian Algorithm)
168
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
169
+
170
+ # Compute IDF weights
171
+ vectorizer = TfidfVectorizer()
172
+ vectorizer.fit([self.original_sentence, modified_sentence])
173
+ idf_values = dict(zip(vectorizer.get_feature_names_out(), vectorizer.idf_))
174
+
175
+ # Apply IDF weighting to aligned word pairs
176
+ idf_weights_original = np.array([idf_values.get(word.lower(), 1.0) for word in original_tokens])
177
+ idf_weights_modified = np.array([idf_values.get(word.lower(), 1.0) for word in modified_tokens])
178
+ combined_idf_weights = (idf_weights_original[row_ind] + idf_weights_modified[col_ind]) / 2
179
+ weighted_score = np.sum((1 - cost_matrix[row_ind, col_ind]) * combined_idf_weights) / np.sum(combined_idf_weights)
180
+
181
+ return 1-weighted_score # Higher score = more dissimilar
182
+
183
+ def _normalize_dict(self, metric_dict):
184
+ """
185
+ Normalize the values in a dictionary to be between 0 and 1.
186
+ """
187
+ values = np.array(list(metric_dict.values()))
188
+ min_val = values.min()
189
+ max_val = values.max()
190
+ if max_val - min_val == 0:
191
+ normalized_values = np.zeros_like(values)
192
+ else:
193
+ normalized_values = (values - min_val) / (max_val - min_val)
194
+ return dict(zip(metric_dict.keys(), normalized_values))
195
+
196
+ def get_normalized_metrics(self):
197
+ """
198
+ Get all normalized metrics as a dictionary.
199
+ """
200
+ return {
201
+ 'Min Edit Distance': self.normalized_levenshtein,
202
+ 'BERTScore': self.normalized_bert_scores,
203
+ 'Mover Score': self.normalized_mover_scores
204
+ }
205
+
206
+ def get_combined_distortions(self):
207
+ """
208
+ Get the dictionary of combined distortion values.
209
+ """
210
+ return self.combined_distortions
211
+
212
+ # Example usage
213
+ if __name__ == "__main__":
214
+
215
+ config = load_config(config_path)['PECCAVI_TEXT']['Metrics']
216
+
217
+ # Original sentence
218
+ original_sentence = "The quick brown fox jumps over the lazy dog"
219
+
220
+ # Paraphrased sentences
221
+ paraphrased_sentences = [
222
+ # Original 1: "A swift auburn fox leaps across a sleepy canine."
223
+ "The swift auburn fox leaps across a sleepy canine.",
224
+ "A quick auburn fox leaps across a sleepy canine.",
225
+ "A swift ginger fox leaps across a sleepy canine.",
226
+ "A swift auburn fox bounds across a sleepy canine.",
227
+ "A swift auburn fox leaps across a tired canine.",
228
+ "Three swift auburn foxes leap across a sleepy canine.",
229
+ "The vulpine specimen rapidly traverses over a dormant dog.",
230
+ "Like lightning, the russet hunter soars over the drowsy guardian.",
231
+ "Tha quick ginger fox jumps o'er the lazy hound, ye ken.",
232
+ "One rapid Vulpes vulpes traverses the path of a quiescent canine.",
233
+ "A swift auburn predator navigates across a lethargic pet.",
234
+ "Subject A (fox) demonstrates velocity over Subject B (dog).",
235
+
236
+ # Original 2: "The agile russet fox bounds over an idle hound."
237
+ "Some agile russet foxes bound over an idle hound.",
238
+ "The nimble russet fox bounds over an idle hound.",
239
+ "The agile brown fox bounds over an idle hound.",
240
+ "The agile russet fox jumps over an idle hound.",
241
+ "The agile russet fox bounds over a lazy hound.",
242
+ "Two agile russet foxes bound over an idle hound.",
243
+ "A dexterous vulpine surpasses a stationary canine.",
244
+ "Quick as thought, the copper warrior sails over the guardian.",
245
+ "Tha nimble reddish fox jumps o'er the doggo, don't ya know.",
246
+ "A dexterous V. vulpes exceeds the plane of an inactive canine.",
247
+ "An agile russet hunter maneuvers above a resting hound.",
248
+ "Test subject F-1 achieves displacement superior to subject D-1.",
249
+
250
+ # Original 3: "A nimble mahogany vulpine vaults above a drowsy dog."
251
+ "The nimble mahogany vulpine vaults above a drowsy dog.",
252
+ "A swift mahogany vulpine vaults above a drowsy dog.",
253
+ "A nimble reddish vulpine vaults above a drowsy dog.",
254
+ "A nimble mahogany fox vaults above a drowsy dog.",
255
+ "A nimble mahogany vulpine leaps above a drowsy dog.",
256
+ "Four nimble mahogany vulpines vault above a drowsy dog.",
257
+ "An agile specimen of reddish fur surpasses a somnolent canine.",
258
+ "Fleet as wind, the earth-toned hunter soars over the sleepy guard.",
259
+ "Tha quick brown beastie jumps o'er the tired pup, aye.",
260
+ "Single V. vulpes demonstrates vertical traverse over C. familiaris.",
261
+ "A nimble rust-colored predator crosses above a drowsy pet.",
262
+ "Observed: Subject Red executes vertical motion over Subject Gray.",
263
+
264
+ # Original 4: "The speedy copper-colored fox hops over the lethargic pup."
265
+ "A speedy copper-colored fox hops over the lethargic pup.",
266
+ "The quick copper-colored fox hops over the lethargic pup.",
267
+ "The speedy bronze fox hops over the lethargic pup.",
268
+ "The speedy copper-colored fox jumps over the lethargic pup.",
269
+ "The speedy copper-colored fox hops over the tired pup.",
270
+ "Multiple speedy copper-colored foxes hop over the lethargic pup.",
271
+ "A rapid vulpine of bronze hue traverses an inactive young canine.",
272
+ "Swift as a dart, the metallic hunter bounds over the lazy puppy.",
273
+ "Tha fast copper beastie leaps o'er the sleepy wee dog.",
274
+ "1 rapid V. vulpes crosses above 1 juvenile C. familiaris.",
275
+ "A fleet copper-toned predator moves past a sluggish young dog.",
276
+ "Field note: Adult fox subject exceeds puppy subject vertically.",
277
+
278
+ # Original 5: "A rapid tawny fox springs over a sluggish dog."
279
+ "The rapid tawny fox springs over a sluggish dog.",
280
+ "A quick tawny fox springs over a sluggish dog.",
281
+ "A rapid golden fox springs over a sluggish dog.",
282
+ "A rapid tawny fox jumps over a sluggish dog.",
283
+ "A rapid tawny fox springs over a lazy dog.",
284
+ "Six rapid tawny foxes spring over a sluggish dog.",
285
+ "An expeditious yellowish vulpine surpasses a torpid canine.",
286
+ "Fast as a bullet, the golden hunter vaults over the idle guard.",
287
+ "Tha swift yellowy fox jumps o'er the lazy mutt, aye.",
288
+ "One V. vulpes displays rapid transit over one inactive C. familiaris.",
289
+ "A speedy yellow-brown predator bypasses a motionless dog.",
290
+ "Log entry: Vulpine subject achieves swift vertical displacement.",
291
+
292
+ # Original 6: "The fleet-footed chestnut fox soars above an indolent canine."
293
+ "A fleet-footed chestnut fox soars above an indolent canine.",
294
+ "The swift chestnut fox soars above an indolent canine.",
295
+ "The fleet-footed brown fox soars above an indolent canine.",
296
+ "The fleet-footed chestnut fox leaps above an indolent canine.",
297
+ "The fleet-footed chestnut fox soars above a lazy canine.",
298
+ "Several fleet-footed chestnut foxes soar above an indolent canine.",
299
+ "A rapid brown vulpine specimen traverses a lethargic domestic dog.",
300
+ "Graceful as a bird, the nutbrown hunter flies over the lazy guard.",
301
+ "Tha quick brown beastie sails o'er the sleepy hound, ken.",
302
+ "Single agile V. vulpes achieves elevation above stationary canine.",
303
+ "A nimble brown predator glides over an unmoving domestic animal.",
304
+ "Research note: Brown subject displays superior vertical mobility.",
305
+
306
+ # Original 7: "A fast ginger fox hurdles past a slothful dog."
307
+ "The fast ginger fox hurdles past a slothful dog.",
308
+ "A quick ginger fox hurdles past a slothful dog.",
309
+ "A fast red fox hurdles past a slothful dog.",
310
+ "A fast ginger fox jumps past a slothful dog.",
311
+ "A fast ginger fox hurdles past a lazy dog.",
312
+ "Five fast ginger foxes hurdle past a slothful dog.",
313
+ "A rapid orange vulpine bypasses a lethargic canine.",
314
+ "Quick as lightning, the flame-colored hunter races past the lazy guard.",
315
+ "Tha swift ginger beastie leaps past the tired doggy, ye see.",
316
+ "1 rapid orange V. vulpes surpasses 1 inactive C. familiaris.",
317
+ "A speedy red-orange predator overtakes a motionless dog.",
318
+ "Data point: Orange subject demonstrates rapid transit past Gray subject.",
319
+
320
+ # Original 8: "The spry rusty-colored fox jumps across a dozing hound."
321
+ "A spry rusty-colored fox jumps across a dozing hound.",
322
+ "The agile rusty-colored fox jumps across a dozing hound.",
323
+ "The spry reddish fox jumps across a dozing hound.",
324
+ "The spry rusty-colored fox leaps across a dozing hound.",
325
+ "The spry rusty-colored fox jumps across a sleeping hound.",
326
+ "Multiple spry rusty-colored foxes jump across a dozing hound.",
327
+ "An agile rust-toned vulpine traverses a somnolent canine.",
328
+ "Nimble as thought, the copper hunter bounds over the resting guard.",
329
+ "Tha lively rust-colored beastie hops o'er the snoozin' hound.",
330
+ "Single dexterous V. vulpes crosses path of dormant C. familiaris.",
331
+ "A lithe rust-tinted predator moves past a slumbering dog.",
332
+ "Observation: Russet subject exhibits agility over dormant subject.",
333
+
334
+ # Original 9: "A quick tan fox leaps over an inactive dog."
335
+ "The quick tan fox leaps over an inactive dog.",
336
+ "A swift tan fox leaps over an inactive dog.",
337
+ "A quick beige fox leaps over an inactive dog.",
338
+ "A quick tan fox jumps over an inactive dog.",
339
+ "A quick tan fox leaps over a motionless dog.",
340
+ "Seven quick tan foxes leap over an inactive dog.",
341
+ "A rapid light-brown vulpine surpasses a stationary canine.",
342
+ "Fast as wind, the sand-colored hunter soars over the still guard.",
343
+ "Tha nimble tan beastie jumps o'er the quiet doggy, aye.",
344
+ "One agile fawn V. vulpes traverses one immobile C. familiaris.",
345
+ "A fleet tan-colored predator bypasses an unmoving dog.",
346
+ "Field report: Tan subject demonstrates movement over static subject.",
347
+
348
+ # Original 10: "The brisk auburn vulpine bounces over a listless canine."
349
+ "Some brisk auburn vulpines bounce over a listless canine.",
350
+ "The quick auburn vulpine bounces over a listless canine.",
351
+ "The brisk russet vulpine bounces over a listless canine.",
352
+ "The brisk auburn fox bounces over a listless canine.",
353
+ "The brisk auburn vulpine jumps over a listless canine.",
354
+ "Five brisk auburn vulpines bounce over a listless canine.",
355
+ "The expeditious specimen supersedes a quiescent Canis lupus.",
356
+ "Swift as wind, the russet hunter vaults over the idle guardian.",
357
+ "Tha quick ginger beastie hops o'er the lazy mutt, aye.",
358
+ "One V. vulpes achieves displacement over inactive C. familiaris.",
359
+ "A high-velocity auburn predator traverses an immobile animal.",
360
+ "Final observation: Red subject shows mobility over Gray subject."
361
+ ]
362
+
363
+ distortion_calculator = SentenceDistortionCalculator(config, original_sentence, paraphrased_sentences)
364
+ for _ in tqdm(range(1)):
365
+ distortion_calculator.calculate_all_metrics()
366
+ distortion_calculator.normalize_metrics()
367
+ distortion_calculator.calculate_combined_distortion()
368
+ distortion_calculator.plot_metrics()
369
+ print("Normalized Metrics:", distortion_calculator.get_normalized_metrics())
370
+ print("Combined Distortion:", distortion_calculator.get_combined_distortions())
renderers/__pycache__/highlighter.cpython-310.pyc ADDED
Binary file (4.98 kB). View file
 
renderers/__pycache__/highlighter.cpython-311.pyc ADDED
Binary file (8.07 kB). View file
 
renderers/__pycache__/plot_3d.cpython-310.pyc ADDED
Binary file (4.34 kB). View file
 
renderers/__pycache__/plot_3d.cpython-311.pyc ADDED
Binary file (6 kB). View file
 
renderers/__pycache__/tree.cpython-310.pyc ADDED
Binary file (10.6 kB). View file
 
renderers/__pycache__/tree.cpython-311.pyc ADDED
Binary file (21.1 kB). View file
 
renderers/highlighter.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ def create_matching_pattern(word):
4
+ """Creates appropriate regex pattern based on word characteristics"""
5
+ escaped_word = re.escape(word)
6
+
7
+ # Check for special cases
8
+ if any(char in word for char in '&-/.\'()[]$€£¥+') or ' ' in word:
9
+ # Special handling for phrases with special characters or spaces
10
+ return rf'{escaped_word}'
11
+ elif word.endswith('%'):
12
+ # Special handling for percentage values
13
+ numeric_part = word[:-1]
14
+ return rf'\b{re.escape(numeric_part)}\s*%'
15
+ elif re.search(r'[0-9]', word) and re.search(r'[a-zA-Z]', word):
16
+ # Special handling for alphanumeric combinations
17
+ return rf'{escaped_word}'
18
+ else:
19
+ # Standard word boundary pattern for simple words
20
+ return rf'\b{escaped_word}\b'
21
+
22
+ def highlight_common_words(common_words, sentences, title):
23
+ """
24
+ Highlight common words in sentences by adding color-coded background and unique IDs.
25
+
26
+ Args:
27
+ common_words (list of tuples): List of tuples where each tuple contains a word's index and the word.
28
+ sentences (list of str): List of sentences to search through.
29
+ title (str): The title for the HTML output.
30
+
31
+ Returns:
32
+ str: HTML string with the highlighted sentences.
33
+ """
34
+ color_map = {}
35
+ color_index = 0
36
+ highlighted_html = []
37
+ pastel_colors = ['#E199C6','#7FB3D5', '#E57373', '#B388EB', '#80D9AA', '#F0B66B',
38
+ "#73D9A0", "#9A89EB", "#E38AD4", "#7FAFDB", "#DDAF8C", "#DA7FC1",
39
+ "#65CFA5", "#B38FDE", "#E6C97A"]
40
+
41
+ # Process each sentence
42
+ for idx, sentence in enumerate(sentences, start=1):
43
+ sentence_with_idx = f"{idx}. {sentence}"
44
+ highlighted_sentence = sentence_with_idx
45
+
46
+ # Highlight common words in each sentence
47
+ for index, word in common_words:
48
+ if word not in color_map:
49
+ color_map[word] = pastel_colors[color_index % len(pastel_colors)]
50
+ color_index += 1
51
+
52
+ # Create appropriate pattern based on word characteristics
53
+ pattern = create_matching_pattern(word)
54
+
55
+ # Replace the word with highlighted version
56
+ highlighted_sentence = re.sub(
57
+ pattern,
58
+ lambda m, idx=index, color=color_map[word]: (
59
+ f'<span style="background-color: {color}; font-weight: bold;'
60
+ f' padding: 2px 4px; border-radius: 2px; position: relative;">'
61
+ f'<span style="background-color: black; color: white; border-radius: 50%;'
62
+ f' padding: 2px 5px; margin-right: 5px;">{idx}</span>'
63
+ f'{m.group(0)}'
64
+ f'</span>'
65
+ ),
66
+ highlighted_sentence,
67
+ flags=re.IGNORECASE
68
+ )
69
+
70
+ highlighted_html.append(highlighted_sentence)
71
+
72
+ # Format the HTML output with the title
73
+ final_html = "<br><br>".join(highlighted_html)
74
+ return f'''
75
+ <div style="border: solid 1px #FFFFFF; padding: 16px; background-color: #000000; color: #FFFFFF; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); border-radius: 8px;">
76
+ <h3 style="margin-top: 0; font-size: 1em; color: #FFFFFF;">{title}</h3>
77
+ <div style="background-color: #333333; line-height: 1.6; padding: 15px; border-radius: 8px; color: #FFFFFF;">{final_html}</div>
78
+ </div>
79
+ '''
80
+
81
+ def highlight_common_words_dict(common_words, sentences, title):
82
+ """
83
+ Highlight common words in sentences (from a dictionary) by adding color-coded background and unique IDs.
84
+
85
+ Args:
86
+ common_words (list of tuples): List of tuples where each tuple contains a word's index and the word.
87
+ sentences (dict): A dictionary of sentences where the key is the sentence and the value is an entailment score.
88
+ title (str): The title for the HTML output.
89
+
90
+ Returns:
91
+ str: HTML string with the highlighted sentences and their entailment scores.
92
+ """
93
+ color_map = {}
94
+ color_index = 0
95
+ highlighted_html = []
96
+ pastel_colors = ['#E199C6','#7FB3D5', '#E57373', '#B388EB', '#80D9AA', '#F0B66B',
97
+ "#73D9A0", "#9A89EB", "#E38AD4", "#7FAFDB", "#DDAF8C", "#DA7FC1",
98
+ "#65CFA5", "#B38FDE", "#E6C97A"]
99
+
100
+ # Process each sentence and its score
101
+ for idx, (sentence, score) in enumerate(sentences.items(), start=1):
102
+ sentence_with_idx = f"{idx}. {sentence}"
103
+ highlighted_sentence = sentence_with_idx
104
+
105
+ # Highlight common words in each sentence
106
+ for index, word in common_words:
107
+ if word not in color_map:
108
+ color_map[word] = pastel_colors[color_index % len(pastel_colors)]
109
+ color_index += 1
110
+
111
+ # Create appropriate pattern based on word characteristics
112
+ pattern = create_matching_pattern(word)
113
+
114
+ # Replace the word with highlighted version
115
+ highlighted_sentence = re.sub(
116
+ pattern,
117
+ lambda m, idx=index, color=color_map[word]: (
118
+ f'<span style="background-color: {color}; font-weight: bold;'
119
+ f' padding: 1px 2px; border-radius: 2px; position: relative;">'
120
+ f'<span style="background-color: black; color: white; border-radius: 50%;'
121
+ f' padding: 1px 3px; margin-right: 3px; font-size: 0.8em;">{idx}</span>'
122
+ f'{m.group(0)}'
123
+ f'</span>'
124
+ ),
125
+ highlighted_sentence,
126
+ flags=re.IGNORECASE
127
+ )
128
+
129
+ # Add the entailment score
130
+ highlighted_html.append(
131
+ f'<div style="margin-bottom: 5px;">'
132
+ f'{highlighted_sentence}'
133
+ f'<div style="display: inline-block; margin-left: 5px; padding: 3px 5px; border-radius: 3px; '
134
+ f'background-color: #333333; color: white; font-size: 0.9em;">'
135
+ f'Entailment Score: {score}</div></div>'
136
+ )
137
+
138
+ # Format the HTML output with the title
139
+ final_html = "<br>".join(highlighted_html)
140
+ return f'''
141
+ <div style="background-color: #000000; color: #FFFFFF;border: solid 1px #FFFFFF; border-radius: 8px;">
142
+ <h3 style="margin-top: 0; font-size: 1em; color: #FFFFFF;">{title}</h3>
143
+ <div style="background-color: #333333; line-height: 1.6; padding: 15px; border-radius: 8px; color: #FFFFFF;">{final_html}</div>
144
+ </div>
145
+ '''
146
+
147
+ def reparaphrased_sentences_html(sentences):
148
+ """
149
+ Create an HTML representation of sentences with numbering.
150
+
151
+ Args:
152
+ sentences (list of str): List of sentences to format.
153
+
154
+ Returns:
155
+ str: HTML string with numbered sentences.
156
+ """
157
+ formatted_sentences = []
158
+
159
+ # Process each sentence
160
+ for idx, sentence in enumerate(sentences, start=1):
161
+ sentence_with_idx = f"{idx}. {sentence}"
162
+ formatted_sentences.append(sentence_with_idx)
163
+
164
+ # Format the HTML output
165
+ final_html = "<br><br>".join(formatted_sentences)
166
+ return f'''
167
+ <div style="border: solid 1px #FFFFFF; background-color: #000000; color: #FFFFFF;
168
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); border-radius: 8px;">
169
+ <div style="background-color: #333333; line-height: 1.6; padding: 15px; border-radius: 8px;">{final_html}</div>
170
+ </div>
171
+ '''
172
+
173
+ if __name__ == "__main__":
174
+ # Example usage
175
+ common_words = [(1, "highlight"), (2, "numbering"), (3, "S&P 500")]
176
+ sentences = ["This is a test to highlight words.", "Numbering is important for clarity.", "The S&P 500 index rose 2% today."]
177
+
178
+ # Test highlight_common_words
179
+ highlighted_html = highlight_common_words(common_words, sentences, "Test Highlighting")
180
+ print(highlighted_html)
181
+
182
+ # Test highlight_common_words_dict
183
+ sentences_with_scores = {"Highlight words in this text.": 0.95, "Number sentences for clarity.": 0.8, "The S&P 500 index is a market benchmark.": 0.88}
184
+ highlighted_html_dict = highlight_common_words_dict(common_words, sentences_with_scores, "Test Dict Highlighting")
185
+ print(highlighted_html_dict)
renderers/plot_3d.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains the code to plot a 3d tree
3
+ """
4
+ import numpy as np
5
+ import plotly.graph_objects as go
6
+ from scipy.interpolate import griddata
7
+
8
+ def gen_three_D_plot(detectability_val, distortion_val, euclidean_val):
9
+ """
10
+ Generates a 3D surface plot showing the relationship between detectability, distortion,
11
+ and Euclidean distance, with a focus on highlighting the "sweet spot" based on a composite score.
12
+
13
+ The function takes three sets of values: detectability, distortion, and Euclidean distance,
14
+ normalizes them to a [0, 1] range, and computes a composite score that combines these three metrics.
15
+ The "sweet spot" is the point where the composite score is maximized. This sweet spot is plotted
16
+ as a red marker on the 3D surface plot.
17
+
18
+ The function then uses a grid interpolation method (`griddata`) to generate a smooth surface
19
+ for the Euclidean distance over the detectability and distortion values. The result is a surface plot
20
+ where the contours represent different Euclidean distances.
21
+
22
+ Args:
23
+ detectability_val (list or array): A list or array of detectability scores.
24
+ distortion_val (list or array): A list or array of distortion scores.
25
+ euclidean_val (list or array): A list or array of Euclidean distances.
26
+
27
+ Returns:
28
+ plotly.graph_objects.Figure: A Plotly figure object representing the 3D surface plot,
29
+ with contour lines and a marker for the sweet spot.
30
+
31
+ Raises:
32
+ ValueError: If `griddata` fails to generate a valid interpolation, which could happen if the
33
+ input data does not allow for a proper interpolation.
34
+
35
+ Example:
36
+ # Example of usage:
37
+ detectability_vals = [0.1, 0.3, 0.5, 0.7, 0.9]
38
+ distortion_vals = [0.2, 0.4, 0.6, 0.8, 1.0]
39
+ euclidean_vals = [0.5, 0.3, 0.2, 0.4, 0.6]
40
+
41
+ fig = gen_three_D_plot(detectability_vals, distortion_vals, euclidean_vals)
42
+ fig.show() # Displays the plot in a web browser
43
+
44
+ Notes:
45
+ - The composite score is calculated as:
46
+ `composite_score = norm_detectability - (norm_distortion + norm_euclidean)`,
47
+ where the goal is to maximize detectability and minimize distortion and Euclidean distance.
48
+ - The `griddata` function uses linear interpolation to create a smooth surface for the plot.
49
+ - The function uses the "Plasma" colorscale for the surface plot, which provides a perceptually uniform color scheme.
50
+ """
51
+
52
+ detectability = np.array(detectability_val)
53
+ distortion = np.array(distortion_val)
54
+ euclidean = np.array(euclidean_val)
55
+
56
+ # Normalize the values to range [0, 1]
57
+ norm_detectability = (detectability - min(detectability)) / (max(detectability) - min(detectability))
58
+ norm_distortion = (distortion - min(distortion)) / (max(distortion) - min(distortion))
59
+ norm_euclidean = (euclidean - min(euclidean)) / (max(euclidean) - min(euclidean))
60
+
61
+ # Composite score: maximize detectability, minimize distortion and Euclidean distance
62
+ composite_score = norm_detectability - (norm_distortion + norm_euclidean)
63
+
64
+ # Find the index of the maximum score (sweet spot)
65
+ sweet_spot_index = np.argmax(composite_score)
66
+
67
+ # Sweet spot values
68
+ sweet_spot_detectability = detectability[sweet_spot_index]
69
+ sweet_spot_distortion = distortion[sweet_spot_index]
70
+ sweet_spot_euclidean = euclidean[sweet_spot_index]
71
+
72
+ # Create a meshgrid from the data
73
+ x_grid, y_grid = np.meshgrid(np.linspace(min(detectability), max(detectability), 30),
74
+ np.linspace(min(distortion), max(distortion), 30))
75
+
76
+ # Interpolate z values (Euclidean distances) to fit the grid using 'nearest' method
77
+ z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='nearest')
78
+
79
+ if z_grid is None:
80
+ raise ValueError("griddata could not generate a valid interpolation. Check your input data.")
81
+
82
+ # Create the 3D contour plot with the Plasma color scale
83
+ fig = go.Figure(data=go.Surface(
84
+ z=z_grid,
85
+ x=x_grid,
86
+ y=y_grid,
87
+ contours={
88
+ "z": {"show": True, "start": min(euclidean), "end": max(euclidean), "size": 0.1, "usecolormap": True}
89
+ },
90
+ colorscale='Plasma'
91
+ ))
92
+
93
+ # Add a marker for the sweet spot
94
+ fig.add_trace(go.Scatter3d(
95
+ x=[sweet_spot_detectability],
96
+ y=[sweet_spot_distortion],
97
+ z=[sweet_spot_euclidean],
98
+ mode='markers+text',
99
+ marker=dict(size=10, color='red', symbol='circle'),
100
+ text=["Sweet Spot"],
101
+ textposition="top center"
102
+ ))
103
+
104
+ # Set axis labels
105
+ fig.update_layout(
106
+ scene=dict(
107
+ xaxis_title='Detectability Score',
108
+ yaxis_title='Distortion Score',
109
+ zaxis_title='Euclidean Distance'
110
+ ),
111
+ margin=dict(l=0, r=0, b=0, t=0)
112
+ )
113
+
114
+ return fig
115
+
116
+ if __name__ == "__main__":
117
+ # Example input data
118
+ detectability_vals = [0.1, 0.3, 0.5, 0.7, 0.9]
119
+ distortion_vals = [0.2, 0.4, 0.6, 0.8, 1.0]
120
+ euclidean_vals = [0.5, 0.3, 0.2, 0.4, 0.6]
121
+
122
+ # Call the function with example data
123
+ fig = gen_three_D_plot(detectability_vals, distortion_vals, euclidean_vals)
124
+
125
+ # Show the plot
126
+ fig.show()
renderers/tree.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import plotly.graph_objects as go
2
+ import textwrap
3
+ import re
4
+ from collections import defaultdict
5
+
6
+ def generate_subplot1(paraphrased_sentence, masked_sentences, strategies, highlight_info, common_grams):
7
+ """
8
+ Generates a subplot visualizing paraphrased and masked sentences in a tree structure.
9
+ Highlights common words with specific colors and applies Longest Common Subsequence (LCS) numbering.
10
+
11
+ Args:
12
+ paraphrased_sentence (str): The paraphrased sentence to be visualized.
13
+ masked_sentences (list of str): A list of masked sentences to be visualized.
14
+ strategies (list of str, optional): List of strategies used for each masked sentence.
15
+ highlight_info (list of tuples): A list of tuples where each tuple contains a word and its associated color for highlighting.
16
+ common_grams (list of tuples): A list of tuples containing an index and a common word or phrase for LCS numbering.
17
+
18
+ Returns:
19
+ plotly.graph_objects.Figure: A Plotly figure representing the tree structure with highlighted words and labeled edges.
20
+ """
21
+ # Combine nodes into one list with appropriate labels
22
+ if isinstance(masked_sentences, str):
23
+ masked_sentences = [masked_sentences]
24
+ nodes = [paraphrased_sentence] + masked_sentences
25
+ nodes[0] += ' L0' # Paraphrased sentence is level 0
26
+ if len(nodes) < 2:
27
+ print("[ERROR] Insufficient nodes for visualization")
28
+ return go.Figure()
29
+
30
+ for i in range(1, len(nodes)):
31
+ nodes[i] += ' L1' # masked sentences are level 1
32
+
33
+ def apply_lcs_numbering(sentence, common_grams):
34
+ """
35
+ Applies LCS numbering to the sentence based on the common_grams.
36
+
37
+ Args:
38
+ sentence (str): The sentence to which the LCS numbering should be applied.
39
+ common_grams (list of tuples): A list of common grams to be replaced with LCS numbers.
40
+
41
+ Returns:
42
+ str: The sentence with LCS numbering applied.
43
+ """
44
+ for idx, lcs in common_grams:
45
+ sentence = re.sub(rf"\b{lcs}\b", f"({idx}){lcs}", sentence)
46
+ return sentence
47
+
48
+ # Apply LCS numbering
49
+ nodes = [apply_lcs_numbering(node, common_grams) for node in nodes]
50
+
51
+
52
+ def highlight_words(sentence, color_map):
53
+ """
54
+ Highlights words in the sentence based on the color_map.
55
+
56
+ Args:
57
+ sentence (str): The sentence where the words will be highlighted.
58
+ color_map (dict): A dictionary mapping words to their colors.
59
+
60
+ Returns:
61
+ str: The sentence with highlighted words.
62
+ """
63
+ for word, color in color_map.items():
64
+ sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE)
65
+ return sentence
66
+
67
+ # Clean and wrap nodes, and highlight specified words globally
68
+ cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes]
69
+ global_color_map = dict(highlight_info)
70
+ highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes]
71
+ wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=55)) for node in highlighted_nodes]
72
+
73
+ def get_levels_and_edges(nodes, strategies=None):
74
+ """
75
+ Determines tree levels and creates edges dynamically.
76
+
77
+ Args:
78
+ nodes (list of str): The nodes representing the sentences.
79
+ strategies (list of str, optional): The strategies used for each edge.
80
+
81
+ Returns:
82
+ tuple: A tuple containing two dictionaries:
83
+ - levels: A dictionary mapping node indices to their levels.
84
+ - edges: A list of edges where each edge is represented by a tuple of node indices.
85
+ """
86
+ levels = {}
87
+ edges = []
88
+ for i, node in enumerate(nodes):
89
+ level = int(node.split()[-1][1])
90
+ levels[i] = level
91
+
92
+ # Add edges from L0 to all L1 nodes
93
+ root_node = next((i for i, level in levels.items() if level == 0), 0)
94
+ for i, level in levels.items():
95
+ if level == 1:
96
+ edges.append((root_node, i))
97
+
98
+ return levels, edges
99
+
100
+ # Get levels and dynamic edges
101
+ levels, edges = get_levels_and_edges(nodes, strategies)
102
+ max_level = max(levels.values(), default=0)
103
+
104
+ # Calculate positions
105
+ positions = {}
106
+ level_heights = defaultdict(int)
107
+ for node, level in levels.items():
108
+ level_heights[level] += 1
109
+
110
+ y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()}
111
+ x_gap = 2
112
+ l1_y_gap = 10
113
+
114
+ for node, level in levels.items():
115
+ if level == 1:
116
+ positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap)
117
+ else:
118
+ positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap)
119
+ y_offsets[level] += 1
120
+
121
+ def color_highlighted_words(node, color_map):
122
+ """
123
+ Colors the highlighted words in the node text.
124
+
125
+ Args:
126
+ node (str): The node text to be highlighted.
127
+ color_map (dict): A dictionary mapping words to their colors.
128
+
129
+ Returns:
130
+ str: The node text with highlighted words.
131
+ """
132
+ parts = re.split(r'(\{\{.*?\}\})', node)
133
+ colored_parts = []
134
+ for part in parts:
135
+ match = re.match(r'\{\{(.*?)\}\}', part)
136
+ if match:
137
+ word = match.group(1)
138
+ color = color_map.get(word, 'black')
139
+ colored_parts.append(f"<span style='color: {color};'>{word}</span>")
140
+ else:
141
+ colored_parts.append(part)
142
+ return ''.join(colored_parts)
143
+
144
+ # Define the text for each edge
145
+ default_edge_texts = [
146
+ "Highest Entropy Masking", "Pseudo-random Masking", "Random Masking",
147
+ "Greedy Sampling", "Temperature Sampling", "Exponential Minimum Sampling",
148
+ "Inverse Transform Sampling", "Greedy Sampling", "Temperature Sampling",
149
+ "Exponential Minimum Sampling", "Inverse Transform Sampling", "Greedy Sampling",
150
+ "Temperature Sampling", "Exponential Minimum Sampling", "Inverse Transform Sampling"
151
+ ]
152
+
153
+ if len(nodes) < 2:
154
+ print("[ERROR] Insufficient nodes for visualization")
155
+ return go.Figure()
156
+
157
+ # Create figure
158
+ fig1 = go.Figure()
159
+
160
+ # Add nodes to the figure
161
+ for i, node in enumerate(wrapped_nodes):
162
+ colored_node = color_highlighted_words(node, global_color_map)
163
+ x, y = positions[i]
164
+ fig1.add_trace(go.Scatter(
165
+ x=[-x], # Reflect the x coordinate
166
+ y=[y],
167
+ mode='markers',
168
+ marker=dict(size=20, color='blue', line=dict(color='black', width=2)),
169
+ hoverinfo='none'
170
+ ))
171
+ fig1.add_annotation(
172
+ x=-x, # Reflect the x coordinate
173
+ y=y,
174
+ text=colored_node,
175
+ showarrow=False,
176
+ xshift=15,
177
+ align="center",
178
+ font=dict(size=12),
179
+ bordercolor='black',
180
+ borderwidth=2,
181
+ borderpad=4,
182
+ bgcolor='white',
183
+ width=400,
184
+ height=100
185
+ )
186
+
187
+ # Add edges and text above each edge
188
+ for i, edge in enumerate(edges):
189
+ x0, y0 = positions[edge[0]]
190
+ x1, y1 = positions[edge[1]]
191
+
192
+ # Use strategy if available, otherwise use default edge text
193
+ if strategies and i < len(strategies):
194
+ edge_text = strategies[i]
195
+ else:
196
+ edge_text = default_edge_texts[i % len(default_edge_texts)]
197
+
198
+ fig1.add_trace(go.Scatter(
199
+ x=[-x0, -x1], # Reflect the x coordinates
200
+ y=[y0, y1],
201
+ mode='lines',
202
+ line=dict(color='black', width=1)
203
+ ))
204
+
205
+ # Calculate the midpoint of the edge
206
+ mid_x = (-x0 + -x1) / 2
207
+ mid_y = (y0 + y1) / 2
208
+
209
+ # Adjust y position to shift text upwards
210
+ text_y_position = mid_y + 0.8 # Increase this value to shift the text further upwards
211
+
212
+ # Add text annotation above the edge
213
+ fig1.add_annotation(
214
+ x=mid_x,
215
+ y=text_y_position,
216
+ text=edge_text, # Use the text specific to this edge
217
+ showarrow=False,
218
+ font=dict(size=12),
219
+ align="center"
220
+ )
221
+
222
+ fig1.update_layout(
223
+ showlegend=False,
224
+ margin=dict(t=50, b=50, l=50, r=50),
225
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
226
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
227
+ width=800 + max_level * 200, # Adjusted width to accommodate more levels
228
+ height=300 + len(nodes) * 100, # Adjusted height to accommodate more levels
229
+ plot_bgcolor='rgba(240,240,240,0.2)',
230
+ paper_bgcolor='white'
231
+ )
232
+
233
+ return fig1
234
+
235
+ def generate_subplot2(masked_sentences, sampled_sentences, highlight_info, common_grams):
236
+ """
237
+ Generates a subplot visualizing multiple masked sentences and their sampled variants in a tree structure.
238
+ Each masked sentence will have multiple sampled sentences derived from it using different sampling techniques.
239
+
240
+ Args:
241
+ masked_sentences (list of str): A list of masked sentences to be visualized as root nodes.
242
+ sampled_sentences (list of str): A list of sampled sentences derived from masked sentences.
243
+ highlight_info (list of tuples): A list of tuples where each tuple contains a word and its associated color for highlighting.
244
+ common_grams (list of tuples): A list of tuples containing an index and a common word or phrase for LCS numbering.
245
+
246
+ Returns:
247
+ plotly.graph_objects.Figure: A Plotly figure representing the tree structure with highlighted words and labeled edges.
248
+ """
249
+ # Define sampling techniques
250
+ sampling_techniques = [
251
+ "Greedy Sampling",
252
+ "Temperature Sampling",
253
+ "Exponential Minimum Sampling",
254
+ "Inverse Transform Sampling"
255
+ ]
256
+
257
+ # Calculate total number of nodes
258
+ num_masked = len(masked_sentences)
259
+ num_sampled_per_masked = len(sampling_techniques)
260
+ total_nodes = num_masked + (num_masked * num_sampled_per_masked)
261
+
262
+ # Combine all sentences into nodes list with appropriate labels
263
+ nodes = []
264
+ # Level 0: masked sentences (root nodes)
265
+ nodes.extend([s + ' L0' for s in masked_sentences])
266
+
267
+ # Level 1: sampled sentences (branch nodes)
268
+ # For each masked sentence, we should have samples from each technique
269
+ sampled_nodes = []
270
+
271
+ # Validate if we have the expected number of sampled sentences
272
+ expected_sampled_count = num_masked * num_sampled_per_masked
273
+ if len(sampled_sentences) < expected_sampled_count:
274
+ # If insufficient samples provided, pad with placeholder sentences
275
+ print(f"Warning: Expected {expected_sampled_count} sampled sentences, but got {len(sampled_sentences)}")
276
+ while len(sampled_sentences) < expected_sampled_count:
277
+ sampled_sentences.append(f"Placeholder sampled sentence {len(sampled_sentences) + 1}")
278
+
279
+ # Add all sampled sentences with level information
280
+ for s in sampled_sentences[:expected_sampled_count]:
281
+ sampled_nodes.append(s + ' L1')
282
+
283
+ nodes.extend(sampled_nodes)
284
+
285
+ def apply_lcs_numbering(sentence, common_grams):
286
+ """
287
+ Applies LCS numbering to the sentence based on the common_grams.
288
+ """
289
+ for idx, lcs in common_grams:
290
+ sentence = re.sub(rf"\b{lcs}\b", f"({idx}){lcs}", sentence)
291
+ return sentence
292
+
293
+ # Apply LCS numbering
294
+ nodes = [apply_lcs_numbering(node, common_grams) for node in nodes]
295
+
296
+ def highlight_words(sentence, color_map):
297
+ """
298
+ Highlights words in the sentence based on the color_map.
299
+ """
300
+ for word, color in color_map.items():
301
+ sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE)
302
+ return sentence
303
+
304
+ # Helper function to color highlighted words
305
+ def color_highlighted_words(node, color_map):
306
+ """
307
+ Colors the highlighted words in the node text.
308
+ """
309
+ parts = re.split(r'(\{\{.*?\}\})', node)
310
+ colored_parts = []
311
+ for part in parts:
312
+ match = re.match(r'\{\{(.*?)\}\}', part)
313
+ if match:
314
+ word = match.group(1)
315
+ color = color_map.get(word, 'black')
316
+ colored_parts.append(f"<span style='color: {color};'>{word}</span>")
317
+ else:
318
+ colored_parts.append(part)
319
+ return ''.join(colored_parts)
320
+
321
+ # Clean nodes, highlight words, and wrap text
322
+ cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes]
323
+ global_color_map = dict(highlight_info)
324
+ highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes]
325
+ wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=80)) for node in highlighted_nodes]
326
+
327
+ # Generate edges based on the tree structure
328
+ def get_levels_and_edges(nodes):
329
+ levels = {}
330
+ edges = []
331
+
332
+ # Extract level info from node labels
333
+ for i, node in enumerate(nodes):
334
+ level = int(node.split()[-1][1])
335
+ levels[i] = level
336
+
337
+ # Create edges from masked sentences to their sampled variants
338
+ for masked_idx in range(num_masked):
339
+ # For each masked sentence, create edges to its sampled variants
340
+ for technique_idx in range(num_sampled_per_masked):
341
+ sampled_idx = num_masked + (masked_idx * num_sampled_per_masked) + technique_idx
342
+ if sampled_idx < len(nodes):
343
+ edges.append((masked_idx, sampled_idx))
344
+
345
+ return levels, edges
346
+
347
+ levels, edges = get_levels_and_edges(nodes)
348
+
349
+ # Calculate positions with improved spacing
350
+ positions = {}
351
+
352
+ # Calculate horizontal spacing for the root nodes (masked sentences)
353
+ root_x_spacing = 0 # All root nodes at x=0
354
+ root_y_spacing = 8.0 # Vertical spacing between root nodes
355
+
356
+ # Calculate positions for sampled nodes
357
+ sampled_x = 3 # X position for all sampled nodes
358
+
359
+ # Calculate y positions for root nodes (masked sentences)
360
+ root_y_start = -(num_masked - 1) * root_y_spacing / 2
361
+ for i in range(num_masked):
362
+ positions[i] = (root_x_spacing, root_y_start + i * root_y_spacing)
363
+
364
+ # Calculate y positions for sampled nodes
365
+ for masked_idx in range(num_masked):
366
+ root_y = positions[masked_idx][1] # Y position of parent masked sentence
367
+
368
+ # Calculate y-spacing for children of this root
369
+ children_y_spacing = 1.5 # Vertical spacing between children of the same root
370
+ children_y_start = root_y - (num_sampled_per_masked - 1) * children_y_spacing / 2
371
+
372
+ # Position each child
373
+ for technique_idx in range(num_sampled_per_masked):
374
+ child_idx = num_masked + (masked_idx * num_sampled_per_masked) + technique_idx
375
+ child_y = children_y_start + technique_idx * children_y_spacing
376
+ positions[child_idx] = (sampled_x, child_y)
377
+
378
+ # Create figure
379
+ fig2 = go.Figure()
380
+
381
+ # Add nodes
382
+ for i, node in enumerate(wrapped_nodes):
383
+ x, y = positions[i]
384
+
385
+ # Define node color based on level
386
+ node_color = 'blue' if levels[i] == 0 else 'green'
387
+
388
+ # Add the node marker
389
+ fig2.add_trace(go.Scatter(
390
+ x=[x],
391
+ y=[y],
392
+ mode='markers',
393
+ marker=dict(size=20, color=node_color, line=dict(color='black', width=2)),
394
+ hoverinfo='none'
395
+ ))
396
+
397
+ # Add node label with highlighting
398
+ colored_node = color_highlighted_words(node, global_color_map)
399
+
400
+ fig2.add_annotation(
401
+ x=x,
402
+ y=y,
403
+ text=colored_node,
404
+ showarrow=False,
405
+ xshift=15,
406
+ align="left",
407
+ font=dict(size=12),
408
+ bordercolor='black',
409
+ borderwidth=2,
410
+ borderpad=4,
411
+ bgcolor='white',
412
+ width=450,
413
+ height=100
414
+ )
415
+
416
+ # Add edges with labels
417
+ for i, (src, dst) in enumerate(edges):
418
+ x0, y0 = positions[src]
419
+ x1, y1 = positions[dst]
420
+
421
+ # Draw the edge
422
+ fig2.add_trace(go.Scatter(
423
+ x=[x0, x1],
424
+ y=[y0, y1],
425
+ mode='lines',
426
+ line=dict(color='black', width=1)
427
+ ))
428
+
429
+ # Add sampling technique label
430
+ # Determine which sampling technique this is
431
+ parent_idx = src
432
+ technique_count = sum(1 for k, (s, _) in enumerate(edges) if s == parent_idx and k < i)
433
+ technique_label = sampling_techniques[technique_count % len(sampling_techniques)]
434
+
435
+ # Calculate midpoint for the label
436
+ mid_x = (x0 + x1) / 2
437
+ mid_y = (y0 + y1) / 2
438
+
439
+ # Add slight offset to avoid overlap
440
+ label_offset = 0.1
441
+
442
+ fig2.add_annotation(
443
+ x=mid_x,
444
+ y=mid_y + label_offset,
445
+ text=technique_label,
446
+ showarrow=False,
447
+ font=dict(size=8),
448
+ align="center"
449
+ )
450
+
451
+ # Update layout
452
+ fig2.update_layout(
453
+ showlegend=False,
454
+ margin=dict(t=20, b=20, l=20, r=20),
455
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
456
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
457
+ width=1200, # Adjusted width to accommodate more levels
458
+ height=2000, # Adjusted height to accommodate more levels
459
+ plot_bgcolor='rgba(240,240,240,0.2)',
460
+ paper_bgcolor='white'
461
+
462
+ )
463
+
464
+ return fig2
465
+
466
+ if __name__ == "__main__":
467
+ paraphrased_sentence = "The quick brown fox jumps over the lazy dog."
468
+ masked_sentences = [
469
+ "A fast brown fox leaps over the lazy dog.",
470
+ "A quick brown fox hops over a lazy dog."
471
+ ]
472
+ highlight_info = [
473
+ ("quick", "red"),
474
+ ("brown", "green"),
475
+ ("fox", "blue"),
476
+ ("lazy", "purple")
477
+ ]
478
+ common_grams = [
479
+ (1, "quick brown fox"),
480
+ (2, "lazy dog")
481
+ ]
482
+
483
+ fig1 = generate_subplot1(paraphrased_sentence, masked_sentences, highlight_info, common_grams)
484
+ fig1.show()
485
+
486
+ sampled_sentence = ["A fast brown fox jumps over a lazy dog."]
487
+
488
+
489
+ fig2 = generate_subplot2(masked_sentences, sampled_sentence, highlight_info, common_grams)
490
+ fig2.show()
utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from utils.watermark import Watermarker
2
+ from utils.paraphraser import Paraphraser
3
+ from utils.entailment import EntailmentAnalyzer
4
+ from utils.sampling import SamplingProcessor
5
+ from utils.config import load_config
utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (404 Bytes). View file
 
utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (509 Bytes). View file
 
utils/__pycache__/config.cpython-310.pyc ADDED
Binary file (594 Bytes). View file
 
utils/__pycache__/config.cpython-311.pyc ADDED
Binary file (971 Bytes). View file
 
utils/__pycache__/entailment.cpython-310.pyc ADDED
Binary file (3.69 kB). View file
 
utils/__pycache__/entailment.cpython-311.pyc ADDED
Binary file (5.33 kB). View file
 
utils/__pycache__/masking_methods.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
utils/__pycache__/masking_methods.cpython-311.pyc ADDED
Binary file (35.6 kB). View file
 
utils/__pycache__/ngram_index_manager.cpython-311.pyc ADDED
Binary file (19.5 kB). View file
 
utils/__pycache__/non_melting_point.cpython-310.pyc ADDED
Binary file (5.05 kB). View file
 
utils/__pycache__/non_melting_point.cpython-311.pyc ADDED
Binary file (33.1 kB). View file
 
utils/__pycache__/paraphraser.cpython-310.pyc ADDED
Binary file (2.85 kB). View file
 
utils/__pycache__/paraphraser.cpython-311.pyc ADDED
Binary file (4.89 kB). View file
 
utils/__pycache__/sampling.cpython-310.pyc ADDED
Binary file (5.06 kB). View file
 
utils/__pycache__/sampling.cpython-311.pyc ADDED
Binary file (9.2 kB). View file
 
utils/__pycache__/watermark.cpython-310.pyc ADDED
Binary file (11.8 kB). View file
 
utils/__pycache__/watermark.cpython-311.pyc ADDED
Binary file (20.9 kB). View file
 
utils/config.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file loads config from config.yaml
3
+ """
4
+
5
+ import yaml
6
+
7
+ def load_config(path):
8
+ """
9
+ Function to load config from config.yaml
10
+ """
11
+ try:
12
+ with open(path, "r") as file:
13
+ config = yaml.safe_load(file)
14
+ return config
15
+ except FileNotFoundError:
16
+ raise FileNotFoundError("Config file not found")
17
+ except Exception as e:
18
+ raise e
utils/config.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is the official config file.
2
+ PECCAVI_TEXT:
3
+ Entailment:
4
+ task: "text-classification"
5
+ model: "ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli"
6
+
7
+ Masking:
8
+ task: "fill-mask"
9
+ tokenizer: "bert-base-uncased"
10
+ model: "bert-base-uncased"
11
+ # tokenizer: "bert-large-cased-whole-word-masking"
12
+ # model: "bert-large-cased-whole-word-masking"
13
+
14
+ Vocabulary:
15
+ tokenizer: "bert-base-uncased"
16
+ model: "bert-base-uncased"
17
+ # permissible_ratio: 0.5
18
+ # tokenizer: "bert-large-cased-whole-word-masking"
19
+ # model: "bert-large-cased-whole-word-masking"
20
+ permissible_ratio: 1.0
21
+
22
+ Sampling:
23
+ tokenizer: "bert-base-uncased"
24
+ model: "bert-base-uncased"
25
+ # tokenizer: "bert-large-cased-whole-word-masking"
26
+ # model: "bert-large-cased-whole-word-masking"
27
+
28
+ Metrics:
29
+ EuclideanDistance: "sentence-transformers/all-MiniLM-L6-v2"
30
+ Distortion: "gpt2"
31
+
32
+ Detector:
33
+ tokenizer: "bert-base-uncased"
34
+ model: "bert-base-uncased"
35
+ # tokenizer: "bert-large-cased-whole-word-masking"
36
+ # model: "bert-large-cased-whole-word-masking"
37
+
38
+ Paraphrase:
39
+ tokenizer: "humarin/chatgpt_paraphraser_on_T5_base"
40
+ model: "humarin/chatgpt_paraphraser_on_T5_base"
41
+ num_beams: 10
42
+ num_beam_groups: 10
43
+ num_return_sequences: 10
44
+ repetition_penalty: 10.0
45
+ diversity_penalty: 3.0
46
+ no_repeat_ngram_size: 2
47
+ temperature: 0.7
48
+ max_length: 64
utils/entailment.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
4
+
5
+ import numpy as np
6
+ from transformers import pipeline
7
+ from typing import List
8
+ from utils.config import load_config
9
+
10
+
11
+ class EntailmentAnalyzer:
12
+ # def __init__(self, config_path: str):
13
+ def __init__(self, config):
14
+ """
15
+ Initialize the EntailmentAnalyzer with the config file path.
16
+
17
+ Args:
18
+ config_path: The path to the configuration file.
19
+ """
20
+ # self.config = load_config(config_path)['PECCAVI_TEXT']['Entailment']
21
+ self.config = config
22
+ self.entailment_pipeline = pipeline(task=self.config['task'], model=self.config['model'])
23
+
24
+ def check_entailment(self, premise: str, hypothesis: str) -> float:
25
+ """
26
+ Check entailment between the premise and hypothesis.
27
+
28
+ Args:
29
+ premise: The premise sentence.
30
+ hypothesis: The hypothesis sentence.
31
+
32
+ Returns:
33
+ float: The entailment score.
34
+ """
35
+ results = self.entailment_pipeline(f"{premise} [SEP] {hypothesis}", top_k=None)
36
+ entailment_score = next(item['score'] for item in results if item['label'] == 'entailment')
37
+ return entailment_score
38
+
39
+ def analyze_entailment(self, original_sentence: str, paraphrased_sentences: List[str], threshold: float) -> tuple:
40
+ """
41
+ Analyze entailment scores for paraphrased sentences. If no selected sentences are found,
42
+ lower the threshold and rerun the analysis.
43
+
44
+ Args:
45
+ original_sentence: The original sentence.
46
+ paraphrased_sentences: List of paraphrased sentences.
47
+ threshold: Minimum score to select a sentence.
48
+
49
+ Returns:
50
+ tuple: A dictionary of all scores, selected sentences, and discarded sentences.
51
+ """
52
+ all_sentences = {}
53
+ selected_sentences = {}
54
+ discarded_sentences = {}
55
+
56
+ # Loop to reduce threshold if no sentences are selected
57
+ while not selected_sentences:
58
+ for paraphrased_sentence in paraphrased_sentences:
59
+ entailment_score = self.check_entailment(original_sentence, paraphrased_sentence)
60
+
61
+ all_sentences[paraphrased_sentence] = entailment_score
62
+ if entailment_score >= threshold:
63
+ selected_sentences[paraphrased_sentence] = entailment_score
64
+ else:
65
+ discarded_sentences[paraphrased_sentence] = entailment_score
66
+
67
+ # If no sentences are selected, lower the threshold
68
+ if not selected_sentences:
69
+ print(f"No selected sentences found. Lowering the threshold by 0.1 (from {threshold} to {threshold - 0.1}).")
70
+ threshold -= 0.1
71
+ if threshold <= 0:
72
+ print("Threshold has reached 0. No sentences meet the criteria.")
73
+ break
74
+
75
+ return all_sentences, selected_sentences, discarded_sentences
76
+
77
+
78
+ if __name__ == "__main__":
79
+ config_path = os.path.join(os.path.dirname(__file__), '..', 'config', 'config.yaml')
80
+
81
+ config_path = '/home/ashhar21137/text_wm/scratch/utils/config/config.yaml'
82
+
83
+ config = load_config(config_path)
84
+
85
+ entailment_analyzer = EntailmentAnalyzer(config['PECCAVI_TEXT']['Entailment'])
86
+
87
+ all_sentences, selected_sentences, discarded_sentences = entailment_analyzer.analyze_entailment(
88
+ "The weather is nice today",
89
+ [
90
+ "The climate is pleasant today",
91
+ "It's a good day weather-wise",
92
+ "Today, the weather is terrible",
93
+ "What a beautiful day it is",
94
+ "The sky is clear and the weather is perfect",
95
+ "It's pouring rain outside today",
96
+ "The weather isn't bad today",
97
+ "A lovely day for outdoor activities"
98
+ ],
99
+ 0.7
100
+ )
101
+
102
+ print("----------------------- All Sentences -----------------------")
103
+ print(all_sentences)
104
+ print("----------------------- Discarded Sentences -----------------------")
105
+ print(discarded_sentences)
106
+ print("----------------------- Selected Sentences -----------------------")
107
+ print(selected_sentences)
utils/masking_methods.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import logging
4
+ import string
5
+ from transformers import BertTokenizer, BertForMaskedLM
6
+ from nltk.corpus import stopwords
7
+ import nltk
8
+ from tqdm import tqdm
9
+
10
+ # Set logging to WARNING for a cleaner terminal.
11
+ logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s")
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Ensure stopwords are downloaded
15
+ try:
16
+ nltk.data.find('corpora/stopwords')
17
+ except LookupError:
18
+ nltk.download('stopwords')
19
+
20
+ def clean_word(word):
21
+ """More robust cleaning for consistent matching"""
22
+ # Remove possessive 's before other punctuation
23
+ if word.lower().endswith("'s"):
24
+ word = word[:-2]
25
+ return word.lower().strip().translate(str.maketrans('', '', string.punctuation))
26
+
27
+ class MaskingProcessor:
28
+ def __init__(self, tokenizer, model):
29
+ self.tokenizer = tokenizer
30
+ self.model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
31
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ self.stop_words = set(stopwords.words('english'))
33
+ tqdm.write(f"[MaskingProcessor] Initialized on device: {self.device}")
34
+
35
+ def is_stopword(self, word):
36
+ """Check if a word is a stopword, handling punctuation and case"""
37
+ return clean_word(word) in self.stop_words
38
+
39
+ def verify_and_correct_ngram_positions(self, sentence, common_ngrams):
40
+ """Verify ngram positions match actual words in sentence and correct if needed."""
41
+ words = sentence.split()
42
+ corrected_ngrams = {}
43
+
44
+ for ngram, positions in common_ngrams.items():
45
+ corrected_positions = []
46
+ ngram_words = ngram.split()
47
+
48
+ # Convert ngram words to clean format for matching
49
+ clean_ngram_words = [clean_word(word) for word in ngram_words]
50
+
51
+ # Scan the sentence to find actual occurrences of the ngram
52
+ for i in range(len(words) - len(ngram_words) + 1):
53
+ is_match = True
54
+ for j, ngram_word in enumerate(clean_ngram_words):
55
+ if clean_word(words[i + j]) != ngram_word:
56
+ is_match = False
57
+ break
58
+
59
+ if is_match:
60
+ # Found a matching position, add it
61
+ corrected_positions.append((i, i + len(ngram_words) - 1))
62
+
63
+ if corrected_positions:
64
+ corrected_ngrams[ngram] = corrected_positions
65
+ else:
66
+ # Log the issue and perform a more flexible search
67
+ print(f"Warning: Could not find exact match for '{ngram}' in the sentence.")
68
+ print(f"Attempting flexible matching...")
69
+
70
+ # Try a more flexible approach by looking for individual words
71
+ for i in range(len(words)):
72
+ if clean_word(words[i]) == clean_ngram_words[0]:
73
+ # We found the first word of the ngram
74
+ if len(ngram_words) == 1 or (
75
+ i + len(ngram_words) <= len(words) and
76
+ all(clean_word(words[i+j]).startswith(clean_ngram_words[j]) for j in range(len(ngram_words)))
77
+ ):
78
+ corrected_positions.append((i, i + len(ngram_words) - 1))
79
+
80
+ if corrected_positions:
81
+ print(f"Found flexible matches for '{ngram}': {corrected_positions}")
82
+ corrected_ngrams[ngram] = corrected_positions
83
+ else:
84
+ # If still no match, keep original positions as fallback
85
+ print(f"No matches found for '{ngram}'. Keeping original positions.")
86
+ corrected_ngrams[ngram] = positions
87
+
88
+ # Log changes
89
+ if corrected_ngrams != common_ngrams:
90
+ print(f"Original ngram positions: {common_ngrams}")
91
+ print(f"Corrected ngram positions: {corrected_ngrams}")
92
+
93
+ return corrected_ngrams
94
+ def in_any_ngram(self, idx, ngram_positions):
95
+ """Check if an original sentence index is part of any n-gram span"""
96
+ return any(start <= idx <= end for start, end in ngram_positions)
97
+ def create_fallback_mask(self, sentence, ngrams):
98
+ """Create a fallback mask when normal strategies fail."""
99
+ try:
100
+ words = sentence.split()
101
+ if not words:
102
+ return None
103
+
104
+ # Find any non-stopword that isn't in an ngram
105
+ ngram_positions = []
106
+ for positions in ngrams.values():
107
+ for start, end in positions:
108
+ ngram_positions.append((start, end))
109
+ ngram_positions.sort()
110
+
111
+ # Find first eligible word
112
+ for idx, word in enumerate(words):
113
+ if not self.is_stopword(word) and not self.in_any_ngram(idx, ngram_positions):
114
+ masked_words = words.copy()
115
+ masked_words[idx] = self.tokenizer.mask_token
116
+ tqdm.write(f"[INFO] Fallback mask created at position {idx}: '{word}'")
117
+ return " ".join(masked_words), [idx]
118
+
119
+ # If no eligible word found, just mask the first non-stop word
120
+ for idx, word in enumerate(words):
121
+ if not self.is_stopword(word):
122
+ masked_words = words.copy()
123
+ masked_words[idx] = self.tokenizer.mask_token
124
+ tqdm.write(f"[INFO] Last resort fallback mask created at position {idx}: '{word}'")
125
+ return " ".join(masked_words), [idx]
126
+
127
+ # If still nothing, mask the first word
128
+ if words:
129
+ masked_words = words.copy()
130
+ masked_words[0] = self.tokenizer.mask_token
131
+ return " ".join(masked_words), [0]
132
+
133
+ return None
134
+ except Exception as e:
135
+ tqdm.write(f"[ERROR] Error creating fallback mask: {e}")
136
+ return None
137
+
138
+ def mask_sentence_random(self, sentence, common_ngrams):
139
+ """Mask random non-stopwords that are not part of common ngrams with controlled positioning."""
140
+ common_ngrams = self.verify_and_correct_ngram_positions(sentence, common_ngrams)
141
+ tqdm.write(f"[MaskingProcessor] Masking (random) sentence: {sentence}")
142
+ original_words = sentence.split()
143
+
144
+ # Handle punctuation
145
+ has_punctuation = False
146
+ punctuation = ''
147
+ if original_words and original_words[-1][-1] in ['.', ',', '!', '?', ';', ':', '"', "'"]:
148
+ has_punctuation = True
149
+ punctuation = original_words[-1][-1]
150
+ original_words[-1] = original_words[-1][:-1]
151
+ if not original_words[-1]: # If the word was just punctuation
152
+ original_words.pop()
153
+
154
+ # Get flattened ngram positions
155
+ ngram_positions = []
156
+ for positions in common_ngrams.values():
157
+ for start, end in positions:
158
+ ngram_positions.append((start, end))
159
+ ngram_positions.sort()
160
+
161
+ # Find all candidate indices (non-stopwords not in ngrams)
162
+ candidate_indices = []
163
+ for idx, word in enumerate(original_words):
164
+ if not self.is_stopword(word) and not self.in_any_ngram(idx, ngram_positions):
165
+ candidate_indices.append(idx)
166
+
167
+ # Debug print candidate words
168
+ print("Candidate words for masking:")
169
+ for idx in candidate_indices:
170
+ print(f" Position {idx}: '{original_words[idx]}'")
171
+
172
+ selected_indices = []
173
+ if ngram_positions:
174
+ # Before first ngram
175
+ before_first = [idx for idx in candidate_indices if idx < ngram_positions[0][0]]
176
+ if before_first:
177
+ num_to_select = min(1, len(before_first)) # Select 1 word
178
+ if num_to_select > 0:
179
+ selected = random.sample(before_first, num_to_select)
180
+ selected_indices.extend(selected)
181
+
182
+ # Between ngrams
183
+ for i in range(len(ngram_positions) - 1):
184
+ between = [idx for idx in candidate_indices
185
+ if ngram_positions[i][1] < idx < ngram_positions[i+1][0]]
186
+ if between:
187
+ num_to_select = min(2, len(between)) # Select between 1-2 words
188
+ if num_to_select > 0:
189
+ selected = random.sample(between, num_to_select)
190
+ selected_indices.extend(selected)
191
+
192
+ # After last ngram
193
+ after_last = [idx for idx in candidate_indices if idx > ngram_positions[-1][1]]
194
+ if after_last:
195
+ num_to_select = min(1, len(after_last)) # Select 1 word
196
+ if num_to_select > 0:
197
+ selected = random.sample(after_last, num_to_select)
198
+ selected_indices.extend(selected)
199
+ else:
200
+ # If no ngrams, pick up to 6 random candidates
201
+ if candidate_indices:
202
+ selected_indices = random.sample(candidate_indices,
203
+ min(6, len(candidate_indices)))
204
+
205
+ masked_words = original_words.copy()
206
+ for idx in selected_indices:
207
+ masked_words[idx] = self.tokenizer.mask_token
208
+
209
+ if has_punctuation:
210
+ masked_words.append(punctuation)
211
+
212
+ # Debug prints
213
+ print("Original sentence:", sentence)
214
+ print("Common ngrams:", common_ngrams)
215
+ print("Common ngram positions:", ngram_positions)
216
+ print("Candidate indices for masking:", candidate_indices)
217
+ print("Selected for masking:", selected_indices)
218
+ print("Masked sentence:", " ".join(masked_words))
219
+
220
+ return " ".join(masked_words), selected_indices
221
+
222
+ def mask_sentence_pseudorandom(self, sentence, common_ngrams):
223
+ """Mask specific non-stopwords based on their position relative to ngrams."""
224
+ common_ngrams = self.verify_and_correct_ngram_positions(sentence, common_ngrams)
225
+ tqdm.write(f"[MaskingProcessor] Masking (pseudorandom) sentence: {sentence}")
226
+ random.seed(3) # Fixed seed for pseudorandom behavior
227
+ original_words = sentence.split()
228
+
229
+ # Handle punctuation
230
+ has_punctuation = False
231
+ punctuation = ''
232
+ if original_words and original_words[-1][-1] in ['.', ',', '!', '?', ';', ':', '"', "'"]:
233
+ has_punctuation = True
234
+ punctuation = original_words[-1][-1]
235
+ original_words[-1] = original_words[-1][:-1]
236
+ if not original_words[-1]: # If the word was just punctuation
237
+ original_words.pop()
238
+
239
+ # Get flattened ngram positions
240
+ ngram_positions = []
241
+ for positions in common_ngrams.values():
242
+ for start, end in positions:
243
+ ngram_positions.append((start, end))
244
+ ngram_positions.sort()
245
+
246
+ # Find all candidate indices (non-stopwords not in ngrams)
247
+ candidate_indices = []
248
+ for idx, word in enumerate(original_words):
249
+ if not self.is_stopword(word) and not self.in_any_ngram(idx, ngram_positions):
250
+ candidate_indices.append(idx)
251
+
252
+ # Debug print candidate words
253
+ print("Candidate words for masking:")
254
+ for idx in candidate_indices:
255
+ print(f" Position {idx}: '{original_words[idx]}'")
256
+
257
+ # PSEUDORANDOM SPECIFIC LOGIC:
258
+ selected_indices = []
259
+ if ngram_positions:
260
+ # Before first ngram
261
+ before_first = [idx for idx in candidate_indices if idx < ngram_positions[0][0]]
262
+ if before_first:
263
+ num_to_select = min(1, len(before_first)) # Select 1 word
264
+ if num_to_select > 0:
265
+ selected = random.sample(before_first, num_to_select)
266
+ selected_indices.extend(selected)
267
+
268
+ # Between ngrams
269
+ for i in range(len(ngram_positions) - 1):
270
+ between = [idx for idx in candidate_indices
271
+ if ngram_positions[i][1] < idx < ngram_positions[i+1][0]]
272
+ if between:
273
+ num_to_select = min(2, len(between)) # Select between 1-2 words
274
+ if num_to_select > 0:
275
+ selected = random.sample(between, num_to_select)
276
+ selected_indices.extend(selected)
277
+
278
+ # After last ngram
279
+ after_last = [idx for idx in candidate_indices if idx > ngram_positions[-1][1]]
280
+ if after_last:
281
+ num_to_select = min(1, len(after_last)) # Select 1 word
282
+ if num_to_select > 0:
283
+ selected = random.sample(after_last, num_to_select)
284
+ selected_indices.extend(selected)
285
+ else:
286
+ # If no ngrams, pick up to 6 random candidates
287
+ if candidate_indices:
288
+ selected_indices = random.sample(candidate_indices,
289
+ min(6, len(candidate_indices)))
290
+
291
+ masked_words = original_words.copy()
292
+ for idx in selected_indices:
293
+ masked_words[idx] = self.tokenizer.mask_token
294
+
295
+ if has_punctuation:
296
+ masked_words.append(punctuation)
297
+
298
+ # Debug prints
299
+ print("Original sentence:", sentence)
300
+ print("Common ngrams:", common_ngrams)
301
+ print("Common ngram positions:", ngram_positions)
302
+ print("Candidate indices for masking:", candidate_indices)
303
+ print("Selected for masking:", selected_indices)
304
+ print("Masked sentence:", " ".join(masked_words))
305
+
306
+ return " ".join(masked_words), selected_indices
307
+
308
+ def mask_sentence_entropy(self, sentence, common_ngrams):
309
+ """Mask words with highest entropy that are not part of common ngrams."""
310
+ common_ngrams = self.verify_and_correct_ngram_positions(sentence, common_ngrams)
311
+ tqdm.write(f"[MaskingProcessor] Masking (entropy) sentence: {sentence}")
312
+ original_words = sentence.split()
313
+
314
+ # Handle punctuation
315
+ has_punctuation = False
316
+ punctuation = ''
317
+ if original_words and original_words[-1][-1] in ['.', ',', '!', '?', ';', ':', '"', "'"]:
318
+ has_punctuation = True
319
+ punctuation = original_words[-1][-1]
320
+ original_words[-1] = original_words[-1][:-1]
321
+ if not original_words[-1]: # If the word was just punctuation
322
+ original_words.pop()
323
+
324
+ # Get flattened ngram positions
325
+ ngram_positions = []
326
+ for positions in common_ngrams.values():
327
+ for start, end in positions:
328
+ ngram_positions.append((start, end))
329
+ ngram_positions.sort()
330
+
331
+ # Find all candidate indices (non-stopwords not in ngrams)
332
+ candidate_indices = []
333
+ for idx, word in enumerate(original_words):
334
+ if not self.is_stopword(word) and not self.in_any_ngram(idx, ngram_positions):
335
+ candidate_indices.append(idx)
336
+
337
+ # Debug print candidate words
338
+ print("Candidate words for masking:")
339
+ for idx in candidate_indices:
340
+ print(f" Position {idx}: '{original_words[idx]}'")
341
+
342
+ # ENTROPY SPECIFIC LOGIC:
343
+ # Calculate entropy for each candidate word
344
+ selected_indices = []
345
+ if candidate_indices:
346
+ # Organize candidates by position relative to ngrams
347
+ if ngram_positions:
348
+ # Group candidates by position
349
+ before_first = []
350
+ between_ngrams = {}
351
+ after_last = []
352
+
353
+ for idx in candidate_indices:
354
+ if idx < ngram_positions[0][0]:
355
+ before_first.append(idx)
356
+ elif idx > ngram_positions[-1][1]:
357
+ after_last.append(idx)
358
+ else:
359
+ # Find which ngram gap this belongs to
360
+ for i in range(len(ngram_positions) - 1):
361
+ if ngram_positions[i][1] < idx < ngram_positions[i+1][0]:
362
+ if i not in between_ngrams:
363
+ between_ngrams[i] = []
364
+ between_ngrams[i].append(idx)
365
+
366
+ # Before first ngram: select 1 highest entropy words
367
+ if before_first:
368
+ entropies = [(idx, self.calculate_word_entropy(sentence, idx)) for idx in before_first]
369
+ entropies.sort(key=lambda x: x[1], reverse=True) # Sort by entropy (highest first)
370
+ num_to_select = min(1, len(entropies)) # Select 1 word
371
+ selected_indices.extend([idx for idx, _ in entropies[:num_to_select]])
372
+
373
+ # For each gap between ngrams: select 1-2 highest entropy words
374
+ for group, indices in between_ngrams.items():
375
+ if indices:
376
+ entropies = [(idx, self.calculate_word_entropy(sentence, idx)) for idx in indices]
377
+ entropies.sort(key=lambda x: x[1], reverse=True) # Sort by entropy (highest first)
378
+ num_to_select = min(2, len(entropies)) # Select between 1-2 words
379
+ selected_indices.extend([idx for idx, _ in entropies[:num_to_select]])
380
+
381
+ # After last ngram: select 1 highest entropy words
382
+ if after_last:
383
+ entropies = [(idx, self.calculate_word_entropy(sentence, idx)) for idx in after_last]
384
+ entropies.sort(key=lambda x: x[1], reverse=True) # Sort by entropy (highest first)
385
+ num_to_select = min(1, len(entropies)) # Select 1 word
386
+ selected_indices.extend([idx for idx, _ in entropies[:num_to_select]])
387
+ else:
388
+ # If no ngrams, calculate entropy for all candidates
389
+ entropies = [(idx, self.calculate_word_entropy(sentence, idx)) for idx in candidate_indices]
390
+ # Sort by entropy (highest first)
391
+ entropies.sort(key=lambda x: x[1], reverse=True)
392
+ # Take top 6 or all if fewer
393
+ selected_indices = [idx for idx, _ in entropies[:min(6, len(entropies))]]
394
+
395
+ masked_words = original_words.copy()
396
+ for idx in selected_indices:
397
+ masked_words[idx] = self.tokenizer.mask_token
398
+
399
+ if has_punctuation:
400
+ masked_words.append(punctuation)
401
+
402
+ # Debug prints
403
+ print("Original sentence:", sentence)
404
+ print("Common ngrams:", common_ngrams)
405
+ print("Common ngram positions:", ngram_positions)
406
+ print("Candidate indices for masking:", candidate_indices)
407
+ print("Selected for masking:", selected_indices)
408
+ print("Masked sentence:", " ".join(masked_words))
409
+
410
+ return " ".join(masked_words), selected_indices
411
+
412
+ def calculate_mask_logits(self, original_sentence, original_mask_indices):
413
+ """Calculate logits for masked positions."""
414
+ logger.info(f"Calculating mask logits for sentence: {original_sentence}")
415
+ words = original_sentence.split()
416
+ mask_logits = {}
417
+ for idx in original_mask_indices:
418
+ masked_words = words.copy()
419
+ masked_words[idx] = self.tokenizer.mask_token
420
+ masked_sentence = " ".join(masked_words)
421
+ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"].to(self.device)
422
+ mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
423
+ with torch.no_grad():
424
+ outputs = self.model(input_ids)
425
+ logits = outputs.logits
426
+ mask_logits_tensor = logits[0, mask_token_index, :]
427
+ top_mask_logits, top_mask_indices = torch.topk(mask_logits_tensor, 100, dim=-1)
428
+ top_tokens = []
429
+ top_logits = []
430
+ seen_words = set()
431
+ for token_id, logit in zip(top_mask_indices[0], top_mask_logits[0]):
432
+ token = self.tokenizer.convert_ids_to_tokens(token_id.item())
433
+ if token.startswith('##'):
434
+ continue
435
+ word = self.tokenizer.convert_tokens_to_string([token]).strip()
436
+ if word and word not in seen_words:
437
+ seen_words.add(word)
438
+ top_tokens.append(word)
439
+ top_logits.append(logit.item())
440
+ if len(top_tokens) == 50:
441
+ break
442
+ mask_logits[idx] = {
443
+ "tokens": top_tokens,
444
+ "logits": top_logits
445
+ }
446
+ logger.info("Completed calculating mask logits.")
447
+ return mask_logits
448
+
449
+ def calculate_word_entropy(self, sentence, word_position):
450
+ """Calculate entropy for a word at a specific position."""
451
+ logger.info(f"Calculating word entropy for position {word_position} in sentence: {sentence}")
452
+ words = sentence.split()
453
+ masked_words = words.copy()
454
+ masked_words[word_position] = self.tokenizer.mask_token
455
+ masked_sentence = " ".join(masked_words)
456
+ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"].to(self.device)
457
+ mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
458
+ with torch.no_grad():
459
+ outputs = self.model(input_ids)
460
+ logits = outputs.logits
461
+ probs = torch.nn.functional.softmax(logits[0, mask_token_index], dim=-1)
462
+ entropy = -torch.sum(probs * torch.log(probs + 1e-9))
463
+ logger.info(f"Computed entropy: {entropy.item()}")
464
+ return entropy.item()
465
+
466
+ def process_sentences(self, sentences_list, common_grams, method="random"):
467
+ """Process multiple sentences with the specified masking method."""
468
+ tqdm.write(f"[MaskingProcessor] Processing sentences using method: {method}")
469
+ results = {}
470
+ for sentence in tqdm(sentences_list, desc="Masking Sentences"):
471
+ try:
472
+ ngrams = common_grams.get(sentence, {})
473
+
474
+ if method == "random":
475
+ masked_sentence, original_mask_indices = self.mask_sentence_random(sentence, ngrams)
476
+ elif method == "pseudorandom":
477
+ masked_sentence, original_mask_indices = self.mask_sentence_pseudorandom(sentence, ngrams)
478
+ else: # entropy
479
+ masked_sentence, original_mask_indices = self.mask_sentence_entropy(sentence, ngrams)
480
+
481
+ # Skip if no masks were applied
482
+ if not original_mask_indices:
483
+ tqdm.write(f"[WARNING] No mask indices found for sentence with method {method}: {sentence[:50]}...")
484
+ # Create a fallback masked sentence with at least one mask
485
+ fallback_result = self.create_fallback_mask(sentence, ngrams)
486
+ if fallback_result:
487
+ masked_sentence, original_mask_indices = fallback_result
488
+ tqdm.write(f"[INFO] Created fallback mask for sentence")
489
+ else:
490
+ tqdm.write(f"[WARNING] Could not create fallback mask, skipping sentence")
491
+ continue
492
+
493
+ logits = self.calculate_mask_logits(sentence, original_mask_indices)
494
+ results[sentence] = {
495
+ "masked_sentence": masked_sentence,
496
+ "mask_logits": logits
497
+ }
498
+ logger.info(f"Processed sentence: {sentence}")
499
+ except Exception as e:
500
+ tqdm.write(f"[ERROR] Failed to process sentence with method {method}: {e}")
501
+ tqdm.write(f"Sentence: {sentence[:100]}...")
502
+ import traceback
503
+ tqdm.write(traceback.format_exc())
504
+ tqdm.write("[MaskingProcessor] Completed processing sentences.")
505
+ return results
506
+
507
+ @staticmethod
508
+ def identify_common_ngrams(sentences, entities):
509
+ """Enhanced to handle possessive forms better"""
510
+ common_grams = {}
511
+
512
+ # Pre-process entities to handle variations
513
+ processed_entities = []
514
+ for entity in entities:
515
+ processed_entities.append(entity)
516
+ # Add possessive form if not already there
517
+ if not entity.endswith("'s") and not entity.endswith("s"):
518
+ processed_entities.append(f"{entity}'s")
519
+
520
+ for sentence in sentences:
521
+ words = sentence.split()
522
+ common_grams[sentence] = {}
523
+
524
+ # Look for each entity in the sentence
525
+ for entity in processed_entities:
526
+ entity_words = entity.split()
527
+ entity_len = len(entity_words)
528
+
529
+ # Convert entity words for matching
530
+ clean_entity_words = [clean_word(word) for word in entity_words]
531
+
532
+ # Find all occurrences
533
+ for i in range(len(words) - entity_len + 1):
534
+ is_match = True
535
+ for j, entity_word in enumerate(clean_entity_words):
536
+ if clean_word(words[i + j]) != entity_word:
537
+ is_match = False
538
+ break
539
+
540
+ if is_match:
541
+ # Use canonical form from entity list for consistency
542
+ base_entity = entity
543
+ if entity.endswith("'s") and any(e == entity[:-2] for e in processed_entities):
544
+ base_entity = entity[:-2]
545
+
546
+ if base_entity not in common_grams[sentence]:
547
+ common_grams[sentence][base_entity] = []
548
+ common_grams[sentence][base_entity].append((i, i + entity_len - 1))
549
+
550
+ return common_grams
551
+ if __name__ == "__main__":
552
+ #example test
553
+ # test_sentence = "Kevin De Bruyne scored for Manchester City as they won the 2019-20 Premier League title."
554
+ # entities to preserve
555
+ # entities = ["Kevin De Bruyne", "Manchester City", "Premier League"]
556
+ # Identify common n-grams
557
+ common_grams = MaskingProcessor.identify_common_ngrams([test_sentence], entities)
558
+
559
+ # Print detected n-grams
560
+ print(f"Detected common n-grams: {common_grams[test_sentence]}")
561
+
562
+ # Initialize the processor
563
+ processor = MaskingProcessor(
564
+ BertTokenizer.from_pretrained("bert-large-cased-whole-word-masking"),
565
+ BertForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
566
+ )
567
+
568
+ # Test all three masking methods
569
+ print("\nTesting Random Masking:")
570
+ masked_random, indices_random = processor.mask_sentence_random(test_sentence, common_grams[test_sentence])
571
+
572
+ print("\nTesting Pseudorandom Masking:")
573
+ masked_pseudorandom, indices_pseudorandom = processor.mask_sentence_pseudorandom(test_sentence, common_grams[test_sentence])
574
+
575
+ print("\nTesting Entropy Masking:")
576
+ masked_entropy, indices_entropy = processor.mask_sentence_entropy(test_sentence, common_grams[test_sentence])
577
+
578
+ # Print results
579
+ print("\nResults:")
580
+ print(f"Original: {test_sentence}")
581
+ print(f"Random Masked: {masked_random}")
582
+ print(f"Pseudorandom Masked: {masked_pseudorandom}")
583
+ print(f"Entropy Masked: {masked_entropy}")
utils/non_melting_point.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ import logging
3
+ import spacy
4
+ from nltk.corpus import stopwords
5
+ from nltk.util import ngrams
6
+ from collections import Counter
7
+ import re
8
+ from tqdm import tqdm
9
+
10
+ # Logging setup
11
+ logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s")
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class NgramProcessor:
15
+ def __init__(self, models=None):
16
+ try:
17
+ nltk.data.find('corpora/stopwords')
18
+ except LookupError:
19
+ nltk.download('stopwords')
20
+ self.stop_words = set(stopwords.words('english'))
21
+
22
+ # Default to standard model if none specified
23
+ if models is None:
24
+ models = ["en_core_web_trf"]
25
+
26
+ # Load specified model
27
+ self.models = {}
28
+ for model_name in models:
29
+ try:
30
+ self.models[model_name] = spacy.load(model_name)
31
+ tqdm.write(f"[NgramProcessor] Loaded model: {model_name}")
32
+ except IOError:
33
+ tqdm.write(f"[NgramProcessor] Error: Model '{model_name}' not found. Please install it with:")
34
+ tqdm.write(f"python -m spacy download {model_name}")
35
+ except Exception as e:
36
+ tqdm.write(f"[NgramProcessor] Error loading model '{model_name}': {str(e)}")
37
+
38
+ # Set primary NLP model for other processes
39
+ if "en_core_web_trf" in self.models:
40
+ self.nlp = self.models["en_core_web_trf"]
41
+ elif len(self.models) > 0:
42
+ # Use first available model as primary if preferred one isn't available
43
+ self.nlp = next(iter(self.models.values()))
44
+ else:
45
+ raise ValueError("No spaCy model was successfully loaded")
46
+
47
+ # Add custom entity patterns for numerical ranges to primary model
48
+ if "entity_ruler" not in self.nlp.pipe_names:
49
+ ruler = self.nlp.add_pipe("entity_ruler", before="ner")
50
+ patterns = [
51
+ {"label": "CARDINAL", "pattern": [{"TEXT": {"REGEX": "\\d+-\\d+"}}]}, # Pattern for ranges like "7-10"
52
+ {"label": "PERCENT", "pattern": [{"TEXT": {"REGEX": "\\d+%"}}]} # Pattern for percentages
53
+ ]
54
+ ruler.add_patterns(patterns)
55
+
56
+ # Create special pattern for numerical ranges
57
+ self.number_range_pattern = re.compile(r'\b(\d+(?:-\d+)+)\b')
58
+
59
+ tqdm.write("[NgramProcessor] Initialized with stopwords, spaCy NLP model, and numerical range detection.")
60
+
61
+ def remove_stopwords(self, text):
62
+ words = re.findall(r'\w+', text.lower())
63
+ filtered_words = [word for word in words if word not in self.stop_words]
64
+ return ' '.join(filtered_words)
65
+
66
+ def extract_number_ranges(self, sentences):
67
+ """Extract numerical ranges like '7-10' from sentences"""
68
+ tqdm.write("[NgramProcessor] Extracting numerical ranges...")
69
+
70
+ number_ranges = []
71
+ range_counts = Counter()
72
+
73
+ for sentence in sentences:
74
+ # Find all numerical ranges in the sentence
75
+ matches = self.number_range_pattern.findall(sentence)
76
+ for match in matches:
77
+ range_counts[match] += 1
78
+
79
+ # Add all ranges that appear in all sentences (threshold for ranges)
80
+ for range_text, count in range_counts.items():
81
+ if count >= 1:
82
+ number_ranges.append(range_text)
83
+
84
+ tqdm.write(f"[NgramProcessor] Found {len(number_ranges)} numerical ranges: {number_ranges}")
85
+
86
+ return number_ranges
87
+ def extract_standalone_numbers(self, sentences):
88
+ """Extract standalone numerical values from sentences"""
89
+ tqdm.write("[NgramProcessor] Extracting standalone numbers...")
90
+
91
+ # Two patterns: one for percentages, one for regular numbers
92
+ percentage_pattern = re.compile(r'\b\d+%\b') # Only matches numbers with % sign
93
+ number_pattern = re.compile(r'\b\d+\b') # Only matches standalone numbers
94
+
95
+ percentage_counts = Counter()
96
+ number_counts = Counter()
97
+ percentage_values = set() # Store the numeric part of percentages for cross-reference
98
+
99
+ # First pass: Find all percentages
100
+ for sentence in sentences:
101
+ # Extract all percentages first
102
+ percentage_matches = percentage_pattern.findall(sentence)
103
+ for match in percentage_matches:
104
+ percentage_counts[match] += 1
105
+ # Store the numeric part for later comparison
106
+ numeric_part = match.rstrip('%')
107
+ percentage_values.add(numeric_part)
108
+
109
+ # Second pass: Find standalone numbers
110
+ for sentence in sentences:
111
+ # Only look for standalone numbers now
112
+ number_matches = number_pattern.findall(sentence)
113
+ for match in number_matches:
114
+ # Avoid double counting numbers that we already counted as percentages
115
+ if match not in percentage_values:
116
+ number_counts[match] += 1
117
+
118
+ # Process percentages first (they have priority)
119
+ threshold = max(1, int(len(sentences) * 1.0))
120
+ standalone_numbers = []
121
+
122
+ # Add percentages that meet the threshold
123
+ for num, count in percentage_counts.items():
124
+ if count >= threshold:
125
+ standalone_numbers.append(num) # Already has % sign
126
+
127
+ # Then add standalone numbers, converting to percentage format if needed
128
+ for num, count in number_counts.items():
129
+ if count >= threshold:
130
+ # If this number also appeared as part of a percentage, use the percentage format
131
+ if num in percentage_values:
132
+ standalone_numbers.append(f"{num}%")
133
+ else:
134
+ standalone_numbers.append(num)
135
+
136
+ tqdm.write(f"[NgramProcessor] Found {len(standalone_numbers)} standalone numbers: {standalone_numbers}")
137
+ return standalone_numbers
138
+
139
+ def extract_regex_subsequences(self, sentences):
140
+ """Extract potential subsequences using regex patterns before applying NLP"""
141
+ tqdm.write("[NgramProcessor] Extracting regex subsequences...")
142
+
143
+ # Find potential multi-word subsequences (2-5 words) that occur across sentences
144
+ potential_subsequences = set()
145
+
146
+ # Process each sentence to find multi-word phrases
147
+ for sentence in sentences:
148
+ # First, clean the sentence by removing punctuation and converting to lowercase
149
+ clean_sentence = re.sub(r'[^\w\s&-./\'()[\]$€£¥+%]', ' ', sentence.lower())
150
+
151
+ # Extract sequences of 2-6 words
152
+ for i in range(2, 7): # Try sequences of length 2-6 words
153
+ pattern = r'\b(\w+(?:[-&\s./\'()[\]$€£¥+%]+\w+){' + str(i-1) + r'})\b'
154
+ matches = re.findall(pattern, clean_sentence)
155
+ potential_subsequences.update(matches)
156
+
157
+ # Filter out sequences that consist only of stopwords (but preserve numbers)
158
+ filtered_subsequences = []
159
+ for subseq in potential_subsequences:
160
+ words = re.split(r'[\s-]+', subseq) # Split on spaces or hyphens
161
+
162
+ # Function to check if a word is a number or percentage
163
+ def is_numeric(word):
164
+ return bool(re.match(r'^\d+(\.\d+)?%?$|^\d+-\d+$', word))
165
+
166
+ # Skip if ALL words are stopwords and none are numeric
167
+ if all((word in self.stop_words and not is_numeric(word)) for word in words):
168
+ tqdm.write(f"[NgramProcessor] Skipping all-stopword phrase: {subseq}")
169
+ continue
170
+
171
+ # Keep if sequence has significant words (not just stopwords)
172
+ # OR if it contains numbers/percentages
173
+ if len(words) > 1 and (
174
+ any(word not in self.stop_words and (len(word) > 2 or is_numeric(word)) for word in words)
175
+ ):
176
+ # Additional check to reject if standalone "the" or other common stopwords
177
+ if not (len(words) == 1 and words[0] in self.stop_words and not is_numeric(words[0])):
178
+ filtered_subsequences.append(subseq)
179
+
180
+ # Count occurrences across all sentences
181
+ subseq_counts = Counter()
182
+ for subseq in filtered_subsequences:
183
+ for sentence in sentences:
184
+ if re.search(r'\b' + re.escape(subseq) + r'\b', sentence.lower()):
185
+ subseq_counts[subseq] += 1
186
+
187
+ # Keep only subsequences that appear in multiple sentences
188
+ threshold = max(2, int(len(sentences) * 1.0)) # threshold to catch all patterns
189
+ regex_candidates = [subseq for subseq, count in subseq_counts.items()
190
+ if count >= threshold]
191
+
192
+ tqdm.write(f"[NgramProcessor] Found {len(regex_candidates)} regex subsequences")
193
+ return regex_candidates
194
+
195
+ def filter_standalone_stopwords(self, ngrams_dict):
196
+ """Remove standalone stopwords and very short terms from the ngrams dictionary"""
197
+ filtered_dict = {}
198
+ for sentence, ngrams in ngrams_dict.items():
199
+ filtered_dict[sentence] = {}
200
+ for ngram, indices in ngrams.items():
201
+ words = ngram.split()
202
+ # Skip single stopwords and very short terms UNLESS they are numbers
203
+ if (len(words) == 1 and (words[0] in self.stop_words or len(words[0]) < 3)):
204
+ # Exception for numbers
205
+ if len(words) == 1 and re.match(r'^\d+$', words[0]):
206
+ filtered_dict[sentence][ngram] = indices
207
+ continue
208
+ else:
209
+ continue
210
+ # Skip if ALL words are stopwords
211
+ if all(word in self.stop_words for word in words):
212
+ continue
213
+ filtered_dict[sentence][ngram] = indices
214
+ return filtered_dict
215
+
216
+ def extract_named_entities(self, sentences):
217
+ entity_counter = Counter()
218
+
219
+ # Process each sentence with each model
220
+ for model_name, nlp_model in self.models.items():
221
+ tqdm.write(f"[NgramProcessor] Extracting entities with model: {model_name}")
222
+ docs = list(nlp_model.pipe(sentences))
223
+
224
+ # Process each sentence
225
+ for doc in docs:
226
+ for ent in doc.ents:
227
+ # Include entity types relevant to this model
228
+ # This is a comprehensive list - some models may not use all these types
229
+ if ent.label_ in {
230
+ # People, organizations, locations
231
+ "PERSON", "ORG", "GPE", "LOC", "NORP",
232
+
233
+ # Facilities and products
234
+ "FAC", "PRODUCT", "WORK_OF_ART", "EVENT",
235
+
236
+ # Numeric entities
237
+ "DATE", "TIME", "MONEY", "QUANTITY", "PERCENT", "CARDINAL", "ORDINAL",
238
+
239
+ # Others
240
+ "LAW", "LANGUAGE",
241
+
242
+ # Scientific entities
243
+ "SCIENTIFIC", "SUBSTANCE", "CHEMICAL", "TECHNOLOGY",
244
+
245
+ # Medical entities
246
+ "DISEASE", "MEDICAL", "CLINICAL", "TREATMENT", "SYMPTOM", "DIAGNOSTIC",
247
+ "ANATOMICAL", "BIOLOGY", "GENE", "PROTEIN", "DRUG",
248
+
249
+ # Legal entities
250
+ "LEGAL", "COURT", "STATUTE", "PROVISION", "CASE_CITATION", "JUDGE",
251
+ "LEGAL_ROLE", "REGULATION", "CONTRACT"
252
+ }:
253
+ # Handle possessive forms by stripping 's
254
+ clean_entity = re.sub(r"'s\b", "", ent.text.lower()).strip()
255
+ # Add model name prefix to distinguish sources
256
+ entity_counter[clean_entity] += 1
257
+
258
+ threshold = max(1, len(sentences) * 1.0) # Adjusted threshold for entities
259
+ return [ent for ent, count in entity_counter.items() if count >= threshold]
260
+
261
+ def extract_domain_specific_entities(self, text):
262
+ """Extract entities from all models and categorize by domain"""
263
+ domain_entities = {}
264
+
265
+ for model_name, nlp_model in self.models.items():
266
+ doc = nlp_model(text)
267
+ domain_entities[model_name] = [(ent.text, ent.label_) for ent in doc.ents]
268
+
269
+ return domain_entities
270
+
271
+ def is_substring_of_any(self, ngram, common_ngrams):
272
+ for other_ngram in common_ngrams:
273
+ if ngram != other_ngram and ngram in other_ngram:
274
+ return True
275
+ return False
276
+
277
+ def find_filtered_ngrams(self, sentences):
278
+ tqdm.write("[NgramProcessor] Processing...")
279
+
280
+ # Step 1: First extract numerical ranges or standalone numbers (special priority)
281
+ number_ranges = self.extract_number_ranges(sentences)
282
+ standalone_numbers = self.extract_standalone_numbers(sentences)
283
+
284
+ # Step 2: Use regex to find common subsequences
285
+ regex_subsequences = self.extract_regex_subsequences(sentences)
286
+ tqdm.write(f"[NgramProcessor] Regex Subsequences: {regex_subsequences}")
287
+
288
+ # Step 3: Then apply spaCy to detect named entities
289
+ named_entities = self.extract_named_entities(sentences)
290
+ # Make sure percentage values have proper format
291
+ for i, entity in enumerate(named_entities):
292
+ if re.match(r'\d+$', entity) and any(f"{entity}%" in sentence for sentence in sentences):
293
+ # Replace standalone digit with percentage if it appears as percentage in text
294
+ named_entities[i] = f"{entity}%"
295
+
296
+ tqdm.write(f"[NgramProcessor] Named Entities: {named_entities}")
297
+
298
+ # Step 4: Consolidate and filter all detected patterns
299
+ # Collect all patterns in one list
300
+ all_patterns = number_ranges + regex_subsequences + named_entities + standalone_numbers
301
+
302
+ # Sort by length (longer first) to prioritize more specific patterns
303
+ all_patterns.sort(key=len, reverse=True)
304
+
305
+ # Remove duplicates while preserving order
306
+ unique_patterns = []
307
+ seen = set()
308
+ for pattern in all_patterns:
309
+ if pattern not in seen:
310
+ # Check if this pattern is a substring of any already selected pattern
311
+ is_substring = False
312
+ for selected_pattern in unique_patterns:
313
+ if pattern in selected_pattern and pattern != selected_pattern:
314
+ is_substring = True
315
+ break
316
+ if not is_substring:
317
+ unique_patterns.append(pattern)
318
+ seen.add(pattern)
319
+ # Re-index sequentially
320
+ indexed_patterns = [(i+1, pattern) for i, pattern in enumerate(unique_patterns)]
321
+ self.indexed_patterns = indexed_patterns
322
+ non_melting_points = [pattern for _, pattern in indexed_patterns]
323
+ tqdm.write(f"[NgramProcessor] Filtered non_melting_points: {non_melting_points}")
324
+ tqdm.write(f"[NgramProcessor] Filtered non-melting points: {len(non_melting_points)}")
325
+
326
+ # Filter out patterns that are substrings of longer patterns or standalone numbers
327
+ standalone_numbers_set = set(standalone_numbers)
328
+ non_melting_points = []
329
+ for pattern in unique_patterns:
330
+ is_substring = False
331
+ for longer_pattern in non_melting_points:
332
+ # Check if pattern is contained within a longer pattern
333
+ if pattern in longer_pattern:
334
+ is_substring = True
335
+ break
336
+ if not is_substring or pattern in standalone_numbers_set:
337
+ non_melting_points.append(pattern)
338
+
339
+ # For remaining cases that might have been missed, apply NLTK n-gram extraction
340
+ # Only on cleaned sentences (less computationally expensive now)
341
+ clean_to_original = {}
342
+ sentences_cleaned = []
343
+
344
+ # Process sentences with spaCy to preserve entity information
345
+ docs = list(self.nlp.pipe(sentences))
346
+
347
+ for i, doc in enumerate(docs):
348
+ original_sentence = sentences[i]
349
+ entity_texts = {ent.text.lower() for ent in doc.ents if len(ent.text.split()) > 1}
350
+
351
+ # Tokenize while preserving entities and numerical ranges
352
+ tokens = []
353
+ j = 0
354
+ words = [token.text for token in doc]
355
+
356
+ while j < len(words):
357
+ # First check for numerical ranges
358
+ current_word = words[j].lower()
359
+ if self.number_range_pattern.match(current_word):
360
+ tokens.append(current_word)
361
+ j += 1
362
+ continue
363
+
364
+ # Then check for entities
365
+ matched_entity = None
366
+ for ent in sorted(entity_texts, key=len, reverse=True):
367
+ ent_words = ent.split()
368
+ if j + len(ent_words) <= len(words) and [w.lower() for w in words[j:j+len(ent_words)]] == ent_words:
369
+ matched_entity = " ".join(words[j:j+len(ent_words)])
370
+ tokens.append(matched_entity.lower()) # preserve full entity
371
+ j += len(ent_words)
372
+ break
373
+
374
+ if not matched_entity:
375
+ word = words[j].lower()
376
+ if word not in self.stop_words and re.match(r'\w+', word):
377
+ tokens.append(word)
378
+ j += 1
379
+
380
+ cleaned = " ".join(tokens)
381
+ sentences_cleaned.append(cleaned)
382
+ clean_to_original[cleaned] = original_sentence
383
+
384
+ # Step 5: Only run n-gram extraction on gaps not covered by regex and named entities
385
+ ngram_lengths = [4, 3, 2, 1] # Consider shorter n-grams now since we already have longer phrases
386
+ all_ngrams_by_length = {}
387
+
388
+ for n in ngram_lengths:
389
+ all_ngrams = []
390
+ for sentence in sentences_cleaned:
391
+ tokens = sentence.split()
392
+ if len(tokens) >= n:
393
+ sent_ngrams = list(ngrams(tokens, n))
394
+ all_ngrams.extend(sent_ngrams)
395
+ all_ngrams_by_length[n] = Counter(all_ngrams)
396
+
397
+ # Step 6: Add additional n-grams that are frequent but weren't caught by regex or named entities
398
+ threshold_factor = 1.0 # threshold since we're focusing on gaps
399
+
400
+ for n_size in sorted(ngram_lengths, reverse=True):
401
+ ngram_counts = all_ngrams_by_length[n_size]
402
+ threshold = max(2, int(len(sentences) * threshold_factor))
403
+
404
+ # Sort by count for efficiency
405
+ for ngram, count in ngram_counts.most_common():
406
+ if count >= threshold:
407
+ ngram_str = ' '.join(ngram)
408
+ # Skip if is a substring of existing n-grams or already in our collection
409
+ if ngram_str not in non_melting_points and not self.is_substring_of_any(ngram_str, non_melting_points):
410
+ non_melting_points.append(ngram_str)
411
+
412
+ # Create sorted version for efficient lookup
413
+ final_non_melting_points = non_melting_points.copy()
414
+ sorted_non_melting_points = sorted(final_non_melting_points, key=len, reverse=True)
415
+ final_indexed_patterns = [(i+1, pattern) for i, pattern in enumerate(sorted_non_melting_points)]
416
+
417
+ #Filter out n-grams that consist entirely of stop words
418
+ filtered_patterns = []
419
+ for idx, pattern in final_indexed_patterns:
420
+ words = pattern.lower().split()
421
+
422
+ # Check if the pattern is a number or contains a number
423
+ has_number = any(re.match(r'.*\d+.*', word) for word in words)
424
+
425
+ # If the pattern has a number OR has any non-stop word, keep it
426
+ if has_number or any(word not in self.stop_words for word in words):
427
+ filtered_patterns.append((idx, pattern))
428
+ else:
429
+ tqdm.write(f"[NgramProcessor] Removing n-gram with all stop words: {pattern}")
430
+
431
+ # Reassign filtered patterns with reindexed values
432
+ self.indexed_patterns = [(i+1, pattern) for i, (_, pattern) in enumerate(filtered_patterns)]
433
+
434
+ # Generate the results with more efficient regex matching
435
+ result = {}
436
+ for sentence in sentences:
437
+ sentence_result = {}
438
+ for _,ngram in self.indexed_patterns: # Use the filtered patterns
439
+ # Skip single word stopwords and short terms
440
+ words = ngram.split()
441
+ if len(words) == 1 and (words[0] in self.stop_words or len(words[0]) < 3):
442
+ continue
443
+
444
+ # Handle numerical ranges differently - need exact matching
445
+ if self.number_range_pattern.match(ngram):
446
+ pattern = re.compile(r'\b' + re.escape(ngram) + r'\b', re.IGNORECASE)
447
+ else:
448
+ # Compile the regex pattern once per n-gram - modified to handle special characters
449
+ pattern = re.compile(r'(?<!\w)' + re.escape(ngram) + r'(?!\w)', re.IGNORECASE)
450
+
451
+ matches = list(pattern.finditer(sentence))
452
+
453
+ if matches:
454
+ indices = []
455
+ for match in matches:
456
+ # Calculate word indices with improved handling for hyphenated terms
457
+ start_pos = match.start()
458
+ text_before = sentence[:start_pos]
459
+
460
+ # More accurate word counting that handles hyphenated terms
461
+ start_idx = len(re.findall(r'\s+', text_before)) + (0 if text_before.strip() == "" else 1)
462
+
463
+ # Count words in the matched n-gram (handling hyphens as single terms)
464
+ if self.number_range_pattern.match(ngram):
465
+ # Numerical ranges count as one term
466
+ ngram_word_count = 1
467
+ else:
468
+ ngram_word_count = len(re.findall(r'\S+', ngram))
469
+
470
+ end_idx = start_idx + ngram_word_count - 1
471
+
472
+ indices.append((start_idx, end_idx))
473
+
474
+ if indices: # Only add if we found valid indices
475
+ sentence_result[ngram] = indices
476
+
477
+ result[sentence] = sentence_result
478
+
479
+ # Apply the stopword filter before returning
480
+ result = self.filter_standalone_stopwords(result)
481
+ return result, dict(self.indexed_patterns)
482
+
483
+ def find_relative_order(self, sentence, common_ngrams):
484
+ # First, identify all possible matches without modifying the sentence
485
+ all_matches = []
486
+
487
+ for ngram in common_ngrams:
488
+ # Special handling for percentages
489
+ if any(char in ngram for char in '&-/.\'()[]$€£¥+%'):
490
+ pattern = re.compile(r'\b' + re.escape(ngram) + r'\b', re.IGNORECASE)
491
+ # Handle numerical ranges
492
+ elif self.number_range_pattern.match(ngram):
493
+ pattern = re.compile(r'\b' + re.escape(ngram) + r'\b', re.IGNORECASE)
494
+ else:
495
+ pattern = re.compile(r'(?<!\w)' + re.escape(ngram) + r"(?:'s)?(?!\w)", re.IGNORECASE)
496
+
497
+ for match in pattern.finditer(sentence):
498
+ start, end = match.span()
499
+ #store character position range, ngram text, and token count
500
+ all_matches.append((start, end, ngram, len(ngram.split())))
501
+
502
+ # Pre-process: identify all word spans in the original sentence
503
+ words = []
504
+ word_spans = []
505
+ for match in re.finditer(r'\S+', sentence):
506
+ words.append(match.group())
507
+ word_spans.append((match.start(), match.end()))
508
+
509
+ # Create a mapping from character positions to word indices
510
+ char_to_word_idx = {}
511
+ for i, (start, end) in enumerate(word_spans):
512
+ for pos in range(start, end + 1):
513
+ char_to_word_idx[pos] = i
514
+
515
+ # Sort by length in characters first, then by word count
516
+ all_matches.sort(key=lambda x: (-len(x[2]), -x[3], x[0]))
517
+
518
+ # Filter out ngrams that overlap with already claimed ranges
519
+ filtered_matches = []
520
+ claimed_ranges = []
521
+
522
+ for start, end, ngram, length in all_matches:
523
+ # Check if this match overlaps with any existing claimed range
524
+ is_overlapping = False
525
+ for c_start, c_end in claimed_ranges:
526
+ # Check for any overlap
527
+ if max(start, c_start) < min(end, c_end):
528
+ is_overlapping = True
529
+ break
530
+
531
+ if not is_overlapping:
532
+ # Add this ngram to our filtered list
533
+ filtered_matches.append((start, end, ngram, length))
534
+ # Claim its range
535
+ claimed_ranges.append((start, end))
536
+
537
+ # Sort filtered matches by position for final ordering
538
+ filtered_matches.sort(key=lambda x: x[0])
539
+
540
+ # Create word-level indices for the final matches
541
+ word_level_matches = []
542
+ for start, end, ngram, _ in filtered_matches:
543
+ # Find the word index for the start and end positions
544
+ try:
545
+ start_word_idx = char_to_word_idx.get(start, char_to_word_idx.get(start+1))
546
+ end_word_idx = char_to_word_idx.get(end-1, char_to_word_idx.get(end-2))
547
+
548
+ if start_word_idx is not None and end_word_idx is not None:
549
+ word_level_matches.append((start_word_idx, end_word_idx, ngram))
550
+ except (KeyError, IndexError):
551
+ # Skip this match if we can't determine word indices
552
+ continue
553
+
554
+ # Create the final order with 1-based indexing
555
+ ngram_to_index = {pattern: idx for idx, pattern in self.indexed_patterns}
556
+ relative_order = [(ngram_to_index.get(ngram, i+1), ngram) for i, (_, _, ngram) in enumerate(word_level_matches)]
557
+
558
+ return relative_order, sentence
559
+
560
+
561
+ # Example usage
562
+ if __name__ == "__main__":
563
+ # Test with NBA Play-In Tournament example
564
+ sentences = [
565
+ "The NBA Play-In Tournament tips off tonight as the No. 7-10 teams in each conference battle for a spot in the playoffs. Here's everything you need to know as the action unfolds.",
566
+ "Tonight the NBA Play-In Tournament begins with No. 7-10 teams from each conference competing for playoff spots. Here's your guide to following all the action.",
567
+ "The NBA Play-In Tournament kicks off this evening featuring the No. 7-10 teams across both conferences fighting for playoff positions. Here's what you should know about the upcoming games.",
568
+ "Starting tonight, the NBA Play-In Tournament will showcase the No. 7-10 teams from each conference as they compete for remaining playoff berths. Here's your complete guide to the action.",
569
+ "The NBA Play-In Tournament begins tonight with the No. 7-10 teams in both conferences battling for playoff spots. Here's everything you need to know about the upcoming games.",
570
+ "Tonight marks the start of the NBA Play-In Tournament where No. 7-10 teams in each conference compete for playoff positions. Here's your essential guide to following the action.",
571
+ "The NBA Play-In Tournament tips off tonight, featuring No. 7-10 teams from both conferences fighting for playoff berths. Here's what you need to know about the tournament.",
572
+ "Beginning tonight, the NBA Play-In Tournament will pit the No. 7-10 teams in each conference against each other for playoff spots. Here's everything you should know about the games.",
573
+ "The NBA Play-In Tournament starts tonight with No. 7-10 teams across both conferences competing for playoff positions. Here's your complete guide to all the action.",
574
+ "Tonight is the tip-off of the NBA Play-In Tournament where the No. 7-10 teams from each conference battle for remaining playoff spots. Here's what you need to know as the games unfold."
575
+ ]
576
+
577
+ # Initialize with multiple models
578
+ processor = NgramProcessor(models=["en_core_web_trf"])
579
+
580
+ # Process with all models combined
581
+ common_ngrams,indexed_ngrams = processor.find_filtered_ngrams(sentences)
582
+
583
+
584
+ # Print results
585
+ print("Common n-grams with indices per sentence:")
586
+ for sentence in sentences:
587
+ order, updated_sentence = processor.find_relative_order(sentence, common_ngrams[sentence])
588
+ print(f"Sentence: {sentence}")
589
+ print(f"Order: {order}")
590
+ print()
utils/old/masking/masking_methods.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from transformers import BertTokenizer, BertForMaskedLM
4
+ from nltk.corpus import stopwords
5
+ import nltk
6
+
7
+ # Ensure stopwords are downloaded
8
+ try:
9
+ nltk.data.find('corpora/stopwords')
10
+ except LookupError:
11
+ nltk.download('stopwords')
12
+
13
+ class MaskingProcessor:
14
+ def __init__(self, ):
15
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
16
+ self.model = BertForMaskedLM.from_pretrained("bert-base-uncased")
17
+ self.stop_words = set(stopwords.words('english'))
18
+
19
+ def adjust_ngram_indices(self, words, common_ngrams, remove_stopwords):
20
+ """
21
+ Adjust indices of common n-grams after removing stop words.
22
+
23
+ Args:
24
+ words (list): List of words in the original sentence.
25
+ common_ngrams (dict): Common n-grams and their indices.
26
+
27
+ Returns:
28
+ dict: Adjusted common n-grams and their indices.
29
+ """
30
+ if not remove_stopwords:
31
+ return common_ngrams
32
+
33
+ non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
34
+ adjusted_ngrams = {}
35
+
36
+ for ngram, positions in common_ngrams.items():
37
+ adjusted_positions = []
38
+ for start, end in positions:
39
+ try:
40
+ new_start = non_stop_word_indices.index(start)
41
+ new_end = non_stop_word_indices.index(end)
42
+ adjusted_positions.append((new_start, new_end))
43
+ except ValueError:
44
+ continue # Skip if indices cannot be mapped
45
+ adjusted_ngrams[ngram] = adjusted_positions
46
+
47
+ return adjusted_ngrams
48
+
49
+ # def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords):
50
+ # """
51
+ # Mask one word before the first common n-gram, one between two n-grams,
52
+ # and one after the last common n-gram (random selection).
53
+
54
+ # Args:
55
+ # original_sentence (str): Original sentence
56
+ # common_ngrams (dict): Common n-grams and their indices
57
+
58
+ # Returns:
59
+ # str: Masked sentence with original stop words retained
60
+ # """
61
+ # words = original_sentence.split()
62
+ # if remove_stopwords:
63
+ # non_stop_words = [word for word in words if word.lower() not in self.stop_words]
64
+ # non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
65
+ # else:
66
+ # non_stop_words = words
67
+ # non_stop_word_indices = list(range(len(words)))
68
+ # # non_stop_words = [word for word in words if word.lower() not in self.stop_words] if remove_stopwords else words
69
+ # adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)
70
+
71
+ # mask_indices = []
72
+ # # Handle before the first common n-gram
73
+ # if adjusted_ngrams:
74
+ # first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
75
+ # if first_ngram_start > 0:
76
+ # mask_indices.append(random.randint(0, first_ngram_start - 1))
77
+
78
+ # # Handle between common n-grams
79
+ # ngram_positions = list(adjusted_ngrams.values())
80
+ # for i in range(len(ngram_positions) - 1):
81
+ # end_prev = ngram_positions[i][-1][1]
82
+ # start_next = ngram_positions[i + 1][0][0]
83
+ # if start_next > end_prev + 1:
84
+ # mask_indices.append(random.randint(end_prev + 1, start_next - 1))
85
+
86
+ # # Handle after the last common n-gram
87
+ # last_ngram_end = ngram_positions[-1][-1][1]
88
+ # if last_ngram_end < len(non_stop_words) - 1:
89
+ # mask_indices.append(random.randint(last_ngram_end + 1, len(non_stop_words) - 1))
90
+
91
+ # # Mask the chosen indices
92
+ # original_masked_sentence = words[:]
93
+ # # for idx in mask_indices:
94
+ # # if idx not in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
95
+ # # non_stop_words[idx] = self.tokenizer.mask_token
96
+ # # original_masked_sentence[idx] = self.tokenizer.mask_token
97
+ # for idx in mask_indices:
98
+ # if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
99
+ # continue # Skip if index belongs to common n-grams
100
+ # if remove_stopwords:
101
+ # original_idx = non_stop_word_indices[idx] # Map back to original indices
102
+ # original_masked_sentence[original_idx] = self.tokenizer.mask_token
103
+ # else:
104
+ # original_masked_sentence[idx] = self.tokenizer.mask_token
105
+
106
+
107
+ # return " ".join(original_masked_sentence)
108
+ def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords):
109
+ """
110
+ Mask one word before the first common n-gram, one between two n-grams,
111
+ and one after the last common n-gram (random selection).
112
+
113
+ Args:
114
+ original_sentence (str): Original sentence
115
+ common_ngrams (dict): Common n-grams and their indices
116
+ remove_stopwords (bool): Whether to remove stop words
117
+
118
+ Returns:
119
+ str: Masked sentence with original stop words retained
120
+ """
121
+ words = original_sentence.split()
122
+ if remove_stopwords:
123
+ non_stop_words = [word for word in words if word.lower() not in self.stop_words]
124
+ non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
125
+ else:
126
+ non_stop_words = words
127
+ non_stop_word_indices = list(range(len(words)))
128
+
129
+ adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)
130
+
131
+ # Collect all indices corresponding to common n-grams
132
+ common_ngram_indices = {
133
+ idx for ngram_positions in adjusted_ngrams.values()
134
+ for start, end in ngram_positions
135
+ for idx in range(start, end + 1)
136
+ }
137
+
138
+ mask_indices = []
139
+ # Handle before the first common n-gram
140
+ if adjusted_ngrams:
141
+ first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
142
+ if first_ngram_start > 0:
143
+ potential_indices = [i for i in range(first_ngram_start) if i not in common_ngram_indices]
144
+ if potential_indices:
145
+ mask_indices.append(random.choice(potential_indices))
146
+
147
+ # Handle between common n-grams
148
+ ngram_positions = list(adjusted_ngrams.values())
149
+ for i in range(len(ngram_positions) - 1):
150
+ end_prev = ngram_positions[i][-1][1]
151
+ start_next = ngram_positions[i + 1][0][0]
152
+ potential_indices = [i for i in range(end_prev + 1, start_next) if i not in common_ngram_indices]
153
+ if potential_indices:
154
+ mask_indices.append(random.choice(potential_indices))
155
+
156
+ # Handle after the last common n-gram
157
+ last_ngram_end = ngram_positions[-1][-1][1]
158
+ if last_ngram_end < len(non_stop_words) - 1:
159
+ potential_indices = [i for i in range(last_ngram_end + 1, len(non_stop_words)) if i not in common_ngram_indices]
160
+ if potential_indices:
161
+ mask_indices.append(random.choice(potential_indices))
162
+
163
+ # Mask the chosen indices
164
+ original_masked_sentence = words[:]
165
+ for idx in mask_indices:
166
+ if remove_stopwords:
167
+ original_idx = non_stop_word_indices[idx] # Map back to original indices
168
+ original_masked_sentence[original_idx] = self.tokenizer.mask_token
169
+ else:
170
+ original_masked_sentence[idx] = self.tokenizer.mask_token
171
+
172
+ return " ".join(original_masked_sentence)
173
+
174
+ def mask_sentence_entropy(self, original_sentence, common_ngrams, remove_stopwords):
175
+ """
176
+ Mask one word before the first common n-gram, one between two n-grams,
177
+ and one after the last common n-gram (highest entropy selection).
178
+
179
+ Args:
180
+ original_sentence (str): Original sentence
181
+ common_ngrams (dict): Common n-grams and their indices
182
+
183
+ Returns:
184
+ str: Masked sentence with original stop words retained
185
+ """
186
+ words = original_sentence.split()
187
+ # non_stop_words = [word for word in words if word.lower() not in self.stop_words] if remove_stopwords else words
188
+ if remove_stopwords:
189
+ non_stop_words = [word for word in words if word.lower() not in self.stop_words]
190
+ non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
191
+ else:
192
+ non_stop_words = words
193
+ non_stop_word_indices = list(range(len(words)))
194
+ adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)
195
+ entropy_scores = {}
196
+
197
+ for idx, word in enumerate(non_stop_words):
198
+ if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
199
+ continue # Skip words in common n-grams
200
+
201
+ masked_sentence = non_stop_words[:idx] + [self.tokenizer.mask_token] + non_stop_words[idx + 1:]
202
+ masked_sentence = " ".join(masked_sentence)
203
+ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
204
+ mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
205
+
206
+ with torch.no_grad():
207
+ outputs = self.model(input_ids)
208
+ logits = outputs.logits
209
+
210
+ filtered_logits = logits[0, mask_token_index, :]
211
+ probs = torch.softmax(filtered_logits, dim=-1)
212
+ entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item() # Add epsilon to prevent log(0)
213
+ entropy_scores[idx] = entropy
214
+
215
+ mask_indices = []
216
+
217
+ # Handle before the first common n-gram
218
+ if adjusted_ngrams:
219
+ first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
220
+ candidates = [i for i in range(first_ngram_start) if i in entropy_scores]
221
+ if candidates:
222
+ mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
223
+
224
+ # Handle between common n-grams
225
+ ngram_positions = list(adjusted_ngrams.values())
226
+ for i in range(len(ngram_positions) - 1):
227
+ end_prev = ngram_positions[i][-1][1]
228
+ start_next = ngram_positions[i + 1][0][0]
229
+ candidates = [i for i in range(end_prev + 1, start_next) if i in entropy_scores]
230
+ if candidates:
231
+ mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
232
+
233
+ # Handle after the last common n-gram
234
+ last_ngram_end = ngram_positions[-1][-1][1]
235
+ candidates = [i for i in range(last_ngram_end + 1, len(non_stop_words)) if i in entropy_scores]
236
+ if candidates:
237
+ mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
238
+
239
+ # Mask the chosen indices
240
+ original_masked_sentence = words[:]
241
+ # for idx in mask_indices:
242
+ # non_stop_words[idx] = self.tokenizer.mask_token
243
+ # original_masked_sentence[idx] = self.tokenizer.mask_token
244
+
245
+ for idx in mask_indices:
246
+ if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
247
+ continue # Skip if index belongs to common n-grams
248
+ if remove_stopwords:
249
+ original_idx = non_stop_word_indices[idx] # Map back to original indices
250
+ original_masked_sentence[original_idx] = self.tokenizer.mask_token
251
+ else:
252
+ original_masked_sentence[idx] = self.tokenizer.mask_token
253
+
254
+
255
+ return " ".join(original_masked_sentence)
256
+
257
+ def calculate_mask_logits(self, masked_sentence):
258
+ """
259
+ Calculate logits for masked tokens in the sentence using BERT.
260
+
261
+ Args:
262
+ masked_sentence (str): Sentence with [MASK] tokens
263
+
264
+ Returns:
265
+ dict: Masked token indices and their logits
266
+ """
267
+ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
268
+ mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
269
+
270
+ with torch.no_grad():
271
+ outputs = self.model(input_ids)
272
+ logits = outputs.logits
273
+
274
+ mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index}
275
+ return mask_logits
276
+
277
+ def process_sentences(self, original_sentences, result_dict, method="random", remove_stopwords=False):
278
+ """
279
+ Process a list of sentences and calculate logits for masked tokens using the specified method.
280
+
281
+ Args:
282
+ original_sentences (list): List of original sentences
283
+ result_dict (dict): Common n-grams and their indices for each sentence
284
+ method (str): Masking method ("random" or "entropy")
285
+
286
+ Returns:
287
+ dict: Masked sentences and their logits for each sentence
288
+ """
289
+ results = {}
290
+
291
+ for sentence, ngrams in result_dict.items():
292
+ if method == "random":
293
+ masked_sentence = self.mask_sentence_random(sentence, ngrams, remove_stopwords)
294
+ elif method == "entropy":
295
+ masked_sentence = self.mask_sentence_entropy(sentence, ngrams, remove_stopwords)
296
+ else:
297
+ raise ValueError("Invalid method. Choose 'random' or 'entropy'.")
298
+
299
+ logits = self.calculate_mask_logits(masked_sentence)
300
+ results[sentence] = {
301
+ "masked_sentence": masked_sentence,
302
+ "mask_logits": logits
303
+ }
304
+
305
+ return results
306
+
307
+ # Example usage
308
+ if __name__ == "__main__":
309
+ # !!! Working both the cases regardless if the stopword is removed or not
310
+ sentences = [
311
+ "The quick brown fox jumps over the lazy dog.",
312
+ "A speedy brown fox jumps over a lazy dog.",
313
+ "A swift brown fox leaps over the lethargic dog."
314
+ ]
315
+ result_dict ={
316
+ 'The quick brown fox jumps over the lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
317
+ 'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
318
+ 'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}
319
+ }
320
+
321
+
322
+ processor = MaskingProcessor()
323
+ results_random = processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=True)
324
+ # results_entropy = processor.process_sentences(sentences, result_dict, method="entropy", remove_stopwords=False)
325
+
326
+ for sentence, output in results_random.items():
327
+ print(f"Original Sentence (Random): {sentence}")
328
+ print(f"Masked Sentence (Random): {output['masked_sentence']}")
329
+ # # print(f"Mask Logits (Random): {output['mask_logits']}")
330
+ # print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
331
+ # print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
332
+ # print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
333
+ print('--------------------------------')
334
+ # for mask_idx, logits in output["mask_logits"].items():
335
+ # print(f"Logits for [MASK] at position {mask_idx}:")
336
+ # print(f' logits : {logits[:5]}') # List of logits for all vocabulary tokens
337
+
338
+
339
+
340
+
341
+ # result_dict = {
342
+ # "The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
343
+ # "A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
344
+ # "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}
345
+ # }
346
+
347
+
348
+ # print('--------------------------------')
349
+ # for sentence, output in results_entropy.items():
350
+ # print(f"Original Sentence (Entropy): {sentence}")
351
+ # print(f"Masked Sentence (Entropy): {output['masked_sentence']}")
352
+ # # print(f"Mask Logits (Entropy): {output['mask_logits']}")
353
+ # print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
354
+ # print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
355
+ # print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
utils/old/masking/masking_methods_new_work.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from transformers import BertTokenizer, BertForMaskedLM
4
+ from nltk.corpus import stopwords
5
+ import nltk
6
+
7
+ # Ensure stopwords are downloaded
8
+ try:
9
+ nltk.data.find('corpora/stopwords')
10
+ except LookupError:
11
+ nltk.download('stopwords')
12
+
13
+ class MaskingProcessor:
14
+ def __init__(self):
15
+
16
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
17
+ self.model = BertForMaskedLM.from_pretrained("bert-base-uncased")
18
+ self.stop_words = set(stopwords.words('english'))
19
+
20
+ def remove_stopwords(self, words):
21
+ """
22
+ Remove stopwords from the given list of words.
23
+
24
+ Args:
25
+ words (list): List of words.
26
+
27
+ Returns:
28
+ list: List of non-stop words.
29
+ """
30
+ return [word for word in words if word.lower() not in self.stop_words]
31
+
32
+ def adjust_ngram_indices(self, original_words, common_ngrams):
33
+ """
34
+ Adjust indices of common n-grams after removing stopwords.
35
+
36
+ Args:
37
+ original_words (list): Original list of words.
38
+ common_ngrams (dict): Common n-grams and their indices.
39
+
40
+ Returns:
41
+ dict: Adjusted common n-grams with updated indices.
42
+ """
43
+ non_stop_words = self.remove_stopwords(original_words)
44
+ original_to_non_stop = []
45
+ non_stop_idx = 0
46
+
47
+ for original_idx, word in enumerate(original_words):
48
+ if word.lower() not in self.stop_words:
49
+ original_to_non_stop.append((original_idx, non_stop_idx))
50
+ non_stop_idx += 1
51
+
52
+ adjusted_ngrams = {}
53
+ for ngram, positions in common_ngrams.items():
54
+ adjusted_positions = []
55
+ for start, end in positions:
56
+ try:
57
+ new_start = next(non_stop for orig, non_stop in original_to_non_stop if orig == start)
58
+ new_end = next(non_stop for orig, non_stop in original_to_non_stop if orig == end)
59
+ adjusted_positions.append((new_start, new_end))
60
+ except StopIteration:
61
+ continue # Skip if indices cannot be mapped
62
+ adjusted_ngrams[ngram] = adjusted_positions
63
+
64
+ return adjusted_ngrams
65
+
66
+ def mask_sentence_random(self, sentence, common_ngrams):
67
+ """
68
+ Mask words in the sentence based on the specified rules after removing stopwords.
69
+ """
70
+ original_words = sentence.split()
71
+ print(f' ---- original_words : {original_words} ----- ')
72
+ non_stop_words = self.remove_stopwords(original_words)
73
+ print(f' ---- non_stop_words : {non_stop_words} ----- ')
74
+ adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
75
+ print(f' ---- common_ngrams : {common_ngrams} ----- ')
76
+ print(f' ---- adjusted_ngrams : {adjusted_ngrams} ----- ')
77
+
78
+ mask_indices = []
79
+
80
+ # Extract n-gram positions in non-stop words
81
+ ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
82
+
83
+ # Mask a word before the first common n-gram
84
+ if ngram_positions:
85
+ print(f' ---- ngram_positions : {ngram_positions} ----- ')
86
+ first_ngram_start = ngram_positions[0][0]
87
+ print(f' ---- first_ngram_start : {first_ngram_start} ----- ')
88
+ if first_ngram_start > 0:
89
+ mask_index_before_ngram = random.randint(0, first_ngram_start-1)
90
+ print(f' ---- mask_index_before_ngram : {mask_index_before_ngram} ----- ')
91
+ mask_indices.append(mask_index_before_ngram)
92
+
93
+ # Mask words between common n-grams
94
+ for i in range(len(ngram_positions) - 1):
95
+ end_prev = ngram_positions[i][1]
96
+ print(f' ---- end_prev : {end_prev} ----- ') # END INDICE FROM PREV LOOP FUNKNLKNLKNLKNLKNLKNLSKDNFLKSDHJFLSDJKFH:KLSDHF:LHKSDF:HJKLDFS:HJKLDFSHJK:
97
+ start_next = ngram_positions[i + 1][0]
98
+ print(f' ---- start_next : {start_next} ----- ')
99
+ if start_next > end_prev + 1:
100
+ mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
101
+ print(f' ---- mask_index_between_ngrams : {mask_index_between_ngrams} ----- ')
102
+ mask_indices.append(mask_index_between_ngrams)
103
+
104
+ # Mask a word after the last common n-gram
105
+ last_ngram_end = ngram_positions[-1][1]
106
+ if last_ngram_end < len(non_stop_words) - 1:
107
+ print(f' ---- last_ngram_end : {last_ngram_end} ----- ')
108
+ mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
109
+ print(f' ---- mask_index_after_ngram : {mask_index_after_ngram} ----- ')
110
+ mask_indices.append(mask_index_after_ngram)
111
+
112
+ # Create mapping from non-stop words to original indices
113
+ non_stop_to_original = {}
114
+ non_stop_idx = 0
115
+ for orig_idx, word in enumerate(original_words):
116
+ if word.lower() not in self.stop_words:
117
+ non_stop_to_original[non_stop_idx] = orig_idx
118
+ non_stop_idx += 1
119
+
120
+ # Map mask indices from non-stop word positions to original positions
121
+ print(f' ---- non_stop_to_original : {non_stop_to_original} ----- ')
122
+ original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
123
+ print(f' ---- original_mask_indices : {original_mask_indices} ----- ')
124
+
125
+ # Apply masks to the original sentence
126
+ masked_words = original_words.copy()
127
+ for idx in original_mask_indices:
128
+ masked_words[idx] = self.tokenizer.mask_token
129
+
130
+ return " ".join(masked_words)
131
+
132
+ def mask_sentence_pseudorandom(self, sentence, common_ngrams):
133
+ """
134
+ Mask words in the sentence based on the specified rules after removing stopwords.
135
+ """
136
+ random.seed(42)
137
+ original_words = sentence.split()
138
+ print(f' ---- original_words : {original_words} ----- ')
139
+ non_stop_words = self.remove_stopwords(original_words)
140
+ print(f' ---- non_stop_words : {non_stop_words} ----- ')
141
+ adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
142
+ print(f' ---- common_ngrams : {common_ngrams} ----- ')
143
+ print(f' ---- adjusted_ngrams : {adjusted_ngrams} ----- ')
144
+
145
+ mask_indices = []
146
+
147
+ # Extract n-gram positions in non-stop words
148
+ ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
149
+
150
+ # Mask a word before the first common n-gram
151
+ if ngram_positions:
152
+ print(f' ---- ngram_positions : {ngram_positions} ----- ')
153
+ first_ngram_start = ngram_positions[0][0]
154
+ print(f' ---- first_ngram_start : {first_ngram_start} ----- ')
155
+ if first_ngram_start > 0:
156
+ mask_index_before_ngram = random.randint(0, first_ngram_start-1)
157
+ print(f' ---- mask_index_before_ngram : {mask_index_before_ngram} ----- ')
158
+ mask_indices.append(mask_index_before_ngram)
159
+
160
+ # Mask words between common n-grams
161
+ for i in range(len(ngram_positions) - 1):
162
+ end_prev = ngram_positions[i][1]
163
+ print(f' ---- end_prev : {end_prev} ----- ')
164
+ start_next = ngram_positions[i + 1][0]
165
+ print(f' ---- start_next : {start_next} ----- ')
166
+ if start_next > end_prev + 1:
167
+ mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
168
+ print(f' ---- mask_index_between_ngrams : {mask_index_between_ngrams} ----- ')
169
+ mask_indices.append(mask_index_between_ngrams)
170
+
171
+ # Mask a word after the last common n-gram
172
+ last_ngram_end = ngram_positions[-1][1]
173
+ if last_ngram_end < len(non_stop_words) - 1:
174
+ print(f' ---- last_ngram_end : {last_ngram_end} ----- ')
175
+ mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
176
+ print(f' ---- mask_index_after_ngram : {mask_index_after_ngram} ----- ')
177
+ mask_indices.append(mask_index_after_ngram)
178
+
179
+ # Create mapping from non-stop words to original indices
180
+ non_stop_to_original = {}
181
+ non_stop_idx = 0
182
+ for orig_idx, word in enumerate(original_words):
183
+ if word.lower() not in self.stop_words:
184
+ non_stop_to_original[non_stop_idx] = orig_idx
185
+ non_stop_idx += 1
186
+
187
+ # Map mask indices from non-stop word positions to original positions
188
+ print(f' ---- non_stop_to_original : {non_stop_to_original} ----- ')
189
+ original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
190
+ print(f' ---- original_mask_indices : {original_mask_indices} ----- ')
191
+
192
+ # Apply masks to the original sentence
193
+ masked_words = original_words.copy()
194
+ for idx in original_mask_indices:
195
+ masked_words[idx] = self.tokenizer.mask_token
196
+
197
+ return " ".join(masked_words)
198
+
199
+
200
+ def calculate_word_entropy(self, sentence, word_position):
201
+ """
202
+ Calculate entropy for a specific word position in the sentence.
203
+
204
+ Args:
205
+ sentence (str): The input sentence
206
+ word_position (int): Position of the word to calculate entropy for
207
+
208
+ Returns:
209
+ float: Entropy value for the word
210
+ """
211
+ words = sentence.split()
212
+ masked_words = words.copy()
213
+ masked_words[word_position] = self.tokenizer.mask_token
214
+ masked_sentence = " ".join(masked_words)
215
+
216
+ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
217
+ mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
218
+
219
+ with torch.no_grad():
220
+ outputs = self.model(input_ids)
221
+ logits = outputs.logits
222
+
223
+ # Get probabilities for the masked position
224
+ probs = torch.nn.functional.softmax(logits[0, mask_token_index], dim=-1)
225
+ # Calculate entropy: -sum(p * log(p))
226
+ entropy = -torch.sum(probs * torch.log(probs + 1e-9))
227
+
228
+ return entropy.item()
229
+
230
+ def mask_sentence_entropy(self, sentence, common_ngrams):
231
+ """
232
+ Mask words in the sentence based on entropy, following n-gram positioning rules.
233
+
234
+ Args:
235
+ sentence (str): Original sentence
236
+ common_ngrams (dict): Common n-grams and their indices
237
+
238
+ Returns:
239
+ str: Masked sentence
240
+ """
241
+ original_words = sentence.split()
242
+ non_stop_words = self.remove_stopwords(original_words)
243
+ adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
244
+
245
+ # Create mapping from non-stop words to original indices
246
+ non_stop_to_original = {}
247
+ original_to_non_stop = {}
248
+ non_stop_idx = 0
249
+ for orig_idx, word in enumerate(original_words):
250
+ if word.lower() not in self.stop_words:
251
+ non_stop_to_original[non_stop_idx] = orig_idx
252
+ original_to_non_stop[orig_idx] = non_stop_idx
253
+ non_stop_idx += 1
254
+
255
+ ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
256
+ mask_indices = []
257
+
258
+ if ngram_positions:
259
+ # Handle words before first n-gram
260
+ first_ngram_start = ngram_positions[0][0]
261
+ if first_ngram_start > 0:
262
+ # Calculate entropy for all candidate positions
263
+ candidate_positions = range(0, first_ngram_start)
264
+ entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
265
+ for pos in candidate_positions]
266
+ # Select position with highest entropy
267
+ mask_indices.append(max(entropies, key=lambda x: x[1])[0])
268
+
269
+ # Handle words between n-grams
270
+ for i in range(len(ngram_positions) - 1):
271
+ end_prev = ngram_positions[i][1]
272
+ start_next = ngram_positions[i + 1][0]
273
+ if start_next > end_prev + 1:
274
+ candidate_positions = range(end_prev + 1, start_next)
275
+ entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
276
+ for pos in candidate_positions]
277
+ mask_indices.append(max(entropies, key=lambda x: x[1])[0])
278
+
279
+ # Handle words after last n-gram
280
+ last_ngram_end = ngram_positions[-1][1]
281
+ if last_ngram_end < len(non_stop_words) - 1:
282
+ candidate_positions = range(last_ngram_end + 1, len(non_stop_words))
283
+ entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
284
+ for pos in candidate_positions]
285
+ mask_indices.append(max(entropies, key=lambda x: x[1])[0])
286
+
287
+ # Map mask indices to original sentence positions and apply masks
288
+ original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
289
+ masked_words = original_words.copy()
290
+ for idx in original_mask_indices:
291
+ masked_words[idx] = self.tokenizer.mask_token
292
+
293
+ return " ".join(masked_words)
294
+
295
+
296
+ def calculate_mask_logits(self, masked_sentence):
297
+ """
298
+ Calculate logits for masked tokens in the sentence using BERT.
299
+
300
+ Args:
301
+ masked_sentence (str): Sentence with [MASK] tokens.
302
+
303
+ Returns:
304
+ dict: Masked token indices and their logits.
305
+ """
306
+ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
307
+ mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
308
+
309
+ with torch.no_grad():
310
+ outputs = self.model(input_ids)
311
+ logits = outputs.logits
312
+
313
+ mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index}
314
+ return mask_logits
315
+
316
+ def process_sentences(self, sentences, result_dict, method="random"):
317
+ """
318
+ Process sentences and calculate logits for masked tokens.
319
+
320
+ Args:
321
+ sentences (list): List of sentences
322
+ result_dict (dict): Dictionary of common n-grams
323
+ method (str): Masking method ("random" or "entropy")
324
+
325
+ Returns:
326
+ dict: Masked sentences and logits for each sentence
327
+ """
328
+ results = {}
329
+
330
+ for sentence, ngrams in result_dict.items():
331
+ if method == "random":
332
+ masked_sentence = self.mask_sentence_random(sentence, ngrams)
333
+ elif method == "pseudorandom":
334
+ masked_sentence = self.mask_sentence_pseudorandom(sentence, ngrams)
335
+ else: # entropy
336
+ masked_sentence = self.mask_sentence_entropy(sentence, ngrams)
337
+
338
+ logits = self.calculate_mask_logits(masked_sentence)
339
+ results[sentence] = {
340
+ "masked_sentence": masked_sentence,
341
+ "mask_logits": logits
342
+ }
343
+
344
+ return results
345
+
346
+
347
+
348
+ if __name__ == "__main__":
349
+ # !!! Working both the cases regardless if the stopword is removed or not
350
+ sentences = [
351
+ "The quick brown fox jumps over the lazy dog everyday.",
352
+ # "A speedy brown fox jumps over a lazy dog.",
353
+ # "A swift brown fox leaps over the lethargic dog."
354
+ ]
355
+ result_dict ={
356
+ 'The quick brown fox jumps over the lazy dog everyday.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
357
+ # 'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
358
+ # 'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}
359
+ }
360
+
361
+
362
+ processor = MaskingProcessor()
363
+ # results_random = processor.process_sentences(sentences, result_dict)
364
+ results_entropy = processor.process_sentences(sentences, result_dict, method="random")
365
+
366
+ # results_entropy = processor.process_sentences(sentences, result_dict, method="entropy", remove_stopwords=False)
367
+
368
+ for sentence, output in results_entropy.items():
369
+ print(f"Original Sentence (Random): {sentence}")
370
+ print(f"Masked Sentence (Random): {output['masked_sentence']}")
371
+ # print(f"Mask Logits (Random): {output['mask_logits']}")
372
+ print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
373
+ print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
374
+ print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
375
+ print('--------------------------------')
376
+ for mask_idx, logits in output["mask_logits"].items():
377
+ print(f"Logits for [MASK] at position {mask_idx}:")
378
+ print(f' logits : {logits[:5]}') # List of logits for all vocabulary tokens
379
+ print(f' len(logits) : {len(logits)}')
380
+
381
+
382
+
383
+
384
+ # -------------------------------------------------------------------------------------------
385
+ # def mask_sentence(self, sentence, common_ngrams):
386
+ # """
387
+ # Mask words in the sentence based on the specified rules after removing stopwords.
388
+
389
+ # Args:
390
+ # sentence (str): Original sentence.
391
+ # common_ngrams (dict): Common n-grams and their indices.
392
+
393
+ # Returns:
394
+ # str: Masked sentence.
395
+ # """
396
+ # original_words = sentence.split()
397
+ # print(f' ---- original_words : {original_words} ----- ')
398
+ # non_stop_words = self.remove_stopwords(original_words)
399
+ # print(f' ---- non_stop_words : {non_stop_words} ----- ')
400
+ # adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
401
+ # print(f' ---- common_ngrams : {common_ngrams} ----- ')
402
+ # print(f' ---- adjusted_ngrams : {adjusted_ngrams} ----- ')
403
+
404
+ # mask_indices = []
405
+
406
+ # # Extract n-gram positions in non-stop words
407
+ # ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
408
+ # print(f' ---- ngram_positions : {ngram_positions} ----- ')
409
+ # # Mask a word before the first common n-gram
410
+ # if ngram_positions:
411
+ # first_ngram_start = ngram_positions[0][0]
412
+ # print(f' ---- first_ngram_start : {first_ngram_start} ----- ')
413
+ # if first_ngram_start > 0:
414
+ # mask_index_before_ngram = random.randint(0, first_ngram_start-1)
415
+ # print(f' ---- mask_index_before_ngram : {mask_index_before_ngram} ----- ')
416
+ # mask_indices.append(mask_index_before_ngram)
417
+
418
+ # # Mask words between common n-grams
419
+ # for i in range(len(ngram_positions) - 1):
420
+ # end_prev = ngram_positions[i][1]
421
+ # print(f' ---- end_prev : {end_prev} ----- ')
422
+ # start_next = ngram_positions[i + 1][0]
423
+ # print(f' ---- start_next : {start_next} ----- ')
424
+ # if start_next > end_prev + 1:
425
+ # mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
426
+ # print(f' ---- mask_index_between_ngrams : {mask_index_between_ngrams} ----- ')
427
+ # mask_indices.append(mask_index_between_ngrams)
428
+
429
+ # # Mask a word after the last common n-gram
430
+ # last_ngram_end = ngram_positions[-1][1]
431
+ # print(f' ---- last_ngram_end : {last_ngram_end} ----- ')
432
+ # if last_ngram_end < len(non_stop_words) - 1:
433
+ # mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
434
+ # print(f' ---- mask_index_after_ngram : {mask_index_after_ngram} ----- ')
435
+ # mask_indices.append(mask_index_after_ngram)
436
+
437
+ # # Map mask indices back to original sentence
438
+ # adjusted_indices = [
439
+ # orig for orig, non_stop in enumerate(original_words)
440
+ # if non_stop in mask_indices
441
+ # ]
442
+
443
+ # # Apply masks to the original sentence
444
+ # for idx in adjusted_indices:
445
+ # original_words[idx] = self.tokenizer.mask_token
446
+
447
+ # return " ".join(original_words)
utils/old/masking/masking_methods_ok_working.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from transformers import BertTokenizer, BertForMaskedLM
4
+ from nltk.corpus import stopwords
5
+ import nltk
6
+
7
+ # Ensure stopwords are downloaded
8
+ try:
9
+ nltk.data.find('corpora/stopwords')
10
+ except LookupError:
11
+ nltk.download('stopwords')
12
+
13
+ class MaskingProcessor:
14
+ def __init__(self, ):
15
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
16
+ self.model = BertForMaskedLM.from_pretrained("bert-base-uncased")
17
+ self.stop_words = set(stopwords.words('english'))
18
+
19
+ def adjust_ngram_indices(self, words, common_ngrams, remove_stopwords):
20
+ """
21
+ Adjust indices of common n-grams after removing stop words.
22
+
23
+ Args:
24
+ words (list): List of words in the original sentence.
25
+ common_ngrams (dict): Common n-grams and their indices.
26
+
27
+ Returns:
28
+ dict: Adjusted common n-grams and their indices.
29
+ """
30
+ if not remove_stopwords:
31
+ return common_ngrams
32
+
33
+ non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
34
+ adjusted_ngrams = {}
35
+
36
+ for ngram, positions in common_ngrams.items():
37
+ adjusted_positions = []
38
+ for start, end in positions:
39
+ try:
40
+ new_start = non_stop_word_indices.index(start)
41
+ new_end = non_stop_word_indices.index(end)
42
+ adjusted_positions.append((new_start, new_end))
43
+ except ValueError:
44
+ continue # Skip if indices cannot be mapped
45
+ adjusted_ngrams[ngram] = adjusted_positions
46
+
47
+ return adjusted_ngrams
48
+
49
+ def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords):
50
+ """
51
+ Mask one word before the first common n-gram, one between two n-grams,
52
+ and one after the last common n-gram (random selection).
53
+
54
+ Args:
55
+ original_sentence (str): Original sentence
56
+ common_ngrams (dict): Common n-grams and their indices
57
+
58
+ Returns:
59
+ str: Masked sentence with original stop words retained
60
+ """
61
+ words = original_sentence.split()
62
+ non_stop_words = [word for word in words if word.lower() not in self.stop_words] if remove_stopwords else words
63
+ adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)
64
+
65
+ mask_indices = []
66
+ # Handle before the first common n-gram
67
+ if adjusted_ngrams:
68
+ first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
69
+ if first_ngram_start > 0:
70
+ mask_indices.append(random.randint(0, first_ngram_start - 1))
71
+
72
+ # Handle between common n-grams
73
+ ngram_positions = list(adjusted_ngrams.values())
74
+ for i in range(len(ngram_positions) - 1):
75
+ end_prev = ngram_positions[i][-1][1]
76
+ start_next = ngram_positions[i + 1][0][0]
77
+ if start_next > end_prev + 1:
78
+ mask_indices.append(random.randint(end_prev + 1, start_next - 1))
79
+
80
+ # Handle after the last common n-gram
81
+ last_ngram_end = ngram_positions[-1][-1][1]
82
+ if last_ngram_end < len(non_stop_words) - 1:
83
+ mask_indices.append(random.randint(last_ngram_end + 1, len(non_stop_words) - 1))
84
+
85
+ # Mask the chosen indices
86
+ original_masked_sentence = words[:]
87
+ for idx in mask_indices:
88
+ if idx not in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
89
+ non_stop_words[idx] = self.tokenizer.mask_token
90
+ original_masked_sentence[idx] = self.tokenizer.mask_token
91
+
92
+ return " ".join(original_masked_sentence)
93
+
94
+ def mask_sentence_entropy(self, original_sentence, common_ngrams, remove_stopwords):
95
+ """
96
+ Mask one word before the first common n-gram, one between two n-grams,
97
+ and one after the last common n-gram (highest entropy selection).
98
+
99
+ Args:
100
+ original_sentence (str): Original sentence
101
+ common_ngrams (dict): Common n-grams and their indices
102
+
103
+ Returns:
104
+ str: Masked sentence with original stop words retained
105
+ """
106
+ words = original_sentence.split()
107
+ non_stop_words = [word for word in words if word.lower() not in self.stop_words] if remove_stopwords else words
108
+ adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)
109
+ entropy_scores = {}
110
+
111
+ for idx, word in enumerate(non_stop_words):
112
+ if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
113
+ continue # Skip words in common n-grams
114
+
115
+ masked_sentence = non_stop_words[:idx] + [self.tokenizer.mask_token] + non_stop_words[idx + 1:]
116
+ masked_sentence = " ".join(masked_sentence)
117
+ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
118
+ mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
119
+
120
+ with torch.no_grad():
121
+ outputs = self.model(input_ids)
122
+ logits = outputs.logits
123
+
124
+ filtered_logits = logits[0, mask_token_index, :]
125
+ probs = torch.softmax(filtered_logits, dim=-1)
126
+ entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item() # Add epsilon to prevent log(0)
127
+ entropy_scores[idx] = entropy
128
+
129
+ mask_indices = []
130
+
131
+ # Handle before the first common n-gram
132
+ if adjusted_ngrams:
133
+ first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
134
+ candidates = [i for i in range(first_ngram_start) if i in entropy_scores]
135
+ if candidates:
136
+ mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
137
+
138
+ # Handle between common n-grams
139
+ ngram_positions = list(adjusted_ngrams.values())
140
+ for i in range(len(ngram_positions) - 1):
141
+ end_prev = ngram_positions[i][-1][1]
142
+ start_next = ngram_positions[i + 1][0][0]
143
+ candidates = [i for i in range(end_prev + 1, start_next) if i in entropy_scores]
144
+ if candidates:
145
+ mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
146
+
147
+ # Handle after the last common n-gram
148
+ last_ngram_end = ngram_positions[-1][-1][1]
149
+ candidates = [i for i in range(last_ngram_end + 1, len(non_stop_words)) if i in entropy_scores]
150
+ if candidates:
151
+ mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
152
+
153
+ # Mask the chosen indices
154
+ original_masked_sentence = words[:]
155
+ for idx in mask_indices:
156
+ non_stop_words[idx] = self.tokenizer.mask_token
157
+ original_masked_sentence[idx] = self.tokenizer.mask_token
158
+
159
+ return " ".join(original_masked_sentence)
160
+
161
+ def calculate_mask_logits(self, masked_sentence):
162
+ """
163
+ Calculate logits for masked tokens in the sentence using BERT.
164
+
165
+ Args:
166
+ masked_sentence (str): Sentence with [MASK] tokens
167
+
168
+ Returns:
169
+ dict: Masked token indices and their logits
170
+ """
171
+ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
172
+ mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
173
+
174
+ with torch.no_grad():
175
+ outputs = self.model(input_ids)
176
+ logits = outputs.logits
177
+
178
+ mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index}
179
+ return mask_logits
180
+
181
+ def process_sentences(self, original_sentences, result_dict, method="random", remove_stopwords=False):
182
+ """
183
+ Process a list of sentences and calculate logits for masked tokens using the specified method.
184
+
185
+ Args:
186
+ original_sentences (list): List of original sentences
187
+ result_dict (dict): Common n-grams and their indices for each sentence
188
+ method (str): Masking method ("random" or "entropy")
189
+
190
+ Returns:
191
+ dict: Masked sentences and their logits for each sentence
192
+ """
193
+ results = {}
194
+
195
+ for sentence, ngrams in result_dict.items():
196
+ if method == "random":
197
+ masked_sentence = self.mask_sentence_random(sentence, ngrams, remove_stopwords)
198
+ elif method == "entropy":
199
+ masked_sentence = self.mask_sentence_entropy(sentence, ngrams, remove_stopwords)
200
+ else:
201
+ raise ValueError("Invalid method. Choose 'random' or 'entropy'.")
202
+
203
+ logits = self.calculate_mask_logits(masked_sentence)
204
+ results[sentence] = {
205
+ "masked_sentence": masked_sentence,
206
+ "mask_logits": logits
207
+ }
208
+
209
+ return results
210
+
211
+ # Example usage
212
+ if __name__ == "__main__":
213
+ # !!! Working both the cases regardless if the stopword is removed or not
214
+ sentences = [
215
+ "The quick brown fox jumps over the lazy dog.",
216
+ "A quick brown dog outpaces a lazy fox.",
217
+ "Quick brown animals leap over lazy obstacles."
218
+ ]
219
+
220
+ result_dict = {
221
+ "The quick brown fox jumps over the lazy dog.": {"quick brown": [(1, 2)], "lazy": [(7, 7)]},
222
+ "A quick brown dog outpaces a lazy fox.": {"quick brown": [(1, 2)], "lazy": [(6, 6)]},
223
+ "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(5, 5)]}
224
+ }
225
+
226
+ # result_dict = {
227
+ # "The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
228
+ # "A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
229
+ # "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}
230
+ # }
231
+
232
+ processor = MaskingProcessor()
233
+ results_random = processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
234
+ # results_entropy = processor.process_sentences(sentences, result_dict, method="entropy", remove_stopwords=False)
235
+
236
+ for sentence, output in results_random.items():
237
+ print(f"Original Sentence (Random): {sentence}")
238
+ print(f"Masked Sentence (Random): {output['masked_sentence']}")
239
+ # print(f"Mask Logits (Random): {output['mask_logits']}")
240
+ print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
241
+ print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
242
+ print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
243
+ print('--------------------------------')
244
+ for mask_idx, logits in output["mask_logits"].items():
245
+ print(f"Logits for [MASK] at position {mask_idx}:")
246
+ print(f' logits : {logits[:5]}') # List of logits for all vocabulary tokens
247
+
248
+
249
+
250
+ # print('--------------------------------')
251
+ # for sentence, output in results_entropy.items():
252
+ # print(f"Original Sentence (Entropy): {sentence}")
253
+ # print(f"Masked Sentence (Entropy): {output['masked_sentence']}")
254
+ # # print(f"Mask Logits (Entropy): {output['mask_logits']}")
255
+ # print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
256
+ # print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
257
+ # print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
utils/old/masking/masking_methods_v1_working.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from transformers import BertTokenizer, BertForMaskedLM
4
+ from nltk.corpus import stopwords
5
+ import nltk
6
+
7
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
8
+ # THIS IS WORKING WHEN THE COORDINATES ARE WITHOUT REMOVING STOPWORDS
9
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
10
+
11
+
12
+ # Ensure stopwords are downloaded
13
+ try:
14
+ nltk.data.find('corpora/stopwords')
15
+ except LookupError:
16
+ nltk.download('stopwords')
17
+
18
+ class MaskingProcessor:
19
+ def __init__(self):
20
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
21
+ self.model = BertForMaskedLM.from_pretrained("bert-base-uncased")
22
+ self.stop_words = set(stopwords.words('english'))
23
+
24
+ def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords=False):
25
+ """
26
+ Mask one word before the first common n-gram, one between two n-grams,
27
+ and one after the last common n-gram (random selection).
28
+
29
+ Args:
30
+ original_sentence (str): Original sentence
31
+ common_ngrams (dict): Common n-grams and their indices
32
+
33
+ Returns:
34
+ str: Masked sentence
35
+ """
36
+ if remove_stopwords:
37
+ words = original_sentence.split()
38
+ words = [word for word in words if word not in self.stop_words]
39
+ else:
40
+ words = original_sentence.split()
41
+
42
+ mask_indices = []
43
+ # Handle before the first common n-gram
44
+ if common_ngrams:
45
+ first_ngram_start = list(common_ngrams.values())[0][0][0]
46
+ if first_ngram_start > 0:
47
+ mask_indices.append(random.randint(0, first_ngram_start - 1))
48
+
49
+ # Handle between common n-grams
50
+ ngram_positions = list(common_ngrams.values())
51
+ for i in range(len(ngram_positions) - 1):
52
+ end_prev = ngram_positions[i][-1][1]
53
+ start_next = ngram_positions[i + 1][0][0]
54
+ if start_next > end_prev + 1:
55
+ mask_indices.append(random.randint(end_prev + 1, start_next - 1))
56
+
57
+ # Handle after the last common n-gram
58
+ last_ngram_end = ngram_positions[-1][-1][1]
59
+ if last_ngram_end < len(words) - 1:
60
+ mask_indices.append(random.randint(last_ngram_end + 1, len(words) - 1))
61
+
62
+ # Mask the chosen indices
63
+ for idx in mask_indices:
64
+ if idx not in [index for ngram_indices in common_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
65
+ words[idx] = self.tokenizer.mask_token
66
+
67
+ return " ".join(words)
68
+
69
+ def mask_sentence_entropy(self, original_sentence, common_ngrams, remove_stopwords=False):
70
+ """
71
+ Mask one word before the first common n-gram, one between two n-grams,
72
+ and one after the last common n-gram (highest entropy selection).
73
+
74
+ Args:
75
+ original_sentence (str): Original sentence
76
+ common_ngrams (dict): Common n-grams and their indices
77
+
78
+ Returns:
79
+ str: Masked sentence
80
+ """
81
+ if remove_stopwords:
82
+ words = original_sentence.split()
83
+ words = [word for word in words if word not in self.stop_words]
84
+ else:
85
+ words = original_sentence.split()
86
+ entropy_scores = {}
87
+
88
+ for idx, word in enumerate(words):
89
+ if idx in [index for ngram_indices in common_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
90
+ continue # Skip words in common n-grams
91
+
92
+ masked_sentence = words[:idx] + [self.tokenizer.mask_token] + words[idx + 1:]
93
+ masked_sentence = " ".join(masked_sentence)
94
+ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
95
+ mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
96
+
97
+ with torch.no_grad():
98
+ outputs = self.model(input_ids)
99
+ logits = outputs.logits
100
+
101
+ filtered_logits = logits[0, mask_token_index, :]
102
+ probs = torch.softmax(filtered_logits, dim=-1)
103
+ entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item() # Add epsilon to prevent log(0)
104
+ entropy_scores[idx] = entropy
105
+
106
+ mask_indices = []
107
+
108
+ # Handle before the first common n-gram
109
+ if common_ngrams:
110
+ first_ngram_start = list(common_ngrams.values())[0][0][0]
111
+ candidates = [i for i in range(first_ngram_start) if i in entropy_scores]
112
+ if candidates:
113
+ mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
114
+
115
+ # Handle between common n-grams
116
+ ngram_positions = list(common_ngrams.values())
117
+ for i in range(len(ngram_positions) - 1):
118
+ end_prev = ngram_positions[i][-1][1]
119
+ start_next = ngram_positions[i + 1][0][0]
120
+ candidates = [i for i in range(end_prev + 1, start_next) if i in entropy_scores]
121
+ if candidates:
122
+ mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
123
+
124
+ # Handle after the last common n-gram
125
+ last_ngram_end = ngram_positions[-1][-1][1]
126
+ candidates = [i for i in range(last_ngram_end + 1, len(words)) if i in entropy_scores]
127
+ if candidates:
128
+ mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
129
+
130
+ # Mask the chosen indices
131
+ for idx in mask_indices:
132
+ words[idx] = self.tokenizer.mask_token
133
+
134
+ return " ".join(words)
135
+
136
+ def calculate_mask_logits(self, masked_sentence):
137
+ """
138
+ Calculate logits for masked tokens in the sentence using BERT.
139
+
140
+ Args:
141
+ masked_sentence (str): Sentence with [MASK] tokens
142
+
143
+ Returns:
144
+ dict: Masked token indices and their logits
145
+ """
146
+ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
147
+ mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
148
+
149
+ with torch.no_grad():
150
+ outputs = self.model(input_ids)
151
+ logits = outputs.logits
152
+
153
+ mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index}
154
+ return mask_logits
155
+
156
+ def process_sentences(self, original_sentences, result_dict, remove_stopwords=False, method="random"):
157
+ """
158
+ Process a list of sentences and calculate logits for masked tokens using the specified method.
159
+
160
+ Args:
161
+ original_sentences (list): List of original sentences
162
+ result_dict (dict): Common n-grams and their indices for each sentence
163
+ method (str): Masking method ("random" or "entropy")
164
+
165
+ Returns:
166
+ dict: Masked sentences and their logits for each sentence
167
+ """
168
+ results = {}
169
+
170
+ for sentence, ngrams in result_dict.items():
171
+ if method == "random":
172
+ masked_sentence = self.mask_sentence_random(sentence, ngrams)
173
+ elif method == "entropy":
174
+ masked_sentence = self.mask_sentence_entropy(sentence, ngrams)
175
+ else:
176
+ raise ValueError("Invalid method. Choose 'random' or 'entropy'.")
177
+
178
+ logits = self.calculate_mask_logits(masked_sentence)
179
+ results[sentence] = {
180
+ "masked_sentence": masked_sentence,
181
+ "mask_logits": logits
182
+ }
183
+
184
+ return results
185
+
186
+ # Example usage
187
+ if __name__ == "__main__":
188
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
189
+ # THIS IS WORKING WHEN THE COORDINATES ARE WITHOUT REMOVING STOPWORDS
190
+
191
+ sentences = [
192
+ "The quick brown fox jumps over the lazy dog.",
193
+ "A quick brown dog outpaces a lazy fox.",
194
+ "Quick brown animals leap over lazy obstacles."
195
+ ]
196
+
197
+ result_dict = {
198
+ "The quick brown fox jumps over the lazy dog.": {"quick brown": [(1, 2)], "lazy": [(7, 7)]},
199
+ "A quick brown dog outpaces a lazy fox.": {"quick brown": [(1, 2)], "lazy": [(6, 6)]},
200
+ "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(5, 5)]}
201
+ }
202
+
203
+ # result_dict = {
204
+ # "The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
205
+ # "A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
206
+ # "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}
207
+ # }
208
+
209
+ processor = MaskingProcessor()
210
+ results_random = processor.process_sentences(sentences, result_dict, remove_stopwords=True, method="random")
211
+ results_entropy = processor.process_sentences(sentences, result_dict, remove_stopwords=True, method="entropy")
212
+
213
+ for sentence, output in results_random.items():
214
+ print(f"Original Sentence (Random): {sentence}")
215
+ print(f"Masked Sentence (Random): {output['masked_sentence']}")
216
+ # print(f"Mask Logits (Random): {output['mask_logits']}")
217
+
218
+ for sentence, output in results_entropy.items():
219
+ print(f"Original Sentence (Entropy): {sentence}")
220
+ print(f"Masked Sentence (Entropy): {output['masked_sentence']}")
221
+ # print(f"Mask Logits (Entropy): {output['mask_logits']}")
222
+
223
+
224
+
225
+
226
+ '''
227
+ result_dict = {
228
+ "The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
229
+ "A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
230
+ "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}
231
+ }
232
+
233
+ '''
utils/old/masking_methods_final_copy.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from transformers import BertTokenizer, BertForMaskedLM
4
+ from nltk.corpus import stopwords
5
+ import nltk
6
+ from transformers import RobertaTokenizer, RobertaForMaskedLM
7
+
8
+
9
+ # Ensure stopwords are downloaded
10
+ try:
11
+ nltk.data.find('corpora/stopwords')
12
+ except LookupError:
13
+ nltk.download('stopwords')
14
+
15
+ class MaskingProcessor:
16
+ # def __init__(self, tokenizer, model):
17
+ def __init__(self):
18
+ # self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
19
+ # self.model = BertForMaskedLM.from_pretrained("bert-base-uncased")
20
+
21
+ # self.tokenizer = tokenizer
22
+ # self.model = model
23
+
24
+ self.tokenizer = BertTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
25
+ self.model = BertForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
26
+
27
+ # self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
28
+ # self.model = RobertaForMaskedLM.from_pretrained("roberta-base")
29
+
30
+ self.stop_words = set(stopwords.words('english'))
31
+
32
+ def remove_stopwords(self, words):
33
+ """
34
+ Remove stopwords from the given list of words.
35
+
36
+ Args:
37
+ words (list): List of words.
38
+
39
+ Returns:
40
+ list: List of non-stop words.
41
+ """
42
+ return [word for word in words if word.lower() not in self.stop_words]
43
+
44
+ def adjust_ngram_indices(self, original_words, common_ngrams):
45
+ """
46
+ Adjust indices of common n-grams after removing stopwords.
47
+
48
+ Args:
49
+ original_words (list): Original list of words.
50
+ common_ngrams (dict): Common n-grams and their indices.
51
+
52
+ Returns:
53
+ dict: Adjusted common n-grams with updated indices.
54
+ """
55
+ non_stop_words = self.remove_stopwords(original_words)
56
+ original_to_non_stop = []
57
+ non_stop_idx = 0
58
+
59
+ for original_idx, word in enumerate(original_words):
60
+ if word.lower() not in self.stop_words:
61
+ original_to_non_stop.append((original_idx, non_stop_idx))
62
+ non_stop_idx += 1
63
+
64
+ adjusted_ngrams = {}
65
+ for ngram, positions in common_ngrams.items():
66
+ adjusted_positions = []
67
+ for start, end in positions:
68
+ try:
69
+ new_start = next(non_stop for orig, non_stop in original_to_non_stop if orig == start)
70
+ new_end = next(non_stop for orig, non_stop in original_to_non_stop if orig == end)
71
+ adjusted_positions.append((new_start, new_end))
72
+ except StopIteration:
73
+ continue # Skip if indices cannot be mapped
74
+ adjusted_ngrams[ngram] = adjusted_positions
75
+
76
+ return adjusted_ngrams
77
+
78
+ def mask_sentence_random(self, sentence, common_ngrams):
79
+ """
80
+ Mask words in the sentence based on the specified rules after removing stopwords.
81
+ """
82
+ # Split sentence into words
83
+ original_words = sentence.split()
84
+
85
+ # Handle punctuation at the end
86
+ has_punctuation = False
87
+ punctuation = None
88
+ if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
89
+ has_punctuation = True
90
+ punctuation = original_words[-1][-1]
91
+ original_words = original_words[:-1]
92
+
93
+ print(f' ---- original_words : {original_words} ----- ')
94
+
95
+ # Process words without punctuation
96
+ non_stop_words = self.remove_stopwords(original_words)
97
+ adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
98
+
99
+ # Rest of the existing function code...
100
+ mask_indices = []
101
+ ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
102
+
103
+ if ngram_positions:
104
+ first_ngram_start = ngram_positions[0][0]
105
+ if first_ngram_start > 0:
106
+ mask_index_before_ngram = random.randint(0, first_ngram_start-1)
107
+ mask_indices.append(mask_index_before_ngram)
108
+
109
+ # Mask words between common n-grams
110
+ for i in range(len(ngram_positions) - 1):
111
+ end_prev = ngram_positions[i][1]
112
+ start_next = ngram_positions[i + 1][0]
113
+ if start_next > end_prev + 1:
114
+ mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
115
+ mask_indices.append(mask_index_between_ngrams)
116
+
117
+ # Mask a word after the last common n-gram
118
+ last_ngram_end = ngram_positions[-1][1]
119
+ if last_ngram_end < len(non_stop_words) - 1:
120
+ mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
121
+ mask_indices.append(mask_index_after_ngram)
122
+
123
+ # Create mapping from non-stop words to original indices
124
+ non_stop_to_original = {}
125
+ non_stop_idx = 0
126
+ for orig_idx, word in enumerate(original_words):
127
+ if word.lower() not in self.stop_words:
128
+ non_stop_to_original[non_stop_idx] = orig_idx
129
+ non_stop_idx += 1
130
+
131
+ # Map mask indices and apply masks
132
+ original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
133
+ masked_words = original_words.copy()
134
+ for idx in original_mask_indices:
135
+ masked_words[idx] = self.tokenizer.mask_token
136
+ # masked_words[idx] = '<mask>' # for roberta
137
+
138
+ # Add back punctuation if it existed
139
+ if has_punctuation:
140
+ masked_words.append(punctuation)
141
+
142
+ print(f' ***** masked_words at end : {masked_words} ***** ')
143
+ print(f' ***** original_mask_indices : {original_mask_indices} ***** ')
144
+ print(f' ***** TESTING : {" ".join(masked_words)} ***** ')
145
+
146
+ return " ".join(masked_words), original_mask_indices
147
+
148
+ def mask_sentence_pseudorandom(self, sentence, common_ngrams):
149
+ """
150
+ Mask words in the sentence based on the specified rules after removing stopwords.
151
+ """
152
+ # Split sentence into words
153
+ random.seed(3)
154
+ original_words = sentence.split()
155
+
156
+ # Handle punctuation at the end
157
+ has_punctuation = False
158
+ punctuation = None
159
+ if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
160
+ has_punctuation = True
161
+ punctuation = original_words[-1][-1]
162
+ original_words = original_words[:-1]
163
+
164
+ print(f' ---- original_words : {original_words} ----- ')
165
+
166
+ # Process words without punctuation
167
+ non_stop_words = self.remove_stopwords(original_words)
168
+ adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
169
+
170
+ # Rest of the existing function code...
171
+ mask_indices = []
172
+ ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
173
+
174
+ if ngram_positions:
175
+ first_ngram_start = ngram_positions[0][0]
176
+ if first_ngram_start > 0:
177
+ mask_index_before_ngram = random.randint(0, first_ngram_start-1)
178
+ mask_indices.append(mask_index_before_ngram)
179
+
180
+ # Mask words between common n-grams
181
+ for i in range(len(ngram_positions) - 1):
182
+ end_prev = ngram_positions[i][1]
183
+ start_next = ngram_positions[i + 1][0]
184
+ if start_next > end_prev + 1:
185
+ mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
186
+ mask_indices.append(mask_index_between_ngrams)
187
+
188
+ # Mask a word after the last common n-gram
189
+ last_ngram_end = ngram_positions[-1][1]
190
+ if last_ngram_end < len(non_stop_words) - 1:
191
+ mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
192
+ mask_indices.append(mask_index_after_ngram)
193
+
194
+ # Create mapping from non-stop words to original indices
195
+ non_stop_to_original = {}
196
+ non_stop_idx = 0
197
+ for orig_idx, word in enumerate(original_words):
198
+ if word.lower() not in self.stop_words:
199
+ non_stop_to_original[non_stop_idx] = orig_idx
200
+ non_stop_idx += 1
201
+
202
+ # Map mask indices and apply masks
203
+ original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
204
+ masked_words = original_words.copy()
205
+ for idx in original_mask_indices:
206
+ masked_words[idx] = self.tokenizer.mask_token
207
+ # masked_words[idx] = '<mask>' # for roberta
208
+
209
+ # Add back punctuation if it existed
210
+ if has_punctuation:
211
+ masked_words.append(punctuation)
212
+
213
+ print(f' ***** masked_words at end : {masked_words} ***** ')
214
+ print(f' ***** original_mask_indices : {original_mask_indices} ***** ')
215
+ print(f' ***** TESTING : {" ".join(masked_words)} ***** ')
216
+
217
+ return " ".join(masked_words), original_mask_indices
218
+
219
+
220
+ def calculate_word_entropy(self, sentence, word_position):
221
+ """
222
+ Calculate entropy for a specific word position in the sentence.
223
+
224
+ Args:
225
+ sentence (str): The input sentence
226
+ word_position (int): Position of the word to calculate entropy for
227
+
228
+ Returns:
229
+ float: Entropy value for the word
230
+ """
231
+ words = sentence.split()
232
+ masked_words = words.copy()
233
+ masked_words[word_position] = self.tokenizer.mask_token
234
+ masked_sentence = " ".join(masked_words)
235
+
236
+ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
237
+ mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
238
+
239
+ with torch.no_grad():
240
+ outputs = self.model(input_ids)
241
+ logits = outputs.logits
242
+
243
+ # Get probabilities for the masked position
244
+ probs = torch.nn.functional.softmax(logits[0, mask_token_index], dim=-1)
245
+ # Calculate entropy: -sum(p * log(p))
246
+ entropy = -torch.sum(probs * torch.log(probs + 1e-9))
247
+
248
+ return entropy.item()
249
+
250
+ def mask_sentence_entropy(self, sentence, common_ngrams):
251
+ """
252
+ Mask words in the sentence based on entropy, following n-gram positioning rules.
253
+
254
+ Args:
255
+ sentence (str): Original sentence
256
+ common_ngrams (dict): Common n-grams and their indices
257
+
258
+ Returns:
259
+ str: Masked sentence
260
+ """
261
+ # Split sentence into words
262
+ original_words = sentence.split()
263
+
264
+ # Handle punctuation at the end
265
+ has_punctuation = False
266
+ punctuation = None
267
+ if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
268
+ has_punctuation = True
269
+ punctuation = original_words[-1][-1]
270
+ original_words = original_words[:-1]
271
+
272
+ # Process words without punctuation
273
+ non_stop_words = self.remove_stopwords(original_words)
274
+ adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
275
+
276
+ # Create mapping from non-stop words to original indices
277
+ non_stop_to_original = {}
278
+ original_to_non_stop = {}
279
+ non_stop_idx = 0
280
+ for orig_idx, word in enumerate(original_words):
281
+ if word.lower() not in self.stop_words:
282
+ non_stop_to_original[non_stop_idx] = orig_idx
283
+ original_to_non_stop[orig_idx] = non_stop_idx
284
+ non_stop_idx += 1
285
+
286
+ ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
287
+ mask_indices = []
288
+
289
+ if ngram_positions:
290
+ # Handle words before first n-gram
291
+ first_ngram_start = ngram_positions[0][0]
292
+ if first_ngram_start > 0:
293
+ candidate_positions = range(0, first_ngram_start)
294
+ entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
295
+ for pos in candidate_positions]
296
+ mask_indices.append(max(entropies, key=lambda x: x[1])[0])
297
+
298
+ # Handle words between n-grams
299
+ for i in range(len(ngram_positions) - 1):
300
+ end_prev = ngram_positions[i][1]
301
+ start_next = ngram_positions[i + 1][0]
302
+ if start_next > end_prev + 1:
303
+ candidate_positions = range(end_prev + 1, start_next)
304
+ entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
305
+ for pos in candidate_positions]
306
+ mask_indices.append(max(entropies, key=lambda x: x[1])[0])
307
+
308
+ # Handle words after last n-gram
309
+ last_ngram_end = ngram_positions[-1][1]
310
+ if last_ngram_end < len(non_stop_words) - 1:
311
+ candidate_positions = range(last_ngram_end + 1, len(non_stop_words))
312
+ entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
313
+ for pos in candidate_positions]
314
+ mask_indices.append(max(entropies, key=lambda x: x[1])[0])
315
+
316
+ # Map mask indices to original sentence positions and apply masks
317
+ original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
318
+ masked_words = original_words.copy()
319
+ for idx in original_mask_indices:
320
+ masked_words[idx] = self.tokenizer.mask_token
321
+
322
+ # Add back punctuation if it existed
323
+ if has_punctuation:
324
+ masked_words.append(punctuation)
325
+
326
+ return " ".join(masked_words), original_mask_indices
327
+
328
+ def calculate_mask_logits(self, original_sentence, original_mask_indices):
329
+ """
330
+ Calculate logits for masked tokens in the sentence using BERT.
331
+
332
+ Args:
333
+ original_sentence (str): Original sentence without masks
334
+ original_mask_indices (list): List of indices to mask
335
+
336
+ Returns:
337
+ dict: Masked token indices and their logits
338
+ """
339
+ print('==========================================================================================================')
340
+ words = original_sentence.split()
341
+ print(f' ##### calculate_mask_logits >> words : {words} ##### ')
342
+ mask_logits = {}
343
+
344
+ for idx in original_mask_indices:
345
+ # Create a copy of words and mask the current position
346
+ print(f' ---- idx : {idx} ----- ')
347
+ masked_words = words.copy()
348
+ masked_words[idx] = '[MASK]'
349
+ # masked_words[idx] = '<mask>' # for roberta
350
+ masked_sentence = " ".join(masked_words)
351
+ print(f' ---- masked_sentence : {masked_sentence} ----- ')
352
+
353
+ # Calculate logits for the current mask
354
+ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
355
+ mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
356
+
357
+ with torch.no_grad():
358
+ outputs = self.model(input_ids)
359
+ logits = outputs.logits
360
+
361
+ # Extract logits for the masked position
362
+ mask_logits_tensor = logits[0, mask_token_index, :]
363
+
364
+ # Get top logits and corresponding tokens
365
+ top_mask_logits, top_mask_indices = torch.topk(mask_logits_tensor, 100, dim=-1) # Get more candidates
366
+
367
+ # Convert token IDs to words and filter out subword tokens
368
+ top_tokens = []
369
+ top_logits = []
370
+ seen_words = set() # To keep track of unique words
371
+
372
+ for token_id, logit in zip(top_mask_indices[0], top_mask_logits[0]):
373
+ token = self.tokenizer.convert_ids_to_tokens(token_id.item())
374
+
375
+ # Skip if it's a subword token (starts with ##)
376
+ if token.startswith('##'):
377
+ continue
378
+
379
+ # Convert token to proper word
380
+ word = self.tokenizer.convert_tokens_to_string([token]).strip()
381
+
382
+ # Only add if it's a new word and not empty
383
+ if word and word not in seen_words:
384
+ seen_words.add(word)
385
+ top_tokens.append(word)
386
+ top_logits.append(logit.item())
387
+
388
+ # Break if we have 50 unique complete words
389
+ if len(top_tokens) == 50:
390
+ break
391
+
392
+ # print(f' ---- top_tokens : {top_tokens} ----- ')
393
+
394
+ # Store results
395
+ mask_logits[idx] = {
396
+ "tokens": top_tokens,
397
+ "logits": top_logits
398
+ }
399
+
400
+ return mask_logits
401
+
402
+ # def calculate_mask_logits(self, original_sentence, original_mask_indices):
403
+ # """
404
+ # Calculate logits for masked tokens in the sentence using BERT.
405
+
406
+ # Args:
407
+ # original_sentence (str): Original sentence without masks
408
+ # original_mask_indices (list): List of indices to mask
409
+
410
+ # Returns:
411
+ # dict: Masked token indices and their logits
412
+ # """
413
+ # words = original_sentence.split()
414
+ # print(f' ##### calculate_mask_logits >> words : {words} ##### ')
415
+ # mask_logits = {}
416
+
417
+ # for idx in original_mask_indices:
418
+ # # Create a copy of words and mask the current position
419
+ # print(f' ---- idx : {idx} ----- ')
420
+ # masked_words = words.copy()
421
+ # print(f' ---- words : {masked_words} ----- ')
422
+ # # masked_words[idx] = self.tokenizer.mask_token
423
+ # masked_words[idx] = '[MASK]'
424
+ # print(f' ---- masked_words : {masked_words} ----- ')
425
+ # masked_sentence = " ".join(masked_words)
426
+ # print(f' ---- masked_sentence : {masked_sentence} ----- ')
427
+
428
+ # # Calculate logits for the current mask
429
+ # input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
430
+ # mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
431
+
432
+ # with torch.no_grad():
433
+ # outputs = self.model(input_ids)
434
+ # logits = outputs.logits
435
+
436
+ # # Extract logits for the masked position
437
+ # mask_logits_tensor = logits[0, mask_token_index, :]
438
+
439
+ # # Get top 50 logits and corresponding tokens
440
+ # top_mask_logits, top_mask_indices = torch.topk(mask_logits_tensor, 50, dim=-1)
441
+
442
+ # # Convert token IDs to words
443
+ # top_tokens = [self.tokenizer.convert_ids_to_tokens(token_id.item()) for token_id in top_mask_indices[0]]
444
+ # print(f' ---- top_tokens : {top_tokens} ----- ')
445
+
446
+ # # Store results
447
+ # mask_logits[idx] = {
448
+ # "tokens": top_tokens,
449
+ # "logits": top_mask_logits.tolist()
450
+ # }
451
+
452
+ # return mask_logits
453
+
454
+
455
+ def process_sentences(self, sentences, result_dict, method="random"):
456
+ """
457
+ Process sentences and calculate logits for masked tokens.
458
+ """
459
+ results = {}
460
+
461
+ for sentence, ngrams in result_dict.items():
462
+ # Split punctuation from the last word before processing
463
+ words = sentence.split()
464
+ last_word = words[-1]
465
+ if any(last_word.endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
466
+ # Split the last word and punctuation
467
+ words[-1] = last_word[:-1]
468
+ punctuation = last_word[-1]
469
+ # Rejoin with space before punctuation to treat it as separate token
470
+ processed_sentence = " ".join(words) + " " + punctuation
471
+ else:
472
+ processed_sentence = sentence
473
+
474
+ if method == "random":
475
+ masked_sentence, original_mask_indices = self.mask_sentence_random(processed_sentence, ngrams)
476
+ elif method == "pseudorandom":
477
+ masked_sentence, original_mask_indices = self.mask_sentence_pseudorandom(processed_sentence, ngrams)
478
+ else: # entropy
479
+ masked_sentence, original_mask_indices = self.mask_sentence_entropy(processed_sentence, ngrams)
480
+
481
+ logits = self.calculate_mask_logits(processed_sentence, original_mask_indices)
482
+ results[sentence] = {
483
+ "masked_sentence": masked_sentence,
484
+ "mask_logits": logits
485
+ }
486
+
487
+ return results
488
+
489
+
490
+
491
+ if __name__ == "__main__":
492
+ # !!! Working both the cases regardless if the stopword is removed or not
493
+ sentences = [
494
+ "The quick brown fox jumps over small cat the lazy dog everyday again and again .",
495
+ # "A speedy brown fox jumps over a lazy dog.",
496
+ # "A swift brown fox leaps over the lethargic dog."
497
+
498
+ ]
499
+ result_dict ={
500
+ 'The quick brown fox jumps over small cat the lazy dog everyday again and again .': {'brown fox': [(2, 3)],'cat': [(7, 7)], 'dog': [(10, 10)]},
501
+ # 'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
502
+ # 'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}
503
+ }
504
+
505
+
506
+ processor = MaskingProcessor()
507
+ # results_random = processor.process_sentences(sentences, result_dict)
508
+ results_entropy = processor.process_sentences(sentences, result_dict, method="random")
509
+
510
+ '''
511
+ results structure :
512
+ results = {
513
+ "The quick brown fox jumps over the lazy dog everyday.":
514
+ { # Original sentence as key
515
+ "masked_sentence": str, # The sentence with [MASK] tokens
516
+ "mask_logits":
517
+ { # Dictionary of mask positions and their predictions
518
+ 1:
519
+ { # Position of mask in sentence
520
+ "tokens" (words) : list, # List of top 50 predicted tokens
521
+ "logits" (probabilities) : list # Corresponding logits for those tokens
522
+ },
523
+ 7:
524
+ {
525
+ "tokens" (words) : list,
526
+ "logits" (probabilities) : list
527
+ },
528
+ 10:
529
+ {
530
+ "tokens (words)": list,
531
+ "logits (probabilities)": list
532
+ }
533
+ }
534
+ }
535
+ }
536
+
537
+ '''
538
+ # results_entropy = processor.process_sentences(sentences, result_dict, method="entropy", remove_stopwords=False)
539
+
540
+ for sentence, output in results_entropy.items():
541
+ print(f"Original Sentence (Random): {sentence}")
542
+ print(f"Masked Sentence (Random): {output['masked_sentence']}")
543
+ # print(f"Mask Logits (Random): {output['mask_logits']}")
544
+ # print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
545
+ # print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
546
+ # print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
547
+ # print('--------------------------------')
548
+ # for mask_idx, logits in output["mask_logits"].items():
549
+ # print(f"Logits for [MASK] at position {mask_idx}:")
550
+ # print(f' logits : {logits[:5]}') # List of logits for all vocabulary tokens
551
+ # print(f' len(logits) : {len(logits)}')
552
+
553
+
554
+ # ------------------------------------------------------------------------------------------------
555
+ # def mask_sentence_random(self, sentence, common_ngrams):
556
+ # """
557
+ # Mask words in the sentence based on the specified rules after removing stopwords.
558
+ # """
559
+ # original_words = sentence.split()
560
+ # # print(f' ---- original_words : {original_words} ----- ')
561
+ # non_stop_words = self.remove_stopwords(original_words)
562
+ # # print(f' ---- non_stop_words : {non_stop_words} ----- ')
563
+ # adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
564
+ # # print(f' ---- common_ngrams : {common_ngrams} ----- ')
565
+ # # print(f' ---- adjusted_ngrams : {adjusted_ngrams} ----- ')
566
+
567
+ # mask_indices = []
568
+
569
+ # # Extract n-gram positions in non-stop words
570
+ # ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
571
+
572
+ # # Mask a word before the first common n-gram
573
+ # if ngram_positions:
574
+ # # print(f' ---- ngram_positions : {ngram_positions} ----- ')
575
+ # first_ngram_start = ngram_positions[0][0]
576
+ # # print(f' ---- first_ngram_start : {first_ngram_start} ----- ')
577
+ # if first_ngram_start > 0:
578
+ # mask_index_before_ngram = random.randint(0, first_ngram_start-1)
579
+ # # print(f' ---- mask_index_before_ngram : {mask_index_before_ngram} ----- ')
580
+ # mask_indices.append(mask_index_before_ngram)
581
+
582
+ # # Mask words between common n-grams
583
+ # for i in range(len(ngram_positions) - 1):
584
+ # end_prev = ngram_positions[i][1]
585
+ # # print(f' ---- end_prev : {end_prev} ----- ')
586
+ # start_next = ngram_positions[i + 1][0]
587
+ # # print(f' ---- start_next : {start_next} ----- ')
588
+ # if start_next > end_prev + 1:
589
+ # mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
590
+ # # print(f' ---- mask_index_between_ngrams : {mask_index_between_ngrams} ----- ')
591
+ # mask_indices.append(mask_index_between_ngrams)
592
+
593
+ # # Mask a word after the last common n-gram
594
+ # last_ngram_end = ngram_positions[-1][1]
595
+ # if last_ngram_end < len(non_stop_words) - 1:
596
+ # # print(f' ---- last_ngram_end : {last_ngram_end} ----- ')
597
+ # mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
598
+ # # print(f' ---- mask_index_after_ngram : {mask_index_after_ngram} ----- ')
599
+ # mask_indices.append(mask_index_after_ngram)
600
+
601
+ # # Create mapping from non-stop words to original indices
602
+ # non_stop_to_original = {}
603
+ # non_stop_idx = 0
604
+ # for orig_idx, word in enumerate(original_words):
605
+ # if word.lower() not in self.stop_words:
606
+ # non_stop_to_original[non_stop_idx] = orig_idx
607
+ # non_stop_idx += 1
608
+
609
+ # # Map mask indices from non-stop word positions to original positions
610
+ # # print(f' ---- non_stop_to_original : {non_stop_to_original} ----- ')
611
+ # original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
612
+ # # print(f' ---- original_mask_indices : {original_mask_indices} ----- ')
613
+
614
+ # # Apply masks to the original sentence
615
+ # masked_words = original_words.copy()
616
+ # for idx in original_mask_indices:
617
+ # masked_words[idx] = self.tokenizer.mask_token
618
+
619
+ # return " ".join(masked_words), original_mask_indices
utils/old/non_melting_points_v1.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ from nltk.corpus import stopwords
3
+ from nltk.util import ngrams
4
+ from collections import Counter
5
+ import re
6
+
7
+ class NgramProcessor:
8
+ def __init__(self):
9
+ try:
10
+ nltk.data.find('corpora/stopwords')
11
+ except LookupError:
12
+ nltk.download('stopwords')
13
+
14
+ self.stop_words = set(stopwords.words('english'))
15
+
16
+ def remove_stopwords(self, text):
17
+ """
18
+ Remove stopwords using NLTK's stopword list
19
+
20
+ Args:
21
+ text (str): Input text
22
+
23
+ Returns:
24
+ str: Cleaned text with stopwords removed
25
+ """
26
+ words = re.findall(r'\w+', text.lower())
27
+ filtered_words = [word for word in words if word not in self.stop_words]
28
+ return ' '.join(filtered_words)
29
+
30
+ def is_exact_match(self, ngram, sentences):
31
+ """
32
+ Check if the given n-gram has an exact match in all sentences
33
+
34
+ Args:
35
+ ngram (str): The n-gram to search for
36
+ sentences (list): List of sentences to search in
37
+
38
+ Returns:
39
+ bool: True if n-gram has exact match in all sentences, False otherwise
40
+ """
41
+ return all(ngram in sentence for sentence in sentences)
42
+
43
+ def is_substring_of_any(self, ngram, common_ngrams):
44
+ """
45
+ Check if the given n-gram is an exact substring of any previously found common n-grams
46
+
47
+ Args:
48
+ ngram (str): The n-gram to check
49
+ common_ngrams (list): List of previously found common n-grams
50
+
51
+ Returns:
52
+ bool: True if ngram is a substring of any common_ngrams, False otherwise
53
+ """
54
+ return any(ngram in other_ngram for other_ngram in common_ngrams if ngram != other_ngram)
55
+
56
+ def find_filtered_ngrams(self, sentences):
57
+ """
58
+ Find all n-grams that have exact matches across all sentences,
59
+ excluding those that are part of larger common n-grams
60
+
61
+ Args:
62
+ sentences (list): List of sentences to analyze
63
+
64
+ Returns:
65
+ list: List of tuples where each tuple contains the n-gram and its indices in each sentence
66
+ """
67
+ original_sentences = sentences[:]
68
+ sentences = [self.remove_stopwords(sentence) for sentence in sentences]
69
+ ngram_lengths = [4, 3, 2, 1] # Quadgram, trigram, bigram, unigram
70
+ common_ngrams = []
71
+
72
+ for n in ngram_lengths:
73
+ ngrams_list = [list(ngrams(sentence.split(), n)) for sentence in sentences]
74
+ ngrams_counter = Counter(ngrams_list[0])
75
+
76
+ for ngram in ngrams_counter:
77
+ ngram_str = ' '.join(ngram)
78
+ if self.is_exact_match(ngram_str, sentences) and not self.is_substring_of_any(ngram_str, [ng[0] for ng in common_ngrams]):
79
+ indices = []
80
+ for original_sentence in original_sentences:
81
+ words = original_sentence.split()
82
+ ngram_indices = [
83
+ (i, i + n - 1) for i in range(len(words) - n + 1)
84
+ if ' '.join(words[i:i + n]).lower() == ngram_str
85
+ ]
86
+ indices.append(ngram_indices)
87
+ common_ngrams.append((ngram_str, indices))
88
+
89
+ return common_ngrams
90
+
91
+ def find_relative_order(self, sentence, common_ngrams):
92
+ """
93
+ Find the relative order of the common n-grams in the sentence
94
+
95
+ Args:
96
+ sentence (str): Sentence in which to find the relative order
97
+ common_ngrams (list): List of common n-grams
98
+
99
+ Returns:
100
+ list: List of tuples with the relative position and the n-gram
101
+ """
102
+ relative_order = []
103
+ for ngram, _ in common_ngrams:
104
+ index = sentence.find(ngram)
105
+ if index != -1:
106
+ relative_order.append((index, ngram))
107
+
108
+ return sorted(relative_order)
109
+
110
+ # Example usage
111
+ if __name__ == "__main__":
112
+ sentences = [
113
+ "The quick brown fox jumps over the lazy dog.",
114
+ "A quick brown dog outpaces a lazy fox.",
115
+ "Quick brown animals leap over lazy obstacles."
116
+ ]
117
+
118
+ processor = NgramProcessor()
119
+ common_ngrams = processor.find_filtered_ngrams(sentences)
120
+ print("Common n-grams and their indices:")
121
+ for ngram, indices in common_ngrams:
122
+ print(f"{ngram}: {indices}")
123
+
124
+ for sentence in sentences:
125
+ relative_order = processor.find_relative_order(sentence, common_ngrams)
126
+ print(f"Relative order in sentence '{sentence}':", relative_order)
127
+
128
+
129
+
130
+ # import nltk
131
+ # from nltk.corpus import stopwords
132
+ # from nltk.util import ngrams
133
+ # from collections import Counter
134
+ # import re
135
+
136
+ # class NgramProcessor:
137
+ # def __init__(self):
138
+ # try:
139
+ # nltk.data.find('corpora/stopwords')
140
+ # except LookupError:
141
+ # nltk.download('stopwords')
142
+
143
+ # self.stop_words = set(stopwords.words('english'))
144
+
145
+ # def remove_stopwords(self, text):
146
+ # """
147
+ # Remove stopwords using NLTK's stopword list
148
+
149
+ # Args:
150
+ # text (str): Input text
151
+
152
+ # Returns:
153
+ # str: Cleaned text with stopwords removed
154
+ # """
155
+ # words = re.findall(r'\w+', text.lower())
156
+ # filtered_words = [word for word in words if word not in self.stop_words]
157
+ # return ' '.join(filtered_words)
158
+
159
+ # def is_exact_match(self, ngram, sentences):
160
+ # """
161
+ # Check if the given n-gram has an exact match in all sentences
162
+
163
+ # Args:
164
+ # ngram (str): The n-gram to search for
165
+ # sentences (list): List of sentences to search in
166
+
167
+ # Returns:
168
+ # bool: True if n-gram has exact match in all sentences, False otherwise
169
+ # """
170
+ # return all(ngram in sentence for sentence in sentences)
171
+
172
+ # def is_substring_of_any(self, ngram, common_ngrams):
173
+ # """
174
+ # Check if the given n-gram is an exact substring of any previously found common n-grams
175
+
176
+ # Args:
177
+ # ngram (str): The n-gram to check
178
+ # common_ngrams (list): List of previously found common n-grams
179
+
180
+ # Returns:
181
+ # bool: True if ngram is a substring of any common_ngrams, False otherwise
182
+ # """
183
+ # return any(ngram in other_ngram for other_ngram in common_ngrams if ngram != other_ngram)
184
+
185
+ # def find_filtered_ngrams(self, sentences):
186
+ # """
187
+ # Find all n-grams that have exact matches across all sentences,
188
+ # excluding those that are part of larger common n-grams
189
+
190
+ # Args:
191
+ # sentences (list): List of sentences to analyze
192
+
193
+ # Returns:
194
+ # list: List of all common n-grams in order of their appearance in the first sentence
195
+ # """
196
+ # sentences = [self.remove_stopwords(sentence) for sentence in sentences]
197
+ # ngram_lengths = [4, 3, 2, 1] # Quadgram, trigram, bigram, unigram
198
+ # common_ngrams = []
199
+
200
+ # for n in ngram_lengths:
201
+ # ngrams_list = [list(ngrams(sentence.split(), n)) for sentence in sentences]
202
+ # ngrams_counter = Counter(ngrams_list[0])
203
+
204
+ # for ngram in ngrams_counter:
205
+ # ngram_str = ' '.join(ngram)
206
+ # if self.is_exact_match(ngram_str, sentences) and not self.is_substring_of_any(ngram_str, common_ngrams):
207
+ # common_ngrams.append(ngram_str)
208
+
209
+ # return common_ngrams
210
+
211
+ # def find_relative_order(self, sentence, common_ngrams):
212
+ # """
213
+ # Find the relative order of the common n-grams in the sentence
214
+
215
+ # Args:
216
+ # sentence (str): Sentence in which to find the relative order
217
+ # common_ngrams (list): List of common n-grams
218
+
219
+ # Returns:
220
+ # list: List of tuples with the relative position and the n-gram
221
+ # """
222
+ # relative_order = []
223
+ # for ngram in common_ngrams:
224
+ # index = sentence.find(ngram)
225
+ # if index != -1:
226
+ # relative_order.append((index, ngram))
227
+
228
+ # return sorted(relative_order)
229
+
230
+ # # Example usage
231
+ # if __name__ == "__main__":
232
+ # sentences = [
233
+ # "The quick brown fox jumps over the lazy dog.",
234
+ # "A quick brown dog outpaces a lazy fox.",
235
+ # "Quick brown animals leap over lazy obstacles."
236
+ # ]
237
+
238
+ # processor = NgramProcessor()
239
+ # common_ngrams = processor.find_filtered_ngrams(sentences)
240
+ # print("Common n-grams:", common_ngrams)
241
+
242
+ # for sentence in sentences:
243
+ # relative_order = processor.find_relative_order(sentence, common_ngrams)
244
+ # print(f"Relative order in sentence '{sentence}':", relative_order)
utils/old/sampling/sampling.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ from masking_methods import MaskingProcessor
4
+ import nltk
5
+ from nltk.corpus import words
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class SamplingProcessor:
10
+ def __init__(self, tokenizer):
11
+ """
12
+ Initialize the SamplingProcessor.
13
+
14
+ Args:
15
+ tokenizer: BERT tokenizer instance
16
+ """
17
+ self.tokenizer = tokenizer
18
+ self.subtoken_prefix = self._get_subtoken_prefix()
19
+ self.subtoken_ids = self._get_subtoken_ids()
20
+ try:
21
+ nltk.data.find('corpora/words')
22
+ except LookupError:
23
+ nltk.download('words')
24
+ self.english_words = set(words.words())
25
+
26
+ # def _get_subtoken_prefix(self):
27
+ # """
28
+ # Identify the subtoken prefix based on the tokenizer.
29
+
30
+ # Returns:
31
+ # str: The prefix used for subtokens (e.g., "##" for BERT).
32
+ # """
33
+ # # This method assumes that the tokenizer uses a consistent subtoken prefix.
34
+ # # Adjust accordingly if using different tokenizers.
35
+ # # For BERT's WordPiece tokenizer:
36
+ # if hasattr(self.tokenizer, "init_kwargs") and "wordpiece_prefix" in self.tokenizer.init_kwargs:
37
+ # return self.tokenizer.init_kwargs["wordpiece_prefix"]
38
+ # elif hasattr(self.tokenizer, "prefix_tokens"):
39
+ # return self.tokenizer.prefix_tokens
40
+ # else:
41
+ # # Default to BERT's subtoken prefix
42
+ # return "##"
43
+
44
+ def _get_subtoken_prefix(self):
45
+ """
46
+ Identify the subtoken prefix based on the tokenizer.
47
+
48
+ Returns:
49
+ str: The prefix used for subtokens (e.g., "##" for BERT).
50
+ """
51
+ # This method assumes that the tokenizer uses a consistent subtoken prefix.
52
+ # Adjust accordingly if using different tokenizers.
53
+ # For BERT's WordPiece tokenizer:
54
+ if hasattr(self.tokenizer, "init_kwargs") and "wordpiece_prefix" in self.tokenizer.init_kwargs:
55
+ return self.tokenizer.init_kwargs["wordpiece_prefix"]
56
+ elif hasattr(self.tokenizer, "prefix_tokens"):
57
+ return self.tokenizer.prefix_tokens
58
+ else:
59
+ # Default to BERT's subtoken prefix
60
+ return "##"
61
+
62
+
63
+ # def _get_subtoken_ids(self):
64
+ # """
65
+ # Retrieve all token IDs that correspond to subtokens.
66
+
67
+ # Returns:
68
+ # set: A set of subtoken IDs.
69
+ # """
70
+ # vocab = self.tokenizer.get_vocab()
71
+ # subtoken_ids = set()
72
+ # for token, idx in vocab.items():
73
+ # if token.startswith(self.subtoken_prefix):
74
+ # subtoken_ids.add(idx)
75
+ # return subtoken_ids
76
+
77
+ def _get_subtoken_ids(self):
78
+ """
79
+ Retrieve all token IDs that correspond to subtokens.
80
+
81
+ Returns:
82
+ list: A list of subtoken IDs.
83
+ """
84
+ vocab = self.tokenizer.get_vocab()
85
+ subtoken_ids = []
86
+ for token, idx in vocab.items():
87
+ if token.startswith(self.subtoken_prefix):
88
+ subtoken_ids.append(idx)
89
+ return subtoken_ids # Changed from set to list
90
+
91
+
92
+ def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0):
93
+ tokens = self.tokenizer.tokenize(masked_sentence)
94
+
95
+ for mask_pos in sorted(mask_logits_dict.keys()):
96
+ try:
97
+ # Get logits and squeeze extra dimension
98
+ mask_logits = torch.tensor(mask_logits_dict[mask_pos]).squeeze(0) # Remove the extra dimension
99
+
100
+ # Create a mask for valid tokens (no special tokens, no subwords)
101
+ valid_mask = torch.zeros_like(mask_logits, dtype=torch.bool)
102
+ for idx in range(len(mask_logits)):
103
+ token = self.tokenizer.convert_ids_to_tokens([idx])[0]
104
+ # Only allow regular words (no special tokens, no subwords)
105
+ if token.isalpha() and not token.startswith('[') and not token.startswith('##'):
106
+ valid_mask[idx] = True
107
+
108
+ # Get valid logits
109
+ valid_logits = mask_logits[valid_mask]
110
+ valid_indices = torch.where(valid_mask)[0]
111
+
112
+ if len(valid_logits) == 0:
113
+ print(f"Warning: No valid tokens found for position {mask_pos}")
114
+ continue
115
+
116
+ if sampling_technique == "inverse_transform":
117
+ probs = torch.softmax(valid_logits / temperature, dim=-1)
118
+ cumulative_probs = torch.cumsum(probs, dim=-1)
119
+ random_prob = random.random()
120
+ sampled_idx = torch.where(cumulative_probs >= random_prob)[0][0].item()
121
+ sampled_index = valid_indices[sampled_idx].item()
122
+
123
+ elif sampling_technique == "exponential_minimum":
124
+ probs = torch.softmax(valid_logits / temperature, dim=-1)
125
+ exp_probs = torch.exp(-torch.log(probs))
126
+ random_probs = torch.rand_like(exp_probs)
127
+ sampled_idx = torch.argmax(random_probs * exp_probs).item()
128
+ sampled_index = valid_indices[sampled_idx].item()
129
+
130
+ elif sampling_technique == "temperature":
131
+ valid_logits = torch.clamp(valid_logits, min=-1e8, max=1e8)
132
+ probs = torch.softmax(valid_logits / temperature, dim=-1)
133
+ if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
134
+ raise ValueError("The computed probabilities contain NaN or inf values.")
135
+ probs = torch.max(probs, torch.tensor(1e-8))
136
+ probs = probs / torch.sum(probs)
137
+ sampled_idx = torch.multinomial(probs, 1)[0].item()
138
+ sampled_index = valid_indices[sampled_idx].item()
139
+
140
+ elif sampling_technique == 'greedy':
141
+ sampled_idx = torch.argmax(valid_logits).item()
142
+ sampled_index = valid_indices[sampled_idx].item()
143
+
144
+ # Replace mask with sampled token
145
+ sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0]
146
+ tokens[mask_pos] = sampled_token
147
+
148
+ except Exception as e:
149
+ print(f"Error sampling for position {mask_pos}: {str(e)}")
150
+ continue
151
+
152
+ return self.tokenizer.convert_tokens_to_string(tokens)
153
+
154
+
155
+
156
+ def process_masked_sentences(self, results_dict, sampling_technique="temperature", temperature=1.0):
157
+ """
158
+ Process all masked sentences in the results dictionary.
159
+
160
+ Args:
161
+ results_dict (dict): Dictionary containing masked sentences and their logits
162
+ sampling_technique (str): Sampling method to use
163
+ temperature (float): Temperature parameter for sampling
164
+
165
+ Returns:
166
+ dict: Dictionary containing original, masked, and sampled sentences
167
+ """
168
+ processed_results = {}
169
+
170
+ for original_sentence, data in results_dict.items():
171
+ masked_sentence = data["masked_sentence"]
172
+ mask_logits = data["mask_logits"]
173
+
174
+ sampled_sentence = self.sample_tokens(
175
+ mask_logits,
176
+ masked_sentence,
177
+ sampling_technique,
178
+ temperature
179
+ )
180
+
181
+ processed_results[original_sentence] = {
182
+ "masked_sentence": masked_sentence,
183
+ "sampled_sentence": sampled_sentence
184
+ }
185
+
186
+ return processed_results
187
+
188
+ if __name__ == "__main__":
189
+ sentences = [
190
+ "The quick brown fox jumps over the lazy dog everyday.",
191
+ ]
192
+ result_dict = {
193
+ 'The quick brown fox jumps over the lazy dog everyday.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
194
+ }
195
+
196
+ # First, mask the sentences
197
+ masking_processor = MaskingProcessor()
198
+ masking_results = masking_processor.process_sentences(sentences, result_dict)
199
+
200
+ # Then, sample replacements for the masks
201
+ sampling_processor = SamplingProcessor(masking_processor.tokenizer)
202
+
203
+ # Try different sampling techniques
204
+ sampling_techniques = ["temperature", "greedy", "inverse_transform", "exponential_minimum"]
205
+
206
+ for technique in sampling_techniques:
207
+ print(f"\nSampling using {technique}:")
208
+ sampled_results = sampling_processor.process_masked_sentences(
209
+ masking_results,
210
+ sampling_technique=technique,
211
+ temperature=1.0
212
+ )
213
+
214
+ for original_sentence, result in sampled_results.items():
215
+ print(f"Original: {original_sentence}")
216
+ print(f"Masked: {result['masked_sentence']}")
217
+ print(f"Sampled: {result['sampled_sentence']}")
218
+ print("---")
219
+
220
+ # --------------------------------------------------------------------------------------------------
221
+ # def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0, top_k=100):
222
+ # words = masked_sentence.split()
223
+ # mask_positions = sorted(mask_logits_dict.keys())
224
+
225
+ # for mask_pos in mask_positions:
226
+ # mask_logits = torch.tensor(mask_logits_dict[mask_pos])
227
+
228
+ # try:
229
+ # if sampling_technique == "inverse_transform":
230
+ # probs = torch.softmax(mask_logits / temperature, dim=-1)
231
+ # cumulative_probs = torch.cumsum(probs, dim=-1)
232
+ # random_prob = random.random()
233
+ # sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
234
+
235
+ # elif sampling_technique == "exponential_minimum":
236
+ # probs = torch.softmax(mask_logits / temperature, dim=-1)
237
+ # exp_probs = torch.exp(-torch.log(probs))
238
+ # random_probs = torch.rand_like(exp_probs)
239
+ # sampled_index = torch.argmax(random_probs * exp_probs).item()
240
+
241
+ # elif sampling_technique == "temperature":
242
+ # mask_logits = torch.clamp(mask_logits, min=-1e8, max=1e8)
243
+ # probs = torch.softmax(mask_logits / temperature, dim=-1)
244
+ # if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
245
+ # raise ValueError("The computed probabilities contain NaN or inf values.")
246
+ # probs = torch.max(probs, torch.tensor(1e-8))
247
+ # probs = probs / torch.sum(probs)
248
+ # sampled_index = torch.multinomial(probs, 1)[0].item()
249
+
250
+ # elif sampling_technique == 'greedy':
251
+ # sampled_index = torch.argmax(mask_logits).item()
252
+
253
+ # else:
254
+ # raise ValueError(f"Unknown sampling technique: {sampling_technique}")
255
+
256
+ # # Replace mask with sampled token
257
+ # sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0]
258
+ # words[mask_pos] = sampled_token
259
+
260
+ # except Exception as e:
261
+ # print(f"Error sampling for position {mask_pos}: {str(e)}")
262
+ # continue
263
+
264
+ # return " ".join(words)
265
+
266
+ ## MORE WEIRD RESULTS
267
+ # def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0, top_k=100):
268
+ # words = masked_sentence.split()
269
+ # mask_positions = sorted(mask_logits_dict.keys())
270
+
271
+ # for mask_pos in mask_positions:
272
+ # mask_logits = torch.tensor(mask_logits_dict[mask_pos])
273
+
274
+ # try:
275
+ # # Create a mask for valid tokens (no special tokens, no subwords)
276
+ # valid_mask = torch.zeros_like(mask_logits, dtype=torch.bool)
277
+ # for idx in range(len(mask_logits)):
278
+ # token = self.tokenizer.convert_ids_to_tokens([idx])[0]
279
+ # # Only allow regular words (no special tokens, no subwords)
280
+ # if token.isalpha() and not token.startswith('[') and not token.startswith('##'):
281
+ # valid_mask[idx] = True
282
+
283
+ # # Get valid logits
284
+ # valid_logits = mask_logits[valid_mask]
285
+ # valid_indices = torch.where(valid_mask)[0]
286
+
287
+ # if len(valid_logits) == 0:
288
+ # print(f"Warning: No valid tokens found for position {mask_pos}")
289
+ # continue
290
+
291
+ # if sampling_technique == "inverse_transform":
292
+ # probs = torch.softmax(valid_logits / temperature, dim=-1)
293
+ # cumulative_probs = torch.cumsum(probs, dim=-1)
294
+ # random_prob = random.random()
295
+ # sampled_idx = torch.where(cumulative_probs >= random_prob)[0][0].item()
296
+ # sampled_index = valid_indices[sampled_idx].item()
297
+
298
+ # elif sampling_technique == "exponential_minimum":
299
+ # probs = torch.softmax(valid_logits / temperature, dim=-1)
300
+ # exp_probs = torch.exp(-torch.log(probs))
301
+ # random_probs = torch.rand_like(exp_probs)
302
+ # sampled_idx = torch.argmax(random_probs * exp_probs).item()
303
+ # sampled_index = valid_indices[sampled_idx].item()
304
+
305
+ # elif sampling_technique == "temperature":
306
+ # valid_logits = torch.clamp(valid_logits, min=-1e8, max=1e8)
307
+ # probs = torch.softmax(valid_logits / temperature, dim=-1)
308
+ # if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
309
+ # raise ValueError("The computed probabilities contain NaN or inf values.")
310
+ # probs = torch.max(probs, torch.tensor(1e-8))
311
+ # probs = probs / torch.sum(probs)
312
+ # sampled_idx = torch.multinomial(probs, 1)[0].item()
313
+ # sampled_index = valid_indices[sampled_idx].item()
314
+
315
+ # elif sampling_technique == 'greedy':
316
+ # sampled_idx = torch.argmax(valid_logits).item()
317
+ # sampled_index = valid_indices[sampled_idx].item()
318
+
319
+ # else:
320
+ # raise ValueError(f"Unknown sampling technique: {sampling_technique}")
321
+
322
+ # # Replace mask with sampled token
323
+ # sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0]
324
+ # words[mask_pos] = sampled_token
325
+
326
+ # except Exception as e:
327
+ # print(f"Error sampling for position {mask_pos}: {str(e)}")
328
+ # continue
329
+
330
+ # return " ".join(words)
utils/old/sampling/sampling_methods.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertTokenizer, BertForMaskedLM
2
+ import torch
3
+ import random
4
+ from masking_methods import MaskingProcessor
5
+ from transformers import pipeline
6
+
7
+ class SamplingProcessorWithModel:
8
+ def __init__(self, model_name='bert-base-uncased'):
9
+ self.tokenizer = BertTokenizer.from_pretrained(model_name)
10
+ self.model = BertForMaskedLM.from_pretrained(model_name)
11
+ self.model.eval() # Set the model to evaluation mode
12
+
13
+ def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0):
14
+ """
15
+ Fills each mask in the masked sentence using the specified sampling technique.
16
+
17
+ Args:
18
+ masked_sentence (str): Sentence with [MASK] tokens.
19
+ sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
20
+ temperature (float): Temperature parameter for sampling methods.
21
+
22
+ Returns:
23
+ str: Sentence with the masks filled.
24
+ """
25
+ input_ids = self.tokenizer.encode(masked_sentence, return_tensors="pt")
26
+
27
+ while self.tokenizer.mask_token_id in input_ids[0]:
28
+ # Find indices of all [MASK] tokens
29
+ mask_indices = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
30
+
31
+ # Process the first [MASK] token in the sequence
32
+ mask_index = mask_indices[0].item()
33
+
34
+ # Get logits from the model
35
+ with torch.no_grad():
36
+ outputs = self.model(input_ids)
37
+ logits = outputs.logits
38
+
39
+ # Extract logits for the [MASK] token
40
+ mask_logits = logits[0, mask_index]
41
+
42
+ if sampling_technique == "inverse_transform":
43
+ probs = torch.softmax(mask_logits / temperature, dim=-1)
44
+ cumulative_probs = torch.cumsum(probs, dim=-1)
45
+ random_prob = random.random()
46
+ sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
47
+
48
+ elif sampling_technique == "exponential_minimum":
49
+ probs = torch.softmax(mask_logits / temperature, dim=-1)
50
+ exp_probs = torch.exp(-torch.log(probs))
51
+ random_probs = torch.rand_like(exp_probs)
52
+ sampled_index = torch.argmax(random_probs * exp_probs).item()
53
+
54
+ elif sampling_technique == "temperature":
55
+ mask_logits = torch.clamp(mask_logits, min=-1e8, max=1e8)
56
+ probs = torch.softmax(mask_logits / temperature, dim=-1)
57
+ if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
58
+ raise ValueError("The computed probabilities contain NaN or inf values.")
59
+ probs = torch.max(probs, torch.tensor(1e-8, device=mask_logits.device))
60
+ probs = probs / torch.sum(probs)
61
+ probs = probs.flatten()
62
+ if probs.size(0) > 1:
63
+ sampled_index = torch.multinomial(probs, 1).item()
64
+ else:
65
+ sampled_index = torch.argmax(probs).item()
66
+
67
+ elif sampling_technique == 'greedy':
68
+ sampled_index = torch.argmax(mask_logits).item()
69
+
70
+ else:
71
+ raise ValueError(f"Unknown sampling technique: {sampling_technique}")
72
+
73
+ # Replace the first [MASK] with the selected token
74
+ input_ids[0, mask_index] = sampled_index
75
+
76
+ return self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
77
+
78
+ def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0):
79
+ """
80
+ Fills each mask in the masked sentence using the specified sampling technique.
81
+
82
+ Args:
83
+ masked_sentence (str): Sentence with [MASK] tokens.
84
+ sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
85
+ temperature (float): Temperature parameter for sampling methods.
86
+
87
+ Returns:
88
+ str: Sentence with the masks filled.
89
+ """
90
+ while '[MASK]' in masked_sentence:
91
+ # Get predictions for the first [MASK]
92
+ predictions = self.unmasker(masked_sentence)
93
+
94
+ # Ensure predictions is a list of dictionaries
95
+ if not isinstance(predictions, list) or not all(isinstance(pred, dict) for pred in predictions):
96
+ raise ValueError("Unexpected structure in predictions from the pipeline.")
97
+
98
+ # Extract logits (scores) from the predictions
99
+ logits = torch.tensor([pred['score'] for pred in predictions], dtype=torch.float32)
100
+
101
+ if sampling_technique == "inverse_transform":
102
+ probs = torch.softmax(logits / temperature, dim=-1)
103
+ cumulative_probs = torch.cumsum(probs, dim=-1)
104
+ random_prob = random.random()
105
+ sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
106
+
107
+ elif sampling_technique == "exponential_minimum":
108
+ probs = torch.softmax(logits / temperature, dim=-1)
109
+ exp_probs = torch.exp(-torch.log(probs))
110
+ random_probs = torch.rand_like(exp_probs)
111
+ sampled_index = torch.argmax(random_probs * exp_probs).item()
112
+
113
+ elif sampling_technique == "temperature":
114
+ logits = torch.clamp(logits, min=-1e8, max=1e8)
115
+ probs = torch.softmax(logits / temperature, dim=-1)
116
+ if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
117
+ raise ValueError("The computed probabilities contain NaN or inf values.")
118
+ probs = torch.max(probs, torch.tensor(1e-8, device=logits.device))
119
+ probs = probs / torch.sum(probs)
120
+ probs = probs.flatten()
121
+ if probs.size(0) > 1:
122
+ sampled_index = torch.multinomial(probs, 1).item()
123
+ else:
124
+ sampled_index = torch.argmax(probs).item()
125
+
126
+ elif sampling_technique == 'greedy':
127
+ sampled_index = torch.argmax(logits).item()
128
+
129
+ else:
130
+ raise ValueError(f"Unknown sampling technique: {sampling_technique}")
131
+
132
+ # Replace the first [MASK] with the selected word
133
+ sampled_token = predictions[sampled_index]['token_str']
134
+ masked_sentence = masked_sentence.replace('[MASK]', sampled_token, 1)
135
+
136
+ return masked_sentence
137
+
138
+
139
+ # Example usage
140
+ if __name__ == "__main__":
141
+ from transformers import BertTokenizer
142
+
143
+ # Define sentences and result_dict
144
+ sentences = [
145
+ "The quick brown fox jumps over the lazy dog.",
146
+ "A quick brown dog outpaces a lazy fox.",
147
+ "Quick brown dog leaps over lazy the fox."
148
+ ]
149
+ result_dict = {
150
+ "The quick brown fox jumps over the lazy dog.": {'quick brown': [(0, 1)], 'fox': [(2, 2)], 'lazy': [(4, 4)], 'dog': [(5, 5)]},
151
+ "A quick brown dog outpaces a lazy fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]},
152
+ "Quick brown dog leaps over lazy the fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]}
153
+ }
154
+
155
+ masking_processor = MaskingProcessor()
156
+ masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
157
+
158
+ # Use SamplingProcessor
159
+ sampling_processor = SamplingProcessorWithModel()
160
+
161
+ # Iterate through masking results to apply sampling
162
+ for sentence, result in masking_results.items():
163
+ print(f"Original Sentence (Random): {sentence}")
164
+ print(f"Masked Sentence (Random): {result['masked_sentence']}")
165
+ masked_sentence = result["masked_sentence"]
166
+
167
+ # Apply different sampling techniques
168
+ for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]:
169
+ print(f"Sampling Technique: {technique}")
170
+ filled_sentence = sampling_processor.fill_masked_sentence(
171
+ masked_sentence=masked_sentence,
172
+ sampling_technique=technique,
173
+ temperature=1.0 # Adjust temperature as needed
174
+ )
175
+ print(f"Filled Sentence: {filled_sentence}\n")
176
+ print('--------------------------------')
177
+
178
+
179
+
180
+ # from transformers import pipeline
181
+ # import torch
182
+ # import random
183
+ # from masking_methods import MaskingProcessor
184
+
185
+
186
+ # class SamplingProcessorWithPipeline:
187
+ # def __init__(self, model_name='bert-base-uncased'):
188
+ # self.unmasker = pipeline('fill-mask', model=model_name)
189
+ # self.tokenizer = self.unmasker.tokenizer
190
+
191
+ # def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0):
192
+ # """
193
+ # Fills each mask in the masked sentence using the specified sampling technique.
194
+
195
+ # Args:
196
+ # masked_sentence (str): Sentence with [MASK] tokens.
197
+ # sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
198
+ # temperature (float): Temperature parameter for sampling methods.
199
+
200
+ # Returns:
201
+ # str: Sentence with the masks filled.
202
+ # """
203
+ # while '[MASK]' in masked_sentence:
204
+ # # Get predictions for the first [MASK]
205
+ # predictions = self.unmasker(masked_sentence)
206
+ # print(f' predictions : {predictions}')
207
+ # print(f' type of predictions : {type(predictions)}')
208
+
209
+ # # Ensure predictions is a list of dictionaries for the first [MASK]
210
+ # if not isinstance(predictions, list) or not all(isinstance(pred, dict) for pred in predictions):
211
+ # raise ValueError("Unexpected structure in predictions from the pipeline.")
212
+
213
+ # # Extract logits (scores) from the predictions
214
+ # logits = torch.tensor([pred['score'] for pred in predictions], dtype=torch.float32)
215
+
216
+ # if sampling_technique == "inverse_transform":
217
+ # probs = torch.softmax(logits / temperature, dim=-1)
218
+ # cumulative_probs = torch.cumsum(probs, dim=-1)
219
+ # random_prob = random.random()
220
+ # sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
221
+
222
+ # elif sampling_technique == "exponential_minimum":
223
+ # probs = torch.softmax(logits / temperature, dim=-1)
224
+ # exp_probs = torch.exp(-torch.log(probs))
225
+ # random_probs = torch.rand_like(exp_probs)
226
+ # sampled_index = torch.argmax(random_probs * exp_probs).item()
227
+
228
+ # elif sampling_technique == "temperature":
229
+ # logits = torch.clamp(logits, min=-1e8, max=1e8)
230
+ # probs = torch.softmax(logits / temperature, dim=-1)
231
+ # if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
232
+ # raise ValueError("The computed probabilities contain NaN or inf values.")
233
+ # probs = torch.max(probs, torch.tensor(1e-8, device=logits.device))
234
+ # probs = probs / torch.sum(probs)
235
+ # probs = probs.flatten()
236
+ # if probs.size(0) > 1:
237
+ # sampled_index = torch.multinomial(probs, 1).item()
238
+ # else:
239
+ # sampled_index = torch.argmax(probs).item()
240
+
241
+ # elif sampling_technique == 'greedy':
242
+ # sampled_index = torch.argmax(logits).item()
243
+
244
+ # else:
245
+ # raise ValueError(f"Unknown sampling technique: {sampling_technique}")
246
+
247
+ # # Replace the first [MASK] with the selected word
248
+ # sampled_token = predictions[sampled_index]['token_str']
249
+ # masked_sentence = masked_sentence.replace('[MASK]', sampled_token, 1)
250
+
251
+ # return masked_sentence
252
+
253
+
254
+ # # Example usage
255
+ # if __name__ == "__main__":
256
+ # from transformers import BertTokenizer
257
+
258
+ # # Define sentences and result_dict
259
+ # sentences = [
260
+ # "The quick brown fox jumps over the lazy dog.",
261
+ # "A quick brown dog outpaces a lazy fox.",
262
+ # "Quick brown animals leap over lazy obstacles."
263
+ # ]
264
+ # result_dict = {
265
+ # "The quick brown fox jumps over the lazy dog.": {"quick brown": [(1, 2)], "lazy": [(7, 7)]},
266
+ # "A quick brown dog outpaces a lazy fox.": {"quick brown": [(1, 2)], "lazy": [(6, 6)]},
267
+ # "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(5, 5)]}
268
+ # }
269
+
270
+ # masking_processor = MaskingProcessor()
271
+ # masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
272
+
273
+ # # Use SamplingProcessor
274
+ # sampling_processor = SamplingProcessorWithPipeline()
275
+
276
+ # # Iterate through masking results to apply sampling
277
+ # for sentence, result in masking_results.items():
278
+ # print(f"Original Sentence (Random): {sentence}")
279
+ # print(f"Masked Sentence (Random): {result['masked_sentence']}")
280
+ # masked_sentence = result["masked_sentence"]
281
+
282
+ # # Apply different sampling techniques
283
+ # for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]:
284
+ # print(f"Sampling Technique: {technique}")
285
+ # filled_sentence = sampling_processor.fill_masked_sentence(
286
+ # masked_sentence=masked_sentence,
287
+ # sampling_technique=technique,
288
+ # temperature=1.0 # Adjust temperature as needed
289
+ # )
290
+ # print(f"Filled Sentence: {filled_sentence}\n")
291
+ # print('--------------------------------')
utils/old/sampling/sampling_methods_v1.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ from masking_methods import MaskingProcessor
4
+
5
+ class SamplingProcessor:
6
+ def __init__(self, tokenizer):
7
+ self.tokenizer = tokenizer
8
+
9
+ def fill_masked_sentence(self, original_sentence, mask_logits, sampling_technique, temperature=1.0):
10
+ """
11
+ Fills each mask in the masked sentence using the specified sampling technique.
12
+
13
+ Args:
14
+ original_sentence (str): The original masked sentence.
15
+ mask_logits (dict): Logits for each [MASK] token.
16
+ sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
17
+ temperature (float): Temperature parameter for sampling methods.
18
+
19
+ Returns:
20
+ str: Sentence with the masks filled.
21
+ """
22
+ sentence_tokens = self.tokenizer.tokenize(original_sentence)
23
+ mask_token_indices = [i for i, token in enumerate(sentence_tokens) if token == self.tokenizer.mask_token]
24
+
25
+ if len(mask_token_indices) != len(mask_logits):
26
+ raise ValueError("Mismatch between number of [MASK] tokens and logits provided.")
27
+
28
+ for mask_idx, filtered_logits in zip(mask_token_indices, mask_logits.values()):
29
+ # Convert logits to a tensor
30
+ filtered_logits = torch.tensor(filtered_logits)
31
+ # filtered_logits, _ = torch.sort(filtered_logits, descending=True)
32
+ # print(f' type of filtered_logits : {type(filtered_logits)}')
33
+ # filtered_logits = filtered_logits[:5]
34
+
35
+ if sampling_technique == "inverse_transform":
36
+ probs = torch.softmax(filtered_logits / temperature, dim=-1)
37
+ cumulative_probs = torch.cumsum(probs, dim=-1)
38
+ random_prob = random.random()
39
+ sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
40
+
41
+ elif sampling_technique == "exponential_minimum":
42
+ probs = torch.softmax(filtered_logits / temperature, dim=-1)
43
+ exp_probs = torch.exp(-torch.log(probs))
44
+ random_probs = torch.rand_like(exp_probs)
45
+ sampled_index = torch.argmax(random_probs * exp_probs).item()
46
+
47
+ elif sampling_technique == "temperature":
48
+ filtered_logits = torch.clamp(filtered_logits, min=-1e8, max=1e8)
49
+ probs = torch.softmax(filtered_logits / temperature, dim=-1)
50
+ if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
51
+ raise ValueError("The computed probabilities contain NaN or inf values.")
52
+ probs = torch.max(probs, torch.tensor(1e-8, device=filtered_logits.device))
53
+ probs = probs / torch.sum(probs)
54
+ probs = probs.flatten()
55
+ if probs.size(0) > 1:
56
+ sampled_index = torch.multinomial(probs, 1).item()
57
+ else:
58
+ sampled_index = torch.argmax(probs).item()
59
+
60
+ elif sampling_technique == 'greedy':
61
+ sampled_index = torch.argmax(filtered_logits).item()
62
+
63
+ else:
64
+ raise ValueError(f"Unknown sampling technique: {sampling_technique}")
65
+
66
+ sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0]
67
+ sentence_tokens[mask_idx] = sampled_token
68
+
69
+ return self.tokenizer.convert_tokens_to_string(sentence_tokens)
70
+
71
+
72
+
73
+ def process_samples(self, masked_sentences, mask_logits, sampling_technique, temperature=1.0):
74
+ """
75
+ Process multiple masked sentences and fill their masks using the specified sampling technique.
76
+
77
+ Args:
78
+ masked_sentences (list): List of masked sentences.
79
+ mask_logits (dict): Logits for each [MASK] token in each sentence.
80
+ sampling_technique (str): Sampling technique to use.
81
+ temperature (float): Temperature parameter for sampling methods.
82
+
83
+ Returns:
84
+ list: List of sentences with masks filled.
85
+ """
86
+ filled_sentences = []
87
+ for sentence, logits in zip(masked_sentences, mask_logits):
88
+ filled_sentence = self.fill_masked_sentence(sentence, logits, sampling_technique, temperature)
89
+ filled_sentences.append(filled_sentence)
90
+ return filled_sentences
91
+
92
+ # Example usage
93
+ if __name__ == "__main__":
94
+ from transformers import BertTokenizer
95
+
96
+ # tokenizer = BertTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
97
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
98
+ processor = SamplingProcessor(tokenizer)
99
+
100
+ sentences = [
101
+ "The quick brown fox jumps over the lazy dog.",
102
+ "A quick brown dog outpaces a lazy fox.",
103
+ "Quick brown dog leaps over lazy the fox."
104
+ ]
105
+ result_dict = {
106
+ "The quick brown fox jumps over the lazy dog.": {'quick brown': [(0, 1)], 'fox': [(2, 2)], 'lazy': [(4, 4)], 'dog': [(5, 5)]},
107
+ "A quick brown dog outpaces a lazy fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]},
108
+ "Quick brown dog leaps over lazy the fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]}
109
+ }
110
+
111
+
112
+ masking_processor = MaskingProcessor()
113
+ masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
114
+ # masked_sentence = "The [MASK] brown fox jumps [MASK] the lazy dog."
115
+ # mask_logits = {
116
+ # 1: torch.randn(len(tokenizer)), # Example logits for first [MASK]
117
+ # 5: torch.randn(len(tokenizer)), # Example logits for second [MASK]
118
+ # }
119
+
120
+ # Iterate through masking results to apply sampling
121
+ for sentence, result in masking_results.items():
122
+ print(f"Original Sentence (Random): {sentence}")
123
+ print(f"Masked Sentence (Random): {result['masked_sentence']}")
124
+ # print(f"Mask Logits (Random): {output['mask_logits']}")
125
+ print(f' type(result["mask_logits"]) : {type(result["mask_logits"])}')
126
+ print(f' length of result["mask_logits"] : {len(result["mask_logits"])}')
127
+ print(f' result["mask_logits"].keys() : {result["mask_logits"].keys()}')
128
+ masked_sentence = result["masked_sentence"]
129
+ mask_logits = result["mask_logits"]
130
+
131
+ print(f"Original Masked Sentence: {masked_sentence}")
132
+
133
+ # Apply different sampling techniques
134
+ for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]:
135
+ print(f"Sampling Technique: {technique}")
136
+
137
+ # Fill the masks using the sampling processor
138
+ filled_sentence = processor.fill_masked_sentence(
139
+ original_sentence=masked_sentence,
140
+ mask_logits=mask_logits,
141
+ sampling_technique=technique,
142
+ temperature=1.0 # Adjust temperature as needed
143
+ )
144
+
145
+ print(f"Filled Sentence: {filled_sentence}\n")
146
+ print('--------------------------------')
utils/old/sampling/sampling_methods_v2.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ import torch
3
+ import random
4
+ from masking_methods import MaskingProcessor
5
+
6
+
7
+ class SamplingProcessorWithPipeline:
8
+ def __init__(self, model_name='bert-base-uncased'):
9
+ self.unmasker = pipeline('fill-mask', model=model_name)
10
+ self.tokenizer = self.unmasker.tokenizer
11
+
12
+ def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0):
13
+ """
14
+ Fills each mask in the masked sentence using the specified sampling technique.
15
+
16
+ Args:
17
+ masked_sentence (str): Sentence with [MASK] tokens.
18
+ sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
19
+ temperature (float): Temperature parameter for sampling methods.
20
+
21
+ Returns:
22
+ str: Sentence with the masks filled.
23
+ """
24
+ while '[MASK]' in masked_sentence:
25
+ # Get predictions for the first [MASK]
26
+ predictions = self.unmasker(masked_sentence)
27
+ print(f' predictions : {predictions}')
28
+ print(f' type of predictions : {type(predictions)}')
29
+
30
+ # Ensure predictions is a list of dictionaries
31
+ if not isinstance(predictions, list) or not all(isinstance(pred, dict) for pred in predictions):
32
+ raise ValueError("Unexpected structure in predictions from the pipeline.")
33
+
34
+ # Extract logits (scores) from the predictions
35
+ logits = torch.tensor([pred['score'] for pred in predictions], dtype=torch.float32)
36
+
37
+ if sampling_technique == "inverse_transform":
38
+ probs = torch.softmax(logits / temperature, dim=-1)
39
+ cumulative_probs = torch.cumsum(probs, dim=-1)
40
+ random_prob = random.random()
41
+ sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
42
+
43
+ elif sampling_technique == "exponential_minimum":
44
+ probs = torch.softmax(logits / temperature, dim=-1)
45
+ exp_probs = torch.exp(-torch.log(probs))
46
+ random_probs = torch.rand_like(exp_probs)
47
+ sampled_index = torch.argmax(random_probs * exp_probs).item()
48
+
49
+ elif sampling_technique == "temperature":
50
+ logits = torch.clamp(logits, min=-1e8, max=1e8)
51
+ probs = torch.softmax(logits / temperature, dim=-1)
52
+ if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
53
+ raise ValueError("The computed probabilities contain NaN or inf values.")
54
+ probs = torch.max(probs, torch.tensor(1e-8, device=logits.device))
55
+ probs = probs / torch.sum(probs)
56
+ probs = probs.flatten()
57
+ if probs.size(0) > 1:
58
+ sampled_index = torch.multinomial(probs, 1).item()
59
+ else:
60
+ sampled_index = torch.argmax(probs).item()
61
+
62
+ elif sampling_technique == 'greedy':
63
+ sampled_index = torch.argmax(logits).item()
64
+
65
+ else:
66
+ raise ValueError(f"Unknown sampling technique: {sampling_technique}")
67
+
68
+ # Replace the first [MASK] with the selected word
69
+ sampled_token = predictions[sampled_index]['token_str']
70
+ masked_sentence = masked_sentence.replace('[MASK]', sampled_token, 1)
71
+
72
+ return masked_sentence
73
+
74
+
75
+ # Example usage
76
+ if __name__ == "__main__":
77
+ from transformers import BertTokenizer
78
+
79
+ # Define sentences and result_dict
80
+ sentences = [
81
+ "The quick brown fox jumps over the lazy dog.",
82
+ "A quick brown dog outpaces a lazy fox.",
83
+ "Quick brown dog leaps over lazy the fox."
84
+ ]
85
+ result_dict = {
86
+ "The quick brown fox jumps over the lazy dog.": {'quick brown': [(0, 1)], 'fox': [(2, 2)], 'lazy': [(4, 4)], 'dog': [(5, 5)]},
87
+ "A quick brown dog outpaces a lazy fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]},
88
+ "Quick brown dog leaps over lazy the fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]}
89
+ }
90
+
91
+ masking_processor = MaskingProcessor()
92
+ masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
93
+
94
+ # Use SamplingProcessor
95
+ sampling_processor = SamplingProcessorWithPipeline()
96
+
97
+ # Iterate through masking results to apply sampling
98
+ for sentence, result in masking_results.items():
99
+ print(f"Original Sentence (Random): {sentence}")
100
+ print(f"Masked Sentence (Random): {result['masked_sentence']}")
101
+ masked_sentence = result["masked_sentence"]
102
+
103
+ # Apply different sampling techniques
104
+ for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]:
105
+ print(f"Sampling Technique: {technique}")
106
+ filled_sentence = sampling_processor.fill_masked_sentence(
107
+ masked_sentence=masked_sentence,
108
+ sampling_technique=technique,
109
+ temperature=1.0 # Adjust temperature as needed
110
+ )
111
+ print(f"Filled Sentence: {filled_sentence}\n")
112
+ print('--------------------------------')
utils/old/sampling_final_copy.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ from masking_methods import MaskingProcessor
4
+
5
+ class SamplingProcessor:
6
+ def __init__(self, tokenizer):
7
+ """
8
+ Initialize the SamplingProcessor.
9
+
10
+ Args:
11
+ tokenizer: BERT tokenizer instance
12
+ """
13
+ self.tokenizer = tokenizer
14
+
15
+ def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0):
16
+ """
17
+ Sample tokens for each mask in the sentence using the specified sampling technique.
18
+
19
+ Args:
20
+ mask_logits_dict (dict): Dictionary of mask positions and their logits/tokens
21
+ masked_sentence (str): Sentence with [MASK] tokens
22
+ sampling_technique (str): Sampling method to use
23
+ temperature (float): Temperature parameter for sampling
24
+
25
+ Returns:
26
+ str: Sentence with sampled tokens replacing masks
27
+ """
28
+ words = masked_sentence.split()
29
+
30
+ # Convert positions and logits to sorted list to process masks in order
31
+ mask_positions = sorted(mask_logits_dict.keys())
32
+
33
+ for mask_pos in mask_positions:
34
+ mask_data = mask_logits_dict[mask_pos]
35
+ mask_logits = torch.tensor(mask_data['logits'])
36
+ candidate_tokens = mask_data['tokens']
37
+
38
+ try:
39
+ if sampling_technique == "inverse_transform":
40
+ probs = torch.softmax(mask_logits / temperature, dim=-1)
41
+ cumulative_probs = torch.cumsum(probs, dim=-1)
42
+ random_prob = random.random()
43
+ sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
44
+
45
+ elif sampling_technique == "exponential_minimum":
46
+ probs = torch.softmax(mask_logits / temperature, dim=-1)
47
+ exp_probs = torch.exp(-torch.log(probs))
48
+ random_probs = torch.rand_like(exp_probs)
49
+ sampled_index = torch.argmax(random_probs * exp_probs).item()
50
+
51
+ elif sampling_technique == "temperature":
52
+ mask_logits = torch.clamp(mask_logits, min=-1e8, max=1e8)
53
+ probs = torch.softmax(mask_logits / temperature, dim=-1)
54
+ if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
55
+ raise ValueError("The computed probabilities contain NaN or inf values.")
56
+ probs = torch.max(probs, torch.tensor(1e-8))
57
+ probs = probs / torch.sum(probs)
58
+ probs = probs.flatten()
59
+ if probs.size(0) > 1:
60
+ sampled_index = torch.multinomial(probs, 1).item()
61
+ else:
62
+ sampled_index = torch.argmax(probs).item()
63
+
64
+ elif sampling_technique == 'greedy':
65
+ sampled_index = torch.argmax(mask_logits).item()
66
+
67
+ else:
68
+ raise ValueError(f"Unknown sampling technique: {sampling_technique}")
69
+
70
+ # Use the sampled index to get the corresponding token
71
+ sampled_token = candidate_tokens[sampled_index]
72
+ # Remove ## if it's a subword token
73
+ sampled_token = sampled_token.replace('##', '')
74
+ words[mask_pos] = sampled_token
75
+
76
+ except Exception as e:
77
+ print(f"Error sampling for position {mask_pos}: {str(e)}")
78
+ continue
79
+
80
+ return " ".join(words)
81
+
82
+ def process_masked_sentences(self, results_dict, sampling_technique="temperature", temperature=1.0):
83
+ """
84
+ Process all masked sentences in the results dictionary.
85
+
86
+ Args:
87
+ results_dict (dict): Dictionary containing masked sentences and their logits
88
+ sampling_technique (str): Sampling method to use
89
+ temperature (float): Temperature parameter for sampling
90
+
91
+ Returns:
92
+ dict: Dictionary containing original, masked, and sampled sentences
93
+ """
94
+ processed_results = {}
95
+
96
+ for original_sentence, data in results_dict.items():
97
+ masked_sentence = data["masked_sentence"]
98
+ mask_logits = data["mask_logits"]
99
+
100
+ sampled_sentence = self.sample_tokens(
101
+ mask_logits,
102
+ masked_sentence,
103
+ sampling_technique,
104
+ temperature
105
+ )
106
+
107
+ processed_results[original_sentence] = {
108
+ "masked_sentence": masked_sentence,
109
+ "sampled_sentence": sampled_sentence
110
+ }
111
+
112
+ return processed_results
113
+
114
+
115
+ if __name__ == "__main__":
116
+ sentences = [
117
+ "The quick brown fox jumps over the lazy dog everyday.",
118
+ "A speedy brown fox jumps over a lazy dog.",
119
+ "A swift brown fox leaps over the lethargic dog."
120
+
121
+ ]
122
+ result_dict ={
123
+ 'The quick brown fox jumps over the lazy dog everyday.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
124
+ 'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
125
+ 'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}
126
+ }
127
+
128
+ # First, mask the sentences
129
+ masking_processor = MaskingProcessor()
130
+ masking_results = masking_processor.process_sentences(sentences, result_dict)
131
+
132
+ # Then, sample replacements for the masks
133
+ sampling_processor = SamplingProcessor(masking_processor.tokenizer)
134
+
135
+ # Try different sampling techniques
136
+ sampling_techniques = ["temperature", "greedy", "inverse_transform", "exponential_minimum"]
137
+
138
+ for technique in sampling_techniques:
139
+ print(f"\nSampling using {technique}:")
140
+ sampled_results = sampling_processor.process_masked_sentences(
141
+ masking_results,
142
+ sampling_technique=technique,
143
+ temperature=1.0
144
+ )
145
+
146
+ '''
147
+ {
148
+ "original_sentence_1":
149
+ {
150
+ "masked_sentence": "sentence with [MASK] tokens",
151
+ "sampling_method1": "sentence with sampled tokens",
152
+ },
153
+ "original_sentence_2":
154
+ {
155
+ "masked_sentence": "sentence with [MASK] tokens",
156
+ "sampling_method": "sentence with sampled tokens"
157
+ },
158
+ # ... and so on for each input sentence
159
+ },
160
+
161
+ '''
162
+
163
+ for original_sentence, result in sampled_results.items():
164
+ print(f"Original: {original_sentence}")
165
+ print(f"Masked: {result['masked_sentence']}")
166
+ print(f"Sampled: {result['sampled_sentence']}")
167
+ print("---")
168
+