Skip to content

Commit

Permalink
Merge pull request #51 from Immortalise/main
Browse files Browse the repository at this point in the history
add "Improving Generalization of Adversarial Training via Robust Critical Fine-Tuning" ICCV 2023
  • Loading branch information
jindongwang authored Aug 18, 2023
2 parents 8b8bde5 + 8558848 commit 3eff2b7
Show file tree
Hide file tree
Showing 32 changed files with 4,365 additions and 1 deletion.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@
Latest research in robust machine learning, including adversarial/backdoor attack and defense, out-of-distribution (OOD) generalization, and safe transfer learning.

Hosted projects:

- **RiFT** (ICCV 2023, #Adversarial Robustness, #Generalization, #OOD)
- [Code](./RiFT/) | [Improving Generalization of Adversarial Training via Robust Critical Fine-Tuning](https://arxiv.org/abs/2308.02533)

- **Diversify** (ICLR 2023, #OOD):
- [Code](./diversify/) | [Out-of-distribution Representation Learning for Time Series Classification](https://arxiv.org/abs/2209.07027)
- [Code](./diversify/) | [Out-of-distribution Representatio[n Learning for Time Series Classification](https://arxiv.org/abs/2209.07027)
- **DRM** (KDD 2023, #OOD):
- [Code](./drm/) | [Domain-Specific Risk Minimization for Out-of-Distribution Generalization](https://arxiv.org/abs/2208.08661)
- **DDLearn** (KDD 2023, #OOD):
Expand Down
166 changes: 166 additions & 0 deletions RiFT/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

data/
results*/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
105 changes: 105 additions & 0 deletions RiFT/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
![](https://files.mdnice.com/user/45288/023bf2cb-1685-43ce-bba8-1ba9b66f80b4.png)

# Improving Generalization of Adversarial Training via Robust Critical Fine-Tuning

This is the official implementation of ICCV2023 [Improving Generalization of Adversarial Training via Robust Critical Fine-Tuning](https://arxiv.org/abs/2308.02533).

**Abstract**: Deep neural networks are susceptible to adversarial examples, posing a significant security risk in critical applications. Adversarial Training (AT) is a well-established technique to enhance adversarial robustness, but it often comes at the cost of decreased generalization ability. This paper proposes Robustness Critical Fine-Tuning (RiFT), a novel approach to enhance generalization without compromising adversarial robustness. The core idea of RiFT is to exploit the redundant capacity for robustness by fine-tuning the adversarially trained model on its non-robust-critical module. To do so, we introduce module robust criticality (MRC), a measure that evaluates the significance of a given module to model robustness under worst-case weight perturbations. Using this measure, we identify the module with the lowest MRC value as the non-robust-critical module and fine-tune its weights to obtain fine-tuned weights. Subsequently, we linearly interpolate between the adversarially trained weights and fine-tuned weights to derive the optimal fine-tuned model weights. We demonstrate the efficacy of RiFT on ResNet18, ResNet34, and WideResNet34-10 models trained on CIFAR10, CIFAR100, and Tiny-ImageNet datasets. Our experiments show that RiFT can significantly improve both generalization and out-of-distribution robust- ness by around 1.5% while maintaining or even slightly enhancing adversarial robustness. Code is available at https://github.com/microsoft/robustlearn.

## Requirements

### Running Enviroments

To install requirements:

```
conda env create -f env.yaml
conda activate rift
```

### Datasets

CIFAR10 and CIFAR100 can be downloaded via PyTorch.

For other datasets:

1. [Tiny-ImageNet](http://cs231n.stanford.edu/tiny-imagenet-200.zip)
2. [CIFAR10-C](https://drive.google.com/drive/folders/1HDVw6CmX3HiG0ODFtI75iIfBDxSiSz2K)
3. [CIFAR100-C](https://drive.google.com/drive/folders/1HDVw6CmX3HiG0ODFtI75iIfBDxSiSz2K)
4. [Tiny-ImageNet-C](https://berkeley.app.box.com/s/6zt1qzwm34hgdzcvi45svsb10zspop8a)

After downloading these datasets, move them to ./data.

The images in Tiny-ImageNet datasets are 64x64 with 200 classes.

## Robust Critical Fine-Tuning

### Demo

Here we present a example for RiFT ResNet18 on CIFAR10.

Download the adversarially trained model weights [here](https://drive.google.com/drive/folders/1Uzqm1cOYFXLa97GZjjwfiVS2OcbpJK4o?usp=drive_link).

```
python main.py --layer=layer2.1.conv2 --resume="./ResNet18_CIFAR10.pth"
```

- layer: the desired layer name to fine-tune.

Here, layer2.1.conv2 is a non-robust-critical module.

The non-robust-critical module of each model on each dataset are summarized as follows:

| | CIFAR10 | CIFAR100 | Tiny-ImageNet |
| -------- | -------------------- | -------------------- | -------------------- |
| ResNet18 | layer2.1.conv2 | layer2.1.conv2 | layer3.1.conv2 |
| ResNet34 | layer2.3.conv2 | layer2.3.conv2 | layer3.5.conv2 |
| WRN34-10 | block1.layer.3.conv2 | block1.layer.2.conv2 | block1.layer.2.conv2 |

### Pipeline

1. Characterize the MRC for each module
`python main.py --cal_mrc --resume=/path/to/your/model`
This will output the MRC for each module.
2. Fine-tuning on non-robust-critical module
Based on the MRC output, choose a module with lowest MRC value to fine-tune.
We suggest to choose the **middle layers** according to our experience.
Try different learning rate! Usually a small learning rate is preferred.
`python main.py --layer=xxx --lr=yyy --resume=zzz`
When fine-tuning finish, it will automatically interpolate between adversarially trained weights and fine-tuned weights.
The robust accuracy, in-distribution test acc are evaluated during the interpolation procedure.
3. Test OOD performance. Pick he best interpolation factor (the one with max IID generalization increase while not drop robustness so much.)
`python eval_ood.py --resume=xxx`

## Results

![](https://files.mdnice.com/user/45288/c3c98491-a292-4888-82cc-081bc8d3c3c6.png)




![](https://files.mdnice.com/user/45288/bad5bb9f-788d-4350-ac5c-ddd850ade04f.png)



## References & Opensources

- Classification models [code](https://github.com/kuangliu/pytorch-cifar)
- Adversarial training [code](https://github.com/P2333/Bag-of-Tricks-for-AT)

## Contact

- Kaijie Zhu: [email protected]
- Jindong Wang: [email protected]

## Citation

```
@inproceedings{zhu2023improving,
title={Improving Generalization of Adversarial Training via Robust Critical Fine-Tuning},
author={Zhu, Kaijie and Hu, Xixu and Wang, Jindong and Xie, Xing and Yang, Ge },
year={2023},
booktitle={International Conference on Computer Vision},
}
```

Loading

0 comments on commit 3eff2b7

Please sign in to comment.