hackergeek/RADIOCAP13
ROCO Radiology Image Captioning Model
This model is a medical image captioning system designed for radiology reports. It utilizes a frozen ViT encoder for image feature extraction and a custom decoder trained to generate captions. The model was trained on the full ROCO-radiology dataset.
- Encoder:
google/vit-base-patch16-224-in21k(frozen, features cached) - Decoder: Trained on full ROCO dataset (~81k samples) for 3 epochs
- Trainable parameters: Only decoder + ViT biases
- Vocab size: 75460
- Sequence Length: 32
- Generation: Beam search (size=3)
Usage
from transformers import ViTModel
import torch
from PIL import Image
from torchvision import transforms
import json
import os
# Assuming SimpleTokenizer and BiasDecoder classes are available from your training script.
# For a full runnable example, their definitions are included below.
# Re-define necessary components and classes for a self-contained example
IMG_SIZE = 224
SEQ_LEN = 32
VOCAB_SIZE = 75460
transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
])
def preprocess_image(img):
if img is None: raise ValueError("Image is None")
if not isinstance(img, Image.Image): img = Image.fromarray(img)
if img.mode != "RGB": img = img.convert("RGB")
return transform(img)
# SimpleTokenizer class (copy-pasted from notebook for self-contained example)
class SimpleTokenizer:
def __init__(self, word2idx=None):
if word2idx is None:
# Placeholder for actual vocab loading or creation if not loaded from file
self.word2idx = {} # Escaped
else:
self.word2idx = word2idx
self.idx2word = {v: k for k, v in self.word2idx.items()} # Escaped
def encode(self, text, max_len=SEQ_LEN):
tokens = [self.word2idx.get(w, self.word2idx["<PAD>"]) for w in text.lower().split()]
tokens = [self.word2idx["<SOS>"]] + tokens[:max_len-2] + [self.word2idx["<EOS>"]]
tokens += [self.word2idx["<PAD>"]] * (max_len - len(tokens))
return torch.tensor(tokens, dtype=torch.long)
def decode(self, tokens):
return " ".join(self.idx2word.get(t.item(), "<UNK>") for t in tokens if t not in [self.word2idx["<PAD>"], self.word2idx["<SOS>"], self.word2idx["<EOS>"]])
@classmethod
def load(cls, path):
with open(f"{path}/vocab.json", "r") as f: # Correctly escaped
word2idx = json.load(f)
tokenizer = cls(word2idx)
return tokenizer
# BiasDecoder class (copy-pasted from notebook for self-contained example)
class BiasDecoder(torch.nn.Module):
def __init__(self, feature_dim=768, vocab_size=VOCAB_SIZE):
super().__init__()
self.token_emb = torch.nn.Embedding(vocab_size, feature_dim)
self.pos_emb = torch.nn.Embedding(SEQ_LEN-1, feature_dim)
self.final_layer = torch.nn.Linear(feature_dim, vocab_size)
def forward(self, img_feat, target_seq):
x = self.token_emb(target_seq)
pos = torch.arange(x.size(1), device=x.device).clamp(max=self.pos_emb.num_embeddings-1)
x = x + self.pos_emb(pos)
x = x + img_feat.unsqueeze(1)
return self.final_layer(x)
# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load ViT (frozen)
vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
vit.eval()
vit.to(device)
# Load decoder
decoder = BiasDecoder().to(device)
# Assuming 'pytorch_model.bin' is in the current directory or specified path
decoder.load_state_dict(torch.load("pytorch_model.bin", map_location=device))
decoder.eval()
# Load tokenizer
# Assuming 'vocab.json' is in the current directory or specified path
tokenizer = SimpleTokenizer.load("./")
pad_idx = tokenizer.word2idx["<PAD>"]
# Generation function
@torch.no_grad()
def generate_caption(model, img_feat, max_len=SEQ_LEN, beam_size=3):
model.eval()
img_feat = img_feat.to(device)
beams = [([tokenizer.word2idx["<SOS>"]], 0.0)]
for _ in range(max_len - 1):
candidates = []
for seq, score in beams:
inp = torch.tensor(seq + [pad_idx] * (SEQ_LEN - len(seq)), device=device).unsqueeze(0)
logits = model(img_feat, inp)
probs = torch.nn.functional.log_softmax(logits[0, len(seq)-1], dim=-1)
top_p, top_i = torch.topk(probs, beam_size)
for i in range(beam_size):
candidates.append((seq + [top_i[i].item()], score + top_p[i].item()))
beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_size]
if all(s[-1] == tokenizer.word2idx["<EOS>"] for s, _ in beams): break
words = [tokenizer.idx2word.get(i, "<UNK>") for i in beams[0][0][1:] if i != pad_idx]
return " ".join(words)
# Example: Generate a caption for an image
# For a real example, you would load an actual image and process it.
# img_path = "path/to/your/image.jpg"
# image = Image.open(img_path).convert("RGB")
# img_tensor = preprocess_image(image).unsqueeze(0).to(device)
# img_feat = vit(pixel_values=img_tensor).pooler_output
# generated_caption = generate_caption(decoder, img_feat)
# print(f"Generated caption: {generated_caption}")
Evaluation (on ROCO Test Set)
- BLEU-1: N/A
- BLEU-2: N/A
- BLEU-3: N/A
- BLEU-4: N/A
- Overall BLEU Score: N/A
Note: BLEU scores were interrupted during computation. Please re-run the evaluation cell (eXra19D_oqcs) after pushing to get accurate scores.
- Downloads last month
- 15
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support
Model tree for hackergeek/RADIOCAP13
Base model
google/vit-base-patch16-224