Skip to content

Commit

Permalink
unify train scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Oct 8, 2024
1 parent 2ba9e80 commit 1b2f311
Show file tree
Hide file tree
Showing 14 changed files with 27 additions and 43 deletions.
12 changes: 6 additions & 6 deletions .github/workflows/references.yml
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,16 @@ jobs:
unzip toy_recogition_set-036a4d80.zip -d reco_set
- if: matrix.framework == 'tensorflow'
name: Train for a short epoch (TF) (document orientation)
run: python references/classification/train_tensorflow_orientation.py ./det_set ./det_set resnet18 page -b 2 --epochs 1
run: python references/classification/train_tensorflow_orientation.py resnet18 --type page --train_path ./det_set --val_path ./det_set -b 2 --epochs 1
- if: matrix.framework == 'pytorch'
name: Train for a short epoch (PT) (document orientation)
run: python references/classification/train_pytorch_orientation.py ./det_set ./det_set resnet18 page -b 2 --epochs 1
run: python references/classification/train_pytorch_orientation.py resnet18 --type page --train_path ./det_set --val_path ./det_set -b 2 --epochs 1
- if: matrix.framework == 'tensorflow'
name: Train for a short epoch (TF) (crop orientation)
run: python references/classification/train_tensorflow_orientation.py ./reco_set ./reco_set resnet18 crop -b 4 --epochs 1
run: python references/classification/train_tensorflow_orientation.py resnet18 --type crop --train_path ./reco_set --val_path ./reco_set -b 4 --epochs 1
- if: matrix.framework == 'pytorch'
name: Train for a short epoch (PT) (crop orientation)
run: python references/classification/train_pytorch_orientation.py ./reco_set ./reco_set resnet18 crop -b 4 --epochs 1
run: python references/classification/train_pytorch_orientation.py resnet18 --type crop --train_path ./reco_set --val_path ./reco_set -b 4 --epochs 1

train-text-recognition:
runs-on: ${{ matrix.os }}
Expand Down Expand Up @@ -318,10 +318,10 @@ jobs:
unzip toy_detection_set-bbbb4243.zip -d det_set
- if: matrix.framework == 'tensorflow'
name: Train for a short epoch (TF)
run: python references/detection/train_tensorflow.py --train_path ./det_set --val_path ./det_set linknet_resnet18 -b 2 --epochs 1
run: python references/detection/train_tensorflow.py linknet_resnet18 --train_path ./det_set --val_path ./det_set -b 2 --epochs 1
- if: matrix.framework == 'pytorch'
name: Train for a short epoch (PT)
run: python references/detection/train_pytorch.py ./det_set ./det_set db_mobilenet_v3_large -b 2 --epochs 1
run: python references/detection/train_pytorch.py db_mobilenet_v3_large --train_path ./det_set --val_path ./det_set -b 2 --epochs 1

evaluate-text-detection:
runs-on: ${{ matrix.os }}
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/classification/magc_resnet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
# Context modeling: B, H, W, C -> B, 1, 1, C
context = self.context_modeling(inputs)
# Transform: B, 1, 1, C -> B, 1, 1, C
transformed = self.transform(context)
transformed = self.transform(context, **kwargs)
return inputs + transformed


Expand Down
4 changes: 2 additions & 2 deletions doctr/models/detection/linknet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ def __init__(
for in_chan, out_chan, s, in_shape in zip(i_chans, o_chans, strides, in_shapes[::-1])
]

def call(self, x: List[tf.Tensor]) -> tf.Tensor:
def call(self, x: List[tf.Tensor], **kwargs: Any) -> tf.Tensor:
out = 0
for decoder, fmap in zip(self.decoders, x[::-1]):
out = decoder(out + fmap)
out = decoder(out + fmap, **kwargs)
return out

def extra_repr(self) -> str:
Expand Down
2 changes: 0 additions & 2 deletions doctr/models/factory/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@

if is_torch_available():
import torch
elif is_tf_available():
pass

__all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config_for_hf_hub"]

Expand Down
4 changes: 2 additions & 2 deletions references/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ python references/classification/train_pytorch_character.py mobilenet_v3_large -
You can start your training in TensorFlow:

```shell
python references/classification/train_tensorflow_orientation.py path/to/your/train_set path/to/your/val_set resnet18 page --epochs 5
python references/classification/train_tensorflow_orientation.py resnet18 --type page --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5
```

or PyTorch:

```shell
python references/classification/train_pytorch_orientation.py path/to/your/train_set path/to/your/val_set resnet18 page --epochs 5
python references/classification/train_pytorch_orientation.py resnet18 --type page --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5
```

The type can be either `page` for document images or `crop` for word crops.
Expand Down
6 changes: 3 additions & 3 deletions references/classification/train_pytorch_orientation.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,10 +375,10 @@ def parse_args():
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

parser.add_argument("train_path", type=str, help="path to training data folder")
parser.add_argument("val_path", type=str, help="path to validation data folder")
parser.add_argument("arch", type=str, help="classification model to train")
parser.add_argument("type", type=str, choices=["page", "crop"], help="type of data to train on")
parser.add_argument("--type", type=str, required=True, choices=["page", "crop"], help="type of data to train on")
parser.add_argument("--train_path", type=str, required=True, help="path to training data folder")
parser.add_argument("--val_path", type=str, required=True, help="path to validation data folder")
parser.add_argument("--name", type=str, default=None, help="Name of your training experiment")
parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on")
parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for training")
Expand Down
8 changes: 4 additions & 4 deletions references/classification/train_tensorflow_orientation.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def main(args):

if args.export_onnx:
print("Exporting model to ONNX...")
if args.arch == "vit_b":
if args.arch in ["vit_s", "vit_b"]:
# fixed batch size for vit
dummy_input = [tf.TensorSpec([1, *(input_size), 3], tf.float32, name="input")]
else:
Expand All @@ -356,10 +356,10 @@ def parse_args():
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

parser.add_argument("train_path", type=str, help="path to training data folder")
parser.add_argument("val_path", type=str, help="path to validation data folder")
parser.add_argument("arch", type=str, help="classification model to train")
parser.add_argument("type", type=str, choices=["page", "crop"], help="type of data to train on")
parser.add_argument("--type", type=str, required=True, choices=["page", "crop"], help="type of data to train on")
parser.add_argument("--train_path", type=str, help="path to training data folder")
parser.add_argument("--val_path", type=str, required=True, help="path to validation data folder")
parser.add_argument("--name", type=str, default=None, help="Name of your training experiment")
parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on")
parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for training")
Expand Down
4 changes: 2 additions & 2 deletions references/detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ pip install -r references/requirements.txt
You can start your training in TensorFlow:

```shell
python references/detection/train_tensorflow.py path/to/your/train_set path/to/your/val_set db_resnet50 --epochs 5
python references/detection/train_tensorflow.py db_resnet50 --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5
```

or PyTorch:

```shell
python references/detection/train_pytorch.py path/to/your/train_set path/to/your/val_set db_resnet50 --epochs 5 --device 0
python references/detection/train_pytorch.py db_resnet50 --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5
```

## Data format
Expand Down
2 changes: 1 addition & 1 deletion references/detection/evaluate_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric):
for images, targets in tqdm(val_loader):
images = batch_transforms(images)
targets = [{CLASS_NAME: t} for t in targets]
out = model(images, targets, training=False, return_preds=True)
out = model(images, target=targets, training=False, return_preds=True)
# Compute metric
loc_preds = out["preds"]
for target, loc_pred in zip(targets, loc_preds):
Expand Down
4 changes: 2 additions & 2 deletions references/detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,9 @@ def parse_args():
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

parser.add_argument("train_path", type=str, help="path to training data folder")
parser.add_argument("val_path", type=str, help="path to validation data folder")
parser.add_argument("arch", type=str, help="text-detection model to train")
parser.add_argument("--train_path", type=str, required=True, help="path to training data folder")
parser.add_argument("--val_path", type=str, required=True, help="path to validation data folder")
parser.add_argument("--name", type=str, default=None, help="Name of your training experiment")
parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on")
parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for training")
Expand Down
10 changes: 2 additions & 8 deletions references/detection/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from doctr.datasets import DataLoader, DetectionDataset
from doctr.models import detection
from doctr.utils.metrics import LocalizationConfusion
from utils import EarlyStopper, load_backbone, plot_recorder, plot_samples
from utils import EarlyStopper, plot_recorder, plot_samples


def record_lr(
Expand Down Expand Up @@ -195,11 +195,6 @@ def main(args):
if isinstance(args.resume, str):
model.load_weights(args.resume)

if isinstance(args.pretrained_backbone, str):
print("Loading backbone weights.")
model = load_backbone(model, args.pretrained_backbone)
print("Done.")

# Metrics
val_metric = LocalizationConfusion(use_polygons=args.rotation and not args.eval_straight)

Expand Down Expand Up @@ -409,7 +404,7 @@ def parse_args():

parser.add_argument("arch", type=str, help="text-detection model to train")
parser.add_argument("--train_path", type=str, required=True, help="path to training data folder")
parser.add_argument("--val_path", type=str, help="path to validation data folder")
parser.add_argument("--val_path", type=str, required=True, help="path to validation data folder")
parser.add_argument("--name", type=str, default=None, help="Name of your training experiment")
parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on")
parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for training")
Expand All @@ -419,7 +414,6 @@ def parse_args():
parser.add_argument("--input_size", type=int, default=1024, help="model input size, H = W")
parser.add_argument("--lr", type=float, default=0.001, help="learning rate for the optimizer (Adam)")
parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint")
parser.add_argument("--pretrained-backbone", type=str, default=None, help="Path to your backbone weights")
parser.add_argument("--test-only", dest="test_only", action="store_true", help="Run the validation loop")
parser.add_argument(
"--freeze-backbone", dest="freeze_backbone", action="store_true", help="freeze model backbone for fine-tuning"
Expand Down
8 changes: 0 additions & 8 deletions references/detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import pickle
from typing import Dict, List

import cv2
Expand Down Expand Up @@ -86,13 +85,6 @@ def plot_recorder(lr_recorder, loss_recorder, beta: float = 0.95, **kwargs) -> N
plt.show(**kwargs)


def load_backbone(model, weights_path):
pretrained_backbone_weights = pickle.load(open(weights_path, "rb"))
model.feat_extractor.set_weights(pretrained_backbone_weights[0])
model.fpn.set_weights(pretrained_backbone_weights[1])
return model


class EarlyStopper:
def __init__(self, patience: int = 5, min_delta: float = 0.01):
self.patience = patience
Expand Down
2 changes: 1 addition & 1 deletion references/recognition/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ python references/recognition/train_tensorflow.py crnn_vgg16_bn --train_path pat
or PyTorch:

```shell
python references/recognition/train_pytorch.py crnn_vgg16_bn --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5 --device 0
python references/recognition/train_pytorch.py crnn_vgg16_bn --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5
```

### Multi-GPU support (PyTorch only - Experimental)
Expand Down
2 changes: 1 addition & 1 deletion references/recognition/evaluate_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric):
for images, targets in tqdm(val_iter):
try:
images = batch_transforms(images)
out = model(images, targets, return_preds=True, training=False)
out = model(images, target=targets, return_preds=True, training=False)
# Compute metric
if len(out["preds"]):
words, _ = zip(*out["preds"])
Expand Down

0 comments on commit 1b2f311

Please sign in to comment.