Gradient clipping을 하는 이유
주로 RNN계열에서 gradient vanishing이나 gradient exploding이 많이 발생하는데, gradient exploding을 방지하여 학습의 안정화를 도모하기 위해 사용하는 방법이다.
Gradient clipping과 L2norm
내용이 사실 굉장히 간단하다. clipping이란 단어에서 유추할 수 있듯이 gradient가 일정 threshold를 넘어가면 clipping을 해준다. clipping은 gradient의 L2norm(norm이지만 보통 L2 norm사용)으로 나눠주는 방식으로 하게된다. threshold의 경우 gradient가 가질 수 있는 최대 L2norm을 뜻하고 이는 하이퍼파라미터로 사용자가 설정해주어야 한다. 논문에서 실험파트를 읽다보면 심심치 않게 maximum L2norm 파라미터를 만날 수 있다.
clipping이 없으면 gradient가 너무 뛰어서 global minimum에 도달하지 못하고 너무 엉뚱한 방향으로 향하게 되지만, clipping을 하게 되면 gradient vector가 방향은 유지하되 적은 값만큼 이동하여 도달하려고 하는 곳으로 안정적으로 내려가게 된다.
Gradient clipping의 pytorch 예제
import torch
max_norm = 5
optimizier = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0)
# you can set it in trainning phase
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
optimizer.step()
Reference
kh-kim.gitbook.io/natural-language-processing-with-pytorch/00-cover-6/05-gradient-clipping
towardsdatascience.com/what-is-gradient-clipping-b8e815cdfb48
'ML&DL > PyTorch' 카테고리의 다른 글
[PyTorch] Dataset과 Dataloader 설명 및 custom dataset & dataloader 만들기 (12) | 2020.09.30 |
---|---|
[PyTorch] Weight Decay (L2 penalty) (0) | 2020.09.28 |
[PyTorch] squeeze, unsqueeze함수: 차원 삭제와 차원 삽입 (3) | 2020.09.17 |
[PyTorch] Tensor 합치기: cat(), stack() (9) | 2020.09.16 |
[PyTorch] 시계열 데이터를 위한 RNN/LSTM/GRU 사용법과 팁 (8) | 2020.08.05 |