Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset, DataLoader | |
| from pathlib import Path | |
| import json | |
| import pandas as pd | |
| import numpy as np | |
| from tqdm import tqdm | |
| from sklearn.preprocessing import LabelEncoder | |
| import pickle | |
| from typing import Dict, List, Optional | |
| import warnings | |
| import random | |
| # ========================= 步骤3: Inter-Task Attention模型 ========================= | |
| class InterTaskAttention(nn.Module): | |
| """ | |
| Inter-Task Attention机制 | |
| 学习任务间的相互关系和依赖 | |
| """ | |
| def __init__(self, hidden_dim: int, num_tasks: int, num_heads: int = 4): | |
| super().__init__() | |
| self.hidden_dim = hidden_dim | |
| self.num_tasks = num_tasks | |
| self.num_heads = num_heads | |
| # Multi-head attention for tasks | |
| self.task_attention = nn.MultiheadAttention( | |
| embed_dim=hidden_dim, | |
| num_heads=num_heads, | |
| dropout=0.1, | |
| batch_first=True | |
| ) | |
| # Task-specific query, key, value projections | |
| self.task_query = nn.Linear(hidden_dim, hidden_dim) | |
| self.task_key = nn.Linear(hidden_dim, hidden_dim) | |
| self.task_value = nn.Linear(hidden_dim, hidden_dim) | |
| # Layer normalization | |
| self.norm1 = nn.LayerNorm(hidden_dim) | |
| self.norm2 = nn.LayerNorm(hidden_dim) | |
| # Feed-forward network | |
| self.ffn = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim * 2), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(hidden_dim * 2, hidden_dim) | |
| ) | |
| def forward(self, task_features): | |
| """ | |
| Args: | |
| task_features: [batch_size, num_tasks, hidden_dim] | |
| Returns: | |
| refined_features: [batch_size, num_tasks, hidden_dim] | |
| """ | |
| # Self-attention across tasks | |
| q = self.task_query(task_features) | |
| k = self.task_key(task_features) | |
| v = self.task_value(task_features) | |
| attended_features, attention_weights = self.task_attention(q, k, v) | |
| # Residual connection + normalization | |
| task_features = self.norm1(task_features + attended_features) | |
| # Feed-forward network | |
| ffn_output = self.ffn(task_features) | |
| task_features = self.norm2(task_features + ffn_output) | |
| return task_features, attention_weights | |
| class OmniPathWithInterTaskAttention(nn.Module): | |
| """ | |
| OmniPath模型 + Inter-Task Attention | |
| 从预提取的特征进行多任务学习 | |
| """ | |
| def __init__(self, | |
| label_mappings: Dict, | |
| feature_dim: int = 1024, | |
| hidden_dim: int = 256, | |
| dropout: float = 0.3, | |
| use_inter_task_attention: bool = True, | |
| inter_task_heads: int = 4): | |
| super().__init__() | |
| self.label_mappings = label_mappings | |
| self.num_tasks = len(label_mappings) | |
| self.use_inter_task_attention = use_inter_task_attention | |
| # Tile-level feature encoder | |
| self.tile_encoder = nn.Sequential( | |
| nn.Linear(feature_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout) | |
| ) | |
| # Tile attention (for aggregating tiles to patient-level) | |
| self.tile_attention = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim // 2), | |
| nn.Tanh(), | |
| nn.Linear(hidden_dim // 2, 1) | |
| ) | |
| # Task-specific encoders (before inter-task attention) | |
| self.task_encoders = nn.ModuleDict() | |
| for task_name in label_mappings.keys(): | |
| self.task_encoders[task_name] = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout) | |
| ) | |
| # Inter-Task Attention | |
| if use_inter_task_attention: | |
| self.inter_task_attention = InterTaskAttention( | |
| hidden_dim=hidden_dim, | |
| num_tasks=self.num_tasks, | |
| num_heads=inter_task_heads | |
| ) | |
| # Task-specific prediction heads (after inter-task attention) | |
| self.task_heads = nn.ModuleDict() | |
| for task_name, mapping in label_mappings.items(): | |
| self.task_heads[task_name] = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim // 2), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim // 2, mapping['num_classes']) | |
| ) | |
| def forward(self, features, return_attention=False): | |
| """ | |
| Args: | |
| features: [batch_size, num_tiles, feature_dim] | |
| Returns: | |
| outputs: dict of task predictions | |
| """ | |
| batch_size, num_tiles, _ = features.shape | |
| # Encode tile features | |
| tile_features = self.tile_encoder(features) # [B, N, H] | |
| # Compute tile attention weights | |
| attention_logits = self.tile_attention(tile_features) # [B, N, 1] | |
| attention_weights = F.softmax(attention_logits, dim=1) # [B, N, 1] | |
| # Aggregate tiles to patient-level | |
| patient_features = torch.sum(tile_features * attention_weights, dim=1) # [B, H] | |
| # Task-specific encoding | |
| task_features_list = [] | |
| task_names = list(self.label_mappings.keys()) | |
| for task_name in task_names: | |
| task_feat = self.task_encoders[task_name](patient_features) # [B, H] | |
| task_features_list.append(task_feat.unsqueeze(1)) # [B, 1, H] | |
| task_features = torch.cat(task_features_list, dim=1) # [B, num_tasks, H] | |
| # Inter-Task Attention | |
| inter_task_attn_weights = None | |
| if self.use_inter_task_attention: | |
| task_features, inter_task_attn_weights = self.inter_task_attention(task_features) | |
| # Task-specific predictions | |
| outputs = {} | |
| for i, task_name in enumerate(task_names): | |
| task_feat = task_features[:, i, :] # [B, H] | |
| outputs[task_name] = self.task_heads[task_name](task_feat) # [B, num_classes] | |
| if return_attention: | |
| outputs['tile_attention'] = attention_weights | |
| outputs['inter_task_attention'] = inter_task_attn_weights | |
| return outputs |