Ziad Meligy commited on
Commit
f1ec150
·
1 Parent(s): 57fef69

Pushing deployment to space

Browse files
Files changed (1) hide show
  1. generate_report.py +46 -20
generate_report.py CHANGED
@@ -12,32 +12,58 @@ from huggingface_hub import hf_hub_download
12
  # from src.models.cnn_encoder import
13
  # from src.models.distil_gpt2 import DistilGPT2
14
  # from src.configs import argHandler
15
-
16
  FLAGS = argHandler()
17
- FLAGS.setDefaults()
18
- tokenizer_wrapper = TokenizerWrapper( FLAGS.csv_label_columns[0], FLAGS.max_sequence_length, FLAGS.tokenizer_vocab_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- encoder = CNN_Encoder('pretrained_visual_model', FLAGS.visual_model_name, FLAGS.visual_model_pop_layers,
21
- FLAGS.encoder_layers, FLAGS.tags_threshold, num_tags=len(FLAGS.tags))
22
- decoder = DistilGPT2.from_pretrained('distilgpt2')
 
 
 
 
 
23
 
24
- optimizer = torch.optim.Adam(decoder.parameters(), lr=FLAGS.learning_rate)
25
- # checkpoint_path = os.path.join(FLAGS.ckpt_path, "checkpoint.pth")
26
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
- encoder.to(device)
28
- decoder.to(device)
29
 
30
- checkpoint_path = hf_hub_download(repo_id="TransformingBerry/CDGPT2_checkpoint", filename="checkpoint.pth")
31
 
32
 
33
- if os.path.exists(checkpoint_path):
34
- print(f"Restoring from checkpoint: {checkpoint_path}")
35
- checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
36
- encoder.load_state_dict(checkpoint['encoder_state_dict'])
37
- decoder.load_state_dict(checkpoint['decoder_state_dict'])
38
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
39
- else:
40
- print("No checkpoint found. Starting from scratch.")
41
 
42
  def generate_report(image_bytes):
43
  image = Image.open(io.BytesIO(image_bytes))
 
12
  # from src.models.cnn_encoder import
13
  # from src.models.distil_gpt2 import DistilGPT2
14
  # from src.configs import argHandler
 
15
  FLAGS = argHandler()
16
+ def init_model():
17
+ global tokenizer_wrapper, encoder, decoder, optimizer
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ print("✅ Initializing model components...")
21
+
22
+ from configs import argHandler
23
+
24
+ FLAGS.setDefaults()
25
+
26
+ tokenizer_wrapper = TokenizerWrapper(
27
+ FLAGS.csv_label_columns[0],
28
+ FLAGS.max_sequence_length,
29
+ FLAGS.tokenizer_vocab_size
30
+ )
31
+
32
+ encoder_model_dir = 'pretrained_visual_model'
33
+ encoder = CNN_Encoder(
34
+ encoder_model_dir,
35
+ FLAGS.visual_model_name,
36
+ FLAGS.visual_model_pop_layers,
37
+ FLAGS.encoder_layers,
38
+ FLAGS.tags_threshold,
39
+ num_tags=len(FLAGS.tags)
40
+ )
41
+
42
+ decoder = DistilGPT2.from_pretrained('distilgpt2')
43
+ optimizer = torch.optim.Adam(decoder.parameters(), lr=FLAGS.learning_rate)
44
+
45
+ encoder.to(device)
46
+ decoder.to(device)
47
+
48
+ checkpoint_path = hf_hub_download(
49
+ repo_id="TransformingBerry/CDGPT2_checkpoint",
50
+ filename="checkpoint.pth"
51
+ )
52
 
53
+ if os.path.exists(checkpoint_path):
54
+ print(f"✅ Restoring from checkpoint: {checkpoint_path}")
55
+ checkpoint = torch.load(checkpoint_path, map_location=device)
56
+ encoder.load_state_dict(checkpoint['encoder_state_dict'])
57
+ decoder.load_state_dict(checkpoint['decoder_state_dict'])
58
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
59
+ else:
60
+ print("⚠️ No checkpoint found. Starting from scratch.")
61
 
62
+ print("✅ Model initialized.")
 
 
 
 
63
 
 
64
 
65
 
66
+ init_model()
 
 
 
 
 
 
 
67
 
68
  def generate_report(image_bytes):
69
  image = Image.open(io.BytesIO(image_bytes))