From ffc9e29844207f4befe548d79fcf18b46f2142e9 Mon Sep 17 00:00:00 2001 From: Jongsoo Park Date: Mon, 15 Apr 2019 14:35:25 -0700 Subject: [PATCH] unit test with multiple op invocations (#19118) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19118 A bug introduced by D14700576 reported by Yufei (fixed by D14778810 and D14785256) was not detected by our units tests. This diff improves unit tests to catch such errors (with this diff and without D14778810, we can reproduce the bug Yufei reported). This improvement also revealed a bug that affects the accuracy when we pre-pack weight and bias together and the pre-packed weight/bias are used by multiple nets. We were modifying the pre-packed bias in-place which was supposed to be constants. Reviewed By: csummersea Differential Revision: D14806077 fbshipit-source-id: aa9049c74b6ea98d21fbd097de306447a662a46d --- .../server/conv_depthwise_dnnlowp_op_test.py | 21 +++------ .../server/conv_dnnlowp_acc16_op_test.py | 24 +++++------ caffe2/quantization/server/conv_dnnlowp_op.cc | 5 ++- caffe2/quantization/server/conv_dnnlowp_op_test.py | 46 ++++++-------------- .../server/conv_groupwise_dnnlowp_acc16_op_test.py | 24 +++++------ .../server/conv_groupwise_dnnlowp_op_test.py | 25 +++++------ caffe2/quantization/server/dnnlowp_test_utils.py | 50 +++++++++++++++++++++- .../fully_connected_dnnlowp_acc16_op_test.py | 19 +++----- .../server/fully_connected_dnnlowp_op.cc | 5 ++- .../server/fully_connected_dnnlowp_op_test.py | 10 ++--- .../fully_connected_rowwise_dnnlowp_op_test.py | 10 ++--- 11 files changed, 117 insertions(+), 122 deletions(-) diff --git a/caffe2/quantization/server/conv_depthwise_dnnlowp_op_test.py b/caffe2/quantization/server/conv_depthwise_dnnlowp_op_test.py index fe1b886..2cbbdcd 100644 --- a/caffe2/quantization/server/conv_depthwise_dnnlowp_op_test.py +++ b/caffe2/quantization/server/conv_depthwise_dnnlowp_op_test.py @@ -10,6 +10,7 @@ from dnnlowp_test_utils import ( check_quantized_results_close, generate_conv_inputs, generate_convnd_inputs, + run_conv_or_fc, ) from hypothesis import given @@ -159,13 +160,9 @@ class DNNLowPOpConvDepthWiseTest(hu.HypothesisTestCase): ) net.Proto().op.extend([relu_op]) - self.ws.create_blob("X").feed(X, device_option=gc) - self.ws.create_blob("W").feed(W, device_option=gc) - self.ws.create_blob("b").feed(b, device_option=gc) - self.ws.run(init_net) - self.ws.run(net) - Y = self.ws.blobs["Y"].fetch() - outputs.append(Output(Y=Y, op_type=op_type, engine=engine, order=order)) + run_conv_or_fc( + self, init_net, net, X, W, b, op_type, engine, order, gc, outputs + ) check_quantized_results_close(outputs, symmetric=preserve_activation_sparsity) @@ -294,12 +291,8 @@ class DNNLowPOpConvDepthWiseTest(hu.HypothesisTestCase): ) net.Proto().op.extend([dequantize]) - self.ws.create_blob("X").feed(X, device_option=gc) - self.ws.create_blob("W").feed(W, device_option=gc) - self.ws.create_blob("b").feed(b, device_option=gc) - self.ws.run(init_net) - self.ws.run(net) - Y = self.ws.blobs["Y"].fetch() - outputs.append(Output(Y=Y, op_type=op_type, engine=engine, order=order)) + run_conv_or_fc( + self, init_net, net, X, W, b, op_type, engine, order, gc, outputs + ) check_quantized_results_close(outputs, symmetric=preserve_activation_sparsity) diff --git a/caffe2/quantization/server/conv_dnnlowp_acc16_op_test.py b/caffe2/quantization/server/conv_dnnlowp_acc16_op_test.py index d14b7dc..336508d 100644 --- a/caffe2/quantization/server/conv_dnnlowp_acc16_op_test.py +++ b/caffe2/quantization/server/conv_dnnlowp_acc16_op_test.py @@ -7,7 +7,10 @@ import hypothesis.strategies as st import numpy as np from caffe2.python import core, dyndep, utils, workspace from caffe2.quantization.server import utils as dnnlowp_utils -from dnnlowp_test_utils import check_quantized_results_close +from dnnlowp_test_utils import ( + check_quantized_results_close, + run_conv_or_fc, +) from hypothesis import assume, given @@ -189,12 +192,9 @@ class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase): ) net.Proto().op.extend([dequantize]) - self.ws.create_blob("X").feed(X, device_option=gc) - self.ws.create_blob("W").feed(W, device_option=gc) - self.ws.create_blob("b").feed(b, device_option=gc) - self.ws.run(net) - Y = self.ws.blobs["Y"].fetch() - outputs.append(Output(Y=Y, op_type=op_type, engine=engine, order=order)) + run_conv_or_fc( + self, None, net, X, W, b, op_type, engine, order, gc, outputs + ) check_quantized_results_close(outputs, symmetric=preserve_activation_sparsity) @@ -374,12 +374,8 @@ class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase): ) net.Proto().op.extend([dequantize]) - self.ws.create_blob("X").feed(X, device_option=gc) - self.ws.create_blob("W").feed(W, device_option=gc) - self.ws.create_blob("b").feed(b, device_option=gc) - self.ws.run(init_net) - self.ws.run(net) - Y = self.ws.blobs["Y"].fetch() - outputs.append(Output(Y=Y, op_type=op_type, engine=engine, order=order)) + run_conv_or_fc( + self, init_net, net, X, W, b, op_type, engine, order, gc, outputs + ) check_quantized_results_close(outputs, symmetric=preserve_activation_sparsity) diff --git a/caffe2/quantization/server/conv_dnnlowp_op.cc b/caffe2/quantization/server/conv_dnnlowp_op.cc index 635a63e..19a2851 100644 --- a/caffe2/quantization/server/conv_dnnlowp_op.cc +++ b/caffe2/quantization/server/conv_dnnlowp_op.cc @@ -245,8 +245,7 @@ void ConvDNNLowPOp::QuantizeBias_() { if (has_packed_bias) { const auto& packed_filter = this->template Input(FILTER); - b_quantized_ = packed_filter.bias; - b_quantized_data_ = b_quantized_->data(); + b_quantized_data_ = packed_filter.bias->data(); } else { const auto& bias = InputTensorCPU_(BIAS); if (this->template InputIsType(BIAS)) { @@ -290,6 +289,8 @@ void ConvDNNLowPOp::QuantizeBias_() { if (this->order_ == StorageOrder::NHWC && in_qparams_[INPUT].zero_point && column_offsets_->empty()) { if (b_quantized_->empty()) { + // When b_quantized_data_ is from pre-packed bias or Int8TensorCPU, + // we can't inplace modify so copy to internal b_quantized_ vector. b_quantized_->assign(b_quantized_data_, b_quantized_data_ + M); b_quantized_data_ = b_quantized_->data(); } diff --git a/caffe2/quantization/server/conv_dnnlowp_op_test.py b/caffe2/quantization/server/conv_dnnlowp_op_test.py index 03d31b5..cc2ff90 100644 --- a/caffe2/quantization/server/conv_dnnlowp_op_test.py +++ b/caffe2/quantization/server/conv_dnnlowp_op_test.py @@ -10,6 +10,7 @@ from dnnlowp_test_utils import ( check_quantized_results_close, generate_conv_inputs, generate_convnd_inputs, + run_conv_or_fc, ) from hypothesis import assume, given @@ -112,7 +113,9 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase): ) net.Proto().op.extend([quantize]) - x_q_param = dnnlowp_utils.choose_quantization_params(X.min(), X.max(), preserve_activation_sparsity) # noqa + x_q_param = dnnlowp_utils.choose_quantization_params( + X.min(), X.max(), preserve_activation_sparsity + ) if do_quantize_weight: int8_given_tensor_fill, w_q_param = dnnlowp_utils.create_int8_given_tensor_fill( W, "W_q", preserve_weight_sparsity @@ -178,13 +181,9 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase): ) net.Proto().op.extend([dequantize]) - self.ws.create_blob("X").feed(X, device_option=gc) - self.ws.create_blob("W").feed(W, device_option=gc) - self.ws.create_blob("b").feed(b, device_option=gc) - self.ws.run(init_net) - self.ws.run(net) - Y = self.ws.blobs["Y"].fetch() - outputs.append(Output(Y=Y, op_type=op_type, engine=engine, order=order)) + run_conv_or_fc( + self, init_net, net, X, W, b, op_type, engine, order, gc, outputs + ) check_quantized_results_close(outputs, symmetric=preserve_activation_sparsity) @@ -295,12 +294,9 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase): ) net.Proto().op.extend([relu]) - self.ws.create_blob("X").feed(X, device_option=gc) - self.ws.create_blob("W").feed(W, device_option=gc) - self.ws.create_blob("b").feed(b, device_option=gc) - self.ws.run(net) - Y = self.ws.blobs["Y"].fetch() - outputs.append(Output(Y=Y, op_type=op_type, engine=engine, order=order)) + run_conv_or_fc( + self, None, net, X, W, b, op_type, engine, order, gc, outputs + ) check_quantized_results_close(outputs) @@ -346,12 +342,6 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase): init_net = core.Net("test_init_net") net = core.Net("test_net") - fall_back_to_NCHW = "DNNLOWP" not in engine and order == "NHWC" - - if fall_back_to_NCHW: - X_nchw = utils.NHWC2NCHW(X) - W_nchw = utils.NHWC2NCHW(W) - do_quantize = "DNNLOWP" in engine do_dequantize = "DNNLOWP" in engine # If output scale/zp aren't set, it gets computed from ref fp32 op @@ -408,7 +398,7 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase): kernels=kernels, dilations=[dilation] * ndim, pads=[pad] * (ndim * 2), - order="NCHW" if fall_back_to_NCHW else order, + order=order, dequantize_output=not do_dequantize, engine=engine, group=group, @@ -428,19 +418,9 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase): ) net.Proto().op.extend([dequantize]) - self.ws.create_blob("X").feed( - X_nchw if fall_back_to_NCHW else X, device_option=gc - ) - self.ws.create_blob("W").feed( - W_nchw if fall_back_to_NCHW else W, device_option=gc + run_conv_or_fc( + self, init_net, net, X, W, b, op_type, engine, order, gc, outputs ) - self.ws.create_blob("b").feed(b, device_option=gc) - self.ws.run(init_net) - self.ws.run(net) - Y = self.ws.blobs["Y"].fetch() - if fall_back_to_NCHW: - Y = utils.NCHW2NHWC(Y) - outputs.append(Output(Y=Y, op_type=op_type, engine=engine, order=order)) check_quantized_results_close(outputs) diff --git a/caffe2/quantization/server/conv_groupwise_dnnlowp_acc16_op_test.py b/caffe2/quantization/server/conv_groupwise_dnnlowp_acc16_op_test.py index 1cd91dc..e1884db 100644 --- a/caffe2/quantization/server/conv_groupwise_dnnlowp_acc16_op_test.py +++ b/caffe2/quantization/server/conv_groupwise_dnnlowp_acc16_op_test.py @@ -7,7 +7,10 @@ import hypothesis.strategies as st import numpy as np from caffe2.python import core, dyndep, utils, workspace from caffe2.quantization.server import utils as dnnlowp_utils -from dnnlowp_test_utils import check_quantized_results_close +from dnnlowp_test_utils import ( + check_quantized_results_close, + run_conv_or_fc, +) from hypothesis import assume, given @@ -169,12 +172,9 @@ class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase): ) net.Proto().op.extend([dequantize]) - self.ws.create_blob("X").feed(X, device_option=gc) - self.ws.create_blob("W").feed(W, device_option=gc) - self.ws.create_blob("b").feed(b, device_option=gc) - self.ws.run(net) - Y = self.ws.blobs["Y"].fetch() - outputs.append(Output(Y=Y, op_type=op_type, engine=engine, order=order)) + run_conv_or_fc( + self, None, net, X, W, b, op_type, engine, order, gc, outputs + ) check_quantized_results_close(outputs, symmetric=preserve_activation_sparsity) @@ -319,12 +319,8 @@ class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase): ) net.Proto().op.extend([dequantize]) - self.ws.create_blob("X").feed(X, device_option=gc) - self.ws.create_blob("W").feed(W, device_option=gc) - self.ws.create_blob("b").feed(b, device_option=gc) - self.ws.run(init_net) - self.ws.run(net) - Y = self.ws.blobs["Y"].fetch() - outputs.append(Output(Y=Y, op_type=op_type, engine=engine, order=order)) + run_conv_or_fc( + self, init_net, net, X, W, b, op_type, engine, order, gc, outputs + ) check_quantized_results_close(outputs) diff --git a/caffe2/quantization/server/conv_groupwise_dnnlowp_op_test.py b/caffe2/quantization/server/conv_groupwise_dnnlowp_op_test.py index f582d9e..c524b7a 100644 --- a/caffe2/quantization/server/conv_groupwise_dnnlowp_op_test.py +++ b/caffe2/quantization/server/conv_groupwise_dnnlowp_op_test.py @@ -7,7 +7,11 @@ import hypothesis.strategies as st from caffe2.python import core, dyndep, workspace from caffe2.python.fb import hardcode_scale_zp from caffe2.quantization.server import utils as dnnlowp_utils -from dnnlowp_test_utils import check_quantized_results_close, generate_conv_inputs +from dnnlowp_test_utils import ( + check_quantized_results_close, + generate_conv_inputs, + run_conv_or_fc, +) from hypothesis import assume, given @@ -157,13 +161,9 @@ class GroupWiseDNNLowPOpConvTest(hu.HypothesisTestCase): ) net.Proto().op.extend([dequantize]) - self.ws.create_blob("X").feed(X, device_option=gc) - self.ws.create_blob("W").feed(W, device_option=gc) - self.ws.create_blob("b").feed(b, device_option=gc) - self.ws.run(init_net) - self.ws.run(net) - Y = self.ws.blobs["Y"].fetch() - outputs.append(Output(Y=Y, op_type=op_type, engine=engine, order=order)) + run_conv_or_fc( + self, init_net, net, X, W, b, op_type, engine, order, gc, outputs + ) check_quantized_results_close(outputs, symmetric=preserve_activation_sparsity) @@ -275,11 +275,8 @@ class GroupWiseDNNLowPOpConvTest(hu.HypothesisTestCase): ) net.Proto().op.extend([relu]) - self.ws.create_blob("X").feed(X, device_option=gc) - self.ws.create_blob("W").feed(W, device_option=gc) - self.ws.create_blob("b").feed(b, device_option=gc) - self.ws.run(net) - Y = self.ws.blobs["Y"].fetch() - outputs.append(Output(Y=Y, op_type=op_type, engine=engine, order=order)) + run_conv_or_fc( + self, None, net, X, W, b, op_type, engine, order, gc, outputs + ) check_quantized_results_close(outputs) diff --git a/caffe2/quantization/server/dnnlowp_test_utils.py b/caffe2/quantization/server/dnnlowp_test_utils.py index 67fcc99..ff15dd4 100644 --- a/caffe2/quantization/server/dnnlowp_test_utils.py +++ b/caffe2/quantization/server/dnnlowp_test_utils.py @@ -1,7 +1,9 @@ from __future__ import absolute_import, division, print_function, unicode_literals +import collections + import numpy as np -from caffe2.python import utils +from caffe2.python import utils, workspace from hypothesis import assume @@ -362,3 +364,49 @@ def generate_conv_inputs( preserve_activation_sparsity, preserve_weight_sparsity, ) + + +def run_conv_or_fc( + test_case, init_net, net, X, W, b, op_type, engine, order, gc, outputs +): + if order: + # Conv + Output = collections.namedtuple("Output", ["Y", "op_type", "engine", "order"]) + else: + # FC + Output = collections.namedtuple("Output", ["Y", "op_type", "engine"]) + + # We run DNNLOWP ops multiple times to test their first runs that + # do caching so exercises different code paths from the subsequent + # runs + + # self.ws.run re-creates operator everytime so this test covers + # cases when we have multiple nets sharing the same workspace + test_case.ws.create_blob("X").feed(X, device_option=gc) + test_case.ws.create_blob("W").feed(W, device_option=gc) + test_case.ws.create_blob("b").feed(b, device_option=gc) + if init_net: + test_case.ws.run(init_net) + for i in range(1 if engine == "" else 2): + test_case.ws.run(net) + Y = test_case.ws.blobs["Y"].fetch() + if order: + outputs.append(Output(Y=Y, op_type=op_type, engine=engine, order=order)) + else: + outputs.append(Output(Y=Y, op_type=op_type, engine=engine)) + + # workspace.CreateNet + workspace.RunNet reuses the same operator + if engine != "": + workspace.FeedBlob("X", X) + workspace.FeedBlob("W", W) + workspace.FeedBlob("b", b) + if init_net: + workspace.RunNetOnce(init_net) + workspace.CreateNet(net) + for i in range(2): + workspace.RunNet(net) + Y = workspace.FetchBlob("Y") + if order: + outputs.append(Output(Y=Y, op_type=op_type, engine=engine, order=order)) + else: + outputs.append(Output(Y=Y, op_type=op_type, engine=engine)) diff --git a/caffe2/quantization/server/fully_connected_dnnlowp_acc16_op_test.py b/caffe2/quantization/server/fully_connected_dnnlowp_acc16_op_test.py index ba1a8c8..6cc8dbf 100644 --- a/caffe2/quantization/server/fully_connected_dnnlowp_acc16_op_test.py +++ b/caffe2/quantization/server/fully_connected_dnnlowp_acc16_op_test.py @@ -7,7 +7,7 @@ import hypothesis.strategies as st import numpy as np from caffe2.python import core, dyndep, workspace from caffe2.quantization.server import utils as dnnlowp_utils -from dnnlowp_test_utils import check_quantized_results_close +from dnnlowp_test_utils import check_quantized_results_close, run_conv_or_fc from hypothesis import given @@ -104,12 +104,8 @@ class DNNLowPFullyConnectedAcc16OpTest(hu.HypothesisTestCase): ) net.Proto().op.extend([dequantize]) - self.ws.create_blob("X").feed(X, device_option=gc) - self.ws.create_blob("W").feed(W, device_option=gc) - self.ws.create_blob("b").feed(b, device_option=gc) - self.ws.run(net) - outputs.append( - Output(Y=self.ws.blobs["Y"].fetch(), op_type=op_type, engine=engine) + run_conv_or_fc( + self, None, net, X, W, b, op_type, engine, None, gc, outputs ) check_quantized_results_close(outputs) @@ -224,13 +220,8 @@ class DNNLowPFullyConnectedAcc16OpTest(hu.HypothesisTestCase): ) net.Proto().op.extend([dequantize]) - self.ws.create_blob("X").feed(X, device_option=gc) - self.ws.create_blob("W").feed(W, device_option=gc) - self.ws.create_blob("b").feed(b, device_option=gc) - self.ws.run(init_net) - self.ws.run(net) - outputs.append( - Output(Y=self.ws.blobs["Y"].fetch(), op_type=op_type, engine=engine) + run_conv_or_fc( + self, init_net, net, X, W, b, op_type, engine, None, gc, outputs ) check_quantized_results_close(outputs) diff --git a/caffe2/quantization/server/fully_connected_dnnlowp_op.cc b/caffe2/quantization/server/fully_connected_dnnlowp_op.cc index e5ff334..ca8639a 100644 --- a/caffe2/quantization/server/fully_connected_dnnlowp_op.cc +++ b/caffe2/quantization/server/fully_connected_dnnlowp_op.cc @@ -711,8 +711,7 @@ bool FullyConnectedDNNLowPOp::GetQuantizationParameters_() { const auto& packed_filter = this->template Input(1); CAFFE_ENFORCE(!dequantize_output_); - b_quantized_ = packed_filter.bias; - b_quantized_data_ = b_quantized_->data(); + b_quantized_data_ = packed_filter.bias->data(); } else { const auto& bias = InputTensorCPU_(2); if (this->template InputIsType(2)) { @@ -758,6 +757,8 @@ bool FullyConnectedDNNLowPOp::GetQuantizationParameters_() { if (in_qparams_[0].zero_point && column_offsets_->empty() && b_quantized_data_) { if (b_quantized_->empty()) { + // When b_quantized_data_ is from pre-packed bias or Int8TensorCPU, + // we can't inplace modify so copy to internal b_quantized_ vector. b_quantized_->assign(b_quantized_data_, b_quantized_data_ + N); b_quantized_data_ = b_quantized_->data(); } diff --git a/caffe2/quantization/server/fully_connected_dnnlowp_op_test.py b/caffe2/quantization/server/fully_connected_dnnlowp_op_test.py index 04f42eb..5468a11 100644 --- a/caffe2/quantization/server/fully_connected_dnnlowp_op_test.py +++ b/caffe2/quantization/server/fully_connected_dnnlowp_op_test.py @@ -10,6 +10,7 @@ from caffe2.quantization.server import utils as dnnlowp_utils from dnnlowp_test_utils import ( avoid_vpmaddubsw_overflow_fc, check_quantized_results_close, + run_conv_or_fc, ) from hypothesis import given @@ -191,13 +192,8 @@ class DNNLowPFullyConnectedOpTest(hu.HypothesisTestCase): ) net.Proto().op.extend([dequantize]) - self.ws.create_blob("X").feed(X, device_option=gc) - self.ws.create_blob("W").feed(W, device_option=gc) - self.ws.create_blob("b").feed(b, device_option=gc) - self.ws.run(init_net) - self.ws.run(net) - outputs.append( - Output(Y=self.ws.blobs["Y"].fetch(), op_type=op_type, engine=engine) + run_conv_or_fc( + self, init_net, net, X, W, b, op_type, engine, None, gc, outputs ) check_quantized_results_close(outputs, symmetric=preserve_activation_sparsity) diff --git a/caffe2/quantization/server/fully_connected_rowwise_dnnlowp_op_test.py b/caffe2/quantization/server/fully_connected_rowwise_dnnlowp_op_test.py index 14a926e..cd06c58 100644 --- a/caffe2/quantization/server/fully_connected_rowwise_dnnlowp_op_test.py +++ b/caffe2/quantization/server/fully_connected_rowwise_dnnlowp_op_test.py @@ -10,6 +10,7 @@ from caffe2.quantization.server import utils as dnnlowp_utils from dnnlowp_test_utils import ( avoid_vpmaddubsw_overflow_fc, check_quantized_results_close, + run_conv_or_fc, ) from hypothesis import given @@ -150,13 +151,8 @@ class RowWiseDNNLowPFullyConnectedOpTest(hu.HypothesisTestCase): ) net.Proto().op.extend([dequantize]) - self.ws.create_blob("X").feed(X, device_option=gc) - self.ws.create_blob("W").feed(W, device_option=gc) - self.ws.create_blob("b").feed(b, device_option=gc) - self.ws.run(init_net) - self.ws.run(net) - outputs.append( - Output(Y=self.ws.blobs["Y"].fetch(), op_type=op_type, engine=engine) + run_conv_or_fc( + self, init_net, net, X, W, b, op_type, engine, None, gc, outputs ) check_quantized_results_close(outputs) -- 2.7.4