Spaces:
Running
Running
File size: 4,188 Bytes
f9ffba7 5c0fa7a f9ffba7 5c0fa7a 0fbaeac f5b3197 5c0fa7a 8369acb f9ffba7 5c0fa7a f9ffba7 5c0fa7a c991573 5c0fa7a f5b3197 5c0fa7a f5b3197 5c0fa7a f5b3197 7b25f43 f5b3197 5c0fa7a f9ffba7 f5b3197 7b25f43 f5b3197 8369acb f9ffba7 5c0fa7a c991573 f69d788 f9ffba7 5c0fa7a f5b3197 f9ffba7 5c0fa7a f69d788 5c0fa7a 284d70a 5c0fa7a f69d788 5c0fa7a f5b3197 5c0fa7a f5b3197 5c0fa7a e4cdaf0 c569fce 5c0fa7a 02a6bc0 5c0fa7a e4cdaf0 c0474e7 5c0fa7a f9ffba7 89a6f1a c0474e7 d7d3d5f 60dc27e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
from fastapi import FastAPI, Request
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from datetime import datetime
from zoneinfo import ZoneInfo
import httpx
from functools import lru_cache
app = FastAPI()
# --- Tải model ---
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto"
)
# --- Request body ---
class GenerateRequest(BaseModel):
prompt: str | list[str]
# --- Cache IP info ---
@lru_cache(maxsize=100)
def cache_location(ip: str, city: str, country: str, timezone: str):
return city, country, timezone
# --- Hàm lấy vị trí từ IP ---
async def get_location_from_ip(ip: str):
try:
async with httpx.AsyncClient(timeout=3) as client:
res = await client.get(f"https://ipinfo.io/{ip}?token=2d478668dc5662")
data = res.json()
city = data.get("city", "Ho Chi Minh City")
country = data.get("country", "VN")
timezone = data.get("timezone", "Asia/Ho_Chi_Minh")
cache_location(ip, city, country, timezone)
return city, country, timezone
except:
return "Ho Chi Minh City", "VN", "Asia/Ho_Chi_Minh"
# --- Cache thời tiết ---
@lru_cache(maxsize=100)
def cache_weather(city: str, desc: str, temp: float):
return desc, temp
async def get_weather(city: str):
try:
api_key = "b7d49ff5de091794a9adc8ea62ef0ac7"
async with httpx.AsyncClient(timeout=3) as client:
url = f"http://api.openweathermap.org/data/2.5/weather?q={city}&appid={api_key}&units=metric&lang=vi"
res = await client.get(url)
data = res.json()
temp = data["main"]["temp"]
desc = data["weather"][0]["description"]
cache_weather(city, desc, temp)
return desc, temp
except:
return None, None
@app.get("/")
def home():
return {"message": "API FastAI đã sẵn sàng phục vụ Túc chủ 👑"}
@app.post("/generate")
async def generate(req: GenerateRequest, request: Request):
prompt = req.prompt
if isinstance(prompt, list):
prompt = " ".join(prompt)
# Lấy IP từ request
client_ip = request.client.host
city, country, timezone = await get_location_from_ip(client_ip)
# --- Logic đặc biệt ---
# 1. Thời gian hiện tại
if "mấy giờ" in prompt or "thời gian" in prompt:
tz = ZoneInfo(timezone)
current_time = datetime.now(tz).strftime("%H:%M")
return {"response": f"Bây giờ ở {city} là {current_time}"}
# 2. Ngày hôm nay
if "ngày hôm nay" in prompt or "hôm nay là ngày" in prompt:
tz = ZoneInfo(timezone)
today = datetime.now(tz).strftime("%d/%m/%Y")
return {"response": f"Hôm nay ở {city} là ngày {today}"}
# 3. Thời tiết
if "thời tiết" in prompt:
desc, temp = await get_weather(city)
if desc and temp:
return {"response": f"Thời tiết ở {city}: {desc}, nhiệt độ {temp}°C"}
else:
return {"response": "Không lấy được dữ liệu thời tiết."}
# --- Fallback sang mô hình ---
final_prompt = f"Trả lời bằng tiếng Việt, ngắn gọn, đầy đủ câu, không lặp lại câu hỏi và tiêu đề trang trích dẫn: {prompt}"
inputs = tokenizer(final_prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=250, # Giảm để tăng tốc
temperature=0.5,
top_p=0.8,
do_sample=True,
repetition_penalty=1.3,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id
)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
clean_result = result[len(final_prompt):].strip()
if not clean_result.endswith((".", "!", "?")):
clean_result += "..."
return {"response": clean_result} |