| ## zR-Llama-1B-chatglm2-6b-tokenizer | |
| 本模型是基于 [build_MiniLLM_from_scratch 开源框架](https://github.com/Tongjilibo/build_MiniLLM_from_scratch) 自行训练的一个1B模型。 | |
| ## 模型参数 | |
| + 1B 参数量 | |
| + 训练语料670亿。 | |
| + 模型支持token长度 896 | |
| ## 预训练模型 | |
| + 使用 [build_MiniLLM_from_scratch 开源框架](https://github.com/Tongjilibo/build_MiniLLM_from_scratch) 的预训练数据集,自己完成 Tokenize 过程。 | |
| + 使用 8 x 80GB A800 GPU 训练。 | |
| + 训练 1 Epoch,bs=32 (每张卡) , lr=1.5e-4。 | |
| + 共耗时 1 天。 | |
| ## SFT模型 | |
| + 使用 [build_MiniLLM_from_scratch 开源框架](https://github.com/Tongjilibo/build_MiniLLM_from_scratch) 提供的全部数据集 | |
| + 使用 单卡A800 微调。 | |
| + 微调 5 Epoch, bs=8, lr=2e-5。 | |
| + 共耗时 3 天 12 小时。 | |
| ## 使用模型 | |
| ```python | |
| import os | |
| import torch | |
| from transformers import AutoTokenizer, LlamaForCausalLM | |
| max_length = 896 | |
| HUMAN = '<human>' | |
| ROBOT = '<robot>' | |
| def build_prompt(query, history) -> str: | |
| texts = '' | |
| for user_input, response in history: | |
| texts += f'{HUMAN}{user_input}{ROBOT}{response}' | |
| texts += f'{HUMAN}{query}{ROBOT}' | |
| return texts | |
| def build_cli_history(history): | |
| prompt = '' | |
| for query, response in history: | |
| prompt += f"\n\nUser:{query.strip()}" | |
| prompt += f"\n\nRobot:{response.strip()}" | |
| return prompt | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| tokenizer = AutoTokenizer.from_pretrained("zRzRzRzRzRzRzR/zR-Llama-1b-ChatGLM2-6b-tokenizer", trust_remote_code=True) | |
| model = LlamaForCausalLM.from_pretrained("zRzRzRzRzRzRzR/zR-Llama-1b-ChatGLM2-6b-tokenizer").to(device) | |
| history = [] | |
| clear_command = 'cls' if os.name == 'nt' else 'clear' | |
| while True: | |
| query = input('\n输入:') | |
| if query.strip() == "stop": | |
| break | |
| if query.strip() == "clear": | |
| history = [] | |
| os.system(clear_command) | |
| continue | |
| inputs = tokenizer.encode(build_prompt(query, history), return_tensors='pt', add_special_tokens=False).to(device) | |
| response = model.generate(inputs) | |
| response = tokenizer.decode(response[0].cpu(), skip_special_tokens=True) | |
| os.system(clear_command) | |
| print(build_cli_history(history + [(query, response)]), flush=True) | |
| ``` | |