First_agent_template / model_init.py
DocUA's picture
Підключено Gemini, але потребує агентної роботи
435589f
raw
history blame
14.2 kB
from typing import Optional, Dict, Any
import os
import google.generativeai as genai
from huggingface_hub import HfApi
import logging
from smolagents import HfApiModel
logger = logging.getLogger(__name__)
class ModelWrapper:
"""Спрощена обгортка для моделей"""
def __init__(self, model, model_type):
self.model = model
self.model_type = model_type
def __call__(self, prompt, **kwargs):
try:
if self.model_type == 'gemini':
# Якщо prompt - це словник з роллю і контентом
if isinstance(prompt, dict) and 'content' in prompt:
text = prompt['content']
# Якщо prompt - це список повідомлень
elif isinstance(prompt, list):
# Беремо останнє повідомлення
last_message = prompt[-1]
text = last_message.get('content', '') if isinstance(last_message, dict) else str(last_message)
else:
text = str(prompt)
logger.info(f"Prompt для Gemini: {text[:100]}...")
response = self.model.generate_content(text)
return response.text
else:
kwargs.pop('stop_sequences', None) # Видаляємо stop_sequences для HF моделей
return self.model(prompt, **kwargs)
except Exception as e:
logger.error(f"Помилка ModelWrapper: {str(e)}")
logger.error(f"Тип prompt: {type(prompt)}")
logger.error(f"Prompt: {str(prompt)[:200]}")
raise
"""Обгортка для уніфікації інтерфейсу різних моделей"""
def __init__(self, model, model_type):
self.model = model
self.model_type = model_type
def _extract_text_from_input(self, input_data):
"""Витягує текст з різних форматів вхідних даних"""
try:
if isinstance(input_data, str):
return input_data
elif isinstance(input_data, list):
# Обробка списку повідомлень
messages = []
for msg in input_data:
if isinstance(msg, dict):
# Витягуємо контент з повідомлення
content = msg.get('content', '')
if isinstance(content, list):
# Якщо контент - список, обробляємо кожен елемент
for item in content:
if isinstance(item, dict) and 'text' in item:
messages.append(item['text'])
else:
messages.append(str(item))
else:
messages.append(str(content))
else:
messages.append(str(msg))
return ' '.join(messages)
elif isinstance(input_data, dict):
# Обробка одиночного повідомлення
content = input_data.get('content', '')
if isinstance(content, list):
return ' '.join(item.get('text', str(item)) for item in content if isinstance(item, dict))
return str(content)
return str(input_data)
except Exception as e:
logger.error(f"Помилка при обробці вхідних даних: {e}")
logger.error(f"Тип даних: {type(input_data)}")
logger.error(f"Дані: {str(input_data)[:200]}")
return str(input_data)
def __call__(self, prompt: str, **kwargs) -> str:
"""
Виклик моделі з підтримкою додаткових параметрів.
Args:
prompt: Текст запиту
**kwargs: Додаткові параметри (ігноруються для Gemini)
"""
try:
if self.model_type == 'gemini':
# Для Gemini ігноруємо додаткові параметри і просто передаємо текст
text = self._extract_text_from_input(prompt)
logger.info(f"Gemini отримав запит: {text[:200]}...") # Логуємо перші 200 символів
response = self.model.generate_content(text)
if response and hasattr(response, 'text'):
logger.info("Gemini успішно згенерував відповідь")
return response.text
else:
error_msg = "Gemini повернув порожню або неправильну відповідь"
logger.error(error_msg)
raise ValueError(error_msg)
else: # huggingface model
# Видаляємо stop_sequences, якщо він є
kwargs.pop('stop_sequences', None)
return self.model(prompt, **kwargs)
except Exception as e:
logger.error(f"Помилка при виклику моделі {self.model_type}: {e}")
logger.error(f"Тип вхідних даних: {type(prompt)}")
logger.error(f"Вміст вхідних даних: {str(prompt)[:200]}")
raise
"""Обгортка для уніфікації інтерфейсу різних моделей"""
def __init__(self, model, model_type):
self.model = model
self.model_type = model_type
def _extract_text_from_input(self, input_data):
"""Витягує текст з різних форматів вхідних даних"""
if isinstance(input_data, str):
return input_data
elif isinstance(input_data, dict):
return input_data.get('content', str(input_data))
elif isinstance(input_data, list):
return ' '.join(self._extract_text_from_input(item) for item in input_data)
return str(input_data)
def __call__(self, prompt: str, **kwargs) -> str:
"""
Виклик моделі з підтримкою додаткових параметрів.
Args:
prompt: Текст запиту
**kwargs: Додаткові параметри (ігноруються для Gemini)
"""
try:
if self.model_type == 'gemini':
# Для Gemini ігноруємо додаткові параметри і просто передаємо текст
text = self._extract_text_from_input(prompt)
logger.info(f"Gemini отримав запит: {text[:200]}...") # Логуємо перші 200 символів
response = self.model.generate_content(text)
if response and hasattr(response, 'text'):
logger.info("Gemini успішно згенерував відповідь")
return response.text
else:
error_msg = "Gemini повернув порожню або неправильну відповідь"
logger.error(error_msg)
raise ValueError(error_msg)
else: # huggingface model
# Видаляємо stop_sequences, якщо він є
kwargs.pop('stop_sequences', None)
return self.model(prompt, **kwargs)
except Exception as e:
logger.error(f"Помилка при виклику моделі {self.model_type}: {e}")
logger.error(f"Тип вхідних даних: {type(prompt)}")
logger.error(f"Вміст вхідних даних: {str(prompt)[:200]}") # Логуємо перші 200 символів
raise
class ModelInitializer:
def __init__(self, config_path: str = 'models_config.json'):
self.config = self._load_config(config_path)
self._setup_api_keys()
def _setup_api_keys(self):
"""Налаштування API ключів"""
self.hf_api_token = os.getenv('HF_API_TOKEN')
self.gemini_api_key = os.getenv('GEMINI_API_KEY')
if self.gemini_api_key:
genai.configure(api_key=self.gemini_api_key)
logger.info("API ключ Gemini успішно налаштовано")
else:
logger.warning("API ключ Gemini не знайдено")
def _load_config(self, config_path: str) -> Dict[str, Any]:
"""Завантаження конфігурації моделей"""
import json
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
logger.info(f"Конфігурацію успішно завантажено з {config_path}")
return config
except Exception as e:
logger.error(f"Помилка завантаження конфігурації: {e}")
raise
def initialize_model(self, model_key: Optional[str] = None) -> Any:
"""Ініціалізація вибраної моделі"""
try:
if model_key is None:
model_key = self.config['default_model']
logger.info(f"Використовуємо модель за замовчуванням: {model_key}")
model_config = self.config['models'].get(model_key)
if not model_config:
error_msg = f"Модель {model_key} не знайдена в конфігурації"
logger.error(error_msg)
raise ValueError(error_msg)
if model_key == 'gemini-flash':
model = self._initialize_gemini(model_config)
logger.info("Ініціалізовано Gemini модель")
return ModelWrapper(model, 'gemini')
else:
model = self._initialize_huggingface(model_config)
logger.info("Ініціалізовано HuggingFace модель")
return ModelWrapper(model, 'huggingface')
except Exception as e:
error_msg = f"Помилка ініціалізації моделі: {e}"
logger.error(error_msg)
raise ValueError(error_msg)
def _initialize_gemini(self, config: Dict[str, Any]) -> Any:
"""Ініціалізація Gemini моделі"""
if not self.gemini_api_key:
raise ValueError("GEMINI_API_KEY не знайдено в змінних середовища")
try:
# Налаштування безпеки (вимикаємо всі обмеження для наукових досліджень)
safety_settings = [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
}
]
# Створюємо модель з мінімальними обмеженнями
model = genai.GenerativeModel(
model_name='gemini-pro',
safety_settings=safety_settings
)
# Тестуємо модель
try:
logger.info("Тестування з'єднання з Gemini...")
test_response = model.generate_content("Test connection")
if test_response and hasattr(test_response, 'text'):
logger.info("Gemini модель успішно ініціалізовано та протестовано")
return model
else:
raise ValueError("Тестова генерація не повернула текст")
except Exception as e:
raise ValueError(f"Помилка тестування Gemini: {str(e)}")
except Exception as e:
error_msg = f"Помилка ініціалізації Gemini: {str(e)}"
logger.error(error_msg)
raise ValueError(error_msg)
def _initialize_huggingface(self, config: Dict[str, Any]) -> Any:
"""Ініціалізація Hugging Face моделі"""
if not self.hf_api_token:
raise ValueError("HF_API_TOKEN не знайдено в змінних середовища")
try:
model = HfApiModel(
model_id=config['model_id'],
token=self.hf_api_token,
temperature=config['parameters']['temperature'],
max_tokens=config['parameters']['max_tokens']
)
return model
except Exception as e:
error_msg = f"Помилка ініціалізації HuggingFace: {str(e)}"
logger.error(error_msg)
raise ValueError(error_msg)
def get_available_models(self) -> list:
"""Отримання списку доступних моделей"""
return [(key, model['description'])
for key, model in self.config['models'].items()]