* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
return Pair(e, RefCreateNode::make(ZerosLike(e)));
}
+ Expr VisitExpr_(const IfNode* op) final {
+ return IfNode::make(TupleGetItemNode::make(VisitExpr(op->cond), 0),
+ VisitExpr(op->true_branch),
+ VisitExpr(op->false_branch));
+ }
+
Type VisitType(const Type& t) final {
return t.defined() ? ReverseADType()(t) : t;
}
tvm.testing.assert_allclose(grad_x.asnumpy(), 2 * np.ones_like(grad_x.asnumpy()))
+def test_if():
+ x = relay.var("x", shape=(1, 16, 64, 64))
+ y = relay.var("y", shape=(1, 16, 64, 64))
+ cond = relay.var("cond", shape=(), dtype='uint1')
+ net = relay.If(cond, x, y)
+ net = relay.log(net)
+ net = relay.ir_pass.infer_type(relay.Function(relay.ir_pass.free_vars(net), net))
+ back_func = relay.ir_pass.infer_type(relay.ir_pass.gradient(net, mode='higher_order'))
+
+
if __name__ == "__main__":
test_id()
test_add()
test_pow()
test_ref()
test_square_second_order()
+ test_if()