Z
commited on
Upload 3523 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +7 -0
- SPC-UQ/.idea/.gitignore +8 -0
- SPC-UQ/.idea/UQ_baseline.iml +12 -0
- SPC-UQ/.idea/inspectionProfiles/Project_Default.xml +24 -0
- SPC-UQ/.idea/inspectionProfiles/profiles_settings.xml +6 -0
- SPC-UQ/.idea/misc.xml +7 -0
- SPC-UQ/.idea/modules.xml +8 -0
- SPC-UQ/.idea/other.xml +6 -0
- SPC-UQ/.idea/workspace.xml +247 -0
- SPC-UQ/Cubic_Regression/ConformalRegression.py +76 -0
- SPC-UQ/Cubic_Regression/DeepEnsembleRegression.py +93 -0
- SPC-UQ/Cubic_Regression/EDLQuantileRegression.py +155 -0
- SPC-UQ/Cubic_Regression/EDLRegression.py +141 -0
- SPC-UQ/Cubic_Regression/QROC.py +125 -0
- SPC-UQ/Cubic_Regression/SPCRegression.py +173 -0
- SPC-UQ/Cubic_Regression/__pycache__/ConformalRegression.cpython-37.pyc +0 -0
- SPC-UQ/Cubic_Regression/__pycache__/DeepEnsembleRegression.cpython-37.pyc +0 -0
- SPC-UQ/Cubic_Regression/__pycache__/EDLQuantileRegression.cpython-37.pyc +0 -0
- SPC-UQ/Cubic_Regression/__pycache__/EDLRegression.cpython-37.pyc +0 -0
- SPC-UQ/Cubic_Regression/__pycache__/QROC.cpython-37.pyc +0 -0
- SPC-UQ/Cubic_Regression/__pycache__/SPCRegression.cpython-37.pyc +0 -0
- SPC-UQ/Cubic_Regression/run_cubic_tests.py +335 -0
- SPC-UQ/Image_Classification/README.md +285 -0
- SPC-UQ/Image_Classification/data/__init__.py +0 -0
- SPC-UQ/Image_Classification/data/ood_detection/__init__.py +0 -0
- SPC-UQ/Image_Classification/data/ood_detection/cifar10.py +107 -0
- SPC-UQ/Image_Classification/data/ood_detection/cifar100.py +107 -0
- SPC-UQ/Image_Classification/data/ood_detection/imagenet.py +85 -0
- SPC-UQ/Image_Classification/data/ood_detection/imagenet_a.py +37 -0
- SPC-UQ/Image_Classification/data/ood_detection/imagenet_o.py +37 -0
- SPC-UQ/Image_Classification/data/ood_detection/ood_union.py +105 -0
- SPC-UQ/Image_Classification/data/ood_detection/svhn.py +94 -0
- SPC-UQ/Image_Classification/data/ood_detection/tinyimagenet.py +115 -0
- SPC-UQ/Image_Classification/environment.yml +16 -0
- SPC-UQ/Image_Classification/evaluate.py +1427 -0
- SPC-UQ/Image_Classification/evaluate_laplace.py +355 -0
- SPC-UQ/Image_Classification/metrics/__init__.py +0 -0
- SPC-UQ/Image_Classification/metrics/calibration_metrics.py +129 -0
- SPC-UQ/Image_Classification/metrics/classification_metrics.py +211 -0
- SPC-UQ/Image_Classification/metrics/ood_metrics.py +135 -0
- SPC-UQ/Image_Classification/metrics/uncertainty_confidence.py +67 -0
- SPC-UQ/Image_Classification/net/__init__.py +0 -0
- SPC-UQ/Image_Classification/net/imagenet_vgg.py +106 -0
- SPC-UQ/Image_Classification/net/imagenet_vit.py +101 -0
- SPC-UQ/Image_Classification/net/imagenet_wide.py +46 -0
- SPC-UQ/Image_Classification/net/lenet.py +37 -0
- SPC-UQ/Image_Classification/net/resnet.py +245 -0
- SPC-UQ/Image_Classification/net/resnet_edl.py +252 -0
- SPC-UQ/Image_Classification/net/resnet_uq.py +272 -0
- SPC-UQ/Image_Classification/net/spectral_normalization/__init__.py +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
SPC-UQ/MNIST_Classification/data/FashionMNIST/raw/t10k-images-idx3-ubyte filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
SPC-UQ/MNIST_Classification/data/FashionMNIST/raw/train-images-idx3-ubyte filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
SPC-UQ/MNIST_Classification/data/MNIST/raw/t10k-images-idx3-ubyte filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
SPC-UQ/MNIST_Classification/data/MNIST/raw/train-images-idx3-ubyte filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
SPC-UQ/UCI_Benchmarks/data/uci/concrete/Concrete_Data.xls filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
SPC-UQ/UCI_Benchmarks/data/uci/power-plant/Folds5x2_pp.ods filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
SPC-UQ/UCI_Benchmarks/data/uci/power-plant/Folds5x2_pp.xlsx filter=lfs diff=lfs merge=lfs -text
|
SPC-UQ/.idea/.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default ignored files
|
| 2 |
+
/shelf/
|
| 3 |
+
/workspace.xml
|
| 4 |
+
# Editor-based HTTP Client requests
|
| 5 |
+
/httpRequests/
|
| 6 |
+
# Datasource local storage ignored files
|
| 7 |
+
/dataSources/
|
| 8 |
+
/dataSources.local.xml
|
SPC-UQ/.idea/UQ_baseline.iml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<module type="PYTHON_MODULE" version="4">
|
| 3 |
+
<component name="NewModuleRootManager">
|
| 4 |
+
<content url="file://$MODULE_DIR$" />
|
| 5 |
+
<orderEntry type="jdk" jdkName="py37" jdkType="Python SDK" />
|
| 6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
| 7 |
+
</component>
|
| 8 |
+
<component name="PyDocumentationSettings">
|
| 9 |
+
<option name="format" value="PLAIN" />
|
| 10 |
+
<option name="myDocStringFormat" value="Plain" />
|
| 11 |
+
</component>
|
| 12 |
+
</module>
|
SPC-UQ/.idea/inspectionProfiles/Project_Default.xml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<profile version="1.0">
|
| 3 |
+
<option name="myName" value="Project Default" />
|
| 4 |
+
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
| 5 |
+
<option name="ignoredPackages">
|
| 6 |
+
<value>
|
| 7 |
+
<list size="11">
|
| 8 |
+
<item index="0" class="java.lang.String" itemvalue="tqdm" />
|
| 9 |
+
<item index="1" class="java.lang.String" itemvalue="scipy" />
|
| 10 |
+
<item index="2" class="java.lang.String" itemvalue="tabulate" />
|
| 11 |
+
<item index="3" class="java.lang.String" itemvalue="scikit_learn" />
|
| 12 |
+
<item index="4" class="java.lang.String" itemvalue="matplotlib" />
|
| 13 |
+
<item index="5" class="java.lang.String" itemvalue="gpytorch" />
|
| 14 |
+
<item index="6" class="java.lang.String" itemvalue="torch" />
|
| 15 |
+
<item index="7" class="java.lang.String" itemvalue="setuptools" />
|
| 16 |
+
<item index="8" class="java.lang.String" itemvalue="numpy" />
|
| 17 |
+
<item index="9" class="java.lang.String" itemvalue="torchvision" />
|
| 18 |
+
<item index="10" class="java.lang.String" itemvalue="Pillow" />
|
| 19 |
+
</list>
|
| 20 |
+
</value>
|
| 21 |
+
</option>
|
| 22 |
+
</inspection_tool>
|
| 23 |
+
</profile>
|
| 24 |
+
</component>
|
SPC-UQ/.idea/inspectionProfiles/profiles_settings.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<settings>
|
| 3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
| 4 |
+
<version value="1.0" />
|
| 5 |
+
</settings>
|
| 6 |
+
</component>
|
SPC-UQ/.idea/misc.xml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="Black">
|
| 4 |
+
<option name="sdkName" value="py37" />
|
| 5 |
+
</component>
|
| 6 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="py37" project-jdk-type="Python SDK" />
|
| 7 |
+
</project>
|
SPC-UQ/.idea/modules.xml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectModuleManager">
|
| 4 |
+
<modules>
|
| 5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/UQ_baseline.iml" filepath="$PROJECT_DIR$/.idea/UQ_baseline.iml" />
|
| 6 |
+
</modules>
|
| 7 |
+
</component>
|
| 8 |
+
</project>
|
SPC-UQ/.idea/other.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="PySciProjectComponent">
|
| 4 |
+
<option name="PY_INTERACTIVE_PLOTS_SUGGESTED" value="true" />
|
| 5 |
+
</component>
|
| 6 |
+
</project>
|
SPC-UQ/.idea/workspace.xml
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="AutoImportSettings">
|
| 4 |
+
<option name="autoReloadType" value="SELECTIVE" />
|
| 5 |
+
</component>
|
| 6 |
+
<component name="ChangeListManager">
|
| 7 |
+
<list default="true" id="5a477d09-bea8-4806-81a6-0e58ccd074ce" name="Changes" comment="" />
|
| 8 |
+
<option name="SHOW_DIALOG" value="false" />
|
| 9 |
+
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
| 10 |
+
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
| 11 |
+
<option name="LAST_RESOLUTION" value="IGNORE" />
|
| 12 |
+
</component>
|
| 13 |
+
<component name="FileTemplateManagerImpl">
|
| 14 |
+
<option name="RECENT_TEMPLATES">
|
| 15 |
+
<list>
|
| 16 |
+
<option value="Python Script" />
|
| 17 |
+
</list>
|
| 18 |
+
</option>
|
| 19 |
+
</component>
|
| 20 |
+
<component name="ProjectColorInfo">{
|
| 21 |
+
"associatedIndex": 2
|
| 22 |
+
}</component>
|
| 23 |
+
<component name="ProjectId" id="30Vf7yOdfmRNkiAM7wBXYZUSVri" />
|
| 24 |
+
<component name="ProjectViewState">
|
| 25 |
+
<option name="hideEmptyMiddlePackages" value="true" />
|
| 26 |
+
<option name="showLibraryContents" value="true" />
|
| 27 |
+
</component>
|
| 28 |
+
<component name="PropertiesComponent">{
|
| 29 |
+
"keyToString": {
|
| 30 |
+
"Python.evidential.executor": "Run",
|
| 31 |
+
"Python.re.executor": "Run",
|
| 32 |
+
"Python.rename.executor": "Run",
|
| 33 |
+
"Python.run_cls_tests.executor": "Run",
|
| 34 |
+
"Python.run_cubic_tests.executor": "Run",
|
| 35 |
+
"Python.run_toy_tests.executor": "Run",
|
| 36 |
+
"Python.run_uci_dataset_tests (1).executor": "Run",
|
| 37 |
+
"Python.run_uci_dataset_tests (2).executor": "Run",
|
| 38 |
+
"Python.run_uci_dataset_tests.executor": "Run",
|
| 39 |
+
"RunOnceActivity.ShowReadmeOnStart": "true",
|
| 40 |
+
"last_opened_file_path": "E:/Experiment/SPC-UQ/Depth_regression/trainers",
|
| 41 |
+
"node.js.detected.package.eslint": "true",
|
| 42 |
+
"node.js.detected.package.tslint": "true",
|
| 43 |
+
"node.js.selected.package.eslint": "(autodetect)",
|
| 44 |
+
"node.js.selected.package.tslint": "(autodetect)",
|
| 45 |
+
"nodejs_package_manager_path": "npm",
|
| 46 |
+
"vue.rearranger.settings.migration": "true"
|
| 47 |
+
}
|
| 48 |
+
}</component>
|
| 49 |
+
<component name="RecentsManager">
|
| 50 |
+
<key name="CopyFile.RECENT_KEYS">
|
| 51 |
+
<recent name="E:\Experiment\SPC-UQ\Depth_regression\trainers" />
|
| 52 |
+
</key>
|
| 53 |
+
<key name="MoveFile.RECENT_KEYS">
|
| 54 |
+
<recent name="E:\Experiment\SPC-UQ\Depth_regression" />
|
| 55 |
+
</key>
|
| 56 |
+
</component>
|
| 57 |
+
<component name="RunManager" selected="Python.run_cls_tests">
|
| 58 |
+
<configuration name="run_cls_tests" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
| 59 |
+
<module name="UQ_baseline" />
|
| 60 |
+
<option name="ENV_FILES" value="" />
|
| 61 |
+
<option name="INTERPRETER_OPTIONS" value="" />
|
| 62 |
+
<option name="PARENT_ENVS" value="true" />
|
| 63 |
+
<envs>
|
| 64 |
+
<env name="PYTHONUNBUFFERED" value="1" />
|
| 65 |
+
</envs>
|
| 66 |
+
<option name="SDK_HOME" value="" />
|
| 67 |
+
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/MNIST_Classification" />
|
| 68 |
+
<option name="IS_MODULE_SDK" value="true" />
|
| 69 |
+
<option name="ADD_CONTENT_ROOTS" value="true" />
|
| 70 |
+
<option name="ADD_SOURCE_ROOTS" value="true" />
|
| 71 |
+
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
| 72 |
+
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/MNIST_Classification/run_cls_tests.py" />
|
| 73 |
+
<option name="PARAMETERS" value="" />
|
| 74 |
+
<option name="SHOW_COMMAND_LINE" value="false" />
|
| 75 |
+
<option name="EMULATE_TERMINAL" value="false" />
|
| 76 |
+
<option name="MODULE_MODE" value="false" />
|
| 77 |
+
<option name="REDIRECT_INPUT" value="false" />
|
| 78 |
+
<option name="INPUT_FILE" value="" />
|
| 79 |
+
<method v="2" />
|
| 80 |
+
</configuration>
|
| 81 |
+
<configuration name="run_cubic_tests" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
| 82 |
+
<module name="UQ_baseline" />
|
| 83 |
+
<option name="ENV_FILES" value="" />
|
| 84 |
+
<option name="INTERPRETER_OPTIONS" value="" />
|
| 85 |
+
<option name="PARENT_ENVS" value="true" />
|
| 86 |
+
<envs>
|
| 87 |
+
<env name="PYTHONUNBUFFERED" value="1" />
|
| 88 |
+
</envs>
|
| 89 |
+
<option name="SDK_HOME" value="" />
|
| 90 |
+
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/Cubic_Regression" />
|
| 91 |
+
<option name="IS_MODULE_SDK" value="true" />
|
| 92 |
+
<option name="ADD_CONTENT_ROOTS" value="true" />
|
| 93 |
+
<option name="ADD_SOURCE_ROOTS" value="true" />
|
| 94 |
+
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
| 95 |
+
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/Cubic_Regression/run_cubic_tests.py" />
|
| 96 |
+
<option name="PARAMETERS" value="" />
|
| 97 |
+
<option name="SHOW_COMMAND_LINE" value="false" />
|
| 98 |
+
<option name="EMULATE_TERMINAL" value="false" />
|
| 99 |
+
<option name="MODULE_MODE" value="false" />
|
| 100 |
+
<option name="REDIRECT_INPUT" value="false" />
|
| 101 |
+
<option name="INPUT_FILE" value="" />
|
| 102 |
+
<method v="2" />
|
| 103 |
+
</configuration>
|
| 104 |
+
<configuration name="run_toy_tests" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
| 105 |
+
<module name="UQ_baseline" />
|
| 106 |
+
<option name="ENV_FILES" value="" />
|
| 107 |
+
<option name="INTERPRETER_OPTIONS" value="" />
|
| 108 |
+
<option name="PARENT_ENVS" value="true" />
|
| 109 |
+
<envs>
|
| 110 |
+
<env name="PYTHONUNBUFFERED" value="1" />
|
| 111 |
+
</envs>
|
| 112 |
+
<option name="SDK_HOME" value="" />
|
| 113 |
+
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/Toy_regression" />
|
| 114 |
+
<option name="IS_MODULE_SDK" value="true" />
|
| 115 |
+
<option name="ADD_CONTENT_ROOTS" value="true" />
|
| 116 |
+
<option name="ADD_SOURCE_ROOTS" value="true" />
|
| 117 |
+
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
| 118 |
+
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/Toy_regression/run_toy_tests.py" />
|
| 119 |
+
<option name="PARAMETERS" value="" />
|
| 120 |
+
<option name="SHOW_COMMAND_LINE" value="false" />
|
| 121 |
+
<option name="EMULATE_TERMINAL" value="false" />
|
| 122 |
+
<option name="MODULE_MODE" value="false" />
|
| 123 |
+
<option name="REDIRECT_INPUT" value="false" />
|
| 124 |
+
<option name="INPUT_FILE" value="" />
|
| 125 |
+
<method v="2" />
|
| 126 |
+
</configuration>
|
| 127 |
+
<configuration name="run_uci_dataset_tests (1)" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
| 128 |
+
<module name="UQ_baseline" />
|
| 129 |
+
<option name="ENV_FILES" value="" />
|
| 130 |
+
<option name="INTERPRETER_OPTIONS" value="" />
|
| 131 |
+
<option name="PARENT_ENVS" value="true" />
|
| 132 |
+
<envs>
|
| 133 |
+
<env name="PYTHONUNBUFFERED" value="1" />
|
| 134 |
+
</envs>
|
| 135 |
+
<option name="SDK_HOME" value="" />
|
| 136 |
+
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/UCI_Benchmarks" />
|
| 137 |
+
<option name="IS_MODULE_SDK" value="true" />
|
| 138 |
+
<option name="ADD_CONTENT_ROOTS" value="true" />
|
| 139 |
+
<option name="ADD_SOURCE_ROOTS" value="true" />
|
| 140 |
+
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
| 141 |
+
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/UCI_Benchmarks/run_uci_dataset_tests.py" />
|
| 142 |
+
<option name="PARAMETERS" value="" />
|
| 143 |
+
<option name="SHOW_COMMAND_LINE" value="false" />
|
| 144 |
+
<option name="EMULATE_TERMINAL" value="false" />
|
| 145 |
+
<option name="MODULE_MODE" value="false" />
|
| 146 |
+
<option name="REDIRECT_INPUT" value="false" />
|
| 147 |
+
<option name="INPUT_FILE" value="" />
|
| 148 |
+
<method v="2" />
|
| 149 |
+
</configuration>
|
| 150 |
+
<configuration name="run_uci_dataset_tests (2)" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
| 151 |
+
<module name="UQ_baseline" />
|
| 152 |
+
<option name="ENV_FILES" value="" />
|
| 153 |
+
<option name="INTERPRETER_OPTIONS" value="" />
|
| 154 |
+
<option name="PARENT_ENVS" value="true" />
|
| 155 |
+
<envs>
|
| 156 |
+
<env name="PYTHONUNBUFFERED" value="1" />
|
| 157 |
+
</envs>
|
| 158 |
+
<option name="SDK_HOME" value="" />
|
| 159 |
+
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/../UCI_Benchmarks" />
|
| 160 |
+
<option name="IS_MODULE_SDK" value="false" />
|
| 161 |
+
<option name="ADD_CONTENT_ROOTS" value="true" />
|
| 162 |
+
<option name="ADD_SOURCE_ROOTS" value="true" />
|
| 163 |
+
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
| 164 |
+
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/../UCI_Benchmarks/run_uci_dataset_tests.py" />
|
| 165 |
+
<option name="PARAMETERS" value="" />
|
| 166 |
+
<option name="SHOW_COMMAND_LINE" value="false" />
|
| 167 |
+
<option name="EMULATE_TERMINAL" value="false" />
|
| 168 |
+
<option name="MODULE_MODE" value="false" />
|
| 169 |
+
<option name="REDIRECT_INPUT" value="false" />
|
| 170 |
+
<option name="INPUT_FILE" value="" />
|
| 171 |
+
<method v="2" />
|
| 172 |
+
</configuration>
|
| 173 |
+
<recent_temporary>
|
| 174 |
+
<list>
|
| 175 |
+
<item itemvalue="Python.run_cls_tests" />
|
| 176 |
+
<item itemvalue="Python.run_cubic_tests" />
|
| 177 |
+
<item itemvalue="Python.run_uci_dataset_tests (2)" />
|
| 178 |
+
<item itemvalue="Python.run_uci_dataset_tests (1)" />
|
| 179 |
+
<item itemvalue="Python.run_toy_tests" />
|
| 180 |
+
</list>
|
| 181 |
+
</recent_temporary>
|
| 182 |
+
</component>
|
| 183 |
+
<component name="SharedIndexes">
|
| 184 |
+
<attachedChunks>
|
| 185 |
+
<set>
|
| 186 |
+
<option value="bundled-js-predefined-1d06a55b98c1-91d5c284f522-JavaScript-PY-241.15989.155" />
|
| 187 |
+
<option value="bundled-python-sdk-babbdf50b680-7c6932dee5e4-com.jetbrains.pycharm.pro.sharedIndexes.bundled-PY-241.15989.155" />
|
| 188 |
+
</set>
|
| 189 |
+
</attachedChunks>
|
| 190 |
+
</component>
|
| 191 |
+
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
|
| 192 |
+
<component name="TaskManager">
|
| 193 |
+
<task active="true" id="Default" summary="Default task">
|
| 194 |
+
<changelist id="5a477d09-bea8-4806-81a6-0e58ccd074ce" name="Changes" comment="" />
|
| 195 |
+
<created>1753717498209</created>
|
| 196 |
+
<option name="number" value="Default" />
|
| 197 |
+
<option name="presentableId" value="Default" />
|
| 198 |
+
<updated>1753717498209</updated>
|
| 199 |
+
<workItem from="1753717499301" duration="61000" />
|
| 200 |
+
<workItem from="1753717566735" duration="10589000" />
|
| 201 |
+
<workItem from="1753789230404" duration="17167000" />
|
| 202 |
+
<workItem from="1753911784054" duration="686000" />
|
| 203 |
+
<workItem from="1753975576642" duration="9813000" />
|
| 204 |
+
<workItem from="1754180731050" duration="28000" />
|
| 205 |
+
<workItem from="1754221817091" duration="34326000" />
|
| 206 |
+
<workItem from="1754311631259" duration="1625000" />
|
| 207 |
+
<workItem from="1754430709706" duration="5414000" />
|
| 208 |
+
<workItem from="1754511717348" duration="1628000" />
|
| 209 |
+
<workItem from="1754839988337" duration="2408000" />
|
| 210 |
+
<workItem from="1754919749833" duration="13485000" />
|
| 211 |
+
<workItem from="1755004077990" duration="3927000" />
|
| 212 |
+
<workItem from="1756907577926" duration="3062000" />
|
| 213 |
+
<workItem from="1756995689716" duration="6564000" />
|
| 214 |
+
<workItem from="1757086350108" duration="1819000" />
|
| 215 |
+
<workItem from="1757163460686" duration="1011000" />
|
| 216 |
+
</task>
|
| 217 |
+
<servers />
|
| 218 |
+
</component>
|
| 219 |
+
<component name="TypeScriptGeneratedFilesManager">
|
| 220 |
+
<option name="version" value="3" />
|
| 221 |
+
</component>
|
| 222 |
+
<component name="XDebuggerManager">
|
| 223 |
+
<breakpoint-manager>
|
| 224 |
+
<default-breakpoints>
|
| 225 |
+
<breakpoint type="python-exception">
|
| 226 |
+
<properties notifyOnTerminate="true" exception="BaseException">
|
| 227 |
+
<option name="notifyOnTerminate" value="true" />
|
| 228 |
+
</properties>
|
| 229 |
+
</breakpoint>
|
| 230 |
+
</default-breakpoints>
|
| 231 |
+
</breakpoint-manager>
|
| 232 |
+
</component>
|
| 233 |
+
<component name="com.intellij.coverage.CoverageDataManagerImpl">
|
| 234 |
+
<SUITE FILE_PATH="coverage/UQ_baseline$evidential.coverage" NAME="evidential Coverage Results" MODIFIED="1753831104933" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/Toy_classification/trainers" />
|
| 235 |
+
<SUITE FILE_PATH="coverage/SPC_UQ$run_toy_tests.coverage" NAME="run_toy_tests Coverage Results" MODIFIED="1754511868967" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/Toy_regression" />
|
| 236 |
+
<SUITE FILE_PATH="coverage/SPC_UQ$rename.coverage" NAME="rename Coverage Results" MODIFIED="1754235285059" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/UCI_regression/trainers" />
|
| 237 |
+
<SUITE FILE_PATH="coverage/UQ_baseline$run_toy_tests.coverage" NAME="run_toy_tests Coverage Results" MODIFIED="1753976501175" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/Toy_regression" />
|
| 238 |
+
<SUITE FILE_PATH="coverage/UQ_baseline$run_cls_tests.coverage" NAME="run_cls_tests Coverage Results" MODIFIED="1753837116769" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/Toy_classification" />
|
| 239 |
+
<SUITE FILE_PATH="coverage/SPC_UQ$run_uci_dataset_tests__1_.coverage" NAME="run_uci_dataset_tests (1) Coverage Results" MODIFIED="1756909439135" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/UCI_Benchmarks" />
|
| 240 |
+
<SUITE FILE_PATH="coverage/SPC_UQ$run_uci_dataset_tests__2_.coverage" NAME="run_uci_dataset_tests (2) Coverage Results" MODIFIED="1756909699750" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/../UCI_Benchmarks" />
|
| 241 |
+
<SUITE FILE_PATH="coverage/SPC_UQ$run_uci_dataset_tests.coverage" NAME="run_uci_dataset_tests Coverage Results" MODIFIED="1754253016762" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/UCI_regression" />
|
| 242 |
+
<SUITE FILE_PATH="coverage/SPC_UQ$re.coverage" NAME="re Coverage Results" MODIFIED="1754263188586" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/Depth_regression/save" />
|
| 243 |
+
<SUITE FILE_PATH="coverage/UQ_baseline$run_uci_dataset_tests.coverage" NAME="run_uci_dataset_tests Coverage Results" MODIFIED="1753995527450" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/UCI_regression" />
|
| 244 |
+
<SUITE FILE_PATH="coverage/SPC_UQ$run_cubic_tests.coverage" NAME="run_cubic_tests Coverage Results" MODIFIED="1756996456492" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/Cubic_Regression" />
|
| 245 |
+
<SUITE FILE_PATH="coverage/SPC_UQ$run_cls_tests.coverage" NAME="run_cls_tests Coverage Results" MODIFIED="1757002010234" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$/MNIST_Classification" />
|
| 246 |
+
</component>
|
| 247 |
+
</project>
|
SPC-UQ/Cubic_Regression/ConformalRegression.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ConformalRegressionNet(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
Simple feedforward regression model with dropout.
|
| 10 |
+
Output: point prediction only (no uncertainty head).
|
| 11 |
+
"""
|
| 12 |
+
def __init__(self, input_dim=1, hidden_dim=64, output_dim=1):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
| 15 |
+
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
|
| 16 |
+
self.out = nn.Linear(hidden_dim, output_dim)
|
| 17 |
+
self.relu = nn.ReLU()
|
| 18 |
+
self.dropout = nn.Dropout(p=0.2)
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
x = self.relu(self.fc1(x))
|
| 22 |
+
x = self.dropout(x)
|
| 23 |
+
x = self.relu(self.fc2(x))
|
| 24 |
+
x = self.dropout(x)
|
| 25 |
+
return self.out(x)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ConformalRegressor:
|
| 29 |
+
"""
|
| 30 |
+
Quantile-based conformal prediction regression model.
|
| 31 |
+
"""
|
| 32 |
+
def __init__(self, quantile=0.9, learning_rate=5e-3):
|
| 33 |
+
torch.manual_seed(24)
|
| 34 |
+
self.quantile = quantile
|
| 35 |
+
self.model = ConformalRegressionNet()
|
| 36 |
+
self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
|
| 37 |
+
self.criterion = nn.MSELoss()
|
| 38 |
+
self.quantile_up = 0.0
|
| 39 |
+
self.quantile_down = 0.0
|
| 40 |
+
|
| 41 |
+
def train(self, x, y, num_epochs=5000):
|
| 42 |
+
self.model.train()
|
| 43 |
+
for epoch in range(num_epochs):
|
| 44 |
+
self.optimizer.zero_grad()
|
| 45 |
+
pred = self.model(x)
|
| 46 |
+
loss = self.criterion(pred, y)
|
| 47 |
+
loss.backward()
|
| 48 |
+
self.optimizer.step()
|
| 49 |
+
if (epoch + 1) % 100 == 0:
|
| 50 |
+
print(f"[Epoch {epoch + 1}/{num_epochs}] Loss: {loss.item():.4f}")
|
| 51 |
+
|
| 52 |
+
def calibrate(self, x_calib, y_calib):
|
| 53 |
+
"""
|
| 54 |
+
Compute empirical quantiles from residuals on calibration set.
|
| 55 |
+
"""
|
| 56 |
+
self.model.eval()
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
pred = self.model(x_calib).detach().cpu().numpy().squeeze()
|
| 59 |
+
y_calib_np = y_calib.detach().cpu().numpy().squeeze()
|
| 60 |
+
residuals = y_calib_np - pred
|
| 61 |
+
|
| 62 |
+
res_up = residuals[residuals > 0]
|
| 63 |
+
res_down = -residuals[residuals <= 0]
|
| 64 |
+
|
| 65 |
+
self.quantile_up = np.quantile(res_up, self.quantile) if len(res_up) > 0 else 0.0
|
| 66 |
+
self.quantile_down = np.quantile(res_down, self.quantile) if len(res_down) > 0 else 0.0
|
| 67 |
+
|
| 68 |
+
def predict(self, x):
|
| 69 |
+
self.model.eval()
|
| 70 |
+
with torch.no_grad():
|
| 71 |
+
y_pred = self.model(x).detach().cpu().numpy().squeeze()
|
| 72 |
+
upper = y_pred + self.quantile_up
|
| 73 |
+
lower = y_pred - self.quantile_down
|
| 74 |
+
uncertainty = np.zeros_like(y_pred) # Conformal prediction doesn't model epistemic
|
| 75 |
+
|
| 76 |
+
return y_pred, upper, lower, uncertainty
|
SPC-UQ/Cubic_Regression/DeepEnsembleRegression.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
def nll_loss(mean, log_var, target):
|
| 7 |
+
"""
|
| 8 |
+
Negative log-likelihood loss for Gaussian output
|
| 9 |
+
"""
|
| 10 |
+
var = torch.exp(log_var)
|
| 11 |
+
loss = 0.5 * torch.log(2 * np.pi * var) + 0.5 * ((target - mean) ** 2) / var
|
| 12 |
+
return loss.mean()
|
| 13 |
+
|
| 14 |
+
class NLLRegressionNN(nn.Module):
|
| 15 |
+
"""
|
| 16 |
+
Neural network for regression with Gaussian likelihood output
|
| 17 |
+
Outputs mean and log-variance
|
| 18 |
+
"""
|
| 19 |
+
def __init__(self, input_dim=1, hidden_dim=64):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
| 22 |
+
self.hidden = nn.Linear(hidden_dim, hidden_dim)
|
| 23 |
+
self.relu = nn.ReLU()
|
| 24 |
+
self.fc2 = nn.Linear(hidden_dim, 2) # Outputs: mean and log_variance
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
x = self.relu(self.fc1(x))
|
| 28 |
+
x = self.relu(self.hidden(x))
|
| 29 |
+
x = self.fc2(x)
|
| 30 |
+
mean = x[:, :1]
|
| 31 |
+
log_var = x[:, 1:]
|
| 32 |
+
return mean, log_var
|
| 33 |
+
|
| 34 |
+
class DeepEnsemble:
|
| 35 |
+
"""
|
| 36 |
+
Deep Ensemble for probabilistic regression with Gaussian likelihood
|
| 37 |
+
"""
|
| 38 |
+
def __init__(self, num_models=5, learning_rate=5e-3):
|
| 39 |
+
torch.manual_seed(42)
|
| 40 |
+
self.models = [NLLRegressionNN() for _ in range(num_models)]
|
| 41 |
+
self.optimizers = [optim.Adam(model.parameters(), lr=learning_rate) for model in self.models]
|
| 42 |
+
|
| 43 |
+
def train(self, data, target, num_epochs=5000):
|
| 44 |
+
"""
|
| 45 |
+
Train all models independently on the same data
|
| 46 |
+
"""
|
| 47 |
+
for idx, (model, optimizer) in enumerate(zip(self.models, self.optimizers), start=1):
|
| 48 |
+
torch.manual_seed(idx + 42) # Different seed for each model
|
| 49 |
+
for epoch in range(num_epochs):
|
| 50 |
+
model.train()
|
| 51 |
+
optimizer.zero_grad()
|
| 52 |
+
mean, log_var = model(data)
|
| 53 |
+
loss = nll_loss(mean, log_var, target)
|
| 54 |
+
loss.backward()
|
| 55 |
+
optimizer.step()
|
| 56 |
+
|
| 57 |
+
if (epoch + 1) % 500 == 0:
|
| 58 |
+
print(f"Model {idx}, Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")
|
| 59 |
+
|
| 60 |
+
def predict(self, data):
|
| 61 |
+
"""
|
| 62 |
+
Return ensemble mean, uncertainty, and prediction interval
|
| 63 |
+
"""
|
| 64 |
+
means = []
|
| 65 |
+
variances = []
|
| 66 |
+
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
for model in self.models:
|
| 69 |
+
model.eval()
|
| 70 |
+
mean, log_var = model(data)
|
| 71 |
+
means.append(mean.numpy())
|
| 72 |
+
variances.append(torch.exp(log_var).numpy())
|
| 73 |
+
|
| 74 |
+
means = np.array(means) # (num_models, batch, 1)
|
| 75 |
+
variances = np.array(variances)
|
| 76 |
+
|
| 77 |
+
# Mean and total predictive variance (mean of model variances)
|
| 78 |
+
mean_ensemble = np.mean(means, axis=0)
|
| 79 |
+
var_ensemble = np.mean(variances, axis=0)
|
| 80 |
+
|
| 81 |
+
# Epistemic uncertainty: variance across model means
|
| 82 |
+
epistemic_uncertainty = np.var(means, axis=0)
|
| 83 |
+
|
| 84 |
+
std_ensemble = np.sqrt(var_ensemble)
|
| 85 |
+
y_low = mean_ensemble - 2 * std_ensemble
|
| 86 |
+
y_high = mean_ensemble + 2 * std_ensemble
|
| 87 |
+
|
| 88 |
+
return (
|
| 89 |
+
mean_ensemble.squeeze(),
|
| 90 |
+
y_high.squeeze(),
|
| 91 |
+
y_low.squeeze(),
|
| 92 |
+
epistemic_uncertainty.squeeze()
|
| 93 |
+
)
|
SPC-UQ/Cubic_Regression/EDLQuantileRegression.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
import numpy as np
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DenseNormalGamma(nn.Module):
|
| 11 |
+
def __init__(self, in_features, out_units):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.out_units = int(out_units)
|
| 14 |
+
self.linear = nn.Linear(in_features, 4 * self.out_units)
|
| 15 |
+
|
| 16 |
+
def evidence(self, x):
|
| 17 |
+
return F.softplus(x)
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
output = self.linear(x)
|
| 21 |
+
mu, log_v, log_alpha, log_beta = torch.chunk(output, chunks=4, dim=-1)
|
| 22 |
+
v = self.evidence(log_v)
|
| 23 |
+
alpha = self.evidence(log_alpha) + 1.0
|
| 24 |
+
beta = self.evidence(log_beta)
|
| 25 |
+
return torch.cat([mu, v, alpha, beta], dim=-1)
|
| 26 |
+
|
| 27 |
+
def extra_repr(self):
|
| 28 |
+
return f"out_units={self.out_units}"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class EDLQRNet(nn.Module):
|
| 32 |
+
def __init__(self, input_dim=1, num_quantiles=3, hidden_dim=64, num_layers=2, activation=nn.ReLU()):
|
| 33 |
+
super().__init__()
|
| 34 |
+
layers = []
|
| 35 |
+
in_features = input_dim
|
| 36 |
+
for _ in range(num_layers):
|
| 37 |
+
layers.append(nn.Linear(in_features, hidden_dim))
|
| 38 |
+
layers.append(activation)
|
| 39 |
+
in_features = hidden_dim
|
| 40 |
+
layers.append(DenseNormalGamma(in_features, num_quantiles))
|
| 41 |
+
self.network = nn.Sequential(*layers)
|
| 42 |
+
|
| 43 |
+
def forward(self, x):
|
| 44 |
+
output = self.network(x)
|
| 45 |
+
mu, v, alpha, beta = torch.chunk(output, 4, dim=-1)
|
| 46 |
+
return mu, v, alpha, beta
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def nig_nll(y, mu, v, alpha, beta, wi_mean, quantile, reduce=True):
|
| 50 |
+
tau2 = 2.0 / (quantile * (1.0 - quantile))
|
| 51 |
+
two_b_lambda = 4.0 * beta * (1.0 + tau2 * wi_mean * v)
|
| 52 |
+
|
| 53 |
+
nll = 0.5 * torch.log(math.pi / v) \
|
| 54 |
+
- alpha * torch.log(two_b_lambda) \
|
| 55 |
+
+ (alpha + 0.5) * torch.log(v * (y - mu) ** 2 + two_b_lambda) \
|
| 56 |
+
+ torch.lgamma(alpha) \
|
| 57 |
+
- torch.lgamma(alpha + 0.5)
|
| 58 |
+
|
| 59 |
+
return torch.mean(nll) if reduce else nll
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def kl_nig(mu1, v1, a1, b1, mu2, v2, a2, b2):
|
| 63 |
+
kl = 0.5 * (a1 - 1.0) / b1 * (v2 * (mu2 - mu1) ** 2) \
|
| 64 |
+
+ 0.5 * (v2 / v1) \
|
| 65 |
+
- 0.5 * torch.log(torch.abs(v2) / torch.abs(v1)) \
|
| 66 |
+
- 0.5 \
|
| 67 |
+
+ a2 * torch.log(b1 / b2) \
|
| 68 |
+
- (torch.lgamma(a1) - torch.lgamma(a2)) \
|
| 69 |
+
+ (a1 - a2) * torch.digamma(a1) \
|
| 70 |
+
- (b1 - b2) * a1 / b1
|
| 71 |
+
return kl
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def tilted_loss(q, e):
|
| 75 |
+
return torch.maximum(q * e, (q - 1.0) * e)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def nig_regularization(y, mu, v, alpha, beta, wi_mean, quantile, lambda_reg=0.01, reduce=True, use_kl=False):
|
| 79 |
+
theta = (1.0 - 2.0 * quantile) / (quantile * (1.0 - quantile))
|
| 80 |
+
error = tilted_loss(quantile, y - mu)
|
| 81 |
+
|
| 82 |
+
if use_kl:
|
| 83 |
+
kl_val = kl_nig(mu, v, alpha, beta, mu, lambda_reg, 1.0 + lambda_reg, beta)
|
| 84 |
+
reg = error * kl_val
|
| 85 |
+
else:
|
| 86 |
+
evidential_term = 2.0 * v + alpha + 1.0 / beta
|
| 87 |
+
reg = error * evidential_term
|
| 88 |
+
|
| 89 |
+
return torch.mean(reg) if reduce else reg
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def quantile_evidential_loss(y_true, mu, v, alpha, beta, quantile, coeff=1.0, reduce=True):
|
| 93 |
+
theta = (1.0 - 2.0 * quantile) / (quantile * (1.0 - quantile))
|
| 94 |
+
wi_mean = beta / (alpha - 1.0)
|
| 95 |
+
mu_adj = mu + theta * wi_mean
|
| 96 |
+
|
| 97 |
+
loss_nll = nig_nll(y_true, mu_adj, v, alpha, beta, wi_mean, quantile, reduce)
|
| 98 |
+
loss_reg = nig_regularization(y_true, mu, v, alpha, beta, wi_mean, quantile, reduce)
|
| 99 |
+
return loss_nll + coeff * loss_reg
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class EDLQuantileRegressor:
|
| 103 |
+
def __init__(self, tau_low=0.05, tau_high=0.95, learning_rate=5e-4):
|
| 104 |
+
torch.manual_seed(42)
|
| 105 |
+
self.model = EDLQRNet()
|
| 106 |
+
self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
|
| 107 |
+
self.quantiles = [tau_low, 0.5, tau_high]
|
| 108 |
+
self.coeff = 0.05
|
| 109 |
+
|
| 110 |
+
def loss_function(self, y, mu, v, alpha, beta):
|
| 111 |
+
total_loss = 0.0
|
| 112 |
+
for i, q in enumerate(self.quantiles):
|
| 113 |
+
total_loss += quantile_evidential_loss(
|
| 114 |
+
y, mu[:, i].unsqueeze(1), v[:, i].unsqueeze(1),
|
| 115 |
+
alpha[:, i].unsqueeze(1), beta[:, i].unsqueeze(1),
|
| 116 |
+
q, coeff=self.coeff
|
| 117 |
+
)
|
| 118 |
+
return total_loss
|
| 119 |
+
|
| 120 |
+
def train(self, x, y, num_epochs=5000):
|
| 121 |
+
self.model.train()
|
| 122 |
+
for epoch in range(num_epochs):
|
| 123 |
+
self.optimizer.zero_grad()
|
| 124 |
+
mu, v, alpha, beta = self.model(x)
|
| 125 |
+
loss = self.loss_function(y, mu, v, alpha, beta)
|
| 126 |
+
loss.backward()
|
| 127 |
+
self.optimizer.step()
|
| 128 |
+
|
| 129 |
+
if (epoch + 1) % 100 == 0:
|
| 130 |
+
print(f"Epoch [{epoch + 1}/{num_epochs}] Loss: {loss.item():.4f}")
|
| 131 |
+
|
| 132 |
+
def predict(self, x):
|
| 133 |
+
self.model.eval()
|
| 134 |
+
with torch.no_grad():
|
| 135 |
+
mu, v, alpha, beta = self.model(x)
|
| 136 |
+
mu_low, mu_mid, mu_high = torch.unbind(mu, dim=1)
|
| 137 |
+
v_low, v_mid, v_high = torch.unbind(v, dim=1)
|
| 138 |
+
alpha_low, alpha_mid, alpha_high = torch.unbind(alpha, dim=1)
|
| 139 |
+
beta_low, beta_mid, beta_high = torch.unbind(beta, dim=1)
|
| 140 |
+
|
| 141 |
+
aleatoric = beta_mid / (alpha_mid - 1.0)
|
| 142 |
+
epistemic_mid = beta_mid / (v_mid * (alpha_mid - 1.0))
|
| 143 |
+
epistemic_low = beta_low / (v_low * (alpha_low - 1.0))
|
| 144 |
+
epistemic_high = beta_high / (v_high * (alpha_high - 1.0))
|
| 145 |
+
uncertainty = epistemic_mid
|
| 146 |
+
|
| 147 |
+
plt.figure(figsize=(10, 6))
|
| 148 |
+
plt.plot(x, epistemic_mid, label='Mid', color='orange')
|
| 149 |
+
plt.plot(x, epistemic_low, label='Low', color='red')
|
| 150 |
+
plt.plot(x, epistemic_high, label='High', color='green')
|
| 151 |
+
plt.legend()
|
| 152 |
+
plt.title('Epistemic Uncertainty')
|
| 153 |
+
plt.show()
|
| 154 |
+
|
| 155 |
+
return mu_mid.numpy().squeeze(), mu_high.numpy().squeeze(), mu_low.numpy().squeeze(), uncertainty.numpy().squeeze()
|
SPC-UQ/Cubic_Regression/EDLRegression.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
from scipy.stats import t as student_t
|
| 6 |
+
import numpy as np
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class EDLRegressionNet(nn.Module):
|
| 11 |
+
def __init__(self, input_dim=1, hidden_dim=64, output_dim=1):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.hidden1 = nn.Linear(input_dim, hidden_dim)
|
| 14 |
+
self.hidden2 = nn.Linear(hidden_dim, hidden_dim)
|
| 15 |
+
self.output_mu = nn.Linear(hidden_dim, output_dim)
|
| 16 |
+
self.output_logv = nn.Linear(hidden_dim, output_dim)
|
| 17 |
+
self.output_alpha = nn.Linear(hidden_dim, output_dim)
|
| 18 |
+
self.output_beta = nn.Linear(hidden_dim, output_dim)
|
| 19 |
+
self.activation = nn.ReLU()
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
x = self.activation(self.hidden1(x))
|
| 23 |
+
x = self.activation(self.hidden2(x))
|
| 24 |
+
mu = self.output_mu(x)
|
| 25 |
+
v = F.softplus(self.output_logv(x))
|
| 26 |
+
alpha = F.softplus(self.output_alpha(x)) + 1.0 + 1e-6
|
| 27 |
+
beta = F.softplus(self.output_beta(x))
|
| 28 |
+
return mu, v, alpha, beta
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def nig_nll(y, mu, v, alpha, beta, reduce=True):
|
| 32 |
+
two_b_lambda = 2 * beta * (1 + v)
|
| 33 |
+
|
| 34 |
+
nll = 0.5 * torch.log(torch.tensor(np.pi) / v) \
|
| 35 |
+
- alpha * torch.log(two_b_lambda) \
|
| 36 |
+
+ (alpha + 0.5) * torch.log(v * (y - mu) ** 2 + two_b_lambda) \
|
| 37 |
+
+ torch.lgamma(alpha) \
|
| 38 |
+
- torch.lgamma(alpha + 0.5)
|
| 39 |
+
|
| 40 |
+
return torch.mean(nll) if reduce else nll
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def kl_nig(mu1, v1, a1, b1, mu2, v2, a2, b2):
|
| 44 |
+
eps = 1e-6
|
| 45 |
+
v1 = torch.clamp(v1, min=eps)
|
| 46 |
+
v2 = torch.clamp(v2, min=eps)
|
| 47 |
+
b1 = torch.clamp(b1, min=eps)
|
| 48 |
+
b2 = torch.clamp(b2, min=eps)
|
| 49 |
+
|
| 50 |
+
term1 = 0.5 * (a1 - 1) / b1 * (v2 * (mu2 - mu1) ** 2)
|
| 51 |
+
term2 = 0.5 * v2 / v1
|
| 52 |
+
term3 = -0.5 * torch.log(v2 / v1)
|
| 53 |
+
term4 = -0.5
|
| 54 |
+
term5 = a2 * torch.log(b1 / b2)
|
| 55 |
+
term6 = -(torch.lgamma(a1) - torch.lgamma(a2))
|
| 56 |
+
term7 = (a1 - a2) * torch.digamma(a1)
|
| 57 |
+
term8 = -(b1 - b2) * a1 / b1
|
| 58 |
+
|
| 59 |
+
kl = term1 + term2 + term3 + term4 + term5 + term6 + term7 + term8
|
| 60 |
+
return kl
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def nig_regularization(y, mu, v, alpha, beta, omega=0.01, reduce=True, use_kl=False):
|
| 64 |
+
error = torch.abs(y - mu)
|
| 65 |
+
|
| 66 |
+
if use_kl:
|
| 67 |
+
kl = kl_nig(mu, v, alpha, beta, mu, omega, 1 + omega, beta)
|
| 68 |
+
reg = error * kl
|
| 69 |
+
else:
|
| 70 |
+
evidential = 2 * v + alpha
|
| 71 |
+
reg = error * evidential
|
| 72 |
+
|
| 73 |
+
return reg.mean() if reduce else reg
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def edl_loss(y, mu, v, alpha, beta, lam=0.0, reduce=True, return_components=False):
|
| 77 |
+
nll = nig_nll(y, mu, v, alpha, beta, reduce=reduce)
|
| 78 |
+
reg = nig_regularization(y, mu, v, alpha, beta, reduce=reduce)
|
| 79 |
+
loss = nll # optionally: loss = nll + lam * reg
|
| 80 |
+
|
| 81 |
+
return (loss, (nll, reg)) if return_components else loss
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def predictive_interval(mu, v, alpha, beta, confidence=0.95):
|
| 85 |
+
mu = torch.as_tensor(mu)
|
| 86 |
+
v = torch.as_tensor(v)
|
| 87 |
+
alpha = torch.as_tensor(alpha)
|
| 88 |
+
beta = torch.as_tensor(beta)
|
| 89 |
+
|
| 90 |
+
dof = 2.0 * alpha
|
| 91 |
+
scale = torch.sqrt((1.0 + v) * beta / (alpha * v))
|
| 92 |
+
lower_q = (1.0 - confidence) / 2.0
|
| 93 |
+
upper_q = 1.0 - lower_q
|
| 94 |
+
|
| 95 |
+
t_l = student_t.ppf(lower_q, df=dof.cpu().numpy())
|
| 96 |
+
t_u = student_t.ppf(upper_q, df=dof.cpu().numpy())
|
| 97 |
+
|
| 98 |
+
lower = mu + torch.from_numpy(t_l).to(mu.device) * scale
|
| 99 |
+
upper = mu + torch.from_numpy(t_u).to(mu.device) * scale
|
| 100 |
+
|
| 101 |
+
return lower, upper
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class EDLRegressor:
|
| 105 |
+
def __init__(self, learning_rate=5e-4):
|
| 106 |
+
torch.manual_seed(33)
|
| 107 |
+
self.model = EDLRegressionNet()
|
| 108 |
+
self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
|
| 109 |
+
self.criterion = edl_loss
|
| 110 |
+
self.lambda_ = 0.01
|
| 111 |
+
|
| 112 |
+
def train(self, x, y, num_epochs=5000):
|
| 113 |
+
torch.manual_seed(33)
|
| 114 |
+
self.model.train()
|
| 115 |
+
for epoch in range(num_epochs):
|
| 116 |
+
self.optimizer.zero_grad()
|
| 117 |
+
mu, v, alpha, beta = self.model(x)
|
| 118 |
+
loss = self.criterion(y, mu, v, alpha, beta, lam=self.lambda_)
|
| 119 |
+
loss.backward()
|
| 120 |
+
self.optimizer.step()
|
| 121 |
+
if (epoch + 1) % 100 == 0:
|
| 122 |
+
print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")
|
| 123 |
+
|
| 124 |
+
def predict(self, x):
|
| 125 |
+
self.model.eval()
|
| 126 |
+
with torch.no_grad():
|
| 127 |
+
mu, v, alpha, beta = self.model(x)
|
| 128 |
+
|
| 129 |
+
# Uncertainty decomposition
|
| 130 |
+
aleatoric = torch.sqrt(beta / (alpha - 1.0 + 1e-6))
|
| 131 |
+
epistemic = torch.sqrt(beta / (v * (alpha - 1.0 + 1e-6)))
|
| 132 |
+
|
| 133 |
+
mu_np = mu.detach().cpu().numpy()
|
| 134 |
+
aleatoric_np = aleatoric.detach().cpu().numpy()
|
| 135 |
+
epistemic_np = epistemic.detach().cpu().numpy()
|
| 136 |
+
|
| 137 |
+
lower, upper = predictive_interval(mu, v, alpha, beta, confidence=0.95)
|
| 138 |
+
lower_np = lower.detach().cpu().numpy()
|
| 139 |
+
upper_np = upper.detach().cpu().numpy()
|
| 140 |
+
|
| 141 |
+
return mu_np.squeeze(), upper_np.squeeze(), lower_np.squeeze(), epistemic_np.squeeze()
|
SPC-UQ/Cubic_Regression/QROC.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
import numpy as np
|
| 5 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 6 |
+
|
| 7 |
+
def pinball_loss(q_pred, target, tau):
|
| 8 |
+
"""
|
| 9 |
+
Compute the quantile (pinball) loss for a given quantile level tau
|
| 10 |
+
"""
|
| 11 |
+
error = target - q_pred
|
| 12 |
+
return torch.mean(torch.max(tau * error, (tau - 1) * error))
|
| 13 |
+
|
| 14 |
+
def multi_quantile_loss(q_low, q_mid, q_high, target, tau_low=0.05, tau_high=0.95):
|
| 15 |
+
"""
|
| 16 |
+
Compute total loss across low, median, and high quantiles
|
| 17 |
+
"""
|
| 18 |
+
loss_l = pinball_loss(q_low, target, tau_low)
|
| 19 |
+
loss_m = pinball_loss(q_mid, target, 0.5)
|
| 20 |
+
loss_h = pinball_loss(q_high, target, tau_high)
|
| 21 |
+
return loss_l + loss_m + loss_h
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class QuantileRegressionNN(nn.Module):
|
| 25 |
+
"""
|
| 26 |
+
Fully-connected quantile regression network with three outputs:
|
| 27 |
+
lower, median, upper quantiles.
|
| 28 |
+
"""
|
| 29 |
+
def __init__(self, input_dim=1, hidden_dim=64):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
| 32 |
+
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
|
| 33 |
+
self.out = nn.Linear(hidden_dim, 3)
|
| 34 |
+
self.relu = nn.ReLU()
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
x = self.relu(self.fc1(x))
|
| 38 |
+
x = self.relu(self.fc2(x))
|
| 39 |
+
out = self.out(x)
|
| 40 |
+
return out[:, :1], out[:, 1:2], out[:, 2:]
|
| 41 |
+
|
| 42 |
+
def extract_features(self, x):
|
| 43 |
+
"""
|
| 44 |
+
Extract penultimate-layer features (for certificate head)
|
| 45 |
+
"""
|
| 46 |
+
x = self.relu(self.fc1(x))
|
| 47 |
+
x = self.relu(self.fc2(x))
|
| 48 |
+
return x
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def build_certificate_head(features, out_dim=20, epochs=500):
|
| 52 |
+
"""
|
| 53 |
+
Train an orthogonal linear projection (certificate head) on extracted features
|
| 54 |
+
"""
|
| 55 |
+
projection = nn.Linear(features.size(1), out_dim)
|
| 56 |
+
loader = DataLoader(TensorDataset(features), shuffle=True, batch_size=128)
|
| 57 |
+
optimizer = optim.Adam(projection.parameters())
|
| 58 |
+
|
| 59 |
+
for _ in range(epochs):
|
| 60 |
+
for (feature_batch,) in loader:
|
| 61 |
+
optimizer.zero_grad()
|
| 62 |
+
output = projection(feature_batch)
|
| 63 |
+
cert_loss = output.pow(2).mean()
|
| 64 |
+
|
| 65 |
+
identity = torch.eye(out_dim, device=projection.weight.device)
|
| 66 |
+
ortho_penalty = (projection.weight @ projection.weight.T - identity).pow(2).mean()
|
| 67 |
+
|
| 68 |
+
(cert_loss + ortho_penalty).backward()
|
| 69 |
+
optimizer.step()
|
| 70 |
+
|
| 71 |
+
return projection
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class QROC:
|
| 75 |
+
"""
|
| 76 |
+
Single-model Quantile Regression with Orthogonal Certificate head (QROC)
|
| 77 |
+
Provides aleatoric and epistemic uncertainty estimation
|
| 78 |
+
"""
|
| 79 |
+
def __init__(self, learning_rate=5e-3, tau_low=0.05, tau_high=0.95):
|
| 80 |
+
torch.manual_seed(42)
|
| 81 |
+
self.model = QuantileRegressionNN()
|
| 82 |
+
self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
|
| 83 |
+
self.tau_low = tau_low
|
| 84 |
+
self.tau_high = tau_high
|
| 85 |
+
self.certificate_head = None
|
| 86 |
+
|
| 87 |
+
def train(self, x, y, epochs=3000):
|
| 88 |
+
"""
|
| 89 |
+
Train the quantile regression network and then fit certificate head on extracted features
|
| 90 |
+
"""
|
| 91 |
+
for epoch in range(epochs):
|
| 92 |
+
self.model.train()
|
| 93 |
+
self.optimizer.zero_grad()
|
| 94 |
+
q_low, q_mid, q_high = self.model(x)
|
| 95 |
+
loss = multi_quantile_loss(q_low, q_mid, q_high, y, self.tau_low, self.tau_high)
|
| 96 |
+
loss.backward()
|
| 97 |
+
self.optimizer.step()
|
| 98 |
+
|
| 99 |
+
if (epoch + 1) % 500 == 0:
|
| 100 |
+
print(f"[{epoch+1}/{epochs}] Loss: {loss.item():.4f}")
|
| 101 |
+
|
| 102 |
+
# Train certificate head using the final feature representation
|
| 103 |
+
self.model.eval()
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
features = self.model.extract_features(x).detach()
|
| 106 |
+
self.certificate_head = build_certificate_head(features)
|
| 107 |
+
|
| 108 |
+
def predict(self, x):
|
| 109 |
+
"""
|
| 110 |
+
Predict quantiles and return aleatoric and epistemic uncertainties
|
| 111 |
+
"""
|
| 112 |
+
self.model.eval()
|
| 113 |
+
with torch.no_grad():
|
| 114 |
+
q_low, q_mid, q_high = self.model(x)
|
| 115 |
+
|
| 116 |
+
# Epistemic: projection energy via orthogonal certificate head
|
| 117 |
+
features = self.model.extract_features(x)
|
| 118 |
+
epistemic = self.certificate_head(features).pow(2).mean(dim=1).cpu().numpy()
|
| 119 |
+
|
| 120 |
+
return (
|
| 121 |
+
q_mid.squeeze().numpy(),
|
| 122 |
+
q_high.squeeze().numpy(),
|
| 123 |
+
q_low.squeeze().numpy(),
|
| 124 |
+
epistemic
|
| 125 |
+
)
|
SPC-UQ/Cubic_Regression/SPCRegression.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
def pinball_loss(q_pred, target, tau):
|
| 8 |
+
"""Standard quantile regression loss."""
|
| 9 |
+
errors = target - q_pred
|
| 10 |
+
loss = torch.max(tau * errors, (tau - 1) * errors)
|
| 11 |
+
return torch.mean(loss)
|
| 12 |
+
|
| 13 |
+
def cali_loss(y_pred, y_true, q, scale=True):
|
| 14 |
+
"""
|
| 15 |
+
Calibration loss for quantile regression.
|
| 16 |
+
Penalizes over- or under-coverage relative to quantile level q.
|
| 17 |
+
"""
|
| 18 |
+
diff = y_true - y_pred
|
| 19 |
+
under_mask = (y_true <= y_pred)
|
| 20 |
+
over_mask = ~under_mask
|
| 21 |
+
|
| 22 |
+
coverage = torch.mean(under_mask.float())
|
| 23 |
+
|
| 24 |
+
if coverage < q:
|
| 25 |
+
loss = torch.mean(diff[over_mask])
|
| 26 |
+
else:
|
| 27 |
+
loss = torch.mean(-diff[under_mask])
|
| 28 |
+
|
| 29 |
+
if scale:
|
| 30 |
+
loss *= torch.abs(q - coverage)
|
| 31 |
+
return loss
|
| 32 |
+
|
| 33 |
+
class SPCRegressionNet(nn.Module):
|
| 34 |
+
"""
|
| 35 |
+
Neural network that predicts:
|
| 36 |
+
- point estimate (v)
|
| 37 |
+
- MAR (mean absolute residual)
|
| 38 |
+
- MAR up/down (for epistemic decomposition)
|
| 39 |
+
- QR up/down (for aleatoric decomposition)
|
| 40 |
+
"""
|
| 41 |
+
def __init__(self, input_dim=1, hidden_dim=64):
|
| 42 |
+
super(SPCRegressionNet, self).__init__()
|
| 43 |
+
self.hidden = nn.Linear(input_dim, hidden_dim)
|
| 44 |
+
self.hidden2 = nn.Linear(hidden_dim, hidden_dim)
|
| 45 |
+
self.hidden3 = nn.Linear(hidden_dim, hidden_dim)
|
| 46 |
+
|
| 47 |
+
self.relu = nn.ReLU()
|
| 48 |
+
# self.dropout = nn.Dropout(p=0.2)
|
| 49 |
+
self.output_v = nn.Linear(hidden_dim, 1)
|
| 50 |
+
self.output_uq = nn.Linear(hidden_dim, 5)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
x = self.hidden(x)
|
| 55 |
+
x = self.relu(x)
|
| 56 |
+
x = self.hidden2(x)
|
| 57 |
+
x = self.relu(x)
|
| 58 |
+
v = self.output_v(x)
|
| 59 |
+
x = self.hidden3(x)
|
| 60 |
+
x = self.relu(x)
|
| 61 |
+
output = self.output_uq(x)
|
| 62 |
+
mar, mar_up, mar_down, q_up, q_down = torch.chunk(output, 5, dim=-1)
|
| 63 |
+
|
| 64 |
+
q_up = F.softplus(q_up)
|
| 65 |
+
q_down = F.softplus(q_down)
|
| 66 |
+
return v, mar, mar_up, mar_down, q_up, q_down
|
| 67 |
+
|
| 68 |
+
class SPCregression:
|
| 69 |
+
"""
|
| 70 |
+
Trainer and predictor for SPC UQ model.
|
| 71 |
+
Supports joint or stagewise training strategies.
|
| 72 |
+
"""
|
| 73 |
+
def __init__(self, learning_rate=5e-3):
|
| 74 |
+
torch.manual_seed(42)
|
| 75 |
+
self.model = SPCRegressionNet()
|
| 76 |
+
self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
|
| 77 |
+
self.optimizer1 = optim.Adam(list(self.model.hidden.parameters()) + list(self.model.hidden2.parameters()) + list(self.model.output_v.parameters()),
|
| 78 |
+
lr=learning_rate,weight_decay=1e-4)
|
| 79 |
+
self.optimizer2 = optim.Adam(list(self.model.hidden3.parameters()) + list(self.model.output_uq.parameters()),
|
| 80 |
+
lr=learning_rate,weight_decay=1e-4)
|
| 81 |
+
self.criterion = nn.MSELoss()
|
| 82 |
+
self.criterion2=nn.L1Loss()
|
| 83 |
+
|
| 84 |
+
def mar_loss(self, y, predictions, mar, mar_up, mar_down, q_up, q_down):
|
| 85 |
+
"""Computes loss for MAR and QR heads."""
|
| 86 |
+
residual = abs(y - predictions)
|
| 87 |
+
diff = (y - predictions.detach())
|
| 88 |
+
loss_mar = self.criterion(mar, residual)
|
| 89 |
+
|
| 90 |
+
mask_up = (diff > 0)
|
| 91 |
+
mask_down = (diff < 0)
|
| 92 |
+
loss_mar_up = self.criterion(mar_up[mask_up], (y[mask_up] - predictions[mask_up]))
|
| 93 |
+
loss_mar_down = self.criterion(mar_down[mask_down], (predictions[mask_down] - y[mask_down]))
|
| 94 |
+
|
| 95 |
+
# loss_q_up = pinball_loss(q_up[mask_up], (y[mask_up] - predictions[mask_up]), 0.95)
|
| 96 |
+
# loss_q_down = pinball_loss(q_down[mask_down], (predictions[mask_down] - y[mask_down]), 0.95)
|
| 97 |
+
loss_cali_up = cali_loss(q_up[mask_up], (y[mask_up] - predictions[mask_up]), 0.95)
|
| 98 |
+
loss_cali_down = cali_loss(q_down[mask_down], (predictions[mask_down] - y[mask_down]), 0.95)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
loss = loss_mar + loss_mar_up + loss_mar_down +(loss_cali_up + loss_cali_down)
|
| 102 |
+
# + 0.2 * (loss_q_up + loss_q_down) \
|
| 103 |
+
# + 0.8 * (loss_cali_up + loss_cali_down) \
|
| 104 |
+
return loss
|
| 105 |
+
|
| 106 |
+
def train(self, data, target, num_epochs=5000, strategy='stagewise'):
|
| 107 |
+
"""
|
| 108 |
+
Train the model using either:
|
| 109 |
+
- 'joint': full loss on all components
|
| 110 |
+
- 'stagewise': first fit task head, then UQ heads
|
| 111 |
+
"""
|
| 112 |
+
torch.manual_seed(42)
|
| 113 |
+
self.model.train()
|
| 114 |
+
|
| 115 |
+
if strategy== 'joint':
|
| 116 |
+
for epoch in range(num_epochs):
|
| 117 |
+
self.optimizer.zero_grad()
|
| 118 |
+
predictions, mar, mar_up, mar_down, q_up, q_down = self.model(data)
|
| 119 |
+
loss = self.criterion(predictions, target)+self.mar_loss(target, predictions, mar, mar_up, mar_down, q_up, q_down)
|
| 120 |
+
loss.backward()
|
| 121 |
+
self.optimizer.step()
|
| 122 |
+
if (epoch + 1) % 100 == 0:
|
| 123 |
+
print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")
|
| 124 |
+
|
| 125 |
+
if strategy== 'stagewise':
|
| 126 |
+
for epoch in range(num_epochs):
|
| 127 |
+
self.optimizer1.zero_grad()
|
| 128 |
+
predictions, mar, mar_up, mar_down, q_up, q_down = self.model(data)
|
| 129 |
+
loss = self.criterion(predictions, target)
|
| 130 |
+
loss.backward()
|
| 131 |
+
self.optimizer1.step()
|
| 132 |
+
if (epoch + 1) % 100 == 0:
|
| 133 |
+
print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")
|
| 134 |
+
|
| 135 |
+
for epoch in range(num_epochs):
|
| 136 |
+
self.optimizer2.zero_grad()
|
| 137 |
+
predictions, mar, mar_up, mar_down, q_up, q_down = self.model(data)
|
| 138 |
+
loss = self.mar_loss(target, predictions, mar, mar_up, mar_down, q_up, q_down)
|
| 139 |
+
loss.backward()
|
| 140 |
+
self.optimizer2.step()
|
| 141 |
+
if (epoch + 1) % 100 == 0:
|
| 142 |
+
print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")
|
| 143 |
+
|
| 144 |
+
def predict(self, data, calibration=False):
|
| 145 |
+
"""Run prediction and return interval bounds and uncertainty estimate."""
|
| 146 |
+
self.model.eval()
|
| 147 |
+
with torch.no_grad():
|
| 148 |
+
predictions, mar, mar_up, mar_down, q_up, q_down = self.model(data)
|
| 149 |
+
|
| 150 |
+
v = predictions.numpy()
|
| 151 |
+
mar = mar.detach().numpy()
|
| 152 |
+
mar_up = mar_up.detach().numpy()
|
| 153 |
+
mar_down = mar_down.detach().numpy()
|
| 154 |
+
q_up = q_up.detach().numpy()
|
| 155 |
+
q_down = q_down.detach().numpy()
|
| 156 |
+
|
| 157 |
+
if calibration:
|
| 158 |
+
# Calibration adjustment based on Self-consistency
|
| 159 |
+
d_up = (mar * mar_down) / ((2 * mar_down - mar) * mar_up)
|
| 160 |
+
d_down = (mar * mar_up) / ((2 * mar_up - mar) * mar_down)
|
| 161 |
+
d_up = np.clip(d_up, 1, None)
|
| 162 |
+
d_down = np.clip(d_down, 1, None)
|
| 163 |
+
q_up *= d_up
|
| 164 |
+
q_down *= d_down
|
| 165 |
+
|
| 166 |
+
high_bound = v + q_up
|
| 167 |
+
low_bound = v - q_down
|
| 168 |
+
|
| 169 |
+
# Self-consistency Verification
|
| 170 |
+
uncertainty = (abs(2 * mar_up * mar_down - mar * (mar_up + mar_down)))
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
return v.squeeze(), high_bound.squeeze(), low_bound.squeeze(), uncertainty.squeeze()
|
SPC-UQ/Cubic_Regression/__pycache__/ConformalRegression.cpython-37.pyc
ADDED
|
Binary file (3.08 kB). View file
|
|
|
SPC-UQ/Cubic_Regression/__pycache__/DeepEnsembleRegression.cpython-37.pyc
ADDED
|
Binary file (3.48 kB). View file
|
|
|
SPC-UQ/Cubic_Regression/__pycache__/EDLQuantileRegression.cpython-37.pyc
ADDED
|
Binary file (6.07 kB). View file
|
|
|
SPC-UQ/Cubic_Regression/__pycache__/EDLRegression.cpython-37.pyc
ADDED
|
Binary file (4.79 kB). View file
|
|
|
SPC-UQ/Cubic_Regression/__pycache__/QROC.cpython-37.pyc
ADDED
|
Binary file (4.47 kB). View file
|
|
|
SPC-UQ/Cubic_Regression/__pycache__/SPCRegression.cpython-37.pyc
ADDED
|
Binary file (5.2 kB). View file
|
|
|
SPC-UQ/Cubic_Regression/run_cubic_tests.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import argparse
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import matplotlib.cm as cm
|
| 6 |
+
from SPCRegression import SPCregression
|
| 7 |
+
from DeepEnsembleRegression import DeepEnsemble
|
| 8 |
+
from ConformalRegression import ConformalRegressor
|
| 9 |
+
from EDLRegression import EDLRegressor
|
| 10 |
+
from EDLQuantileRegression import EDLQuantileRegressor
|
| 11 |
+
from QROC import QROC
|
| 12 |
+
from scipy.stats import binned_statistic
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def ece_pi(y_true, pred_lower, pred_upper, num_bins=10):
|
| 16 |
+
"""Compute Expected Calibration Error (ECE) based on prediction interval width bins."""
|
| 17 |
+
N = y_true.shape[0]
|
| 18 |
+
in_interval = ((y_true >= pred_lower) & (y_true <= pred_upper)).astype(float)
|
| 19 |
+
widths = pred_upper - pred_lower
|
| 20 |
+
min_w, max_w = np.min(widths), np.max(widths)
|
| 21 |
+
|
| 22 |
+
if min_w == max_w:
|
| 23 |
+
return np.abs(in_interval.mean() - 1.0)
|
| 24 |
+
|
| 25 |
+
bin_edges = np.linspace(min_w, max_w, num_bins + 1)
|
| 26 |
+
ece = 0.0
|
| 27 |
+
|
| 28 |
+
for i in range(num_bins):
|
| 29 |
+
bin_mask = (widths >= bin_edges[i]) & (widths < bin_edges[i + 1])
|
| 30 |
+
count_in_bin = np.sum(bin_mask)
|
| 31 |
+
if count_in_bin == 0:
|
| 32 |
+
continue
|
| 33 |
+
avg_in_interval = in_interval[bin_mask].mean()
|
| 34 |
+
nominal_coverage = 0.95
|
| 35 |
+
calib_error = np.abs(avg_in_interval - nominal_coverage)
|
| 36 |
+
weight = count_in_bin / N
|
| 37 |
+
ece += weight * calib_error
|
| 38 |
+
return ece
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def binning(pred_lower, pred_upper, num_bins=10):
|
| 42 |
+
"""Group indices into bins based on interval width."""
|
| 43 |
+
widths = pred_upper - pred_lower
|
| 44 |
+
min_w, max_w = np.min(widths), np.max(widths)
|
| 45 |
+
bin_edges = np.linspace(min_w, max_w, num_bins + 1)
|
| 46 |
+
|
| 47 |
+
bins = []
|
| 48 |
+
for i in range(num_bins):
|
| 49 |
+
if i == num_bins - 1:
|
| 50 |
+
bin_mask = (widths >= bin_edges[i]) & (widths <= bin_edges[i + 1])
|
| 51 |
+
else:
|
| 52 |
+
bin_mask = (widths >= bin_edges[i]) & (widths < bin_edges[i + 1])
|
| 53 |
+
bin_indices = np.where(bin_mask)[0]
|
| 54 |
+
bins.append(bin_indices)
|
| 55 |
+
return bins, bin_edges
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def generate_multimodal_data(n_samples=1000):
|
| 59 |
+
"""Generate mixture-distribution noise samples."""
|
| 60 |
+
x = np.random.randn(n_samples)
|
| 61 |
+
mask = np.random.choice([0, 1, 2], p=[0.4, 0.3, 0.3], size=n_samples)
|
| 62 |
+
y = np.where(mask == 0, x + np.random.randn(n_samples) * 1,
|
| 63 |
+
np.where(mask == 1, x + 40 + np.random.randn(n_samples) * 1,
|
| 64 |
+
x - 10 + np.random.randn(n_samples) * 1))
|
| 65 |
+
return y
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def generate_train_data(n_samples=20, noise='log'):
|
| 69 |
+
"""Generate synthetic training data with nonlinear relationship and optional noise."""
|
| 70 |
+
np.random.seed(57)
|
| 71 |
+
x = np.linspace(-4, 4, n_samples)
|
| 72 |
+
|
| 73 |
+
if noise == 'log':
|
| 74 |
+
noise = np.random.lognormal(mean=1.5, sigma=1, size=n_samples)
|
| 75 |
+
elif noise == 'tri':
|
| 76 |
+
noise = generate_multimodal_data(n_samples)
|
| 77 |
+
elif noise == 'norm':
|
| 78 |
+
noise = np.random.normal(0, 8, size=n_samples)
|
| 79 |
+
|
| 80 |
+
noise = noise - np.mean(noise)
|
| 81 |
+
y = x ** 3 + noise
|
| 82 |
+
x = x.reshape(-1, 1).astype(np.float32)
|
| 83 |
+
y = y.reshape(-1, 1).astype(np.float32)
|
| 84 |
+
return x, y
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def generate_test_data(n_samples=100, noise='log'):
|
| 88 |
+
"""Generate synthetic test data with extended input range and optional noise."""
|
| 89 |
+
np.random.seed(27)
|
| 90 |
+
x = np.linspace(-6, 6, n_samples)
|
| 91 |
+
|
| 92 |
+
if noise == 'log':
|
| 93 |
+
noise = np.random.lognormal(mean=1.5, sigma=1, size=n_samples)
|
| 94 |
+
elif noise == 'tri':
|
| 95 |
+
noise = generate_multimodal_data(n_samples)
|
| 96 |
+
elif noise == 'norm':
|
| 97 |
+
noise = np.random.normal(0, 8, size=n_samples)
|
| 98 |
+
|
| 99 |
+
noise = noise - np.mean(noise)
|
| 100 |
+
y = x ** 3 + noise
|
| 101 |
+
x = x.reshape(-1, 1).astype(np.float32)
|
| 102 |
+
y = y.reshape(-1, 1).astype(np.float32)
|
| 103 |
+
return x, y
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# ===================== Argument Parser ===================== #
|
| 107 |
+
parser = argparse.ArgumentParser()
|
| 108 |
+
parser.add_argument("--num-epochs", default=5000, type=int)
|
| 109 |
+
parser.add_argument('--data-noise', default='log', choices=['norm', 'tri', 'log'])
|
| 110 |
+
parser.add_argument('--UQ-model', default='SPCregression', choices=['SPCregression', 'DeepEnsemble', 'EDLRegressor', 'EDLQuantileRegressor', 'QROC', 'ConformalRegressor'], help='Select UQ model to test.')
|
| 111 |
+
args = parser.parse_args()
|
| 112 |
+
|
| 113 |
+
# Generate training, calibration, and testing datasets
|
| 114 |
+
x_train, y_train = generate_train_data(n_samples=2000, noise=args.data_noise)
|
| 115 |
+
x_calib, y_calib = generate_train_data(n_samples=500, noise=args.data_noise)
|
| 116 |
+
x_test, y_test = generate_test_data(n_samples=1000, noise=args.data_noise)
|
| 117 |
+
|
| 118 |
+
# Convert data to torch tensors
|
| 119 |
+
x_train_tensor = torch.from_numpy(x_train)
|
| 120 |
+
y_train_tensor = torch.from_numpy(y_train)
|
| 121 |
+
x_calib_tensor = torch.from_numpy(x_calib)
|
| 122 |
+
y_calib_tensor = torch.from_numpy(y_calib)
|
| 123 |
+
x_test_tensor = torch.from_numpy(x_test)
|
| 124 |
+
y_test_tensor = torch.from_numpy(y_test)
|
| 125 |
+
|
| 126 |
+
# Training parameters
|
| 127 |
+
num_epochs = 5000
|
| 128 |
+
num_models = 5
|
| 129 |
+
lr = 0.001
|
| 130 |
+
|
| 131 |
+
# Select UQ model to test
|
| 132 |
+
UQ = args.UQ_model
|
| 133 |
+
|
| 134 |
+
if UQ == 'SPCregression':
|
| 135 |
+
model = SPCregression(learning_rate=lr)
|
| 136 |
+
elif UQ == 'DeepEnsemble':
|
| 137 |
+
model = DeepEnsemble(num_models=num_models, learning_rate=lr)
|
| 138 |
+
elif UQ == 'EDLRegressor':
|
| 139 |
+
model = EDLRegressor(learning_rate=lr)
|
| 140 |
+
elif UQ == 'EDLQuantileRegressor':
|
| 141 |
+
model = EDLQuantileRegressor(tau_low=0.025, tau_high=0.975, learning_rate=lr)
|
| 142 |
+
elif UQ == 'QROC':
|
| 143 |
+
model = QROC(tau_low=0.025, tau_high=0.975, learning_rate=lr)
|
| 144 |
+
elif UQ == 'ConformalRegressor':
|
| 145 |
+
model = ConformalRegressor(0.95, learning_rate=lr)
|
| 146 |
+
|
| 147 |
+
# Train model
|
| 148 |
+
model.train(x_train_tensor, y_train_tensor, num_epochs)
|
| 149 |
+
|
| 150 |
+
# Calibrate if applicable
|
| 151 |
+
if UQ == 'ConformalRegressor':
|
| 152 |
+
model.calibrate(x_calib_tensor, y_calib_tensor)
|
| 153 |
+
|
| 154 |
+
# Predict on train and test sets
|
| 155 |
+
mean, upper_bound, lower_bound, uncertainty = model.predict(x_train_tensor)
|
| 156 |
+
y_train_np = y_train.flatten()
|
| 157 |
+
|
| 158 |
+
# Evaluate train metrics
|
| 159 |
+
mpi_width = np.mean(upper_bound - lower_bound)
|
| 160 |
+
picp = np.mean(((y_train_np >= lower_bound) & (y_train_np <= upper_bound)).astype(float))
|
| 161 |
+
rmse = np.sqrt(np.mean((y_train_np - mean) ** 2))
|
| 162 |
+
print(f"Train Mean Prediction Interval Width (MPIW): {mpi_width:.4f}")
|
| 163 |
+
print(f"Train Prediction Interval Coverage Probability (PICP): {picp:.4f}")
|
| 164 |
+
print(f"Train Root Mean Squared Error (RMSE): {rmse:.4f}")
|
| 165 |
+
|
| 166 |
+
# Predict on test set
|
| 167 |
+
mean, upper_bound, lower_bound, uncertainty = model.predict(x_test_tensor)
|
| 168 |
+
x_test_np = x_test.flatten()
|
| 169 |
+
y_test_np = y_test.flatten()
|
| 170 |
+
|
| 171 |
+
# In-distribution test subset for interval evaluation
|
| 172 |
+
y_low_id = lower_bound[170:830]
|
| 173 |
+
y_high_id = upper_bound[170:830]
|
| 174 |
+
y_mean_id = mean[170:830]
|
| 175 |
+
x_test_id = x_test_np[170:830]
|
| 176 |
+
y_test_id = y_test_np[170:830]
|
| 177 |
+
|
| 178 |
+
mpi_width_id = np.mean(y_high_id - y_low_id)
|
| 179 |
+
picp_id = np.mean(((y_test_id >= y_low_id) & (y_test_id <= y_high_id)).astype(float))
|
| 180 |
+
# PICP+
|
| 181 |
+
picp_plus = np.sum(((y_test_id >= y_mean_id) & (y_test_id <= y_high_id)).astype(float))/np.sum((y_test_id >= y_mean_id).astype(float))
|
| 182 |
+
# PICP-
|
| 183 |
+
picp_minus = np.sum(((y_test_id >= y_low_id) & (y_test_id <= y_mean_id)).astype(float))/np.sum((y_test_id <= y_mean_id).astype(float))
|
| 184 |
+
|
| 185 |
+
rmse_id = np.sqrt(np.mean((y_test_id - y_mean_id) ** 2))
|
| 186 |
+
|
| 187 |
+
print(f"ID Mean Prediction Interval Width (MPIW): {mpi_width_id:.4f}")
|
| 188 |
+
print(f"ID Prediction Interval Coverage Probability (PICP): {picp_id:.4f}")
|
| 189 |
+
print(f"ID Prediction Interval Coverage Probability (PICP+): {picp_plus:.4f}")
|
| 190 |
+
print(f"ID Prediction Interval Coverage Probability (PICP-): {picp_minus:.4f}")
|
| 191 |
+
print(f"ID Root Mean Squared Error (RMSE): {rmse_id:.4f}")
|
| 192 |
+
|
| 193 |
+
# Compute ECE
|
| 194 |
+
ece = ece_pi(y_test_id, y_low_id, y_high_id, num_bins=10)
|
| 195 |
+
print(f"Expected Calibration Error (ECE): {ece:.4f}")
|
| 196 |
+
|
| 197 |
+
threshold = uncertainty[170:830].mean()
|
| 198 |
+
print('threshold:', threshold)
|
| 199 |
+
cer_count = 0
|
| 200 |
+
unc_count = 0
|
| 201 |
+
cer_diff = []
|
| 202 |
+
unc_diff = []
|
| 203 |
+
|
| 204 |
+
for i in range(len(uncertainty)):
|
| 205 |
+
unc = uncertainty[i]
|
| 206 |
+
diff_sq = (y_test_np[i] - mean[i])**2
|
| 207 |
+
if unc < threshold:
|
| 208 |
+
cer_count += 1
|
| 209 |
+
cer_diff.append(diff_sq)
|
| 210 |
+
else:
|
| 211 |
+
unc_count += 1
|
| 212 |
+
unc_diff.append(diff_sq)
|
| 213 |
+
|
| 214 |
+
rmse_certain = np.sqrt(np.mean(cer_diff)) if len(cer_diff)>0 else 0
|
| 215 |
+
rmse_uncertain = np.sqrt(np.mean(unc_diff)) if len(unc_diff)>0 else 0
|
| 216 |
+
rmse_all = np.sqrt(np.mean(cer_diff + unc_diff))
|
| 217 |
+
|
| 218 |
+
print('Certain sample:', cer_count,
|
| 219 |
+
'RMSE_certain:', round(rmse_certain,4),
|
| 220 |
+
'Uncertain sample:', unc_count,
|
| 221 |
+
'RMSE_uncertain:', round(rmse_uncertain,4),
|
| 222 |
+
'RMSE_all:', round(rmse_all,4))
|
| 223 |
+
|
| 224 |
+
print(round(rmse_id,4),
|
| 225 |
+
round(picp_id,4),
|
| 226 |
+
round(mpi_width_id,4),
|
| 227 |
+
round(rmse_certain,4),
|
| 228 |
+
round(rmse_uncertain,4),
|
| 229 |
+
round(rmse_all,4),
|
| 230 |
+
cer_count, unc_count)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def plot_binned_intervals(x_test, y_test, mean, lower_bound, upper_bound, bins, num_bins=1):
|
| 234 |
+
"""Plot prediction intervals with color-coded bins based on interval width."""
|
| 235 |
+
x = x_test.squeeze()
|
| 236 |
+
y = y_test.squeeze()
|
| 237 |
+
|
| 238 |
+
def quantile_stat(q):
|
| 239 |
+
def func(y_in_bin):
|
| 240 |
+
return np.percentile(y_in_bin, q)
|
| 241 |
+
return func
|
| 242 |
+
|
| 243 |
+
bin_means, bin_edges, _ = binned_statistic(x, y, statistic='mean', bins=num_bins)
|
| 244 |
+
q5, _, _ = binned_statistic(x, y, statistic=quantile_stat(5), bins=bin_edges)
|
| 245 |
+
q95, _, _ = binned_statistic(x, y, statistic=quantile_stat(95), bins=bin_edges)
|
| 246 |
+
|
| 247 |
+
gt = x ** 3
|
| 248 |
+
y_up = (y - gt)[y > gt]
|
| 249 |
+
y_down = (gt - y)[y < gt]
|
| 250 |
+
gt_up = np.quantile(y_up, 0.95)
|
| 251 |
+
gt_down = np.quantile(y_down, 0.95)
|
| 252 |
+
|
| 253 |
+
plt.figure(figsize=(7, 5))
|
| 254 |
+
cmap = cm.get_cmap('viridis', len(bins))
|
| 255 |
+
interval_width = 0.1
|
| 256 |
+
for i, bin_indices in enumerate(bins):
|
| 257 |
+
color = cmap(i)
|
| 258 |
+
for j, idx in enumerate(bin_indices):
|
| 259 |
+
x_val = x_test[idx]
|
| 260 |
+
plt.fill_between(
|
| 261 |
+
[x_val - interval_width / 2, x_val + interval_width / 2],
|
| 262 |
+
[lower_bound[idx], lower_bound[idx]],
|
| 263 |
+
[upper_bound[idx], upper_bound[idx]],
|
| 264 |
+
color=color,
|
| 265 |
+
alpha=0.2,
|
| 266 |
+
label=f'Bin {i + 1}' if j == 0 else None
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
plt.plot(x, gt, color='blue', linestyle='--', label='Mean E[y|x]', linewidth=2)
|
| 270 |
+
plt.plot(x, gt - gt_down, color='darkorange', linestyle='--', label='Lower bound GT', linewidth=2)
|
| 271 |
+
plt.plot(x, gt + gt_up, color='purple', linestyle='--', label='Upper bound GT', linewidth=2)
|
| 272 |
+
plt.scatter(x_test, y_test, color='green', s=5, label='Test Data')
|
| 273 |
+
plt.plot(x_test, mean, color='red', label='Point Prediction')
|
| 274 |
+
plt.title("Equal-width Binned PIs")
|
| 275 |
+
plt.legend()
|
| 276 |
+
plt.ylim(-100, 100)
|
| 277 |
+
plt.tight_layout()
|
| 278 |
+
plt.show()
|
| 279 |
+
|
| 280 |
+
# Second plot with shaded split intervals
|
| 281 |
+
plt.figure(figsize=(7, 4))
|
| 282 |
+
plt.plot(x, gt, color='blue', linestyle='--', label='Mean E[y|x]', linewidth=2)
|
| 283 |
+
plt.plot(x, gt - gt_down, color='darkorange', linestyle='--', label='Lower bound GT', linewidth=2)
|
| 284 |
+
plt.plot(x, gt + gt_up, color='purple', linestyle='--', label='Upper bound GT', linewidth=2)
|
| 285 |
+
plt.scatter(x_test, y_test, color='green', s=5, label='Test Data')
|
| 286 |
+
plt.plot(x_test, mean, color='red', label='Point Prediction')
|
| 287 |
+
plt.fill_between(x_test.flatten(), lower_bound, mean, color='orange', alpha=0.3, label='Lower Interval')
|
| 288 |
+
plt.fill_between(x_test.flatten(), mean, upper_bound, color='purple', alpha=0.4, label='Upper Interval')
|
| 289 |
+
plt.legend(fontsize=8)
|
| 290 |
+
plt.ylim(-80, 100)
|
| 291 |
+
plt.title('Split-point Prediction Intervals')
|
| 292 |
+
plt.tight_layout()
|
| 293 |
+
plt.show()
|
| 294 |
+
|
| 295 |
+
def plot_intervals(x_test, y_test, mean, lower_bound, upper_bound, uncertainty):
|
| 296 |
+
"""Plot prediction intervals and epistemic uncertainty with ground truth reference."""
|
| 297 |
+
x = x_test.squeeze()
|
| 298 |
+
y = y_test.squeeze()
|
| 299 |
+
|
| 300 |
+
gt = x ** 3
|
| 301 |
+
y_up = (y - gt)[y > gt]
|
| 302 |
+
y_down = (gt - y)[y < gt]
|
| 303 |
+
gt_up = np.quantile(y_up, 0.95)
|
| 304 |
+
gt_down = np.quantile(y_down, 0.95)
|
| 305 |
+
|
| 306 |
+
import matplotlib.gridspec as gridspec
|
| 307 |
+
plt.figure(figsize=(7, 9))
|
| 308 |
+
gs = gridspec.GridSpec(2, 1, height_ratios=[4, 3])
|
| 309 |
+
|
| 310 |
+
ax1 = plt.subplot(gs[0])
|
| 311 |
+
ax1.axvspan(x.min(), -4, facecolor='lightgray', alpha=0.4)
|
| 312 |
+
ax1.axvspan(4, x.max(), facecolor='lightgray', alpha=0.4)
|
| 313 |
+
ax1.plot(x, gt, color='blue', linestyle='--', label='Mean E[y|x]', linewidth=2)
|
| 314 |
+
ax1.plot(x, gt - gt_down, color='darkorange', linestyle='--', label='Lower bound GT', linewidth=2)
|
| 315 |
+
ax1.plot(x, gt + gt_up, color='purple', linestyle='--', label='Upper bound GT', linewidth=2)
|
| 316 |
+
ax1.scatter(x_test, y_test, color='green', s=5, label='Test Data')
|
| 317 |
+
ax1.plot(x_test, mean, color='red', label='Point Prediction')
|
| 318 |
+
ax1.fill_between(x_test.flatten(), lower_bound, mean, color='orange', alpha=0.3, label='Lower Interval')
|
| 319 |
+
ax1.fill_between(x_test.flatten(), mean, upper_bound, color='purple', alpha=0.4, label='Upper Interval')
|
| 320 |
+
ax1.legend(fontsize=9)
|
| 321 |
+
ax1.set_ylim(-150, 150)
|
| 322 |
+
|
| 323 |
+
ax2 = plt.subplot(gs[1])
|
| 324 |
+
ax2.axvspan(x.min(), -4, facecolor='lightgray', alpha=0.4, label='OOD Region')
|
| 325 |
+
ax2.axvspan(4, x.max(), facecolor='lightgray', alpha=0.4)
|
| 326 |
+
ax2.plot(x_test, uncertainty, label='Epistemic uncertainty', color='dodgerblue')
|
| 327 |
+
ax2.fill_between(x_test.squeeze(), uncertainty.squeeze(), alpha=0.3, color='dodgerblue')
|
| 328 |
+
ax2.legend()
|
| 329 |
+
plt.tight_layout()
|
| 330 |
+
plt.show()
|
| 331 |
+
|
| 332 |
+
# Call visualizations after metric evaluations
|
| 333 |
+
bins, bin_edges = binning(y_low_id, y_high_id, num_bins=5)
|
| 334 |
+
plot_binned_intervals(x_test_id, y_test_id, y_mean_id, y_low_id, y_high_id, bins)
|
| 335 |
+
plot_intervals(x_test, y_test, mean, lower_bound, upper_bound, uncertainty)
|
SPC-UQ/Image_Classification/README.md
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Deep Deterministic Uncertainty
|
| 2 |
+
|
| 3 |
+
[](https://arxiv.org/abs/2102.11582)
|
| 4 |
+
[](https://pytorch.org/)
|
| 5 |
+
[](https://github.com/omegafragger/DDU/blob/main/LICENSE)
|
| 6 |
+
|
| 7 |
+
This repository contains the code for [*Deterministic Neural Networks with Appropriate Inductive Biases Capture Epistemic and Aleatoric Uncertainty*](https://arxiv.org/abs/2102.11582).
|
| 8 |
+
|
| 9 |
+
If the code or the paper has been useful in your research, please add a citation to our work:
|
| 10 |
+
|
| 11 |
+
```
|
| 12 |
+
@article{mukhoti2021deterministic,
|
| 13 |
+
title={Deterministic Neural Networks with Appropriate Inductive Biases Capture Epistemic and Aleatoric Uncertainty},
|
| 14 |
+
author={Mukhoti, Jishnu and Kirsch, Andreas and van Amersfoort, Joost and Torr, Philip HS and Gal, Yarin},
|
| 15 |
+
journal={arXiv preprint arXiv:2102.11582},
|
| 16 |
+
year={2021}
|
| 17 |
+
}
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
## Dependencies
|
| 21 |
+
|
| 22 |
+
The code is based on PyTorch and requires a few further dependencies, listed in [environment.yml](environment.yml). It should work with newer versions as well.
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
## OoD Detection
|
| 26 |
+
|
| 27 |
+
### Datasets
|
| 28 |
+
|
| 29 |
+
For OoD detection, you can train on [*CIFAR-10/100*](https://www.cs.toronto.edu/~kriz/cifar.html). You can also train on [*Dirty-MNIST*](https://blackhc.github.io/ddu_dirty_mnist/) by downloading *Ambiguous-MNIST* (```amnist_labels.pt``` and ```amnist_samples.pt```) from [here](https://github.com/BlackHC/ddu_dirty_mnist/releases/tag/data-v0.6.0) and using the following training instructions.
|
| 30 |
+
|
| 31 |
+
### Training
|
| 32 |
+
|
| 33 |
+
In order to train a model for the OoD detection task, use the [train.py](train.py) script. Following are the main parameters for training:
|
| 34 |
+
```
|
| 35 |
+
--seed: seed for initialization
|
| 36 |
+
--dataset: dataset used for training (cifar10/cifar100/dirty_mnist)
|
| 37 |
+
--dataset-root: /path/to/amnist_labels.pt and amnist_samples.pt/ (if training on dirty-mnist)
|
| 38 |
+
--model: model to train (wide_resnet/vgg16/resnet18/resnet50/lenet)
|
| 39 |
+
-sn: whether to use spectral normalization (available for wide_resnet, vgg16 and resnets)
|
| 40 |
+
--coeff: Coefficient for spectral normalization
|
| 41 |
+
-mod: whether to use architectural modifications (leaky ReLU + average pooling in skip connections)
|
| 42 |
+
--save-path: path/for/saving/model/
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
As an example, in order to train a Wide-ResNet-28-10 with spectral normalization and architectural modifications on CIFAR-10, use the following:
|
| 46 |
+
```
|
| 47 |
+
python train.py \
|
| 48 |
+
--seed 1 \
|
| 49 |
+
--dataset cifar10 \
|
| 50 |
+
--model wide_resnet \
|
| 51 |
+
-sn -mod \
|
| 52 |
+
--coeff 3.0
|
| 53 |
+
```
|
| 54 |
+
Similarly, to train a ResNet-18 with spectral normalization on Dirty-MNIST, use:
|
| 55 |
+
```
|
| 56 |
+
python train.py \
|
| 57 |
+
--seed 1 \
|
| 58 |
+
--dataset dirty-mnist \
|
| 59 |
+
--dataset-root /home/user/amnist/ \
|
| 60 |
+
--model resnet18 \
|
| 61 |
+
-sn \
|
| 62 |
+
--coeff 3.0
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
### Evaluation
|
| 66 |
+
|
| 67 |
+
To evaluate trained models, use [evaluate.py](evaluate.py). This script can evaluate and aggregate results over multiple experimental runs. For example, if the pretrained models are stored in a directory path ```/home/user/models```, store them using the following directory structure:
|
| 68 |
+
```
|
| 69 |
+
models
|
| 70 |
+
├── Run1
|
| 71 |
+
│ └── wide_resnet_1_350.model
|
| 72 |
+
├── Run2
|
| 73 |
+
│ └── wide_resnet_2_350.model
|
| 74 |
+
├── Run3
|
| 75 |
+
│ └── wide_resnet_3_350.model
|
| 76 |
+
├── Run4
|
| 77 |
+
│ └── wide_resnet_4_350.model
|
| 78 |
+
└── Run5
|
| 79 |
+
└── wide_resnet_5_350.model
|
| 80 |
+
```
|
| 81 |
+
For an ensemble of models, store the models using the following directory structure:
|
| 82 |
+
```
|
| 83 |
+
model_ensemble
|
| 84 |
+
├── Run1
|
| 85 |
+
│ ├── wide_resnet_1_350.model
|
| 86 |
+
│ ├── wide_resnet_2_350.model
|
| 87 |
+
│ ├── wide_resnet_3_350.model
|
| 88 |
+
│ ├── wide_resnet_4_350.model
|
| 89 |
+
│ └── wide_resnet_5_350.model
|
| 90 |
+
├── Run2
|
| 91 |
+
│ ├── wide_resnet_10_350.model
|
| 92 |
+
│ ├── wide_resnet_6_350.model
|
| 93 |
+
│ ├── wide_resnet_7_350.model
|
| 94 |
+
│ ├── wide_resnet_8_350.model
|
| 95 |
+
│ └── wide_resnet_9_350.model
|
| 96 |
+
├── Run3
|
| 97 |
+
│ ├── wide_resnet_11_350.model
|
| 98 |
+
│ ├── wide_resnet_12_350.model
|
| 99 |
+
│ ├── wide_resnet_13_350.model
|
| 100 |
+
│ ├── wide_resnet_14_350.model
|
| 101 |
+
│ └── wide_resnet_15_350.model
|
| 102 |
+
├── Run4
|
| 103 |
+
│ ├── wide_resnet_16_350.model
|
| 104 |
+
│ ├── wide_resnet_17_350.model
|
| 105 |
+
│ ├── wide_resnet_18_350.model
|
| 106 |
+
│ ├── wide_resnet_19_350.model
|
| 107 |
+
│ └── wide_resnet_20_350.model
|
| 108 |
+
└── Run5
|
| 109 |
+
├── wide_resnet_21_350.model
|
| 110 |
+
├── wide_resnet_22_350.model
|
| 111 |
+
├── wide_resnet_23_350.model
|
| 112 |
+
├── wide_resnet_24_350.model
|
| 113 |
+
└── wide_resnet_25_350.model
|
| 114 |
+
```
|
| 115 |
+
Following are the main parameters for evaluation:
|
| 116 |
+
```
|
| 117 |
+
--seed: seed used for initializing the first trained model
|
| 118 |
+
--dataset: dataset used for training (cifar10/cifar100)
|
| 119 |
+
--ood_dataset: OoD dataset to compute AUROC
|
| 120 |
+
--load-path: /path/to/pretrained/models/
|
| 121 |
+
--model: model architecture to load (wide_resnet/vgg16)
|
| 122 |
+
--runs: number of experimental runs
|
| 123 |
+
-sn: whether the model was trained using spectral normalization
|
| 124 |
+
--coeff: Coefficient for spectral normalization
|
| 125 |
+
-mod: whether the model was trained using architectural modifications
|
| 126 |
+
--ensemble: number of models in the ensemble
|
| 127 |
+
--model-type: type of model to load for evaluation (softmax/ensemble/gmm)
|
| 128 |
+
```
|
| 129 |
+
As an example, in order to evaluate a Wide-ResNet-28-10 with spectral normalization and architectural modifications on CIFAR-10 with OoD dataset as SVHN, use the following:
|
| 130 |
+
```
|
| 131 |
+
python evaluate.py \
|
| 132 |
+
--seed 1 \
|
| 133 |
+
--dataset cifar10 \
|
| 134 |
+
--ood_dataset svhn \
|
| 135 |
+
--load-path /path/to/pretrained/models/ \
|
| 136 |
+
--model wide_resnet \
|
| 137 |
+
--runs 5 \
|
| 138 |
+
-sn -mod \
|
| 139 |
+
--coeff 3.0 \
|
| 140 |
+
--model-type softmax
|
| 141 |
+
```
|
| 142 |
+
Similarly, to evaluate the above model using feature density, set ```--model-type gmm```. The evaluation script assumes that the seeds of models trained in consecutive runs differ by 1. The script stores the results in a json file with the following structure:
|
| 143 |
+
```
|
| 144 |
+
{
|
| 145 |
+
"mean": {
|
| 146 |
+
"accuracy": mean accuracy,
|
| 147 |
+
"ece": mean ECE,
|
| 148 |
+
"m1_auroc": mean AUROC using log density / MI for ensembles,
|
| 149 |
+
"m1_auprc": mean AUPRC using log density / MI for ensembles,
|
| 150 |
+
"m2_auroc": mean AUROC using entropy / PE for ensembles,
|
| 151 |
+
"m2_auprc": mean AUPRC using entropy / PE for ensembles,
|
| 152 |
+
"t_ece": mean ECE (post temp scaling)
|
| 153 |
+
"t_m1_auroc": mean AUROC using log density / MI for ensembles (post temp scaling),
|
| 154 |
+
"t_m1_auprc": mean AUPRC using log density / MI for ensembles (post temp scaling),
|
| 155 |
+
"t_m2_auroc": mean AUROC using entropy / PE for ensembles (post temp scaling),
|
| 156 |
+
"t_m2_auprc": mean AUPRC using entropy / PE for ensembles (post temp scaling)
|
| 157 |
+
},
|
| 158 |
+
"std": {
|
| 159 |
+
"accuracy": std error accuracy,
|
| 160 |
+
"ece": std error ECE,
|
| 161 |
+
"m1_auroc": std error AUROC using log density / MI for ensembles,
|
| 162 |
+
"m1_auprc": std error AUPRC using log density / MI for ensembles,
|
| 163 |
+
"m2_auroc": std error AUROC using entropy / PE for ensembles,
|
| 164 |
+
"m2_auprc": std error AUPRC using entropy / PE for ensembles,
|
| 165 |
+
"t_ece": std error ECE (post temp scaling),
|
| 166 |
+
"t_m1_auroc": std error AUROC using log density / MI for ensembles (post temp scaling),
|
| 167 |
+
"t_m1_auprc": std error AUPRC using log density / MI for ensembles (post temp scaling),
|
| 168 |
+
"t_m2_auroc": std error AUROC using entropy / PE for ensembles (post temp scaling),
|
| 169 |
+
"t_m2_auprc": std error AUPRC using entropy / PE for ensembles (post temp scaling)
|
| 170 |
+
},
|
| 171 |
+
"values": {
|
| 172 |
+
"accuracy": accuracy list,
|
| 173 |
+
"ece": ece list,
|
| 174 |
+
"m1_auroc": AUROC list using log density / MI for ensembles,
|
| 175 |
+
"m2_auroc": AUROC list using entropy / PE for ensembles,
|
| 176 |
+
"t_ece": ece list (post temp scaling),
|
| 177 |
+
"t_m1_auroc": AUROC list using log density / MI for ensembles (post temp scaling),
|
| 178 |
+
"t_m1_auprc": AUPRC list using log density / MI for ensembles (post temp scaling),
|
| 179 |
+
"t_m2_auroc": AUROC list using entropy / PE for ensembles (post temp scaling),
|
| 180 |
+
"t_m2_auprc": AUPRC list using entropy / PE for ensembles (post temp scaling)
|
| 181 |
+
},
|
| 182 |
+
"info": {dictionary of args}
|
| 183 |
+
}
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
### Results
|
| 187 |
+
|
| 188 |
+
#### Dirty-MNIST
|
| 189 |
+
|
| 190 |
+
To visualise DDU's performance on Dirty-MNIST (i.e., Fig. 1 of the paper), use [fig_1_plot.ipynb](notebooks/fig_1_plot.ipynb). The notebook requires a pretrained LeNet, VGG-16 and ResNet-18 with spectral normalization trained on Dirty-MNIST and visualises the softmax entropy and feature density for Dirty-MNIST (iD) samples vs Fashion-MNIST (OoD) samples. The notebook also visualises the softmax entropies of MNIST vs Ambiguous-MNIST samples for the ResNet-18+SN model (Fig. 2 of the paper). The following figure shows the output of the notebook for the LeNet, VGG-16 and ResNet18+SN model we trained on Dirty-MNIST.
|
| 191 |
+
|
| 192 |
+
<p align="center">
|
| 193 |
+
<img src="vis/dirty_mnist_vis.png" width="500" />
|
| 194 |
+
</p>
|
| 195 |
+
|
| 196 |
+
#### CIFAR-10 vs SVHN
|
| 197 |
+
|
| 198 |
+
The following table presents results for a Wide-ResNet-28-10 architecture trained on CIFAR-10 with SVHN as the OoD dataset. For the full set of results, refer to the [paper](https://arxiv.org/abs/2102.11582).
|
| 199 |
+
|
| 200 |
+
| Method | Aleatoric Uncertainty | Epistemic Uncertainty | Test Accuracy | Test ECE | AUROC |
|
| 201 |
+
| --- | --- | --- | --- | --- | --- |
|
| 202 |
+
| Softmax | Softmax Entropy | Softmax Entropy | 95.98+-0.02 | 0.85+-0.02 | 94.44+-0.43 |
|
| 203 |
+
| [Energy-based](https://arxiv.org/abs/2010.03759) | Softmax Entropy | Softmax Density | 95.98+-0.02 | 0.85+-0.02 | 94.56+-0.51 |
|
| 204 |
+
| [5-Ensemble](https://arxiv.org/abs/1612.01474) | Predictive Entropy | Predictive Entropy | 96.59+-0.02 | 0.76+-0.03 | 97.73+-0.31 |
|
| 205 |
+
| DDU (ours) | Softmax Entropy | GMM Density | 95.97+-0.03 | 0.85+-0.04 | 98.09+-0.10 |
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
## Active Learning
|
| 209 |
+
|
| 210 |
+
To run active learning experiments, use ```active_learning_script.py```. You can run active learning experiments on both [MNIST](http://yann.lecun.com/exdb/mnist/) as well as [Dirty-MNIST](https://blackhc.github.io/ddu_dirty_mnist/). When running with Dirty-MNIST, you will need to provide a pretrained model on Dirty-MNIST to distinguish between clean MNIST and Ambiguous-MNIST samples. The following are the main command line arguments for ```active_learning_script.py```.
|
| 211 |
+
```
|
| 212 |
+
--seed: seed used for initializing the first model (later experimental runs will have seeds incremented by 1)
|
| 213 |
+
--model: model architecture to train (resnet18)
|
| 214 |
+
-ambiguous: whether to use ambiguous MNIST during training. If this is set to True, the models will be trained on Dirty-MNIST, otherwise they will train on MNIST.
|
| 215 |
+
--dataset-root: /path/to/amnist_labels.pt and amnist_samples.pt/
|
| 216 |
+
--trained-model: model architecture of pretrained model to distinguish clean and ambiguous MNIST samples
|
| 217 |
+
-tsn: if pretrained model has been trained using spectral normalization
|
| 218 |
+
--tcoeff: coefficient of spectral normalization used on pretrained model
|
| 219 |
+
-tmod: if pretrained model has been trained using architectural modifications (leaky ReLU and average pooling on skip connections)
|
| 220 |
+
--saved-model-path: /path/to/saved/pretrained/model/
|
| 221 |
+
--saved-model-name: name of the saved pretrained model file
|
| 222 |
+
--threshold: Threshold of softmax entropy to decide if a sample is ambiguous (samples having higher softmax entropy than threshold will be considered ambiguous)
|
| 223 |
+
--subsample: number of clean MNIST samples to use to subsample clean MNIST
|
| 224 |
+
-sn: whether to use spectral normalization during training
|
| 225 |
+
--coeff: coefficient of spectral normalization during training
|
| 226 |
+
-mod: whether to use architectural modifications (leaky ReLU and average pooling on skip connections) during training
|
| 227 |
+
--al-type: type of active learning acquisition model (softmax/ensemble/gmm)
|
| 228 |
+
-mi: whether to use mutual information for ensemble al-type
|
| 229 |
+
--num-initial-samples: number of initial samples in the training set
|
| 230 |
+
--max-training-samples: maximum number of training samples
|
| 231 |
+
--acquisition-batch-size: batch size for each acquisition step
|
| 232 |
+
```
|
| 233 |
+
|
| 234 |
+
As an example, to run the active learning experiment on MNIST using the DDU method, use:
|
| 235 |
+
```
|
| 236 |
+
python active_learning_script.py \
|
| 237 |
+
--seed 1 \
|
| 238 |
+
--model resnet18 \
|
| 239 |
+
-sn -mod \
|
| 240 |
+
--al-type gmm
|
| 241 |
+
```
|
| 242 |
+
Similarly, to run the active learning experiment on Dirty-MNIST using the DDU baseline, with a pretrained ResNet-18 with SN to distinguish clean and ambiguous MNIST samples, use the following:
|
| 243 |
+
```
|
| 244 |
+
python active_learning_script.py \
|
| 245 |
+
--seed 1 \
|
| 246 |
+
--model resnet18 \
|
| 247 |
+
-sn -mod \
|
| 248 |
+
-ambiguous \
|
| 249 |
+
--dataset-root /home/user/amnist/ \
|
| 250 |
+
--trained-model resnet18 \
|
| 251 |
+
-tsn \
|
| 252 |
+
--saved-model-path /path/to/pretrained/model \
|
| 253 |
+
--saved-model-name resnet18_sn_3.0_1_350.model \
|
| 254 |
+
--threshold 1.0 \
|
| 255 |
+
--subsample 1000 \
|
| 256 |
+
--al-type gmm
|
| 257 |
+
```
|
| 258 |
+
|
| 259 |
+
### Results
|
| 260 |
+
|
| 261 |
+
The active learning script stores all results in json files. The MNIST test set accuracy is stored in a json file with the following structure:
|
| 262 |
+
```
|
| 263 |
+
{
|
| 264 |
+
"experiment run": list of MNIST test set accuracies one per acquisition step
|
| 265 |
+
}
|
| 266 |
+
```
|
| 267 |
+
When using ambiguous samples in the pool set, the script also stores the fraction of ambiguous samples acquired in each step in the following json:
|
| 268 |
+
```
|
| 269 |
+
{
|
| 270 |
+
"experiment run": list of fractions of ambiguous samples in the acquired training set
|
| 271 |
+
}
|
| 272 |
+
```
|
| 273 |
+
|
| 274 |
+
### Visualisation
|
| 275 |
+
|
| 276 |
+
To visualise results from the above json files, use the [al_plot.ipynb](notebooks/al_plot.ipynb) notebook. The following diagram shows the performance of different baselines (softmax, ensemble PE, ensemble MI and DDU) on MNIST and Dirty-MNIST.
|
| 277 |
+
|
| 278 |
+
<p align="center">
|
| 279 |
+
<img src="vis/al_plots.png" width="700" />
|
| 280 |
+
</p>
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
## Questions
|
| 284 |
+
|
| 285 |
+
For any questions, please feel free to raise an issue or email us directly. Our emails can be found on the [paper](https://arxiv.org/abs/2102.11582).
|
SPC-UQ/Image_Classification/data/__init__.py
ADDED
|
File without changes
|
SPC-UQ/Image_Classification/data/ood_detection/__init__.py
ADDED
|
File without changes
|
SPC-UQ/Image_Classification/data/ood_detection/cifar10.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Create train, valid, test iterators for CIFAR-10.
|
| 3 |
+
Train set size: 45000
|
| 4 |
+
Val set size: 5000
|
| 5 |
+
Test set size: 10000
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
from torch.utils.data import Subset
|
| 11 |
+
|
| 12 |
+
from torchvision import datasets
|
| 13 |
+
from torchvision import transforms
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_train_valid_loader(batch_size, augment, val_seed, imagesize=128, val_size=0.1, num_workers=1, pin_memory=False, **kwargs):
|
| 17 |
+
"""
|
| 18 |
+
Utility function for loading and returning train and valid
|
| 19 |
+
multi-process iterators over the CIFAR-10 dataset.
|
| 20 |
+
Params:
|
| 21 |
+
------
|
| 22 |
+
- batch_size: how many samples per batch to load.
|
| 23 |
+
- augment: whether to apply the data augmentation scheme
|
| 24 |
+
mentioned in the paper. Only applied on the train split.
|
| 25 |
+
- val_seed: fix seed for reproducibility.
|
| 26 |
+
- val_size: percentage split of the training set used for
|
| 27 |
+
the validation set. Should be a float in the range [0, 1].
|
| 28 |
+
- num_workers: number of subprocesses to use when loading the dataset.
|
| 29 |
+
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
|
| 30 |
+
True if using GPU.
|
| 31 |
+
|
| 32 |
+
Returns
|
| 33 |
+
-------
|
| 34 |
+
- train_loader: training set iterator.
|
| 35 |
+
- valid_loader: validation set iterator.
|
| 36 |
+
"""
|
| 37 |
+
error_msg = "[!] val_size should be in the range [0, 1]."
|
| 38 |
+
assert (val_size >= 0) and (val_size <= 1), error_msg
|
| 39 |
+
|
| 40 |
+
normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010],)
|
| 41 |
+
|
| 42 |
+
# define transforms
|
| 43 |
+
valid_transform = transforms.Compose([transforms.Resize(imagesize),transforms.ToTensor(), normalize,])
|
| 44 |
+
|
| 45 |
+
if augment:
|
| 46 |
+
train_transform = transforms.Compose(
|
| 47 |
+
[transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),transforms.Resize(imagesize), transforms.ToTensor(), normalize,]
|
| 48 |
+
)
|
| 49 |
+
else:
|
| 50 |
+
train_transform = transforms.Compose([transforms.Resize(imagesize),transforms.ToTensor(), normalize,])
|
| 51 |
+
|
| 52 |
+
# load the dataset
|
| 53 |
+
data_dir = "./data"
|
| 54 |
+
train_dataset = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform,)
|
| 55 |
+
|
| 56 |
+
valid_dataset = datasets.CIFAR10(root=data_dir, train=True, download=False, transform=valid_transform,)
|
| 57 |
+
|
| 58 |
+
num_train = len(train_dataset)
|
| 59 |
+
indices = list(range(num_train))
|
| 60 |
+
split = int(np.floor(val_size * num_train))
|
| 61 |
+
|
| 62 |
+
np.random.seed(val_seed)
|
| 63 |
+
np.random.shuffle(indices)
|
| 64 |
+
|
| 65 |
+
train_idx, valid_idx = indices[split:], indices[:split]
|
| 66 |
+
|
| 67 |
+
train_subset = Subset(train_dataset, train_idx)
|
| 68 |
+
valid_subset = Subset(valid_dataset, valid_idx)
|
| 69 |
+
|
| 70 |
+
train_loader = torch.utils.data.DataLoader(
|
| 71 |
+
train_subset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True,
|
| 72 |
+
)
|
| 73 |
+
valid_loader = torch.utils.data.DataLoader(
|
| 74 |
+
valid_subset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=False,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
return (train_loader, valid_loader)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_test_loader(batch_size, imagesize=128, num_workers=1, pin_memory=False, **kwargs):
|
| 81 |
+
"""
|
| 82 |
+
Utility function for loading and returning a multi-process
|
| 83 |
+
test iterator over the CIFAR-10 dataset.
|
| 84 |
+
If using CUDA, num_workers should be set to 1 and pin_memory to True.
|
| 85 |
+
Params
|
| 86 |
+
------
|
| 87 |
+
- batch_size: how many samples per batch to load.
|
| 88 |
+
- num_workers: number of subprocesses to use when loading the dataset.
|
| 89 |
+
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
|
| 90 |
+
True if using GPU.
|
| 91 |
+
Returns
|
| 92 |
+
-------
|
| 93 |
+
- data_loader: test set iterator.
|
| 94 |
+
"""
|
| 95 |
+
normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010],)
|
| 96 |
+
|
| 97 |
+
# define transform
|
| 98 |
+
transform = transforms.Compose([transforms.Resize(imagesize), transforms.ToTensor(), normalize,])
|
| 99 |
+
|
| 100 |
+
data_dir = "./data"
|
| 101 |
+
dataset = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform,)
|
| 102 |
+
|
| 103 |
+
data_loader = torch.utils.data.DataLoader(
|
| 104 |
+
dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
return data_loader
|
SPC-UQ/Image_Classification/data/ood_detection/cifar100.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Create train, valid, test iterators for CIFAR-100.
|
| 3 |
+
Train set size: 45000
|
| 4 |
+
Val set size: 5000
|
| 5 |
+
Test set size: 10000
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
from torch.utils.data import Subset
|
| 11 |
+
|
| 12 |
+
from torchvision import datasets
|
| 13 |
+
from torchvision import transforms
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_train_valid_loader(batch_size, augment, val_seed, imagesize=128, val_size=0.1, num_workers=1, pin_memory=False, **kwargs):
|
| 17 |
+
"""
|
| 18 |
+
Utility function for loading and returning train and valid
|
| 19 |
+
multi-process iterators over the CIFAR-100 dataset.
|
| 20 |
+
Params:
|
| 21 |
+
------
|
| 22 |
+
- batch_size: how many samples per batch to load.
|
| 23 |
+
- augment: whether to apply the data augmentation scheme
|
| 24 |
+
mentioned in the paper. Only applied on the train split.
|
| 25 |
+
- val_seed: fix seed for reproducibility.
|
| 26 |
+
- val_size: percentage split of the training set used for
|
| 27 |
+
the validation set. Should be a float in the range [0, 1].
|
| 28 |
+
- num_workers: number of subprocesses to use when loading the dataset.
|
| 29 |
+
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
|
| 30 |
+
True if using GPU.
|
| 31 |
+
|
| 32 |
+
Returns
|
| 33 |
+
-------
|
| 34 |
+
- train_loader: training set iterator.
|
| 35 |
+
- valid_loader: validation set iterator.
|
| 36 |
+
"""
|
| 37 |
+
error_msg = "[!] val_size should be in the range [0, 1]."
|
| 38 |
+
assert (val_size >= 0) and (val_size <= 1), error_msg
|
| 39 |
+
|
| 40 |
+
normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010],)
|
| 41 |
+
|
| 42 |
+
# define transforms
|
| 43 |
+
valid_transform = transforms.Compose([transforms.Resize(imagesize),transforms.ToTensor(), normalize,])
|
| 44 |
+
|
| 45 |
+
if augment:
|
| 46 |
+
train_transform = transforms.Compose(
|
| 47 |
+
[transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.Resize(imagesize),transforms.ToTensor(), normalize,]
|
| 48 |
+
)
|
| 49 |
+
else:
|
| 50 |
+
train_transform = transforms.Compose([transforms.Resize(imagesize),transforms.ToTensor(), normalize,])
|
| 51 |
+
|
| 52 |
+
# load the dataset
|
| 53 |
+
data_dir = "./data"
|
| 54 |
+
train_dataset = datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform,)
|
| 55 |
+
|
| 56 |
+
valid_dataset = datasets.CIFAR100(root=data_dir, train=True, download=False, transform=valid_transform,)
|
| 57 |
+
|
| 58 |
+
num_train = len(train_dataset)
|
| 59 |
+
indices = list(range(num_train))
|
| 60 |
+
split = int(np.floor(val_size * num_train))
|
| 61 |
+
|
| 62 |
+
np.random.seed(val_seed)
|
| 63 |
+
np.random.shuffle(indices)
|
| 64 |
+
|
| 65 |
+
train_idx, valid_idx = indices[split:], indices[:split]
|
| 66 |
+
|
| 67 |
+
train_subset = Subset(train_dataset, train_idx)
|
| 68 |
+
valid_subset = Subset(valid_dataset, valid_idx)
|
| 69 |
+
|
| 70 |
+
train_loader = torch.utils.data.DataLoader(
|
| 71 |
+
train_subset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True,
|
| 72 |
+
)
|
| 73 |
+
valid_loader = torch.utils.data.DataLoader(
|
| 74 |
+
valid_subset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=False,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
return (train_loader, valid_loader)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_test_loader(batch_size, imagesize=128, num_workers=1, pin_memory=False, **kwargs):
|
| 81 |
+
"""
|
| 82 |
+
Utility function for loading and returning a multi-process
|
| 83 |
+
test iterator over the CIFAR-100 dataset.
|
| 84 |
+
If using CUDA, num_workers should be set to 1 and pin_memory to True.
|
| 85 |
+
Params
|
| 86 |
+
------
|
| 87 |
+
- batch_size: how many samples per batch to load.
|
| 88 |
+
- num_workers: number of subprocesses to use when loading the dataset.
|
| 89 |
+
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
|
| 90 |
+
True if using GPU.
|
| 91 |
+
Returns
|
| 92 |
+
-------
|
| 93 |
+
- data_loader: test set iterator.
|
| 94 |
+
"""
|
| 95 |
+
normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010],)
|
| 96 |
+
|
| 97 |
+
# define transform
|
| 98 |
+
transform = transforms.Compose([transforms.Resize(imagesize), transforms.ToTensor(), normalize,])
|
| 99 |
+
|
| 100 |
+
data_dir = "./data"
|
| 101 |
+
dataset = datasets.CIFAR100(root=data_dir, train=False, download=True, transform=transform,)
|
| 102 |
+
|
| 103 |
+
data_loader = torch.utils.data.DataLoader(
|
| 104 |
+
dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
return data_loader
|
SPC-UQ/Image_Classification/data/ood_detection/imagenet.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Create train, valid, test iterators for ImageNet.
|
| 3 |
+
Train set size: user-defined
|
| 4 |
+
Val set size: user-defined
|
| 5 |
+
Test set size: user-defined (if available)
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
from torch.utils.data import Subset
|
| 11 |
+
from torchvision import datasets, transforms
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_train_valid_loader(batch_size, augment, val_seed, imagesize=224, val_size=0.1, num_workers=1, pin_memory=False, **kwargs):
|
| 15 |
+
assert 0 <= val_size <= 1, "[!] val_size should be in the range [0, 1]."
|
| 16 |
+
|
| 17 |
+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 18 |
+
imagesize = 224
|
| 19 |
+
# Define transformations
|
| 20 |
+
valid_transform = transforms.Compose([
|
| 21 |
+
transforms.Resize(256),
|
| 22 |
+
transforms.CenterCrop(imagesize),
|
| 23 |
+
transforms.ToTensor(),
|
| 24 |
+
normalize
|
| 25 |
+
])
|
| 26 |
+
if augment:
|
| 27 |
+
transform = transforms.Compose([
|
| 28 |
+
transforms.Resize(256),
|
| 29 |
+
transforms.CenterCrop(224),
|
| 30 |
+
transforms.ToTensor(),
|
| 31 |
+
transforms.Normalize(
|
| 32 |
+
mean=[0.485, 0.456, 0.406],
|
| 33 |
+
std=[0.229, 0.224, 0.225]
|
| 34 |
+
)
|
| 35 |
+
])
|
| 36 |
+
else:
|
| 37 |
+
train_transform = valid_transform
|
| 38 |
+
|
| 39 |
+
data_dir = "./data/Imagenet1K"
|
| 40 |
+
# Load the dataset
|
| 41 |
+
train_dataset = datasets.ImageFolder(root=f"{data_dir}/train", transform=train_transform)
|
| 42 |
+
valid_dataset = datasets.ImageFolder(root=f"{data_dir}/train", transform=valid_transform)
|
| 43 |
+
|
| 44 |
+
num_train = len(train_dataset)
|
| 45 |
+
indices = list(range(num_train))
|
| 46 |
+
split = int(np.floor(val_size * num_train))
|
| 47 |
+
|
| 48 |
+
np.random.seed(val_seed)
|
| 49 |
+
np.random.shuffle(indices)
|
| 50 |
+
|
| 51 |
+
train_idx, valid_idx = indices[split:], indices[:split]
|
| 52 |
+
|
| 53 |
+
train_subset = Subset(train_dataset, train_idx)
|
| 54 |
+
valid_subset = Subset(valid_dataset, valid_idx)
|
| 55 |
+
|
| 56 |
+
train_loader = torch.utils.data.DataLoader(
|
| 57 |
+
train_subset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True,
|
| 58 |
+
)
|
| 59 |
+
valid_loader = torch.utils.data.DataLoader(
|
| 60 |
+
valid_subset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=False,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
return train_loader, valid_loader
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_test_loader(batch_size, imagesize=224, num_workers=1, pin_memory=False, **kwargs):
|
| 67 |
+
|
| 68 |
+
# Define transformation
|
| 69 |
+
transform = transforms.Compose([
|
| 70 |
+
transforms.Resize(256),
|
| 71 |
+
transforms.CenterCrop(224),
|
| 72 |
+
transforms.ToTensor(),
|
| 73 |
+
transforms.Normalize(
|
| 74 |
+
mean=[0.485, 0.456, 0.406],
|
| 75 |
+
std=[0.229, 0.224, 0.225]
|
| 76 |
+
)
|
| 77 |
+
])
|
| 78 |
+
data_dir = "./data/Imagenet1K"
|
| 79 |
+
dataset = datasets.ImageFolder(root=f"{data_dir}/val", transform=transform)
|
| 80 |
+
|
| 81 |
+
data_loader = torch.utils.data.DataLoader(
|
| 82 |
+
dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
return data_loader
|
SPC-UQ/Image_Classification/data/ood_detection/imagenet_a.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Create train, valid, test iterators for ImageNet.
|
| 3 |
+
Train set size: user-defined
|
| 4 |
+
Val set size: user-defined
|
| 5 |
+
Test set size: user-defined (if available)
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
from torch.utils.data import Subset
|
| 11 |
+
from torchvision import datasets, transforms
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_test_loader(batch_size, imagesize=224, num_workers=1, pin_memory=False, **kwargs):
|
| 16 |
+
transform = transforms.Compose([
|
| 17 |
+
transforms.Resize(256),
|
| 18 |
+
transforms.CenterCrop(imagesize),
|
| 19 |
+
transforms.ToTensor(),
|
| 20 |
+
transforms.Normalize(
|
| 21 |
+
mean=[0.485, 0.456, 0.406],
|
| 22 |
+
std=[0.229, 0.224, 0.225]
|
| 23 |
+
)
|
| 24 |
+
])
|
| 25 |
+
|
| 26 |
+
root = "./data/imagenet-a"
|
| 27 |
+
dataset = datasets.ImageFolder(
|
| 28 |
+
root=root,
|
| 29 |
+
transform=transform
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
data_loader = torch.utils.data.DataLoader(
|
| 33 |
+
dataset, batch_size=batch_size, shuffle=False,
|
| 34 |
+
num_workers=num_workers, pin_memory=pin_memory
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
return data_loader
|
SPC-UQ/Image_Classification/data/ood_detection/imagenet_o.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Create train, valid, test iterators for ImageNet.
|
| 3 |
+
Train set size: user-defined
|
| 4 |
+
Val set size: user-defined
|
| 5 |
+
Test set size: user-defined (if available)
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
from torch.utils.data import Subset
|
| 11 |
+
from torchvision import datasets, transforms
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_test_loader(batch_size, imagesize=224, num_workers=1, pin_memory=False, **kwargs):
|
| 16 |
+
transform = transforms.Compose([
|
| 17 |
+
transforms.Resize(256),
|
| 18 |
+
transforms.CenterCrop(imagesize),
|
| 19 |
+
transforms.ToTensor(),
|
| 20 |
+
transforms.Normalize(
|
| 21 |
+
mean=[0.485, 0.456, 0.406],
|
| 22 |
+
std=[0.229, 0.224, 0.225]
|
| 23 |
+
)
|
| 24 |
+
])
|
| 25 |
+
|
| 26 |
+
root = "./data/imagenet-o"
|
| 27 |
+
dataset = datasets.ImageFolder(
|
| 28 |
+
root=root,
|
| 29 |
+
transform=transform
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
data_loader = torch.utils.data.DataLoader(
|
| 33 |
+
dataset, batch_size=batch_size, shuffle=False,
|
| 34 |
+
num_workers=num_workers, pin_memory=pin_memory
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
return data_loader
|
SPC-UQ/Image_Classification/data/ood_detection/ood_union.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import Subset, ConcatDataset, DataLoader
|
| 4 |
+
import random
|
| 5 |
+
|
| 6 |
+
from torchvision import datasets, transforms
|
| 7 |
+
from torchvision.datasets import ImageFolder
|
| 8 |
+
|
| 9 |
+
def get_svhn_test_loader(batch_size, imagesize=128, num_workers=1, pin_memory=False, **kwargs):
|
| 10 |
+
normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010],)
|
| 11 |
+
|
| 12 |
+
# define transform
|
| 13 |
+
transform = transforms.Compose([transforms.Resize(imagesize), transforms.ToTensor(), normalize,])
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
data_dir = "./data"
|
| 17 |
+
dataset = datasets.SVHN(root=data_dir, split="test", download=True, transform=transform,)
|
| 18 |
+
|
| 19 |
+
data_loader = torch.utils.data.DataLoader(
|
| 20 |
+
dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
return data_loader
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_tinyimagenet_test_loader(batch_size, imagesize=128, num_workers=1, pin_memory=False, **kwargs):
|
| 27 |
+
|
| 28 |
+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 29 |
+
|
| 30 |
+
transform = transforms.Compose([
|
| 31 |
+
transforms.Resize(imagesize),
|
| 32 |
+
transforms.ToTensor(),
|
| 33 |
+
normalize
|
| 34 |
+
])
|
| 35 |
+
|
| 36 |
+
data_dir = "./data/tinyimagenet"
|
| 37 |
+
test_dir = os.path.join(data_dir, "test")
|
| 38 |
+
dataset = ImageFolder(root=test_dir, transform=transform)
|
| 39 |
+
|
| 40 |
+
data_loader = torch.utils.data.DataLoader(
|
| 41 |
+
dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory
|
| 42 |
+
)
|
| 43 |
+
return data_loader
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_cifar10_test_loader(batch_size, imagesize=128, num_workers=1, pin_memory=False, **kwargs):
|
| 47 |
+
|
| 48 |
+
normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010],)
|
| 49 |
+
|
| 50 |
+
# define transform
|
| 51 |
+
transform = transforms.Compose([transforms.Resize(imagesize), transforms.ToTensor(), normalize,])
|
| 52 |
+
|
| 53 |
+
data_dir = "./data"
|
| 54 |
+
dataset = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform,)
|
| 55 |
+
|
| 56 |
+
data_loader = torch.utils.data.DataLoader(
|
| 57 |
+
dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
return data_loader
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_cifar100_test_loader(batch_size, imagesize=128, num_workers=1, pin_memory=False, **kwargs):
|
| 64 |
+
|
| 65 |
+
normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010],)
|
| 66 |
+
|
| 67 |
+
# define transform
|
| 68 |
+
transform = transforms.Compose([transforms.Resize(imagesize), transforms.ToTensor(), normalize,])
|
| 69 |
+
|
| 70 |
+
data_dir = "./data"
|
| 71 |
+
dataset = datasets.CIFAR100(root=data_dir, train=False, download=True, transform=transform,)
|
| 72 |
+
|
| 73 |
+
data_loader = torch.utils.data.DataLoader(
|
| 74 |
+
dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
return data_loader
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_combined_ood_test_loader(batch_size, sample_seed, imagesize=128, num_workers=1, pin_memory=False, sample_size=10000, **kwargs):
|
| 82 |
+
svhn_ds = get_svhn_test_loader(batch_size=1, imagesize=imagesize).dataset
|
| 83 |
+
tiny_ds = get_tinyimagenet_test_loader(batch_size=1, imagesize=imagesize).dataset
|
| 84 |
+
|
| 85 |
+
combined_dataset = ConcatDataset([
|
| 86 |
+
svhn_ds,
|
| 87 |
+
tiny_ds
|
| 88 |
+
])
|
| 89 |
+
|
| 90 |
+
# print(len(combined_dataset))
|
| 91 |
+
|
| 92 |
+
random.seed(sample_seed)
|
| 93 |
+
if sample_size is not None and sample_size < len(combined_dataset):
|
| 94 |
+
indices = random.sample(range(len(combined_dataset)), sample_size)
|
| 95 |
+
combined_dataset = Subset(combined_dataset, indices)
|
| 96 |
+
|
| 97 |
+
data_loader = DataLoader(
|
| 98 |
+
combined_dataset,
|
| 99 |
+
batch_size=batch_size,
|
| 100 |
+
shuffle=False,
|
| 101 |
+
num_workers=num_workers,
|
| 102 |
+
pin_memory=pin_memory,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
return data_loader
|
SPC-UQ/Image_Classification/data/ood_detection/svhn.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from torch.utils.data import Subset
|
| 5 |
+
|
| 6 |
+
from torchvision import datasets
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_train_valid_loader(batch_size, augment, val_seed, imagesize=128, val_size=0.1, num_workers=1, pin_memory=False, **kwargs):
|
| 11 |
+
"""
|
| 12 |
+
Utility function for loading and returning train and valid
|
| 13 |
+
multi-process iterators over the SVHN dataset.
|
| 14 |
+
Params:
|
| 15 |
+
------
|
| 16 |
+
- batch_size: how many samples per batch to load.
|
| 17 |
+
- augment: whether to apply the data augmentation scheme
|
| 18 |
+
mentioned in the paper. Only applied on the train split.
|
| 19 |
+
- val_seed: fix seed for reproducibility.
|
| 20 |
+
- val_size: percentage split of the training set used for
|
| 21 |
+
the validation set. Should be a float in the range [0, 1].
|
| 22 |
+
- num_workers: number of subprocesses to use when loading the dataset.
|
| 23 |
+
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
|
| 24 |
+
True if using GPU.
|
| 25 |
+
|
| 26 |
+
Returns
|
| 27 |
+
-------
|
| 28 |
+
- train_loader: training set iterator.
|
| 29 |
+
- valid_loader: validation set iterator.
|
| 30 |
+
"""
|
| 31 |
+
error_msg = "[!] val_size should be in the range [0, 1]."
|
| 32 |
+
assert (val_size >= 0) and (val_size <= 1), error_msg
|
| 33 |
+
|
| 34 |
+
normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010],)
|
| 35 |
+
|
| 36 |
+
# define transforms
|
| 37 |
+
valid_transform = transforms.Compose([transforms.Resize(imagesize),transforms.ToTensor(), normalize,])
|
| 38 |
+
|
| 39 |
+
# load the dataset
|
| 40 |
+
data_dir = "./data"
|
| 41 |
+
train_dataset = datasets.SVHN(root=data_dir, split="train", download=True, transform=valid_transform,)
|
| 42 |
+
|
| 43 |
+
valid_dataset = datasets.SVHN(root=data_dir, split="train", download=True, transform=valid_transform,)
|
| 44 |
+
|
| 45 |
+
num_train = len(train_dataset)
|
| 46 |
+
indices = list(range(num_train))
|
| 47 |
+
split = int(np.floor(val_size * num_train))
|
| 48 |
+
|
| 49 |
+
np.random.seed(val_seed)
|
| 50 |
+
np.random.shuffle(indices)
|
| 51 |
+
|
| 52 |
+
train_idx, valid_idx = indices[split:], indices[:split]
|
| 53 |
+
train_subset = Subset(train_dataset, train_idx)
|
| 54 |
+
valid_subset = Subset(valid_dataset, valid_idx)
|
| 55 |
+
|
| 56 |
+
train_loader = torch.utils.data.DataLoader(
|
| 57 |
+
train_subset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True,
|
| 58 |
+
)
|
| 59 |
+
valid_loader = torch.utils.data.DataLoader(
|
| 60 |
+
valid_subset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=False,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
return (train_loader, valid_loader)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_test_loader(batch_size, imagesize=128, num_workers=1, pin_memory=False, **kwargs):
|
| 67 |
+
"""
|
| 68 |
+
Utility function for loading and returning a multi-process
|
| 69 |
+
test iterator over the SVHN dataset.
|
| 70 |
+
If using CUDA, num_workers should be set to 1 and pin_memory to True.
|
| 71 |
+
Params
|
| 72 |
+
------
|
| 73 |
+
- batch_size: how many samples per batch to load.
|
| 74 |
+
- num_workers: number of subprocesses to use when loading the dataset.
|
| 75 |
+
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
|
| 76 |
+
True if using GPU.
|
| 77 |
+
Returns
|
| 78 |
+
-------
|
| 79 |
+
- data_loader: test set iterator.
|
| 80 |
+
"""
|
| 81 |
+
normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010],)
|
| 82 |
+
|
| 83 |
+
# define transform
|
| 84 |
+
transform = transforms.Compose([transforms.Resize(imagesize), transforms.ToTensor(), normalize,])
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
data_dir = "./data"
|
| 88 |
+
dataset = datasets.SVHN(root=data_dir, split="test", download=True, transform=transform,)
|
| 89 |
+
|
| 90 |
+
data_loader = torch.utils.data.DataLoader(
|
| 91 |
+
dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
return data_loader
|
SPC-UQ/Image_Classification/data/ood_detection/tinyimagenet.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Create train, valid, test iterators for Tiny-ImageNet.
|
| 3 |
+
Train set size: 90000 (450 per class)
|
| 4 |
+
Val set size: 10000 (50 per class)
|
| 5 |
+
Test set size: 10000 (no labels)
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
import os
|
| 11 |
+
from torch.utils.data import Subset
|
| 12 |
+
from torchvision import datasets, transforms
|
| 13 |
+
from torchvision.datasets import ImageFolder
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_train_valid_loader(batch_size, augment, val_seed, imagesize=128, val_size=0.1, num_workers=1, pin_memory=False, **kwargs):
|
| 17 |
+
"""
|
| 18 |
+
Load and return train and valid iterators over the Tiny-ImageNet dataset.
|
| 19 |
+
|
| 20 |
+
Params:
|
| 21 |
+
------
|
| 22 |
+
- data_dir: path to Tiny-ImageNet dataset directory.
|
| 23 |
+
- batch_size: number of samples per batch.
|
| 24 |
+
- augment: whether to apply data augmentation.
|
| 25 |
+
- val_seed: random seed for reproducibility.
|
| 26 |
+
- val_size: fraction of the training set used for validation (0 to 1).
|
| 27 |
+
- num_workers: number of subprocesses for data loading.
|
| 28 |
+
- pin_memory: set to True if using GPU.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
-------
|
| 32 |
+
- train_loader: training set iterator.
|
| 33 |
+
- valid_loader: validation set iterator.
|
| 34 |
+
"""
|
| 35 |
+
assert 0 <= val_size <= 1, "[!] val_size should be in the range [0, 1]."
|
| 36 |
+
|
| 37 |
+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 38 |
+
|
| 39 |
+
# Define transforms
|
| 40 |
+
valid_transform = transforms.Compose([
|
| 41 |
+
transforms.Resize(imagesize),
|
| 42 |
+
transforms.ToTensor(),
|
| 43 |
+
normalize
|
| 44 |
+
])
|
| 45 |
+
|
| 46 |
+
if augment:
|
| 47 |
+
train_transform = transforms.Compose([
|
| 48 |
+
transforms.RandomCrop(256, padding=4),
|
| 49 |
+
transforms.RandomHorizontalFlip(),
|
| 50 |
+
transforms.Resize(imagesize),
|
| 51 |
+
transforms.ToTensor(),
|
| 52 |
+
normalize
|
| 53 |
+
])
|
| 54 |
+
else:
|
| 55 |
+
train_transform = valid_transform
|
| 56 |
+
|
| 57 |
+
# Load dataset
|
| 58 |
+
data_dir = "./data/tinyimagenet"
|
| 59 |
+
train_dir = os.path.join(data_dir, "train")
|
| 60 |
+
train_dataset = ImageFolder(root=train_dir, transform=train_transform)
|
| 61 |
+
valid_dataset = ImageFolder(root=train_dir, transform=valid_transform) # Same dataset, different transform
|
| 62 |
+
|
| 63 |
+
num_train = len(train_dataset)
|
| 64 |
+
indices = list(range(num_train))
|
| 65 |
+
split = int(np.floor(val_size * num_train))
|
| 66 |
+
|
| 67 |
+
np.random.seed(val_seed)
|
| 68 |
+
np.random.shuffle(indices)
|
| 69 |
+
|
| 70 |
+
train_idx, valid_idx = indices[split:], indices[:split]
|
| 71 |
+
|
| 72 |
+
train_subset = Subset(train_dataset, train_idx)
|
| 73 |
+
valid_subset = Subset(valid_dataset, valid_idx)
|
| 74 |
+
|
| 75 |
+
train_loader = torch.utils.data.DataLoader(
|
| 76 |
+
train_subset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True
|
| 77 |
+
)
|
| 78 |
+
valid_loader = torch.utils.data.DataLoader(
|
| 79 |
+
valid_subset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=False
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
return train_loader, valid_loader
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def get_test_loader(batch_size, imagesize=128, num_workers=1, pin_memory=False, **kwargs):
|
| 86 |
+
"""
|
| 87 |
+
Load and return a test iterator over the Tiny-ImageNet dataset.
|
| 88 |
+
|
| 89 |
+
Params:
|
| 90 |
+
------
|
| 91 |
+
- data_dir: path to Tiny-ImageNet dataset directory.
|
| 92 |
+
- batch_size: number of samples per batch.
|
| 93 |
+
- num_workers: number of subprocesses for data loading.
|
| 94 |
+
- pin_memory: set to True if using GPU.
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
-------
|
| 98 |
+
- data_loader: test set iterator.
|
| 99 |
+
"""
|
| 100 |
+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 101 |
+
|
| 102 |
+
transform = transforms.Compose([
|
| 103 |
+
transforms.Resize(imagesize),
|
| 104 |
+
transforms.ToTensor(),
|
| 105 |
+
normalize
|
| 106 |
+
])
|
| 107 |
+
|
| 108 |
+
data_dir = "./data/tinyimagenet"
|
| 109 |
+
test_dir = os.path.join(data_dir, "test")
|
| 110 |
+
dataset = ImageFolder(root=test_dir, transform=transform)
|
| 111 |
+
|
| 112 |
+
data_loader = torch.utils.data.DataLoader(
|
| 113 |
+
dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory
|
| 114 |
+
)
|
| 115 |
+
return data_loader
|
SPC-UQ/Image_Classification/environment.yml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: DDU
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- python
|
| 7 |
+
- pytorch=1.7.1
|
| 8 |
+
- torchvision=0.8.2
|
| 9 |
+
- cudatoolkit=10.1
|
| 10 |
+
- tqdm
|
| 11 |
+
- tensorboard
|
| 12 |
+
- numpy
|
| 13 |
+
- scipy
|
| 14 |
+
- matplotlib
|
| 15 |
+
- seaborn
|
| 16 |
+
- scikit-learn
|
SPC-UQ/Image_Classification/evaluate.py
ADDED
|
@@ -0,0 +1,1427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script to evaluate a single model.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import gc
|
| 6 |
+
import json
|
| 7 |
+
import math
|
| 8 |
+
import torch
|
| 9 |
+
import argparse
|
| 10 |
+
import torch.backends.cudnn as cudnn
|
| 11 |
+
import numpy as np
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
from sklearn.metrics import accuracy_score
|
| 14 |
+
|
| 15 |
+
# Import dataloaders
|
| 16 |
+
import data.ood_detection.cifar10 as cifar10
|
| 17 |
+
import data.ood_detection.cifar100 as cifar100
|
| 18 |
+
import data.ood_detection.svhn as svhn
|
| 19 |
+
import data.ood_detection.imagenet as imagenet
|
| 20 |
+
import data.ood_detection.tinyimagenet as tinyimagenet
|
| 21 |
+
import data.ood_detection.imagenet_o as imagenet_o
|
| 22 |
+
import data.ood_detection.imagenet_a as imagenet_a
|
| 23 |
+
import data.ood_detection.ood_union as ood_union
|
| 24 |
+
|
| 25 |
+
# Import network models
|
| 26 |
+
from net.resnet import resnet50
|
| 27 |
+
from net.resnet_edl import resnet50_edl
|
| 28 |
+
from net.wide_resnet import wrn
|
| 29 |
+
from net.wide_resnet_edl import wrn_edl
|
| 30 |
+
from net.wide_resnet_uq import wrn_uq
|
| 31 |
+
from net.vgg import vgg16
|
| 32 |
+
from net.vgg_edl import vgg16_edl
|
| 33 |
+
from net.vgg_uq import vgg16_uq
|
| 34 |
+
from net.imagenet_wide import imagenet_wide
|
| 35 |
+
from net.imagenet_vgg import imagenet_vgg16
|
| 36 |
+
from net.imagenet_vit import imagenet_vit
|
| 37 |
+
|
| 38 |
+
# Import metrics to compute
|
| 39 |
+
from metrics.classification_metrics import (
|
| 40 |
+
test_classification_net,
|
| 41 |
+
test_classification_net_logits,
|
| 42 |
+
test_classification_uq,
|
| 43 |
+
test_classification_net_ensemble,
|
| 44 |
+
test_classification_net_edl,
|
| 45 |
+
create_adversarial_dataloader,
|
| 46 |
+
test_classification_net_logits_edl,
|
| 47 |
+
test_classification_net_softmax
|
| 48 |
+
)
|
| 49 |
+
from metrics.calibration_metrics import expected_calibration_error
|
| 50 |
+
from metrics.uncertainty_confidence import entropy, logsumexp, self_consistency, edl_unc, certificate
|
| 51 |
+
from metrics.ood_metrics import get_roc_auc, get_roc_auc_logits, get_roc_auc_ensemble, get_unc_ensemble, get_roc_auc_uncs
|
| 52 |
+
from metrics.classification_metrics import get_logits_labels
|
| 53 |
+
from metrics.classification_metrics import get_logits_labels_uq
|
| 54 |
+
|
| 55 |
+
# Import GMM utils
|
| 56 |
+
from utils.gmm_utils import get_embeddings, gmm_evaluate, gmm_fit
|
| 57 |
+
from utils.ensemble_utils import load_ensemble, Ensemble_fit, Ensemble_evaluate, Ensemble_load
|
| 58 |
+
from utils.oc_utils import oc_fit, oc_evaluate
|
| 59 |
+
from utils.eval_utils import model_load_name
|
| 60 |
+
from utils.train_utils import model_save_name
|
| 61 |
+
from utils.args import eval_args
|
| 62 |
+
|
| 63 |
+
# Import SPC utils
|
| 64 |
+
from utils.spc_utils import SPC_fit, SPC_load, SPC_evaluate
|
| 65 |
+
|
| 66 |
+
# Import EDL utils
|
| 67 |
+
from utils.edl_utils import EDL_fit, EDL_load, EDL_evaluate
|
| 68 |
+
|
| 69 |
+
# Temperature scaling
|
| 70 |
+
from utils.temperature_scaling import ModelWithTemperature
|
| 71 |
+
|
| 72 |
+
# Dataset params
|
| 73 |
+
dataset_num_classes = {"cifar10": 10, "cifar100": 100, "svhn": 10, "imagenet": 1000, "tinyimagenet": 200, "imagenet_o":200, "imagenet_a":200}
|
| 74 |
+
|
| 75 |
+
dataset_loader = {"cifar10": cifar10, "cifar100": cifar100, "svhn": svhn, "imagenet": imagenet, "tinyimagenet": tinyimagenet, "imagenet_o":imagenet_o, "imagenet_a":imagenet_a}
|
| 76 |
+
|
| 77 |
+
# Mapping model name to model function
|
| 78 |
+
models = {"resnet50": resnet50, "resnet50_edl":resnet50_edl, "wide_resnet": wrn, "wide_resnet_edl": wrn_edl, "wide_resnet_uq": wrn_uq, "vgg16": vgg16, "vgg16_edl": vgg16_edl, "vgg16_uq": vgg16_uq, "imagenet_wide":imagenet_wide, "imagenet_vgg16":imagenet_vgg16, "imagenet_vit":imagenet_vit}
|
| 79 |
+
|
| 80 |
+
model_to_num_dim = {"resnet50": 2048, "resnet50_edl":2048, "wide_resnet": 640, "wide_resnet_edl": 640, "wide_resnet_uq": 640, "vgg16": 512, "vgg16_edl": 512, "vgg16_uq": 512, "imagenet_wide":2048, "imagenet_vgg16":4096, "imagenet_vit":768}
|
| 81 |
+
|
| 82 |
+
model_to_input_dim = {"resnet50": 32, "resnet50_edl": 32, "wide_resnet": 32, "wide_resnet_edl": 32, "wide_resnet_uq": 32, "vgg16": 32, "vgg16_edl": 32, "vgg16_uq": 32, "imagenet_wide":224, "imagenet_vgg16":224, "imagenet_vit":224}
|
| 83 |
+
|
| 84 |
+
model_to_last_layer = {"resnet50": "module.fc", "wide_resnet": "module.linear", "vgg16": "module.classifier", "imagenet_wide": "module.linear", "imagenet_vgg16": "module.classifier", "imagenet_vit": "module.linear"}
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
|
| 88 |
+
args = eval_args().parse_args()
|
| 89 |
+
|
| 90 |
+
# Checking if GPU is available
|
| 91 |
+
cuda = torch.cuda.is_available()
|
| 92 |
+
|
| 93 |
+
# Setting additional parameters
|
| 94 |
+
print("Parsed args", args)
|
| 95 |
+
print("Seed: ", args.seed)
|
| 96 |
+
torch.manual_seed(args.seed)
|
| 97 |
+
device = torch.device("cuda" if cuda else "cpu")
|
| 98 |
+
|
| 99 |
+
# Taking input for the dataset
|
| 100 |
+
num_classes = dataset_num_classes[args.dataset]
|
| 101 |
+
|
| 102 |
+
test_loader = dataset_loader[args.dataset].get_test_loader(batch_size=args.batch_size, imagesize=model_to_input_dim[args.model], pin_memory=args.gpu)
|
| 103 |
+
|
| 104 |
+
if args.ood_dataset=='ood_union':
|
| 105 |
+
ood_test_loader = ood_union.get_combined_ood_test_loader(batch_size=args.batch_size, sample_seed=args.seed, imagesize=model_to_input_dim[args.model], pin_memory=args.gpu)
|
| 106 |
+
else:
|
| 107 |
+
ood_test_loader = dataset_loader[args.ood_dataset].get_test_loader(batch_size=args.batch_size,
|
| 108 |
+
imagesize=model_to_input_dim[args.model],
|
| 109 |
+
pin_memory=args.gpu)
|
| 110 |
+
|
| 111 |
+
# Evaluating the models
|
| 112 |
+
accuracies = []
|
| 113 |
+
c_accuracies = []
|
| 114 |
+
|
| 115 |
+
# Pre temperature scaling
|
| 116 |
+
# m1 - Uncertainty/Confidence Metric 1
|
| 117 |
+
# for deterministic model: logsumexp, for ensemble: entropy
|
| 118 |
+
# m2 - Uncertainty/Confidence Metric 2
|
| 119 |
+
# for deterministic model: entropy, for ensemble: MI
|
| 120 |
+
eces = []
|
| 121 |
+
ood_m1_aurocs = []
|
| 122 |
+
ood_m1_auprcs = []
|
| 123 |
+
ood_m2_aurocs = []
|
| 124 |
+
ood_m2_auprcs = []
|
| 125 |
+
|
| 126 |
+
err_m1_aurocs = []
|
| 127 |
+
err_m1_auprcs = []
|
| 128 |
+
err_m2_aurocs = []
|
| 129 |
+
err_m2_auprcs = []
|
| 130 |
+
|
| 131 |
+
adv_ep = 0.02
|
| 132 |
+
adv_m1_aurocs = []
|
| 133 |
+
adv_m1_auprcs = []
|
| 134 |
+
adv_m2_aurocs = []
|
| 135 |
+
adv_m2_auprcs = []
|
| 136 |
+
|
| 137 |
+
# Post temperature scaling
|
| 138 |
+
t_eces = []
|
| 139 |
+
t_m1_aurocs = []
|
| 140 |
+
t_m1_auprcs = []
|
| 141 |
+
t_m2_aurocs = []
|
| 142 |
+
t_m2_auprcs = []
|
| 143 |
+
|
| 144 |
+
c_eces = []
|
| 145 |
+
|
| 146 |
+
adv_unc = np.zeros((args.runs, 9))
|
| 147 |
+
adv_acc = np.zeros((args.runs, 9))
|
| 148 |
+
|
| 149 |
+
topt = None
|
| 150 |
+
|
| 151 |
+
for i in range(args.runs):
|
| 152 |
+
print (f"Evaluating run: {(i+1)}")
|
| 153 |
+
# Loading the model(s)
|
| 154 |
+
if args.model_type == "ensemble":
|
| 155 |
+
if args.dataset == 'imagenet':
|
| 156 |
+
train_loader, val_loader = dataset_loader[args.dataset].get_train_valid_loader(
|
| 157 |
+
batch_size=args.batch_size, imagesize=model_to_input_dim[args.model], augment=args.data_aug,
|
| 158 |
+
val_seed=(args.seed + i), val_size=args.val_size, pin_memory=args.gpu)
|
| 159 |
+
net = models[args.model](pretrained=True, num_classes=1000).cuda()
|
| 160 |
+
if args.gpu:
|
| 161 |
+
net.cuda()
|
| 162 |
+
net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
|
| 163 |
+
cudnn.benchmark = True
|
| 164 |
+
net.eval()
|
| 165 |
+
|
| 166 |
+
else:
|
| 167 |
+
val_loaders = []
|
| 168 |
+
for j in range(args.ensemble):
|
| 169 |
+
train_loader, val_loader = dataset_loader[args.dataset].get_train_valid_loader(
|
| 170 |
+
batch_size=args.batch_size, imagesize=model_to_input_dim[args.model], augment=args.data_aug, val_seed=(args.seed+(5*i)+j), val_size=0.1, pin_memory=args.gpu,
|
| 171 |
+
)
|
| 172 |
+
val_loaders.append(val_loader)
|
| 173 |
+
# Evaluate an ensemble
|
| 174 |
+
ensemble_loc = args.load_loc
|
| 175 |
+
net_ensemble = load_ensemble(
|
| 176 |
+
ensemble_loc=ensemble_loc,
|
| 177 |
+
model_name=args.model,
|
| 178 |
+
device=device,
|
| 179 |
+
num_classes=num_classes,
|
| 180 |
+
spectral_normalization=args.sn,
|
| 181 |
+
mod=args.mod,
|
| 182 |
+
coeff=args.coeff,
|
| 183 |
+
seed=(i)
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
else:
|
| 187 |
+
train_loader, val_loader = dataset_loader[args.dataset].get_train_valid_loader(
|
| 188 |
+
batch_size=args.batch_size, imagesize=model_to_input_dim[args.model], augment=args.data_aug, val_seed=(args.seed+i), val_size=args.val_size, pin_memory=args.gpu,
|
| 189 |
+
)
|
| 190 |
+
if args.dataset == 'imagenet':
|
| 191 |
+
net = models[args.model](pretrained=True, num_classes=1000).cuda()
|
| 192 |
+
if args.gpu:
|
| 193 |
+
net.cuda()
|
| 194 |
+
net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
|
| 195 |
+
cudnn.benchmark = True
|
| 196 |
+
net.eval()
|
| 197 |
+
|
| 198 |
+
else:
|
| 199 |
+
if args.val_size==0.1 or (not args.crossval):
|
| 200 |
+
saved_model_name = os.path.join(
|
| 201 |
+
args.load_loc,
|
| 202 |
+
"Run" + str(i + 1),
|
| 203 |
+
model_load_name(args.model, args.sn, args.mod, args.coeff, args.seed, i) + "_350.model",
|
| 204 |
+
)
|
| 205 |
+
else:
|
| 206 |
+
saved_model_name = os.path.join(
|
| 207 |
+
args.load_loc,
|
| 208 |
+
"Run" + str(i + 1),
|
| 209 |
+
model_load_name(args.model, args.sn, args.mod, args.coeff, args.seed, i) + "_350_0"+str(int(args.val_size*10))+".model",
|
| 210 |
+
)
|
| 211 |
+
print(saved_model_name)
|
| 212 |
+
net = models[args.model](
|
| 213 |
+
spectral_normalization=args.sn, mod=args.mod, coeff=args.coeff, num_classes=num_classes, temp=1.0,
|
| 214 |
+
)
|
| 215 |
+
if args.gpu:
|
| 216 |
+
net.cuda()
|
| 217 |
+
net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
|
| 218 |
+
cudnn.benchmark = True
|
| 219 |
+
net.load_state_dict(torch.load(str(saved_model_name)))
|
| 220 |
+
net.eval()
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# Evaluating the UQ method
|
| 224 |
+
if args.model_type == "ensemble":
|
| 225 |
+
if args.dataset == 'imagenet':
|
| 226 |
+
ensemble_model_path = os.path.join(
|
| 227 |
+
args.load_loc,
|
| 228 |
+
"Run" + str(i + 1),
|
| 229 |
+
model_load_name(args.model, args.sn, args.mod, args.coeff, args.seed, i) + "_350_ensemble_model.pth",
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
if os.path.exists(ensemble_model_path):
|
| 233 |
+
print(f"Loading existing ensemble_model from {ensemble_model_path}")
|
| 234 |
+
Ensemble_model = Ensemble_load(ensemble_model_path, model_to_num_dim[args.model],
|
| 235 |
+
num_classes, device)
|
| 236 |
+
else:
|
| 237 |
+
if args.model == 'imagenet_vgg16':
|
| 238 |
+
embed_path = 'data/imagenet_train_vgg_embedding.pt'
|
| 239 |
+
# embed_path = 'data/imagenet_val_vgg_embedding.pt'
|
| 240 |
+
if args.model == 'imagenet_wide':
|
| 241 |
+
embed_path = 'data/imagenet_train_wide_embedding.pt'
|
| 242 |
+
# embed_path = 'data/imagenet_val_wide_embedding.pt'
|
| 243 |
+
if args.model == 'imagenet_vit':
|
| 244 |
+
embed_path = 'data/imagenet_train_vit_embedding.pt'
|
| 245 |
+
# embed_path = 'data/imagenet_val_vit_embedding.pt'
|
| 246 |
+
if os.path.exists(embed_path):
|
| 247 |
+
data = torch.load(embed_path, map_location=device)
|
| 248 |
+
embeddings = data['embeddings']
|
| 249 |
+
labels = data['labels']
|
| 250 |
+
else:
|
| 251 |
+
embeddings, labels = get_embeddings(
|
| 252 |
+
net,
|
| 253 |
+
train_loader,
|
| 254 |
+
num_dim=model_to_num_dim[args.model],
|
| 255 |
+
dtype=torch.double,
|
| 256 |
+
device=device,
|
| 257 |
+
storage_device=device,
|
| 258 |
+
)
|
| 259 |
+
torch.save({'embeddings': embeddings, 'labels': labels}, embed_path)
|
| 260 |
+
Ensemble_model = Ensemble_fit(embeddings, labels, model_to_num_dim[args.model], num_classes, device)
|
| 261 |
+
torch.save(Ensemble_model.state_dict(), ensemble_model_path)
|
| 262 |
+
print(f"Model saved at {ensemble_model_path}")
|
| 263 |
+
|
| 264 |
+
logits, predictive_entropy, mut_info, labels = Ensemble_evaluate(net, Ensemble_model, test_loader, model_to_num_dim[args.model], device)
|
| 265 |
+
ood_logits, ood_predictive_entropy, ood_mut_info, ood_labels = Ensemble_evaluate(net, Ensemble_model, ood_test_loader, model_to_num_dim[args.model], device)
|
| 266 |
+
(conf_matrix, accuracy, labels_list, predictions, confidences,) = test_classification_net_softmax(logits, labels)
|
| 267 |
+
t_accuracy = accuracy
|
| 268 |
+
ece = expected_calibration_error(confidences, predictions, labels_list, num_bins=15)
|
| 269 |
+
t_ece = ece
|
| 270 |
+
(_, _, _), (_, _, _), ood_m1_auroc, ood_m1_auprc = get_roc_auc_uncs(predictive_entropy, ood_predictive_entropy, device)
|
| 271 |
+
(_, _, _), (_, _, _), ood_m2_auroc, ood_m2_auprc = get_roc_auc_uncs(mut_info, ood_mut_info, device)
|
| 272 |
+
|
| 273 |
+
labels_array = np.array(labels_list)
|
| 274 |
+
pred_array = np.array(predictions)
|
| 275 |
+
correct_mask = labels_array == pred_array
|
| 276 |
+
entropy_right = predictive_entropy[correct_mask]
|
| 277 |
+
entropy_wrong = predictive_entropy[~correct_mask]
|
| 278 |
+
mut_info_right = mut_info[correct_mask]
|
| 279 |
+
mut_info_wrong = mut_info[~correct_mask]
|
| 280 |
+
(_, _, _), (_, _, _), err_m1_auroc, err_m1_auprc = get_roc_auc_uncs(entropy_right, entropy_wrong, device)
|
| 281 |
+
(_, _, _), (_, _, _), err_m2_auroc, err_m2_auprc = get_roc_auc_uncs(mut_info_right, mut_info_wrong, device)
|
| 282 |
+
|
| 283 |
+
adv_test_loader = dataset_loader['imagenet_a'].get_test_loader(batch_size=args.batch_size,
|
| 284 |
+
imagesize=model_to_input_dim[args.model],
|
| 285 |
+
pin_memory=args.gpu)
|
| 286 |
+
|
| 287 |
+
adv_logits, adv_predictive_entropy, adv_mut_info, adv_labels = Ensemble_evaluate(net, Ensemble_model, adv_test_loader, model_to_num_dim[args.model], device)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
(_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc_uncs(predictive_entropy, adv_predictive_entropy, device)
|
| 291 |
+
(_, _, _), (_, _, _), adv_m2_auroc, adv_m2_auprc = get_roc_auc_uncs(mut_info, adv_mut_info, device)
|
| 292 |
+
|
| 293 |
+
print('adv_m1_auroc', adv_m1_auroc)
|
| 294 |
+
|
| 295 |
+
t_m1_auroc = ood_m1_auroc
|
| 296 |
+
t_m1_auprc = ood_m1_auprc
|
| 297 |
+
t_m2_auroc = ood_m2_auroc
|
| 298 |
+
t_m2_auprc = ood_m2_auprc
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
else:
|
| 302 |
+
(conf_matrix, accuracy, labels_list, predictions, confidences,) = test_classification_net_ensemble(
|
| 303 |
+
net_ensemble, test_loader, device
|
| 304 |
+
)
|
| 305 |
+
ece = expected_calibration_error(confidences, predictions, labels_list, num_bins=15)
|
| 306 |
+
|
| 307 |
+
(_, _, _), (_, _, _), ood_m1_auroc, ood_m1_auprc = get_roc_auc_ensemble(
|
| 308 |
+
net_ensemble, test_loader, ood_test_loader, "entropy", device
|
| 309 |
+
)
|
| 310 |
+
(_, _, _), (_, _, _), ood_m2_auroc, ood_m2_auprc = get_roc_auc_ensemble(
|
| 311 |
+
net_ensemble, test_loader, ood_test_loader, "mutual_information", device
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
labels_array = np.array(labels_list)
|
| 315 |
+
pred_array = np.array(predictions)
|
| 316 |
+
correct_mask = labels_array == pred_array
|
| 317 |
+
from torch.utils.data import Subset, DataLoader
|
| 318 |
+
dataset = test_loader.dataset
|
| 319 |
+
correct_indices = np.where(correct_mask)[0]
|
| 320 |
+
right_subset = Subset(dataset, correct_indices)
|
| 321 |
+
right_loader = DataLoader(right_subset, batch_size=test_loader.batch_size, shuffle=False)
|
| 322 |
+
wrong_indices = np.where(~correct_mask)[0]
|
| 323 |
+
wrong_subset = Subset(dataset, wrong_indices)
|
| 324 |
+
wrong_loader = DataLoader(wrong_subset, batch_size=test_loader.batch_size, shuffle=False)
|
| 325 |
+
(_, _, _), (_, _, _), err_m1_auroc, err_m1_auprc = get_roc_auc_ensemble(
|
| 326 |
+
net_ensemble, right_loader, wrong_loader, "entropy", device
|
| 327 |
+
)
|
| 328 |
+
(_, _, _), (_, _, _), err_m2_auroc, err_m2_auprc = get_roc_auc_ensemble(
|
| 329 |
+
net_ensemble, right_loader, wrong_loader, "mutual_information", device
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
adv_test_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=adv_ep,batch_size=args.batch_size)
|
| 333 |
+
(_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc_ensemble(
|
| 334 |
+
net_ensemble, test_loader, adv_test_loader, "entropy", device
|
| 335 |
+
)
|
| 336 |
+
(_, _, _), (_, _, _), adv_m2_auroc, adv_m2_auprc = get_roc_auc_ensemble(
|
| 337 |
+
net_ensemble, test_loader, adv_test_loader, "mutual_information", device
|
| 338 |
+
)
|
| 339 |
+
print('adv_m1_auroc,adv_m2_auroc', adv_m1_auroc, adv_m2_auroc)
|
| 340 |
+
|
| 341 |
+
if args.sample_noise:
|
| 342 |
+
adv_eps = np.linspace(0, 0.4, 9)
|
| 343 |
+
print(adv_eps)
|
| 344 |
+
for idx_ep, ep in enumerate(adv_eps):
|
| 345 |
+
adv_loader = create_adversarial_dataloader(net_ensemble[0], test_loader, device, epsilon=ep,
|
| 346 |
+
batch_size=args.batch_size)
|
| 347 |
+
(adv_conf_matrix, adv_accuracy, adv_labels_list, adv_predictions, adv_confidences) = test_classification_net_ensemble(
|
| 348 |
+
net_ensemble, adv_loader, device
|
| 349 |
+
)
|
| 350 |
+
uncertainties = get_unc_ensemble(net_ensemble, adv_loader, "entropy", device).detach().cpu().numpy()
|
| 351 |
+
quantiles = np.quantile(uncertainties, np.linspace(0, 1, 10))
|
| 352 |
+
quantiles = np.delete(quantiles, 0)
|
| 353 |
+
unc_list = []
|
| 354 |
+
accuracy_list = []
|
| 355 |
+
for threshold in quantiles:
|
| 356 |
+
cer_indices = (uncertainties < threshold)
|
| 357 |
+
unc_indices = ~cer_indices
|
| 358 |
+
labels_list = np.array(adv_labels_list)
|
| 359 |
+
targets_cer = labels_list[cer_indices]
|
| 360 |
+
predictions = np.array(adv_predictions)
|
| 361 |
+
pred_cer = predictions[cer_indices]
|
| 362 |
+
targets_unc = labels_list[unc_indices]
|
| 363 |
+
pred_unc = predictions[unc_indices]
|
| 364 |
+
cer_right = np.sum(targets_cer == pred_cer)
|
| 365 |
+
cer = len(targets_cer)
|
| 366 |
+
unc_right = np.sum(targets_unc == pred_unc)
|
| 367 |
+
unc = len(targets_unc)
|
| 368 |
+
accuracy_cer = cer_right / cer
|
| 369 |
+
accuracy_unc = unc_right / unc
|
| 370 |
+
unc_list.append(threshold)
|
| 371 |
+
accuracy_list.append(accuracy_cer)
|
| 372 |
+
print('ACC:', accuracy_cer, accuracy_unc)
|
| 373 |
+
from scipy.stats import spearmanr
|
| 374 |
+
|
| 375 |
+
Spearman_acc, p_acc = spearmanr(unc_list, accuracy_list)
|
| 376 |
+
print("Spearman correlation:", Spearman_acc, "mean uncertainties:", uncertainties.mean())
|
| 377 |
+
adv_unc[i][idx_ep] = uncertainties.mean()
|
| 378 |
+
adv_acc[i][idx_ep] = adv_accuracy
|
| 379 |
+
|
| 380 |
+
(_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc_ensemble(
|
| 381 |
+
net_ensemble, test_loader, adv_loader, "entropy", device
|
| 382 |
+
)
|
| 383 |
+
(_, _, _), (_, _, _), adv_m2_auroc, adv_m2_auprc = get_roc_auc_ensemble(
|
| 384 |
+
net_ensemble, test_loader, adv_loader, "mutual_information", device
|
| 385 |
+
)
|
| 386 |
+
print('adv_m1_auroc,adv_m2_auroc', adv_m1_auroc, adv_m2_auroc)
|
| 387 |
+
|
| 388 |
+
# Temperature scale the ensemble
|
| 389 |
+
t_ensemble = []
|
| 390 |
+
for model, val_loader in zip(net_ensemble, val_loaders):
|
| 391 |
+
t_model = ModelWithTemperature(model)
|
| 392 |
+
t_model.set_temperature(val_loader)
|
| 393 |
+
t_ensemble.append(t_model)
|
| 394 |
+
|
| 395 |
+
(
|
| 396 |
+
t_conf_matrix,
|
| 397 |
+
t_accuracy,
|
| 398 |
+
t_labels_list,
|
| 399 |
+
t_predictions,
|
| 400 |
+
t_confidences,
|
| 401 |
+
) = test_classification_net_ensemble(t_ensemble, test_loader, device)
|
| 402 |
+
t_ece = expected_calibration_error(t_confidences, t_predictions, t_labels_list, num_bins=15)
|
| 403 |
+
|
| 404 |
+
(_, _, _), (_, _, _), t_m1_auroc, t_m1_auprc = get_roc_auc_ensemble(
|
| 405 |
+
t_ensemble, test_loader, ood_test_loader, "entropy", device
|
| 406 |
+
)
|
| 407 |
+
(_, _, _), (_, _, _), t_m2_auroc, t_m2_auprc = get_roc_auc_ensemble(
|
| 408 |
+
t_ensemble, test_loader, ood_test_loader, "mutual_information", device
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
elif args.model_type == "edl":
|
| 412 |
+
if args.dataset == 'imagenet':
|
| 413 |
+
edl_model_path = os.path.join(
|
| 414 |
+
args.load_loc,
|
| 415 |
+
"Run" + str(i + 1),
|
| 416 |
+
model_load_name(args.model, args.sn, args.mod, args.coeff, args.seed, i) + "_350_edl_model.pth",
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
if os.path.exists(edl_model_path):
|
| 420 |
+
print(f"Loading existing edl_model from {edl_model_path}")
|
| 421 |
+
EDL_model = EDL_load(edl_model_path, model_to_num_dim[args.model],
|
| 422 |
+
num_classes, device)
|
| 423 |
+
else:
|
| 424 |
+
if args.model=='imagenet_vgg16':
|
| 425 |
+
embed_path = 'data/imagenet_train_vgg_embedding.pt'
|
| 426 |
+
# embed_path = 'data/imagenet_val_vgg_embedding.pt'
|
| 427 |
+
if args.model=='imagenet_wide':
|
| 428 |
+
embed_path = 'data/imagenet_train_wide_embedding.pt'
|
| 429 |
+
# embed_path = 'data/imagenet_val_wide_embedding.pt'
|
| 430 |
+
if args.model=='imagenet_vit':
|
| 431 |
+
embed_path = 'data/imagenet_train_vit_embedding.pt'
|
| 432 |
+
# embed_path = 'data/imagenet_val_vit_embedding.pt'
|
| 433 |
+
if os.path.exists(embed_path):
|
| 434 |
+
data = torch.load(embed_path, map_location=device)
|
| 435 |
+
embeddings = data['embeddings']
|
| 436 |
+
labels = data['labels']
|
| 437 |
+
else:
|
| 438 |
+
embeddings, labels = get_embeddings(
|
| 439 |
+
net,
|
| 440 |
+
train_loader,
|
| 441 |
+
num_dim=model_to_num_dim[args.model],
|
| 442 |
+
dtype=torch.double,
|
| 443 |
+
device=device,
|
| 444 |
+
storage_device=device,
|
| 445 |
+
)
|
| 446 |
+
torch.save({'embeddings': embeddings, 'labels': labels}, embed_path)
|
| 447 |
+
EDL_model = EDL_fit(embeddings, labels, model_to_num_dim[args.model], num_classes,
|
| 448 |
+
device)
|
| 449 |
+
torch.save(EDL_model.state_dict(), edl_model_path)
|
| 450 |
+
print(f"Model saved at {edl_model_path}")
|
| 451 |
+
|
| 452 |
+
logits, labels = EDL_evaluate(net, EDL_model, test_loader, model_to_num_dim[args.model], device)
|
| 453 |
+
ood_logits, ood_labels = EDL_evaluate(net, EDL_model, ood_test_loader, model_to_num_dim[args.model], device)
|
| 454 |
+
(conf_matrix, accuracy, labels_list, predictions, confidences,) = test_classification_net_logits_edl(logits, labels)
|
| 455 |
+
t_accuracy = accuracy
|
| 456 |
+
|
| 457 |
+
ece = expected_calibration_error(confidences, predictions, labels_list, num_bins=15)
|
| 458 |
+
t_ece = ece
|
| 459 |
+
|
| 460 |
+
(_, _, _), (_, _, _), ood_m1_auroc, ood_m1_auprc = get_roc_auc_logits(logits, ood_logits, edl_unc, device)
|
| 461 |
+
|
| 462 |
+
labels_array = np.array(labels_list)
|
| 463 |
+
pred_array = np.array(predictions)
|
| 464 |
+
correct_mask = labels_array == pred_array
|
| 465 |
+
# logits, _ = get_logits_labels(net, test_loader, device)
|
| 466 |
+
logits_right = logits[correct_mask]
|
| 467 |
+
logits_wrong = logits[~correct_mask]
|
| 468 |
+
(_, _, _), (_, _, _), err_m1_auroc, err_m1_auprc = get_roc_auc_logits(logits_right, logits_wrong,
|
| 469 |
+
edl_unc, device)
|
| 470 |
+
|
| 471 |
+
adv_test_loader = dataset_loader['imagenet_a'].get_test_loader(batch_size=args.batch_size,
|
| 472 |
+
imagesize=model_to_input_dim[args.model],
|
| 473 |
+
pin_memory=args.gpu)
|
| 474 |
+
# (adv_conf_matrix, adv_accuracy, adv_labels_list, adv_predictions,adv_confidences,) = test_classification_net_edl(net, adv_test_loader, device)
|
| 475 |
+
# adv_logits, _ = get_logits_labels(net, adv_test_loader, device)
|
| 476 |
+
adv_logits, adv_labels = EDL_evaluate(net, EDL_model, adv_test_loader, model_to_num_dim[args.model], device)
|
| 477 |
+
|
| 478 |
+
(_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc_logits(logits, adv_logits, edl_unc, device)
|
| 479 |
+
print('adv_m1_auroc', adv_m1_auroc)
|
| 480 |
+
|
| 481 |
+
if args.sample_noise:
|
| 482 |
+
adv_eps = np.linspace(0, 0.4, 9)
|
| 483 |
+
for idx_ep, ep in enumerate(adv_eps):
|
| 484 |
+
adv_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=ep,
|
| 485 |
+
batch_size=args.batch_size)
|
| 486 |
+
(adv_conf_matrix, adv_accuracy, adv_labels_list, adv_predictions,
|
| 487 |
+
adv_confidences,) = test_classification_net(net, adv_loader, device)
|
| 488 |
+
adv_logits, adv_labels = EDL_evaluate(net, EDL_model, adv_loader,
|
| 489 |
+
model_to_num_dim[args.model], device)
|
| 490 |
+
uncertainties = edl_unc(adv_logits).detach().cpu().numpy()
|
| 491 |
+
quantiles = np.quantile(uncertainties, np.linspace(0, 1, 10))
|
| 492 |
+
quantiles = np.delete(quantiles, 0)
|
| 493 |
+
unc_list = []
|
| 494 |
+
accuracy_list = []
|
| 495 |
+
for threshold in quantiles:
|
| 496 |
+
cer_indices = (uncertainties < threshold)
|
| 497 |
+
unc_indices = ~cer_indices
|
| 498 |
+
labels_list = np.array(adv_labels_list)
|
| 499 |
+
targets_cer = labels_list[cer_indices]
|
| 500 |
+
predictions = np.array(adv_predictions)
|
| 501 |
+
pred_cer = predictions[cer_indices]
|
| 502 |
+
targets_unc = labels_list[unc_indices]
|
| 503 |
+
pred_unc = predictions[unc_indices]
|
| 504 |
+
cer_right = np.sum(targets_cer == pred_cer)
|
| 505 |
+
cer = len(targets_cer)
|
| 506 |
+
unc_right = np.sum(targets_unc == pred_unc)
|
| 507 |
+
unc = len(targets_unc)
|
| 508 |
+
accuracy_cer = cer_right / cer
|
| 509 |
+
accuracy_unc = unc_right / unc
|
| 510 |
+
unc_list.append(threshold)
|
| 511 |
+
accuracy_list.append(accuracy_cer)
|
| 512 |
+
print('ACC:', accuracy_cer, accuracy_unc)
|
| 513 |
+
from scipy.stats import spearmanr
|
| 514 |
+
|
| 515 |
+
Spearman_acc, p_acc = spearmanr(unc_list, accuracy_list)
|
| 516 |
+
print("Spearman correlation:", Spearman_acc, "mean uncertainties:", uncertainties.mean())
|
| 517 |
+
adv_unc[i][idx_ep] = uncertainties.mean()
|
| 518 |
+
adv_acc[i][idx_ep] = adv_accuracy
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
else:
|
| 522 |
+
(conf_matrix, accuracy, labels_list, predictions, confidences,) = test_classification_net_edl(
|
| 523 |
+
net, test_loader, device
|
| 524 |
+
)
|
| 525 |
+
t_accuracy = accuracy
|
| 526 |
+
ece = expected_calibration_error(confidences, predictions, labels_list, num_bins=15)
|
| 527 |
+
t_ece=ece
|
| 528 |
+
|
| 529 |
+
(_, _, _), (_, _, _), ood_m1_auroc, ood_m1_auprc = get_roc_auc(net, test_loader, ood_test_loader, edl_unc, device)
|
| 530 |
+
|
| 531 |
+
labels_array = np.array(labels_list)
|
| 532 |
+
pred_array = np.array(predictions)
|
| 533 |
+
correct_mask = labels_array == pred_array
|
| 534 |
+
logits, _ = get_logits_labels(net, test_loader, device)
|
| 535 |
+
logits_right = logits[correct_mask]
|
| 536 |
+
logits_wrong = logits[~correct_mask]
|
| 537 |
+
(_, _, _), (_, _, _), err_m1_auroc, err_m1_auprc = get_roc_auc_logits(logits_right, logits_wrong, edl_unc, device)
|
| 538 |
+
|
| 539 |
+
adv_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=adv_ep,
|
| 540 |
+
batch_size=args.batch_size, edl=True)
|
| 541 |
+
(adv_conf_matrix, adv_accuracy, adv_labels_list, adv_predictions,
|
| 542 |
+
adv_confidences,) = test_classification_net_edl(net, adv_loader, device)
|
| 543 |
+
adv_logits, _ = get_logits_labels(net, adv_loader, device)
|
| 544 |
+
|
| 545 |
+
(_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc_logits(logits, adv_logits, edl_unc, device)
|
| 546 |
+
print('adv_m1_auroc', adv_m1_auroc)
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
if args.sample_noise:
|
| 550 |
+
adv_eps = np.linspace(0, 0.4, 9)
|
| 551 |
+
for idx_ep, ep in enumerate(adv_eps):
|
| 552 |
+
adv_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=ep,
|
| 553 |
+
batch_size=args.batch_size, edl=True)
|
| 554 |
+
(adv_conf_matrix, adv_accuracy, adv_labels_list, adv_predictions,
|
| 555 |
+
adv_confidences,) = test_classification_net_edl(net, adv_loader, device)
|
| 556 |
+
adv_logits, _ = get_logits_labels(net, adv_loader, device)
|
| 557 |
+
uncertainties = edl_unc(adv_logits).detach().cpu().numpy()
|
| 558 |
+
quantiles = np.quantile(uncertainties, np.linspace(0, 1, 10))
|
| 559 |
+
quantiles = np.delete(quantiles, 0)
|
| 560 |
+
unc_list = []
|
| 561 |
+
accuracy_list = []
|
| 562 |
+
for threshold in quantiles:
|
| 563 |
+
cer_indices = (uncertainties < threshold)
|
| 564 |
+
unc_indices = ~cer_indices
|
| 565 |
+
labels_list = np.array(adv_labels_list)
|
| 566 |
+
targets_cer = labels_list[cer_indices]
|
| 567 |
+
predictions = np.array(adv_predictions)
|
| 568 |
+
pred_cer = predictions[cer_indices]
|
| 569 |
+
targets_unc = labels_list[unc_indices]
|
| 570 |
+
pred_unc = predictions[unc_indices]
|
| 571 |
+
cer_right = np.sum(targets_cer == pred_cer)
|
| 572 |
+
cer = len(targets_cer)
|
| 573 |
+
unc_right = np.sum(targets_unc == pred_unc)
|
| 574 |
+
unc = len(targets_unc)
|
| 575 |
+
accuracy_cer = cer_right / cer
|
| 576 |
+
accuracy_unc = unc_right / unc
|
| 577 |
+
unc_list.append(threshold)
|
| 578 |
+
accuracy_list.append(accuracy_cer)
|
| 579 |
+
print('ACC:', accuracy_cer, accuracy_unc)
|
| 580 |
+
from scipy.stats import spearmanr
|
| 581 |
+
|
| 582 |
+
Spearman_acc, p_acc = spearmanr(unc_list, accuracy_list)
|
| 583 |
+
print("Spearman correlation:", Spearman_acc, "mean uncertainties:", uncertainties.mean())
|
| 584 |
+
adv_unc[i][idx_ep] = uncertainties.mean()
|
| 585 |
+
adv_acc[i][idx_ep] = adv_accuracy
|
| 586 |
+
|
| 587 |
+
ood_m2_auroc=ood_m1_auroc
|
| 588 |
+
ood_m2_auprc = ood_m1_auprc
|
| 589 |
+
err_m2_auroc = err_m1_auroc
|
| 590 |
+
err_m2_auprc = err_m1_auprc
|
| 591 |
+
adv_m2_auroc = adv_m1_auroc
|
| 592 |
+
adv_m2_auprc = adv_m1_auprc
|
| 593 |
+
t_m1_auroc=ood_m1_auroc
|
| 594 |
+
t_m1_auprc=ood_m1_auprc
|
| 595 |
+
t_m2_auroc=ood_m1_auroc
|
| 596 |
+
t_m2_auprc=ood_m1_auprc
|
| 597 |
+
|
| 598 |
+
elif args.model_type == "joint":
|
| 599 |
+
(conf_matrix, accuracy, labels_list, predictions, confidences,) = test_classification_uq(
|
| 600 |
+
net, test_loader, device
|
| 601 |
+
)
|
| 602 |
+
ece = expected_calibration_error(confidences, predictions, labels_list, num_bins=15)
|
| 603 |
+
print(accuracy)
|
| 604 |
+
print('ece', ece)
|
| 605 |
+
|
| 606 |
+
t_ece=ece
|
| 607 |
+
t_accuracy=accuracy
|
| 608 |
+
|
| 609 |
+
print("SPC Model")
|
| 610 |
+
|
| 611 |
+
logits, labels = get_logits_labels_uq(net, test_loader, device)
|
| 612 |
+
|
| 613 |
+
soft = torch.nn.functional.softmax(logits[0], dim=1)
|
| 614 |
+
delta = torch.min(torch.min(logits[2] - logits[3], logits[1] - 2 * logits[3]), 2 * logits[2] - logits[1])
|
| 615 |
+
|
| 616 |
+
uncertainty = abs(logits[2] + logits[3] - logits[1])
|
| 617 |
+
|
| 618 |
+
threshold = 0.05
|
| 619 |
+
mask = (uncertainty < threshold).float()
|
| 620 |
+
delta = delta * mask
|
| 621 |
+
softmax_prob = soft + delta
|
| 622 |
+
|
| 623 |
+
c_confidences, c_predictions = torch.max(softmax_prob, dim=1)
|
| 624 |
+
c_predictions = c_predictions.tolist()
|
| 625 |
+
c_confidences = c_confidences.tolist()
|
| 626 |
+
c_ece = expected_calibration_error(c_confidences, c_predictions, labels_list, num_bins=15)
|
| 627 |
+
print('ece', ece, 't_ece', t_ece, 'c_ece', c_ece)
|
| 628 |
+
c_accuracy = accuracy_score(labels_list, c_predictions)
|
| 629 |
+
print(accuracy, t_accuracy, c_accuracy)
|
| 630 |
+
|
| 631 |
+
ood_logits, ood_labels = get_logits_labels_uq(net, ood_test_loader, device)
|
| 632 |
+
|
| 633 |
+
(_, _, _), (_, _, _), ood_m1_auroc, ood_m1_auprc = get_roc_auc_logits(logits, ood_logits, self_consistency,
|
| 634 |
+
device)
|
| 635 |
+
(_, _, _), (_, _, _), ood_m2_auroc, ood_m2_auprc = get_roc_auc_logits(logits[0], ood_logits[0], entropy, device)
|
| 636 |
+
|
| 637 |
+
labels_array = np.array(labels_list)
|
| 638 |
+
pred_array = np.array(predictions)
|
| 639 |
+
correct_mask = labels_array == pred_array
|
| 640 |
+
logits_right = [m[correct_mask] for m in logits]
|
| 641 |
+
logits_wrong = [m[~correct_mask] for m in logits]
|
| 642 |
+
(_, _, _), (_, _, _), err_m1_auroc, err_m1_auprc = get_roc_auc_logits(logits_right, logits_wrong,
|
| 643 |
+
self_consistency, device)
|
| 644 |
+
(_, _, _), (_, _, _), err_m2_auroc, err_m2_auprc = get_roc_auc_logits(logits_right[0], logits_wrong[0], entropy,
|
| 645 |
+
device)
|
| 646 |
+
|
| 647 |
+
if args.dataset == 'imagenet':
|
| 648 |
+
adv_test_loader = dataset_loader['imagenet_a'].get_test_loader(batch_size=args.batch_size,
|
| 649 |
+
imagesize=model_to_input_dim[args.model],
|
| 650 |
+
pin_memory=args.gpu)
|
| 651 |
+
else:
|
| 652 |
+
adv_test_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=adv_ep,
|
| 653 |
+
batch_size=args.batch_size, joint=True)
|
| 654 |
+
|
| 655 |
+
adv_logits, adv_labels = get_logits_labels_uq(net, adv_test_loader, device)
|
| 656 |
+
|
| 657 |
+
(_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc_logits(logits, adv_logits, self_consistency, device)
|
| 658 |
+
(_, _, _), (_, _, _), adv_m2_auroc, adv_m2_auprc = get_roc_auc_logits(logits[0], adv_logits[0], entropy, device)
|
| 659 |
+
|
| 660 |
+
t_m1_auroc = ood_m1_auroc
|
| 661 |
+
t_m1_auprc = ood_m1_auprc
|
| 662 |
+
t_m2_auroc = ood_m2_auroc
|
| 663 |
+
t_m2_auprc = ood_m2_auprc
|
| 664 |
+
|
| 665 |
+
else:
|
| 666 |
+
(conf_matrix, accuracy, labels_list, predictions, confidences,) = test_classification_net(
|
| 667 |
+
net, test_loader, device
|
| 668 |
+
)
|
| 669 |
+
ece = expected_calibration_error(confidences, predictions, labels_list, num_bins=15)
|
| 670 |
+
print(accuracy)
|
| 671 |
+
print('ece',ece)
|
| 672 |
+
|
| 673 |
+
temp_scaled_net = ModelWithTemperature(net)
|
| 674 |
+
temp_scaled_net.set_temperature(val_loader)
|
| 675 |
+
# temp_scaled_net.set_temperature(train_loader)
|
| 676 |
+
topt = temp_scaled_net.temperature
|
| 677 |
+
|
| 678 |
+
(t_conf_matrix, t_accuracy, t_labels_list, t_predictions, t_confidences,) = test_classification_net(
|
| 679 |
+
temp_scaled_net, test_loader, device
|
| 680 |
+
)
|
| 681 |
+
t_ece = expected_calibration_error(t_confidences, t_predictions, t_labels_list, num_bins=15)
|
| 682 |
+
print('t_ece',t_ece)
|
| 683 |
+
|
| 684 |
+
if (args.model_type == "gmm"):
|
| 685 |
+
# Evaluate a GMM model
|
| 686 |
+
print("GMM Model")
|
| 687 |
+
|
| 688 |
+
if args.crossval:
|
| 689 |
+
embeddings, labels = get_embeddings(
|
| 690 |
+
net,
|
| 691 |
+
val_loader,
|
| 692 |
+
num_dim=model_to_num_dim[args.model],
|
| 693 |
+
dtype=torch.double,
|
| 694 |
+
device=device,
|
| 695 |
+
storage_device=device,
|
| 696 |
+
)
|
| 697 |
+
else:
|
| 698 |
+
if args.dataset == 'imagenet':
|
| 699 |
+
if args.model == 'imagenet_vgg16':
|
| 700 |
+
embed_path = 'data/imagenet_train_vgg_embedding.pt'
|
| 701 |
+
# embed_path = 'data/imagenet_val_vgg_embedding.pt'
|
| 702 |
+
if args.model == 'imagenet_wide':
|
| 703 |
+
embed_path = 'data/imagenet_train_wide_embedding.pt'
|
| 704 |
+
# embed_path = 'data/imagenet_val_wide_embedding.pt'
|
| 705 |
+
if args.model == 'imagenet_vit':
|
| 706 |
+
embed_path = 'data/imagenet_train_vit_embedding.pt'
|
| 707 |
+
# embed_path = 'data/imagenet_val_vit_embedding.pt'
|
| 708 |
+
if os.path.exists(embed_path):
|
| 709 |
+
data = torch.load(embed_path, map_location=device)
|
| 710 |
+
embeddings = data['embeddings']
|
| 711 |
+
labels = data['labels']
|
| 712 |
+
else:
|
| 713 |
+
embeddings, labels = get_embeddings(
|
| 714 |
+
net,
|
| 715 |
+
train_loader,
|
| 716 |
+
num_dim=model_to_num_dim[args.model],
|
| 717 |
+
dtype=torch.double,
|
| 718 |
+
device=device,
|
| 719 |
+
storage_device=device,
|
| 720 |
+
)
|
| 721 |
+
torch.save({'embeddings': embeddings, 'labels': labels}, embed_path)
|
| 722 |
+
else:
|
| 723 |
+
embeddings, labels = get_embeddings(
|
| 724 |
+
net,
|
| 725 |
+
train_loader,
|
| 726 |
+
num_dim=model_to_num_dim[args.model],
|
| 727 |
+
dtype=torch.double,
|
| 728 |
+
device=device,
|
| 729 |
+
storage_device=device,
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
try:
|
| 733 |
+
gaussians_model, jitter_eps = gmm_fit(embeddings=embeddings, labels=labels, num_classes=num_classes, device=device)
|
| 734 |
+
logits, labels = gmm_evaluate(
|
| 735 |
+
net, gaussians_model, test_loader, device=device, num_classes=num_classes, storage_device=device,
|
| 736 |
+
)
|
| 737 |
+
|
| 738 |
+
ood_logits, ood_labels = gmm_evaluate(
|
| 739 |
+
net, gaussians_model, ood_test_loader, device=device, num_classes=num_classes, storage_device=device,
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
(_, _, _), (_, _, _), ood_m1_auroc, ood_m1_auprc = get_roc_auc_logits(
|
| 743 |
+
logits, ood_logits, logsumexp, device, confidence=True
|
| 744 |
+
)
|
| 745 |
+
(_, _, _), (_, _, _), ood_m2_auroc, ood_m2_auprc = get_roc_auc_logits(logits, ood_logits, entropy, device)
|
| 746 |
+
|
| 747 |
+
labels_array = np.array(labels_list)
|
| 748 |
+
pred_array = np.array(predictions)
|
| 749 |
+
correct_mask = labels_array == pred_array
|
| 750 |
+
logits_right = logits[correct_mask]
|
| 751 |
+
logits_wrong = logits[~correct_mask]
|
| 752 |
+
(_, _, _), (_, _, _), err_m1_auroc, err_m1_auprc = get_roc_auc_logits(logits_right, logits_wrong,logsumexp, device, confidence=True)
|
| 753 |
+
(_, _, _), (_, _, _), err_m2_auroc, err_m2_auprc = get_roc_auc_logits(logits_right, logits_wrong,entropy, device, confidence=True)
|
| 754 |
+
|
| 755 |
+
if args.dataset == 'imagenet':
|
| 756 |
+
adv_test_loader = dataset_loader['imagenet_a'].get_test_loader(batch_size=args.batch_size,
|
| 757 |
+
imagesize=model_to_input_dim[args.model],
|
| 758 |
+
pin_memory=args.gpu)
|
| 759 |
+
else:
|
| 760 |
+
adv_test_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=adv_ep,batch_size=args.batch_size)
|
| 761 |
+
(adv_conf_matrix, adv_accuracy, adv_labels_list, adv_predictions, adv_confidences,) = test_classification_net(net, adv_test_loader, device)
|
| 762 |
+
adv_logits, adv_labels = gmm_evaluate(net, gaussians_model, adv_test_loader, device=device, num_classes=num_classes, storage_device=device, )
|
| 763 |
+
(_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc_logits(logits, adv_logits, logsumexp, device, confidence=True)
|
| 764 |
+
(_, _, _), (_, _, _), adv_m2_auroc, adv_m2_auprc = get_roc_auc_logits(logits, adv_logits, entropy, device)
|
| 765 |
+
|
| 766 |
+
if args.sample_noise:
|
| 767 |
+
adv_eps = np.linspace(0, 0.4, 9)
|
| 768 |
+
for idx_ep, ep in enumerate(adv_eps):
|
| 769 |
+
adv_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=ep,
|
| 770 |
+
batch_size=args.batch_size)
|
| 771 |
+
(adv_conf_matrix, adv_accuracy, adv_labels_list, adv_predictions,
|
| 772 |
+
adv_confidences,) = test_classification_net(net, adv_loader, device)
|
| 773 |
+
adv_logits, adv_labels = gmm_evaluate(net, gaussians_model, adv_loader, device=device, num_classes=num_classes,storage_device=device,)
|
| 774 |
+
uncertainties = -logsumexp(adv_logits).detach().cpu().numpy()
|
| 775 |
+
quantiles = np.quantile(uncertainties, np.linspace(0, 1, 10))
|
| 776 |
+
quantiles = np.delete(quantiles, 0)
|
| 777 |
+
unc_list = []
|
| 778 |
+
accuracy_list = []
|
| 779 |
+
for threshold in quantiles:
|
| 780 |
+
cer_indices = (uncertainties < threshold)
|
| 781 |
+
unc_indices = ~cer_indices
|
| 782 |
+
labels_list = np.array(adv_labels_list)
|
| 783 |
+
targets_cer = labels_list[cer_indices]
|
| 784 |
+
predictions = np.array(adv_predictions)
|
| 785 |
+
pred_cer = predictions[cer_indices]
|
| 786 |
+
targets_unc = labels_list[unc_indices]
|
| 787 |
+
pred_unc = predictions[unc_indices]
|
| 788 |
+
cer_right = np.sum(targets_cer == pred_cer)
|
| 789 |
+
cer = len(targets_cer)
|
| 790 |
+
unc_right = np.sum(targets_unc == pred_unc)
|
| 791 |
+
unc = len(targets_unc)
|
| 792 |
+
accuracy_cer = cer_right / cer
|
| 793 |
+
accuracy_unc = unc_right / unc
|
| 794 |
+
unc_list.append(threshold)
|
| 795 |
+
accuracy_list.append(accuracy_cer)
|
| 796 |
+
print('ACC:', accuracy_cer, accuracy_unc)
|
| 797 |
+
from scipy.stats import spearmanr
|
| 798 |
+
Spearman_acc, p_acc = spearmanr(unc_list, accuracy_list)
|
| 799 |
+
print("Spearman correlation:", Spearman_acc, "mean uncertainties:", uncertainties.mean())
|
| 800 |
+
adv_unc[i][idx_ep]=uncertainties.mean()
|
| 801 |
+
adv_acc[i][idx_ep] = adv_accuracy
|
| 802 |
+
|
| 803 |
+
(_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc_logits(logits, adv_logits, logsumexp, device, confidence=True)
|
| 804 |
+
(_, _, _), (_, _, _), adv_m2_auroc, adv_m2_auprc = get_roc_auc_logits(logits, adv_logits, entropy, device)
|
| 805 |
+
print('adv_m1_auroc,adv_m2_auroc', adv_m1_auroc, adv_m2_auroc)
|
| 806 |
+
|
| 807 |
+
t_m1_auroc = ood_m1_auroc
|
| 808 |
+
t_m1_auprc = ood_m1_auprc
|
| 809 |
+
t_m2_auroc = ood_m2_auroc
|
| 810 |
+
t_m2_auprc = ood_m2_auprc
|
| 811 |
+
|
| 812 |
+
except RuntimeError as e:
|
| 813 |
+
print("Runtime Error caught: " + str(e))
|
| 814 |
+
continue
|
| 815 |
+
|
| 816 |
+
elif (args.model_type == "oc"):
|
| 817 |
+
# Evaluate a OC model
|
| 818 |
+
print("OC Model")
|
| 819 |
+
|
| 820 |
+
if args.crossval:
|
| 821 |
+
embeddings, labels = get_embeddings(
|
| 822 |
+
net,
|
| 823 |
+
val_loader,
|
| 824 |
+
num_dim=model_to_num_dim[args.model],
|
| 825 |
+
dtype=torch.double,
|
| 826 |
+
device=device,
|
| 827 |
+
storage_device=device,
|
| 828 |
+
)
|
| 829 |
+
else:
|
| 830 |
+
if args.dataset == 'imagenet':
|
| 831 |
+
if args.model == 'imagenet_vgg16':
|
| 832 |
+
embed_path = 'data/imagenet_train_vgg_embedding.pt'
|
| 833 |
+
# embed_path = 'data/imagenet_val_vgg_embedding.pt'
|
| 834 |
+
if args.model == 'imagenet_wide':
|
| 835 |
+
embed_path = 'data/imagenet_train_wide_embedding.pt'
|
| 836 |
+
# embed_path = 'data/imagenet_val_wide_embedding.pt'
|
| 837 |
+
if args.model == 'imagenet_vit':
|
| 838 |
+
embed_path = 'data/imagenet_train_vit_embedding.pt'
|
| 839 |
+
# embed_path = 'data/imagenet_val_vit_embedding.pt'
|
| 840 |
+
if os.path.exists(embed_path):
|
| 841 |
+
data = torch.load(embed_path, map_location=device)
|
| 842 |
+
embeddings = data['embeddings']
|
| 843 |
+
labels = data['labels']
|
| 844 |
+
else:
|
| 845 |
+
embeddings, labels = get_embeddings(
|
| 846 |
+
net,
|
| 847 |
+
train_loader,
|
| 848 |
+
num_dim=model_to_num_dim[args.model],
|
| 849 |
+
dtype=torch.double,
|
| 850 |
+
device=device,
|
| 851 |
+
storage_device=device,
|
| 852 |
+
)
|
| 853 |
+
torch.save({'embeddings': embeddings, 'labels': labels}, embed_path)
|
| 854 |
+
else:
|
| 855 |
+
embeddings, labels = get_embeddings(
|
| 856 |
+
net,
|
| 857 |
+
train_loader,
|
| 858 |
+
num_dim=model_to_num_dim[args.model],
|
| 859 |
+
dtype=torch.double,
|
| 860 |
+
device=device,
|
| 861 |
+
storage_device=device,
|
| 862 |
+
)
|
| 863 |
+
|
| 864 |
+
try:
|
| 865 |
+
oc_model = oc_fit(embeddings=embeddings, device=device)
|
| 866 |
+
logits, OCs = oc_evaluate(
|
| 867 |
+
net, oc_model, test_loader,model_to_num_dim[args.model], device=device
|
| 868 |
+
)
|
| 869 |
+
|
| 870 |
+
ood_logits, ood_OCs = oc_evaluate(
|
| 871 |
+
net, oc_model, ood_test_loader, model_to_num_dim[args.model], device=device)
|
| 872 |
+
|
| 873 |
+
(_, _, _), (_, _, _), ood_m1_auroc, ood_m1_auprc = get_roc_auc_logits(OCs, ood_OCs, certificate, device, confidence=True)
|
| 874 |
+
(_, _, _), (_, _, _), ood_m2_auroc, ood_m2_auprc = get_roc_auc_logits(logits, ood_logits, entropy, device)
|
| 875 |
+
|
| 876 |
+
labels_array = np.array(labels_list)
|
| 877 |
+
pred_array = np.array(predictions)
|
| 878 |
+
correct_mask = labels_array == pred_array
|
| 879 |
+
logits_right = logits[correct_mask]
|
| 880 |
+
logits_wrong = logits[~correct_mask]
|
| 881 |
+
OCs_right = OCs[correct_mask]
|
| 882 |
+
OCs_wrong = OCs[~correct_mask]
|
| 883 |
+
(_, _, _), (_, _, _), err_m1_auroc, err_m1_auprc = get_roc_auc_logits(OCs_right, OCs_wrong, certificate, device, confidence=True)
|
| 884 |
+
(_, _, _), (_, _, _), err_m2_auroc, err_m2_auprc = get_roc_auc_logits(logits_right, logits_wrong, entropy, device)
|
| 885 |
+
|
| 886 |
+
if args.dataset == 'imagenet':
|
| 887 |
+
adv_test_loader = dataset_loader['imagenet_a'].get_test_loader(batch_size=args.batch_size,
|
| 888 |
+
imagesize=model_to_input_dim[args.model],
|
| 889 |
+
pin_memory=args.gpu)
|
| 890 |
+
else:
|
| 891 |
+
adv_test_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=adv_ep,batch_size=args.batch_size)
|
| 892 |
+
(adv_conf_matrix, adv_accuracy, adv_labels_list, adv_predictions, adv_confidences,) = test_classification_net(net, adv_test_loader, device)
|
| 893 |
+
adv_logits, adv_OCs = oc_evaluate(net, oc_model, adv_test_loader, model_to_num_dim[args.model], device=device)
|
| 894 |
+
(_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc_logits(OCs, adv_OCs, certificate, device, confidence=True)
|
| 895 |
+
(_, _, _), (_, _, _), adv_m2_auroc, adv_m2_auprc = get_roc_auc_logits(logits, adv_logits, entropy, device)
|
| 896 |
+
|
| 897 |
+
if args.sample_noise:
|
| 898 |
+
adv_eps = np.linspace(0, 0.4, 9)
|
| 899 |
+
for idx_ep, ep in enumerate(adv_eps):
|
| 900 |
+
adv_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=ep,
|
| 901 |
+
batch_size=args.batch_size)
|
| 902 |
+
(adv_conf_matrix, adv_accuracy, adv_labels_list, adv_predictions,
|
| 903 |
+
adv_confidences,) = test_classification_net(net, adv_loader, device)
|
| 904 |
+
adv_logits, adv_OCs = oc_evaluate(net, oc_model, adv_loader, model_to_num_dim[args.model], device=device)
|
| 905 |
+
uncertainties = -adv_OCs.cpu().numpy()
|
| 906 |
+
quantiles = np.quantile(uncertainties, np.linspace(0, 1, 10))
|
| 907 |
+
quantiles = np.delete(quantiles, 0)
|
| 908 |
+
unc_list = []
|
| 909 |
+
accuracy_list = []
|
| 910 |
+
for threshold in quantiles:
|
| 911 |
+
cer_indices = (uncertainties < threshold)
|
| 912 |
+
unc_indices = ~cer_indices
|
| 913 |
+
labels_list = np.array(adv_labels_list)
|
| 914 |
+
targets_cer = labels_list[cer_indices]
|
| 915 |
+
predictions = np.array(adv_predictions)
|
| 916 |
+
pred_cer = predictions[cer_indices]
|
| 917 |
+
targets_unc = labels_list[unc_indices]
|
| 918 |
+
pred_unc = predictions[unc_indices]
|
| 919 |
+
cer_right = np.sum(targets_cer == pred_cer)
|
| 920 |
+
cer = len(targets_cer)
|
| 921 |
+
unc_right = np.sum(targets_unc == pred_unc)
|
| 922 |
+
unc = len(targets_unc)
|
| 923 |
+
accuracy_cer = cer_right / cer
|
| 924 |
+
accuracy_unc = unc_right / unc
|
| 925 |
+
unc_list.append(threshold)
|
| 926 |
+
accuracy_list.append(accuracy_cer)
|
| 927 |
+
print('ACC:', accuracy_cer, accuracy_unc)
|
| 928 |
+
from scipy.stats import spearmanr
|
| 929 |
+
Spearman_acc, p_acc = spearmanr(unc_list, accuracy_list)
|
| 930 |
+
print("Spearman correlation:", Spearman_acc, "mean uncertainties:", uncertainties.mean())
|
| 931 |
+
adv_unc[i][idx_ep]=uncertainties.mean()
|
| 932 |
+
adv_acc[i][idx_ep] = adv_accuracy
|
| 933 |
+
|
| 934 |
+
(_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc_logits(OCs, adv_OCs, certificate, device, confidence=True)
|
| 935 |
+
(_, _, _), (_, _, _), adv_m2_auroc, adv_m2_auprc = get_roc_auc_logits(logits, adv_logits, entropy, device)
|
| 936 |
+
print('adv_m1_auroc,adv_m2_auroc', adv_m1_auroc, adv_m2_auroc)
|
| 937 |
+
|
| 938 |
+
t_m1_auroc = ood_m1_auroc
|
| 939 |
+
t_m1_auprc = ood_m1_auprc
|
| 940 |
+
t_m2_auroc = ood_m2_auroc
|
| 941 |
+
t_m2_auprc = ood_m2_auprc
|
| 942 |
+
|
| 943 |
+
except RuntimeError as e:
|
| 944 |
+
print("Runtime Error caught: " + str(e))
|
| 945 |
+
continue
|
| 946 |
+
|
| 947 |
+
elif (args.model_type == "spc"):
|
| 948 |
+
print("SPC Model")
|
| 949 |
+
|
| 950 |
+
if args.crossval:
|
| 951 |
+
spc_model_path = os.path.join(
|
| 952 |
+
args.load_loc,
|
| 953 |
+
"Run" + str(i + 1),
|
| 954 |
+
model_load_name(args.model, args.sn, args.mod, args.coeff, args.seed, i) + "_valsize_"+str(args.val_size)+"_350_mar_model.pth",
|
| 955 |
+
)
|
| 956 |
+
else:
|
| 957 |
+
spc_model_path = os.path.join(
|
| 958 |
+
args.load_loc,
|
| 959 |
+
"Run" + str(i + 1),
|
| 960 |
+
model_load_name(args.model, args.sn, args.mod, args.coeff, args.seed, i) + "_350_mar_model.pth",
|
| 961 |
+
)
|
| 962 |
+
|
| 963 |
+
if os.path.exists(spc_model_path):
|
| 964 |
+
print(f"Loading existing spc_model from {spc_model_path}")
|
| 965 |
+
SPC_model = SPC_load(spc_model_path, model_to_num_dim[args.model], num_classes, device)
|
| 966 |
+
else:
|
| 967 |
+
print(f"Model not found. Training a new one...")
|
| 968 |
+
if args.crossval:
|
| 969 |
+
embeddings, labels = get_embeddings(
|
| 970 |
+
net,
|
| 971 |
+
val_loader,
|
| 972 |
+
num_dim=model_to_num_dim[args.model],
|
| 973 |
+
dtype=torch.double,
|
| 974 |
+
device=device,
|
| 975 |
+
storage_device=device,
|
| 976 |
+
)
|
| 977 |
+
else:
|
| 978 |
+
if args.dataset == 'imagenet':
|
| 979 |
+
if args.model=='imagenet_vgg16':
|
| 980 |
+
embed_path = 'data/imagenet_train_vgg_embedding.pt'
|
| 981 |
+
# embed_path = 'data/imagenet_val_vgg_embedding.pt'
|
| 982 |
+
if args.model=='imagenet_wide':
|
| 983 |
+
embed_path = 'data/imagenet_train_wide_embedding.pt'
|
| 984 |
+
# embed_path = 'data/imagenet_val_wide_embedding.pt'
|
| 985 |
+
if args.model=='imagenet_vit':
|
| 986 |
+
embed_path = 'data/imagenet_train_vit_embedding.pt'
|
| 987 |
+
# embed_path = 'data/imagenet_val_vit_embedding.pt'
|
| 988 |
+
if os.path.exists(embed_path):
|
| 989 |
+
data = torch.load(embed_path, map_location=device)
|
| 990 |
+
embeddings = data['embeddings']
|
| 991 |
+
labels = data['labels']
|
| 992 |
+
else:
|
| 993 |
+
embeddings, labels = get_embeddings(
|
| 994 |
+
net,
|
| 995 |
+
train_loader,
|
| 996 |
+
num_dim=model_to_num_dim[args.model],
|
| 997 |
+
dtype=torch.double,
|
| 998 |
+
device=device,
|
| 999 |
+
storage_device=device,
|
| 1000 |
+
)
|
| 1001 |
+
torch.save({'embeddings': embeddings, 'labels': labels}, embed_path)
|
| 1002 |
+
else:
|
| 1003 |
+
embeddings, labels = get_embeddings(
|
| 1004 |
+
net,
|
| 1005 |
+
train_loader,
|
| 1006 |
+
num_dim=model_to_num_dim[args.model],
|
| 1007 |
+
dtype=torch.double,
|
| 1008 |
+
device=device,
|
| 1009 |
+
storage_device=device,
|
| 1010 |
+
)
|
| 1011 |
+
|
| 1012 |
+
parts= model_to_last_layer[args.model].split('.')
|
| 1013 |
+
net_last_layer = net
|
| 1014 |
+
|
| 1015 |
+
for attr in parts:
|
| 1016 |
+
net_last_layer = getattr(net_last_layer, attr)
|
| 1017 |
+
|
| 1018 |
+
SPC_model=SPC_fit(net_last_layer, topt, embeddings, labels, model_to_num_dim[args.model], num_classes, device)
|
| 1019 |
+
torch.save(SPC_model.state_dict(), spc_model_path)
|
| 1020 |
+
print(f"Model saved at {spc_model_path}")
|
| 1021 |
+
|
| 1022 |
+
logits,mars=SPC_evaluate(net, SPC_model, test_loader, model_to_num_dim[args.model], num_classes, device)
|
| 1023 |
+
|
| 1024 |
+
soft = torch.nn.functional.softmax(logits, dim=1)
|
| 1025 |
+
delta = torch.min(torch.min(mars[2] - mars[3], mars[1] - 2 * mars[3]), 2 * mars[2] - mars[1])
|
| 1026 |
+
# delta=mars[2] - mars[3]
|
| 1027 |
+
|
| 1028 |
+
uncertainty = abs(mars[2]+mars[3] - mars[1])
|
| 1029 |
+
|
| 1030 |
+
threshold=0.05
|
| 1031 |
+
mask=(uncertainty<threshold).float()
|
| 1032 |
+
delta = delta*mask
|
| 1033 |
+
softmax_prob = soft + delta
|
| 1034 |
+
print(torch.sum(softmax_prob, dim=1))
|
| 1035 |
+
|
| 1036 |
+
c_confidences, c_predictions = torch.max(softmax_prob, dim=1)
|
| 1037 |
+
c_predictions=c_predictions.tolist()
|
| 1038 |
+
c_confidences=c_confidences.tolist()
|
| 1039 |
+
c_ece = expected_calibration_error(c_confidences, c_predictions, labels_list, num_bins=15)
|
| 1040 |
+
print('ece', ece, 't_ece', t_ece, 'c_ece', c_ece)
|
| 1041 |
+
c_accuracy=accuracy_score(labels_list, c_predictions)
|
| 1042 |
+
print('accuracy',accuracy,'t_accuracy',t_accuracy,'c_accuracy',c_accuracy)
|
| 1043 |
+
|
| 1044 |
+
ood_logits,ood_mars=SPC_evaluate(net, SPC_model, ood_test_loader, model_to_num_dim[args.model], num_classes, device)
|
| 1045 |
+
(_, _, _), (_, _, _), ood_m1_auroc, ood_m1_auprc = get_roc_auc_logits(mars, ood_mars, self_consistency, device)
|
| 1046 |
+
(_, _, _), (_, _, _), ood_m2_auroc, ood_m2_auprc = get_roc_auc_logits(logits, ood_logits, entropy, device)
|
| 1047 |
+
|
| 1048 |
+
labels_array = np.array(labels_list)
|
| 1049 |
+
pred_array = np.array(predictions)
|
| 1050 |
+
correct_mask = labels_array == pred_array
|
| 1051 |
+
mars_right = [m[correct_mask] for m in mars]
|
| 1052 |
+
mars_wrong = [m[~correct_mask] for m in mars]
|
| 1053 |
+
(_, _, _), (_, _, _), err_m1_auroc, err_m1_auprc = get_roc_auc_logits(mars_right, mars_wrong, self_consistency, device)
|
| 1054 |
+
(_, _, _), (_, _, _), err_m2_auroc, err_m2_auprc = get_roc_auc_logits(mars_right[0], mars_wrong[0], entropy, device)
|
| 1055 |
+
|
| 1056 |
+
if args.dataset == 'imagenet':
|
| 1057 |
+
adv_test_loader = dataset_loader['imagenet_a'].get_test_loader(batch_size=args.batch_size,
|
| 1058 |
+
imagesize=model_to_input_dim[args.model],
|
| 1059 |
+
pin_memory=args.gpu)
|
| 1060 |
+
else:
|
| 1061 |
+
adv_test_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=adv_ep,batch_size=args.batch_size)
|
| 1062 |
+
(adv_conf_matrix, adv_accuracy, adv_labels_list, adv_predictions,adv_confidences,) = test_classification_net(net, adv_test_loader, device)
|
| 1063 |
+
adv_logits, adv_mars = SPC_evaluate(net, SPC_model, adv_test_loader,model_to_num_dim[args.model], num_classes, device)
|
| 1064 |
+
(_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc_logits(mars, adv_mars,self_consistency, device)
|
| 1065 |
+
(_, _, _), (_, _, _), adv_m2_auroc, adv_m2_auprc = get_roc_auc_logits(logits, adv_logits, entropy, device)
|
| 1066 |
+
|
| 1067 |
+
if args.sample_noise:
|
| 1068 |
+
adv_eps = np.linspace(0, 0.4, 9)
|
| 1069 |
+
print(adv_eps)
|
| 1070 |
+
for idx_ep, ep in enumerate(adv_eps):
|
| 1071 |
+
adv_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=ep,
|
| 1072 |
+
batch_size=args.batch_size)
|
| 1073 |
+
(adv_conf_matrix, adv_accuracy, adv_labels_list, adv_predictions,
|
| 1074 |
+
adv_confidences,) = test_classification_net(net, adv_loader, device)
|
| 1075 |
+
adv_logits, adv_mars=SPC_evaluate(net, SPC_model, adv_loader, model_to_num_dim[args.model], num_classes, device)
|
| 1076 |
+
uncertainties = self_consistency(adv_mars).detach().cpu().numpy()
|
| 1077 |
+
quantiles = np.quantile(uncertainties, np.linspace(0, 1, 10))
|
| 1078 |
+
quantiles = np.delete(quantiles, 0)
|
| 1079 |
+
unc_list = []
|
| 1080 |
+
accuracy_list = []
|
| 1081 |
+
for threshold in quantiles:
|
| 1082 |
+
cer_indices = (uncertainties < threshold)
|
| 1083 |
+
unc_indices = ~cer_indices
|
| 1084 |
+
labels_list = np.array(adv_labels_list)
|
| 1085 |
+
targets_cer = labels_list[cer_indices]
|
| 1086 |
+
predictions = np.array(adv_predictions)
|
| 1087 |
+
pred_cer = predictions[cer_indices]
|
| 1088 |
+
targets_unc = labels_list[unc_indices]
|
| 1089 |
+
pred_unc = predictions[unc_indices]
|
| 1090 |
+
cer_right = np.sum(targets_cer == pred_cer)
|
| 1091 |
+
cer = len(targets_cer)
|
| 1092 |
+
unc_right = np.sum(targets_unc == pred_unc)
|
| 1093 |
+
unc = len(targets_unc)
|
| 1094 |
+
accuracy_cer = cer_right / cer
|
| 1095 |
+
accuracy_unc = unc_right / unc
|
| 1096 |
+
unc_list.append(threshold)
|
| 1097 |
+
accuracy_list.append(accuracy_cer)
|
| 1098 |
+
print('ACC:', accuracy_cer, accuracy_unc)
|
| 1099 |
+
from scipy.stats import spearmanr
|
| 1100 |
+
Spearman_acc, p_acc = spearmanr(unc_list, accuracy_list)
|
| 1101 |
+
print("Spearman correlation:", Spearman_acc, "mean uncertainties:", uncertainties.mean())
|
| 1102 |
+
adv_unc[i][idx_ep] = uncertainties.mean()
|
| 1103 |
+
adv_acc[i][idx_ep] = adv_accuracy
|
| 1104 |
+
|
| 1105 |
+
(_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc_logits(mars, adv_mars, self_consistency,device)
|
| 1106 |
+
(_, _, _), (_, _, _), adv_m2_auroc, adv_m2_auprc = get_roc_auc_logits(logits, adv_logits, entropy,device)
|
| 1107 |
+
print('adv_m1_auroc,adv_m2_auroc',adv_m1_auroc,adv_m2_auroc)
|
| 1108 |
+
|
| 1109 |
+
|
| 1110 |
+
t_m1_auroc = ood_m1_auroc
|
| 1111 |
+
t_m1_auprc = ood_m1_auprc
|
| 1112 |
+
t_m2_auroc = ood_m2_auroc
|
| 1113 |
+
t_m2_auprc = ood_m2_auprc
|
| 1114 |
+
|
| 1115 |
+
else:
|
| 1116 |
+
# Evaluate a normal Softmax model
|
| 1117 |
+
print("Softmax Model")
|
| 1118 |
+
(_, _, _), (_, _, _), ood_m1_auroc, ood_m1_auprc = get_roc_auc(net, test_loader, ood_test_loader, entropy, device)
|
| 1119 |
+
(_, _, _), (_, _, _), ood_m2_auroc, ood_m2_auprc = get_roc_auc(net, test_loader, ood_test_loader, logsumexp, device, confidence=True)
|
| 1120 |
+
|
| 1121 |
+
(_, _, _), (_, _, _), t_m1_auroc, t_m1_auprc = get_roc_auc(temp_scaled_net, test_loader, ood_test_loader, entropy, device)
|
| 1122 |
+
(_, _, _), (_, _, _), t_m2_auroc, t_m2_auprc = get_roc_auc(temp_scaled_net, test_loader, ood_test_loader, logsumexp, device, confidence=True)
|
| 1123 |
+
|
| 1124 |
+
labels_array = np.array(labels_list)
|
| 1125 |
+
pred_array = np.array(predictions)
|
| 1126 |
+
correct_mask = labels_array == pred_array
|
| 1127 |
+
logits, _ = get_logits_labels(net, test_loader, device)
|
| 1128 |
+
logits_right = logits[correct_mask]
|
| 1129 |
+
logits_wrong = logits[~correct_mask]
|
| 1130 |
+
(_, _, _), (_, _, _), err_m1_auroc, err_m1_auprc = get_roc_auc_logits(logits_right, logits_wrong, entropy, device)
|
| 1131 |
+
(_, _, _), (_, _, _), err_m2_auroc, err_m2_auprc = get_roc_auc_logits(logits_right, logits_wrong, logsumexp, device, confidence=True)
|
| 1132 |
+
|
| 1133 |
+
if args.dataset == 'imagenet':
|
| 1134 |
+
adv_test_loader = dataset_loader['imagenet_a'].get_test_loader(batch_size=args.batch_size,
|
| 1135 |
+
imagesize=model_to_input_dim[args.model],
|
| 1136 |
+
pin_memory=args.gpu)
|
| 1137 |
+
else:
|
| 1138 |
+
adv_test_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=adv_ep,batch_size=args.batch_size)
|
| 1139 |
+
(adv_conf_matrix, adv_accuracy, adv_labels_list, adv_predictions, adv_confidences,) = test_classification_net(net, adv_test_loader, device)
|
| 1140 |
+
adv_logits, _ = get_logits_labels(net, adv_test_loader, device)
|
| 1141 |
+
(_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc(net, test_loader, adv_test_loader, entropy, device)
|
| 1142 |
+
(_, _, _), (_, _, _), adv_m2_auroc, adv_m2_auprc = get_roc_auc(net, test_loader, adv_test_loader, logsumexp, device, confidence=True)
|
| 1143 |
+
|
| 1144 |
+
if args.sample_noise:
|
| 1145 |
+
adv_eps = np.linspace(0, 0.4, 9)
|
| 1146 |
+
for idx_ep, ep in enumerate(adv_eps):
|
| 1147 |
+
adv_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=ep,
|
| 1148 |
+
batch_size=args.batch_size)
|
| 1149 |
+
(adv_conf_matrix, adv_accuracy, adv_labels_list, adv_predictions,
|
| 1150 |
+
adv_confidences,) = test_classification_net(net, adv_loader, device)
|
| 1151 |
+
adv_logits, _ = get_logits_labels(net, adv_loader, device)
|
| 1152 |
+
uncertainties = entropy(adv_logits).detach().cpu().numpy()
|
| 1153 |
+
quantiles = np.quantile(uncertainties, np.linspace(0, 1, 10))
|
| 1154 |
+
quantiles = np.delete(quantiles, 0)
|
| 1155 |
+
unc_list = []
|
| 1156 |
+
accuracy_list = []
|
| 1157 |
+
for threshold in quantiles:
|
| 1158 |
+
cer_indices = (uncertainties < threshold)
|
| 1159 |
+
unc_indices = ~cer_indices
|
| 1160 |
+
labels_list = np.array(adv_labels_list)
|
| 1161 |
+
targets_cer = labels_list[cer_indices]
|
| 1162 |
+
predictions = np.array(adv_predictions)
|
| 1163 |
+
pred_cer = predictions[cer_indices]
|
| 1164 |
+
targets_unc = labels_list[unc_indices]
|
| 1165 |
+
pred_unc = predictions[unc_indices]
|
| 1166 |
+
cer_right = np.sum(targets_cer == pred_cer)
|
| 1167 |
+
cer = len(targets_cer)
|
| 1168 |
+
unc_right = np.sum(targets_unc == pred_unc)
|
| 1169 |
+
unc = len(targets_unc)
|
| 1170 |
+
accuracy_cer = cer_right / cer
|
| 1171 |
+
accuracy_unc = unc_right / unc
|
| 1172 |
+
unc_list.append(threshold)
|
| 1173 |
+
accuracy_list.append(accuracy_cer)
|
| 1174 |
+
print('ACC:', accuracy_cer, accuracy_unc)
|
| 1175 |
+
from scipy.stats import spearmanr
|
| 1176 |
+
|
| 1177 |
+
Spearman_acc, p_acc = spearmanr(unc_list, accuracy_list)
|
| 1178 |
+
print("Spearman correlation:", Spearman_acc, "mean uncertainties:", uncertainties.mean())
|
| 1179 |
+
adv_unc[i][idx_ep] = uncertainties.mean()
|
| 1180 |
+
adv_acc[i][idx_ep] = adv_accuracy
|
| 1181 |
+
|
| 1182 |
+
(_, _, _), (_, _, _), adv_m1_auroc, adv_m1_auprc = get_roc_auc(net, test_loader, adv_loader, entropy, device)
|
| 1183 |
+
(_, _, _), (_, _, _), adv_m2_auroc, adv_m2_auprc = get_roc_auc(net, test_loader, adv_loader, logsumexp, device, confidence=True)
|
| 1184 |
+
print('adv_m1_auroc,adv_m2_auroc', adv_m1_auroc, adv_m2_auroc)
|
| 1185 |
+
|
| 1186 |
+
|
| 1187 |
+
accuracies.append(accuracy)
|
| 1188 |
+
if (args.model_type == "spc" or args.model_type == "joint"):
|
| 1189 |
+
c_accuracies.append(c_accuracy)
|
| 1190 |
+
else:
|
| 1191 |
+
c_accuracies.append(t_accuracy)
|
| 1192 |
+
|
| 1193 |
+
|
| 1194 |
+
# Pre-temperature results
|
| 1195 |
+
eces.append(ece)
|
| 1196 |
+
ood_m1_aurocs.append(ood_m1_auroc)
|
| 1197 |
+
ood_m1_auprcs.append(ood_m1_auprc)
|
| 1198 |
+
ood_m2_aurocs.append(ood_m2_auroc)
|
| 1199 |
+
ood_m2_auprcs.append(ood_m2_auprc)
|
| 1200 |
+
|
| 1201 |
+
err_m1_aurocs.append(err_m1_auroc)
|
| 1202 |
+
err_m1_auprcs.append(err_m1_auprc)
|
| 1203 |
+
err_m2_aurocs.append(err_m2_auroc)
|
| 1204 |
+
err_m2_auprcs.append(err_m2_auprc)
|
| 1205 |
+
|
| 1206 |
+
adv_m1_aurocs.append(adv_m1_auroc)
|
| 1207 |
+
adv_m1_auprcs.append(adv_m1_auprc)
|
| 1208 |
+
adv_m2_aurocs.append(adv_m2_auroc)
|
| 1209 |
+
adv_m2_auprcs.append(adv_m2_auprc)
|
| 1210 |
+
|
| 1211 |
+
# Post-temperature results
|
| 1212 |
+
t_eces.append(t_ece)
|
| 1213 |
+
t_m1_aurocs.append(t_m1_auroc)
|
| 1214 |
+
t_m1_auprcs.append(t_m1_auprc)
|
| 1215 |
+
t_m2_aurocs.append(t_m2_auroc)
|
| 1216 |
+
t_m2_auprcs.append(t_m2_auprc)
|
| 1217 |
+
|
| 1218 |
+
if (args.model_type == "spc" or args.model_type == "joint"):
|
| 1219 |
+
c_eces.append(c_ece)
|
| 1220 |
+
|
| 1221 |
+
gc.collect()
|
| 1222 |
+
torch.cuda.empty_cache()
|
| 1223 |
+
torch.cuda.ipc_collect()
|
| 1224 |
+
|
| 1225 |
+
if args.sample_noise:
|
| 1226 |
+
adv_unc_norm = (adv_unc - adv_unc.min(axis=1, keepdims=True)) / \
|
| 1227 |
+
(adv_unc.max(axis=1, keepdims=True) - adv_unc.min(axis=1, keepdims=True) + 1e-8)
|
| 1228 |
+
mean_unc = np.mean(adv_unc_norm, axis=0)
|
| 1229 |
+
std_unc = np.std(adv_unc_norm, axis=0)
|
| 1230 |
+
|
| 1231 |
+
plt.figure(figsize=(10, 6))
|
| 1232 |
+
plt.plot(mean_unc, label='Uncertainty', color='orange')
|
| 1233 |
+
plt.fill_between(range(len(mean_unc)),
|
| 1234 |
+
mean_unc - std_unc,
|
| 1235 |
+
mean_unc + std_unc,
|
| 1236 |
+
color='orange', alpha=0.3, label="±1 Std Dev")
|
| 1237 |
+
# plt.legend()
|
| 1238 |
+
# plt.title('Uncertainty Across Multiple Runs')
|
| 1239 |
+
plt.xlabel('Noise')
|
| 1240 |
+
plt.ylabel('Uncertainty')
|
| 1241 |
+
plt.savefig("adv_unc_"
|
| 1242 |
+
+ model_save_name(args.model, args.sn, args.mod, args.coeff, args.seed)
|
| 1243 |
+
+ "_"
|
| 1244 |
+
+ args.model_type
|
| 1245 |
+
+ "_"
|
| 1246 |
+
+ args.dataset
|
| 1247 |
+
+ ".png")
|
| 1248 |
+
plt.show()
|
| 1249 |
+
plt.close()
|
| 1250 |
+
|
| 1251 |
+
|
| 1252 |
+
mean_acc = np.mean(adv_acc, axis=0)
|
| 1253 |
+
std_acc = np.std(adv_acc, axis=0)
|
| 1254 |
+
|
| 1255 |
+
plt.figure(figsize=(10, 6))
|
| 1256 |
+
plt.plot(mean_acc, label='Accuracy', color='red')
|
| 1257 |
+
plt.fill_between(range(len(mean_acc)),
|
| 1258 |
+
mean_acc - std_acc,
|
| 1259 |
+
mean_acc + std_acc,
|
| 1260 |
+
color='red', alpha=0.3, label="±1 Std Dev")
|
| 1261 |
+
# plt.legend()
|
| 1262 |
+
# plt.title('Uncertainty Across Multiple Runs')
|
| 1263 |
+
plt.xlabel('Noise')
|
| 1264 |
+
plt.ylabel('Accuracy')
|
| 1265 |
+
plt.savefig("adv_acc_"
|
| 1266 |
+
+ model_save_name(args.model, args.sn, args.mod, args.coeff, args.seed)
|
| 1267 |
+
+ "_"
|
| 1268 |
+
+ args.model_type
|
| 1269 |
+
+ "_"
|
| 1270 |
+
+ args.dataset
|
| 1271 |
+
+ ".png")
|
| 1272 |
+
plt.show()
|
| 1273 |
+
plt.close()
|
| 1274 |
+
|
| 1275 |
+
save_dir = "curve_data"
|
| 1276 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 1277 |
+
|
| 1278 |
+
if args.sn:
|
| 1279 |
+
prefix = f"{args.dataset}_{args.model_type}_{args.model}_SN"
|
| 1280 |
+
else:
|
| 1281 |
+
prefix = f"{args.dataset}_{args.model_type}_{args.model}"
|
| 1282 |
+
|
| 1283 |
+
|
| 1284 |
+
accuracy_tensor = torch.tensor(accuracies)
|
| 1285 |
+
c_accuracy_tensor = torch.tensor(c_accuracies)
|
| 1286 |
+
ece_tensor = torch.tensor(eces)
|
| 1287 |
+
ood_m1_auroc_tensor = torch.tensor(ood_m1_aurocs)
|
| 1288 |
+
m1_auprc_tensor = torch.tensor(ood_m1_auprcs)
|
| 1289 |
+
ood_m2_auroc_tensor = torch.tensor(ood_m2_aurocs)
|
| 1290 |
+
ood_m2_auprc_tensor = torch.tensor(ood_m2_auprcs)
|
| 1291 |
+
|
| 1292 |
+
err_m1_auroc_tensor = torch.tensor(err_m1_aurocs)
|
| 1293 |
+
err_m1_auprc_tensor = torch.tensor(err_m1_auprcs)
|
| 1294 |
+
err_m2_auroc_tensor = torch.tensor(err_m2_aurocs)
|
| 1295 |
+
err_m2_auprc_tensor = torch.tensor(err_m2_auprcs)
|
| 1296 |
+
|
| 1297 |
+
adv_m1_auroc_tensor = torch.tensor(adv_m1_aurocs)
|
| 1298 |
+
adv_m1_auprc_tensor = torch.tensor(adv_m1_auprcs)
|
| 1299 |
+
adv_m2_auroc_tensor = torch.tensor(adv_m2_aurocs)
|
| 1300 |
+
adv_m2_auprc_tensor = torch.tensor(adv_m2_auprcs)
|
| 1301 |
+
|
| 1302 |
+
t_ece_tensor = torch.tensor(t_eces)
|
| 1303 |
+
t_m1_auroc_tensor = torch.tensor(t_m1_aurocs)
|
| 1304 |
+
t_m1_auprc_tensor = torch.tensor(t_m1_auprcs)
|
| 1305 |
+
t_m2_auroc_tensor = torch.tensor(t_m2_aurocs)
|
| 1306 |
+
t_m2_auprc_tensor = torch.tensor(t_m2_auprcs)
|
| 1307 |
+
|
| 1308 |
+
c_ece_tensor = torch.tensor(c_eces)
|
| 1309 |
+
|
| 1310 |
+
mean_accuracy = torch.mean(accuracy_tensor)
|
| 1311 |
+
mean_c_accuracy = torch.mean(c_accuracy_tensor)
|
| 1312 |
+
mean_ece = torch.mean(ece_tensor)
|
| 1313 |
+
mean_ood_m1_auroc = torch.mean(ood_m1_auroc_tensor)
|
| 1314 |
+
mean_m1_auprc = torch.mean(m1_auprc_tensor)
|
| 1315 |
+
mean_m2_auroc = torch.mean(ood_m2_auroc_tensor)
|
| 1316 |
+
mean_m2_auprc = torch.mean(ood_m2_auprc_tensor)
|
| 1317 |
+
|
| 1318 |
+
mean_err_m1_auroc = torch.mean(err_m1_auroc_tensor)
|
| 1319 |
+
mean_err_m1_auprc = torch.mean(err_m1_auprc_tensor)
|
| 1320 |
+
mean_err_m2_auroc = torch.mean(err_m2_auroc_tensor)
|
| 1321 |
+
mean_err_m2_auprc = torch.mean(err_m2_auprc_tensor)
|
| 1322 |
+
|
| 1323 |
+
mean_adv_m1_auroc = torch.mean(adv_m1_auroc_tensor)
|
| 1324 |
+
mean_adv_m1_auprc = torch.mean(adv_m1_auprc_tensor)
|
| 1325 |
+
mean_adv_m2_auroc = torch.mean(adv_m2_auroc_tensor)
|
| 1326 |
+
mean_adv_m2_auprc = torch.mean(adv_m2_auprc_tensor)
|
| 1327 |
+
|
| 1328 |
+
mean_t_ece = torch.mean(t_ece_tensor)
|
| 1329 |
+
mean_t_m1_auroc = torch.mean(t_m1_auroc_tensor)
|
| 1330 |
+
mean_t_m1_auprc = torch.mean(t_m1_auprc_tensor)
|
| 1331 |
+
mean_t_m2_auroc = torch.mean(t_m2_auroc_tensor)
|
| 1332 |
+
mean_t_m2_auprc = torch.mean(t_m2_auprc_tensor)
|
| 1333 |
+
|
| 1334 |
+
mean_c_ece = torch.mean(c_ece_tensor)
|
| 1335 |
+
|
| 1336 |
+
std_accuracy = torch.std(accuracy_tensor) / math.sqrt(accuracy_tensor.shape[0])
|
| 1337 |
+
std_c_accuracy = torch.std(c_accuracy_tensor) / math.sqrt(c_accuracy_tensor.shape[0])
|
| 1338 |
+
std_ece = torch.std(ece_tensor) / math.sqrt(ece_tensor.shape[0])
|
| 1339 |
+
std_ood_m1_auroc = torch.std(ood_m1_auroc_tensor) / math.sqrt(ood_m1_auroc_tensor.shape[0])
|
| 1340 |
+
std_m1_auprc = torch.std(m1_auprc_tensor) / math.sqrt(m1_auprc_tensor.shape[0])
|
| 1341 |
+
std_m2_auroc = torch.std(ood_m2_auroc_tensor) / math.sqrt(ood_m2_auroc_tensor.shape[0])
|
| 1342 |
+
std_m2_auprc = torch.std(ood_m2_auprc_tensor) / math.sqrt(ood_m2_auprc_tensor.shape[0])
|
| 1343 |
+
std_err_m1_auroc = torch.std(err_m1_auroc_tensor) / math.sqrt(err_m1_auroc_tensor.shape[0])
|
| 1344 |
+
std_err_m1_auprc = torch.std(err_m1_auprc_tensor) / math.sqrt(err_m1_auprc_tensor.shape[0])
|
| 1345 |
+
std_err_m2_auroc = torch.std(err_m2_auroc_tensor) / math.sqrt(err_m2_auroc_tensor.shape[0])
|
| 1346 |
+
std_err_m2_auprc = torch.std(err_m2_auprc_tensor) / math.sqrt(err_m2_auprc_tensor.shape[0])
|
| 1347 |
+
std_adv_m1_auroc = torch.std(adv_m1_auroc_tensor) / math.sqrt(adv_m1_auroc_tensor.shape[0])
|
| 1348 |
+
std_adv_m1_auprc = torch.std(adv_m1_auprc_tensor) / math.sqrt(adv_m1_auprc_tensor.shape[0])
|
| 1349 |
+
std_adv_m2_auroc = torch.std(adv_m2_auroc_tensor) / math.sqrt(adv_m2_auroc_tensor.shape[0])
|
| 1350 |
+
std_adv_m2_auprc = torch.std(adv_m2_auprc_tensor) / math.sqrt(adv_m2_auprc_tensor.shape[0])
|
| 1351 |
+
|
| 1352 |
+
std_t_ece = torch.std(t_ece_tensor) / math.sqrt(t_ece_tensor.shape[0])
|
| 1353 |
+
std_t_m1_auroc = torch.std(t_m1_auroc_tensor) / math.sqrt(t_m1_auroc_tensor.shape[0])
|
| 1354 |
+
std_t_m1_auprc = torch.std(t_m1_auprc_tensor) / math.sqrt(t_m1_auprc_tensor.shape[0])
|
| 1355 |
+
std_t_m2_auroc = torch.std(t_m2_auroc_tensor) / math.sqrt(t_m2_auroc_tensor.shape[0])
|
| 1356 |
+
std_t_m2_auprc = torch.std(t_m2_auprc_tensor) / math.sqrt(t_m2_auprc_tensor.shape[0])
|
| 1357 |
+
|
| 1358 |
+
std_c_ece = torch.std(c_ece_tensor) / math.sqrt(c_ece_tensor.shape[0])
|
| 1359 |
+
|
| 1360 |
+
res_dict = {}
|
| 1361 |
+
res_dict["mean"] = {}
|
| 1362 |
+
res_dict["mean"]["accuracy"] = mean_accuracy.item()
|
| 1363 |
+
res_dict["mean"]["ece"] = mean_ece.item()
|
| 1364 |
+
res_dict["mean"]["ood_m1_auroc"] = mean_ood_m1_auroc.item()
|
| 1365 |
+
res_dict["mean"]["ood_m1_auprc"] = mean_m1_auprc.item()
|
| 1366 |
+
res_dict["mean"]["ood_m2_auroc"] = mean_m2_auroc.item()
|
| 1367 |
+
res_dict["mean"]["ood_m2_auprc"] = mean_m2_auprc.item()
|
| 1368 |
+
res_dict["mean"]["t_ece"] = mean_t_ece.item()
|
| 1369 |
+
res_dict["mean"]["t_m1_auroc"] = mean_t_m1_auroc.item()
|
| 1370 |
+
res_dict["mean"]["t_m1_auprc"] = mean_t_m1_auprc.item()
|
| 1371 |
+
res_dict["mean"]["t_m2_auroc"] = mean_t_m2_auroc.item()
|
| 1372 |
+
res_dict["mean"]["t_m2_auprc"] = mean_t_m2_auprc.item()
|
| 1373 |
+
res_dict["mean"]["c_ece"] = mean_c_ece.item()
|
| 1374 |
+
|
| 1375 |
+
res_dict["std"] = {}
|
| 1376 |
+
res_dict["std"]["accuracy"] = std_accuracy.item()
|
| 1377 |
+
res_dict["std"]["ece"] = std_ece.item()
|
| 1378 |
+
res_dict["std"]["ood_m1_auroc"] = std_ood_m1_auroc.item()
|
| 1379 |
+
res_dict["std"]["ood_m1_auprc"] = std_m1_auprc.item()
|
| 1380 |
+
res_dict["std"]["ood_m2_auroc"] = std_m2_auroc.item()
|
| 1381 |
+
res_dict["std"]["ood_m2_auprc"] = std_m2_auprc.item()
|
| 1382 |
+
res_dict["std"]["t_ece"] = std_t_ece.item()
|
| 1383 |
+
res_dict["std"]["t_m1_auroc"] = std_t_m1_auroc.item()
|
| 1384 |
+
res_dict["std"]["t_m1_auprc"] = std_t_m1_auprc.item()
|
| 1385 |
+
res_dict["std"]["t_m2_auroc"] = std_t_m2_auroc.item()
|
| 1386 |
+
res_dict["std"]["t_m2_auprc"] = std_t_m2_auprc.item()
|
| 1387 |
+
res_dict["std"]["c_ece"] = std_c_ece.item()
|
| 1388 |
+
|
| 1389 |
+
res_dict["values"] = {}
|
| 1390 |
+
res_dict["values"]["accuracy"] = accuracies
|
| 1391 |
+
res_dict["values"]["ece"] = eces
|
| 1392 |
+
res_dict["values"]["ood_m1_auroc"] = ood_m1_aurocs
|
| 1393 |
+
res_dict["values"]["ood_m1_auprc"] = ood_m1_auprcs
|
| 1394 |
+
res_dict["values"]["ood_m2_auroc"] = ood_m2_aurocs
|
| 1395 |
+
res_dict["values"]["ood_m2_auprc"] = ood_m2_auprcs
|
| 1396 |
+
res_dict["values"]["t_ece"] = t_eces
|
| 1397 |
+
res_dict["values"]["t_m1_auroc"] = t_m1_aurocs
|
| 1398 |
+
res_dict["values"]["t_m1_auprc"] = t_m1_auprcs
|
| 1399 |
+
res_dict["values"]["t_m2_auroc"] = t_m2_aurocs
|
| 1400 |
+
res_dict["values"]["t_m2_auprc"] = t_m2_auprcs
|
| 1401 |
+
res_dict["values"]["c_ece"] = c_eces
|
| 1402 |
+
|
| 1403 |
+
res_dict["info"] = vars(args)
|
| 1404 |
+
|
| 1405 |
+
print(f"{mean_accuracy.item() * 100:.2f} ± {std_accuracy.item() * 100:.2f}")
|
| 1406 |
+
print(f"{mean_c_accuracy.item() * 100:.2f} ± {std_c_accuracy.item() * 100:.2f}")
|
| 1407 |
+
print(f"{mean_ece.item()*100:.2f} ± {std_ece.item()*100:.2f}")
|
| 1408 |
+
print(f"{mean_t_ece.item()*100:.2f} ± {std_t_ece.item()*100:.2f}")
|
| 1409 |
+
print(f"{mean_c_ece.item() * 100:.2f} ± {std_c_ece.item() * 100:.2f}")
|
| 1410 |
+
print(f"{mean_adv_m1_auroc.item()*100:.2f} ± {std_adv_m1_auroc.item()*100:.2f}")
|
| 1411 |
+
print(f"{mean_err_m1_auroc.item()*100:.2f} ± {std_err_m1_auroc.item()*100:.2f}")
|
| 1412 |
+
print(f"{mean_ood_m1_auroc.item()*100:.2f} ± {std_ood_m1_auroc.item()*100:.2f}")
|
| 1413 |
+
|
| 1414 |
+
|
| 1415 |
+
with open(
|
| 1416 |
+
"res_"
|
| 1417 |
+
+ model_save_name(args.model, args.sn, args.mod, args.coeff, args.seed)
|
| 1418 |
+
+ "_"
|
| 1419 |
+
+ args.model_type
|
| 1420 |
+
+ "_"
|
| 1421 |
+
+ args.dataset
|
| 1422 |
+
+ "_"
|
| 1423 |
+
+ args.ood_dataset
|
| 1424 |
+
+ ".json",
|
| 1425 |
+
"w",
|
| 1426 |
+
) as f:
|
| 1427 |
+
json.dump(res_dict, f)
|
SPC-UQ/Image_Classification/evaluate_laplace.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script to evaluate the Laplace Approximation.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import math
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import argparse
|
| 12 |
+
import torch.backends.cudnn as cudnn
|
| 13 |
+
import numpy as np
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
|
| 17 |
+
# Import data loaders and networks
|
| 18 |
+
import data.ood_detection.cifar10 as cifar10
|
| 19 |
+
import data.ood_detection.cifar100 as cifar100
|
| 20 |
+
import data.ood_detection.svhn as svhn
|
| 21 |
+
import data.ood_detection.imagenet as imagenet
|
| 22 |
+
import data.ood_detection.tinyimagenet as tinyimagenet
|
| 23 |
+
import data.ood_detection.imagenet_o as imagenet_o
|
| 24 |
+
import data.ood_detection.imagenet_a as imagenet_a
|
| 25 |
+
import data.ood_detection.ood_union as ood_union
|
| 26 |
+
|
| 27 |
+
from net.resnet import resnet50
|
| 28 |
+
from net.resnet_edl import resnet50_edl
|
| 29 |
+
from net.wide_resnet import wrn
|
| 30 |
+
from net.wide_resnet_edl import wrn_edl
|
| 31 |
+
from net.vgg import vgg16
|
| 32 |
+
from net.vgg_edl import vgg16_edl
|
| 33 |
+
from net.imagenet_wide import imagenet_wide
|
| 34 |
+
from net.imagenet_vgg import imagenet_vgg16
|
| 35 |
+
from net.imagenet_vit import imagenet_vit
|
| 36 |
+
|
| 37 |
+
from metrics.classification_metrics import (
|
| 38 |
+
test_classification_net,
|
| 39 |
+
test_classification_net_logits,
|
| 40 |
+
test_classification_net_ensemble,
|
| 41 |
+
test_classification_net_edl,
|
| 42 |
+
create_adversarial_dataloader
|
| 43 |
+
)
|
| 44 |
+
from metrics.calibration_metrics import expected_calibration_error
|
| 45 |
+
|
| 46 |
+
from utils.gmm_utils import get_embeddings, gmm_evaluate, gmm_fit
|
| 47 |
+
from utils.ensemble_utils import load_ensemble, ensemble_forward_pass
|
| 48 |
+
from utils.eval_utils import model_load_name
|
| 49 |
+
from utils.train_utils import model_save_name
|
| 50 |
+
from utils.args import laplace_eval_args
|
| 51 |
+
|
| 52 |
+
from laplace import Laplace
|
| 53 |
+
from sklearn import metrics as M
|
| 54 |
+
from laplace.curvature import AsdlGGN, AsdlEF, BackPackGGN, BackPackEF
|
| 55 |
+
|
| 56 |
+
import warnings
|
| 57 |
+
warnings.filterwarnings('ignore')
|
| 58 |
+
|
| 59 |
+
# Dataset mapping and config
|
| 60 |
+
DATASET_NUM_CLASSES = {
|
| 61 |
+
"cifar10": 10, "cifar100": 100, "svhn": 10, "imagenet": 1000,
|
| 62 |
+
"tinyimagenet": 200, "imagenet_o": 200, "imagenet_a": 200
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
DATASET_LOADER = {
|
| 66 |
+
"cifar10": cifar10, "cifar100": cifar100, "svhn": svhn, "imagenet": imagenet,
|
| 67 |
+
"tinyimagenet": tinyimagenet, "imagenet_o": imagenet_o, "imagenet_a": imagenet_a
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
MODELS = {
|
| 71 |
+
"resnet50": resnet50, "resnet50_edl": resnet50_edl,
|
| 72 |
+
"wide_resnet": wrn, "wide_resnet_edl": wrn_edl,
|
| 73 |
+
"vgg16": vgg16, "vgg16_edl": vgg16_edl,
|
| 74 |
+
"imagenet_wide": imagenet_wide
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
MODEL_TO_NUM_DIM = {
|
| 78 |
+
"resnet50": 2048, "resnet50_edl": 2048, "wide_resnet": 640, "wide_resnet_edl": 640,
|
| 79 |
+
"vgg16": 512, "vgg16_edl": 512, "imagenet_wide": 2048,
|
| 80 |
+
"imagenet_vgg16": 4096, "imagenet_vit": 768
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
MODEL_TO_INPUT_DIM = {
|
| 84 |
+
"resnet50": 32, "resnet50_edl": 32, "wide_resnet": 32, "wide_resnet_edl": 32,
|
| 85 |
+
"vgg16": 32, "vgg16_edl": 32, "imagenet_wide": 224,
|
| 86 |
+
"imagenet_vgg16": 224, "imagenet_vit": 224
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
MODEL_TO_LAST_LAYER = {
|
| 90 |
+
"resnet50": "module.fc", "wide_resnet": "module.linear", "vgg16": "module.classifier",
|
| 91 |
+
"imagenet_wide": "module.linear",
|
| 92 |
+
"imagenet_vgg16": "module.classifier", "imagenet_vit": "module.linear"
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
def get_backend(backend, approx_type):
|
| 96 |
+
if backend == 'kazuki':
|
| 97 |
+
return AsdlGGN if approx_type == 'ggn' else AsdlEF
|
| 98 |
+
elif backend == 'backpack':
|
| 99 |
+
return BackPackGGN if approx_type == 'ggn' else BackPackEF
|
| 100 |
+
else:
|
| 101 |
+
raise ValueError(f"Unknown backend: {backend}")
|
| 102 |
+
|
| 103 |
+
def get_cpu_memory_mb():
|
| 104 |
+
import psutil
|
| 105 |
+
process = psutil.Process(os.getpid())
|
| 106 |
+
mem_info = process.memory_info()
|
| 107 |
+
return mem_info.rss / 1024 ** 2
|
| 108 |
+
|
| 109 |
+
def print_metrics(mean, std, name):
|
| 110 |
+
print(f"{name}: {mean * 100:.2f} ± {std * 100:.2f}")
|
| 111 |
+
|
| 112 |
+
if __name__ == "__main__":
|
| 113 |
+
|
| 114 |
+
args = laplace_eval_args().parse_args()
|
| 115 |
+
|
| 116 |
+
# Set random seed and device
|
| 117 |
+
torch.manual_seed(args.seed)
|
| 118 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 119 |
+
print("Parsed args:", args)
|
| 120 |
+
print("Seed:", args.seed)
|
| 121 |
+
|
| 122 |
+
num_classes = DATASET_NUM_CLASSES[args.dataset]
|
| 123 |
+
test_loader = DATASET_LOADER[args.dataset].get_test_loader(
|
| 124 |
+
batch_size=args.batch_size, imagesize=MODEL_TO_INPUT_DIM[args.model], pin_memory=args.gpu
|
| 125 |
+
)
|
| 126 |
+
if args.ood_dataset == 'ood_union':
|
| 127 |
+
ood_test_loader = ood_union.get_combined_ood_test_loader(
|
| 128 |
+
batch_size=args.batch_size, sample_seed=args.seed,
|
| 129 |
+
imagesize=MODEL_TO_INPUT_DIM[args.model], pin_memory=args.gpu
|
| 130 |
+
)
|
| 131 |
+
else:
|
| 132 |
+
ood_test_loader = DATASET_LOADER[args.ood_dataset].get_test_loader(
|
| 133 |
+
batch_size=args.batch_size, imagesize=MODEL_TO_INPUT_DIM[args.model], pin_memory=args.gpu
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# Prepare metric accumulators
|
| 137 |
+
accuracies, eces, ood_aurocs, err_aurocs, adv_aurocs = [], [], [], [], []
|
| 138 |
+
err_aurocs, adv_aurocs = [], []
|
| 139 |
+
adv_unc = np.zeros((args.runs, 9))
|
| 140 |
+
adv_acc = np.zeros((args.runs, 9))
|
| 141 |
+
adv_ep = 0.02
|
| 142 |
+
|
| 143 |
+
for i in range(args.runs):
|
| 144 |
+
# Load training/validation splits
|
| 145 |
+
train_loader, val_loader = DATASET_LOADER[args.dataset].get_train_valid_loader(
|
| 146 |
+
batch_size=args.batch_size, imagesize=MODEL_TO_INPUT_DIM[args.model], augment=args.data_aug,
|
| 147 |
+
val_seed=(args.seed + i), val_size=args.val_size, pin_memory=args.gpu
|
| 148 |
+
)
|
| 149 |
+
mixture_components = []
|
| 150 |
+
for model_idx in range(args.nr_components):
|
| 151 |
+
if args.dataset == 'imagenet':
|
| 152 |
+
net = MODELS[args.model](pretrained=True, num_classes=1000).cuda()
|
| 153 |
+
net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
|
| 154 |
+
cudnn.benchmark = True
|
| 155 |
+
else:
|
| 156 |
+
if args.val_size == 0.1 or not args.crossval:
|
| 157 |
+
saved_model_name = os.path.join(
|
| 158 |
+
args.load_loc, f"Run{i+1}",
|
| 159 |
+
model_load_name(args.model, args.sn, args.mod, args.coeff, args.seed, i) + "_350.model",
|
| 160 |
+
)
|
| 161 |
+
else:
|
| 162 |
+
saved_model_name = os.path.join(
|
| 163 |
+
args.load_loc, f"Run{i+1}",
|
| 164 |
+
model_load_name(args.model, args.sn, args.mod, args.coeff, args.seed, i)
|
| 165 |
+
+ f"_350_0{int(args.val_size * 10)}.model"
|
| 166 |
+
)
|
| 167 |
+
print('Loading:', saved_model_name)
|
| 168 |
+
net = MODELS[args.model](
|
| 169 |
+
spectral_normalization=args.sn, mod=args.mod, coeff=args.coeff, num_classes=num_classes, temp=1.0
|
| 170 |
+
)
|
| 171 |
+
if args.gpu:
|
| 172 |
+
net.cuda()
|
| 173 |
+
net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
|
| 174 |
+
cudnn.benchmark = True
|
| 175 |
+
net.load_state_dict(torch.load(str(saved_model_name)))
|
| 176 |
+
|
| 177 |
+
# Laplace backend and fit
|
| 178 |
+
args.prior_precision = 1.0 if isinstance(args.prior_precision, float) else torch.load(args.prior_precision, map_location=device)
|
| 179 |
+
Backend = get_backend(args.backend, args.approx_type)
|
| 180 |
+
args.last_layer_name = MODEL_TO_LAST_LAYER[args.model]
|
| 181 |
+
optional_args = {"last_layer_name": args.last_layer_name} if args.subset_of_weights == 'last_layer' else {}
|
| 182 |
+
|
| 183 |
+
print('Fitting Laplace approximation...')
|
| 184 |
+
model = Laplace(
|
| 185 |
+
net, args.likelihood, subset_of_weights=args.subset_of_weights,
|
| 186 |
+
hessian_structure=args.hessian_structure, prior_precision=args.prior_precision,
|
| 187 |
+
temperature=args.temperature, backend=Backend, **optional_args
|
| 188 |
+
)
|
| 189 |
+
model.fit(val_loader if args.crossval else train_loader)
|
| 190 |
+
|
| 191 |
+
# Optional: Optimize prior precision
|
| 192 |
+
if (args.optimize_prior_precision is not None) and (args.method == 'laplace'):
|
| 193 |
+
n = model.n_params if args.prior_structure == 'all' else model.n_layers
|
| 194 |
+
prior_precision = args.prior_precision * torch.ones(n, device=device)
|
| 195 |
+
print('Optimizing prior precision...')
|
| 196 |
+
model.optimize_prior_precision(
|
| 197 |
+
method=args.optimize_prior_precision, init_prior_prec=prior_precision,
|
| 198 |
+
val_loader=val_loader, pred_type=args.pred_type, link_approx=args.link_approx,
|
| 199 |
+
n_samples=args.n_samples, verbose=(args.prior_structure == 'scalar')
|
| 200 |
+
)
|
| 201 |
+
mixture_components.append(model)
|
| 202 |
+
|
| 203 |
+
model = mixture_components[0]
|
| 204 |
+
loss_fn = nn.NLLLoss()
|
| 205 |
+
|
| 206 |
+
# Evaluate ID data
|
| 207 |
+
id_y_true, id_y_prob = [], []
|
| 208 |
+
for data in tqdm(test_loader, desc='Evaluating ID data'):
|
| 209 |
+
x, y = data[0].to(device), data[1].to(device)
|
| 210 |
+
id_y_true.append(y.cpu())
|
| 211 |
+
y_prob = model(x, pred_type=args.pred_type, link_approx=args.link_approx, n_samples=args.n_samples)
|
| 212 |
+
id_y_prob.append(y_prob.cpu())
|
| 213 |
+
id_y_prob = torch.cat(id_y_prob, dim=0)
|
| 214 |
+
id_y_true = torch.cat(id_y_true, dim=0)
|
| 215 |
+
c, preds = torch.max(id_y_prob, 1)
|
| 216 |
+
metrics = {}
|
| 217 |
+
metrics['conf'] = c.mean().item()
|
| 218 |
+
metrics['nll'] = loss_fn(id_y_prob.log(), id_y_true).item()
|
| 219 |
+
metrics['acc'] = (id_y_true == preds).float().mean().item()
|
| 220 |
+
accuracy = metrics['acc']
|
| 221 |
+
id_confidences = id_y_prob.max(dim=1)[0].numpy()
|
| 222 |
+
ece = expected_calibration_error(id_confidences, preds.numpy(), id_y_true.numpy(), num_bins=15)
|
| 223 |
+
t_ece = ece
|
| 224 |
+
metrics['ece'] = ece
|
| 225 |
+
print(metrics)
|
| 226 |
+
|
| 227 |
+
# Evaluate OOD data
|
| 228 |
+
ood_y_true, ood_y_prob = [], []
|
| 229 |
+
for data in tqdm(ood_test_loader, desc='Evaluating OOD data'):
|
| 230 |
+
x, y = data[0].to(device), data[1].to(device)
|
| 231 |
+
ood_y_true.append(y.cpu())
|
| 232 |
+
y_prob = model(x, pred_type=args.pred_type, link_approx=args.link_approx, n_samples=args.n_samples)
|
| 233 |
+
ood_y_prob.append(y_prob.cpu())
|
| 234 |
+
ood_y_prob = torch.cat(ood_y_prob, dim=0)
|
| 235 |
+
ood_y_true = torch.cat(ood_y_true, dim=0)
|
| 236 |
+
ood_confidences = ood_y_prob.max(dim=1)[0].numpy()
|
| 237 |
+
|
| 238 |
+
# OOD AUROC/AUPRC metrics
|
| 239 |
+
bin_labels = np.concatenate([
|
| 240 |
+
np.zeros(id_confidences.shape[0]),
|
| 241 |
+
np.ones(ood_confidences.shape[0])
|
| 242 |
+
])
|
| 243 |
+
scores = np.concatenate([id_confidences, ood_confidences])
|
| 244 |
+
fpr, tpr, thresholds = M.roc_curve(bin_labels, scores)
|
| 245 |
+
precision, recall, prc_thresholds = M.precision_recall_curve(bin_labels, scores)
|
| 246 |
+
ood_auroc = M.roc_auc_score(bin_labels, scores)
|
| 247 |
+
auprc = M.average_precision_score(bin_labels, scores)
|
| 248 |
+
print(f"OOD AUROC: {ood_auroc:.4f}, AUPRC: {auprc:.4f}")
|
| 249 |
+
|
| 250 |
+
# Error AUROC/AUPRC (in-distribution: correct vs incorrect)
|
| 251 |
+
labels_array = np.array(id_y_true)
|
| 252 |
+
pred_array = np.array(preds)
|
| 253 |
+
correct_mask = labels_array == pred_array
|
| 254 |
+
confidences_right = id_confidences[correct_mask]
|
| 255 |
+
confidences_wrong = id_confidences[~correct_mask]
|
| 256 |
+
bin_labels = np.concatenate([
|
| 257 |
+
np.zeros(confidences_right.shape[0]),
|
| 258 |
+
np.ones(confidences_wrong.shape[0])
|
| 259 |
+
])
|
| 260 |
+
scores = np.concatenate([confidences_right, confidences_wrong])
|
| 261 |
+
err_auroc = M.roc_auc_score(bin_labels, scores)
|
| 262 |
+
err_auprc = M.average_precision_score(bin_labels, scores)
|
| 263 |
+
print(f"Error AUROC: {err_auroc:.4f}, AUPRC: {err_auprc:.4f}")
|
| 264 |
+
|
| 265 |
+
# Adversarial robustness
|
| 266 |
+
adv_loader = create_adversarial_dataloader(net, test_loader, device, epsilon=adv_ep, batch_size=args.batch_size)
|
| 267 |
+
adv_y_prob, adv_y_true = [], []
|
| 268 |
+
for data in tqdm(adv_loader, desc='Adversarial evaluation'):
|
| 269 |
+
x, y = data[0].to(device), data[1].to(device)
|
| 270 |
+
y_prob = model(x, pred_type=args.pred_type, link_approx=args.link_approx, n_samples=args.n_samples)
|
| 271 |
+
adv_y_true.append(y.cpu())
|
| 272 |
+
adv_y_prob.append(y_prob.cpu())
|
| 273 |
+
adv_y_prob = torch.cat(adv_y_prob, dim=0)
|
| 274 |
+
adv_y_true = torch.cat(adv_y_true, dim=0).numpy()
|
| 275 |
+
_, adv_predictions = torch.max(adv_y_prob, 1)
|
| 276 |
+
adv_accuracy = (adv_y_true == adv_predictions).mean()
|
| 277 |
+
adv_confidences = adv_y_prob.max(dim=1)[0].numpy()
|
| 278 |
+
bin_labels = np.concatenate([
|
| 279 |
+
np.zeros(id_confidences.shape[0]),
|
| 280 |
+
np.ones(adv_confidences.shape[0])
|
| 281 |
+
])
|
| 282 |
+
adv_scores = np.concatenate([id_confidences, adv_confidences])
|
| 283 |
+
adv_auroc = M.roc_auc_score(bin_labels, adv_scores)
|
| 284 |
+
adv_auprc = M.average_precision_score(bin_labels, adv_scores)
|
| 285 |
+
print(f"Adversarial AUROC: {adv_auroc:.4f}, AUPRC: {adv_auprc:.4f}")
|
| 286 |
+
|
| 287 |
+
# If sample_noise: save/plot noise-uncertainty and accuracy curves
|
| 288 |
+
if args.sample_noise:
|
| 289 |
+
adv_eps = np.linspace(0, 0.4, 9)
|
| 290 |
+
for idx_ep, ep in enumerate(adv_eps):
|
| 291 |
+
adv_loader = create_adversarial_dataloader(
|
| 292 |
+
net, test_loader, device, epsilon=ep, batch_size=args.batch_size
|
| 293 |
+
)
|
| 294 |
+
adv_y_prob, adv_y_true = [], []
|
| 295 |
+
for data in tqdm(adv_loader, desc=f"Adv evaluation ep={ep:.2f}"):
|
| 296 |
+
x, y = data[0].to(device), data[1].to(device)
|
| 297 |
+
y_prob = model(x, pred_type=args.pred_type, link_approx=args.link_approx, n_samples=args.n_samples)
|
| 298 |
+
adv_y_true.append(y.cpu())
|
| 299 |
+
adv_y_prob.append(y_prob.cpu())
|
| 300 |
+
adv_y_prob = torch.cat(adv_y_prob, dim=0)
|
| 301 |
+
adv_y_true = torch.cat(adv_y_true, dim=0).numpy()
|
| 302 |
+
_, predictions = torch.max(adv_y_prob, 1)
|
| 303 |
+
adv_accuracy = (adv_y_true == predictions).mean()
|
| 304 |
+
uncertainties = 1 - adv_y_prob.max(dim=1)[0].numpy()
|
| 305 |
+
adv_unc[i][idx_ep] = uncertainties.mean()
|
| 306 |
+
adv_acc[i][idx_ep] = adv_accuracy
|
| 307 |
+
# Save/plot uncertainty/accuracy curves as in your original
|
| 308 |
+
|
| 309 |
+
# Accumulate results
|
| 310 |
+
accuracies.append(accuracy)
|
| 311 |
+
eces.append(ece)
|
| 312 |
+
ood_aurocs.append(ood_auroc)
|
| 313 |
+
err_aurocs.append(err_auroc)
|
| 314 |
+
adv_aurocs.append(adv_auroc)
|
| 315 |
+
|
| 316 |
+
del model, mixture_components
|
| 317 |
+
torch.cuda.empty_cache()
|
| 318 |
+
import gc
|
| 319 |
+
gc.collect()
|
| 320 |
+
|
| 321 |
+
# Final result reporting and saving
|
| 322 |
+
|
| 323 |
+
def mean_std(x):
|
| 324 |
+
arr = torch.tensor(x)
|
| 325 |
+
return arr.mean().item(), arr.std().item() / math.sqrt(arr.shape[0])
|
| 326 |
+
|
| 327 |
+
# Print summary
|
| 328 |
+
print_metrics(*mean_std(accuracies), "Accuracy")
|
| 329 |
+
print_metrics(*mean_std(eces), "ECE")
|
| 330 |
+
print_metrics(*mean_std(adv_aurocs), "Adv AUROC")
|
| 331 |
+
print_metrics(*mean_std(err_aurocs), "Error AUROC")
|
| 332 |
+
print_metrics(*mean_std(ood_aurocs), "OOD AUROC")
|
| 333 |
+
|
| 334 |
+
# Store only required metrics
|
| 335 |
+
result_json = {}
|
| 336 |
+
for key, arr in [
|
| 337 |
+
("accuracy", accuracies),
|
| 338 |
+
("ece", eces),
|
| 339 |
+
("adv_auroc", adv_aurocs),
|
| 340 |
+
("err_auroc", err_aurocs),
|
| 341 |
+
("ood_auroc", ood_aurocs)
|
| 342 |
+
]:
|
| 343 |
+
mean, std = mean_std(arr)
|
| 344 |
+
result_json[key] = {
|
| 345 |
+
"mean": mean,
|
| 346 |
+
"std": std,
|
| 347 |
+
"values": [float(v) for v in arr]
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
result_file = (
|
| 351 |
+
"res_" + model_save_name(args.model, args.sn, args.mod, args.coeff, args.seed)
|
| 352 |
+
+ "_laplace_" + args.dataset + "_" + args.ood_dataset + ".json"
|
| 353 |
+
)
|
| 354 |
+
with open(result_file, "w") as f:
|
| 355 |
+
json.dump(result_json, f, indent=2)
|
SPC-UQ/Image_Classification/metrics/__init__.py
ADDED
|
File without changes
|
SPC-UQ/Image_Classification/metrics/calibration_metrics.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Metrics to measure calibration of a trained deep neural network.
|
| 3 |
+
References:
|
| 4 |
+
[1] C. Guo, G. Pleiss, Y. Sun, and K. Q. Weinberger. On calibration of modern neural networks.
|
| 5 |
+
arXiv preprint arXiv:1706.04599, 2017.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
from torch import nn
|
| 12 |
+
from torch.nn import functional as F
|
| 13 |
+
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
|
| 16 |
+
plt.rcParams.update({"font.size": 20})
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Some keys used for the following dictionaries
|
| 20 |
+
COUNT = "count"
|
| 21 |
+
CONF = "conf"
|
| 22 |
+
ACC = "acc"
|
| 23 |
+
BIN_ACC = "bin_acc"
|
| 24 |
+
BIN_CONF = "bin_conf"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _bin_initializer(num_bins=10):
|
| 28 |
+
bin_dict = {}
|
| 29 |
+
for i in range(num_bins):
|
| 30 |
+
bin_dict[i] = {}
|
| 31 |
+
bin_dict[i][COUNT] = 0
|
| 32 |
+
bin_dict[i][CONF] = 0
|
| 33 |
+
bin_dict[i][ACC] = 0
|
| 34 |
+
bin_dict[i][BIN_ACC] = 0
|
| 35 |
+
bin_dict[i][BIN_CONF] = 0
|
| 36 |
+
|
| 37 |
+
return bin_dict
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _populate_bins(confs, preds, labels, num_bins=10):
|
| 41 |
+
|
| 42 |
+
bin_dict = _bin_initializer(num_bins)
|
| 43 |
+
num_test_samples = len(confs)
|
| 44 |
+
|
| 45 |
+
for i in range(0, num_test_samples):
|
| 46 |
+
confidence = confs[i]
|
| 47 |
+
prediction = preds[i]
|
| 48 |
+
label = labels[i]
|
| 49 |
+
# binn = int(math.ceil(((num_bins * confidence) - 1)))
|
| 50 |
+
binn = min(num_bins - 1, max(0, int(num_bins * confidence)))
|
| 51 |
+
# if binn>=num_bins:
|
| 52 |
+
# binn=num_bins-1
|
| 53 |
+
bin_dict[binn][COUNT] = bin_dict[binn][COUNT] + 1
|
| 54 |
+
bin_dict[binn][CONF] = bin_dict[binn][CONF] + confidence
|
| 55 |
+
bin_dict[binn][ACC] = bin_dict[binn][ACC] + (1 if (label == prediction) else 0)
|
| 56 |
+
|
| 57 |
+
for binn in range(0, num_bins):
|
| 58 |
+
if bin_dict[binn][COUNT] == 0:
|
| 59 |
+
bin_dict[binn][BIN_ACC] = 0
|
| 60 |
+
bin_dict[binn][BIN_CONF] = 0
|
| 61 |
+
else:
|
| 62 |
+
bin_dict[binn][BIN_ACC] = float(bin_dict[binn][ACC]) / bin_dict[binn][COUNT]
|
| 63 |
+
bin_dict[binn][BIN_CONF] = bin_dict[binn][CONF] / float(bin_dict[binn][COUNT])
|
| 64 |
+
return bin_dict
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def expected_calibration_error(confs, preds, labels, num_bins=10):
|
| 68 |
+
bin_dict = _populate_bins(confs, preds, labels, num_bins)
|
| 69 |
+
num_samples = len(labels)
|
| 70 |
+
ece = 0
|
| 71 |
+
for i in range(num_bins):
|
| 72 |
+
bin_accuracy = bin_dict[i][BIN_ACC]
|
| 73 |
+
bin_confidence = bin_dict[i][BIN_CONF]
|
| 74 |
+
bin_count = bin_dict[i][COUNT]
|
| 75 |
+
ece += (float(bin_count) / num_samples) * abs(bin_accuracy - bin_confidence)
|
| 76 |
+
return ece
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# Calibration error scores in the form of loss metrics
|
| 80 |
+
class ECELoss(nn.Module):
|
| 81 |
+
"""
|
| 82 |
+
Compute ECE (Expected Calibration Error)
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
def __init__(self, n_bins=15):
|
| 86 |
+
super(ECELoss, self).__init__()
|
| 87 |
+
bin_boundaries = torch.linspace(0, 1, n_bins + 1)
|
| 88 |
+
self.bin_lowers = bin_boundaries[:-1]
|
| 89 |
+
self.bin_uppers = bin_boundaries[1:]
|
| 90 |
+
|
| 91 |
+
def forward(self, logits, labels):
|
| 92 |
+
softmaxes = F.softmax(logits, dim=1)
|
| 93 |
+
confidences, predictions = torch.max(softmaxes, 1)
|
| 94 |
+
accuracies = predictions.eq(labels)
|
| 95 |
+
|
| 96 |
+
ece = torch.zeros(1, device=logits.device)
|
| 97 |
+
for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
|
| 98 |
+
# Calculated |confidence - accuracy| in each bin
|
| 99 |
+
in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
|
| 100 |
+
prop_in_bin = in_bin.float().mean()
|
| 101 |
+
if prop_in_bin.item() > 0:
|
| 102 |
+
accuracy_in_bin = accuracies[in_bin].float().mean()
|
| 103 |
+
avg_confidence_in_bin = confidences[in_bin].mean()
|
| 104 |
+
ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
|
| 105 |
+
|
| 106 |
+
return ece
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# Methods for plotting reliability diagrams and bin-strength plots
|
| 110 |
+
def reliability_plot(confs, preds, labels, num_bins=15, model_name='model'):
|
| 111 |
+
"""
|
| 112 |
+
Method to draw a reliability plot from a model's predictions and confidences.
|
| 113 |
+
"""
|
| 114 |
+
bin_dict = _populate_bins(confs, preds, labels, num_bins)
|
| 115 |
+
bns = [(i / float(num_bins)) for i in range(num_bins)]
|
| 116 |
+
y = []
|
| 117 |
+
for i in range(num_bins):
|
| 118 |
+
y.append(bin_dict[i][BIN_ACC])
|
| 119 |
+
plt.figure(figsize=(10, 8)) # width:20, height:3
|
| 120 |
+
plt.bar(bns, bns, align="edge", width=0.03, color="pink", label="Expected")
|
| 121 |
+
plt.bar(bns, y, align="edge", width=0.03, color="blue", alpha=0.5, label="Actual")
|
| 122 |
+
plt.ylabel("Accuracy", fontsize=30)
|
| 123 |
+
plt.xlabel("Confidence", fontsize=30)
|
| 124 |
+
plt.xticks(fontsize=30)
|
| 125 |
+
plt.yticks(fontsize=30)
|
| 126 |
+
plt.legend(fontsize=30, loc='upper left')
|
| 127 |
+
plt.savefig(f'./reliability_plot_{model_name}.pdf')
|
| 128 |
+
plt.savefig(f'./reliability_plot_{model_name}.png')
|
| 129 |
+
plt.show()
|
SPC-UQ/Image_Classification/metrics/classification_metrics.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Metrics to measure classification performance
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
import numpy as np
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 11 |
+
|
| 12 |
+
from utils.ensemble_utils import ensemble_forward_pass
|
| 13 |
+
|
| 14 |
+
from sklearn.metrics import accuracy_score
|
| 15 |
+
from sklearn.metrics import confusion_matrix
|
| 16 |
+
|
| 17 |
+
def evidential_loss(alpha, target, lambda_reg=0.001):
|
| 18 |
+
num_classes = alpha.shape[1]
|
| 19 |
+
target_one_hot = F.one_hot(target, num_classes=num_classes).float()
|
| 20 |
+
S = alpha.sum(dim=1, keepdim=True)
|
| 21 |
+
|
| 22 |
+
log_likelihood = torch.sum(target_one_hot * (torch.digamma(S) - torch.digamma(alpha)), dim=1)
|
| 23 |
+
kl_divergence = lambda_reg * torch.sum((alpha - 1) * (1 - target_one_hot), dim=1)
|
| 24 |
+
|
| 25 |
+
loss = log_likelihood + kl_divergence
|
| 26 |
+
return torch.mean(loss)
|
| 27 |
+
|
| 28 |
+
def create_adversarial_dataloader(model, data_loader, device, epsilon=0.03, batch_size=32, edl=False, joint=False):
|
| 29 |
+
adv_examples = []
|
| 30 |
+
adv_labels = []
|
| 31 |
+
|
| 32 |
+
model.eval()
|
| 33 |
+
for data, label in data_loader:
|
| 34 |
+
data = data.to(device).detach().requires_grad_(True)
|
| 35 |
+
label = label.to(device)
|
| 36 |
+
|
| 37 |
+
model.zero_grad()
|
| 38 |
+
logit = model(data)
|
| 39 |
+
if edl:
|
| 40 |
+
loss = evidential_loss(logit, label)
|
| 41 |
+
if joint:
|
| 42 |
+
loss = F.cross_entropy(logit[0], label)
|
| 43 |
+
else:
|
| 44 |
+
loss = F.cross_entropy(logit, label)
|
| 45 |
+
loss.backward()
|
| 46 |
+
|
| 47 |
+
signed_grad = data.grad.sign()
|
| 48 |
+
data_adv = data + epsilon * signed_grad
|
| 49 |
+
|
| 50 |
+
adv_examples.append(data_adv.detach().cpu())
|
| 51 |
+
adv_labels.append(label.detach().cpu())
|
| 52 |
+
|
| 53 |
+
adv_examples = torch.cat(adv_examples, dim=0)
|
| 54 |
+
adv_labels = torch.cat(adv_labels, dim=0)
|
| 55 |
+
|
| 56 |
+
adv_dataset = TensorDataset(adv_examples, adv_labels)
|
| 57 |
+
adv_dataloader = DataLoader(adv_dataset, batch_size=batch_size, shuffle=False)
|
| 58 |
+
|
| 59 |
+
return adv_dataloader
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_logits_labels(model, data_loader, device):
|
| 63 |
+
"""
|
| 64 |
+
Utility function to get logits and labels.
|
| 65 |
+
"""
|
| 66 |
+
model.eval()
|
| 67 |
+
logits = []
|
| 68 |
+
labels = []
|
| 69 |
+
with torch.no_grad():
|
| 70 |
+
for data, label in data_loader:
|
| 71 |
+
data = data.to(device)
|
| 72 |
+
label = label.to(device)
|
| 73 |
+
|
| 74 |
+
logit = model(data)
|
| 75 |
+
logits.append(logit)
|
| 76 |
+
labels.append(label)
|
| 77 |
+
logits = torch.cat(logits, dim=0)
|
| 78 |
+
labels = torch.cat(labels, dim=0)
|
| 79 |
+
return logits, labels
|
| 80 |
+
|
| 81 |
+
def get_logits_labels_uq(model, data_loader, device):
|
| 82 |
+
"""
|
| 83 |
+
Utility function to get logits as a list: [pred_all, mar_all, mar_up_all, mar_down_all]
|
| 84 |
+
and labels as a single tensor.
|
| 85 |
+
"""
|
| 86 |
+
model.eval()
|
| 87 |
+
preds = []
|
| 88 |
+
mars = []
|
| 89 |
+
mars_up = []
|
| 90 |
+
mars_down = []
|
| 91 |
+
labels = []
|
| 92 |
+
|
| 93 |
+
with torch.no_grad():
|
| 94 |
+
for data, label in data_loader:
|
| 95 |
+
data = data.to(device)
|
| 96 |
+
label = label.to(device)
|
| 97 |
+
|
| 98 |
+
pred, mar, mar_up, mar_down = model(data)
|
| 99 |
+
preds.append(pred)
|
| 100 |
+
mars.append(mar)
|
| 101 |
+
mars_up.append(mar_up)
|
| 102 |
+
mars_down.append(mar_down)
|
| 103 |
+
labels.append(label)
|
| 104 |
+
|
| 105 |
+
pred_all = torch.cat(preds, dim=0)
|
| 106 |
+
mar_all = torch.cat(mars, dim=0)
|
| 107 |
+
mar_up_all = torch.cat(mars_up, dim=0)
|
| 108 |
+
mar_down_all = torch.cat(mars_down, dim=0)
|
| 109 |
+
labels_all = torch.cat(labels, dim=0)
|
| 110 |
+
|
| 111 |
+
logits = [pred_all, mar_all, mar_up_all, mar_down_all]
|
| 112 |
+
return logits, labels_all
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def test_classification_net_softmax(softmax_prob, labels):
|
| 116 |
+
"""
|
| 117 |
+
This function reports classification accuracy and confusion matrix given softmax vectors and
|
| 118 |
+
labels from a model.
|
| 119 |
+
"""
|
| 120 |
+
labels_list = []
|
| 121 |
+
predictions_list = []
|
| 122 |
+
confidence_vals_list = []
|
| 123 |
+
|
| 124 |
+
confidence_vals, predictions = torch.max(softmax_prob, dim=1)
|
| 125 |
+
labels_list.extend(labels.cpu().numpy())
|
| 126 |
+
predictions_list.extend(predictions.cpu().numpy())
|
| 127 |
+
confidence_vals_list.extend(confidence_vals.detach().cpu().numpy())
|
| 128 |
+
accuracy = accuracy_score(labels_list, predictions_list)
|
| 129 |
+
return (
|
| 130 |
+
confusion_matrix(labels_list, predictions_list),
|
| 131 |
+
accuracy,
|
| 132 |
+
labels_list,
|
| 133 |
+
predictions_list,
|
| 134 |
+
confidence_vals_list,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def test_classification_net_logits(logits, labels):
|
| 139 |
+
"""
|
| 140 |
+
This function reports classification accuracy and confusion matrix given logits and labels
|
| 141 |
+
from a model.
|
| 142 |
+
"""
|
| 143 |
+
softmax_prob = F.softmax(logits, dim=1)
|
| 144 |
+
return test_classification_net_softmax(softmax_prob, labels)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def test_classification_net(model, data_loader, device):
|
| 148 |
+
"""
|
| 149 |
+
This function reports classification accuracy and confusion matrix over a dataset.
|
| 150 |
+
"""
|
| 151 |
+
logits, labels = get_logits_labels(model, data_loader, device)
|
| 152 |
+
return test_classification_net_logits(logits, labels)
|
| 153 |
+
|
| 154 |
+
def test_classification_uq(model, data_loader, device):
|
| 155 |
+
"""
|
| 156 |
+
This function reports classification accuracy and confusion matrix over a dataset.
|
| 157 |
+
"""
|
| 158 |
+
logits, labels = get_logits_labels_uq(model, data_loader, device)
|
| 159 |
+
return test_classification_net_logits(logits[0], labels)
|
| 160 |
+
|
| 161 |
+
def test_classification_net_ensemble(model_ensemble, data_loader, device):
|
| 162 |
+
"""
|
| 163 |
+
This function reports classification accuracy and confusion matrix over a dataset
|
| 164 |
+
for a deep ensemble.
|
| 165 |
+
"""
|
| 166 |
+
for model in model_ensemble:
|
| 167 |
+
model.eval()
|
| 168 |
+
softmax_prob = []
|
| 169 |
+
labels = []
|
| 170 |
+
with torch.no_grad():
|
| 171 |
+
for data, label in data_loader:
|
| 172 |
+
data = data.to(device)
|
| 173 |
+
label = label.to(device)
|
| 174 |
+
|
| 175 |
+
softmax, _, _ = ensemble_forward_pass(model_ensemble, data)
|
| 176 |
+
softmax_prob.append(softmax)
|
| 177 |
+
labels.append(label)
|
| 178 |
+
softmax_prob = torch.cat(softmax_prob, dim=0)
|
| 179 |
+
labels = torch.cat(labels, dim=0)
|
| 180 |
+
|
| 181 |
+
return test_classification_net_softmax(softmax_prob, labels)
|
| 182 |
+
|
| 183 |
+
def test_classification_net_logits_edl(logits, labels):
|
| 184 |
+
"""
|
| 185 |
+
This function reports classification accuracy and confusion matrix given softmax vectors and
|
| 186 |
+
labels from a model.
|
| 187 |
+
"""
|
| 188 |
+
labels_list = []
|
| 189 |
+
predictions_list = []
|
| 190 |
+
confidence_vals_list = []
|
| 191 |
+
|
| 192 |
+
predicted_probs = logits / logits.sum(dim=1, keepdim=True)
|
| 193 |
+
confidence_vals, predictions = torch.max(predicted_probs, dim=1)
|
| 194 |
+
labels_list.extend(labels.cpu().numpy())
|
| 195 |
+
predictions_list.extend(predictions.cpu().numpy())
|
| 196 |
+
confidence_vals_list.extend(confidence_vals.cpu().numpy())
|
| 197 |
+
accuracy = accuracy_score(labels_list, predictions_list)
|
| 198 |
+
return (
|
| 199 |
+
confusion_matrix(labels_list, predictions_list),
|
| 200 |
+
accuracy,
|
| 201 |
+
labels_list,
|
| 202 |
+
predictions_list,
|
| 203 |
+
confidence_vals_list,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
def test_classification_net_edl(model, data_loader, device):
|
| 207 |
+
"""
|
| 208 |
+
This function reports classification accuracy and confusion matrix over a dataset.
|
| 209 |
+
"""
|
| 210 |
+
logits, labels = get_logits_labels(model, data_loader, device)
|
| 211 |
+
return test_classification_net_logits_edl(logits, labels)
|
SPC-UQ/Image_Classification/metrics/ood_metrics.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Utility functions to get OOD detection ROC curves and AUROC scores
|
| 2 |
+
# Ideally should be agnostic of model architectures
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from sklearn import metrics
|
| 7 |
+
|
| 8 |
+
from utils.ensemble_utils import ensemble_forward_pass
|
| 9 |
+
from metrics.classification_metrics import get_logits_labels
|
| 10 |
+
from metrics.uncertainty_confidence import entropy, logsumexp, confidence, edl_unc
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_roc_auc(net, test_loader, ood_test_loader, uncertainty, device, confidence=False):
|
| 14 |
+
logits, _ = get_logits_labels(net, test_loader, device)
|
| 15 |
+
ood_logits, _ = get_logits_labels(net, ood_test_loader, device)
|
| 16 |
+
|
| 17 |
+
return get_roc_auc_logits(logits, ood_logits, uncertainty, device, confidence=confidence)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_roc_auc_logits(logits, ood_logits, uncertainty, device, confidence=False):
|
| 21 |
+
uncertainties = uncertainty(logits)
|
| 22 |
+
ood_uncertainties = uncertainty(ood_logits)
|
| 23 |
+
|
| 24 |
+
# In-distribution
|
| 25 |
+
bin_labels = torch.zeros(uncertainties.shape[0]).to(device)
|
| 26 |
+
in_scores = uncertainties
|
| 27 |
+
|
| 28 |
+
# OOD
|
| 29 |
+
bin_labels = torch.cat((bin_labels, torch.ones(ood_uncertainties.shape[0]).to(device)))
|
| 30 |
+
|
| 31 |
+
if confidence:
|
| 32 |
+
bin_labels = 1 - bin_labels
|
| 33 |
+
ood_scores = ood_uncertainties # entropy(ood_logits)
|
| 34 |
+
scores = torch.cat((in_scores, ood_scores))
|
| 35 |
+
|
| 36 |
+
fpr, tpr, thresholds = metrics.roc_curve(bin_labels.cpu().numpy(), scores.cpu().numpy())
|
| 37 |
+
precision, recall, prc_thresholds = metrics.precision_recall_curve(bin_labels.cpu().numpy(), scores.cpu().numpy())
|
| 38 |
+
auroc = metrics.roc_auc_score(bin_labels.cpu().numpy(), scores.cpu().numpy())
|
| 39 |
+
auprc = metrics.average_precision_score(bin_labels.cpu().numpy(), scores.cpu().numpy())
|
| 40 |
+
|
| 41 |
+
return (fpr, tpr, thresholds), (precision, recall, prc_thresholds), auroc, auprc
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_roc_auc_uncs(uncertainties, ood_uncertainties, device, confidence=False):
|
| 45 |
+
# In-distribution
|
| 46 |
+
bin_labels = torch.zeros(uncertainties.shape[0]).to(device)
|
| 47 |
+
in_scores = uncertainties
|
| 48 |
+
|
| 49 |
+
# OOD
|
| 50 |
+
bin_labels = torch.cat((bin_labels, torch.ones(ood_uncertainties.shape[0]).to(device)))
|
| 51 |
+
|
| 52 |
+
if confidence:
|
| 53 |
+
bin_labels = 1 - bin_labels
|
| 54 |
+
ood_scores = ood_uncertainties # entropy(ood_logits)
|
| 55 |
+
scores = torch.cat((in_scores, ood_scores))
|
| 56 |
+
|
| 57 |
+
fpr, tpr, thresholds = metrics.roc_curve(bin_labels.detach().cpu().numpy(), scores.detach().cpu().numpy())
|
| 58 |
+
precision, recall, prc_thresholds = metrics.precision_recall_curve(bin_labels.detach().cpu().numpy(), scores.detach().cpu().numpy())
|
| 59 |
+
auroc = metrics.roc_auc_score(bin_labels.detach().cpu().numpy(), scores.detach().cpu().numpy())
|
| 60 |
+
auprc = metrics.average_precision_score(bin_labels.detach().cpu().numpy(), scores.detach().cpu().numpy())
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
return (fpr, tpr, thresholds), (precision, recall, prc_thresholds), auroc, auprc
|
| 64 |
+
|
| 65 |
+
def get_roc_auc_ensemble(model_ensemble, test_loader, ood_test_loader, uncertainty, device):
|
| 66 |
+
bin_labels_uncertainties = None
|
| 67 |
+
uncertainties = None
|
| 68 |
+
|
| 69 |
+
for model in model_ensemble:
|
| 70 |
+
model.eval()
|
| 71 |
+
|
| 72 |
+
bin_labels_uncertainties = []
|
| 73 |
+
uncertainties = []
|
| 74 |
+
with torch.no_grad():
|
| 75 |
+
# Getting uncertainties for in-distribution data
|
| 76 |
+
for data, label in test_loader:
|
| 77 |
+
data = data.to(device)
|
| 78 |
+
label = label.to(device)
|
| 79 |
+
|
| 80 |
+
bin_label_uncertainty = torch.zeros(label.shape).to(device)
|
| 81 |
+
if uncertainty == "mutual_information":
|
| 82 |
+
net_output, _, unc = ensemble_forward_pass(model_ensemble, data)
|
| 83 |
+
else:
|
| 84 |
+
net_output, unc, _ = ensemble_forward_pass(model_ensemble, data)
|
| 85 |
+
|
| 86 |
+
bin_labels_uncertainties.append(bin_label_uncertainty)
|
| 87 |
+
uncertainties.append(unc)
|
| 88 |
+
|
| 89 |
+
# Getting entropies for OOD data
|
| 90 |
+
for data, label in ood_test_loader:
|
| 91 |
+
data = data.to(device)
|
| 92 |
+
label = label.to(device)
|
| 93 |
+
|
| 94 |
+
bin_label_uncertainty = torch.ones(label.shape).to(device)
|
| 95 |
+
if uncertainty == "mutual_information":
|
| 96 |
+
net_output, _, unc = ensemble_forward_pass(model_ensemble, data)
|
| 97 |
+
else:
|
| 98 |
+
net_output, unc, _ = ensemble_forward_pass(model_ensemble, data)
|
| 99 |
+
|
| 100 |
+
bin_labels_uncertainties.append(bin_label_uncertainty)
|
| 101 |
+
uncertainties.append(unc)
|
| 102 |
+
|
| 103 |
+
bin_labels_uncertainties = torch.cat(bin_labels_uncertainties)
|
| 104 |
+
uncertainties = torch.cat(uncertainties)
|
| 105 |
+
|
| 106 |
+
fpr, tpr, roc_thresholds = metrics.roc_curve(bin_labels_uncertainties.cpu().numpy(), uncertainties.cpu().numpy())
|
| 107 |
+
precision, recall, prc_thresholds = metrics.precision_recall_curve(
|
| 108 |
+
bin_labels_uncertainties.cpu().numpy(), uncertainties.cpu().numpy()
|
| 109 |
+
)
|
| 110 |
+
auroc = metrics.roc_auc_score(bin_labels_uncertainties.cpu().numpy(), uncertainties.cpu().numpy())
|
| 111 |
+
auprc = metrics.average_precision_score(bin_labels_uncertainties.cpu().numpy(), uncertainties.cpu().numpy())
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
return (fpr, tpr, roc_thresholds), (precision, recall, prc_thresholds), auroc, auprc
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def get_unc_ensemble(model_ensemble, test_loader, uncertainty, device):
|
| 118 |
+
for model in model_ensemble:
|
| 119 |
+
model.eval()
|
| 120 |
+
|
| 121 |
+
uncertainties = []
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
for data, label in test_loader:
|
| 124 |
+
data = data.to(device)
|
| 125 |
+
|
| 126 |
+
if uncertainty == "mutual_information":
|
| 127 |
+
net_output, _, unc = ensemble_forward_pass(model_ensemble, data)
|
| 128 |
+
else:
|
| 129 |
+
net_output, unc, _ = ensemble_forward_pass(model_ensemble, data)
|
| 130 |
+
|
| 131 |
+
uncertainties.append(unc)
|
| 132 |
+
|
| 133 |
+
uncertainties = torch.cat(uncertainties)
|
| 134 |
+
|
| 135 |
+
return uncertainties
|
SPC-UQ/Image_Classification/metrics/uncertainty_confidence.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Metrics measuring either uncertainty or confidence of a model.
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def entropy(logits):
|
| 9 |
+
p = F.softmax(logits, dim=1)
|
| 10 |
+
logp = F.log_softmax(logits, dim=1)
|
| 11 |
+
plogp = p * logp
|
| 12 |
+
entropy = -torch.sum(plogp, dim=1)
|
| 13 |
+
return entropy
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def logsumexp(logits):
|
| 17 |
+
return torch.logsumexp(logits, dim=1, keepdim=False)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def confidence(logits):
|
| 21 |
+
p = F.softmax(logits, dim=1)
|
| 22 |
+
confidence, _ = torch.max(p, dim=1)
|
| 23 |
+
return confidence
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def entropy_prob(probs):
|
| 27 |
+
p = probs
|
| 28 |
+
eps = 1e-12
|
| 29 |
+
logp = torch.log(p + eps)
|
| 30 |
+
plogp = p * logp
|
| 31 |
+
entropy = -torch.sum(plogp, dim=1)
|
| 32 |
+
return entropy
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def mutual_information_prob(probs):
|
| 36 |
+
mean_output = torch.mean(probs, dim=0)
|
| 37 |
+
predictive_entropy = entropy_prob(mean_output)
|
| 38 |
+
|
| 39 |
+
# Computing expectation of entropies
|
| 40 |
+
p = probs
|
| 41 |
+
eps = 1e-12
|
| 42 |
+
logp = torch.log(p + eps)
|
| 43 |
+
plogp = p * logp
|
| 44 |
+
exp_entropies = torch.mean(-torch.sum(plogp, dim=2), dim=0)
|
| 45 |
+
|
| 46 |
+
# Computing mutual information
|
| 47 |
+
mi = predictive_entropy - exp_entropies
|
| 48 |
+
return mi
|
| 49 |
+
|
| 50 |
+
def self_consistency(mars):
|
| 51 |
+
logits=mars[0]
|
| 52 |
+
logits = torch.nn.functional.softmax(logits, dim=1)
|
| 53 |
+
mar=mars[1].squeeze()
|
| 54 |
+
mar_up=1-logits
|
| 55 |
+
mar_down=logits
|
| 56 |
+
uncertainty = abs(2 * mar_up * mar_down - mar) #(mar_up + mar_down)=1
|
| 57 |
+
uncertainty = torch.sum(uncertainty, dim=1)
|
| 58 |
+
|
| 59 |
+
return uncertainty
|
| 60 |
+
|
| 61 |
+
def edl_unc(logits):
|
| 62 |
+
num_classes = logits.shape[1]
|
| 63 |
+
uncertainty = num_classes / logits.sum(dim=1)
|
| 64 |
+
return uncertainty
|
| 65 |
+
|
| 66 |
+
def certificate(OCs):
|
| 67 |
+
return OCs
|
SPC-UQ/Image_Classification/net/__init__.py
ADDED
|
File without changes
|
SPC-UQ/Image_Classification/net/imagenet_vgg.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from torchvision.models import vgg16, VGG16_Weights
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ImagenetVGG16(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
VGG16 wrapper for ImageNet-like classification with:
|
| 11 |
+
- Optional pretrained backbone
|
| 12 |
+
- Optional feature freezing
|
| 13 |
+
- Temperature-scaled logits
|
| 14 |
+
- Exposes penultimate features via `self.feature` (detached)
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
num_classes: int = 1000,
|
| 20 |
+
pretrained: bool = True,
|
| 21 |
+
temp: float = 1.0,
|
| 22 |
+
freeze_features: bool = False,
|
| 23 |
+
) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
# Load base model (weights imply specific preprocessing; handle in dataloader)
|
| 27 |
+
base_model = vgg16(weights=VGG16_Weights.DEFAULT if pretrained else None)
|
| 28 |
+
|
| 29 |
+
# Convolutional feature extractor and avgpool
|
| 30 |
+
self.features: nn.Module = base_model.features
|
| 31 |
+
self.avgpool: nn.Module = base_model.avgpool
|
| 32 |
+
|
| 33 |
+
# Penultimate FC stack from original VGG16 (remove final classifier layer)
|
| 34 |
+
# VGG16 classifier:
|
| 35 |
+
# [Linear(25088->4096), ReLU, Dropout, Linear(4096->4096), ReLU, Dropout, Linear(4096->1000)]
|
| 36 |
+
self.fc_pre: nn.Sequential = nn.Sequential(*list(base_model.classifier.children())[:-1])
|
| 37 |
+
|
| 38 |
+
# New classification head
|
| 39 |
+
self.classifier: nn.Linear = nn.Linear(4096, num_classes)
|
| 40 |
+
|
| 41 |
+
# If using pretrained and num_classes matches 1000, copy the final layer weights
|
| 42 |
+
if pretrained and num_classes == 1000:
|
| 43 |
+
with torch.no_grad():
|
| 44 |
+
self.classifier.weight.copy_(base_model.classifier[-1].weight)
|
| 45 |
+
self.classifier.bias.copy_(base_model.classifier[-1].bias)
|
| 46 |
+
|
| 47 |
+
# Optional: freeze convolutional features (+ fc_pre) for linear probing
|
| 48 |
+
if freeze_features:
|
| 49 |
+
for p in self.features.parameters():
|
| 50 |
+
p.requires_grad = False
|
| 51 |
+
for p in self.fc_pre.parameters():
|
| 52 |
+
p.requires_grad = False
|
| 53 |
+
|
| 54 |
+
# Temperature (applied to logits at inference/training)
|
| 55 |
+
self.register_buffer("temperature", torch.tensor(float(temp)))
|
| 56 |
+
self.feature: Optional[torch.Tensor] = None # cached detached penultimate features
|
| 57 |
+
|
| 58 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 59 |
+
"""
|
| 60 |
+
Args:
|
| 61 |
+
x: Input batch of shape (B, 3, H, W)
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
logits: Tensor of shape (B, num_classes)
|
| 65 |
+
"""
|
| 66 |
+
x = self.features(x)
|
| 67 |
+
x = self.avgpool(x)
|
| 68 |
+
x = torch.flatten(x, 1)
|
| 69 |
+
x = self.fc_pre(x)
|
| 70 |
+
|
| 71 |
+
# Cache penultimate features (detached) for downstream use
|
| 72 |
+
self.feature = x.detach()
|
| 73 |
+
|
| 74 |
+
logits = self.classifier(x)
|
| 75 |
+
if self.temperature is not None and float(self.temperature) != 1.0:
|
| 76 |
+
logits = logits / self.temperature
|
| 77 |
+
return logits
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def imagenet_vgg16(
|
| 81 |
+
temp: float = 1.0,
|
| 82 |
+
pretrained: bool = True,
|
| 83 |
+
num_classes: int = 1000,
|
| 84 |
+
freeze_features: bool = False,
|
| 85 |
+
**kwargs,
|
| 86 |
+
) -> ImagenetVGG16:
|
| 87 |
+
"""
|
| 88 |
+
Factory function for ImagenetVGG16.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
temp: Temperature applied to logits (T=1.0 disables scaling).
|
| 92 |
+
pretrained: Load pretrained ImageNet weights for the backbone.
|
| 93 |
+
num_classes: Output classes for the final classifier.
|
| 94 |
+
freeze_features: If True, freeze backbone + fc_pre for linear probing.
|
| 95 |
+
**kwargs: Forwarded to the ImagenetVGG16 init (future-proof).
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Initialized ImagenetVGG16 model.
|
| 99 |
+
"""
|
| 100 |
+
return ImagenetVGG16(
|
| 101 |
+
num_classes=num_classes,
|
| 102 |
+
pretrained=pretrained,
|
| 103 |
+
temp=temp,
|
| 104 |
+
freeze_features=freeze_features,
|
| 105 |
+
**kwargs,
|
| 106 |
+
)
|
SPC-UQ/Image_Classification/net/imagenet_vit.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from torchvision.models import vit_b_16, ViT_B_16_Weights
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ImagenetViT(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
ViT-B/16 wrapper with:
|
| 10 |
+
- Optional pretrained backbone (torchvision)
|
| 11 |
+
- Proper CLS token + positional embeddings usage
|
| 12 |
+
- Temperature-scaled logits
|
| 13 |
+
- Exposed CLS feature via `self.feature` (detached)
|
| 14 |
+
- Optional backbone freezing for linear probing
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
num_classes: int = 1000,
|
| 20 |
+
pretrained: bool = True,
|
| 21 |
+
temp: float = 1.0,
|
| 22 |
+
freeze_backbone: bool = False,
|
| 23 |
+
) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.backbone = vit_b_16(weights=ViT_B_16_Weights.DEFAULT if pretrained else None)
|
| 26 |
+
|
| 27 |
+
# Hidden size & final norm
|
| 28 |
+
self.hidden_dim: int = self.backbone.hidden_dim
|
| 29 |
+
self.norm: nn.Module = self.backbone.encoder.ln # final LayerNorm
|
| 30 |
+
|
| 31 |
+
# New classification head
|
| 32 |
+
self.head: nn.Linear = nn.Linear(self.hidden_dim, num_classes)
|
| 33 |
+
|
| 34 |
+
# If using pretrained and keeping 1000-class head, copy weights
|
| 35 |
+
if pretrained and num_classes == 1000:
|
| 36 |
+
with torch.no_grad():
|
| 37 |
+
src = self.backbone.heads.head
|
| 38 |
+
self.head.weight.copy_(src.weight)
|
| 39 |
+
self.head.bias.copy_(src.bias)
|
| 40 |
+
|
| 41 |
+
# Optionally freeze everything except the head
|
| 42 |
+
if freeze_backbone:
|
| 43 |
+
for p in self.backbone.parameters():
|
| 44 |
+
p.requires_grad = False
|
| 45 |
+
|
| 46 |
+
# Temperature buffer and CLS feature cache
|
| 47 |
+
self.register_buffer("temperature", torch.tensor(float(temp)))
|
| 48 |
+
self.feature: Optional[torch.Tensor] = None
|
| 49 |
+
|
| 50 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 51 |
+
"""
|
| 52 |
+
Args:
|
| 53 |
+
x: (B, 3, H, W)
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
logits: (B, num_classes)
|
| 57 |
+
"""
|
| 58 |
+
# Patchify + linear proj (handles resizing logic)
|
| 59 |
+
x = self.backbone._process_input(x) # (B, N, hidden_dim)
|
| 60 |
+
|
| 61 |
+
# CLS token + positional embeddings (use pretrained params)
|
| 62 |
+
n = x.shape[0]
|
| 63 |
+
cls_token = self.backbone.class_token.expand(n, -1, -1) # (B, 1, hidden_dim)
|
| 64 |
+
x = torch.cat((cls_token, x), dim=1) # (B, N+1, hidden_dim)
|
| 65 |
+
|
| 66 |
+
# Positional embeddings + dropout
|
| 67 |
+
x = x + self.backbone.encoder.pos_embedding # (B, N+1, hidden_dim)
|
| 68 |
+
x = self.backbone.encoder.dropout(x)
|
| 69 |
+
|
| 70 |
+
# Transformer encoder (layers) + final norm
|
| 71 |
+
x = self.backbone.encoder.layers(x)
|
| 72 |
+
x = self.norm(x)
|
| 73 |
+
|
| 74 |
+
# CLS feature
|
| 75 |
+
cls = x[:, 0] # (B, hidden_dim)
|
| 76 |
+
self.feature = cls.detach() # cache detached feature
|
| 77 |
+
|
| 78 |
+
# Head + optional temperature scaling
|
| 79 |
+
logits = self.head(cls)
|
| 80 |
+
if float(self.temperature) != 1.0:
|
| 81 |
+
logits = logits / self.temperature
|
| 82 |
+
return logits
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def imagenet_vit(
|
| 86 |
+
temp: float = 1.0,
|
| 87 |
+
pretrained: bool = True,
|
| 88 |
+
num_classes: int = 1000,
|
| 89 |
+
freeze_backbone: bool = False,
|
| 90 |
+
**kwargs,
|
| 91 |
+
) -> ImagenetViT:
|
| 92 |
+
"""
|
| 93 |
+
Factory for ImagenetViT.
|
| 94 |
+
"""
|
| 95 |
+
return ImagenetViT(
|
| 96 |
+
num_classes=num_classes,
|
| 97 |
+
pretrained=pretrained,
|
| 98 |
+
temp=temp,
|
| 99 |
+
freeze_backbone=freeze_backbone,
|
| 100 |
+
**kwargs,
|
| 101 |
+
)
|
SPC-UQ/Image_Classification/net/imagenet_wide.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torchvision.models import wide_resnet50_2, Wide_ResNet50_2_Weights
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ImagenetWideResNet(nn.Module):
|
| 8 |
+
def __init__(self, num_classes=1000, pretrained=True, temp=1.0):
|
| 9 |
+
super().__init__()
|
| 10 |
+
|
| 11 |
+
base_model = wide_resnet50_2(weights=Wide_ResNet50_2_Weights.DEFAULT if pretrained else None)
|
| 12 |
+
|
| 13 |
+
# Adapt to match your original WRN structure
|
| 14 |
+
self.features = nn.Sequential(
|
| 15 |
+
base_model.conv1,
|
| 16 |
+
base_model.bn1,
|
| 17 |
+
base_model.relu,
|
| 18 |
+
base_model.maxpool,
|
| 19 |
+
base_model.layer1,
|
| 20 |
+
base_model.layer2,
|
| 21 |
+
base_model.layer3,
|
| 22 |
+
base_model.layer4,
|
| 23 |
+
base_model.avgpool
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
self.linear = nn.Linear(2048, num_classes)
|
| 27 |
+
self.temp = temp
|
| 28 |
+
self.feature = None
|
| 29 |
+
|
| 30 |
+
if pretrained:
|
| 31 |
+
self.linear.load_state_dict(base_model.fc.state_dict())
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
out = self.features(x)
|
| 35 |
+
out = torch.flatten(out, 1)
|
| 36 |
+
self.feature = out.clone().detach()
|
| 37 |
+
if self.temp == 1:
|
| 38 |
+
out = self.linear(out)
|
| 39 |
+
else:
|
| 40 |
+
out = self.linear(out) / self.temp
|
| 41 |
+
return out
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def imagenet_wide(temp=1.0, pretrained=True, **kwargs):
|
| 45 |
+
model = ImagenetWideResNet(pretrained=pretrained, temp=temp, **kwargs)
|
| 46 |
+
return model
|
SPC-UQ/Image_Classification/net/lenet.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Implementation of Lenet in pytorch.
|
| 2 |
+
Refernece:
|
| 3 |
+
[1] LeCun, Y., Bottou, L., Bengio, Y., & Haffner, P. (1998).
|
| 4 |
+
Gradient-based learning applied to document recognition.
|
| 5 |
+
Proceedings of the IEEE, 86, 2278-2324.
|
| 6 |
+
"""
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class LeNet(nn.Module):
|
| 12 |
+
def __init__(self, num_classes, temp=1.0, mnist=True, **kwargs):
|
| 13 |
+
super(LeNet, self).__init__()
|
| 14 |
+
self.num_classes = num_classes
|
| 15 |
+
self.conv1 = nn.Conv2d(1 if mnist else 3, 6, 5)
|
| 16 |
+
self.conv2 = nn.Conv2d(6, 16, 5)
|
| 17 |
+
self.fc1 = nn.Linear(256, 120)
|
| 18 |
+
self.fc2 = nn.Linear(120, 84)
|
| 19 |
+
self.fc3 = nn.Linear(84, num_classes)
|
| 20 |
+
self.temp = temp
|
| 21 |
+
self.feature = None
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
out = F.relu(self.conv1(x))
|
| 25 |
+
out = F.max_pool2d(out, 2)
|
| 26 |
+
out = F.relu(self.conv2(out))
|
| 27 |
+
out = F.max_pool2d(out, 2)
|
| 28 |
+
out = out.view(out.size(0), -1)
|
| 29 |
+
out = F.relu(self.fc1(out))
|
| 30 |
+
out = F.relu(self.fc2(out))
|
| 31 |
+
self.feature = out
|
| 32 |
+
out = self.fc3(out) / self.temp
|
| 33 |
+
return out
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def lenet(num_classes=10, temp=1.0, mnist=True, **kwargs):
|
| 37 |
+
return LeNet(num_classes=num_classes, temp=temp, mnist=True, **kwargs)
|
SPC-UQ/Image_Classification/net/resnet.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pytorch implementation of ResNet models.
|
| 3 |
+
Reference:
|
| 4 |
+
[1] He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: CVPR, 2016.
|
| 5 |
+
"""
|
| 6 |
+
import torch
|
| 7 |
+
import math
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from net.spectral_normalization.spectral_norm_conv_inplace import spectral_norm_conv
|
| 12 |
+
from net.spectral_normalization.spectral_norm_fc import spectral_norm_fc
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class AvgPoolShortCut(nn.Module):
|
| 16 |
+
def __init__(self, stride, out_c, in_c):
|
| 17 |
+
super(AvgPoolShortCut, self).__init__()
|
| 18 |
+
self.stride = stride
|
| 19 |
+
self.out_c = out_c
|
| 20 |
+
self.in_c = in_c
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
if x.shape[2] % 2 != 0:
|
| 24 |
+
x = F.avg_pool2d(x, 1, self.stride)
|
| 25 |
+
else:
|
| 26 |
+
x = F.avg_pool2d(x, self.stride, self.stride)
|
| 27 |
+
pad = torch.zeros(x.shape[0], self.out_c - self.in_c, x.shape[2], x.shape[3], device=x.device,)
|
| 28 |
+
x = torch.cat((x, pad), dim=1)
|
| 29 |
+
return x
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class BasicBlock(nn.Module):
|
| 33 |
+
expansion = 1
|
| 34 |
+
|
| 35 |
+
def __init__(self, input_size, wrapped_conv, in_planes, planes, stride=1, mod=True):
|
| 36 |
+
super(BasicBlock, self).__init__()
|
| 37 |
+
self.conv1 = wrapped_conv(input_size, in_planes, planes, kernel_size=3, stride=stride)
|
| 38 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 39 |
+
self.conv2 = wrapped_conv(math.ceil(input_size / stride), planes, planes, kernel_size=3, stride=1)
|
| 40 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 41 |
+
self.mod = mod
|
| 42 |
+
self.activation = F.leaky_relu if self.mod else F.relu
|
| 43 |
+
|
| 44 |
+
self.shortcut = nn.Sequential()
|
| 45 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
| 46 |
+
if mod:
|
| 47 |
+
self.shortcut = nn.Sequential(AvgPoolShortCut(stride, self.expansion * planes, in_planes))
|
| 48 |
+
else:
|
| 49 |
+
self.shortcut = nn.Sequential(
|
| 50 |
+
wrapped_conv(input_size, in_planes, self.expansion * planes, kernel_size=1, stride=stride,),
|
| 51 |
+
nn.BatchNorm2d(planes),
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
out = self.activation(self.bn1(self.conv1(x)))
|
| 56 |
+
out = self.bn2(self.conv2(out))
|
| 57 |
+
out += self.shortcut(x)
|
| 58 |
+
out = self.activation(out)
|
| 59 |
+
return out
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class Bottleneck(nn.Module):
|
| 63 |
+
expansion = 4
|
| 64 |
+
|
| 65 |
+
def __init__(self, input_size, wrapped_conv, in_planes, planes, stride=1, mod=True):
|
| 66 |
+
super(Bottleneck, self).__init__()
|
| 67 |
+
self.conv1 = wrapped_conv(input_size, in_planes, planes, kernel_size=1, stride=1)
|
| 68 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 69 |
+
self.conv2 = wrapped_conv(input_size, planes, planes, kernel_size=3, stride=stride)
|
| 70 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 71 |
+
self.conv3 = wrapped_conv(math.ceil(input_size / stride), planes, self.expansion * planes, kernel_size=1, stride=1)
|
| 72 |
+
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
|
| 73 |
+
self.mod = mod
|
| 74 |
+
self.activation = F.leaky_relu if self.mod else F.relu
|
| 75 |
+
|
| 76 |
+
self.shortcut = nn.Sequential()
|
| 77 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
| 78 |
+
if mod:
|
| 79 |
+
self.shortcut = nn.Sequential(AvgPoolShortCut(stride, self.expansion * planes, in_planes))
|
| 80 |
+
else:
|
| 81 |
+
self.shortcut = nn.Sequential(
|
| 82 |
+
wrapped_conv(input_size, in_planes, self.expansion * planes, kernel_size=1, stride=stride,),
|
| 83 |
+
nn.BatchNorm2d(self.expansion * planes),
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def forward(self, x):
|
| 87 |
+
out = self.activation(self.bn1(self.conv1(x)))
|
| 88 |
+
out = self.activation(self.bn2(self.conv2(out)))
|
| 89 |
+
out = self.bn3(self.conv3(out))
|
| 90 |
+
out += self.shortcut(x)
|
| 91 |
+
out = self.activation(out)
|
| 92 |
+
return out
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class ResNet(nn.Module):
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
block,
|
| 99 |
+
num_blocks,
|
| 100 |
+
num_classes=10,
|
| 101 |
+
temp=1.0,
|
| 102 |
+
spectral_normalization=True,
|
| 103 |
+
mod=True,
|
| 104 |
+
coeff=3,
|
| 105 |
+
n_power_iterations=1,
|
| 106 |
+
mnist=False,
|
| 107 |
+
):
|
| 108 |
+
"""
|
| 109 |
+
If the "mod" parameter is set to True, the architecture uses 2 modifications:
|
| 110 |
+
1. LeakyReLU instead of normal ReLU
|
| 111 |
+
2. Average Pooling on the residual connections.
|
| 112 |
+
"""
|
| 113 |
+
super(ResNet, self).__init__()
|
| 114 |
+
self.in_planes = 64
|
| 115 |
+
|
| 116 |
+
self.mod = mod
|
| 117 |
+
|
| 118 |
+
def wrapped_conv(input_size, in_c, out_c, kernel_size, stride):
|
| 119 |
+
padding = 1 if kernel_size == 3 else 0
|
| 120 |
+
|
| 121 |
+
conv = nn.Conv2d(in_c, out_c, kernel_size, stride, padding, bias=False)
|
| 122 |
+
|
| 123 |
+
if not spectral_normalization:
|
| 124 |
+
return conv
|
| 125 |
+
|
| 126 |
+
# NOTE: Google uses the spectral_norm_fc in all cases
|
| 127 |
+
if kernel_size == 1:
|
| 128 |
+
# use spectral norm fc, because bound are tight for 1x1 convolutions
|
| 129 |
+
wrapped_conv = spectral_norm_fc(conv, coeff, n_power_iterations)
|
| 130 |
+
else:
|
| 131 |
+
# Otherwise use spectral norm conv, with loose bound
|
| 132 |
+
shapes = (in_c, input_size, input_size)
|
| 133 |
+
wrapped_conv = spectral_norm_conv(conv, coeff, shapes, n_power_iterations)
|
| 134 |
+
|
| 135 |
+
return wrapped_conv
|
| 136 |
+
|
| 137 |
+
self.wrapped_conv = wrapped_conv
|
| 138 |
+
|
| 139 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 140 |
+
|
| 141 |
+
if mnist:
|
| 142 |
+
self.conv1 = wrapped_conv(28, 1, 64, kernel_size=3, stride=1)
|
| 143 |
+
self.layer1 = self._make_layer(block, 28, 64, num_blocks[0], stride=1)
|
| 144 |
+
self.layer2 = self._make_layer(block, 28, 128, num_blocks[1], stride=2)
|
| 145 |
+
self.layer3 = self._make_layer(block, 14, 256, num_blocks[2], stride=2)
|
| 146 |
+
self.layer4 = self._make_layer(block, 7, 512, num_blocks[3], stride=2)
|
| 147 |
+
else:
|
| 148 |
+
self.conv1 = wrapped_conv(32, 3, 64, kernel_size=3, stride=1)
|
| 149 |
+
self.layer1 = self._make_layer(block, 32, 64, num_blocks[0], stride=1)
|
| 150 |
+
self.layer2 = self._make_layer(block, 32, 128, num_blocks[1], stride=2)
|
| 151 |
+
self.layer3 = self._make_layer(block, 16, 256, num_blocks[2], stride=2)
|
| 152 |
+
self.layer4 = self._make_layer(block, 8, 512, num_blocks[3], stride=2)
|
| 153 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
| 154 |
+
self.activation = F.leaky_relu if self.mod else F.relu
|
| 155 |
+
self.feature = None
|
| 156 |
+
self.temp = temp
|
| 157 |
+
|
| 158 |
+
def _make_layer(self, block, input_size, planes, num_blocks, stride):
|
| 159 |
+
strides = [stride] + [1] * (num_blocks - 1)
|
| 160 |
+
layers = []
|
| 161 |
+
for stride in strides:
|
| 162 |
+
layers.append(block(input_size, self.wrapped_conv, self.in_planes, planes, stride, self.mod,))
|
| 163 |
+
self.in_planes = planes * block.expansion
|
| 164 |
+
input_size = math.ceil(input_size / stride)
|
| 165 |
+
return nn.Sequential(*layers)
|
| 166 |
+
|
| 167 |
+
def forward(self, x):
|
| 168 |
+
out = self.activation(self.bn1(self.conv1(x)))
|
| 169 |
+
out = self.layer1(out)
|
| 170 |
+
out = self.layer2(out)
|
| 171 |
+
out = self.layer3(out)
|
| 172 |
+
out = self.layer4(out)
|
| 173 |
+
out = F.avg_pool2d(out, 4)
|
| 174 |
+
out = out.view(out.size(0), -1)
|
| 175 |
+
self.feature = out.clone().detach()
|
| 176 |
+
if self.temp==1:
|
| 177 |
+
out = self.fc(out)
|
| 178 |
+
else:
|
| 179 |
+
out = self.fc(out) / self.temp
|
| 180 |
+
return out
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def resnet18(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
|
| 184 |
+
model = ResNet(
|
| 185 |
+
BasicBlock,
|
| 186 |
+
[2, 2, 2, 2],
|
| 187 |
+
spectral_normalization=spectral_normalization,
|
| 188 |
+
mod=mod,
|
| 189 |
+
temp=temp,
|
| 190 |
+
mnist=mnist,
|
| 191 |
+
**kwargs
|
| 192 |
+
)
|
| 193 |
+
return model
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def resnet50(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
|
| 197 |
+
model = ResNet(
|
| 198 |
+
Bottleneck,
|
| 199 |
+
[3, 4, 6, 3],
|
| 200 |
+
spectral_normalization=spectral_normalization,
|
| 201 |
+
mod=mod,
|
| 202 |
+
temp=temp,
|
| 203 |
+
mnist=mnist,
|
| 204 |
+
**kwargs
|
| 205 |
+
)
|
| 206 |
+
return model
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def resnet101(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
|
| 210 |
+
model = ResNet(
|
| 211 |
+
Bottleneck,
|
| 212 |
+
[3, 4, 23, 3],
|
| 213 |
+
spectral_normalization=spectral_normalization,
|
| 214 |
+
mod=mod,
|
| 215 |
+
temp=temp,
|
| 216 |
+
mnist=mnist,
|
| 217 |
+
**kwargs
|
| 218 |
+
)
|
| 219 |
+
return model
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def resnet110(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
|
| 223 |
+
model = ResNet(
|
| 224 |
+
Bottleneck,
|
| 225 |
+
[3, 4, 26, 3],
|
| 226 |
+
spectral_normalization=spectral_normalization,
|
| 227 |
+
mod=mod,
|
| 228 |
+
temp=temp,
|
| 229 |
+
mnist=mnist,
|
| 230 |
+
**kwargs
|
| 231 |
+
)
|
| 232 |
+
return model
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def resnet152(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
|
| 236 |
+
model = ResNet(
|
| 237 |
+
Bottleneck,
|
| 238 |
+
[3, 8, 36, 3],
|
| 239 |
+
spectral_normalization=spectral_normalization,
|
| 240 |
+
mod=mod,
|
| 241 |
+
temp=temp,
|
| 242 |
+
mnist=mnist,
|
| 243 |
+
**kwargs
|
| 244 |
+
)
|
| 245 |
+
return model
|
SPC-UQ/Image_Classification/net/resnet_edl.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pytorch implementation of ResNet models.
|
| 3 |
+
Reference:
|
| 4 |
+
[1] He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: CVPR, 2016.
|
| 5 |
+
"""
|
| 6 |
+
import torch
|
| 7 |
+
import math
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from net.spectral_normalization.spectral_norm_conv_inplace import spectral_norm_conv
|
| 12 |
+
from net.spectral_normalization.spectral_norm_fc import spectral_norm_fc
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class AvgPoolShortCut(nn.Module):
|
| 16 |
+
def __init__(self, stride, out_c, in_c):
|
| 17 |
+
super(AvgPoolShortCut, self).__init__()
|
| 18 |
+
self.stride = stride
|
| 19 |
+
self.out_c = out_c
|
| 20 |
+
self.in_c = in_c
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
if x.shape[2] % 2 != 0:
|
| 24 |
+
x = F.avg_pool2d(x, 1, self.stride)
|
| 25 |
+
else:
|
| 26 |
+
x = F.avg_pool2d(x, self.stride, self.stride)
|
| 27 |
+
pad = torch.zeros(x.shape[0], self.out_c - self.in_c, x.shape[2], x.shape[3], device=x.device,)
|
| 28 |
+
x = torch.cat((x, pad), dim=1)
|
| 29 |
+
return x
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class BasicBlock(nn.Module):
|
| 33 |
+
expansion = 1
|
| 34 |
+
|
| 35 |
+
def __init__(self, input_size, wrapped_conv, in_planes, planes, stride=1, mod=True):
|
| 36 |
+
super(BasicBlock, self).__init__()
|
| 37 |
+
self.conv1 = wrapped_conv(input_size, in_planes, planes, kernel_size=3, stride=stride)
|
| 38 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 39 |
+
self.conv2 = wrapped_conv(math.ceil(input_size / stride), planes, planes, kernel_size=3, stride=1)
|
| 40 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 41 |
+
self.mod = mod
|
| 42 |
+
self.activation = F.leaky_relu if self.mod else F.relu
|
| 43 |
+
|
| 44 |
+
self.shortcut = nn.Sequential()
|
| 45 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
| 46 |
+
if mod:
|
| 47 |
+
self.shortcut = nn.Sequential(AvgPoolShortCut(stride, self.expansion * planes, in_planes))
|
| 48 |
+
else:
|
| 49 |
+
self.shortcut = nn.Sequential(
|
| 50 |
+
wrapped_conv(input_size, in_planes, self.expansion * planes, kernel_size=1, stride=stride,),
|
| 51 |
+
nn.BatchNorm2d(planes),
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
out = self.activation(self.bn1(self.conv1(x)))
|
| 56 |
+
out = self.bn2(self.conv2(out))
|
| 57 |
+
out += self.shortcut(x)
|
| 58 |
+
out = self.activation(out)
|
| 59 |
+
return out
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class Bottleneck(nn.Module):
|
| 63 |
+
expansion = 4
|
| 64 |
+
|
| 65 |
+
def __init__(self, input_size, wrapped_conv, in_planes, planes, stride=1, mod=True):
|
| 66 |
+
super(Bottleneck, self).__init__()
|
| 67 |
+
self.conv1 = wrapped_conv(input_size, in_planes, planes, kernel_size=1, stride=1)
|
| 68 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 69 |
+
self.conv2 = wrapped_conv(input_size, planes, planes, kernel_size=3, stride=stride)
|
| 70 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 71 |
+
self.conv3 = wrapped_conv(math.ceil(input_size / stride), planes, self.expansion * planes, kernel_size=1, stride=1)
|
| 72 |
+
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
|
| 73 |
+
self.mod = mod
|
| 74 |
+
self.activation = F.leaky_relu if self.mod else F.relu
|
| 75 |
+
|
| 76 |
+
self.shortcut = nn.Sequential()
|
| 77 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
| 78 |
+
if mod:
|
| 79 |
+
self.shortcut = nn.Sequential(AvgPoolShortCut(stride, self.expansion * planes, in_planes))
|
| 80 |
+
else:
|
| 81 |
+
self.shortcut = nn.Sequential(
|
| 82 |
+
wrapped_conv(input_size, in_planes, self.expansion * planes, kernel_size=1, stride=stride,),
|
| 83 |
+
nn.BatchNorm2d(self.expansion * planes),
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def forward(self, x):
|
| 87 |
+
out = self.activation(self.bn1(self.conv1(x)))
|
| 88 |
+
out = self.activation(self.bn2(self.conv2(out)))
|
| 89 |
+
out = self.bn3(self.conv3(out))
|
| 90 |
+
out += self.shortcut(x)
|
| 91 |
+
out = self.activation(out)
|
| 92 |
+
return out
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class ResNet(nn.Module):
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
block,
|
| 99 |
+
num_blocks,
|
| 100 |
+
num_classes=10,
|
| 101 |
+
temp=1.0,
|
| 102 |
+
spectral_normalization=True,
|
| 103 |
+
mod=True,
|
| 104 |
+
coeff=3,
|
| 105 |
+
n_power_iterations=1,
|
| 106 |
+
mnist=False,
|
| 107 |
+
):
|
| 108 |
+
"""
|
| 109 |
+
If the "mod" parameter is set to True, the architecture uses 2 modifications:
|
| 110 |
+
1. LeakyReLU instead of normal ReLU
|
| 111 |
+
2. Average Pooling on the residual connections.
|
| 112 |
+
"""
|
| 113 |
+
super(ResNet, self).__init__()
|
| 114 |
+
self.in_planes = 64
|
| 115 |
+
|
| 116 |
+
self.mod = mod
|
| 117 |
+
|
| 118 |
+
def wrapped_conv(input_size, in_c, out_c, kernel_size, stride):
|
| 119 |
+
padding = 1 if kernel_size == 3 else 0
|
| 120 |
+
|
| 121 |
+
conv = nn.Conv2d(in_c, out_c, kernel_size, stride, padding, bias=False)
|
| 122 |
+
|
| 123 |
+
if not spectral_normalization:
|
| 124 |
+
return conv
|
| 125 |
+
|
| 126 |
+
# NOTE: Google uses the spectral_norm_fc in all cases
|
| 127 |
+
if kernel_size == 1:
|
| 128 |
+
# use spectral norm fc, because bound are tight for 1x1 convolutions
|
| 129 |
+
wrapped_conv = spectral_norm_fc(conv, coeff, n_power_iterations)
|
| 130 |
+
else:
|
| 131 |
+
# Otherwise use spectral norm conv, with loose bound
|
| 132 |
+
shapes = (in_c, input_size, input_size)
|
| 133 |
+
wrapped_conv = spectral_norm_conv(conv, coeff, shapes, n_power_iterations)
|
| 134 |
+
|
| 135 |
+
return wrapped_conv
|
| 136 |
+
|
| 137 |
+
self.wrapped_conv = wrapped_conv
|
| 138 |
+
|
| 139 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 140 |
+
|
| 141 |
+
if mnist:
|
| 142 |
+
self.conv1 = wrapped_conv(28, 1, 64, kernel_size=3, stride=1)
|
| 143 |
+
self.layer1 = self._make_layer(block, 28, 64, num_blocks[0], stride=1)
|
| 144 |
+
self.layer2 = self._make_layer(block, 28, 128, num_blocks[1], stride=2)
|
| 145 |
+
self.layer3 = self._make_layer(block, 14, 256, num_blocks[2], stride=2)
|
| 146 |
+
self.layer4 = self._make_layer(block, 7, 512, num_blocks[3], stride=2)
|
| 147 |
+
else:
|
| 148 |
+
self.conv1 = wrapped_conv(32, 3, 64, kernel_size=3, stride=1)
|
| 149 |
+
self.layer1 = self._make_layer(block, 32, 64, num_blocks[0], stride=1)
|
| 150 |
+
self.layer2 = self._make_layer(block, 32, 128, num_blocks[1], stride=2)
|
| 151 |
+
self.layer3 = self._make_layer(block, 16, 256, num_blocks[2], stride=2)
|
| 152 |
+
self.layer4 = self._make_layer(block, 8, 512, num_blocks[3], stride=2)
|
| 153 |
+
self.activation = F.leaky_relu if self.mod else F.relu
|
| 154 |
+
self.feature = None
|
| 155 |
+
self.temp = temp
|
| 156 |
+
self.output_alpha = nn.Linear(512 * block.expansion, num_classes)
|
| 157 |
+
self.output_beta = nn.Linear(512 * block.expansion, num_classes)
|
| 158 |
+
self.output_nu = nn.Linear(512 * block.expansion, num_classes)
|
| 159 |
+
self.output_gamma = nn.Linear(512 * block.expansion, num_classes)
|
| 160 |
+
|
| 161 |
+
def _make_layer(self, block, input_size, planes, num_blocks, stride):
|
| 162 |
+
strides = [stride] + [1] * (num_blocks - 1)
|
| 163 |
+
layers = []
|
| 164 |
+
for stride in strides:
|
| 165 |
+
layers.append(block(input_size, self.wrapped_conv, self.in_planes, planes, stride, self.mod,))
|
| 166 |
+
self.in_planes = planes * block.expansion
|
| 167 |
+
input_size = math.ceil(input_size / stride)
|
| 168 |
+
return nn.Sequential(*layers)
|
| 169 |
+
|
| 170 |
+
def forward(self, x):
|
| 171 |
+
out = self.activation(self.bn1(self.conv1(x)))
|
| 172 |
+
out = self.layer1(out)
|
| 173 |
+
out = self.layer2(out)
|
| 174 |
+
out = self.layer3(out)
|
| 175 |
+
out = self.layer4(out)
|
| 176 |
+
out = F.avg_pool2d(out, 4)
|
| 177 |
+
out = out.view(out.size(0), -1)
|
| 178 |
+
self.feature = out.clone().detach()
|
| 179 |
+
alpha = self.output_alpha(out)
|
| 180 |
+
beta = self.output_beta(out)
|
| 181 |
+
nu = self.output_nu(out)
|
| 182 |
+
gamma = self.output_gamma(out)
|
| 183 |
+
alpha = F.softplus(alpha) + 1
|
| 184 |
+
beta = F.softplus(beta) + 1e-6
|
| 185 |
+
nu = F.softplus(nu) + 1e-6
|
| 186 |
+
gamma = F.softplus(gamma) + 1e-6
|
| 187 |
+
return alpha#[alpha, beta, nu, gamma]
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def resnet18_edl(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
|
| 191 |
+
model = ResNet(
|
| 192 |
+
BasicBlock,
|
| 193 |
+
[2, 2, 2, 2],
|
| 194 |
+
spectral_normalization=spectral_normalization,
|
| 195 |
+
mod=mod,
|
| 196 |
+
temp=temp,
|
| 197 |
+
mnist=mnist,
|
| 198 |
+
**kwargs
|
| 199 |
+
)
|
| 200 |
+
return model
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def resnet50_edl(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
|
| 204 |
+
model = ResNet(
|
| 205 |
+
Bottleneck,
|
| 206 |
+
[3, 4, 6, 3],
|
| 207 |
+
spectral_normalization=spectral_normalization,
|
| 208 |
+
mod=mod,
|
| 209 |
+
temp=temp,
|
| 210 |
+
mnist=mnist,
|
| 211 |
+
**kwargs
|
| 212 |
+
)
|
| 213 |
+
return model
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def resnet101_edl(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
|
| 217 |
+
model = ResNet(
|
| 218 |
+
Bottleneck,
|
| 219 |
+
[3, 4, 23, 3],
|
| 220 |
+
spectral_normalization=spectral_normalization,
|
| 221 |
+
mod=mod,
|
| 222 |
+
temp=temp,
|
| 223 |
+
mnist=mnist,
|
| 224 |
+
**kwargs
|
| 225 |
+
)
|
| 226 |
+
return model
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def resnet110_edl(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
|
| 230 |
+
model = ResNet(
|
| 231 |
+
Bottleneck,
|
| 232 |
+
[3, 4, 26, 3],
|
| 233 |
+
spectral_normalization=spectral_normalization,
|
| 234 |
+
mod=mod,
|
| 235 |
+
temp=temp,
|
| 236 |
+
mnist=mnist,
|
| 237 |
+
**kwargs
|
| 238 |
+
)
|
| 239 |
+
return model
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def resnet152_edl(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
|
| 243 |
+
model = ResNet(
|
| 244 |
+
Bottleneck,
|
| 245 |
+
[3, 8, 36, 3],
|
| 246 |
+
spectral_normalization=spectral_normalization,
|
| 247 |
+
mod=mod,
|
| 248 |
+
temp=temp,
|
| 249 |
+
mnist=mnist,
|
| 250 |
+
**kwargs
|
| 251 |
+
)
|
| 252 |
+
return model
|
SPC-UQ/Image_Classification/net/resnet_uq.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pytorch implementation of ResNet models.
|
| 3 |
+
Reference:
|
| 4 |
+
[1] He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: CVPR, 2016.
|
| 5 |
+
"""
|
| 6 |
+
import torch
|
| 7 |
+
import math
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from net.spectral_normalization.spectral_norm_conv_inplace import spectral_norm_conv
|
| 12 |
+
from net.spectral_normalization.spectral_norm_fc import spectral_norm_fc
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class AvgPoolShortCut(nn.Module):
|
| 16 |
+
def __init__(self, stride, out_c, in_c):
|
| 17 |
+
super(AvgPoolShortCut, self).__init__()
|
| 18 |
+
self.stride = stride
|
| 19 |
+
self.out_c = out_c
|
| 20 |
+
self.in_c = in_c
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
if x.shape[2] % 2 != 0:
|
| 24 |
+
x = F.avg_pool2d(x, 1, self.stride)
|
| 25 |
+
else:
|
| 26 |
+
x = F.avg_pool2d(x, self.stride, self.stride)
|
| 27 |
+
pad = torch.zeros(x.shape[0], self.out_c - self.in_c, x.shape[2], x.shape[3], device=x.device,)
|
| 28 |
+
x = torch.cat((x, pad), dim=1)
|
| 29 |
+
return x
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class BasicBlock(nn.Module):
|
| 33 |
+
expansion = 1
|
| 34 |
+
|
| 35 |
+
def __init__(self, input_size, wrapped_conv, in_planes, planes, stride=1, mod=True):
|
| 36 |
+
super(BasicBlock, self).__init__()
|
| 37 |
+
self.conv1 = wrapped_conv(input_size, in_planes, planes, kernel_size=3, stride=stride)
|
| 38 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 39 |
+
self.conv2 = wrapped_conv(math.ceil(input_size / stride), planes, planes, kernel_size=3, stride=1)
|
| 40 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 41 |
+
self.mod = mod
|
| 42 |
+
self.activation = F.leaky_relu if self.mod else F.relu
|
| 43 |
+
|
| 44 |
+
self.shortcut = nn.Sequential()
|
| 45 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
| 46 |
+
if mod:
|
| 47 |
+
self.shortcut = nn.Sequential(AvgPoolShortCut(stride, self.expansion * planes, in_planes))
|
| 48 |
+
else:
|
| 49 |
+
self.shortcut = nn.Sequential(
|
| 50 |
+
wrapped_conv(input_size, in_planes, self.expansion * planes, kernel_size=1, stride=stride,),
|
| 51 |
+
nn.BatchNorm2d(planes),
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
out = self.activation(self.bn1(self.conv1(x)))
|
| 56 |
+
out = self.bn2(self.conv2(out))
|
| 57 |
+
out += self.shortcut(x)
|
| 58 |
+
out = self.activation(out)
|
| 59 |
+
return out
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class Bottleneck(nn.Module):
|
| 63 |
+
expansion = 4
|
| 64 |
+
|
| 65 |
+
def __init__(self, input_size, wrapped_conv, in_planes, planes, stride=1, mod=True):
|
| 66 |
+
super(Bottleneck, self).__init__()
|
| 67 |
+
self.conv1 = wrapped_conv(input_size, in_planes, planes, kernel_size=1, stride=1)
|
| 68 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 69 |
+
self.conv2 = wrapped_conv(input_size, planes, planes, kernel_size=3, stride=stride)
|
| 70 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 71 |
+
self.conv3 = wrapped_conv(math.ceil(input_size / stride), planes, self.expansion * planes, kernel_size=1, stride=1)
|
| 72 |
+
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
|
| 73 |
+
self.mod = mod
|
| 74 |
+
self.activation = F.leaky_relu if self.mod else F.relu
|
| 75 |
+
|
| 76 |
+
self.shortcut = nn.Sequential()
|
| 77 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
| 78 |
+
if mod:
|
| 79 |
+
self.shortcut = nn.Sequential(AvgPoolShortCut(stride, self.expansion * planes, in_planes))
|
| 80 |
+
else:
|
| 81 |
+
self.shortcut = nn.Sequential(
|
| 82 |
+
wrapped_conv(input_size, in_planes, self.expansion * planes, kernel_size=1, stride=stride,),
|
| 83 |
+
nn.BatchNorm2d(self.expansion * planes),
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def forward(self, x):
|
| 87 |
+
out = self.activation(self.bn1(self.conv1(x)))
|
| 88 |
+
out = self.activation(self.bn2(self.conv2(out)))
|
| 89 |
+
out = self.bn3(self.conv3(out))
|
| 90 |
+
out += self.shortcut(x)
|
| 91 |
+
out = self.activation(out)
|
| 92 |
+
return out
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class ResNet(nn.Module):
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
block,
|
| 99 |
+
num_blocks,
|
| 100 |
+
num_classes=10,
|
| 101 |
+
temp=1.0,
|
| 102 |
+
spectral_normalization=True,
|
| 103 |
+
mod=True,
|
| 104 |
+
coeff=3,
|
| 105 |
+
n_power_iterations=1,
|
| 106 |
+
mnist=False,
|
| 107 |
+
):
|
| 108 |
+
"""
|
| 109 |
+
If the "mod" parameter is set to True, the architecture uses 2 modifications:
|
| 110 |
+
1. LeakyReLU instead of normal ReLU
|
| 111 |
+
2. Average Pooling on the residual connections.
|
| 112 |
+
"""
|
| 113 |
+
super(ResNet, self).__init__()
|
| 114 |
+
self.in_planes = 64
|
| 115 |
+
|
| 116 |
+
self.mod = mod
|
| 117 |
+
|
| 118 |
+
def wrapped_conv(input_size, in_c, out_c, kernel_size, stride):
|
| 119 |
+
padding = 1 if kernel_size == 3 else 0
|
| 120 |
+
|
| 121 |
+
conv = nn.Conv2d(in_c, out_c, kernel_size, stride, padding, bias=False)
|
| 122 |
+
|
| 123 |
+
if not spectral_normalization:
|
| 124 |
+
return conv
|
| 125 |
+
|
| 126 |
+
# NOTE: Google uses the spectral_norm_fc in all cases
|
| 127 |
+
if kernel_size == 1:
|
| 128 |
+
# use spectral norm fc, because bound are tight for 1x1 convolutions
|
| 129 |
+
wrapped_conv = spectral_norm_fc(conv, coeff, n_power_iterations)
|
| 130 |
+
else:
|
| 131 |
+
# Otherwise use spectral norm conv, with loose bound
|
| 132 |
+
shapes = (in_c, input_size, input_size)
|
| 133 |
+
wrapped_conv = spectral_norm_conv(conv, coeff, shapes, n_power_iterations)
|
| 134 |
+
|
| 135 |
+
return wrapped_conv
|
| 136 |
+
|
| 137 |
+
self.wrapped_conv = wrapped_conv
|
| 138 |
+
|
| 139 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 140 |
+
|
| 141 |
+
if mnist:
|
| 142 |
+
self.conv1 = wrapped_conv(28, 1, 64, kernel_size=3, stride=1)
|
| 143 |
+
self.layer1 = self._make_layer(block, 28, 64, num_blocks[0], stride=1)
|
| 144 |
+
self.layer2 = self._make_layer(block, 28, 128, num_blocks[1], stride=2)
|
| 145 |
+
self.layer3 = self._make_layer(block, 14, 256, num_blocks[2], stride=2)
|
| 146 |
+
self.layer4 = self._make_layer(block, 7, 512, num_blocks[3], stride=2)
|
| 147 |
+
else:
|
| 148 |
+
self.conv1 = wrapped_conv(32, 3, 64, kernel_size=3, stride=1)
|
| 149 |
+
self.layer1 = self._make_layer(block, 32, 64, num_blocks[0], stride=1)
|
| 150 |
+
self.layer2 = self._make_layer(block, 32, 128, num_blocks[1], stride=2)
|
| 151 |
+
self.layer3 = self._make_layer(block, 16, 256, num_blocks[2], stride=2)
|
| 152 |
+
self.layer4 = self._make_layer(block, 8, 512, num_blocks[3], stride=2)
|
| 153 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
| 154 |
+
self.activation = F.leaky_relu if self.mod else F.relu
|
| 155 |
+
self.feature = None
|
| 156 |
+
self.temp = temp
|
| 157 |
+
|
| 158 |
+
def make_branch():
|
| 159 |
+
layers = []
|
| 160 |
+
in_features = 512 * block.expansion
|
| 161 |
+
neurons = 512 * block.expansion
|
| 162 |
+
for _ in range(1):
|
| 163 |
+
layers.append(nn.Linear(in_features, neurons))
|
| 164 |
+
layers.append(F.relu)
|
| 165 |
+
# layers.append(nn.Dropout(dropout_p))
|
| 166 |
+
in_features = neurons
|
| 167 |
+
# neurons //= 2
|
| 168 |
+
return nn.Sequential(*layers), nn.Linear(in_features, num_classes)
|
| 169 |
+
|
| 170 |
+
self.hidden_mar, self.mar = make_branch()
|
| 171 |
+
self.hidden_mar_up, self.mar_up = make_branch()
|
| 172 |
+
self.hidden_mar_down, self.mar_down = make_branch()
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def forward(self, x):
|
| 176 |
+
mar = self.mar(self.hidden_mar(x))
|
| 177 |
+
mar_up = self.mar_up(self.hidden_mar_up(x))
|
| 178 |
+
mar_down = self.mar_down(self.hidden_mar_down(x))
|
| 179 |
+
return mar, mar_up, mar_down
|
| 180 |
+
|
| 181 |
+
def _make_layer(self, block, input_size, planes, num_blocks, stride):
|
| 182 |
+
strides = [stride] + [1] * (num_blocks - 1)
|
| 183 |
+
layers = []
|
| 184 |
+
for stride in strides:
|
| 185 |
+
layers.append(block(input_size, self.wrapped_conv, self.in_planes, planes, stride, self.mod,))
|
| 186 |
+
self.in_planes = planes * block.expansion
|
| 187 |
+
input_size = math.ceil(input_size / stride)
|
| 188 |
+
return nn.Sequential(*layers)
|
| 189 |
+
|
| 190 |
+
def forward(self, x):
|
| 191 |
+
out = self.activation(self.bn1(self.conv1(x)))
|
| 192 |
+
out = self.layer1(out)
|
| 193 |
+
out = self.layer2(out)
|
| 194 |
+
out = self.layer3(out)
|
| 195 |
+
out = self.layer4(out)
|
| 196 |
+
out = F.avg_pool2d(out, 4)
|
| 197 |
+
out = out.view(out.size(0), -1)
|
| 198 |
+
self.feature = out.clone().detach()
|
| 199 |
+
if self.temp==1:
|
| 200 |
+
pred = self.fc(out)
|
| 201 |
+
else:
|
| 202 |
+
pred = self.fc(out) / self.temp
|
| 203 |
+
|
| 204 |
+
mar = self.mar(self.hidden_mar(out))
|
| 205 |
+
mar_up = self.mar_up(self.hidden_mar_up(out))
|
| 206 |
+
mar_down = self.mar_down(self.hidden_mar_down(out))
|
| 207 |
+
return pred, mar, mar_up, mar_down
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def resnet18(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
|
| 211 |
+
model = ResNet(
|
| 212 |
+
BasicBlock,
|
| 213 |
+
[2, 2, 2, 2],
|
| 214 |
+
spectral_normalization=spectral_normalization,
|
| 215 |
+
mod=mod,
|
| 216 |
+
temp=temp,
|
| 217 |
+
mnist=mnist,
|
| 218 |
+
**kwargs
|
| 219 |
+
)
|
| 220 |
+
return model
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def resnet50(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
|
| 224 |
+
model = ResNet(
|
| 225 |
+
Bottleneck,
|
| 226 |
+
[3, 4, 6, 3],
|
| 227 |
+
spectral_normalization=spectral_normalization,
|
| 228 |
+
mod=mod,
|
| 229 |
+
temp=temp,
|
| 230 |
+
mnist=mnist,
|
| 231 |
+
**kwargs
|
| 232 |
+
)
|
| 233 |
+
return model
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def resnet101(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
|
| 237 |
+
model = ResNet(
|
| 238 |
+
Bottleneck,
|
| 239 |
+
[3, 4, 23, 3],
|
| 240 |
+
spectral_normalization=spectral_normalization,
|
| 241 |
+
mod=mod,
|
| 242 |
+
temp=temp,
|
| 243 |
+
mnist=mnist,
|
| 244 |
+
**kwargs
|
| 245 |
+
)
|
| 246 |
+
return model
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def resnet110(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
|
| 250 |
+
model = ResNet(
|
| 251 |
+
Bottleneck,
|
| 252 |
+
[3, 4, 26, 3],
|
| 253 |
+
spectral_normalization=spectral_normalization,
|
| 254 |
+
mod=mod,
|
| 255 |
+
temp=temp,
|
| 256 |
+
mnist=mnist,
|
| 257 |
+
**kwargs
|
| 258 |
+
)
|
| 259 |
+
return model
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def resnet152(spectral_normalization=True, mod=True, temp=1.0, mnist=False, imagenet=False, **kwargs):
|
| 263 |
+
model = ResNet(
|
| 264 |
+
Bottleneck,
|
| 265 |
+
[3, 8, 36, 3],
|
| 266 |
+
spectral_normalization=spectral_normalization,
|
| 267 |
+
mod=mod,
|
| 268 |
+
temp=temp,
|
| 269 |
+
mnist=mnist,
|
| 270 |
+
**kwargs
|
| 271 |
+
)
|
| 272 |
+
return model
|
SPC-UQ/Image_Classification/net/spectral_normalization/__init__.py
ADDED
|
File without changes
|