hackergeek/RADIOCAP13

ROCO Radiology Image Captioning Model

This model is a medical image captioning system designed for radiology reports. It utilizes a frozen ViT encoder for image feature extraction and a custom decoder trained to generate captions. The model was trained on the full ROCO-radiology dataset.

  • Encoder: google/vit-base-patch16-224-in21k (frozen, features cached)
  • Decoder: Trained on full ROCO dataset (~81k samples) for 3 epochs
  • Trainable parameters: Only decoder + ViT biases
  • Vocab size: 75460
  • Sequence Length: 32
  • Generation: Beam search (size=3)

Usage

from transformers import ViTModel
import torch
from PIL import Image
from torchvision import transforms
import json
import os

# Assuming SimpleTokenizer and BiasDecoder classes are available from your training script.
# For a full runnable example, their definitions are included below.

# Re-define necessary components and classes for a self-contained example
IMG_SIZE = 224
SEQ_LEN = 32
VOCAB_SIZE = 75460

transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
])

def preprocess_image(img):
    if img is None: raise ValueError("Image is None")
    if not isinstance(img, Image.Image): img = Image.fromarray(img)
    if img.mode != "RGB": img = img.convert("RGB")
    return transform(img)

# SimpleTokenizer class (copy-pasted from notebook for self-contained example)
class SimpleTokenizer:
    def __init__(self, word2idx=None):
        if word2idx is None:
            # Placeholder for actual vocab loading or creation if not loaded from file
            self.word2idx = {} # Escaped
        else:
            self.word2idx = word2idx
        self.idx2word = {v: k for k, v in self.word2idx.items()} # Escaped

    def encode(self, text, max_len=SEQ_LEN):
        tokens = [self.word2idx.get(w, self.word2idx["<PAD>"]) for w in text.lower().split()]
        tokens = [self.word2idx["<SOS>"]] + tokens[:max_len-2] + [self.word2idx["<EOS>"]]
        tokens += [self.word2idx["<PAD>"]] * (max_len - len(tokens))
        return torch.tensor(tokens, dtype=torch.long)

    def decode(self, tokens):
        return " ".join(self.idx2word.get(t.item(), "<UNK>") for t in tokens if t not in [self.word2idx["<PAD>"], self.word2idx["<SOS>"], self.word2idx["<EOS>"]])

    @classmethod
    def load(cls, path):
        with open(f"{path}/vocab.json", "r") as f: # Correctly escaped
            word2idx = json.load(f)
        tokenizer = cls(word2idx)
        return tokenizer

# BiasDecoder class (copy-pasted from notebook for self-contained example)
class BiasDecoder(torch.nn.Module):
    def __init__(self, feature_dim=768, vocab_size=VOCAB_SIZE):
        super().__init__()
        self.token_emb   = torch.nn.Embedding(vocab_size, feature_dim)
        self.pos_emb     = torch.nn.Embedding(SEQ_LEN-1, feature_dim)
        self.final_layer = torch.nn.Linear(feature_dim, vocab_size)

    def forward(self, img_feat, target_seq):
        x = self.token_emb(target_seq)
        pos = torch.arange(x.size(1), device=x.device).clamp(max=self.pos_emb.num_embeddings-1)
        x = x + self.pos_emb(pos)
        x = x + img_feat.unsqueeze(1)
        return self.final_layer(x)

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load ViT (frozen)
vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
vit.eval()
vit.to(device)

# Load decoder
decoder = BiasDecoder().to(device)
# Assuming 'pytorch_model.bin' is in the current directory or specified path
decoder.load_state_dict(torch.load("pytorch_model.bin", map_location=device))
decoder.eval()

# Load tokenizer
# Assuming 'vocab.json' is in the current directory or specified path
tokenizer = SimpleTokenizer.load("./")
pad_idx = tokenizer.word2idx["<PAD>"]

# Generation function
@torch.no_grad()
def generate_caption(model, img_feat, max_len=SEQ_LEN, beam_size=3):
    model.eval()
    img_feat = img_feat.to(device)
    beams = [([tokenizer.word2idx["<SOS>"]], 0.0)]
    for _ in range(max_len - 1):
        candidates = []
        for seq, score in beams:
            inp = torch.tensor(seq + [pad_idx] * (SEQ_LEN - len(seq)), device=device).unsqueeze(0)
            logits = model(img_feat, inp)
            probs = torch.nn.functional.log_softmax(logits[0, len(seq)-1], dim=-1)
            top_p, top_i = torch.topk(probs, beam_size)
            for i in range(beam_size):
                candidates.append((seq + [top_i[i].item()], score + top_p[i].item()))
        beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_size]
        if all(s[-1] == tokenizer.word2idx["<EOS>"] for s, _ in beams): break
    words = [tokenizer.idx2word.get(i, "<UNK>") for i in beams[0][0][1:] if i != pad_idx]
    return " ".join(words)

# Example: Generate a caption for an image
# For a real example, you would load an actual image and process it.
# img_path = "path/to/your/image.jpg"
# image = Image.open(img_path).convert("RGB")
# img_tensor = preprocess_image(image).unsqueeze(0).to(device)
# img_feat = vit(pixel_values=img_tensor).pooler_output
# generated_caption = generate_caption(decoder, img_feat)
# print(f"Generated caption: {generated_caption}")

Evaluation (on ROCO Test Set)

  • BLEU-1: N/A
  • BLEU-2: N/A
  • BLEU-3: N/A
  • BLEU-4: N/A
  • Overall BLEU Score: N/A

Note: BLEU scores were interrupted during computation. Please re-run the evaluation cell (eXra19D_oqcs) after pushing to get accurate scores.

Downloads last month
15
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for hackergeek/RADIOCAP13

Finetuned
(916)
this model

Dataset used to train hackergeek/RADIOCAP13

Space using hackergeek/RADIOCAP13 1