[BYOC][ETHOSN] Introduce further operator support (#6355)
authormbaret <55580676+mbaret@users.noreply.github.com>
Tue, 15 Sep 2020 22:05:19 +0000 (23:05 +0100)
committerGitHub <noreply@github.com>
Tue, 15 Sep 2020 22:05:19 +0000 (15:05 -0700)
* [BYOC][ETHOSN] Introduce further operator support

This PR introduces support for the following operators:
 - Quantized Fully Connected
 - Quantized Addition
 - Depth-to-space
 - Max/Avg Pool 2D
 - Quantized Relu (Clip)
 - Reshape
 - Quantized Sigmoid

Co-authored-by: Leo Blonk <Leo.Blonk@arm.com>
Co-authored-by: Tristan O'Connor <tristan.oconnor@arm.com>
Co-authored-by: Ramana Radhakrishnan <ramana.radhakrishnan@arm.com>
* Skip tf imports if not available

Change-Id: I11bcf4a78014fa63e7b8e3b0cb00eecfd6cb7760

* ethos -> ethosn

Change-Id: I1fb1a11d0765f6d69f04c24b9c24e08665b8af6a

* Reduce random testing in test_addition

Change-Id: Id06063a0a0cf5f01356df23dc5d4bbbcb47cfa99

* Reduce random testing in test fullyconnected

Change-Id: I330408dfabc4bd804373f100581ce909ff724052

* Fix dumb mistake with rename

Change-Id: I2c5007be485b323116a0e8bab0f9106ea5ec834b

* Added comments to update the hashes in network tests when necessary

Change-Id: I13828c918c959daa492b9ed942a882c86d6690d1

* Fix github name

Change-Id: Idaa70ab9c2ec8db2828d51d15e7c23f28670ec82

* Use black formatting

Change-Id: I538171bd547a16395bef155a1dad28e8b3e347f2

Co-authored-by: Leo Blonk <Leo.Blonk@arm.com>
Co-authored-by: Tristan O'Connor <tristan.oconnor@arm.com>
Co-authored-by: Ramana Radhakrishnan <ramana.radhakrishnan@arm.com>
18 files changed:
python/tvm/relay/op/contrib/ethosn.py
src/relay/backend/contrib/ethosn/codegen.cc
src/relay/backend/contrib/ethosn/codegen_ethosn.h
src/relay/backend/contrib/ethosn/ethosn_api.cc
src/relay/backend/contrib/ethosn/ethosn_api.h
src/runtime/contrib/ethosn/ethosn_runtime.cc
src/runtime/contrib/ethosn/ethosn_runtime.h
tests/python/contrib/test_ethosn/infrastructure.py
tests/python/contrib/test_ethosn/test_addition.py [new file with mode: 0644]
tests/python/contrib/test_ethosn/test_conv2d.py
tests/python/contrib/test_ethosn/test_depth_to_space.py [new file with mode: 0644]
tests/python/contrib/test_ethosn/test_fullyconnected.py [new file with mode: 0644]
tests/python/contrib/test_ethosn/test_networks.py [new file with mode: 0644]
tests/python/contrib/test_ethosn/test_pooling.py [new file with mode: 0644]
tests/python/contrib/test_ethosn/test_relu.py [new file with mode: 0644]
tests/python/contrib/test_ethosn/test_reshape.py [new file with mode: 0644]
tests/python/contrib/test_ethosn/test_sigmoid.py [new file with mode: 0644]
tests/python/contrib/test_ethosn/test_topologies.py

index 213f4d3..3c676f4 100644 (file)
@@ -57,6 +57,28 @@ def pattern_table():
         )
         return pattern
 
+    def qnn_fc_pattern():
+        pattern = is_op("qnn.dense")(
+            wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
+        )
+        pattern = is_op("nn.bias_add")(pattern, is_constant())
+        pattern = is_op("qnn.requantize")(
+            pattern, is_constant(), is_constant(), is_constant(), is_constant()
+        )
+        return pattern
+
+    def qnn_avg_pool2d_pattern():
+        pattern = is_op("cast")(wildcard())
+        pattern = is_op("nn.avg_pool2d")(pattern)
+        pattern = is_op("cast")(pattern)
+        return pattern
+
+    def qnn_sigmoid_pattern():
+        pattern = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant())
+        pattern = is_op("sigmoid")(pattern)
+        pattern = is_op("qnn.quantize")(pattern, is_constant(), is_constant())
+        return pattern
+
     def check_conv2d(extract):
         """Check if a conv2d is supported by Ethos-N."""
         if not ethosn_available():
@@ -64,11 +86,80 @@ def pattern_table():
 
         return support.conv2d(extract)
 
+    def check_fc(extract):
+        """Check if a fully connected is supported by Ethos-N."""
+        if not ethosn_available():
+            return False
+
+        return support.fc(extract)
+
+    def check_avg_pool2d(extract):
+        """Check if a avg pool2d is supported by Ethos-N."""
+        if not ethosn_available():
+            return False
+
+        return support.avg_pool2d(extract)
+
+    def check_sigmoid(extract):
+        """Check if a sigmoid is supported by Ethos-N."""
+        if not ethosn_available():
+            return False
+
+        if extract.attrs.out_dtype != "uint8":
+            return False
+
+        return support.sigmoid(extract)
+
     return [
         ("ethos-n.qnn_conv2d", qnn_conv_pattern(), check_conv2d),
+        ("ethos-n.qnn_avg_pool2d", qnn_avg_pool2d_pattern(), check_avg_pool2d),
+        ("ethos-n.qnn_sigmoid", qnn_sigmoid_pattern(), check_sigmoid),
+        ("ethos-n.qnn_fc", qnn_fc_pattern(), check_fc),
     ]
 
 
+def _is_ethosn_composite(node):
+    if isinstance(node, tvm.relay.expr.Call) and isinstance(node.op, tvm.relay.Function):
+        if "Composite" in node.op.attrs:
+            comp_name = node.op.attrs["Composite"]
+            return comp_name.split(".")[0] == "ethos-n"
+
+    return False
+
+
+@tvm.ir.register_op_attr("nn.max_pool2d", "target.ethos-n")
+def max_pool2d(attrs, args):
+    """Check if a max pool2d is supported by Ethos-N."""
+    if not ethosn_available():
+        return False
+
+    pool = tvm.relay.nn.max_pool2d(*args, **attrs)
+    return support.max_pool2d(pool)
+
+
+@tvm.ir.register_op_attr("reshape", "target.ethos-n")
+def reshape(attrs, args):
+    """Check if a reshape is supported by Ethos-N."""
+    if not ethosn_available():
+        return False
+
+    if not _is_ethosn_composite(args[0]):
+        return False
+
+    rs = tvm.relay.op.reshape(*args, attrs["newshape"])
+    return support.reshape(rs)
+
+
+@tvm.ir.register_op_attr("qnn.add", "target.ethos-n")
+def qnn_add(attrs, args):
+    """Check if an addition is supported by Ethos-N."""
+    if not ethosn_available():
+        return False
+
+    add = _qnn.op.add(*args)
+    return support.addition(add)
+
+
 @tvm.ir.register_op_attr("qnn.concatenate", "target.ethos-n")
 def qnn_concatenate(attrs, args):
     """Check if a concatenate is supported by Ethos-N."""
@@ -116,3 +207,29 @@ def split(attrs, args):
         return False
 
     return True
+
+
+@tvm.ir.register_op_attr("nn.depth_to_space", "target.ethos-n")
+def depth_to_space(attrs, args):
+    """Check if a depth_to_space is supported by Ethos-N."""
+    if not ethosn_available():
+        return False
+
+    depth = tvm.relay.nn.depth_to_space(*args, **attrs)
+    if not support.depth_to_space(depth):
+        return False
+
+    return True
+
+
+@tvm.ir.register_op_attr("clip", "target.ethos-n")
+def clip(attrs, args):
+    """Check if a clip is supported by Ethos-N."""
+    if not ethosn_available():
+        return False
+
+    c = tvm.relay.clip(*args, **attrs)
+    if not support.relu(c):
+        return False
+
+    return True
index 58cd5bf..b331ced 100644 (file)
@@ -83,6 +83,34 @@ void InferTensorsVisitor::InferCall(const CallNode* cn) {
     ConvolutionParams params;
     err += EthosnAPI::QnnConv2d(cn->op.as<FunctionNode>()->body, &params);
     tensor_table_[cn->args[0]] = {params.activation_info};
+  } else if (IsEthosnFunc(call, "ethos-n.qnn_fc")) {
+    FullyConnectedParams params;
+    err += EthosnAPI::QnnFullyConnected(cn->op.as<FunctionNode>()->body, &params);
+    tensor_table_[cn->args[0]] = {params.input_info};
+  } else if (IsEthosnOp(call, "nn.max_pool2d")) {
+    MaxPool2DParams params;
+    params.input_info = GetTensorInfo(tensor_table_, call);
+    err += EthosnAPI::MaxPool2D(call, &params);
+    tensor_table_[cn->args[0]] = {params.input_info};
+  } else if (IsEthosnFunc(call, "ethos-n.qnn_avg_pool2d")) {
+    AvgPool2DParams params;
+    params.input_info = GetTensorInfo(tensor_table_, call);
+    err += EthosnAPI::AvgPool2D(cn->op.as<FunctionNode>()->body, &params);
+    tensor_table_[cn->args[0]] = {params.input_info};
+  } else if (IsEthosnOp(call, "reshape")) {
+    ReshapeParams params;
+    params.input_info = GetTensorInfo(tensor_table_, call);
+    err += EthosnAPI::Reshape(call, &params);
+    tensor_table_[cn->args[0]] = {params.input_info};
+  } else if (IsEthosnOp(call, "qnn.add")) {
+    AdditionParams params;
+    err += EthosnAPI::Addition(call, &params);
+    tensor_table_[cn->args[0]] = {params.lhs_info};
+    tensor_table_[cn->args[1]] = {params.rhs_info};
+  } else if (IsEthosnFunc(call, "ethos-n.qnn_sigmoid")) {
+    SigmoidParams params;
+    err += EthosnAPI::Sigmoid(cn->op.as<FunctionNode>()->body, &params);
+    tensor_table_[cn->args[0]] = {params.input_info};
   } else if (IsEthosnOp(call, "qnn.concatenate")) {
     ConcatenateParams params;
     err = EthosnAPI::Concatenate(call, &params);
@@ -92,6 +120,16 @@ void InferTensorsVisitor::InferCall(const CallNode* cn) {
     params.input_info = GetTensorInfo(tensor_table_, call);
     err = EthosnAPI::Split(call, &params);
     tensor_table_[cn->args[0]] = {params.input_info};
+  } else if (IsEthosnOp(call, "nn.depth_to_space")) {
+    DepthToSpaceParams params;
+    params.input_info = GetTensorInfo(tensor_table_, call);
+    err += EthosnAPI::DepthToSpace(call, &params);
+    tensor_table_[cn->args[0]] = {params.input_info};
+  } else if (IsEthosnOp(call, "clip")) {
+    ReluParams params;
+    params.input_info = GetTensorInfo(tensor_table_, call);
+    err = EthosnAPI::Relu(call, &params);
+    tensor_table_[cn->args[0]] = {params.input_info};
   } else {
     err = EthosnError("unknown operator");
   }
@@ -198,12 +236,36 @@ sl::TensorsAndId ConstructNetworkVisitor::HandleCall(const CallNode* cn) {
   if (IsEthosnFunc(call, "ethos-n.qnn_conv2d")) {
     if ((err = MakeConvolutionLayer(call, &tensor))) ReportFatalError(call, err);
     return MakeOps(tensor);
+  } else if (IsEthosnFunc(call, "ethos-n.qnn_fc")) {
+    if ((err = MakeFullyConnectedLayer(call, &tensor))) ReportFatalError(call, err);
+    return MakeOps(tensor);
+  } else if (IsEthosnOp(call, "nn.max_pool2d")) {
+    if ((err = MakeMaxPool2DLayer(call, &tensor))) ReportFatalError(call, err);
+    return MakeOps(tensor);
+  } else if (IsEthosnFunc(call, "ethos-n.qnn_avg_pool2d")) {
+    if ((err = MakeAvgPool2DLayer(call, &tensor))) ReportFatalError(call, err);
+    return MakeOps(tensor);
+  } else if (IsEthosnOp(call, "reshape")) {
+    if ((err = MakeReshapeLayer(call, &tensor))) ReportFatalError(call, err);
+    return MakeOps(tensor);
+  } else if (IsEthosnOp(call, "qnn.add")) {
+    if ((err = MakeAdditionLayer(call, &tensor))) ReportFatalError(call, err);
+    return MakeOps(tensor);
+  } else if (IsEthosnFunc(call, "ethos-n.qnn_sigmoid")) {
+    if ((err = MakeSigmoidLayer(call, &tensor))) ReportFatalError(call, err);
+    return MakeOps(tensor);
   } else if (IsEthosnOp(call, "qnn.concatenate")) {
     if ((err = MakeConcatenateLayer(call, &tensor))) ReportFatalError(call, err);
     return MakeOps(tensor);
   } else if (IsEthosnOp(call, "split")) {
     if ((err = MakeSplitLayer(call, &tensors))) ReportFatalError(call, err);
     return tensors;
+  } else if (IsEthosnOp(call, "nn.depth_to_space")) {
+    if ((err = MakeDepthToSpaceLayer(call, &tensor))) ReportFatalError(call, err);
+    return MakeOps(tensor);
+  } else if (IsEthosnOp(call, "clip")) {
+    if ((err = MakeReluLayer(call, &tensor))) ReportFatalError(call, err);
+    return MakeOps(tensor);
   } else {
     ReportFatalError(call, EthosnError("unknown operator"));
     return {};
@@ -266,6 +328,115 @@ EthosnError ConstructNetworkVisitor::MakeConvolutionLayer(const Call& call,
   return EthosnError();
 }
 
+EthosnError ConstructNetworkVisitor::MakeFullyConnectedLayer(const Call& call,
+                                                             sl::TensorAndId<sl::Operand>* out) {
+  FullyConnectedParams params;
+  if (auto err = EthosnAPI::QnnFullyConnected(call->op.as<FunctionNode>()->body, &params)) {
+    return err;
+  }
+
+  auto weights = AddConstant(network_, params.weights_info, params.raw_weights).tensor;
+  auto bias = AddConstant(network_, params.bias_info, params.raw_bias).tensor;
+  try {
+    auto input =
+        AddReshape(network_, *operand_table_[call->args[0]][0], params.input_info.m_Dimensions)
+            .tensor;
+    *out = AddFullyConnected(network_, *input, *bias, *weights, params.fc_info);
+  } catch (const sl::NotSupportedException& e) {
+    return EthosnError(e.what());
+  }
+  return EthosnError();
+}
+
+EthosnError ConstructNetworkVisitor::MakeMaxPool2DLayer(const Call& call,
+                                                        sl::TensorAndId<sl::Operand>* out) {
+  MaxPool2DParams params;
+  params.input_info = GetTensorInfo(tensor_table_, call);
+  if (auto err = EthosnAPI::MaxPool2D(call, &params)) {
+    return err;
+  }
+
+  auto input = operand_table_[call->args[0]][0];
+
+  try {
+    *out = AddPooling(network_, *input, params.pool_info);
+  } catch (const sl::NotSupportedException& e) {
+    return EthosnError(e.what());
+  }
+  return EthosnError();
+}
+
+EthosnError ConstructNetworkVisitor::MakeAvgPool2DLayer(const Call& call,
+                                                        sl::TensorAndId<sl::Operand>* out) {
+  AvgPool2DParams params;
+  params.input_info = GetTensorInfo(tensor_table_, call);
+  if (auto err = EthosnAPI::AvgPool2D(call->op.as<FunctionNode>()->body, &params)) {
+    return err;
+  }
+
+  auto input = operand_table_[call->args[0]][0];
+
+  try {
+    *out = AddPooling(network_, *input, params.pool_info);
+  } catch (const sl::NotSupportedException& e) {
+    return EthosnError(e.what());
+  }
+  return EthosnError();
+}
+
+EthosnError ConstructNetworkVisitor::MakeReshapeLayer(const Call& call,
+                                                      sl::TensorAndId<sl::Operand>* out) {
+  ReshapeParams params;
+  params.input_info = GetTensorInfo(tensor_table_, call);
+  if (auto err = EthosnAPI::Reshape(call, &params)) {
+    return err;
+  }
+
+  auto input = operand_table_[call->args[0]][0];
+
+  try {
+    *out = AddReshape(network_, *input, params.new_shape);
+  } catch (const sl::NotSupportedException& e) {
+    return EthosnError(e.what());
+  }
+  return EthosnError();
+}
+
+EthosnError ConstructNetworkVisitor::MakeAdditionLayer(const Call& call,
+                                                       sl::TensorAndId<sl::Operand>* out) {
+  AdditionParams params;
+  if (auto err = EthosnAPI::Addition(call, &params)) {
+    return err;
+  }
+
+  auto lhs = operand_table_[call->args[0]][0];
+  auto rhs = operand_table_[call->args[1]][0];
+
+  try {
+    *out = AddAddition(network_, *lhs, *rhs, params.output_quantization_info);
+  } catch (const sl::NotSupportedException& e) {
+    return EthosnError(e.what());
+  }
+  return EthosnError();
+}
+
+EthosnError ConstructNetworkVisitor::MakeSigmoidLayer(const Call& call,
+                                                      sl::TensorAndId<sl::Operand>* out) {
+  SigmoidParams params;
+  if (auto err = EthosnAPI::Sigmoid(call->op.as<FunctionNode>()->body, &params)) {
+    return err;
+  }
+
+  auto input = operand_table_[call->args[0]][0];
+
+  try {
+    *out = AddSigmoid(network_, *input);
+  } catch (const sl::NotSupportedException& e) {
+    return EthosnError(e.what());
+  }
+  return EthosnError();
+}
+
 EthosnError ConstructNetworkVisitor::MakeConcatenateLayer(const Call& call,
                                                           sl::TensorAndId<sl::Operand>* out) {
   ConcatenateParams params;
@@ -304,6 +475,42 @@ EthosnError ConstructNetworkVisitor::MakeSplitLayer(const Call& call, sl::Tensor
   return EthosnError();
 }
 
+EthosnError ConstructNetworkVisitor::MakeDepthToSpaceLayer(const Call& call,
+                                                           sl::TensorAndId<sl::Operand>* out) {
+  DepthToSpaceParams params;
+  params.input_info = GetTensorInfo(tensor_table_, call);
+  if (auto err = EthosnAPI::DepthToSpace(call, &params)) {
+    return err;
+  }
+
+  auto input = operand_table_[call->args[0]][0];
+
+  try {
+    *out = AddDepthToSpace(network_, *input, params.depth_info);
+  } catch (const sl::NotSupportedException& e) {
+    return EthosnError(e.what());
+  }
+  return EthosnError();
+}
+
+EthosnError ConstructNetworkVisitor::MakeReluLayer(const Call& call,
+                                                   sl::TensorAndId<sl::Operand>* out) {
+  ReluParams params;
+  params.input_info = GetTensorInfo(tensor_table_, call);
+  if (auto err = EthosnAPI::Relu(call, &params)) {
+    return err;
+  }
+
+  auto input = operand_table_[call->args[0]][0];
+
+  try {
+    *out = AddRelu(network_, *input, params.relu_info);
+  } catch (const sl::NotSupportedException& e) {
+    return EthosnError(e.what());
+  }
+  return EthosnError();
+}
+
 runtime::Module EthosnCompiler::CreateRuntimeModule(const ObjectRef& ref) {
   std::vector<runtime::ethosn::OrderedCompiledNetwork> cmms;
   if (ref->IsInstance<FunctionNode>()) {
index 7d1fe9c..a42db8f 100644 (file)
@@ -198,8 +198,16 @@ class ConstructNetworkVisitor : public MixedModeVisitor, private ErrorReportingP
 
   // Make a support library operand from a Call
   EthosnError MakeConvolutionLayer(const Call& call, sl::TensorAndId<sl::Operand>* out);
+  EthosnError MakeFullyConnectedLayer(const Call&, sl::TensorAndId<sl::Operand>* out);
+  EthosnError MakeMaxPool2DLayer(const Call& call, sl::TensorAndId<sl::Operand>* out);
+  EthosnError MakeAvgPool2DLayer(const Call& call, sl::TensorAndId<sl::Operand>* out);
+  EthosnError MakeReshapeLayer(const Call& call, sl::TensorAndId<sl::Operand>* out);
+  EthosnError MakeAdditionLayer(const Call& call, sl::TensorAndId<sl::Operand>* out);
+  EthosnError MakeSigmoidLayer(const Call& call, sl::TensorAndId<sl::Operand>* out);
   EthosnError MakeConcatenateLayer(const Call& call, sl::TensorAndId<sl::Operand>* out);
   EthosnError MakeSplitLayer(const Call& call, sl::TensorsAndId* outs);
+  EthosnError MakeDepthToSpaceLayer(const Call& call, sl::TensorAndId<sl::Operand>* out);
+  EthosnError MakeReluLayer(const Call& call, sl::TensorAndId<sl::Operand>* out);
 
   /*! \brief A look-up table from Expr to layers. */
   std::map<Expr, std::vector<std::shared_ptr<sl::Operand>>> operand_table_;
index b7cac65..2aa2632 100644 (file)
@@ -139,6 +139,235 @@ EthosnError EthosnAPI::QnnConv2d(const Expr& expr, ConvolutionParams* params) {
   return err;
 }
 
+EthosnError EthosnAPI::QnnFullyConnected(const Expr& expr, FullyConnectedParams* params) {
+  Call requantize = Downcast<Call>(expr);
+  Call bias_add = Downcast<Call>(requantize->args[0]);
+  Call dense = Downcast<Call>(bias_add->args[0]);
+
+  // Extract the quantization params from the arguments
+  int input_zero_point;
+  int kernel_zero_point;
+  int output_zero_point;
+  float input_scale;
+  float kernel_scale;
+  float output_scale;
+  EthosnError err = AsConstant<int>(dense->args[2], &input_zero_point);
+  err += AsConstant<int>(dense->args[3], &kernel_zero_point);
+  err += AsConstant<int>(requantize->args[4], &output_zero_point);
+  err += AsConstant<float>(dense->args[4], &input_scale);
+  err += AsConstant<float>(dense->args[5], &kernel_scale);
+  err += AsConstant<float>(requantize->args[3], &output_scale);
+
+  // Convert quantization params
+  sl::QuantizationInfo data_q_info;
+  sl::QuantizationInfo weights_q_info;
+  sl::QuantizationInfo bias_q_info;
+  sl::QuantizationInfo output_q_info;
+  err += Tvm2Npu(input_zero_point, input_scale, &data_q_info);
+  err += Tvm2Npu(kernel_zero_point, kernel_scale, &weights_q_info);
+  err += Tvm2Npu(0, data_q_info.m_Scale * weights_q_info.m_Scale, &bias_q_info);
+  err += Tvm2Npu(output_zero_point, output_scale, &output_q_info);
+
+  // Create fc info
+  params->fc_info = sl::FullyConnectedInfo(output_q_info);
+
+  // Create data info
+  const TensorTypeNode* data_dtype = dense->args[0]->checked_type().as<TensorTypeNode>();
+  sl::TensorShape data_tensor_shape;
+  sl::DataType data_data_type;
+  err += Tvm2Npu(data_dtype->shape, &data_tensor_shape);
+  err += Tvm2Npu(data_dtype->dtype, &data_data_type);
+  params->input_info = sl::TensorInfo({data_tensor_shape[0], 1, 1, data_tensor_shape[1]},
+                                      data_data_type, sl::DataFormat::NHWC, data_q_info);
+
+  // Create weights info
+  const auto* weights_dtype = dense->args[1]->checked_type().as<TensorTypeNode>();
+  sl::TensorShape weights_tensor_shape;
+  sl::DataType weights_data_type;
+  sl::DataFormat weights_data_format;
+  // Ignore the error here because weights don't have a batch axis
+  Tvm2Npu(weights_dtype->shape, &weights_tensor_shape);
+  err += Tvm2Npu(weights_dtype->dtype, &weights_data_type);
+  err += Tvm2Npu("HWIO", &weights_data_format);
+  params->weights_info = sl::TensorInfo({1, 1, weights_tensor_shape[1], weights_tensor_shape[0]},
+                                        weights_data_type, weights_data_format, weights_q_info);
+  params->raw_weights = dense->args[1].as<ConstantNode>()->data->data;
+
+  // Create bias info
+  params->bias_info =
+      sl::TensorInfo({1, 1, 1, weights_tensor_shape[0]}, sl::DataType::INT32_QUANTIZED,
+                     sl::DataFormat::NHWC, bias_q_info);
+  params->raw_bias = bias_add->args[1].as<ConstantNode>()->data->data;
+
+  return err;
+}
+
+EthosnError EthosnAPI::Pool2d(const Call& pool, Array<IndexExpr> size, Array<IndexExpr> strides,
+                              Array<IndexExpr> padding, sl::PoolingType pooling_type,
+                              sl::PoolingInfo* pool_info, sl::TensorInfo* input_info,
+                              std::string layout) {
+  uint32_t npu_sizex, npu_sizey;
+  sl::Padding npu_padding;
+  sl::Stride npu_stride;
+  EthosnError err = Tvm2Npu(size, &npu_sizex, &npu_sizey);
+  err += Tvm2Npu(padding, &npu_padding);
+  err += Tvm2Npu(strides, &npu_stride);
+  *pool_info = sl::PoolingInfo(npu_sizex, npu_sizey, npu_stride.m_X, npu_stride.m_Y, npu_padding,
+                               pooling_type);
+
+  // Create input info
+  const auto* input_dtype = pool->args[0]->checked_type().as<TensorTypeNode>();
+  sl::TensorShape input_tensor_shape;
+  sl::DataType input_data_type;
+  sl::DataFormat input_data_format;
+  err += Tvm2Npu(input_dtype->shape, &input_tensor_shape);
+  err += Tvm2Npu(input_dtype->dtype, &input_data_type);
+  err += Tvm2Npu(layout, &input_data_format);
+  if (input_data_format != sl::DataFormat::NHWC) {
+    return EthosnError(ErrStrm() << "data format=" << layout << ", data format must = NHWC");
+  }
+  *input_info = sl::TensorInfo(input_tensor_shape, input_data_type, input_data_format,
+                               input_info->m_QuantizationInfo);
+  return err;
+}
+
+EthosnError EthosnAPI::MaxPool2D(const Expr& expr, MaxPool2DParams* params) {
+  Call pool = Downcast<Call>(expr);
+  const auto pool_attrs = pool->attrs.as<MaxPool2DAttrs>();
+  return Pool2d(pool, pool_attrs->pool_size, pool_attrs->strides, pool_attrs->padding,
+                sl::PoolingType::MAX, &params->pool_info, &params->input_info, pool_attrs->layout);
+}
+
+EthosnError EthosnAPI::AvgPool2D(const Expr& expr, AvgPool2DParams* params) {
+  Call cast_0 = Downcast<Call>(expr);
+  Call pool = Downcast<Call>(cast_0->args[0]);
+  Call cast_1 = Downcast<Call>(pool->args[0]);
+  const auto pool_attrs = pool->attrs.as<AvgPool2DAttrs>();
+  return Pool2d(cast_1, pool_attrs->pool_size, pool_attrs->strides, pool_attrs->padding,
+                sl::PoolingType::AVG, &params->pool_info, &params->input_info, pool_attrs->layout);
+}
+
+EthosnError EthosnAPI::Reshape(const Expr& expr, ReshapeParams* params) {
+  // Create input info
+  Call reshape = Downcast<Call>(expr);
+  const auto* input_dtype = reshape->args[0]->checked_type().as<TensorTypeNode>();
+  const auto& reshape_attrs = reshape->attrs.as<ReshapeAttrs>();
+
+  sl::TensorShape input_tensor_shape = {1, 1, 1, 1};
+  sl::DataType input_data_type;
+  EthosnError err = Tvm2Npu(input_dtype->shape, &input_tensor_shape);
+  err += Tvm2Npu(input_dtype->dtype, &input_data_type);
+  int tensor_size = 1;
+  for (const auto& dim : input_tensor_shape) {
+    tensor_size *= dim;
+  }
+
+  int infer_index = -1;
+  int reshaped_size = 1;
+  Array<Integer> inferred_shape = {1, 1, 1, 1};
+  for (size_t i = 0; i < reshape_attrs->newshape.size(); i++) {
+    int value = reshape_attrs->newshape[i].as<IntImmNode>()->value;
+    if (value < -1) {
+      return EthosnError(ErrStrm()
+                         << "reshape dimension=" << value << ", reshape dimension must be >= -1");
+    }
+    if (value == -1) {
+      if (infer_index != -1) {
+        return EthosnError("only one reshape dimension can be inferred");
+      }
+      infer_index = i;
+    } else {
+      inferred_shape.Set(i, value);
+      reshaped_size *= value;
+    }
+  }
+
+  if (infer_index != -1) {
+    if (tensor_size % reshaped_size != 0) {
+      return EthosnError(ErrStrm()
+                         << "reshaped size=" << reshaped_size
+                         << ", must be an integer factor of the input size " << tensor_size);
+    }
+    int value = tensor_size / reshaped_size;
+    inferred_shape.Set(infer_index, Integer(value));
+  }
+  err += Tvm2Npu(inferred_shape, &params->new_shape);
+  params->input_info =
+      sl::TensorInfo(input_tensor_shape, input_data_type, params->input_info.m_DataFormat,
+                     params->input_info.m_QuantizationInfo);
+
+  return err;
+}
+
+EthosnError EthosnAPI::Addition(const Expr& expr, AdditionParams* params) {
+  Call call = Downcast<Call>(expr);
+  // Extract the quantization params from the arguments
+  float lhs_scale;
+  int lhs_zero_point;
+  float rhs_scale;
+  int rhs_zero_point;
+  float output_scale;
+  int output_zero_point;
+  EthosnError err = AsConstant<float>(call->args[2], &lhs_scale);
+  err += AsConstant<int>(call->args[3], &lhs_zero_point);
+  err += AsConstant<float>(call->args[4], &rhs_scale);
+  err += AsConstant<int>(call->args[5], &rhs_zero_point);
+  err += AsConstant<float>(call->args[6], &output_scale);
+  err += AsConstant<int>(call->args[7], &output_zero_point);
+
+  sl::QuantizationInfo lhs_q_info;
+  sl::QuantizationInfo rhs_q_info;
+  err += Tvm2Npu(lhs_zero_point, lhs_scale, &lhs_q_info);
+  err += Tvm2Npu(rhs_zero_point, rhs_scale, &rhs_q_info);
+  err += Tvm2Npu(output_zero_point, output_scale, &params->output_quantization_info);
+
+  // Create input info
+  const auto* lhs_dtype = call->args[0]->checked_type().as<TensorTypeNode>();
+  sl::TensorShape lhs_tensor_shape;
+  sl::DataType lhs_data_type;
+  err += Tvm2Npu(lhs_dtype->shape, &lhs_tensor_shape);
+  err += Tvm2Npu(lhs_dtype->dtype, &lhs_data_type);
+  params->lhs_info =
+      sl::TensorInfo(lhs_tensor_shape, lhs_data_type, sl::DataFormat::NHWC, lhs_q_info);
+
+  const auto* rhs_dtype = call->args[1]->checked_type().as<TensorTypeNode>();
+  sl::TensorShape rhs_tensor_shape;
+  sl::DataType rhs_data_type;
+  err += Tvm2Npu(rhs_dtype->shape, &rhs_tensor_shape);
+  err += Tvm2Npu(rhs_dtype->dtype, &rhs_data_type);
+  params->rhs_info =
+      sl::TensorInfo(rhs_tensor_shape, rhs_data_type, sl::DataFormat::NHWC, rhs_q_info);
+  return err;
+}
+
+EthosnError EthosnAPI::Sigmoid(const Expr& expr, SigmoidParams* params) {
+  Call quantize = Downcast<Call>(expr);
+  Call sigmoid = Downcast<Call>(quantize->args[0]);
+  Call dequantize = Downcast<Call>(sigmoid->args[0]);
+
+  // Create input info
+  const auto* input_dtype = quantize->checked_type().as<TensorTypeNode>();
+  sl::TensorShape input_tensor_shape = {1, 1, 1, 1};
+  sl::DataType input_tensor_dtype;
+  EthosnError err = Tvm2Npu(input_dtype->shape, &input_tensor_shape);
+  err += Tvm2Npu(input_dtype->dtype, &input_tensor_dtype);
+  float input_sc;
+  int input_zp;
+  err += AsConstant<int>(dequantize->args[2], &input_zp);
+  err += AsConstant<float>(dequantize->args[1], &input_sc);
+  float output_sc;
+  int output_zp;
+  err += AsConstant<int>(quantize->args[2], &output_zp);
+  err += AsConstant<float>(quantize->args[1], &output_sc);
+  if (output_zp != 0 || output_sc != 1.0f / 256.0f) {
+    err += EthosnError(ErrStrm() << "output quantization params=(" << output_zp << ", " << output_sc
+                                 << "), must = (0, 1/256)");
+  }
+  params->input_info = sl::TensorInfo(input_tensor_shape, input_tensor_dtype, sl::DataFormat::NHWC,
+                                      sl::QuantizationInfo(input_zp, input_sc));
+  return err;
+}
+
 EthosnError EthosnAPI::Concatenate(const Expr& expr, ConcatenateParams* params) {
   Call call = Downcast<Call>(expr);
   const auto& attrs = call->attrs.as<ConcatenateAttrs>();
@@ -206,6 +435,46 @@ EthosnError EthosnAPI::Split(const Expr& expr, SplitParams* params) {
   return err;
 }
 
+EthosnError EthosnAPI::DepthToSpace(const Expr& expr, DepthToSpaceParams* params) {
+  Call call = Downcast<Call>(expr);
+  const auto* input_dtype = call->args[0]->checked_type().as<TensorTypeNode>();
+  const auto* attrs = call->attrs.as<SubPixelAttrs>();
+  if (attrs->mode != "DCR") {
+    return EthosnError(ErrStrm() << "mode=" << attrs->mode << ", mode must = DCR");
+  }
+  params->depth_info.m_BlockSize = attrs->block_size;
+
+  sl::TensorShape input_tensor_shape;
+  sl::DataType input_data_type;
+  sl::DataFormat input_data_format;
+  EthosnError err = Tvm2Npu(input_dtype->shape, &input_tensor_shape);
+  err += Tvm2Npu(input_dtype->dtype, &input_data_type);
+  err += Tvm2Npu(attrs->layout, &input_data_format);
+  if (input_data_format != sl::DataFormat::NHWC) {
+    err += EthosnError(ErrStrm() << "layout=" << attrs->layout << ", layout must = NHWC");
+  }
+  params->input_info = sl::TensorInfo(input_tensor_shape, input_data_type, input_data_format,
+                                      params->input_info.m_QuantizationInfo);
+  return err;
+}
+
+EthosnError EthosnAPI::Relu(const Expr& expr, ReluParams* params) {
+  Call call = Downcast<Call>(expr);
+  const auto* input_dtype = call->args[0]->checked_type().as<TensorTypeNode>();
+  const auto* attrs = call->attrs.as<ClipAttrs>();
+  params->relu_info.m_LowerBound = attrs->a_min;
+  params->relu_info.m_UpperBound = attrs->a_max;
+
+  sl::TensorShape input_tensor_shape = {1, 1, 1, 1};
+  sl::DataType input_data_type;
+  EthosnError err = Tvm2Npu(input_dtype->shape, &input_tensor_shape);
+  err += Tvm2Npu(input_dtype->dtype, &input_data_type);
+  params->input_info =
+      sl::TensorInfo(input_tensor_shape, input_data_type, params->input_info.m_DataFormat,
+                     params->input_info.m_QuantizationInfo);
+  return err;
+}
+
 EthosnError EthosnAPI::Tvm2Npu(const Array<IndexExpr>& padding, sl::Padding* npu_padding) {
   std::array<uint32_t, 4> dim;
   if (EthosnError err = AsArray<IndexExpr, uint32_t>(padding, &dim)) {
@@ -242,6 +511,19 @@ EthosnError EthosnAPI::Tvm2Npu(const Array<IndexExpr>& strides, sl::Stride* npu_
   return EthosnError();
 }
 
+EthosnError EthosnAPI::Tvm2Npu(const Array<IndexExpr>& size, uint32_t* x, uint32_t* y) {
+  if (size.size() != 2) {
+    return EthosnError(ErrStrm() << "dimensions=" << size.size() << ", dimensions must = 2");
+  }
+  std::array<uint32_t, 4> dim;
+  if (EthosnError err = AsArray<IndexExpr, uint32_t>(size, &dim)) {
+    return err;
+  }
+  *x = dim[0];
+  *y = dim[1];
+  return EthosnError();
+}
+
 EthosnError EthosnAPI::Tvm2Npu(const std::string& dformat, sl::DataFormat* data_format) {
   if (dformat == "NCHW") {
     *data_format = sl::DataFormat::NCHW;
@@ -286,6 +568,10 @@ EthosnError EthosnAPI::Tvm2Npu(int32_t zero_point, float scale, sl::Quantization
   return EthosnError();
 }
 
+EthosnError EthosnAPI::Tvm2Npu(const Array<Integer>& shape, sl::TensorShape* npu_shape) {
+  return AsArray<Integer, uint32_t>(shape, npu_shape);
+}
+
 EthosnError EthosnAPI::Tvm2Npu(const Array<Array<Integer>>& padding, sl::Padding* npu_padding) {
   if (padding.size() != 4) {
     return EthosnError(ErrStrm() << "padding tuple size=" << padding.size()
@@ -348,6 +634,56 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.conv2d")
       }
     });
 
+TVM_REGISTER_GLOBAL("relay.ethos-n.support.fc")
+    .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
+      Call call = args[0];
+      FullyConnectedParams params;
+      auto err = EthosnAPI::QnnFullyConnected(call, &params);
+      *rv = !err && sl::IsFullyConnectedSupported(params.bias_info, params.weights_info,
+                                                  params.fc_info, params.input_info);
+    });
+
+TVM_REGISTER_GLOBAL("relay.ethos-n.support.max_pool2d")
+    .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
+      Call call = args[0];
+      MaxPool2DParams params;
+      auto err = EthosnAPI::MaxPool2D(call, &params);
+      *rv = !err && sl::IsPoolingSupported(params.pool_info, params.input_info);
+    });
+
+TVM_REGISTER_GLOBAL("relay.ethos-n.support.avg_pool2d")
+    .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
+      Call call = args[0];
+      AvgPool2DParams params;
+      auto err = EthosnAPI::AvgPool2D(call, &params);
+      *rv = !err && sl::IsPoolingSupported(params.pool_info, params.input_info);
+    });
+
+TVM_REGISTER_GLOBAL("relay.ethos-n.support.reshape")
+    .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
+      Call call = args[0];
+      ReshapeParams params;
+      auto err = EthosnAPI::Reshape(call, &params);
+      *rv = !err && sl::IsReshapeSupported(params.new_shape, params.input_info);
+    });
+
+TVM_REGISTER_GLOBAL("relay.ethos-n.support.addition")
+    .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
+      Call call = args[0];
+      AdditionParams params;
+      auto err = EthosnAPI::Addition(call, &params);
+      *rv = !err && sl::IsAdditionSupported(params.lhs_info, params.rhs_info,
+                                            params.output_quantization_info);
+    });
+
+TVM_REGISTER_GLOBAL("relay.ethos-n.support.sigmoid")
+    .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
+      Call call = args[0];
+      SigmoidParams params;
+      auto err = EthosnAPI::Sigmoid(call, &params);
+      *rv = !err && sl::IsSigmoidSupported(params.input_info);
+    });
+
 TVM_REGISTER_GLOBAL("relay.ethos-n.support.concatenate")
     .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
       Call call = args[0];
@@ -364,6 +700,22 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.split")
       *rv = !err && sl::IsSplitSupported(params.input_info, params.split_info);
     });
 
+TVM_REGISTER_GLOBAL("relay.ethos-n.support.depth_to_space")
+    .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
+      Call call = args[0];
+      DepthToSpaceParams params;
+      auto err = EthosnAPI::DepthToSpace(call, &params);
+      *rv = !err && sl::IsDepthToSpaceSupported(params.input_info, params.depth_info);
+    });
+
+TVM_REGISTER_GLOBAL("relay.ethos-n.support.relu")
+    .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
+      Call call = args[0];
+      ReluParams params;
+      auto err = EthosnAPI::Relu(call, &params);
+      *rv = !err && sl::IsReluSupported(params.relu_info, params.input_info);
+    });
+
 TVM_REGISTER_GLOBAL("relay.ethos-n.query").set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
 #if defined ETHOSN_HW
   *rv = true;
index 20fe8be..e1b57b8 100644 (file)
@@ -54,6 +54,40 @@ struct ConvolutionParams {
   bool is_depthwise = false;
 };
 
+struct FullyConnectedParams {
+  sl::FullyConnectedInfo fc_info;
+  sl::TensorInfo input_info;
+  sl::TensorInfo weights_info;
+  sl::TensorInfo bias_info;
+  void* raw_weights = nullptr;
+  void* raw_bias = nullptr;
+};
+
+struct MaxPool2DParams {
+  sl::PoolingInfo pool_info = sl::PoolingInfo(0, 0, 0, 0, sl::Padding(), sl::PoolingType::MAX);
+  sl::TensorInfo input_info;
+};
+
+struct AvgPool2DParams {
+  sl::PoolingInfo pool_info = sl::PoolingInfo(0, 0, 0, 0, sl::Padding(), sl::PoolingType::AVG);
+  sl::TensorInfo input_info;
+};
+
+struct ReshapeParams {
+  sl::TensorShape new_shape{};
+  sl::TensorInfo input_info;
+};
+
+struct AdditionParams {
+  sl::QuantizationInfo output_quantization_info;
+  sl::TensorInfo lhs_info;
+  sl::TensorInfo rhs_info;
+};
+
+struct SigmoidParams {
+  sl::TensorInfo input_info;
+};
+
 struct ConcatenateParams {
   sl::QuantizationInfo qInfo;
   sl::ConcatenationInfo concat_info = sl::ConcatenationInfo(1, qInfo);
@@ -65,6 +99,16 @@ struct SplitParams {
   sl::TensorInfo input_info;
 };
 
+struct DepthToSpaceParams {
+  sl::DepthToSpaceInfo depth_info = sl::DepthToSpaceInfo(0);
+  sl::TensorInfo input_info;
+};
+
+struct ReluParams {
+  sl::ReluInfo relu_info;
+  sl::TensorInfo input_info;
+};
+
 /*!
  * \brief A wrapper around std::stringstream to build an EthosnError.
  */
@@ -127,13 +171,29 @@ class EthosnAPI {
  public:
   /*! \brief Extract the Support Library convolution params from an ethos-n.qnn_conv2d func */
   static EthosnError QnnConv2d(const Expr& expr, ConvolutionParams* params);
+  /*! \brief Extract the Support Library dense params from an ethos-n.qnn_fc func */
+  static EthosnError QnnFullyConnected(const Expr& expr, FullyConnectedParams* params);
+  /*! \brief Extract the Support Library max_pool2d params from a Relay max_pool2d call */
+  static EthosnError MaxPool2D(const Expr& expr, MaxPool2DParams* params);
+  /*! \brief Extract the Support Library avg_pool params from a Relay ethos-n.qnn_avg_pool2d func */
+  static EthosnError AvgPool2D(const Expr& expr, AvgPool2DParams* params);
+  /*! \brief Extract the Support Library reshape params from a Relay reshape call */
+  static EthosnError Reshape(const Expr& expr, ReshapeParams* params);
+  /*! \brief Extract the Support Library addition params from a Relay qnn.addition call */
+  static EthosnError Addition(const Expr& expr, AdditionParams* params);
+  /*! \brief Extract the Support Library sigmoid params from a Relay an ethos-n.qnn_sigmoid func */
+  static EthosnError Sigmoid(const Expr& expr, SigmoidParams* params);
   /*! \brief Extract the Support Library concatenate params from a Relay qnn.concatenate call */
   static EthosnError Concatenate(const Expr& expr, ConcatenateParams* params);
   /*! \brief Extract the Support Library split params from a Relay split call */
   static EthosnError Split(const Expr& expr, SplitParams* params);
+  /*! \brief Extract the Support Library depth_to_space params from a Relay depth_to_space call */
+  static EthosnError DepthToSpace(const Expr& expr, DepthToSpaceParams* params);
+  /*! \brief Extract the Support Library relu params from a Relay relu call */
+  static EthosnError Relu(const Expr& expr, ReluParams* params);
 
  private:
-  /*! \brief Convert a TVM tensor shape to a SL tensor shape */
+  /*! \brief Convert a TVM IndexExpr array to a SL tensor shape */
   static EthosnError Tvm2Npu(const Array<IndexExpr>& shape, sl::TensorShape* npu_shape);
   /*! \brief Convert a TVM data type to a SL data type */
   static EthosnError Tvm2Npu(const tvm::DataType& dtype, sl::DataType* data_type);
@@ -143,10 +203,19 @@ class EthosnAPI {
   static EthosnError Tvm2Npu(const Array<IndexExpr>& strides, sl::Stride* npu_stride);
   /*! \brief Convert TVM data format to SL data format */
   static EthosnError Tvm2Npu(const std::string& dformat, sl::DataFormat* data_format);
+  /*! \brief Convert TVM size array for pooling size to x and y values */
+  static EthosnError Tvm2Npu(const Array<IndexExpr>& size, uint32_t* x, uint32_t* y);
   /*! \brief Convert TVM quantization info to SL quantization info */
   static EthosnError Tvm2Npu(int32_t zero_point, float scale, sl::QuantizationInfo* npu_qinfo);
   /*! \brief Convert TVM 2D padding to SL padding */
   static EthosnError Tvm2Npu(const Array<Array<Integer>>& padding, sl::Padding* npu_padding);
+  /*! \brief Convert a TVM Integer array to a SL tensor shape */
+  static EthosnError Tvm2Npu(const Array<Integer>& shape, sl::TensorShape* npu_shape);
+  /*! \brief Convert a TVM pooling call to SL pooling information */
+  static EthosnError Pool2d(const Call& pool, Array<IndexExpr> size, Array<IndexExpr> strides,
+                            Array<IndexExpr> padding, sl::PoolingType pooling_type,
+                            sl::PoolingInfo* pool_info, sl::TensorInfo* input_info,
+                            std::string layout);
 
   // Convert an array of IntImmNodes into ValueT
   // IndexT type of Array indexing variable
index 0fbebcf..f5164e5 100644 (file)
@@ -120,6 +120,14 @@ Module EthosnModule::LoadFromBinary(void* strm) {
   return Module(n);
 }
 
+void EthosnModule::SaveToFile(const std::string& path, const std::string& format) {
+  std::string data;
+  dmlc::MemoryStringStream writer(&data);
+  dmlc::SeekStream* strm = &writer;
+  SaveToBinary(strm);
+  SaveBinaryToFile(path, data);
+}
+
 TVM_REGISTER_GLOBAL("runtime.module.loadbinary_ethos-n")
     .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = EthosnModule::LoadFromBinary(args[0]); });
 }  // namespace ethosn
index 730739c..7a111e8 100644 (file)
@@ -86,6 +86,11 @@ class EthosnModule : public ModuleNode {
    *       ] * number of functions
    */
   static Module LoadFromBinary(void* strm);
+  /*!
+   * \brief Save a module to a specified path.
+   * \param path Where to save the serialized module.
+   */
+  void SaveToFile(const std::string& path, const std::string& format) override;
 
   const char* type_key() const override { return "ethos-n"; }
 
index 2c88d56..31ebb1a 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 
-"""Expose Ethos test functions to the Python front end"""
+"""Ethos-N test functions"""
 
 from __future__ import absolute_import, print_function
 import tvm
 from tvm import relay
 from tvm.contrib import util, graph_runtime, download
-from tvm.relay.testing import run_opt_pass
-from enum import Enum
 from hashlib import md5
 from itertools import zip_longest, combinations
 import numpy as np
@@ -33,6 +31,25 @@ from . import _infrastructure
 from tvm.relay.op.contrib import get_pattern_table
 
 
+def get_real_image(im_height, im_width):
+    repo_base = "https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/"
+    img_name = "elephant-299.jpg"
+    image_url = os.path.join(repo_base, img_name)
+    img_path = download.download_testdata(image_url, img_name, module="data")
+    image = Image.open(img_path).resize((im_height, im_width))
+    x = np.array(image).astype("uint8")
+    data = np.reshape(x, (1, im_height, im_width, 3))
+    return data
+
+
+def assert_lib_hash(lib, golden):
+    temp = util.tempdir()
+    path = temp.relpath("lib.cmm")
+    lib.imported_modules[1].save(path)
+    lib_hash = md5(open(path, "rb").read()).hexdigest()
+    assert lib_hash == golden, "Expected hash: {} Got hash: {}".format(golden, lib_hash)
+
+
 def make_module(func, params):
     func = relay.Function(relay.analysis.free_vars(func), func)
     if params:
@@ -177,3 +194,54 @@ def test_error(mod, params, err_msg):
 
     assert caught is not None
     assert err_msg in caught, caught
+
+
+def get_conv2d(var, shape):
+    """Standard convolution to test activation functions"""
+
+    weight_shape = (1, 1, shape[3], 1)
+    w = tvm.nd.array(np.ones(weight_shape, "uint8"))
+    weights = relay.const(w, "uint8")
+    conv = relay.qnn.op.conv2d(
+        var,
+        weights,
+        input_zero_point=relay.const(0, "int32"),
+        kernel_zero_point=relay.const(0, "int32"),
+        input_scale=relay.const(1.0, "float32"),
+        kernel_scale=relay.const(1.0, "float32"),
+        kernel_size=(1, 1),
+        channels=1,
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+    )
+    b = tvm.nd.array(np.zeros((shape[0],), "int32"))
+    biasc = relay.const(b, "int32")
+    bias = relay.nn.bias_add(conv, biasc, axis=0)
+    req = relay.qnn.op.requantize(
+        bias,
+        relay.const(1.0, "float32"),  # input zero scale
+        relay.const(0, "int32"),  # input zero point
+        relay.const(1.1, "float32"),  # output zero scale
+        relay.const(0, "int32"),  # output zero point
+        out_dtype="uint8",
+    )
+    params = {"w": w, "b": b}
+    return req, params
+
+
+def get_conv2d_qnn_params(input_zp, input_sc, kernel_zp, kernel_sc, kernel_h, kernel_w, channels):
+    input_max = input_sc * (255 - input_zp)
+    input_min = -input_sc * input_zp
+    kernel_max = kernel_sc * (255 - kernel_zp)
+    kernel_min = -kernel_sc * kernel_zp
+    output_limits = [
+        kernel_max * kernel_h * kernel_w * channels * input_max,
+        kernel_min * kernel_h * kernel_w * channels * input_max,
+        kernel_min * kernel_h * kernel_w * channels * input_min,
+        kernel_max * kernel_h * kernel_w * channels * input_min,
+    ]
+    output_max = max(output_limits)
+    output_min = min(output_limits)
+    output_sc = (output_max - output_min) / 255
+    output_zp = -int(output_min / output_sc)
+    return output_zp, output_sc
diff --git a/tests/python/contrib/test_ethosn/test_addition.py b/tests/python/contrib/test_ethosn/test_addition.py
new file mode 100644 (file)
index 0000000..a332ab9
--- /dev/null
@@ -0,0 +1,116 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Ethos-N integration addition tests"""
+
+import tvm
+from tvm import relay
+from tvm.relay.op.contrib.ethosn import ethosn_available
+from . import infrastructure as tei
+import numpy as np
+
+
+def _get_model(input_shape, lhs_zp, lhs_sc, rhs_zp, rhs_sc, out_zp, out_sc, dtype):
+    """Return a model and any parameters it may have"""
+
+    a = relay.var("a", shape=input_shape, dtype=dtype)
+    b = relay.var("b", shape=input_shape, dtype=dtype)
+    model = relay.qnn.op.add(
+        lhs=a,
+        rhs=b,
+        lhs_scale=relay.const(lhs_sc, "float32"),
+        lhs_zero_point=relay.const(lhs_zp, "int32"),
+        rhs_scale=relay.const(rhs_sc, "float32"),
+        rhs_zero_point=relay.const(rhs_zp, "int32"),
+        output_scale=relay.const(out_sc, "float32"),
+        output_zero_point=relay.const(out_zp, "int32"),
+    )
+    return model
+
+
+def _get_addition_qnn_params(input1_zp, input1_sc, input2_zp, input2_sc):
+    input1_max = input1_sc * (255 - input1_zp)
+    input1_min = -input1_sc * input1_zp
+    input2_max = input2_sc * (255 - input2_zp)
+    input2_min = -input2_sc * input2_zp
+    output_max = input1_max + input2_max
+    output_min = input1_min + input2_min
+    output_sc = (output_max - output_min) / 255
+    output_zp = -int(output_min / output_sc)
+    return output_zp, output_sc
+
+
+def test_addition():
+    if not ethosn_available():
+        return
+
+    trials = [
+        ((1, 22, 9, 9), 24, 1.057, 253, 0.452),
+        ((1, 27, 21, 16), 79, 0.850, 24, 0.380),
+        ((1, 7, 12, 28), 125, 1.293, 239, 0.320),
+        ((1, 14, 9, 6), 14, 0.942, 227, 1.562),
+        ((1, 13, 16, 22), 15, 0.727, 180, 0.461),
+    ]
+    np.random.seed(0)
+    for shape, rhs_zp, rhs_sc, lhs_zp, lhs_sc in trials:
+        outputs = []
+        inputs = {
+            "a": tvm.nd.array(np.random.randint(0, high=255, size=shape, dtype="uint8")),
+            "b": tvm.nd.array(np.random.randint(0, high=255, size=shape, dtype="uint8")),
+        }
+        out_zp, out_sc = _get_addition_qnn_params(lhs_zp, lhs_sc, rhs_zp, rhs_sc)
+        model = _get_model(shape, lhs_zp, lhs_sc, rhs_zp, rhs_sc, out_zp, out_sc, "uint8")
+        for npu in [False, True]:
+            mod = tei.make_module(model, [])
+            outputs.append(tei.build_and_run(mod, inputs, 1, {}, npu=npu))
+
+        tei.verify(outputs, 2)
+
+
+def test_addition_failure():
+    if not ethosn_available():
+        return
+
+    trials = [
+        (
+            (2, 4, 4, 4),
+            "uint8",
+            0,
+            1,
+            0,
+            1,
+            0,
+            1,
+            "batch size=2, batch size must = 1; batch size=2, batch size must = 1",
+        ),
+        (
+            (1, 4, 4, 4),
+            "int8",
+            0,
+            1,
+            0,
+            1,
+            0,
+            1,
+            "dtype='int8', dtype must be either uint8 or int32; dtype='int8', dtype must be either uint8 or int32",
+        ),
+    ]
+
+    for shape, dtype, lhs_zp, lhs_sc, rhs_zp, rhs_sc, out_zp, out_sc, err_msg in trials:
+        model = _get_model(shape, lhs_zp, lhs_sc, rhs_zp, rhs_sc, out_zp, out_sc, dtype)
+        mod = tei.make_ethosn_partition(model)
+        tei.test_error(mod, {}, err_msg)
index 64052ce..f9fdcf5 100644 (file)
@@ -113,24 +113,6 @@ def _get_model(
     return req, params
 
 
-def _get_conv2d_qnn_params(input_zp, input_sc, kernel_zp, kernel_sc, kernel_h, kernel_w, channels):
-    input_max = input_sc * (255 - input_zp)
-    input_min = -input_sc * input_zp
-    kernel_max = kernel_sc * (255 - kernel_zp)
-    kernel_min = -kernel_sc * kernel_zp
-    output_limits = [
-        kernel_max * kernel_h * kernel_w * channels * input_max,
-        kernel_min * kernel_h * kernel_w * channels * input_max,
-        kernel_min * kernel_h * kernel_w * channels * input_min,
-        kernel_max * kernel_h * kernel_w * channels * input_min,
-    ]
-    output_max = max(output_limits)
-    output_min = min(output_limits)
-    output_sc = (output_max - output_min) / 255
-    output_zp = -int(output_min / output_sc)
-    return output_zp, output_sc
-
-
 def test_conv2d():
     if not ethosn_available():
         return
@@ -171,7 +153,7 @@ def test_conv2d():
             input_sc = np.random.random() * 2
             kernel_zp = np.random.randint(0, 255)
             kernel_sc = np.random.random() * 2
-            output_zp, output_sc = _get_conv2d_qnn_params(
+            output_zp, output_sc = tei.get_conv2d_qnn_params(
                 input_zp, input_sc, kernel_zp, kernel_sc, kernel_h, kernel_w, shape[3]
             )
             model, params = _get_model(
diff --git a/tests/python/contrib/test_ethosn/test_depth_to_space.py b/tests/python/contrib/test_ethosn/test_depth_to_space.py
new file mode 100644 (file)
index 0000000..7daf888
--- /dev/null
@@ -0,0 +1,69 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Ethos-N integration depth-to-space tests"""
+
+import tvm
+from tvm import relay
+from tvm.relay.op.contrib.ethosn import ethosn_available
+from . import infrastructure as tei
+import numpy as np
+
+
+def _get_model(shape, block, dtype, layout):
+    a = relay.var("a", shape=shape, dtype=dtype)
+    depth = relay.nn.depth_to_space(a, layout=layout, block_size=block)
+    return depth
+
+
+def test_depth_to_space():
+    if not ethosn_available():
+        return
+
+    trials = [
+        (1, 16, 16, 16),
+        (1, 64, 32, 16),
+    ]
+
+    for shape in trials:
+        inputs = {
+            "a": tvm.nd.array(np.random.randint(0, high=255, size=shape, dtype="uint8")),
+        }
+        outputs = []
+        for npu in [False, True]:
+            model = _get_model(shape, 2, "uint8", "NHWC")
+            mod = tei.make_module(model, {})
+            outputs.append(tei.build_and_run(mod, inputs, 1, {}, npu=npu))
+
+        tei.verify(outputs, 1)
+
+
+def test_depth_to_space_failure():
+    if not ethosn_available():
+        return
+
+    trials = [
+        ((2, 16, 16, 16), 2, "uint8", "NHWC", "batch size=2, batch size must = 1"),
+        ((1, 16, 16, 16), 2, "int8", "NHWC", "dtype='int8', dtype must be either uint8 or int32"),
+        ((1, 16, 16, 16), 4, "uint8", "NHWC", "Only block size of 2 is supported"),
+        ((1, 16, 16, 16), 2, "uint8", "NCHW", "layout=NCHW, layout must = NHWC"),
+    ]
+
+    for shape, block, dtype, layout, err_msg in trials:
+        model = _get_model(shape, block, dtype, layout)
+        mod = tei.make_ethosn_partition(model)
+        tei.test_error(mod, {}, err_msg)
diff --git a/tests/python/contrib/test_ethosn/test_fullyconnected.py b/tests/python/contrib/test_ethosn/test_fullyconnected.py
new file mode 100644 (file)
index 0000000..09d07f6
--- /dev/null
@@ -0,0 +1,155 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Ethos-N integration fully connected tests"""
+
+import numpy as np
+import tvm
+from tvm import relay
+from tvm.relay.op.contrib.ethosn import ethosn_available
+from . import infrastructure as tei
+
+
+def _get_model(
+    shape, weight_shape, input_zp, input_sc, kernel_zp, kernel_sc, output_zp, output_sc, dtype
+):
+    """Return a model an any parameters it may have"""
+    a = relay.var("a", shape=shape, dtype=dtype)
+    w = tvm.nd.array(np.ones(weight_shape, dtype))
+    weights = relay.const(w, dtype)
+    fc = relay.qnn.op.dense(
+        a,
+        weights,
+        input_zero_point=relay.const(input_zp, "int32"),
+        kernel_zero_point=relay.const(kernel_zp, "int32"),
+        input_scale=relay.const(input_sc, "float32"),
+        kernel_scale=relay.const(kernel_sc, "float32"),
+        units=weight_shape[0],
+        out_dtype="int32",
+    )
+    b = tvm.nd.array(np.random.randint(0, high=255, size=(shape[0],), dtype="int32"))
+    biasc = relay.const(b, "int32")
+    bias = relay.nn.bias_add(fc, biasc, axis=0)
+    req = relay.qnn.op.requantize(
+        bias,
+        relay.const(input_sc * kernel_sc, "float32"),  # input zero scale
+        relay.const(input_zp * kernel_zp, "int32"),  # input zero point
+        relay.const(output_sc, "float32"),  # output zero scale
+        relay.const(output_zp, "int32"),  # output zero point
+        out_dtype="uint8",
+    )
+    params = {"w": w, "b": b}
+    return req, params
+
+
+def test_fullyconnected():
+    if not ethosn_available():
+        return
+
+    trials = [
+        ((1, 1024), 71, 0.580, 79, 1.498),
+        ((1, 4096), 166, 1.724, 117, 0.180),
+        ((1, 16384), 101, 1.372, 21, 1.346),
+    ]
+    np.random.seed(0)
+    for shape, input_zp, input_sc, kernel_zp, kernel_sc in trials:
+        inputs = {
+            "a": tvm.nd.array(np.random.randint(0, high=255, size=shape, dtype="uint8")),
+        }
+        outputs = []
+        output_zp, output_sc = tei.get_conv2d_qnn_params(
+            input_zp, input_sc, kernel_zp, kernel_sc, shape[0], shape[1], 1
+        )
+        for npu in [False, True]:
+            model, params = _get_model(
+                shape,
+                shape,
+                input_zp,
+                input_sc,  # input zp, sc
+                kernel_zp,
+                kernel_sc,  # kernel
+                output_zp,
+                output_sc,  # output
+                "uint8",
+            )
+            mod = tei.make_module(model, params)
+            outputs.append(tei.build_and_run(mod, inputs, 1, params, npu=npu))
+        tei.verify(outputs, 1)
+
+
+def test_fullyconnected_failure():
+    if not ethosn_available():
+        return
+
+    trials = [
+        (
+            (1, 64),
+            (1, 64),
+            0,
+            1,
+            0,
+            1,
+            0,
+            1,
+            "uint8",
+            "Overall scale (of the input * weights / output) should be in the range [0, 1)",
+        ),
+        (
+            (1, 1, 1, 64),
+            (1, 64),
+            0,
+            1,
+            0,
+            1,
+            0,
+            1,
+            "uint8",
+            "Weights tensor must have I dimension equal to the number of channels of the input tensor.;",
+        ),
+        ((1024, 64), (1, 64), 0, 1, 0, 1, 0, 1, "uint8", "batch size=1024, batch size must = 1;"),
+    ]
+
+    np.random.seed(0)
+    for (
+        shape,
+        weight_shape,
+        input_zp,
+        input_sc,
+        kernel_zp,
+        kernel_sc,
+        output_zp,
+        output_sc,
+        dtype,
+        err_msg,
+    ) in trials:
+        inputs = {
+            "a": tvm.nd.array(np.random.randint(0, high=255, size=shape, dtype=dtype)),
+        }
+        model, params = _get_model(
+            shape,
+            weight_shape,
+            input_zp,
+            input_sc,
+            kernel_zp,
+            kernel_sc,
+            output_zp,
+            output_sc,
+            dtype,
+        )
+        model = tei.make_ethosn_composite(model, "ethos-n.qnn_fc")
+        mod = tei.make_ethosn_partition(model)
+        tei.test_error(mod, {}, err_msg)
diff --git a/tests/python/contrib/test_ethosn/test_networks.py b/tests/python/contrib/test_ethosn/test_networks.py
new file mode 100644 (file)
index 0000000..8c6fd43
--- /dev/null
@@ -0,0 +1,175 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Ethos-N integration end-to-end network tests"""
+
+import pytest
+
+pytest.importorskip("tflite")
+pytest.importorskip("tensorflow")
+
+from tvm import relay
+from tvm.relay.op.contrib.ethosn import ethosn_available, Available
+from tvm.contrib import download
+import tvm.relay.testing.tf as tf_testing
+import tflite.Model
+from . import infrastructure as tei
+
+
+def _get_tflite_model(tflite_model_path, inputs_dict, dtype):
+    with open(tflite_model_path, "rb") as f:
+        tflite_model_buffer = f.read()
+
+    try:
+        tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buffer, 0)
+    except AttributeError:
+        tflite_model = tflite.Model.GetRootAsModel(tflite_model_buffer, 0)
+    shape_dict = {}
+    dtype_dict = {}
+    for input in inputs_dict:
+        input_shape = inputs_dict[input]
+        shape_dict[input] = input_shape
+        dtype_dict[input] = dtype
+
+    return relay.frontend.from_tflite(
+        tflite_model,
+        shape_dict=shape_dict,
+        dtype_dict=dtype_dict,
+    )
+
+
+def _test_image_network(
+    model_url,
+    model_sub_path,
+    input_dict,
+    compile_hash,
+    output_count,
+    run=True,
+    host_ops=0,
+    npu_partitions=1,
+):
+    if not ethosn_available():
+        return
+
+    def get_model():
+        if model_url[-3:] in ("tgz", "zip"):
+            model_path = tf_testing.get_workload_official(
+                model_url,
+                model_sub_path,
+            )
+        else:
+            model_path = download.download_testdata(
+                model_url,
+                model_sub_path,
+            )
+        return _get_tflite_model(model_path, input_dict, "uint8")
+
+    outputs = []
+    inputs = {}
+    for input_name in input_dict:
+        input_shape = input_dict[input_name]
+        inputs[input_name] = tei.get_real_image(input_shape[1], input_shape[2])
+
+    for npu in [False, True]:
+        mod, params = get_model()
+        graph, lib, params = tei.build(
+            mod, params, npu=npu, expected_host_ops=host_ops, npu_partitions=npu_partitions
+        )
+        if npu:
+            tei.assert_lib_hash(lib, compile_hash)
+        if run:
+            outputs.append(tei.run(graph, lib, params, inputs, output_count, npu=npu))
+
+    if run:
+        tei.verify(outputs, 1, verify_saturation=False)
+
+
+def test_mobilenet_v1():
+    # If this test is failing due to a hash mismatch, please notify @mbaret and
+    # @Leo-arm. The hash is there to catch any changes in the behaviour of the
+    # codegen, which could come about from either a change in Support Library
+    # version or a change in the Ethos-N codegen. To update this requires running
+    # on hardware that isn't available in CI.
+    hw = ethosn_available()
+    _test_image_network(
+        model_url="https://storage.googleapis.com/download.tensorflow.org/"
+        "models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz",
+        model_sub_path="mobilenet_v1_1.0_224_quant.tflite",
+        input_dict={"input": (1, 224, 224, 3)},
+        compile_hash="81637c89339201a07dc96e3b5dbf836a",
+        output_count=1,
+        run=(hw == Available.SW_AND_HW),
+        host_ops=3,
+        npu_partitions=1,
+    )
+
+
+def test_inception_v3():
+    # If this test is failing due to a hash mismatch, please notify @mbaret and
+    # @Leo-arm. The hash is there to catch any changes in the behaviour of the
+    # codegen, which could come about from either a change in Support Library
+    # version or a change in the Ethos-N codegen. To update this requires running
+    # on hardware that isn't available in CI.
+    _test_image_network(
+        model_url="https://storage.googleapis.com/download.tensorflow.org/"
+        "models/tflite_11_05_08/inception_v3_quant.tgz",
+        model_sub_path="inception_v3_quant.tflite",
+        input_dict={"input": (1, 299, 299, 3)},
+        compile_hash="de0e175af610ebd45ccb03d170dc9664",
+        output_count=1,
+        run=False,
+        host_ops=0,
+        npu_partitions=1,
+    )
+
+
+def test_inception_v4():
+    # If this test is failing due to a hash mismatch, please notify @mbaret and
+    # @Leo-arm. The hash is there to catch any changes in the behaviour of the
+    # codegen, which could come about from either a change in Support Library
+    # version or a change in the Ethos-N codegen. To update this requires running
+    # on hardware that isn't available in CI.
+    _test_image_network(
+        model_url="https://storage.googleapis.com/download.tensorflow.org/"
+        "models/inception_v4_299_quant_20181026.tgz",
+        model_sub_path="inception_v4_299_quant.tflite",
+        input_dict={"input": (1, 299, 299, 3)},
+        compile_hash="06bf6cb56344f3904bcb108e54edfe87",
+        output_count=1,
+        run=False,
+        host_ops=3,
+        npu_partitions=1,
+    )
+
+
+def test_ssd_mobilenet_v1():
+    # If this test is failing due to a hash mismatch, please notify @mbaret and
+    # @Leo-arm. The hash is there to catch any changes in the behaviour of the
+    # codegen, which could come about from either a change in Support Library
+    # version or a change in the Ethos-N codegen. To update this requires running
+    # on hardware that isn't available in CI.
+    _test_image_network(
+        model_url="https://storage.googleapis.com/download.tensorflow.org/"
+        "models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip",
+        model_sub_path="detect.tflite",
+        input_dict={"normalized_input_image_tensor": (1, 300, 300, 3)},
+        compile_hash="6211d96103880b016baa85e638abddef",
+        output_count=4,
+        run=False,
+        host_ops=28,
+        npu_partitions=2,
+    )
diff --git a/tests/python/contrib/test_ethosn/test_pooling.py b/tests/python/contrib/test_ethosn/test_pooling.py
new file mode 100644 (file)
index 0000000..6b2330f
--- /dev/null
@@ -0,0 +1,123 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Ethos-N integration pooling tests"""
+
+import numpy as np
+import tvm
+from tvm import relay
+from tvm.relay.op.contrib.ethosn import ethosn_available
+from . import infrastructure as tei
+
+
+def _get_model(shape, typef, sizes, strides, pads, layout, dtype):
+    """Return a model and any parameters it may have"""
+    req = relay.var("a", shape=shape, dtype=dtype)
+    if typef == relay.nn.avg_pool2d:
+        req = relay.cast(req, "int32")
+    req = typef(req, pool_size=sizes, strides=strides, padding=pads, ceil_mode=True, layout=layout)
+    if typef == relay.nn.avg_pool2d:
+        req = relay.cast(req, dtype)
+    return req
+
+
+def test_pooling():
+    if not ethosn_available():
+        return
+
+    trials = [
+        ((1, 8, 8, 8), relay.nn.max_pool2d, (2, 2), (2, 2), (0, 0, 0, 0), "NHWC"),
+        ((1, 9, 9, 9), relay.nn.max_pool2d, (2, 2), (2, 2), (0, 0, 1, 1), "NHWC"),
+        ((1, 9, 9, 9), relay.nn.max_pool2d, (3, 3), (2, 2), (0, 0, 0, 0), "NHWC"),
+        ((1, 8, 8, 8), relay.nn.max_pool2d, (3, 3), (2, 2), (0, 0, 1, 1), "NHWC"),
+        ((1, 8, 8, 8), relay.nn.avg_pool2d, (3, 3), (1, 1), (1, 1, 1, 1), "NHWC"),
+    ]
+
+    for shape, typef, size, stride, pad, layout in trials:
+        inputs = {
+            "a": tvm.nd.array(np.random.randint(low=0, high=255, size=shape, dtype="uint8")),
+        }
+        outputs = []
+        model = _get_model(shape, typef, size, stride, pad, layout, "uint8")
+        for npu in [False, True]:
+            mod = tei.make_module(model, {})
+            outputs.append(tei.build_and_run(mod, inputs, 1, {}, npu=npu))
+
+        tei.verify(outputs, 1)
+
+
+def test_pooling_failure():
+    if not ethosn_available():
+        return
+
+    trials = [
+        (
+            (2, 8, 8, 8),
+            relay.nn.max_pool2d,
+            (2, 2),
+            (2, 2),
+            (0, 0, 0, 0),
+            "NHWC",
+            "uint8",
+            "batch size=2, batch size must = 1",
+        ),
+        (
+            (1, 8, 8, 8),
+            relay.nn.max_pool2d,
+            (2, 2),
+            (2, 2),
+            (0, 0, 0, 0),
+            "NHWC",
+            "int8",
+            "dtype='int8', dtype must be either uint8 or int32",
+        ),
+        (
+            (1, 8, 8, 8),
+            relay.nn.max_pool2d,
+            (2, 2),
+            (2, 2),
+            (0, 0, 0, 0),
+            "NCHW",
+            "uint8",
+            "data format=NCHW, data format must = NHWC",
+        ),
+        (
+            (1, 8, 8, 8),
+            relay.nn.max_pool2d,
+            (2, 2),
+            (2, 2, 2),
+            (0, 0, 0, 0),
+            "NHWC",
+            "uint8",
+            "stride size=3, stride size must = 2",
+        ),
+        (
+            (1, 8, 8, 8),
+            relay.nn.max_pool2d,
+            (2, 2, 2),
+            (2, 2),
+            (0, 0, 0, 0),
+            "NHWC",
+            "uint8",
+            "dimensions=3, dimensions must = 2",
+        ),
+    ]
+
+    for shape, typef, size, stride, pad, layout, dtype, err_msg in trials:
+        model = _get_model(shape, typef, size, stride, pad, layout, dtype)
+        mod = tei.make_ethosn_partition(model)
+        tei.test_error(mod, {}, err_msg)
diff --git a/tests/python/contrib/test_ethosn/test_relu.py b/tests/python/contrib/test_ethosn/test_relu.py
new file mode 100644 (file)
index 0000000..6b366e6
--- /dev/null
@@ -0,0 +1,70 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Ethos-N integration relu tests"""
+
+import tvm
+from tvm import relay
+from tvm.relay.op.contrib.ethosn import ethosn_available
+from . import infrastructure as tei
+import numpy as np
+
+
+def _get_model(shape, dtype, a_min, a_max):
+    a = relay.var("a", shape=shape, dtype=dtype)
+    relu = relay.clip(a, a_min=a_min, a_max=a_max)
+    return relu
+
+
+def test_relu():
+    if not ethosn_available():
+        return
+
+    trials = [
+        ((1, 4, 4, 4), 65, 178),
+        ((1, 8, 4, 2), 1, 254),
+        ((1, 16), 12, 76),
+    ]
+
+    for shape, a_min, a_max in trials:
+        inputs = {
+            "a": tvm.nd.array(np.random.randint(0, high=255, size=shape, dtype="uint8")),
+        }
+        outputs = []
+        for npu in [False, True]:
+            model = _get_model(inputs["a"].shape, "uint8", a_min, a_max)
+            mod = tei.make_module(model, {})
+            outputs.append(tei.build_and_run(mod, inputs, 1, {}, npu=npu))
+
+        tei.verify(outputs, 1)
+
+
+def test_relu_failure():
+    if not ethosn_available():
+        return
+
+    trials = [
+        ((1, 4, 4, 4, 4), "uint8", 65, 78, "dimensions=5, dimensions must be <= 4"),
+        ((1, 8, 4, 2), "int8", 1, 254, "dtype='int8', dtype must be either uint8 or int32"),
+        ((1, 8, 4, 2), "uint8", 254, 1, "Relu has lower bound > upper bound"),
+        ((2, 2, 2, 2), "uint8", 1, 63, "batch size=2, batch size must = 1; "),
+    ]
+
+    for shape, dtype, a_min, a_max, err_msg in trials:
+        model = _get_model(shape, dtype, a_min, a_max)
+        mod = tei.make_ethosn_partition(model)
+        tei.test_error(mod, {}, err_msg)
diff --git a/tests/python/contrib/test_ethosn/test_reshape.py b/tests/python/contrib/test_ethosn/test_reshape.py
new file mode 100644 (file)
index 0000000..e15ddd6
--- /dev/null
@@ -0,0 +1,81 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Ethos-N integration reshape tests"""
+
+import tvm
+from tvm import relay
+from tvm.relay.op.contrib import ethosn_available, get_pattern_table
+from . import infrastructure as tei
+import numpy as np
+
+
+def _get_model(input_shape, output_shape, dtype):
+    """Return a model and any parameters it may have"""
+    a = relay.var("a", shape=input_shape, dtype=dtype)
+    conv, params = tei.get_conv2d(a, input_shape)
+    req = relay.reshape(conv, output_shape)
+    return req, params
+
+
+def test_reshape():
+    if not ethosn_available():
+        return
+
+    trials = [
+        ((1, 15, 4, 1), (60,)),
+        ((1, 15, 4, 1), (30, 2)),
+        ((1, 15, 4, 1), (1, 4, 15, 1)),
+        ((1, 15, 4, 1), (1, 12, 5, 1)),
+        ((1, 15, 4, 1), (1, -1, 2, 1)),
+    ]
+
+    np.random.seed(0)
+    for input_shape, output_shape in trials:
+        inputs = {
+            "a": tvm.nd.array(np.random.randint(0, high=255, size=input_shape, dtype="uint8"))
+        }
+        outputs = []
+        for npu in [False, True]:
+            model, params = _get_model(input_shape, output_shape, "uint8")
+            mod = tei.make_module(model, params)
+            outputs.append(tei.build_and_run(mod, inputs, 1, params, npu=npu))
+
+        tei.verify(outputs, 1)
+
+
+def test_reshape_failure():
+    if not ethosn_available():
+        return
+
+    trials = [
+        (
+            (1, 15, 4, 1),
+            (1, 15, -2),
+            "uint8",
+            "reshape dimension=-2, reshape dimension must be >= -1",
+        ),
+    ]
+
+    np.random.seed(0)
+    for input_shape, output_shape, dtype, err_msg in trials:
+        model, params = _get_model(input_shape, output_shape, dtype)
+        mod = tei.make_module(model, params)
+        pattern = get_pattern_table("ethos-n")
+        mod = relay.transform.MergeComposite(pattern)(mod)
+        mod = tei.make_ethosn_partition(mod["main"].body)
+        tei.test_error(mod, {}, err_msg)
diff --git a/tests/python/contrib/test_ethosn/test_sigmoid.py b/tests/python/contrib/test_ethosn/test_sigmoid.py
new file mode 100644 (file)
index 0000000..f3018dd
--- /dev/null
@@ -0,0 +1,97 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Ethos-N integration sigmoid tests"""
+
+import tvm
+from tvm import relay
+from tvm.relay.op.contrib.ethosn import ethosn_available
+from . import infrastructure as tei
+import numpy as np
+
+
+def _get_model(shape, input_zp, input_sc, output_zp, output_sc, dtype):
+    a = relay.var("a", shape=shape, dtype=dtype)
+    dequantize = relay.qnn.op.dequantize(
+        a,
+        input_scale=relay.const(input_sc, "float32"),
+        input_zero_point=relay.const(input_zp, "int32"),
+    )
+    sigmoid = relay.sigmoid(dequantize)
+    model = relay.qnn.op.quantize(
+        sigmoid,
+        output_scale=relay.const(output_sc, "float32"),
+        output_zero_point=relay.const(output_zp, "int32"),
+        out_dtype=dtype,
+    )
+    return model
+
+
+def test_sigmoid():
+    if not ethosn_available():
+        return
+
+    trials = [
+        (1, 16, 16, 16),
+        (1, 8, 8),
+    ]
+
+    np.random.seed(0)
+    for shape in trials:
+        inputs = {
+            "a": tvm.nd.array(np.random.randint(0, high=255, size=shape, dtype="uint8")),
+        }
+        outputs = []
+        for npu in [False, True]:
+            model = _get_model(shape, 64, 0.02, 0, 1 / 256, "uint8")
+            mod = tei.make_module(model, [])
+            outputs.append(tei.build_and_run(mod, inputs, 1, {}, npu=npu))
+
+        tei.verify(outputs, 1)
+
+
+def test_sigmoid_failure():
+    if not ethosn_available():
+        return
+
+    trials = [
+        ((2, 4, 4, 4), 64, 0.2, 0, 1 / 256, "uint8", "batch size=2, batch size must = 1"),
+        (
+            (1, 4, 4, 4),
+            64,
+            0.2,
+            0,
+            1 / 256,
+            "int8",
+            "dtype='int8', dtype must be either uint8 or int32",
+        ),
+        (
+            (1, 4, 4, 4),
+            64,
+            0.2,
+            0,
+            1,
+            "uint8",
+            "output quantization params=(0, 1), must = (0, 1/256)",
+        ),
+    ]
+
+    for shape, input_zp, input_sc, output_zp, output_sc, dtype, err_msg in trials:
+        model = _get_model(shape, input_zp, input_sc, output_zp, output_sc, dtype)
+        model = tei.make_ethosn_composite(model, "ethos-n.qnn_sigmoid")
+        mod = tei.make_ethosn_partition(model)
+        tei.test_error(mod, {}, err_msg)
index 0cf5720..89099db 100644 (file)
 import numpy as np
 import tvm
 from tvm import relay
-from tvm.relay.op.contrib.ethosn import ethosn_available
+from tvm.relay.op.contrib.ethosn import ethosn_available, Available
 from . import infrastructure as tei
 
 
+def test_split_add_concat():
+    if not ethosn_available():
+        return
+
+    def get_model(input_shape, var_names):
+        """Return a model"""
+
+        a = relay.var(next(var_names), shape=input_shape, dtype="uint8")
+        split_scale = relay.const(0.25, "float32")
+        split_zp = relay.const(100, "int32")
+        add_scale = relay.const(0.75, "float32")
+        add_zp = relay.const(120, "int32")
+        axis = 2
+
+        split = relay.split(a, indices_or_sections=4, axis=axis)
+        b = relay.qnn.op.add(
+            split[0],
+            split[1],
+            lhs_scale=split_scale,
+            lhs_zero_point=split_zp,
+            rhs_scale=split_scale,
+            rhs_zero_point=split_zp,
+            output_scale=add_scale,
+            output_zero_point=add_zp,
+        )
+        conc = relay.qnn.op.concatenate(
+            [b, split[2], split[3]],
+            input_scales=(add_scale, split_scale, split_scale),
+            input_zero_points=(add_zp, split_zp, split_zp),
+            output_scale=add_scale,
+            output_zero_point=add_zp,
+            axis=axis,
+        )
+        return conc
+
+    inputs = {
+        "a": tvm.nd.array(np.random.randint(0, high=255, size=(1, 16, 16, 4), dtype="uint8")),
+    }
+
+    outputs = []
+    for npu in [False, True]:
+        model = get_model(inputs["a"].shape, iter(inputs))
+        mod = tei.make_module(model, [])
+        outputs.append(tei.build_and_run(mod, inputs, 1, {}, npu=npu))
+
+    tei.verify(outputs, 2)
+
+
+def test_multiple_command_streams():
+    """Check that multiple Ethos-N partitions are correctly handled.
+
+    If there's more than one Ethos-N graph partition, more than one command
+    stream will be created. This should be handled correctly by both the
+    Ethos-N codegen and Ethos-N runtime module. This test checks against a
+    simple graph which creates two Ethos-N partitions and checks the result
+    against an 'all-CPU' run through TVM.
+    """
+    if ethosn_available() != Available.SW_AND_HW:
+        return
+
+    def get_model():
+        """
+        max_pool2d
+             |
+            abs
+             |
+        max_pool2d
+        """
+        x = relay.var("x", shape=(1, 4, 4, 4), dtype="uint8")
+        out = relay.nn.max_pool2d(x, (2, 2), (2, 2), layout="NHWC")  # supported
+        out = relay.op.abs(out)  # not supported
+        out = relay.nn.max_pool2d(out, (2, 2), (2, 2), layout="NHWC")  # supported
+        return out
+
+    np.random.seed(0)
+    outputs = []
+    inputs = {"x": tvm.nd.array(np.random.randint(0, high=256, size=(1, 4, 4, 4), dtype="uint8"))}
+    for npu in [False, True]:
+        model = get_model()
+        mod = tei.make_module(model, {})
+        outputs.append(
+            tei.build_and_run(mod, inputs, 1, {}, npu=npu, expected_host_ops=1, npu_partitions=2)
+        )
+
+    tei.verify(outputs, 0)
+
+
+def test_output_order():
+    if not ethosn_available():
+        return
+
+    def get_model(input_shape, var_names):
+        """Return a model"""
+
+        a = relay.var(next(var_names), shape=input_shape, dtype="uint8")
+
+        z = relay.op.clip(a, 0, 255)
+        b = relay.op.clip(z, 0, 15)
+        c = relay.op.clip(z, 16, 31)
+        d = relay.op.clip(z, 32, 47)
+        e = relay.op.clip(z, 48, 63)
+        f = relay.op.clip(z, 64, 79)
+        g = relay.op.clip(z, 80, 95)
+        h = relay.op.clip(z, 96, 111)
+        i = relay.op.clip(z, 112, 127)
+        return relay.Tuple((d, c, e, f, i, b, h, g))
+
+    inputs = {
+        "a": tvm.nd.array(np.random.randint(0, high=255, size=(1, 16, 16, 4), dtype="uint8")),
+    }
+
+    outputs = []
+    for npu in [False, True]:
+        model = get_model(inputs["a"].shape, iter(inputs))
+        mod = tei.make_module(model, [])
+        outputs.append(tei.build_and_run(mod, inputs, 8, {}, npu=npu))
+
+    tei.verify(outputs, 1)
+
+
 def test_split_with_asym_concats():
     if not ethosn_available():
         return