[Relay][Op] Add instance norm op (#4004)
authorbindog <bindog@outlook.com>
Thu, 3 Oct 2019 00:01:36 +0000 (08:01 +0800)
committerWuwei Lin <wuwei@apache.org>
Thu, 3 Oct 2019 00:01:36 +0000 (20:01 -0400)
* [Relay][Op] Add instance norm op

* mend

[Relay][Op] Add instance norm op

include/tvm/relay/attrs/nn.h
python/tvm/relay/frontend/mxnet.py
python/tvm/relay/frontend/onnx.py
python/tvm/relay/op/nn/nn.py
src/relay/op/nn/nn.cc
src/relay/pass/simplify_inference.cc
tests/python/frontend/mxnet/test_forward.py
tests/python/frontend/onnx/test_forward.py

index 5d2c551..d6f9bee 100644 (file)
@@ -492,6 +492,29 @@ struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> {
 };  // struct BatchNormAttrs
 
 
+/*! \brief Attributes used in instance_norm operator */
+struct InstanceNormAttrs : public tvm::AttrsNode<InstanceNormAttrs> {
+  int axis;
+  double epsilon;
+  bool center;
+  bool scale;
+
+  TVM_DECLARE_ATTRS(InstanceNormAttrs, "relay.attrs.InstanceNormAttrs") {
+    TVM_ATTR_FIELD(axis)
+      .describe("Specify which shape axis denotes the channel.")
+      .set_default(1);
+    TVM_ATTR_FIELD(epsilon)
+      .describe("Small float added to variance to avoid dividing by zero")
+      .set_default(1e-5);
+    TVM_ATTR_FIELD(center).set_default(true)
+      .describe("If true, add offset of beta to normalized tensor; "
+                "otherwise, beta is ignored.");
+    TVM_ATTR_FIELD(scale).set_default(true)
+      .describe("If true, multiply by gamma; otherwise, gamma is ignored.");
+  }
+};  // struct InstanceNormAttrs
+
+
 /*! \brief Attributes used in layer_norm operator */
 struct LayerNormAttrs : public tvm::AttrsNode<LayerNormAttrs> {
   int axis;
index ec922ab..0febfdd 100644 (file)
@@ -324,6 +324,14 @@ def _mx_batch_norm(inputs, attrs):
     return _op.nn.batch_norm(*inputs, **new_attrs)
 
 
+def _mx_instance_norm(inputs, attrs):
+    assert len(inputs) == 3
+    new_attrs = {}
+    new_attrs["axis"] = attrs.get_int("axis", 1)
+    new_attrs["epsilon"] = attrs.get_float("eps", 1e-5)
+    return _op.nn.instance_norm(*inputs, **new_attrs)
+
+
 def _mx_layer_norm(inputs, attrs):
     assert len(inputs) == 3
     if attrs.get_bool("output_mean_var", False):
@@ -1133,6 +1141,7 @@ _convert_map = {
     "Dropout"       : _mx_dropout,
     "BatchNorm"     : _mx_batch_norm,
     "BatchNorm_v1"  : _mx_batch_norm,
+    "InstanceNorm"  : _mx_instance_norm,
     "LayerNorm"     : _mx_layer_norm,
     "LRN"           : _mx_lrn,
     "L2Normalization"  : _mx_l2_normalize,
index 387e9cf..a7f7874 100644 (file)
@@ -176,6 +176,15 @@ class BatchNorm(OnnxOpConverter):
         return out[0]
 
 
+class InstanceNorm(OnnxOpConverter):
+    """ Operator converter for BatchNorm.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        return AttrCvt(op_name='instance_norm')(inputs, attr, params)
+
+
 class Conv(OnnxOpConverter):
     """ Operator converter for Conv.
     """
@@ -999,7 +1008,7 @@ def _get_convert_map(opset):
         'GlobalAveragePool': Renamer('global_avg_pool2d'),
         'GlobalMaxPool': Renamer('global_max_pool2d'),
         'BatchNormalization': BatchNorm.get_converter(opset),
-        # 'InstanceNormalization'
+        'InstanceNormalization': InstanceNorm.get_converter(opset),
         # 'LpNormalization'
         'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
         'Flatten': Flatten.get_converter(opset),
index 3f1e3bc..43ae06f 100644 (file)
@@ -935,6 +935,73 @@ def batch_norm(data,
     return TupleWrapper(result, 3)
 
 
+def instance_norm(data,
+                  gamma,
+                  beta,
+                  axis=1,
+                  epsilon=1e-5,
+                  center=True,
+                  scale=True):
+    r"""
+    Instance Normalization (Ulyanov and et al., 2016)
+    Applies instance normalization to the n-dimensional input array.
+
+    .. math::
+
+        out = \frac{data - mean(data)}{\sqrt{var(data)+\epsilon}}
+            * gamma + beta
+
+    The instance normalization is similar to batch normalization, but unlike
+    batch normalization, the mean and var are calculated per-dimension
+    separately for each object(instance) in a mini-batch, not over a batch.
+    And the same normalization is applied both at test and train time.
+
+    Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
+    have shape *(k,)*.
+
+    The parameter ``axis`` specifies which axis of the input shape denotes
+    the 'channel'.  The default is 1. Specifying -1 sets the channel axis
+    to be the last item in the input shape.
+
+    .. note::
+
+        This operator can be optimized away for inference.
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        Input to which instance_norm will be applied.
+
+    gamma : tvm.relay.Expr
+        The gamma scale factor.
+
+    beta : tvm.relay.Expr
+        The beta offset factor.
+
+    axis : int, optional, default=1
+        Specify along which shape axis the channel is specified.
+
+    epsilon : double, optional, default=1e-5
+        Small float added to variance to avoid dividing by zero.
+
+    center : boolean, optional, default=True
+        If True, add offset of beta to normalized tensor, If False,
+        beta is ignored.
+
+    scale : boolean, optional, default=True
+        If True, multiply by gamma. If False, gamma is not used.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The normalized data.
+
+    .. _`Instance Normalization: The Missing Ingredient for Fast Stylization`:
+        https://arxiv.org/abs/1607.08022
+    """
+    return _make.instance_norm(data, gamma, beta, axis, epsilon, center, scale)
+
+
 def layer_norm(data,
                gamma,
                beta,
@@ -964,7 +1031,7 @@ def layer_norm(data,
     Parameters
     ----------
     data : tvm.relay.Expr
-        Input to which batch_norm will be applied.
+        Input to which layer_norm will be applied.
 
     gamma : tvm.relay.Expr
         The gamma scale factor.
index 42a0f01..9115c69 100644 (file)
@@ -640,6 +640,76 @@ axis to be the last item in the input shape.
 .add_type_rel("BatchNorm", BatchNormRel);
 
 
+// instance_norm
+TVM_REGISTER_NODE_TYPE(InstanceNormAttrs);
+
+bool InstanceNormRel(const Array<Type>& types,
+                     int num_inputs,
+                     const Attrs& attrs,
+                     const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 4);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+  const InstanceNormAttrs* param = attrs.as<InstanceNormAttrs>();
+  int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size();
+  CHECK(axis >= 0 && axis < (int)data->shape.size());
+  reporter->Assign(types[1], TensorTypeNode::make({data->shape[axis]}, data->dtype));
+  reporter->Assign(types[2], TensorTypeNode::make({data->shape[axis]}, data->dtype));
+  reporter->Assign(types[3], TensorTypeNode::make(data->shape, data->dtype));
+
+  return true;
+}
+
+Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon,
+                      bool center, bool scale) {
+  auto attrs = make_node<InstanceNormAttrs>();
+  attrs->axis = axis;
+  attrs->epsilon = epsilon;
+  attrs->center = center;
+  attrs->scale = scale;
+  static const Op& op = Op::Get("nn.instance_norm");
+  return CallNode::make(op, {data, gamma, beta}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_API("relay.op.nn._make.instance_norm")
+.set_body([](const TVMArgs& args, TVMRetValue* rv) {
+    runtime::detail::unpack_call<Expr, 7>(MakeInstanceNorm, args, rv);
+  });
+
+RELAY_REGISTER_OP("nn.instance_norm")
+.describe(R"code(Instance Normalization (Ulyanov and et al., 2016)
+Applies instance normalization to the n-dimensional input array.
+
+.. math::
+
+    out = \frac{data - mean(data)}{\sqrt{var(data)+\epsilon}}
+        * gamma + beta
+
+The instance normalization is similar to batch normalization, but unlike
+batch normalization, the mean and var are calculated per-dimension
+separately for each object(instance) in a mini-batch, not over a batch.
+And the same normalization is applied both at test and train time.
+
+Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
+have shape *(k,)*.
+
+The parameter ``axis`` specifies which axis of the input shape denotes
+the 'channel'.  The default is 1. Specifying -1 sets the channel axis
+to be the last item in the input shape.
+
+.. note::
+
+    This operator can be optimized away for inference.
+)code" TVM_ADD_FILELINE)
+.set_attrs_type_key("relay.attrs.InstanceNormAttrs")
+.set_num_inputs(3)
+.add_argument("data", "Tensor", "Input to which instance_norm will be applied.")
+.add_argument("gamma", "Tensor", "The gamma scale factor.")
+.add_argument("beta", "Tensor", "The beta offset factor.")
+.set_support_level(1)
+.add_type_rel("InstanceNorm", InstanceNormRel);
+
+
 // layer_norm
 TVM_REGISTER_NODE_TYPE(LayerNormAttrs);
 
index 3790dbf..70586de 100644 (file)
@@ -92,6 +92,41 @@ Expr LayerNormToInferUnpack(const Attrs attrs,
   return out;
 }
 
+
+Expr InstanceNormToInferUnpack(const Attrs attrs,
+                               Expr data,
+                               Expr gamma,
+                               Expr beta,
+                               Type tdata) {
+  auto ttype = tdata.as<TensorTypeNode>();
+  CHECK(ttype);
+  const auto param = attrs.as<InstanceNormAttrs>();
+  CHECK(param);
+
+  int ndim = ttype->shape.size();
+  int axis = (param->axis < 0) ? param->axis + ndim : param->axis;
+  Array<Integer> reduced_axes;
+  for (int i = 1; i < ndim; ++i) {
+      if (i != axis)
+          reduced_axes.push_back(i);
+  }
+
+  Expr epsilon = MakeConstantScalar(Float(32), static_cast<float>(param->epsilon));
+  Expr mean = Mean(data, reduced_axes, true, false);
+  Expr var = Variance(data, mean, reduced_axes, true, false);
+  Expr denom = Sqrt(Add(var, epsilon));
+  Expr out = Divide(Subtract(data, mean), denom);
+
+  if (param->scale) {
+    out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis}));
+  }
+  if (param->center) {
+    out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis}));
+  }
+  return out;
+}
+
+
 class InferenceSimplifier : public ExprMutator {
  public:
   Expr VisitExpr_(const TupleGetItemNode* n) final {
@@ -116,6 +151,7 @@ class InferenceSimplifier : public ExprMutator {
 
   Expr VisitExpr_(const CallNode* n) {
     static const Op& batch_norm = Op::Get("nn.batch_norm");
+    static const Op& instance_norm = Op::Get("nn.instance_norm");
     static const Op& layer_norm = Op::Get("nn.layer_norm");
     auto new_n = ExprMutator::VisitExpr_(n);
     if (n->op.same_as(batch_norm)) {
@@ -124,6 +160,10 @@ class InferenceSimplifier : public ExprMutator {
       const auto* call = new_n.as<CallNode>();
       return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1],
                                     call->args[2], n->args[0]->checked_type());
+    } else if (n->op.same_as(instance_norm)) {
+      const auto* call = new_n.as<CallNode>();
+      return InstanceNormToInferUnpack(call->attrs, call->args[0], call->args[1],
+                                    call->args[2], n->args[0]->checked_type());
     }
     return new_n;
   }
index 4530585..f45f152 100644 (file)
@@ -758,6 +758,26 @@ def test_forward_batch_norm():
     verify((2, 3, 4, 5), fix_gamma=True)
 
 
+def test_forward_instance_norm():
+    def verify(shape, axis=1, epsilon=1e-5):
+        x = np.random.uniform(size=shape).astype("float32")
+        gamma = np.random.uniform(size=(shape[axis])).astype("float32")
+        beta = np.random.uniform(size=(shape[axis])).astype("float32")
+        ref_res = mx.nd.InstanceNorm(mx.nd.array(x), mx.nd.array(gamma), mx.nd.array(beta), epsilon)
+        mx_sym = mx.sym.InstanceNorm(mx.sym.var("x"), mx.sym.var("gamma"), mx.sym.var("beta"), epsilon)
+        shape_dict = {"x": x.shape, "gamma": gamma.shape, "beta": beta.shape}
+        mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+                op_res = intrp.evaluate()(x, gamma, beta)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5)
+    verify((2, 3, 4, 5))
+    verify((32, 64, 80, 64))
+    verify((8, 6, 5))
+    verify((8, 7, 6, 5, 4))
+
+
 def test_forward_layer_norm():
     def verify(shape, axis=-1):
         x = np.random.uniform(size=shape).astype("float32")
@@ -938,6 +958,7 @@ if __name__ == '__main__':
     test_forward_sequence_mask()
     test_forward_contrib_div_sqrt_dim()
     test_forward_batch_norm()
+    test_forward_instance_norm()
     test_forward_layer_norm()
     test_forward_one_hot()
     test_forward_convolution()
index dc9493f..16e7174 100644 (file)
@@ -416,6 +416,50 @@ def test_lrn():
     verify_lrn((5, 5, 5, 5), 3, 'float32')
     verify_lrn((5, 5, 5, 5), 3, 'float32', alpha=0.0002, beta=0.5, bias=2.0)
 
+
+def verify_instance_norm(shape, axis=1):
+
+    def _get_python_instance_norm(x, gamma, beta, epsilon=1e-5):
+        dims_x = len(x.shape)
+        axis = tuple(range(2, dims_x))
+        mean = np.mean(x, axis=axis, keepdims=True)
+        var = np.var(x, axis=axis, keepdims=True)
+        dim_ones = (1,) * (dims_x - 2)
+        gamma = gamma.reshape(-1, *dim_ones)
+        beta = beta.reshape(-1, *dim_ones)
+        return gamma * (x - mean) / np.sqrt(var + epsilon) + beta
+
+    x = np.random.randn(*shape).astype(np.float32)
+    gamma = np.random.randn(shape[1]).astype(np.float32)
+    beta = np.random.randn(shape[1]).astype(np.float32)
+    epsilon = 1e-5
+    y = _get_python_instance_norm(x, gamma, beta, epsilon).astype(np.float32)
+
+    node = onnx.helper.make_node(
+            'InstanceNormalization',
+            inputs=['x', 'gamma', 'beta'],
+            outputs=['y'],
+            epsilon=epsilon,
+        )
+    graph = helper.make_graph([node],
+                              "instance_norm_test",
+                              inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(shape)),
+                                      helper.make_tensor_value_info("gamma", TensorProto.FLOAT, (shape[1],)),
+                                      helper.make_tensor_value_info("beta", TensorProto.FLOAT, (shape[1],))],
+                              outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(shape))])
+    model = helper.make_model(graph, producer_name='instance_norm_test')
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(model, [x, gamma, beta], target, ctx, shape, 'float32')
+        tvm.testing.assert_allclose(y, tvm_out, rtol=1e-5, atol=1e-5)
+
+
+def test_instance_norm():
+    verify_instance_norm((2, 3, 4, 5))
+    verify_instance_norm((32, 64, 80, 64))
+    verify_instance_norm((8, 6, 5))
+    verify_instance_norm((8, 7, 6, 5, 4))
+
+
 def _test_upsample_nearest():
     scale = 2
     in_shape = (1, 1, 3, 3)
@@ -1270,6 +1314,7 @@ if __name__ == '__main__':
     test_matmul()
     test_gather()
     test_lrn()
+    test_instance_norm()
     test_upsample()
     test_forward_min()
     test_forward_max()