[논문 리뷰] Learningn both Weights and Connections for Efficient Neural Network
https://arxiv.org/pdf/1506.02626.pdf Learning both Weights and Connections for Efficient Neural Network, 논문을 바탕으로 작성하였습니다.
https://github.com/jack-willturner/deep-compression 코드 참고
1 Abstract
기존 Network들은 학습을 하기 이전에 architecture들을 고정시키기 때문에 학습 단계에서 구조를 발전시키는 방법에 제한이 있었다. 이 제한을 해결하기 위해서 본 논문은 Accuracy를 낮추지 않으면서, 중요한 connection만 학습시켜 저장공간을 줄이고 계산량을 줄이는 방법을 제안한다.
불필요한 connections 들을 없애는 작업인 pruning은 3가지 steps로 진행된다.
1. 어떤 connection이 중요한지 학습하기
2. 주요하지 않은 connection prune 하기
3. 남아있는 connections들의 weight을 fine tune하기 위해 retrain
2 Introduction
본 논문의 목적은 large network를 사용하는데 필요한 에너지를 pruning networks를 통해 감소시켜 mobile devices에도 사용이 가능하게 하는 것이다. 이 목적을 달성하기 위해서는 pruning method를 사용하되, 전체 정확도가 감소하지 않게 유의하였다. Initial training phase (step1) 이후 threshold보다 낮은 weight값을 가지는 connection들은 모두 제거하여 dense한 layer를 sparse한 layer로 변환시킨다. 첫번째 step에서 network의 topology를 학습하는데, 즉 어떤 connection이 중요한지, 불필요한 connection들을 제거하는 단계이다. 이후 sparse한 network을 retrain하면서 remaining connections들이 제거된 connection을 보수할 수 있도록 한다. pruning과 retrainig 단계를 반복적으로 진행하여 network 자체의 복잡도를 줄여나간다. 이로 인하여 결과적으로 학습 단계에서 네트워크의 connectivity를 학습한다. 이를 포유류의 뇌, 신생아 때 생성된 synapses는 사용하지 않는 connection들의 gradual 한 pruning으로 인해 성인이 되는 것으로 비유하였다.
3 Learning Connections in Addition to Weights
앞서 말한 3가지 step들을 시행한다. Normal network 학습을 통해서 connectivity를 학습한다. 기존 학습과 달리 weight의 final value를 학습하는 것이 아닌, "어떤 connection이 중요"한지를 학습한다.
Unlike conventional training, however, we are not learning the final values of the weights, but rather we are learning which connections are important
2번째 step에서는 Low weight connection들을 삭제하는 단계로, threshold 보다 낮은 값을 가진 connections들은 제거한다. 그리고 마지막 단계에서는 remaining하는 sparse connection에 대해 다시 retraining을 한다. Retraining 없이는 낮은 accuracy를 지니게 되어, 아주 중요한 단계임을 강조한다.
3-1 Regularization
알맞은 Regularization은 pruning 과 retraining 의 성능에 많은 영향을 끼치는데, L1-regularization은 non-zero 파라미터를 0에 가깝게끔 값을 낮게 조정하기 때문에 pruning 이후에는 좋은 영향을 끼치나, retraining 단계에서는 좋지 않은 성능을 지닌다. 그러나 L2-regularization은 best pruning result을 가져온다.
3-2 Dropout Ratio Adjustment
Dropout을 학습 때 적용하여도 reference 단계에서는 dropout을 반영하지 않는 반면에, pruning은 학습과 reference 단계 모두 sparse한 network를 사용한다. 때문에, dropout rate을 기존에 사용했던 값을 사용하게 되면, network 자체가 너무 sparse하게 변해 오히려 좋지 않은 결과를 초래한다. 따라서 dropout rate을 조정하는 작업이 필요하다.
(1) 번 식에서 Ci는 i번째 layer의 connection 개수, Ni는 i번째 layer의 neuron 개수로, i-1 번째 neuron 개수와 i번째 neuron 개수를 곱하면 fully connected 상태의 connection 개수는 Ci가 된다. (2) 번식에서 Dr는 retraining 단계에서의 새로 정의된 drop-out rate이고 Do 는 original dropout rate이다. 마찬가지로 Cir는 retraining 단계에서 i번째 layer의 connection 개수, Cio는 original i번째 layer의 connection 개수로, 이를 나누고 루트를 씌운 값에 Do를 곱하면, retraining 단계에서의 새로 정의한 Drop out rate Dr가 된다.
3-3 Local pruning and Parameter Co-adaption
Retraining 단계에서 remaining connection들을 초기화하여 retrain 하는 것보다 처음 training phase에서 얻은 weight들로 retrain하는 것이 훨씬 더 높은 성능을 띄는 것을 확인하였다.
gradient descent is able to find a good solution when the network is initially trained, but not after re-initializing some layers and retraining them. So when we retrain the pruned layers, we should keep the surviving parameters instead of re-initializing them.
Gradient descent가 initially trained된 network에서는 좋은 값을 찾지만, layer들을 재초기화하여 retraining 하는 network에서는 좋은 값을 찾지 못한다. 따라서 pruned layer를 retrain할때는 re-initialize하기 보다는 surviving parameter를 사용하는 것이 좋다.
Retraining the pruned layers starting with retained weights requires less computation because we don’t have to back propagate through the entire network. Also, neural networks are prone to suffer the vanishing gradient problem as the networks get deeper, which makes pruning errors harder to recover for deep networks. To prevent this, we fix the parameters for CONV layers and only retrain the FC layers after pruning the FC layers, and vice versa.
Retained된 (이미 값을 구한) weight을 사용하는 것이 계산량을 줄인다. (역전파 불필요) 또한, Neural Network는 깊어질 수록 vanishing gradient 문제에 취약해지는데, 이를 방지하기 위해 Convolutional layer의 파라미터를 고정하고, Fully-Connected layer를 pruning한 후 FC layer 만을 retrain 하는 방법(또는 이 반대)을 사용한다.
3-4 Local pruning and Parameter Co-adaption
필요한 connection을 고르는데에는 반복적인 process가 필요하다. Pruning 이후 retraining을 1개의 iteration이고 이 iteration의 많은 반복 이후에 minimum 개수의 connection을 얻는다.
3-5 Pruning Neurons
0의 값을 가진 neurons들은 input output connection이 0개이므로 final loss에서 사용되지 않기에 제거 대상이다. 이를 Dead Neuron이라 부르는데, dead neuron은 gradient descent과 Regularization을 통해 만들어진다. 이러한 neurons들은 자동적으로 retraining 단계에서 제거 된다.
4 Regularization
위 그림은 정확도와 파라미터 개수의 trade off 관계를 나타낸다. 파라미터가 pruned 될수록 Acc는 낮아지는 것을 알 수 있다. 위 사진에서는 L1 & L2 regularization을 w/, w/o retraining을 실험한 것으로, L2 regularization w/ iterative prune and retrain이 가장 좋은 성능을 나타낸 것을 알 수 있다.