Update modeling_videochat_flash.py
Browse files
modeling_videochat_flash.py
CHANGED
|
@@ -636,7 +636,7 @@ class VideoChatFlashQwenForCausalLM(LlavaMetaForCausalLM, Qwen2ForCausalLM_Flash
|
|
| 636 |
|
| 637 |
image_sizes = [frames[0].shape[:2]]
|
| 638 |
|
| 639 |
-
frames = [self.get_vision_tower().image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].
|
| 640 |
|
| 641 |
conv = conv_templates["qwen_2"].copy()
|
| 642 |
|
|
@@ -679,7 +679,7 @@ class VideoChatFlashQwenForCausalLM(LlavaMetaForCausalLM, Qwen2ForCausalLM_Flash
|
|
| 679 |
|
| 680 |
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
| 681 |
if outputs.endswith(stop_str):
|
| 682 |
-
|
| 683 |
|
| 684 |
outputs = outputs.strip()
|
| 685 |
|
|
|
|
| 636 |
|
| 637 |
image_sizes = [frames[0].shape[:2]]
|
| 638 |
|
| 639 |
+
frames = [self.get_vision_tower().image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].to(self.model.dtype).cuda()]
|
| 640 |
|
| 641 |
conv = conv_templates["qwen_2"].copy()
|
| 642 |
|
|
|
|
| 679 |
|
| 680 |
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
| 681 |
if outputs.endswith(stop_str):
|
| 682 |
+
outputs = outputs[: -len(stop_str)]
|
| 683 |
|
| 684 |
outputs = outputs.strip()
|
| 685 |
|