[Relay] Bitserial ops (#3844)
authorJosh Fromm <jwfromm@uw.edu>
Sun, 1 Sep 2019 00:51:51 +0000 (17:51 -0700)
committerJared Roesch <roeschinc@gmail.com>
Sun, 1 Sep 2019 00:51:51 +0000 (17:51 -0700)
* Added arm_cpu NHWC schedules.

* Fixed kernel shape legalization.

* Added bitserial ops to relay.

* Snapshot and more missing files.

* Added dense testing.

* Added tests

* Added ASF header to new files.

* cc lint

* Pylint change.

* pylint fixes.

* Change arm legalize test.

* Added assert check to arm legalize.

* Added better documentation, fixed some bad style

* Reverted arm conv2d nhwc changes.

include/tvm/relay/attrs/bitserial.h [new file with mode: 0644]
python/tvm/relay/op/nn/_nn.py
python/tvm/relay/op/nn/nn.py
python/tvm/relay/op/op_attrs.py
src/relay/op/nn/bitserial.cc [new file with mode: 0644]
tests/python/relay/test_op_level1.py
tests/python/relay/test_op_level2.py
topi/python/topi/arm_cpu/bitserial_conv2d.py
topi/python/topi/generic/nn.py
topi/python/topi/nn/bitserial_conv2d.py

diff --git a/include/tvm/relay/attrs/bitserial.h b/include/tvm/relay/attrs/bitserial.h
new file mode 100644 (file)
index 0000000..2a7376b
--- /dev/null
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/attrs/bitserial.h
+ * \brief Auxiliary attributes for bitserial operators.
+ */
+
+#ifndef TVM_RELAY_ATTRS_BITSERIAL_H_
+#define TVM_RELAY_ATTRS_BITSERIAL_H_
+
+#include <tvm/attrs.h>
+#include <tvm/relay/base.h>
+#include <string>
+
+namespace tvm {
+namespace relay {
+
+/*! \brief Attributes used in bitpack operators */
+struct BitPackAttrs : public tvm::AttrsNode<BitPackAttrs> {
+  int bits;
+  int pack_axis;
+  int bit_axis;
+  DataType pack_type;
+  std::string name;
+
+  TVM_DECLARE_ATTRS(BitPackAttrs, "relay.attrs.BitPackAttrs") {
+    TVM_ATTR_FIELD(bits).set_default(1).describe("Number of bits to quantize with.");
+    TVM_ATTR_FIELD(pack_axis).set_default(1).describe(
+        "Axis that should be compressed, typically channels.");
+    TVM_ATTR_FIELD(bit_axis).set_default(-1).describe("New axis for packed bits.");
+    TVM_ATTR_FIELD(pack_type)
+        .set_default(NullValue<DataType>())
+        .describe("Type of int to pack bits into.");
+    TVM_ATTR_FIELD(name).set_default("BitPack").describe("Name of operation.");
+  }
+};
+
+/*! \brief Attribues used in bitserial convolution operators */
+struct BinaryConv2DAttrs : public tvm::AttrsNode<BinaryConv2DAttrs> {
+  Array<IndexExpr> strides;
+  Array<IndexExpr> padding;
+  IndexExpr channels;
+  Array<IndexExpr> kernel_size;
+  int activation_bits;
+  int weight_bits;
+  std::string data_layout;
+  std::string kernel_layout;
+  DataType pack_dtype;
+  DataType out_dtype;
+  bool unipolar;
+
+  TVM_DECLARE_ATTRS(BinaryConv2DAttrs, "relay.attrs.BinaryConv2DAttrs") {
+    TVM_ATTR_FIELD(strides)
+        .set_default(Array<IndexExpr>({1, 1}))
+        .describe("Specifies the strides of the convolution.");
+    TVM_ATTR_FIELD(padding)
+        .set_default(Array<IndexExpr>({0, 0}))
+        .describe(
+            "If padding is non-zero the input is implicitly zero-padded"
+            "on both sides for padding number of points.");
+    TVM_ATTR_FIELD(kernel_size)
+        .set_default(Array<IndexExpr>({3, 3}))
+        .describe("Specifies the dimensions of the convolution window.");
+    TVM_ATTR_FIELD(channels)
+        .set_default(NullValue<IndexExpr>())
+        .describe("Number of output channels, needed for shape inference.");
+    TVM_ATTR_FIELD(activation_bits)
+        .set_default(1)
+        .describe("Number of bits activation should be packed with.");
+    TVM_ATTR_FIELD(weight_bits)
+        .set_default(1)
+        .describe("Number of bits kernel should be packed with.");
+    TVM_ATTR_FIELD(data_layout)
+        .set_default("NCHW")
+        .describe("Dimension ordering of input data, can be 'NCHW' or NHWC'.");
+    TVM_ATTR_FIELD(kernel_layout)
+        .set_default("OIHW")
+        .describe("Dimension ordering of kernel data, can be 'OIHW' or HWIO'.");
+    TVM_ATTR_FIELD(pack_dtype)
+        .set_default(NullValue<DataType>())
+        .describe("Datatype to pack bits into.");
+    TVM_ATTR_FIELD(out_dtype).set_default(NullValue<DataType>()).describe("Output datatype.");
+    TVM_ATTR_FIELD(unipolar).set_default(true).describe(
+        "Whether to use unipolar or bipolar quantization.");
+  }
+};
+
+/*~ \brief Attributes for bitserial dense operator */
+struct BinaryDenseAttrs : public tvm::AttrsNode<BinaryDenseAttrs> {
+  IndexExpr units;
+  int data_bits;
+  int weight_bits;
+  DataType pack_dtype;
+  DataType out_dtype;
+  bool unipolar;
+
+  TVM_DECLARE_ATTRS(BinaryDenseAttrs, "relay.attrs.BinaryDenseAttrs") {
+    TVM_ATTR_FIELD(units)
+      .describe("Number of hidden units of the dense transformation.");
+    TVM_ATTR_FIELD(data_bits)
+      .set_default(1)
+      .describe("Number of bits to pack for incoming tensor.");
+    TVM_ATTR_FIELD(weight_bits)
+      .set_default(1)
+      .describe("Number of bits to pack for weight tensor.");
+    TVM_ATTR_FIELD(pack_dtype)
+      .set_default(NullValue<DataType>())
+      .describe("Datatype to pack bits into before computation.");
+    TVM_ATTR_FIELD(out_dtype)
+      .set_default(NullValue<DataType>())
+      .describe("Output data type.");
+    TVM_ATTR_FIELD(unipolar)
+      .set_default(true)
+      .describe("Whether to use unipolar or bipolar quantization for inputs.");
+  }
+};
+
+}  // namespace relay
+}  // namespace tvm
+#endif  // TVM_RELAY_ATTRS_BITSERIAL_H_
index 03a04c9..d652977 100644 (file)
@@ -600,3 +600,120 @@ def schedule_deformable_conv2d(attrs, outs, target):
 
 
 reg.register_pattern("nn.deformable_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
+
+
+@reg.register_compute("nn.bitpack")
+def compute_bitpack(attrs, inputs, out_dtype, target):
+    """Compute definition for bitpack"""
+    bits = attrs.bits
+    pack_axis = attrs.pack_axis
+    bit_axis = attrs.bit_axis
+    pack_type = attrs.pack_type
+    name = attrs.name
+    with target:
+        out = topi.nn.bitpack(inputs[0], bits, pack_axis, bit_axis, pack_type,
+                              name)
+    return [out]
+
+@reg.register_schedule("nn.bitpack")
+def schedule_bitpack(attrs, outs, target):
+    with target:
+        return topi.generic.schedule_bitpack(outs)
+
+reg.register_pattern("nn.bitpack", OpPattern.INJECTIVE)
+
+
+@reg.register_compute("nn.bitserial_conv2d")
+def compute_bitserial_conv2d(attrs, inputs, out_dtype, target):
+    """Compute definition for bitserial conv2d."""
+    padding = get_const_tuple(attrs.padding)
+    strides = get_const_tuple(attrs.strides)
+    activation_bits = attrs.activation_bits
+    weight_bits = attrs.weight_bits
+    layout = attrs.data_layout
+    pack_dtype = attrs.pack_dtype
+    out_dtype = attrs.out_dtype
+    unipolar = attrs.unipolar
+    if layout == 'NCHW':
+        with target:
+            out = topi.nn.bitserial_conv2d_nchw(
+                inputs[0], inputs[1], strides, padding, activation_bits,
+                weight_bits, pack_dtype, out_dtype, unipolar)
+    elif layout == 'NHWC':
+        with target:
+            out = topi.nn.bitserial_conv2d_nhwc(
+                inputs[0], inputs[1], strides, padding, activation_bits,
+                weight_bits, pack_dtype, out_dtype, unipolar)
+    else:
+        raise ValueError("Data layout not supported.")
+
+    return [out]
+
+
+@reg.register_schedule("nn.bitserial_conv2d")
+def schedule_bitserial_conv2d(attrs, outs, target):
+    """Schedule definition for bitserial conv2d."""
+    layout = attrs.data_layout
+    if layout == 'NCHW':
+        with target:
+            return topi.generic.schedule_bitserial_conv2d_nchw(outs)
+    elif layout == 'NHWC':
+        with target:
+            return topi.generic.schedule_bitserial_conv2d_nhwc(outs)
+    else:
+        raise ValueError("Data layout not supported.")
+
+@reg.register_legalize("nn.bitserial_conv2d")
+def legalize_bitserial_conv2d(attrs, inputs, types):
+    """Legalize bitserial_conv2d op.
+
+    Parameters
+    ----------
+    attrs : tvm.attrs.Attrs
+        Attributes of current convolution
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    types : list of types
+        List of input and output types
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The legalized expr
+    """
+    return topi.nn.bitserial_conv2d_legalize(attrs, inputs, types)
+
+
+reg.register_pattern("nn.bitserial_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
+
+
+# bitserial_dense
+@reg.register_compute("nn.bitserial_dense")
+def compute_bitserial_dense(attrs, inputs, out_type, target):
+    """Compute definition of bitserial_dense"""
+    data_bits = attrs.data_bits
+    weight_bits = attrs.weight_bits
+    pack_dtype = attrs.pack_dtype
+    out_dtype = attrs.out_dtype
+    out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
+    unipolar = attrs.unipolar
+    return [
+        topi.nn.bitserial_dense(
+            inputs[0],
+            inputs[1],
+            data_bits,
+            weight_bits,
+            pack_dtype,
+            out_dtype,
+            unipolar)
+    ]
+
+
+@reg.register_schedule("nn.bitserial_dense")
+def schedule_bitserial_dense(attrs, outputs, target):
+    """Schedule definition of bitserial_dense"""
+    with target:
+        return topi.generic.schedule_bitserial_dense(outputs)
+
+
+reg.register_pattern("nn.bitserial_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
index 946ea33..19c50d6 100644 (file)
@@ -1459,3 +1459,165 @@ def deformable_conv2d(data,
     return _make.deformable_conv2d(data, offset, weight, strides, padding, dilation,
                                    deformable_groups, groups, channels, kernel_size, data_layout,
                                    kernel_layout, out_layout, out_dtype)
+
+
+def bitpack(data,
+            bits=1,
+            pack_axis=1,
+            bit_axis=2,
+            pack_type="uint32",
+            name="BitPack"):
+    r"""Tensor packing for bitserial operations.
+    The values along the input tensor's pack_axis are quantized
+    and packed together into the specified pack_type in a new
+    bit axis.
+
+    For example, consider bitpacking with data to be a tensor with shape [1, 64, 128, 128],
+    pack_axis=1, bit_axis=4, pack_type=uint8, and bits=2. The output in this case will
+    be of shape [1, 8, 128, 128, 2]. The dimension of axis 1 has been reduced by a factor
+    of 8 since each value is packed into an 8-bit uint8. Axis 4 is now two bitplanes
+    representing the quantized value of the incoming data. The output tensor is now
+    ready to be used in a bitserial operation.
+
+    Parameters
+    ----------
+    data : tvm.relay.expr
+        The incoming tensor to be packed.
+
+    bits : int
+        Number of bits that should be packed.
+
+    pack_axis : int
+        Axis that should be decomposed and packed.
+
+    bit_axis : int
+        New axis containing bitplane.
+
+    pack_type : str
+        Datatype to pack bits into.
+
+    name : str, optional
+        Name of the operation.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The packed tensor.
+    """
+    return _make.bitpack(data, bits, pack_axis, bit_axis, pack_type, name)
+
+
+def bitserial_conv2d(data,
+                     weight,
+                     strides=(1, 1),
+                     padding=(0, 0),
+                     channels=None,
+                     kernel_size=(3, 3),
+                     activation_bits=1,
+                     weight_bits=1,
+                     data_layout='NCHW',
+                     kernel_layout='OIHW',
+                     pack_dtype='uint32',
+                     out_dtype='int16',
+                     unipolar=True):
+    r"""2D convolution using bitserial computation.
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        The input data to the operator.
+
+    weight : tvm.relay.Expr
+        The weight expressions.
+
+    strides : tuple of int, optional
+        The strides of convolution.
+
+    padding : tuple of int, optional
+        The padding of convolution on both sides of inputs before convolution.
+
+    channels : int, optional
+        Number of output channels of this convolution.
+
+    kernel_size : tuple of int, optional
+        The spatial of the convolution kernel.
+
+    activation_bits : int
+        Number of bits to pack for activations.
+
+    weight_bits : int
+        Number of bits to pack for weights.
+
+    data_layout : str, optional
+        Layout of the input.
+
+    kernel_layout : str, optional
+        Layout of the kernel
+
+    pack_dtype: str, optional
+        Datatype to pack bits into.
+
+    out_dtype : str, optional
+        Specifies the output data type for mixed precision conv2d.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The computed result.
+    """
+
+    return _make.bitserial_conv2d(data, weight, strides, padding, channels,
+                                  kernel_size, activation_bits, weight_bits,
+                                  data_layout, kernel_layout, pack_dtype,
+                                  out_dtype, unipolar)
+
+
+def bitserial_dense(data,
+                    weight,
+                    units=None,
+                    data_bits=1,
+                    weight_bits=1,
+                    pack_dtype='uint32',
+                    out_dtype='int16',
+                    unipolar=True):
+    """Bitserial Dense operator.
+    Applies matrix multiplication of two quantized matrices
+    using a fast bitserial algorithm.
+
+    .. math::
+
+    `Y = X * W`
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        The input data to the operator.
+
+    weight : tvm.relay.Expr
+        The weight expressions.
+
+    units : int, optional
+        Number of hidden units of the dense transformation.
+
+    data_bits : int
+        Number of bits incoming tensor should be packed with.
+
+    weight_bits : int
+        Number of bits weight tensor should be packed with.
+
+    pack_dtype : str, optional
+        Datatype to pack individual bits into before computation.
+
+    out_dtype : str, optional
+        Specifies the output data type for mixed precision dense.
+
+    unipolar : bool, optional
+        Whether to use unipolar or bipolar quantization for inputs.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The computed result.
+    """
+    return _make.bitserial_dense(data, weight, units, data_bits, weight_bits,
+                                 pack_dtype, out_dtype, unipolar)
index 48d3d20..11f8ad1 100644 (file)
@@ -264,3 +264,18 @@ class MaxPool2DAttrs(Attrs):
 @register_relay_attr_node
 class AvgPool2DAttrs(Attrs):
     """Attributes used in avg_pool2d operators"""
+
+
+@register_relay_attr_node
+class BitPackAttrs(Attrs):
+    """Attributes used in bitpack operator"""
+
+
+@register_relay_attr_node
+class BinaryConv2DAttrs(Attrs):
+    """Attributes used in bitserial conv2d operators"""
+
+
+@register_relay_attr_node
+class BinaryDenseAttrs(Attrs):
+    """Attributes used in bitserial dense operators"""
diff --git a/src/relay/op/nn/bitserial.cc b/src/relay/op/nn/bitserial.cc
new file mode 100644 (file)
index 0000000..6ee1ee6
--- /dev/null
@@ -0,0 +1,257 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file bitserial.cc
+ * \brief Property def of bitserial operators.
+ */
+
+#include <tvm/data_layout.h>
+#include <tvm/relay/attrs/bitserial.h>
+#include <tvm/relay/op.h>
+
+#include "../../pass/alter_op_layout.h"
+
+namespace tvm {
+namespace relay {
+
+// relay.nn.bitpack
+TVM_REGISTER_NODE_TYPE(BitPackAttrs);
+
+template <typename T>
+Array<Array<Layout>> BinaryConv2DInferCorrectLayout(const Attrs& attrs,
+                                                    const Array<Layout>& new_in_layouts,
+                                                    const Array<Layout>& old_in_layouts,
+                                                    const Array<Array<IndexExpr>>& old_in_shapes) {
+  const T* params = attrs.as<T>();
+
+  // We always make other operators to fit the layouts of convolution layers
+  // So this inference ignores all inputs
+  return Array<Array<Layout>>{{params->data_layout, params->kernel_layout}, {params->data_layout}};
+}
+
+bool BitPackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                const TypeReporter& reporter) {
+  const BitPackAttrs* param = attrs.as<BitPackAttrs>();
+  CHECK_EQ(types.size(), 2);
+  const auto* data = types[0].as<TensorTypeNode>();
+  CHECK(data);
+  int ndim = data->shape.size();
+  int bits = param->bits;
+  int pack_axis = param->pack_axis;
+  int bit_axis = param->bit_axis;
+  DataType pack_type = param->pack_type;
+
+  int pack_bits = pack_type.bits();
+
+  Array<IndexExpr> out_shape;
+  for (int i = 0; i < ndim; ++i) {
+    if (i == bit_axis) {
+      out_shape.push_back(bits);
+      if (i == pack_axis) {
+        out_shape.push_back(data->shape[i] / pack_bits);
+      } else {
+        out_shape.push_back(data->shape[i]);
+      }
+    } else if (i == pack_axis) {
+      out_shape.push_back(data->shape[i] / pack_bits);
+    } else {
+      out_shape.push_back(data->shape[i]);
+    }
+  }
+  // Add extra check for last axis expansion.
+  if (bit_axis == ndim) {
+    out_shape.push_back(bits);
+  }
+
+  reporter->Assign(types[1], TensorTypeNode::make(out_shape, pack_type));
+  return true;
+}
+
+Expr MakeBitPack(Expr data, int bits, int pack_axis, int bit_axis, DataType pack_type,
+                 std::string name) {
+  auto attrs = make_node<BitPackAttrs>();
+  attrs->bits = bits;
+  attrs->pack_axis = pack_axis;
+  attrs->bit_axis = bit_axis;
+  attrs->pack_type = pack_type;
+  attrs->name = name;
+  static const Op& op = Op::Get("nn.bitpack");
+  return CallNode::make(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_API("relay.op.nn._make.bitpack").set_body_typed(MakeBitPack);
+
+RELAY_REGISTER_OP("nn.bitpack")
+    .describe(R"code(Bitpack layer that prepares data for bitserial operations.
+
+This layer backs the bits of an input into a single datatype, allowing 
+efficient implementation of bitserial operations.
+
+- **data**: Input tensor of any shape, dimension that is to be
+            packed must be divisible by number of bits.
+- **out**:  Packed tensor with shape appropriately compressed. 
+)code" TVM_ADD_FILELINE)
+    .set_num_inputs(1)
+    .set_attrs_type_key("relay.attrs.BitPackAttrs")
+    .add_argument("data", "Tensor", "Input data.")
+    .set_support_level(2)
+    .add_type_rel("BitPack", BitPackRel);
+
+// relay.nn.bitserial_conv2d
+TVM_REGISTER_NODE_TYPE(BinaryConv2DAttrs);
+
+bool BinaryConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                     const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 3);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+
+  const BinaryConv2DAttrs* param = attrs.as<BinaryConv2DAttrs>();
+  CHECK(param != nullptr);
+
+  static const Layout kNCHW("NCHW");
+
+  const Layout in_layout(param->data_layout);
+  const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW);
+  Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
+  CHECK(param->channels.defined());
+  CHECK(param->kernel_size.defined());
+  Array<IndexExpr> oshape({dshape_nchw[0], param->channels, 0, 0});
+  oshape.Set(
+      2, (dshape_nchw[2] + param->padding[0] * 2 - param->kernel_size[0]) / param->strides[0] + 1);
+  oshape.Set(
+      3, (dshape_nchw[3] + param->padding[1] * 2 - param->kernel_size[1]) / param->strides[1] + 1);
+  DataType out_dtype = param->out_dtype;
+  oshape = trans_in_layout.BackwardShape(oshape);
+  // assign output type
+  reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
+  return true;
+}
+
+// Positional relay function to create binaryconv2d operator
+// used by frontend FFI.
+Expr MakeBinaryConv2D(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
+                      IndexExpr channels, Array<IndexExpr> kernel_size, int activation_bits,
+                      int weight_bits, std::string data_layout, std::string kernel_layout,
+                      DataType pack_dtype, DataType out_dtype, bool unipolar) {
+  auto attrs = make_node<BinaryConv2DAttrs>();
+  attrs->strides = std::move(strides);
+  attrs->padding = std::move(padding);
+  attrs->channels = std::move(channels);
+  attrs->kernel_size = std::move(kernel_size);
+  attrs->activation_bits = activation_bits;
+  attrs->weight_bits = weight_bits;
+  attrs->data_layout = std::move(data_layout);
+  attrs->kernel_layout = std::move(kernel_layout);
+  attrs->pack_dtype = std::move(pack_dtype);
+  attrs->out_dtype = std::move(out_dtype);
+  attrs->unipolar = unipolar;
+  static const Op& op = Op::Get("nn.bitserial_conv2d");
+  return CallNode::make(op, {data, weight}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_API("relay.op.nn._make.bitserial_conv2d").set_body_typed(MakeBinaryConv2D);
+
+RELAY_REGISTER_OP("nn.bitserial_conv2d")
+    .describe(R"code(2D convolution using packed binary computation.
+
+This layer creates a convolution kernel that is convolved with the
+layer input using bitserial computation. This enables faster processing
+on some platforms.
+
+- **data**:   4D input tensor that can be either `NCHW` or `NHWC` layout.
+
+- **weight**: Weight tensor that can either be prepacked (5D) or unpacked (4D).
+              When data is NCHW, weight is expected to be OIHW or OIHWi.
+              When data is NHWC weight is expected to be HWIO or HWIOi.
+
+- **out**:    Output with same layout as input.            
+)code" TVM_ADD_FILELINE)
+    .set_attrs_type_key("relay.attrs.BinaryConv2DAttrs")
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("weight", "Tensor", "The weight tensor.")
+    .set_support_level(2)
+    .add_type_rel("BinaryConv2D", BinaryConv2DRel)
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
+                                   BinaryConv2DInferCorrectLayout<BinaryConv2DAttrs>);
+
+// relay.nn.bitserial_dense
+TVM_REGISTER_NODE_TYPE(BinaryDenseAttrs);
+
+bool BinaryDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                    const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 3);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+
+  const BinaryDenseAttrs* param = attrs.as<BinaryDenseAttrs>();
+  CHECK(param != nullptr);
+
+  CHECK(static_cast<int>(data->shape.size()) != 0);
+  CHECK(param->units.defined());
+
+  Array<tvm::Expr> oshape = data->shape;
+  oshape.Set((oshape.size() - 1), param->units);
+
+  DataType out_dtype = param->out_dtype;
+  if (out_dtype.bits() == 0) {
+    out_dtype = data->dtype;
+  }
+
+  // Assign output type.
+  reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
+  return true;
+}
+
+// Positional relay function to create bitserial dense operator used by frontend FFI.
+Expr MakeBinaryDense(Expr data, Expr weight, IndexExpr units, int data_bits, int weight_bits,
+                     DataType pack_dtype, DataType out_dtype, bool unipolar) {
+  auto attrs = make_node<BinaryDenseAttrs>();
+  attrs->units = units;
+  attrs->data_bits = data_bits;
+  attrs->weight_bits = weight_bits;
+  attrs->pack_dtype = pack_dtype;
+  attrs->out_dtype = out_dtype;
+  attrs->unipolar = unipolar;
+  static const Op& op = Op::Get("nn.bitserial_dense");
+  return CallNode::make(op, {data, weight}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_API("relay.op.nn._make.bitserial_dense").set_body_typed(MakeBinaryDense);
+
+RELAY_REGISTER_OP("nn.bitserial_dense")
+    .describe(R"code(Applies a quantized linear transformation: :math:`Y = XW^T`.
+
+- **data**: `(x1, x2, ..., xn, input_dim)`
+- **weight**: `(units, input_dim)`
+- **out**: `(x1, x2, ..., xn, units)`.
+
+)code" TVM_ADD_FILELINE)
+    .set_attrs_type_key("relay.attrs.BinaryDenseAttrs")
+    .set_num_inputs(2)
+    .add_argument("data", "2D Tensor", "Input data.")
+    .add_argument("weight", "2D Tensor", "Weight matrix.")
+    .set_support_level(1)
+    .add_type_rel("BinaryDense", BinaryDenseRel);
+
+}  // namespace relay
+}  // namespace tvm
index 66e65c5..c25393c 100644 (file)
@@ -337,6 +337,16 @@ def test_dense():
         tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
 
 
+def test_bitserial_dense():
+    m, k = tvm.var("m"), tvm.var("k")
+    x = relay.var("x", relay.TensorType((m, k), "int16"))
+    w = relay.var("w", relay.TensorType((k, 32), "int16"))
+    y = relay.nn.bitserial_dense(x, w, units=32)
+    "units=8" in y.astext()
+    yy = run_infer_type(y)
+    assert yy.checked_type == relay.TensorType((m, 32), "int16")
+
+
 if __name__ == "__main__":
     test_concatenate()
     test_bias_add()
@@ -349,3 +359,4 @@ if __name__ == "__main__":
     test_dropout()
     test_batch_norm()
     test_dense()
+    test_bitserial_dense()
index 5e9abdf..a94a203 100644 (file)
@@ -105,8 +105,8 @@ def test_conv2d_run():
                         except_targets=None,
                         **attrs):
         if except_targets is None:
-          except_targets = []
-          
+            except_targets = []
+
         x = relay.var("x", shape=dshape, dtype=dtype)
         w = relay.var("w", dtype=dtype)
         y = relay.nn.conv2d(x, w,
@@ -599,12 +599,35 @@ def test_conv2d_int8_intrinsics():
     assert "vpmulld" in asm and "vpadd" in asm
 
 
+def test_bitserial_conv2d_infer_type():
+    # Basic shape test with ambiguous batch.
+    n, c, h, w = tvm.var("n"), 32, 224, 224
+    x = relay.var("x", relay.ty.TensorType((n, c, h, w), "int16"))
+    w = relay.var("w", relay.ty.TensorType((32, 32, 3, 3), "int16"))
+    y = relay.nn.bitserial_conv2d(
+        x, w, kernel_size=(3, 3), padding=(0, 0), channels=32)
+    yy = run_infer_type(y)
+    assert yy.checked_type ==  relay.TensorType(
+        (n, 32, 222, 222), "int16")
+
+
+def test_bitpack_infer_type():
+    # Test axis packing shape inference.
+    o, i, h, w = 32, 32, 128, 128
+    x = relay.var("x", relay.ty.TensorType((o, i, h, w), "int16"))
+    y = relay.nn.bitpack(x, bit_axis=4, pack_axis=1, pack_type='uint16', bits=1)
+    yy = run_infer_type(y)
+    assert yy.checked_type ==  relay.TensorType(
+        (32, 2, 128, 128, 1), "uint16")
+
+
 if __name__ == "__main__":
     test_pool2d()
     test_avg_pool2d_no_count_pad()
     test_lrn()
     test_l2_normalize()
     test_conv2d_infer_type()
+    test_bitpack_infer_type()
     test_upsampling_infer_type()
     test_flatten_infer_type()
     test_pad_infer_type()
@@ -612,6 +635,7 @@ if __name__ == "__main__":
     test_conv2d_transpose_infer_type()
     test_conv2d_transpose_run()
     test_conv2d_run()
+    test_bitserial_conv2d_infer_type()
     test_batch_flatten()
     test_upsampling()
     test_conv2d_int8_intrinsics()
index 4198267..af9c5be 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name,unused-variable,invalid-name
+# pylint: disable=invalid-name,unused-variable,invalid-name,unused-argument
 """Bitserial conv2d schedule on arm cpu"""
 from __future__ import absolute_import as _abs
 import tvm
 from tvm import autotvm
+from tvm import relay
 from .. import tag
 from ..nn.pad import pad
-from ..nn.bitserial_conv2d import bitserial_conv2d_nhwc
+from ..nn.bitserial_conv2d import bitserial_conv2d_nhwc, bitserial_conv2d_legalize
 from ..nn.bitserial_util import bitpack, binary_op_multiplier
 from ..nn.util import get_pad_tuple
 from ..util import get_const_int, get_const_tuple
@@ -350,3 +351,40 @@ def schedule_bitserial_conv2d_nhwc(cfg, outs):
 
     traverse(outs[0].op)
     return s
+
+@bitserial_conv2d_legalize.register("arm_cpu")
+def _bitserial_conv2d_legalize(attrs, inputs, arg_types):
+    """Legalizes Bitserial Conv2D op.
+
+    Parameters
+    ----------
+    attrs : tvm.attrs.Attrs
+        Attributes of current convolution
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    types : list of types
+        List of input and output types
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The legalized expr
+    """
+
+    # Fix different kernel layouts where possible.
+    if attrs['data_layout'] == 'NHWC':
+        data, kernel = inputs
+        if len(kernel.data.shape) == 4:
+            # HWIO layout is expected for NHWC input.
+            if attrs['kernel_layout'] == 'HWOI':
+                # Handle HWOI layout. This is common in TF depthwise conv2d graph.
+                kernel = relay.transpose(kernel, axes=(0, 1, 3, 2))
+            elif attrs['kernel_layout'] == 'OIHW':
+                kernel = relay.transpose(kernel, axes=(2, 3, 1, 0))
+            ## Set new attrs for the tranposed conv.
+            new_attrs = {k: attrs[k] for k in attrs.keys()}
+            new_attrs['kernel_layout'] = 'HWIO'
+
+            conv = relay.nn.bitserial_conv2d(data, kernel, **new_attrs)
+            return conv
+    return None
index 38b6632..8fbedec 100644 (file)
@@ -470,6 +470,23 @@ def schedule_binarize_pack(outs):
     return _default_schedule(outs, False)
 
 
+@tvm.target.override_native_generic_func("schedule_bitpack")
+def schedule_bitpack(outs):
+    """Schedule for bitpack
+    Parameters
+    ----------
+    outs: Array of Tensor
+        The computation graph description of bitpack
+        in the format of an array of tensors.
+
+    Returns
+    -------
+    sch: Schedule
+        The computation schedule for the op.
+    """
+    return _default_schedule(outs, False)
+
+
 @tvm.target.override_native_generic_func("schedule_binary_dense")
 def schedule_binary_dense(outs):
     """Schedule for binary_dense
index 99cac88..21abdf0 100644 (file)
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=invalid-name, too-many-locals, too-many-arguments
+# pylint: disable=unused-argument, redefined-builtin
 """Bitserial Conv2D operators"""
 from __future__ import absolute_import as _abs
 import tvm
@@ -65,7 +66,10 @@ def bitserial_conv2d_nchw(data, kernel, stride, padding, activation_bits, weight
     """
     assert isinstance(stride, int) or len(stride) == 2
     Input_q = bitpack(data, activation_bits, pack_axis=1, bit_axis=2, pack_type=pack_dtype)
-    Filter_q = bitpack(filter, weight_bits, pack_axis=1, bit_axis=4, pack_type=pack_dtype)
+    if len(filter.shape) == 4:
+        Filter_q = bitpack(filter, weight_bits, pack_axis=1, bit_axis=4, pack_type=pack_dtype)
+    else:
+        Filter_q = filter
     batch, in_channel, activation_bits, in_height, in_width = Input_q.shape
     num_filter, _, kernel_h, kernel_w, weight_bits = Filter_q.shape
 
@@ -414,3 +418,24 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
     return tvm.compute(oshape, lambda n, h, w, co:
                        conv[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC],
                        name='output_unpack', tag='spatial_bitserial_conv_nhwc')
+
+@tvm.target.generic_func
+def bitserial_conv2d_legalize(attrs, inputs, types):
+    """Legalizes Bitserial Conv2D op.
+
+    Parameters
+    ----------
+    attrs : tvm.attrs.Attrs
+        Attributes of current convolution
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    types : list of types
+        List of input and output types
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The legalized expr
+    """
+    # not to change by default
+    return None