TETSU0701 commited on
Commit
59f7675
·
verified ·
1 Parent(s): af53027

Upload Model.py

Browse files
Files changed (1) hide show
  1. Model.py +184 -0
Model.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from pathlib import Path
6
+ import json
7
+ import pandas as pd
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+ from sklearn.preprocessing import LabelEncoder
11
+ import pickle
12
+ from typing import Dict, List, Optional
13
+ import warnings
14
+ import random
15
+ # ========================= 步骤3: Inter-Task Attention模型 =========================
16
+
17
+ class InterTaskAttention(nn.Module):
18
+ """
19
+ Inter-Task Attention机制
20
+ 学习任务间的相互关系和依赖
21
+ """
22
+
23
+ def __init__(self, hidden_dim: int, num_tasks: int, num_heads: int = 4):
24
+ super().__init__()
25
+ self.hidden_dim = hidden_dim
26
+ self.num_tasks = num_tasks
27
+ self.num_heads = num_heads
28
+
29
+ # Multi-head attention for tasks
30
+ self.task_attention = nn.MultiheadAttention(
31
+ embed_dim=hidden_dim,
32
+ num_heads=num_heads,
33
+ dropout=0.1,
34
+ batch_first=True
35
+ )
36
+
37
+ # Task-specific query, key, value projections
38
+ self.task_query = nn.Linear(hidden_dim, hidden_dim)
39
+ self.task_key = nn.Linear(hidden_dim, hidden_dim)
40
+ self.task_value = nn.Linear(hidden_dim, hidden_dim)
41
+
42
+ # Layer normalization
43
+ self.norm1 = nn.LayerNorm(hidden_dim)
44
+ self.norm2 = nn.LayerNorm(hidden_dim)
45
+
46
+ # Feed-forward network
47
+ self.ffn = nn.Sequential(
48
+ nn.Linear(hidden_dim, hidden_dim * 2),
49
+ nn.ReLU(),
50
+ nn.Dropout(0.1),
51
+ nn.Linear(hidden_dim * 2, hidden_dim)
52
+ )
53
+
54
+ def forward(self, task_features):
55
+ """
56
+ Args:
57
+ task_features: [batch_size, num_tasks, hidden_dim]
58
+ Returns:
59
+ refined_features: [batch_size, num_tasks, hidden_dim]
60
+ """
61
+ # Self-attention across tasks
62
+ q = self.task_query(task_features)
63
+ k = self.task_key(task_features)
64
+ v = self.task_value(task_features)
65
+
66
+ attended_features, attention_weights = self.task_attention(q, k, v)
67
+
68
+ # Residual connection + normalization
69
+ task_features = self.norm1(task_features + attended_features)
70
+
71
+ # Feed-forward network
72
+ ffn_output = self.ffn(task_features)
73
+ task_features = self.norm2(task_features + ffn_output)
74
+
75
+ return task_features, attention_weights
76
+
77
+
78
+ class OmniPathWithInterTaskAttention(nn.Module):
79
+ """
80
+ OmniPath模型 + Inter-Task Attention
81
+ 从预提取的特征进行多任务学习
82
+ """
83
+
84
+ def __init__(self,
85
+ label_mappings: Dict,
86
+ feature_dim: int = 1024,
87
+ hidden_dim: int = 256,
88
+ dropout: float = 0.3,
89
+ use_inter_task_attention: bool = True,
90
+ inter_task_heads: int = 4):
91
+ super().__init__()
92
+
93
+ self.label_mappings = label_mappings
94
+ self.num_tasks = len(label_mappings)
95
+ self.use_inter_task_attention = use_inter_task_attention
96
+
97
+ # Tile-level feature encoder
98
+ self.tile_encoder = nn.Sequential(
99
+ nn.Linear(feature_dim, hidden_dim),
100
+ nn.LayerNorm(hidden_dim),
101
+ nn.ReLU(),
102
+ nn.Dropout(dropout)
103
+ )
104
+
105
+ # Tile attention (for aggregating tiles to patient-level)
106
+ self.tile_attention = nn.Sequential(
107
+ nn.Linear(hidden_dim, hidden_dim // 2),
108
+ nn.Tanh(),
109
+ nn.Linear(hidden_dim // 2, 1)
110
+ )
111
+
112
+ # Task-specific encoders (before inter-task attention)
113
+ self.task_encoders = nn.ModuleDict()
114
+ for task_name in label_mappings.keys():
115
+ self.task_encoders[task_name] = nn.Sequential(
116
+ nn.Linear(hidden_dim, hidden_dim),
117
+ nn.LayerNorm(hidden_dim),
118
+ nn.ReLU(),
119
+ nn.Dropout(dropout)
120
+ )
121
+
122
+ # Inter-Task Attention
123
+ if use_inter_task_attention:
124
+ self.inter_task_attention = InterTaskAttention(
125
+ hidden_dim=hidden_dim,
126
+ num_tasks=self.num_tasks,
127
+ num_heads=inter_task_heads
128
+ )
129
+
130
+ # Task-specific prediction heads (after inter-task attention)
131
+ self.task_heads = nn.ModuleDict()
132
+ for task_name, mapping in label_mappings.items():
133
+ self.task_heads[task_name] = nn.Sequential(
134
+ nn.Linear(hidden_dim, hidden_dim // 2),
135
+ nn.ReLU(),
136
+ nn.Dropout(dropout),
137
+ nn.Linear(hidden_dim // 2, mapping['num_classes'])
138
+ )
139
+
140
+ def forward(self, features, return_attention=False):
141
+ """
142
+ Args:
143
+ features: [batch_size, num_tiles, feature_dim]
144
+ Returns:
145
+ outputs: dict of task predictions
146
+ """
147
+ batch_size, num_tiles, _ = features.shape
148
+
149
+ # Encode tile features
150
+ tile_features = self.tile_encoder(features) # [B, N, H]
151
+
152
+ # Compute tile attention weights
153
+ attention_logits = self.tile_attention(tile_features) # [B, N, 1]
154
+ attention_weights = F.softmax(attention_logits, dim=1) # [B, N, 1]
155
+
156
+ # Aggregate tiles to patient-level
157
+ patient_features = torch.sum(tile_features * attention_weights, dim=1) # [B, H]
158
+
159
+ # Task-specific encoding
160
+ task_features_list = []
161
+ task_names = list(self.label_mappings.keys())
162
+
163
+ for task_name in task_names:
164
+ task_feat = self.task_encoders[task_name](patient_features) # [B, H]
165
+ task_features_list.append(task_feat.unsqueeze(1)) # [B, 1, H]
166
+
167
+ task_features = torch.cat(task_features_list, dim=1) # [B, num_tasks, H]
168
+
169
+ # Inter-Task Attention
170
+ inter_task_attn_weights = None
171
+ if self.use_inter_task_attention:
172
+ task_features, inter_task_attn_weights = self.inter_task_attention(task_features)
173
+
174
+ # Task-specific predictions
175
+ outputs = {}
176
+ for i, task_name in enumerate(task_names):
177
+ task_feat = task_features[:, i, :] # [B, H]
178
+ outputs[task_name] = self.task_heads[task_name](task_feat) # [B, num_classes]
179
+
180
+ if return_attention:
181
+ outputs['tile_attention'] = attention_weights
182
+ outputs['inter_task_attention'] = inter_task_attn_weights
183
+
184
+ return outputs