AshjanMohammed commited on
Commit
99d4e94
·
verified ·
1 Parent(s): 59d8bcf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -37
app.py CHANGED
@@ -324,44 +324,34 @@ if not check_password():
324
 
325
  # ========== Load Model (Cached) ==========
326
  @st.cache_resource
327
-
328
  def load_model():
329
- tokenizer = AutoTokenizer.from_pretrained(model_name)
330
- model = AutoModel.from_pretrained(model_name).to(device)
331
-
332
- if os.path.exists(MODEL_PATH):
333
- try:
334
- ckpt = torch.load(MODEL_PATH, map_location=device)
335
-
336
- # Case 1: checkpoint is a dict with model_state
337
- if isinstance(ckpt, dict) and "model_state" in ckpt:
338
- model.load_state_dict(ckpt["model_state"], strict=False)
339
- print("Loaded checkpoint with model_state ✅")
340
-
341
- # Case 2: checkpoint is a state_dict directly
342
- elif isinstance(ckpt, dict):
343
- model.load_state_dict(ckpt, strict=False)
344
- print("Loaded raw state_dict ✅")
345
-
346
- # Case 3: checkpoint is a full model object
347
- else:
348
- print("Checkpoint is a full model object — using directly ✅")
349
- model = ckpt.to(device)
350
-
351
- except Exception as e:
352
- print(f"⚠️ Could not load checkpoint ({e}). Using pretrained TinyBERT model.")
353
-
354
- model.eval()
355
-
356
- chatbot = DefenseQAChatbot(
357
- model=model,
358
- tokenizer=tokenizer,
359
- device=device,
360
- embeddings_path=EMBEDDINGS_PATH
361
- )
362
-
363
- return chatbot
364
-
365
  chatbot = load_model()
366
 
367
  # ========== Initialize Session State ==========
 
324
 
325
  # ========== Load Model (Cached) ==========
326
  @st.cache_resource
 
327
  def load_model():
328
+ try:
329
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
330
+ model = AutoModel.from_pretrained(model_name).to(device)
331
+
332
+ if os.path.exists(MODEL_PATH):
333
+ # Try loading with weights_only=False
334
+ try:
335
+ ckpt = torch.load(MODEL_PATH, map_location=device, weights_only=False)
336
+ except:
337
+ # Fallback to old method
338
+ ckpt = torch.load(MODEL_PATH, map_location=device)
339
+
340
+ model.load_state_dict(ckpt["model_state"])
341
+ model.eval()
342
+
343
+ chatbot = DefenseQAChatbot(
344
+ model=model,
345
+ tokenizer=tokenizer,
346
+ device=device,
347
+ embeddings_path=EMBEDDINGS_PATH
348
+ )
349
+
350
+ return chatbot
351
+ except Exception as e:
352
+ st.error(f"Error loading model: {str(e)}")
353
+ st.stop()
354
+
 
 
 
 
 
 
 
 
 
355
  chatbot = load_model()
356
 
357
  # ========== Initialize Session State ==========