From: bindog Date: Thu, 3 Oct 2019 00:01:36 +0000 (+0800) Subject: [Relay][Op] Add instance norm op (#4004) X-Git-Tag: upstream/0.7.0~1834 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=7d911f46c3a3a02cd541435aa2495ceca57a88ba;p=platform%2Fupstream%2Ftvm.git [Relay][Op] Add instance norm op (#4004) * [Relay][Op] Add instance norm op * mend [Relay][Op] Add instance norm op --- diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 5d2c551..d6f9bee 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -492,6 +492,29 @@ struct BatchNormAttrs : public tvm::AttrsNode { }; // struct BatchNormAttrs +/*! \brief Attributes used in instance_norm operator */ +struct InstanceNormAttrs : public tvm::AttrsNode { + int axis; + double epsilon; + bool center; + bool scale; + + TVM_DECLARE_ATTRS(InstanceNormAttrs, "relay.attrs.InstanceNormAttrs") { + TVM_ATTR_FIELD(axis) + .describe("Specify which shape axis denotes the channel.") + .set_default(1); + TVM_ATTR_FIELD(epsilon) + .describe("Small float added to variance to avoid dividing by zero") + .set_default(1e-5); + TVM_ATTR_FIELD(center).set_default(true) + .describe("If true, add offset of beta to normalized tensor; " + "otherwise, beta is ignored."); + TVM_ATTR_FIELD(scale).set_default(true) + .describe("If true, multiply by gamma; otherwise, gamma is ignored."); + } +}; // struct InstanceNormAttrs + + /*! \brief Attributes used in layer_norm operator */ struct LayerNormAttrs : public tvm::AttrsNode { int axis; diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index ec922ab..0febfdd 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -324,6 +324,14 @@ def _mx_batch_norm(inputs, attrs): return _op.nn.batch_norm(*inputs, **new_attrs) +def _mx_instance_norm(inputs, attrs): + assert len(inputs) == 3 + new_attrs = {} + new_attrs["axis"] = attrs.get_int("axis", 1) + new_attrs["epsilon"] = attrs.get_float("eps", 1e-5) + return _op.nn.instance_norm(*inputs, **new_attrs) + + def _mx_layer_norm(inputs, attrs): assert len(inputs) == 3 if attrs.get_bool("output_mean_var", False): @@ -1133,6 +1141,7 @@ _convert_map = { "Dropout" : _mx_dropout, "BatchNorm" : _mx_batch_norm, "BatchNorm_v1" : _mx_batch_norm, + "InstanceNorm" : _mx_instance_norm, "LayerNorm" : _mx_layer_norm, "LRN" : _mx_lrn, "L2Normalization" : _mx_l2_normalize, diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 387e9cf..a7f7874 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -176,6 +176,15 @@ class BatchNorm(OnnxOpConverter): return out[0] +class InstanceNorm(OnnxOpConverter): + """ Operator converter for BatchNorm. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + return AttrCvt(op_name='instance_norm')(inputs, attr, params) + + class Conv(OnnxOpConverter): """ Operator converter for Conv. """ @@ -999,7 +1008,7 @@ def _get_convert_map(opset): 'GlobalAveragePool': Renamer('global_avg_pool2d'), 'GlobalMaxPool': Renamer('global_max_pool2d'), 'BatchNormalization': BatchNorm.get_converter(opset), - # 'InstanceNormalization' + 'InstanceNormalization': InstanceNorm.get_converter(opset), # 'LpNormalization' 'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']), 'Flatten': Flatten.get_converter(opset), diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 3f1e3bc..43ae06f 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -935,6 +935,73 @@ def batch_norm(data, return TupleWrapper(result, 3) +def instance_norm(data, + gamma, + beta, + axis=1, + epsilon=1e-5, + center=True, + scale=True): + r""" + Instance Normalization (Ulyanov and et al., 2016) + Applies instance normalization to the n-dimensional input array. + + .. math:: + + out = \frac{data - mean(data)}{\sqrt{var(data)+\epsilon}} + * gamma + beta + + The instance normalization is similar to batch normalization, but unlike + batch normalization, the mean and var are calculated per-dimension + separately for each object(instance) in a mini-batch, not over a batch. + And the same normalization is applied both at test and train time. + + Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta`` + have shape *(k,)*. + + The parameter ``axis`` specifies which axis of the input shape denotes + the 'channel'. The default is 1. Specifying -1 sets the channel axis + to be the last item in the input shape. + + .. note:: + + This operator can be optimized away for inference. + + Parameters + ---------- + data : tvm.relay.Expr + Input to which instance_norm will be applied. + + gamma : tvm.relay.Expr + The gamma scale factor. + + beta : tvm.relay.Expr + The beta offset factor. + + axis : int, optional, default=1 + Specify along which shape axis the channel is specified. + + epsilon : double, optional, default=1e-5 + Small float added to variance to avoid dividing by zero. + + center : boolean, optional, default=True + If True, add offset of beta to normalized tensor, If False, + beta is ignored. + + scale : boolean, optional, default=True + If True, multiply by gamma. If False, gamma is not used. + + Returns + ------- + result : tvm.relay.Expr + The normalized data. + + .. _`Instance Normalization: The Missing Ingredient for Fast Stylization`: + https://arxiv.org/abs/1607.08022 + """ + return _make.instance_norm(data, gamma, beta, axis, epsilon, center, scale) + + def layer_norm(data, gamma, beta, @@ -964,7 +1031,7 @@ def layer_norm(data, Parameters ---------- data : tvm.relay.Expr - Input to which batch_norm will be applied. + Input to which layer_norm will be applied. gamma : tvm.relay.Expr The gamma scale factor. diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 42a0f01..9115c69 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -640,6 +640,76 @@ axis to be the last item in the input shape. .add_type_rel("BatchNorm", BatchNormRel); +// instance_norm +TVM_REGISTER_NODE_TYPE(InstanceNormAttrs); + +bool InstanceNormRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 4); + const auto* data = types[0].as(); + if (data == nullptr) return false; + const InstanceNormAttrs* param = attrs.as(); + int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size(); + CHECK(axis >= 0 && axis < (int)data->shape.size()); + reporter->Assign(types[1], TensorTypeNode::make({data->shape[axis]}, data->dtype)); + reporter->Assign(types[2], TensorTypeNode::make({data->shape[axis]}, data->dtype)); + reporter->Assign(types[3], TensorTypeNode::make(data->shape, data->dtype)); + + return true; +} + +Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, + bool center, bool scale) { + auto attrs = make_node(); + attrs->axis = axis; + attrs->epsilon = epsilon; + attrs->center = center; + attrs->scale = scale; + static const Op& op = Op::Get("nn.instance_norm"); + return CallNode::make(op, {data, gamma, beta}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op.nn._make.instance_norm") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeInstanceNorm, args, rv); + }); + +RELAY_REGISTER_OP("nn.instance_norm") +.describe(R"code(Instance Normalization (Ulyanov and et al., 2016) +Applies instance normalization to the n-dimensional input array. + +.. math:: + + out = \frac{data - mean(data)}{\sqrt{var(data)+\epsilon}} + * gamma + beta + +The instance normalization is similar to batch normalization, but unlike +batch normalization, the mean and var are calculated per-dimension +separately for each object(instance) in a mini-batch, not over a batch. +And the same normalization is applied both at test and train time. + +Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta`` +have shape *(k,)*. + +The parameter ``axis`` specifies which axis of the input shape denotes +the 'channel'. The default is 1. Specifying -1 sets the channel axis +to be the last item in the input shape. + +.. note:: + + This operator can be optimized away for inference. +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.InstanceNormAttrs") +.set_num_inputs(3) +.add_argument("data", "Tensor", "Input to which instance_norm will be applied.") +.add_argument("gamma", "Tensor", "The gamma scale factor.") +.add_argument("beta", "Tensor", "The beta offset factor.") +.set_support_level(1) +.add_type_rel("InstanceNorm", InstanceNormRel); + + // layer_norm TVM_REGISTER_NODE_TYPE(LayerNormAttrs); diff --git a/src/relay/pass/simplify_inference.cc b/src/relay/pass/simplify_inference.cc index 3790dbf..70586de 100644 --- a/src/relay/pass/simplify_inference.cc +++ b/src/relay/pass/simplify_inference.cc @@ -92,6 +92,41 @@ Expr LayerNormToInferUnpack(const Attrs attrs, return out; } + +Expr InstanceNormToInferUnpack(const Attrs attrs, + Expr data, + Expr gamma, + Expr beta, + Type tdata) { + auto ttype = tdata.as(); + CHECK(ttype); + const auto param = attrs.as(); + CHECK(param); + + int ndim = ttype->shape.size(); + int axis = (param->axis < 0) ? param->axis + ndim : param->axis; + Array reduced_axes; + for (int i = 1; i < ndim; ++i) { + if (i != axis) + reduced_axes.push_back(i); + } + + Expr epsilon = MakeConstantScalar(Float(32), static_cast(param->epsilon)); + Expr mean = Mean(data, reduced_axes, true, false); + Expr var = Variance(data, mean, reduced_axes, true, false); + Expr denom = Sqrt(Add(var, epsilon)); + Expr out = Divide(Subtract(data, mean), denom); + + if (param->scale) { + out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis})); + } + if (param->center) { + out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis})); + } + return out; +} + + class InferenceSimplifier : public ExprMutator { public: Expr VisitExpr_(const TupleGetItemNode* n) final { @@ -116,6 +151,7 @@ class InferenceSimplifier : public ExprMutator { Expr VisitExpr_(const CallNode* n) { static const Op& batch_norm = Op::Get("nn.batch_norm"); + static const Op& instance_norm = Op::Get("nn.instance_norm"); static const Op& layer_norm = Op::Get("nn.layer_norm"); auto new_n = ExprMutator::VisitExpr_(n); if (n->op.same_as(batch_norm)) { @@ -124,6 +160,10 @@ class InferenceSimplifier : public ExprMutator { const auto* call = new_n.as(); return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2], n->args[0]->checked_type()); + } else if (n->op.same_as(instance_norm)) { + const auto* call = new_n.as(); + return InstanceNormToInferUnpack(call->attrs, call->args[0], call->args[1], + call->args[2], n->args[0]->checked_type()); } return new_n; } diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 4530585..f45f152 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -758,6 +758,26 @@ def test_forward_batch_norm(): verify((2, 3, 4, 5), fix_gamma=True) +def test_forward_instance_norm(): + def verify(shape, axis=1, epsilon=1e-5): + x = np.random.uniform(size=shape).astype("float32") + gamma = np.random.uniform(size=(shape[axis])).astype("float32") + beta = np.random.uniform(size=(shape[axis])).astype("float32") + ref_res = mx.nd.InstanceNorm(mx.nd.array(x), mx.nd.array(gamma), mx.nd.array(beta), epsilon) + mx_sym = mx.sym.InstanceNorm(mx.sym.var("x"), mx.sym.var("gamma"), mx.sym.var("beta"), epsilon) + shape_dict = {"x": x.shape, "gamma": gamma.shape, "beta": beta.shape} + mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x, gamma, beta) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5) + verify((2, 3, 4, 5)) + verify((32, 64, 80, 64)) + verify((8, 6, 5)) + verify((8, 7, 6, 5, 4)) + + def test_forward_layer_norm(): def verify(shape, axis=-1): x = np.random.uniform(size=shape).astype("float32") @@ -938,6 +958,7 @@ if __name__ == '__main__': test_forward_sequence_mask() test_forward_contrib_div_sqrt_dim() test_forward_batch_norm() + test_forward_instance_norm() test_forward_layer_norm() test_forward_one_hot() test_forward_convolution() diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index dc9493f..16e7174 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -416,6 +416,50 @@ def test_lrn(): verify_lrn((5, 5, 5, 5), 3, 'float32') verify_lrn((5, 5, 5, 5), 3, 'float32', alpha=0.0002, beta=0.5, bias=2.0) + +def verify_instance_norm(shape, axis=1): + + def _get_python_instance_norm(x, gamma, beta, epsilon=1e-5): + dims_x = len(x.shape) + axis = tuple(range(2, dims_x)) + mean = np.mean(x, axis=axis, keepdims=True) + var = np.var(x, axis=axis, keepdims=True) + dim_ones = (1,) * (dims_x - 2) + gamma = gamma.reshape(-1, *dim_ones) + beta = beta.reshape(-1, *dim_ones) + return gamma * (x - mean) / np.sqrt(var + epsilon) + beta + + x = np.random.randn(*shape).astype(np.float32) + gamma = np.random.randn(shape[1]).astype(np.float32) + beta = np.random.randn(shape[1]).astype(np.float32) + epsilon = 1e-5 + y = _get_python_instance_norm(x, gamma, beta, epsilon).astype(np.float32) + + node = onnx.helper.make_node( + 'InstanceNormalization', + inputs=['x', 'gamma', 'beta'], + outputs=['y'], + epsilon=epsilon, + ) + graph = helper.make_graph([node], + "instance_norm_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(shape)), + helper.make_tensor_value_info("gamma", TensorProto.FLOAT, (shape[1],)), + helper.make_tensor_value_info("beta", TensorProto.FLOAT, (shape[1],))], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(shape))]) + model = helper.make_model(graph, producer_name='instance_norm_test') + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, [x, gamma, beta], target, ctx, shape, 'float32') + tvm.testing.assert_allclose(y, tvm_out, rtol=1e-5, atol=1e-5) + + +def test_instance_norm(): + verify_instance_norm((2, 3, 4, 5)) + verify_instance_norm((32, 64, 80, 64)) + verify_instance_norm((8, 6, 5)) + verify_instance_norm((8, 7, 6, 5, 4)) + + def _test_upsample_nearest(): scale = 2 in_shape = (1, 1, 3, 3) @@ -1270,6 +1314,7 @@ if __name__ == '__main__': test_matmul() test_gather() test_lrn() + test_instance_norm() test_upsample() test_forward_min() test_forward_max()