| import torch | |
| import torchvision | |
| from torchvision.models import EfficientNet_B2_Weights | |
| from torch import nn | |
| def create_model(num_classes=7): | |
| weights = EfficientNet_B2_Weights.DEFAULT | |
| model = torchvision.models.efficientnet_b2(weights=weights) | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| model.classifier = nn.Sequential( | |
| nn.Dropout(p=0.3), | |
| nn.Linear(model.classifier[1].in_features, num_classes) | |
| ) | |
| return model | |
| def load_model(weights_path="model/effnetb2_dermamnist.pth"): | |
| model = create_model(num_classes=7) | |
| model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu"))) | |
| model.eval() | |
| return model | |