IlPakoZ commited on
Commit
7d1331e
·
verified ·
1 Parent(s): b7daa57

Upload analysis.py

Browse files
Files changed (1) hide show
  1. analysis.py +12 -14
analysis.py CHANGED
@@ -36,9 +36,6 @@ def plot_crossattention_weights(target_mask, drug_mask, target_tokenized, drug_t
36
  if hasattr(drug_mask, 'cpu'):
37
  drug_mask = drug_mask.cpu()
38
 
39
- # Convert boolean masks
40
- target_mask = target_mask.astype(bool)
41
- drug_mask = drug_mask.astype(bool)
42
  logger.info(f"Target mask shape: {target_mask.shape}, Drug mask shape: {drug_mask.shape}")
43
 
44
  # Get tokens for unmasked positions
@@ -56,32 +53,33 @@ def plot_crossattention_weights(target_mask, drug_mask, target_tokenized, drug_t
56
  crossattention_weights = crossattention_weights.cpu()
57
 
58
  subset = crossattention_weights[target_mask][:, drug_mask]
59
- logger.info(f"Subset shape: {subset.shape}")
60
- height, width = subset.shape
61
 
 
 
 
62
  fig, ax = plt.subplots(
63
  figsize=(width * 0.2 + 2, height * 0.2 + 3),
64
  dpi=300
65
  )
66
- im = ax.imshow(subset, cmap='hot', interpolation='nearest')
67
-
68
  plt.colorbar(im, ax=ax, orientation='vertical', fraction=0.05, shrink=0.8)
69
-
70
  plt.title("Cross-Attention Weights")
71
  plt.xlabel("Drug Tokens")
72
  plt.ylabel("Target Tokens")
73
-
74
  # Create vertical labels for drug tokens
75
  vertical_labels = ['\n'.join(label) for label in drug_token_str]
76
  plt.xticks(ticks=np.arange(width), labels=vertical_labels)
77
  plt.yticks(ticks=np.arange(height), labels=target_token_str)
78
-
79
  # Add text annotations
80
- max_val = subset.max()
81
  logger.info(f"Max crossattention weight: {max_val}")
82
  for i in range(height):
83
  for j in range(width):
84
- val = subset[i, j]
85
  if val > max_val / 2:
86
  # Extract just the digits after the decimal (no leading '0.')
87
  text = f"{val % 1:.2f}"[2:]
@@ -89,14 +87,14 @@ def plot_crossattention_weights(target_mask, drug_mask, target_tokenized, drug_t
89
  ha='center', va='center',
90
  color="black",
91
  fontsize=6)
92
-
93
  # Convert to PIL Image
94
  buf = io.BytesIO()
95
  plt.savefig(buf, format='png', bbox_inches='tight', dpi=300)
96
  buf.seek(0)
97
  img = Image.open(buf)
98
  plt.close()
99
-
100
  logger.info("Finished plot_crossattention_weights successfully")
101
  return img
102
 
 
36
  if hasattr(drug_mask, 'cpu'):
37
  drug_mask = drug_mask.cpu()
38
 
 
 
 
39
  logger.info(f"Target mask shape: {target_mask.shape}, Drug mask shape: {drug_mask.shape}")
40
 
41
  # Get tokens for unmasked positions
 
53
  crossattention_weights = crossattention_weights.cpu()
54
 
55
  subset = crossattention_weights[target_mask][:, drug_mask]
56
+ subset_np = subset.detach().numpy() # Convert to numpy for matplotlib
 
57
 
58
+ logger.info(f"Subset shape: {subset_np.shape}")
59
+ height, width = subset_np.shape
60
+
61
  fig, ax = plt.subplots(
62
  figsize=(width * 0.2 + 2, height * 0.2 + 3),
63
  dpi=300
64
  )
65
+ im = ax.imshow(subset_np, cmap='hot', interpolation='nearest')
66
+
67
  plt.colorbar(im, ax=ax, orientation='vertical', fraction=0.05, shrink=0.8)
 
68
  plt.title("Cross-Attention Weights")
69
  plt.xlabel("Drug Tokens")
70
  plt.ylabel("Target Tokens")
71
+
72
  # Create vertical labels for drug tokens
73
  vertical_labels = ['\n'.join(label) for label in drug_token_str]
74
  plt.xticks(ticks=np.arange(width), labels=vertical_labels)
75
  plt.yticks(ticks=np.arange(height), labels=target_token_str)
76
+
77
  # Add text annotations
78
+ max_val = subset_np.max()
79
  logger.info(f"Max crossattention weight: {max_val}")
80
  for i in range(height):
81
  for j in range(width):
82
+ val = subset_np[i, j]
83
  if val > max_val / 2:
84
  # Extract just the digits after the decimal (no leading '0.')
85
  text = f"{val % 1:.2f}"[2:]
 
87
  ha='center', va='center',
88
  color="black",
89
  fontsize=6)
90
+
91
  # Convert to PIL Image
92
  buf = io.BytesIO()
93
  plt.savefig(buf, format='png', bbox_inches='tight', dpi=300)
94
  buf.seek(0)
95
  img = Image.open(buf)
96
  plt.close()
97
+
98
  logger.info("Finished plot_crossattention_weights successfully")
99
  return img
100