File size: 12,160 Bytes
f7d6d00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
from typing import Dict, Optional, Tuple, Union
import torch
from transformers import AutoConfig, AutoModelForCausalLM
from transformers import EsmConfig, LlamaConfig, PretrainedConfig
from transformers import EsmModel, LlamaForCausalLM, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import Cache, GenerateOutput


class ModalityAdapterConfig(PretrainedConfig):
    model_type = "modality_adapter"

    def __init__(
            self, 
            input_dim: int, 
            intermediate_dim: int,
            output_dim: int, 
            dropout_rate: float = 0.3,
            **kwargs
    ):
        super().__init__(**kwargs)
        self.input_dim = input_dim
        self.intermediate_dim = intermediate_dim
        self.output_dim = output_dim
        self.dropout_rate = dropout_rate


class Esm2LlamaInstructConfig(PretrainedConfig):
    model_type = "esm2llama_instruct"

    def __init__(
            self, 
            # model components
            esm_config: Optional[Union[EsmConfig, Dict]] = None, 
            adapter_config: Optional[Union[ModalityAdapterConfig, Dict]] = None,
            llama_config: Optional[Union[LlamaConfig, Dict]] = None, 
            # standalone attributes
            placeholder_id: int = 128003, 
            **kwargs
    ):
        super().__init__(**kwargs)
        
        if isinstance(esm_config, dict):
            self.esm_config = EsmConfig(**esm_config)
        else:
            self.esm_config = esm_config
            
        if isinstance(llama_config, dict):
            self.llama_config = LlamaConfig(**llama_config)
        else:
            self.llama_config = llama_config
            
        if isinstance(adapter_config, dict):
            self.adapter_config = ModalityAdapterConfig(**adapter_config)
        else:
            self.adapter_config = adapter_config
            
        self.placeholder_id = placeholder_id


class ModalityAdapter(PreTrainedModel):
    config_class = ModalityAdapterConfig

    def __init__(self, config: ModalityAdapterConfig):
        super().__init__(config)
        self.config = config
        self.fc1 = torch.nn.Linear(config.input_dim, config.intermediate_dim)
        self.fc2 = torch.nn.Linear(config.intermediate_dim, config.output_dim)
        self.activation = torch.nn.GELU()
        self.ln1 = torch.nn.LayerNorm(normalized_shape=config.intermediate_dim)  # DEPRECATED
        self.ln2 = torch.nn.LayerNorm(normalized_shape=config.output_dim)  # DEPRECATED
        self.dropout = torch.nn.Dropout(p=config.dropout_rate)

        self.post_init()  # initialize weights and apply final processing

    def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
        # input: (bsz, seq_len, input_dim)
        hidden_states = self.activation(self.fc1(hidden_states))
        hidden_states = self.dropout(hidden_states)
        # interm: (bsz, seq_len, interm_dim)
        hidden_states = self.activation(self.fc2(hidden_states))
        hidden_states = self.dropout(hidden_states)
        hidden_states = torch.nn.functional.normalize(hidden_states, p=2, dim=-1)
        return hidden_states  # (bsz, seq_len, output_dim)


class Esm2LlamaInstructForCausalLM(PreTrainedModel):
    """
    Esm2LlamaInstructForCausalLM model for protein function prediction.
    Similar to `EncoderDecoderModel` but with more complicated architecture.
    Initialize with either a configuration OR all three components.
    `kwargs` can override standalone attributes in `Esm2LlamaInstructConfig`.
    """
    config_class = Esm2LlamaInstructConfig

    def __init__(
            self, 
            config: Optional[Esm2LlamaInstructConfig] = None, 
            esm_encoder: Optional[EsmModel] = None, 
            adapter: Optional[ModalityAdapter] = None,
            llama_decoder: Optional[LlamaForCausalLM] = None, 
            **kwargs
        ):
        if config is not None:  # components ignored if config is provided
            super().__init__(config)
            self.esm_encoder = EsmModel(
                config.esm_config, 
                add_pooling_layer=False
            )
            self.adapter = ModalityAdapter(config.adapter_config)
            self.llama_decoder = LlamaForCausalLM(config.llama_config)
        else: 
            config = Esm2LlamaInstructConfig(
                esm_config=esm_encoder.config,
                adapter_config=adapter.config,
                llama_config=llama_decoder.config, 
                **kwargs  # override standalone attributes
            ) 
            super().__init__(config)
            self.esm_encoder = esm_encoder
            self.adapter = adapter
            self.llama_decoder = llama_decoder
            
    def prepare_decoder_inputs(
            self, 
            input_ids: torch.LongTensor,
            encoder_hidden_states: torch.FloatTensor,
            attention_mask: Optional[torch.LongTensor] = None,
            encoder_attention_mask: Optional[torch.LongTensor] = None, 
    ): 
        """
        Embed and replace placeholder in `input_ids` by encoder hidden states.
        `input_ids` must be passed to locate placeholder for replacement.
        """
        # preparation
        batch_size, seq_len = input_ids.size()
        _, encoder_seq_len, _ = encoder_hidden_states.size()
        if attention_mask is None: 
            attention_mask = torch.ones(
                (batch_size, seq_len), 
                dtype=torch.long, 
                device=input_ids.device
            )
        if encoder_attention_mask is None: 
            encoder_attention_mask = torch.ones(
                (batch_size, encoder_seq_len), 
                dtype=torch.long, 
                device=encoder_hidden_states.device
            )
        inputs_embeds = self.llama_decoder.get_input_embeddings()(input_ids)
        # replacement
        placeholder_mask = input_ids == self.config.placeholder_id
        encoder_mask = encoder_attention_mask.bool()
        inputs_embeds[placeholder_mask] = encoder_hidden_states[encoder_mask]
        return inputs_embeds, attention_mask

    def forward(
            self, 
            # chat template text inputs
            input_ids: Optional[torch.LongTensor] = None,
            attention_mask: Optional[torch.LongTensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_values: Optional[Cache] = None,
            labels: Optional[torch.LongTensor] = None,
            # protein amino-acid sequence inputs
            protein_input_ids: Optional[torch.LongTensor] = None,
            protein_attention_mask: Optional[torch.LongTensor] = None,
            protein_position_ids: Optional[torch.LongTensor] = None, 
            protein_head_mask: Optional[torch.LongTensor] = None,
            protein_inputs_embeds: Optional[torch.FloatTensor] = None,
            # behavior control arguments
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            return_encoder_outputs: bool = False,
            return_adapter_outputs: bool = False, 
            return_decoder_inputs: bool = False,
            cache_position: Optional[torch.LongTensor] = None
    ) -> Union[Tuple, CausalLMOutputWithPast]: 
        """
        Compute encoder and adapter outputs, then pass to decoder.
        `input_ids` is expected to be [prompt + description] in teacher-forcing 
        scenario and [prompt] only in first iteration of inference (with 
        return_decoder_inputs=True). 
        Attention: possible concatenation of the mask and labels should be 
        handled before calling this method.
        `inputs_embeds` not allowed due to placeholder replacement scheme. 
        """
        # esm_encoder forward
        encoder_output = self.esm_encoder(
            input_ids=protein_input_ids,
            attention_mask=protein_attention_mask,
            position_ids=protein_position_ids,
            head_mask=protein_head_mask,
            inputs_embeds=protein_inputs_embeds,
            use_cache=False, # because config.esm_config.is_decoder=False
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )
        encoder_hidden_states = encoder_output[0]
        encoder_attention_mask = protein_attention_mask
        if return_encoder_outputs:
            return encoder_output
        # adapter forward
        adapter_output = self.adapter(encoder_hidden_states)
        if return_adapter_outputs:
            return adapter_output, encoder_attention_mask
        # decoder input preparation
        inputs_embeds, attention_mask = self.prepare_decoder_inputs(
            input_ids=input_ids, 
            encoder_hidden_states=adapter_output, 
            attention_mask=attention_mask, 
            encoder_attention_mask=encoder_attention_mask, 
        )
        if return_decoder_inputs:
            return inputs_embeds, attention_mask
        # llama_decoder forward
        return self.llama_decoder.forward(
            input_ids=None,
            attention_mask=attention_mask, 
            position_ids=position_ids, 
            past_key_values=past_key_values, 
            inputs_embeds=inputs_embeds, 
            labels=labels, 
            use_cache=use_cache, 
            output_attentions=output_attentions, 
            return_dict=return_dict, 
            cache_position=cache_position
        )

    def generate(
        self,
        inputs: torch.LongTensor,  # alias of `input_ids`
        attention_mask: Optional[torch.LongTensor] = None,
        protein_input_ids: Optional[torch.LongTensor] = None,
        protein_attention_mask: Optional[torch.LongTensor] = None,
        protein_inputs_embeds: Optional[torch.FloatTensor] = None,
        **kwargs
    ) -> Union[GenerateOutput, torch.LongTensor]:
        """
        Do inference based on given input prompt. 
        `inputs` is expected to be [prompt] only. 
        Output will not keep the input prompt due to input in form of embeds.
        Generation behavior can be controlled by `args` and `kwargs`, read 
        `GenerationMixin.generate` for more info. 
        """
        # get decoder inputs
        prompt_inputs_embeds, prompt_attention_mask = self(
            input_ids=inputs, 
            attention_mask=attention_mask,
            protein_input_ids=protein_input_ids,
            protein_attention_mask=protein_attention_mask,
            protein_inputs_embeds=protein_inputs_embeds,
            use_cache=False, 
            output_attentions=False,
            output_hidden_states=False,
            return_dict=False,
            return_decoder_inputs=True
        )
        # do generate on llama_decoder
        return self.llama_decoder.generate(
            inputs_embeds=prompt_inputs_embeds, 
            attention_mask=prompt_attention_mask, 
            **kwargs
        )

    def gradient_checkpointing_enable(self):
        """
        Enable gradient checkpointing for all submodules that support it.
        Attention! Model need to be in train mode before calling this method.
        """
        if hasattr(self.esm_encoder, "gradient_checkpointing_enable"):
            self.esm_encoder.gradient_checkpointing_enable()
        if hasattr(self.llama_decoder, "gradient_checkpointing_enable"):
            self.llama_decoder.gradient_checkpointing_enable()
        # simple adapter no need to implement gradient checkpointing

    def gradient_checkpointing_disable(self):
        if hasattr(self.esm_encoder, "gradient_checkpointing_disable"):
            self.esm_encoder.gradient_checkpointing_disable()
        if hasattr(self.llama_decoder, "gradient_checkpointing_disable"):
            self.llama_decoder.gradient_checkpointing_disable()


AutoConfig.register("esm2llama_instruct", Esm2LlamaInstructConfig)
AutoModelForCausalLM.register(Esm2LlamaInstructConfig, Esm2LlamaInstructForCausalLM)