lionguard-demo / app.py
gabrielchua's picture
Update app.py
1e12913 verified
raw
history blame
19.3 kB
"""
app.py
"""
# Standard imports
import json
import os
import sys
import uuid
import asyncio
from datetime import datetime
# Third party imports
import openai
import gradio as gr
import gspread
from google.oauth2 import service_account
from transformers import AutoModel
# Local imports
from utils import get_embeddings
# --- Categories
CATEGORIES = {
"binary": ["binary"],
"hateful": ["hateful_l1", "hateful_l2"],
"insults": ["insults"],
"sexual": [
"sexual_l1",
"sexual_l2",
],
"physical_violence": ["physical_violence"],
"self_harm": ["self_harm_l1", "self_harm_l2"],
"all_other_misconduct": [
"all_other_misconduct_l1",
"all_other_misconduct_l2",
],
}
# --- OpenAI Setup ---
# Create both sync and async clients
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
async_client = openai.AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
# --- Model Loading ---
def load_lionguard2():
model = AutoModel.from_pretrained("govtech/lionguard-2", trust_remote_code=True)
return model
model = load_lionguard2()
# --- Google Sheets Config ---
GOOGLE_SHEET_URL = os.environ.get("GOOGLE_SHEET_URL")
GOOGLE_CREDENTIALS = os.environ.get("GCP_SERVICE_ACCOUNT")
RESULTS_SHEET_NAME = "results"
VOTES_SHEET_NAME = "votes"
CHATBOT_SHEET_NAME = "chatbot"
def get_gspread_client():
credentials = service_account.Credentials.from_service_account_info(
json.loads(GOOGLE_CREDENTIALS),
scopes=[
"https://www.googleapis.com/auth/spreadsheets",
"https://www.googleapis.com/auth/drive",
],
)
return gspread.authorize(credentials)
def save_results_data(row):
try:
gc = get_gspread_client()
sheet = gc.open_by_url(GOOGLE_SHEET_URL)
ws = sheet.worksheet(RESULTS_SHEET_NAME)
ws.append_row(list(row.values()))
except Exception as e:
print(f"Error saving results data: {e}")
def save_vote_data(text_id, agree):
try:
gc = get_gspread_client()
sheet = gc.open_by_url(GOOGLE_SHEET_URL)
ws = sheet.worksheet(VOTES_SHEET_NAME)
vote_row = {
"datetime": datetime.now().isoformat(),
"text_id": text_id,
"agree": agree
}
ws.append_row(list(vote_row.values()))
except Exception as e:
print(f"Error saving vote data: {e}")
def log_chatbot_data(row):
try:
gc = get_gspread_client()
sheet = gc.open_by_url(GOOGLE_SHEET_URL)
ws = sheet.worksheet(CHATBOT_SHEET_NAME)
ws.append_row([
row["datetime"], row["text_id"], row["text"], row["binary_score"],
row["hateful_l1_score"], row["hateful_l2_score"], row["insults_score"],
row["sexual_l1_score"], row["sexual_l2_score"], row["physical_violence_score"],
row["self_harm_l1_score"], row["self_harm_l2_score"], row["aom_l1_score"],
row["aom_l2_score"], row["openai_score"]
])
except Exception as e:
print(f"Error saving chatbot data: {e}")
# --- Classifier logic ---
def format_score_with_style(score_str):
if score_str == "-":
return '<span style="color: #9ca3af;">-</span>'
try:
score = float(score_str)
percentage = int(score * 100)
if score < 0.4:
return f'<span style="color: #34d399; font-weight:600;">πŸ‘Œ {percentage}%</span>'
elif 0.4 <= score < 0.7:
return f'<span style="color: #fbbf24; font-weight:600;">⚠️ {percentage}%</span>'
else:
return f'<span style="color: #fca5a5; font-weight:600;">🚨 {percentage}%</span>'
except:
return score_str
def format_binary_score(score):
percentage = int(score * 100)
if score < 0.4:
return f'<div style="background:linear-gradient(135deg, #065f46 0%, #047857 100%); color:#34d399; padding:24px 0; border-radius:20px; text-align:center; font-weight:900; border:3px solid #10b981; font-size:24px; margin:24px 0; box-shadow:0 4px 24px rgba(0,0,0,0.3);">βœ… Pass ({percentage}/100)</div>'
elif 0.4 <= score < 0.7:
return f'<div style="background:linear-gradient(135deg, #92400e 0%, #b45309 100%); color:#fbbf24; padding:24px 0; border-radius:20px; text-align:center; font-weight:900; border:3px solid #f59e0b; font-size:24px; margin:24px 0; box-shadow:0 4px 24px rgba(0,0,0,0.3);">⚠️ Warning ({percentage}/100)</div>'
else:
return f'<div style="background:linear-gradient(135deg, #991b1b 0%, #b91c1c 100%); color:#fca5a5; padding:24px 0; border-radius:20px; text-align:center; font-weight:900; border:3px solid #ef4444; font-size:24px; margin:24px 0; box-shadow:0 4px 24px rgba(0,0,0,0.3);">🚨 Fail ({percentage}/100)</div>'
def analyze_text(text):
if not text.strip():
empty_html = '<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Enter text to analyze</div>'
return empty_html, empty_html, "", ""
try:
text_id = str(uuid.uuid4())
embeddings = get_embeddings([text])
results = model.predict(embeddings)
binary_score = results.get('binary', [0.0])[0]
main_categories = ['hateful', 'insults', 'sexual', 'physical_violence', 'self_harm', 'all_other_misconduct']
categories_html = []
max_scores = {}
for category in main_categories:
subcategories = CATEGORIES[category]
category_name = category.replace('_', ' ').title()
category_emojis = {
'Hateful': '🀬',
'Insults': 'πŸ’’',
'Sexual': 'πŸ”ž',
'Physical Violence': 'βš”οΈ',
'Self Harm': '☹️',
'All Other Misconduct': 'πŸ™…β€β™€οΈ'
}
category_display = f"{category_emojis.get(category_name, 'πŸ“')} {category_name}"
level_scores = [results.get(subcategory_key, [0.0])[0] for subcategory_key in subcategories]
max_score = max(level_scores) if level_scores else 0.0
max_scores[category] = max_score
categories_html.append(f'''
<tr>
<td>{category_display}</td>
<td style="text-align: center;">{format_score_with_style(f"{max_score:.4f}")}</td>
</tr>
''')
html_table = f'''
<table style="width:100%">
<thead>
<tr><th>Category</th><th>Score</th></tr>
</thead>
<tbody>
{''.join(categories_html)}
</tbody>
</table>
'''
# Save to Google Sheets if enabled
if GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
results_row = {
"datetime": datetime.now().isoformat(),
"text_id": text_id,
"text": text,
"binary_score": binary_score,
}
for category in main_categories:
results_row[f"{category}_max"] = max_scores[category]
save_results_data(results_row)
voting_html = '<div>Help improve LionGuard2! Rate the analysis below.</div>'
return format_binary_score(binary_score), html_table, text_id, voting_html
except Exception as e:
error_msg = f"Error analyzing text: {str(e)}"
return f'<div style="color: #fca5a5;">❌ {error_msg}</div>', '', '', ''
def vote_thumbs_up(text_id):
if text_id and GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
save_vote_data(text_id, True)
return '<div style="color: #34d399; font-weight:700;">πŸŽ‰ Thank you!</div>'
return '<div>Voting not available or analysis not yet run.</div>'
def vote_thumbs_down(text_id):
if text_id and GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
save_vote_data(text_id, False)
return '<div style="color: #fca5a5; font-weight:700;">πŸ“ Thanks for the feedback!</div>'
return '<div>Voting not available or analysis not yet run.</div>'
# --- Guardrail Comparison logic (ASYNC VERSION) ---
async def get_openai_response_async(message, system_prompt="You are a helpful assistant."):
"""Async version of OpenAI API call"""
try:
response = await async_client.chat.completions.create(
model="gpt-4.1-nano",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": message}
],
max_tokens=500,
temperature=0,
seed=42,
)
return response.choices[0].message.content
except Exception as e:
return f"Error: {str(e)}. Please check your OpenAI API key."
async def openai_moderation_async(message):
"""Async version of OpenAI moderation"""
try:
response = await async_client.moderations.create(input=message)
return response.results[0].flagged
except Exception as e:
print(f"Error in OpenAI moderation: {e}")
return False
def lionguard_2_sync(message, threshold=0.5):
"""LionGuard remains sync as it's using a local model"""
try:
embeddings = get_embeddings([message])
results = model.predict(embeddings)
binary_prob = results['binary'][0]
return binary_prob > threshold, binary_prob
except Exception as e:
print(f"Error in LionGuard 2: {e}")
return False, 0.0
async def process_no_moderation(message, history_no_mod):
"""Process message without moderation"""
no_mod_response = await get_openai_response_async(message)
history_no_mod.append({"role": "user", "content": message})
history_no_mod.append({"role": "assistant", "content": no_mod_response})
return history_no_mod
async def process_openai_moderation(message, history_openai):
"""Process message with OpenAI moderation"""
openai_flagged = await openai_moderation_async(message)
history_openai.append({"role": "user", "content": message})
if openai_flagged:
openai_response = "🚫 This message has been flagged by OpenAI moderation"
history_openai.append({"role": "assistant", "content": openai_response})
else:
openai_response = await get_openai_response_async(message)
history_openai.append({"role": "assistant", "content": openai_response})
return history_openai
async def process_lionguard(message, history_lg):
"""Process message with LionGuard 2"""
# Run LionGuard sync check in thread pool to not block
loop = asyncio.get_event_loop()
lg_flagged, lg_score = await loop.run_in_executor(None, lionguard_2_sync, message, 0.5)
history_lg.append({"role": "user", "content": message})
if lg_flagged:
lg_response = "🚫 This message has been flagged by LionGuard 2"
history_lg.append({"role": "assistant", "content": lg_response})
else:
lg_response = await get_openai_response_async(message)
history_lg.append({"role": "assistant", "content": lg_response})
return history_lg, lg_score
async def process_message_async(message, history_no_mod, history_openai, history_lg):
"""Process message concurrently across all three guardrails"""
if not message.strip():
return history_no_mod, history_openai, history_lg, ""
# Run all three processes concurrently using asyncio.gather
results = await asyncio.gather(
process_no_moderation(message, history_no_mod),
process_openai_moderation(message, history_openai),
process_lionguard(message, history_lg),
return_exceptions=True # Continue even if one fails
)
# Unpack results
history_no_mod = results[0] if not isinstance(results[0], Exception) else history_no_mod
history_openai = results[1] if not isinstance(results[1], Exception) else history_openai
history_lg_result = results[2] if not isinstance(results[2], Exception) else (history_lg, 0.0)
history_lg = history_lg_result[0]
lg_score = history_lg_result[1] if isinstance(history_lg_result, tuple) else 0.0
# --- Logging for chatbot worksheet (runs in background) ---
if GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
try:
loop = asyncio.get_event_loop()
# Run logging in thread pool so it doesn't block
loop.run_in_executor(None, _log_chatbot_sync, message, lg_score)
except Exception as e:
print(f"Chatbot logging failed: {e}")
return history_no_mod, history_openai, history_lg, ""
def _log_chatbot_sync(message, lg_score):
"""Sync helper for logging - runs in thread pool"""
try:
embeddings = get_embeddings([message])
results = model.predict(embeddings)
now = datetime.now().isoformat()
text_id = str(uuid.uuid4())
row = {
"datetime": now,
"text_id": text_id,
"text": message,
"binary_score": results.get("binary", [None])[0],
"hateful_l1_score": results.get(CATEGORIES['hateful'][0], [None])[0],
"hateful_l2_score": results.get(CATEGORIES['hateful'][1], [None])[0],
"insults_score": results.get(CATEGORIES['insults'][0], [None])[0],
"sexual_l1_score": results.get(CATEGORIES['sexual'][0], [None])[0],
"sexual_l2_score": results.get(CATEGORIES['sexual'][1], [None])[0],
"physical_violence_score": results.get(CATEGORIES['physical_violence'][0], [None])[0],
"self_harm_l1_score": results.get(CATEGORIES['self_harm'][0], [None])[0],
"self_harm_l2_score": results.get(CATEGORIES['self_harm'][1], [None])[0],
"aom_l1_score": results.get(CATEGORIES['all_other_misconduct'][0], [None])[0],
"aom_l2_score": results.get(CATEGORIES['all_other_misconduct'][1], [None])[0],
"openai_score": None
}
try:
openai_result = client.moderations.create(input=message)
row["openai_score"] = float(openai_result.results[0].category_scores.get("hate", 0.0))
except Exception:
row["openai_score"] = None
log_chatbot_data(row)
except Exception as e:
print(f"Error in sync logging: {e}")
def process_message(message, history_no_mod, history_openai, history_lg):
"""Wrapper function for Gradio (converts async to sync)"""
return asyncio.run(process_message_async(message, history_no_mod, history_openai, history_lg))
def clear_all_chats():
return [], [], []
# ---- MAIN GRADIO UI ----
DISCLAIMER = """
<div style='background: #fbbf24; color: #1e293b; border-radius: 8px; padding: 14px; margin-bottom: 12px; font-size: 15px; font-weight:500;'>
⚠️ LionGuard 2 may make mistakes. All entries are logged (anonymised) to improve the model.
</div>
"""
with gr.Blocks(title="LionGuard 2 Demo", theme=gr.themes.Soft()) as demo:
gr.HTML("<h1 style='text-align:center'>LionGuard 2 Demo</h1>")
with gr.Tabs():
with gr.Tab("Classifier"):
gr.HTML(DISCLAIMER)
with gr.Row():
with gr.Column(scale=1, min_width=400):
text_input = gr.Textbox(
label="Enter text to analyze:",
placeholder="Type your text here...",
lines=8,
max_lines=16,
container=True
)
analyze_btn = gr.Button("Analyze", variant="primary")
with gr.Column(scale=1, min_width=400):
binary_output = gr.HTML(
value='<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic; font-size:36px;">Enter text to analyze</div>'
)
category_table = gr.HTML(
value='<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Category scores will appear here after analysis</div>'
)
voting_feedback = gr.HTML(value="")
current_text_id = gr.Textbox(value="", visible=False)
with gr.Row(visible=False) as voting_buttons_row:
thumbs_up_btn = gr.Button("πŸ‘ Looks Accurate", variant="primary")
thumbs_down_btn = gr.Button("πŸ‘Ž Looks Wrong", variant="secondary")
def analyze_and_show_voting(text):
binary_score, category_table_val, text_id, voting_html = analyze_text(text)
show_vote = gr.update(visible=True) if text_id else gr.update(visible=False)
return binary_score, category_table_val, text_id, show_vote, "", ""
analyze_btn.click(
analyze_and_show_voting,
inputs=[text_input],
outputs=[binary_output, category_table, current_text_id, voting_buttons_row, voting_feedback, voting_feedback]
)
text_input.submit(
analyze_and_show_voting,
inputs=[text_input],
outputs=[binary_output, category_table, current_text_id, voting_buttons_row, voting_feedback, voting_feedback]
)
thumbs_up_btn.click(vote_thumbs_up, inputs=[current_text_id], outputs=[voting_feedback])
thumbs_down_btn.click(vote_thumbs_down, inputs=[current_text_id], outputs=[voting_feedback])
with gr.Tab("Guardrail Comparison"):
gr.HTML(DISCLAIMER)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("#### πŸ”΅ No Moderation")
chatbot_no_mod = gr.Chatbot(height=650, label="No Moderation", show_label=False, bubble_full_width=False, type='messages')
with gr.Column(scale=1):
gr.Markdown("#### 🟠 OpenAI Moderation")
chatbot_openai = gr.Chatbot(height=650, label="OpenAI Moderation", show_label=False, bubble_full_width=False, type='messages')
with gr.Column(scale=1):
gr.Markdown("#### πŸ›‘οΈ LionGuard 2")
chatbot_lg = gr.Chatbot(height=650, label="LionGuard 2", show_label=False, bubble_full_width=False, type='messages')
gr.Markdown("##### πŸ’¬ Send Message to All Models")
with gr.Row():
message_input = gr.Textbox(
placeholder="Type your message to compare responses...",
show_label=False,
scale=4
)
send_btn = gr.Button("Send", variant="primary", scale=1)
with gr.Row():
clear_btn = gr.Button("Clear All Chats", variant="stop")
send_btn.click(
process_message,
inputs=[message_input, chatbot_no_mod, chatbot_openai, chatbot_lg],
outputs=[chatbot_no_mod, chatbot_openai, chatbot_lg, message_input]
)
message_input.submit(
process_message,
inputs=[message_input, chatbot_no_mod, chatbot_openai, chatbot_lg],
outputs=[chatbot_no_mod, chatbot_openai, chatbot_lg, message_input]
)
clear_btn.click(
clear_all_chats,
outputs=[chatbot_no_mod, chatbot_openai, chatbot_lg]
)
if __name__ == "__main__":
demo.launch()