From 082f27ebf8b14f537f0d7686e8161db1684f3110 Mon Sep 17 00:00:00 2001 From: Rishabh Jain <56974688+jainris@users.noreply.github.com> Date: Thu, 27 Aug 2020 08:51:45 +0530 Subject: [PATCH] [Relay/TOPI][TFLite] Implemented MATRIX_SET_DIAG Operator for Relay/TOPI and TFLite Frontend. (#6303) * Corrected docstring error. * Minor changes. * Changed MATRIX_SET_DIAG registration from broadcast to injective. --- include/tvm/topi/transform.h | 29 ++++++++++ python/tvm/relay/frontend/tflite.py | 28 ++++++++++ python/tvm/relay/op/_transform.py | 1 + python/tvm/relay/op/transform.py | 41 ++++++++++++++ python/tvm/topi/testing/__init__.py | 1 + python/tvm/topi/testing/matrix_set_diag.py | 47 ++++++++++++++++ python/tvm/topi/transform.py | 40 ++++++++++++++ src/relay/op/tensor/transform.cc | 50 +++++++++++++++++ src/topi/transform.cc | 4 ++ tests/python/frontend/tflite/test_forward.py | 72 +++++++++++++++++++++++++ tests/python/relay/test_op_level10.py | 28 ++++++++++ tests/python/topi/python/test_topi_transform.py | 36 +++++++++++++ 12 files changed, 377 insertions(+) create mode 100644 python/tvm/topi/testing/matrix_set_diag.py diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 19b2ef4..eb69fc5 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1511,6 +1511,35 @@ inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Array name, tag); } +/*! + * \brief Returns a tensor with the diagonal of input tensor replaced with the provided diagonal. + * \param input input tensor. + * \param diagonal values to be filled in the diagonal. + * \param name output tensor name. + * \param tag output tensor tag. + * \return new tensor with given diagonal values. + */ +inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, + const std::string name = "T_matrix_set_diag", + const std::string tag = kInjective) { + size_t ndim = input->shape.size() - 1; + + return compute( + input->shape, + [&](const Array& iter_vars) { + auto get_diag = [&]() { + Array diagonal_indices; + for (size_t i = 0; i < ndim; i++) { + diagonal_indices.push_back(iter_vars[i]); + } + return diagonal(diagonal_indices); + }; + return if_then_else((PrimExpr)iter_vars[ndim] == iter_vars[ndim - 1], get_diag(), + input(iter_vars)); + }, + name, tag); +} + } // namespace topi } // namespace tvm #endif // TVM_TOPI_TRANSFORM_H_ diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 200352c..31ff871 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -107,6 +107,7 @@ class OperatorConverter(object): 'LOGICAL_NOT': self.convert_logical_not, 'LOGICAL_OR': self.convert_logical_or, 'LOGISTIC': self.convert_logistic, + 'MATRIX_SET_DIAG': self.convert_matrix_set_diag, 'MAX_POOL_2D': self.convert_max_pool2d, 'MAXIMUM': self.convert_maximum, 'MEAN': self.convert_reduce_mean, @@ -2989,6 +2990,33 @@ class OperatorConverter(object): out = _op.reverse(input_expr, axis) return out + def convert_matrix_set_diag(self, op): + """Convert TFLite MATRIX_SET_DIAG""" + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensor's length should be 2" + + assert input_tensors[0].tensor.Type() == input_tensors[1].tensor.Type(), \ + "input and diagonal should be the same type of tensors" + + if input_tensors[0].qnn_params: + # Check that input and output tensor have same qnn params. + output_tensors = self.get_output_tensors(op) + assert self.has_same_qnn_params(input_tensors[0], output_tensors[0]), \ + "TFLite MATRIX_SET_DIAG requires input and output tensors' \ + scale and zero points to be equal" + + # Check that input and diagonal tensor have same qnn params. + assert self.has_same_qnn_params(input_tensors[0], input_tensors[1]), \ + "TFLite MATRIX_SET_DIAG requires input and diagonal tensors' \ + scale and zero points to be equal" + + input_expr = self.get_tensor_expr(input_tensors[0]) + diagonal_expr = self.get_tensor_expr(input_tensors[1]) + + out = _op.matrix_set_diag(input_expr, diagonal_expr) + return out + def get_expr(self, input_tensor_idx): return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx)) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index a69eb8c..b562233 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -61,6 +61,7 @@ _reg.register_reduce_schedule("collapse_sum_like") _reg.register_reduce_schedule("collapse_sum_to") _reg.register_injective_schedule("unravel_index") _reg.register_injective_schedule("sparse_to_dense") +_reg.register_injective_schedule("matrix_set_diag") # concatenate _reg.register_schedule("concatenate", strategy.schedule_concatenate) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index b46b156..6d3c8be 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1167,3 +1167,44 @@ def sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value=0 if default_value == 0: default_value = const(0) return _make.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value) + + +def matrix_set_diag(data, diagonal): + """ + Returns a tensor with the diagonal of input tensor replaced with the provided diagonal values. + + Parameters + ---------- + data : relay.Expr + Input Tensor. + diagonal : relay.Expr + Values to be filled in the diagonal. + + Returns + ------- + result : relay.Expr + New tensor with given diagonal values. + + Examples + -------- + .. code-block:: python + + data = [[[7, 7, 7, 7], + [7, 7, 7, 7], + [7, 7, 7, 7]], + [[7, 7, 7, 7], + [7, 7, 7, 7], + [7, 7, 7, 7]]] + + diagonal = [[1, 2, 3], + [4, 5, 6]] + + relay.matrix_set_diag(input, diagonal) = + [[[1, 7, 7, 7], + [7, 2, 7, 7], + [7, 7, 3, 7]], + [[4, 7, 7, 7], + [7, 5, 7, 7], + [7, 7, 6, 7]]] + """ + return _make.matrix_set_diag(data, diagonal) diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index 70ee8e9..ce0554f 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -60,3 +60,4 @@ from .common import get_injective_schedule, get_reduce_schedule, get_broadcast_s get_elemwise_schedule, get_conv2d_nchw_implement, dispatch from .adaptive_pool_python import adaptive_pool from .grid_sample_python import affine_grid_python, grid_sample_nchw_python +from .matrix_set_diag import matrix_set_diag diff --git a/python/tvm/topi/testing/matrix_set_diag.py b/python/tvm/topi/testing/matrix_set_diag.py new file mode 100644 index 0000000..e0a8914 --- /dev/null +++ b/python/tvm/topi/testing/matrix_set_diag.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""MatrixSetDiag in Python""" +import numpy as np + +def matrix_set_diag(input_np, diagonal): + """matrix_set_diag operator implemented in numpy. + + Returns a numpy array with the diagonal of input array + replaced with the provided diagonal values. + + Parameters + ---------- + input : numpy.ndarray + Input Array. + Shape = [D1, D2, D3, ... , Dn-1 , Dn] + diagonal : numpy.ndarray + Values to be filled in the diagonal. + Shape = [D1, D2, D3, ... , Dn-1] + + Returns + ------- + result : numpy.ndarray + New Array with given diagonal values. + Shape = [D1, D2, D3, ... , Dn-1 , Dn] + """ + out = np.array(input_np, copy=True) + n = min(input_np.shape[-1], input_np.shape[-2]) + for i in range(n): + out[..., i, i] = diagonal[..., i] + + return out diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 6bfa473..f3e5a6a 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -798,3 +798,43 @@ def sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value=0 """ return cpp.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value) + +def matrix_set_diag(data, diagonal): + """ + Returns a tensor with the diagonal of input tensor replaced with the provided diagonal values. + + Parameters + ---------- + data : relay.Expr + Input Tensor. + diagonal : relay.Expr + Values to be filled in the diagonal. + + Returns + ------- + result : relay.Expr + New tensor with given diagonal values. + + Examples + -------- + .. code-block:: python + + data = [[[7, 7, 7, 7], + [7, 7, 7, 7], + [7, 7, 7, 7]], + [[7, 7, 7, 7], + [7, 7, 7, 7], + [7, 7, 7, 7]]] + + diagonal = [[1, 2, 3], + [4, 5, 6]] + + relay.matrix_set_diag(input, diagonal) = + [[[1, 7, 7, 7], + [7, 2, 7, 7], + [7, 7, 3, 7]], + [[4, 7, 7, 7], + [7, 5, 7, 7], + [7, 7, 6, 7]]] + """ + return cpp.matrix_set_diag(data, diagonal) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index f1d5b7a..1e223b7 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3093,5 +3093,55 @@ RELAY_REGISTER_OP("sparse_to_dense") .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) .set_attr("FTVMCompute", SparseToDenseCompute); +// relay.matrix_set_diag +bool MatrixSetDiagRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [input, diagonal, result] + CHECK_EQ(types.size(), 3); + + const auto* input = types[0].as(); + CHECK(input); + + const auto* diagonal = types[1].as(); + CHECK(diagonal); + + int d_ndims = diagonal->shape.size(); + for (int i = 0; i < d_ndims - 1; i++) { + reporter->AssertEQ(input->shape[i], diagonal->shape[i]); + } + auto min_dim = if_then_else(input->shape[d_ndims - 1] >= input->shape[d_ndims], + input->shape[d_ndims], input->shape[d_ndims - 1]); + reporter->Assert(diagonal->shape[d_ndims - 1] >= min_dim); + + reporter->Assign(types[2], TensorType(input->shape, input->dtype)); + return true; +} + +Array MatrixSetDiagCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + return Array{topi::matrix_set_diag(inputs[0], inputs[1])}; +} + +Expr MakeMatrixSetDiag(Expr input, Expr diagonal) { + static const Op& op = Op::Get("matrix_set_diag"); + return Call(op, {input, diagonal}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.matrix_set_diag").set_body_typed(MakeMatrixSetDiag); + +RELAY_REGISTER_OP("matrix_set_diag") + .describe( + R"code(Returns a tensor with the diagonal of input tensor replaced with the provided diagonal values. + **input** Input tensor. + **diagonal** Values to be filled in the diagonal. + )code" TVM_ADD_FILELINE) + .set_num_inputs(2) + .add_argument("input", "Tensor", "Input Tensor.") + .add_argument("diagonal", "Tensor", "Values to be filled in the diagonal.") + .set_support_level(10) + .add_type_rel("MatrixSetDiag", MatrixSetDiagRel) + .set_attr("FTVMCompute", MatrixSetDiagCompute) + .set_attr("TOpPattern", kInjective); + } // namespace relay } // namespace tvm diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 7a76c60..154933f 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -176,5 +176,9 @@ TVM_REGISTER_GLOBAL("topi.one_hot").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = one_hot(args[0], args[1], args[2], depth, axis, dtype); }); +TVM_REGISTER_GLOBAL("topi.matrix_set_diag").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = matrix_set_diag(args[0], args[1]); +}); + } // namespace topi } // namespace tvm diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 4d7fc06..70a629d 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -2653,6 +2653,77 @@ def test_forward_reverse_v2(): ####################################################################### +# MATRIX_SET_DIAG +# --------------- + +def _test_matrix_set_diag(input_shape, input_type, quantized=False): + """ One iteration of MATRIX_SET_DIAG """ + with tf.Graph().as_default(): + diagonal_shape = list(input_shape[:-2]) + diagonal_shape.append(min(input_shape[-2], input_shape[-1])) + + if quantized: + # ignoring input_type as quantized requires uint8 + input = np.random.uniform(0, 256, input_shape).astype('uint8') + in_input = tf.placeholder(dtype='float32', shape=input.shape, name="input") + inq_input = tf.quantization.fake_quant_with_min_max_args( + in_input, + min=-100, + max=100, + name="q_input") + + diagonal = np.random.uniform(0, 256, diagonal_shape).astype('uint8') + in_diagonal = tf.placeholder(dtype='float32', shape=diagonal.shape, name="diagonal") + inq_diagonal = tf.quantization.fake_quant_with_min_max_args( + in_diagonal, + min=-100, + max=100, + name="q_diagonal") + + input_range = {'q_input': (-100, 100), 'q_diagonal': (-100, 100)} + + out = array_ops.matrix_set_diag(inq_input, inq_diagonal) + out = tf.quantization.fake_quant_with_min_max_args( + out, + min=-100, + max=100, + name="out") + + compare_tflite_with_tvm( + [input, diagonal], + ["q_input", "q_diagonal"], + [inq_input, inq_diagonal], + [out], + quantized=True, + input_range=input_range) + else: + input = np.random.uniform(0, 100, input_shape).astype(input_type) + diagonal = np.random.uniform(0, 100, diagonal_shape).astype(input_type) + + in_input = tf.placeholder(dtype=input.dtype, shape=input.shape, name="input") + in_diagonal = tf.placeholder(dtype=diagonal.dtype, shape=diagonal.shape, name="diagonal") + + out = array_ops.matrix_set_diag(in_input, in_diagonal) + + compare_tflite_with_tvm( + [input, diagonal], + ["input", "diagonal"], + [in_input, in_diagonal], + [out]) + +def test_forward_matrix_set_diag(): + """ MATRIX_SET_DIAG """ + for dtype in [np.float32, np.int32]: + _test_matrix_set_diag((4, 4), dtype) + _test_matrix_set_diag((5, 4, 3, 4), dtype) + _test_matrix_set_diag((4, 4, 2), dtype) + + _test_matrix_set_diag((4, 4), np.uint8, quantized=True) + _test_matrix_set_diag((5, 4, 3, 4), np.uint8, quantized=True) + _test_matrix_set_diag((4, 4, 2), np.uint8, quantized=True) + + +####################################################################### # Custom Operators # ---------------- @@ -3131,6 +3202,7 @@ if __name__ == '__main__': test_forward_arg_min_max() test_forward_expand_dims() test_forward_reverse_v2() + test_forward_matrix_set_diag() # NN test_forward_convolution() diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index c0a990b..a65b17f 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -471,6 +471,33 @@ def test_one_hot(): _verify((3, 2, 4, 5), 6, 1, 0, 1, "int32") _verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32") +def test_matrix_set_diag(): + def _verify(input_shape, dtype): + diagonal_shape = list(input_shape[:-2]) + diagonal_shape.append(min(input_shape[-2], input_shape[-1])) + input = relay.var("input", relay.TensorType(input_shape, dtype)) + diagonal = relay.var("diagonal", relay.TensorType(diagonal_shape, dtype)) + out = relay.matrix_set_diag(input, diagonal) + + in_type = run_infer_type(input) + out_type = run_infer_type(out) + assert in_type.checked_type == out_type.checked_type + + func = relay.Function([input, diagonal], out) + input_np = np.random.randint(-100, 100, size=input_shape).astype(dtype) + diagonal_np = np.random.randint(-100, 100, size=diagonal_shape).astype(dtype) + out_np = tvm.topi.testing.matrix_set_diag(input_np, diagonal_np) + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + out_relay = intrp.evaluate(func)(input_np, diagonal_np) + tvm.testing.assert_allclose(out_relay.asnumpy(), out_np) + + _verify((2, 2), 'float32') + _verify((4, 3, 3), 'int32') + _verify((2, 3, 4), 'float32') + if __name__ == "__main__": test_adaptive_pool() test_collapse_sum_like() @@ -483,3 +510,4 @@ if __name__ == "__main__": test_sequence_mask() test_one_hot() test_ndarray_size() + test_matrix_set_diag() diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index 13d24d5..d8c51b8 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -746,6 +746,35 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ for device in get_all_backend(): check_device(device) +def verify_matrix_set_diag(input_shape, dtype): + diagonal_shape = list(input_shape[:-2]) + diagonal_shape.append(min(input_shape[-2], input_shape[-1])) + input = te.placeholder(shape=input_shape, name="input", dtype=dtype) + diagonal = te.placeholder(shape=diagonal_shape, name="diagonal", dtype=dtype) + matrix_set_diag_result = topi.transform.matrix_set_diag(input, diagonal) + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = tvm.topi.testing.get_injective_schedule(device)(matrix_set_diag_result) + fn = tvm.build(s, [input, diagonal, matrix_set_diag_result], device, name="matrix_set_diag") + input_npy = np.random.randint(-100, 100, size=input_shape).astype(dtype) + diagonal_npy = np.random.randint(-100, 100, size=diagonal_shape).astype(dtype) + out_npy = tvm.topi.testing.matrix_set_diag(input_npy, diagonal_npy) + input_nd = tvm.nd.array(input_npy, ctx) + diagonal_nd = tvm.nd.array(diagonal_npy, ctx) + out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(matrix_set_diag_result.dtype), ctx) + fn(input_nd, diagonal_nd, out_nd) + out_topi = out_nd.asnumpy() + tvm.testing.assert_allclose(out_topi, out_npy) + + for device in get_all_backend(): + check_device(device) + + def test_strided_slice(): verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2]) verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1]) @@ -1105,6 +1134,12 @@ def test_sparse_to_dense(): #sparse_indices should not be > 2d tensor #verify_sparse_to_dense([[[[0, 1, 4], [0, 2, 4]]]], [[[3.1, 3.1, 3.1]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]) +def test_matrix_set_diag(): + for dtype in ['float32', 'int32']: + verify_matrix_set_diag((2, 2), dtype) + verify_matrix_set_diag((4, 3, 3), dtype) + verify_matrix_set_diag((2, 3, 4), dtype) + if __name__ == "__main__": test_strided_slice() test_concatenate() @@ -1130,3 +1165,4 @@ if __name__ == "__main__": test_one_hot() test_unravel_index() test_sparse_to_dense() + test_matrix_set_diag() -- 2.7.4