jem-mini / app.py
qewrufda's picture
Upload 9 files (#1)
da82347 verified
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import gradio as gr
# ============================================
# ๋ชจ๋ธ + LoRA ๊ฒฝ๋กœ
# ============================================
BASE_MODEL = "beomi/Llama-3-Open-Ko-8B"
LORA_PATH = "./lora" # Space repo์— lora ํด๋” ์—…๋กœ๋“œ
# ============================================
# ํ† ํฌ๋‚˜์ด์ € ๋ฐ ๋ชจ๋ธ ๋กœ๋“œ
# ============================================
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
model = PeftModel.from_pretrained(model, LORA_PATH, is_local=True)
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
# ============================================
# AI ์„ฑ๊ฒฉ ์„ค์ •
# ============================================
AI_PERSONALITY = """
๋„ˆ๋Š” ์‚ฌ์šฉ์ž์˜ ๋ง์„ ์ง„์‹ฌ์œผ๋กœ ๋“ค์–ด์ฃผ๋Š” ์นœ๊ตฌ์•ผ.
์‚ฌ์šฉ์ž๊ฐ€ ๋Œ€ํ™”๋ฅผ ๊ฑธ๋ฉด ์ž์—ฐ์Šค๋Ÿฝ๊ณ  ์ผ์ƒ์ ์ธ ํ†ค์œผ๋กœ ๋Œ€๋‹ตํ•ด.
์žฅํ™ฉํ•˜์ง€ ๋ง๊ณ , ๊ณต๊ฐํ•˜๋ฉด์„œ ์งง๊ณ  ๋”ฐ๋œปํ•˜๊ฒŒ ๋งํ•  ๊ฒƒ.
"""
history = []
# ============================================
# ๋Œ€ํ™” ํ•จ์ˆ˜
# ============================================
def chat(user_input):
history.append({"role": "user", "content": user_input})
prompt = "<|begin_of_text|>\n" + AI_PERSONALITY.strip() + "\n\n"
for turn in history:
prompt += f"<|start_header_id|>{turn['role']}<|end_header_id|>\n{turn['content']}<|eot|>\n"
prompt += "<|start_header_id|>assistant<|end_header_id|>\n"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.6,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
response_full = tokenizer.decode(output[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
response = response_full.split("<|eot|>")[0].strip()
history.append({"role": "assistant", "content": response})
if len(history) > 10:
history[:] = history[-10:] # ์ตœ๊ทผ 10ํ„ด๋งŒ ์œ ์ง€
return response
# ============================================
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์‹คํ–‰
# ============================================
iface = gr.Interface(fn=chat, inputs="text", outputs="text")
iface.launch()