[Relay][Frontend] Simplify parameter handling in Tensorflow frontend (#2993)
authorAlexey Romanov <romanov.alexey1@huawei.com>
Thu, 6 Jun 2019 18:00:19 +0000 (21:00 +0300)
committerTianqi Chen <tqchen@users.noreply.github.com>
Thu, 6 Jun 2019 18:00:19 +0000 (11:00 -0700)
python/tvm/relay/frontend/tensorflow.py
tests/python/frontend/tensorflow/test_forward.py
topi/python/topi/util.py

index 307fb20693f4d6e5b2adc1b86073675fff8e3c34..f709a63e79e8ecbef94a13faa8f60a7ae86d78e2 100644 (file)
@@ -63,7 +63,7 @@ def _get_relay_op(op_name):
     return op
 
 class AttrCvt(object):
-    """Common attribute conveter. An AttrConverter instance is a callable:
+    """Common attribute converter. An AttrConverter instance is a callable:
     ```
     attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)})
     new_op_name, new_attr = attr_converter(attrs)
@@ -222,17 +222,37 @@ def _dimension_constraint():
         return False
     return _dim_check, "Only 2d kernel supported."
 
-def _infer_channels(inputs, params, transpose=False):
-    """A hack for getting 'channles' or 'units' since tensorflow don't provide
+def _infer_channels(node, params, transpose=False):
+    """A hack for getting 'channels' or 'units' since tensorflow don't provide
     these attributes. We check the shape of weights provided to get the number.
     """
-    out_type = ir_pass.infer_type(inputs)
-    out_shapes = [get_const_tuple(out_type.checked_type.shape)]
-    channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
+    out_shape = _infer_shape(node, params)
+    channels = out_shape[0] if not transpose else out_shape[1]
     return channels
 
+def _infer_out_shapes(inputs, params):
+    """A method to get the output shape of intermediate nodes in the relay graph."""
+    return [_infer_shape(inputs, params)]
+
+def _infer_shape(node, params=None):
+    """A method to get the output shape of an intermediate node in the relay graph."""
+    out_type = ir_pass.infer_type(node)
+    return get_const_tuple(out_type.checked_type.shape)
+
+def _get_param(params, input_node):
+    return params.pop(input_node.name_hint).asnumpy()
+
+def _get_num_param(params, input_node):
+    return _get_param(params, input_node)[0]
+
+def _get_list_param(params, input_node):
+    return _get_param(params, input_node).tolist()
+
+def _get_tuple_param(params, input_node):
+    return tuple(_get_param(params, input_node))
+
 def _rsqrt():
-    def _impl(inputs, attr, *args):
+    def _impl(inputs, attr, params):
         inputs.append(tvm.relay.const(-0.5, attr['T'].name))
         return AttrCvt(op_name="power")(inputs, attr)
     return _impl
@@ -243,16 +263,15 @@ def _argx(func, func_name):
         try:
             # In Tensorflow, `axis` argument is a Tensor, not attribute. We
             # support the case where it inputs from a scalar constant.
-            axis_input_name = inputs[1].name_hint
-            axis_input_vlaue = [params[axis_input_name].asnumpy()[0]]
+            axis_input_value = [_get_num_param(params, inputs[1])]
         except (IndexError, KeyError):
             raise TypeError( \
                 "Unsupported argument for `{}` : `axis` should be a constant".format(func_name))
-        return func(inputs[0], axis=axis_input_vlaue, keepdims=False)
+        return func(inputs[0], axis=axis_input_value, keepdims=False)
     return _impl
 
 def _elemwise(name):
-    def _impl(inputs, attr, *args):
+    def _impl(inputs, attr, params):
         assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs))
         return _get_relay_op(name)(*inputs)
     return _impl
@@ -472,7 +491,7 @@ def _cast():
 def _expand_dims():
     def _impl(inputs, attr, params):
         dim_input = inputs.pop(1)
-        axis = params.pop(_get_name_hint(dim_input)).asnumpy()[0]
+        axis = _get_num_param(params, dim_input)
         return AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'],
                        extras={'axis': int(axis), 'num_newaxis': 1})(inputs, attr)
     return _impl
@@ -527,21 +546,19 @@ def _identity():
 def _concatV2():
     def _impl(inputs, attr, params):
         pop_node = inputs.pop(len(inputs)-1)
-        axis = params[pop_node.name_hint]
-        params.pop(pop_node.name_hint)
+        axis = int(_get_num_param(params, pop_node))
         return AttrCvt(
             op_name="concatenate", ignores=['T', 'N', 'Tidx'],
-            extras={'axis': int(axis.asnumpy()[0])})([inputs], attr)
+            extras={'axis': axis})([inputs], attr)
     return _impl
 
 def _concat():
     def _impl(inputs, attr, params):
         pop_node = inputs.pop(0)
-        axis = params[pop_node.name_hint]
-        params.pop(pop_node.name_hint)
+        axis = int(_get_num_param(params, pop_node))
         return AttrCvt(
             op_name="concatenate", ignores=['N'],
-            extras={'axis': int(axis.asnumpy()[0])})([inputs], attr)
+            extras={'axis': axis})([inputs], attr)
     return _impl
 
 def _pack():
@@ -565,8 +582,8 @@ def _tile():
 
 def _slice():
     def _impl(inputs, attr, params):
-        begin = params.pop(_get_name_hint(inputs[1])).asnumpy().tolist()
-        size = params.pop(_get_name_hint(inputs[2])).asnumpy().tolist()
+        begin = _get_list_param(params, inputs[1])
+        size = _get_list_param(params, inputs[2])
         data_shape = attr['_input_shapes'][inputs[0]]
         data_dim = len(data_shape)
         end = size
@@ -581,24 +598,18 @@ def _slice():
 
 def _reshape():
     def _impl(inputs, attr, params):
+        pop_node = inputs.pop(1)
         try:
-            pop_node = inputs[1]
-            shape_arg = params.pop(pop_node.name_hint)
-            inputs.pop(1)
-
-            return AttrCvt(
-                op_name="reshape",
-                extras={'newshape':tuple(shape_arg.asnumpy())},
-                ignores=['Tshape'])(inputs, attr)
+            shape_arg = _get_tuple_param(params, pop_node)
         except AttributeError:
             # Shape operator is already pruned, hence
             # try to infer shape by precompute prune if possible.
-            params_new = _infer_value(inputs[1], params)
-            inputs.pop(1)
-            return AttrCvt(
-                op_name="reshape",
-                extras={'newshape':tuple(params_new.asnumpy().astype('int64').flatten())},
-                ignores=['Tshape'])(inputs, attr)
+            params_new = _infer_value(pop_node, params)
+            shape_arg = tuple(params_new.asnumpy().astype('int64').flatten())
+        return AttrCvt(
+            op_name="reshape",
+            extras={'newshape': shape_arg},
+            ignores=['Tshape'])(inputs, attr)
     return _impl
 
 
@@ -737,9 +748,10 @@ def _fill():
         if -1 in output_shape:
             output_shape = _infer_value(inputs[0], params).asnumpy().reshape([-1]).tolist()
 
-        fill_arg = params.pop(inputs.pop(1).name_hint)
-        return _op.full(tvm.relay.const(fill_arg.asnumpy()[0], attr['T'].name),
-                        output_shape, attr['T'].name)
+        fill_arg = _get_num_param(params, inputs.pop(1))
+        dtype = attr['T'].name
+        return _op.full(tvm.relay.const(fill_arg, dtype),
+                        output_shape, dtype)
     return _impl
 
 def _lrn():
@@ -757,9 +769,7 @@ def _lrn():
 
 def _sum():
     def _impl(inputs, attr, params):
-        axis = params.pop(inputs[1].name_hint).asnumpy()
-        # convert to tuple for preventing invalid parameter format error
-        axis = tuple(axis)
+        axis = _get_tuple_param(params, inputs[1])
         return AttrCvt(
             op_name='sum',
             extras={'axis': axis},
@@ -786,25 +796,17 @@ def _square():
 def _gather():
     "GatherV2, Gather"
     def _impl(inputs, attr, params):
-
-        axis = 0
         if len(inputs) > 2:
-            axis = params[inputs.pop(2).name_hint].asnumpy()[0]
-        new_input = []
-        new_input.append(inputs.pop(0))
-        new_input.append(inputs.pop(0))
+            axis = _get_num_param(params, inputs.pop(2))
+        else:
+            axis = 0
+        new_input = inputs[0:2]
         return AttrCvt(op_name="take",
                        extras={'axis': tvm.const(axis, 'int32')},
-                       ignores=['Tindices', 'Tparams', 'validate_indices', \
+                       ignores=['Tindices', 'Tparams', 'validate_indices',
                                 'Taxis', '_class'])(new_input, attr)
     return _impl
 
-def _infer_out_shapes(inputs, params):
-    """A method to get the output shape of an intermediate node in the relay graph."""
-    out_type = ir_pass.infer_type(inputs)
-    out_shapes = [get_const_tuple(out_type.checked_type.shape)]
-    return out_shapes
-
 def _stridedSlice():
     def _impl(inputs, attr, params):
         """Strided Slice.
@@ -812,9 +814,9 @@ def _stridedSlice():
         Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/
         tensorflow/core/util/strided_slice_op.cc#L147-L368
         """
-        begin = params.pop(inputs[1].name_hint).asnumpy().tolist()
-        end = params.pop(inputs[2].name_hint).asnumpy().tolist()
-        stride = params.pop(inputs[3].name_hint).asnumpy().tolist()
+        begin = _get_list_param(params, inputs[1])
+        end = _get_list_param(params, inputs[2])
+        stride = _get_list_param(params, inputs[3])
         begin_mask = int(attr.get('begin_mask', 0))
         end_mask = int(attr.get('end_mask', 0))
         ellipsis_mask = int(attr.get('ellipsis_mask', 0))
@@ -889,7 +891,7 @@ def _stridedSlice():
         if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
             begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)
         out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride)
-        out_shape = _infer_out_shapes(out, params)[0]
+        out_shape = _infer_shape(out, params)
         if not fshape_indices:
             fshape_indices = range(len(out_shape))
 
@@ -910,19 +912,14 @@ def _stridedSlice():
 
 def _pad(name):
     def _impl(inputs, attr, params):
-        padlist_key = inputs[1].name_hint
-        if padlist_key in params:
-            padlist = params.pop(padlist_key).asnumpy()
-        else:
-            raise tvm.error.OpAttributeRequired(
-                'Attribute {} not found in operator Pad.'.format(padlist_key))
-        paddings = tuple([tuple(l) for l in padlist])
+        padlist = _get_param(params, inputs[1])
+        paddings = tuple(tuple(l) for l in padlist)
         attr['pad_width'] = paddings
         attr['pad_value'] = 0
         new_inputs = [inputs[0]]
         if name == 'PadV2':
-            constant_values = params.pop(inputs[2].name_hint).asnumpy()
-            attr['pad_value'] = constant_values[0]
+            constant_values = _get_num_param(params, inputs[2])
+            attr['pad_value'] = constant_values
         return AttrCvt(
             op_name='pad',
             ignores=['Tpaddings'],)(new_inputs, attr)
@@ -932,10 +929,9 @@ def _transpose():
     def _impl(inputs, attr, params):
         # If perm is not specified, axes is left empty,
         # otherwise its value is get from params
-        param_name = _get_name_hint(inputs[1])
-        if param_name in params:
-            axes = tuple(params.get(param_name).asnumpy())
-        else:
+        try:
+            axes = _get_list_param(params, inputs[1])
+        except (IndexError, KeyError):
             axes = None
         return _op.transpose(inputs[0], axes=axes)
     return _impl
@@ -947,7 +943,7 @@ def _where():
 
 def _reverse_v2():
     def _impl(inputs, attr, params):
-        axis = params.pop(inputs[1].name_hint).asnumpy()[0]
+        axis = _get_num_param(params, inputs[1])
         return AttrCvt(
             op_name="reverse",
             ignores=['Tidx'],
@@ -968,9 +964,9 @@ def _rank():
 
 def _range():
     def _impl(inputs, attr, params):
-        start = params.pop(inputs[0].name_hint).asnumpy()[0]
-        limit = params.pop(inputs[1].name_hint).asnumpy()[0]
-        delta = params.pop(inputs[2].name_hint).asnumpy()[0]
+        start = _get_num_param(params, inputs[0])
+        limit = _get_num_param(params, inputs[1])
+        delta = _get_num_param(params, inputs[2])
 
         name = attr["_node_name"]
         params[name] = tvm.nd.array([start, limit, delta])
@@ -981,25 +977,27 @@ def _range():
 
 def _elu():
     def _impl(inputs, attr, params):
-        alpha = tvm.relay.const(-1.0, attr['T'].name)
-        return alpha * _op.nn.relu(tvm.relay.const(1, attr['T'].name) \
+        dtype = attr['T'].name
+        alpha = tvm.relay.const(-1.0, dtype)
+        return alpha * _op.nn.relu(tvm.relay.const(1, dtype) \
                                    - _op.exp(inputs[0])) + _op.nn.relu(inputs[0])
     return _impl
 
 def _selu():
     def _impl(inputs, attr, params):
-        alpha = tvm.relay.const(-1.6732632423543772848170429916717, attr['T'].name)
-        gamma = tvm.relay.const(1.0507009873554804934193349852946, attr['T'].name)
-        return gamma * (alpha * _op.nn.relu(tvm.relay.const(1, attr['T'].name) \
+        dtype = attr['T'].name
+        alpha = tvm.relay.const(-1.6732632423543772848170429916717, dtype)
+        gamma = tvm.relay.const(1.0507009873554804934193349852946, dtype)
+        return gamma * (alpha * _op.nn.relu(tvm.relay.const(1, dtype) \
                                             - _op.exp(inputs[0])) + _op.nn.relu(inputs[0]))
     return _impl
 
 def _mean():
     def _impl(inputs, attr, params):
-        axis = params.pop(inputs[1].name_hint)
+        axis = _get_tuple_param(params, inputs[1])
         return AttrCvt(op_name="mean", ignores=['Tdim', 'Tidx'],
                        transforms={'keep_dims': 'keepdims'},
-                       extras={'axis': tuple(axis.asnumpy())})([inputs[0]], attr)
+                       extras={'axis': axis})([inputs[0]], attr)
     return _impl
 
 def _broadcast(name):
@@ -1025,8 +1023,7 @@ def _split(has_size_vector):
             if has_size_vector:
                 input_node_index = 0
                 input_axis_index = 2
-                size_splits_input_name = _get_name_hint(inputs[1])
-                size_splits = params[size_splits_input_name].asnumpy()
+                size_splits = _get_param(params, inputs[1])
                 section_beginnings = np.cumsum(size_splits)[:-1]
                 indices_or_sections = tuple(section_beginnings)
             else:
@@ -1034,8 +1031,7 @@ def _split(has_size_vector):
                 input_axis_index = 0
                 indices_or_sections = attr['num_split']
             input_node = inputs[input_node_index]
-            axis_input_name = _get_name_hint(inputs[input_axis_index])
-            axis_input_value = params[axis_input_name].asnumpy()[0]
+            axis_input_value = _get_num_param(params, inputs[input_axis_index])
         except (IndexError, KeyError):
             raise TypeError( \
                 "Unsupported argument for split: `axis` and `num_or_size_splits` " \
@@ -1105,8 +1101,8 @@ def _space_to_batch_nd():
     def _impl(inputs, attr, params):
         input_node = inputs[0]
         input_shape = attr['_input_shapes'][input_node]
-        block_shape = params.pop(inputs[1].name_hint).asnumpy().tolist()
-        paddings = params.pop(inputs[2].name_hint).asnumpy().tolist()
+        block_shape = _get_list_param(params, inputs[1])
+        paddings = _get_list_param(params, inputs[2])
         N = len(input_shape)
         M = len(block_shape)
         batch = input_shape[0]
@@ -1127,7 +1123,7 @@ def _space_to_batch_nd():
         axes = [2 * i + 2 for i in range(M)] + [0] + [2 * i + 1 for i in range(M)] + \
                list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length))
         permuted_reshaped_padded = tvm.relay.transpose(reshaped_padded, axes=axes)
-        permuted_reshaped_padded_shape = _infer_out_shapes(permuted_reshaped_padded, params)[0]
+        permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded, params)
         # Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension,
         # producing an output tensor of shape:
         # [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ...,
@@ -1144,8 +1140,8 @@ def _batch_to_space_nd():
     def _impl(inputs, attr, params):
         input_node = inputs[0]
         input_shape = attr['_input_shapes'][input_node]
-        block_shape = params.pop(inputs[1].name_hint).asnumpy().tolist()
-        crops = params.pop(inputs[2].name_hint).asnumpy().tolist()
+        block_shape = _get_list_param(params, inputs[1])
+        crops = _get_list_param(params, inputs[2])
         M = len(block_shape)
         batch = input_shape[0]
         # From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d:
@@ -1170,7 +1166,7 @@ def _batch_to_space_nd():
         # [batch / prod(block_shape), input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
         #  ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
         #  input_shape[M+1], ..., input_shape[N-1]]
-        reshaped_permuted_shape = _infer_out_shapes(reshaped_permuted, params)[0]
+        reshaped_permuted_shape = _infer_shape(reshaped_permuted, params)
         cropped = reshaped_permuted
         for axis in range(1, M+1):
             crop = crops[axis - 1]
@@ -1971,23 +1967,17 @@ class GraphProto(object):
 
                 # Infer shapes even without specifying "add_shapes=True"
                 if output_shapes == [None]:
-                    out_shapes = []
-                    for node_item in self._nodes[node.name]:
-                        out_type = ir_pass.infer_type(node_item)
-                        out_shapes.append(get_const_tuple(out_type.checked_type.shape))
+                    out_shapes = [_infer_shape(node_item) for node_item in self._nodes[node.name]]
                     self._output_shapes[node.name] = out_shapes
 
                 if self._output_shapes[node.name] and shape and node.name in shape:
                     assert self._output_shapes[node.name] == list(shape[node.name])
 
-            # Infer shapes if passed explicitely
+            # Infer shapes if passed explicitly
             node_output = self._nodes[node.name]
             if shape and (not self._output_shapes[node.name][0]
                           or -1 in self._output_shapes[node.name][0]):
-                out_shapes = []
-                for node_item in node_output:
-                    out_type = ir_pass.infer_type(node_item)
-                    out_shapes.append(get_const_tuple(out_type.checked_type.shape))
+                out_shapes = [_infer_shape(node_item) for node_item in node_output]
                 self._output_shapes[node.name] = out_shapes
 
         out = []
index eebb73c95b1b75186262552083a809d15cb9f09a..3899bc04d5c673d77fffc430f82196fb44ea78bc 100644 (file)
@@ -56,31 +56,23 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
     layout = None
     if target == "cuda":
         layout = "NCHW"
-    target_host = 'llvm'
-
-    if isinstance(input_data, list):
-        shape_dict = {}
-        dtype_dict = {}
-        for i, e in enumerate(input_node):
-            shape_dict[e] = input_data[i].shape
-            dtype_dict[e] = input_data[i].dtype
-    else:
-        shape_dict = {input_node: input_data.shape}
-        dtype_dict = {input_node: input_data.dtype}
+    target_host = None
+
+    shape_dict = {e: i.shape for e, i in zip(input_node, input_data)}
 
     sym, params = relay.frontend.from_tensorflow(graph_def,
                                                  layout=layout,
                                                  shape=shape_dict,
                                                  outputs=out_names)
     with relay.build_config(opt_level=opt_level):
-        graph, lib, params = relay.build(sym, target, params=params)
+        graph, lib, params = relay.build(sym, target, target_host, params)
 
     ctx = tvm.context(target, 0)
     from tvm.contrib import graph_runtime
     m = graph_runtime.create(graph, lib, ctx)
     # set inputs
-    for i, e in enumerate(input_node):
-        m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype)))
+    for e, i in zip(input_node, input_data):
+        m.set_input(e, tvm.nd.array(i))
 
     m.set_input(**params)
     # execute
@@ -88,10 +80,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
     # get outputs
     assert out_names is None or num_output == len(out_names), (
         "out_names: {} num_output: {}".format(out_names, num_output))
-    tvm_output_list = []
-    for i in range(0, num_output):
-        tvm_output = m.get_output(i)
-        tvm_output_list.append(tvm_output.asnumpy())
+    tvm_output_list = [m.get_output(i).asnumpy() for i in range(num_output)]
     return tvm_output_list
 
 def run_tf_graph(sess, input_data, input_node, output_node):
@@ -100,13 +89,9 @@ def run_tf_graph(sess, input_data, input_node, output_node):
     input_node = convert_to_list(input_node)
     output_node = convert_to_list(output_node)
 
-    tensor = [0] * len(output_node)
-    for i in range(len(output_node)):
-        tensor[i] = sess.graph.get_tensor_by_name(output_node[i])
+    tensor = [sess.graph.get_tensor_by_name(output_name) for output_name in output_node]
 
-    input_dict = {}
-    for i, e in enumerate(input_node):
-        input_dict[e] = input_data[i]
+    input_dict = {e: input_data[i] for i, e in enumerate(input_node)}
 
     output_data = sess.run(tensor, input_dict)
     return output_data
@@ -115,17 +100,15 @@ def run_tf_graph(sess, input_data, input_node, output_node):
 def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
                         no_gpu=False, opt_level=3):
     """Generic function to generate and compare tensorflow and TVM output"""
+    def name_without_num(name):
+        return name.split(':')[0] if ":" in name else name
 
     out_name = convert_to_list(out_name)
-    out_node = [0]*len(out_name)
-    for i in range(len(out_name)):
-        out_node[i] = out_name[i].split(':')[0] if ":" in out_name[i] else out_name[i]
+    out_node = [name_without_num(name) for name in out_name]
 
     in_data = convert_to_list(in_data)
     in_name = convert_to_list(in_name)
-    in_node = [0]*len(in_name)
-    for i in range(len(in_name)):
-        in_node[i] = in_name[i].split(':')[0] if ":" in in_name[i] else in_name[i]
+    in_node = [name_without_num(name) for name in in_name]
     with tf.Session() as sess:
         if init_global_variables:
             sess.run(variables.global_variables_initializer())
@@ -577,6 +560,38 @@ def test_forward_variable():
     _test_variable(np.random.uniform(size=(32, 100)).astype('float32'))
 
 
+#######################################################################
+# MatMul
+# ------
+
+def _test_matmul(i, j, k, dtype, outer=None):
+    """ One iteration of matmul """
+
+    A_shape_init = [i, j]
+    B_shape_init = [j, k]
+
+    for transpose_a in [False, True]:
+        for transpose_b in [False, True]:
+            outer = outer or []
+            A_shape = outer + (A_shape_init[::-1] if transpose_a else A_shape_init)
+            B_shape = outer + (B_shape_init[::-1] if transpose_b else B_shape_init)
+
+            with tf.Graph().as_default():
+                A = tf.placeholder(shape=A_shape, dtype=dtype, name='A')
+                B = tf.placeholder(shape=B_shape, dtype=dtype, name='B')
+                result = tf.matmul(A, B, transpose_a=transpose_a, transpose_b=transpose_b)
+
+                A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype)
+                B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
+                compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name)
+
+def test_forward_matmul():
+    """ Matmul op test"""
+    _test_matmul(1, 3, 6, 'int32')
+    _test_matmul(5, 3, 1, 'float64')
+    # TODO non-empty outer requires BatchMatMul (BatchMatMulV2 for some cases?) support
+
+
 #######################################################################
 # StridedSlice
 # ------------
@@ -1785,3 +1800,6 @@ if __name__ == '__main__':
     test_forward_rel_ops()
     test_forward_logical()
     test_where()
+
+    test_forward_matmul()
+    # TODO missing tests: rank, range
\ No newline at end of file
index f648245c6bb78381da3da1ec9a31d37a2a9c7824..623c81a07da8c7eb352a39dcf8588e6b164b28a9 100644 (file)
@@ -151,11 +151,7 @@ def get_const_tuple(in_tuple):
     out_tuple : tuple of int
         The output.
     """
-    out_tuple = ()
-    for elem in in_tuple:
-        value = get_const_int(elem)
-        out_tuple = out_tuple + (value, )
-    return out_tuple
+    return tuple(get_const_int(elem) for elem in in_tuple)
 
 
 def get_float_tuple(in_tuple):
@@ -171,11 +167,7 @@ def get_float_tuple(in_tuple):
     out_tuple : tuple of float
         The output.
     """
-    out_tuple = ()
-    for elem in in_tuple:
-        value = get_const_float(elem)
-        out_tuple = out_tuple + (value, )
-    return out_tuple
+    return tuple(get_const_float(elem) for elem in in_tuple)
 
 
 def simplify(expr):