Spaces:
Build error
Build error
| from sdc_classifier import SDCClassifier | |
| from dotenv import load_dotenv | |
| import torch | |
| import json | |
| import os | |
| from typing import Dict, Tuple, Optional, Any, List | |
| from dataclasses import dataclass, field | |
| import pandas as pd | |
| # Load environment variables | |
| load_dotenv() | |
| class Config: | |
| # DEFAULT_CLASSES_FILE: str = "classes.json" | |
| DEFAULT_CLASSES_FILE: str = "kw_questions.json" | |
| DEFAULT_SIGNATURES_FILE: str = "signatures.npz" | |
| CACHE_FILE: str = "embeddings_cache.db" | |
| MODEL_INFO_FILE: str = "model_info.json" | |
| DEFAULT_OPENAI_MODELS: List[str] = field(default_factory=lambda: ["text-embedding-3-large", "text-embedding-3-small"]) | |
| DEFAULT_LOCAL_MODEL: str = "cambridgeltl/SapBERT-from-PubMedBERT-fulltext" | |
| config = Config() | |
| class ClassifierApp: | |
| def __init__(self): | |
| self.classifier = None | |
| self.initial_info = { | |
| "status": "initializing", | |
| "model_info": {}, | |
| "classes_info": {}, | |
| "errors": [] | |
| } | |
| self.model_type = "Local" # Додати цей рядок | |
| def initialize_environment(self) -> Tuple[Dict, Optional[SDCClassifier]]: | |
| """Ініціалізація середовища при першому запуску""" | |
| try: | |
| # Перевіряємо наявність необхідних файлів | |
| if not os.path.exists(config.DEFAULT_CLASSES_FILE): | |
| self.initial_info["errors"].append( | |
| f"ПОМИЛКА: Файл {config.DEFAULT_CLASSES_FILE} не знайдено!" | |
| ) | |
| self.initial_info["status"] = "error" | |
| print(f"\nПомилка: Файл {config.DEFAULT_CLASSES_FILE} не знайдено!") | |
| return self.initial_info, None | |
| print("\nСтворення класифікатора...") | |
| try: | |
| # Визначаємо яка модель використовувалась для сигнатур | |
| signatures_model = None | |
| if os.path.exists(config.MODEL_INFO_FILE): | |
| with open(config.MODEL_INFO_FILE, 'r') as f: | |
| model_info = json.load(f) | |
| if not model_info.get('using_local', True): | |
| signatures_model = "text-embedding-3-small" # Модель, яка використовувалась | |
| # Створюємо класифікатор з тією ж моделлю | |
| self.classifier = SDCClassifier(openai_api_key=os.getenv("OPENAI_API_KEY")) | |
| print(f"Використовується модель: {signatures_model or 'local'}") | |
| except Exception as e: | |
| print(f"\nПомилка при створенні класифікатора: {str(e)}") | |
| self.initial_info["errors"].append(f"Помилка при створенні класифікатора: {str(e)}") | |
| self.initial_info["status"] = "error" | |
| return self.initial_info, None | |
| print("\nЗавантаження класів...") | |
| try: | |
| classes = self.classifier.load_classes(config.DEFAULT_CLASSES_FILE) | |
| self.initial_info["classes_info"] = { | |
| "total_classes": len(classes), | |
| "classes_list": list(classes.keys()), | |
| "hints_per_class": {cls: len(hints) for cls, hints in classes.items()} | |
| } | |
| except Exception as e: | |
| print(f"\nПомилка при завантаженні класів: {str(e)}") | |
| self.initial_info["errors"].append(f"Помилка при завантаженні класів: {str(e)}") | |
| self.initial_info["status"] = "error" | |
| return self.initial_info, None | |
| print("\nПеревірка та завантаження сигнатур...") | |
| if os.path.exists(config.DEFAULT_SIGNATURES_FILE): | |
| try: | |
| self.classifier.load_signatures(config.DEFAULT_SIGNATURES_FILE) | |
| self.initial_info["status"] = "success" | |
| print("Сигнатури завантажено успішно") | |
| except Exception as e: | |
| print(f"\nПомилка при завантаженні сигнатур: {str(e)}") | |
| self.initial_info["errors"].append(f"Помилка при завантаженні сигнатур: {str(e)}") | |
| self.initial_info["status"] = "error" | |
| return self.initial_info, None | |
| else: | |
| print("\nСтворення нових сигнатур...") | |
| self.initial_info["status"] = "creating_signatures" | |
| try: | |
| result = self.classifier.initialize_signatures( | |
| force_rebuild=True, | |
| signatures_file=config.DEFAULT_SIGNATURES_FILE | |
| ) | |
| if isinstance(result, str) and "error" in result.lower(): | |
| print(f"\nПомилка при створенні сигнатур: {result}") | |
| self.initial_info["errors"].append(result) | |
| self.initial_info["status"] = "error" | |
| return self.initial_info, None | |
| except Exception as e: | |
| print(f"\nПомилка при створенні сигнатур: {str(e)}") | |
| self.initial_info["errors"].append(f"Помилка при створенні сигнатур: {str(e)}") | |
| self.initial_info["status"] = "error" | |
| return self.initial_info, None | |
| print("\nЗбереження інформації про модель...") | |
| try: | |
| self.classifier.save_model_info(config.MODEL_INFO_FILE) | |
| with open(config.MODEL_INFO_FILE, "r") as f: | |
| self.initial_info["model_info"] = json.load(f) | |
| self.initial_info["status"] = "success" | |
| print("\nІніціалізація завершена успішно") | |
| return self.initial_info, self.classifier | |
| except Exception as e: | |
| print(f"\nПомилка при збереженні інформації про модель: {str(e)}") | |
| self.initial_info["errors"].append(f"Помилка при читанні model_info: {str(e)}") | |
| self.initial_info["status"] = "error" | |
| return self.initial_info, None | |
| except Exception as e: | |
| print(f"\nЗагальна помилка при ініціалізації: {str(e)}") | |
| self.initial_info["errors"].append(f"ПОМИЛКА при ініціалізації: {str(e)}") | |
| self.initial_info["status"] = "error" | |
| return self.initial_info, None | |
| def create_classifier( | |
| self, | |
| model_type: str, | |
| openai_model: Optional[str] = None, | |
| local_model: Optional[str] = None, | |
| device: Optional[str] = None | |
| ) -> SDCClassifier: | |
| """Створення класифікатора з відповідними параметрами""" | |
| classifier = SDCClassifier() | |
| if model_type == "OpenAI": | |
| if hasattr(classifier, 'set_openai_model'): | |
| classifier.set_openai_model(openai_model) | |
| else: | |
| if hasattr(classifier, 'set_local_model'): | |
| classifier.set_local_model(local_model, device) | |
| return classifier | |
| def update_model_inputs( | |
| self, | |
| model_type: str, | |
| openai_model: str, | |
| local_model: str, | |
| device: str | |
| ) -> Dict[str, Any]: | |
| """Оновлення моделі та інтерфейсу при зміні типу моделі""" | |
| try: | |
| self.classifier = self.create_classifier( | |
| model_type=model_type, | |
| openai_model=openai_model if model_type == "OpenAI" else None, | |
| local_model=local_model if model_type == "Local" else None, | |
| device=device if model_type == "Local" else None | |
| ) | |
| self.classifier.restore_base_state() | |
| result = self.classifier.initialize_signatures() | |
| self.classifier.save_model_info(config.MODEL_INFO_FILE) | |
| with open(config.MODEL_INFO_FILE, "r") as f: | |
| model_info = json.load(f) | |
| new_system_info = { | |
| "status": "success", | |
| "model_info": model_info, | |
| "classes_info": { | |
| "total_classes": len(self.classifier.classes_json), | |
| "classes_list": list(self.classifier.classes_json.keys()), | |
| "hints_per_class": {cls: len(hints) for cls, hints in self.classifier.classes_json.items()} | |
| }, | |
| "errors": [] | |
| } | |
| return { | |
| "model_choice": gr.update(visible=model_type == "OpenAI"), | |
| "local_model_path": gr.update(visible=model_type == "Local"), | |
| "device_choice": gr.update(visible=model_type == "Local"), | |
| "system_info": new_system_info, | |
| "system_md": self.update_system_markdown(new_system_info), | |
| "build_out": f"Модель змінено на {model_type}", | |
| "cache_stats": self.classifier.get_cache_stats() | |
| } | |
| except Exception as e: | |
| error_info = { | |
| "status": "error", | |
| "errors": [str(e)], | |
| "model_info": {}, | |
| "classes_info": {} | |
| } | |
| return { | |
| "model_choice": gr.update(visible=model_type == "OpenAI"), | |
| "local_model_path": gr.update(visible=model_type == "Local"), | |
| "device_choice": gr.update(visible=model_type == "Local"), | |
| "system_info": error_info, | |
| "system_md": self.update_system_markdown(error_info), | |
| "build_out": f"Помилка: {str(e)}", | |
| "cache_stats": {} | |
| } | |
| def update_classifier_settings( | |
| self, | |
| json_file: Optional[str], | |
| model_type: str, | |
| openai_model: str, | |
| local_model: str, | |
| device: str, | |
| force_rebuild: bool | |
| ) -> Tuple[str, Dict, Dict, str]: | |
| """Оновлення налаштувань класифікатора""" | |
| try: | |
| self.classifier = self.create_classifier( | |
| model_type=model_type, | |
| openai_model=openai_model if model_type == "OpenAI" else None, | |
| local_model=local_model if model_type == "Local" else None, | |
| device=device if model_type == "Local" else None | |
| ) | |
| if json_file is not None: | |
| with open(json_file.name, 'r', encoding='utf-8') as f: | |
| new_classes = json.load(f) | |
| self.classifier.load_classes(new_classes) | |
| else: | |
| self.classifier.restore_base_state() | |
| result = self.classifier.initialize_signatures( | |
| force_rebuild=force_rebuild, | |
| signatures_file=config.DEFAULT_SIGNATURES_FILE if not force_rebuild else None | |
| ) | |
| self.classifier.save_model_info(config.MODEL_INFO_FILE) | |
| with open(config.MODEL_INFO_FILE, "r") as f: | |
| model_info = json.load(f) | |
| new_system_info = { | |
| "status": "success", | |
| "model_info": model_info, | |
| "classes_info": { | |
| "total_classes": len(self.classifier.classes_json), | |
| "classes_list": list(self.classifier.classes_json.keys()), | |
| "hints_per_class": { | |
| cls: len(hints) | |
| for cls, hints in self.classifier.classes_json.items() | |
| } | |
| }, | |
| "errors": [] | |
| } | |
| return ( | |
| result, | |
| self.classifier.get_cache_stats(), | |
| new_system_info, | |
| self.update_system_markdown(new_system_info) | |
| ) | |
| except Exception as e: | |
| error_info = { | |
| "status": "error", | |
| "errors": [str(e)], | |
| "model_info": {}, | |
| "classes_info": {} | |
| } | |
| return ( | |
| f"Помилка: {str(e)}", | |
| self.classifier.get_cache_stats(), | |
| error_info, | |
| self.update_system_markdown(error_info) | |
| ) | |
| def process_single_text(self, text: str, threshold: float) -> Dict: | |
| """Обробка одного тексту""" | |
| try: | |
| if self.classifier is None: | |
| raise ValueError("Класифікатор не ініціалізовано") | |
| return self.classifier.process_single_text(text, threshold) | |
| except Exception as e: | |
| return {"error": str(e)} | |
| def load_data(self, csv_path: str, emb_path: str) -> str: | |
| """Завантаження даних для пакетної обробки""" | |
| try: | |
| if self.classifier is None: | |
| raise ValueError("Класифікатор не ініціалізовано") | |
| return self.classifier.load_data(csv_path, emb_path) | |
| except Exception as e: | |
| return f"Помилка: {str(e)}" | |
| def classify_batch(self, filter_str: str, threshold: float): | |
| """Пакетна класифікація""" | |
| try: | |
| if self.classifier is None: | |
| raise ValueError("Класифікатор не ініціалізовано") | |
| return self.classifier.classify_rows(filter_str, threshold) | |
| except Exception as e: | |
| return None | |
| def save_results(self) -> str: | |
| """Збереження результатів""" | |
| try: | |
| if self.classifier is None: | |
| raise ValueError("Класифікатор не ініціалізовано") | |
| return self.classifier.save_results() | |
| except Exception as e: | |
| return f"Помилка: {str(e)}" | |
| def sync_system_info(self) -> Dict: | |
| """Синхронізація системної інформації""" | |
| try: | |
| if self.classifier is None: | |
| raise ValueError("Класифікатор не ініціалізовано") | |
| self.classifier.save_model_info(config.MODEL_INFO_FILE) | |
| with open(config.MODEL_INFO_FILE, "r") as f: | |
| model_info = json.load(f) | |
| self.initial_info = { | |
| "status": "success", | |
| "model_info": model_info, | |
| "classes_info": { | |
| "total_classes": len(self.classifier.classes_json), | |
| "classes_list": list(self.classifier.classes_json.keys()), | |
| "hints_per_class": { | |
| cls: len(hints) | |
| for cls, hints in self.classifier.classes_json.items() | |
| } | |
| }, | |
| "errors": [] | |
| } | |
| return self.initial_info | |
| except Exception as e: | |
| self.initial_info = { | |
| "status": "error", | |
| "model_info": {}, | |
| "classes_info": {}, | |
| "errors": [str(e)] | |
| } | |
| return self.initial_info | |
| def evaluate_batch(self, csv_file, threshold: float) -> tuple[pd.DataFrame, str]: | |
| """ | |
| Оцінка пакетної класифікації | |
| Args: | |
| csv_file: завантажений CSV файл від gradio | |
| threshold: поріг впевненості | |
| Returns: | |
| tuple[pd.DataFrame, str]: результати та статистика | |
| """ | |
| try: | |
| if self.classifier is None: | |
| return None, "Помилка: Класифікатор не ініціалізовано" | |
| # Перевірка на None | |
| if csv_file is None: | |
| return None, "Помилка: Файл не завантажено" | |
| # Зберігаємо тимчасовий файл | |
| temp_path = "temp_upload.csv" | |
| if hasattr(csv_file, 'name'): | |
| # Якщо це файловий об'єкт від gradio | |
| import shutil | |
| shutil.copy2(csv_file.name, temp_path) | |
| else: | |
| # Якщо це шлях до файлу | |
| temp_path = str(csv_file) | |
| # Виконуємо класифікацію | |
| results_df, statistics = self.classifier.evaluate_classification(temp_path, threshold) | |
| # Формуємо текст статистики | |
| stats_md = f"""### Статистика класифікації | |
| - Всього зразків: {statistics['total_samples']} | |
| - Правильний клас на першому місці: {statistics['correct_first_place']['count']} ({statistics['correct_first_place']['percentage']}%) | |
| - Правильний клас в топ-3: {statistics['in_top3']['count']} ({statistics['in_top3']['percentage']}%) | |
| - Правильний клас не знайдено: {statistics['not_found']['count']} ({statistics['not_found']['percentage']}%) | |
| #### Середня впевненість для правильних класифікацій: {statistics['mean_confidence_correct']}% | |
| #### Розподіл впевненості: | |
| - 90-100%: {statistics['confidence_distribution']['90-100%']['count']} ({statistics['confidence_distribution']['90-100%']['percentage']}%) | |
| - 70-90%: {statistics['confidence_distribution']['70-90%']['count']} ({statistics['confidence_distribution']['70-90%']['percentage']}%) | |
| - 50-70%: {statistics['confidence_distribution']['50-70%']['count']} ({statistics['confidence_distribution']['50-70%']['percentage']}%) | |
| - <50%: {statistics['confidence_distribution']['<50%']['count']} ({statistics['confidence_distribution']['<50%']['percentage']}%) | |
| """ | |
| # Зберігаємо результати для подальшого використання | |
| self.current_evaluation_results = results_df | |
| # Видаляємо тимчасовий файл якщо він був створений | |
| if temp_path == "temp_upload.csv" and os.path.exists(temp_path): | |
| os.remove(temp_path) | |
| return results_df, stats_md | |
| except Exception as e: | |
| # У випадку помилки спробуємо видалити тимчасовий файл | |
| if os.path.exists("temp_upload.csv"): | |
| os.remove("temp_upload.csv") | |
| return None, f"Помилка: {str(e)}" | |
| def save_evaluation_results(self) -> tuple[str, str]: | |
| """ | |
| Зберігає результати останньої оцінки класифікації та готує файл для завантаження | |
| Returns: | |
| tuple[str, str]: (шлях до файлу, повідомлення про статус) | |
| """ | |
| try: | |
| if not hasattr(self, 'current_evaluation_results'): | |
| return None, "Помилка: Немає результатів для збереження" | |
| output_path = "evaluation_results.csv" | |
| self.current_evaluation_results.to_csv(output_path, index=False) | |
| return output_path, f"Результати збережено у файл {output_path}" | |
| except Exception as e: | |
| return None, f"Помилка при збереженні: {str(e)}" | |
| def update_system_markdown(info: Dict) -> str: | |
| """Оновлення Markdown з системною інформацією""" | |
| if info["status"] == "success": | |
| return f""" | |
| ### Поточна конфігурація: | |
| - Модель: {info['model_info'].get('using_local', 'OpenAI')} | |
| - Кількість класів: {info['classes_info']['total_classes']} | |
| - Класи: {', '.join(info['classes_info']['classes_list'])} | |
| """ | |
| else: | |
| return f""" | |
| ### Помилки ініціалізації: | |
| {chr(10).join('- ' + err for err in info['errors'])} | |
| """ |