CDGPT2-Deployment / distil_gpt2.py
Ziad Meligy
Pushing deployment to space
eb8805a
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
from transformers import GPT2Tokenizer
import tiktoken
from transformers import GPT2LMHeadModel
from transformers import PretrainedConfig
@dataclass
class GPTConfig(PretrainedConfig):
visual_size: int = 1024
vocab_size: int = 50257
block_size: int = 1024
tags_embd: int = 400
n_embd: int = 768
n_layer: int = 6
n_head: int = 12
def __init__(self,**kwargs):
super().__init__(**kwargs)
self.hidden_size = self.n_embd
class CasualSelfAttention(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
assert config.n_embd % config.n_head == 0
self.c_attn = nn.Linear(config.n_embd, config.n_embd * 3)
self.visual_attn = nn.Linear(config.visual_size, config.n_embd * 2)
self.tags_attn = nn.Linear(config.tags_embd, config.n_embd * 2)
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
self.n_head = config.n_head
self.n_embed = config.n_embd
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.register_buffer(
'bias', torch.tril(torch.ones(199, 353))
.view(1, 1, 199, 353)
)
def forward(self, x: torch.Tensor, visual_features: torch.Tensor = None, tags_embedding: torch.Tensor = None) -> torch.Tensor:
B, T, C = x.size()
visual_features=visual_features.to(self.device)
tags_embedding=tags_embedding.to(self.device)
qkv = self.c_attn(x) # the error happens here
q, k, v = qkv.split(self.n_embed, dim=2)
q = q.view(B, T, self.n_head, self.n_embed // self.n_head).transpose(1, 2)
k = k.view(B, T, self.n_head, self.n_embed // self.n_head).transpose(1, 2)
v = v.view(B, T, self.n_head, self.n_embed // self.n_head).transpose(1, 2)
# Handle visual input if provided
if visual_features is not None:
visual_kv = self.visual_attn(visual_features)
visual_k, visual_v = visual_kv.split(self.n_embed, dim=2)
visual_k = visual_k.view(B, visual_features.size(1), self.n_head, self.n_embed // self.n_head).transpose(1, 2)
visual_v = visual_v.view(B, visual_features.size(1), self.n_head, self.n_embed // self.n_head).transpose(1, 2)
k = torch.cat([k, visual_k], dim=-2)
v = torch.cat([v, visual_v], dim=-2)
if tags_embedding is not None:
tags_kv = self.tags_attn(tags_embedding)
tags_k, tags_v = tags_kv.split(self.n_embed, dim=2)
tags_k = tags_k.view(B, tags_embedding.size(1), self.n_head, self.n_embed // self.n_head).transpose(1, 2)
tags_v = tags_v.view(B, tags_embedding.size(1), self.n_head, self.n_embed // self.n_head).transpose(1, 2)
k = torch.cat([k, tags_k], dim=-2)
v = torch.cat([v, tags_v], dim=-2)
# Causal self-attention computation
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
device = att.device
query_seq_len, key_seq_len = T, k.size(-2)
# Text can attend to: previous text + all visual/tag tokens
text_mask = torch.tril(torch.ones(T, T, device=device)) # Text-to-text causal
non_text_mask = torch.ones(T, key_seq_len - T, device=device) # Text-to-other full
combined_mask = torch.cat([text_mask, non_text_mask], dim=1)
# Reshape for broadcasting
combined_mask = combined_mask.view(1, 1, T, key_seq_len)
att = att.masked_fill(combined_mask == 0, float('-inf'))
att = F.softmax(att, dim=-1)
visual_att = att[..., :T, T:].mean().item() # Text → Visual attention
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * (self.n_embed // self.n_head))
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, config: GPTConfig):
super(MLP, self).__init__()
self.c_fc = nn.Linear(config.n_embd, config.n_embd * 4) # c_fc means fully connected layer and c is for context
self.gelu = nn.GELU()
self.c_proj = nn.Linear(config.n_embd * 4, config.n_embd)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.c_fc(x)
x = self.gelu(x)
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, config: GPTConfig):
super(Block, self).__init__()
self.ln_1 = nn.LayerNorm(config.n_embd)
self.attn = CasualSelfAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd)
self.mlp = MLP(config)
def forward(self, x: torch.Tensor,visual_features: torch.Tensor, tags_embedding: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.ln_1(x),visual_features, tags_embedding)
x = x + self.mlp(self.ln_2(x))
return x
class DistilGPT2(GPT2LMHeadModel):
def __init__(self, config: GPTConfig):
super(DistilGPT2, self).__init__(config)
self.config = config
self.transformer = nn.ModuleDict(
{
'wte': nn.Embedding(config.vocab_size, config.n_embd),
'wpe': nn.Embedding(config.block_size, config.n_embd),
'h': nn.ModuleList(
[
Block(config) for _ in range(config.n_layer)
]
), # transformer blocks
'ln_f': nn.LayerNorm(config.n_embd) # final layer normalization
}
)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # linear layer for projection from embedding to vocab size
def forward(self, idx: torch.Tensor, visual_features: torch.Tensor = None, tags_embedding: torch.Tensor = None, return_dict: bool = False) -> torch.Tensor:
idx=idx.to(self.device)
B, T = idx.size()
assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, Block size is {self.config.block_size}"
# forward the token and positional embeddings
pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
pos_emb = self.transformer['wpe'](pos)
tok_emb = self.transformer['wte'](idx)
x = tok_emb + pos_emb
# forward the transformer
for block in self.transformer['h']:
x = block(x, visual_features=visual_features, tags_embedding=tags_embedding)
# forward the head
x = self.transformer['ln_f'](x)
logits = self.lm_head(x)
if return_dict:
return {'logits': logits}
else:
return logits
@classmethod
def from_pretrained(cls, model_type: str):
"""Loads pre-trained GPT-2 model weights from Hugging Face and handles custom layers."""
from transformers import GPT2LMHeadModel
print(f"Loading weights from pre-trained GPT: {model_type}")
# Ensure the model type is supported
assert model_type in {'distilgpt2', 'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
# Define configurations based on the model type
config_args = {
'distilgpt2': dict(n_layer=6, n_head=12, n_embd=768),
'gpt2': dict(n_layer=12, n_head=12, n_embd=768),
'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024),
'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280),
'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600),
}[model_type]
config_args['vocab_size'] = 50257
config_args['block_size'] = 1024
# Initialize the custom model with the given configuration
config = GPTConfig(**config_args)
from transformers import GPT2Config
config = GPT2Config.from_pretrained('distilgpt2')
config.visual_size=1024
config.block_size=1024
config.tags_embd=400
config.n_embd=768
config.n_layer=6
config.n_head=12
model = cls(config)
# Load state dictionary from Hugging Face model
model_hf = GPT2LMHeadModel.from_pretrained(model_type)
sd_hf = model_hf.state_dict()
# State dictionary of the custom model
sd = model.state_dict()
# Filter out custom keys that are not in the pre-trained model
custom_keys = {k for k in sd if 'visual_attn' in k or 'tags_attn' in k}
sd_keys_filtered = [k for k in sd if k not in custom_keys]
# Load matching keys
for k in sd_keys_filtered:
if k in sd_hf and sd_hf[k].shape == sd[k].shape:
with torch.no_grad():
sd[k].copy_(sd_hf[k])
# Initialize custom layers separately
for k in custom_keys:
with torch.no_grad():
print(f"Initializing custom layer: {k}")
sd[k].normal_(0.0, 0.02) # Adjust initialization method as needed
# Update the model's state dictionary
model.load_state_dict(sd, strict=False)
return model
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
# Prepare inputs for autoregressive generation
inputs = {"idx": input_ids}
if past:
inputs["past_key_values"] = past # Include past key values for caching
# Include additional features like visual and tags if provided
if "visual_features" in kwargs:
inputs["visual_features"] = kwargs["visual_features"]
if "tags_embedding" in kwargs:
inputs["tags_embedding"] = kwargs["tags_embedding"]
return inputs
def generate(
self,
input_ids: torch.Tensor = None,
max_length: int = None,
min_length: int = None,
do_sample: bool = None,
early_stopping: bool = None,
num_beams: int = None,
temperature: float = None,
top_k: int = None,
top_p: float = None,
repetition_penalty: float = None,
bos_token_id: int = None,
pad_token_id: int = None,
eos_token_ids: int = None,
length_penalty: float = None,
no_repeat_ngram_size: int = None,
num_return_sequences: int = None,
attention_mask: torch.Tensor = None,
visual_features: torch.Tensor = None,
tags_embedding: torch.Tensor = None,
):
"""
Generate sequences using autoregressive decoding.
Args:
input_ids (torch.Tensor): Input tensor of token IDs.
max_length (int): Maximum length of the generated sequence.
min_length (int): Minimum length of the generated sequence.
do_sample (bool): Whether to use sampling; if False, uses greedy decoding.
early_stopping (bool): Whether to stop when all beams have finished.
num_beams (int): Number of beams for beam search.
temperature (float): Sampling temperature.
top_k (int): Top-k sampling.
top_p (float): Top-p (nucleus) sampling.
repetition_penalty (float): Penalty for repeated n-grams.
bos_token_id (int): Beginning of sequence token ID.
pad_token_id (int): Padding token ID.
eos_token_ids (int): End of sequence token ID.
length_penalty (float): Beam search length penalty.
no_repeat_ngram_size (int): Size of n-grams not to repeat.
num_return_sequences (int): Number of sequences to return.
attention_mask (torch.Tensor): Attention mask for padding tokens.
visual_features (torch.Tensor): Visual features for the transformer.
tags_embedding (torch.Tensor): Tags embeddings for the transformer.
Returns:
torch.Tensor: Generated sequences of token IDs.
"""
# Default values for unspecified parameters
max_length = max_length or self.config.block_size
min_length = min_length or 0
do_sample = do_sample or False
early_stopping = early_stopping or False
num_beams = num_beams or 1
temperature = temperature or 1.0
top_k = top_k or 0
top_p = top_p or 1.0
repetition_penalty = repetition_penalty or 1.0
bos_token_id = bos_token_id or self.config.bos_token_id
pad_token_id = pad_token_id or self.config.pad_token_id
eos_token_ids = eos_token_ids or self.config.eos_token_ids
length_penalty = length_penalty or 1.0
no_repeat_ngram_size = no_repeat_ngram_size or 0
num_return_sequences = num_return_sequences or 1
if input_ids is not None:
batch_size=input_ids.shape[0]
else:
batch_size=1
if input_ids is None:
assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
"You should either supply a context to complete as `input_ids` input "
"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
)
input_ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long)
else:
assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
# Avoid duplicate outputs when greedy decoding
if not do_sample:
if num_beams == 1:
assert num_return_sequences == 1, (
"Greedy decoding will always produce the same output for num_beams == 1 "
"and num_return_sequences > 1. Please set num_return_sequences = 1."
)
else:
assert num_beams >= num_return_sequences, (
"Greedy beam search decoding cannot return more sequences than it has beams. "
"Please set num_beams >= num_return_sequences."
)
# Create attention mask if necessary
if attention_mask is None:
if pad_token_id is not None and pad_token_id in input_ids:
attention_mask = (input_ids != pad_token_id).long()
else:
attention_mask = torch.ones_like(input_ids)
# Set pad_token_id if not provided and eos_token_ids is available
if pad_token_id is None and eos_token_ids is not None:
pad_token_id = eos_token_ids
print(f"Setting `pad_token_id` to {pad_token_id} (first `eos_token_ids`) to generate sequence.")
# Current sequence length and vocabulary size
cur_len = input_ids.size(1)
vocab_size = self.config.vocab_size
# Adjust effective batch size and multiplier for sampling
if do_sample:
effective_batch_size = batch_size * num_return_sequences
effective_batch_mult = num_return_sequences
else:
effective_batch_size = batch_size
effective_batch_mult = 1
# Expand input_ids and attention_mask for beam search or multiple return sequences
if num_return_sequences > 1 or num_beams > 1:
input_ids_len = input_ids.size(-1)
# Expand dimensions and repeat for each beam and return sequence
input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
attention_mask = attention_mask.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
# Reshape to combine batch and beam dimensions
input_ids = input_ids.reshape(effective_batch_size * num_beams, input_ids_len)
attention_mask = attention_mask.reshape(effective_batch_size * num_beams, input_ids_len)
if num_beams > 1:
output = self._generate_beam_search(
input_ids=input_ids,
attention_mask=attention_mask,
visual_features=visual_features,
tags_embedding=tags_embedding,
cur_len=input_ids.size(1),
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
early_stopping=early_stopping,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids,
length_penalty=length_penalty,
num_return_sequences=num_return_sequences,
num_beams=num_beams,
)
else:
output = self._generate_no_beam_search(
input_ids=input_ids,
attention_mask=attention_mask,
visual_features=visual_features,
tags_embedding=tags_embedding,
cur_len=input_ids.size(1),
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids,
batch_size=batch_size,
vocab_size=vocab_size,
)
return output
def _generate_no_beam_search(
self,
input_ids,
visual_features,
tags_embedding,
cur_len,
max_length,
min_length,
do_sample,
temperature,
top_k,
top_p,
repetition_penalty,
no_repeat_ngram_size,
pad_token_id,
eos_token_ids,
batch_size,
vocab_size,
attention_mask,
):
"""
Generate sequences for each example without beam search (num_beams == 1).
All returned sequences are generated independently.
"""
# Track unfinished sentences and their lengths
unfinished_sents=torch.ones_like(input_ids[:,0])
sent_lengths=torch.ones_like(input_ids[:,0])*max_length
past=None
while cur_len < max_length:
if past is None:
inputs = input_ids
else:
inputs = input_ids[:, -1].unsqueeze(1)
model_inputs = self.prepare_inputs_for_generation(
inputs, past=past, visual_features=visual_features, tags_embedding=tags_embedding
)
outputs = self(**model_inputs)
# next_token_logits = outputs[0][-1, :] # Extract logits for the last token, shape: [batch_size, vocab_size]
next_token_logits = outputs[:, -1, :]
# next_token_logits = next_token_logits.unsqueeze(0) # Add a new dimension: [1, batch_size, vocab_size]
next_token_logits = next_token_logits.expand(batch_size, vocab_size) # Expand to match batch size: [batch_size, vocab_size]
# if self._do_output_past(outputs): # we dont have this function implemented
# past = outputs[1]
# Apply repetition penalty
if repetition_penalty != 1.0:
next_token_logits_penalties=self._create_next_token_logits_penalties(input_ids,next_token_logits,repetition_penalty)
next_token_logits=next_token_logits @ next_token_logits_penalties.T # .T de mn 3ndy
# Prevent repetition of n-grams
if no_repeat_ngram_size > 0: # not checked generated by chat
banned_tokens=self.calc_banned_ngram_tokens(input_ids,batch_size,no_repeat_ngram_size,cur_len) # not checked generated by chat
banned_tokens_indices_mask=[]
for banned_tokens_slice in banned_tokens:
banned_tokens_indices_mask.append(
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
)
banned_tokens_indices_mask=torch.tensor(banned_tokens_indices_mask,dtype=bool)
next_token_logits[banned_tokens_indices_mask]= -float('inf')
# Min length constraint for EOS
if eos_token_ids is not None and cur_len < min_length:
# create eos_token_id boolean mask
is_token_logit_eos_token = torch.arange(vocab_size, device=next_token_logits.device) == eos_token_ids
eos_token_indices_mask = is_token_logit_eos_token.unsqueeze(0).expand(batch_size, -1)
# next_token_logits=next_token_logits.unsqueeze(0).expand(batch_size,vocab_size)
next_token_logits = next_token_logits.masked_fill(eos_token_indices_mask, -float("inf"))
# Sampling or greedy decoding
if do_sample:
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
next_token_logits=self.top_k_top_p_filtering(next_token_logits,top_k=top_k,top_p=top_p)
next_token = torch.multinomial(torch.softmax(next_token_logits, dim=-1), num_samples=1).squeeze(1)
else:
next_token=torch.argmax(next_token_logits,dim=-1)
if eos_token_ids is not None:
unfinished_sents=unfinished_sents.to(self.device)
tokens_to_add = next_token * unfinished_sents + pad_token_id * (1 - unfinished_sents)
else:
tokens_to_add = next_token
input_ids=input_ids.to(self.device)
input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=1)
if eos_token_ids is not None:
eos_in_sents = tokens_to_add == eos_token_ids
# If sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents * eos_in_sents.int()
sent_lengths=sent_lengths.to(self.device)
sent_lengths = (
sent_lengths * (1 - is_sents_unfinished_and_token_to_add_is_eos)
+ cur_len * is_sents_unfinished_and_token_to_add_is_eos
)
# Unfinished sentences are set to zero if eos is in the sentence
unfinished_sents -= is_sents_unfinished_and_token_to_add_is_eos
# Stop if there is a </s> in each sentence, or if we exceed the maximum length
if torch.max(unfinished_sents) == 0: # => this line is what keeps it stopping at 57 etc..
break
cur_len += 1
# Pad sequences if necessary
min_sent_length = sent_lengths.min()
max_sent_length = sent_lengths.max()
if min_sent_length != max_sent_length:
assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
padding = torch.ones((batch_size, max_sent_length), dtype=torch.int) * pad_token_id
broad_casted_sent_lengths = sent_lengths.unsqueeze(-1).expand(batch_size, max_sent_length)
broad_casted_range = torch.arange(max_sent_length).unsqueeze(0).expand(batch_size, max_sent_length).T
# Use torch.where to apply padding where necessary
decoded = torch.where(broad_casted_range < broad_casted_sent_lengths, input_ids, padding)
else:
decoded = input_ids
return decoded
def _create_next_token_logits_penalties(self,input_ids, logits, repetition_penalty):
"""
Create logit penalties for already seen input_ids based on repetition penalty.
Args:
input_ids (torch.Tensor): Tensor of shape (batch_size, seq_len) containing input token IDs.
logits (torch.Tensor): Tensor of shape (batch_size, vocab_size) containing next-token logits.
repetition_penalty (float): The penalty to apply for repeated tokens.
Returns:
torch.Tensor: Tensor of shape (batch_size, vocab_size) with applied penalties.
"""
token_penalties=torch.ones_like(logits)
prev_input_ids=[torch.unique(input_id) for input_id in input_ids]
for i, prev_input_id in enumerate(prev_input_ids):
logits_penalized=logits[i][prev_input_ids]
logit_penalties=torch.zeros_like(logits_penalized)
logit_penalties[logits_penalized<0]=repetition_penalty
logit_penalties[logits_penalized>0]=1/repetition_penalty
token_penalties[i].scatter_(0,prev_input_id,logit_penalties)
return token_penalties
def top_k_top_p_filtering(self,logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
"""
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering.
Args:
logits: Logits distribution of shape (batch size, vocabulary size).
top_k (int): Keep only top k tokens with the highest probability.
top_p (float): Keep the top tokens with cumulative probability >= top_p (nucleus filtering).
filter_value (float): Value to assign to filtered logits.
min_tokens_to_keep (int): Ensure at least this many tokens are kept.
Returns:
torch.Tensor: Filtered logits.
"""
logits_shape = logits.size()
# Top-k filtering
if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep), logits_shape[-1]) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
top_k_values, _ = torch.topk(logits, top_k, dim=-1)
min_top_k_values = top_k_values[:, -1].unsqueeze(-1) # Minimum logit in top-k
logits = torch.where(logits < min_top_k_values, torch.full_like(logits, filter_value), logits)
# Top-p (nucleus) filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
# Ensure we keep at least min_tokens_to_keep tokens
sorted_indices_to_remove[:, :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dims=-1)
sorted_indices_to_remove[:, 0] = 0
# Scatter sorted indices back to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits = torch.where(indices_to_remove, torch.full_like(logits, filter_value), logits)
return logits
def calc_banned_ngram_tokens(self,prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
"""
Calculate banned n-gram tokens for no-repeat n-gram constraints.
Args:
prev_input_ids (torch.Tensor): Tensor of shape (num_hypos, seq_len) containing token sequences.
num_hypos (int): Number of hypotheses in the batch.
no_repeat_ngram_size (int): Size of the n-grams to avoid repeating.
cur_len (int): Current length of the sequence being generated.
Returns:
List[List[int]]: List of banned tokens for each hypothesis.
"""
if cur_len + 1 < no_repeat_ngram_size:
# Return no banned tokens if not enough tokens have been generated
return [[] for _ in range(num_hypos)]
# Dictionary to store generated n-grams for each hypothesis
generated_ngrams = [{} for _ in range(num_hypos)]
# Populate the n-grams
for idx in range(num_hypos):
gen_tokens = prev_input_ids[idx].tolist()
generated_ngram = generated_ngrams[idx]
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1])
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
def _get_generated_ngrams(hypo_idx):
# Get n-grams that have already appeared
start_idx = cur_len + 1 - no_repeat_ngram_size
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
return generated_ngrams[hypo_idx].get(ngram_idx, [])
# Calculate banned tokens for each hypothesis
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
return banned_tokens