Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[scripts] Add backbone freeze for recognition scripts and update augmentations also for DDP script #1328

Merged
merged 1 commit into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion references/detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def main(args):
# Backbone freezing
if args.freeze_backbone:
for p in model.feat_extractor.parameters():
p.reguires_grad_(False)
p.requires_grad = False

# Optimizer
optimizer = torch.optim.Adam(
Expand Down
8 changes: 8 additions & 0 deletions references/recognition/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,11 @@ def main(args):
plot_samples(x, target)
return

# Backbone freezing
if args.freeze_backbone:
for p in model.feat_extractor.parameters():
p.requires_grad = False

# Optimizer
optimizer = torch.optim.Adam(
[p for p in model.parameters() if p.requires_grad],
Expand Down Expand Up @@ -457,6 +462,9 @@ def parse_args():
parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint")
parser.add_argument("--vocab", type=str, default="french", help="Vocab to be used for training")
parser.add_argument("--test-only", dest="test_only", action="store_true", help="Run the validation loop")
parser.add_argument(
"--freeze-backbone", dest="freeze_backbone", action="store_true", help="freeze model backbone for fine-tuning"
)
parser.add_argument(
"--show-samples", dest="show_samples", action="store_true", help="Display unormalized training samples"
)
Expand Down
31 changes: 28 additions & 3 deletions references/recognition/train_pytorch_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,14 @@
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR
from torch.utils.data import DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from torchvision.transforms import ColorJitter, Compose, Normalize
from torchvision.transforms.v2 import (
Compose,
GaussianBlur,
Normalize,
RandomGrayscale,
RandomPerspective,
RandomPhotometricDistort,
)

from doctr import transforms as T
from doctr.datasets import VOCABS, RecognitionDataset, WordGenerator
Expand Down Expand Up @@ -170,6 +177,11 @@ def main(rank: int, world_size: int, args):
checkpoint = torch.load(args.resume, map_location="cpu")
model.load_state_dict(checkpoint)

# Backbone freezing
if args.freeze_backbone:
for p in model.feat_extractor.parameters():
p.requires_grad = False

# create default process group
device = torch.device("cuda", args.devices[rank])
dist.init_process_group(args.backend, rank=rank, world_size=world_size)
Expand Down Expand Up @@ -211,7 +223,12 @@ def main(rank: int, world_size: int, args):
T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
# Augmentations
T.RandomApply(T.ColorInversion(), 0.1),
ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02),
RandomGrayscale(p=0.1),
RandomPhotometricDistort(p=0.1),
T.RandomApply(T.RandomShadow(), p=0.4),
T.RandomApply(T.GaussianNoise(mean=0, std=0.1), 0.1),
T.RandomApply(GaussianBlur(3), 0.3),
RandomPerspective(distortion_scale=0.2, p=0.3),
]
),
)
Expand All @@ -234,7 +251,12 @@ def main(rank: int, world_size: int, args):
T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
# Ensure we have a 90% split of white-background images
T.RandomApply(T.ColorInversion(), 0.9),
ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02),
RandomGrayscale(p=0.1),
RandomPhotometricDistort(p=0.1),
T.RandomApply(T.RandomShadow(), p=0.4),
T.RandomApply(T.GaussianNoise(mean=0, std=0.1), 0.1),
T.RandomApply(GaussianBlur(3), 0.3),
RandomPerspective(distortion_scale=0.2, p=0.3),
]
),
)
Expand Down Expand Up @@ -376,6 +398,9 @@ def parse_args():
parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint")
parser.add_argument("--vocab", type=str, default="french", help="Vocab to be used for training")
parser.add_argument("--test-only", dest="test_only", action="store_true", help="Run the validation loop")
parser.add_argument(
"--freeze-backbone", dest="freeze_backbone", action="store_true", help="freeze model backbone for fine-tuning"
)
parser.add_argument(
"--show-samples", dest="show_samples", action="store_true", help="Display unormalized training samples"
)
Expand Down
8 changes: 8 additions & 0 deletions references/recognition/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,11 @@ def main(args):
task = Task.init(project_name="docTR/text-recognition", task_name=exp_name, reuse_last_task_id=False)
task.upload_artifact("config", config)

# Backbone freezing
if args.freeze_backbone:
for layer in model.feat_extractor.layers:
layer.trainable = False

min_loss = np.inf

# Training loop
Expand Down Expand Up @@ -413,6 +418,9 @@ def parse_args():
parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint")
parser.add_argument("--vocab", type=str, default="french", help="Vocab to be used for training")
parser.add_argument("--test-only", dest="test_only", action="store_true", help="Run the validation loop")
parser.add_argument(
"--freeze-backbone", dest="freeze_backbone", action="store_true", help="freeze model backbone for fine-tuning"
)
parser.add_argument(
"--show-samples", dest="show_samples", action="store_true", help="Display unormalized training samples"
)
Expand Down
Loading