高级搜索

留言板

尊敬的读者、作者、审稿人, 关于本刊的投稿、审稿、编辑和出版的任何问题, 您可以本页添加留言。我们将尽快给您答复。谢谢您的支持!

姓名
邮箱
手机号码
标题
留言内容
验证码

面向不平衡图像数据的对抗自编码器过采样算法

职为梅 常智 卢俊华 耿正乾

职为梅, 常智, 卢俊华, 耿正乾. 面向不平衡图像数据的对抗自编码器过采样算法[J]. 电子与信息学报. doi: 10.11999/JEIT240330
引用本文: 职为梅, 常智, 卢俊华, 耿正乾. 面向不平衡图像数据的对抗自编码器过采样算法[J]. 电子与信息学报. doi: 10.11999/JEIT240330
ZHI Weimei, CHANG Zhi, LU Junhua, GENG Zhengqian. Adversarial Autoencoders Oversampling Algorithm for Imbalanced Image Data[J]. Journal of Electronics & Information Technology. doi: 10.11999/JEIT240330
Citation: ZHI Weimei, CHANG Zhi, LU Junhua, GENG Zhengqian. Adversarial Autoencoders Oversampling Algorithm for Imbalanced Image Data[J]. Journal of Electronics & Information Technology. doi: 10.11999/JEIT240330

面向不平衡图像数据的对抗自编码器过采样算法

doi: 10.11999/JEIT240330
基金项目: 国家重点研发计划 (2023YFC2206404)
详细信息
    作者简介:

    职为梅:女,副教授,研究方向为数据挖掘、机器学习

    常智:男,硕士生,研究方向为数据挖掘、生成对抗网络

    卢俊华:女,硕士生,研究方向为数据挖掘、深度学习

    耿正乾:男,硕士生,研究方向为数据挖掘、机器学习

    通讯作者:

    常智 cszchang@163.com

  • 中图分类号: TN911.73; TP181

Adversarial Autoencoders Oversampling Algorithm for Imbalanced Image Data

Funds: The National Key Research and Development Project (2023YFC2206404)
  • 摘要: 许多适用于低维数据的传统不平衡学习算法在图像数据上的效果并不理想。基于生成对抗网络(GAN)的过采样算法虽然可以生成高质量图像,但在类不平衡情况下容易产生模式崩溃问题。基于自编码器(AE)的过采样算法容易训练,但生成的图像质量较低。为进一步提高过采样算法在不平衡图像中生成样本的质量和训练的稳定性,该文基于生成对抗网络和自编码器的思想提出一种融合自编码器和生成对抗网络的过采样算法(BAEGAN)。首先在自编码器中引入一个条件嵌入层,使用预训练的条件自编码器初始化GAN以稳定模型训练;然后改进判别器的输出结构,引入一种融合焦点损失和梯度惩罚的损失函数以减轻类不平衡的影响;最后从潜在向量的分布映射中使用合成少数类过采样技术(SMOTE)来生成高质量的图像。在4个图像数据集上的实验结果表明该算法在生成图像质量和过采样后的分类性能上优于具有辅助分类器的条件生成对抗网络(ACGAN)、平衡生成对抗网络 (BAGAN)等过采样算法,能有效解决图像数据中的类不平衡问题。
  • 图  1  生成对抗网络

    图  2  对抗自编码器

    图  3  BAEGAN算法架构

    图  4  不同过采样算法在MNIST不平衡数据集上生成的图像

    图  5  不同过采样算法在FMNIST不平衡数据集上生成的图像

    图  6  不同过采样算法在SVHN不平衡数据集上生成的图像

    图  7  不同过采样算法在CIFAR-10不平衡数据集上生成的图像

    图  8  不同过采样算法在MNIST不平衡数据集上过采样后的样本分布图

    图  9  CIFAR-10上消融实验的图像生成效果

    1  BAEGAN算法描述

     输入:从不平衡的训练集$X$中划分一批数据$B = \{ {b_1},{b_2},\cdots,$
     ${b_{|X|/m}}\} $;批量大小$m$;类别数量$n$;预先设定的模型超参数;
     先验分布$p({\boldsymbol{z}})$;
     输出:平衡后的数据集${X_{\text{b}}}$
     (1) (a) 初始化所有网络参数(编码器${\theta _E}$、解码器${\theta _{{\text{De}}}}$、生成器
     ${\theta _G}$、判别器${\theta _D}$),预训练条件自编码器:
     (2) WHILE预训练轮数 DO
     (3)  FOR 从$B$中选取一组数据$({\boldsymbol{x}},{\boldsymbol{c}})$ DO
     (4)   将数据${\boldsymbol{x}}$送入编码器$E$,获得${\boldsymbol{z}}$;
     (5)   将${\boldsymbol{z}}$和${\boldsymbol{c}}$输入嵌入层,获得${{\boldsymbol{z}}_{\text{c}}}$;
     (6)   将${{\boldsymbol{z}}_{\text{c}}}$送入解码器${\text{De}}$,获得重构图像$\hat {\boldsymbol{x}}$;
     (7)   由式(2)计算损失,更新${\theta _E}$和${\theta _{{\text{De}}}}$。
     (8)  END
     (9) END
     (10) (b) 预训练的条件自编码器初始化${\theta _G}$和${\theta _{{\text{De}}}}$,训练模型:
     (11) WHILE 模型未收敛或未达到训练轮数 DO
     (12) FOR 从$B$中选取一组数据$({\boldsymbol{x}},{\boldsymbol{c}})$ DO
     (13)   将数据${\boldsymbol{x}}$送入编码器$E$中,获得${\boldsymbol{z}}$;
     (14)   将${\boldsymbol{z}}$和${\boldsymbol{c}}$输入嵌入层中,获得${{\boldsymbol{z}}_{\text{c}}}$;
     (15)   将${{\boldsymbol{z}}_{\text{c}}}$送入解码器${\text{De}}$,获得重构图像$\hat {\boldsymbol{x}}$;
     (16)   根据式(2)计算损失,更新${\theta _E}$和${\theta _{{\text{De}}}}$。
     (17)   将${\boldsymbol{x}}$送入$G$,获得${{\boldsymbol{z}}_{{\text{fake}}}}$ ,从$p({\boldsymbol{z}})$中获得${{\boldsymbol{z}}_{{\text{real}}}}$;
     (18)   将${{\boldsymbol{z}}_{{\text{fake}}}}$和${{\boldsymbol{z}}_{{\text{real}}}}$输入判别器$D$,由式(4)计算判别器损失,
         更新${\theta _D}$;
     (19)   ${{\boldsymbol{z}}_{{\text{fake}}}}$送入$D$,由式(5)计算生成器损失,更新${\theta _G}$;
     (20) END
     (21) END
     (22) (c) 生成样本,平衡数据集:
     (23) WHILE 选取少数类${\text{c}}$中的所有样本$({{\boldsymbol{x}}_{\text{c}}},{\boldsymbol{c}})$ ,直至所有少数
       类选取完毕DO
     (24) 将数据${{\boldsymbol{x}}_{\mathrm{c}}}$送入$E$中,获得潜在向量${\boldsymbol{z}}$;
     (25) 将${\boldsymbol{z}}$和${\boldsymbol{c}}$送入SMOTE中,获得平衡后的潜在向量${{\boldsymbol{z}}^{\text{b}}}$和类
        标签${{\boldsymbol{c}}^{\text{b}}}$;
     (26) 将${{\boldsymbol{z}}^{\text{b}}}$和${{\boldsymbol{c}}^{\text{b}}}$输入嵌入层中,获得嵌入条件的向量$ {\boldsymbol{z}}_{\text{c}}^{\text{b}} $;
     (27) 将$ {\boldsymbol{z}}_{\text{c}}^{\text{b}} $送入解码器${\text{De}}$,获得平衡后属于类${\text{c}}$的样本集;
     (28) END
     (29) 获得平衡数据集${X_{\text{b}}}$。
    下载: 导出CSV

    表  1  网络结构设置

    层数卷积核数量卷积核大小步长填充
    判别器或编码器164421
    2128421
    3256421
    4512421
    生成器或解码器1512410
    2256421
    3128421
    464421
    5图像通道数421
    下载: 导出CSV

    表  2  不同过采样算法在各不平衡数据集上的FID分数

    算法MNISTFMNISTSVHNCIFAR-10
    CGAN[12]280.482290.239340.472363.291
    ACGAN[13]140.239188.182190.384210.356
    BAGAN[16]119.293100.231183.753199.088
    DeepSMOTE[11]100.31596.449161.483170.104
    BAEGAN82.63394.546175..332142.333
    下载: 导出CSV

    表  3  不同过采样算法在各不平衡数据集上的分类性能

    MNISTFMNISTSVHNCIFAR-10
    ACSAF1GMACSAF1GMACSAF1GMACSAF1GM
    CGAN[12]0.87920.85440.90570.65280.63620.72630.72590.69080.79360.33190.30880.5302
    ACGAN[13]0.92120.91230.94920.81440.78950.86060.77200.74030.82390.40060.34100.5918
    BAGAN[16]0.93060.92770.95980.81480.80930.89310.80230.77750.86770.43380.40250.6373
    DeepSMOTE[11]0.96090.96030.97800.83630.83270.90610.80940.78730.87390.45380.43350.6530
    BAEGAN0.98070.97150.98420.87990.81560.91330.83570.77690.89420.54430.52540.7301
    下载: 导出CSV

    表  4  在CIFAR-10上消融实验分类结果

    算法ACSAF1GM
    BAEGAN-AE0.42260.39460.5802
    BAEGAN-L0.35840.31420.4098
    BAEGAN-S0.27320.22330.3083
    BAEGAN0.54430.52540.7301
    下载: 导出CSV

    表  5  算法运行时间分析(s)

    算法MNISTFMNISTSVHNCIFAR-10
    CGAN[12]1774232326123284
    ACGAN[13]1476167519352249
    BAGAN[16]72638430982311038
    DeepSMOTE[11]71672014151941
    BAEGAN1827193420763374
    下载: 导出CSV
  • [1] FAN Xi, GUO Xin, CHEN Qi, et al. Data augmentation of credit default swap transactions based on a sequence GAN[J]. Information Processing & Management, 2022, 59(3): 102889. doi: 10.1016/j.ipm.2022.102889.
    [2] 刘侠, 吕志伟, 李博, 等. 基于多尺度残差双域注意力网络的乳腺动态对比度增强磁共振成像肿瘤分割方法[J]. 电子与信息学报, 2023, 45(5): 1774–1785. doi: 10.11999/JEIT220362.

    LIU Xia, LÜ Zhiwei, LI Bo, et al. Segmentation algorithm of breast tumor in dynamic contrast-enhanced magnetic resonance imaging based on network with multi-scale residuals and dual-domain attention[J]. Journal of Electronics & Information Technology, 2023, 45(5): 1774–1785. doi: 10.11999/JEIT220362.
    [3] 尹梓诺, 马海龙, 胡涛. 基于联合注意力机制和一维卷积神经网络-双向长短期记忆网络模型的流量异常检测方法[J]. 电子与信息学报, 2023, 45(10): 3719–3728. doi: 10.11999/JEIT220959.

    YIN Zinuo, MA Hailong, and HU Tao. A traffic anomaly detection method based on the joint model of attention mechanism and one-dimensional convolutional neural network-bidirectional long short term memory[J]. Journal of Electronics & Information Technology, 2023, 45(10): 3719–3728. doi: 10.11999/JEIT220959.
    [4] FERNÁNDEZ A, GARCÍA S, GALAR M, et al. Learning From Imbalanced Data Sets[M]. Cham: Springer, 2018: 327–349. doi: 10.1007/978-3-319-98074-4.
    [5] HUANG Zhan’ao, SANG Yongsheng, SUN Yanan, et al. A neural network learning algorithm for highly imbalanced data classification[J]. Information Sciences, 2022, 612: 496–513. doi: 10.1016/j.ins.2022.08.074.
    [6] FU Saiji, YU Xiaotong, and TIAN Yingjie. Cost sensitive ν-support vector machine with LINEX loss[J]. Information Processing & Management, 2022, 59(2): 102809. doi: 10.1016/j.ipm.2021.102809.
    [7] LIN T Y, GOYAL P, GIRSHICK R, et al. Focal loss for dense object detection[C]. Proceedings of the IEEE International Conference on Computer Vision, Venice, Italy, 2017: 2999–3007. doi: 10.1109/ICCV.2017.324.
    [8] LI Buyu, LIU Yu, and WANG Xiaogang. Gradient harmonized single-stage detector[C]. Proceedings of the 33rd AAAI Conference on Artificial Intelligence, Honolulu, USA, 2019: 8577–8584. doi: 10.1609/aaai.v33i01.33018577.
    [9] MICHELUCCI U. An introduction to autoencoders[J]. arXiv preprint arXiv: 2201.03898, 2022. doi: 10.48550/arXiv.2201.03898.
    [10] GOODFELLOW I J, POUGET-ABADIE J, MIRZA M, et al. Generative adversarial nets[C]. Proceedings of the 27th International Conference on Neural Information Processing Systems, Montreal, Canada, 2014: 2672–2680.
    [11] DABLAIN D, KRAWCZYK B, and CHAWLA N V. DeepSMOTE: Fusing deep learning and SMOTE for imbalanced data[J]. IEEE Transactions on Neural Networks and Learning Systems, 2023, 34(9): 6390–6404. doi: 10.1109/TNNLS.2021.3136503.
    [12] MIRZA M and OSINDERO S. Conditional generative adversarial nets[J]. arXiv preprint arXiv: 1411.1784, 2014. doi: 10.48550/arXiv.1411.1784.
    [13] ODENA A, OLAH C, and SHLENS J. Conditional image synthesis with auxiliary classifier GANs[C]. Proceedings of the 34th International Conference on Machine Learning, Sydney, Australia, 2017: 2642–2651.
    [14] GULRAJANI I, AHMED F, ARJOVSKY M, et al. Improved training of wasserstein GANs[C]. Proceedings of the 31st International Conference on Neural Information Processing Systems, Long Beach, USA, 2017: 5769–5779.
    [15] CHAWLA N V, BOWYER K W, HALL L O, et al. SMOTE: Synthetic minority over-sampling technique[J]. Journal of Artificial Intelligence Research, 2002, 16: 321–357. doi: 10.1613/jair.953.
    [16] MARIANI G, SCHEIDEGGER F, ISTRATE R, et al. BAGAN: Data augmentation with balancing GAN[J]. arXiv preprint arXiv: 1803.09655, 2018. doi: 10.48550/arXiv.1803.09655.
    [17] HUANG Gaofeng and JAFARI A H. Enhanced balancing GAN: Minority-class image generation[J]. Neural Computing and Applications, 2023, 35(7): 5145–5154. doi: 10.1007/s00521-021-06163-8.
    [18] BAO Jianmin, CHEN Dong, WEN Fang, et al. CVAE-GAN: Fine-grained image generation through asymmetric training[C]. Proceedings of the IEEE International Conference on Computer Vision (ICCV), Venice, Italy, 2017: 2764–2773. doi: 10.1109/ICCV.2017.299.
    [19] MAKHZANI A, SHLENS J, JAITLY N, et al. Adversarial autoencoders[J]. arXiv preprint arXiv: 1511.05644, 2015. doi: 10.48550/arXiv.1511.05644.
    [20] CUI Yin, JIA Menglin, LIN T Y, et al. Class-balanced loss based on effective number of samples[C]. Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), Long Beach, USA, 2019: 9260–9269. doi: 10.1109/CVPR.2019.00949.
    [21] DOWSON D C and LANDAU B V. The Fréchet distance between multivariate normal distributions[J]. Journal of Multivariate Analysis, 1982, 12(3): 450–455. doi: 10.1016/0047-259X(82)90077-X.
    [22] HUANG Chen, LI Yining, LOY C C, et al. Learning deep representation for imbalanced classification[C]. Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, Las Vegas, USA, 2016: 5375–5384. doi: 10.1109/CVPR.2016.580.
    [23] KUBAT M and MATWIN S. Addressing the curse of imbalanced training sets: One-sided selection[C]. 14th International Conference on Machine Learning, Nashville, USA, 1997: 179–186.
    [24] HRIPCSAK G and ROTHSCHILD A S. Agreement, the F-measure, and reliability in information retrieval[J]. Journal of the American Medical Informatics Association, 2005, 12(3): 296–298. doi: 10.1197/jamia.M1733.
    [25] SOKOLOVA M and LAPALME G. A systematic analysis of performance measures for classification tasks[J]. Information Processing & Management, 2009, 45(4): 427–437. doi: 10.1016/j.ipm.2009.03.002.
    [26] RADFORD A, METZ L, and CHINTALA S. Unsupervised representation learning with deep convolutional generative adversarial networks[C]. 4th International Conference on Learning Representations, San Juan, Puerto Rico, 2016.
  • 加载中
图(9) / 表(6)
计量
  • 文章访问数:  26
  • HTML全文浏览量:  15
  • PDF下载量:  3
  • 被引次数: 0
出版历程
  • 收稿日期:  2024-04-24
  • 修回日期:  2024-09-19
  • 网络出版日期:  2024-09-24

目录

    /

    返回文章
    返回