from typing import List, Tuple, Optional, Any, Dict import torch from transformers.cache_utils import Cache class FgateDynamicCache(Cache): """ A cache that grows dynamically as more tokens are generated. This is the default for generative models. It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is `[batch_size, num_heads, seq_len, head_dim]`. Example: ```python >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") >>> # Prepare a cache class and pass it to model's forward >>> past_key_values = DynamicCache() >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values # access cache filled with key/values from generation DynamicCache() ``` """ def __init__(self) -> None: super().__init__() self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] self.log_fgate_cache: List[torch.Tensor] = [] self.key_shift_cache: List[torch.Tensor] = [] self.value_shift_cache: List[torch.Tensor] = [] self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen def update_shift_cache( self, key_shift_state: torch.Tensor, value_shift_state: torch.Tensor, layer_idx, ): assert layer_idx == len(self.key_shift_cache) == len(self.value_shift_cache) self.key_shift_cache.append(key_shift_state) self.value_shift_cache.append(value_shift_state) def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: """ Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the sequence length. """ if layer_idx < len(self): return (self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx]) else: raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") def __iter__(self): """ Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over keys and values """ for layer_idx in range(len(self)): yield (self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx]) def __len__(self): """ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds to the number of layers in the model. """ return len(self.key_cache) def update( self, key_states: torch.Tensor, value_states: torch.Tensor, log_fgate_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Parameters: key_states (`torch.Tensor`): The new key states to cache. value_states (`torch.Tensor`): The new value states to cache. layer_idx (`int`): The index of the layer to cache the states for. cache_kwargs (`Dict[str, Any]`, `optional`): Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. Return: A tuple containing the updated key and value states. """ assert log_fgate_states.ndim == 3, f"log_fgate must be (B, H, T), but get {log_fgate_states.size()}" # Update the number of seen tokens if layer_idx == 0: self._seen_tokens += key_states.shape[-2] # Update the cache if len(self.key_cache) <= layer_idx: self.key_cache.append(key_states) self.value_cache.append(value_states) self.log_fgate_cache.append(log_fgate_states) else: self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) self.log_fgate_cache[layer_idx] = torch.cat([self.log_fgate_cache[layer_idx], log_fgate_states], dim=-1) return self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx] def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # TODO: deprecate this function in favor of `cache_position` if len(self.key_cache) <= layer_idx: return 0 return self.key_cache[layer_idx].shape[-2] def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" return None def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for backward compatibility.""" legacy_cache = () for layer_idx in range(len(self)): legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx]),) return legacy_cache @classmethod def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_layers: Optional[int] = None) -> "DynamicCache": """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for backward compatibility.""" raise NotImplementedError assert num_layers is not None cache = cls(num_layers) if past_key_values is not None: for layer_idx in range(len(past_key_values)): key_states, value_states, log_fgate_states = past_key_values[layer_idx] cache.update(key_states, value_states, log_fgate_states, layer_idx) return cache def crop(self, max_length: int): """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" # In case it is negative if max_length < 0: max_length = self.get_seq_length() - abs(max_length) if self.get_seq_length() <= max_length: return self._seen_tokens = max_length for idx in range(len(self.key_cache)): self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] self.log_fgate_cache[idx] = self.log_fgate_cache[idx][..., :max_length] def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]: """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by `_split_model_inputs()` in `generation.utils`""" out = [] for i in range(0, full_batch_size, split_size): current_split = DynamicCache() current_split._seen_tokens = self._seen_tokens current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] current_split.log_fgate_cache = [tensor[i : i + split_size] for tensor in self.log_fgate_cache] out.append(current_split) return out @classmethod def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache": """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in `generation.utils`""" cache = cls() for idx in range(len(splits[0])): layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0) layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0) layer_log_fgates = torch.cat([current.log_fgate_cache[idx] for current in splits], dim=0) cache.update(layer_keys, layer_values, layer_log_fgates, idx) return cache def batch_repeat_interleave(self, repeats: int): """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" for layer_idx in range(len(self)): self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) self.log_fgate_cache[layer_idx] = self.log_fgate_cache[layer_idx].repeat_interleave(repeats, dim=0) def batch_select_indices(self, indices: torch.Tensor): """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" for layer_idx in range(len(self)): self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] self.log_fgate_cache[layer_idx] = self.log_fgate_cache[layer_idx][indices, ...]