""" Gradio application to load and run Sparse Autoencoders (SAEs) on hidden states from the PlantRNA-FM model, with support for averaging activations across multiple sequences. """ import spaces import gradio as gr import torch import plotly.express as px from transformers import AutoTokenizer, AutoModel from huggingface_hub import hf_hub_download import os import pathlib import pandas as pd import numpy as np from dictionary_learning import AutoEncoder, utils # --- Constants and Global Setup --- SAE_REPO = "afg1/plantrna-fm-sparse-autoencoders" BASE_MODEL_NAME = "yangheng/PlantRNA-FM" DEVICE = "cpu" #if torch.cuda.is_available() else "cpu" BATCH_SIZE = 16 # Batch size for memory-efficient inference print(f"Using device: {DEVICE}") # --- Global Caches for Models --- # Caching avoids reloading the models on every function call, which is slow. BASE_MODEL_CACHE = {} SAE_MODEL_CACHE = {} # --- Model Loading Functions --- def get_base_model(): """Loads and caches the base PlantRNA-FM model and tokenizer.""" if "model" not in BASE_MODEL_CACHE: print("Loading base model and tokenizer...") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME) model = AutoModel.from_pretrained(BASE_MODEL_NAME).to(DEVICE) model.eval() BASE_MODEL_CACHE["model"] = model BASE_MODEL_CACHE["tokenizer"] = tokenizer print("Base model and tokenizer loaded.") return BASE_MODEL_CACHE["model"], BASE_MODEL_CACHE["tokenizer"] def get_sae_model(layer, k_sparsity): """Downloads, loads, and caches a specific Sparse Autoencoder.""" sae_key = f"layer_{layer}_k_{k_sparsity}" if sae_key not in SAE_MODEL_CACHE: print(f"Loading SAE for layer {layer}, k={k_sparsity}...") try: model_path = hf_hub_download( repo_id=SAE_REPO, filename=f"layer_{layer}/k_{k_sparsity}/trainer_0/ae.pt", ) config_path = hf_hub_download( repo_id=SAE_REPO, filename=f"layer_{layer}/k_{k_sparsity}/trainer_0/config.json", ) ae, config = utils.load_dictionary(pathlib.Path(model_path).parents[0], device=DEVICE) SAE_MODEL_CACHE[sae_key] = ae print(f"SAE for {sae_key} loaded.") except Exception as e: raise gr.Error( f"Could not load SAE (layer={layer}, k={k_sparsity}). " f"Check if it exists in the repo. Error: {e}" ) return SAE_MODEL_CACHE[sae_key] # --- Plotting Helper Function --- def create_plot(activations, title, k_sparsity): """Filters for the top k-sparsity activations and creates a Plotly bar chart.""" if k_sparsity > len(activations): k_sparsity = len(activations) # Find the indices of the top k activations top_indices = np.argsort(activations)[-k_sparsity:] top_values = activations[top_indices] # Sort the top activations by value for better plotting sorted_pairs = sorted(zip(top_indices, top_values), key=lambda item: item[1], reverse=True) plot_indices = [p[0] for p in sorted_pairs] plot_values = [p[1] for p in sorted_pairs] if not plot_indices: fig = px.bar(title=title) fig.update_layout(annotations=[dict(text="No features activated", showarrow=False)]) return fig fig = px.bar( x=plot_indices, y=plot_values, labels={'x': 'Latent Feature Index', 'y': 'Activation Value'}, title=title ) # Use a dark color for the bars for better contrast fig.update_traces(marker_color='#2c3e50') fig.update_xaxes(type='category') fig.update_layout( xaxis_title="Latent Feature Index", yaxis_title="Activation Value", font=dict(family="Arial, sans-serif", size=12), plot_bgcolor='white' # Set background to white ) return fig # --- Core Inference Logic --- def parse_sequences(text_input, file_input): """Parses sequences from either a file upload or a textbox.""" if file_input is not None: try: df = pd.read_parquet(file_input.name) if 'sequence' not in df.columns: raise gr.Error("Parquet file must have a 'sequence' column.") return df['sequence'].tolist() except Exception as e: raise gr.Error(f"Error reading Parquet file: {e}") if text_input and text_input.strip(): return [seq.strip() for seq in text_input.strip().split('\n')] return [] @spaces.GPU def get_average_activations(sequences, layer, k_sparsity, progress=None): """ Processes sequences in batches to get average SAE activations, with progress tracking. """ if not sequences: raise gr.Error("No sequences provided.") first_len = len(sequences[0]) if not all(len(s) == first_len for s in sequences): raise gr.Error("All sequences must have the same length.") model, tokenizer = get_base_model() sae_model = get_sae_model(layer, k_sparsity) total_sequences = len(sequences) sum_activations = None iterable = range(0, total_sequences, BATCH_SIZE) if progress: iterable = progress.tqdm(iterable, desc="Processing Batches") for i in iterable: batch_sequences = sequences[i:i+BATCH_SIZE] inputs = tokenizer(batch_sequences, return_tensors="pt", padding="longest", truncation=True, max_length=first_len).to(DEVICE) with torch.no_grad(): outputs = model(**inputs, output_hidden_states=True) target_hidden_state = outputs.hidden_states[layer] sae_activations = sae_model.encode(target_hidden_state) if sum_activations is None: sum_activations = torch.zeros_like(sae_activations[0]).to(DEVICE) sum_activations += sae_activations.sum(dim=0) avg_activations = sum_activations / total_sequences tok_inputs = tokenizer(sequences[0], return_tensors="pt") tokens = tokenizer.convert_ids_to_tokens(tok_inputs['input_ids'][0], skip_special_tokens=True) indexed_tokens = [f"{i}: {token}" for i, token in enumerate(tokens)] return avg_activations, indexed_tokens, total_sequences @spaces.GPU def analyze_sequences_and_update_ui(text_input, file_input, layer, k_sparsity, progress=gr.Progress(track_tqdm=True)): """ Main function to drive analysis and update the Gradio UI, with progress bar. """ sequences = parse_sequences(text_input, file_input) if not sequences: raise gr.Error("Please provide sequences via file or textbox.") avg_activations, indexed_tokens, num_seqs = get_average_activations(sequences, layer, k_sparsity, progress) analysis_title = f"Averaged over {num_seqs} sequences" return avg_activations, gr.update(choices=indexed_tokens, value=None), gr.update(value=None), analysis_title @spaces.GPU def update_plot_on_selection(evt: gr.SelectData, avg_activations, layer, k_sparsity, analysis_title): """ Callback to update the plot when a token position is selected. """ if evt.value is None or avg_activations is None: return None token_index = evt.index selected_token_str = evt.value # Offset by 1 to account for the [CLS] token at the beginning of the sequence activation_index = token_index + 1 # Ensure the index is within the bounds of the tensor if activation_index >= avg_activations.shape[0]: raise gr.Error(f"Index {activation_index} is out of bounds for activation tensor with shape {avg_activations.shape}") token_activations = avg_activations[activation_index, :].cpu().numpy() plot_title = (f"Top {k_sparsity} Avg. Activations for position '{selected_token_str}'\n" f"({analysis_title} - Layer {layer}, k={k_sparsity})") fig = create_plot(token_activations, plot_title, k_sparsity) return fig # --- Gradio UI Definition --- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( ''' # PlantRNA-FM SAE Viewer: Average Activations Analyze average SAE feature activations over multiple sequences. Provide sequences by uploading a Parquet file or pasting into the textbox. ''' ) with gr.Row(): sae_layer = gr.Dropdown(label="SAE Layer", choices=[1,2,3,4,5,6,7,8,9,10,11], value=6) sae_sparsity = gr.Dropdown(label="SAE Sparsity (k)", choices=[32, 64, 128, 256, 512], value=128) with gr.Accordion("Input Method", open=True): gr.Markdown("Provide sequences via a Parquet file (must contain a 'sequence' column) OR paste them below (one per line). The file upload takes priority.") file_input = gr.File(label="Upload Parquet File", file_types=[".parquet"]) text_input = gr.Textbox(label="Input RNA Sequences", placeholder="e.g., GCGAAAGAUAUGUAGAUAAAGAUAUAGAUGAUAG...", lines=5) run_button = gr.Button("Analyze Average Activations", variant="primary") # --- States and Outputs --- activations_state = gr.State() analysis_title_state = gr.State() gr.Markdown("### Select Nucleotide Position") sequence_selector = gr.Radio(label="Sequence Positions", interactive=True) output_plot = gr.Plot(label="SAE Latent Activations") # --- Event Wiring --- run_button.click( fn=analyze_sequences_and_update_ui, inputs=[text_input, file_input, sae_layer, sae_sparsity], outputs=[activations_state, sequence_selector, output_plot, analysis_title_state], ) sequence_selector.select( fn=update_plot_on_selection, inputs=[activations_state, sae_layer, sae_sparsity, analysis_title_state], outputs=output_plot, ) gr.Examples( examples=[ ["AUGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGC\nAUGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGC", None, 6, 128], ], inputs=[text_input, file_input, sae_layer, sae_sparsity] ) demo.launch()