Update modelling_longitudinal.py
Browse files
modelling_longitudinal.py
CHANGED
|
@@ -270,7 +270,7 @@ class LongitudinalPromptMultiCXREncoderDecoderModel(VisionEncoderDecoderModel):
|
|
| 270 |
if torch.all(input_ids[:, 0] == 1):
|
| 271 |
input_ids = input_ids[:, 1:]
|
| 272 |
|
| 273 |
-
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids)
|
| 274 |
decoder_attention_mask = (input_ids != mask_token_id).int()
|
| 275 |
decoder_position_ids = torch.nn.functional.relu(
|
| 276 |
torch.cumsum(decoder_attention_mask, dim=1, dtype=torch.int64) - 1
|
|
|
|
| 270 |
if torch.all(input_ids[:, 0] == 1):
|
| 271 |
input_ids = input_ids[:, 1:]
|
| 272 |
|
| 273 |
+
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
|
| 274 |
decoder_attention_mask = (input_ids != mask_token_id).int()
|
| 275 |
decoder_position_ids = torch.nn.functional.relu(
|
| 276 |
torch.cumsum(decoder_attention_mask, dim=1, dtype=torch.int64) - 1
|