Optimizers for Deep Learning
- ๋ง์ ๋ฅ๋ฌ๋ ๋ฌธํ์์ $x^i$๋ฅผ $i$๋ฒ์งธ ๋ฐ์ดํฐ๋ฅผ ์๋ฏธํ๋๋ฐ ์ฐ๊ณ , $x_i$๋ $x$๋ฒกํฐ์ $i$๋ฒ์งธ๋ฅผ ์๋ฏธํ๋๋ฐ ์๋๋ค. ์ด ๊ดํ์ ๋ฐ๋ฅด๊ฒ ์ต๋๋ค.
๋ฅ๋ฌ๋์ ๋ฌธ์ ๋ ์ํ์ ๊ด์ ์์ ํ์ํ๋ฉด ๊ฒฐ๊ตญ ๋ค์๊ณผ ๊ฐ์ด ์์ฝํ ์ ์์ต๋๋ค.
- ๋ฏธ์ง์ ํจ์ $f$์ ๋ํด ์๊ณ ์ ํ๋๋ฐ,
- ๋ชจ๋ ์ง์ ์ด ์๋ ์ด๋ค ์ง์ $x_i$ ๋ค์์๋ง ๊ทธ ๊ฐ $f(x^i) = y^i$ ๋ฅผ ์๊ณ ์๊ณ ,
- ๊ทธ๋์ ์ด๋ค ํ๋ํฐ $\ell$ ์ ์ ์ํด์, $\sum_i \ell(f(x^i), g(x^i))$๊ฐ ์์ $g$๋ฅผ $f$์ ๊ทผ์ฌ-ํจ์๋ก ์๊ฐํ๊ณ ์ถ์ต๋๋ค.
- ๊ทธ๋ฐ๋ฐ ์ด $g$๋ฅผ ๋ชจ๋ ํจ์์ ๊ณต๊ฐ์์ ์ต์ ํํ๋ ๊ฒ์ ์ผ๋ฐ์ ์ผ๋ก ๊ฐ๋ฅํ์ง ์์ผ๋ฏ๋ก,
- ์ด๋ค parameter $\theta$ ์ ์ํด ํํ๋๋ ํจ์๊ณต๊ฐ์ ๋ถ๋ถ์งํฉ $g_\theta$๋ง์ ์๊ฐํ๋ฉฐ,
- $\minimize \sum_i \ell(f(x^i), g_\theta(x^i))$ by moving $\theta$๋ก ์๊ฐํฉ๋๋ค.
๊ทธ๋ฐ๋ฐ $f(x_i)$๋ ์ด๋ฏธ ์๊ณ ์์ผ๋ฏ๋ก, ๊ฒฐ๊ตญ์ $\ell$ ์ด๋ผ๋ ํจ์๋ $\theta$์๋ง ์์กดํ๊ฒ ๋ฉ๋๋ค. ๋ฐ๋ผ์, ์ฐ๋ฆฌ๋ $\ell(\theta)$๋ฅผ ์ต์ํํ๋ $\theta$๋ฅผ ์ฐพ๋ ๊ฒ์ ๋ชฉํ๋ก ํฉ๋๋ค. ๋ค์ ์ด๋ฅผ ์ผ๋ฐํํด์, $\theta \in \R^n$ ์ผ๋ก ์๊ฐํ๋ฉด, $\ell : \R^n \to \R$ ํจ์์ ์ต์ ํ ๋ฌธ์ ๊ฐ ๋ฉ๋๋ค.
- ์ผ๋ฐ์ ์ผ๋ก, $f$๊ฐ ๋ณผ๋ก(Convex) ํจ์์ธ ๊ฒฝ์ฐ์๋ ์ ์ผํ ์ต์ ๊ฐ์ด ์กด์ฌํ๋ฉฐ, ์ถ๊ฐ๋ก ๊ฐ๋ณผ๋ก (Strictly Convex) ํจ์์ธ ๊ฒฝ์ฐ์๋ ์ต์ ๊ฐ์ ์ฃผ๋ ์ต์ $\theta$๋ ์ ์ผํจ์ ์๊ณ ์์ต๋๋ค.
- ๋ณผ๋กํจ์๋ฅผ ๋น๋กฏํ ์ฌ๋ฌ ์ข์ ์ฑ์ง์ ๊ฐ์ง ํจ์๋ค์ ์ต์ ํํ๋๊ฒ๋ ๋งค์ฐ ํฅ๋ฏธ๋ก์ด ์ด๋ก ๋ค์ด ๋ง์ด ์์ง๋ง, ์ฌ๊ธฐ์๋ ๊น์ด ๋ค๋ฃจ์ง๋ ์๊ฒ ์ต๋๋ค.
- Deep Learning์์ ์ฌ์ฉ๋๋ ํจ์๋ ๋ณผ๋ก์ฑ์ ๊ธฐ๋ํ ์ ์๊ธฐ ๋๋ฌธ์, ๋น์ฐํ global minimum์ ์ฐพ์ ์๋ ์์ต๋๋ค. ์ฐ๋ฆฌ๋ Heuristicํ๊ฒ ๊ทธ๋ด๋ฏํ ์์ ๊ฐ์ ์ฐพ๋ ๊ฒ์ ๋ชฉํํฉ๋๋ค.
- ๊ทธ๋ ๊ธฐ ๋๋ฌธ์, ์ฌ๊ธฐ ๋์จ ์ด๋ค ๋ง๋ ์๋ฐํ์ง ์์ต๋๋ค. ์๋ฆฌ๊ณผํ๋ถ์ ํ๋ฐ ๊ฑธ์น๊ณ ์๋ ์ ๊ฒ๋ ์ฝ๊ฐ ์์ํ ๋ถ๋ถ์ด๊ธฐ๋ ํ๋ฐ, convex๋ฅผ ๊ธฐ๋ํ ์ ์์ผ๋ฉด์๋ convexํ ํจ์์ ๋ํด ์ฆ๋ช ํ ๋ค์ ๊ทธ๋ฌ๋ฏ๋ก ์ ๋๋ค~ ๋ผ๊ณ ์ฃผ์ฅํ๊ฑฐ๋, convexํ ํจ์์์์กฐ์ฐจ ์๋ ด์ฑ์ด ์ฆ๋ช ์ด ์๋๋ ์๊ณ ๋ฆฌ์ฆ(ADAM) ์ ๋ง์๊ป ์ฌ์ฉํ๋ ๋ฑ(โฆ) ์กฐ๊ธ์ ์ ๊ธฐํ ๋ถ์ผ์ ๋๋ค. ํ์ง๋ง ์ค์ฉ์ ์ธ ๊ฐ์น๊ฐ ์ด์ชฝ์์๋ ๊ฐ์ฅ ์ฐ์ ์๋๊ธฐ ๋๋ฌธ์ ์ด์ฉ์ ์๋๋ฏ ์ถ์ต๋๋ค.
์ฃผ์ ์ด ๊ธ์ ๋น๋กฏํ์ฌ, optimizer๋ฅผ ์ค๋ช ํ๋ ๋ง์ ๊ธ์์ ์ด๋ค ํ ์คํธ ํจ์์ ๋ํ ์๋ฆ๋ค์ด visualization์ ์ด์ฉํ์ฌ ์ดํด๋ฅผ ๋์ต๋๋ค. ๋งค์ฐ ํจ๊ณผ์ ์ด๋ผ๊ณ ์๊ฐํ๊ณ ์ ๋ ์คํํด๋ณด๋ฉด์ ์ฌ๋ฐ์๊ธฐ ๋๋ฌธ์ ์ ๋ ์ง์ visualization์ ๋ง์ด ๋ง๋ค์ด ๋ดค์ง๋ง, ๊ทธ๋ํฝ์ ์ด์ฉํด์ ์ดํดํ ๋ ์ฃผ์ํ ์ ์ด ์์ต๋๋ค.
์ฐ๋ฆฌ๊ฐ ์ผ๋ฐ์ ์ผ๋ก ์ฌ์ฉํ๋ ๋ฅ ๋ฌ๋ ๋ชจ๋ธ์ ์์ญ๋ง, ์๋ฐฑ๋ง ๊ฐ์ ํ๋ผ๋ฏธํฐ๊ฐ ์์ต๋๋ค. ๋ค์ ๋งํด, $\R^{1,000,000}$ ๊ฐ์ ๊ณต๊ฐ์์ ๋ญ๊ฐ๋ฅผ ์ต์ ํํ๋ค๋ ๋ป์ ๋๋ค. ์ด ๊ธ์์ ๋ค์ ์ค๋ช ํ โ์ข์โ ์ต์ ํ ์๊ณ ๋ฆฌ์ฆ๋ค์ธ RMSProp, Adam ๋ฑ๋ฑ์ ์ด๋ฐ ๊ณ ์ฐจ์ ๊ณต๊ฐ์์์ ์ต์ ํ๋ฅผ ๋น ๋ฅด๊ฒ ํ ๋ชฉ์ ์ผ๋ก ๊ฐ๋ฐ๋์์ต๋๋ค. ์ด๋ฐ ์ํฉ๊ณผ๋ ๋ฌ๋ฆฌ ์ฐ๋ฆฌ์ ๊ทธ๋ํฝ์ $\R^2 \to \R$ ํจ์์ ์ต์ ํ๋ฅผ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์, ์ธํ ์ด ๋ง์ด ๋ค๋ฆ ๋๋ค. ๊ฐ์ฅ ํฐ ์ฐจ์ด๋ learning rate์ธ๋ฐ, ํ๋ผ๋ฏธํฐ๊ฐ ๋ง๋ค๋๊ฒ์ ๊ทธ๋งํผ ๊ฐ ํ๋ผ๋ฏธํฐ์ ๋ํ ์์กด๋๋ ๋ฎ์์ง ๊ฒ์ด๊ณ ๊ทธ๋ฌ๋ฉด ๋ฏธ์ธ์กฐ์ ์ ์ํ ์์ learning rate๊ฐ ์ผ๋ฐ์ ์ ๋๋ค. ํนํ $\R^2$ ๊ฐ์ ๋๋ฌด ์ ์ฐจ์์ ๊ณต๊ฐ์์๋ Adam๊ฐ์ ์๊ณ ๋ฆฌ์ฆ๋ค์ learning rate๊ฐ ๊ฐ์๋ ํจ์ฌ ์๋ ด์ด ๋๋ฆฝ๋๋ค.
์ด๊ฑธ ๋ณด์ ํ๊ธฐ ์ํด, ๋ฌ๋ ค๋๊ฐ๋ ์ ๋ค์ ์ด๋ฐ์ ์ด๋ ์๋๊ฐ ๋น์ทํ๋๋ก ์๋์ ์ผ๋ก learning rate๋ฅผ ์กฐ์ ํ์ต๋๋ค. ์ฆ, ์ด ๊ทธ๋ํฝ์ ๋ณด๊ณ โ์ Adam์ด ๋น ๋ฅด๊ตฐโ ์ด๋ฐ ์๊ฐ์ ํ๋ ๊ฒ์ ์ฌ๋ฐ๋ฅธ ์ดํด๊ฐ ์๋ ์๋ ์๋ค๋ ๊ฒ์ ๋๋ค. ๊ทธ๋ํฝ์ ์ด๋๊น์ง๋ ๊ธ๋ก ์ค๋ช ํ ๋ฐ๋ฅผ ๋ณด์ฌ์ฃผ๊ธฐ ์ํ ์์์ด๋ฏ๋ก, ๋ง์ ํ๋์ด ๊ฐํด์ก์์ ์ผ๋์ ๋๊ณ , ๋งํ๊ณ ์ ํ๋ ๋ฐ๊ฐ ๋ฌด์์ธ์ง๋ฅผ ํ์ ํ๋ฉด ๋ ๊ฒ ๊ฐ์ต๋๋ค.
์ด ๋ถ์ผ์ ๋ํ ์ ์ดํด๋ ์์ง ๋ง์ด ๋ถ์กฑํ๊ธฐ ๋๋ฌธ์, ์ค๋ช ์ ์ค๋ฅ๊ฐ ์๊ฑฐ๋ ๋ณด๊ฐํ ์ , ๋์น ์ ์ด ์์ ์ ์์ต๋๋ค. ๋๊ธ ๋ฑ์ผ๋ก ํผ๋๋ฐฑ์ ํญ์ ํ์ํฉ๋๋ค :)
(Stochastic) Gradient Descent
์ด ๊ธ์์๋ ๋ ์๊ฐ gradient descent๋ฅผ ์ดํดํ๊ณ ์๋ค๊ณ ๊ฐ์ ํ์ง๋ง, ํน์ ์๋๋ผ๋ฉด Gradient Descent์ ๋ํ ํฌ์คํ ๋ฅผ ์ฐธ๊ณ ํด ์ฃผ์ธ์.
Gradient Descent๋ ํ ์คํ ํ ์คํ ์ด ๋๋ฌด ๋๋ฆฌ๊ธฐ ๋๋ฌธ์ (๋ชจ๋ ๋ฐ์ดํฐ๋ฅผ ํ๋ฐํด ๋์์ผ ํด์), ๋์ ํ ๋ฐ์ดํฐ ๋๋ ์์์ ๋ฐ์ดํฐ๋ก ์์ฃผ ๋ฐ๋ stochastic gradient descent๊ฐ ๊ธฐ๋ณธ์ด ๋ฉ๋๋ค. ๊ธฐ๋ณธ์ ์ธ SGD์ ๋ํด์๋ Stochastic Gradient Descent ํฌ์คํ ์์ ๋ค๋ฃจ์์ต๋๋ค. (์ธ์ ๊ฐ ๋ฆฌํผ๋ ์์ ) ์ดํ์ ๋ชจ๋ ์๊ณ ๋ฆฌ์ฆ๋ค์ ์ด๋ค์์ผ๋ก๋ SGD์ ๊ธฐ๋ฐํ๊ธฐ ๋๋ฌธ์, SGD์ ์์ ํ๋ฒ ๋ฆฌ๋ทฐํ ๊ฐ์น๊ฐ ์์ต๋๋ค.
\[i(k) \sim \uniform{1}{N},\quad \theta^{k+1} = \theta^k - \alpha \nabla{f_{i(k)}(\theta^k)}\]์ฌ๊ธฐ์ $\nabla f_{i(k)}$ ๋์ ๋ค๋ฅธ ์ ๋นํ $g_k$๋ฅผ ์ก์๋ ๋๋๋ฐ (Batched-gradient), ๋์ $g_k$๋ $\nabla F(x^k)$ ์ Unbiased Estimator ์ฌ์ผ ํฉ๋๋ค. ๋ํ ์ข๋ ์์ ๊ฐ๋จํ๊ฒ ์ฐ๊ธฐ ์ํด, ์์ผ๋ก $i(k)$ ์ ์ ํ์ ๋
ผ์ํ์ง ์๊ฒ ์ต๋๋ค. ์ด๊ฑด ๊ทธ๋ฅ ๋๋คํ๊ฒ ๋๋ฆฌ๋ฉด ๋ฉ๋๋ค. ์ฆ, ์ SGD๋ฅผ ๋ค์ ์ธ ๋,
\(\theta^{k+1} = \theta^k - \alpha g^k\)
์ด๋ ๊ฒ๋ง ์ฐ๋๋ผ๋, $g^k$๋ฅผ ๋๋คํ๊ฒ ๊ณจ๋ผ์ง index $i(k)$์ ๋ํ (๋๋, batch๋ฅผ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ batch-gradient) $\nabla f_i(k)$์ ๊ฐ์ผ๋ก ์ฝ์ด์ฃผ๋ฉด ๋ฉ๋๋ค. Batch์ ๋ํ ์์ธํ ์๊ธฐ๋ ์์ ๋งํฌ๊ฑธ๋ฆฐ SGD ํฌ์คํ
์ ์ฝ์ด์ฃผ์ธ์.
์ฌ๊ธฐ์ ๋ค๋ฅธ๊ฑด ๋๋ถ๋ถ ํฐ ๋ฌธ์ ๊ฐ ์๋๋ฐ, $\alpha$, learning rate์ ์ ํ์ด ๋ฌธ์ ์ ๋๋ค. Learning rate๊ฐ ๋๋ฌด ํฌ๊ฑฐ๋ ์์ผ๋ฉด ์ต์ ํ๊ฐ ์ ์ด๋ฃจ์ด์ง์ง ์์ต๋๋ค.
- Learning rate๊ฐ ๋๋ฌด ์์ผ๋ฉด, ํจ์๊ฐ์ ์๋ ด์ด ๋๋ฌด ๋๋ฆฝ๋๋ค.
- Learning rate๊ฐ ๋๋ฌด ํฌ๋ฉด, ๋ชฉํํ๋ ์ ์ ์ง๋์ณ์ ์๋ ดํ์ง ์์ ์๋ ์์ต๋๋ค.
(์ฌ์ง์ถ์ฒ : https://www.deeplearningwizard.com/deep_learning/boosting_models_pytorch/lr_scheduling/)
์ผ๋ฐ์ ์ผ๋ก ์ข์ Learning rate๋ฅผ ์ก๋ ๋ฐฉ๋ฒ์ด ์๋๊ฒ์ ์๋๊ณ , ๋๋ ค๋ณด๋ฉด์ ์ฐพ์์ผ ํฉ๋๋ค.
Momentum SGD
์ $\alpha$๊ฐ ์์์ฌ์ผ ํ ์ด์ ๋ ๋ณ๋ก ์์ต๋๋ค. ์ฆ, ๋ง์ฝ ๋งค ์๊ฐ ํ์ฌ ์์ ๊น์ง ์๊ณ ์๋ ์ด๋ค ์ ๋ณด๋ฅผ ์ด์ฉํด์ learning rate๋ฅผ ์กฐ์ ํด ์ค ์ ์๋ค๋ฉด ๋ ์ข์ ์๊ณ ๋ฆฌ์ฆ์ด ๋ ๊ฒ์ ๋๋ค.
Momentum์ด๋, SGD์ โ์ด์ ์ ๊ฐ๋ ๋ฐฉํฅ์ผ๋ก ์กฐ๊ธ ๋ ๊ฐ๊ณ ์ ํ๋โ ๊ด์ฑ์ ์ถ๊ฐํ๋ ๋ฐฉ๋ฒ์
๋๋ค. ์ด ๋ฐฉ๋ฒ์ ๋จผ์ ์์์ผ๋ก ์ฐ๋ฉดโฆ
\(\begin{align*}
v^{k+1} &= g^k + \beta v^k \\
\theta^{k+1} &= \theta^k - \alpha v^{k+1}
\end{align*}\)
์ฆ, โ๋ฐฉ๊ธ ์ ์ ๊ฐ๋ ์๋๋ฒกํฐ์ $\beta$๋ฐฐโ ๋ฅผ ํ์ฌ ๊ฐ๊ณ ์ถ์ ๋ฐฉํฅ (gradient)์ ๋ํด์ฃผ๋ ๊ฒ์
๋๋ค. ์ด์ ์ ๊ฐ๋ ๋ฐฉํฅ์ผ๋ก ๊ณ์ ๊ฐ๋ ค๋ ๊ฒฝํฅ์ฑ์ด ์๊ธฐ ๋๋ฌธ์, ์ด๋ฐ ํจ์๋ฅผ ์๋์ ์ผ๋ก ๋น ๋ฅด๊ฒ ํ์ถํ ์ ์์ต๋๋ค.
์ ๋ณด๋ฉด, $x$๋ฐฉํฅ์ผ๋ก๋ gradient์ ๋นํด ๊ฐ์ผํ ๊ฑฐ๋ฆฌ๊ฐ ๋จผ๋ฐ ๋นํด $y$๋ฐฉํฅ์ gradient๊ฐ ๋ง์ด ๋ณํ๊ณ ์์ต๋๋ค. ์ผ๋ฐ SGD๋ ํ๋ฐ (valley๋ฅผ ๋ด๋ ค์จ ํ)์ gradient๊ฐ ๋๋ฌด ์์์ ํ์ต์ด ์ ๋์ง ์๋ ๋ฐ๋ฉด, Momentum์ด ์ถ๊ฐ๋ SGD๋ ์ด์ ์ ๊ฐ๋ ๋ฐฉํฅ์ ์ ์ด์ฉํด์ ๋น ๋ฅด๊ฒ ์งํํ๋ ๊ฒ์ ๊ด์ฐฐํ ์ ์์ต๋๋ค.
๋ํ, ์์ learning rate๋ฅผ ์ธ ๋ ๋ฐ์ํ๋ ๋ํ๋์ ๋ฌธ์ ๋ ์์ local minima๋ฅผ ์ง๋์น์ง ๋ชปํ๊ณ ๋ฌถ์ฌ๋ฒ๋ฆฌ๋ ํ์์ ๋๋ค. ์ด๋ momentum์ local minima์์ ๋น ์ ธ๋์ฌ ์ ์๋ ์๋ก์ด ๊ฐ๋ฅ์ฑ์ ์ ๊ณตํ๊ธฐ๋ ํฉ๋๋ค.
์๋ ๊ทธ๋ฆผ์ ์ข ์ด๊ฑฐ์ง๋ก ๋ผ์๋ง์ถ ํจ์๋ฅผ ์ด์ฉํด์ ๋ง๋ ์์์ง๋ง, (4, 1) ๋ถ๊ทผ์ ์์ฃผ ์์ local minima๋ฅผ ์ง๋์น์ง ๋ชปํ๋ SGD์ momentum์ ํ (๊ด์ฑ) ์ผ๋ก ์ง๋์ณ์ global minima๋ฅผ ํฅํด ๊ฐ๋ momentum SGD๋ฅผ ๋น๊ตํด ๋ณผ ์ ์์ต๋๋ค.
Nesterov Accelerated Gradient
NAG๋ผ๊ณ ๋ ๋ถ๋ฆฌ๋ ์ด ๋ฐฉ๋ฒ์, ์๋ ๋์์์๋ ์ ์ ์๋ฏ momentum์ ์ฐ๋, ๋ฐฉ๊ธ์ ์ momentum๊ณผ ์ฌ๊ธฐ์ gradient๋ฅผ ํฉ์น๋๊ฒ ์๋๋ผ ๋ฐฉ๊ธ์ ์ momentum์ ์ด์ฐจํผ ๊ฐ ๊ฒ์ด๋ฏ๋ก ๊ทธ๋งํผ์ ์ผ๋จ ๋ฌด์์ ๊ฐ ๋ค์ ๊ฑฐ๊ธฐ์ gradient๋ฅผ ์ฐพ๋ ๋ฐฉ๋ฒ์
๋๋ค.
(์ฌ์ง ์ถ์ฒ : stanford CS231)
์ง๊ด์ ์ผ๋ก, ์ด ๋ฐฉ๋ฒ์ด ๋ณด๋ค ์ ์๋ํ ์ ์์ ๊ฒ ๊ฐ์ ์ด์ ๋ momentum SGD๊ฐ ๊ฐ๋ ๋น ๋ฅธ ์๋ ด์๋ (๊ด์ฑ) ์ ์ ์งํ๋ฉด์๋ ์ข๋ ์์ ์ ์ด๊ธฐ ๋๋ฌธ์ ๋๋ค. ์๋ฅผ ๋ค์ด momentum ๋ฐฉํฅ์ด $-y$ ๋ฐฉํฅ์ด๊ณ , ์ง๊ธ์ gradient๋ $-y$๋ฐฉํฅ์ด๋ผ์ $-y$๋ก ํ์ฐธ๋์ ๊ฐ๋ณด๋๊น ๋ค์ $+y$ ๋ฐฉํฅ gradient๊ฐ ์๋ ์ํฉ์ ์๊ฐํด ๋ณผ ์ ์์ต๋๋ค. ์ด๋ momentum SGD๊ฐ ๊ด์ฑ๋๋ฌธ์ ๋ฐ๋๋ก ์ง๋ํ๊ฒ ๋ง๋ ๋ค๋ ๋ป์ธ๋ฐ, NAG๋ ์ด๋ momentum๋งํผ์ ๊ฐ๋ณธ๋ค์ ๊ฑฐ๊ธฐ์ gradient๋ฅผ ์ฌ๊ธฐ ๋๋ฌธ์ ์ด๋ฐ ์ง๋์ด ๋ํฉ๋๋ค.
์์์ ๋ณธ parabolaํํ์์, ์ข ํฐ learning rate๋ฅผ ์ก์์ momentum์ด ์ข ํฌ๊ฒ ์ง๋ํ๊ฒ ๋ง๋ค๋ฉด ์ด๋ฐ ์ฌ๋ฐ๋ ์์๋ฅผ ์ป์ต๋๋ค.
SGD๋ ์ ๋ฉ๋ฆฌ์ ๋ค์ ํ๋ฉด์ ๊ธฐ์ด๋ค๋๊ณ ์๊ณ (โฆ) ๋ชจ๋ฉํ
์ ์์๋ ์ง๋์ด ๋๋ฌด ํฐ๋ฐ ๋นํด, NAG๊ฐ ์ข๋ ์์ ์ ์ผ๋ก ์์ง์ด๊ณ ์์ต๋๋ค.
RMSProp
G.Hinton. Lecture Notes for csc321
RMSProp์ ์ธ๊ณต์ง๋ฅ ๋ถ์ผ์ ์ ๊ตฌ์์ด์ ์ต๊ณ ์ big name์ค ํ๋์ธ Geoffry Hinton ๊ต์๋์ Coursera ๊ฐ์์์ ์ฒ์ ์๊ฐ๋ ์๊ณ ๋ฆฌ์ฆ์ ๋๋ค. ๋๋ต์ ์ธ ์์ด๋์ด๋, $g^k$์ ์ง๊ธ๊น์ง์ ํ์คํ ๋ฆฌ๋ฅผ ๋ณผ ๋, ๋ณ๋์ด ์ข ์ฌํ ๊ฐ์ ์ง๋์ค์ธ ๊ฐ์ผ ๊ฐ๋ฅ์ฑ์ด ๋์ ๋ณด์ด๋ฏ๋ก ๊ทธ๋งํผ learning rate๋ฅผ ๊น์์ ์ ์ฉํ๊ฒ ๋ค๋ ์์ด๋์ด์ ๋๋ค. ์ฆ $g^k$์ ๊ฐ โ๋ฐฉํฅโ ๋ง๋ค, ๋ค๋ฅธ learning rate๋ฅผ ์ ์ฉํ๋ค๋ ์์ด๋์ด๊ฐ ํต์ฌ์ด๋ผ๊ณ ํ ์ ์๊ฒ ์ต๋๋ค.
์ด๋ฅผ ์ํด, $(g^k_i)^2$์ estimation์ $(m^k)_i$ ๋ก ์ ์ํ๊ณ , ์ด ๊ฐ์ square root๋ก ๋๋ ์ค๋๋ค. ์ฆ ์ด๋ค ํ๋ผ๋ฏธํฐ ๋ฐฉํฅ์ผ๋ก์ ์ง๊ธ๊น์ง์ gradient๊ฐ์ โRoot Mean Square averageโ ๋งํผ์ ์ด์ฉ, ๋ณ๋์ด ์ฌํ ํ๋ผ๋ฏธํฐ๊ฐ ๋ฌด์์ธ์ง๋ฅผ ์ฐพ์ต๋๋ค. ๋ฐ๋๋ก, gradient์ ๋ณ๋์ด ๊ฑฐ์ ์์ผ๋ฉด ๋น๊ต์ smoothํ ๊ธธ์ (๊ธฐ์ธ์ด์ ธ ์์ ์๋ ์์ง๋ง, ๋ณ๋์ด ์ฌํ์ง ์์) ๊ฐ๊ณ ์์ ๊ฒ์ด๋ฏ๋ก, ์ข ๊ณผ๊ฐํ๊ฒ ๊ฐ๋ ๋ฉ๋๋ค. ์ด๋ฅผ ์์์ผ๋ก ํํํ๋ฉด ๋ค์๊ณผ ๊ฐ์ต๋๋ค. \(\begin{align*} m^k &= \beta_2 m^{k-1} + (1 - \beta_2) (g^k \odot g^k) \\ x^{k+1} &= x^k - \alpha g^k \oslash \sqrt{m^k + \epsilon} \end{align*}\) ์ฌ๊ธฐ์ ๋ณดํต $\beta_2$ ๊ฐ์ 0.99 ์ ๋๋ฅผ ์ฐ๊ณ , ์ ์์์ $\odot$ ์ $\oslash$ ๋ elementwise ๊ณฑ์ ๊ณผ ๋๋์ ์ ์๋ฏธํฉ๋๋ค. ์ ์์ ๋ณด๋ฉด, $m^k$๋ (์ด์ $m^k$์ 0.99๋ฐฐ) ์ (ํ์ฌ gradient์ ๊ฐ ๋ฐฉํฅ๋ณ ๊ฐ์ ์ ๊ณฑ์ 0.01๋ฐฐ)๋ฅผ ๋ํด์ ๋ง๋ค์ด์ง๋๋ฐ, ์ด๋ ์ฆ ์ด์ ์ gradient๋ค์ด ์ ์ ์ํฅ์ ๋ ๋ฏธ์น๋ (์ฆ, ๊ฐ์คํ๊ท ์ ๋ด๋, ์ต๊ทผ๊ฐ์ด ์ข๋ ์ ๋ฐ์๋๋๋ก ํ๋ ๊ฐ์คํ๊ท , ์ด๋ฅผ โExponentially (decaying) moving averageโ ๋ผ ํด์ EMA๋ผ ์๋๋ค) ํํ์ ์์์ ์๋ฏธํฉ๋๋ค. $\epsilon$์ ์ค์์ค์ฐจ์ ๋ฐ๋ผ ์์์ square root๋ฅผ ์ทจํ๋ (์ํ์ ์ผ๋ก๋ $m^k$์ ๊ฐ ๊ฐ์ด ์์๊ฐ ๋์ฌ ์ ์์ง๋ง, ๊ณ์ฐํด์ 0์ธ ๊ฐ์ ์ค์์ค์ฐจ๋๋ฌธ์ 0์ ๋งค์ฐ ๊ฐ๊น์ด ์์์ผ ์ ์์ต๋๋ค) ํ๋ก๊ทธ๋จ์ ์ธ ์ค๋ฅ๋ฅผ ๋ฐฉ์งํ๊ธฐ ์ํด ์ถ๊ฐ๋ ๊ฐ์ผ๋ก, ์ํ์ ์ผ๋ก๋ ์๋ฏธ๊ฐ ์์ต๋๋ค.
RMSProp์ ๋น๋กฏํ์ฌ, ์ดํ์ ๋ค๋ฃฐ ์๊ณ ๋ฆฌ์ฆ๋ค์๊ฒ parabola๋ ๋๋ฌด ์ฌ์ด ์ผ์ด์ค์ด๋ฏ๋ก Bealeโs function์ผ๋ก ๊ทธ๋ํฝ์ ๊ทธ๋ ค๋ณด๊ฒ ์ต๋๋ค.
๋ง์ฐฌ๊ฐ์ง๋ก ๊ทน์ด๋ฐ ๋ช๋ฒ์ ์๋๊ฐ ๋น์ทํ์ง๋ง, SGD๋ gradient๊ฐ ์์์ง๋ฉด ๊ฐ๊ณณ์ ์๋ ๋ฐ๋ฉด NAG๋ RMSProp์ ๊ด์ฑ์ ํ์ผ๋ก ์ข ๋น ๋ฅด๊ฒ ๋์๊ฐ ์ ์์ต๋๋ค. ๋์ ์ฝ๊ฐ ๋ถ์์ ํ ๋ชจ์ต์ ๋ณด์ด๋๋ฐ, RMSProp์ด ์ข๋ ๋น ๋ฅด๊ฒ ์ ๊ธธ์ ๋์์ค๋ ๋ชจ์ต๋ ํ์ธํ ์ ์์ต๋๋ค.
Adam
D. P. Kingma and J. Ba, Adam: A method for stochastic optimization, ICLR, 2015.
Adam์ Adaptive Moment ์ ์ฝ์๋ก, ๊ฐ๋จํ ์์ฝํ๋ฉด Momentum-SGD์ RMSProp์ ์์ด๋์ ์๊ณ ๋ฆฌ์ฆ์ ๋๋ค.
- Momentum-SGD์์ ๊ฐ์ ธ์ค๋ ์์ด๋์ด๋, $g^k$ ๋์ ๋ชจ๋ฉํ ์ ํฌํจํ gradient ๊ฐ์ ์ฌ์ฉํ๋ ๊ฒ์ ๋๋ค.
- RMSProp์ $(g^k_i)^2$์ running average๋ฅผ ์ทจํ๋ ์์ด๋์ด๋ฅผ ๊ทธ๋๋ก ๊ฐ์ ธ์ต๋๋ค.
- ๋จ, Momentum๋ RMSProp์ฒ๋ผ ๊ณ์ฐํฉ๋๋ค. ์ด ๋ง์ด ๋ฌด์จ ๋ป์ธ์ง๋ ์์์ ๋ณด๋ฉด์ ์๊ธฐํ๊ฒ ์ต๋๋ค. \(\begin{align*} m_1^k &= \beta_1m_1^{k-1} + (1-\beta_1)g^k, \tilde{m_1^k} = \frac{m_1^k}{1 - \beta_1^{k+1}} \\ m_2^k &= \beta_2 m_2^{k-1} + (1 - \beta_2) g^k \odot g^k, \tilde{m_2^k} = \frac{m_2^k}{1 - \beta_2^{k+1}} \\ x^{k+1} &= x^k - \alpha \tilde{m_1}^{k} \oslash \sqrt{\tilde{m_2}^k + \epsilon} \end{align*}\)
์์์์, $m_1$์ $g$๊ฐ์ EMA๊ณ (์ด ํํ์ ๋ฐ๋ก ์ RMSProp์์ ์ผ์ต๋๋ค), $m_2$ ๊ฐ์ $g \odot g$ ์ EMA์ ๋๋ค.
- ์ผ๋ฐ์ ์ผ๋ก, $\beta_1 = 0.9$, $\beta_2 = 0.99$๋ฅผ ๋ง์ด ์๋๋ค.
- $\tilde{m_1}$ ์ ์ฐ๋ ๊ฒ์ ๋ณผ ์ ์๋๋ฐ, ์ด๊ฑด ์ฒ์์ $m_1, m_2$ ์ ๊ฐ์ ๋ชจ๋ฅด๊ธฐ ๋๋ฌธ์ ์ด๊ฑธ 0์ผ๋ก ์ด๊ธฐํํ ์๋ฐ์ ์์ด์์ ๋๋ค. 0์ผ๋ก ์ด๊ธฐํํ $m_1, m_2$ ๋๋ฌธ์ ์ด๋ฐ gradient๊ฐ ์ฌ๋ฐ๋ฅด๊ฒ ๋ฐ์๋์ง ์๋ ์ผ์ ๋ง๊ธฐ ์ํด (์๊ทธ๋ฌ๋ฉด ์ฒ์ $m_1$์ $g_1$์ 1/10์ ๋ถ๊ณผํ๊ฒ ๋ฉ๋๋ค), ์ด๋ฅผ bias-correctionํด ์ค๋๋ค.
Adam์ ํ๋ Deep Learning์์ ๊ฐ์ฅ ๋ง์ด ์ฌ์ฉ๋๋ ์ต์ ํ ์๊ณ ๋ฆฌ์ฆ์ ๋๋ค. ๊ฐ๋จํ ๋งํด์, CNN๊ฐ์๊ฑธ trainingํ๋๋ฐ ๋ญ ์ธ์ง ๋ชจ๋ฅด๋ฉด ๊ทธ๋ฅ Adam์ ์จ๋ณด๋ฉด ๋ฉ๋๋ค.
๊ทธ์ ๋์ ์ค์๋๋ฅผ ๊ฐ์ง ์๊ณ ๋ฆฌ์ฆ์ด๊ธฐ ๋๋ฌธ์, ์ญ์ผ๋ก ์ฌ๊ธฐ์๋ ์งง๊ฒ๋ง ๋ค๋ฃจ๊ณ ์ธ์ ๊ฐ CS-Adventure ํฌ์คํ ์ผ๋ก Adam ๋ ผ๋ฌธ์ ์ฝ๊ณ ์ ๋ฆฌํ๋ ค๊ณ ํฉ๋๋ค. ์ฌ๋ฐ๋ ์ผํ๋ก, 2018 ICML์๋ Adam์ ๋ณผ๋กํจ์์ ๋ํ ์๋ ด์ฑ์ ๋ฐ๋ก๋ฅผ ์ฐพ์ ๋ ผ๋ฌธ์ด ๋ฐํ๋์์ต๋๋ค. Adam์ ์๋ณธ ๋ ผ๋ฌธ์๋ โConvexํ๋ฉด ์๋ ดํ๋คโ ๋ ์ฆ๋ช ์ด ์๋๋ฐ, ์ด ์ฆ๋ช ์๋ ์ค๋ฅ๊ฐ ์๊ณ , ์ค์ ๋ก๋ ๋ณผ๋กํจ์๋ผ๋ ์กฐ๊ฑด ํ์์๋ ์๋ ดํ์ง ์๋๋ค๊ณ ํฉ๋๋ค. Nevertheless, ์ด์ฐจํผ ๊ทธ๋์ Adam ์ฐ๋๊ฒ ๋ณผ๋กํจ์์์ ์ ์๋ ดํด์ ์ฐ๋๊ฒ ์๋์๊ธฐ ๋๋ฌธ์ ๊ฐ์์น ์๊ณ ๋ค๋ค ์ ์ฐ๊ณ ์์ต๋๋ค.
๋ง์ฐฌ๊ฐ์ง๋ก Bealeโs function์ผ๋ก ๊ทธ๋ํฝ์ ๊ทธ๋ ค๋ณด๊ฒ ์ต๋๋ค.
๋งจ ์์ ์ฃผ์์์๋ ๋งํ์ง๋ง, ์ด๊ฑด ์ ๋ ๊ฐ ์๊ณ ๋ฆฌ์ฆ๋ค์ ๊ณต์ ํ ๋น๊ต๊ฐ ์๋๋๋ค. ์ด๋ฐ ๋ช๋ฒ์ ์๋๊ฐ (๋์ผ๋ก ๋ณด๊ธฐ์) ๋น์ทํ๋๋ก ์ต์ง๋ก learning rate๋ฅผ ํค์์ ๋ง์ท๋๋ฐ (์๊ทธ๋ฌ๋ฉด โ์์ ์ฑโ ๋๋ฌธ์ Adam์ด ์ด๋ฐ์ ๋๋ฌด ๋๋ฆฝ๋๋ค), ์ฐ๋ฆฌ๊ฐ ๊ฒจ์ฐ ์๋ฐฑ๊ฐ์ ์ ์ ๋ํด ํด๋ณด๋๊ฒ๊ณผ๋ ๋ฌ๋ฆฌ ์ค์ ๋ฅ๋ฌ๋์์๋ ์์ฒ ์๋ง๊ฐ์ ๋ฐ์ดํฐ๋ฅผ ์์ญ๋ฐํด ์ด์์ฉ ๋๋ฆด๊ฑฐ๋ผ์ ๊ทผ๋ณธ์ ์ธ ๋ชฉํ๊ฐ ์ข ๋ค๋ฆ ๋๋ค. ํจ์์ ๋ณต์ก๋๋ ๋ง์ด ๋ค๋ฅด๊ณ โฆ์ด ๊ทธ๋ฆผ์, SGD-NAG์ lr์ด 0.0015์ด๊ณ , Adam์ LR์ด 0.7์ ๋๋ค. ๋ณ๋ช ์ ์กฐ๊ธ ํ์๋ฉด, SGD๋ LR 0.1๋ ๊ฐ๋น์ ๋ชปํ๊ณ ์๋ ด์ ๋ชป์ํค๋๋ฐ ๋นํด ์ด๋ ๊ฒ ๋์ LR๋ก๋ ์๋ ด์ํค๋๊ฒ ๋ฌ๋ฆฌ๋งํ๋ฉด Adam์ ์์ ์ฑ์ ๋ณด์ฌ์ค๋ค๊ณ ๋ณผ ์ ์์ต๋๋ค.
์ผ๋ฐ์ ์ผ๋ก, Adam๊ณผ ๊ฐ์ Adaptive rate (RMSProp๋ ๋ง์ฐฌ๊ฐ์ง) ์๊ณ ๋ฆฌ์ฆ๋ค์ ๊ทธ ์๋ ด ์๋๊ฐ ๋งค์ฐ ๋น ๋ฅด์ง๋ง, ๊ฐ๋ ์ด๋ค ์์์์๋ ์ ์๋ ดํ์ง ๋ชปํ๊ธฐ๋ ํฉ๋๋ค. ์ด๋๋ SGD๊ฐ์ ์๊ณ ๋ฆฌ์ฆ๋ค์ด (์ถฉ๋ถํ learning rate๋ฅผ ์ ํ๋ํ๋ฉด) ์ผ๋ฐ์ ์ผ๋ก ์๋ ด์ฑ์ด ์ข๋ค๊ณ ์๋ ค์ ธ ์์ต๋๋ค.
Other algorithms
์ฌ๊ธฐ์ ๋ค๋ฃจ์ง ์์ ์๊ณ ๋ฆฌ์ฆ ์ค์๋ ์ฌ๋ฐ๊ณ ์ ์ฉํ๊ฒ ๋ง์ต๋๋ค.
- AMSGrad (Reddi et al, 2018) : Adam์ ์๋ ด์ฑ์ด ์ฑ๋ฆฝํ์ง ์์์ ์ง์ ํ ๋ ผ๋ฌธ์์ ๊ทธ ๋์์ผ๋ก ๋ค๊ณ ๋์จ ์๊ณ ๋ฆฌ์ฆ์ ๋๋ค. ์ธ์ ๊ฐ Adam์ ๋ํด ์์ธํ ๊ณต๋ถํ ๋ Adam + Convexity์ ๋ฐ๋ก + AMSGrad๋ก ๋ค๋ฃจ๋ ค๊ณ ์๊ฐํ๊ณ ์์ต๋๋ค.
- NAdam (Dozat, 2016) : Adam์ $m_1$์ Nesterov ๋ฐฉ์์ผ๋ก ๋ฐ๊พผ ๋ฐฉ๋ฒ์ ๋๋ค.
- AdamW (Loshchilov et al, 2017) : Adam ์ regularization์ ์ ๊ทน์ ์ผ๋ก ๋์ ํด์ Adam์ generalization issue๋ฅผ ํด๊ฒฐํ๋ ๋ฐฉ๋ฒ์ ๋๋ค. ์ธ์ ๊ฐ ๋ค๋ฃฐ ์๋? ์์ต๋๋ค.
- AdaDelta, AdaGrad, AdaMax : Adaptive rate ์๊ณ ๋ฆฌ์ฆ๋ค์ด ๊ต์ฅํ ๋ค์ํ๊ฒ ์๋๋ฐ, ์ฝ๊ฐ์ฉ ๋ค๋ฅธ ์์ด๋์ด๊ฐ ์์ง๋ง ๊ธฐ๋ณธ ์์ด๋์ด (gradient์ ๊ทธ square ๋ฑ์ EMAํด์ ๊ณ์ฐ) ๋ ๋ฒจ์์๋ ๋ค๋ฅด์ง ์์ผ๋ฏ๋ก ๋์ด๊ฐ๋๋ค.
- BFGS๋ฅผ ๋น๋กฏํ 2nd-order methods : ์ํ์ ์ผ๋ก ํฅ๋ฏธ๋กญ์ง๋ง ์ด ๊ธ์์๋ ์ผ๋ถ๋ฌ ์ธ๊ธ์ ํผํ์ต๋๋ค. ์ฐจํ์ ๋ค๋ฃฐ ์์ ์ ๋๋ค.