본문으로 건너뛰기

Forward Backward & Chain Rule

· 4분 읽기

딥러닝의 핵심 메커니즘 중 하나인 순전파와 역전파 과정을 수식적으로 깊이 있게 다뤄보자.

먼저, 신경망의 역전파를 이해하기 위해 필요한 두 가지 중요한 미적분학 관련 개념인 편미분과 합성 함수의 미분에 대해 간단히 알아보자.

우선 편미분은 다변수 함수에서 한 변수의 변화가 함수에 미치는 영향을 측정하기 위해 사용된다. 다른 모든 변수는 상수로 간주하면서 해당 변수에 대해서만 미분을 진행하는 것이다.

예를 들어, 함수 f(x,y)=x2y+3xy+y2f(x, y) = x^2 y + 3xy + y^2 에서 𝑥에 대한 편미분을 계산하면, 𝑦를 상수로 간주하고 𝑥에 대해서만 미분을 수행한다. 수식으로 표현하면 아래와 같다.

fx=2xy+3y\frac{\partial f}{\partial x} = 2xy + 3y

이번에는 𝑥를 상수로 간주하고 𝑦에 대해서만 미분을 해 본다. 아래처럼 결과가 나온다.

fy=x2+3x+2y\frac{\partial f}{\partial y} = x^2 + 3x + 2y

합성 함수의 미분은 복잡한 함수들이 서로 결합되어 있을 때 사용된다. 이는 체인룰(Chain Rule)이라고 하는 규칙을 사용하여 외부 함수의 미분과 내부 함수의 미분을 곱해 전체 미분을 구하는 방식인데.

예를 들어, 𝑦=𝑓(𝑔(𝑥))의 형태를 갖는 함수에서 𝑥에 대한 𝑦의 미분을 구하면, 𝑔(𝑥)의 미분과 𝑓(𝑔)의 미분을 곱하여 아래와 같이 계산한다.

dydx=dfdgdgdx\frac{dy}{dx} = \frac{df}{dg} \cdot \frac{dg}{dx}

신경망은 이렇게 복잡한 합성 함수로 구성되어 있다고 이해할 수 있다. 신경망에서 각 레이어는 가중합 연산과 활성화 함수를 포함하고 있으며, 이 두 요소가 결합되어 하나의 함수를 형성한다. 이렇게 각 레이어를 거치면서 다양한 함수들이 결합되어, 전체 신경망은 연속적인 합성 함수로 구성된다고 볼 수 있다. 이러한 구조 때문에 신경망을 미분할 때는 합성 함수의 미분법을 적용한다.

이어서 역전파, 순전파 과정에서의 계산 과정을 알아보겠다. 흔히 딥 러닝 모델에서 데이터가 처리되는 방식은 크게 순전파(Forward Propagation)와 역전파(Backpropagation) 두 과정으로 나뉘는데요. 각 개념을 간단히 복습해 보겠다.

순전파는 입력 데이터가 네트워크의 처음부터 끝까지 흘러가는 과정이다. 이 과정에서 각 레이어는 입력 데이터에 대해 선형 변환(예: 가중치와 편향을 이용한 계산)과 활성화 함수(예: ReLU, Sigmoid)를 적용한다. 이러한 계산을 통해 입력 데이터는 최종적으로 출력층에 도달하고, 모델은 예측 결과를 생성한다.

역전파는 모델의 예측 결과와 실제 값과의 차이(손실)를 계산한 후, 이 손실 값을 사용하여 모델의 각 파라미터(가중치)를 조정하는 과정이다. 손실 함수를 통해 계산된 손실을 기반으로, 앞서 배운 Chain Rule을 이용해 각 레이어와 뉴런의 가중치에 대한 손실의 미분값(기울기)을 계산하고, 이 기울기를 사용하여 가중치를 업데이트한다.

계산 과정을 알아보기 위해, 일단 간단한 예시를 들어보겠다. 입력 데이터는 xx, 각 층에서 적용할 가중치 값은 w1w_1w2w_2 이고. 일단 입력 xx에 가중치w1w_1을 곱한 후 시그모이드 함수를 적용하여 출력값 a1a_1을 생성한다.

z1=w1xz_1 = w_1 \cdot x
a1=σ(z1)a_1 = \sigma(z_1)

이 출력값 a1a_1 은 다음 층으로 전달된다. 이게 출력층이라고 하면, a1a_1 에 가중치 w2w_2 를 곱하여 최종적으로 예측값 ypredy_{pred} 를 계산하게 되는데. 이 과정을 통해 입력 데이터는 초기 단계에서 처리되어 최종적인 예측 결과를 생성하게 된다.

z2=w2a1z_2 = w_2 \cdot a_1
ypred=σ(z2)y_{\text{pred}} = \sigma(z_2)

다음에는 역전파 과정에 대해서도 알아보겠다. 우리가 현재 고려하고 있는 단층 퍼셉트론에서는 순전파 과정을 통해 입력 데이터를 받아 최종 예측값 ypredy_{pred}를 생성하고, 이 값을 사용하여 신경망의 성능을 평가하는데.

이때 손실 함수로는 아래와 같은 평균 제곱 오차(Mean Squared Error, MSE)를 사용해 보겠다.

L=12(ypredytrue)2L = \frac{1}{2}(y_{\text{pred}} - y_{\text{true}})^2

신경망의 주된 목표는 이 손실 함수의 값을 최소화하는 것이며, 이를 위해 역전파 과정을 통해 가중치 w1w_1w2w_2 를 업데이트한다.

역전파 과정은 순전파의 반대 방향으로 진행되며, 이 과정에서 가중치에 대한 손실 함수의 그래디언트(미분값)를 계산하는데. Chain Rule을 사용하여 이 그래디언트를 계산하고, 계산된 그래디언트는 가중치의 업데이트를 위해 사용된다.

구체적으로, 손실 함수 𝐿에 대해 w2w_2의 그래디언트를 먼저 계산하고, 이어서 w1w_1의 그래디언트를 계산한다. 이러한 그래디언트 계산은 각 가중치의 영향을 평가하고, 가중치를 조정함으로써 손실을 줄이는 방향으로 신경망을 조정하는 데 중요한 역할을 한다.

이렇게 역전파를 통한 가중치 업데이트는 신경망이 학습하고 성능을 개선하도록 돕는다. 먼저, 위 손실함수를 ypredy_{pred} 에 대해 미분한 L/ypred{\partial L}/{\partial y_{\text{pred}}} 값은 아래와 같다. 우리는 이 값을 Chain Rule 적용 과정에서 사용할 것이다. Lypred=(ypredytrue)\frac{\partial L}{\partial y_{\text{pred}}} = (y_{\text{pred}} - y_{\text{true}})

역전파 과정에서, 출력층의 가중치 w2w_2 에 대한 손실 함수의 그래디언트는 Chain Rule을 사용하여 계산된다.

Lw2=Lypredypredz2z2w2\frac{\partial L}{\partial w_2} = \frac{\partial L}{\partial y_{\text{pred}}} \cdot \frac{\partial y_{\text{pred}}}{\partial z_2} \cdot \frac{\partial z_2}{\partial w_2}

가중치 w1w_1에 대한 그래디언트도 비슷한 방식으로 계산된다. 이 그래디언트는 다음과 같은 세 개의 미분값을 연쇄적으로 곱해서 구할 수 있는데요. 각각z2/a1∂z_2/∂a_1w2w_2 , z1/w1∂z_1/∂w_1은 입력 xx, a11/z11∂a_11/∂z_11σ(z1)σ′(z1) 와 같다. 이 모든 미분값들을 연쇄적으로 곱하는 것으로 w1w_1 에 대한 그래디언트를 계산하며, 이는 모델의 가중치를 업데이트하는 데 사용된다.

Lw1=Lypredypredz2z2a1a1z1z1w1\frac{\partial L}{\partial w_1} = \frac{\partial L}{\partial y_{\text{pred}}} \cdot \frac{\partial y_{\text{pred}}}{\partial z_2} \cdot \frac{\partial z_2}{\partial a_1} \cdot \frac{\partial a_1}{\partial z_1} \cdot \frac{\partial z_1}{\partial w_1}

이제 본격적으로 경사 하강법을 통해 w1w_1w2w_2 를 업데이트해 보겠다. 경사 하강법은 머신 러닝과 딥 러닝에서 널리 사용되는 최적화 기법으로, 모델의 손실 함수를 최소화하기 위해 가중치를 반복적으로 조정하는 방법이다. 이 과정에서 계산된 각 가중치에 대한 손실 함수의 그래디언트를 사용하여 가중치를 업데이트한다.

w=wηLww = w - \eta \cdot \frac{\partial L}{\partial w}

위 공식을 이용해 위에서 계산한 각 가중치의 그래디언트에 대해서 업데이트 하는 과정은 아래와 같다. 참고로 w1w_1 도 동일하게 적용된다고 보면 된다.

w2=w2ηLw2w_2 = w_2 - \eta \cdot \frac{\partial L}{\partial w_2}

Lw2(ypredytrue)σ(z2)a1\frac{\partial L}{\partial w_2} (y_{\text{pred}} - y_{\text{true}}) \cdot \sigma'(z_2) \cdot a_1

위 과정을 통해 각 가중치 w1w_1w2w_2 는 주어진 데이터에 대해 손실을 최소화하는 방향으로 조정된다.