MENU

Numerical Analysis: 什么是梯度下降

• October 27, 2019 • Read: 5620 • Knowledge,Numerical Analysis,深度学习

梯度下降法的原理

梯度下降法是一种数值计算方法,但是在数值分析中往往并没有直接采用它,但是在这一方法基础上发展而来的共轭梯度法却被得到应用,而梯度下降法是其核心思想与方法。目前,较为火热的深度学习方面,由于广泛采用了自动求导原理,对损失函数求极值时,梯度下降算法在该领域得到了十分广泛的应用。而我们感兴趣的是利用共轭梯度法求线性方程组,该方法的思想就是梯度下降原理,因此,很有必要对该梯度下降法有一个认识。

什么是梯度下降

在微积分中,很常见的一个公式就是泰勒展开,这个公式意在描述函数$f(x)$的值与$f(x+\Delta x)$之间的关系,更直白一点的说,研究的是应变量变化的幅度。对于一元函数而言,利用泰勒展开可以方便的获得函数值的变化。

$$ f(x + \Delta x) = f(x) + f(x)' \cdot \Delta x + \frac{1}{2!} f(x)'' \cdot \Delta x ^2 $$

当我们不考虑二阶以上的项时,便得到$f(x + \Delta x)$的一个近似值。这个式子其几何意义在很多数值方法中被应用,比如在ODE方程的数值法中,欧拉方法本质上就是采用了这个公式去完成函数值的求解,利用该点的函数值、该点的斜率以及步长,即可得到下一个点的函数值。

在实际应用中往往是多元函数,尤其在深度学习中,神经网络每一层基本都是多个神经元,往往都是多元函数问题。从简单一点的情况来看,即是二元函数,此时泰勒公式的描述似乎有点繁琐,不是那么容易理解,但是从全微分的概念来看却十分清晰。二元函数的全微分公式为:

$$ \mathrm{d}f(x, y) = \frac{\partial f}{\partial x} \mathrm{d}x + \frac{\partial f}{\partial y} \mathrm{d}y $$

其中,$\mathrm{d}f(x, y)$表示函数的变化量,即$f(x + \Delta x, y + \Delta y) - f(x, y)$,这样很容易可以得到二元函数的泰勒展开形式。当然,这里对函数值变化量更感兴趣。对于上式的表达可以写为两个向量的内积,这样可以很直观地从几何角度来理解什么是梯度下降。函数的导数向量记为$a$

$$ a = (\frac{\partial f}{\partial x}, \frac{\partial f}{\partial y}) $$

而自变量的增量向量为$b$

$$ b = (\mathrm{d}x, \mathrm{d}y) $$

这两个向量的内积就表示函数的变化量的大小,而从向量内积的定义和几何两个方面来看,

$$ a \cdot b = |a||b|\mathrm{cos} \theta $$

当这两个向量平行且反向的时候,夹角为$\pi$,内积取得最小值,因为此时$\cos \pi = -1$,那么,从这个角度来看,只要这两个向量保持共线且反向,函数的变化量最小

$$ b = - k a \quad (k是一个正数) $$

$$ (\mathrm{d}x,\mathrm{d}y) = - k (\frac{\partial f}{\partial x}, \frac{\partial f}{\partial y}) $$

这样,通过这个式子就可得到函数沿着梯度方向的一个自变量变化值。因为沿着这个方向是下降最快的一个路径,此时的$k$可以看做是步长,所以该方法也称为最速下降法。上式虽然是以二元函数为例的,但是毫无疑问的是可以推广到$n$元函数中去。更为一个通用的表达是

$$ \nabla f = - k \Delta x_{i} $$

Example

比如函数$f(x, y) = x^2 + y ^ 2$,当然这个函数极小值很明显就是0,即当$x= 0, y=0$时取得最小值。

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# 函数图像
x = np.linspace(-2, 2, 50)
y = np.linspace(-2, 2, 50)
xx, yy = np.meshgrid(x, y)
z = xx**2 + yy**2

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
cb = ax.plot_surface(xx, yy, z, cmap='RdBu_r', alpha=0.7)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
fig.colorbar(cb, ax=ax)
ax.set_title('$z = x^2 + y^2$')
fig.savefig('function image.jpg', dpi=600)
plt.show()

function

此时,任意选择一个初始的点,如点$(1, 1)$,即$x_0 = 1, x_1 = 1$,函数的梯度向量为

$$ \nabla f = (2x, 2y) $$

带入初始值后,得到$\nabla f = (2, 2)$,步长k的取值如果过大,可能会使得自变量增量过大,导致跨越过最小值点,因此,适当的较小的$k$值是获得正确解答的必要条件,这里取$k=0.1$,由此得到自变量变化量为$(-0.2, -0.2)$,即下一个新的点为$(1 - 0.2, 1 - 0.2)$,这一算法的意义为,通过梯度向量计算得到的自变量变化量,就是函数下降最快的方向,以此反复计算,知道前后计算步变化误差满足精度要求即可。
对于这个函数,$x_0 = 1, y_0 = 1$,$x_1 = 0.8, y_1 = 0.8$, $x_2 = 0.64, y_2 = 0.64$,$x_3 = 0.516, y_3 = 0.516$

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# 函数图像
x = np.linspace(-2, 2, 50)
y = np.linspace(-2, 2, 50)
xx, yy = np.meshgrid(x, y)
z = xx**2 + yy**2

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
cb = ax.plot_surface(xx, yy, z, cmap='RdBu_r', alpha=0.7)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
fig.colorbar(cb, ax=ax)
ax.set_title('$z = x^2 + y^2$')

ax.plot([1, 0.8, 0.64, 0.516],
[1, 0.8, 0.64, 0.516],
[2, 1.28, 0.8192, 0.532512],
color='k', linewidth=2)

fig.savefig('gradient.jpg', dpi=600)
plt.show()

png

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# 函数图像
x = np.linspace(-2, 2, 50)
y = np.linspace(-2, 2, 50)
xx, yy = np.meshgrid(x, y)
z = xx**2 + yy**2

plt.style.use('ggplot')
fig,ax = plt.subplots(1,1)
cs = ax.contour(xx, yy, z, cmap='viridis')
ax.clabel(cs, inline=True, fontsize=12)
ax.set_xlabel('x')
ax.set_ylabel('y')
fig.colorbar(cb, ax=ax)
ax.set_title('$z = x^2 + y^2$')

ax.plot([1, 0.8, 0.64, 0.516], [1, 0.8, 0.64, 0.516], linewidth=2)

fig.savefig('contour.jpg', dpi=600)
plt.show()

contour

补充一点

这里为了显示出梯度下降的过程,对之前那个函数进行编程实现,获得每一步的自变量、应变量值,并进行可视化。

import numpy as np

f = lambda x,y: x**2 + y**2
x0 = 4
y0 = 5
fx = lambda x: 2 * x
fy = lambda y: 2 * y
step = 0.1
max_time = 400
tol = 0.00001
x = np.zeros(max_time)
y = np.zeros(max_time)
z = np.zeros(max_time)
number = 0
while f(x0, y0) > tol:
    dx, dy = np.array([fx(x0), fy(y0)])
    x0 = x0 - dx * step
    y0 = y0 - dy * step
    z0 = f(x0, y0)
    x[number] = x0
    y[number] = y0
    z[number] = z0
    number += 1
    if number > max_time:
        break

print(number)
print(x[0:number])
print(y[0:number])
print(z[0:number])

结果为

35
[3.20000000e+00 2.56000000e+00 2.04800000e+00 1.63840000e+00
 1.31072000e+00 1.04857600e+00 8.38860800e-01 6.71088640e-01
 5.36870912e-01 4.29496730e-01 3.43597384e-01 2.74877907e-01
 2.19902326e-01 1.75921860e-01 1.40737488e-01 1.12589991e-01
 9.00719925e-02 7.20575940e-02 5.76460752e-02 4.61168602e-02
 3.68934881e-02 2.95147905e-02 2.36118324e-02 1.88894659e-02
 1.51115727e-02 1.20892582e-02 9.67140656e-03 7.73712525e-03
 6.18970020e-03 4.95176016e-03 3.96140813e-03 3.16912650e-03
 2.53530120e-03 2.02824096e-03 1.62259277e-03]
[4.00000000e+00 3.20000000e+00 2.56000000e+00 2.04800000e+00
 1.63840000e+00 1.31072000e+00 1.04857600e+00 8.38860800e-01
 6.71088640e-01 5.36870912e-01 4.29496730e-01 3.43597384e-01
 2.74877907e-01 2.19902326e-01 1.75921860e-01 1.40737488e-01
 1.12589991e-01 9.00719925e-02 7.20575940e-02 5.76460752e-02
 4.61168602e-02 3.68934881e-02 2.95147905e-02 2.36118324e-02
 1.88894659e-02 1.51115727e-02 1.20892582e-02 9.67140656e-03
 7.73712525e-03 6.18970020e-03 4.95176016e-03 3.96140813e-03
 3.16912650e-03 2.53530120e-03 2.02824096e-03]
[2.62400000e+01 1.67936000e+01 1.07479040e+01 6.87865856e+00
 4.40234148e+00 2.81749855e+00 1.80319907e+00 1.15404740e+00
 7.38590339e-01 4.72697817e-01 3.02526603e-01 1.93617026e-01
 1.23914897e-01 7.93055338e-02 5.07555416e-02 3.24835466e-02
 2.07894698e-02 1.33052607e-02 8.51536685e-03 5.44983478e-03
 3.48789426e-03 2.23225233e-03 1.42864149e-03 9.14330553e-04
 5.85171554e-04 3.74509795e-04 2.39686269e-04 1.53399212e-04
 9.81754956e-05 6.28323172e-05 4.02126830e-05 2.57361171e-05
 1.64711150e-05 1.05415136e-05 6.74656869e-06]

可视化编程:


from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

xi = np.linspace(-6, 6, 50)
yi = np.linspace(-6, 6, 50)
xx, yy = np.meshgrid(xi, yi)
zi = xx**2 + yy**2
fig = plt.figure()
ax = Axes3D(fig)
ax.plot_surface(xx, yy, zi, cmap='RdBu_r', alpha=0.5)
ax.plot(x[0:number], y[0:number], z[0:number], color='tab:blue')
fig.savefig('demo.jpg',dpi=600)
plt.show()

demo

Last Modified: October 30, 2019
Archives Tip
QR Code for this page
Tipping QR Code