essense 3b v1.2 finetuned on prompts extracted from qwen 3 8b

vae decoder usage:

vae_checkpoint = torch.load("cosine_vae_final.pth", map_location=device)

class VAEDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.Linear(512, 512), nn.GELU(),
            nn.Linear(512, 1024), nn.GELU(),
            nn.Linear(1024, 2048), nn.GELU(),
            nn.Linear(2048, 6*2048), # 6-token embeddings
        )
    def forward(self, z):
        return self.decoder(z)

vae_decoder = VAEDecoder().to(device)
vae_checkpoint = torch.load("cosine_vae_final.pth", map_location=device, weights_only=False)
vae_decoder.eval()

global_mean = vae_checkpoint['global_mean']
global_std = vae_checkpoint['global_std']

batch_size = 8
z = torch.randn(batch_size, 512, device=device, requires_grad=True)
generated_flat = vae_decoder(z)
generated_flat = generated_flat * global_std + global_mean
sequences = generated_flat.reshape(-1, 6, 2048)
# output: (8, 6, 2048) aka 8 generated embeddings each 6 tokens long
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support