[Relay] [Parser] fix parser for cast. (#3873)
author雾雨魔理沙 <lolisa@marisa.moe>
Mon, 2 Sep 2019 15:22:11 +0000 (08:22 -0700)
committerWuwei Lin <wuwei@apache.org>
Mon, 2 Sep 2019 15:22:11 +0000 (11:22 -0400)
* fix

* lint

python/tvm/relay/_parser.py
tests/python/relay/test_ir_text_printer.py

index 3e559df..f7024fe 100644 (file)
@@ -77,7 +77,8 @@ class ExprOp(OpWrapper):
         try:
             return expr.Call(self.operator, args, attrs, type_args)
         except Exception:
-            raise Exception(str(self.operator) + " " + str(attrs))
+            raise Exception("Operator {} is not registered. It's attributes are {}"
+                            .format(self.operator, attrs))
 
 class FuncOp(OpWrapper):
     """Convert the attrs, call the python function with the attrs passed in as keyword arguments.
@@ -132,6 +133,7 @@ FUNC_OPS = {
     "nn.dropout": op.nn.dropout_raw,
     "zeros": op.zeros,
     "split": op.split,
+    "cast": op.cast
 }
 
 TYPE_PREFIXES = [
index b55261c..c6f59d9 100644 (file)
@@ -169,19 +169,23 @@ def test_inception_v3():
     net, params = tvm.relay.testing.inception_v3.get_workload(batch_size=1)
     astext(net)
 
+
 def test_squeezenet():
     for version in ['1.0', '1.1']:
         net, params = tvm.relay.testing.squeezenet.get_workload(batch_size=1, version=version)
         astext(net)
 
+
 def test_vgg():
     net, params = tvm.relay.testing.vgg.get_workload(batch_size=1)
     astext(net)
 
+
 def test_densenet():
     net, params = tvm.relay.testing.densenet.get_workload(batch_size=1)
     astext(net)
 
+
 def test_call_node_order():
     x = relay.var("x")
     y = relay.var("y")
@@ -196,6 +200,7 @@ def test_call_node_order():
          "};\n"
          "%2(%1)")
 
+
 def test_let_inlining():
     tup = relay.Tuple([relay.const(0), relay.const(0)])
     x = relay.var("x")
@@ -208,10 +213,19 @@ def test_let_inlining():
         ("let %x = (0, 0);\n"
          "%x")
 
+
 def test_zeros():
     x = relay.op.zeros([], "float32")
     astext(x)
 
+
+def test_cast():
+    data = relay.var('data', dtype='float32')
+    fp16_cast = relay.cast(data, dtype='float16')
+    cast_func = relay.Function(relay.analysis.free_vars(fp16_cast), fp16_cast)
+    astext(cast_func)
+
+
 if __name__ == "__main__":
     do_print[0] = True
     test_lstm()
@@ -233,3 +247,4 @@ if __name__ == "__main__":
     test_let_if_scope()
     test_variable_name()
     test_call_node_order()
+    test_cast()