[In]
# 경사하강법 구하기
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
def my_func(x): # 최솟값을 구하는 함수
return x**2 - 2*x
def grad_func(x): # 도함수
return 2*x - 2
lr = 0.1 # 학습계수
x = 4.0 # 초깃값 설정
record_x = []
record_y = []
for i in range(20): # x를 20번 갱신
y = my_func(x)
record_x.append(x)
record_y.append(y)
x -= lr * grad_func(x) # 경사하강법
x_f = np.linspace(-2, 4, 1000)
y_f = my_func(x_f)
plt.plot(x_f, y_f, linestyle = 'dashed') # 함수를 점선으로 표시
plt.scatter(record_x, record_y) # x와 y의 기록을 표시
plt.xlabel('x', size = 14)
plt.ylabel('y', size = 14)
plt.show()
[Out]
f(x)가 계속 감소하다가 기울기가 0일 때 멈추는 것을 확인할 수 있다.
# 지역최소점과 전역최소점
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
def my_func(x): # 최솟값을 구하는 함수
return x**4 + 2*x**3 - 3*x**2 - 2*x
def grad_func(x): # 도함수를 구하는 함수
return 4*x**3 + 6*x**2 - 6*x -2
lr = 0.01 # 학습계수
x = 1.6 # x의 초깃값
record_x = [] # x에 대한 기록
record_y = [] # y에 대한 기록
for i in range(20):
y = my_func(x)
record_x.append(x)
record_y.append(y)
x -= lr*grad_func(x) # 최소하강법
x_f = np.linspace(-2.8, 1.6, 1000)
y_f = my_func(x_f)
plt.plot(x_f, y_f, linestyle = 'dashed')
plt.scatter(record_x, record_y) # x와 y의 기록을 표시
plt.xlabel('x', size = 14)
plt.ylabel('y', size = 14)
plt.grid()
plt.show()
[Out]
결과를 확인해보면, f(x)가 계속 감소하다가 기울기가 0일 때 더이상 진행이 되지 않고 지역최소점의 트랩에 빠지게 된다. 이러한 결과를 방지하기 위해, 초기값 x를 적절하게 선택을 하거나, 랜덤하게 선택하는 방향으로 전역최소점에 다다르도록 조절한다.
다음 예시를 보며 전역최소점일 때를 구해보자.
[In]
# 지역최소점의 트랩을 회피하도록 코드를 설정
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
def my_func(x): # 최솟값을 구하는 함수
return x**4 - 2*x**3 - 3*x**2 + 2*x
def grad_func(x): # 도함수
return 4*x**3 - 6*x**2 - 6*x + 2
lr = 0.01 # 학습계수
x = 1.0 # 초깃값을 적절하게 설정
record_x = []
record_y = []
for i in range(20):
y = my_func(x)
record_x.append(x)
record_y.append(y)
x -= lr * grad_func(x)
x_f = np.linspace(-1.6, 2.8, 1000)
y_f = my_func(x_f)
plt.plot(x_f, y_f, linestyle = 'dashed')
plt.scatter(record_x, record_y)
plt.xlabel('x', size = 14)
plt.ylabel('y', size = 14)
plt.grid()
[Out]
초깃값 x를 1.0으로 선택함으로써, 함수가 전역최소점에 다다를 수 있게 된 것을 확인할 수 있다.