IlPakoZ commited on
Commit
b7daa57
·
verified ·
1 Parent(s): 696610f

Upload analysis.py

Browse files
Files changed (1) hide show
  1. analysis.py +5 -13
analysis.py CHANGED
@@ -32,11 +32,9 @@ def plot_crossattention_weights(target_mask, drug_mask, target_tokenized, drug_t
32
 
33
  # Convert masks to numpy arrays if they're tensors
34
  if hasattr(target_mask, 'cpu'):
35
- logger.info("Converting target_mask tensor to numpy")
36
- target_mask = target_mask.cpu().numpy()
37
  if hasattr(drug_mask, 'cpu'):
38
- logger.info("Converting drug_mask tensor to numpy")
39
- drug_mask = drug_mask.cpu().numpy()
40
 
41
  # Convert boolean masks
42
  target_mask = target_mask.astype(bool)
@@ -45,23 +43,17 @@ def plot_crossattention_weights(target_mask, drug_mask, target_tokenized, drug_t
45
 
46
  # Get tokens for unmasked positions
47
  tokens_input = target_tokenized["input_ids"][0][target_mask]
48
- if hasattr(tokens_input, 'cpu'):
49
- logger.info("Converting target tokens tensor to numpy")
50
- tokens_input = tokens_input.cpu().numpy()
51
  target_token_str = target_tokenizer.convert_ids_to_tokens(tokens_input)
52
- logger.info(f"Target tokens: {target_token_str}")
53
 
54
  tokens_input = drug_tokenized["input_ids"][0][drug_mask]
55
- if hasattr(tokens_input, 'cpu'):
56
- logger.info("Converting drug tokens tensor to numpy")
57
- tokens_input = tokens_input.cpu().numpy()
58
  drug_token_str = drug_tokenizer.convert_ids_to_tokens(tokens_input)
 
59
  logger.info(f"Drug tokens: {drug_token_str}")
60
 
61
  # Extract subset of attention weights
62
  if hasattr(crossattention_weights, 'cpu'):
63
- logger.info("Converting crossattention_weights tensor to numpy")
64
- crossattention_weights = crossattention_weights.cpu().numpy()
65
 
66
  subset = crossattention_weights[target_mask][:, drug_mask]
67
  logger.info(f"Subset shape: {subset.shape}")
 
32
 
33
  # Convert masks to numpy arrays if they're tensors
34
  if hasattr(target_mask, 'cpu'):
35
+ target_mask = target_mask.cpu()
 
36
  if hasattr(drug_mask, 'cpu'):
37
+ drug_mask = drug_mask.cpu()
 
38
 
39
  # Convert boolean masks
40
  target_mask = target_mask.astype(bool)
 
43
 
44
  # Get tokens for unmasked positions
45
  tokens_input = target_tokenized["input_ids"][0][target_mask]
46
+
 
 
47
  target_token_str = target_tokenizer.convert_ids_to_tokens(tokens_input)
 
48
 
49
  tokens_input = drug_tokenized["input_ids"][0][drug_mask]
 
 
 
50
  drug_token_str = drug_tokenizer.convert_ids_to_tokens(tokens_input)
51
+
52
  logger.info(f"Drug tokens: {drug_token_str}")
53
 
54
  # Extract subset of attention weights
55
  if hasattr(crossattention_weights, 'cpu'):
56
+ crossattention_weights = crossattention_weights.cpu()
 
57
 
58
  subset = crossattention_weights[target_mask][:, drug_mask]
59
  logger.info(f"Subset shape: {subset.shape}")