这是使用Pytorch实现的CapsNet,参考论文《Dynamic Routing Between Capsules》
为了便于使用,已经将 Capsule 层封装成了一个模块,可以单独使用。
初始化时只需要定义这个模块的输入特征个数,输出特征个数,输入特征长度,输出特征长度,路由次数,如
cap = Capsule(input_features=3266, output_features=10, input_feature_length=8, output_feature_length=16, routing_iterators=3)
输入张量大小定义为一个 (批量大小,输入特征个数,输入特征长度),即 (batch_size, input_features, input_feature_length)
输出张量大小定义为一个 (批量大小,输出特征个数,输出特征长度),即 (batch_size, output_features, output_feature_length)
output = cap(input)
其中 input 大小为 (batch_size, 3266, 8)
其中 output 大小为 (batch_size, 10, 16)