File size: 2,533 Bytes
7157974 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import torch
import json
import argparse
import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from generate import generate
def parse_args():
parse = argparse.ArgumentParser()
parse.add_argument('--gen_length', type=int, default=28, help='generation length')
parse.add_argument('--block_length', type=int, default=28, help='block length')
parse.add_argument('--cfg', type=float, default=0., help='classfier-free guidance scale')
parse.add_argument('--eos_inf', action='store_true', help='set eos token logit to -inf')
parse.add_argument('--type', type=str, default='ftb', help='btf (backward to forward): predict previous sentence, ftb (forward to backward): predict next sentence')
args = parse.parse_args()
return args
args = parse_args()
# Engish translation, extra_prompt = ' just output the sentence directly.'
extra_prompt = ' 直接输出句子即可。'
with open("data/poem_data.json", "r") as f:
poems = json.load(f)
def next_predition_pairs(poems):
return [poem['first'] + "的下一句是什么?" + extra_prompt for poem in poems], [poem['second'] for poem in poems]
def prev_predition_pairs(poems):
return [poem['second'] + "的上一句是什么?" + extra_prompt for poem in poems], [poem['first'] for poem in poems]
device = 'cuda'
model = AutoModelForCausalLM.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)
if args.type == 'ftb':
prompts, answers = next_predition_pairs(poems)
elif args.type == 'btf':
prompts, answers = prev_predition_pairs(poems)
else:
raise NotImplementedError(args.type)
acc = 0
for index in tqdm.tqdm(range(len(prompts))):
prompt, answer = prompts[index], answers[index]
m = [{"role": "user", "content": prompt},]
prompt = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
input_ids = tokenizer(prompt)['input_ids']
input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)
out = generate(model, input_ids, steps=args.gen_length, temperature=0., cfg_scale=args.cfg,
gen_length=args.gen_length, block_length=args.block_length, logits_eos_inf=args.eos_inf)
out = tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)[0]
acc = acc + 1 if answer in out else acc
print(args)
print(f'Accuracy: {acc/ len(prompts)}')
print('*' * 20)
|