引言
本文主要介绍如何在tensorflow上仅使用200个带标签的mnist图像,实现在一万张测试图片上99%的测试精度,原理在于使用GAN做半监督学习。前文主要介绍一些原理部分,后文详细介绍代码及其实现原理。前文介绍比较简单,有基础的同学请略过直接看第二部分,文章末尾给出了代码GitHub链接。对GAN不了解的同学可以查看微信公众号:机器学习算法全栈工程师 的GAN入门文章。1
本博客中的代码最终以GitHub中的代码为准,GitHub链接在文章底部,另外,本文已投稿至微信公众号:机器学习算法全栈工程师,欢迎关注此公众号
1.监督,无监督,半监督学习介绍
在正式介绍实现半监督学习之前,我在这里首先介绍一下监督学习(supervised learning),半监督学习(semi-supervised learning)和无监督学习(unsupervised learning)的区别。
- 监督学习是指在训练集中包含训练数据的标签(label),比如类别标签,位置标签等等。最普遍使用标签学习的是分类任务,对于分类任务,输入给网络训练样本(samples)的一些特征(feature)以及此样本对应的标签(label),通过神经网络拟合的方法,神经网络可以在特征和标签之间找到一个合适的映射关系(mapping),这样当训练完成后,输入给网络没有label的样本,神经网络可以通过这一个映射关系猜出它属于哪一类。典型机器学习的监督学习的例子是KNN和SVM。目前机器视觉领域的急速发展离不开监督学习。
- 而无监督学习的训练事先没有训练标签,直接输入给算法一些数据,算法会努力学习数据的共同点,寻找样本之间的规律性。无监督学习是很典型的学习,人的学习有时候就是基于无监督的,比如我并不懂音乐,但是我听了上百首歌曲后,我可以根据我听的结果将音乐分为摇滚乐(记为0类)、民谣(记为1类)、纯音乐(记为2类)等等,事实上,我并不知道具体是哪一类,所以将它们记为0,1,2三类。典型的无监督学习方法是聚类算法,比如k-means。
- 东方快车电影里面大侦探有过一个台词,人们的话只有对与错,没有中间地带,最后经过一系列事件后他找到了对与错之间的betweeness。在监督学习和无监督学习之间,同样存在着中间地带——半监督学习。半监督学习简单来说就是将无监督学习和监督学习相结合,一部分包含了监督学习一部分包含了无监督学习,比如给一个分类任务,此分类任务的训练集中有精确标签的数据非常少,但是包含了大量的没有标注的数据,如果直接用监督学习的方法去做的话,效果不一定很好,有标注的训练数据太少很容易导致过拟合,而且大量的无标注的数据都没有充分的利用,最常见的例子是在医学图像的分析检测任务中,医学图像本身就不容易获得,要获得精标注的图像就需要有经验的医生去一个一个标注,显然他们并没有那么多的时间。这时候就是半监督学习的用武之地了,半监督学习很适合用在标签数据少,训练数据又比较多的情况。
常见的半监督学习方法主要有:
1.Self training
2.Generative model
3.S3VMs
4.Graph-Based AIgorithems
5.Multiview AIgorithems
接下来我会结合Improved Techniques for Training GANs这篇论文详细介绍如何使用目前最火的生成对抗模型GAN去实现半监督学习,也即是半监督学习的第二种方法,并给出详细的代码解释,对理论不是很熟悉的同学可以直接看代码。另外注明:我只复现了论文半监督学习的部分,之前也有人复现了此部分,但是我感觉他对原文有很大的曲解,他使用了所有的标签去帮助生成,并不在分类上,不太符合半监督学习的本质,而且代码很复杂,感兴趣的可以去GitHub上搜ssgan,希望能帮助你。
2. Improved Techniques for Training GANs
GAN是无监督学习的代表,它可以不断学习模拟数据的分布进而生成和训练数据相似分布的样本,在训练过程不需要标签,GAN在无监督学习领域,生成领域,半监督学习领域以及强化学习领域都有广泛的应用。但是GAN存在很多的训练不稳定等等的问题,作者good fellow在2016年放出了Improved Techniques for Training GANs,对GAN训练不稳定的问题做了一些解释和经验上的解决方案,并给出了和半监督学习结合的方法。
从平衡点角度解释GAN的不稳定性来说,GAN的纳什均衡点是一个鞍点,并不是一个局部最小值点,基于梯度的方法主要是寻找高维空间中的极小值点,因此使用梯度训练的方法很难使GAN收敛到平衡点。为此,为了进一部分缓解这个问题,goodfellow联合提出了一些改进方案,
主要有:
- Feature matching,
- Minibatch discrimination
- weight Historical averaging (相当于一个正则化的方式)
- One-sided label smoothing
- Virtual batch normalization
后来发现Feature matching在半监督学习上表现良好,mini-batch discrimination表现很差。
3. semi-supervised GAN
对于一个普通的分类器来说,假设对MNIST分类,一共有10类数据,分别是0-9,分类器模型以数据x作为输入,输出一个K=10维的向量,经过softmax后计算出分类概率最大的那个类别。在监督学习领域,往往是通过最小化类别标签和预测分布 的交叉熵来实现最好的结果。
但是将GAN用在半监督学习领域的时候需要做一些改变,生成器不做改变,仍然负责从输入噪声数据中生成图像,判别器D不在是一个简单的真假分类(二分类)器,假设输入数据有K类,D就是K+1的分类器,多出的那一类是判别输入是否是生成器G生成的图像。网络的流程图见下图:
网络结构确定了之后就是损失函数的设计部分,借助GAN我们就可以从无标签数据中学习,只要知道输入数据是真实数据,那就可以通过最大化\(logP_{model}(y\in{1,2,…,K}|x)\)来实现,上述式子可解释为不管输入的是哪一类真的图片(不是生成器G生成的假图片),只要最大化输出它是真图像的概率就可以了,不需要具体分出是哪一类。由于GAN的生成器的参与,训练数据中有一半都是生成的假数据。
下面给出判别器D的损失函数设计,D损失函数包括两个部分,一个是监督学习损失,一个是半监督学习损失,具体公式如下:
其中:
对于无监督学习来说,只需要输出真假就可以了,不需要确定是哪一类,因此我们令
其中\( P_{model} \)表示判别是假图像的概率,那么D(x)就代表了输出是真图像的概率,那么无监督学习的损失函数就可以表示为
这不就是GAN的损失函数嘛!好了,到这里得出结论,在半监督学习中,判别器的分类要多分一类,多出的这一类表示的是生成器生成的假图像这一类,另外判别器的损失函数不仅包括了监督损失函数而且还有无监督的损失函数,在训练过程中同时最小化这两者。损失函数介绍完毕,接下来介绍代码实现部分。
4.代码实现及解读
注:完整代码的GitHub连接在文章底部。这里只截取关键部分做介绍
在代码中,我使用feature matching,one side label smoothing方式,并没有使用论文中介绍的Historical averaging,而是只对判别器D使用了简单的l2正则化,防止过拟合,另外论文中介绍的Minibatch discrimination, Virtual batch normalization等等都没有使用,主要是这两者在半监督学习中表现不是很好,但是如果想获得好的生成结果还是很有用的。
首先介绍网络结构部分,因为是在mnist数据集比较简单,所以随便搭了一个判别器和生成器,具体如下:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22def discriminator(self, name, inputs, reuse):
l = tf.shape(inputs)[0]
inputs = tf.reshape(inputs, (l,self.img_size,self.img_size,self.dim))
with tf.variable_scope(name,reuse=reuse):
out = []
output = conv2d('d_con1',inputs,5, 64, stride=2, padding='SAME') #14*14
output1 = lrelu(self.bn('d_bn1',output))
out.append(output1)
# output1 = tf.contrib.keras.layers.GaussianNoise
output = conv2d('d_con2', output1, 3, 64*2, stride=2, padding='SAME')#7*7
output2 = lrelu(self.bn('d_bn2', output))
out.append(output2)
output = conv2d('d_con3', output2, 3, 64*4, stride=1, padding='VALID')#5*5
output3 = lrelu(self.bn('d_bn3', output))
out.append(output3)
output = conv2d('d_con4', output3, 3, 64*4, stride=2, padding='VALID')#2*2
output4 = lrelu(self.bn('d_bn4', output))
out.append(output4)
output = tf.reshape(output4, [l, 2*2*64*4])# 2*2*64*4
output = fc('d_fc', output, self.num_class)
# output = tf.nn.softmax(output)
return output, out
其中conv2d()是卷积操作,参数依次是,层的名字,输入tensor,卷积核大小,输出通道数,步长,padding。判别器中每一层都加了归一化层,这里使用最简单的归一化,函数如下所示,另外每一层的激活函数使用leakyrelu。判别器D最终返回两个值,第一个是计算的logits,另外一个是一个列表,列表的每一个元素代表判别器每一层的输出,为接下来实现feature matching做准备。
生成器结构如下所示:其最后一层激活函数使用tanh1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19def generator(self,name, noise, reuse):
with tf.variable_scope(name,reuse=reuse):
l = self.batch_size
output = fc('g_dc', noise, 2*2*64)
output = tf.reshape(output, [-1, 2, 2, 64])
output = tf.nn.relu(self.bn('g_bn1',output))
output = deconv2d('g_dcon1',output,5,outshape=[l, 4, 4, 64*4])
output = tf.nn.relu(self.bn('g_bn2',output))
output = deconv2d('g_dcon2', output, 5, outshape=[l, 8, 8, 64 * 2])
output = tf.nn.relu(self.bn('g_bn3', output))
output = deconv2d('g_dcon3', output, 5, outshape=[l, 16, 16,64 * 1])
output = tf.nn.relu(self.bn('g_bn4', output))
output = deconv2d('g_dcon4', output, 5, outshape=[l, 32, 32, self.dim])
output = tf.image.resize_images(output, (28, 28))
# output = tf.nn.relu(self.bn('g_bn4', output))
return tf.nn.tanh(output)
网络结构是根据DCGAN的结构改的,所以网络简要介绍到这里。
接下来介绍网络初始化方面:
首先在train.py里建立一个Train的类,并做一些初始化1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215#coding:utf-8
from glob import glob
from PIL import Image
import matplotlib.pyplot as plt
import scipy.misc as scm
from vlib.layers import *
import tensorflow as tf
import numpy as np
from vlib.load_data import *
import os
import vlib.plot as plot
import vlib.my_extract as dataload
import vlib.save_images as save_img
import time
from tensorflow.examples.tutorials.mnist import input_data #as mnist_data
mnist = input_data.read_data_sets('data/', one_hot=True)
# temp = 0.89
class Train(object):
def __init__(self, sess, args):
#sess=tf.Session()
self.sess = sess
self.img_size = 28 # the size of image
self.trainable = True
self.batch_size = 50 # must be even number
self.lr = 0.0002
self.mm = 0.5 # momentum term for adam
self.z_dim = 128 # the dimension of noise z
self.EPOCH = 50 # the number of max epoch
self.LAMBDA = 0.1 # parameter of WGAN-GP
self.model = args.model # 'DCGAN' or 'WGAN'
self.dim = 1 # RGB is different with gray pic
self.num_class = 11
self.load_model = args.load_model
self.build_model() # initializer
def build_model(self):
# build placeholders
self.x=tf.placeholder(tf.float32,shape=[self.batch_size,self.img_size*self.img_size*self.dim],name='real_img')
self.z = tf.placeholder(tf.float32, shape=[self.batch_size, self.z_dim], name='noise')
self.label = tf.placeholder(tf.float32, shape=[self.batch_size, self.num_class - 1], name='label')
self.flag = tf.placeholder(tf.float32, shape=[], name='flag')
self.flag2 = tf.placeholder(tf.float32, shape=[], name='flag2')
# define the network
self.G_img = self.generator('gen', self.z, reuse=False)
d_logits_r, layer_out_r = self.discriminator('dis', self.x, reuse=False)
d_logits_f, layer_out_f = self.discriminator('dis', self.G_img, reuse=True)
d_regular = tf.add_n(tf.get_collection('regularizer', 'dis'), 'loss') # D regular loss
# caculate the unsupervised loss
un_label_r = tf.concat([tf.ones_like(self.label), tf.zeros(shape=(self.batch_size, 1))], axis=1)
un_label_f = tf.concat([tf.zeros_like(self.label), tf.ones(shape=(self.batch_size, 1))], axis=1)
logits_r, logits_f = tf.nn.softmax(d_logits_r), tf.nn.softmax(d_logits_f)
d_loss_r = -tf.log(tf.reduce_sum(logits_r[:, :-1])/tf.reduce_sum(logits_r[:,:]))
d_loss_f = -tf.log(tf.reduce_sum(logits_f[:, -1])/tf.reduce_sum(logits_f[:,:]))
# d_loss_r = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=un_label_r*0.9, logits=d_logits_r))
# d_loss_f = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=un_label_f*0.9, logits=d_logits_f))
# feature match
f_match = tf.constant(0., dtype=tf.float32)
for i in range(4):
f_match += tf.reduce_mean(tf.multiply(layer_out_f[i]-layer_out_r[i], layer_out_f[i]-layer_out_r[i]))
# caculate the supervised loss
s_label = tf.concat([self.label, tf.zeros(shape=(self.batch_size,1))], axis=1)
s_l_r = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=s_label*0.9, logits=d_logits_r))
s_l_f = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=un_label_f*0.9, logits=d_logits_f)) # same as d_loss_f
self.d_l_1, self.d_l_2 = d_loss_r + d_loss_f, s_l_r
self.d_loss = d_loss_r + d_loss_f + s_l_r*self.flag*10 + d_regular
self.g_loss = d_loss_f + 0.01*f_match
all_vars = tf.global_variables()
g_vars = [v for v in all_vars if 'gen' in v.name]
d_vars = [v for v in all_vars if 'dis' in v.name]
for v in all_vars:
print v
if self.model == 'DCGAN':
self.opt_d = tf.train.AdamOptimizer(self.lr, beta1=self.mm).minimize(self.d_loss, var_list=d_vars)
self.opt_g = tf.train.AdamOptimizer(self.lr, beta1=self.mm).minimize(self.g_loss, var_list=g_vars)
elif self.model == 'WGAN_GP':
self.opt_d = tf.train.AdamOptimizer(1e-5, beta1=0.5, beta2=0.9).minimize(self.d_loss, var_list=d_vars)
self.opt_g = tf.train.AdamOptimizer(1e-5, beta1=0.5, beta2=0.9).minimize(self.g_loss, var_list=g_vars)
else:
print ('model can only be "DCGAN","WGAN_GP" !')
return
# test
test_logits, _ = self.discriminator('dis', self.x, reuse=True)
test_logits = tf.nn.softmax(test_logits)
temp = tf.reshape(test_logits[:, -1],shape=[self.batch_size, 1])
for i in range(10):
temp = tf.concat([temp, tf.reshape(test_logits[:, -1],shape=[self.batch_size, 1])], axis=1)
test_logits -= temp
self.prediction = tf.nn.in_top_k(test_logits, tf.argmax(s_label, axis=1), 1)
self.saver = tf.train.Saver()
if not self.load_model:
init = tf.global_variables_initializer()
self.sess.run(init)
elif self.load_model:
self.saver.restore(self.sess, os.getcwd()+'/model_saved/model.ckpt')
print 'model load done'
self.sess.graph.finalize()
def train(self):
if not os.path.exists('model_saved'):
os.mkdir('model_saved')
if not os.path.exists('gen_picture'):
os.mkdir('gen_picture')
noise = np.random.normal(-1, 1, [self.batch_size, 128])
temp = 0.80
print 'training'
for epoch in range(self.EPOCH):
# iters = int(156191//self.batch_size)
iters = 50000//self.batch_size
flag2 = 1 # if epoch>10 else 0
for idx in range(iters):
start_t = time.time()
flag = 1 if idx < 4 else 0 # set we use 2*batch_size=200 train data labeled.
batchx, batchl = mnist.train.next_batch(self.batch_size)
# batchx, batchl = self.sess.run([batchx, batchl])
g_opt = [self.opt_g, self.g_loss]
d_opt = [self.opt_d, self.d_loss, self.d_l_1, self.d_l_2]
feed = {self.x:batchx, self.z:noise, self.label:batchl, self.flag:flag, self.flag2:flag2}
# update the Discrimater k times
_, loss_d, d1,d2 = self.sess.run(d_opt, feed_dict=feed)
# update the Generator one time
_, loss_g = self.sess.run(g_opt, feed_dict=feed)
print ("[%3f][epoch:%2d/%2d][iter:%4d/%4d],loss_d:%5f,loss_g:%4f, d1:%4f, d2:%4f"%
(time.time()-start_t, epoch, self.EPOCH,idx,iters, loss_d, loss_g,d1,d2)), 'flag:',flag
plot.plot('d_loss', loss_d)
plot.plot('g_loss', loss_g)
if ((idx+1) % 100) == 0: # flush plot picture per 1000 iters
plot.flush()
plot.tick()
if (idx+1)%500==0:
print ('images saving............')
img = self.sess.run(self.G_img, feed_dict=feed)
save_img.save_images(img, os.getcwd()+'/gen_picture/'+'sample{}_{}.jpg'\
.format(epoch, (idx+1)/500))
print 'images save done'
test_acc = self.test()
plot.plot('test acc', test_acc)
plot.flush()
plot.tick()
print 'test acc:{}'.format(test_acc), 'temp:%3f'%(temp)
if test_acc > temp:
print ('model saving..............')
path = os.getcwd() + '/model_saved'
save_path = os.path.join(path, "model.ckpt")
self.saver.save(self.sess, save_path=save_path)
print ('model saved...............')
temp = test_acc
# output = conv2d('Z_cona{}'.format(i), output, 3, 64, stride=1, padding='SAME')
def generator(self,name, noise, reuse):
with tf.variable_scope(name,reuse=reuse):
l = self.batch_size
output = fc('g_dc', noise, 2*2*64)
output = tf.reshape(output, [-1, 2, 2, 64])
output = tf.nn.relu(self.bn('g_bn1',output))
output = deconv2d('g_dcon1',output,5,outshape=[l, 4, 4, 64*4])
output = tf.nn.relu(self.bn('g_bn2',output))
output = deconv2d('g_dcon2', output, 5, outshape=[l, 8, 8, 64 * 2])
output = tf.nn.relu(self.bn('g_bn3', output))
output = deconv2d('g_dcon3', output, 5, outshape=[l, 16, 16,64 * 1])
output = tf.nn.relu(self.bn('g_bn4', output))
output = deconv2d('g_dcon4', output, 5, outshape=[l, 32, 32, self.dim])
output = tf.image.resize_images(output, (28, 28))
# output = tf.nn.relu(self.bn('g_bn4', output))
return tf.nn.tanh(output)
def discriminator(self, name, inputs, reuse):
l = tf.shape(inputs)[0]
inputs = tf.reshape(inputs, (l,self.img_size,self.img_size,self.dim))
with tf.variable_scope(name,reuse=reuse):
out = []
output = conv2d('d_con1',inputs,5, 64, stride=2, padding='SAME') #14*14
output1 = lrelu(self.bn('d_bn1',output))
out.append(output1)
# output1 = tf.contrib.keras.layers.GaussianNoise
output = conv2d('d_con2', output1, 3, 64*2, stride=2, padding='SAME')#7*7
output2 = lrelu(self.bn('d_bn2', output))
out.append(output2)
output = conv2d('d_con3', output2, 3, 64*4, stride=1, padding='VALID')#5*5
output3 = lrelu(self.bn('d_bn3', output))
out.append(output3)
output = conv2d('d_con4', output3, 3, 64*4, stride=2, padding='VALID')#2*2
output4 = lrelu(self.bn('d_bn4', output))
out.append(output4)
output = tf.reshape(output4, [l, 2*2*64*4])# 2*2*64*4
output = fc('d_fc', output, self.num_class)
# output = tf.nn.softmax(output)
return output, out
def bn(self, name, input):
val = tf.contrib.layers.batch_norm(input, decay=0.9,
updates_collections=None,
epsilon=1e-5,
scale=True,
is_training=True,
scope=name)
return val
# def get_loss(self, logits, layer_out):
def test(self):
count = 0.
print 'testing................'
for i in range(10000//self.batch_size):
testx, textl = mnist.test.next_batch(self.batch_size)
prediction = self.sess.run(self.prediction, feed_dict={self.x:testx, self.label:textl})
count += np.sum(prediction)
return count/10000.
args是传进来的参数,主要包括三个,
- 一个是args.model,选择DCGAN模式还是WGAN-GP模式,二者的不同主要在于损失函数不同和优化器的学习率不同,其他都一样。
- 第二个参数是args.trainable,训练还是测试,训练时为True,测试是False。
- Loadmodel表示是否选择加载训练好的权重。
Build_model函数里面主要包括了网络训练前的准备工作,主要包括损失函数的设计和优化器的设计。下文将详细做出介绍,尤其是损失函数部分。
首先,建立了五个placeholder,flag表示两个标志位,只有0-1两种情况,注意到我num_class是11,也就是做11分类,但是lable的placeholder中shape是(batchsize,10)。为了方便,我将生成器的生成结果和真实数据X级联在一起作为判别器的输入,输出再把他它们结果split分开。
d_regular 表示正则化,这里我将判别器中所有的weights做了l2正则。
监督学习的损失函数使用常见的交叉熵损失函数,对生成器生成的图像的label的one_hot型为:
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
将原始的label扩展到(batchsize,11)后再和生成器生成的假数据的label在第一维度concat到一起得到batchl,另外乘以0.9,做单边标签平滑(one side smoothing),由此计算得到监督学习的损失函数值s_l,。
生成器G的损失函数
生成器G的损失函数包括两部分,一个是来自GAN训练的部分,另外一个是feature matching , 论文中提到的feature matching意思是特征匹配,主要思想是希望生成器生成的假数据输入到判别器,经过判别器每一层计算的结果和将真实数据X输入到判别器,判别器每一层的结果尽可能的相似,公式如下:
其中\(f(x)\)是D的每一层的输出。Feature matching 是指导G进行训练,所以我将他放在了G的损失函数里。
分类器D的损失函数:
相比较G的损失函数,D的损失函数就比较麻烦了
接下来介绍无监督学习的损失函数实现:
在前面介绍的无监督学习的损失函数中,有一部分和GAN的损失函数很相似,所以在代码中我们使用了
无监督学习的时候没有标签的指导,此时判别器或者称为分类器D无法正确对输入进行分类,此时只要求D能够区分真假就可以了,由此我们得到了无监督学习的损失un_s,直观上也很好理解,假设输入给判别器D真图像,它结果经过softmax后输出类似下面表格的形式,其中前十个黄色区域表示对0-9的分类概率,最后一个灰色的表示对假图像的分类概率,由于无监督学习中判别器D并不知道具体是哪一类数据,所以干脆D的损失函数最小化输出假图像的概率就可以了,当输入为生成器生成的假图像时,只要最小化D输出为真图像的概率,由此我们得到了un_s.。但是此时有一个问题,即是有监督学习的时候不就没有用了吗,因为这时候应该使用s_l.为了解决这个问题,我使用了一个标志位flag作为控制他们之间的使用,具体代码:
有标签的时候flag是1,表示使用s_l,无监督的时候flag是0,表示使用无监督损失函数。此时已经完成了判别器D损失函数的一部分设计,剩下的一部分和GAN中的D的损失一样,在代码中我给出了两种损失函数,一个是原始GAN的交叉熵损失函数,和DCGAN使用的一样,另外一个是improved wgan论文中使用的损失函数,但是在做了对比之后,我强烈建议使用DCGAN来做,improved wgan的损失函数虽然在生成结果的优化上有很大帮助,但是并不适合半监督学习中。
训练
接下来就是训练部分:
此时可能有一个疑问,我们是如何实现只使用200带标签的数据训练的,答案就在flag这个标志位里,在训练部分代码中,当迭代次数小于200的时候,flag=1, 此时表示使用s_l作为损失函数的一部分,当flag=0的时候,un_s起作用而s_l并没有起作用,这时,即使我们feed了正确的标签数据,但是s_l不起作用,就相当于没有使用标签。
flag的作用本来是使用他控制feature matching是否工作的,因为这部分损失相当的大,后来发现影响不大,暂时就放在这里了。
测试
1 | def test(self): |
测试精度结果变化图
本文实验代码
使用GAN实现半监督学习代码https://github.com/LDOUBLEV/semi-supervised-GAN
如果感觉有用的话,欢迎star, fork
备注
详细代码请以github中为准,另关于结果不理想的问题,可能和之前做的迁移学习有关,下面是最近跑出来的结果,最好的精度是0.95,这个问题有时间会慢慢解决。另:链接中的模型精度是很高的,可以直接调用