Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| import math | |
| from transformers import GPT2Tokenizer | |
| import tiktoken | |
| from transformers import GPT2LMHeadModel | |
| from transformers import PretrainedConfig | |
| class GPTConfig(PretrainedConfig): | |
| visual_size: int = 1024 | |
| vocab_size: int = 50257 | |
| block_size: int = 1024 | |
| tags_embd: int = 400 | |
| n_embd: int = 768 | |
| n_layer: int = 6 | |
| n_head: int = 12 | |
| def __init__(self,**kwargs): | |
| super().__init__(**kwargs) | |
| self.hidden_size = self.n_embd | |
| class CasualSelfAttention(nn.Module): | |
| def __init__(self, config: GPTConfig): | |
| super().__init__() | |
| assert config.n_embd % config.n_head == 0 | |
| self.c_attn = nn.Linear(config.n_embd, config.n_embd * 3) | |
| self.visual_attn = nn.Linear(config.visual_size, config.n_embd * 2) | |
| self.tags_attn = nn.Linear(config.tags_embd, config.n_embd * 2) | |
| self.c_proj = nn.Linear(config.n_embd, config.n_embd) | |
| self.n_head = config.n_head | |
| self.n_embed = config.n_embd | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.register_buffer( | |
| 'bias', torch.tril(torch.ones(199, 353)) | |
| .view(1, 1, 199, 353) | |
| ) | |
| def forward(self, x: torch.Tensor, visual_features: torch.Tensor = None, tags_embedding: torch.Tensor = None) -> torch.Tensor: | |
| B, T, C = x.size() | |
| visual_features=visual_features.to(self.device) | |
| tags_embedding=tags_embedding.to(self.device) | |
| qkv = self.c_attn(x) # the error happens here | |
| q, k, v = qkv.split(self.n_embed, dim=2) | |
| q = q.view(B, T, self.n_head, self.n_embed // self.n_head).transpose(1, 2) | |
| k = k.view(B, T, self.n_head, self.n_embed // self.n_head).transpose(1, 2) | |
| v = v.view(B, T, self.n_head, self.n_embed // self.n_head).transpose(1, 2) | |
| # Handle visual input if provided | |
| if visual_features is not None: | |
| visual_kv = self.visual_attn(visual_features) | |
| visual_k, visual_v = visual_kv.split(self.n_embed, dim=2) | |
| visual_k = visual_k.view(B, visual_features.size(1), self.n_head, self.n_embed // self.n_head).transpose(1, 2) | |
| visual_v = visual_v.view(B, visual_features.size(1), self.n_head, self.n_embed // self.n_head).transpose(1, 2) | |
| k = torch.cat([k, visual_k], dim=-2) | |
| v = torch.cat([v, visual_v], dim=-2) | |
| if tags_embedding is not None: | |
| tags_kv = self.tags_attn(tags_embedding) | |
| tags_k, tags_v = tags_kv.split(self.n_embed, dim=2) | |
| tags_k = tags_k.view(B, tags_embedding.size(1), self.n_head, self.n_embed // self.n_head).transpose(1, 2) | |
| tags_v = tags_v.view(B, tags_embedding.size(1), self.n_head, self.n_embed // self.n_head).transpose(1, 2) | |
| k = torch.cat([k, tags_k], dim=-2) | |
| v = torch.cat([v, tags_v], dim=-2) | |
| # Causal self-attention computation | |
| att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) | |
| device = att.device | |
| query_seq_len, key_seq_len = T, k.size(-2) | |
| # Text can attend to: previous text + all visual/tag tokens | |
| text_mask = torch.tril(torch.ones(T, T, device=device)) # Text-to-text causal | |
| non_text_mask = torch.ones(T, key_seq_len - T, device=device) # Text-to-other full | |
| combined_mask = torch.cat([text_mask, non_text_mask], dim=1) | |
| # Reshape for broadcasting | |
| combined_mask = combined_mask.view(1, 1, T, key_seq_len) | |
| att = att.masked_fill(combined_mask == 0, float('-inf')) | |
| att = F.softmax(att, dim=-1) | |
| visual_att = att[..., :T, T:].mean().item() # Text → Visual attention | |
| y = att @ v | |
| y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * (self.n_embed // self.n_head)) | |
| y = self.c_proj(y) | |
| return y | |
| class MLP(nn.Module): | |
| def __init__(self, config: GPTConfig): | |
| super(MLP, self).__init__() | |
| self.c_fc = nn.Linear(config.n_embd, config.n_embd * 4) # c_fc means fully connected layer and c is for context | |
| self.gelu = nn.GELU() | |
| self.c_proj = nn.Linear(config.n_embd * 4, config.n_embd) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.c_fc(x) | |
| x = self.gelu(x) | |
| x = self.c_proj(x) | |
| return x | |
| class Block(nn.Module): | |
| def __init__(self, config: GPTConfig): | |
| super(Block, self).__init__() | |
| self.ln_1 = nn.LayerNorm(config.n_embd) | |
| self.attn = CasualSelfAttention(config) | |
| self.ln_2 = nn.LayerNorm(config.n_embd) | |
| self.mlp = MLP(config) | |
| def forward(self, x: torch.Tensor,visual_features: torch.Tensor, tags_embedding: torch.Tensor) -> torch.Tensor: | |
| x = x + self.attn(self.ln_1(x),visual_features, tags_embedding) | |
| x = x + self.mlp(self.ln_2(x)) | |
| return x | |
| class DistilGPT2(GPT2LMHeadModel): | |
| def __init__(self, config: GPTConfig): | |
| super(DistilGPT2, self).__init__(config) | |
| self.config = config | |
| self.transformer = nn.ModuleDict( | |
| { | |
| 'wte': nn.Embedding(config.vocab_size, config.n_embd), | |
| 'wpe': nn.Embedding(config.block_size, config.n_embd), | |
| 'h': nn.ModuleList( | |
| [ | |
| Block(config) for _ in range(config.n_layer) | |
| ] | |
| ), # transformer blocks | |
| 'ln_f': nn.LayerNorm(config.n_embd) # final layer normalization | |
| } | |
| ) | |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # linear layer for projection from embedding to vocab size | |
| def forward(self, idx: torch.Tensor, visual_features: torch.Tensor = None, tags_embedding: torch.Tensor = None, return_dict: bool = False) -> torch.Tensor: | |
| idx=idx.to(self.device) | |
| B, T = idx.size() | |
| assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, Block size is {self.config.block_size}" | |
| # forward the token and positional embeddings | |
| pos = torch.arange(0, T, dtype=torch.long, device=idx.device) | |
| pos_emb = self.transformer['wpe'](pos) | |
| tok_emb = self.transformer['wte'](idx) | |
| x = tok_emb + pos_emb | |
| # forward the transformer | |
| for block in self.transformer['h']: | |
| x = block(x, visual_features=visual_features, tags_embedding=tags_embedding) | |
| # forward the head | |
| x = self.transformer['ln_f'](x) | |
| logits = self.lm_head(x) | |
| if return_dict: | |
| return {'logits': logits} | |
| else: | |
| return logits | |
| def from_pretrained(cls, model_type: str): | |
| """Loads pre-trained GPT-2 model weights from Hugging Face and handles custom layers.""" | |
| from transformers import GPT2LMHeadModel | |
| print(f"Loading weights from pre-trained GPT: {model_type}") | |
| # Ensure the model type is supported | |
| assert model_type in {'distilgpt2', 'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} | |
| # Define configurations based on the model type | |
| config_args = { | |
| 'distilgpt2': dict(n_layer=6, n_head=12, n_embd=768), | |
| 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), | |
| 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), | |
| 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), | |
| 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), | |
| }[model_type] | |
| config_args['vocab_size'] = 50257 | |
| config_args['block_size'] = 1024 | |
| # Initialize the custom model with the given configuration | |
| config = GPTConfig(**config_args) | |
| from transformers import GPT2Config | |
| config = GPT2Config.from_pretrained('distilgpt2') | |
| config.visual_size=1024 | |
| config.block_size=1024 | |
| config.tags_embd=400 | |
| config.n_embd=768 | |
| config.n_layer=6 | |
| config.n_head=12 | |
| model = cls(config) | |
| # Load state dictionary from Hugging Face model | |
| model_hf = GPT2LMHeadModel.from_pretrained(model_type) | |
| sd_hf = model_hf.state_dict() | |
| # State dictionary of the custom model | |
| sd = model.state_dict() | |
| # Filter out custom keys that are not in the pre-trained model | |
| custom_keys = {k for k in sd if 'visual_attn' in k or 'tags_attn' in k} | |
| sd_keys_filtered = [k for k in sd if k not in custom_keys] | |
| # Load matching keys | |
| for k in sd_keys_filtered: | |
| if k in sd_hf and sd_hf[k].shape == sd[k].shape: | |
| with torch.no_grad(): | |
| sd[k].copy_(sd_hf[k]) | |
| # Initialize custom layers separately | |
| for k in custom_keys: | |
| with torch.no_grad(): | |
| print(f"Initializing custom layer: {k}") | |
| sd[k].normal_(0.0, 0.02) # Adjust initialization method as needed | |
| # Update the model's state dictionary | |
| model.load_state_dict(sd, strict=False) | |
| return model | |
| def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): | |
| # Prepare inputs for autoregressive generation | |
| inputs = {"idx": input_ids} | |
| if past: | |
| inputs["past_key_values"] = past # Include past key values for caching | |
| # Include additional features like visual and tags if provided | |
| if "visual_features" in kwargs: | |
| inputs["visual_features"] = kwargs["visual_features"] | |
| if "tags_embedding" in kwargs: | |
| inputs["tags_embedding"] = kwargs["tags_embedding"] | |
| return inputs | |
| def generate( | |
| self, | |
| input_ids: torch.Tensor = None, | |
| max_length: int = None, | |
| min_length: int = None, | |
| do_sample: bool = None, | |
| early_stopping: bool = None, | |
| num_beams: int = None, | |
| temperature: float = None, | |
| top_k: int = None, | |
| top_p: float = None, | |
| repetition_penalty: float = None, | |
| bos_token_id: int = None, | |
| pad_token_id: int = None, | |
| eos_token_ids: int = None, | |
| length_penalty: float = None, | |
| no_repeat_ngram_size: int = None, | |
| num_return_sequences: int = None, | |
| attention_mask: torch.Tensor = None, | |
| visual_features: torch.Tensor = None, | |
| tags_embedding: torch.Tensor = None, | |
| ): | |
| """ | |
| Generate sequences using autoregressive decoding. | |
| Args: | |
| input_ids (torch.Tensor): Input tensor of token IDs. | |
| max_length (int): Maximum length of the generated sequence. | |
| min_length (int): Minimum length of the generated sequence. | |
| do_sample (bool): Whether to use sampling; if False, uses greedy decoding. | |
| early_stopping (bool): Whether to stop when all beams have finished. | |
| num_beams (int): Number of beams for beam search. | |
| temperature (float): Sampling temperature. | |
| top_k (int): Top-k sampling. | |
| top_p (float): Top-p (nucleus) sampling. | |
| repetition_penalty (float): Penalty for repeated n-grams. | |
| bos_token_id (int): Beginning of sequence token ID. | |
| pad_token_id (int): Padding token ID. | |
| eos_token_ids (int): End of sequence token ID. | |
| length_penalty (float): Beam search length penalty. | |
| no_repeat_ngram_size (int): Size of n-grams not to repeat. | |
| num_return_sequences (int): Number of sequences to return. | |
| attention_mask (torch.Tensor): Attention mask for padding tokens. | |
| visual_features (torch.Tensor): Visual features for the transformer. | |
| tags_embedding (torch.Tensor): Tags embeddings for the transformer. | |
| Returns: | |
| torch.Tensor: Generated sequences of token IDs. | |
| """ | |
| # Default values for unspecified parameters | |
| max_length = max_length or self.config.block_size | |
| min_length = min_length or 0 | |
| do_sample = do_sample or False | |
| early_stopping = early_stopping or False | |
| num_beams = num_beams or 1 | |
| temperature = temperature or 1.0 | |
| top_k = top_k or 0 | |
| top_p = top_p or 1.0 | |
| repetition_penalty = repetition_penalty or 1.0 | |
| bos_token_id = bos_token_id or self.config.bos_token_id | |
| pad_token_id = pad_token_id or self.config.pad_token_id | |
| eos_token_ids = eos_token_ids or self.config.eos_token_ids | |
| length_penalty = length_penalty or 1.0 | |
| no_repeat_ngram_size = no_repeat_ngram_size or 0 | |
| num_return_sequences = num_return_sequences or 1 | |
| if input_ids is not None: | |
| batch_size=input_ids.shape[0] | |
| else: | |
| batch_size=1 | |
| if input_ids is None: | |
| assert isinstance(bos_token_id, int) and bos_token_id >= 0, ( | |
| "You should either supply a context to complete as `input_ids` input " | |
| "or a `bos_token_id` (integer >= 0) as a first token to start the generation." | |
| ) | |
| input_ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long) | |
| else: | |
| assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)." | |
| # Avoid duplicate outputs when greedy decoding | |
| if not do_sample: | |
| if num_beams == 1: | |
| assert num_return_sequences == 1, ( | |
| "Greedy decoding will always produce the same output for num_beams == 1 " | |
| "and num_return_sequences > 1. Please set num_return_sequences = 1." | |
| ) | |
| else: | |
| assert num_beams >= num_return_sequences, ( | |
| "Greedy beam search decoding cannot return more sequences than it has beams. " | |
| "Please set num_beams >= num_return_sequences." | |
| ) | |
| # Create attention mask if necessary | |
| if attention_mask is None: | |
| if pad_token_id is not None and pad_token_id in input_ids: | |
| attention_mask = (input_ids != pad_token_id).long() | |
| else: | |
| attention_mask = torch.ones_like(input_ids) | |
| # Set pad_token_id if not provided and eos_token_ids is available | |
| if pad_token_id is None and eos_token_ids is not None: | |
| pad_token_id = eos_token_ids | |
| print(f"Setting `pad_token_id` to {pad_token_id} (first `eos_token_ids`) to generate sequence.") | |
| # Current sequence length and vocabulary size | |
| cur_len = input_ids.size(1) | |
| vocab_size = self.config.vocab_size | |
| # Adjust effective batch size and multiplier for sampling | |
| if do_sample: | |
| effective_batch_size = batch_size * num_return_sequences | |
| effective_batch_mult = num_return_sequences | |
| else: | |
| effective_batch_size = batch_size | |
| effective_batch_mult = 1 | |
| # Expand input_ids and attention_mask for beam search or multiple return sequences | |
| if num_return_sequences > 1 or num_beams > 1: | |
| input_ids_len = input_ids.size(-1) | |
| # Expand dimensions and repeat for each beam and return sequence | |
| input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len) | |
| attention_mask = attention_mask.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len) | |
| # Reshape to combine batch and beam dimensions | |
| input_ids = input_ids.reshape(effective_batch_size * num_beams, input_ids_len) | |
| attention_mask = attention_mask.reshape(effective_batch_size * num_beams, input_ids_len) | |
| if num_beams > 1: | |
| output = self._generate_beam_search( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| visual_features=visual_features, | |
| tags_embedding=tags_embedding, | |
| cur_len=input_ids.size(1), | |
| max_length=max_length, | |
| min_length=min_length, | |
| do_sample=do_sample, | |
| early_stopping=early_stopping, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| pad_token_id=pad_token_id, | |
| eos_token_ids=eos_token_ids, | |
| length_penalty=length_penalty, | |
| num_return_sequences=num_return_sequences, | |
| num_beams=num_beams, | |
| ) | |
| else: | |
| output = self._generate_no_beam_search( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| visual_features=visual_features, | |
| tags_embedding=tags_embedding, | |
| cur_len=input_ids.size(1), | |
| max_length=max_length, | |
| min_length=min_length, | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| pad_token_id=pad_token_id, | |
| eos_token_ids=eos_token_ids, | |
| batch_size=batch_size, | |
| vocab_size=vocab_size, | |
| ) | |
| return output | |
| def _generate_no_beam_search( | |
| self, | |
| input_ids, | |
| visual_features, | |
| tags_embedding, | |
| cur_len, | |
| max_length, | |
| min_length, | |
| do_sample, | |
| temperature, | |
| top_k, | |
| top_p, | |
| repetition_penalty, | |
| no_repeat_ngram_size, | |
| pad_token_id, | |
| eos_token_ids, | |
| batch_size, | |
| vocab_size, | |
| attention_mask, | |
| ): | |
| """ | |
| Generate sequences for each example without beam search (num_beams == 1). | |
| All returned sequences are generated independently. | |
| """ | |
| # Track unfinished sentences and their lengths | |
| unfinished_sents=torch.ones_like(input_ids[:,0]) | |
| sent_lengths=torch.ones_like(input_ids[:,0])*max_length | |
| past=None | |
| while cur_len < max_length: | |
| if past is None: | |
| inputs = input_ids | |
| else: | |
| inputs = input_ids[:, -1].unsqueeze(1) | |
| model_inputs = self.prepare_inputs_for_generation( | |
| inputs, past=past, visual_features=visual_features, tags_embedding=tags_embedding | |
| ) | |
| outputs = self(**model_inputs) | |
| # next_token_logits = outputs[0][-1, :] # Extract logits for the last token, shape: [batch_size, vocab_size] | |
| next_token_logits = outputs[:, -1, :] | |
| # next_token_logits = next_token_logits.unsqueeze(0) # Add a new dimension: [1, batch_size, vocab_size] | |
| next_token_logits = next_token_logits.expand(batch_size, vocab_size) # Expand to match batch size: [batch_size, vocab_size] | |
| # if self._do_output_past(outputs): # we dont have this function implemented | |
| # past = outputs[1] | |
| # Apply repetition penalty | |
| if repetition_penalty != 1.0: | |
| next_token_logits_penalties=self._create_next_token_logits_penalties(input_ids,next_token_logits,repetition_penalty) | |
| next_token_logits=next_token_logits @ next_token_logits_penalties.T # .T de mn 3ndy | |
| # Prevent repetition of n-grams | |
| if no_repeat_ngram_size > 0: # not checked generated by chat | |
| banned_tokens=self.calc_banned_ngram_tokens(input_ids,batch_size,no_repeat_ngram_size,cur_len) # not checked generated by chat | |
| banned_tokens_indices_mask=[] | |
| for banned_tokens_slice in banned_tokens: | |
| banned_tokens_indices_mask.append( | |
| [True if token in banned_tokens_slice else False for token in range(vocab_size)] | |
| ) | |
| banned_tokens_indices_mask=torch.tensor(banned_tokens_indices_mask,dtype=bool) | |
| next_token_logits[banned_tokens_indices_mask]= -float('inf') | |
| # Min length constraint for EOS | |
| if eos_token_ids is not None and cur_len < min_length: | |
| # create eos_token_id boolean mask | |
| is_token_logit_eos_token = torch.arange(vocab_size, device=next_token_logits.device) == eos_token_ids | |
| eos_token_indices_mask = is_token_logit_eos_token.unsqueeze(0).expand(batch_size, -1) | |
| # next_token_logits=next_token_logits.unsqueeze(0).expand(batch_size,vocab_size) | |
| next_token_logits = next_token_logits.masked_fill(eos_token_indices_mask, -float("inf")) | |
| # Sampling or greedy decoding | |
| if do_sample: | |
| if temperature != 1.0: | |
| next_token_logits = next_token_logits / temperature | |
| next_token_logits=self.top_k_top_p_filtering(next_token_logits,top_k=top_k,top_p=top_p) | |
| next_token = torch.multinomial(torch.softmax(next_token_logits, dim=-1), num_samples=1).squeeze(1) | |
| else: | |
| next_token=torch.argmax(next_token_logits,dim=-1) | |
| if eos_token_ids is not None: | |
| unfinished_sents=unfinished_sents.to(self.device) | |
| tokens_to_add = next_token * unfinished_sents + pad_token_id * (1 - unfinished_sents) | |
| else: | |
| tokens_to_add = next_token | |
| input_ids=input_ids.to(self.device) | |
| input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=1) | |
| if eos_token_ids is not None: | |
| eos_in_sents = tokens_to_add == eos_token_ids | |
| # If sentence is unfinished and the token to add is eos, sent_lengths is filled with current length | |
| is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents * eos_in_sents.int() | |
| sent_lengths=sent_lengths.to(self.device) | |
| sent_lengths = ( | |
| sent_lengths * (1 - is_sents_unfinished_and_token_to_add_is_eos) | |
| + cur_len * is_sents_unfinished_and_token_to_add_is_eos | |
| ) | |
| # Unfinished sentences are set to zero if eos is in the sentence | |
| unfinished_sents -= is_sents_unfinished_and_token_to_add_is_eos | |
| # Stop if there is a </s> in each sentence, or if we exceed the maximum length | |
| if torch.max(unfinished_sents) == 0: # => this line is what keeps it stopping at 57 etc.. | |
| break | |
| cur_len += 1 | |
| # Pad sequences if necessary | |
| min_sent_length = sent_lengths.min() | |
| max_sent_length = sent_lengths.max() | |
| if min_sent_length != max_sent_length: | |
| assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths" | |
| padding = torch.ones((batch_size, max_sent_length), dtype=torch.int) * pad_token_id | |
| broad_casted_sent_lengths = sent_lengths.unsqueeze(-1).expand(batch_size, max_sent_length) | |
| broad_casted_range = torch.arange(max_sent_length).unsqueeze(0).expand(batch_size, max_sent_length).T | |
| # Use torch.where to apply padding where necessary | |
| decoded = torch.where(broad_casted_range < broad_casted_sent_lengths, input_ids, padding) | |
| else: | |
| decoded = input_ids | |
| return decoded | |
| def _create_next_token_logits_penalties(self,input_ids, logits, repetition_penalty): | |
| """ | |
| Create logit penalties for already seen input_ids based on repetition penalty. | |
| Args: | |
| input_ids (torch.Tensor): Tensor of shape (batch_size, seq_len) containing input token IDs. | |
| logits (torch.Tensor): Tensor of shape (batch_size, vocab_size) containing next-token logits. | |
| repetition_penalty (float): The penalty to apply for repeated tokens. | |
| Returns: | |
| torch.Tensor: Tensor of shape (batch_size, vocab_size) with applied penalties. | |
| """ | |
| token_penalties=torch.ones_like(logits) | |
| prev_input_ids=[torch.unique(input_id) for input_id in input_ids] | |
| for i, prev_input_id in enumerate(prev_input_ids): | |
| logits_penalized=logits[i][prev_input_ids] | |
| logit_penalties=torch.zeros_like(logits_penalized) | |
| logit_penalties[logits_penalized<0]=repetition_penalty | |
| logit_penalties[logits_penalized>0]=1/repetition_penalty | |
| token_penalties[i].scatter_(0,prev_input_id,logit_penalties) | |
| return token_penalties | |
| def top_k_top_p_filtering(self,logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1): | |
| """ | |
| Filter a distribution of logits using top-k and/or nucleus (top-p) filtering. | |
| Args: | |
| logits: Logits distribution of shape (batch size, vocabulary size). | |
| top_k (int): Keep only top k tokens with the highest probability. | |
| top_p (float): Keep the top tokens with cumulative probability >= top_p (nucleus filtering). | |
| filter_value (float): Value to assign to filtered logits. | |
| min_tokens_to_keep (int): Ensure at least this many tokens are kept. | |
| Returns: | |
| torch.Tensor: Filtered logits. | |
| """ | |
| logits_shape = logits.size() | |
| # Top-k filtering | |
| if top_k > 0: | |
| top_k = min(max(top_k, min_tokens_to_keep), logits_shape[-1]) # Safety check | |
| # Remove all tokens with a probability less than the last token of the top-k | |
| top_k_values, _ = torch.topk(logits, top_k, dim=-1) | |
| min_top_k_values = top_k_values[:, -1].unsqueeze(-1) # Minimum logit in top-k | |
| logits = torch.where(logits < min_top_k_values, torch.full_like(logits, filter_value), logits) | |
| # Top-p (nucleus) filtering | |
| if top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) | |
| cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) | |
| # Remove tokens with cumulative probability above the threshold | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| if min_tokens_to_keep > 1: | |
| # Ensure we keep at least min_tokens_to_keep tokens | |
| sorted_indices_to_remove[:, :min_tokens_to_keep] = 0 | |
| # Shift the indices to the right to keep also the first token above the threshold | |
| sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dims=-1) | |
| sorted_indices_to_remove[:, 0] = 0 | |
| # Scatter sorted indices back to original indexing | |
| indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) | |
| logits = torch.where(indices_to_remove, torch.full_like(logits, filter_value), logits) | |
| return logits | |
| def calc_banned_ngram_tokens(self,prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len): | |
| """ | |
| Calculate banned n-gram tokens for no-repeat n-gram constraints. | |
| Args: | |
| prev_input_ids (torch.Tensor): Tensor of shape (num_hypos, seq_len) containing token sequences. | |
| num_hypos (int): Number of hypotheses in the batch. | |
| no_repeat_ngram_size (int): Size of the n-grams to avoid repeating. | |
| cur_len (int): Current length of the sequence being generated. | |
| Returns: | |
| List[List[int]]: List of banned tokens for each hypothesis. | |
| """ | |
| if cur_len + 1 < no_repeat_ngram_size: | |
| # Return no banned tokens if not enough tokens have been generated | |
| return [[] for _ in range(num_hypos)] | |
| # Dictionary to store generated n-grams for each hypothesis | |
| generated_ngrams = [{} for _ in range(num_hypos)] | |
| # Populate the n-grams | |
| for idx in range(num_hypos): | |
| gen_tokens = prev_input_ids[idx].tolist() | |
| generated_ngram = generated_ngrams[idx] | |
| for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]): | |
| prev_ngram_tuple = tuple(ngram[:-1]) | |
| generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] | |
| def _get_generated_ngrams(hypo_idx): | |
| # Get n-grams that have already appeared | |
| start_idx = cur_len + 1 - no_repeat_ngram_size | |
| ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist()) | |
| return generated_ngrams[hypo_idx].get(ngram_idx, []) | |
| # Calculate banned tokens for each hypothesis | |
| banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)] | |
| return banned_tokens | |