import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from pathlib import Path import json import pandas as pd import numpy as np from tqdm import tqdm from sklearn.preprocessing import LabelEncoder import pickle from typing import Dict, List, Optional import warnings import random # ========================= 步骤3: Inter-Task Attention模型 ========================= class InterTaskAttention(nn.Module): """ Inter-Task Attention机制 学习任务间的相互关系和依赖 """ def __init__(self, hidden_dim: int, num_tasks: int, num_heads: int = 4): super().__init__() self.hidden_dim = hidden_dim self.num_tasks = num_tasks self.num_heads = num_heads # Multi-head attention for tasks self.task_attention = nn.MultiheadAttention( embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True ) # Task-specific query, key, value projections self.task_query = nn.Linear(hidden_dim, hidden_dim) self.task_key = nn.Linear(hidden_dim, hidden_dim) self.task_value = nn.Linear(hidden_dim, hidden_dim) # Layer normalization self.norm1 = nn.LayerNorm(hidden_dim) self.norm2 = nn.LayerNorm(hidden_dim) # Feed-forward network self.ffn = nn.Sequential( nn.Linear(hidden_dim, hidden_dim * 2), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_dim * 2, hidden_dim) ) def forward(self, task_features): """ Args: task_features: [batch_size, num_tasks, hidden_dim] Returns: refined_features: [batch_size, num_tasks, hidden_dim] """ # Self-attention across tasks q = self.task_query(task_features) k = self.task_key(task_features) v = self.task_value(task_features) attended_features, attention_weights = self.task_attention(q, k, v) # Residual connection + normalization task_features = self.norm1(task_features + attended_features) # Feed-forward network ffn_output = self.ffn(task_features) task_features = self.norm2(task_features + ffn_output) return task_features, attention_weights class OmniPathWithInterTaskAttention(nn.Module): """ OmniPath模型 + Inter-Task Attention 从预提取的特征进行多任务学习 """ def __init__(self, label_mappings: Dict, feature_dim: int = 1024, hidden_dim: int = 256, dropout: float = 0.3, use_inter_task_attention: bool = True, inter_task_heads: int = 4): super().__init__() self.label_mappings = label_mappings self.num_tasks = len(label_mappings) self.use_inter_task_attention = use_inter_task_attention # Tile-level feature encoder self.tile_encoder = nn.Sequential( nn.Linear(feature_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Dropout(dropout) ) # Tile attention (for aggregating tiles to patient-level) self.tile_attention = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.Tanh(), nn.Linear(hidden_dim // 2, 1) ) # Task-specific encoders (before inter-task attention) self.task_encoders = nn.ModuleDict() for task_name in label_mappings.keys(): self.task_encoders[task_name] = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Dropout(dropout) ) # Inter-Task Attention if use_inter_task_attention: self.inter_task_attention = InterTaskAttention( hidden_dim=hidden_dim, num_tasks=self.num_tasks, num_heads=inter_task_heads ) # Task-specific prediction heads (after inter-task attention) self.task_heads = nn.ModuleDict() for task_name, mapping in label_mappings.items(): self.task_heads[task_name] = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim // 2, mapping['num_classes']) ) def forward(self, features, return_attention=False): """ Args: features: [batch_size, num_tiles, feature_dim] Returns: outputs: dict of task predictions """ batch_size, num_tiles, _ = features.shape # Encode tile features tile_features = self.tile_encoder(features) # [B, N, H] # Compute tile attention weights attention_logits = self.tile_attention(tile_features) # [B, N, 1] attention_weights = F.softmax(attention_logits, dim=1) # [B, N, 1] # Aggregate tiles to patient-level patient_features = torch.sum(tile_features * attention_weights, dim=1) # [B, H] # Task-specific encoding task_features_list = [] task_names = list(self.label_mappings.keys()) for task_name in task_names: task_feat = self.task_encoders[task_name](patient_features) # [B, H] task_features_list.append(task_feat.unsqueeze(1)) # [B, 1, H] task_features = torch.cat(task_features_list, dim=1) # [B, num_tasks, H] # Inter-Task Attention inter_task_attn_weights = None if self.use_inter_task_attention: task_features, inter_task_attn_weights = self.inter_task_attention(task_features) # Task-specific predictions outputs = {} for i, task_name in enumerate(task_names): task_feat = task_features[:, i, :] # [B, H] outputs[task_name] = self.task_heads[task_name](task_feat) # [B, num_classes] if return_attention: outputs['tile_attention'] = attention_weights outputs['inter_task_attention'] = inter_task_attn_weights return outputs