From 14f004229f876f6c82687c955e53e8c7a83127f9 Mon Sep 17 00:00:00 2001 From: Rishabh Jain <56974688+jainris@users.noreply.github.com> Date: Tue, 8 Sep 2020 20:14:01 +0530 Subject: [PATCH] [TFLite] Implemented MATRIX_DIAG Operator for TFLite. (#6397) * Added implementation for MATRIX_DIAG Operator. * Added tests for MATRIX_DIAG Operator. --- python/tvm/relay/frontend/tflite.py | 25 +++++++++++++++++++++++++ tests/python/frontend/tflite/test_forward.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 6e4a62d..59ba9f4 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -107,6 +107,7 @@ class OperatorConverter(object): 'LOGICAL_NOT': self.convert_logical_not, 'LOGICAL_OR': self.convert_logical_or, 'LOGISTIC': self.convert_logistic, + 'MATRIX_DIAG': self.convert_matrix_diag, 'MATRIX_SET_DIAG': self.convert_matrix_set_diag, 'MAX_POOL_2D': self.convert_max_pool2d, 'MAXIMUM': self.convert_maximum, @@ -3020,6 +3021,30 @@ class OperatorConverter(object): out = _op.matrix_set_diag(input_expr, diagonal_expr) return out + def convert_matrix_diag(self, op): + """Convert TFLite MATRIX_DIAG""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensor's length should be 1" + + diagonal = input_tensors[0] + + if diagonal.qnn_params: + # Check that diagonal and output tensor have same qnn params. + output_tensors = self.get_output_tensors(op) + assert self.has_same_qnn_params(diagonal, output_tensors[0]), \ + "TFLite MATRIX_DIAG requires diagonal and output tensors' \ + scale and zero points to be equal" + + shape = diagonal.tensor.ShapeAsNumpy() + shape = np.append(shape, shape[-1]) + dtype = self.get_tensor_type_str(diagonal.tensor.Type()) + + input_expr = _op.zeros(tuple(shape), dtype) + diagonal_expr = self.get_tensor_expr(diagonal) + + out = _op.matrix_set_diag(input_expr, diagonal_expr) + return out + def get_expr(self, input_tensor_idx): return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx)) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 3577de3..89296a6 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -2761,6 +2761,33 @@ def test_forward_matrix_set_diag(): ####################################################################### +# MATRIX_DIAG +# ----------- + +def _test_matrix_diag(diagonal_shape, dtype): + """ One iteration of MATRIX_DIAG """ + with tf.Graph().as_default(): + diagonal = np.random.uniform(0, 100, diagonal_shape).astype(dtype) + in_diagonal = tf.placeholder(dtype=diagonal.dtype, shape=diagonal.shape, name="diagonal") + + out = array_ops.matrix_diag(in_diagonal) + + compare_tflite_with_tvm( + [diagonal], + ["diagonal"], + [in_diagonal], + [out], + experimental_new_converter=True) + +def test_forward_matrix_diag(): + """ MATRIX_DIAG """ + for dtype in [np.float32, np.int32]: + _test_matrix_diag((4), dtype) + _test_matrix_diag((5, 4, 3), dtype) + _test_matrix_diag((2, 3), dtype) + + +####################################################################### # Custom Operators # ---------------- @@ -3240,6 +3267,7 @@ if __name__ == '__main__': test_forward_expand_dims() test_forward_reverse_v2() test_forward_matrix_set_diag() + test_forward_matrix_diag() # NN test_forward_convolution() -- 2.7.4