from memory import ConversationMemory from config import Config from transformers import AutoTokenizer, AutoModelForCausalLM import torch, gc import unicodedata from typing import Dict, Tuple import re import pandas as pd import sqlite3 class MojicaAgent: def __init__(self, config: Config): self.config = config self.memory = ConversationMemory() self.schema = self._load_schema() self._safe_initializer_model() def _safe_initializer_model(self): def try_load_model(): dtype = torch.float16 if "cuda" in self.config.DEVICE else torch.float32 tokenizer = AutoTokenizer.from_pretrained(self.config.MODEL_NAME) model = ( AutoModelForCausalLM.from_pretrained( self.config.MODEL_NAME, trust_remote_code=True, torch_dtype=dtype ) .to(self.config.DEVICE) .eval() ) # eval porque solo se predice return tokenizer, model try: self.tokenizer, self.model = try_load_model() except torch.cuda.OutOfMemoryError: # Liberar memoria y volver a intentar gc.collect() torch.cuda.empty_cache() torch.cuda.ipc_collect() self.tokenizer, self.model = try_load_model() def _load_schema(self) -> Dict: conn = sqlite3.connect(self.config.DB_PATH) cursor = conn.cursor() cursor.execute(f"PRAGMA table_info({self.config.TABLE_NAME})") columns = [ {"name": column[1], "type": column[2]} for column in cursor.fetchall() ] schema = {"table_name": self.config.TABLE_NAME, "columns": columns} conn.close() return schema def _get_schema_structured(self) -> Dict: if self.memory.schema_cache: return self.memory.schema_cache cursor = self.conn.cursor() cursor.execute(f"PRAGMA table_info({self.config.TABLE_NAME})") columns = [ {"name": column[1], "type": column[2]} for column in cursor.fetchall() ] schema = {"table_name": self.config.TABLE_NAME, "columns": columns} self.memory.schema_cache = schema return schema def _build_prompt(self, question: str) -> str: memory_context = self.memory.get_context(question) table_name = self.schema["table_name"] # 1. Detectar tipo de pregunta question_type = ( "PRODUCTOS" if "producto" in question.lower() else "CLIENTES" if "cliente" in question.lower() else "GENERAL" ) # 2. Ejemplos dinámicos examples = { "PRODUCTOS": ( "-- P: 'Top 10 productos más vendidos'\n" 'SELECT "Descripcion", SUM("Cantidad") AS total_vendido\n' f'FROM "{table_name}"\n' 'WHERE "Descripcion" IS NOT NULL\n' 'GROUP BY "Descripcion"\n' "ORDER BY total_vendido DESC\n" "LIMIT 10;\n\n" "-- P: 'Productos con mayor valor neto'\n" 'SELECT "Descripcion", SUM("Neto") AS valor_total\n' f'FROM "{table_name}"\n' 'WHERE "Descripcion" IS NOT NULL\n' 'GROUP BY "Descripcion"\n' "ORDER BY valor_total DESC\n" "LIMIT 5;" ), "CLIENTES": ( "-- P: 'Top 5 clientes con mayor valor neto'\n" 'SELECT "Cliente", SUM("Neto") AS valor_total\n' f'FROM "{table_name}"\n' "WHERE \"Cliente\" IS NOT NULL AND \"Fecha\" BETWEEN '2025-01-01' AND '2025-12-31'\n" 'GROUP BY "Cliente"\n' "ORDER BY valor_total DESC\n" "LIMIT 5;\n\n" "-- P: 'Clientes con más compras en marzo'\n" 'SELECT "Cliente", COUNT(*) AS total_compras\n' f'FROM "{table_name}"\n' "WHERE \"Cliente\" IS NOT NULL AND strftime('%m', \"Fecha\") = '03'\n" 'GROUP BY "Cliente"\n' "ORDER BY total_compras DESC\n" "LIMIT 10;\n\n" "-- P: 'Clientes de Guadalajara con más compras'\n" 'SELECT "Cliente", "Razon Social", COUNT(*) AS total_compras\n' f'FROM "{table_name}"\n' 'WHERE "Cliente" IS NOT NULL AND "Ciudad" = \'Guadalajara\'\n' 'GROUP BY "Cliente", "Razon Social"\n' "ORDER BY total_compras DESC\n" "LIMIT 10;" ), "GENERAL": ( "-- P: 'Ventas totales por mes'\n" 'SELECT strftime(\'%m\', "Fecha") AS mes, SUM("Neto") AS ventas\n' f'FROM "{table_name}"\n' "WHERE mes IS NOT NULL\n" "GROUP BY mes\n" "ORDER BY mes;\n\n" "-- P: 'Producto menos vendido en 2025'\n" 'SELECT "Descripcion", SUM("Cantidad") AS total_vendido\n' f'FROM "{table_name}"\n' "WHERE \"Descripcion\" IS NOT NULL AND \"Fecha\" BETWEEN '2025-01-01' AND '2025-12-31'\n" 'GROUP BY "Descripcion"\n' "ORDER BY total_vendido ASC\n" "LIMIT 1;" ), } # 3. Columnas esenciales essential_columns = [ { "name": "Descripcion", "type": "TEXT", "description": "Nombre del producto", }, {"name": "Cantidad", "type": "REAL", "description": "Unidades vendidas"}, {"name": "Cliente", "type": "TEXT", "description": "Código de cliente"}, { "name": "Razon Social", "type": "TEXT", "description": "Nombre completo del cliente", }, {"name": "Ciudad", "type": "TEXT", "description": "Ciudad del cliente"}, { "name": "Fecha", "type": "TEXT", "description": "Fecha de venta (YYYY-MM-DD)", }, {"name": "Neto", "type": "REAL", "description": "Valor neto de la venta"}, ] # 4. Prompt final con nueva regla return ( f""" ### TAREA ### Generar SOLO código SQL para la pregunta, usando EXCLUSIVAMENTE la tabla: "{table_name}" ### COLUMNAS RELEVANTES ### """ + "\n".join( [ f"- {col['name']} ({col['type']}): {col['description']}" for col in essential_columns ] ) + f""" ### CONTEXTO (Últimas interacciones) ### {memory_context if memory_context else "Sin historial relevante"} ### EJEMPLOS ({question_type}) ### {examples[question_type]} ### REGLAS CRÍTICAS ### - Usar siempre nombres exactos de columnas - Agrupar por la dimensión principal (producto/cliente) - Ordenar DESC para 'más/mayor', ASC para 'menos/menor' - Usar LIMIT para top N - Año actual: 2025 - Siempre terminar con un LIMIT = 1 en caso que se indique lo contrario - Para 'más vendido' usar SUM("Cantidad"), para 'mayor valor' usar SUM("Neto") - Usar "Razon Social" cuando pregunten por el nombre del cliente - Usar "Ciudad" para filtrar o agrupar por ubicación - Queda estrictamente prohibido usar acentos - **Siempre excluir valores nulos con 'IS NOT NULL' en las columnas usadas en WHERE, GROUP BY u ORDER BY** ### PREGUNTA ACTUAL ### \"\"\"{question}\"\"\" ### SQL: """ ) def _clean_sql_output(self, output: str) -> str: # Encuentra todas las posibles queries completas que terminen en ; sql_matches = list( re.finditer( r"(SELECT|WITH|INSERT|UPDATE|DELETE)[\s\S]+?;", output, re.IGNORECASE ) ) if not sql_matches: return None # Tomar la última query encontrada sql = sql_matches[-1].group(0).strip() # Seguridad: bloquear queries peligrosas if any( cmd in sql.upper() for cmd in ["DROP", "DELETE", "UPDATE", "INSERT", "ALTER"] ): return None # Asegurar que termine en ; if not sql.endswith(";"): sql += ";" # ──────────────────────────────── # 1. Quitar acentos de toda la query # ──────────────────────────────── def remove_accents(text: str) -> str: return "".join( c for c in unicodedata.normalize("NFKD", text) if not unicodedata.combining(c) ) sql = remove_accents(sql) # ──────────────────────────────── # 2. Agregar LIMIT si no existe # ──────────────────────────────── # Buscar si ya hay un LIMIT en la query if not re.search(r"\bLIMIT\s+\d+", sql, re.IGNORECASE): # Insertar antes del último punto y coma sql = ( sql[:-1] + " LIMIT 1;" ) # puedes cambiar 100 por el valor default que quieras return sql def _execute_sql(self, sql: str): conn = sqlite3.connect(self.config.DB_PATH) try: result = pd.read_sql_query(sql, conn) conn.close() return result except Exception as e: return f"Error de ejecución: {str(e)}" finally: conn.close() def consult(self, question: str) -> Tuple[str, any]: prompt = self._build_prompt(question) tokenized_input = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=self.config.MAX_TOKENS, ).to(self.config.DEVICE) # Desactiva el cálculo de gradientes -> Siempre poner cuando se haga prediccion # - Reduce consumo de memoria # - Acelera inferencia with torch.no_grad(): tokenized_output_model = self.model.generate( **tokenized_input, max_new_tokens=self.config.MAX_NEW_TOKENS, temperature=0.2, top_p=0.95, top_k=50, repetition_penalty=1.1, do_sample=True, pad_token_id=self.tokenizer.eos_token_id, ) output_model = self.tokenizer.decode( tokenized_output_model[0], skip_special_tokens=True ) sql_query = self._clean_sql_output(output_model) if not sql_query: return "Error: No se pudo generar SQL válido" + "\n" + output_model, None result = self._execute_sql(sql_query) self.memory.add_interaction(question, sql_query, result) return sql_query, result