Expr VisitExpr_(const CallNode* op) final {
if (const OpNode* op_node = op->op.as<OpNode>()) {
Op op_ref = GetRef<Op>(op_node);
- CHECK(rev_map.count(op_ref))
- << op_node->name << " does not have reverse mode defined";
return LetList::With([&](LetList* ll) {
std::vector<Var> args;
for (const auto& arg : op->args) {
return RefCreateNode::make(unitF);
}
+bool MissingGrad(const Expr& e) {
+ struct MGVisitor : ExprVisitor {
+ const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
+ std::unordered_set<std::string> op_names;
+
+ void VisitExpr_(const OpNode* op) final {
+ Op op_ref = GetRef<Op>(op);
+ if (!rev_map.count(op_ref)) {
+ op_names.insert(op_ref->name);
+ }
+ ExprVisitor::VisitExpr_(op);
+ }
+ };
+
+ MGVisitor mg;
+ mg.VisitExpr(e);
+
+ if (mg.op_names.size() > 0) {
+ LOG(WARNING) << "found operators with missing gradients:";
+ for (const auto& op : mg.op_names) {
+ LOG(WARNING) << " " << op;
+ }
+ return true;
+ }
+
+ return false;
+}
+
Expr Gradient(const Expr& re, const Module& mod) {
auto e = DeGlobal(mod, re);
auto f = e.as<FunctionNode>();
for (const auto& p : f->params) {
CHECK(p->checked_type().as<TensorTypeNode>()) << "input parameters need to be tensor";
}
+ CHECK(!MissingGrad(e)) << "input has operators with missing gradients";
Expr body = LetList::With([&](LetList* ll) {
Var bp = ll->Push(BPEmpty());
Expr rev = ReverseAD(bp)(e);