def _impl(inputs, attr, params):
# Tensorflow: (data, gamma, beta, moving_mean, moving_variance)
# Relay: (data, gamma, beta, moving_mean, moving_varience)
+ assert len(inputs) == 5
axis = 3
need_cast = False
if 'U' in attr:
need_cast = True
inputs[0] = _op.cast(inputs[0], dtype=attr['U'].name)
-
+ # Check if mean and variance are empty
+ # If so, replace them with Mean and Variance Ops
+ # For run-time calculation
+ moving_mean_shape = [int(n) for n in inputs[3].type_annotation.shape]
+ moving_variance_shape = [int(n) for n in inputs[4].type_annotation.shape]
+ if (moving_mean_shape[0] == 0 and moving_variance_shape[0] == 0):
+ inputs[3] = _op.mean(inputs[0], axis=axis, keepdims=False, exclude=True)
+ inputs[4] = _op.variance(inputs[0], axis=axis, keepdims=False, exclude=True)
out = AttrCvt(op_name='batch_norm',
transforms={'scale_after_normalization':'scale',
'variance_epsilon':'epsilon'},
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+BatchNorm without given mean and variance given testcases
+====================
+This is a test script to test fused_batch_norm operators
+in TensorFlow frontend when mean and variance are not given.
+"""
+import tvm
+import numpy as np
+import tensorflow as tf
+from tvm import relay
+from tensorflow.python.framework import graph_util
+
+def verify_fused_batch_norm(shape):
+ g = tf.Graph()
+ with g.as_default():
+ input_tensor = tf.placeholder(tf.float32, shape=shape, name='input')
+ alpha = tf.constant(np.random.rand(shape[-1],), dtype=tf.float32, name='alpha')
+ beta = tf.constant(np.random.rand(shape[-1],), dtype=tf.float32, name='beta')
+ bn = tf.nn.fused_batch_norm(x=input_tensor, offset=beta, scale=alpha, name='bn')
+ out = tf.identity(bn[0], name='output')
+ data = np.random.rand(*shape)
+ with tf.Session(graph=out.graph) as sess:
+ sess.run([tf.global_variables_initializer()])
+ tf_out = sess.run(out, feed_dict={input_tensor:data})
+ constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['output'])
+
+ for device in ["llvm"]:
+ ctx = tvm.context(device, 0)
+ if not ctx.exist:
+ print("Skip because %s is not enabled" % device)
+ continue
+ mod, params = relay.frontend.from_tensorflow(constant_graph,
+ outputs=['output'])
+ with relay.build_config(opt_level=3):
+ graph, lib, params = relay.build(mod,
+ target=device,
+ params=params)
+ from tvm.contrib import graph_runtime
+ m = graph_runtime.create(graph, lib, ctx)
+ m.set_input(**params)
+ m.set_input('input', data)
+ m.run()
+ tvm_out = m.get_output(0)
+ tvm.testing.assert_allclose(tvm_out.asnumpy(), tf_out.astype(tvm_out.dtype),
+ atol=1e-3, rtol=1e-3)
+
+def test_fused_batch_norm():
+ verify_fused_batch_norm(shape=(1, 12, 12, 32))
+ verify_fused_batch_norm(shape=(1, 24, 24, 64))
+ verify_fused_batch_norm(shape=(1, 64, 64, 128))
+ verify_fused_batch_norm(shape=(8, 12, 12, 32))
+ verify_fused_batch_norm(shape=(16, 12, 12, 32))
+ verify_fused_batch_norm(shape=(32, 12, 12, 32))
+
+if __name__ == "__main__":
+ test_fused_batch_norm()