Spaces:
Runtime error
Runtime error
| import os, re, sys | |
| import spaces | |
| import traceback | |
| import shutil | |
| import torch | |
| import numpy as np | |
| from num2words import num2words | |
| from datetime import timedelta | |
| import datetime | |
| import subprocess | |
| from utils.mm_utils import ( | |
| KeywordsStoppingCriteria, | |
| get_model_name_from_path, | |
| tokenizer_mm_token, | |
| ApolloMMLoader | |
| ) | |
| from utils.conversation import conv_templates, SeparatorStyle | |
| from utils.constants import ( | |
| X_TOKEN, | |
| X_TOKEN_INDEX, | |
| ) | |
| from decord import cpu, VideoReader | |
| from huggingface_hub import snapshot_download | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel #, BitsAndBytesConfig | |
| import gradio as gr | |
| import zipfile | |
| model_url = "GoodiesHere/Apollo-LMMs-Apollo-1_5B-t32" | |
| video_url = "https://www.youtube.com/watch?v=dQw4w9WgXcQ" | |
| yt_dlp_bin = os.getenv('YT_DLP') | |
| if yt_dlp_bin == "": | |
| yt_dlp_bin = "yt-dlp" | |
| if not os.path.exists('example.mp4'): | |
| subprocess.run([yt_dlp_bin, '-o', 'example.mp4', '--recode-video', 'mp4', video_url]) | |
| title_markdown = """ | |
| <div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
| <div> | |
| <h1 >You are chatting with Apollo-3B</h1> | |
| </div> | |
| </div> | |
| <div align="center"> | |
| <div style="display:flex; gap: 0.25rem; margin-top: 10px;" align="center"> | |
| <a href='https://apollo-lmms.github.io/Apollo/'><img src='https://img.shields.io/badge/Project-Apollo-deepskyblue'></a> | |
| <a href='https://huggingface.co/GoodiesHere/Apollo-LMMs-Apollo-1_5B-t32'><img src='https://img.shields.io/badge/model-checkpoints-gold'></a> | |
| </div> | |
| </div> | |
| """ | |
| block_css = """ | |
| #buttons button { | |
| min-width: min(120px,100%); | |
| color: #9C276A | |
| } | |
| """ | |
| plum_color = gr.themes.colors.Color( | |
| name='plum', | |
| c50='#F8E4EF', | |
| c100='#E9D0DE', | |
| c200='#DABCCD', | |
| c300='#CBA8BC', | |
| c400='#BC94AB', | |
| c500='#AD809A', | |
| c600='#9E6C89', | |
| c700='#8F5878', | |
| c800='#804467', | |
| c900='#713056', | |
| c950='#662647', | |
| ) | |
| model_path = snapshot_download(model_url, repo_type="model") | |
| destination_path = './tmp/data' | |
| os.makedirs(destination_path, exist_ok=True) | |
| shutil.copytree(model_path, destination_path, dirs_exist_ok=True) | |
| #quantization_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True) | |
| class Chat: | |
| def __init__(self): | |
| self.version = "qwen_1_5" | |
| model_name = "apollo" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| #attn_implementation="sdpa" if torch.__version__ > "2.1.2" else "eager" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| #attn_implementation=attn_implementation, | |
| device_map="auto", | |
| #quantization_config=quantization_config, | |
| #load_in_4bit=True, | |
| ).to(device=device, dtype=torch.bfloat16).half() | |
| self._model = model | |
| self._tokenizer = model.tokenizer | |
| self._vision_processors = model.vision_tower.vision_processor | |
| self._max_length = model.config.llm_cfg['model_max_length'] | |
| self._config = self._model.config | |
| self.num_repeat_token = self._config.mm_connector_cfg['num_output_tokens'] #todo: get from config | |
| self.mm_use_im_start_end = self._config.use_mm_start_end | |
| frames_per_clip = 4 | |
| clip_duration=getattr(self._config, 'clip_duration') | |
| self.mm_processor = ApolloMMLoader(self._vision_processors, | |
| clip_duration, | |
| frames_per_clip, | |
| clip_sampling_ratio=0.65, | |
| model_max_length = self._config.model_max_length, | |
| device=device, | |
| num_repeat_token=self.num_repeat_token) | |
| self._model.config.encode_batch_size = 35 | |
| self._model.eval() | |
| def remove_after_last_dot(self, s): | |
| last_dot_index = s.rfind('.') | |
| if last_dot_index == -1: | |
| return s | |
| return s[:last_dot_index + 1] | |
| def apply_first_prompt(self, message, replace_string, data_type): | |
| if self.mm_use_im_start_end: | |
| message = X_START_TOKEN[data_type] + replace_string + X_END_TOKEN[data_type] + '\n\n' + message | |
| else: | |
| message = (replace_string) + '\n\n' + message | |
| return message | |
| def generate(self, data: list, message, temperature, top_p, max_output_tokens): | |
| # TODO: support multiple turns of conversation. | |
| mm_data, replace_string, data_type = data[0] | |
| print(message) | |
| conv = conv_templates[self.version].copy() | |
| if isinstance(message, str): | |
| message = self.apply_first_prompt(message, replace_string, data_type) | |
| conv.append_message(conv.roles[0], message) | |
| elif isinstance(message, list): | |
| if X_TOKEN[data_type] not in message[0]['content']: | |
| print('applying prompt') | |
| message[0]['content'] = self.apply_first_prompt(message[0]['content'], replace_string, data_type) | |
| for mes in message: | |
| conv.append_message(mes["role"], mes["content"]) | |
| conv.append_message(conv.roles[1], None) | |
| prompt = conv.get_prompt() | |
| print(prompt.replace(X_TOKEN['video'],'<v>')) | |
| input_ids = tokenizer_mm_token(prompt, self._tokenizer, return_tensors="pt").unsqueeze(0).cuda().to(self._model.device) | |
| pad_token_ids = self._tokenizer.pad_token_id if self._tokenizer.pad_token_id is not None else self._tokenizer.eos_token_id | |
| stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 | |
| keywords = [stop_str] | |
| stopping_criteria = KeywordsStoppingCriteria(keywords, self._tokenizer, input_ids) | |
| print(f'running on {input_ids.shape[1]} tokens!') | |
| with torch.inference_mode(): | |
| output_ids = self._model.generate(input_ids, | |
| vision_input=[mm_data], | |
| data_types=[data_type], | |
| do_sample=True if temperature > 0 else False, | |
| temperature=temperature, | |
| max_new_tokens=max_output_tokens, | |
| top_p=top_p, | |
| use_cache=True, | |
| num_beams=1, | |
| stopping_criteria=[stopping_criteria]) | |
| print(f'generated on {output_ids.shape[1]} tokens!') | |
| print(output_ids) | |
| pred = self._tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() | |
| print(pred) | |
| return self.remove_after_last_dot(pred) | |
| def generate(image, video, message, chatbot, textbox_in, temperature, top_p, max_output_tokens, dtype=torch.float16): | |
| print(message) | |
| if textbox_in is None: | |
| raise gr.Error("Chat messages cannot be empty") | |
| return ( | |
| gr.update(value=image, interactive=True), | |
| gr.update(value=video, interactive=True), | |
| message, | |
| chatbot, | |
| None, | |
| ) | |
| data = [] | |
| mm_processor = handler.mm_processor | |
| try: | |
| if image is not None: | |
| image, prompt = mm_processor.load_image(image) | |
| data.append((image, prompt, 'image')) | |
| elif video is not None: | |
| video_tensor, prompt = mm_processor.load_video(video) | |
| data.append((video_tensor, prompt, 'video')) | |
| elif image is None and video is None: | |
| data.append((None, None, 'text')) | |
| else: | |
| raise NotImplementedError("Not support image and video at the same time") | |
| except Exception as e: | |
| traceback.print_exc() | |
| return gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), message, chatbot, None | |
| assert len(message) % 2 == 0, "The message should be a pair of user and system message." | |
| show_images = "" | |
| if image is not None: | |
| show_images += f'<img src="./file={image}" style="display: inline-block;width: 250px;max-height: 400px;">' | |
| if video is not None: | |
| show_images += f'<video controls playsinline width="300" style="display: inline-block;" src="./file={video}"></video>' | |
| one_turn_chat = [textbox_in, None] | |
| # 1. first run case | |
| if len(chatbot) == 0: | |
| one_turn_chat[0] += "\n" + show_images | |
| # 2. not first run case | |
| else: | |
| # scanning the last image or video | |
| length = len(chatbot) | |
| for i in range(length - 1, -1, -1): | |
| previous_image = re.findall(r'<img src="./file=(.+?)"', chatbot[i][0]) | |
| previous_video = re.findall(r'<video controls playsinline width="500" style="display: inline-block;" src="./file=(.+?)"', chatbot[i][0]) | |
| if len(previous_image) > 0: | |
| previous_image = previous_image[-1] | |
| # 2.1 new image append or pure text input will start a new conversation | |
| if (video is not None) or (image is not None and os.path.basename(previous_image) != os.path.basename(image)): | |
| message.clear() | |
| one_turn_chat[0] += "\n" + show_images | |
| break | |
| elif len(previous_video) > 0: | |
| previous_video = previous_video[-1] | |
| # 2.2 new video append or pure text input will start a new conversation | |
| if image is not None or (video is not None and os.path.basename(previous_video) != os.path.basename(video)): | |
| message.clear() | |
| one_turn_chat[0] += "\n" + show_images | |
| break | |
| message.append({'role': 'user', 'content': textbox_in}) | |
| text_en_out = handler.generate(data, message, temperature=temperature, top_p=top_p, max_output_tokens=max_output_tokens) | |
| message.append({'role': 'assistant', 'content': text_en_out}) | |
| one_turn_chat[1] = text_en_out | |
| chatbot.append(one_turn_chat) | |
| return gr.update(value=image, interactive=True), gr.update(value=video, interactive=True), message, chatbot, None | |
| def regenerate(message, chatbot): | |
| message.pop(-1), message.pop(-1) | |
| chatbot.pop(-1) | |
| return message, chatbot | |
| def clear_history(message, chatbot): | |
| message.clear(), chatbot.clear() | |
| return (gr.update(value=None, interactive=True), | |
| gr.update(value=None, interactive=True), | |
| message, chatbot, | |
| gr.update(value=None, interactive=True)) | |
| handler = Chat() | |
| textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False) | |
| theme = gr.themes.Default(primary_hue=plum_color) | |
| # theme.update_color("primary", plum_color.c500) | |
| theme.set(slider_color="#9C276A") | |
| theme.set(block_title_text_color="#9C276A") | |
| theme.set(block_label_text_color="#9C276A") | |
| theme.set(button_primary_text_color="#9C276A") | |
| with gr.Blocks(title='Apollo-3B', theme=theme, css=block_css) as demo: | |
| gr.Markdown(title_markdown) | |
| message = gr.State([]) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| image = gr.State(None) | |
| video = gr.Video(label="Input Video") | |
| with gr.Accordion("Parameters", open=True) as parameter_row: | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.4, | |
| step=0.1, | |
| interactive=True, | |
| label="Temperature", | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.7, | |
| step=0.1, | |
| interactive=True, | |
| label="Top P", | |
| ) | |
| max_output_tokens = gr.Slider( | |
| minimum=32, | |
| maximum=1024, | |
| value=256, | |
| step=32, | |
| interactive=True, | |
| label="Max output tokens", | |
| ) | |
| with gr.Column(scale=7): | |
| chatbot = gr.Chatbot(label="Apollo", bubble_full_width=True, height=420) | |
| with gr.Row(): | |
| with gr.Column(scale=8): | |
| textbox.render() | |
| with gr.Column(scale=1, min_width=50): | |
| submit_btn = gr.Button(value="Send", variant="primary", interactive=True) | |
| with gr.Row(elem_id="buttons") as button_row: | |
| upvote_btn = gr.Button(value="π Upvote", interactive=True) | |
| downvote_btn = gr.Button(value="π Downvote", interactive=True) | |
| regenerate_btn = gr.Button(value="π Regenerate", interactive=True) | |
| clear_btn = gr.Button(value="ποΈ Clear history", interactive=True) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| f"{destination_path}/../../example.mp4", | |
| "What is this shit?", | |
| ], | |
| ], | |
| inputs=[video, textbox], | |
| ) | |
| submit_btn.click( | |
| generate, | |
| [image, video, message, chatbot, textbox, temperature, top_p, max_output_tokens], | |
| [image, video, message, chatbot, textbox]) | |
| textbox.submit( | |
| generate, | |
| [ | |
| image, | |
| video, | |
| message, | |
| chatbot, | |
| textbox, | |
| temperature, | |
| top_p, | |
| max_output_tokens, | |
| ], | |
| [image, video, message, chatbot, textbox], | |
| ) | |
| regenerate_btn.click( | |
| regenerate, | |
| [message, chatbot], | |
| [message, chatbot]).then( | |
| generate, | |
| [image, video, message, chatbot, textbox, temperature, top_p, max_output_tokens], | |
| [image, video, message, chatbot]) | |
| clear_btn.click( | |
| clear_history, | |
| [message, chatbot], | |
| [image, video, message, chatbot, textbox]) | |
| demo.launch() |