|
|
"""Utilities for interactive visualization of extracted graphs""" |
|
|
import pandas as pd |
|
|
import plotly.express as px |
|
|
import streamlit as st |
|
|
|
|
|
|
|
|
def create_scatter_plot_with_filter(graph_data): |
|
|
""" |
|
|
Crea uno scatter plot interattivo con filtro per cumulative influence |
|
|
|
|
|
Args: |
|
|
graph_data: Dizionario contenente i dati del grafo (nodes, metadata, etc) |
|
|
""" |
|
|
if 'nodes' not in graph_data: |
|
|
st.warning("⚠️ No nodes found in graph data") |
|
|
return |
|
|
|
|
|
|
|
|
prompt_tokens = graph_data.get('metadata', {}).get('prompt_tokens', []) |
|
|
|
|
|
|
|
|
token_map = {i: token for i, token in enumerate(prompt_tokens)} |
|
|
|
|
|
|
|
|
|
|
|
scatter_data = [] |
|
|
skipped_nodes = [] |
|
|
|
|
|
for node in graph_data['nodes']: |
|
|
layer_val = node.get('layer', '') |
|
|
|
|
|
try: |
|
|
|
|
|
if str(layer_val).upper() == 'E': |
|
|
layer_numeric = -1 |
|
|
else: |
|
|
|
|
|
layer_numeric = int(layer_val) |
|
|
|
|
|
|
|
|
influence_val = node.get('influence', 0) |
|
|
if influence_val is None or influence_val == 0: |
|
|
influence_val = 0.001 |
|
|
|
|
|
|
|
|
ctx_idx_val = node.get('ctx_idx', 0) |
|
|
token_str = token_map.get(ctx_idx_val, f"ctx_{ctx_idx_val}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
node_id = node.get('node_id', '') |
|
|
node_type = node.get('feature_type', '') |
|
|
feature_idx = None |
|
|
|
|
|
if node_type == 'cross layer transcoder': |
|
|
|
|
|
if node_id and '_' in node_id: |
|
|
parts = node_id.split('_') |
|
|
if len(parts) >= 2: |
|
|
try: |
|
|
|
|
|
feature_idx = int(parts[1]) |
|
|
except (ValueError, IndexError): |
|
|
pass |
|
|
|
|
|
|
|
|
if feature_idx is None: |
|
|
skipped_nodes.append(f"layer={layer_val}, node_id={node_id}, type=SAE") |
|
|
continue |
|
|
else: |
|
|
|
|
|
|
|
|
feature_idx = -1 |
|
|
|
|
|
scatter_data.append({ |
|
|
'layer': layer_numeric, |
|
|
'ctx_idx': ctx_idx_val, |
|
|
'token': token_str, |
|
|
'id': node_id, |
|
|
'influence': influence_val, |
|
|
'feature': feature_idx |
|
|
}) |
|
|
except (ValueError, TypeError): |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
if skipped_nodes: |
|
|
st.warning(f"⚠️ {len(skipped_nodes)} feature nodes with malformed node_id were skipped") |
|
|
with st.expander("Skipped nodes details"): |
|
|
for node_info in skipped_nodes[:10]: |
|
|
st.text(node_info) |
|
|
if len(skipped_nodes) > 10: |
|
|
st.text(f"... and {len(skipped_nodes) - 10} more nodes") |
|
|
|
|
|
if not scatter_data: |
|
|
st.warning("⚠️ No valid nodes found for plotting") |
|
|
return |
|
|
|
|
|
scatter_df = pd.DataFrame(scatter_data) |
|
|
|
|
|
|
|
|
scatter_df['influence'] = scatter_df['influence'].fillna(0.001) |
|
|
scatter_df['influence'] = scatter_df['influence'].replace(0, 0.001) |
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
bin_width = 0.3 |
|
|
scatter_df['sub_column'] = 0 |
|
|
|
|
|
for (ctx, layer), group in scatter_df.groupby(['ctx_idx', 'layer']): |
|
|
n_nodes = len(group) |
|
|
if n_nodes > 1: |
|
|
|
|
|
n_bins = min(5, int(np.ceil(np.sqrt(n_nodes)))) |
|
|
|
|
|
for i, idx in enumerate(group.index): |
|
|
sub_col = (i % n_bins) - (n_bins - 1) / 2 |
|
|
scatter_df.at[idx, 'sub_column'] = sub_col * bin_width |
|
|
|
|
|
|
|
|
scatter_df['ctx_idx_display'] = scatter_df['ctx_idx'] + scatter_df['sub_column'] |
|
|
|
|
|
|
|
|
st.markdown("### 3️⃣ Filter Features by Cumulative Influence Coverage") |
|
|
|
|
|
|
|
|
max_influence = scatter_df['influence'].max() |
|
|
|
|
|
|
|
|
node_threshold_used = graph_data.get('metadata', {}).get('node_threshold', None) |
|
|
|
|
|
if node_threshold_used is not None: |
|
|
st.info(f""" |
|
|
**The `influence` field is the cumulative coverage (0-{max_influence:.2f})** calculated by circuit tracer pruning. When nodes are sorted by descending influence, a node with `influence=0.65` means that |
|
|
**up to that node** covers 65% of the total influence. |
|
|
""") |
|
|
else: |
|
|
st.info(f""" |
|
|
**The `influence` field is the cumulative coverage (0-{max_influence:.2f})** calculated by circuit tracer pruning. |
|
|
|
|
|
When nodes are sorted by descending influence, a node with `influence=0.65` means that |
|
|
**up to that node** covers 65% of the total influence. |
|
|
""") |
|
|
|
|
|
cumulative_threshold = st.slider( |
|
|
"Cumulative Influence Threshold", |
|
|
min_value=0.0, |
|
|
max_value=float(max_influence), |
|
|
value=float(max_influence), |
|
|
step=0.01, |
|
|
key="cumulative_slider_main", |
|
|
help=f"Keep only nodes with influence ≤ threshold. Range: 0.0 - {max_influence:.2f} (max in data)" |
|
|
) |
|
|
|
|
|
|
|
|
filter_error_nodes = st.checkbox( |
|
|
"Exclude Reconstruction Error Nodes (feature = -1)", |
|
|
value=False, |
|
|
key="filter_error_checkbox", |
|
|
help="Reconstruction error nodes represent the part of the model not explained by SAE features" |
|
|
) |
|
|
|
|
|
|
|
|
num_total = len(scatter_df) |
|
|
|
|
|
|
|
|
is_error_node = scatter_df['feature'] == -1 |
|
|
n_error_total = is_error_node.sum() |
|
|
pct_error_nodes = (n_error_total / num_total * 100) if num_total > 0 else 0 |
|
|
|
|
|
|
|
|
is_embedding = scatter_df['layer'] == -1 |
|
|
|
|
|
max_layer = scatter_df['layer'].max() |
|
|
is_logit = scatter_df['layer'] == max_layer |
|
|
|
|
|
|
|
|
if cumulative_threshold < 1.0: |
|
|
mask_influence = scatter_df['influence'] <= cumulative_threshold |
|
|
mask_keep = mask_influence | is_embedding | is_logit |
|
|
else: |
|
|
mask_keep = pd.Series([True] * len(scatter_df), index=scatter_df.index) |
|
|
|
|
|
|
|
|
if filter_error_nodes: |
|
|
|
|
|
mask_not_error = (scatter_df['feature'] != -1) | is_embedding | is_logit |
|
|
mask_keep = mask_keep & mask_not_error |
|
|
|
|
|
scatter_filtered = scatter_df[mask_keep].copy() |
|
|
|
|
|
|
|
|
feature_nodes_filtered = scatter_filtered[~((scatter_filtered['layer'] == -1) | (scatter_filtered['layer'] == max_layer))] |
|
|
if len(feature_nodes_filtered) > 0: |
|
|
threshold_influence = feature_nodes_filtered['influence'].max() |
|
|
else: |
|
|
threshold_influence = 0.0 |
|
|
|
|
|
num_selected = len(scatter_filtered) |
|
|
|
|
|
|
|
|
is_embedding_filtered = scatter_filtered['layer'] == -1 |
|
|
max_layer_filtered = scatter_filtered['layer'].max() |
|
|
is_logit_filtered = scatter_filtered['layer'] == max_layer_filtered |
|
|
is_error_filtered = scatter_filtered['feature'] == -1 |
|
|
|
|
|
n_embeddings = len(scatter_filtered[is_embedding_filtered]) |
|
|
n_error_nodes = len(scatter_filtered[is_error_filtered & ~is_embedding_filtered & ~is_logit_filtered]) |
|
|
n_features = len(scatter_filtered[~(is_embedding_filtered | is_logit_filtered | is_error_filtered)]) |
|
|
n_logits_excluded = len(scatter_filtered[is_logit_filtered]) |
|
|
n_error_excluded = n_error_total - n_error_nodes if filter_error_nodes else 0 |
|
|
|
|
|
|
|
|
col1, col2, col3, col4 = st.columns(4) |
|
|
|
|
|
with col1: |
|
|
st.metric("Total Nodes", num_total) |
|
|
|
|
|
with col2: |
|
|
st.metric("Selected Nodes", num_selected) |
|
|
|
|
|
with col3: |
|
|
pct = (num_selected / num_total * 100) if num_total > 0 else 0 |
|
|
st.metric("% Nodes", f"{pct:.1f}%") |
|
|
|
|
|
with col4: |
|
|
st.metric("Influence Threshold", f"{threshold_influence:.6f}") |
|
|
|
|
|
|
|
|
|
|
|
scatter_df = scatter_filtered |
|
|
|
|
|
|
|
|
scatter_df = scatter_df.copy() |
|
|
scatter_df['sub_column'] = 0 |
|
|
|
|
|
for (ctx, layer), group in scatter_df.groupby(['ctx_idx', 'layer']): |
|
|
n_nodes = len(group) |
|
|
if n_nodes > 1: |
|
|
n_bins = min(5, int(np.ceil(np.sqrt(n_nodes)))) |
|
|
for i, idx in enumerate(group.index): |
|
|
sub_col = (i % n_bins) - (n_bins - 1) / 2 |
|
|
scatter_df.at[idx, 'sub_column'] = sub_col * bin_width |
|
|
|
|
|
scatter_df['ctx_idx_display'] = scatter_df['ctx_idx'] + scatter_df['sub_column'] |
|
|
|
|
|
|
|
|
|
|
|
if 'node_influence' not in scatter_df.columns: |
|
|
|
|
|
df_sorted_by_cumul = scatter_df.sort_values('influence').reset_index(drop=True) |
|
|
df_sorted_by_cumul['node_influence'] = df_sorted_by_cumul['influence'].diff() |
|
|
df_sorted_by_cumul.loc[0, 'node_influence'] = df_sorted_by_cumul.loc[0, 'influence'] |
|
|
|
|
|
|
|
|
node_id_to_marginal = dict(zip(df_sorted_by_cumul['id'], df_sorted_by_cumul['node_influence'])) |
|
|
scatter_df['node_influence'] = scatter_df['id'].map(node_id_to_marginal).fillna(scatter_df['influence']) |
|
|
|
|
|
|
|
|
|
|
|
is_error_in_complete = scatter_df['feature'] == -1 |
|
|
total_node_influence = scatter_df['node_influence'].sum() |
|
|
error_node_influence = scatter_df[is_error_in_complete]['node_influence'].sum() |
|
|
pct_error_influence = (error_node_influence / total_node_influence * 100) if total_node_influence > 0 else 0 |
|
|
|
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
with col1: |
|
|
st.metric( |
|
|
"% Error Nodes", |
|
|
f"{pct_error_nodes:.1f}%", |
|
|
help=f"{n_error_total} out of {num_total} total nodes are reconstruction error (feature=-1)" |
|
|
) |
|
|
with col2: |
|
|
st.metric( |
|
|
"% Node Influence (Error)", |
|
|
f"{pct_error_influence:.1f}%", |
|
|
help=f"Reconstruction error nodes contribute {pct_error_influence:.1f}% of total node_influence" |
|
|
) |
|
|
|
|
|
|
|
|
info_parts = [f"{n_embeddings} embeddings", f"{n_features} features"] |
|
|
if n_error_nodes > 0: |
|
|
info_parts.append(f"{n_error_nodes} error nodes") |
|
|
|
|
|
excluded_parts = [f"{n_logits_excluded} logits"] |
|
|
if n_error_excluded > 0: |
|
|
excluded_parts.append(f"{n_error_excluded} error nodes") |
|
|
|
|
|
st.info(f"📊 Displaying {n_embeddings + n_features + n_error_nodes} nodes: {', '.join(info_parts)} ({', '.join(excluded_parts)} excluded)") |
|
|
|
|
|
|
|
|
|
|
|
is_embedding_group = scatter_df['layer'] == -1 |
|
|
max_layer = scatter_df['layer'].max() |
|
|
is_logit_group = scatter_df['layer'] == max_layer |
|
|
is_feature_group = ~(is_embedding_group | is_logit_group) |
|
|
|
|
|
|
|
|
scatter_df = scatter_df[~is_logit_group].copy() |
|
|
|
|
|
|
|
|
is_embedding_group = scatter_df['layer'] == -1 |
|
|
is_feature_group = scatter_df['layer'] != -1 |
|
|
|
|
|
|
|
|
scatter_df['node_type'] = 'feature' |
|
|
scatter_df.loc[is_embedding_group, 'node_type'] = 'embedding' |
|
|
|
|
|
|
|
|
|
|
|
scatter_df['influence_log'] = 0.0 |
|
|
|
|
|
for group_name, group_mask in [('embedding', is_embedding_group), |
|
|
('feature', is_feature_group)]: |
|
|
if group_mask.sum() > 0: |
|
|
group_data = scatter_df[group_mask]['node_influence'].abs() |
|
|
|
|
|
max_in_group = group_data.max() |
|
|
if max_in_group > 0: |
|
|
normalized = group_data / max_in_group |
|
|
|
|
|
|
|
|
|
|
|
scatter_df.loc[group_mask, 'influence_log'] = (normalized ** 3) * 1000 + 10 |
|
|
else: |
|
|
scatter_df.loc[group_mask, 'influence_log'] = 10 |
|
|
|
|
|
|
|
|
symbol_map = { |
|
|
'embedding': 'square', |
|
|
'feature': 'circle' |
|
|
} |
|
|
|
|
|
fig = px.scatter( |
|
|
scatter_df, |
|
|
x='ctx_idx_display', |
|
|
y='layer', |
|
|
size='influence_log', |
|
|
symbol='node_type', |
|
|
symbol_map=symbol_map, |
|
|
color='node_type', |
|
|
color_discrete_map={ |
|
|
'embedding': '#4CAF50', |
|
|
'feature': '#808080' |
|
|
}, |
|
|
labels={ |
|
|
'id': 'Node ID', |
|
|
'ctx_idx_display': 'Context Position', |
|
|
'ctx_idx': 'ctx_idx', |
|
|
'layer': 'Layer', |
|
|
'influence': 'Cumulative Influence', |
|
|
'node_influence': 'Node Influence', |
|
|
'node_type': 'Node Type', |
|
|
'token': 'Token', |
|
|
'feature': 'Feature' |
|
|
}, |
|
|
title='Features by Layer and Position (size: node_influence^3 normalized per group)', |
|
|
hover_data={ |
|
|
'ctx_idx': True, |
|
|
'token': True, |
|
|
'layer': True, |
|
|
'node_type': True, |
|
|
'id': True, |
|
|
'feature': True, |
|
|
'node_influence': ':.6f', |
|
|
'influence': ':.4f', |
|
|
'ctx_idx_display': False, |
|
|
'influence_log': False |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
max_influence_log = scatter_df['influence_log'].max() |
|
|
|
|
|
fig.update_traces( |
|
|
marker=dict( |
|
|
sizemode='area', |
|
|
sizeref=2.*max_influence_log/(50.**2) if max_influence_log > 0 else 1, |
|
|
sizemin=2, |
|
|
opacity=0.3, |
|
|
line=dict(width=1.5, color='white') |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
unique_ctx = sorted(scatter_df['ctx_idx'].unique()) |
|
|
tick_labels = [f"{ctx}: {token_map.get(ctx, '')}" for ctx in unique_ctx] |
|
|
|
|
|
fig.update_layout( |
|
|
template='plotly_white', |
|
|
height=600, |
|
|
showlegend=True, |
|
|
legend=dict( |
|
|
title="Node Type", |
|
|
orientation="v", |
|
|
yanchor="top", |
|
|
y=0.99, |
|
|
xanchor="left", |
|
|
x=0.99, |
|
|
bgcolor="rgba(255,255,255,0.8)" |
|
|
), |
|
|
xaxis=dict( |
|
|
gridcolor='lightgray', |
|
|
tickmode='array', |
|
|
tickvals=unique_ctx, |
|
|
ticktext=tick_labels, |
|
|
tickangle=-45 |
|
|
), |
|
|
yaxis=dict(gridcolor='lightgray') |
|
|
) |
|
|
|
|
|
st.plotly_chart(fig, use_container_width=True) |
|
|
|
|
|
|
|
|
with st.expander("📊 Statistics by Group (Size Normalization)", expanded=False): |
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
|
st.markdown("**🟩 Embeddings (green squares)**") |
|
|
emb_data = scatter_df[scatter_df['node_type'] == 'embedding'] |
|
|
if len(emb_data) > 0: |
|
|
st.metric("Nodes", len(emb_data)) |
|
|
st.metric("Max node_influence", f"{emb_data['node_influence'].max():.6f}") |
|
|
st.metric("Mean node_influence", f"{emb_data['node_influence'].mean():.6f}") |
|
|
st.metric("Min node_influence", f"{emb_data['node_influence'].min():.6f}") |
|
|
else: |
|
|
st.info("No embeddings in filtered dataset") |
|
|
|
|
|
with col2: |
|
|
st.markdown("**⚪ Features (gray circles)**") |
|
|
feat_data = scatter_df[scatter_df['node_type'] == 'feature'] |
|
|
if len(feat_data) > 0: |
|
|
st.metric("Nodes", len(feat_data)) |
|
|
st.metric("Max node_influence", f"{feat_data['node_influence'].max():.6f}") |
|
|
st.metric("Mean node_influence", f"{feat_data['node_influence'].mean():.6f}") |
|
|
st.metric("Min node_influence", f"{feat_data['node_influence'].min():.6f}") |
|
|
else: |
|
|
st.info("No features in filtered dataset") |
|
|
|
|
|
st.info(""" |
|
|
💡 **Size formula**: `size = (normalized_node_influence)³ × 1000 + 10` |
|
|
|
|
|
Size is normalized **per group** and uses **power 3** to emphasize differences: |
|
|
- A node with 50% of max → size = 0.5³ = 12.5% (much smaller) |
|
|
- A node with 80% of max → size = 0.8³ = 51.2% |
|
|
- A node with 100% of max → size = 1.0³ = 100% |
|
|
|
|
|
The 2 groups (embeddings and features) have independent scales. |
|
|
Note: in the JSON the "influence" field is the pre-pruning cumulative, so estimating node_influence as the difference between consecutive cumulatives is only a normalized proxy (to be renormalized on the current set), because the graph may already be topologically pruned and the selection does not coincide with a contiguous prefix of sorted nodes. |
|
|
""") |
|
|
|
|
|
|
|
|
with st.expander("📈 Pareto Analysis Node Influence (Features only)", expanded=False): |
|
|
try: |
|
|
|
|
|
features_only = scatter_df[scatter_df['node_type'] == 'feature'].copy() |
|
|
|
|
|
if len(features_only) == 0: |
|
|
st.warning("⚠️ No features found in filtered dataset") |
|
|
return |
|
|
|
|
|
|
|
|
sorted_df = features_only.sort_values('node_influence', ascending=False).reset_index(drop=True) |
|
|
|
|
|
|
|
|
sorted_df['rank'] = range(1, len(sorted_df) + 1) |
|
|
sorted_df['rank_pct'] = sorted_df['rank'] / len(sorted_df) * 100 |
|
|
|
|
|
|
|
|
total_node_inf = sorted_df['node_influence'].sum() |
|
|
|
|
|
if total_node_inf == 0: |
|
|
st.warning("⚠️ Total Node influence is 0") |
|
|
return |
|
|
|
|
|
sorted_df['cumulative_node_influence'] = sorted_df['node_influence'].cumsum() |
|
|
sorted_df['cumulative_node_influence_pct'] = sorted_df['cumulative_node_influence'] / total_node_inf * 100 |
|
|
|
|
|
|
|
|
import plotly.graph_objects as go |
|
|
from plotly.subplots import make_subplots |
|
|
|
|
|
|
|
|
fig_pareto = make_subplots(specs=[[{"secondary_y": True}]]) |
|
|
|
|
|
|
|
|
display_limit = min(100, len(sorted_df)) |
|
|
|
|
|
fig_pareto.add_trace( |
|
|
go.Bar( |
|
|
x=sorted_df['rank'][:display_limit], |
|
|
y=sorted_df['node_influence'][:display_limit], |
|
|
name='Node Influence', |
|
|
marker=dict(color='#2196F3', opacity=0.6), |
|
|
hovertemplate='<b>Rank: %{x}</b><br>Node Influence: %{y:.6f}<extra></extra>' |
|
|
), |
|
|
secondary_y=False |
|
|
) |
|
|
|
|
|
|
|
|
fig_pareto.add_trace( |
|
|
go.Scatter( |
|
|
x=sorted_df['rank_pct'], |
|
|
y=sorted_df['cumulative_node_influence_pct'], |
|
|
mode='lines+markers', |
|
|
name='Cumulative %', |
|
|
line=dict(color='#FF5722', width=3), |
|
|
marker=dict(size=4), |
|
|
hovertemplate='<b>Top %{x:.1f}% features</b><br>Cumulative: %{y:.1f}%<extra></extra>' |
|
|
), |
|
|
secondary_y=True |
|
|
) |
|
|
|
|
|
|
|
|
for pct, label in [(80, '80%'), (90, '90%'), (95, '95%')]: |
|
|
fig_pareto.add_hline( |
|
|
y=pct, |
|
|
line_dash="dash", |
|
|
line_color="gray", |
|
|
opacity=0.5, |
|
|
secondary_y=True |
|
|
) |
|
|
fig_pareto.add_annotation( |
|
|
x=100, |
|
|
y=pct, |
|
|
text=label, |
|
|
showarrow=False, |
|
|
xanchor='left', |
|
|
yref='y2' |
|
|
) |
|
|
|
|
|
|
|
|
knee_idx = (sorted_df['cumulative_node_influence_pct'] >= 80).idxmax() |
|
|
knee_rank_pct = sorted_df.loc[knee_idx, 'rank_pct'] |
|
|
knee_cumul = sorted_df.loc[knee_idx, 'cumulative_node_influence_pct'] |
|
|
|
|
|
fig_pareto.add_trace( |
|
|
go.Scatter( |
|
|
x=[knee_rank_pct], |
|
|
y=[knee_cumul], |
|
|
mode='markers', |
|
|
name='Knee (80%)', |
|
|
marker=dict(size=15, color='#4CAF50', symbol='diamond', line=dict(width=2, color='white')), |
|
|
hovertemplate=f'<b>Knee Point</b><br>Top {knee_rank_pct:.1f}% features<br>Cumulativa: {knee_cumul:.1f}%<extra></extra>', |
|
|
showlegend=True |
|
|
), |
|
|
secondary_y=True |
|
|
) |
|
|
|
|
|
|
|
|
fig_pareto.update_xaxes(title_text="Rank % Features (by descending node_influence)") |
|
|
fig_pareto.update_yaxes(title_text="Node Influence (individual)", secondary_y=False) |
|
|
fig_pareto.update_yaxes(title_text="Cumulative % Node Influence", secondary_y=True, range=[0, 105]) |
|
|
|
|
|
fig_pareto.update_layout( |
|
|
height=500, |
|
|
showlegend=True, |
|
|
template='plotly_white', |
|
|
legend=dict(x=0.02, y=0.98, xanchor='left', yanchor='top'), |
|
|
title="Pareto Chart: Node Influence of Features" |
|
|
) |
|
|
|
|
|
st.plotly_chart(fig_pareto, use_container_width=True) |
|
|
|
|
|
|
|
|
st.markdown("#### 📊 Pareto Statistics (Node Influence)") |
|
|
|
|
|
col1, col2, col3, col4 = st.columns(4) |
|
|
|
|
|
|
|
|
top_10_idx = max(0, int(len(sorted_df) * 0.1)) |
|
|
top_20_idx = max(0, int(len(sorted_df) * 0.2)) |
|
|
top_50_idx = max(0, int(len(sorted_df) * 0.5)) |
|
|
|
|
|
top_10_pct = sorted_df['cumulative_node_influence_pct'].iloc[top_10_idx] if top_10_idx < len(sorted_df) else 0 |
|
|
top_20_pct = sorted_df['cumulative_node_influence_pct'].iloc[top_20_idx] if top_20_idx < len(sorted_df) else 0 |
|
|
top_50_pct = sorted_df['cumulative_node_influence_pct'].iloc[top_50_idx] if top_50_idx < len(sorted_df) else 0 |
|
|
|
|
|
with col1: |
|
|
st.metric("Top 10% features", f"{top_10_pct:.1f}% node_influence", |
|
|
help=f"The top {int(len(sorted_df)*0.1)} most influential features cover {top_10_pct:.1f}% of total influence") |
|
|
with col2: |
|
|
st.metric("Top 20% features", f"{top_20_pct:.1f}% node_influence", |
|
|
help=f"The top {int(len(sorted_df)*0.2)} most influential features cover {top_20_pct:.1f}% of total influence") |
|
|
with col3: |
|
|
st.metric("Top 50% features", f"{top_50_pct:.1f}% node_influence", |
|
|
help=f"The top {int(len(sorted_df)*0.5)} most influential features cover {top_50_pct:.1f}% of total influence") |
|
|
with col4: |
|
|
|
|
|
gini = 1 - 2 * np.trapz(sorted_df['cumulative_node_influence_pct'] / 100, sorted_df['rank_pct'] / 100) |
|
|
st.metric("Gini Coefficient", f"{gini:.3f}", help="0 = equal distribution, 1 = highly concentrated") |
|
|
|
|
|
|
|
|
|
|
|
knee_cumul_threshold = sorted_df.loc[knee_idx, 'influence'] if 'influence' in sorted_df.columns else scatter_df['influence'].max() |
|
|
|
|
|
st.success(f""" |
|
|
🎯 **Knee Point (80%)**: The first **{knee_rank_pct:.1f}%** of features ({int(len(sorted_df) * knee_rank_pct / 100)} nodes) |
|
|
cover **80%** of total node_influence. |
|
|
|
|
|
💡 **Threshold Suggestion**: To focus on features up to the knee point (80%), |
|
|
use `cumulative_threshold ≈ {knee_cumul_threshold:.4f}` in the slider above. |
|
|
""") |
|
|
|
|
|
|
|
|
with st.expander("📊 Node Influence Distribution Histogram", expanded=False): |
|
|
fig_hist = px.histogram( |
|
|
sorted_df, |
|
|
x='node_influence', |
|
|
nbins=50, |
|
|
title='Node Influence Distribution (Features)', |
|
|
labels={'node_influence': 'Node Influence', 'count': 'Frequency'}, |
|
|
color_discrete_sequence=['#2196F3'] |
|
|
) |
|
|
|
|
|
fig_hist.update_layout( |
|
|
height=350, |
|
|
template='plotly_white', |
|
|
showlegend=False |
|
|
) |
|
|
|
|
|
fig_hist.update_traces(marker=dict(opacity=0.7)) |
|
|
|
|
|
st.plotly_chart(fig_hist, use_container_width=True) |
|
|
|
|
|
|
|
|
col1, col2, col3, col4 = st.columns(4) |
|
|
with col1: |
|
|
st.metric("Mean", f"{sorted_df['node_influence'].mean():.6f}") |
|
|
with col2: |
|
|
st.metric("Median", f"{sorted_df['node_influence'].median():.6f}") |
|
|
with col3: |
|
|
st.metric("Std Dev", f"{sorted_df['node_influence'].std():.6f}") |
|
|
with col4: |
|
|
st.metric("Max", f"{sorted_df['node_influence'].max():.6f}") |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"❌ Error creating distribution chart: {str(e)}") |
|
|
import traceback |
|
|
st.code(traceback.format_exc()) |
|
|
|
|
|
|
|
|
|
|
|
sae_features_only = scatter_filtered[ |
|
|
~(is_embedding_filtered | is_logit_filtered | is_error_filtered) |
|
|
].copy() |
|
|
|
|
|
return sae_features_only |
|
|
|
|
|
|