首页 文章详情

【TensorFlow】笔记:基础知识-自定义层

深度学习入门笔记 | 701 2021-02-11 02:51 0 0 0
UniSMS (合一短信)

点击上方“公众号”可订阅哦!



在TensorFlow2.0中,任何一个自定义层都继承自tf.keras.layers.Layer。

Layer层中需要自定义的函数有很多,但是在实际使用时一般只需要定义那些必须使用的函数即可,以下是对__init__、build和call三个主要函数的小结。



01

__init__函数

__init__ 函数首先是一些必要参数的初始化,这些参数的初始化写在 def __init__(self,) 中,然后是一些参数的初始化。
class Mylayer(tf.keras.layers.Layer):     # 显示继承自Layer层    def __init__(self, unit):             # init中显示地确定参数        super().__init__()                # 调用父层        self.unit = unit                  # 把参数加载到类

init函数最重要的就是显式的确定需要的一些参数。对于输入的init中的参数,输入Tensor是不会在这里进行标注的,init初始化的是模型参数。



02

build函数


build() 可自定义网络的权重的维度,可以根据输入来指定权重的维度。


    def build(self, input_shape):        self.weight = self.add_weight(shape=(input_shape[-1], self.unit),                                     initializer=tf.keras.initializers.RandomNormal(),                                     trainable=True)        self.bias = self.add_weight(shape=(self.unit,),                                   initializer=tf.keras.initializers.Zeros(),                                   trainable=True)


在Layer() 类中有一个__call__() 魔法方法(上述两个函数已经被tf集成在该函数下面),会被自动调用,因此不用外部调用。




03

call函数


call函数是最重要的函数,这部分代码包含了主要层的实现,即完成前向传播。

init函数,定义并声明参数,build函数声明了权重可变参数,而这只是定义了一些初始化的参数以及一些需要更新的参数变量,真正实现所定义类的功能是在call函数中。


    def call(self, inputs):        return tf.matmul(inputs, self.weight) + self.bias


call中的一系列操作是对init和build中变量的引用,所有的计算在call中完成。


输入的参数在这里出现,经过计算后将计算值返回。



完整代码:

import tensorflow as tf
class MyLayer(tf.keras.Model): def __init__(self, unit=32): super(MyLayer, self).__init__() self.unit = unit
def build(self, input_shape): self.weight = self.add_weight(shape=(input_shape[-1], self.unit), initializer=tf.keras.initializers.RandomNormal(), trainable=True) self.bias = self.add_weight(shape=(self.unit,), initializer=tf.keras.initializers.Zeros(), trainable=True) def call(self, inputs): return tf.matmul(inputs, self.weight) + self.bias
my_layer = MyLayer(3)x = tf.ones((3,5))out = my_layer(x)print(out)

输出:

tf.Tensor([[ 0.16174725 -0.03372785 -0.01657906] [ 0.16174725 -0.03372785 -0.01657906] [ 0.16174725 -0.03372785 -0.01657906]], shape=(3, 3), dtype=float32)





 END

深度学习入门笔记

微信号:sdxx_rmbj

日常更新学习笔记、论文简述

good-icon 0
favorite-icon 0
收藏
回复数量: 0
    暂无评论~~
    Ctrl+Enter