Add EtaExpand to transform API (#3406)
authorWei Chen <ipondering.weic@gmail.com>
Thu, 20 Jun 2019 20:41:41 +0000 (13:41 -0700)
committerHaichen Shen <shenhaichen@gmail.com>
Thu, 20 Jun 2019 20:41:41 +0000 (13:41 -0700)
* Add EtaExpand to transform API

* Add test case

python/tvm/relay/transform.py
src/relay/pass/eta_expand.cc
tests/python/relay/test_pass_eta_expand.py

index 3fae615..5f47e5b 100644 (file)
@@ -406,6 +406,15 @@ def ToANormalForm():
     """
     return _transform.ToANormalForm()
 
+def EtaExpand():
+    """Add abstraction over a function
+
+    Returns
+    -------
+    ret: tvm.relay.Pass
+        The registered pass that eta expands an expression.
+    """
+    return _transform.EtaExpand()
 
 def ToGraphNormalForm():
     """Turn A Normal Form expression into Graph Normal Form expression
index 0193b9a..3139d41 100644 (file)
@@ -67,5 +67,20 @@ Expr EtaExpand(const Expr& e, const Module& mod) {
 
 TVM_REGISTER_API("relay._ir_pass.eta_expand").set_body_typed(EtaExpand);
 
+namespace transform {
+
+Pass EtaExpand() {
+  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+    [=](Function f, Module m, PassContext pc) {
+    return Downcast<Function>(EtaExpand(f, m));
+  };
+  return CreateFunctionPass(pass_func, 1, "EtaExpand", {});
+}
+
+TVM_REGISTER_API("relay._transform.EtaExpand")
+.set_body_typed(EtaExpand);
+
+}  // namespace transform
+
 }  // namespace relay
 }  // namespace tvm
index 40a8428..4e20b02 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 from tvm import relay
+import tvm.relay.module as _module
+import tvm.relay.transform as _transform
 
 def test_eta_expand_basic():
-    mod = relay.Module()
     x = relay.var('x', 'int32')
-    y = relay.var('y', 'int32')
     orig = relay.Function([x], x)
-    got = relay.ir_pass.eta_expand(orig, mod)
+    mod = _module.Module.from_expr(orig)
+    seq = _transform.Sequential([_transform.EtaExpand()])
+    with _transform.PassContext(opt_level=3):
+        mod = seq(mod)
+
+    got = mod[mod.entry_func.name_hint]
+
+    y = relay.var('y', 'int32')
     expected = relay.Function([y], orig(y))
 
     got = relay.ir_pass.infer_type(got, mod)