from .. import op as _op
from ... import nd as _nd
from .common import ExprTable
+from .common import infer_shape as _infer_shape
__all__ = ['from_tflite']
'LOGISTIC': self.convert_logistic,
'SPLIT': self.convert_split,
'TRANSPOSE': self.convert_transpose,
- 'TILE': self.convert_tile
+ 'TILE': self.convert_tile,
+ 'BATCH_TO_SPACE_ND': self.convert_batch_to_space_nd,
+ 'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd
}
def check_unsupported_ops(self):
out = _op.concatenate(in_exprs_reshaped, pack_axis)
return out
+ def convert_batch_to_space_nd(self, op):
+ """batch_to_space_nd implementation."""
+ try:
+ from tflite.Operator import Operator
+ except ImportError:
+ raise ImportError("The tflite package must be installed")
+
+ assert isinstance(op, Operator)
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 3, "input tensors length should be 3"
+
+ input_tensor = input_tensors[0]
+ input_tensor_idx = input_tensor.tensor_idx
+ in_expr = self.get_expr(input_tensor_idx)
+
+ input_shape = list(input_tensor.tensor.ShapeAsNumpy())
+ batch = input_shape[0]
+
+ block_shape = list(self.get_tensor_value(input_tensors[1]))
+ M = len(block_shape)
+
+ crops = list(self.get_tensor_value(input_tensors[2]))
+
+ # From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d:
+ # Reshape input to reshaped of shape
+ shape1 = block_shape + [batch // np.prod(block_shape)] + input_shape[1:]
+ reshaped = _op.reshape(in_expr, newshape=shape1)
+
+ # Permute dimensions of reshaped to produce permuted of shape
+ axes = [M] + [axis for i in range(M) for axis in [M + i + 1, i]] + \
+ list(range(2 * M + 1, len(shape1)))
+ permuted = _op.transpose(reshaped, axes=axes)
+
+ # Reshape permuted to produce reshaped_permuted of shape
+ shape2 = [0] + [-3] * M + [-2]
+ reshaped_permuted = _op.reshape(permuted, newshape=shape2)
+
+ # Crop the start and end of dimensions [1, ..., M] of reshaped_permuted according to crops
+ # to produce the output of shape:
+ reshaped_permuted_shape = _infer_shape(reshaped_permuted)
+ cropped = reshaped_permuted
+ for axis in range(1, M + 1):
+ crop = crops[axis - 1]
+ if (crop != [0, 0]).all():
+ indices = _op.arange(
+ _expr.const(crop[0]),
+ _expr.const(reshaped_permuted_shape[axis] - crop[1]),
+ dtype='int32'
+ )
+ cropped = _op.take(cropped, indices=indices, axis=axis)
+
+ return cropped
+
+ def convert_space_to_batch_nd(self, op):
+ """space_to_batch_nd implementation."""
+ try:
+ from tflite.Operator import Operator
+ except ImportError:
+ raise ImportError("The tflite package must be installed")
+
+ assert isinstance(op, Operator)
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 3, "input tensors length should be 3"
+
+ input_tensor = input_tensors[0]
+ input_tensor_idx = input_tensor.tensor_idx
+ in_expr = self.get_expr(input_tensor_idx)
+
+ input_shape = list(input_tensor.tensor.ShapeAsNumpy())
+ batch = input_shape[0]
+ N = len(input_shape)
+
+ block_shape = list(self.get_tensor_value(input_tensors[1]))
+ M = len(block_shape)
+
+ paddings = list(self.get_tensor_value(input_tensors[2]))
+
+ # From https://www.tensorflow.org/api_docs/python/tf/space_to_batch_nd:
+ # Zero-pad the start and end of dimensions [1, ..., M] of the input according to paddings
+ # to produce padded of shape padded_shape.
+ remaining_shape_length = N - M - 1
+ padded_list = [(0, 0)] + paddings + [(0, 0)] * remaining_shape_length
+
+ padded_shape = []
+ for element in padded_list:
+ if isinstance(element, np.ndarray):
+ element = element.tolist()
+
+ padded_shape.append(element)
+
+ padded_shape = tuple(padded_shape)
+ padded = _op.nn.pad(in_expr, pad_width=tuple(padded_shape))
+
+ # Reshape padded to reshaped_padded of shape:
+ shape1 = [batch] + [item for i in range(M) for item in [-4, -1, block_shape[i]]] + [-2]
+ reshaped_padded = _op.reshape(padded, newshape=shape1)
+
+ # Permute dimensions of reshaped_padded to produce permuted_reshaped_padded of shape:
+ axes = [2 * i + 2 for i in range(M)] + [0] + [2 * i + 1 for i in range(M)] + \
+ list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length))
+ permuted_reshaped_padded = _op.transpose(reshaped_padded, axes=axes)
+ permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded)
+
+ # Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension,
+ # producing an output tensor of shape:
+ shape2 = [batch * np.prod(block_shape)] + list(permuted_reshaped_padded_shape)[M + 1:]
+ reshaped_permuted_reshaped_padded = _op.reshape(permuted_reshaped_padded, newshape=shape2)
+
+ return reshaped_permuted_reshaped_padded
+
def get_expr(self, input_tensor_idx):
return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))
_test_forward_tile((2, ), (3, ), "int32")
_test_forward_tile((2, 2), (2, 3), "float32")
+######################################################################
+# BatchToSpaceND
+# --------------
+
+
+def _test_batch_to_space_nd(input_shape, block_shape, crops, dtype='int32'):
+ data = np.random.uniform(0, 5, size=input_shape).astype(dtype)
+
+ with tf.Graph().as_default():
+ in_data = array_ops.placeholder(shape=input_shape, dtype=dtype)
+
+ out = array_ops.batch_to_space_nd(in_data, block_shape, crops)
+
+ compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+
+
+def test_forward_batch_to_space_nd():
+ # test cases: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d
+ _test_batch_to_space_nd(
+ input_shape=[4, 1, 1, 1],
+ block_shape=[2, 2],
+ crops=[[0, 0], [0, 0]]
+ )
+
+ _test_batch_to_space_nd(
+ input_shape=[4, 1, 1, 3],
+ block_shape=[2, 2],
+ crops=[[0, 0], [0, 0]]
+ )
+
+ _test_batch_to_space_nd(
+ input_shape=[4, 2, 2, 1],
+ block_shape=[2, 2],
+ crops=[[0, 0], [0, 0]]
+ )
+
+######################################################################
+# SpaceToBatchND
+# --------------
+
+
+def _test_space_to_batch_nd(input_shape, block_shape, paddings, dtype='int32'):
+ data = np.random.uniform(0, 5, size=input_shape).astype(dtype)
+
+ with tf.Graph().as_default():
+ in_data = array_ops.placeholder(shape=input_shape, dtype=dtype)
+
+ out = array_ops.space_to_batch_nd(in_data, block_shape, paddings)
+
+ compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+
+
+def test_forward_space_to_batch_nd():
+ # test cases: https://www.tensorflow.org/api_docs/python/tf/space_to_batch_nd
+ _test_space_to_batch_nd(
+ input_shape=[1, 2, 2, 1],
+ block_shape=[2, 2],
+ paddings=[[0, 0], [0, 0]]
+ )
+
+ _test_space_to_batch_nd(
+ input_shape=[1, 2, 2, 3],
+ block_shape=[2, 2],
+ paddings=[[0, 0], [0, 0]]
+ )
+
+ _test_space_to_batch_nd(
+ input_shape=[1, 4, 4, 1],
+ block_shape=[2, 2],
+ paddings=[[0, 0], [0, 0]]
+ )
+
+ _test_space_to_batch_nd(
+ input_shape=[2, 2, 4, 1],
+ block_shape=[2, 2],
+ paddings=[[0, 0], [2, 0]]
+ )
#######################################################################
# Pooling
# Main
# ----
if __name__ == '__main__':
+ # BatchToSpaceND
+ test_forward_batch_to_space_nd()
+
+ # SpaceToBatchND
+ test_forward_space_to_batch_nd()
+
# Split
test_forward_split()