--- license: apache-2.0 datasets: - dmedhi/garbage-image-classification-detection language: - en base_model: - microsoft/resnet-18 tags: - image-classification - resnet-18 - pruning --- # ResNet18 Garbage Classifier This is a ResNet18 model pruned & fine-tuned for classifying different types of garbage. ![image/png](https://cdn-uploads.huggingface.co/production/uploads/657f09ddcec775bfe0fb5539/sosCSmvDjEl4sUEPYSOrZ.png) ## Model Details * **Architecture:** ResNet18 * **Task:** Image Classification ## How to Use for Inference Here's a Python code snippet demonstrating how to load the model and perform inference on a single image: ```python import torch from torchvision import models, transforms from PIL import Image import cv2 # Define the device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the model architecture model = models.resnet18(pretrained=False) num_ftrs = model.fc.in_features model.fc = torch.nn.Linear(num_ftrs, 8) # Load the trained weights model.load_state_dict(torch.load('resnet_18_pruned.pth', map_location=device)) model.eval() model.to(device) # Define the class names class_names = ["Garbage", "Cardboard", "Garbage", "Glass", "Metal", "Paper", "Plastic", "Trash"] # Define the transformations for inference def get_transform(train=False): if train: raise ValueError("This transform is for training, use train=False for inference.") else: return transforms.Compose([ transforms.ToPILImage(), transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def predict_image(model, image_path, transform, class_names): model.eval() image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = transform(image).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(image) _, predicted = torch.max(outputs, 1) print(f"Predicted Class ID: {predicted.item()}") print(f"Predicted Class: {class_names[predicted.item()]}") # Example usage: Replace 'path/to/your/image.jpg' with the actual path image_path = 'path/to/your/image.jpg' transform = get_transform(train=False) predict_image(model, image_path, transform, class_names) ``` ## Intended Use This model is intended for the classification of common garbage types. ## Limitations The accuracy of this model may vary depending on the quality and diversity of the training data. It may not perform well on unseen or unusual types of waste.\ Trained on [dmedhi/garbage-image-classification-detection](https://huggingface.co/datasets/dmedhi/garbage-image-classification-detection) dataset for 50 epochs with a validation loss of 1.49. Accuracy and loss can be optimized with further preprocessing of the dataset. ## Pruning Fine-grained pruning reduced the model size from `42.65 MB` to just `6.45 MB` (15.13% of the original model), and fine-tuning on just 5 epochs helped the model to regain its lost accuracy upto what it has been achieved during training. In the files section, if you check the model, the size is `44 MB` because the weights are still there. They are only reduced to zeroes. To actually check the size of a fine-grained prune model, use `count_nonzero()`. ```python for param in model.parameters(): num_counted_elements += param.count_nonzero() ```