# -*- coding: utf-8 -*- """Defense QA Chatbot - Streamlit Version with Password Protection""" import streamlit as st import pandas as pd import numpy as np import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModel import pickle import os import warnings warnings.filterwarnings("ignore") torch.manual_seed(42) np.random.seed(42) # ========== Page Config ========== st.set_page_config( page_title="Mission Assistant", layout="wide", initial_sidebar_state="collapsed" ) # ========== Custom CSS ========== st.markdown(""" """, unsafe_allow_html=True) # ========== Configuration ========== device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_name = "huawei-noah/TinyBERT_General_4L_312D" max_length = 512 MODEL_PATH = "tinybert_defense_aug.pt" EMBEDDINGS_PATH = "defense_embeddings_p3.pkl" # ========== Helper Functions ========== def mean_pooling(last_hidden, mask): mask = mask.unsqueeze(-1).type_as(last_hidden) summed = (last_hidden * mask).sum(dim=1) counts = mask.sum(dim=1).clamp(min=1e-6) emb = summed / counts emb = F.normalize(emb, p=2, dim=1) return emb # ========== Chatbot Class ========== class DefenseQAChatbot: def __init__(self, model, tokenizer, device, embeddings_path): self.model = model.eval() self.tok = tokenizer self.device = device with open(embeddings_path, 'rb') as f: saved_data = pickle.load(f) self.response_embs = saved_data['embeddings'] if 'responses' in saved_data: self.responses = saved_data['responses'] else: num_embeddings = len(self.response_embs) self.responses = [f"Defense Response #{i+1}" for i in range(num_embeddings)] def _embed_one(self, text): with torch.no_grad(): enc = self.tok([text], truncation=True, padding="longest", max_length=max_length, return_tensors="pt").to(self.device) out = self.model(**enc) emb = mean_pooling(out.last_hidden_state, enc["attention_mask"]) return emb[0].cpu().numpy() def get_response(self, user_prompt, top_k=5, reject=0.55): if not user_prompt.strip(): return "Please ask a question about defense protocols." q = self._embed_one(user_prompt) sims = self.response_embs @ q top = np.argpartition(-sims, min(top_k, len(sims)-1))[:top_k] top = top[np.argsort(-sims[top])] best = top[0] score = float(sims[best]) # إذا الثقة منخفضة جداً if score < reject: return "I couldn't find a reliable answer. Please try rephrasing your question or ask about specific defense protocols and procedures." # إرجاع الإجابة بدون تقييم response_text = self.responses[best] return response_text # ========== Password Protection ========== # ========== Password Protection ========== def check_password(): """Returns True if user entered correct password""" def password_entered(): """Checks whether password is correct""" # غيّر كلمة السر هنا CORRECT_PASSWORD = "NWTC@2025" if st.session_state["password"] == CORRECT_PASSWORD: st.session_state["password_correct"] = True del st.session_state["password"] else: st.session_state["password_correct"] = False if "password_correct" not in st.session_state: st.markdown("""
NWTC Logo

Mission Assistant

Secure Access Required

""", unsafe_allow_html=True) col1, col2, col3 = st.columns([1, 2, 1]) with col2: st.markdown("""

🔐 Enter Access Code

This system is restricted to authorized personnel only

""", unsafe_allow_html=True) st.text_input( "Password", type="password", on_change=password_entered, key="password", label_visibility="collapsed", placeholder="Enter your access code..." ) st.markdown("""

🇸🇦 Made by NWTC

""", unsafe_allow_html=True) return False elif not st.session_state["password_correct"]: st.markdown("""

Mission Assistant

Secure Access Required

""", unsafe_allow_html=True) col1, col2, col3 = st.columns([1, 2, 1]) with col2: st.markdown("""

🔐 Enter Access Code

This system is restricted to authorized personnel only

""", unsafe_allow_html=True) st.text_input( "Password", type="password", on_change=password_entered, key="password", label_visibility="collapsed", placeholder="Enter your access code..." ) st.error("❌ Access Denied - Incorrect password") st.markdown("""

🇸🇦 Made by NWTC

""", unsafe_allow_html=True) return False else: return True if not check_password(): st.stop() # ========== Load Model (Cached) ========== @st.cache_resource def load_model(): try: tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModel.from_pretrained(model_name).to(device) if os.path.exists(MODEL_PATH): # Try loading with weights_only=False try: ckpt = torch.load(MODEL_PATH, map_location=device, weights_only=False) except: # Fallback to old method ckpt = torch.load(MODEL_PATH, map_location=device) model.load_state_dict(ckpt["model_state"]) model.eval() chatbot = DefenseQAChatbot( model=model, tokenizer=tokenizer, device=device, embeddings_path=EMBEDDINGS_PATH ) return chatbot except Exception as e: st.error(f"Error loading model: {str(e)}") st.stop() chatbot = load_model() # ========== Initialize Session State ========== if "messages" not in st.session_state: st.session_state.messages = [] # ========== Header ========== # ========== Header ========== st.markdown("""
NWTC Logo

Mission Assistant

""", unsafe_allow_html=True) # ========== Info Expander ========== with st.expander("💡 Quick Questions & Examples"): st.markdown("""

Click any question to ask instantly

""", unsafe_allow_html=True) # Organize questions by category categories = { "C2/STAFF": [ "Explain C2 command structure in operations", "What are staff coordination procedures?", "Decision making process in command" ], "INTEL/RECON": [ "Ground maneuver force target marking conditions", "Intelligence gathering best practices", "Reconnaissance task priorities" ], "MEDICAL": [ "MEDEVAC reference documentation", "Casualty care protocols", "Medical evacuation procedures" ], "TACTICAL": [ "Area description ordering protocols", "Tactical positioning principles", "Maneuver techniques overview" ], "OE/ENVIRONMENT": [ "Terrain analysis methods", "Weather impact on operations", "Infrastructure assessment" ], "CA/PA": [ "Civil affairs protocols", "Public affairs guidelines", "Community engagement procedures" ] } # Create tabs for categories tabs = st.tabs(list(categories.keys())) for tab, (category, questions) in zip(tabs, categories.items()): with tab: for i, question in enumerate(questions): if st.button( f"→ {question}", key=f"{category}_{i}", use_container_width=True ): st.session_state.messages.append({"role": "user", "content": question}) with st.chat_message("assistant"): with st.spinner("Analyzing..."): response = chatbot.get_response(question) st.markdown(response) st.session_state.messages.append({"role": "assistant", "content": response}) st.rerun() # ========== Chat Display ========== # Empty state - only show when no messages if len(st.session_state.messages) == 0: st.markdown( """
NWTC

Ask only about:
• C2 (command & staff)
• CA/PA (civil & public affairs)
• Intel/Recon (intelligence collection, analysis & reconnaissance tasks)
• Medical (casualty care & MEDEVAC)
• OE/Environment (terrain, weather, infrastructure & civil factors)
• Tactical concepts & techniques (high-level; no step-by-step actionable instructions)

""", unsafe_allow_html=True ) # Chat container with fixed height chat_container = st.container() with chat_container: for message in st.session_state.messages: if message["role"] == "user": st.markdown(f"""
👤 You
{message["content"]}
""", unsafe_allow_html=True) else: st.markdown(f"""
🤖 Assistant
{message["content"]}
""", unsafe_allow_html=True) # ========== Chat Input & Clear Button ========== col1, col2 = st.columns([5, 1]) with col1: prompt = st.chat_input("Enter your inquiry here...") with col2: if st.button("🗑️ Clear", use_container_width=True): st.session_state.messages = [] st.rerun() # Process message if prompt: st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("assistant"): with st.spinner("Analyzing..."): response = chatbot.get_response(prompt) st.markdown(response) st.session_state.messages.append({"role": "assistant", "content": response}) st.rerun() # ========== Footer ========== st.markdown(""" """, unsafe_allow_html=True)