import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

# 对 conv1 层进行剪枝
module = model.conv1
print("conv1 的参数:\n", list(module.named_parameters()))  # 查看 conv1 层有哪些参数，以及对应的参数值是什么

# 进行非结构化剪枝

# 对 bias 参数进行 L1非结构化剪枝
# 即随机把参数的某些值置为 0
# amount 表示剪枝比例，即将 30% 值随机置 0
# importance_scores 是用来指定重要性指标的接口
# 例如可以设置参数的绝对值信息作为指标
# 要求是该指标的形状必须与被剪枝的参数相同
importance_scores = module.bias.abs()
prune.l1_unstructured(module, name='bias', amount=0.3, importance_scores=importance_scores)

# named_buffers() 是 torch 中的一个用于获取模型或模块中所有注册的 buffers 及其名称的方法
# buffers 是模型中需要保存但不需要梯度更新的张量
# 这里保存的是剪枝的 mask，剪枝并不是直接把参数的值设为 0，而是通过 mask 机制把被剪枝的掩码置 0
# 这样，被剪枝的权重还可以被恢复，方法为 prune.remove()
print("bias 的 mask:\n", list(module.named_buffers()))
print("查看剪枝后的结果:\n", module.bias)    # 查看剪枝后的结果

# 查看原始权重
print(module.bias_orig)

# 恢复权重
prune.remove(module, name='bias')
print("查看恢复的结果:\n", module.bias)    # 查看恢复的结果
