Skip to content

Commit 5ef9804

Browse files
committed
updated vit
1 parent 4833d54 commit 5ef9804

39 files changed

+1185
-122
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ The supported methods are as follows:
9595
- [x] [ArcFace_mxnet (CVPR'2019)](recognition/arcface_mxnet)
9696
- [x] [ArcFace_torch (CVPR'2019)](recognition/arcface_torch)
9797
- [x] [SubCenter ArcFace (ECCV'2020)](recognition/subcenter_arcface)
98-
- [x] [PartialFC_mxnet (Arxiv'2020)](recognition/partial_fc)
99-
- [x] [PartialFC_torch (Arxiv'2020)](recognition/arcface_torch)
98+
- [x] [PartialFC_mxnet (CVPR'2022)](recognition/partial_fc)
99+
- [x] [PartialFC_torch (CVPR'2022)](recognition/arcface_torch)
100100
- [x] [VPL (CVPR'2021)](recognition/vpl)
101101
- [x] [Arcface_oneflow](recognition/arcface_oneflow)
102102
- [x] [ArcFace_Paddle (CVPR'2019)](recognition/arcface_paddle)

recognition/arcface_torch/.gitignore

Lines changed: 0 additions & 5 deletions
This file was deleted.

recognition/arcface_torch/README.md

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ identity on a single server.
55

66
## Requirements
77

8-
- Install [PyTorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).
8+
In order to enjoy the features of the new torch, we have upgraded the torch to 1.9.0.
9+
torch version before than 1.9.0 may not work in the future.
10+
11+
- Install [PyTorch](http://pytorch.org) (torch>=1.9.0), our doc for [install.md](docs/install.md).
912
- (Optional) Install [DALI](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/), our doc for [install_dali.md](docs/install_dali.md).
1013
- `pip install -r requirement.txt`.
1114

@@ -58,26 +61,55 @@ For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with F
5861
globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
5962

6063

64+
> 1. Large Scale Datasets
65+
66+
| Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | Training Throughout | log |
67+
|:-----------------|:------------|:------------|:------------|:------------|:--------------------|:------------------------------------------------------------------------------------------------------------------------------------------------|
68+
| MS1MV3 | mobileface | 65.76 | 94.44 | 91.85 | ~13000 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_mobileface_lr02/training.log) |
69+
| Glint360K | mobileface | 69.83 | 95.17 | 92.58 | -11000 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_mobileface_lr02_bs4k/training.log) |
70+
| WF42M-PFC-0.2 | mobileface | 73.80 | 95.40 | 92.64 | (16GPUs)~18583 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_mobilefacenet_pfc02_bs8k_16gpus/training.log) |
71+
| MS1MV3 | r100 | 83.23 | 96.88 | 95.31 | ~3400 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_r100_lr02/training.log) |
72+
| Glint360K | r100 | 90.86 | 97.53 | 96.43 | ~5000 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_r100_lr02_bs4k_16gpus/training.log) |
73+
| WF42M-PFC-0.2 | r50(bs4k) | 93.83 | 97.53 | 96.16 | (8 GPUs)~5900 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r50_bs4k_pfc02/training.log) |
74+
| WF42M-PFC-0.2 | r50(bs8k) | 93.96 | 97.46 | 96.12 | (16GPUs)~11000 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r50_lr01_pfc02_bs8k_16gpus/training.log) |
75+
| WF42M-PFC-0.2 | r50(bs4k) | 94.04 | 97.48 | 95.94 | (32GPUs)~17000 | click me |
76+
| WF42M-PFC-0.0018 | r100(bs16k) | 93.08 | 97.51 | 95.88 | (32GPUs)~10000 | click me |
77+
| WF42M-PFC-0.2 | r100(bs4k) | 96.69 | 97.85 | 96.63 | (16GPUs)~5200 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r100_bs4k_pfc02/training.log) |
78+
79+
> 2. VIT For Face Recognition
80+
81+
| Datasets | Backbone | FLOPs | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | Training Throughout | log |
82+
|:--------------|:-------------|:------|:------------|:------------|:------------|:--------------------|:---------|
83+
| WF42M-PFC-0.3 | R18(bs4k) | 2.6 | 79.13 | 95.77 | 93.36 | - | click me |
84+
| WF42M-PFC-0.3 | R50(bs4k) | 6.3 | 94.03 | 97.48 | 95.94 | - | click me |
85+
| WF42M-PFC-0.3 | R100(bs4k) | 12.1 | 96.69 | 97.82 | 96.45 | - | click me |
86+
| WF42M-PFC-0.3 | R200(bs4k) | 23.5 | 97.70 | 97.97 | 96.93 | - | click me |
87+
| WF42M-PFC-0.3 | VIT-T(bs24k) | 1.5 | 92.24 | 97.31 | 95.97 | (64GPUs)~35000 | click me |
88+
| WF42M-PFC-0.3 | VIT-S(bs24k) | 5.7 | 95.87 | 97.73 | 96.57 | (64GPUs)~25000 | click me |
89+
| WF42M-PFC-0.3 | VIT-B(bs24k) | 11.4 | 97.42 | 97.90 | 97.04 | (64GPUs)~13800 | click me |
90+
| WF42M-PFC-0.3 | VIT-L(bs24k) | 25.3 | 97.85 | 98.00 | 97.23 | (64GPUs)~9406 | click me |
91+
92+
WF42M means WebFace42M, `PFC-0.3` means negivate class centers sample rate is 0.3.
93+
94+
> 3. Noisy Datasets
95+
96+
| Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | log |
97+
|:-------------------------|:---------|:------------|:------------|:------------|:---------|
98+
| WF12M-Flip(40%) | R50 | 43.87 | 88.35 | 80.78 | click me |
99+
| WF12M-Flip(40%)-PFC-0.3* | R50 | 80.20 | 96.11 | 93.79 | click me |
100+
| WF12M-Conflict | R50 | 79.93 | 95.30 | 91.56 | click me |
101+
| WF12M-Conflict-PFC-0.3* | R50 | 91.68 | 97.28 | 95.75 | click me |
102+
103+
WF12M means WebFace12M, `+PFC-0.3*` denotes additional abnormal inter-class filtering.
61104

62-
| Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | Training Throughout | log |
63-
|:-------------------------|:-----------|:------------|:------------|:------------|:--------------------|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
64-
| MS1MV3 | mobileface | 65.76 | 94.44 | 91.85 | ~13000 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_mobileface_lr02/training.log)\|[config](configs/ms1mv3_mobileface_lr02.py) |
65-
| Glint360K | mobileface | 69.83 | 95.17 | 92.58 | -11000 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_mobileface_lr02_bs4k/training.log)\|[config](configs/glint360k_mobileface_lr02_bs4k.py) |
66-
| WebFace42M-PartialFC-0.2 | mobileface | 73.80 | 95.40 | 92.64 | (16GPUs)~18583 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_mobilefacenet_pfc02_bs8k_16gpus/training.log)\|[config](configs/webface42m_mobilefacenet_pfc02_bs8k_16gpus.py) |
67-
| MS1MV3 | r100 | 83.23 | 96.88 | 95.31 | ~3400 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_r100_lr02/training.log)\|[config](configs/ms1mv3_r100_lr02.py) |
68-
| Glint360K | r100 | 90.86 | 97.53 | 96.43 | ~5000 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_r100_lr02_bs4k_16gpus/training.log)\|[config](configs/glint360k_r100_lr02_bs4k_16gpus.py) |
69-
| WebFace42M-PartialFC-0.2 | r50(bs4k) | 93.83 | 97.53 | 96.16 | (8 GPUs)~5900 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r50_bs4k_pfc02/training.log)\|[config](configs/webface42m_r50_lr01_pfc02_bs4k_8gpus.py) |
70-
| WebFace42M-PartialFC-0.2 | r50(bs8k) | 93.96 | 97.46 | 96.12 | (16GPUs)~11000 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r50_lr01_pfc02_bs8k_16gpus/training.log)\|[config](configs/webface42m_r50_lr01_pfc02_bs8k_16gpus.py) |
71-
| WebFace42M-PartialFC-0.2 | r50(bs4k) | 94.04 | 97.48 | 95.94 | (32GPUs)~17000 | log\|[config](configs/webface42m_r50_lr01_pfc02_bs4k_32gpus.py) |
72-
| WebFace42M-PartialFC-0.2 | r100(bs4k) | 96.69 | 97.85 | 96.63 | (16GPUs)~5200 | [log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r100_bs4k_pfc02/training.log)\|[config](configs/webface42m_r100_lr01_pfc02_bs4k_16gpus.py) |
73-
| WebFace42M-PartialFC-0.2 | r200 | - | - | - | - | log\|config |
74105

75-
`PartialFC-0.2` means negivate class centers sample rate is 0.2.
76106

77107

78108
## Speed Benchmark
109+
<div><img src="https://github.com/anxiangsir/insightface_arcface_log/blob/master/pfc_exp.png" width = "90%" /></div>
110+
79111

80-
`arcface_torch` can train large-scale face recognition training set efficiently and quickly. When the number of
112+
**Arcface-Torch** can train large-scale face recognition training set efficiently and quickly. When the number of
81113
classes in training sets is greater than 1 Million, partial fc sampling strategy will get same
82114
accuracy with several times faster training performance and smaller GPU memory.
83115
Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a
@@ -86,12 +118,12 @@ sparse part of the parameters will be updated, which can reduce a lot of GPU mem
86118
we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed
87119
training and mixed precision training.
88120

89-
![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png)
121+
90122

91123
More details see
92124
[speed_benchmark.md](docs/speed_benchmark.md) in docs.
93125

94-
### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better)
126+
> 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better)
95127
96128
`-` means training failed because of gpu memory limitations.
97129

@@ -104,7 +136,7 @@ More details see
104136
| 16000000 | **-** | **-** | 2679 |
105137
| 29000000 | **-** | **-** | **1855** |
106138

107-
### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)
139+
> 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)
108140
109141
| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
110142
|:--------------------------------|:--------------|:---------------|:---------------|
@@ -126,11 +158,18 @@ More details see
126158
pages={4690--4699},
127159
year={2019}
128160
}
161+
@inproceedings{an2022pfc,
162+
title={Killing Two Birds with One Stone: Efficient and Robust Training of Face Recognition CNNs by Partial FC},
163+
author={An, Xiang and Deng, Jiangkang and Guo, Jia and Feng, Ziyong and Zhu, Xuhan and Jing, Yang and Tongliang, Liu},
164+
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
165+
year={2022}
166+
}
129167
@inproceedings{an2020partical_fc,
130168
title={Partial FC: Training 10 Million Identities on a Single Machine},
131169
author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and
132170
Zhang, Debing and Fu Ying},
133-
booktitle={Arxiv 2010.05222},
171+
booktitle={Proceedings of International Conference on Computer Vision Workshop},
172+
pages={1445-1449},
134173
year={2020}
135174
}
136175
```

recognition/arcface_torch/backbones/__init__.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,69 @@ def get_model(name, **kwargs):
1717
elif name == "r2060":
1818
from .iresnet2060 import iresnet2060
1919
return iresnet2060(False, **kwargs)
20+
2021
elif name == "mbf":
2122
fp16 = kwargs.get("fp16", False)
2223
num_features = kwargs.get("num_features", 512)
2324
return get_mbf(fp16=fp16, num_features=num_features)
25+
26+
elif name == "mbf_large":
27+
from .mobilefacenet import get_mbf_large
28+
fp16 = kwargs.get("fp16", False)
29+
num_features = kwargs.get("num_features", 512)
30+
return get_mbf_large(fp16=fp16, num_features=num_features)
31+
32+
elif name == "vit_t":
33+
num_features = kwargs.get("num_features", 512)
34+
from .vit import VisionTransformer
35+
return VisionTransformer(
36+
img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12,
37+
num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1)
38+
39+
elif name == "vit_t_dp005_mask0": # For WebFace42M
40+
num_features = kwargs.get("num_features", 512)
41+
from .vit import VisionTransformer
42+
return VisionTransformer(
43+
img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12,
44+
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0)
45+
46+
elif name == "vit_s":
47+
num_features = kwargs.get("num_features", 512)
48+
from .vit import VisionTransformer
49+
return VisionTransformer(
50+
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12,
51+
num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1)
52+
53+
elif name == "vit_s_dp005_mask_0": # For WebFace42M
54+
num_features = kwargs.get("num_features", 512)
55+
from .vit import VisionTransformer
56+
return VisionTransformer(
57+
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12,
58+
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0)
59+
60+
elif name == "vit_b":
61+
# this is a feature
62+
num_features = kwargs.get("num_features", 512)
63+
from .vit import VisionTransformer
64+
return VisionTransformer(
65+
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24,
66+
num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1, using_checkpoint=True)
67+
68+
elif name == "vit_b_dp005_mask_005": # For WebFace42M
69+
# this is a feature
70+
num_features = kwargs.get("num_features", 512)
71+
from .vit import VisionTransformer
72+
return VisionTransformer(
73+
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24,
74+
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True)
75+
76+
elif name == "vit_l_dp005_mask_005": # For WebFace42M
77+
# this is a feature
78+
num_features = kwargs.get("num_features", 512)
79+
from .vit import VisionTransformer
80+
return VisionTransformer(
81+
img_size=112, patch_size=9, num_classes=num_features, embed_dim=768, depth=24,
82+
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True)
83+
2484
else:
25-
raise ValueError()
85+
raise ValueError()

recognition/arcface_torch/backbones/iresnet.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import torch
22
from torch import nn
3+
from torch.utils.checkpoint import checkpoint
34

45
__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
5-
6+
using_ckpt = False
67

78
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
89
"""3x3 convolution with padding"""
@@ -43,7 +44,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None,
4344
self.downsample = downsample
4445
self.stride = stride
4546

46-
def forward(self, x):
47+
def forard_impl(self, x):
4748
identity = x
4849
out = self.bn1(x)
4950
out = self.conv1(out)
@@ -54,7 +55,13 @@ def forward(self, x):
5455
if self.downsample is not None:
5556
identity = self.downsample(x)
5657
out += identity
57-
return out
58+
return out
59+
60+
def forward(self, x):
61+
if self.training and using_ckpt:
62+
return checkpoint(self.forard_imlp, x)
63+
else:
64+
return self.forard_impl(x)
5865

5966

6067
class IResNet(nn.Module):
@@ -63,6 +70,7 @@ def __init__(self,
6370
block, layers, dropout=0, num_features=512, zero_init_residual=False,
6471
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
6572
super(IResNet, self).__init__()
73+
self.extra_gflops = 0.0
6674
self.fp16 = fp16
6775
self.inplanes = 64
6876
self.dilation = 1

0 commit comments

Comments
 (0)