Carlos Isael Ramírez González commited on
Commit
1dd8a9e
Β·
1 Parent(s): d269fec

Modelo y logica lista

Browse files
Files changed (6) hide show
  1. .gitattributes +1 -0
  2. app.py +25 -0
  3. config.py +10 -0
  4. memory.py +47 -0
  5. mojica_agent.py +292 -0
  6. requirements.txt +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.db filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,7 +1,32 @@
1
  from fastapi import FastAPI
 
 
 
 
 
2
 
3
  app = FastAPI()
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  @app.get("/")
6
  def greet_json():
7
  return {"Hello": "World!"}
 
1
  from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from typing import Any
4
+ import pandas as pd
5
+ from mojica_agent import MojicaAgent
6
+ from config import Config
7
 
8
  app = FastAPI()
9
 
10
+ mojica_bot = MojicaAgent(Config)
11
+
12
+ # * Esquema de entrada como marshmellow
13
+ class QuestionRequest(BaseModel):
14
+ question: str
15
+
16
+ class AnswerResponse(BaseModel):
17
+ sql: str
18
+ result: Any
19
+
20
+ @app.post("/ask", response_model=AnswerResponse)
21
+ def ask_question(req: QuestionRequest):
22
+ sql, result = mojica_bot.consult(req.question)
23
+
24
+ # * Si es dataframe lo convertimos a json
25
+ if isinstance(result, pd.DataFrame):
26
+ result = result.to_dict(orient="records")
27
+
28
+ return {"sql": sql, "result": result}
29
+
30
  @app.get("/")
31
  def greet_json():
32
  return {"Hello": "World!"}
config.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ class Config:
3
+ DB_PATH = 'dataset.db'
4
+ TABLE_NAME = 'sells'
5
+ MODEL_NAME = "ibm-granite/granite-3b-code-instruct-128k"
6
+ CSV_PATH = "/kaggle/input/mojica-hoja-1/mojica_hoja_1.csv"
7
+ MAX_HISTORY = 3 # Mantener las ΓΊltimas 3 interacciones (memoria)
8
+ MAX_TOKENS = 8_000
9
+ MAX_NEW_TOKENS = 400
10
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
memory.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ import pandas as pd
3
+ from config import Config
4
+
5
+ class ConversationMemory:
6
+ def __init__(self, max_history: int = Config.MAX_HISTORY):
7
+ self.history = deque(maxlen=max_history)
8
+ self.schema_cache = None
9
+
10
+ def add_interaction(self, question: str, sql: str, result: str):
11
+ self.history.append({
12
+ "question": question,
13
+ "sql": sql,
14
+ "result_summary": self._summarize_result(result)
15
+ })
16
+
17
+ def _summarize_result(self, result) -> str:
18
+ """Resumen ejecutivo para memoria de contexto"""
19
+ if isinstance(result, pd.DataFrame):
20
+ # Enfocado en datos CLAVE no en metadatos
21
+ if len(result) == 1:
22
+ return f"Único resultado: {result.iloc[0].to_dict()}"
23
+ elif 'Cliente' in result.columns:
24
+ top = result.nlargest(3, 'Neto') if 'Neto' in result.columns else result.head(3)
25
+ return f"Top clientes: {top['Cliente'].tolist()}"
26
+ else:
27
+ return f"Filas: {len(result)}, Columnas: {list(result.columns)}"
28
+ return str(result)
29
+
30
+ def get_context(self, current_question: str) -> str:
31
+ if not self.history:
32
+ return ""
33
+ last_relevant = []
34
+ for interaction in self.history:
35
+ if "producto" in interaction['question'].lower() and "producto" in current_question.lower():
36
+ last_relevant.append(interaction)
37
+ elif "cliente" in interaction['question'].lower() and "cliente" in current_question.lower():
38
+ last_relevant.append(interaction)
39
+
40
+ context = ""
41
+ for i, interaction in enumerate(last_relevant[-1:], 1): # Solo la ΓΊltima relevante
42
+ context += (
43
+ f"InteracciΓ³n #{i}: {interaction['question'][:50]}...\n"
44
+ f"SQL: {interaction['sql'][:70]}...\n"
45
+ f"Resultado: {interaction['result_summary']}\n\n"
46
+ )
47
+ return context
mojica_agent.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from memory import ConversationMemory
2
+ from config import Config
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import torch, gc
5
+ import unicodedata
6
+ from typing import Dict, Tuple
7
+ import re
8
+ import pandas as pd
9
+ import sqlite3
10
+
11
+ class MojicaAgent:
12
+ def __init__(self, config: Config):
13
+ self.config = config
14
+ self.memory = ConversationMemory()
15
+ # self._initialize_database()
16
+ self._safe_initializer_model()
17
+
18
+ def _safe_initializer_model(self):
19
+ def try_load_model():
20
+ dtype = torch.float16 if "cuda" in self.config.DEVICE else torch.float32
21
+ tokenizer = AutoTokenizer.from_pretrained(self.config.MODEL_NAME)
22
+ model = (
23
+ AutoModelForCausalLM.from_pretrained(
24
+ self.config.MODEL_NAME, trust_remote_code=True, torch_dtype=dtype
25
+ )
26
+ .to(self.config.DEVICE)
27
+ .eval()
28
+ ) # eval porque solo se predice
29
+ return tokenizer, model
30
+
31
+ try:
32
+ self.tokenizer, self.model = try_load_model()
33
+ except torch.cuda.OutOfMemoryError:
34
+ # Liberar memoria y volver a intentar
35
+ gc.collect()
36
+ torch.cuda.empty_cache()
37
+ torch.cuda.ipc_collect()
38
+
39
+ self.tokenizer, self.model = try_load_model()
40
+
41
+ def _initialize_database(self):
42
+ self.conn = sqlite3.connect(self.config.DB_PATH)
43
+ # cursor = self.conn.cursor()
44
+
45
+ # # Si la tabla existe, la borramos
46
+ # cursor.execute(f"DROP TABLE IF EXISTS {self.config.TABLE_NAME}")
47
+ # self.conn.commit()
48
+
49
+ # # Cargar todos los datos del CSV en la tabla
50
+ # df = pd.read_csv(self.config.CSV_PATH, low_memory=False)
51
+ # df.to_sql(self.config.TABLE_NAME, self.conn, if_exists="replace", index=False)
52
+ self.schema = self._get_schema_structured()
53
+
54
+ def _get_schema_structured(self) -> Dict:
55
+ if self.memory.schema_cache:
56
+ return self.memory.schema_cache
57
+ cursor = self.conn.cursor()
58
+ cursor.execute(f"PRAGMA table_info({self.config.TABLE_NAME})")
59
+ columns = [
60
+ {"name": column[1], "type": column[2]} for column in cursor.fetchall()
61
+ ]
62
+ schema = {"table_name": self.config.TABLE_NAME, "columns": columns}
63
+ self.memory.schema_cache = schema
64
+ return schema
65
+
66
+ def _build_prompt(self, question: str) -> str:
67
+ memory_context = self.memory.get_context(question)
68
+ table_name = self.schema["table_name"]
69
+
70
+ # 1. Detectar tipo de pregunta
71
+ question_type = (
72
+ "PRODUCTOS"
73
+ if "producto" in question.lower()
74
+ else "CLIENTES" if "cliente" in question.lower() else "GENERAL"
75
+ )
76
+
77
+ # 2. Ejemplos dinΓ‘micos
78
+ examples = {
79
+ "PRODUCTOS": (
80
+ "-- P: 'Top 10 productos mΓ‘s vendidos'\n"
81
+ 'SELECT "Descripcion", SUM("Cantidad") AS total_vendido\n'
82
+ f'FROM "{table_name}"\n'
83
+ 'WHERE "Descripcion" IS NOT NULL\n'
84
+ 'GROUP BY "Descripcion"\n'
85
+ "ORDER BY total_vendido DESC\n"
86
+ "LIMIT 10;\n\n"
87
+ "-- P: 'Productos con mayor valor neto'\n"
88
+ 'SELECT "Descripcion", SUM("Neto") AS valor_total\n'
89
+ f'FROM "{table_name}"\n'
90
+ 'WHERE "Descripcion" IS NOT NULL\n'
91
+ 'GROUP BY "Descripcion"\n'
92
+ "ORDER BY valor_total DESC\n"
93
+ "LIMIT 5;"
94
+ ),
95
+ "CLIENTES": (
96
+ "-- P: 'Top 5 clientes con mayor valor neto'\n"
97
+ 'SELECT "Cliente", SUM("Neto") AS valor_total\n'
98
+ f'FROM "{table_name}"\n'
99
+ "WHERE \"Cliente\" IS NOT NULL AND \"Fecha\" BETWEEN '2025-01-01' AND '2025-12-31'\n"
100
+ 'GROUP BY "Cliente"\n'
101
+ "ORDER BY valor_total DESC\n"
102
+ "LIMIT 5;\n\n"
103
+ "-- P: 'Clientes con mΓ‘s compras en marzo'\n"
104
+ 'SELECT "Cliente", COUNT(*) AS total_compras\n'
105
+ f'FROM "{table_name}"\n'
106
+ "WHERE \"Cliente\" IS NOT NULL AND strftime('%m', \"Fecha\") = '03'\n"
107
+ 'GROUP BY "Cliente"\n'
108
+ "ORDER BY total_compras DESC\n"
109
+ "LIMIT 10;\n\n"
110
+ "-- P: 'Clientes de Guadalajara con mΓ‘s compras'\n"
111
+ 'SELECT "Cliente", "Razon Social", COUNT(*) AS total_compras\n'
112
+ f'FROM "{table_name}"\n'
113
+ 'WHERE "Cliente" IS NOT NULL AND "Ciudad" = \'Guadalajara\'\n'
114
+ 'GROUP BY "Cliente", "Razon Social"\n'
115
+ "ORDER BY total_compras DESC\n"
116
+ "LIMIT 10;"
117
+ ),
118
+ "GENERAL": (
119
+ "-- P: 'Ventas totales por mes'\n"
120
+ 'SELECT strftime(\'%m\', "Fecha") AS mes, SUM("Neto") AS ventas\n'
121
+ f'FROM "{table_name}"\n'
122
+ "WHERE mes IS NOT NULL\n"
123
+ "GROUP BY mes\n"
124
+ "ORDER BY mes;\n\n"
125
+ "-- P: 'Producto menos vendido en 2025'\n"
126
+ 'SELECT "Descripcion", SUM("Cantidad") AS total_vendido\n'
127
+ f'FROM "{table_name}"\n'
128
+ "WHERE \"Descripcion\" IS NOT NULL AND \"Fecha\" BETWEEN '2025-01-01' AND '2025-12-31'\n"
129
+ 'GROUP BY "Descripcion"\n'
130
+ "ORDER BY total_vendido ASC\n"
131
+ "LIMIT 1;"
132
+ ),
133
+ }
134
+
135
+ # 3. Columnas esenciales
136
+ essential_columns = [
137
+ {
138
+ "name": "Descripcion",
139
+ "type": "TEXT",
140
+ "description": "Nombre del producto",
141
+ },
142
+ {"name": "Cantidad", "type": "REAL", "description": "Unidades vendidas"},
143
+ {"name": "Cliente", "type": "TEXT", "description": "CΓ³digo de cliente"},
144
+ {
145
+ "name": "Razon Social",
146
+ "type": "TEXT",
147
+ "description": "Nombre completo del cliente",
148
+ },
149
+ {"name": "Ciudad", "type": "TEXT", "description": "Ciudad del cliente"},
150
+ {
151
+ "name": "Fecha",
152
+ "type": "TEXT",
153
+ "description": "Fecha de venta (YYYY-MM-DD)",
154
+ },
155
+ {"name": "Neto", "type": "REAL", "description": "Valor neto de la venta"},
156
+ ]
157
+
158
+ # 4. Prompt final con nueva regla
159
+ return (
160
+ f"""
161
+ ### TAREA ###
162
+ Generar SOLO cΓ³digo SQL para la pregunta, usando EXCLUSIVAMENTE la tabla: "{table_name}"
163
+
164
+ ### COLUMNAS RELEVANTES ###
165
+ """
166
+ + "\n".join(
167
+ [
168
+ f"- {col['name']} ({col['type']}): {col['description']}"
169
+ for col in essential_columns
170
+ ]
171
+ )
172
+ + f"""
173
+
174
+ ### CONTEXTO (Últimas interacciones) ###
175
+ {memory_context if memory_context else "Sin historial relevante"}
176
+
177
+ ### EJEMPLOS ({question_type}) ###
178
+ {examples[question_type]}
179
+
180
+ ### REGLAS CRÍTICAS ###
181
+ - Usar siempre nombres exactos de columnas
182
+ - Agrupar por la dimensiΓ³n principal (producto/cliente)
183
+ - Ordenar DESC para 'mΓ‘s/mayor', ASC para 'menos/menor'
184
+ - Usar LIMIT para top N
185
+ - AΓ±o actual: 2025
186
+ - Siempre terminar con un LIMIT = 1 en caso que se indique lo contrario
187
+ - Para 'mΓ‘s vendido' usar SUM("Cantidad"), para 'mayor valor' usar SUM("Neto")
188
+ - Usar "Razon Social" cuando pregunten por el nombre del cliente
189
+ - Usar "Ciudad" para filtrar o agrupar por ubicaciΓ³n
190
+ - Queda estrictamente prohibido usar acentos
191
+ - **Siempre excluir valores nulos con 'IS NOT NULL' en las columnas usadas en WHERE, GROUP BY u ORDER BY**
192
+
193
+ ### PREGUNTA ACTUAL ###
194
+ \"\"\"{question}\"\"\"
195
+
196
+ ### SQL:
197
+ """
198
+ )
199
+
200
+ def _clean_sql_output(self, output: str) -> str:
201
+ # Encuentra todas las posibles queries completas que terminen en ;
202
+ sql_matches = list(
203
+ re.finditer(
204
+ r"(SELECT|WITH|INSERT|UPDATE|DELETE)[\s\S]+?;", output, re.IGNORECASE
205
+ )
206
+ )
207
+
208
+ if not sql_matches:
209
+ return None
210
+
211
+ # Tomar la ΓΊltima query encontrada
212
+ sql = sql_matches[-1].group(0).strip()
213
+
214
+ # Seguridad: bloquear queries peligrosas
215
+ if any(
216
+ cmd in sql.upper()
217
+ for cmd in ["DROP", "DELETE", "UPDATE", "INSERT", "ALTER"]
218
+ ):
219
+ return None
220
+
221
+ # Asegurar que termine en ;
222
+ if not sql.endswith(";"):
223
+ sql += ";"
224
+
225
+ # ────────────────────────────────
226
+ # 1. Quitar acentos de toda la query
227
+ # ────────────────────────────────
228
+ def remove_accents(text: str) -> str:
229
+ return "".join(
230
+ c
231
+ for c in unicodedata.normalize("NFKD", text)
232
+ if not unicodedata.combining(c)
233
+ )
234
+
235
+ sql = remove_accents(sql)
236
+
237
+ # ────────────────────────────────
238
+ # 2. Agregar LIMIT si no existe
239
+ # ────────────────────────────────
240
+ # Buscar si ya hay un LIMIT en la query
241
+ if not re.search(r"\bLIMIT\s+\d+", sql, re.IGNORECASE):
242
+ # Insertar antes del ΓΊltimo punto y coma
243
+ sql = (
244
+ sql[:-1] + " LIMIT 1;"
245
+ ) # puedes cambiar 100 por el valor default que quieras
246
+
247
+ return sql
248
+
249
+ def _execute_sql(self, sql: str):
250
+ try:
251
+ return pd.read_sql_query(sql, self.conn)
252
+ except Exception as e:
253
+ return f"Error de ejecuciΓ³n: {str(e)}"
254
+
255
+ def consult(self, question: str) -> Tuple[str, any]:
256
+ prompt = self._build_prompt(question)
257
+
258
+ tokenized_input = self.tokenizer(
259
+ prompt,
260
+ return_tensors="pt",
261
+ truncation=True,
262
+ max_length=self.config.MAX_TOKENS,
263
+ ).to(self.config.DEVICE)
264
+
265
+ # Desactiva el cΓ‘lculo de gradientes -> Siempre poner cuando se haga prediccion
266
+ # - Reduce consumo de memoria
267
+ # - Acelera inferencia
268
+ with torch.no_grad():
269
+ tokenized_output_model = self.model.generate(
270
+ **tokenized_input,
271
+ max_new_tokens=self.config.MAX_NEW_TOKENS,
272
+ temperature=0.2,
273
+ top_p=0.95,
274
+ top_k=50,
275
+ repetition_penalty=1.1,
276
+ do_sample=True,
277
+ pad_token_id=self.tokenizer.eos_token_id,
278
+ )
279
+
280
+ output_model = self.tokenizer.decode(
281
+ tokenized_output_model[0], skip_special_tokens=True
282
+ )
283
+
284
+ sql_query = self._clean_sql_output(output_model)
285
+
286
+ if not sql_query:
287
+ return "Error: No se pudo generar SQL vΓ‘lido" + "\n" + output_model, None
288
+
289
+ result = self._execute_sql(sql_query)
290
+ self.memory.add_interaction(question, sql_query, result)
291
+
292
+ return sql_query, result
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ