from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import StreamingResponse, JSONResponse from pydantic import BaseModel import torch import torchaudio from transformers import ( WhisperProcessor, WhisperForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer, ) import io import tempfile import os import requests app = FastAPI(title="Asistente de Voz API - Versión Simple") # ============================================ # TOKEN DE HUGGING FACE (OPCIONAL) # ============================================ # Si quieres usar modelos privados o más cuota, obtén tu token en: # https://huggingface.co/settings/tokens HF_TOKEN = os.getenv("HF_TOKEN", None) # ============================================ # CARGAR MODELOS # ============================================ print("🔄 Cargando modelos...") # 1. WHISPER (Speech-to-Text) print("📝 Cargando Whisper...") whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small") whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") whisper_model.eval() # 2. MODELO DE LENGUAJE (más pequeño y rápido) print("🤖 Cargando modelo de lenguaje...") # Usando GPT-2 pequeño en español llm_tokenizer = AutoTokenizer.from_pretrained("DeepESP/gpt2-spanish-medium") llm_model = AutoModelForCausalLM.from_pretrained("DeepESP/gpt2-spanish-medium") llm_model.eval() print("✅ Modelos cargados!\n") # ============================================ # MODELOS DE DATOS # ============================================ class ChatRequest(BaseModel): question: str max_length: int = 150 class TTSRequest(BaseModel): text: str # ============================================ # FUNCIONES AUXILIARES # ============================================ def process_audio_file(audio_bytes): """Procesa bytes de audio y los convierte al formato correcto""" with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: tmp.write(audio_bytes) tmp_path = tmp.name try: # Cargar audio waveform, sample_rate = torchaudio.load(tmp_path) # Remuestrear a 16kHz if sample_rate != 16000: resampler = torchaudio.transforms.Resample(sample_rate, 16000) waveform = resampler(waveform) # Convertir a mono if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) return waveform.squeeze().numpy() finally: os.unlink(tmp_path) # ============================================ # ENDPOINT 1: TRANSCRIPCIÓN # ============================================ @app.post("/transcribe") async def transcribe_audio(file: UploadFile = File(...)): """Convierte audio WAV a texto""" try: print(f"📥 Recibiendo audio: {file.filename}") # Procesar audio audio_bytes = await file.read() waveform = process_audio_file(audio_bytes) # Transcribir con Whisper input_features = whisper_processor( waveform, sampling_rate=16000, return_tensors="pt" ).input_features with torch.no_grad(): predicted_ids = whisper_model.generate(input_features) transcription = whisper_processor.batch_decode( predicted_ids, skip_special_tokens=True )[0] print(f"✅ Transcrito: {transcription}") return JSONResponse({ "text": transcription, "success": True }) except Exception as e: print(f"❌ Error: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) # ============================================ # ENDPOINT 2: CHAT IA # ============================================ @app.post("/chat") async def chat(request: ChatRequest): """Genera respuesta de IA""" try: question = request.question.strip() print(f"💬 Pregunta: {question}") if not question: return JSONResponse({ "answer": "No escuché ninguna pregunta", "success": False }) # Crear contexto en español prompt = f"""Eres un asistente virtual amigable. Responde de forma breve y clara. Pregunta: {question} Respuesta:""" # Generar respuesta inputs = llm_tokenizer.encode(prompt, return_tensors="pt") with torch.no_grad(): outputs = llm_model.generate( inputs, max_length=request.max_length, num_return_sequences=1, temperature=0.8, top_p=0.9, do_sample=True, pad_token_id=llm_tokenizer.eos_token_id, repetition_penalty=1.2 ) # Decodificar full_text = llm_tokenizer.decode(outputs[0], skip_special_tokens=True) # Extraer solo la respuesta if "Respuesta:" in full_text: answer = full_text.split("Respuesta:")[-1].strip() else: answer = full_text.replace(prompt, "").strip() # Limpiar y limitar answer = answer.split("\n")[0].strip() # Solo primera línea if len(answer) > 200: answer = answer[:200].rsplit(" ", 1)[0] + "..." # Si está vacía, dar respuesta por defecto if not answer or len(answer) < 5: answer = "Interesante pregunta. Déjame pensar en eso." print(f"✅ Respuesta: {answer}") return JSONResponse({ "answer": answer, "success": True }) except Exception as e: print(f"❌ Error: {str(e)}") return JSONResponse({ "answer": "Lo siento, tuve un problema procesando tu pregunta", "success": False }) # ============================================ # ENDPOINT 3: TTS (usando API de HF) # ============================================ @app.post("/tts") async def text_to_speech(request: TTSRequest): """ Convierte texto a voz usando API de Hugging Face IMPORTANTE: Requiere conexión a internet """ try: text = request.text.strip() print(f"🔊 Generando voz: {text[:50]}...") if not text: raise HTTPException(status_code=400, detail="Texto vacío") # Limitar longitud if len(text) > 300: text = text[:300] + "..." # Usar API de Hugging Face para TTS # Modelo: Facebook MMS TTS español API_URL = "https://api-inference.huggingface.co/models/facebook/mms-tts-spa" headers = {} if HF_TOKEN: headers["Authorization"] = f"Bearer {HF_TOKEN}" # Hacer request a la API response = requests.post( API_URL, headers=headers, json={"inputs": text}, timeout=30 ) if response.status_code == 200: print(f"✅ Audio generado: {len(response.content)} bytes") return StreamingResponse( io.BytesIO(response.content), media_type="audio/flac", headers={ "Content-Disposition": "attachment; filename=speech.flac" } ) else: print(f"❌ Error API TTS: {response.status_code}") raise HTTPException( status_code=response.status_code, detail=f"Error en TTS: {response.text}" ) except requests.exceptions.Timeout: print("⏱️ Timeout en TTS") raise HTTPException(status_code=504, detail="Timeout generando audio") except Exception as e: print(f"❌ Error: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) # ============================================ # ENDPOINT 4: PROCESO COMPLETO # ============================================ @app.post("/complete") async def complete_conversation(file: UploadFile = File(...)): """ Proceso completo: Audio → Texto → IA → Audio """ try: print("\n" + "="*50) print("🔄 PROCESO COMPLETO INICIADO") print("="*50) # PASO 1: Transcribir print("\n📝 PASO 1: Transcribiendo...") audio_bytes = await file.read() waveform = process_audio_file(audio_bytes) input_features = whisper_processor( waveform, sampling_rate=16000, return_tensors="pt" ).input_features with torch.no_grad(): predicted_ids = whisper_model.generate(input_features) transcription = whisper_processor.batch_decode( predicted_ids, skip_special_tokens=True )[0].strip() print(f"✅ Transcripción: {transcription}") if not transcription or len(transcription) < 3: transcription = "No te escuché bien" # PASO 2: Generar respuesta print("\n🤖 PASO 2: Generando respuesta IA...") prompt = f"""Eres un asistente virtual amigable. Responde breve. Pregunta: {transcription} Respuesta:""" inputs = llm_tokenizer.encode(prompt, return_tensors="pt") with torch.no_grad(): outputs = llm_model.generate( inputs, max_length=150, temperature=0.8, top_p=0.9, do_sample=True, pad_token_id=llm_tokenizer.eos_token_id, repetition_penalty=1.2 ) full_text = llm_tokenizer.decode(outputs[0], skip_special_tokens=True) if "Respuesta:" in full_text: answer = full_text.split("Respuesta:")[-1].strip() else: answer = full_text.replace(prompt, "").strip() answer = answer.split("\n")[0].strip() if len(answer) > 200: answer = answer[:200].rsplit(" ", 1)[0] + "..." if not answer or len(answer) < 5: answer = "Entiendo tu pregunta." print(f"✅ Respuesta: {answer}") # PASO 3: Generar audio print("\n🔊 PASO 3: Generando audio...") API_URL = "https://api-inference.huggingface.co/models/facebook/mms-tts-spa" headers = {} if HF_TOKEN: headers["Authorization"] = f"Bearer {HF_TOKEN}" response = requests.post( API_URL, headers=headers, json={"inputs": answer}, timeout=30 ) if response.status_code != 200: print(f"⚠️ Error TTS, usando respuesta de texto") return JSONResponse({ "transcription": transcription, "answer": answer, "audio_error": True }) print("✅ Audio generado correctamente") print("="*50 + "\n") # Retornar audio con metadata en headers return StreamingResponse( io.BytesIO(response.content), media_type="audio/flac", headers={ "X-Transcription": transcription, "X-Answer": answer, "Content-Disposition": "attachment; filename=response.flac" } ) except Exception as e: print(f"❌ ERROR COMPLETO: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) # ============================================ # ENDPOINTS INFORMATIVOS # ============================================ @app.get("/") async def root(): return { "message": "🤖 API Asistente de Voz ESP32", "version": "2.0 - Simplificada", "status": "online", "endpoints": { "POST /transcribe": "Audio WAV → Texto", "POST /chat": "Pregunta → Respuesta IA", "POST /tts": "Texto → Audio", "POST /complete": "Audio → Audio (recomendado)" }, "models": { "stt": "openai/whisper-small", "llm": "DeepESP/gpt2-spanish-medium", "tts": "facebook/mms-tts-spa (API)" } } @app.get("/health") async def health_check(): return { "status": "healthy", "models_loaded": { "whisper": whisper_model is not None, "llm": llm_model is not None, "tts": "API externa" } } @app.get("/test") async def test_endpoint(): """Endpoint de prueba simple""" return { "message": "¡Servidor funcionando correctamente!", "test": "OK" } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)