From e93b5a1acc9e5317b5347ee22f1282540ffc127a Mon Sep 17 00:00:00 2001 From: felix Date: Sat, 23 Sep 2023 14:31:29 +0200 Subject: [PATCH] update train scripts --- references/detection/train_pytorch.py | 2 +- references/recognition/train_pytorch.py | 8 ++++++ references/recognition/train_pytorch_ddp.py | 31 +++++++++++++++++++-- references/recognition/train_tensorflow.py | 8 ++++++ 4 files changed, 45 insertions(+), 4 deletions(-) diff --git a/references/detection/train_pytorch.py b/references/detection/train_pytorch.py index b65e7aff41..41f421a341 100644 --- a/references/detection/train_pytorch.py +++ b/references/detection/train_pytorch.py @@ -317,7 +317,7 @@ def main(args): # Backbone freezing if args.freeze_backbone: for p in model.feat_extractor.parameters(): - p.reguires_grad_(False) + p.requires_grad = False # Optimizer optimizer = torch.optim.Adam( diff --git a/references/recognition/train_pytorch.py b/references/recognition/train_pytorch.py index c419571d8d..d27f7cc64d 100644 --- a/references/recognition/train_pytorch.py +++ b/references/recognition/train_pytorch.py @@ -339,6 +339,11 @@ def main(args): plot_samples(x, target) return + # Backbone freezing + if args.freeze_backbone: + for p in model.feat_extractor.parameters(): + p.requires_grad = False + # Optimizer optimizer = torch.optim.Adam( [p for p in model.parameters() if p.requires_grad], @@ -457,6 +462,9 @@ def parse_args(): parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint") parser.add_argument("--vocab", type=str, default="french", help="Vocab to be used for training") parser.add_argument("--test-only", dest="test_only", action="store_true", help="Run the validation loop") + parser.add_argument( + "--freeze-backbone", dest="freeze_backbone", action="store_true", help="freeze model backbone for fine-tuning" + ) parser.add_argument( "--show-samples", dest="show_samples", action="store_true", help="Display unormalized training samples" ) diff --git a/references/recognition/train_pytorch_ddp.py b/references/recognition/train_pytorch_ddp.py index ec18ce0298..337aebb700 100644 --- a/references/recognition/train_pytorch_ddp.py +++ b/references/recognition/train_pytorch_ddp.py @@ -23,7 +23,14 @@ from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR from torch.utils.data import DataLoader, SequentialSampler from torch.utils.data.distributed import DistributedSampler -from torchvision.transforms import ColorJitter, Compose, Normalize +from torchvision.transforms.v2 import ( + Compose, + GaussianBlur, + Normalize, + RandomGrayscale, + RandomPerspective, + RandomPhotometricDistort, +) from doctr import transforms as T from doctr.datasets import VOCABS, RecognitionDataset, WordGenerator @@ -170,6 +177,11 @@ def main(rank: int, world_size: int, args): checkpoint = torch.load(args.resume, map_location="cpu") model.load_state_dict(checkpoint) + # Backbone freezing + if args.freeze_backbone: + for p in model.feat_extractor.parameters(): + p.requires_grad = False + # create default process group device = torch.device("cuda", args.devices[rank]) dist.init_process_group(args.backend, rank=rank, world_size=world_size) @@ -211,7 +223,12 @@ def main(rank: int, world_size: int, args): T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), # Augmentations T.RandomApply(T.ColorInversion(), 0.1), - ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02), + RandomGrayscale(p=0.1), + RandomPhotometricDistort(p=0.1), + T.RandomApply(T.RandomShadow(), p=0.4), + T.RandomApply(T.GaussianNoise(mean=0, std=0.1), 0.1), + T.RandomApply(GaussianBlur(3), 0.3), + RandomPerspective(distortion_scale=0.2, p=0.3), ] ), ) @@ -234,7 +251,12 @@ def main(rank: int, world_size: int, args): T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), # Ensure we have a 90% split of white-background images T.RandomApply(T.ColorInversion(), 0.9), - ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02), + RandomGrayscale(p=0.1), + RandomPhotometricDistort(p=0.1), + T.RandomApply(T.RandomShadow(), p=0.4), + T.RandomApply(T.GaussianNoise(mean=0, std=0.1), 0.1), + T.RandomApply(GaussianBlur(3), 0.3), + RandomPerspective(distortion_scale=0.2, p=0.3), ] ), ) @@ -376,6 +398,9 @@ def parse_args(): parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint") parser.add_argument("--vocab", type=str, default="french", help="Vocab to be used for training") parser.add_argument("--test-only", dest="test_only", action="store_true", help="Run the validation loop") + parser.add_argument( + "--freeze-backbone", dest="freeze_backbone", action="store_true", help="freeze model backbone for fine-tuning" + ) parser.add_argument( "--show-samples", dest="show_samples", action="store_true", help="Display unormalized training samples" ) diff --git a/references/recognition/train_tensorflow.py b/references/recognition/train_tensorflow.py index 8064a1dde5..9dfa957dd9 100644 --- a/references/recognition/train_tensorflow.py +++ b/references/recognition/train_tensorflow.py @@ -333,6 +333,11 @@ def main(args): task = Task.init(project_name="docTR/text-recognition", task_name=exp_name, reuse_last_task_id=False) task.upload_artifact("config", config) + # Backbone freezing + if args.freeze_backbone: + for layer in model.feat_extractor.layers: + layer.trainable = False + min_loss = np.inf # Training loop @@ -413,6 +418,9 @@ def parse_args(): parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint") parser.add_argument("--vocab", type=str, default="french", help="Vocab to be used for training") parser.add_argument("--test-only", dest="test_only", action="store_true", help="Run the validation loop") + parser.add_argument( + "--freeze-backbone", dest="freeze_backbone", action="store_true", help="freeze model backbone for fine-tuning" + ) parser.add_argument( "--show-samples", dest="show_samples", action="store_true", help="Display unormalized training samples" )