Spaces:
Runtime error
Runtime error
| import spaces | |
| import gradio as gr | |
| import torch | |
| from peft import PeftModel, PeftConfig | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import json | |
| from datetime import datetime | |
| from uuid import uuid4 | |
| import os | |
| from pathlib import Path | |
| from huggingface_hub import CommitScheduler | |
| # TODO make it so that feedback is only saved on prev. example if user makes another obfuscation | |
| # and changes slider but doesn't hit obfuscate | |
| # TODO maybe make it save and reset if user hits submit feedback | |
| # TODO sampling params for modles | |
| # TODO obfuscation ID? | |
| # Converts text to the correct format for LoRA adapters in StyleRemix | |
| def convert_data_to_format(text): | |
| output = f"### Original: {text}\n ### Rewrite:" | |
| return output | |
| MODEL_PATHS = { | |
| "length_more": "hallisky/lora-length-long-llama-3-8b", | |
| "length_less": "hallisky/lora-length-short-llama-3-8b", | |
| "function_more": "hallisky/lora-function-more-llama-3-8b", | |
| "function_less": "hallisky/lora-function-less-llama-3-8b", | |
| "grade_more": "hallisky/lora-grade-highschool-llama-3-8b", | |
| "grade_less": "hallisky/lora-grade-elementary-llama-3-8b", | |
| "formality_more": "hallisky/lora-formality-formal-llama-3-8b", | |
| "formality_less": "hallisky/lora-formality-informal-llama-3-8b", | |
| "sarcasm_more": "hallisky/lora-sarcasm-more-llama-3-8b", | |
| "sarcasm_less": "hallisky/lora-sarcasm-less-llama-3-8b", | |
| "voice_passive": "hallisky/lora-voice-passive-llama-3-8b", | |
| "voice_active": "hallisky/lora-voice-active-llama-3-8b", | |
| } | |
| FIRST_MODEL = list(MODEL_PATHS.keys())[5] | |
| DESCRIPTION = """\ | |
| # Authorship Obfuscation | |
| This Space demonstrates StyleRemix, a Llama 3 model with 8B parameters fine-tuned for chat instructions. Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also [deploy the model on Inference Endpoints](https://huggingface.co/inference-endpoints). | |
| ๐ For more details about the Llama 2 family of models and how to use them with `transformers`, take a look [at our blog post](https://huggingface.co/blog/llama2). | |
| ๐จ Looking for an even more powerful model? Check out the [13B version](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat) or the large [70B model demo](https://huggingface.co/spaces/ysharma/Explore_llamav2_with_TGI). | |
| """ | |
| # Load models | |
| if not torch.cuda.is_available(): | |
| device = "cpu" | |
| DESCRIPTION += "\n<p>Running on CPU ๐ฅถ This demo does not work on CPU.</p>" | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| model_id = "meta-llama/Meta-Llama-3-8B" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, add_bos_token=True, add_eos_token=False, padding_side="left") | |
| tokenizer.add_special_tokens({'pad_token': '<padding_token>'}) | |
| base_model = AutoModelForCausalLM.from_pretrained(model_id).to(device) # device_map="auto" requires accelerate | |
| base_model.resize_token_embeddings(len(tokenizer)) # Resize to add pad token. Value doesn't matter | |
| # Load in the first model | |
| model = PeftModel.from_pretrained(base_model, MODEL_PATHS[FIRST_MODEL], adapter_name=FIRST_MODEL).to(device) | |
| # Load in the rest of the models | |
| # for cur_adapter in MODEL_PATHS.keys(): | |
| # if cur_adapter != FIRST_MODEL: | |
| # model.load_adapter(MODEL_PATHS[cur_adapter], adapter_name=cur_adapter) | |
| model.eval() | |
| # Global variable to store the latest obfuscation result | |
| user_id = str(uuid4()) # Generate a unique session-specific user ID | |
| JSON_DATASET_DIR = Path("json_dataset") | |
| JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True) | |
| JSON_DATASET_PATH = JSON_DATASET_DIR / f"train-{user_id}.json" | |
| scheduler = CommitScheduler( | |
| repo_id="authorship-obfuscation-demo-data", | |
| repo_type="dataset", | |
| folder_path=JSON_DATASET_DIR, | |
| path_in_repo="data", | |
| every=0.5 | |
| ) | |
| def save_data(data): | |
| with scheduler.lock: | |
| with JSON_DATASET_PATH.open("a") as f: | |
| json.dump(data, f) | |
| f.write("\n") | |
| def save_feedback(feedback_rating, feedback_text, latest_obfuscation): | |
| latest_obfuscation["feedback_rating"] = feedback_rating | |
| latest_obfuscation["feedback_text"] = feedback_text | |
| save_data(latest_obfuscation) | |
| return "No Feedback Selected", "" | |
| def greet(input_text, length, function_words, grade_level, sarcasm, formality, voice, persuasive, descriptive, narrative, expository): | |
| global latest_obfuscation, user_id | |
| current_time = datetime.now().isoformat() | |
| sliders_dict = {} | |
| cur_keys = [] | |
| cur_keys.append(("length_more" if length > 0 else (None if length == 0 else "length_less"), length)) | |
| cur_keys.append(("function_more" if function_words > 0 else (None if function_words == 0 else "function_less"), function_words)) | |
| cur_keys.append(("grade_more" if grade_level > 0 else (None if grade_level == 0 else "grade_less"), grade_level)) | |
| cur_keys.append(("sarcasm_more" if sarcasm > 0 else (None if sarcasm == 0 else "sarcasm_less"), sarcasm)) | |
| cur_keys.append(("formality_more" if formality > 0 else (None if formality == 0 else "formality_less"), formality)) | |
| cur_keys.append(("voice_active" if voice > 0 else (None if voice == 0 else "voice_passive"), voice)) | |
| for cur_key in cur_keys: | |
| if cur_key[0] is not None: | |
| sliders_dict[cur_key[0]] = cur_key[1] | |
| response = ( | |
| f"Hello!\n" | |
| f"Input Text: {input_text}\n" | |
| f"Length: {length}\n" | |
| f"Function Words: {function_words}\n" | |
| f"Grade Level: {grade_level}\n" | |
| f"Sarcasm: {sarcasm}\n" | |
| f"Formality: {formality}\n" | |
| f"Voice: {voice}\n" | |
| f"Persuasive: {persuasive}\n" | |
| f"Descriptive: {descriptive}\n" | |
| f"Narrative: {narrative}\n" | |
| f"Expository: {expository}" | |
| ) | |
| converted_text = convert_data_to_format(input_text) | |
| # Convert the list of strings in data to a list of model inputs | |
| inputs = tokenizer(converted_text, return_tensors="pt", max_length=2048, truncation=True).to(device) | |
| input_length = inputs.input_ids.shape[1] | |
| with torch.no_grad(): | |
| outputs = model.generate(**inputs, max_length=100, top_p = 0.95) | |
| response = tokenizer.decode(outputs[0, input_length:], skip_special_tokens=True) | |
| # Save the new obfuscation result and reset feedback | |
| latest_obfuscation = { | |
| "datetime": current_time, | |
| "user_id": user_id, | |
| "input_text": input_text, | |
| "sliders": { | |
| "length": length, | |
| "function_words": function_words, | |
| "grade_level": grade_level, | |
| "sarcasm": sarcasm, | |
| "formality": formality, | |
| "voice": voice, | |
| "persuasive": persuasive, | |
| "descriptive": descriptive, | |
| "narrative": narrative, | |
| "expository": expository | |
| }, | |
| "input": input_text, | |
| "output": response, | |
| "feedback_rating": "No Feedback Selected", | |
| "feedback_text": "" | |
| } | |
| # Save the obfuscation result | |
| save_data(latest_obfuscation) | |
| return response, gr.update(interactive=True), gr.update(interactive=True), latest_obfuscation | |
| def reset_sliders(): | |
| return [0.5] * 7 + [0] * 3 | |
| def toggle_slider(checked, value): | |
| if checked: | |
| return gr.update(value=value, interactive=True) | |
| else: | |
| return gr.update(value=0, interactive=False) | |
| def reset_writing_type_sliders(selected_type): | |
| reset_values = [gr.update(value=0, interactive=False) for _ in range(4)] | |
| if selected_type != "None": | |
| index = ["Persuasive", "Descriptive", "Narrative", "Expository"].index(selected_type) | |
| reset_values[index] = gr.update(value=0, interactive=True) | |
| return reset_values | |
| def update_save_feedback_button(feedback_rating, feedback_text): | |
| if feedback_rating != "No Feedback Selected" or feedback_text.strip() != "": | |
| return gr.update(interactive=True), gr.update(visible=False) | |
| else: | |
| return gr.update(interactive=False), gr.update(visible=True) | |
| def update_obfuscate_button(input_text): | |
| if input_text.strip() == "": | |
| return gr.update(interactive=False), gr.update(visible=True) | |
| else: | |
| return gr.update(interactive=True), gr.update(visible=False) | |
| def check_initial_feedback_state(feedback_rating, feedback_text): | |
| return update_save_feedback_button(feedback_rating, feedback_text) | |
| demo = gr.Blocks() | |
| with demo: | |
| latest_obfuscation = gr.State({}) | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(variant="panel"): | |
| gr.Markdown("# 1) Input Text\n### Enter the text to be obfuscated.") | |
| input_text = gr.Textbox( | |
| label="Input Text", | |
| placeholder="The quick brown fox jumped over the lazy dogs." | |
| ) | |
| gr.Markdown("# 2) Style Element Sliders\n### Adjust the style element sliders to the desired levels to steer the obfuscation.") | |
| reset_button = gr.Button("Choose slider values automatically (based on input text)") | |
| sliders = [] | |
| slider_values = [ | |
| ("Length (Shorter \u2192 Longer)", -1, 1, 0), | |
| ("Function Words (Fewer \u2192 More)", -1, 1, 0), | |
| ("Grade Level (Lower \u2192 Higher)", -1, 1, 0), | |
| ("Formality (Less \u2192 More)", -1, 1, 0), | |
| ("Sarcasm (Less \u2192 More)", -1, 1, 0), | |
| ("Voice (Passive \u2192 Active)", -1, 1, 0), | |
| ("Writing Type: Persuasive (None \u2192 More)", 0, 1, 0), | |
| ("Writing Type: Descriptive (None \u2192 More)", 0, 1, 0), | |
| ("Writing Type: Narrative (None \u2192 More)", 0, 1, 0), | |
| ("Writing Type: Expository (None \u2192 More)", 0, 1, 0) | |
| ] | |
| non_writing_type_sliders = [] | |
| writing_type_sliders = [] | |
| for idx, (label, min_val, max_val, default) in enumerate(slider_values): | |
| if "Writing Type" not in label: | |
| with gr.Row(): | |
| checkbox = gr.Checkbox(label=label) | |
| slider = gr.Slider(label=label, minimum=min_val, maximum=max_val, step=0.01, value=default, interactive=False) | |
| checkbox.change(fn=toggle_slider, inputs=[checkbox, gr.State(default)], outputs=slider) | |
| non_writing_type_sliders.append(slider) | |
| sliders.append(slider) | |
| writing_type_radio = gr.Radio( | |
| label="Writing Type", | |
| choices=["None", "Persuasive", "Descriptive", "Narrative", "Expository"], | |
| value="None" | |
| ) | |
| writing_type_radio.change(fn=reset_writing_type_sliders, inputs=writing_type_radio, outputs=writing_type_sliders) | |
| for idx, (label, min_val, max_val, default) in enumerate(slider_values): | |
| if "Writing Type" in label: | |
| with gr.Row(): | |
| slider = gr.Slider(label=label, minimum=min_val, maximum=max_val, step=0.01, value=default, interactive=False) | |
| writing_type_sliders.append(slider) | |
| sliders.append(slider) | |
| obfuscate_button = gr.Button("Obfuscate Text", interactive=False) | |
| warning_message = gr.Markdown( | |
| "<div style='text-align: center; color: red;'>โ ๏ธ Please enter text before obfuscating. โ ๏ธ</div>", visible=True | |
| ) | |
| reset_button.click(fn=reset_sliders, inputs=[], outputs=sliders) | |
| input_text.change(fn=update_obfuscate_button, inputs=input_text, outputs=[obfuscate_button, warning_message]) | |
| # Initialize the button and warning message state on page load | |
| demo.load(fn=update_obfuscate_button, inputs=input_text, outputs=[obfuscate_button, warning_message]) | |
| # with gr.Column(variant="panel"): | |
| # gr.Markdown("# 3) Obfuscated Output") | |
| with gr.Column(variant="panel"): | |
| gr.Markdown("# 3) Obfuscated Output") | |
| output = gr.Textbox(label="Output", lines=3) | |
| gr.Markdown("## Feedback [Optional]") | |
| # Add thumbs up / thumbs down | |
| gr.Markdown("### Is the response good or bad?") | |
| feedback_rating = gr.Radio(choices=["No Feedback Selected", "Good ๐", "Bad ๐"], value="No Feedback Selected", interactive=False, label="Rate the Response") | |
| # Add feedback box | |
| gr.Markdown("### Provide any feedback on the obfuscation") | |
| feedback_text = gr.Textbox(label="Feedback", lines=3, interactive=False) | |
| obfuscate_button.click( | |
| fn=greet, | |
| inputs=[input_text] + sliders, | |
| outputs=[output, feedback_rating, feedback_text, latest_obfuscation]) | |
| save_feedback_button = gr.Button("Save Feedback", interactive=False) | |
| feedback_warning_message = gr.Markdown( | |
| "<div style='text-align: center; color: red;'>โ ๏ธ Please provide feedback or a rating before submitting. โ ๏ธ</div>", visible=True | |
| ) | |
| # Update the interactivity of the save_feedback_button based on feedback_rating and feedback_text | |
| feedback_rating.change(fn=update_save_feedback_button, inputs=[feedback_rating, feedback_text], outputs=[save_feedback_button, feedback_warning_message]) | |
| feedback_text.change(fn=update_save_feedback_button, inputs=[feedback_rating, feedback_text], outputs=[save_feedback_button, feedback_warning_message]) | |
| save_feedback_button.click( | |
| fn=save_feedback, | |
| inputs=[feedback_rating, feedback_text, latest_obfuscation], | |
| outputs=[feedback_rating, feedback_text] | |
| ) | |
| # Initialize the save feedback button and warning message state on page load | |
| demo.load(fn=check_initial_feedback_state, inputs=[feedback_rating, feedback_text], outputs=[save_feedback_button, feedback_warning_message]) | |
| demo.launch() | |