"""
return _transform.ToANormalForm()
+def EtaExpand():
+ """Add abstraction over a function
+
+ Returns
+ -------
+ ret: tvm.relay.Pass
+ The registered pass that eta expands an expression.
+ """
+ return _transform.EtaExpand()
def ToGraphNormalForm():
"""Turn A Normal Form expression into Graph Normal Form expression
TVM_REGISTER_API("relay._ir_pass.eta_expand").set_body_typed(EtaExpand);
+namespace transform {
+
+Pass EtaExpand() {
+ runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
+ [=](Function f, Module m, PassContext pc) {
+ return Downcast<Function>(EtaExpand(f, m));
+ };
+ return CreateFunctionPass(pass_func, 1, "EtaExpand", {});
+}
+
+TVM_REGISTER_API("relay._transform.EtaExpand")
+.set_body_typed(EtaExpand);
+
+} // namespace transform
+
} // namespace relay
} // namespace tvm
# specific language governing permissions and limitations
# under the License.
from tvm import relay
+import tvm.relay.module as _module
+import tvm.relay.transform as _transform
def test_eta_expand_basic():
- mod = relay.Module()
x = relay.var('x', 'int32')
- y = relay.var('y', 'int32')
orig = relay.Function([x], x)
- got = relay.ir_pass.eta_expand(orig, mod)
+ mod = _module.Module.from_expr(orig)
+ seq = _transform.Sequential([_transform.EtaExpand()])
+ with _transform.PassContext(opt_level=3):
+ mod = seq(mod)
+
+ got = mod[mod.entry_func.name_hint]
+
+ y = relay.var('y', 'int32')
expected = relay.Function([y], orig(y))
got = relay.ir_pass.infer_type(got, mod)