Spaces:
Running
Running
Upload analysis.py
Browse files- analysis.py +5 -13
analysis.py
CHANGED
|
@@ -32,11 +32,9 @@ def plot_crossattention_weights(target_mask, drug_mask, target_tokenized, drug_t
|
|
| 32 |
|
| 33 |
# Convert masks to numpy arrays if they're tensors
|
| 34 |
if hasattr(target_mask, 'cpu'):
|
| 35 |
-
|
| 36 |
-
target_mask = target_mask.cpu().numpy()
|
| 37 |
if hasattr(drug_mask, 'cpu'):
|
| 38 |
-
|
| 39 |
-
drug_mask = drug_mask.cpu().numpy()
|
| 40 |
|
| 41 |
# Convert boolean masks
|
| 42 |
target_mask = target_mask.astype(bool)
|
|
@@ -45,23 +43,17 @@ def plot_crossattention_weights(target_mask, drug_mask, target_tokenized, drug_t
|
|
| 45 |
|
| 46 |
# Get tokens for unmasked positions
|
| 47 |
tokens_input = target_tokenized["input_ids"][0][target_mask]
|
| 48 |
-
|
| 49 |
-
logger.info("Converting target tokens tensor to numpy")
|
| 50 |
-
tokens_input = tokens_input.cpu().numpy()
|
| 51 |
target_token_str = target_tokenizer.convert_ids_to_tokens(tokens_input)
|
| 52 |
-
logger.info(f"Target tokens: {target_token_str}")
|
| 53 |
|
| 54 |
tokens_input = drug_tokenized["input_ids"][0][drug_mask]
|
| 55 |
-
if hasattr(tokens_input, 'cpu'):
|
| 56 |
-
logger.info("Converting drug tokens tensor to numpy")
|
| 57 |
-
tokens_input = tokens_input.cpu().numpy()
|
| 58 |
drug_token_str = drug_tokenizer.convert_ids_to_tokens(tokens_input)
|
|
|
|
| 59 |
logger.info(f"Drug tokens: {drug_token_str}")
|
| 60 |
|
| 61 |
# Extract subset of attention weights
|
| 62 |
if hasattr(crossattention_weights, 'cpu'):
|
| 63 |
-
|
| 64 |
-
crossattention_weights = crossattention_weights.cpu().numpy()
|
| 65 |
|
| 66 |
subset = crossattention_weights[target_mask][:, drug_mask]
|
| 67 |
logger.info(f"Subset shape: {subset.shape}")
|
|
|
|
| 32 |
|
| 33 |
# Convert masks to numpy arrays if they're tensors
|
| 34 |
if hasattr(target_mask, 'cpu'):
|
| 35 |
+
target_mask = target_mask.cpu()
|
|
|
|
| 36 |
if hasattr(drug_mask, 'cpu'):
|
| 37 |
+
drug_mask = drug_mask.cpu()
|
|
|
|
| 38 |
|
| 39 |
# Convert boolean masks
|
| 40 |
target_mask = target_mask.astype(bool)
|
|
|
|
| 43 |
|
| 44 |
# Get tokens for unmasked positions
|
| 45 |
tokens_input = target_tokenized["input_ids"][0][target_mask]
|
| 46 |
+
|
|
|
|
|
|
|
| 47 |
target_token_str = target_tokenizer.convert_ids_to_tokens(tokens_input)
|
|
|
|
| 48 |
|
| 49 |
tokens_input = drug_tokenized["input_ids"][0][drug_mask]
|
|
|
|
|
|
|
|
|
|
| 50 |
drug_token_str = drug_tokenizer.convert_ids_to_tokens(tokens_input)
|
| 51 |
+
|
| 52 |
logger.info(f"Drug tokens: {drug_token_str}")
|
| 53 |
|
| 54 |
# Extract subset of attention weights
|
| 55 |
if hasattr(crossattention_weights, 'cpu'):
|
| 56 |
+
crossattention_weights = crossattention_weights.cpu()
|
|
|
|
| 57 |
|
| 58 |
subset = crossattention_weights[target_mask][:, drug_mask]
|
| 59 |
logger.info(f"Subset shape: {subset.shape}")
|