0 引言
在利用Caffe进行训练的时候,最终的训练结果会保存下来,在做预测的时候可以直接加载训练好的模型。但是目前接触的Tensorflow案例中,都是直接训练、然后在测试集上验证,最后退出整个程序。下次再使用的时候,就需要重新训练、预测。这样就很不科学。心想,肯定会有办法来保存和加载模型。于是,在莫烦的教程中看到保存和读取模型参数的教程[1],他说Tensorflow初期,只支持网络参数的保存和读取,不能保存网络结构,如果想使用训练好的参数,必须重新搭建一模一样的网络结构,才能完成预测。但是后来可能Tensorflow觉得这样不方便,于是推出MetaGraph[2],可以保存网络结构。本文主要介绍网络参数的保存和读取,以及网络结构的保存和读取。
最主要的类:tf.train.Saver
初始化参数:
max_to_keep 参数用来设置保存模型的个数,默认为5,即保存最近的5个模型。如果想每训练一代epoch就想保存一次模型,可以将 max_to_keep设置为None或者0,如:saver=tf.train.Saver(max_to_keep=0)
方法:
1 网络参数的保存与读取
网络参数也就是Tensorflow中的Variables类型。
网络参数的保存
|
|
运行上面的程序,会在当前文件夹下面创建model文件夹,并在model文件夹下,生成四个文件:checkpoint,params.ckpt.data-00000-of-00001,params.ckpt.index,params.ckpt.meta,它们的含义为:
- checkpoint文件保存了一个目录下所有的模型文件列表,这个文件是tf.train.Saver类自动生成且自动维护的。在 checkpoint文件中维护了由一个tf.train.Saver类持久化的所有TensorFlow模型文件的文件名。当某个保存的TensorFlow模型文件被删除时,这个模型所对应的文件名也会从checkpoint文件中删除。checkpoint中内容的格式为CheckpointState Protocol Buffer.
- params.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构。TensorFlow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。TensorFlow中元图是由MetaGraphDef Protocol Buffer定义的。MetaGraphDef 中的内容构成了TensorFlow持久化时的第一个文件。保存MetaGraphDef 信息的文件默认以.meta为后缀名,文件model.ckpt.meta中存储的就是元图数据。
params.ckpt.data-00000-of-00001文件保存了网络参数的值,但是数据是没有结构的。为了在网络中恢复模型,需要这样使用:
12saver = tf.train.import_meta_graph(path_to_ckpt_meta)saver.restore(sess, path_to_ckpt_data)params.ckpt.index还不清楚是干什么用的,猜想是一种映射关系。
网络参数的读取
|
|
2 网络模型的保存与读取
网络模型的导入导出是通过元图(MetaGraph)实现的。MetaGraph包含以下内容:
- MetaInfoDef for meta information, such as version and other user information.
- GraphDef for describing the graph.
- SaverDef for the saver.
- CollectionDef map that further describes additional components of the model, such as Variables, tf.train.QueueRunner, etc. In order for a Python object to be serialized to and from MetaGraphDef, the Python class must implement to_proto() and from_proto() methods, and register them with the system using register_proto_function.
导出
|
|
导入
最简单的情况:
|
|
训练到一半,停下来,导入接着训练
|
|
上面的训练,训练没有完成,停下来了,要接着训练:
利用之前的训练结果,训练扩展后的网络模型
首先,定义一个网络,训练,保存结果
然后,加载训练结果,扩展网络,接着训练。
- .pd文件:the .pb file can save your whole graph (meta + data). To load and use (but not train) a graph in c++ you’ll usually use it, created with freeze_graph, which creates the .pb file from the meta and data. Be careful, (at least in previous TF versions and for some people) the py function provided by freeze_graph did not work properly, so you’d have to use the script version. Tensorflow also provides a tf.train.Saver.to_proto() method, but I don’t know what it does exactly.
参考文献
[1] Morvan tensorflow
[2] tensorflow.org python api
Tensorflow C++ 编译和调用图模型:http://blog.csdn.net/rockingdingo/article/details/75452711