PTQ(PostTrainingQuantization)源码阅读一
PTQ(Post Training Quantization)源码阅读一
最近在做模型量化相关工作,就研究下PTQ的原理和代码实现。PTQ原理部分已经有很多文章讲的都很好,有时间的话后面自己总结一篇原理篇。本文主要从PTQ代码实现来阐述。
讲解代码前我们先看下PTQ的使用: # load model model = load_model(model_path) model.eval() # register quant_handle_hook in forward_post_hooks ptq = PTQ() model = ptq.quantize(model) # calibration for key, input in reader: model(input) # compute quant params ptq.ptq._convert(model) # save quant model jit.save(model, quant_model_path)
我们先看下如何收集 activation 量化信息。 ImperativePTQclass ImperativePTQ(object): """ Static post training quantization. """ def __init__(self, quant_config=ptq_config.default_ptq_config): """ Constructor. Args: quant_config(PTQConfig): the config of post training quantization. The config has weight_quantizer and activation_quantizer. In default, the weight_quantizer is PerChannelAbsmaxQuantizer and the activation_quantizer is KLQuantizer. """ super().__init__() assert isinstance(quant_config, ptq_config.PTQConfig) self._quant_config = quant_config
ImperativePTQ 是PTQ的实现类。输出参数为quant_config ,主要指明 weight/activation 的量化方法。默认的 activation_quantizer 使用 KLQuantizer , weight_quntizer 使用 PerChannelAbsmaxQuantizer .class PTQConfig(object): """ The PTQ config shows how to quantize the inputs and outputs. """ def __init__(self, activation_quantizer, weight_quantizer): """ Constructor. Args: activation_quantizer(BaseQuantizer): The activation quantizer. It should be the instance of BaseQuantizer. weight_quantizer(BaseQuantizer): The weight quantizer. It should be the instance of BaseQuantizer. """ super().__init__() assert isinstance(activation_quantizer, tuple(SUPPORT_ACT_QUANTIZERS)) assert isinstance(weight_quantizer, tuple(SUPPORT_WT_QUANTIZERS)) self.in_act_quantizer = copy.deepcopy(activation_quantizer) self.out_act_quantizer = copy.deepcopy(activation_quantizer) self.wt_quantizer = copy.deepcopy(weight_quantizer) self.quant_hook_handle = None # In order to wrap simulated layers, use in_act_quantizer # to calculate the input thresholds for conv2d, linear and etc. self.enable_in_act_quantizer = False default_ptq_config = PTQConfig(KLQuantizer(), PerChannelAbsmaxQuantizer())
其中 quant_hook_handle 是 Layer 的 foward post hook的 handle。
enable_in_act_quantizer 是否使用 in_act_quantizer 计算输入激活的量化参数。activation 默认使用 KLQuantizer 量化器。 weight 默认使用 PerChannelAbsmaxQuantizer 量化器。 _is_skip_layer 和 _is_quant_layer
模型一般是一层一层堆叠起来的,框架提供的 nn.Conv2d , nn.Linear 层一般作为基础层来搭建模型网络。量化时我们需要知道哪些层需要量化,哪些层不需要量化。可以通过_is_skip_layer 和_is_quant_layer 两个静态类方法获得。 @staticmethod def _is_skip_layer(layer): return hasattr(layer, "skip_quant") and layer.skip_quant == True @staticmethod def _is_quant_layer(layer): return hasattr(layer, "_quant_config") is_leaf_layerdef is_leaf_layer(layer): """ Whether the layer is leaf layer. """ return isinstance(layer, paddle.nn.Layer) and len(layer.sublayers()) == 0
layer 的 sublayers 空时为叶子节点。 quantize def quantize(self, model, inplace=False, fuse=False, fuse_list=None): """ Add quant config and hook to the target layer. Args: model(paddle.nn.Layer): The model to be quantized. inplace(bool): Whether apply quantization to the input model. Default: False. fuse(bool): Whether to fuse layers. Default: False. fuse_list(list): The layers" names to be fused. For example, "fuse_list = [["conv1", "bn1"], ["conv2", "bn2"]]". A TypeError would be raised if "fuse" was set as True but "fuse_list" was None. Default: None. Return quantized_model(paddle.nn.Layer): The quantized model. """ assert isinstance( model, paddle.nn.Layer ), "The model must be the instance of paddle.nn.Layer." if not inplace: model = copy.deepcopy(model) if fuse: model.eval() model = fuse_utils.fuse_layers(model, fuse_list)
我们看下模型量化的入口, model 是模型实例,inplace 指明是否在原图上操作,fuse 和fuse_list 用户指定是否对模型做fuse操作。该接口最终返经过处理(用于收集模型各层 activation 的信息)后的模型。 for name, layer in model.named_sublayers(): if ( PTQRegistry.is_supported_layer(layer) and utils.is_leaf_layer(layer) and not self._is_skip_layer(layer) ): # Add quant config quant_config = copy.deepcopy(self._quant_config) if PTQRegistry.is_simulated_quant_layer(layer): ## quant activation quant_config.enable_in_act_quantizer = True layer._quant_config = quant_config # register hook hook = ptq_hooks.quant_forward_post_hook quant_hook_handle = layer.register_forward_post_hook(hook) quant_config.quant_hook_handle = quant_hook_handle layer._forward_post_hooks.move_to_end( quant_hook_handle._hook_id, last=False ) return model
首先遍历各层,判断该层: 是否支持量化。 是否是叶子层。 是否跳过该层。
PTQRegistry 是一个字典,后续再看下其实现。
如果满足上述条件,则对该层添加量化处理: 层中保存量化配置参数 quant_config 。如果是模拟量化层(针对 input/weight 量化)的话,开启 enable_in_act_quantizer .再层中注册 register_forward_post_hook ,其实现为 ptq_hooks.quant_forward_post_hook .
我们看下 quant_forward_post_hook 的实现:def quant_forward_post_hook(layer, inputs, outputs): """ The forward_post_hook for PTQ. """ assert hasattr( layer, "_quant_config" ), "The layer should have _quant_config attr" qc = layer._quant_config if qc.enable_in_act_quantizer: qc.in_act_quantizer.sample_data(layer, inputs) qc.out_act_quantizer.sample_data(layer, (outputs,))
在 forward 完成后,通过 qc.out_act_quantizer 收集 outputs 的 activation 数据。
根据 qc.enable_in_act_quantizer 的配置确定是否收集 inputs 的 activation 数据。
我们知道,只有 PTQRegistry.is_simulated_quant_layer(layer) 真(目前只有 nn.Conv2D/nn.Linaer 时为真)的时候 qc.enable_in_act_quantizer 为真。
KLQuantizer 、PerChannelAbsmaxQuantizer 的实现我们后面再讨论。
至此,处理完各层后返回 model 对象。后续使用校准数据过 model ,收集 activation 分布。