np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
- def test_forgot_kernel_arena(self):
- self.assertRaises(RuntimeError, lambda: torch._C._te.VarHandle("n", torch._C._te.Dtype.Int))
-
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
def test_alloc_in_loop(self):
with kernel_arena_scope():
return Value(result_v);
}
- template <typename Op>
- void visit_binary_op(BinaryOpNode<Op>* v, bool option = false) {
+ template <
+ typename D,
+ typename std::enable_if<std::is_same<
+ decltype(detail::bin_op_deducer(std::declval<D>())),
+ void>::value>::type* = nullptr>
+ void visit_binary_op(NodePtr<D> v, bool option = false) {
v->lhs()->accept(this);
Value lhs_v = value_;
v->rhs()->accept(this);
};
// The common base between all expression node.
-class TORCH_API Expr : public KernelScopedObject {
+class TORCH_API Expr : public std::enable_shared_from_this<Expr> {
public:
explicit Expr(Dtype dtype, IRNodeType expr_type = kOther)
: dtype_(dtype), expr_type_(expr_type) {}
+ virtual ~Expr() = default;
Dtype dtype() const {
return dtype_;
}
*/
static ExprPtr clone(ExprPtr s);
+ protected:
+ std::shared_ptr<Expr> getptr() {
+ return shared_from_this();
+ }
+
private:
Dtype dtype_;
IRNodeType expr_type_;
public:
using ExprNodeBase = ExprNode<Op>;
void accept(IRVisitor* visitor) override {
- visitor->visit(static_to<Op>(this));
+ visitor->visit(static_to<Op>(Base::getptr()));
}
ExprPtr accept_mutator(IRMutator* mutator) override;
// pass the constructor to the base class
template <class Op, class Base>
ExprPtr ExprNode<Op, Base>::accept_mutator(IRMutator* mutator) {
- return mutator->mutate(static_to<Op>(this));
+ return mutator->mutate(static_to<Op>(Base::getptr()));
}
inline bool same_node(const ExprHandle& expr1, const ExprHandle& expr2) {
#pragma once
#include <c10/core/ScalarType.h>
+#include <memory>
namespace torch {
namespace jit {
namespace tensorexpr {
template <typename Node>
-using NodePtr = Node*;
+using NodePtr = std::shared_ptr<Node>;
template <typename To, typename From>
NodePtr<To> to(NodePtr<From> x) {
- return dynamic_cast<NodePtr<To>>(x);
+ return std::dynamic_pointer_cast<To>(x);
}
template <typename To, typename From>
NodePtr<To> static_to(NodePtr<From> x) {
- return static_cast<NodePtr<To>>(x);
+ return std::static_pointer_cast<To>(x);
}
template <typename Node, typename... Args>
NodePtr<Node> alloc(Args&&... args) {
- return new Node(std::forward<Args>(args)...);
+ return std::make_shared<Node>(std::forward<Args>(args)...);
}
class Buf;
putHash(v, hash_combine(hashOf(v->lhs()), "%", hashOf(v->rhs())));
}
+void HashProvider::visit(RoundOffPtr v) {
+ CACHE_GUARD();
+ v->lhs()->accept(this);
+ v->rhs()->accept(this);
+ putHash(v, hash_combine(hashOf(v->lhs()), "rof", hashOf(v->rhs())));
+}
+
void HashProvider::visit(MaxPtr v) {
CACHE_GUARD();
v->lhs()->accept(this);
return hashOf(e);
}
- bool cachedHash(const KernelScopedObject* e) {
+ bool cachedHash(ExprPtr e) {
return exprToHash_.find(e) != exprToHash_.end();
}
+ bool cachedHash(StmtPtr s) {
+ return stmtToHash_.find(s) != stmtToHash_.end();
+ }
void clearCache() {
exprToHash_.clear();
+ stmtToHash_.clear();
}
void visit(AddPtr v) override;
void visit(MulPtr v) override;
void visit(DivPtr v) override;
void visit(ModPtr v) override;
+ void visit(RoundOffPtr v) override;
void visit(MaxPtr v) override;
void visit(MinPtr v) override;
void visit(AndPtr v) override;
}
SimplifierHashType hashOf(StmtPtr s) {
- auto it = exprToHash_.find(s);
- if (it != exprToHash_.end()) {
+ auto it = stmtToHash_.find(s);
+ if (it != stmtToHash_.end()) {
return it->second;
}
_hash_combine(seed, args...);
}
- void putHash(const KernelScopedObject* e, SimplifierHashType h) {
+ void putHash(ExprPtr e, SimplifierHashType h) {
auto res = exprToHash_.emplace(e, h);
if (res.second == false) {
// This is always a logic bug since we should check the cache first.
throw std::runtime_error("hash collision");
}
}
+ void putHash(StmtPtr s, SimplifierHashType h) {
+ auto res = stmtToHash_.emplace(s, h);
+ if (res.second == false) {
+ // This is always a logic bug since we should check the cache first.
+ throw std::runtime_error("hash collision");
+ }
+ }
- std::unordered_map<const KernelScopedObject*, SimplifierHashType> exprToHash_;
+ std::unordered_map<ExprPtr, SimplifierHashType> exprToHash_;
+ std::unordered_map<StmtPtr, SimplifierHashType> stmtToHash_;
UniqueNameManager name_manager_;
size_t te_hash(SimplifierHashType val) {
ExprPtr rhs_;
};
+namespace detail {
+template <typename T>
+void bin_op_deducer(BinaryOpNode<T>);
+bool bin_op_deducer(...);
+} // namespace detail
+
class TORCH_API Add : public BinaryOpNode<Add> {
public:
Add(ExprPtr lhs, ExprPtr rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kAdd) {}
namespace jit {
namespace tensorexpr {
-template <typename Op>
+template <
+ typename Op,
+ typename std::enable_if<std::is_same<
+ decltype(detail::bin_op_deducer(std::declval<Op>())),
+ void>::value>::type* = nullptr>
static ExprPtr mutate_binary_op(
- NodePtr<BinaryOpNode<Op>> v,
+ NodePtr<Op> v,
IRCloner* cloner,
bool option = false) {
ExprPtr lhs_new = v->lhs()->accept_mutator(cloner);
namespace jit {
namespace tensorexpr {
-template <typename Op>
+template <
+ typename Op,
+ typename std::enable_if<std::is_same<
+ decltype(detail::bin_op_deducer(std::declval<Op>())),
+ void>::value>::type* = nullptr>
static ExprPtr mutate_binary_op(
- BinaryOpNode<Op>* v,
+ NodePtr<Op> v,
IRMutator* mutator,
bool option = false) {
ExprPtr lhs = v->lhs();
// TODO: change whether to include the parenthesis to the parent expression,
// we need to look at the operator precedence to make the output simpler.
-template <typename Op>
+template <
+ typename Op,
+ typename std::enable_if<std::is_same<
+ decltype(detail::bin_op_deducer(std::declval<Op>())),
+ void>::value>::type* = nullptr>
void visitBinaryOp(
- BinaryOpNode<Op>* v,
+ NodePtr<Op> v,
const std::string& op_str,
IRPrinter* printer,
bool parens = true) {
namespace jit {
namespace tensorexpr {
+// Creates a new Expr of the given type with the provided lhs and rhs.
+inline ExprPtr newBinaryOpOfType(
+ IRNodeType expr_type,
+ ExprPtr lhs,
+ ExprPtr rhs,
+ bool option) {
+ switch (expr_type) {
+ // NOLINTNEXTLINE(bugprone-branch-clone)
+ case IRNodeType::kAdd:
+ return alloc<Add>(lhs, rhs);
+ case IRNodeType::kSub:
+ return alloc<Sub>(lhs, rhs);
+ case IRNodeType::kMul:
+ return alloc<Mul>(lhs, rhs);
+ case IRNodeType::kDiv:
+ return alloc<Div>(lhs, rhs);
+ case IRNodeType::kMod:
+ return alloc<Mod>(lhs, rhs);
+ case IRNodeType::kMax:
+ return alloc<Max>(lhs, rhs, option);
+ case IRNodeType::kMin:
+ return alloc<Min>(lhs, rhs, option);
+ case IRNodeType::kAnd:
+ return alloc<And>(lhs, rhs);
+ case IRNodeType::kXor:
+ return alloc<Xor>(lhs, rhs);
+ case IRNodeType::kLshift:
+ return alloc<Lshift>(lhs, rhs);
+ case IRNodeType::kRshift:
+ return alloc<Rshift>(lhs, rhs);
+ default:
+ LOG(FATAL) << "unsupported expr_type: " << static_cast<int>(expr_type);
+ return nullptr;
+ }
+}
+
+template <
+ typename Op,
+ typename std::enable_if<std::is_same<
+ decltype(detail::bin_op_deducer(std::declval<Op>())),
+ void>::value>::type* = nullptr>
+static ExprPtr mutateBinaryOp(
+ NodePtr<Op> v,
+ IRMutator* mutator,
+ bool option = false) {
+ ExprPtr lhs = v->lhs();
+ ExprPtr rhs = v->rhs();
+ ExprPtr lhs_new = lhs->accept_mutator(mutator);
+ ExprPtr rhs_new = rhs->accept_mutator(mutator);
+
+ ExprPtr node = v;
+
+ if (lhs != lhs_new || rhs != rhs_new) {
+ node = newBinaryOpOfType(v->expr_type(), lhs_new, rhs_new, option);
+ }
+
+ // Can only fold if both sides are constant.
+ if (!lhs_new->isConstant() || !rhs_new->isConstant()) {
+ return node;
+ }
+
+ return evaluateOp(node);
+}
+
// Simple recursive GCD.
template <typename T>
T gcd(T a, T b) {
return alloc<IfThenElse>(condition_new, true_value_new, false_value_new);
}
+ExprPtr PolynomialTransformer::mutate(AndPtr v) {
+ return mutateBinaryOp(v, this);
+}
+
+ExprPtr PolynomialTransformer::mutate(XorPtr v) {
+ return mutateBinaryOp(v, this);
+}
+
+ExprPtr PolynomialTransformer::mutate(LshiftPtr v) {
+ return mutateBinaryOp(v, this);
+}
+
+ExprPtr PolynomialTransformer::mutate(RshiftPtr v) {
+ return mutateBinaryOp(v, this);
+}
+
StmtPtr PolynomialBase::mutate(CondPtr v) {
ExprPtr cond_old = v->condition();
StmtPtr true_old = v->true_stmt();
scalar = getImmediateByType(multiplier->dtype(), 1);
}
+ // TODO: this leaks memory!
return new ModRound(scalar, denom, divisor, mod_divisor);
}
template <class ExprType>
Dtype promoteTypesMap(
ExprPtr s,
- std::unordered_map<SimplifierHashType, ExprType*>& m) {
+ std::unordered_map<SimplifierHashType, ExprType>& m) {
Dtype t = s->dtype();
bool first = true;
for (auto& e : m) {
}
template <class ExprType>
-Dtype promoteTypesVar(ExprType* e) {
+Dtype promoteTypesVar(ExprType e) {
return e->dtype();
}
template <class ExprType, class... Args>
-Dtype promoteTypesVar(ExprType* e, Args... es) {
+Dtype promoteTypesVar(ExprType e, Args... es) {
Dtype lhs = e->dtype();
Dtype rhs = promoteTypesVar(es...);
if (e->isConstant()) {
return promoteTypes(lhs, rhs);
}
-// Creates a new Expr of the given type with the provided lhs and rhs.
-inline ExprPtr newBinaryOpOfType(
- IRNodeType expr_type,
- ExprPtr lhs,
- ExprPtr rhs,
- bool option) {
- switch (expr_type) {
- // NOLINTNEXTLINE(bugprone-branch-clone)
- case IRNodeType::kAdd:
- return alloc<Add>(lhs, rhs);
- case IRNodeType::kSub:
- return alloc<Sub>(lhs, rhs);
- case IRNodeType::kMul:
- return alloc<Mul>(lhs, rhs);
- case IRNodeType::kDiv:
- return alloc<Div>(lhs, rhs);
- case IRNodeType::kMod:
- return alloc<Mod>(lhs, rhs);
- case IRNodeType::kMax:
- return alloc<Max>(lhs, rhs, option);
- case IRNodeType::kMin:
- return alloc<Min>(lhs, rhs, option);
- case IRNodeType::kAnd:
- return alloc<And>(lhs, rhs);
- case IRNodeType::kXor:
- return alloc<Xor>(lhs, rhs);
- case IRNodeType::kLshift:
- return alloc<Lshift>(lhs, rhs);
- case IRNodeType::kRshift:
- return alloc<Rshift>(lhs, rhs);
- default:
- LOG(FATAL) << "unsupported expr_type: " << static_cast<int>(expr_type);
- return nullptr;
- }
-}
-
// Uses the evaluator to fold an Expression with constant terms.
// E.g. evaluateOp(Add(3, 4)) => 7.
// Expr v must not have any unbound Vars.
ExprPtr mutate(ModPtr v) override;
- ExprPtr mutate(AndPtr v) override {
- return mutateBinaryOp(v, this);
- }
+ ExprPtr mutate(AndPtr v) override;
- ExprPtr mutate(XorPtr v) override {
- return mutateBinaryOp(v, this);
- }
+ ExprPtr mutate(XorPtr v) override;
- ExprPtr mutate(LshiftPtr v) override {
- return mutateBinaryOp(v, this);
- }
+ ExprPtr mutate(LshiftPtr v) override;
- ExprPtr mutate(RshiftPtr v) override {
- return mutateBinaryOp(v, this);
- }
+ ExprPtr mutate(RshiftPtr v) override;
ExprPtr mutate(MaxPtr v) override;
ExprPtr mutate(IfThenElsePtr v) override;
- template <typename Op>
- static ExprPtr mutateBinaryOp(
- BinaryOpNode<Op>* v,
- IRMutator* mutator,
- bool option = false) {
- ExprPtr lhs = v->lhs();
- ExprPtr rhs = v->rhs();
- ExprPtr lhs_new = lhs->accept_mutator(mutator);
- ExprPtr rhs_new = rhs->accept_mutator(mutator);
-
- ExprPtr node = v;
-
- if (lhs != lhs_new || rhs != rhs_new) {
- node = newBinaryOpOfType(v->expr_type(), lhs_new, rhs_new, option);
- }
-
- // Can only fold if both sides are constant.
- if (!lhs_new->isConstant() || !rhs_new->isConstant()) {
- return node;
- }
-
- return evaluateOp(node);
- }
-
static ExprPtr simplify(ExprPtr e);
static ExprHandle simplify(const ExprHandle& e);
static StmtPtr simplify(StmtPtr e);
namespace jit {
namespace tensorexpr {
-template <typename Op>
-void verifyBitwiseOp(const BitwiseOpNode<Op>* v, IRVerifier* verifier) {
+namespace detail {
+template <typename T>
+void deducer(BinaryOpNode<T>);
+
+bool deducer(...);
+} // namespace detail
+
+template <
+ typename D,
+ typename std::enable_if<std::is_same<
+ decltype(detail::deducer(std::declval<D>())),
+ void>::value>::type* = nullptr>
+void verifyBitwiseOp(NodePtr<D> v, IRVerifier* verifier) {
if (!v->lhs()->dtype().is_integral()) {
throw unsupported_dtype();
}
namespace jit {
namespace tensorexpr {
-template <typename Op>
-static void visit_binary_op(BinaryOpNode<Op>* v, IRVisitor* visitor) {
+template <
+ typename Op,
+ typename std::enable_if<std::is_same<
+ decltype(detail::bin_op_deducer(std::declval<Op>())),
+ void>::value>::type* = nullptr>
+static void visit_binary_op(NodePtr<Op> v, IRVisitor* visitor) {
v->lhs()->accept(visitor);
v->rhs()->accept(visitor);
}
if (v->op_type() == kTanh) {
ScalarType stype = v->dtype().scalar_type();
if (stype == ScalarType::Float) {
- return fast_tanh(v->param(0)->accept_mutator(this)).node();
+ return fast_tanh(ExprHandle(v->param(0)->accept_mutator(this))).node();
}
} else if (v->op_type() == kSigmoid) {
ScalarType stype = v->dtype().scalar_type();
if (stype == ScalarType::Float) {
- return fast_sigmoid(v->param(0)->accept_mutator(this)).node();
+ return fast_sigmoid(ExprHandle(v->param(0)->accept_mutator(this)))
+ .node();
}
}
// TODO: fast exp
void LoopNest::compressAllBuffers(StmtPtr stmt) {
for (auto buf : BufFinder::find(stmt)) {
- compressBuffer(const_cast<BufPtr>(buf), stmt);
+ compressBuffer(buf, stmt);
}
}
class Placeholder;
// The common base between all statement node.
-class TORCH_API Stmt : public KernelScopedObject {
+class TORCH_API Stmt : public std::enable_shared_from_this<Stmt> {
public:
Stmt() = default;
+ virtual ~Stmt() = default;
virtual void accept(IRVisitor* visitor) = 0;
virtual StmtPtr accept_mutator(IRMutator* mutator) = 0;
StmtPtr get_parent() const {
- return parent_;
+ return parent_ ? parent_->getptr() : nullptr;
}
/*
static StmtPtr clone(StmtPtr s);
protected:
- static void set_parent(StmtPtr s, StmtPtr new_parent) {
+ static void set_parent(StmtPtr s, Stmt* new_parent) {
s->parent_ = new_parent;
}
+ std::shared_ptr<Stmt> getptr() {
+ return shared_from_this();
+ }
private:
- StmtPtr parent_ = nullptr;
+ Stmt* parent_ = nullptr;
};
template <class Op>
public:
using StmtNodeBase = StmtNode<Op>;
void accept(IRVisitor* visitor) override {
- visitor->visit(static_to<Op>(this));
+ visitor->visit(static_to<Op>(getptr()));
}
StmtPtr accept_mutator(IRMutator* mutator) override;
StmtNode() = default;
template <class Op>
StmtPtr StmtNode<Op>::accept_mutator(IRMutator* mutator) {
- return mutator->mutate(static_to<Op>(this));
+ return mutator->mutate(static_to<Op>(getptr()));
}
// Concrete Stmt classes
}
void clear() {
- for (auto* s : stmts_) {
+ for (auto s : stmts_) {
set_parent(s, nullptr);
}
stmts_.clear();
// returns the immediate child containing statement s.
StmtPtr getEnclosedRoot(StmtPtr s) const {
- while (s && s->get_parent() != this) {
+ while (s && s->get_parent().get() != this) {
s = s->get_parent();
}
return s;
[](Placeholder& self,
const std::vector<ExprHandle>& args,
const ExprHandle& val) { return self.store(args, val); })
- .def(
- "data",
- [](Placeholder& self) { return BufHandle(self.data()); },
- py::return_value_policy::reference);
+ .def("data", [](Placeholder& self) { return BufHandle(self.data()); });
py::class_<Tensor, std::unique_ptr<Tensor, py::nodelete>>(te, "Tensor")
.def(py::init(
[](BufHandle& b, StmtPtr s) { return new Tensor(b.node(), s); }))
return self.load(v);
})
.def("buf", [](Tensor& self) { return BufHandle(self.buf()); })
- .def("stmt", &Tensor::stmt, py::return_value_policy::reference);
- py::class_<Cast>(te, "Cast").def_static("make", &Cast::make);
+ .def("stmt", &Tensor::stmt);
+ py::class_<Cast, std::shared_ptr<Cast>>(te, "Cast")
+ .def_static("make", &Cast::make);
py::class_<DimArg>(te, "DimArg")
.def(py::init<const ExprHandle&>())
},
py::return_value_policy::reference);
- py::class_<Stmt, std::unique_ptr<Stmt, py::nodelete>>(te, "Stmt")
+ py::class_<Stmt, std::shared_ptr<Stmt>>(te, "Stmt")
.def(py::init([](const std::vector<StmtPtr>& stmts) {
return tensorexpr::Block::make(stmts);
}))
ss << self;
return ss.str();
});
- py::class_<Store, Stmt, std::unique_ptr<Store, py::nodelete>>(te, "Store")
+ py::class_<Store, Stmt, std::shared_ptr<Store>>(te, "Store")
.def_static(
"make",
[](const BufHandle& buf,
std::vector<ExprHandle>& indices,
const ExprHandle& value) {
return Store::make(buf, indices, value);
- },
- py::return_value_policy::reference);
+ });
- py::class_<For, Stmt, std::unique_ptr<For, py::nodelete>>(te, "For")
- .def(
- "index_var",
- [](For& self) { return VarHandle(self.var()); },
- py::return_value_policy::reference)
- .def("body", &For::body, py::return_value_policy::reference)
+ py::class_<For, Stmt, std::shared_ptr<For>>(te, "For")
+ .def("index_var", [](For& self) { return VarHandle(self.var()); })
+ .def("body", &For::body)
.def("set_parallel", &For::set_parallel)
.def(
"set_gpu_block_index",
[](const VarHandle& var,
const ExprHandle& start,
const ExprHandle& stop,
- StmtPtr body) { return For::make(var, start, stop, body); },
- py::return_value_policy::reference);
+ StmtPtr body) { return For::make(var, start, stop, body); });
- py::class_<Cond, Stmt, std::unique_ptr<Cond, py::nodelete>>(te, "Cond")
+ py::class_<Cond, Stmt, std::shared_ptr<Cond>>(te, "Cond")
.def_static(
"make",
[](const ExprHandle& condition,
StmtPtr true_stmt,
StmtPtr false_stmt) {
- return alloc<Cond>(condition.node(), true_stmt, false_stmt);
- },
- py::return_value_policy::reference)
- .def("true_stmt", &Cond::true_stmt, py::return_value_policy::reference)
- .def("false_stmt", &Cond::false_stmt, py::return_value_policy::reference);
+ return Cond::make(condition, true_stmt, false_stmt);
+ })
+ .def("true_stmt", &Cond::true_stmt)
+ .def("false_stmt", &Cond::false_stmt);
- py::class_<
- tensorexpr::Block,
- Stmt,
- std::unique_ptr<tensorexpr::Block, py::nodelete>>(te, "Block")
+ py::class_<tensorexpr::Block, Stmt, std::shared_ptr<tensorexpr::Block>>(
+ te, "Block")
.def(py::init([](const std::vector<StmtPtr>& stmts) {
return tensorexpr::Block::make(stmts);
}))
- .def(
- "stmts",
- &tensorexpr::Block::stmts,
- py::return_value_policy::reference);
- py::class_<ExternalCall, Stmt, std::unique_ptr<ExternalCall, py::nodelete>>(
+ .def("stmts", &tensorexpr::Block::stmts);
+ py::class_<ExternalCall, Stmt, std::shared_ptr<ExternalCall>>(
te, "ExternalCall")
- .def(py::init(&ExternalCall::make), py::return_value_policy::reference);
+ .def(py::init(&ExternalCall::make));
py::class_<LoopNest>(te, "LoopNest")
.def(py::init<const std::vector<Tensor*>&>())