Autodiff/Backpropagation
์ด ๊ธ์ ์ฌ์ธต์ ๊ฒฝ๋ง ์์ ์์ ๊ณต๋ถํ ๋ด์ฉ์ ๊ธฐ๋ฐํ์ง๋ง, ์ ๊ฐ ๋๋ฆ๋๋ก ์ดํดํ ๋ฐ๋ฅผ ๋ง๋ถ์ฌ์ ์์ฑํ์ต๋๋ค. ํนํ, ์ค๋ช ํ๋ ๋ฐฉ๋ฒ์ด ์กฐ๊ธ ๋ค๋ฆ ๋๋ค.
Introduction
Multi layer perceptron์ด ๋์๋ , Convolutionary neural network๊ฐ ๋์๋ ๊ธฐ๋ณธ์ ์ธ ํ์ logistic regression๊ณผ ๋ค๋ฅผ ๊ฒ์ด ์์ต๋๋ค.
์ค๋ช ์ ํธ์๋ฅผ ์ํด ์ด ๊ธ์์๋ MLP์ ๋ํด ์ค๋ช ํ๊ฒ ์ง๋ง, CNN๋ ์ฌ์ค์ ๋ณ๋ก ๋ค๋ฅด์ง ์์ต๋๋ค.
- MLP๋ ๊ฒฐ๊ตญ, ๋ค์๊ณผ ๊ฐ์ ํํ์ ํจ์๋ก ๋ํ๋๋ ํ๊ท ๋ชจํ์ด๋ผ๊ณ ๋ณผ ์ ์์ต๋๋ค. \(\begin{align*} y_L &= W_L y_{L-1} + b_L \\ y_{L - 1} &= \sigma(W_{L-1} y_{L - 2} + b_{L - 1}) \\ \cdots & \cdots \\ y_2 &= \sigma (W_2 y_1 + b_2) \\ y_1 &= \sigma (W_1 x + b_1) \end{align*}\)
- ์ฌ๊ธฐ์ $W_i, b_i$ ๋ค์ ๋ชจ๋ trainable weight ์ด๊ณ , $\sigma$๋ ์ด๋ค activation function ์ ๋๋ค.
- ์ฐ๋ฆฌ๋, $y_L$์ ์ฐธ๊ฐ (์ด๋ผ๊ณ ๋งํ๋ฉด ์ข ์ ๋งคํ์ง๋งโฆ) ์ ๋ฐํํ๋ ํจ์ $\tilde{y_L} = f(x)$ ๊ฐ ์กด์ฌํ๋ค๊ณ ์๊ฐํฉ๋๋ค. ์ด๋ฅผ ์ต๋ํ ๊ทผ์ฌ ํ๋ ๊ฒ์ด ๋ชฉํ์ ๋๋ค. ์ฆ, ์ ์ ํํ์ ํจ์ ($W_i, b_i$ ๋ฅผ ์ด์ฉํ์ฌ ํํ๋๋ ํจ์) ๋ฅผ ํํ ๊ฐ๋ฅํ๋ค ๋ผ๊ณ ์ ์ํ๋ฉด, Ground-truth ํจ์์ ๊ฐ์ฅ ๊ฐ๊น์ด ํํ๊ฐ๋ฅํ ํจ์ ๋ฅผ ์ฐพ๊ณ ์ถ์ต๋๋ค.
- ๊ทธ๋ฌ๋ ์ฐ๋ฆฌ๋ ground truth๋ฅผ ๋ชจ๋ ์๋๊ฒ ์๋๋ผ, ๋ช๋ช ๋ฐ์ดํฐ $x^1, x^2, \dots$ ์ ๋ํด ์๊ณ ์์ต๋๋ค.
๋ฐ๋ผ์, ์ด๋ค Loss function์ ์ ์ํ์ฌ \(\sum_{i = 1}^{n} \mathcal{L}(y_L^{i}, \tilde{y_L}^{i})\) ์ ์ ์ํ ๋ค์, ์ด $\mathcal{L}$ ์ด ์ด๋ค ์ค์ $\tilde{y_L}^i$ ๊ณผ $y_L^i$ ๊ฐ์ ๊ฑฐ๋ฆฌ๋ฅผ ์ ์ํ๋ฏ๋ก, ์ด๋ฅผ ๊ฐ๋ฅํ ์ต์ํํ๋ ๋ฐฉํฅ์ผ๋ก ๋์๊ฐ๋ ค๊ณ ํฉ๋๋ค. - ๊ทธ๋ฌ๋ฏ๋ก, ์ฐ๋ฆฌ๋ ์ฌ๊ธฐ์ SGD ๋๋ ๊ทธ ๋น์ทํ ์๊ณ ๋ฆฌ์ฆ๋ค์ ์ฌ์ฉํฉ๋๋ค. ์ฆ, $W_k, b_k$ ํ๋ ฌ ๋๋ ๋ฒกํฐ์ ๋ค์ด ์๋ ๊ฐ ๋ณ์ $W_k(i, j)$ ๋ $b_k(i)$ ๋ฅผ ์ด์ฉํด์ ์ ์ฒด ๊ณต๊ฐ์์์ Loss function์ ๊ทธ๋ ค๋๊ณ , ๊ทธ minimum์ (iterativeํ๊ฒ) ์ฐพ์ ์ ์๊ธฐ๋ฅผ ๋ฐ๋๋๋ค.
- SGD๋ ๋ค๋ฅธ ๋ฐฉ๋ฒ์ ์ฐ๋ ค๋ฉด, ๊ฒฐ๊ตญ์ ์ด๋ฐ ๋๋์ ํธ๋ฏธ๋ถ๊ณ์๋ค์ ๊ผญ ์์์ผ ํฉ๋๋ค. \(\pdv{\mathcal{L}}{W_k(i, j)} \quad \quad \pdv{\mathcal{L}}{b_k(i)}\)
๋ชจ๋ธ์ด ๊ฐ๋จํ๋ฉด ๋ญ ์ง์ ๋ฏธ๋ถํ๋ค๊ณ ์น์ง๋ง, ์์ ์๋ MLP ์ ๊ฐ์ด ์๊ธด ๋ณต์กํ ํจ์๋ฅผ ์ด๋ป๊ฒ ๋ฏธ๋ถํ ์ ์์๊น์?
Backpropagation
์ ํํ๋ฅผ ์ ๋ณด๋ฉด, ํฉ์ฑํจ์ ํํ์์ ์ ์ ์์ต๋๋ค. ํฉ์ฑํจ์์ ๋ฏธ๋ถ์ Chain rule์ ์ด์ฉํด์ ์ํํ ์ ์์ต๋๋ค.
$x \in \R^m, y \in \R^n$, $g : \R^m \to \R^n, f : \R^n \to \R$ ์ ๋์ ์ธํ
์ ์๊ฐํด ๋ด
์๋ค. $\mathbf{y} = g(\mathbf{x}), z = \mathbf{y}$ ๋ผ ํ ๋, ๋ค์์ด ์ฑ๋ฆฝํฉ๋๋ค.
\(\pdv{z}{x_i} = \sum_{j} \pdv{z}{y_j} \pdv{y_j}{x_i}\)
์ด ๋ฐฉ๋ฒ์ ์ด์ฉํด์, ์ฐ๋ฆฌ๋ ์ ์ฒด $P$๊ฐ์ ๋ชจ๋ ํ๋ผ๋ฏธํฐ์ ๋ํด $\pdv{\mathcal{L}}{w_i}$ ๋ฅผ ๊ตฌํด์ผ ํฉ๋๋ค.
์ด๋ฅผ ๊ณ์ฐํ๊ธฐ ์ํด, ๋จผ์ Computational graph๋ฅผ ๋ง๋ญ๋๋ค. Computational graph๋, ์๋์ ๊ฐ์ด ๊ฐ ๊ฐ๋ค์ ๋
ธ๋๋ก, ๊ณ์ฐ์ ํ์ํ dependency๋ค์ edge๋ก ์ฐ๊ฒฐํด์ ๊ทธ๋ํ ํํ๋ก ๋ง๋ ๊ฒ์
๋๋ค.
(์ฌ์ง์ถ์ฒ : ์์ธ๋ํ๊ต ์ฌ์ธต์ ๊ฒฝ๋ง์ ์ํ์ ๊ธฐ์ด ๊ฐ์์๋ฃ)
์ฌ๊ธฐ์, โ๋ณ์๋ฅผ ๋ค๋ฅธ ๋ณ์๋ก ๋ฏธ๋ถํ ๋ฏธ๋ถ๊ณ์โ ๋ค์ ๊ตฌํ๊ณ , โ์ต์ข ๊ฒฐ๊ณผ๋ฅผ ๋ณ์๋ก ๋ฏธ๋ถํ ๋ฏธ๋ถ๊ณ์โ ๋ฅผ ๊ทธ ๊ฒฐ๊ณผ๋ก ์ป์ ๊ฒ์ ๋๋ค. ๊ตฌ์ฒด์ ์ผ๋ก,
- ๊ฐ edge์ ๋ํด, ๋ณ์๋ฅผ ๋ณ์๋ก ๋ฏธ๋ถํ ์ค๊ฐ ๋ฏธ๋ถ๊ณ์๋ฅผ edge์ ์ ์ด๋ฃ๊ณ ,
- ๋ง์ง๋ง์, root ๋ ธ๋ (i.e, ๊ณ์ฐ์ ์ต์ข ๊ฐ) ์์ ์ถ๋ฐํด์, ์์์ ๋ ธ๋๊น์ง ๊ฐ๋ ๊ฒฝ๋ก๋ฅผ ๋ชจ๋ ๋ฐ๋ผ๊ฐ๋ฉด์ ๊ณฑํด์ ๋ํ๋ฉด โ์ต์ข ๊ฒฐ๊ณผ๋ฅผ ๋ณ์๋ก ๋ฏธ๋ถํโ ๋ฏธ๋ถ๊ณ์๋ฅผ ์ป์ต๋๋ค.
์ฆ, ์๊ณ ๋ฆฌ์ฆ์ ์ธ์ด๋ก ๋งํ์๋ฉด DAG ์์์ depth๊ฐ ๋ฎ์ ๋ ธ๋๋ถํฐ ๊ฑฐ๊พธ๋ก ์ฌ๋ผ๊ฐ๋ฉด์ edge์ ๊ฐ์ ๊ณ์ฐํ๊ณ (DP), ๋์์ฌ๋๋ topological order๋ก ๊ณ์ฐํ๊ฒ ๋ค๋ ์๋ฏธ์ ๋๋ค.
์ด ์ฌ์ง์ ์๋ ํจ์๋ฅผ ์ง์ ๊ณ์ฐํ๋ฉด์ ๊ณผ์ ์ ๋ฐ๋ผ๊ฐ ๋ณด๊ฒ ์ต๋๋ค.
- step์ด ๋ฎ์ ๊ฒ๋ถํฐ ์ฌ๋ผ๊ฐ๋๋ค. ์ฆ, ์ฒ์์๋ step 1์ธ $a$-๋ ธ๋์ ๋ฏธ๋ถ๊ณ์๋ค์ ๊ณ์ฐํ๊ธฐ ์ํด $\pdv{a}{x}$ ๋ฅผ ๊ตฌํ๋ฉฐ, ์ด ๊ฐ์ $1/x$ ์ด๋ฏ๋ก 1/3์ ๋๋ค. ์ฌ๊ธฐ์ ์ฃผ๋ชฉํ ์ ์, ์ผ๋ฐ์ ์ธ symbolic differntiation์ ์ํํ ๋๋ $1/x$๋ฅผ ๋ค๊ณ ๊ฐ์ง๋ง, ์ฐ๋ฆฌ๋ ์ด์ฐจํผ ์ต์ข ์ ์ผ๋ก ์์น์ฐ์ฐ์ ํ ๊ฒ์ด๋ฏ๋ก $1/3$ ์ด๋ผ๋ ์ฌ์ค๋ง ๊ธฐ์ตํ๋ฉด $1/x$ ๋ผ๋ ๊ฐ์ ์์ด๋ฒ๋ ค๋ ๋ฉ๋๋ค. ์ด ๊ฐ์ edge์ ์ ์ด ๋ฃ์ต๋๋ค. ๋ํ ์ดํ์ $a$๊ฐ๋ ํ์ํ๊ธฐ ๋๋ฌธ์ $a = \log 3$ ์ด๋ผ๋ ๊ฒฐ๊ณผ๋ฅผ ๋ ธ๋์ ์ ์ด๋ฃ์ต๋๋ค. ์ด์ , step 1๊น์ง ์์ต๋๋ค.
- step 2์ ํด๋นํ๋ $b$๋ฅผ ๊ตฌํด์ผ ํฉ๋๋ค. $\pdv{b}{a} = y, \pdv{b}{y} = a$ ์ด๋ฉฐ, ์ด๋ ๊ฐ๊ฐ $a, y$์ ์ด๋ฏธ ๊ณ์ฐํ ๋ ธ๋๊ฐ ์ ์ฐธ์กฐํด์ ๊ณ์ฐํ ์ ์์ต๋๋ค. ๊ฐ๊ฐ $2, \log 3$ ์ด ๋ ๊ฒ์ด๋ฉฐ, ์ด๋ฅผ edge์ ์ ์ด ๋ฃ์ต๋๋ค. $b$ ๋ $2 \log 3$ ์ด๊ณ . ์ด๊ฑด ๋ ธ๋์ ์ ์ด๋ฃ์ต๋๋ค.
- step 3์ ํด๋นํ๋ $\pdv{c}{b}$ ๋ $\frac{1}{2\sqrt{b}} = \frac{1}{2\sqrt{\log 3}}$ ์ ๋๋ค. $c = \sqrt{log 3}$ ์ ๋๋ค.
- step 4๋ ๋ง์ง๋ง์ผ๋ก, $\pdv{f}{c} = 1$, $\pdv{f}{b} = 1$ ์ด๋ฉฐ, $f$ ์ ์ต์ข ์ ์ธ ๊ฐ์ $\sqrt{2 \log 3} + 2 \log 3$ ์ ๋๋ค.
์ฌ๊ธฐ๊น์ง๊ฐ ์ง๊ธ 1 ๊ณผ์ ์ด ๋๋ ๊ฒ์
๋๋ค. ์ด๋ฅผ โForward passโ ๋ฑ์ผ๋ก ๋ถ๋ฆ
๋๋ค. ์ฌ๊ธฐ๊น์ง ๊ณ์ฐํ ๊ฒฐ๊ณผ๋ ์๋์ ๊ฐ์ต๋๋ค.
์ด์ , ๋ค์ ๊ฑฐ๊พธ๋ก ๋์๊ฐ๋ฉด์ ๊ณ์ฐํฉ๋๋ค.
- $\pdv{f}{c} = 1$. ์ด๋ฒ์๋, $c$์ ํด๋นํ๋ ๋ ธ๋์ ์ด ๊ฐ์ ์ ์ด๋ฃ์ต๋๋ค.
- $\pdv{c}{b} = \frac{1}{2 \sqrt{2 \log 3}}$ ์ด๋ฉฐ, $\pdv{f}{b}$ ๋ ์ฌ๊ธฐ์ ๋ฐ๋ก $b$๊ฐ $f$์ ์ํฅ์ ๋ฏธ์น๋ 1์ด ์์ผ๋ฏ๋ก (๊ฐ์ฅ ์๋ edge), $\pdv{f}{b} = \frac{1}{2 \sqrt{2 \log 3}} + 1$ ์ ๋๋ค. ๋ง์ฐฌ๊ฐ์ง๋ก $b$ ๋ ธ๋์ ์ ์ด ๋ฃ์ต๋๋ค.
- ๊ฐ์ ๋ฐฉ๋ฒ์ผ๋ก ๋ค๋ก ๊ณ์ ๋ฌ๋ฆฝ๋๋ค. $\pdv{f}{a} = \frac{1}{\sqrt{2 \log 3}} + 2$.
- ๊ฒฐ๊ตญ ๋ค ๊ณ์ฐํ๋ฉด, $\pdv{f}{x} = \frac{1}{3\sqrt{2\log 3}} + \frac{2}{3}$ ๊ณผ $\pdv{f}{y} = \sqrt{\frac{\log 3}{8}} + \log 3$ ์ด ๋จ์ ๊ฒ์ ๋๋ค. ์ด ๋ฐฉ๋ฒ์ Backpropagation์ด๋ผ๊ณ ๋ถ๋ฆ ๋๋ค.
Notes
- Autodiff๋, ์ฌ์ค์ forward autodiff ๋ฑ ๋ช๊ฐ์ง ๋ฐฉ๋ฒ์ด ๋ ์์ต๋๋ค. ๊ทธ์ค ๊ฐ์ฅ ๋ํ์ ์ธ ๋ฐฉ๋ฒ์ธ backpropagation์ ์๊ฐํ๋๋ฐ, ์ฌ์ค ์๊ฐํด ๋ณด๋ฉด ๋ฐ๋๋ก forward pass๋ง์ผ๋ก ๊ณ์ฐํ๋ ๋ฐฉ๋ฒ์ด ์์ต๋๋ค. ์ ๊ณผ์ ์์ ์๋ก ์ฌ๋ผ๊ฐ๋ DP๋ฅผ ํ ๋, $\pdv{b}{a}$ ๊ฐ์ ๊ฐ๋ค์ ๊ณ์ฐํด์ edge์ ์ ์ด๋๊ณ , ๋ฐ๋ก $b$ ๋ ธ๋์๋ $\pdv{b}{x}, \pdv{b}{y}$ ๋ฅผ ๊ทธ์๋ฆฌ์์ ๊ณ์ฐํด์ (์์ $\pdv{a}{x}$ ๋ ๋ ธ๋์ ์ ์ด๋จ์ ๊ฒ์ด๋ฏ๋ก) ๊ธฐ์ตํ๋ ๋ฐฉ๋ฒ์ด ์์ต๋๋ค. ์ด๋ ๊ฒ ๊ณ์ฐํ๋ฉด ํ๋ฒ forward๋ฅผ ๋ฌ๋ฆด ๋ ๋ชจ๋ ๊ณ์ฐ์ด ๋๋ฉ๋๋ค.
- ๊ทธ๋ผ์๋ ๋ถ๊ตฌํ๊ณ , ์ค์ ๋ก ์ฌ์ฉํ๋ deep learning์์์ gradient ๊ณ์ฐ์ ๋๋ถ๋ถ backpropagation์ ๋๋ค. ๊ทธ ์ด์ ๋, ์ง๊ธ ์ ์์์์๋ ์ ์ ์๋ ๋ถ๋ถ์ด๊ธด ํ์ง๋ง MLP๋ฅผ ๋ค์ ์๊ฐํด ๋ณด๋ฉด ๋๋ถ๋ถ์ ์ฐ์ฐ์ด ํ๋ ฌ๊ณฑ์ด๋ฏ๋ก ์ ์์์๋ ๋ฌ๋ฆฌ ๊ฐ edge์ ์ค์นผ๋ผ๊ฐ์ด ์๋๋ผ ํ๋ ฌ์ด ์ฐ์ฌ์ง๊ฒ ๋ฉ๋๋ค. ์ด๊ฒ ์ ์๋ฏธ๊ฐ ์๋๋ฉด, ๊ฒฐ๊ตญ์ โํ๋ ฌ๋ค์ ์์๋๋ก ๋ง์ดโ ๊ณฑํด์ผ ํ๋ค๋ ์๊ธฐ๊ณ โฆ ํ๋ ฌ ์ฌ๋ฌ๊ฐ๋ฅผ ๊ณฑํ ๋๋ ์์ ํ๋ ฌ๋ถํฐ ๊ณฑํ๊ณ ๊ทธ ๊ฒฐ๊ณผ๋ฅผ ํฐ ํ๋ ฌ๊ณผ ๊ณฑํ๋ ๊ฒ์ด ๋์ฒด๋ก ๋ณด๋ค ํจ์จ์ ์ ๋๋ค. (์ด ํํ์ ์๋ฒฝํ๊ฒ ์ ํํ์ง๋ ์์ง๋ง, ํ๋ ฌ์ ํฌ๊ธฐ๊ฐ ๋จ์กฐ์ฆ๊ฐํ๋ค๋ฉด ์ฐธ์ ๋๋ค. ์ฌ์ค์ ์ด ์์ฒด๊ฐ ํ๋ ฌ ๊ณฑ์ ์์ ๋ผ๋ (๋ฐฑ์ค์๋ ์๋..ใ ใ ) ๋งค์ฐ ์ ๋ช ํ DP ๋ฌธ์ ์ ๋๋ค.) ๊ทธ๋ฐ๋ฐ MLP๋ CNN์ด๋ , ๋คํธ์ํฌ ๋์ชฝ (์ถ๋ ฅ์ ๊ฐ๊น์ด ์ชฝ) ์ผ๋ก ํฅํ๋ฉด์ ์ ์ feature์ ๊ฐ์๋ฅผ ์ค์ฌ๋๊ฐ๋ ๊ฒ์ด ์ผ๋ฐ์ ์ด๋ฉฐ, ๋ฐ๋ผ์ backpropagation ๋ฐฉ๋ฒ์ผ๋ก ๋ค์์๋ถํฐ ๊ณฑํ๋ฉด์ ์ค๋๊ฒ ํ๋ ฌ ๊ณฑ์ ์ ๋ ๋นจ๋ฆฌ ํ ์ ์๊ธฐ ๋๋ฌธ์ ๋๋ค.
- Torch ๋ฑ ๋ฅ๋ฌ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ค์ ์ด backpropagation์ ์๋์ผ๋ก ์ ๋ฐ๋ผ๊ฐ ์ฃผ๊ธฐ ๋๋ฌธ์ ์ผ๋ฐ์ ์ผ๋ก๋ ๊ฑฑ์ ํ ํ์๊ฐ ์์ง๋ง, ์๋ก์ด loss function์ ์ ์ํ ๋๋ ํญ์ ๋ฏธ๋ถ๊ฐ๋ฅํ์ง๋ฅผ ์๊ฐํด์ผ ํฉ๋๋ค.
Reference
- ์์ธ๋ํ๊ต ์ฌ์ธต์ ๊ฒฝ๋ง์ ์ํ์ ๊ธฐ์ด ๊ฐ์์๋ฃ (๋งํฌ)
- Ian Goodfellow, Yoshua Bengio, & Aaron Courville (2016). Deep Learning. MIT Press.