[Relay][TensorFlow Frontend] SoftPlus Sqrt (#3187)
authorYong Wu <ywu118@alumni.jh.edu>
Wed, 15 May 2019 05:42:34 +0000 (22:42 -0700)
committerYizhi Liu <liuyizhi@apache.org>
Wed, 15 May 2019 05:42:34 +0000 (22:42 -0700)
python/tvm/relay/frontend/tensorflow.py
tests/python/frontend/tensorflow/test_forward.py

index 48f7883..4bd78b4 100644 (file)
@@ -990,6 +990,16 @@ def _softmax():
                        transforms={'axis': ('axis', 1)})([inputs[0]], attr)
     return _impl
 
+def _softplus():
+    # op description: https://www.tensorflow.org/api_docs/python/tf/math/softplus
+    def _impl(inputs, attr, params):
+        exp_out = AttrCvt('exp')(inputs, attr)
+        inputs.append(tvm.relay.const(1, attr['T'].name))
+        rh = tvm.relay.const(1, attr['T'].name)
+        add_out = _get_relay_op('add')(exp_out, rh)
+        return _get_relay_op('log')(add_out)
+    return _impl
+
 def _logical(name):
     def _impl(inputs, attr, params):
         return AttrCvt(op_name=name)(inputs, attr)
@@ -1163,9 +1173,11 @@ _convert_map = {
     'Sign'                              : AttrCvt('sign'),
     'Slice'                             : _slice(),
     'Softmax'                           : _softmax(),
+    'Softplus'                          : _softplus(),
     'SpaceToBatchND'                    : _space_to_batch_nd(),
     'Split'                             : _split(False),
     'SplitV'                            : _split(True),
+    'Sqrt'                              : AttrCvt('sqrt'),
     'Square'                            : _square(),
     'Squeeze'                           : _squeeze(),
     'StridedSlice'                      : _stridedSlice(),
index 58bbdab..2f1cc2f 100644 (file)
@@ -1151,7 +1151,6 @@ def test_forward_placeholder():
             graph_def = tf_testing.AddShapesToGraphDef(sess, out_node)
             tf_output = run_tf_graph(sess, data, 'Placeholder:0', out_node + ':0')
             tvm_output = run_tvm_graph(graph_def, data, 'Placeholder')
-            print("tf_output is {}\ntvm_output is {}".format(tf_output, tvm_output))
             tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
 
 #######################################################################
@@ -1440,22 +1439,37 @@ def test_forward_pow_exp():
     compare_tf_with_tvm([np_in1], ['in1:0'], 'exp:0')
 
 def test_forward_log():
-    """test Log """
+    """test operator Log """
     np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
     tf.reset_default_graph()
     in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
     tf.log(in_data, name="log")
     compare_tf_with_tvm([np_data], ['in_data:0'], 'log:0')
 
+def test_forward_softplus():
+    """test operator Softplus"""
+    np_data = np.random.uniform(1, 10, size=(2, 3, 5)).astype(np.float32)
+    tf.reset_default_graph()
+    in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
+    tf.nn.softplus(in_data, name="softplus")
+    compare_tf_with_tvm([np_data], ['in_data:0'], 'softplus:0')
+
 def test_forward_rsqrt():
     """test Rsqrt """
     np_data = np.random.uniform(1, 100, size=(5, 7, 11)).astype(np.float32)
     tf.reset_default_graph()
     in_data = tf.placeholder(tf.float32, (5, 7, 11), name="in_data")
     tf.rsqrt(in_data, name="rsqrt")
-    print(tf.get_default_graph().as_graph_def())
     compare_tf_with_tvm([np_data], ['in_data:0'], 'rsqrt:0')
 
+def test_forward_sqrt():
+    """test Sqrt """
+    np_data = np.random.uniform(1, 100, size=(5, 7, 11)).astype(np.float32)
+    tf.reset_default_graph()
+    in_data = tf.placeholder(tf.float32, (5, 7, 11), name="in_data")
+    tf.sqrt(in_data, name="sqrt")
+    compare_tf_with_tvm([np_data], ['in_data:0'], 'sqrt:0')
+
 #######################################################################
 # Mean
 # ----
@@ -1561,6 +1575,8 @@ if __name__ == '__main__':
     test_forward_pow_exp()
     test_forward_sign()
     test_forward_log()
+    test_forward_softplus()
+    test_forward_sqrt()
     test_forward_rsqrt()
     test_forward_expand_dims()