Spaces:
Sleeping
Sleeping
| from PIL import Image | |
| import io | |
| import torch | |
| import os | |
| import numpy as np | |
| from CNN_encoder import CNN_Encoder | |
| from distil_gpt2 import DistilGPT2 | |
| from configs import argHandler | |
| from utils import load_image, split_report_sections | |
| from tokenizer_wrapper import TokenizerWrapper | |
| from huggingface_hub import hf_hub_download | |
| from api import API_call | |
| # from src.models.cnn_encoder import | |
| # from src.models.distil_gpt2 import DistilGPT2 | |
| # from src.configs import argHandler | |
| FLAGS = argHandler() | |
| def init_model(): | |
| global tokenizer_wrapper, encoder, decoder, optimizer | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print("β Initializing model components...") | |
| from configs import argHandler | |
| FLAGS.setDefaults() | |
| tokenizer_wrapper = TokenizerWrapper( | |
| FLAGS.csv_label_columns[0], | |
| FLAGS.max_sequence_length, | |
| FLAGS.tokenizer_vocab_size | |
| ) | |
| encoder_model_dir = 'pretrained_visual_model' | |
| encoder = CNN_Encoder( | |
| encoder_model_dir, | |
| FLAGS.visual_model_name, | |
| FLAGS.visual_model_pop_layers, | |
| FLAGS.encoder_layers, | |
| FLAGS.tags_threshold, | |
| num_tags=len(FLAGS.tags) | |
| ) | |
| decoder = DistilGPT2.from_pretrained('distilgpt2') | |
| optimizer = torch.optim.Adam(decoder.parameters(), lr=FLAGS.learning_rate) | |
| encoder.to(device) | |
| decoder.to(device) | |
| checkpoint_path = hf_hub_download( | |
| repo_id="TransformingBerry/CDGPT2_checkpoint", | |
| filename="checkpoint.pth" | |
| ) | |
| if os.path.exists(checkpoint_path): | |
| print(f"β Restoring from checkpoint: {checkpoint_path}") | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| encoder.load_state_dict(checkpoint['encoder_state_dict']) | |
| decoder.load_state_dict(checkpoint['decoder_state_dict']) | |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| else: | |
| print("β οΈ No checkpoint found. Starting from scratch.") | |
| print("β Model initialized.") | |
| init_model() | |
| def generate_report(image_bytes): | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| image_tensor = load_image(image) | |
| visual_features, tags_embedding = encoder(image_tensor) | |
| dec_input = torch.unsqueeze( | |
| torch.tensor(tokenizer_wrapper.GPT2_encode('startseq', pad=False)), 0 | |
| ) | |
| generation_config = { | |
| "visual_features": visual_features, | |
| "tags_embedding": tags_embedding, | |
| "num_beams": 1, | |
| "max_length": FLAGS.max_sequence_length, | |
| "min_length": 3, | |
| "eos_token_ids": tokenizer_wrapper.GPT2_eos_token_id(), | |
| "pad_token_id": tokenizer_wrapper.GPT2_pad_token_id(), | |
| "do_sample": False, | |
| "early_stopping": True, | |
| } | |
| tokens = decoder.generate(dec_input, **generation_config) | |
| sentence = tokenizer_wrapper.GPT2_decode(tokens[0]) | |
| sentence = tokenizer_wrapper.filter_special_words(sentence) | |
| print(sentence) | |
| # Call the API to structure the report | |
| structured_report = API_call(sentence) | |
| print(structured_report) | |
| structured_report =split_report_sections(structured_report) | |
| return structured_report |