Optimization in Capsule Network Based on Mutual Information Autoencoder and Variational Routing
-
摘要: 胶囊网络是一类有别于卷积神经网络的新型网络模型。该文尝试提高其泛化性和精准性:首先,利用变分路由来缓解经典路由对先验信息依赖性强、易导致模型过拟合的问题。通过使用高斯混合模型(GMM)来拟合低级矩阵胶囊,并利用变分法求取近似分布,避免了参数最大似然点估计的误差,用置信度评估来获得泛化性能的提高;其次,考虑到实际数据大多无标签或者标注困难,构建互信息评价标准的胶囊自编码器,实现特征参数的有效筛选。即通过引入局部编码器,只保留胶囊中对原始输入识别最有效的特征,在减轻网络负担的同时提高了其分类识别的精准性。该文的方法在MNIST, FashionMNIST, CIFAR-10和CIFAR-100等数据集上进行了对比测试,实验结果表明:该文方法对比经典胶囊网络,其性能得到显著改善。
-
关键词:
- 胶囊网络 /
- 变分路由 /
- 基于互信息评价的胶囊自编码器
Abstract: Capsule network is a new type of network model which is different from convolutional neural network. This paper attempts to improve its generalization and accuracy. Firstly, variational routing is used to alleviate the problem of classic routing that is highly dependent on prior information and can easily lead to model overfitting. By using the Gaussian Mixture Model (GMM) to fit the low-level matrix capsule and using the variational method to fit the approximation distribution, the error of the maximum likelihood point estimation is avoided, and the confidence calculation is used to improve the generalization performance; Secondly, considering that the actual data is mostly untagged or difficult to label, a capsule autoencoder with mutual information evaluation criterion is constructed to achieve effective selection of feature parameters. That is, by introducing a local encoder, only the most effective features in the capsule for identifying and classifying the original input are retained, which reduces the computational burden of the network while improving the accuracy of classification and recognition at the same time. The method in this paper is compared and tested on datasets such as MNIST, FashionMNIST, CIFAR-10, and CIFAR-100. The experimental results show that the performance of the proposed method is significantly improved compared with the classic capsule network. -
表 1 变分路由算法伪代码
输入:投票矩阵$ {\boldsymbol{\nu }}_{n} $,激活值$ a $,迭代次数T (1) 初始化:令${\alpha }_{0}=0.001,{ {\boldsymbol{m} } }_{0}={{{\textit{0}}} },{r}_{nk}=1/k,{ {\boldsymbol{W} } }_{0}$为单位矩
阵,$ {\beta }_{0},{\nu }_{0} $为常数。(2) VBM 步: (3) 更新 ${r}_{nk}\leftarrow {r}_{nk}\cdot a$ (4) 更新 ${N}_{k},{\tilde {{\boldsymbol{\nu}} } }_{k},{{\boldsymbol{S}}}_{k}$(通过式(12)—式(14)) (5) 更新 $ {\alpha }_{k} $ (通过式(16)) (6) 更新 ${{\boldsymbol{m}}}_{k},{\beta }_{k},{{\boldsymbol{W}}}_{k}^{-1},{{\boldsymbol{\nu}} }_{k},{{\boldsymbol{W}}}_{k}$(通过式(18)—式(22)) (7) T=T–1 (8) VBE 步: (9) 更新 $ \mathrm{l}\mathrm{n}{\rho }_{nk} $(通过式(9)) (10) 其中 $\ln{\tilde {\pi } }_{k}=\varphi \left({\alpha }_{k}\right)\varphi \left(\displaystyle\sum\nolimits_{i=1}^{k}{\alpha }_{k}\right)$ (11) $\ln\tilde {\varLambda }=\displaystyle\sum\nolimits_{i=1}^{D}\varphi \left(\frac{ {{\boldsymbol{\nu}} }_{k}+1-i}{2}\right)+D\ln 2+\ln\left|{\boldsymbol{W}}\right|$ (12) $\begin{array}{l}{E}_{Mk,{\varLambda }_{K} }\left[{\left({ {\boldsymbol{\nu} } }_{n}-{ {\boldsymbol{\mu} } }_{k}\right)}^{\mathrm{T} }\varLambda \left({ {\boldsymbol{\nu} } }_{n}-{ {\boldsymbol{\mu} } }_{k}\right)\right]\\ =D{\beta }_{k}^{-1}+{{\boldsymbol{\nu}} }_{k}{\left({{\boldsymbol{\nu}}}_{n}-{{\boldsymbol{m}}}_{k}\right)}^{\mathrm{T} }{{\boldsymbol{W}}}_{k}\left({{\boldsymbol{\nu}} }_{n}-{{\boldsymbol{m}}}_{k}\right)\end{array}$ (13) ${{\boldsymbol{M}}}_{k} = \mathrm{s}\mathrm{q}\mathrm{u}\mathrm{e}\mathrm{e}\mathrm{z}\mathrm{e}\left({{\boldsymbol{m}}}_{k}\right)$($ \mathrm{s}\mathrm{q}\mathrm{u}\mathrm{e}\mathrm{e}\mathrm{z}\mathrm{e} $为维度转换函数) (14) $ a = \mathrm{s}\mathrm{q}\mathrm{u}\mathrm{e}\mathrm{e}\mathrm{z}\mathrm{e}\left({N}_{k}\right) $ (15) 输出:${{\boldsymbol{M}}}_{k} ,a$ 首先完成(1)~(2)中输入和初始化步骤,然后开始迭代(4)~(8)的
VBM步和(10)~(13)的VBE步,直到T为0时停止更新,然后计算
(13)~(14)的${{\boldsymbol{M}}}_{k}$和$ a $,并完成(15)。表 2 基于编码胶囊的路由伪代码
输入$ x $,t=3,初始化$ b=0 $ 步骤1 计算$ c $,$ c\leftarrow \mathrm{S}\mathrm{o}\mathrm{f}\mathrm{t}\mathrm{m}\mathrm{a}\mathrm{x}\left(b\right) $ 计算$ h $,$ h\leftarrow H\left(x\right) $ 计算$ g $,$ g\leftarrow G\left(h\right) $ 计算$ u $,$ u=w\cdot \mathrm{c}\mathrm{o}\mathrm{n}\mathrm{c}\mathrm{a}\mathrm{t}\left(g,h\right) $ 更新$ s,s\leftarrow \sum \left(c\cdot u\right) $ 更新$ v,v\leftarrow \mathrm{S}\mathrm{q}\mathrm{u}\mathrm{a}\mathrm{s}\mathrm{h}\left(\mathrm{s}\right) $ 更新$ t,t\leftarrow t-1 $ 步骤2 更新$ b,b\leftarrow b+\left(g+h\right)\cdot v $ 步骤3 输出$ v $ 完成步骤1,当$ t $不为0时,完成步骤2更新$ b $,并将$ b $代入
步骤1计算$ v $和$ t $;当$ t $为0时,结束迭代计算,完成步骤3。表 3 分类准确率对比(%)
模型 MNIST
准确率FashionMNIST
准确率CNN 98.00 90.30 ResNet 99.27 94.90 Inception-V3 99.29 94.97 CN 99.30 92.50 VBCN 99.50 93.50 表 4 泛化性对比(%)
模型 Two_MNIST
准确率Two_FashionMNIST
准确率CNN 45.30 41.40 ResNet 89.09 59.60 Inception-V3 77.35 68.45 CN 93.15 82.60 VBCN 95.65 86.20 表 5 CIFAR-10测试准确率对比
标签 类别名称 经典CN准确率 改进CN准确率 0 飞机 0.73 0.81 1 汽车 0.76 0.87 2 鸟 0.71 0.74 3 猫 0.45 0.54 4 鹿 0.66 0.76 5 狗 0.55 0.60 6 青蛙 0.58 0.64 7 马 0.77 0.80 8 船 0.59 0.67 9 卡车 0.71 0.77 均值 --- 0.65 0.72 表 6 CIFAR-100测试准确率对比(%)
模型 CIFAR-100准确率 经典CN 46.98 改进CN 52.33 -
[1] SABOUR S, FROSST N, and HINTON G E. Dynamic routing between capsules[C]. The 31st International Conference on Neural Information Processing Systems, Long Beach, USA, 2017: 3856–3866. [2] HINTON G E, SABOUR S, and FROSST N. Matrix capsules with EM routing[C]. International Conference on Learning Representations, Vancouver, Canada, 2018. [3] GOLHANI K, BALASUNDRAM S K, VADAMALAI G, et al. A review of neural networks in plant disease detection using hyperspectral data[J]. Information Processing in Agriculture, 2018, 5(3): 354–371. doi: 10.1016/j.inpa.2018.05.002 [4] PAOLETTI M E, HAUT J M, FERNANDEZ-BELTRAN R, et al. Capsule networks for hyperspectral image classification[J]. IEEE Transactions on Geoscience and Remote Sensing, 2019, 57(4): 2145–2160. doi: 10.1109/TGRS.2018.2871782 [5] CHU Xin, XU Ning, LIU Xiaofeng, et al. Research on capsule network optimization structure by variable route planning[C]. 2019 IEEE International Conference on Real-time Computing and Robotics (RCAR), Irkutsk, Russia, 2019: 858–861. [6] AUBERT G and VESE L. A variational method in image recovery[J]. SIAM Journal on Numerical Analysis, 1997, 34(5): 1948–1979. doi: 10.1137/S003614299529230X [7] 李速, 齐翔林, 胡宏, 等. 功能柱结构神经网络模型中的同步振荡现象[J]. 中国科学C辑, 2004, 34(4): 385–394. doi: 10.3321/j.issn:1006-9259.2004.04.012 [8] MOON T K. The expectation-maximization algorithm[J]. IEEE Signal Processing Magazine, 1996, 13(6): 47–60. doi: 10.1109/79.543975 [9] 西广成. 基于平均场理论逼近的神经网络[J]. 电子学报, 1995(8): 62–64. doi: 10.3321/j.issn:0372-2112.1995.08.016XI Guangcheng. Neural network based on mean-field theory approximation[J]. Acta Electronica Sinica, 1995(8): 62–64. doi: 10.3321/j.issn:0372-2112.1995.08.016 [10] BISHOP C M. Pattern Recognition and Machine Learning[M]. New York: Springer, 2006: 293–355. [11] GÖRÜR D and RASMUSSEN C E. Dirichlet process Gaussian mixture models: Choice of the base distribution[J]. Journal of Computer Science and Technology, 2010, 25(4): 653–664. doi: 10.1007/s11390-010-9355-8 [12] SHRIBERG E, FERRER L, KAJAREKAR S, et al. Modeling prosodic feature sequences for speaker recognition[J]. Speech Communication, 2005, 46(3/4): 455–472. [13] HJELM R D, FEDOROV A, LAVOIE-MARCHILDON S, et al. Learning deep representations by mutual information estimation and maximization[C]. 7th International Conference on Learning Representations, New Orleans, USA, 2019: 1–24. [14] BELGHAZI M I, RAJESWAR S, BARATIN A, et al. MINE: Mutual information neural estimation[J]. arXiv: 1801.04062, 2018: 531–540. [15] 徐峻岭, 周毓明, 陈林, 等. 基于互信息的无监督特征选择[J]. 计算机研究与发展, 2012, 49(2): 372–382.XU Junling, ZHOU Yuming, CHEN Lin, et al. An unsupervised feature selection approach based on mutual information[J]. Journal of Computer Research and Development, 2012, 49(2): 372–382. [16] 姚志均, 刘俊涛, 周瑜, 等. 基于对称KL距离的相似性度量方法[J]. 华中科技大学学报: 自然科学版, 2011, 39(11): 1–4, 38.YAO Zhijun, LIU Juntao, ZHOU Yu, et al. Similarity measure method using symmetric KL divergence[J]. Journal of Huazhong University of Science and Technology:Nature Science, 2011, 39(11): 1–4, 38. [17] PATHAK D, KRÄHENBÜHL P, DONAHUE J, et al. Context encoders: Feature learning by inpainting[C]. The IEEE Conference on Computer Vision and Pattern Recognition, Las Vegas, USA, 2016: 2536–2544. [18] KRIZHEVSKY A and HINTON G E. Learning multiple layers of features from tiny images[R]. Technical report, 2009. [19] LECUN Y, CORTES C, and BURGES C J C. MNIST handwritten digit database. 2010[OL]. http://yann.lecun.com/exdb/mnist, 2010, 7: 23. [20] XIAO H, RASUL K, and VOLLGRAF R. Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms. arXiv: 1708.07747, 2017. [21] HE Kaiming, ZHANG Xiangyu, REN Shaoqing, et al. Deep residual learning for image recognition[C]. The IEEE Conference on Computer Vision and Pattern Recognition, Las Vegas, USA, 2016: 770–778. [22] SZEGEDY C, VANHOUCKE V, IOFFE S, et al. Rethinking the inception architecture for computer vision[C]. The IEEE Conference on Computer Vision and Pattern Recognition, Las Vegas, USA, 2016: 2818–2826.