해당 코드를 참고하여 작성하였습니다. https://github.com/jack-willturner/deep-compression
논문에 대한 요약은 https://shshin9812.tistory.com/22 참고바랍니다.
Pruning을 적용하여 학습하기 위해서는 train.py와 prune.py를 사용해준다. train.py는 주요 connections들을 배우는 학습 과정이므로 생략하고 prune.py 부터 살펴보자.
python train.py --model='resnet34' --checkpoint='resnet34'
python prune.py --model='resnet34' --checkpoint='resnet34'
prune.py
prune.py에 있는 주요 Argument parsing부터 살펴보자. 위 코드를 적용하여 해석해보면 model은 resnet34를 사용하고 resnet34(model이름명)을 사용하여 pruned model을 저장한다. Pruning type은 unstructured를 사용하고, prune_iters를 통해 prune을 100번 반복하는 것을 알 수 있다.
################################################################## ARGUMENT PARSING
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 pruning")
parser.add_argument(
"--model",
default="resnet18",
help="resnet9, resnet18, resnet34, resnet50, wrn_40_2, wrn_16_2, wrn_40_1",
)
parser.add_argument("--data_loc", default="./datasets/cifar10/", type=str)
parser.add_argument(
"--checkpoint", default=None, type=str, help="Pretrained model to start from"
)
parser.add_argument(
"--prune_checkpoint", default=None, type=str, help="Where to save pruned models"
)
parser.add_argument("--n_gpus", default=0, type=int, help="Number of GPUs to use")
parser.add_argument(
"--save_every",
default=5,
type=int,
help="How often to save checkpoints in number of prunes (e.g. 10 = every 10 prunes)",
)
parser.add_argument("--seed", default=1, type=int)
parser.add_argument("--cutout", action="store_true")
### pruning specific args
parser.add_argument("--pruner", default="L1Pruner", type=str)
parser.add_argument(
"--pruning_type",
default="unstructured",
type=str,
help="structured or unstructured",
)
parser.add_argument(
"--prune_iters",
default=100,
help="how many times to repeat the prune->finetune process",
)
parser.add_argument(
"--target_prune_rate",
default=99,
type=int,
help="Percentage of parameters to prune",
)
parser.add_argument("--finetune_steps", default=100)
parser.add_argument("--lr", default=0.001)
parser.add_argument("--weight_decay", default=0.0005, type=float)
args = parser.parse_args()
아래 코드를 보면, prune_rates는 0부터 99까지 정수를 numpy 형태로 지니고 있으며, for문을 통해 0부터 99까지 pruner.prune, finetune, validate을 반복한다. 0~99의 숫자를 가지는 prune_rate이 5의 배수가 되면 아래와 같은 방식으로 checkpoint에 저장해준다. 5의 배수가 아닌 경우에는 저장을 하지 않는다. 이후 finetuning을 해주고, checkpoint가 True이면 즉, 저장되어있는 값이 있다면, validation 을 해준다.
prune_rates = np.linspace(0, args.target_prune_rate, args.prune_iters) # 0~99 까지
for prune_rate in tqdm(prune_rates): # 0 ~ 99
# pruner = L1pruner(unstructured)
# pruning 단계
pruner.prune(model, prune_rate)
# 5
if prune_rate % args.save_every == 0:
checkpoint = args.prune_checkpoint + str(prune_rate)
else:
checkpoint = None # don't bother saving anything
finetune(model, trainloader, criterion, optimizer, args.finetune_steps)
if checkpoint:
validate(model, prune_rate, testloader, criterion, checkpoint=checkpoint)
l1_pruner.py
pruner.prune(model, prune_rate) 을 자세히 보면, model은 여기서 resnet34이고 prune_rate은 0~99까지의 for문에서의 숫자이다. pruner는 l1_pruner.py의 L1Pruner Class를 나타내므로 l1_pruner.py를 보면 아래 코드와 같다.
L1Pruner Class안에있는 prune 함수는 "unstructed"인 경우 unstructured_prune 함수를 불러온다.
def prune(self, model, prune_rate):
# self.pruning_type = unstructrued
if self.pruning_type.lower() == "unstructured":
self.unstructured_prune(model, prune_rate)
elif self.pruning_type.lower() == "structured":
self.structured_prune(model, prune_rate)
else:
raise ValueError("Invalid type of pruning")
unstructured_prune 함수 코드는 아래와 같다. resnet.py에 있는 resnet34 모델에서 정의된 get_prunable_layers (나중에 resnet.py에서 설명예정) 를 통해 convolutional layers들을 convs에 정의해준다.
convs에 있는 conv ( 각 convolutional layer ) 에 대해 weights들을 concat 해주고 이를 all_weights에 저장한다. 예를들어 (64,3,3,3) 인 convolutional layer의 weight(conv.conv.weight) .view(-1) 를 통해 1차원 tensor를 만들어주고 (즉, [1728] ) 이를 이전에 정의되었던 all_weight과 concat해준다. 모든 conv의 weight에 대해 concat를 수행해주고 난 후 detach를 한 것을 abs_weight에 저장한다.
Threshold에는 prune_rate 만큼의 퍼센테지에 해당하는 data 값을 계산하여 기준점으로 삼아준다.
model.get_prunable_layers에서 받아온 convolutional layer는 conv_bn_relu.py에 정의된 ConvBNReLU class이므로, mask.update 함수를 통해 threshold보다 높은 값의 conv.weight 만 (낮은 값은 0) conv.mask.mask.weight과 곱해주어 이 값을 mask로 update 해준다 (new mask) .
def unstructured_prune(self, model, prune_rate=50.0):
# get all the prunable convolutions
convs = model.get_prunable_layers(pruning_type=self.pruning_type)
# collate all weights into a single vector so l1-threshold can be calculated
all_weights = torch.Tensor()
if torch.cuda.is_available():
all_weights = all_weights.cuda()
for conv in convs:
all_weights = torch.cat((all_weights.view(-1), conv.conv.weight.view(-1))) # 1728 + 36864 (64,64,3,3)
abs_weights = torch.abs(all_weights.detach())
# abs : absolute value
# abs_weights on cuda
threshold = np.percentile(abs_weights.cpu(), prune_rate) # computing prune_rate'th percentile of data
# prune anything beneath l1-threshold
for conv in model.get_prunable_layers(pruning_type=self.pruning_type):
# conv.conv.weight on cuda,
# torch.gt(input, other, *, out=None) input > other : returns boolean / false = 0, true = 1
# conv.mask.update -> conv.mask.mask.weight * 0 or 1
conv.mask.update(
torch.mul(
torch.gt(torch.abs(conv.conv.weight), threshold).float(),
conv.mask.mask.weight.cuda(),
)
)
conv_bn_relu.py
위의 코드와 좀 더 깊은 이해를 위해 conv_bn_relu.py를 살펴보자. unstructured_prune 함수에서는 model의 convolutional layer인 ConvBNReLU Class 와 Class 안 함수들을 불러온다.
ConvBNReLU Class : ConvBNReLU 함수는 nn.Conv2d를 통해 convolutional layer를 정의(self.conv)하고, self.bn 에는 BatchNorm2d를, relu가 True라면 ReLU 함수를 self.relu에 정의한다.
forward 함수에서 알 수 있듯, 이 convolutional layer에 mask를 적용(self.mask.apply ← UnstructuredMask Class 안 apply 함수)하는데, conv의 weight와 mask.weight을 곱해주어 conv.weight.data에 저장한다. update되는 weight값이 없다면, mask는 1로 구성된 tensor 값으로, conv.weight.data에 변화가 없다.
weight가 새롭게 적용된 conv와 bn, relu를 거친 input값을 return 해준다.
class ConvBNReLU(nn.Module):
def __init__(
self,
in_planes,
planes,
kernel_size=3,
stride=1,
padding=1,
bias=False,
relu=True,
):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(
in_planes,
planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
)
self.bn = nn.BatchNorm2d(planes)
if relu:
self.relu = nn.ReLU()
else:
self.relu = nn.Identity()
self.mask = UnstructuredMask(
in_planes, planes, kernel_size, stride, padding, bias
)
def forward(self, x):
self.mask.apply(self.conv, self.bn)
return self.relu(self.bn(self.conv(x)))
class UnstructuredMask:
def __init__(self, in_planes, planes, kernel_size, stride, padding, bias=None):
self.mask = nn.Conv2d(
in_planes,
planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=False,
)
self.mask.weight.data = torch.ones(self.mask.weight.size()) # weight size(64,64,3,3) 만큼 1 생성
def update(self, new_mask):
self.mask.weight.data = new_mask
def apply(self, conv, bn=None):
conv.weight.data = torch.mul(conv.weight, self.mask.weight.cuda())
# conv.weight in cuda , self.mask.weight not in cuda <- figured
resnet.py
https://shshin9812.tistory.com/3 : ResNet에 대한 간략한 설명 참고!
resnet.py 을 보면 다양한 depth를 가진 Network가 정의되어있다. 이 중 코드에서 사용된 resnet34 를 중점적으로 보려고 한다.
Shortcut Class를 먼저 살펴보자. 들어온 input값은 self.conv_bn 을 지나게 되는데, 이는 ConvBNReLU Class로 정의된 convolutional - batchnormalization 와의 maks.apply 와 convolutional - batchnormalization - relu 로 이루어진 layer이다.
class Shortcut(nn.Module):
def __init__(
self,
in_planes,
planes,
expansion=1,
kernel_size=1,
stride=1,
padding=0,
bias=False,
):
super(Shortcut, self).__init__()
self.conv_bn = ConvBNReLU(
in_planes,
expansion * planes,
kernel_size=kernel_size,
stride=stride,
padding=0,
bias=False,
relu=False,
)
def forward(self, x):
return self.conv_bn(x)
resnet34는 BasicBlock을 사용하므로 BasicBlock만 살펴보자. forward 함수에서 알 수 있듯, input 값은 self.conv1 와 self.conv2 를 순차적으로 통과한 뒤, self.shortcut을 지난 처음 input 값을 통과한 결과 값과 더해준다. 이후 relu 함수를 통과한다. 여기서 self.shortcut은 input과 output의 dimension이 다른 경우 또는 stride가 1이 아닌 경우에 nn.Identity() 가 아닌 Shortcut Class 에 정의된 self.conv_bn 이다.
여기서 정의된 get_prunable_layers는 "strutured" 와 "unstructured" 로 나뉘는데, "unstructured"의 경우 shortcut이 Identity가 아닌 경우와 맞는 경우로 나누어 layer들을 return 해준다.
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = ConvBNReLU(
in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
)
self.conv2 = ConvBNReLU(
planes, planes, kernel_size=3, stride=1, padding=1, bias=False, relu=False
)
self.shortcut = nn.Identity() # returns input
# input =/= output dimension -> self.shortcut = Shortcut()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = Shortcut(
in_planes,
planes,
self.expansion,
kernel_size=1,
stride=stride,
bias=False,
)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
out += self.shortcut(x)
out = F.relu(out)
return out
def get_prunable_layers(self, pruning_type="unstructured"):
if pruning_type == "unstructured":
# if self.shortcut is nn.Identity
if isinstance(self.shortcut, nn.Identity):
return [self.conv1, self.conv2]
# if self.shortcut is not nn.Identity
else:
return [self.conv1, self.conv2, self.shortcut.conv_bn]
elif pruning_type == "structured":
return [self.conv1]
else:
raise NotImplementedError
드디어 ResNet Class를 보자. (ResNet의 기본 구조) forward 함수를 보면 먼저 input 값은 ConvBNReLU로 정의된 self.conv_bn_relu를 통과한다. 이후 self.layer1에 정의된 layer들을 통과하는데, self.layer1은 basicblock이 num_block[0] 개수만큼 있다. 마찬가지로 self.layer2, self.layer3, self.layer4 도 여러 blocks 들로 이루어져 있다. 각 self.layer 안 layer들을 통과한 이후에는 avg_pool2d 와 self.linear 함수를 지난다.
여기서 정의된 get_prunable_layers의 "unstructured"의 경우 convs list에 conv_bn_relu를 append해준다. 이후 각 self.layer 안에 있는 layer(BasicBlock 안 get_prunable_layers 함수와 같이)들을 마찬가지로 append 해준다.
class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
super(ResNet, self).__init__()
self.in_planes = 64
self.conv_bn_relu = ConvBNReLU(
3, 64, kernel_size=3, stride=1, padding=1, bias=False
)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv_bn_relu(x)
self.activations = []
for layer in self.layer1:
out = layer(out)
self.activations.append(out)
for layer in self.layer2:
out = layer(out)
self.activations.append(out)
for layer in self.layer3:
out = layer(out)
self.activations.append(out)
for layer in self.layer4:
out = layer(out)
self.activations.append(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
def get_prunable_layers(self, pruning_type="unstructured"):
convs = []
if pruning_type == "unstructured":
convs.append(self.conv_bn_relu)
for stage in [self.layer1, self.layer2, self.layer3, self.layer4]:
for layer in stage:
for conv in layer.get_prunable_layers(pruning_type):
convs.append(conv)
return convs
'코드 구현 및 오류' 카테고리의 다른 글
[코드 구현] CIFAR100 Image Classification with Noisy Labels (0) | 2022.05.14 |
---|---|
[논문 구현] MobileNetV2 논문 구현 (0) | 2022.04.17 |
[코드 오류] LiteHRNet 설정 오류 (0) | 2022.04.11 |
댓글