日本a√视频在线,久久青青亚洲国产,亚洲一区欧美二区,免费g片在线观看网站

        <style id="k3y6c"><u id="k3y6c"></u></style>
        <s id="k3y6c"></s>
        <mark id="k3y6c"></mark>
          
          

          <mark id="k3y6c"></mark>

          "); //-->

          博客專欄

          EEPW首頁 > 博客 > 征程 6EM 常見 QConfig 配置解讀與示例

          征程 6EM 常見 QConfig 配置解讀與示例

          發(fā)布人:地平線開發(fā)者 時間:2025-06-01 來源:工程師 發(fā)布文章

          一、引言

          在工具鏈用戶手冊《量化感知訓(xùn)練(QAT)-開發(fā)指南-QConfig 詳解》章節(jié)專門介紹了在 J6EM 上 qconfig 是怎么回事,從經(jīng)歷看,大家可能會存在看了依舊不懂,或懂了不知道怎么配置的情況,特別是一些 OE 包中示例沒有的配置,例如固定某節(jié)點 scale、配置 linear weight int16 等操作。

          qconfig 控制了模型所有節(jié)點的量化類型,例如是采用 int8 還是 int16 量化,是固定校準(zhǔn)階段的 scale 去 qat 還是不固定 scale 去 qat。

          提供的模板可分為三類:基礎(chǔ)模板、敏感度模板、自定義模板。本文將常見配置通過示例方式進行呈現(xiàn)。

          二、基礎(chǔ)模板

          基礎(chǔ)模板中 calibration / qat / qat_fixed_act_scale 區(qū)別在于使用的 observer 類型和 scale 更新邏輯,分別用于校準(zhǔn),不固定 activation scaleqat 訓(xùn)練,固定 activation scale qat 訓(xùn)練。

          default 模板 ( default_calibration_qconfig_setter / default_qat_qconfig_setter / default_qat_fixed_act_qconfig_setter ) 會做三件事:

          首先,將可以設(shè)置的高精度輸出都設(shè)置上,對于不支持高精度的輸出將給出提示;

          然后,從 grid sample 算子的 grid 輸入向前搜索,直到出現(xiàn)第一個 gemm 類算子或者 QuantStub,將中間的所有算子都設(shè)置為 int16。根據(jù)經(jīng)驗這里的 grid 一般表達(dá)范圍較寬,int8 有較大可能不滿足精度需求;

          最后,將其余算子設(shè)置為 int8。

          int16 模板 ( qat_8bit_weight_16bit_act_qconfig_setter / qat_8bit_weight_16bit_fixed_act_qconfig_setter / calibration_8bit_weight_16bit_act_qconfig_setter ) 會做兩件事:

          首先,將可以設(shè)置的高精度輸出都設(shè)置上,對于不支持高精度的輸出將給出提示;

          其次,將其余算子設(shè)置為 int16。

          from horizon_plugin_pytorch.quantization.qconfig_template import (
             default_calibration_qconfig_setter,
             default_qat_qconfig_setter,
             default_qat_fixed_act_qconfig_setter,
             qat_8bit_weight_16bit_act_qconfig_setter,
             qat_8bit_weight_16bit_fixed_act_qconfig_setter,
             calibration_8bit_weight_16bit_act_qconfig_setter,
          )
          qat_or_calib_model = prepare(
             float_model,
             example_inputs=example_inputs,  # 用來感知圖結(jié)構(gòu)
             qconfig_setter=(

                 default_qat_qconfig_setter,    # 根據(jù)需要配置setter模板
             ),
          )

          三、敏感度模板

          敏感度模板有三個:

          sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter
          sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter
          sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter

          三者的區(qū)別和基礎(chǔ)模板中三者的區(qū)別類似,也是分別用于校準(zhǔn),不固定 activation scale qat 訓(xùn)練,固定 activation scale qat 訓(xùn)練。

          敏感度模板的第一個輸入是精度 debug 工具產(chǎn)生的敏感度結(jié)果,第二個參數(shù)可以指定 ratio 或 topk,敏感度模板會根據(jù)配置,將量化敏感度最高的 topk 個算子設(shè)置為 int16。搭配固定模板,可以實現(xiàn)混合精度調(diào)優(yōu)。

          若模型有多個輸出,每個輸出都會產(chǎn)生一個敏感度表,您可以設(shè)置多個敏感度模版。示例如下:

          from horizon_plugin_pytorch.quantization.qconfig_template import (
             default_calibration_qconfig_setter,
             sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter,
             sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter,
             sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter,
          )

          # 這兩個pt文件是通過debug工具得到的
          table1 = torch.load("output_0-0_L1_sensitive_ops.pt")
          table2 = torch.load("output_0-1_L1_sensitive_ops.pt")

          calibration_model = prepare(
             float_model,
             example_inputs=example_input,
             qconfig_setter=(
                 sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter(table1, ratio=0.2),
                 sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter(table2, ratio=0.2),
                 default_calibration_qconfig_setter,
             ),
          )

          四、自定義模板

          自定義模板為 ModuleNameQconfigSetter,需要傳入模塊名和對應(yīng)自定義的 qconfig,一般用于設(shè)置 fixed scale、配置 linear weight int16 等特殊需求,可以和固定模板,敏感度模板搭配使用。示例如下:

          from horizon_plugin_pytorch.quantization.qconfig_template import (
             calibration_8bit_weight_16bit_act_qconfig_setter,
             ModuleNameQconfigSetter,
          )
          from horizon_plugin_pytorch.quantization.qconfig import (
             get_qconfig,
             MSEObserver,
             MinMaxObserver,
             FixedScaleObserver,
             QConfig,
          )
          from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize

          # 手動設(shè)置某個算子的輸出scale
          op_name_output_fix_scale_qconfig = QConfig(
             output=FakeQuantize.with_args(
                 observer=FixedScaleObserver,
                 dtype=qint16,
                 scale=0.0625,
             )
          )

          # 設(shè)置某個算子weight與輸出activation的量化類型
          # 校準(zhǔn)時用MSEObserver,qat時用MinMaxObserver
          # 沒有weight的算子,配置了weight_dtype也不會起作用
          calib_weight_act_both_int16_qconfig = get_qconfig(
             observer=MSEObserver,
             weight_dtype=qint16,
             out_dtype=qint16,
          )

          calib_weight_act_both_int8_qconfig = get_qconfig(
             observer=MSEObserver,
             weight_dtype=qint8,
             out_dtype=qint8,
          )

          qat_weight_act_both_int16_qconfig = get_qconfig(
             observer=MinMaxObserver,
             weight_dtype=qint16,
             out_dtype=qint16,
             fix_scale=True,    # 是否固定scale
          )

          放在一塊簡單示例如下:

          from horizon_plugin_pytorch.quantization.qconfig_template import (
             default_qat_qconfig_setter,
             sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter,
             ModuleNameQconfigSetter,
          )

          table = torch.load("output_0-0_dataindex_1_sensitive_ops.pt")

          # 自動替換生成的算子只能通過 ModuleNameQconfigSetter 配置自定義 qconfig。
          module_name_to_qconfig = {
             "_generated_add_0": op_name_output_fix_scale_qconfig ,
          }

          qat_model = prepare(
             float_model,
             example_inputs=example_input,
             qconfig_setter=(
                 ModuleNameQconfigSetter(module_name_to_qconfig),
                 sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter(table, ratio=0.2),
                 default_qat_qconfig_setter,
             ),
          )

          五、可運行的示例

          將網(wǎng)絡(luò)中 linear2 的 weight 配置為 int16 量化、輸入配置為 int8 量化、輸出配置為 int16 量化,其他算子激活使用 int16 量化,weight 使用 int8 量化。

          import torch
          from horizon_plugin_pytorch import set_march, March
          set_march(March.NASH_M)
          from horizon_plugin_pytorch.quantization import prepare, set_fake_quantize, FakeQuantState
          from horizon_plugin_pytorch.quantization import QuantStub
          from horizon_plugin_pytorch.quantization.hbdk4 import export
          from horizon_plugin_pytorch.quantization.qconfig_template import calibration_8bit_weight_16bit_act_qconfig_setter, ModuleNameQconfigSetter
          from horizon_plugin_pytorch.quantization.qconfig import get_qconfig, MSEObserver, MinMaxObserver
          from horizon_plugin_pytorch.dtype import qint8, qint16
          from torch.quantization import DeQuantStub
          import torch.nn as nn


          # 定義網(wǎng)絡(luò)結(jié)構(gòu)
          class SmallModel(nn.Module):
             def __init__(self):
                 super(SmallModel, self).__init__()
                 # 第一個 Linear: 輸入 [2, 100, 256] -> 輸出 [2, 100, 256]
                 self.linear1 = nn.Linear(256, 256)
                 self.layernorm = nn.LayerNorm(256)  # 對最后一維進行歸一化
                 self.relu = nn.ReLU()
                 # 第二個 Linear: 輸入 [2, 100, 256] -> 輸出 [2, 100, 60]
                 self.linear2 = nn.Linear(256, 60)
                 # 第三個 Linear: 輸入 [2, 100, 60] -> 輸出 [2, 100, 60]
                 self.linear3 = nn.Linear(60, 60)
                 self.quant = QuantStub()
                 self.dequant = DeQuantStub()

             def forward(self, x):
                 x = self.quant(x)
                 # 第一個 Linear
                 x = self.linear1(x)  # [2, 100, 256]
                 x = self.layernorm(x)  # [2, 100, 256]
                 x = self.relu(x)  # [2, 100, 256]
                 # 第二個 Linear
                 x = self.linear2(x)  # [2, 100, 60]
                 # 第三個 Linear
                 x = self.linear3(x)
                 x = self.dequant(x)
                 return x

          example_input = torch.randn(2, 100, 256)
          model = SmallModel()

          # 前向傳播
          output = model(example_input)
          print("輸出形狀:", output.shape)

          # A global march indicating the target hardware version must be setted before prepare qat.
          set_march(March.NASH_M)

          calib_weight_act_both_int16_qconfig = get_qconfig(
             observer=MSEObserver,
             weight_dtype=qint16,
             out_dtype=qint16,
          )

          # layernorm沒有weight,配置了weight_dtype也不會起作用
          calib_weight_act_both_int8_qconfig = get_qconfig(
             observer=MSEObserver,
             weight_dtype=qint8,
             out_dtype=qint8,
          )

          qat_weight_act_both_int16_qconfig = get_qconfig(
             observer=MinMaxObserver,
             weight_dtype=qint16,
             out_dtype=qint16,
             fix_scale=True,
          )
          # 節(jié)點名稱,可以從model_check_result.txt中獲取,也可以從敏感度文件中獲取
          module_name_to_qconfig = {
             "layernorm": calib_weight_act_both_int8_qconfig,
             "linear2": calib_weight_act_both_int16_qconfig,  
          }

          calib_model = prepare(model.eval(), example_input,
                               qconfig_setter=(
                                   ModuleNameQconfigSetter(module_name_to_qconfig),
                                   calibration_8bit_weight_16bit_act_qconfig_setter,
                                   ),
                               )

          calib_model.eval()
          set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)
          calib_model(example_input)

          calib_model.eval()                            
          set_fake_quantize(calib_model, FakeQuantState.VALIDATION)
          calib_out = calib_model(example_input)

          qat_bc = export(calib_model, example_input)

          配置 add 單算子輸入和輸出均使用固定 scale

          import torch
          from horizon_plugin_pytorch import set_march, March
          set_march(March.NASH_E)
          from horizon_plugin_pytorch.quantization import prepare, set_fake_quantize, FakeQuantState
          from horizon_plugin_pytorch.quantization import QuantStub
          from horizon_plugin_pytorch.quantization.hbdk4 import export
          from horizon_plugin_pytorch.quantization.qconfig_template import calibration_8bit_weight_16bit_act_qconfig_setter, ModuleNameQconfigSetter
          from horizon_plugin_pytorch.quantization.qconfig import get_qconfig, MSEObserver, MinMaxObserver, FixedScaleObserver, QConfig
          from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize
          from horizon_plugin_pytorch.dtype import qint8, qint16
          from torch.quantization import DeQuantStub
          import torch.nn as nn


          class AddNet(nn.Module):
             def __init__(self):
                 super(AddNet, self).__init__()
                 self.quant_x = QuantStub()
                 self.quant_y = QuantStub()
                 self.dequant = DeQuantStub()

             def forward(self, x, y):
                 x = self.quant_x(x)
                 y = self.quant_y(y)
                 z = torch.add(x, y)
                 z = self.dequant(z)
                 return z

          # 創(chuàng)建模型
          model = AddNet()

          # 生成兩個相同形狀的輸入張量
          torch.manual_seed(42)
          x = torch.randn(1, 1, 2, 6)
          y = torch.randn(1, 2, 2, 6)
          example_input = (x,y)

          # 前向傳播
          output = model(example_input[0], example_input[1])
          print("float輸出數(shù)據(jù):", output)
          print("輸入形狀:", example_input[0].shape)
          print("輸出形狀:", output.shape)

          # A global march indicating the target hardware version must be setted before prepare qat.
          set_march(March.NASH_E)

          add_input_fix_scale_qconfig = QConfig(
             output=FakeQuantize.with_args(
                 observer=FixedScaleObserver,
                 dtype=qint16,
                 scale=0.03125,
             )
          )
          add_output_fix_scale_qconfig = QConfig(
             output=FakeQuantize.with_args(
                 observer=FixedScaleObserver,
                 dtype=qint16,
                 scale=0.0625,
             )
          )

          # 節(jié)點名稱,可以從model_check_result.txt中獲取,也可以從敏感度文件中獲取
          module_name_to_qconfig = {
             "quant_x": add_input_fix_scale_qconfig,

             "quant_y": add_input_fix_scale_qconfig,

             "_generated_add_0": add_output_fix_scale_qconfig,
          }

          calib_model = prepare(model.eval(), example_input,
                               qconfig_setter=(
                                   ModuleNameQconfigSetter(module_name_to_qconfig),
                                   calibration_8bit_weight_16bit_act_qconfig_setter,
                                   ),
                               )

          calib_model.eval()
          set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)
          calib_model(example_input[0], example_input[1])

          calib_model.eval()                            
          set_fake_quantize(calib_model, FakeQuantState.VALIDATION)
          calib_out = calib_model(example_input[0], example_input[1])
          print("calib輸出數(shù)據(jù):", calib_out)

          qat_bc = export(calib_model, example_input)

          六、凍結(jié)部分網(wǎng)絡(luò)結(jié)構(gòu) qat 的配置

          補充常見凍結(jié)網(wǎng)絡(luò)結(jié)構(gòu),去進行 qat 的做法

          from horizon_plugin_pytorch.quantization import (
             QuantStub,
             prepare,
             set_fake_quantize,
             FakeQuantState,
          )
          #prepare QAT模型
          qat_model = prepare(
             model,
             example_inputs=xxx,
             qconfig_setter=(
                 xxx,
             )
          )
          #加載calib權(quán)重
          qat_model.load_state_dict(torch.load("calib-checkpoint.ckpt"))
          #QAT訓(xùn)練
          qat_model.train()
          #固定backbone部分的權(quán)重,requires_grad不影響drop bn的行為,需要與eval聯(lián)合用
          for param in qat_model.backbone.parameters():
             param.requires_grad = False
          #固定backbone部分的scale,eval只影響drop bn的行為,如果發(fā)生了backward仍然會改變權(quán)重,需要與requires_grad聯(lián)合使用
          qat_model.backbone.eval()
          set_fake_quantize(qat_model.backbone, FakeQuantState.VALIDATION)
          #配置head的FakeQuant為QAT狀態(tài)
          set_fake_quantize(qat_model.head, FakeQuantState.QAT)


          *博客內(nèi)容為網(wǎng)友個人發(fā)布,僅代表博主個人觀點,如有侵權(quán)請聯(lián)系工作人員刪除。




          相關(guān)推薦

          技術(shù)專區(qū)

          關(guān)閉