Spaces:
Sleeping
Sleeping
| """Web App for the Codebook Features project.""" | |
| import argparse | |
| import glob | |
| import os | |
| import streamlit as st | |
| import code_search_utils | |
| import utils | |
| import webapp_utils | |
| # --- Parse command line arguments --- | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--deploy", | |
| default=True, | |
| help="Deploy mode.", | |
| ) | |
| parser.add_argument( | |
| "--cache_dir", | |
| type=str, | |
| default="cache/", | |
| help="Path to directory containing cache for codebook models.", | |
| ) | |
| try: | |
| args = parser.parse_args() | |
| except SystemExit as e: | |
| # This exception will be raised if --help or invalid command line arguments | |
| # are used. Currently streamlit prevents the program from exiting normally | |
| # so we have to do a hard exit. | |
| os._exit(e.code if isinstance(e.code, int) else 1) | |
| deploy = args.deploy | |
| webapp_utils.load_widget_state() | |
| st.set_page_config( | |
| page_title="Codebook Features", | |
| page_icon="π", | |
| ) | |
| st.title("Codebook Features") | |
| # --- Load model info and cache --- | |
| pretty_model_names = { | |
| "TinyStories-1Layer-21M#100ksteps_vcb_mlp": "TinyStories-1L-21M-MLP", | |
| "TinyStories-1Layer-21M_ccb_attn_preproj": "TinyStories 1 Layer Attention Codebook", | |
| "TinyStories-33M_ccb_attn_preproj": "TinyStories 4 Layer Attention Codebook", | |
| "TinyStories-1Layer-21M_vcb_mlp": "TinyStories 1 Layer MLP Codebook", | |
| } | |
| orig_model_name = {v: k for k, v in pretty_model_names.items()} | |
| base_cache_dir = args.cache_dir | |
| dirs = glob.glob(base_cache_dir + "models/*/") | |
| model_name_options = [d.split("/")[-2].split("_")[:-2] for d in dirs] | |
| model_name_options = ["_".join(m) for m in model_name_options] | |
| model_name_options = sorted(set(model_name_options)) | |
| def_model_idx = ["attn" in m.lower() for m in model_name_options].index(True) | |
| p_model_name = st.selectbox( | |
| "Model", | |
| [pretty_model_names.get(m, m) for m in model_name_options], | |
| index=def_model_idx, | |
| key=webapp_utils.persist("model_name"), | |
| ) | |
| model_name = orig_model_name.get(p_model_name, p_model_name) | |
| is_fsm = "FSM" in p_model_name | |
| codes_cache_path = base_cache_dir + f"models/{model_name}_*" | |
| dirs = glob.glob(codes_cache_path) | |
| dirs.sort(key=os.path.getmtime) | |
| # session states | |
| codes_cache_path = dirs[-1] + "/" | |
| model_info = utils.ModelInfoForWebapp.load(codes_cache_path) | |
| num_codes = model_info.num_codes | |
| num_layers = model_info.n_layers | |
| num_heads = model_info.n_heads | |
| cb_at = model_info.cb_at | |
| gcb = model_info.gcb | |
| gcb = "_gcb" if gcb else "" | |
| is_attn = "attn" in cb_at | |
| dataset_cache_path = base_cache_dir + f"datasets/{model_info.dataset_name}/" | |
| ( | |
| tokens_str, | |
| tokens_text, | |
| token_byte_pos, | |
| cb_acts, | |
| act_count_ft_tkns, | |
| metrics, | |
| ) = webapp_utils.load_code_search_cache(codes_cache_path, dataset_cache_path) | |
| seq_len = len(tokens_str[0]) | |
| metric_keys = ["eval_loss", "eval_accuracy", "eval_dead_code_fraction"] | |
| metrics = {k: v for k, v in metrics.items() if k.split("/")[0] in metric_keys} | |
| # --- Set the session states --- | |
| st.session_state["model_name_id"] = model_name | |
| st.session_state["cb_acts"] = cb_acts | |
| st.session_state["tokens_text"] = tokens_text | |
| st.session_state["tokens_str"] = tokens_str | |
| st.session_state["act_count_ft_tkns"] = act_count_ft_tkns | |
| st.session_state["num_codes"] = num_codes | |
| st.session_state["gcb"] = gcb | |
| st.session_state["cb_at"] = cb_at | |
| st.session_state["is_attn"] = is_attn | |
| st.session_state["seq_len"] = seq_len | |
| if not deploy: | |
| st.markdown("## Metrics") | |
| # hide metrics by default | |
| if st.checkbox("Show Model Metrics"): | |
| st.write(metrics) | |
| st.markdown("## Demo Codes") | |
| demo_codes_desc = ( | |
| "This section contains codes that we've found to be interpretable along " | |
| "with a description of the feature we think they are capturing. " | |
| "Click on the π search button for a code to see the tokens that code activates on." | |
| ) | |
| st.write(demo_codes_desc) | |
| demo_file_path = codes_cache_path + "demo_codes.txt" | |
| if st.checkbox("Show Demo Codes"): | |
| try: | |
| with open(demo_file_path, "r") as f: | |
| demo_codes = f.readlines() | |
| except FileNotFoundError: | |
| demo_codes = [] | |
| code_desc, code_regex = "", "" | |
| demo_codes = [code.strip() for code in demo_codes if code.strip()] | |
| num_cols = 6 if is_attn else 5 | |
| cols = st.columns([1] * (num_cols - 1) + [2]) | |
| # st.markdown(button_height_style, unsafe_allow_html=True) | |
| cols[0].markdown("Search", help="Button to see token activations for the code.") | |
| cols[1].write("Code") | |
| cols[2].write("Layer") | |
| if is_attn: | |
| cols[3].write("Head") | |
| cols[-2].markdown( | |
| "Num Acts", | |
| help="Number of tokens that the code activates on in the acts dataset.", | |
| ) | |
| cols[-1].markdown("Description", help="Interpreted description of the code.") | |
| if len(demo_codes) == 0: | |
| st.markdown( | |
| f""" | |
| <div style="font-size: 1.0rem; color: red;"> | |
| No demo codes found in file {demo_file_path} | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| skip = True | |
| for code_txt in demo_codes: | |
| if code_txt.startswith("##"): | |
| skip = True | |
| continue | |
| if code_txt.startswith("#"): | |
| code_desc, code_regex = code_txt[1:].split(":") | |
| code_desc, code_regex = code_desc.strip(), code_regex.strip() | |
| skip = False | |
| continue | |
| if skip: | |
| continue | |
| code_info = utils.CodeInfo.from_str(code_txt, regex=code_regex) | |
| comp_info = f"layer{code_info.layer}_{f'head{code_info.head}' if code_info.head is not None else ''}" | |
| button_key = ( | |
| f"demo_search_code{code_info.code}_layer{code_info.layer}_desc-{code_info.description}" | |
| + (f"head{code_info.head}" if code_info.head is not None else "") | |
| ) | |
| cols = st.columns([1] * (num_cols - 1) + [2]) | |
| button_clicked = cols[0].button( | |
| "π", | |
| key=button_key, | |
| ) | |
| if button_clicked: | |
| webapp_utils.set_ct_acts( | |
| code_info.code, code_info.layer, code_info.head, None, is_attn | |
| ) | |
| cols[1].write(code_info.code) | |
| cols[2].write(str(code_info.layer)) | |
| if is_attn: | |
| cols[3].write(str(code_info.head)) | |
| cols[-2].write(str(act_count_ft_tkns[comp_info][code_info.code])) | |
| cols[-1].write(code_desc) | |
| skip = True | |
| # --- Code Search --- | |
| st.markdown("## Code Search") | |
| code_search_desc = ( | |
| "To find whether the codebooks model has captured a relevant feature from the data (e.g. pronouns)," | |
| " you can specify a regex pattern for your feature (e.g. βhe|she|theyβ) and find whether any code" | |
| " activating on the regex pattern exists.\n\n" | |
| "Since strings can contain several tokens, you can specify the token you want a code to fire on by" | |
| " using a capture group. For example, the search term βNew (York)β will try to find codes that" | |
| " activate on the bigram feature βNew Yorkβ at the York token" | |
| ) | |
| if st.checkbox("Search with Regex"): | |
| st.write(code_search_desc) | |
| regex_pattern = st.text_input( | |
| "Enter a regex pattern", | |
| help="Wrap code token in the first group. E.g. New (York)", | |
| key="regex_pattern", | |
| ) | |
| # topk = st.slider("Top K", 1, 20, 10) | |
| prec_col, sort_col = st.columns(2) | |
| prec_threshold = prec_col.slider( | |
| "Precision Threshold", | |
| 0.0, | |
| 1.0, | |
| 0.9, | |
| help="Shows codes with precision on the regex pattern above the threshold.", | |
| ) | |
| sort_by_options = ["Precision", "Recall", "Num Acts"] | |
| sort_by_name = sort_col.radio( | |
| "Sort By", | |
| sort_by_options, | |
| index=0, | |
| horizontal=True, | |
| help="Sorts the codes by the selected metric.", | |
| ) | |
| sort_by = sort_by_options.index(sort_by_name) | |
| def get_codebook_wise_codes_for_regex( | |
| regex_pattern, prec_threshold, gcb, model_name | |
| ): | |
| """Get codebook wise codes for a given regex pattern.""" | |
| assert model_name is not None # required for loading from correct cache data | |
| return code_search_utils.get_codes_from_pattern( | |
| regex_pattern, | |
| tokens_text, | |
| token_byte_pos, | |
| cb_acts, | |
| act_count_ft_tkns, | |
| gcb=gcb, | |
| topk=8, | |
| prec_threshold=prec_threshold, | |
| ) | |
| if regex_pattern: | |
| codebook_wise_codes, re_token_matches = get_codebook_wise_codes_for_regex( | |
| regex_pattern, | |
| prec_threshold, | |
| gcb, | |
| model_name, | |
| ) | |
| st.markdown( | |
| f"Found <span style='color:green;'>{re_token_matches}</span> matches", | |
| unsafe_allow_html=True, | |
| ) | |
| num_search_cols = 7 if is_attn else 6 | |
| non_deploy_offset = 0 | |
| if not deploy: | |
| non_deploy_offset = 1 | |
| num_search_cols += non_deploy_offset | |
| cols = st.columns(num_search_cols) | |
| cols[0].markdown("Search", help="Button to see token activations for the code.") | |
| cols[1].write("Layer") | |
| if is_attn: | |
| cols[2].write("Head") | |
| cols[-4 - non_deploy_offset].write("Code") | |
| cols[-3 - non_deploy_offset].write("Precision") | |
| cols[-2 - non_deploy_offset].write("Recall") | |
| cols[-1 - non_deploy_offset].markdown( | |
| "Num Acts", | |
| help="Number of tokens that the code activates on in the acts dataset.", | |
| ) | |
| if not deploy: | |
| cols[-1].markdown( | |
| "Save to Demos", | |
| help="Button to save the code to demos along with the regex pattern.", | |
| ) | |
| all_codes = codebook_wise_codes.items() | |
| all_codes = [ | |
| (cb_name, code_pr_info) | |
| for cb_name, code_pr_infos in all_codes | |
| for code_pr_info in code_pr_infos | |
| ] | |
| all_codes = sorted(all_codes, key=lambda x: x[1][1 + sort_by], reverse=True) | |
| for cb_name, (code, prec, rec, code_acts) in all_codes: | |
| layer_head = cb_name.split("_") | |
| layer = layer_head[0][5:] | |
| head = layer_head[1][4:] if len(layer_head) > 1 else None | |
| button_key = f"search_code{code}_layer{layer}" + ( | |
| f"head{head}" if head is not None else "" | |
| ) | |
| cols = st.columns(num_search_cols) | |
| extra_args = { | |
| "prec": prec, | |
| "recall": rec, | |
| "num_acts": code_acts, | |
| "regex": regex_pattern, | |
| } | |
| button_clicked = cols[0].button("π", key=button_key) | |
| if button_clicked: | |
| webapp_utils.set_ct_acts(code, layer, head, extra_args, is_attn) | |
| cols[1].write(layer) | |
| if is_attn: | |
| cols[2].write(head) | |
| cols[-4 - non_deploy_offset].write(code) | |
| cols[-3 - non_deploy_offset].write(f"{prec*100:.2f}%") | |
| cols[-2 - non_deploy_offset].write(f"{rec*100:.2f}%") | |
| cols[-1 - non_deploy_offset].write(str(code_acts)) | |
| if not deploy: | |
| webapp_utils.add_save_code_button( | |
| demo_file_path, | |
| num_acts=code_acts, | |
| save_regex=True, | |
| prec=prec, | |
| recall=rec, | |
| button_st_container=cols[-1], | |
| button_key_suffix=f"_code{code}_layer{layer}_head{head}", | |
| ) | |
| if len(all_codes) == 0: | |
| st.markdown( | |
| f""" | |
| <div style="font-size: 1.0rem; color: red;"> | |
| No codes found for pattern {regex_pattern} at precision threshold: {prec_threshold} | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # --- Display Code Token Activations --- | |
| st.markdown("## Code Token Activations") | |
| filter_codes = st.checkbox("Show filters", key="filter_codes", value=True) | |
| act_range, layer_code_acts = None, None | |
| if filter_codes: | |
| act_range = st.slider( | |
| "Minimum number of activations", | |
| 0, | |
| 10_000, | |
| 100, | |
| key="ct_act_range", | |
| help="Filter codes by the number of tokens they activate on.", | |
| ) | |
| cols = st.columns(5 if is_attn else 4) | |
| layer = cols[0].number_input("Layer", 0, num_layers - 1, 0, key="ct_act_layer") | |
| if is_attn: | |
| head = cols[1].number_input("Head", 0, num_heads - 1, 0, key="ct_act_head") | |
| else: | |
| head = None | |
| def_code = st.session_state.get("ct_act_code", 0) | |
| if filter_codes: | |
| layer_code_acts = act_count_ft_tkns[ | |
| f"layer{layer}{'_head'+str(head) if head is not None else ''}" | |
| ] | |
| def_code = webapp_utils.find_next_code(def_code, layer_code_acts, act_range) | |
| if "ct_act_code" in st.session_state: | |
| st.session_state["ct_act_code"] = def_code | |
| code = cols[-3].number_input( | |
| "Code", | |
| 0, | |
| num_codes - 1, | |
| def_code, | |
| key="ct_act_code", | |
| ) | |
| num_examples = cols[-2].number_input( | |
| "Max Results", | |
| -1, | |
| 1000, # setting to 1000 for efficiency purposes even though it can be more than 1000. | |
| 100, | |
| help="Number of examples to show in the results. Set to -1 to show all examples.", | |
| ) | |
| ctx_size = cols[-1].number_input( | |
| "Context Size", | |
| 1, | |
| 10, | |
| 5, | |
| help="Number of tokens to show before and after the code token.", | |
| ) | |
| acts, acts_count = webapp_utils.get_code_acts( | |
| model_name, | |
| tokens_str, | |
| code, | |
| layer, | |
| head, | |
| ctx_size, | |
| num_examples, | |
| is_fsm=is_fsm, | |
| ) | |
| st.write( | |
| f"Token Activations for Layer {layer}{f' Head {head}' if head is not None else ''} Code {code} | " | |
| f"Activates on {acts_count[0]} tokens on the acts dataset", | |
| ) | |
| if not deploy: | |
| webapp_utils.add_save_code_button( | |
| demo_file_path, | |
| acts_count[0], | |
| save_regex=False, | |
| button_text=True, | |
| button_key_suffix="_token_acts", | |
| ) | |
| st.markdown(webapp_utils.escape_markdown(acts), unsafe_allow_html=True) | |