본문 바로가기

ML&DL/PyTorch

[PyTorch] Gradient clipping (그래디언트 클리핑)

Gradient clipping을 하는 이유

주로 RNN계열에서 gradient vanishing이나 gradient exploding이 많이 발생하는데, gradient exploding을 방지하여 학습의 안정화를 도모하기 위해 사용하는 방법이다.

 

Gradient clipping과 L2norm

내용이 사실 굉장히 간단하다. clipping이란 단어에서 유추할 수 있듯이 gradient가 일정 threshold를 넘어가면 clipping을 해준다. clipping은 gradient의 L2norm(norm이지만 보통 L2 norm사용)으로 나눠주는 방식으로 하게된다. threshold의 경우 gradient가 가질 수 있는 최대 L2norm을 뜻하고 이는 하이퍼파라미터로 사용자가 설정해주어야 한다. 논문에서 실험파트를 읽다보면 심심치 않게 maximum L2norm 파라미터를 만날 수 있다. 

 

 

clipping이 없으면 gradient가 너무 뛰어서 global minimum에 도달하지 못하고 너무 엉뚱한 방향으로 향하게 되지만, clipping을 하게 되면 gradient vector가 방향은 유지하되 적은 값만큼 이동하여 도달하려고 하는 곳으로 안정적으로 내려가게 된다. 

 

기울기의 방향은 유지, 크기만 달라진다.

Gradient clipping의 pytorch 예제

 

import torch

max_norm = 5
optimizier = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0)

# you can set it in trainning phase
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
optimizer.step()

Reference

kh-kim.gitbook.io/natural-language-processing-with-pytorch/00-cover-6/05-gradient-clipping

towardsdatascience.com/what-is-gradient-clipping-b8e815cdfb48