Skip to content

Commit

Permalink
Allow the use of phikon features in the benchmark script (#297)
Browse files Browse the repository at this point in the history
* feat: Use torchvision normalize function for ResNet and improve weights for Camelyon16

* feat: Add phikon feature extractor

* feat: Improve Camelyon16 benchmark script

* feat: Allow Camelyon16 models to use features of dimension different from 2048

* feat: Use 45 epochs

* feat: Use ssl features for fed_benchmark

* feat: Add plot
  • Loading branch information
xavier-owkin authored Jan 9, 2024
1 parent ec5af8e commit c658e54
Show file tree
Hide file tree
Showing 7 changed files with 5,836 additions and 25 deletions.
62 changes: 44 additions & 18 deletions flamby/benchmarks/fed_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,10 @@ def main(args_cli):
NUM_EPOCHS_POOLED = 0

# We can now instantiate the dataset specific model on CPU
global_init = Baseline()
if args_cli.use_ssl_features and dataset_name == "fed_camelyon16":
global_init = Baseline(768)
else:
global_init = Baseline()

# We parse the hyperparams from the config or from the CLI if strategy is given
strategy_specific_hp_dicts = get_strategies(
Expand Down Expand Up @@ -461,7 +464,6 @@ def main(args_cli):


if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument(
"--GPU", type=int, default=0, help="GPU to run the training on (if available)"
Expand Down Expand Up @@ -513,33 +515,41 @@ def main(args_cli):
"-nft",
type=int,
default=None,
help="The number of SGD fine-tuning updates to be"
"performed on the model at the personalization step,"
"if strategy is given and that it is FedAvgFineTuning",
help=(
"The number of SGD fine-tuning updates to be"
"performed on the model at the personalization step,"
"if strategy is given and that it is FedAvgFineTuning"
),
)
parser.add_argument(
"--tau",
"-tau",
type=float,
default=None,
help="FedOpt tau parameter used only if strategy is "
"given and that it is a fedopt strategy",
help=(
"FedOpt tau parameter used only if strategy is "
"given and that it is a fedopt strategy"
),
)
parser.add_argument(
"--beta1",
"-b1",
type=float,
default=None,
help="FedOpt beta1 parameter used only if strategy is "
"given and that it is a fedopt strategy",
help=(
"FedOpt beta1 parameter used only if strategy is "
"given and that it is a fedopt strategy"
),
)
parser.add_argument(
"--beta2",
"-b2",
type=float,
default=None,
help="FedOpt beta2 parameter used only if strategy is"
" given and that it is a fedopt strategy",
help=(
"FedOpt beta2 parameter used only if strategy is"
" given and that it is a fedopt strategy"
),
)
parser.add_argument(
"--strategy",
Expand Down Expand Up @@ -578,22 +588,24 @@ def main(args_cli):
"-dpe",
type=float,
default=None,
help="the target epsilon for (epsilon, delta)-differential" "private guarantee",
help="the target epsilon for (epsilon, delta)-differential private guarantee",
)
parser.add_argument(
"--dp_target_delta",
"-dpd",
type=float,
default=None,
help="the target delta for (epsilon, delta)-differential" "private guarantee",
help="the target delta for (epsilon, delta)-differential private guarantee",
)
parser.add_argument(
"--dp_max_grad_norm",
"-mgn",
type=float,
default=None,
help="the maximum L2 norm of per-sample gradients; "
"used to enforce differential privacy",
help=(
"the maximum L2 norm of per-sample gradients; "
"used to enforce differential privacy"
),
)
parser.add_argument(
"--log",
Expand Down Expand Up @@ -621,15 +633,29 @@ def main(args_cli):
"-scb",
default=None,
type=str,
help="Whether or not to compute only one single-centric baseline and which one.",
help=(
"Whether or not to compute only one single-centric baseline and which one."
),
choices=["Pooled", "Local"],
)
parser.add_argument(
"--nlocal",
default=0,
type=int,
help="Will only be used if --single-centric-baseline Local, will test"
"only training on Local {nlocal}.",
help=(
"Will only be used if --single-centric-baseline Local, will test"
"only training on Local {nlocal}."
),
)
parser.add_argument(
"--use-ssl-features",
action="store_true",
help=(
"Whether to use the much more performant phikon feature extractor on"
" Camelyon16, trained with self-supervised learning on histology datasets"
" from https://www.medrxiv.org/content/10.1101/2023.07.21.23292757v2"
" instead of imagenet-trained resnet."
),
)
parser.add_argument("--seed", default=0, type=int, help="Seed")

Expand Down
7 changes: 5 additions & 2 deletions flamby/datasets/fed_camelyon16/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,11 @@ def main(num_workers_torch, log=False, log_period=10, debug=False, cpu_only=Fals
"""
metrics_dict = {"AUC": metric}
use_gpu = torch.cuda.is_available() and not (cpu_only)
training_set = FedCamelyon16(train=True, pooled=True, debug=debug)
# extract feature dimension used
features_dimension = training_set[0][0].size(1)
training_dl = dl(
FedCamelyon16(train=True, pooled=True, debug=debug),
training_set,
num_workers=num_workers_torch,
batch_size=BATCH_SIZE,
collate_fn=collate_fn,
Expand Down Expand Up @@ -69,7 +72,7 @@ def main(num_workers_torch, log=False, log_period=10, debug=False, cpu_only=Fals
# At each new seed we re-initialize the model
# and training_dl is shuffled as well
torch.manual_seed(seed)
m = Baseline()
m = Baseline(features_dimension)
# We put the model on GPU whenever it is possible
if use_gpu:
m = m.cuda()
Expand Down
7 changes: 4 additions & 3 deletions flamby/datasets/fed_camelyon16/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,21 +258,22 @@ def collate_fn(dataset_elements_list, max_tiles=10000):
Parameters
----------
dataset_elements_list : List[torch.Tensor]
A list of torch tensors of dimensions [n, 2048] with uneven distribution of ns.
A list of torch tensors of dimensions [n, m] with uneven distribution of ns.
max_tiles : int, optional
The nummber of tiles max by Tensor, by default 10000
Returns
-------
Tuple(torch.Tensor, torch.Tensor)
X, y two torch tensors of size (len(dataset_elements_list), max_tiles, 2048) and
X, y two torch tensors of size (len(dataset_elements_list), max_tiles, m) and
(len(dataset_elements_list),)
"""
n = len(dataset_elements_list)
X0, y0, _ = dataset_elements_list[0]
feature_dim = X0.size(1)
X_dtype = X0.dtype
y_dtype = y0.dtype
X = torch.zeros((n, max_tiles, 2048), dtype=X_dtype)
X = torch.zeros((n, max_tiles, feature_dim), dtype=X_dtype)
y = torch.empty((n, 1), dtype=y_dtype)

for i in range(n):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torchvision.transforms import Compose, ToTensor
from transformers import AutoImageProcessor, ViTModel
from tqdm import tqdm
from transformers import AutoImageProcessor, ViTModel

from flamby.utils import read_config, write_value_in_config

Expand Down
4 changes: 2 additions & 2 deletions flamby/datasets/fed_camelyon16/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@


class Baseline(nn.Module):
def __init__(self):
def __init__(self, original_dimension: int = 2048):
super(Baseline, self).__init__()
# As per the article
self.Od = 2048 # Original dimension of the input embeddings
self.Od = original_dimension # Original dimension of the input embeddings
self.M = 128 # New dimension of the input embedding

self.L = 128 # Dimension of the new features after query and value projections
Expand Down
Loading

0 comments on commit c658e54

Please sign in to comment.