Update DINOv3_FT.ipynb
Browse filesAdded code to load the adapter weights
- DINOv3_FT.ipynb +41 -28
DINOv3_FT.ipynb
CHANGED
|
@@ -1691,34 +1691,47 @@
|
|
| 1691 |
{
|
| 1692 |
"cell_type": "code",
|
| 1693 |
"source": [
|
| 1694 |
-
"import torch\n"
|
| 1695 |
-
"from PIL import Image\n"
|
| 1696 |
-
"from typing import List, Dict\n"
|
| 1697 |
-
"\n"
|
| 1698 |
-
"
|
| 1699 |
-
"
|
| 1700 |
-
"\n"
|
| 1701 |
-
"
|
| 1702 |
-
"\n"
|
| 1703 |
-
"
|
| 1704 |
-
"
|
| 1705 |
-
"\n"
|
| 1706 |
-
"
|
| 1707 |
-
"
|
| 1708 |
-
"\n"
|
| 1709 |
-
"
|
| 1710 |
-
"
|
| 1711 |
-
"
|
| 1712 |
-
"
|
| 1713 |
-
"
|
| 1714 |
-
"
|
| 1715 |
-
"
|
| 1716 |
-
"
|
| 1717 |
-
"
|
| 1718 |
-
"
|
| 1719 |
-
"
|
| 1720 |
-
"
|
| 1721 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1722 |
],
|
| 1723 |
"metadata": {
|
| 1724 |
"id": "RGZntYQEaVbA"
|
|
|
|
| 1691 |
{
|
| 1692 |
"cell_type": "code",
|
| 1693 |
"source": [
|
| 1694 |
+
"import torch\n" \
|
| 1695 |
+
"from PIL import Image\n" \
|
| 1696 |
+
"from typing import List, Dict\n" \
|
| 1697 |
+
"\n" \
|
| 1698 |
+
"# --- Load checkpoint ---\n" \
|
| 1699 |
+
"ckpt_path = \"./checkpoints_dinov3_class/best_acc_0.9025.pt\"\n" \
|
| 1700 |
+
"\n" \
|
| 1701 |
+
"model = DinoV3Linear(backbone, hidden_size, num_classes, freeze_backbone=True).to(device)\n" \
|
| 1702 |
+
"checkpoint = torch.load(ckpt_path, map_location=device)\n" \
|
| 1703 |
+
"model.load_state_dict(checkpoint[\"model_state_dict\"])\n" \
|
| 1704 |
+
"model.eval()\n" \
|
| 1705 |
+
"\n" \
|
| 1706 |
+
"# --- Prepare images ---\n" \
|
| 1707 |
+
"images = [\"/content/pizza.jpg\", \"/content/spaghetti.JPG\"]\n" \
|
| 1708 |
+
"\n" \
|
| 1709 |
+
"pil_images = [Image.open(p).convert(\"RGB\") for p in images]\n" \
|
| 1710 |
+
"inputs = image_processor(images=pil_images, return_tensors=\"pt\").to(device)\n" \
|
| 1711 |
+
"\n" \
|
| 1712 |
+
"# --- Inference ---\n" \
|
| 1713 |
+
"with torch.no_grad():\n" \
|
| 1714 |
+
" logits = model(inputs[\"pixel_values\"])\n" \
|
| 1715 |
+
"\n" \
|
| 1716 |
+
"# take top 2 classes\n" \
|
| 1717 |
+
"probs = logits.softmax(dim=-1)\n" \
|
| 1718 |
+
"scores, indices = probs.topk(2, dim=-1)\n" \
|
| 1719 |
+
"\n" \
|
| 1720 |
+
"# --- Format results ---\n" \
|
| 1721 |
+
"results = []\n" \
|
| 1722 |
+
"for path, idxs, scs in zip(images, indices, scores):\n" \
|
| 1723 |
+
" preds = [\n" \
|
| 1724 |
+
" {\n" \
|
| 1725 |
+
" \"label_id\": int(i.item()),\n" \
|
| 1726 |
+
" \"label\": id2label.get(int(i.item()), f\"class_{int(i)}\"),\n" \
|
| 1727 |
+
" \"score\": float(s.item())\n" \
|
| 1728 |
+
" }\n" \
|
| 1729 |
+
" for i, s in zip(idxs, scs)\n" \
|
| 1730 |
+
" ]\n" \
|
| 1731 |
+
" results.append({\"image\": path, \"topk\": preds})\n" \
|
| 1732 |
+
"\n" \
|
| 1733 |
+
"print(results)\n"
|
| 1734 |
+
|
| 1735 |
],
|
| 1736 |
"metadata": {
|
| 1737 |
"id": "RGZntYQEaVbA"
|