Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """Defense QA Chatbot - Streamlit Version with Password Protection""" | |
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModel | |
| import pickle | |
| import os | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| torch.manual_seed(42) | |
| np.random.seed(42) | |
| # ========== Page Config ========== | |
| st.set_page_config( | |
| page_title="Mission Assistant", | |
| layout="wide", | |
| initial_sidebar_state="collapsed" | |
| ) | |
| # ========== Custom CSS ========== | |
| st.markdown(""" | |
| <style> | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap'); | |
| * { | |
| font-family: 'Inter', sans-serif; | |
| } | |
| .main { | |
| background: linear-gradient(to bottom, #0f172a 0%, #1e293b 100%); | |
| } | |
| #MainMenu {visibility: hidden;} | |
| footer {visibility: hidden;} | |
| header {visibility: hidden;} | |
| .header-container { | |
| background: linear-gradient(135deg, #1e40af 0%, #3b82f6 50%, #60a5fa 100%); | |
| padding: 43px 30px; | |
| text-align: center; | |
| color: white; | |
| border-radius: 16px; | |
| margin: -60px -20px 20px -20px; | |
| box-shadow: 0 8px 32px rgba(59, 130, 246, 0.3); | |
| } | |
| .header-container h1 { | |
| margin: 0 0 12px 0; | |
| font-size: 36px; | |
| font-weight: 700; | |
| } | |
| .header-container p { | |
| margin: 0; | |
| font-size: 16px; | |
| opacity: 0.95; | |
| } | |
| .stChatMessage { | |
| background: rgba(30, 41, 59, 0.6); | |
| border-radius: 12px; | |
| border: 1px solid rgba(148, 163, 184, 0.3); | |
| margin: 10px 0; | |
| } | |
| .stTextInput input, .stTextArea textarea { | |
| background: rgba(30, 41, 59, 0.8) !important; | |
| color: #f1f5f9 !important; | |
| border: 2px solid rgba(148, 163, 184, 0.4) !important; | |
| border-radius: 12px !important; | |
| font-size: 15px !important; | |
| } | |
| .stTextArea textarea { | |
| min-height: 100px !important; | |
| } | |
| .stButton button { | |
| background: linear-gradient(135deg, #2563eb, #1e40af) !important; | |
| color: white !important; | |
| border: none !important; | |
| border-radius: 10px !important; | |
| font-weight: 600 !important; | |
| padding: 12px 24px !important; | |
| width: 100%; | |
| } | |
| .stButton button:hover { | |
| background: linear-gradient(135deg, #1e40af, #1e3a8a) !important; | |
| transform: translateY(-2px); | |
| } | |
| .streamlit-expanderHeader { | |
| background: rgba(51, 65, 85, 0.9) !important; | |
| color: #60a5fa !important; | |
| border-radius: 12px !important; | |
| font-weight: 600 !important; | |
| } | |
| .streamlit-expanderContent { | |
| background: rgba(30, 41, 59, 0.8) !important; | |
| border: 1px solid rgba(148, 163, 184, 0.3) !important; | |
| border-radius: 0 0 12px 12px !important; | |
| } | |
| .info-box { | |
| background: rgba(51, 65, 85, 0.5); | |
| padding: 20px; | |
| border-radius: 10px; | |
| border: 1px solid rgba(148, 163, 184, 0.3); | |
| color: #e2e8f0; | |
| } | |
| .info-box h3 { | |
| color: #60a5fa; | |
| font-size: 16px; | |
| font-weight: 600; | |
| margin-bottom: 12px; | |
| border-bottom: 2px solid #3b82f6; | |
| padding-bottom: 8px; | |
| } | |
| .footer-container { | |
| text-align: center; | |
| padding: 30px 20px; | |
| margin-top: 50px; | |
| margin-bottom: 20px; | |
| color: #64748b; | |
| border-top: 1px solid rgba(148, 163, 184, 0.3); | |
| } | |
| .nwtc-badge { | |
| display: inline-block; | |
| background: rgba(59, 130, 246, 0.1); | |
| padding: 10px 20px; | |
| border-radius: 8px; | |
| margin-top: 10px; | |
| border: 1px solid rgba(59, 130, 246, 0.3); | |
| color: #3b82f6; | |
| font-weight: 600; | |
| } | |
| .login-box { | |
| background: rgba(30, 41, 59, 0.8); | |
| padding: 40px; | |
| border-radius: 16px; | |
| border: 1px solid rgba(148, 163, 184, 0.3); | |
| text-align: center; | |
| max-width: 500px; | |
| margin: 100px auto; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # ========== Configuration ========== | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model_name = "huawei-noah/TinyBERT_General_4L_312D" | |
| max_length = 512 | |
| MODEL_PATH = "tinybert_defense_aug.pt" | |
| EMBEDDINGS_PATH = "defense_embeddings_p3.pkl" | |
| # ========== Helper Functions ========== | |
| def mean_pooling(last_hidden, mask): | |
| mask = mask.unsqueeze(-1).type_as(last_hidden) | |
| summed = (last_hidden * mask).sum(dim=1) | |
| counts = mask.sum(dim=1).clamp(min=1e-6) | |
| emb = summed / counts | |
| emb = F.normalize(emb, p=2, dim=1) | |
| return emb | |
| # ========== Chatbot Class ========== | |
| class DefenseQAChatbot: | |
| def __init__(self, model, tokenizer, device, embeddings_path): | |
| self.model = model.eval() | |
| self.tok = tokenizer | |
| self.device = device | |
| with open(embeddings_path, 'rb') as f: | |
| saved_data = pickle.load(f) | |
| self.response_embs = saved_data['embeddings'] | |
| if 'responses' in saved_data: | |
| self.responses = saved_data['responses'] | |
| else: | |
| num_embeddings = len(self.response_embs) | |
| self.responses = [f"Defense Response #{i+1}" for i in range(num_embeddings)] | |
| def _embed_one(self, text): | |
| with torch.no_grad(): | |
| enc = self.tok([text], truncation=True, padding="longest", | |
| max_length=max_length, return_tensors="pt").to(self.device) | |
| out = self.model(**enc) | |
| emb = mean_pooling(out.last_hidden_state, enc["attention_mask"]) | |
| return emb[0].cpu().numpy() | |
| def get_response(self, user_prompt, top_k=5, reject=0.55): | |
| if not user_prompt.strip(): | |
| return "Please ask a question about defense protocols." | |
| q = self._embed_one(user_prompt) | |
| sims = self.response_embs @ q | |
| top = np.argpartition(-sims, min(top_k, len(sims)-1))[:top_k] | |
| top = top[np.argsort(-sims[top])] | |
| best = top[0] | |
| score = float(sims[best]) | |
| # إذا الثقة منخفضة جداً | |
| if score < reject: | |
| return "I couldn't find a reliable answer. Please try rephrasing your question or ask about specific defense protocols and procedures." | |
| # إرجاع الإجابة بدون تقييم | |
| response_text = self.responses[best] | |
| return response_text | |
| # ========== Password Protection ========== | |
| # ========== Password Protection ========== | |
| def check_password(): | |
| """Returns True if user entered correct password""" | |
| def password_entered(): | |
| """Checks whether password is correct""" | |
| # غيّر كلمة السر هنا | |
| CORRECT_PASSWORD = "NWTC@2025" | |
| if st.session_state["password"] == CORRECT_PASSWORD: | |
| st.session_state["password_correct"] = True | |
| del st.session_state["password"] | |
| else: | |
| st.session_state["password_correct"] = False | |
| if "password_correct" not in st.session_state: | |
| st.markdown(""" | |
| <div class="header-container"> | |
| <div style="display: flex; align-items: center; justify-content: center; gap: 20px;"> | |
| <h1 style="margin: 0;">Mission Assistant</h1> | |
| </div> | |
| <p style="margin-top: 15px; font-size: 18px;">Secure Access Required</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| col1, col2, col3 = st.columns([1, 2, 1]) | |
| with col2: | |
| st.markdown(""" | |
| <div class="login-box"> | |
| <h2 style="color: #60a5fa; margin-bottom: 20px;">🔐 Enter Access Code</h2> | |
| <p style="color: #cbd5e1; margin-bottom: 30px;"> | |
| This system is restricted to authorized personnel only | |
| </p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.text_input( | |
| "Password", | |
| type="password", | |
| on_change=password_entered, | |
| key="password", | |
| label_visibility="collapsed", | |
| placeholder="Enter your access code..." | |
| ) | |
| st.markdown(""" | |
| <p style="text-align: center; color: #64748b; font-size: 13px; margin-top: 20px;"> | |
| 🇸🇦 Made by NWTC | |
| </p> | |
| """, unsafe_allow_html=True) | |
| return False | |
| elif not st.session_state["password_correct"]: | |
| st.markdown(""" | |
| <div class="header-container"> | |
| <h1>Mission Assistant</h1> | |
| <p style="margin-top: 15px; font-size: 18px;">Secure Access Required</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| col1, col2, col3 = st.columns([1, 2, 1]) | |
| with col2: | |
| st.markdown(""" | |
| <div class="login-box"> | |
| <h2 style="color: #60a5fa; margin-bottom: 20px;">🔐 Enter Access Code</h2> | |
| <p style="color: #cbd5e1; margin-bottom: 30px;"> | |
| This system is restricted to authorized personnel only | |
| </p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.text_input( | |
| "Password", | |
| type="password", | |
| on_change=password_entered, | |
| key="password", | |
| label_visibility="collapsed", | |
| placeholder="Enter your access code..." | |
| ) | |
| st.error("❌ Access Denied - Incorrect password") | |
| st.markdown(""" | |
| <p style="text-align: center; color: #64748b; font-size: 13px; margin-top: 20px;"> | |
| 🇸🇦 Made by NWTC | |
| </p> | |
| """, unsafe_allow_html=True) | |
| return False | |
| else: | |
| return True | |
| if not check_password(): | |
| st.stop() | |
| # ========== Load Model (Cached) ========== | |
| def load_model(): | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModel.from_pretrained(model_name).to(device) | |
| if os.path.exists(MODEL_PATH): | |
| # Try loading with weights_only=False | |
| try: | |
| ckpt = torch.load(MODEL_PATH, map_location=device, weights_only=False) | |
| except: | |
| # Fallback to old method | |
| ckpt = torch.load(MODEL_PATH, map_location=device) | |
| model.load_state_dict(ckpt["model_state"]) | |
| model.eval() | |
| chatbot = DefenseQAChatbot( | |
| model=model, | |
| tokenizer=tokenizer, | |
| device=device, | |
| embeddings_path=EMBEDDINGS_PATH | |
| ) | |
| return chatbot | |
| except Exception as e: | |
| st.error(f"Error loading model: {str(e)}") | |
| st.stop() | |
| chatbot = load_model() | |
| # ========== Initialize Session State ========== | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # ========== Header ========== | |
| # ========== Header ========== | |
| st.markdown(""" | |
| <div class="header-container"> | |
| <h1 style="margin: 0;">Mission Assistant</h1> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # ========== Info Expander ========== | |
| with st.expander("💡 Quick Questions & Examples"): | |
| st.markdown(""" | |
| <p style="color: #cbd5e1; font-size: 14px; margin-bottom: 20px; text-align: center;"> | |
| Click any question to ask instantly | |
| </p> | |
| """, unsafe_allow_html=True) | |
| # Organize questions by category | |
| categories = { | |
| "C2/STAFF": [ | |
| "What elements make up a C2 system in military operations?", | |
| "What is the purpose of coordination measures in staff operations?" | |
| ], | |
| "INTEL/RECON": [ | |
| "What are the key components of intelligence operations according to ADP 2-0?", | |
| "Why are reconnaissance objectives important in planning missions?", | |
| ], | |
| "TACTICAL": [ | |
| "What is the goal of tactical weapon positioning?", | |
| "What are common maneuver forms used in offensive operations?", | |
| ], | |
| "OE/ENVIRONMENT": [ | |
| "What details should be included when describing an area-type disposition?", | |
| "What is terrain analysis and why is it important in military operations?", | |
| "How does precipitation affect operational planning?" | |
| ] | |
| } | |
| # Create tabs for categories | |
| tabs = st.tabs(list(categories.keys())) | |
| for tab, (category, questions) in zip(tabs, categories.items()): | |
| with tab: | |
| for i, question in enumerate(questions): | |
| if st.button( | |
| f"→ {question}", | |
| key=f"{category}_{i}", | |
| use_container_width=True | |
| ): | |
| st.session_state.messages.append({"role": "user", "content": question}) | |
| with st.chat_message("assistant"): | |
| with st.spinner("Analyzing..."): | |
| response = chatbot.get_response(question) | |
| st.markdown(response) | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| st.rerun() | |
| # ========== Chat Display ========== | |
| # Empty state - only show when no messages | |
| if len(st.session_state.messages) == 0: | |
| st.markdown( | |
| """ | |
| <div style="text-align: center; padding: 100px 20px; color: #64748b;"> | |
| <div style="font-size: 100px; font-weight: 700; color: #475569; margin-bottom: 20px;"> | |
| NWTC | |
| </div> | |
| <p style="font-size: 15px; margin: 0; color: #94a3b8; line-height: 1.8; max-width: 900px; margin: 0 auto;"> | |
| Ask only about:<br> | |
| • C2 (command & staff)<br> | |
| • Intel/Recon (intelligence collection, analysis & reconnaissance tasks)<br> | |
| • OE/Environment (terrain, weather, infrastructure & civil factors)<br> | |
| • Tactical concepts & techniques (high-level; no step-by-step actionable instructions) | |
| </p> | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| # Chat container with fixed height | |
| chat_container = st.container() | |
| with chat_container: | |
| for message in st.session_state.messages: | |
| if message["role"] == "user": | |
| st.markdown(f""" | |
| <div style="background: linear-gradient(90deg, rgba(59, 130, 246, 0.25), rgba(37, 99, 235, 0.1)); | |
| border-left: 4px solid #3b82f6; | |
| border-radius: 12px; | |
| padding: 15px 20px; | |
| margin: 12px 0;"> | |
| <div style="color: #60a5fa; font-weight: 600; margin-bottom: 8px;">👤 You</div> | |
| <div style="color: #e2e8f0; line-height: 1.6;">{message["content"]}</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| else: | |
| st.markdown(f""" | |
| <div style="background: linear-gradient(90deg, rgba(16, 185, 129, 0.25), rgba(5, 150, 105, 0.1)); | |
| border-left: 4px solid #10b981; | |
| border-radius: 12px; | |
| padding: 15px 20px; | |
| margin: 12px 0;"> | |
| <div style="color: #34d399; font-weight: 600; margin-bottom: 8px;">🤖 Assistant</div> | |
| <div style="color: #e2e8f0; line-height: 1.6;">{message["content"]}</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # ========== Chat Input & Clear Button ========== | |
| col1, col2 = st.columns([5, 1]) | |
| with col1: | |
| prompt = st.chat_input("Enter your inquiry here...") | |
| with col2: | |
| if st.button("🗑️ Clear", use_container_width=True): | |
| st.session_state.messages = [] | |
| st.rerun() | |
| # Process message | |
| if prompt: | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("assistant"): | |
| with st.spinner("Analyzing..."): | |
| response = chatbot.get_response(prompt) | |
| st.markdown(response) | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| st.rerun() | |
| # ========== Footer ========== | |
| st.markdown(""" | |
| <div class="footer-container"> | |
| <p style="font-size: 14px; margin: 8px 0;"> | |
| Powered by Advanced Natural Language Processing | |
| </p> | |
| <div class="nwtc-badge"> | |
| 🇸🇦 Made by NWTC | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) |