Spaces:
Running
Running
| """ | |
| app.py | |
| """ | |
| # Standard imports | |
| import json | |
| import os | |
| import sys | |
| import uuid | |
| import asyncio | |
| from datetime import datetime | |
| # Third party imports | |
| import openai | |
| import gradio as gr | |
| import gspread | |
| from google.oauth2 import service_account | |
| from transformers import AutoModel | |
| # Local imports | |
| from utils import get_embeddings | |
| # --- Categories | |
| CATEGORIES = { | |
| "binary": ["binary"], | |
| "hateful": ["hateful_l1", "hateful_l2"], | |
| "insults": ["insults"], | |
| "sexual": [ | |
| "sexual_l1", | |
| "sexual_l2", | |
| ], | |
| "physical_violence": ["physical_violence"], | |
| "self_harm": ["self_harm_l1", "self_harm_l2"], | |
| "all_other_misconduct": [ | |
| "all_other_misconduct_l1", | |
| "all_other_misconduct_l2", | |
| ], | |
| } | |
| # --- OpenAI Setup --- | |
| # Create both sync and async clients | |
| client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| async_client = openai.AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| # --- Model Loading --- | |
| def load_lionguard2(): | |
| model = AutoModel.from_pretrained("govtech/lionguard-2", trust_remote_code=True) | |
| return model | |
| model = load_lionguard2() | |
| # --- Google Sheets Config --- | |
| GOOGLE_SHEET_URL = os.environ.get("GOOGLE_SHEET_URL") | |
| GOOGLE_CREDENTIALS = os.environ.get("GCP_SERVICE_ACCOUNT") | |
| RESULTS_SHEET_NAME = "results" | |
| VOTES_SHEET_NAME = "votes" | |
| CHATBOT_SHEET_NAME = "chatbot" | |
| def get_gspread_client(): | |
| credentials = service_account.Credentials.from_service_account_info( | |
| json.loads(GOOGLE_CREDENTIALS), | |
| scopes=[ | |
| "https://www.googleapis.com/auth/spreadsheets", | |
| "https://www.googleapis.com/auth/drive", | |
| ], | |
| ) | |
| return gspread.authorize(credentials) | |
| def save_results_data(row): | |
| try: | |
| gc = get_gspread_client() | |
| sheet = gc.open_by_url(GOOGLE_SHEET_URL) | |
| ws = sheet.worksheet(RESULTS_SHEET_NAME) | |
| ws.append_row(list(row.values())) | |
| except Exception as e: | |
| print(f"Error saving results data: {e}") | |
| def save_vote_data(text_id, agree): | |
| try: | |
| gc = get_gspread_client() | |
| sheet = gc.open_by_url(GOOGLE_SHEET_URL) | |
| ws = sheet.worksheet(VOTES_SHEET_NAME) | |
| vote_row = { | |
| "datetime": datetime.now().isoformat(), | |
| "text_id": text_id, | |
| "agree": agree | |
| } | |
| ws.append_row(list(vote_row.values())) | |
| except Exception as e: | |
| print(f"Error saving vote data: {e}") | |
| def log_chatbot_data(row): | |
| try: | |
| gc = get_gspread_client() | |
| sheet = gc.open_by_url(GOOGLE_SHEET_URL) | |
| ws = sheet.worksheet(CHATBOT_SHEET_NAME) | |
| ws.append_row([ | |
| row["datetime"], row["text_id"], row["text"], row["binary_score"], | |
| row["hateful_l1_score"], row["hateful_l2_score"], row["insults_score"], | |
| row["sexual_l1_score"], row["sexual_l2_score"], row["physical_violence_score"], | |
| row["self_harm_l1_score"], row["self_harm_l2_score"], row["aom_l1_score"], | |
| row["aom_l2_score"], row["openai_score"] | |
| ]) | |
| except Exception as e: | |
| print(f"Error saving chatbot data: {e}") | |
| # --- Classifier logic --- | |
| def format_score_with_style(score_str): | |
| if score_str == "-": | |
| return '<span style="color: #9ca3af;">-</span>' | |
| try: | |
| score = float(score_str) | |
| percentage = int(score * 100) | |
| if score < 0.4: | |
| return f'<span style="color: #34d399; font-weight:600;">π {percentage}%</span>' | |
| elif 0.4 <= score < 0.7: | |
| return f'<span style="color: #fbbf24; font-weight:600;">β οΈ {percentage}%</span>' | |
| else: | |
| return f'<span style="color: #fca5a5; font-weight:600;">π¨ {percentage}%</span>' | |
| except: | |
| return score_str | |
| def format_binary_score(score): | |
| percentage = int(score * 100) | |
| if score < 0.4: | |
| return f'<div style="background:linear-gradient(135deg, #065f46 0%, #047857 100%); color:#34d399; padding:24px 0; border-radius:20px; text-align:center; font-weight:900; border:3px solid #10b981; font-size:24px; margin:24px 0; box-shadow:0 4px 24px rgba(0,0,0,0.3);">β Pass ({percentage}/100)</div>' | |
| elif 0.4 <= score < 0.7: | |
| return f'<div style="background:linear-gradient(135deg, #92400e 0%, #b45309 100%); color:#fbbf24; padding:24px 0; border-radius:20px; text-align:center; font-weight:900; border:3px solid #f59e0b; font-size:24px; margin:24px 0; box-shadow:0 4px 24px rgba(0,0,0,0.3);">β οΈ Warning ({percentage}/100)</div>' | |
| else: | |
| return f'<div style="background:linear-gradient(135deg, #991b1b 0%, #b91c1c 100%); color:#fca5a5; padding:24px 0; border-radius:20px; text-align:center; font-weight:900; border:3px solid #ef4444; font-size:24px; margin:24px 0; box-shadow:0 4px 24px rgba(0,0,0,0.3);">π¨ Fail ({percentage}/100)</div>' | |
| def analyze_text(text): | |
| if not text.strip(): | |
| empty_html = '<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Enter text to analyze</div>' | |
| return empty_html, empty_html, "", "" | |
| try: | |
| text_id = str(uuid.uuid4()) | |
| embeddings = get_embeddings([text]) | |
| results = model.predict(embeddings) | |
| binary_score = results.get('binary', [0.0])[0] | |
| main_categories = ['hateful', 'insults', 'sexual', 'physical_violence', 'self_harm', 'all_other_misconduct'] | |
| categories_html = [] | |
| max_scores = {} | |
| for category in main_categories: | |
| subcategories = CATEGORIES[category] | |
| category_name = category.replace('_', ' ').title() | |
| category_emojis = { | |
| 'Hateful': 'π€¬', | |
| 'Insults': 'π’', | |
| 'Sexual': 'π', | |
| 'Physical Violence': 'βοΈ', | |
| 'Self Harm': 'βΉοΈ', | |
| 'All Other Misconduct': 'π ββοΈ' | |
| } | |
| category_display = f"{category_emojis.get(category_name, 'π')} {category_name}" | |
| level_scores = [results.get(subcategory_key, [0.0])[0] for subcategory_key in subcategories] | |
| max_score = max(level_scores) if level_scores else 0.0 | |
| max_scores[category] = max_score | |
| categories_html.append(f''' | |
| <tr> | |
| <td>{category_display}</td> | |
| <td style="text-align: center;">{format_score_with_style(f"{max_score:.4f}")}</td> | |
| </tr> | |
| ''') | |
| html_table = f''' | |
| <table style="width:100%"> | |
| <thead> | |
| <tr><th>Category</th><th>Score</th></tr> | |
| </thead> | |
| <tbody> | |
| {''.join(categories_html)} | |
| </tbody> | |
| </table> | |
| ''' | |
| # Save to Google Sheets if enabled | |
| if GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS: | |
| results_row = { | |
| "datetime": datetime.now().isoformat(), | |
| "text_id": text_id, | |
| "text": text, | |
| "binary_score": binary_score, | |
| } | |
| for category in main_categories: | |
| results_row[f"{category}_max"] = max_scores[category] | |
| save_results_data(results_row) | |
| voting_html = '<div>Help improve LionGuard2! Rate the analysis below.</div>' | |
| return format_binary_score(binary_score), html_table, text_id, voting_html | |
| except Exception as e: | |
| error_msg = f"Error analyzing text: {str(e)}" | |
| return f'<div style="color: #fca5a5;">β {error_msg}</div>', '', '', '' | |
| def vote_thumbs_up(text_id): | |
| if text_id and GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS: | |
| save_vote_data(text_id, True) | |
| return '<div style="color: #34d399; font-weight:700;">π Thank you!</div>' | |
| return '<div>Voting not available or analysis not yet run.</div>' | |
| def vote_thumbs_down(text_id): | |
| if text_id and GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS: | |
| save_vote_data(text_id, False) | |
| return '<div style="color: #fca5a5; font-weight:700;">π Thanks for the feedback!</div>' | |
| return '<div>Voting not available or analysis not yet run.</div>' | |
| # --- Guardrail Comparison logic (ASYNC VERSION) --- | |
| async def get_openai_response_async(message, system_prompt="You are a helpful assistant."): | |
| """Async version of OpenAI API call""" | |
| try: | |
| response = await async_client.chat.completions.create( | |
| model="gpt-4.1-nano", | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": message} | |
| ], | |
| max_tokens=500, | |
| temperature=0, | |
| seed=42, | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| return f"Error: {str(e)}. Please check your OpenAI API key." | |
| async def openai_moderation_async(message): | |
| """Async version of OpenAI moderation""" | |
| try: | |
| response = await async_client.moderations.create(input=message) | |
| return response.results[0].flagged | |
| except Exception as e: | |
| print(f"Error in OpenAI moderation: {e}") | |
| return False | |
| def lionguard_2_sync(message, threshold=0.5): | |
| """LionGuard remains sync as it's using a local model""" | |
| try: | |
| embeddings = get_embeddings([message]) | |
| results = model.predict(embeddings) | |
| binary_prob = results['binary'][0] | |
| return binary_prob > threshold, binary_prob | |
| except Exception as e: | |
| print(f"Error in LionGuard 2: {e}") | |
| return False, 0.0 | |
| async def process_no_moderation(message, history_no_mod): | |
| """Process message without moderation""" | |
| no_mod_response = await get_openai_response_async(message) | |
| history_no_mod.append({"role": "user", "content": message}) | |
| history_no_mod.append({"role": "assistant", "content": no_mod_response}) | |
| return history_no_mod | |
| async def process_openai_moderation(message, history_openai): | |
| """Process message with OpenAI moderation""" | |
| openai_flagged = await openai_moderation_async(message) | |
| history_openai.append({"role": "user", "content": message}) | |
| if openai_flagged: | |
| openai_response = "π« This message has been flagged by OpenAI moderation" | |
| history_openai.append({"role": "assistant", "content": openai_response}) | |
| else: | |
| openai_response = await get_openai_response_async(message) | |
| history_openai.append({"role": "assistant", "content": openai_response}) | |
| return history_openai | |
| async def process_lionguard(message, history_lg): | |
| """Process message with LionGuard 2""" | |
| # Run LionGuard sync check in thread pool to not block | |
| loop = asyncio.get_event_loop() | |
| lg_flagged, lg_score = await loop.run_in_executor(None, lionguard_2_sync, message, 0.5) | |
| history_lg.append({"role": "user", "content": message}) | |
| if lg_flagged: | |
| lg_response = "π« This message has been flagged by LionGuard 2" | |
| history_lg.append({"role": "assistant", "content": lg_response}) | |
| else: | |
| lg_response = await get_openai_response_async(message) | |
| history_lg.append({"role": "assistant", "content": lg_response}) | |
| return history_lg, lg_score | |
| async def process_message_async(message, history_no_mod, history_openai, history_lg): | |
| """Process message concurrently across all three guardrails""" | |
| if not message.strip(): | |
| return history_no_mod, history_openai, history_lg, "" | |
| # Run all three processes concurrently using asyncio.gather | |
| results = await asyncio.gather( | |
| process_no_moderation(message, history_no_mod), | |
| process_openai_moderation(message, history_openai), | |
| process_lionguard(message, history_lg), | |
| return_exceptions=True # Continue even if one fails | |
| ) | |
| # Unpack results | |
| history_no_mod = results[0] if not isinstance(results[0], Exception) else history_no_mod | |
| history_openai = results[1] if not isinstance(results[1], Exception) else history_openai | |
| history_lg_result = results[2] if not isinstance(results[2], Exception) else (history_lg, 0.0) | |
| history_lg = history_lg_result[0] | |
| lg_score = history_lg_result[1] if isinstance(history_lg_result, tuple) else 0.0 | |
| # --- Logging for chatbot worksheet (runs in background) --- | |
| if GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS: | |
| try: | |
| loop = asyncio.get_event_loop() | |
| # Run logging in thread pool so it doesn't block | |
| loop.run_in_executor(None, _log_chatbot_sync, message, lg_score) | |
| except Exception as e: | |
| print(f"Chatbot logging failed: {e}") | |
| return history_no_mod, history_openai, history_lg, "" | |
| def _log_chatbot_sync(message, lg_score): | |
| """Sync helper for logging - runs in thread pool""" | |
| try: | |
| embeddings = get_embeddings([message]) | |
| results = model.predict(embeddings) | |
| now = datetime.now().isoformat() | |
| text_id = str(uuid.uuid4()) | |
| row = { | |
| "datetime": now, | |
| "text_id": text_id, | |
| "text": message, | |
| "binary_score": results.get("binary", [None])[0], | |
| "hateful_l1_score": results.get(CATEGORIES['hateful'][0], [None])[0], | |
| "hateful_l2_score": results.get(CATEGORIES['hateful'][1], [None])[0], | |
| "insults_score": results.get(CATEGORIES['insults'][0], [None])[0], | |
| "sexual_l1_score": results.get(CATEGORIES['sexual'][0], [None])[0], | |
| "sexual_l2_score": results.get(CATEGORIES['sexual'][1], [None])[0], | |
| "physical_violence_score": results.get(CATEGORIES['physical_violence'][0], [None])[0], | |
| "self_harm_l1_score": results.get(CATEGORIES['self_harm'][0], [None])[0], | |
| "self_harm_l2_score": results.get(CATEGORIES['self_harm'][1], [None])[0], | |
| "aom_l1_score": results.get(CATEGORIES['all_other_misconduct'][0], [None])[0], | |
| "aom_l2_score": results.get(CATEGORIES['all_other_misconduct'][1], [None])[0], | |
| "openai_score": None | |
| } | |
| try: | |
| openai_result = client.moderations.create(input=message) | |
| row["openai_score"] = float(openai_result.results[0].category_scores.get("hate", 0.0)) | |
| except Exception: | |
| row["openai_score"] = None | |
| log_chatbot_data(row) | |
| except Exception as e: | |
| print(f"Error in sync logging: {e}") | |
| def process_message(message, history_no_mod, history_openai, history_lg): | |
| """Wrapper function for Gradio (converts async to sync)""" | |
| return asyncio.run(process_message_async(message, history_no_mod, history_openai, history_lg)) | |
| def clear_all_chats(): | |
| return [], [], [] | |
| # ---- MAIN GRADIO UI ---- | |
| DISCLAIMER = """ | |
| <div style='background: #fbbf24; color: #1e293b; border-radius: 8px; padding: 14px; margin-bottom: 12px; font-size: 15px; font-weight:500;'> | |
| β οΈ LionGuard 2 may make mistakes. All entries are logged (anonymised) to improve the model. | |
| </div> | |
| """ | |
| with gr.Blocks(title="LionGuard 2 Demo", theme=gr.themes.Soft()) as demo: | |
| gr.HTML("<h1 style='text-align:center'>LionGuard 2 Demo</h1>") | |
| with gr.Tabs(): | |
| with gr.Tab("Classifier"): | |
| gr.HTML(DISCLAIMER) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=400): | |
| text_input = gr.Textbox( | |
| label="Enter text to analyze:", | |
| placeholder="Type your text here...", | |
| lines=8, | |
| max_lines=16, | |
| container=True | |
| ) | |
| analyze_btn = gr.Button("Analyze", variant="primary") | |
| with gr.Column(scale=1, min_width=400): | |
| binary_output = gr.HTML( | |
| value='<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic; font-size:36px;">Enter text to analyze</div>' | |
| ) | |
| category_table = gr.HTML( | |
| value='<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Category scores will appear here after analysis</div>' | |
| ) | |
| voting_feedback = gr.HTML(value="") | |
| current_text_id = gr.Textbox(value="", visible=False) | |
| with gr.Row(visible=False) as voting_buttons_row: | |
| thumbs_up_btn = gr.Button("π Looks Accurate", variant="primary") | |
| thumbs_down_btn = gr.Button("π Looks Wrong", variant="secondary") | |
| def analyze_and_show_voting(text): | |
| binary_score, category_table_val, text_id, voting_html = analyze_text(text) | |
| show_vote = gr.update(visible=True) if text_id else gr.update(visible=False) | |
| return binary_score, category_table_val, text_id, show_vote, "", "" | |
| analyze_btn.click( | |
| analyze_and_show_voting, | |
| inputs=[text_input], | |
| outputs=[binary_output, category_table, current_text_id, voting_buttons_row, voting_feedback, voting_feedback] | |
| ) | |
| text_input.submit( | |
| analyze_and_show_voting, | |
| inputs=[text_input], | |
| outputs=[binary_output, category_table, current_text_id, voting_buttons_row, voting_feedback, voting_feedback] | |
| ) | |
| thumbs_up_btn.click(vote_thumbs_up, inputs=[current_text_id], outputs=[voting_feedback]) | |
| thumbs_down_btn.click(vote_thumbs_down, inputs=[current_text_id], outputs=[voting_feedback]) | |
| with gr.Tab("Guardrail Comparison"): | |
| gr.HTML(DISCLAIMER) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### π΅ No Moderation") | |
| chatbot_no_mod = gr.Chatbot(height=650, label="No Moderation", show_label=False, bubble_full_width=False, type='messages') | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### π OpenAI Moderation") | |
| chatbot_openai = gr.Chatbot(height=650, label="OpenAI Moderation", show_label=False, bubble_full_width=False, type='messages') | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### π‘οΈ LionGuard 2") | |
| chatbot_lg = gr.Chatbot(height=650, label="LionGuard 2", show_label=False, bubble_full_width=False, type='messages') | |
| gr.Markdown("##### π¬ Send Message to All Models") | |
| with gr.Row(): | |
| message_input = gr.Textbox( | |
| placeholder="Type your message to compare responses...", | |
| show_label=False, | |
| scale=4 | |
| ) | |
| send_btn = gr.Button("Send", variant="primary", scale=1) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear All Chats", variant="stop") | |
| send_btn.click( | |
| process_message, | |
| inputs=[message_input, chatbot_no_mod, chatbot_openai, chatbot_lg], | |
| outputs=[chatbot_no_mod, chatbot_openai, chatbot_lg, message_input] | |
| ) | |
| message_input.submit( | |
| process_message, | |
| inputs=[message_input, chatbot_no_mod, chatbot_openai, chatbot_lg], | |
| outputs=[chatbot_no_mod, chatbot_openai, chatbot_lg, message_input] | |
| ) | |
| clear_btn.click( | |
| clear_all_chats, | |
| outputs=[chatbot_no_mod, chatbot_openai, chatbot_lg] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |