From 880c26039431ee9069e6fa79a323ea8ba2ab6c17 Mon Sep 17 00:00:00 2001 From: shoubhik Date: Thu, 12 Sep 2019 16:34:20 -0700 Subject: [PATCH] Do type checking for the input and kernel in the qnn conv2d (#3904) * [QNN] Convolution 2D Implementation. Rebasing. Empty commit. Clang-format styling. * Reformatting code. * Fixing lint issues. --- src/relay/qnn/op/convolution.cc | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 6e1d13e..2fdb509 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -40,6 +40,26 @@ namespace qnn { // relay.op.qnn.conv2d TVM_REGISTER_NODE_TYPE(QnnConv2DAttrs); +bool QnnConv2DRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* weight = types[1].as(); + if (data == nullptr || weight == nullptr) return false; + const auto* param = attrs.as(); + CHECK(param != nullptr) << "QnnConv2DAttrs cannot be nullptr."; + CHECK(data->dtype == Int(8) || data->dtype == UInt(8)) + << "Expected qnn conv2d type(int8, uint8) for input but was " << data->dtype; + CHECK(weight->dtype == Int(8) || weight->dtype == UInt(8)) + << "Expected qnn conv2d type(int8, uint8) for weight but was " << weight->dtype; + CHECK(param->out_dtype == Int(16) || param->out_dtype == Int(32)) + << "Expected qnn conv2d type(int32, int16) for output but was " << param->out_dtype; + CHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0."; + return Conv2DRel(types, num_inputs, attrs, reporter); +} + // Workload - batch_size, in_channels, out_channels, kernel_h, kernel_w using WorkloadType = std::tuple; @@ -475,7 +495,7 @@ operator to understand how to scale back the int32 output to (u)int8. .add_argument("data", "Tensor", "The quantized input data tensor.") .add_argument("weight", "Tensor", "The quantized weight tensor.") .set_support_level(11) -.add_type_rel("QnnConv2D", Conv2DRel) +.add_type_rel("QnnConv2D", QnnConv2DRel) .set_attr("FTVMQnnCanonicalize", QnnConv2DCanonicalize); TVM_REGISTER_API("relay.qnn.op._make.conv2d").set_body_typed(MakeQnnConv2D); -- 2.7.4