TETSU0701's picture
Upload Model.py
59f7675 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import json
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder
import pickle
from typing import Dict, List, Optional
import warnings
import random
# ========================= 步骤3: Inter-Task Attention模型 =========================
class InterTaskAttention(nn.Module):
"""
Inter-Task Attention机制
学习任务间的相互关系和依赖
"""
def __init__(self, hidden_dim: int, num_tasks: int, num_heads: int = 4):
super().__init__()
self.hidden_dim = hidden_dim
self.num_tasks = num_tasks
self.num_heads = num_heads
# Multi-head attention for tasks
self.task_attention = nn.MultiheadAttention(
embed_dim=hidden_dim,
num_heads=num_heads,
dropout=0.1,
batch_first=True
)
# Task-specific query, key, value projections
self.task_query = nn.Linear(hidden_dim, hidden_dim)
self.task_key = nn.Linear(hidden_dim, hidden_dim)
self.task_value = nn.Linear(hidden_dim, hidden_dim)
# Layer normalization
self.norm1 = nn.LayerNorm(hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
# Feed-forward network
self.ffn = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim * 2, hidden_dim)
)
def forward(self, task_features):
"""
Args:
task_features: [batch_size, num_tasks, hidden_dim]
Returns:
refined_features: [batch_size, num_tasks, hidden_dim]
"""
# Self-attention across tasks
q = self.task_query(task_features)
k = self.task_key(task_features)
v = self.task_value(task_features)
attended_features, attention_weights = self.task_attention(q, k, v)
# Residual connection + normalization
task_features = self.norm1(task_features + attended_features)
# Feed-forward network
ffn_output = self.ffn(task_features)
task_features = self.norm2(task_features + ffn_output)
return task_features, attention_weights
class OmniPathWithInterTaskAttention(nn.Module):
"""
OmniPath模型 + Inter-Task Attention
从预提取的特征进行多任务学习
"""
def __init__(self,
label_mappings: Dict,
feature_dim: int = 1024,
hidden_dim: int = 256,
dropout: float = 0.3,
use_inter_task_attention: bool = True,
inter_task_heads: int = 4):
super().__init__()
self.label_mappings = label_mappings
self.num_tasks = len(label_mappings)
self.use_inter_task_attention = use_inter_task_attention
# Tile-level feature encoder
self.tile_encoder = nn.Sequential(
nn.Linear(feature_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout)
)
# Tile attention (for aggregating tiles to patient-level)
self.tile_attention = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.Tanh(),
nn.Linear(hidden_dim // 2, 1)
)
# Task-specific encoders (before inter-task attention)
self.task_encoders = nn.ModuleDict()
for task_name in label_mappings.keys():
self.task_encoders[task_name] = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout)
)
# Inter-Task Attention
if use_inter_task_attention:
self.inter_task_attention = InterTaskAttention(
hidden_dim=hidden_dim,
num_tasks=self.num_tasks,
num_heads=inter_task_heads
)
# Task-specific prediction heads (after inter-task attention)
self.task_heads = nn.ModuleDict()
for task_name, mapping in label_mappings.items():
self.task_heads[task_name] = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, mapping['num_classes'])
)
def forward(self, features, return_attention=False):
"""
Args:
features: [batch_size, num_tiles, feature_dim]
Returns:
outputs: dict of task predictions
"""
batch_size, num_tiles, _ = features.shape
# Encode tile features
tile_features = self.tile_encoder(features) # [B, N, H]
# Compute tile attention weights
attention_logits = self.tile_attention(tile_features) # [B, N, 1]
attention_weights = F.softmax(attention_logits, dim=1) # [B, N, 1]
# Aggregate tiles to patient-level
patient_features = torch.sum(tile_features * attention_weights, dim=1) # [B, H]
# Task-specific encoding
task_features_list = []
task_names = list(self.label_mappings.keys())
for task_name in task_names:
task_feat = self.task_encoders[task_name](patient_features) # [B, H]
task_features_list.append(task_feat.unsqueeze(1)) # [B, 1, H]
task_features = torch.cat(task_features_list, dim=1) # [B, num_tasks, H]
# Inter-Task Attention
inter_task_attn_weights = None
if self.use_inter_task_attention:
task_features, inter_task_attn_weights = self.inter_task_attention(task_features)
# Task-specific predictions
outputs = {}
for i, task_name in enumerate(task_names):
task_feat = task_features[:, i, :] # [B, H]
outputs[task_name] = self.task_heads[task_name](task_feat) # [B, num_classes]
if return_attention:
outputs['tile_attention'] = attention_weights
outputs['inter_task_attention'] = inter_task_attn_weights
return outputs