Legalize - Use Non-recursive Rewriter. (#5296)
authorAnimesh Jain <anijain@umich.edu>
Fri, 10 Apr 2020 04:01:35 +0000 (21:01 -0700)
committerGitHub <noreply@github.com>
Fri, 10 Apr 2020 04:01:35 +0000 (21:01 -0700)
* Legalize - Use Non-recursive Rewriter.

* Cleanup.

include/tvm/relay/expr_functor.h
src/relay/transforms/legalize.cc

index 6f8ac69..04b2754 100644 (file)
@@ -330,7 +330,7 @@ class MixedModeMutator : public ::tvm::relay::ExprMutator {
  *
  *  ExprRewriter provides a Rewrite interface for modifying graphs in Post-DFS order.
  *
- *  The expectation is that ExprRewriter objects will be passed to PostOrderRewrite, which will
+ * The expectation is that ExprRewriter objects will be passed to PostOrderRewrite, which will
  * non-recursively unroll the graph and call Rewriting on inputs. It will then pass the original
  * node, called `pre`, and a node recreated with any alterned inputs, called `post`, to the
  * ExprRewriter. The ExprRewriter can then use the information in those two nodes to do more complex
@@ -408,7 +408,7 @@ class ExprRewriter {
 
 /*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes
  *
- *  PostOrderRewrite does a non-recursive traversal of the graph in Post-DFS order and calls the
+ * PostOrderRewrite does a non-recursive traversal of the graph in Post-DFS order and calls the
  * ExprRewriter's Rewrite functions on nodes once their inputs are rewritten. At each rewrite call,
  * PostOrderRewrite provides the original node and the node with altered inputs for use by the
  * ExprRewriter.
index 250dd69..01411a6 100644 (file)
@@ -35,19 +35,18 @@ namespace legalize {
 
 // Call registered FTVMLegalize of an op
 // Returns the legalized expression
-class Legalizer : public ExprMutator {
+class Legalizer : public ExprRewriter {
  public:
   explicit Legalizer(const std::string& legalize_map_attr_name)
       : legalize_map_attr_name_{legalize_map_attr_name} {}
 
-  Expr VisitExpr_(const CallNode* call_node) {
+  Expr Rewrite_(const CallNode* call_node, const Expr& post) override {
     // Get the new_call node without any changes to current call node.
-    Expr new_e = ExprMutator::VisitExpr_(call_node);
-    Call new_call = Downcast<Call>(new_e);
+    Call new_call = Downcast<Call>(post);
 
     // Check if the string is registered in the OpRegistry.
     if (!Op::HasAttr(legalize_map_attr_name_)) {
-      return new_e;
+      return post;
     }
 
     // Collect the registered legalize function.
@@ -70,19 +69,18 @@ class Legalizer : public ExprMutator {
         // Transform the op by calling the registered legalize function.
         Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types);
 
-        // Reassign new_e if the transformation succeeded.
+        // Return the new expr if the transformation succeeded.
         if (legalized_value.defined()) {
           // Check that the returned Expr from legalize is CallNode.
           const CallNode* legalized_call_node = legalized_value.as<CallNode>();
           CHECK(legalized_call_node)
               << "Can only replace the original operator with another call node";
-
-          new_e = legalized_value;
+          return legalized_value;
         }
       }
     }
 
-    return new_e;
+    return post;
   }
 
  private:
@@ -90,7 +88,8 @@ class Legalizer : public ExprMutator {
 };
 
 Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name) {
-  return Legalizer(legalize_map_attr_name).Mutate(expr);
+  auto rewriter = Legalizer(legalize_map_attr_name);
+  return PostOrderRewrite(expr, &rewriter);
 }
 
 }  // namespace legalize