Spaces:
Sleeping
Sleeping
| import torch | |
| import os | |
| from fastapi import FastAPI, Request | |
| from pydantic import BaseModel | |
| from transformers import PreTrainedModel, AutoConfig | |
| from huggingface_hub import hf_hub_download | |
| import tiktoken | |
| from model import GPT, GPTConfig | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pathlib import Path | |
| import tempfile | |
| # Get the absolute path to the templates directory | |
| TEMPLATES_DIR = os.path.join(os.path.dirname(__file__), "templates") | |
| MODEL_ID = "sagargurujula/text-generator" | |
| # Initialize FastAPI | |
| app = FastAPI(title="GPT Text Generator") | |
| # Templates with absolute path | |
| templates = Jinja2Templates(directory=TEMPLATES_DIR) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Set device | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # Use system's temporary directory | |
| cache_dir = Path(tempfile.gettempdir()) / "model_cache" | |
| os.environ['TRANSFORMERS_CACHE'] = str(cache_dir) | |
| os.environ['HF_HOME'] = str(cache_dir) | |
| # Load model from Hugging Face Hub | |
| def load_model(): | |
| try: | |
| # Download the model file from HF Hub with authentication | |
| model_path = hf_hub_download( | |
| repo_id=MODEL_ID, | |
| filename="best_model.pth", | |
| cache_dir=cache_dir, | |
| token=os.environ.get('HF_TOKEN') # Get token from environment variable | |
| ) | |
| # Initialize our custom GPT model | |
| model = GPT(GPTConfig()) | |
| # Load the state dict | |
| checkpoint = torch.load(model_path, map_location=device, weights_only=True) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| raise | |
| # Load the model | |
| model = load_model() | |
| # Define the request body | |
| class TextInput(BaseModel): | |
| text: str | |
| async def generate_text(input: TextInput): | |
| # Prepare input tensor | |
| enc = tiktoken.get_encoding('gpt2') | |
| input_ids = torch.tensor([enc.encode(input.text)]).to(device) | |
| # Generate multiple tokens | |
| generated_tokens = [] | |
| num_tokens_to_generate = 50 # Generate 20 new tokens | |
| with torch.no_grad(): | |
| current_ids = input_ids | |
| for _ in range(num_tokens_to_generate): | |
| # Get model predictions | |
| logits, _ = model(current_ids) | |
| next_token = logits[0, -1, :].argmax().item() | |
| generated_tokens.append(next_token) | |
| # Add the new token to our current sequence | |
| current_ids = torch.cat([current_ids, torch.tensor([[next_token]]).to(device)], dim=1) | |
| # Decode all generated tokens | |
| generated_text = enc.decode(generated_tokens) | |
| # Return both input and generated text | |
| return { | |
| "input_text": input.text, | |
| "generated_text": generated_text | |
| } | |
| # Modify the root route to serve the template | |
| async def home(request: Request): | |
| return templates.TemplateResponse( | |
| "index.html", | |
| {"request": request, "title": "GPT Text Generator"} | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="127.0.0.1", port=8080) | |
| # To run the app, use the command: uvicorn app:app --reload |