Batch normalization 推导

前向传播阶段

首先计算一个 batch 的平均:

然后计算这个 batch 的方差:

然后就可以进行 normalization 了:

这一步是为了让输入分布满足标准正态分布,其中在分母上加上 $\epsilon$ 是为了防止分母过小或为 0 。

进行缩放和平移,得到最终结果:

从另一个角度看,这个操作就是上一步的反操作,抵消上面 bn 的影响,增强算法的表达能力,在之后,我们还能进行一些其他变换,比如 relu 等。

在训练时就不需要进行 normalization 了。

1
2
3
4
5
6
7
8
9
10
11
12
13
if mode == 'train':
sample_mean = np.mean(x, axis=0)
sample_var = np.var(x, axis=0)
x_hat = (x - sample_mean) / np.sqrt(sample_var + eps)
out = gamma * x_hat + beta
cache = (gamma, x, sample_mean, sample_var, eps, x_hat)

running_mean = momentum * running_mean + (1 - momentum) * sample_mean
running_var = momentum * running_var + (1 - momentum) * sample_var
else:
# 不需要进行 normalization
x_hat = (x - running_mean) / np.sqrt(running_var + eps)
out = gamma * x_hat + beta

反向传播阶段

对于 bn 的反向传播,能一步步进行推导,具体可以参考:

Understanding the backward pass through Batch Normalization Layer

关键在于仔细。

对于 alternative backward 的推导:

What does the gradient flowing through batch normalization looks like

其中最关键的是对于链式损失函数的推导:


假设,我们对于 x 的 normalization 的情况如下:

也就是说只有位移,没有缩放,那么:

这里就是该推导的核心部分,对于 $\delta _{i,j}$ 只有当 $i=j$ 时,$\delta$ 为 1 否则为 0 。所以对于上面的式子而言,第一个式子为 1 只有当 $k=i$ 且 $l=j$ ,第二个式子为 $1/N$ 只有当 $l=j$ 。


从最简单 beta 的开始:


然后是次简单的 gamma :


最后的第一条链式推导:

首先我们知道:

所以:

还没完,我们知道:

于是:

组合一下:

最后:

该推导对应的代码:

1
2
3
4
5
6
mean = 1./N*np.sum(x, axis = 0)
var = 1./N*np.sum((x-mean)**2, axis = 0)
dbeta = np.sum(dy, axis=0)
dgamma = np.sum((h - mean) * (var + eps)**(-1. / 2.) * dy, axis=0)
dx = (1. / N) * gamma * (var + eps)**(-1. / 2.) * (N * dy - np.sum(dy, axis=0)
- (x - mean) * (var + eps)**(-1.0) * np.sum(dy * (x - mean), axis=0))

这个推导非常富有想象力,然而有一种更简单的推导方式:

Deriving the Gradient for the Backward Pass of Batch Normalization

新代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def batchnorm_backward(dout, cache):
N, D = dout.shape
x_mu, inv_var, x_hat, gamma = cache

# intermediate partial derivatives
dxhat = dout * gamma

# final partial derivatives
dx = (1. / N) * inv_var * (N*dxhat - np.sum(dxhat, axis=0)
- x_hat*np.sum(dxhat*x_hat, axis=0))
dbeta = np.sum(dout, axis=0)
dgamma = np.sum(x_hat*dout, axis=0)

return dx, dgamma, dbeta

LN 和 BN 的不同在于方向不同, LN 是纵向, BN 是横向。不再赘述。

LN:$\mathbf { x } : \mathbf { N } \times D \rightarrow \boldsymbol { \mu } , \boldsymbol { \sigma } : \boldsymbol { 1 } \times \mathbf { D }$

BN:$\mathbf { x } : \mathbf { N } \times D \rightarrow \boldsymbol { \mu } , \boldsymbol { \sigma } : \mathbf { N } \times \mathbf { 1 }$