|
|
import torch |
|
|
import json |
|
|
import argparse |
|
|
import tqdm |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from generate import generate |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parse = argparse.ArgumentParser() |
|
|
parse.add_argument('--gen_length', type=int, default=28, help='generation length') |
|
|
parse.add_argument('--block_length', type=int, default=28, help='block length') |
|
|
parse.add_argument('--cfg', type=float, default=0., help='classfier-free guidance scale') |
|
|
parse.add_argument('--eos_inf', action='store_true', help='set eos token logit to -inf') |
|
|
parse.add_argument('--type', type=str, default='ftb', help='btf (backward to forward): predict previous sentence, ftb (forward to backward): predict next sentence') |
|
|
args = parse.parse_args() |
|
|
return args |
|
|
args = parse_args() |
|
|
|
|
|
|
|
|
|
|
|
extra_prompt = ' 直接输出句子即可。' |
|
|
with open("data/poem_data.json", "r") as f: |
|
|
poems = json.load(f) |
|
|
|
|
|
|
|
|
def next_predition_pairs(poems): |
|
|
return [poem['first'] + "的下一句是什么?" + extra_prompt for poem in poems], [poem['second'] for poem in poems] |
|
|
|
|
|
|
|
|
def prev_predition_pairs(poems): |
|
|
return [poem['second'] + "的上一句是什么?" + extra_prompt for poem in poems], [poem['first'] for poem in poems] |
|
|
|
|
|
|
|
|
device = 'cuda' |
|
|
model = AutoModelForCausalLM.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval() |
|
|
tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True) |
|
|
|
|
|
if args.type == 'ftb': |
|
|
prompts, answers = next_predition_pairs(poems) |
|
|
elif args.type == 'btf': |
|
|
prompts, answers = prev_predition_pairs(poems) |
|
|
else: |
|
|
raise NotImplementedError(args.type) |
|
|
|
|
|
|
|
|
acc = 0 |
|
|
for index in tqdm.tqdm(range(len(prompts))): |
|
|
prompt, answer = prompts[index], answers[index] |
|
|
|
|
|
m = [{"role": "user", "content": prompt},] |
|
|
prompt = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) |
|
|
input_ids = tokenizer(prompt)['input_ids'] |
|
|
input_ids = torch.tensor(input_ids).to(device).unsqueeze(0) |
|
|
|
|
|
out = generate(model, input_ids, steps=args.gen_length, temperature=0., cfg_scale=args.cfg, |
|
|
gen_length=args.gen_length, block_length=args.block_length, logits_eos_inf=args.eos_inf) |
|
|
|
|
|
out = tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)[0] |
|
|
|
|
|
acc = acc + 1 if answer in out else acc |
|
|
|
|
|
print(args) |
|
|
print(f'Accuracy: {acc/ len(prompts)}') |
|
|
print('*' * 20) |
|
|
|