from __future__ import print_function
import numpy as np
import pytest
-import tensorflow as tf
+try:
+ import tensorflow.compat.v1 as tf
+except ImportError:
+ import tensorflow as tf
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import graph_util
from tensorflow.python.ops import nn_ops
from tvm import te
from tvm import relay
import tvm.relay.testing.tf as tf_testing
+from packaging import version as package_version
#######################################################################
# Generic run functions for TVM & tensorflow
""" One iteration of a variable """
tf.reset_default_graph()
- input_op = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
- input_tensor = array_ops.reshape(input_op, data.shape)
+ with tf.Graph().as_default():
+ input_op = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+ input_tensor = array_ops.reshape(input_op, data.shape)
- size = input_tensor.shape.dims[1]
- with variable_scope.variable_scope("linear", reuse=None):
- w = variable_scope.get_variable(
- "w", shape=[size, size], dtype=input_tensor.dtype)
- math_ops.matmul(input_tensor, w)
+ size = input_tensor.shape.dims[1]
+ with variable_scope.variable_scope("linear", reuse=None):
+ w = variable_scope.get_variable(
+ "w", shape=[size, size], dtype=input_tensor.dtype)
+ math_ops.matmul(input_tensor, w)
- compare_tf_with_tvm(data, 'Placeholder:0', 'MatMul:0',
- init_global_variables=True)
+ compare_tf_with_tvm(data, 'Placeholder:0', 'MatMul:0',
+ init_global_variables=True)
def test_forward_variable():
""" One iteration of a Stridedslice """
tf.reset_default_graph()
- in_data = tf.placeholder(dtype, ip_shape, name="in_data")
- tf.strided_slice(in_data, begin, end, stride, begin_mask=begin_mask,
- end_mask=end_mask, new_axis_mask=new_axis_mask,
- shrink_axis_mask=shrink_axis_mask,
- ellipsis_mask=ellipsis_mask, name="strided_slice")
- np_data = np.random.uniform(size=ip_shape).astype(dtype)
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(dtype, ip_shape, name="in_data")
+ tf.strided_slice(in_data, begin, end, stride, begin_mask=begin_mask,
+ end_mask=end_mask, new_axis_mask=new_axis_mask,
+ shrink_axis_mask=shrink_axis_mask,
+ ellipsis_mask=ellipsis_mask, name="strided_slice")
+ np_data = np.random.uniform(size=ip_shape).astype(dtype)
- compare_tf_with_tvm(np_data, 'in_data:0', 'strided_slice:0')
+ compare_tf_with_tvm(np_data, 'in_data:0', 'strided_slice:0')
def test_forward_stridedslice():
np_numer = np.random.uniform(-100, 100, size=ip_shape).astype(dtype)
np_denomin = np.random.uniform(1, 100, size=ip_shape).astype(dtype)
tf.reset_default_graph()
- numerator = tf.placeholder(dtype, ip_shape, name="numer")
- denominator = tf.placeholder(dtype, ip_shape, name="denomin")
- tf.math.divide(numerator, denominator, name='RealDiv')
- compare_tf_with_tvm([np_numer, np_denomin], [
- 'numer:0', 'denomin:0'], 'RealDiv:0')
+ with tf.Graph().as_default():
+ numerator = tf.placeholder(dtype, ip_shape, name="numer")
+ denominator = tf.placeholder(dtype, ip_shape, name="denomin")
+ tf.math.divide(numerator, denominator, name='RealDiv')
+ compare_tf_with_tvm([np_numer, np_denomin], [
+ 'numer:0', 'denomin:0'], 'RealDiv:0')
def _test_forward_floordiv(ip_shape, dtype):
np_numer = np.random.uniform(1, 100, size=ip_shape).astype(dtype)
tf.reset_default_graph()
- numerator = tf.placeholder(dtype, ip_shape, name="numer")
- tf.math.floordiv(numerator, tf.constant(5, dtype=dtype), name='FloorDiv')
- compare_tf_with_tvm([np_numer], ['numer:0'], 'FloorDiv:0')
+ with tf.Graph().as_default():
+ numerator = tf.placeholder(dtype, ip_shape, name="numer")
+ tf.math.floordiv(numerator, tf.constant(5, dtype=dtype), name='FloorDiv')
+ compare_tf_with_tvm([np_numer], ['numer:0'], 'FloorDiv:0')
def test_forward_divide():
np_numer = np.random.uniform(1, 100, size=in_shape).astype(dtype)
np_factor = np.random.uniform(1, 100, size=if_shape).astype(dtype)
tf.reset_default_graph()
- numerator = tf.placeholder(dtype, in_shape, name="numer")
- factor = tf.placeholder(dtype, if_shape, name="factor")
- tf.floormod(numerator, factor, name='FloorMod')
- compare_tf_with_tvm([np_numer, np_factor], ['numer:0', 'factor:0'], 'FloorMod:0')
+ with tf.Graph().as_default():
+ numerator = tf.placeholder(dtype, in_shape, name="numer")
+ factor = tf.placeholder(dtype, if_shape, name="factor")
+ tf.floormod(numerator, factor, name='FloorMod')
+ compare_tf_with_tvm([np_numer, np_factor], ['numer:0', 'factor:0'], 'FloorMod:0')
def test_forward_floormod():
'''test FloorMod'''
np_data_1 = np.random.uniform(-100, 100, size=ip_shape).astype(dtype)
np_data_2 = np.random.uniform(1, 10, size=ip_shape).astype(dtype)
tf.reset_default_graph()
- in_data_1 = tf.placeholder(dtype, ip_shape, name="in_data_1")
- in_data_2 = tf.placeholder(dtype, ip_shape, name="in_data_2")
- tf.truncatemod(in_data_1, in_data_2, name='truncatemod')
- compare_tf_with_tvm([np_data_1, np_data_2], [
- 'in_data_1:0', 'in_data_2:0'], 'truncatemod:0')
+ with tf.Graph().as_default():
+ in_data_1 = tf.placeholder(dtype, ip_shape, name="in_data_1")
+ in_data_2 = tf.placeholder(dtype, ip_shape, name="in_data_2")
+ tf.truncatemod(in_data_1, in_data_2, name='truncatemod')
+ compare_tf_with_tvm([np_data_1, np_data_2], [
+ 'in_data_1:0', 'in_data_2:0'], 'truncatemod:0')
def test_forward_truncatemod():
""" One iteration of a GatherV2 """
tf.reset_default_graph()
- in_data = tf.placeholder(dtype, ip_shape, name="in_data")
- indices = tf.placeholder("int32", indice_shape, name="indices")
- out = tf.gather(in_data, indices, axis=axis)
- np_data = np.random.uniform(1, 10, size=ip_shape).astype(dtype)
-
- def _fill_indices(indice_value):
- indices = np.array(ip_shape, dtype=dtype)
- if isinstance(indice_value, int):
- indices = np.array([indice_value], dtype='int32')
- else:
- indices = np.asarray(indice_value, dtype='int32')
- return indices
- np_indices = _fill_indices(indice_value)
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(dtype, ip_shape, name="in_data")
+ indices = tf.placeholder("int32", indice_shape, name="indices")
+ out = tf.gather(in_data, indices, axis=axis)
+ np_data = np.random.uniform(1, 10, size=ip_shape).astype(dtype)
+
+ def _fill_indices(indice_value):
+ indices = np.array(ip_shape, dtype=dtype)
+ if isinstance(indice_value, int):
+ indices = np.array([indice_value], dtype='int32')
+ else:
+ indices = np.asarray(indice_value, dtype='int32')
+ return indices
+ np_indices = _fill_indices(indice_value)
- compare_tf_with_tvm([np_data, np_indices], [
- 'in_data:0', 'indices:0'], out.name)
+ compare_tf_with_tvm([np_data, np_indices], [
+ 'in_data:0', 'indices:0'], out.name)
def test_forward_gather():
"""test operator GatherNd"""
np_data = np.random.uniform(1, 100, size=(2, 2)).astype(np.float32)
tf.reset_default_graph()
- in_data = tf.placeholder(tf.float32, (2, 2), name="in_data")
- tf.gather_nd(in_data, indices=[[1, 0], [0, 1]], name="gather_nd")
- compare_tf_with_tvm([np_data], ['in_data:0'], 'gather_nd:0')
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(tf.float32, (2, 2), name="in_data")
+ tf.gather_nd(in_data, indices=[[1, 0], [0, 1]], name="gather_nd")
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'gather_nd:0')
#######################################################################
tf.reset_default_graph()
lh_data = np.random.uniform(size=lh_shpae).astype(dtype)
rh_data = np.random.uniform(size=rh_shape).astype(dtype)
- lft_data = tf.placeholder(dtype, name="lft_data")
- rgt_data = tf.placeholder(dtype, name="rgt_data")
- tf.nn.bias_add(lft_data, rgt_data, name="BiasAdd")
- compare_tf_with_tvm([lh_data, rh_data], [
- 'lft_data:0', 'rgt_data:0'], 'BiasAdd:0')
+ with tf.Graph().as_default():
+ lft_data = tf.placeholder(dtype, name="lft_data")
+ rgt_data = tf.placeholder(dtype, name="rgt_data")
+ tf.nn.bias_add(lft_data, rgt_data, name="BiasAdd")
+ compare_tf_with_tvm([lh_data, rh_data], [
+ 'lft_data:0', 'rgt_data:0'], 'BiasAdd:0')
check_bias_add((10, 8, 16, 32), (32,), dtype="int32")
check_bias_add((10, 20), (20,), dtype="float32")
""" One iteration of a Split """
tf.reset_default_graph()
- in_data = tf.placeholder(dtype, in_shape, name="in_data")
- num_split = len(num_or_size_splits) if isinstance(num_or_size_splits, list)\
- else num_or_size_splits
- split = tf.split(in_data, num_or_size_splits, axis=axis)
- relu = [tf.nn.relu(i) for i in split]
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(dtype, in_shape, name="in_data")
+ num_split = len(num_or_size_splits) if isinstance(num_or_size_splits, list)\
+ else num_or_size_splits
+ split = tf.split(in_data, num_or_size_splits, axis=axis)
+ relu = [tf.nn.relu(i) for i in split]
- compare_tf_with_tvm([np_data], ['in_data:0'], [n.name for n in relu])
+ compare_tf_with_tvm([np_data], ['in_data:0'], [n.name for n in relu])
# and now test together with concat
tf.reset_default_graph()
- in_data = tf.placeholder(dtype, in_shape, name="in_data")
- splitted = tf.split(in_data, num_or_size_splits, axis=axis)
- tf.concat(splitted, axis)
-
- compare_tf_with_tvm([np_data], 'in_data:0', 'concat:0')
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(dtype, in_shape, name="in_data")
+ splitted = tf.split(in_data, num_or_size_splits, axis=axis)
+ concat = tf.concat(splitted, axis)
+ compare_tf_with_tvm([np_data], 'in_data:0', concat.name)
def test_forward_split():
def _test_forward_top_k_v2(in_shape, k):
np_data = np.random.uniform(-100, 100, size=in_shape).astype("float32")
tf.reset_default_graph()
- in_data = tf.placeholder("float32", in_shape, name="in_data")
- tf.math.top_k(in_data, k, name='TopK')
- compare_tf_with_tvm([np_data], ['in_data:0'], 'TopK:0')
+ with tf.Graph().as_default():
+ in_data = tf.placeholder("float32", in_shape, name="in_data")
+ tf.math.top_k(in_data, k, name='TopK')
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'TopK:0')
def test_forward_top_k_v2():
np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype)
tf.reset_default_graph()
- in_data = tf.placeholder(dtype, ip_shape, name="in_data")
- unstack = tf.unstack(in_data, axis=axis)
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(dtype, ip_shape, name="in_data")
+ unstack = tf.unstack(in_data, axis=axis)
- compare_tf_with_tvm([np_data], ['in_data:0'], [n.name for n in unstack])
+ compare_tf_with_tvm([np_data], ['in_data:0'], [n.name for n in unstack])
tf.reset_default_graph()
- in_data = tf.placeholder(dtype, ip_shape, name="in_data")
- tf.stack(tf.unstack(in_data, axis=axis), axis=axis)
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(dtype, ip_shape, name="in_data")
+ tf.stack(tf.unstack(in_data, axis=axis), axis=axis)
- compare_tf_with_tvm([np_data], ['in_data:0'], 'stack:0')
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'stack:0')
def test_forward_unstack():
def _test_tile(in_shape, multiples, dtype):
np_data = np.random.uniform(-5, 5, size=in_shape).astype(dtype)
tf.reset_default_graph()
- in_data = tf.placeholder(dtype, in_shape, name="in_data")
- tf.tile(in_data, multiples=multiples, name="tile")
- compare_tf_with_tvm([np_data], ['in_data:0'], 'tile:0')
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(dtype, in_shape, name="in_data")
+ tf.tile(in_data, multiples=multiples, name="tile")
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'tile:0')
def test_forward_tile():
def _test_forward_clip_by_value(ip_shape, clip_value_min, clip_value_max, dtype):
tf.reset_default_graph()
- in_data = tf.placeholder(dtype, ip_shape, name="in_data")
- tf.clip_by_value(in_data, clip_value_min,
- clip_value_max, name="ClipByValue")
- np_data = np.random.uniform(-100, 100, size=ip_shape).astype(dtype)
- compare_tf_with_tvm([np_data], ['in_data:0'], 'ClipByValue:0')
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(dtype, ip_shape, name="in_data")
+ tf.clip_by_value(in_data, clip_value_min,
+ clip_value_max, name="ClipByValue")
+ np_data = np.random.uniform(-100, 100, size=ip_shape).astype(dtype)
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'ClipByValue:0')
def test_forward_clip_by_value():
extrapolation_value=0.0, method='bilinear', dtype="float32"):
image = np.random.uniform(0, 10, size=img_shape).astype(dtype)
tf.reset_default_graph()
- in_data = tf.placeholder(dtype, image.shape, name="in_data")
- tf.image.crop_and_resize(in_data, boxes=boxes, box_ind=box_idx,
- crop_size=crop_size, method=method,
- extrapolation_value=extrapolation_value,
- name="crop_and_resize")
- compare_tf_with_tvm([image], ['in_data:0'], 'crop_and_resize:0')
+ with tf.Graph().as_default():
+ in_data = array_ops.placeholder(dtype, image.shape, name="in_data")
+ tf.image.crop_and_resize(in_data, boxes=boxes, box_ind=box_idx,
+ crop_size=crop_size, method=method,
+ extrapolation_value=extrapolation_value,
+ name="crop_and_resize")
+ compare_tf_with_tvm([image], ['in_data:0'], 'crop_and_resize:0')
def test_forward_crop_and_resize():
m1 = array_ops.zeros([batch_size, num_hidden])
x = tf.placeholder(shape=(batch_size, input_size), dtype=dtype)
g, ((out_m0, out_m1)) = \
- tf.contrib.rnn.LSTMBlockCell(num_hidden,
- forget_bias=forget_bias)(x, ((m0, m1)))
+ tensorflow.contrib.rnn.LSTMBlockCell(num_hidden,
+ forget_bias=forget_bias)(x, (m0, m1))
sess.run([variables.global_variables_initializer()])
res = sess.run([g, out_m0, out_m1], {
x.name: np.array([[1., 1.]]),
"""test operator Unpack"""
np_data = np.random.uniform(-100, 100, size=in_shape).astype(dtype)
tf.reset_default_graph()
- in_data = tf.placeholder(dtype, in_shape, name="in_data")
- tf.unstack(in_data, axis=axis, name="Unpack")
- compare_tf_with_tvm([np_data], ['in_data:0'], 'Unpack:0')
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(dtype, in_shape, name="in_data")
+ tf.unstack(in_data, axis=axis, name="Unpack")
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'Unpack:0')
def test_forward_unpack():
def test_forward_range():
"""test operator Range"""
tf.reset_default_graph()
- tf.range(1, 18, 3, name="range")
- compare_tf_with_tvm([], [], 'range:0')
+ with tf.Graph().as_default():
+ tf.range(1, 18, 3, name="range")
+ compare_tf_with_tvm([], [], 'range:0')
"""test type assignment for operator Range"""
tf.reset_default_graph()
- tf.range(1, 256 + 1, 1, dtype=tf.float32)
- compare_tf_with_tvm([], [], 'range:0')
+ with tf.Graph().as_default():
+ tf.range(1, 256 + 1, 1, dtype=tf.float32)
+ compare_tf_with_tvm([], [], 'range:0')
#######################################################################
# Pad
#######################################################################
# PTB
# ---
-dir(tf.contrib)
-
+try:
+ #Load contrib for running ptb model in tf version before 2.0
+ import tensorflow.contrib
+except:
+ pass
def test_forward_ptb():
'''test ptb model'''
def check_softmax(in_shape, axis, dtype):
np_data = np.random.uniform(-100, 100, size=in_shape).astype(dtype)
tf.reset_default_graph()
- in_data = tf.placeholder(dtype, in_shape, name="in_data")
- tf.nn.softmax(in_data, axis=axis, name="Softmax")
- compare_tf_with_tvm([np_data], ['in_data:0'], 'Softmax:0')
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(dtype, in_shape, name="in_data")
+ tf.nn.softmax(in_data, axis=axis, name="Softmax")
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'Softmax:0')
check_softmax((2, 3, 5), 2, "float32")
check_softmax((2, 3, 5), -1, "float32")
"""test Round"""
np_data = np.random.uniform(-10, 10, size=(5, 7)).astype(np.float32)
tf.reset_default_graph()
- in_data = tf.placeholder(tf.float32, (5, 7), name="in_data")
- tf.round(in_data, name="round")
- compare_tf_with_tvm([np_data], ['in_data:0'], 'round:0')
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(tf.float32, (5, 7), name="in_data")
+ tf.round(in_data, name="round")
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'round:0')
def test_forward_abs():
"""test operator Abs"""
np_data = np.random.uniform(1, 100, size=(9, 11)).astype(np.float32)
tf.reset_default_graph()
- in_data = tf.placeholder(tf.float32, (9, 11), name="in_data")
- tf.math.abs(in_data, name="abs")
- compare_tf_with_tvm([np_data], ['in_data:0'], 'abs:0')
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(tf.float32, (9, 11), name="in_data")
+ tf.math.abs(in_data, name="abs")
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'abs:0')
def _test_forward_zeros_like(in_shape, dtype):
np_data = np.random.uniform(-10, 10, size=in_shape).astype(dtype)
tf.reset_default_graph()
- in_data = tf.placeholder(dtype, in_shape, name="in_data")
- tf.zeros_like(in_data, name="zeros_like")
- compare_tf_with_tvm([np_data], ['in_data:0'], 'zeros_like:0')
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(dtype, in_shape, name="in_data")
+ tf.zeros_like(in_data, name="zeros_like")
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'zeros_like:0')
def test_forward_zeros_like():
def _test_forward_reverse_v2(in_shape, axis, dtype):
np_data = np.random.uniform(-10, 10, size=in_shape).astype(dtype)
tf.reset_default_graph()
- in_data = tf.placeholder(dtype, in_shape, name="in_data")
- tf.reverse(in_data, axis=[axis], name="reverse")
- compare_tf_with_tvm([np_data], ['in_data:0'], 'reverse:0')
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(dtype, in_shape, name="in_data")
+ tf.reverse(in_data, axis=[axis], name="reverse")
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'reverse:0')
def test_forward_reverse_v2():
"""test Sign"""
np_data = np.random.uniform(-10, 10, size=(5, 7, 11)).astype(np.float32)
tf.reset_default_graph()
- in_data = tf.placeholder(tf.float32, (5, 7, 11), name="in_data")
- tf.sign(in_data, name="sign")
- compare_tf_with_tvm([np_data], ['in_data:0'], 'sign:0')
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(tf.float32, (5, 7, 11), name="in_data")
+ tf.sign(in_data, name="sign")
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'sign:0')
def test_forward_square():
"""test operator Square """
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph()
- in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
- tf.square(in_data, name="square")
- compare_tf_with_tvm([np_data], ['in_data:0'], 'square:0')
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
+ tf.square(in_data, name="square")
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'square:0')
def test_forward_pow_exp():
np_in1 = np.random.uniform(-2, 2, size=(5, 7, 11)).astype(np.float32)
np_in2 = np.random.uniform(-2, 2, size=(5, 7, 11)).astype(np.float32)
tf.reset_default_graph()
- in1 = tf.placeholder(tf.float32, (5, 7, 11), name="in1")
- in2 = tf.placeholder(tf.float32, (5, 7, 11), name="in2")
- out1 = tf.pow(in1, in2, name="pow")
- out = tf.exp(in1, name='exp')
- compare_tf_with_tvm([np_in1, np_in2], ['in1:0', 'in2:0'], 'pow:0')
- compare_tf_with_tvm([np_in1], ['in1:0'], 'exp:0')
+ with tf.Graph().as_default():
+ in1 = tf.placeholder(tf.float32, (5, 7, 11), name="in1")
+ in2 = tf.placeholder(tf.float32, (5, 7, 11), name="in2")
+ out1 = tf.pow(in1, in2, name="pow")
+ out = tf.exp(in1, name='exp')
+ compare_tf_with_tvm([np_in1, np_in2], ['in1:0', 'in2:0'], 'pow:0')
+ compare_tf_with_tvm([np_in1], ['in1:0'], 'exp:0')
def test_forward_log():
"""test operator Log """
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph()
- in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
- tf.log(in_data, name="log")
- compare_tf_with_tvm([np_data], ['in_data:0'], 'log:0')
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
+ tf.log(in_data, name="log")
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'log:0')
def test_forward_log1p():
"""test operator Log1p """
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph()
- in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
- tf.log1p(in_data, name="log1p")
- compare_tf_with_tvm([np_data], ['in_data:0'], 'log1p:0')
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
+ tf.log1p(in_data, name="log1p")
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'log1p:0')
def test_forward_cos():
"""test operator cos """
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph()
- in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
- tf.cos(in_data, name="cos")
- compare_tf_with_tvm([np_data], ['in_data:0'], 'cos:0')
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
+ tf.cos(in_data, name="cos")
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'cos:0')
def test_forward_tan():
"""test operator sin """
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph()
- in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
- tf.sin(in_data, name="sin")
- compare_tf_with_tvm([np_data], ['in_data:0'], 'sin:0')
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
+ tf.sin(in_data, name="sin")
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'sin:0')
def test_forward_negative():
np_data = np.random.uniform(-100, 255,
size=(224, 224, 3)).astype(np.float32)
tf.reset_default_graph()
- in_data = tf.placeholder(tf.float32, (224, 224, 3), name="in_data")
- tf.negative(in_data, name="negative")
- compare_tf_with_tvm([np_data], ['in_data:0'], 'negative:0')
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(tf.float32, (224, 224, 3), name="in_data")
+ tf.negative(in_data, name="negative")
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'negative:0')
def test_forward_log_softmax():
"""test operator LogSoftmax"""
np_data = np.random.uniform(1, 100, size=(9, 11)).astype(np.float32)
tf.reset_default_graph()
- in_data = tf.placeholder(tf.float32, (9, 11), name="in_data")
- tf.math.log_softmax(in_data, name="LogSoftmax")
- compare_tf_with_tvm([np_data], ['in_data:0'], 'LogSoftmax:0')
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(tf.float32, (9, 11), name="in_data")
+ tf.math.log_softmax(in_data, name="LogSoftmax")
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'LogSoftmax:0')
def test_forward_softplus():
"""test operator Softplus"""
np_data = np.random.uniform(1, 10, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph()
- in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
- tf.nn.softplus(in_data, name="softplus")
- compare_tf_with_tvm([np_data], ['in_data:0'], 'softplus:0')
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
+ tf.nn.softplus(in_data, name="softplus")
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'softplus:0')
def test_forward_rsqrt():
"""test Rsqrt """
np_data = np.random.uniform(1, 100, size=(5, 7, 11)).astype(np.float32)
tf.reset_default_graph()
- in_data = tf.placeholder(tf.float32, (5, 7, 11), name="in_data")
- tf.rsqrt(in_data, name="rsqrt")
- compare_tf_with_tvm([np_data], ['in_data:0'], 'rsqrt:0')
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(tf.float32, (5, 7, 11), name="in_data")
+ tf.rsqrt(in_data, name="rsqrt")
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'rsqrt:0')
def test_forward_sqrt():
"""test Sqrt """
np_data = np.random.uniform(1, 100, size=(5, 7, 11)).astype(np.float32)
tf.reset_default_graph()
- in_data = tf.placeholder(tf.float32, (5, 7, 11), name="in_data")
- tf.sqrt(in_data, name="sqrt")
- compare_tf_with_tvm([np_data], ['in_data:0'], 'sqrt:0')
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(tf.float32, (5, 7, 11), name="in_data")
+ tf.sqrt(in_data, name="sqrt")
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'sqrt:0')
def _test_forward_right_shift(in_shape, dtype):
lh_data = np.random.randint(1, 3, size=in_shape).astype(dtype)
rh_data = np.random.randint(1, 8, size=in_shape).astype(dtype)
tf.reset_default_graph()
- lft_data = tf.placeholder(dtype, in_shape, name="lft_data")
- rgt_data = tf.placeholder(dtype, in_shape, name="rgt_data")
- tf.bitwise.right_shift(lft_data, rgt_data, name="RightShift")
- compare_tf_with_tvm([lh_data, rh_data], [
- 'lft_data:0', 'rgt_data:0'], 'RightShift:0')
+ with tf.Graph().as_default():
+ lft_data = tf.placeholder(dtype, in_shape, name="lft_data")
+ rgt_data = tf.placeholder(dtype, in_shape, name="rgt_data")
+ tf.bitwise.right_shift(lft_data, rgt_data, name="RightShift")
+ compare_tf_with_tvm([lh_data, rh_data], [
+ 'lft_data:0', 'rgt_data:0'], 'RightShift:0')
def test_forward_right_shift():
lh_data = np.random.randint(100, 1000000, size=in_shape).astype(dtype)
rh_data = np.random.randint(1, 3, size=in_shape).astype(dtype)
tf.reset_default_graph()
- lft_data = tf.placeholder(dtype, in_shape, name="lft_data")
- rgt_data = tf.placeholder(dtype, in_shape, name="rgt_data")
- tf.bitwise.left_shift(lft_data, rgt_data, name="LeftShift")
- compare_tf_with_tvm([lh_data, rh_data], [
- 'lft_data:0', 'rgt_data:0'], 'LeftShift:0')
+ with tf.Graph().as_default():
+ lft_data = tf.placeholder(dtype, in_shape, name="lft_data")
+ rgt_data = tf.placeholder(dtype, in_shape, name="rgt_data")
+ tf.bitwise.left_shift(lft_data, rgt_data, name="LeftShift")
+ compare_tf_with_tvm([lh_data, rh_data], [
+ 'lft_data:0', 'rgt_data:0'], 'LeftShift:0')
def test_forward_left_shift():
"""Test the All operator."""
np_data = np.random.choice([True, False], size=(5, 7, 11))
tf.reset_default_graph()
- in_data = tf.placeholder(tf.bool, (5, 7, 11), name="in_data")
- tf.reduce_all(in_data, name="all")
- compare_tf_with_tvm([np_data], ['in_data:0'], 'all:0')
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(tf.bool, (5, 7, 11), name="in_data")
+ tf.reduce_all(in_data, name="all")
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'all:0')
def test_forward_reduce_any():
"""Test the Any operator."""
np_data = np.random.choice([True, False], size=(5, 7, 11))
tf.reset_default_graph()
- in_data = tf.placeholder(tf.bool, (5, 7, 11), name="in_data")
- tf.reduce_any(in_data, name="any")
- compare_tf_with_tvm([np_data], ['in_data:0'], 'any:0')
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(tf.bool, (5, 7, 11), name="in_data")
+ tf.reduce_any(in_data, name="any")
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'any:0')
def test_forward_reduce_max():
def check_max(ishape, axis, keepdims, dtype):
tf.reset_default_graph()
np_data = np.random.uniform(size=ishape).astype(dtype)
- in_data = tf.placeholder(dtype, name="in_data")
- tf.math.reduce_max(in_data, axis=axis,
- keepdims=keepdims, name="reduce_max")
- compare_tf_with_tvm([np_data], ['in_data:0'], 'reduce_max:0')
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(dtype, name="in_data")
+ tf.math.reduce_max(in_data, axis=axis,
+ keepdims=keepdims, name="reduce_max")
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'reduce_max:0')
check_max((10, 8, 16, 32), axis=(-1), keepdims=True, dtype="int32")
check_max((10, 8, 16, 32), axis=(2, 3), keepdims=True, dtype="float32")
def check_min(ishape, axis, keepdims, dtype):
tf.reset_default_graph()
np_data = np.random.uniform(size=ishape).astype(dtype)
- in_data = tf.placeholder(dtype, name="in_data")
- tf.math.reduce_min(in_data, axis=axis,
- keepdims=keepdims, name="reduce_max")
- compare_tf_with_tvm([np_data], ['in_data:0'], 'reduce_max:0')
+ with tf.Graph().as_default():
+ in_data = tf.placeholder(dtype, name="in_data")
+ tf.math.reduce_min(in_data, axis=axis,
+ keepdims=keepdims, name="reduce_max")
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'reduce_max:0')
check_min((10, 8, 16, 32), axis=(-1), keepdims=True, dtype="int32")
check_min((10, 8, 16, 32), axis=(2, 3), keepdims=True, dtype="float32")
def _test_forward_expand_dims(data, axis):
- in1 = tf.placeholder(shape=data.shape, dtype=data.dtype, name='in1')
- out = tf.expand_dims(in1, axis)
- compare_tf_with_tvm([data], [in1.name], out.name)
+ with tf.Graph().as_default():
+ in1 = tf.placeholder(shape=data.shape, dtype=data.dtype, name='in1')
+ out = tf.expand_dims(in1, axis)
+ compare_tf_with_tvm([data], [in1.name], out.name)
def test_forward_expand_dims():
tf.reset_default_graph()
lh_data = np.random.uniform(size=lh_shape).astype(dtype)
rh_data = np.random.uniform(size=rh_shape).astype(dtype)
- lft_data = tf.placeholder(dtype, name="lft_data")
- rgt_data = tf.placeholder(dtype, name="rgt_data")
- tf.math.maximum(lft_data, rgt_data, name="maximum")
- compare_tf_with_tvm([lh_data, rh_data], [
- 'lft_data:0', 'rgt_data:0'], 'maximum:0')
+ with tf.Graph().as_default():
+ lft_data = tf.placeholder(dtype, name="lft_data")
+ rgt_data = tf.placeholder(dtype, name="rgt_data")
+ tf.math.maximum(lft_data, rgt_data, name="maximum")
+ compare_tf_with_tvm([lh_data, rh_data], [
+ 'lft_data:0', 'rgt_data:0'], 'maximum:0')
check_maximum((10, 8, 16, 32), (1,), dtype="int32")
check_maximum((10, 8, 16, 32), (10, 8, 16, 32), dtype="float32")
tf.reset_default_graph()
lh_data = np.random.uniform(size=lh_shape).astype(dtype)
rh_data = np.random.uniform(size=rh_shape).astype(dtype)
- lft_data = tf.placeholder(dtype, name="lft_data")
- rgt_data = tf.placeholder(dtype, name="rgt_data")
- tf.math.minimum(lft_data, rgt_data, name="minimum")
- compare_tf_with_tvm([lh_data, rh_data], [
- 'lft_data:0', 'rgt_data:0'], 'minimum:0')
+ with tf.Graph().as_default():
+ lft_data = tf.placeholder(dtype, name="lft_data")
+ rgt_data = tf.placeholder(dtype, name="rgt_data")
+ tf.math.minimum(lft_data, rgt_data, name="minimum")
+ compare_tf_with_tvm([lh_data, rh_data], [
+ 'lft_data:0', 'rgt_data:0'], 'minimum:0')
check_minimum((10, 8, 16, 32), (1,), dtype="int32")
check_minimum((10, 8, 16, 32), (10, 8, 16, 32), dtype="float32")
test_forward_ptb()
# RNN
- test_forward_lstm()
+ if package_version.parse(tf.VERSION) < package_version.parse('2.0.0'):
+ #in 2.0, tf.contrib.rnn.LSTMBlockCell is removed
+ test_forward_lstm()
# Elementwise
test_forward_ceil()