NWTCompany / app.py
AshjanMohammed's picture
Update app.py
4c38303 verified
# -*- coding: utf-8 -*-
"""Defense QA Chatbot - Streamlit Version with Password Protection"""
import streamlit as st
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
import pickle
import os
import warnings
warnings.filterwarnings("ignore")
torch.manual_seed(42)
np.random.seed(42)
# ========== Page Config ==========
st.set_page_config(
page_title="Mission Assistant",
layout="wide",
initial_sidebar_state="collapsed"
)
# ========== Custom CSS ==========
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
* {
font-family: 'Inter', sans-serif;
}
.main {
background: linear-gradient(to bottom, #0f172a 0%, #1e293b 100%);
}
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
header {visibility: hidden;}
.header-container {
background: linear-gradient(135deg, #1e40af 0%, #3b82f6 50%, #60a5fa 100%);
padding: 43px 30px;
text-align: center;
color: white;
border-radius: 16px;
margin: -60px -20px 20px -20px;
box-shadow: 0 8px 32px rgba(59, 130, 246, 0.3);
}
.header-container h1 {
margin: 0 0 12px 0;
font-size: 36px;
font-weight: 700;
}
.header-container p {
margin: 0;
font-size: 16px;
opacity: 0.95;
}
.stChatMessage {
background: rgba(30, 41, 59, 0.6);
border-radius: 12px;
border: 1px solid rgba(148, 163, 184, 0.3);
margin: 10px 0;
}
.stTextInput input, .stTextArea textarea {
background: rgba(30, 41, 59, 0.8) !important;
color: #f1f5f9 !important;
border: 2px solid rgba(148, 163, 184, 0.4) !important;
border-radius: 12px !important;
font-size: 15px !important;
}
.stTextArea textarea {
min-height: 100px !important;
}
.stButton button {
background: linear-gradient(135deg, #2563eb, #1e40af) !important;
color: white !important;
border: none !important;
border-radius: 10px !important;
font-weight: 600 !important;
padding: 12px 24px !important;
width: 100%;
}
.stButton button:hover {
background: linear-gradient(135deg, #1e40af, #1e3a8a) !important;
transform: translateY(-2px);
}
.streamlit-expanderHeader {
background: rgba(51, 65, 85, 0.9) !important;
color: #60a5fa !important;
border-radius: 12px !important;
font-weight: 600 !important;
}
.streamlit-expanderContent {
background: rgba(30, 41, 59, 0.8) !important;
border: 1px solid rgba(148, 163, 184, 0.3) !important;
border-radius: 0 0 12px 12px !important;
}
.info-box {
background: rgba(51, 65, 85, 0.5);
padding: 20px;
border-radius: 10px;
border: 1px solid rgba(148, 163, 184, 0.3);
color: #e2e8f0;
}
.info-box h3 {
color: #60a5fa;
font-size: 16px;
font-weight: 600;
margin-bottom: 12px;
border-bottom: 2px solid #3b82f6;
padding-bottom: 8px;
}
.footer-container {
text-align: center;
padding: 30px 20px;
margin-top: 50px;
margin-bottom: 20px;
color: #64748b;
border-top: 1px solid rgba(148, 163, 184, 0.3);
}
.nwtc-badge {
display: inline-block;
background: rgba(59, 130, 246, 0.1);
padding: 10px 20px;
border-radius: 8px;
margin-top: 10px;
border: 1px solid rgba(59, 130, 246, 0.3);
color: #3b82f6;
font-weight: 600;
}
.login-box {
background: rgba(30, 41, 59, 0.8);
padding: 40px;
border-radius: 16px;
border: 1px solid rgba(148, 163, 184, 0.3);
text-align: center;
max-width: 500px;
margin: 100px auto;
}
</style>
""", unsafe_allow_html=True)
# ========== Configuration ==========
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "huawei-noah/TinyBERT_General_4L_312D"
max_length = 512
MODEL_PATH = "tinybert_defense_aug.pt"
EMBEDDINGS_PATH = "defense_embeddings_p3.pkl"
# ========== Helper Functions ==========
def mean_pooling(last_hidden, mask):
mask = mask.unsqueeze(-1).type_as(last_hidden)
summed = (last_hidden * mask).sum(dim=1)
counts = mask.sum(dim=1).clamp(min=1e-6)
emb = summed / counts
emb = F.normalize(emb, p=2, dim=1)
return emb
# ========== Chatbot Class ==========
class DefenseQAChatbot:
def __init__(self, model, tokenizer, device, embeddings_path):
self.model = model.eval()
self.tok = tokenizer
self.device = device
with open(embeddings_path, 'rb') as f:
saved_data = pickle.load(f)
self.response_embs = saved_data['embeddings']
if 'responses' in saved_data:
self.responses = saved_data['responses']
else:
num_embeddings = len(self.response_embs)
self.responses = [f"Defense Response #{i+1}" for i in range(num_embeddings)]
def _embed_one(self, text):
with torch.no_grad():
enc = self.tok([text], truncation=True, padding="longest",
max_length=max_length, return_tensors="pt").to(self.device)
out = self.model(**enc)
emb = mean_pooling(out.last_hidden_state, enc["attention_mask"])
return emb[0].cpu().numpy()
def get_response(self, user_prompt, top_k=5, reject=0.55):
if not user_prompt.strip():
return "Please ask a question about defense protocols."
q = self._embed_one(user_prompt)
sims = self.response_embs @ q
top = np.argpartition(-sims, min(top_k, len(sims)-1))[:top_k]
top = top[np.argsort(-sims[top])]
best = top[0]
score = float(sims[best])
# إذا الثقة منخفضة جداً
if score < reject:
return "I couldn't find a reliable answer. Please try rephrasing your question or ask about specific defense protocols and procedures."
# إرجاع الإجابة بدون تقييم
response_text = self.responses[best]
return response_text
# ========== Password Protection ==========
# ========== Password Protection ==========
def check_password():
"""Returns True if user entered correct password"""
def password_entered():
"""Checks whether password is correct"""
# غيّر كلمة السر هنا
CORRECT_PASSWORD = "NWTC@2025"
if st.session_state["password"] == CORRECT_PASSWORD:
st.session_state["password_correct"] = True
del st.session_state["password"]
else:
st.session_state["password_correct"] = False
if "password_correct" not in st.session_state:
st.markdown("""
<div class="header-container">
<div style="display: flex; align-items: center; justify-content: center; gap: 20px;">
<h1 style="margin: 0;">Mission Assistant</h1>
</div>
<p style="margin-top: 15px; font-size: 18px;">Secure Access Required</p>
</div>
""", unsafe_allow_html=True)
col1, col2, col3 = st.columns([1, 2, 1])
with col2:
st.markdown("""
<div class="login-box">
<h2 style="color: #60a5fa; margin-bottom: 20px;">🔐 Enter Access Code</h2>
<p style="color: #cbd5e1; margin-bottom: 30px;">
This system is restricted to authorized personnel only
</p>
</div>
""", unsafe_allow_html=True)
st.text_input(
"Password",
type="password",
on_change=password_entered,
key="password",
label_visibility="collapsed",
placeholder="Enter your access code..."
)
st.markdown("""
<p style="text-align: center; color: #64748b; font-size: 13px; margin-top: 20px;">
🇸🇦 Made by NWTC
</p>
""", unsafe_allow_html=True)
return False
elif not st.session_state["password_correct"]:
st.markdown("""
<div class="header-container">
<h1>Mission Assistant</h1>
<p style="margin-top: 15px; font-size: 18px;">Secure Access Required</p>
</div>
""", unsafe_allow_html=True)
col1, col2, col3 = st.columns([1, 2, 1])
with col2:
st.markdown("""
<div class="login-box">
<h2 style="color: #60a5fa; margin-bottom: 20px;">🔐 Enter Access Code</h2>
<p style="color: #cbd5e1; margin-bottom: 30px;">
This system is restricted to authorized personnel only
</p>
</div>
""", unsafe_allow_html=True)
st.text_input(
"Password",
type="password",
on_change=password_entered,
key="password",
label_visibility="collapsed",
placeholder="Enter your access code..."
)
st.error("❌ Access Denied - Incorrect password")
st.markdown("""
<p style="text-align: center; color: #64748b; font-size: 13px; margin-top: 20px;">
🇸🇦 Made by NWTC
</p>
""", unsafe_allow_html=True)
return False
else:
return True
if not check_password():
st.stop()
# ========== Load Model (Cached) ==========
@st.cache_resource
def load_model():
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
if os.path.exists(MODEL_PATH):
# Try loading with weights_only=False
try:
ckpt = torch.load(MODEL_PATH, map_location=device, weights_only=False)
except:
# Fallback to old method
ckpt = torch.load(MODEL_PATH, map_location=device)
model.load_state_dict(ckpt["model_state"])
model.eval()
chatbot = DefenseQAChatbot(
model=model,
tokenizer=tokenizer,
device=device,
embeddings_path=EMBEDDINGS_PATH
)
return chatbot
except Exception as e:
st.error(f"Error loading model: {str(e)}")
st.stop()
chatbot = load_model()
# ========== Initialize Session State ==========
if "messages" not in st.session_state:
st.session_state.messages = []
# ========== Header ==========
# ========== Header ==========
st.markdown("""
<div class="header-container">
<h1 style="margin: 0;">Mission Assistant</h1>
</div>
""", unsafe_allow_html=True)
# ========== Info Expander ==========
with st.expander("💡 Quick Questions & Examples"):
st.markdown("""
<p style="color: #cbd5e1; font-size: 14px; margin-bottom: 20px; text-align: center;">
Click any question to ask instantly
</p>
""", unsafe_allow_html=True)
# Organize questions by category
categories = {
"C2/STAFF": [
"What elements make up a C2 system in military operations?",
"What is the purpose of coordination measures in staff operations?"
],
"INTEL/RECON": [
"What are the key components of intelligence operations according to ADP 2-0?",
"Why are reconnaissance objectives important in planning missions?",
],
"TACTICAL": [
"What is the goal of tactical weapon positioning?",
"What are common maneuver forms used in offensive operations?",
],
"OE/ENVIRONMENT": [
"What details should be included when describing an area-type disposition?",
"What is terrain analysis and why is it important in military operations?",
"How does precipitation affect operational planning?"
]
}
# Create tabs for categories
tabs = st.tabs(list(categories.keys()))
for tab, (category, questions) in zip(tabs, categories.items()):
with tab:
for i, question in enumerate(questions):
if st.button(
f"→ {question}",
key=f"{category}_{i}",
use_container_width=True
):
st.session_state.messages.append({"role": "user", "content": question})
with st.chat_message("assistant"):
with st.spinner("Analyzing..."):
response = chatbot.get_response(question)
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})
st.rerun()
# ========== Chat Display ==========
# Empty state - only show when no messages
if len(st.session_state.messages) == 0:
st.markdown(
"""
<div style="text-align: center; padding: 100px 20px; color: #64748b;">
<div style="font-size: 100px; font-weight: 700; color: #475569; margin-bottom: 20px;">
NWTC
</div>
<p style="font-size: 15px; margin: 0; color: #94a3b8; line-height: 1.8; max-width: 900px; margin: 0 auto;">
Ask only about:<br>
• C2 (command & staff)<br>
• Intel/Recon (intelligence collection, analysis & reconnaissance tasks)<br>
• OE/Environment (terrain, weather, infrastructure & civil factors)<br>
• Tactical concepts & techniques (high-level; no step-by-step actionable instructions)
</p>
</div>
""",
unsafe_allow_html=True
)
# Chat container with fixed height
chat_container = st.container()
with chat_container:
for message in st.session_state.messages:
if message["role"] == "user":
st.markdown(f"""
<div style="background: linear-gradient(90deg, rgba(59, 130, 246, 0.25), rgba(37, 99, 235, 0.1));
border-left: 4px solid #3b82f6;
border-radius: 12px;
padding: 15px 20px;
margin: 12px 0;">
<div style="color: #60a5fa; font-weight: 600; margin-bottom: 8px;">👤 You</div>
<div style="color: #e2e8f0; line-height: 1.6;">{message["content"]}</div>
</div>
""", unsafe_allow_html=True)
else:
st.markdown(f"""
<div style="background: linear-gradient(90deg, rgba(16, 185, 129, 0.25), rgba(5, 150, 105, 0.1));
border-left: 4px solid #10b981;
border-radius: 12px;
padding: 15px 20px;
margin: 12px 0;">
<div style="color: #34d399; font-weight: 600; margin-bottom: 8px;">🤖 Assistant</div>
<div style="color: #e2e8f0; line-height: 1.6;">{message["content"]}</div>
</div>
""", unsafe_allow_html=True)
# ========== Chat Input & Clear Button ==========
col1, col2 = st.columns([5, 1])
with col1:
prompt = st.chat_input("Enter your inquiry here...")
with col2:
if st.button("🗑️ Clear", use_container_width=True):
st.session_state.messages = []
st.rerun()
# Process message
if prompt:
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("assistant"):
with st.spinner("Analyzing..."):
response = chatbot.get_response(prompt)
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})
st.rerun()
# ========== Footer ==========
st.markdown("""
<div class="footer-container">
<p style="font-size: 14px; margin: 8px 0;">
Powered by Advanced Natural Language Processing
</p>
<div class="nwtc-badge">
🇸🇦 Made by NWTC
</div>
</div>
""", unsafe_allow_html=True)