AshjanMohammed commited on
Commit
36deff5
Β·
verified Β·
1 Parent(s): dcabbb0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +445 -0
app.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Defense QA Chatbot - Streamlit Version with Password Protection"""
3
+
4
+ import streamlit as st
5
+ import pandas as pd
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from transformers import AutoTokenizer, AutoModel
10
+ import pickle
11
+ import os
12
+ import warnings
13
+
14
+ warnings.filterwarnings("ignore")
15
+ torch.manual_seed(42)
16
+ np.random.seed(42)
17
+
18
+ # ========== Page Config ==========
19
+ st.set_page_config(
20
+ page_title="Defense Protocol Assistant",
21
+ page_icon="πŸ›‘οΈ",
22
+ layout="wide",
23
+ initial_sidebar_state="collapsed"
24
+ )
25
+
26
+ # ========== Custom CSS ==========
27
+ st.markdown("""
28
+ <style>
29
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
30
+
31
+ * {
32
+ font-family: 'Inter', sans-serif;
33
+ }
34
+
35
+ .main {
36
+ background: linear-gradient(to bottom, #0f172a 0%, #1e293b 100%);
37
+ }
38
+
39
+ #MainMenu {visibility: hidden;}
40
+ footer {visibility: hidden;}
41
+ header {visibility: hidden;}
42
+
43
+ .header-container {
44
+ background: linear-gradient(135deg, #1e40af 0%, #3b82f6 50%, #60a5fa 100%);
45
+ padding: 40px 30px;
46
+ text-align: center;
47
+ color: white;
48
+ border-radius: 16px;
49
+ margin-bottom: 20px;
50
+ box-shadow: 0 8px 32px rgba(59, 130, 246, 0.3);
51
+ }
52
+
53
+ .header-container h1 {
54
+ margin: 0 0 12px 0;
55
+ font-size: 36px;
56
+ font-weight: 700;
57
+ }
58
+
59
+ .header-container p {
60
+ margin: 0;
61
+ font-size: 16px;
62
+ opacity: 0.95;
63
+ }
64
+
65
+ .stChatMessage {
66
+ background: rgba(30, 41, 59, 0.6);
67
+ border-radius: 12px;
68
+ border: 1px solid rgba(148, 163, 184, 0.3);
69
+ margin: 10px 0;
70
+ }
71
+
72
+ .stTextInput input, .stTextArea textarea {
73
+ background: rgba(30, 41, 59, 0.8) !important;
74
+ color: #f1f5f9 !important;
75
+ border: 2px solid rgba(148, 163, 184, 0.4) !important;
76
+ border-radius: 12px !important;
77
+ font-size: 15px !important;
78
+ }
79
+
80
+ .stTextArea textarea {
81
+ min-height: 100px !important;
82
+ }
83
+
84
+ .stButton button {
85
+ background: linear-gradient(135deg, #2563eb, #1e40af) !important;
86
+ color: white !important;
87
+ border: none !important;
88
+ border-radius: 10px !important;
89
+ font-weight: 600 !important;
90
+ padding: 12px 24px !important;
91
+ width: 100%;
92
+ }
93
+
94
+ .stButton button:hover {
95
+ background: linear-gradient(135deg, #1e40af, #1e3a8a) !important;
96
+ transform: translateY(-2px);
97
+ }
98
+
99
+ .streamlit-expanderHeader {
100
+ background: rgba(51, 65, 85, 0.9) !important;
101
+ color: #60a5fa !important;
102
+ border-radius: 12px !important;
103
+ font-weight: 600 !important;
104
+ }
105
+
106
+ .streamlit-expanderContent {
107
+ background: rgba(30, 41, 59, 0.8) !important;
108
+ border: 1px solid rgba(148, 163, 184, 0.3) !important;
109
+ border-radius: 0 0 12px 12px !important;
110
+ }
111
+
112
+ .info-box {
113
+ background: rgba(51, 65, 85, 0.5);
114
+ padding: 20px;
115
+ border-radius: 10px;
116
+ border: 1px solid rgba(148, 163, 184, 0.3);
117
+ color: #e2e8f0;
118
+ }
119
+
120
+ .info-box h3 {
121
+ color: #60a5fa;
122
+ font-size: 16px;
123
+ font-weight: 600;
124
+ margin-bottom: 12px;
125
+ border-bottom: 2px solid #3b82f6;
126
+ padding-bottom: 8px;
127
+ }
128
+
129
+ .footer-container {
130
+ text-align: center;
131
+ padding: 30px 20px;
132
+ margin-top: 30px;
133
+ color: #64748b;
134
+ border-top: 1px solid rgba(148, 163, 184, 0.3);
135
+ }
136
+
137
+ .nwtc-badge {
138
+ display: inline-block;
139
+ background: rgba(59, 130, 246, 0.1);
140
+ padding: 10px 20px;
141
+ border-radius: 8px;
142
+ margin-top: 10px;
143
+ border: 1px solid rgba(59, 130, 246, 0.3);
144
+ color: #3b82f6;
145
+ font-weight: 600;
146
+ }
147
+
148
+ .login-box {
149
+ background: rgba(30, 41, 59, 0.8);
150
+ padding: 40px;
151
+ border-radius: 16px;
152
+ border: 1px solid rgba(148, 163, 184, 0.3);
153
+ text-align: center;
154
+ max-width: 500px;
155
+ margin: 100px auto;
156
+ }
157
+ </style>
158
+ """, unsafe_allow_html=True)
159
+
160
+ # ========== Configuration ==========
161
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
162
+ model_name = "huawei-noah/TinyBERT_General_4L_312D"
163
+ max_length = 512
164
+
165
+ MODEL_PATH = "tinybert_defense_aug.pt"
166
+ EMBEDDINGS_PATH = "defense_embeddings_p3.pkl"
167
+
168
+ # ========== Helper Functions ==========
169
+ def mean_pooling(last_hidden, mask):
170
+ mask = mask.unsqueeze(-1).type_as(last_hidden)
171
+ summed = (last_hidden * mask).sum(dim=1)
172
+ counts = mask.sum(dim=1).clamp(min=1e-6)
173
+ emb = summed / counts
174
+ emb = F.normalize(emb, p=2, dim=1)
175
+ return emb
176
+
177
+ # ========== Chatbot Class ==========
178
+ class DefenseQAChatbot:
179
+ def __init__(self, model, tokenizer, device, embeddings_path):
180
+ self.model = model.eval()
181
+ self.tok = tokenizer
182
+ self.device = device
183
+
184
+ with open(embeddings_path, 'rb') as f:
185
+ saved_data = pickle.load(f)
186
+
187
+ self.response_embs = saved_data['embeddings']
188
+
189
+ if 'responses' in saved_data:
190
+ self.responses = saved_data['responses']
191
+ else:
192
+ num_embeddings = len(self.response_embs)
193
+ self.responses = [f"Defense Response #{i+1}" for i in range(num_embeddings)]
194
+
195
+ def _embed_one(self, text):
196
+ with torch.no_grad():
197
+ enc = self.tok([text], truncation=True, padding="longest",
198
+ max_length=max_length, return_tensors="pt").to(self.device)
199
+ out = self.model(**enc)
200
+ emb = mean_pooling(out.last_hidden_state, enc["attention_mask"])
201
+ return emb[0].cpu().numpy()
202
+
203
+ def get_response(self, user_prompt, top_k=5, reject=0.55):
204
+ if not user_prompt.strip():
205
+ return "Please ask a question about defense protocols."
206
+
207
+ q = self._embed_one(user_prompt)
208
+ sims = self.response_embs @ q
209
+ top = np.argpartition(-sims, min(top_k, len(sims)-1))[:top_k]
210
+ top = top[np.argsort(-sims[top])]
211
+ best = top[0]
212
+ score = float(sims[best])
213
+
214
+ if score < reject:
215
+ return f"I couldn't find a reliable answer with sufficient confidence (score: {score:.2f}). Please try rephrasing your question."
216
+
217
+ response_text = self.responses[best]
218
+
219
+ if score >= 0.80:
220
+ confidence = "🟒 High confidence"
221
+ elif score >= 0.65:
222
+ confidence = "🟑 Medium confidence"
223
+ else:
224
+ confidence = "🟠 Low confidence"
225
+
226
+ return f"{response_text}\n\n---\n*{confidence} β€’ Score: {score:.2f}*"
227
+
228
+ # ========== Password Protection ==========
229
+ def check_password():
230
+ """Returns True if user entered correct password"""
231
+
232
+ def password_entered():
233
+ """Checks whether password is correct"""
234
+ # ΨΊΩŠΩ‘Ψ± ΩƒΩ„Ω…Ψ© Ψ§Ω„Ψ³Ψ± Ω‡Ω†Ψ§
235
+ CORRECT_PASSWORD = "NWTC@2025"
236
+
237
+ if st.session_state["password"] == CORRECT_PASSWORD:
238
+ st.session_state["password_correct"] = True
239
+ del st.session_state["password"]
240
+ else:
241
+ st.session_state["password_correct"] = False
242
+
243
+ if "password_correct" not in st.session_state:
244
+ st.markdown("""
245
+ <div class="header-container">
246
+ <h1>πŸ›‘οΈ DEFENSE PROTOCOL ASSISTANT</h1>
247
+ <p style="margin-top: 15px; font-size: 18px;">Secure Access Required</p>
248
+ </div>
249
+ """, unsafe_allow_html=True)
250
+
251
+ col1, col2, col3 = st.columns([1, 2, 1])
252
+ with col2:
253
+ st.markdown("""
254
+ <div class="login-box">
255
+ <h2 style="color: #60a5fa; margin-bottom: 20px;">πŸ” Enter Access Code</h2>
256
+ <p style="color: #cbd5e1; margin-bottom: 30px;">
257
+ This system is restricted to authorized personnel only
258
+ </p>
259
+ </div>
260
+ """, unsafe_allow_html=True)
261
+
262
+ st.text_input(
263
+ "Password",
264
+ type="password",
265
+ on_change=password_entered,
266
+ key="password",
267
+ label_visibility="collapsed",
268
+ placeholder="Enter your access code..."
269
+ )
270
+
271
+ st.markdown("""
272
+ <p style="text-align: center; color: #64748b; font-size: 13px; margin-top: 20px;">
273
+ πŸ‡ΈπŸ‡¦ NWTC Defense AI Research Division
274
+ </p>
275
+ """, unsafe_allow_html=True)
276
+
277
+ return False
278
+
279
+ elif not st.session_state["password_correct"]:
280
+ st.markdown("""
281
+ <div class="header-container">
282
+ <h1>πŸ›‘οΈ DEFENSE PROTOCOL ASSISTANT</h1>
283
+ <p style="margin-top: 15px; font-size: 18px;">Secure Access Required</p>
284
+ </div>
285
+ """, unsafe_allow_html=True)
286
+
287
+ col1, col2, col3 = st.columns([1, 2, 1])
288
+ with col2:
289
+ st.markdown("""
290
+ <div class="login-box">
291
+ <h2 style="color: #60a5fa; margin-bottom: 20px;">πŸ” Enter Access Code</h2>
292
+ <p style="color: #cbd5e1; margin-bottom: 30px;">
293
+ This system is restricted to authorized personnel only
294
+ </p>
295
+ </div>
296
+ """, unsafe_allow_html=True)
297
+
298
+ st.text_input(
299
+ "Password",
300
+ type="password",
301
+ on_change=password_entered,
302
+ key="password",
303
+ label_visibility="collapsed",
304
+ placeholder="Enter your access code..."
305
+ )
306
+
307
+ st.error("❌ Access Denied - Incorrect password")
308
+
309
+ st.markdown("""
310
+ <p style="text-align: center; color: #64748b; font-size: 13px; margin-top: 20px;">
311
+ πŸ‡ΈπŸ‡¦ NWTC Defense AI Research Division
312
+ </p>
313
+ """, unsafe_allow_html=True)
314
+
315
+ return False
316
+
317
+ else:
318
+ return True
319
+
320
+ if not check_password():
321
+ st.stop()
322
+
323
+ # ========== Load Model (Cached) ==========
324
+ @st.cache_resource
325
+ def load_model():
326
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
327
+ model = AutoModel.from_pretrained(model_name).to(device)
328
+
329
+ if os.path.exists(MODEL_PATH):
330
+ ckpt = torch.load(MODEL_PATH, map_location=device)
331
+ model.load_state_dict(ckpt["model_state"])
332
+ model.eval()
333
+
334
+ chatbot = DefenseQAChatbot(
335
+ model=model,
336
+ tokenizer=tokenizer,
337
+ device=device,
338
+ embeddings_path=EMBEDDINGS_PATH
339
+ )
340
+
341
+ return chatbot
342
+
343
+ chatbot = load_model()
344
+
345
+ # ========== Initialize Session State ==========
346
+ if "messages" not in st.session_state:
347
+ st.session_state.messages = []
348
+
349
+ # ========== Header ==========
350
+ st.markdown("""
351
+ <div class="header-container">
352
+ <h1>πŸ›‘οΈ DEFENSE PROTOCOL ASSISTANT</h1>
353
+ <p>Advanced AI-Powered Military Intelligence System</p>
354
+ <p style="margin-top: 8px; font-size: 14px; opacity: 0.9;">
355
+ Fine-tuned TinyBERT Neural Network | Real-time Semantic Search
356
+ </p>
357
+ </div>
358
+ """, unsafe_allow_html=True)
359
+
360
+ # ========== Info Expander ==========
361
+ with st.expander("ℹ️ System Information & Examples"):
362
+ col1, col2 = st.columns(2)
363
+
364
+ with col1:
365
+ st.markdown(f"""
366
+ <div class="info-box">
367
+ <h3>βš™οΈ SYSTEM SPECIFICATIONS</h3>
368
+ <ul style="list-style: none; padding-left: 0;">
369
+ <li style="padding: 6px 0; border-left: 3px solid #3b82f6; padding-left: 12px; margin: 8px 0;">
370
+ <strong style="color: #93c5fd;">Knowledge Base:</strong> {len(chatbot.responses):,} Protocols
371
+ </li>
372
+ <li style="padding: 6px 0; border-left: 3px solid #3b82f6; padding-left: 12px; margin: 8px 0;">
373
+ <strong style="color: #93c5fd;">Compute:</strong> {str(device).upper()}
374
+ </li>
375
+ <li style="padding: 6px 0; border-left: 3px solid #3b82f6; padding-left: 12px; margin: 8px 0;">
376
+ <strong style="color: #93c5fd;">Architecture:</strong> TinyBERT-4L-312D
377
+ </li>
378
+ <li style="padding: 6px 0; border-left: 3px solid #3b82f6; padding-left: 12px; margin: 8px 0;">
379
+ <strong style="color: #93c5fd;">Response Time:</strong> &lt;2s Average
380
+ </li>
381
+ </ul>
382
+ </div>
383
+ """, unsafe_allow_html=True)
384
+
385
+ with col2:
386
+ st.markdown("""
387
+ <div class="info-box">
388
+ <h3>πŸ’‘ QUERY EXAMPLES</h3>
389
+ <p style="line-height: 1.8; font-size: 14px;">
390
+ β†’ Ground maneuver force target marking conditions<br>
391
+ β†’ Area description ordering protocols<br>
392
+ β†’ MEDEVAC reference documentation resources<br>
393
+ β†’ Military doctrine compliance procedures
394
+ </p>
395
+ </div>
396
+ """, unsafe_allow_html=True)
397
+
398
+ # ========== Chat Display ==========
399
+ if len(st.session_state.messages) == 0:
400
+ st.markdown("""
401
+ <div style="text-align: center; padding: 100px 20px; color: #64748b;">
402
+ <div style="font-size: 100px; font-weight: 700; color: #475569; margin-bottom: 20px;">
403
+ NWTC
404
+ </div>
405
+ <p style="font-size: 16px; margin: 0; color: #94a3b8;">
406
+ Ask your questions about defense protocols and procedures
407
+ </p>
408
+ </div>
409
+ """, unsafe_allow_html=True)
410
+ else:
411
+ for message in st.session_state.messages:
412
+ with st.chat_message(message["role"]):
413
+ st.markdown(message["content"])
414
+
415
+ # ========== Chat Input ==========
416
+ if prompt := st.chat_input("Enter your defense protocol inquiry here..."):
417
+ st.session_state.messages.append({"role": "user", "content": prompt})
418
+ with st.chat_message("user"):
419
+ st.markdown(prompt)
420
+
421
+ with st.chat_message("assistant"):
422
+ with st.spinner("Analyzing..."):
423
+ response = chatbot.get_response(prompt)
424
+ st.markdown(response)
425
+
426
+ st.session_state.messages.append({"role": "assistant", "content": response})
427
+
428
+ # ========== Clear Button ==========
429
+ col1, col2, col3 = st.columns([3, 1, 3])
430
+ with col2:
431
+ if st.button("πŸ—‘οΈ Clear Chat"):
432
+ st.session_state.messages = []
433
+ st.rerun()
434
+
435
+ # ========== Footer ==========
436
+ st.markdown("""
437
+ <div class="footer-container">
438
+ <p style="font-size: 14px; margin: 8px 0;">
439
+ Powered by Advanced Natural Language Processing
440
+ </p>
441
+ <div class="nwtc-badge">
442
+ πŸ‡ΈπŸ‡¦ NWTC Defense AI Research Division
443
+ </div>
444
+ </div>
445
+ """, unsafe_allow_html=True)