Fix generation when `repetition_penalty` is activated (#57)
Browse files- make sure input_ids do not contain negative numbers (indicating images) after they are no longer needed (5905c926df4db18660da263a9777998ca66a14fe)
Co-authored-by: Yen-Chun Chen <[email protected]>
- image_embedding_phi3_v.py +10 -1
image_embedding_phi3_v.py
CHANGED
|
@@ -12,6 +12,7 @@
|
|
| 12 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
|
|
|
| 15 |
|
| 16 |
import torch
|
| 17 |
from torch import nn
|
|
@@ -191,7 +192,15 @@ class Phi3ImageEmbedding(nn.Module):
|
|
| 191 |
# positions for image tokens
|
| 192 |
positions = torch.nonzero((input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=True)
|
| 193 |
has_image = len(positions[0].tolist()) > 0
|
| 194 |
-
input_ids = input_ids.clamp_min(0).clamp_max(self.vocab_size).detach()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
hidden_states = self.wte(input_ids)
|
| 196 |
|
| 197 |
if has_image:
|
|
|
|
| 12 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
+
import warnings
|
| 16 |
|
| 17 |
import torch
|
| 18 |
from torch import nn
|
|
|
|
| 192 |
# positions for image tokens
|
| 193 |
positions = torch.nonzero((input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=True)
|
| 194 |
has_image = len(positions[0].tolist()) > 0
|
| 195 |
+
# input_ids = input_ids.clamp_min(0).clamp_max(self.vocab_size).detach()
|
| 196 |
+
input_ids.clamp_min_(0).clamp_max_(self.vocab_size)
|
| 197 |
+
warnings.warn(
|
| 198 |
+
"Phi-3-V modifies `input_ids` in-place and the tokens indicating images will be "
|
| 199 |
+
"removed after model forward. If your workflow requires multiple forward passes on "
|
| 200 |
+
"the same `input_ids`, please make a copy of `input_ids` before passing it to the "
|
| 201 |
+
"model."
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
hidden_states = self.wte(input_ids)
|
| 205 |
|
| 206 |
if has_image:
|