|
|
from fastapi import FastAPI, File, UploadFile, HTTPException |
|
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
|
from pydantic import BaseModel |
|
|
import torch |
|
|
import torchaudio |
|
|
from transformers import ( |
|
|
WhisperProcessor, |
|
|
WhisperForConditionalGeneration, |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
) |
|
|
import io |
|
|
import tempfile |
|
|
import os |
|
|
import requests |
|
|
|
|
|
app = FastAPI(title="Asistente de Voz API - Versión Simple") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN", None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("🔄 Cargando modelos...") |
|
|
|
|
|
|
|
|
print("📝 Cargando Whisper...") |
|
|
whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small") |
|
|
whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") |
|
|
whisper_model.eval() |
|
|
|
|
|
|
|
|
print("🤖 Cargando modelo de lenguaje...") |
|
|
|
|
|
llm_tokenizer = AutoTokenizer.from_pretrained("DeepESP/gpt2-spanish-medium") |
|
|
llm_model = AutoModelForCausalLM.from_pretrained("DeepESP/gpt2-spanish-medium") |
|
|
llm_model.eval() |
|
|
|
|
|
print("✅ Modelos cargados!\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
question: str |
|
|
max_length: int = 150 |
|
|
|
|
|
class TTSRequest(BaseModel): |
|
|
text: str |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_audio_file(audio_bytes): |
|
|
"""Procesa bytes de audio y los convierte al formato correcto""" |
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: |
|
|
tmp.write(audio_bytes) |
|
|
tmp_path = tmp.name |
|
|
|
|
|
try: |
|
|
|
|
|
waveform, sample_rate = torchaudio.load(tmp_path) |
|
|
|
|
|
|
|
|
if sample_rate != 16000: |
|
|
resampler = torchaudio.transforms.Resample(sample_rate, 16000) |
|
|
waveform = resampler(waveform) |
|
|
|
|
|
|
|
|
if waveform.shape[0] > 1: |
|
|
waveform = torch.mean(waveform, dim=0, keepdim=True) |
|
|
|
|
|
return waveform.squeeze().numpy() |
|
|
finally: |
|
|
os.unlink(tmp_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/transcribe") |
|
|
async def transcribe_audio(file: UploadFile = File(...)): |
|
|
"""Convierte audio WAV a texto""" |
|
|
try: |
|
|
print(f"📥 Recibiendo audio: {file.filename}") |
|
|
|
|
|
|
|
|
audio_bytes = await file.read() |
|
|
waveform = process_audio_file(audio_bytes) |
|
|
|
|
|
|
|
|
input_features = whisper_processor( |
|
|
waveform, |
|
|
sampling_rate=16000, |
|
|
return_tensors="pt" |
|
|
).input_features |
|
|
|
|
|
with torch.no_grad(): |
|
|
predicted_ids = whisper_model.generate(input_features) |
|
|
|
|
|
transcription = whisper_processor.batch_decode( |
|
|
predicted_ids, |
|
|
skip_special_tokens=True |
|
|
)[0] |
|
|
|
|
|
print(f"✅ Transcrito: {transcription}") |
|
|
|
|
|
return JSONResponse({ |
|
|
"text": transcription, |
|
|
"success": True |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Error: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/chat") |
|
|
async def chat(request: ChatRequest): |
|
|
"""Genera respuesta de IA""" |
|
|
try: |
|
|
question = request.question.strip() |
|
|
print(f"💬 Pregunta: {question}") |
|
|
|
|
|
if not question: |
|
|
return JSONResponse({ |
|
|
"answer": "No escuché ninguna pregunta", |
|
|
"success": False |
|
|
}) |
|
|
|
|
|
|
|
|
prompt = f"""Eres un asistente virtual amigable. Responde de forma breve y clara. |
|
|
|
|
|
Pregunta: {question} |
|
|
Respuesta:""" |
|
|
|
|
|
|
|
|
inputs = llm_tokenizer.encode(prompt, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = llm_model.generate( |
|
|
inputs, |
|
|
max_length=request.max_length, |
|
|
num_return_sequences=1, |
|
|
temperature=0.8, |
|
|
top_p=0.9, |
|
|
do_sample=True, |
|
|
pad_token_id=llm_tokenizer.eos_token_id, |
|
|
repetition_penalty=1.2 |
|
|
) |
|
|
|
|
|
|
|
|
full_text = llm_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
if "Respuesta:" in full_text: |
|
|
answer = full_text.split("Respuesta:")[-1].strip() |
|
|
else: |
|
|
answer = full_text.replace(prompt, "").strip() |
|
|
|
|
|
|
|
|
answer = answer.split("\n")[0].strip() |
|
|
if len(answer) > 200: |
|
|
answer = answer[:200].rsplit(" ", 1)[0] + "..." |
|
|
|
|
|
|
|
|
if not answer or len(answer) < 5: |
|
|
answer = "Interesante pregunta. Déjame pensar en eso." |
|
|
|
|
|
print(f"✅ Respuesta: {answer}") |
|
|
|
|
|
return JSONResponse({ |
|
|
"answer": answer, |
|
|
"success": True |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Error: {str(e)}") |
|
|
return JSONResponse({ |
|
|
"answer": "Lo siento, tuve un problema procesando tu pregunta", |
|
|
"success": False |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/tts") |
|
|
async def text_to_speech(request: TTSRequest): |
|
|
""" |
|
|
Convierte texto a voz usando API de Hugging Face |
|
|
IMPORTANTE: Requiere conexión a internet |
|
|
""" |
|
|
try: |
|
|
text = request.text.strip() |
|
|
print(f"🔊 Generando voz: {text[:50]}...") |
|
|
|
|
|
if not text: |
|
|
raise HTTPException(status_code=400, detail="Texto vacío") |
|
|
|
|
|
|
|
|
if len(text) > 300: |
|
|
text = text[:300] + "..." |
|
|
|
|
|
|
|
|
|
|
|
API_URL = "https://api-inference.huggingface.co/models/facebook/mms-tts-spa" |
|
|
|
|
|
headers = {} |
|
|
if HF_TOKEN: |
|
|
headers["Authorization"] = f"Bearer {HF_TOKEN}" |
|
|
|
|
|
|
|
|
response = requests.post( |
|
|
API_URL, |
|
|
headers=headers, |
|
|
json={"inputs": text}, |
|
|
timeout=30 |
|
|
) |
|
|
|
|
|
if response.status_code == 200: |
|
|
print(f"✅ Audio generado: {len(response.content)} bytes") |
|
|
|
|
|
return StreamingResponse( |
|
|
io.BytesIO(response.content), |
|
|
media_type="audio/flac", |
|
|
headers={ |
|
|
"Content-Disposition": "attachment; filename=speech.flac" |
|
|
} |
|
|
) |
|
|
else: |
|
|
print(f"❌ Error API TTS: {response.status_code}") |
|
|
raise HTTPException( |
|
|
status_code=response.status_code, |
|
|
detail=f"Error en TTS: {response.text}" |
|
|
) |
|
|
|
|
|
except requests.exceptions.Timeout: |
|
|
print("⏱️ Timeout en TTS") |
|
|
raise HTTPException(status_code=504, detail="Timeout generando audio") |
|
|
except Exception as e: |
|
|
print(f"❌ Error: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/complete") |
|
|
async def complete_conversation(file: UploadFile = File(...)): |
|
|
""" |
|
|
Proceso completo: Audio → Texto → IA → Audio |
|
|
""" |
|
|
try: |
|
|
print("\n" + "="*50) |
|
|
print("🔄 PROCESO COMPLETO INICIADO") |
|
|
print("="*50) |
|
|
|
|
|
|
|
|
print("\n📝 PASO 1: Transcribiendo...") |
|
|
audio_bytes = await file.read() |
|
|
waveform = process_audio_file(audio_bytes) |
|
|
|
|
|
input_features = whisper_processor( |
|
|
waveform, |
|
|
sampling_rate=16000, |
|
|
return_tensors="pt" |
|
|
).input_features |
|
|
|
|
|
with torch.no_grad(): |
|
|
predicted_ids = whisper_model.generate(input_features) |
|
|
|
|
|
transcription = whisper_processor.batch_decode( |
|
|
predicted_ids, |
|
|
skip_special_tokens=True |
|
|
)[0].strip() |
|
|
|
|
|
print(f"✅ Transcripción: {transcription}") |
|
|
|
|
|
if not transcription or len(transcription) < 3: |
|
|
transcription = "No te escuché bien" |
|
|
|
|
|
|
|
|
print("\n🤖 PASO 2: Generando respuesta IA...") |
|
|
prompt = f"""Eres un asistente virtual amigable. Responde breve. |
|
|
|
|
|
Pregunta: {transcription} |
|
|
Respuesta:""" |
|
|
|
|
|
inputs = llm_tokenizer.encode(prompt, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = llm_model.generate( |
|
|
inputs, |
|
|
max_length=150, |
|
|
temperature=0.8, |
|
|
top_p=0.9, |
|
|
do_sample=True, |
|
|
pad_token_id=llm_tokenizer.eos_token_id, |
|
|
repetition_penalty=1.2 |
|
|
) |
|
|
|
|
|
full_text = llm_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
if "Respuesta:" in full_text: |
|
|
answer = full_text.split("Respuesta:")[-1].strip() |
|
|
else: |
|
|
answer = full_text.replace(prompt, "").strip() |
|
|
|
|
|
answer = answer.split("\n")[0].strip() |
|
|
if len(answer) > 200: |
|
|
answer = answer[:200].rsplit(" ", 1)[0] + "..." |
|
|
|
|
|
if not answer or len(answer) < 5: |
|
|
answer = "Entiendo tu pregunta." |
|
|
|
|
|
print(f"✅ Respuesta: {answer}") |
|
|
|
|
|
|
|
|
print("\n🔊 PASO 3: Generando audio...") |
|
|
API_URL = "https://api-inference.huggingface.co/models/facebook/mms-tts-spa" |
|
|
|
|
|
headers = {} |
|
|
if HF_TOKEN: |
|
|
headers["Authorization"] = f"Bearer {HF_TOKEN}" |
|
|
|
|
|
response = requests.post( |
|
|
API_URL, |
|
|
headers=headers, |
|
|
json={"inputs": answer}, |
|
|
timeout=30 |
|
|
) |
|
|
|
|
|
if response.status_code != 200: |
|
|
print(f"⚠️ Error TTS, usando respuesta de texto") |
|
|
return JSONResponse({ |
|
|
"transcription": transcription, |
|
|
"answer": answer, |
|
|
"audio_error": True |
|
|
}) |
|
|
|
|
|
print("✅ Audio generado correctamente") |
|
|
print("="*50 + "\n") |
|
|
|
|
|
|
|
|
return StreamingResponse( |
|
|
io.BytesIO(response.content), |
|
|
media_type="audio/flac", |
|
|
headers={ |
|
|
"X-Transcription": transcription, |
|
|
"X-Answer": answer, |
|
|
"Content-Disposition": "attachment; filename=response.flac" |
|
|
} |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ ERROR COMPLETO: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
return { |
|
|
"message": "🤖 API Asistente de Voz ESP32", |
|
|
"version": "2.0 - Simplificada", |
|
|
"status": "online", |
|
|
"endpoints": { |
|
|
"POST /transcribe": "Audio WAV → Texto", |
|
|
"POST /chat": "Pregunta → Respuesta IA", |
|
|
"POST /tts": "Texto → Audio", |
|
|
"POST /complete": "Audio → Audio (recomendado)" |
|
|
}, |
|
|
"models": { |
|
|
"stt": "openai/whisper-small", |
|
|
"llm": "DeepESP/gpt2-spanish-medium", |
|
|
"tts": "facebook/mms-tts-spa (API)" |
|
|
} |
|
|
} |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
return { |
|
|
"status": "healthy", |
|
|
"models_loaded": { |
|
|
"whisper": whisper_model is not None, |
|
|
"llm": llm_model is not None, |
|
|
"tts": "API externa" |
|
|
} |
|
|
} |
|
|
|
|
|
@app.get("/test") |
|
|
async def test_endpoint(): |
|
|
"""Endpoint de prueba simple""" |
|
|
return { |
|
|
"message": "¡Servidor funcionando correctamente!", |
|
|
"test": "OK" |
|
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |