Expr VisitExpr_(const CallNode* call) final {
static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
+
+ std::unordered_set<std::string> skip_list{"zeros_like", "ones_like", "full_like", "full"};
+
auto origin_args = call->args;
Expr res = ExprMutator::VisitExpr_(call);
call = res.as<CallNode>();
if (call->args.size() == 0) return res;
const OpNode* op = call->op.as<OpNode>();
if (op == nullptr) return res;
+ if (skip_list.count(op->name)) {
+ return res;
+ }
// skip stateful ops.
if (op_stateful.get(GetRef<Op>(op), false)) return res;
// Try to evaluate shape_of op
assert relay.analysis.graph_equal(zz, zexpected)
+def test_fold_full():
+ c_shape = (8, 9, 10)
+ def before():
+ dtype = 'float32'
+ return relay.full(relay.const(1.0, dtype), c_shape, dtype=dtype)
+
+ def expected():
+ # expect no changes
+ return before()
+
+ zz = run_opt_pass(before(), transform.FoldConstant())
+ zexpected = run_opt_pass(expected(), transform.InferType())
+ assert relay.analysis.graph_equal(zz, zexpected)
+
+
if __name__ == "__main__":
test_fold_const()
test_fold_let()
test_fold_tuple()
test_fold_concat()
test_fold_shape_of()
+ test_fold_full()