|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from pathlib import Path |
|
|
import json |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
from sklearn.preprocessing import LabelEncoder |
|
|
import pickle |
|
|
from typing import Dict, List, Optional |
|
|
import warnings |
|
|
import random |
|
|
|
|
|
|
|
|
class InterTaskAttention(nn.Module): |
|
|
""" |
|
|
Inter-Task Attention机制 |
|
|
学习任务间的相互关系和依赖 |
|
|
""" |
|
|
|
|
|
def __init__(self, hidden_dim: int, num_tasks: int, num_heads: int = 4): |
|
|
super().__init__() |
|
|
self.hidden_dim = hidden_dim |
|
|
self.num_tasks = num_tasks |
|
|
self.num_heads = num_heads |
|
|
|
|
|
|
|
|
self.task_attention = nn.MultiheadAttention( |
|
|
embed_dim=hidden_dim, |
|
|
num_heads=num_heads, |
|
|
dropout=0.1, |
|
|
batch_first=True |
|
|
) |
|
|
|
|
|
|
|
|
self.task_query = nn.Linear(hidden_dim, hidden_dim) |
|
|
self.task_key = nn.Linear(hidden_dim, hidden_dim) |
|
|
self.task_value = nn.Linear(hidden_dim, hidden_dim) |
|
|
|
|
|
|
|
|
self.norm1 = nn.LayerNorm(hidden_dim) |
|
|
self.norm2 = nn.LayerNorm(hidden_dim) |
|
|
|
|
|
|
|
|
self.ffn = nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim * 2), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(hidden_dim * 2, hidden_dim) |
|
|
) |
|
|
|
|
|
def forward(self, task_features): |
|
|
""" |
|
|
Args: |
|
|
task_features: [batch_size, num_tasks, hidden_dim] |
|
|
Returns: |
|
|
refined_features: [batch_size, num_tasks, hidden_dim] |
|
|
""" |
|
|
|
|
|
q = self.task_query(task_features) |
|
|
k = self.task_key(task_features) |
|
|
v = self.task_value(task_features) |
|
|
|
|
|
attended_features, attention_weights = self.task_attention(q, k, v) |
|
|
|
|
|
|
|
|
task_features = self.norm1(task_features + attended_features) |
|
|
|
|
|
|
|
|
ffn_output = self.ffn(task_features) |
|
|
task_features = self.norm2(task_features + ffn_output) |
|
|
|
|
|
return task_features, attention_weights |
|
|
|
|
|
|
|
|
class OmniPathWithInterTaskAttention(nn.Module): |
|
|
""" |
|
|
OmniPath模型 + Inter-Task Attention |
|
|
从预提取的特征进行多任务学习 |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
label_mappings: Dict, |
|
|
feature_dim: int = 1024, |
|
|
hidden_dim: int = 256, |
|
|
dropout: float = 0.3, |
|
|
use_inter_task_attention: bool = True, |
|
|
inter_task_heads: int = 4): |
|
|
super().__init__() |
|
|
|
|
|
self.label_mappings = label_mappings |
|
|
self.num_tasks = len(label_mappings) |
|
|
self.use_inter_task_attention = use_inter_task_attention |
|
|
|
|
|
|
|
|
self.tile_encoder = nn.Sequential( |
|
|
nn.Linear(feature_dim, hidden_dim), |
|
|
nn.LayerNorm(hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout) |
|
|
) |
|
|
|
|
|
|
|
|
self.tile_attention = nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim // 2), |
|
|
nn.Tanh(), |
|
|
nn.Linear(hidden_dim // 2, 1) |
|
|
) |
|
|
|
|
|
|
|
|
self.task_encoders = nn.ModuleDict() |
|
|
for task_name in label_mappings.keys(): |
|
|
self.task_encoders[task_name] = nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim), |
|
|
nn.LayerNorm(hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout) |
|
|
) |
|
|
|
|
|
|
|
|
if use_inter_task_attention: |
|
|
self.inter_task_attention = InterTaskAttention( |
|
|
hidden_dim=hidden_dim, |
|
|
num_tasks=self.num_tasks, |
|
|
num_heads=inter_task_heads |
|
|
) |
|
|
|
|
|
|
|
|
self.task_heads = nn.ModuleDict() |
|
|
for task_name, mapping in label_mappings.items(): |
|
|
self.task_heads[task_name] = nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim // 2), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden_dim // 2, mapping['num_classes']) |
|
|
) |
|
|
|
|
|
def forward(self, features, return_attention=False): |
|
|
""" |
|
|
Args: |
|
|
features: [batch_size, num_tiles, feature_dim] |
|
|
Returns: |
|
|
outputs: dict of task predictions |
|
|
""" |
|
|
batch_size, num_tiles, _ = features.shape |
|
|
|
|
|
|
|
|
tile_features = self.tile_encoder(features) |
|
|
|
|
|
|
|
|
attention_logits = self.tile_attention(tile_features) |
|
|
attention_weights = F.softmax(attention_logits, dim=1) |
|
|
|
|
|
|
|
|
patient_features = torch.sum(tile_features * attention_weights, dim=1) |
|
|
|
|
|
|
|
|
task_features_list = [] |
|
|
task_names = list(self.label_mappings.keys()) |
|
|
|
|
|
for task_name in task_names: |
|
|
task_feat = self.task_encoders[task_name](patient_features) |
|
|
task_features_list.append(task_feat.unsqueeze(1)) |
|
|
|
|
|
task_features = torch.cat(task_features_list, dim=1) |
|
|
|
|
|
|
|
|
inter_task_attn_weights = None |
|
|
if self.use_inter_task_attention: |
|
|
task_features, inter_task_attn_weights = self.inter_task_attention(task_features) |
|
|
|
|
|
|
|
|
outputs = {} |
|
|
for i, task_name in enumerate(task_names): |
|
|
task_feat = task_features[:, i, :] |
|
|
outputs[task_name] = self.task_heads[task_name](task_feat) |
|
|
|
|
|
if return_attention: |
|
|
outputs['tile_attention'] = attention_weights |
|
|
outputs['inter_task_attention'] = inter_task_attn_weights |
|
|
|
|
|
return outputs |