[Relay] Fix ad for conditional expression (#3453)
author雾雨魔理沙 <lolisa@marisa.moe>
Fri, 28 Jun 2019 09:09:17 +0000 (02:09 -0700)
committerWuwei Lin <wuwelin@amazon.com>
Fri, 28 Jun 2019 09:09:17 +0000 (17:09 +0800)
* save

* fix

src/relay/pass/gradient.cc
tests/python/relay/test_pass_gradient.py

index 91072b3..5d26f7a 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -311,6 +311,12 @@ struct ReverseAD : ExprMutator {
     return Pair(e, RefCreateNode::make(ZerosLike(e)));
   }
 
+  Expr VisitExpr_(const IfNode* op) final {
+    return IfNode::make(TupleGetItemNode::make(VisitExpr(op->cond), 0),
+                        VisitExpr(op->true_branch),
+                        VisitExpr(op->false_branch));
+  }
+
   Type VisitType(const Type& t) final {
     return t.defined() ? ReverseADType()(t) : t;
   }
index d99bee5..6fece1b 100644 (file)
@@ -231,6 +231,16 @@ def test_square_second_order():
     tvm.testing.assert_allclose(grad_x.asnumpy(), 2 * np.ones_like(grad_x.asnumpy()))
 
 
+def test_if():
+    x = relay.var("x", shape=(1, 16, 64, 64))
+    y = relay.var("y", shape=(1, 16, 64, 64))
+    cond = relay.var("cond", shape=(), dtype='uint1')
+    net = relay.If(cond, x, y)
+    net = relay.log(net)
+    net = relay.ir_pass.infer_type(relay.Function(relay.ir_pass.free_vars(net), net))
+    back_func = relay.ir_pass.infer_type(relay.ir_pass.gradient(net, mode='higher_order'))
+
+
 if __name__ == "__main__":
     test_id()
     test_add()
@@ -242,3 +252,4 @@ if __name__ == "__main__":
     test_pow()
     test_ref()
     test_square_second_order()
+    test_if()