Carlos Isael Ramírez González commited on
Commit
a04ffe2
·
1 Parent(s): 9e123b2

Cambie el modelo antiguo por el nuevo

Browse files
Files changed (1) hide show
  1. mojica_agent.py +158 -208
mojica_agent.py CHANGED
@@ -1,109 +1,54 @@
1
- from memory import Memory as ConversationMemory
2
  from config import Config
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch, gc
5
  import unicodedata
6
- from typing import Dict, Tuple, Optional, Any
7
- import re
8
  import pandas as pd
9
  import sqlite3
10
- from intelligent_question_router import IntelligentQuestionRouter
11
-
12
 
13
  class MojicaAgent:
14
  def __init__(self, config: Config):
15
  self.config = config
16
  self.memory = ConversationMemory()
17
- self.essential_columns = [
18
- {
19
- "name": "Descripcion",
20
- "type": "TEXT",
21
- "description": "Nombre del producto",
22
- },
23
- {"name": "Cantidad", "type": "REAL", "description": "Unidades vendidas"},
24
- {"name": "Cliente", "type": "TEXT", "description": "Código de cliente"},
25
- {
26
- "name": "Razon Social",
27
- "type": "TEXT",
28
- "description": "Nombre completo del cliente",
29
- },
30
- {"name": "Ciudad", "type": "TEXT", "description": "Ciudad del cliente"},
31
- {
32
- "name": "Fecha",
33
- "type": "TEXT",
34
- "description": "Fecha de venta (YYYY-MM-DD)",
35
- },
36
- {"name": "Neto", "type": "REAL", "description": "Valor neto de la venta"},
37
- ]
38
- self.schema = self._load_schema()
39
  self._safe_initializer_model()
40
 
41
- def _initialize_model(self):
42
  def try_load_model():
43
- self.tokenizer = AutoTokenizer.from_pretrained(self.config.MODEL_NAME)
44
- self.model = AutoModelForCausalLM.from_pretrained(
45
- self.config.MODEL_NAME,
46
- device_map="auto",
47
- torch_dtype="auto",
48
- trust_remote_code=True,
49
- ).eval()
 
 
 
50
 
51
  try:
52
- try_load_model()
53
  except torch.cuda.OutOfMemoryError:
 
54
  gc.collect()
55
  torch.cuda.empty_cache()
56
  torch.cuda.ipc_collect()
57
- try_load_model()
58
 
59
- def _load_training_data(self):
60
- training_examples = [
61
- {"question": "productos más vendidos", "category": "producto"},
62
- {"question": "mejor producto", "category": "producto"},
63
- {"question": "clientes que más compran", "category": "cliente"},
64
- {"question": "clientes inactivos", "category": "cliente"},
65
- ]
66
- try:
67
- self.router.semantic_classifier.train(training_examples)
68
- except Exception as e:
69
- print(f"Error training semantic classifier: {e}")
70
-
71
- def _validate_result_existing(self, result):
72
- # Si es un string de error
73
- if isinstance(result, str) and "Error" in result:
74
- return False
75
-
76
- # Si es un DataFrame vacío
77
- if hasattr(result, "empty") and result.empty:
78
- return False
79
-
80
- # Si es una lista vacía
81
- if isinstance(result, list) and len(result) == 0:
82
- return False
83
-
84
- # En cualquier otro caso, asumimos éxito
85
- return True
86
-
87
- def _initialize_database(self):
88
- self.conn = sqlite3.connect(self.config.DB_PATH)
89
-
90
- cursor = self.conn.cursor()
91
- cursor.execute(f"DROP TABLE IF EXISTS {self.config.TABLE_NAME}")
92
- self.conn.commit()
93
- df = pd.read_csv(self.config.CSV_PATH, low_memory=False)
94
 
95
- real_cols = [
96
- col["name"] for col in self.essential_columns if col["type"] == "REAL"
 
 
 
 
97
  ]
98
- for col in real_cols:
99
- if col in df.columns:
100
- df[col] = pd.to_numeric(df[col], errors="coerce")
101
-
102
- df.to_sql(self.config.TABLE_NAME, self.conn, if_exists="replace", index=False)
103
- self.schema = self._get_schema_structured()
104
- # Configuracion de pandas:
105
- pd.set_option("display.float_format", "{:,.2f}".format)
106
-
107
  def _get_schema_structured(self) -> Dict:
108
  if self.memory.schema_cache:
109
  return self.memory.schema_cache
@@ -116,23 +61,99 @@ class MojicaAgent:
116
  self.memory.schema_cache = schema
117
  return schema
118
 
119
- def _generate_sql_prompt(self, question: str) -> str:
120
- memory_context = self.memory.get_relevant_memory(question)
121
  table_name = self.schema["table_name"]
122
- # Uso del router
123
- try:
124
- examples_list = self.router.route_question(question)
125
- # Convertir ejemplos a texto para el prompt
126
- examples_text = "\n".join(
127
- [f"-- P: '{ex['pregunta']}'\n{ex['query']}\n" for ex in examples_list]
128
- )
129
- question_type = "ROUTED_EXAMPLES"
130
- except Exception as e:
131
- print(f"Router failed, using manual detection: {e}")
132
- # Fallback a detección manual
133
- # question_type = self._detect_question_type_manual(question)
134
- # examples_text = self.examples.get(question_type, "")
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  return (
137
  f"""
138
  ### TAREA ###
@@ -143,7 +164,7 @@ class MojicaAgent:
143
  + "\n".join(
144
  [
145
  f"- {col['name']} ({col['type']}): {col['description']}"
146
- for col in self.essential_columns
147
  ]
148
  )
149
  + f"""
@@ -151,32 +172,22 @@ class MojicaAgent:
151
  ### CONTEXTO (Últimas interacciones) ###
152
  {memory_context if memory_context else "Sin historial relevante"}
153
 
154
- ### EJEMPLOS ###
155
- {examples_text}
156
 
157
  ### REGLAS CRÍTICAS ###
158
  - Usar siempre nombres exactos de columnas
159
- - Usar solo las columnas listadas
160
- - Prohibido inventar columnas
161
- - Para el nombre del cliente, usar SIEMPRE "Razon Social".
162
- - Para un mes específico usar: strftime('%m', "Fecha") = 'MM'
163
- - Para cantidades usar SUM("Cantidad"), para dinero usar SUM("Neto")
164
  - Agrupar por la dimensión principal (producto/cliente)
165
  - Ordenar DESC para 'más/mayor', ASC para 'menos/menor'
166
- - Contesta siempre en el idioma en el que se te pregunta no traduzcas.
167
  - Año actual: 2025
168
- - No inventes columnas o tablas que no existan
169
- - Para preguntas sobre clientes cero, SIEMPRE usar la subconsulta NOT IN con las últimas 4 semanas.
170
- - Si se menciona una ciudad, incluir el filtro AND "Ciudad" LIKE '%...%'
171
- - Usa LIMIT cuando se te pida un numero finito de datos
172
  - Para 'más vendido' usar SUM("Cantidad"), para 'mayor valor' usar SUM("Neto")
173
  - Usar "Razon Social" cuando pregunten por el nombre del cliente
174
  - Usar "Ciudad" para filtrar o agrupar por ubicación
175
  - Queda estrictamente prohibido usar acentos
176
  - **Siempre excluir valores nulos con 'IS NOT NULL' en las columnas usadas en WHERE, GROUP BY u ORDER BY**
177
- - Para preguntas sobre ciudad SIEMPRE incluir "Ciudad" en la query
178
- - Para busquedas por Descripcion siempre usar LIKE
179
- - Mandar solo la cantidad de rows que el usuario pide.
180
  ### PREGUNTA ACTUAL ###
181
  \"\"\"{question}\"\"\"
182
 
@@ -184,27 +195,7 @@ class MojicaAgent:
184
  """
185
  )
186
 
187
- def _generate_analysis_prompt(self, question: str, result: Any) -> str:
188
- return f"""
189
- Basado EXCLUSIVAMENTE en estos datos: {result}
190
-
191
- Responde esta pregunta: {question}
192
-
193
- Reglas estrictas:
194
- - Nunca inventes numeros
195
- - Usa solo datos proporcionados
196
- - Maximo una oracion
197
- """
198
-
199
- def _clean_analysis_output(self, ouput: str) -> Optional[str]:
200
- pattern = r"Respuesta:([\s\S]+)"
201
- match = re.search(pattern, ouput)
202
- if match:
203
- return match.group(1).strip()
204
- else:
205
- return "Sin análisis"
206
-
207
- def _clean_sql_output(self, output: str) -> Optional[str]:
208
  # Encuentra todas las posibles queries completas que terminen en ;
209
  sql_matches = list(
210
  re.finditer(
@@ -245,63 +236,38 @@ class MojicaAgent:
245
  # 2. Agregar LIMIT si no existe
246
  # ────────────────────────────────
247
  # Buscar si ya hay un LIMIT en la query
248
- # if not re.search(r"\bLIMIT\s+\d+", sql, re.IGNORECASE):
249
- # # Insertar antes del último punto y coma
250
- # sql = sql[:-1] + " LIMIT 1;" # puedes cambiar 100 por el valor default que quieras
251
-
252
- validate_sql = self._validate_and_correct_sql(sql)
253
- return validate_sql
254
-
255
- def _validate_and_correct_sql(self, sql: str) -> str:
256
- cur = self.conn.cursor()
257
- cur.execute(f'PRAGMA table_info("{self.config.TABLE_NAME}")')
258
- real_columns = [row[1] for row in cur.fetchall()]
259
- column_lower_map = {col.lower(): col for col in real_columns}
260
- aliases = {
261
- "city": "Ciudad",
262
- "client": "Cliente",
263
- "razon_social": "Razon Social",
264
- "razón social": "Razon Social",
265
- "Sales": "sells",
266
- '"Date"': "Fecha",
267
- "mojica_Clientes": "sells",
268
- "value_total": "valor_total",
269
- "strstrftime": "strftime",
270
- }
271
- alias_map = {k.lower(): v for k, v in aliases.items()}
272
-
273
- pattern = r"\b\w+\b"
274
-
275
- def replace_column(m):
276
- candidate = m.group(0) # Palabra encontrada
277
- key = candidate.lower()
278
- # ¿Es una columna?
279
- corrected = column_lower_map.get(key)
280
- if corrected:
281
- return corrected
282
 
283
- # ¿Es una alias?
284
- corrected = alias_map.get(key)
285
- if corrected is not None:
286
- return corrected
287
- return candidate # si no encuentra nada, lo deja igual
288
 
289
- return re.sub(pattern, replace_column, sql).replace("\\", "")
290
-
291
- def _execute_sql(self, sql: str) -> Any:
292
  try:
293
- return pd.read_sql_query(sql, self.conn)
 
 
294
  except Exception as e:
295
- return f"Error: {str(e)}"
 
 
 
 
 
296
 
297
- def consult(self, question: str) -> Tuple[str, Any, str]:
298
- sql_prompt = self._generate_sql_prompt(question)
299
  tokenized_input = self.tokenizer(
300
- sql_prompt,
301
  return_tensors="pt",
302
  truncation=True,
303
  max_length=self.config.MAX_TOKENS,
304
  ).to(self.config.DEVICE)
 
 
 
 
305
  with torch.no_grad():
306
  tokenized_output_model = self.model.generate(
307
  **tokenized_input,
@@ -313,33 +279,17 @@ class MojicaAgent:
313
  do_sample=True,
314
  pad_token_id=self.tokenizer.eos_token_id,
315
  )
 
316
  output_model = self.tokenizer.decode(
317
  tokenized_output_model[0], skip_special_tokens=True
318
  )
319
- sql = self._clean_sql_output(output_model)
320
 
321
- # * Ejecución de SQL y generación de analisis
322
- result = self._execute_sql(sql)
323
- # * INICIO DE ANALISIS (COMENTADO)
324
- # Analisis
325
- # analysis_prompt = self._generate_analysis_prompt(question, result)
326
- # analyzed_token_input = self.tokenizer(
327
- # analysis_prompt,
328
- # return_tensors="pt",
329
- # truncation=True,
330
- # max_length=self.config.MAX_TOKENS,
331
- # ).to(self.config.DEVICE)
332
- # with torch.no_grad():
333
- # tokenized_analysis_output_model = self.model.generate(
334
- # **analyzed_token_input,
335
- # max_new_tokens=self.config.MAX_NEW_TOKENS,
336
- # temperature=0.65,
337
- # )
338
- # analysis = self.tokenizer.decode(
339
- # tokenized_analysis_output_model[0], skip_special_tokens=True
340
- # )
341
- # analysis = self._clean_analysis_output(analysis)
342
- # analysis <- LE quite ese parametro
343
- # * FIN DE ANALISIS (COMENTADO)
344
- self.memory.add_interaction(question=question, answer=result, sql=sql)
345
- return sql, result
 
1
+ from memory import ConversationMemory
2
  from config import Config
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch, gc
5
  import unicodedata
6
+ from typing import Dict, Tuple
7
+ import re
8
  import pandas as pd
9
  import sqlite3
 
 
10
 
11
  class MojicaAgent:
12
  def __init__(self, config: Config):
13
  self.config = config
14
  self.memory = ConversationMemory()
15
+ self.schema = self._load_schema()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  self._safe_initializer_model()
17
 
18
+ def _safe_initializer_model(self):
19
  def try_load_model():
20
+ dtype = torch.float16 if "cuda" in self.config.DEVICE else torch.float32
21
+ tokenizer = AutoTokenizer.from_pretrained(self.config.MODEL_NAME)
22
+ model = (
23
+ AutoModelForCausalLM.from_pretrained(
24
+ self.config.MODEL_NAME, trust_remote_code=True, torch_dtype=dtype
25
+ )
26
+ .to(self.config.DEVICE)
27
+ .eval()
28
+ ) # eval porque solo se predice
29
+ return tokenizer, model
30
 
31
  try:
32
+ self.tokenizer, self.model = try_load_model()
33
  except torch.cuda.OutOfMemoryError:
34
+ # Liberar memoria y volver a intentar
35
  gc.collect()
36
  torch.cuda.empty_cache()
37
  torch.cuda.ipc_collect()
 
38
 
39
+ self.tokenizer, self.model = try_load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ def _load_schema(self) -> Dict:
42
+ conn = sqlite3.connect(self.config.DB_PATH)
43
+ cursor = conn.cursor()
44
+ cursor.execute(f"PRAGMA table_info({self.config.TABLE_NAME})")
45
+ columns = [
46
+ {"name": column[1], "type": column[2]} for column in cursor.fetchall()
47
  ]
48
+ schema = {"table_name": self.config.TABLE_NAME, "columns": columns}
49
+ conn.close()
50
+ return schema
51
+
 
 
 
 
 
52
  def _get_schema_structured(self) -> Dict:
53
  if self.memory.schema_cache:
54
  return self.memory.schema_cache
 
61
  self.memory.schema_cache = schema
62
  return schema
63
 
64
+ def _build_prompt(self, question: str) -> str:
65
+ memory_context = self.memory.get_context(question)
66
  table_name = self.schema["table_name"]
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ # 1. Detectar tipo de pregunta
69
+ question_type = (
70
+ "PRODUCTOS"
71
+ if "producto" in question.lower()
72
+ else "CLIENTES" if "cliente" in question.lower() else "GENERAL"
73
+ )
74
+
75
+ # 2. Ejemplos dinámicos
76
+ examples = {
77
+ "PRODUCTOS": (
78
+ "-- P: 'Top 10 productos más vendidos'\n"
79
+ 'SELECT "Descripcion", SUM("Cantidad") AS total_vendido\n'
80
+ f'FROM "{table_name}"\n'
81
+ 'WHERE "Descripcion" IS NOT NULL\n'
82
+ 'GROUP BY "Descripcion"\n'
83
+ "ORDER BY total_vendido DESC\n"
84
+ "LIMIT 10;\n\n"
85
+ "-- P: 'Productos con mayor valor neto'\n"
86
+ 'SELECT "Descripcion", SUM("Neto") AS valor_total\n'
87
+ f'FROM "{table_name}"\n'
88
+ 'WHERE "Descripcion" IS NOT NULL\n'
89
+ 'GROUP BY "Descripcion"\n'
90
+ "ORDER BY valor_total DESC\n"
91
+ "LIMIT 5;"
92
+ ),
93
+ "CLIENTES": (
94
+ "-- P: 'Top 5 clientes con mayor valor neto'\n"
95
+ 'SELECT "Cliente", SUM("Neto") AS valor_total\n'
96
+ f'FROM "{table_name}"\n'
97
+ "WHERE \"Cliente\" IS NOT NULL AND \"Fecha\" BETWEEN '2025-01-01' AND '2025-12-31'\n"
98
+ 'GROUP BY "Cliente"\n'
99
+ "ORDER BY valor_total DESC\n"
100
+ "LIMIT 5;\n\n"
101
+ "-- P: 'Clientes con más compras en marzo'\n"
102
+ 'SELECT "Cliente", COUNT(*) AS total_compras\n'
103
+ f'FROM "{table_name}"\n'
104
+ "WHERE \"Cliente\" IS NOT NULL AND strftime('%m', \"Fecha\") = '03'\n"
105
+ 'GROUP BY "Cliente"\n'
106
+ "ORDER BY total_compras DESC\n"
107
+ "LIMIT 10;\n\n"
108
+ "-- P: 'Clientes de Guadalajara con más compras'\n"
109
+ 'SELECT "Cliente", "Razon Social", COUNT(*) AS total_compras\n'
110
+ f'FROM "{table_name}"\n'
111
+ 'WHERE "Cliente" IS NOT NULL AND "Ciudad" = \'Guadalajara\'\n'
112
+ 'GROUP BY "Cliente", "Razon Social"\n'
113
+ "ORDER BY total_compras DESC\n"
114
+ "LIMIT 10;"
115
+ ),
116
+ "GENERAL": (
117
+ "-- P: 'Ventas totales por mes'\n"
118
+ 'SELECT strftime(\'%m\', "Fecha") AS mes, SUM("Neto") AS ventas\n'
119
+ f'FROM "{table_name}"\n'
120
+ "WHERE mes IS NOT NULL\n"
121
+ "GROUP BY mes\n"
122
+ "ORDER BY mes;\n\n"
123
+ "-- P: 'Producto menos vendido en 2025'\n"
124
+ 'SELECT "Descripcion", SUM("Cantidad") AS total_vendido\n'
125
+ f'FROM "{table_name}"\n'
126
+ "WHERE \"Descripcion\" IS NOT NULL AND \"Fecha\" BETWEEN '2025-01-01' AND '2025-12-31'\n"
127
+ 'GROUP BY "Descripcion"\n'
128
+ "ORDER BY total_vendido ASC\n"
129
+ "LIMIT 1;"
130
+ ),
131
+ }
132
+
133
+ # 3. Columnas esenciales
134
+ essential_columns = [
135
+ {
136
+ "name": "Descripcion",
137
+ "type": "TEXT",
138
+ "description": "Nombre del producto",
139
+ },
140
+ {"name": "Cantidad", "type": "REAL", "description": "Unidades vendidas"},
141
+ {"name": "Cliente", "type": "TEXT", "description": "Código de cliente"},
142
+ {
143
+ "name": "Razon Social",
144
+ "type": "TEXT",
145
+ "description": "Nombre completo del cliente",
146
+ },
147
+ {"name": "Ciudad", "type": "TEXT", "description": "Ciudad del cliente"},
148
+ {
149
+ "name": "Fecha",
150
+ "type": "TEXT",
151
+ "description": "Fecha de venta (YYYY-MM-DD)",
152
+ },
153
+ {"name": "Neto", "type": "REAL", "description": "Valor neto de la venta"},
154
+ ]
155
+
156
+ # 4. Prompt final con nueva regla
157
  return (
158
  f"""
159
  ### TAREA ###
 
164
  + "\n".join(
165
  [
166
  f"- {col['name']} ({col['type']}): {col['description']}"
167
+ for col in essential_columns
168
  ]
169
  )
170
  + f"""
 
172
  ### CONTEXTO (Últimas interacciones) ###
173
  {memory_context if memory_context else "Sin historial relevante"}
174
 
175
+ ### EJEMPLOS ({question_type}) ###
176
+ {examples[question_type]}
177
 
178
  ### REGLAS CRÍTICAS ###
179
  - Usar siempre nombres exactos de columnas
 
 
 
 
 
180
  - Agrupar por la dimensión principal (producto/cliente)
181
  - Ordenar DESC para 'más/mayor', ASC para 'menos/menor'
182
+ - Usar LIMIT para top N
183
  - Año actual: 2025
184
+ - Siempre terminar con un LIMIT = 1 en caso que se indique lo contrario
 
 
 
185
  - Para 'más vendido' usar SUM("Cantidad"), para 'mayor valor' usar SUM("Neto")
186
  - Usar "Razon Social" cuando pregunten por el nombre del cliente
187
  - Usar "Ciudad" para filtrar o agrupar por ubicación
188
  - Queda estrictamente prohibido usar acentos
189
  - **Siempre excluir valores nulos con 'IS NOT NULL' en las columnas usadas en WHERE, GROUP BY u ORDER BY**
190
+
 
 
191
  ### PREGUNTA ACTUAL ###
192
  \"\"\"{question}\"\"\"
193
 
 
195
  """
196
  )
197
 
198
+ def _clean_sql_output(self, output: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  # Encuentra todas las posibles queries completas que terminen en ;
200
  sql_matches = list(
201
  re.finditer(
 
236
  # 2. Agregar LIMIT si no existe
237
  # ────────────────────────────────
238
  # Buscar si ya hay un LIMIT en la query
239
+ if not re.search(r"\bLIMIT\s+\d+", sql, re.IGNORECASE):
240
+ # Insertar antes del último punto y coma
241
+ sql = (
242
+ sql[:-1] + " LIMIT 1;"
243
+ ) # puedes cambiar 100 por el valor default que quieras
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
+ return sql
 
 
 
 
246
 
247
+ def _execute_sql(self, sql: str):
248
+ conn = sqlite3.connect(self.config.DB_PATH)
 
249
  try:
250
+ result = pd.read_sql_query(sql, conn)
251
+ conn.close()
252
+ return result
253
  except Exception as e:
254
+ return f"Error de ejecución: {str(e)}"
255
+ finally:
256
+ conn.close()
257
+
258
+ def consult(self, question: str) -> Tuple[str, any]:
259
+ prompt = self._build_prompt(question)
260
 
 
 
261
  tokenized_input = self.tokenizer(
262
+ prompt,
263
  return_tensors="pt",
264
  truncation=True,
265
  max_length=self.config.MAX_TOKENS,
266
  ).to(self.config.DEVICE)
267
+
268
+ # Desactiva el cálculo de gradientes -> Siempre poner cuando se haga prediccion
269
+ # - Reduce consumo de memoria
270
+ # - Acelera inferencia
271
  with torch.no_grad():
272
  tokenized_output_model = self.model.generate(
273
  **tokenized_input,
 
279
  do_sample=True,
280
  pad_token_id=self.tokenizer.eos_token_id,
281
  )
282
+
283
  output_model = self.tokenizer.decode(
284
  tokenized_output_model[0], skip_special_tokens=True
285
  )
 
286
 
287
+ sql_query = self._clean_sql_output(output_model)
288
+
289
+ if not sql_query:
290
+ return "Error: No se pudo generar SQL válido" + "\n" + output_model, None
291
+
292
+ result = self._execute_sql(sql_query)
293
+ self.memory.add_interaction(question, sql_query, result)
294
+
295
+ return sql_query, result