{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import cv2\n", "import pandas as pd\n", "from sklearn import svm\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import mean_squared_error\n", "import os\n", "import numpy as np\n", "import cv2\n", "from sklearn import svm\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import mean_squared_error, precision_score, recall_score\n", "import os\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "def extract_features(image_path):\n", " # Read the image\n", " img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)\n", " \n", " # Resize the image to a fixed size\n", " img = cv2.resize(img, (200, 200))\n", " \n", " # Extract HOG features\n", " hog = cv2.HOGDescriptor()\n", " features = hog.compute(img)\n", " \n", " return features.flatten()\n", "\n", "def load_yolo_annotations(annotation_path, img_width, img_height):\n", " with open(annotation_path, 'r') as file:\n", " lines = file.readlines()\n", " \n", " for line in lines:\n", " parts = line.strip().split()\n", " class_id = int(parts[0])\n", " if class_id == 3:\n", " x_center = float(parts[1]) * img_width\n", " y_center = float(parts[2]) * img_height\n", " width = float(parts[3]) * img_width\n", " height = float(parts[4]) * img_height\n", " \n", " # Convert from YOLO format (center x, center y, width, height) to (x, y, width, height)\n", " x = x_center - (width / 2)\n", " y = y_center - (height / 2)\n", " return [x, y, width, height]\n", " return None" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "# Load dataset\n", "X = [] # Features\n", "y = [] # Labels (bounding box coordinates)\n", "\n", "# Path to your dataset and annotations\n", "dataset_path = \"C:/Users/keese/term_project/data/processed/training/images\"\n", "annotations_path = \"C:/Users/keese/term_project/data/processed/training/labels\"\n", "\n", "\n", "for filename in os.listdir(dataset_path):\n", " if filename.endswith(\".jpg\") or filename.endswith(\".png\"):\n", " image_path = os.path.join(dataset_path, filename)\n", " annotation_file = os.path.join(annotations_path, filename.replace('.jpg', '.txt').replace('.png', '.txt'))\n", " \n", " if not os.path.exists(annotation_file):\n", " print(f\"Warning: Annotation file not found for {image_path}\")\n", " continue\n", " \n", " # Read the image to get its dimensions\n", " img = cv2.imread(image_path)\n", " img_height, img_width = img.shape[:2]\n", " \n", " # Extract features\n", " features = extract_features(image_path)\n", " X.append(features)\n", " \n", " # Load bounding box coordinates from YOLO annotations\n", " bbox = load_yolo_annotations(annotation_file, img_width, img_height)\n", " y.append(bbox)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X = np.array(X)\n", "y = np.array(y)\n", "y = np.array([bbox if bbox is not None else [np.nan, np.nan, np.nan, np.nan] for bbox in y], dtype=float)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Now create a mask for valid y values\n", "valid_mask = ~np.isnan(y).any(axis=1) # Create a mask for valid y values\n", "X = X[valid_mask] # Filter X\n", "y = y[valid_mask] # Filter y" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Split the dataset\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n", "\n", "# Create and train the SVM model for each coordinate\n", "models = {\n", " 'x': svm.SVR(kernel='linear'),\n", " 'y': svm.SVR(kernel='linear'),\n", " 'width': svm.SVR(kernel='linear'),\n", " 'height': svm.SVR(kernel='linear')\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Train each model separately\n", "for coord in models:\n", " coord_index = ['x', 'y', 'width', 'height'].index(coord)\n", " models[coord].fit(X_train, y_train[:, coord_index])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Evaluate the model\n", "y_pred = np.column_stack([models[coord].predict(X_test) for coord in models])\n", "y_test = np.array(y_test)\n", "\n", "\n", "mse = mean_squared_error(y_test, y_pred)\n", "precision = precision_score(y_test, y_pred)\n", "recall = recall_score(y_test, y_pred)\n", "print(f\"Mean Squared Error: {mse}\")\n", "print(f\"Precision: {precision}\")\n", "print(f\"Recall: {recall}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Function to predict the bounding box of a table in an image\n", "def predict_table_bbox(image_path):\n", " features = extract_features(image_path)\n", " bbox = [models[coord].predict([features])[0] for coord in models]\n", " return bbox\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "from matplotlib.patches import Rectangle\n", "import cv2\n", "\n", "def visualize_predictions(image_path, predictions):\n", " # Load the image\n", " img = cv2.imread(image_path)\n", " img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Convert BGR to RGB\n", "\n", " # Create a figure and axis\n", " fig, ax = plt.subplots(1)\n", " fig.set_size_inches(20, 20)\n", "\n", " # Display the image\n", " ax.imshow(img)\n", " ax.axis(\"off\") # Hide the axes\n", "\n", " # Loop through the predictions and draw rectangles\n", " for pred in predictions:\n", " x, y, w, h, cls = pred # Assuming pred is in the format (x, y, width, height, score, class_id)\n", " \n", " # Create a rectangle patch\n", " rect = Rectangle((x, y), w, h, linewidth=2, edgecolor='r', facecolor='none')\n", " ax.add_patch(rect)\n", " \n", " # Optionally, add a label with the class name and score\n", " ax.text(x + w / 2, y, f'{cls}', color='r', ha='center', va='bottom')\n", "\n", " # Show the plot\n", " plt.show()\n", "\n", "\n", "\n", "image_path = \"C:/Users/keese/term_project/Document_layout_Detection_Yolov8/training/images/PMC2987860_00002.jpg\"\n", "\n", "preds = predict_table_bbox(image_path)\n", "\n", "predictions = [\n", " [preds[0], preds[1], preds[2], preds[3], 'Table']\n", "]\n", "\n", "visualize_predictions(image_path, predictions)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "image_path = \"C:/Users/keese/term_project/Document_layout_Detection_Yolov8/training/images/PMC3033327_00002.jpg\"\n", "\n", "preds = predict_table_bbox(image_path)\n", "# Example predictions: list of [x, y, width, height, score, class_id]\n", "predictions = [\n", " [preds[0], preds[1], preds[2], preds[3], 'Table']\n", "]\n", "\n", "visualize_predictions(image_path, predictions)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "image_path = \"C:/Users/keese/term_project/Document_layout_Detection_Yolov8/validation/images/PMC2639556_00006.jpg\"\n", "\n", "preds = predict_table_bbox(image_path)\n", "# Example predictions: list of [x, y, width, height, score, class_id]\n", "predictions = [\n", " [preds[0], preds[1], preds[2], preds[3], 'Table']\n", "]\n", "\n", "visualize_predictions(image_path, predictions)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "image_path = \"C:/Users/keese/term_project/Document_layout_Detection_Yolov8/validation/images/PMC2683799_00002.jpg\"\n", "\n", "preds = predict_table_bbox(image_path)\n", "# Example predictions: list of [x, y, width, height, score, class_id]\n", "predictions = [\n", " [preds[0], preds[1], preds[2], preds[3], 'Table']\n", "]\n", "\n", "visualize_predictions(image_path, predictions)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 2 }