| import torch | |
| import transformers | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig | |
| import openai | |
| from openai import OpenAI | |
| def hide(original_input, hide_model, tokenizer): | |
| hide_template = """<s>Paraphrase the text:%s\n\n""" | |
| input_text = hide_template % original_input | |
| inputs = tokenizer(input_text, return_tensors='pt').to(hide_model.device) | |
| pred = hide_model.generate( | |
| **inputs, | |
| generation_config=GenerationConfig( | |
| max_new_tokens = int(len(inputs['input_ids'][0]) * 1.3), | |
| do_sample=False, | |
| num_beams=3, | |
| repetition_penalty=5.0, | |
| ), | |
| ) | |
| pred = pred.cpu()[0][len(inputs['input_ids'][0]):] | |
| hide_input = tokenizer.decode(pred, skip_special_tokens=True) | |
| return hide_input | |
| def seek(hide_input, hide_output, original_input, seek_model, tokenizer): | |
| seek_template = """<s>Convert the text:\n%s\n\n%s\n\nConvert the text:\n%s\n\n""" | |
| input_text = seek_template % (hide_input, hide_output, original_input) | |
| inputs = tokenizer(input_text, return_tensors='pt').to(seek_model.device) | |
| pred = seek_model.generate( | |
| **inputs, | |
| generation_config=GenerationConfig( | |
| max_new_tokens = int(len(inputs['input_ids'][0]) * 1.3), | |
| do_sample=False, | |
| num_beams=3, | |
| ), | |
| ) | |
| pred = pred.cpu()[0][len(inputs['input_ids'][0]):] | |
| original_output = tokenizer.decode(pred, skip_special_tokens=True) | |
| return original_output | |
| def get_gpt_output(prompt, api_key=None): | |
| if not api_key: | |
| raise ValueError('an open api key is needed for this function') | |
| client = OpenAI(api_key=api_key) | |
| completion = client.chat.completions.create( | |
| model="gpt-3.5-turbo", | |
| messages=[ | |
| {"role": "user", "content": prompt} | |
| ] | |
| ) | |
| return completion.choices[0].message.content |