昇思25天学习打卡营第16天 | DCGAN生成漫画头像

这两天把minspore配置到我的电脑上了,然后运行就没什么问题了✨😊

今天学这个DCGAN生成漫画头像,我超级感兴趣的嘞🦄🥰

GAN基础原理

这部分原理介绍参考GAN图像生成。

DCGAN原理

DCGAN(深度卷积对抗生成网络,Deep Convolutional Generative Adversarial Networks)是GAN的直接扩展。不同之处在于,DCGAN会分别在判别器和生成器中使用卷积和转置卷积层。

它最早由Radford等人在论文Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks中进行描述。判别器由分层的卷积层、BatchNorm层和LeakyReLU激活层组成。输入是3x64x64的图像,输出是该图像为真图像的概率。生成器则是由转置卷积层、BatchNorm层和ReLU激活层组成。输入是标准正态分布中提取出的隐向量𝑧𝑧,输出是3x64x64的RGB图像。

本教程将使用动漫头像数据集来训练一个生成式对抗网络,接着使用该网络生成动漫头像图片。

数据准备与处理

from download import download

url = "https://download.mindspore.cn/dataset/Faces/faces.zip"

path = download(url, "./faces", kind="zip", replace=True)

数据处理

batch_size = 128          # 批量大小
image_size = 64           # 训练图像空间大小
nc = 3                    # 图像彩色通道数
nz = 100                  # 隐向量的长度
ngf = 64                  # 特征图在生成器中的大小
ndf = 64                  # 特征图在判别器中的大小
num_epochs = 3           # 训练周期数
lr = 0.0002               # 学习率
beta1 = 0.5               # Adam优化器的beta1超参数

import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision as vision

def create_dataset_imagenet(dataset_path):
    """数据加载"""
    dataset = ds.ImageFolderDataset(dataset_path,
                                    num_parallel_workers=4,
                                    shuffle=True,
                                    decode=True)

    # 数据增强操作
    transforms = [
        vision.Resize(image_size),
        vision.CenterCrop(image_size),
        vision.HWC2CHW(),
        lambda x: ((x / 255).astype("float32"))
    ]

    # 数据映射操作
    dataset = dataset.project('image')
    dataset = dataset.map(transforms, 'image')

    # 批量操作
    dataset = dataset.batch(batch_size)
    return dataset

dataset = create_dataset_imagenet('./faces')

import matplotlib.pyplot as plt

def plot_data(data):
    # 可视化部分训练数据
    plt.figure(figsize=(10, 3), dpi=140)
    for i, image in enumerate(data[0][:30], 1):
        plt.subplot(3, 10, i)
        plt.axis("off")
        plt.imshow(image.transpose(1, 2, 0))
    plt.show()

sample_data = next(dataset.create_tuple_iterator(output_numpy=True))
plot_data(sample_data)

构造网络

当处理完数据后,就可以来进行网络的搭建了。按照DCGAN论文中的描述,所有模型权重均应从mean为0,sigma为0.02的正态分布中随机初始化。

生成器

生成器G的功能是将隐向量z映射到数据空间。由于数据是图像,这一过程也会创建与真实图像大小相同的 RGB 图像。在实践场景中,该功能是通过一系列Conv2dTranspose转置卷积层来完成的,每个层都与BatchNorm2d层和ReLu激活层配对,输出数据会经过tanh函数,使其返回[-1,1]的数据范围内。

DCGAN论文生成图像如下所示:

dcgangenerator

我们通过输入部分中设置的nzngfnc来影响代码中的生成器结构。nz是隐向量z的长度,ngf与通过生成器传播的特征图的大小有关,nc是输出图像中的通道数。

以下是生成器的代码实现:

import mindspore as ms
from mindspore import nn, ops
from mindspore.common.initializer import Normal

weight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)

class Generator(nn.Cell):
    """DCGAN网络生成器"""

    def __init__(self):
        super(Generator, self).__init__()
        self.generator = nn.SequentialCell(
            nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),
            nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),
            nn.ReLU(),
            nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),
            nn.ReLU(),
            nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),
            nn.ReLU(),
            nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.BatchNorm2d(ngf, gamma_init=gamma_init),
            nn.ReLU(),
            nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.Tanh()
            )

    def construct(self, x):
        return self.generator(x)

generator = Generator()

判别器

如前所述,判别器D是一个二分类网络模型,输出判定该图像为真实图的概率。通过一系列的Conv2dBatchNorm2dLeakyReLU层对其进行处理,最后通过Sigmoid激活函数得到最终概率。

DCGAN论文提到,使用卷积而不是通过池化来进行下采样是一个好方法,因为它可以让网络学习自己的池化特征。

判别器的代码实现如下:

class Discriminator(nn.Cell):
    """DCGAN网络判别器"""

    def __init__(self):
        super(Discriminator, self).__init__()
        self.discriminator = nn.SequentialCell(
            nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),
            )
        self.adv_layer = nn.Sigmoid()

    def construct(self, x):
        out = self.discriminator(x)
        out = out.reshape(out.shape[0], -1)
        return self.adv_layer(out)

discriminator = Discriminator()

模型训练

损失函数

当定义了DG后,接下来将使用MindSpore中定义的二进制交叉熵损失函数BCELoss。

优化器

这里设置了两个单独的优化器,一个用于D,另一个用于G。这两个都是lr = 0.0002beta1 = 0.5的Adam优化器。

训练模型

训练分为两个主要部分:训练判别器和训练生成器。

  • 训练判别器

    训练判别器的目的是最大程度地提高判别图像真伪的概率。按照Goodfellow的方法,是希望通过提高其随机梯度来更新判别器,所以我们要最大化𝑙𝑜𝑔𝐷(𝑥)+𝑙𝑜𝑔(1−𝐷(𝐺(𝑧))𝑙𝑜𝑔𝐷(𝑥)+𝑙𝑜𝑔(1−𝐷(𝐺(𝑧))的值。

  • 训练生成器

    如DCGAN论文所述,我们希望通过最小化𝑙𝑜𝑔(1−𝐷(𝐺(𝑧)))𝑙𝑜𝑔(1−𝐷(𝐺(𝑧)))来训练生成器,以产生更好的虚假图像。

在这两个部分中,分别获取训练过程中的损失,并在每个周期结束时进行统计,将fixed_noise批量推送到生成器中,以直观地跟踪G的训练进度。

下面实现模型训练正向逻辑:

# 定义损失函数
adversarial_loss = nn.BCELoss(reduction='mean')

# 为生成器和判别器设置优化器
optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G.update_parameters_name('optim_g.')
optimizer_D.update_parameters_name('optim_d.')

def generator_forward(real_imgs, valid):
    # 将噪声采样为发生器的输入
    z = ops.standard_normal((real_imgs.shape[0], nz, 1, 1))

    # 生成一批图像
    gen_imgs = generator(z)

    # 损失衡量发生器绕过判别器的能力
    g_loss = adversarial_loss(discriminator(gen_imgs), valid)

    return g_loss, gen_imgs

def discriminator_forward(real_imgs, gen_imgs, valid, fake):
    # 衡量鉴别器从生成的样本中对真实样本进行分类的能力
    real_loss = adversarial_loss(discriminator(real_imgs), valid)
    fake_loss = adversarial_loss(discriminator(gen_imgs), fake)
    d_loss = (real_loss + fake_loss) / 2
    return d_loss

grad_generator_fn = ms.value_and_grad(generator_forward, None,
                                      optimizer_G.parameters,
                                      has_aux=True)
grad_discriminator_fn = ms.value_and_grad(discriminator_forward, None,
                                          optimizer_D.parameters)

@ms.jit
def train_step(imgs):
    valid = ops.ones((imgs.shape[0], 1), mindspore.float32)
    fake = ops.zeros((imgs.shape[0], 1), mindspore.float32)

    (g_loss, gen_imgs), g_grads = grad_generator_fn(imgs, valid)
    optimizer_G(g_grads)
    d_loss, d_grads = grad_discriminator_fn(imgs, gen_imgs, valid, fake)
    optimizer_D(d_grads)

    return g_loss, d_loss, gen_imgs

import mindspore

G_losses = []
D_losses = []
image_list = []

total = dataset.get_dataset_size()
for epoch in range(num_epochs):
    generator.set_train()
    discriminator.set_train()
    # 为每轮训练读入数据
    for i, (imgs, ) in enumerate(dataset.create_tuple_iterator()):
        g_loss, d_loss, gen_imgs = train_step(imgs)
        if i % 100 == 0 or i == total - 1:
            # 输出训练记录
            print('[%2d/%d][%3d/%d]   Loss_D:%7.4f  Loss_G:%7.4f' % (
                epoch + 1, num_epochs, i + 1, total, d_loss.asnumpy(), g_loss.asnumpy()))
        D_losses.append(d_loss.asnumpy())
        G_losses.append(g_loss.asnumpy())

    # 每个epoch结束后,使用生成器生成一组图片
    generator.set_train(False)
    fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))
    img = generator(fixed_noise)
    image_list.append(img.transpose(0, 2, 3, 1).asnumpy())

    # 保存网络模型参数为ckpt文件
    mindspore.save_checkpoint(generator, "./generator.ckpt")
    mindspore.save_checkpoint(discriminator, "./discriminator.ckpt")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/773704.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

一本超简单能用Python实现办公自动化的神书!让我轻松摆脱办公烦恼!

《超简单:用Python让Excel飞起来》 这本书旨在通过Python与Excel的“强强联手”,为办公人员提供一套高效的数据处理方案。书中还介绍了如何在Excel中调用Python代码,进一步拓宽了办公自动化的应用范围。 全书共9章。第1~3章主要讲解Python编…

【数据结构】06.栈队列

一、栈 1.1栈的概念及结构 栈:一种特殊的线性表,其只允许在固定的一端进行插入和删除元素操作。进行数据插入和删除操作的一端称为栈顶,另一端称为栈底。栈中的数据元素遵守后进先出LIFO(Last In First Out)的原则。 压栈&#…

JAVA 对象存储OSS工具类(腾讯云)

对象存储OSS工具类 import com.qcloud.cos.COSClient; import com.qcloud.cos.ClientConfig; import com.qcloud.cos.auth.BasicCOSCredentials; import com.qcloud.cos.auth.COSCredentials; import com.qcloud.cos.model.ObjectMetadata; import com.qcloud.cos.model.PutObj…

洗地机品牌哪个最好用?硬核推荐五大实力爆款洗地机

在这个忙碌的时代,家就是我们放松的港湾,但要保持它的清洁与舒适常常很不容易。每天拖着疲惫的身体回家,还要面对地板上那些难缠的灰尘、污渍,真是非常让人头疼。不过,洗地机的出现就像是给家务清洁装上了智能引擎&…

idea中maven全局配置

配置了就不需要每次创建项目都来设置maven仓库了。 1.先把项目全关了 2. 进入全局设置 3.设置maven的仓库就可以了

一篇文章带你完全理解C语言数组

文章目录 1.一维数组的创建和初始化数组的创建1.2数组的初始化1.3 一维数组的使用1.4一维数组在内存中的存储 2.二维数组的创建和初始化2.1二维数组的创建2.2 二维数组的初始化2.3 二维数组的使用2.4 二维数组在内存中的存储 3.数组越界4.数组作为函数参数4.1 冒泡排序函数的错…

从零开始开发美颜SDK:打造属于平台的主播美颜工具

本篇文章,小编将从零开始,介绍如何打造一款属于平台的主播美颜工具。 一、需求分析 首先,明确开发美颜SDK的需求是至关重要的。当前市场上,美颜工具的功能主要包括: 1.实时美颜:磨皮、美白、瘦脸等基础功…

Static关键字的用法详解

Static关键字的用法详解 1、Static修饰内部类2、Static修饰方法3、Static修饰变量4、Static修饰代码块5、总结 💖The Begin💖点点关注,收藏不迷路💖 在Java编程语言中,static是一个关键字,它可以用于多种上…

项目机会:4万平:智能仓,AGV,穿梭车,AMR,WMS,提升机,机器人……

导语 大家好,我是社长,老K。专注分享智能制造和智能仓储物流等内容。 如下为近期国内智能仓储物流相关项目的公开信息线索,这些项目具体信息会发布到知识星球,请感兴趣的球友先人一步到知识星球【智能仓储物流技术研习社】自行下载…

时钟系统框图(时钟树)解析

时钟系统框图(时钟树)解析 文章目录 时钟系统框图(时钟树)解析1、时钟树2、 4个时钟源:$HSI、HSE、LSI、LSE$3、PLL锁相环倍频输出4、系统时钟的来源5、Enable CSS(时钟监视系统)6、几个重要的时…

微信开发者工具使用

1.下载微信开发者工具 https://developers.weixin.qq.com/miniprogram/dev/devtools/stable.html 2.下载小程序项目代码 3.用微信开发者工具导入项目代码 4.npm安装依赖 5.构建 6.修改测试环境 7.清除缓存 观察切换test后,登录时是否test字样提醒,若…

使用Python+OpenCV实现姿态估计--20240705

姿态估计使用Opencv+Mediapipe来时实现 什么是Mediapipe? Mediapipe是主要用于构建多模式音频,视频或任何时间序列数据的框架。借助MediaPipe框架,可以构建令人印象深刻的ML管道,例如TensorFlow,TFLite等推理模型以及媒体处理功能。 安装命令: pip install mediapipe如果…

大模型提示词工程和落地思考

本文是一篇内部的个人分享(已无敏感信息) ,目的是增加产品、开发同学对 LLM 的理解,以降低沟通中的阻力,更好推进落地。 以下经脱敏后的原文: 大模型并不神奇 很多人听到’大模型’这个词可能会觉得很神秘&#xff…

centos7固定ip

1.查看虚拟网络配置 2.修改网卡配置文件 [jiajinglocalhost ~]$ su - Password: Last login: Thu Jul 4 19:06:16 PDT 2024 on pts/0 [rootlocalhost ~]# vim /etc/sysconfig/network-scripts/ifcfg-ens33 TYPE"Ethernet" PROXY_METHOD"none" BROWSER_ON…

倘若你的的B端系统如此漂亮,还担心拿不出手吗,尤其是面对客户

如果你的B端系统设计如此漂亮,那么通常来说,你不太需要担心在客户那里拿不出手。一个漂亮和易用的设计可以提升用户体验,增加客户对系统的满意度。 然而,还是有一些因素需要考虑,以确保你的B端系统在客户那里能够得到良…

综合项目实战--jenkins流水线

一、流水线定义 软件生产环节,如:需求调研、需求设计、概要设计、详细设计、编码、单元测试、集成测试、系统测试、用户验收测试、交付等,这些流程就组成一条完整的流水线。脚本式流水线(pipeline)的出现代表企业人员可以更自由的通过代码来实现不同的工作流程。 二、pi…

【夏季跨境】7-9月热门类目选品指南(附自用选品工具免费插件)

一、夏季跨境热门类目及品类 1 、运动与户外 随着夏季的到来,户外活动变得更加频繁,运动和户外装备的需求显著上升 自行车配件 自行车部件:踏板、自行车配件存储、自行车灯、自行车手机支架 整车及相关配件:成人/儿童自行车、…

JavaEE——计算机工作原理

冯诺依曼体系(VonNeumannArchitecture) 现代计算机,大多遵守冯诺依曼体系结构 CPU中央处理器:进行算术运算与逻辑判断 存储器:分为外存和内存,用于存储数据(使用二进制存储) 输入…

【python】OpenCV—Nighttime Low Illumination Image Enhancement

文章目录 1 背景介绍2 代码实现3 原理分析4 效果展示5 附录np.ndindexnumpy.ravelnumpy.argsortcv2.detailEnhancecv2.edgePreservingFilter 1 背景介绍 学习参考来自:OpenCV基础(24)改善夜间图像的照明 源码: 链接&#xff1a…

Golang 数组+切片+映射

数组 什么是数组 数组是一种数据类型,属于值类型数组可以存放多个同一类型数据 数组定义 var 数组名 [数组大小]数据类型 var a [5]int数组初始化的4种方式 var numArray01 [3]int [3]int{1,2,3} var numArray02 [3]int{1,2,3} var numArray03 [...]int{1,2…