最近在抽时间学习TensorFlow这个DL库的使用,学的断断续续的,看官网上第一个案例就是训练手写字符识别, 我之前在做Weibo.cn验证码识别的时候,自己搞了一个数据集,当时用的c++库tiny-dnn进行训练的(见:验证码破解技术四部曲之使用卷积神经网络(四)),现在我把它移植到TensorFlow上试试。
完整代码见:weibo.cn/tensorflow-impl
使用的库
- TensorFlow-1.0
- scikit-learn-0.18
- pillow
加载数据集
数据集下载地址:training_set.zip
解压过后如下图:
我把同一类的图片放到了一个文件夹里,文件夹的名字也就是图片的label,打开文件夹后可以看到字符的图片信息。
下面,我们把数据加载到一个pickle文件里面,它需要有train_dataset、train_labels、test_dataset、test_labels四个变量代表训练集和测试集的数据和标签。
此外,还需要有个label_map,用来把训练的标签和实际的标签对应,比如说3对应字母M,4对应字母N。
此部分的代码见:load_models.py。注:很多的代码参考自udacity的deeplearning课程。
首先根据文件夹的来加载所有的数据,index代表训练里的标签,label代表实际的标签,使用PIL读取图片,并转换成numpy数组。
1 | import numpy as np |
接下来,把数据打乱。
1 | def randomize(dataset, labels): |
然后使用scikit-learn的函数,把训练集和测试集分开。
1 | from sklearn.model_selection import train_test_split |
在TensorFlow官网给的例子中,会把label进行One-Hot Encoding
,并把28*28的图片转换成了一维向量(784)。如下图,查看官网例子的模型。
我也把数据转换了一下,把32*32的图片转换成一维向量(1024),并对标签进行One-Hot Encoding。
1 | def reformat(dataset, labels, image_size, num_labels): |
转换后,格式就和minist一样了。
最后,把数据保存到save.pickle里面。
1 | save = { |
验证数据集加载是否正确
加载完数据后,需要验证一下数据是否正确。我选择的方法很简单,就是把trainset的第1个(或者第2个、第n个)图片打开,看看它的标签和看到的能不能对上。
1 | import cPickle as pickle |
运行后,可以看到第一张图片是Y,标签也是正确的。
训练
数据加载好了之后,就可以开始训练了,训练的网络就使用TensorFlow官网在Deep MNIST for Experts里提供的就好了。
此部分的代码见:train.py。
先加载一下模型:
1 | import cPickle as pickle |
minist的数据都是28*28的,把里面的网络改完了之后,如下:
1 | def weight_variable(shape): |
主要改动就是输入层把28*28改成了image_size*image_size(32*32),然后第三层的全连接网络把7*7改成了image_size/4*image_size/4(8*8),以及把10(手写字符一共10类)改成了num_labels。
然后训练,我这里把batch_size改成了128,训练批次改少了。
1 | batch_size = 128 |
运行,可以看到识别率在不断的上升。
最后,有了接近98%的识别率,只有4000个训练数据,感觉不错了。