add missing gradient check to gradient pass (#4169)
authorAltan Haan <altanh@cs.washington.edu>
Tue, 22 Oct 2019 06:13:55 +0000 (23:13 -0700)
committerJared Roesch <roeschinc@gmail.com>
Tue, 22 Oct 2019 06:13:55 +0000 (23:13 -0700)
src/relay/pass/gradient.cc

index 2606910..8b06b87 100644 (file)
@@ -351,8 +351,6 @@ struct ReverseAD : ExprMutator {
   Expr VisitExpr_(const CallNode* op) final {
     if (const OpNode* op_node = op->op.as<OpNode>()) {
       Op op_ref = GetRef<Op>(op_node);
-      CHECK(rev_map.count(op_ref))
-        << op_node->name << " does not have reverse mode defined";
       return LetList::With([&](LetList* ll) {
         std::vector<Var> args;
         for (const auto& arg : op->args) {
@@ -408,6 +406,34 @@ Expr BPEmpty() {
   return RefCreateNode::make(unitF);
 }
 
+bool MissingGrad(const Expr& e) {
+  struct MGVisitor : ExprVisitor {
+    const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
+    std::unordered_set<std::string> op_names;
+
+    void VisitExpr_(const OpNode* op) final {
+      Op op_ref = GetRef<Op>(op);
+      if (!rev_map.count(op_ref)) {
+        op_names.insert(op_ref->name);
+      }
+      ExprVisitor::VisitExpr_(op);
+    }
+  };
+
+  MGVisitor mg;
+  mg.VisitExpr(e);
+
+  if (mg.op_names.size() > 0) {
+    LOG(WARNING) << "found operators with missing gradients:";
+    for (const auto& op : mg.op_names) {
+      LOG(WARNING) << "    " << op;
+    }
+    return true;
+  }
+
+  return false;
+}
+
 Expr Gradient(const Expr& re, const Module& mod) {
   auto e = DeGlobal(mod, re);
   auto f = e.as<FunctionNode>();
@@ -416,6 +442,7 @@ Expr Gradient(const Expr& re, const Module& mod) {
   for (const auto& p : f->params) {
     CHECK(p->checked_type().as<TensorTypeNode>()) << "input parameters need to be tensor";
   }
+  CHECK(!MissingGrad(e)) << "input has operators with missing gradients";
   Expr body = LetList::With([&](LetList* ll) {
     Var bp = ll->Push(BPEmpty());
     Expr rev = ReverseAD(bp)(e);