Spaces:
Sleeping
Sleeping
| """Util functions for codebook features.""" | |
| import pathlib | |
| import re | |
| import typing | |
| from dataclasses import dataclass | |
| from functools import partial | |
| from typing import Optional | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from termcolor import colored | |
| from tqdm import tqdm | |
| class CodeInfo: | |
| """Dataclass for codebook info.""" | |
| code: int | |
| layer: int | |
| head: Optional[int] | |
| cb_at: Optional[str] = None | |
| # for patching interventions | |
| pos: Optional[int] = None | |
| code_pos: Optional[int] = -1 | |
| # for description & regex-based interpretation | |
| description: Optional[str] = None | |
| regex: Optional[str] = None | |
| prec: Optional[float] = None | |
| recall: Optional[float] = None | |
| num_acts: Optional[int] = None | |
| def __post_init__(self): | |
| """Convert to appropriate types.""" | |
| self.code = int(self.code) | |
| self.layer = int(self.layer) | |
| if self.head: | |
| self.head = int(self.head) | |
| if self.pos: | |
| self.pos = int(self.pos) | |
| if self.code_pos: | |
| self.code_pos = int(self.code_pos) | |
| if self.prec: | |
| self.prec = float(self.prec) | |
| assert 0 <= self.prec <= 1 | |
| if self.recall: | |
| self.recall = float(self.recall) | |
| assert 0 <= self.recall <= 1 | |
| if self.num_acts: | |
| self.num_acts = int(self.num_acts) | |
| def check_description_info(self): | |
| """Check if the regex info is present.""" | |
| assert self.num_acts is not None and self.description is not None | |
| if self.regex is not None: | |
| assert self.prec is not None and self.recall is not None | |
| def __repr__(self): | |
| """Return the string representation.""" | |
| repr = f"CodeInfo(code={self.code}, layer={self.layer}, head={self.head}, cb_at={self.cb_at}" | |
| if self.pos is not None or self.code_pos is not None: | |
| repr += f", pos={self.pos}, code_pos={self.code_pos}" | |
| if self.description is not None: | |
| repr += f", description={self.description}" | |
| if self.regex is not None: | |
| repr += f", regex={self.regex}, prec={self.prec}, recall={self.recall}" | |
| if self.num_acts is not None: | |
| repr += f", num_acts={self.num_acts}" | |
| repr += ")" | |
| return repr | |
| def from_str(cls, code_txt, *args, **kwargs): | |
| """Extract code info fields from string.""" | |
| code_txt = code_txt.strip().lower() | |
| code_txt = code_txt.split(", ") | |
| code_txt = dict(txt.split(": ") for txt in code_txt) | |
| return cls(*args, **code_txt, **kwargs) | |
| class ModelInfoForWebapp: | |
| """Model info for webapp.""" | |
| model_name: str | |
| pretrained_path: str | |
| dataset_name: str | |
| num_codes: int | |
| cb_at: str | |
| gcb: str | |
| n_layers: int | |
| n_heads: Optional[int] = None | |
| seed: int = 42 | |
| max_samples: int = 2000 | |
| def __post_init__(self): | |
| """Convert to correct types.""" | |
| self.num_codes = int(self.num_codes) | |
| self.n_layers = int(self.n_layers) | |
| if self.n_heads == "None": | |
| self.n_heads = None | |
| elif self.n_heads is not None: | |
| self.n_heads = int(self.n_heads) | |
| self.seed = int(self.seed) | |
| self.max_samples = int(self.max_samples) | |
| def load(cls, path): | |
| """Parse model info from path.""" | |
| path = pathlib.Path(path) | |
| with open(path / "info.txt", "r") as f: | |
| lines = f.readlines() | |
| lines = dict(line.strip().split(": ") for line in lines) | |
| return cls(**lines) | |
| def save(self, path): | |
| """Save model info to path.""" | |
| path = pathlib.Path(path) | |
| with open(path / "info.txt", "w") as f: | |
| for k, v in self.__dict__.items(): | |
| f.write(f"{k}: {v}\n") | |
| def logits_to_pred(logits, tokenizer, k=5): | |
| """Convert logits to top-k predictions.""" | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) | |
| probs = sorted_logits.softmax(dim=-1) | |
| topk_preds = [tokenizer.convert_ids_to_tokens(e) for e in sorted_indices[:, -1, :k]] | |
| topk_preds = [ | |
| tokenizer.convert_tokens_to_string([e]) for batch in topk_preds for e in batch | |
| ] | |
| return [(topk_preds[i], probs[:, -1, i].item()) for i in range(len(topk_preds))] | |
| def features_to_tokens(cb_key, cb_acts, num_codes, code=None): | |
| """Return the set of token ids each codebook feature activates on.""" | |
| codebook_ids = cb_acts[cb_key] | |
| if code is None: | |
| features_tokens = [[] for _ in range(num_codes)] | |
| for i in tqdm(range(codebook_ids.shape[0])): | |
| for j in range(codebook_ids.shape[1]): | |
| for k in range(codebook_ids.shape[2]): | |
| features_tokens[codebook_ids[i, j, k]].append((i, j)) | |
| else: | |
| idx0, idx1, _ = np.where(codebook_ids == code) | |
| features_tokens = list(zip(idx0, idx1)) | |
| return features_tokens | |
| def color_str(s: str, html: bool, color: Optional[str] = None): | |
| """Color the string for html or terminal.""" | |
| if html: | |
| color = "DeepSkyBlue" if color is None else color | |
| return f"<span style='color:{color}'>{s}</span>" | |
| else: | |
| color = "light_cyan" if color is None else color | |
| return colored(s, color) | |
| def color_tokens_tokfsm(tokens, color_idx, html=False): | |
| """Separate states with a dash and color red the tokens in color_idx.""" | |
| ret_string = "" | |
| itr_over_color_idx = 0 | |
| tokens_enumerate = enumerate(tokens) | |
| if tokens[0] == "<|endoftext|>": | |
| next(tokens_enumerate) | |
| if color_idx[0] == 0: | |
| itr_over_color_idx += 1 | |
| for i, c in tokens_enumerate: | |
| if i % 2 == 1: | |
| ret_string += "-" | |
| if itr_over_color_idx < len(color_idx) and i == color_idx[itr_over_color_idx]: | |
| ret_string += color_str(c, html) | |
| itr_over_color_idx += 1 | |
| else: | |
| ret_string += c | |
| return ret_string | |
| def color_tokens(tokens, color_idx, n=3, html=False): | |
| """Color the tokens in color_idx.""" | |
| ret_string = "" | |
| last_colored_token_idx = -1 | |
| for i in color_idx: | |
| c_str = tokens[i] | |
| if i <= last_colored_token_idx + 2 * n + 1: | |
| ret_string += "".join(tokens[last_colored_token_idx + 1 : i]) | |
| else: | |
| ret_string += "".join( | |
| tokens[last_colored_token_idx + 1 : last_colored_token_idx + n + 1] | |
| ) | |
| ret_string += " ... " | |
| ret_string += "".join(tokens[i - n : i]) | |
| ret_string += color_str(c_str, html) | |
| last_colored_token_idx = i | |
| ret_string += "".join( | |
| tokens[ | |
| last_colored_token_idx + 1 : min(last_colored_token_idx + n, len(tokens)) | |
| ] | |
| ) | |
| return ret_string | |
| def prepare_example_print( | |
| example_id, | |
| example_tokens, | |
| tokens_to_color, | |
| html, | |
| color_fn=color_tokens, | |
| ): | |
| """Format example to print.""" | |
| example_output = color_str(example_id, html, "green") | |
| example_output += ( | |
| ": " | |
| + color_fn(example_tokens, tokens_to_color, html=html) | |
| + ("<br>" if html else "\n") | |
| ) | |
| return example_output | |
| def print_token_activations_of_code( | |
| code_act_by_pos, | |
| tokens, | |
| is_fsm=False, | |
| n=3, | |
| max_examples=100, | |
| randomize=False, | |
| html=False, | |
| return_example_list=False, | |
| ): | |
| """Print the context with the tokens that a code activates on. | |
| Args: | |
| code_act_by_pos: list of (example_id, token_pos_id) tuples specifying | |
| the token positions that a code activates on in a dataset. | |
| tokens: list of tokens of a dataset. | |
| is_fsm: whether the dataset is the TokFSM dataset. | |
| n: context to print around each side of a token that the code activates on. | |
| max_examples: maximum number of examples to print. | |
| randomize: whether to randomize the order of examples. | |
| html: Format the printing style for html or terminal. | |
| return_example_list: whether to return the printed string by examples or as a single string. | |
| Returns: | |
| string of all examples formatted if `return_example_list` is False otherwise | |
| list of (example_string, num_tokens_colored) tuples for each example. | |
| """ | |
| if randomize: | |
| raise NotImplementedError("Randomize not yet implemented.") | |
| indices = range(len(code_act_by_pos)) | |
| print_output = [] if return_example_list else "" | |
| curr_ex = code_act_by_pos[0][0] | |
| total_examples = 0 | |
| tokens_to_color = [] | |
| color_fn = color_tokens_tokfsm if is_fsm else partial(color_tokens, n=n) | |
| for idx in indices: | |
| if total_examples > max_examples: | |
| break | |
| i, j = code_act_by_pos[idx] | |
| if i != curr_ex and curr_ex >= 0: | |
| # got new example so print the previous one | |
| curr_ex_output = prepare_example_print( | |
| curr_ex, | |
| tokens[curr_ex], | |
| tokens_to_color, | |
| html, | |
| color_fn, | |
| ) | |
| total_examples += 1 | |
| if return_example_list: | |
| print_output.append((curr_ex_output, len(tokens_to_color))) | |
| else: | |
| print_output += curr_ex_output | |
| curr_ex = i | |
| tokens_to_color = [] | |
| tokens_to_color.append(j) | |
| curr_ex_output = prepare_example_print( | |
| curr_ex, | |
| tokens[curr_ex], | |
| tokens_to_color, | |
| html, | |
| color_fn, | |
| ) | |
| if return_example_list: | |
| print_output.append((curr_ex_output, len(tokens_to_color))) | |
| else: | |
| print_output += curr_ex_output | |
| print_output += color_str("*" * 50, html, "green") | |
| total_examples += 1 | |
| return print_output | |
| def print_token_activations_of_codes( | |
| ft_tkns, | |
| tokens, | |
| is_fsm=False, | |
| n=3, | |
| start=0, | |
| stop=1000, | |
| indices=None, | |
| max_examples=100, | |
| freq_filter=None, | |
| randomize=False, | |
| html=False, | |
| return_example_list=False, | |
| ): | |
| """Print the tokens for the codebook features.""" | |
| indices = list(range(start, stop)) if indices is None else indices | |
| num_tokens = len(tokens) * len(tokens[0]) | |
| codes, token_act_freqs, token_acts = [], [], [] | |
| for i in indices: | |
| tkns_of_code = ft_tkns[i] | |
| freq = (len(tkns_of_code), 100 * len(tkns_of_code) / num_tokens) | |
| if freq_filter is not None and freq[1] > freq_filter: | |
| continue | |
| codes.append(i) | |
| token_act_freqs.append(freq) | |
| if len(tkns_of_code) > 0: | |
| tkn_acts = print_token_activations_of_code( | |
| tkns_of_code, | |
| tokens, | |
| is_fsm, | |
| n=n, | |
| max_examples=max_examples, | |
| randomize=randomize, | |
| html=html, | |
| return_example_list=return_example_list, | |
| ) | |
| token_acts.append(tkn_acts) | |
| else: | |
| token_acts.append("") | |
| return codes, token_act_freqs, token_acts | |
| def patch_in_codes(run_cb_ids, hook, pos, code, code_pos=None): | |
| """Patch in the `code` at `run_cb_ids`.""" | |
| pos = slice(None) if pos is None else pos | |
| code_pos = slice(None) if code_pos is None else code_pos | |
| if code_pos == "append": | |
| assert pos == slice(None) | |
| run_cb_ids = F.pad(run_cb_ids, (0, 1), mode="constant", value=code) | |
| if isinstance(pos, typing.Iterable) or isinstance(pos, typing.Iterable): | |
| for p in pos: | |
| run_cb_ids[:, p, code_pos] = code | |
| else: | |
| run_cb_ids[:, pos, code_pos] = code | |
| return run_cb_ids | |
| def get_cb_hook_key(cb_at: str, layer_idx: int, gcb_idx: Optional[int] = None): | |
| """Get the layer name used to store hooks/cache.""" | |
| comp_name = "attn" if "attn" in cb_at else "mlp" | |
| if gcb_idx is None: | |
| return f"blocks.{layer_idx}.{comp_name}.codebook_layer.hook_codebook_ids" | |
| else: | |
| return f"blocks.{layer_idx}.{comp_name}.codebook_layer.codebook.{gcb_idx}.hook_codebook_ids" | |
| def run_model_fn_with_codes( | |
| input, | |
| cb_model, | |
| fn_name, | |
| fn_kwargs=None, | |
| list_of_code_infos=(), | |
| ): | |
| """Run the `cb_model`'s `fn_name` method while activating the codes in `list_of_code_infos`. | |
| Common use case includes running the `run_with_cache` method while activating the codes. | |
| For running the `generate` method, use `generate_with_codes` instead. | |
| """ | |
| if fn_kwargs is None: | |
| fn_kwargs = {} | |
| hook_fns = [ | |
| partial(patch_in_codes, pos=tupl.pos, code=tupl.code, code_pos=tupl.code_pos) | |
| for tupl in list_of_code_infos | |
| ] | |
| fwd_hooks = [ | |
| (get_cb_hook_key(tupl.cb_at, tupl.layer, tupl.head), hook_fns[i]) | |
| for i, tupl in enumerate(list_of_code_infos) | |
| ] | |
| cb_model.reset_hook_kwargs() | |
| with cb_model.hooks(fwd_hooks, [], True, False) as hooked_model: | |
| ret = hooked_model.__getattribute__(fn_name)(input, **fn_kwargs) | |
| return ret | |
| def generate_with_codes( | |
| input, | |
| cb_model, | |
| list_of_code_infos=(), | |
| tokfsm=None, | |
| generate_kwargs=None, | |
| ): | |
| """Sample from the language model while activating the codes in `list_of_code_infos`.""" | |
| gen = run_model_fn_with_codes( | |
| input, | |
| cb_model, | |
| "generate", | |
| generate_kwargs, | |
| list_of_code_infos, | |
| ) | |
| return tokfsm.seq_to_traj(gen) if tokfsm is not None else gen | |
| def JSD(logits1, logits2, pos=-1, reduction="batchmean"): | |
| """Compute the Jensen-Shannon divergence between two distributions.""" | |
| if len(logits1.shape) == 3: | |
| logits1, logits2 = logits1[:, pos, :], logits2[:, pos, :] | |
| probs1 = F.softmax(logits1, dim=-1) | |
| probs2 = F.softmax(logits2, dim=-1) | |
| total_m = (0.5 * (probs1 + probs2)).log() | |
| loss = 0.0 | |
| loss += F.kl_div( | |
| total_m, | |
| F.log_softmax(logits1, dim=-1), | |
| log_target=True, | |
| reduction=reduction, | |
| ) | |
| loss += F.kl_div( | |
| total_m, | |
| F.log_softmax(logits2, dim=-1), | |
| log_target=True, | |
| reduction=reduction, | |
| ) | |
| return 0.5 * loss | |
| def cb_hook_key_to_info(layer_hook_key: str): | |
| """Get the layer info from the codebook layer hook key. | |
| Args: | |
| layer_hook_key: the hook key of the codebook layer. | |
| E.g. `blocks.3.attn.codebook_layer.hook_codebook_ids` | |
| Returns: | |
| comp_name: the name of the component codebook is appied at. | |
| layer_idx: the layer index. | |
| gcb_idx: the codebook index if the codebook layer is grouped, otherwise None. | |
| """ | |
| layer_search = re.search(r"blocks\.(\d+)\.(\w+)\.", layer_hook_key) | |
| assert layer_search is not None | |
| layer_idx, comp_name = int(layer_search.group(1)), layer_search.group(2) | |
| gcb_idx_search = re.search(r"codebook\.(\d+)", layer_hook_key) | |
| if gcb_idx_search is not None: | |
| gcb_idx = int(gcb_idx_search.group(1)) | |
| else: | |
| gcb_idx = None | |
| return comp_name, layer_idx, gcb_idx | |
| def find_code_changes(cache1, cache2, pos=None): | |
| """Find the codebook codes that are different between the two caches.""" | |
| for k in cache1.keys(): | |
| if "codebook" in k: | |
| c1 = cache1[k][0, pos] | |
| c2 = cache2[k][0, pos] | |
| if not torch.all(c1 == c2): | |
| print(cb_hook_key_to_info(k), c1.tolist(), c2.tolist()) | |
| print(cb_hook_key_to_info(k), c1.tolist(), c2.tolist()) | |
| def common_codes_in_cache(cache_codes, threshold=0.0): | |
| """Get the common code in the cache.""" | |
| codes, counts = torch.unique(cache_codes, return_counts=True, sorted=True) | |
| counts = counts.float() * 100 | |
| counts /= cache_codes.shape[1] | |
| counts, indices = torch.sort(counts, descending=True) | |
| codes = codes[indices] | |
| indices = counts > threshold | |
| codes, counts = codes[indices], counts[indices] | |
| return codes, counts | |
| def parse_topic_codes_string( | |
| info_str: str, | |
| pos: Optional[int] = None, | |
| code_append: Optional[bool] = False, | |
| **code_info_kwargs, | |
| ): | |
| """Parse the topic codes string.""" | |
| code_info_strs = info_str.strip().split("\n") | |
| code_info_strs = [e.strip() for e in code_info_strs if e] | |
| topic_codes = [] | |
| layer, head = None, None | |
| if code_append is None: | |
| code_pos = None | |
| else: | |
| code_pos = "append" if code_append else -1 | |
| for code_info_str in code_info_strs: | |
| topic_codes.append( | |
| CodeInfo.from_str( | |
| code_info_str, | |
| pos=pos, | |
| code_pos=code_pos, | |
| **code_info_kwargs, | |
| ) | |
| ) | |
| if code_append is None or code_append: | |
| continue | |
| if layer == topic_codes[-1].layer and head == topic_codes[-1].head: | |
| code_pos -= 1 # type: ignore | |
| else: | |
| code_pos = -1 | |
| topic_codes[-1].code_pos = code_pos | |
| layer, head = topic_codes[-1].layer, topic_codes[-1].head | |
| return topic_codes | |
| def find_similar_codes(cb_model, code_info, n=8): | |
| """Find the `n` most similar codes to the given code using cosine similarity. | |
| Useful for finding related codes for interpretability. | |
| """ | |
| codebook = cb_model.get_codebook(code_info) | |
| device = codebook.weight.device | |
| code = codebook(torch.tensor(code_info.code).to(device)) | |
| code = code.to(device) | |
| logits = torch.matmul(code, codebook.weight.T) | |
| _, indices = torch.topk(logits, n) | |
| assert indices[0] == code_info.code | |
| assert torch.allclose(logits[indices[0]], torch.tensor(1.0)) | |
| return indices[1:], logits[indices[1:]].tolist() | |