[Relay] Make check stricter: disallow inserting function with free vars into module...
author雾雨魔理沙 <lolisa@marisa.moe>
Sat, 22 Aug 2020 05:11:50 +0000 (22:11 -0700)
committerGitHub <noreply@github.com>
Sat, 22 Aug 2020 05:11:50 +0000 (22:11 -0700)
* save

lint

lint

fix test

fix test

* fix

python/tvm/relay/prelude.py
src/ir/module.cc
tests/python/frontend/tensorflow/test_forward.py
tests/python/relay/test_type_infer.py
tests/python/relay/test_vm_serialization.py
tutorials/dev/use_pass_infra.py

index 5b2ecc2..1b7ed77 100644 (file)
@@ -25,7 +25,7 @@ from .op.tensor import add, subtract, equal
 from .adt import Constructor, TypeData, Clause, Match
 from .adt import PatternConstructor, PatternVar, PatternWildcard
 from . import op, transform
-
+from .analysis import free_vars
 
 def get_tensor_array_shape(expr, dtype, prelude):
     """Get the static shape of a tensor array if it has fixed rank shape.
@@ -51,7 +51,7 @@ def get_tensor_array_shape(expr, dtype, prelude):
         has dynamic shape.
     """
     mod = prelude.mod
-    mod["main"] = Function([], expr)
+    mod["main"] = Function(free_vars(expr), expr)
     mod = transform.InferType()(mod)
     checked_type = mod["main"].body.checked_type
     assert isinstance(checked_type, TypeCall), "Input must be a tensor array."
index b347408..bcab39a 100644 (file)
@@ -189,16 +189,10 @@ relay::Function RunTypeCheck(const IRModule& mod, const GlobalVar& var, relay::F
   // Type check the item before we add it to the module.
   auto fv = relay::FreeVars(func);
   auto ftv = relay::FreeTypeVars(func, mod);
-  if (fv.size() != 0) {
-    LOG(WARNING) << "There are free variables: " << fv << " in function: " << AsText(func, false)
-                 << std::endl;
-  }
-  if (ftv.size() != 0) {
-    LOG(WARNING) << "There are free type variables: " << ftv
-                 << " in function: " << AsText(func, false) << std::endl;
-  }
-  func = relay::Function(concat(func->params, fv), func->body, func->ret_type,
-                         concat(func->type_params, ftv), func->attrs);
+  CHECK_EQ(fv.size(), 0) << "There are free variables: " << fv
+                         << " in function: " << AsText(func, false);
+  CHECK_EQ(ftv.size(), 0) << "There are free type variables: " << fv
+                          << " in function: " << AsText(func, false);
   // Type check the item before we add it to the module.
   relay::Function checked_func = InferType(func, mod, var);
   return checked_func;
index e6f1687..799d9c2 100644 (file)
@@ -3852,143 +3852,5 @@ def test_forward_dynmaic_rnn_lstmblockcell():
                 tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
 
 
-#######################################################################
-# Main
-# ----
 if __name__ == '__main__':
-    # Transforms
-    test_forward_slice()
-    test_forward_transpose()
-    test_forward_reshape()
-    test_forward_depthtospace()
-    test_forward_spacetodepth()
-    test_forward_squeeze()
-    test_forward_pack()
-    test_forward_size()
-    test_forward_broadcast_to()
-    test_forward_fill()
-    test_forward_crop()
-    test_forward_resize()
-    test_forward_crop_and_resize()
-    test_forward_pad()
-    test_forward_unpack()
-    test_forward_gather()
-    test_forward_gather_nd()
-    test_forward_stridedslice()
-    test_forward_split()
-    test_forward_unstack()
-    test_forward_tile()
-    test_forward_top_k_v2()
-    test_forward_clip_by_value()
-    test_forward_maximum()
-    test_forward_minimum()
-    test_forward_range()
-    test_forward_right_shift()
-    test_forward_left_shift()
-    test_forward_truncatemod()
-    test_forward_one_hot()
-    test_forward_atan2()
-    test_forward_nms()
-
-    # Activations
-    test_forward_sigmoid()
-    test_forward_relu()
-    test_forward_leaky_relu()
-    test_forward_elu()
-    test_forward_selu()
-    test_forward_tanh()
-
-    # Tensor
-    test_forward_round()
-    test_forward_reverse_v2()
-    test_forward_pow_exp()
-    test_forward_sign()
-    test_forward_negative()
-    test_forward_divide()
-    test_forward_abs()
-    test_forward_softplus()
-    test_forward_sqrt()
-    test_forward_rsqrt()
-    test_forward_expand_dims()
-    test_forward_square()
-    test_forward_softmax()
-    test_forward_log_softmax()
-    test_forward_bias_add()
-    test_forward_zeros_like()
-    test_forward_squared_difference()
-    test_forward_add_n()
-    test_forward_floormod()
-    test_forward_isfinite()
-    test_forward_isinf()
-    test_forward_unravel_index()
-    test_forward_unary()
-
-    # Reductions
-    test_forward_argminmax()
-    test_forward_reduce()
-    test_forward_mean()
-
-    # TensorArray
-    test_tensor_array_write_read()
-    test_tensor_array_concat()
-    test_tensor_array_scatter()
-    test_tensor_array_gather()
-    test_tensor_array_size()
-    test_tensor_array_split()
-    test_tensor_array_stack()
-    test_tensor_array_unstack()
-
-    # General
-    test_forward_multi_input()
-    test_forward_multi_output()
-    test_forward_variable()
-    test_placeholder()
-
-    # NN
-    test_forward_convolution()
-    test_forward_convolution3d()
-    test_forward_convolution3d_transpose()
-    test_forward_pooling()
-    test_forward_concat_v2()
-    test_forward_lrn()
-    test_forward_l2_normalize()
-    test_forward_space_to_batch_nd()
-    test_forward_batch_to_space_nd()
-    test_forward_dilation()
-
-    # End to End
-    test_forward_inception_v3()
-    test_forward_inception_v1()
-    test_forward_mobilenet()
-    test_forward_resnetv2()
-    test_forward_ssd()
-    test_forward_placeholder()
-    test_forward_ptb()
-
-    # RNN
-    test_forward_lstm()
-
-    # Elementwise
-    test_forward_ceil()
-    test_forward_floor()
-
-    # Relational ops
-    test_forward_rel_ops()
-    test_forward_logical()
-    test_forward_where()
-    test_forward_matmul()
-    test_forward_batch_matmul()
-
-    # Internal misc. ops
-    test_read_variable_op()
-
-    # Sharing params case using Mean ops
-    test_sharing_node()
-
-    # StatefulPartitionedCall
-    test_forward_spop()
-
-    # Test dynamic input shape
-    test_forward_dynamic_input_shape()
-
-    test_forward_dynmaic_rnn_lstmblockcell()
+    pytest.main([__file__])
index e5082db..cc4748c 100644 (file)
 """Test that type checker correcly computes types
    for expressions.
 """
+import pytest
 import tvm
 from tvm import te
 from tvm import relay
 from tvm.relay import op, transform, analysis
 from tvm.relay import Any
 
-
 def run_infer_type(expr, mod=None):
     if not mod:
         mod = tvm.IRModule.from_expr(expr)
@@ -368,26 +368,9 @@ def test_if():
     f = relay.Var('f', choice_t)
     true_branch = relay.Var('True', relay.TensorType([Any(), 1], dtype='float32'))
     false_branch = relay.Var('False', relay.TensorType([Any(), Any()], dtype='float32'))
-    top = relay.Function([true_branch, false_branch], relay.If(f(), true_branch, false_branch))
+    top = relay.Function([f, true_branch, false_branch], relay.If(f(), true_branch, false_branch))
     ft = run_infer_type(top)
     tvm.ir.assert_structural_equal(ft.ret_type, relay.TensorType([Any(), 1], dtype='float32'))
 
 if __name__ == "__main__":
-    test_free_expr()
-    test_dual_op()
-    test_single_op()
-    test_recursion()
-    test_monomorphic_let()
-    test_decl()
-    test_recursion()
-    test_tuple()
-    test_incomplete_call()
-    test_type_args()
-    test_global_var_recursion()
-    test_equal()
-    test_ref()
-    test_constructor_type()
-    test_constructor_call()
-    test_adt_match()
-    test_let_polymorphism()
-    test_if()
+    pytest.main([__file__])
index 2aae431..df3bbc1 100644 (file)
@@ -16,6 +16,7 @@
 # under the License.
 # pylint: disable=invalid-name, missing-docstring, no-else-return
 """Unit tests for the Relay VM serialization and deserialization."""
+import pytest
 import numpy as np
 
 import tvm
@@ -291,22 +292,11 @@ def test_vm_shape_of():
 
     newshape_var = relay.var('newshape', shape=(2,), dtype='int64')
     args.append(np.array((1, -1), dtype='int64'))
-    main = relay.reshape(relu_x, newshape=newshape_var)
+    main = relay.Function([x, newshape_var], relay.reshape(relu_x, newshape=newshape_var))
 
     res = get_serialized_output(main, *args).asnumpy()
     tvm.testing.assert_allclose(res.flatten(), data.flatten())
 
 
 if __name__ == "__main__":
-    test_serializer()
-    test_save_load()
-    test_const()
-    test_if()
-    test_loop()
-    test_tuple()
-    test_adt_list()
-    test_adt_compose()
-    test_closure()
-    test_synthetic()
-    test_mobilenet()
-    test_vm_shape_of()
+    pytest.main([__file__])
index 8212334..4b842b9 100644 (file)
@@ -65,7 +65,7 @@ def example():
     z = relay.add(y, c)
     z1 = relay.add(y, c)
     z2 = relay.add(z, z1)
-    return relay.Function([x], z2)
+    return relay.Function([x, weight], z2)
 
 ###############################################################################
 # Let us register layout alteration for a conv2d op so that we can apply the