본문 바로가기

Python/AI 수학 with Python

[Python] 경사하강법 구현

[In]

# 경사하강법 구하기
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

def my_func(x): # 최솟값을 구하는 함수
    return x**2  - 2*x

def grad_func(x): # 도함수
    return 2*x - 2

lr = 0.1 # 학습계수
x = 4.0 # 초깃값 설정

record_x = []
record_y = []
for i in range(20): # x를 20번 갱신
    y = my_func(x)
    record_x.append(x)
    record_y.append(y)
    x -= lr * grad_func(x) # 경사하강법

x_f = np.linspace(-2, 4, 1000)
y_f = my_func(x_f)

plt.plot(x_f, y_f, linestyle = 'dashed') # 함수를 점선으로 표시
plt.scatter(record_x, record_y) # x와 y의 기록을 표시

plt.xlabel('x', size = 14)
plt.ylabel('y', size = 14)
plt.show()

[Out]

f(x)가 계속 감소하다가 기울기가 0일 때 멈추는 것을 확인할 수 있다.

 

# 지역최소점과 전역최소점
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt

def my_func(x): # 최솟값을 구하는 함수
    return x**4 + 2*x**3 - 3*x**2 - 2*x

def grad_func(x): # 도함수를 구하는 함수
    return 4*x**3 + 6*x**2 - 6*x -2

lr = 0.01 # 학습계수
x = 1.6 # x의 초깃값

record_x = [] # x에 대한 기록
record_y = [] # y에 대한 기록

for i in range(20):
    y = my_func(x)
    record_x.append(x)
    record_y.append(y)
    x -= lr*grad_func(x) # 최소하강법
    
x_f = np.linspace(-2.8, 1.6, 1000)
y_f = my_func(x_f)

plt.plot(x_f, y_f, linestyle = 'dashed')
plt.scatter(record_x, record_y) # x와 y의 기록을 표시

plt.xlabel('x', size = 14)
plt.ylabel('y', size = 14)
plt.grid()

plt.show()

[Out]

결과를 확인해보면, f(x)가 계속 감소하다가 기울기가 0일 때 더이상 진행이 되지 않고 지역최소점의 트랩에 빠지게 된다. 이러한 결과를 방지하기 위해, 초기값 x를 적절하게 선택을 하거나, 랜덤하게 선택하는 방향으로 전역최소점에 다다르도록 조절한다.

 

다음 예시를 보며 전역최소점일 때를 구해보자.

[In]

# 지역최소점의 트랩을 회피하도록 코드를 설정
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

def my_func(x): # 최솟값을 구하는 함수
    return x**4 - 2*x**3 - 3*x**2 + 2*x

def grad_func(x): # 도함수
    return 4*x**3 - 6*x**2 - 6*x + 2

lr = 0.01 # 학습계수
x = 1.0 # 초깃값을 적절하게 설정

record_x = []
record_y = []

for i in range(20):
    y = my_func(x)
    record_x.append(x)
    record_y.append(y)
    x -= lr * grad_func(x)
    
x_f = np.linspace(-1.6, 2.8, 1000)
y_f = my_func(x_f)

plt.plot(x_f, y_f, linestyle = 'dashed')
plt.scatter(record_x, record_y)

plt.xlabel('x', size = 14)
plt.ylabel('y', size = 14)
plt.grid()

[Out]

초깃값 x를 1.0으로 선택함으로써, 함수가 전역최소점에 다다를 수 있게 된 것을 확인할 수 있다.