Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision.models import resnet18 | |
| from datasets import load_dataset | |
| from huggingface_hub import hf_hub_download | |
| import numpy as np | |
| import random | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import io | |
| from torch.utils.data import DataLoader | |
| import base64 | |
| # Model architecture definition | |
| class ResNet18_Dropout(nn.Module): | |
| def __init__(self, in_channels, num_classes, dropout_rate=0.3): | |
| super().__init__() | |
| self.model = resnet18(weights=None) | |
| self.model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) | |
| in_features = self.model.fc.in_features | |
| self.model.fc = nn.Sequential( | |
| nn.Dropout(dropout_rate), | |
| nn.Linear(in_features, num_classes) | |
| ) | |
| def forward(self, x): | |
| return self.model(x) | |
| def transform_multispectral_map(example): | |
| image = np.array(example["image"], dtype=np.float32) | |
| if image.ndim != 3 or image.shape[2] != 13: | |
| raise ValueError(f"Expected shape (H, W, 13), got {image.shape}") | |
| # Normalize | |
| image = image / 2750.0 | |
| image = np.clip(image, 0, 1) | |
| # === DATA AUGMENTATION === | |
| # Horizontal flip | |
| if random.random() < 0.5: | |
| image = np.flip(image, axis=1).copy() | |
| # Vertical flip | |
| if random.random() < 0.5: | |
| image = np.flip(image, axis=0).copy() | |
| # Rotation (by 90, 180, 270) | |
| if random.random() < 0.5: | |
| k = random.choice([1, 2, 3]) | |
| image = np.rot90(image, k=k, axes=(0, 1)).copy() | |
| # === SHAPE FORMAT === | |
| image = image.transpose(2, 0, 1) # (C=13, H, W) | |
| return { | |
| "image": torch.tensor(image, dtype=torch.float32), | |
| "label": torch.tensor(example["label"], dtype=torch.long) | |
| } | |
| # RGB conversion functions | |
| def load_rgb_from_multispectral_sample(numpy_array): | |
| """ | |
| Takes a NumPy array with 13 multispectral bands and returns a scaled RGB NumPy array. | |
| Equivalent to loading bands 4-3-2 and scaling as GDAL would. | |
| """ | |
| # GDAL-style scaling: scale 0β2750 -> 1β255 | |
| def scale_band(band): | |
| band = np.clip((band / 2750) * 255, 0, 255) | |
| return band.astype(np.uint8) | |
| # Bands 4 (red), 3 (green), 2 (blue) => index 3, 2, 1 in 0-based | |
| bands = [3, 2, 1] | |
| # Ensure the input is a NumPy array | |
| if not isinstance(numpy_array, np.ndarray): | |
| raise TypeError("Input must be a NumPy array") | |
| # Check if the array has the expected number of channels (13) | |
| if numpy_array.shape[-1] != 13: | |
| raise ValueError(f"Input array must have 13 channels, but got {numpy_array.shape[-1]}") | |
| # Extract and scale the RGB bands from the NumPy array | |
| rgb = np.stack([scale_band(numpy_array[:, :, b]) for b in bands], axis=-1) | |
| return rgb | |
| def load_rgb_from_transformed_tensor(tensor_image): | |
| """ | |
| Takes a torch.Tensor with 13 multispectral bands (C, H, W) and returns a scaled RGB NumPy array. | |
| """ | |
| if not isinstance(tensor_image, torch.Tensor): | |
| raise TypeError("Input must be a torch.Tensor") | |
| if tensor_image.shape[0] != 13: | |
| raise ValueError(f"Expected 13 channels, got {tensor_image.shape[0]}") | |
| # Convert to NumPy (C, H, W) β (H, W, C) | |
| np_image = tensor_image.numpy() | |
| np_image = np.transpose(np_image, (1, 2, 0)) # (H, W, 13) | |
| # Bands 4-3-2 β index 3, 2, 1 | |
| bands = [3, 2, 1] | |
| def scale_band(band): | |
| band = np.clip((band * 255), 0, 255) | |
| return band.astype(np.uint8) | |
| rgb = np.stack([scale_band(np_image[:, :, b]) for b in bands], axis=-1) # (H, W, 3) | |
| return rgb | |
| # Global variables for model and dataset | |
| model = None | |
| dataset = None | |
| label_names = None | |
| label2id = None | |
| id2label = None | |
| def load_model_and_data(): | |
| """Load the model and dataset""" | |
| global model, dataset, label_names, label2id, id2label | |
| try: | |
| # Load dataset | |
| print("Loading dataset...") | |
| dataset = load_dataset("blanchon/EuroSAT_MSI", cache_dir="./hf_cache", streaming=False) | |
| dataset["test"] = dataset["test"].map(transform_multispectral_map) | |
| dataset["test"].set_format(type="torch", columns=["image", "label"]) | |
| # Setup labels | |
| label_names = dataset["train"].features['label'].names | |
| label2id = {name: i for i, name in enumerate(label_names)} | |
| id2label = {v: k for k, v in label2id.items()} | |
| num_classes = len(label_names) | |
| # Load model | |
| print("Loading model...") | |
| model_path = hf_hub_download(repo_id="Rhodham96/Resnet18DropoutSentinel", filename="pytorch_model.bin") | |
| model = ResNet18_Dropout(in_channels=13, num_classes=num_classes) | |
| model.load_state_dict(torch.load(model_path, map_location='cpu')) | |
| model.eval() | |
| print(f"Model and dataset loaded successfully!") | |
| print(f"Classes: {label_names}") | |
| return True | |
| except Exception as e: | |
| print(f"Error loading model or dataset: {str(e)}") | |
| return False | |
| def predict_images(): | |
| """Process 16 random images and return results""" | |
| global model, dataset, id2label | |
| if model is None or dataset is None: | |
| return "Model or dataset not loaded. Please wait for initialization." | |
| test_dataloader = DataLoader(dataset["test"], batch_size=32, shuffle=True) | |
| try: | |
| # Get 16 random samples from validation set | |
| num_batches = 5 | |
| collected_images = [] | |
| collected_labels = [] | |
| collected_preds = [] | |
| #criterion = nn.CrossEntropyLoss() | |
| model.eval() | |
| with torch.no_grad(): | |
| for i, batch in enumerate(test_dataloader): | |
| if i >= num_batches: | |
| break | |
| images = batch['image'] | |
| labels = batch['label'] | |
| outputs = model(images) | |
| _, preds = outputs.max(1) | |
| collected_images.append(images.cpu()) | |
| collected_labels.append(labels.cpu()) | |
| collected_preds.append(preds.cpu()) | |
| # Concatenate all samples | |
| images = torch.cat(collected_images) | |
| labels = torch.cat(collected_labels) | |
| preds = torch.cat(collected_preds) | |
| # Randomly select 10 indices | |
| indices = random.sample(range(len(images)), 10) | |
| # Prepare for plotting | |
| selected_images = images[indices] | |
| selected_labels = labels[indices] | |
| selected_preds = preds[indices] | |
| image_to_see_layers = selected_images[0] | |
| label_to_see_layers = selected_labels[0] | |
| # Plot | |
| fig, axes = plt.subplots(2, 5, figsize=(15, 6)) | |
| axes = axes.flatten() | |
| for i in range(10): | |
| img = load_rgb_from_transformed_tensor(selected_images[i]) | |
| axes[i].imshow(img) | |
| axes[i].axis("off") | |
| true_label = id2label[selected_labels[i].item()] | |
| pred_label = id2label[selected_preds[i].item()] | |
| color = "green" if pred_label == true_label else "red" | |
| axes[i].set_title(f"T: {true_label}\nP: {pred_label}", color=color) | |
| plt.tight_layout() | |
| # Convert plot to image | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') | |
| buf.seek(0) | |
| plt.close() | |
| # Convert to PIL Image | |
| result_image = Image.open(buf) | |
| # Calculate accuracy | |
| correct_predictions = (selected_preds == selected_labels).sum().item() | |
| accuracy = correct_predictions / len(selected_labels) * 100 | |
| summary = f"Accuracy: {correct_predictions}/{len(selected_labels)} ({accuracy:.1f}%)\n" | |
| summary += f"Classes: {', '.join(label_names)}" | |
| return result_image, summary | |
| except Exception as e: | |
| error_msg = f"Error during prediction: {str(e)}" | |
| print(error_msg) | |
| # Return a placeholder image and error message | |
| placeholder = Image.new('RGB', (800, 600), color='lightgray') | |
| return placeholder, error_msg | |
| def create_interface(): | |
| """Create the Gradio interface""" | |
| # Initialize model and data | |
| init_success = load_model_and_data() | |
| if not init_success: | |
| def error_function(): | |
| placeholder = Image.new('RGB', (800, 600), color='red') | |
| return placeholder, "Failed to load model or dataset. Please check the logs." | |
| interface = gr.Interface( | |
| fn=error_function, | |
| inputs=[], | |
| outputs=[ | |
| gr.Image(type="pil", label="Results"), | |
| gr.Textbox(label="Summary") | |
| ], | |
| title="π°οΈ Satellite Image Classification - ERROR", | |
| description="Failed to initialize the application." | |
| ) | |
| return interface | |
| # Create the main interface | |
| interface = gr.Interface( | |
| fn=predict_images, | |
| inputs=[], | |
| outputs=[ | |
| gr.Image(type="pil", label="Classification Results (10 Random Images)"), | |
| gr.Textbox(label="Summary", lines=3) | |
| ], | |
| title="π°οΈ Satellite Image Classification with ResNet18", | |
| description=""" | |
| This app classifies satellite images from the EuroSAT dataset using a trained ResNet18 model. | |
| **How it works:** | |
| - Loads 10 random satellite images from the test set | |
| - Each image has 13 spectral bands, converted to RGB for display | |
| - Shows true labels vs predicted labels | |
| - Green titles = correct predictions, Red titles = incorrect predictions | |
| **Dataset:** EuroSAT with 13 multispectral bands | |
| **Model:** ResNet18 with dropout, trained on 13-channel input | |
| Click "Generate" to process 10 new random images! | |
| """, | |
| examples=[], | |
| cache_examples=False, | |
| allow_flagging="never" | |
| ) | |
| return interface | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch(share=True) |