Spaces:
Sleeping
Sleeping
| import random | |
| import torch | |
| import logging | |
| import string | |
| from transformers import BertTokenizer, BertForMaskedLM | |
| from nltk.corpus import stopwords | |
| import nltk | |
| from tqdm import tqdm | |
| # Set logging to WARNING for a cleaner terminal. | |
| logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger(__name__) | |
| # Ensure stopwords are downloaded | |
| try: | |
| nltk.data.find('corpora/stopwords') | |
| except LookupError: | |
| nltk.download('stopwords') | |
| def clean_word(word): | |
| """More robust cleaning for consistent matching""" | |
| # Remove possessive 's before other punctuation | |
| if word.lower().endswith("'s"): | |
| word = word[:-2] | |
| return word.lower().strip().translate(str.maketrans('', '', string.punctuation)) | |
| class MaskingProcessor: | |
| def __init__(self, tokenizer, model): | |
| self.tokenizer = tokenizer | |
| self.model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.stop_words = set(stopwords.words('english')) | |
| tqdm.write(f"[MaskingProcessor] Initialized on device: {self.device}") | |
| def is_stopword(self, word): | |
| """Check if a word is a stopword, handling punctuation and case""" | |
| return clean_word(word) in self.stop_words | |
| def verify_and_correct_ngram_positions(self, sentence, common_ngrams): | |
| """Verify ngram positions match actual words in sentence and correct if needed.""" | |
| words = sentence.split() | |
| corrected_ngrams = {} | |
| for ngram, positions in common_ngrams.items(): | |
| corrected_positions = [] | |
| ngram_words = ngram.split() | |
| # Convert ngram words to clean format for matching | |
| clean_ngram_words = [clean_word(word) for word in ngram_words] | |
| # Scan the sentence to find actual occurrences of the ngram | |
| for i in range(len(words) - len(ngram_words) + 1): | |
| is_match = True | |
| for j, ngram_word in enumerate(clean_ngram_words): | |
| if clean_word(words[i + j]) != ngram_word: | |
| is_match = False | |
| break | |
| if is_match: | |
| # Found a matching position, add it | |
| corrected_positions.append((i, i + len(ngram_words) - 1)) | |
| if corrected_positions: | |
| corrected_ngrams[ngram] = corrected_positions | |
| else: | |
| # Log the issue and perform a more flexible search | |
| print(f"Warning: Could not find exact match for '{ngram}' in the sentence.") | |
| print(f"Attempting flexible matching...") | |
| # Try a more flexible approach by looking for individual words | |
| for i in range(len(words)): | |
| if clean_word(words[i]) == clean_ngram_words[0]: | |
| # We found the first word of the ngram | |
| if len(ngram_words) == 1 or ( | |
| i + len(ngram_words) <= len(words) and | |
| all(clean_word(words[i+j]).startswith(clean_ngram_words[j]) for j in range(len(ngram_words))) | |
| ): | |
| corrected_positions.append((i, i + len(ngram_words) - 1)) | |
| if corrected_positions: | |
| print(f"Found flexible matches for '{ngram}': {corrected_positions}") | |
| corrected_ngrams[ngram] = corrected_positions | |
| else: | |
| # If still no match, keep original positions as fallback | |
| print(f"No matches found for '{ngram}'. Keeping original positions.") | |
| corrected_ngrams[ngram] = positions | |
| # Log changes | |
| if corrected_ngrams != common_ngrams: | |
| print(f"Original ngram positions: {common_ngrams}") | |
| print(f"Corrected ngram positions: {corrected_ngrams}") | |
| return corrected_ngrams | |
| def in_any_ngram(self, idx, ngram_positions): | |
| """Check if an original sentence index is part of any n-gram span""" | |
| return any(start <= idx <= end for start, end in ngram_positions) | |
| def create_fallback_mask(self, sentence, ngrams): | |
| """Create a fallback mask when normal strategies fail.""" | |
| try: | |
| words = sentence.split() | |
| if not words: | |
| return None | |
| # Find any non-stopword that isn't in an ngram | |
| ngram_positions = [] | |
| for positions in ngrams.values(): | |
| for start, end in positions: | |
| ngram_positions.append((start, end)) | |
| ngram_positions.sort() | |
| # Find first eligible word | |
| for idx, word in enumerate(words): | |
| if not self.is_stopword(word) and not self.in_any_ngram(idx, ngram_positions): | |
| masked_words = words.copy() | |
| masked_words[idx] = self.tokenizer.mask_token | |
| tqdm.write(f"[INFO] Fallback mask created at position {idx}: '{word}'") | |
| return " ".join(masked_words), [idx] | |
| # If no eligible word found, just mask the first non-stop word | |
| for idx, word in enumerate(words): | |
| if not self.is_stopword(word): | |
| masked_words = words.copy() | |
| masked_words[idx] = self.tokenizer.mask_token | |
| tqdm.write(f"[INFO] Last resort fallback mask created at position {idx}: '{word}'") | |
| return " ".join(masked_words), [idx] | |
| # If still nothing, mask the first word | |
| if words: | |
| masked_words = words.copy() | |
| masked_words[0] = self.tokenizer.mask_token | |
| return " ".join(masked_words), [0] | |
| return None | |
| except Exception as e: | |
| tqdm.write(f"[ERROR] Error creating fallback mask: {e}") | |
| return None | |
| def mask_sentence_random(self, sentence, common_ngrams): | |
| """Mask random non-stopwords that are not part of common ngrams with controlled positioning.""" | |
| common_ngrams = self.verify_and_correct_ngram_positions(sentence, common_ngrams) | |
| tqdm.write(f"[MaskingProcessor] Masking (random) sentence: {sentence}") | |
| original_words = sentence.split() | |
| # Handle punctuation | |
| has_punctuation = False | |
| punctuation = '' | |
| if original_words and original_words[-1][-1] in ['.', ',', '!', '?', ';', ':', '"', "'"]: | |
| has_punctuation = True | |
| punctuation = original_words[-1][-1] | |
| original_words[-1] = original_words[-1][:-1] | |
| if not original_words[-1]: # If the word was just punctuation | |
| original_words.pop() | |
| # Get flattened ngram positions | |
| ngram_positions = [] | |
| for positions in common_ngrams.values(): | |
| for start, end in positions: | |
| ngram_positions.append((start, end)) | |
| ngram_positions.sort() | |
| # Find all candidate indices (non-stopwords not in ngrams) | |
| candidate_indices = [] | |
| for idx, word in enumerate(original_words): | |
| if not self.is_stopword(word) and not self.in_any_ngram(idx, ngram_positions): | |
| candidate_indices.append(idx) | |
| # Debug print candidate words | |
| print("Candidate words for masking:") | |
| for idx in candidate_indices: | |
| print(f" Position {idx}: '{original_words[idx]}'") | |
| selected_indices = [] | |
| if ngram_positions: | |
| # Before first ngram | |
| before_first = [idx for idx in candidate_indices if idx < ngram_positions[0][0]] | |
| if before_first: | |
| num_to_select = min(1, len(before_first)) # Select 1 word | |
| if num_to_select > 0: | |
| selected = random.sample(before_first, num_to_select) | |
| selected_indices.extend(selected) | |
| # Between ngrams | |
| for i in range(len(ngram_positions) - 1): | |
| between = [idx for idx in candidate_indices | |
| if ngram_positions[i][1] < idx < ngram_positions[i+1][0]] | |
| if between: | |
| num_to_select = min(2, len(between)) # Select between 1-2 words | |
| if num_to_select > 0: | |
| selected = random.sample(between, num_to_select) | |
| selected_indices.extend(selected) | |
| # After last ngram | |
| after_last = [idx for idx in candidate_indices if idx > ngram_positions[-1][1]] | |
| if after_last: | |
| num_to_select = min(1, len(after_last)) # Select 1 word | |
| if num_to_select > 0: | |
| selected = random.sample(after_last, num_to_select) | |
| selected_indices.extend(selected) | |
| else: | |
| # If no ngrams, pick up to 6 random candidates | |
| if candidate_indices: | |
| selected_indices = random.sample(candidate_indices, | |
| min(6, len(candidate_indices))) | |
| masked_words = original_words.copy() | |
| for idx in selected_indices: | |
| masked_words[idx] = self.tokenizer.mask_token | |
| if has_punctuation: | |
| masked_words.append(punctuation) | |
| # Debug prints | |
| print("Original sentence:", sentence) | |
| print("Common ngrams:", common_ngrams) | |
| print("Common ngram positions:", ngram_positions) | |
| print("Candidate indices for masking:", candidate_indices) | |
| print("Selected for masking:", selected_indices) | |
| print("Masked sentence:", " ".join(masked_words)) | |
| return " ".join(masked_words), selected_indices | |
| def mask_sentence_pseudorandom(self, sentence, common_ngrams): | |
| """Mask specific non-stopwords based on their position relative to ngrams.""" | |
| common_ngrams = self.verify_and_correct_ngram_positions(sentence, common_ngrams) | |
| tqdm.write(f"[MaskingProcessor] Masking (pseudorandom) sentence: {sentence}") | |
| random.seed(3) # Fixed seed for pseudorandom behavior | |
| original_words = sentence.split() | |
| # Handle punctuation | |
| has_punctuation = False | |
| punctuation = '' | |
| if original_words and original_words[-1][-1] in ['.', ',', '!', '?', ';', ':', '"', "'"]: | |
| has_punctuation = True | |
| punctuation = original_words[-1][-1] | |
| original_words[-1] = original_words[-1][:-1] | |
| if not original_words[-1]: # If the word was just punctuation | |
| original_words.pop() | |
| # Get flattened ngram positions | |
| ngram_positions = [] | |
| for positions in common_ngrams.values(): | |
| for start, end in positions: | |
| ngram_positions.append((start, end)) | |
| ngram_positions.sort() | |
| # Find all candidate indices (non-stopwords not in ngrams) | |
| candidate_indices = [] | |
| for idx, word in enumerate(original_words): | |
| if not self.is_stopword(word) and not self.in_any_ngram(idx, ngram_positions): | |
| candidate_indices.append(idx) | |
| # Debug print candidate words | |
| print("Candidate words for masking:") | |
| for idx in candidate_indices: | |
| print(f" Position {idx}: '{original_words[idx]}'") | |
| # PSEUDORANDOM SPECIFIC LOGIC: | |
| selected_indices = [] | |
| if ngram_positions: | |
| # Before first ngram | |
| before_first = [idx for idx in candidate_indices if idx < ngram_positions[0][0]] | |
| if before_first: | |
| num_to_select = min(1, len(before_first)) # Select 1 word | |
| if num_to_select > 0: | |
| selected = random.sample(before_first, num_to_select) | |
| selected_indices.extend(selected) | |
| # Between ngrams | |
| for i in range(len(ngram_positions) - 1): | |
| between = [idx for idx in candidate_indices | |
| if ngram_positions[i][1] < idx < ngram_positions[i+1][0]] | |
| if between: | |
| num_to_select = min(2, len(between)) # Select between 1-2 words | |
| if num_to_select > 0: | |
| selected = random.sample(between, num_to_select) | |
| selected_indices.extend(selected) | |
| # After last ngram | |
| after_last = [idx for idx in candidate_indices if idx > ngram_positions[-1][1]] | |
| if after_last: | |
| num_to_select = min(1, len(after_last)) # Select 1 word | |
| if num_to_select > 0: | |
| selected = random.sample(after_last, num_to_select) | |
| selected_indices.extend(selected) | |
| else: | |
| # If no ngrams, pick up to 6 random candidates | |
| if candidate_indices: | |
| selected_indices = random.sample(candidate_indices, | |
| min(6, len(candidate_indices))) | |
| masked_words = original_words.copy() | |
| for idx in selected_indices: | |
| masked_words[idx] = self.tokenizer.mask_token | |
| if has_punctuation: | |
| masked_words.append(punctuation) | |
| # Debug prints | |
| print("Original sentence:", sentence) | |
| print("Common ngrams:", common_ngrams) | |
| print("Common ngram positions:", ngram_positions) | |
| print("Candidate indices for masking:", candidate_indices) | |
| print("Selected for masking:", selected_indices) | |
| print("Masked sentence:", " ".join(masked_words)) | |
| return " ".join(masked_words), selected_indices | |
| def mask_sentence_entropy(self, sentence, common_ngrams): | |
| """Mask words with highest entropy that are not part of common ngrams.""" | |
| common_ngrams = self.verify_and_correct_ngram_positions(sentence, common_ngrams) | |
| tqdm.write(f"[MaskingProcessor] Masking (entropy) sentence: {sentence}") | |
| original_words = sentence.split() | |
| # Handle punctuation | |
| has_punctuation = False | |
| punctuation = '' | |
| if original_words and original_words[-1][-1] in ['.', ',', '!', '?', ';', ':', '"', "'"]: | |
| has_punctuation = True | |
| punctuation = original_words[-1][-1] | |
| original_words[-1] = original_words[-1][:-1] | |
| if not original_words[-1]: # If the word was just punctuation | |
| original_words.pop() | |
| # Get flattened ngram positions | |
| ngram_positions = [] | |
| for positions in common_ngrams.values(): | |
| for start, end in positions: | |
| ngram_positions.append((start, end)) | |
| ngram_positions.sort() | |
| # Find all candidate indices (non-stopwords not in ngrams) | |
| candidate_indices = [] | |
| for idx, word in enumerate(original_words): | |
| if not self.is_stopword(word) and not self.in_any_ngram(idx, ngram_positions): | |
| candidate_indices.append(idx) | |
| # Debug print candidate words | |
| print("Candidate words for masking:") | |
| for idx in candidate_indices: | |
| print(f" Position {idx}: '{original_words[idx]}'") | |
| # ENTROPY SPECIFIC LOGIC: | |
| # Calculate entropy for each candidate word | |
| selected_indices = [] | |
| if candidate_indices: | |
| # Organize candidates by position relative to ngrams | |
| if ngram_positions: | |
| # Group candidates by position | |
| before_first = [] | |
| between_ngrams = {} | |
| after_last = [] | |
| for idx in candidate_indices: | |
| if idx < ngram_positions[0][0]: | |
| before_first.append(idx) | |
| elif idx > ngram_positions[-1][1]: | |
| after_last.append(idx) | |
| else: | |
| # Find which ngram gap this belongs to | |
| for i in range(len(ngram_positions) - 1): | |
| if ngram_positions[i][1] < idx < ngram_positions[i+1][0]: | |
| if i not in between_ngrams: | |
| between_ngrams[i] = [] | |
| between_ngrams[i].append(idx) | |
| # Before first ngram: select 1 highest entropy words | |
| if before_first: | |
| entropies = [(idx, self.calculate_word_entropy(sentence, idx)) for idx in before_first] | |
| entropies.sort(key=lambda x: x[1], reverse=True) # Sort by entropy (highest first) | |
| num_to_select = min(1, len(entropies)) # Select 1 word | |
| selected_indices.extend([idx for idx, _ in entropies[:num_to_select]]) | |
| # For each gap between ngrams: select 1-2 highest entropy words | |
| for group, indices in between_ngrams.items(): | |
| if indices: | |
| entropies = [(idx, self.calculate_word_entropy(sentence, idx)) for idx in indices] | |
| entropies.sort(key=lambda x: x[1], reverse=True) # Sort by entropy (highest first) | |
| num_to_select = min(2, len(entropies)) # Select between 1-2 words | |
| selected_indices.extend([idx for idx, _ in entropies[:num_to_select]]) | |
| # After last ngram: select 1 highest entropy words | |
| if after_last: | |
| entropies = [(idx, self.calculate_word_entropy(sentence, idx)) for idx in after_last] | |
| entropies.sort(key=lambda x: x[1], reverse=True) # Sort by entropy (highest first) | |
| num_to_select = min(1, len(entropies)) # Select 1 word | |
| selected_indices.extend([idx for idx, _ in entropies[:num_to_select]]) | |
| else: | |
| # If no ngrams, calculate entropy for all candidates | |
| entropies = [(idx, self.calculate_word_entropy(sentence, idx)) for idx in candidate_indices] | |
| # Sort by entropy (highest first) | |
| entropies.sort(key=lambda x: x[1], reverse=True) | |
| # Take top 6 or all if fewer | |
| selected_indices = [idx for idx, _ in entropies[:min(6, len(entropies))]] | |
| masked_words = original_words.copy() | |
| for idx in selected_indices: | |
| masked_words[idx] = self.tokenizer.mask_token | |
| if has_punctuation: | |
| masked_words.append(punctuation) | |
| # Debug prints | |
| print("Original sentence:", sentence) | |
| print("Common ngrams:", common_ngrams) | |
| print("Common ngram positions:", ngram_positions) | |
| print("Candidate indices for masking:", candidate_indices) | |
| print("Selected for masking:", selected_indices) | |
| print("Masked sentence:", " ".join(masked_words)) | |
| return " ".join(masked_words), selected_indices | |
| def calculate_mask_logits(self, original_sentence, original_mask_indices): | |
| """Calculate logits for masked positions.""" | |
| logger.info(f"Calculating mask logits for sentence: {original_sentence}") | |
| words = original_sentence.split() | |
| mask_logits = {} | |
| for idx in original_mask_indices: | |
| masked_words = words.copy() | |
| masked_words[idx] = self.tokenizer.mask_token | |
| masked_sentence = " ".join(masked_words) | |
| input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"].to(self.device) | |
| mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] | |
| with torch.no_grad(): | |
| outputs = self.model(input_ids) | |
| logits = outputs.logits | |
| mask_logits_tensor = logits[0, mask_token_index, :] | |
| top_mask_logits, top_mask_indices = torch.topk(mask_logits_tensor, 100, dim=-1) | |
| top_tokens = [] | |
| top_logits = [] | |
| seen_words = set() | |
| for token_id, logit in zip(top_mask_indices[0], top_mask_logits[0]): | |
| token = self.tokenizer.convert_ids_to_tokens(token_id.item()) | |
| if token.startswith('##'): | |
| continue | |
| word = self.tokenizer.convert_tokens_to_string([token]).strip() | |
| if word and word not in seen_words: | |
| seen_words.add(word) | |
| top_tokens.append(word) | |
| top_logits.append(logit.item()) | |
| if len(top_tokens) == 50: | |
| break | |
| mask_logits[idx] = { | |
| "tokens": top_tokens, | |
| "logits": top_logits | |
| } | |
| logger.info("Completed calculating mask logits.") | |
| return mask_logits | |
| def calculate_word_entropy(self, sentence, word_position): | |
| """Calculate entropy for a word at a specific position.""" | |
| logger.info(f"Calculating word entropy for position {word_position} in sentence: {sentence}") | |
| words = sentence.split() | |
| masked_words = words.copy() | |
| masked_words[word_position] = self.tokenizer.mask_token | |
| masked_sentence = " ".join(masked_words) | |
| input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"].to(self.device) | |
| mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] | |
| with torch.no_grad(): | |
| outputs = self.model(input_ids) | |
| logits = outputs.logits | |
| probs = torch.nn.functional.softmax(logits[0, mask_token_index], dim=-1) | |
| entropy = -torch.sum(probs * torch.log(probs + 1e-9)) | |
| logger.info(f"Computed entropy: {entropy.item()}") | |
| return entropy.item() | |
| def process_sentences(self, sentences_list, common_grams, method="random"): | |
| """Process multiple sentences with the specified masking method.""" | |
| tqdm.write(f"[MaskingProcessor] Processing sentences using method: {method}") | |
| results = {} | |
| for sentence in tqdm(sentences_list, desc="Masking Sentences"): | |
| try: | |
| ngrams = common_grams.get(sentence, {}) | |
| if method == "random": | |
| masked_sentence, original_mask_indices = self.mask_sentence_random(sentence, ngrams) | |
| elif method == "pseudorandom": | |
| masked_sentence, original_mask_indices = self.mask_sentence_pseudorandom(sentence, ngrams) | |
| else: # entropy | |
| masked_sentence, original_mask_indices = self.mask_sentence_entropy(sentence, ngrams) | |
| # Skip if no masks were applied | |
| if not original_mask_indices: | |
| tqdm.write(f"[WARNING] No mask indices found for sentence with method {method}: {sentence[:50]}...") | |
| # Create a fallback masked sentence with at least one mask | |
| fallback_result = self.create_fallback_mask(sentence, ngrams) | |
| if fallback_result: | |
| masked_sentence, original_mask_indices = fallback_result | |
| tqdm.write(f"[INFO] Created fallback mask for sentence") | |
| else: | |
| tqdm.write(f"[WARNING] Could not create fallback mask, skipping sentence") | |
| continue | |
| logits = self.calculate_mask_logits(sentence, original_mask_indices) | |
| results[sentence] = { | |
| "masked_sentence": masked_sentence, | |
| "mask_logits": logits | |
| } | |
| logger.info(f"Processed sentence: {sentence}") | |
| except Exception as e: | |
| tqdm.write(f"[ERROR] Failed to process sentence with method {method}: {e}") | |
| tqdm.write(f"Sentence: {sentence[:100]}...") | |
| import traceback | |
| tqdm.write(traceback.format_exc()) | |
| tqdm.write("[MaskingProcessor] Completed processing sentences.") | |
| return results | |
| def identify_common_ngrams(sentences, entities): | |
| """Enhanced to handle possessive forms better""" | |
| common_grams = {} | |
| # Pre-process entities to handle variations | |
| processed_entities = [] | |
| for entity in entities: | |
| processed_entities.append(entity) | |
| # Add possessive form if not already there | |
| if not entity.endswith("'s") and not entity.endswith("s"): | |
| processed_entities.append(f"{entity}'s") | |
| for sentence in sentences: | |
| words = sentence.split() | |
| common_grams[sentence] = {} | |
| # Look for each entity in the sentence | |
| for entity in processed_entities: | |
| entity_words = entity.split() | |
| entity_len = len(entity_words) | |
| # Convert entity words for matching | |
| clean_entity_words = [clean_word(word) for word in entity_words] | |
| # Find all occurrences | |
| for i in range(len(words) - entity_len + 1): | |
| is_match = True | |
| for j, entity_word in enumerate(clean_entity_words): | |
| if clean_word(words[i + j]) != entity_word: | |
| is_match = False | |
| break | |
| if is_match: | |
| # Use canonical form from entity list for consistency | |
| base_entity = entity | |
| if entity.endswith("'s") and any(e == entity[:-2] for e in processed_entities): | |
| base_entity = entity[:-2] | |
| if base_entity not in common_grams[sentence]: | |
| common_grams[sentence][base_entity] = [] | |
| common_grams[sentence][base_entity].append((i, i + entity_len - 1)) | |
| return common_grams | |
| if __name__ == "__main__": | |
| #example test | |
| # test_sentence = "Kevin De Bruyne scored for Manchester City as they won the 2019-20 Premier League title." | |
| # entities to preserve | |
| # entities = ["Kevin De Bruyne", "Manchester City", "Premier League"] | |
| # Identify common n-grams | |
| common_grams = MaskingProcessor.identify_common_ngrams([test_sentence], entities) | |
| # Print detected n-grams | |
| print(f"Detected common n-grams: {common_grams[test_sentence]}") | |
| # Initialize the processor | |
| processor = MaskingProcessor( | |
| BertTokenizer.from_pretrained("bert-large-cased-whole-word-masking"), | |
| BertForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking") | |
| ) | |
| # Test all three masking methods | |
| print("\nTesting Random Masking:") | |
| masked_random, indices_random = processor.mask_sentence_random(test_sentence, common_grams[test_sentence]) | |
| print("\nTesting Pseudorandom Masking:") | |
| masked_pseudorandom, indices_pseudorandom = processor.mask_sentence_pseudorandom(test_sentence, common_grams[test_sentence]) | |
| print("\nTesting Entropy Masking:") | |
| masked_entropy, indices_entropy = processor.mask_sentence_entropy(test_sentence, common_grams[test_sentence]) | |
| # Print results | |
| print("\nResults:") | |
| print(f"Original: {test_sentence}") | |
| print(f"Random Masked: {masked_random}") | |
| print(f"Pseudorandom Masked: {masked_pseudorandom}") | |
| print(f"Entropy Masked: {masked_entropy}") |