Spaces:
Build error
Build error
| """This script refers to the dialogue example of streamlit, the interactive | |
| generation code of chatglm2 and transformers. | |
| We mainly modified part of the code logic to adapt to the | |
| generation of our model. | |
| Please refer to these links below for more information: | |
| 1. streamlit chat example: | |
| https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps | |
| 2. chatglm2: | |
| https://github.com/THUDM/ChatGLM2-6B | |
| 3. transformers: | |
| https://github.com/huggingface/transformers | |
| Please run with the command `streamlit run path/to/web_demo.py | |
| --server.address=0.0.0.0 --server.port 7860`. | |
| Using `python path/to/web_demo.py` may cause unknown problems. | |
| """ | |
| # isort: skip_file | |
| import copy | |
| import warnings | |
| from dataclasses import asdict, dataclass | |
| from typing import Callable, List, Optional | |
| import streamlit as st | |
| import torch | |
| from torch import nn | |
| from transformers.generation.utils import (LogitsProcessorList, | |
| StoppingCriteriaList) | |
| from transformers.utils import logging | |
| from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip | |
| logger = logging.get_logger(__name__) | |
| # model_name_or_path="/root/finetune/models/internlm2-chat-7b" | |
| model_name_or_path = "../finetune/work_dirs/assistTuner/merged" | |
| class GenerationConfig: | |
| # this config is used for chat to provide more diversity | |
| max_length: int = 32768 | |
| top_p: float = 0.8 | |
| temperature: float = 0.8 | |
| do_sample: bool = True | |
| repetition_penalty: float = 1.005 | |
| def generate_interactive( | |
| model, | |
| tokenizer, | |
| prompt, | |
| generation_config: Optional[GenerationConfig] = None, | |
| logits_processor: Optional[LogitsProcessorList] = None, | |
| stopping_criteria: Optional[StoppingCriteriaList] = None, | |
| prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], | |
| List[int]]] = None, | |
| additional_eos_token_id: Optional[int] = None, | |
| **kwargs, | |
| ): | |
| inputs = tokenizer([prompt], padding=True, return_tensors='pt') | |
| input_length = len(inputs['input_ids'][0]) | |
| for k, v in inputs.items(): | |
| inputs[k] = v.cuda() | |
| input_ids = inputs['input_ids'] | |
| _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] | |
| if generation_config is None: | |
| generation_config = model.generation_config | |
| generation_config = copy.deepcopy(generation_config) | |
| model_kwargs = generation_config.update(**kwargs) | |
| bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612 | |
| generation_config.bos_token_id, | |
| generation_config.eos_token_id, | |
| ) | |
| if isinstance(eos_token_id, int): | |
| eos_token_id = [eos_token_id] | |
| if additional_eos_token_id is not None: | |
| eos_token_id.append(additional_eos_token_id) | |
| has_default_max_length = kwargs.get( | |
| 'max_length') is None and generation_config.max_length is not None | |
| if has_default_max_length and generation_config.max_new_tokens is None: | |
| warnings.warn( | |
| f"Using 'max_length''s default \ | |
| ({repr(generation_config.max_length)}) \ | |
| to control the generation length. " | |
| 'This behaviour is deprecated and will be removed from the \ | |
| config in v5 of Transformers -- we' | |
| ' recommend using `max_new_tokens` to control the maximum \ | |
| length of the generation.', | |
| UserWarning, | |
| ) | |
| elif generation_config.max_new_tokens is not None: | |
| generation_config.max_length = generation_config.max_new_tokens + \ | |
| input_ids_seq_length | |
| if not has_default_max_length: | |
| logger.warn( # pylint: disable=W4902 | |
| f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) " | |
| f"and 'max_length'(={generation_config.max_length}) seem to " | |
| "have been set. 'max_new_tokens' will take precedence. " | |
| 'Please refer to the documentation for more information. ' | |
| '(https://huggingface.co/docs/transformers/main/' | |
| 'en/main_classes/text_generation)', | |
| UserWarning, | |
| ) | |
| if input_ids_seq_length >= generation_config.max_length: | |
| input_ids_string = 'input_ids' | |
| logger.warning( | |
| f'Input length of {input_ids_string} is {input_ids_seq_length}, ' | |
| f"but 'max_length' is set to {generation_config.max_length}. " | |
| 'This can lead to unexpected behavior. You should consider' | |
| " increasing 'max_new_tokens'.") | |
| # 2. Set generation parameters if not already defined | |
| logits_processor = logits_processor if logits_processor is not None \ | |
| else LogitsProcessorList() | |
| stopping_criteria = stopping_criteria if stopping_criteria is not None \ | |
| else StoppingCriteriaList() | |
| logits_processor = model._get_logits_processor( | |
| generation_config=generation_config, | |
| input_ids_seq_length=input_ids_seq_length, | |
| encoder_input_ids=input_ids, | |
| prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, | |
| logits_processor=logits_processor, | |
| ) | |
| stopping_criteria = model._get_stopping_criteria( | |
| generation_config=generation_config, | |
| stopping_criteria=stopping_criteria) | |
| logits_warper = model._get_logits_warper(generation_config) | |
| unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) | |
| scores = None | |
| while True: | |
| model_inputs = model.prepare_inputs_for_generation( | |
| input_ids, **model_kwargs) | |
| # forward pass to get next token | |
| outputs = model( | |
| **model_inputs, | |
| return_dict=True, | |
| output_attentions=False, | |
| output_hidden_states=False, | |
| ) | |
| next_token_logits = outputs.logits[:, -1, :] | |
| # pre-process distribution | |
| next_token_scores = logits_processor(input_ids, next_token_logits) | |
| next_token_scores = logits_warper(input_ids, next_token_scores) | |
| # sample | |
| probs = nn.functional.softmax(next_token_scores, dim=-1) | |
| if generation_config.do_sample: | |
| next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
| else: | |
| next_tokens = torch.argmax(probs, dim=-1) | |
| # update generated ids, model inputs, and length for next step | |
| input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | |
| model_kwargs = model._update_model_kwargs_for_generation( | |
| outputs, model_kwargs, is_encoder_decoder=False) | |
| unfinished_sequences = unfinished_sequences.mul( | |
| (min(next_tokens != i for i in eos_token_id)).long()) | |
| output_token_ids = input_ids[0].cpu().tolist() | |
| output_token_ids = output_token_ids[input_length:] | |
| for each_eos_token_id in eos_token_id: | |
| if output_token_ids[-1] == each_eos_token_id: | |
| output_token_ids = output_token_ids[:-1] | |
| response = tokenizer.decode(output_token_ids) | |
| yield response | |
| # stop when each sentence is finished | |
| # or if we exceed the maximum length | |
| if unfinished_sequences.max() == 0 or stopping_criteria( | |
| input_ids, scores): | |
| break | |
| def on_btn_click(): | |
| del st.session_state.messages | |
| def load_model(): | |
| model = (AutoModelForCausalLM.from_pretrained( | |
| model_name_or_path, | |
| trust_remote_code=True).to(torch.bfloat16).cuda()) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, | |
| trust_remote_code=True) | |
| return model, tokenizer | |
| def prepare_generation_config(): | |
| with st.sidebar: | |
| max_length = st.slider('Max Length', | |
| min_value=8, | |
| max_value=32768, | |
| value=32768) | |
| top_p = st.slider('Top P', 0.0, 1.0, 0.8, step=0.01) | |
| temperature = st.slider('Temperature', 0.0, 1.0, 0.7, step=0.01) | |
| st.button('Clear Chat History', on_click=on_btn_click) | |
| generation_config = GenerationConfig(max_length=max_length, | |
| top_p=top_p, | |
| temperature=temperature) | |
| return generation_config | |
| user_prompt = '<|im_start|>user\n{user}<|im_end|>\n' | |
| robot_prompt = '<|im_start|>assistant\n{robot}<|im_end|>\n' | |
| cur_query_prompt = '<|im_start|>user\n{user}<|im_end|>\n\ | |
| <|im_start|>assistant\n' | |
| def combine_history(prompt): | |
| messages = st.session_state.messages | |
| meta_instruction = ('You are a helpful, honest, ' | |
| 'and harmless AI assistant.') | |
| total_prompt = f'<s><|im_start|>system\n{meta_instruction}<|im_end|>\n' | |
| for message in messages: | |
| cur_content = message['content'] | |
| if message['role'] == 'user': | |
| cur_prompt = user_prompt.format(user=cur_content) | |
| elif message['role'] == 'robot': | |
| cur_prompt = robot_prompt.format(robot=cur_content) | |
| else: | |
| raise RuntimeError | |
| total_prompt += cur_prompt | |
| total_prompt = total_prompt + cur_query_prompt.format(user=prompt) | |
| return total_prompt | |
| def main(): | |
| st.title('internlm2_5-7b-chat-assistant') | |
| # torch.cuda.empty_cache() | |
| print('load model begin.') | |
| model, tokenizer = load_model() | |
| print('load model end.') | |
| generation_config = prepare_generation_config() | |
| # Initialize chat history | |
| if 'messages' not in st.session_state: | |
| st.session_state.messages = [] | |
| # Display chat messages from history on app rerun | |
| for message in st.session_state.messages: | |
| with st.chat_message(message['role'], avatar=message.get('avatar')): | |
| st.markdown(message['content']) | |
| # Accept user input | |
| if prompt := st.chat_input('What is up?'): | |
| # Display user message in chat message container | |
| with st.chat_message('user', avatar='user'): | |
| st.markdown(prompt) | |
| real_prompt = combine_history(prompt) | |
| # Add user message to chat history | |
| st.session_state.messages.append({ | |
| 'role': 'user', | |
| 'content': prompt, | |
| 'avatar': 'user' | |
| }) | |
| with st.chat_message('robot', avatar='assistant'): | |
| message_placeholder = st.empty() | |
| for cur_response in generate_interactive( | |
| model=model, | |
| tokenizer=tokenizer, | |
| prompt=real_prompt, | |
| additional_eos_token_id=92542, | |
| device='cuda:0', | |
| **asdict(generation_config), | |
| ): | |
| # Display robot response in chat message container | |
| message_placeholder.markdown(cur_response + '▌') | |
| message_placeholder.markdown(cur_response) | |
| # Add robot response to chat history | |
| st.session_state.messages.append({ | |
| 'role': 'robot', | |
| 'content': cur_response, # pylint: disable=undefined-loop-variable | |
| 'avatar': 'assistant', | |
| }) | |
| torch.cuda.empty_cache() | |
| if __name__ == '__main__': | |
| main() | |