atharvdeore999 commited on
Commit
5f6a162
·
verified ·
1 Parent(s): 8b9c168

Update DINOv3_FT.ipynb

Browse files

Added code to load the adapter weights

Files changed (1) hide show
  1. DINOv3_FT.ipynb +41 -28
DINOv3_FT.ipynb CHANGED
@@ -1691,34 +1691,47 @@
1691
  {
1692
  "cell_type": "code",
1693
  "source": [
1694
- "import torch\n",
1695
- "from PIL import Image\n",
1696
- "from typing import List, Dict\n",
1697
- "\n",
1698
- "\n",
1699
- "model.eval()\n",
1700
- "\n",
1701
- "images = [\"/content/pizza.jpg\", \"/content/spaghetti.JPG\"]\n",
1702
- "\n",
1703
- "pil_images = [Image.open(p).convert(\"RGB\") for p in images]\n",
1704
- "inputs = image_processor(images=pil_images, return_tensors=\"pt\").to(device)\n",
1705
- "\n",
1706
- "with torch.no_grad():\n",
1707
- " logits = model(inputs[\"pixel_values\"])\n",
1708
- "\n",
1709
- "# take top 2 classes\n",
1710
- "probs = logits.softmax(dim=-1)\n",
1711
- "scores, indices = probs.topk(2, dim=-1)\n",
1712
- "\n",
1713
- "results = []\n",
1714
- "for path, idxs, scs in zip(images, indices, scores):\n",
1715
- " preds = [\n",
1716
- " {\"label_id\": int(i.item()),\n",
1717
- " \"label\": id2label.get(int(i.item()), f\"class_{int(i)}\"),\n",
1718
- " \"score\": float(s.item())}\n",
1719
- " for i, s in zip(idxs, scs)\n",
1720
- " ]\n",
1721
- " results.append({\"image\": path, \"topk\": preds})\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
1722
  ],
1723
  "metadata": {
1724
  "id": "RGZntYQEaVbA"
 
1691
  {
1692
  "cell_type": "code",
1693
  "source": [
1694
+ "import torch\n" \
1695
+ "from PIL import Image\n" \
1696
+ "from typing import List, Dict\n" \
1697
+ "\n" \
1698
+ "# --- Load checkpoint ---\n" \
1699
+ "ckpt_path = \"./checkpoints_dinov3_class/best_acc_0.9025.pt\"\n" \
1700
+ "\n" \
1701
+ "model = DinoV3Linear(backbone, hidden_size, num_classes, freeze_backbone=True).to(device)\n" \
1702
+ "checkpoint = torch.load(ckpt_path, map_location=device)\n" \
1703
+ "model.load_state_dict(checkpoint[\"model_state_dict\"])\n" \
1704
+ "model.eval()\n" \
1705
+ "\n" \
1706
+ "# --- Prepare images ---\n" \
1707
+ "images = [\"/content/pizza.jpg\", \"/content/spaghetti.JPG\"]\n" \
1708
+ "\n" \
1709
+ "pil_images = [Image.open(p).convert(\"RGB\") for p in images]\n" \
1710
+ "inputs = image_processor(images=pil_images, return_tensors=\"pt\").to(device)\n" \
1711
+ "\n" \
1712
+ "# --- Inference ---\n" \
1713
+ "with torch.no_grad():\n" \
1714
+ " logits = model(inputs[\"pixel_values\"])\n" \
1715
+ "\n" \
1716
+ "# take top 2 classes\n" \
1717
+ "probs = logits.softmax(dim=-1)\n" \
1718
+ "scores, indices = probs.topk(2, dim=-1)\n" \
1719
+ "\n" \
1720
+ "# --- Format results ---\n" \
1721
+ "results = []\n" \
1722
+ "for path, idxs, scs in zip(images, indices, scores):\n" \
1723
+ " preds = [\n" \
1724
+ " {\n" \
1725
+ " \"label_id\": int(i.item()),\n" \
1726
+ " \"label\": id2label.get(int(i.item()), f\"class_{int(i)}\"),\n" \
1727
+ " \"score\": float(s.item())\n" \
1728
+ " }\n" \
1729
+ " for i, s in zip(idxs, scs)\n" \
1730
+ " ]\n" \
1731
+ " results.append({\"image\": path, \"topk\": preds})\n" \
1732
+ "\n" \
1733
+ "print(results)\n"
1734
+
1735
  ],
1736
  "metadata": {
1737
  "id": "RGZntYQEaVbA"