Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

训练自己的数据loss报错,寻求陈博帮助 #43

Open
wfs123456 opened this issue Dec 24, 2022 · 2 comments
Open

训练自己的数据loss报错,寻求陈博帮助 #43

wfs123456 opened this issue Dec 24, 2022 · 2 comments

Comments

@wfs123456
Copy link

您好,我用您的segv3训练了一批自己的数据,19分类(包括背景类),比Kitti.yaml中少一类,改了dataloader后输入输出的维度均已对齐,但损失函数报错,希望得到您的帮助
in_vol([16,5,64,2048]), output([16,20,64,2048]), proj_labels([16,64,2048]),self.loss_w([19]),请问还有什么地方需要修改的吗?

proj_labels = proj_labels.squeeze(1).cuda(non_blocking=True).long()
[output, z2, z3, z4, z5] = model(in_vol, proj_mask)
loss = criterion(torch.log(output.clamp(min=1e-8)), proj_labels)+
criterion(torch.log(z5.clamp(min=1e-8)), proj_labels_5)+\

File "/home/penglei/anaconda3/envs/segv3/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 213, in forward
return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction)
File "/home/penglei/anaconda3/envs/segv3/lib/python3.8/site-packages/torch/nn/functional.py", line 2266, in nll_loss
ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: weight tensor should be defined either for all or no classes at /pytorch/aten/src/THCUNN/generic/SpatialClassNLLCriterion.cu:27

@chenfengxu714
Copy link
Owner

您好,看起来是您的不同class的weight 的维度和class 个数没有对齐

@wfs123456
Copy link
Author

,看起来是您的不同class的weight 的维度

感谢回复,刚接触点云有两个问题想咨询:
(1)是这个原因,我的数据只有19分类,已修改kitti-yaml文件中为19类,但这个loss计算依旧报错不知道还有什么地方需要修改,我在自己的数据中添加了一个额外类凑够kitti的20类就可以正常run了,咨询一下还有什么地方需要改吗
(2)另外还有个问题咨询您,由于我的数据(461196.6, 4407146.5, 17.5)与kitti不一致,点云project图像的时候SSGV321.yaml中的sensor传感器参数需要修改吗,因为debug发现laserscan.py中的do_range_projection函数生成的proj_x(545,545,545,545.......)和proj_y(7,7,7,7,7,7,7,7,7,........)都是一个值,导致生成的proj_xyz有误,请问这个问题是什么原因导致的

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants