diff --git a/deep_sort/deep/resnet.py b/deep_sort/deep/resnet.py index 6912b13..207fc36 100644 --- a/deep_sort/deep/resnet.py +++ b/deep_sort/deep/resnet.py @@ -85,7 +85,7 @@ def __init__(self, block, blocks_num, reid=False, num_classes=1000, groups=1, wi self.groups = groups self.width_per_group = width_per_group - self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2, + self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=1, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(self.in_channel) self.relu = nn.ReLU(inplace=True) @@ -93,10 +93,10 @@ def __init__(self, block, blocks_num, reid=False, num_classes=1000, groups=1, wi self.layer1 = self._make_layers(block, 64, blocks_num[0]) self.layer2 = self._make_layers(block, 128, blocks_num[1], stride=2) self.layer3 = self._make_layers(block, 256, blocks_num[2], stride=2) - # self.layer4 = self._make_layers(block, 512, blocks_num[3], stride=1) + self.layer4 = self._make_layers(block, 512, blocks_num[3], stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.fc = nn.Linear(256 * block.expansion, num_classes) + self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): @@ -131,7 +131,7 @@ def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) - # x = self.layer4(x) + x = self.layer4(x) x = self.avgpool(x) x = torch.flatten(x, 1) @@ -168,6 +168,6 @@ def resnext50_32x4d(num_classes=1000, reid=False): if __name__ == '__main__': - net = resnet18(reid=True) + net = resnet18(reid=False) x = torch.randn(4, 3, 128, 64) y = net(x) diff --git a/deepsort.py b/deepsort.py index a9b6bdd..b4139a8 100644 --- a/deepsort.py +++ b/deepsort.py @@ -148,7 +148,7 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--VIDEO_PATH", type=str, default='demo.avi') parser.add_argument("--config_mmdetection", type=str, default="./configs/mmdet.yaml") - parser.add_argument("--config_detection", type=str, default="./configs/mask_rcnn.yaml") + parser.add_argument("--config_detection", type=str, default="./configs/yolov5m.yaml") parser.add_argument("--config_deepsort", type=str, default="./configs/deep_sort.yaml") parser.add_argument("--config_fastreid", type=str, default="./configs/fastreid.yaml") parser.add_argument("--fastreid", action="store_true")