essense 3b v1.2 finetuned on prompts extracted from qwen 3 8b
vae decoder usage:
vae_checkpoint = torch.load("cosine_vae_final.pth", map_location=device)
class VAEDecoder(nn.Module):
def __init__(self):
super().__init__()
self.decoder = nn.Sequential(
nn.Linear(512, 512), nn.GELU(),
nn.Linear(512, 1024), nn.GELU(),
nn.Linear(1024, 2048), nn.GELU(),
nn.Linear(2048, 6*2048), # 6-token embeddings
)
def forward(self, z):
return self.decoder(z)
vae_decoder = VAEDecoder().to(device)
vae_checkpoint = torch.load("cosine_vae_final.pth", map_location=device, weights_only=False)
vae_decoder.eval()
global_mean = vae_checkpoint['global_mean']
global_std = vae_checkpoint['global_std']
batch_size = 8
z = torch.randn(batch_size, 512, device=device, requires_grad=True)
generated_flat = vae_decoder(z)
generated_flat = generated_flat * global_std + global_mean
sequences = generated_flat.reshape(-1, 6, 2048)
# output: (8, 6, 2048) aka 8 generated embeddings each 6 tokens long
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support