0. 前言
- 为什么要学
tf.keras
模型相关源码?- 在TF1.x和TF2.x中,模型构建都推荐使用
tf.keras
,两者可以通用。 - 在TF1.x中参数管理都是使用 global variables,即
tf.get_collection
相关,但TF2.x中不推荐使用这种方法。 - 使用
tf.keras
管理参数(如通过model.variables
获取参数等),是TF1.x和TF2.x的通用方法。但由于没有具体使用过,相关的资料好像也没有特别全的,所以希望自己研究源码学习一遍,理解底层实现原理。
- 在TF1.x和TF2.x中,模型构建都推荐使用
- 本文源码基于
1.14.0
。1.13.2
中没有tf.Module
。1.14.0
和2.0.0rc1
的keras模型源码相同。1.15.0
和2.0.0
的keras源码相同,但这版本源码在VSCode上代码跳转上存在一些问题,目前无法处理。- 与
1.14.0
的主要区别在于整理了API,新建了tensorflow_core
,具体keras模型相关的代码应该没什么改变。 - 跳转不能正常使用的原因如下:tensorflow issue #31973、vscode-python-language-server issue #818
- 与
- 本文主要包含以下几个模块的学习:
tf.keras.Model
、Network
、tf.keras.layers.Layer
、tf.Module
- 参考资料:
- 提一句:如果要看
tf.keras.Model
有具体哪些模型相关功能,主要参考Network
类即可。
1. tf.Module
1.1. 基本情况
- 位于
tensorflow.python.module.module.py
。 - 定位:是一个容器,包含
tf.Variable
、tf.Module
、一些函数。 - 作用有以下三个:
- 参数管理。
- 子模块管理。
name_scope
应用。
1.2. 功能简介
- 参数管理:
- 概述:提供
variables
、trainable_variables
属性管理tf.Variable
。 - 相关属性(property):
variables
、trainable_variables
- 相关方法:
_flatten
、_flatten_module
- 概述:提供
- 子模块管理:
- 概述:提供
submodules
属性管理其他tf.Module
。 - 相关属性(property):
submodules
。 - 相关方法:
_flatten
、_flatten_module
。
- 概述:提供
name_scope
应用:- 概述:为当前module提供
tf.name_scope
对象,主要作用是在tensorboard中group operations以及为变量名create hierarchies。 - 相关属性:
name_scope
- 相关方法:
classmethod
-name_scope
。 - 建议使用方法:
- 方法一:在代码块中添加
with self.name_scope:
。 - 方法二:使用
@tf.Module.with_name_scope
修饰除构造器以外的方法。
- 方法一:在代码块中添加
- 概述:为当前module提供
1.3. 参数、子模块代码分析
这两部分代码都调用了以下两个方法:
_flatten
、_flatten_module
。参数、子模块管理的思路都一样,大致思路如下:
- 使用
tf.Module
使用成员变量保存参数或子模块。 - 在
_flatten_module
中使用内置函数vars
获取当前对象所有属性(即成员变量)。
- 使用
_flatten
源码分析- 该方法包含详细文档,需要仔细看下。
- 主要功能就是调用了
_flatten_module
方法。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42def _flatten(self,
recursive=True,
predicate=None,
attribute_traversal_key=None,
with_path=False):
"""
Args:
recursive: Whether to recurse into child modules or not.
是否要递归寻找子模块的对应内容
predicate: (Optional) If set then only values matching predicate are
yielded. A value of `None` (the default) means no items will be
filtered.
使用该方法过滤对象成员变量(如判断是variables, trainable_variables, modules)
attribute_traversal_key: (Optional) Method to rekey object attributes
before they are sorted. Contract is the same as `key` argument to
builtin `sorted` and only applies to object properties.
重新生成key
with_path: (Optional) Whether to include the path to the object as well
as the object itself. If `with_path` is `True` then leaves will not be
de-duplicated (e.g. if the same leaf instance is reachable via multiple
modules then it will be yielded multiple times with different paths).
如果是True,则返回的key为tuple,包含多个元素;
如果是False,则只返回当前元素名称(即成员变量名称)
Returns:
Flat generator for leaves of the current module and optionally all
submodules.
返回的是生成器(即yield结果)
"""
if predicate is None:
predicate = lambda _: True
return _flatten_module(
self,
recursive=recursive,
predicate=predicate,
attributes_to_ignore=self._TF_MODULE_IGNORED_PROPERTIES,
attribute_traversal_key=attribute_traversal_key,
with_path=with_path)
_flatten_module
源码分析- 几个理解重点:
vars
:内置函数,获取对象所有成员变量。id
:内置函数,获取对象内存地址。flatten_with_tuple_paths
:返回一组元组,每个元组包含属性path和属性值。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57def _flatten_module(module,
recursive,
predicate,
attribute_traversal_key,
attributes_to_ignore,
with_path,
module_path=(),
seen=None):
"""Implementation of `flatten`."""
if seen is None:
seen = set([id(module)])
module_dict = vars(module)
submodules = []
for key in sorted(module_dict, key=attribute_traversal_key):
if key in attributes_to_ignore:
continue
for leaf_path, leaf in nest.flatten_with_tuple_paths(module_dict[key]):
leaf_path = (key,) + leaf_path
# TODO(tomhennigan) Handle cycles for `with_path=True` (e.g. `a.a = a`).
# 在 with_path 为 False 时,确定每个leaf只迭代一次
if not with_path:
leaf_id = id(leaf)
if leaf_id in seen:
continue
seen.add(leaf_id)
# 判断当前成员变量/属性是否符合要求,符合则输出
if predicate(leaf):
if with_path:
yield module_path + leaf_path, leaf
else:
yield leaf
# 如果当前成员变量/属性是 `tf.Module`,则继续递归查询叶节点
if recursive and _is_module(leaf):
# Walk direct properties first then recurse.
submodules.append((module_path + leaf_path, leaf))
# 递归查询叶节点
for submodule_path, submodule in submodules:
subvalues = _flatten_module(
submodule,
recursive=recursive,
predicate=predicate,
attribute_traversal_key=attribute_traversal_key,
attributes_to_ignore=submodule._TF_MODULE_IGNORED_PROPERTIES,
with_path=with_path,
module_path=submodule_path,
seen=seen)
for subvalue in subvalues:
# Predicate is already tested for these values.
yield subvalue
- 几个理解重点:
2. tf.keras.layers.Layer
2.1. 基本情况
- 位于
tensorflow.python.keras.engine.base_layer
- 参考资料:
- 概述:
- 所有layer的父类。
- 继承了
tf.Module
,但参数管理并没有使用tf.Module
中那一套,而是自己构建的一套。
- 主要功能:
- 构建模型正向过程。
- 参数管理(weight/loss/metric/update)。
- 自定义实现时,建议实现三个方法:
__init__()
:使用成员变量保存配置信息。build()
:在知道输入数据的shape和dtype后,在调用__call__
时会调用一次该方法。call()
:在确认build
被调用后,在__call__
中被调用。
- 初始化方法
- 定义:
def __init__(self, trainable=True, name=None, dtype=None, **kwargs):
kwargs
可以设置的参数包括:input_shape
,batch_input_shape
,batch_size
,weights
,activity_regularizer
。- 作用:看了源码,就是初始化了一大串成员变量,没什么太多值得一提的。
- 定义:
2.2. __call__
方法
- 定义:
def __call__(self, inputs, *args, **kwargs):
- 作用:模型的前向过程,主要就是调用了
call
方法,注意args
和kwargs
都是要导入call
方法中的。 - 相关方法:
__call__
build
call
- 流程:
- 将输入数据
inputs
(如tuple, dict, 单个元素等)转换为 Python list。 - 以下操作都包裹在
name_scope(self.name)
。- 调用
_maybe_build()
,其中查看built
的取值,如果为True则直接返回。 - 判断输入数据是否符合要求。
- 调用
call
方法。
- 调用
- 如果有
_initial_weights
则调用set_weights
设置数据。
- 将输入数据
2.3. config 相关
- 概述:Layer的config是一个Python字典对象,可使用该对象reinstantiated相同的Layer对象。
- 主要的参数就是:name, trainable, _batch_input_shape, dtype
- 相关方法:
get_config
:获取dict对象。from_config
:通过dict对象和class类,获取一个新的对象。
2.4. nodes 相关
- 概述:用于处理Layer之间的相互连接关系。
- 相关属性:
inbound_nodes
:其实就是调用了成员变量_inbound_nodes
outbound_nodes
:其实就是调用了成员变量_outbound_nodes
output_shape
input_shape
output
input
- 相关方法:
_add_inbound_node
get_input_shape_at
get_output_shape_at
get_input_at
get_output_at
get_input_mask_at
get_output_mask_at
class Node
- 在
Layer
所在文件中定义。 - 作用:连接多个Layer的就是Node。
inbound_nodes
是别的Layer连接到本Layer。outbound_nodes
是本Layer连接到别的Layer。- 初始化参数:
outbound_layer
:输入input_tensors
并得到output_tensors
的层。inbound_layers
:Layer序列,长度与input_tensors
相同,即得到每个input_tensor
的来源。node_indices
:整数序列,长度与inbound_layers
相同,与node_indices[i]
中node对应input_tensors[i]
。tensor_indices
:整数序列,长度与 inbound_layers 相同,与input_tenosrs[i]
中output tensors 对应。input_tensors
output_tensors
- node 对象在
call
方法被调用时生成。
- 在
- 举例:若产生A层到B层的连接,则Node对象被建立,
A.outbound_nodes
和B.inbound_nodes
分别添加该Node。 - 自己画了图感受了一下,猜测调用一次Layer就创建一个Node,像Siamese那样调用多次Layer就创建多个Node。
- input/output相关(包括大多数shape相关)属性/操作都是调用了
inbound_nodes
来处理。
2.6. 参数管理
- weights
- 也就是
variables
的别名 - 从底层实现来看,包括当前Layer直接保存的weights和当前Layer包含的其他layers对象中的weights参数。
- 当前Layer的weights主要保存在
_trainable_weights
和_non_trainable_weights
两个序列中。 - 相关方法:
add_wegiht
:类似于tf.get_variable
,创建好后加入上述对应序列中。set_weights
:通过ndarray对象赋值。get_weights
:获取得到 ndarray 结果。
- 相关属性
trainable_weights
:获取本Layer以及其他Layer的所有trainable weights。non_trainable_weights
:获取本Layer以及其他Layer的所有non_trainable weights。weights
:上述两个属性的集合。
- 也就是
- loss
- 对应了三个成员变量,分别是
_losses
,_eager_losses
,callable_losses
。第三个参数一般就是 regularizer。 - 相关属性 -
losses
:- 分为当前Layer losses以及子Layer losses。
- 当前Layer Losses:根据是否是eager模式获取
_losses
或_eager_losses
,之后通过_callable_losses
获取对应数据。 - 子Layer Losses 类似。
- 相关方法:
add_loss
:eager模式下不支持。输入数据可以是tensors,也可以是无参数的callable对象,确定类型后添加到对应的序列中。get_loss_for(self, inputs)
:- 与对应 inputs 对应的losses。
- 如果inputs为空,则返回所有
_unconditional_loss
。 - 通过
tf_utils.get_reachable_from_inputs
获取所有能够连接到的tensor,判断losses
中哪些是能够连接到的tensor,并返回。
- 对应了三个成员变量,分别是
- update
- 用于存储weights更新操作,如BN。
- 相关方法:
add_update
:将输入的updates存入_updates
中,并根据是否输入inputs设置对应的_unconditional_update
属性。get_updates_for
:如果输入inputs为None,则输入所有_unconditional_update
为True的updates;如果inputs不为None,则通过tf_utils.get_reachable_from_inputs
获取对应的updates。
- 相关属性:
- 成员变量
_updates
:保存当前Layer的相关操作。 updates
:会返回当前Layer以及对应子Layer的所有updates。
- 成员变量
- metric
- 性能指标没有属性,只有成员变量
_metrics
以及_metrics_tensors
。 - 相关方法:
add_metric(add_metric(self, value, aggregation=None, name=None))
value
是tensor,也没啥难度。获取tensor和对应updates应该是在之前做的。
- 相关属性:
_metrics
:list_metrics_tensors
:dict,key是string,value是tensor。
- 性能指标没有属性,只有成员变量
3. Network
3.1. 基本情况
- 位于
tensorflow.python.keras.engine.network
- 概述:
Network
是由一系列Layer
的拓扑组成。Model
对象仅仅是在Network
的基础上加上一些训练过程。 - 作用:使用keras构建模型其实就是构建了一个
Network
对象。 - Network提供了两种构建模型的方式
- Graph Networks:就是使用 functional api 构建模型。
- Subclass Networks:新建class(继承
tf.keras.Model
)。
- Graph Networks 支持的但 Subclass Networks 不支持的特性:
- Model cloning (
keras.models.clone
) - Serialization (
model.get_config()/from_config
,model.to_json()/to_yaml()
- Whole-model saving (
model.save()
)
- Model cloning (
- Subclass Networks 模式可通过 name/dynamic 来创建对象。
- 初始化源码分析
- 作用:
- 通过输入参数,确定是使用 Graph Network 还是 Subclassed Network。
- 确认无 legacy layers,好像意思就是
tf.layers.Layers
中的对象。
- 使用 Graph Network 的情况:
- 可变参数数量为2。
- 可变参数数量为1且命名参数中存在
outputs
。 - 可变查宿数量为0且命名参数中存在
inputs
和outputs
。
- 除上述情况外,其他均使用 Subclassed Network。
1
2
3
4
5
6
7
8
9
10
11
12def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called
# Signature detection
if (len(args) == 2 or
len(args) == 1 and 'outputs' in kwargs or
'inputs' in kwargs and 'outputs' in kwargs):
# Graph network
self._init_graph_network(*args, **kwargs)
else:
# Subclassed network
self._init_subclassed_network(**kwargs)
tf_utils.assert_no_legacy_layers(self.layers)
- 作用:
3.2. Graph Network 初始化
- 流程:
- 命名参数确认:只允许有 name/trainable 这两个命名参数,其他参数均不允许。
- 输入输出初始化。
- 包括为所有 outputs 设置
create_keras_history
属性。
- 包括为所有 outputs 设置
- 基本初始化(与 Subclassed Model 通用):初始化了一票成员变量。
- 判断输入输出是否合法。
- 每个输入tensor不能重复出现,否则抛出异常。
- 每个输入tensor必须有
_keras_history
,否则抛出异常。 - 每个输入必须是 input tensor,否则输出 warning 日志。
- 所有输入的batch size应该不矛盾,否则。
- 每个输出tensor必须有
_keras_history
,否则抛出异常。
- 初始化一票成员变量。
- 获取所有相关的nodes和layers。
- 感觉大概就是从 outputs 开始,使用DFS往前找Layer和Node。
- 把找到的nodes和layers存放到对应的成员被变量中。
- 创建Node,关联inputs和outputs。
- 设置
input_names
和output_names
。
- 正向流程:
- 相关函数
call
,其中主要就是调用了_run_internal_graph
方法。 _run_internal_graph
方法的主要流程就是 Depth 依次获取各层nodes,然后依次调用layer计算output tensors。
- 相关函数
3.3. Subclassed Network 初始化
- 好像没啥多的工作。
def build(self, input_shape):
方法:- 作用:根据 input shapes 构建模型,用于 Subclassed Network。
- 在实现过程中,是创建了
tf.placeholder
后调用call
方法。
3.4. config 相关
- 相关方法:
to_json
,to_yaml
,get_config
,from_config
。 - 作用:获取模型的配置文件(json/yaml 形式),可以通过这个配置文件重新获取模型。
3.5. 模型权重相关
- 相关方法:
save
,save_weights
,load_weights
- 相关属性:
trainable_weights
,non_trainable_weights
,weights
weights
:所有_layers
的weights以及当前Model的weights。
3.6. Layers 相关
- 相关属性:
layers
- 相关方法:
get_layer
- 概述:
- 在构建模型的过程中初始化了
_layers
成员变量。 get_layer
就是根据名称来获取对应的对象。layers
就是在_layers
的基础上过滤掉 Layer-like containers。
- 在构建模型的过程中初始化了
3.7. summary
- 作用:打印网络信息。
- 细节:
- 必须在调用过
build
方法后使用,否则报错。 - 可以选择每行输出长度、输出信息位置、输出方法(默认使用
print
)。 - 输出信息主要包括 Layer type, output shape, param, connected to 四个方面,分别使用
layer.name + '(' + layer.__class__.__name__ + ')'
,layer.output_shape
,layer.count_params()
来获取。
- 必须在调用过
4. tf.keras.Model
4.1. 概述
- 功能:在 Module 的基础上增加一系列模型训练、预测、评估相关的操作。
- 主要模块:
- distribution_strategy 相关。
- 在 TF2.x中,
distribute
不可用,需要使用 distribution strategy scope 替代。 - 如
get_weights
,load_weights
。
- 在 TF2.x中,
compile
metrics
,metrics_names
,reset_metrics
fit
,fit_generator
,train_on_batch
evaluate
,evaluate_generator
predict
,predict_generator
,predict_on_batch
,test_on_batch
- distribution_strategy 相关。
- 这些都不是模型相关,而是训练/预测/评估相关,这不是本文的主要内容。