#pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
from __future__ import absolute_import
-from ..expr import const
+from ..expr import const, Tuple, TupleGetItem
from .op import register_gradient
from .transform import collapse_sum_like, broadcast_to_like, where
from .tensor import exp, negative, power, less, cos, sin
layout=attrs.layout, ceil_mode=attrs.ceil_mode,
count_include_pad=attrs.count_include_pad)
return [pool_grad]
+
+# not implemented, this is only for testing.
+@register_gradient("concatenate")
+def concatenate_grad(orig, grad):
+ assert len(orig.args) == 1
+ t = orig.args[0]
+ x = TupleGetItem(t, 0)
+ y = TupleGetItem(t, 1)
+ # Assume only two element in tuple rn.
+ # In the real implementation, concatenate_grad probably need to be implemented by an operator.
+ return [Tuple([zeros_like(x), zeros_like(y)])]
#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
#include "pattern_util.h"
+#include "pass_util.h"
#include "let_list.h"
#include "../ir/type_functor.h"
}
};
+Type ReverseType(const Type& t) {
+ return ReverseADType()(t);
+}
+
+/*! \brief Lift a function that transform Tensor to a function that also transform more type
+ * by doing a structure preserving map.
+ */
+Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
+ const Type& t,
+ const Expr& e,
+ LetList* ll) {
+ CHECK(IsAtomic(e)) << e;
+ if (t.as<TensorTypeNode>()) {
+ return f(e);
+ } else if (auto* tt = t.as<TupleTypeNode>()) {
+ tvm::Array<Expr> fields;
+ for (size_t i = 0; i < tt->fields.size(); ++i) {
+ fields.push_back(LiftTensor(f,
+ tt->fields[i],
+ ll->Push(GetField(e, i)),
+ ll));
+ }
+ return TupleNode::make(fields);
+ } else {
+ LOG(FATAL) << "unsupported input/output type: " << tt;
+ throw;
+ }
+}
+
+/*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */
+Expr GetRev(const Type& t, const Expr& e, LetList* ll) {
+ auto rev = [&](const Expr& e) {
+ return Pair(e, ll->Push(RefCreateNode::make(ZerosLike(e))));
+ };
+ return LiftTensor(rev, t, e, ll);
+}
+
+/*! \brief ReverseType(t) -> t. Get the original value. */
+Expr GetValue(const Type& t, const Expr& e, LetList* ll) {
+ return LiftTensor([&](const Expr& e) { return GetField(e, 0); }, t, e, ll);
+}
+
+/*! \brief ReverseType(t) -> t. Get the gradient. */
+Expr GetGrad(const Type& t, const Expr& e, LetList* ll) {
+ auto grad = [&](const Expr& e) {
+ return ll->Push(RefReadNode::make(GetField(e, 1)));
+ };
+ return LiftTensor(grad, t, e, ll);
+}
+
+void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
+ if (t.as<TensorTypeNode>()) {
+ ll->Push(RefWriteNode::make(GetField(arg, 1),
+ Add(ll->Push(RefReadNode::make(GetField(arg, 1))),
+ grad)));
+ } else if (auto* tt = t.as<TupleTypeNode>()) {
+ for (size_t i = 0; i < tt->fields.size(); ++i) {
+ UpdateGrad(tt->fields[i],
+ ll->Push(GetField(arg, i)),
+ ll->Push(GetField(grad, i)),
+ ll);
+ }
+ } else {
+ LOG(FATAL) << "unsupported arg type of operator: " << t;
+ throw;
+ }
+}
+
struct ReverseAD : ExprMutator {
Var bp;
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
- ReverseAD(const Var& bp) : bp(bp) { } /// NOLINT(*)
+ explicit ReverseAD(const Var& bp) : bp(bp) { }
Expr VisitExpr_(const OpNode* op) final {
LOG(FATAL) << "op should only be inside call";
args.push_back(ll->Push(VisitExpr(arg)));
}
std::vector<Expr> orig_args;
- for (const auto& arg : args) {
- orig_args.push_back(GetField(arg, 0));
+ for (size_t i = 0; i < args.size(); ++i) {
+ orig_args.push_back(GetValue(op->args[i]->checked_type(), args[i], ll));
}
Expr orig = CallNode::make(op->op, orig_args, op->attrs, op->type_args);
- Var orig_var = ll->Push(orig);
- auto ref = ll->Push(RefCreateNode::make(ZerosLike(orig_var)));
+ auto ret = ll->Push(GetRev(op->checked_type(), ll->Push(orig), ll));
auto bpv = ll->Push(RefReadNode::make(bp));
Expr nbp = FunctionNode::make(
{},
LetList::With([&](LetList* ll) {
- tvm::Array<Expr> rev = rev_map[op_ref](orig, ll->Push(RefReadNode::make(ref)));
- CHECK(args.size() == rev.size());
- for (size_t i = 0; i < args.size(); ++i) {
- ll->Push(RefWriteNode::make(GetField(args[i], 1),
- Add(ll->Push(RefReadNode::make(GetField(args[i], 1))),
- rev[i])));
- }
+ tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad(op->checked_type(), ret, ll));
+ CHECK(args.size() == rev.size());
+ for (size_t i = 0; i < args.size(); ++i) {
+ UpdateGrad(op->args[i]->checked_type(), args[i], rev[i], ll);
+ }
return CallNode::make(bpv, {});
- }),
+ }),
TupleTypeNode::make({}),
{});
ll->Push(RefWriteNode::make(bp, nbp));
- return Pair(orig_var, ref);
+ return ret;
});
}
return ExprMutator::VisitExpr_(op);
}
Type VisitType(const Type& t) final {
- return t.defined() ? ReverseADType()(t) : t;
+ return t.defined() ? ReverseType(t) : t;
}
};
import tvm
from tvm import relay
-from tvm.relay.analysis import free_vars, free_type_vars
+from tvm.relay.analysis import free_vars, free_type_vars, assert_alpha_equal
from tvm.relay import create_executor, transform
from tvm.relay.transform import gradient
from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, check_grad, rand
+import tvm.relay.op as op
def test_id():
tvm.testing.assert_allclose(grad.asnumpy(), 4 * np.ones_like(x.asnumpy()))
+def test_concat():
+ shape = (10, 10)
+ dtype = 'float32'
+ t = relay.TensorType(shape, dtype)
+ rt = relay.TensorType((10, 20), dtype)
+ x = relay.var("x", t)
+ y = op.concatenate([x, x], axis=1)
+ func = relay.Function([x], y)
+ func = run_infer_type(func)
+ back_func = run_infer_type(gradient(func))
+ assert_alpha_equal(back_func.checked_type, relay.FuncType([t], relay.TupleType([rt, relay.TupleType([t])])))
+ # no value validation as concatenate has dummy gradient right now.
+
+
if __name__ == "__main__":
test_id()
test_add()