invincible-jha
Add CrewAI orchestrator for agent coordination
6f0fdff
raw
history blame
10.7 kB
import os
import torch
import gradio as gr
import logging
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from utils.log_manager import LogManager
from utils.analytics_logger import AnalyticsLogger
from agents.orchestrator import WellnessOrchestrator
# Force CPU-only mode
torch.cuda.is_available = lambda: False
if hasattr(torch, 'set_default_tensor_type'):
torch.set_default_tensor_type('torch.FloatTensor')
class WellnessInterface:
def __init__(self, config):
self.config = config
self.log_manager = LogManager()
self.logger = self.log_manager.get_agent_logger("interface")
self.analytics = AnalyticsLogger()
# Ensure CPU-only operation
self.device = "cpu"
self.logger.info("Using CPU-only mode")
# Initialize models
self.initialize_models()
# Initialize orchestrator
self.initialize_orchestrator()
# Initialize interface
self.setup_interface()
def initialize_models(self):
"""Initialize AI models"""
self.logger.info("Initializing AI models")
try:
# Initialize emotion detection model
self.emotion_model = pipeline(
"text-classification",
model=self.config["MODEL_CONFIGS"]["emotion_detection"]["model_id"],
device=self.device
)
# Initialize conversation model
self.conversation_tokenizer = AutoTokenizer.from_pretrained(
self.config["MODEL_CONFIGS"]["conversation"]["model_id"]
)
self.conversation_model = AutoModelForCausalLM.from_pretrained(
self.config["MODEL_CONFIGS"]["conversation"]["model_id"],
device_map={"": self.device}
)
self.logger.info("AI models initialized successfully")
except Exception as e:
self.logger.error(f"Error initializing models: {str(e)}")
raise
def initialize_orchestrator(self):
"""Initialize CrewAI orchestrator"""
self.logger.info("Initializing CrewAI orchestrator")
try:
self.orchestrator = WellnessOrchestrator(
model_config=self.config["MODEL_CONFIGS"]
)
self.logger.info("Orchestrator initialized successfully")
except Exception as e:
self.logger.error(f"Error initializing orchestrator: {str(e)}")
raise
def setup_interface(self):
"""Setup the Gradio interface components"""
self.logger.info("Setting up interface components")
try:
with gr.Blocks(
theme=gr.themes.Soft(),
css=".gradio-container {background-color: #f7f7f7}"
) as self.interface:
gr.Markdown(
"# 🧠 Mental Wellness Support",
elem_classes="text-center"
)
gr.Markdown(
"A safe space for mental health support and guidance.",
elem_classes="text-center"
)
with gr.Row():
with gr.Column(scale=3):
self.chatbot = gr.Chatbot(
label="Mental Wellness Assistant",
height=400,
value=[],
type="messages",
elem_id="wellness_chat",
avatar_images=["👤", "🤖"]
)
with gr.Row():
self.text_input = gr.Textbox(
label="Type your message",
placeholder="Enter your message here...",
lines=2,
scale=4,
container=False
)
self.submit_btn = gr.Button(
"Send",
scale=1,
variant="primary"
)
with gr.Row():
self.audio_input = gr.Audio(
label="Voice Input",
type="filepath",
format="wav",
scale=1
)
self.image_input = gr.Image(
label="Image Upload",
type="filepath",
scale=1
)
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### Quick Actions")
self.clear_btn = gr.Button(
"🗑️ Clear Chat",
variant="secondary"
)
self.emergency_btn = gr.Button(
"🚨 Emergency Help",
variant="stop"
)
gr.Markdown("### Resources")
gr.Markdown("""
- 📞 Crisis Hotline: 988
- 💭 Text HOME to 741741
- 🏥 Emergency: 911
""")
# Event handlers
self.submit_btn.click(
fn=self.process_input,
inputs=[
self.text_input,
self.audio_input,
self.image_input,
self.chatbot
],
outputs=[
self.chatbot,
self.text_input
],
api_name="chat"
)
self.clear_btn.click(
fn=self.clear_chat,
inputs=[],
outputs=[self.chatbot],
api_name="clear"
)
self.emergency_btn.click(
fn=self.emergency_help,
inputs=[],
outputs=[self.chatbot],
api_name="emergency"
)
# Add keyboard shortcuts
self.text_input.submit(
fn=self.process_input,
inputs=[
self.text_input,
self.audio_input,
self.image_input,
self.chatbot
],
outputs=[
self.chatbot,
self.text_input
]
)
self.logger.info("Interface setup completed successfully")
except Exception as e:
self.logger.error(f"Error setting up interface: {str(e)}")
raise
def process_input(self, text, audio, image, history):
"""Process user input from various sources"""
try:
if not text and not audio and not image:
return history, ""
# Log the interaction start
self.analytics.log_user_interaction(
user_id="anonymous",
interaction_type="message",
agent_type="interface",
duration=0,
success=True,
details={"input_types": {
"text": bool(text),
"audio": bool(audio),
"image": bool(image)
}}
)
# Process through orchestrator
context = {
"history": history,
"emotion": self.emotion_model(text)[0] if text else None,
"has_audio": bool(audio),
"has_image": bool(image)
}
response = self.orchestrator.process_message(
message=text if text else "Sent media",
context=context
)
# Add to chat history using message format
history = history or []
history.append({"role": "user", "content": text if text else "Sent media"})
history.append({
"role": "assistant",
"content": response["message"],
"metadata": {
"agent": response["agent_type"],
"task": response["task_type"]
}
})
return history, "" # Return empty string to clear text input
except Exception as e:
self.logger.error(f"Error processing input: {str(e)}")
history = history or []
history.append({
"role": "assistant",
"content": "I apologize, but I encountered an error. Please try again."
})
return history, text # Keep text input in case of error
def clear_chat(self):
"""Clear the chat history"""
self.logger.info("Clearing chat history")
return None
def emergency_help(self):
"""Provide emergency help information"""
self.logger.info("Emergency help requested")
# Use crisis agent through orchestrator
response = self.orchestrator.process_message(
message="EMERGENCY_HELP_REQUESTED",
context={"is_emergency": True}
)
return [{
"role": "assistant",
"content": response["message"],
"metadata": {
"agent": response["agent_type"],
"task": response["task_type"]
}
}]
def launch(self, **kwargs):
"""Launch the interface"""
self.logger.info("Launching interface")
# Configure for Hugging Face Spaces
kwargs.update({
"show_api": False,
"show_error": True,
"quiet": True
})
self.interface.launch(**kwargs)