Spaces:
Sleeping
Sleeping
| import plotly.graph_objects as go | |
| import textwrap | |
| import re | |
| from collections import defaultdict | |
| def generate_subplot1(paraphrased_sentence, masked_sentences, strategies, highlight_info, common_grams): | |
| """ | |
| Generates a subplot visualizing paraphrased and masked sentences in a tree structure. | |
| Highlights common words with specific colors and applies Longest Common Subsequence (LCS) numbering. | |
| Args: | |
| paraphrased_sentence (str): The paraphrased sentence to be visualized. | |
| masked_sentences (list of str): A list of masked sentences to be visualized. | |
| strategies (list of str, optional): List of strategies used for each masked sentence. | |
| highlight_info (list of tuples): A list of tuples where each tuple contains a word and its associated color for highlighting. | |
| common_grams (list of tuples): A list of tuples containing an index and a common word or phrase for LCS numbering. | |
| Returns: | |
| plotly.graph_objects.Figure: A Plotly figure representing the tree structure with highlighted words and labeled edges. | |
| """ | |
| # Combine nodes into one list with appropriate labels | |
| if isinstance(masked_sentences, str): | |
| masked_sentences = [masked_sentences] | |
| nodes = [paraphrased_sentence] + masked_sentences | |
| nodes[0] += ' L0' # Paraphrased sentence is level 0 | |
| if len(nodes) < 2: | |
| print("[ERROR] Insufficient nodes for visualization") | |
| return go.Figure() | |
| for i in range(1, len(nodes)): | |
| nodes[i] += ' L1' # masked sentences are level 1 | |
| def apply_lcs_numbering(sentence, common_grams): | |
| """ | |
| Applies LCS numbering to the sentence based on the common_grams. | |
| Args: | |
| sentence (str): The sentence to which the LCS numbering should be applied. | |
| common_grams (list of tuples): A list of common grams to be replaced with LCS numbers. | |
| Returns: | |
| str: The sentence with LCS numbering applied. | |
| """ | |
| for idx, lcs in common_grams: | |
| sentence = re.sub(rf"\b{lcs}\b", f"({idx}){lcs}", sentence) | |
| return sentence | |
| # Apply LCS numbering | |
| nodes = [apply_lcs_numbering(node, common_grams) for node in nodes] | |
| def highlight_words(sentence, color_map): | |
| """ | |
| Highlights words in the sentence based on the color_map. | |
| Args: | |
| sentence (str): The sentence where the words will be highlighted. | |
| color_map (dict): A dictionary mapping words to their colors. | |
| Returns: | |
| str: The sentence with highlighted words. | |
| """ | |
| for word, color in color_map.items(): | |
| sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE) | |
| return sentence | |
| # Clean and wrap nodes, and highlight specified words globally | |
| cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes] | |
| global_color_map = dict(highlight_info) | |
| highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes] | |
| wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=55)) for node in highlighted_nodes] | |
| def get_levels_and_edges(nodes, strategies=None): | |
| """ | |
| Determines tree levels and creates edges dynamically. | |
| Args: | |
| nodes (list of str): The nodes representing the sentences. | |
| strategies (list of str, optional): The strategies used for each edge. | |
| Returns: | |
| tuple: A tuple containing two dictionaries: | |
| - levels: A dictionary mapping node indices to their levels. | |
| - edges: A list of edges where each edge is represented by a tuple of node indices. | |
| """ | |
| levels = {} | |
| edges = [] | |
| for i, node in enumerate(nodes): | |
| level = int(node.split()[-1][1]) | |
| levels[i] = level | |
| # Add edges from L0 to all L1 nodes | |
| root_node = next((i for i, level in levels.items() if level == 0), 0) | |
| for i, level in levels.items(): | |
| if level == 1: | |
| edges.append((root_node, i)) | |
| return levels, edges | |
| # Get levels and dynamic edges | |
| levels, edges = get_levels_and_edges(nodes, strategies) | |
| max_level = max(levels.values(), default=0) | |
| # Calculate positions | |
| positions = {} | |
| level_heights = defaultdict(int) | |
| for node, level in levels.items(): | |
| level_heights[level] += 1 | |
| y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()} | |
| x_gap = 2 | |
| l1_y_gap = 10 | |
| for node, level in levels.items(): | |
| if level == 1: | |
| positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap) | |
| else: | |
| positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap) | |
| y_offsets[level] += 1 | |
| def color_highlighted_words(node, color_map): | |
| """ | |
| Colors the highlighted words in the node text. | |
| Args: | |
| node (str): The node text to be highlighted. | |
| color_map (dict): A dictionary mapping words to their colors. | |
| Returns: | |
| str: The node text with highlighted words. | |
| """ | |
| parts = re.split(r'(\{\{.*?\}\})', node) | |
| colored_parts = [] | |
| for part in parts: | |
| match = re.match(r'\{\{(.*?)\}\}', part) | |
| if match: | |
| word = match.group(1) | |
| color = color_map.get(word, 'black') | |
| colored_parts.append(f"<span style='color: {color};'>{word}</span>") | |
| else: | |
| colored_parts.append(part) | |
| return ''.join(colored_parts) | |
| # Define the text for each edge | |
| default_edge_texts = [ | |
| "Highest Entropy Masking", "Pseudo-random Masking", "Random Masking", | |
| "Greedy Sampling", "Temperature Sampling", "Exponential Minimum Sampling", | |
| "Inverse Transform Sampling", "Greedy Sampling", "Temperature Sampling", | |
| "Exponential Minimum Sampling", "Inverse Transform Sampling", "Greedy Sampling", | |
| "Temperature Sampling", "Exponential Minimum Sampling", "Inverse Transform Sampling" | |
| ] | |
| if len(nodes) < 2: | |
| print("[ERROR] Insufficient nodes for visualization") | |
| return go.Figure() | |
| # Create figure | |
| fig1 = go.Figure() | |
| # Add nodes to the figure | |
| for i, node in enumerate(wrapped_nodes): | |
| colored_node = color_highlighted_words(node, global_color_map) | |
| x, y = positions[i] | |
| fig1.add_trace(go.Scatter( | |
| x=[-x], # Reflect the x coordinate | |
| y=[y], | |
| mode='markers', | |
| marker=dict(size=20, color='blue', line=dict(color='black', width=2)), | |
| hoverinfo='none' | |
| )) | |
| fig1.add_annotation( | |
| x=-x, # Reflect the x coordinate | |
| y=y, | |
| text=colored_node, | |
| showarrow=False, | |
| xshift=15, | |
| align="center", | |
| font=dict(size=12), | |
| bordercolor='black', | |
| borderwidth=2, | |
| borderpad=4, | |
| bgcolor='white', | |
| width=400, | |
| height=100 | |
| ) | |
| # Add edges and text above each edge | |
| for i, edge in enumerate(edges): | |
| x0, y0 = positions[edge[0]] | |
| x1, y1 = positions[edge[1]] | |
| # Use strategy if available, otherwise use default edge text | |
| if strategies and i < len(strategies): | |
| edge_text = strategies[i] | |
| else: | |
| edge_text = default_edge_texts[i % len(default_edge_texts)] | |
| fig1.add_trace(go.Scatter( | |
| x=[-x0, -x1], # Reflect the x coordinates | |
| y=[y0, y1], | |
| mode='lines', | |
| line=dict(color='black', width=1) | |
| )) | |
| # Calculate the midpoint of the edge | |
| mid_x = (-x0 + -x1) / 2 | |
| mid_y = (y0 + y1) / 2 | |
| # Adjust y position to shift text upwards | |
| text_y_position = mid_y + 0.8 # Increase this value to shift the text further upwards | |
| # Add text annotation above the edge | |
| fig1.add_annotation( | |
| x=mid_x, | |
| y=text_y_position, | |
| text=edge_text, # Use the text specific to this edge | |
| showarrow=False, | |
| font=dict(size=12), | |
| align="center" | |
| ) | |
| fig1.update_layout( | |
| showlegend=False, | |
| margin=dict(t=50, b=50, l=50, r=50), | |
| xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
| yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
| width=800 + max_level * 200, # Adjusted width to accommodate more levels | |
| height=300 + len(nodes) * 100, # Adjusted height to accommodate more levels | |
| plot_bgcolor='rgba(240,240,240,0.2)', | |
| paper_bgcolor='white' | |
| ) | |
| return fig1 | |
| def generate_subplot2(masked_sentences, sampled_sentences, highlight_info, common_grams): | |
| """ | |
| Generates a subplot visualizing multiple masked sentences and their sampled variants in a tree structure. | |
| Each masked sentence will have multiple sampled sentences derived from it using different sampling techniques. | |
| Args: | |
| masked_sentences (list of str): A list of masked sentences to be visualized as root nodes. | |
| sampled_sentences (list of str): A list of sampled sentences derived from masked sentences. | |
| highlight_info (list of tuples): A list of tuples where each tuple contains a word and its associated color for highlighting. | |
| common_grams (list of tuples): A list of tuples containing an index and a common word or phrase for LCS numbering. | |
| Returns: | |
| plotly.graph_objects.Figure: A Plotly figure representing the tree structure with highlighted words and labeled edges. | |
| """ | |
| # Define sampling techniques | |
| sampling_techniques = [ | |
| "Greedy Sampling", | |
| "Temperature Sampling", | |
| "Exponential Minimum Sampling", | |
| "Inverse Transform Sampling" | |
| ] | |
| # Calculate total number of nodes | |
| num_masked = len(masked_sentences) | |
| num_sampled_per_masked = len(sampling_techniques) | |
| total_nodes = num_masked + (num_masked * num_sampled_per_masked) | |
| # Combine all sentences into nodes list with appropriate labels | |
| nodes = [] | |
| # Level 0: masked sentences (root nodes) | |
| nodes.extend([s + ' L0' for s in masked_sentences]) | |
| # Level 1: sampled sentences (branch nodes) | |
| # For each masked sentence, we should have samples from each technique | |
| sampled_nodes = [] | |
| # Validate if we have the expected number of sampled sentences | |
| expected_sampled_count = num_masked * num_sampled_per_masked | |
| if len(sampled_sentences) < expected_sampled_count: | |
| # If insufficient samples provided, pad with placeholder sentences | |
| print(f"Warning: Expected {expected_sampled_count} sampled sentences, but got {len(sampled_sentences)}") | |
| while len(sampled_sentences) < expected_sampled_count: | |
| sampled_sentences.append(f"Placeholder sampled sentence {len(sampled_sentences) + 1}") | |
| # Add all sampled sentences with level information | |
| for s in sampled_sentences[:expected_sampled_count]: | |
| sampled_nodes.append(s + ' L1') | |
| nodes.extend(sampled_nodes) | |
| def apply_lcs_numbering(sentence, common_grams): | |
| """ | |
| Applies LCS numbering to the sentence based on the common_grams. | |
| """ | |
| for idx, lcs in common_grams: | |
| sentence = re.sub(rf"\b{lcs}\b", f"({idx}){lcs}", sentence) | |
| return sentence | |
| # Apply LCS numbering | |
| nodes = [apply_lcs_numbering(node, common_grams) for node in nodes] | |
| def highlight_words(sentence, color_map): | |
| """ | |
| Highlights words in the sentence based on the color_map. | |
| """ | |
| for word, color in color_map.items(): | |
| sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE) | |
| return sentence | |
| # Helper function to color highlighted words | |
| def color_highlighted_words(node, color_map): | |
| """ | |
| Colors the highlighted words in the node text. | |
| """ | |
| parts = re.split(r'(\{\{.*?\}\})', node) | |
| colored_parts = [] | |
| for part in parts: | |
| match = re.match(r'\{\{(.*?)\}\}', part) | |
| if match: | |
| word = match.group(1) | |
| color = color_map.get(word, 'black') | |
| colored_parts.append(f"<span style='color: {color};'>{word}</span>") | |
| else: | |
| colored_parts.append(part) | |
| return ''.join(colored_parts) | |
| # Clean nodes, highlight words, and wrap text | |
| cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes] | |
| global_color_map = dict(highlight_info) | |
| highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes] | |
| wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=80)) for node in highlighted_nodes] | |
| # Generate edges based on the tree structure | |
| def get_levels_and_edges(nodes): | |
| levels = {} | |
| edges = [] | |
| # Extract level info from node labels | |
| for i, node in enumerate(nodes): | |
| level = int(node.split()[-1][1]) | |
| levels[i] = level | |
| # Create edges from masked sentences to their sampled variants | |
| for masked_idx in range(num_masked): | |
| # For each masked sentence, create edges to its sampled variants | |
| for technique_idx in range(num_sampled_per_masked): | |
| sampled_idx = num_masked + (masked_idx * num_sampled_per_masked) + technique_idx | |
| if sampled_idx < len(nodes): | |
| edges.append((masked_idx, sampled_idx)) | |
| return levels, edges | |
| levels, edges = get_levels_and_edges(nodes) | |
| # Calculate positions with improved spacing | |
| positions = {} | |
| # Calculate horizontal spacing for the root nodes (masked sentences) | |
| root_x_spacing = 0 # All root nodes at x=0 | |
| root_y_spacing = 8.0 # Vertical spacing between root nodes | |
| # Calculate positions for sampled nodes | |
| sampled_x = 3 # X position for all sampled nodes | |
| # Calculate y positions for root nodes (masked sentences) | |
| root_y_start = -(num_masked - 1) * root_y_spacing / 2 | |
| for i in range(num_masked): | |
| positions[i] = (root_x_spacing, root_y_start + i * root_y_spacing) | |
| # Calculate y positions for sampled nodes | |
| for masked_idx in range(num_masked): | |
| root_y = positions[masked_idx][1] # Y position of parent masked sentence | |
| # Calculate y-spacing for children of this root | |
| children_y_spacing = 1.5 # Vertical spacing between children of the same root | |
| children_y_start = root_y - (num_sampled_per_masked - 1) * children_y_spacing / 2 | |
| # Position each child | |
| for technique_idx in range(num_sampled_per_masked): | |
| child_idx = num_masked + (masked_idx * num_sampled_per_masked) + technique_idx | |
| child_y = children_y_start + technique_idx * children_y_spacing | |
| positions[child_idx] = (sampled_x, child_y) | |
| # Create figure | |
| fig2 = go.Figure() | |
| # Add nodes | |
| for i, node in enumerate(wrapped_nodes): | |
| x, y = positions[i] | |
| # Define node color based on level | |
| node_color = 'blue' if levels[i] == 0 else 'green' | |
| # Add the node marker | |
| fig2.add_trace(go.Scatter( | |
| x=[x], | |
| y=[y], | |
| mode='markers', | |
| marker=dict(size=20, color=node_color, line=dict(color='black', width=2)), | |
| hoverinfo='none' | |
| )) | |
| # Add node label with highlighting | |
| colored_node = color_highlighted_words(node, global_color_map) | |
| fig2.add_annotation( | |
| x=x, | |
| y=y, | |
| text=colored_node, | |
| showarrow=False, | |
| xshift=15, | |
| align="left", | |
| font=dict(size=12), | |
| bordercolor='black', | |
| borderwidth=2, | |
| borderpad=4, | |
| bgcolor='white', | |
| width=450, | |
| height=100 | |
| ) | |
| # Add edges with labels | |
| for i, (src, dst) in enumerate(edges): | |
| x0, y0 = positions[src] | |
| x1, y1 = positions[dst] | |
| # Draw the edge | |
| fig2.add_trace(go.Scatter( | |
| x=[x0, x1], | |
| y=[y0, y1], | |
| mode='lines', | |
| line=dict(color='black', width=1) | |
| )) | |
| # Add sampling technique label | |
| # Determine which sampling technique this is | |
| parent_idx = src | |
| technique_count = sum(1 for k, (s, _) in enumerate(edges) if s == parent_idx and k < i) | |
| technique_label = sampling_techniques[technique_count % len(sampling_techniques)] | |
| # Calculate midpoint for the label | |
| mid_x = (x0 + x1) / 2 | |
| mid_y = (y0 + y1) / 2 | |
| # Add slight offset to avoid overlap | |
| label_offset = 0.1 | |
| fig2.add_annotation( | |
| x=mid_x, | |
| y=mid_y + label_offset, | |
| text=technique_label, | |
| showarrow=False, | |
| font=dict(size=8), | |
| align="center" | |
| ) | |
| # Update layout | |
| fig2.update_layout( | |
| showlegend=False, | |
| margin=dict(t=20, b=20, l=20, r=20), | |
| xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
| yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
| width=1200, # Adjusted width to accommodate more levels | |
| height=2000, # Adjusted height to accommodate more levels | |
| plot_bgcolor='rgba(240,240,240,0.2)', | |
| paper_bgcolor='white' | |
| ) | |
| return fig2 | |
| if __name__ == "__main__": | |
| paraphrased_sentence = "The quick brown fox jumps over the lazy dog." | |
| masked_sentences = [ | |
| "A fast brown fox leaps over the lazy dog.", | |
| "A quick brown fox hops over a lazy dog." | |
| ] | |
| highlight_info = [ | |
| ("quick", "red"), | |
| ("brown", "green"), | |
| ("fox", "blue"), | |
| ("lazy", "purple") | |
| ] | |
| common_grams = [ | |
| (1, "quick brown fox"), | |
| (2, "lazy dog") | |
| ] | |
| fig1 = generate_subplot1(paraphrased_sentence, masked_sentences, highlight_info, common_grams) | |
| fig1.show() | |
| sampled_sentence = ["A fast brown fox jumps over a lazy dog."] | |
| fig2 = generate_subplot2(masked_sentences, sampled_sentence, highlight_info, common_grams) | |
| fig2.show() | |