S-Dreamer commited on
Commit
b2ce241
·
verified ·
1 Parent(s): 628c76c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -40
app.py CHANGED
@@ -1,59 +1,81 @@
1
-
2
- import gradio as gr
3
  import torch
4
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
5
 
6
  # --- Model Setup ---
7
- model_path = "WhiteRabbitNeo/WhiteRabbitNeo-2.5-Qwen-2.5-Coder-7B"
 
 
 
 
 
8
 
9
  model = AutoModelForCausalLM.from_pretrained(
10
- model_path,
11
- torch_dtype=torch.float16,
12
- device_map="auto",
13
- load_in_4bit=False,
14
- trust_remote_code=False,
15
  )
16
 
17
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 
 
 
 
 
18
 
19
- def generate_code(user_prompt, temperature=0.75, top_p=1.0, max_tokens=2048, top_k=50):
20
- tokens = tokenizer.encode(user_prompt)
21
- tokens = torch.LongTensor(tokens).unsqueeze(0).to("cuda")
22
 
23
  with torch.no_grad():
24
  output = model.generate(
25
- input_ids=tokens,
26
- max_length=len(tokens[0]) + max_tokens,
27
  do_sample=True,
28
- temperature=temperature,
29
- top_p=top_p,
30
- top_k=top_k,
31
- num_return_sequences=1,
32
  pad_token_id=tokenizer.eos_token_id,
33
  )
34
 
35
- # Decode only newly generated tokens
36
- generated_tokens = output[0][len(tokens[0]):]
37
- code_string = tokenizer.decode(generated_tokens, skip_special_tokens=True)
38
- # Wrap in python code block for Gradio Markdown rendering
39
- return f"```python\n{code_string}\n```"
40
 
41
- # --- Gradio UI ---
42
  with gr.Blocks(title="Spec Kit Copilot") as demo:
43
- with gr.Tab("AI Code Generation"):
44
- user_input = gr.Textbox(label="Describe code to generate", lines=4, placeholder="E.g., Python function to sort a list")
45
- temperature = gr.Slider(0.0, 1.0, 0.75, label="Temperature")
46
- top_p = gr.Slider(0.0, 1.0, 1.0, label="Top-p")
47
- max_tokens = gr.Slider(256, 4096, 2048, step=128, label="Max Tokens")
48
- top_k = gr.Slider(0, 100, 50, label="Top-k")
49
- generate_btn = gr.Button("Generate Code")
50
- preview = gr.Markdown()
51
-
52
- generate_btn.click(
53
- fn=generate_code,
54
- inputs=[user_input, temperature, top_p, max_tokens, top_k],
55
- outputs=preview
56
- )
 
 
 
 
 
 
 
 
 
57
 
58
  if __name__ == "__main__":
59
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ import os
 
2
  import torch
3
+ import gradio as gr
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
5
+
6
+ # --- Config ---
7
+ MODEL_PATH = os.getenv("MODEL_PATH", "WhiteRabbitNeo/WhiteRabbitNeo-2.5-Qwen-2.5-Coder-7B")
8
+ LOAD_IN_4BIT = os.getenv("LOAD_IN_4BIT", "true").lower() == "true"
9
+ MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", 2048))
10
+ DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
11
 
12
  # --- Model Setup ---
13
+ quant_config = BitsAndBytesConfig(
14
+ load_in_4bit=LOAD_IN_4BIT,
15
+ bnb_4bit_compute_dtype=torch.bfloat16,
16
+ bnb_4bit_use_double_quant=True,
17
+ bnb_4bit_quant_type="nf4",
18
+ )
19
 
20
  model = AutoModelForCausalLM.from_pretrained(
21
+ MODEL_PATH,
22
+ quantization_config=quant_config if LOAD_IN_4BIT else None,
23
+ torch_dtype=torch.bfloat16 if DEVICE != "cpu" else torch.float32,
24
+ device_map="auto" if DEVICE != "cpu" else None,
25
+ trust_remote_code=True,
26
  )
27
 
28
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
29
+
30
+ # --- Generation Function ---
31
+ def generate_code(user_prompt, temperature=0.7, top_p=0.95, max_tokens=1024, top_k=50):
32
+ if not user_prompt.strip():
33
+ return "⚠️ Please enter a valid prompt."
34
 
35
+ inputs = tokenizer(user_prompt, return_tensors="pt", truncation=True)
36
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
 
37
 
38
  with torch.no_grad():
39
  output = model.generate(
40
+ **inputs,
41
+ max_new_tokens=int(max_tokens),
42
  do_sample=True,
43
+ temperature=float(temperature),
44
+ top_p=float(top_p),
45
+ top_k=int(top_k),
 
46
  pad_token_id=tokenizer.eos_token_id,
47
  )
48
 
49
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
50
+ new_text = generated_text[len(user_prompt):].strip()
51
+ safe_code = new_text.replace("```", "`\u200b``") # Prevent Markdown escape issues
52
+ return f"```python\n{safe_code}\n```"
 
53
 
54
+ # --- UI ---
55
  with gr.Blocks(title="Spec Kit Copilot") as demo:
56
+ gr.Markdown("### 🧠 Spec Kit Copilot — AI Code Generator (Hugging Face Space Edition)")
57
+ with gr.Row():
58
+ with gr.Column(scale=2):
59
+ user_input = gr.Textbox(
60
+ label="Describe code to generate",
61
+ lines=4,
62
+ placeholder="E.g., Python function to parse a JSON file and pretty-print it."
63
+ )
64
+ with gr.Row():
65
+ temperature = gr.Slider(0.0, 1.0, 0.7, label="Temperature")
66
+ top_p = gr.Slider(0.0, 1.0, 0.95, label="Top-p")
67
+ with gr.Row():
68
+ max_tokens = gr.Slider(256, 4096, 1024, step=128, label="Max Tokens")
69
+ top_k = gr.Slider(0, 100, 50, label="Top-k")
70
+ generate_btn = gr.Button("🚀 Generate Code")
71
+ with gr.Column(scale=3):
72
+ preview = gr.Markdown("")
73
+
74
+ generate_btn.click(
75
+ fn=generate_code,
76
+ inputs=[user_input, temperature, top_p, max_tokens, top_k],
77
+ outputs=preview,
78
+ )
79
 
80
  if __name__ == "__main__":
81
+ demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))