Masked Autoencoder for Euclid Images

Overview

This masked autoencoder (MAE) is trained to reconstruct Euclid galaxy images where 90% of the image is masked. The trained model shows superhuman performance at reconstruction.

There is an interactive demo here. Try choosing your own images!

This version is trained on Euclid DR1 (13.6M images, 12M for training). This updates and significantly improves on the otherwise-identical Euclid RR2 model presented in the reference below.

The base model is a custom timm vision transformer; see config.yaml for the exact hyperparameters. These don't matter much; the main things to note are that we need finegrained patches (8x8 pixels here) but the model otherwise need not be large to work well. This version is under 100M parameters and comfortably runs predictions on CPU.

This version is as presented in the citation below, except that I have removed the angular-scale-dependent positional encoding for simplicity (it turns out to not be necessary for good performance), and we trained on Euclid DR1. Please cite the workshop paper if you find our work helpful.

@misc{wu2025reenvisioningeuclidgalaxymorphology,
    title={Re-envisioning Euclid Galaxy Morphology: Identifying and Interpreting Features with Sparse Autoencoders}, 
    author={John F. Wu and Michael Walmsley},
    year={2025},
    eprint={2510.23749},
    archivePrefix={arXiv},
    primaryClass={astro-ph.IM},
    url={https://arxiv.org/abs/2510.23749}, 
}

Instructions

Quickstart - Download

import mae_timm_simplified  # download this script from the "files and versions" tab

import omegaconf
from huggingface_hub import hf_hub_download

cfg_path = hf_hub_download(repo_id="mwalmsley/euclid-rr2-mae", filename="config.yaml")
cfg = omegaconf.OmegaConf.load(cfg_path)
mae = mae_timm_simplified.MAE.from_pretrained("mwalmsley/euclid-rr2-mae", cfg=cfg)

Quickstart - Make Prediction

from PIL import Image
import torch
from lightly.models.utils import random_token_mask

image = Image.open('foo.jpg')

image = preprocess_image(image)

batch = {
    'image': image.unsqueeze(0),  # (1, 3, H, W)
    'id_str': ['dummy' ],  # required by my convention, ignore
    }


# utility function for generating random patch indices to mask
_, idx_mask = random_token_mask(
    size=(1, mae.sequence_length),  # (batch_size, seq_len)
    mask_ratio=0.9,  # your choice
    device=batch['image'].device,
)

with torch.no_grad():
    result = mae.predict(batch, idx_mask=idx_mask)

result is a dict with keys for the original image, masked image, reconstructed image, reconstruction loss, and embedding.

preprocess_image is a torchvision transform that returns 3x224x224 float tensors normalised from 0 to 1:

from torchvision.transforms import v2

def preprocess_image(image):
    preprocess = transforms.Compose([
        v2.ToImage(),
        transforms.Resize((224, 224)),
        v2.ToDtype(torch.float32, scale=True)
    ])
    return preprocess(image)

idx_mask is a list of patch indices to mask (e.g. [1, 5, 45, ...]). random_token_mask is a small utility from lightly that generates random patch indices - but it's basically equivalent to np.random.choice. For custom masks, read on.

Generate a custom mask

We divide the image into 784 patches, in a grid of 28 by 28 patches. Each patch is 8x8 pixels (covering our 224x224 image).

The patch in the top corner is index 1 (not 0!), and higher indices go left-to-right and then down a row (like reading a page).

For example, to mask only the first 28 patches (the top of the image)

    row_mask = torch.tensor(range(28)) + 1 # 1 to 29
    # copy for all images in the batch
    idx_mask = row_mask.unsqueeze(0).repeat(batch_size, 1)  # (batch_size, num_masked)
    idx_mask = idx_mask.to('cuda')

To mask the middle strip:

    row_mask = torch.tensor(range(28)) + 1 + 13*28

And so on, however you like. Just remember to add 1 for the class token!

Make predictions for the masked patches

    mae = mae.to('cuda')
    with torch.no_grad():
        result = mae.predict(batch, idx_mask=idx_mask)

    # result has keys including images, masked, reconstructed
    # each key is a list of standard PIL images
    images = result['images']
    masked = result['masked']
    reconstructed = result['reconstructed']

    # Visualize the results
    fig, axes = plt.subplots(nrows=3, ncols=8, figsize=(24, 9))
    for i in range(8):
        axes[0, i].imshow(images[i])
        axes[0, i].set_title("Original")
        axes[1, i].imshow(masked[i])
        axes[1, i].set_title("Masked")
        axes[2, i].imshow(reconstructed[i])
        axes[2, i].set_title("Reconstructed")
    plt.tight_layout()
    plt.show()

That's everything you need to know for doing inference (reconstructions or embeddings). To reproduce my training, keep reading.

Download Data

Get a dataset of Euclid images, prepared as Galaxy-Zoo-style jpgs:


from datasets import load_dataset

dataset_dict = load_dataset(
    'mwalmsley/euclid_q1',   # _rr2, _dr1 versions are available to EC members
    name='tiny-v1-gz_arcsinh_vis_y'  # tiny subset for testing
)

Use my utility package galaxy-datasets to load this as a Lightning DataModule, including an appropriate torchvision transform...


from galaxy_datasets.pytorch.galaxy_datamodule import HuggingFaceDataModule
from galaxy_datasets.transforms import default_view_config, get_galaxy_transform


# define augmentations to use
view_config = default_view_config()
view_config.output_size = 224
view_config.erase_iterations = 0  # for simplicity
ssl_image_transform = get_galaxy_transform(cfg=view_config)
# this is just a torchvision Compose transform
# returns 3x224x224 float tensor normalised 0-1.

datamodule = HuggingFaceDataModule(
    dataset_dict=dataset_dict,
    train_transform=ssl_image_transform,
    test_transform=ssl_image_transform,
    batch_size=batch_size,
    num_workers=num_workers,
    prefetch_factor=prefetch_factor
)
datamodule.setup()
# this is just a lightning datamodule
# should yield batches with an 'image' key, see below

# get a batch
test_loader = datamodule.test_dataloader()
for batch in test_loader:
    batch['image'] = batch['image'].to('cuda')
    break

...or you can do this yourself. You should make batches that include an 'image' key which contains

  • BxCx224x224 float tensors normalised from 0 to 1
  • where those tensors are created by transforming (e.g. with torchvision) a GZ-style jpg (download from HuggingFace above)

It might work for other human-friendly jpgs, but that's outside of the training distribution, so no promises.


Walmsley trained the model and Wu ran the sparsity analysis. Additional thanks to Inigo Val Slijepcevic, Micah Bowles, Devina Mohan, Anna Scaife, and Joshua Speagle, for their help and advice. We are grateful to the Euclid Consortium and the European Space Agency for making the data available.

Downloads last month
36
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train mwalmsley/euclid-dr1-mae

Space using mwalmsley/euclid-dr1-mae 1

Collection including mwalmsley/euclid-dr1-mae