tf.keras 模型相关源码分析

0. 前言

  • 为什么要学 tf.keras 模型相关源码?
    • 在TF1.x和TF2.x中,模型构建都推荐使用 tf.keras,两者可以通用。
    • 在TF1.x中参数管理都是使用 global variables,即 tf.get_collection 相关,但TF2.x中不推荐使用这种方法。
    • 使用 tf.keras 管理参数(如通过 model.variables 获取参数等),是TF1.x和TF2.x的通用方法。但由于没有具体使用过,相关的资料好像也没有特别全的,所以希望自己研究源码学习一遍,理解底层实现原理。
  • 本文源码基于 1.14.0
    • 1.13.2中没有tf.Module
    • 1.14.02.0.0rc1的keras模型源码相同。
    • 1.15.02.0.0的keras源码相同,但这版本源码在VSCode上代码跳转上存在一些问题,目前无法处理。
  • 本文主要包含以下几个模块的学习:tf.keras.ModelNetworktf.keras.layers.Layertf.Module
  • 参考资料:
  • 提一句:如果要看tf.keras.Model有具体哪些模型相关功能,主要参考Network类即可。

1. tf.Module

1.1. 基本情况

  • 位于 tensorflow.python.module.module.py
  • 定位:是一个容器,包含 tf.Variabletf.Module、一些函数。
  • 作用有以下三个:
    • 参数管理。
    • 子模块管理。
    • name_scope应用。

1.2. 功能简介

  • 参数管理:
    • 概述:提供variablestrainable_variables属性管理tf.Variable
    • 相关属性(property):variablestrainable_variables
    • 相关方法:_flatten_flatten_module
  • 子模块管理:
    • 概述:提供submodules属性管理其他 tf.Module
    • 相关属性(property):submodules
    • 相关方法:_flatten_flatten_module
  • name_scope 应用:
    • 概述:为当前module提供tf.name_scope对象,主要作用是在tensorboard中group operations以及为变量名create hierarchies。
    • 相关属性:name_scope
    • 相关方法:classmethod - name_scope
    • 建议使用方法:
      • 方法一:在代码块中添加 with self.name_scope:
      • 方法二:使用@tf.Module.with_name_scope修饰除构造器以外的方法。

1.3. 参数、子模块代码分析

  • 这两部分代码都调用了以下两个方法:_flatten_flatten_module

  • 参数、子模块管理的思路都一样,大致思路如下:

    • 使用tf.Module使用成员变量保存参数或子模块。
    • _flatten_module中使用内置函数 vars 获取当前对象所有属性(即成员变量)。
  • _flatten 源码分析

    • 该方法包含详细文档,需要仔细看下。
    • 主要功能就是调用了 _flatten_module 方法。
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      def _flatten(self,
      recursive=True,
      predicate=None,
      attribute_traversal_key=None,
      with_path=False):
      """
      Args:
      recursive: Whether to recurse into child modules or not.
      是否要递归寻找子模块的对应内容

      predicate: (Optional) If set then only values matching predicate are
      yielded. A value of `None` (the default) means no items will be
      filtered.
      使用该方法过滤对象成员变量(如判断是variables, trainable_variables, modules)

      attribute_traversal_key: (Optional) Method to rekey object attributes
      before they are sorted. Contract is the same as `key` argument to
      builtin `sorted` and only applies to object properties.
      重新生成key

      with_path: (Optional) Whether to include the path to the object as well
      as the object itself. If `with_path` is `True` then leaves will not be
      de-duplicated (e.g. if the same leaf instance is reachable via multiple
      modules then it will be yielded multiple times with different paths).
      如果是True,则返回的key为tuple,包含多个元素;
      如果是False,则只返回当前元素名称(即成员变量名称)

      Returns:
      Flat generator for leaves of the current module and optionally all
      submodules.
      返回的是生成器(即yield结果)
      """
      if predicate is None:
      predicate = lambda _: True

      return _flatten_module(
      self,
      recursive=recursive,
      predicate=predicate,
      attributes_to_ignore=self._TF_MODULE_IGNORED_PROPERTIES,
      attribute_traversal_key=attribute_traversal_key,
      with_path=with_path)
  • _flatten_module 源码分析

    • 几个理解重点:
      • vars:内置函数,获取对象所有成员变量。
      • id:内置函数,获取对象内存地址。
      • flatten_with_tuple_paths:返回一组元组,每个元组包含属性path和属性值。
        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12
        13
        14
        15
        16
        17
        18
        19
        20
        21
        22
        23
        24
        25
        26
        27
        28
        29
        30
        31
        32
        33
        34
        35
        36
        37
        38
        39
        40
        41
        42
        43
        44
        45
        46
        47
        48
        49
        50
        51
        52
        53
        54
        55
        56
        57
        def _flatten_module(module,
        recursive,
        predicate,
        attribute_traversal_key,
        attributes_to_ignore,
        with_path,
        module_path=(),
        seen=None):
        """Implementation of `flatten`."""
        if seen is None:
        seen = set([id(module)])

        module_dict = vars(module)
        submodules = []

        for key in sorted(module_dict, key=attribute_traversal_key):
        if key in attributes_to_ignore:
        continue

        for leaf_path, leaf in nest.flatten_with_tuple_paths(module_dict[key]):
        leaf_path = (key,) + leaf_path

        # TODO(tomhennigan) Handle cycles for `with_path=True` (e.g. `a.a = a`).
        # 在 with_path 为 False 时,确定每个leaf只迭代一次
        if not with_path:
        leaf_id = id(leaf)
        if leaf_id in seen:
        continue
        seen.add(leaf_id)

        # 判断当前成员变量/属性是否符合要求,符合则输出
        if predicate(leaf):
        if with_path:
        yield module_path + leaf_path, leaf
        else:
        yield leaf

        # 如果当前成员变量/属性是 `tf.Module`,则继续递归查询叶节点
        if recursive and _is_module(leaf):
        # Walk direct properties first then recurse.
        submodules.append((module_path + leaf_path, leaf))

        # 递归查询叶节点
        for submodule_path, submodule in submodules:
        subvalues = _flatten_module(
        submodule,
        recursive=recursive,
        predicate=predicate,
        attribute_traversal_key=attribute_traversal_key,
        attributes_to_ignore=submodule._TF_MODULE_IGNORED_PROPERTIES,
        with_path=with_path,
        module_path=submodule_path,
        seen=seen)

        for subvalue in subvalues:
        # Predicate is already tested for these values.
        yield subvalue

2. tf.keras.layers.Layer

2.1. 基本情况

  • 位于 tensorflow.python.keras.engine.base_layer
  • 参考资料:
  • 概述:
    • 所有layer的父类。
    • 继承了 tf.Module,但参数管理并没有使用 tf.Module 中那一套,而是自己构建的一套。
  • 主要功能:
    • 构建模型正向过程。
    • 参数管理(weight/loss/metric/update)。
  • 自定义实现时,建议实现三个方法:
    • __init__():使用成员变量保存配置信息。
    • build():在知道输入数据的shape和dtype后,在调用__call__时会调用一次该方法。
    • call():在确认build被调用后,在__call__中被调用。
  • 初始化方法
    • 定义:def __init__(self, trainable=True, name=None, dtype=None, **kwargs):
    • kwargs 可以设置的参数包括:input_shape, batch_input_shape, batch_size, weights, activity_regularizer
    • 作用:看了源码,就是初始化了一大串成员变量,没什么太多值得一提的。

2.2. __call__ 方法

  • 定义:def __call__(self, inputs, *args, **kwargs):
  • 作用:模型的前向过程,主要就是调用了 call 方法,注意 argskwargs 都是要导入 call 方法中的。
  • 相关方法:
    • __call__
    • build
    • call
  • 流程:
    • 将输入数据 inputs (如tuple, dict, 单个元素等)转换为 Python list。
    • 以下操作都包裹在 name_scope(self.name)
      • 调用 _maybe_build(),其中查看built的取值,如果为True则直接返回。
      • 判断输入数据是否符合要求。
      • 调用 call 方法。
    • 如果有 _initial_weights 则调用 set_weights 设置数据。

2.3. config 相关

  • 概述:Layer的config是一个Python字典对象,可使用该对象reinstantiated相同的Layer对象。
  • 主要的参数就是:name, trainable, _batch_input_shape, dtype
  • 相关方法:
    • get_config:获取dict对象。
    • from_config:通过dict对象和class类,获取一个新的对象。

2.4. nodes 相关

  • 概述:用于处理Layer之间的相互连接关系。
  • 相关属性:
    • inbound_nodes:其实就是调用了成员变量 _inbound_nodes
    • outbound_nodes:其实就是调用了成员变量 _outbound_nodes
    • output_shape
    • input_shape
    • output
    • input
  • 相关方法:
    • _add_inbound_node
    • get_input_shape_at
    • get_output_shape_at
    • get_input_at
    • get_output_at
    • get_input_mask_at
    • get_output_mask_at
  • class Node
    • Layer 所在文件中定义。
    • 作用:连接多个Layer的就是Node。
    • inbound_nodes 是别的Layer连接到本Layer。
    • outbound_nodes 是本Layer连接到别的Layer。
    • 初始化参数:
      • outbound_layer:输入 input_tensors 并得到 output_tensors 的层。
      • inbound_layers:Layer序列,长度与 input_tensors 相同,即得到每个 input_tensor 的来源。
      • node_indices:整数序列,长度与 inbound_layers 相同,与 node_indices[i] 中node对应 input_tensors[i]
      • tensor_indices:整数序列,长度与 inbound_layers 相同,与 input_tenosrs[i]中output tensors 对应。
      • input_tensors
      • output_tensors
    • node 对象在 call 方法被调用时生成。
  • 举例:若产生A层到B层的连接,则Node对象被建立,A.outbound_nodesB.inbound_nodes 分别添加该Node。
  • 自己画了图感受了一下,猜测调用一次Layer就创建一个Node,像Siamese那样调用多次Layer就创建多个Node。
  • input/output相关(包括大多数shape相关)属性/操作都是调用了 inbound_nodes 来处理。

2.6. 参数管理

  • weights
    • 也就是 variables 的别名
    • 从底层实现来看,包括当前Layer直接保存的weights和当前Layer包含的其他layers对象中的weights参数。
    • 当前Layer的weights主要保存在 _trainable_weights_non_trainable_weights 两个序列中。
    • 相关方法:
      • add_wegiht:类似于 tf.get_variable,创建好后加入上述对应序列中。
      • set_weights:通过ndarray对象赋值。
      • get_weights:获取得到 ndarray 结果。
    • 相关属性
      • trainable_weights:获取本Layer以及其他Layer的所有trainable weights。
      • non_trainable_weights:获取本Layer以及其他Layer的所有non_trainable weights。
      • weights:上述两个属性的集合。
  • loss
    • 对应了三个成员变量,分别是 _losses, _eager_losses, callable_losses。第三个参数一般就是 regularizer。
    • 相关属性 - losses
      • 分为当前Layer losses以及子Layer losses。
      • 当前Layer Losses:根据是否是eager模式获取 _losses_eager_losses,之后通过 _callable_losses 获取对应数据。
      • 子Layer Losses 类似。
    • 相关方法:
      • add_loss:eager模式下不支持。输入数据可以是tensors,也可以是无参数的callable对象,确定类型后添加到对应的序列中。
      • get_loss_for(self, inputs)
        • 与对应 inputs 对应的losses。
        • 如果inputs为空,则返回所有 _unconditional_loss
        • 通过 tf_utils.get_reachable_from_inputs 获取所有能够连接到的tensor,判断losses中哪些是能够连接到的tensor,并返回。
  • update
    • 用于存储weights更新操作,如BN。
    • 相关方法:
      • add_update:将输入的updates存入_updates中,并根据是否输入inputs设置对应的_unconditional_update属性。
      • get_updates_for:如果输入inputs为None,则输入所有_unconditional_update为True的updates;如果inputs不为None,则通过tf_utils.get_reachable_from_inputs获取对应的updates。
    • 相关属性:
      • 成员变量_updates:保存当前Layer的相关操作。
      • updates:会返回当前Layer以及对应子Layer的所有updates。
  • metric
    • 性能指标没有属性,只有成员变量_metrics以及_metrics_tensors
    • 相关方法:add_metric(add_metric(self, value, aggregation=None, name=None))
      • value是tensor,也没啥难度。获取tensor和对应updates应该是在之前做的。
    • 相关属性:
      • _metrics:list
      • _metrics_tensors:dict,key是string,value是tensor。

3. Network

3.1. 基本情况

  • 位于 tensorflow.python.keras.engine.network
  • 概述:Network 是由一系列 Layer 的拓扑组成。Model 对象仅仅是在 Network 的基础上加上一些训练过程。
  • 作用:使用keras构建模型其实就是构建了一个 Network 对象。
  • Network提供了两种构建模型的方式
    • Graph Networks:就是使用 functional api 构建模型。
    • Subclass Networks:新建class(继承tf.keras.Model)。
  • Graph Networks 支持的但 Subclass Networks 不支持的特性:
    • Model cloning (keras.models.clone)
    • Serialization (model.get_config()/from_config, model.to_json()/to_yaml()
    • Whole-model saving (model.save())
  • Subclass Networks 模式可通过 name/dynamic 来创建对象。
  • 初始化源码分析
    • 作用:
      • 通过输入参数,确定是使用 Graph Network 还是 Subclassed Network。
      • 确认无 legacy layers,好像意思就是 tf.layers.Layers 中的对象。
    • 使用 Graph Network 的情况:
      • 可变参数数量为2。
      • 可变参数数量为1且命名参数中存在 outputs
      • 可变查宿数量为0且命名参数中存在 inputsoutputs
    • 除上述情况外,其他均使用 Subclassed Network。
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      def __init__(self, *args, **kwargs):  # pylint: disable=super-init-not-called
      # Signature detection
      if (len(args) == 2 or
      len(args) == 1 and 'outputs' in kwargs or
      'inputs' in kwargs and 'outputs' in kwargs):
      # Graph network
      self._init_graph_network(*args, **kwargs)
      else:
      # Subclassed network
      self._init_subclassed_network(**kwargs)

      tf_utils.assert_no_legacy_layers(self.layers)

3.2. Graph Network 初始化

  • 流程:
    • 命名参数确认:只允许有 name/trainable 这两个命名参数,其他参数均不允许。
    • 输入输出初始化。
      • 包括为所有 outputs 设置 create_keras_history 属性。
    • 基本初始化(与 Subclassed Model 通用):初始化了一票成员变量。
    • 判断输入输出是否合法。
      • 每个输入tensor不能重复出现,否则抛出异常。
      • 每个输入tensor必须有 _keras_history,否则抛出异常。
      • 每个输入必须是 input tensor,否则输出 warning 日志。
      • 所有输入的batch size应该不矛盾,否则。
      • 每个输出tensor必须有 _keras_history,否则抛出异常。
    • 初始化一票成员变量。
    • 获取所有相关的nodes和layers。
      • 感觉大概就是从 outputs 开始,使用DFS往前找Layer和Node。
      • 把找到的nodes和layers存放到对应的成员被变量中。
    • 创建Node,关联inputs和outputs。
    • 设置 input_namesoutput_names
  • 正向流程:
    • 相关函数call,其中主要就是调用了 _run_internal_graph 方法。
    • _run_internal_graph 方法的主要流程就是 Depth 依次获取各层nodes,然后依次调用layer计算output tensors。

3.3. Subclassed Network 初始化

  • 好像没啥多的工作。
  • def build(self, input_shape): 方法:
    • 作用:根据 input shapes 构建模型,用于 Subclassed Network。
    • 在实现过程中,是创建了 tf.placeholder 后调用 call 方法。

3.4. config 相关

  • 相关方法:to_json, to_yaml, get_config, from_config
  • 作用:获取模型的配置文件(json/yaml 形式),可以通过这个配置文件重新获取模型。

3.5. 模型权重相关

  • 相关方法:save, save_weights, load_weights
  • 相关属性:trainable_weightsnon_trainable_weights, weights
    • weights:所有_layers的weights以及当前Model的weights。

3.6. Layers 相关

  • 相关属性:layers
  • 相关方法:get_layer
  • 概述:
    • 在构建模型的过程中初始化了 _layers 成员变量。
    • get_layer 就是根据名称来获取对应的对象。
    • layers 就是在 _layers 的基础上过滤掉 Layer-like containers。

3.7. summary

  • 作用:打印网络信息。
  • 细节:
    • 必须在调用过 build 方法后使用,否则报错。
    • 可以选择每行输出长度、输出信息位置、输出方法(默认使用print)。
    • 输出信息主要包括 Layer type, output shape, param, connected to 四个方面,分别使用 layer.name + '(' + layer.__class__.__name__ + ')', layer.output_shape, layer.count_params() 来获取。

4. tf.keras.Model

4.1. 概述

  • 功能:在 Module 的基础上增加一系列模型训练、预测、评估相关的操作。
  • 主要模块:
    • distribution_strategy 相关。
      • 在 TF2.x中,distribute 不可用,需要使用 distribution strategy scope 替代。
      • get_weights, load_weights
    • compile
    • metrics, metrics_names, reset_metrics
    • fit, fit_generator, train_on_batch
    • evaluate, evaluate_generator
    • predict, predict_generator, predict_on_batch, test_on_batch
  • 这些都不是模型相关,而是训练/预测/评估相关,这不是本文的主要内容。