from typing import Dict, Optional, Tuple, Union import torch from transformers import AutoConfig, AutoModelForCausalLM from transformers import EsmConfig, LlamaConfig, PretrainedConfig from transformers import EsmModel, LlamaForCausalLM, PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.generation.utils import Cache, GenerateOutput class ModalityAdapterConfig(PretrainedConfig): model_type = "modality_adapter" def __init__( self, input_dim: int, intermediate_dim: int, output_dim: int, dropout_rate: float = 0.3, **kwargs ): super().__init__(**kwargs) self.input_dim = input_dim self.intermediate_dim = intermediate_dim self.output_dim = output_dim self.dropout_rate = dropout_rate class Esm2LlamaInstructConfig(PretrainedConfig): model_type = "esm2llama_instruct" def __init__( self, # model components esm_config: Optional[Union[EsmConfig, Dict]] = None, adapter_config: Optional[Union[ModalityAdapterConfig, Dict]] = None, llama_config: Optional[Union[LlamaConfig, Dict]] = None, # standalone attributes placeholder_id: int = 128003, **kwargs ): super().__init__(**kwargs) if isinstance(esm_config, dict): self.esm_config = EsmConfig(**esm_config) else: self.esm_config = esm_config if isinstance(llama_config, dict): self.llama_config = LlamaConfig(**llama_config) else: self.llama_config = llama_config if isinstance(adapter_config, dict): self.adapter_config = ModalityAdapterConfig(**adapter_config) else: self.adapter_config = adapter_config self.placeholder_id = placeholder_id class ModalityAdapter(PreTrainedModel): config_class = ModalityAdapterConfig def __init__(self, config: ModalityAdapterConfig): super().__init__(config) self.config = config self.fc1 = torch.nn.Linear(config.input_dim, config.intermediate_dim) self.fc2 = torch.nn.Linear(config.intermediate_dim, config.output_dim) self.activation = torch.nn.GELU() self.ln1 = torch.nn.LayerNorm(normalized_shape=config.intermediate_dim) # DEPRECATED self.ln2 = torch.nn.LayerNorm(normalized_shape=config.output_dim) # DEPRECATED self.dropout = torch.nn.Dropout(p=config.dropout_rate) self.post_init() # initialize weights and apply final processing def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: # input: (bsz, seq_len, input_dim) hidden_states = self.activation(self.fc1(hidden_states)) hidden_states = self.dropout(hidden_states) # interm: (bsz, seq_len, interm_dim) hidden_states = self.activation(self.fc2(hidden_states)) hidden_states = self.dropout(hidden_states) hidden_states = torch.nn.functional.normalize(hidden_states, p=2, dim=-1) return hidden_states # (bsz, seq_len, output_dim) class Esm2LlamaInstructForCausalLM(PreTrainedModel): """ Esm2LlamaInstructForCausalLM model for protein function prediction. Similar to `EncoderDecoderModel` but with more complicated architecture. Initialize with either a configuration OR all three components. `kwargs` can override standalone attributes in `Esm2LlamaInstructConfig`. """ config_class = Esm2LlamaInstructConfig def __init__( self, config: Optional[Esm2LlamaInstructConfig] = None, esm_encoder: Optional[EsmModel] = None, adapter: Optional[ModalityAdapter] = None, llama_decoder: Optional[LlamaForCausalLM] = None, **kwargs ): if config is not None: # components ignored if config is provided super().__init__(config) self.esm_encoder = EsmModel( config.esm_config, add_pooling_layer=False ) self.adapter = ModalityAdapter(config.adapter_config) self.llama_decoder = LlamaForCausalLM(config.llama_config) else: config = Esm2LlamaInstructConfig( esm_config=esm_encoder.config, adapter_config=adapter.config, llama_config=llama_decoder.config, **kwargs # override standalone attributes ) super().__init__(config) self.esm_encoder = esm_encoder self.adapter = adapter self.llama_decoder = llama_decoder def prepare_decoder_inputs( self, input_ids: torch.LongTensor, encoder_hidden_states: torch.FloatTensor, attention_mask: Optional[torch.LongTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None, ): """ Embed and replace placeholder in `input_ids` by encoder hidden states. `input_ids` must be passed to locate placeholder for replacement. """ # preparation batch_size, seq_len = input_ids.size() _, encoder_seq_len, _ = encoder_hidden_states.size() if attention_mask is None: attention_mask = torch.ones( (batch_size, seq_len), dtype=torch.long, device=input_ids.device ) if encoder_attention_mask is None: encoder_attention_mask = torch.ones( (batch_size, encoder_seq_len), dtype=torch.long, device=encoder_hidden_states.device ) inputs_embeds = self.llama_decoder.get_input_embeddings()(input_ids) # replacement placeholder_mask = input_ids == self.config.placeholder_id encoder_mask = encoder_attention_mask.bool() inputs_embeds[placeholder_mask] = encoder_hidden_states[encoder_mask] return inputs_embeds, attention_mask def forward( self, # chat template text inputs input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, labels: Optional[torch.LongTensor] = None, # protein amino-acid sequence inputs protein_input_ids: Optional[torch.LongTensor] = None, protein_attention_mask: Optional[torch.LongTensor] = None, protein_position_ids: Optional[torch.LongTensor] = None, protein_head_mask: Optional[torch.LongTensor] = None, protein_inputs_embeds: Optional[torch.FloatTensor] = None, # behavior control arguments use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, return_encoder_outputs: bool = False, return_adapter_outputs: bool = False, return_decoder_inputs: bool = False, cache_position: Optional[torch.LongTensor] = None ) -> Union[Tuple, CausalLMOutputWithPast]: """ Compute encoder and adapter outputs, then pass to decoder. `input_ids` is expected to be [prompt + description] in teacher-forcing scenario and [prompt] only in first iteration of inference (with return_decoder_inputs=True). Attention: possible concatenation of the mask and labels should be handled before calling this method. `inputs_embeds` not allowed due to placeholder replacement scheme. """ # esm_encoder forward encoder_output = self.esm_encoder( input_ids=protein_input_ids, attention_mask=protein_attention_mask, position_ids=protein_position_ids, head_mask=protein_head_mask, inputs_embeds=protein_inputs_embeds, use_cache=False, # because config.esm_config.is_decoder=False output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) encoder_hidden_states = encoder_output[0] encoder_attention_mask = protein_attention_mask if return_encoder_outputs: return encoder_output # adapter forward adapter_output = self.adapter(encoder_hidden_states) if return_adapter_outputs: return adapter_output, encoder_attention_mask # decoder input preparation inputs_embeds, attention_mask = self.prepare_decoder_inputs( input_ids=input_ids, encoder_hidden_states=adapter_output, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, ) if return_decoder_inputs: return inputs_embeds, attention_mask # llama_decoder forward return self.llama_decoder.forward( input_ids=None, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, return_dict=return_dict, cache_position=cache_position ) def generate( self, inputs: torch.LongTensor, # alias of `input_ids` attention_mask: Optional[torch.LongTensor] = None, protein_input_ids: Optional[torch.LongTensor] = None, protein_attention_mask: Optional[torch.LongTensor] = None, protein_inputs_embeds: Optional[torch.FloatTensor] = None, **kwargs ) -> Union[GenerateOutput, torch.LongTensor]: """ Do inference based on given input prompt. `inputs` is expected to be [prompt] only. Output will not keep the input prompt due to input in form of embeds. Generation behavior can be controlled by `args` and `kwargs`, read `GenerationMixin.generate` for more info. """ # get decoder inputs prompt_inputs_embeds, prompt_attention_mask = self( input_ids=inputs, attention_mask=attention_mask, protein_input_ids=protein_input_ids, protein_attention_mask=protein_attention_mask, protein_inputs_embeds=protein_inputs_embeds, use_cache=False, output_attentions=False, output_hidden_states=False, return_dict=False, return_decoder_inputs=True ) # do generate on llama_decoder return self.llama_decoder.generate( inputs_embeds=prompt_inputs_embeds, attention_mask=prompt_attention_mask, **kwargs ) def gradient_checkpointing_enable(self): """ Enable gradient checkpointing for all submodules that support it. Attention! Model need to be in train mode before calling this method. """ if hasattr(self.esm_encoder, "gradient_checkpointing_enable"): self.esm_encoder.gradient_checkpointing_enable() if hasattr(self.llama_decoder, "gradient_checkpointing_enable"): self.llama_decoder.gradient_checkpointing_enable() # simple adapter no need to implement gradient checkpointing def gradient_checkpointing_disable(self): if hasattr(self.esm_encoder, "gradient_checkpointing_disable"): self.esm_encoder.gradient_checkpointing_disable() if hasattr(self.llama_decoder, "gradient_checkpointing_disable"): self.llama_decoder.gradient_checkpointing_disable() AutoConfig.register("esm2llama_instruct", Esm2LlamaInstructConfig) AutoModelForCausalLM.register(Esm2LlamaInstructConfig, Esm2LlamaInstructForCausalLM)