Skip to content

Commit

Permalink
Pass train-crop-mode to create_loader/transforms from train.py args
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Jan 25, 2024
1 parent 53a4888 commit 809a9e1
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 0 deletions.
2 changes: 2 additions & 0 deletions timm/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def create_loader(
re_mode: str = 'const',
re_count: int = 1,
re_split: bool = False,
train_crop_mode: Optional[str] = None,
scale: Optional[Tuple[float, float]] = None,
ratio: Optional[Tuple[float, float]] = None,
hflip: float = 0.5,
Expand Down Expand Up @@ -280,6 +281,7 @@ def create_loader(
input_size,
is_training=is_training,
no_aug=no_aug,
train_crop_mode=train_crop_mode,
scale=scale,
ratio=ratio,
hflip=hflip,
Expand Down
5 changes: 5 additions & 0 deletions timm/data/transforms_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def transforms_imagenet_train(
Args:
img_size: Target image size.
train_crop_mode: Training random crop mode ('rrc', 'rkrc', 'rkrr').
scale: Random resize scale range (crop area, < 1.0 => zoom in).
ratio: Random aspect ratio range (crop ratio for RRC, ratio adjustment factor for RKR).
hflip: Horizontal flip probability.
Expand Down Expand Up @@ -112,6 +113,7 @@ def transforms_imagenet_train(
* normalizes and converts the branches above with the third, final transform
"""
train_crop_mode = train_crop_mode or 'rrc'
assert train_crop_mode in {'rrc', 'rkrc', 'rkrr'}
if train_crop_mode in ('rkrc', 'rkrr'):
# FIXME integration of RKR is a WIP
scale = tuple(scale or (0.8, 1.00))
Expand Down Expand Up @@ -318,6 +320,7 @@ def create_transform(
input_size: Union[int, Tuple[int, int], Tuple[int, int, int]] = 224,
is_training: bool = False,
no_aug: bool = False,
train_crop_mode: Optional[str] = None,
scale: Optional[Tuple[float, float]] = None,
ratio: Optional[Tuple[float, float]] = None,
hflip: float = 0.5,
Expand Down Expand Up @@ -347,6 +350,7 @@ def create_transform(
input_size: Target input size (channels, height, width) tuple or size scalar.
is_training: Return training (random) transforms.
no_aug: Disable augmentation for training (useful for debug).
train_crop_mode: Training random crop mode ('rrc', 'rkrc', 'rkrr').
scale: Random resize scale range (crop area, < 1.0 => zoom in).
ratio: Random aspect ratio range (crop ratio for RRC, ratio adjustment factor for RKR).
hflip: Horizontal flip probability.
Expand Down Expand Up @@ -400,6 +404,7 @@ def create_transform(
elif is_training:
transform = transforms_imagenet_train(
img_size,
train_crop_mode=train_crop_mode,
scale=scale,
ratio=ratio,
hflip=hflip,
Expand Down
3 changes: 3 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@
group = parser.add_argument_group('Augmentation and regularization parameters')
group.add_argument('--no-aug', action='store_true', default=False,
help='Disable all training augmentation, override other train aug args')
group.add_argument('--train-crop-mode', type=str, default=None,
help='Crop-mode in train'),
group.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
group.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
Expand Down Expand Up @@ -685,6 +687,7 @@ def main():
re_mode=args.remode,
re_count=args.recount,
re_split=args.resplit,
train_crop_mode=args.train_crop_mode,
scale=args.scale,
ratio=args.ratio,
hflip=args.hflip,
Expand Down

0 comments on commit 809a9e1

Please sign in to comment.