From b3730e575ac0ab4a54fe6984bde35408e93b5e6f Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Thu, 7 May 2020 16:06:20 +0530 Subject: [PATCH] [FRONTEND][TFLite] Fully connected op conversion made in sync with TFLite (#5510) * [FRONTEND][TFLite] Fully connected op conversion made in sync with TFLite * [1] Test case added * [2] Review comments handled * [3] Prints removed --- python/tvm/relay/frontend/tflite.py | 33 ++++++++++++++++++++-------- tests/python/frontend/tflite/test_forward.py | 26 ++++++++++++++++++++++ 2 files changed, 50 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index a55a57f..bb456f1 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -1331,16 +1331,28 @@ class OperatorConverter(object): input_tensor_shape = input_tensor.tensor.ShapeAsNumpy() weight_tensor_shape = weight_tensor.tensor.ShapeAsNumpy() - # reshape input tensor from N H W C to N H*W*C - input_size_per_batch = 1 - for s in range(1, len(input_tensor_shape)): - input_size_per_batch *= input_tensor_shape[s] - assert input_size_per_batch == weight_tensor_shape[1], \ - "input size and weight size are mismatched" - target_shape = tuple((input_tensor_shape[0], input_size_per_batch)) + # Weight should have only 2 dimensions(TFLite convention) + assert len(weight_tensor_shape) == 2, "Weight should be only 2-dim" + + # Input shape: [i_batch_size, ..., n_inputs] + # Filter shape: [n_inputs, n_units] + # + # As we will transform Fully_Connected Input to Dense Op inputs as below + # Dense expected Input shape: [batch_size, n_units] + # Dense expected Weight shape: [out_dim, n_units] + # Dense output shape: [batch_size, out_dim] + # So it is evident that input shape: [batch_size = input_size / n_units, n_units] + input_size = 1 + for _, shape in enumerate(input_tensor_shape): + input_size *= shape + + # First get the batch size + batch_size = int(input_size / weight_tensor_shape[1]) + target_shape = tuple((batch_size, weight_tensor_shape[1])) in_expr = self.get_expr(input_tensor_idx) in_expr = _op.reshape(in_expr, target_shape) + #TODO: Change the output shape calculation based on keep_dim option assert op.BuiltinOptionsType() == BuiltinOptions.FullyConnectedOptions op_options = op.BuiltinOptions() fully_connected_options = FullyConnectedOptions() @@ -1352,8 +1364,11 @@ class OperatorConverter(object): assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32) weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type) - weight_value = self.get_tensor_value(weight_tensor) - weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str) + if self.has_expr(weight_tensor.tensor_idx): + weight_expr = self.get_expr(weight_tensor.tensor_idx) + else: + weight_value = self.get_tensor_value(weight_tensor) + weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str) weight_shape = _infer_shape(weight_expr) if input_tensor.qnn_params: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 7257748..20a077f 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -420,6 +420,29 @@ def test_forward_cast(): _test_cast(np.arange(6.0, dtype=np.int32).reshape((1, 6)), cast_dtype=tf.int64) ####################################################################### +# Batch Mat Mul +# ---- +def _test_batch_matmul(A_shape, B_shape, dtype, adjoint_a=False, adjoint_b=False): + with tf.Graph().as_default(): + A = array_ops.placeholder(shape=A_shape, dtype=dtype, name='A') + B = array_ops.placeholder(shape=B_shape, dtype=dtype, name='B') + result = math_ops.matmul(A, B, adjoint_a=adjoint_a, + adjoint_b=adjoint_b, name='batchmatmul') + + A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype) + B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype) + compare_tflite_with_tvm([A_np, B_np], [A.name, B.name], [A, B], [result]) + + +def test_forward_batch_matmul(): + """ BATCH_MAT_MUL """ + _test_batch_matmul((3, 5, 4), (3, 4, 5), 'float32') + _test_batch_matmul((3, 5, 4), (3, 4, 5), 'float32', True, True) + _test_batch_matmul((3, 5, 4), (3, 5, 4), 'float32', True, False) + _test_batch_matmul((3, 5, 4), (3, 5, 4), 'float32', False, True) + _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), 'float32') + +####################################################################### # Tile # ---- @@ -2001,6 +2024,9 @@ if __name__ == '__main__': # Cast test_forward_cast() + # BatchMatMul + test_forward_batch_matmul() + # Tile test_forward_tile() -- 2.7.4