Z commited on
Commit
a5cdd29
·
verified ·
1 Parent(s): 9b2fb41

Upload 3523 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +7 -0
  2. SPC-UQ/.idea/.gitignore +8 -0
  3. SPC-UQ/.idea/UQ_baseline.iml +12 -0
  4. SPC-UQ/.idea/inspectionProfiles/Project_Default.xml +24 -0
  5. SPC-UQ/.idea/inspectionProfiles/profiles_settings.xml +6 -0
  6. SPC-UQ/.idea/misc.xml +7 -0
  7. SPC-UQ/.idea/modules.xml +8 -0
  8. SPC-UQ/.idea/other.xml +6 -0
  9. SPC-UQ/.idea/workspace.xml +247 -0
  10. SPC-UQ/Cubic_Regression/ConformalRegression.py +76 -0
  11. SPC-UQ/Cubic_Regression/DeepEnsembleRegression.py +93 -0
  12. SPC-UQ/Cubic_Regression/EDLQuantileRegression.py +155 -0
  13. SPC-UQ/Cubic_Regression/EDLRegression.py +141 -0
  14. SPC-UQ/Cubic_Regression/QROC.py +125 -0
  15. SPC-UQ/Cubic_Regression/SPCRegression.py +173 -0
  16. SPC-UQ/Cubic_Regression/__pycache__/ConformalRegression.cpython-37.pyc +0 -0
  17. SPC-UQ/Cubic_Regression/__pycache__/DeepEnsembleRegression.cpython-37.pyc +0 -0
  18. SPC-UQ/Cubic_Regression/__pycache__/EDLQuantileRegression.cpython-37.pyc +0 -0
  19. SPC-UQ/Cubic_Regression/__pycache__/EDLRegression.cpython-37.pyc +0 -0
  20. SPC-UQ/Cubic_Regression/__pycache__/QROC.cpython-37.pyc +0 -0
  21. SPC-UQ/Cubic_Regression/__pycache__/SPCRegression.cpython-37.pyc +0 -0
  22. SPC-UQ/Cubic_Regression/run_cubic_tests.py +335 -0
  23. SPC-UQ/Image_Classification/README.md +285 -0
  24. SPC-UQ/Image_Classification/data/__init__.py +0 -0
  25. SPC-UQ/Image_Classification/data/ood_detection/__init__.py +0 -0
  26. SPC-UQ/Image_Classification/data/ood_detection/cifar10.py +107 -0
  27. SPC-UQ/Image_Classification/data/ood_detection/cifar100.py +107 -0
  28. SPC-UQ/Image_Classification/data/ood_detection/imagenet.py +85 -0
  29. SPC-UQ/Image_Classification/data/ood_detection/imagenet_a.py +37 -0
  30. SPC-UQ/Image_Classification/data/ood_detection/imagenet_o.py +37 -0
  31. SPC-UQ/Image_Classification/data/ood_detection/ood_union.py +105 -0
  32. SPC-UQ/Image_Classification/data/ood_detection/svhn.py +94 -0
  33. SPC-UQ/Image_Classification/data/ood_detection/tinyimagenet.py +115 -0
  34. SPC-UQ/Image_Classification/environment.yml +16 -0
  35. SPC-UQ/Image_Classification/evaluate.py +1427 -0
  36. SPC-UQ/Image_Classification/evaluate_laplace.py +355 -0
  37. SPC-UQ/Image_Classification/metrics/__init__.py +0 -0
  38. SPC-UQ/Image_Classification/metrics/calibration_metrics.py +129 -0
  39. SPC-UQ/Image_Classification/metrics/classification_metrics.py +211 -0
  40. SPC-UQ/Image_Classification/metrics/ood_metrics.py +135 -0
  41. SPC-UQ/Image_Classification/metrics/uncertainty_confidence.py +67 -0
  42. SPC-UQ/Image_Classification/net/__init__.py +0 -0
  43. SPC-UQ/Image_Classification/net/imagenet_vgg.py +106 -0
  44. SPC-UQ/Image_Classification/net/imagenet_vit.py +101 -0
  45. SPC-UQ/Image_Classification/net/imagenet_wide.py +46 -0
  46. SPC-UQ/Image_Classification/net/lenet.py +37 -0
  47. SPC-UQ/Image_Classification/net/resnet.py +245 -0
  48. SPC-UQ/Image_Classification/net/resnet_edl.py +252 -0
  49. SPC-UQ/Image_Classification/net/resnet_uq.py +272 -0
  50. SPC-UQ/Image_Classification/net/spectral_normalization/__init__.py +0 -0
.gitattributes CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ SPC-UQ/MNIST_Classification/data/FashionMNIST/raw/t10k-images-idx3-ubyte filter=lfs diff=lfs merge=lfs -text
37
+ SPC-UQ/MNIST_Classification/data/FashionMNIST/raw/train-images-idx3-ubyte filter=lfs diff=lfs merge=lfs -text
38
+ SPC-UQ/MNIST_Classification/data/MNIST/raw/t10k-images-idx3-ubyte filter=lfs diff=lfs merge=lfs -text
39
+ SPC-UQ/MNIST_Classification/data/MNIST/raw/train-images-idx3-ubyte filter=lfs diff=lfs merge=lfs -text
40
+ SPC-UQ/UCI_Benchmarks/data/uci/concrete/Concrete_Data.xls filter=lfs diff=lfs merge=lfs -text
41
+ SPC-UQ/UCI_Benchmarks/data/uci/power-plant/Folds5x2_pp.ods filter=lfs diff=lfs merge=lfs -text
42
+ SPC-UQ/UCI_Benchmarks/data/uci/power-plant/Folds5x2_pp.xlsx filter=lfs diff=lfs merge=lfs -text
SPC-UQ/.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
SPC-UQ/.idea/UQ_baseline.iml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="jdk" jdkName="py37" jdkType="Python SDK" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ <component name="PyDocumentationSettings">
9
+ <option name="format" value="PLAIN" />
10
+ <option name="myDocStringFormat" value="Plain" />
11
+ </component>
12
+ </module>
SPC-UQ/.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
5
+ <option name="ignoredPackages">
6
+ <value>
7
+ <list size="11">
8
+ <item index="0" class="java.lang.String" itemvalue="tqdm" />
9
+ <item index="1" class="java.lang.String" itemvalue="scipy" />
10
+ <item index="2" class="java.lang.String" itemvalue="tabulate" />
11
+ <item index="3" class="java.lang.String" itemvalue="scikit_learn" />
12
+ <item index="4" class="java.lang.String" itemvalue="matplotlib" />
13
+ <item index="5" class="java.lang.String" itemvalue="gpytorch" />
14
+ <item index="6" class="java.lang.String" itemvalue="torch" />
15
+ <item index="7" class="java.lang.String" itemvalue="setuptools" />
16
+ <item index="8" class="java.lang.String" itemvalue="numpy" />
17
+ <item index="9" class="java.lang.String" itemvalue="torchvision" />
18
+ <item index="10" class="java.lang.String" itemvalue="Pillow" />
19
+ </list>
20
+ </value>
21
+ </option>
22
+ </inspection_tool>
23
+ </profile>
24
+ </component>
SPC-UQ/.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
SPC-UQ/.idea/misc.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="py37" />
5
+ </component>
6
+ <component name="ProjectRootManager" version="2" project-jdk-name="py37" project-jdk-type="Python SDK" />
7
+ </project>
SPC-UQ/.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/UQ_baseline.iml" filepath="$PROJECT_DIR$/.idea/UQ_baseline.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
SPC-UQ/.idea/other.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="PySciProjectComponent">
4
+ <option name="PY_INTERACTIVE_PLOTS_SUGGESTED" value="true" />
5
+ </component>
6
+ </project>
SPC-UQ/.idea/workspace.xml ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="AutoImportSettings">
4
+ <option name="autoReloadType" value="SELECTIVE" />
5
+ </component>
6
+ <component name="ChangeListManager">
7
+ <list default="true" id="5a477d09-bea8-4806-81a6-0e58ccd074ce" name="Changes" comment="" />
8
+ <option name="SHOW_DIALOG" value="false" />
9
+ <option name="HIGHLIGHT_CONFLICTS" value="true" />
10
+ <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
11
+ <option name="LAST_RESOLUTION" value="IGNORE" />
12
+ </component>
13
+ <component name="FileTemplateManagerImpl">
14
+ <option name="RECENT_TEMPLATES">
15
+ <list>
16
+ <option value="Python Script" />
17
+ </list>
18
+ </option>
19
+ </component>
20
+ <component name="ProjectColorInfo">{
21
+ &quot;associatedIndex&quot;: 2
22
+ }</component>
23
+ <component name="ProjectId" id="30Vf7yOdfmRNkiAM7wBXYZUSVri" />
24
+ <component name="ProjectViewState">
25
+ <option name="hideEmptyMiddlePackages" value="true" />
26
+ <option name="showLibraryContents" value="true" />
27
+ </component>
28
+ <component name="PropertiesComponent">{
29
+ &quot;keyToString&quot;: {
30
+ &quot;Python.evidential.executor&quot;: &quot;Run&quot;,
31
+ &quot;Python.re.executor&quot;: &quot;Run&quot;,
32
+ &quot;Python.rename.executor&quot;: &quot;Run&quot;,
33
+ &quot;Python.run_cls_tests.executor&quot;: &quot;Run&quot;,
34
+ &quot;Python.run_cubic_tests.executor&quot;: &quot;Run&quot;,
35
+ &quot;Python.run_toy_tests.executor&quot;: &quot;Run&quot;,
36
+ &quot;Python.run_uci_dataset_tests (1).executor&quot;: &quot;Run&quot;,
37
+ &quot;Python.run_uci_dataset_tests (2).executor&quot;: &quot;Run&quot;,
38
+ &quot;Python.run_uci_dataset_tests.executor&quot;: &quot;Run&quot;,
39
+ &quot;RunOnceActivity.ShowReadmeOnStart&quot;: &quot;true&quot;,
40
+ &quot;last_opened_file_path&quot;: &quot;E:/Experiment/SPC-UQ/Depth_regression/trainers&quot;,
41
+ &quot;node.js.detected.package.eslint&quot;: &quot;true&quot;,
42
+ &quot;node.js.detected.package.tslint&quot;: &quot;true&quot;,
43
+ &quot;node.js.selected.package.eslint&quot;: &quot;(autodetect)&quot;,
44
+ &quot;node.js.selected.package.tslint&quot;: &quot;(autodetect)&quot;,
45
+ &quot;nodejs_package_manager_path&quot;: &quot;npm&quot;,
46
+ &quot;vue.rearranger.settings.migration&quot;: &quot;true&quot;
47
+ }
48
+ }</component>
49
+ <component name="RecentsManager">
50
+ <key name="CopyFile.RECENT_KEYS">
51
+ <recent name="E:\Experiment\SPC-UQ\Depth_regression\trainers" />
52
+ </key>
53
+ <key name="MoveFile.RECENT_KEYS">
54
+ <recent name="E:\Experiment\SPC-UQ\Depth_regression" />
55
+ </key>
56
+ </component>
57
+ <component name="RunManager" selected="Python.run_cls_tests">
58
+ <configuration name="run_cls_tests" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
59
+ <module name="UQ_baseline" />
60
+ <option name="ENV_FILES" value="" />
61
+ <option name="INTERPRETER_OPTIONS" value="" />
62
+ <option name="PARENT_ENVS" value="true" />
63
+ <envs>
64
+ <env name="PYTHONUNBUFFERED" value="1" />
65
+ </envs>
66
+ <option name="SDK_HOME" value="" />
67
+ <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/MNIST_Classification" />
68
+ <option name="IS_MODULE_SDK" value="true" />
69
+ <option name="ADD_CONTENT_ROOTS" value="true" />
70
+ <option name="ADD_SOURCE_ROOTS" value="true" />
71
+ <EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
72
+ <option name="SCRIPT_NAME" value="$PROJECT_DIR$/MNIST_Classification/run_cls_tests.py" />
73
+ <option name="PARAMETERS" value="" />
74
+ <option name="SHOW_COMMAND_LINE" value="false" />
75
+ <option name="EMULATE_TERMINAL" value="false" />
76
+ <option name="MODULE_MODE" value="false" />
77
+ <option name="REDIRECT_INPUT" value="false" />
78
+ <option name="INPUT_FILE" value="" />
79
+ <method v="2" />
80
+ </configuration>
81
+ <configuration name="run_cubic_tests" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
82
+ <module name="UQ_baseline" />
83
+ <option name="ENV_FILES" value="" />
84
+ <option name="INTERPRETER_OPTIONS" value="" />
85
+ <option name="PARENT_ENVS" value="true" />
86
+ <envs>
87
+ <env name="PYTHONUNBUFFERED" value="1" />
88
+ </envs>
89
+ <option name="SDK_HOME" value="" />
90
+ <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/Cubic_Regression" />
91
+ <option name="IS_MODULE_SDK" value="true" />
92
+ <option name="ADD_CONTENT_ROOTS" value="true" />
93
+ <option name="ADD_SOURCE_ROOTS" value="true" />
94
+ <EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
95
+ <option name="SCRIPT_NAME" value="$PROJECT_DIR$/Cubic_Regression/run_cubic_tests.py" />
96
+ <option name="PARAMETERS" value="" />
97
+ <option name="SHOW_COMMAND_LINE" value="false" />
98
+ <option name="EMULATE_TERMINAL" value="false" />
99
+ <option name="MODULE_MODE" value="false" />
100
+ <option name="REDIRECT_INPUT" value="false" />
101
+ <option name="INPUT_FILE" value="" />
102
+ <method v="2" />
103
+ </configuration>
104
+ <configuration name="run_toy_tests" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
105
+ <module name="UQ_baseline" />
106
+ <option name="ENV_FILES" value="" />
107
+ <option name="INTERPRETER_OPTIONS" value="" />
108
+ <option name="PARENT_ENVS" value="true" />
109
+ <envs>
110
+ <env name="PYTHONUNBUFFERED" value="1" />
111
+ </envs>
112
+ <option name="SDK_HOME" value="" />
113
+ <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/Toy_regression" />
114
+ <option name="IS_MODULE_SDK" value="true" />
115
+ <option name="ADD_CONTENT_ROOTS" value="true" />
116
+ <option name="ADD_SOURCE_ROOTS" value="true" />
117
+ <EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
118
+ <option name="SCRIPT_NAME" value="$PROJECT_DIR$/Toy_regression/run_toy_tests.py" />
119
+ <option name="PARAMETERS" value="" />
120
+ <option name="SHOW_COMMAND_LINE" value="false" />
121
+ <option name="EMULATE_TERMINAL" value="false" />
122
+ <option name="MODULE_MODE" value="false" />
123
+ <option name="REDIRECT_INPUT" value="false" />
124
+ <option name="INPUT_FILE" value="" />
125
+ <method v="2" />
126
+ </configuration>
127
+ <configuration name="run_uci_dataset_tests (1)" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
128
+ <module name="UQ_baseline" />
129
+ <option name="ENV_FILES" value="" />
130
+ <option name="INTERPRETER_OPTIONS" value="" />
131
+ <option name="PARENT_ENVS" value="true" />
132
+ <envs>
133
+ <env name="PYTHONUNBUFFERED" value="1" />
134
+ </envs>
135
+ <option name="SDK_HOME" value="" />
136
+ <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/UCI_Benchmarks" />
137
+ <option name="IS_MODULE_SDK" value="true" />
138
+ <option name="ADD_CONTENT_ROOTS" value="true" />
139
+ <option name="ADD_SOURCE_ROOTS" value="true" />
140
+ <EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
141
+ <option name="SCRIPT_NAME" value="$PROJECT_DIR$/UCI_Benchmarks/run_uci_dataset_tests.py" />
142
+ <option name="PARAMETERS" value="" />
143
+ <option name="SHOW_COMMAND_LINE" value="false" />
144
+ <option name="EMULATE_TERMINAL" value="false" />
145
+ <option name="MODULE_MODE" value="false" />
146
+ <option name="REDIRECT_INPUT" value="false" />
147
+ <option name="INPUT_FILE" value="" />
148
+ <method v="2" />
149
+ </configuration>
150
+ <configuration name="run_uci_dataset_tests (2)" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
151
+ <module name="UQ_baseline" />
152
+ <option name="ENV_FILES" value="" />
153
+ <option name="INTERPRETER_OPTIONS" value="" />
154
+ <option name="PARENT_ENVS" value="true" />
155
+ <envs>
156
+ <env name="PYTHONUNBUFFERED" value="1" />
157
+ </envs>
158
+ <option name="SDK_HOME" value="" />
159
+ <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/../UCI_Benchmarks" />
160
+ <option name="IS_MODULE_SDK" value="false" />
161
+ <option name="ADD_CONTENT_ROOTS" value="true" />
162
+ <option name="ADD_SOURCE_ROOTS" value="true" />
163
+ <EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
164
+ <option name="SCRIPT_NAME" value="$PROJECT_DIR$/../UCI_Benchmarks/run_uci_dataset_tests.py" />
165
+ <option name="PARAMETERS" value="" />
166
+ <option name="SHOW_COMMAND_LINE" value="false" />
167
+ <option name="EMULATE_TERMINAL" value="false" />
168
+ <option name="MODULE_MODE" value="false" />
169
+ <option name="REDIRECT_INPUT" value="false" />
170
+ <option name="INPUT_FILE" value="" />
171
+ <method v="2" />
172
+ </configuration>
173
+ <recent_temporary>
174
+ <list>
175
+ <item itemvalue="Python.run_cls_tests" />
176
+ <item itemvalue="Python.run_cubic_tests" />
177
+ <item itemvalue="Python.run_uci_dataset_tests (2)" />
178
+ <item itemvalue="Python.run_uci_dataset_tests (1)" />
179
+ <item itemvalue="Python.run_toy_tests" />
180
+ </list>
181
+ </recent_temporary>
182
+ </component>
183
+ <component name="SharedIndexes">
184
+ <attachedChunks>
185
+ <set>
186
+ <option value="bundled-js-predefined-1d06a55b98c1-91d5c284f522-JavaScript-PY-241.15989.155" />
187
+ <option value="bundled-python-sdk-babbdf50b680-7c6932dee5e4-com.jetbrains.pycharm.pro.sharedIndexes.bundled-PY-241.15989.155" />
188
+ </set>
189
+ </attachedChunks>
190
+ </component>
191
+ <component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
192
+ <component name="TaskManager">
193
+ <task active="true" id="Default" summary="Default task">
194
+ <changelist id="5a477d09-bea8-4806-81a6-0e58ccd074ce" name="Changes" comment="" />
195
+ <created>1753717498209</created>
196
+ <option name="number" value="Default" />
197
+ <option name="presentableId" value="Default" />
198
+ <updated>1753717498209</updated>
199
+ <workItem from="1753717499301" duration="61000" />
200
+ <workItem from="1753717566735" duration="10589000" />
201
+ <workItem from="1753789230404" duration="17167000" />
202
+ <workItem from="1753911784054" duration="686000" />
203
+ <workItem from="1753975576642" duration="9813000" />
204
+ <workItem from="1754180731050" duration="28000" />
205
+ <workItem from="1754221817091" duration="34326000" />
206
+ <workItem from="1754311631259" duration="1625000" />
207
+ <workItem from="1754430709706" duration="5414000" />
208
+ <workItem from="1754511717348" duration="1628000" />
209
+ <workItem from="1754839988337" duration="2408000" />
210
+ <workItem from="1754919749833" duration="13485000" />
211
+ <workItem from="1755004077990" duration="3927000" />
212
+ <workItem from="1756907577926" duration="3062000" />
213
+ <workItem from="1756995689716" duration="6564000" />
214
+ <workItem from="1757086350108" duration="1819000" />
215
+ <workItem from="1757163460686" duration="1011000" />
216
+ </task>
217
+ <servers />
218
+ </component>
219
+ <component name="TypeScriptGeneratedFilesManager">
220
+ <option name="version" value="3" />
221
+ </component>
222
+ <component name="XDebuggerManager">
223
+ <breakpoint-manager>
224
+ <default-breakpoints>
225
+ <breakpoint type="python-exception">
226
+ <properties notifyOnTerminate="true" exception="BaseException">
227
+ <option name="notifyOnTerminate" value="true" />
228
+ </properties>
229
+ </breakpoint>
230
+ </default-breakpoints>
231
+ </breakpoint-manager>
232
+ </component>
233
+ <component name="com.intellij.coverage.CoverageDataManagerImpl">
234
+ <SUITE FILE_PATH="coverage/UQ_baseline$evidential.coverage" NAME="evidential Coverage Results" MODIFIED="1753831104933" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/Toy_classification/trainers" />
235
+ <SUITE FILE_PATH="coverage/SPC_UQ$run_toy_tests.coverage" NAME="run_toy_tests Coverage Results" MODIFIED="1754511868967" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/Toy_regression" />
236
+ <SUITE FILE_PATH="coverage/SPC_UQ$rename.coverage" NAME="rename Coverage Results" MODIFIED="1754235285059" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/UCI_regression/trainers" />
237
+ <SUITE FILE_PATH="coverage/UQ_baseline$run_toy_tests.coverage" NAME="run_toy_tests Coverage Results" MODIFIED="1753976501175" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/Toy_regression" />
238
+ <SUITE FILE_PATH="coverage/UQ_baseline$run_cls_tests.coverage" NAME="run_cls_tests Coverage Results" MODIFIED="1753837116769" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/Toy_classification" />
239
+ <SUITE FILE_PATH="coverage/SPC_UQ$run_uci_dataset_tests__1_.coverage" NAME="run_uci_dataset_tests (1) Coverage Results" MODIFIED="1756909439135" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/UCI_Benchmarks" />
240
+ <SUITE FILE_PATH="coverage/SPC_UQ$run_uci_dataset_tests__2_.coverage" NAME="run_uci_dataset_tests (2) Coverage Results" MODIFIED="1756909699750" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/../UCI_Benchmarks" />
241
+ <SUITE FILE_PATH="coverage/SPC_UQ$run_uci_dataset_tests.coverage" NAME="run_uci_dataset_tests Coverage Results" MODIFIED="1754253016762" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/UCI_regression" />
242
+ <SUITE FILE_PATH="coverage/SPC_UQ$re.coverage" NAME="re Coverage Results" MODIFIED="1754263188586" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/Depth_regression/save" />
243
+ <SUITE FILE_PATH="coverage/UQ_baseline$run_uci_dataset_tests.coverage" NAME="run_uci_dataset_tests Coverage Results" MODIFIED="1753995527450" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/UCI_regression" />
244
+ <SUITE FILE_PATH="coverage/SPC_UQ$run_cubic_tests.coverage" NAME="run_cubic_tests Coverage Results" MODIFIED="1756996456492" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/Cubic_Regression" />
245
+ <SUITE FILE_PATH="coverage/SPC_UQ$run_cls_tests.coverage" NAME="run_cls_tests Coverage Results" MODIFIED="1757002010234" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/MNIST_Classification" />
246
+ </component>
247
+ </project>
SPC-UQ/Cubic_Regression/ConformalRegression.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import numpy as np
5
+
6
+
7
+ class ConformalRegressionNet(nn.Module):
8
+ """
9
+ Simple feedforward regression model with dropout.
10
+ Output: point prediction only (no uncertainty head).
11
+ """
12
+ def __init__(self, input_dim=1, hidden_dim=64, output_dim=1):
13
+ super().__init__()
14
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
15
+ self.fc2 = nn.Linear(hidden_dim, hidden_dim)
16
+ self.out = nn.Linear(hidden_dim, output_dim)
17
+ self.relu = nn.ReLU()
18
+ self.dropout = nn.Dropout(p=0.2)
19
+
20
+ def forward(self, x):
21
+ x = self.relu(self.fc1(x))
22
+ x = self.dropout(x)
23
+ x = self.relu(self.fc2(x))
24
+ x = self.dropout(x)
25
+ return self.out(x)
26
+
27
+
28
+ class ConformalRegressor:
29
+ """
30
+ Quantile-based conformal prediction regression model.
31
+ """
32
+ def __init__(self, quantile=0.9, learning_rate=5e-3):
33
+ torch.manual_seed(24)
34
+ self.quantile = quantile
35
+ self.model = ConformalRegressionNet()
36
+ self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
37
+ self.criterion = nn.MSELoss()
38
+ self.quantile_up = 0.0
39
+ self.quantile_down = 0.0
40
+
41
+ def train(self, x, y, num_epochs=5000):
42
+ self.model.train()
43
+ for epoch in range(num_epochs):
44
+ self.optimizer.zero_grad()
45
+ pred = self.model(x)
46
+ loss = self.criterion(pred, y)
47
+ loss.backward()
48
+ self.optimizer.step()
49
+ if (epoch + 1) % 100 == 0:
50
+ print(f"[Epoch {epoch + 1}/{num_epochs}] Loss: {loss.item():.4f}")
51
+
52
+ def calibrate(self, x_calib, y_calib):
53
+ """
54
+ Compute empirical quantiles from residuals on calibration set.
55
+ """
56
+ self.model.eval()
57
+ with torch.no_grad():
58
+ pred = self.model(x_calib).detach().cpu().numpy().squeeze()
59
+ y_calib_np = y_calib.detach().cpu().numpy().squeeze()
60
+ residuals = y_calib_np - pred
61
+
62
+ res_up = residuals[residuals > 0]
63
+ res_down = -residuals[residuals <= 0]
64
+
65
+ self.quantile_up = np.quantile(res_up, self.quantile) if len(res_up) > 0 else 0.0
66
+ self.quantile_down = np.quantile(res_down, self.quantile) if len(res_down) > 0 else 0.0
67
+
68
+ def predict(self, x):
69
+ self.model.eval()
70
+ with torch.no_grad():
71
+ y_pred = self.model(x).detach().cpu().numpy().squeeze()
72
+ upper = y_pred + self.quantile_up
73
+ lower = y_pred - self.quantile_down
74
+ uncertainty = np.zeros_like(y_pred) # Conformal prediction doesn't model epistemic
75
+
76
+ return y_pred, upper, lower, uncertainty
SPC-UQ/Cubic_Regression/DeepEnsembleRegression.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import numpy as np
5
+
6
+ def nll_loss(mean, log_var, target):
7
+ """
8
+ Negative log-likelihood loss for Gaussian output
9
+ """
10
+ var = torch.exp(log_var)
11
+ loss = 0.5 * torch.log(2 * np.pi * var) + 0.5 * ((target - mean) ** 2) / var
12
+ return loss.mean()
13
+
14
+ class NLLRegressionNN(nn.Module):
15
+ """
16
+ Neural network for regression with Gaussian likelihood output
17
+ Outputs mean and log-variance
18
+ """
19
+ def __init__(self, input_dim=1, hidden_dim=64):
20
+ super().__init__()
21
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
22
+ self.hidden = nn.Linear(hidden_dim, hidden_dim)
23
+ self.relu = nn.ReLU()
24
+ self.fc2 = nn.Linear(hidden_dim, 2) # Outputs: mean and log_variance
25
+
26
+ def forward(self, x):
27
+ x = self.relu(self.fc1(x))
28
+ x = self.relu(self.hidden(x))
29
+ x = self.fc2(x)
30
+ mean = x[:, :1]
31
+ log_var = x[:, 1:]
32
+ return mean, log_var
33
+
34
+ class DeepEnsemble:
35
+ """
36
+ Deep Ensemble for probabilistic regression with Gaussian likelihood
37
+ """
38
+ def __init__(self, num_models=5, learning_rate=5e-3):
39
+ torch.manual_seed(42)
40
+ self.models = [NLLRegressionNN() for _ in range(num_models)]
41
+ self.optimizers = [optim.Adam(model.parameters(), lr=learning_rate) for model in self.models]
42
+
43
+ def train(self, data, target, num_epochs=5000):
44
+ """
45
+ Train all models independently on the same data
46
+ """
47
+ for idx, (model, optimizer) in enumerate(zip(self.models, self.optimizers), start=1):
48
+ torch.manual_seed(idx + 42) # Different seed for each model
49
+ for epoch in range(num_epochs):
50
+ model.train()
51
+ optimizer.zero_grad()
52
+ mean, log_var = model(data)
53
+ loss = nll_loss(mean, log_var, target)
54
+ loss.backward()
55
+ optimizer.step()
56
+
57
+ if (epoch + 1) % 500 == 0:
58
+ print(f"Model {idx}, Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")
59
+
60
+ def predict(self, data):
61
+ """
62
+ Return ensemble mean, uncertainty, and prediction interval
63
+ """
64
+ means = []
65
+ variances = []
66
+
67
+ with torch.no_grad():
68
+ for model in self.models:
69
+ model.eval()
70
+ mean, log_var = model(data)
71
+ means.append(mean.numpy())
72
+ variances.append(torch.exp(log_var).numpy())
73
+
74
+ means = np.array(means) # (num_models, batch, 1)
75
+ variances = np.array(variances)
76
+
77
+ # Mean and total predictive variance (mean of model variances)
78
+ mean_ensemble = np.mean(means, axis=0)
79
+ var_ensemble = np.mean(variances, axis=0)
80
+
81
+ # Epistemic uncertainty: variance across model means
82
+ epistemic_uncertainty = np.var(means, axis=0)
83
+
84
+ std_ensemble = np.sqrt(var_ensemble)
85
+ y_low = mean_ensemble - 2 * std_ensemble
86
+ y_high = mean_ensemble + 2 * std_ensemble
87
+
88
+ return (
89
+ mean_ensemble.squeeze(),
90
+ y_high.squeeze(),
91
+ y_low.squeeze(),
92
+ epistemic_uncertainty.squeeze()
93
+ )
SPC-UQ/Cubic_Regression/EDLQuantileRegression.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ import math
8
+
9
+
10
+ class DenseNormalGamma(nn.Module):
11
+ def __init__(self, in_features, out_units):
12
+ super().__init__()
13
+ self.out_units = int(out_units)
14
+ self.linear = nn.Linear(in_features, 4 * self.out_units)
15
+
16
+ def evidence(self, x):
17
+ return F.softplus(x)
18
+
19
+ def forward(self, x):
20
+ output = self.linear(x)
21
+ mu, log_v, log_alpha, log_beta = torch.chunk(output, chunks=4, dim=-1)
22
+ v = self.evidence(log_v)
23
+ alpha = self.evidence(log_alpha) + 1.0
24
+ beta = self.evidence(log_beta)
25
+ return torch.cat([mu, v, alpha, beta], dim=-1)
26
+
27
+ def extra_repr(self):
28
+ return f"out_units={self.out_units}"
29
+
30
+
31
+ class EDLQRNet(nn.Module):
32
+ def __init__(self, input_dim=1, num_quantiles=3, hidden_dim=64, num_layers=2, activation=nn.ReLU()):
33
+ super().__init__()
34
+ layers = []
35
+ in_features = input_dim
36
+ for _ in range(num_layers):
37
+ layers.append(nn.Linear(in_features, hidden_dim))
38
+ layers.append(activation)
39
+ in_features = hidden_dim
40
+ layers.append(DenseNormalGamma(in_features, num_quantiles))
41
+ self.network = nn.Sequential(*layers)
42
+
43
+ def forward(self, x):
44
+ output = self.network(x)
45
+ mu, v, alpha, beta = torch.chunk(output, 4, dim=-1)
46
+ return mu, v, alpha, beta
47
+
48
+
49
+ def nig_nll(y, mu, v, alpha, beta, wi_mean, quantile, reduce=True):
50
+ tau2 = 2.0 / (quantile * (1.0 - quantile))
51
+ two_b_lambda = 4.0 * beta * (1.0 + tau2 * wi_mean * v)
52
+
53
+ nll = 0.5 * torch.log(math.pi / v) \
54
+ - alpha * torch.log(two_b_lambda) \
55
+ + (alpha + 0.5) * torch.log(v * (y - mu) ** 2 + two_b_lambda) \
56
+ + torch.lgamma(alpha) \
57
+ - torch.lgamma(alpha + 0.5)
58
+
59
+ return torch.mean(nll) if reduce else nll
60
+
61
+
62
+ def kl_nig(mu1, v1, a1, b1, mu2, v2, a2, b2):
63
+ kl = 0.5 * (a1 - 1.0) / b1 * (v2 * (mu2 - mu1) ** 2) \
64
+ + 0.5 * (v2 / v1) \
65
+ - 0.5 * torch.log(torch.abs(v2) / torch.abs(v1)) \
66
+ - 0.5 \
67
+ + a2 * torch.log(b1 / b2) \
68
+ - (torch.lgamma(a1) - torch.lgamma(a2)) \
69
+ + (a1 - a2) * torch.digamma(a1) \
70
+ - (b1 - b2) * a1 / b1
71
+ return kl
72
+
73
+
74
+ def tilted_loss(q, e):
75
+ return torch.maximum(q * e, (q - 1.0) * e)
76
+
77
+
78
+ def nig_regularization(y, mu, v, alpha, beta, wi_mean, quantile, lambda_reg=0.01, reduce=True, use_kl=False):
79
+ theta = (1.0 - 2.0 * quantile) / (quantile * (1.0 - quantile))
80
+ error = tilted_loss(quantile, y - mu)
81
+
82
+ if use_kl:
83
+ kl_val = kl_nig(mu, v, alpha, beta, mu, lambda_reg, 1.0 + lambda_reg, beta)
84
+ reg = error * kl_val
85
+ else:
86
+ evidential_term = 2.0 * v + alpha + 1.0 / beta
87
+ reg = error * evidential_term
88
+
89
+ return torch.mean(reg) if reduce else reg
90
+
91
+
92
+ def quantile_evidential_loss(y_true, mu, v, alpha, beta, quantile, coeff=1.0, reduce=True):
93
+ theta = (1.0 - 2.0 * quantile) / (quantile * (1.0 - quantile))
94
+ wi_mean = beta / (alpha - 1.0)
95
+ mu_adj = mu + theta * wi_mean
96
+
97
+ loss_nll = nig_nll(y_true, mu_adj, v, alpha, beta, wi_mean, quantile, reduce)
98
+ loss_reg = nig_regularization(y_true, mu, v, alpha, beta, wi_mean, quantile, reduce)
99
+ return loss_nll + coeff * loss_reg
100
+
101
+
102
+ class EDLQuantileRegressor:
103
+ def __init__(self, tau_low=0.05, tau_high=0.95, learning_rate=5e-4):
104
+ torch.manual_seed(42)
105
+ self.model = EDLQRNet()
106
+ self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
107
+ self.quantiles = [tau_low, 0.5, tau_high]
108
+ self.coeff = 0.05
109
+
110
+ def loss_function(self, y, mu, v, alpha, beta):
111
+ total_loss = 0.0
112
+ for i, q in enumerate(self.quantiles):
113
+ total_loss += quantile_evidential_loss(
114
+ y, mu[:, i].unsqueeze(1), v[:, i].unsqueeze(1),
115
+ alpha[:, i].unsqueeze(1), beta[:, i].unsqueeze(1),
116
+ q, coeff=self.coeff
117
+ )
118
+ return total_loss
119
+
120
+ def train(self, x, y, num_epochs=5000):
121
+ self.model.train()
122
+ for epoch in range(num_epochs):
123
+ self.optimizer.zero_grad()
124
+ mu, v, alpha, beta = self.model(x)
125
+ loss = self.loss_function(y, mu, v, alpha, beta)
126
+ loss.backward()
127
+ self.optimizer.step()
128
+
129
+ if (epoch + 1) % 100 == 0:
130
+ print(f"Epoch [{epoch + 1}/{num_epochs}] Loss: {loss.item():.4f}")
131
+
132
+ def predict(self, x):
133
+ self.model.eval()
134
+ with torch.no_grad():
135
+ mu, v, alpha, beta = self.model(x)
136
+ mu_low, mu_mid, mu_high = torch.unbind(mu, dim=1)
137
+ v_low, v_mid, v_high = torch.unbind(v, dim=1)
138
+ alpha_low, alpha_mid, alpha_high = torch.unbind(alpha, dim=1)
139
+ beta_low, beta_mid, beta_high = torch.unbind(beta, dim=1)
140
+
141
+ aleatoric = beta_mid / (alpha_mid - 1.0)
142
+ epistemic_mid = beta_mid / (v_mid * (alpha_mid - 1.0))
143
+ epistemic_low = beta_low / (v_low * (alpha_low - 1.0))
144
+ epistemic_high = beta_high / (v_high * (alpha_high - 1.0))
145
+ uncertainty = epistemic_mid
146
+
147
+ plt.figure(figsize=(10, 6))
148
+ plt.plot(x, epistemic_mid, label='Mid', color='orange')
149
+ plt.plot(x, epistemic_low, label='Low', color='red')
150
+ plt.plot(x, epistemic_high, label='High', color='green')
151
+ plt.legend()
152
+ plt.title('Epistemic Uncertainty')
153
+ plt.show()
154
+
155
+ return mu_mid.numpy().squeeze(), mu_high.numpy().squeeze(), mu_low.numpy().squeeze(), uncertainty.numpy().squeeze()
SPC-UQ/Cubic_Regression/EDLRegression.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ from scipy.stats import t as student_t
6
+ import numpy as np
7
+ import math
8
+
9
+
10
+ class EDLRegressionNet(nn.Module):
11
+ def __init__(self, input_dim=1, hidden_dim=64, output_dim=1):
12
+ super().__init__()
13
+ self.hidden1 = nn.Linear(input_dim, hidden_dim)
14
+ self.hidden2 = nn.Linear(hidden_dim, hidden_dim)
15
+ self.output_mu = nn.Linear(hidden_dim, output_dim)
16
+ self.output_logv = nn.Linear(hidden_dim, output_dim)
17
+ self.output_alpha = nn.Linear(hidden_dim, output_dim)
18
+ self.output_beta = nn.Linear(hidden_dim, output_dim)
19
+ self.activation = nn.ReLU()
20
+
21
+ def forward(self, x):
22
+ x = self.activation(self.hidden1(x))
23
+ x = self.activation(self.hidden2(x))
24
+ mu = self.output_mu(x)
25
+ v = F.softplus(self.output_logv(x))
26
+ alpha = F.softplus(self.output_alpha(x)) + 1.0 + 1e-6
27
+ beta = F.softplus(self.output_beta(x))
28
+ return mu, v, alpha, beta
29
+
30
+
31
+ def nig_nll(y, mu, v, alpha, beta, reduce=True):
32
+ two_b_lambda = 2 * beta * (1 + v)
33
+
34
+ nll = 0.5 * torch.log(torch.tensor(np.pi) / v) \
35
+ - alpha * torch.log(two_b_lambda) \
36
+ + (alpha + 0.5) * torch.log(v * (y - mu) ** 2 + two_b_lambda) \
37
+ + torch.lgamma(alpha) \
38
+ - torch.lgamma(alpha + 0.5)
39
+
40
+ return torch.mean(nll) if reduce else nll
41
+
42
+
43
+ def kl_nig(mu1, v1, a1, b1, mu2, v2, a2, b2):
44
+ eps = 1e-6
45
+ v1 = torch.clamp(v1, min=eps)
46
+ v2 = torch.clamp(v2, min=eps)
47
+ b1 = torch.clamp(b1, min=eps)
48
+ b2 = torch.clamp(b2, min=eps)
49
+
50
+ term1 = 0.5 * (a1 - 1) / b1 * (v2 * (mu2 - mu1) ** 2)
51
+ term2 = 0.5 * v2 / v1
52
+ term3 = -0.5 * torch.log(v2 / v1)
53
+ term4 = -0.5
54
+ term5 = a2 * torch.log(b1 / b2)
55
+ term6 = -(torch.lgamma(a1) - torch.lgamma(a2))
56
+ term7 = (a1 - a2) * torch.digamma(a1)
57
+ term8 = -(b1 - b2) * a1 / b1
58
+
59
+ kl = term1 + term2 + term3 + term4 + term5 + term6 + term7 + term8
60
+ return kl
61
+
62
+
63
+ def nig_regularization(y, mu, v, alpha, beta, omega=0.01, reduce=True, use_kl=False):
64
+ error = torch.abs(y - mu)
65
+
66
+ if use_kl:
67
+ kl = kl_nig(mu, v, alpha, beta, mu, omega, 1 + omega, beta)
68
+ reg = error * kl
69
+ else:
70
+ evidential = 2 * v + alpha
71
+ reg = error * evidential
72
+
73
+ return reg.mean() if reduce else reg
74
+
75
+
76
+ def edl_loss(y, mu, v, alpha, beta, lam=0.0, reduce=True, return_components=False):
77
+ nll = nig_nll(y, mu, v, alpha, beta, reduce=reduce)
78
+ reg = nig_regularization(y, mu, v, alpha, beta, reduce=reduce)
79
+ loss = nll # optionally: loss = nll + lam * reg
80
+
81
+ return (loss, (nll, reg)) if return_components else loss
82
+
83
+
84
+ def predictive_interval(mu, v, alpha, beta, confidence=0.95):
85
+ mu = torch.as_tensor(mu)
86
+ v = torch.as_tensor(v)
87
+ alpha = torch.as_tensor(alpha)
88
+ beta = torch.as_tensor(beta)
89
+
90
+ dof = 2.0 * alpha
91
+ scale = torch.sqrt((1.0 + v) * beta / (alpha * v))
92
+ lower_q = (1.0 - confidence) / 2.0
93
+ upper_q = 1.0 - lower_q
94
+
95
+ t_l = student_t.ppf(lower_q, df=dof.cpu().numpy())
96
+ t_u = student_t.ppf(upper_q, df=dof.cpu().numpy())
97
+
98
+ lower = mu + torch.from_numpy(t_l).to(mu.device) * scale
99
+ upper = mu + torch.from_numpy(t_u).to(mu.device) * scale
100
+
101
+ return lower, upper
102
+
103
+
104
+ class EDLRegressor:
105
+ def __init__(self, learning_rate=5e-4):
106
+ torch.manual_seed(33)
107
+ self.model = EDLRegressionNet()
108
+ self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
109
+ self.criterion = edl_loss
110
+ self.lambda_ = 0.01
111
+
112
+ def train(self, x, y, num_epochs=5000):
113
+ torch.manual_seed(33)
114
+ self.model.train()
115
+ for epoch in range(num_epochs):
116
+ self.optimizer.zero_grad()
117
+ mu, v, alpha, beta = self.model(x)
118
+ loss = self.criterion(y, mu, v, alpha, beta, lam=self.lambda_)
119
+ loss.backward()
120
+ self.optimizer.step()
121
+ if (epoch + 1) % 100 == 0:
122
+ print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")
123
+
124
+ def predict(self, x):
125
+ self.model.eval()
126
+ with torch.no_grad():
127
+ mu, v, alpha, beta = self.model(x)
128
+
129
+ # Uncertainty decomposition
130
+ aleatoric = torch.sqrt(beta / (alpha - 1.0 + 1e-6))
131
+ epistemic = torch.sqrt(beta / (v * (alpha - 1.0 + 1e-6)))
132
+
133
+ mu_np = mu.detach().cpu().numpy()
134
+ aleatoric_np = aleatoric.detach().cpu().numpy()
135
+ epistemic_np = epistemic.detach().cpu().numpy()
136
+
137
+ lower, upper = predictive_interval(mu, v, alpha, beta, confidence=0.95)
138
+ lower_np = lower.detach().cpu().numpy()
139
+ upper_np = upper.detach().cpu().numpy()
140
+
141
+ return mu_np.squeeze(), upper_np.squeeze(), lower_np.squeeze(), epistemic_np.squeeze()
SPC-UQ/Cubic_Regression/QROC.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import numpy as np
5
+ from torch.utils.data import DataLoader, TensorDataset
6
+
7
+ def pinball_loss(q_pred, target, tau):
8
+ """
9
+ Compute the quantile (pinball) loss for a given quantile level tau
10
+ """
11
+ error = target - q_pred
12
+ return torch.mean(torch.max(tau * error, (tau - 1) * error))
13
+
14
+ def multi_quantile_loss(q_low, q_mid, q_high, target, tau_low=0.05, tau_high=0.95):
15
+ """
16
+ Compute total loss across low, median, and high quantiles
17
+ """
18
+ loss_l = pinball_loss(q_low, target, tau_low)
19
+ loss_m = pinball_loss(q_mid, target, 0.5)
20
+ loss_h = pinball_loss(q_high, target, tau_high)
21
+ return loss_l + loss_m + loss_h
22
+
23
+
24
+ class QuantileRegressionNN(nn.Module):
25
+ """
26
+ Fully-connected quantile regression network with three outputs:
27
+ lower, median, upper quantiles.
28
+ """
29
+ def __init__(self, input_dim=1, hidden_dim=64):
30
+ super().__init__()
31
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
32
+ self.fc2 = nn.Linear(hidden_dim, hidden_dim)
33
+ self.out = nn.Linear(hidden_dim, 3)
34
+ self.relu = nn.ReLU()
35
+
36
+ def forward(self, x):
37
+ x = self.relu(self.fc1(x))
38
+ x = self.relu(self.fc2(x))
39
+ out = self.out(x)
40
+ return out[:, :1], out[:, 1:2], out[:, 2:]
41
+
42
+ def extract_features(self, x):
43
+ """
44
+ Extract penultimate-layer features (for certificate head)
45
+ """
46
+ x = self.relu(self.fc1(x))
47
+ x = self.relu(self.fc2(x))
48
+ return x
49
+
50
+
51
+ def build_certificate_head(features, out_dim=20, epochs=500):
52
+ """
53
+ Train an orthogonal linear projection (certificate head) on extracted features
54
+ """
55
+ projection = nn.Linear(features.size(1), out_dim)
56
+ loader = DataLoader(TensorDataset(features), shuffle=True, batch_size=128)
57
+ optimizer = optim.Adam(projection.parameters())
58
+
59
+ for _ in range(epochs):
60
+ for (feature_batch,) in loader:
61
+ optimizer.zero_grad()
62
+ output = projection(feature_batch)
63
+ cert_loss = output.pow(2).mean()
64
+
65
+ identity = torch.eye(out_dim, device=projection.weight.device)
66
+ ortho_penalty = (projection.weight @ projection.weight.T - identity).pow(2).mean()
67
+
68
+ (cert_loss + ortho_penalty).backward()
69
+ optimizer.step()
70
+
71
+ return projection
72
+
73
+
74
+ class QROC:
75
+ """
76
+ Single-model Quantile Regression with Orthogonal Certificate head (QROC)
77
+ Provides aleatoric and epistemic uncertainty estimation
78
+ """
79
+ def __init__(self, learning_rate=5e-3, tau_low=0.05, tau_high=0.95):
80
+ torch.manual_seed(42)
81
+ self.model = QuantileRegressionNN()
82
+ self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
83
+ self.tau_low = tau_low
84
+ self.tau_high = tau_high
85
+ self.certificate_head = None
86
+
87
+ def train(self, x, y, epochs=3000):
88
+ """
89
+ Train the quantile regression network and then fit certificate head on extracted features
90
+ """
91
+ for epoch in range(epochs):
92
+ self.model.train()
93
+ self.optimizer.zero_grad()
94
+ q_low, q_mid, q_high = self.model(x)
95
+ loss = multi_quantile_loss(q_low, q_mid, q_high, y, self.tau_low, self.tau_high)
96
+ loss.backward()
97
+ self.optimizer.step()
98
+
99
+ if (epoch + 1) % 500 == 0:
100
+ print(f"[{epoch+1}/{epochs}] Loss: {loss.item():.4f}")
101
+
102
+ # Train certificate head using the final feature representation
103
+ self.model.eval()
104
+ with torch.no_grad():
105
+ features = self.model.extract_features(x).detach()
106
+ self.certificate_head = build_certificate_head(features)
107
+
108
+ def predict(self, x):
109
+ """
110
+ Predict quantiles and return aleatoric and epistemic uncertainties
111
+ """
112
+ self.model.eval()
113
+ with torch.no_grad():
114
+ q_low, q_mid, q_high = self.model(x)
115
+
116
+ # Epistemic: projection energy via orthogonal certificate head
117
+ features = self.model.extract_features(x)
118
+ epistemic = self.certificate_head(features).pow(2).mean(dim=1).cpu().numpy()
119
+
120
+ return (
121
+ q_mid.squeeze().numpy(),
122
+ q_high.squeeze().numpy(),
123
+ q_low.squeeze().numpy(),
124
+ epistemic
125
+ )
SPC-UQ/Cubic_Regression/SPCRegression.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ import torch.nn.functional as F
6
+
7
+ def pinball_loss(q_pred, target, tau):
8
+ """Standard quantile regression loss."""
9
+ errors = target - q_pred
10
+ loss = torch.max(tau * errors, (tau - 1) * errors)
11
+ return torch.mean(loss)
12
+
13
+ def cali_loss(y_pred, y_true, q, scale=True):
14
+ """
15
+ Calibration loss for quantile regression.
16
+ Penalizes over- or under-coverage relative to quantile level q.
17
+ """
18
+ diff = y_true - y_pred
19
+ under_mask = (y_true <= y_pred)
20
+ over_mask = ~under_mask
21
+
22
+ coverage = torch.mean(under_mask.float())
23
+
24
+ if coverage < q:
25
+ loss = torch.mean(diff[over_mask])
26
+ else:
27
+ loss = torch.mean(-diff[under_mask])
28
+
29
+ if scale:
30
+ loss *= torch.abs(q - coverage)
31
+ return loss
32
+
33
+ class SPCRegressionNet(nn.Module):
34
+ """
35
+ Neural network that predicts:
36
+ - point estimate (v)
37
+ - MAR (mean absolute residual)
38
+ - MAR up/down (for epistemic decomposition)
39
+ - QR up/down (for aleatoric decomposition)
40
+ """
41
+ def __init__(self, input_dim=1, hidden_dim=64):
42
+ super(SPCRegressionNet, self).__init__()
43
+ self.hidden = nn.Linear(input_dim, hidden_dim)
44
+ self.hidden2 = nn.Linear(hidden_dim, hidden_dim)
45
+ self.hidden3 = nn.Linear(hidden_dim, hidden_dim)
46
+
47
+ self.relu = nn.ReLU()
48
+ # self.dropout = nn.Dropout(p=0.2)
49
+ self.output_v = nn.Linear(hidden_dim, 1)
50
+ self.output_uq = nn.Linear(hidden_dim, 5)
51
+
52
+
53
+ def forward(self, x):
54
+ x = self.hidden(x)
55
+ x = self.relu(x)
56
+ x = self.hidden2(x)
57
+ x = self.relu(x)
58
+ v = self.output_v(x)
59
+ x = self.hidden3(x)
60
+ x = self.relu(x)
61
+ output = self.output_uq(x)
62
+ mar, mar_up, mar_down, q_up, q_down = torch.chunk(output, 5, dim=-1)
63
+
64
+ q_up = F.softplus(q_up)
65
+ q_down = F.softplus(q_down)
66
+ return v, mar, mar_up, mar_down, q_up, q_down
67
+
68
+ class SPCregression:
69
+ """
70
+ Trainer and predictor for SPC UQ model.
71
+ Supports joint or stagewise training strategies.
72
+ """
73
+ def __init__(self, learning_rate=5e-3):
74
+ torch.manual_seed(42)
75
+ self.model = SPCRegressionNet()
76
+ self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
77
+ self.optimizer1 = optim.Adam(list(self.model.hidden.parameters()) + list(self.model.hidden2.parameters()) + list(self.model.output_v.parameters()),
78
+ lr=learning_rate,weight_decay=1e-4)
79
+ self.optimizer2 = optim.Adam(list(self.model.hidden3.parameters()) + list(self.model.output_uq.parameters()),
80
+ lr=learning_rate,weight_decay=1e-4)
81
+ self.criterion = nn.MSELoss()
82
+ self.criterion2=nn.L1Loss()
83
+
84
+ def mar_loss(self, y, predictions, mar, mar_up, mar_down, q_up, q_down):
85
+ """Computes loss for MAR and QR heads."""
86
+ residual = abs(y - predictions)
87
+ diff = (y - predictions.detach())
88
+ loss_mar = self.criterion(mar, residual)
89
+
90
+ mask_up = (diff > 0)
91
+ mask_down = (diff < 0)
92
+ loss_mar_up = self.criterion(mar_up[mask_up], (y[mask_up] - predictions[mask_up]))
93
+ loss_mar_down = self.criterion(mar_down[mask_down], (predictions[mask_down] - y[mask_down]))
94
+
95
+ # loss_q_up = pinball_loss(q_up[mask_up], (y[mask_up] - predictions[mask_up]), 0.95)
96
+ # loss_q_down = pinball_loss(q_down[mask_down], (predictions[mask_down] - y[mask_down]), 0.95)
97
+ loss_cali_up = cali_loss(q_up[mask_up], (y[mask_up] - predictions[mask_up]), 0.95)
98
+ loss_cali_down = cali_loss(q_down[mask_down], (predictions[mask_down] - y[mask_down]), 0.95)
99
+
100
+
101
+ loss = loss_mar + loss_mar_up + loss_mar_down +(loss_cali_up + loss_cali_down)
102
+ # + 0.2 * (loss_q_up + loss_q_down) \
103
+ # + 0.8 * (loss_cali_up + loss_cali_down) \
104
+ return loss
105
+
106
+ def train(self, data, target, num_epochs=5000, strategy='stagewise'):
107
+ """
108
+ Train the model using either:
109
+ - 'joint': full loss on all components
110
+ - 'stagewise': first fit task head, then UQ heads
111
+ """
112
+ torch.manual_seed(42)
113
+ self.model.train()
114
+
115
+ if strategy== 'joint':
116
+ for epoch in range(num_epochs):
117
+ self.optimizer.zero_grad()
118
+ predictions, mar, mar_up, mar_down, q_up, q_down = self.model(data)
119
+ loss = self.criterion(predictions, target)+self.mar_loss(target, predictions, mar, mar_up, mar_down, q_up, q_down)
120
+ loss.backward()
121
+ self.optimizer.step()
122
+ if (epoch + 1) % 100 == 0:
123
+ print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")
124
+
125
+ if strategy== 'stagewise':
126
+ for epoch in range(num_epochs):
127
+ self.optimizer1.zero_grad()
128
+ predictions, mar, mar_up, mar_down, q_up, q_down = self.model(data)
129
+ loss = self.criterion(predictions, target)
130
+ loss.backward()
131
+ self.optimizer1.step()
132
+ if (epoch + 1) % 100 == 0:
133
+ print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")
134
+
135
+ for epoch in range(num_epochs):
136
+ self.optimizer2.zero_grad()
137
+ predictions, mar, mar_up, mar_down, q_up, q_down = self.model(data)
138
+ loss = self.mar_loss(target, predictions, mar, mar_up, mar_down, q_up, q_down)
139
+ loss.backward()
140
+ self.optimizer2.step()
141
+ if (epoch + 1) % 100 == 0:
142
+ print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")
143
+
144
+ def predict(self, data, calibration=False):
145
+ """Run prediction and return interval bounds and uncertainty estimate."""
146
+ self.model.eval()
147
+ with torch.no_grad():
148
+ predictions, mar, mar_up, mar_down, q_up, q_down = self.model(data)
149
+
150
+ v = predictions.numpy()
151
+ mar = mar.detach().numpy()
152
+ mar_up = mar_up.detach().numpy()
153
+ mar_down = mar_down.detach().numpy()
154
+ q_up = q_up.detach().numpy()
155
+ q_down = q_down.detach().numpy()
156
+
157
+ if calibration:
158
+ # Calibration adjustment based on Self-consistency
159
+ d_up = (mar * mar_down) / ((2 * mar_down - mar) * mar_up)
160
+ d_down = (mar * mar_up) / ((2 * mar_up - mar) * mar_down)
161
+ d_up = np.clip(d_up, 1, None)
162
+ d_down = np.clip(d_down, 1, None)
163
+ q_up *= d_up
164
+ q_down *= d_down
165
+
166
+ high_bound = v + q_up
167
+ low_bound = v - q_down
168
+
169
+ # Self-consistency Verification
170
+ uncertainty = (abs(2 * mar_up * mar_down - mar * (mar_up + mar_down)))
171
+
172
+
173
+ return v.squeeze(), high_bound.squeeze(), low_bound.squeeze(), uncertainty.squeeze()
SPC-UQ/Cubic_Regression/__pycache__/ConformalRegression.cpython-37.pyc ADDED
Binary file (3.08 kB). View file
 
SPC-UQ/Cubic_Regression/__pycache__/DeepEnsembleRegression.cpython-37.pyc ADDED
Binary file (3.48 kB). View file
 
SPC-UQ/Cubic_Regression/__pycache__/EDLQuantileRegression.cpython-37.pyc ADDED
Binary file (6.07 kB). View file
 
SPC-UQ/Cubic_Regression/__pycache__/EDLRegression.cpython-37.pyc ADDED
Binary file (4.79 kB). View file
 
SPC-UQ/Cubic_Regression/__pycache__/QROC.cpython-37.pyc ADDED
Binary file (4.47 kB). View file
 
SPC-UQ/Cubic_Regression/__pycache__/SPCRegression.cpython-37.pyc ADDED
Binary file (5.2 kB). View file
 
SPC-UQ/Cubic_Regression/run_cubic_tests.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import argparse
4
+ import matplotlib.pyplot as plt
5
+ import matplotlib.cm as cm
6
+ from SPCRegression import SPCregression
7
+ from DeepEnsembleRegression import DeepEnsemble
8
+ from ConformalRegression import ConformalRegressor
9
+ from EDLRegression import EDLRegressor
10
+ from EDLQuantileRegression import EDLQuantileRegressor
11
+ from QROC import QROC
12
+ from scipy.stats import binned_statistic
13
+
14
+
15
+ def ece_pi(y_true, pred_lower, pred_upper, num_bins=10):
16
+ """Compute Expected Calibration Error (ECE) based on prediction interval width bins."""
17
+ N = y_true.shape[0]
18
+ in_interval = ((y_true >= pred_lower) & (y_true <= pred_upper)).astype(float)
19
+ widths = pred_upper - pred_lower
20
+ min_w, max_w = np.min(widths), np.max(widths)
21
+
22
+ if min_w == max_w:
23
+ return np.abs(in_interval.mean() - 1.0)
24
+
25
+ bin_edges = np.linspace(min_w, max_w, num_bins + 1)
26
+ ece = 0.0
27
+
28
+ for i in range(num_bins):
29
+ bin_mask = (widths >= bin_edges[i]) & (widths < bin_edges[i + 1])
30
+ count_in_bin = np.sum(bin_mask)
31
+ if count_in_bin == 0:
32
+ continue
33
+ avg_in_interval = in_interval[bin_mask].mean()
34
+ nominal_coverage = 0.95
35
+ calib_error = np.abs(avg_in_interval - nominal_coverage)
36
+ weight = count_in_bin / N
37
+ ece += weight * calib_error
38
+ return ece
39
+
40
+
41
+ def binning(pred_lower, pred_upper, num_bins=10):
42
+ """Group indices into bins based on interval width."""
43
+ widths = pred_upper - pred_lower
44
+ min_w, max_w = np.min(widths), np.max(widths)
45
+ bin_edges = np.linspace(min_w, max_w, num_bins + 1)
46
+
47
+ bins = []
48
+ for i in range(num_bins):
49
+ if i == num_bins - 1:
50
+ bin_mask = (widths >= bin_edges[i]) & (widths <= bin_edges[i + 1])
51
+ else:
52
+ bin_mask = (widths >= bin_edges[i]) & (widths < bin_edges[i + 1])
53
+ bin_indices = np.where(bin_mask)[0]
54
+ bins.append(bin_indices)
55
+ return bins, bin_edges
56
+
57
+
58
+ def generate_multimodal_data(n_samples=1000):
59
+ """Generate mixture-distribution noise samples."""
60
+ x = np.random.randn(n_samples)
61
+ mask = np.random.choice([0, 1, 2], p=[0.4, 0.3, 0.3], size=n_samples)
62
+ y = np.where(mask == 0, x + np.random.randn(n_samples) * 1,
63
+ np.where(mask == 1, x + 40 + np.random.randn(n_samples) * 1,
64
+ x - 10 + np.random.randn(n_samples) * 1))
65
+ return y
66
+
67
+
68
+ def generate_train_data(n_samples=20, noise='log'):
69
+ """Generate synthetic training data with nonlinear relationship and optional noise."""
70
+ np.random.seed(57)
71
+ x = np.linspace(-4, 4, n_samples)
72
+
73
+ if noise == 'log':
74
+ noise = np.random.lognormal(mean=1.5, sigma=1, size=n_samples)
75
+ elif noise == 'tri':
76
+ noise = generate_multimodal_data(n_samples)
77
+ elif noise == 'norm':
78
+ noise = np.random.normal(0, 8, size=n_samples)
79
+
80
+ noise = noise - np.mean(noise)
81
+ y = x ** 3 + noise
82
+ x = x.reshape(-1, 1).astype(np.float32)
83
+ y = y.reshape(-1, 1).astype(np.float32)
84
+ return x, y
85
+
86
+
87
+ def generate_test_data(n_samples=100, noise='log'):
88
+ """Generate synthetic test data with extended input range and optional noise."""
89
+ np.random.seed(27)
90
+ x = np.linspace(-6, 6, n_samples)
91
+
92
+ if noise == 'log':
93
+ noise = np.random.lognormal(mean=1.5, sigma=1, size=n_samples)
94
+ elif noise == 'tri':
95
+ noise = generate_multimodal_data(n_samples)
96
+ elif noise == 'norm':
97
+ noise = np.random.normal(0, 8, size=n_samples)
98
+
99
+ noise = noise - np.mean(noise)
100
+ y = x ** 3 + noise
101
+ x = x.reshape(-1, 1).astype(np.float32)
102
+ y = y.reshape(-1, 1).astype(np.float32)
103
+ return x, y
104
+
105
+
106
+ # ===================== Argument Parser ===================== #
107
+ parser = argparse.ArgumentParser()
108
+ parser.add_argument("--num-epochs", default=5000, type=int)
109
+ parser.add_argument('--data-noise', default='log', choices=['norm', 'tri', 'log'])
110
+ parser.add_argument('--UQ-model', default='SPCregression', choices=['SPCregression', 'DeepEnsemble', 'EDLRegressor', 'EDLQuantileRegressor', 'QROC', 'ConformalRegressor'], help='Select UQ model to test.')
111
+ args = parser.parse_args()
112
+
113
+ # Generate training, calibration, and testing datasets
114
+ x_train, y_train = generate_train_data(n_samples=2000, noise=args.data_noise)
115
+ x_calib, y_calib = generate_train_data(n_samples=500, noise=args.data_noise)
116
+ x_test, y_test = generate_test_data(n_samples=1000, noise=args.data_noise)
117
+
118
+ # Convert data to torch tensors
119
+ x_train_tensor = torch.from_numpy(x_train)
120
+ y_train_tensor = torch.from_numpy(y_train)
121
+ x_calib_tensor = torch.from_numpy(x_calib)
122
+ y_calib_tensor = torch.from_numpy(y_calib)
123
+ x_test_tensor = torch.from_numpy(x_test)
124
+ y_test_tensor = torch.from_numpy(y_test)
125
+
126
+ # Training parameters
127
+ num_epochs = 5000
128
+ num_models = 5
129
+ lr = 0.001
130
+
131
+ # Select UQ model to test
132
+ UQ = args.UQ_model
133
+
134
+ if UQ == 'SPCregression':
135
+ model = SPCregression(learning_rate=lr)
136
+ elif UQ == 'DeepEnsemble':
137
+ model = DeepEnsemble(num_models=num_models, learning_rate=lr)
138
+ elif UQ == 'EDLRegressor':
139
+ model = EDLRegressor(learning_rate=lr)
140
+ elif UQ == 'EDLQuantileRegressor':
141
+ model = EDLQuantileRegressor(tau_low=0.025, tau_high=0.975, learning_rate=lr)
142
+ elif UQ == 'QROC':
143
+ model = QROC(tau_low=0.025, tau_high=0.975, learning_rate=lr)
144
+ elif UQ == 'ConformalRegressor':
145
+ model = ConformalRegressor(0.95, learning_rate=lr)
146
+
147
+ # Train model
148
+ model.train(x_train_tensor, y_train_tensor, num_epochs)
149
+
150
+ # Calibrate if applicable
151
+ if UQ == 'ConformalRegressor':
152
+ model.calibrate(x_calib_tensor, y_calib_tensor)
153
+
154
+ # Predict on train and test sets
155
+ mean, upper_bound, lower_bound, uncertainty = model.predict(x_train_tensor)
156
+ y_train_np = y_train.flatten()
157
+
158
+ # Evaluate train metrics
159
+ mpi_width = np.mean(upper_bound - lower_bound)
160
+ picp = np.mean(((y_train_np >= lower_bound) & (y_train_np <= upper_bound)).astype(float))
161
+ rmse = np.sqrt(np.mean((y_train_np - mean) ** 2))
162
+ print(f"Train Mean Prediction Interval Width (MPIW): {mpi_width:.4f}")
163
+ print(f"Train Prediction Interval Coverage Probability (PICP): {picp:.4f}")
164
+ print(f"Train Root Mean Squared Error (RMSE): {rmse:.4f}")
165
+
166
+ # Predict on test set
167
+ mean, upper_bound, lower_bound, uncertainty = model.predict(x_test_tensor)
168
+ x_test_np = x_test.flatten()
169
+ y_test_np = y_test.flatten()
170
+
171
+ # In-distribution test subset for interval evaluation
172
+ y_low_id = lower_bound[170:830]
173
+ y_high_id = upper_bound[170:830]
174
+ y_mean_id = mean[170:830]
175
+ x_test_id = x_test_np[170:830]
176
+ y_test_id = y_test_np[170:830]
177
+
178
+ mpi_width_id = np.mean(y_high_id - y_low_id)
179
+ picp_id = np.mean(((y_test_id >= y_low_id) & (y_test_id <= y_high_id)).astype(float))
180
+ # PICP+
181
+ picp_plus = np.sum(((y_test_id >= y_mean_id) & (y_test_id <= y_high_id)).astype(float))/np.sum((y_test_id >= y_mean_id).astype(float))
182
+ # PICP-
183
+ picp_minus = np.sum(((y_test_id >= y_low_id) & (y_test_id <= y_mean_id)).astype(float))/np.sum((y_test_id <= y_mean_id).astype(float))
184
+
185
+ rmse_id = np.sqrt(np.mean((y_test_id - y_mean_id) ** 2))
186
+
187
+ print(f"ID Mean Prediction Interval Width (MPIW): {mpi_width_id:.4f}")
188
+ print(f"ID Prediction Interval Coverage Probability (PICP): {picp_id:.4f}")
189
+ print(f"ID Prediction Interval Coverage Probability (PICP+): {picp_plus:.4f}")
190
+ print(f"ID Prediction Interval Coverage Probability (PICP-): {picp_minus:.4f}")
191
+ print(f"ID Root Mean Squared Error (RMSE): {rmse_id:.4f}")
192
+
193
+ # Compute ECE
194
+ ece = ece_pi(y_test_id, y_low_id, y_high_id, num_bins=10)
195
+ print(f"Expected Calibration Error (ECE): {ece:.4f}")
196
+
197
+ threshold = uncertainty[170:830].mean()
198
+ print('threshold:', threshold)
199
+ cer_count = 0
200
+ unc_count = 0
201
+ cer_diff = []
202
+ unc_diff = []
203
+
204
+ for i in range(len(uncertainty)):
205
+ unc = uncertainty[i]
206
+ diff_sq = (y_test_np[i] - mean[i])**2
207
+ if unc < threshold:
208
+ cer_count += 1
209
+ cer_diff.append(diff_sq)
210
+ else:
211
+ unc_count += 1
212
+ unc_diff.append(diff_sq)
213
+
214
+ rmse_certain = np.sqrt(np.mean(cer_diff)) if len(cer_diff)>0 else 0
215
+ rmse_uncertain = np.sqrt(np.mean(unc_diff)) if len(unc_diff)>0 else 0
216
+ rmse_all = np.sqrt(np.mean(cer_diff + unc_diff))
217
+
218
+ print('Certain sample:', cer_count,
219
+ 'RMSE_certain:', round(rmse_certain,4),
220
+ 'Uncertain sample:', unc_count,
221
+ 'RMSE_uncertain:', round(rmse_uncertain,4),
222
+ 'RMSE_all:', round(rmse_all,4))
223
+
224
+ print(round(rmse_id,4),
225
+ round(picp_id,4),
226
+ round(mpi_width_id,4),
227
+ round(rmse_certain,4),
228
+ round(rmse_uncertain,4),
229
+ round(rmse_all,4),
230
+ cer_count, unc_count)
231
+
232
+
233
+ def plot_binned_intervals(x_test, y_test, mean, lower_bound, upper_bound, bins, num_bins=1):
234
+ """Plot prediction intervals with color-coded bins based on interval width."""
235
+ x = x_test.squeeze()
236
+ y = y_test.squeeze()
237
+
238
+ def quantile_stat(q):
239
+ def func(y_in_bin):
240
+ return np.percentile(y_in_bin, q)
241
+ return func
242
+
243
+ bin_means, bin_edges, _ = binned_statistic(x, y, statistic='mean', bins=num_bins)
244
+ q5, _, _ = binned_statistic(x, y, statistic=quantile_stat(5), bins=bin_edges)
245
+ q95, _, _ = binned_statistic(x, y, statistic=quantile_stat(95), bins=bin_edges)
246
+
247
+ gt = x ** 3
248
+ y_up = (y - gt)[y > gt]
249
+ y_down = (gt - y)[y < gt]
250
+ gt_up = np.quantile(y_up, 0.95)
251
+ gt_down = np.quantile(y_down, 0.95)
252
+
253
+ plt.figure(figsize=(7, 5))
254
+ cmap = cm.get_cmap('viridis', len(bins))
255
+ interval_width = 0.1
256
+ for i, bin_indices in enumerate(bins):
257
+ color = cmap(i)
258
+ for j, idx in enumerate(bin_indices):
259
+ x_val = x_test[idx]
260
+ plt.fill_between(
261
+ [x_val - interval_width / 2, x_val + interval_width / 2],
262
+ [lower_bound[idx], lower_bound[idx]],
263
+ [upper_bound[idx], upper_bound[idx]],
264
+ color=color,
265
+ alpha=0.2,
266
+ label=f'Bin {i + 1}' if j == 0 else None
267
+ )
268
+
269
+ plt.plot(x, gt, color='blue', linestyle='--', label='Mean E[y|x]', linewidth=2)
270
+ plt.plot(x, gt - gt_down, color='darkorange', linestyle='--', label='Lower bound GT', linewidth=2)
271
+ plt.plot(x, gt + gt_up, color='purple', linestyle='--', label='Upper bound GT', linewidth=2)
272
+ plt.scatter(x_test, y_test, color='green', s=5, label='Test Data')
273
+ plt.plot(x_test, mean, color='red', label='Point Prediction')
274
+ plt.title("Equal-width Binned PIs")
275
+ plt.legend()
276
+ plt.ylim(-100, 100)
277
+ plt.tight_layout()
278
+ plt.show()
279
+
280
+ # Second plot with shaded split intervals
281
+ plt.figure(figsize=(7, 4))
282
+ plt.plot(x, gt, color='blue', linestyle='--', label='Mean E[y|x]', linewidth=2)
283
+ plt.plot(x, gt - gt_down, color='darkorange', linestyle='--', label='Lower bound GT', linewidth=2)
284
+ plt.plot(x, gt + gt_up, color='purple', linestyle='--', label='Upper bound GT', linewidth=2)
285
+ plt.scatter(x_test, y_test, color='green', s=5, label='Test Data')
286
+ plt.plot(x_test, mean, color='red', label='Point Prediction')
287
+ plt.fill_between(x_test.flatten(), lower_bound, mean, color='orange', alpha=0.3, label='Lower Interval')
288
+ plt.fill_between(x_test.flatten(), mean, upper_bound, color='purple', alpha=0.4, label='Upper Interval')
289
+ plt.legend(fontsize=8)
290
+ plt.ylim(-80, 100)
291
+ plt.title('Split-point Prediction Intervals')
292
+ plt.tight_layout()
293
+ plt.show()
294
+
295
+ def plot_intervals(x_test, y_test, mean, lower_bound, upper_bound, uncertainty):
296
+ """Plot prediction intervals and epistemic uncertainty with ground truth reference."""
297
+ x = x_test.squeeze()
298
+ y = y_test.squeeze()
299
+
300
+ gt = x ** 3
301
+ y_up = (y - gt)[y > gt]
302
+ y_down = (gt - y)[y < gt]
303
+ gt_up = np.quantile(y_up, 0.95)
304
+ gt_down = np.quantile(y_down, 0.95)
305
+
306
+ import matplotlib.gridspec as gridspec
307
+ plt.figure(figsize=(7, 9))
308
+ gs = gridspec.GridSpec(2, 1, height_ratios=[4, 3])
309
+
310
+ ax1 = plt.subplot(gs[0])
311
+ ax1.axvspan(x.min(), -4, facecolor='lightgray', alpha=0.4)
312
+ ax1.axvspan(4, x.max(), facecolor='lightgray', alpha=0.4)
313
+ ax1.plot(x, gt, color='blue', linestyle='--', label='Mean E[y|x]', linewidth=2)
314
+ ax1.plot(x, gt - gt_down, color='darkorange', linestyle='--', label='Lower bound GT', linewidth=2)
315
+ ax1.plot(x, gt + gt_up, color='purple', linestyle='--', label='Upper bound GT', linewidth=2)
316
+ ax1.scatter(x_test, y_test, color='green', s=5, label='Test Data')
317
+ ax1.plot(x_test, mean, color='red', label='Point Prediction')
318
+ ax1.fill_between(x_test.flatten(), lower_bound, mean, color='orange', alpha=0.3, label='Lower Interval')
319
+ ax1.fill_between(x_test.flatten(), mean, upper_bound, color='purple', alpha=0.4, label='Upper Interval')
320
+ ax1.legend(fontsize=9)
321
+ ax1.set_ylim(-150, 150)
322
+
323
+ ax2 = plt.subplot(gs[1])
324
+ ax2.axvspan(x.min(), -4, facecolor='lightgray', alpha=0.4, label='OOD Region')
325
+ ax2.axvspan(4, x.max(), facecolor='lightgray', alpha=0.4)
326
+ ax2.plot(x_test, uncertainty, label='Epistemic uncertainty', color='dodgerblue')
327
+ ax2.fill_between(x_test.squeeze(), uncertainty.squeeze(), alpha=0.3, color='dodgerblue')
328
+ ax2.legend()
329
+ plt.tight_layout()
330
+ plt.show()
331
+
332
+ # Call visualizations after metric evaluations
333
+ bins, bin_edges = binning(y_low_id, y_high_id, num_bins=5)
334
+ plot_binned_intervals(x_test_id, y_test_id, y_mean_id, y_low_id, y_high_id, bins)
335
+ plot_intervals(x_test, y_test, mean, lower_bound, upper_bound, uncertainty)
SPC-UQ/Image_Classification/README.md ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Deep Deterministic Uncertainty
2
+
3
+ [![arXiv](https://img.shields.io/badge/stat.ML-arXiv%3A2006.08437-B31B1B.svg)](https://arxiv.org/abs/2102.11582)
4
+ [![Pytorch 1.8.1](https://img.shields.io/badge/pytorch-1.8.1-blue.svg)](https://pytorch.org/)
5
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/omegafragger/DDU/blob/main/LICENSE)
6
+
7
+ This repository contains the code for [*Deterministic Neural Networks with Appropriate Inductive Biases Capture Epistemic and Aleatoric Uncertainty*](https://arxiv.org/abs/2102.11582).
8
+
9
+ If the code or the paper has been useful in your research, please add a citation to our work:
10
+
11
+ ```
12
+ @article{mukhoti2021deterministic,
13
+ title={Deterministic Neural Networks with Appropriate Inductive Biases Capture Epistemic and Aleatoric Uncertainty},
14
+ author={Mukhoti, Jishnu and Kirsch, Andreas and van Amersfoort, Joost and Torr, Philip HS and Gal, Yarin},
15
+ journal={arXiv preprint arXiv:2102.11582},
16
+ year={2021}
17
+ }
18
+ ```
19
+
20
+ ## Dependencies
21
+
22
+ The code is based on PyTorch and requires a few further dependencies, listed in [environment.yml](environment.yml). It should work with newer versions as well.
23
+
24
+
25
+ ## OoD Detection
26
+
27
+ ### Datasets
28
+
29
+ For OoD detection, you can train on [*CIFAR-10/100*](https://www.cs.toronto.edu/~kriz/cifar.html). You can also train on [*Dirty-MNIST*](https://blackhc.github.io/ddu_dirty_mnist/) by downloading *Ambiguous-MNIST* (```amnist_labels.pt``` and ```amnist_samples.pt```) from [here](https://github.com/BlackHC/ddu_dirty_mnist/releases/tag/data-v0.6.0) and using the following training instructions.
30
+
31
+ ### Training
32
+
33
+ In order to train a model for the OoD detection task, use the [train.py](train.py) script. Following are the main parameters for training:
34
+ ```
35
+ --seed: seed for initialization
36
+ --dataset: dataset used for training (cifar10/cifar100/dirty_mnist)
37
+ --dataset-root: /path/to/amnist_labels.pt and amnist_samples.pt/ (if training on dirty-mnist)
38
+ --model: model to train (wide_resnet/vgg16/resnet18/resnet50/lenet)
39
+ -sn: whether to use spectral normalization (available for wide_resnet, vgg16 and resnets)
40
+ --coeff: Coefficient for spectral normalization
41
+ -mod: whether to use architectural modifications (leaky ReLU + average pooling in skip connections)
42
+ --save-path: path/for/saving/model/
43
+ ```
44
+
45
+ As an example, in order to train a Wide-ResNet-28-10 with spectral normalization and architectural modifications on CIFAR-10, use the following:
46
+ ```
47
+ python train.py \
48
+ --seed 1 \
49
+ --dataset cifar10 \
50
+ --model wide_resnet \
51
+ -sn -mod \
52
+ --coeff 3.0
53
+ ```
54
+ Similarly, to train a ResNet-18 with spectral normalization on Dirty-MNIST, use:
55
+ ```
56
+ python train.py \
57
+ --seed 1 \
58
+ --dataset dirty-mnist \
59
+ --dataset-root /home/user/amnist/ \
60
+ --model resnet18 \
61
+ -sn \
62
+ --coeff 3.0
63
+ ```
64
+
65
+ ### Evaluation
66
+
67
+ To evaluate trained models, use [evaluate.py](evaluate.py). This script can evaluate and aggregate results over multiple experimental runs. For example, if the pretrained models are stored in a directory path ```/home/user/models```, store them using the following directory structure:
68
+ ```
69
+ models
70
+ ├── Run1
71
+ │   └── wide_resnet_1_350.model
72
+ ├── Run2
73
+ │   └── wide_resnet_2_350.model
74
+ ├── Run3
75
+ │   └── wide_resnet_3_350.model
76
+ ├── Run4
77
+ │   └── wide_resnet_4_350.model
78
+ └── Run5
79
+ └── wide_resnet_5_350.model
80
+ ```
81
+ For an ensemble of models, store the models using the following directory structure:
82
+ ```
83
+ model_ensemble
84
+ ├── Run1
85
+ │   ├── wide_resnet_1_350.model
86
+ │   ├── wide_resnet_2_350.model
87
+ │   ├── wide_resnet_3_350.model
88
+ │   ├── wide_resnet_4_350.model
89
+ │   └── wide_resnet_5_350.model
90
+ ├── Run2
91
+ │   ├── wide_resnet_10_350.model
92
+ │   ├── wide_resnet_6_350.model
93
+ │   ├── wide_resnet_7_350.model
94
+ │   ├── wide_resnet_8_350.model
95
+ │   └── wide_resnet_9_350.model
96
+ ├── Run3
97
+ │   ├── wide_resnet_11_350.model
98
+ │   ├── wide_resnet_12_350.model
99
+ │   ├── wide_resnet_13_350.model
100
+ │   ├── wide_resnet_14_350.model
101
+ │   └── wide_resnet_15_350.model
102
+ ├── Run4
103
+ │   ├── wide_resnet_16_350.model
104
+ │   ├── wide_resnet_17_350.model
105
+ │   ├── wide_resnet_18_350.model
106
+ │   ├── wide_resnet_19_350.model
107
+ │   └── wide_resnet_20_350.model
108
+ └── Run5
109
+ ├── wide_resnet_21_350.model
110
+ ├── wide_resnet_22_350.model
111
+ ├── wide_resnet_23_350.model
112
+ ├── wide_resnet_24_350.model
113
+ └── wide_resnet_25_350.model
114
+ ```
115
+ Following are the main parameters for evaluation:
116
+ ```
117
+ --seed: seed used for initializing the first trained model
118
+ --dataset: dataset used for training (cifar10/cifar100)
119
+ --ood_dataset: OoD dataset to compute AUROC
120
+ --load-path: /path/to/pretrained/models/
121
+ --model: model architecture to load (wide_resnet/vgg16)
122
+ --runs: number of experimental runs
123
+ -sn: whether the model was trained using spectral normalization
124
+ --coeff: Coefficient for spectral normalization
125
+ -mod: whether the model was trained using architectural modifications
126
+ --ensemble: number of models in the ensemble
127
+ --model-type: type of model to load for evaluation (softmax/ensemble/gmm)
128
+ ```
129
+ As an example, in order to evaluate a Wide-ResNet-28-10 with spectral normalization and architectural modifications on CIFAR-10 with OoD dataset as SVHN, use the following:
130
+ ```
131
+ python evaluate.py \
132
+ --seed 1 \
133
+ --dataset cifar10 \
134
+ --ood_dataset svhn \
135
+ --load-path /path/to/pretrained/models/ \
136
+ --model wide_resnet \
137
+ --runs 5 \
138
+ -sn -mod \
139
+ --coeff 3.0 \
140
+ --model-type softmax
141
+ ```
142
+ Similarly, to evaluate the above model using feature density, set ```--model-type gmm```. The evaluation script assumes that the seeds of models trained in consecutive runs differ by 1. The script stores the results in a json file with the following structure:
143
+ ```
144
+ {
145
+ "mean": {
146
+ "accuracy": mean accuracy,
147
+ "ece": mean ECE,
148
+ "m1_auroc": mean AUROC using log density / MI for ensembles,
149
+ "m1_auprc": mean AUPRC using log density / MI for ensembles,
150
+ "m2_auroc": mean AUROC using entropy / PE for ensembles,
151
+ "m2_auprc": mean AUPRC using entropy / PE for ensembles,
152
+ "t_ece": mean ECE (post temp scaling)
153
+ "t_m1_auroc": mean AUROC using log density / MI for ensembles (post temp scaling),
154
+ "t_m1_auprc": mean AUPRC using log density / MI for ensembles (post temp scaling),
155
+ "t_m2_auroc": mean AUROC using entropy / PE for ensembles (post temp scaling),
156
+ "t_m2_auprc": mean AUPRC using entropy / PE for ensembles (post temp scaling)
157
+ },
158
+ "std": {
159
+ "accuracy": std error accuracy,
160
+ "ece": std error ECE,
161
+ "m1_auroc": std error AUROC using log density / MI for ensembles,
162
+ "m1_auprc": std error AUPRC using log density / MI for ensembles,
163
+ "m2_auroc": std error AUROC using entropy / PE for ensembles,
164
+ "m2_auprc": std error AUPRC using entropy / PE for ensembles,
165
+ "t_ece": std error ECE (post temp scaling),
166
+ "t_m1_auroc": std error AUROC using log density / MI for ensembles (post temp scaling),
167
+ "t_m1_auprc": std error AUPRC using log density / MI for ensembles (post temp scaling),
168
+ "t_m2_auroc": std error AUROC using entropy / PE for ensembles (post temp scaling),
169
+ "t_m2_auprc": std error AUPRC using entropy / PE for ensembles (post temp scaling)
170
+ },
171
+ "values": {
172
+ "accuracy": accuracy list,
173
+ "ece": ece list,
174
+ "m1_auroc": AUROC list using log density / MI for ensembles,
175
+ "m2_auroc": AUROC list using entropy / PE for ensembles,
176
+ "t_ece": ece list (post temp scaling),
177
+ "t_m1_auroc": AUROC list using log density / MI for ensembles (post temp scaling),
178
+ "t_m1_auprc": AUPRC list using log density / MI for ensembles (post temp scaling),
179
+ "t_m2_auroc": AUROC list using entropy / PE for ensembles (post temp scaling),
180
+ "t_m2_auprc": AUPRC list using entropy / PE for ensembles (post temp scaling)
181
+ },
182
+ "info": {dictionary of args}
183
+ }
184
+ ```
185
+
186
+ ### Results
187
+
188
+ #### Dirty-MNIST
189
+
190
+ To visualise DDU's performance on Dirty-MNIST (i.e., Fig. 1 of the paper), use [fig_1_plot.ipynb](notebooks/fig_1_plot.ipynb). The notebook requires a pretrained LeNet, VGG-16 and ResNet-18 with spectral normalization trained on Dirty-MNIST and visualises the softmax entropy and feature density for Dirty-MNIST (iD) samples vs Fashion-MNIST (OoD) samples. The notebook also visualises the softmax entropies of MNIST vs Ambiguous-MNIST samples for the ResNet-18+SN model (Fig. 2 of the paper). The following figure shows the output of the notebook for the LeNet, VGG-16 and ResNet18+SN model we trained on Dirty-MNIST.
191
+
192
+ <p align="center">
193
+ <img src="vis/dirty_mnist_vis.png" width="500" />
194
+ </p>
195
+
196
+ #### CIFAR-10 vs SVHN
197
+
198
+ The following table presents results for a Wide-ResNet-28-10 architecture trained on CIFAR-10 with SVHN as the OoD dataset. For the full set of results, refer to the [paper](https://arxiv.org/abs/2102.11582).
199
+
200
+ | Method | Aleatoric Uncertainty | Epistemic Uncertainty | Test Accuracy | Test ECE | AUROC |
201
+ | --- | --- | --- | --- | --- | --- |
202
+ | Softmax | Softmax Entropy | Softmax Entropy | 95.98+-0.02 | 0.85+-0.02 | 94.44+-0.43 |
203
+ | [Energy-based](https://arxiv.org/abs/2010.03759) | Softmax Entropy | Softmax Density | 95.98+-0.02 | 0.85+-0.02 | 94.56+-0.51 |
204
+ | [5-Ensemble](https://arxiv.org/abs/1612.01474) | Predictive Entropy | Predictive Entropy | 96.59+-0.02 | 0.76+-0.03 | 97.73+-0.31 |
205
+ | DDU (ours) | Softmax Entropy | GMM Density | 95.97+-0.03 | 0.85+-0.04 | 98.09+-0.10 |
206
+
207
+
208
+ ## Active Learning
209
+
210
+ To run active learning experiments, use ```active_learning_script.py```. You can run active learning experiments on both [MNIST](http://yann.lecun.com/exdb/mnist/) as well as [Dirty-MNIST](https://blackhc.github.io/ddu_dirty_mnist/). When running with Dirty-MNIST, you will need to provide a pretrained model on Dirty-MNIST to distinguish between clean MNIST and Ambiguous-MNIST samples. The following are the main command line arguments for ```active_learning_script.py```.
211
+ ```
212
+ --seed: seed used for initializing the first model (later experimental runs will have seeds incremented by 1)
213
+ --model: model architecture to train (resnet18)
214
+ -ambiguous: whether to use ambiguous MNIST during training. If this is set to True, the models will be trained on Dirty-MNIST, otherwise they will train on MNIST.
215
+ --dataset-root: /path/to/amnist_labels.pt and amnist_samples.pt/
216
+ --trained-model: model architecture of pretrained model to distinguish clean and ambiguous MNIST samples
217
+ -tsn: if pretrained model has been trained using spectral normalization
218
+ --tcoeff: coefficient of spectral normalization used on pretrained model
219
+ -tmod: if pretrained model has been trained using architectural modifications (leaky ReLU and average pooling on skip connections)
220
+ --saved-model-path: /path/to/saved/pretrained/model/
221
+ --saved-model-name: name of the saved pretrained model file
222
+ --threshold: Threshold of softmax entropy to decide if a sample is ambiguous (samples having higher softmax entropy than threshold will be considered ambiguous)
223
+ --subsample: number of clean MNIST samples to use to subsample clean MNIST
224
+ -sn: whether to use spectral normalization during training
225
+ --coeff: coefficient of spectral normalization during training
226
+ -mod: whether to use architectural modifications (leaky ReLU and average pooling on skip connections) during training
227
+ --al-type: type of active learning acquisition model (softmax/ensemble/gmm)
228
+ -mi: whether to use mutual information for ensemble al-type
229
+ --num-initial-samples: number of initial samples in the training set
230
+ --max-training-samples: maximum number of training samples
231
+ --acquisition-batch-size: batch size for each acquisition step
232
+ ```
233
+
234
+ As an example, to run the active learning experiment on MNIST using the DDU method, use:
235
+ ```
236
+ python active_learning_script.py \
237
+ --seed 1 \
238
+ --model resnet18 \
239
+ -sn -mod \
240
+ --al-type gmm
241
+ ```
242
+ Similarly, to run the active learning experiment on Dirty-MNIST using the DDU baseline, with a pretrained ResNet-18 with SN to distinguish clean and ambiguous MNIST samples, use the following:
243
+ ```
244
+ python active_learning_script.py \
245
+ --seed 1 \
246
+ --model resnet18 \
247
+ -sn -mod \
248
+ -ambiguous \
249
+ --dataset-root /home/user/amnist/ \
250
+ --trained-model resnet18 \
251
+ -tsn \
252
+ --saved-model-path /path/to/pretrained/model \
253
+ --saved-model-name resnet18_sn_3.0_1_350.model \
254
+ --threshold 1.0 \
255
+ --subsample 1000 \
256
+ --al-type gmm
257
+ ```
258
+
259
+ ### Results
260
+
261
+ The active learning script stores all results in json files. The MNIST test set accuracy is stored in a json file with the following structure:
262
+ ```
263
+ {
264
+ "experiment run": list of MNIST test set accuracies one per acquisition step
265
+ }
266
+ ```
267
+ When using ambiguous samples in the pool set, the script also stores the fraction of ambiguous samples acquired in each step in the following json:
268
+ ```
269
+ {
270
+ "experiment run": list of fractions of ambiguous samples in the acquired training set
271
+ }
272
+ ```
273
+
274
+ ### Visualisation
275
+
276
+ To visualise results from the above json files, use the [al_plot.ipynb](notebooks/al_plot.ipynb) notebook. The following diagram shows the performance of different baselines (softmax, ensemble PE, ensemble MI and DDU) on MNIST and Dirty-MNIST.
277
+
278
+ <p align="center">
279
+ <img src="vis/al_plots.png" width="700" />
280
+ </p>
281
+
282
+
283
+ ## Questions
284
+
285
+ For any questions, please feel free to raise an issue or email us directly. Our emails can be found on the [paper](https://arxiv.org/abs/2102.11582).
SPC-UQ/Image_Classification/data/__init__.py ADDED
File without changes
SPC-UQ/Image_Classification/data/ood_detection/__init__.py ADDED
File without changes
SPC-UQ/Image_Classification/data/ood_detection/cifar10.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Create train, valid, test iterators for CIFAR-10.
3
+ Train set size: 45000
4
+ Val set size: 5000
5
+ Test set size: 10000
6
+ """
7
+
8
+ import torch
9
+ import numpy as np
10
+ from torch.utils.data import Subset
11
+
12
+ from torchvision import datasets
13
+ from torchvision import transforms
14
+
15
+
16
+ def get_train_valid_loader(batch_size, augment, val_seed, imagesize=128, val_size=0.1, num_workers=1, pin_memory=False, **kwargs):
17
+ """
18
+ Utility function for loading and returning train and valid
19
+ multi-process iterators over the CIFAR-10 dataset.
20
+ Params:
21
+ ------
22
+ - batch_size: how many samples per batch to load.
23
+ - augment: whether to apply the data augmentation scheme
24
+ mentioned in the paper. Only applied on the train split.
25
+ - val_seed: fix seed for reproducibility.
26
+ - val_size: percentage split of the training set used for
27
+ the validation set. Should be a float in the range [0, 1].
28
+ - num_workers: number of subprocesses to use when loading the dataset.
29
+ - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
30
+ True if using GPU.
31
+
32
+ Returns
33
+ -------
34
+ - train_loader: training set iterator.
35
+ - valid_loader: validation set iterator.
36
+ """
37
+ error_msg = "[!] val_size should be in the range [0, 1]."
38
+ assert (val_size >= 0) and (val_size <= 1), error_msg
39
+
40
+ normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010],)
41
+
42
+ # define transforms
43
+ valid_transform = transforms.Compose([transforms.Resize(imagesize),transforms.ToTensor(), normalize,])
44
+
45
+ if augment:
46
+ train_transform = transforms.Compose(
47
+ [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),transforms.Resize(imagesize), transforms.ToTensor(), normalize,]
48
+ )
49
+ else:
50
+ train_transform = transforms.Compose([transforms.Resize(imagesize),transforms.ToTensor(), normalize,])
51
+
52
+ # load the dataset
53
+ data_dir = "./data"
54
+ train_dataset = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform,)
55
+
56
+ valid_dataset = datasets.CIFAR10(root=data_dir, train=True, download=False, transform=valid_transform,)
57
+
58
+ num_train = len(train_dataset)
59
+ indices = list(range(num_train))
60
+ split = int(np.floor(val_size * num_train))
61
+
62
+ np.random.seed(val_seed)
63
+ np.random.shuffle(indices)
64
+
65
+ train_idx, valid_idx = indices[split:], indices[:split]
66
+
67
+ train_subset = Subset(train_dataset, train_idx)
68
+ valid_subset = Subset(valid_dataset, valid_idx)
69
+
70
+ train_loader = torch.utils.data.DataLoader(
71
+ train_subset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True,
72
+ )
73
+ valid_loader = torch.utils.data.DataLoader(
74
+ valid_subset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=False,
75
+ )
76
+
77
+ return (train_loader, valid_loader)
78
+
79
+
80
+ def get_test_loader(batch_size, imagesize=128, num_workers=1, pin_memory=False, **kwargs):
81
+ """
82
+ Utility function for loading and returning a multi-process
83
+ test iterator over the CIFAR-10 dataset.
84
+ If using CUDA, num_workers should be set to 1 and pin_memory to True.
85
+ Params
86
+ ------
87
+ - batch_size: how many samples per batch to load.
88
+ - num_workers: number of subprocesses to use when loading the dataset.
89
+ - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
90
+ True if using GPU.
91
+ Returns
92
+ -------
93
+ - data_loader: test set iterator.
94
+ """
95
+ normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010],)
96
+
97
+ # define transform
98
+ transform = transforms.Compose([transforms.Resize(imagesize), transforms.ToTensor(), normalize,])
99
+
100
+ data_dir = "./data"
101
+ dataset = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform,)
102
+
103
+ data_loader = torch.utils.data.DataLoader(
104
+ dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory,
105
+ )
106
+
107
+ return data_loader
SPC-UQ/Image_Classification/data/ood_detection/cifar100.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Create train, valid, test iterators for CIFAR-100.
3
+ Train set size: 45000
4
+ Val set size: 5000
5
+ Test set size: 10000
6
+ """
7
+
8
+ import torch
9
+ import numpy as np
10
+ from torch.utils.data import Subset
11
+
12
+ from torchvision import datasets
13
+ from torchvision import transforms
14
+
15
+
16
+ def get_train_valid_loader(batch_size, augment, val_seed, imagesize=128, val_size=0.1, num_workers=1, pin_memory=False, **kwargs):
17
+ """
18
+ Utility function for loading and returning train and valid
19
+ multi-process iterators over the CIFAR-100 dataset.
20
+ Params:
21
+ ------
22
+ - batch_size: how many samples per batch to load.
23
+ - augment: whether to apply the data augmentation scheme
24
+ mentioned in the paper. Only applied on the train split.
25
+ - val_seed: fix seed for reproducibility.
26
+ - val_size: percentage split of the training set used for
27
+ the validation set. Should be a float in the range [0, 1].
28
+ - num_workers: number of subprocesses to use when loading the dataset.
29
+ - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
30
+ True if using GPU.
31
+
32
+ Returns
33
+ -------
34
+ - train_loader: training set iterator.
35
+ - valid_loader: validation set iterator.
36
+ """
37
+ error_msg = "[!] val_size should be in the range [0, 1]."
38
+ assert (val_size >= 0) and (val_size <= 1), error_msg
39
+
40
+ normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010],)
41
+
42
+ # define transforms
43
+ valid_transform = transforms.Compose([transforms.Resize(imagesize),transforms.ToTensor(), normalize,])
44
+
45
+ if augment:
46
+ train_transform = transforms.Compose(
47
+ [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.Resize(imagesize),transforms.ToTensor(), normalize,]
48
+ )
49
+ else:
50
+ train_transform = transforms.Compose([transforms.Resize(imagesize),transforms.ToTensor(), normalize,])
51
+
52
+ # load the dataset
53
+ data_dir = "./data"
54
+ train_dataset = datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform,)
55
+
56
+ valid_dataset = datasets.CIFAR100(root=data_dir, train=True, download=False, transform=valid_transform,)
57
+
58
+ num_train = len(train_dataset)
59
+ indices = list(range(num_train))
60
+ split = int(np.floor(val_size * num_train))
61
+
62
+ np.random.seed(val_seed)
63
+ np.random.shuffle(indices)
64
+
65
+ train_idx, valid_idx = indices[split:], indices[:split]
66
+
67
+ train_subset = Subset(train_dataset, train_idx)
68
+ valid_subset = Subset(valid_dataset, valid_idx)
69
+
70
+ train_loader = torch.utils.data.DataLoader(
71
+ train_subset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True,
72
+ )
73
+ valid_loader = torch.utils.data.DataLoader(
74
+ valid_subset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=False,
75
+ )
76
+
77
+ return (train_loader, valid_loader)
78
+
79
+
80
+ def get_test_loader(batch_size, imagesize=128, num_workers=1, pin_memory=False, **kwargs):
81
+ """
82
+ Utility function for loading and returning a multi-process
83
+ test iterator over the CIFAR-100 dataset.
84
+ If using CUDA, num_workers should be set to 1 and pin_memory to True.
85
+ Params
86
+ ------
87
+ - batch_size: how many samples per batch to load.
88
+ - num_workers: number of subprocesses to use when loading the dataset.
89
+ - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
90
+ True if using GPU.
91
+ Returns
92
+ -------
93
+ - data_loader: test set iterator.
94
+ """
95
+ normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010],)
96
+
97
+ # define transform
98
+ transform = transforms.Compose([transforms.Resize(imagesize), transforms.ToTensor(), normalize,])
99
+
100
+ data_dir = "./data"
101
+ dataset = datasets.CIFAR100(root=data_dir, train=False, download=True, transform=transform,)
102
+
103
+ data_loader = torch.utils.data.DataLoader(
104
+ dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory,
105
+ )
106
+
107
+ return data_loader
SPC-UQ/Image_Classification/data/ood_detection/imagenet.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Create train, valid, test iterators for ImageNet.
3
+ Train set size: user-defined
4
+ Val set size: user-defined
5
+ Test set size: user-defined (if available)
6
+ """
7
+
8
+ import torch
9
+ import numpy as np
10
+ from torch.utils.data import Subset
11
+ from torchvision import datasets, transforms
12
+
13
+
14
+ def get_train_valid_loader(batch_size, augment, val_seed, imagesize=224, val_size=0.1, num_workers=1, pin_memory=False, **kwargs):
15
+ assert 0 <= val_size <= 1, "[!] val_size should be in the range [0, 1]."
16
+
17
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
18
+ imagesize = 224
19
+ # Define transformations
20
+ valid_transform = transforms.Compose([
21
+ transforms.Resize(256),
22
+ transforms.CenterCrop(imagesize),
23
+ transforms.ToTensor(),
24
+ normalize
25
+ ])
26
+ if augment:
27
+ transform = transforms.Compose([
28
+ transforms.Resize(256),
29
+ transforms.CenterCrop(224),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize(
32
+ mean=[0.485, 0.456, 0.406],
33
+ std=[0.229, 0.224, 0.225]
34
+ )
35
+ ])
36
+ else:
37
+ train_transform = valid_transform
38
+
39
+ data_dir = "./data/Imagenet1K"
40
+ # Load the dataset
41
+ train_dataset = datasets.ImageFolder(root=f"{data_dir}/train", transform=train_transform)
42
+ valid_dataset = datasets.ImageFolder(root=f"{data_dir}/train", transform=valid_transform)
43
+
44
+ num_train = len(train_dataset)
45
+ indices = list(range(num_train))
46
+ split = int(np.floor(val_size * num_train))
47
+
48
+ np.random.seed(val_seed)
49
+ np.random.shuffle(indices)
50
+
51
+ train_idx, valid_idx = indices[split:], indices[:split]
52
+
53
+ train_subset = Subset(train_dataset, train_idx)
54
+ valid_subset = Subset(valid_dataset, valid_idx)
55
+
56
+ train_loader = torch.utils.data.DataLoader(
57
+ train_subset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True,
58
+ )
59
+ valid_loader = torch.utils.data.DataLoader(
60
+ valid_subset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=False,
61
+ )
62
+
63
+ return train_loader, valid_loader
64
+
65
+
66
+ def get_test_loader(batch_size, imagesize=224, num_workers=1, pin_memory=False, **kwargs):
67
+
68
+ # Define transformation
69
+ transform = transforms.Compose([
70
+ transforms.Resize(256),
71
+ transforms.CenterCrop(224),
72
+ transforms.ToTensor(),
73
+ transforms.Normalize(
74
+ mean=[0.485, 0.456, 0.406],
75
+ std=[0.229, 0.224, 0.225]
76
+ )
77
+ ])
78
+ data_dir = "./data/Imagenet1K"
79
+ dataset = datasets.ImageFolder(root=f"{data_dir}/val", transform=transform)
80
+
81
+ data_loader = torch.utils.data.DataLoader(
82
+ dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory,
83
+ )
84
+
85
+ return data_loader
SPC-UQ/Image_Classification/data/ood_detection/imagenet_a.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Create train, valid, test iterators for ImageNet.
3
+ Train set size: user-defined
4
+ Val set size: user-defined
5
+ Test set size: user-defined (if available)
6
+ """
7
+
8
+ import torch
9
+ import numpy as np
10
+ from torch.utils.data import Subset
11
+ from torchvision import datasets, transforms
12
+
13
+
14
+
15
+ def get_test_loader(batch_size, imagesize=224, num_workers=1, pin_memory=False, **kwargs):
16
+ transform = transforms.Compose([
17
+ transforms.Resize(256),
18
+ transforms.CenterCrop(imagesize),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize(
21
+ mean=[0.485, 0.456, 0.406],
22
+ std=[0.229, 0.224, 0.225]
23
+ )
24
+ ])
25
+
26
+ root = "./data/imagenet-a"
27
+ dataset = datasets.ImageFolder(
28
+ root=root,
29
+ transform=transform
30
+ )
31
+
32
+ data_loader = torch.utils.data.DataLoader(
33
+ dataset, batch_size=batch_size, shuffle=False,
34
+ num_workers=num_workers, pin_memory=pin_memory
35
+ )
36
+
37
+ return data_loader
SPC-UQ/Image_Classification/data/ood_detection/imagenet_o.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Create train, valid, test iterators for ImageNet.
3
+ Train set size: user-defined
4
+ Val set size: user-defined
5
+ Test set size: user-defined (if available)
6
+ """
7
+
8
+ import torch
9
+ import numpy as np
10
+ from torch.utils.data import Subset
11
+ from torchvision import datasets, transforms
12
+
13
+
14
+
15
+ def get_test_loader(batch_size, imagesize=224, num_workers=1, pin_memory=False, **kwargs):
16
+ transform = transforms.Compose([
17
+ transforms.Resize(256),
18
+ transforms.CenterCrop(imagesize),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize(
21
+ mean=[0.485, 0.456, 0.406],
22
+ std=[0.229, 0.224, 0.225]
23
+ )
24
+ ])
25
+
26
+ root = "./data/imagenet-o"
27
+ dataset = datasets.ImageFolder(
28
+ root=root,
29
+ transform=transform
30
+ )
31
+
32
+ data_loader = torch.utils.data.DataLoader(
33
+ dataset, batch_size=batch_size, shuffle=False,
34
+ num_workers=num_workers, pin_memory=pin_memory
35
+ )
36
+
37
+ return data_loader
SPC-UQ/Image_Classification/data/ood_detection/ood_union.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.utils.data import Subset, ConcatDataset, DataLoader
4
+ import random
5
+
6
+ from torchvision import datasets, transforms
7
+ from torchvision.datasets import ImageFolder
8
+
9
+ def get_svhn_test_loader(batch_size, imagesize=128, num_workers=1, pin_memory=False, **kwargs):
10
+ normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010],)
11
+
12
+ # define transform
13
+ transform = transforms.Compose([transforms.Resize(imagesize), transforms.ToTensor(), normalize,])
14
+
15
+
16
+ data_dir = "./data"
17
+ dataset = datasets.SVHN(root=data_dir, split="test", download=True, transform=transform,)
18
+
19
+ data_loader = torch.utils.data.DataLoader(
20
+ dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory,
21
+ )
22
+
23
+ return data_loader
24
+
25
+
26
+ def get_tinyimagenet_test_loader(batch_size, imagesize=128, num_workers=1, pin_memory=False, **kwargs):
27
+
28
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
29
+
30
+ transform = transforms.Compose([
31
+ transforms.Resize(imagesize),
32
+ transforms.ToTensor(),
33
+ normalize
34
+ ])
35
+
36
+ data_dir = "./data/tinyimagenet"
37
+ test_dir = os.path.join(data_dir, "test")
38
+ dataset = ImageFolder(root=test_dir, transform=transform)
39
+
40
+ data_loader = torch.utils.data.DataLoader(
41
+ dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory
42
+ )
43
+ return data_loader
44
+
45
+
46
+ def get_cifar10_test_loader(batch_size, imagesize=128, num_workers=1, pin_memory=False, **kwargs):
47
+
48
+ normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010],)
49
+
50
+ # define transform
51
+ transform = transforms.Compose([transforms.Resize(imagesize), transforms.ToTensor(), normalize,])
52
+
53
+ data_dir = "./data"
54
+ dataset = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform,)
55
+
56
+ data_loader = torch.utils.data.DataLoader(
57
+ dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory,
58
+ )
59
+
60
+ return data_loader
61
+
62
+
63
+ def get_cifar100_test_loader(batch_size, imagesize=128, num_workers=1, pin_memory=False, **kwargs):
64
+
65
+ normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010],)
66
+
67
+ # define transform
68
+ transform = transforms.Compose([transforms.Resize(imagesize), transforms.ToTensor(), normalize,])
69
+
70
+ data_dir = "./data"
71
+ dataset = datasets.CIFAR100(root=data_dir, train=False, download=True, transform=transform,)
72
+
73
+ data_loader = torch.utils.data.DataLoader(
74
+ dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory,
75
+ )
76
+
77
+ return data_loader
78
+
79
+
80
+
81
+ def get_combined_ood_test_loader(batch_size, sample_seed, imagesize=128, num_workers=1, pin_memory=False, sample_size=10000, **kwargs):
82
+ svhn_ds = get_svhn_test_loader(batch_size=1, imagesize=imagesize).dataset
83
+ tiny_ds = get_tinyimagenet_test_loader(batch_size=1, imagesize=imagesize).dataset
84
+
85
+ combined_dataset = ConcatDataset([
86
+ svhn_ds,
87
+ tiny_ds
88
+ ])
89
+
90
+ # print(len(combined_dataset))
91
+
92
+ random.seed(sample_seed)
93
+ if sample_size is not None and sample_size < len(combined_dataset):
94
+ indices = random.sample(range(len(combined_dataset)), sample_size)
95
+ combined_dataset = Subset(combined_dataset, indices)
96
+
97
+ data_loader = DataLoader(
98
+ combined_dataset,
99
+ batch_size=batch_size,
100
+ shuffle=False,
101
+ num_workers=num_workers,
102
+ pin_memory=pin_memory,
103
+ )
104
+
105
+ return data_loader
SPC-UQ/Image_Classification/data/ood_detection/svhn.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from torch.utils.data import Subset
5
+
6
+ from torchvision import datasets
7
+ from torchvision import transforms
8
+
9
+
10
+ def get_train_valid_loader(batch_size, augment, val_seed, imagesize=128, val_size=0.1, num_workers=1, pin_memory=False, **kwargs):
11
+ """
12
+ Utility function for loading and returning train and valid
13
+ multi-process iterators over the SVHN dataset.
14
+ Params:
15
+ ------
16
+ - batch_size: how many samples per batch to load.
17
+ - augment: whether to apply the data augmentation scheme
18
+ mentioned in the paper. Only applied on the train split.
19
+ - val_seed: fix seed for reproducibility.
20
+ - val_size: percentage split of the training set used for
21
+ the validation set. Should be a float in the range [0, 1].
22
+ - num_workers: number of subprocesses to use when loading the dataset.
23
+ - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
24
+ True if using GPU.
25
+
26
+ Returns
27
+ -------
28
+ - train_loader: training set iterator.
29
+ - valid_loader: validation set iterator.
30
+ """
31
+ error_msg = "[!] val_size should be in the range [0, 1]."
32
+ assert (val_size >= 0) and (val_size <= 1), error_msg
33
+
34
+ normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010],)
35
+
36
+ # define transforms
37
+ valid_transform = transforms.Compose([transforms.Resize(imagesize),transforms.ToTensor(), normalize,])
38
+
39
+ # load the dataset
40
+ data_dir = "./data"
41
+ train_dataset = datasets.SVHN(root=data_dir, split="train", download=True, transform=valid_transform,)
42
+
43
+ valid_dataset = datasets.SVHN(root=data_dir, split="train", download=True, transform=valid_transform,)
44
+
45
+ num_train = len(train_dataset)
46
+ indices = list(range(num_train))
47
+ split = int(np.floor(val_size * num_train))
48
+
49
+ np.random.seed(val_seed)
50
+ np.random.shuffle(indices)
51
+
52
+ train_idx, valid_idx = indices[split:], indices[:split]
53
+ train_subset = Subset(train_dataset, train_idx)
54
+ valid_subset = Subset(valid_dataset, valid_idx)
55
+
56
+ train_loader = torch.utils.data.DataLoader(
57
+ train_subset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True,
58
+ )
59
+ valid_loader = torch.utils.data.DataLoader(
60
+ valid_subset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=False,
61
+ )
62
+
63
+ return (train_loader, valid_loader)
64
+
65
+
66
+ def get_test_loader(batch_size, imagesize=128, num_workers=1, pin_memory=False, **kwargs):
67
+ """
68
+ Utility function for loading and returning a multi-process
69
+ test iterator over the SVHN dataset.
70
+ If using CUDA, num_workers should be set to 1 and pin_memory to True.
71
+ Params
72
+ ------
73
+ - batch_size: how many samples per batch to load.
74
+ - num_workers: number of subprocesses to use when loading the dataset.
75
+ - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
76
+ True if using GPU.
77
+ Returns
78
+ -------
79
+ - data_loader: test set iterator.
80
+ """
81
+ normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010],)
82
+
83
+ # define transform
84
+ transform = transforms.Compose([transforms.Resize(imagesize), transforms.ToTensor(), normalize,])
85
+
86
+
87
+ data_dir = "./data"
88
+ dataset = datasets.SVHN(root=data_dir, split="test", download=True, transform=transform,)
89
+
90
+ data_loader = torch.utils.data.DataLoader(
91
+ dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory,
92
+ )
93
+
94
+ return data_loader
SPC-UQ/Image_Classification/data/ood_detection/tinyimagenet.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Create train, valid, test iterators for Tiny-ImageNet.
3
+ Train set size: 90000 (450 per class)
4
+ Val set size: 10000 (50 per class)
5
+ Test set size: 10000 (no labels)
6
+ """
7
+
8
+ import torch
9
+ import numpy as np
10
+ import os
11
+ from torch.utils.data import Subset
12
+ from torchvision import datasets, transforms
13
+ from torchvision.datasets import ImageFolder
14
+
15
+
16
+ def get_train_valid_loader(batch_size, augment, val_seed, imagesize=128, val_size=0.1, num_workers=1, pin_memory=False, **kwargs):
17
+ """
18
+ Load and return train and valid iterators over the Tiny-ImageNet dataset.
19
+
20
+ Params:
21
+ ------
22
+ - data_dir: path to Tiny-ImageNet dataset directory.
23
+ - batch_size: number of samples per batch.
24
+ - augment: whether to apply data augmentation.
25
+ - val_seed: random seed for reproducibility.
26
+ - val_size: fraction of the training set used for validation (0 to 1).
27
+ - num_workers: number of subprocesses for data loading.
28
+ - pin_memory: set to True if using GPU.
29
+
30
+ Returns:
31
+ -------
32
+ - train_loader: training set iterator.
33
+ - valid_loader: validation set iterator.
34
+ """
35
+ assert 0 <= val_size <= 1, "[!] val_size should be in the range [0, 1]."
36
+
37
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
38
+
39
+ # Define transforms
40
+ valid_transform = transforms.Compose([
41
+ transforms.Resize(imagesize),
42
+ transforms.ToTensor(),
43
+ normalize
44
+ ])
45
+
46
+ if augment:
47
+ train_transform = transforms.Compose([
48
+ transforms.RandomCrop(256, padding=4),
49
+ transforms.RandomHorizontalFlip(),
50
+ transforms.Resize(imagesize),
51
+ transforms.ToTensor(),
52
+ normalize
53
+ ])
54
+ else:
55
+ train_transform = valid_transform
56
+
57
+ # Load dataset
58
+ data_dir = "./data/tinyimagenet"
59
+ train_dir = os.path.join(data_dir, "train")
60
+ train_dataset = ImageFolder(root=train_dir, transform=train_transform)
61
+ valid_dataset = ImageFolder(root=train_dir, transform=valid_transform) # Same dataset, different transform
62
+
63
+ num_train = len(train_dataset)
64
+ indices = list(range(num_train))
65
+ split = int(np.floor(val_size * num_train))
66
+
67
+ np.random.seed(val_seed)
68
+ np.random.shuffle(indices)
69
+
70
+ train_idx, valid_idx = indices[split:], indices[:split]
71
+
72
+ train_subset = Subset(train_dataset, train_idx)
73
+ valid_subset = Subset(valid_dataset, valid_idx)
74
+
75
+ train_loader = torch.utils.data.DataLoader(
76
+ train_subset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True
77
+ )
78
+ valid_loader = torch.utils.data.DataLoader(
79
+ valid_subset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=False
80
+ )
81
+
82
+ return train_loader, valid_loader
83
+
84
+
85
+ def get_test_loader(batch_size, imagesize=128, num_workers=1, pin_memory=False, **kwargs):
86
+ """
87
+ Load and return a test iterator over the Tiny-ImageNet dataset.
88
+
89
+ Params:
90
+ ------
91
+ - data_dir: path to Tiny-ImageNet dataset directory.
92
+ - batch_size: number of samples per batch.
93
+ - num_workers: number of subprocesses for data loading.
94
+ - pin_memory: set to True if using GPU.
95
+
96
+ Returns:
97
+ -------
98
+ - data_loader: test set iterator.
99
+ """
100
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
101
+
102
+ transform = transforms.Compose([
103
+ transforms.Resize(imagesize),
104
+ transforms.ToTensor(),
105
+ normalize
106
+ ])
107
+
108
+ data_dir = "./data/tinyimagenet"
109
+ test_dir = os.path.join(data_dir, "test")
110
+ dataset = ImageFolder(root=test_dir, transform=transform)
111
+
112
+ data_loader = torch.utils.data.DataLoader(
113
+ dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory
114
+ )
115
+ return data_loader
SPC-UQ/Image_Classification/environment.yml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: DDU
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - python
7
+ - pytorch=1.7.1
8
+ - torchvision=0.8.2
9
+ - cudatoolkit=10.1
10
+ - tqdm
11
+ - tensorboard
12
+ - numpy
13
+ - scipy
14
+ - matplotlib
15
+ - seaborn
16
+ - scikit-learn
SPC-UQ/Image_Classification/evaluate.py ADDED
@@ -0,0 +1,1427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script to evaluate a single model.
3
+ """
4
+ import os
5
+ import gc
6
+ import json
7
+ import math
8
+ import torch
9
+ import argparse
10
+ import torch.backends.cudnn as cudnn
11
+ import numpy as np
12
+ import matplotlib.pyplot as plt
13
+ from sklearn.metrics import accuracy_score
14
+
15
+ # Import dataloaders
16
+ import data.ood_detection.cifar10 as cifar10
17
+ import data.ood_detection.cifar100 as cifar100
18
+ import data.ood_detection.svhn as svhn
19
+ import data.ood_detection.imagenet as imagenet
20
+ import data.ood_detection.tinyimagenet as tinyimagenet
21
+ import data.ood_detection.imagenet_o as imagenet_o
22
+ import data.ood_detection.imagenet_a as imagenet_a
23
+ import data.ood_detection.ood_union as ood_union
24
+
25
+ # Import network models
26
+ from net.resnet import resnet50
27
+ from net.resnet_edl import resnet50_edl
28
+ from net.wide_resnet import wrn
29
+ from net.wide_resnet_edl import wrn_edl
30
+ from net.wide_resnet_uq import wrn_uq
31
+ from net.vgg import vgg16
32
+ from net.vgg_edl import vgg16_edl
33
+ from net.vgg_uq import vgg16_uq
34
+ from net.imagenet_wide import imagenet_wide
35
+ from net.imagenet_vgg import imagenet_vgg16
36
+ from net.imagenet_vit import imagenet_vit
37
+
38
+ # Import metrics to compute
39
+ from metrics.classification_metrics import (
40
+ test_classification_net,
41
+ test_classification_net_logits,
42
+ test_classification_uq,
43
+ test_classification_net_ensemble,
44
+ test_classification_net_edl,
45
+ create_adversarial_dataloader,
46
+ test_classification_net_logits_edl,
47
+ test_classification_net_softmax
48
+ )
49
+ from metrics.calibration_metrics import expected_calibration_error
50
+ from metrics.uncertainty_confidence import entropy, logsumexp, self_consistency, edl_unc, certificate
51
+ from metrics.ood_metrics import get_roc_auc, get_roc_auc_logits, get_roc_auc_ensemble, get_unc_ensemble, get_roc_auc_uncs
52
+ from metrics.classification_metrics import get_logits_labels
53
+ from metrics.classification_metrics import get_logits_labels_uq
54
+
55
+ # Import GMM utils
56
+ from utils.gmm_utils import get_embeddings, gmm_evaluate, gmm_fit
57
+ from utils.ensemble_utils import load_ensemble, Ensemble_fit, Ensemble_evaluate, Ensemble_load
58
+ from utils.oc_utils import oc_fit, oc_evaluate
59
+ from utils.eval_utils import model_load_name
60
+ from utils.train_utils import model_save_name
61
+ from utils.args import eval_args
62
+
63
+ # Import SPC utils
64
+ from utils.spc_utils import SPC_fit, SPC_load, SPC_evaluate
65
+
66
+ # Import EDL utils
67
+ from utils.edl_utils import EDL_fit, EDL_load, EDL_evaluate
68
+
69
+ # Temperature scaling
70
+ from utils.temperature_scaling import ModelWithTemperature
71
+
72
+ # Dataset params
73
+ dataset_num_classes = {"cifar10": 10, "cifar100": 100, "svhn": 10, "imagenet": 1000, "tinyimagenet": 200, "imagenet_o":200, "imagenet_a":200}
74
+
75
+ dataset_loader = {"cifar10": cifar10, "cifar100": cifar100, "svhn": svhn, "imagenet": imagenet, "tinyimagenet": tinyimagenet, "imagenet_o":imagenet_o, "imagenet_a":imagenet_a}
76
+
77
+ # Mapping model name to model function
78
+ models = {"resnet50": resnet50, "resnet50_edl":resnet50_edl, "wide_resnet": wrn, "wide_resnet_edl": wrn_edl, "wide_resnet_uq": wrn_uq, "vgg16": vgg16, "vgg16_edl": vgg16_edl, "vgg16_uq": vgg16_uq, "imagenet_wide":imagenet_wide, "imagenet_vgg16":imagenet_vgg16, "imagenet_vit":imagenet_vit}
79
+
80
+ model_to_num_dim = {"resnet50": 2048, "resnet50_edl":2048, "wide_resnet": 640, "wide_resnet_edl": 640, "wide_resnet_uq": 640, "vgg16": 512, "vgg16_edl": 512, "vgg16_uq": 512, "imagenet_wide":2048, "imagenet_vgg16":4096, "imagenet_vit":768}
81
+
82
+ model_to_input_dim = {"resnet50": 32, "resnet50_edl": 32, "wide_resnet": 32, "wide_resnet_edl": 32, "wide_resnet_uq": 32, "vgg16": 32, "vgg16_edl": 32, "vgg16_uq": 32, "imagenet_wide":224, "imagenet_vgg16":224, "imagenet_vit":224}
83
+
84
+ model_to_last_layer = {"resnet50": "module.fc", "wide_resnet": "module.linear", "vgg16": "module.classifier", "imagenet_wide": "module.linear", "imagenet_vgg16": "module.classifier", "imagenet_vit": "module.linear"}
85
+
86
+ if __name__ == "__main__":
87
+
88
+ args = eval_args().parse_args()
89
+
90
+ # Checking if GPU is available
91
+ cuda = torch.cuda.is_available()
92
+
93
+ # Setting additional parameters
94
+ print("Parsed args", args)
95
+ print("Seed: ", args.seed)
96
+ torch.manual_seed(args.seed)
97
+ device = torch.device("cuda" if cuda else "cpu")
98
+
99
+ # Taking input for the dataset
100
+ num_classes = dataset_num_classes[args.dataset]
101
+
102
+ test_loader = dataset_loader[args.dataset].get_test_loader(batch_size=args.batch_size, imagesize=model_to_input_dim[args.model], pin_memory=args.gpu)
103
+
104
+ if args.ood_dataset=='ood_union':
105
+ ood_test_loader = ood_union.get_combined_ood_test_loader(batch_size=args.batch_size, sample_seed=args.seed, imagesize=model_to_input_dim[args.model], pin_memory=args.gpu)
106
+ else:
107
+ ood_test_loader = dataset_loader[args.ood_dataset].get_test_loader(batch_size=args.batch_size,
108
+ imagesize=model_to_input_dim[args.model],
109
+ pin_memory=args.gpu)
110
+
111
+ # Evaluating the models
112
+ accuracies = []
113
+ c_accuracies = []
114
+
115
+ # Pre temperature scaling
116
+ # m1 - Uncertainty/Confidence Metric 1
117
+ # for deterministic model: logsumexp, for ensemble: entropy
118
+ # m2 - Uncertainty/Confidence Metric 2
119
+ # for deterministic model: entropy, for ensemble: MI
120
+ eces = []
121
+ ood_m1_aurocs = []
122
+ ood_m1_auprcs = []
123
+ ood_m2_aurocs = []
124
+ ood_m2_auprcs = []
125
+
126
+ err_m1_aurocs = []
127
+ err_m1_auprcs = []
128
+ err_m2_aurocs = []
129
+ err_m2_auprcs = []
130
+
131
+ adv_ep = 0.02
132
+ adv_m1_aurocs = []
133
+ adv_m1_auprcs = []
134
+ adv_m2_aurocs = []
135
+ adv_m2_auprcs = []
136
+
137
+ # Post temperature scaling
138
+ t_eces = []
139
+ t_m1_aurocs = []
140
+ t_m1_auprcs = []
141
+ t_m2_aurocs = []
142
+ t_m2_auprcs = []
143
+
144
+ c_eces = []
145
+
146
+ adv_unc = np.zeros((args.runs, 9))
147
+ adv_acc = np.zeros((args.runs, 9))
148
+
149
+ topt = None
150
+
151
+ for i in range(args.runs):
152
+ print (f"Evaluating run: {(i+1)}")
153
+ # Loading the model(s)
154
+ if args.model_type == "ensemble":
155
+ if args.dataset == 'imagenet':
156
+ train_loader, val_loader = dataset_loader[args.dataset].get_train_valid_loader(
157
+ batch_size=args.batch_size, imagesize=model_to_input_dim[args.model], augment=args.data_aug,
158
+ val_seed=(args.seed + i), val_size=args.val_size, pin_memory=args.gpu)
159
+ net = models[args.model](pretrained=True, num_classes=1000).cuda()
160
+ if args.gpu:
161
+ net.cuda()
162
+ net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
163
+ cudnn.benchmark = True
164
+ net.eval()
165
+
166
+ else:
167
+ val_loaders = []
168
+ for j in range(args.ensemble):
169
+ train_loader, val_loader = dataset_loader[args.dataset].get_train_valid_loader(
170
+ batch_size=args.batch_size, imagesize=model_to_input_dim[args.model], augment=args.data_aug, val_seed=(args.seed+(5*i)+j), val_size=0.1, pin_memory=args.gpu,
171
+ )
172
+ val_loaders.append(val_loader)
173
+ # Evaluate an ensemble
174
+ ensemble_loc = args.load_loc
175
+ net_ensemble = load_ensemble(
176
+ ensemble_loc=ensemble_loc,
177
+ model_name=args.model,
178
+ device=device,
179
+ num_classes=num_classes,
180
+ spectral_normalization=args.sn,
181
+ mod=args.mod,
182
+ coeff=args.coeff,
183
+ seed=(i)
184
+ )
185
+
186
+ else:
187
+ train_loader, val_loader = dataset_loader[args.dataset].get_train_valid_loader(
188
+ batch_size=args.batch_size, imagesize=model_to_input_dim[args.model], augment=args.data_aug, val_seed=(args.seed+i), val_size=args.val_size, pin_memory=args.gpu,
189
+ )
190
+ if args.dataset == 'imagenet':
191
+ net = models[args.model](pretrained=True, num_classes=1000).cuda()
192
+ if args.gpu:
193
+ net.cuda()
194
+ net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
195
+ cudnn.benchmark = True
196
+ net.eval()
197
+
198
+ else:
199
+ if args.val_size==0.1 or (not args.crossval):
200
+ saved_model_name = os.path.join(
201
+ args.load_loc,
202
+ "Run" + str(i + 1),
203
+ model_load_name(args.model, args.sn, args.mod, args.coeff, args.seed, i) + "_350.model",
204
+ )
205
+ else:
206
+ saved_model_name = os.path.join(
207
+ args.load_loc,
208
+ "Run" + str(i + 1),
209
+ model_load_name(args.model, args.sn, args.mod, args.coeff, args.seed, i) + "_350_0"+str(int(args.val_size*10))+".model",
210
+ )
211
+ print(saved_model_name)
212
+ net = models[args.model](
213
+ spectral_normalization=args.sn, mod=args.mod, coeff=args.coeff, num_classes=num_classes, temp=1.0,
214
+ )
215
+ if args.gpu:
216
+ net.cuda()
217
+ net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
218
+ cudnn.benchmark = True
219
+ net.load_state_dict(torch.load(str(saved_model_name)))
220
+ net.eval()
221
+
222
+
223
+ # Evaluating the UQ method
224
+ if args.model_type == "ensemble":
225
+ if args.dataset == 'imagenet':
226
+ ensemble_model_path = os.path.join(
227
+ args.load_loc,
228
+ "Run" + str(i + 1),
229
+ model_load_name(args.model, args.sn, args.mod, args.coeff, args.seed, i) + "_350_ensemble_model.pth",
230
+ )
231
+
232
+ if os.path.exists(ensemble_model_path):
233
+ print(f"Loading existing ensemble_model from {ensemble_model_path}")
234
+ Ensemble_model = Ensemble_load(ensemble_model_path, model_to_num_dim[args.model],
235
+ num_classes, device)
236
+ else:
237
+ if args.model == 'imagenet_vgg16':
238
+ embed_path = 'data/imagenet_train_vgg_embedding.pt'
239
+ # embed_path = 'data/imagenet_val_vgg_embedding.pt'
240
+ if args.model == 'imagenet_wide':
241
+ embed_path = 'data/imagenet_train_wide_embedding.pt'
242
+ # embed_path = 'data/imagenet_val_wide_embedding.pt'
243
+ if args.model == 'imagenet_vit':
244
+ embed_path = 'data/imagenet_train_vit_embedding.pt'
245
+ # embed_path = 'data/imagenet_val_vit_embedding.pt'
246
+ if os.path.exists(embed_path):
247
+ data = torch.load(embed_path, map_location=device)
248
+ embeddings = data['embeddings']
249
+ labels = data['labels']
250
+ else:
251
+ embeddings, labels = get_embeddings(
252
+ net,
253
+ train_loader,
254
+ num_dim=model_to_num_dim[args.model],
255
+ dtype=torch.double,
256
+ device=device,
257
+ storage_device=device,
258
+ )
259
+ torch.save({'embeddings': embeddings, 'labels': labels}, embed_path)
260
+ Ensemble_model = Ensemble_fit(embeddings, labels, model_to_num_dim[args.model], num_classes, device)
261
+ torch.save(Ensemble_model.state_dict(), ensemble_model_path)
262
+ print(f"Model saved at {ensemble_model_path}")
263
+
264
+ logits, predictive_entropy, mut_info, labels = Ensemble_evaluate(net, Ensemble_model, test_loader, model_to_num_dim[args.model], device)
265
+ ood_logits, ood_predictive_entropy, ood_mut_info, ood_labels = Ensemble_evaluate(net, Ensemble_model, ood_test_loader, model_to_num_dim[args.model], device)
266
+ (conf_matrix, accuracy, labels_list, predictions, confidences,) = test_classification_net_softmax(logits, labels)
267
+ t_accuracy = accuracy
268
+ ece = expected_calibration_error(confidences, predictions, labels_list, num_bins=15)
269
+ t_ece = ece
270
+ (_, _, _), (_, _, _), ood_m1_auroc, ood_m1_auprc = get_roc_auc_uncs(predictive_entropy, ood_predictive_entropy, device)
271
+ (_, _, _), (_, _, _), ood_m2_auroc, ood_m2_auprc = get_roc_auc_uncs(mut_info, ood_mut_info, device)
272
+
273
+ labels_array = np.array(labels_list)
274
+ pred_array = np.array(predictions)
275
+ correct_mask = labels_array == pred_array
276
+ entropy_right = predictive_entropy[correct_mask]
277
+ entropy_wrong = predictive_entropy[~correct_mask]
278
+ mut_info_right = mut_info[correct_mask]
279
+ mut_info_wrong = mut_info[~correct_mask]
280
+ (_, _, _), (_, _, _), err_m1_auroc, err_m1_auprc = get_roc_auc_uncs(entropy_right, entropy_wrong, device)
281
+ (_, _, _), (_, _, _), err_m2_auroc, err_m2_auprc = get_roc_auc_uncs(mut_info_right, mut_info_wrong, device)
282
+
283
+ adv_test_loader = dataset_loader['imagenet_a'].get_test_loader(batch_size=args.batch_size,
284
+ imagesize=model_to_input_dim[args.model],
285
+ pin_memory=args.gpu)
286
+
287
+ adv_logits, adv_predictive_entropy, adv_mut_info, adv_labels = Ensemble_evaluate(net, Ensemble_model, adv_test_loader, model_to_num_dim[args.model], device)
288
+
289
+
290
+ (_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc_uncs(predictive_entropy, adv_predictive_entropy, device)
291
+ (_, _, _), (_, _, _), adv_m2_auroc, adv_m2_auprc = get_roc_auc_uncs(mut_info, adv_mut_info, device)
292
+
293
+ print('adv_m1_auroc', adv_m1_auroc)
294
+
295
+ t_m1_auroc = ood_m1_auroc
296
+ t_m1_auprc = ood_m1_auprc
297
+ t_m2_auroc = ood_m2_auroc
298
+ t_m2_auprc = ood_m2_auprc
299
+
300
+
301
+ else:
302
+ (conf_matrix, accuracy, labels_list, predictions, confidences,) = test_classification_net_ensemble(
303
+ net_ensemble, test_loader, device
304
+ )
305
+ ece = expected_calibration_error(confidences, predictions, labels_list, num_bins=15)
306
+
307
+ (_, _, _), (_, _, _), ood_m1_auroc, ood_m1_auprc = get_roc_auc_ensemble(
308
+ net_ensemble, test_loader, ood_test_loader, "entropy", device
309
+ )
310
+ (_, _, _), (_, _, _), ood_m2_auroc, ood_m2_auprc = get_roc_auc_ensemble(
311
+ net_ensemble, test_loader, ood_test_loader, "mutual_information", device
312
+ )
313
+
314
+ labels_array = np.array(labels_list)
315
+ pred_array = np.array(predictions)
316
+ correct_mask = labels_array == pred_array
317
+ from torch.utils.data import Subset, DataLoader
318
+ dataset = test_loader.dataset
319
+ correct_indices = np.where(correct_mask)[0]
320
+ right_subset = Subset(dataset, correct_indices)
321
+ right_loader = DataLoader(right_subset, batch_size=test_loader.batch_size, shuffle=False)
322
+ wrong_indices = np.where(~correct_mask)[0]
323
+ wrong_subset = Subset(dataset, wrong_indices)
324
+ wrong_loader = DataLoader(wrong_subset, batch_size=test_loader.batch_size, shuffle=False)
325
+ (_, _, _), (_, _, _), err_m1_auroc, err_m1_auprc = get_roc_auc_ensemble(
326
+ net_ensemble, right_loader, wrong_loader, "entropy", device
327
+ )
328
+ (_, _, _), (_, _, _), err_m2_auroc, err_m2_auprc = get_roc_auc_ensemble(
329
+ net_ensemble, right_loader, wrong_loader, "mutual_information", device
330
+ )
331
+
332
+ adv_test_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=adv_ep,batch_size=args.batch_size)
333
+ (_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc_ensemble(
334
+ net_ensemble, test_loader, adv_test_loader, "entropy", device
335
+ )
336
+ (_, _, _), (_, _, _), adv_m2_auroc, adv_m2_auprc = get_roc_auc_ensemble(
337
+ net_ensemble, test_loader, adv_test_loader, "mutual_information", device
338
+ )
339
+ print('adv_m1_auroc,adv_m2_auroc', adv_m1_auroc, adv_m2_auroc)
340
+
341
+ if args.sample_noise:
342
+ adv_eps = np.linspace(0, 0.4, 9)
343
+ print(adv_eps)
344
+ for idx_ep, ep in enumerate(adv_eps):
345
+ adv_loader = create_adversarial_dataloader(net_ensemble[0], test_loader, device, epsilon=ep,
346
+ batch_size=args.batch_size)
347
+ (adv_conf_matrix, adv_accuracy, adv_labels_list, adv_predictions, adv_confidences) = test_classification_net_ensemble(
348
+ net_ensemble, adv_loader, device
349
+ )
350
+ uncertainties = get_unc_ensemble(net_ensemble, adv_loader, "entropy", device).detach().cpu().numpy()
351
+ quantiles = np.quantile(uncertainties, np.linspace(0, 1, 10))
352
+ quantiles = np.delete(quantiles, 0)
353
+ unc_list = []
354
+ accuracy_list = []
355
+ for threshold in quantiles:
356
+ cer_indices = (uncertainties < threshold)
357
+ unc_indices = ~cer_indices
358
+ labels_list = np.array(adv_labels_list)
359
+ targets_cer = labels_list[cer_indices]
360
+ predictions = np.array(adv_predictions)
361
+ pred_cer = predictions[cer_indices]
362
+ targets_unc = labels_list[unc_indices]
363
+ pred_unc = predictions[unc_indices]
364
+ cer_right = np.sum(targets_cer == pred_cer)
365
+ cer = len(targets_cer)
366
+ unc_right = np.sum(targets_unc == pred_unc)
367
+ unc = len(targets_unc)
368
+ accuracy_cer = cer_right / cer
369
+ accuracy_unc = unc_right / unc
370
+ unc_list.append(threshold)
371
+ accuracy_list.append(accuracy_cer)
372
+ print('ACC:', accuracy_cer, accuracy_unc)
373
+ from scipy.stats import spearmanr
374
+
375
+ Spearman_acc, p_acc = spearmanr(unc_list, accuracy_list)
376
+ print("Spearman correlation:", Spearman_acc, "mean uncertainties:", uncertainties.mean())
377
+ adv_unc[i][idx_ep] = uncertainties.mean()
378
+ adv_acc[i][idx_ep] = adv_accuracy
379
+
380
+ (_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc_ensemble(
381
+ net_ensemble, test_loader, adv_loader, "entropy", device
382
+ )
383
+ (_, _, _), (_, _, _), adv_m2_auroc, adv_m2_auprc = get_roc_auc_ensemble(
384
+ net_ensemble, test_loader, adv_loader, "mutual_information", device
385
+ )
386
+ print('adv_m1_auroc,adv_m2_auroc', adv_m1_auroc, adv_m2_auroc)
387
+
388
+ # Temperature scale the ensemble
389
+ t_ensemble = []
390
+ for model, val_loader in zip(net_ensemble, val_loaders):
391
+ t_model = ModelWithTemperature(model)
392
+ t_model.set_temperature(val_loader)
393
+ t_ensemble.append(t_model)
394
+
395
+ (
396
+ t_conf_matrix,
397
+ t_accuracy,
398
+ t_labels_list,
399
+ t_predictions,
400
+ t_confidences,
401
+ ) = test_classification_net_ensemble(t_ensemble, test_loader, device)
402
+ t_ece = expected_calibration_error(t_confidences, t_predictions, t_labels_list, num_bins=15)
403
+
404
+ (_, _, _), (_, _, _), t_m1_auroc, t_m1_auprc = get_roc_auc_ensemble(
405
+ t_ensemble, test_loader, ood_test_loader, "entropy", device
406
+ )
407
+ (_, _, _), (_, _, _), t_m2_auroc, t_m2_auprc = get_roc_auc_ensemble(
408
+ t_ensemble, test_loader, ood_test_loader, "mutual_information", device
409
+ )
410
+
411
+ elif args.model_type == "edl":
412
+ if args.dataset == 'imagenet':
413
+ edl_model_path = os.path.join(
414
+ args.load_loc,
415
+ "Run" + str(i + 1),
416
+ model_load_name(args.model, args.sn, args.mod, args.coeff, args.seed, i) + "_350_edl_model.pth",
417
+ )
418
+
419
+ if os.path.exists(edl_model_path):
420
+ print(f"Loading existing edl_model from {edl_model_path}")
421
+ EDL_model = EDL_load(edl_model_path, model_to_num_dim[args.model],
422
+ num_classes, device)
423
+ else:
424
+ if args.model=='imagenet_vgg16':
425
+ embed_path = 'data/imagenet_train_vgg_embedding.pt'
426
+ # embed_path = 'data/imagenet_val_vgg_embedding.pt'
427
+ if args.model=='imagenet_wide':
428
+ embed_path = 'data/imagenet_train_wide_embedding.pt'
429
+ # embed_path = 'data/imagenet_val_wide_embedding.pt'
430
+ if args.model=='imagenet_vit':
431
+ embed_path = 'data/imagenet_train_vit_embedding.pt'
432
+ # embed_path = 'data/imagenet_val_vit_embedding.pt'
433
+ if os.path.exists(embed_path):
434
+ data = torch.load(embed_path, map_location=device)
435
+ embeddings = data['embeddings']
436
+ labels = data['labels']
437
+ else:
438
+ embeddings, labels = get_embeddings(
439
+ net,
440
+ train_loader,
441
+ num_dim=model_to_num_dim[args.model],
442
+ dtype=torch.double,
443
+ device=device,
444
+ storage_device=device,
445
+ )
446
+ torch.save({'embeddings': embeddings, 'labels': labels}, embed_path)
447
+ EDL_model = EDL_fit(embeddings, labels, model_to_num_dim[args.model], num_classes,
448
+ device)
449
+ torch.save(EDL_model.state_dict(), edl_model_path)
450
+ print(f"Model saved at {edl_model_path}")
451
+
452
+ logits, labels = EDL_evaluate(net, EDL_model, test_loader, model_to_num_dim[args.model], device)
453
+ ood_logits, ood_labels = EDL_evaluate(net, EDL_model, ood_test_loader, model_to_num_dim[args.model], device)
454
+ (conf_matrix, accuracy, labels_list, predictions, confidences,) = test_classification_net_logits_edl(logits, labels)
455
+ t_accuracy = accuracy
456
+
457
+ ece = expected_calibration_error(confidences, predictions, labels_list, num_bins=15)
458
+ t_ece = ece
459
+
460
+ (_, _, _), (_, _, _), ood_m1_auroc, ood_m1_auprc = get_roc_auc_logits(logits, ood_logits, edl_unc, device)
461
+
462
+ labels_array = np.array(labels_list)
463
+ pred_array = np.array(predictions)
464
+ correct_mask = labels_array == pred_array
465
+ # logits, _ = get_logits_labels(net, test_loader, device)
466
+ logits_right = logits[correct_mask]
467
+ logits_wrong = logits[~correct_mask]
468
+ (_, _, _), (_, _, _), err_m1_auroc, err_m1_auprc = get_roc_auc_logits(logits_right, logits_wrong,
469
+ edl_unc, device)
470
+
471
+ adv_test_loader = dataset_loader['imagenet_a'].get_test_loader(batch_size=args.batch_size,
472
+ imagesize=model_to_input_dim[args.model],
473
+ pin_memory=args.gpu)
474
+ # (adv_conf_matrix, adv_accuracy, adv_labels_list, adv_predictions,adv_confidences,) = test_classification_net_edl(net, adv_test_loader, device)
475
+ # adv_logits, _ = get_logits_labels(net, adv_test_loader, device)
476
+ adv_logits, adv_labels = EDL_evaluate(net, EDL_model, adv_test_loader, model_to_num_dim[args.model], device)
477
+
478
+ (_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc_logits(logits, adv_logits, edl_unc, device)
479
+ print('adv_m1_auroc', adv_m1_auroc)
480
+
481
+ if args.sample_noise:
482
+ adv_eps = np.linspace(0, 0.4, 9)
483
+ for idx_ep, ep in enumerate(adv_eps):
484
+ adv_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=ep,
485
+ batch_size=args.batch_size)
486
+ (adv_conf_matrix, adv_accuracy, adv_labels_list, adv_predictions,
487
+ adv_confidences,) = test_classification_net(net, adv_loader, device)
488
+ adv_logits, adv_labels = EDL_evaluate(net, EDL_model, adv_loader,
489
+ model_to_num_dim[args.model], device)
490
+ uncertainties = edl_unc(adv_logits).detach().cpu().numpy()
491
+ quantiles = np.quantile(uncertainties, np.linspace(0, 1, 10))
492
+ quantiles = np.delete(quantiles, 0)
493
+ unc_list = []
494
+ accuracy_list = []
495
+ for threshold in quantiles:
496
+ cer_indices = (uncertainties < threshold)
497
+ unc_indices = ~cer_indices
498
+ labels_list = np.array(adv_labels_list)
499
+ targets_cer = labels_list[cer_indices]
500
+ predictions = np.array(adv_predictions)
501
+ pred_cer = predictions[cer_indices]
502
+ targets_unc = labels_list[unc_indices]
503
+ pred_unc = predictions[unc_indices]
504
+ cer_right = np.sum(targets_cer == pred_cer)
505
+ cer = len(targets_cer)
506
+ unc_right = np.sum(targets_unc == pred_unc)
507
+ unc = len(targets_unc)
508
+ accuracy_cer = cer_right / cer
509
+ accuracy_unc = unc_right / unc
510
+ unc_list.append(threshold)
511
+ accuracy_list.append(accuracy_cer)
512
+ print('ACC:', accuracy_cer, accuracy_unc)
513
+ from scipy.stats import spearmanr
514
+
515
+ Spearman_acc, p_acc = spearmanr(unc_list, accuracy_list)
516
+ print("Spearman correlation:", Spearman_acc, "mean uncertainties:", uncertainties.mean())
517
+ adv_unc[i][idx_ep] = uncertainties.mean()
518
+ adv_acc[i][idx_ep] = adv_accuracy
519
+
520
+
521
+ else:
522
+ (conf_matrix, accuracy, labels_list, predictions, confidences,) = test_classification_net_edl(
523
+ net, test_loader, device
524
+ )
525
+ t_accuracy = accuracy
526
+ ece = expected_calibration_error(confidences, predictions, labels_list, num_bins=15)
527
+ t_ece=ece
528
+
529
+ (_, _, _), (_, _, _), ood_m1_auroc, ood_m1_auprc = get_roc_auc(net, test_loader, ood_test_loader, edl_unc, device)
530
+
531
+ labels_array = np.array(labels_list)
532
+ pred_array = np.array(predictions)
533
+ correct_mask = labels_array == pred_array
534
+ logits, _ = get_logits_labels(net, test_loader, device)
535
+ logits_right = logits[correct_mask]
536
+ logits_wrong = logits[~correct_mask]
537
+ (_, _, _), (_, _, _), err_m1_auroc, err_m1_auprc = get_roc_auc_logits(logits_right, logits_wrong, edl_unc, device)
538
+
539
+ adv_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=adv_ep,
540
+ batch_size=args.batch_size, edl=True)
541
+ (adv_conf_matrix, adv_accuracy, adv_labels_list, adv_predictions,
542
+ adv_confidences,) = test_classification_net_edl(net, adv_loader, device)
543
+ adv_logits, _ = get_logits_labels(net, adv_loader, device)
544
+
545
+ (_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc_logits(logits, adv_logits, edl_unc, device)
546
+ print('adv_m1_auroc', adv_m1_auroc)
547
+
548
+
549
+ if args.sample_noise:
550
+ adv_eps = np.linspace(0, 0.4, 9)
551
+ for idx_ep, ep in enumerate(adv_eps):
552
+ adv_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=ep,
553
+ batch_size=args.batch_size, edl=True)
554
+ (adv_conf_matrix, adv_accuracy, adv_labels_list, adv_predictions,
555
+ adv_confidences,) = test_classification_net_edl(net, adv_loader, device)
556
+ adv_logits, _ = get_logits_labels(net, adv_loader, device)
557
+ uncertainties = edl_unc(adv_logits).detach().cpu().numpy()
558
+ quantiles = np.quantile(uncertainties, np.linspace(0, 1, 10))
559
+ quantiles = np.delete(quantiles, 0)
560
+ unc_list = []
561
+ accuracy_list = []
562
+ for threshold in quantiles:
563
+ cer_indices = (uncertainties < threshold)
564
+ unc_indices = ~cer_indices
565
+ labels_list = np.array(adv_labels_list)
566
+ targets_cer = labels_list[cer_indices]
567
+ predictions = np.array(adv_predictions)
568
+ pred_cer = predictions[cer_indices]
569
+ targets_unc = labels_list[unc_indices]
570
+ pred_unc = predictions[unc_indices]
571
+ cer_right = np.sum(targets_cer == pred_cer)
572
+ cer = len(targets_cer)
573
+ unc_right = np.sum(targets_unc == pred_unc)
574
+ unc = len(targets_unc)
575
+ accuracy_cer = cer_right / cer
576
+ accuracy_unc = unc_right / unc
577
+ unc_list.append(threshold)
578
+ accuracy_list.append(accuracy_cer)
579
+ print('ACC:', accuracy_cer, accuracy_unc)
580
+ from scipy.stats import spearmanr
581
+
582
+ Spearman_acc, p_acc = spearmanr(unc_list, accuracy_list)
583
+ print("Spearman correlation:", Spearman_acc, "mean uncertainties:", uncertainties.mean())
584
+ adv_unc[i][idx_ep] = uncertainties.mean()
585
+ adv_acc[i][idx_ep] = adv_accuracy
586
+
587
+ ood_m2_auroc=ood_m1_auroc
588
+ ood_m2_auprc = ood_m1_auprc
589
+ err_m2_auroc = err_m1_auroc
590
+ err_m2_auprc = err_m1_auprc
591
+ adv_m2_auroc = adv_m1_auroc
592
+ adv_m2_auprc = adv_m1_auprc
593
+ t_m1_auroc=ood_m1_auroc
594
+ t_m1_auprc=ood_m1_auprc
595
+ t_m2_auroc=ood_m1_auroc
596
+ t_m2_auprc=ood_m1_auprc
597
+
598
+ elif args.model_type == "joint":
599
+ (conf_matrix, accuracy, labels_list, predictions, confidences,) = test_classification_uq(
600
+ net, test_loader, device
601
+ )
602
+ ece = expected_calibration_error(confidences, predictions, labels_list, num_bins=15)
603
+ print(accuracy)
604
+ print('ece', ece)
605
+
606
+ t_ece=ece
607
+ t_accuracy=accuracy
608
+
609
+ print("SPC Model")
610
+
611
+ logits, labels = get_logits_labels_uq(net, test_loader, device)
612
+
613
+ soft = torch.nn.functional.softmax(logits[0], dim=1)
614
+ delta = torch.min(torch.min(logits[2] - logits[3], logits[1] - 2 * logits[3]), 2 * logits[2] - logits[1])
615
+
616
+ uncertainty = abs(logits[2] + logits[3] - logits[1])
617
+
618
+ threshold = 0.05
619
+ mask = (uncertainty < threshold).float()
620
+ delta = delta * mask
621
+ softmax_prob = soft + delta
622
+
623
+ c_confidences, c_predictions = torch.max(softmax_prob, dim=1)
624
+ c_predictions = c_predictions.tolist()
625
+ c_confidences = c_confidences.tolist()
626
+ c_ece = expected_calibration_error(c_confidences, c_predictions, labels_list, num_bins=15)
627
+ print('ece', ece, 't_ece', t_ece, 'c_ece', c_ece)
628
+ c_accuracy = accuracy_score(labels_list, c_predictions)
629
+ print(accuracy, t_accuracy, c_accuracy)
630
+
631
+ ood_logits, ood_labels = get_logits_labels_uq(net, ood_test_loader, device)
632
+
633
+ (_, _, _), (_, _, _), ood_m1_auroc, ood_m1_auprc = get_roc_auc_logits(logits, ood_logits, self_consistency,
634
+ device)
635
+ (_, _, _), (_, _, _), ood_m2_auroc, ood_m2_auprc = get_roc_auc_logits(logits[0], ood_logits[0], entropy, device)
636
+
637
+ labels_array = np.array(labels_list)
638
+ pred_array = np.array(predictions)
639
+ correct_mask = labels_array == pred_array
640
+ logits_right = [m[correct_mask] for m in logits]
641
+ logits_wrong = [m[~correct_mask] for m in logits]
642
+ (_, _, _), (_, _, _), err_m1_auroc, err_m1_auprc = get_roc_auc_logits(logits_right, logits_wrong,
643
+ self_consistency, device)
644
+ (_, _, _), (_, _, _), err_m2_auroc, err_m2_auprc = get_roc_auc_logits(logits_right[0], logits_wrong[0], entropy,
645
+ device)
646
+
647
+ if args.dataset == 'imagenet':
648
+ adv_test_loader = dataset_loader['imagenet_a'].get_test_loader(batch_size=args.batch_size,
649
+ imagesize=model_to_input_dim[args.model],
650
+ pin_memory=args.gpu)
651
+ else:
652
+ adv_test_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=adv_ep,
653
+ batch_size=args.batch_size, joint=True)
654
+
655
+ adv_logits, adv_labels = get_logits_labels_uq(net, adv_test_loader, device)
656
+
657
+ (_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc_logits(logits, adv_logits, self_consistency, device)
658
+ (_, _, _), (_, _, _), adv_m2_auroc, adv_m2_auprc = get_roc_auc_logits(logits[0], adv_logits[0], entropy, device)
659
+
660
+ t_m1_auroc = ood_m1_auroc
661
+ t_m1_auprc = ood_m1_auprc
662
+ t_m2_auroc = ood_m2_auroc
663
+ t_m2_auprc = ood_m2_auprc
664
+
665
+ else:
666
+ (conf_matrix, accuracy, labels_list, predictions, confidences,) = test_classification_net(
667
+ net, test_loader, device
668
+ )
669
+ ece = expected_calibration_error(confidences, predictions, labels_list, num_bins=15)
670
+ print(accuracy)
671
+ print('ece',ece)
672
+
673
+ temp_scaled_net = ModelWithTemperature(net)
674
+ temp_scaled_net.set_temperature(val_loader)
675
+ # temp_scaled_net.set_temperature(train_loader)
676
+ topt = temp_scaled_net.temperature
677
+
678
+ (t_conf_matrix, t_accuracy, t_labels_list, t_predictions, t_confidences,) = test_classification_net(
679
+ temp_scaled_net, test_loader, device
680
+ )
681
+ t_ece = expected_calibration_error(t_confidences, t_predictions, t_labels_list, num_bins=15)
682
+ print('t_ece',t_ece)
683
+
684
+ if (args.model_type == "gmm"):
685
+ # Evaluate a GMM model
686
+ print("GMM Model")
687
+
688
+ if args.crossval:
689
+ embeddings, labels = get_embeddings(
690
+ net,
691
+ val_loader,
692
+ num_dim=model_to_num_dim[args.model],
693
+ dtype=torch.double,
694
+ device=device,
695
+ storage_device=device,
696
+ )
697
+ else:
698
+ if args.dataset == 'imagenet':
699
+ if args.model == 'imagenet_vgg16':
700
+ embed_path = 'data/imagenet_train_vgg_embedding.pt'
701
+ # embed_path = 'data/imagenet_val_vgg_embedding.pt'
702
+ if args.model == 'imagenet_wide':
703
+ embed_path = 'data/imagenet_train_wide_embedding.pt'
704
+ # embed_path = 'data/imagenet_val_wide_embedding.pt'
705
+ if args.model == 'imagenet_vit':
706
+ embed_path = 'data/imagenet_train_vit_embedding.pt'
707
+ # embed_path = 'data/imagenet_val_vit_embedding.pt'
708
+ if os.path.exists(embed_path):
709
+ data = torch.load(embed_path, map_location=device)
710
+ embeddings = data['embeddings']
711
+ labels = data['labels']
712
+ else:
713
+ embeddings, labels = get_embeddings(
714
+ net,
715
+ train_loader,
716
+ num_dim=model_to_num_dim[args.model],
717
+ dtype=torch.double,
718
+ device=device,
719
+ storage_device=device,
720
+ )
721
+ torch.save({'embeddings': embeddings, 'labels': labels}, embed_path)
722
+ else:
723
+ embeddings, labels = get_embeddings(
724
+ net,
725
+ train_loader,
726
+ num_dim=model_to_num_dim[args.model],
727
+ dtype=torch.double,
728
+ device=device,
729
+ storage_device=device,
730
+ )
731
+
732
+ try:
733
+ gaussians_model, jitter_eps = gmm_fit(embeddings=embeddings, labels=labels, num_classes=num_classes, device=device)
734
+ logits, labels = gmm_evaluate(
735
+ net, gaussians_model, test_loader, device=device, num_classes=num_classes, storage_device=device,
736
+ )
737
+
738
+ ood_logits, ood_labels = gmm_evaluate(
739
+ net, gaussians_model, ood_test_loader, device=device, num_classes=num_classes, storage_device=device,
740
+ )
741
+
742
+ (_, _, _), (_, _, _), ood_m1_auroc, ood_m1_auprc = get_roc_auc_logits(
743
+ logits, ood_logits, logsumexp, device, confidence=True
744
+ )
745
+ (_, _, _), (_, _, _), ood_m2_auroc, ood_m2_auprc = get_roc_auc_logits(logits, ood_logits, entropy, device)
746
+
747
+ labels_array = np.array(labels_list)
748
+ pred_array = np.array(predictions)
749
+ correct_mask = labels_array == pred_array
750
+ logits_right = logits[correct_mask]
751
+ logits_wrong = logits[~correct_mask]
752
+ (_, _, _), (_, _, _), err_m1_auroc, err_m1_auprc = get_roc_auc_logits(logits_right, logits_wrong,logsumexp, device, confidence=True)
753
+ (_, _, _), (_, _, _), err_m2_auroc, err_m2_auprc = get_roc_auc_logits(logits_right, logits_wrong,entropy, device, confidence=True)
754
+
755
+ if args.dataset == 'imagenet':
756
+ adv_test_loader = dataset_loader['imagenet_a'].get_test_loader(batch_size=args.batch_size,
757
+ imagesize=model_to_input_dim[args.model],
758
+ pin_memory=args.gpu)
759
+ else:
760
+ adv_test_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=adv_ep,batch_size=args.batch_size)
761
+ (adv_conf_matrix, adv_accuracy, adv_labels_list, adv_predictions, adv_confidences,) = test_classification_net(net, adv_test_loader, device)
762
+ adv_logits, adv_labels = gmm_evaluate(net, gaussians_model, adv_test_loader, device=device, num_classes=num_classes, storage_device=device, )
763
+ (_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc_logits(logits, adv_logits, logsumexp, device, confidence=True)
764
+ (_, _, _), (_, _, _), adv_m2_auroc, adv_m2_auprc = get_roc_auc_logits(logits, adv_logits, entropy, device)
765
+
766
+ if args.sample_noise:
767
+ adv_eps = np.linspace(0, 0.4, 9)
768
+ for idx_ep, ep in enumerate(adv_eps):
769
+ adv_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=ep,
770
+ batch_size=args.batch_size)
771
+ (adv_conf_matrix, adv_accuracy, adv_labels_list, adv_predictions,
772
+ adv_confidences,) = test_classification_net(net, adv_loader, device)
773
+ adv_logits, adv_labels = gmm_evaluate(net, gaussians_model, adv_loader, device=device, num_classes=num_classes,storage_device=device,)
774
+ uncertainties = -logsumexp(adv_logits).detach().cpu().numpy()
775
+ quantiles = np.quantile(uncertainties, np.linspace(0, 1, 10))
776
+ quantiles = np.delete(quantiles, 0)
777
+ unc_list = []
778
+ accuracy_list = []
779
+ for threshold in quantiles:
780
+ cer_indices = (uncertainties < threshold)
781
+ unc_indices = ~cer_indices
782
+ labels_list = np.array(adv_labels_list)
783
+ targets_cer = labels_list[cer_indices]
784
+ predictions = np.array(adv_predictions)
785
+ pred_cer = predictions[cer_indices]
786
+ targets_unc = labels_list[unc_indices]
787
+ pred_unc = predictions[unc_indices]
788
+ cer_right = np.sum(targets_cer == pred_cer)
789
+ cer = len(targets_cer)
790
+ unc_right = np.sum(targets_unc == pred_unc)
791
+ unc = len(targets_unc)
792
+ accuracy_cer = cer_right / cer
793
+ accuracy_unc = unc_right / unc
794
+ unc_list.append(threshold)
795
+ accuracy_list.append(accuracy_cer)
796
+ print('ACC:', accuracy_cer, accuracy_unc)
797
+ from scipy.stats import spearmanr
798
+ Spearman_acc, p_acc = spearmanr(unc_list, accuracy_list)
799
+ print("Spearman correlation:", Spearman_acc, "mean uncertainties:", uncertainties.mean())
800
+ adv_unc[i][idx_ep]=uncertainties.mean()
801
+ adv_acc[i][idx_ep] = adv_accuracy
802
+
803
+ (_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc_logits(logits, adv_logits, logsumexp, device, confidence=True)
804
+ (_, _, _), (_, _, _), adv_m2_auroc, adv_m2_auprc = get_roc_auc_logits(logits, adv_logits, entropy, device)
805
+ print('adv_m1_auroc,adv_m2_auroc', adv_m1_auroc, adv_m2_auroc)
806
+
807
+ t_m1_auroc = ood_m1_auroc
808
+ t_m1_auprc = ood_m1_auprc
809
+ t_m2_auroc = ood_m2_auroc
810
+ t_m2_auprc = ood_m2_auprc
811
+
812
+ except RuntimeError as e:
813
+ print("Runtime Error caught: " + str(e))
814
+ continue
815
+
816
+ elif (args.model_type == "oc"):
817
+ # Evaluate a OC model
818
+ print("OC Model")
819
+
820
+ if args.crossval:
821
+ embeddings, labels = get_embeddings(
822
+ net,
823
+ val_loader,
824
+ num_dim=model_to_num_dim[args.model],
825
+ dtype=torch.double,
826
+ device=device,
827
+ storage_device=device,
828
+ )
829
+ else:
830
+ if args.dataset == 'imagenet':
831
+ if args.model == 'imagenet_vgg16':
832
+ embed_path = 'data/imagenet_train_vgg_embedding.pt'
833
+ # embed_path = 'data/imagenet_val_vgg_embedding.pt'
834
+ if args.model == 'imagenet_wide':
835
+ embed_path = 'data/imagenet_train_wide_embedding.pt'
836
+ # embed_path = 'data/imagenet_val_wide_embedding.pt'
837
+ if args.model == 'imagenet_vit':
838
+ embed_path = 'data/imagenet_train_vit_embedding.pt'
839
+ # embed_path = 'data/imagenet_val_vit_embedding.pt'
840
+ if os.path.exists(embed_path):
841
+ data = torch.load(embed_path, map_location=device)
842
+ embeddings = data['embeddings']
843
+ labels = data['labels']
844
+ else:
845
+ embeddings, labels = get_embeddings(
846
+ net,
847
+ train_loader,
848
+ num_dim=model_to_num_dim[args.model],
849
+ dtype=torch.double,
850
+ device=device,
851
+ storage_device=device,
852
+ )
853
+ torch.save({'embeddings': embeddings, 'labels': labels}, embed_path)
854
+ else:
855
+ embeddings, labels = get_embeddings(
856
+ net,
857
+ train_loader,
858
+ num_dim=model_to_num_dim[args.model],
859
+ dtype=torch.double,
860
+ device=device,
861
+ storage_device=device,
862
+ )
863
+
864
+ try:
865
+ oc_model = oc_fit(embeddings=embeddings, device=device)
866
+ logits, OCs = oc_evaluate(
867
+ net, oc_model, test_loader,model_to_num_dim[args.model], device=device
868
+ )
869
+
870
+ ood_logits, ood_OCs = oc_evaluate(
871
+ net, oc_model, ood_test_loader, model_to_num_dim[args.model], device=device)
872
+
873
+ (_, _, _), (_, _, _), ood_m1_auroc, ood_m1_auprc = get_roc_auc_logits(OCs, ood_OCs, certificate, device, confidence=True)
874
+ (_, _, _), (_, _, _), ood_m2_auroc, ood_m2_auprc = get_roc_auc_logits(logits, ood_logits, entropy, device)
875
+
876
+ labels_array = np.array(labels_list)
877
+ pred_array = np.array(predictions)
878
+ correct_mask = labels_array == pred_array
879
+ logits_right = logits[correct_mask]
880
+ logits_wrong = logits[~correct_mask]
881
+ OCs_right = OCs[correct_mask]
882
+ OCs_wrong = OCs[~correct_mask]
883
+ (_, _, _), (_, _, _), err_m1_auroc, err_m1_auprc = get_roc_auc_logits(OCs_right, OCs_wrong, certificate, device, confidence=True)
884
+ (_, _, _), (_, _, _), err_m2_auroc, err_m2_auprc = get_roc_auc_logits(logits_right, logits_wrong, entropy, device)
885
+
886
+ if args.dataset == 'imagenet':
887
+ adv_test_loader = dataset_loader['imagenet_a'].get_test_loader(batch_size=args.batch_size,
888
+ imagesize=model_to_input_dim[args.model],
889
+ pin_memory=args.gpu)
890
+ else:
891
+ adv_test_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=adv_ep,batch_size=args.batch_size)
892
+ (adv_conf_matrix, adv_accuracy, adv_labels_list, adv_predictions, adv_confidences,) = test_classification_net(net, adv_test_loader, device)
893
+ adv_logits, adv_OCs = oc_evaluate(net, oc_model, adv_test_loader, model_to_num_dim[args.model], device=device)
894
+ (_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc_logits(OCs, adv_OCs, certificate, device, confidence=True)
895
+ (_, _, _), (_, _, _), adv_m2_auroc, adv_m2_auprc = get_roc_auc_logits(logits, adv_logits, entropy, device)
896
+
897
+ if args.sample_noise:
898
+ adv_eps = np.linspace(0, 0.4, 9)
899
+ for idx_ep, ep in enumerate(adv_eps):
900
+ adv_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=ep,
901
+ batch_size=args.batch_size)
902
+ (adv_conf_matrix, adv_accuracy, adv_labels_list, adv_predictions,
903
+ adv_confidences,) = test_classification_net(net, adv_loader, device)
904
+ adv_logits, adv_OCs = oc_evaluate(net, oc_model, adv_loader, model_to_num_dim[args.model], device=device)
905
+ uncertainties = -adv_OCs.cpu().numpy()
906
+ quantiles = np.quantile(uncertainties, np.linspace(0, 1, 10))
907
+ quantiles = np.delete(quantiles, 0)
908
+ unc_list = []
909
+ accuracy_list = []
910
+ for threshold in quantiles:
911
+ cer_indices = (uncertainties < threshold)
912
+ unc_indices = ~cer_indices
913
+ labels_list = np.array(adv_labels_list)
914
+ targets_cer = labels_list[cer_indices]
915
+ predictions = np.array(adv_predictions)
916
+ pred_cer = predictions[cer_indices]
917
+ targets_unc = labels_list[unc_indices]
918
+ pred_unc = predictions[unc_indices]
919
+ cer_right = np.sum(targets_cer == pred_cer)
920
+ cer = len(targets_cer)
921
+ unc_right = np.sum(targets_unc == pred_unc)
922
+ unc = len(targets_unc)
923
+ accuracy_cer = cer_right / cer
924
+ accuracy_unc = unc_right / unc
925
+ unc_list.append(threshold)
926
+ accuracy_list.append(accuracy_cer)
927
+ print('ACC:', accuracy_cer, accuracy_unc)
928
+ from scipy.stats import spearmanr
929
+ Spearman_acc, p_acc = spearmanr(unc_list, accuracy_list)
930
+ print("Spearman correlation:", Spearman_acc, "mean uncertainties:", uncertainties.mean())
931
+ adv_unc[i][idx_ep]=uncertainties.mean()
932
+ adv_acc[i][idx_ep] = adv_accuracy
933
+
934
+ (_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc_logits(OCs, adv_OCs, certificate, device, confidence=True)
935
+ (_, _, _), (_, _, _), adv_m2_auroc, adv_m2_auprc = get_roc_auc_logits(logits, adv_logits, entropy, device)
936
+ print('adv_m1_auroc,adv_m2_auroc', adv_m1_auroc, adv_m2_auroc)
937
+
938
+ t_m1_auroc = ood_m1_auroc
939
+ t_m1_auprc = ood_m1_auprc
940
+ t_m2_auroc = ood_m2_auroc
941
+ t_m2_auprc = ood_m2_auprc
942
+
943
+ except RuntimeError as e:
944
+ print("Runtime Error caught: " + str(e))
945
+ continue
946
+
947
+ elif (args.model_type == "spc"):
948
+ print("SPC Model")
949
+
950
+ if args.crossval:
951
+ spc_model_path = os.path.join(
952
+ args.load_loc,
953
+ "Run" + str(i + 1),
954
+ model_load_name(args.model, args.sn, args.mod, args.coeff, args.seed, i) + "_valsize_"+str(args.val_size)+"_350_mar_model.pth",
955
+ )
956
+ else:
957
+ spc_model_path = os.path.join(
958
+ args.load_loc,
959
+ "Run" + str(i + 1),
960
+ model_load_name(args.model, args.sn, args.mod, args.coeff, args.seed, i) + "_350_mar_model.pth",
961
+ )
962
+
963
+ if os.path.exists(spc_model_path):
964
+ print(f"Loading existing spc_model from {spc_model_path}")
965
+ SPC_model = SPC_load(spc_model_path, model_to_num_dim[args.model], num_classes, device)
966
+ else:
967
+ print(f"Model not found. Training a new one...")
968
+ if args.crossval:
969
+ embeddings, labels = get_embeddings(
970
+ net,
971
+ val_loader,
972
+ num_dim=model_to_num_dim[args.model],
973
+ dtype=torch.double,
974
+ device=device,
975
+ storage_device=device,
976
+ )
977
+ else:
978
+ if args.dataset == 'imagenet':
979
+ if args.model=='imagenet_vgg16':
980
+ embed_path = 'data/imagenet_train_vgg_embedding.pt'
981
+ # embed_path = 'data/imagenet_val_vgg_embedding.pt'
982
+ if args.model=='imagenet_wide':
983
+ embed_path = 'data/imagenet_train_wide_embedding.pt'
984
+ # embed_path = 'data/imagenet_val_wide_embedding.pt'
985
+ if args.model=='imagenet_vit':
986
+ embed_path = 'data/imagenet_train_vit_embedding.pt'
987
+ # embed_path = 'data/imagenet_val_vit_embedding.pt'
988
+ if os.path.exists(embed_path):
989
+ data = torch.load(embed_path, map_location=device)
990
+ embeddings = data['embeddings']
991
+ labels = data['labels']
992
+ else:
993
+ embeddings, labels = get_embeddings(
994
+ net,
995
+ train_loader,
996
+ num_dim=model_to_num_dim[args.model],
997
+ dtype=torch.double,
998
+ device=device,
999
+ storage_device=device,
1000
+ )
1001
+ torch.save({'embeddings': embeddings, 'labels': labels}, embed_path)
1002
+ else:
1003
+ embeddings, labels = get_embeddings(
1004
+ net,
1005
+ train_loader,
1006
+ num_dim=model_to_num_dim[args.model],
1007
+ dtype=torch.double,
1008
+ device=device,
1009
+ storage_device=device,
1010
+ )
1011
+
1012
+ parts= model_to_last_layer[args.model].split('.')
1013
+ net_last_layer = net
1014
+
1015
+ for attr in parts:
1016
+ net_last_layer = getattr(net_last_layer, attr)
1017
+
1018
+ SPC_model=SPC_fit(net_last_layer, topt, embeddings, labels, model_to_num_dim[args.model], num_classes, device)
1019
+ torch.save(SPC_model.state_dict(), spc_model_path)
1020
+ print(f"Model saved at {spc_model_path}")
1021
+
1022
+ logits,mars=SPC_evaluate(net, SPC_model, test_loader, model_to_num_dim[args.model], num_classes, device)
1023
+
1024
+ soft = torch.nn.functional.softmax(logits, dim=1)
1025
+ delta = torch.min(torch.min(mars[2] - mars[3], mars[1] - 2 * mars[3]), 2 * mars[2] - mars[1])
1026
+ # delta=mars[2] - mars[3]
1027
+
1028
+ uncertainty = abs(mars[2]+mars[3] - mars[1])
1029
+
1030
+ threshold=0.05
1031
+ mask=(uncertainty<threshold).float()
1032
+ delta = delta*mask
1033
+ softmax_prob = soft + delta
1034
+ print(torch.sum(softmax_prob, dim=1))
1035
+
1036
+ c_confidences, c_predictions = torch.max(softmax_prob, dim=1)
1037
+ c_predictions=c_predictions.tolist()
1038
+ c_confidences=c_confidences.tolist()
1039
+ c_ece = expected_calibration_error(c_confidences, c_predictions, labels_list, num_bins=15)
1040
+ print('ece', ece, 't_ece', t_ece, 'c_ece', c_ece)
1041
+ c_accuracy=accuracy_score(labels_list, c_predictions)
1042
+ print('accuracy',accuracy,'t_accuracy',t_accuracy,'c_accuracy',c_accuracy)
1043
+
1044
+ ood_logits,ood_mars=SPC_evaluate(net, SPC_model, ood_test_loader, model_to_num_dim[args.model], num_classes, device)
1045
+ (_, _, _), (_, _, _), ood_m1_auroc, ood_m1_auprc = get_roc_auc_logits(mars, ood_mars, self_consistency, device)
1046
+ (_, _, _), (_, _, _), ood_m2_auroc, ood_m2_auprc = get_roc_auc_logits(logits, ood_logits, entropy, device)
1047
+
1048
+ labels_array = np.array(labels_list)
1049
+ pred_array = np.array(predictions)
1050
+ correct_mask = labels_array == pred_array
1051
+ mars_right = [m[correct_mask] for m in mars]
1052
+ mars_wrong = [m[~correct_mask] for m in mars]
1053
+ (_, _, _), (_, _, _), err_m1_auroc, err_m1_auprc = get_roc_auc_logits(mars_right, mars_wrong, self_consistency, device)
1054
+ (_, _, _), (_, _, _), err_m2_auroc, err_m2_auprc = get_roc_auc_logits(mars_right[0], mars_wrong[0], entropy, device)
1055
+
1056
+ if args.dataset == 'imagenet':
1057
+ adv_test_loader = dataset_loader['imagenet_a'].get_test_loader(batch_size=args.batch_size,
1058
+ imagesize=model_to_input_dim[args.model],
1059
+ pin_memory=args.gpu)
1060
+ else:
1061
+ adv_test_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=adv_ep,batch_size=args.batch_size)
1062
+ (adv_conf_matrix, adv_accuracy, adv_labels_list, adv_predictions,adv_confidences,) = test_classification_net(net, adv_test_loader, device)
1063
+ adv_logits, adv_mars = SPC_evaluate(net, SPC_model, adv_test_loader,model_to_num_dim[args.model], num_classes, device)
1064
+ (_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc_logits(mars, adv_mars,self_consistency, device)
1065
+ (_, _, _), (_, _, _), adv_m2_auroc, adv_m2_auprc = get_roc_auc_logits(logits, adv_logits, entropy, device)
1066
+
1067
+ if args.sample_noise:
1068
+ adv_eps = np.linspace(0, 0.4, 9)
1069
+ print(adv_eps)
1070
+ for idx_ep, ep in enumerate(adv_eps):
1071
+ adv_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=ep,
1072
+ batch_size=args.batch_size)
1073
+ (adv_conf_matrix, adv_accuracy, adv_labels_list, adv_predictions,
1074
+ adv_confidences,) = test_classification_net(net, adv_loader, device)
1075
+ adv_logits, adv_mars=SPC_evaluate(net, SPC_model, adv_loader, model_to_num_dim[args.model], num_classes, device)
1076
+ uncertainties = self_consistency(adv_mars).detach().cpu().numpy()
1077
+ quantiles = np.quantile(uncertainties, np.linspace(0, 1, 10))
1078
+ quantiles = np.delete(quantiles, 0)
1079
+ unc_list = []
1080
+ accuracy_list = []
1081
+ for threshold in quantiles:
1082
+ cer_indices = (uncertainties < threshold)
1083
+ unc_indices = ~cer_indices
1084
+ labels_list = np.array(adv_labels_list)
1085
+ targets_cer = labels_list[cer_indices]
1086
+ predictions = np.array(adv_predictions)
1087
+ pred_cer = predictions[cer_indices]
1088
+ targets_unc = labels_list[unc_indices]
1089
+ pred_unc = predictions[unc_indices]
1090
+ cer_right = np.sum(targets_cer == pred_cer)
1091
+ cer = len(targets_cer)
1092
+ unc_right = np.sum(targets_unc == pred_unc)
1093
+ unc = len(targets_unc)
1094
+ accuracy_cer = cer_right / cer
1095
+ accuracy_unc = unc_right / unc
1096
+ unc_list.append(threshold)
1097
+ accuracy_list.append(accuracy_cer)
1098
+ print('ACC:', accuracy_cer, accuracy_unc)
1099
+ from scipy.stats import spearmanr
1100
+ Spearman_acc, p_acc = spearmanr(unc_list, accuracy_list)
1101
+ print("Spearman correlation:", Spearman_acc, "mean uncertainties:", uncertainties.mean())
1102
+ adv_unc[i][idx_ep] = uncertainties.mean()
1103
+ adv_acc[i][idx_ep] = adv_accuracy
1104
+
1105
+ (_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc_logits(mars, adv_mars, self_consistency,device)
1106
+ (_, _, _), (_, _, _), adv_m2_auroc, adv_m2_auprc = get_roc_auc_logits(logits, adv_logits, entropy,device)
1107
+ print('adv_m1_auroc,adv_m2_auroc',adv_m1_auroc,adv_m2_auroc)
1108
+
1109
+
1110
+ t_m1_auroc = ood_m1_auroc
1111
+ t_m1_auprc = ood_m1_auprc
1112
+ t_m2_auroc = ood_m2_auroc
1113
+ t_m2_auprc = ood_m2_auprc
1114
+
1115
+ else:
1116
+ # Evaluate a normal Softmax model
1117
+ print("Softmax Model")
1118
+ (_, _, _), (_, _, _), ood_m1_auroc, ood_m1_auprc = get_roc_auc(net, test_loader, ood_test_loader, entropy, device)
1119
+ (_, _, _), (_, _, _), ood_m2_auroc, ood_m2_auprc = get_roc_auc(net, test_loader, ood_test_loader, logsumexp, device, confidence=True)
1120
+
1121
+ (_, _, _), (_, _, _), t_m1_auroc, t_m1_auprc = get_roc_auc(temp_scaled_net, test_loader, ood_test_loader, entropy, device)
1122
+ (_, _, _), (_, _, _), t_m2_auroc, t_m2_auprc = get_roc_auc(temp_scaled_net, test_loader, ood_test_loader, logsumexp, device, confidence=True)
1123
+
1124
+ labels_array = np.array(labels_list)
1125
+ pred_array = np.array(predictions)
1126
+ correct_mask = labels_array == pred_array
1127
+ logits, _ = get_logits_labels(net, test_loader, device)
1128
+ logits_right = logits[correct_mask]
1129
+ logits_wrong = logits[~correct_mask]
1130
+ (_, _, _), (_, _, _), err_m1_auroc, err_m1_auprc = get_roc_auc_logits(logits_right, logits_wrong, entropy, device)
1131
+ (_, _, _), (_, _, _), err_m2_auroc, err_m2_auprc = get_roc_auc_logits(logits_right, logits_wrong, logsumexp, device, confidence=True)
1132
+
1133
+ if args.dataset == 'imagenet':
1134
+ adv_test_loader = dataset_loader['imagenet_a'].get_test_loader(batch_size=args.batch_size,
1135
+ imagesize=model_to_input_dim[args.model],
1136
+ pin_memory=args.gpu)
1137
+ else:
1138
+ adv_test_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=adv_ep,batch_size=args.batch_size)
1139
+ (adv_conf_matrix, adv_accuracy, adv_labels_list, adv_predictions, adv_confidences,) = test_classification_net(net, adv_test_loader, device)
1140
+ adv_logits, _ = get_logits_labels(net, adv_test_loader, device)
1141
+ (_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc(net, test_loader, adv_test_loader, entropy, device)
1142
+ (_, _, _), (_, _, _), adv_m2_auroc, adv_m2_auprc = get_roc_auc(net, test_loader, adv_test_loader, logsumexp, device, confidence=True)
1143
+
1144
+ if args.sample_noise:
1145
+ adv_eps = np.linspace(0, 0.4, 9)
1146
+ for idx_ep, ep in enumerate(adv_eps):
1147
+ adv_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=ep,
1148
+ batch_size=args.batch_size)
1149
+ (adv_conf_matrix, adv_accuracy, adv_labels_list, adv_predictions,
1150
+ adv_confidences,) = test_classification_net(net, adv_loader, device)
1151
+ adv_logits, _ = get_logits_labels(net, adv_loader, device)
1152
+ uncertainties = entropy(adv_logits).detach().cpu().numpy()
1153
+ quantiles = np.quantile(uncertainties, np.linspace(0, 1, 10))
1154
+ quantiles = np.delete(quantiles, 0)
1155
+ unc_list = []
1156
+ accuracy_list = []
1157
+ for threshold in quantiles:
1158
+ cer_indices = (uncertainties < threshold)
1159
+ unc_indices = ~cer_indices
1160
+ labels_list = np.array(adv_labels_list)
1161
+ targets_cer = labels_list[cer_indices]
1162
+ predictions = np.array(adv_predictions)
1163
+ pred_cer = predictions[cer_indices]
1164
+ targets_unc = labels_list[unc_indices]
1165
+ pred_unc = predictions[unc_indices]
1166
+ cer_right = np.sum(targets_cer == pred_cer)
1167
+ cer = len(targets_cer)
1168
+ unc_right = np.sum(targets_unc == pred_unc)
1169
+ unc = len(targets_unc)
1170
+ accuracy_cer = cer_right / cer
1171
+ accuracy_unc = unc_right / unc
1172
+ unc_list.append(threshold)
1173
+ accuracy_list.append(accuracy_cer)
1174
+ print('ACC:', accuracy_cer, accuracy_unc)
1175
+ from scipy.stats import spearmanr
1176
+
1177
+ Spearman_acc, p_acc = spearmanr(unc_list, accuracy_list)
1178
+ print("Spearman correlation:", Spearman_acc, "mean uncertainties:", uncertainties.mean())
1179
+ adv_unc[i][idx_ep] = uncertainties.mean()
1180
+ adv_acc[i][idx_ep] = adv_accuracy
1181
+
1182
+ (_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc(net, test_loader, adv_loader, entropy, device)
1183
+ (_, _, _), (_, _, _), adv_m2_auroc, adv_m2_auprc = get_roc_auc(net, test_loader, adv_loader, logsumexp, device, confidence=True)
1184
+ print('adv_m1_auroc,adv_m2_auroc', adv_m1_auroc, adv_m2_auroc)
1185
+
1186
+
1187
+ accuracies.append(accuracy)
1188
+ if (args.model_type == "spc" or args.model_type == "joint"):
1189
+ c_accuracies.append(c_accuracy)
1190
+ else:
1191
+ c_accuracies.append(t_accuracy)
1192
+
1193
+
1194
+ # Pre-temperature results
1195
+ eces.append(ece)
1196
+ ood_m1_aurocs.append(ood_m1_auroc)
1197
+ ood_m1_auprcs.append(ood_m1_auprc)
1198
+ ood_m2_aurocs.append(ood_m2_auroc)
1199
+ ood_m2_auprcs.append(ood_m2_auprc)
1200
+
1201
+ err_m1_aurocs.append(err_m1_auroc)
1202
+ err_m1_auprcs.append(err_m1_auprc)
1203
+ err_m2_aurocs.append(err_m2_auroc)
1204
+ err_m2_auprcs.append(err_m2_auprc)
1205
+
1206
+ adv_m1_aurocs.append(adv_m1_auroc)
1207
+ adv_m1_auprcs.append(adv_m1_auprc)
1208
+ adv_m2_aurocs.append(adv_m2_auroc)
1209
+ adv_m2_auprcs.append(adv_m2_auprc)
1210
+
1211
+ # Post-temperature results
1212
+ t_eces.append(t_ece)
1213
+ t_m1_aurocs.append(t_m1_auroc)
1214
+ t_m1_auprcs.append(t_m1_auprc)
1215
+ t_m2_aurocs.append(t_m2_auroc)
1216
+ t_m2_auprcs.append(t_m2_auprc)
1217
+
1218
+ if (args.model_type == "spc" or args.model_type == "joint"):
1219
+ c_eces.append(c_ece)
1220
+
1221
+ gc.collect()
1222
+ torch.cuda.empty_cache()
1223
+ torch.cuda.ipc_collect()
1224
+
1225
+ if args.sample_noise:
1226
+ adv_unc_norm = (adv_unc - adv_unc.min(axis=1, keepdims=True)) / \
1227
+ (adv_unc.max(axis=1, keepdims=True) - adv_unc.min(axis=1, keepdims=True) + 1e-8)
1228
+ mean_unc = np.mean(adv_unc_norm, axis=0)
1229
+ std_unc = np.std(adv_unc_norm, axis=0)
1230
+
1231
+ plt.figure(figsize=(10, 6))
1232
+ plt.plot(mean_unc, label='Uncertainty', color='orange')
1233
+ plt.fill_between(range(len(mean_unc)),
1234
+ mean_unc - std_unc,
1235
+ mean_unc + std_unc,
1236
+ color='orange', alpha=0.3, label="±1 Std Dev")
1237
+ # plt.legend()
1238
+ # plt.title('Uncertainty Across Multiple Runs')
1239
+ plt.xlabel('Noise')
1240
+ plt.ylabel('Uncertainty')
1241
+ plt.savefig("adv_unc_"
1242
+ + model_save_name(args.model, args.sn, args.mod, args.coeff, args.seed)
1243
+ + "_"
1244
+ + args.model_type
1245
+ + "_"
1246
+ + args.dataset
1247
+ + ".png")
1248
+ plt.show()
1249
+ plt.close()
1250
+
1251
+
1252
+ mean_acc = np.mean(adv_acc, axis=0)
1253
+ std_acc = np.std(adv_acc, axis=0)
1254
+
1255
+ plt.figure(figsize=(10, 6))
1256
+ plt.plot(mean_acc, label='Accuracy', color='red')
1257
+ plt.fill_between(range(len(mean_acc)),
1258
+ mean_acc - std_acc,
1259
+ mean_acc + std_acc,
1260
+ color='red', alpha=0.3, label="±1 Std Dev")
1261
+ # plt.legend()
1262
+ # plt.title('Uncertainty Across Multiple Runs')
1263
+ plt.xlabel('Noise')
1264
+ plt.ylabel('Accuracy')
1265
+ plt.savefig("adv_acc_"
1266
+ + model_save_name(args.model, args.sn, args.mod, args.coeff, args.seed)
1267
+ + "_"
1268
+ + args.model_type
1269
+ + "_"
1270
+ + args.dataset
1271
+ + ".png")
1272
+ plt.show()
1273
+ plt.close()
1274
+
1275
+ save_dir = "curve_data"
1276
+ os.makedirs(save_dir, exist_ok=True)
1277
+
1278
+ if args.sn:
1279
+ prefix = f"{args.dataset}_{args.model_type}_{args.model}_SN"
1280
+ else:
1281
+ prefix = f"{args.dataset}_{args.model_type}_{args.model}"
1282
+
1283
+
1284
+ accuracy_tensor = torch.tensor(accuracies)
1285
+ c_accuracy_tensor = torch.tensor(c_accuracies)
1286
+ ece_tensor = torch.tensor(eces)
1287
+ ood_m1_auroc_tensor = torch.tensor(ood_m1_aurocs)
1288
+ m1_auprc_tensor = torch.tensor(ood_m1_auprcs)
1289
+ ood_m2_auroc_tensor = torch.tensor(ood_m2_aurocs)
1290
+ ood_m2_auprc_tensor = torch.tensor(ood_m2_auprcs)
1291
+
1292
+ err_m1_auroc_tensor = torch.tensor(err_m1_aurocs)
1293
+ err_m1_auprc_tensor = torch.tensor(err_m1_auprcs)
1294
+ err_m2_auroc_tensor = torch.tensor(err_m2_aurocs)
1295
+ err_m2_auprc_tensor = torch.tensor(err_m2_auprcs)
1296
+
1297
+ adv_m1_auroc_tensor = torch.tensor(adv_m1_aurocs)
1298
+ adv_m1_auprc_tensor = torch.tensor(adv_m1_auprcs)
1299
+ adv_m2_auroc_tensor = torch.tensor(adv_m2_aurocs)
1300
+ adv_m2_auprc_tensor = torch.tensor(adv_m2_auprcs)
1301
+
1302
+ t_ece_tensor = torch.tensor(t_eces)
1303
+ t_m1_auroc_tensor = torch.tensor(t_m1_aurocs)
1304
+ t_m1_auprc_tensor = torch.tensor(t_m1_auprcs)
1305
+ t_m2_auroc_tensor = torch.tensor(t_m2_aurocs)
1306
+ t_m2_auprc_tensor = torch.tensor(t_m2_auprcs)
1307
+
1308
+ c_ece_tensor = torch.tensor(c_eces)
1309
+
1310
+ mean_accuracy = torch.mean(accuracy_tensor)
1311
+ mean_c_accuracy = torch.mean(c_accuracy_tensor)
1312
+ mean_ece = torch.mean(ece_tensor)
1313
+ mean_ood_m1_auroc = torch.mean(ood_m1_auroc_tensor)
1314
+ mean_m1_auprc = torch.mean(m1_auprc_tensor)
1315
+ mean_m2_auroc = torch.mean(ood_m2_auroc_tensor)
1316
+ mean_m2_auprc = torch.mean(ood_m2_auprc_tensor)
1317
+
1318
+ mean_err_m1_auroc = torch.mean(err_m1_auroc_tensor)
1319
+ mean_err_m1_auprc = torch.mean(err_m1_auprc_tensor)
1320
+ mean_err_m2_auroc = torch.mean(err_m2_auroc_tensor)
1321
+ mean_err_m2_auprc = torch.mean(err_m2_auprc_tensor)
1322
+
1323
+ mean_adv_m1_auroc = torch.mean(adv_m1_auroc_tensor)
1324
+ mean_adv_m1_auprc = torch.mean(adv_m1_auprc_tensor)
1325
+ mean_adv_m2_auroc = torch.mean(adv_m2_auroc_tensor)
1326
+ mean_adv_m2_auprc = torch.mean(adv_m2_auprc_tensor)
1327
+
1328
+ mean_t_ece = torch.mean(t_ece_tensor)
1329
+ mean_t_m1_auroc = torch.mean(t_m1_auroc_tensor)
1330
+ mean_t_m1_auprc = torch.mean(t_m1_auprc_tensor)
1331
+ mean_t_m2_auroc = torch.mean(t_m2_auroc_tensor)
1332
+ mean_t_m2_auprc = torch.mean(t_m2_auprc_tensor)
1333
+
1334
+ mean_c_ece = torch.mean(c_ece_tensor)
1335
+
1336
+ std_accuracy = torch.std(accuracy_tensor) / math.sqrt(accuracy_tensor.shape[0])
1337
+ std_c_accuracy = torch.std(c_accuracy_tensor) / math.sqrt(c_accuracy_tensor.shape[0])
1338
+ std_ece = torch.std(ece_tensor) / math.sqrt(ece_tensor.shape[0])
1339
+ std_ood_m1_auroc = torch.std(ood_m1_auroc_tensor) / math.sqrt(ood_m1_auroc_tensor.shape[0])
1340
+ std_m1_auprc = torch.std(m1_auprc_tensor) / math.sqrt(m1_auprc_tensor.shape[0])
1341
+ std_m2_auroc = torch.std(ood_m2_auroc_tensor) / math.sqrt(ood_m2_auroc_tensor.shape[0])
1342
+ std_m2_auprc = torch.std(ood_m2_auprc_tensor) / math.sqrt(ood_m2_auprc_tensor.shape[0])
1343
+ std_err_m1_auroc = torch.std(err_m1_auroc_tensor) / math.sqrt(err_m1_auroc_tensor.shape[0])
1344
+ std_err_m1_auprc = torch.std(err_m1_auprc_tensor) / math.sqrt(err_m1_auprc_tensor.shape[0])
1345
+ std_err_m2_auroc = torch.std(err_m2_auroc_tensor) / math.sqrt(err_m2_auroc_tensor.shape[0])
1346
+ std_err_m2_auprc = torch.std(err_m2_auprc_tensor) / math.sqrt(err_m2_auprc_tensor.shape[0])
1347
+ std_adv_m1_auroc = torch.std(adv_m1_auroc_tensor) / math.sqrt(adv_m1_auroc_tensor.shape[0])
1348
+ std_adv_m1_auprc = torch.std(adv_m1_auprc_tensor) / math.sqrt(adv_m1_auprc_tensor.shape[0])
1349
+ std_adv_m2_auroc = torch.std(adv_m2_auroc_tensor) / math.sqrt(adv_m2_auroc_tensor.shape[0])
1350
+ std_adv_m2_auprc = torch.std(adv_m2_auprc_tensor) / math.sqrt(adv_m2_auprc_tensor.shape[0])
1351
+
1352
+ std_t_ece = torch.std(t_ece_tensor) / math.sqrt(t_ece_tensor.shape[0])
1353
+ std_t_m1_auroc = torch.std(t_m1_auroc_tensor) / math.sqrt(t_m1_auroc_tensor.shape[0])
1354
+ std_t_m1_auprc = torch.std(t_m1_auprc_tensor) / math.sqrt(t_m1_auprc_tensor.shape[0])
1355
+ std_t_m2_auroc = torch.std(t_m2_auroc_tensor) / math.sqrt(t_m2_auroc_tensor.shape[0])
1356
+ std_t_m2_auprc = torch.std(t_m2_auprc_tensor) / math.sqrt(t_m2_auprc_tensor.shape[0])
1357
+
1358
+ std_c_ece = torch.std(c_ece_tensor) / math.sqrt(c_ece_tensor.shape[0])
1359
+
1360
+ res_dict = {}
1361
+ res_dict["mean"] = {}
1362
+ res_dict["mean"]["accuracy"] = mean_accuracy.item()
1363
+ res_dict["mean"]["ece"] = mean_ece.item()
1364
+ res_dict["mean"]["ood_m1_auroc"] = mean_ood_m1_auroc.item()
1365
+ res_dict["mean"]["ood_m1_auprc"] = mean_m1_auprc.item()
1366
+ res_dict["mean"]["ood_m2_auroc"] = mean_m2_auroc.item()
1367
+ res_dict["mean"]["ood_m2_auprc"] = mean_m2_auprc.item()
1368
+ res_dict["mean"]["t_ece"] = mean_t_ece.item()
1369
+ res_dict["mean"]["t_m1_auroc"] = mean_t_m1_auroc.item()
1370
+ res_dict["mean"]["t_m1_auprc"] = mean_t_m1_auprc.item()
1371
+ res_dict["mean"]["t_m2_auroc"] = mean_t_m2_auroc.item()
1372
+ res_dict["mean"]["t_m2_auprc"] = mean_t_m2_auprc.item()
1373
+ res_dict["mean"]["c_ece"] = mean_c_ece.item()
1374
+
1375
+ res_dict["std"] = {}
1376
+ res_dict["std"]["accuracy"] = std_accuracy.item()
1377
+ res_dict["std"]["ece"] = std_ece.item()
1378
+ res_dict["std"]["ood_m1_auroc"] = std_ood_m1_auroc.item()
1379
+ res_dict["std"]["ood_m1_auprc"] = std_m1_auprc.item()
1380
+ res_dict["std"]["ood_m2_auroc"] = std_m2_auroc.item()
1381
+ res_dict["std"]["ood_m2_auprc"] = std_m2_auprc.item()
1382
+ res_dict["std"]["t_ece"] = std_t_ece.item()
1383
+ res_dict["std"]["t_m1_auroc"] = std_t_m1_auroc.item()
1384
+ res_dict["std"]["t_m1_auprc"] = std_t_m1_auprc.item()
1385
+ res_dict["std"]["t_m2_auroc"] = std_t_m2_auroc.item()
1386
+ res_dict["std"]["t_m2_auprc"] = std_t_m2_auprc.item()
1387
+ res_dict["std"]["c_ece"] = std_c_ece.item()
1388
+
1389
+ res_dict["values"] = {}
1390
+ res_dict["values"]["accuracy"] = accuracies
1391
+ res_dict["values"]["ece"] = eces
1392
+ res_dict["values"]["ood_m1_auroc"] = ood_m1_aurocs
1393
+ res_dict["values"]["ood_m1_auprc"] = ood_m1_auprcs
1394
+ res_dict["values"]["ood_m2_auroc"] = ood_m2_aurocs
1395
+ res_dict["values"]["ood_m2_auprc"] = ood_m2_auprcs
1396
+ res_dict["values"]["t_ece"] = t_eces
1397
+ res_dict["values"]["t_m1_auroc"] = t_m1_aurocs
1398
+ res_dict["values"]["t_m1_auprc"] = t_m1_auprcs
1399
+ res_dict["values"]["t_m2_auroc"] = t_m2_aurocs
1400
+ res_dict["values"]["t_m2_auprc"] = t_m2_auprcs
1401
+ res_dict["values"]["c_ece"] = c_eces
1402
+
1403
+ res_dict["info"] = vars(args)
1404
+
1405
+ print(f"{mean_accuracy.item() * 100:.2f} ± {std_accuracy.item() * 100:.2f}")
1406
+ print(f"{mean_c_accuracy.item() * 100:.2f} ± {std_c_accuracy.item() * 100:.2f}")
1407
+ print(f"{mean_ece.item()*100:.2f} ± {std_ece.item()*100:.2f}")
1408
+ print(f"{mean_t_ece.item()*100:.2f} ± {std_t_ece.item()*100:.2f}")
1409
+ print(f"{mean_c_ece.item() * 100:.2f} ± {std_c_ece.item() * 100:.2f}")
1410
+ print(f"{mean_adv_m1_auroc.item()*100:.2f} ± {std_adv_m1_auroc.item()*100:.2f}")
1411
+ print(f"{mean_err_m1_auroc.item()*100:.2f} ± {std_err_m1_auroc.item()*100:.2f}")
1412
+ print(f"{mean_ood_m1_auroc.item()*100:.2f} ± {std_ood_m1_auroc.item()*100:.2f}")
1413
+
1414
+
1415
+ with open(
1416
+ "res_"
1417
+ + model_save_name(args.model, args.sn, args.mod, args.coeff, args.seed)
1418
+ + "_"
1419
+ + args.model_type
1420
+ + "_"
1421
+ + args.dataset
1422
+ + "_"
1423
+ + args.ood_dataset
1424
+ + ".json",
1425
+ "w",
1426
+ ) as f:
1427
+ json.dump(res_dict, f)
SPC-UQ/Image_Classification/evaluate_laplace.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script to evaluate the Laplace Approximation.
3
+ """
4
+
5
+ import os
6
+ import json
7
+ import math
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.nn as nn
11
+ import argparse
12
+ import torch.backends.cudnn as cudnn
13
+ import numpy as np
14
+ from tqdm import tqdm
15
+ import matplotlib.pyplot as plt
16
+
17
+ # Import data loaders and networks
18
+ import data.ood_detection.cifar10 as cifar10
19
+ import data.ood_detection.cifar100 as cifar100
20
+ import data.ood_detection.svhn as svhn
21
+ import data.ood_detection.imagenet as imagenet
22
+ import data.ood_detection.tinyimagenet as tinyimagenet
23
+ import data.ood_detection.imagenet_o as imagenet_o
24
+ import data.ood_detection.imagenet_a as imagenet_a
25
+ import data.ood_detection.ood_union as ood_union
26
+
27
+ from net.resnet import resnet50
28
+ from net.resnet_edl import resnet50_edl
29
+ from net.wide_resnet import wrn
30
+ from net.wide_resnet_edl import wrn_edl
31
+ from net.vgg import vgg16
32
+ from net.vgg_edl import vgg16_edl
33
+ from net.imagenet_wide import imagenet_wide
34
+ from net.imagenet_vgg import imagenet_vgg16
35
+ from net.imagenet_vit import imagenet_vit
36
+
37
+ from metrics.classification_metrics import (
38
+ test_classification_net,
39
+ test_classification_net_logits,
40
+ test_classification_net_ensemble,
41
+ test_classification_net_edl,
42
+ create_adversarial_dataloader
43
+ )
44
+ from metrics.calibration_metrics import expected_calibration_error
45
+
46
+ from utils.gmm_utils import get_embeddings, gmm_evaluate, gmm_fit
47
+ from utils.ensemble_utils import load_ensemble, ensemble_forward_pass
48
+ from utils.eval_utils import model_load_name
49
+ from utils.train_utils import model_save_name
50
+ from utils.args import laplace_eval_args
51
+
52
+ from laplace import Laplace
53
+ from sklearn import metrics as M
54
+ from laplace.curvature import AsdlGGN, AsdlEF, BackPackGGN, BackPackEF
55
+
56
+ import warnings
57
+ warnings.filterwarnings('ignore')
58
+
59
+ # Dataset mapping and config
60
+ DATASET_NUM_CLASSES = {
61
+ "cifar10": 10, "cifar100": 100, "svhn": 10, "imagenet": 1000,
62
+ "tinyimagenet": 200, "imagenet_o": 200, "imagenet_a": 200
63
+ }
64
+
65
+ DATASET_LOADER = {
66
+ "cifar10": cifar10, "cifar100": cifar100, "svhn": svhn, "imagenet": imagenet,
67
+ "tinyimagenet": tinyimagenet, "imagenet_o": imagenet_o, "imagenet_a": imagenet_a
68
+ }
69
+
70
+ MODELS = {
71
+ "resnet50": resnet50, "resnet50_edl": resnet50_edl,
72
+ "wide_resnet": wrn, "wide_resnet_edl": wrn_edl,
73
+ "vgg16": vgg16, "vgg16_edl": vgg16_edl,
74
+ "imagenet_wide": imagenet_wide
75
+ }
76
+
77
+ MODEL_TO_NUM_DIM = {
78
+ "resnet50": 2048, "resnet50_edl": 2048, "wide_resnet": 640, "wide_resnet_edl": 640,
79
+ "vgg16": 512, "vgg16_edl": 512, "imagenet_wide": 2048,
80
+ "imagenet_vgg16": 4096, "imagenet_vit": 768
81
+ }
82
+
83
+ MODEL_TO_INPUT_DIM = {
84
+ "resnet50": 32, "resnet50_edl": 32, "wide_resnet": 32, "wide_resnet_edl": 32,
85
+ "vgg16": 32, "vgg16_edl": 32, "imagenet_wide": 224,
86
+ "imagenet_vgg16": 224, "imagenet_vit": 224
87
+ }
88
+
89
+ MODEL_TO_LAST_LAYER = {
90
+ "resnet50": "module.fc", "wide_resnet": "module.linear", "vgg16": "module.classifier",
91
+ "imagenet_wide": "module.linear",
92
+ "imagenet_vgg16": "module.classifier", "imagenet_vit": "module.linear"
93
+ }
94
+
95
+ def get_backend(backend, approx_type):
96
+ if backend == 'kazuki':
97
+ return AsdlGGN if approx_type == 'ggn' else AsdlEF
98
+ elif backend == 'backpack':
99
+ return BackPackGGN if approx_type == 'ggn' else BackPackEF
100
+ else:
101
+ raise ValueError(f"Unknown backend: {backend}")
102
+
103
+ def get_cpu_memory_mb():
104
+ import psutil
105
+ process = psutil.Process(os.getpid())
106
+ mem_info = process.memory_info()
107
+ return mem_info.rss / 1024 ** 2
108
+
109
+ def print_metrics(mean, std, name):
110
+ print(f"{name}: {mean * 100:.2f} ± {std * 100:.2f}")
111
+
112
+ if __name__ == "__main__":
113
+
114
+ args = laplace_eval_args().parse_args()
115
+
116
+ # Set random seed and device
117
+ torch.manual_seed(args.seed)
118
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
119
+ print("Parsed args:", args)
120
+ print("Seed:", args.seed)
121
+
122
+ num_classes = DATASET_NUM_CLASSES[args.dataset]
123
+ test_loader = DATASET_LOADER[args.dataset].get_test_loader(
124
+ batch_size=args.batch_size, imagesize=MODEL_TO_INPUT_DIM[args.model], pin_memory=args.gpu
125
+ )
126
+ if args.ood_dataset == 'ood_union':
127
+ ood_test_loader = ood_union.get_combined_ood_test_loader(
128
+ batch_size=args.batch_size, sample_seed=args.seed,
129
+ imagesize=MODEL_TO_INPUT_DIM[args.model], pin_memory=args.gpu
130
+ )
131
+ else:
132
+ ood_test_loader = DATASET_LOADER[args.ood_dataset].get_test_loader(
133
+ batch_size=args.batch_size, imagesize=MODEL_TO_INPUT_DIM[args.model], pin_memory=args.gpu
134
+ )
135
+
136
+ # Prepare metric accumulators
137
+ accuracies, eces, ood_aurocs, err_aurocs, adv_aurocs = [], [], [], [], []
138
+ err_aurocs, adv_aurocs = [], []
139
+ adv_unc = np.zeros((args.runs, 9))
140
+ adv_acc = np.zeros((args.runs, 9))
141
+ adv_ep = 0.02
142
+
143
+ for i in range(args.runs):
144
+ # Load training/validation splits
145
+ train_loader, val_loader = DATASET_LOADER[args.dataset].get_train_valid_loader(
146
+ batch_size=args.batch_size, imagesize=MODEL_TO_INPUT_DIM[args.model], augment=args.data_aug,
147
+ val_seed=(args.seed + i), val_size=args.val_size, pin_memory=args.gpu
148
+ )
149
+ mixture_components = []
150
+ for model_idx in range(args.nr_components):
151
+ if args.dataset == 'imagenet':
152
+ net = MODELS[args.model](pretrained=True, num_classes=1000).cuda()
153
+ net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
154
+ cudnn.benchmark = True
155
+ else:
156
+ if args.val_size == 0.1 or not args.crossval:
157
+ saved_model_name = os.path.join(
158
+ args.load_loc, f"Run{i+1}",
159
+ model_load_name(args.model, args.sn, args.mod, args.coeff, args.seed, i) + "_350.model",
160
+ )
161
+ else:
162
+ saved_model_name = os.path.join(
163
+ args.load_loc, f"Run{i+1}",
164
+ model_load_name(args.model, args.sn, args.mod, args.coeff, args.seed, i)
165
+ + f"_350_0{int(args.val_size * 10)}.model"
166
+ )
167
+ print('Loading:', saved_model_name)
168
+ net = MODELS[args.model](
169
+ spectral_normalization=args.sn, mod=args.mod, coeff=args.coeff, num_classes=num_classes, temp=1.0
170
+ )
171
+ if args.gpu:
172
+ net.cuda()
173
+ net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
174
+ cudnn.benchmark = True
175
+ net.load_state_dict(torch.load(str(saved_model_name)))
176
+
177
+ # Laplace backend and fit
178
+ args.prior_precision = 1.0 if isinstance(args.prior_precision, float) else torch.load(args.prior_precision, map_location=device)
179
+ Backend = get_backend(args.backend, args.approx_type)
180
+ args.last_layer_name = MODEL_TO_LAST_LAYER[args.model]
181
+ optional_args = {"last_layer_name": args.last_layer_name} if args.subset_of_weights == 'last_layer' else {}
182
+
183
+ print('Fitting Laplace approximation...')
184
+ model = Laplace(
185
+ net, args.likelihood, subset_of_weights=args.subset_of_weights,
186
+ hessian_structure=args.hessian_structure, prior_precision=args.prior_precision,
187
+ temperature=args.temperature, backend=Backend, **optional_args
188
+ )
189
+ model.fit(val_loader if args.crossval else train_loader)
190
+
191
+ # Optional: Optimize prior precision
192
+ if (args.optimize_prior_precision is not None) and (args.method == 'laplace'):
193
+ n = model.n_params if args.prior_structure == 'all' else model.n_layers
194
+ prior_precision = args.prior_precision * torch.ones(n, device=device)
195
+ print('Optimizing prior precision...')
196
+ model.optimize_prior_precision(
197
+ method=args.optimize_prior_precision, init_prior_prec=prior_precision,
198
+ val_loader=val_loader, pred_type=args.pred_type, link_approx=args.link_approx,
199
+ n_samples=args.n_samples, verbose=(args.prior_structure == 'scalar')
200
+ )
201
+ mixture_components.append(model)
202
+
203
+ model = mixture_components[0]
204
+ loss_fn = nn.NLLLoss()
205
+
206
+ # Evaluate ID data
207
+ id_y_true, id_y_prob = [], []
208
+ for data in tqdm(test_loader, desc='Evaluating ID data'):
209
+ x, y = data[0].to(device), data[1].to(device)
210
+ id_y_true.append(y.cpu())
211
+ y_prob = model(x, pred_type=args.pred_type, link_approx=args.link_approx, n_samples=args.n_samples)
212
+ id_y_prob.append(y_prob.cpu())
213
+ id_y_prob = torch.cat(id_y_prob, dim=0)
214
+ id_y_true = torch.cat(id_y_true, dim=0)
215
+ c, preds = torch.max(id_y_prob, 1)
216
+ metrics = {}
217
+ metrics['conf'] = c.mean().item()
218
+ metrics['nll'] = loss_fn(id_y_prob.log(), id_y_true).item()
219
+ metrics['acc'] = (id_y_true == preds).float().mean().item()
220
+ accuracy = metrics['acc']
221
+ id_confidences = id_y_prob.max(dim=1)[0].numpy()
222
+ ece = expected_calibration_error(id_confidences, preds.numpy(), id_y_true.numpy(), num_bins=15)
223
+ t_ece = ece
224
+ metrics['ece'] = ece
225
+ print(metrics)
226
+
227
+ # Evaluate OOD data
228
+ ood_y_true, ood_y_prob = [], []
229
+ for data in tqdm(ood_test_loader, desc='Evaluating OOD data'):
230
+ x, y = data[0].to(device), data[1].to(device)
231
+ ood_y_true.append(y.cpu())
232
+ y_prob = model(x, pred_type=args.pred_type, link_approx=args.link_approx, n_samples=args.n_samples)
233
+ ood_y_prob.append(y_prob.cpu())
234
+ ood_y_prob = torch.cat(ood_y_prob, dim=0)
235
+ ood_y_true = torch.cat(ood_y_true, dim=0)
236
+ ood_confidences = ood_y_prob.max(dim=1)[0].numpy()
237
+
238
+ # OOD AUROC/AUPRC metrics
239
+ bin_labels = np.concatenate([
240
+ np.zeros(id_confidences.shape[0]),
241
+ np.ones(ood_confidences.shape[0])
242
+ ])
243
+ scores = np.concatenate([id_confidences, ood_confidences])
244
+ fpr, tpr, thresholds = M.roc_curve(bin_labels, scores)
245
+ precision, recall, prc_thresholds = M.precision_recall_curve(bin_labels, scores)
246
+ ood_auroc = M.roc_auc_score(bin_labels, scores)
247
+ auprc = M.average_precision_score(bin_labels, scores)
248
+ print(f"OOD AUROC: {ood_auroc:.4f}, AUPRC: {auprc:.4f}")
249
+
250
+ # Error AUROC/AUPRC (in-distribution: correct vs incorrect)
251
+ labels_array = np.array(id_y_true)
252
+ pred_array = np.array(preds)
253
+ correct_mask = labels_array == pred_array
254
+ confidences_right = id_confidences[correct_mask]
255
+ confidences_wrong = id_confidences[~correct_mask]
256
+ bin_labels = np.concatenate([
257
+ np.zeros(confidences_right.shape[0]),
258
+ np.ones(confidences_wrong.shape[0])
259
+ ])
260
+ scores = np.concatenate([confidences_right, confidences_wrong])
261
+ err_auroc = M.roc_auc_score(bin_labels, scores)
262
+ err_auprc = M.average_precision_score(bin_labels, scores)
263
+ print(f"Error AUROC: {err_auroc:.4f}, AUPRC: {err_auprc:.4f}")
264
+
265
+ # Adversarial robustness
266
+ adv_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=adv_ep, batch_size=args.batch_size)
267
+ adv_y_prob, adv_y_true = [], []
268
+ for data in tqdm(adv_loader, desc='Adversarial evaluation'):
269
+ x, y = data[0].to(device), data[1].to(device)
270
+ y_prob = model(x, pred_type=args.pred_type, link_approx=args.link_approx, n_samples=args.n_samples)
271
+ adv_y_true.append(y.cpu())
272
+ adv_y_prob.append(y_prob.cpu())
273
+ adv_y_prob = torch.cat(adv_y_prob, dim=0)
274
+ adv_y_true = torch.cat(adv_y_true, dim=0).numpy()
275
+ _, adv_predictions = torch.max(adv_y_prob, 1)
276
+ adv_accuracy = (adv_y_true == adv_predictions).mean()
277
+ adv_confidences = adv_y_prob.max(dim=1)[0].numpy()
278
+ bin_labels = np.concatenate([
279
+ np.zeros(id_confidences.shape[0]),
280
+ np.ones(adv_confidences.shape[0])
281
+ ])
282
+ adv_scores = np.concatenate([id_confidences, adv_confidences])
283
+ adv_auroc = M.roc_auc_score(bin_labels, adv_scores)
284
+ adv_auprc = M.average_precision_score(bin_labels, adv_scores)
285
+ print(f"Adversarial AUROC: {adv_auroc:.4f}, AUPRC: {adv_auprc:.4f}")
286
+
287
+ # If sample_noise: save/plot noise-uncertainty and accuracy curves
288
+ if args.sample_noise:
289
+ adv_eps = np.linspace(0, 0.4, 9)
290
+ for idx_ep, ep in enumerate(adv_eps):
291
+ adv_loader = create_adversarial_dataloader(
292
+ net, test_loader, device, epsilon=ep, batch_size=args.batch_size
293
+ )
294
+ adv_y_prob, adv_y_true = [], []
295
+ for data in tqdm(adv_loader, desc=f"Adv evaluation ep={ep:.2f}"):
296
+ x, y = data[0].to(device), data[1].to(device)
297
+ y_prob = model(x, pred_type=args.pred_type, link_approx=args.link_approx, n_samples=args.n_samples)
298
+ adv_y_true.append(y.cpu())
299
+ adv_y_prob.append(y_prob.cpu())
300
+ adv_y_prob = torch.cat(adv_y_prob, dim=0)
301
+ adv_y_true = torch.cat(adv_y_true, dim=0).numpy()
302
+ _, predictions = torch.max(adv_y_prob, 1)
303
+ adv_accuracy = (adv_y_true == predictions).mean()
304
+ uncertainties = 1 - adv_y_prob.max(dim=1)[0].numpy()
305
+ adv_unc[i][idx_ep] = uncertainties.mean()
306
+ adv_acc[i][idx_ep] = adv_accuracy
307
+ # Save/plot uncertainty/accuracy curves as in your original
308
+
309
+ # Accumulate results
310
+ accuracies.append(accuracy)
311
+ eces.append(ece)
312
+ ood_aurocs.append(ood_auroc)
313
+ err_aurocs.append(err_auroc)
314
+ adv_aurocs.append(adv_auroc)
315
+
316
+ del model, mixture_components
317
+ torch.cuda.empty_cache()
318
+ import gc
319
+ gc.collect()
320
+
321
+ # Final result reporting and saving
322
+
323
+ def mean_std(x):
324
+ arr = torch.tensor(x)
325
+ return arr.mean().item(), arr.std().item() / math.sqrt(arr.shape[0])
326
+
327
+ # Print summary
328
+ print_metrics(*mean_std(accuracies), "Accuracy")
329
+ print_metrics(*mean_std(eces), "ECE")
330
+ print_metrics(*mean_std(adv_aurocs), "Adv AUROC")
331
+ print_metrics(*mean_std(err_aurocs), "Error AUROC")
332
+ print_metrics(*mean_std(ood_aurocs), "OOD AUROC")
333
+
334
+ # Store only required metrics
335
+ result_json = {}
336
+ for key, arr in [
337
+ ("accuracy", accuracies),
338
+ ("ece", eces),
339
+ ("adv_auroc", adv_aurocs),
340
+ ("err_auroc", err_aurocs),
341
+ ("ood_auroc", ood_aurocs)
342
+ ]:
343
+ mean, std = mean_std(arr)
344
+ result_json[key] = {
345
+ "mean": mean,
346
+ "std": std,
347
+ "values": [float(v) for v in arr]
348
+ }
349
+
350
+ result_file = (
351
+ "res_" + model_save_name(args.model, args.sn, args.mod, args.coeff, args.seed)
352
+ + "_laplace_" + args.dataset + "_" + args.ood_dataset + ".json"
353
+ )
354
+ with open(result_file, "w") as f:
355
+ json.dump(result_json, f, indent=2)
SPC-UQ/Image_Classification/metrics/__init__.py ADDED
File without changes
SPC-UQ/Image_Classification/metrics/calibration_metrics.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Metrics to measure calibration of a trained deep neural network.
3
+ References:
4
+ [1] C. Guo, G. Pleiss, Y. Sun, and K. Q. Weinberger. On calibration of modern neural networks.
5
+ arXiv preprint arXiv:1706.04599, 2017.
6
+ """
7
+
8
+ import math
9
+ import torch
10
+ import numpy as np
11
+ from torch import nn
12
+ from torch.nn import functional as F
13
+
14
+ import matplotlib.pyplot as plt
15
+
16
+ plt.rcParams.update({"font.size": 20})
17
+
18
+
19
+ # Some keys used for the following dictionaries
20
+ COUNT = "count"
21
+ CONF = "conf"
22
+ ACC = "acc"
23
+ BIN_ACC = "bin_acc"
24
+ BIN_CONF = "bin_conf"
25
+
26
+
27
+ def _bin_initializer(num_bins=10):
28
+ bin_dict = {}
29
+ for i in range(num_bins):
30
+ bin_dict[i] = {}
31
+ bin_dict[i][COUNT] = 0
32
+ bin_dict[i][CONF] = 0
33
+ bin_dict[i][ACC] = 0
34
+ bin_dict[i][BIN_ACC] = 0
35
+ bin_dict[i][BIN_CONF] = 0
36
+
37
+ return bin_dict
38
+
39
+
40
+ def _populate_bins(confs, preds, labels, num_bins=10):
41
+
42
+ bin_dict = _bin_initializer(num_bins)
43
+ num_test_samples = len(confs)
44
+
45
+ for i in range(0, num_test_samples):
46
+ confidence = confs[i]
47
+ prediction = preds[i]
48
+ label = labels[i]
49
+ # binn = int(math.ceil(((num_bins * confidence) - 1)))
50
+ binn = min(num_bins - 1, max(0, int(num_bins * confidence)))
51
+ # if binn>=num_bins:
52
+ # binn=num_bins-1
53
+ bin_dict[binn][COUNT] = bin_dict[binn][COUNT] + 1
54
+ bin_dict[binn][CONF] = bin_dict[binn][CONF] + confidence
55
+ bin_dict[binn][ACC] = bin_dict[binn][ACC] + (1 if (label == prediction) else 0)
56
+
57
+ for binn in range(0, num_bins):
58
+ if bin_dict[binn][COUNT] == 0:
59
+ bin_dict[binn][BIN_ACC] = 0
60
+ bin_dict[binn][BIN_CONF] = 0
61
+ else:
62
+ bin_dict[binn][BIN_ACC] = float(bin_dict[binn][ACC]) / bin_dict[binn][COUNT]
63
+ bin_dict[binn][BIN_CONF] = bin_dict[binn][CONF] / float(bin_dict[binn][COUNT])
64
+ return bin_dict
65
+
66
+
67
+ def expected_calibration_error(confs, preds, labels, num_bins=10):
68
+ bin_dict = _populate_bins(confs, preds, labels, num_bins)
69
+ num_samples = len(labels)
70
+ ece = 0
71
+ for i in range(num_bins):
72
+ bin_accuracy = bin_dict[i][BIN_ACC]
73
+ bin_confidence = bin_dict[i][BIN_CONF]
74
+ bin_count = bin_dict[i][COUNT]
75
+ ece += (float(bin_count) / num_samples) * abs(bin_accuracy - bin_confidence)
76
+ return ece
77
+
78
+
79
+ # Calibration error scores in the form of loss metrics
80
+ class ECELoss(nn.Module):
81
+ """
82
+ Compute ECE (Expected Calibration Error)
83
+ """
84
+
85
+ def __init__(self, n_bins=15):
86
+ super(ECELoss, self).__init__()
87
+ bin_boundaries = torch.linspace(0, 1, n_bins + 1)
88
+ self.bin_lowers = bin_boundaries[:-1]
89
+ self.bin_uppers = bin_boundaries[1:]
90
+
91
+ def forward(self, logits, labels):
92
+ softmaxes = F.softmax(logits, dim=1)
93
+ confidences, predictions = torch.max(softmaxes, 1)
94
+ accuracies = predictions.eq(labels)
95
+
96
+ ece = torch.zeros(1, device=logits.device)
97
+ for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
98
+ # Calculated |confidence - accuracy| in each bin
99
+ in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
100
+ prop_in_bin = in_bin.float().mean()
101
+ if prop_in_bin.item() > 0:
102
+ accuracy_in_bin = accuracies[in_bin].float().mean()
103
+ avg_confidence_in_bin = confidences[in_bin].mean()
104
+ ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
105
+
106
+ return ece
107
+
108
+
109
+ # Methods for plotting reliability diagrams and bin-strength plots
110
+ def reliability_plot(confs, preds, labels, num_bins=15, model_name='model'):
111
+ """
112
+ Method to draw a reliability plot from a model's predictions and confidences.
113
+ """
114
+ bin_dict = _populate_bins(confs, preds, labels, num_bins)
115
+ bns = [(i / float(num_bins)) for i in range(num_bins)]
116
+ y = []
117
+ for i in range(num_bins):
118
+ y.append(bin_dict[i][BIN_ACC])
119
+ plt.figure(figsize=(10, 8)) # width:20, height:3
120
+ plt.bar(bns, bns, align="edge", width=0.03, color="pink", label="Expected")
121
+ plt.bar(bns, y, align="edge", width=0.03, color="blue", alpha=0.5, label="Actual")
122
+ plt.ylabel("Accuracy", fontsize=30)
123
+ plt.xlabel("Confidence", fontsize=30)
124
+ plt.xticks(fontsize=30)
125
+ plt.yticks(fontsize=30)
126
+ plt.legend(fontsize=30, loc='upper left')
127
+ plt.savefig(f'./reliability_plot_{model_name}.pdf')
128
+ plt.savefig(f'./reliability_plot_{model_name}.png')
129
+ plt.show()
SPC-UQ/Image_Classification/metrics/classification_metrics.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Metrics to measure classification performance
3
+ """
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from torch.utils.data import DataLoader, TensorDataset
11
+
12
+ from utils.ensemble_utils import ensemble_forward_pass
13
+
14
+ from sklearn.metrics import accuracy_score
15
+ from sklearn.metrics import confusion_matrix
16
+
17
+ def evidential_loss(alpha, target, lambda_reg=0.001):
18
+ num_classes = alpha.shape[1]
19
+ target_one_hot = F.one_hot(target, num_classes=num_classes).float()
20
+ S = alpha.sum(dim=1, keepdim=True)
21
+
22
+ log_likelihood = torch.sum(target_one_hot * (torch.digamma(S) - torch.digamma(alpha)), dim=1)
23
+ kl_divergence = lambda_reg * torch.sum((alpha - 1) * (1 - target_one_hot), dim=1)
24
+
25
+ loss = log_likelihood + kl_divergence
26
+ return torch.mean(loss)
27
+
28
+ def create_adversarial_dataloader(model, data_loader, device, epsilon=0.03, batch_size=32, edl=False, joint=False):
29
+ adv_examples = []
30
+ adv_labels = []
31
+
32
+ model.eval()
33
+ for data, label in data_loader:
34
+ data = data.to(device).detach().requires_grad_(True)
35
+ label = label.to(device)
36
+
37
+ model.zero_grad()
38
+ logit = model(data)
39
+ if edl:
40
+ loss = evidential_loss(logit, label)
41
+ if joint:
42
+ loss = F.cross_entropy(logit[0], label)
43
+ else:
44
+ loss = F.cross_entropy(logit, label)
45
+ loss.backward()
46
+
47
+ signed_grad = data.grad.sign()
48
+ data_adv = data + epsilon * signed_grad
49
+
50
+ adv_examples.append(data_adv.detach().cpu())
51
+ adv_labels.append(label.detach().cpu())
52
+
53
+ adv_examples = torch.cat(adv_examples, dim=0)
54
+ adv_labels = torch.cat(adv_labels, dim=0)
55
+
56
+ adv_dataset = TensorDataset(adv_examples, adv_labels)
57
+ adv_dataloader = DataLoader(adv_dataset, batch_size=batch_size, shuffle=False)
58
+
59
+ return adv_dataloader
60
+
61
+
62
+ def get_logits_labels(model, data_loader, device):
63
+ """
64
+ Utility function to get logits and labels.
65
+ """
66
+ model.eval()
67
+ logits = []
68
+ labels = []
69
+ with torch.no_grad():
70
+ for data, label in data_loader:
71
+ data = data.to(device)
72
+ label = label.to(device)
73
+
74
+ logit = model(data)
75
+ logits.append(logit)
76
+ labels.append(label)
77
+ logits = torch.cat(logits, dim=0)
78
+ labels = torch.cat(labels, dim=0)
79
+ return logits, labels
80
+
81
+ def get_logits_labels_uq(model, data_loader, device):
82
+ """
83
+ Utility function to get logits as a list: [pred_all, mar_all, mar_up_all, mar_down_all]
84
+ and labels as a single tensor.
85
+ """
86
+ model.eval()
87
+ preds = []
88
+ mars = []
89
+ mars_up = []
90
+ mars_down = []
91
+ labels = []
92
+
93
+ with torch.no_grad():
94
+ for data, label in data_loader:
95
+ data = data.to(device)
96
+ label = label.to(device)
97
+
98
+ pred, mar, mar_up, mar_down = model(data)
99
+ preds.append(pred)
100
+ mars.append(mar)
101
+ mars_up.append(mar_up)
102
+ mars_down.append(mar_down)
103
+ labels.append(label)
104
+
105
+ pred_all = torch.cat(preds, dim=0)
106
+ mar_all = torch.cat(mars, dim=0)
107
+ mar_up_all = torch.cat(mars_up, dim=0)
108
+ mar_down_all = torch.cat(mars_down, dim=0)
109
+ labels_all = torch.cat(labels, dim=0)
110
+
111
+ logits = [pred_all, mar_all, mar_up_all, mar_down_all]
112
+ return logits, labels_all
113
+
114
+
115
+ def test_classification_net_softmax(softmax_prob, labels):
116
+ """
117
+ This function reports classification accuracy and confusion matrix given softmax vectors and
118
+ labels from a model.
119
+ """
120
+ labels_list = []
121
+ predictions_list = []
122
+ confidence_vals_list = []
123
+
124
+ confidence_vals, predictions = torch.max(softmax_prob, dim=1)
125
+ labels_list.extend(labels.cpu().numpy())
126
+ predictions_list.extend(predictions.cpu().numpy())
127
+ confidence_vals_list.extend(confidence_vals.detach().cpu().numpy())
128
+ accuracy = accuracy_score(labels_list, predictions_list)
129
+ return (
130
+ confusion_matrix(labels_list, predictions_list),
131
+ accuracy,
132
+ labels_list,
133
+ predictions_list,
134
+ confidence_vals_list,
135
+ )
136
+
137
+
138
+ def test_classification_net_logits(logits, labels):
139
+ """
140
+ This function reports classification accuracy and confusion matrix given logits and labels
141
+ from a model.
142
+ """
143
+ softmax_prob = F.softmax(logits, dim=1)
144
+ return test_classification_net_softmax(softmax_prob, labels)
145
+
146
+
147
+ def test_classification_net(model, data_loader, device):
148
+ """
149
+ This function reports classification accuracy and confusion matrix over a dataset.
150
+ """
151
+ logits, labels = get_logits_labels(model, data_loader, device)
152
+ return test_classification_net_logits(logits, labels)
153
+
154
+ def test_classification_uq(model, data_loader, device):
155
+ """
156
+ This function reports classification accuracy and confusion matrix over a dataset.
157
+ """
158
+ logits, labels = get_logits_labels_uq(model, data_loader, device)
159
+ return test_classification_net_logits(logits[0], labels)
160
+
161
+ def test_classification_net_ensemble(model_ensemble, data_loader, device):
162
+ """
163
+ This function reports classification accuracy and confusion matrix over a dataset
164
+ for a deep ensemble.
165
+ """
166
+ for model in model_ensemble:
167
+ model.eval()
168
+ softmax_prob = []
169
+ labels = []
170
+ with torch.no_grad():
171
+ for data, label in data_loader:
172
+ data = data.to(device)
173
+ label = label.to(device)
174
+
175
+ softmax, _, _ = ensemble_forward_pass(model_ensemble, data)
176
+ softmax_prob.append(softmax)
177
+ labels.append(label)
178
+ softmax_prob = torch.cat(softmax_prob, dim=0)
179
+ labels = torch.cat(labels, dim=0)
180
+
181
+ return test_classification_net_softmax(softmax_prob, labels)
182
+
183
+ def test_classification_net_logits_edl(logits, labels):
184
+ """
185
+ This function reports classification accuracy and confusion matrix given softmax vectors and
186
+ labels from a model.
187
+ """
188
+ labels_list = []
189
+ predictions_list = []
190
+ confidence_vals_list = []
191
+
192
+ predicted_probs = logits / logits.sum(dim=1, keepdim=True)
193
+ confidence_vals, predictions = torch.max(predicted_probs, dim=1)
194
+ labels_list.extend(labels.cpu().numpy())
195
+ predictions_list.extend(predictions.cpu().numpy())
196
+ confidence_vals_list.extend(confidence_vals.cpu().numpy())
197
+ accuracy = accuracy_score(labels_list, predictions_list)
198
+ return (
199
+ confusion_matrix(labels_list, predictions_list),
200
+ accuracy,
201
+ labels_list,
202
+ predictions_list,
203
+ confidence_vals_list,
204
+ )
205
+
206
+ def test_classification_net_edl(model, data_loader, device):
207
+ """
208
+ This function reports classification accuracy and confusion matrix over a dataset.
209
+ """
210
+ logits, labels = get_logits_labels(model, data_loader, device)
211
+ return test_classification_net_logits_edl(logits, labels)
SPC-UQ/Image_Classification/metrics/ood_metrics.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Utility functions to get OOD detection ROC curves and AUROC scores
2
+ # Ideally should be agnostic of model architectures
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from sklearn import metrics
7
+
8
+ from utils.ensemble_utils import ensemble_forward_pass
9
+ from metrics.classification_metrics import get_logits_labels
10
+ from metrics.uncertainty_confidence import entropy, logsumexp, confidence, edl_unc
11
+
12
+
13
+ def get_roc_auc(net, test_loader, ood_test_loader, uncertainty, device, confidence=False):
14
+ logits, _ = get_logits_labels(net, test_loader, device)
15
+ ood_logits, _ = get_logits_labels(net, ood_test_loader, device)
16
+
17
+ return get_roc_auc_logits(logits, ood_logits, uncertainty, device, confidence=confidence)
18
+
19
+
20
+ def get_roc_auc_logits(logits, ood_logits, uncertainty, device, confidence=False):
21
+ uncertainties = uncertainty(logits)
22
+ ood_uncertainties = uncertainty(ood_logits)
23
+
24
+ # In-distribution
25
+ bin_labels = torch.zeros(uncertainties.shape[0]).to(device)
26
+ in_scores = uncertainties
27
+
28
+ # OOD
29
+ bin_labels = torch.cat((bin_labels, torch.ones(ood_uncertainties.shape[0]).to(device)))
30
+
31
+ if confidence:
32
+ bin_labels = 1 - bin_labels
33
+ ood_scores = ood_uncertainties # entropy(ood_logits)
34
+ scores = torch.cat((in_scores, ood_scores))
35
+
36
+ fpr, tpr, thresholds = metrics.roc_curve(bin_labels.cpu().numpy(), scores.cpu().numpy())
37
+ precision, recall, prc_thresholds = metrics.precision_recall_curve(bin_labels.cpu().numpy(), scores.cpu().numpy())
38
+ auroc = metrics.roc_auc_score(bin_labels.cpu().numpy(), scores.cpu().numpy())
39
+ auprc = metrics.average_precision_score(bin_labels.cpu().numpy(), scores.cpu().numpy())
40
+
41
+ return (fpr, tpr, thresholds), (precision, recall, prc_thresholds), auroc, auprc
42
+
43
+
44
+ def get_roc_auc_uncs(uncertainties, ood_uncertainties, device, confidence=False):
45
+ # In-distribution
46
+ bin_labels = torch.zeros(uncertainties.shape[0]).to(device)
47
+ in_scores = uncertainties
48
+
49
+ # OOD
50
+ bin_labels = torch.cat((bin_labels, torch.ones(ood_uncertainties.shape[0]).to(device)))
51
+
52
+ if confidence:
53
+ bin_labels = 1 - bin_labels
54
+ ood_scores = ood_uncertainties # entropy(ood_logits)
55
+ scores = torch.cat((in_scores, ood_scores))
56
+
57
+ fpr, tpr, thresholds = metrics.roc_curve(bin_labels.detach().cpu().numpy(), scores.detach().cpu().numpy())
58
+ precision, recall, prc_thresholds = metrics.precision_recall_curve(bin_labels.detach().cpu().numpy(), scores.detach().cpu().numpy())
59
+ auroc = metrics.roc_auc_score(bin_labels.detach().cpu().numpy(), scores.detach().cpu().numpy())
60
+ auprc = metrics.average_precision_score(bin_labels.detach().cpu().numpy(), scores.detach().cpu().numpy())
61
+
62
+
63
+ return (fpr, tpr, thresholds), (precision, recall, prc_thresholds), auroc, auprc
64
+
65
+ def get_roc_auc_ensemble(model_ensemble, test_loader, ood_test_loader, uncertainty, device):
66
+ bin_labels_uncertainties = None
67
+ uncertainties = None
68
+
69
+ for model in model_ensemble:
70
+ model.eval()
71
+
72
+ bin_labels_uncertainties = []
73
+ uncertainties = []
74
+ with torch.no_grad():
75
+ # Getting uncertainties for in-distribution data
76
+ for data, label in test_loader:
77
+ data = data.to(device)
78
+ label = label.to(device)
79
+
80
+ bin_label_uncertainty = torch.zeros(label.shape).to(device)
81
+ if uncertainty == "mutual_information":
82
+ net_output, _, unc = ensemble_forward_pass(model_ensemble, data)
83
+ else:
84
+ net_output, unc, _ = ensemble_forward_pass(model_ensemble, data)
85
+
86
+ bin_labels_uncertainties.append(bin_label_uncertainty)
87
+ uncertainties.append(unc)
88
+
89
+ # Getting entropies for OOD data
90
+ for data, label in ood_test_loader:
91
+ data = data.to(device)
92
+ label = label.to(device)
93
+
94
+ bin_label_uncertainty = torch.ones(label.shape).to(device)
95
+ if uncertainty == "mutual_information":
96
+ net_output, _, unc = ensemble_forward_pass(model_ensemble, data)
97
+ else:
98
+ net_output, unc, _ = ensemble_forward_pass(model_ensemble, data)
99
+
100
+ bin_labels_uncertainties.append(bin_label_uncertainty)
101
+ uncertainties.append(unc)
102
+
103
+ bin_labels_uncertainties = torch.cat(bin_labels_uncertainties)
104
+ uncertainties = torch.cat(uncertainties)
105
+
106
+ fpr, tpr, roc_thresholds = metrics.roc_curve(bin_labels_uncertainties.cpu().numpy(), uncertainties.cpu().numpy())
107
+ precision, recall, prc_thresholds = metrics.precision_recall_curve(
108
+ bin_labels_uncertainties.cpu().numpy(), uncertainties.cpu().numpy()
109
+ )
110
+ auroc = metrics.roc_auc_score(bin_labels_uncertainties.cpu().numpy(), uncertainties.cpu().numpy())
111
+ auprc = metrics.average_precision_score(bin_labels_uncertainties.cpu().numpy(), uncertainties.cpu().numpy())
112
+
113
+
114
+ return (fpr, tpr, roc_thresholds), (precision, recall, prc_thresholds), auroc, auprc
115
+
116
+
117
+ def get_unc_ensemble(model_ensemble, test_loader, uncertainty, device):
118
+ for model in model_ensemble:
119
+ model.eval()
120
+
121
+ uncertainties = []
122
+ with torch.no_grad():
123
+ for data, label in test_loader:
124
+ data = data.to(device)
125
+
126
+ if uncertainty == "mutual_information":
127
+ net_output, _, unc = ensemble_forward_pass(model_ensemble, data)
128
+ else:
129
+ net_output, unc, _ = ensemble_forward_pass(model_ensemble, data)
130
+
131
+ uncertainties.append(unc)
132
+
133
+ uncertainties = torch.cat(uncertainties)
134
+
135
+ return uncertainties
SPC-UQ/Image_Classification/metrics/uncertainty_confidence.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Metrics measuring either uncertainty or confidence of a model.
3
+ """
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def entropy(logits):
9
+ p = F.softmax(logits, dim=1)
10
+ logp = F.log_softmax(logits, dim=1)
11
+ plogp = p * logp
12
+ entropy = -torch.sum(plogp, dim=1)
13
+ return entropy
14
+
15
+
16
+ def logsumexp(logits):
17
+ return torch.logsumexp(logits, dim=1, keepdim=False)
18
+
19
+
20
+ def confidence(logits):
21
+ p = F.softmax(logits, dim=1)
22
+ confidence, _ = torch.max(p, dim=1)
23
+ return confidence
24
+
25
+
26
+ def entropy_prob(probs):
27
+ p = probs
28
+ eps = 1e-12
29
+ logp = torch.log(p + eps)
30
+ plogp = p * logp
31
+ entropy = -torch.sum(plogp, dim=1)
32
+ return entropy
33
+
34
+
35
+ def mutual_information_prob(probs):
36
+ mean_output = torch.mean(probs, dim=0)
37
+ predictive_entropy = entropy_prob(mean_output)
38
+
39
+ # Computing expectation of entropies
40
+ p = probs
41
+ eps = 1e-12
42
+ logp = torch.log(p + eps)
43
+ plogp = p * logp
44
+ exp_entropies = torch.mean(-torch.sum(plogp, dim=2), dim=0)
45
+
46
+ # Computing mutual information
47
+ mi = predictive_entropy - exp_entropies
48
+ return mi
49
+
50
+ def self_consistency(mars):
51
+ logits=mars[0]
52
+ logits = torch.nn.functional.softmax(logits, dim=1)
53
+ mar=mars[1].squeeze()
54
+ mar_up=1-logits
55
+ mar_down=logits
56
+ uncertainty = abs(2 * mar_up * mar_down - mar) #(mar_up + mar_down)=1
57
+ uncertainty = torch.sum(uncertainty, dim=1)
58
+
59
+ return uncertainty
60
+
61
+ def edl_unc(logits):
62
+ num_classes = logits.shape[1]
63
+ uncertainty = num_classes / logits.sum(dim=1)
64
+ return uncertainty
65
+
66
+ def certificate(OCs):
67
+ return OCs
SPC-UQ/Image_Classification/net/__init__.py ADDED
File without changes
SPC-UQ/Image_Classification/net/imagenet_vgg.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Optional
5
+ from torchvision.models import vgg16, VGG16_Weights
6
+
7
+
8
+ class ImagenetVGG16(nn.Module):
9
+ """
10
+ VGG16 wrapper for ImageNet-like classification with:
11
+ - Optional pretrained backbone
12
+ - Optional feature freezing
13
+ - Temperature-scaled logits
14
+ - Exposes penultimate features via `self.feature` (detached)
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ num_classes: int = 1000,
20
+ pretrained: bool = True,
21
+ temp: float = 1.0,
22
+ freeze_features: bool = False,
23
+ ) -> None:
24
+ super().__init__()
25
+
26
+ # Load base model (weights imply specific preprocessing; handle in dataloader)
27
+ base_model = vgg16(weights=VGG16_Weights.DEFAULT if pretrained else None)
28
+
29
+ # Convolutional feature extractor and avgpool
30
+ self.features: nn.Module = base_model.features
31
+ self.avgpool: nn.Module = base_model.avgpool
32
+
33
+ # Penultimate FC stack from original VGG16 (remove final classifier layer)
34
+ # VGG16 classifier:
35
+ # [Linear(25088->4096), ReLU, Dropout, Linear(4096->4096), ReLU, Dropout, Linear(4096->1000)]
36
+ self.fc_pre: nn.Sequential = nn.Sequential(*list(base_model.classifier.children())[:-1])
37
+
38
+ # New classification head
39
+ self.classifier: nn.Linear = nn.Linear(4096, num_classes)
40
+
41
+ # If using pretrained and num_classes matches 1000, copy the final layer weights
42
+ if pretrained and num_classes == 1000:
43
+ with torch.no_grad():
44
+ self.classifier.weight.copy_(base_model.classifier[-1].weight)
45
+ self.classifier.bias.copy_(base_model.classifier[-1].bias)
46
+
47
+ # Optional: freeze convolutional features (+ fc_pre) for linear probing
48
+ if freeze_features:
49
+ for p in self.features.parameters():
50
+ p.requires_grad = False
51
+ for p in self.fc_pre.parameters():
52
+ p.requires_grad = False
53
+
54
+ # Temperature (applied to logits at inference/training)
55
+ self.register_buffer("temperature", torch.tensor(float(temp)))
56
+ self.feature: Optional[torch.Tensor] = None # cached detached penultimate features
57
+
58
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
59
+ """
60
+ Args:
61
+ x: Input batch of shape (B, 3, H, W)
62
+
63
+ Returns:
64
+ logits: Tensor of shape (B, num_classes)
65
+ """
66
+ x = self.features(x)
67
+ x = self.avgpool(x)
68
+ x = torch.flatten(x, 1)
69
+ x = self.fc_pre(x)
70
+
71
+ # Cache penultimate features (detached) for downstream use
72
+ self.feature = x.detach()
73
+
74
+ logits = self.classifier(x)
75
+ if self.temperature is not None and float(self.temperature) != 1.0:
76
+ logits = logits / self.temperature
77
+ return logits
78
+
79
+
80
+ def imagenet_vgg16(
81
+ temp: float = 1.0,
82
+ pretrained: bool = True,
83
+ num_classes: int = 1000,
84
+ freeze_features: bool = False,
85
+ **kwargs,
86
+ ) -> ImagenetVGG16:
87
+ """
88
+ Factory function for ImagenetVGG16.
89
+
90
+ Args:
91
+ temp: Temperature applied to logits (T=1.0 disables scaling).
92
+ pretrained: Load pretrained ImageNet weights for the backbone.
93
+ num_classes: Output classes for the final classifier.
94
+ freeze_features: If True, freeze backbone + fc_pre for linear probing.
95
+ **kwargs: Forwarded to the ImagenetVGG16 init (future-proof).
96
+
97
+ Returns:
98
+ Initialized ImagenetVGG16 model.
99
+ """
100
+ return ImagenetVGG16(
101
+ num_classes=num_classes,
102
+ pretrained=pretrained,
103
+ temp=temp,
104
+ freeze_features=freeze_features,
105
+ **kwargs,
106
+ )
SPC-UQ/Image_Classification/net/imagenet_vit.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Optional
4
+ from torchvision.models import vit_b_16, ViT_B_16_Weights
5
+
6
+
7
+ class ImagenetViT(nn.Module):
8
+ """
9
+ ViT-B/16 wrapper with:
10
+ - Optional pretrained backbone (torchvision)
11
+ - Proper CLS token + positional embeddings usage
12
+ - Temperature-scaled logits
13
+ - Exposed CLS feature via `self.feature` (detached)
14
+ - Optional backbone freezing for linear probing
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ num_classes: int = 1000,
20
+ pretrained: bool = True,
21
+ temp: float = 1.0,
22
+ freeze_backbone: bool = False,
23
+ ) -> None:
24
+ super().__init__()
25
+ self.backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT if pretrained else None)
26
+
27
+ # Hidden size & final norm
28
+ self.hidden_dim: int = self.backbone.hidden_dim
29
+ self.norm: nn.Module = self.backbone.encoder.ln # final LayerNorm
30
+
31
+ # New classification head
32
+ self.head: nn.Linear = nn.Linear(self.hidden_dim, num_classes)
33
+
34
+ # If using pretrained and keeping 1000-class head, copy weights
35
+ if pretrained and num_classes == 1000:
36
+ with torch.no_grad():
37
+ src = self.backbone.heads.head
38
+ self.head.weight.copy_(src.weight)
39
+ self.head.bias.copy_(src.bias)
40
+
41
+ # Optionally freeze everything except the head
42
+ if freeze_backbone:
43
+ for p in self.backbone.parameters():
44
+ p.requires_grad = False
45
+
46
+ # Temperature buffer and CLS feature cache
47
+ self.register_buffer("temperature", torch.tensor(float(temp)))
48
+ self.feature: Optional[torch.Tensor] = None
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ """
52
+ Args:
53
+ x: (B, 3, H, W)
54
+
55
+ Returns:
56
+ logits: (B, num_classes)
57
+ """
58
+ # Patchify + linear proj (handles resizing logic)
59
+ x = self.backbone._process_input(x) # (B, N, hidden_dim)
60
+
61
+ # CLS token + positional embeddings (use pretrained params)
62
+ n = x.shape[0]
63
+ cls_token = self.backbone.class_token.expand(n, -1, -1) # (B, 1, hidden_dim)
64
+ x = torch.cat((cls_token, x), dim=1) # (B, N+1, hidden_dim)
65
+
66
+ # Positional embeddings + dropout
67
+ x = x + self.backbone.encoder.pos_embedding # (B, N+1, hidden_dim)
68
+ x = self.backbone.encoder.dropout(x)
69
+
70
+ # Transformer encoder (layers) + final norm
71
+ x = self.backbone.encoder.layers(x)
72
+ x = self.norm(x)
73
+
74
+ # CLS feature
75
+ cls = x[:, 0] # (B, hidden_dim)
76
+ self.feature = cls.detach() # cache detached feature
77
+
78
+ # Head + optional temperature scaling
79
+ logits = self.head(cls)
80
+ if float(self.temperature) != 1.0:
81
+ logits = logits / self.temperature
82
+ return logits
83
+
84
+
85
+ def imagenet_vit(
86
+ temp: float = 1.0,
87
+ pretrained: bool = True,
88
+ num_classes: int = 1000,
89
+ freeze_backbone: bool = False,
90
+ **kwargs,
91
+ ) -> ImagenetViT:
92
+ """
93
+ Factory for ImagenetViT.
94
+ """
95
+ return ImagenetViT(
96
+ num_classes=num_classes,
97
+ pretrained=pretrained,
98
+ temp=temp,
99
+ freeze_backbone=freeze_backbone,
100
+ **kwargs,
101
+ )
SPC-UQ/Image_Classification/net/imagenet_wide.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision.models import wide_resnet50_2, Wide_ResNet50_2_Weights
5
+
6
+
7
+ class ImagenetWideResNet(nn.Module):
8
+ def __init__(self, num_classes=1000, pretrained=True, temp=1.0):
9
+ super().__init__()
10
+
11
+ base_model = wide_resnet50_2(weights=Wide_ResNet50_2_Weights.DEFAULT if pretrained else None)
12
+
13
+ # Adapt to match your original WRN structure
14
+ self.features = nn.Sequential(
15
+ base_model.conv1,
16
+ base_model.bn1,
17
+ base_model.relu,
18
+ base_model.maxpool,
19
+ base_model.layer1,
20
+ base_model.layer2,
21
+ base_model.layer3,
22
+ base_model.layer4,
23
+ base_model.avgpool
24
+ )
25
+
26
+ self.linear = nn.Linear(2048, num_classes)
27
+ self.temp = temp
28
+ self.feature = None
29
+
30
+ if pretrained:
31
+ self.linear.load_state_dict(base_model.fc.state_dict())
32
+
33
+ def forward(self, x):
34
+ out = self.features(x)
35
+ out = torch.flatten(out, 1)
36
+ self.feature = out.clone().detach()
37
+ if self.temp == 1:
38
+ out = self.linear(out)
39
+ else:
40
+ out = self.linear(out) / self.temp
41
+ return out
42
+
43
+
44
+ def imagenet_wide(temp=1.0, pretrained=True, **kwargs):
45
+ model = ImagenetWideResNet(pretrained=pretrained, temp=temp, **kwargs)
46
+ return model
SPC-UQ/Image_Classification/net/lenet.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of Lenet in pytorch.
2
+ Refernece:
3
+ [1] LeCun, Y., Bottou, L., Bengio, Y., & Haffner, P. (1998).
4
+ Gradient-based learning applied to document recognition.
5
+ Proceedings of the IEEE, 86, 2278-2324.
6
+ """
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class LeNet(nn.Module):
12
+ def __init__(self, num_classes, temp=1.0, mnist=True, **kwargs):
13
+ super(LeNet, self).__init__()
14
+ self.num_classes = num_classes
15
+ self.conv1 = nn.Conv2d(1 if mnist else 3, 6, 5)
16
+ self.conv2 = nn.Conv2d(6, 16, 5)
17
+ self.fc1 = nn.Linear(256, 120)
18
+ self.fc2 = nn.Linear(120, 84)
19
+ self.fc3 = nn.Linear(84, num_classes)
20
+ self.temp = temp
21
+ self.feature = None
22
+
23
+ def forward(self, x):
24
+ out = F.relu(self.conv1(x))
25
+ out = F.max_pool2d(out, 2)
26
+ out = F.relu(self.conv2(out))
27
+ out = F.max_pool2d(out, 2)
28
+ out = out.view(out.size(0), -1)
29
+ out = F.relu(self.fc1(out))
30
+ out = F.relu(self.fc2(out))
31
+ self.feature = out
32
+ out = self.fc3(out) / self.temp
33
+ return out
34
+
35
+
36
+ def lenet(num_classes=10, temp=1.0, mnist=True, **kwargs):
37
+ return LeNet(num_classes=num_classes, temp=temp, mnist=True, **kwargs)
SPC-UQ/Image_Classification/net/resnet.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pytorch implementation of ResNet models.
3
+ Reference:
4
+ [1] He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: CVPR, 2016.
5
+ """
6
+ import torch
7
+ import math
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from net.spectral_normalization.spectral_norm_conv_inplace import spectral_norm_conv
12
+ from net.spectral_normalization.spectral_norm_fc import spectral_norm_fc
13
+
14
+
15
+ class AvgPoolShortCut(nn.Module):
16
+ def __init__(self, stride, out_c, in_c):
17
+ super(AvgPoolShortCut, self).__init__()
18
+ self.stride = stride
19
+ self.out_c = out_c
20
+ self.in_c = in_c
21
+
22
+ def forward(self, x):
23
+ if x.shape[2] % 2 != 0:
24
+ x = F.avg_pool2d(x, 1, self.stride)
25
+ else:
26
+ x = F.avg_pool2d(x, self.stride, self.stride)
27
+ pad = torch.zeros(x.shape[0], self.out_c - self.in_c, x.shape[2], x.shape[3], device=x.device,)
28
+ x = torch.cat((x, pad), dim=1)
29
+ return x
30
+
31
+
32
+ class BasicBlock(nn.Module):
33
+ expansion = 1
34
+
35
+ def __init__(self, input_size, wrapped_conv, in_planes, planes, stride=1, mod=True):
36
+ super(BasicBlock, self).__init__()
37
+ self.conv1 = wrapped_conv(input_size, in_planes, planes, kernel_size=3, stride=stride)
38
+ self.bn1 = nn.BatchNorm2d(planes)
39
+ self.conv2 = wrapped_conv(math.ceil(input_size / stride), planes, planes, kernel_size=3, stride=1)
40
+ self.bn2 = nn.BatchNorm2d(planes)
41
+ self.mod = mod
42
+ self.activation = F.leaky_relu if self.mod else F.relu
43
+
44
+ self.shortcut = nn.Sequential()
45
+ if stride != 1 or in_planes != self.expansion * planes:
46
+ if mod:
47
+ self.shortcut = nn.Sequential(AvgPoolShortCut(stride, self.expansion * planes, in_planes))
48
+ else:
49
+ self.shortcut = nn.Sequential(
50
+ wrapped_conv(input_size, in_planes, self.expansion * planes, kernel_size=1, stride=stride,),
51
+ nn.BatchNorm2d(planes),
52
+ )
53
+
54
+ def forward(self, x):
55
+ out = self.activation(self.bn1(self.conv1(x)))
56
+ out = self.bn2(self.conv2(out))
57
+ out += self.shortcut(x)
58
+ out = self.activation(out)
59
+ return out
60
+
61
+
62
+ class Bottleneck(nn.Module):
63
+ expansion = 4
64
+
65
+ def __init__(self, input_size, wrapped_conv, in_planes, planes, stride=1, mod=True):
66
+ super(Bottleneck, self).__init__()
67
+ self.conv1 = wrapped_conv(input_size, in_planes, planes, kernel_size=1, stride=1)
68
+ self.bn1 = nn.BatchNorm2d(planes)
69
+ self.conv2 = wrapped_conv(input_size, planes, planes, kernel_size=3, stride=stride)
70
+ self.bn2 = nn.BatchNorm2d(planes)
71
+ self.conv3 = wrapped_conv(math.ceil(input_size / stride), planes, self.expansion * planes, kernel_size=1, stride=1)
72
+ self.bn3 = nn.BatchNorm2d(self.expansion * planes)
73
+ self.mod = mod
74
+ self.activation = F.leaky_relu if self.mod else F.relu
75
+
76
+ self.shortcut = nn.Sequential()
77
+ if stride != 1 or in_planes != self.expansion * planes:
78
+ if mod:
79
+ self.shortcut = nn.Sequential(AvgPoolShortCut(stride, self.expansion * planes, in_planes))
80
+ else:
81
+ self.shortcut = nn.Sequential(
82
+ wrapped_conv(input_size, in_planes, self.expansion * planes, kernel_size=1, stride=stride,),
83
+ nn.BatchNorm2d(self.expansion * planes),
84
+ )
85
+
86
+ def forward(self, x):
87
+ out = self.activation(self.bn1(self.conv1(x)))
88
+ out = self.activation(self.bn2(self.conv2(out)))
89
+ out = self.bn3(self.conv3(out))
90
+ out += self.shortcut(x)
91
+ out = self.activation(out)
92
+ return out
93
+
94
+
95
+ class ResNet(nn.Module):
96
+ def __init__(
97
+ self,
98
+ block,
99
+ num_blocks,
100
+ num_classes=10,
101
+ temp=1.0,
102
+ spectral_normalization=True,
103
+ mod=True,
104
+ coeff=3,
105
+ n_power_iterations=1,
106
+ mnist=False,
107
+ ):
108
+ """
109
+ If the "mod" parameter is set to True, the architecture uses 2 modifications:
110
+ 1. LeakyReLU instead of normal ReLU
111
+ 2. Average Pooling on the residual connections.
112
+ """
113
+ super(ResNet, self).__init__()
114
+ self.in_planes = 64
115
+
116
+ self.mod = mod
117
+
118
+ def wrapped_conv(input_size, in_c, out_c, kernel_size, stride):
119
+ padding = 1 if kernel_size == 3 else 0
120
+
121
+ conv = nn.Conv2d(in_c, out_c, kernel_size, stride, padding, bias=False)
122
+
123
+ if not spectral_normalization:
124
+ return conv
125
+
126
+ # NOTE: Google uses the spectral_norm_fc in all cases
127
+ if kernel_size == 1:
128
+ # use spectral norm fc, because bound are tight for 1x1 convolutions
129
+ wrapped_conv = spectral_norm_fc(conv, coeff, n_power_iterations)
130
+ else:
131
+ # Otherwise use spectral norm conv, with loose bound
132
+ shapes = (in_c, input_size, input_size)
133
+ wrapped_conv = spectral_norm_conv(conv, coeff, shapes, n_power_iterations)
134
+
135
+ return wrapped_conv
136
+
137
+ self.wrapped_conv = wrapped_conv
138
+
139
+ self.bn1 = nn.BatchNorm2d(64)
140
+
141
+ if mnist:
142
+ self.conv1 = wrapped_conv(28, 1, 64, kernel_size=3, stride=1)
143
+ self.layer1 = self._make_layer(block, 28, 64, num_blocks[0], stride=1)
144
+ self.layer2 = self._make_layer(block, 28, 128, num_blocks[1], stride=2)
145
+ self.layer3 = self._make_layer(block, 14, 256, num_blocks[2], stride=2)
146
+ self.layer4 = self._make_layer(block, 7, 512, num_blocks[3], stride=2)
147
+ else:
148
+ self.conv1 = wrapped_conv(32, 3, 64, kernel_size=3, stride=1)
149
+ self.layer1 = self._make_layer(block, 32, 64, num_blocks[0], stride=1)
150
+ self.layer2 = self._make_layer(block, 32, 128, num_blocks[1], stride=2)
151
+ self.layer3 = self._make_layer(block, 16, 256, num_blocks[2], stride=2)
152
+ self.layer4 = self._make_layer(block, 8, 512, num_blocks[3], stride=2)
153
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
154
+ self.activation = F.leaky_relu if self.mod else F.relu
155
+ self.feature = None
156
+ self.temp = temp
157
+
158
+ def _make_layer(self, block, input_size, planes, num_blocks, stride):
159
+ strides = [stride] + [1] * (num_blocks - 1)
160
+ layers = []
161
+ for stride in strides:
162
+ layers.append(block(input_size, self.wrapped_conv, self.in_planes, planes, stride, self.mod,))
163
+ self.in_planes = planes * block.expansion
164
+ input_size = math.ceil(input_size / stride)
165
+ return nn.Sequential(*layers)
166
+
167
+ def forward(self, x):
168
+ out = self.activation(self.bn1(self.conv1(x)))
169
+ out = self.layer1(out)
170
+ out = self.layer2(out)
171
+ out = self.layer3(out)
172
+ out = self.layer4(out)
173
+ out = F.avg_pool2d(out, 4)
174
+ out = out.view(out.size(0), -1)
175
+ self.feature = out.clone().detach()
176
+ if self.temp==1:
177
+ out = self.fc(out)
178
+ else:
179
+ out = self.fc(out) / self.temp
180
+ return out
181
+
182
+
183
+ def resnet18(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
184
+ model = ResNet(
185
+ BasicBlock,
186
+ [2, 2, 2, 2],
187
+ spectral_normalization=spectral_normalization,
188
+ mod=mod,
189
+ temp=temp,
190
+ mnist=mnist,
191
+ **kwargs
192
+ )
193
+ return model
194
+
195
+
196
+ def resnet50(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
197
+ model = ResNet(
198
+ Bottleneck,
199
+ [3, 4, 6, 3],
200
+ spectral_normalization=spectral_normalization,
201
+ mod=mod,
202
+ temp=temp,
203
+ mnist=mnist,
204
+ **kwargs
205
+ )
206
+ return model
207
+
208
+
209
+ def resnet101(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
210
+ model = ResNet(
211
+ Bottleneck,
212
+ [3, 4, 23, 3],
213
+ spectral_normalization=spectral_normalization,
214
+ mod=mod,
215
+ temp=temp,
216
+ mnist=mnist,
217
+ **kwargs
218
+ )
219
+ return model
220
+
221
+
222
+ def resnet110(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
223
+ model = ResNet(
224
+ Bottleneck,
225
+ [3, 4, 26, 3],
226
+ spectral_normalization=spectral_normalization,
227
+ mod=mod,
228
+ temp=temp,
229
+ mnist=mnist,
230
+ **kwargs
231
+ )
232
+ return model
233
+
234
+
235
+ def resnet152(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
236
+ model = ResNet(
237
+ Bottleneck,
238
+ [3, 8, 36, 3],
239
+ spectral_normalization=spectral_normalization,
240
+ mod=mod,
241
+ temp=temp,
242
+ mnist=mnist,
243
+ **kwargs
244
+ )
245
+ return model
SPC-UQ/Image_Classification/net/resnet_edl.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pytorch implementation of ResNet models.
3
+ Reference:
4
+ [1] He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: CVPR, 2016.
5
+ """
6
+ import torch
7
+ import math
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from net.spectral_normalization.spectral_norm_conv_inplace import spectral_norm_conv
12
+ from net.spectral_normalization.spectral_norm_fc import spectral_norm_fc
13
+
14
+
15
+ class AvgPoolShortCut(nn.Module):
16
+ def __init__(self, stride, out_c, in_c):
17
+ super(AvgPoolShortCut, self).__init__()
18
+ self.stride = stride
19
+ self.out_c = out_c
20
+ self.in_c = in_c
21
+
22
+ def forward(self, x):
23
+ if x.shape[2] % 2 != 0:
24
+ x = F.avg_pool2d(x, 1, self.stride)
25
+ else:
26
+ x = F.avg_pool2d(x, self.stride, self.stride)
27
+ pad = torch.zeros(x.shape[0], self.out_c - self.in_c, x.shape[2], x.shape[3], device=x.device,)
28
+ x = torch.cat((x, pad), dim=1)
29
+ return x
30
+
31
+
32
+ class BasicBlock(nn.Module):
33
+ expansion = 1
34
+
35
+ def __init__(self, input_size, wrapped_conv, in_planes, planes, stride=1, mod=True):
36
+ super(BasicBlock, self).__init__()
37
+ self.conv1 = wrapped_conv(input_size, in_planes, planes, kernel_size=3, stride=stride)
38
+ self.bn1 = nn.BatchNorm2d(planes)
39
+ self.conv2 = wrapped_conv(math.ceil(input_size / stride), planes, planes, kernel_size=3, stride=1)
40
+ self.bn2 = nn.BatchNorm2d(planes)
41
+ self.mod = mod
42
+ self.activation = F.leaky_relu if self.mod else F.relu
43
+
44
+ self.shortcut = nn.Sequential()
45
+ if stride != 1 or in_planes != self.expansion * planes:
46
+ if mod:
47
+ self.shortcut = nn.Sequential(AvgPoolShortCut(stride, self.expansion * planes, in_planes))
48
+ else:
49
+ self.shortcut = nn.Sequential(
50
+ wrapped_conv(input_size, in_planes, self.expansion * planes, kernel_size=1, stride=stride,),
51
+ nn.BatchNorm2d(planes),
52
+ )
53
+
54
+ def forward(self, x):
55
+ out = self.activation(self.bn1(self.conv1(x)))
56
+ out = self.bn2(self.conv2(out))
57
+ out += self.shortcut(x)
58
+ out = self.activation(out)
59
+ return out
60
+
61
+
62
+ class Bottleneck(nn.Module):
63
+ expansion = 4
64
+
65
+ def __init__(self, input_size, wrapped_conv, in_planes, planes, stride=1, mod=True):
66
+ super(Bottleneck, self).__init__()
67
+ self.conv1 = wrapped_conv(input_size, in_planes, planes, kernel_size=1, stride=1)
68
+ self.bn1 = nn.BatchNorm2d(planes)
69
+ self.conv2 = wrapped_conv(input_size, planes, planes, kernel_size=3, stride=stride)
70
+ self.bn2 = nn.BatchNorm2d(planes)
71
+ self.conv3 = wrapped_conv(math.ceil(input_size / stride), planes, self.expansion * planes, kernel_size=1, stride=1)
72
+ self.bn3 = nn.BatchNorm2d(self.expansion * planes)
73
+ self.mod = mod
74
+ self.activation = F.leaky_relu if self.mod else F.relu
75
+
76
+ self.shortcut = nn.Sequential()
77
+ if stride != 1 or in_planes != self.expansion * planes:
78
+ if mod:
79
+ self.shortcut = nn.Sequential(AvgPoolShortCut(stride, self.expansion * planes, in_planes))
80
+ else:
81
+ self.shortcut = nn.Sequential(
82
+ wrapped_conv(input_size, in_planes, self.expansion * planes, kernel_size=1, stride=stride,),
83
+ nn.BatchNorm2d(self.expansion * planes),
84
+ )
85
+
86
+ def forward(self, x):
87
+ out = self.activation(self.bn1(self.conv1(x)))
88
+ out = self.activation(self.bn2(self.conv2(out)))
89
+ out = self.bn3(self.conv3(out))
90
+ out += self.shortcut(x)
91
+ out = self.activation(out)
92
+ return out
93
+
94
+
95
+ class ResNet(nn.Module):
96
+ def __init__(
97
+ self,
98
+ block,
99
+ num_blocks,
100
+ num_classes=10,
101
+ temp=1.0,
102
+ spectral_normalization=True,
103
+ mod=True,
104
+ coeff=3,
105
+ n_power_iterations=1,
106
+ mnist=False,
107
+ ):
108
+ """
109
+ If the "mod" parameter is set to True, the architecture uses 2 modifications:
110
+ 1. LeakyReLU instead of normal ReLU
111
+ 2. Average Pooling on the residual connections.
112
+ """
113
+ super(ResNet, self).__init__()
114
+ self.in_planes = 64
115
+
116
+ self.mod = mod
117
+
118
+ def wrapped_conv(input_size, in_c, out_c, kernel_size, stride):
119
+ padding = 1 if kernel_size == 3 else 0
120
+
121
+ conv = nn.Conv2d(in_c, out_c, kernel_size, stride, padding, bias=False)
122
+
123
+ if not spectral_normalization:
124
+ return conv
125
+
126
+ # NOTE: Google uses the spectral_norm_fc in all cases
127
+ if kernel_size == 1:
128
+ # use spectral norm fc, because bound are tight for 1x1 convolutions
129
+ wrapped_conv = spectral_norm_fc(conv, coeff, n_power_iterations)
130
+ else:
131
+ # Otherwise use spectral norm conv, with loose bound
132
+ shapes = (in_c, input_size, input_size)
133
+ wrapped_conv = spectral_norm_conv(conv, coeff, shapes, n_power_iterations)
134
+
135
+ return wrapped_conv
136
+
137
+ self.wrapped_conv = wrapped_conv
138
+
139
+ self.bn1 = nn.BatchNorm2d(64)
140
+
141
+ if mnist:
142
+ self.conv1 = wrapped_conv(28, 1, 64, kernel_size=3, stride=1)
143
+ self.layer1 = self._make_layer(block, 28, 64, num_blocks[0], stride=1)
144
+ self.layer2 = self._make_layer(block, 28, 128, num_blocks[1], stride=2)
145
+ self.layer3 = self._make_layer(block, 14, 256, num_blocks[2], stride=2)
146
+ self.layer4 = self._make_layer(block, 7, 512, num_blocks[3], stride=2)
147
+ else:
148
+ self.conv1 = wrapped_conv(32, 3, 64, kernel_size=3, stride=1)
149
+ self.layer1 = self._make_layer(block, 32, 64, num_blocks[0], stride=1)
150
+ self.layer2 = self._make_layer(block, 32, 128, num_blocks[1], stride=2)
151
+ self.layer3 = self._make_layer(block, 16, 256, num_blocks[2], stride=2)
152
+ self.layer4 = self._make_layer(block, 8, 512, num_blocks[3], stride=2)
153
+ self.activation = F.leaky_relu if self.mod else F.relu
154
+ self.feature = None
155
+ self.temp = temp
156
+ self.output_alpha = nn.Linear(512 * block.expansion, num_classes)
157
+ self.output_beta = nn.Linear(512 * block.expansion, num_classes)
158
+ self.output_nu = nn.Linear(512 * block.expansion, num_classes)
159
+ self.output_gamma = nn.Linear(512 * block.expansion, num_classes)
160
+
161
+ def _make_layer(self, block, input_size, planes, num_blocks, stride):
162
+ strides = [stride] + [1] * (num_blocks - 1)
163
+ layers = []
164
+ for stride in strides:
165
+ layers.append(block(input_size, self.wrapped_conv, self.in_planes, planes, stride, self.mod,))
166
+ self.in_planes = planes * block.expansion
167
+ input_size = math.ceil(input_size / stride)
168
+ return nn.Sequential(*layers)
169
+
170
+ def forward(self, x):
171
+ out = self.activation(self.bn1(self.conv1(x)))
172
+ out = self.layer1(out)
173
+ out = self.layer2(out)
174
+ out = self.layer3(out)
175
+ out = self.layer4(out)
176
+ out = F.avg_pool2d(out, 4)
177
+ out = out.view(out.size(0), -1)
178
+ self.feature = out.clone().detach()
179
+ alpha = self.output_alpha(out)
180
+ beta = self.output_beta(out)
181
+ nu = self.output_nu(out)
182
+ gamma = self.output_gamma(out)
183
+ alpha = F.softplus(alpha) + 1
184
+ beta = F.softplus(beta) + 1e-6
185
+ nu = F.softplus(nu) + 1e-6
186
+ gamma = F.softplus(gamma) + 1e-6
187
+ return alpha#[alpha, beta, nu, gamma]
188
+
189
+
190
+ def resnet18_edl(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
191
+ model = ResNet(
192
+ BasicBlock,
193
+ [2, 2, 2, 2],
194
+ spectral_normalization=spectral_normalization,
195
+ mod=mod,
196
+ temp=temp,
197
+ mnist=mnist,
198
+ **kwargs
199
+ )
200
+ return model
201
+
202
+
203
+ def resnet50_edl(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
204
+ model = ResNet(
205
+ Bottleneck,
206
+ [3, 4, 6, 3],
207
+ spectral_normalization=spectral_normalization,
208
+ mod=mod,
209
+ temp=temp,
210
+ mnist=mnist,
211
+ **kwargs
212
+ )
213
+ return model
214
+
215
+
216
+ def resnet101_edl(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
217
+ model = ResNet(
218
+ Bottleneck,
219
+ [3, 4, 23, 3],
220
+ spectral_normalization=spectral_normalization,
221
+ mod=mod,
222
+ temp=temp,
223
+ mnist=mnist,
224
+ **kwargs
225
+ )
226
+ return model
227
+
228
+
229
+ def resnet110_edl(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
230
+ model = ResNet(
231
+ Bottleneck,
232
+ [3, 4, 26, 3],
233
+ spectral_normalization=spectral_normalization,
234
+ mod=mod,
235
+ temp=temp,
236
+ mnist=mnist,
237
+ **kwargs
238
+ )
239
+ return model
240
+
241
+
242
+ def resnet152_edl(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
243
+ model = ResNet(
244
+ Bottleneck,
245
+ [3, 8, 36, 3],
246
+ spectral_normalization=spectral_normalization,
247
+ mod=mod,
248
+ temp=temp,
249
+ mnist=mnist,
250
+ **kwargs
251
+ )
252
+ return model
SPC-UQ/Image_Classification/net/resnet_uq.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pytorch implementation of ResNet models.
3
+ Reference:
4
+ [1] He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: CVPR, 2016.
5
+ """
6
+ import torch
7
+ import math
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from net.spectral_normalization.spectral_norm_conv_inplace import spectral_norm_conv
12
+ from net.spectral_normalization.spectral_norm_fc import spectral_norm_fc
13
+
14
+
15
+ class AvgPoolShortCut(nn.Module):
16
+ def __init__(self, stride, out_c, in_c):
17
+ super(AvgPoolShortCut, self).__init__()
18
+ self.stride = stride
19
+ self.out_c = out_c
20
+ self.in_c = in_c
21
+
22
+ def forward(self, x):
23
+ if x.shape[2] % 2 != 0:
24
+ x = F.avg_pool2d(x, 1, self.stride)
25
+ else:
26
+ x = F.avg_pool2d(x, self.stride, self.stride)
27
+ pad = torch.zeros(x.shape[0], self.out_c - self.in_c, x.shape[2], x.shape[3], device=x.device,)
28
+ x = torch.cat((x, pad), dim=1)
29
+ return x
30
+
31
+
32
+ class BasicBlock(nn.Module):
33
+ expansion = 1
34
+
35
+ def __init__(self, input_size, wrapped_conv, in_planes, planes, stride=1, mod=True):
36
+ super(BasicBlock, self).__init__()
37
+ self.conv1 = wrapped_conv(input_size, in_planes, planes, kernel_size=3, stride=stride)
38
+ self.bn1 = nn.BatchNorm2d(planes)
39
+ self.conv2 = wrapped_conv(math.ceil(input_size / stride), planes, planes, kernel_size=3, stride=1)
40
+ self.bn2 = nn.BatchNorm2d(planes)
41
+ self.mod = mod
42
+ self.activation = F.leaky_relu if self.mod else F.relu
43
+
44
+ self.shortcut = nn.Sequential()
45
+ if stride != 1 or in_planes != self.expansion * planes:
46
+ if mod:
47
+ self.shortcut = nn.Sequential(AvgPoolShortCut(stride, self.expansion * planes, in_planes))
48
+ else:
49
+ self.shortcut = nn.Sequential(
50
+ wrapped_conv(input_size, in_planes, self.expansion * planes, kernel_size=1, stride=stride,),
51
+ nn.BatchNorm2d(planes),
52
+ )
53
+
54
+ def forward(self, x):
55
+ out = self.activation(self.bn1(self.conv1(x)))
56
+ out = self.bn2(self.conv2(out))
57
+ out += self.shortcut(x)
58
+ out = self.activation(out)
59
+ return out
60
+
61
+
62
+ class Bottleneck(nn.Module):
63
+ expansion = 4
64
+
65
+ def __init__(self, input_size, wrapped_conv, in_planes, planes, stride=1, mod=True):
66
+ super(Bottleneck, self).__init__()
67
+ self.conv1 = wrapped_conv(input_size, in_planes, planes, kernel_size=1, stride=1)
68
+ self.bn1 = nn.BatchNorm2d(planes)
69
+ self.conv2 = wrapped_conv(input_size, planes, planes, kernel_size=3, stride=stride)
70
+ self.bn2 = nn.BatchNorm2d(planes)
71
+ self.conv3 = wrapped_conv(math.ceil(input_size / stride), planes, self.expansion * planes, kernel_size=1, stride=1)
72
+ self.bn3 = nn.BatchNorm2d(self.expansion * planes)
73
+ self.mod = mod
74
+ self.activation = F.leaky_relu if self.mod else F.relu
75
+
76
+ self.shortcut = nn.Sequential()
77
+ if stride != 1 or in_planes != self.expansion * planes:
78
+ if mod:
79
+ self.shortcut = nn.Sequential(AvgPoolShortCut(stride, self.expansion * planes, in_planes))
80
+ else:
81
+ self.shortcut = nn.Sequential(
82
+ wrapped_conv(input_size, in_planes, self.expansion * planes, kernel_size=1, stride=stride,),
83
+ nn.BatchNorm2d(self.expansion * planes),
84
+ )
85
+
86
+ def forward(self, x):
87
+ out = self.activation(self.bn1(self.conv1(x)))
88
+ out = self.activation(self.bn2(self.conv2(out)))
89
+ out = self.bn3(self.conv3(out))
90
+ out += self.shortcut(x)
91
+ out = self.activation(out)
92
+ return out
93
+
94
+
95
+ class ResNet(nn.Module):
96
+ def __init__(
97
+ self,
98
+ block,
99
+ num_blocks,
100
+ num_classes=10,
101
+ temp=1.0,
102
+ spectral_normalization=True,
103
+ mod=True,
104
+ coeff=3,
105
+ n_power_iterations=1,
106
+ mnist=False,
107
+ ):
108
+ """
109
+ If the "mod" parameter is set to True, the architecture uses 2 modifications:
110
+ 1. LeakyReLU instead of normal ReLU
111
+ 2. Average Pooling on the residual connections.
112
+ """
113
+ super(ResNet, self).__init__()
114
+ self.in_planes = 64
115
+
116
+ self.mod = mod
117
+
118
+ def wrapped_conv(input_size, in_c, out_c, kernel_size, stride):
119
+ padding = 1 if kernel_size == 3 else 0
120
+
121
+ conv = nn.Conv2d(in_c, out_c, kernel_size, stride, padding, bias=False)
122
+
123
+ if not spectral_normalization:
124
+ return conv
125
+
126
+ # NOTE: Google uses the spectral_norm_fc in all cases
127
+ if kernel_size == 1:
128
+ # use spectral norm fc, because bound are tight for 1x1 convolutions
129
+ wrapped_conv = spectral_norm_fc(conv, coeff, n_power_iterations)
130
+ else:
131
+ # Otherwise use spectral norm conv, with loose bound
132
+ shapes = (in_c, input_size, input_size)
133
+ wrapped_conv = spectral_norm_conv(conv, coeff, shapes, n_power_iterations)
134
+
135
+ return wrapped_conv
136
+
137
+ self.wrapped_conv = wrapped_conv
138
+
139
+ self.bn1 = nn.BatchNorm2d(64)
140
+
141
+ if mnist:
142
+ self.conv1 = wrapped_conv(28, 1, 64, kernel_size=3, stride=1)
143
+ self.layer1 = self._make_layer(block, 28, 64, num_blocks[0], stride=1)
144
+ self.layer2 = self._make_layer(block, 28, 128, num_blocks[1], stride=2)
145
+ self.layer3 = self._make_layer(block, 14, 256, num_blocks[2], stride=2)
146
+ self.layer4 = self._make_layer(block, 7, 512, num_blocks[3], stride=2)
147
+ else:
148
+ self.conv1 = wrapped_conv(32, 3, 64, kernel_size=3, stride=1)
149
+ self.layer1 = self._make_layer(block, 32, 64, num_blocks[0], stride=1)
150
+ self.layer2 = self._make_layer(block, 32, 128, num_blocks[1], stride=2)
151
+ self.layer3 = self._make_layer(block, 16, 256, num_blocks[2], stride=2)
152
+ self.layer4 = self._make_layer(block, 8, 512, num_blocks[3], stride=2)
153
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
154
+ self.activation = F.leaky_relu if self.mod else F.relu
155
+ self.feature = None
156
+ self.temp = temp
157
+
158
+ def make_branch():
159
+ layers = []
160
+ in_features = 512 * block.expansion
161
+ neurons = 512 * block.expansion
162
+ for _ in range(1):
163
+ layers.append(nn.Linear(in_features, neurons))
164
+ layers.append(F.relu)
165
+ # layers.append(nn.Dropout(dropout_p))
166
+ in_features = neurons
167
+ # neurons //= 2
168
+ return nn.Sequential(*layers), nn.Linear(in_features, num_classes)
169
+
170
+ self.hidden_mar, self.mar = make_branch()
171
+ self.hidden_mar_up, self.mar_up = make_branch()
172
+ self.hidden_mar_down, self.mar_down = make_branch()
173
+
174
+
175
+ def forward(self, x):
176
+ mar = self.mar(self.hidden_mar(x))
177
+ mar_up = self.mar_up(self.hidden_mar_up(x))
178
+ mar_down = self.mar_down(self.hidden_mar_down(x))
179
+ return mar, mar_up, mar_down
180
+
181
+ def _make_layer(self, block, input_size, planes, num_blocks, stride):
182
+ strides = [stride] + [1] * (num_blocks - 1)
183
+ layers = []
184
+ for stride in strides:
185
+ layers.append(block(input_size, self.wrapped_conv, self.in_planes, planes, stride, self.mod,))
186
+ self.in_planes = planes * block.expansion
187
+ input_size = math.ceil(input_size / stride)
188
+ return nn.Sequential(*layers)
189
+
190
+ def forward(self, x):
191
+ out = self.activation(self.bn1(self.conv1(x)))
192
+ out = self.layer1(out)
193
+ out = self.layer2(out)
194
+ out = self.layer3(out)
195
+ out = self.layer4(out)
196
+ out = F.avg_pool2d(out, 4)
197
+ out = out.view(out.size(0), -1)
198
+ self.feature = out.clone().detach()
199
+ if self.temp==1:
200
+ pred = self.fc(out)
201
+ else:
202
+ pred = self.fc(out) / self.temp
203
+
204
+ mar = self.mar(self.hidden_mar(out))
205
+ mar_up = self.mar_up(self.hidden_mar_up(out))
206
+ mar_down = self.mar_down(self.hidden_mar_down(out))
207
+ return pred, mar, mar_up, mar_down
208
+
209
+
210
+ def resnet18(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
211
+ model = ResNet(
212
+ BasicBlock,
213
+ [2, 2, 2, 2],
214
+ spectral_normalization=spectral_normalization,
215
+ mod=mod,
216
+ temp=temp,
217
+ mnist=mnist,
218
+ **kwargs
219
+ )
220
+ return model
221
+
222
+
223
+ def resnet50(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
224
+ model = ResNet(
225
+ Bottleneck,
226
+ [3, 4, 6, 3],
227
+ spectral_normalization=spectral_normalization,
228
+ mod=mod,
229
+ temp=temp,
230
+ mnist=mnist,
231
+ **kwargs
232
+ )
233
+ return model
234
+
235
+
236
+ def resnet101(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
237
+ model = ResNet(
238
+ Bottleneck,
239
+ [3, 4, 23, 3],
240
+ spectral_normalization=spectral_normalization,
241
+ mod=mod,
242
+ temp=temp,
243
+ mnist=mnist,
244
+ **kwargs
245
+ )
246
+ return model
247
+
248
+
249
+ def resnet110(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
250
+ model = ResNet(
251
+ Bottleneck,
252
+ [3, 4, 26, 3],
253
+ spectral_normalization=spectral_normalization,
254
+ mod=mod,
255
+ temp=temp,
256
+ mnist=mnist,
257
+ **kwargs
258
+ )
259
+ return model
260
+
261
+
262
+ def resnet152(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
263
+ model = ResNet(
264
+ Bottleneck,
265
+ [3, 8, 36, 3],
266
+ spectral_normalization=spectral_normalization,
267
+ mod=mod,
268
+ temp=temp,
269
+ mnist=mnist,
270
+ **kwargs
271
+ )
272
+ return model
SPC-UQ/Image_Classification/net/spectral_normalization/__init__.py ADDED
File without changes