Spaces:
Sleeping
Sleeping
| from sentence_transformers import SentenceTransformer | |
| from sklearn.cluster import KMeans | |
| from config import Config | |
| from load_json import load_examples | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.cluster import KMeans | |
| class SemanticClassifier: | |
| def __init__(self, model_name="paraphrase-multilingual-MiniLM-L12-v2", initialized_train=True): | |
| self.model = SentenceTransformer(model_name) | |
| self.clusters = {} | |
| self.examples_embeddings = None | |
| self.kmeans = None | |
| if initialized_train: | |
| self.train() | |
| def train(self, train_data=Config.EXMAPLES_JSON, n_clusters=15): | |
| examples = load_examples(train_data) | |
| # * Aplanar ejemplos | |
| flat_examples = [] | |
| for category, items in examples.items(): | |
| for item in items: | |
| flat_examples.append({ | |
| "category": category, | |
| "pregunta": item["pregunta"], | |
| "query": item["query"] | |
| }) | |
| questions = [ex["pregunta"] for ex in flat_examples] | |
| # * Obtener embeddings | |
| embeddings = self.model.encode(questions) | |
| # * Clustering | |
| self.kmeans = KMeans(n_clusters=n_clusters, random_state=12) | |
| cluster_ids = self.kmeans.fit_predict(embeddings) | |
| # * Guardar ejemplos por cluster | |
| for i, cluster_id in enumerate(cluster_ids): | |
| # * Crear lista si no existe | |
| if cluster_id not in self.clusters: | |
| self.clusters[cluster_id] = [] | |
| # * Agregamos el ejemplo | |
| self.clusters[cluster_id].append(flat_examples[i]) | |
| self.examples_embeddings = embeddings | |
| def classify(self, question: str): | |
| # * En formato de embedding | |
| question_embedding = self.model.encode([question]) | |
| # * Encontrar el cluster más cercano | |
| cluster_id = self.kmeans.predict(question_embedding)[0] | |
| # * Retornamos los ejemplos de ese cluster | |
| return self.clusters.get(cluster_id, []) | |
| # * FORMA DE USARSE | |
| # classifier = SemanticClassifier() | |
| # classifier.train(Config.EXMAPLES_JSON, n_clusters=5) | |
| # resultado = classifier.classify("¿Cuantas ciudades tenemos registradas?") | |
| # print(resultado) # te devuelve ejemplos de ese cluster |