Spaces:
Runtime error
Runtime error
| """Inspired by the SantaCoder demo Huggingface space. | |
| Link: https://huggingface.co/spaces/bigcode/santacoder-demo/tree/main/app.py | |
| """ | |
| import os | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed | |
| REPO = "replit/replit-code-v1-3b" | |
| description = """# <h1 style="text-align: center; color: white;"><span style='color: #F26207;'> Code Completion with replit-code-v1-3b </h1> | |
| <span style="color: white; text-align: center;"> replit-code-v1-3b model is a 2.7B LLM trained on 20 languages from the Stack Dedup v1.2 dataset. You can click the button several times to keep completing your code.</span>""" | |
| token = os.environ["HUB_TOKEN"] | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| PAD_TOKEN = "<|pad|>" | |
| EOS_TOKEN = "<|endoftext|>" | |
| UNK_TOKEN = "<|unk|>" | |
| MAX_INPUT_TOKENS = 1024 # max tokens from context | |
| tokenizer = AutoTokenizer.from_pretrained(REPO, use_auth_token=token, trust_remote_code=True) | |
| tokenizer.truncation_side = "left" # ensures if truncate, then keep the last N tokens of the prompt going L -> R | |
| if device == "cuda": | |
| model = AutoModelForCausalLM.from_pretrained(REPO, use_auth_token=token, trust_remote_code=True, low_cpu_mem_usage=True).to(device, dtype=torch.bfloat16) | |
| else: | |
| model = AutoModelForCausalLM.from_pretrained(REPO, use_auth_token=token, trust_remote_code=True, low_cpu_mem_usage=True) | |
| model.eval() | |
| custom_css = """ | |
| .gradio-container { | |
| background-color: #0D1525; | |
| color:white | |
| } | |
| #orange-button { | |
| background: #F26207 !important; | |
| color: white; | |
| } | |
| .cm-gutters{ | |
| border: none !important; | |
| } | |
| """ | |
| def post_processing(prompt, completion): | |
| return prompt + completion | |
| # completion = "<span style='color: #499cd5;'>" + completion + "</span>" | |
| # prompt = "<span style='color: black;'>" + prompt + "</span>" | |
| # code_html = f"<hr><br><pre style='font-size: 14px'><code>{prompt}{completion}</code></pre><br><hr>" | |
| # return code_html | |
| def code_generation(prompt, max_new_tokens, temperature=0.2, seed=42, top_p=0.9, top_k=None, use_cache=True, repetition_penalty=1.0): | |
| # truncates the prompt to MAX_INPUT_TOKENS if its too long | |
| x = tokenizer.encode(prompt, return_tensors="pt", max_length=MAX_INPUT_TOKENS, truncation=True).to(device) | |
| print("Prompt shape: ", x.shape) # just adding to see in the space logs in prod | |
| set_seed(seed) | |
| y = model.generate(x, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| top_p=top_p, | |
| top_k=top_k, | |
| use_cache=use_cache, | |
| repetition_penalty=repetition_penalty | |
| ) | |
| completion = tokenizer.decode(y[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
| completion = completion[len(prompt):] | |
| return post_processing(prompt, completion) | |
| demo = gr.Blocks( | |
| css=custom_css | |
| ) | |
| with demo: | |
| gr.Markdown(value=description) | |
| with gr.Row(): | |
| input_col , settings_col = gr.Column(scale=6), gr.Column(scale=6), | |
| with input_col: | |
| code = gr.Code(lines=28,label='Input', value="def sieve_eratosthenes(n):") | |
| with settings_col: | |
| with gr.Accordion("Generation Settings", open=True): | |
| max_new_tokens= gr.Slider( | |
| minimum=8, | |
| maximum=128, | |
| step=1, | |
| value=48, | |
| label="Max Tokens", | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.5, | |
| step=0.1, | |
| value=0.2, | |
| label="Temperature", | |
| ) | |
| repetition_penalty = gr.Slider( | |
| minimum=1.0, | |
| maximum=1.9, | |
| step=0.1, | |
| value=1.0, | |
| label="Repetition Penalty. 1.0 means no penalty.", | |
| ) | |
| seed = gr.Slider( | |
| minimum=0, | |
| maximum=1000, | |
| step=1, | |
| label="Random Seed" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| step=0.1, | |
| value=0.9, | |
| label="Top P", | |
| ) | |
| top_k = gr.Slider( | |
| minimum=1, | |
| maximum=64, | |
| step=1, | |
| value=4, | |
| label="Top K", | |
| ) | |
| use_cache = gr.Checkbox( | |
| label="Use Cache", | |
| value=True | |
| ) | |
| with gr.Row(): | |
| run = gr.Button(elem_id="orange-button", value="Generate More Code") | |
| # with gr.Row(): | |
| # # _, middle_col_row_2, _ = gr.Column(scale=1), gr.Column(scale=6), gr.Column(scale=1) | |
| # # with middle_col_row_2: | |
| # output = gr.HTML(label="Generated Code") | |
| event = run.click(code_generation, [code, max_new_tokens, temperature, seed, top_p, top_k, use_cache, repetition_penalty], code, api_name="predict") | |
| demo.queue(max_size=40).launch() |