static void vectorize(te::LoopNest* ln, te::Tensor* target, int width) {
auto loops = ln->getLoopStmtsFor(target);
- te::For *inner, *tail;
+ te::ForPtr inner, tail;
ln->splitWithTail(loops[0], width, &inner, &tail);
ASSERT_TRUE(te::LoopNest::vectorize(inner));
}
te::LoopNest ln({B});
ln.prepareForCodegen();
vectorize(&ln, B, 8);
- te::Stmt* s = ln.root_stmt();
+ te::StmtPtr s = ln.root_stmt();
s = te::IRSimplifier::simplify(s);
te::LLVMCodeGen cg(s, {A, B, N});
VarHandle index = VarHandle("index", kInt);
ExprHandle load_a = a_buf.load(index);
ExprHandle to_float = Cast::make(kFloat, load_a);
- Stmt* store_b = b_buf.store({index}, to_float);
- Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
+ StmtPtr store_b = b_buf.store({index}, to_float);
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<int> a_v(kTotalSize);
PaddedBuffer<float> b_v(kTotalSize);
VarHandle index = VarHandle("index", kInt);
ExprHandle load_a = a_buf.load(index);
ExprHandle to_float = Sub::make(0, load_a);
- Stmt* store_b = b_buf.store({index}, to_float);
- Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
+ StmtPtr store_b = b_buf.store({index}, to_float);
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<int> a_v(kTotalSize);
PaddedBuffer<int> b_v(kTotalSize);
VarHandle index = VarHandle("index", kInt);
ExprHandle load_a = a_buf.load(index);
ExprHandle to_float = Sub::make(0, load_a);
- Stmt* store_b = b_buf.store({index}, to_float);
- Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
+ StmtPtr store_b = b_buf.store({index}, to_float);
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<float> a_v(kTotalSize);
PaddedBuffer<float> b_v(kTotalSize);
ExprHandle load_a = a_buf.load(index);
ExprHandle load_b = b_buf.load(index);
ExprHandle load_c = c_buf.load(index);
- Stmt* store_d = d_buf.store({index}, load_a + load_b * load_c);
- Stmt* stmt = For::make(index, 0, kTotalSize, store_d);
+ StmtPtr store_d = d_buf.store({index}, load_a + load_b * load_c);
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_d);
PaddedBuffer<int> a_v(kTotalSize);
PaddedBuffer<int> b_v(kTotalSize);
ExprHandle load_a = a_buf.load(index);
ExprHandle load_b = b_buf.load(index);
ExprHandle load_c = c_buf.load(index);
- Stmt* store_d = d_buf.store({index}, load_a + load_b * load_c);
- Stmt* stmt = For::make(index, 0, kTotalSize, store_d);
+ StmtPtr store_d = d_buf.store({index}, load_a + load_b * load_c);
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_d);
PaddedBuffer<float> a_v(kTotalSize);
PaddedBuffer<float> b_v(kTotalSize);
ExprHandle load_a = a_buf.load(index);
ExprHandle load_b = b_buf.load(index);
ExprHandle load_c = c_buf.load(index);
- Stmt* store_d = d_buf.store({index}, load_a - load_b * load_c);
- Stmt* stmt = For::make(index, 0, kTotalSize, store_d);
+ StmtPtr store_d = d_buf.store({index}, load_a - load_b * load_c);
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_d);
PaddedBuffer<int> a_v(kTotalSize);
PaddedBuffer<int> b_v(kTotalSize);
ExprHandle load_a = a_buf.load(index);
ExprHandle load_b = b_buf.load(index);
ExprHandle load_c = c_buf.load(index);
- Stmt* store_d = d_buf.store({index}, load_a - load_b * load_c);
- Stmt* stmt = For::make(index, 0, kTotalSize, store_d);
+ StmtPtr store_d = d_buf.store({index}, load_a - load_b * load_c);
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_d);
PaddedBuffer<float> a_v(kTotalSize);
PaddedBuffer<float> b_v(kTotalSize);
ExprHandle load_a = a_buf.load(index);
ExprHandle load_b = b_buf.load(index);
ExprHandle load_c = c_buf.load(index);
- Stmt* store_d = d_buf.store({index}, load_a + load_c * (load_b - load_a));
- Stmt* stmt = For::make(index, 0, kTotalSize, store_d);
+ StmtPtr store_d = d_buf.store({index}, load_a + load_c * (load_b - load_a));
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_d);
PaddedBuffer<float> a_v(kTotalSize);
PaddedBuffer<float> b_v(kTotalSize);
ExprHandle load_b = b_buf.load(index);
ExprHandle load_c = c_buf.load(index);
ExprHandle load_d = d_buf.load(index);
- Stmt* store_e = e_buf.store({index}, load_a + load_b * load_c * load_d);
- Stmt* stmt = For::make(index, 0, kTotalSize, store_e);
+ StmtPtr store_e = e_buf.store({index}, load_a + load_b * load_c * load_d);
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_e);
PaddedBuffer<int> a_v(kTotalSize);
PaddedBuffer<int> b_v(kTotalSize);
ExprHandle load_b = b_buf.load(index);
ExprHandle load_c = c_buf.load(index);
ExprHandle load_d = d_buf.load(index);
- Stmt* store_e = e_buf.store({index}, load_a + load_b * load_c * load_d);
- Stmt* stmt = For::make(index, 0, kTotalSize, store_e);
+ StmtPtr store_e = e_buf.store({index}, load_a + load_b * load_c * load_d);
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_e);
PaddedBuffer<float> a_v(kTotalSize);
PaddedBuffer<float> b_v(kTotalSize);
VarHandle index = VarHandle("index", kInt);
ExprHandle load_a = a_buf.load(index);
ExprHandle load_b = b_buf.load(index);
- Stmt* store_c = c_buf.store({index}, load_a * load_b);
- Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
+ StmtPtr store_c = c_buf.store({index}, load_a * load_b);
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_c);
PaddedBuffer<int> a_v(kTotalSize);
PaddedBuffer<int> b_v(kTotalSize);
VarHandle index = VarHandle("index", kInt);
ExprHandle load_a = a_buf.load(index);
ExprHandle load_b = b_buf.load(index);
- Stmt* store_c = c_buf.store({index}, load_a * load_b);
- Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
+ StmtPtr store_c = c_buf.store({index}, load_a * load_b);
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_c);
PaddedBuffer<float> a_v(kTotalSize);
PaddedBuffer<float> b_v(kTotalSize);
VarHandle index = VarHandle("index", kInt);
ExprHandle load_a = a_buf.load(index);
ExprHandle load_b = b_buf.load(index);
- Stmt* store_c = c_buf.store({index}, load_a / load_b);
- Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
+ StmtPtr store_c = c_buf.store({index}, load_a / load_b);
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_c);
PaddedBuffer<int> a_v(kTotalSize);
PaddedBuffer<int> b_v(kTotalSize);
VarHandle index = VarHandle("index", kInt);
ExprHandle load_a = a_buf.load(index);
ExprHandle load_b = b_buf.load(index);
- Stmt* store_c = c_buf.store({index}, load_a / load_b);
- Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
+ StmtPtr store_c = c_buf.store({index}, load_a / load_b);
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_c);
PaddedBuffer<float> a_v(kTotalSize);
PaddedBuffer<float> b_v(kTotalSize);
VarHandle index = VarHandle("index", kInt);
ExprHandle load_a = a_buf.load(index);
ExprHandle load_b = b_buf.load(index);
- Stmt* store_c = c_buf.store({index}, Max::make(load_a, load_b, true));
- Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
+ StmtPtr store_c = c_buf.store({index}, Max::make(load_a, load_b, true));
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_c);
PaddedBuffer<int> a_v(kTotalSize);
PaddedBuffer<int> b_v(kTotalSize);
VarHandle index = VarHandle("index", kInt);
ExprHandle load_a = a_buf.load(index);
ExprHandle load_b = b_buf.load(index);
- Stmt* store_c = c_buf.store({index}, Max::make(load_a, load_b, true));
- Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
+ StmtPtr store_c = c_buf.store({index}, Max::make(load_a, load_b, true));
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_c);
PaddedBuffer<float> a_v(kTotalSize);
PaddedBuffer<float> b_v(kTotalSize);
VarHandle index = VarHandle("index", kInt);
ExprHandle load_a = a_buf.load(index);
ExprHandle load_b = b_buf.load(index);
- Stmt* store_c = c_buf.store({index}, Min::make(load_a, load_b, true));
- Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
+ StmtPtr store_c = c_buf.store({index}, Min::make(load_a, load_b, true));
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_c);
PaddedBuffer<int> a_v(kTotalSize);
PaddedBuffer<int> b_v(kTotalSize);
VarHandle index = VarHandle("index", kInt);
ExprHandle load_a = a_buf.load(index);
ExprHandle load_b = b_buf.load(index);
- Stmt* store_c = c_buf.store({index}, Min::make(load_a, load_b, true));
- Stmt* stmt = For::make(index, 0, kTotalSize, store_c);
+ StmtPtr store_c = c_buf.store({index}, Min::make(load_a, load_b, true));
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_c);
PaddedBuffer<float> a_v(kTotalSize);
PaddedBuffer<float> b_v(kTotalSize);
VarHandle index = VarHandle("index", kInt);
ExprHandle load_a = a_buf.load(index);
- Stmt* store_b = b_buf.store({index}, FloatImm::make(1.0f) / load_a);
- Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
+ StmtPtr store_b = b_buf.store({index}, FloatImm::make(1.0f) / load_a);
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<float> a_v(kTotalSize);
PaddedBuffer<float> b_v(kTotalSize);
VarHandle index = VarHandle("index", kInt);
ExprHandle load_a = a_buf.load(index);
- Stmt* store_b = b_buf.store({index}, Max::make(load_a, 0, false));
- Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
+ StmtPtr store_b = b_buf.store({index}, Max::make(load_a, 0, false));
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<int> a_v(kTotalSize);
PaddedBuffer<int> b_v(kTotalSize);
VarHandle index = VarHandle("index", kInt);
ExprHandle load_a = a_buf.load(index);
- Stmt* store_b = b_buf.store(
+ StmtPtr store_b = b_buf.store(
{index}, Max::make(load_a, 0, false) // relu does not propagate nans
);
- Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<float> a_v(kTotalSize);
PaddedBuffer<float> b_v(kTotalSize);
VarHandle index = VarHandle("index", kInt);
ExprHandle load_a = a_buf.load(index);
- Stmt* store_b = b_buf.store({index}, log(load_a));
- Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
+ StmtPtr store_b = b_buf.store({index}, log(load_a));
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<float> a_v(kTotalSize);
PaddedBuffer<float> b_v(kTotalSize);
VarHandle index = VarHandle("index", kInt);
ExprHandle load_a = a_buf.load(index);
- Stmt* store_b = b_buf.store({index}, fast_log(load_a));
- Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
+ StmtPtr store_b = b_buf.store({index}, fast_log(load_a));
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<float> a_v(kTotalSize);
PaddedBuffer<float> b_v(kTotalSize);
VarHandle index = VarHandle("index", kInt);
ExprHandle load_a = a_buf.load(index);
- Stmt* store_b = b_buf.store({index}, fast_tanh(load_a));
- Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
+ StmtPtr store_b = b_buf.store({index}, fast_tanh(load_a));
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<float> a_v(kTotalSize);
PaddedBuffer<float> b_v(kTotalSize);
VarHandle index = VarHandle("index", kInt);
ExprHandle load_a = a_buf.load(index);
- Stmt* store_b = b_buf.store({index}, fast_sigmoid(load_a));
- Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
+ StmtPtr store_b = b_buf.store({index}, fast_sigmoid(load_a));
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<float> a_v(kTotalSize);
PaddedBuffer<float> b_v(kTotalSize);
VarHandle index = VarHandle("index", kInt);
ExprHandle load_a = a_buf.load(index);
- Stmt* store_b = b_buf.store({index}, log10(load_a));
- Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
+ StmtPtr store_b = b_buf.store({index}, log10(load_a));
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<float> a_v(kTotalSize);
PaddedBuffer<float> b_v(kTotalSize);
VarHandle index = VarHandle("index", kInt);
ExprHandle load_a = a_buf.load(index);
- Stmt* store_b = b_buf.store({index}, log2(load_a));
- Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
+ StmtPtr store_b = b_buf.store({index}, log2(load_a));
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<float> a_v(kTotalSize);
PaddedBuffer<float> b_v(kTotalSize);
VarHandle index = VarHandle("index", kInt);
ExprHandle load_a = a_buf.load(index);
- Stmt* store_b = b_buf.store({index}, exp(load_a));
- Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
+ StmtPtr store_b = b_buf.store({index}, exp(load_a));
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<float> a_v(kTotalSize);
PaddedBuffer<float> b_v(kTotalSize);
VarHandle index = VarHandle("index", kInt);
ExprHandle load_a = a_buf.load(index);
- Stmt* store_b = b_buf.store({index}, erf(load_a));
- Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
+ StmtPtr store_b = b_buf.store({index}, erf(load_a));
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<float> a_v(kTotalSize);
PaddedBuffer<float> b_v(kTotalSize);
VarHandle index = VarHandle("index", kInt);
ExprHandle load_a = a_buf.load(index);
- Stmt* store_b = b_buf.store({index}, cos(load_a));
- Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
+ StmtPtr store_b = b_buf.store({index}, cos(load_a));
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<float> a_v(kTotalSize);
PaddedBuffer<float> b_v(kTotalSize);
return a.load(y, x) * b->load(y, x);
});
LoopNest l({c});
- std::vector<For*> loops = l.getLoopStmtsFor(c);
- Stmt* body = l.getLoopBodyFor(c);
+ std::vector<ForPtr> loops = l.getLoopStmtsFor(c);
+ StmtPtr body = l.getLoopBodyFor(c);
{
// Infer bounds on the top-level loop scope
auto bounds_info = inferBounds(loops[0]);
LoopNest l({b});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* inner;
+ ForPtr inner;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* tail;
- std::vector<For*> loops = l.getLoopStmtsFor(b);
+ ForPtr tail;
+ std::vector<ForPtr> loops = l.getLoopStmtsFor(b);
LoopNest::splitWithTail(loops[0], 16, &inner, &tail);
- For* outer = loops[0];
+ ForPtr outer = loops[0];
{
// Verify inferred bounds for the outer loop
return a.load(y + 100, x + 100) * b->load(y * 2, x * 5);
});
LoopNest l({c});
- std::vector<For*> loops = l.getLoopStmtsFor(c);
- Stmt* body = l.getLoopBodyFor(c);
+ std::vector<ForPtr> loops = l.getLoopStmtsFor(c);
+ StmtPtr body = l.getLoopBodyFor(c);
{
// Infer bounds on the top-level loop scope
auto bounds_info = inferBounds(loops[0]);
Tensor* c = Compute(
"c", {{H, "x"}}, [&](const VarHandle& x) { return a.load(x + H); });
LoopNest l({b, c});
- std::vector<For*> loops = NodeFinder<For>::find(l.root_stmt());
+ std::vector<ForPtr> loops = NodeFinder<For>::find(l.root_stmt());
{
// Infer bounds on the top-level loop scope
// Same as above but the offsets are on the Store now.
// Can't do this through ComputeAPI without transforms we don't have yet.
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{For::make(x, 0, 64, Store::make(b, {x}, Load::make(a, {x}))),
For::make(x, 0, 32, Store::make(c, {x + 10}, Load::make(a, {x}))),
For::make(x, 0, 96, Store::make(d, {x + 2}, Load::make(a, {x})))});
LoopNest l({B, C});
auto bounds_info_before = inferBounds(l.root_stmt());
- Stmt* j_loop = l.getLoopStmtsFor(B)[1];
+ StmtPtr j_loop = l.getLoopStmtsFor(B)[1];
LoopNest::cacheAccesses(A->buf(), "A_local", j_loop);
auto bounds_info_after = inferBounds(l.root_stmt());
ASSERT_EQ(TABI.stop.size(), 1);
// Bounds should be 0 -> (3*4*5)-1
- ASSERT_TRUE(exprEquals(TABI.start[0], new IntImm(0)));
- ASSERT_TRUE(exprEquals(TABI.stop[0], new IntImm(3 * 4 * 5 - 1)));
+ ASSERT_TRUE(exprEquals(TABI.start[0], alloc<IntImm>(0)));
+ ASSERT_TRUE(exprEquals(TABI.stop[0], alloc<IntImm>(3 * 4 * 5 - 1)));
}
TEST(BoundsInference, GetPotentialHazards) {
* C[0] = 5;
*/
- Store* store1 = Store::make(a, {0}, Load::make(b, {0}));
- Store* store2 = Store::make(b, {0}, 3);
- Store* store3 = Store::make(a, {0}, Load::make(b, {0}));
- Store* store4 = Store::make(c, {0}, 5);
- Stmt* stmt = Block::make({store1, store2, store3, store4});
+ StorePtr store1 = Store::make(a, {0}, Load::make(b, {0}));
+ StorePtr store2 = Store::make(b, {0}, 3);
+ StorePtr store3 = Store::make(a, {0}, Load::make(b, {0}));
+ StorePtr store4 = Store::make(c, {0}, 5);
+ StmtPtr stmt = Block::make({store1, store2, store3, store4});
MemDependencyChecker analyzer;
stmt->accept(&analyzer);
MemDependencyChecker analyzer;
l.root_stmt()->accept(&analyzer);
- For* loopRootA = l.getLoopStmtsFor(A)[0];
- For* loopRootB = l.getLoopStmtsFor(B)[0];
+ ForPtr loopRootA = l.getLoopStmtsFor(A)[0];
+ ForPtr loopRootB = l.getLoopStmtsFor(B)[0];
// No dependencies between loops.
ASSERT_EQ(
MemDependencyChecker analyzer;
l.root_stmt()->accept(&analyzer);
- For* loopRootA = l.getLoopStmtsFor(A)[0];
- For* loopRootB = l.getLoopStmtsFor(B)[0];
+ ForPtr loopRootA = l.getLoopStmtsFor(A)[0];
+ ForPtr loopRootB = l.getLoopStmtsFor(B)[0];
ASSERT_EQ(
HazardKind::ReadAfterWrite,
LoopNest l({A});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For *inner, *tail;
+ ForPtr inner, tail;
// Splitting with tail by something offset creates a tail which also writes to
// A.
- For* outer = l.getLoopStmtsFor(A)[0];
+ ForPtr outer = l.getLoopStmtsFor(A)[0];
// `outer` loop get transformed to the outer loop after splitting.
LoopNest::splitWithTail(outer, 5, &inner, &tail);
i, 0, 100, Block::make({storeA1, storeB, storeC, storeA2, storeA3}));
tensorexpr::analysis::MemDependencyChecker analyzer;
forI->accept(&analyzer);
- ASSERT_TRUE(
- isOverlapping(analyzer, storeA1, dynamic_cast<Load*>(loadA1.node())));
- ASSERT_FALSE(
- isOverlapping(analyzer, storeA1, dynamic_cast<Load*>(loadA2.node())));
+ ASSERT_TRUE(isOverlapping(analyzer, storeA1, to<Load>(loadA1.node())));
+ ASSERT_FALSE(isOverlapping(analyzer, storeA1, to<Load>(loadA2.node())));
ASSERT_TRUE(isOverlapping(analyzer, storeA1, storeA2));
ASSERT_FALSE(isOverlapping(analyzer, storeA1, storeA3));
}
// LoopNest, IRSimplifier, etc.
te::LoopNest loop({conv});
loop.prepareForCodegen();
- te::Stmt* s = loop.root_stmt();
+ te::StmtPtr s = loop.root_stmt();
s = te::IRSimplifier::simplify(s);
at::Tensor result = at::empty_like(ref);
TEST(CppPrinter, AllocateOnStackThenFree) {
KernelScope kernel_scope;
- std::vector<Expr*> dims = {new IntImm(2), new IntImm(3)};
- Buf* buf = new Buf("x", dims, kInt);
- Allocate* alloc = new Allocate(buf);
- Free* free = new Free(buf);
- Block* block = Block::make({alloc, free});
+ std::vector<ExprPtr> dims = {alloc<IntImm>(2), alloc<IntImm>(3)};
+ BufPtr buf = alloc<Buf>("x", dims, kInt);
+ AllocatePtr alloc_ = alloc<Allocate>(buf);
+ FreePtr free_ = alloc<Free>(buf);
+ BlockPtr block = Block::make({alloc_, free_});
std::stringstream ss;
CppPrinter printer(&ss);
TEST(CppPrinter, AllocateOnHeapThenFree) {
KernelScope kernel_scope;
- std::vector<Expr*> dims = {new IntImm(20), new IntImm(50), new IntImm(3)};
- Buf* buf = new Buf("y", dims, kLong);
- Allocate* alloc = new Allocate(buf);
- Free* free = new Free(buf);
- Block* block = Block::make({alloc, free});
+ std::vector<ExprPtr> dims = {
+ alloc<IntImm>(20), alloc<IntImm>(50), alloc<IntImm>(3)};
+ BufPtr buf = alloc<Buf>("y", dims, kLong);
+ AllocatePtr alloc_ = alloc<Allocate>(buf);
+ FreePtr free_ = alloc<Free>(buf);
+ BlockPtr block = Block::make({alloc_, free_});
std::stringstream ss;
CppPrinter printer(&ss);
return a_buf.load(n, b_id, t_id) + b_buf.load(n, b_id, t_id);
});
LoopNest l({c});
- std::vector<For*> loops = l.getLoopStmtsFor(c);
+ std::vector<ForPtr> loops = l.getLoopStmtsFor(c);
loops[1]->set_gpu_block_index(0);
loops[2]->set_gpu_thread_index(0);
l.prepareForCodegen();
- Stmt* stmt = l.root_stmt();
+ StmtPtr stmt = l.root_stmt();
CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf);
const int N = block_count * block_size * num_iter;
PaddedBuffer<ctype> a_v(N);
return sigmoid(sigmoid(a_buf.load(n, b_id, t_id)));
});
LoopNest l({c});
- std::vector<For*> loops = l.getLoopStmtsFor(c);
+ std::vector<ForPtr> loops = l.getLoopStmtsFor(c);
loops[1]->set_gpu_block_index(0);
loops[2]->set_gpu_thread_index(0);
l.prepareForCodegen();
- Stmt* stmt = l.root_stmt();
+ StmtPtr stmt = l.root_stmt();
CudaCodeGen cuda_cg(stmt, c, a_buf);
const int N = block_count * block_size * num_iter;
PaddedBuffer<float> a_v(N);
},
[&](const VarHandle& n) { return a_buf.load(n) + b_buf.load(n); });
LoopNest l({c});
- For* n_inner;
- std::vector<For*> loops = l.getLoopStmtsFor(c);
+ ForPtr n_inner;
+ std::vector<ForPtr> loops = l.getLoopStmtsFor(c);
l.splitWithMask(loops[0], block_size, &n_inner);
loops[0]->set_gpu_block_index(0);
n_inner->set_gpu_thread_index(0);
l.prepareForCodegen();
- Stmt* stmt = l.root_stmt();
+ StmtPtr stmt = l.root_stmt();
CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf);
PaddedBuffer<float> a_v(N);
PaddedBuffer<float> b_v(N);
LoopNest l({b});
l.prepareForCodegen();
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
CudaCodeGen cg(s, {a, b});
std::vector<at::Half> aData(4, 2.0f);
});
LoopNest l({c});
l.prepareForCodegen();
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
CudaCodeGen cg(s, {a, b, c, m, n});
std::vector<float> aData(M * N, 1.0f);
return Intrinsics::make(IntrinsicsOp::kRand, kFloat);
});
LoopNest l({c});
- std::vector<For*> loops = l.getLoopStmtsFor(c);
+ std::vector<ForPtr> loops = l.getLoopStmtsFor(c);
loops[1]->set_gpu_block_index(0);
loops[2]->set_gpu_thread_index(0);
l.prepareForCodegen();
- Stmt* stmt = l.root_stmt();
+ StmtPtr stmt = l.root_stmt();
CudaCodeGen cuda_cg(stmt, c);
const int N = block_count * block_size * num_iter;
PaddedBuffer<float> c_v(N);
Tensor* b = Compute(
"b", {{n, "n"}}, [&](const VarHandle& i) { return a.load(i) * 2.0f; });
LoopNest l({b});
- For* inner;
- std::vector<For*> loops = l.getLoopStmtsFor(b);
+ ForPtr inner;
+ std::vector<ForPtr> loops = l.getLoopStmtsFor(b);
l.splitWithMask(loops[0], 1024, &inner);
loops[0]->set_gpu_block_index(0);
inner->set_gpu_thread_index(0);
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
CudaCodeGen cg(s, {a, b, n});
std::vector<float> aData(N, 1.0f);
// }
// }
- Store* init_store = output_buf.store({0}, 0.f);
+ StorePtr init_store = output_buf.store({0}, 0.f);
VarHandle i1("i1", kInt);
ExprHandle load_data = Load::make(BufHandle(data_buf.data()), {i1});
ExprHandle load_output = Load::make(BufHandle(output_buf.data()), {0});
ExprHandle add_value = load_output + load_data;
- Store* store_output = output_buf.store({0}, add_value);
- For* for_output = For::make(i1, 0, N, store_output);
- Stmt* reduce_block = Block::make({init_store, for_output});
+ StorePtr store_output = output_buf.store({0}, add_value);
+ ForPtr for_output = For::make(i1, 0, N, store_output);
+ StmtPtr reduce_block = Block::make({init_store, for_output});
VarHandle thread_idx("tidx", kInt);
LoopOptions thread_idx_options;
thread_idx_options.set_gpu_thread_index(0);
- For* thread_idx_loop =
+ ForPtr thread_idx_loop =
For::make(thread_idx, 0, 1, reduce_block, thread_idx_options);
VarHandle block_idx("bidx", kInt);
LoopOptions block_idx_options;
block_idx_options.set_gpu_block_index(0);
- For* block_idx_loop =
+ ForPtr block_idx_loop =
For::make(block_idx, 0, 1, thread_idx_loop, block_idx_options);
CudaCodeGen cuda_cg(block_idx_loop, data_buf, output_buf);
Placeholder a_buf("a", kFloat, {N});
Placeholder b_buf("b", kFloat, {1});
- Store* init_store = b_buf.store({0}, 0.f);
+ StorePtr init_store = b_buf.store({0}, 0.f);
VarHandle t("t", kInt);
VarHandle b("b", kInt);
// b[0] = 0
ExprHandle cond_t_lt_1 =
CompareSelect::make(t, 1, CompareSelectOperation::kLT);
- Cond* masked_init_b = Cond::make(cond_t_lt_1, init_store, nullptr);
+ CondPtr masked_init_b = Cond::make(cond_t_lt_1, init_store, nullptr);
LoopOptions thread_idx_options;
thread_idx_options.set_gpu_thread_index(0);
- For* for_init = For::make(t, 0, N, masked_init_b, thread_idx_options);
+ ForPtr for_init = For::make(t, 0, N, masked_init_b, thread_idx_options);
// for t in 0..1024: // thread-idx
// b[0] = b[0] + a[t] // implied atomic
ExprHandle load_a = Load::make(BufHandle(a_buf.data()), {t});
ExprHandle load_b = Load::make(BufHandle(b_buf.data()), {0});
ExprHandle add_value = load_b + load_a;
- Store* store_b = b_buf.store({0}, add_value);
- For* for_b = For::make(t, 0, N, store_b, thread_idx_options);
+ StorePtr store_b = b_buf.store({0}, add_value);
+ ForPtr for_b = For::make(t, 0, N, store_b, thread_idx_options);
- Stmt* reduce_block = Block::make({for_init, for_b});
+ StmtPtr reduce_block = Block::make({for_init, for_b});
VarHandle block_idx("bidx", kInt);
LoopOptions block_idx_options;
block_idx_options.set_gpu_block_index(0);
- For* block_idx_loop =
+ ForPtr block_idx_loop =
For::make(block_idx, 0, 1, reduce_block, block_idx_options);
CudaCodeGen cuda_cg(block_idx_loop, a_buf, b_buf);
// a[0] = 0
// for n in 0..2:
// a[0] = a[0] + n
- Store* store_a0_0 = a_buf.store({0}, 0.f);
+ StorePtr store_a0_0 = a_buf.store({0}, 0.f);
ExprHandle load_a0 = Load::make(BufHandle(a_buf.data()), {0});
ExprHandle v1 = load_a0 + n;
- Store* store_a0_v1 = a_buf.store({0}, v1);
- For* loop_a_0 = For::make(n, 0, 2, store_a0_v1);
+ StorePtr store_a0_v1 = a_buf.store({0}, v1);
+ ForPtr loop_a_0 = For::make(n, 0, 2, store_a0_v1);
// for m in 0..1024: // thread-idx
// b[m] = m
- Store* store_bm_m = b_buf.store({m}, m + 0.f);
+ StorePtr store_bm_m = b_buf.store({m}, m + 0.f);
LoopOptions thread_idx_options;
thread_idx_options.set_gpu_thread_index(0);
- For* loop_b_1 = For::make(m, 0, N, store_bm_m, thread_idx_options);
+ ForPtr loop_b_1 = For::make(m, 0, N, store_bm_m, thread_idx_options);
// a[1] = 1
// for l in 0..2:
// a[1] = a[1] + l
- Store* store_a1_1 = a_buf.store({1}, 1.f);
+ StorePtr store_a1_1 = a_buf.store({1}, 1.f);
ExprHandle load_a1 = a_buf.load(1);
ExprHandle v2 = load_a1 + l;
- Store* store_a1_v2 = a_buf.store({1}, v2);
- For* loop_a_1 = For::make(l, 0, 2, store_a1_v2);
+ StorePtr store_a1_v2 = a_buf.store({1}, v2);
+ ForPtr loop_a_1 = For::make(l, 0, 2, store_a1_v2);
- Stmt* reduce_block =
+ StmtPtr reduce_block =
Block::make({store_a0_0, loop_a_0, loop_b_1, store_a1_1, loop_a_1});
VarHandle block_idx("bidx", kInt);
LoopOptions block_idx_options;
block_idx_options.set_gpu_block_index(0);
- For* block_idx_loop =
+ ForPtr block_idx_loop =
For::make(block_idx, 0, 1, reduce_block, block_idx_options);
CudaCodeGen cuda_cg(block_idx_loop, a_buf, b_buf);
VarHandle m("m", kInt);
VarHandle n("n", kInt);
- std::vector<Stmt*> block;
- std::vector<Expr*> dims;
+ std::vector<StmtPtr> block;
+ std::vector<ExprPtr> dims;
dims.push_back(ExprHandle(N).node());
- BufHandle c{new Buf("c", dims, kFloat)};
+ BufHandle c{alloc<Buf>("c", dims, kFloat)};
{
// alloc(c, 64);
- Allocate* alloc = Allocate::make(c);
+ AllocatePtr alloc = Allocate::make(c);
block.push_back(alloc);
}
{
// for n in 0..64: // thread-idx
// c(n) = 0
- Store* store_cn_0 = Store::make(c, {n}, 0.f);
- For* loop_n1 = For::make(n, 0, N, store_cn_0, thread_idx_opt);
+ StorePtr store_cn_0 = Store::make(c, {n}, 0.f);
+ ForPtr loop_n1 = For::make(n, 0, N, store_cn_0, thread_idx_opt);
block.push_back(loop_n1);
}
ExprHandle a_kmn =
Load::make(BufHandle(a.data()), {k * (M * N) + m * N + n});
ExprHandle v_add = load_cn + a_kmn;
- Store* store_cn_v = Store::make(c, {n}, v_add);
- For* loop_n2 = For::make(n, 0, N, store_cn_v, thread_idx_opt);
- For* loop_m1 = For::make(m, 0, M, loop_n2);
+ StorePtr store_cn_v = Store::make(c, {n}, v_add);
+ ForPtr loop_n2 = For::make(n, 0, N, store_cn_v, thread_idx_opt);
+ ForPtr loop_m1 = For::make(m, 0, M, loop_n2);
block.push_back(loop_m1);
}
// b(k) = 0
// for n in 0..64: // thread_idx
// b(k) = b(k) + c(n)
- Store* store_bk_0 = b.store({k}, 0.f);
+ StorePtr store_bk_0 = b.store({k}, 0.f);
block.push_back(store_bk_0);
ExprHandle load_bk = b.load(k);
ExprHandle load_cn = Load::make(kFloat, c, {n});
ExprHandle v_add = load_bk + load_cn;
- Store* store_bk = b.store({k}, v_add);
- For* loop_n3 = For::make(n, 0, N, store_bk, thread_idx_opt);
+ StorePtr store_bk = b.store({k}, v_add);
+ ForPtr loop_n3 = For::make(n, 0, N, store_bk, thread_idx_opt);
block.push_back(loop_n3);
}
{
// free(c)
- Free* free_stmt = Free::make(c);
+ FreePtr free_stmt = Free::make(c);
block.push_back(free_stmt);
}
- Block* reduce_body = Block::make(block);
- For* loop_k1 = For::make(k, 0, 1, reduce_body, block_idx_opt);
+ BlockPtr reduce_body = Block::make(block);
+ ForPtr loop_k1 = For::make(k, 0, 1, reduce_body, block_idx_opt);
// TODO: check the generated code for correctness.
CudaCodeGen cuda_cg(loop_k1, a, b);
VarHandle m("m", kInt);
VarHandle n("n", kInt);
- BufHandle c{new Buf("c", {new IntImm(1)}, kFloat)};
- std::vector<Stmt*> block_k;
+ BufHandle c{
+ alloc<Buf>("c", std::vector<ExprPtr>({alloc<IntImm>(1)}), kFloat)};
+ std::vector<StmtPtr> block_k;
{
// b(k) = 0
- Store* store_bk_0 = b.store({k}, 0.f);
+ StorePtr store_bk_0 = b.store({k}, 0.f);
block_k.push_back(store_bk_0);
}
- std::vector<Stmt*> block_n;
+ std::vector<StmtPtr> block_n;
{
// alloc(c, 1);
- Allocate* alloc = Allocate::make(c);
+ AllocatePtr alloc = Allocate::make(c);
block_n.push_back(alloc);
}
{
// c(0) = 0
- Store* store_c0_0 = Store::make(c, {0}, 0.f);
+ StorePtr store_c0_0 = Store::make(c, {0}, 0.f);
block_n.push_back(store_c0_0);
}
{
ExprHandle load_c0 = Load::make(kFloat, c, {0});
ExprHandle a_kmn = a.load(k * (M * N) + m * N + n);
ExprHandle v_add = load_c0 + a_kmn;
- Store* store_c0_v = Store::make(c, {0}, v_add);
- For* loop_m = For::make(m, 0, M, store_c0_v);
+ StorePtr store_c0_v = Store::make(c, {0}, v_add);
+ ForPtr loop_m = For::make(m, 0, M, store_c0_v);
block_n.push_back(loop_m);
}
{
ExprHandle load_bk = b.load(k);
ExprHandle load_c0 = Load::make(kFloat, c, {0});
ExprHandle v_add = load_bk + load_c0;
- Store* store_bk = b.store({k}, v_add);
+ StorePtr store_bk = b.store({k}, v_add);
block_n.push_back(store_bk);
}
{
// free(c)
- Free* free_stmt = Free::make(c);
+ FreePtr free_stmt = Free::make(c);
block_n.push_back(free_stmt);
}
{
- Block* block_n_stmt = Block::make(block_n);
- For* for_n = For::make(n, 0, N, block_n_stmt, thread_idx_opt);
+ BlockPtr block_n_stmt = Block::make(block_n);
+ ForPtr for_n = For::make(n, 0, N, block_n_stmt, thread_idx_opt);
block_k.push_back(for_n);
}
- Block* block_k_stmt = Block::make(block_k);
- For* loop_k = For::make(k, 0, 1, block_k_stmt, block_idx_opt);
+ BlockPtr block_k_stmt = Block::make(block_k);
+ ForPtr loop_k = For::make(k, 0, 1, block_k_stmt, block_idx_opt);
CudaCodeGen cuda_cg(loop_k, a, b);
PaddedBuffer<float> a_v(1, M, N, "a_v");
LoopNest l({b, c, d});
l.prepareForCodegen();
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
CudaCodeGen cg(s, {a, b, c, d});
std::vector<at::Half> aData(4, 2.0f);
auto half = ToDtype<at::Half>();
Placeholder a("a", half, {4});
Tensor* relu = Compute("relu", {{4, "n"}}, [&](const VarHandle& i) {
- return Max::make(a.load(i), ExprHandle(new HalfImm(0)), true);
+ return Max::make(a.load(i), ExprHandle(alloc<HalfImm>(0)), true);
});
LoopNest l({relu});
l.prepareForCodegen();
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
CudaCodeGen cg(s, {a, relu});
std::ostringstream oss;
auto half = ToDtype<at::Half>();
Placeholder b("b", half, {4});
Tensor* relu = Compute("relu", {{4, "n"}}, [&](const VarHandle& i) {
- return Max::make(a.load(i), ExprHandle(new FloatImm(0)), true);
+ return Max::make(a.load(i), ExprHandle(alloc<FloatImm>(0)), true);
});
LoopNest l({relu});
l.prepareForCodegen();
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
CudaCodeGen cg(s, {a, b, relu});
std::ostringstream oss;
ExprHandle cmp = CompareSelect::make(i, 10, CompareSelectOperation::kLT);
ExprHandle ite = IfThenElse::make(cmp, Add::make(load_a, load_b), load_b);
- For* loop =
+ ForPtr loop =
For::make(i, 0, 12, Block::make({c.store({i}, ite)}), block_idx_opt);
CudaCodeGen cuda_cg(loop, a, b, c);
});
LoopNest l({c, d});
- std::vector<For*> loops = l.getLoopStmtsFor(c);
+ std::vector<ForPtr> loops = l.getLoopStmtsFor(c);
loops[0]->set_gpu_block_index(0);
loops = l.getLoopStmtsFor(d);
loops[0]->set_gpu_block_index(0);
l.prepareForCodegen();
- Stmt* stmt = l.root_stmt();
+ StmtPtr stmt = l.root_stmt();
CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf);
std::ostringstream oss;
auto blockExtents = cuda_cg.gpu_block_extents();
auto threadExtents = cuda_cg.gpu_thread_extents();
- ASSERT_TRUE(exprEquals(blockExtents[0], new IntImm(A_SIZE)));
- ASSERT_TRUE(exprEquals(threadExtents[0], new IntImm(1)));
+ ASSERT_TRUE(exprEquals(blockExtents[0], alloc<IntImm>(A_SIZE)));
+ ASSERT_TRUE(exprEquals(threadExtents[0], alloc<IntImm>(1)));
// Sanity check that the kernel works.
PaddedBuffer<float> a_v(A_SIZE);
});
LoopNest l({c, d});
- std::vector<For*> loops = l.getLoopStmtsFor(c);
+ std::vector<ForPtr> loops = l.getLoopStmtsFor(c);
loops[0]->set_gpu_thread_index(0);
loops = l.getLoopStmtsFor(d);
loops[0]->set_gpu_thread_index(0);
l.prepareForCodegen();
- Stmt* stmt = l.root_stmt();
+ StmtPtr stmt = l.root_stmt();
CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf);
std::ostringstream oss;
auto blockExtents = cuda_cg.gpu_block_extents();
auto threadExtents = cuda_cg.gpu_thread_extents();
- ASSERT_TRUE(exprEquals(blockExtents[0], new IntImm(1)));
- ASSERT_TRUE(exprEquals(threadExtents[0], new IntImm(B_SIZE)));
+ ASSERT_TRUE(exprEquals(blockExtents[0], alloc<IntImm>(1)));
+ ASSERT_TRUE(exprEquals(threadExtents[0], alloc<IntImm>(B_SIZE)));
PaddedBuffer<float> a_v(A_SIZE);
PaddedBuffer<float> b_v(B_SIZE);
});
LoopNest l({c, d});
- std::vector<For*> loops = l.getLoopStmtsFor(c);
+ std::vector<ForPtr> loops = l.getLoopStmtsFor(c);
loops[0]->set_gpu_block_index(0);
loops = l.getLoopStmtsFor(d);
loops[0]->set_gpu_block_index(1);
l.prepareForCodegen();
- Stmt* stmt = l.root_stmt();
+ StmtPtr stmt = l.root_stmt();
CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf);
std::ostringstream oss;
auto blockExtents = cuda_cg.gpu_block_extents();
auto threadExtents = cuda_cg.gpu_thread_extents();
- ASSERT_TRUE(exprEquals(blockExtents[0], new IntImm(A_SIZE)));
- ASSERT_TRUE(exprEquals(blockExtents[1], new IntImm(B_SIZE)));
+ ASSERT_TRUE(exprEquals(blockExtents[0], alloc<IntImm>(A_SIZE)));
+ ASSERT_TRUE(exprEquals(blockExtents[1], alloc<IntImm>(B_SIZE)));
PaddedBuffer<float> a_v(A_SIZE);
PaddedBuffer<float> b_v(B_SIZE);
});
LoopNest l({c, d});
- std::vector<For*> loops = l.getLoopStmtsFor(c);
+ std::vector<ForPtr> loops = l.getLoopStmtsFor(c);
loops[0]->set_gpu_block_index(0);
loops = l.getLoopStmtsFor(d);
loops[0]->set_gpu_thread_index(0);
l.prepareForCodegen();
- Stmt* stmt = l.root_stmt();
+ StmtPtr stmt = l.root_stmt();
CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf);
std::ostringstream oss;
auto blockExtents = cuda_cg.gpu_block_extents();
auto threadExtents = cuda_cg.gpu_thread_extents();
- ASSERT_TRUE(exprEquals(blockExtents[0], new IntImm(A_SIZE)));
- ASSERT_TRUE(exprEquals(threadExtents[0], new IntImm(B_SIZE)));
+ ASSERT_TRUE(exprEquals(blockExtents[0], alloc<IntImm>(A_SIZE)));
+ ASSERT_TRUE(exprEquals(threadExtents[0], alloc<IntImm>(B_SIZE)));
PaddedBuffer<float> a_v(A_SIZE);
PaddedBuffer<float> b_v(B_SIZE);
});
LoopNest l({c, d});
- std::vector<For*> loops = l.getLoopStmtsFor(c);
+ std::vector<ForPtr> loops = l.getLoopStmtsFor(c);
loops[0]->set_gpu_block_index(0);
loops[1]->set_gpu_thread_index(0);
loops = l.getLoopStmtsFor(d);
loops[1]->set_gpu_thread_index(0);
l.prepareForCodegen();
- Stmt* stmt = l.root_stmt();
+ StmtPtr stmt = l.root_stmt();
CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf);
std::ostringstream oss;
auto blockExtents = cuda_cg.gpu_block_extents();
auto threadExtents = cuda_cg.gpu_thread_extents();
- ASSERT_TRUE(exprEquals(blockExtents[0], new IntImm(OUTER_SIZE)));
- ASSERT_TRUE(exprEquals(threadExtents[0], new IntImm(A_SIZE)));
+ ASSERT_TRUE(exprEquals(blockExtents[0], alloc<IntImm>(OUTER_SIZE)));
+ ASSERT_TRUE(exprEquals(threadExtents[0], alloc<IntImm>(A_SIZE)));
PaddedBuffer<float> a_v(OUTER_SIZE, A_SIZE);
PaddedBuffer<float> b_v(OUTER_SIZE, B_SIZE);
});
LoopNest l({c, d});
- std::vector<For*> loops = l.getLoopStmtsFor(c);
+ std::vector<ForPtr> loops = l.getLoopStmtsFor(c);
loops[0]->set_gpu_block_index(0);
loops[1]->set_gpu_thread_index(0);
loops = l.getLoopStmtsFor(d);
loops[1]->set_gpu_thread_index(0);
l.prepareForCodegen();
- Stmt* stmt = l.root_stmt();
+ StmtPtr stmt = l.root_stmt();
CudaCodeGen cuda_cg(stmt, c, d, OUTER_SIZE, A_SIZE, B_SIZE, a_buf, b_buf);
std::ostringstream oss;
auto threadExtents = cuda_cg.gpu_thread_extents();
ASSERT_TRUE(exprEquals(blockExtents[0], OUTER_SIZE.node()));
ASSERT_TRUE(exprEquals(
- threadExtents[0], new Max(A_SIZE.node(), B_SIZE.node(), true)));
+ threadExtents[0], alloc<Max>(A_SIZE.node(), B_SIZE.node(), true)));
int OUTER_EXTENT = 10;
int A_EXTENT = 100;
VarHandle j("j", kInt);
VarHandle k("k", kInt);
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
i,
0,
OUTER_SIZE,
auto blockExtents = cuda_cg.gpu_block_extents();
auto threadExtents = cuda_cg.gpu_thread_extents();
- ASSERT_TRUE(exprEquals(blockExtents[0], new IntImm(OUTER_SIZE)));
- ASSERT_TRUE(exprEquals(threadExtents[0], new IntImm(A_SIZE)));
+ ASSERT_TRUE(exprEquals(blockExtents[0], alloc<IntImm>(OUTER_SIZE)));
+ ASSERT_TRUE(exprEquals(threadExtents[0], alloc<IntImm>(A_SIZE)));
PaddedBuffer<float> a_v(OUTER_SIZE, A_SIZE);
PaddedBuffer<float> b_v(OUTER_SIZE, B_SIZE);
VarHandle j("j", kInt);
VarHandle k("k", kInt);
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
i,
0,
OUTER_SIZE,
auto blockExtents = cuda_cg.gpu_block_extents();
auto threadExtents = cuda_cg.gpu_thread_extents();
- ASSERT_TRUE(exprEquals(blockExtents[0], new IntImm(1)));
- ASSERT_TRUE(exprEquals(threadExtents[0], new IntImm(A_SIZE)));
+ ASSERT_TRUE(exprEquals(blockExtents[0], alloc<IntImm>(1)));
+ ASSERT_TRUE(exprEquals(threadExtents[0], alloc<IntImm>(A_SIZE)));
PaddedBuffer<float> a_v(OUTER_SIZE, A_SIZE);
PaddedBuffer<float> b_v(OUTER_SIZE, B_SIZE);
});
LoopNest l({c, d});
- std::vector<For*> loops = l.getLoopStmtsFor(c);
+ std::vector<ForPtr> loops = l.getLoopStmtsFor(c);
loops[0]->set_gpu_block_index(0);
loops[1]->set_gpu_thread_index(0);
loops = l.getLoopStmtsFor(d);
loops[1]->set_gpu_thread_index(1);
l.prepareForCodegen();
- Stmt* stmt = l.root_stmt();
+ StmtPtr stmt = l.root_stmt();
CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf);
std::ostringstream oss;
auto blockExtents = cuda_cg.gpu_block_extents();
auto threadExtents = cuda_cg.gpu_thread_extents();
- ASSERT_TRUE(exprEquals(blockExtents[0], new IntImm(OUTER_SIZE)));
- ASSERT_TRUE(exprEquals(threadExtents[0], new IntImm(A_SIZE)));
+ ASSERT_TRUE(exprEquals(blockExtents[0], alloc<IntImm>(OUTER_SIZE)));
+ ASSERT_TRUE(exprEquals(threadExtents[0], alloc<IntImm>(A_SIZE)));
PaddedBuffer<float> a_v(OUTER_SIZE, A_SIZE);
PaddedBuffer<float> b_v(OUTER_SIZE, B_SIZE);
});
LoopNest l({c, d});
- std::vector<For*> loops = l.getLoopStmtsFor(c);
+ std::vector<ForPtr> loops = l.getLoopStmtsFor(c);
loops[0]->set_gpu_block_index(0);
loops[1]->set_gpu_thread_index(0);
loops = l.getLoopStmtsFor(d);
loops[1]->set_gpu_thread_index(0);
l.prepareForCodegen();
- Stmt* stmt = l.root_stmt();
+ StmtPtr stmt = l.root_stmt();
CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf);
std::ostringstream oss;
auto blockExtents = cuda_cg.gpu_block_extents();
auto threadExtents = cuda_cg.gpu_thread_extents();
- ASSERT_TRUE(exprEquals(blockExtents[0], new IntImm(OUTER_A_SIZE)));
- ASSERT_TRUE(exprEquals(threadExtents[0], new IntImm(A_SIZE)));
+ ASSERT_TRUE(exprEquals(blockExtents[0], alloc<IntImm>(OUTER_A_SIZE)));
+ ASSERT_TRUE(exprEquals(threadExtents[0], alloc<IntImm>(A_SIZE)));
PaddedBuffer<float> a_v(OUTER_A_SIZE, A_SIZE);
PaddedBuffer<float> b_v(OUTER_B_SIZE, B_SIZE);
ExprHandle load_a = a_buf.load(0);
VarHandle var = VarHandle("v", kFloat);
- Stmt* let_store = Let::make(var, load_a);
- Stmt* store_b = b_buf.store({0}, var);
- Block* block = Block::make({let_store, store_b});
+ StmtPtr let_store = Let::make(var, load_a);
+ StmtPtr store_b = b_buf.store({0}, var);
+ BlockPtr block = Block::make({let_store, store_b});
SimpleIREvaluator eval(block, {a_buf, b_buf});
ExprHandle load_b =
b_buf.load({Ramp::make(index * kVectorSize, 1, kVectorSize)});
ExprHandle value = load_a + load_b;
- Stmt* store_c =
+ StmtPtr store_c =
c_buf.store({Ramp::make(index * kVectorSize, 1, kVectorSize)}, value);
- Stmt* stmt = For::make(index, 0, kVectorCount, store_c);
+ StmtPtr stmt = For::make(index, 0, kVectorCount, store_c);
ASSERT_EQ(load_a.dtype(), Dtype(kFloat, kVectorSize));
ASSERT_EQ(load_b.dtype(), Dtype(kFloat, kVectorSize));
TEST(Expr, Substitute01) {
KernelScope kernel_scope;
- Var* x = new Var("x", kFloat);
- Var* y = new Var("y", kFloat);
- Expr* e = new Mul(new Sub(x, new FloatImm(1.0f)), new Add(x, y));
-
- Var* z = new Var("z", kFloat);
- Expr* e2 = Substitute(e, {{x, new Add(z, new FloatImm(5.0f))}});
- Expr* e2_ref = new Mul(
- new Sub(new Add(z, new FloatImm(5.0f)), new FloatImm(1.0f)),
- new Add(new Add(z, new FloatImm(5.0f)), y));
+ VarPtr x = alloc<Var>("x", kFloat);
+ VarPtr y = alloc<Var>("y", kFloat);
+ ExprPtr e =
+ alloc<Mul>(alloc<Sub>(x, alloc<FloatImm>(1.0f)), alloc<Add>(x, y));
+
+ VarPtr z = alloc<Var>("z", kFloat);
+ ExprPtr e2 = Substitute(e, {{x, alloc<Add>(z, alloc<FloatImm>(5.0f))}});
+ ExprPtr e2_ref = alloc<Mul>(
+ alloc<Sub>(alloc<Add>(z, alloc<FloatImm>(5.0f)), alloc<FloatImm>(1.0f)),
+ alloc<Add>(alloc<Add>(z, alloc<FloatImm>(5.0f)), y));
std::ostringstream oss;
oss << *e2;
std::string e2_str = oss.str();
Placeholder b(BufHandle("b", {n}, kFloat));
Placeholder c(BufHandle("c", {n}, kFloat));
VarHandle i("i", kInt);
- Stmt* s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i)));
+ StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i)));
std::vector<float> aData(size, 1.0f);
std::vector<float> bData(size, 2.0f);
std::vector<float> cData(size, 0.0f);
PaddedBuffer<float> a_v(N);
Placeholder a_buf("a", kFloat, {N});
VarHandle index = VarHandle("index", kInt);
- Stmt* assign_x2 = a_buf.store({index}, cast<float>(index) * 2);
- Stmt* assign_x3 = a_buf.store({index}, cast<float>(index) * 3);
+ StmtPtr assign_x2 = a_buf.store({index}, cast<float>(index) * 2);
+ StmtPtr assign_x3 = a_buf.store({index}, cast<float>(index) * 3);
ExprHandle even_cond = CompareSelect::make(Mod::make(index, 2), 0, kEQ);
- Stmt* assign = Cond::make(even_cond, assign_x2, assign_x3);
- Stmt* for_stmt = For::make(index, 0, N, assign);
+ StmtPtr assign = Cond::make(even_cond, assign_x2, assign_x3);
+ StmtPtr for_stmt = For::make(index, 0, N, assign);
SimpleIREvaluator(for_stmt, {a_buf})(a_v);
PaddedBuffer<float> a_ref(N);
Placeholder a_buf("a", kInt, {N});
VarHandle index = VarHandle("index", kInt);
- Stmt* body = a_buf.store({index}, 5);
- Stmt* loop = For::make(index, 0, N, body);
+ StmtPtr body = a_buf.store({index}, 5);
+ StmtPtr loop = For::make(index, 0, N, body);
- Stmt* cloned_loop = Stmt::clone(loop);
+ StmtPtr cloned_loop = Stmt::clone(loop);
std::vector<int> orig_loop_results(N);
std::vector<int> cloned_loop_results(N);
SimpleIREvaluator(loop, {a_buf})(orig_loop_results);
// Let's add another assign to the body in the cloned loop and verify that the
// original statement hasn't changed while the cloned one has.
- Stmt* body_addition = a_buf.store({index}, 33);
- Block* cloned_body =
- static_cast<Block*>(static_cast<For*>(cloned_loop)->body());
+ StmtPtr body_addition = a_buf.store({index}, 33);
+ BlockPtr cloned_body = static_to<Block>(static_to<For>(cloned_loop)->body());
cloned_body->append_stmt(body_addition);
std::vector<int> orig_loop_results_after_mutation(N);
return MatmulResult->load(i, j) + FloatImm::make(3.0f);
});
- Stmt* root_stmt =
- new Block({A->stmt(), B->stmt(), MatmulResult->stmt(), Result->stmt()});
+ StmtPtr root_stmt = alloc<Block>(std::vector<StmtPtr>(
+ {A->stmt(), B->stmt(), MatmulResult->stmt(), Result->stmt()}));
LoopNest l(root_stmt, {Result->buf()});
// Inlining should not inline anything here since all Bufs are either
});
LoopNest l({chunk_0, chunk_1, consumer});
- auto* body = l.root_stmt();
+ auto body = l.root_stmt();
std::stringstream ss;
ss << *body;
TEST(IRVerifier, BitwiseOps) {
KernelScope kernel_scope;
- Var* X = new Var("x", kInt);
- Var* Y = new Var("y", kFloat);
+ VarPtr X = alloc<Var>("x", kInt);
+ VarPtr Y = alloc<Var>("y", kFloat);
{
- auto a = new And(X, Y);
+ auto a = alloc<And>(X, Y);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
{
- auto a = new Or(X, Y);
+ auto a = alloc<Or>(X, Y);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
{
- auto a = new Xor(X, Y);
+ auto a = alloc<Xor>(X, Y);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
{
- auto a = new Lshift(X, Y);
+ auto a = alloc<Lshift>(X, Y);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
{
- auto a = new Rshift(X, Y);
+ auto a = alloc<Rshift>(X, Y);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
TEST(IRVerifier, CompareSelect) {
KernelScope kernel_scope;
- Expr* X = new IntImm(1);
- Expr* Y = new FloatImm(3.14f);
+ ExprPtr X = alloc<IntImm>(1);
+ ExprPtr Y = alloc<FloatImm>(3.14f);
{
- auto a = new CompareSelect(X, X, X, Y, kEQ);
+ auto a = alloc<CompareSelect>(X, X, X, Y, kEQ);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
{
- auto a = new CompareSelect(X, Y, X, X, kEQ);
+ auto a = alloc<CompareSelect>(X, Y, X, X, kEQ);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
TEST(IRVerifier, Ramp) {
KernelScope kernel_scope;
- Var* I = new Var("i", kInt);
- Var* J = new Var("j", kFloat);
+ VarPtr I = alloc<Var>("i", kInt);
+ VarPtr J = alloc<Var>("j", kFloat);
{
- auto a = new Ramp(I, J, 4);
+ auto a = alloc<Ramp>(I, J, 4);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
TEST(IRVerifier, Load) {
KernelScope kernel_scope;
- Var* I = new Var("i", kInt);
- Var* J = new Var("j", kLong);
- Var* K = new Var("k", kFloat);
- Buf* B = new Buf("b", {new IntImm(10), new IntImm(20)}, kFloat);
+ VarPtr I = alloc<Var>("i", kInt);
+ VarPtr J = alloc<Var>("j", kLong);
+ VarPtr K = alloc<Var>("k", kFloat);
+ BufPtr B = alloc<Buf>(
+ "b",
+ std::vector<ExprPtr>({alloc<IntImm>(10), alloc<IntImm>(20)}),
+ kFloat);
{
// Indices with different int dtypes (kInt, kLong) are ok
- auto a = new Load(B, {I, J});
+ auto a = alloc<Load>(B, std::vector<ExprPtr>({I, J}));
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_NO_THROW(verify(a));
}
{
// Float index
- auto a = new Load(B, {K, K});
+ auto a = alloc<Load>(B, std::vector<ExprPtr>({K, K}));
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
{
// Multilanes are only allowed in flattened indices
- auto multilane_index = new Ramp(I, new IntImm(1), 4);
- auto a = new Load(B, {I, multilane_index});
+ auto multilane_index = alloc<Ramp>(I, alloc<IntImm>(1), 4);
+ auto a = alloc<Load>(B, std::vector<ExprPtr>({I, multilane_index}));
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
TEST(IRVerifier, IfThenElse) {
KernelScope kernel_scope;
- Var* I = new Var("i", kInt);
- Var* J = new Var("j", kLong);
- Var* K = new Var("k", kFloat);
+ VarPtr I = alloc<Var>("i", kInt);
+ VarPtr J = alloc<Var>("j", kLong);
+ VarPtr K = alloc<Var>("k", kFloat);
{
// Condition must be integral
- auto a = new IfThenElse(K, I, I);
+ auto a = alloc<IfThenElse>(K, I, I);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
{
// Dtypes of true and false exprs must match
- auto a = new IfThenElse(I, I, J);
+ auto a = alloc<IfThenElse>(I, I, J);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
{
// Can't have multiple lanes in condition expr
- auto a = new IfThenElse(new Broadcast(I, 4), I, I);
+ auto a = alloc<IfThenElse>(alloc<Broadcast>(I, 4), I, I);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
TEST(IRVerifier, For) {
KernelScope kernel_scope;
- Var* I = new Var("i", kInt);
- Var* J = new Var("j", kInt);
- Stmt* body = new Block({});
+ VarPtr I = alloc<Var>("i", kInt);
+ VarPtr J = alloc<Var>("j", kInt);
+ StmtPtr body = alloc<Block>(std::vector<StmtPtr>({}));
{
// Can't have nullptr as a Var
- auto a = new For(nullptr, I, J, body);
+ auto a = alloc<For>(nullptr, I, J, body);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_ANY_THROW(verify(a));
}
TEST(IRVerifier, Block) {
KernelScope kernel_scope;
- Var* I = new Var("i", kInt);
- Buf* B = new Buf("B", {new IntImm(10)}, kInt);
+ VarPtr I = alloc<Var>("i", kInt);
+ BufPtr B = alloc<Buf>("B", std::vector<ExprPtr>({alloc<IntImm>(10)}), kInt);
{
- Stmt* store = new Store(B, {I}, I);
+ StmtPtr store = alloc<Store>(B, std::vector<ExprPtr>({I}), I);
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
- Stmt* block1 = new Block({store});
+ StmtPtr block1 = alloc<Block>(std::vector<StmtPtr>({store}));
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
- Stmt* block2 = new Block({store});
+ StmtPtr block2 = alloc<Block>(std::vector<StmtPtr>({store}));
// Stmt can't have multiple parrents, thus inserting it into several blocks
// is illegal
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
TEST(IRVerifier, Store) {
KernelScope kernel_scope;
- Var* I = new Var("i", kInt);
- Var* J = new Var("j", kLong);
- Var* K = new Var("k", kFloat);
- Buf* B = new Buf("b", {new IntImm(10), new IntImm(20)}, kFloat);
+ VarPtr I = alloc<Var>("i", kInt);
+ VarPtr J = alloc<Var>("j", kLong);
+ VarPtr K = alloc<Var>("k", kFloat);
+ BufPtr B = alloc<Buf>(
+ "b",
+ std::vector<ExprPtr>({alloc<IntImm>(10), alloc<IntImm>(20)}),
+ kFloat);
{
// Indices with different int dtypes (kInt, kLong) are ok
- auto a = new Store(B, {I, J}, K);
+ auto a = alloc<Store>(B, std::vector<ExprPtr>({I, J}), K);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_NO_THROW(verify(a));
}
{
// Float index
- auto a = new Store(B, {K, K}, K);
+ auto a = alloc<Store>(B, std::vector<ExprPtr>({K, K}), K);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
{
// Multilanes are only allowed in flattened indices
- auto multilane_index = new Ramp(I, new IntImm(1), 4);
- auto a = new Store(B, {I, multilane_index}, K);
+ auto multilane_index = alloc<Ramp>(I, alloc<IntImm>(1), 4);
+ auto a = alloc<Store>(B, std::vector<ExprPtr>({I, multilane_index}), K);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
{
// Value and buf dtypes mismatch
- auto a = new Store(B, {I}, I);
+ auto a = alloc<Store>(B, std::vector<ExprPtr>({I}), I);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
EXPECT_ANY_THROW(verify(a));
}
auto ref = a * (a * b);
TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {a, b};
- Stmt* s = k.getCodeGenStmt();
+ StmtPtr s = k.getCodeGenStmt();
std::ostringstream oss;
oss << *s;
auto ref = a * (a * b);
TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {a, b};
- Stmt* s = k.getCodeGenStmt();
+ StmtPtr s = k.getCodeGenStmt();
std::ostringstream oss;
oss << *s;
auto ref = a * (a * b);
TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {a, b};
- Stmt* s = k.getCodeGenStmt();
+ StmtPtr s = k.getCodeGenStmt();
std::ostringstream oss;
oss << *s;
auto ref = a * (a * b);
TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {a, b};
- Stmt* s = k.getCodeGenStmt();
+ StmtPtr s = k.getCodeGenStmt();
std::ostringstream oss;
oss << *s;
auto ref = t[0] * t[1];
TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {a, b};
- Stmt* s = k.getCodeGenStmt();
+ StmtPtr s = k.getCodeGenStmt();
std::ostringstream oss;
oss << *s;
TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {a, b, c};
- Stmt* s = k.getCodeGenStmt();
+ StmtPtr s = k.getCodeGenStmt();
std::ostringstream oss;
oss << *s;
TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {a, b, c};
- Stmt* s = k.getCodeGenStmt();
+ StmtPtr s = k.getCodeGenStmt();
std::ostringstream oss;
oss << *s;
TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {a, b, c};
- Stmt* s = k.getCodeGenStmt();
+ StmtPtr s = k.getCodeGenStmt();
std::ostringstream oss;
oss << *s;
parseIR(graph_string, &*graph);
TensorExprKernel k(graph);
- Stmt* s = k.getCodeGenStmt();
+ StmtPtr s = k.getCodeGenStmt();
std::ostringstream oss;
oss << *s;
parseIR(graph_string, &*graph);
TensorExprKernel k(graph);
- Stmt* s = k.getCodeGenStmt();
+ StmtPtr s = k.getCodeGenStmt();
std::ostringstream oss;
oss << *s;
auto ref = a.sum(/*dtype=*/dtype);
TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {a};
- Stmt* s = k.getCodeGenStmt();
+ StmtPtr s = k.getCodeGenStmt();
std::ostringstream oss;
oss << *s;
auto o = at::empty({}, TensorOptions(kCPU));
TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {a};
- Stmt* s = k.getCodeGenStmt();
+ StmtPtr s = k.getCodeGenStmt();
std::ostringstream oss;
oss << *s;
TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {a};
- Stmt* s = k.getCodeGenStmt();
+ StmtPtr s = k.getCodeGenStmt();
std::ostringstream oss;
oss << *s;
TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {a};
- Stmt* s = k.getCodeGenStmt();
+ StmtPtr s = k.getCodeGenStmt();
std::ostringstream oss;
oss << *s;
TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {a};
- Stmt* s = k.getCodeGenStmt();
+ StmtPtr s = k.getCodeGenStmt();
std::ostringstream oss;
oss << *s;
TensorExprKernel k(graph);
std::vector<at::Tensor> inputs = {a};
- Stmt* s = k.getCodeGenStmt();
+ StmtPtr s = k.getCodeGenStmt();
std::ostringstream oss;
oss << *s;
parseIR(graph_string, &*graph);
TensorExprKernel k(graph);
- Stmt* s = k.getCodeGenStmt();
+ StmtPtr s = k.getCodeGenStmt();
std::ostringstream oss;
oss << *s;
parseIR(graph_string, &*graph);
TensorExprKernel k(graph);
- Stmt* s = k.getCodeGenStmt();
+ StmtPtr s = k.getCodeGenStmt();
std::ostringstream oss;
oss << *s;
VarHandle index = VarHandle("index", kInt);
ExprHandle load_a = a_buf.load(index);
- Stmt* store_b = b_buf.store({index}, fast_log(load_a));
- Stmt* stmt = For::make(index, 0, kTotalSize, store_b);
+ StmtPtr store_b = b_buf.store({index}, fast_log(load_a));
+ StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
PaddedBuffer<float> a_v(kTotalSize);
PaddedBuffer<float> b_v(kTotalSize);
BufHandle c("c", {M, N}, kFloat);
VarHandle m("m", kInt);
VarHandle n("n", kInt);
- Stmt* s = For::make(
+ StmtPtr s = For::make(
m,
0,
M,
Placeholder c_buf(BufHandle(c->buf()));
LoopNest l({c});
- Stmt* s = l.root_stmt();
- ASSERT_TRUE(LoopNest::vectorize(
- dynamic_cast<For*>(dynamic_cast<Block*>(s)->front())));
+ StmtPtr s = l.root_stmt();
+ ASSERT_TRUE(LoopNest::vectorize(to<For>(to<Block>(s)->front())));
- ASSERT_TRUE(dynamic_cast<For*>(dynamic_cast<Block*>(s)->front()) == nullptr);
+ ASSERT_TRUE(to<For>(to<Block>(s)->front()) == nullptr);
LLVMCodeGen cg(s, {a, c_buf});
Placeholder c_buf(BufHandle(c->buf()));
LoopNest l({c});
- Stmt* s = l.root_stmt();
- ASSERT_TRUE(LoopNest::vectorize(
- dynamic_cast<For*>(dynamic_cast<Block*>(s)->front())));
- ASSERT_TRUE(dynamic_cast<For*>(dynamic_cast<Block*>(s)->front()) == nullptr);
+ StmtPtr s = l.root_stmt();
+ ASSERT_TRUE(LoopNest::vectorize(to<For>(to<Block>(s)->front())));
+ ASSERT_TRUE(to<For>(to<Block>(s)->front()) == nullptr);
LLVMCodeGen cg(s, {a, c_buf});
return cast<float>(i * i + 1);
});
LoopNest l({tensor});
- Stmt* stmt = l.root_stmt();
+ StmtPtr stmt = l.root_stmt();
Placeholder f_buf(BufHandle(tensor->buf()));
LLVMCodeGen cg(stmt, {f_buf});
Placeholder c_buf(BufHandle(c->buf()));
LoopNest l({c});
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
LLVMCodeGen cg(s, {a, b, c_buf});
Placeholder c_buf(BufHandle(c->buf()));
LoopNest l({c});
l.prepareForCodegen();
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
LLVMCodeGen cg(s, {a, b, c_buf});
Placeholder b(BufHandle("b", {n}, kFloat));
Placeholder c(BufHandle("c", {n}, kFloat));
VarHandle i("i", kInt);
- Stmt* s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i)));
+ StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i)));
std::vector<float> aData(size, 1.0f);
std::vector<float> bData(size, 2.0f);
std::vector<float> cData(size, 0.0f);
Placeholder b(BufHandle("b", {n}, kFloat));
Placeholder c(BufHandle("c", {n}, kFloat));
VarHandle i("i", kInt);
- Stmt* s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i)));
+ StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i)));
std::vector<float> aData(size, 1.0f);
std::vector<float> bData(size, 2.0f);
std::vector<float> cData(size, 0.0f);
return a.load(i) + b.load(i);
});
LoopNest l({c});
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
LLVMCodeGen cg(s, {a, b, c, n});
std::vector<float> aData(size, 1.0f);
std::vector<float> bData(size, 2.0f);
});
LoopNest l({c});
l.prepareForCodegen();
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
LLVMCodeGen cg(s, {a, b, c, m, n});
std::vector<float> aData(M * N, 1.0f);
std::vector<float> bData(M * N, 2.0f);
TEST(LLVM, EmptyStmt) {
KernelScope kernel_scope;
- Stmt* s = new Block({});
+ StmtPtr s = alloc<Block>(std::vector<StmtPtr>({}));
LLVMCodeGen cg(s, {});
cg.call({});
LoopNest l({c});
l.prepareForCodegen();
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
s = IRSimplifier::simplify(s);
LLVMCodeGen cg(s, {a, c});
std::vector<float> aData(1, 1.0f);
LoopNest loop({b});
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
LLVMCodeGen cg(s, {a, b});
Tensor* b = Reduce("sum", axis, Sum(), a, reduce_axis);
LoopNest loop({b});
- std::vector<For*> loops = loop.getLoopStmtsFor(b);
- For* loop_m = loops.at(1);
- For* loop_n = loops.at(2);
+ std::vector<ForPtr> loops = loop.getLoopStmtsFor(b);
+ ForPtr loop_m = loops.at(1);
+ ForPtr loop_n = loops.at(2);
loop.reorderAxis(loop_m, loop_n);
loops = loop.getLoopStmtsFor(b);
loop_m = loops.at(2);
loop_n = loops.at(1);
- auto b_body = const_cast<Stmt*>(loop.getAllWritesToBuf(b->buf())[1]);
+ auto b_body = loop.getAllWritesToBuf(b->buf())[1];
ASSERT_TRUE(loop.rfactor(b_body, loop_n));
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
LLVMCodeGen cg(s, {a, b});
Tensor* b = Reduce("sum", {{1, "K"}}, Sum(), a, {{M, "M"}, {N, "N"}});
LoopNest loopnest({b});
- std::vector<For*> loops = loopnest.getLoopStmtsFor(b);
+ std::vector<ForPtr> loops = loopnest.getLoopStmtsFor(b);
// Reorder n and m loops
loopnest.reorderAxis(loops.at(1), loops.at(2));
- auto b_body = const_cast<Stmt*>(loopnest.getAllWritesToBuf(b->buf()).at(1));
+ auto b_body = loopnest.getAllWritesToBuf(b->buf()).at(1);
auto all_loops = loopnest.getAllLoopNestsWritingToBuf(b->buf());
ASSERT_TRUE(all_loops.size() == 2 && all_loops[1].size() == 3);
ASSERT_TRUE(loopnest.rfactor(b_body, all_loops[1][1]));
loopnest.prepareForCodegen();
- Stmt* s = IRSimplifier::simplify(loopnest.root_stmt());
+ StmtPtr s = IRSimplifier::simplify(loopnest.root_stmt());
LLVMCodeGen cg(s, {a, b});
PaddedBuffer<float> a_v(1, M, N, "a_v");
});
LoopNest loop_nest({f});
auto const& loops = loop_nest.getLoopStmtsFor(f);
- For* m = loops[0];
- For* n = loops[1];
+ ForPtr m = loops[0];
+ ForPtr n = loops[1];
if (test_cfg & 0x1) {
m->set_parallel();
}
n->set_parallel();
}
loop_nest.prepareForCodegen();
- Stmt* stmt = loop_nest.root_stmt();
+ StmtPtr stmt = loop_nest.root_stmt();
LLVMCodeGen cg(stmt, {f});
PaddedBuffer<float> f_v(M, N, "f_v");
return t3->load(m, n) + m + n;
});
LoopNest loop_nest({t4}, {t1, t2, t3, t4});
- std::vector<For*> loop_list;
+ std::vector<ForPtr> loop_list;
{
auto const& loops = loop_nest.getLoopStmtsFor(t1);
loop_list.push_back(loops[0]);
}
}
loop_nest.prepareForCodegen();
- Stmt* stmt = loop_nest.root_stmt();
+ StmtPtr stmt = loop_nest.root_stmt();
LLVMCodeGen cg(stmt, {t4});
PaddedBuffer<float> t4_v(M, N, "t4_v");
{
auto const& loops = loop.getLoopStmtsFor(CT);
- For* m = loops[0];
+ ForPtr m = loops[0];
loop.splitWithMask(m, 16);
}
{
auto const& loops = loop.getLoopStmtsFor(CT);
- For* n = loops[2];
+ ForPtr n = loops[2];
loop.splitWithMask(n, 16);
}
// mo, mi, no, ni, k ->
// mo, no, mi, ni, k
{
auto const& loops = loop.getLoopStmtsFor(CT);
- For* mi = loops[1];
- For* no = loops[2];
+ ForPtr mi = loops[1];
+ ForPtr no = loops[2];
loop.reorderAxis(mi, no);
}
// mo, no, mi, ni, k ->
// mo, no, mi, k, ni
{
auto const& loops = loop.getLoopStmtsFor(CT);
- For* ni = loops[3];
- For* k = loops[4];
+ ForPtr ni = loops[3];
+ ForPtr k = loops[4];
loop.reorderAxis(ni, k);
}
// mo, no, mi, k, ni ->
// mo, no, k, mi, ni
{
auto const& loops = loop.getLoopStmtsFor(CT);
- For* mi = loops[2];
- For* k = loops[3];
+ ForPtr mi = loops[2];
+ ForPtr k = loops[3];
loop.reorderAxis(mi, k);
}
{
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
LLVMCodeGen cg(s, {AP, BP, CT});
LoopNest l({c});
l.prepareForCodegen();
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
int32_t N_value = 1024;
std::vector<float> av(M * N_value);
using namespace torch::jit::tensorexpr;
-void checkIR(Stmt* s, const std::string& pattern) {
+void checkIR(StmtPtr s, const std::string& pattern) {
std::ostringstream oss;
oss << *s;
torch::jit::testing::FileCheck().run(pattern, oss.str());
return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
});
LoopNest l({tensor});
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
+ std::vector<ForPtr> loops =
+ l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
LoopNest::splitWithTail(loops[0], 2);
LoopNest::splitWithTail(loops[0], 2);
return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
});
LoopNest l({tensor});
- Stmt* stmt = l.root_stmt();
+ StmtPtr stmt = l.root_stmt();
std::ostringstream oss;
oss << *stmt;
ASSERT_GT(oss.str().size(), 20);
};
Tensor* tensor = Compute("f", {{26, "x"}, {5, "y"}}, func);
LoopNest l({tensor});
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
+ std::vector<ForPtr> loops =
+ l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
LoopNest::splitWithTail(loops[0], 4);
- Stmt* stmt = l.root_stmt();
+ StmtPtr stmt = l.root_stmt();
std::ostringstream oss;
oss << *stmt;
ASSERT_GT(oss.str().size(), 200);
BufHandle f("f", {26, 5}, kFloat);
ExprHandle x_1 = x_outer * 4 + x_inner;
ExprHandle x_outer_end = (ExprHandle(26) - 0) / 4;
- For* stmt1 = For::make(
+ ForPtr stmt1 = For::make(
x_outer,
0,
x_outer_end,
4,
For::make(y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y)))));
ExprHandle x_2 = x_tail + x_outer_end * 4;
- For* stmt2 = For::make(
+ ForPtr stmt2 = For::make(
x_tail,
0,
(ExprHandle(26) - 0) % 4,
For::make(y, 0, 5, Store::make(f, {x_2, y}, func(x_2, y))));
- Stmt* stmt = Block::make({stmt1, stmt2});
+ StmtPtr stmt = Block::make({stmt1, stmt2});
std::ostringstream oss_ref;
oss_ref << *stmt;
}
}
-Block* getSimplifiedBody(const LoopNest& l) {
- Stmt* stmt = l.root_stmt();
- Stmt* simplified = IRSimplifier::simplify(stmt);
- return dynamic_cast<Block*>(simplified);
+BlockPtr getSimplifiedBody(const LoopNest& l) {
+ StmtPtr stmt = l.root_stmt();
+ StmtPtr simplified = IRSimplifier::simplify(stmt);
+ return to<Block>(simplified);
}
-void assertForRange(For* f, int expected_start, int expected_stop) {
+void assertForRange(ForPtr f, int expected_start, int expected_stop) {
ASSERT_NE(f, nullptr);
- const IntImm* start = dynamic_cast<const IntImm*>(f->start());
+ IntImmPtr start = to<IntImm>(f->start());
ASSERT_NE(start, nullptr);
ASSERT_EQ(start->value(), expected_start);
- const IntImm* stop = dynamic_cast<const IntImm*>(f->stop());
+ IntImmPtr stop = to<IntImm>(f->stop());
ASSERT_NE(stop, nullptr);
ASSERT_EQ(stop->value(), expected_stop);
}
void assertForRanges(
- Block* body,
+ BlockPtr body,
const std::vector<std::pair<int, int>>& start_stops) {
ASSERT_EQ(body->nstmts(), start_stops.size());
auto it = body->begin();
for (size_t i = 0; i < start_stops.size(); i++, it++) {
- For* loop = dynamic_cast<For*>(*it);
+ ForPtr loop = to<For>(*it);
assertForRange(loop, start_stops[i].first, start_stops[i].second);
}
}
Tensor* tensor = Compute("f", {{10, "x"}}, func);
LoopNest l({tensor});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* head;
+ ForPtr head;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* tail;
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
+ ForPtr tail;
+ std::vector<ForPtr> loops =
+ l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
loops[0]->set_gpu_block_index(LoopOptions::IDX_Y);
LoopNest::sliceHead(loops[0], 2, &head, &tail);
- Block* body = getSimplifiedBody(l);
+ BlockPtr body = getSimplifiedBody(l);
assertForRanges(body, {{0, 2}, {0, 8}});
ASSERT_TRUE(tail->loop_options().is_gpu_block_index());
Tensor* tensor = Compute("f", {{10, "x"}}, func);
LoopNest l({tensor});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* head;
+ ForPtr head;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* tail;
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
+ ForPtr tail;
+ std::vector<ForPtr> loops =
+ l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
LoopNest::sliceTail(loops[0], 4, &head, &tail);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* tail_head;
+ ForPtr tail_head;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* tail_tail;
+ ForPtr tail_tail;
tail->set_gpu_block_index(LoopOptions::IDX_Y);
LoopNest::sliceTail(tail, 2, &tail_head, &tail_tail);
- Block* body = getSimplifiedBody(l);
+ BlockPtr body = getSimplifiedBody(l);
assertForRanges(body, {{0, 6}, {0, 2}, {8, 10}});
ASSERT_TRUE(tail_head->loop_options().is_gpu_block_index());
Tensor* tensor = Compute("f", {{10, "x"}}, func);
LoopNest l({tensor});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* head;
+ ForPtr head;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* tail;
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
+ ForPtr tail;
+ std::vector<ForPtr> loops =
+ l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
LoopNest::sliceHead(loops[0], 10, &head, &tail);
ASSERT_EQ(head, loops[0]);
ASSERT_EQ(tail, nullptr);
- Block* body = getSimplifiedBody(l);
+ BlockPtr body = getSimplifiedBody(l);
assertForRanges(body, {{0, 10}});
}
Tensor* tensor = Compute("f", {{10, "x"}}, func);
LoopNest l({tensor});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* head;
+ ForPtr head;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* tail;
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
+ ForPtr tail;
+ std::vector<ForPtr> loops =
+ l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
LoopNest::sliceHead(loops[0], 100, &head, &tail);
ASSERT_EQ(head, loops[0]);
ASSERT_EQ(tail, nullptr);
- Block* body = getSimplifiedBody(l);
+ BlockPtr body = getSimplifiedBody(l);
assertForRanges(body, {{0, 10}});
}
Tensor* tensor = Compute("f", {{10, "x"}}, func);
LoopNest l({tensor});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* head;
+ ForPtr head;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* tail;
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
+ ForPtr tail;
+ std::vector<ForPtr> loops =
+ l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
LoopNest::sliceHead(loops[0], 4, &head, &tail);
ASSERT_NE(head, nullptr);
ASSERT_NE(tail, nullptr);
ASSERT_NE(tail, loops[0]);
- Block* body = getSimplifiedBody(l);
+ BlockPtr body = getSimplifiedBody(l);
assertForRanges(body, {{0, 4}, {4, 10}});
}
};
Tensor* tensor = Compute("f", {{10, "x"}}, func);
LoopNest l({tensor});
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
+ std::vector<ForPtr> loops =
+ l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* head;
+ ForPtr head;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* tail;
+ ForPtr tail;
LoopNest::sliceTail(loops[0], 4, &head, &tail);
// head: [0, 6)
// tail: [6, 10)
// tail_head: [6, 8)
// tail_tail: [8, 10)
- Block* body = getSimplifiedBody(l);
+ BlockPtr body = getSimplifiedBody(l);
assertForRanges(body, {{0, 6}, {6, 8}, {8, 10}});
}
Tensor* tensor = Compute("f", {{10, "x"}}, func);
LoopNest l({tensor});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* head;
+ ForPtr head;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* tail;
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
+ ForPtr tail;
+ std::vector<ForPtr> loops =
+ l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
LoopNest::sliceTail(loops[0], 10, &head, &tail);
ASSERT_EQ(head, nullptr);
ASSERT_EQ(tail, loops[0]);
- Block* body = getSimplifiedBody(l);
+ BlockPtr body = getSimplifiedBody(l);
assertForRanges(body, {{0, 10}});
}
Tensor* tensor = Compute("f", {{10, "x"}}, func);
LoopNest l({tensor});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* head;
+ ForPtr head;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* tail;
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
+ ForPtr tail;
+ std::vector<ForPtr> loops =
+ l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
LoopNest::sliceTail(loops[0], 100, &head, &tail);
ASSERT_EQ(head, nullptr);
ASSERT_EQ(tail, loops[0]);
- Block* body = getSimplifiedBody(l);
+ BlockPtr body = getSimplifiedBody(l);
assertForRanges(body, {{0, 10}});
}
Tensor* tensor = Compute("f", {{10, "x"}}, func);
LoopNest l({tensor});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* head;
+ ForPtr head;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* tail;
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
+ ForPtr tail;
+ std::vector<ForPtr> loops =
+ l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
LoopNest::sliceTail(loops[0], 4, &head, &tail);
ASSERT_NE(head, nullptr);
ASSERT_NE(tail, nullptr);
ASSERT_NE(tail, loops[0]);
- Block* body = getSimplifiedBody(l);
+ BlockPtr body = getSimplifiedBody(l);
assertForRanges(body, {{0, 6}, {6, 10}});
}
LoopNest l({tensor});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* inner;
+ ForPtr inner;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* tail;
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
+ ForPtr tail;
+ std::vector<ForPtr> loops =
+ l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
// outer: [0, 4)
// inner: [0, 21)
// tail: [84, 100)
// for (int x_tail = 0; x_tail < 16; x_tail++) {
// f[x_tail + 84] = 1.f + float(x_tail + 84);
// }
- Block* body = getSimplifiedBody(l);
+ BlockPtr body = getSimplifiedBody(l);
assertForRanges(body, {{0, 2}, {2, 4}, {0, 16}});
auto biter = body->begin();
- For* loop = dynamic_cast<For*>(*biter++);
+ ForPtr loop = to<For>(*biter++);
assertForRanges(loop->body(), {{0, 19}, {19, 21}});
- loop = dynamic_cast<For*>(*biter);
+ loop = to<For>(*biter);
assertForRanges(loop->body(), {{0, 19}, {19, 21}});
}
};
Tensor* tensor = Compute("f", {{10, "x"}}, func);
LoopNest l({tensor});
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
+ std::vector<ForPtr> loops =
+ l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* head;
+ ForPtr head;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* tail;
+ ForPtr tail;
LoopNest::sliceHead(loops[0], 2, &head, &tail);
// head: [0, 2)
// tail: [2, 10)
LoopNest::normalize(tail);
// normalized_tail: [0, 8)
- Block* body = getSimplifiedBody(l);
+ BlockPtr body = getSimplifiedBody(l);
assertForRanges(body, {{0, 2}, {0, 8}});
}
Tensor* tensor =
Compute("f", {{dim, "x"}}, [](const ExprHandle& x) { return x; });
LoopNest l({tensor});
- std::vector<For*> loops =
+ std::vector<ForPtr> loops =
l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* head;
+ ForPtr head;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* tail;
+ ForPtr tail;
LoopNest::sliceHead(loops[0], 2, &head, &tail);
LoopNest::sliceTail(tail, 2);
- Block* body = getSimplifiedBody(l);
+ BlockPtr body = getSimplifiedBody(l);
ASSERT_EQ(expected_for_ranges.size(), 3);
auto it = body->begin();
for (auto& start_stop : expected_for_ranges) {
- For* loop = dynamic_cast<For*>(*it++);
+ ForPtr loop = to<For>(*it++);
int start = evalExpr<int>(ExprHandle(loop->start()), dim, dimension);
int stop = evalExpr<int>(ExprHandle(loop->stop()), dim, dimension);
ASSERT_EQ(start, start_stop.first);
};
Tensor* tensor = Compute("f", {{199, "x"}}, func);
LoopNest l({tensor});
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
+ std::vector<ForPtr> loops =
+ l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
LoopNest::splitWithTail(loops[0], 17);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
LoopNest::splitWithTail(loops[0], 7);
- Stmt* stmt = l.root_stmt();
- Stmt* simplified = IRSimplifier::simplify(stmt);
- Block* body = dynamic_cast<Block*>(simplified);
+ StmtPtr stmt = l.root_stmt();
+ StmtPtr simplified = IRSimplifier::simplify(stmt);
+ BlockPtr body = to<Block>(simplified);
ASSERT_EQ(body->nstmts(), 3);
auto biter = body->begin();
// Verify that the split loops are ordered correctly.
- For* loop = dynamic_cast<For*>(*biter++);
+ ForPtr loop = to<For>(*biter++);
assertForRange(loop, 0, 7);
- loop = dynamic_cast<For*>(*biter++);
+ loop = to<For>(*biter++);
assertForRange(loop, 0, 4);
- loop = dynamic_cast<For*>(*biter);
+ loop = to<For>(*biter);
assertForRange(loop, 0, 12);
}
};
Tensor* tensor = Compute("f", {{24, "x"}, {5, "y"}}, func);
LoopNest l({tensor});
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
+ std::vector<ForPtr> loops =
+ l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
LoopNest::splitWithTail(loops[0], 4);
- Stmt* stmt = l.root_stmt();
+ StmtPtr stmt = l.root_stmt();
std::ostringstream oss;
oss << *stmt;
ASSERT_GT(oss.str().size(), 200);
BufHandle f("f", {24, 5}, kFloat);
ExprHandle x_1 = x_outer * 4 + x_inner;
ExprHandle x_outer_end = (ExprHandle(24) - 0) / 4;
- Stmt* stmt = new Block({For::make(
+ StmtPtr stmt = alloc<Block>(std::vector<StmtPtr>({For::make(
x_outer,
0,
x_outer_end,
x_inner,
0,
4,
- For::make(y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y)))))});
+ For::make(y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y)))))}));
std::ostringstream oss_ref;
oss_ref << *stmt;
});
LoopNest l({tensor});
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
+ std::vector<ForPtr> loops =
+ l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
LoopNest::splitWithMask(loops[1], 4);
- Stmt* stmt = l.root_stmt();
+ StmtPtr stmt = l.root_stmt();
PaddedBuffer<float> a_v(M, N, "a");
PaddedBuffer<float> b_v(M, N, "b");
});
LoopNest l({tensor});
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
+ std::vector<ForPtr> loops =
+ l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
LoopNest::splitWithMask(loops[0], 4);
LoopNest::splitWithMask(loops[0], 4);
- Stmt* stmt1 = IRSimplifier::simplify(l.root_stmt());
+ StmtPtr stmt1 = IRSimplifier::simplify(l.root_stmt());
// Two splits mean 3 loops, but should need no masks in this case.
checkIR(stmt1, R"IR(
// }
// }
// }
- Buf* A = new Buf("A", {new IntImm(100), new IntImm(100)}, kInt);
- Buf* B =
- new Buf("B", {new IntImm(100), new IntImm(100), new IntImm(200)}, kInt);
- Buf* C =
- new Buf("C", {new IntImm(100), new IntImm(100), new IntImm(300)}, kInt);
+ BufPtr A = alloc<Buf>(
+ "A",
+ std::vector<ExprPtr>({alloc<IntImm>(100), alloc<IntImm>(100)}),
+ kInt);
+ BufPtr B = alloc<Buf>(
+ "B",
+ std::vector<ExprPtr>(
+ {alloc<IntImm>(100), alloc<IntImm>(100), alloc<IntImm>(200)}),
+ kInt);
+ BufPtr C = alloc<Buf>(
+ "C",
+ std::vector<ExprPtr>(
+ {alloc<IntImm>(100), alloc<IntImm>(100), alloc<IntImm>(300)}),
+ kInt);
BufHandle a_buf(A);
BufHandle b_buf(B);
BufHandle c_buf(C);
});
LoopNest l({tensor});
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
+ std::vector<ForPtr> loops =
+ l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
l.tile(loops[0], loops[1], 4, 8);
// IR check
- Stmt* stmt = IRSimplifier::simplify(l.root_stmt());
+ StmtPtr stmt = IRSimplifier::simplify(l.root_stmt());
checkIR(stmt, R"IR(
# CHECK: for (int m_outer
# CHECK: for (int n_outer
});
LoopNest l({tensor});
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
+ std::vector<ForPtr> loops =
+ l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
l.tile(loops[0], loops[1], 5, 9);
// IR check
- Stmt* stmt = IRSimplifier::simplify(l.root_stmt());
+ StmtPtr stmt = IRSimplifier::simplify(l.root_stmt());
checkIR(stmt, R"IR(
# CHECK: for (int m_outer
# CHECK: for (int n_outer
});
LoopNest nest({tensor});
- std::vector<For*> loops =
+ std::vector<ForPtr> loops =
nest.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
nest.tile(loops[1], loops[2], 3, 3);
// IR check
- Stmt* stmt = IRSimplifier::simplify(nest.root_stmt());
+ StmtPtr stmt = IRSimplifier::simplify(nest.root_stmt());
checkIR(stmt, R"IR(
# CHECK: for (int m
# CHECK: for (int n_outer
return a_buf.load(m) + b_buf.load(m) + 1.0f;
});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For *inner, *tail;
+ ForPtr inner, tail;
LoopNest l({tensor});
auto loops = NodeFinder<For>::find(l.root_stmt());
LoopNest::splitWithTail(loops[0], 4, &inner, &tail);
ASSERT_NE(inner, nullptr);
ASSERT_NE(tail, nullptr);
- For* outer = loops[0];
+ ForPtr outer = loops[0];
// Outer loop carries loop axis bindings.
ASSERT_TRUE(outer->loop_options().is_gpu_block_index());
return a_buf.load(m) + b_buf.load(m) + 1.0f;
});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* inner;
+ ForPtr inner;
LoopNest l({tensor});
auto loops = NodeFinder<For>::find(l.root_stmt());
loops[0]->set_gpu_block_index(LoopOptions::IDX_Y);
LoopNest::splitWithMask(loops[0], 4, &inner);
- For* outer = loops[0];
+ ForPtr outer = loops[0];
// Outer loop carries loop axis bindings.
ASSERT_TRUE(outer->loop_options().is_gpu_block_index());
return a_buf.load(m, n) + b_buf.load(n, k);
});
LoopNest l({c});
- Stmt* stmt = l.root_stmt();
+ StmtPtr stmt = l.root_stmt();
PaddedBuffer<float> a_v(M, N, "a_v");
for (int m = 0; m < M; m++) {
LoopNest l({d}, {c, d});
l.prepareForCodegen();
- Stmt* stmt = l.root_stmt();
+ StmtPtr stmt = l.root_stmt();
std::ostringstream oss;
oss << *stmt;
ASSERT_GT(oss.str().size(), 100);
l1.prepareForCodegen();
l2.prepareForCodegen();
- Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt());
- Stmt* stmt2 = IRSimplifier::simplify(l2.root_stmt());
+ StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
+ StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt());
SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, c_buf, d_buf, y});
SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, c_buf, d_buf, y});
}
}
l.prepareForCodegen();
- Stmt* stmt = l.root_stmt();
+ StmtPtr stmt = l.root_stmt();
std::ostringstream oss;
oss << *stmt;
});
LoopNest l2({z2});
l2.prepareForCodegen();
- Stmt* stmt2 = l2.root_stmt();
+ StmtPtr stmt2 = l2.root_stmt();
std::ostringstream oss2;
oss2 << *stmt2;
// would normally compare results but Rand isn't implemented in the
// SimpleIREvaluator, even if we could seed it.
- Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt());
+ StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
// Check the IR we produced
checkIR(stmt1, R"IR(
// would normally compare results but Rand isn't implemented in the
// SimpleIREvaluator, even if we could seed it.
- Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt());
+ StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
// Check the IR we produced
checkIR(stmt1, R"IR(
// would normally compare results but Rand isn't implemented in the
// SimpleIREvaluator, even if we could seed it.
- Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt());
+ StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
// Check the IR we produced
checkIR(stmt1, R"IR(
l1.prepareForCodegen();
l2.prepareForCodegen();
- Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt());
- Stmt* stmt2 = IRSimplifier::simplify(l2.root_stmt());
+ StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
+ StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt());
SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y});
SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y});
LoopNest l1({y}, {x, y});
l1.computeInline(x->buf());
- Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt());
+ StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
// Check the IR we produced
checkIR(stmt1, R"IR(
});
LoopNest l({b}, {a, b});
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0);
+ std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0);
LoopNest::splitWithMask(loops[0], 4);
ASSERT_THROWS_WITH(l.computeInline(a->buf()), "compound indices");
}
});
LoopNest l({b}, {a, b});
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(b->buf()).at(0);
+ std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(b->buf()).at(0);
LoopNest::splitWithMask(loops[0], 3);
l.computeInline(a->buf());
l.prepareForCodegen();
- Stmt* s = IRSimplifier::simplify(l.root_stmt());
+ StmtPtr s = IRSimplifier::simplify(l.root_stmt());
std::vector<int> output(6, 0);
SimpleIREvaluator eval(s, {b});
return a->load(j + ExprHandle(8));
});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* i_inner;
+ ForPtr i_inner;
LoopNest l({b}, {a, b});
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0);
+ std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0);
LoopNest::splitWithMask(loops[0], 4, &i_inner);
LoopNest::splitWithMask(i_inner, 2);
ASSERT_THROWS_WITH(l.computeInline(a->buf()), "compound indices");
LoopNest l({b}, {a, b});
l.computeInline(a->buf());
- std::vector<For*> loops = NodeFinder<For>::find(l.root_stmt());
+ std::vector<ForPtr> loops = NodeFinder<For>::find(l.root_stmt());
LoopNest::splitWithMask(loops.back(), 3);
l.prepareForCodegen();
- Stmt* s = IRSimplifier::simplify(l.root_stmt());
+ StmtPtr s = IRSimplifier::simplify(l.root_stmt());
std::vector<int> output(6, 0);
SimpleIREvaluator eval(s, {b});
eval(output);
loops = NodeFinder<For>::find(l.root_stmt());
LoopNest::splitWithMask(loops.front(), 2);
l.prepareForCodegen();
- Stmt* s = IRSimplifier::simplify(l.root_stmt());
+ StmtPtr s = IRSimplifier::simplify(l.root_stmt());
std::vector<int> output(16, 0);
SimpleIREvaluator eval(s, {b});
eval(output);
});
LoopNest l({b}, {a, b});
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0);
+ std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0);
LoopNest::splitWithMask(loops[0], 4);
ASSERT_THROWS_WITH(l.computeInline(a->buf()), "compound indices");
}
});
LoopNest l({c}, {a, b, c});
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0);
+ std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0);
l.computeInline(a->buf());
l.prepareForCodegen();
- Stmt* s = IRSimplifier::simplify(l.root_stmt());
+ StmtPtr s = IRSimplifier::simplify(l.root_stmt());
std::vector<int> output(4 * 3, 0);
SimpleIREvaluator eval(s, {c});
eval(output);
});
LoopNest l({c}, {a, b, c});
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0);
+ std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0);
l.computeInline(a->buf());
l.computeInline(b->buf());
l.prepareForCodegen();
- Stmt* s = IRSimplifier::simplify(l.root_stmt());
+ StmtPtr s = IRSimplifier::simplify(l.root_stmt());
std::vector<int> output(4 * 3, 0);
SimpleIREvaluator eval(s, {c});
eval(output);
});
LoopNest l({c}, {a, b, c});
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0);
+ std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0);
l.computeInline(b->buf());
l.prepareForCodegen();
- Stmt* s = IRSimplifier::simplify(l.root_stmt());
+ StmtPtr s = IRSimplifier::simplify(l.root_stmt());
std::vector<int> output(4 * 3, 0);
SimpleIREvaluator eval(s, {c});
eval(output);
});
LoopNest l({c}, {a, b, c});
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0);
+ std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0);
LoopNest::splitWithMask(loops[0], 4);
loops = l.getAllLoopNestsWritingToBuf(b->buf()).at(0);
LoopNest::splitWithMask(loops[0], 3);
// would normally compare results but Rand isn't implemented in the
// SimpleIREvaluator, even if we could seed it.
- Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt());
+ StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
// Check the IR we produced
checkIR(stmt1, R"IR(
LoopNest l({b, c});
l.prepareForCodegen();
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
std::vector<float> a_data(kTotalSize, 7.0f);
std::vector<float> b_data(kTotalSize, 0.0f);
l.computeInline(l.getLoopBodyFor(e));
l.computeInline(l.getLoopBodyFor(f));
l.prepareForCodegen();
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
std::vector<float> a_data(kTotalSize, 1.0f);
std::vector<float> b_data(kTotalSize, 2.0f);
return a.load(i, j) + b.load(i, j);
});
LoopNest l({c});
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
SimpleIREvaluator cg(s, {a, b, c, m, n});
std::vector<float> aData(M * N, 1.0f);
std::vector<float> bData(M * N, 2.0f);
Tensor* B = Compute(
"B", {{N, "i_b"}}, [&](const VarHandle& i_b) { return A->load(i_b); });
LoopNest l({B}, {A, B});
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(B->buf()).at(0);
+ std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(B->buf()).at(0);
LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]);
l.prepareForCodegen();
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
checkIR(s, R"IR(
# CHECK: Allocate(temp); // dtype=int, dims=[1]
{
// First let's try to compute P at axis cy (the outer loop)
LoopNest l(orig_loopnest);
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(c->buf()).at(0);
+ std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(c->buf()).at(0);
LoopNest::computeAt(l.getLoopBodyFor(p), loops[0]);
l.prepareForCodegen();
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
// Check the IR we produced
checkIR(s, R"IR(
{
// Now let's try to compute P at axis cx (the inner loop)
LoopNest l(orig_loopnest);
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(c->buf()).at(0);
+ std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(c->buf()).at(0);
LoopNest::computeAt(l.getLoopBodyFor(p), loops[1]);
l.prepareForCodegen();
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
// Check the IR we produced
checkIR(s, R"IR(
{
// First let's try to compute A at axis dy (the outer loop)
LoopNest l(orig_loopnest);
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(D->buf()).at(0);
+ std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(D->buf()).at(0);
LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]);
l.prepareForCodegen();
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
// Check the IR we produced
checkIR(s, R"IR(
{
// Now let's try to compute A at axis dx (the inner loop)
LoopNest l(orig_loopnest);
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(D->buf()).at(0);
+ std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(D->buf()).at(0);
LoopNest::computeAt(l.getLoopBodyFor(A), loops[1]);
l.prepareForCodegen();
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
// Check the IR we produced
checkIR(s, R"IR(
# CHECK: }
# CHECK: Free(temp);
)IR");
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
// Now check that the loop still produces the correct result.
std::vector<int> c_data(kW * kH, 0);
{
// Now let's try to compute P at axis cx (the inner loop)
LoopNest l(orig_loopnest);
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(c->buf()).at(0);
+ std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(c->buf()).at(0);
LoopNest::computeAt(l.getLoopBodyFor(p), loops[1]);
l.simplify();
l.eliminateDeadStores();
# CHECK: }
# CHECK: Free(temp);
)IR");
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
// Now check that the loop still produces the correct result.
std::vector<int> c_data(kW * kH, 0);
# CHECK: }
# CHECK: }
)IR");
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(B->buf()).at(0);
+ std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(B->buf()).at(0);
LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]);
// FIXME: The current IR is totally broken. The body of the inlined loop is:
l.simplify();
l.prepareForCodegen();
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
SimpleIREvaluator cg(s, {IP, B});
// auto At = at::ones({N, H}, at::kFloat);
std::stringstream ordering;
public:
- std::string getOrder(Stmt* s) {
+ std::string getOrder(StmtPtr s) {
ordering.str("");
s->accept(this);
return ordering.str();
}
// NOLINTNEXTLINE(cppcoreguidelines-explicit--functions,modernize-use-override)
- void visit(For* v) {
+ void visit(ForPtr v) {
ordering << v->var()->name_hint() << ",";
IRVisitor::visit(v);
}
return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
});
LoopNest l({tensor});
- Stmt* stmt1 = Stmt::clone(l.root_stmt());
+ StmtPtr stmt1 = Stmt::clone(l.root_stmt());
std::vector<int> stmt1_output(6, 0);
SimpleIREvaluator cg(stmt1, {tensor});
auto loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
LoopNest::reorderAxis(loops[0], loops[1]);
- Stmt* stmt2 = Stmt::clone(l.root_stmt());
+ StmtPtr stmt2 = Stmt::clone(l.root_stmt());
ASSERT_NE(stmt1, stmt2);
LoopOrderHelper loopOrderHelper;
// Reorder them back.
loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
LoopNest::reorderAxis(loops[0], loops[1]);
- Stmt* stmt3 = l.root_stmt();
+ StmtPtr stmt3 = l.root_stmt();
std::string order3 = loopOrderHelper.getOrder(stmt3);
ASSERT_EQ(order3, order1);
LoopNest l({tensor});
LoopOrderHelper loopOrderHelper;
- Stmt* stmt1 = Stmt::clone(l.root_stmt());
+ StmtPtr stmt1 = Stmt::clone(l.root_stmt());
ASSERT_EQ(loopOrderHelper.getOrder(stmt1), "x,y,z,");
std::vector<int> stmt1_output(24, 0);
LoopNest::reorderAxis(loops[0], loops[1]);
ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "y,x,z,");
- Stmt* stmt2 = Stmt::clone(l.root_stmt());
+ StmtPtr stmt2 = Stmt::clone(l.root_stmt());
std::vector<int> stmt2_output(24, 0);
SimpleIREvaluator cg2(stmt2, {tensor});
LoopNest::reorderAxis(loops[1], loops[2]);
ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "y,z,x,");
- Stmt* stmt3 = Stmt::clone(l.root_stmt());
+ StmtPtr stmt3 = Stmt::clone(l.root_stmt());
std::vector<int> stmt3_output(24, 0);
SimpleIREvaluator cg3(stmt3, {tensor});
LoopNest l({tensor});
LoopOrderHelper loopOrderHelper;
- Stmt* stmt1 = Stmt::clone(l.root_stmt());
+ StmtPtr stmt1 = Stmt::clone(l.root_stmt());
ASSERT_EQ(loopOrderHelper.getOrder(stmt1), "w,x,y,z,");
std::vector<int> stmt1_output(24, 0);
LoopNest::reorderAxis(loops[2], loops[1]);
ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "w,y,x,z,");
- Stmt* stmt2 = l.root_stmt();
+ StmtPtr stmt2 = l.root_stmt();
std::vector<int> stmt2_output(24, 0);
SimpleIREvaluator cg2(stmt2, {tensor});
LoopNest l({tensor});
LoopOrderHelper loopOrderHelper;
- Stmt* stmt1 = Stmt::clone(l.root_stmt());
+ StmtPtr stmt1 = Stmt::clone(l.root_stmt());
std::vector<int> stmt1_output(24, 0);
SimpleIREvaluator cg(stmt1, {tensor});
LoopNest::reorderAxis(loops[0], loops[3]);
ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "z,x,y,w,");
- Stmt* stmt2 = l.root_stmt();
+ StmtPtr stmt2 = l.root_stmt();
std::vector<int> stmt2_output(24, 0);
SimpleIREvaluator cg2(stmt2, {tensor});
return ExprHandle(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
});
LoopNest l({tensor});
- Stmt* stmt1 = Stmt::clone(l.root_stmt());
+ StmtPtr stmt1 = Stmt::clone(l.root_stmt());
auto loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
LoopNest::reorderAxis(loops[1], loops[1]);
- Stmt* stmt2 = Stmt::clone(l.root_stmt());
+ StmtPtr stmt2 = Stmt::clone(l.root_stmt());
std::ostringstream oss, oss2;
oss << *stmt1;
VarHandle i = VarHandle(loops[0]->var());
- Stmt* store_1 = Store::make(BufHandle(extra.data()), {i, 0}, ExprHandle(1.f));
- Stmt* store_2 = Store::make(BufHandle(extra.data()), {i, 1}, ExprHandle(2.f));
+ StmtPtr store_1 =
+ Store::make(BufHandle(extra.data()), {i, 0}, ExprHandle(1.f));
+ StmtPtr store_2 =
+ Store::make(BufHandle(extra.data()), {i, 1}, ExprHandle(2.f));
// stmt 3 is the Function body.
- Stmt* store_3 = Store::make(BufHandle(extra.data()), {i, 2}, ExprHandle(4.f));
+ StmtPtr store_3 =
+ Store::make(BufHandle(extra.data()), {i, 2}, ExprHandle(4.f));
loops[0]->body()->prepend_stmt(store_1);
loops[1]->body()->prepend_stmt(store_2);
loops[1]->body()->append_stmt(store_3);
- Stmt* stmt1 = Stmt::clone(l.root_stmt());
+ StmtPtr stmt1 = Stmt::clone(l.root_stmt());
std::vector<int> extra1(6, 0);
std::vector<int> res1(24, 0);
*/
LoopNest::reorderAxis(loops[1], loops[2]);
- Stmt* stmt2 = Stmt::clone(l.root_stmt());
+ StmtPtr stmt2 = Stmt::clone(l.root_stmt());
// Check the IR we produced
checkIR(stmt2, R"IR(
*/
loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0);
LoopNest::reorderAxis(loops[0], loops[2]);
- Stmt* stmt3 = Stmt::clone(l.root_stmt());
+ StmtPtr stmt3 = Stmt::clone(l.root_stmt());
// Check the IR we produced
checkIR(stmt3, R"IR(
auto loops = l.getAllLoopNestsWritingToBuf(c->buf()).at(0);
int j = 0;
- for (auto* l : loops) {
+ for (auto l : loops) {
// Add an increment at each layer of the loop which counts the number of
// times the loop executes.
- Load* load = new Load(extra.data(), {new IntImm(j)});
- Add* add = new Add(load, new IntImm(1));
- Stmt* store = new Store(extra.data(), {new IntImm(j)}, add);
+ LoadPtr load =
+ alloc<Load>(extra.data(), std::vector<ExprPtr>({alloc<IntImm>(j)}));
+ AddPtr add = alloc<Add>(load, alloc<IntImm>(1));
+ StmtPtr store = alloc<Store>(
+ extra.data(), std::vector<ExprPtr>({alloc<IntImm>(j)}), add);
if (prepend) {
l->body()->prepend_stmt(store);
}
j++;
}
- Stmt* stmt1 = Stmt::clone(l.root_stmt());
+ StmtPtr stmt1 = Stmt::clone(l.root_stmt());
std::vector<int> extra1(5, 0);
std::vector<int> res1(2 * 3 * 2 * 3 * 2, 0);
loops = l.getAllLoopNestsWritingToBuf(c->buf()).at(0);
LoopNest::reorderAxis(loops[index1], loops[index2]);
- Stmt* stmt2 = Stmt::clone(l.root_stmt());
+ StmtPtr stmt2 = Stmt::clone(l.root_stmt());
std::ostringstream oss, oss2;
oss << *stmt1;
});
LoopNest l({z}, {x, y, z});
- For* a = nullptr;
- For* b = nullptr;
+ ForPtr a = nullptr;
+ ForPtr b = nullptr;
auto fors = NodeFinder<For>::find(l.root_stmt());
- for (auto* f : fors) {
+ for (auto f : fors) {
if (f->var()->name_hint() == "m2") {
a = f;
} else if (f->var()->name_hint() == "k2") {
LoopNest::reorderAxis(a, b);
l.prepareForCodegen();
- Stmt* stmt = IRSimplifier::simplify(l.root_stmt());
+ StmtPtr stmt = IRSimplifier::simplify(l.root_stmt());
// Check the IR we produced has the 3 nests in the right order, but k and m
// swapped in the middle.
ASSERT_TRUE(
LoopNest::vectorize(l.getAllLoopNestsWritingToBuf(tensor->buf())[0][0]));
- Stmt* root_stmt = l.root_stmt();
- Block* outer_block = dynamic_cast<Block*>(root_stmt);
+ StmtPtr root_stmt = l.root_stmt();
+ BlockPtr outer_block = to<Block>(root_stmt);
ASSERT_NE(outer_block, nullptr);
- while (Block* inner_block = dynamic_cast<Block*>(outer_block->front())) {
+ while (BlockPtr inner_block = to<Block>(outer_block->front())) {
outer_block = inner_block;
}
// Verify that we have only a single loop level remaining after
// vectorization.
ASSERT_EQ(outer_block->nstmts(), 1);
- For* for_loop = dynamic_cast<For*>(outer_block->front());
+ ForPtr for_loop = to<For>(outer_block->front());
ASSERT_NE(for_loop, nullptr);
- Block* for_body = for_loop->body();
+ BlockPtr for_body = for_loop->body();
ASSERT_EQ(for_body->nstmts(), 1);
- ASSERT_EQ(dynamic_cast<For*>(for_body->front()), nullptr);
+ ASSERT_EQ(to<For>(for_body->front()), nullptr);
}
TEST(LoopNest, VectorizeLoopNotNormalized) {
ASSERT_TRUE(LoopNest::vectorize(inner_for));
ASSERT_EQ(outer_for->body()->nstmts(), 1);
- ASSERT_EQ(dynamic_cast<For*>(outer_for->body()->front()), nullptr);
+ ASSERT_EQ(to<For>(outer_for->body()->front()), nullptr);
}
namespace {
Tensor* A = Compute(
"A", {{upper_bound, "x"}}, [&](const VarHandle& x) { return x * 2; });
LoopNest l({A});
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(A->buf())[0];
- Stmt* unrolled = nullptr;
+ std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A->buf())[0];
+ StmtPtr unrolled = nullptr;
LoopNest::unroll(loops[0], &unrolled);
std::ostringstream oss;
oss << *unrolled;
{{outer_bound, "x"}, {inner_bound, "y"}},
[&](const VarHandle& x, const VarHandle& y) { return x + y; });
LoopNest l({A});
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(A->buf())[0];
- Stmt* unrolled = nullptr;
+ std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A->buf())[0];
+ StmtPtr unrolled = nullptr;
LoopNest::unroll(loops[0], &unrolled);
checkIR(unrolled, R"IR(
# CHECK: for (int y = 0; y < 4; y++) {
{{outer_bound, "x"}, {inner_bound, "y"}},
[&](const VarHandle& x, const VarHandle& y) { return x + y; });
LoopNest l({A});
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(A->buf())[0];
- Stmt* unrolled = nullptr;
+ std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A->buf())[0];
+ StmtPtr unrolled = nullptr;
LoopNest::unroll(
- static_cast<For*>(loops[0]->body()->stmts().front()), &unrolled);
+ static_to<For>(loops[0]->body()->stmts().front()), &unrolled);
checkIR(loops[0], R"IR(
# CHECK: for (int x = 0; x < 3; x++) {
# CHECK: A[x, 0] = x;
{Store::make(a_buf, {x}, x * 2),
Store::make(b_buf, {x}, Load::make(a_buf, {x}))}));
Block::make({f});
- Stmt* unrolled = nullptr;
+ StmtPtr unrolled = nullptr;
LoopNest::unroll(f, &unrolled);
checkIR(unrolled, R"IR(
# CHECK: A[0] = 0;
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
auto b = Block::make({outer_for});
- std::vector<For*> loops = {outer_for, inner_for};
- Stmt* unrolled = nullptr;
+ std::vector<ForPtr> loops = {outer_for, inner_for};
+ StmtPtr unrolled = nullptr;
LoopNest::unroll(loops[0], &unrolled);
checkIR(unrolled, R"IR(
# CHECK: for (int j = 0; j < 4; j++) {
Tensor* A = Compute(
"A", {{upper_bound, "x"}}, [&](const VarHandle& x) { return x * 2; });
LoopNest l({A});
- std::vector<For*> loops = l.getAllLoopNestsWritingToBuf(A->buf())[0];
- Stmt* unrolled = nullptr;
+ std::vector<ForPtr> loops = l.getAllLoopNestsWritingToBuf(A->buf())[0];
+ StmtPtr unrolled = nullptr;
ASSERT_THROWS_WITH(
LoopNest::unroll(loops[0], &unrolled), "non-constant loop");
}
Store::make(a_buf, {x}, e),
Store::make(b_buf, {x}, e + 1)}));
Block::make({f});
- Stmt* unrolled = nullptr;
+ StmtPtr unrolled = nullptr;
LoopNest::unroll(f, &unrolled);
std::ostringstream oss;
oss << *unrolled;
Block::make({for_stmt});
ASSERT_FALSE(LoopNest::isNormalized(for_stmt));
- for_stmt->set_start(new IntImm(0));
+ for_stmt->set_start(alloc<IntImm>(0));
ASSERT_TRUE(LoopNest::isNormalized(for_stmt));
VarHandle N("N", kInt);
LoopNest::normalize(for_stmt);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* x_inner;
+ ForPtr x_inner;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* x_tail;
+ ForPtr x_tail;
LoopNest::splitWithTail(for_stmt, 10, &x_inner, &x_tail);
auto x_outer_result = IRSimplifier::simplify(for_stmt);
auto outer_for = For::make(i, 0, 10, inner_for);
Block::make({outer_for});
- std::vector<For*> loops = {outer_for, inner_for};
- For* flattened = nullptr;
+ std::vector<ForPtr> loops = {outer_for, inner_for};
+ ForPtr flattened = nullptr;
ASSERT_TRUE(LoopNest::flatten(loops, &flattened));
ASSERT_EQ(flattened, loops.front());
auto for3 = For::make(i, 0, 10, for2);
Block::make({for3});
- std::vector<For*> loops = {for3, for2, for1};
- For* flattened = nullptr;
+ std::vector<ForPtr> loops = {for3, for2, for1};
+ ForPtr flattened = nullptr;
ASSERT_TRUE(LoopNest::flatten(loops, &flattened));
ASSERT_EQ(flattened, loops.front());
auto outer_for = For::make(i, 2, 10, inner_for);
Block::make({outer_for});
- std::vector<For*> loops = {outer_for, inner_for};
- For* flattened = nullptr;
+ std::vector<ForPtr> loops = {outer_for, inner_for};
+ ForPtr flattened = nullptr;
ASSERT_TRUE(LoopNest::flatten(loops, &flattened));
ASSERT_EQ(flattened, loops.front());
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
auto b = Block::make({outer_for});
- std::vector<For*> loops = {outer_for, inner_for};
- For* flattened = nullptr;
+ std::vector<ForPtr> loops = {outer_for, inner_for};
+ ForPtr flattened = nullptr;
ASSERT_TRUE(LoopNest::flatten(loops, &flattened));
ASSERT_EQ(flattened, loops.front());
HashProvider hasher;
auto hash_before = hasher.hash(par);
- std::vector<For*> loops = {outer_for, inner_for};
- For* flattened = nullptr;
+ std::vector<ForPtr> loops = {outer_for, inner_for};
+ ForPtr flattened = nullptr;
ASSERT_FALSE(LoopNest::flatten(loops, &flattened));
ASSERT_EQ(flattened, nullptr);
auto hash_after = hasher.hash(par);
HashProvider hasher;
auto hash_before = hasher.hash(par);
- std::vector<For*> loops = {outer_for, inner_for};
- For* flattened = nullptr;
+ std::vector<ForPtr> loops = {outer_for, inner_for};
+ ForPtr flattened = nullptr;
ASSERT_FALSE(LoopNest::flatten(loops, &flattened));
ASSERT_EQ(flattened, nullptr);
auto hash_after = hasher.hash(par);
auto hash_before = hasher.hash(loop.root_stmt());
auto loops = loop.getAllLoopNestsWritingToBuf(c->buf())[1];
- For* flattened = nullptr;
+ ForPtr flattened = nullptr;
ASSERT_FALSE(LoopNest::flatten(loops, &flattened));
ASSERT_EQ(flattened, nullptr);
auto hash_after = hasher.hash(loop.root_stmt());
HashProvider hasher;
auto hash_before = hasher.hash(par);
- std::vector<For*> loops = {outer_for1, inner_for2};
- For* flattened = nullptr;
+ std::vector<ForPtr> loops = {outer_for1, inner_for2};
+ ForPtr flattened = nullptr;
ASSERT_FALSE(LoopNest::flatten(loops, &flattened));
ASSERT_EQ(flattened, nullptr);
auto hash_after = hasher.hash(par);
});
LoopNest l({B, C}, {A, B, C});
- Stmt* j_loop = l.getAllLoopNestsWritingToBuf(B->buf())[0][1];
+ StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B->buf())[0][1];
LoopNest::cacheAccesses(A->buf(), "A_local", j_loop);
l.prepareForCodegen();
- Stmt* result = IRSimplifier::simplify(l.root_stmt());
+ StmtPtr result = IRSimplifier::simplify(l.root_stmt());
// just this once: verify the whole thing.
checkIR(result, R"IR(
});
LoopNest l({B, C}, {A, B, C});
- Stmt* i_loop = l.getAllLoopNestsWritingToBuf(B->buf())[0][0];
+ StmtPtr i_loop = l.getAllLoopNestsWritingToBuf(B->buf())[0][0];
LoopNest::cacheAccesses(A->buf(), "A_local", i_loop);
l.prepareForCodegen();
- Stmt* result = IRSimplifier::simplify(l.root_stmt());
+ StmtPtr result = IRSimplifier::simplify(l.root_stmt());
checkIR(result, R"IR(
#CHECK: Allocate(A_local); // dtype=int, dims=[21, 11]
});
LoopNest l({B, C}, {A, B, C});
- Stmt* j_loop = l.getAllLoopNestsWritingToBuf(B->buf())[0][1];
+ StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B->buf())[0][1];
LoopNest::cacheAccesses(A->buf(), "A_local", j_loop);
l.prepareForCodegen();
- Stmt* result = IRSimplifier::simplify(l.root_stmt());
+ StmtPtr result = IRSimplifier::simplify(l.root_stmt());
checkIR(result, R"IR(
#CHECK: Allocate(A_local); // dtype=int, dims=[2, 11]
});
LoopNest l({B, C}, {A, B, C});
- Stmt* body = l.getLoopBodyFor(B);
+ StmtPtr body = l.getLoopBodyFor(B);
LoopNest::cacheAccesses(A->buf(), "A_local", body);
l.prepareForCodegen();
- Stmt* result = IRSimplifier::simplify(l.root_stmt());
+ StmtPtr result = IRSimplifier::simplify(l.root_stmt());
checkIR(result, R"IR(
#CHECK: Allocate(A_local); // dtype=int, dims=[5, 2]
});
LoopNest l({B, C}, {A, B, C});
- Stmt* a_loop = l.getAllLoopNestsWritingToBuf(A->buf())[0][1];
+ StmtPtr a_loop = l.getAllLoopNestsWritingToBuf(A->buf())[0][1];
LoopNest::cacheAccesses(A->buf(), "A_local", a_loop);
l.prepareForCodegen();
- Stmt* result = IRSimplifier::simplify(l.root_stmt());
+ StmtPtr result = IRSimplifier::simplify(l.root_stmt());
checkIR(result, R"IR(
#CHECK: Allocate(A_local); // dtype=int, dims=[1, 64]
BufHandle g("g", {26, 5}, kInt);
ExprHandle x_outer_end = 5;
ExprHandle x_2 = x + x_outer_end * 4;
- For* stmt1 = For::make(
+ ForPtr stmt1 = For::make(
x,
0,
5,
Store::make(f, {x_2, y}, (x_2 + y)),
Store::make(g, {x_2, y}, (x_2 * y)),
})));
- Stmt* stmt = Block::make({stmt1});
+ StmtPtr stmt = Block::make({stmt1});
// Will eliminate if not used by an output.
LoopNest loop(Stmt::clone(stmt), {f.node()});
BufHandle h("h", {26, 5}, kInt);
ExprHandle x_outer_end = 5;
ExprHandle x_2 = x + x_outer_end * 4;
- For* stmt1 = For::make(x, 0, 26 * 5, Store::make(f, {x}, x));
- For* stmt2 = For::make(z, 0, 26 * 5, Store::make(g, {z}, z + 1));
- For* stmt3 = For::make(
+ ForPtr stmt1 = For::make(x, 0, 26 * 5, Store::make(f, {x}, x));
+ ForPtr stmt2 = For::make(z, 0, 26 * 5, Store::make(g, {z}, z + 1));
+ ForPtr stmt3 = For::make(
x,
0,
5,
Block::make({
Store::make(h, {x, y}, Load::make(f, {x * y})),
})));
- Stmt* stmt = Block::make({stmt1, stmt2, stmt3});
+ StmtPtr stmt = Block::make({stmt1, stmt2, stmt3});
// Will eliminate the write to g, but not f since it used by the producer of
// h.
{Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)});
auto inner_for2 = For::make(y, 0, 5, for_body2);
auto outer_for2 = For::make(x, 0, 10, inner_for2);
- Block* body = Block::make({outer_for1, outer_for2});
+ BlockPtr body = Block::make({outer_for1, outer_for2});
Tensor* A = new Tensor(a_buf.node(), body);
std::vector<int> a_data(50, 0);
- Stmt* s = IRSimplifier::simplify(l.root_stmt());
+ StmtPtr s = IRSimplifier::simplify(l.root_stmt());
SimpleIREvaluator cg(s, {A});
std::vector<int> a_ref(50, 0);
{Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)});
auto inner_for2 = For::make(y, 0, 5, for_body2);
auto outer_for2 = For::make(x, 0, 10, inner_for2);
- Block* body = Block::make({outer_for1, outer_for2});
+ BlockPtr body = Block::make({outer_for1, outer_for2});
Tensor* A = new Tensor(a_buf.node(), body);
Tensor* B = Compute(
std::vector<int> a_data(50, 0);
std::vector<int> b_data(50, 0);
- Stmt* s = IRSimplifier::simplify(l.root_stmt());
+ StmtPtr s = IRSimplifier::simplify(l.root_stmt());
SimpleIREvaluator cg(s, {B});
std::vector<int> b_ref(50, 0);
return {std::move(a), t};
}
-static Stmt* splitTailReorder(Tensor* b) {
+static StmtPtr splitTailReorder(Tensor* b) {
constexpr int kVectorWidth = 8;
LoopNest nest({b});
auto loops = nest.getAllLoopNestsWritingToBuf(b->buf())[0];
return nest.root_stmt();
}
-static Stmt* splitMaskReorder(Tensor* b) {
+static StmtPtr splitMaskReorder(Tensor* b) {
constexpr int kVectorWidth = 8;
LoopNest nest({b});
auto loops = nest.getAllLoopNestsWritingToBuf(b->buf())[1];
return nest.root_stmt();
}
-static void checkColReduce(Stmt* s, Placeholder& p, Tensor* t) {
+static void checkColReduce(StmtPtr s, Placeholder& p, Tensor* t) {
int M = immediateAs<int>(p.dim(0));
int N = immediateAs<int>(p.dim(1));
PaddedBuffer<float> a(M, N);
KernelScope kernel_scope;
constexpr int M = 76, N = 128;
auto p = colReduce(M, N);
- Stmt* s = splitTailReorder(p.second);
+ StmtPtr s = splitTailReorder(p.second);
std::ostringstream oss;
oss << *s;
KernelScope kernel_scope;
constexpr int M = 76, N = 100;
auto p = colReduce(M, N);
- Stmt* s = splitTailReorder(p.second);
+ StmtPtr s = splitTailReorder(p.second);
std::ostringstream oss;
oss << *s;
KernelScope kernel_scope;
constexpr int M = 76, N = 128;
auto p = colReduce(M, N);
- Stmt* s = splitMaskReorder(p.second);
+ StmtPtr s = splitMaskReorder(p.second);
checkColReduce(s, *p.first, p.second);
}
KernelScope kernel_scope;
constexpr int M = 76, N = 100;
auto p = colReduce(M, N);
- Stmt* s = splitMaskReorder(p.second);
+ StmtPtr s = splitMaskReorder(p.second);
checkColReduce(s, *p.first, p.second);
}
auto outer_cond =
Cond::make(CompareSelect::make(i, 5, kGT), inner_cond, nullptr);
auto forI = For::make(i, 0, 20, outer_cond);
- Stmt* par = Block::make({forI});
+ StmtPtr par = Block::make({forI});
LoopNest l(par, {a_buf.node()});
LoopNest::reorderAxis(forI, forJ);
ASSERT_EQ(par, l.root_stmt());
ASSERT_TRUE(LoopNest::vectorize(loops[0]));
nest.prepareForCodegen();
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
- Stmt* s = nest.root_stmt();
+ StmtPtr s = nest.root_stmt();
std::ostringstream oss;
oss << *nest.root_stmt();
torch::jit::testing::FileCheck().run(
Placeholder a("a", kLong, {N});
Placeholder b("b", kLong, {N});
VarHandle n("n", kLong);
- Stmt* s = For::make(n, 0, N, b.store({n}, a.load({n}) + LongImm::make(1l)));
+ StmtPtr s = For::make(n, 0, N, b.store({n}, a.load({n}) + LongImm::make(1l)));
s = IRSimplifier::simplify(s);
std::ostringstream oss;
oss << *s;
# CHECK-NOT: for (
)IR";
- auto newForI = dynamic_cast<For*>(Stmt::clone(forI));
+ auto newForI = to<For>(Stmt::clone(forI));
auto forM = For::make(m, 0, 50, newForI);
auto par = Block::make({forM});
LoopNest nest(par, {a_buf.node(), b_buf.node()});
# CHECK-NOT: for (
)IR";
- auto newForI = dynamic_cast<For*>(Stmt::clone(forI));
+ auto newForI = to<For>(Stmt::clone(forI));
auto forM = For::make(m, 0, 50, newForI);
auto par = Block::make({forM});
LoopNest nest(par, {a_buf.node(), b_buf.node()});
auto forK = For::make(k, 0, 100, Store::make(b_buf, {k}, Mul::make(20, k)));
auto par = Block::make({forJ, forK});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_loop;
+ ForPtr fused_loop;
ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
std::ostringstream oss;
auto forK = For::make(k, 0, 100, Store::make(b_buf, {k}, Mul::make(20, k)));
auto par = Block::make({forI, forJ, forK});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_loop;
+ ForPtr fused_loop;
ASSERT_TRUE(LoopNest::fuseLoops({forI, forJ, forK}, &fused_loop));
std::ostringstream oss;
auto forN = For::make(n, 0, 20, Block::make({initB, forK}));
auto par = Block::make({forM, forN});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_loop;
+ ForPtr fused_loop;
ASSERT_TRUE(LoopNest::fuseLoops({forM, forN}, &fused_loop));
std::ostringstream oss;
Store::make(b_buf, {m, n}, Add::make(m, Mul::make(n, 100)))));
auto par = Block::make({forI, forM});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_loop;
+ ForPtr fused_loop;
ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
std::ostringstream oss;
n, 0, 100, Store::make(b_buf, {i, n}, Add::make(i, Mul::make(n, 100))));
auto forI = For::make(i, 0, 20, Block::make({forJ, forN}));
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_loop;
+ ForPtr fused_loop;
ASSERT_TRUE(LoopNest::fuseLoops({forJ, forN}, &fused_loop));
std::ostringstream oss;
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
auto par = Block::make({forJ, forK});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_loop;
+ ForPtr fused_loop;
ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
}
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
auto par = Block::make({forJ, forK});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_loop;
+ ForPtr fused_loop;
ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
}
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
auto par = Block::make({forJ, initB, forK});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_loop;
+ ForPtr fused_loop;
ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
}
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
auto par = Block::make({forI, initB, forK});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_loop;
+ ForPtr fused_loop;
ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
}
auto forK = For::make(k, 0, N, Store::make(b_buf, {j}, Mul::make(20, k)));
auto par = Block::make({forJ, forK});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_loop;
+ ForPtr fused_loop;
ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
std::ostringstream oss;
auto forK = For::make(k, 0, M + N, Store::make(b_buf, {j}, Mul::make(20, k)));
auto par = Block::make({forJ, forK});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_loop;
+ ForPtr fused_loop;
ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
std::ostringstream oss;
auto forK = For::make(k, M, N + N, Store::make(b_buf, {j}, Mul::make(20, k)));
auto par = Block::make({forJ, forK});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_loop;
+ ForPtr fused_loop;
ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
std::ostringstream oss;
auto par = Block::make({forJ, forK});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_loop;
+ ForPtr fused_loop;
ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
std::ostringstream oss;
auto par = Block::make({forI, forM});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_loop;
+ ForPtr fused_loop;
ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
std::ostringstream oss;
For::make(m, 0, 20, Store::make(c_buf, {m}, Load::make(a_buf, {m})));
auto par = Block::make({forI, forM});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_loop;
+ ForPtr fused_loop;
ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
std::ostringstream oss;
auto par = Block::make({forI, forM});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_loop;
+ ForPtr fused_loop;
ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
std::ostringstream oss;
auto par = Block::make({forI, forM});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_loop;
+ ForPtr fused_loop;
ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
std::ostringstream oss;
auto par = Block::make({forI, forM});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_loop;
+ ForPtr fused_loop;
ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
}
auto par = Block::make({forI, forM});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_loop;
+ ForPtr fused_loop;
ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
}
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
auto par = Block::make({forJ, forK});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_loop;
+ ForPtr fused_loop;
ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
}
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
auto par = Block::make({forJ, forK});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_loop;
+ ForPtr fused_loop;
ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
}
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
auto par = Block::make({forM, forN});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_loop;
+ ForPtr fused_loop;
ASSERT_FALSE(LoopNest::fuseLoops({forM, forN}, &fused_loop));
}
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
auto par = Block::make({forI, forM});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_loop;
+ ForPtr fused_loop;
ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop));
}
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,cppcoreguidelines-avoid-magic-numbers)
auto forI = For::make(i, 0, 20, Block::make({forJ, forN}));
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_loop;
+ ForPtr fused_loop;
ASSERT_FALSE(LoopNest::fuseLoops({forJ, forN}, &fused_loop));
}
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
auto par = Block::make({forJ, forK});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_loop;
+ ForPtr fused_loop;
ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop));
}
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
auto par = Block::make({forK, forJ});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_loop;
+ ForPtr fused_loop;
ASSERT_FALSE(LoopNest::fuseLoops({forK, forJ}, &fused_loop));
}
using namespace analysis;
- auto CB = [](int s, int e) { return Bound(new IntImm(s), new IntImm(e)); };
+ auto CB = [](int s, int e) {
+ return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
+ };
// Sanity check 3 overlap cases.
ASSERT_EQ(ContainedOrEqual, boundOverlap(CB(0, 0), CB(0, 0)));
using namespace analysis;
- auto CB = [](int s, int e) { return Bound(new IntImm(s), new IntImm(e)); };
+ auto CB = [](int s, int e) {
+ return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
+ };
// Sanity check one dimensional cases.
ASSERT_EQ(ContainedOrEqual, overlaps({CB(0, 0)}, {CB(0, 0)}));
using namespace analysis;
- auto CB = [](int s, int e) { return Bound(new IntImm(s), new IntImm(e)); };
+ auto CB = [](int s, int e) {
+ return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
+ };
auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
return indexBoundsEquals(x, y);
};
using namespace analysis;
- auto CB = [](int s, int e) { return Bound(new IntImm(s), new IntImm(e)); };
+ auto CB = [](int s, int e) {
+ return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
+ };
auto EQ = [](std::vector<IndexBounds> x, std::vector<IndexBounds> y) {
if (x.size() != y.size()) {
return false;
* B[0] = A[0] + 1;
*/
- Store* aStore = Store::make(a, {0}, 3);
- Store* bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1));
+ StorePtr aStore = Store::make(a, {0}, 3);
+ StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1));
- Stmt* stmt = Block::make({aStore, bStore});
+ StmtPtr stmt = Block::make({aStore, bStore});
stmt->accept(&analyzer);
* C[0] = B[0] + 1;
*/
- Store* aStore = Store::make(a, {0}, 3);
- Store* bStore = Store::make(b, {0}, Load::make(a, {0}));
- Store* cStore = Store::make(c, {0}, Add::make(Load::make(b, {0}), 1));
+ StorePtr aStore = Store::make(a, {0}, 3);
+ StorePtr bStore = Store::make(b, {0}, Load::make(a, {0}));
+ StorePtr cStore = Store::make(c, {0}, Add::make(Load::make(b, {0}), 1));
- Stmt* stmt = Block::make({aStore, bStore, cStore});
+ StmtPtr stmt = Block::make({aStore, bStore, cStore});
stmt->accept(&analyzer);
* B[0] = A[0] + 1;
*/
- Store* aStore = Store::make(a, {0}, 3);
- Store* a2Store = Store::make(a, {0}, 6);
- Store* bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1));
+ StorePtr aStore = Store::make(a, {0}, 3);
+ StorePtr a2Store = Store::make(a, {0}, 6);
+ StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1));
- Stmt* stmt = Block::make({aStore, a2Store, bStore});
+ StmtPtr stmt = Block::make({aStore, a2Store, bStore});
stmt->accept(&analyzer);
* B[0] = A[0] + 1;
*/
- Store* aStore = Store::make(a, {x}, x);
- Stmt* loop = For::make(x, 0, 10, aStore);
- Store* bStore = Store::make(b, {0}, Add::make(Load::make(a, {4}), 1));
+ StorePtr aStore = Store::make(a, {x}, x);
+ StmtPtr loop = For::make(x, 0, 10, aStore);
+ StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {4}), 1));
- Stmt* stmt = Block::make({loop, bStore});
+ StmtPtr stmt = Block::make({loop, bStore});
stmt->accept(&analyzer);
// It should have bounds covering the range of x: 0 <= x < 10.
ASSERT_TRUE(indexBoundsEquals(
- aStoreAccess->bounds(), {Bound(new IntImm(0), new IntImm(9))}));
+ aStoreAccess->bounds(), {Bound(alloc<IntImm>(0), alloc<IntImm>(9))}));
}
// Reductions should promote dependencies as well.
* B[0] = A[0];
*/
- Store* aInit = Store::make(a, {0}, 0);
+ StorePtr aInit = Store::make(a, {0}, 0);
ExprHandle reduce =
ExprHandle(Sum()(a.node(), ExprHandle(1), {x.node()}, {x.node()}));
- Store* aReduce = Store::make(a, {0}, reduce);
- Stmt* loop = For::make(x, 0, 10, aReduce);
- Store* bStore = Store::make(b, {0}, Load::make(a, {0}));
+ StorePtr aReduce = Store::make(a, {0}, reduce);
+ StmtPtr loop = For::make(x, 0, 10, aReduce);
+ StorePtr bStore = Store::make(b, {0}, Load::make(a, {0}));
- Stmt* stmt = Block::make({aInit, loop, bStore});
+ StmtPtr stmt = Block::make({aInit, loop, bStore});
stmt->accept(&analyzer);
// Find loads within the reduction:
auto reduceLoads = NodeFinder<Load>::find(reduce.node());
// Pull out the access for the load inside the loop.
- for (auto* load : reduceLoads) {
+ for (auto load : reduceLoads) {
auto loopLoad = analyzer.accessFor(load);
// It should have 10 element long bounds.
ASSERT_TRUE(indexBoundsEquals(
- loopLoad->bounds(), {Bound(new IntImm(0), new IntImm(9))}));
+ loopLoad->bounds(), {Bound(alloc<IntImm>(0), alloc<IntImm>(9))}));
}
}
* B[0] = A[0];
*/
- Store* aInit = Store::make(a, {0}, 0);
+ StorePtr aInit = Store::make(a, {0}, 0);
ExprHandle aLoad = Load::make(a, {x});
- Store* aReduce = Store::make(a, {0}, Add::make(aLoad, 1));
- Stmt* loop = For::make(x, 0, 10, aReduce);
- Store* bStore = Store::make(b, {0}, Load::make(a, {0}));
+ StorePtr aReduce = Store::make(a, {0}, Add::make(aLoad, 1));
+ StmtPtr loop = For::make(x, 0, 10, aReduce);
+ StorePtr bStore = Store::make(b, {0}, Load::make(a, {0}));
- Stmt* stmt = Block::make({aInit, loop, bStore});
+ StmtPtr stmt = Block::make({aInit, loop, bStore});
stmt->accept(&analyzer);
auto loopLoad = analyzer.accessFor(aLoad.node());
// It should have 10 element long bounds.
ASSERT_TRUE(indexBoundsEquals(
- loopLoad->bounds(), {Bound(new IntImm(0), new IntImm(9))}));
+ loopLoad->bounds(), {Bound(alloc<IntImm>(0), alloc<IntImm>(9))}));
}
// Can determine dependencies of outputs, through to inputs.
*/
ExprHandle aLoad = Load::make(a, {x});
- Store* bStore = Store::make(b, {x}, Max::make(aLoad, 0, true));
- Stmt* loop = For::make(x, 0, 10, bStore);
+ StorePtr bStore = Store::make(b, {x}, Max::make(aLoad, 0, true));
+ StmtPtr loop = For::make(x, 0, 10, bStore);
- Stmt* stmt = Block::make({loop});
+ StmtPtr stmt = Block::make({loop});
stmt->accept(&analyzer);
* }
*/
- Store* bStore = Store::make(b, {x}, Max::make(x, 0, true));
- Stmt* loop = For::make(x, 0, 10, bStore);
+ StorePtr bStore = Store::make(b, {x}, Max::make(x, 0, true));
+ StmtPtr loop = For::make(x, 0, 10, bStore);
- Stmt* stmt = Block::make({loop});
+ StmtPtr stmt = Block::make({loop});
stmt->accept(&analyzer);
* }
*/
- std::vector<Stmt*> stmts(
+ std::vector<StmtPtr> stmts(
{For::make(x, 1, 10, Store::make(b, {x}, Load::make(a, {x}))),
For::make(
x, 1, 9, Store::make(b, {x}, Mul::make(Load::make(b, {x}), 2))),
For::make(x, 3, 4, Store::make(c, {x}, Load::make(a, {x}))),
For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x})))});
- Stmt* stmt = Block::make(stmts);
+ StmtPtr stmt = Block::make(stmts);
stmt->accept(&analyzer);
// The last write to C does not depend on the other write to C.
ASSERT_FALSE(analyzer.dependsIndirectly(stmts[3], stmts[2]));
- auto CB = [](int s, int e) { return Bound(new IntImm(s), new IntImm(e)); };
+ auto CB = [](int s, int e) {
+ return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
+ };
auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
return indexBoundsEquals(x, y);
};
// much.
auto history = analyzer.getHistory();
ASSERT_EQ(history.size(), 10);
- Var* aVar = a.node()->base_handle();
- Var* bVar = b.node()->base_handle();
- Var* cVar = c.node()->base_handle();
+ VarPtr aVar = a.node()->base_handle();
+ VarPtr bVar = b.node()->base_handle();
+ VarPtr cVar = c.node()->base_handle();
// The first access is the input A.
ASSERT_EQ(history[0]->type(), AccessType::Input);
* }
*/
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))),
For::make(x, 0, 9, Store::make(a, {x}, Load::make(a, {x + 1}))),
For::make(
// Sanity check output depends on Input.
ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
- auto CB = [](int s, int e) { return Bound(new IntImm(s), new IntImm(e)); };
+ auto CB = [](int s, int e) {
+ return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
+ };
auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
return indexBoundsEquals(x, y);
};
// Now let's look at the bounds of each access.
auto history = analyzer.getHistory();
ASSERT_EQ(history.size(), 12);
- Var* aVar = a.node()->base_handle();
- Var* bVar = b.node()->base_handle();
+ VarPtr aVar = a.node()->base_handle();
+ VarPtr bVar = b.node()->base_handle();
// The first access is the input A.
ASSERT_EQ(history[0]->type(), AccessType::Input);
// Not self dependent since all loop iterations use a different y.
MemDependencyChecker analyzer;
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
y,
0,
10,
// Not self dependent due to different y (with offset).
MemDependencyChecker analyzer;
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
y,
0,
10,
// Is self dependent since all loops use a common constant element of A.
MemDependencyChecker analyzer;
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
x,
0,
10,
// read.
MemDependencyChecker analyzer;
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
x,
0,
10,
// Is self dependent since all loops use a common symbolic element of A.
MemDependencyChecker analyzer;
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
x,
0,
10,
MemDependencyChecker analyzer;
- Stmt* stmt =
+ StmtPtr stmt =
For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1})));
stmt->accept(&analyzer);
MemDependencyChecker analyzer;
analyzer.allowLoopExecutionOrderAnalysis();
- Stmt* stmt =
+ StmtPtr stmt =
For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1})));
stmt->accept(&analyzer);
MemDependencyChecker analyzer;
- Stmt* stmt =
+ StmtPtr stmt =
For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1})));
stmt->accept(&analyzer);
MemDependencyChecker analyzer;
analyzer.allowLoopExecutionOrderAnalysis();
- Stmt* stmt =
+ StmtPtr stmt =
For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1})));
stmt->accept(&analyzer);
MemDependencyChecker analyzer;
analyzer.allowLoopExecutionOrderAnalysis();
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
x,
3,
10,
MemDependencyChecker analyzer;
analyzer.allowLoopExecutionOrderAnalysis();
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
x,
3,
10,
MemDependencyChecker analyzer;
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
x,
3,
10,
MemDependencyChecker analyzer;
analyzer.allowLoopExecutionOrderAnalysis();
- Stmt* stmt =
+ StmtPtr stmt =
For::make(x, 3, 10, Store::make(a, {x - 2}, Load::make(a, {x - 1})));
stmt->accept(&analyzer);
// Execution order doesn't matter since the read and the write are totally
// distinct.
- Stmt* stmt =
+ StmtPtr stmt =
For::make(x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2})));
stmt->accept(&analyzer);
// Execution order doesn't matter since the read and the write are totally
// distinct.
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 1})));
stmt->accept(&analyzer);
// same if the read is behind the write so long as they are distinct.
MemDependencyChecker analyzer;
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 1})));
stmt->accept(&analyzer);
// But not if the offset is in the stride.
MemDependencyChecker analyzer;
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 2})));
stmt->accept(&analyzer);
// Works with negative offsets too.
MemDependencyChecker analyzer;
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 2})));
stmt->accept(&analyzer);
// Detects accesses are distinct when offset is large but not a multiple
// of stride.
MemDependencyChecker analyzer;
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 7})));
stmt->accept(&analyzer);
// Works with offsets which are multiples of the stride.
MemDependencyChecker analyzer;
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 4})));
stmt->accept(&analyzer);
// within.
MemDependencyChecker analyzer;
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
x, 0, 10, Store::make(a, {x * 6}, Load::make(a, {x * 6 + 5})));
stmt->accept(&analyzer);
// multiple.
MemDependencyChecker analyzer;
- Stmt* stmt =
+ StmtPtr stmt =
For::make(x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6})));
stmt->accept(&analyzer);
// still works when the read axis is the smaller stride.
MemDependencyChecker analyzer;
- Stmt* stmt =
+ StmtPtr stmt =
For::make(x, 0, 10, Store::make(a, {x * 4}, Load::make(a, {x * 2})));
stmt->accept(&analyzer);
// and there is an offset.
MemDependencyChecker analyzer;
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 1})));
stmt->accept(&analyzer);
// The smaller stride determines whether there is overlap.
MemDependencyChecker analyzer;
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 4})));
stmt->accept(&analyzer);
// The smaller stride determines whether there is overlap, not the larger.
MemDependencyChecker analyzer;
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
x, 0, 10, Store::make(a, {x * 2 + 3}, Load::make(a, {x * 6})));
stmt->accept(&analyzer);
// If they have strides with no common muliple > 1, they overlap.
MemDependencyChecker analyzer;
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 3 + 1})));
stmt->accept(&analyzer);
// If the offset is greater than the size of the loop, they can't overlap.
MemDependencyChecker analyzer;
- Stmt* stmt =
+ StmtPtr stmt =
For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 10})));
stmt->accept(&analyzer);
// If they have different execution orders they may overlap.
MemDependencyChecker analyzer;
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
x, 0, 10, Store::make(a, {x}, Load::make(a, {ExprHandle(9) - x})));
stmt->accept(&analyzer);
// Or they may not, depending on their start offset and strides.
MemDependencyChecker analyzer;
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
x,
0,
10,
// If the stride is not monotonic, they overlap.
MemDependencyChecker analyzer;
- Stmt* stmt =
+ StmtPtr stmt =
For::make(x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2})));
stmt->accept(&analyzer);
// If the stride is not monotonic, they overlap - even with an offset.
MemDependencyChecker analyzer;
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2 + 1})));
stmt->accept(&analyzer);
// Mod too...
analysis::MemDependencyChecker analyzer;
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
x,
0,
10,
{
MemDependencyChecker analyzer;
- Stmt* stmt =
+ StmtPtr stmt =
For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1})));
stmt->accept(&analyzer);
{
MemDependencyChecker analyzer;
analyzer.allowLoopExecutionOrderAnalysis();
- Stmt* stmt =
+ StmtPtr stmt =
For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1})));
stmt->accept(&analyzer);
using namespace analysis;
MemDependencyChecker analyzer({a.node()}, {b.node()});
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{For::make(
x, 0, 10, Store::make(b, {x * 2 + 1}, Load::make(a, {x * 2 + 1}))),
For::make(x, 0, 10, Store::make(b, {x * 2}, Load::make(a, {x * 2})))
{
analysis::MemDependencyChecker analyzer({a.node()}, {c.node()});
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{For::make(
x,
0,
// Future usages may depend on accesses in both branches of a condition.
MemDependencyChecker analyzer({a, b}, {c});
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
Cond::make(
CompareSelect::make(y, 5, CompareSelectOperation::kLT),
// Future usages may depend on accesses in both branches of a condition.
MemDependencyChecker analyzer({a, b}, {c});
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
Cond::make(
CompareSelect::make(y, 5, CompareSelectOperation::kLT),
// Only has true branch.
MemDependencyChecker analyzer({a, b}, {c});
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
Cond::make(
CompareSelect::make(y, 5, CompareSelectOperation::kLT),
// Only has false branch.
MemDependencyChecker analyzer({a, b}, {c});
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
Cond::make(
CompareSelect::make(y, 5, CompareSelectOperation::kLT),
// Cond's Condition depends on a previous access.
MemDependencyChecker analyzer({a}, {c});
- Store* initStore = Store::make(c, {x}, Load::make(a, {x}));
+ StorePtr initStore = Store::make(c, {x}, Load::make(a, {x}));
ExprHandle conditionalLoad = Load::make(c, {0});
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{For::make(x, 0, 10, initStore),
Cond::make(
CompareSelect::make(
// Future usages may depend on accesses in both branches of a condition.
MemDependencyChecker analyzer({a, b}, {c});
- Store* ifStore = Store::make(
+ StorePtr ifStore = Store::make(
c,
{0},
IfThenElse::make(
CompareSelect::make(y, 5, CompareSelectOperation::kLT),
Add::make(Load::make(b, {0}), 1),
Add::make(Load::make(b, {1}), 1)));
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
ifStore});
// dependent on it.
MemDependencyChecker analyzer({a, b}, {c});
- Store* ifStore = Store::make(
+ StorePtr ifStore = Store::make(
c,
{0},
IfThenElse::make(
CompareSelect::make(y, 5, CompareSelectOperation::kLT),
Add::make(Load::make(b, {0}), 1),
42));
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
ifStore});
// uncertain if this would be helpful.
MemDependencyChecker analyzer({a, b}, {c});
- Store* ifStore = Store::make(
+ StorePtr ifStore = Store::make(
c,
{0},
IfThenElse::make(
CompareSelect::make(y, 5, CompareSelectOperation::kLT),
Load::make(b, {x}),
Load::make(a, {x})));
- Stmt* stmt = Block::make({For::make(x, 0, 10, ifStore)});
+ StmtPtr stmt = Block::make({For::make(x, 0, 10, ifStore)});
stmt->accept(&analyzer);
// Cutting a loop with single element writes.
MemDependencyChecker analyzer({a}, {b});
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x}))),
Store::make(b, {5}, 100)});
// loop with one element writes.
MemDependencyChecker analyzer({a}, {b});
- For* firstLoop =
+ ForPtr firstLoop =
For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x})));
- Store* secondStore = Store::make(b, {x}, Add::make(Load::make(b, {x}), 1));
- For* secondLoop = For::make(x, 4, 7, secondStore);
+ StorePtr secondStore =
+ Store::make(b, {x}, Add::make(Load::make(b, {x}), 1));
+ ForPtr secondLoop = For::make(x, 4, 7, secondStore);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{firstLoop,
secondLoop,
Store::make(b, {4}, 100),
* }
*/
MemDependencyChecker analyzer({a, b}, {c});
- Stmt* stmt = Block::make({For::make(
+ StmtPtr stmt = Block::make({For::make(
x, 0, Load::make(b, {0}), Store::make(c, {x}, Load::make(a, {x})))});
stmt->accept(&analyzer);
* }
*/
MemDependencyChecker analyzer({a, b}, {c});
- Stmt* stmt = Block::make({For::make(
+ StmtPtr stmt = Block::make({For::make(
x,
Load::make(b, {0}),
Load::make(b, {1}),
* }
*/
MemDependencyChecker analyzer({a, b}, {c});
- Stmt* stmt = Block::make({For::make(
+ StmtPtr stmt = Block::make({For::make(
x, 0, 10, Store::make(c, {x}, Load::make(a, {Load::make(b, {x})})))});
stmt->accept(&analyzer);
* }
*/
MemDependencyChecker analyzer({a, b}, {c});
- Stmt* stmt = Block::make({For::make(
+ StmtPtr stmt = Block::make({For::make(
x, 0, 10, Store::make(c, {Load::make(b, {x})}, Load::make(a, {x})))});
stmt->accept(&analyzer);
* }
*/
MemDependencyChecker analyzer({a, b}, {c});
- Stmt* stmt = Block::make({For::make(
+ StmtPtr stmt = Block::make({For::make(
x, 0, 10, Store::make(c, {Load::make(b, {Load::make(a, {x})})}, x))});
stmt->accept(&analyzer);
// Full range.
MemDependencyChecker analyzer({a}, {b});
- Stmt* stmt = Block::make({For::make(
+ StmtPtr stmt = Block::make({For::make(
x,
0,
M,
// Partial range.
MemDependencyChecker analyzer({a}, {b});
- Stmt* stmt = Block::make({For::make(
+ StmtPtr stmt = Block::make({For::make(
x,
0,
5,
// Partial loops.
MemDependencyChecker analyzer({a}, {b});
- Stmt* stmt = Block::make({For::make(
+ StmtPtr stmt = Block::make({For::make(
x,
0,
N,
// dimensionality.
MemDependencyChecker analyzer({a, c}, {b});
- Stmt* stmt = Block::make({For::make(
+ StmtPtr stmt = Block::make({For::make(
x,
0,
M,
// Multi-dim reductions.
MemDependencyChecker analyzer({a}, {b});
- Stmt* stmt = Block::make({For::make(
+ StmtPtr stmt = Block::make({For::make(
x,
0,
M,
ASSERT_TRUE(analyzer.dependsIndirectly(d->buf(), b_buf.data()));
// Second loop depends on first loop.
- auto* c_loop = l.getLoopStmtsFor(c)[0];
- auto* d_loop = l.getLoopStmtsFor(d)[0];
+ auto c_loop = l.getLoopStmtsFor(c)[0];
+ auto d_loop = l.getLoopStmtsFor(d)[0];
ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop));
}
l.splitWithTail(l.getLoopStmtsFor(c)[0], 2);
MemDependencyChecker analyzer_after({a_buf.data(), b_buf.data()}, {c->buf()});
- Stmt* stmt = IRSimplifier::simplify(l.root_stmt());
+ StmtPtr stmt = IRSimplifier::simplify(l.root_stmt());
stmt->accept(&analyzer_after);
// Splitting should not change accesses at all.
l.reorderAxis(loops[0], loops[1]);
MemDependencyChecker analyzer_after({a_buf.data(), b_buf.data()}, {c->buf()});
- Stmt* stmt = IRSimplifier::simplify(l.root_stmt());
+ StmtPtr stmt = IRSimplifier::simplify(l.root_stmt());
stmt->accept(&analyzer_after);
// Reordering should not change accesses at all.
ASSERT_TRUE(analyzer.dependsIndirectly(d->buf(), b.data()));
// Second loop depends on first loop.
- auto* c_loop = l.getLoopStmtsFor(c)[0];
- auto* d_loop = l.getLoopStmtsFor(d)[0];
+ auto c_loop = l.getLoopStmtsFor(c)[0];
+ auto d_loop = l.getLoopStmtsFor(d)[0];
ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop));
// Reduction depends on both inputs.
{
auto const& loops = loop.getLoopStmtsFor(CT);
- For* m = loops[0];
+ ForPtr m = loops[0];
loop.splitWithMask(m, 4);
}
{
auto const& loops = loop.getLoopStmtsFor(CT);
- For* n = loops[2];
+ ForPtr n = loops[2];
loop.splitWithMask(n, 16);
}
// mo, mi, no, ni, k ->
// mo, no, mi, ni, k
{
auto const& loops = loop.getLoopStmtsFor(CT);
- For* mi = loops[1];
- For* no = loops[2];
+ ForPtr mi = loops[1];
+ ForPtr no = loops[2];
loop.reorderAxis(mi, no);
}
// mo, no, mi, ni, k ->
// mo, no, mi, k, ni
{
auto const& loops = loop.getLoopStmtsFor(CT);
- For* ni = loops[3];
- For* k = loops[4];
+ ForPtr ni = loops[3];
+ ForPtr k = loops[4];
loop.reorderAxis(ni, k);
}
// mo, no, mi, k, ni ->
// mo, no, k, mi, ni
{
auto const& loops = loop.getLoopStmtsFor(CT);
- For* mi = loops[2];
- For* k = loops[3];
+ ForPtr mi = loops[2];
+ ForPtr k = loops[3];
loop.reorderAxis(mi, k);
}
{
// Test both unlowered and lowered form.
{
- Stmt* stmt = IRSimplifier::simplify(loop.root_stmt());
+ StmtPtr stmt = IRSimplifier::simplify(loop.root_stmt());
stmt->accept(&analyzer_unlowered);
// Outputs depend on inputs.
// now check lowered dependency graph.
{
- Stmt* stmt = IRSimplifier::simplify(loop.root_stmt());
+ StmtPtr stmt = IRSimplifier::simplify(loop.root_stmt());
stmt->accept(&analyzer_lowered);
// Lowering will change the dimensionality of all bounds due to index
history_before[i]->bounds(), history_after[i]->bounds()));
} else {
ASSERT_EQ(history_after[i]->bounds().size(), 1);
- Expr* flat_bounds = new IntImm(1);
+ ExprPtr flat_bounds = alloc<IntImm>(1);
for (auto& b : history_before[i]->bounds()) {
- flat_bounds = new Mul(flat_bounds, new Add(b.end, new IntImm(1)));
+ flat_bounds =
+ alloc<Mul>(flat_bounds, alloc<Add>(b.end, alloc<IntImm>(1)));
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
ASSERT_TRUE(exprEquals(b.start, history_after[i]->bounds()[0].start));
}
flat_bounds = IRSimplifier::simplify(flat_bounds);
- Expr* after_bounds = IRSimplifier::simplify(
- new Add(history_after[i]->bounds()[0].end, new IntImm(1)));
+ ExprPtr after_bounds = IRSimplifier::simplify(
+ alloc<Add>(history_after[i]->bounds()[0].end, alloc<IntImm>(1)));
ASSERT_TRUE(exprEquals(flat_bounds, after_bounds));
}
}
Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {});
LoopNest loop({c});
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
SimpleIREvaluator cg(s, {b, c});
Tensor* c = Reduce("sum", {}, Sum(), b, {});
LoopNest loop({c});
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
SimpleIREvaluator cg(s, {b, c});
Tensor* c = Reduce("sum", {}, Sum(), b, {{10, "m"}});
LoopNest loop({c});
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
SimpleIREvaluator cg(s, {b, c});
Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}});
LoopNest loop({c});
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
SimpleIREvaluator cg(s, {b, c, n, m});
Tensor* c = Reduce("sum", {{2, "l"}, {3, "n"}}, Sum(), b, {{m, "m"}});
LoopNest loop({c});
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
SimpleIREvaluator cg(s, {b, c, m});
Tensor* d = Reduce("sum2", {{2, "l"}}, Sum(), b, {{3, "n"}, {m, "m"}});
LoopNest loop2({d});
loop2.prepareForCodegen();
- Stmt* s2 = loop2.root_stmt();
+ StmtPtr s2 = loop2.root_stmt();
s2 = IRSimplifier::simplify(s2);
SimpleIREvaluator cg2(s2, {b, d, m});
Tensor* e = Reduce("sum3", {{2, "l"}}, Sum(), c_buf, {{3, "m"}});
LoopNest loop3({e});
loop3.prepareForCodegen();
- Stmt* s3 = loop3.root_stmt();
+ StmtPtr s3 = loop3.root_stmt();
s3 = IRSimplifier::simplify(s3);
SimpleIREvaluator cg3(s3, {c, e});
{{3, "f"}, {2, "g"}, {3, "h"}, {2, "i"}, {3, "j"}});
LoopNest loop({c});
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
SimpleIREvaluator cg(s, {in_, c});
Tensor* c = Reduce("product", {{M, "m"}}, product, b, {{N, "n"}});
LoopNest loop({c});
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
SimpleIREvaluator cg(s, {b, c});
LoopNest loop({dm1});
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
SimpleIREvaluator cg(s, {in_, dm1});
LoopNest loop({min});
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
SimpleIREvaluator cg(s, {in_, min, minInit});
LoopNest loop({any});
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
SimpleIREvaluator cg(s, {b, any, searchValue});
LoopNest loop({mm});
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
SimpleIREvaluator cg(s, {tA, tB, mm});
LoopNest loop({l1, l2});
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
SimpleIREvaluator cg(s, {in, l1, l2});
});
LoopNest loop({d}, {c, d});
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
SimpleIREvaluator cg(s, {a, b, d, m});
Tensor* d = Reduce("sum", {{2, "l1"}}, Sum(), c, {{3, "n1"}, {m, "m1"}});
LoopNest loop({d}, {c, d});
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
SimpleIREvaluator cg(s, {a, b, d, m});
Tensor* tensor = Reduce("sum", {{16, "m"}}, Sum(), in, {{8, "n"}});
LoopNest l({tensor});
- std::vector<For*> loops = l.getLoopStmtsFor(tensor);
+ std::vector<ForPtr> loops = l.getLoopStmtsFor(tensor);
LoopNest::splitWithTail(loops[1], 2);
l.prepareForCodegen();
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
s = IRSimplifier::simplify(s);
SimpleIREvaluator cg(s, {in, tensor});
std::vector<float> out(16, -1.f);
Tensor* tensor = Reduce("sum", {{16, "m"}}, Sum(), in, {{8, "n"}});
LoopNest l({tensor});
- std::vector<For*> loops = l.getLoopStmtsFor(tensor);
+ std::vector<ForPtr> loops = l.getLoopStmtsFor(tensor);
LoopNest::splitWithTail(loops[0], 2);
LoopNest::splitWithTail(loops[0], 2);
l.prepareForCodegen();
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
s = IRSimplifier::simplify(s);
SimpleIREvaluator cg(s, {in, tensor});
LoopNest l_({tensor_});
l_.prepareForCodegen();
- Stmt* s_ = Stmt::clone(l_.root_stmt());
+ StmtPtr s_ = Stmt::clone(l_.root_stmt());
s_ = IRSimplifier::simplify(s_);
Tensor* tensor = Reduce("sum", {{1, "k"}, {12, "n"}}, Sum(), in, {{6, "m"}});
LoopNest::reorderAxis(loops[1], loops[2]);
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
s = IRSimplifier::simplify(s);
Tensor* c = Reduce("sum", {}, Sum(), b, {{m, "m"}, {n, "n"}});
LoopNest loop({c});
- std::vector<For*> loops = loop.getLoopStmtsFor(c);
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
- auto c_body = const_cast<Stmt*>(loop.getAllWritesToBuf(c->buf())[1]);
+ std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
+ auto c_body = loop.getAllWritesToBuf(c->buf())[1];
ASSERT_TRUE(loop.rfactor(c_body, loops.at(0)));
auto rc = NodeFinder<ReduceOp>::find(loop.root_stmt());
ASSERT_EQ(rc.size(), 2);
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
SimpleIREvaluator cg(s, {b, c, m, n});
Tensor* c = Reduce("sum", {}, Sum(), b, {{m, "m"}, {n, "n"}, {k, "k"}});
LoopNest loop({c});
- std::vector<For*> loops = loop.getLoopStmtsFor(c);
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
- auto c_body = const_cast<Stmt*>(loop.getAllWritesToBuf(c->buf())[1]);
+ std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
+ auto c_body = loop.getAllWritesToBuf(c->buf())[1];
ASSERT_FALSE(loop.rfactor(c_body, loops.at(2)));
auto rc = NodeFinder<ReduceOp>::find(loop.root_stmt());
ASSERT_EQ(rc.size(), 1);
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
SimpleIREvaluator cg(s, {b, c, m, n, k});
Tensor* c = Reduce("sum", {}, Sum(), b, {{m, "m"}, {n, "n"}, {k, "k"}});
LoopNest loop({c});
- std::vector<For*> loops = loop.getLoopStmtsFor(c);
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
- auto c_body = const_cast<Stmt*>(loop.getAllWritesToBuf(c->buf())[1]);
+ std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
+ auto c_body = loop.getAllWritesToBuf(c->buf())[1];
ASSERT_TRUE(loop.rfactor(c_body, loops.at(0)));
auto rc = NodeFinder<ReduceOp>::find(loop.root_stmt());
ASSERT_EQ(rc.size(), 2);
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
SimpleIREvaluator cg(s, {b, c, m, n, k});
IRSimplifier::simplify(refloop.root_stmt()), {in_, c});
ref_cg.call({in, ref});
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
- Buf* tmp_buf = const_cast<Buf*>(c->buf());
+ BufPtr tmp_buf = c->buf();
for (int idx = 0; idx < rfac_number; idx++) {
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
- auto reduce = const_cast<Stmt*>(loop.getAllWritesToBuf(tmp_buf)[1]);
+ auto reduce = loop.getAllWritesToBuf(tmp_buf)[1];
ASSERT_TRUE(loop.rfactor(
reduce, loop.getLoopStmtsFor(tmp_buf).at(idx), &tmp_buf));
}
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
SimpleIREvaluator cg(s, {in_, c});
Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}});
LoopNest loop({c});
- std::vector<For*> loops = loop.getLoopStmtsFor(c);
+ std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
LoopNest::splitWithTail(loops[i], 8);
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
SimpleIREvaluator cg(s, {b, c});
Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}});
LoopNest loop({c});
- std::vector<For*> loops = loop.getLoopStmtsFor(c);
+ std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
LoopNest::splitWithTail(loops[i], 5);
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
SimpleIREvaluator cg(s, {b, c});
Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}});
LoopNest loop({c});
- std::vector<For*> loops = loop.getLoopStmtsFor(c);
+ std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
LoopNest::splitWithTail(loops[i], 16);
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
SimpleIREvaluator cg(s, {b, c});
Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}});
LoopNest loop({c});
- std::vector<For*> loops = loop.getLoopStmtsFor(c);
+ std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
LoopNest::splitWithMask(loops[i], 8);
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
SimpleIREvaluator cg(s, {b, c});
Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}});
LoopNest loop({c});
- std::vector<For*> loops = loop.getLoopStmtsFor(c);
+ std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
LoopNest::splitWithMask(loops[i], 5);
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
SimpleIREvaluator cg(s, {b, c});
Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}});
LoopNest loop({c});
- std::vector<For*> loops = loop.getLoopStmtsFor(c);
+ std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
LoopNest::splitWithMask(loops[i], 16);
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
s = IRSimplifier::simplify(s);
SimpleIREvaluator cg(s, {b, c});
Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}});
LoopNest loop({c});
- std::vector<For*> loops = loop.getLoopStmtsFor(c);
+ std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
LoopNest::splitWithTail(loops[2], SPLIT_FACTOR);
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
- auto c_body = const_cast<Stmt*>(loop.getAllWritesToBuf(c->buf())[2]);
+ auto c_body = loop.getAllWritesToBuf(c->buf())[2];
auto all_loops = loop.getAllLoopNestsWritingToBuf(c->buf());
ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(2).size() == 3);
LoopNest::reorderAxis(all_loops[2][1], all_loops[2][2]);
ASSERT_TRUE(loop.rfactor(c_body, all_loops[2][1]));
loop.prepareForCodegen();
loop.simplify();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
SimpleIREvaluator cg(s, {b, c});
Tensor* c = Reduce("sum", {}, Sum(), b, {{N, "n"}, {K, "k"}});
LoopNest loop({c});
- std::vector<For*> loops = loop.getLoopStmtsFor(c);
+ std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For *i, *t;
+ ForPtr i, t;
LoopNest::splitWithTail(loops[1], SPLIT_FACTOR, &i, &t);
LoopNest::reorderAxis(loops[0], i);
auto all_loops = loop.getAllLoopNestsWritingToBuf(c->buf());
ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(1).size() == 3);
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
- auto c_body = const_cast<Stmt*>(loop.getAllWritesToBuf(c->buf())[1]);
+ auto c_body = loop.getAllWritesToBuf(c->buf())[1];
ASSERT_TRUE(loop.rfactor(c_body, all_loops[1][0]));
LoopNest::reorderAxis(all_loops[1][0], all_loops[1][2]);
loop.prepareForCodegen();
loop.simplify();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
SimpleIREvaluator cg(s, {b, c});
l1.prepareForCodegen();
l2.prepareForCodegen();
- Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt());
- Stmt* stmt2 = IRSimplifier::simplify(l2.root_stmt());
+ StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
+ StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt());
SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y});
SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y});
l1.prepareForCodegen();
l2.prepareForCodegen();
- Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt());
- Stmt* stmt2 = IRSimplifier::simplify(l2.root_stmt());
+ StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
+ StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt());
SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y});
SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y});
l_before.prepareForCodegen();
SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e});
- Stmt* d_loop = l.getLoopStmtsFor(d)[0];
+ StmtPtr d_loop = l.getLoopStmtsFor(d)[0];
l.cacheAccesses(d->buf(), "d_local", d_loop);
l.prepareForCodegen();
- Stmt* result = IRSimplifier::simplify(l.root_stmt());
+ StmtPtr result = IRSimplifier::simplify(l.root_stmt());
SimpleIREvaluator cg_after(result, {a, b, e});
std::ostringstream oss;
l_before.prepareForCodegen();
SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e});
- Stmt* d_loop = l.getLoopStmtsFor(d)[1];
+ StmtPtr d_loop = l.getLoopStmtsFor(d)[1];
l.cacheAccesses(d->buf(), "d_local", d_loop);
l.prepareForCodegen();
- Stmt* result = IRSimplifier::simplify(l.root_stmt());
+ StmtPtr result = IRSimplifier::simplify(l.root_stmt());
SimpleIREvaluator cg_after(result, {a, b, e});
std::ostringstream oss;
l_before.prepareForCodegen();
SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e});
- Stmt* d_loop = l.getLoopStmtsFor(d)[2];
+ StmtPtr d_loop = l.getLoopStmtsFor(d)[2];
l.cacheAccesses(d->buf(), "d_local", d_loop);
l.prepareForCodegen();
- Stmt* result = IRSimplifier::simplify(l.root_stmt());
+ StmtPtr result = IRSimplifier::simplify(l.root_stmt());
SimpleIREvaluator cg_after(result, {a, b, e});
std::ostringstream oss;
LoopNest l({e}, {c, d, e});
- Stmt* d_loop = l.getLoopStmtsFor(d)[1];
+ StmtPtr d_loop = l.getLoopStmtsFor(d)[1];
l.cacheAccesses(c->buf(), "scale_local", d_loop);
l.prepareForCodegen();
- Stmt* result = IRSimplifier::simplify(l.root_stmt());
+ StmtPtr result = IRSimplifier::simplify(l.root_stmt());
std::ostringstream oss;
oss << *result;
LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4);
- Stmt* e_loop = l.getLoopStmtsFor(e)[1];
+ StmtPtr e_loop = l.getLoopStmtsFor(e)[1];
l.cacheAccesses(d->buf(), "sum_local", e_loop);
l.prepareForCodegen();
- Stmt* result = IRSimplifier::simplify(l.root_stmt());
+ StmtPtr result = IRSimplifier::simplify(l.root_stmt());
std::ostringstream oss;
oss << *result;
LoopNest l({e}, {c, d, e});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* inner;
+ ForPtr inner;
// Split outer reduction axis.
LoopNest::splitWithMask(l.getLoopStmtsFor(d)[0], 4, &inner);
l.cacheAccesses(d->buf(), "sum_local", inner);
l.prepareForCodegen();
- Stmt* result = IRSimplifier::simplify(l.root_stmt());
+ StmtPtr result = IRSimplifier::simplify(l.root_stmt());
// reduction changes but cache does not.
std::ostringstream oss;
LoopNest l({e}, {c, d, e});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* inner;
+ ForPtr inner;
// reorder outer reduction axes.
auto loops = l.getLoopStmtsFor(d);
l.cacheAccesses(d->buf(), "sum_local", inner);
l.prepareForCodegen();
- Stmt* result = IRSimplifier::simplify(l.root_stmt());
+ StmtPtr result = IRSimplifier::simplify(l.root_stmt());
// neither reduction body not cache changes.
std::ostringstream oss;
Tensor* c = Reduce("sum", {}, Sum(), b, {{m, "a"}, {n, "b"}, {k, "c"}});
LoopNest loop({c});
- std::vector<For*> loops = loop.getLoopStmtsFor(c);
+ std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
LoopNest::reorderAxis(loops.at(0), loops.at(1));
loops = loop.getLoopStmtsFor(c);
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
- auto c_body = const_cast<Stmt*>(loop.getAllWritesToBuf(c->buf())[1]);
+ auto c_body = loop.getAllWritesToBuf(c->buf())[1];
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- Buf* rfac_buf;
+ BufPtr rfac_buf;
ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf));
loop.distributeLoop(loops.at(0));
LoopNest::cacheAccesses(rfac_buf, "tmp", all_loops[1][1]);
loop.simplify();
loop.prepareForCodegen();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
std::ostringstream oss;
oss << *s;
Tensor* c = Reduce("sum", {}, Sum(), b, {{m, "a"}, {n, "b"}, {k, "c"}});
LoopNest loop({c});
- std::vector<For*> loops = loop.getLoopStmtsFor(c);
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
- auto c_body = const_cast<Stmt*>(loop.getAllWritesToBuf(c->buf())[1]);
+ std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
+ auto c_body = loop.getAllWritesToBuf(c->buf())[1];
LoopNest::reorderAxis(loops.at(0), loops.at(1));
loops = loop.getLoopStmtsFor(c);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- Buf* rfac_buf;
+ BufPtr rfac_buf;
ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf));
loop.distributeLoop(loops.at(0));
auto all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf);
LoopNest::cacheAccesses(rfac_buf, "tmp", all_loops[1][2]);
loop.prepareForCodegen();
loop.simplify();
- Stmt* s = loop.root_stmt();
+ StmtPtr s = loop.root_stmt();
std::ostringstream oss;
oss << *s;
ASSERT_TRUE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[0]));
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
s = IRSimplifier::simplify(s);
std::ostringstream oss;
// But if we rfactor this so it's not a reduce axis we can vectorize that
// loop.
- std::vector<For*> loops = l.getLoopStmtsFor(tensor);
+ std::vector<ForPtr> loops = l.getLoopStmtsFor(tensor);
LoopNest::reorderAxis(loops[0], loops[1]);
loops = l.getLoopStmtsFor(tensor);
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
- auto tensor_body = const_cast<Stmt*>(l.getAllWritesToBuf(tensor->buf())[1]);
- Buf* rfac_buf = nullptr;
+ auto tensor_body = l.getAllWritesToBuf(tensor->buf())[1];
+ BufPtr rfac_buf = nullptr;
ASSERT_TRUE(LoopNest::rfactor(tensor_body, loops.at(0), &rfac_buf));
LoopNest::distributeLoop(loops.at(0));
ASSERT_TRUE(LoopNest::vectorize(rfac_loops[1][0]));
l.simplify();
- Stmt* s = l.root_stmt();
+ StmtPtr s = l.root_stmt();
std::ostringstream oss;
oss << *s;
{{M, "m"}});
LoopNest nest({C});
nest.prepareForCodegen();
- Stmt* s = IRSimplifier::simplify(nest.root_stmt());
+ StmtPtr s = IRSimplifier::simplify(nest.root_stmt());
std::ostringstream oss;
oss << *s << "\n";
const std::string& expected_ir =
KernelScope kernel_scope;
BufHandle a("A", {1}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(a, {0}, 0),
For::make(
x,
KernelScope kernel_scope;
BufHandle a("A", {10}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(a, {0}, 0),
For::make(
x,
KernelScope kernel_scope;
BufHandle a("A", {1}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(a, {0}, 0),
For::make(
x,
KernelScope kernel_scope;
BufHandle a("A", {1}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make({For::make(
+ StmtPtr stmt = Block::make({For::make(
x,
0,
10,
VarHandle x("x", kInt);
VarHandle y("y", kInt);
VarHandle z("z", kInt);
- Stmt* stmt = Block::make({For::make(
+ StmtPtr stmt = Block::make({For::make(
x,
0,
10,
KernelScope kernel_scope;
BufHandle a("A", {1}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{For::make(
x,
0,
KernelScope kernel_scope;
BufHandle a("A", {1}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{For::make(
x,
0,
BufHandle a("A", {1}, kInt);
VarHandle x("x", kInt);
VarHandle y("y", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{For::make(
x,
0,
KernelScope kernel_scope;
BufHandle a("A", {2}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make({
+ StmtPtr stmt = Block::make({
Store::make(a, {0}, 0),
Store::make(a, {1}, 0),
For::make(
BufHandle b("B", {10}, kInt);
VarHandle x("x", kInt);
VarHandle x2("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(a, {0}, 0),
For::make(x, 0, 10, Store::make(b, {x}, x)),
For::make(
VarHandle N("N", kInt);
BufHandle a("A", {N}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(a, {i}, 0),
For::make(
x,
BufHandle a("A", {1}, kInt);
VarHandle x("x", kInt);
VarHandle y("y", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(a, {0}, 0),
For::make(
x,
KernelScope kernel_scope;
BufHandle a("A", {2}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make({
+ StmtPtr stmt = Block::make({
Store::make(a, {0}, 0),
Store::make(a, {1}, 0),
For::make(
KernelScope kernel_scope;
BufHandle a("A", {1}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(a, {0}, 0),
For::make(
x, 0, 10, Block::make({Store::make(a, {0}, Add::make(x, 1))}))});
BufHandle a("A", {1}, kInt);
BufHandle b("B", {10}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(a, {0}, 0),
For::make(
x,
KernelScope kernel_scope;
BufHandle a("A", {2}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make({
+ StmtPtr stmt = Block::make({
Store::make(a, {0}, 0),
Store::make(a, {1}, 0),
For::make(
BufHandle b("B", {Load::make(c, {0})}, kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Allocate::make(b),
Store::make(a, {0}, Load::make(c, {0})),
Store::make(b, {0}, 0),
KernelScope kernel_scope;
BufHandle a("A", {1}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make({For::make(
+ StmtPtr stmt = Block::make({For::make(
x,
0,
10,
KernelScope kernel_scope;
BufHandle a("A", {1}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make({For::make(
+ StmtPtr stmt = Block::make({For::make(
x,
0,
10,
BufHandle a("A", {1}, kInt);
BufHandle b("B", {1}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make({For::make(
+ StmtPtr stmt = Block::make({For::make(
x,
0,
10,
VarHandle x("x", kInt);
LoopOptions loopOpts;
loopOpts.set_gpu_block_index(0);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(a, {0}, 0),
For::make(
x,
BufHandle c("C", {5}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(a, {x}, Load::make(b, {x})),
Store::make(c, {x}, Load::make(a, {x})),
Cond::make(
BufHandle c("C", {5}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Cond::make(
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
BufHandle c("C", {5}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(a, {x}, Load::make(b, {x})),
Store::make(c, {x}, Load::make(a, {x})),
Cond::make(
VarHandle x("x", kInt);
VarHandle y("y", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
{Store::make(a, {x}, Load::make(b, {x})),
Store::make(c, {x}, Load::make(a, {x})),
VarHandle x("x", kInt);
VarHandle y("y", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
{Store::make(a, {x}, Load::make(b, {x})),
Store::make(a, {x}, Load::make(b, {x + 1})),
BufHandle c("C", {5}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Cond::make(
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
BufHandle c("C", {5}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Cond::make(
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
BufHandle c("C", {5}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(a, {x}, Load::make(b, {x})),
Store::make(c, {x}, Load::make(a, {x})),
Cond::make(
BufHandle c("C", {5}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make({Cond::make(
+ StmtPtr stmt = Block::make({Cond::make(
CompareSelect::make(Load::make(a, {x}), 5, CompareSelectOperation::kLT),
Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
Store::make(a, {x}, Add::make(Load::make(a, {x}), 10)))});
VarHandle x("x", kInt);
VarHandle y("y", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(
b,
{y},
VarHandle x("x", kInt);
VarHandle y("y", kInt);
- Stmt* stmt = Block::make({
+ StmtPtr stmt = Block::make({
Store::make(a, {x}, 0),
Store::make(
b,
BufHandle d("D", {5}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make({Store::make(
+ StmtPtr stmt = Block::make({Store::make(
a,
{x},
IfThenElse::make(
BufHandle b("B", {5}, kFloat);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make({Store::make(
+ StmtPtr stmt = Block::make({Store::make(
a,
{x},
IfThenElse::make(
BufHandle c("C", {5}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(a, {x}, Load::make(a, {x})),
Store::make(
a,
BufHandle c("C", {5}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make({Store::make(
+ StmtPtr stmt = Block::make({Store::make(
b,
{x},
IfThenElse::make(
KernelScope kernel_scope;
BufHandle a("A", {5}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make({For::make(
+ StmtPtr stmt = Block::make({For::make(
x,
0,
10,
BufHandle c("C", {5}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make({Cond::make(
+ StmtPtr stmt = Block::make({Cond::make(
CompareSelect::make(
IfThenElse::make(
CompareSelect::make(
VarHandle x("x", kInt);
VarHandle y("y", kInt);
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
y,
0,
10,
VarHandle x("x", kInt);
VarHandle y("y", kInt);
- Stmt* stmt = Block::make({For::make(
+ StmtPtr stmt = Block::make({For::make(
y,
0,
10,
KernelScope kernel_scope;
BufHandle a("A", {1}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(a, {0}, 0),
For::make(
x,
KernelScope kernel_scope;
BufHandle a("A", {1}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))),
Store::make(a, {0}, 0),
For::make(
VarHandle x1("x1", kInt);
VarHandle x2("x2", kInt);
VarHandle x3("x3", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(a, {0}, 2),
For::make(
x1, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x1))),
KernelScope kernel_scope;
BufHandle a("A", {1}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(a, {0}, 2),
For::make(
x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x))),
KernelScope kernel_scope;
BufHandle a("A", {1}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(a, {0}, 1),
Store::make(a, {0}, 3),
Cond::make(
KernelScope kernel_scope;
BufHandle a("A", {1}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(a, {0}, 1),
Store::make(a, {0}, 3),
Cond::make(
KernelScope kernel_scope;
BufHandle a("A", {1}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(a, {1}, Load::make(a, {0})),
Store::make(a, {0}, Load::make(a, {1})),
Store::make(a, {0}, Load::make(a, {1})),
KernelScope kernel_scope;
BufHandle a("A", {1}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
{Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), 2))}),
KernelScope kernel_scope;
BufHandle a("A", {1}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make({Cond::make(
+ StmtPtr stmt = Block::make({Cond::make(
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
Block::make(
{Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
KernelScope kernel_scope;
BufHandle a("A", {1}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
Cond::make(
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
KernelScope kernel_scope;
BufHandle a("A", {1}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Cond::make(
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
KernelScope kernel_scope;
BufHandle a("A", {1}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Cond::make(
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
Block::make({Cond::make(
KernelScope kernel_scope;
BufHandle a("A", {1}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
Cond::make(
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
BufHandle a("A", {10}, kInt);
BufHandle b("B", {10}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Cond::make(
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
BufHandle a("A", {10}, kInt);
BufHandle b("B", {10}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(a, {4}, 0),
Cond::make(
CompareSelect::make(x, 2, CompareSelectOperation::kGT),
BufHandle a("A", {1}, kInt);
VarHandle x("x", kInt);
VarHandle y("y", kInt);
- Stmt* stmt = Block::make({For::make(
+ StmtPtr stmt = Block::make({For::make(
y,
0,
10,
BufHandle b("B", {10}, kInt);
VarHandle x("x", kInt);
VarHandle y("y", kInt);
- Stmt* stmt = Block::make({Cond::make(
+ StmtPtr stmt = Block::make({Cond::make(
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
Block::make(
{Store::make(a, {0}, 0),
BufHandle b("B", {10}, kInt);
VarHandle x("x", kInt);
VarHandle y("y", kInt);
- Stmt* stmt = Block::make({Cond::make(
+ StmtPtr stmt = Block::make({Cond::make(
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
Block::make({For::make(
x,
BufHandle b("B", {10}, kInt);
VarHandle x("x", kInt);
VarHandle y("y", kInt);
- Stmt* stmt = Block::make({Cond::make(
+ StmtPtr stmt = Block::make({Cond::make(
CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
Block::make(
{Store::make(a, {0}, 0),
KernelScope kernel_scope;
BufHandle a("A", {1}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Cond::make(
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
For::make(
KernelScope kernel_scope;
BufHandle a("A", {1}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Cond::make(
CompareSelect::make(x, 5, CompareSelectOperation::kLT),
For::make(
BufHandle a("A", {10}, kInt);
VarHandle x("x", kInt);
VarHandle y("y", kInt);
- Stmt* stmt = Block::make({For::make(
+ StmtPtr stmt = Block::make({For::make(
x,
0,
10,
BufHandle a("A", {10}, kInt);
VarHandle x("x", kInt);
VarHandle y("y", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Let::make(y, 30),
For::make(
x,
KernelScope kernel_scope;
BufHandle a("A", {3, 4, 5}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(a, {0, 1, 2}, 0),
For::make(
x,
KernelScope kernel_scope;
BufHandle a("A", {3, 4, 5}, kInt);
VarHandle x("x", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(a, {0, 1, 2}, 0),
For::make(
x,
BufHandle a("A", {3, 4, 5}, kInt);
VarHandle x("x", kInt);
VarHandle y("y", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(a, {0, 1, 2}, 0),
For::make(
x,
BufHandle a("A", {3, 4, 5}, kInt);
VarHandle x("x", kInt);
VarHandle y("y", kInt);
- Stmt* stmt = Block::make(
+ StmtPtr stmt = Block::make(
{Store::make(a, {0, 1, 2}, 0),
For::make(
x,
VarHandle x("x", kInt);
VarHandle y("y", kInt);
VarHandle z("z", kInt);
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
x,
0,
10,
VarHandle x("x", kInt);
VarHandle y("y", kInt);
VarHandle z("z", kInt);
- Stmt* stmt = For::make(
+ StmtPtr stmt = For::make(
x,
0,
10,
ExprHandle body = x * (ExprHandle(2) + ExprHandle(4));
ExprHandle newF = IRSimplifier::simplify(body);
- Mul* root = newF.AsNode<Mul>();
+ MulPtr root = newF.AsNode<Mul>();
ASSERT_NE(root, nullptr);
- ASSERT_NE(dynamic_cast<const IntImm*>(root->lhs()), nullptr);
+ ASSERT_NE(to<IntImm>(root->lhs()), nullptr);
SimpleIRExprEval eval(newF);
eval.bindVar(x, ExprHandle(3));
ExprHandle body = x * (ExprHandle(2.f) + ExprHandle(4.f));
ExprHandle newF = IRSimplifier::simplify(body);
- Mul* root = newF.AsNode<Mul>();
+ MulPtr root = newF.AsNode<Mul>();
ASSERT_NE(root, nullptr);
- ASSERT_NE(dynamic_cast<const FloatImm*>(root->rhs()), nullptr);
+ ASSERT_NE(to<FloatImm>(root->rhs()), nullptr);
SimpleIRExprEval eval(newF);
eval.bindVar(x, ExprHandle(3.f));
ExprHandle f = x < 4.f;
ExprHandle newF = IRSimplifier::simplify(f);
- const IntImm* folded = newF.AsNode<IntImm>();
+ IntImmPtr folded = newF.AsNode<IntImm>();
ASSERT_EQ(folded, nullptr);
{
ExprHandle body = (ExprHandle(3) * x) + (ExprHandle(5) * y);
ExprHandle newF = IRSimplifier::simplify(body);
- Add* root = newF.AsNode<Add>();
+ AddPtr root = newF.AsNode<Add>();
ASSERT_NE(root, nullptr);
- ASSERT_EQ(dynamic_cast<const FloatImm*>(root->lhs()), nullptr);
- ASSERT_EQ(dynamic_cast<const FloatImm*>(root->rhs()), nullptr);
+ ASSERT_EQ(to<FloatImm>(root->lhs()), nullptr);
+ ASSERT_EQ(to<FloatImm>(root->rhs()), nullptr);
SimpleIRExprEval eval(newF);
eval.bindVar(x, ExprHandle(3.f));
VarHandle y("y", kFloat);
ExprHandle f = (x * y) + (x * y);
- Add* root = f.AsNode<Add>();
+ AddPtr root = f.AsNode<Add>();
ASSERT_NE(root, nullptr);
HashProvider hasher;
ExprHandle f =
Intrinsics::make(kRand, kFloat) + Intrinsics::make(kRand, kInt);
- Add* root = f.AsNode<Add>();
+ AddPtr root = f.AsNode<Add>();
ASSERT_NE(root, nullptr);
HashProvider hasher;
KernelScope kernel_scope;
HashProvider hasher;
- std::vector<Expr*> immediates;
+ std::vector<ExprPtr> immediates;
- immediates.push_back(new DoubleImm(1));
- immediates.push_back(new FloatImm(1));
- immediates.push_back(new HalfImm(1));
+ immediates.push_back(alloc<DoubleImm>(1));
+ immediates.push_back(alloc<FloatImm>(1));
+ immediates.push_back(alloc<HalfImm>(1));
// NOLINTNEXTLINE(modernize-use-bool-literals)
- immediates.push_back(new BoolImm(1));
- immediates.push_back(new CharImm(1));
- immediates.push_back(new ByteImm(1));
- immediates.push_back(new ShortImm(1));
- immediates.push_back(new IntImm(1));
- immediates.push_back(new LongImm(1));
+ immediates.push_back(alloc<BoolImm>(1));
+ immediates.push_back(alloc<CharImm>(1));
+ immediates.push_back(alloc<ByteImm>(1));
+ immediates.push_back(alloc<ShortImm>(1));
+ immediates.push_back(alloc<IntImm>(1));
+ immediates.push_back(alloc<LongImm>(1));
// Immediates of different types are not equal.
for (unsigned int i = 0; i < immediates.size(); ++i) {
ExprHandle body = (ExprHandle(2) + x) + ExprHandle(4);
ExprHandle simplified = IRSimplifier::simplify(body);
- Add* root = simplified.AsNode<Add>();
+ AddPtr root = simplified.AsNode<Add>();
ASSERT_NE(root, nullptr);
- Var* lhs = dynamic_cast<Var*>(root->lhs());
+ VarPtr lhs = to<Var>(root->lhs());
ASSERT_NE(lhs, nullptr);
ASSERT_EQ(lhs->name_hint(), "x");
- const IntImm* rhs = dynamic_cast<const IntImm*>(root->rhs());
+ IntImmPtr rhs = to<IntImm>(root->rhs());
ASSERT_NE(rhs, nullptr);
ASSERT_EQ(rhs->value(), 6.f);
}
ExprHandle body = (ExprHandle(2) - x) - ExprHandle(4);
ExprHandle simplified = IRSimplifier::simplify(body);
- Sub* root = simplified.AsNode<Sub>();
+ SubPtr root = simplified.AsNode<Sub>();
ASSERT_NE(root, nullptr);
- const IntImm* lhs = dynamic_cast<const IntImm*>(root->lhs());
+ IntImmPtr lhs = to<IntImm>(root->lhs());
ASSERT_NE(lhs, nullptr);
ASSERT_EQ(lhs->value(), -2.f);
- Var* rhs = dynamic_cast<Var*>(root->rhs());
+ VarPtr rhs = to<Var>(root->rhs());
ASSERT_NE(rhs, nullptr);
ASSERT_EQ(rhs->name_hint(), "x");
}
(ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4)));
ExprHandle simplified = IRSimplifier::simplify(body);
- Mul* root = simplified.AsNode<Mul>();
+ MulPtr root = simplified.AsNode<Mul>();
ASSERT_NE(root, nullptr);
- const IntImm* lhs = dynamic_cast<const IntImm*>(root->lhs());
+ IntImmPtr lhs = to<IntImm>(root->lhs());
ASSERT_NE(lhs, nullptr);
ASSERT_EQ(lhs->value(), 2);
- Var* rhs = dynamic_cast<Var*>(root->rhs());
+ VarPtr rhs = to<Var>(root->rhs());
ASSERT_NE(rhs, nullptr);
ASSERT_EQ(rhs->name_hint(), "x");
}
(ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4)));
ExprHandle simplified = IRSimplifier::simplify(body);
- Mul* root = simplified.AsNode<Mul>();
+ MulPtr root = simplified.AsNode<Mul>();
ASSERT_NE(root, nullptr);
- const LongImm* lhs = dynamic_cast<const LongImm*>(root->lhs());
+ LongImmPtr lhs = to<LongImm>(root->lhs());
ASSERT_NE(lhs, nullptr);
ASSERT_EQ(lhs->value(), 2);
- Var* rhs = dynamic_cast<Var*>(root->rhs());
+ VarPtr rhs = to<Var>(root->rhs());
ASSERT_NE(rhs, nullptr);
ASSERT_EQ(rhs->name_hint(), "x");
}
ExprHandle body = (x + ExprHandle(0)) * 1;
ExprHandle simplified = IRSimplifier::simplify(body);
- Var* root = simplified.AsNode<Var>();
+ VarPtr root = simplified.AsNode<Var>();
ASSERT_NE(root, nullptr);
ASSERT_EQ(root->name_hint(), "x");
}
ExprHandle simplified = IRSimplifier::simplify(body);
- Add* root = simplified.AsNode<Add>();
+ AddPtr root = simplified.AsNode<Add>();
ASSERT_NE(root, nullptr);
- Mul* lhs = dynamic_cast<Mul*>(root->lhs());
+ MulPtr lhs = to<Mul>(root->lhs());
ASSERT_NE(lhs, nullptr);
- Var* varX = dynamic_cast<Var*>(lhs->rhs());
+ VarPtr varX = to<Var>(lhs->rhs());
ASSERT_NE(varX, nullptr);
ASSERT_EQ(varX->name_hint(), "y");
- Mul* rhs = dynamic_cast<Mul*>(root->rhs());
+ MulPtr rhs = to<Mul>(root->rhs());
ASSERT_NE(rhs, nullptr);
- Var* varY = dynamic_cast<Var*>(rhs->rhs());
+ VarPtr varY = to<Var>(rhs->rhs());
ASSERT_NE(varY, nullptr);
ASSERT_EQ(varY->name_hint(), "x");
}
ExprHandle body = x + 2 + y;
ExprHandle simplified = IRSimplifier::simplify(body);
- Add* root = simplified.AsNode<Add>();
+ AddPtr root = simplified.AsNode<Add>();
ASSERT_NE(root, nullptr);
IS_NODE_WITH_NAME(Add, root->lhs(), rhs);
BufHandle a_buf("A", {6}, kInt);
auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) / 6));
- const Stmt* simplified = IRSimplifier::simplify(for_stmt);
+ const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
std::ostringstream oss;
oss << *(simplified);
BufHandle a_buf("A", {5}, kInt);
auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + 25) / 6));
- const Stmt* simplified = IRSimplifier::simplify(for_stmt);
+ const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
std::ostringstream oss;
oss << *(simplified);
BufHandle a_buf("A", {6}, kInt);
auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) / (-6)));
- const Stmt* simplified = IRSimplifier::simplify(for_stmt);
+ const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
std::ostringstream oss;
oss << *(simplified);
BufHandle a_buf("A", {5}, kInt);
auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + (-5)) / 6));
- const Stmt* simplified = IRSimplifier::simplify(for_stmt);
+ const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
std::ostringstream oss;
oss << *(simplified);
auto for_j = For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) / 6));
auto for_i = For::make(i, 0, 6, for_j);
- const Stmt* simplified = IRSimplifier::simplify(for_i);
+ const StmtPtr simplified = IRSimplifier::simplify(for_i);
std::ostringstream oss;
oss << *(simplified);
For::make(j, -1, 9, Store::make(a_buf, {i, j + 1}, (i + j * 6) / 6));
auto for_i = For::make(i, 0, 6, for_j);
- const Stmt* simplified = IRSimplifier::simplify(for_i);
+ const StmtPtr simplified = IRSimplifier::simplify(for_i);
std::ostringstream oss;
oss << *(simplified);
For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) / (-6)));
auto for_i = For::make(i, 0, 6, for_j);
- const Stmt* simplified = IRSimplifier::simplify(for_i);
+ const StmtPtr simplified = IRSimplifier::simplify(for_i);
std::ostringstream oss;
oss << *(simplified);
BufHandle a_buf("A", {100}, kInt);
auto for_stmt = For::make(i, 0, 100, Store::make(a_buf, {i}, (i % 100)));
- const Stmt* simplified = IRSimplifier::simplify(for_stmt);
+ const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
std::ostringstream oss;
oss << *(simplified);
BufHandle a_buf("A", {6}, kInt);
auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) % 6));
- const Stmt* simplified = IRSimplifier::simplify(for_stmt);
+ const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
std::ostringstream oss;
oss << *(simplified);
BufHandle a_buf("A", {5}, kInt);
auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + 25) % 6));
- const Stmt* simplified = IRSimplifier::simplify(for_stmt);
+ const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
std::ostringstream oss;
oss << *(simplified);
BufHandle a_buf("A", {6}, kInt);
auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) % (-6)));
- const Stmt* simplified = IRSimplifier::simplify(for_stmt);
+ const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
std::ostringstream oss;
oss << *(simplified);
BufHandle a_buf("A", {5}, kInt);
auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + (-5)) % 6));
- const Stmt* simplified = IRSimplifier::simplify(for_stmt);
+ const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
std::ostringstream oss;
oss << *(simplified);
auto for_j = For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) % 6));
auto for_i = For::make(i, 0, 6, for_j);
- const Stmt* simplified = IRSimplifier::simplify(for_i);
+ const StmtPtr simplified = IRSimplifier::simplify(for_i);
std::ostringstream oss;
oss << *(simplified);
For::make(j, -1, 9, Store::make(a_buf, {i, j + 1}, (i + j * 6) % 6));
auto for_i = For::make(i, 0, 6, for_j);
- const Stmt* simplified = IRSimplifier::simplify(for_i);
+ const StmtPtr simplified = IRSimplifier::simplify(for_i);
std::ostringstream oss;
oss << *(simplified);
For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) % (-6)));
auto for_i = For::make(i, 0, 6, for_j);
- const Stmt* simplified = IRSimplifier::simplify(for_i);
+ const StmtPtr simplified = IRSimplifier::simplify(for_i);
std::ostringstream oss;
oss << *(simplified);
BufHandle a("A", {1}, kInt);
BufHandle b("B", {1}, kInt);
ExprHandle condition(1);
- Stmt* true_val = Store::make(a, {0}, 1);
- Stmt* false_val = Store::make(b, {0}, 1);
+ StmtPtr true_val = Store::make(a, {0}, 1);
+ StmtPtr false_val = Store::make(b, {0}, 1);
- Cond* body = new Cond(condition.node(), true_val, false_val);
- Stmt* simplified = IRSimplifier::simplify(body);
- Block* block = dynamic_cast<Block*>(simplified);
+ CondPtr body = alloc<Cond>(condition.node(), true_val, false_val);
+ StmtPtr simplified = IRSimplifier::simplify(body);
+ BlockPtr block = to<Block>(simplified);
IS_NODE_WITH_NAME(Store, block->front(), store);
IS_VAR_WITH_NAME(store->base_handle(), "A");
}
BufHandle a("A", {1}, kInt);
BufHandle b("B", {1}, kInt);
ExprHandle condition(0);
- Stmt* true_val = Store::make(a, {0}, 1);
- Stmt* false_val = Store::make(b, {0}, 1);
+ StmtPtr true_val = Store::make(a, {0}, 1);
+ StmtPtr false_val = Store::make(b, {0}, 1);
- Stmt* body = new Cond(condition.node(), true_val, false_val);
- Stmt* simplified = IRSimplifier::simplify(body);
- Block* block = dynamic_cast<Block*>(simplified);
+ StmtPtr body = alloc<Cond>(condition.node(), true_val, false_val);
+ StmtPtr simplified = IRSimplifier::simplify(body);
+ BlockPtr block = to<Block>(simplified);
IS_NODE_WITH_NAME(Store, block->front(), store);
IS_VAR_WITH_NAME(store->base_handle(), "B");
}
BufHandle a("A", {1}, kInt);
BufHandle b("B", {1}, kInt);
ExprHandle condition(x - x);
- Stmt* true_val = Store::make(a, {0}, 1);
- Stmt* false_val = Store::make(b, {0}, 1);
+ StmtPtr true_val = Store::make(a, {0}, 1);
+ StmtPtr false_val = Store::make(b, {0}, 1);
- Stmt* body = new Cond(condition.node(), true_val, false_val);
- Stmt* simplified = IRSimplifier::simplify(body);
- Block* block = dynamic_cast<Block*>(simplified);
+ StmtPtr body = alloc<Cond>(condition.node(), true_val, false_val);
+ StmtPtr simplified = IRSimplifier::simplify(body);
+ BlockPtr block = to<Block>(simplified);
IS_NODE_WITH_NAME(Store, block->front(), store);
IS_VAR_WITH_NAME(store->base_handle(), "B");
}
VarHandle x("x", kInt);
BufHandle a("A", {1}, kInt);
ExprHandle condition(x - x);
- Stmt* true_val = Store::make(a, {0}, x);
- Stmt* false_val = Store::make(a, {0}, x);
+ StmtPtr true_val = Store::make(a, {0}, x);
+ StmtPtr false_val = Store::make(a, {0}, x);
- Stmt* body = new Cond(condition.node(), true_val, false_val);
- Stmt* simplified = IRSimplifier::simplify(body);
- Block* block = dynamic_cast<Block*>(simplified);
+ StmtPtr body = alloc<Cond>(condition.node(), true_val, false_val);
+ StmtPtr simplified = IRSimplifier::simplify(body);
+ BlockPtr block = to<Block>(simplified);
IS_NODE_WITH_NAME(Store, block->front(), store);
IS_VAR_WITH_NAME(store->base_handle(), "A");
}
VarHandle x("x", kInt);
BufHandle a("A", {1}, kInt);
ExprHandle condition(x - x);
- Stmt* true_val = Store::make(a, {0}, ExprHandle(2) * x);
- Stmt* false_val = Store::make(a, {0}, x + x);
+ StmtPtr true_val = Store::make(a, {0}, ExprHandle(2) * x);
+ StmtPtr false_val = Store::make(a, {0}, x + x);
- Stmt* body = new Cond(condition.node(), true_val, false_val);
- Stmt* simplified = IRSimplifier::simplify(body);
- Block* block = dynamic_cast<Block*>(simplified);
+ StmtPtr body = alloc<Cond>(condition.node(), true_val, false_val);
+ StmtPtr simplified = IRSimplifier::simplify(body);
+ BlockPtr block = to<Block>(simplified);
IS_NODE_WITH_NAME(Store, block->front(), store);
IS_VAR_WITH_NAME(store->base_handle(), "A");
}
VarHandle x("x", kInt);
BufHandle a("A", {1}, kInt);
ExprHandle condition(x);
- Stmt* true_val = Store::make(a, {0}, x);
- Stmt* false_val = Store::make(a, {0}, ExprHandle(2) * x);
+ StmtPtr true_val = Store::make(a, {0}, x);
+ StmtPtr false_val = Store::make(a, {0}, ExprHandle(2) * x);
- Stmt* body = new Cond(condition.node(), true_val, false_val);
- Stmt* simplified = IRSimplifier::simplify(body);
- Block* block = dynamic_cast<Block*>(simplified);
+ StmtPtr body = alloc<Cond>(condition.node(), true_val, false_val);
+ StmtPtr simplified = IRSimplifier::simplify(body);
+ BlockPtr block = to<Block>(simplified);
ASSERT_EQ(block, nullptr);
}
{
- Stmt* cond = new Cond(ExprHandle(false).node(), new Block({}), nullptr);
- Stmt* simplified = IRSimplifier::simplify(cond);
+ StmtPtr cond = alloc<Cond>(
+ ExprHandle(false).node(),
+ alloc<Block>(std::vector<StmtPtr>({})),
+ nullptr);
+ StmtPtr simplified = IRSimplifier::simplify(cond);
ASSERT_EQ(simplified, nullptr);
}
{
- Stmt* cond = new Cond(ExprHandle(true).node(), nullptr, new Block({}));
- Stmt* simplified = IRSimplifier::simplify(cond);
+ StmtPtr cond = alloc<Cond>(
+ ExprHandle(true).node(),
+ nullptr,
+ alloc<Block>(std::vector<StmtPtr>({})));
+ StmtPtr simplified = IRSimplifier::simplify(cond);
ASSERT_EQ(simplified, nullptr);
}
}
{
VarHandle x("x", kInt);
ExprHandle condition(x);
- Stmt* true_val = new Block({});
+ StmtPtr true_val = alloc<Block>(std::vector<StmtPtr>({}));
- Stmt* body = new Cond(condition.node(), true_val, nullptr);
- Stmt* simplified = IRSimplifier::simplify(body);
- Block* block = dynamic_cast<Block*>(simplified);
+ StmtPtr body = alloc<Cond>(condition.node(), true_val, nullptr);
+ StmtPtr simplified = IRSimplifier::simplify(body);
+ BlockPtr block = to<Block>(simplified);
ASSERT_NE(block, nullptr);
ASSERT_EQ(block->nstmts(), 0);
}
{
VarHandle x("x", kInt);
ExprHandle condition(x);
- Stmt* false_val = new Block({});
+ StmtPtr false_val = alloc<Block>(std::vector<StmtPtr>({}));
- Stmt* body = new Cond(condition.node(), nullptr, false_val);
- Stmt* simplified = IRSimplifier::simplify(body);
- Block* block = dynamic_cast<Block*>(simplified);
+ StmtPtr body = alloc<Cond>(condition.node(), nullptr, false_val);
+ StmtPtr simplified = IRSimplifier::simplify(body);
+ BlockPtr block = to<Block>(simplified);
ASSERT_NE(block, nullptr);
ASSERT_EQ(block->nstmts(), 0);
}
BufHandle c("C", {4}, kInt);
VarHandle i("i", kInt);
auto body = For::make(i, 0, 0, Store::make(c, {i}, Load::make(a, {i})));
- Stmt* simplified = IRSimplifier::simplify(body);
- Block* block = dynamic_cast<Block*>(simplified);
+ StmtPtr simplified = IRSimplifier::simplify(body);
+ BlockPtr block = to<Block>(simplified);
ASSERT_EQ(block->nstmts(), 0);
}
BufHandle c("C", {4}, kInt);
VarHandle i("i", kInt);
auto body = For::make(i, 2, 2, Store::make(c, {i}, Load::make(a, {i})));
- Stmt* simplified = IRSimplifier::simplify(body);
- Block* block = dynamic_cast<Block*>(simplified);
+ StmtPtr simplified = IRSimplifier::simplify(body);
+ BlockPtr block = to<Block>(simplified);
ASSERT_EQ(block->nstmts(), 0);
}
BufHandle c("C", {4}, kInt);
VarHandle i("i", kInt);
auto body = For::make(i, x, x, Store::make(c, {i}, Load::make(a, {i})));
- Stmt* simplified = IRSimplifier::simplify(body);
- Block* block = dynamic_cast<Block*>(simplified);
+ StmtPtr simplified = IRSimplifier::simplify(body);
+ BlockPtr block = to<Block>(simplified);
ASSERT_EQ(block->nstmts(), 0);
}
BufHandle c("C", {4}, kInt);
VarHandle i("i", kInt);
auto body = For::make(i, 0, x - x, Store::make(c, {i}, Load::make(a, {i})));
- Stmt* simplified = IRSimplifier::simplify(body);
- Block* block = dynamic_cast<Block*>(simplified);
+ StmtPtr simplified = IRSimplifier::simplify(body);
+ BlockPtr block = to<Block>(simplified);
ASSERT_EQ(block->nstmts(), 0);
}
BufHandle c("C", {4}, kInt);
VarHandle i("i", kInt);
auto body = For::make(i, 0, 3, Store::make(c, {i}, Load::make(a, {i})));
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE(For, simplified);
}
}
BufHandle c("C", {4}, kInt);
VarHandle i("i", kInt);
auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})));
- Stmt* simplified = IRSimplifier::simplify(body);
- Block* block = dynamic_cast<Block*>(simplified);
+ StmtPtr simplified = IRSimplifier::simplify(body);
+ BlockPtr block = to<Block>(simplified);
IS_NODE_WITH_NAME(Store, block->front(), store);
IS_VAR_WITH_NAME(store->base_handle(), "C");
IS_IMM_WITH_VAL(Int, store->flat_index(), 0);
BufHandle c("C", {4}, kInt);
VarHandle i("i", kInt);
auto body = For::make(i, 2, 3, Store::make(c, {i}, Load::make(a, {i})));
- Stmt* simplified = IRSimplifier::simplify(body);
- Block* block = dynamic_cast<Block*>(simplified);
+ StmtPtr simplified = IRSimplifier::simplify(body);
+ BlockPtr block = to<Block>(simplified);
IS_NODE_WITH_NAME(Store, block->front(), store);
IS_VAR_WITH_NAME(store->base_handle(), "C");
IS_IMM_WITH_VAL(Int, store->flat_index(), 2);
BufHandle c("C", {4}, kInt);
VarHandle i("i", kInt);
auto body = For::make(i, x, x + 1, Store::make(c, {i}, Load::make(a, {i})));
- Stmt* simplified = IRSimplifier::simplify(body);
- Block* block = dynamic_cast<Block*>(simplified);
+ StmtPtr simplified = IRSimplifier::simplify(body);
+ BlockPtr block = to<Block>(simplified);
IS_NODE_WITH_NAME(Store, block->front(), store);
IS_VAR_WITH_NAME(store->base_handle(), "C");
IS_VAR_WITH_NAME(store->flat_index(), "x");
VarHandle i("i", kInt);
auto body =
For::make(i, 0, x - x + 1, Store::make(c, {i}, Load::make(a, {i})));
- Stmt* simplified = IRSimplifier::simplify(body);
- Block* block = dynamic_cast<Block*>(simplified);
+ StmtPtr simplified = IRSimplifier::simplify(body);
+ BlockPtr block = to<Block>(simplified);
IS_NODE_WITH_NAME(Store, block->front(), store);
IS_VAR_WITH_NAME(store->base_handle(), "C");
IS_IMM_WITH_VAL(Int, store->flat_index(), 0);
BufHandle c("C", {4}, kInt);
VarHandle i("i", kInt);
auto body = For::make(i, 0, 3, Store::make(c, {i}, Load::make(a, {i})));
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE(For, simplified);
}
}
options.set_gpu_block_index(12);
auto body =
For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})), options);
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(For, simplified, for_);
LoopOptions options2 = for_->loop_options();
ASSERT_EQ(options.gpu_block_index(), options2.gpu_block_index());
BufHandle c("C", {4}, kInt);
VarHandle i("i", kInt);
VarHandle j("j", kInt);
- auto* body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})));
+ auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})));
auto outer = For::make(j, 0, 1, body);
- Stmt* simplified = IRSimplifier::simplify(outer);
- Block* block = dynamic_cast<Block*>(simplified);
+ StmtPtr simplified = IRSimplifier::simplify(outer);
+ BlockPtr block = to<Block>(simplified);
IS_NODE_WITH_NAME(Store, block->front(), store);
IS_VAR_WITH_NAME(store->base_handle(), "C");
IS_IMM_WITH_VAL(Int, store->flat_index(), 0);
BufHandle c("C", {4}, kInt);
VarHandle i("i", kInt);
VarHandle j("j", kInt);
- auto* body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})));
+ auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})));
auto outer = For::make(j, 0, 2, body);
- Stmt* simplified = IRSimplifier::simplify(outer);
- For* for__ = static_cast<For*>(simplified);
+ StmtPtr simplified = IRSimplifier::simplify(outer);
+ ForPtr for__ = static_to<For>(simplified);
IS_NODE_WITH_NAME(For, for__, for_);
IS_VAR_WITH_NAME(for_->var(), "j");
IS_IMM_WITH_VAL(Int, for_->start(), 0);
IS_IMM_WITH_VAL(Int, for_->stop(), 2);
- Block* block = dynamic_cast<Block*>(for_->body());
+ BlockPtr block = to<Block>(for_->body());
ASSERT_NE(block, nullptr);
IS_NODE_WITH_NAME(Store, block->front(), store);
IS_VAR_WITH_NAME(store->base_handle(), "C");
BufHandle c("C", {4}, kInt);
VarHandle i("i", kInt);
VarHandle j("j", kInt);
- auto* body = For::make(i, 0, 2, Store::make(c, {i}, Load::make(a, {i})));
+ auto body = For::make(i, 0, 2, Store::make(c, {i}, Load::make(a, {i})));
auto outer = For::make(j, 0, 1, body);
- Stmt* simplified = IRSimplifier::simplify(outer);
- Block* block = dynamic_cast<Block*>(simplified);
+ StmtPtr simplified = IRSimplifier::simplify(outer);
+ BlockPtr block = to<Block>(simplified);
IS_NODE_WITH_NAME(For, block->front(), for_);
IS_VAR_WITH_NAME(for_->var(), "i");
IS_IMM_WITH_VAL(Int, for_->start(), 0);
LoopNest l({b});
l.prepareForCodegen();
- Stmt* body = l.root_stmt();
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr body = l.root_stmt();
+ StmtPtr simplified = IRSimplifier::simplify(body);
- Block* block = dynamic_cast<Block*>(simplified);
+ BlockPtr block = to<Block>(simplified);
IS_NODE_WITH_NAME(For, block->front(), for_);
// for is over "m".
IS_VAR_WITH_NAME(for_->var(), "m");
{
// Flatten many layers around an empty block to an empty block.
- Stmt* last = new Block({});
+ StmtPtr last = alloc<Block>(std::vector<StmtPtr>({}));
for (int i = 0; i < 11; ++i) {
VarHandle loopVar("loopVar", kInt);
last = For::make(loopVar, 0, 10, last);
}
- Stmt* simplified = IRSimplifier::simplify(last);
+ StmtPtr simplified = IRSimplifier::simplify(last);
IS_NODE_WITH_NAME(Block, simplified, block);
ASSERT_EQ(block->nstmts(), 0);
}
// Flatten multiple blocks down to one.
// { { { stmt1, stmt2 } } } => { stmt1, stmt2 }
BufHandle a("A", {1}, kInt);
- Store* store1 = Store::make(a, {0}, 1);
- Store* store2 = Store::make(a, {0}, 0);
+ StorePtr store1 = Store::make(a, {0}, 1);
+ StorePtr store2 = Store::make(a, {0}, 0);
- Block* block1 = new Block({store1, store2});
- Block* block2 = new Block({block1});
+ BlockPtr block1 = alloc<Block>(std::vector<StmtPtr>({store1, store2}));
+ BlockPtr block2 = alloc<Block>(std::vector<StmtPtr>({block1}));
- Block* enclosing = new Block({block2});
- Stmt* simplified = IRSimplifier::simplify(enclosing);
+ BlockPtr enclosing = alloc<Block>(std::vector<StmtPtr>({block2}));
+ StmtPtr simplified = IRSimplifier::simplify(enclosing);
IS_NODE_WITH_NAME(Block, simplified, block);
ASSERT_EQ(block->nstmts(), 2);
// Flatten multiple sub blocks containing statements.
// { { stmt1 }, { stmt2 } } => { stmt1, stmt2 }
BufHandle a("A", {1}, kInt);
- Store* store1 = Store::make(a, {0}, 1);
- Store* store2 = Store::make(a, {0}, 0);
+ StorePtr store1 = Store::make(a, {0}, 1);
+ StorePtr store2 = Store::make(a, {0}, 0);
- Block* block1 = new Block({store1});
- Block* block2 = new Block({store2});
+ BlockPtr block1 = alloc<Block>(std::vector<StmtPtr>({store1}));
+ BlockPtr block2 = alloc<Block>(std::vector<StmtPtr>({store2}));
- Block* enclosing = new Block({block1, block2});
- Stmt* simplified = IRSimplifier::simplify(enclosing);
+ BlockPtr enclosing = alloc<Block>(std::vector<StmtPtr>({block1, block2}));
+ StmtPtr simplified = IRSimplifier::simplify(enclosing);
IS_NODE_WITH_NAME(Block, simplified, block);
ASSERT_EQ(block->nstmts(), 2);
// Flatten sub blocks with different depths.
// { stmt1 , { { stmt2 } } } => { stmt1, stmt2 }
BufHandle a("A", {1}, kInt);
- Store* store1 = Store::make(a, {0}, 1);
- Store* store2 = Store::make(a, {0}, 0);
+ StorePtr store1 = Store::make(a, {0}, 1);
+ StorePtr store2 = Store::make(a, {0}, 0);
- Block* block1 = new Block({store2});
- Block* block2 = new Block({block1});
+ BlockPtr block1 = alloc<Block>(std::vector<StmtPtr>({store2}));
+ BlockPtr block2 = alloc<Block>(std::vector<StmtPtr>({block1}));
- Block* enclosing = new Block({store1, block2});
- Stmt* simplified = IRSimplifier::simplify(enclosing);
+ BlockPtr enclosing = alloc<Block>(std::vector<StmtPtr>({store1, block2}));
+ StmtPtr simplified = IRSimplifier::simplify(enclosing);
IS_NODE_WITH_NAME(Block, simplified, block);
ASSERT_EQ(block->nstmts(), 2);
{
// Flatten many layers around an empty block to an empty block.
- Stmt* last = new Block({});
+ StmtPtr last = alloc<Block>(std::vector<StmtPtr>({}));
for (int i = 0; i < 11; ++i) {
- last = new Block({last});
+ last = alloc<Block>(std::vector<StmtPtr>({last}));
}
- Stmt* simplified = IRSimplifier::simplify(last);
+ StmtPtr simplified = IRSimplifier::simplify(last);
IS_NODE_WITH_NAME(Block, simplified, block);
ASSERT_EQ(block->nstmts(), 0);
}
// Simple positive case.
BufHandle b("x", {0}, kInt);
- Allocate* alloc = Allocate::make(b);
- Free* free_ = Free::make(b);
+ AllocatePtr alloc_ = Allocate::make(b);
+ FreePtr free_ = Free::make(b);
- Block* block1 = new Block({alloc, free_});
+ BlockPtr block1 = alloc<Block>(std::vector<StmtPtr>({alloc_, free_}));
ASSERT_EQ(block1->nstmts(), 2);
- Stmt* simplified = IRSimplifier::simplify(block1);
+ StmtPtr simplified = IRSimplifier::simplify(block1);
IS_NODE_WITH_NAME(Block, simplified, block2);
ASSERT_EQ(block2->nstmts(), 0);
}
// Simple negative case.
BufHandle b("x", {2}, kInt);
- Allocate* alloc = Allocate::make(b);
- Free* free_ = Free::make(b);
+ AllocatePtr alloc_ = Allocate::make(b);
+ FreePtr free_ = Free::make(b);
- Block* block1 = new Block({alloc, free_});
+ BlockPtr block1 = alloc<Block>(std::vector<StmtPtr>({alloc_, free_}));
ASSERT_EQ(block1->nstmts(), 2);
- Stmt* simplified = IRSimplifier::simplify(block1);
+ StmtPtr simplified = IRSimplifier::simplify(block1);
IS_NODE_WITH_NAME(Block, simplified, block2);
ASSERT_EQ(block2->nstmts(), 2);
}
BufHandle b1("x", {0}, kInt);
BufHandle b2("y", {2}, kInt);
- Allocate* alloc1 = Allocate::make(b1);
- Allocate* alloc2 = Allocate::make(b2);
- Free* free2_ = Free::make(b2);
- Free* free1_ = Free::make(b1);
+ AllocatePtr alloc1 = Allocate::make(b1);
+ AllocatePtr alloc2 = Allocate::make(b2);
+ FreePtr free2_ = Free::make(b2);
+ FreePtr free1_ = Free::make(b1);
- Block* block1 = new Block({alloc1, alloc2, free2_, free1_});
+ BlockPtr block1 =
+ alloc<Block>(std::vector<StmtPtr>({alloc1, alloc2, free2_, free1_}));
ASSERT_EQ(block1->nstmts(), 4);
- Stmt* simplified = IRSimplifier::simplify(block1);
+ StmtPtr simplified = IRSimplifier::simplify(block1);
IS_NODE_WITH_NAME(Block, simplified, block2);
ASSERT_EQ(block2->nstmts(), 2);
IS_NODE_WITH_NAME(Allocate, block2->stmts().front(), simplified_alloc);
BufHandle b1("x", {0}, kInt);
BufHandle b2("y", {z}, kInt);
- Allocate* alloc1 = Allocate::make(b1);
- Allocate* alloc2 = Allocate::make(b2);
- Free* free2_ = Free::make(b2);
- Free* free1_ = Free::make(b1);
+ AllocatePtr alloc1 = Allocate::make(b1);
+ AllocatePtr alloc2 = Allocate::make(b2);
+ FreePtr free2_ = Free::make(b2);
+ FreePtr free1_ = Free::make(b1);
- Block* block1 = new Block({alloc1, alloc2, free2_, free1_});
+ BlockPtr block1 =
+ alloc<Block>(std::vector<StmtPtr>({alloc1, alloc2, free2_, free1_}));
ASSERT_EQ(block1->nstmts(), 4);
- Stmt* simplified = IRSimplifier::simplify(block1);
+ StmtPtr simplified = IRSimplifier::simplify(block1);
IS_NODE_WITH_NAME(Block, simplified, block2);
ASSERT_EQ(block2->nstmts(), 2);
}
Store::make(c, {i}, Load::make(a, {i})),
nullptr));
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Cond, simplified, cond);
IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block);
IS_NODE_WITH_NAME(For, true_block->front(), loop);
Store::make(c, {i}, Load::make(a, {i})),
nullptr));
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(For, simplified, loop);
IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond);
}
Store::make(c, {0}, Load::make(a, {i})),
nullptr));
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(For, simplified, loop);
IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond);
}
Store::make(c, {0}, Load::make(a, {i})),
nullptr));
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Cond, simplified, cond);
IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block);
IS_NODE_WITH_NAME(For, true_block->front(), loop);
Store::make(c, {0}, Load::make(a, {i})),
nullptr));
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Cond, simplified, cond);
IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block);
IS_NODE_WITH_NAME(For, true_block->front(), loop);
Store::make(c, {0}, Load::make(a, {i})),
nullptr)}));
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(For, simplified, loop);
IS_NODE_WITH_NAME(Let, loop->body()->front(), let);
IS_NODE_WITH_NAME(Cond, loop->body()->back(), cond);
nullptr),
nullptr));
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Cond, simplified, cond);
IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block);
IS_NODE_WITH_NAME(Cond, true_block->front(), cond2);
nullptr),
nullptr));
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Cond, simplified, cond);
IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block);
IS_NODE_WITH_NAME(For, true_block->front(), loop);
Store::make(c, {0}, Load::make(a, {i})),
Store::make(c, {0}, 0)));
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(For, simplified, loop);
IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond);
}
Store::make(c, {1}, Load::make(a, {i})),
nullptr));
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(For, simplified, loop);
IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond);
}
Store::make(a, {1}, i),
nullptr)});
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
IS_NODE_WITH_NAME(Block, simplified, block);
ASSERT_EQ(block->nstmts(), 1);
Store::make(a, {1}, i),
nullptr)});
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Block, simplified, block);
ASSERT_EQ(block->nstmts(), 2);
IS_NODE_WITH_NAME(Cond, block->front(), cond1);
Store::make(a, {1}, i),
nullptr)});
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Block, simplified, block);
ASSERT_EQ(block->nstmts(), 2);
IS_NODE_WITH_NAME(Cond, block->front(), cond1);
Store::make(a, {1}, i),
nullptr)});
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Block, simplified, block);
ASSERT_EQ(block->nstmts(), 2);
IS_NODE_WITH_NAME(Cond, block->front(), cond1);
// TODO for later.
auto body = Block::make(
{Cond::make(
- CompareSelect::make(
- i,
- 10,
- new IntImm(1),
- new IntImm(0),
- CompareSelectOperation::kLT),
+ CompareSelect::make(i, 10, 1, 0, CompareSelectOperation::kLT),
Store::make(a, {0}, i),
nullptr),
Cond::make(
- CompareSelect::make(
- j,
- 10,
- new IntImm(2),
- new IntImm(0),
- CompareSelectOperation::kLT),
+ CompareSelect::make(j, 10, 2, 0, CompareSelectOperation::kLT),
Store::make(a, {1}, i),
nullptr)});
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Block, simplified, block);
ASSERT_EQ(block->nstmts(), 2);
IS_NODE_WITH_NAME(Cond, block->front(), cond1);
nullptr,
Store::make(a, {1}, i))});
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Block, simplified, block);
ASSERT_EQ(block->nstmts(), 1);
IS_NODE_WITH_NAME(Cond, block->front(), cond);
Store::make(a, {1}, i),
Store::make(b, {1}, i))});
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Block, simplified, block);
ASSERT_EQ(block->nstmts(), 1);
IS_NODE_WITH_NAME(Cond, block->front(), cond);
nullptr,
Store::make(b, {1}, i))});
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Block, simplified, block);
ASSERT_EQ(block->nstmts(), 1);
IS_NODE_WITH_NAME(Cond, block->front(), cond);
Store::make(a, {1}, j),
nullptr),
});
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Block, simplified, block);
ASSERT_EQ(block->nstmts(), 3);
auto it = block->begin();
Store::make(a, {1}, j),
nullptr),
});
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Block, simplified, block);
ASSERT_EQ(block->nstmts(), 1);
IS_NODE_WITH_NAME(Cond, block->front(), cond);
Store::make(a, {1}, j),
nullptr),
});
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Block, simplified, block);
ASSERT_EQ(block->nstmts(), 3);
IS_NODE_WITH_NAME(Cond, block->front(), cond);
CompareSelectOperation::kLT),
Store::make(a, {1}, i),
nullptr)});
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Block, simplified, block);
ASSERT_EQ(block->nstmts(), 1);
IS_NODE_WITH_NAME(Cond, block->front(), cond);
{Cond::make(i, Store::make(a, {0}, i), nullptr),
Cond::make(i, Store::make(a, {1}, i), nullptr)});
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Block, simplified, block);
ASSERT_EQ(block->nstmts(), 1);
IS_NODE_WITH_NAME(Cond, block->front(), cond);
{Cond::make(i, Store::make(a, {0}, i), nullptr),
Cond::make(j, Store::make(a, {1}, i), nullptr)});
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Block, simplified, block);
ASSERT_EQ(block->nstmts(), 2);
IS_NODE_WITH_NAME(Cond, block->front(), cond1);
auto body = Block::make(
{Cond::make(1, Store::make(a, {0}, i), nullptr),
Cond::make(1, Store::make(a, {1}, i), nullptr)});
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Block, simplified, block);
ASSERT_EQ(block->nstmts(), 2);
IS_NODE_WITH_NAME(Store, block->front(), store1);
Store::make(a, {2}, Load::make(b, {0})),
nullptr)}));
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Cond, simplified, cond);
IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block);
IS_NODE_WITH_NAME(For, true_block->front(), loop);
auto body = Block::make(
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
{Store::make(a, {0}, 1),
- new SyncThreads(),
- new SyncThreads(),
+ alloc<SyncThreads>(),
+ alloc<SyncThreads>(),
Store::make(a, {1}, 0)});
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Block, simplified, block);
ASSERT_EQ(block->nstmts(), 3);
auto it = block->begin();
{
// Eliminate outer SyncThreads.
auto body = Block::make(
- {new SyncThreads(), Store::make(a, {1}, 0), new SyncThreads()});
+ {alloc<SyncThreads>(), Store::make(a, {1}, 0), alloc<SyncThreads>()});
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Block, simplified, block);
ASSERT_EQ(block->nstmts(), 1);
auto it = block->begin();
// Merge many inner SyncThreads.
auto body = Block::make(
{Store::make(a, {0}, 1),
- new SyncThreads(),
- new SyncThreads(),
- new SyncThreads(),
- new SyncThreads(),
- new SyncThreads(),
+ alloc<SyncThreads>(),
+ alloc<SyncThreads>(),
+ alloc<SyncThreads>(),
+ alloc<SyncThreads>(),
+ alloc<SyncThreads>(),
Store::make(a, {1}, 0)});
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Block, simplified, block);
ASSERT_EQ(block->nstmts(), 3);
auto it = block->begin();
{
// Merge multiple outer SyncThreads.
auto body = Block::make(
- {new SyncThreads(),
- new SyncThreads(),
+ {alloc<SyncThreads>(),
+ alloc<SyncThreads>(),
Store::make(a, {1}, 0),
- new SyncThreads(),
- new SyncThreads(),
- new SyncThreads(),
- new SyncThreads()});
+ alloc<SyncThreads>(),
+ alloc<SyncThreads>(),
+ alloc<SyncThreads>(),
+ alloc<SyncThreads>()});
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Block, simplified, block);
ASSERT_EQ(block->nstmts(), 1);
auto it = block->begin();
// Merge multiple sections;
auto body = Block::make(
{Store::make(a, {0}, 1),
- new SyncThreads(),
- new SyncThreads(),
+ alloc<SyncThreads>(),
+ alloc<SyncThreads>(),
Store::make(a, {1}, 0),
Store::make(a, {2}, 0),
- new SyncThreads(),
- new SyncThreads(),
- new SyncThreads(),
+ alloc<SyncThreads>(),
+ alloc<SyncThreads>(),
+ alloc<SyncThreads>(),
Store::make(a, {3}, 0)});
- Stmt* simplified = IRSimplifier::simplify(body);
+ StmtPtr simplified = IRSimplifier::simplify(body);
IS_NODE_WITH_NAME(Block, simplified, block);
ASSERT_EQ(block->nstmts(), 6);
auto it = block->begin();
ExprHandle ramp = Ramp::make(ExprHandle(0), ExprHandle(6), num_lanes);
ExprHandle broadcast = Broadcast::make(ExprHandle(-5), num_lanes);
ExprHandle simplified = IRSimplifier::simplify(ramp - broadcast);
- Ramp* newRamp = simplified.AsNode<Ramp>();
+ RampPtr newRamp = simplified.AsNode<Ramp>();
IS_NODE_WITH_NAME(IntImm, newRamp->base(), base);
ASSERT_EQ(base->value(), 5);
IS_NODE_WITH_NAME(IntImm, newRamp->stride(), stride);
constexpr int N = 8;
Placeholder b("b", kFloat, {N});
VarHandle n("n", kInt);
- Stmt* s = For::make(
+ StmtPtr s = For::make(
n, 1, N, b.store({n}, CompareSelect::make(n, 1, 0.f, 1.0f, kLT)));
s = IRSimplifier::simplify(s);
std::ostringstream oss;
constexpr int N = 8;
Placeholder b("b", kFloat, {N});
VarHandle n("n", kInt);
- Stmt* s =
+ StmtPtr s =
For::make(n, 1, N, b.store({n}, IfThenElse::make(n < 1, 0.f, 1.0f)));
s = IRSimplifier::simplify(s);
std::ostringstream oss;
csel = CompareSelect::make(j, 1, 1, csel, kLT);
csel = CompareSelect::make(i, N - 1, 1, csel, kGE);
csel = CompareSelect::make(j, N - 1, 1, csel, kGE);
- Stmt* s = b.store({i, j}, IfThenElse::make(csel, 0.f, 1.0f));
+ StmtPtr s = b.store({i, j}, IfThenElse::make(csel, 0.f, 1.0f));
s = For::make(j, 1, N - 1, s);
s = For::make(i, 1, N - 1, s);
s = IRSimplifier::simplify(s);
csel = CompareSelect::make(j, 1, 1, csel, kLT);
csel = CompareSelect::make(i, N - 1, 1, csel, kGE);
csel = CompareSelect::make(j, N - 1, 1, csel, kGE);
- Stmt* s = b.store(
+ StmtPtr s = b.store(
{i, j}, b.load({i, j}) + IfThenElse::make(csel, 0.f, a.load({i, j})));
s = For::make(j, 0, K, s);
s = For::make(i, 0, K, s);
namespace jit {
using namespace torch::jit::tensorexpr;
-#define IS_NODE(T, node) \
- { \
- auto* node_ = dynamic_cast<const T*>(node); \
- ASSERT_NE(nullptr, node_); \
+#define IS_NODE(T, node) \
+ { \
+ auto node_ = to<T>(node); \
+ ASSERT_NE(nullptr, node_); \
}
-#define IS_NODE_WITH_NAME(T, node, name) \
- auto* name = dynamic_cast<const T*>(node); \
+#define IS_NODE_WITH_NAME(T, node, name) \
+ auto name = to<T>(node); \
ASSERT_NE(nullptr, name);
#define IS_NODE_WITH_NAME_AND_CAST(T, node, name, Type) \
- const T* name = nullptr; \
+ NodePtr<T> name = nullptr; \
{ \
- auto* node_ = dynamic_cast<const Cast*>(node); \
+ auto node_ = to<Cast>(node); \
ASSERT_NE(nullptr, node_); \
ASSERT_EQ(node_->dtype().scalar_type(), ScalarType::Type); \
- name = dynamic_cast<const T*>(node_->src_value()); \
+ name = to<T>(node_->src_value()); \
} \
ASSERT_NE(nullptr, name);
-#define IS_IMM_WITH_VAL(T, node, val) \
- { \
- auto* node_ = dynamic_cast<const T##Imm*>(node); \
- ASSERT_NE(nullptr, node_); \
- ASSERT_EQ(node_->value(), val); \
+#define IS_IMM_WITH_VAL(T, node, val) \
+ { \
+ auto node_ = to<T##Imm>(node); \
+ ASSERT_NE(nullptr, node_); \
+ ASSERT_EQ(node_->value(), val); \
}
-#define IS_VAR_WITH_NAME(node, name) \
- { \
- auto* node_ = dynamic_cast<const Var*>(node); \
- ASSERT_NE(nullptr, node_); \
- ASSERT_EQ(node_->name_hint(), name); \
+#define IS_VAR_WITH_NAME(node, name) \
+ { \
+ auto node_ = to<Var>(node); \
+ ASSERT_NE(nullptr, node_); \
+ ASSERT_EQ(node_->name_hint(), name); \
}
#define IS_BINOP_W_VARS(T, node, name, v1, v2) \
- const T* name = nullptr; \
+ NodePtr<T> name = nullptr; \
{ \
- name = dynamic_cast<const T*>(node); \
+ name = to<T>(node); \
ASSERT_NE(nullptr, name); \
IS_VAR_WITH_NAME(name->lhs(), v1); \
IS_VAR_WITH_NAME(name->rhs(), v2); \
}
#define IS_BINOP_W_CONST(T, node, name, v, c) \
- const T* name = nullptr; \
+ NodePtr<T> name = nullptr; \
{ \
- name = dynamic_cast<const T*>(node); \
+ name = to<T>(node); \
ASSERT_NE(nullptr, name); \
IS_VAR_WITH_NAME(name->lhs(), v); \
IS_IMM_WITH_VAL(Int, name->rhs(), c); \
}
-#define IS_RAND(node) \
- { \
- auto* node_ = dynamic_cast<const Intrinsics*>(node); \
- ASSERT_NE(nullptr, node_); \
- ASSERT_EQ(node_->op_type(), kRand); \
+#define IS_RAND(node) \
+ { \
+ auto node_ = to<Intrinsics>(node); \
+ ASSERT_NE(nullptr, node_); \
+ ASSERT_EQ(node_->op_type(), kRand); \
}
} // namespace jit
// also be a 'Mul' or some other expression.
//
// Let's construct a simple TE:
- Expr* lhs = new IntImm(5);
- Expr* rhs = new Var("x", kInt);
- Expr* mul = new Mul(lhs, rhs);
+ ExprPtr lhs = alloc<IntImm>(5);
+ ExprPtr rhs = alloc<Var>("x", kInt);
+ ExprPtr mul = alloc<Mul>(lhs, rhs);
std::cout << "Tensor expression: " << *mul << std::endl;
// Prints: Tensor expression: 5 * x
// Let's start with defining a domain. We do this by creating a Buf object.
// First, let's specify the sizes:
- std::vector<Expr*> dims = {
- new IntImm(64), new IntImm(32)}; // IntImm stands for Integer Immediate
+ std::vector<ExprPtr> dims = {
+ alloc<IntImm>(64),
+ alloc<IntImm>(32)}; // IntImm stands for Integer Immediate
// and represents an integer constant
// Now we can create a Buf object by providing a name, dimensions, and a
// data type of the elements:
- Buf* buf = new Buf("X", dims, kInt);
+ BufPtr buf = alloc<Buf>("X", dims, kInt);
// Next we need to spefify the computation. We can do that by either
// constructing a complete tensor statement for it (statements are
// Let's define two variables, i and j - they will be axis in our
// computation.
- Var* i = new Var("i", kInt);
- Var* j = new Var("j", kInt);
- std::vector<Var*> args = {i, j};
+ VarPtr i = alloc<Var>("i", kInt);
+ VarPtr j = alloc<Var>("j", kInt);
+ std::vector<VarPtr> args = {i, j};
// Now we can define the body of the tensor computation using these
// variables. What this means is that values in our tensor are:
// X[i, j] = i * j
- Expr* body = new Mul(i, j);
+ ExprPtr body = alloc<Mul>(i, j);
// Finally, we pass all these pieces together to Tensor constructor:
Tensor* X = new Tensor(buf, args, body);
// Loop transformations can be composed, so we can do something else with
// our loop nest now. Let's split the inner loop with a factor of 9, for
// instance.
- std::vector<For*> loops = loopnest.getLoopStmtsFor(Y);
+ std::vector<ForPtr> loops = loopnest.getLoopStmtsFor(Y);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* j_inner;
+ ForPtr j_inner;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* j_tail;
+ ForPtr j_tail;
int split_factor = 9;
loopnest.splitWithTail(
loops[1], // loops[0] is the outer loop, loops[1] is inner
tensorexpr::Tensor* target,
int width) {
using namespace torch::jit::tensorexpr;
- std::vector<For*> loops = ln->getLoopStmtsFor(target);
- For *inner, *tail;
+ std::vector<ForPtr> loops = ln->getLoopStmtsFor(target);
+ ForPtr inner, tail;
TORCH_CHECK(loops.size() > 0, "No loops created for pointwise op");
ln->splitWithTail(loops[0], width, &inner, &tail);
ln->vectorize(inner);
LoopNest ln({out});
optimizePointwise(&ln, out, width);
ln.prepareForCodegen();
- Stmt* s = ln.root_stmt();
+ StmtPtr s = ln.root_stmt();
s = tensorexpr::IRSimplifier::simplify(s);
std::vector<CodeGen::BufferArg> args;
args.emplace_back(out);
namespace tensorexpr {
class HasRand : public IRVisitor {
public:
- HasRand(Stmt* stmt) : stmt_(stmt) {
+ HasRand(StmtPtr stmt) : stmt_(stmt) {
stmt_->accept(this);
}
}
private:
- void visit(Intrinsics* v) override {
+ void visit(IntrinsicsPtr v) override {
if (v->op_type() == IntrinsicsOp::kRand) {
has_rand_ = true;
} else {
IRVisitor::visit(v);
}
}
- Stmt* stmt_;
+ StmtPtr stmt_;
bool has_rand_ = false;
};
-template <typename Node>
+template <typename Op>
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class NodeFinder : public IRVisitor {
public:
- void visit(Node* v) override {
- nodes.push_back((Node*)v);
+ void visit(NodePtr<Op> v) override {
+ nodes.push_back((NodePtr<Op>)v);
IRVisitor::visit(v);
}
- static std::vector<Node*> find(Stmt* s) {
- NodeFinder<Node> nf;
+ static std::vector<NodePtr<Op>> find(StmtPtr s) {
+ NodeFinder<Op> nf;
s->accept(&nf);
return nf.nodes;
}
- static std::vector<Node*> find(Expr* e) {
- NodeFinder<Node> nf;
+ static std::vector<NodePtr<Op>> find(ExprPtr e) {
+ NodeFinder<Op> nf;
e->accept(&nf);
return nf.nodes;
}
- std::vector<Node*> nodes;
+ std::vector<NodePtr<Op>> nodes;
};
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class VarFinder : public IRVisitor {
public:
- void visit(Var* v) override {
+ void visit(VarPtr v) override {
vars_.insert(v);
IRVisitor::visit(v);
}
- static std::unordered_set<Var*> find(Stmt* s) {
+ static std::unordered_set<VarPtr> find(StmtPtr s) {
VarFinder nf;
s->accept(&nf);
return nf.vars();
}
- static std::unordered_set<Var*> find(Expr* e) {
+ static std::unordered_set<VarPtr> find(ExprPtr e) {
VarFinder nf;
e->accept(&nf);
return nf.vars();
}
- const std::unordered_set<Var*>& vars() {
+ const std::unordered_set<VarPtr>& vars() {
return vars_;
}
private:
- std::unordered_set<Var*> vars_;
+ std::unordered_set<VarPtr> vars_;
};
class BufFinder : public IRVisitor {
public:
- void visit(Buf* v) override {
+ void visit(BufPtr v) override {
bufs_.insert(v);
IRVisitor::visit(v);
}
- static std::unordered_set<Buf*> find(Stmt* s) {
+ static std::unordered_set<BufPtr> find(StmtPtr s) {
BufFinder nf;
s->accept(&nf);
return nf.bufs();
}
- static std::unordered_set<Buf*> find(Expr* e) {
+ static std::unordered_set<BufPtr> find(ExprPtr e) {
BufFinder nf;
e->accept(&nf);
return nf.bufs();
}
- const std::unordered_set<Buf*>& bufs() {
+ const std::unordered_set<BufPtr>& bufs() {
return bufs_;
}
private:
- std::unordered_set<Buf*> bufs_;
+ std::unordered_set<BufPtr> bufs_;
};
// Finds all kinds of write operations to the provided Buf.
class WritesToBuf : public IRVisitor {
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
- WritesToBuf(Buf* target) : target_(target) {}
+ WritesToBuf(BufPtr target) : target_(target) {}
- std::vector<Stmt*> writes() {
+ std::vector<StmtPtr> writes() {
return writes_;
}
- static std::vector<Stmt*> find(Stmt* s, Buf* b) {
+ static std::vector<StmtPtr> find(StmtPtr s, BufPtr b) {
WritesToBuf finder(b);
s->accept(&finder);
return finder.writes();
}
private:
- void visit(Store* v) override {
+ void visit(StorePtr v) override {
if (v->buf() == target_) {
writes_.push_back(v);
}
}
- void visit(AtomicAdd* v) override {
+ void visit(AtomicAddPtr v) override {
if (v->buf() == target_) {
writes_.push_back(v);
}
}
- Buf* target_;
- std::vector<Stmt*> writes_;
+ BufPtr target_;
+ std::vector<StmtPtr> writes_;
};
class StmtsReadingBuf : public IRVisitor {
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
- StmtsReadingBuf(Buf* target) : target_(target) {}
+ StmtsReadingBuf(BufPtr target) : target_(target) {}
- std::vector<Stmt*> reads() {
+ std::vector<StmtPtr> reads() {
return reads_;
}
- static std::vector<Stmt*> find(Stmt* s, Buf* b) {
+ static std::vector<StmtPtr> find(StmtPtr s, BufPtr b) {
StmtsReadingBuf finder(b);
s->accept(&finder);
return finder.reads();
}
private:
- bool readsBuffer(Stmt* s) {
+ bool readsBuffer(StmtPtr s) {
auto loads = NodeFinder<Load>::find(s);
for (auto l : loads) {
if (l->buf() == target_) {
return false;
}
- void visit(Store* v) override {
+ void visit(StorePtr v) override {
if (readsBuffer(v)) {
reads_.push_back(v);
}
}
- void visit(Let* v) override {
+ void visit(LetPtr v) override {
if (readsBuffer(v)) {
reads_.push_back(v);
}
}
- void visit(Cond* v) override {
+ void visit(CondPtr v) override {
if (readsBuffer(v)) {
reads_.push_back(v);
}
}
- void visit(AtomicAdd* v) override {
+ void visit(AtomicAddPtr v) override {
if (readsBuffer(v)) {
reads_.push_back(v);
}
}
- Buf* target_;
- std::vector<Stmt*> reads_;
+ BufPtr target_;
+ std::vector<StmtPtr> reads_;
};
// Traverses the IR to determine if a particular Var is modified within it.
class ModifiesVarChecker : public IRVisitor {
public:
- ModifiesVarChecker(Var* v) : var_(v) {}
+ ModifiesVarChecker(VarPtr v) : var_(v) {}
- static bool check(Stmt* s, Var* v) {
+ static bool check(StmtPtr s, VarPtr v) {
ModifiesVarChecker checker(v);
s->accept(&checker);
return checker.found();
}
private:
- void visit(Store* v) override {
+ void visit(StorePtr v) override {
if (v->buf()->base_handle() == var_) {
found_ = true;
return;
IRVisitor::visit(v);
}
- void visit(AtomicAdd* v) override {
+ void visit(AtomicAddPtr v) override {
if (v->buf()->base_handle() == var_) {
found_ = true;
return;
IRVisitor::visit(v);
}
- void visit(Let* v) override {
+ void visit(LetPtr v) override {
if (v->var() == var_) {
found_ = true;
return;
IRVisitor::visit(v);
}
- void visit(For* v) override {
+ void visit(ForPtr v) override {
if (v->var() == var_) {
found_ = true;
return;
IRVisitor::visit(v);
}
- Var* var_;
+ VarPtr var_;
bool found_{false};
};
// It creates a map of multi dim buffers and their flat verions
class CreateBufferMap : public IRVisitor {
public:
- const std::unordered_map<std::string, Buf*>& getBufferMap() const {
+ const std::unordered_map<std::string, BufPtr>& getBufferMap() const {
return map_input_to_tensor_bufs_;
}
private:
- void visit(Store* v) override {
- auto load_node = dynamic_cast<Load*>(v->value());
+ void visit(StorePtr v) override {
+ auto load_node = to<Load>(v->value());
if (load_node) {
auto t_buf = load_node->buf();
map_input_to_tensor_bufs_.emplace(t_buf->name_hint(), v->buf());
} else {
- auto add_node = dynamic_cast<Add*>(v->value());
- auto mul_node = dynamic_cast<Mul*>(v->value());
+ auto add_node = to<Add>(v->value());
+ auto mul_node = to<Mul>(v->value());
// This means for now, v->value() can be Add or Mul
TORCH_INTERNAL_ASSERT((add_node || mul_node));
map_input_to_tensor_bufs_.emplace(v->buf()->name_hint(), v->buf());
}
v->value()->accept(this);
}
- std::unordered_map<std::string, Buf*> map_input_to_tensor_bufs_;
+ std::unordered_map<std::string, BufPtr> map_input_to_tensor_bufs_;
};
} // namespace tensorexpr
}
}
-bool BlockAnalysis::areBufsInMap(const std::unordered_set<Buf*>& bufs) const {
+bool BlockAnalysis::areBufsInMap(const std::unordered_set<BufPtr>& bufs) const {
for (auto const& arg : bufs) {
auto got = map_input_to_tensor_bufs_.find(arg->name_hint());
if (got == map_input_to_tensor_bufs_.end()) {
return true;
}
-Buf* BlockAnalysis::getMultiDimBuf(Buf* buf) const {
+BufPtr BlockAnalysis::getMultiDimBuf(BufPtr buf) const {
auto input_ = map_input_to_tensor_bufs_.find(buf->name_hint());
if (input_ != map_input_to_tensor_bufs_.end()) {
return input_->second;
}
}
-std::string BlockAnalysis::getInputName(Buf* buf) const {
+std::string BlockAnalysis::getInputName(BufPtr buf) const {
auto input_ = map_input_to_tensor_bufs_.find(buf->name_hint());
if (input_ != map_input_to_tensor_bufs_.end()) {
return input_->second->name_hint();
}
}
-void BlockAnalysis::visit(Store* v) {
+void BlockAnalysis::visit(StorePtr v) {
store_targets_.insert(v->buf());
v->value()->accept(this);
}
-void BlockAnalysis::visit(Load* v) {
+void BlockAnalysis::visit(LoadPtr v) {
loads_.insert(v->buf());
}
-void BlockAnalysis::visit(For* v) {
+void BlockAnalysis::visit(ForPtr v) {
const LoopOptions& loop_options = v->loop_options();
if (loop_options.is_gpu_block_index()) {
map_input_to_tensor_bufs_ = loop_options.get_buffer_mapping();
v->body()->accept(this);
} else if (loop_options.is_gpu_thread_index()) {
auto block_size = v->stop();
- block_size_ = dynamic_cast<IntImm*>(block_size)->value();
+ block_size_ = to<IntImm>(block_size)->value();
v->body()->accept(this);
} else {
IRVisitor::visit(v);
// TODO: When handling fused ops d = a + b + c, the correct
// way would be to mutate the expression to Block version and print.
-void BlockPrinter::visit(Add* v) {
+void BlockPrinter::visit(AddPtr v) {
emitIndent();
os() << "add(";
v->lhs()->accept(this);
v->rhs()->accept(this);
}
-void BlockPrinter::visit(Mul* v) {
+void BlockPrinter::visit(MulPtr v) {
emitIndent();
os() << "mul(";
v->lhs()->accept(this);
v->rhs()->accept(this);
}
-void BlockPrinter::visit(For* v) {
+void BlockPrinter::visit(ForPtr v) {
const LoopOptions& loop_options = v->loop_options();
auto buf_reads = block_analysis_->loads();
auto buf_writes = block_analysis_->stores();
- std::unordered_set<Buf*> bufs(buf_reads.begin(), buf_reads.end());
+ std::unordered_set<BufPtr> bufs(buf_reads.begin(), buf_reads.end());
bufs.insert(buf_writes.begin(), buf_writes.end());
if (loop_options.is_gpu_block_index()) {
}
}
-void BlockPrinter::PrintTensorInfo(const std::unordered_set<Buf*>& bufs) {
+void BlockPrinter::PrintTensorInfo(const std::unordered_set<BufPtr>& bufs) {
os() << "tensors {";
for (auto& buf : bufs) {
os() << std::endl;
os() << "}" << std::endl << std::endl;
}
-void BlockPrinter::PrintArguments(const std::unordered_set<Buf*>& bufs) {
+void BlockPrinter::PrintArguments(const std::unordered_set<BufPtr>& bufs) {
for (auto& buf : bufs) {
auto multidimbuf = block_analysis_->getMultiDimBuf(buf);
auto num_dims = multidimbuf->dims().size();
// The dims for the multi-dim tensors
for (unsigned long d = 0; d < num_dims; d++) {
- auto dim_val = dynamic_cast<IntImm*>(multidimbuf->dim(d));
+ auto dim_val = to<IntImm>(multidimbuf->dim(d));
this->dim_values_map.emplace(this->dim_names[d], dim_val->value());
}
// The dimensions for the flattened tensors
- auto val = dynamic_cast<IntImm*>(buf->dim(0));
+ auto val = to<IntImm>(buf->dim(0));
if (block_analysis_->is_buf_store_target(buf)) {
this->dim_values_map.emplace(
this->flat_dim_names[num_dims - 1], val->value());
os() << "}" << std::endl << std::endl;
}
-void BlockPrinter::PrintBufferInfo(const std::unordered_set<Buf*>& bufs) {
+void BlockPrinter::PrintBufferInfo(const std::unordered_set<BufPtr>& bufs) {
emitIndent();
os() << "buffers {";
for (auto& read : bufs) {
os() << "}" << std::endl << std::endl;
}
-void BlockPrinter::PrintDistribution(const std::unordered_set<Buf*>& bufs) {
+void BlockPrinter::PrintDistribution(const std::unordered_set<BufPtr>& bufs) {
emitIndent();
os() << "distribution {" << std::endl;
for (auto& buf : bufs) {
}
void BlockPrinter::PrintLoop(
- const std::unordered_set<Buf*>& bufs,
+ const std::unordered_set<BufPtr>& bufs,
bool block_idx) {
emitIndent();
os() << "loop (";
}
void BlockPrinter::PrintReshapeInfo(
- const std::unordered_set<Buf*>& bufs,
+ const std::unordered_set<BufPtr>& bufs,
bool reverse) {
for (auto& buf : bufs) {
emitIndent();
}
}
-void BlockPrinter::PrintDMAs(const std::unordered_set<Buf*>& bufs) {
+void BlockPrinter::PrintDMAs(const std::unordered_set<BufPtr>& bufs) {
for (auto& read : bufs) {
emitIndent();
os() << "dma_in(";
os() << ")" << std::endl;
}
}
-void BlockPrinter::PrintAdjustBuffers(const std::unordered_set<Buf*>& bufs) {
+void BlockPrinter::PrintAdjustBuffers(const std::unordered_set<BufPtr>& bufs) {
for (auto& read : bufs) {
emitIndent();
os() << "adjust_buffer(";
}
}
-void BlockPrinter::visit(Load* v) {
+void BlockPrinter::visit(LoadPtr v) {
os() << block_analysis_->getFlatInputName(v->buf()) << ".buffer, ";
}
-void BlockPrinter::visit(Store* v) {
+void BlockPrinter::visit(StorePtr v) {
emitIndent();
os() << *v->value() << block_analysis_->getFlatInputName(v->buf())
<< ".tensor)" << std::endl;
}
-void BlockPrinter::visit(Block* v) {
+void BlockPrinter::visit(BlockPtr v) {
os() << "{" << std::endl;
indent_++;
- for (Stmt* s : v->stmts()) {
+ for (StmtPtr s : v->stmts()) {
s->accept(this);
}
indent_--;
block_analysis_ = std::make_unique<BlockAnalysis>();
printer_ = std::make_unique<BlockPrinter>(&oss_, block_analysis_.get());
- Stmt* stmt_v = stmt();
+ StmtPtr stmt_v = stmt();
stmt_v->accept(block_analysis_.get());
auto buf_reads = block_analysis_->loads();
auto buf_writes = block_analysis_->stores();
// Ensure all Bufs in reads/writes are in the map
- std::unordered_set<Buf*> bufs(buf_reads.begin(), buf_reads.end());
+ std::unordered_set<BufPtr> bufs(buf_reads.begin(), buf_reads.end());
bufs.insert(buf_writes.begin(), buf_writes.end());
if (!block_analysis_->areBufsInMap(bufs)) {
throw std::runtime_error("BlockCodeGen: Entry not in input/Buffer map");
// A class that analyzes the given program relevant for Block backend.
class BlockAnalysis : public IRVisitor {
public:
- bool is_buf_store_target(Buf* buf) const {
+ bool is_buf_store_target(BufPtr buf) const {
return store_targets_.count(buf) > 0;
}
- const std::unordered_set<Buf*>& loads() const {
+ const std::unordered_set<BufPtr>& loads() const {
return loads_;
}
- const std::unordered_set<Buf*>& stores() const {
+ const std::unordered_set<BufPtr>& stores() const {
return store_targets_;
}
return block_size_;
}
- bool areBufsInMap(const std::unordered_set<Buf*>& bufs) const;
+ bool areBufsInMap(const std::unordered_set<BufPtr>& bufs) const;
- Buf* getMultiDimBuf(Buf* buf) const;
+ BufPtr getMultiDimBuf(BufPtr buf) const;
- std::string getInputName(Buf* buf) const;
+ std::string getInputName(BufPtr buf) const;
- std::string getFlatInputName(Buf* buf) const {
+ std::string getFlatInputName(BufPtr buf) const {
return getInputName(buf) + "_flat";
}
- std::unordered_map<std::string, Buf*> getBufferMap() const {
+ std::unordered_map<std::string, BufPtr> getBufferMap() const {
return map_input_to_tensor_bufs_;
}
private:
- void visit(Store* v) override;
- void visit(Load* v) override;
- void visit(For* v) override;
+ void visit(StorePtr v) override;
+ void visit(LoadPtr v) override;
+ void visit(ForPtr v) override;
- std::unordered_map<std::string, Buf*> map_input_to_tensor_bufs_;
- std::unordered_set<Buf*> store_targets_;
- std::unordered_set<Buf*> loads_;
+ std::unordered_map<std::string, BufPtr> map_input_to_tensor_bufs_;
+ std::unordered_set<BufPtr> store_targets_;
+ std::unordered_set<BufPtr> loads_;
int block_size_ = 32;
};
std::unordered_map<std::string, int> dim_values_map;
std::vector<std::string> dim_names = {"N", "H", "W", "C"};
std::vector<std::string> flat_dim_names = {"N", "NH", "NHW", "NHWC"};
- void PrintTensorInfo(const std::unordered_set<Buf*>& bufs);
- void PrintArguments(const std::unordered_set<Buf*>& bufs);
- void PrintBufferInfo(const std::unordered_set<Buf*>& bufs);
- void PrintDistribution(const std::unordered_set<Buf*>& bufs);
- void PrintLoop(const std::unordered_set<Buf*>& bufs, bool block_idx = true);
+ void PrintTensorInfo(const std::unordered_set<BufPtr>& bufs);
+ void PrintArguments(const std::unordered_set<BufPtr>& bufs);
+ void PrintBufferInfo(const std::unordered_set<BufPtr>& bufs);
+ void PrintDistribution(const std::unordered_set<BufPtr>& bufs);
+ void PrintLoop(const std::unordered_set<BufPtr>& bufs, bool block_idx = true);
void PrintReshapeInfo(
- const std::unordered_set<Buf*>& bufs,
+ const std::unordered_set<BufPtr>& bufs,
bool reverse = false);
- void PrintDMAs(const std::unordered_set<Buf*>& bufs);
- void PrintAdjustBuffers(const std::unordered_set<Buf*>& bufs);
-
- void visit(For* v) override;
- void visit(Load* v) override;
- void visit(Store* v) override;
- void visit(Block* v) override;
- void visit(Add* v) override;
- void visit(Mul* v) override;
+ void PrintDMAs(const std::unordered_set<BufPtr>& bufs);
+ void PrintAdjustBuffers(const std::unordered_set<BufPtr>& bufs);
+
+ void visit(ForPtr v) override;
+ void visit(LoadPtr v) override;
+ void visit(StorePtr v) override;
+ void visit(BlockPtr v) override;
+ void visit(AddPtr v) override;
+ void visit(MulPtr v) override;
};
class TORCH_API BlockCodeGen : public CodeGen {
public:
template <typename... Ts>
/* implicit */
- BlockCodeGen(Stmt* stmt, Ts... ts)
+ BlockCodeGen(StmtPtr stmt, Ts... ts)
: CodeGen(
stmt,
std::vector<BufferArg>({BufferArg(ts)...}),
}
BlockCodeGen(
- Stmt* stmt,
+ StmtPtr stmt,
const std::vector<BufferArg>& buffer_args,
at::Device device = at::Device(at::kCPU),
const std::string& kernel_func_name = "func")
template <typename Container>
BoundsInfo mergeTensorAccesses(
const Container& accesses,
- const std::unordered_map<Var*, Buf*>& varToBuf,
+ const std::unordered_map<VarPtr, BufPtr>& varToBuf,
bool distinctAccessKinds) {
BoundsInfo ret;
for (auto& access : accesses) {
auto vtbIt = varToBuf.find(access->var());
TORCH_INTERNAL_ASSERT(vtbIt != varToBuf.end());
- Buf* buf = vtbIt->second;
+ BufPtr buf = vtbIt->second;
std::vector<TensorAccessBoundsInfo>& infos = ret[buf];
bool added = false;
TORCH_INTERNAL_ASSERT(TABI.stop.size() == access->bounds().size());
for (size_t i = 0; i < TABI.start.size(); ++i) {
TABI.start[i] = IRSimplifier::simplify(
- new Min(TABI.start[i], access->bounds()[i].start, true));
+ alloc<Min>(TABI.start[i], access->bounds()[i].start, true));
TABI.stop[i] = IRSimplifier::simplify(
- new Max(TABI.stop[i], access->bounds()[i].end, true));
+ alloc<Max>(TABI.stop[i], access->bounds()[i].end, true));
added = true;
if (kind != TABI.kind) {
return ret;
}
-std::unordered_map<Var*, Buf*> getAllBufs(Stmt* s) {
- std::unordered_map<Var*, Buf*> varToBuf;
+std::unordered_map<VarPtr, BufPtr> getAllBufs(StmtPtr s) {
+ std::unordered_map<VarPtr, BufPtr> varToBuf;
auto bufs = NodeFinder<Buf>::find(s);
- for (auto* b : bufs) {
+ for (auto b : bufs) {
varToBuf[b->base_handle()] = b;
}
return varToBuf;
}
-std::unordered_map<Var*, Buf*> getAllBufs(Expr* e) {
- std::unordered_map<Var*, Buf*> varToBuf;
+std::unordered_map<VarPtr, BufPtr> getAllBufs(ExprPtr e) {
+ std::unordered_map<VarPtr, BufPtr> varToBuf;
auto bufs = NodeFinder<Buf>::find(e);
- for (auto* b : bufs) {
+ for (auto b : bufs) {
varToBuf[b->base_handle()] = b;
}
return varToBuf;
}
-BoundsInfo inferBounds(Stmt* s, bool distinctAccessKinds) {
+BoundsInfo inferBounds(StmtPtr s, bool distinctAccessKinds) {
auto varToBuf = getAllBufs(s);
MemDependencyChecker checker;
BoundsInfo getInferredBounds(
MemDependencyChecker& analyzer,
- Stmt* s,
+ StmtPtr s,
bool distinctAccessKinds) {
return mergeTensorAccesses(
analyzer.accessesWithin(s), getAllBufs(s), distinctAccessKinds);
BoundsInfo getInferredBounds(
MemDependencyChecker& analyzer,
- Expr* e,
+ ExprPtr e,
bool distinctAccessKinds) {
return mergeTensorAccesses(
analyzer.accessesWithin(e), getAllBufs(e), distinctAccessKinds);
std::cerr << "}\n";
}
-std::vector<Expr*> getBoundExtents(
+std::vector<ExprPtr> getBoundExtents(
const std::vector<TensorAccessBoundsInfo>& infos) {
- std::vector<Expr*> starts;
- std::vector<Expr*> stops;
+ std::vector<ExprPtr> starts;
+ std::vector<ExprPtr> stops;
// Find the safe size of the temprorary buffer by determining the outer
// extents of a union of all bounds.
starts.push_back(p.start[i]);
} else {
starts[i] =
- IRSimplifier::simplify(new Min(starts[i], p.start[i], true));
+ IRSimplifier::simplify(alloc<Min>(starts[i], p.start[i], true));
}
if (stops.size() <= i) {
stops.push_back(p.stop[i]);
} else {
- stops[i] = IRSimplifier::simplify(new Max(stops[i], p.stop[i], true));
+ stops[i] =
+ IRSimplifier::simplify(alloc<Max>(stops[i], p.stop[i], true));
}
}
}
- std::vector<Expr*> extents;
+ std::vector<ExprPtr> extents;
for (size_t i = 0; i < starts.size(); ++i) {
- Expr* dim = IRSimplifier::simplify(
- new Add(new Sub(stops[i], starts[i]), new IntImm(1)));
+ ExprPtr dim = IRSimplifier::simplify(
+ alloc<Add>(alloc<Sub>(stops[i], starts[i]), alloc<IntImm>(1)));
extents.push_back(dim);
}
BoundSet convertBounds(
BoundsInfo& bounds,
- Buf* buf,
+ BufPtr buf,
TensorAccessKind filter = kMutate) {
auto it = bounds.find(buf);
if (it == bounds.end()) {
HazardKind getPotentialHazards(
MemDependencyChecker& analyzer,
- Stmt* A,
- Stmt* B) {
+ StmtPtr A,
+ StmtPtr B) {
BoundsInfo aBounds = getInferredBounds(analyzer, A, true);
BoundsInfo bBounds = getInferredBounds(analyzer, B, true);
BoundSet aReads;
for (auto& pair : bBounds) {
- Buf* buf = pair.first;
+ BufPtr buf = pair.first;
if (aBounds.find(buf) == aBounds.end()) {
continue;
}
const BoundsInfo& bBounds,
TensorAccessKind aFilter = kMutate,
TensorAccessKind bFilter = kMutate) {
- using IndexBoundsInfo = std::unordered_map<Buf*, std::vector<IndexBounds>>;
+ using IndexBoundsInfo = std::unordered_map<BufPtr, std::vector<IndexBounds>>;
IndexBoundsInfo aIndexBoundsInfo;
for (auto& aBound : aBounds) {
aIndexBoundsInfo[aBound.first] = getIndexBounds(aBound.second, aFilter);
bool hasConflictingOverlap(
analysis::MemDependencyChecker& analyzer,
- Stmt* A,
- Stmt* B) {
+ StmtPtr A,
+ StmtPtr B) {
BoundsInfo aBounds = getInferredBounds(analyzer, A, true);
BoundsInfo bBounds = getInferredBounds(analyzer, B, true);
return hasConflictingOverlap(aBounds, bBounds);
bool isOverlapping(
analysis::MemDependencyChecker& analyzer,
- Store* S1,
- Store* S2) {
+ StorePtr S1,
+ StorePtr S2) {
BoundsInfo s1Bounds = getInferredBounds(analyzer, S1, true);
BoundsInfo s2Bounds = getInferredBounds(analyzer, S2, true);
return hasConflictingOverlap(s1Bounds, s2Bounds, kStore, kStore);
bool isOverlapping(
analysis::MemDependencyChecker& analyzer,
- Store* S,
- Load* L) {
+ StorePtr S,
+ LoadPtr L) {
BoundsInfo sBounds = getInferredBounds(analyzer, S, true);
BoundsInfo lBounds = getInferredBounds(analyzer, L, true);
return hasConflictingOverlap(sBounds, lBounds, kStore, kLoad);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct TORCH_API TensorAccessBoundsInfo {
TensorAccessKind kind;
- std::vector<Expr*> start;
- std::vector<Expr*> stop;
+ std::vector<ExprPtr> start;
+ std::vector<ExprPtr> stop;
};
using BoundsInfo =
- std::unordered_map<Buf*, std::vector<TensorAccessBoundsInfo>>;
+ std::unordered_map<BufPtr, std::vector<TensorAccessBoundsInfo>>;
-TORCH_API BoundsInfo inferBounds(Stmt* s, bool distinctAccessKinds = true);
+TORCH_API BoundsInfo inferBounds(StmtPtr s, bool distinctAccessKinds = true);
// Bounds inference caching the analysis. The MemDependencyChecker must already
// have been run.
TORCH_API BoundsInfo getInferredBounds(
analysis::MemDependencyChecker& analyzer,
- Stmt* s,
+ StmtPtr s,
bool distinctAccessKinds = true);
TORCH_API BoundsInfo getInferredBounds(
analysis::MemDependencyChecker& analyzer,
- Expr* e,
+ ExprPtr e,
bool distinctAccessKinds = true);
TORCH_API void printBoundsInfo(const BoundsInfo& v);
-TORCH_API std::vector<Expr*> getBoundExtents(
+TORCH_API std::vector<ExprPtr> getBoundExtents(
const std::vector<TensorAccessBoundsInfo>& infos);
// The kind of dependency found, in increasing order of exclusivity.
WriteAfterWrite,
NoDependency,
};
-TORCH_API HazardKind
-getPotentialHazards(analysis::MemDependencyChecker& analyzer, Stmt* A, Stmt* B);
+TORCH_API HazardKind getPotentialHazards(
+ analysis::MemDependencyChecker& analyzer,
+ StmtPtr A,
+ StmtPtr B);
// Returns true if there is a conflicting overlap between accesses in
// statements A and B. A conflicting overlap is an overlap in buffer accesses
// where at least one of the accesses is a Store.
TORCH_API bool hasConflictingOverlap(
analysis::MemDependencyChecker& analyzer,
- Stmt* A,
- Stmt* B);
+ StmtPtr A,
+ StmtPtr B);
// Same as above, between accesses in stores S1 and S2.
TORCH_API bool isOverlapping(
analysis::MemDependencyChecker& analyzer,
- Store* S1,
- Store* S2);
+ StorePtr S1,
+ StorePtr S2);
// Same as above, between accesses in store S and load L.
TORCH_API bool isOverlapping(
analysis::MemDependencyChecker& analyzer,
- Store* S,
- Load* L);
+ StorePtr S,
+ LoadPtr L);
} // namespace tensorexpr
} // namespace jit
return ContainedOrEqual;
}
- Expr* lowDiff = IRSimplifier::simplify(new Sub(a.start, b.end));
- Expr* highDiff = IRSimplifier::simplify(new Sub(b.start, a.end));
+ ExprPtr lowDiff = IRSimplifier::simplify(alloc<Sub>(a.start, b.end));
+ ExprPtr highDiff = IRSimplifier::simplify(alloc<Sub>(b.start, a.end));
if (lowDiff->isConstant() && highDiff->isConstant()) {
int low = immediateAs<int>(lowDiff);
}
}
- Expr* diff_start = IRSimplifier::simplify(new Sub(b.start, a.start));
- Expr* diff_end = IRSimplifier::simplify(new Sub(b.end, a.end));
+ ExprPtr diff_start = IRSimplifier::simplify(alloc<Sub>(b.start, a.start));
+ ExprPtr diff_end = IRSimplifier::simplify(alloc<Sub>(b.end, a.end));
// If one side fully encloses the other, they're adjacent.
if (diff_start->isConstant() && diff_end->isConstant()) {
Bound ret = a[0];
for (size_t i = 1; i < a.size(); ++i) {
- ret.start = new Mul(ret.start, a[i].start);
- ret.end = new Mul(ret.end, a[i].end);
+ ret.start = alloc<Mul>(ret.start, a[i].start);
+ ret.end = alloc<Mul>(ret.end, a[i].end);
}
ret.start = IRSimplifier::simplify(ret.start);
return {a};
}
- Expr* lowDiff = IRSimplifier::simplify(new Sub(b.start, a.start));
- Expr* highDiff = IRSimplifier::simplify(new Sub(b.end, a.end));
+ ExprPtr lowDiff = IRSimplifier::simplify(alloc<Sub>(b.start, a.start));
+ ExprPtr highDiff = IRSimplifier::simplify(alloc<Sub>(b.end, a.end));
// If the diff has only a single var, we can try to guess sign.
if (!lowDiff->isConstant()) {
auto vars = VarFinder::find(lowDiff);
if (vars.size() == 1) {
- lowDiff = IRSimplifier::simplify(new Sub(
- SubstituteInClone(b.start, {{*vars.begin(), new IntImm(1)}}),
- SubstituteInClone(a.start, {{*vars.begin(), new IntImm(1)}})));
+ lowDiff = IRSimplifier::simplify(alloc<Sub>(
+ SubstituteInClone(b.start, {{*vars.begin(), alloc<IntImm>(1)}}),
+ SubstituteInClone(a.start, {{*vars.begin(), alloc<IntImm>(1)}})));
}
}
if (!highDiff->isConstant()) {
auto vars = VarFinder::find(highDiff);
if (vars.size() == 1) {
- highDiff = IRSimplifier::simplify(new Sub(
- SubstituteInClone(b.end, {{*vars.begin(), new IntImm(1)}}),
- SubstituteInClone(a.end, {{*vars.begin(), new IntImm(1)}})));
+ highDiff = IRSimplifier::simplify(alloc<Sub>(
+ SubstituteInClone(b.end, {{*vars.begin(), alloc<IntImm>(1)}}),
+ SubstituteInClone(a.end, {{*vars.begin(), alloc<IntImm>(1)}})));
}
}
if (hasHead) {
res.emplace_back(
- a.start, IRSimplifier::simplify(new Sub(b.start, new IntImm(1))));
+ a.start, IRSimplifier::simplify(alloc<Sub>(b.start, alloc<IntImm>(1))));
}
if (hasTail) {
- Expr* tailStart = IRSimplifier::simplify(new Add(b.end, new IntImm(1)));
+ ExprPtr tailStart =
+ IRSimplifier::simplify(alloc<Add>(b.end, alloc<IntImm>(1)));
res.emplace_back(tailStart, a.end);
}
// A simple class containing the start and end of a range in a single dimension.
struct TORCH_API Bound {
- Expr* start{nullptr};
- Expr* end{nullptr};
+ ExprPtr start{nullptr};
+ ExprPtr end{nullptr};
// This stores whether or not the start and end of this Bound have previously
// been swapped. This occurs when the bound is in a loop with a negative
bool swapped{false};
Bound() = default;
- Bound(Expr* s, Expr* e) : start(s), end(e) {}
+ Bound(ExprPtr s, ExprPtr e) : start(s), end(e) {}
void print() const {
std::cout << "(" << *start << ", " << *end << ")";
struct BoundHash {
size_t operator()(const Bound& b) const {
- return std::hash<Expr*>()(b.start) ^ std::hash<Expr*>()(b.end);
+ return std::hash<ExprPtr>()(b.start) ^ std::hash<ExprPtr>()(b.end);
}
};
std::unique_ptr<CodeGen> CreateCodeGen(
const std::string& name,
- Stmt* stmt,
+ StmtPtr stmt,
const std::vector<CodeGen::BufferArg>& params,
at::Device device,
const std::string& kernel_func_name) {
return method(stmt, params, device, kernel_func_name);
}
-Expr* GenericIntrinsicsExpander::mutate(Intrinsics* v) {
+ExprPtr GenericIntrinsicsExpander::mutate(IntrinsicsPtr v) {
if (v->op_type() == kSigmoid) {
auto x = v->param(0)->accept_mutator(this);
auto one = expr_to_vec(
template <typename... Ts>
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
- CodeGen(Stmt* stmt, Ts... ts)
+ CodeGen(StmtPtr stmt, Ts... ts)
: stmt_(stmt), buffer_args_({BufferArg(ts)...}) {}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
CodeGen(
- Stmt* stmt,
+ StmtPtr stmt,
std::vector<BufferArg> buffer_args,
at::Device device = at::kCPU,
std::string kernel_func_name = "func")
virtual ~CodeGen() = default;
- Stmt* stmt() const {
+ StmtPtr stmt() const {
return stmt_;
}
- void set_stmt(Stmt* s) {
+ void set_stmt(StmtPtr s) {
stmt_ = s;
}
static void* argToPtr(const BufferArg& bufferArg, const CallArg& callArg);
private:
- Stmt* stmt_;
+ StmtPtr stmt_;
std::vector<BufferArg> buffer_args_;
at::Device device_ = at::kCPU;
std::string kernel_func_name_ = "func";
BufferArg(const VarHandle& var) : var_(var.node()), isVar_(true) {}
BufferArg(const BufHandle& buf) : buf_(buf.node()) {}
- Var* var() const {
+ VarPtr var() const {
return isVar_ ? var_ : buf_->base_handle();
}
- Buf* buf() const {
+ BufPtr buf() const {
return buf_;
}
}
private:
- Var* var_ = nullptr;
- Buf* buf_ = nullptr;
+ VarPtr var_ = nullptr;
+ BufPtr buf_ = nullptr;
bool isVar_ = false;
};
}
using StmtFactoryMethod = std::function<std::unique_ptr<CodeGen>(
- Stmt* stmt,
+ StmtPtr stmt,
const std::vector<CodeGen::BufferArg>&,
at::Device device,
const std::string& kernel_func_name)>;
RegisterCodeGenList& codegen_list = RegisterCodeGenList::GetInstance();
codegen_list.AddStmtFactoryMethod(
name,
- [](Stmt* stmt,
+ [](StmtPtr stmt,
const std::vector<CodeGen::BufferArg>& params,
at::Device device,
const std::string& kernel_func_name) {
TORCH_API std::unique_ptr<CodeGen> CreateCodeGen(
const std::string& name,
- Stmt* stmt,
+ StmtPtr stmt,
const std::vector<CodeGen::BufferArg>& params,
at::Device device = at::kCPU,
const std::string& kernel_func_name = "func");
class TORCH_API GenericIntrinsicsExpander : public IRMutator {
protected:
- Expr* mutate(Intrinsics* v) override;
+ ExprPtr mutate(IntrinsicsPtr v) override;
};
} // namespace tensorexpr
namespace jit {
namespace tensorexpr {
-void CppPrinter::visit(Allocate* alloc) {
+void CppPrinter::visit(AllocatePtr alloc) {
constexpr size_t kAllocOnStackThresholdSize = 512;
size_t size = 1;
for (auto dim : alloc->dims()) {
- IntImm* v = dynamic_cast<IntImm*>(dim);
+ IntImmPtr v = to<IntImm>(dim);
if (v) {
size *= v->value();
} else {
}
}
-void CppPrinter::visit(Free* free) {
- Var* var = free->buffer_var();
+void CppPrinter::visit(FreePtr free) {
+ VarPtr var = free->buffer_var();
if (allocated_on_heap_.count(var)) {
emitIndent();
os() << "free(" << name_manager()->get_unique_name(var) << ");"
explicit CppPrinter(std::ostream* os) : IRPrinter(*os) {}
using IRPrinter::visit;
- void visit(Allocate*) override;
- void visit(Free*) override;
+ void visit(AllocatePtr) override;
+ void visit(FreePtr) override;
private:
- std::unordered_set<Var*> allocated_on_heap_;
+ std::unordered_set<VarPtr> allocated_on_heap_;
};
} // namespace tensorexpr
// TODO: move this to a more shared place.
class ScopedVarName {
public:
- ScopedVarName(VarNameMap* mapping, Var* var, const std::string& name)
+ ScopedVarName(VarNameMap* mapping, VarPtr var, const std::string& name)
: mapping_(mapping), var_(var) {
auto iter = mapping->find(var);
if (iter != mapping->end()) {
mapping->insert(std::make_pair(var, name));
}
- ScopedVarName(UniqueNameManager* manager, Var* var, const std::string& name)
+ ScopedVarName(UniqueNameManager* manager, VarPtr var, const std::string& name)
: ScopedVarName(&manager->unique_name_mapping_, var, name) {}
ScopedVarName(const ScopedVarName&) = delete;
private:
VarNameMap* mapping_ = nullptr;
- Var* var_ = nullptr;
+ VarPtr var_ = nullptr;
};
-static int as_int(Expr* expr) {
- auto v = dynamic_cast<IntImm*>(expr);
+static int as_int(ExprPtr expr) {
+ auto v = to<IntImm>(expr);
if (!v) {
throw malformed_input(
"cuda_codegen: non Int expr interpreted as int", expr);
return v->value();
}
-static bool is_zero(Expr* expr) {
+static bool is_zero(ExprPtr expr) {
return as_int(expr) == 0;
}
}
}
-void CudaAnalysis::visit(Free* v) {
+void CudaAnalysis::visit(FreePtr v) {
if (thread_local_bufs_.count(v->buffer_var()) == 0 &&
cross_block_bufs_.count(v->buffer_var()) == 0) {
throw std::runtime_error("Global free not supported yet");
}
}
-void CudaAnalysis::visit(Allocate* v) {
- Stmt* p = v->get_parent();
+void CudaAnalysis::visit(AllocatePtr v) {
+ StmtPtr p = v->get_parent();
while (p) {
- For* for_v = dynamic_cast<For*>(p);
+ ForPtr for_v = to<For>(p);
if (for_v) {
// NOLINTNEXTLINE(bugprone-branch-clone)
if (for_v->loop_options().is_gpu_block_index()) {
throw std::runtime_error("Global alloc not supported yet");
}
-void CudaAnalysis::visit(For* v) {
+void CudaAnalysis::visit(ForPtr v) {
// Recurse first.
v->body()->accept(this);
if (gpu_block_index >= 3) {
throw std::runtime_error("support only 3D gpu_block_index");
}
- Expr* prev = nullptr;
+ ExprPtr prev = nullptr;
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
// NOLINTNEXTLINE(bugprone-branch-clone)
if (gpu_block_extents_.size() <= gpu_block_index) {
gpu_block_extents_[gpu_block_index] = v->stop();
} else {
gpu_block_extents_[gpu_block_index] =
- IRSimplifier::simplify(new Max(prev, v->stop(), true));
+ IRSimplifier::simplify(alloc<Max>(prev, v->stop(), true));
}
} else if (loop_options.is_gpu_thread_index()) {
int gpu_thread_index = loop_options.gpu_thread_index();
if (gpu_thread_index >= 3) {
throw std::runtime_error("support only 3D gpu_thread_index");
}
- Expr* prev = nullptr;
+ ExprPtr prev = nullptr;
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
// NOLINTNEXTLINE(bugprone-branch-clone)
if (gpu_thread_extents_.size() <= gpu_thread_index) {
gpu_thread_extents_[gpu_thread_index] = v->stop();
} else {
gpu_thread_extents_[gpu_thread_index] =
- IRSimplifier::simplify(new Max(prev, v->stop(), true));
+ IRSimplifier::simplify(alloc<Max>(prev, v->stop(), true));
}
}
}
-void CudaPrinter::print_flat_alloc(Allocate* alloc) {
+void CudaPrinter::print_flat_alloc(AllocatePtr alloc) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::vector<Expr*> dims = alloc->dims();
+ std::vector<ExprPtr> dims = alloc->dims();
// TODO: this should be merged with the storage flattener.
int64_t flat_size = 1;
for (auto dim : dims) {
- IntImm* dim_i = dynamic_cast<IntImm*>(dim);
+ IntImmPtr dim_i = to<IntImm>(dim);
if (dim_i) {
flat_size *= dim_i->value();
} else {
<< "[" << flat_size << "];" << std::endl;
}
-void CudaPrinter::visit(Allocate* v) {
+void CudaPrinter::visit(AllocatePtr v) {
// TODO: handle dynamic shapes here.
if (cuda_analysis_->cross_block_bufs().count(v->buffer_var()) != 0) {
emitIndent();
throw std::runtime_error("Encountered Alloc not local to block or thread");
}
-void CudaPrinter::visit(Free* v) {
+void CudaPrinter::visit(FreePtr v) {
// do nothing
}
-void CudaPrinter::visit(For* v) {
+void CudaPrinter::visit(ForPtr v) {
IRPrinter::visit(v);
}
-void CudaPrinter::visit(Cast* v) {
+void CudaPrinter::visit(CastPtr v) {
if (v->dtype().scalar_type() == ScalarType::Half) {
os() << "__float2half(";
v->src_value()->accept(this);
os() << ")";
}
-void CudaPrinter::visit(Intrinsics* v) {
+void CudaPrinter::visit(IntrinsicsPtr v) {
if (v->op_type() == IntrinsicsOp::kRand) {
os() << "Uint32ToFloat(" << *rand_func_ << "())";
return;
os() << ")";
}
-void CudaPrinter::visit(ExternalCall* v) {
+void CudaPrinter::visit(ExternalCallPtr v) {
throw unimplemented_lowering(v);
}
-void CudaPrinter::visit(Load* v) {
+void CudaPrinter::visit(LoadPtr v) {
// TODO: find a better metric in using ldg or not. Support different dtypes.
// Detects whether the load target is also a store target.
// TODO: this is currently too wide. It detects whether a store-target
}
// TODO: maybe this should be a more shared location?
-// TODO: investigate how "Expr*" can be implicitly converted to "ExprHandle" as
-// a bool.
-static bool CheckEqual(Expr* lhs, Expr* rhs) {
+// TODO: investigate how "ExprPtr" can be implicitly converted to "ExprHandle"
+// as a bool.
+static bool CheckEqual(ExprPtr lhs, ExprPtr rhs) {
// The fast path. Checks if the pointers are the same.
if (lhs == rhs) {
return true;
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
AtomicAddFuser(
- const std::unordered_set<Var*>& thread_local_bufs,
+ const std::unordered_set<VarPtr>& thread_local_bufs,
const GPUMetaVarRewriter& metavars)
: thread_local_bufs_(thread_local_bufs) {
- const std::vector<Expr*>& block_extents = metavars.gpu_block_extents();
- const std::vector<Var*>& block_vars = metavars.gpu_block_vars();
+ const std::vector<ExprPtr>& block_extents = metavars.gpu_block_extents();
+ const std::vector<VarPtr>& block_vars = metavars.gpu_block_vars();
for (size_t i = 0; i < block_extents.size(); ++i) {
MetaVarExtent extent{block_extents[i], false};
if (extent.expr->isConstant() && immediateEquals(extent.expr, 1)) {
metavars_[block_vars[i]] = extent;
}
- const std::vector<Expr*>& thread_extents = metavars.gpu_thread_extents();
- const std::vector<Var*>& thread_vars = metavars.gpu_thread_vars();
+ const std::vector<ExprPtr>& thread_extents = metavars.gpu_thread_extents();
+ const std::vector<VarPtr>& thread_vars = metavars.gpu_thread_vars();
for (size_t i = 0; i < thread_extents.size(); ++i) {
MetaVarExtent extent{thread_extents[i], false};
if (extent.expr->isConstant() && immediateEquals(extent.expr, 1)) {
}
}
- Stmt* mutate(Store* v) override {
- Buf* buf = v->buf();
- Store* orig = const_cast<Store*>(v); // NOLINT
+ StmtPtr mutate(StorePtr v) override {
+ BufPtr buf = v->buf();
+ StorePtr orig = const_cast<StorePtr>(v); // NOLINT
// Thread locals never need to be atomic.
if (thread_local_bufs_.count(buf->base_handle()) != 0) {
if (dtype != ScalarType::Float && dtype != ScalarType::Double) {
return orig;
}
- Add* add_v = dynamic_cast<Add*>(v->value());
+ AddPtr add_v = to<Add>(v->value());
if (!add_v) {
return orig;
}
- Load* load_v = dynamic_cast<Load*>(add_v->lhs());
+ LoadPtr load_v = to<Load>(add_v->lhs());
if (!load_v) {
return orig;
}
// TODO: this checks that the metavars occur directly as an index, but this
// is pessimistic, blockIdx.x + 1 is fine too if there is no overlapping.
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::unordered_set<Var*> vars_to_find = nontrivial_metavars_;
- for (Expr* e : v->indices()) {
- if (Var* v = dynamic_cast<Var*>(e)) {
+ std::unordered_set<VarPtr> vars_to_find = nontrivial_metavars_;
+ for (ExprPtr e : v->indices()) {
+ if (VarPtr v = to<Var>(e)) {
vars_to_find.erase(v);
}
}
return orig;
}
- return new AtomicAdd(buf, v->indices(), add_v->rhs());
+ return alloc<AtomicAdd>(buf, v->indices(), add_v->rhs());
}
private:
- const std::unordered_set<Var*>& thread_local_bufs_;
+ const std::unordered_set<VarPtr>& thread_local_bufs_;
struct MetaVarExtent {
- Expr* expr{nullptr};
+ ExprPtr expr{nullptr};
bool trivial{false};
};
- std::unordered_map<Var*, MetaVarExtent> metavars_;
- std::unordered_set<Var*> nontrivial_metavars_;
+ std::unordered_map<VarPtr, MetaVarExtent> metavars_;
+ std::unordered_set<VarPtr> nontrivial_metavars_;
};
-void CudaPrinter::visit(Store* v) {
+void CudaPrinter::visit(StorePtr v) {
emitIndent();
if (v->indices().empty()) {
os() << *v->base_handle() << " = ";
os() << std::endl;
}
-void CudaPrinter::visit(AtomicAdd* v) {
+void CudaPrinter::visit(AtomicAddPtr v) {
emitIndent();
if (cuda_analysis_->thread_local_bufs().count(v->base_handle()) > 0) {
// atomicAdd only works on global and shared memory
os() << std::endl;
}
-void CudaPrinter::visit(Max* v) {
+void CudaPrinter::visit(MaxPtr v) {
if (v->dtype().is_integral()) {
os() << "max(";
} else {
os() << ")";
}
-void CudaPrinter::visit(Min* v) {
+void CudaPrinter::visit(MinPtr v) {
if (v->dtype().is_integral()) {
os() << "min(";
} else {
os() << ")";
}
-void CudaPrinter::visit(IfThenElse* v) {
+void CudaPrinter::visit(IfThenElsePtr v) {
os() << "((";
v->condition()->accept(this);
os() << ") ? ";
os() << ")";
}
-void CudaPrinter::visit(Block* v) {
+void CudaPrinter::visit(BlockPtr v) {
os() << "{" << std::endl;
indent_++;
- for (Stmt* s : v->stmts()) {
+ for (StmtPtr s : v->stmts()) {
s->accept(this);
}
os() << "}";
}
-void CudaPrinter::visit(Let* v) {
+void CudaPrinter::visit(LetPtr v) {
emitIndent();
os() << dtypeToCppString(v->dtype());
os() << " " << *v->var() << " = ";
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class PrioritizeLoad : public IRMutator {
public:
- Expr* mutate(Load* v) override {
+ ExprPtr mutate(LoadPtr v) override {
// Look at the declaration of this variable for more details.
if (nested_if_then_else_ > 0) {
return IRMutator::mutate(v);
}
MemLoadList& load_list = load_stack_.back();
- Var* load_new_var = new Var("v", v->dtype());
- Expr* new_value = IRMutator::mutate(v);
+ VarPtr load_new_var = alloc<Var>("v", v->dtype());
+ ExprPtr new_value = IRMutator::mutate(v);
load_list.push_back(std::make_pair(load_new_var, new_value));
return load_new_var;
}
- Expr* mutate(Cast* v) override {
- Load* src_load = dynamic_cast<Load*>(v->src_value());
- Expr* new_src = v->src_value()->accept_mutator(this);
- Var* new_var = dynamic_cast<Var*>(new_src);
+ ExprPtr mutate(CastPtr v) override {
+ LoadPtr src_load = to<Load>(v->src_value());
+ ExprPtr new_src = v->src_value()->accept_mutator(this);
+ VarPtr new_var = to<Var>(new_src);
if (!src_load || !new_var) {
- return new Cast(v->dtype(), new_src);
+ return alloc<Cast>(v->dtype(), new_src);
}
// We just did the prioritize load, let's fold in the Cast.
assert(pair.first == new_var);
load_list.pop_back();
- new_var = new Var("v", v->dtype());
+ new_var = alloc<Var>("v", v->dtype());
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- Expr* new_value = new Cast(v->dtype(), pair.second);
+ ExprPtr new_value = alloc<Cast>(v->dtype(), pair.second);
load_list.push_back(std::make_pair(new_var, new_value));
return new_var;
}
- Stmt* mutate(Store* v) override {
- Store* last = nested_store_;
+ StmtPtr mutate(StorePtr v) override {
+ StorePtr last = nested_store_;
nested_store_ = v;
- Stmt* s = IRMutator::mutate(v);
+ StmtPtr s = IRMutator::mutate(v);
nested_store_ = last;
return s;
}
- Stmt* mutate(Let* v) override {
+ StmtPtr mutate(LetPtr v) override {
nested_let_ = true;
- Stmt* s = IRMutator::mutate(v);
+ StmtPtr s = IRMutator::mutate(v);
nested_let_ = false;
return s;
}
- Stmt* mutate(Block* v) override {
- Block* v1 = const_cast<Block*>(v); // NOLINT
+ StmtPtr mutate(BlockPtr v) override {
+ BlockPtr v1 = const_cast<BlockPtr>(v); // NOLINT
assert(v1);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::list<Stmt*> stmts = v1->stmts();
- for (Stmt* stmt : stmts) {
+ std::list<StmtPtr> stmts = v1->stmts();
+ for (StmtPtr stmt : stmts) {
PushList();
- Stmt* stmt_new = stmt->accept_mutator(this);
+ StmtPtr stmt_new = stmt->accept_mutator(this);
AddMemLoadsFromList(v1, stmt);
PopList();
return v1;
}
- Expr* mutate(IfThenElse* v) override {
+ ExprPtr mutate(IfThenElsePtr v) override {
nested_if_then_else_++;
- Expr* new_v = IRMutator::mutate(v);
+ ExprPtr new_v = IRMutator::mutate(v);
nested_if_then_else_--;
return new_v;
}
private:
- using MemLoadEntry = std::pair<Var*, Expr*>;
+ using MemLoadEntry = std::pair<VarPtr, ExprPtr>;
using MemLoadList = std::vector<MemLoadEntry>;
using MemoryLoadStack = std::vector<MemLoadList>;
load_stack_.pop_back();
}
- void AddMemLoadsFromList(Block* block, Stmt* last) {
+ void AddMemLoadsFromList(BlockPtr block, StmtPtr last) {
MemLoadList& load_list = load_stack_.back();
if (load_list.empty()) {
return;
}
for (auto& pair : load_list) {
- Stmt* news = new Let(pair.first, pair.second);
+ StmtPtr news = alloc<Let>(pair.first, pair.second);
block->insert_stmt_before(news, last);
}
}
// }
// int v2 = v + 2;
int nested_if_then_else_{0};
- Store* nested_store_{nullptr};
+ StorePtr nested_store_{nullptr};
bool nested_let_{false};
- std::unordered_set<Var*> thread_local_bufs_;
+ std::unordered_set<VarPtr> thread_local_bufs_;
};
std::string CudaCodeGen::GetUniqueFuncName(const std::string& func_prefix) {
return true;
}
-Stmt* GPUMetaVarRewriter::mutate(For* v) {
- Stmt* body = v->body();
- Expr* old_reach = nullptr;
+StmtPtr GPUMetaVarRewriter::mutate(ForPtr v) {
+ StmtPtr body = v->body();
+ ExprPtr old_reach = nullptr;
const LoopOptions& loop_options = v->loop_options();
if (loop_options.is_gpu_block_index()) {
int gpu_block_index = loop_options.gpu_block_index();
current_block_reach_[gpu_block_index] = v->stop();
} else {
current_block_reach_[gpu_block_index] =
- IRSimplifier::simplify(new Max(old_reach, v->stop(), true));
+ IRSimplifier::simplify(alloc<Max>(old_reach, v->stop(), true));
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- Var* metaVar = gpu_block_vars_[gpu_block_index];
+ VarPtr metaVar = gpu_block_vars_[gpu_block_index];
body = Substitute(Stmt::clone(body), {{v->var(), metaVar}});
} else if (loop_options.is_gpu_thread_index()) {
int gpu_thread_index = loop_options.gpu_thread_index();
current_thread_reach_[gpu_thread_index] = v->stop();
} else {
current_thread_reach_[gpu_thread_index] =
- IRSimplifier::simplify(new Max(old_reach, v->stop(), true));
+ IRSimplifier::simplify(alloc<Max>(old_reach, v->stop(), true));
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- Var* metaVar = gpu_thread_vars_[gpu_thread_index];
+ VarPtr metaVar = gpu_thread_vars_[gpu_thread_index];
body = Substitute(Stmt::clone(body), {{v->var(), metaVar}});
}
return v->cloneWithNewBody(body);
}
-Stmt* GPUMetaVarRewriter::mutate(Block* v) {
+StmtPtr GPUMetaVarRewriter::mutate(BlockPtr v) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<Segment> innerSegments;
Segment current;
// the same launch reach. Segments are comprised of all statements that aren't
// loops - which are their own segments. Some operations, such as threading
// and memory ops should never be masked and so also get their own segment.
- for (Stmt* stmt : *v) {
- Stmt* stmt_new = stmt->accept_mutator(this);
+ for (StmtPtr stmt : *v) {
+ StmtPtr stmt_new = stmt->accept_mutator(this);
if (stmt == stmt_new) {
stmt_new = Stmt::clone(stmt_new);
}
// Likewise, Allocate and Free should never be masked.
- if (dynamic_cast<Allocate*>(stmt) || dynamic_cast<Free*>(stmt)) {
+ if (to<Allocate>(stmt) || to<Free>(stmt)) {
pushAndReset(false);
}
// If the current stmt *was* a loop, it's a segment boundary.
- if (For* f = dynamic_cast<For*>(stmt)) {
+ if (ForPtr f = to<For>(stmt)) {
pushAndReset(false);
}
if (isFullExtent()) {
// flatten inner segments.
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::vector<Stmt*> stmts;
+ std::vector<StmtPtr> stmts;
for (auto& v : innerSegments) {
for (auto* s : v.stmts()) {
stmts.push_back(s);
}
}
- return new Block(stmts);
+ return alloc<Block>(stmts);
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::vector<Stmt*> stmts;
+ std::vector<StmtPtr> stmts;
for (auto& segment : innerSegments) {
bool need_sync = false;
// We never mask loops, they'll mask their contents.
// If we get here, we must mask since we're not full reach and our direct
// child isn't a For.
- Stmt* inner = new Block(segment.stmts());
+ StmtPtr inner = alloc<Block>(segment.stmts());
// threads inside blocks.
auto& thread_extents = cuda_analysis_->gpu_thread_extents();
for (size_t i = 0; i < gpu_thread_vars_.size(); ++i) {
if (!exprEquals(current_thread_reach_[i], thread_extents[i])) {
need_sync = true;
// Mask it against the current dimensions.
- inner = new Cond(
- new CompareSelect(
+ inner = alloc<Cond>(
+ alloc<CompareSelect>(
gpu_thread_vars_[i],
current_thread_reach_[i],
CompareSelectOperation::kLT),
for (size_t i = 0; i < gpu_block_vars_.size(); ++i) {
if (!exprEquals(current_block_reach_[i], block_extents[i])) {
// Mask it against the current dimensions.
- inner = new Cond(
- new CompareSelect(
+ inner = alloc<Cond>(
+ alloc<CompareSelect>(
gpu_block_vars_[i],
current_block_reach_[i],
CompareSelectOperation::kLT),
}
if (need_sync) {
- stmts.push_back(new SyncThreads());
+ stmts.push_back(alloc<SyncThreads>());
}
stmts.push_back(inner);
if (need_sync) {
- stmts.push_back(new SyncThreads());
+ stmts.push_back(alloc<SyncThreads>());
}
}
- return new Block(stmts);
+ return alloc<Block>(stmts);
}
static std::ostream& operator<<(
std::ostream& out,
- const std::vector<Expr*>& exprs) {
+ const std::vector<ExprPtr>& exprs) {
size_t i = 0;
for (auto expr : exprs) {
if (i++ > 0) {
// Check whether the statement uses the Half type, if so add the
// half_support_literal.
- Stmt* stmt_v = stmt();
+ StmtPtr stmt_v = stmt();
HalfChecker halfChecker(buffer_args());
stmt_v->accept(&halfChecker);
os() << ", ";
}
const BufferArg& buffer_arg = buffer_args[i];
- Var* var = buffer_arg.var();
+ VarPtr var = buffer_arg.var();
Dtype dtype = buffer_arg.dtype();
os() << printer_->dtypeToCppString(dtype)
<< name_manager()->get_unique_name(var);
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- Var* rand_seed;
+ VarPtr rand_seed;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- Var* rand_offset;
+ VarPtr rand_offset;
if (has_random_) {
// TODO: switch to kUint64 when it is available.
- rand_seed = new Var("rand_seed", kInt);
- rand_offset = new Var("rand_offset", kInt);
+ rand_seed = alloc<Var>("rand_seed", kInt);
+ rand_offset = alloc<Var>("rand_offset", kInt);
std::string uint64_str = "unsigned long long";
os() << ", " << uint64_str << " " << *rand_seed << ", " << uint64_str << " "
<< *rand_offset;
os() << std::endl;
if (has_random_) {
- Var* idx = new Var("idx", kInt);
+ VarPtr idx = alloc<Var>("idx", kInt);
os() << "int " << *idx << " = blockIdx.x*blockDim.x + threadIdx.x;"
<< std::endl;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- Var* rand_func = printer_->rand_func();
+ VarPtr rand_func = printer_->rand_func();
os() << "Philox " << *rand_func << "(" << *rand_seed << ", " << *idx << ", "
<< *rand_offset << ");" << std::endl;
os() << std::endl;
os() << "}";
// Check that all block extents had been set.
- const std::vector<Expr*>& gpu_block_extents =
+ const std::vector<ExprPtr>& gpu_block_extents =
metavar_rewriter_->gpu_block_extents();
for (size_t i = 0; i < gpu_block_extents.size(); i++) {
if (!gpu_block_extents[i]) {
auto const& buffer_args = this->buffer_args();
// TODO: move as much of this into the constructors.
- const std::vector<Expr*>& gpu_block_extents =
+ const std::vector<ExprPtr>& gpu_block_extents =
metavar_rewriter_->gpu_block_extents();
- const std::vector<Expr*>& gpu_thread_extents =
+ const std::vector<ExprPtr>& gpu_thread_extents =
metavar_rewriter_->gpu_thread_extents();
if (gpu_block_extents.size() > 3 || gpu_thread_extents.size() > 3) {
throw malformed_input(
class CudaAnalysis : public IRVisitor {
public:
CudaAnalysis() {
- gpu_block_extents_ = {new IntImm(1), new IntImm(1), new IntImm(1)};
- gpu_thread_extents_ = {new IntImm(1), new IntImm(1), new IntImm(1)};
+ gpu_block_extents_ = {alloc<IntImm>(1), alloc<IntImm>(1), alloc<IntImm>(1)};
+ gpu_thread_extents_ = {
+ alloc<IntImm>(1), alloc<IntImm>(1), alloc<IntImm>(1)};
}
- bool is_buf_store_target(Buf* buf) const {
+ bool is_buf_store_target(BufPtr buf) const {
return store_targets_.count(buf) > 0;
}
- const std::unordered_set<Var*>& thread_local_bufs() const {
+ const std::unordered_set<VarPtr>& thread_local_bufs() const {
return thread_local_bufs_;
}
- const std::unordered_set<Var*>& cross_block_bufs() const {
+ const std::unordered_set<VarPtr>& cross_block_bufs() const {
return cross_block_bufs_;
}
- const std::vector<Expr*>& gpu_block_extents() const {
+ const std::vector<ExprPtr>& gpu_block_extents() const {
return gpu_block_extents_;
}
- const std::vector<Expr*>& gpu_thread_extents() const {
+ const std::vector<ExprPtr>& gpu_thread_extents() const {
return gpu_thread_extents_;
}
private:
- void visit(Store* v) override {
+ void visit(StorePtr v) override {
store_targets_.insert(v->buf());
}
- void visit(Allocate* v) override;
- void visit(Free* v) override;
- void visit(For* v) override;
+ void visit(AllocatePtr v) override;
+ void visit(FreePtr v) override;
+ void visit(ForPtr v) override;
- std::unordered_set<Buf*> store_targets_;
- std::unordered_set<Var*> thread_local_bufs_;
- std::unordered_set<Var*> cross_block_bufs_;
+ std::unordered_set<BufPtr> store_targets_;
+ std::unordered_set<VarPtr> thread_local_bufs_;
+ std::unordered_set<VarPtr> cross_block_bufs_;
- std::vector<Expr*> gpu_block_extents_;
- std::vector<Expr*> gpu_thread_extents_;
+ std::vector<ExprPtr> gpu_block_extents_;
+ std::vector<ExprPtr> gpu_thread_extents_;
};
// An IRMutator that replaces binding loop options with Cuda metavars, and masks
explicit GPUMetaVarRewriter(const CudaAnalysis* cuda_analysis)
: cuda_analysis_(cuda_analysis) {
gpu_block_vars_ = {
- new Var("blockIdx.x", kInt),
- new Var("blockIdx.y", kInt),
- new Var("blockIdx.z", kInt)};
+ alloc<Var>("blockIdx.x", kInt),
+ alloc<Var>("blockIdx.y", kInt),
+ alloc<Var>("blockIdx.z", kInt)};
gpu_thread_vars_ = {
- new Var("threadIdx.x", kInt),
- new Var("threadIdx.y", kInt),
- new Var("threadIdx.z", kInt)};
-
- current_block_reach_ = {new IntImm(1), new IntImm(1), new IntImm(1)};
- current_thread_reach_ = {new IntImm(1), new IntImm(1), new IntImm(1)};
+ alloc<Var>("threadIdx.x", kInt),
+ alloc<Var>("threadIdx.y", kInt),
+ alloc<Var>("threadIdx.z", kInt)};
+
+ current_block_reach_ = {
+ alloc<IntImm>(1), alloc<IntImm>(1), alloc<IntImm>(1)};
+ current_thread_reach_ = {
+ alloc<IntImm>(1), alloc<IntImm>(1), alloc<IntImm>(1)};
}
- Stmt* mutate(For* v) override;
- Stmt* mutate(Block* v) override;
+ StmtPtr mutate(ForPtr v) override;
+ StmtPtr mutate(BlockPtr v) override;
- const std::vector<Var*>& gpu_block_vars() const {
+ const std::vector<VarPtr>& gpu_block_vars() const {
return gpu_block_vars_;
}
- const std::vector<Var*>& gpu_thread_vars() const {
+ const std::vector<VarPtr>& gpu_thread_vars() const {
return gpu_thread_vars_;
}
- const std::vector<Expr*>& gpu_block_extents() const {
+ const std::vector<ExprPtr>& gpu_block_extents() const {
return cuda_analysis_->gpu_block_extents();
}
- const std::vector<Expr*>& gpu_thread_extents() const {
+ const std::vector<ExprPtr>& gpu_thread_extents() const {
return cuda_analysis_->gpu_thread_extents();
}
return stmts_.empty();
}
- std::vector<Stmt*>& stmts() {
+ std::vector<StmtPtr>& stmts() {
return stmts_;
}
bool mask() {
}
private:
- std::vector<Stmt*> stmts_;
+ std::vector<StmtPtr> stmts_;
bool mask_{true};
};
// parameters.
bool isFullExtent();
- std::vector<Var*> gpu_block_vars_;
- std::vector<Var*> gpu_thread_vars_;
+ std::vector<VarPtr> gpu_block_vars_;
+ std::vector<VarPtr> gpu_thread_vars_;
- std::vector<Expr*> current_block_reach_;
- std::vector<Expr*> current_thread_reach_;
+ std::vector<ExprPtr> current_block_reach_;
+ std::vector<ExprPtr> current_thread_reach_;
const CudaAnalysis* cuda_analysis_;
};
bool has_random)
: IRPrinter(*os), cuda_analysis_(cuda_analysis) {
if (has_random) {
- rand_func_ = new Var("rand", kHandle);
+ rand_func_ = alloc<Var>("rand", kHandle);
}
}
- void visit(Cast* v) override;
- void visit(Intrinsics* v) override;
- void visit(For* v) override;
+ void visit(CastPtr v) override;
+ void visit(IntrinsicsPtr v) override;
+ void visit(ForPtr v) override;
- void visit(Load* v) override;
- void visit(Store* v) override;
- void visit(AtomicAdd* v) override;
- void visit(Max* v) override;
- void visit(Min* v) override;
- void visit(IfThenElse* v) override;
- void visit(Block* v) override;
- void visit(Allocate* v) override;
- void visit(Free* v) override;
- void visit(Let* v) override;
+ void visit(LoadPtr v) override;
+ void visit(StorePtr v) override;
+ void visit(AtomicAddPtr v) override;
+ void visit(MaxPtr v) override;
+ void visit(MinPtr v) override;
+ void visit(IfThenElsePtr v) override;
+ void visit(BlockPtr v) override;
+ void visit(AllocatePtr v) override;
+ void visit(FreePtr v) override;
+ void visit(LetPtr v) override;
- void visit(ExternalCall* v) override;
+ void visit(ExternalCallPtr v) override;
- Var* rand_func() const {
+ VarPtr rand_func() const {
return rand_func_;
}
using IRPrinter::visit;
private:
- Var* rand_func_;
+ VarPtr rand_func_;
const CudaAnalysis* cuda_analysis_;
- void print_flat_alloc(Allocate* alloc);
+ void print_flat_alloc(AllocatePtr alloc);
};
// Construct Cuda C from the buffer and tensor input, and invoke the kernel
public:
template <typename... Ts>
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
- CudaCodeGen(Stmt* stmt, Ts... ts)
+ CudaCodeGen(StmtPtr stmt, Ts... ts)
: CodeGen(
stmt,
std::vector<BufferArg>({BufferArg(ts)...}),
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
CudaCodeGen(
- Stmt* stmt,
+ StmtPtr stmt,
const std::vector<BufferArg>& buffer_args,
at::Device device = at::Device(at::kCUDA, at::cuda::current_device()),
const std::string& kernel_func_name = "func")
c10::optional<c10::Device> device_opt,
c10::optional<bool> pin_memory_opt) override;
- const std::vector<Expr*>& gpu_block_extents() const {
+ const std::vector<ExprPtr>& gpu_block_extents() const {
return cuda_analysis_->gpu_block_extents();
}
- const std::vector<Expr*>& gpu_thread_extents() const {
+ const std::vector<ExprPtr>& gpu_thread_extents() const {
return cuda_analysis_->gpu_thread_extents();
}
~SimpleIREvaluatorImpl() override = default;
- void bindBuf(Buf* buf, void* ptr) {
+ void bindBuf(BufPtr buf, void* ptr) {
buffer_mapping_[buf] = ptr;
}
- void bindVar(Var* var, const Value& val) {
+ void bindVar(VarPtr var, const Value& val) {
eval_context_[var] = val;
}
- Value evaluateExpr(Expr* e) {
+ Value evaluateExpr(ExprPtr e) {
e->accept(this);
return value_;
}
internal_buffers_.clear();
}
- TORCH_API void visit(Add* v) override {
+ TORCH_API void visit(AddPtr v) override {
visit_binary_op(v);
}
- TORCH_API void visit(Sub* v) override {
+ TORCH_API void visit(SubPtr v) override {
visit_binary_op(v);
}
- TORCH_API void visit(Mul* v) override {
+ TORCH_API void visit(MulPtr v) override {
visit_binary_op(v);
}
- TORCH_API void visit(Div* v) override {
+ TORCH_API void visit(DivPtr v) override {
visit_binary_op(v);
}
- TORCH_API void visit(Mod* v) override {
+ TORCH_API void visit(ModPtr v) override {
visit_binary_op(v);
}
- TORCH_API void visit(Max* v) override {
+ TORCH_API void visit(MaxPtr v) override {
visit_binary_op(v, v->propagate_nans());
}
- TORCH_API void visit(Min* v) override {
+ TORCH_API void visit(MinPtr v) override {
visit_binary_op(v, v->propagate_nans());
}
- TORCH_API void visit(And* v) override {
+ TORCH_API void visit(AndPtr v) override {
visit_binary_op(v);
}
- TORCH_API void visit(Or* v) override {
+ TORCH_API void visit(OrPtr v) override {
visit_binary_op(v);
}
- TORCH_API void visit(Xor* v) override {
+ TORCH_API void visit(XorPtr v) override {
visit_binary_op(v);
}
- TORCH_API void visit(Lshift* v) override {
+ TORCH_API void visit(LshiftPtr v) override {
visit_binary_op(v);
}
- TORCH_API void visit(Rshift* v) override {
+ TORCH_API void visit(RshiftPtr v) override {
visit_binary_op(v);
}
- void visit(CompareSelect* v) override {
+ void visit(CompareSelectPtr v) override {
visit_compare_select_op(v, v->compare_select_op());
}
}
void visit_compare_select_op(
- CompareSelect* v,
+ CompareSelectPtr v,
CompareSelectOperation cmp_op) {
v->lhs()->accept(this);
Value lhs_v = value_;
}
}
-#define IMM_VISIT(Type, Name) \
- TORCH_API void visit(Name##Imm* v) override { \
- value_ = Value(v->value()); \
+#define IMM_VISIT(Type, Name) \
+ TORCH_API void visit(Name##ImmPtr v) override { \
+ value_ = Value(v->value()); \
}
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT);
#undef IMM_VISIT
- TORCH_API void visit(Block* v) override {
- Block* last = scope_;
+ TORCH_API void visit(BlockPtr v) override {
+ BlockPtr last = scope_;
scope_ = v;
- for (Stmt* s : v->stmts()) {
+ for (StmtPtr s : v->stmts()) {
s->accept(this);
}
auto it = var_by_scope_.find(v);
if (it != var_by_scope_.end()) {
- for (Expr* v : it->second) {
+ for (ExprPtr v : it->second) {
eval_context_.erase(v);
}
var_by_scope_.erase(it);
scope_ = last;
}
- TORCH_API void visit(Var* v) override {
+ TORCH_API void visit(VarPtr v) override {
auto iter = eval_context_.find(v);
if (iter == eval_context_.end()) {
throw malformed_input("could not find Var in context", v);
}
}
- TORCH_API void visit(Cast* v) override {
- Expr* src_value = v->src_value();
+ TORCH_API void visit(CastPtr v) override {
+ ExprPtr src_value = v->src_value();
src_value->accept(this);
Dtype dst_dtype = v->dtype();
Dtype src_dtype = src_value->dtype();
}
}
- TORCH_API void visit(BitCast* v) override {
- Expr* src_value = v->src_value();
+ TORCH_API void visit(BitCastPtr v) override {
+ ExprPtr src_value = v->src_value();
src_value->accept(this);
Dtype dst_dtype = v->dtype();
Dtype src_dtype = src_value->dtype();
}
}
- TORCH_API void visit(For* v) override {
- Expr* var_node = v->var();
+ TORCH_API void visit(ForPtr v) override {
+ ExprPtr var_node = v->var();
v->start()->accept(this);
int start = value_.as<int>();
v->stop()->accept(this);
eval_context_.erase(var_node);
}
- TORCH_API void visit(Ramp* v) override {
+ TORCH_API void visit(RampPtr v) override {
v->base()->accept(this);
int base = value().as<int>();
v->stride()->accept(this);
value_ = Value(values);
}
- TORCH_API void visit(Broadcast* v) override {
+ TORCH_API void visit(BroadcastPtr v) override {
v->value()->accept(this);
Value value = this->value();
int lanes = v->lanes();
}
}
- TORCH_API void visit(IfThenElse* v) override {
+ TORCH_API void visit(IfThenElsePtr v) override {
v->condition()->accept(this);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool cond_v;
}
}
- TORCH_API void visit(Load* v) override {
+ TORCH_API void visit(LoadPtr v) override {
auto iter = buffer_mapping_.find(v->buf());
if (iter == buffer_mapping_.end()) {
throw malformed_input("could not find base node in Load", v);
}
void* ptr = iter->second;
- Expr* flat_idx = flatten_index(v->buf()->dims(), v->indices());
+ ExprPtr flat_idx = flatten_index(v->buf()->dims(), v->indices());
flat_idx->accept(this);
std::vector<int> index = value().as_vec<int>();
ScalarType v_sdtype = v->dtype().scalar_type();
}
}
- TORCH_API void visit(Store* v) override {
+ TORCH_API void visit(StorePtr v) override {
auto iter = buffer_mapping_.find(v->buf());
if (iter == buffer_mapping_.end()) {
throw malformed_input("could not find base node in Store", v);
void* ptr = iter->second;
- Expr* flat_idx = flatten_index(v->buf()->dims(), v->indices());
+ ExprPtr flat_idx = flatten_index(v->buf()->dims(), v->indices());
flat_idx->accept(this);
std::vector<int> index = value().as_vec<int>();
ScalarType v_sdtype = v->value()->dtype().scalar_type();
}
}
- void visit(ExternalCall* v) override {
+ void visit(ExternalCallPtr v) override {
auto& func_registry = getNNCFunctionRegistry();
if (!func_registry.count(v->func_name())) {
throw unimplemented_lowering(v);
}
- std::vector<Buf*> bufs(v->buf_args());
+ std::vector<BufPtr> bufs(v->buf_args());
bufs.insert(bufs.begin(), v->buf());
std::vector<void*> buf_ptrs;
std::vector<int8_t> buf_dtypes;
std::vector<int64_t> extra_args;
- for (Buf* b : bufs) {
+ for (BufPtr b : bufs) {
auto iter = buffer_mapping_.find(b);
if (iter == buffer_mapping_.end()) {
throw malformed_input("could not find buf", v);
buf_ptrs.push_back(iter->second);
buf_ranks.push_back(b->dims().size());
buf_dtypes.push_back((int8_t)b->dtype().scalar_type());
- for (Expr* dim_expr : b->dims()) {
+ for (ExprPtr dim_expr : b->dims()) {
dim_expr->accept(this);
buf_dims.push_back(value().as<int>());
}
}
- for (Expr* a : v->args()) {
+ for (ExprPtr a : v->args()) {
a->accept(this);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t val;
}
template <typename TReturn, typename TInput>
- void visit_intrinsics_helper(Intrinsics* v) {
+ void visit_intrinsics_helper(IntrinsicsPtr v) {
std::vector<Value> values(v->nparams());
for (const auto i : c10::irange(v->nparams())) {
v->param(i)->accept(this);
value_ = Value(result);
}
- TORCH_API void visit(Intrinsics* v) override {
+ TORCH_API void visit(IntrinsicsPtr v) override {
auto ty = v->dtype().scalar_type();
if (v->op_type() == kIsNan) {
auto inp_dtype = v->params().at(0)->dtype().scalar_type();
}
}
- void visit(Allocate* v) override {
- Buf* b = v->buf();
- std::vector<Expr*> dims = b->dims();
+ void visit(AllocatePtr v) override {
+ BufPtr b = v->buf();
+ std::vector<ExprPtr> dims = b->dims();
int total_byte_size = b->dtype().byte_size();
for (auto& dim : dims) {
dim->accept(this);
internal_buffers_.insert(std::make_pair(b, std::move(buffer)));
}
- void visit(Free* v) override {
- Buf* b = v->buf();
+ void visit(FreePtr v) override {
+ BufPtr b = v->buf();
int count = internal_buffers_.erase(b);
if (count == 0) {
throw std::runtime_error(
buffer_mapping_.erase(b);
}
- void visit(Let* v) override {
+ void visit(LetPtr v) override {
var_by_scope_[scope_].push_back(v->var());
bindVar(v->var(), evaluateExpr(v->value()));
}
- void visit(Cond* v) override {
+ void visit(CondPtr v) override {
v->condition()->accept(this);
if (value().as<int>()) {
if (v->true_stmt()) {
}
Value value_;
- Block* scope_;
- std::unordered_map<Expr*, Value> eval_context_;
- std::unordered_map<Block*, std::vector<Expr*>> var_by_scope_;
- std::unordered_map<Buf*, void*> buffer_mapping_;
- std::unordered_map<Buf*, std::unique_ptr<std::vector<int>>> internal_buffers_;
+ BlockPtr scope_;
+ std::unordered_map<ExprPtr, Value> eval_context_;
+ std::unordered_map<BlockPtr, std::vector<ExprPtr>> var_by_scope_;
+ std::unordered_map<BufPtr, void*> buffer_mapping_;
+ std::unordered_map<BufPtr, std::unique_ptr<std::vector<int>>>
+ internal_buffers_;
};
SimpleIREvaluator::SimpleIREvaluator(
- Stmt* stmt,
+ StmtPtr stmt,
const std::vector<BufferArg>& buffer_args,
at::Device device,
const std::string& kernel_func_name)
}
}
-void SimpleIREvaluator::bindVar(Var* v, Expr* e) {
+void SimpleIREvaluator::bindVar(VarPtr v, ExprPtr e) {
impl_->bindVar(v, impl_->evaluateExpr(e));
}
class TORCH_API SimpleIREvaluator : public CodeGen {
public:
SimpleIREvaluator(
- Stmt* stmt,
+ StmtPtr stmt,
const std::vector<BufferArg>& buffer_args,
at::Device device = at::kCPU,
const std::string& kernel_func_name = "func");
call(args);
}
- void bindVar(Var* v, Expr* e);
+ void bindVar(VarPtr v, ExprPtr e);
Value value() const;
private:
std::vector<BufferArg> buffer_args_extended = buffer_args;
Placeholder ret_buf("ret_val", dtype_, {1});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::vector<Expr*> indices;
- Expr* zero = new IntImm(0);
+ std::vector<ExprPtr> indices;
+ ExprPtr zero = alloc<IntImm>(0);
for (size_t i = 0; i < ret_buf.data()->ndim(); i++) {
indices.push_back(zero);
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- Stmt* store_stmt =
+ StmtPtr store_stmt =
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
- new Store(ret_buf.data(), indices, expr.node());
+ alloc<Store>(ret_buf.data(), indices, expr.node());
buffer_args_extended.emplace_back(ret_buf);
codegen_.reset(new CodeGenType(store_stmt, buffer_args_extended));
}
call(call_args);
}
- void bindVar(Var* v, Expr* e) {
+ void bindVar(VarPtr v, ExprPtr e) {
codegen_->bindVar(v, e);
}
// Substitutes the given vars with their corresponding expressions in the input
// expression.
-inline Expr* Substitute(Expr* expr, const VarMapping& var_mapping) {
+inline ExprPtr Substitute(ExprPtr expr, const VarMapping& var_mapping) {
VarSubMutator var_sub(var_mapping);
return expr->accept_mutator(&var_sub);
}
// Substitutes the given vars with their corresponding expressions in the input
// statement.
-inline Stmt* Substitute(Stmt* stmt, const VarMapping& var_mapping) {
+inline StmtPtr Substitute(StmtPtr stmt, const VarMapping& var_mapping) {
VarSubMutator var_sub(var_mapping);
return stmt->accept_mutator(&var_sub);
}
// their corresponding expressions in the clone.
// NOTE: This works because cloning reuses variables and does not create new
// ones, and `VarMapping` input has variables as the key.
-inline Expr* SubstituteInClone(Expr* expr, const VarMapping& var_mapping) {
+inline ExprPtr SubstituteInClone(ExprPtr expr, const VarMapping& var_mapping) {
VarSubMutator var_sub(var_mapping);
return Expr::clone(expr)->accept_mutator(&var_sub);
}
// their corresponding expressions in the clone.
// NOTE: This works because cloning reuses variables and does not create new
// ones, and `VarMapping` input has variables as the key.
-inline Stmt* SubstituteInClone(Stmt* stmt, const VarMapping& var_mapping) {
+inline StmtPtr SubstituteInClone(StmtPtr stmt, const VarMapping& var_mapping) {
VarSubMutator var_sub(var_mapping);
return Stmt::clone(stmt)->accept_mutator(&var_sub);
}
#pragma once
#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/tensorexpr/fwd_decls.h>
#include <sstream>
#include <stdexcept>
// Forward declarations of functions
namespace std {
-TORCH_API std::string to_string(const torch::jit::tensorexpr::Expr*);
-TORCH_API std::string to_string(const torch::jit::tensorexpr::Stmt*);
+TORCH_API std::string to_string(const torch::jit::tensorexpr::ExprPtr);
+TORCH_API std::string to_string(const torch::jit::tensorexpr::StmtPtr);
} // namespace std
namespace torch {
public:
explicit unimplemented_lowering()
: std::runtime_error("UNIMPLEMENTED LOWERING") {}
- explicit unimplemented_lowering(Expr* expr)
+ explicit unimplemented_lowering(ExprPtr expr)
: std::runtime_error("UNIMPLEMENTED LOWERING: " + std::to_string(expr)) {}
- explicit unimplemented_lowering(Stmt* stmt)
+ explicit unimplemented_lowering(StmtPtr stmt)
: std::runtime_error("UNIMPLEMENTED LOWERING: " + std::to_string(stmt)) {}
};
explicit malformed_input() : std::runtime_error("MALFORMED INPUT") {}
explicit malformed_input(const std::string& err)
: std::runtime_error("MALFORMED INPUT: " + err) {}
- explicit malformed_input(Expr* expr)
+ explicit malformed_input(ExprPtr expr)
: std::runtime_error("MALFORMED INPUT: " + std::to_string(expr)) {}
- explicit malformed_input(const std::string& err, Expr* expr)
+ explicit malformed_input(const std::string& err, ExprPtr expr)
: std::runtime_error(
"MALFORMED INPUT: " + err + " - " + std::to_string(expr)) {}
- explicit malformed_input(Stmt* stmt)
+ explicit malformed_input(StmtPtr stmt)
: std::runtime_error("MALFORMED INPUT: " + std::to_string(stmt)) {}
- explicit malformed_input(const std::string& err, Stmt* stmt)
+ explicit malformed_input(const std::string& err, StmtPtr stmt)
: std::runtime_error(
"MALFORMED INPUT: " + err + " - " + std::to_string(stmt)) {}
};
explicit malformed_ir() : std::runtime_error("MALFORMED IR") {}
explicit malformed_ir(const std::string& err)
: std::runtime_error("MALFORMED IR: " + err) {}
- explicit malformed_ir(Expr* expr)
+ explicit malformed_ir(ExprPtr expr)
: std::runtime_error("MALFORMED IR: " + std::to_string(expr)) {}
- explicit malformed_ir(const std::string& err, Expr* expr)
+ explicit malformed_ir(const std::string& err, ExprPtr expr)
: std::runtime_error(
"MALFORMED IR: " + err + " - " + std::to_string(expr)) {}
- explicit malformed_ir(Stmt* stmt)
+ explicit malformed_ir(StmtPtr stmt)
: std::runtime_error("MALFORMED IR: " + std::to_string(stmt)) {}
- explicit malformed_ir(const std::string& err, Stmt* stmt)
+ explicit malformed_ir(const std::string& err, StmtPtr stmt)
: std::runtime_error(
"MALFORMED IR: " + err + " - " + std::to_string(stmt)) {}
};
const std::vector<ExprHandle>& dims,
Dtype dtype) {
return ExprHandle(
- new Buf(name_hint, ExprHandleVectorToExprVector(dims), dtype));
+ alloc<Buf>(name_hint, ExprHandleVectorToExprVector(dims), dtype));
}
ExprHandle Buf::make(const std::vector<ExprHandle>& dims, Dtype dtype) {
*/
#pragma once
+#include <torch/csrc/jit/tensorexpr/fwd_decls.h>
#include <torch/csrc/jit/tensorexpr/ir_mutator.h>
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
#include <torch/csrc/jit/tensorexpr/mem_arena.h>
return dtype_;
}
virtual void accept(IRVisitor* visitor) = 0;
- virtual Expr* accept_mutator(IRMutator* mutator) = 0;
+ virtual ExprPtr accept_mutator(IRMutator* mutator) = 0;
IRNodeType expr_type() const {
return expr_type_;
* All sub-expressions inside the given expressions are also cloned. Note
* that the variables are not deep-copied since they are immutable.
*/
- static Expr* clone(Expr* s);
+ static ExprPtr clone(ExprPtr s);
private:
Dtype dtype_;
public:
using ExprNodeBase = ExprNode<Op>;
void accept(IRVisitor* visitor) override {
- visitor->visit(static_cast<Op*>(this));
+ visitor->visit(static_to<Op>(this));
}
- Expr* accept_mutator(IRMutator* mutator) override;
+ ExprPtr accept_mutator(IRMutator* mutator) override;
// pass the constructor to the base class
using Base::Base;
};
class TORCH_API ExprHandle {
public:
ExprHandle() = default;
- explicit ExprHandle(Expr* node)
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
- : base_expr_node_(const_cast<Expr*>(node)) {}
+ explicit ExprHandle(ExprPtr node) : base_expr_node_(node) {}
- Expr* node() {
+ ExprPtr node() {
return base_expr_node_;
}
- Expr* node() const {
+ ExprPtr node() const {
return base_expr_node_;
}
#undef IMM_EXPR_DECLARE
template <class Op>
- Op* AsNode() {
- return dynamic_cast<Op*>(this->node());
+ NodePtr<Op> AsNode() {
+ return to<Op>(this->node());
}
template <class Op>
- Op* AsNode() const {
+ NodePtr<Op> AsNode() const {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
return const_cast<ExprHandle*>(this)->AsNode<Op>();
}
ExprHandle operator>>(const ExprHandle& other) const;
private:
- Expr* base_expr_node_ = nullptr;
+ ExprPtr base_expr_node_ = nullptr;
};
// The underlying representation node to a Var.
class TORCH_API Var : public ExprNode<Var> {
public:
static ExprHandle make(const std::string& name_hint, Dtype dtype) {
- return ExprHandle(new Var(name_hint, dtype));
+ return ExprHandle(alloc<Var>(name_hint, dtype));
}
static ExprHandle make(Dtype dtype) {
- return ExprHandle(new Var("", dtype));
+ return ExprHandle(alloc<Var>("", dtype));
}
// TODO: unique_name
static ExprHandle make(const std::vector<ExprHandle>& dims, Dtype dtype);
// TODO: unique_name
- Var* base_handle() const {
+ VarPtr base_handle() const {
return base_handle_;
}
- void set_base_handle(Var* base_handle) {
+ void set_base_handle(VarPtr base_handle) {
base_handle_ = base_handle;
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Buf(const std::string& name_hint,
- const std::vector<Expr*>& dims,
+ const std::vector<ExprPtr>& dims,
Dtype dtype,
- Expr* initializer = nullptr)
- : Buf(new Var(name_hint, kHandle), dims, dtype, initializer) {}
+ ExprPtr initializer = nullptr)
+ : Buf(alloc<Var>(name_hint, kHandle), dims, dtype, initializer) {}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
- Buf(Var* var,
- std::vector<Expr*> dims,
+ Buf(VarPtr var,
+ std::vector<ExprPtr> dims,
Dtype dtype,
- Expr* initializer = nullptr)
+ ExprPtr initializer = nullptr)
: ExprNodeBase(dtype, kPrimitive),
base_handle_(var),
dims_(std::move(dims)),
size_t ndim() const {
return dims_.size();
}
- Expr* dim(size_t index) const {
+ ExprPtr dim(size_t index) const {
if (index >= ndim()) {
throw out_of_range_index();
}
return dims_[index];
}
- std::vector<Expr*> dims() const {
+ std::vector<ExprPtr> dims() const {
return dims_;
}
- void set_dims(std::vector<Expr*> dims) {
+ void set_dims(std::vector<ExprPtr> dims) {
dims_ = dims;
};
- Expr* initializer() const {
+ ExprPtr initializer() const {
return initializer_;
};
}
private:
- Var* base_handle_;
- std::vector<Expr*> dims_;
- Expr* initializer_;
+ VarPtr base_handle_;
+ std::vector<ExprPtr> dims_;
+ ExprPtr initializer_;
};
class TORCH_API BufHandle : public ExprHandle {
explicit BufHandle(Dtype dtype) : ExprHandle(Buf::make("_", {}, dtype)) {}
- explicit BufHandle(Buf* node) : ExprHandle(node) {}
- Buf* node() const {
- return static_cast<Buf*>(ExprHandle::node());
+ explicit BufHandle(BufPtr node) : ExprHandle(node) {}
+ BufPtr node() const {
+ return static_to<Buf>(ExprHandle::node());
}
- Buf* node() {
- return static_cast<Buf*>(ExprHandle::node());
+ BufPtr node() {
+ return static_to<Buf>(ExprHandle::node());
}
template <typename... Ts>
explicit VarHandle(Dtype dtype) : ExprHandle(Var::make(dtype)) {}
VarHandle(const std::string& name_hint, Dtype dtype)
: ExprHandle(Var::make(name_hint, dtype)) {}
- explicit VarHandle(Var* node) : ExprHandle(node) {}
- Var* node() const {
- return static_cast<Var*>(ExprHandle::node());
+ explicit VarHandle(VarPtr node) : ExprHandle(node) {}
+ VarPtr node() const {
+ return static_to<Var>(ExprHandle::node());
}
bool operator==(const VarHandle& other) const {
return this->node() == other.node();
};
template <class Op, class Base>
-Expr* ExprNode<Op, Base>::accept_mutator(IRMutator* mutator) {
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
- ExprNode* this_mutable = const_cast<ExprNode*>(this);
- return mutator->mutate(static_cast<Op*>(this_mutable));
+ExprPtr ExprNode<Op, Base>::accept_mutator(IRMutator* mutator) {
+ return mutator->mutate(static_to<Op>(this));
}
inline bool same_node(const ExprHandle& expr1, const ExprHandle& expr2) {
--- /dev/null
+#pragma once
+#include <c10/core/ScalarType.h>
+
+namespace torch {
+namespace jit {
+namespace tensorexpr {
+
+template <typename Node>
+using NodePtr = Node*;
+
+template <typename To, typename From>
+NodePtr<To> to(NodePtr<From> x) {
+ return dynamic_cast<NodePtr<To>>(x);
+}
+
+template <typename To, typename From>
+NodePtr<To> static_to(NodePtr<From> x) {
+ return static_cast<NodePtr<To>>(x);
+}
+
+template <typename Node, typename... Args>
+NodePtr<Node> alloc(Args&&... args) {
+ return new Node(std::forward<Args>(args)...);
+}
+
+class Buf;
+class Expr;
+class Stmt;
+class Var;
+
+using BufPtr = NodePtr<Buf>;
+using ExprPtr = NodePtr<Expr>;
+using StmtPtr = NodePtr<Stmt>;
+using VarPtr = NodePtr<Var>;
+
+class ExprHandle;
+
+class Add;
+class And;
+class BitCast;
+class Broadcast;
+class Cast;
+class CompareSelect;
+class Div;
+class IfThenElse;
+class Intrinsics;
+class Let;
+class Load;
+class Lshift;
+class Max;
+class MaxTerm;
+class Min;
+class MinTerm;
+class Mod;
+class Mul;
+class Or;
+class Polynomial;
+class Ramp;
+class ReduceOp;
+class RoundOff;
+class Rshift;
+class Store;
+class Sub;
+class Term;
+class Xor;
+using AddPtr = NodePtr<Add>;
+using AndPtr = NodePtr<And>;
+using BitCastPtr = NodePtr<BitCast>;
+using BroadcastPtr = NodePtr<Broadcast>;
+using CastPtr = NodePtr<Cast>;
+using CompareSelectPtr = NodePtr<CompareSelect>;
+using DivPtr = NodePtr<Div>;
+using IfThenElsePtr = NodePtr<IfThenElse>;
+using IntrinsicsPtr = NodePtr<Intrinsics>;
+using LetPtr = NodePtr<Let>;
+using LoadPtr = NodePtr<Load>;
+using LshiftPtr = NodePtr<Lshift>;
+using MaxPtr = NodePtr<Max>;
+using MaxTermPtr = NodePtr<MaxTerm>;
+using MinPtr = NodePtr<Min>;
+using MinTermPtr = NodePtr<MinTerm>;
+using ModPtr = NodePtr<Mod>;
+using MulPtr = NodePtr<Mul>;
+using OrPtr = NodePtr<Or>;
+using PolynomialPtr = NodePtr<Polynomial>;
+using RampPtr = NodePtr<Ramp>;
+using ReduceOpPtr = NodePtr<ReduceOp>;
+using RoundOffPtr = NodePtr<RoundOff>;
+using RshiftPtr = NodePtr<Rshift>;
+using StorePtr = NodePtr<Store>;
+using SubPtr = NodePtr<Sub>;
+using TermPtr = NodePtr<Term>;
+using XorPtr = NodePtr<Xor>;
+
+class Allocate;
+class AtomicAdd;
+class Block;
+class Cond;
+class ExternalCall;
+class For;
+class Free;
+class SyncThreads;
+using AllocatePtr = NodePtr<Allocate>;
+using AtomicAddPtr = NodePtr<AtomicAdd>;
+using BlockPtr = NodePtr<Block>;
+using CondPtr = NodePtr<Cond>;
+using ExternalCallPtr = NodePtr<ExternalCall>;
+using ForPtr = NodePtr<For>;
+using FreePtr = NodePtr<Free>;
+using SyncThreadsPtr = NodePtr<SyncThreads>;
+
+#define IMM_DECLARE(Type, Name) \
+ class Name##Imm; \
+ using Name##ImmPtr = NodePtr<Name##Imm>;
+AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE);
+#undef IMM_DECLARE
+
+} // namespace tensorexpr
+} // namespace jit
+} // namespace torch
return hasHalf_;
}
- void visit(Load* v) override {
+ void visit(LoadPtr v) override {
hasHalf_ |= v->dtype().scalar_type() == ScalarType::Half;
IRVisitor::visit(v);
}
- void visit(Store* v) override {
+ void visit(StorePtr v) override {
hasHalf_ |= v->buf()->dtype().scalar_type() == ScalarType::Half;
IRVisitor::visit(v);
}
- void visit(HalfImm* v) override {
+ void visit(HalfImmPtr v) override {
hasHalf_ = true;
}
- void visit(Cast* v) override {
+ void visit(CastPtr v) override {
hasHalf_ |= v->dtype().scalar_type() == ScalarType::Half;
IRVisitor::visit(v);
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class HalfRewriter : public IRMutator {
- Expr* mutate(Load* v) override {
- Expr* child = IRMutator::mutate(v);
+ ExprPtr mutate(LoadPtr v) override {
+ ExprPtr child = IRMutator::mutate(v);
if (child->dtype().scalar_type() != ScalarType::Half) {
return child;
}
- Expr* ret =
- new Cast(child->dtype().cloneWithScalarType(ScalarType::Float), child);
+ ExprPtr ret = alloc<Cast>(
+ child->dtype().cloneWithScalarType(ScalarType::Float), child);
inserted_half_casts_.insert(ret);
return ret;
}
- Stmt* mutate(Store* v) override {
+ StmtPtr mutate(StorePtr v) override {
// Since mutation changes the `value()` expression in-place, we need to
// get the dtype of the `value()` before that is mutated.
Dtype newType = v->value()->dtype();
- Expr* new_val = v->value()->accept_mutator(this);
+ ExprPtr new_val = v->value()->accept_mutator(this);
if (newType.scalar_type() == ScalarType::Half) {
new_val =
- new Cast(newType.cloneWithScalarType(ScalarType::Half), new_val);
+ alloc<Cast>(newType.cloneWithScalarType(ScalarType::Half), new_val);
inserted_half_casts_.insert(new_val);
}
- return new Store(v->buf(), v->indices(), new_val);
+ return alloc<Store>(v->buf(), v->indices(), new_val);
}
- Expr* mutate(HalfImm* v) override {
- return new Cast(kFloat, v);
+ ExprPtr mutate(HalfImmPtr v) override {
+ return alloc<Cast>(kFloat, v);
}
- Expr* mutate(Cast* v) override {
- Expr* child = v->src_value()->accept_mutator(this);
+ ExprPtr mutate(CastPtr v) override {
+ ExprPtr child = v->src_value()->accept_mutator(this);
// just don't allow half casts we didn't insert.
if (v->dtype().scalar_type() == ScalarType::Half) {
}
// Remove Half(Float()) and friends.
- Cast* cast_child = dynamic_cast<Cast*>(child);
+ CastPtr cast_child = to<Cast>(child);
if (cast_child) {
if (v->dtype().is_floating_point() &&
cast_child->dtype().is_floating_point()) {
- return new Cast(v->dtype(), cast_child->src_value());
+ return alloc<Cast>(v->dtype(), cast_child->src_value());
}
}
return v;
}
- return new Cast(v->dtype(), child);
+ return alloc<Cast>(v->dtype(), child);
}
- Stmt* mutate(Let* v) override {
+ StmtPtr mutate(LetPtr v) override {
if (v->dtype().scalar_type() == ScalarType::Half) {
- Var* load_new_var = new Var(v->var()->name_hint(), kFloat);
- Expr* new_value = new Cast(
+ VarPtr load_new_var = alloc<Var>(v->var()->name_hint(), kFloat);
+ ExprPtr new_value = alloc<Cast>(
v->dtype().cloneWithScalarType(ScalarType::Float),
v->value()->accept_mutator(this));
var_map[v->var()] = load_new_var;
- return new Let(load_new_var, new_value);
+ return alloc<Let>(load_new_var, new_value);
}
return IRMutator::mutate(v);
}
- Expr* mutate(Var* v) override {
+ ExprPtr mutate(VarPtr v) override {
auto it = var_map.find(v);
if (it != var_map.end()) {
return it->second;
}
private:
- std::unordered_set<Expr*> inserted_half_casts_;
- std::unordered_map<Var*, Var*> var_map;
+ std::unordered_set<ExprPtr> inserted_half_casts_;
+ std::unordered_map<VarPtr, VarPtr> var_map;
};
} // namespace tensorexpr
return _h != other;
}
-void HashProvider::visit(Add* v) {
+void HashProvider::visit(AddPtr v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "+", hashOf(v->rhs())));
}
-void HashProvider::visit(Sub* v) {
+void HashProvider::visit(SubPtr v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "-", hashOf(v->rhs())));
}
-void HashProvider::visit(Mul* v) {
+void HashProvider::visit(MulPtr v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "*", hashOf(v->rhs())));
}
-void HashProvider::visit(Div* v) {
+void HashProvider::visit(DivPtr v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "/", hashOf(v->rhs())));
}
-void HashProvider::visit(Mod* v) {
+void HashProvider::visit(ModPtr v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "%", hashOf(v->rhs())));
}
-void HashProvider::visit(Max* v) {
+void HashProvider::visit(MaxPtr v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "Mx", hashOf(v->rhs())));
}
-void HashProvider::visit(Min* v) {
+void HashProvider::visit(MinPtr v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "Mn", hashOf(v->rhs())));
}
-void HashProvider::visit(And* v) {
+void HashProvider::visit(AndPtr v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "&", hashOf(v->rhs())));
}
-void HashProvider::visit(Or* v) {
+void HashProvider::visit(OrPtr v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "|", hashOf(v->rhs())));
}
-void HashProvider::visit(Xor* v) {
+void HashProvider::visit(XorPtr v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "^", hashOf(v->rhs())));
}
-void HashProvider::visit(Lshift* v) {
+void HashProvider::visit(LshiftPtr v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "<<", hashOf(v->rhs())));
}
-void HashProvider::visit(Rshift* v) {
+void HashProvider::visit(RshiftPtr v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), ">>", hashOf(v->rhs())));
}
-void HashProvider::visit(CompareSelect* v) {
+void HashProvider::visit(CompareSelectPtr v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
hashOf(v->ret_val2())));
}
-void HashProvider::visit(Cast* v) {
+void HashProvider::visit(CastPtr v) {
CACHE_GUARD();
v->src_value()->accept(this);
putHash(v, hash_combine("cast", v->dtype(), hashOf(v->src_value())));
}
-void HashProvider::visit(Var* v) {
+void HashProvider::visit(VarPtr v) {
CACHE_GUARD();
putHash(v, hash_combine("var", name_manager_.get_unique_name(v)));
}
-void HashProvider::visit(Ramp* v) {
+void HashProvider::visit(RampPtr v) {
CACHE_GUARD();
v->base()->accept(this);
v->stride()->accept(this);
hash_combine("ramp", hashOf(v->base()), hashOf(v->stride()), v->lanes()));
}
-void HashProvider::visit(Load* v) {
+void HashProvider::visit(LoadPtr v) {
CACHE_GUARD();
v->base_handle()->accept(this);
SimplifierHashType indices_hash;
- for (Expr* ind : v->indices()) {
+ for (ExprPtr ind : v->indices()) {
ind->accept(this);
indices_hash = hash_combine(indices_hash, hashOf(ind));
}
putHash(v, hash_combine("load", hashOf(v->base_handle()), indices_hash));
}
-void HashProvider::visit(Store* v) {
+void HashProvider::visit(StorePtr v) {
CACHE_GUARD();
v->base_handle()->accept(this);
SimplifierHashType indices_hash;
- for (Expr* ind : v->indices()) {
+ for (ExprPtr ind : v->indices()) {
ind->accept(this);
indices_hash = hash_combine(indices_hash, hashOf(ind));
}
"store", hashOf(v->base_handle()), indices_hash, hashOf(v->value())));
}
-void HashProvider::visit(Block* v) {
+void HashProvider::visit(BlockPtr v) {
CACHE_GUARD();
SimplifierHashType hash;
- for (Stmt* s : *v) {
+ for (StmtPtr s : *v) {
s->accept(this);
hash = hash_combine(hash, hashOf(s));
}
putHash(v, hash);
}
-void HashProvider::visit(For* v) {
+void HashProvider::visit(ForPtr v) {
CACHE_GUARD();
v->var()->accept(this);
v->start()->accept(this);
putHash(v, hash);
}
-void HashProvider::visit(Broadcast* v) {
+void HashProvider::visit(BroadcastPtr v) {
CACHE_GUARD();
v->value()->accept(this);
putHash(v, hash_combine("broadcast", hashOf(v->value()), v->lanes()));
}
-void HashProvider::visit(IfThenElse* v) {
+void HashProvider::visit(IfThenElsePtr v) {
CACHE_GUARD();
v->condition()->accept(this);
v->true_value()->accept(this);
hashOf(v->false_value())));
}
-void HashProvider::visit(Intrinsics* v) {
+void HashProvider::visit(IntrinsicsPtr v) {
CACHE_GUARD();
// calls to rand are not symbolic and have a different value each time, they
// should not hash to anything and this is the best we can do.
putHash(v, hash);
}
-void HashProvider::visit(Allocate* v) {
+void HashProvider::visit(AllocatePtr v) {
CACHE_GUARD();
- Var* buffer_var = v->buffer_var();
+ VarPtr buffer_var = v->buffer_var();
buffer_var->accept(this);
SimplifierHashType hash =
hash_combine("allocate", hashOf(buffer_var), v->dtype());
- std::vector<Expr*> dims = v->dims();
- for (Expr* dim : dims) {
+ std::vector<ExprPtr> dims = v->dims();
+ for (ExprPtr dim : dims) {
dim->accept(this);
hash = hash_combine(hash, hashOf(dim));
}
putHash(v, hash);
}
-void HashProvider::visit(Free* v) {
+void HashProvider::visit(FreePtr v) {
CACHE_GUARD();
- Var* buffer_var = v->buffer_var();
+ VarPtr buffer_var = v->buffer_var();
buffer_var->accept(this);
putHash(v, hash_combine("free", hashOf(buffer_var)));
}
-void HashProvider::visit(Cond* v) {
+void HashProvider::visit(CondPtr v) {
CACHE_GUARD();
- Expr* condition = v->condition();
- Stmt* true_stmt = v->true_stmt();
- Stmt* false_stmt = v->false_stmt();
+ ExprPtr condition = v->condition();
+ StmtPtr true_stmt = v->true_stmt();
+ StmtPtr false_stmt = v->false_stmt();
condition->accept(this);
SimplifierHashType hash = hash_combine("cond", hashOf(condition));
putHash(v, hash);
}
-void HashProvider::visit(Term* v) {
+void HashProvider::visit(TermPtr v) {
CACHE_GUARD();
v->scalar()->accept(this);
SimplifierHashType hash = hash_combine("term", hashOf(v->scalar()));
- for (auto* c : v->variables()) {
+ for (auto c : v->variables()) {
c->accept(this);
hash = hash_combine(hash, hashOf(c));
}
putHash(v, hash);
}
-void HashProvider::visit(Polynomial* v) {
+void HashProvider::visit(PolynomialPtr v) {
CACHE_GUARD();
v->scalar()->accept(this);
SimplifierHashType hash = hash_combine("term", hashOf(v->scalar()));
- for (auto* c : v->variables()) {
+ for (auto c : v->variables()) {
c->accept(this);
hash = hash_combine(hash, hashOf(c));
}
putHash(v, hash);
}
-void HashProvider::visit(MaxTerm* v) {
+void HashProvider::visit(MaxTermPtr v) {
CACHE_GUARD();
SimplifierHashType hash = hash_combine("maxterm");
if (v->scalar()) {
hash = hash_combine(hash, hashOf(v->scalar()));
}
- for (auto* c : v->variables()) {
+ for (auto c : v->variables()) {
c->accept(this);
hash = hash_combine(hash, hashOf(c));
}
putHash(v, hash);
}
-void HashProvider::visit(MinTerm* v) {
+void HashProvider::visit(MinTermPtr v) {
CACHE_GUARD();
SimplifierHashType hash = hash_combine("minterm");
if (v->scalar()) {
hash = hash_combine(hash, hashOf(v->scalar()));
}
- for (auto* c : v->variables()) {
+ for (auto c : v->variables()) {
c->accept(this);
hash = hash_combine(hash, hashOf(c));
}
class TORCH_API HashProvider : public IRVisitor {
public:
template <class T>
- SimplifierHashType hash(T* e) {
+ SimplifierHashType hash(T e) {
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
e->accept(this);
return hashOf(e);
exprToHash_.clear();
}
- void visit(Add* v) override;
- void visit(Sub* v) override;
- void visit(Mul* v) override;
- void visit(Div* v) override;
- void visit(Mod* v) override;
- void visit(Max* v) override;
- void visit(Min* v) override;
- void visit(And* v) override;
- void visit(Or* v) override;
- void visit(Xor* v) override;
- void visit(Lshift* v) override;
- void visit(Rshift* v) override;
- void visit(CompareSelect* v) override;
+ void visit(AddPtr v) override;
+ void visit(SubPtr v) override;
+ void visit(MulPtr v) override;
+ void visit(DivPtr v) override;
+ void visit(ModPtr v) override;
+ void visit(MaxPtr v) override;
+ void visit(MinPtr v) override;
+ void visit(AndPtr v) override;
+ void visit(OrPtr v) override;
+ void visit(XorPtr v) override;
+ void visit(LshiftPtr v) override;
+ void visit(RshiftPtr v) override;
+ void visit(CompareSelectPtr v) override;
// NOLINTNEXTLINE
#define IMM_VISIT(Type, Name) \
- void visit(Name##Imm* v) override { \
+ void visit(Name##ImmPtr v) override { \
CACHE_GUARD(); \
putHash(v, hash_combine(#Name, v->value())); \
}
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT);
#undef IMM_VISIT
- void visit(Cast* v) override;
- void visit(Var* v) override;
- void visit(Ramp* v) override;
- void visit(Load* v) override;
- void visit(Store* v) override;
- void visit(Block* v) override;
- void visit(For* v) override;
- void visit(Broadcast* v) override;
- void visit(IfThenElse* v) override;
- void visit(Intrinsics* v) override;
- void visit(Allocate* v) override;
- void visit(Free* v) override;
- void visit(Cond* v) override;
- void visit(Term* v) override;
- void visit(Polynomial* v) override;
- void visit(MaxTerm* v) override;
- void visit(MinTerm* v) override;
+ void visit(CastPtr v) override;
+ void visit(VarPtr v) override;
+ void visit(RampPtr v) override;
+ void visit(LoadPtr v) override;
+ void visit(StorePtr v) override;
+ void visit(BlockPtr v) override;
+ void visit(ForPtr v) override;
+ void visit(BroadcastPtr v) override;
+ void visit(IfThenElsePtr v) override;
+ void visit(IntrinsicsPtr v) override;
+ void visit(AllocatePtr v) override;
+ void visit(FreePtr v) override;
+ void visit(CondPtr v) override;
+ void visit(TermPtr v) override;
+ void visit(PolynomialPtr v) override;
+ void visit(MaxTermPtr v) override;
+ void visit(MinTermPtr v) override;
template <typename... Types>
SimplifierHashType hash_combine(const Types&... args) {
}
private:
- SimplifierHashType hashOf(Expr* e) {
+ SimplifierHashType hashOf(ExprPtr e) {
auto it = exprToHash_.find(e);
if (it != exprToHash_.end()) {
return it->second;
return hash;
}
- SimplifierHashType hashOf(Stmt* s) {
+ SimplifierHashType hashOf(StmtPtr s) {
auto it = exprToHash_.find(s);
if (it != exprToHash_.end()) {
return it->second;
(seed._h >> 4);
}
- void _hash_combine(SimplifierHashType& seed, Expr* e) {
+ void _hash_combine(SimplifierHashType& seed, ExprPtr e) {
_hash_combine(seed, hash(e));
}
return Dtype(buffer_dtype, index_dtype.lanes());
}
-static Dtype dtypeOfIndices(const std::vector<Expr*>& indices) {
+static Dtype dtypeOfIndices(const std::vector<ExprPtr>& indices) {
if (!indices.size()) {
// Return something so we can handle scalar buffers.
return kInt;
return indices.at(0)->dtype();
}
-void castIndicesToInts(std::vector<Expr*>& indices) {
+void castIndicesToInts(std::vector<ExprPtr>& indices) {
// Cast all indices to either Int or Long
auto index_dtype = ScalarType::Int;
for (auto& index : indices) {
const Dtype& dt = index->dtype();
if (c10::isIntegralType(dt.scalar_type(), true) &&
dt.scalar_type() != index_dtype) {
- index = new Cast(Dtype(index_dtype, dt.lanes()), index);
+ index = alloc<Cast>(Dtype(index_dtype, dt.lanes()), index);
}
}
}
-Load::Load(Dtype dtype, Buf* buf, std::vector<Expr*> indices)
+Load::Load(Dtype dtype, BufPtr buf, std::vector<ExprPtr> indices)
: ExprNodeBase(dtype), buf_(buf), indices_(std::move(indices)) {
castIndicesToInts(indices_);
}
-Load::Load(Buf* buf, const std::vector<Expr*>& indices)
+Load::Load(BufPtr buf, const std::vector<ExprPtr>& indices)
: Load(ChooseDtype(buf->dtype(), dtypeOfIndices(indices)), buf, indices) {}
ExprHandle Load::make(
const BufHandle& buf,
const std::vector<ExprHandle>& indices) {
return ExprHandle(
- new Load(dtype, buf.node(), ExprHandleVectorToExprVector(indices)));
+ alloc<Load>(dtype, buf.node(), ExprHandleVectorToExprVector(indices)));
}
ExprHandle Load::make(
return Load::make(buf.dtype(), buf, indices);
}
-Store::Store(Buf* buf, std::vector<Expr*> indices, Expr* value)
+Store::Store(BufPtr buf, std::vector<ExprPtr> indices, ExprPtr value)
: buf_(buf), indices_(std::move(indices)), value_(value) {
castIndicesToInts(indices_);
}
-Store* Store::make(
+StorePtr Store::make(
const BufHandle& buf,
const std::vector<ExprHandle>& indices,
const ExprHandle& value) {
- return new Store(
+ return alloc<Store>(
buf.node(), ExprHandleVectorToExprVector(indices), value.node());
}
-Expr* flatten_index(
- const std::vector<Expr*>& dims,
- const std::vector<Expr*>& indices) {
+ExprPtr flatten_index(
+ const std::vector<ExprPtr>& dims,
+ const std::vector<ExprPtr>& indices) {
// Handle already flattened indices first
if (indices.size() == 1) {
return indices[0];
throw malformed_input("dimensions mismatch in flatten_index");
}
if (ndim == 0) {
- return new IntImm(0);
+ return alloc<IntImm>(0);
}
- std::vector<Expr*> strides(ndim);
+ std::vector<ExprPtr> strides(ndim);
// stride[i] = stride[i+1]*dims[i+1], i < ndim-1
// stride[i] = 1, i = ndim-1
- strides[ndim - 1] = new IntImm(1);
+ strides[ndim - 1] = alloc<IntImm>(1);
for (size_t i = 1; i < ndim; i++) {
- strides[ndim - 1 - i] = new Mul(strides[ndim - i], dims[ndim - i]);
+ strides[ndim - 1 - i] = alloc<Mul>(strides[ndim - i], dims[ndim - i]);
}
- Expr* total_index = new IntImm(0);
+ ExprPtr total_index = alloc<IntImm>(0);
for (const auto i : c10::irange(ndim)) {
- total_index = new Add(total_index, new Mul(indices[i], strides[i]));
+ total_index = alloc<Add>(total_index, alloc<Mul>(indices[i], strides[i]));
}
return total_index;
}
Dtype Intrinsics::IntrinsicsDtype(
IntrinsicsOp op_type,
- const std::vector<Expr*>& params) {
+ const std::vector<ExprPtr>& params) {
// TODO: check the op_type and make a real decision
// Doesnt this fail with kRand?
if (params.size() == 0) {
}
}
-ExternalCall* ExternalCall::make(
+ExternalCallPtr ExternalCall::make(
BufHandle buf,
const std::string& func_name,
const std::vector<BufHandle>& buf_args,
const std::vector<ExprHandle>& args) {
- std::vector<Buf*> buf_arg_nodes;
+ std::vector<BufPtr> buf_arg_nodes;
buf_arg_nodes.reserve(buf_args.size());
for (const BufHandle& buf_arg : buf_args) {
buf_arg_nodes.push_back(buf_arg.node());
}
- return new ExternalCall(
+ return alloc<ExternalCall>(
buf.node(), func_name, buf_arg_nodes, ExprHandleVectorToExprVector(args));
}
-std::vector<Expr*> ExprHandleVectorToExprVector(
+std::vector<ExprPtr> ExprHandleVectorToExprVector(
const std::vector<ExprHandle>& v) {
- std::vector<Expr*> result(v.size());
+ std::vector<ExprPtr> result(v.size());
for (const auto i : c10::irange(v.size())) {
result[i] = v[i].node();
}
}
std::vector<ExprHandle> ExprVectorToExprHandleVector(
- const std::vector<Expr*>& v) {
+ const std::vector<ExprPtr>& v) {
std::vector<ExprHandle> result(v.size());
for (const auto i : c10::irange(v.size())) {
result[i] = ExprHandle(v[i]);
return result;
}
-std::vector<Var*> VarHandleVectorToVarVector(const std::vector<VarHandle>& v) {
- std::vector<Var*> result(v.size());
+std::vector<VarPtr> VarHandleVectorToVarVector(
+ const std::vector<VarHandle>& v) {
+ std::vector<VarPtr> result(v.size());
for (const auto i : c10::irange(v.size())) {
result[i] = v[i].node();
}
return result;
}
-std::vector<VarHandle> VarVectorToVarHandleVector(const std::vector<Var*>& v) {
+std::vector<VarHandle> VarVectorToVarHandleVector(
+ const std::vector<VarPtr>& v) {
std::vector<VarHandle> result(v.size());
for (const auto i : c10::irange(v.size())) {
result[i] = VarHandle(v[i]);
return result;
}
-bool immediateIsNegative(Expr* e) {
-#define TYPE_CASE(Type, Name) \
- if (Name##Imm* imm = dynamic_cast<Name##Imm*>(e)) { \
- return imm->value() < 0; \
+bool immediateIsNegative(ExprPtr e) {
+#define TYPE_CASE(Type, Name) \
+ if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
+ return imm->value() < 0; \
}
AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
#undef TYPE_CASE
#include <c10/util/string_utils.h>
#include <torch/csrc/jit/tensorexpr/exceptions.h>
#include <torch/csrc/jit/tensorexpr/expr.h>
+#include <torch/csrc/jit/tensorexpr/fwd_decls.h>
#include <torch/csrc/jit/tensorexpr/stmt.h>
#include <ATen/core/ivalue.h>
class TORCH_API Cast : public ExprNode<Cast> {
public:
- Expr* src_value() const {
+ ExprPtr src_value() const {
return src_value_;
}
- void set_src_value(Expr* src_value) {
+ void set_src_value(ExprPtr src_value) {
src_value_ = src_value;
}
static ExprHandle make(Dtype dtype, const ExprHandle& src_value) {
- return ExprHandle(new Cast(dtype, src_value.node()));
+ return ExprHandle(alloc<Cast>(dtype, src_value.node()));
}
- Cast(Dtype dtype, Expr* src_value)
+ Cast(Dtype dtype, ExprPtr src_value)
: ExprNodeBase(dtype, kCast), src_value_(src_value) {}
bool isConstant() const override {
}
private:
- Expr* src_value_;
+ ExprPtr src_value_;
};
template <typename T>
// This is a bitwise cast, akin to bitcast in LLVM
class TORCH_API BitCast : public ExprNode<BitCast> {
public:
- Expr* src_value() const {
+ ExprPtr src_value() const {
return src_value_;
}
- void set_src_value(Expr* src_value) {
+ void set_src_value(ExprPtr src_value) {
src_value_ = src_value;
}
static ExprHandle make(Dtype dtype, const ExprHandle& src_value) {
- return ExprHandle(new BitCast(dtype, src_value.node()));
+ return ExprHandle(alloc<BitCast>(dtype, src_value.node()));
}
- BitCast(Dtype dtype, Expr* src_value)
+ BitCast(Dtype dtype, ExprPtr src_value)
: ExprNodeBase(dtype, kBitCast), src_value_(src_value) {
TORCH_CHECK(src_value_->dtype().byte_size() == dtype.byte_size());
}
}
private:
- Expr* src_value_;
+ ExprPtr src_value_;
};
template <typename T>
template <typename Op>
class BinaryOpNode : public ExprNode<Op> {
public:
- Expr* lhs() const {
+ ExprPtr lhs() const {
return this->lhs_;
}
- Expr* rhs() const {
+ ExprPtr rhs() const {
return this->rhs_;
}
- void set_lhs(Expr* lhs) {
+ void set_lhs(ExprPtr lhs) {
lhs_ = lhs;
}
- void set_rhs(Expr* rhs) {
+ void set_rhs(ExprPtr rhs) {
rhs_ = rhs;
}
static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) {
- return ExprHandle(new Op(lhs.node(), rhs.node()));
+ return ExprHandle(alloc<Op>(lhs.node(), rhs.node()));
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
BinaryOpNode(
- Expr* lhs_v,
- Expr* rhs_v,
+ ExprPtr lhs_v,
+ ExprPtr rhs_v,
IRNodeType expr_type,
ScalarType ret_type = ScalarType::Undefined)
: ExprNode<Op>(
rhs_(CastIfNeeded(rhs_v, ExprNode<Op>::dtype())) {}
private:
- static Expr* CastIfNeeded(Expr* expr, Dtype dst_dtype) {
+ static ExprPtr CastIfNeeded(ExprPtr expr, Dtype dst_dtype) {
if (expr->dtype() == dst_dtype) {
return expr;
}
return Cast::make(dst_dtype, ExprHandle(expr)).node();
}
- Expr* lhs_;
- Expr* rhs_;
+ ExprPtr lhs_;
+ ExprPtr rhs_;
};
class TORCH_API Add : public BinaryOpNode<Add> {
public:
- Add(Expr* lhs, Expr* rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kAdd) {}
+ Add(ExprPtr lhs, ExprPtr rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kAdd) {}
};
class TORCH_API Sub : public BinaryOpNode<Sub> {
public:
- Sub(Expr* lhs, Expr* rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kSub) {}
+ Sub(ExprPtr lhs, ExprPtr rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kSub) {}
};
class TORCH_API Mul : public BinaryOpNode<Mul> {
public:
- Mul(Expr* lhs, Expr* rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kMul) {}
+ Mul(ExprPtr lhs, ExprPtr rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kMul) {}
};
class TORCH_API Div : public BinaryOpNode<Div> {
public:
- Div(Expr* lhs, Expr* rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kDiv) {}
+ Div(ExprPtr lhs, ExprPtr rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kDiv) {}
};
class TORCH_API Mod : public BinaryOpNode<Mod> {
public:
- Mod(Expr* lhs, Expr* rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kMod) {}
+ Mod(ExprPtr lhs, ExprPtr rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kMod) {}
};
template <typename Op>
class BitwiseOpNode : public BinaryOpNode<Op> {
public:
- BitwiseOpNode(Expr* lhs, Expr* rhs, IRNodeType type)
+ BitwiseOpNode(ExprPtr lhs, ExprPtr rhs, IRNodeType type)
: BinaryOpNode<Op>(lhs, rhs, type) {}
static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) {
class TORCH_API And : public BitwiseOpNode<And> {
public:
- And(Expr* lhs, Expr* rhs) : BitwiseOpNode(lhs, rhs, IRNodeType::kAnd) {}
+ And(ExprPtr lhs, ExprPtr rhs) : BitwiseOpNode(lhs, rhs, IRNodeType::kAnd) {}
};
class TORCH_API Or : public BitwiseOpNode<Or> {
public:
- Or(Expr* lhs, Expr* rhs) : BitwiseOpNode(lhs, rhs, IRNodeType::kOr) {}
+ Or(ExprPtr lhs, ExprPtr rhs) : BitwiseOpNode(lhs, rhs, IRNodeType::kOr) {}
};
class TORCH_API Xor : public BitwiseOpNode<Xor> {
public:
- Xor(Expr* lhs, Expr* rhs) : BitwiseOpNode(lhs, rhs, IRNodeType::kXor) {}
+ Xor(ExprPtr lhs, ExprPtr rhs) : BitwiseOpNode(lhs, rhs, IRNodeType::kXor) {}
};
class TORCH_API Lshift : public BitwiseOpNode<Lshift> {
public:
- Lshift(Expr* lhs, Expr* rhs) : BitwiseOpNode(lhs, rhs, IRNodeType::kLshift) {}
+ Lshift(ExprPtr lhs, ExprPtr rhs)
+ : BitwiseOpNode(lhs, rhs, IRNodeType::kLshift) {}
};
class TORCH_API Rshift : public BitwiseOpNode<Rshift> {
public:
- Rshift(Expr* lhs, Expr* rhs) : BitwiseOpNode(lhs, rhs, IRNodeType::kRshift) {}
+ Rshift(ExprPtr lhs, ExprPtr rhs)
+ : BitwiseOpNode(lhs, rhs, IRNodeType::kRshift) {}
};
// TODO: add TORCH_API
bool propagate_nans_;
public:
- Max(Expr* lhs, Expr* rhs, bool propagate_nans)
+ Max(ExprPtr lhs, ExprPtr rhs, bool propagate_nans)
: BinaryOpNode(lhs, rhs, IRNodeType::kMax),
propagate_nans_(propagate_nans) {}
const ExprHandle& lhs,
const ExprHandle& rhs,
bool propagate_nans) {
- return ExprHandle(new Max(lhs.node(), rhs.node(), propagate_nans));
+ return ExprHandle(alloc<Max>(lhs.node(), rhs.node(), propagate_nans));
}
};
bool propagate_nans_;
public:
- Min(Expr* lhs, Expr* rhs, bool propagate_nans)
+ Min(ExprPtr lhs, ExprPtr rhs, bool propagate_nans)
: BinaryOpNode(lhs, rhs, IRNodeType::kMin),
propagate_nans_(propagate_nans) {}
const ExprHandle& lhs,
const ExprHandle& rhs,
bool propagate_nans) {
- return ExprHandle(new Min(lhs.node(), rhs.node(), propagate_nans));
+ return ExprHandle(alloc<Min>(lhs.node(), rhs.node(), propagate_nans));
}
};
return value_; \
} \
static ExprHandle make(Type value) { \
- return ExprHandle(new Name##Imm(value)); \
+ return ExprHandle(alloc<Name##Imm>(value)); \
} \
\
private: \
// Get immediate by ScalarType.
template <typename T>
-Expr* getImmediateByType(ScalarType immType, T initialVal) {
+ExprPtr getImmediateByType(ScalarType immType, T initialVal) {
switch (immType) {
#define TYPE_CASE(Type, Name) \
case ScalarType::Name: \
- return new Name##Imm(initialVal);
+ return alloc<Name##Imm>(initialVal);
// NOLINTNEXTLINE(bugprone-branch-clone)
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
#undef TYPE_CASE
}
template <typename T>
-Expr* getImmediateByType(Dtype dtype, T initialVal) {
+ExprPtr getImmediateByType(Dtype dtype, T initialVal) {
return getImmediateByType<T>(dtype.scalar_type(), initialVal);
}
template <typename T>
-T immediateAs(Expr* e) {
-#define TYPE_CASE(Type, Name) \
- if (Name##Imm* imm = dynamic_cast<Name##Imm*>(e)) { \
- return imm->value(); \
+T immediateAs(ExprPtr e) {
+#define TYPE_CASE(Type, Name) \
+ if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
+ return imm->value(); \
}
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
#undef TYPE_CASE
}
template <typename T>
-bool immediateEquals(Expr* e, T val) {
-#define TYPE_CASE(Type, Name) \
- if (Name##Imm* imm = dynamic_cast<Name##Imm*>(e)) { \
- return imm->value() == val; \
+bool immediateEquals(ExprPtr e, T val) {
+#define TYPE_CASE(Type, Name) \
+ if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
+ return imm->value() == val; \
}
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
#undef TYPE_CASE
return false;
}
-TORCH_API bool immediateIsNegative(Expr* e);
+TORCH_API bool immediateIsNegative(ExprPtr e);
// Represents a ramp vector node:
// [base, base + 1 * stride, ... , base + (lanes - 1) * stride]
class TORCH_API Ramp : public ExprNode<Ramp> {
public:
- Expr* base() const {
+ ExprPtr base() const {
return base_;
}
- Expr* stride() const {
+ ExprPtr stride() const {
return stride_;
}
- void set_base(Expr* base) {
+ void set_base(ExprPtr base) {
base_ = base;
}
- void set_stride(Expr* stride) {
+ void set_stride(ExprPtr stride) {
stride_ = stride;
}
if (stride.dtype() != base.dtype()) {
throw malformed_input("Bad stride in Ramp");
}
- return ExprHandle(new Ramp(base.node(), stride.node(), lanes));
+ return ExprHandle(alloc<Ramp>(base.node(), stride.node(), lanes));
}
int lanes() const {
return lanes_;
}
- Ramp(Expr* base, Expr* stride, int lanes)
+ Ramp(ExprPtr base, ExprPtr stride, int lanes)
: ExprNodeBase(Dtype(base->dtype(), lanes)),
base_(base),
stride_(stride),
lanes_(lanes) {}
private:
- Expr* base_;
- Expr* stride_;
+ ExprPtr base_;
+ ExprPtr stride_;
int lanes_;
};
class TORCH_API Load : public ExprNode<Load> {
public:
- Var* base_handle() const {
+ VarPtr base_handle() const {
return buf_->base_handle();
}
- std::vector<Expr*> indices() const {
+ std::vector<ExprPtr> indices() const {
return indices_;
}
- Expr* flat_index() const {
+ ExprPtr flat_index() const {
TORCH_CHECK(indices_.size() == 1, "Indices haven't been flattened.");
return indices_[0];
}
- Buf* buf() const {
+ BufPtr buf() const {
return buf_;
}
- void set_buf(Buf* buf) {
+ void set_buf(BufPtr buf) {
buf_ = buf;
}
- void set_indices(std::vector<Expr*> indices) {
+ void set_indices(std::vector<ExprPtr> indices) {
indices_ = indices;
}
const BufHandle& buf,
const std::vector<ExprHandle>& indices);
- Load(Dtype dtype, Buf* base_handle, std::vector<Expr*> indices);
- Load(Buf* base_handle, const std::vector<Expr*>& indices);
+ Load(Dtype dtype, BufPtr base_handle, std::vector<ExprPtr> indices);
+ Load(BufPtr base_handle, const std::vector<ExprPtr>& indices);
private:
- Buf* buf_;
- std::vector<Expr*> indices_;
+ BufPtr buf_;
+ std::vector<ExprPtr> indices_;
};
class TORCH_API Broadcast : public ExprNode<Broadcast> {
public:
- Expr* value() const {
+ ExprPtr value() const {
return value_;
}
- void set_value(Expr* value) {
+ void set_value(ExprPtr value) {
value_ = value;
}
return lanes_;
}
static ExprHandle make(const ExprHandle& value, int lanes) {
- return ExprHandle(new Broadcast(value.node(), lanes));
+ return ExprHandle(alloc<Broadcast>(value.node(), lanes));
}
- Broadcast(Expr* value, int lanes)
+ Broadcast(ExprPtr value, int lanes)
: ExprNodeBase(Dtype(value->dtype(), lanes)),
value_(value),
lanes_(lanes) {}
private:
- Expr* value_;
+ ExprPtr value_;
int lanes_;
};
class TORCH_API IfThenElse : public ExprNode<IfThenElse> {
public:
- Expr* condition() const {
+ ExprPtr condition() const {
return condition_;
}
// Lazily evaluated only if condition is true
- Expr* true_value() const {
+ ExprPtr true_value() const {
return true_;
}
// Lazily evaluated only if condition is false
- Expr* false_value() const {
+ ExprPtr false_value() const {
return false_;
}
- void set_condition(Expr* condition) {
+ void set_condition(ExprPtr condition) {
condition_ = condition;
}
- void set_true_value(Expr* true_value) {
+ void set_true_value(ExprPtr true_value) {
true_ = true_value;
}
- void set_false_value(Expr* false_value) {
+ void set_false_value(ExprPtr false_value) {
false_ = false_value;
}
if (t.dtype() != f.dtype()) {
throw malformed_input("Bad dtype in IfThenElse");
}
- return ExprHandle(new IfThenElse(c.node(), t.node(), f.node()));
+ return ExprHandle(alloc<IfThenElse>(c.node(), t.node(), f.node()));
}
- IfThenElse(Expr* c, Expr* t, Expr* f)
+ IfThenElse(ExprPtr c, ExprPtr t, ExprPtr f)
: ExprNodeBase(t->dtype()), condition_(c), true_(t), false_(f) {}
private:
- Expr* condition_;
- Expr* true_;
- Expr* false_;
+ ExprPtr condition_;
+ ExprPtr true_;
+ ExprPtr false_;
};
class TORCH_API CompareSelect : public ExprNode<CompareSelect> {
CompareSelectOperation compare_select_op() const {
return compare_op_;
}
- Expr* lhs() const {
+ ExprPtr lhs() const {
return this->lhs_;
}
- Expr* rhs() const {
+ ExprPtr rhs() const {
return this->rhs_;
}
- Expr* ret_val1() const {
+ ExprPtr ret_val1() const {
return this->ret_val1_;
}
- Expr* ret_val2() const {
+ ExprPtr ret_val2() const {
return this->ret_val2_;
}
- void set_lhs(Expr* lhs) {
+ void set_lhs(ExprPtr lhs) {
lhs_ = lhs;
}
- void set_rhs(Expr* rhs) {
+ void set_rhs(ExprPtr rhs) {
rhs_ = rhs;
}
- void set_ret_val1(Expr* ret_val1) {
+ void set_ret_val1(ExprPtr ret_val1) {
ret_val1_ = ret_val1;
}
- void set_ret_val2(Expr* ret_val2) {
+ void set_ret_val2(ExprPtr ret_val2) {
ret_val2_ = ret_val2;
}
if (lhs.dtype() != rhs.dtype()) {
throw malformed_input("bad dtype in CompareSelect");
}
- return ExprHandle(new CompareSelect(
+ return ExprHandle(alloc<CompareSelect>(
lhs.node(),
rhs.node(),
IntImm::make(1).node(),
if (lhs.dtype() != rhs.dtype() || ret_val1.dtype() != ret_val2.dtype()) {
throw malformed_input("bad dtype in CompareSelect");
}
- return ExprHandle(new CompareSelect(
+ return ExprHandle(alloc<CompareSelect>(
lhs.node(),
rhs.node(),
ret_val1.node(),
}
CompareSelect(
- Expr* lhs,
- Expr* rhs,
- Expr* ret_val1,
- Expr* ret_val2,
+ ExprPtr lhs,
+ ExprPtr rhs,
+ ExprPtr ret_val1,
+ ExprPtr ret_val2,
CompareSelectOperation cmp_op,
CompareSelectBias bias = kUnbiased)
: ExprNodeBase(ret_val1->dtype()),
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
CompareSelect(
- Expr* lhs,
- Expr* rhs,
+ ExprPtr lhs,
+ ExprPtr rhs,
CompareSelectOperation cmp_op,
CompareSelectBias bias = kUnbiased)
: ExprNodeBase(kInt),
lhs_(lhs),
rhs_(rhs),
- ret_val1_(new IntImm(1)),
- ret_val2_(new IntImm(0)),
+ ret_val1_(alloc<IntImm>(1)),
+ ret_val2_(alloc<IntImm>(0)),
compare_op_(cmp_op),
bias_(bias) {}
private:
- Expr* lhs_;
- Expr* rhs_;
- Expr* ret_val1_;
- Expr* ret_val2_;
+ ExprPtr lhs_;
+ ExprPtr rhs_;
+ ExprPtr ret_val1_;
+ ExprPtr ret_val2_;
CompareSelectOperation compare_op_;
CompareSelectBias bias_;
};
class TORCH_API Intrinsics : public ExprNode<Intrinsics> {
public:
static ExprHandle make(IntrinsicsOp op_type, const ExprHandle& v1) {
- return ExprHandle(new Intrinsics(op_type, v1.node()));
+ return ExprHandle(alloc<Intrinsics>(op_type, v1.node()));
}
static ExprHandle make(
IntrinsicsOp op_type,
const ExprHandle& v1,
const ExprHandle& v2) {
- return ExprHandle(new Intrinsics(op_type, v1.node(), v2.node()));
+ return ExprHandle(alloc<Intrinsics>(op_type, v1.node(), v2.node()));
}
static ExprHandle make(
IntrinsicsOp op_type,
const std::vector<ExprHandle>& params) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::vector<Expr*> params_nodes(params.size());
+ std::vector<ExprPtr> params_nodes(params.size());
for (size_t i = 0; i < params.size(); i++) {
params_nodes[i] = params[i].node();
}
- return ExprHandle(new Intrinsics(op_type, params_nodes));
+ return ExprHandle(alloc<Intrinsics>(op_type, params_nodes));
}
static ExprHandle make(IntrinsicsOp op_type, Dtype dtype) {
- return ExprHandle(new Intrinsics(op_type, dtype));
+ return ExprHandle(alloc<Intrinsics>(op_type, dtype));
}
IntrinsicsOp op_type() const {
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
- Intrinsics(IntrinsicsOp op_type, Expr* v1)
+ Intrinsics(IntrinsicsOp op_type, ExprPtr v1)
: ExprNodeBase(IntrinsicsDtype(op_type, v1->dtype())),
params_({v1}),
op_type_(op_type) {
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
- Intrinsics(IntrinsicsOp op_type, Expr* v1, Expr* v2)
+ Intrinsics(IntrinsicsOp op_type, ExprPtr v1, ExprPtr v2)
: ExprNodeBase(IntrinsicsDtype(op_type, v1->dtype(), v2->dtype())),
params_({v1, v2}),
op_type_(op_type) {
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
- Intrinsics(IntrinsicsOp op_type, const std::vector<Expr*>& params)
+ Intrinsics(IntrinsicsOp op_type, const std::vector<ExprPtr>& params)
: ExprNodeBase(IntrinsicsDtype(op_type, params)),
params_(params),
op_type_(op_type) {
Intrinsics(
IntrinsicsOp op_type,
Dtype dtype,
- const std::vector<Expr*>& params)
+ const std::vector<ExprPtr>& params)
: ExprNodeBase(IntrinsicsDtype(op_type, dtype)),
params_(params),
op_type_(op_type) {
return params_.size();
}
- Expr* param(int index) const {
+ ExprPtr param(int index) const {
return params_[index];
}
- const std::vector<Expr*>& params() const {
+ const std::vector<ExprPtr>& params() const {
return params_;
}
- void set_params(std::vector<Expr*> params) {
+ void set_params(std::vector<ExprPtr> params) {
params_ = std::move(params);
}
static Dtype IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1, Dtype dt2);
static Dtype IntrinsicsDtype(
IntrinsicsOp op_type,
- const std::vector<Expr*>& params);
+ const std::vector<ExprPtr>& params);
- std::vector<Expr*> params_;
+ std::vector<ExprPtr> params_;
IntrinsicsOp op_type_;
};
class MaxTerm;
class MinTerm;
-TORCH_API std::vector<Expr*> ExprHandleVectorToExprVector(
+TORCH_API std::vector<ExprPtr> ExprHandleVectorToExprVector(
const std::vector<ExprHandle>&);
TORCH_API std::vector<ExprHandle> ExprVectorToExprHandleVector(
- const std::vector<Expr*>&);
-TORCH_API std::vector<Var*> VarHandleVectorToVarVector(
+ const std::vector<ExprPtr>&);
+TORCH_API std::vector<VarPtr> VarHandleVectorToVarVector(
const std::vector<VarHandle>&);
TORCH_API std::vector<VarHandle> VarVectorToVarHandleVector(
- const std::vector<Var*>&);
-TORCH_API Expr* flatten_index(
- const std::vector<Expr*>& dims,
- const std::vector<Expr*>& indices);
+ const std::vector<VarPtr>&);
+TORCH_API ExprPtr flatten_index(
+ const std::vector<ExprPtr>& dims,
+ const std::vector<ExprPtr>& indices);
} // namespace tensorexpr
} // namespace jit
namespace tensorexpr {
template <typename Op>
-static Expr* mutate_binary_op(
- BinaryOpNode<Op>* v,
+static ExprPtr mutate_binary_op(
+ NodePtr<BinaryOpNode<Op>> v,
IRCloner* cloner,
bool option = false) {
- Expr* lhs_new = v->lhs()->accept_mutator(cloner);
- Expr* rhs_new = v->rhs()->accept_mutator(cloner);
+ ExprPtr lhs_new = v->lhs()->accept_mutator(cloner);
+ ExprPtr rhs_new = v->rhs()->accept_mutator(cloner);
IRNodeType expr_type = v->expr_type();
switch (expr_type) {
case IRNodeType::kAdd:
- return new Add(lhs_new, rhs_new);
+ return alloc<Add>(lhs_new, rhs_new);
case IRNodeType::kSub:
- return new Sub(lhs_new, rhs_new);
+ return alloc<Sub>(lhs_new, rhs_new);
case IRNodeType::kMul:
- return new Mul(lhs_new, rhs_new);
+ return alloc<Mul>(lhs_new, rhs_new);
case IRNodeType::kDiv:
- return new Div(lhs_new, rhs_new);
+ return alloc<Div>(lhs_new, rhs_new);
case IRNodeType::kMod:
- return new Mod(lhs_new, rhs_new);
+ return alloc<Mod>(lhs_new, rhs_new);
case IRNodeType::kMax:
- return new Max(lhs_new, rhs_new, option);
+ return alloc<Max>(lhs_new, rhs_new, option);
case IRNodeType::kMin:
- return new Min(lhs_new, rhs_new, option);
+ return alloc<Min>(lhs_new, rhs_new, option);
case IRNodeType::kAnd:
- return new And(lhs_new, rhs_new);
+ return alloc<And>(lhs_new, rhs_new);
case IRNodeType::kOr:
- return new Or(lhs_new, rhs_new);
+ return alloc<Or>(lhs_new, rhs_new);
case IRNodeType::kXor:
- return new Xor(lhs_new, rhs_new);
+ return alloc<Xor>(lhs_new, rhs_new);
case IRNodeType::kLshift:
- return new Lshift(lhs_new, rhs_new);
+ return alloc<Lshift>(lhs_new, rhs_new);
case IRNodeType::kRshift:
- return new Rshift(lhs_new, rhs_new);
+ return alloc<Rshift>(lhs_new, rhs_new);
default:
throw unimplemented_lowering(v);
}
}
-Expr* IRCloner::mutate(Add* v) {
+ExprPtr IRCloner::mutate(AddPtr v) {
return mutate_binary_op(v, this);
}
-Expr* IRCloner::mutate(Sub* v) {
+ExprPtr IRCloner::mutate(SubPtr v) {
return mutate_binary_op(v, this);
}
-Expr* IRCloner::mutate(Mul* v) {
+ExprPtr IRCloner::mutate(MulPtr v) {
return mutate_binary_op(v, this);
}
-Expr* IRCloner::mutate(Div* v) {
+ExprPtr IRCloner::mutate(DivPtr v) {
return mutate_binary_op(v, this);
}
-Expr* IRCloner::mutate(Mod* v) {
+ExprPtr IRCloner::mutate(ModPtr v) {
return mutate_binary_op(v, this);
}
-Expr* IRCloner::mutate(And* v) {
+ExprPtr IRCloner::mutate(AndPtr v) {
return mutate_binary_op(v, this);
}
-Expr* IRCloner::mutate(Or* v) {
+ExprPtr IRCloner::mutate(OrPtr v) {
return mutate_binary_op(v, this);
}
-Expr* IRCloner::mutate(Xor* v) {
+ExprPtr IRCloner::mutate(XorPtr v) {
return mutate_binary_op(v, this);
}
-Expr* IRCloner::mutate(Lshift* v) {
+ExprPtr IRCloner::mutate(LshiftPtr v) {
return mutate_binary_op(v, this);
}
-Expr* IRCloner::mutate(Rshift* v) {
+ExprPtr IRCloner::mutate(RshiftPtr v) {
return mutate_binary_op(v, this);
}
-Expr* IRCloner::mutate(Max* v) {
+ExprPtr IRCloner::mutate(MaxPtr v) {
return mutate_binary_op(v, this, v->propagate_nans());
}
-Expr* IRCloner::mutate(Min* v) {
+ExprPtr IRCloner::mutate(MinPtr v) {
return mutate_binary_op(v, this, v->propagate_nans());
}
-Expr* IRCloner::mutate(CompareSelect* v) {
- Expr* lhs_new = v->lhs()->accept_mutator(this);
- Expr* rhs_new = v->rhs()->accept_mutator(this);
- Expr* retval1_new = v->ret_val1()->accept_mutator(this);
- Expr* retval2_new = v->ret_val2()->accept_mutator(this);
- return new CompareSelect(
+ExprPtr IRCloner::mutate(CompareSelectPtr v) {
+ ExprPtr lhs_new = v->lhs()->accept_mutator(this);
+ ExprPtr rhs_new = v->rhs()->accept_mutator(this);
+ ExprPtr retval1_new = v->ret_val1()->accept_mutator(this);
+ ExprPtr retval2_new = v->ret_val2()->accept_mutator(this);
+ return alloc<CompareSelect>(
lhs_new,
rhs_new,
retval1_new,
}
// NOLINTNEXTLINE
-#define IMM_MUTATE_DEFINE(_1, Name) \
- Expr* IRCloner::mutate(Name##Imm* v) { \
- return v; \
+#define IMM_MUTATE_DEFINE(_1, Name) \
+ ExprPtr IRCloner::mutate(Name##ImmPtr v) { \
+ return v; \
}
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_MUTATE_DEFINE);
#undef IMM_MUTATE_DEFINE
-Expr* IRCloner::mutate(Cast* v) {
- Expr* src_value_new = v->src_value()->accept_mutator(this);
- return new Cast(v->dtype(), src_value_new);
+ExprPtr IRCloner::mutate(CastPtr v) {
+ ExprPtr src_value_new = v->src_value()->accept_mutator(this);
+ return alloc<Cast>(v->dtype(), src_value_new);
}
-Expr* IRCloner::mutate(BitCast* v) {
- Expr* src_value_new = v->src_value()->accept_mutator(this);
- return new BitCast(v->dtype(), src_value_new);
+ExprPtr IRCloner::mutate(BitCastPtr v) {
+ ExprPtr src_value_new = v->src_value()->accept_mutator(this);
+ return alloc<BitCast>(v->dtype(), src_value_new);
}
-Expr* IRCloner::mutate(Ramp* v) {
- Expr* base_new = v->base()->accept_mutator(this);
- Expr* stride_new = v->stride()->accept_mutator(this);
- return new Ramp(base_new, stride_new, v->lanes());
+ExprPtr IRCloner::mutate(RampPtr v) {
+ ExprPtr base_new = v->base()->accept_mutator(this);
+ ExprPtr stride_new = v->stride()->accept_mutator(this);
+ return alloc<Ramp>(base_new, stride_new, v->lanes());
}
-Expr* IRCloner::mutate(Load* v) {
- std::vector<Expr*> indices_new;
+ExprPtr IRCloner::mutate(LoadPtr v) {
+ std::vector<ExprPtr> indices_new;
indices_new.reserve(v->indices().size());
- for (Expr* ind : v->indices()) {
+ for (ExprPtr ind : v->indices()) {
indices_new.push_back(ind->accept_mutator(this));
}
- Buf* buf_new = dynamic_cast<Buf*>(v->buf()->accept_mutator(this));
- return new Load(v->dtype(), buf_new, indices_new);
+ BufPtr buf_new = to<Buf>(v->buf()->accept_mutator(this));
+ return alloc<Load>(v->dtype(), buf_new, indices_new);
}
// We do not clone Vars since the original IR and cloned IR are expected to
// share the underlying variables.
-Expr* IRCloner::mutate(Var* v) {
+ExprPtr IRCloner::mutate(VarPtr v) {
return v;
}
// initializers, this is the expected usage of clone at this point.
//
// TODO: Revisit this if Bufs need to be cloned as well.
-Expr* IRCloner::mutate(Buf* v) {
+ExprPtr IRCloner::mutate(BufPtr v) {
return v;
}
-Expr* IRCloner::mutate(Broadcast* v) {
+ExprPtr IRCloner::mutate(BroadcastPtr v) {
int lanes = v->lanes();
- Expr* value_new = v->value()->accept_mutator(this);
- return new Broadcast(value_new, lanes);
+ ExprPtr value_new = v->value()->accept_mutator(this);
+ return alloc<Broadcast>(value_new, lanes);
}
-Expr* IRCloner::mutate(IfThenElse* v) {
- Expr* condition_new = v->condition()->accept_mutator(this);
- Expr* true_value_new = v->true_value()->accept_mutator(this);
- Expr* false_value_new = v->false_value()->accept_mutator(this);
+ExprPtr IRCloner::mutate(IfThenElsePtr v) {
+ ExprPtr condition_new = v->condition()->accept_mutator(this);
+ ExprPtr true_value_new = v->true_value()->accept_mutator(this);
+ ExprPtr false_value_new = v->false_value()->accept_mutator(this);
- return new IfThenElse(condition_new, true_value_new, false_value_new);
+ return alloc<IfThenElse>(condition_new, true_value_new, false_value_new);
}
-Expr* IRCloner::mutate(Intrinsics* v) {
- std::vector<Expr*> params_new;
+ExprPtr IRCloner::mutate(IntrinsicsPtr v) {
+ std::vector<ExprPtr> params_new;
params_new.reserve(v->nparams());
for (auto param : v->params()) {
params_new.push_back(param->accept_mutator(this));
}
- return new Intrinsics(v->op_type(), v->dtype(), params_new);
+ return alloc<Intrinsics>(v->op_type(), v->dtype(), params_new);
}
-Expr* IRCloner::mutate(Term* v) {
- Expr* scalar_new = v->scalar()->accept_mutator(this);
+ExprPtr IRCloner::mutate(TermPtr v) {
+ ExprPtr scalar_new = v->scalar()->accept_mutator(this);
- std::vector<Expr*> variables_new;
+ std::vector<ExprPtr> variables_new;
variables_new.reserve(v->variables().size());
- for (auto* t : v->variables()) {
+ for (auto t : v->variables()) {
variables_new.push_back(t->accept_mutator(this));
}
- return new Term(v->hasher(), scalar_new, variables_new);
+ return alloc<Term>(v->hasher(), scalar_new, variables_new);
}
-Expr* IRCloner::mutate(Polynomial* v) {
- Expr* scalar_new = v->scalar()->accept_mutator(this);
+ExprPtr IRCloner::mutate(PolynomialPtr v) {
+ ExprPtr scalar_new = v->scalar()->accept_mutator(this);
- std::vector<Term*> variables_new;
+ std::vector<TermPtr> variables_new;
variables_new.reserve(v->variables().size());
- for (auto* t : v->variables()) {
- variables_new.push_back(static_cast<Term*>(t->accept_mutator(this)));
+ for (auto t : v->variables()) {
+ variables_new.push_back(static_to<Term>(t->accept_mutator(this)));
}
- return new Polynomial(v->hasher(), scalar_new, variables_new);
+ return alloc<Polynomial>(v->hasher(), scalar_new, variables_new);
}
-Expr* IRCloner::mutate(RoundOff* v) {
- return new RoundOff(
+ExprPtr IRCloner::mutate(RoundOffPtr v) {
+ return alloc<RoundOff>(
v->lhs()->accept_mutator(this), v->rhs()->accept_mutator(this));
}
-Expr* IRCloner::mutate(MaxTerm* v) {
- Expr* scalar_new = v->scalar() ? v->scalar()->accept_mutator(this) : nullptr;
+ExprPtr IRCloner::mutate(MaxTermPtr v) {
+ ExprPtr scalar_new =
+ v->scalar() ? v->scalar()->accept_mutator(this) : nullptr;
- std::vector<Expr*> variables_new;
+ std::vector<ExprPtr> variables_new;
variables_new.reserve(v->variables().size());
- for (auto* t : v->variables()) {
+ for (auto t : v->variables()) {
variables_new.push_back(t->accept_mutator(this));
}
- return new MaxTerm(
+ return alloc<MaxTerm>(
v->hasher(), scalar_new, v->propagate_nans(), variables_new);
}
-Expr* IRCloner::mutate(MinTerm* v) {
- Expr* scalar_new = v->scalar() ? v->scalar()->accept_mutator(this) : nullptr;
+ExprPtr IRCloner::mutate(MinTermPtr v) {
+ ExprPtr scalar_new =
+ v->scalar() ? v->scalar()->accept_mutator(this) : nullptr;
- std::vector<Expr*> variables_new;
+ std::vector<ExprPtr> variables_new;
variables_new.reserve(v->variables().size());
- for (auto* t : v->variables()) {
+ for (auto t : v->variables()) {
variables_new.push_back(t->accept_mutator(this));
}
- return new MinTerm(
+ return alloc<MinTerm>(
v->hasher(), scalar_new, v->propagate_nans(), variables_new);
}
-Expr* IRCloner::mutate(ReduceOp* v) {
- Expr* body_new = v->body()->accept_mutator(this);
+ExprPtr IRCloner::mutate(ReduceOpPtr v) {
+ ExprPtr body_new = v->body()->accept_mutator(this);
- std::vector<Var*> reduce_args_new;
+ std::vector<VarPtr> reduce_args_new;
reduce_args_new.reserve(v->reduce_args().size());
- for (auto* r : v->reduce_args()) {
- reduce_args_new.push_back(static_cast<Var*>(r->accept_mutator(this)));
+ for (auto r : v->reduce_args()) {
+ reduce_args_new.push_back(static_to<Var>(r->accept_mutator(this)));
}
- return new ReduceOp(body_new, reduce_args_new, v->reducer());
+ return alloc<ReduceOp>(body_new, reduce_args_new, v->reducer());
}
-Stmt* IRCloner::mutate(For* v) {
+StmtPtr IRCloner::mutate(ForPtr v) {
auto start_new = v->start()->accept_mutator(this);
auto stop_new = v->stop()->accept_mutator(this);
auto body_new = v->body()->accept_mutator(this);
- return new For(v->var(), start_new, stop_new, body_new, v->loop_options());
+ return alloc<For>(v->var(), start_new, stop_new, body_new, v->loop_options());
}
-Stmt* IRCloner::mutate(Block* v) {
- std::vector<Stmt*> stmts_new;
+StmtPtr IRCloner::mutate(BlockPtr v) {
+ std::vector<StmtPtr> stmts_new;
stmts_new.reserve(v->nstmts());
- for (Stmt* stmt : *v) {
+ for (StmtPtr stmt : *v) {
stmts_new.push_back(stmt->accept_mutator(this));
}
- return new Block(stmts_new);
+ return alloc<Block>(stmts_new);
}
-Stmt* IRCloner::mutate(Store* v) {
- std::vector<Expr*> indices_new;
+StmtPtr IRCloner::mutate(StorePtr v) {
+ std::vector<ExprPtr> indices_new;
indices_new.reserve(v->indices().size());
for (auto ind : v->indices()) {
indices_new.push_back(ind->accept_mutator(this));
}
auto value_new = v->value()->accept_mutator(this);
- Buf* buf_new = dynamic_cast<Buf*>(v->buf()->accept_mutator(this));
- return new Store(buf_new, indices_new, value_new);
+ BufPtr buf_new = to<Buf>(v->buf()->accept_mutator(this));
+ return alloc<Store>(buf_new, indices_new, value_new);
}
-Stmt* IRCloner::mutate(AtomicAdd* v) {
- std::vector<Expr*> indices_new;
+StmtPtr IRCloner::mutate(AtomicAddPtr v) {
+ std::vector<ExprPtr> indices_new;
indices_new.reserve(v->indices().size());
for (auto ind : v->indices()) {
indices_new.push_back(ind->accept_mutator(this));
}
auto value_new = v->value()->accept_mutator(this);
- Buf* buf_new = dynamic_cast<Buf*>(v->buf()->accept_mutator(this));
- return new AtomicAdd(buf_new, indices_new, value_new);
+ BufPtr buf_new = to<Buf>(v->buf()->accept_mutator(this));
+ return alloc<AtomicAdd>(buf_new, indices_new, value_new);
}
-Stmt* IRCloner::mutate(Allocate* v) {
- Buf* buf_new = dynamic_cast<Buf*>(v->buf()->accept_mutator(this));
- return new Allocate(buf_new);
+StmtPtr IRCloner::mutate(AllocatePtr v) {
+ BufPtr buf_new = to<Buf>(v->buf()->accept_mutator(this));
+ return alloc<Allocate>(buf_new);
}
-Stmt* IRCloner::mutate(Free* v) {
- Buf* buf_new = dynamic_cast<Buf*>(v->buf()->accept_mutator(this));
- return new Free(buf_new);
+StmtPtr IRCloner::mutate(FreePtr v) {
+ BufPtr buf_new = to<Buf>(v->buf()->accept_mutator(this));
+ return alloc<Free>(buf_new);
}
-Stmt* IRCloner::mutate(SyncThreads* v) {
- return new SyncThreads();
+StmtPtr IRCloner::mutate(SyncThreadsPtr v) {
+ return alloc<SyncThreads>();
}
-Stmt* IRCloner::mutate(ExternalCall* v) {
- Buf* buf_new = dynamic_cast<Buf*>(v->buf()->accept_mutator(this));
+StmtPtr IRCloner::mutate(ExternalCallPtr v) {
+ BufPtr buf_new = to<Buf>(v->buf()->accept_mutator(this));
- std::vector<Buf*> buf_args_new;
+ std::vector<BufPtr> buf_args_new;
buf_args_new.reserve(v->buf_args().size());
- for (Buf* buf_arg : v->buf_args()) {
- buf_args_new.push_back(dynamic_cast<Buf*>(buf_arg->accept_mutator(this)));
+ for (BufPtr buf_arg : v->buf_args()) {
+ buf_args_new.push_back(to<Buf>(buf_arg->accept_mutator(this)));
}
- std::vector<Expr*> args_new;
+ std::vector<ExprPtr> args_new;
args_new.reserve(v->args().size());
- for (Expr* arg : v->args()) {
+ for (ExprPtr arg : v->args()) {
args_new.push_back(arg->accept_mutator(this));
}
- return new ExternalCall(buf_new, v->func_name(), buf_args_new, args_new);
+ return alloc<ExternalCall>(buf_new, v->func_name(), buf_args_new, args_new);
}
-Stmt* IRCloner::mutate(Let* v) {
+StmtPtr IRCloner::mutate(LetPtr v) {
auto value_new = v->value()->accept_mutator(this);
- return new Let(v->var(), value_new);
+ return alloc<Let>(v->var(), value_new);
}
-Stmt* IRCloner::mutate(Cond* v) {
+StmtPtr IRCloner::mutate(CondPtr v) {
auto condition_new = v->condition()->accept_mutator(this);
- Stmt* true_old = v->true_stmt();
- Stmt* false_old = v->false_stmt();
- Stmt* true_new = true_old ? true_old->accept_mutator(this) : true_old;
- Stmt* false_new = false_old ? false_old->accept_mutator(this) : false_old;
- return new Cond(condition_new, true_new, false_new);
+ StmtPtr true_old = v->true_stmt();
+ StmtPtr false_old = v->false_stmt();
+ StmtPtr true_new = true_old ? true_old->accept_mutator(this) : true_old;
+ StmtPtr false_new = false_old ? false_old->accept_mutator(this) : false_old;
+ return alloc<Cond>(condition_new, true_new, false_new);
}
-Stmt* Stmt::clone(Stmt* s) {
+StmtPtr Stmt::clone(StmtPtr s) {
IRCloner cloner;
- Stmt* cloned = s->accept_mutator(&cloner);
+ StmtPtr cloned = s->accept_mutator(&cloner);
set_parent(cloned, nullptr);
return cloned;
}
-Expr* Expr::clone(Expr* e) {
+ExprPtr Expr::clone(ExprPtr e) {
IRCloner cloner;
return e->accept_mutator(&cloner);
}
namespace jit {
namespace tensorexpr {
-class Add;
-class Sub;
-class Mul;
-class Div;
-class Mod;
-class Max;
-class Min;
-class And;
-class Or;
-class Xor;
-class Lshift;
-class Rshift;
-class CompareSelect;
-
-#define IMM_DECLARE(Type, Name) class Name##Imm;
-AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE);
-#undef IMM_DECLARE
-
-class Cast;
-class BitCast;
-class Var;
-class Buf;
-class Ramp;
-class Load;
-class For;
-class Block;
-class Store;
-class Broadcast;
-class IfThenElse;
-class ExprHandle;
-class Expr;
-class Intrinsics;
-class Allocate;
-class Free;
-class Let;
-class Cond;
-class Stmt;
-class Term;
-class Polynomial;
-class RoundOff;
-class MaxTerm;
-class MinTerm;
-class ReduceOp;
-class AtomicAdd;
-class SyncThreads;
-class ExternalCall;
-
class TORCH_API IRCloner : public IRMutator {
public:
~IRCloner() override = default;
- Expr* mutate(Add* v) override;
- Expr* mutate(Sub* v) override;
- Expr* mutate(Mul* v) override;
- Expr* mutate(Div* v) override;
- Expr* mutate(Mod* v) override;
- Expr* mutate(Max* v) override;
- Expr* mutate(Min* v) override;
- Expr* mutate(And* v) override;
- Expr* mutate(Or* v) override;
- Expr* mutate(Xor* v) override;
- Expr* mutate(Lshift* v) override;
- Expr* mutate(Rshift* v) override;
- Expr* mutate(CompareSelect* v) override;
-#define IMM_MUTATE_DECLARE(Type, Name) Expr* mutate(Name##Imm* v) override;
+ ExprPtr mutate(AddPtr v) override;
+ ExprPtr mutate(SubPtr v) override;
+ ExprPtr mutate(MulPtr v) override;
+ ExprPtr mutate(DivPtr v) override;
+ ExprPtr mutate(ModPtr v) override;
+ ExprPtr mutate(MaxPtr v) override;
+ ExprPtr mutate(MinPtr v) override;
+ ExprPtr mutate(AndPtr v) override;
+ ExprPtr mutate(OrPtr v) override;
+ ExprPtr mutate(XorPtr v) override;
+ ExprPtr mutate(LshiftPtr v) override;
+ ExprPtr mutate(RshiftPtr v) override;
+ ExprPtr mutate(CompareSelectPtr v) override;
+#define IMM_MUTATE_DECLARE(Type, Name) ExprPtr mutate(Name##ImmPtr v) override;
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_MUTATE_DECLARE);
#undef IMM_MUTATE_DECLARE
- Expr* mutate(Cast* v) override;
- Expr* mutate(BitCast* v) override;
- Expr* mutate(Var* v) override;
- Expr* mutate(Buf* v) override;
- Expr* mutate(Ramp* v) override;
- Expr* mutate(Load* v) override;
- Expr* mutate(Broadcast* v) override;
- Expr* mutate(IfThenElse* v) override;
- Expr* mutate(Intrinsics* v) override;
+ ExprPtr mutate(CastPtr v) override;
+ ExprPtr mutate(BitCastPtr v) override;
+ ExprPtr mutate(VarPtr v) override;
+ ExprPtr mutate(BufPtr v) override;
+ ExprPtr mutate(RampPtr v) override;
+ ExprPtr mutate(LoadPtr v) override;
+ ExprPtr mutate(BroadcastPtr v) override;
+ ExprPtr mutate(IfThenElsePtr v) override;
+ ExprPtr mutate(IntrinsicsPtr v) override;
- Expr* mutate(Term* v) override;
- Expr* mutate(Polynomial* v) override;
- Expr* mutate(RoundOff* v) override;
- Expr* mutate(MaxTerm* v) override;
- Expr* mutate(MinTerm* v) override;
+ ExprPtr mutate(TermPtr v) override;
+ ExprPtr mutate(PolynomialPtr v) override;
+ ExprPtr mutate(RoundOffPtr v) override;
+ ExprPtr mutate(MaxTermPtr v) override;
+ ExprPtr mutate(MinTermPtr v) override;
- Expr* mutate(ReduceOp* v) override;
+ ExprPtr mutate(ReduceOpPtr v) override;
- Stmt* mutate(For* v) override;
- Stmt* mutate(Block* v) override;
- Stmt* mutate(Store* v) override;
- Stmt* mutate(AtomicAdd* v) override;
- Stmt* mutate(SyncThreads* v) override;
- Stmt* mutate(ExternalCall* v) override;
+ StmtPtr mutate(ForPtr v) override;
+ StmtPtr mutate(BlockPtr v) override;
+ StmtPtr mutate(StorePtr v) override;
+ StmtPtr mutate(AtomicAddPtr v) override;
+ StmtPtr mutate(SyncThreadsPtr v) override;
+ StmtPtr mutate(ExternalCallPtr v) override;
- Stmt* mutate(Allocate* v) override;
- Stmt* mutate(Free* v) override;
- Stmt* mutate(Let* v) override;
- Stmt* mutate(Cond* v) override;
+ StmtPtr mutate(AllocatePtr v) override;
+ StmtPtr mutate(FreePtr v) override;
+ StmtPtr mutate(LetPtr v) override;
+ StmtPtr mutate(CondPtr v) override;
};
} // namespace tensorexpr
namespace tensorexpr {
template <typename Op>
-static Expr* mutate_binary_op(
+static ExprPtr mutate_binary_op(
BinaryOpNode<Op>* v,
IRMutator* mutator,
bool option = false) {
- Expr* lhs = v->lhs();
- Expr* rhs = v->rhs();
- Expr* lhs_new = lhs->accept_mutator(mutator);
- Expr* rhs_new = rhs->accept_mutator(mutator);
+ ExprPtr lhs = v->lhs();
+ ExprPtr rhs = v->rhs();
+ ExprPtr lhs_new = lhs->accept_mutator(mutator);
+ ExprPtr rhs_new = rhs->accept_mutator(mutator);
if (lhs != lhs_new) {
v->set_lhs(lhs_new);
}
return v;
}
-Expr* IRMutator::mutate(Add* v) {
+ExprPtr IRMutator::mutate(AddPtr v) {
return mutate_binary_op(v, this);
}
-Expr* IRMutator::mutate(Sub* v) {
+ExprPtr IRMutator::mutate(SubPtr v) {
return mutate_binary_op(v, this);
}
-Expr* IRMutator::mutate(Mul* v) {
+ExprPtr IRMutator::mutate(MulPtr v) {
return mutate_binary_op(v, this);
}
-Expr* IRMutator::mutate(Div* v) {
+ExprPtr IRMutator::mutate(DivPtr v) {
return mutate_binary_op(v, this);
}
-Expr* IRMutator::mutate(Mod* v) {
+ExprPtr IRMutator::mutate(ModPtr v) {
return mutate_binary_op(v, this);
}
-Expr* IRMutator::mutate(And* v) {
+ExprPtr IRMutator::mutate(AndPtr v) {
return mutate_binary_op(v, this);
}
-Expr* IRMutator::mutate(Or* v) {
+ExprPtr IRMutator::mutate(OrPtr v) {
return mutate_binary_op(v, this);
}
-Expr* IRMutator::mutate(Xor* v) {
+ExprPtr IRMutator::mutate(XorPtr v) {
return mutate_binary_op(v, this);
}
-Expr* IRMutator::mutate(Lshift* v) {
+ExprPtr IRMutator::mutate(LshiftPtr v) {
return mutate_binary_op(v, this);
}
-Expr* IRMutator::mutate(Rshift* v) {
+ExprPtr IRMutator::mutate(RshiftPtr v) {
return mutate_binary_op(v, this);
}
-Expr* IRMutator::mutate(Max* v) {
+ExprPtr IRMutator::mutate(MaxPtr v) {
return mutate_binary_op(v, this, v->propagate_nans());
}
-Expr* IRMutator::mutate(Min* v) {
+ExprPtr IRMutator::mutate(MinPtr v) {
return mutate_binary_op(v, this, v->propagate_nans());
}
-Expr* IRMutator::mutate(CompareSelect* v) {
- Expr* lhs = v->lhs();
- Expr* rhs = v->rhs();
- Expr* ret_val1 = v->ret_val1();
- Expr* ret_val2 = v->ret_val2();
- Expr* lhs_new = lhs->accept_mutator(this);
- Expr* rhs_new = rhs->accept_mutator(this);
- Expr* ret_val1_new = ret_val1->accept_mutator(this);
- Expr* ret_val2_new = ret_val2->accept_mutator(this);
+ExprPtr IRMutator::mutate(CompareSelectPtr v) {
+ ExprPtr lhs = v->lhs();
+ ExprPtr rhs = v->rhs();
+ ExprPtr ret_val1 = v->ret_val1();
+ ExprPtr ret_val2 = v->ret_val2();
+ ExprPtr lhs_new = lhs->accept_mutator(this);
+ ExprPtr rhs_new = rhs->accept_mutator(this);
+ ExprPtr ret_val1_new = ret_val1->accept_mutator(this);
+ ExprPtr ret_val2_new = ret_val2->accept_mutator(this);
if (lhs != lhs_new) {
v->set_lhs(lhs_new);
}
}
// NOLINTNEXTLINE
-#define IMM_MUTATE_DEFINE(_1, Name) \
- Expr* IRMutator::mutate(Name##Imm* v) { \
- return v; \
+#define IMM_MUTATE_DEFINE(_1, Name) \
+ ExprPtr IRMutator::mutate(Name##ImmPtr v) { \
+ return v; \
}
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_MUTATE_DEFINE);
#undef IMM_MUTATE_DEFINE
-Expr* IRMutator::mutate(Cast* v) {
- Expr* src_value = v->src_value();
- Expr* src_value_new = src_value->accept_mutator(this);
+ExprPtr IRMutator::mutate(CastPtr v) {
+ ExprPtr src_value = v->src_value();
+ ExprPtr src_value_new = src_value->accept_mutator(this);
if (src_value != src_value_new) {
v->set_src_value(src_value_new);
}
return v;
}
-Expr* IRMutator::mutate(BitCast* v) {
- Expr* src_value = v->src_value();
- Expr* src_value_new = src_value->accept_mutator(this);
+ExprPtr IRMutator::mutate(BitCastPtr v) {
+ ExprPtr src_value = v->src_value();
+ ExprPtr src_value_new = src_value->accept_mutator(this);
if (src_value != src_value_new) {
v->set_src_value(src_value_new);
}
return v;
}
-Expr* IRMutator::mutate(Var* v) {
+ExprPtr IRMutator::mutate(VarPtr v) {
return v;
}
-Expr* IRMutator::mutate(Ramp* v) {
- Expr* base = v->base();
- Expr* stride = v->stride();
- Expr* base_new = base->accept_mutator(this);
- Expr* stride_new = stride->accept_mutator(this);
+ExprPtr IRMutator::mutate(RampPtr v) {
+ ExprPtr base = v->base();
+ ExprPtr stride = v->stride();
+ ExprPtr base_new = base->accept_mutator(this);
+ ExprPtr stride_new = stride->accept_mutator(this);
if (base != base_new) {
v->set_base(base_new);
}
return v;
}
-Expr* IRMutator::mutate(Load* v) {
- Buf* buf = v->buf();
+ExprPtr IRMutator::mutate(LoadPtr v) {
+ BufPtr buf = v->buf();
bool any_index_changed = false;
- std::vector<Expr*> indices_new;
+ std::vector<ExprPtr> indices_new;
indices_new.reserve(v->indices().size());
- for (Expr* ind : v->indices()) {
- Expr* new_ind = ind->accept_mutator(this);
+ for (ExprPtr ind : v->indices()) {
+ ExprPtr new_ind = ind->accept_mutator(this);
if (new_ind != ind) {
any_index_changed = true;
}
indices_new.push_back(new_ind);
}
- Buf* buf_new = dynamic_cast<Buf*>(buf->accept_mutator(this));
+ BufPtr buf_new = to<Buf>(buf->accept_mutator(this));
if (buf != buf_new) {
v->set_buf(buf_new);
return v;
}
-Expr* IRMutator::mutate(Buf* v) {
- Var* var = v->base_handle();
- Var* var_new =
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
- dynamic_cast<Var*>(const_cast<Expr*>(var->accept_mutator(this)));
+ExprPtr IRMutator::mutate(BufPtr v) {
+ VarPtr var = v->base_handle();
+ VarPtr var_new = to<Var>(var->accept_mutator(this));
if (!var_new) {
return nullptr;
}
bool dims_changed = false;
- std::vector<Expr*> dims_old = v->dims();
- std::vector<Expr*> dims_new(dims_old.size());
+ std::vector<ExprPtr> dims_old = v->dims();
+ std::vector<ExprPtr> dims_new(dims_old.size());
for (const auto i : c10::irange(dims_old.size())) {
dims_new[i] = dims_old[i]->accept_mutator(this);
dims_changed |= (dims_new[i] != dims_old[i]);
return v;
}
-Expr* IRMutator::mutate(Broadcast* v) {
- Expr* value = v->value();
- Expr* value_new = value->accept_mutator(this);
+ExprPtr IRMutator::mutate(BroadcastPtr v) {
+ ExprPtr value = v->value();
+ ExprPtr value_new = value->accept_mutator(this);
if (value != value_new) {
v->set_value(value_new);
}
return v;
}
-Expr* IRMutator::mutate(IfThenElse* v) {
- Expr* condition = v->condition();
- Expr* true_value = v->true_value();
- Expr* false_value = v->false_value();
- Expr* condition_new = condition->accept_mutator(this);
- Expr* true_value_new = true_value->accept_mutator(this);
- Expr* false_value_new = false_value->accept_mutator(this);
+ExprPtr IRMutator::mutate(IfThenElsePtr v) {
+ ExprPtr condition = v->condition();
+ ExprPtr true_value = v->true_value();
+ ExprPtr false_value = v->false_value();
+ ExprPtr condition_new = condition->accept_mutator(this);
+ ExprPtr true_value_new = true_value->accept_mutator(this);
+ ExprPtr false_value_new = false_value->accept_mutator(this);
if (condition != condition_new) {
v->set_condition(condition_new);
return v;
}
-Expr* IRMutator::mutate(Intrinsics* v) {
- std::vector<Expr*> params(v->nparams());
+ExprPtr IRMutator::mutate(IntrinsicsPtr v) {
+ std::vector<ExprPtr> params(v->nparams());
bool any_change = false;
for (int i = 0; i < v->nparams(); i++) {
- Expr* value = v->param(i);
- Expr* value_new = value->accept_mutator(this);
+ ExprPtr value = v->param(i);
+ ExprPtr value_new = value->accept_mutator(this);
if (value != value_new) {
any_change = true;
}
return v;
}
-Expr* IRMutator::mutate(Term* v) {
- Expr* newScalar = v->scalar()->accept_mutator(this);
+ExprPtr IRMutator::mutate(TermPtr v) {
+ ExprPtr newScalar = v->scalar()->accept_mutator(this);
- std::vector<Expr*> variables;
- for (auto* t : v->variables()) {
+ std::vector<ExprPtr> variables;
+ for (auto t : v->variables()) {
variables.push_back(t->accept_mutator(this));
}
- return new Term(v->hasher(), newScalar, variables);
+ return alloc<Term>(v->hasher(), newScalar, variables);
}
-Expr* IRMutator::mutate(Polynomial* v) {
- Expr* newScalar = v->scalar()->accept_mutator(this);
+ExprPtr IRMutator::mutate(PolynomialPtr v) {
+ ExprPtr newScalar = v->scalar()->accept_mutator(this);
- std::vector<Term*> variables;
- for (auto* t : v->variables()) {
- variables.push_back(static_cast<Term*>(t->accept_mutator(this)));
+ std::vector<TermPtr> variables;
+ for (auto t : v->variables()) {
+ variables.push_back(static_to<Term>(t->accept_mutator(this)));
}
- return new Polynomial(v->hasher(), newScalar, variables);
+ return alloc<Polynomial>(v->hasher(), newScalar, variables);
}
-Expr* IRMutator::mutate(RoundOff* v) {
- return new RoundOff(
+ExprPtr IRMutator::mutate(RoundOffPtr v) {
+ return alloc<RoundOff>(
v->lhs()->accept_mutator(this), v->rhs()->accept_mutator(this));
}
-Expr* IRMutator::mutate(MaxTerm* v) {
- Expr* newScalar = nullptr;
+ExprPtr IRMutator::mutate(MaxTermPtr v) {
+ ExprPtr newScalar = nullptr;
if (v->scalar()) {
newScalar = v->scalar()->accept_mutator(this);
}
- std::vector<Expr*> variables;
- for (auto* t : v->variables()) {
+ std::vector<ExprPtr> variables;
+ for (auto t : v->variables()) {
variables.push_back(t->accept_mutator(this));
}
- return new MaxTerm(v->hasher(), newScalar, v->propagate_nans(), variables);
+ return alloc<MaxTerm>(v->hasher(), newScalar, v->propagate_nans(), variables);
}
-Expr* IRMutator::mutate(MinTerm* v) {
- Expr* newScalar = nullptr;
+ExprPtr IRMutator::mutate(MinTermPtr v) {
+ ExprPtr newScalar = nullptr;
if (v->scalar()) {
newScalar = v->scalar()->accept_mutator(this);
}
- std::vector<Expr*> variables;
- for (auto* t : v->variables()) {
+ std::vector<ExprPtr> variables;
+ for (auto t : v->variables()) {
variables.push_back(t->accept_mutator(this));
}
- return new MinTerm(v->hasher(), newScalar, v->propagate_nans(), variables);
+ return alloc<MinTerm>(v->hasher(), newScalar, v->propagate_nans(), variables);
}
-Expr* IRMutator::mutate(ReduceOp* v) {
- Expr* body_new = v->body()->accept_mutator(this);
+ExprPtr IRMutator::mutate(ReduceOpPtr v) {
+ ExprPtr body_new = v->body()->accept_mutator(this);
- std::vector<Var*> new_reduce_args;
- for (auto* r : v->reduce_args()) {
- new_reduce_args.push_back(static_cast<Var*>(r->accept_mutator(this)));
+ std::vector<VarPtr> new_reduce_args;
+ for (auto r : v->reduce_args()) {
+ new_reduce_args.push_back(static_to<Var>(r->accept_mutator(this)));
}
- return new ReduceOp(body_new, new_reduce_args, v->reducer());
+ return alloc<ReduceOp>(body_new, new_reduce_args, v->reducer());
}
-Stmt* IRMutator::mutate(For* v) {
- Expr* var = v->var();
- Expr* start = v->start();
- Expr* stop = v->stop();
- Stmt* body = v->body();
+StmtPtr IRMutator::mutate(ForPtr v) {
+ ExprPtr var = v->var();
+ ExprPtr start = v->start();
+ ExprPtr stop = v->stop();
+ StmtPtr body = v->body();
LoopOptions loop_options = v->loop_options();
- Expr* var_new_expr = var->accept_mutator(this);
- Var* var_new = dynamic_cast<Var*>(var_new_expr);
- Expr* start_new = start->accept_mutator(this);
- Expr* stop_new = stop->accept_mutator(this);
- Stmt* body_new = body->accept_mutator(this);
+ ExprPtr var_new_expr = var->accept_mutator(this);
+ VarPtr var_new = to<Var>(var_new_expr);
+ ExprPtr start_new = start->accept_mutator(this);
+ ExprPtr stop_new = stop->accept_mutator(this);
+ StmtPtr body_new = body->accept_mutator(this);
if (!body_new) {
return nullptr;
}
return v;
}
-Stmt* IRMutator::mutate(Block* v) {
+StmtPtr IRMutator::mutate(BlockPtr v) {
bool any_change = false;
- std::vector<Stmt*> stmts;
- for (Stmt* stmt : *v) {
- Stmt* stmt_new = stmt->accept_mutator(this);
+ std::vector<StmtPtr> stmts;
+ for (StmtPtr stmt : *v) {
+ StmtPtr stmt_new = stmt->accept_mutator(this);
if (stmt != stmt_new) {
any_change = true;
} else {
return v;
}
-Stmt* IRMutator::mutate(Store* v) {
- Buf* buf = v->buf();
+StmtPtr IRMutator::mutate(StorePtr v) {
+ BufPtr buf = v->buf();
bool any_index_changed = false;
- std::vector<Expr*> indices_new;
- for (Expr* ind : v->indices()) {
- Expr* new_ind = ind->accept_mutator(this);
+ std::vector<ExprPtr> indices_new;
+ for (ExprPtr ind : v->indices()) {
+ ExprPtr new_ind = ind->accept_mutator(this);
if (new_ind != ind) {
any_index_changed = true;
}
indices_new.push_back(new_ind);
}
- Expr* value = v->value();
- Buf* buf_new = dynamic_cast<Buf*>(buf->accept_mutator(this));
- Expr* value_new = value->accept_mutator(this);
+ ExprPtr value = v->value();
+ BufPtr buf_new = to<Buf>(buf->accept_mutator(this));
+ ExprPtr value_new = value->accept_mutator(this);
if (buf != buf_new) {
v->set_buf(buf_new);
return v;
}
-Stmt* IRMutator::mutate(AtomicAdd* v) {
- Buf* buf = v->buf();
+StmtPtr IRMutator::mutate(AtomicAddPtr v) {
+ BufPtr buf = v->buf();
bool any_index_changed = false;
- std::vector<Expr*> indices_new;
- for (Expr* ind : v->indices()) {
- Expr* new_ind = ind->accept_mutator(this);
+ std::vector<ExprPtr> indices_new;
+ for (ExprPtr ind : v->indices()) {
+ ExprPtr new_ind = ind->accept_mutator(this);
if (new_ind != ind) {
any_index_changed = true;
}
indices_new.push_back(new_ind);
}
- Expr* value = v->value();
- Buf* buf_new = dynamic_cast<Buf*>(buf->accept_mutator(this));
- Expr* value_new = value->accept_mutator(this);
+ ExprPtr value = v->value();
+ BufPtr buf_new = to<Buf>(buf->accept_mutator(this));
+ ExprPtr value_new = value->accept_mutator(this);
if (buf != buf_new) {
v->set_buf(buf_new);
return v;
}
-Stmt* IRMutator::mutate(SyncThreads* v) {
- return new SyncThreads();
+StmtPtr IRMutator::mutate(SyncThreadsPtr v) {
+ return alloc<SyncThreads>();
}
-Stmt* IRMutator::mutate(ExternalCall* v) {
- Buf* buf = v->buf();
- Buf* buf_new = dynamic_cast<Buf*>(buf->accept_mutator(this));
+StmtPtr IRMutator::mutate(ExternalCallPtr v) {
+ BufPtr buf = v->buf();
+ BufPtr buf_new = to<Buf>(buf->accept_mutator(this));
TORCH_INTERNAL_ASSERT(buf_new);
bool buf_args_changed = false;
- std::vector<Buf*> buf_args_new;
+ std::vector<BufPtr> buf_args_new;
buf_args_new.reserve(v->buf_args().size());
- for (Buf* buf_arg : v->buf_args()) {
- Buf* buf_arg_new = dynamic_cast<Buf*>(buf_arg->accept_mutator(this));
+ for (BufPtr buf_arg : v->buf_args()) {
+ BufPtr buf_arg_new = to<Buf>(buf_arg->accept_mutator(this));
TORCH_INTERNAL_ASSERT(buf_arg_new);
buf_args_new.push_back(buf_arg_new);
buf_args_changed |= buf_arg_new != buf_arg;
}
bool args_changed = false;
- std::vector<Expr*> args_new;
+ std::vector<ExprPtr> args_new;
args_new.reserve(v->args().size());
- for (Expr* arg : v->args()) {
- Expr* arg_new = arg->accept_mutator(this);
+ for (ExprPtr arg : v->args()) {
+ ExprPtr arg_new = arg->accept_mutator(this);
args_new.push_back(arg_new);
args_changed |= arg_new != arg;
}
return v;
}
-Stmt* IRMutator::mutate(Allocate* v) {
- Buf* buf = v->buf();
- Buf* buf_new = dynamic_cast<Buf*>(buf->accept_mutator(this));
+StmtPtr IRMutator::mutate(AllocatePtr v) {
+ BufPtr buf = v->buf();
+ BufPtr buf_new = to<Buf>(buf->accept_mutator(this));
TORCH_INTERNAL_ASSERT(buf_new);
if (buf != buf_new) {
v->set_buf(buf_new);
return v;
}
-Stmt* IRMutator::mutate(Free* v) {
- Buf* buf = v->buf();
- Buf* buf_new = dynamic_cast<Buf*>(buf->accept_mutator(this));
+StmtPtr IRMutator::mutate(FreePtr v) {
+ BufPtr buf = v->buf();
+ BufPtr buf_new = to<Buf>(buf->accept_mutator(this));
TORCH_INTERNAL_ASSERT(buf_new);
if (buf != buf_new) {
v->set_buf(buf_new);
return v;
}
-Stmt* IRMutator::mutate(Let* v) {
- Var* var_old = v->var();
- Var* var_new = dynamic_cast<Var*>(var_old->accept_mutator(this));
+StmtPtr IRMutator::mutate(LetPtr v) {
+ VarPtr var_old = v->var();
+ VarPtr var_new = to<Var>(var_old->accept_mutator(this));
- Expr* val_old = v->value();
- Expr* val_new = val_old->accept_mutator(this);
+ ExprPtr val_old = v->value();
+ ExprPtr val_new = val_old->accept_mutator(this);
if (var_old != var_new) {
v->set_var(var_new);
return v;
}
-Stmt* IRMutator::mutate(Cond* v) {
- Expr* cond_old = v->condition();
- Stmt* true_old = v->true_stmt();
- Stmt* false_old = v->false_stmt();
+StmtPtr IRMutator::mutate(CondPtr v) {
+ ExprPtr cond_old = v->condition();
+ StmtPtr true_old = v->true_stmt();
+ StmtPtr false_old = v->false_stmt();
- Expr* cond_new = cond_old->accept_mutator(this);
- Stmt* true_new = true_old ? true_old->accept_mutator(this) : true_old;
- Stmt* false_new = false_old ? false_old->accept_mutator(this) : false_old;
+ ExprPtr cond_new = cond_old->accept_mutator(this);
+ StmtPtr true_new = true_old ? true_old->accept_mutator(this) : true_old;
+ StmtPtr false_new = false_old ? false_old->accept_mutator(this) : false_old;
if (cond_old != cond_new) {
v->set_condition(cond_new);
#pragma once
#include <c10/core/ScalarType.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/tensorexpr/fwd_decls.h>
#include <vector>
namespace torch {
namespace jit {
namespace tensorexpr {
-class Add;
-class Sub;
-class Mul;
-class Div;
-class Mod;
-class Max;
-class Min;
-class And;
-class Or;
-class Xor;
-class Lshift;
-class Rshift;
-class CompareSelect;
-
-#define IMM_DECLARE(Type, Name) class Name##Imm;
-AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE);
-#undef IMM_DECLARE
-
-class Cast;
-class BitCast;
-class Var;
-class Buf;
-class Ramp;
-class Load;
-class For;
-class Block;
-class Store;
-class Broadcast;
-class IfThenElse;
-class ExprHandle;
-class Expr;
-class Intrinsics;
-class Allocate;
-class Free;
-class Let;
-class Cond;
-class Stmt;
-class Term;
-class Polynomial;
-class RoundOff;
-class MaxTerm;
-class MinTerm;
-class ReduceOp;
-class AtomicAdd;
-class SyncThreads;
-class ExternalCall;
-
class TORCH_API IRMutator {
public:
virtual ~IRMutator() = default;
- virtual Expr* mutate(Add* v);
- virtual Expr* mutate(Sub* v);
- virtual Expr* mutate(Mul* v);
- virtual Expr* mutate(Div* v);
- virtual Expr* mutate(Mod* v);
- virtual Expr* mutate(Max* v);
- virtual Expr* mutate(Min* v);
- virtual Expr* mutate(And* v);
- virtual Expr* mutate(Or* v);
- virtual Expr* mutate(Xor* v);
- virtual Expr* mutate(Lshift* v);
- virtual Expr* mutate(Rshift* v);
- virtual Expr* mutate(CompareSelect* v);
-#define IMM_MUTATE_DECLARE(Type, Name) virtual Expr* mutate(Name##Imm* v);
+ virtual ExprPtr mutate(AddPtr v);
+ virtual ExprPtr mutate(SubPtr v);
+ virtual ExprPtr mutate(MulPtr v);
+ virtual ExprPtr mutate(DivPtr v);
+ virtual ExprPtr mutate(ModPtr v);
+ virtual ExprPtr mutate(MaxPtr v);
+ virtual ExprPtr mutate(MinPtr v);
+ virtual ExprPtr mutate(AndPtr v);
+ virtual ExprPtr mutate(OrPtr v);
+ virtual ExprPtr mutate(XorPtr v);
+ virtual ExprPtr mutate(LshiftPtr v);
+ virtual ExprPtr mutate(RshiftPtr v);
+ virtual ExprPtr mutate(CompareSelectPtr v);
+#define IMM_MUTATE_DECLARE(Type, Name) virtual ExprPtr mutate(Name##ImmPtr v);
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_MUTATE_DECLARE);
#undef IMM_MUTATE_DECLARE
- virtual Expr* mutate(Cast* v);
- virtual Expr* mutate(BitCast* v);
- virtual Expr* mutate(Var* v);
- virtual Expr* mutate(Buf* v);
- virtual Expr* mutate(Ramp* v);
- virtual Expr* mutate(Load* v);
- virtual Expr* mutate(Broadcast* v);
- virtual Expr* mutate(IfThenElse* v);
- virtual Expr* mutate(Intrinsics* v);
-
- virtual Expr* mutate(Term* v);
- virtual Expr* mutate(Polynomial* v);
- virtual Expr* mutate(RoundOff* v);
- virtual Expr* mutate(MaxTerm* v);
- virtual Expr* mutate(MinTerm* v);
-
- virtual Expr* mutate(ReduceOp* v);
-
- virtual Stmt* mutate(For* v);
- virtual Stmt* mutate(Block* v);
- virtual Stmt* mutate(Store* v);
- virtual Stmt* mutate(AtomicAdd* v);
- virtual Stmt* mutate(SyncThreads* v);
- virtual Stmt* mutate(ExternalCall* v);
-
- virtual Stmt* mutate(Allocate* v);
- virtual Stmt* mutate(Free* v);
- virtual Stmt* mutate(Let* v);
- virtual Stmt* mutate(Cond* v);
+ virtual ExprPtr mutate(CastPtr v);
+ virtual ExprPtr mutate(BitCastPtr v);
+ virtual ExprPtr mutate(VarPtr v);
+ virtual ExprPtr mutate(BufPtr v);
+ virtual ExprPtr mutate(RampPtr v);
+ virtual ExprPtr mutate(LoadPtr v);
+ virtual ExprPtr mutate(BroadcastPtr v);
+ virtual ExprPtr mutate(IfThenElsePtr v);
+ virtual ExprPtr mutate(IntrinsicsPtr v);
+
+ virtual ExprPtr mutate(TermPtr v);
+ virtual ExprPtr mutate(PolynomialPtr v);
+ virtual ExprPtr mutate(RoundOffPtr v);
+ virtual ExprPtr mutate(MaxTermPtr v);
+ virtual ExprPtr mutate(MinTermPtr v);
+
+ virtual ExprPtr mutate(ReduceOpPtr v);
+
+ virtual StmtPtr mutate(ForPtr v);
+ virtual StmtPtr mutate(BlockPtr v);
+ virtual StmtPtr mutate(StorePtr v);
+ virtual StmtPtr mutate(AtomicAddPtr v);
+ virtual StmtPtr mutate(SyncThreadsPtr v);
+ virtual StmtPtr mutate(ExternalCallPtr v);
+
+ virtual StmtPtr mutate(AllocatePtr v);
+ virtual StmtPtr mutate(FreePtr v);
+ virtual StmtPtr mutate(LetPtr v);
+ virtual StmtPtr mutate(CondPtr v);
};
} // namespace tensorexpr
}
}
-void IRPrinter::visit(Add* v) {
+void IRPrinter::visit(AddPtr v) {
visitBinaryOp(v, "+", this);
}
-void IRPrinter::visit(Sub* v) {
+void IRPrinter::visit(SubPtr v) {
visitBinaryOp(v, "-", this);
}
-void IRPrinter::visit(Mul* v) {
+void IRPrinter::visit(MulPtr v) {
visitBinaryOp(v, "*", this);
}
-void IRPrinter::visit(Div* v) {
+void IRPrinter::visit(DivPtr v) {
visitBinaryOp(v, "/", this);
}
-void IRPrinter::visit(And* v) {
+void IRPrinter::visit(AndPtr v) {
visitBinaryOp(v, "&", this);
}
-void IRPrinter::visit(Or* v) {
+void IRPrinter::visit(OrPtr v) {
visitBinaryOp(v, "|", this);
}
-void IRPrinter::visit(Xor* v) {
+void IRPrinter::visit(XorPtr v) {
visitBinaryOp(v, "^", this);
}
-void IRPrinter::visit(Lshift* v) {
+void IRPrinter::visit(LshiftPtr v) {
visitBinaryOp(v, "<<", this);
}
-void IRPrinter::visit(Rshift* v) {
+void IRPrinter::visit(RshiftPtr v) {
visitBinaryOp(v, ">>", this);
}
-void IRPrinter::visit(Mod* v) {
+void IRPrinter::visit(ModPtr v) {
if (v->dtype().is_integral()) {
visitBinaryOp(v, "%", this);
} else if (v->dtype().is_floating_point()) {
}
}
-void IRPrinter::visit(Max* v) {
+void IRPrinter::visit(MaxPtr v) {
os() << "Max(";
v->lhs()->accept(this);
os() << ", ";
os() << ", " << (unsigned int)v->propagate_nans() << ")";
}
-void IRPrinter::visit(Min* v) {
+void IRPrinter::visit(MinPtr v) {
os() << "Min(";
v->lhs()->accept(this);
os() << ", ";
os() << ", " << (unsigned int)v->propagate_nans() << ")";
}
-void IRPrinter::visit(CompareSelect* v) {
+void IRPrinter::visit(CompareSelectPtr v) {
CompareSelectOperation cmp_op = v->compare_select_op();
int self_prec = getPrecedence(v->expr_type());
int lhs_prec = getPrecedence(v->lhs()->expr_type());
}
os() << " ? ";
- auto withParens = [&](Expr* e) {
+ auto withParens = [&](ExprPtr e) {
auto prec = getPrecedence(e->expr_type());
if (prec >= self_prec) {
os() << "(";
}
// NOLINTNEXTLINE
-#define IMM_PRINT_VISIT(Type, Name) \
- void IRPrinter::visit(Name##Imm* v) { \
- formatImm(os(), v->value()); \
+#define IMM_PRINT_VISIT(Type, Name) \
+ void IRPrinter::visit(Name##ImmPtr v) { \
+ formatImm(os(), v->value()); \
}
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_PRINT_VISIT);
#undef IMM_PRINT_VISIT
-void IRPrinter::visit(Cast* v) {
+void IRPrinter::visit(CastPtr v) {
auto dtype = v->dtype();
os() << dtypeToCppString(dtype) << "(";
v->src_value()->accept(this);
os() << ")";
}
-void IRPrinter::visit(Var* v) {
+void IRPrinter::visit(VarPtr v) {
os() << name_manager_.get_unique_name(v);
}
-void IRPrinter::visit(Ramp* v) {
+void IRPrinter::visit(RampPtr v) {
os() << "Ramp(" << *v->base() << ", " << *v->stride() << ", " << v->lanes()
<< ")";
}
-void IRPrinter::visit(Load* v) {
+void IRPrinter::visit(LoadPtr v) {
// TODO: support the mask case
if (v->indices().size() == 0) {
os() << *v->base_handle();
} else {
os() << *v->base_handle() << "[";
size_t i = 0;
- for (Expr* ind : v->indices()) {
+ for (ExprPtr ind : v->indices()) {
if (i++) {
os() << ", ";
}
}
}
-void IRPrinter::visit(Broadcast* v) {
+void IRPrinter::visit(BroadcastPtr v) {
os() << "Broadcast(" << *v->value() << ", " << v->lanes() << ")";
}
-void IRPrinter::visit(IfThenElse* v) {
+void IRPrinter::visit(IfThenElsePtr v) {
os() << "IfThenElse(" << *v->condition() << ", " << *v->true_value() << ", "
<< *v->false_value() << ")";
}
-void IRPrinter::visit(Intrinsics* v) {
+void IRPrinter::visit(IntrinsicsPtr v) {
os() << v->func_name() << "(";
for (const auto i : c10::irange(v->nparams())) {
if (i > 0) {
os() << ")";
}
-void IRPrinter::visit(Term* v) {
+void IRPrinter::visit(TermPtr v) {
os() << "Term(";
v->scalar()->accept(this);
- for (auto* t : v->variables()) {
+ for (auto t : v->variables()) {
os() << ",";
t->accept(this);
}
os() << ")";
}
-void IRPrinter::visit(Polynomial* v) {
+void IRPrinter::visit(PolynomialPtr v) {
bool first = true;
os() << "Polynomial(";
- for (auto* t : v->variables()) {
+ for (auto t : v->variables()) {
if (!first) {
os() << " + ";
}
os() << ")";
}
-void IRPrinter::visit(RoundOff* v) {
+void IRPrinter::visit(RoundOffPtr v) {
os() << "RoundOff(";
v->lhs()->accept(this);
os() << ", ";
os() << ")";
}
-void IRPrinter::visit(MaxTerm* v) {
+void IRPrinter::visit(MaxTermPtr v) {
os() << "MaxTerm(";
if (v->scalar()) {
v->scalar()->accept(this);
os() << ")";
}
-void IRPrinter::visit(MinTerm* v) {
+void IRPrinter::visit(MinTermPtr v) {
os() << "MinTerm(";
if (v->scalar()) {
v->scalar()->accept(this);
os() << ")";
}
-void IRPrinter::visit(ReduceOp* v) {
+void IRPrinter::visit(ReduceOpPtr v) {
os() << "ReduceOp(";
os() << *v->body() << ", ";
bool first = true;
os() << "reduce_args={";
- for (auto* d : v->reduce_args()) {
+ for (auto d : v->reduce_args()) {
if (!first) {
os() << ", ";
}
// each statement in a `Block` the printer will insert indentation before
// the statement and a newline after the statement.
-void IRPrinter::visit(Store* v) {
+void IRPrinter::visit(StorePtr v) {
// TODO: handle the mask
if (v->indices().size() == 0) {
os() << *v->base_handle() << " = " << *v->value() << ";";
os() << *v->base_handle() << "[";
size_t i = 0;
- for (Expr* ind : v->indices()) {
+ for (ExprPtr ind : v->indices()) {
if (i++) {
os() << ", ";
}
os() << "] = " << *v->value() << ";";
}
-void IRPrinter::visit(For* v) {
- Var* var = v->var();
+void IRPrinter::visit(ForPtr v) {
+ VarPtr var = v->var();
VarHandle vv(var);
os() << "for (" << dtypeToCppString(var->dtype()) << " " << vv << " = "
<< ExprHandle(v->start()) << "; " << vv << " < " << ExprHandle(v->stop())
}
}
-void IRPrinter::visit(Block* v) {
+void IRPrinter::visit(BlockPtr v) {
os() << "{\n";
indent_++;
- for (Stmt* s : *v) {
+ for (StmtPtr s : *v) {
emitIndent();
os() << *s << "\n";
}
os() << "}";
}
-void IRPrinter::visit(Allocate* v) {
+void IRPrinter::visit(AllocatePtr v) {
os() << "Allocate(" << *v->buffer_var()
<< "); // dtype=" << dtypeToCppString(v->dtype());
os() << ", dims=[";
- const std::vector<Expr*>& dims = v->dims();
+ const std::vector<ExprPtr>& dims = v->dims();
for (const auto i : c10::irange(dims.size())) {
if (i != 0) {
os() << ", ";
os() << "]";
}
-void IRPrinter::visit(Free* v) {
+void IRPrinter::visit(FreePtr v) {
os() << "Free(" << *v->buffer_var() << ");";
}
-void IRPrinter::visit(Let* v) {
+void IRPrinter::visit(LetPtr v) {
os() << dtypeToCppString(v->dtype()) << " " << *v->var();
os() << " = " << *v->value();
os() << ";";
}
-void IRPrinter::visit(Cond* v) {
- Expr* cond = v->condition();
- Stmt* true_stmt = v->true_stmt();
- Stmt* false_stmt = v->false_stmt();
+void IRPrinter::visit(CondPtr v) {
+ ExprPtr cond = v->condition();
+ StmtPtr true_stmt = v->true_stmt();
+ StmtPtr false_stmt = v->false_stmt();
if (!true_stmt) {
os() << "if (!" << *cond << ") ";
os() << *false_stmt;
}
}
-void IRPrinter::visit(AtomicAdd* v) {
+void IRPrinter::visit(AtomicAddPtr v) {
os() << "atomicAdd(&" << *v->base_handle() << "[";
size_t i = 0;
- for (Expr* ind : v->indices()) {
+ for (ExprPtr ind : v->indices()) {
if (i++) {
os() << ", ";
}
os() << "], " << *v->value() << ");";
}
-void IRPrinter::visit(SyncThreads* v) {
+void IRPrinter::visit(SyncThreadsPtr v) {
os() << "__syncthreads();";
}
-void IRPrinter::visit(ExternalCall* v) {
+void IRPrinter::visit(ExternalCallPtr v) {
os() << *v->buf() << " = " << v->func_name() << "(";
os() << "buf_args={";
int i = 0;
- for (Buf* buf_arg : v->buf_args()) {
+ for (BufPtr buf_arg : v->buf_args()) {
if (i++ > 0) {
os() << ", ";
}
os() << "}, args={";
i = 0;
- for (Expr* arg : v->args()) {
+ for (ExprPtr arg : v->args()) {
if (i++ > 0) {
os() << ", ";
}
return stream;
}
-void print(const Expr* expr) {
+void print(ExprPtr expr) {
if (expr) {
- Expr* mutable_expr = const_cast<Expr*>(expr);
IRPrinter p(std::cout);
- p.print(*mutable_expr);
+ p.print(*expr);
} else {
std::cout << "(null expr)";
}
std::cout << "\n";
}
-void print(const Stmt* stmt) {
+void print(StmtPtr stmt) {
if (stmt) {
- Stmt* mutable_stmt = const_cast<Stmt*>(stmt);
IRPrinter p(std::cout);
- p.print(*mutable_stmt);
+ p.print(*stmt);
} else {
std::cout << "(null stmt)\n";
}
} // namespace torch
namespace std {
-std::string to_string(const Expr* expr) {
+std::string to_string(ExprPtr expr) {
std::ostringstream oss;
oss << *expr;
return oss.str();
}
-std::string to_string(const Stmt* stmt) {
+std::string to_string(StmtPtr stmt) {
std::ostringstream oss;
oss << *stmt;
return oss.str();
#include <iostream>
+#include <torch/csrc/jit/tensorexpr/fwd_decls.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
#include <torch/csrc/jit/tensorexpr/unique_name_manager.h>
void print(ExprHandle);
void print(Expr&);
void print(Stmt&);
- void visit(Add* v) override;
- void visit(Sub* v) override;
- void visit(Mul* v) override;
- void visit(Div* v) override;
- void visit(Mod* v) override;
- void visit(Max* v) override;
- void visit(Min* v) override;
- void visit(And* v) override;
- void visit(Or* v) override;
- void visit(Xor* v) override;
- void visit(Lshift* v) override;
- void visit(Rshift* v) override;
- void visit(CompareSelect* v) override;
-#define IMM_PRINT_VISIT(Type, Name) void visit(Name##Imm* v) override;
+ void visit(AddPtr v) override;
+ void visit(SubPtr v) override;
+ void visit(MulPtr v) override;
+ void visit(DivPtr v) override;
+ void visit(ModPtr v) override;
+ void visit(MaxPtr v) override;
+ void visit(MinPtr v) override;
+ void visit(AndPtr v) override;
+ void visit(OrPtr v) override;
+ void visit(XorPtr v) override;
+ void visit(LshiftPtr v) override;
+ void visit(RshiftPtr v) override;
+ void visit(CompareSelectPtr v) override;
+#define IMM_PRINT_VISIT(Type, Name) void visit(Name##ImmPtr v) override;
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_PRINT_VISIT);
#undef IMM_PRINT_VISIT
- void visit(Cast* v) override;
- void visit(Var* v) override;
- void visit(Ramp* v) override;
- void visit(Load* v) override;
- void visit(Broadcast* v) override;
- void visit(IfThenElse* v) override;
- void visit(Intrinsics* v) override;
- void visit(Term* v) override;
- void visit(Polynomial* v) override;
- void visit(RoundOff* v) override;
- void visit(MaxTerm* v) override;
- void visit(MinTerm* v) override;
- void visit(ReduceOp* v) override;
-
- void visit(AtomicAdd* v) override;
- void visit(SyncThreads* v) override;
- void visit(ExternalCall* v) override;
- void visit(Store* v) override;
- void visit(For* v) override;
- void visit(Cond* v) override;
- void visit(Block* v) override;
- void visit(Allocate* v) override;
- void visit(Free* v) override;
- void visit(Let* v) override;
+ void visit(CastPtr v) override;
+ void visit(VarPtr v) override;
+ void visit(RampPtr v) override;
+ void visit(LoadPtr v) override;
+ void visit(BroadcastPtr v) override;
+ void visit(IfThenElsePtr v) override;
+ void visit(IntrinsicsPtr v) override;
+ void visit(TermPtr v) override;
+ void visit(PolynomialPtr v) override;
+ void visit(RoundOffPtr v) override;
+ void visit(MaxTermPtr v) override;
+ void visit(MinTermPtr v) override;
+ void visit(ReduceOpPtr v) override;
+
+ void visit(AtomicAddPtr v) override;
+ void visit(SyncThreadsPtr v) override;
+ void visit(ExternalCallPtr v) override;
+ void visit(StorePtr v) override;
+ void visit(ForPtr v) override;
+ void visit(CondPtr v) override;
+ void visit(BlockPtr v) override;
+ void visit(AllocatePtr v) override;
+ void visit(FreePtr v) override;
+ void visit(LetPtr v) override;
// A child class may have a difference rule for generating dtype
// string, e.g. CUDA needs int64_t to be generated as long long.
TORCH_API std::ostream& operator<<(std::ostream& stream, const Stmt&);
TORCH_API std::ostream& operator<<(std::ostream& stream, const Tensor&);
-TORCH_API void print(const Expr* expr);
-TORCH_API void print(const Stmt* stmt);
+TORCH_API void print(ExprPtr expr);
+TORCH_API void print(StmtPtr stmt);
TORCH_API void print(const Tensor* t);
} // namespace tensorexpr
namespace std {
using torch::jit::tensorexpr::Expr;
+using torch::jit::tensorexpr::ExprPtr;
using torch::jit::tensorexpr::Stmt;
+using torch::jit::tensorexpr::StmtPtr;
using torch::jit::tensorexpr::Tensor;
-TORCH_API std::string to_string(const Expr* expr);
-TORCH_API std::string to_string(const Stmt* stmt);
+TORCH_API std::string to_string(ExprPtr expr);
+TORCH_API std::string to_string(StmtPtr stmt);
TORCH_API std::string to_string(const Tensor* t);
} // namespace std
// Helper for determining if an Expr is a multi-lane primitive (e.g. Broadcast
// or Ramp).
-bool isMultilanePrimitive(Expr* e) {
- return dynamic_cast<Broadcast*>(e) || dynamic_cast<Ramp*>(e);
+bool isMultilanePrimitive(ExprPtr e) {
+ return to<Broadcast>(e) || to<Ramp>(e);
}
SimplifierHashType Term::hashVars() const {
SimplifierHashType hash;
- for (auto* v : variables_) {
+ for (auto v : variables_) {
hash = hasher_.hash_combine(hash, hasher_.hash(v));
}
if (dtype().is_floating_point()) {
throw std::logic_error("reordering FP ops");
}
- std::sort(variables_.begin(), variables_.end(), [&](Expr* a, Expr* b) {
+ std::sort(variables_.begin(), variables_.end(), [&](ExprPtr a, ExprPtr b) {
return hasher_.hash(a) < hasher_.hash(b);
});
}
SimplifierHashType Polynomial::hashVars() const {
SimplifierHashType hash;
- for (auto* v : variables_) {
+ for (auto v : variables_) {
hash = hasher_.hash_combine(hash, hasher_.hash(v));
}
return hash;
if (dtype().is_floating_point()) {
throw std::logic_error("reordering FP ops");
}
- std::sort(variables_.begin(), variables_.end(), [&](Expr* a, Expr* b) {
+ std::sort(variables_.begin(), variables_.end(), [&](ExprPtr a, ExprPtr b) {
return hasher_.hash(a) < hasher_.hash(b);
});
}
void MaxTerm::uniquefy() {
- std::sort(variables_.begin(), variables_.end(), [&](Expr* a, Expr* b) {
+ std::sort(variables_.begin(), variables_.end(), [&](ExprPtr a, ExprPtr b) {
return hasher_.hash(a) < hasher_.hash(b);
});
- auto it =
- std::unique(variables_.begin(), variables_.end(), [&](Expr* a, Expr* b) {
+ auto it = std::unique(
+ variables_.begin(), variables_.end(), [&](ExprPtr a, ExprPtr b) {
return hasher_.hash(a) == hasher_.hash(b);
});
variables_.resize(std::distance(variables_.begin(), it));
}
void MinTerm::uniquefy() {
- std::sort(variables_.begin(), variables_.end(), [&](Expr* a, Expr* b) {
+ std::sort(variables_.begin(), variables_.end(), [&](ExprPtr a, ExprPtr b) {
return hasher_.hash(a) < hasher_.hash(b);
});
- auto it =
- std::unique(variables_.begin(), variables_.end(), [&](Expr* a, Expr* b) {
+ auto it = std::unique(
+ variables_.begin(), variables_.end(), [&](ExprPtr a, ExprPtr b) {
return hasher_.hash(a) == hasher_.hash(b);
});
variables_.resize(std::distance(variables_.begin(), it));
// Handles optimization cases for Broadcast/Ramp +/- Broadcast/Ramp
template <class Op>
-Expr* combineMultilane(Expr* lhs, Expr* rhs) {
- if (Broadcast* bc = dynamic_cast<Broadcast*>(lhs)) {
- if (Broadcast* bcother = dynamic_cast<Broadcast*>(rhs)) {
+ExprPtr combineMultilane(ExprPtr lhs, ExprPtr rhs) {
+ if (BroadcastPtr bc = to<Broadcast>(lhs)) {
+ if (BroadcastPtr bcother = to<Broadcast>(rhs)) {
if (bc->lanes() != bcother->lanes()) {
throw malformed_input("multilane lane mismatch");
}
- Expr* ret =
- new Broadcast(new Op(bc->value(), bcother->value()), bc->lanes());
+ ExprPtr ret = alloc<Broadcast>(
+ alloc<Op>(bc->value(), bcother->value()), bc->lanes());
return ret;
}
- if (Ramp* r = dynamic_cast<Ramp*>(rhs)) {
+ if (RampPtr r = to<Ramp>(rhs)) {
if (bc->lanes() != r->lanes()) {
throw malformed_input("multilane lane mismatch");
}
- Expr* ret =
- new Ramp(new Op(bc->value(), r->base()), r->stride(), r->lanes());
+ ExprPtr ret = alloc<Ramp>(
+ alloc<Op>(bc->value(), r->base()), r->stride(), r->lanes());
return ret;
}
- } else if (Ramp* ramp = dynamic_cast<Ramp*>(lhs)) {
- if (Ramp* rother = dynamic_cast<Ramp*>(rhs)) {
+ } else if (RampPtr ramp = to<Ramp>(lhs)) {
+ if (RampPtr rother = to<Ramp>(rhs)) {
if (ramp->lanes() != rother->lanes()) {
throw malformed_input("multilane lane mismatch");
}
- Expr* ret = new Ramp(
- new Op(ramp->base(), rother->base()),
- new Op(ramp->stride(), rother->stride()),
+ ExprPtr ret = alloc<Ramp>(
+ alloc<Op>(ramp->base(), rother->base()),
+ alloc<Op>(ramp->stride(), rother->stride()),
ramp->lanes());
return ret;
}
- if (Broadcast* bc = dynamic_cast<Broadcast*>(rhs)) {
+ if (BroadcastPtr bc = to<Broadcast>(rhs)) {
if (ramp->lanes() != bc->lanes()) {
throw malformed_input("multilane lane mismatch");
}
- Expr* ret = new Ramp(
- new Op(ramp->base(), bc->value()), ramp->stride(), ramp->lanes());
+ ExprPtr ret = alloc<Ramp>(
+ alloc<Op>(ramp->base(), bc->value()), ramp->stride(), ramp->lanes());
return ret;
}
}
}
// Handles optimization cases for Broadcast/Ramp * Broadcast/Ramp
-Expr* mulMultilane(Expr* lhs, Expr* rhs) {
- if (Broadcast* bc = dynamic_cast<Broadcast*>(lhs)) {
- if (Broadcast* bcother = dynamic_cast<Broadcast*>(rhs)) {
+ExprPtr mulMultilane(ExprPtr lhs, ExprPtr rhs) {
+ if (BroadcastPtr bc = to<Broadcast>(lhs)) {
+ if (BroadcastPtr bcother = to<Broadcast>(rhs)) {
if (bc->lanes() != bcother->lanes()) {
throw malformed_input("multilane lane mismatch");
}
- Expr* ret =
- new Broadcast(new Mul(bc->value(), bcother->value()), bc->lanes());
+ ExprPtr ret = alloc<Broadcast>(
+ alloc<Mul>(bc->value(), bcother->value()), bc->lanes());
return ret;
}
- if (Ramp* r = dynamic_cast<Ramp*>(rhs)) {
+ if (RampPtr r = to<Ramp>(rhs)) {
if (bc->lanes() != r->lanes()) {
throw malformed_input("multilane lane mismatch");
}
- Expr* ret = new Ramp(
- new Mul(bc->value(), r->base()),
- new Mul(bc->value(), r->stride()),
+ ExprPtr ret = alloc<Ramp>(
+ alloc<Mul>(bc->value(), r->base()),
+ alloc<Mul>(bc->value(), r->stride()),
r->lanes());
return ret;
}
- } else if (Ramp* ramp = dynamic_cast<Ramp*>(lhs)) {
- if (Ramp* r = dynamic_cast<Ramp*>(rhs)) {
+ } else if (RampPtr ramp = to<Ramp>(lhs)) {
+ if (RampPtr r = to<Ramp>(rhs)) {
if (ramp->lanes() != r->lanes()) {
throw malformed_input("multilane lane mismatch");
}
- Expr* ret = new Ramp(
- new Mul(ramp->base(), r->base()),
- new Mul(ramp->stride(), r->stride()),
+ ExprPtr ret = alloc<Ramp>(
+ alloc<Mul>(ramp->base(), r->base()),
+ alloc<Mul>(ramp->stride(), r->stride()),
r->lanes());
return ret;
}
- if (Broadcast* bc = dynamic_cast<Broadcast*>(rhs)) {
+ if (BroadcastPtr bc = to<Broadcast>(rhs)) {
if (ramp->lanes() != bc->lanes()) {
throw malformed_input("multilane lane mismatch");
}
- Expr* ret = new Ramp(
- new Mul(bc->value(), ramp->base()),
- new Mul(bc->value(), ramp->stride()),
+ ExprPtr ret = alloc<Ramp>(
+ alloc<Mul>(bc->value(), ramp->base()),
+ alloc<Mul>(bc->value(), ramp->stride()),
ramp->lanes());
return ret;
}
}
void PolynomialTransformer::addOrUpdateTerm(
- std::unordered_map<SimplifierHashType, Term*>& varmap,
- Term* term) {
+ std::unordered_map<SimplifierHashType, TermPtr>& varmap,
+ TermPtr term) {
SimplifierHashType hash = term->hashVars();
auto insertRes = varmap.emplace(hash, term);
if (insertRes.second == false) {
- Term* lt = insertRes.first->second;
- Expr* termScalar = evaluateOp(new Add(lt->scalar(), term->scalar()));
+ TermPtr lt = insertRes.first->second;
+ ExprPtr termScalar = evaluateOp(alloc<Add>(lt->scalar(), term->scalar()));
// If the term is canceled out, remove from the map.
if (immediateEquals(termScalar, 0)) {
return;
}
- varmap[hash] = new Term(hasher_, termScalar, lt->variables());
+ varmap[hash] = alloc<Term>(hasher_, termScalar, lt->variables());
}
}
-Expr* PolynomialTransformer::addPolynomials(Polynomial* lhs, Polynomial* rhs) {
+ExprPtr PolynomialTransformer::addPolynomials(
+ PolynomialPtr lhs,
+ PolynomialPtr rhs) {
// simplify common components
// The key here is the variable hash, not the term's hash since we do want
// to combine terms that have the same vars but different scalar components.
- std::unordered_map<SimplifierHashType, Term*> varmap;
+ std::unordered_map<SimplifierHashType, TermPtr> varmap;
- for (auto* lt : lhs->variables()) {
+ for (auto lt : lhs->variables()) {
addOrUpdateTerm(varmap, lt);
}
- for (auto* rt : rhs->variables()) {
+ for (auto rt : rhs->variables()) {
addOrUpdateTerm(varmap, rt);
}
- Expr* newScalar = evaluateOp(new Add(lhs->scalar(), rhs->scalar()));
- return new Polynomial(hasher_, newScalar, varmap);
+ ExprPtr newScalar = evaluateOp(alloc<Add>(lhs->scalar(), rhs->scalar()));
+ return alloc<Polynomial>(hasher_, newScalar, varmap);
}
// Insert a new Term into the provided polynomial. If the new term has common
// variables to an existing term it is combined.
-Expr* PolynomialTransformer::insertTerm(Polynomial* poly, Term* term) {
+ExprPtr PolynomialTransformer::insertTerm(PolynomialPtr poly, TermPtr term) {
SimplifierHashType tHash = term->hashVars();
- std::vector<Term*> newVars;
+ std::vector<TermPtr> newVars;
bool found = false;
- for (auto* v : poly->variables()) {
+ for (auto v : poly->variables()) {
if (v->hashVars() == tHash) {
- Expr* newScalar = evaluateOp(new Add(term->scalar(), v->scalar()));
+ ExprPtr newScalar = evaluateOp(alloc<Add>(term->scalar(), v->scalar()));
found = true;
// Skip this term if we cancelled it out.
if (immediateEquals(newScalar, 0)) {
continue;
}
- auto* term = new Term(hasher_, newScalar, v->variables());
+ auto term = alloc<Term>(hasher_, newScalar, v->variables());
newVars.push_back(term);
} else {
newVars.push_back(v);
return poly->scalar();
}
- auto* Poly = new Polynomial(hasher_, poly->scalar(), newVars);
+ auto Poly = alloc<Polynomial>(hasher_, poly->scalar(), newVars);
return Poly;
}
-Expr* PolynomialTransformer::mutate(Add* v) {
- Expr* lhs_new = v->lhs()->accept_mutator(this);
- Expr* rhs_new = v->rhs()->accept_mutator(this);
+ExprPtr PolynomialTransformer::mutate(AddPtr v) {
+ ExprPtr lhs_new = v->lhs()->accept_mutator(this);
+ ExprPtr rhs_new = v->rhs()->accept_mutator(this);
// Constant Folding.
if (lhs_new->isConstant() && rhs_new->isConstant()) {
- Expr* result = evaluateOp(new Add(lhs_new, rhs_new));
+ ExprPtr result = evaluateOp(alloc<Add>(lhs_new, rhs_new));
return result;
}
// Multilane folding.
if (isMultilanePrimitive(lhs_new)) {
- if (auto* ret = combineMultilane<Add>(lhs_new, rhs_new)) {
+ if (auto ret = combineMultilane<Add>(lhs_new, rhs_new)) {
return ret->accept_mutator(this);
}
}
- Expr* scalar = nullptr;
- Expr* variable = nullptr;
+ ExprPtr scalar = nullptr;
+ ExprPtr variable = nullptr;
if (lhs_new->isConstant()) {
scalar = evaluateOp(lhs_new);
variable = rhs_new;
// If there is a scalar, and it's zero: short circuit and return the other
// side.
if (scalar && immediateEquals(scalar, 0)) {
- auto* c = new Cast(v->dtype(), variable);
+ auto c = alloc<Cast>(v->dtype(), variable);
return c->accept_mutator(this);
}
// dont want to combine ops.
if (lhs_new->dtype().is_floating_point() ||
rhs_new->dtype().is_floating_point()) {
- return new Add(lhs_new, rhs_new);
+ return alloc<Add>(lhs_new, rhs_new);
}
- Polynomial* lhsPoly = dynamic_cast<Polynomial*>(lhs_new);
- Polynomial* rhsPoly = dynamic_cast<Polynomial*>(rhs_new);
+ PolynomialPtr lhsPoly = to<Polynomial>(lhs_new);
+ PolynomialPtr rhsPoly = to<Polynomial>(rhs_new);
if (lhsPoly && rhsPoly) {
return addPolynomials(lhsPoly, rhsPoly);
}
- Term* lhsTerm = dynamic_cast<Term*>(lhs_new);
- Term* rhsTerm = dynamic_cast<Term*>(rhs_new);
+ TermPtr lhsTerm = to<Term>(lhs_new);
+ TermPtr rhsTerm = to<Term>(rhs_new);
if (lhsPoly && rhsTerm) {
return insertTerm(lhsPoly, rhsTerm);
if (lhsTerm && rhsTerm) {
// If the terms refer to the same variables: combine them.
if (lhsTerm->hashVars() == rhsTerm->hashVars()) {
- Expr* newScalar =
- evaluateOp(new Add(lhsTerm->scalar(), rhsTerm->scalar()));
+ ExprPtr newScalar =
+ evaluateOp(alloc<Add>(lhsTerm->scalar(), rhsTerm->scalar()));
// If the terms cancelled out, return zero.
if (immediateEquals(newScalar, 0)) {
return newScalar->accept_mutator(this);
}
- return new Term(hasher_, newScalar, lhsTerm->variables());
+ return alloc<Term>(hasher_, newScalar, lhsTerm->variables());
}
// Otherwise this is a new polynomial with no scalar and two variable
// terms.
- return new Polynomial(
+ return alloc<Polynomial>(
hasher_, getImmediateByType(v->dtype(), 0), lhsTerm, rhsTerm);
}
// Adds are commutative.
- Polynomial* poly = lhsPoly ? lhsPoly : rhsPoly;
+ PolynomialPtr poly = lhsPoly ? lhsPoly : rhsPoly;
// Add to Polynomial->scalar().
if (scalar && poly) {
- Expr* newScalar = evaluateOp(new Add(scalar, poly->scalar()));
- return new Polynomial(hasher_, newScalar, poly->variables());
+ ExprPtr newScalar = evaluateOp(alloc<Add>(scalar, poly->scalar()));
+ return alloc<Polynomial>(hasher_, newScalar, poly->variables());
}
// Simple Polynomial with a scalar and Term.
- Term* term = lhsTerm ? lhsTerm : rhsTerm;
+ TermPtr term = lhsTerm ? lhsTerm : rhsTerm;
if (scalar && term) {
- return new Polynomial(hasher_, scalar, term);
+ return alloc<Polynomial>(hasher_, scalar, term);
}
// Simple Term with a scalar and variable type.
if (scalar) {
- return new Polynomial(
+ return alloc<Polynomial>(
hasher_,
scalar,
- new Term(hasher_, getImmediateByType(v->dtype(), 1), variable));
+ alloc<Term>(hasher_, getImmediateByType(v->dtype(), 1), variable));
}
// If LHS is neither Term not Polynomial, wrap it in a Term.
if (!lhsTerm && !lhsPoly) {
- lhsTerm = new Term(hasher_, getImmediateByType(v->dtype(), 1), lhs_new);
+ lhsTerm = alloc<Term>(hasher_, getImmediateByType(v->dtype(), 1), lhs_new);
}
// Same for RHS.
if (!rhsTerm && !rhsPoly) {
- rhsTerm = new Term(hasher_, getImmediateByType(v->dtype(), 1), rhs_new);
+ rhsTerm = alloc<Term>(hasher_, getImmediateByType(v->dtype(), 1), rhs_new);
}
// If we now have a poly and a term, we can insert.
}
if (lhsTerm->hashVars() == rhsTerm->hashVars()) {
- return new Term(
+ return alloc<Term>(
hasher_,
- evaluateOp(new Add(lhsTerm->scalar(), rhsTerm->scalar())),
+ evaluateOp(alloc<Add>(lhsTerm->scalar(), rhsTerm->scalar())),
lhsTerm->variables());
}
// If all else fails we have a new Polynomial with two new variable Terms.
- return new Polynomial(
+ return alloc<Polynomial>(
hasher_, getImmediateByType(v->dtype(), 0), lhsTerm, rhsTerm);
}
-Expr* PolynomialTransformer::subTerms(Term* lhs, Term* rhs, bool negated) {
+ExprPtr PolynomialTransformer::subTerms(
+ TermPtr lhs,
+ TermPtr rhs,
+ bool negated) {
// If RHS not already negated, negate it.
if (!negated) {
- Expr* minusOne = getImmediateByType(rhs->dtype(), -1);
- Expr* negateScalar = evaluateOp(new Mul(minusOne, rhs->scalar()));
- rhs = new Term(hasher_, negateScalar, rhs->variables());
+ ExprPtr minusOne = getImmediateByType(rhs->dtype(), -1);
+ ExprPtr negateScalar = evaluateOp(alloc<Mul>(minusOne, rhs->scalar()));
+ rhs = alloc<Term>(hasher_, negateScalar, rhs->variables());
}
if (lhs->hashVars() == rhs->hashVars()) {
- Expr* newScalar = evaluateOp(new Add(lhs->scalar(), rhs->scalar()));
+ ExprPtr newScalar = evaluateOp(alloc<Add>(lhs->scalar(), rhs->scalar()));
// If the terms cancel out, return zero.
if (immediateEquals(newScalar, 0)) {
return newScalar;
}
- return new Term(hasher_, newScalar, lhs->variables());
+ return alloc<Term>(hasher_, newScalar, lhs->variables());
}
- return new Polynomial(
+ return alloc<Polynomial>(
hasher_,
getImmediateByType(promoteTypes(lhs->dtype(), rhs->dtype()), 0),
lhs,
// Subtract the RHS Polynomial from the LHS Polynomial, cancelling out where
// possible.
-Expr* PolynomialTransformer::subPolynomials(Polynomial* lhs, Polynomial* rhs) {
+ExprPtr PolynomialTransformer::subPolynomials(
+ PolynomialPtr lhs,
+ PolynomialPtr rhs) {
// simplify common components
// The key here is the variable hash, not the term's hash since we do want
// to combine terms that have the same vars but different scalar components.
- std::unordered_map<SimplifierHashType, Term*> varmap;
+ std::unordered_map<SimplifierHashType, TermPtr> varmap;
- for (auto* lt : lhs->variables()) {
+ for (auto lt : lhs->variables()) {
addOrUpdateTerm(varmap, lt);
}
- for (auto* rt : rhs->variables()) {
+ for (auto rt : rhs->variables()) {
// Polynomials add their terms, so negate the RHS's Terms.
- Expr* negated =
- evaluateOp(new Mul(getImmediateByType(rt->dtype(), -1), rt->scalar()));
- Term* newRHS = new Term(hasher_, negated, rt->variables());
+ ExprPtr negated = evaluateOp(
+ alloc<Mul>(getImmediateByType(rt->dtype(), -1), rt->scalar()));
+ TermPtr newRHS = alloc<Term>(hasher_, negated, rt->variables());
addOrUpdateTerm(varmap, newRHS);
}
- Expr* newScalar = evaluateOp(new Sub(lhs->scalar(), rhs->scalar()));
+ ExprPtr newScalar = evaluateOp(alloc<Sub>(lhs->scalar(), rhs->scalar()));
// No vars means this cancelled out to a scalar, return it unwrapped.
if (varmap.empty()) {
}
// Wrap new variables in a Polynomial.
- return new Polynomial(hasher_, newScalar, varmap);
+ return alloc<Polynomial>(hasher_, newScalar, varmap);
}
-Expr* PolynomialTransformer::mutate(Sub* v) {
- Expr* lhs_new = v->lhs()->accept_mutator(this);
- Expr* rhs_new = v->rhs()->accept_mutator(this);
+ExprPtr PolynomialTransformer::mutate(SubPtr v) {
+ ExprPtr lhs_new = v->lhs()->accept_mutator(this);
+ ExprPtr rhs_new = v->rhs()->accept_mutator(this);
// Constant Folding.
if (lhs_new->isConstant() && rhs_new->isConstant()) {
- Expr* result = evaluateOp(new Sub(lhs_new, rhs_new));
+ ExprPtr result = evaluateOp(alloc<Sub>(lhs_new, rhs_new));
return result;
}
// Multilane folding.
if (isMultilanePrimitive(lhs_new)) {
- if (auto* ret = combineMultilane<Sub>(lhs_new, rhs_new)) {
+ if (auto ret = combineMultilane<Sub>(lhs_new, rhs_new)) {
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
return ret->accept_mutator(this);
}
}
if (rhs_new->isConstant() && immediateEquals(rhs_new, 0)) {
- auto* c = new Cast(v->dtype(), lhs_new);
+ auto c = alloc<Cast>(v->dtype(), lhs_new);
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
return c->accept_mutator(this);
}
// dont want to combine ops.
if (lhs_new->dtype().is_floating_point() ||
rhs_new->dtype().is_floating_point()) {
- return new Sub(lhs_new, rhs_new);
+ return alloc<Sub>(lhs_new, rhs_new);
}
- Polynomial* lhsPoly = dynamic_cast<Polynomial*>(lhs_new);
- Polynomial* rhsPoly = dynamic_cast<Polynomial*>(rhs_new);
+ PolynomialPtr lhsPoly = to<Polynomial>(lhs_new);
+ PolynomialPtr rhsPoly = to<Polynomial>(rhs_new);
if (lhsPoly && rhsPoly) {
- auto* ret = subPolynomials(lhsPoly, rhsPoly);
+ auto ret = subPolynomials(lhsPoly, rhsPoly);
if (!ret) {
// Cancelled out completely.
return getImmediateByType(v->dtype(), 0);
return ret;
}
- Term* lhsTerm = dynamic_cast<Term*>(lhs_new);
- Term* rhsTerm = dynamic_cast<Term*>(rhs_new);
+ TermPtr lhsTerm = to<Term>(lhs_new);
+ TermPtr rhsTerm = to<Term>(rhs_new);
// Polynomial - Term.
if (lhsPoly && rhsTerm) {
// Negate the term.
- Expr* negate = evaluateOp(
- new Mul(getImmediateByType(rhsTerm->dtype(), -1), rhsTerm->scalar()));
- Term* newTerm = new Term(hasher_, negate, rhsTerm->variables());
+ ExprPtr negate = evaluateOp(alloc<Mul>(
+ getImmediateByType(rhsTerm->dtype(), -1), rhsTerm->scalar()));
+ TermPtr newTerm = alloc<Term>(hasher_, negate, rhsTerm->variables());
return insertTerm(lhsPoly, newTerm);
}
// Term - Polynomial.
if (rhsPoly && lhsTerm) {
// Negate every part of the Polynomial.
- Expr* minusOne = getImmediateByType(lhsTerm->dtype(), -1);
- Expr* negateScalar = evaluateOp(new Mul(minusOne, rhsPoly->scalar()));
+ ExprPtr minusOne = getImmediateByType(lhsTerm->dtype(), -1);
+ ExprPtr negateScalar = evaluateOp(alloc<Mul>(minusOne, rhsPoly->scalar()));
- std::vector<Term*> variables;
- for (auto* t : rhsPoly->variables()) {
- Expr* negate = evaluateOp(new Mul(minusOne, t->scalar()));
- variables.push_back(new Term(hasher_, negate, t->variables()));
+ std::vector<TermPtr> variables;
+ for (auto t : rhsPoly->variables()) {
+ ExprPtr negate = evaluateOp(alloc<Mul>(minusOne, t->scalar()));
+ variables.push_back(alloc<Term>(hasher_, negate, t->variables()));
}
- Polynomial* newPoly = new Polynomial(hasher_, negateScalar, variables);
+ PolynomialPtr newPoly = alloc<Polynomial>(hasher_, negateScalar, variables);
return insertTerm(newPoly, lhsTerm);
}
if (lhsPoly && rhsScalar) {
// Easy path, just sub the scalar component.
- Expr* newScalar = evaluateOp(new Sub(lhsPoly->scalar(), rhs_new));
- return new Polynomial(hasher_, newScalar, lhsPoly->variables());
+ ExprPtr newScalar = evaluateOp(alloc<Sub>(lhsPoly->scalar(), rhs_new));
+ return alloc<Polynomial>(hasher_, newScalar, lhsPoly->variables());
}
if (lhsScalar && rhsPoly) {
// Sub the scalar component.
- Expr* newScalar = evaluateOp(new Sub(lhs_new, rhsPoly->scalar()));
+ ExprPtr newScalar = evaluateOp(alloc<Sub>(lhs_new, rhsPoly->scalar()));
// Negate each term in the Polynomial RHS.
- Expr* minusOne = getImmediateByType(rhsPoly->dtype(), -1);
- std::vector<Term*> variables;
- for (auto* t : rhsPoly->variables()) {
- Expr* negate = evaluateOp(new Mul(minusOne, t->scalar()));
- variables.push_back(new Term(hasher_, negate, t->variables()));
+ ExprPtr minusOne = getImmediateByType(rhsPoly->dtype(), -1);
+ std::vector<TermPtr> variables;
+ for (auto t : rhsPoly->variables()) {
+ ExprPtr negate = evaluateOp(alloc<Mul>(minusOne, t->scalar()));
+ variables.push_back(alloc<Term>(hasher_, negate, t->variables()));
}
- return new Polynomial(hasher_, newScalar, variables);
+ return alloc<Polynomial>(hasher_, newScalar, variables);
}
if (lhsTerm && rhsScalar) {
// Negate the constant.
- Expr* negate =
- evaluateOp(new Mul(getImmediateByType(rhs_new->dtype(), -1), rhs_new));
- return new Polynomial(hasher_, negate, lhsTerm);
+ ExprPtr negate = evaluateOp(
+ alloc<Mul>(getImmediateByType(rhs_new->dtype(), -1), rhs_new));
+ return alloc<Polynomial>(hasher_, negate, lhsTerm);
}
if (lhsScalar && rhsTerm) {
// Negate the RHS Term.
- Expr* negate = evaluateOp(new Mul(
+ ExprPtr negate = evaluateOp(alloc<Mul>(
getImmediateByType(rhsTerm->scalar()->dtype(), -1), rhsTerm->scalar()));
- return new Polynomial(
- hasher_, lhs_new, new Term(hasher_, negate, rhsTerm->variables()));
+ return alloc<Polynomial>(
+ hasher_, lhs_new, alloc<Term>(hasher_, negate, rhsTerm->variables()));
}
// simple term with a scalar and variable type.
if (lhsScalar) {
// Create a negated term.
- return new Polynomial(
+ return alloc<Polynomial>(
hasher_,
lhs_new,
- new Term(hasher_, getImmediateByType(v->dtype(), -1), rhs_new));
+ alloc<Term>(hasher_, getImmediateByType(v->dtype(), -1), rhs_new));
}
if (rhsScalar) {
// Negate the scalar.
- Expr* negate =
- evaluateOp(new Mul(getImmediateByType(rhs_new->dtype(), -1), rhs_new));
- return new Polynomial(
+ ExprPtr negate = evaluateOp(
+ alloc<Mul>(getImmediateByType(rhs_new->dtype(), -1), rhs_new));
+ return alloc<Polynomial>(
hasher_,
negate,
- new Term(hasher_, getImmediateByType(v->dtype(), 1), lhs_new));
+ alloc<Term>(hasher_, getImmediateByType(v->dtype(), 1), lhs_new));
}
// no scalar...
if (!lhsTerm && !lhsPoly) {
- lhsTerm = new Term(hasher_, getImmediateByType(v->dtype(), 1), lhs_new);
+ lhsTerm = alloc<Term>(hasher_, getImmediateByType(v->dtype(), 1), lhs_new);
}
bool createdRHSnegated = false;
if (!rhsTerm && !rhsPoly) {
- rhsTerm = new Term(hasher_, getImmediateByType(v->dtype(), -1), rhs_new);
+ rhsTerm = alloc<Term>(hasher_, getImmediateByType(v->dtype(), -1), rhs_new);
createdRHSnegated = true;
}
// Insert wrapper Term into negated RHS Poly.
if (rhsPoly) {
CHECK(lhsTerm);
- Expr* minusOne = getImmediateByType(rhsPoly->dtype(), -1);
- Expr* newScalar = evaluateOp(new Mul(minusOne, rhsPoly->scalar()));
+ ExprPtr minusOne = getImmediateByType(rhsPoly->dtype(), -1);
+ ExprPtr newScalar = evaluateOp(alloc<Mul>(minusOne, rhsPoly->scalar()));
// Negate each term in the Polynomial RHS.
- std::vector<Term*> variables;
- for (auto* t : rhsPoly->variables()) {
- Expr* negate = evaluateOp(new Mul(minusOne, t->scalar()));
- variables.push_back(new Term(hasher_, negate, t->variables()));
+ std::vector<TermPtr> variables;
+ for (auto t : rhsPoly->variables()) {
+ ExprPtr negate = evaluateOp(alloc<Mul>(minusOne, t->scalar()));
+ variables.push_back(alloc<Term>(hasher_, negate, t->variables()));
}
- auto* poly = new Polynomial(hasher_, newScalar, variables);
+ auto poly = alloc<Polynomial>(hasher_, newScalar, variables);
return insertTerm(poly, lhsTerm);
}
- return new Polynomial(
+ return alloc<Polynomial>(
hasher_, getImmediateByType(v->dtype(), 0), lhsTerm, rhsTerm);
}
// Multiply two terms together, usually creating a new term with the variable
// lists concatenated.
-Term* PolynomialTransformer::mulTerms(Term* lhs, Term* rhs) {
- Expr* scalar = evaluateOp(new Mul(lhs->scalar(), rhs->scalar()));
+TermPtr PolynomialTransformer::mulTerms(TermPtr lhs, TermPtr rhs) {
+ ExprPtr scalar = evaluateOp(alloc<Mul>(lhs->scalar(), rhs->scalar()));
if (immediateEquals(scalar, 0)) {
return nullptr;
}
// Can reorder here since floating point ops don't get put into Terms.
- std::vector<Expr*> variables;
- std::vector<Expr*> multilaneVariables;
+ std::vector<ExprPtr> variables;
+ std::vector<ExprPtr> multilaneVariables;
// For now don't handle exponents.
- for (auto* c : lhs->variables()) {
+ for (auto c : lhs->variables()) {
if (isMultilanePrimitive(c)) {
multilaneVariables.push_back(c);
} else {
variables.push_back(c);
}
}
- for (auto* c : rhs->variables()) {
+ for (auto c : rhs->variables()) {
if (isMultilanePrimitive(c)) {
multilaneVariables.push_back(c);
} else {
}
// Merge all the multilane vars:
- Expr* lastNode{nullptr};
- for (auto* node : multilaneVariables) {
+ ExprPtr lastNode{nullptr};
+ for (auto node : multilaneVariables) {
if (lastNode == nullptr) {
lastNode = node;
} else {
- if (auto* next = mulMultilane(lastNode, node)) {
+ if (auto next = mulMultilane(lastNode, node)) {
lastNode = next->accept_mutator(this);
} else {
variables.push_back(lastNode);
variables.push_back(lastNode);
}
- return new Term(hasher_, scalar, variables);
+ return alloc<Term>(hasher_, scalar, variables);
}
// Multiply a Polynomial by a Term.
-Expr* PolynomialTransformer::polyByTerm(Polynomial* poly, Term* term) {
+ExprPtr PolynomialTransformer::polyByTerm(PolynomialPtr poly, TermPtr term) {
// poly * term
// = (poly_terms + poly_scalar) * term
// = poly_terms * term + poly_scalar * term
// First, multiply all variables (terms) in the polynomial by the input
// term.
- std::vector<Term*> newTerms;
- for (auto* var : poly->variables()) {
- Term* newTerm = mulTerms(var, term);
+ std::vector<TermPtr> newTerms;
+ for (auto var : poly->variables()) {
+ TermPtr newTerm = mulTerms(var, term);
if (newTerm) {
newTerms.push_back(newTerm);
}
// polynomial. If there are variables in term, this becomes a new term in
// the result polynomial.
if (!immediateEquals(poly->scalar(), 0)) {
- Expr* scalar = evaluateOp(new Mul(poly->scalar(), term->scalar()));
+ ExprPtr scalar = evaluateOp(alloc<Mul>(poly->scalar(), term->scalar()));
if (term->variables().empty()) {
- return new Polynomial(hasher_, scalar, newTerms);
+ return alloc<Polynomial>(hasher_, scalar, newTerms);
}
- newTerms.push_back(new Term(hasher_, scalar, term->variables()));
+ newTerms.push_back(alloc<Term>(hasher_, scalar, term->variables()));
}
// The only case when the result polynomial has a scalar is when the input
// term does not have any variables and the input polynomial has a non-zero
// scalar. That case is handled above. So, at this point, we do not have any
// scalars in the result polynomial.
- return new Polynomial(hasher_, std::move(newTerms));
+ return alloc<Polynomial>(hasher_, std::move(newTerms));
}
// Does multiplying these two expressions make a Rounding Off operation.
// e.g. LHS = (x/y), RHS = y => (x / y) * y => RoundOff(x, y).
-Expr* PolynomialTransformer::isRoundOff(Expr* lhs, Expr* rhs) {
- Div* div{nullptr};
- Expr* other{nullptr};
+ExprPtr PolynomialTransformer::isRoundOff(ExprPtr lhs, ExprPtr rhs) {
+ DivPtr div{nullptr};
+ ExprPtr other{nullptr};
- if ((div = dynamic_cast<Div*>(lhs))) {
+ if ((div = to<Div>(lhs))) {
other = rhs;
- } else if ((div = dynamic_cast<Div*>(rhs))) {
+ } else if ((div = to<Div>(rhs))) {
other = lhs;
} else {
return nullptr;
}
- Expr* denom = div->rhs();
+ ExprPtr denom = div->rhs();
- if (Term* denomTerm = dynamic_cast<Term*>(denom)) {
+ if (TermPtr denomTerm = to<Term>(denom)) {
if (immediateEquals(denomTerm->scalar(), 1) &&
denomTerm->variables().size() == 1) {
denom = denomTerm->variables()[0];
if (hasher_.hash(denom) == hasher_.hash(other)) {
// If the denominator is equal to the other, then yes it's a RoundOff.
- return new RoundOff(div->lhs(), div->rhs());
+ return alloc<RoundOff>(div->lhs(), div->rhs());
}
if (denom->isConstant() && other->isConstant()) {
return nullptr;
}
// If they are both scalar we may be able to find a common factor.
- if (immediateEquals(evaluateOp(new Mod(other, denom)), 0)) {
- Expr* scalar = evaluateOp(new Div(other, denom));
- Expr* newDenom = evaluateOp(new Div(other, scalar));
- return new Term(hasher_, scalar, new RoundOff(div->lhs(), newDenom));
+ if (immediateEquals(evaluateOp(alloc<Mod>(other, denom)), 0)) {
+ ExprPtr scalar = evaluateOp(alloc<Div>(other, denom));
+ ExprPtr newDenom = evaluateOp(alloc<Div>(other, scalar));
+ return alloc<Term>(
+ hasher_, scalar, alloc<RoundOff>(div->lhs(), newDenom));
}
}
}
// Inserts a new component into a term, looking for opportunities to simplify.
-Expr* PolynomialTransformer::insertIntoTerm(Term* term, Expr* expr) {
- std::vector<Expr*> vars;
+ExprPtr PolynomialTransformer::insertIntoTerm(TermPtr term, ExprPtr expr) {
+ std::vector<ExprPtr> vars;
// Search for RoundOffs.
bool merged{false};
- for (auto* component : term->variables()) {
- if (auto* roundoff = isRoundOff(component, expr)) {
+ for (auto component : term->variables()) {
+ if (auto roundoff = isRoundOff(component, expr)) {
vars.push_back(roundoff);
merged = true;
} else {
return vars[0];
}
- return new Term(hasher_, term->scalar(), vars);
+ return alloc<Term>(hasher_, term->scalar(), vars);
}
-Expr* PolynomialTransformer::mutate(Mul* v) {
- Expr* lhs_new = v->lhs()->accept_mutator(this);
- Expr* rhs_new = v->rhs()->accept_mutator(this);
+ExprPtr PolynomialTransformer::mutate(MulPtr v) {
+ ExprPtr lhs_new = v->lhs()->accept_mutator(this);
+ ExprPtr rhs_new = v->rhs()->accept_mutator(this);
// Constant Folding.
if (lhs_new->isConstant() && rhs_new->isConstant()) {
- return evaluateOp(new Mul(lhs_new, rhs_new));
+ return evaluateOp(alloc<Mul>(lhs_new, rhs_new));
}
// Multilane folding.
if (isMultilanePrimitive(lhs_new)) {
- if (auto* ret = mulMultilane(lhs_new, rhs_new)) {
+ if (auto ret = mulMultilane(lhs_new, rhs_new)) {
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
return ret->accept_mutator(this);
}
}
// Order doesn't matter.
- Expr* scalar = nullptr;
- Expr* variable = nullptr;
+ ExprPtr scalar = nullptr;
+ ExprPtr variable = nullptr;
if (lhs_new->isConstant()) {
scalar = lhs_new;
variable = rhs_new;
// Handle special case mul by 1 since thats safe for floating point, even if
// it's Nan/Inf.
if (scalar && immediateEquals(scalar, 1)) {
- auto* c = new Cast(v->dtype(), variable);
+ auto c = alloc<Cast>(v->dtype(), variable);
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
return c->accept_mutator(this);
}
// dont want to combine ops.
if (lhs_new->dtype().is_floating_point() ||
rhs_new->dtype().is_floating_point()) {
- return new Mul(lhs_new, rhs_new);
+ return alloc<Mul>(lhs_new, rhs_new);
}
// Handle special case mul by 0.
}
// Catch cases of rounding (Div(A/B) * B).
- if (auto* ret = isRoundOff(lhs_new, rhs_new)) {
+ if (auto ret = isRoundOff(lhs_new, rhs_new)) {
return ret;
- } else if (auto* ret = isRoundOff(v->lhs(), v->rhs())) {
+ } else if (auto ret = isRoundOff(v->lhs(), v->rhs())) {
// We can break the Round + Mod pattern via factorization of the Div, so
// check whether it would have worked on the unsimplified tree. If so, we
// need to simplify again.
return ret->accept_mutator(this);
}
- Polynomial* lhsPoly = dynamic_cast<Polynomial*>(lhs_new);
- Polynomial* rhsPoly = dynamic_cast<Polynomial*>(rhs_new);
+ PolynomialPtr lhsPoly = to<Polynomial>(lhs_new);
+ PolynomialPtr rhsPoly = to<Polynomial>(rhs_new);
if (lhsPoly && rhsPoly) {
// This expands to more terms that we can't generally fix without variable
// factorization, it's more efficient to just leave these as Muls.
- return new Mul(lhsPoly, rhsPoly);
+ return alloc<Mul>(lhsPoly, rhsPoly);
}
- Term* lhsTerm = dynamic_cast<Term*>(lhs_new);
- Term* rhsTerm = dynamic_cast<Term*>(rhs_new);
+ TermPtr lhsTerm = to<Term>(lhs_new);
+ TermPtr rhsTerm = to<Term>(rhs_new);
if (lhsPoly && rhsTerm) {
return polyByTerm(lhsPoly, rhsTerm);
}
if (scalar && lhsTerm) {
- Expr* newScalar = evaluateOp(new Mul(scalar, lhsTerm->scalar()));
- return new Term(hasher_, newScalar, lhsTerm->variables());
+ ExprPtr newScalar = evaluateOp(alloc<Mul>(scalar, lhsTerm->scalar()));
+ return alloc<Term>(hasher_, newScalar, lhsTerm->variables());
}
if (scalar && rhsTerm) {
- Expr* newScalar = evaluateOp(new Mul(scalar, rhsTerm->scalar()));
- return new Term(hasher_, newScalar, rhsTerm->variables());
+ ExprPtr newScalar = evaluateOp(alloc<Mul>(scalar, rhsTerm->scalar()));
+ return alloc<Term>(hasher_, newScalar, rhsTerm->variables());
}
// If this is a scalar * a Polynomial, push the scalar term down.
// We can wrap the scalar with a Term and use polyByTerm.
if (scalar && lhsPoly) {
- return polyByTerm(lhsPoly, new Term(hasher_, scalar));
+ return polyByTerm(lhsPoly, alloc<Term>(hasher_, scalar));
}
if (scalar && rhsPoly) {
- return polyByTerm(rhsPoly, new Term(hasher_, scalar));
+ return polyByTerm(rhsPoly, alloc<Term>(hasher_, scalar));
}
// simple term with a scalar and variable type.
if (scalar) {
- return new Term(hasher_, scalar, variable);
+ return alloc<Term>(hasher_, scalar, variable);
}
// Multiplying Polynomial by variable can be wrapped in a term and handled
// by polyByTerm also.
if (lhsPoly) {
- auto* term =
- new Term(hasher_, getImmediateByType(rhs_new->dtype(), 1), rhs_new);
+ auto term =
+ alloc<Term>(hasher_, getImmediateByType(rhs_new->dtype(), 1), rhs_new);
return polyByTerm(lhsPoly, term);
}
if (rhsPoly) {
- auto* term =
- new Term(hasher_, getImmediateByType(lhs_new->dtype(), 1), lhs_new);
+ auto term =
+ alloc<Term>(hasher_, getImmediateByType(lhs_new->dtype(), 1), lhs_new);
return polyByTerm(rhsPoly, term);
}
}
// Two variables, create a new Term.
- return new Term(hasher_, getImmediateByType(v->dtype(), 1), lhs_new, rhs_new);
+ return alloc<Term>(
+ hasher_, getImmediateByType(v->dtype(), 1), lhs_new, rhs_new);
}
-Expr* factorizeDivision(Expr* lhs_new, Expr* rhs_new) {
+ExprPtr factorizeDivision(ExprPtr lhs_new, ExprPtr rhs_new) {
if (!lhs_new || !rhs_new) {
return nullptr;
}
- Expr* leftScalar = lhs_new->isConstant() ? lhs_new : nullptr;
- Expr* rightScalar = rhs_new->isConstant() ? rhs_new : nullptr;
+ ExprPtr leftScalar = lhs_new->isConstant() ? lhs_new : nullptr;
+ ExprPtr rightScalar = rhs_new->isConstant() ? rhs_new : nullptr;
- auto* lhsTerm = dynamic_cast<Term*>(lhs_new);
- auto* rhsTerm = dynamic_cast<Term*>(rhs_new);
+ auto lhsTerm = to<Term>(lhs_new);
+ auto rhsTerm = to<Term>(rhs_new);
if (lhsTerm) {
leftScalar = lhsTerm->scalar();
}
}
leftScalar = evaluateOp(
- new Div(leftScalar, getImmediateByType(leftScalar->dtype(), GCD)));
+ alloc<Div>(leftScalar, getImmediateByType(leftScalar->dtype(), GCD)));
rightScalar = evaluateOp(
- new Div(rightScalar, getImmediateByType(rightScalar->dtype(), GCD)));
+ alloc<Div>(rightScalar, getImmediateByType(rightScalar->dtype(), GCD)));
if (lhsTerm) {
- lhs_new = new Term(lhsTerm->hasher(), leftScalar, lhsTerm->variables());
+ lhs_new = alloc<Term>(lhsTerm->hasher(), leftScalar, lhsTerm->variables());
} else {
lhs_new = leftScalar;
}
if (rhsTerm) {
- rhs_new = new Term(rhsTerm->hasher(), rightScalar, rhsTerm->variables());
+ rhs_new = alloc<Term>(rhsTerm->hasher(), rightScalar, rhsTerm->variables());
} else {
rhs_new = rightScalar;
}
- return new Div(lhs_new, rhs_new);
+ return alloc<Div>(lhs_new, rhs_new);
}
-Expr* PolynomialTransformer::mutate(Div* v) {
- Expr* lhs_new = v->lhs()->accept_mutator(this);
- Expr* rhs_new = v->rhs()->accept_mutator(this);
+ExprPtr PolynomialTransformer::mutate(DivPtr v) {
+ ExprPtr lhs_new = v->lhs()->accept_mutator(this);
+ ExprPtr rhs_new = v->rhs()->accept_mutator(this);
// Constant Folding.
if (lhs_new->isConstant() && rhs_new->isConstant()) {
- return evaluateOp(new Div(lhs_new, rhs_new));
+ return evaluateOp(alloc<Div>(lhs_new, rhs_new));
}
// If this is a floating point Div then order of operations is important, we
// dont want to combine ops.
if (lhs_new->dtype().is_floating_point() ||
rhs_new->dtype().is_floating_point()) {
- return new Div(lhs_new, rhs_new);
+ return alloc<Div>(lhs_new, rhs_new);
}
// If the numerator is zero, so is the result.
return ret->accept_mutator(this);
}
- return new Div(lhs_new, rhs_new);
+ return alloc<Div>(lhs_new, rhs_new);
}
-Expr* PolynomialTransformer::mutate(Mod* v) {
- Expr* lhs_new = v->lhs()->accept_mutator(this);
- Expr* rhs_new = v->rhs()->accept_mutator(this);
+ExprPtr PolynomialTransformer::mutate(ModPtr v) {
+ ExprPtr lhs_new = v->lhs()->accept_mutator(this);
+ ExprPtr rhs_new = v->rhs()->accept_mutator(this);
// Constant Folding.
if (lhs_new->isConstant() && rhs_new->isConstant()) {
- return evaluateOp(new Mod(lhs_new, rhs_new));
+ return evaluateOp(alloc<Mod>(lhs_new, rhs_new));
}
// 0 % x => 0.
return getImmediateByType(v->dtype(), 0);
}
- Term* lhsTerm = dynamic_cast<Term*>(lhs_new);
+ TermPtr lhsTerm = to<Term>(lhs_new);
if (!lhsTerm) {
- Polynomial* lhsPoly = dynamic_cast<Polynomial*>(lhs_new);
+ PolynomialPtr lhsPoly = to<Polynomial>(lhs_new);
if (lhsPoly) {
// Can still optimize this out if we can factorize the polynomial.
lhsTerm = factorizePolynomial(lhsPoly);
if (lhsTerm) {
// ((C1 * C2) * x) % C1 => 0.
if (rhs_new->isConstant() &&
- immediateEquals(evaluateOp(new Mod(lhsTerm->scalar(), rhs_new)), 0)) {
+ immediateEquals(
+ evaluateOp(alloc<Mod>(lhsTerm->scalar(), rhs_new)), 0)) {
return getImmediateByType(v->dtype(), 0);
}
// (x * y * z) % x => 0.
- for (auto* component : lhsTerm->variables()) {
+ for (auto component : lhsTerm->variables()) {
if (hasher_.hash(component) == hasher_.hash(rhs_new)) {
return getImmediateByType(v->dtype(), 0);
}
// also, (x * y * z) % (z * y) => 0.
// This requires all variable terms found in the RHS to be present in the
// LHS.
- Term* rhsTerm = dynamic_cast<Term*>(rhs_new);
+ TermPtr rhsTerm = to<Term>(rhs_new);
if (rhsTerm) {
auto& lVars = lhsTerm->variables();
auto& rVars = rhsTerm->variables();
if (rLeft == 0 &&
immediateEquals(
- evaluateOp(new Mod(lhsTerm->scalar(), rhsTerm->scalar())), 0)) {
+ evaluateOp(alloc<Mod>(lhsTerm->scalar(), rhsTerm->scalar())),
+ 0)) {
return getImmediateByType(v->dtype(), 0);
}
}
}
- return new Mod(lhs_new, rhs_new);
+ return alloc<Mod>(lhs_new, rhs_new);
}
namespace {
// The first type on the template refers to the op, as in Min or Max and the
// second type refers to the corresponding term, as in MinTerm or MaxTerm.
template <class Op, class OpTerm>
-Expr* combineMinMaxTerms(
- Expr* lhs,
- Expr* rhs,
+ExprPtr combineMinMaxTerms(
+ ExprPtr lhs,
+ ExprPtr rhs,
bool propagate_nans,
HashProvider& hasher) {
- auto combine_scalars = [&](Expr* c1, Expr* c2) -> Expr* {
+ auto combine_scalars = [&](ExprPtr c1, ExprPtr c2) -> ExprPtr {
if (c1 && c2) {
- return evaluateOp(new Op(c1, c2, propagate_nans));
+ return evaluateOp(alloc<Op>(c1, c2, propagate_nans));
}
if (c1) {
return c1;
return c2;
};
- auto combine_opterms = [&](OpTerm* m1, OpTerm* m2) {
- Expr* scalar = combine_scalars(m1->scalar(), m2->scalar());
- std::vector<Expr*> variables;
+ auto combine_opterms = [&](NodePtr<OpTerm> m1, NodePtr<OpTerm> m2) {
+ ExprPtr scalar = combine_scalars(m1->scalar(), m2->scalar());
+ std::vector<ExprPtr> variables;
for (auto v : m1->variables()) {
variables.push_back(v);
}
for (auto v : m2->variables()) {
variables.push_back(v);
}
- return new OpTerm(hasher, scalar, propagate_nans, std::move(variables));
+ return alloc<OpTerm>(hasher, scalar, propagate_nans, std::move(variables));
};
- auto add_expr_to_opterm = [&](Expr* expr, OpTerm* opterm) {
- Expr* scalar = nullptr;
- std::vector<Expr*> variables;
+ auto add_expr_to_opterm = [&](ExprPtr expr, NodePtr<OpTerm> opterm) {
+ ExprPtr scalar = nullptr;
+ std::vector<ExprPtr> variables;
if (opterm) {
scalar = opterm->scalar();
variables = opterm->variables();
} else {
variables.push_back(expr);
}
- return new OpTerm(hasher, scalar, propagate_nans, std::move(variables));
+ return alloc<OpTerm>(hasher, scalar, propagate_nans, std::move(variables));
};
- OpTerm* lhs_opterm = dynamic_cast<OpTerm*>(lhs);
- OpTerm* rhs_opterm = dynamic_cast<OpTerm*>(rhs);
+ auto lhs_opterm = to<OpTerm>(lhs);
+ auto rhs_opterm = to<OpTerm>(rhs);
if (lhs_opterm && lhs_opterm->propagate_nans() != propagate_nans) {
- return new Op(lhs, rhs, propagate_nans);
+ return alloc<Op>(lhs, rhs, propagate_nans);
}
if (rhs_opterm && rhs_opterm->propagate_nans() != propagate_nans) {
- return new Op(lhs, rhs, propagate_nans);
+ return alloc<Op>(lhs, rhs, propagate_nans);
}
if (lhs_opterm && rhs_opterm) {
// the other op of opterm in other_op.
template <class OpTerm>
bool isOperandInMinMaxTerm(
- OpTerm* opterm,
- Expr* op,
+ NodePtr<OpTerm> opterm,
+ ExprPtr op,
HashProvider& hasher,
- Expr** other_op) {
+ ExprPtr* other_op) {
if (opterm->variables().size() != 2) {
return false;
}
// type corresponding to the expected inner op (e.g. MinTerm).
template <class OpTerm, class OtherOpTerm>
bool simplifyNestedMinMax(
- Expr* lhs,
- Expr* rhs,
+ ExprPtr lhs,
+ ExprPtr rhs,
bool propagate_nans,
HashProvider& hasher,
- Expr** new_op) {
- auto lhs_opterm = dynamic_cast<OtherOpTerm*>(lhs);
- auto rhs_opterm = dynamic_cast<OtherOpTerm*>(rhs);
+ ExprPtr* new_op) {
+ auto lhs_opterm = to<OtherOpTerm>(lhs);
+ auto rhs_opterm = to<OtherOpTerm>(rhs);
if (lhs_opterm && rhs_opterm &&
lhs_opterm->propagate_nans() == propagate_nans &&
rhs_opterm->propagate_nans() == propagate_nans) {
auto rhs_v1 = rhs_opterm->variables()[0];
auto rhs_v2 = rhs_opterm->variables()[1];
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- Expr* new_op_lhs;
+ ExprPtr new_op_lhs;
if (isOperandInMinMaxTerm<OtherOpTerm>(
lhs_opterm, rhs_v1, hasher, &new_op_lhs)) {
- auto inner_op =
- new OpTerm(hasher, nullptr, propagate_nans, new_op_lhs, rhs_v2);
- *new_op = new OtherOpTerm(
+ auto inner_op = alloc<OpTerm>(
+ hasher, nullptr, propagate_nans, new_op_lhs, rhs_v2);
+ *new_op = alloc<OtherOpTerm>(
hasher, nullptr, propagate_nans, rhs_v1, inner_op);
return true;
}
if (isOperandInMinMaxTerm<OtherOpTerm>(
lhs_opterm, rhs_v2, hasher, &new_op_lhs)) {
- auto inner_op =
- new OpTerm(hasher, nullptr, propagate_nans, new_op_lhs, rhs_v1);
- *new_op = new OtherOpTerm(
+ auto inner_op = alloc<OpTerm>(
+ hasher, nullptr, propagate_nans, new_op_lhs, rhs_v1);
+ *new_op = alloc<OtherOpTerm>(
hasher, nullptr, propagate_nans, rhs_v2, inner_op);
return true;
}
} // namespace
-Expr* PolynomialTransformer::mutate(Max* v) {
- Expr* lhs_new = v->lhs()->accept_mutator(this);
- Expr* rhs_new = v->rhs()->accept_mutator(this);
+ExprPtr PolynomialTransformer::mutate(MaxPtr v) {
+ ExprPtr lhs_new = v->lhs()->accept_mutator(this);
+ ExprPtr rhs_new = v->rhs()->accept_mutator(this);
// Constant Folding.
if (lhs_new->isConstant() && rhs_new->isConstant()) {
- return evaluateOp(new Max(lhs_new, rhs_new, v->propagate_nans()));
+ return evaluateOp(alloc<Max>(lhs_new, rhs_new, v->propagate_nans()));
}
// If diff is constant, return the appropriate operand.
- Expr* diff = new Sub(lhs_new, rhs_new);
+ ExprPtr diff = alloc<Sub>(lhs_new, rhs_new);
diff = diff->accept_mutator(this);
if (diff->isConstant()) {
if (immediateAs<int>(diff) > 0) {
// Max(Min(x, y), Min(x, z)) => Min(x, Max(y, z))
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- Expr* new_op;
+ ExprPtr new_op;
if (simplifyNestedMinMax<MaxTerm, MinTerm>(
lhs_new, rhs_new, v->propagate_nans(), hasher_, &new_op)) {
return new_op;
lhs_new, rhs_new, v->propagate_nans(), hasher_);
}
-Expr* PolynomialTransformer::mutate(Min* v) {
- Expr* lhs_new = v->lhs()->accept_mutator(this);
- Expr* rhs_new = v->rhs()->accept_mutator(this);
+ExprPtr PolynomialTransformer::mutate(MinPtr v) {
+ ExprPtr lhs_new = v->lhs()->accept_mutator(this);
+ ExprPtr rhs_new = v->rhs()->accept_mutator(this);
// Constant Folding.
if (lhs_new->isConstant() && rhs_new->isConstant()) {
- return evaluateOp(new Min(lhs_new, rhs_new, v->propagate_nans()));
+ return evaluateOp(alloc<Min>(lhs_new, rhs_new, v->propagate_nans()));
}
// If diff is constant, return the appropriate operand.
- Expr* diff = new Sub(lhs_new, rhs_new);
+ ExprPtr diff = alloc<Sub>(lhs_new, rhs_new);
diff = diff->accept_mutator(this);
if (diff->isConstant()) {
if (immediateAs<int>(diff) < 0) {
// Min(Max(x, y), Max(x, z)) => Max(x, Min(y, z))
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- Expr* new_op;
+ ExprPtr new_op;
if (simplifyNestedMinMax<MinTerm, MaxTerm>(
lhs_new, rhs_new, v->propagate_nans(), hasher_, &new_op)) {
return new_op;
lhs_new, rhs_new, v->propagate_nans(), hasher_);
}
-Expr* PolynomialTransformer::mutate(CompareSelect* v) {
- Expr* lhs_new = v->lhs()->accept_mutator(this);
- Expr* rhs_new = v->rhs()->accept_mutator(this);
- Expr* true_branch = v->ret_val1()->accept_mutator(this);
- Expr* false_branch = v->ret_val2()->accept_mutator(this);
+ExprPtr PolynomialTransformer::mutate(CompareSelectPtr v) {
+ ExprPtr lhs_new = v->lhs()->accept_mutator(this);
+ ExprPtr rhs_new = v->rhs()->accept_mutator(this);
+ ExprPtr true_branch = v->ret_val1()->accept_mutator(this);
+ ExprPtr false_branch = v->ret_val2()->accept_mutator(this);
// Constant Folding.
if (lhs_new->isConstant() && rhs_new->isConstant()) {
- Expr* v_new = new CompareSelect(
+ ExprPtr v_new = alloc<CompareSelect>(
lhs_new,
rhs_new,
true_branch,
// since we can't correctly handle NaN.
if (lhs_new->dtype().is_floating_point() ||
rhs_new->dtype().is_floating_point()) {
- return new CompareSelect(
+ return alloc<CompareSelect>(
lhs_new,
rhs_new,
true_branch,
}
// If diff is constant, we can determine it.
- Expr* diff = new Sub(rhs_new, lhs_new);
+ ExprPtr diff = alloc<Sub>(rhs_new, lhs_new);
diff = diff->accept_mutator(this);
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
if (!diff->isConstant()) {
- return new CompareSelect(
+ return alloc<CompareSelect>(
lhs_new,
rhs_new,
true_branch,
}
// should not be possible but just in case.
- return new CompareSelect(
+ return alloc<CompareSelect>(
lhs_new,
rhs_new,
true_branch,
v->bias());
}
-Expr* PolynomialTransformer::mutate(Intrinsics* v) {
- std::vector<Expr*> new_params;
+ExprPtr PolynomialTransformer::mutate(IntrinsicsPtr v) {
+ std::vector<ExprPtr> new_params;
bool changed = false;
bool allConstant = true;
- for (auto* p : v->params()) {
- Expr* new_child = p->accept_mutator(this);
+ for (auto p : v->params()) {
+ ExprPtr new_child = p->accept_mutator(this);
new_params.push_back(new_child);
changed |= p != new_child;
allConstant &= new_child->isConstant();
}
- Expr* node = v;
+ ExprPtr node = v;
if (changed) {
- node = new Intrinsics(v->op_type(), new_params);
+ node = alloc<Intrinsics>(v->op_type(), new_params);
}
if (!allConstant || !v->isPure()) {
}
// we're evaluating, but the evaluator only supports float intrinsics.
- std::vector<Expr*> const_params;
+ std::vector<ExprPtr> const_params;
changed = false;
- for (auto* p : new_params) {
+ for (auto p : new_params) {
if (p->dtype().scalar_type() == ScalarType::Float) {
const_params.push_back(p);
} else {
const_params.push_back(
- new Cast(Dtype(ScalarType::Float, p->dtype().lanes()), p));
+ alloc<Cast>(Dtype(ScalarType::Float, p->dtype().lanes()), p));
changed = true;
}
}
if (changed) {
- node = new Intrinsics(v->op_type(), const_params);
+ node = alloc<Intrinsics>(v->op_type(), const_params);
}
return evaluateOp(node);
}
-Expr* PolynomialTransformer::mutate(Cast* v) {
- Expr* node = v->src_value()->accept_mutator(this);
+ExprPtr PolynomialTransformer::mutate(CastPtr v) {
+ ExprPtr node = v->src_value()->accept_mutator(this);
if (node->isConstant()) {
- return evaluateOp(new Cast(v->dtype(), node));
+ return evaluateOp(alloc<Cast>(v->dtype(), node));
}
if (v->dtype() == node->dtype()) {
return node;
}
- return new Cast(v->dtype(), node);
+ return alloc<Cast>(v->dtype(), node);
}
-Expr* PolynomialTransformer::mutate(IfThenElse* v) {
- Expr* condition = v->condition();
- Expr* true_value = v->true_value();
- Expr* false_value = v->false_value();
- Expr* condition_new = condition->accept_mutator(this);
- Expr* true_value_new = true_value->accept_mutator(this);
- Expr* false_value_new = false_value->accept_mutator(this);
+ExprPtr PolynomialTransformer::mutate(IfThenElsePtr v) {
+ ExprPtr condition = v->condition();
+ ExprPtr true_value = v->true_value();
+ ExprPtr false_value = v->false_value();
+ ExprPtr condition_new = condition->accept_mutator(this);
+ ExprPtr true_value_new = true_value->accept_mutator(this);
+ ExprPtr false_value_new = false_value->accept_mutator(this);
// If the condition is constant then we can choose the right branch now.
if (condition_new->isConstant()) {
return v;
}
- return new IfThenElse(condition_new, true_value_new, false_value_new);
+ return alloc<IfThenElse>(condition_new, true_value_new, false_value_new);
}
-Stmt* PolynomialBase::mutate(Cond* v) {
- Expr* cond_old = v->condition();
- Stmt* true_old = v->true_stmt();
- Stmt* false_old = v->false_stmt();
+StmtPtr PolynomialBase::mutate(CondPtr v) {
+ ExprPtr cond_old = v->condition();
+ StmtPtr true_old = v->true_stmt();
+ StmtPtr false_old = v->false_stmt();
- Expr* cond_new = cond_old->accept_mutator(this);
- Stmt* true_new = true_old ? true_old->accept_mutator(this) : true_old;
- Stmt* false_new = false_old ? false_old->accept_mutator(this) : false_old;
+ ExprPtr cond_new = cond_old->accept_mutator(this);
+ StmtPtr true_new = true_old ? true_old->accept_mutator(this) : true_old;
+ StmtPtr false_new = false_old ? false_old->accept_mutator(this) : false_old;
// If the condition is constant then we can choose the right branch now.
if (cond_new->isConstant()) {
return true_new;
}
- Block* true_block = dynamic_cast<Block*>(true_new);
- Block* false_block = dynamic_cast<Block*>(false_new);
+ BlockPtr true_block = to<Block>(true_new);
+ BlockPtr false_block = to<Block>(false_new);
bool true_empty = !true_new || (true_block && true_block->nstmts() == 0);
bool false_empty = !false_new || (false_block && false_block->nstmts() == 0);
if (true_empty && false_empty) {
- return new Block({});
+ return alloc<Block>(std::vector<StmtPtr>({}));
}
if (cond_old != cond_new) {
v->set_condition(cond_new);
return v;
}
-Stmt* handleForCondReordering(For* loop, Cond* cond) {
+StmtPtr handleForCondReordering(ForPtr loop, CondPtr cond) {
if (cond->false_stmt()) {
return nullptr;
}
auto condition_vars = VarFinder::find(cond->condition());
- for (auto* v : condition_vars) {
+ for (auto v : condition_vars) {
// If the condition depends on a Var that is modified in the loop body, it
// may not be safe to reorder.
if (ModifiesVarChecker::check(loop, v)) {
}
}
- For* new_f = loop->cloneWithNewBody(Stmt::clone(cond->true_stmt()));
+ ForPtr new_f = loop->cloneWithNewBody(Stmt::clone(cond->true_stmt()));
return cond->cloneWithNewBody(new_f);
}
-Stmt* PolynomialBase::mutate(For* v) {
- Expr* var = v->var();
- Expr* start = v->start();
- Expr* stop = v->stop();
- Stmt* body = v->body();
+StmtPtr PolynomialBase::mutate(ForPtr v) {
+ ExprPtr var = v->var();
+ ExprPtr start = v->start();
+ ExprPtr stop = v->stop();
+ StmtPtr body = v->body();
LoopOptions loop_options = v->loop_options();
- Expr* var_new_expr = var->accept_mutator(this);
- Var* var_new = dynamic_cast<Var*>(var_new_expr);
- Expr* start_new = start->accept_mutator(this);
- Expr* stop_new = stop->accept_mutator(this);
- Stmt* body_new = body;
+ ExprPtr var_new_expr = var->accept_mutator(this);
+ VarPtr var_new = to<Var>(var_new_expr);
+ ExprPtr start_new = start->accept_mutator(this);
+ ExprPtr stop_new = stop->accept_mutator(this);
+ StmtPtr body_new = body;
- Expr* loops = new Sub(stop_new, start_new);
+ ExprPtr loops = alloc<Sub>(stop_new, start_new);
loops = loops->accept_mutator(this);
if (loop_options.isDefault() && loops->isConstant()) {
if (immediateEquals(loops, 0)) {
- return new Block({});
+ return alloc<Block>(std::vector<StmtPtr>({}));
} else if (immediateEquals(loops, 1)) {
body_new = Substitute(body, {{var_new, start_new}});
body_new = body_new->accept_mutator(this);
body_new = body_new->accept_mutator(this);
if (!body_new) {
- return new Block({});
+ return alloc<Block>(std::vector<StmtPtr>({}));
}
- if (auto* block = dynamic_cast<Block*>(body_new)) {
+ if (auto block = to<Block>(body_new)) {
if (block->nstmts() == 0) {
- return new Block({});
+ return alloc<Block>(std::vector<StmtPtr>({}));
}
if (block->nstmts() == 1) {
- if (auto* cond = dynamic_cast<Cond*>(block->front())) {
- Stmt* reordered = handleForCondReordering(v, cond);
+ if (auto cond = to<Cond>(block->front())) {
+ StmtPtr reordered = handleForCondReordering(v, cond);
if (reordered) {
return reordered->accept_mutator(this);
}
return v;
}
-Stmt* PolynomialBase::mutate(Block* v) {
- std::vector<Stmt*> stmts;
+StmtPtr PolynomialBase::mutate(BlockPtr v) {
+ std::vector<StmtPtr> stmts;
// Flatten sub-blocks:
bool stmts_changed = false;
- for (Stmt* stmt : *v) {
- Stmt* stmt_new = stmt->accept_mutator(this);
+ for (StmtPtr stmt : *v) {
+ StmtPtr stmt_new = stmt->accept_mutator(this);
stmts_changed |= stmt != stmt_new;
if (stmt_new == nullptr) {
continue;
}
- if (auto* subBlock = dynamic_cast<Block*>(stmt_new)) {
+ if (auto subBlock = to<Block>(stmt_new)) {
for (Block::iterator I = subBlock->begin(), E = subBlock->end();
I != E;) {
// Be careful to avoid invalidating the iterator.
- Stmt* s = *(I++);
+ StmtPtr s = *(I++);
subBlock->remove_stmt(s);
stmts.push_back(s);
}
// TermExpander
-Expr* TermExpander::mutate(Term* v) {
- Expr* newScalar = v->scalar()->accept_mutator(this);
+ExprPtr TermExpander::mutate(TermPtr v) {
+ ExprPtr newScalar = v->scalar()->accept_mutator(this);
if (immediateEquals(newScalar, 0)) {
return newScalar;
}
- std::vector<Expr*> vars;
- std::vector<Expr*> multilaneVars;
+ std::vector<ExprPtr> vars;
+ std::vector<ExprPtr> multilaneVars;
// Assume we can reorder here because we wont merge floating terms.
- Expr* lastNode{nullptr};
- for (auto* var : v->variables()) {
- Expr* node = var->accept_mutator(this);
- if (Mul* mul = dynamic_cast<Mul*>(node)) {
+ ExprPtr lastNode{nullptr};
+ for (auto var : v->variables()) {
+ ExprPtr node = var->accept_mutator(this);
+ if (MulPtr mul = to<Mul>(node)) {
// If the sub-Expr resolved to a multiplication, lift it into this
// term.
if (isMultilanePrimitive(mul->lhs())) {
}
}
- for (auto* node : multilaneVars) {
+ for (auto node : multilaneVars) {
if (lastNode == nullptr) {
lastNode = node;
} else {
}
}
- for (auto* node : vars) {
+ for (auto node : vars) {
if (lastNode == nullptr) {
lastNode = node;
} else {
- lastNode = new Mul(lastNode, node);
+ lastNode = alloc<Mul>(lastNode, node);
}
}
auto termDtype = v->scalar()->dtype();
auto lastNodeDtype = lastNode->dtype();
if (termDtype != lastNodeDtype) {
- Expr* castV = v->scalar();
+ ExprPtr castV = v->scalar();
// Take care of lane mismatch first.
if (termDtype.lanes() != lastNodeDtype.lanes()) {
- castV = new Broadcast(v->scalar(), lastNodeDtype.lanes());
+ castV = alloc<Broadcast>(v->scalar(), lastNodeDtype.lanes());
}
// Now take care of scalar type as well.
if (termDtype.scalar_type() != lastNodeDtype.scalar_type()) {
- castV = new Cast(lastNode->dtype(), castV);
+ castV = alloc<Cast>(lastNode->dtype(), castV);
// For scalars, we can simplify the cast further.
if (lastNodeDtype.lanes() == 1) {
castV = evaluateOp(castV);
}
}
- lastNode = new Mul(castV, lastNode);
+ lastNode = alloc<Mul>(castV, lastNode);
} else {
- lastNode = new Mul(v->scalar(), lastNode);
+ lastNode = alloc<Mul>(v->scalar(), lastNode);
}
} else {
lastNode = v->scalar();
// Returns an immediate containing the greatest common divisor of all terms
// (inc. the scalar term) in the polynomial. If the GCD is uninteresting
// (e.g. 1) then returns nullptr.
-Expr* polyGCD(Polynomial* poly) {
- Expr* scalar = poly->scalar();
- const std::vector<Term*>& variables = poly->variables();
+ExprPtr polyGCD(PolynomialPtr poly) {
+ ExprPtr scalar = poly->scalar();
+ const std::vector<TermPtr>& variables = poly->variables();
// We ony want to factorize if we're saving complete operations, i.e. no
// value in factorizing 6x + 4y into 2 * (3x + 2y) since we don't save work.
int opsSaved = 1; // default to saving the scalar.
long GCD = std::abs(immediateAs<long>(scalar));
- for (auto* t : variables) {
+ for (auto t : variables) {
long termScalar = std::abs(immediateAs<long>(t->scalar()));
long newGCD = gcd(std::max(GCD, termScalar), std::min(GCD, termScalar));
if (newGCD == 1) {
// denotes x, 'divisor' denotes y and 'mod_divisor' denotes z.
class ModRound {
public:
- ModRound(Expr* scalar, Expr* denom, Expr* divisor, Expr* mod_divisor)
+ ModRound(ExprPtr scalar, ExprPtr denom, ExprPtr divisor, ExprPtr mod_divisor)
: scalar(scalar),
denom(denom),
divisor(divisor),
mod_divisor(mod_divisor) {}
- Expr* scalar;
- Expr* denom;
- Expr* divisor;
- Expr* mod_divisor;
+ ExprPtr scalar;
+ ExprPtr denom;
+ ExprPtr divisor;
+ ExprPtr mod_divisor;
};
-c10::optional<class ModRound*> isModRound(Term* e) {
- Div* div{nullptr};
- Mod* mod{nullptr};
- Expr* denom{nullptr};
- Expr* divisor{nullptr};
- Expr* mod_divisor{nullptr};
- Expr* multiplier = e->scalar();
- Expr* scalar{nullptr};
- Expr* other{nullptr};
-
- for (auto* m : e->variables()) {
+c10::optional<class ModRound*> isModRound(TermPtr e) {
+ DivPtr div{nullptr};
+ ModPtr mod{nullptr};
+ ExprPtr denom{nullptr};
+ ExprPtr divisor{nullptr};
+ ExprPtr mod_divisor{nullptr};
+ ExprPtr multiplier = e->scalar();
+ ExprPtr scalar{nullptr};
+ ExprPtr other{nullptr};
+
+ for (auto m : e->variables()) {
if (m->expr_type() == IRNodeType::kMod) {
// TODO: currently only identify terms with one variable being mod; it is
// possible to extend this if we have to handle terms like (t/(x%2 * y) %
// z) * (x%2 *y).
if (!mod) {
- mod = dynamic_cast<Mod*>(m);
+ mod = to<Mod>(m);
} else {
return c10::nullopt;
}
if (multiplier->isConstant()) {
// Take care of lane mismatch first.
if (multiplier->dtype().lanes() != m->dtype().lanes()) {
- multiplier = new Broadcast(multiplier, m->dtype().lanes());
+ multiplier = alloc<Broadcast>(multiplier, m->dtype().lanes());
}
// Take care of scalar type mismatch.
if (multiplier->dtype().scalar_type() != m->dtype().scalar_type()) {
- multiplier = new Cast(m->dtype(), multiplier);
+ multiplier = alloc<Cast>(m->dtype(), multiplier);
if (m->dtype().lanes() == 1) {
multiplier = evaluateOp(multiplier);
}
}
// All non-mod vairables are considered as part of the multiplier.
- multiplier = new Mul(multiplier, m);
+ multiplier = alloc<Mul>(multiplier, m);
}
}
multiplier = IRSimplifier::simplify(multiplier);
mod_divisor = IRSimplifier::simplify(mod->rhs());
other = mod->lhs();
- if (!(div = dynamic_cast<Div*>(other))) {
+ if (!(div = to<Div>(other))) {
return c10::nullopt;
}
// transformations.
if (divisor->isConstant() && multiplier->isConstant()) {
// If both are scalar we may be able to find a common factor.
- if (immediateEquals(evaluateOp(new Mod(multiplier, divisor)), 0)) {
+ if (immediateEquals(evaluateOp(alloc<Mod>(multiplier, divisor)), 0)) {
// The common factor becomes 'scalar' of the term, e.g.,in t/3%7*6,
// divisor=multiplier=3, scalar=2.
- Expr* c = evaluateOp(new Div(multiplier, divisor));
+ ExprPtr c = evaluateOp(alloc<Div>(multiplier, divisor));
scalar = c;
- } else if (immediateEquals(evaluateOp(new Mod(divisor, multiplier)), 0)) {
+ } else if (immediateEquals(
+ evaluateOp(alloc<Mod>(divisor, multiplier)), 0)) {
// The common factor becomes part of 'denom', e.g., in t/14%7*2,
// divisor=multiplier=2, denom=t/7.
- Expr* c = evaluateOp(new Div(divisor, multiplier));
+ ExprPtr c = evaluateOp(alloc<Div>(divisor, multiplier));
divisor = multiplier;
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
- denom = IRSimplifier::simplify(new Div(other, c));
+ denom = IRSimplifier::simplify(alloc<Div>(other, c));
} else {
return c10::nullopt;
}
// (1) Round + Mod pattern: (x/y) * y + x % y => RoundOff(x,y) + Mod(x, y) => x
// (2) Mod round + Mod pattern: (x/y % z)*y + x%y => ModRound(x, y, z) + Mod(x,
// y) => x % (y*z)
-Expr* simplifyRoundModPattern(Polynomial* poly) {
- std::vector<Term*> rounds;
- std::vector<Term*> mods;
- std::vector<Term*> mod_rounds;
- std::vector<Term*> others;
+ExprPtr simplifyRoundModPattern(PolynomialPtr poly) {
+ std::vector<TermPtr> rounds;
+ std::vector<TermPtr> mods;
+ std::vector<TermPtr> mod_rounds;
+ std::vector<TermPtr> others;
// Split out the Mod, ModRounds and RoundOffs operations so we can inspect.
- for (auto* c : poly->variables()) {
+ for (auto c : poly->variables()) {
if (c->variables().size() > 1) {
if (auto a = isModRound(c)) {
mod_rounds.push_back(c);
continue;
}
- Expr* e = c->variables()[0];
+ ExprPtr e = c->variables()[0];
- if (dynamic_cast<RoundOff*>(e)) {
+ if (to<RoundOff>(e)) {
rounds.push_back(c);
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
} else if (e->expr_type() == IRNodeType::kMod) {
HashProvider& hasher = poly->hasher();
bool didAnything = false;
- std::vector<Term*> mods_merged;
+ std::vector<TermPtr> mods_merged;
bool repeat = true;
// Repeat merging terms till there are no Mods or the terms cannot be merged
// any further.
repeat = false;
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
for (int64_t i = mods.size() - 1; i >= 0; i--) {
- Term* m = mods[i];
- Mod* mod = dynamic_cast<Mod*>(m->variables()[0]);
+ TermPtr m = mods[i];
+ ModPtr mod = to<Mod>(m->variables()[0]);
CHECK(mod);
- Expr* mod_lhs = IRSimplifier::simplify(mod->lhs());
- Expr* mod_rhs = IRSimplifier::simplify(mod->rhs());
+ ExprPtr mod_lhs = IRSimplifier::simplify(mod->lhs());
+ ExprPtr mod_rhs = IRSimplifier::simplify(mod->rhs());
bool merged = false;
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
for (int64_t j = mod_rounds.size() - 1; j >= 0; j--) {
- Term* mr = mod_rounds[j];
+ TermPtr mr = mod_rounds[j];
auto a = isModRound(mr);
CHECK(a);
ModRound* mod_round = dynamic_cast<ModRound*>(*a);
// optimization. E.g. it's possible to do: 2 * (x/y%z) * y + (x%y) =>
// x%(y*z) + (x/y%z) * y
if (!immediateEquals(
- evaluateOp(new Sub(mod_round->scalar, m->scalar())), 0)) {
+ evaluateOp(alloc<Sub>(mod_round->scalar, m->scalar())), 0)) {
continue;
}
// Valid optimization if mod LHS matches denom and mod RHS matches
if (hasher.hash(mod_round->denom) == hasher.hash(mod_lhs) &&
hasher.hash(mod_round->divisor) == hasher.hash(mod_rhs)) {
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
- Term* merged_m = new Term(
+ TermPtr merged_m = alloc<Term>(
hasher,
mod_round->scalar,
- IRSimplifier::simplify(new Mod(
+ IRSimplifier::simplify(alloc<Mod>(
mod_round->denom,
- new Mul(mod_round->divisor, mod_round->mod_divisor))));
+ alloc<Mul>(mod_round->divisor, mod_round->mod_divisor))));
mods_merged.push_back(merged_m);
merged = true;
repeat = true;
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
for (int64_t k = rounds.size() - 1; k >= 0; k--) {
- Term* r = rounds[k];
- RoundOff* roundoff = dynamic_cast<RoundOff*>(r->variables()[0]);
+ TermPtr r = rounds[k];
+ RoundOffPtr roundoff = to<RoundOff>(r->variables()[0]);
CHECK(roundoff);
// TODO: for now don't attempt partial factorization of this
// (x/y) * y but unsure thats actually much better, particulary with
// CSE.
if (!immediateEquals(
- evaluateOp(new Sub(r->scalar(), m->scalar())), 0)) {
+ evaluateOp(alloc<Sub>(r->scalar(), m->scalar())), 0)) {
continue;
}
- Expr* round_lhs = IRSimplifier::simplify(roundoff->lhs());
- Expr* round_rhs = IRSimplifier::simplify(roundoff->rhs());
+ ExprPtr round_lhs = IRSimplifier::simplify(roundoff->lhs());
+ ExprPtr round_rhs = IRSimplifier::simplify(roundoff->rhs());
// Valid optimization if LHS and RHS are equal for both.
if (hasher.hash(round_lhs) == hasher.hash(mod_lhs) &&
hasher.hash(round_rhs) == hasher.hash(mod_rhs)) {
- Term* merged_r = new Term(hasher, r->scalar(), round_lhs);
+ TermPtr merged_r = alloc<Term>(hasher, r->scalar(), round_lhs);
others.push_back(merged_r);
merged = true;
didAnything = true;
others.insert(others.end(), rounds.begin(), rounds.end());
}
- return new Polynomial(hasher, poly->scalar(), others);
+ return alloc<Polynomial>(hasher, poly->scalar(), others);
}
// Trivially factorize terms by GCD of scalar components.
-Term* PolynomialBase::factorizePolynomial(Polynomial* poly) {
- Expr* scalar = poly->scalar();
- const std::vector<Term*>& variables = poly->variables();
+TermPtr PolynomialBase::factorizePolynomial(PolynomialPtr poly) {
+ ExprPtr scalar = poly->scalar();
+ const std::vector<TermPtr>& variables = poly->variables();
// Compute the GCD of terms.
- Expr* GCD = polyGCD(poly);
+ ExprPtr GCD = polyGCD(poly);
// No GCD means 0 or 1 and can't be factored.
if (!GCD) {
}
// Create new struture.
- std::vector<Term*> newPolyTerms;
+ std::vector<TermPtr> newPolyTerms;
newPolyTerms.reserve(variables.size());
- for (auto* t : variables) {
+ for (auto t : variables) {
// New term with the scalar divided by the GCD.
- newPolyTerms.push_back(new Term(
- poly->hasher(), evaluateOp(new Div(t->scalar(), GCD)), t->variables()));
+ newPolyTerms.push_back(alloc<Term>(
+ poly->hasher(),
+ evaluateOp(alloc<Div>(t->scalar(), GCD)),
+ t->variables()));
}
- Polynomial* newPoly = new Polynomial(
- poly->hasher(), evaluateOp(new Div(scalar, GCD)), newPolyTerms);
+ PolynomialPtr newPoly = alloc<Polynomial>(
+ poly->hasher(), evaluateOp(alloc<Div>(scalar, GCD)), newPolyTerms);
- return new Term(poly->hasher(), GCD, newPoly);
+ return alloc<Term>(poly->hasher(), GCD, newPoly);
}
-Expr* TermExpander::mutate(Polynomial* v) {
+ExprPtr TermExpander::mutate(PolynomialPtr v) {
if (v->variables().empty()) {
return v->scalar();
}
// If this Polynomial can be factorized: do it, then expand the result.
- if (Expr* simplified = simplifyRoundModPattern(v)) {
+ if (ExprPtr simplified = simplifyRoundModPattern(v)) {
return simplified->accept_mutator(this);
}
// If this Polynomial can be factorized: do it, then expand the result.
- if (Expr* factorized = factorizePolynomial(v)) {
+ if (ExprPtr factorized = factorizePolynomial(v)) {
return factorized->accept_mutator(this);
}
- std::vector<Term*> addTerms;
- std::vector<Term*> subTerms;
+ std::vector<TermPtr> addTerms;
+ std::vector<TermPtr> subTerms;
// partition the terms into a list to add and list to subtract.
- for (auto* node : v->variables()) {
+ for (auto node : v->variables()) {
if (immediateIsNegative(node->scalar())) {
subTerms.push_back(node);
} else if (!immediateEquals(node->scalar(), 0)) {
}
// The last node constructed.
- Expr* lastNode{nullptr};
+ ExprPtr lastNode{nullptr};
- for (auto* node : addTerms) {
- Expr* simpleNode = node->accept_mutator(this);
+ for (auto node : addTerms) {
+ ExprPtr simpleNode = node->accept_mutator(this);
if (lastNode == nullptr) {
lastNode = simpleNode;
}
if (isMultilanePrimitive(simpleNode)) {
- auto* ret = combineMultilane<Add>(lastNode, simpleNode);
+ auto ret = combineMultilane<Add>(lastNode, simpleNode);
if (ret) {
// simplify result first, then expand.
lastNode = ret->accept_mutator(simplifier_);
}
}
- lastNode = new Add(lastNode, simpleNode);
+ lastNode = alloc<Add>(lastNode, simpleNode);
}
// If we have no add terms the scalar should go first.
// E.g. 1 - x.
bool scalarWritten = false;
if (lastNode == nullptr) {
- auto* scalarNode = v->scalar()->accept_mutator(simplifier_);
+ auto scalarNode = v->scalar()->accept_mutator(simplifier_);
if (!immediateEquals(scalarNode, 0)) {
lastNode = scalarNode;
}
}
- for (auto* node : subTerms) {
+ for (auto node : subTerms) {
// Can still be first node if scalarVal is 0.
if (lastNode == nullptr) {
lastNode = node->accept_mutator(this);
}
// Negate the term back to positive since we'll be subtracting it.
- Expr* negated = evaluateOp(new Mul(
+ ExprPtr negated = evaluateOp(alloc<Mul>(
getImmediateByType(node->scalar()->dtype(), -1), node->scalar()));
- Term* newRHS = new Term(node->hasher(), negated, node->variables());
- lastNode = new Sub(lastNode, newRHS->accept_mutator(this));
+ TermPtr newRHS = alloc<Term>(node->hasher(), negated, node->variables());
+ lastNode = alloc<Sub>(lastNode, newRHS->accept_mutator(this));
}
if (scalarWritten || immediateEquals(v->scalar(), 0)) {
if (immediateIsNegative(v->scalar())) {
// Negate the scalar and subtract.
- Expr* negated = evaluateOp(
- new Mul(getImmediateByType(lastNode->dtype(), -1), v->scalar()));
- lastNode = new Sub(lastNode, evaluateOp(negated));
+ ExprPtr negated = evaluateOp(
+ alloc<Mul>(getImmediateByType(lastNode->dtype(), -1), v->scalar()));
+ lastNode = alloc<Sub>(lastNode, evaluateOp(negated));
} else {
// we want to avoid a cast to the scalar if it would happen.
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
if (v->scalar()->dtype() != lastNode->dtype()) {
- lastNode = new Add(
- lastNode, evaluateOp(new Cast(lastNode->dtype(), v->scalar())));
+ lastNode = alloc<Add>(
+ lastNode, evaluateOp(alloc<Cast>(lastNode->dtype(), v->scalar())));
} else {
- lastNode = new Add(lastNode, v->scalar());
+ lastNode = alloc<Add>(lastNode, v->scalar());
}
}
return lastNode;
}
-Expr* TermExpander::mutate(MaxTerm* v) {
+ExprPtr TermExpander::mutate(MaxTermPtr v) {
auto& variables = v->variables();
if (variables.empty()) {
if (!v->scalar()) {
return v->scalar();
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- Expr* max;
+ ExprPtr max;
if (v->scalar()) {
- max = new Max(variables[0], v->scalar(), v->propagate_nans());
+ max = alloc<Max>(variables[0], v->scalar(), v->propagate_nans());
} else {
max = variables[0];
}
for (size_t i = 1; i < variables.size(); i++) {
- max = new Max(max, variables[i], v->propagate_nans());
+ max = alloc<Max>(max, variables[i], v->propagate_nans());
}
return max->accept_mutator(this);
}
-Expr* TermExpander::mutate(MinTerm* v) {
+ExprPtr TermExpander::mutate(MinTermPtr v) {
auto& variables = v->variables();
if (variables.empty()) {
if (!v->scalar()) {
return v->scalar();
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- Expr* min;
+ ExprPtr min;
if (v->scalar()) {
- min = new Min(variables[0], v->scalar(), v->propagate_nans());
+ min = alloc<Min>(variables[0], v->scalar(), v->propagate_nans());
} else {
min = variables[0];
}
for (size_t i = 1; i < variables.size(); i++) {
- min = new Min(min, variables[i], v->propagate_nans());
+ min = alloc<Min>(min, variables[i], v->propagate_nans());
}
return min->accept_mutator(this);
}
// Expands RoundOff(x, y) => Term(1, Div(x, y), y), which will later be expanded
// to Mul(Div(x, y), y).
-Expr* TermExpander::mutate(RoundOff* v) {
- Term* term = new Term(
+ExprPtr TermExpander::mutate(RoundOffPtr v) {
+ TermPtr term = alloc<Term>(
simplifier_->hasher(),
getImmediateByType(v->dtype(), 1),
- new Div(v->lhs(), v->rhs()),
+ alloc<Div>(v->lhs(), v->rhs()),
v->rhs());
return term->accept_mutator(this);
}
-Expr* buf_flat_size(Buf* v) {
- std::vector<Expr*> dims = v->dims();
+ExprPtr buf_flat_size(BufPtr v) {
+ std::vector<ExprPtr> dims = v->dims();
- Expr* flattened = getImmediateByType(kInt, 1);
+ ExprPtr flattened = getImmediateByType(kInt, 1);
for (auto& dim : dims) {
- flattened = new Mul(flattened, dim);
+ flattened = alloc<Mul>(flattened, dim);
}
flattened = IRSimplifier::simplify(flattened);
return flattened;
}
-Stmt* TermExpander::mutate(Allocate* v) {
- Buf* buf = v->buf();
- Buf* buf_new = dynamic_cast<Buf*>(v->buf()->accept_mutator(this));
+StmtPtr TermExpander::mutate(AllocatePtr v) {
+ BufPtr buf = v->buf();
+ BufPtr buf_new = to<Buf>(v->buf()->accept_mutator(this));
TORCH_INTERNAL_ASSERT(buf_new);
- Expr* flattened = buf_flat_size(buf_new);
+ ExprPtr flattened = buf_flat_size(buf_new);
if (flattened->isConstant() && immediateEquals(flattened, 0)) {
eliminated_allocations_.insert(buf_new->base_handle());
return v;
}
-Stmt* TermExpander::mutate(Free* v) {
- Buf* buf = v->buf();
- Buf* buf_new = dynamic_cast<Buf*>(v->buf()->accept_mutator(this));
+StmtPtr TermExpander::mutate(FreePtr v) {
+ BufPtr buf = v->buf();
+ BufPtr buf_new = to<Buf>(v->buf()->accept_mutator(this));
TORCH_INTERNAL_ASSERT(buf_new);
if (eliminated_allocations_.count(buf_new->base_handle())) {
}
// Combines adjactent Cond nodes with identical conditions.
-Block* TermExpander::fuseConditions(Block* v) {
- std::vector<Stmt*> stmts;
+BlockPtr TermExpander::fuseConditions(BlockPtr v) {
+ std::vector<StmtPtr> stmts;
bool did_anything = false;
- Cond* prev_cond = nullptr;
+ CondPtr prev_cond = nullptr;
- for (auto* s : *v) {
- Cond* cond = dynamic_cast<Cond*>(s);
+ for (auto s : *v) {
+ CondPtr cond = to<Cond>(s);
if (!cond) {
prev_cond = nullptr;
stmts.push_back(s);
// Fuse the two Conds by appending the bodies of the second Cond to the
// first.
- Block* true_block = new Block({});
- Block* false_block = new Block({});
+ BlockPtr true_block = alloc<Block>(std::vector<StmtPtr>({}));
+ BlockPtr false_block = alloc<Block>(std::vector<StmtPtr>({}));
if (prev_cond->true_stmt()) {
true_block->splice(true_block->end(), prev_cond->true_stmt());
}
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
- Stmt* new_cond = prev_cond->cloneWithNewBodies(true_block, false_block)
- ->accept_mutator(this);
- prev_cond = dynamic_cast<Cond*>(new_cond);
+ StmtPtr new_cond = prev_cond->cloneWithNewBodies(true_block, false_block)
+ ->accept_mutator(this);
+ prev_cond = to<Cond>(new_cond);
// erase, which shortens the list.
stmts.pop_back();
}
// clean up parents.
- for (auto* s : stmts) {
+ for (auto s : stmts) {
if (s->get_parent() == v) {
v->remove_stmt(s);
}
}
- return new Block(stmts);
+ return alloc<Block>(stmts);
}
-Stmt* TermExpander::fuseSyncThreads(Block* block) {
+StmtPtr TermExpander::fuseSyncThreads(BlockPtr block) {
// only really first if highest level Block.
bool first = block->get_parent() == nullptr;
- SyncThreads* last = nullptr;
- std::vector<Stmt*> stmts;
+ SyncThreadsPtr last = nullptr;
+ std::vector<StmtPtr> stmts;
bool did_anything = false;
- for (auto* s : *block) {
- SyncThreads* sync = dynamic_cast<SyncThreads*>(s);
+ for (auto s : *block) {
+ SyncThreadsPtr sync = to<SyncThreads>(s);
if (!sync) {
first = false;
last = nullptr;
}
// clean up parents.
- for (auto* s : stmts) {
+ for (auto s : stmts) {
if (s->get_parent() == block) {
block->remove_stmt(s);
}
}
- return new Block({stmts});
+ return alloc<Block>(std::vector<StmtPtr>({stmts}));
}
-Stmt* TermExpander::mutate(Block* v) {
- Stmt* new_stmt = PolynomialBase::mutate(v);
- Block* new_block = dynamic_cast<Block*>(new_stmt);
+StmtPtr TermExpander::mutate(BlockPtr v) {
+ StmtPtr new_stmt = PolynomialBase::mutate(v);
+ BlockPtr new_block = to<Block>(new_stmt);
if (!new_block) {
return new_stmt;
}
// This function records the bounds(range) info of the index var in a for-stmt.
// The bounds info will be used later when simplifying expressions with the
// index var.
-Stmt* SimplifierUnderContext::mutate(For* v) {
- Expr* var = v->var();
- Expr* start = v->start();
- Expr* stop = v->stop();
- Stmt* body = v->body();
+StmtPtr SimplifierUnderContext::mutate(ForPtr v) {
+ ExprPtr var = v->var();
+ ExprPtr start = v->start();
+ ExprPtr stop = v->stop();
+ StmtPtr body = v->body();
LoopOptions loop_options = v->loop_options();
- Expr* var_new_expr = var->accept_mutator(this);
- Var* var_new = dynamic_cast<Var*>(var_new_expr);
- Expr* start_new = start->accept_mutator(this);
- Expr* stop_new = stop->accept_mutator(this);
- Stmt* body_new = body;
+ ExprPtr var_new_expr = var->accept_mutator(this);
+ VarPtr var_new = to<Var>(var_new_expr);
+ ExprPtr start_new = start->accept_mutator(this);
+ ExprPtr stop_new = stop->accept_mutator(this);
+ StmtPtr body_new = body;
// save bounds info before this for-stmt
//
// bound info after the for stmt, we can use it to simplify the assignment
// stmt x = (i+20)/5 to x = 4.
bool has_bounds = false;
- std::pair<Expr*, Expr*> bound_old;
- Var* var_key = dynamic_cast<Var*>(var);
+ std::pair<ExprPtr, ExprPtr> bound_old;
+ VarPtr var_key = to<Var>(var);
auto got = var_bound_info_.find(var_key);
if (got != var_bound_info_.end()) {
has_bounds = true;
bound_old = got->second;
}
// set bounds info for index var
- const std::pair<Expr*, Expr*> bound_new = std::make_pair(start_new, stop_new);
+ const std::pair<ExprPtr, ExprPtr> bound_new =
+ std::make_pair(start_new, stop_new);
var_bound_info_[var_key] = bound_new;
- Expr* iters = new Sub(stop_new, start_new);
+ ExprPtr iters = alloc<Sub>(stop_new, start_new);
iters = iters->accept_mutator(this);
if (loop_options.isDefault() && iters->isConstant()) {
if (immediateEquals(iters, 0)) {
- return new Block({});
+ return alloc<Block>(std::vector<StmtPtr>({}));
} else if (immediateEquals(iters, 1)) {
body_new = Substitute(body, {{var_new, start_new}});
body_new = body_new->accept_mutator(this);
}
if (!body_new) {
- return new Block({});
+ return alloc<Block>(std::vector<StmtPtr>({}));
}
- if (auto* block = dynamic_cast<Block*>(body_new)) {
+ if (auto block = to<Block>(body_new)) {
if (block->nstmts() == 0) {
- return new Block({});
+ return alloc<Block>(std::vector<StmtPtr>({}));
}
if (block->nstmts() == 1) {
// if the stmt in the loop body is a if-stmt, try to move the branching
// out of the loop
- if (auto* cond = dynamic_cast<Cond*>(block->front())) {
- Stmt* reordered = handleForCondReordering(v, cond);
+ if (auto cond = to<Cond>(block->front())) {
+ StmtPtr reordered = handleForCondReordering(v, cond);
if (reordered) {
return reordered->accept_mutator(this);
}
// returns -1. But currently, both Pytorch and NNC are performing an incorrect
// integer division: (-1)/6 = 0. With the current implementation of integer
// division, x has to be not negative. d) j is not negative
-Expr* distributeDiv(Expr* lhs, Expr* rhs, VarBoundInfo var_bound_info) {
+ExprPtr distributeDiv(ExprPtr lhs, ExprPtr rhs, VarBoundInfo var_bound_info) {
if (!lhs || !rhs) {
return nullptr;
}
}
// identify n: a positive integer constant
- Expr* rhsScalar = rhs->isConstant() ? rhs : nullptr;
+ ExprPtr rhsScalar = rhs->isConstant() ? rhs : nullptr;
if (!rhsScalar) {
return nullptr;
}
- Expr* check_n_value =
- IRSimplifier::simplify(new CompareSelect(rhsScalar, new IntImm(0), kGT));
+ ExprPtr check_n_value = IRSimplifier::simplify(
+ alloc<CompareSelect>(rhsScalar, alloc<IntImm>(0), kGT));
if (!immediateEquals(check_n_value, 1)) {
return nullptr;
}
- auto* lhsAdd = dynamic_cast<Add*>(lhs);
+ auto lhsAdd = to<Add>(lhs);
if (!lhsAdd) {
return nullptr;
}
- Expr* lhsAdd1 = lhsAdd->lhs();
- Expr* lhsAdd2 = lhsAdd->rhs();
+ ExprPtr lhsAdd1 = lhsAdd->lhs();
+ ExprPtr lhsAdd2 = lhsAdd->rhs();
// identify index var 'i'
- Var* var_key = dynamic_cast<Var*>(lhsAdd1);
- Expr* main = lhsAdd2;
+ VarPtr var_key = to<Var>(lhsAdd1);
+ ExprPtr main = lhsAdd2;
if (var_key == nullptr) {
- var_key = dynamic_cast<Var*>(lhsAdd2);
+ var_key = to<Var>(lhsAdd2);
main = lhsAdd1;
}
// open upper bound, i.e., end is one more than the maximum value in the
// range
auto end = got->second.second;
- Expr* check_start =
- IRSimplifier::simplify(new CompareSelect(start, new IntImm(0), kGE));
- Expr* check_end =
- IRSimplifier::simplify(new CompareSelect(end, rhsScalar, kLE));
+ ExprPtr check_start = IRSimplifier::simplify(
+ alloc<CompareSelect>(start, alloc<IntImm>(0), kGE));
+ ExprPtr check_end =
+ IRSimplifier::simplify(alloc<CompareSelect>(end, rhsScalar, kLE));
if (!check_start->isConstant() || !check_end->isConstant() ||
!immediateEquals(check_start, 1) || !immediateEquals(check_end, 1)) {
return nullptr;
}
- Expr* ret = IRSimplifier::simplify(new Div(main, rhsScalar));
+ ExprPtr ret = IRSimplifier::simplify(alloc<Div>(main, rhsScalar));
// simplify type 1) exprs: '(i+x)/n' => 'x/n'
- Expr* sign_check =
- IRSimplifier::simplify(new CompareSelect(main, new IntImm(0), kGE));
- Expr* main_mod = IRSimplifier::simplify(new Mod(main, rhsScalar));
- Expr* mod_check = IRSimplifier::simplify(
- new CompareSelect(new Add(main_mod, end), rhsScalar, kLE));
+ ExprPtr sign_check =
+ IRSimplifier::simplify(alloc<CompareSelect>(main, alloc<IntImm>(0), kGE));
+ ExprPtr main_mod = IRSimplifier::simplify(alloc<Mod>(main, rhsScalar));
+ ExprPtr mod_check = IRSimplifier::simplify(
+ alloc<CompareSelect>(alloc<Add>(main_mod, end), rhsScalar, kLE));
if (sign_check->isConstant() && immediateEquals(sign_check, 1) &&
mod_check->isConstant() && immediateEquals(mod_check, 1)) {
return ret;
}
// simplify type 2 exprs: '(i+j*n)/n' => 'j'
- auto ret_var = dynamic_cast<Var*>(ret);
+ auto ret_var = to<Var>(ret);
if (ret_var && ret_var->dtype() == kInt) {
// retrieve j's range info
auto got = var_bound_info.find(ret_var);
// check if j is not negative
sign_check = IRSimplifier::simplify(
- new CompareSelect(got->second.first, new IntImm(0), kGE));
+ alloc<CompareSelect>(got->second.first, alloc<IntImm>(0), kGE));
if (sign_check->isConstant() && immediateEquals(sign_check, 1)) {
return ret_var;
}
// returns -1. But currently, both Pytorch and NNC are performing an incorrect
// integer division: (-1)/6 = 0. With the current implementation of integer
// division, j has to be not negative. d) j is not negative
-Expr* distributeMod(Expr* lhs, Expr* rhs, VarBoundInfo var_bound_info) {
+ExprPtr distributeMod(ExprPtr lhs, ExprPtr rhs, VarBoundInfo var_bound_info) {
if (!lhs || !rhs) {
return nullptr;
}
}
// identify n: a positive integer constant
- Expr* rhsScalar = rhs->isConstant() ? rhs : nullptr;
+ ExprPtr rhsScalar = rhs->isConstant() ? rhs : nullptr;
if (!rhsScalar) {
return nullptr;
}
- Expr* check_n_value =
- IRSimplifier::simplify(new CompareSelect(rhsScalar, new IntImm(0), kGT));
+ ExprPtr check_n_value = IRSimplifier::simplify(
+ alloc<CompareSelect>(rhsScalar, alloc<IntImm>(0), kGT));
if (!immediateEquals(check_n_value, 1)) {
return nullptr;
}
- auto* lhsAdd = dynamic_cast<Add*>(lhs);
+ auto lhsAdd = to<Add>(lhs);
if (!lhsAdd) {
return nullptr;
}
if (!lhsAdd || !rhsScalar) {
return nullptr;
}
- Expr* lhsAdd1 = lhsAdd->lhs();
- Expr* lhsAdd2 = lhsAdd->rhs();
+ ExprPtr lhsAdd1 = lhsAdd->lhs();
+ ExprPtr lhsAdd2 = lhsAdd->rhs();
// identify index var 'i'
- Var* var_key = dynamic_cast<Var*>(lhsAdd1);
- Expr* main = lhsAdd2;
+ VarPtr var_key = to<Var>(lhsAdd1);
+ ExprPtr main = lhsAdd2;
if (var_key == nullptr) {
- var_key = dynamic_cast<Var*>(lhsAdd2);
+ var_key = to<Var>(lhsAdd2);
main = lhsAdd1;
}
if (var_key == nullptr) {
// open upper bound, i.e., end is one more than the maximum value in the
// range
auto end = got->second.second;
- Expr* check_start =
- IRSimplifier::simplify(new CompareSelect(start, new IntImm(0), kGE));
- Expr* check_end =
- IRSimplifier::simplify(new CompareSelect(end, rhsScalar, kLE));
+ ExprPtr check_start = IRSimplifier::simplify(
+ alloc<CompareSelect>(start, alloc<IntImm>(0), kGE));
+ ExprPtr check_end =
+ IRSimplifier::simplify(alloc<CompareSelect>(end, rhsScalar, kLE));
if (!check_start->isConstant() || !check_end->isConstant() ||
!immediateEquals(check_start, 1) || !immediateEquals(check_end, 1)) {
return nullptr;
}
// simplify type 1) exprs: '(i+x)%n' => 'i+x%n'
- Expr* sign_check =
- IRSimplifier::simplify(new CompareSelect(main, new IntImm(0), kGE));
- Expr* main_mod = IRSimplifier::simplify(new Mod(main, rhsScalar));
- Expr* mod_check = IRSimplifier::simplify(
- new CompareSelect(new Add(main_mod, end), rhsScalar, kLE));
+ ExprPtr sign_check =
+ IRSimplifier::simplify(alloc<CompareSelect>(main, alloc<IntImm>(0), kGE));
+ ExprPtr main_mod = IRSimplifier::simplify(alloc<Mod>(main, rhsScalar));
+ ExprPtr mod_check = IRSimplifier::simplify(
+ alloc<CompareSelect>(alloc<Add>(main_mod, end), rhsScalar, kLE));
if (sign_check->isConstant() && immediateEquals(sign_check, 1) &&
mod_check->isConstant() && immediateEquals(mod_check, 1)) {
- return new Add(var_key, main_mod);
+ return alloc<Add>(var_key, main_mod);
}
// simplify type 2) exprs: '(i+j*n)%n' => 'i'
- Expr* main_div = IRSimplifier::simplify(new Div(main, rhsScalar));
- auto j_var = dynamic_cast<Var*>(main_div);
+ ExprPtr main_div = IRSimplifier::simplify(alloc<Div>(main, rhsScalar));
+ auto j_var = to<Var>(main_div);
if (j_var && j_var->dtype() == kInt) {
// retrieve j's range info
auto got = var_bound_info.find(j_var);
// check if j is not negative
sign_check = IRSimplifier::simplify(
- new CompareSelect(got->second.first, new IntImm(0), kGE));
+ alloc<CompareSelect>(got->second.first, alloc<IntImm>(0), kGE));
if (sign_check->isConstant() && immediateEquals(sign_check, 1)) {
return var_key;
}
return nullptr;
}
-Expr* SimplifierUnderContext::mutate(Div* v) {
- Expr* lhs = v->lhs();
- Expr* rhs = v->rhs();
+ExprPtr SimplifierUnderContext::mutate(DivPtr v) {
+ ExprPtr lhs = v->lhs();
+ ExprPtr rhs = v->rhs();
std::ostringstream oss;
if (auto ret = distributeDiv(lhs, rhs, var_bound_info_)) {
return ret->accept_mutator(this);
}
- Expr* lhs_new = lhs->accept_mutator(this);
- Expr* rhs_new = rhs->accept_mutator(this);
+ ExprPtr lhs_new = lhs->accept_mutator(this);
+ ExprPtr rhs_new = rhs->accept_mutator(this);
if (lhs == lhs_new && rhs == rhs_new) {
return v;
}
- return new Div(lhs_new, rhs_new);
+ return alloc<Div>(lhs_new, rhs_new);
}
-Expr* SimplifierUnderContext::mutate(Mod* v) {
- Expr* lhs = v->lhs();
- Expr* rhs = v->rhs();
+ExprPtr SimplifierUnderContext::mutate(ModPtr v) {
+ ExprPtr lhs = v->lhs();
+ ExprPtr rhs = v->rhs();
std::ostringstream oss;
if (auto ret = distributeMod(lhs, rhs, var_bound_info_)) {
// i % N -> i if the range of i's values is a subset of [0, N)
// where N is an integer constant
- auto* lhsVar = dynamic_cast<Var*>(lhs);
- Expr* rhsScalar = rhs->isConstant() ? rhs : nullptr;
+ auto lhsVar = to<Var>(lhs);
+ ExprPtr rhsScalar = rhs->isConstant() ? rhs : nullptr;
if (lhsVar && rhsScalar && !rhsScalar->dtype().is_floating_point()) {
auto got = var_bound_info_.find(lhsVar);
if (got != var_bound_info_.end()) {
auto start = got->second.first;
auto end = got->second.second;
- Expr* check_start =
- IRSimplifier::simplify(new CompareSelect(start, new IntImm(0), kGE));
- Expr* check_end =
- IRSimplifier::simplify(new CompareSelect(end, rhsScalar, kLE));
+ ExprPtr check_start = IRSimplifier::simplify(
+ alloc<CompareSelect>(start, alloc<IntImm>(0), kGE));
+ ExprPtr check_end =
+ IRSimplifier::simplify(alloc<CompareSelect>(end, rhsScalar, kLE));
if (check_start->isConstant() && check_end->isConstant() &&
immediateEquals(check_start, 1) && immediateEquals(check_end, 1)) {
oss << "SimplifierUnderContext: " << *v << " => " << *lhsVar << "\n";
}
}
- Expr* lhs_new = lhs->accept_mutator(this);
- Expr* rhs_new = rhs->accept_mutator(this);
+ ExprPtr lhs_new = lhs->accept_mutator(this);
+ ExprPtr rhs_new = rhs->accept_mutator(this);
if (lhs == lhs_new && rhs == rhs_new) {
return v;
}
- return new Mod(lhs_new, rhs_new);
+ return alloc<Mod>(lhs_new, rhs_new);
}
-bool exprEquals(Expr* A, Expr* B) {
+bool exprEquals(ExprPtr A, ExprPtr B) {
try {
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
- Expr* diff = IRSimplifier::simplify(new Sub(A, B));
+ ExprPtr diff = IRSimplifier::simplify(alloc<Sub>(A, B));
if (!diff->isConstant()) {
return false;
}
// A bunch of helpers for determine the Dtype of the output of a multi argument
// Term or Polynomial.
template <class ExprType>
-Dtype promoteTypesVec(Expr* s, std::vector<ExprType*>& v) {
+Dtype promoteTypesVec(ExprPtr s, std::vector<ExprType>& v) {
Dtype t = s->dtype();
bool first = true;
- for (auto* e : v) {
+ for (auto e : v) {
if (first) {
t = Dtype(t.scalar_type(), e->dtype().lanes());
first = false;
}
template <class ExprType>
-Dtype promoteTypesVec(std::vector<ExprType*>& v) {
+Dtype promoteTypesVec(std::vector<ExprType>& v) {
if (v.empty()) {
throw malformed_input("empty list of types");
}
Dtype t = v[0]->dtype();
- for (auto* e : v) {
+ for (auto e : v) {
t = promoteTypes(t, e->dtype());
}
return t;
template <class ExprType>
Dtype promoteTypesMap(
- Expr* s,
+ ExprPtr s,
std::unordered_map<SimplifierHashType, ExprType*>& m) {
Dtype t = s->dtype();
bool first = true;
}
// Creates a new Expr of the given type with the provided lhs and rhs.
-inline Expr* newBinaryOpOfType(
+inline ExprPtr newBinaryOpOfType(
IRNodeType expr_type,
- Expr* lhs,
- Expr* rhs,
+ ExprPtr lhs,
+ ExprPtr rhs,
bool option) {
switch (expr_type) {
// NOLINTNEXTLINE(bugprone-branch-clone)
case IRNodeType::kAdd:
- return new Add(lhs, rhs);
+ return alloc<Add>(lhs, rhs);
case IRNodeType::kSub:
- return new Sub(lhs, rhs);
+ return alloc<Sub>(lhs, rhs);
case IRNodeType::kMul:
- return new Mul(lhs, rhs);
+ return alloc<Mul>(lhs, rhs);
case IRNodeType::kDiv:
- return new Div(lhs, rhs);
+ return alloc<Div>(lhs, rhs);
case IRNodeType::kMod:
- return new Mod(lhs, rhs);
+ return alloc<Mod>(lhs, rhs);
case IRNodeType::kMax:
- return new Max(lhs, rhs, option);
+ return alloc<Max>(lhs, rhs, option);
case IRNodeType::kMin:
- return new Min(lhs, rhs, option);
+ return alloc<Min>(lhs, rhs, option);
case IRNodeType::kAnd:
- return new And(lhs, rhs);
+ return alloc<And>(lhs, rhs);
case IRNodeType::kXor:
- return new Xor(lhs, rhs);
+ return alloc<Xor>(lhs, rhs);
case IRNodeType::kLshift:
- return new Lshift(lhs, rhs);
+ return alloc<Lshift>(lhs, rhs);
case IRNodeType::kRshift:
- return new Rshift(lhs, rhs);
+ return alloc<Rshift>(lhs, rhs);
default:
LOG(FATAL) << "unsupported expr_type: " << static_cast<int>(expr_type);
return nullptr;
// Uses the evaluator to fold an Expression with constant terms.
// E.g. evaluateOp(Add(3, 4)) => 7.
// Expr v must not have any unbound Vars.
-inline Expr* evaluateOp(Expr* v) {
+inline ExprPtr evaluateOp(ExprPtr v) {
ExprHandle handle(v);
ExprEval<SimpleIREvaluator> eval(handle);
public:
template <class... Args>
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
- Term(HashProvider& hasher, Expr* s, Args... ts)
+ Term(HashProvider& hasher, ExprPtr s, Args... ts)
: ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) {
CHECK(s->isConstant());
addComponent(ts...);
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
- Term(HashProvider& hasher, Expr* s, std::vector<Expr*> v)
+ Term(HashProvider& hasher, ExprPtr s, std::vector<ExprPtr> v)
: ExprNodeBase(promoteTypesVec(s, v)),
variables_(std::move(v)),
scalar_(s),
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Term(
HashProvider& hasher,
- Expr* s,
- std::unordered_map<SimplifierHashType, Expr*> varmap)
+ ExprPtr s,
+ std::unordered_map<SimplifierHashType, ExprPtr> varmap)
: ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) {
for (auto& p : varmap) {
addComponent(p.second);
sort();
}
- Expr* scalar() const {
+ ExprPtr scalar() const {
return scalar_;
}
- const std::vector<Expr*>& variables() const {
+ const std::vector<ExprPtr>& variables() const {
return variables_;
}
HashProvider& hasher() const {
SimplifierHashType hashVars() const;
private:
- std::vector<Expr*> variables_;
- Expr* scalar_;
+ std::vector<ExprPtr> variables_;
+ ExprPtr scalar_;
HashProvider& hasher_;
void addComponent() {}
- void addComponent(Expr* e) {
+ void addComponent(ExprPtr e) {
variables_.push_back(e);
}
template <class... Es>
- void addComponent(Expr* e, Es... es) {
+ void addComponent(ExprPtr e, Es... es) {
addComponent(e);
addComponent(es...);
}
public:
template <class... Args>
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
- Polynomial(HashProvider& hasher, Expr* s, Args... ts)
+ Polynomial(HashProvider& hasher, ExprPtr s, Args... ts)
: ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) {
CHECK(s->isConstant());
addTerm(ts...);
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
- Polynomial(HashProvider& hasher, Expr* s, std::vector<Term*> v)
+ Polynomial(HashProvider& hasher, ExprPtr s, std::vector<TermPtr> v)
: ExprNodeBase(promoteTypesVec(s, v)),
variables_(std::move(v)),
scalar_(s),
// Helper constructor for list of terms with no scalar component.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
- Polynomial(HashProvider& hasher, std::vector<Term*> terms)
+ Polynomial(HashProvider& hasher, std::vector<TermPtr> terms)
: ExprNodeBase(promoteTypesVec(terms)),
variables_(std::move(terms)),
scalar_(getImmediateByType(dtype(), 0)),
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Polynomial(
HashProvider& hasher,
- Expr* s,
- std::unordered_map<SimplifierHashType, Term*> varmap)
+ ExprPtr s,
+ std::unordered_map<SimplifierHashType, TermPtr> varmap)
: ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) {
for (auto& p : varmap) {
addTerm(p.second);
sort();
}
- Expr* scalar() const {
+ ExprPtr scalar() const {
return scalar_;
}
- const std::vector<Term*>& variables() const {
+ const std::vector<TermPtr>& variables() const {
return variables_;
}
HashProvider& hasher() const {
SimplifierHashType hashVars() const;
private:
- std::vector<Term*> variables_;
- Expr* scalar_;
+ std::vector<TermPtr> variables_;
+ ExprPtr scalar_;
HashProvider& hasher_;
- void addTerm(Term* t) {
+ void addTerm(TermPtr t) {
variables_.push_back(t);
}
template <class... Ts>
- void addTerm(Term* t, Ts... ts) {
+ void addTerm(TermPtr t, Ts... ts) {
addTerm(t);
addTerm(ts...);
}
class RoundOff : public BinaryOpNode<RoundOff> {
public:
- RoundOff(Expr* lhs, Expr* rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kOther) {}
+ RoundOff(ExprPtr lhs, ExprPtr rhs)
+ : BinaryOpNode(lhs, rhs, IRNodeType::kOther) {}
};
class MaxTerm : public ExprNode<MaxTerm> {
public:
template <class... Args>
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
- MaxTerm(HashProvider& hasher, Expr* s, bool p, Args... ts)
+ MaxTerm(HashProvider& hasher, ExprPtr s, bool p, Args... ts)
: ExprNodeBase(s ? promoteTypesVar(s, ts...) : promoteTypesVar(ts...)),
scalar_(s),
hasher_(hasher),
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
- MaxTerm(HashProvider& hasher, Expr* s, bool p, std::vector<Expr*> v)
+ MaxTerm(HashProvider& hasher, ExprPtr s, bool p, std::vector<ExprPtr> v)
: ExprNodeBase(s ? promoteTypesVec(s, v) : promoteTypesVec(v)),
variables_(std::move(v)),
scalar_(s),
return propagate_nans_;
}
- Expr* scalar() const {
+ ExprPtr scalar() const {
return scalar_;
}
- const std::vector<Expr*>& variables() const {
+ const std::vector<ExprPtr>& variables() const {
return variables_;
}
HashProvider& hasher() const {
}
private:
- std::vector<Expr*> variables_;
- Expr* scalar_;
+ std::vector<ExprPtr> variables_;
+ ExprPtr scalar_;
HashProvider& hasher_;
bool propagate_nans_;
void addComponent() {}
- void addComponent(Expr* e) {
+ void addComponent(ExprPtr e) {
variables_.push_back(e);
}
template <class... Es>
- void addComponent(Expr* e, Es... es) {
+ void addComponent(ExprPtr e, Es... es) {
addComponent(e);
addComponent(es...);
}
public:
template <class... Args>
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
- MinTerm(HashProvider& hasher, Expr* s, bool p, Args... ts)
+ MinTerm(HashProvider& hasher, ExprPtr s, bool p, Args... ts)
: ExprNodeBase(s ? promoteTypesVar(s, ts...) : promoteTypesVar(ts...)),
scalar_(s),
hasher_(hasher),
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
- MinTerm(HashProvider& hasher, Expr* s, bool p, std::vector<Expr*> v)
+ MinTerm(HashProvider& hasher, ExprPtr s, bool p, std::vector<ExprPtr> v)
: ExprNodeBase(s ? promoteTypesVec(s, v) : promoteTypesVec(v)),
variables_(std::move(v)),
scalar_(s),
return propagate_nans_;
}
- Expr* scalar() const {
+ ExprPtr scalar() const {
return scalar_;
}
- const std::vector<Expr*>& variables() const {
+ const std::vector<ExprPtr>& variables() const {
return variables_;
}
HashProvider& hasher() const {
}
private:
- std::vector<Expr*> variables_;
- Expr* scalar_;
+ std::vector<ExprPtr> variables_;
+ ExprPtr scalar_;
HashProvider& hasher_;
bool propagate_nans_;
void addComponent() {}
- void addComponent(Expr* e) {
+ void addComponent(ExprPtr e) {
variables_.push_back(e);
}
template <class... Es>
- void addComponent(Expr* e, Es... es) {
+ void addComponent(ExprPtr e, Es... es) {
addComponent(e);
addComponent(es...);
}
};
// Context-sensitive IR simplification
-using VarBoundInfo = std::unordered_map<Var*, std::pair<Expr*, Expr*>>;
+using VarBoundInfo = std::unordered_map<VarPtr, std::pair<ExprPtr, ExprPtr>>;
class TORCH_API SimplifierUnderContext : public IRMutator {
public:
~SimplifierUnderContext() override = default;
// Add boundary info for index variables in for-loops
- Stmt* mutate(For* v) override;
+ StmtPtr mutate(ForPtr v) override;
- Expr* mutate(Div* v) override;
- Expr* mutate(Mod* v) override;
+ ExprPtr mutate(DivPtr v) override;
+ ExprPtr mutate(ModPtr v) override;
protected:
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
public:
~PolynomialBase() override = default;
- Stmt* mutate(Block* v) override;
+ StmtPtr mutate(BlockPtr v) override;
- Stmt* mutate(Cond* v) override;
+ StmtPtr mutate(CondPtr v) override;
- Stmt* mutate(For* v) override;
+ StmtPtr mutate(ForPtr v) override;
// Trivially factorize terms by GCD of scalar components.
- Term* factorizePolynomial(Polynomial* poly);
+ TermPtr factorizePolynomial(PolynomialPtr poly);
HashProvider& hasher() {
return hasher_;
// Inserts term into the provided map, in the case of a hash collision
// combines the term with the existing and updates the map.
void addOrUpdateTerm(
- std::unordered_map<SimplifierHashType, Term*>& varmap,
- Term* term);
+ std::unordered_map<SimplifierHashType, TermPtr>& varmap,
+ TermPtr term);
// Add Polynomial expressions, combining Terms representing the same
// variables.
- Expr* addPolynomials(Polynomial* lhs, Polynomial* rhs);
+ ExprPtr addPolynomials(PolynomialPtr lhs, PolynomialPtr rhs);
- // Insert a new Term into the provided polynomial. If the new term has common
- // variables to an existing term it is combined.
- Expr* insertTerm(Polynomial* poly, Term* term);
+ // Insert a new Term into the provided polynomial. If the new term has
+ // common variables to an existing term it is combined.
+ ExprPtr insertTerm(PolynomialPtr poly, TermPtr term);
// Merge and simplify addition.
- Expr* mutate(Add* v) override;
+ ExprPtr mutate(AddPtr v) override;
// Subtract one term from another, cancelling if necessary.
- Expr* subTerms(Term* lhs, Term* rhs, bool negated);
+ ExprPtr subTerms(TermPtr lhs, TermPtr rhs, bool negated);
// Subtract the RHS Polynomial from the LHS Polynomial, cancelling out where
// possible.
- Expr* subPolynomials(Polynomial* lhs, Polynomial* rhs);
+ ExprPtr subPolynomials(PolynomialPtr lhs, PolynomialPtr rhs);
// Merge and simplify subtraction.
- Expr* mutate(Sub* v) override;
+ ExprPtr mutate(SubPtr v) override;
// Multiply two terms together, usually creating a new term with the variable
// lists concatenated.
- Term* mulTerms(Term* lhs, Term* rhs);
+ TermPtr mulTerms(TermPtr lhs, TermPtr rhs);
// Multiply a Polynomial by a Term.
- Expr* polyByTerm(Polynomial* poly, Term* term);
+ ExprPtr polyByTerm(PolynomialPtr poly, TermPtr term);
// Match a rounding pattern and create a RoundOff if found.
- Expr* isRoundOff(Expr* lhs, Expr* rhs);
+ ExprPtr isRoundOff(ExprPtr lhs, ExprPtr rhs);
// Inserts a new component into a term, simplifying if possible.
- Expr* insertIntoTerm(Term* term, Expr* expr);
+ ExprPtr insertIntoTerm(TermPtr term, ExprPtr expr);
// Merge and simplify multiplication.
- Expr* mutate(Mul* v) override;
+ ExprPtr mutate(MulPtr v) override;
- Expr* mutate(Div* v) override;
+ ExprPtr mutate(DivPtr v) override;
- Expr* mutate(Mod* v) override;
+ ExprPtr mutate(ModPtr v) override;
- Expr* mutate(And* v) override {
+ ExprPtr mutate(AndPtr v) override {
return mutateBinaryOp(v, this);
}
- Expr* mutate(Xor* v) override {
+ ExprPtr mutate(XorPtr v) override {
return mutateBinaryOp(v, this);
}
- Expr* mutate(Lshift* v) override {
+ ExprPtr mutate(LshiftPtr v) override {
return mutateBinaryOp(v, this);
}
- Expr* mutate(Rshift* v) override {
+ ExprPtr mutate(RshiftPtr v) override {
return mutateBinaryOp(v, this);
}
- Expr* mutate(Max* v) override;
+ ExprPtr mutate(MaxPtr v) override;
- Expr* mutate(Min* v) override;
+ ExprPtr mutate(MinPtr v) override;
- Expr* mutate(CompareSelect* v) override;
+ ExprPtr mutate(CompareSelectPtr v) override;
- Expr* mutate(Intrinsics* v) override;
+ ExprPtr mutate(IntrinsicsPtr v) override;
- Expr* mutate(Cast* v) override;
+ ExprPtr mutate(CastPtr v) override;
- Expr* mutate(IfThenElse* v) override;
+ ExprPtr mutate(IfThenElsePtr v) override;
template <typename Op>
- static Expr* mutateBinaryOp(
+ static ExprPtr mutateBinaryOp(
BinaryOpNode<Op>* v,
IRMutator* mutator,
bool option = false) {
- Expr* lhs = v->lhs();
- Expr* rhs = v->rhs();
- Expr* lhs_new = lhs->accept_mutator(mutator);
- Expr* rhs_new = rhs->accept_mutator(mutator);
+ ExprPtr lhs = v->lhs();
+ ExprPtr rhs = v->rhs();
+ ExprPtr lhs_new = lhs->accept_mutator(mutator);
+ ExprPtr rhs_new = rhs->accept_mutator(mutator);
- Expr* node = v;
+ ExprPtr node = v;
if (lhs != lhs_new || rhs != rhs_new) {
node = newBinaryOpOfType(v->expr_type(), lhs_new, rhs_new, option);
return evaluateOp(node);
}
- static Expr* simplify(Expr* e);
+ static ExprPtr simplify(ExprPtr e);
static ExprHandle simplify(const ExprHandle& e);
- static Stmt* simplify(Stmt* e);
+ static StmtPtr simplify(StmtPtr e);
};
// Expands Terms and Polynomial expressions into primitive operations.
// Does some simple factorization and reordering.
class TORCH_API TermExpander : public PolynomialBase {
PolynomialTransformer* simplifier_;
- std::set<Var*> eliminated_allocations_;
+ std::set<VarPtr> eliminated_allocations_;
public:
using PolynomialBase::mutate;
}
// Expand Terms out to a series of Muls.
- Expr* mutate(Term* v) override;
+ ExprPtr mutate(TermPtr v) override;
// Expand Polynomials out to a series of Adds.
- Expr* mutate(Polynomial* v) override;
+ ExprPtr mutate(PolynomialPtr v) override;
// Expand MaxTerms to a series of Max ops.
- Expr* mutate(MaxTerm* v) override;
+ ExprPtr mutate(MaxTermPtr v) override;
// Expand MinTerms to a series of Min ops.
- Expr* mutate(MinTerm* v) override;
+ ExprPtr mutate(MinTermPtr v) override;
// Expand RoundOff to it's component: Mul(Div(lhs, rhs), rhs).
- Expr* mutate(RoundOff* v) override;
+ ExprPtr mutate(RoundOffPtr v) override;
// Eliminate zero length allocations.
- Stmt* mutate(Allocate* v) override;
- Stmt* mutate(Free* v) override;
+ StmtPtr mutate(AllocatePtr v) override;
+ StmtPtr mutate(FreePtr v) override;
// Override to enable condition fusing.
- Block* fuseConditions(Block* v);
- Stmt* fuseSyncThreads(Block* block);
- Stmt* mutate(Block* v) override;
+ BlockPtr fuseConditions(BlockPtr v);
+ StmtPtr fuseSyncThreads(BlockPtr block);
+ StmtPtr mutate(BlockPtr v) override;
};
class TORCH_API IRSimplifier {
public:
- static Expr* simplify(Expr* e) {
+ static ExprPtr simplify(ExprPtr e) {
SimplifierUnderContext ctxsimplifier;
e = e->accept_mutator(&ctxsimplifier);
return ExprHandle(simplify(e.node()));
}
- static Stmt* simplify(Stmt* s) {
+ static StmtPtr simplify(StmtPtr s) {
SimplifierUnderContext ctxsimplifier;
s = s->accept_mutator(&ctxsimplifier);
};
// Flattens the buf and performs the simplifier on the flattened dims.
-Expr* buf_flat_size(Buf* v);
+ExprPtr buf_flat_size(BufPtr v);
// Returns true if expressions A and B can be simplified to an equal expression.
-TORCH_API bool exprEquals(Expr* A, Expr* B);
+TORCH_API bool exprEquals(ExprPtr A, ExprPtr B);
} // namespace tensorexpr
} // namespace jit
}
}
-void IRVerifier::visit(And* v) {
+void IRVerifier::visit(AndPtr v) {
verifyBitwiseOp(v, this);
IRVisitor::visit(v);
}
-void IRVerifier::visit(Or* v) {
+void IRVerifier::visit(OrPtr v) {
verifyBitwiseOp(v, this);
IRVisitor::visit(v);
}
-void IRVerifier::visit(Xor* v) {
+void IRVerifier::visit(XorPtr v) {
verifyBitwiseOp(v, this);
IRVisitor::visit(v);
}
-void IRVerifier::visit(Lshift* v) {
+void IRVerifier::visit(LshiftPtr v) {
verifyBitwiseOp(v, this);
IRVisitor::visit(v);
}
-void IRVerifier::visit(Rshift* v) {
+void IRVerifier::visit(RshiftPtr v) {
verifyBitwiseOp(v, this);
IRVisitor::visit(v);
}
-void IRVerifier::visit(Mod* v) {
+void IRVerifier::visit(ModPtr v) {
if (!v->dtype().is_integral() && !v->dtype().is_floating_point()) {
throw std::runtime_error("invalid dtype: " + std::to_string(v->dtype()));
}
IRVisitor::visit(v);
}
-void IRVerifier::visit(CompareSelect* v) {
+void IRVerifier::visit(CompareSelectPtr v) {
if (v->ret_val1()->dtype() != v->ret_val2()->dtype()) {
throw malformed_ir("bad dtype in CompareSelect");
}
IRVisitor::visit(v);
}
-void IRVerifier::visit(Ramp* v) {
+void IRVerifier::visit(RampPtr v) {
if (v->stride()->dtype() != v->base()->dtype()) {
throw malformed_ir("Bad stride in Ramp");
}
IRVisitor::visit(v);
}
-void IRVerifier::visit(Load* v) {
+void IRVerifier::visit(LoadPtr v) {
auto indices = v->indices();
if (indices.size() > 0 && v->buf()->base_handle()->dtype() != kHandle) {
throw malformed_ir(
IRVisitor::visit(v);
}
-void IRVerifier::visit(IfThenElse* v) {
+void IRVerifier::visit(IfThenElsePtr v) {
if (!v->condition()->dtype().is_integral()) {
throw unsupported_dtype();
}
IRVisitor::visit(v);
}
-void IRVerifier::visit(Intrinsics* v) {
+void IRVerifier::visit(IntrinsicsPtr v) {
// TODO: add a check for OpArgCount and op_type
IRVisitor::visit(v);
}
-void IRVerifier::visit(Store* v) {
+void IRVerifier::visit(StorePtr v) {
auto indices = v->indices();
if (indices.size() > 0 && v->buf()->base_handle()->dtype() != kHandle) {
throw malformed_ir(
IRVisitor::visit(v);
}
-void IRVerifier::visit(For* v) {
+void IRVerifier::visit(ForPtr v) {
if (!v->var()) {
throw malformed_ir("nullptr Var in For loop");
} else if (!v->start()) {
IRVisitor::visit(v);
}
-void IRVerifier::visit(Block* v) {
- for (Stmt* s : v->stmts()) {
+void IRVerifier::visit(BlockPtr v) {
+ for (StmtPtr s : v->stmts()) {
if (s->get_parent() != v) {
throw malformed_ir("Broken child-parent link inside a Block");
}
IRVisitor::visit(v);
}
-void IRVerifier::visit(ExternalCall* v) {
+void IRVerifier::visit(ExternalCallPtr v) {
IRVisitor::visit(v);
}
-void verify(Stmt* s) {
+void verify(StmtPtr s) {
IRVerifier verifier;
s->accept(&verifier);
}
-void verify(Expr* e) {
+void verify(ExprPtr e) {
IRVerifier verifier;
e->accept(&verifier);
}
#include <iostream>
+#include <torch/csrc/jit/tensorexpr/fwd_decls.h>
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
namespace torch {
public:
IRVerifier() = default;
- void visit(Mod* v) override;
- void visit(And* v) override;
- void visit(Or* v) override;
- void visit(Xor* v) override;
- void visit(Lshift* v) override;
- void visit(Rshift* v) override;
- void visit(CompareSelect* v) override;
- void visit(Ramp* v) override;
- void visit(Load* v) override;
- void visit(IfThenElse* v) override;
- void visit(Intrinsics* v) override;
-
- void visit(ExternalCall* v) override;
- void visit(Store* v) override;
- void visit(For* v) override;
- void visit(Block* v) override;
+ void visit(ModPtr v) override;
+ void visit(AndPtr v) override;
+ void visit(OrPtr v) override;
+ void visit(XorPtr v) override;
+ void visit(LshiftPtr v) override;
+ void visit(RshiftPtr v) override;
+ void visit(CompareSelectPtr v) override;
+ void visit(RampPtr v) override;
+ void visit(LoadPtr v) override;
+ void visit(IfThenElsePtr v) override;
+ void visit(IntrinsicsPtr v) override;
+
+ void visit(ExternalCallPtr v) override;
+ void visit(StorePtr v) override;
+ void visit(ForPtr v) override;
+ void visit(BlockPtr v) override;
};
-TORCH_API void verify(Stmt*);
-TORCH_API void verify(Expr*);
+TORCH_API void verify(StmtPtr);
+TORCH_API void verify(ExprPtr);
TORCH_API void verify(ExprHandle);
} // namespace tensorexpr
v->rhs()->accept(visitor);
}
-void IRVisitor::visit(Add* v) {
+void IRVisitor::visit(AddPtr v) {
visit_binary_op(v, this);
}
-void IRVisitor::visit(Sub* v) {
+void IRVisitor::visit(SubPtr v) {
visit_binary_op(v, this);
}
-void IRVisitor::visit(Mul* v) {
+void IRVisitor::visit(MulPtr v) {
visit_binary_op(v, this);
}
-void IRVisitor::visit(Div* v) {
+void IRVisitor::visit(DivPtr v) {
visit_binary_op(v, this);
}
-void IRVisitor::visit(Mod* v) {
+void IRVisitor::visit(ModPtr v) {
visit_binary_op(v, this);
}
-void IRVisitor::visit(Max* v) {
+void IRVisitor::visit(MaxPtr v) {
visit_binary_op(v, this);
}
-void IRVisitor::visit(Min* v) {
+void IRVisitor::visit(MinPtr v) {
visit_binary_op(v, this);
}
-void IRVisitor::visit(And* v) {
+void IRVisitor::visit(AndPtr v) {
visit_binary_op(v, this);
}
-void IRVisitor::visit(Or* v) {
+void IRVisitor::visit(OrPtr v) {
visit_binary_op(v, this);
}
-void IRVisitor::visit(Xor* v) {
+void IRVisitor::visit(XorPtr v) {
visit_binary_op(v, this);
}
-void IRVisitor::visit(Lshift* v) {
+void IRVisitor::visit(LshiftPtr v) {
visit_binary_op(v, this);
}
-void IRVisitor::visit(Rshift* v) {
+void IRVisitor::visit(RshiftPtr v) {
visit_binary_op(v, this);
}
-void IRVisitor::visit(CompareSelect* v) {
+void IRVisitor::visit(CompareSelectPtr v) {
v->lhs()->accept(this);
v->rhs()->accept(this);
v->ret_val1()->accept(this);
// NOLINTNEXTLINE
#define IMM_VISIT(Type, Name) \
- void IRVisitor::visit(Name##Imm* v) {}
+ void IRVisitor::visit(Name##ImmPtr v) {}
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT);
#undef IMM_VISIT
-void IRVisitor::visit(Cast* v) {
+void IRVisitor::visit(CastPtr v) {
v->src_value()->accept(this);
}
-void IRVisitor::visit(BitCast* v) {
+void IRVisitor::visit(BitCastPtr v) {
v->src_value()->accept(this);
}
-void IRVisitor::visit(Var* v) {}
+void IRVisitor::visit(VarPtr v) {}
-void IRVisitor::visit(Ramp* v) {
+void IRVisitor::visit(RampPtr v) {
v->base()->accept(this);
v->stride()->accept(this);
}
-void IRVisitor::visit(Load* v) {
+void IRVisitor::visit(LoadPtr v) {
v->buf()->accept(this);
- for (Expr* ind : v->indices()) {
+ for (ExprPtr ind : v->indices()) {
ind->accept(this);
}
}
-void IRVisitor::visit(Buf* v) {
+void IRVisitor::visit(BufPtr v) {
v->base_handle()->accept(this);
}
-void IRVisitor::visit(Store* v) {
+void IRVisitor::visit(StorePtr v) {
v->buf()->accept(this);
- for (Expr* ind : v->indices()) {
+ for (ExprPtr ind : v->indices()) {
ind->accept(this);
}
v->value()->accept(this);
}
-void IRVisitor::visit(AtomicAdd* v) {
+void IRVisitor::visit(AtomicAddPtr v) {
v->buf()->accept(this);
- for (Expr* ind : v->indices()) {
+ for (ExprPtr ind : v->indices()) {
ind->accept(this);
}
v->value()->accept(this);
}
-void IRVisitor::visit(SyncThreads* v) {}
+void IRVisitor::visit(SyncThreadsPtr v) {}
-void IRVisitor::visit(ExternalCall* v) {
+void IRVisitor::visit(ExternalCallPtr v) {
v->buf()->accept(this);
- for (Buf* buf_arg : v->buf_args()) {
+ for (BufPtr buf_arg : v->buf_args()) {
buf_arg->accept(this);
}
- for (Expr* arg : v->args()) {
+ for (ExprPtr arg : v->args()) {
arg->accept(this);
}
}
-void IRVisitor::visit(Block* v) {
- for (Stmt* s : *v) {
+void IRVisitor::visit(BlockPtr v) {
+ for (StmtPtr s : *v) {
s->accept(this);
}
}
-void IRVisitor::visit(For* v) {
+void IRVisitor::visit(ForPtr v) {
v->var()->accept(this);
v->start()->accept(this);
v->stop()->accept(this);
}
}
-void IRVisitor::visit(Broadcast* v) {
+void IRVisitor::visit(BroadcastPtr v) {
v->value()->accept(this);
}
-void IRVisitor::visit(IfThenElse* v) {
+void IRVisitor::visit(IfThenElsePtr v) {
v->condition()->accept(this);
v->true_value()->accept(this);
v->false_value()->accept(this);
}
-void IRVisitor::visit(Intrinsics* v) {
+void IRVisitor::visit(IntrinsicsPtr v) {
for (const auto i : c10::irange(v->nparams())) {
v->param(i)->accept(this);
}
}
-void IRVisitor::visit(Allocate* v) {
+void IRVisitor::visit(AllocatePtr v) {
v->buffer_var()->accept(this);
- std::vector<Expr*> dims = v->dims();
- for (Expr* dim : dims) {
+ std::vector<ExprPtr> dims = v->dims();
+ for (ExprPtr dim : dims) {
dim->accept(this);
}
}
-void IRVisitor::visit(Free* v) {
+void IRVisitor::visit(FreePtr v) {
v->buffer_var()->accept(this);
}
-void IRVisitor::visit(Let* v) {
+void IRVisitor::visit(LetPtr v) {
v->var()->accept(this);
v->value()->accept(this);
}
-void IRVisitor::visit(Cond* v) {
- Expr* condition = v->condition();
- Stmt* true_stmt = v->true_stmt();
- Stmt* false_stmt = v->false_stmt();
+void IRVisitor::visit(CondPtr v) {
+ ExprPtr condition = v->condition();
+ StmtPtr true_stmt = v->true_stmt();
+ StmtPtr false_stmt = v->false_stmt();
condition->accept(this);
if (true_stmt) {
true_stmt->accept(this);
}
}
-void IRVisitor::visit(Term* v) {
+void IRVisitor::visit(TermPtr v) {
v->scalar()->accept(this);
- for (auto* t : v->variables()) {
+ for (auto t : v->variables()) {
t->accept(this);
}
}
-void IRVisitor::visit(Polynomial* v) {
+void IRVisitor::visit(PolynomialPtr v) {
v->scalar()->accept(this);
- for (auto* t : v->variables()) {
+ for (auto t : v->variables()) {
t->accept(this);
}
}
-void IRVisitor::visit(RoundOff* v) {
+void IRVisitor::visit(RoundOffPtr v) {
v->lhs()->accept(this);
v->rhs()->accept(this);
}
-void IRVisitor::visit(MaxTerm* v) {
+void IRVisitor::visit(MaxTermPtr v) {
if (v->scalar()) {
v->scalar()->accept(this);
}
- for (auto* t : v->variables()) {
+ for (auto t : v->variables()) {
t->accept(this);
}
}
-void IRVisitor::visit(MinTerm* v) {
+void IRVisitor::visit(MinTermPtr v) {
if (v->scalar()) {
v->scalar()->accept(this);
}
- for (auto* t : v->variables()) {
+ for (auto t : v->variables()) {
t->accept(this);
}
}
-void IRVisitor::visit(ReduceOp* v) {
+void IRVisitor::visit(ReduceOpPtr v) {
v->body()->accept(this);
- for (auto* r : v->reduce_args()) {
+ for (auto r : v->reduce_args()) {
r->accept(this);
}
}
#pragma once
#include <c10/core/ScalarType.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/tensorexpr/fwd_decls.h>
namespace torch {
namespace jit {
namespace tensorexpr {
-class Add;
-class Sub;
-class Mul;
-class Div;
-class Mod;
-class Max;
-class Min;
-class And;
-class Or;
-class Xor;
-class Lshift;
-class Rshift;
-class CompareSelect;
-
-#define IMM_DECLARE(Type, Name) class Name##Imm;
-
-AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE)
-#undef IMM_DECLARE
-
-class Cast;
-class BitCast;
-class Var;
-class Buf;
-class Ramp;
-class Load;
-class For;
-class Block;
-class Store;
-class Broadcast;
-class IfThenElse;
-class Intrinsics;
-class Allocate;
-class Free;
-class Let;
-class Cond;
-class Term;
-class Polynomial;
-class RoundOff;
-class MaxTerm;
-class MinTerm;
-class ReduceOp;
-class AtomicAdd;
-class SyncThreads;
-class ExternalCall;
-
class TORCH_API IRVisitor {
public:
virtual ~IRVisitor() = default;
- virtual void visit(Add* v);
- virtual void visit(Sub* v);
- virtual void visit(Mul* v);
- virtual void visit(Div* v);
- virtual void visit(Mod* v);
- virtual void visit(Max* v);
- virtual void visit(Min* v);
- virtual void visit(And* v);
- virtual void visit(Or* v);
- virtual void visit(Xor* v);
- virtual void visit(Lshift* v);
- virtual void visit(Rshift* v);
- virtual void visit(CompareSelect* v);
-
-#define IMM_PRINT_VISIT(Type, Name) virtual void visit(Name##Imm* v);
+ virtual void visit(AddPtr v);
+ virtual void visit(SubPtr v);
+ virtual void visit(MulPtr v);
+ virtual void visit(DivPtr v);
+ virtual void visit(ModPtr v);
+ virtual void visit(MaxPtr v);
+ virtual void visit(MinPtr v);
+ virtual void visit(AndPtr v);
+ virtual void visit(OrPtr v);
+ virtual void visit(XorPtr v);
+ virtual void visit(LshiftPtr v);
+ virtual void visit(RshiftPtr v);
+ virtual void visit(CompareSelectPtr v);
+
+#define IMM_PRINT_VISIT(Type, Name) virtual void visit(Name##ImmPtr v);
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_PRINT_VISIT)
#undef IMM_PRINT_VISIT
- virtual void visit(Cast* v);
- virtual void visit(BitCast* v);
- virtual void visit(Var* v);
- virtual void visit(Buf* v);
- virtual void visit(Ramp* v);
- virtual void visit(Load* v);
- virtual void visit(For* v);
- virtual void visit(Block* v);
- virtual void visit(Store* v);
- virtual void visit(Broadcast* v);
- virtual void visit(IfThenElse* v);
- virtual void visit(Intrinsics* v);
- virtual void visit(Allocate* v);
- virtual void visit(Free* v);
- virtual void visit(Let* v);
- virtual void visit(Cond* v);
- virtual void visit(Term* v);
- virtual void visit(Polynomial* v);
- virtual void visit(RoundOff* v);
- virtual void visit(MaxTerm* v);
- virtual void visit(MinTerm* v);
- virtual void visit(ReduceOp* v);
- virtual void visit(AtomicAdd* v);
- virtual void visit(SyncThreads* v);
- virtual void visit(ExternalCall* v);
+ virtual void visit(CastPtr v);
+ virtual void visit(BitCastPtr v);
+ virtual void visit(VarPtr v);
+ virtual void visit(BufPtr v);
+ virtual void visit(RampPtr v);
+ virtual void visit(LoadPtr v);
+ virtual void visit(ForPtr v);
+ virtual void visit(BlockPtr v);
+ virtual void visit(StorePtr v);
+ virtual void visit(BroadcastPtr v);
+ virtual void visit(IfThenElsePtr v);
+ virtual void visit(IntrinsicsPtr v);
+ virtual void visit(AllocatePtr v);
+ virtual void visit(FreePtr v);
+ virtual void visit(LetPtr v);
+ virtual void visit(CondPtr v);
+ virtual void visit(TermPtr v);
+ virtual void visit(PolynomialPtr v);
+ virtual void visit(RoundOffPtr v);
+ virtual void visit(MaxTermPtr v);
+ virtual void visit(MinTermPtr v);
+ virtual void visit(ReduceOpPtr v);
+ virtual void visit(AtomicAddPtr v);
+ virtual void visit(SyncThreadsPtr v);
+ virtual void visit(ExternalCallPtr v);
};
} // namespace tensorexpr
c10::optional<TensorInfo> getTensorInfo(BufHandle b) {
std::vector<int64_t> dims;
for (auto dim : b.dims()) {
- auto val = dynamic_cast<IntImm*>(dim.node());
+ auto val = to<IntImm>(dim.node());
if (!val) {
return c10::nullopt;
}
} // namespace jit
} // namespace torch
-static at::ScalarType tensorType(Buf* b) {
+static at::ScalarType tensorType(BufPtr b) {
return static_cast<at::ScalarType>(b->dtype().scalar_type());
}
-std::vector<int64_t> bufferSizes(Buf* b) {
+std::vector<int64_t> bufferSizes(BufPtr b) {
std::vector<int64_t> sizes;
for (size_t i = 0; i < b->ndim(); i++) {
- sizes.push_back(dynamic_cast<IntImm*>(b->dim(i))->value());
+ sizes.push_back(to<IntImm>(b->dim(i))->value());
}
return sizes;
}
ExprHandle TensorExprKernel::chunk(
- Buf* b,
+ BufPtr b,
size_t chunkIdx,
int64_t dim,
int64_t chunks,
// output[i,j+l2,k] = inp3[i,j,k]
auto output_sizes_expr = ExprHandleVectorToExprVector(outputShape);
- auto output_buf = new Buf("aten_cat", output_sizes_expr, ToDtype(high_type));
+ auto output_buf =
+ alloc<Buf>("aten_cat", output_sizes_expr, ToDtype(high_type));
if (non_empty_inputs.size() == 0) {
- return new Tensor(output_buf, new tensorexpr::Block({}));
+ return new Tensor(
+ output_buf, alloc<tensorexpr::Block>(std::vector<StmtPtr>({})));
}
int64_t concat_dim = c10::get<int64_t>(arg_dim);
auto gen_code_for_input = [&](const BufHandle& inp,
size_t inp_pos,
- Expr* concat_dim_size,
+ ExprPtr concat_dim_size,
const std::vector<ExprHandle>& dims) {
- std::vector<Var*> for_vars(dims.size());
- std::vector<Expr*> load_indices(dims.size());
- std::vector<Expr*> store_indices(dims.size());
+ std::vector<VarPtr> for_vars(dims.size());
+ std::vector<ExprPtr> load_indices(dims.size());
+ std::vector<ExprPtr> store_indices(dims.size());
for (size_t i = 0; i < dims.size(); ++i) {
- for_vars[i] = new Var(
+ for_vars[i] = alloc<Var>(
"i" + c10::to_string(inp_pos) + "_" + c10::to_string(i), kInt);
load_indices[i] = for_vars[i];
if (i == norm_concat_dim) {
- store_indices[i] = new Add(for_vars[i], concat_dim_size);
+ store_indices[i] = alloc<Add>(for_vars[i], concat_dim_size);
} else {
store_indices[i] = for_vars[i];
}
}
auto inp_buf = inp.node();
- auto load_expr = new Load(inp_buf, load_indices);
+ auto load_expr = alloc<Load>(inp_buf, load_indices);
auto load_promoted = promoteToDtype(ExprHandle(load_expr), high_type);
- Stmt* st = new Store(output_buf, store_indices, load_promoted.node());
+ StmtPtr st = alloc<Store>(output_buf, store_indices, load_promoted.node());
for (size_t i = dims.size(); i > 0; --i) {
- st = new For(for_vars[i - 1], new IntImm(0), dims[i - 1].node(), st);
+ st =
+ alloc<For>(for_vars[i - 1], alloc<IntImm>(0), dims[i - 1].node(), st);
}
return st;
};
- Expr* concat_dim_size = nullptr;
- auto block = new tensorexpr::Block({});
+ ExprPtr concat_dim_size = nullptr;
+ auto block = alloc<tensorexpr::Block>(std::vector<StmtPtr>({}));
for (size_t i = 0; i < non_empty_inputs.size(); ++i) {
auto input_dims =
ExprVectorToExprHandleVector(non_empty_inputs[i].node()->dims());
if (concat_dim_size == nullptr) {
- concat_dim_size = new IntImm(0);
+ concat_dim_size = alloc<IntImm>(0);
}
block->append_stmt(gen_code_for_input(
non_empty_inputs[i], i, concat_dim_size, input_dims));
concat_dim_size =
- new Add(concat_dim_size, input_dims[norm_concat_dim].node());
+ alloc<Add>(concat_dim_size, input_dims[norm_concat_dim].node());
}
return new Tensor(output_buf, IRSimplifier::simplify(block));
}
std::vector<ExprHandle> newAxes(axes.begin(), axes.end());
ExprHandle load = promoteToDtype(
tensorOrConstant(nonEmptyInputs[0], newAxes), highType);
- size_t offset =
- dynamic_cast<IntImm*>(nonEmptyInputs[0].node()->dim(dim))->value();
+ size_t offset = to<IntImm>(nonEmptyInputs[0].node()->dim(dim))->value();
newAxes[dim] = newAxes[dim] - IntImm::make(offset);
for (size_t ii = 1; ii < nonEmptyInputs.size(); ++ii) {
load,
promoteToDtype(tensorOrConstant(input, newAxes), highType));
- offset += dynamic_cast<IntImm*>(input.node()->dim(dim))->value();
+ offset += to<IntImm>(input.node()->dim(dim))->value();
newAxes[dim] = axes[dim] - IntImm::make(offset);
}
// Once we have a performant TE representation for conv2d, we could use it
// here instead of the external call!
- Stmt* s = ExternalCall::make(
+ StmtPtr s = ExternalCall::make(
ResultBuf,
"nnc_aten_conv2d",
{inp, w, b},
*/
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
ExprHandle cur_stride = 1;
- std::vector<Expr*> dims, indices;
+ std::vector<ExprPtr> dims, indices;
for (size_t idx = 0; idx < view_dims.size(); idx++) {
- dims.push_back(new IntImm(view_dims[idx]));
+ dims.push_back(alloc<IntImm>(view_dims[idx]));
indices.push_back(axes[idx].node());
}
ExprHandle flat_idx = ExprHandle(flatten_index(dims, indices));
}
// Return the (lower, upper) loop bounds if they are constants, else nullopt.
-c10::optional<std::pair<int64_t, int64_t>> loopBounds(For* loop) {
+c10::optional<std::pair<int64_t, int64_t>> loopBounds(ForPtr loop) {
auto start = IRSimplifier::simplify(loop->start());
auto stop = IRSimplifier::simplify(loop->stop());
if (!start->isConstant() || !stop->isConstant()) {
}
// True if all the loops in this vector have equal bounds.
-bool loopBoundsAllEqual(const std::vector<For*>& loops) {
+bool loopBoundsAllEqual(const std::vector<ForPtr>& loops) {
auto bounds = loopBounds(loops[0]);
if (!bounds) {
return false;
// on matching bounds exists to avoid inserting conditionals on the loop
// indices where none would be needed, which would significantly complicate
// vectorization.
-void fuseAllLoops(Stmt* st) {
- if (auto block = dynamic_cast<tensorexpr::Block*>(st)) {
- std::vector<For*> loopsToFuse;
+void fuseAllLoops(StmtPtr st) {
+ if (auto block = to<tensorexpr::Block>(st)) {
+ std::vector<ForPtr> loopsToFuse;
for (auto stmt : *block) {
- auto loop = dynamic_cast<For*>(stmt);
+ auto loop = to<For>(stmt);
if (!loop) {
// Block contains something that's not a loop. Quit.
return;
return;
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fusedLoop;
+ ForPtr fusedLoop;
if (!LoopNest::fuseLoops(loopsToFuse, &fusedLoop)) {
return;
}
}
}
-Stmt* TensorExprKernel::transformLoops(BackendType backendType, Stmt* st) {
+StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) {
torch::jit::tensorexpr::LoopNest l(st, bufOutputs_);
GRAPH_DEBUG("Original Stmt:\n", std::to_string(l.root_stmt()), "\n");
if (backendType == kCudaCodeGen) {
for (auto buf : bufOutputs_) {
- std::vector<For*> loops = l.getLoopStmtsFor(buf);
+ std::vector<ForPtr> loops = l.getLoopStmtsFor(buf);
if (loops.empty()) {
// This happens when Buf is 0-dim
continue;
}
- For* flattened = nullptr;
+ ForPtr flattened = nullptr;
LoopNest::flatten(loops, &flattened);
assert(flattened);
if (loopLevels == 2) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* inner;
+ ForPtr inner;
const int kDefaultBlockSize = 512;
if (blockSize < 0) {
blockSize = kDefaultBlockSize;
inner->set_gpu_thread_index(0);
} else if (loopLevels == 3) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* inner;
+ ForPtr inner;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* inner1;
+ ForPtr inner1;
// TODO: change the number of microprocessors
const int kDefaultBlockCount = 1280;
const int kDefaultBlockSize = 256;
if (buf->dtype().scalar_type() == ScalarType::Byte) {
blockSize = default_uint8_blocksize;
}
- std::vector<For*> loops = l.getLoopStmtsFor(buf);
+ std::vector<ForPtr> loops = l.getLoopStmtsFor(buf);
TORCH_INTERNAL_ASSERT(!loops.empty(), "loops should not be empty");
- For* flattened = nullptr;
+ ForPtr flattened = nullptr;
LoopNest::flatten(loops, &flattened);
assert(flattened);
- For* inner = nullptr;
+ ForPtr inner = nullptr;
LoopNest::splitWithMask(flattened, blockSize, &inner);
flattened->set_gpu_block_index(0);
inner->set_gpu_thread_index(0);
l.vectorizeInnerLoops();
}
- Stmt* stmt = l.root_stmt();
+ StmtPtr stmt = l.root_stmt();
// Arithmetic Simplification.
stmt = IRSimplifier::simplify(stmt);
GRAPH_DEBUG("Final Stmt:\n", std::to_string(stmt), "\n");
Tensor* TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) {
const TensorTypePtr& tt = v->type()->expect<TensorType>();
TORCH_INTERNAL_ASSERT(bufs_.count(v));
- Buf* buf = bufs_.at(v);
+ BufPtr buf = bufs_.at(v);
// No shape info is present in the graph
if (!tt->sizes().concrete_sizes()) {
te_sizes.push_back(IntImm::make(s));
}
- Buf* buf = new Buf(
+ BufPtr buf = alloc<Buf>(
"const_" + v->debugName(),
ExprHandleVectorToExprVector(te_sizes),
ToDtype(static_cast<ScalarType>(*tt->scalarType())));
OptimizeCat(graph_);
// Block to collect the Stmts corresponding to all tensors.
- auto block = new Block({});
+ auto block = alloc<Block>(std::vector<StmtPtr>({}));
// Bind inputs to buffers.
nInputs_ = graph_->inputs().size();
}
BackendType backendType = inferBackendTypeFromDevice(device_);
- Stmt* stmt = transformLoops(backendType, block);
+ StmtPtr stmt = transformLoops(backendType, block);
// Generate code.
codegen_ = CreateCodeGen(
return runArgs;
}
-Stmt* TensorExprKernel::getCodeGenStmt() {
+StmtPtr TensorExprKernel::getCodeGenStmt() {
return codegen_->stmt();
}
inline std::vector<int64_t> bufferSizes(const T& t) {
std::vector<int64_t> sizes;
for (size_t i = 0; i < t->ndim(); i++) {
- sizes.push_back(dynamic_cast<IntImm*>(t->dim(i))->value());
+ sizes.push_back(to<IntImm>(t->dim(i))->value());
}
return sizes;
}
class TORCH_API TensorExprKernel {
struct ConstantDescr {
- Buf* buf;
+ BufPtr buf;
void* ptr;
};
InterpreterState(code_).run(stack);
}
- Stmt* getCodeGenStmt();
+ StmtPtr getCodeGenStmt();
std::string getCodeText(const std::string& attr = "") {
return codegen_->getCodeText(attr);
std::vector<std::vector<ExprHandle>> shapes);
ExprHandle chunk(
- Buf* b,
+ BufPtr b,
size_t chunkIdx,
int64_t dim,
int64_t chunks,
void bindConstant(const torch::jit::Value* v);
- Stmt* transformLoops(BackendType backendType, Stmt* st);
+ StmtPtr transformLoops(BackendType backendType, StmtPtr st);
std::string getCodeGenName(BackendType backendType);
std::vector<std::vector<int64_t>> tensorOutputSizes_;
std::vector<std::vector<int64_t>> tensorOutputStrides_;
std::vector<UnpackedTensorOptions> tensorOutputTensorOptions_;
- std::unordered_set<Buf*> bufOutputs_;
- std::unordered_map<const torch::jit::Value*, Buf*> bufs_;
+ std::unordered_set<BufPtr> bufOutputs_;
+ std::unordered_map<const torch::jit::Value*, BufPtr> bufs_;
std::unordered_map<const torch::jit::Value*, VarHandle> scalars_;
std::unordered_map<const torch::jit::Value*, std::string> input_name_map_;
std::unique_ptr<CodeGen> codegen_;
llvm::Type* Int8PtrTy_;
llvm::Type* VoidTy_;
- std::unordered_map<const Var*, int> varToArg_;
- std::unordered_map<const Var*, llvm::Value*> varToVal_;
- std::unordered_map<Block*, std::vector<Var*>> scopeToVar_;
- Block* scope_;
+ std::unordered_map<VarPtr, int> varToArg_;
+ std::unordered_map<VarPtr, llvm::Value*> varToVal_;
+ std::unordered_map<BlockPtr, std::vector<VarPtr>> scopeToVar_;
+ BlockPtr scope_;
std::string llvmCode_;
std::string asmCode_;
llvm::Type* dtypeToLLVM(Dtype dtype);
llvm::Type* dtypeToLLVMPtr(Dtype dtype);
void emitWrapper(const std::vector<llvm::Type*>& params);
- void emitKernel(Stmt* stmt, const std::vector<llvm::Type*>& params);
+ void emitKernel(StmtPtr stmt, const std::vector<llvm::Type*>& params);
llvm::Value* toVec(llvm::Value* v, int lanes);
enum Arity {
Arity arity,
int lanes);
- llvm::Value* varToValue(Var* var);
+ llvm::Value* varToValue(VarPtr var);
void replaceVarMapping(
- const std::vector<Var*>& vars,
+ const std::vector<VarPtr>& vars,
const std::vector<llvm::Value*>& vals);
llvm::Value* packFuncArgs(const std::vector<llvm::Value*>& func_args);
std::vector<llvm::Value*> unpackFuncArgs(llvm::Value* packed, int arg_count);
- void processParallelFor(For* v);
+ void processParallelFor(ForPtr v);
public:
LLVMCodeGenImpl(
- Stmt* stmt,
+ StmtPtr stmt,
const std::vector<CodeGen::BufferArg>& args,
at::Device device,
Dtype dtype,
llvm::JITTargetAddress getKernelAddress() const;
- void visit(Add* v) override;
- void visit(Sub* v) override;
- void visit(Mul* v) override;
- void visit(Div* v) override;
- void visit(Mod* v) override;
- void visit(Max* v) override;
- void visit(Min* v) override;
- void visit(And* v) override;
- void visit(Or* v) override;
- void visit(Xor* v) override;
- void visit(Lshift* v) override;
- void visit(Rshift* v) override;
- void visit(CompareSelect* v) override;
-
-#define IMM_VISIT_DECLARE(_1, Name) void visit(Name##Imm* v) override;
+ void visit(AddPtr v) override;
+ void visit(SubPtr v) override;
+ void visit(MulPtr v) override;
+ void visit(DivPtr v) override;
+ void visit(ModPtr v) override;
+ void visit(MaxPtr v) override;
+ void visit(MinPtr v) override;
+ void visit(AndPtr v) override;
+ void visit(OrPtr v) override;
+ void visit(XorPtr v) override;
+ void visit(LshiftPtr v) override;
+ void visit(RshiftPtr v) override;
+ void visit(CompareSelectPtr v) override;
+
+#define IMM_VISIT_DECLARE(_1, Name) void visit(Name##ImmPtr v) override;
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT_DECLARE);
#undef IMM_VISIT_DECLARE
- void visit(Cast* v) override;
- void visit(BitCast* v) override;
- void visit(Var* v) override;
- void visit(Ramp* v) override;
- void visit(Load* v) override;
- void visit(For* v) override;
- void visit(Block* v) override;
- void visit(Store* v) override;
- void visit(Broadcast* v) override;
- void visit(IfThenElse* v) override;
- void visit(Intrinsics* v) override;
- void visit(Allocate* v) override;
- void visit(Free* v) override;
- void visit(Let* v) override;
- void visit(Cond* v) override;
- void visit(ExternalCall* v) override;
-
- void emitIsNan(Intrinsics* v);
+ void visit(CastPtr v) override;
+ void visit(BitCastPtr v) override;
+ void visit(VarPtr v) override;
+ void visit(RampPtr v) override;
+ void visit(LoadPtr v) override;
+ void visit(ForPtr v) override;
+ void visit(BlockPtr v) override;
+ void visit(StorePtr v) override;
+ void visit(BroadcastPtr v) override;
+ void visit(IfThenElsePtr v) override;
+ void visit(IntrinsicsPtr v) override;
+ void visit(AllocatePtr v) override;
+ void visit(FreePtr v) override;
+ void visit(LetPtr v) override;
+ void visit(CondPtr v) override;
+ void visit(ExternalCallPtr v) override;
+
+ void emitIsNan(IntrinsicsPtr v);
llvm::Value* emitUnmaskedLoad(llvm::Value* addr, llvm::Value* idx);
llvm::Value* emitMaskedLoad(
LLVMCodeGen::~LLVMCodeGen() = default;
-LLVMCodeGen::LLVMCodeGen(Stmt* stmt)
+LLVMCodeGen::LLVMCodeGen(StmtPtr stmt)
: LLVMCodeGen(stmt, std::vector<CodeGen::BufferArg>()) {}
LLVMCodeGen::LLVMCodeGen(
- Stmt* stmt,
+ StmtPtr stmt,
const std::vector<BufferArg>& args,
at::Device device,
const std::string& kernel_func_name,
} // namespace
LLVMCodeGenImpl::LLVMCodeGenImpl(
- Stmt* stmt,
+ StmtPtr stmt,
const std::vector<CodeGen::BufferArg>& args,
at::Device device,
Dtype dtype,
class LLVMIntrinsicsExpander : public GenericIntrinsicsExpander {
private:
- Expr* mutate(Intrinsics* v) {
+ ExprPtr mutate(IntrinsicsPtr v) {
if (v->op_type() == kTanh) {
ScalarType stype = v->dtype().scalar_type();
if (stype == ScalarType::Float) {
};
void LLVMCodeGenImpl::emitKernel(
- Stmt* stmt,
+ StmtPtr stmt,
const std::vector<llvm::Type*>& params) {
// Set insert point to the real function.
bb_ = llvm::BasicBlock::Create(getContext(), "entry", fn_);
// TODO: The binary ops are copypasta.
-void LLVMCodeGenImpl::visit(Add* v) {
+void LLVMCodeGenImpl::visit(AddPtr v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
}
}
-void LLVMCodeGenImpl::visit(Sub* v) {
+void LLVMCodeGenImpl::visit(SubPtr v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
}
}
-void LLVMCodeGenImpl::visit(Mul* v) {
+void LLVMCodeGenImpl::visit(MulPtr v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
}
}
-void LLVMCodeGenImpl::visit(Div* v) {
+void LLVMCodeGenImpl::visit(DivPtr v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
}
}
-void LLVMCodeGenImpl::visit(And* v) {
+void LLVMCodeGenImpl::visit(AndPtr v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
}
}
-void LLVMCodeGenImpl::visit(Or* v) {
+void LLVMCodeGenImpl::visit(OrPtr v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
}
}
-void LLVMCodeGenImpl::visit(Xor* v) {
+void LLVMCodeGenImpl::visit(XorPtr v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
}
}
-void LLVMCodeGenImpl::visit(Lshift* v) {
+void LLVMCodeGenImpl::visit(LshiftPtr v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
}
}
-void LLVMCodeGenImpl::visit(Rshift* v) {
+void LLVMCodeGenImpl::visit(RshiftPtr v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
}
}
-void LLVMCodeGenImpl::visit(Mod* v) {
+void LLVMCodeGenImpl::visit(ModPtr v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
}
}
-void LLVMCodeGenImpl::visit(Max* v) {
+void LLVMCodeGenImpl::visit(MaxPtr v) {
v->lhs()->accept(this);
auto lhs = this->value_;
v->rhs()->accept(this);
irb_.CreateFCmp(llvm::FCmpInst::FCMP_OGT, lhs, rhs), lhs, rhs));
}
-void LLVMCodeGenImpl::visit(Min* v) {
+void LLVMCodeGenImpl::visit(MinPtr v) {
v->lhs()->accept(this);
auto lhs = this->value_;
v->rhs()->accept(this);
irb_.CreateFCmp(llvm::FCmpInst::FCMP_OLT, lhs, rhs), lhs, rhs));
}
-void LLVMCodeGenImpl::visit(CompareSelect* v) {
+void LLVMCodeGenImpl::visit(CompareSelectPtr v) {
auto genUnbiased = [this, v]() -> llvm::Value* {
v->lhs()->accept(this);
auto lhs = this->value_;
}
#define IMM_VISIT_DECLARE(Type, Name) \
- void LLVMCodeGenImpl::visit(Name##Imm* v) { \
+ void LLVMCodeGenImpl::visit(Name##ImmPtr v) { \
value_ = getFromType<Type>(Name##Ty_, v->value()); \
}
AT_FORALL_SCALAR_TYPES(IMM_VISIT_DECLARE);
#undef IMM_VISIT_DECLARE
-void LLVMCodeGenImpl::visit(HalfImm* v) {
+void LLVMCodeGenImpl::visit(HalfImmPtr v) {
value_ = llvm::ConstantFP::get(HalfTy_, v->value());
}
-void LLVMCodeGenImpl::visit(BoolImm* v) {
+void LLVMCodeGenImpl::visit(BoolImmPtr v) {
value_ = llvm::ConstantInt::get(BoolTy_, v->value());
}
}
}
-void LLVMCodeGenImpl::visit(Cast* v) {
+void LLVMCodeGenImpl::visit(CastPtr v) {
v->src_value()->accept(this);
llvm::Type* dstType =
}
}
-void LLVMCodeGenImpl::visit(BitCast* v) {
+void LLVMCodeGenImpl::visit(BitCastPtr v) {
v->src_value()->accept(this);
llvm::Type* dstType = dtypeToLLVM(v->dtype());
value_ = irb_.CreateBitOrPointerCast(value_, dstType);
}
-void LLVMCodeGenImpl::visit(Var* v) {
+void LLVMCodeGenImpl::visit(VarPtr v) {
value_ = varToValue(v);
}
-llvm::Value* LLVMCodeGenImpl::varToValue(Var* v) {
+llvm::Value* LLVMCodeGenImpl::varToValue(VarPtr v) {
// It is possible for v to be in both varToVal_ and varToArgs.
// In that case, varToVal_ takes precedence.
if (varToVal_.count(v)) {
}
void LLVMCodeGenImpl::replaceVarMapping(
- const std::vector<Var*>& vars,
+ const std::vector<VarPtr>& vars,
const std::vector<llvm::Value*>& vals) {
TORCH_CHECK(vars.size() == vals.size());
for (const auto i : c10::irange(vars.size())) {
- Var* var = vars[i];
+ VarPtr var = vars[i];
llvm::Value* val = vals[i];
if (val) {
varToVal_[var] = val;
}
}
-void LLVMCodeGenImpl::visit(Ramp* v) {
+void LLVMCodeGenImpl::visit(RampPtr v) {
v->base()->accept(this);
auto base = this->value_;
v->stride()->accept(this);
return phi;
}
-void LLVMCodeGenImpl::visit(Load* v) {
+void LLVMCodeGenImpl::visit(LoadPtr v) {
if (v->dtype().lanes() == 1) {
v->base_handle()->accept(this);
auto base = this->value_;
bool unmasked_load = true;
// Handle the case where the load is contiguous and unmasked efficiently
- auto* idx_ramp = dynamic_cast<Ramp*>(v->flat_index());
+ auto idx_ramp = to<Ramp>(v->flat_index());
if (idx_ramp) {
- auto* stride_imm = dynamic_cast<IntImm*>(idx_ramp->stride());
+ auto stride_imm = to<IntImm>(idx_ramp->stride());
if (stride_imm && stride_imm->value() == 1) {
v->base_handle()->accept(this);
auto base = this->value_;
// * Move the body into its own closure.
// * Identify var across the boundary into arguments and forward them.
// * Send the closure and range to the dispatcher for execution.
-void LLVMCodeGenImpl::processParallelFor(For* v) {
+void LLVMCodeGenImpl::processParallelFor(ForPtr v) {
// Create "start" and "stop" values.
v->start()->accept(this);
auto start = this->value_;
auto stop = this->value_;
// The Vars that need to be forward in the body closure.
- std::vector<Var*> body_arg_vars;
+ std::vector<VarPtr> body_arg_vars;
// Corresponding Value* that was used in the old body for the caller.
std::vector<llvm::Value*> body_caller_vals;
// Corresponding Value* that will be used in the new body closure.
std::vector<llvm::Value*> body_closure_args;
- // Identify the Var* used in the body, and generated outside.
+ // Identify the VarPtr used in the body, and generated outside.
VarFinder var_finder;
v->body()->accept(&var_finder);
auto& vars = var_finder.vars();
value_ = llvm::ConstantInt::get(IntTy_, 0);
}
-void LLVMCodeGenImpl::visit(For* v) {
+void LLVMCodeGenImpl::visit(ForPtr v) {
if (v->is_parallel()) {
processParallelFor(v);
return;
value_ = llvm::ConstantInt::get(IntTy_, 0);
}
-void LLVMCodeGenImpl::visit(Block* v) {
- Block* last = scope_;
+void LLVMCodeGenImpl::visit(BlockPtr v) {
+ BlockPtr last = scope_;
scope_ = v;
- for (Stmt* s : *v) {
+ for (StmtPtr s : *v) {
s->accept(this);
}
auto it = scopeToVar_.find(v);
if (it != scopeToVar_.end()) {
- for (Var* e : it->second) {
+ for (VarPtr e : it->second) {
if (varToVal_.erase(e) != 1) {
throw std::runtime_error("erasing var that doesn't exist");
}
irb_.SetInsertPoint(tailblock);
}
-void LLVMCodeGenImpl::visit(Store* v) {
+void LLVMCodeGenImpl::visit(StorePtr v) {
if (v->value()->dtype().lanes() == 1) {
v->base_handle()->accept(this);
auto base = this->value_;
auto val = this->value_;
// Handle the case where the store is contiguous and unmasked efficiently
- auto* idx_ramp = dynamic_cast<Ramp*>(v->flat_index());
+ auto idx_ramp = to<Ramp>(v->flat_index());
if (idx_ramp) {
- auto* stride_imm = dynamic_cast<IntImm*>(idx_ramp->stride());
+ auto stride_imm = to<IntImm>(idx_ramp->stride());
if (stride_imm && stride_imm->value() == 1) {
idx_ramp->base()->accept(this);
auto first_idx = value_;
value_ = llvm::ConstantInt::get(IntTy_, 0);
}
-void LLVMCodeGenImpl::visit(Broadcast* v) {
+void LLVMCodeGenImpl::visit(BroadcastPtr v) {
v->value()->accept(this);
int lanes = v->lanes();
value_ = irb_.CreateVectorSplat(lanes, value_);
}
-void LLVMCodeGenImpl::visit(IfThenElse* v) {
+void LLVMCodeGenImpl::visit(IfThenElsePtr v) {
v->condition()->accept(this);
llvm::Value* condition = value_;
llvm::Value* c = irb_.CreateICmpNE(
}
}
-void LLVMCodeGenImpl::emitIsNan(Intrinsics* v) {
+void LLVMCodeGenImpl::emitIsNan(IntrinsicsPtr v) {
v->param(0)->accept(this);
llvm::Type* dstType = dtypeToLLVM(v->dtype());
if (!v->param(0)->dtype().is_floating_point()) {
return SimdCallee{callee.getFunctionType(), callee.getCallee(), useSimd};
}
-void LLVMCodeGenImpl::visit(Intrinsics* v) {
+void LLVMCodeGenImpl::visit(IntrinsicsPtr v) {
llvm::FunctionType* call_ty = nullptr;
llvm::Value* call_fn = nullptr;
bool call_simd_sleef = false;
}
}
-void LLVMCodeGenImpl::visit(ExternalCall* v) {
+void LLVMCodeGenImpl::visit(ExternalCallPtr v) {
constexpr int max_buffers = 10;
constexpr int max_dimensions = 40;
// Prepare a vector of bufs that we need to pass to the external function.
// This vector is the output buf followed by the buf_args.
- std::vector<Buf*> bufs(v->buf_args());
+ std::vector<BufPtr> bufs(v->buf_args());
bufs.insert(bufs.begin(), v->buf());
int64_t bufs_num = bufs.size();
// Count the size of dims array - it consists of dimension of all bufs
// concatenated together.
int64_t dims_num = 0;
- for (Buf* b : bufs) {
+ for (BufPtr b : bufs) {
dims_num += b->dims().size();
}
int i = 0;
int dim_idx = 0;
- for (Buf* b : bufs) {
+ for (BufPtr b : bufs) {
// Store value for buf pointer
auto gep = irb_.CreateInBoundsGEP(
buf_ptrs, {llvm::ConstantInt::getSigned(IntTy_, i)});
}
i = 0;
- for (Expr* arg : v->args()) {
+ for (ExprPtr arg : v->args()) {
auto gep = irb_.CreateInBoundsGEP(
extra_args, {llvm::ConstantInt::getSigned(IntTy_, i)});
arg->accept(this);
value_ = llvm::ConstantInt::get(IntTy_, 0);
}
-void LLVMCodeGenImpl::visit(Allocate* v) {
+void LLVMCodeGenImpl::visit(AllocatePtr v) {
llvm::Value* size =
llvm::ConstantInt::getSigned(LongTy_, v->dtype().byte_size());
- for (Expr* e : v->dims()) {
+ for (ExprPtr e : v->dims()) {
e->accept(this);
size = irb_.CreateMul(size, irb_.CreateZExt(value_, LongTy_));
}
varToVal_[v->buffer_var()] = malloc;
}
-void LLVMCodeGenImpl::visit(Free* v) {
+void LLVMCodeGenImpl::visit(FreePtr v) {
value_ = llvm::ConstantInt::get(IntTy_, 0);
llvm::Value* ptr = varToVal_.at(v->buffer_var());
if (!llvm::isa<llvm::AllocaInst>(ptr)) {
}
}
-void LLVMCodeGenImpl::visit(Let* v) {
+void LLVMCodeGenImpl::visit(LetPtr v) {
v->value()->accept(this);
if (!varToVal_.count(v->var())) {
varToVal_.emplace(v->var(), value_);
}
}
-void LLVMCodeGenImpl::visit(Cond* v) {
+void LLVMCodeGenImpl::visit(CondPtr v) {
// Even if true_stmt and false_stmt are nullptr,
// in case condition is a function call with side effect,
// we still evaluate it.
class TORCH_API LLVMCodeGen : public CodeGen {
public:
explicit LLVMCodeGen(
- Stmt* stmt,
+ StmtPtr stmt,
const std::vector<BufferArg>& args,
at::Device device = at::kCPU,
const std::string& kernel_func_name = "func",
c10::optional<std::string> triple = c10::nullopt,
c10::optional<std::string> cpu = c10::nullopt,
c10::optional<std::string> attrs = c10::nullopt);
- explicit LLVMCodeGen(Stmt* stmt);
+ explicit LLVMCodeGen(StmtPtr stmt);
LLVMCodeGen() = delete;
~LLVMCodeGen() override;
struct TORCH_API LLVMCodeGenBuilder {
using BufferArg = CodeGen::BufferArg;
- LLVMCodeGenBuilder(Stmt* stmt, std::vector<BufferArg> args)
+ LLVMCodeGenBuilder(StmtPtr stmt, std::vector<BufferArg> args)
: stmt_(stmt), args_(std::move(args)) {}
LLVMCodeGenBuilder& device(at::Device device) {
}
private:
- Stmt* stmt_;
+ StmtPtr stmt_;
std::vector<BufferArg> args_;
at::Device device_ = at::kCPU;
std::string kernelFuncName_ = "func";
verify(root_stmt_);
}
-LoopNest::LoopNest(Stmt* stmt, std::unordered_set<Buf*> output_bufs)
+LoopNest::LoopNest(StmtPtr stmt, std::unordered_set<BufPtr> output_bufs)
: root_stmt_(stmt), output_bufs_(std::move(output_bufs)) {
verify(root_stmt_);
}
verify(root_stmt_);
}
-const std::unordered_set<Buf*> LoopNest::getIntermediateBufs() const {
- std::unordered_set<Buf*> result;
+const std::unordered_set<BufPtr> LoopNest::getIntermediateBufs() const {
+ std::unordered_set<BufPtr> result;
auto input_bufs = getInputBufs();
auto bufs = NodeFinder<Buf>::find(root_stmt_);
- for (auto* buf : bufs) {
+ for (auto buf : bufs) {
if (!output_bufs_.count(buf) && !input_bufs.count(buf)) {
result.insert(buf);
}
return result;
}
-const std::unordered_set<Buf*> LoopNest::getInputBufs() const {
- std::unordered_set<Buf*> result;
+const std::unordered_set<BufPtr> LoopNest::getInputBufs() const {
+ std::unordered_set<BufPtr> result;
auto buf_load_store_uses = findLoadOrStoreUses(root_stmt_);
for (auto& kv : buf_load_store_uses) {
bool has_store = false;
class IndexFlattener : public IRMutator {
public:
- Stmt* flatten(Stmt* s) {
+ StmtPtr flatten(StmtPtr s) {
return s->accept_mutator(this);
}
- Expr* mutate(Load* v) override {
+ ExprPtr mutate(LoadPtr v) override {
if (v->indices().size() == 1) {
return v;
}
- return new Load(
- v->dtype(), v->buf(), {flatten_index(v->buf()->dims(), v->indices())});
+ return alloc<Load>(
+ v->dtype(),
+ v->buf(),
+ std::vector<ExprPtr>({flatten_index(v->buf()->dims(), v->indices())}));
}
- Stmt* mutate(Store* v) override {
- Expr* value = v->value();
- Expr* new_value = value->accept_mutator(this);
+ StmtPtr mutate(StorePtr v) override {
+ ExprPtr value = v->value();
+ ExprPtr new_value = value->accept_mutator(this);
if (v->indices().size() == 1 && value == new_value) {
- return (Stmt*)v;
+ return (StmtPtr)v;
}
- return new Store(
- v->buf(), {flatten_index(v->buf()->dims(), v->indices())}, new_value);
+ return alloc<Store>(
+ v->buf(),
+ std::vector<ExprPtr>({flatten_index(v->buf()->dims(), v->indices())}),
+ new_value);
}
};
class Vectorizer : public IRMutator {
public:
- Stmt* vectorize(For* v) {
- Stmt* body = v->body();
- Var* var = v->var();
- Expr* start = v->start();
- Expr* stop = v->stop();
-
- IntImm* start_imm = dynamic_cast<IntImm*>(start);
- IntImm* stop_imm = dynamic_cast<IntImm*>(stop);
+ StmtPtr vectorize(ForPtr v) {
+ StmtPtr body = v->body();
+ VarPtr var = v->var();
+ ExprPtr start = v->start();
+ ExprPtr stop = v->stop();
+
+ IntImmPtr start_imm = to<IntImm>(start);
+ IntImmPtr stop_imm = to<IntImm>(stop);
if (!start_imm) {
throw std::runtime_error(
"Can't vectorize due to non-constant loop start!");
start_ = start_imm;
lanes_ = stop_imm->value();
- Stmt* new_body = body->accept_mutator(this);
+ StmtPtr new_body = body->accept_mutator(this);
if (new_body == body) {
throw std::runtime_error("Vectorization failed!");
}
return new_body;
}
- Expr* mutate(Add* v) override {
- std::vector<Expr*> inputs = {v->lhs(), v->rhs()};
+ ExprPtr mutate(AddPtr v) override {
+ std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
return try_vectorize(v, inputs, [&]() {
return ExprHandle(inputs[0]) + ExprHandle(inputs[1]);
});
}
- Expr* mutate(Sub* v) override {
- std::vector<Expr*> inputs = {v->lhs(), v->rhs()};
+ ExprPtr mutate(SubPtr v) override {
+ std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
return try_vectorize(v, inputs, [&]() {
return ExprHandle(inputs[0]) - ExprHandle(inputs[1]);
});
}
- Expr* mutate(Mul* v) override {
- std::vector<Expr*> inputs = {v->lhs(), v->rhs()};
+ ExprPtr mutate(MulPtr v) override {
+ std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
return try_vectorize(v, inputs, [&]() {
return ExprHandle(inputs[0]) * ExprHandle(inputs[1]);
});
}
- Expr* mutate(Div* v) override {
- std::vector<Expr*> inputs = {v->lhs(), v->rhs()};
+ ExprPtr mutate(DivPtr v) override {
+ std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
return try_vectorize(v, inputs, [&]() {
return ExprHandle(inputs[0]) / ExprHandle(inputs[1]);
});
}
- Expr* mutate(And* v) override {
- std::vector<Expr*> inputs = {v->lhs(), v->rhs()};
+ ExprPtr mutate(AndPtr v) override {
+ std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
return try_vectorize(v, inputs, [&]() {
return ExprHandle(inputs[0]) & ExprHandle(inputs[1]);
});
}
- Expr* mutate(Or* v) override {
- std::vector<Expr*> inputs = {v->lhs(), v->rhs()};
+ ExprPtr mutate(OrPtr v) override {
+ std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
return try_vectorize(v, inputs, [&]() {
return ExprHandle(inputs[0]) | ExprHandle(inputs[1]);
});
}
- Expr* mutate(Xor* v) override {
- std::vector<Expr*> inputs = {v->lhs(), v->rhs()};
+ ExprPtr mutate(XorPtr v) override {
+ std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
return try_vectorize(v, inputs, [&]() {
return ExprHandle(inputs[0]) ^ ExprHandle(inputs[1]);
});
}
- Expr* mutate(Lshift* v) override {
- std::vector<Expr*> inputs = {v->lhs(), v->rhs()};
+ ExprPtr mutate(LshiftPtr v) override {
+ std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
return try_vectorize(v, inputs, [&]() {
return ExprHandle(inputs[0]) << ExprHandle(inputs[1]);
});
}
- Expr* mutate(Rshift* v) override {
- std::vector<Expr*> inputs = {v->lhs(), v->rhs()};
+ ExprPtr mutate(RshiftPtr v) override {
+ std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
return try_vectorize(v, inputs, [&]() {
return ExprHandle(inputs[0]) >> ExprHandle(inputs[1]);
});
}
- Expr* mutate(Max* v) override {
- std::vector<Expr*> inputs = {v->lhs(), v->rhs()};
+ ExprPtr mutate(MaxPtr v) override {
+ std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
return try_vectorize(v, inputs, [&]() {
return Max::make(
ExprHandle(inputs[0]), ExprHandle(inputs[1]), v->propagate_nans());
});
}
- Expr* mutate(Min* v) override {
- std::vector<Expr*> inputs = {v->lhs(), v->rhs()};
+ ExprPtr mutate(MinPtr v) override {
+ std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
return try_vectorize(v, inputs, [&]() {
return Min::make(
ExprHandle(inputs[0]), ExprHandle(inputs[1]), v->propagate_nans());
});
}
- Expr* mutate(CompareSelect* v) override {
- std::vector<Expr*> inputs = {
+ ExprPtr mutate(CompareSelectPtr v) override {
+ std::vector<ExprPtr> inputs = {
v->lhs(), v->rhs(), v->ret_val1(), v->ret_val2()};
return try_vectorize(v, inputs, [&]() {
return CompareSelect::make(
});
}
- Expr* mutate(BitCast* v) override {
- std::vector<Expr*> inputs = {v->src_value()};
+ ExprPtr mutate(BitCastPtr v) override {
+ std::vector<ExprPtr> inputs = {v->src_value()};
return try_vectorize(v, inputs, [&]() {
return BitCast::make(
Dtype(v->dtype().scalar_type(), lanes_), ExprHandle(inputs[0]));
});
}
- Expr* mutate(Cast* v) override {
- std::vector<Expr*> inputs = {v->src_value()};
+ ExprPtr mutate(CastPtr v) override {
+ std::vector<ExprPtr> inputs = {v->src_value()};
return try_vectorize(v, inputs, [&]() {
return Cast::make(
Dtype(v->dtype().scalar_type(), lanes_), ExprHandle(inputs[0]));
});
}
- Expr* mutate(Var* v) override {
+ ExprPtr mutate(VarPtr v) override {
if (v == var_) {
return Ramp::make(ExprHandle(start_), 1, lanes_).node();
}
return v;
}
- Expr* mutate(Ramp* v) override {
- Expr* base = v->base();
- Expr* stride = v->stride();
+ ExprPtr mutate(RampPtr v) override {
+ ExprPtr base = v->base();
+ ExprPtr stride = v->stride();
- Expr* base_new = base->accept_mutator(this);
- Expr* stride_new = stride->accept_mutator(this);
+ ExprPtr base_new = base->accept_mutator(this);
+ ExprPtr stride_new = stride->accept_mutator(this);
if (base_new == base && stride_new == stride) {
return v;
throw std::runtime_error("Can't vectorize a Ramp!");
}
- Expr* mutate(Load* v) override {
+ ExprPtr mutate(LoadPtr v) override {
Dtype dtype(v->dtype().scalar_type(), lanes_);
- Buf* buf = v->buf();
- std::vector<Expr*> inputs = {v->flat_index()};
+ BufPtr buf = v->buf();
+ std::vector<ExprPtr> inputs = {v->flat_index()};
return try_vectorize(v, inputs, [&]() {
return Load::make(dtype, BufHandle(buf), {ExprHandle(inputs[0])});
});
}
- Expr* mutate(ReduceOp* v) override {
+ ExprPtr mutate(ReduceOpPtr v) override {
Dtype dtype(v->dtype().scalar_type(), lanes_);
- std::vector<Expr*> inputs = {v->body()};
+ std::vector<ExprPtr> inputs = {v->body()};
- auto* out = try_vectorize(v, inputs, [&]() {
+ auto out = try_vectorize(v, inputs, [&]() {
return ExprHandle(
- new ReduceOp(inputs[0], v->reduce_args(), v->reducer()));
+ alloc<ReduceOp>(inputs[0], v->reduce_args(), v->reducer()));
});
return out;
}
- Expr* mutate(Broadcast* v) override {
- Expr* val = v->value();
- Expr* new_val = val->accept_mutator(this);
+ ExprPtr mutate(BroadcastPtr v) override {
+ ExprPtr val = v->value();
+ ExprPtr new_val = val->accept_mutator(this);
if (new_val == val) {
return v;
}
throw std::runtime_error("Can't vectorize a Broadcast!");
}
- Expr* mutate(IfThenElse* v) override {
- Expr* condition = v->condition();
- Expr* new_condition = condition->accept_mutator(this);
+ ExprPtr mutate(IfThenElsePtr v) override {
+ ExprPtr condition = v->condition();
+ ExprPtr new_condition = condition->accept_mutator(this);
if (new_condition != condition) {
throw std::runtime_error("Can't vectorize an IfThenElse condition!");
}
- std::vector<Expr*> inputs = {v->true_value(), v->false_value()};
+ std::vector<ExprPtr> inputs = {v->true_value(), v->false_value()};
return try_vectorize(v, inputs, [&]() {
return IfThenElse::make(
ExprHandle(condition), ExprHandle(inputs[0]), ExprHandle(inputs[1]));
});
}
- Expr* mutate(Intrinsics* v) override {
- std::vector<Expr*> inputs = v->params();
+ ExprPtr mutate(IntrinsicsPtr v) override {
+ std::vector<ExprPtr> inputs = v->params();
return try_vectorize(v, inputs, [&]() {
- return ExprHandle(new Intrinsics(v->op_type(), inputs));
+ return ExprHandle(alloc<Intrinsics>(v->op_type(), inputs));
});
}
- Stmt* mutate(Store* v) override {
- Buf* buf = v->buf();
- std::vector<Expr*> inputs = {v->flat_index(), v->value()};
+ StmtPtr mutate(StorePtr v) override {
+ BufPtr buf = v->buf();
+ std::vector<ExprPtr> inputs = {v->flat_index(), v->value()};
return try_vectorize(v, inputs, [&]() {
return Store::make(
BufHandle(buf), {ExprHandle(inputs[0])}, ExprHandle(inputs[1]));
});
}
- Stmt* mutate(For* v) override {
- Var* var = v->var();
- Expr* start = v->start();
- Expr* stop = v->stop();
+ StmtPtr mutate(ForPtr v) override {
+ VarPtr var = v->var();
+ ExprPtr start = v->start();
+ ExprPtr stop = v->stop();
LoopOptions loop_options = v->loop_options();
- Expr* new_start = start->accept_mutator(this);
- Expr* new_stop = stop->accept_mutator(this);
+ ExprPtr new_start = start->accept_mutator(this);
+ ExprPtr new_stop = stop->accept_mutator(this);
if (new_start != start || new_stop != stop) {
throw std::runtime_error(
"Can't vectorize nested For with dependent loop bounds!");
}
- Stmt* body = v->body();
- Stmt* new_body = body->accept_mutator(this);
+ StmtPtr body = v->body();
+ StmtPtr new_body = body->accept_mutator(this);
if (new_body == body) {
- return (For*)v;
+ return (ForPtr)v;
}
- return new For(var, new_start, new_stop, new_body, loop_options);
+ return alloc<For>(var, new_start, new_stop, new_body, loop_options);
}
- Stmt* mutate(Block* v) override {
+ StmtPtr mutate(BlockPtr v) override {
// IRMutator does in-place mutations. But the logic in vectorization checks
// for success by looking for a new stmt. So, we override the in-place
// mutations and create a clone here if any of its statements change.
// TODO: Can we change the logic of vectorizer so that we don't need this?
bool any_change = false;
- std::vector<Stmt*> stmts;
- for (Stmt* stmt : *v) {
- Stmt* stmt_new = stmt->accept_mutator(this);
+ std::vector<StmtPtr> stmts;
+ for (StmtPtr stmt : *v) {
+ StmtPtr stmt_new = stmt->accept_mutator(this);
if (stmt != stmt_new) {
any_change = true;
} else {
}
}
if (any_change) {
- return new Block(stmts);
+ return alloc<Block>(stmts);
}
return v;
}
template <typename T>
- Expr* try_vectorize(Expr* e, std::vector<Expr*>& inputs, T&& vec_ctor) {
+ ExprPtr try_vectorize(ExprPtr e, std::vector<ExprPtr>& inputs, T&& vec_ctor) {
bool vectorize = vectorize_inputs(inputs);
if (vectorize) {
return vec_ctor().node();
}
template <typename T>
- Stmt* try_vectorize(Stmt* s, std::vector<Expr*>& inputs, T&& vec_ctor) {
+ StmtPtr try_vectorize(StmtPtr s, std::vector<ExprPtr>& inputs, T&& vec_ctor) {
bool vectorize = vectorize_inputs(inputs);
if (vectorize) {
return vec_ctor();
}
- return (Stmt*)s;
+ return (StmtPtr)s;
}
- bool vectorize_inputs(std::vector<Expr*>& inputs) {
+ bool vectorize_inputs(std::vector<ExprPtr>& inputs) {
bool any_vectorized = false;
- std::vector<Expr*> new_inputs;
+ std::vector<ExprPtr> new_inputs;
// Attempt to vectorize each input.
- for (Expr*& in : inputs) {
- Expr* new_in = in->accept_mutator(this);
+ for (ExprPtr& in : inputs) {
+ ExprPtr new_in = in->accept_mutator(this);
new_inputs.push_back(new_in);
if (new_in != in) {
any_vectorized = true;
return true;
}
- Var* var_ = nullptr;
+ VarPtr var_ = nullptr;
int lanes_ = 0;
- Expr* start_ = nullptr;
+ ExprPtr start_ = nullptr;
};
-bool LoopNest::vectorize(For* f) {
- Block* b = dynamic_cast<Block*>(f->get_parent());
+bool LoopNest::vectorize(ForPtr f) {
+ BlockPtr b = to<Block>(f->get_parent());
if (!b) {
return false;
}
// Can't vectorize reduction axes.
auto reductions = NodeFinder<ReduceOp>::find(f);
- for (auto* r : reductions) {
+ for (auto r : reductions) {
if (std::find(r->reduce_args().begin(), r->reduce_args().end(), f->var()) !=
r->reduce_args().end()) {
return false;
}
Vectorizer v;
- Stmt* new_f = nullptr;
+ StmtPtr new_f = nullptr;
try {
new_f = Stmt::clone(f);
- normalize(dynamic_cast<For*>(new_f));
+ normalize(to<For>(new_f));
new_f = FlattenIndexes(new_f);
- new_f = v.vectorize(dynamic_cast<For*>(new_f));
+ new_f = v.vectorize(to<For>(new_f));
} catch (std::runtime_error& e) {
// We clone f before vectorizing. So, any partial vectorization will
// have modified the clone. In case of an exception, we can continue
output_bufs_.insert(t->buf());
}
- std::vector<Stmt*> loops;
+ std::vector<StmtPtr> loops;
for (Tensor* t : tensors_to_compute) {
- Stmt* loop = t->stmt();
+ StmtPtr loop = t->stmt();
if (loop->get_parent()) {
std::cerr << "Error: creating a loopnest from already used Tensors\n";
loops = {};
break;
}
// Flatten initializers.
- if (Block* block = dynamic_cast<Block*>(loop)) {
- for (auto* s : block->stmts()) {
+ if (BlockPtr block = to<Block>(loop)) {
+ for (auto s : block->stmts()) {
block->remove_stmt(s);
loops.push_back(s);
}
}
}
- root_stmt_ = new Block(loops);
+ root_stmt_ = alloc<Block>(loops);
}
class FunctionInliner : public IRMutator {
public:
- FunctionInliner(Store* producer, std::unordered_set<Buf*> outputs)
+ FunctionInliner(StorePtr producer, std::unordered_set<BufPtr> outputs)
: buf_(producer->buf()),
producer_(producer),
outputs_(std::move(outputs)) {
- for (auto* i : producer->indices()) {
- if (auto index_var = dynamic_cast<Var*>(i)) {
+ for (auto i : producer->indices()) {
+ if (auto index_var = to<Var>(i)) {
index_vars_.insert(index_var);
producer_index_vars_.push_back(index_var);
- } else if (dynamic_cast<IntImm*>(i) != nullptr) {
+ } else if (to<IntImm>(i) != nullptr) {
// If the index can be a constant, then that dimension must have size 1
// (since we don't support in-place writes). Resolves issue 52581.
TORCH_INTERNAL_ASSERT(
- dynamic_cast<IntImm*>(i)->value() == 0,
+ to<IntImm>(i)->value() == 0,
"Constant index impression should always be zero");
producer_index_vars_.push_back(nullptr);
} else {
}
private:
- Expr* mutate_loads(Buf* buf, std::vector<Expr*> dims) {
- std::vector<Var*> index_vars;
+ ExprPtr mutate_loads(BufPtr buf, std::vector<ExprPtr> dims) {
+ std::vector<VarPtr> index_vars;
TORCH_INTERNAL_ASSERT(buf->ndim() == producer_index_vars_.size());
for (const auto i : c10::irange(buf->ndim())) {
- Var* func_callee_arg = producer_index_vars_.at(i);
- Expr* func_caller_param = dims.at(i);
+ VarPtr func_callee_arg = producer_index_vars_.at(i);
+ ExprPtr func_caller_param = dims.at(i);
if (func_callee_arg == nullptr) {
TORCH_INTERNAL_ASSERT(
- dynamic_cast<IntImm*>(func_caller_param) != nullptr &&
- dynamic_cast<IntImm*>(func_caller_param)->value() == 0,
+ to<IntImm>(func_caller_param) != nullptr &&
+ to<IntImm>(func_caller_param)->value() == 0,
"We are implicitly assuming that if you have an index of 0, that must also be inlined into an index of 0");
continue;
}
}
// Call the actual replacement.
- Expr* body = producer_->value();
- Expr* result = Expr::clone(body)->accept_mutator(this);
+ ExprPtr body = producer_->value();
+ ExprPtr result = Expr::clone(body)->accept_mutator(this);
// Remove the mappings we created for this function parameters.
- for (auto* v : index_vars) {
+ for (auto v : index_vars) {
for (auto& pair : random_bindings_) {
if (pair.second.erase(v)) {
- Expr* inlined = inline_mapping_[v];
- for (auto* nv : VarFinder::find(inlined)) {
+ ExprPtr inlined = inline_mapping_[v];
+ for (auto nv : VarFinder::find(inlined)) {
pair.second.insert(nv);
}
}
return result;
}
- Expr* mutate(Load* v) override {
- Buf* buf = v->buf();
+ ExprPtr mutate(LoadPtr v) override {
+ BufPtr buf = v->buf();
if (buf != buf_) {
return IRMutator::mutate(v);
}
}
// Replace the target variable with the caller expressions.
- Expr* mutate(Var* v) override {
+ ExprPtr mutate(VarPtr v) override {
auto iter = inline_mapping_.find(v);
if (iter == inline_mapping_.end()) {
return v;
} else {
- Expr* expr = iter->second;
+ ExprPtr expr = iter->second;
// Continue to transform the value from the lookup table.
return expr->accept_mutator(this);
}
}
// Handle random intrinsics which should be cached.
- Expr* mutate(Intrinsics* v) override {
+ ExprPtr mutate(IntrinsicsPtr v) override {
if (!in_producer_ || v->op_type() != kRand) {
return IRMutator::mutate(v);
}
- // Create a new Let Statment for the random variable, which we can refer to
- // multiple times and resolve the same value (ie. store it in a scalar
+ // Create a new Let Statement for the random variable, which we can refer
+ // to multiple times and resolve the same value (ie. store it in a scalar
// rather than the Tensor).
const std::string& name = buf_->name_hint();
- Var* new_var = new Var(name, v->dtype());
- random_bindings_[new Let(new_var, v)] = index_vars_;
+ VarPtr new_var = alloc<Var>(name, v->dtype());
+ random_bindings_[alloc<Let>(new_var, v)] = index_vars_;
return new_var;
}
// Remove the buffer write from the inlined function.
- Stmt* mutate(Store* v) override {
+ StmtPtr mutate(StorePtr v) override {
// If the buf_ is in the outputs set, keep its statement intact. Otherwise,
// remove it.
if (v == producer_ && !outputs_.count(buf_)) {
in_producer_ = true;
- producer_ = dynamic_cast<Store*>(IRMutator::mutate(v));
+ producer_ = to<Store>(IRMutator::mutate(v));
TORCH_INTERNAL_ASSERT(producer_ != nullptr);
in_producer_ = false;
return nullptr;
}
// Any Random Instrinsics that were turned into vars must be inserted here.
- Stmt* mutate(Block* v) override {
- std::vector<Stmt*> stmts;
- for (Stmt* stmt : *v) {
- Stmt* stmt_new = stmt->accept_mutator(this);
+ StmtPtr mutate(BlockPtr v) override {
+ std::vector<StmtPtr> stmts;
+ for (StmtPtr stmt : *v) {
+ StmtPtr stmt_new = stmt->accept_mutator(this);
if (!stmt_new) {
continue;
}
return Block::make(stmts);
}
- Stmt* mutate(For* v) override {
- For* res = dynamic_cast<For*>(IRMutator::mutate(v));
+ StmtPtr mutate(ForPtr v) override {
+ ForPtr res = to<For>(IRMutator::mutate(v));
if (!res) {
return nullptr;
}
// Find any random bindings that should be defined in this loops body.
- std::vector<Let*> bindings_this_loop;
- Var* fv = v->var();
+ std::vector<LetPtr> bindings_this_loop;
+ VarPtr fv = v->var();
for (auto& pair : random_bindings_) {
auto& index_var = pair.second;
if (index_var.erase(fv)) {
}
}
- for (auto* l : bindings_this_loop) {
+ for (auto l : bindings_this_loop) {
res->body()->prepend_stmt(l);
random_bindings_.erase(l);
}
}
private:
- Buf* buf_;
- Store* producer_;
+ BufPtr buf_;
+ StorePtr producer_;
// Index Vars present in the producer.
- std::unordered_set<Var*> index_vars_;
- std::vector<Var*> producer_index_vars_;
+ std::unordered_set<VarPtr> index_vars_;
+ std::vector<VarPtr> producer_index_vars_;
- std::unordered_map<Var*, Expr*> inline_mapping_;
+ std::unordered_map<VarPtr, ExprPtr> inline_mapping_;
// In the producer's scope - we need to bind any calls to rand().
bool in_producer_ = false;
- std::unordered_map<Let*, std::unordered_set<Var*>> random_bindings_;
- std::unordered_set<Buf*> outputs_;
+ std::unordered_map<LetPtr, std::unordered_set<VarPtr>> random_bindings_;
+ std::unordered_set<BufPtr> outputs_;
};
-bool LoopNest::computeInline(Stmt* s) {
- auto* s_store = dynamic_cast<Store*>(s);
+bool LoopNest::computeInline(StmtPtr s) {
+ auto s_store = to<Store>(s);
if (s_store == nullptr) {
throw std::logic_error("Could not find buffer producer to inline");
}
return computeInline(s_store->buf());
}
-bool LoopNest::computeInline(Buf* b) {
+bool LoopNest::computeInline(BufPtr b) {
// If buf is used or defined in an ExternalCall, we cannot inline it
auto buf_load_store_uses = findLoadOrStoreUses(root_stmt_);
for (auto& use : buf_load_store_uses.at(b)) {
- Stmt* s = use.s;
- if (dynamic_cast<ExternalCall*>(s)) {
+ StmtPtr s = use.s;
+ if (to<ExternalCall>(s)) {
return false;
}
}
// Find producers.
- Store* relevant_store{nullptr};
+ StorePtr relevant_store{nullptr};
auto stores = NodeFinder<Store>::find(root_stmt_);
- for (auto* s : stores) {
+ for (auto s : stores) {
if (s->buf() == b) {
auto reductions = NodeFinder<ReduceOp>::find(s);
if (!reductions.empty()) {
// difficult synchronization logic across blocks. Inlining trivial reads does
// not duplicate work
void LoopNest::inlineIntermediateBufs(bool allow_duplicated_work) {
- std::unordered_set<Buf*> bufs_to_inline;
+ std::unordered_set<BufPtr> bufs_to_inline;
auto intermediate_bufs = getIntermediateBufs();
if (allow_duplicated_work) {
// tensors, always inline, bc we are not duplicating any work
// and avoiding an intermediary buffer
if (stores.size() == 1) {
- if (auto store = dynamic_cast<Store*>(stores[0].s)) {
- auto input_as_load = dynamic_cast<Load*>(store->value());
+ if (auto store = to<Store>(stores[0].s)) {
+ auto input_as_load = to<Load>(store->value());
if (input_as_load && input_bufs.count(input_as_load->buf())) {
bufs_to_inline.insert(buf);
continue;
}
} else {
// If S is not a store, it must be an ExternalCall.
- TORCH_INTERNAL_ASSERT(dynamic_cast<ExternalCall*>(stores[0].s));
+ TORCH_INTERNAL_ASSERT(to<ExternalCall>(stores[0].s));
}
}
// TODO: Unify with DepTracker
class LoadOrStoreUseFinder : public IRVisitor {
public:
- std::unordered_map<Buf*, std::vector<BufLoadOrStoreUse>> findUses(Stmt* s) {
+ std::unordered_map<BufPtr, std::vector<BufLoadOrStoreUse>> findUses(
+ StmtPtr s) {
uses_.clear();
s->accept(this);
return uses_;
}
private:
- void visit(Store* v) override {
+ void visit(StorePtr v) override {
if (stores_[v->buf()].insert(last_stmt_).second) {
- uses_[v->buf()].push_back({(Stmt*)v, true});
+ uses_[v->buf()].push_back({(StmtPtr)v, true});
}
- last_stmt_ = (Stmt*)v;
+ last_stmt_ = (StmtPtr)v;
IRVisitor::visit(v);
}
- void visit(ExternalCall* v) override {
+ void visit(ExternalCallPtr v) override {
if (stores_[v->buf()].insert(last_stmt_).second) {
- uses_[v->buf()].push_back({(Stmt*)v, true});
+ uses_[v->buf()].push_back({(StmtPtr)v, true});
}
- last_stmt_ = (Stmt*)v;
+ last_stmt_ = (StmtPtr)v;
- for (Buf* input_buf : v->buf_args()) {
+ for (BufPtr input_buf : v->buf_args()) {
if (loads_[input_buf].insert(last_stmt_).second) {
uses_[input_buf].push_back({last_stmt_, false});
}
IRVisitor::visit(v);
}
- void visit(Load* v) override {
+ void visit(LoadPtr v) override {
if (loads_[v->buf()].insert(last_stmt_).second) {
uses_[v->buf()].push_back({last_stmt_, false});
}
IRVisitor::visit(v);
}
- Stmt* last_stmt_ = nullptr;
- std::unordered_map<Buf*, std::vector<BufLoadOrStoreUse>> uses_;
+ StmtPtr last_stmt_ = nullptr;
+ std::unordered_map<BufPtr, std::vector<BufLoadOrStoreUse>> uses_;
// Sets of loads and stores in order to keep the results unique
- std::unordered_map<Buf*, std::unordered_set<Stmt*>> loads_;
- std::unordered_map<Buf*, std::unordered_set<Stmt*>> stores_;
+ std::unordered_map<BufPtr, std::unordered_set<StmtPtr>> loads_;
+ std::unordered_map<BufPtr, std::unordered_set<StmtPtr>> stores_;
};
-std::unordered_map<Buf*, std::vector<BufLoadOrStoreUse>> findLoadOrStoreUses(
- Stmt* s) {
+std::unordered_map<BufPtr, std::vector<BufLoadOrStoreUse>> findLoadOrStoreUses(
+ StmtPtr s) {
LoadOrStoreUseFinder uf;
return uf.findUses(s);
}
class ContainedStmtsFinder : public IRVisitor {
public:
// Simply list all Stores and Block that are children of the given stmt
- const std::unordered_set<Stmt*>& findContainedStmts(Stmt* s) {
+ const std::unordered_set<StmtPtr>& findContainedStmts(StmtPtr s) {
contained_.clear();
s->accept(this);
return contained_;
}
private:
- void visit(Store* v) override {
- contained_.insert((Stmt*)v);
+ void visit(StorePtr v) override {
+ contained_.insert((StmtPtr)v);
IRVisitor::visit(v);
}
- void visit(ExternalCall* v) override {
- contained_.insert((Stmt*)v);
+ void visit(ExternalCallPtr v) override {
+ contained_.insert((StmtPtr)v);
IRVisitor::visit(v);
}
- void visit(Block* v) override {
- contained_.insert((Stmt*)v);
+ void visit(BlockPtr v) override {
+ contained_.insert((StmtPtr)v);
IRVisitor::visit(v);
}
- std::unordered_set<Stmt*> contained_;
+ std::unordered_set<StmtPtr> contained_;
};
-bool containsAll(const std::vector<BufLoadOrStoreUse>& uses, Block* b) {
- std::unordered_set<Stmt*> not_found;
+bool containsAll(const std::vector<BufLoadOrStoreUse>& uses, BlockPtr b) {
+ std::unordered_set<StmtPtr> not_found;
for (auto use : uses) {
not_found.insert(use.s);
}
ContainedStmtsFinder csf;
- const std::unordered_set<Stmt*>& contained = csf.findContainedStmts(b);
+ const std::unordered_set<StmtPtr>& contained = csf.findContainedStmts(b);
for (auto s : contained) {
not_found.erase(s);
}
return not_found.empty();
}
-Block* findParentBlock(Stmt* s) {
+BlockPtr findParentBlock(StmtPtr s) {
while (s) {
- if (auto b = dynamic_cast<Block*>(s)) {
+ if (auto b = to<Block>(s)) {
return b;
}
s = s->get_parent();
return nullptr;
}
-Block* findLowestContainingBlock(const std::vector<BufLoadOrStoreUse>& uses) {
+BlockPtr findLowestContainingBlock(const std::vector<BufLoadOrStoreUse>& uses) {
// TODO: we're not using the most efficient algorithm here for simplicity.
// Replace with something more performant in case it becomes a bottleneck.
- Block* b = findParentBlock(uses[0].s);
+ BlockPtr b = findParentBlock(uses[0].s);
while (b && !containsAll(uses, b)) {
b = findParentBlock(b->get_parent());
}
return b;
}
-Stmt* LoopNest::insertAllocFree(Stmt* stmt) {
+StmtPtr LoopNest::insertAllocFree(StmtPtr stmt) {
auto intermediate_bufs = getIntermediateBufs();
if (intermediate_bufs.size() == 0ULL) {
return stmt;
}
- Block* b = dynamic_cast<Block*>(stmt);
+ BlockPtr b = to<Block>(stmt);
if (!b) {
- b = new Block({stmt});
+ b = alloc<Block>(std::vector<StmtPtr>({stmt}));
}
- std::unordered_map<Buf*, std::vector<BufLoadOrStoreUse>> uses =
+ std::unordered_map<BufPtr, std::vector<BufLoadOrStoreUse>> uses =
findLoadOrStoreUses(stmt);
// Insert allocations and frees for temporary buffers at global scope.
- for (Buf* buf : intermediate_bufs) {
- b->prepend_stmt(new Allocate(buf));
- b->append_stmt(new Free(buf));
+ for (BufPtr buf : intermediate_bufs) {
+ b->prepend_stmt(alloc<Allocate>(buf));
+ b->append_stmt(alloc<Free>(buf));
}
return b;
class StmtDeleter : public IRMutator {
public:
- StmtDeleter(const std::unordered_set<Stmt*>& targets) : targets_(targets) {}
+ StmtDeleter(const std::unordered_set<StmtPtr>& targets) : targets_(targets) {}
private:
- Stmt* mutate(Block* v) override {
- std::vector<Stmt*> stmts;
+ StmtPtr mutate(BlockPtr v) override {
+ std::vector<StmtPtr> stmts;
- for (auto* s : v->stmts()) {
+ for (auto s : v->stmts()) {
if (targets_.count(s) == 0) {
- Stmt* ns = s->accept_mutator(this);
+ StmtPtr ns = s->accept_mutator(this);
if (ns) {
stmts.push_back(Stmt::clone(ns));
}
return Block::make(stmts);
}
- const std::unordered_set<Stmt*>& targets_;
+ const std::unordered_set<StmtPtr>& targets_;
};
void LoopNest::eliminateDeadStores() {
MemDependencyChecker checker(getInputBufs(), getOutputBufs());
root_stmt_->accept(&checker);
- std::unordered_set<Stmt*> deadStores;
+ std::unordered_set<StmtPtr> deadStores;
std::vector<std::shared_ptr<AccessInfo>> outputAccesses;
- for (auto* o : getOutputBufs()) {
+ for (auto o : getOutputBufs()) {
outputAccesses.push_back(checker.output(o));
}
// the rest of the IR nodes (the ones not touched directly) to be cloned.
class IfThenElseReplacer : public IRCloner {
public:
- IfThenElseReplacer(IfThenElse* to_replace, Expr* new_expr)
+ IfThenElseReplacer(IfThenElsePtr to_replace, ExprPtr new_expr)
: to_replace_(to_replace), new_expr_(new_expr) {}
- Expr* mutate(IfThenElse* i) override {
+ ExprPtr mutate(IfThenElsePtr i) override {
if (i == to_replace_) {
return new_expr_;
}
}
private:
- IfThenElse* to_replace_;
- Expr* new_expr_;
+ IfThenElsePtr to_replace_;
+ ExprPtr new_expr_;
};
// Check if the given condition is optimizable.
// * sets `compared_value` to `expr`, and
// * returns true.
bool isConditionOptimizable(
- Expr* condition,
- Var** cond_var,
- Expr** compared_value) {
- auto cs = dynamic_cast<CompareSelect*>(condition);
+ ExprPtr condition,
+ VarPtr* cond_var,
+ ExprPtr* compared_value) {
+ auto cs = to<CompareSelect>(condition);
if (cs && cs->compare_select_op() == kLT) {
- auto var = dynamic_cast<Var*>(cs->lhs());
+ auto var = to<Var>(cs->lhs());
if (var) {
*cond_var = var;
*compared_value = cs->rhs();
// * sub_exprs to the list of sub-expressions that are the result of this
// if-then-else expression.
bool isConditionalFromCat(
- IfThenElse* ite,
- Var** cond_var,
- std::vector<Expr*>* comp_values,
- std::vector<Expr*>* sub_exprs) {
- Var* var = nullptr;
+ IfThenElsePtr ite,
+ VarPtr* cond_var,
+ std::vector<ExprPtr>* comp_values,
+ std::vector<ExprPtr>* sub_exprs) {
+ VarPtr var = nullptr;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- Expr* comp_value;
+ ExprPtr comp_value;
if (isConditionOptimizable(ite->condition(), &var, &comp_value)) {
if (*cond_var == nullptr) {
*cond_var = var;
// expressions. Can not optimize such cases.
return false;
}
- auto true_ite = dynamic_cast<IfThenElse*>(ite->true_value());
+ auto true_ite = to<IfThenElse>(ite->true_value());
if (true_ite) {
if (!isConditionalFromCat(true_ite, cond_var, comp_values, sub_exprs)) {
return false;
} else {
sub_exprs->push_back(ite->true_value());
}
- auto false_ite = dynamic_cast<IfThenElse*>(ite->false_value());
+ auto false_ite = to<IfThenElse>(ite->false_value());
if (false_ite) {
return false;
}
return false;
}
-bool areConstantsAndSorted(const std::vector<Expr*>& comp_values) {
+bool areConstantsAndSorted(const std::vector<ExprPtr>& comp_values) {
std::vector<int> comp_consts;
comp_consts.reserve(comp_values.size());
for (auto c : comp_values) {
// Consider every store in the root_stmt_ and try to optimize the
// conditionals in that store.
auto stores = NodeFinder<Store>::find(root_stmt_);
- std::unordered_set<For*> split_fors;
+ std::unordered_set<ForPtr> split_fors;
for (auto store : stores) {
- Var* cond_var = nullptr;
+ VarPtr cond_var = nullptr;
// `comp_values` represent the list of compared values that will be
// collected as we check for the expected pattern. Since that will
// only include the RHS of the conditions in the if-then-else expressions
// we need to start with `0` which is the initial bound, given that we
// only handle normalized loops (check for this is done below).
- std::vector<Expr*> comp_values = {new IntImm(0)};
- std::vector<Expr*> sub_exprs;
+ std::vector<ExprPtr> comp_values = {alloc<IntImm>(0)};
+ std::vector<ExprPtr> sub_exprs;
auto ifthenelse_exprs = NodeFinder<IfThenElse>::find(store);
if (ifthenelse_exprs.empty()) {
continue;
// Remove all the if-then-else expressions from this store and create
// one loop per sub-expression.
- std::vector<Stmt*> split_loops;
+ std::vector<StmtPtr> split_loops;
auto cond_to_replace = ifthenelse_exprs.front();
for (size_t i = 0; i < sub_exprs.size(); ++i) {
IfThenElseReplacer ifthenelseReplacer(cond_to_replace, sub_exprs[i]);
auto new_store = store->accept_mutator(&ifthenelseReplacer);
auto new_for_body =
for_to_split->body()->clone_and_replace(store, new_store);
- auto new_for = new For(
+ auto new_for = alloc<For>(
for_to_split->var(),
comp_values[i],
comp_values[i + 1],
LoopNest::normalize(new_for);
split_loops.push_back(new_for);
}
- auto par = dynamic_cast<Block*>(for_to_split->get_parent());
- par->replace_stmt(for_to_split, new Block(split_loops));
+ auto par = to<Block>(for_to_split->get_parent());
+ par->replace_stmt(for_to_split, alloc<Block>(split_loops));
}
root_stmt_ = IRSimplifier::simplify(root_stmt_);
return true;
}
void LoopNest::vectorizeInnerLoops() {
- std::vector<For*> innerLoops;
- std::vector<For*> worklist;
+ std::vector<ForPtr> innerLoops;
+ std::vector<ForPtr> worklist;
// Find outer-most For loops
- if (For* rootF = dynamic_cast<For*>(root_stmt_)) {
+ if (ForPtr rootF = to<For>(root_stmt_)) {
worklist.push_back(rootF);
- } else if (Block* body = dynamic_cast<Block*>(root_stmt_)) {
- std::vector<Block*> blocks = {body};
+ } else if (BlockPtr body = to<Block>(root_stmt_)) {
+ std::vector<BlockPtr> blocks = {body};
while (blocks.size()) {
- Block* b = blocks.back();
+ BlockPtr b = blocks.back();
blocks.pop_back();
- for (Stmt* s : *b) {
- if (For* f = dynamic_cast<For*>(s)) {
+ for (StmtPtr s : *b) {
+ if (ForPtr f = to<For>(s)) {
worklist.push_back(f);
- } else if (Block* b2 = dynamic_cast<Block*>(s)) {
+ } else if (BlockPtr b2 = to<Block>(s)) {
blocks.push_back(b2);
}
}
// Traverse the For loop nest find inner-most loops, which are
// vectorization candidates.
while (worklist.size()) {
- For* f = worklist.back();
+ ForPtr f = worklist.back();
worklist.pop_back();
bool containsSubLoops = false;
- if (Block* body = dynamic_cast<Block*>(f->body())) {
- for (Stmt* s2 : *body) {
- if (For* f2 = dynamic_cast<For*>(s2)) {
+ if (BlockPtr body = to<Block>(f->body())) {
+ for (StmtPtr s2 : *body) {
+ if (ForPtr f2 = to<For>(s2)) {
containsSubLoops = true;
worklist.push_back(f2);
}
}
// vectorize inner loops.
- for (For* loop : innerLoops) {
+ for (ForPtr loop : innerLoops) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* split1;
+ ForPtr split1;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* tail1;
+ ForPtr tail1;
static const int kBodyVectorWidth = 8;
splitWithTail(loop, kBodyVectorWidth, &split1, &tail1);
if (tail1) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* split2;
+ ForPtr split2;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* tail2;
+ ForPtr tail2;
static const int kTailVectorWidth = 4;
splitWithTail(tail1, kTailVectorWidth, &split2, &tail2);
vectorize(split2);
}
}
-void LoopNest::sliceHead(For* f, int factor, For** head, For** tail) {
- if (dynamic_cast<IntImm*>(f->start()) && dynamic_cast<IntImm*>(f->stop())) {
- int start_val = dynamic_cast<IntImm*>(f->start())->value();
- int stop_val = dynamic_cast<IntImm*>(f->stop())->value();
+void LoopNest::sliceHead(ForPtr f, int factor, ForPtr* head, ForPtr* tail) {
+ if (to<IntImm>(f->start()) && to<IntImm>(f->stop())) {
+ int start_val = to<IntImm>(f->start())->value();
+ int stop_val = to<IntImm>(f->stop())->value();
int size_val = stop_val - start_val;
if (factor >= size_val) {
*head = f;
throw malformed_input("sliceHead attempted on null loop", f);
}
- Block* p = dynamic_cast<Block*>(f->get_parent());
+ BlockPtr p = to<Block>(f->get_parent());
if (!p) {
throw malformed_input("sliceHead attempted on loop with no parent", p);
}
- Expr* head_end =
- new Min(new Add(f->start(), new IntImm(factor)), f->stop(), true);
- *head = new For(f->var(), f->start(), head_end, Stmt::clone(f->body()));
- *tail = new For(
+ ExprPtr head_end = alloc<Min>(
+ alloc<Add>(f->start(), alloc<IntImm>(factor)), f->stop(), true);
+ *head = alloc<For>(f->var(), f->start(), head_end, Stmt::clone(f->body()));
+ *tail = alloc<For>(
f->var(), head_end, f->stop(), Stmt::clone(f->body()), f->loop_options());
p->replace_stmt(f, *head);
LoopNest::normalize(*tail);
}
}
-void LoopNest::sliceHead(For* f, int factor) {
+void LoopNest::sliceHead(ForPtr f, int factor) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For *head, *tail;
+ ForPtr head, tail;
sliceHead(f, factor, &head, &tail);
}
-void LoopNest::sliceTail(For* f, int factor, For** head, For** tail) {
- if (dynamic_cast<IntImm*>(f->start()) && dynamic_cast<IntImm*>(f->stop())) {
- int start_val = dynamic_cast<IntImm*>(f->start())->value();
- int stop_val = dynamic_cast<IntImm*>(f->stop())->value();
+void LoopNest::sliceTail(ForPtr f, int factor, ForPtr* head, ForPtr* tail) {
+ if (to<IntImm>(f->start()) && to<IntImm>(f->stop())) {
+ int start_val = to<IntImm>(f->start())->value();
+ int stop_val = to<IntImm>(f->stop())->value();
int size_val = stop_val - start_val;
if (factor >= size_val) {
*head = nullptr;
throw malformed_input("sliceTail attempted on null loop", f);
}
- Block* p = dynamic_cast<Block*>(f->get_parent());
+ BlockPtr p = to<Block>(f->get_parent());
if (!p) {
throw malformed_input("sliceTail attempted on loop with no parent", p);
}
- Expr* tail_start =
- new Max(f->start(), new Sub(f->stop(), new IntImm(factor)), true);
- *head = new For(
+ ExprPtr tail_start = alloc<Max>(
+ f->start(), alloc<Sub>(f->stop(), alloc<IntImm>(factor)), true);
+ *head = alloc<For>(
f->var(),
f->start(),
tail_start,
Stmt::clone(f->body()),
f->loop_options());
- *tail = new For(f->var(), tail_start, f->stop(), Stmt::clone(f->body()));
+ *tail = alloc<For>(f->var(), tail_start, f->stop(), Stmt::clone(f->body()));
p->replace_stmt(f, *head);
p->insert_stmt_after(*tail, *head);
LoopNest::normalize(*head);
}
}
-void LoopNest::sliceTail(For* f, int factor) {
+void LoopNest::sliceTail(ForPtr f, int factor) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For *head, *tail;
+ ForPtr head, tail;
sliceTail(f, factor, &head, &tail);
}
-void LoopNest::splitWithTail(For* f, int factor) {
+void LoopNest::splitWithTail(ForPtr f, int factor) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For *inner, *tail;
+ ForPtr inner, tail;
splitWithTail(f, factor, &inner, &tail);
}
-void LoopNest::splitWithTail(For* f, int factor, For** inner, For** tail) {
+void LoopNest::splitWithTail(
+ ForPtr f,
+ int factor,
+ ForPtr* inner,
+ ForPtr* tail) {
if (!f) {
throw malformed_input("splitWithTail attempted on null loop", f);
}
- Block* p = dynamic_cast<Block*>(f->get_parent());
+ BlockPtr p = to<Block>(f->get_parent());
if (!p) {
throw malformed_input("splitWithTail attempted on loop with no parent", p);
}
bool tail_is_needed = true;
- if (dynamic_cast<IntImm*>(f->start()) && dynamic_cast<IntImm*>(f->stop())) {
- int start_val = dynamic_cast<IntImm*>(f->start())->value();
- int stop_val = dynamic_cast<IntImm*>(f->stop())->value();
+ if (to<IntImm>(f->start()) && to<IntImm>(f->stop())) {
+ int start_val = to<IntImm>(f->start())->value();
+ int stop_val = to<IntImm>(f->stop())->value();
int size_val = stop_val - start_val;
int tail_size = size_val % factor;
if (tail_size == 0) {
}
}
- IntImm* factor_expr = new IntImm(factor);
- Expr* size = new Sub(f->stop(), f->start());
- Expr* split_count = new Div(size, factor_expr);
- Expr* tail_size = new Mod(size, factor_expr);
+ IntImmPtr factor_expr = alloc<IntImm>(factor);
+ ExprPtr size = alloc<Sub>(f->stop(), f->start());
+ ExprPtr split_count = alloc<Div>(size, factor_expr);
+ ExprPtr tail_size = alloc<Mod>(size, factor_expr);
const std::string& loop_var_name = f->var()->name_hint();
Dtype loop_var_dtype = f->var()->dtype();
- Var* i_inner = new Var(loop_var_name + "_inner", loop_var_dtype);
- Var* i_outer = new Var(loop_var_name + "_outer", loop_var_dtype);
+ VarPtr i_inner = alloc<Var>(loop_var_name + "_inner", loop_var_dtype);
+ VarPtr i_outer = alloc<Var>(loop_var_name + "_outer", loop_var_dtype);
// x -> x.outer * inner.size + x.inner
- Expr* combined_index1 = new Add(new Mul(i_outer, factor_expr), i_inner);
+ ExprPtr combined_index1 =
+ alloc<Add>(alloc<Mul>(i_outer, factor_expr), i_inner);
if (tail_is_needed) {
- Var* i_tail = new Var(loop_var_name + "_tail", loop_var_dtype);
+ VarPtr i_tail = alloc<Var>(loop_var_name + "_tail", loop_var_dtype);
// x -> x.tail + outer.size * inner.size
- Expr* combined_index2 = new Add(i_tail, new Mul(split_count, factor_expr));
+ ExprPtr combined_index2 =
+ alloc<Add>(i_tail, alloc<Mul>(split_count, factor_expr));
- Stmt* body_tail =
+ StmtPtr body_tail =
SubstituteInClone(f->body(), {{f->var(), combined_index2}});
- *tail = new For(i_tail, new IntImm(0), tail_size, body_tail);
+ *tail = alloc<For>(i_tail, alloc<IntImm>(0), tail_size, body_tail);
p->insert_stmt_after(*tail, f);
} else {
*tail = nullptr;
}
- Stmt* body_inner = Substitute(f->removeBody(), {{f->var(), combined_index1}});
+ StmtPtr body_inner =
+ Substitute(f->removeBody(), {{f->var(), combined_index1}});
- *inner = new For(i_inner, new IntImm(0), factor_expr, body_inner);
+ *inner = alloc<For>(i_inner, alloc<IntImm>(0), factor_expr, body_inner);
// The input loop `f` will be the outer loop after split.
f->set_var(i_outer);
- f->set_start(new IntImm(0));
+ f->set_start(alloc<IntImm>(0));
f->set_stop(split_count);
f->set_body(*inner);
}
-void LoopNest::splitWithMask(For* f, int factor) {
+void LoopNest::splitWithMask(ForPtr f, int factor) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* inner;
+ ForPtr inner;
splitWithMask(f, factor, &inner);
}
-void LoopNest::splitWithMask(For* f, int factor, For** inner) {
- Block* p = dynamic_cast<Block*>(f->get_parent());
+void LoopNest::splitWithMask(ForPtr f, int factor, ForPtr* inner) {
+ BlockPtr p = to<Block>(f->get_parent());
if (!p) {
std::cerr << "Parent is not a Block!\n";
return;
}
bool tail_is_needed = true;
- Expr* start = IRSimplifier::simplify(f->start());
- Expr* stop = IRSimplifier::simplify(f->stop());
+ ExprPtr start = IRSimplifier::simplify(f->start());
+ ExprPtr stop = IRSimplifier::simplify(f->stop());
if (start->isConstant() && stop->isConstant()) {
int start_val = immediateAs<int>(start);
int stop_val = immediateAs<int>(stop);
}
}
- IntImm* factor_expr = new IntImm(factor);
- Expr* size = new Sub(f->stop(), f->start());
+ IntImmPtr factor_expr = alloc<IntImm>(factor);
+ ExprPtr size = alloc<Sub>(f->stop(), f->start());
// split_count = (size + factor - 1) / factor
- Expr* split_count =
- new Div(new Sub(new Add(size, factor_expr), new IntImm(1)), factor_expr);
+ ExprPtr split_count = alloc<Div>(
+ alloc<Sub>(alloc<Add>(size, factor_expr), alloc<IntImm>(1)), factor_expr);
const std::string& loop_var_name = f->var()->name_hint();
Dtype loop_var_dtype = f->var()->dtype();
- Var* i_inner = new Var(loop_var_name + "_inner", loop_var_dtype);
- Var* i_outer = new Var(loop_var_name + "_outer", loop_var_dtype);
+ VarPtr i_inner = alloc<Var>(loop_var_name + "_inner", loop_var_dtype);
+ VarPtr i_outer = alloc<Var>(loop_var_name + "_outer", loop_var_dtype);
// x -> x.outer * inner.size + x.inner
- Expr* combined_index = new Add(new Mul(i_outer, factor_expr), i_inner);
+ ExprPtr combined_index =
+ alloc<Add>(alloc<Mul>(i_outer, factor_expr), i_inner);
- Stmt* body_inner = f->removeBody();
+ StmtPtr body_inner = f->removeBody();
// TODO: is it ok that we're doing it eagerly? In the other implementation we
// are only materializing predicates at the last, lowering, step.
if (tail_is_needed) {
- IntImm* start = dynamic_cast<IntImm*>(f->start());
+ IntImmPtr start = to<IntImm>(f->start());
if (!start || start->value() != 0) {
throw unimplemented_lowering();
}
- Expr* predicate =
+ ExprPtr predicate =
CompareSelect::make(ExprHandle(f->var()), ExprHandle(f->stop()), kLT)
.node();
body_inner = Cond::make(ExprHandle(predicate), body_inner, nullptr);
}
body_inner = Substitute(body_inner, {{f->var(), combined_index}});
- *inner = new For(i_inner, new IntImm(0), factor_expr, body_inner);
+ *inner = alloc<For>(i_inner, alloc<IntImm>(0), factor_expr, body_inner);
// The input loop `f` will be the outer loop after split.
f->set_var(i_outer);
- f->set_start(new IntImm(0));
+ f->set_start(alloc<IntImm>(0));
f->set_stop(split_count);
f->set_body(*inner);
}
-std::vector<For*> LoopNest::distributeLoop(
- For* loop,
- const std::unordered_set<Stmt*>& pivots) {
+std::vector<ForPtr> LoopNest::distributeLoop(
+ ForPtr loop,
+ const std::unordered_set<StmtPtr>& pivots) {
TORCH_INTERNAL_ASSERT(loop);
auto root = loop->get_parent();
if (root == nullptr) {
throw malformed_input("Loop without parent: ", loop);
}
- auto root_block = dynamic_cast<Block*>(root);
+ auto root_block = to<Block>(root);
if (root_block == nullptr) {
throw malformed_input(
"Loop's parent must be a Block, instead found ", root);
}
// Extract bodies for all the loops after distribution.
- std::vector<Block*> new_loop_bodies;
- auto new_loop_body = new Block({});
+ std::vector<BlockPtr> new_loop_bodies;
+ auto new_loop_body = alloc<Block>(std::vector<StmtPtr>({}));
while (!loop->body()->empty()) {
auto s = loop->body()->front();
loop->body()->remove_stmt(s);
new_loop_body->append_stmt(s);
if (pivots.count(s)) {
new_loop_bodies.push_back(new_loop_body);
- new_loop_body = new Block({});
+ new_loop_body = alloc<Block>(std::vector<StmtPtr>({}));
}
}
if (!new_loop_body->empty()) {
// The first loop body has to be in the original loop.
loop->body()->splice(loop->body()->begin(), new_loop_bodies.front());
- std::vector<For*> new_loops = {loop};
+ std::vector<ForPtr> new_loops = {loop};
// Create loops for all the remaining blocks.
// Add all the new loops to the parent block.
return new_loops;
}
-std::vector<For*> LoopNest::distributeLoop(For* loop) {
- std::unordered_set<Stmt*> stmtsInBlock(
+std::vector<ForPtr> LoopNest::distributeLoop(ForPtr loop) {
+ std::unordered_set<StmtPtr> stmtsInBlock(
loop->body()->begin(), loop->body()->end());
return distributeLoop(loop, stmtsInBlock);
}
-std::vector<For*> LoopNest::distributeLoopAndParents(For* loop) {
+std::vector<ForPtr> LoopNest::distributeLoopAndParents(ForPtr loop) {
auto parentLoop = getParentLoop(loop);
auto result = distributeLoop(loop);
if (parentLoop) {
return result;
}
-std::vector<For*> LoopNest::distributeLoopOverInnerLoops(For* loop) {
+std::vector<ForPtr> LoopNest::distributeLoopOverInnerLoops(ForPtr loop) {
auto loops = NodeFinder<For>::find(loop);
- std::unordered_set<Stmt*> loopsSet(loops.begin(), loops.end());
+ std::unordered_set<StmtPtr> loopsSet(loops.begin(), loops.end());
return distributeLoop(loop, loopsSet);
}
-std::vector<For*> LoopNest::distributeLoopAndParentsOverInnerLoops(For* loop) {
+std::vector<ForPtr> LoopNest::distributeLoopAndParentsOverInnerLoops(
+ ForPtr loop) {
auto parentLoop = getParentLoop(loop);
auto result = distributeLoopOverInnerLoops(loop);
if (parentLoop) {
return result;
}
-bool areEqual(Expr* expr1, Expr* expr2) {
- auto diff = IRSimplifier::simplify(new Sub(expr1, expr2));
+bool areEqual(ExprPtr expr1, ExprPtr expr2) {
+ auto diff = IRSimplifier::simplify(alloc<Sub>(expr1, expr2));
return diff->isConstant() && (immediateAs<int>(diff) == 0);
};
-bool doesExprContainAnyVar(Expr* expr, const std::unordered_set<Var*>& vars) {
- for (auto* v : VarFinder::find(expr)) {
+bool doesExprContainAnyVar(
+ ExprPtr expr,
+ const std::unordered_set<VarPtr>& vars) {
+ for (auto v : VarFinder::find(expr)) {
if (vars.count(v)) {
return true;
}
// that are loop-independent w.r.t. the given list of outer loop
// variables.
bool areIndicesLoopIndependent(
- const std::vector<Expr*>& expr_list1,
- const std::vector<Expr*>& expr_list2,
- const std::unordered_set<Var*>& outer_loop_vars) {
+ const std::vector<ExprPtr>& expr_list1,
+ const std::vector<ExprPtr>& expr_list2,
+ const std::unordered_set<VarPtr>& outer_loop_vars) {
if (expr_list1.size() != expr_list2.size()) {
return false;
}
return true;
}
-bool LoopNest::hasLoopCarriedDependence(For* loop) {
+bool LoopNest::hasLoopCarriedDependence(ForPtr loop) {
analysis::MemDependencyChecker analyzer;
loop->accept(&analyzer);
- std::unordered_set<Var*> outer_loop_vars = {loop->var()};
+ std::unordered_set<VarPtr> outer_loop_vars = {loop->var()};
auto outer_loops = LoopNest::getEnclosingLoopNest(loop);
for (auto l : outer_loops) {
outer_loop_vars.insert(l->var());
return false;
}
-bool LoopNest::unsafeFuseLoops(const std::vector<For*>& loops, For** fused) {
+bool LoopNest::unsafeFuseLoops(
+ const std::vector<ForPtr>& loops,
+ ForPtr* fused) {
if (loops.empty()) {
return false;
}
return false;
}
}
- auto root_block = dynamic_cast<Block*>(root);
+ auto root_block = to<Block>(root);
if (root_block == nullptr) {
return false;
}
// onwards and moving them into the first loop's body.
// This way the final fused loop will be the same as the first loop.
for (size_t i = 1; i < loops.size(); ++i) {
- auto body = dynamic_cast<Block*>(SubstituteInClone(
+ auto body = to<Block>(SubstituteInClone(
loops[i]->body(), {{loops[i]->var(), first_loop->var()}}));
first_loop->body()->splice(first_loop->body()->end(), body);
root_block->remove_stmt(loops[i]);
return true;
}
-bool LoopNest::fuseLoops(const std::vector<For*>& loops, For** fused) {
+bool LoopNest::fuseLoops(const std::vector<ForPtr>& loops, ForPtr* fused) {
if (loops.empty()) {
return false;
}
// This check can be done only after the loops are fused into one. But if the
// check is violated, we need to return the given loops in the original form.
// So, we create a clone of all the loops, fuse them and check for this.
- std::vector<For*> loops_copy;
+ std::vector<ForPtr> loops_copy;
loops_copy.reserve(loops.size());
- Block* parent = new Block({});
+ BlockPtr parent = alloc<Block>(std::vector<StmtPtr>({}));
for (auto& l : loops) {
auto l_copy = Stmt::clone(l);
- loops_copy.push_back(dynamic_cast<For*>(l_copy));
+ loops_copy.push_back(to<For>(l_copy));
parent->append_stmt(l_copy);
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* fused_copy;
+ ForPtr fused_copy;
bool ret = unsafeFuseLoops(loops_copy, &fused_copy);
if (!ret || hasLoopCarriedDependence(fused_copy)) {
return false;
return unsafeFuseLoops(loops, fused);
}
-For* findOuterFor(For* a, For* b) {
- Stmt* s = b; // guess b is the latter.
+ForPtr findOuterFor(ForPtr a, ForPtr b) {
+ StmtPtr s = b; // guess b is the latter.
while (s != nullptr) {
if (s == a) {
// yes, b is after a.
return nullptr;
}
-void LoopNest::reorderAxis(For* a, For* b) {
+void LoopNest::reorderAxis(ForPtr a, ForPtr b) {
if (a == b) {
// nothing to do.
return;
}
// find inner and outer.
- For* outer = findOuterFor(a, b);
+ ForPtr outer = findOuterFor(a, b);
if (outer == nullptr) {
throw std::runtime_error("Reordered a loop not in LoopNest");
}
- For* inner = a == outer ? b : a;
- std::deque<For*> internal_axes;
+ ForPtr inner = a == outer ? b : a;
+ std::deque<ForPtr> internal_axes;
// Find relevant axes, store reversed.
- Stmt* s = inner;
+ StmtPtr s = inner;
while (s != outer) {
- if (For* f = dynamic_cast<For*>(s)) {
+ if (ForPtr f = to<For>(s)) {
internal_axes.push_back(f);
}
internal_axes.push_back(outer);
- Block* root = dynamic_cast<Block*>(outer->get_parent());
+ BlockPtr root = to<Block>(outer->get_parent());
CHECK(root);
// Do a shallow copy of the inner blocks.
- Block* body = new Block({});
+ BlockPtr body = alloc<Block>(std::vector<StmtPtr>({}));
body->splice(body->end(), inner->body());
- For* before{outer};
- For* after{nullptr};
- For* last = internal_axes.front();
- Stmt* newInner = body;
+ ForPtr before{outer};
+ ForPtr after{nullptr};
+ ForPtr last = internal_axes.front();
+ StmtPtr newInner = body;
s = inner;
while (s != outer) {
- if (auto cond = dynamic_cast<Cond*>(s->get_parent())) {
+ if (auto cond = to<Cond>(s->get_parent())) {
if (s == cond->true_stmt()) {
newInner = cond->cloneWithNewBody(newInner);
} else {
// s is the false branch of Cond
- newInner = cond->cloneWithNewBodies(new Block({}), newInner);
+ newInner = cond->cloneWithNewBodies(
+ alloc<Block>(std::vector<StmtPtr>({})), newInner);
}
}
s = s->get_parent();
// When reordering loop i and j we need to ensure that Statement A and C are
// still both executed with the loop extents of i, and that the three
// statements are not reordered (as much as possible).
- for (auto* loop : internal_axes) {
+ for (auto loop : internal_axes) {
// If the inner loop had a component after the loop we must wrap it in a For
// loop matching this level of the tree.
if (after != nullptr) {
bool hadBeforeStmts = false;
for (auto I = loop->body()->begin(), E = loop->body()->end(); I != E;) {
// Be careful not to invalidate the iterator.
- Stmt* s = *(I++);
+ StmtPtr s = *(I++);
if (s == last) {
// This is the midpoint.
loop->body()->remove_stmt(s);
std::swap(internal_axes.front(), internal_axes.back());
// Create the reordered internals:
- for (auto* loop : internal_axes) {
+ for (auto loop : internal_axes) {
newInner = loop->cloneWithNewBody(newInner);
}
return isTrivialPermutation(permutation);
}
-std::vector<For*> LoopNest::reorder(
- const std::vector<For*>& loops,
+std::vector<ForPtr> LoopNest::reorder(
+ const std::vector<ForPtr>& loops,
const std::vector<size_t>& permutation) {
if (loops.size() != permutation.size()) {
throw malformed_input("invalid permutation size");
throw malformed_input("reorder is only allowed on perfectly nested loops");
}
- auto parent = dynamic_cast<Block*>(loops.front()->get_parent());
+ auto parent = to<Block>(loops.front()->get_parent());
if (parent == nullptr) {
throw malformed_input("parent of the loops must be a Block");
}
// Reorder the loops according to the permutation.
- std::vector<For*> result(loops.size());
+ std::vector<ForPtr> result(loops.size());
for (size_t i = 0; i < loops.size(); ++i) {
result[i] = loops[permutation[i]];
}
// We use an empty block statement to replace the outermost loop
// so that we know the position where the outermost reordered loop
// is to be inserted.
- auto empty_block = new Block({});
+ auto empty_block = alloc<Block>(std::vector<StmtPtr>({}));
parent->replace_stmt(loops.front(), empty_block);
for (size_t i = 1; i < loops.size(); ++i) {
- auto block = dynamic_cast<Block*>(loops[i]->get_parent());
+ auto block = to<Block>(loops[i]->get_parent());
TORCH_INTERNAL_ASSERT(block);
block->remove_stmt(loops[i]);
}
return result;
}
-For* LoopNest::getLoopAt(For* root, const std::vector<int>& indices) const {
+ForPtr LoopNest::getLoopAt(ForPtr root, const std::vector<int>& indices) const {
if (indices.empty()) {
return root;
}
throw malformed_input("root loop is null");
}
- For* curr = root;
+ ForPtr curr = root;
for (auto i : indices) {
if (i < 0 || curr->body()->nstmts() <= i) {
return nullptr;
}
- std::list<Stmt*>::iterator stmtp = curr->body()->begin();
+ std::list<StmtPtr>::iterator stmtp = curr->body()->begin();
std::advance(stmtp, i);
- curr = dynamic_cast<For*>(*stmtp);
+ curr = to<For>(*stmtp);
if (curr == nullptr) {
return nullptr;
}
return curr;
}
-For* LoopNest::tile(For* x, For* y, int x_factor, int y_factor) {
- auto parent = dynamic_cast<Block*>(x->get_parent());
+ForPtr LoopNest::tile(ForPtr x, ForPtr y, int x_factor, int y_factor) {
+ auto parent = to<Block>(x->get_parent());
if (parent == nullptr) {
throw malformed_input("parent of the loops must be a Block");
}
// Split x, y axes by x_factor and y_factor
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For *yi, *ytail;
+ ForPtr yi, ytail;
splitWithTail(y, y_factor, &yi, &ytail);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For *xi, *xtail;
+ ForPtr xi, xtail;
splitWithTail(x, x_factor, &xi, &xtail);
// Distribute xi over yo and ytail so we can manipulate the loop order of {xo,
// For {xi, yo, yi}, reorder the axes to be yo, xi, yi
xi = loops.front();
- For* yo = dynamic_cast<For*>(xi->body()->stmts().front());
+ ForPtr yo = to<For>(xi->body()->stmts().front());
CHECK(yo);
reorder({xi, yo}, {1, 0});
// For {xi, ytail}, reorder the axes to be ytail, xi
if (loops.size() == 2) {
xi = loops.back();
- ytail = dynamic_cast<For*>(xi->body()->stmts().front());
+ ytail = to<For>(xi->body()->stmts().front());
CHECK(ytail);
reorder({xi, ytail}, {1, 0});
}
return xtail;
}
-bool LoopNest::areLoopsPerfectlyNested(const std::vector<For*>& loops) {
+bool LoopNest::areLoopsPerfectlyNested(const std::vector<ForPtr>& loops) {
if (loops.size() < 2) {
return true;
}
return true;
}
-void LoopNest::unroll(For* f, Stmt** unrolled) {
- Block* p = dynamic_cast<Block*>(f->get_parent());
+void LoopNest::unroll(ForPtr f, StmtPtr* unrolled) {
+ BlockPtr p = to<Block>(f->get_parent());
if (!f) {
throw malformed_input("unroll attempted on null loop");
} else if (!p) {
throw std::runtime_error("Can't unroll due to non-constant loop stop!");
}
- std::vector<Stmt*> unrolled_stmts;
+ std::vector<StmtPtr> unrolled_stmts;
int start_val = immediateAs<int>(start_expr);
int stop_val = immediateAs<int>(stop_expr);
for (int current = start_val; current < stop_val; ++current) {
stmt, {{f->var(), getImmediateByType(f->var()->dtype(), current)}}));
}
}
- *unrolled = new Block(unrolled_stmts);
+ *unrolled = alloc<Block>(unrolled_stmts);
*unrolled = IRSimplifier::simplify(*unrolled);
p->replace_stmt(f, *unrolled);
}
-void LoopNest::unroll(For* f) {
+void LoopNest::unroll(ForPtr f) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- Stmt* unrolled;
+ StmtPtr unrolled;
unroll(f, &unrolled);
}
-bool LoopNest::isNormalized(For* f) {
+bool LoopNest::isNormalized(ForPtr f) {
if (f->start()->isConstant()) {
return immediateAs<int>(f->start()) == 0;
}
return false;
}
-bool LoopNest::normalize(For* f) {
+bool LoopNest::normalize(ForPtr f) {
if (!f) {
throw malformed_input("normalize attempted on null loop");
}
f->body(),
{{f->var(), (VarHandle(f->var()) + ExprHandle(f->start())).node()}});
f->set_body(IRSimplifier::simplify(for_body_normalized));
- f->set_stop(IRSimplifier::simplify(new Sub(f->stop(), f->start())));
- f->set_start(new IntImm(0));
+ f->set_stop(IRSimplifier::simplify(alloc<Sub>(f->stop(), f->start())));
+ f->set_start(alloc<IntImm>(0));
return true;
}
// This function expects that there are 'num' loops perfectly nested within
// and including 'f'.
-std::vector<For*> LoopNest::getLoopStmtsInLoopNest(For* f, size_t num) {
- std::vector<For*> loops(num);
- For* curr_for = f;
+std::vector<ForPtr> LoopNest::getLoopStmtsInLoopNest(ForPtr f, size_t num) {
+ std::vector<ForPtr> loops(num);
+ ForPtr curr_for = f;
loops[0] = curr_for;
for (size_t i = 1; i < num; ++i) {
TORCH_INTERNAL_ASSERT(curr_for->body()->nstmts() == 1);
- curr_for = dynamic_cast<For*>(curr_for->body()->front());
+ curr_for = to<For>(curr_for->body()->front());
TORCH_INTERNAL_ASSERT(curr_for);
loops[i] = curr_for;
}
return loops;
}
-bool LoopNest::flatten(const std::vector<For*>& loops, For** flattened) {
+bool LoopNest::flatten(const std::vector<ForPtr>& loops, ForPtr* flattened) {
if (loops.empty()) {
throw malformed_input("flatten attempted on empty set of loops");
}
- Block* p = dynamic_cast<Block*>(loops[0]->get_parent());
+ BlockPtr p = to<Block>(loops[0]->get_parent());
if (!p) {
throw malformed_input("flatten attempted on loops with no parent");
}
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
auto normalized_loops = getLoopStmtsInLoopNest(loops.front(), loops.size());
- auto flat_var = new Var(
+ auto flat_var = alloc<Var>(
normalized_loops[0]->var()->name_hint() + "_flat",
normalized_loops[0]->var()->dtype());
VarMapping var_mapping;
- Expr* stop = new IntImm(1);
+ ExprPtr stop = alloc<IntImm>(1);
for (size_t i = 0; i < normalized_loops.size(); ++i) {
size_t idx = normalized_loops.size() - i - 1;
auto curr_loop = normalized_loops[idx];
- Expr* div = new Div(flat_var, stop);
- Expr* sub_expr = idx == 0 ? div : new Mod(div, curr_loop->stop());
+ ExprPtr div = alloc<Div>(flat_var, stop);
+ ExprPtr sub_expr = idx == 0 ? div : alloc<Mod>(div, curr_loop->stop());
var_mapping.push_back(std::make_pair(curr_loop->var(), sub_expr));
- stop = new Mul(curr_loop->stop(), stop);
+ stop = alloc<Mul>(curr_loop->stop(), stop);
}
auto flattened_body =
Substitute(normalized_loops.back()->removeBody(), var_mapping);
normalized_loops.front()->set_var(flat_var);
- normalized_loops.front()->set_start(new IntImm(0));
+ normalized_loops.front()->set_start(alloc<IntImm>(0));
normalized_loops.front()->set_stop(stop);
normalized_loops.front()->set_body(flattened_body);
*flattened = normalized_loops.front();
return true;
}
-bool LoopNest::flatten(const std::vector<For*>& loops) {
+bool LoopNest::flatten(const std::vector<ForPtr>& loops) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For* flattened;
+ ForPtr flattened;
return flatten(loops, &flattened);
}
-void LoopNest::compressBuffer(Buf* buf, Stmt* stmt) {
+void LoopNest::compressBuffer(BufPtr buf, StmtPtr stmt) {
// Loop iterations in NNC IR do not follow sequential semantics by default.
// In other words, the iterations of the loops could be executed in any
// random order without affecting correctness. This constraint in turn
auto reads = StmtsReadingBuf::find(stmt, buf);
// Find the parent common to all the buffer accesses.
- Block* parent = dynamic_cast<Block*>(writes.front()->get_parent());
+ BlockPtr parent = to<Block>(writes.front()->get_parent());
TORCH_INTERNAL_ASSERT(parent);
for (auto w : writes) {
parent = Block::getSharedParent(parent, w);
// Collect all the loops that are above the common parent.
auto loops = LoopNest::getEnclosingLoopNest(parent);
- std::unordered_set<Var*> loop_vars;
+ std::unordered_set<VarPtr> loop_vars;
for (auto l : loops) {
loop_vars.insert(l->var());
}
// Vector to indicate which dimensions could be compressed away.
std::vector<bool> dims(buf->dims().size(), true);
- auto check_indices = [&](const std::vector<Expr*>& indices) {
+ auto check_indices = [&](const std::vector<ExprPtr>& indices) {
TORCH_INTERNAL_ASSERT(indices.size() == dims.size());
for (size_t i = 0; i < indices.size(); ++i) {
auto index_vars = NodeFinder<Var>::find(indices[i]);
}
// Compress buffer by removing the marked dims.
- std::vector<Expr*> new_dims(buf->dims());
+ std::vector<ExprPtr> new_dims(buf->dims());
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i]) {
- new_dims[i] = new IntImm(1);
+ new_dims[i] = alloc<IntImm>(1);
}
}
buf->set_dims(new_dims);
// Modify all access to reflect the removed dims.
- auto get_new_indices = [&](const std::vector<Expr*>& indices) {
+ auto get_new_indices = [&](const std::vector<ExprPtr>& indices) {
TORCH_INTERNAL_ASSERT(indices.size() == dims.size());
- std::vector<Expr*> new_indices(indices);
+ std::vector<ExprPtr> new_indices(indices);
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i]) {
- new_indices[i] = new IntImm(0);
+ new_indices[i] = alloc<IntImm>(0);
}
}
return new_indices;
}
}
-void LoopNest::compressAllBuffers(Stmt* stmt) {
+void LoopNest::compressAllBuffers(StmtPtr stmt) {
for (auto buf : BufFinder::find(stmt)) {
- compressBuffer(const_cast<Buf*>(buf), stmt);
+ compressBuffer(const_cast<BufPtr>(buf), stmt);
}
}
-std::vector<For*> LoopNest::getLoopStmtsFor(Tensor* t) const {
- Stmt* cur_stmt = getLoopBodyFor(t);
+std::vector<ForPtr> LoopNest::getLoopStmtsFor(Tensor* t) const {
+ StmtPtr cur_stmt = getLoopBodyFor(t);
return getLoopStmtsFor(cur_stmt);
}
-std::vector<For*> LoopNest::getLoopStmtsFor(Buf* buf) const {
- Stmt* cur_stmt = getLoopBodyFor(buf);
+std::vector<ForPtr> LoopNest::getLoopStmtsFor(BufPtr buf) const {
+ StmtPtr cur_stmt = getLoopBodyFor(buf);
return getLoopStmtsFor(cur_stmt);
}
-std::vector<For*> LoopNest::getLoopStmtsFor(Stmt* s) const {
- std::vector<For*> result;
+std::vector<ForPtr> LoopNest::getLoopStmtsFor(StmtPtr s) const {
+ std::vector<ForPtr> result;
while (s) {
- if (auto* loop = dynamic_cast<For*>(s)) {
+ if (auto loop = to<For>(s)) {
result.push_back(loop);
}
s = s->get_parent();
return result;
}
-Stmt* LoopNest::getLoopBodyFor(Tensor* t) const {
+StmtPtr LoopNest::getLoopBodyFor(Tensor* t) const {
return getLoopBodyFor(t->buf());
}
-Stmt* LoopNest::getLoopBodyFor(Buf* buf) const {
+StmtPtr LoopNest::getLoopBodyFor(BufPtr buf) const {
auto writes = WritesToBuf::find(root_stmt_, buf);
// special case for reduction Tensors, ignore the initializer if it's the only
// op:
if (writes.size() == 2) {
- if (Store* s = dynamic_cast<Store*>(writes.back())) {
- if (ReduceOp* r = dynamic_cast<ReduceOp*>(s->value())) {
- return (Stmt*)s; // NOLINT
+ if (StorePtr s = to<Store>(writes.back())) {
+ if (ReduceOpPtr r = to<ReduceOp>(s->value())) {
+ return (StmtPtr)s; // NOLINT
}
}
}
- Stmt* res = nullptr;
- for (auto* s : writes) {
+ StmtPtr res = nullptr;
+ for (auto s : writes) {
if (!res) {
res = s;
continue;
res = Block::getSharedParent(res, s);
}
- return (Stmt*)res; // NOLINT
+ return (StmtPtr)res; // NOLINT
}
-For* LoopNest::getParentLoop(Stmt* st) {
+ForPtr LoopNest::getParentLoop(StmtPtr st) {
if (st == nullptr) {
return nullptr;
}
auto par = st->get_parent();
- if (auto f = dynamic_cast<For*>(par)) {
+ if (auto f = to<For>(par)) {
return f;
}
return getParentLoop(par);
}
-std::vector<For*> LoopNest::getEnclosingLoopNest(Stmt* st) {
- std::vector<For*> loops;
+std::vector<ForPtr> LoopNest::getEnclosingLoopNest(StmtPtr st) {
+ std::vector<ForPtr> loops;
auto f = getParentLoop(st);
while (f) {
loops.push_back(f);
return loops;
}
-std::vector<Stmt*> LoopNest::getAllWritesToBuf(Buf* buf) const {
+std::vector<StmtPtr> LoopNest::getAllWritesToBuf(BufPtr buf) const {
return WritesToBuf::find(root_stmt_, buf);
}
-std::vector<For*> LoopNest::getAllInnermostLoopsWritingToBuf(Buf* buf) const {
+std::vector<ForPtr> LoopNest::getAllInnermostLoopsWritingToBuf(
+ BufPtr buf) const {
auto writes = getAllWritesToBuf(buf);
- std::vector<For*> innermost_loops;
+ std::vector<ForPtr> innermost_loops;
innermost_loops.reserve(writes.size());
for (auto w : writes) {
innermost_loops.push_back(LoopNest::getParentLoop(w));
return innermost_loops;
}
-std::vector<std::vector<For*>> LoopNest::getAllLoopNestsWritingToBuf(
- Buf* buf) const {
+std::vector<std::vector<ForPtr>> LoopNest::getAllLoopNestsWritingToBuf(
+ BufPtr buf) const {
auto writes = getAllWritesToBuf(buf);
- std::vector<std::vector<For*>> loopnests;
+ std::vector<std::vector<ForPtr>> loopnests;
loopnests.reserve(writes.size());
for (auto w : writes) {
loopnests.emplace_back(LoopNest::getEnclosingLoopNest(w));
return loopnests;
}
-Stmt* LoopNest::simplify() {
+StmtPtr LoopNest::simplify() {
root_stmt_ = IRSimplifier::simplify(root_stmt_);
return root_stmt_;
}
-Stmt* FlattenIndexes(Stmt* s) {
+StmtPtr FlattenIndexes(StmtPtr s) {
IndexFlattener idx_flattener;
return idx_flattener.flatten(s);
}
// LoopNest::computeAt for more details.
class LoopComputeAtRewriter : public IRMutator {
public:
- LoopComputeAtRewriter(Buf* buf, Buf* new_buf, std::vector<Expr*> offsets)
+ LoopComputeAtRewriter(
+ BufPtr buf,
+ BufPtr new_buf,
+ std::vector<ExprPtr> offsets)
: buf_(buf), new_buf_(new_buf), offsets_(std::move(offsets)) {}
private:
- Buf* buf_;
- Buf* new_buf_;
- std::vector<Expr*> offsets_;
+ BufPtr buf_;
+ BufPtr new_buf_;
+ std::vector<ExprPtr> offsets_;
- Expr* mutate(Load* v) override {
+ ExprPtr mutate(LoadPtr v) override {
if (v->buf() != buf_) {
return v;
}
- std::vector<Expr*> new_indices(v->indices().size());
+ std::vector<ExprPtr> new_indices(v->indices().size());
for (const auto i : c10::irange(v->indices().size())) {
new_indices[i] =
- IRSimplifier::simplify(new Sub(v->indices()[i], offsets_[i]));
+ IRSimplifier::simplify(alloc<Sub>(v->indices()[i], offsets_[i]));
}
- return new Load(v->dtype(), new_buf_, new_indices);
+ return alloc<Load>(v->dtype(), new_buf_, new_indices);
}
};
-static Store* getStoreStmtOfProducer(Stmt* s) {
- if (Store* st = dynamic_cast<Store*>(s)) {
+static StorePtr getStoreStmtOfProducer(StmtPtr s) {
+ if (StorePtr st = to<Store>(s)) {
return st;
}
- if (Block* b = dynamic_cast<Block*>(s)) {
- for (Stmt* ss : *b) {
- if (Store* st = dynamic_cast<Store*>(ss)) {
+ if (BlockPtr b = to<Block>(s)) {
+ for (StmtPtr ss : *b) {
+ if (StorePtr st = to<Store>(ss)) {
return st;
}
}
return nullptr;
}
-static std::vector<Var*> getOuterLoopIndexes(Stmt* s) {
- std::vector<Var*> res;
- Stmt* cur = s;
+static std::vector<VarPtr> getOuterLoopIndexes(StmtPtr s) {
+ std::vector<VarPtr> res;
+ StmtPtr cur = s;
while (cur) {
- if (auto l = dynamic_cast<For*>(cur)) {
+ if (auto l = to<For>(cur)) {
res.push_back(l->var());
}
cur = cur->get_parent();
class CacheReplacer : public IRMutator {
public:
- CacheReplacer(Buf* buffer, Buf* cache, std::vector<Expr*>& offsets)
+ CacheReplacer(BufPtr buffer, BufPtr cache, std::vector<ExprPtr>& offsets)
: buf_(buffer), cache_(cache), offsets_(offsets) {}
private:
- Expr* mutate(Load* v) override {
- Buf* buf = v->buf();
+ ExprPtr mutate(LoadPtr v) override {
+ BufPtr buf = v->buf();
if (buf != buf_) {
return IRMutator::mutate(v);
}
// Map indices to call-parameters.
- std::vector<Expr*> newIndices;
+ std::vector<ExprPtr> newIndices;
TORCH_INTERNAL_ASSERT(offsets_.size() == v->indices().size());
for (size_t i = 0; i < v->indices().size(); ++i) {
- Expr* index = v->indices()[i]->accept_mutator(this);
- Expr* offset = offsets_[i];
- Expr* sub = IRSimplifier::simplify(new Sub(index, offset));
+ ExprPtr index = v->indices()[i]->accept_mutator(this);
+ ExprPtr offset = offsets_[i];
+ ExprPtr sub = IRSimplifier::simplify(alloc<Sub>(index, offset));
newIndices.push_back(sub);
}
- return new Load(cache_, newIndices);
+ return alloc<Load>(cache_, newIndices);
}
- Stmt* mutate(Store* v) override {
- Buf* buf = v->buf();
+ StmtPtr mutate(StorePtr v) override {
+ BufPtr buf = v->buf();
if (buf != buf_) {
return IRMutator::mutate(v);
}
- Expr* newValue = v->value()->accept_mutator(this);
+ ExprPtr newValue = v->value()->accept_mutator(this);
// Map indices to call-parameters.
- std::vector<Expr*> newIndices;
+ std::vector<ExprPtr> newIndices;
TORCH_INTERNAL_ASSERT(offsets_.size() == v->indices().size());
for (size_t i = 0; i < v->indices().size(); ++i) {
- Expr* index = v->indices()[i]->accept_mutator(this);
- Expr* offset = offsets_[i];
- Expr* sub = IRSimplifier::simplify(new Sub(index, offset));
+ ExprPtr index = v->indices()[i]->accept_mutator(this);
+ ExprPtr offset = offsets_[i];
+ ExprPtr sub = IRSimplifier::simplify(alloc<Sub>(index, offset));
newIndices.push_back(sub);
}
- return new Store(cache_, newIndices, newValue);
+ return alloc<Store>(cache_, newIndices, newValue);
}
- Buf* buf_;
- Buf* cache_;
- std::vector<Expr*>& offsets_;
+ BufPtr buf_;
+ BufPtr cache_;
+ std::vector<ExprPtr>& offsets_;
};
LoopNest::AccessResult LoopNest::cacheAccesses(
- Buf* producer,
+ BufPtr producer,
const std::string& name,
- Stmt* consumer) {
- ReduceOp* reduceOp{nullptr};
+ StmtPtr consumer) {
+ ReduceOpPtr reduceOp{nullptr};
auto stores = NodeFinder<Store>::find(consumer);
- for (auto* store : stores) {
- if (auto ro = dynamic_cast<ReduceOp*>(store->value())) {
+ for (auto store : stores) {
+ if (auto ro = to<ReduceOp>(store->value())) {
if (store->buf() != producer) {
continue;
}
bool hasWrites = info.kind == kStore || info.kind == kMutate;
std::vector<std::string> var_names = {"i", "j", "k", "l", "m", "n", "o", "p"};
- std::vector<Expr*> tmp_dims;
- std::vector<Var*> new_loop_vars;
- std::vector<Expr*> new_loop_vars_expr;
+ std::vector<ExprPtr> tmp_dims;
+ std::vector<VarPtr> new_loop_vars;
+ std::vector<ExprPtr> new_loop_vars_expr;
// Determine the size of the cache, and create a loop var for each dimension.
for (size_t i = 0; i < info.start.size(); ++i) {
- Expr* dim = IRSimplifier::simplify(
- new Add(new Sub(info.stop[i], info.start[i]), new IntImm(1)));
+ ExprPtr dim = IRSimplifier::simplify(
+ alloc<Add>(alloc<Sub>(info.stop[i], info.start[i]), alloc<IntImm>(1)));
tmp_dims.push_back(dim);
- new_loop_vars.push_back(new Var(var_names[i % var_names.size()], kInt));
+ new_loop_vars.push_back(alloc<Var>(var_names[i % var_names.size()], kInt));
new_loop_vars_expr.push_back(new_loop_vars[i]);
}
// Create the var.
- Buf* tmp_buf = new Buf(new Var(name, kHandle), tmp_dims, producer->dtype());
+ BufPtr tmp_buf =
+ alloc<Buf>(alloc<Var>(name, kHandle), tmp_dims, producer->dtype());
// determine the offsets for calls into the cache based off the loop start of
// each axis.
- std::vector<Expr*> tmp_params;
+ std::vector<ExprPtr> tmp_params;
for (size_t i = 0; i < new_loop_vars.size(); ++i) {
- tmp_params.push_back(new Add(new_loop_vars[i], info.start[i]));
+ tmp_params.push_back(alloc<Add>(new_loop_vars[i], info.start[i]));
}
// Replace acceses to the producer in the consumer with the cache.
CacheReplacer replacer(producer, tmp_buf, info.start);
// TODO: Can we reuse 'consumer' below without cloning?
- Stmt* new_consumer =
+ StmtPtr new_consumer =
IRSimplifier::simplify(Stmt::clone(consumer)->accept_mutator(&replacer));
// replace the old consumer with the replaced consumer.
- Block* consumer_block = nullptr;
+ BlockPtr consumer_block = nullptr;
// if the consumer is a block, we should mutate it in place.
- if ((consumer_block = dynamic_cast<Block*>(consumer))) {
+ if ((consumer_block = to<Block>(consumer))) {
consumer_block->clear();
consumer_block->append_stmt(new_consumer);
} else {
- consumer_block = dynamic_cast<Block*>(consumer->get_parent());
+ consumer_block = to<Block>(consumer->get_parent());
assert(consumer_block);
consumer_block->replace_stmt(consumer, new_consumer);
}
// Instead we need to create a new ReduceOp.
bool on_reduce_axis = false;
if (reduceOp) {
- std::set<Var*> reduce_args(
+ std::set<VarPtr> reduce_args(
reduceOp->reduce_args().begin(), reduceOp->reduce_args().end());
- std::set<Var*> enclosing_vars;
+ std::set<VarPtr> enclosing_vars;
for (auto enclosing_for_stmt : NodeFinder<For>::find(consumer)) {
enclosing_vars.insert(enclosing_for_stmt->var());
}
// reduceOp means we had both loads and stores.
// Init cache to 0.
- Stmt* tmp_init = new Store(
+ StmtPtr tmp_init = alloc<Store>(
tmp_buf, new_loop_vars_expr, getImmediateByType(tmp_buf->dtype(), 0));
for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) {
tmp_init =
- new For(new_loop_vars[i], new IntImm(0), tmp_dims[i], tmp_init);
+ alloc<For>(new_loop_vars[i], alloc<IntImm>(0), tmp_dims[i], tmp_init);
}
consumer_block->insert_stmt_before(tmp_init, new_consumer);
// Reduce back to the original buffer:
- Stmt* tmp_store = new Store(
+ StmtPtr tmp_store = alloc<Store>(
producer,
tmp_params,
reduceOp->reducer()(
producer,
- ExprHandle(new Load(tmp_buf, new_loop_vars_expr)),
+ ExprHandle(alloc<Load>(tmp_buf, new_loop_vars_expr)),
tmp_params,
{}));
for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) {
- tmp_store =
- new For(new_loop_vars[i], new IntImm(0), tmp_dims[i], tmp_store);
+ tmp_store = alloc<For>(
+ new_loop_vars[i], alloc<IntImm>(0), tmp_dims[i], tmp_store);
}
consumer_block->insert_stmt_after(tmp_store, new_consumer);
if (hasReads) {
// Fill the cache with values from the consumer.
- Stmt* tmp_store =
- new Store(tmp_buf, new_loop_vars_expr, new Load(producer, tmp_params));
+ StmtPtr tmp_store = alloc<Store>(
+ tmp_buf, new_loop_vars_expr, alloc<Load>(producer, tmp_params));
for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) {
- tmp_store =
- new For(new_loop_vars[i], new IntImm(0), tmp_dims[i], tmp_store);
+ tmp_store = alloc<For>(
+ new_loop_vars[i], alloc<IntImm>(0), tmp_dims[i], tmp_store);
}
consumer_block->insert_stmt_before(tmp_store, new_consumer);
if (hasWrites) {
// sync the cache back to the producer buf.
- Stmt* tmp_store =
- new Store(producer, tmp_params, new Load(tmp_buf, new_loop_vars_expr));
+ StmtPtr tmp_store = alloc<Store>(
+ producer, tmp_params, alloc<Load>(tmp_buf, new_loop_vars_expr));
for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) {
- tmp_store =
- new For(new_loop_vars[i], new IntImm(0), tmp_dims[i], tmp_store);
+ tmp_store = alloc<For>(
+ new_loop_vars[i], alloc<IntImm>(0), tmp_dims[i], tmp_store);
}
consumer_block->insert_stmt_after(tmp_store, new_consumer);
* `temp` instead of `producer`. The indices in the corresponding accesses
* also need to be offset.
*/
-void LoopNest::computeAt(Stmt* s, For* f) {
- Store* st = getStoreStmtOfProducer(s);
+void LoopNest::computeAt(StmtPtr s, ForPtr f) {
+ StorePtr st = getStoreStmtOfProducer(s);
if (!st) {
return;
}
}
// Compute dimensions of the temp buffer we would need to allocate
- std::vector<Expr*> dims = getBoundExtents(bounds_it->second);
+ std::vector<ExprPtr> dims = getBoundExtents(bounds_it->second);
// TODO: Use name-hint of the producer instead of "temp"
- Buf* temp_buf = new Buf("temp", dims, st->value()->dtype());
+ BufPtr temp_buf = alloc<Buf>("temp", dims, st->value()->dtype());
// Generate index variables for 'temp'
- std::vector<Expr*> temp_indices(dims.size());
+ std::vector<ExprPtr> temp_indices(dims.size());
for (const auto i : c10::irange(dims.size())) {
// TODO: Use name-hint of the producer indices instead of 'idx'
- temp_indices[i] = new Var(std::string("idx") + c10::to_string(i), kInt);
+ temp_indices[i] = alloc<Var>(std::string("idx") + c10::to_string(i), kInt);
}
// Prepare substitute rules for constructing the temp statement from the prod
// modified (e.g. split or merged) so that the loop indices no longer
// correspond to the indices of the original expression and even their number
// might be different. In that case, the loop below would crash.
- std::vector<Var*> prod_indices = getOuterLoopIndexes(s);
- std::vector<std::pair<Var*, Expr*>> rewrite_indices_map;
- std::vector<Expr*> offsets;
+ std::vector<VarPtr> prod_indices = getOuterLoopIndexes(s);
+ std::vector<std::pair<VarPtr, ExprPtr>> rewrite_indices_map;
+ std::vector<ExprPtr> offsets;
for (const TensorAccessBoundsInfo& p : bounds_it->second) {
for (const auto i : c10::irange(p.start.size())) {
if (offsets.size() <= i) {
offsets.push_back(p.start[i]);
} else {
offsets[i] =
- IRSimplifier::simplify(new Min(offsets[i], p.start[i], true));
+ IRSimplifier::simplify(alloc<Min>(offsets[i], p.start[i], true));
}
}
}
for (const auto i : c10::irange(prod_indices.size())) {
rewrite_indices_map.push_back(
- {prod_indices[i], new Add(temp_indices[i], offsets[i])});
+ {prod_indices[i], alloc<Add>(temp_indices[i], offsets[i])});
}
// Construct the temp statement
- Stmt* bd = new Store(
+ StmtPtr bd = alloc<Store>(
temp_buf,
temp_indices,
SubstituteInClone(st->value(), rewrite_indices_map));
// We're creating loops from innermost to outermost, so we need to access
// dimensions in reversed order.
size_t dim_idx = dims.size() - 1 - i;
- bd = new For(
- dynamic_cast<Var*>(temp_indices[dim_idx]),
- new IntImm(0),
- dims[dim_idx],
- bd);
+ bd = alloc<For>(
+ to<Var>(temp_indices[dim_idx]), alloc<IntImm>(0), dims[dim_idx], bd);
}
// Add constructed stmts to the consumer loop
// Rewrite accesses to producer in consumer with accesses to temp
LoopComputeAtRewriter lr(st->buf(), temp_buf, offsets);
- Stmt* new_f = f->accept_mutator(&lr);
+ StmtPtr new_f = f->accept_mutator(&lr);
if (f != new_f) {
- Block* bb = dynamic_cast<Block*>(f->get_parent());
+ BlockPtr bb = to<Block>(f->get_parent());
bb->replace_stmt(f, new_f);
}
}
class RfactorStoreRewriter : public IRMutator {
public:
RfactorStoreRewriter(
- Buf* old_buf,
- const std::vector<Expr*>& old_indices,
- Buf* new_buf,
- Var* reduction_var)
+ BufPtr old_buf,
+ const std::vector<ExprPtr>& old_indices,
+ BufPtr new_buf,
+ VarPtr reduction_var)
: old_buf_(old_buf),
old_indices_(old_indices),
new_buf_(new_buf),
new_indices_.push_back(reduction_var_);
}
- Expr* mutate(Load* v) override {
+ ExprPtr mutate(LoadPtr v) override {
if (v->buf() != old_buf_) {
return IRMutator::mutate(v);
}
return IRMutator::mutate(v);
}
- return new Load(new_buf_, new_indices_);
+ return alloc<Load>(new_buf_, new_indices_);
}
- Expr* mutate(ReduceOp* v) override {
- Expr* body_new = v->body()->accept_mutator(this);
+ ExprPtr mutate(ReduceOpPtr v) override {
+ ExprPtr body_new = v->body()->accept_mutator(this);
- std::vector<Var*> new_reduce_args;
- for (auto* r : v->reduce_args()) {
+ std::vector<VarPtr> new_reduce_args;
+ for (auto r : v->reduce_args()) {
if (r != reduction_var_) {
new_reduce_args.push_back(r);
}
}
- return new ReduceOp(body_new, new_reduce_args, v->reducer());
+ return alloc<ReduceOp>(body_new, new_reduce_args, v->reducer());
}
- Stmt* mutate(Store* v) override {
+ StmtPtr mutate(StorePtr v) override {
if (v->buf() != old_buf_) {
return IRMutator::mutate(v);
}
return IRMutator::mutate(v);
}
- Expr* new_value = v->value()->accept_mutator(this);
- return new Store(new_buf_, new_indices_, new_value);
+ ExprPtr new_value = v->value()->accept_mutator(this);
+ return alloc<Store>(new_buf_, new_indices_, new_value);
}
private:
- Buf* old_buf_;
- const std::vector<Expr*>& old_indices_;
- Buf* new_buf_;
- Var* reduction_var_;
- std::vector<Expr*> new_indices_;
+ BufPtr old_buf_;
+ const std::vector<ExprPtr>& old_indices_;
+ BufPtr new_buf_;
+ VarPtr reduction_var_;
+ std::vector<ExprPtr> new_indices_;
};
-bool LoopNest::rfactor(Stmt* st, For* target_for) {
- Buf* tmp_buf = nullptr;
+bool LoopNest::rfactor(StmtPtr st, ForPtr target_for) {
+ BufPtr tmp_buf = nullptr;
return rfactor(st, target_for, &tmp_buf);
}
-bool LoopNest::rfactor(Stmt* st, For* outer_reduction_for, Buf** rfac_buf_ptr) {
- Store* reduction_store = dynamic_cast<Store*>(st);
- ReduceOp* reduce_op = dynamic_cast<ReduceOp*>(reduction_store->value());
+bool LoopNest::rfactor(
+ StmtPtr st,
+ ForPtr outer_reduction_for,
+ BufPtr* rfac_buf_ptr) {
+ StorePtr reduction_store = to<Store>(st);
+ ReduceOpPtr reduce_op = to<ReduceOp>(reduction_store->value());
if (!reduce_op) {
// Not a reduction store
return false;
auto orig_buf = reduction_store->buf();
auto orig_buf_indices = reduction_store->indices();
- Var* reduction_var = outer_reduction_for->var();
+ VarPtr reduction_var = outer_reduction_for->var();
- std::set<Var*> reduce_args = {
+ std::set<VarPtr> reduce_args = {
reduce_op->reduce_args().begin(), reduce_op->reduce_args().end()};
if (reduce_args.size() < 2) {
// Verify that outer_reduction_for is a perfect loop nest with all loops being
// reductions
- Stmt* cur = outer_reduction_for;
- while (For* cur_for = dynamic_cast<For*>(cur)) {
+ StmtPtr cur = outer_reduction_for;
+ while (ForPtr cur_for = to<For>(cur)) {
if (!reduce_args.count(cur_for->var())) {
// output axis inside outer_reduction_for are not allowed
return false;
}
reduce_args.erase(cur_for->var());
- Block* b = cur_for->body();
+ BlockPtr b = cur_for->body();
if (b->nstmts() != 1) {
return false;
}
// assert: reduce_axis match loop vars from outer_reduction_for and inside
// assert: no other stmts in outer_reduction_for or its child loops
- std::vector<Expr*> rfac_dims = orig_buf->dims();
- Expr* extra_dim = IRSimplifier::simplify(
- new Sub(outer_reduction_for->stop(), outer_reduction_for->start()));
+ std::vector<ExprPtr> rfac_dims = orig_buf->dims();
+ ExprPtr extra_dim = IRSimplifier::simplify(
+ alloc<Sub>(outer_reduction_for->stop(), outer_reduction_for->start()));
rfac_dims.push_back(extra_dim);
- Expr* rfac_init =
- new Cast(reduce_op->dtype(), reduce_op->reducer().initializer());
+ ExprPtr rfac_init =
+ alloc<Cast>(reduce_op->dtype(), reduce_op->reducer().initializer());
- *rfac_buf_ptr = new Buf(
+ *rfac_buf_ptr = alloc<Buf>(
orig_buf->name_hint() + "_rfac",
rfac_dims,
reduce_op->dtype(),
rfac_init);
- Buf* rfac_buf = *rfac_buf_ptr;
+ BufPtr rfac_buf = *rfac_buf_ptr;
// Rewrite the original reduction store to use the temporary rfac buffer:
// 1) X[*indexes] --> T[*indexes + {reduction_var}]
// 2) reduce_axis -= {reduction_var}
RfactorStoreRewriter rfac_rewriter(
orig_buf, orig_buf_indices, rfac_buf, reduction_var);
- dynamic_cast<Block*>(st->get_parent())
+ to<Block>(st->get_parent())
->replace_stmt(st, st->accept_mutator(&rfac_rewriter));
// Insert a store for the final reduction over the temp buffer into the
// original buffer:
// X[*indexes] = ReduceOp(X[*indexes] + T[*indexes + {reduction_var}],
// reduce_axis={reduction_var})
- Block* b = outer_reduction_for->body();
+ BlockPtr b = outer_reduction_for->body();
TORCH_INTERNAL_ASSERT(b->nstmts() == 1);
- Stmt* first_reduction_loop = b->stmts().front();
+ StmtPtr first_reduction_loop = b->stmts().front();
auto rfac_buf_indices = orig_buf_indices;
rfac_buf_indices.emplace_back(reduction_var);
- Expr* final_reduce_load = new Load(rfac_buf, rfac_buf_indices);
+ ExprPtr final_reduce_load = alloc<Load>(rfac_buf, rfac_buf_indices);
outer_reduction_for->body()->insert_stmt_after(
- new Store(
+ alloc<Store>(
orig_buf,
orig_buf_indices,
reduce_op->reducer()(
// Insert an initialization store for the temp buffer:
// T[a,b,c] = init
outer_reduction_for->body()->insert_stmt_before(
- new Store(rfac_buf, rfac_buf_indices, rfac_init), first_reduction_loop);
+ alloc<Store>(rfac_buf, rfac_buf_indices, rfac_init),
+ first_reduction_loop);
return true;
}
#include <vector>
#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/tensorexpr/fwd_decls.h>
namespace torch {
namespace jit {
// A constructor for building a LoopNest from an Stmt and a list of output
// buffers.
- LoopNest(Stmt* stmt, std::unordered_set<Buf*> output_bufs);
+ LoopNest(StmtPtr stmt, std::unordered_set<BufPtr> output_bufs);
// A constructor for building a LoopNest from another loopnest. It clones the
// other loopnest's stmt.
LoopNest(const LoopNest& other);
- Stmt* root_stmt() const {
+ StmtPtr root_stmt() const {
return root_stmt_;
}
- std::vector<For*> getLoopStmtsFor(Tensor*) const;
- std::vector<For*> getLoopStmtsFor(Buf*) const;
- std::vector<For*> getLoopStmtsFor(Stmt*) const;
- Stmt* getLoopBodyFor(Tensor*) const;
- Stmt* getLoopBodyFor(Buf*) const;
+ std::vector<ForPtr> getLoopStmtsFor(Tensor*) const;
+ std::vector<ForPtr> getLoopStmtsFor(BufPtr) const;
+ std::vector<ForPtr> getLoopStmtsFor(StmtPtr) const;
+ StmtPtr getLoopBodyFor(Tensor*) const;
+ StmtPtr getLoopBodyFor(BufPtr) const;
// Returns the For stmt indexed by 'indices' in the 'root' For stmt.
//'indices' indicates the path to the returned loop from 'root' in AST, e.g.,
// the path from 'root' to 'j_loop' is [0]
// the path from 'root' to 'k1_loop' is [0, 0]
// the path from 'root' to 'k2_loop' is [0, 2]
- For* getLoopAt(For* root, const std::vector<int>& indices) const;
+ ForPtr getLoopAt(ForPtr root, const std::vector<int>& indices) const;
// Returns the For stmt that is immediately enclosing the given stmt.
- static For* getParentLoop(Stmt* st);
+ static ForPtr getParentLoop(StmtPtr st);
// Returns the list of For stmts corresponding to the loopnest that is
// enclosing the given stmt.
- static std::vector<For*> getEnclosingLoopNest(Stmt* st);
+ static std::vector<ForPtr> getEnclosingLoopNest(StmtPtr st);
// Returns a list of all Stmts that write to the given buf.
- std::vector<Stmt*> getAllWritesToBuf(Buf*) const;
+ std::vector<StmtPtr> getAllWritesToBuf(BufPtr) const;
// The following methods return the For loops that contain writes to
// the given buf.
// to buf.
// For the above example:
// getAllInnermostLoopsWritingToBuf(a) => {j1, k2, j3}
- std::vector<For*> getAllInnermostLoopsWritingToBuf(Buf*) const;
+ std::vector<ForPtr> getAllInnermostLoopsWritingToBuf(BufPtr) const;
// Returns a list of For loopnests which contain a Stmt that writes to
// the given buf. Each loopnest here is a vector For loops.
// For the above example:
// getAllLoopNestsWritingToBuf(a) => {{i1,j1}, {i2,j2,k2}, {i2,j3}}
- std::vector<std::vector<For*>> getAllLoopNestsWritingToBuf(Buf*) const;
+ std::vector<std::vector<ForPtr>> getAllLoopNestsWritingToBuf(BufPtr) const;
- Stmt* simplify();
+ StmtPtr simplify();
- bool computeInline(Stmt* s);
- bool computeInline(Buf* b);
+ bool computeInline(StmtPtr s);
+ bool computeInline(BufPtr b);
void inlineIntermediateBufs(bool allow_duplicated_work);
// Optimizes conditionals.
// So, the pointer to the input loop should be valid after splitting and
// will point to the outer loop. The `inner` and `tail` parameters will be
// set to point to the inner and tail loops that are generated.
- static void splitWithTail(For* f, int factor, For** inner, For** tail);
+ static void splitWithTail(ForPtr f, int factor, ForPtr* inner, ForPtr* tail);
// A convenience wrapper when the caller does not need to access the
// split loops.
- static void splitWithTail(For* f, int factor);
+ static void splitWithTail(ForPtr f, int factor);
// Splits the given loop into 2 nested loops with the given factor as the
// inner loop bound. If the factor does not evenly divide the loop bound,
// So, the pointer to the input loop should be valid after splitting and
// will point to the outer loop. The `inner` parameter will be set to point
// to the inner loop that is generated.
- static void splitWithMask(For* f, int factor, For** inner);
+ static void splitWithMask(ForPtr f, int factor, ForPtr* inner);
// A convenience wrapper when the caller does not need to access the
// split loops.
- static void splitWithMask(For* f, int factor);
+ static void splitWithMask(ForPtr f, int factor);
// The following methods support loop distribution.
// For example, consider the following code. This will be used to
// : for i
// S6: for k
// S7: B[i] = B[i] +
- static std::vector<For*> distributeLoop(
- For* loop,
- const std::unordered_set<Stmt*>& pivots);
+ static std::vector<ForPtr> distributeLoop(
+ ForPtr loop,
+ const std::unordered_set<StmtPtr>& pivots);
// This method distributes the given loop over every stmt in its body.
//
// : for i
// S6: for k
// S7: B[i] = B[i] +
- static std::vector<For*> distributeLoop(For* loop);
+ static std::vector<ForPtr> distributeLoop(ForPtr loop);
// Same as above, but also distribute parent loops.
// Returns the result of distributing the outermost loop.
//
// : for i
// S6: for k
// S7: B[i] = B[i] +
- static std::vector<For*> distributeLoopAndParents(For* loop);
+ static std::vector<ForPtr> distributeLoopAndParents(ForPtr loop);
// This method distributes the given loop over its body by splitting
// after every For stmt in its body.
// S5: B[i] = A[i]
// S6: for k
// S7: B[i] = B[i] +
- static std::vector<For*> distributeLoopOverInnerLoops(For* loop);
+ static std::vector<ForPtr> distributeLoopOverInnerLoops(ForPtr loop);
// Same as above, but also distribute parent loops.
// Returns the result of distributing the outermost loop.
//
// S5: B[i] = A[i]
// S6: for k
// S7: B[i] = B[i] +
- static std::vector<For*> distributeLoopAndParentsOverInnerLoops(For* loop);
+ static std::vector<ForPtr> distributeLoopAndParentsOverInnerLoops(
+ ForPtr loop);
// This method performs loop fusion.
// For example, consider the following code.
// Below are the two requirements to apply unsafeFuseLoops:
// * All the loops have the same parent.
// * There are no statements between these loops in their parent body.
- static bool unsafeFuseLoops(const std::vector<For*>& loops, For** fused);
+ static bool unsafeFuseLoops(const std::vector<ForPtr>& loops, ForPtr* fused);
// Loop fusion is done only when all the conditions below are satisfied.
// * All the loops have the same parent.
// * The start bounds are the same for all loops.
// * The stop bounds are the same for all loops.
// * Fusing the loops does not violate or add any dependencies.
- static bool fuseLoops(const std::vector<For*>& loops, For** fused);
+ static bool fuseLoops(const std::vector<ForPtr>& loops, ForPtr* fused);
- static void reorderAxis(For* a, For* b);
+ static void reorderAxis(ForPtr a, ForPtr b);
// Reorder the given list of loops according to the permutation specified.
// Here `permutation[i]` represents the position of the loop in the input
// for p
// for q
// A[p,q,r,s] =
- static std::vector<For*> reorder(
- const std::vector<For*>& loops,
+ static std::vector<ForPtr> reorder(
+ const std::vector<ForPtr>& loops,
const std::vector<size_t>& permutation);
// Tile takes a 2d domain (x, y) and splits it into small rectangular blocks
// for k: (0, 32)
// A[i_outer * 4 + i_inner, 7 * 9 + j_tail] =
// B[i_outer * 4 + i_inner, k] + C[7 * 9 + j_tail, k]
- For* tile(For* x, For* y, int x_factor, int y_factor);
+ ForPtr tile(ForPtr x, ForPtr y, int x_factor, int y_factor);
// Returns true if the given loops are perfectly nested, i.e., every loop
// (except the innermost) should have exactly one statement in its body
// and that statement must be the next inner loop.
- static bool areLoopsPerfectlyNested(const std::vector<For*>& loops);
+ static bool areLoopsPerfectlyNested(const std::vector<ForPtr>& loops);
// Returns true if the given loop has a loop-carried dependence.
- static bool hasLoopCarriedDependence(For* loop);
+ static bool hasLoopCarriedDependence(ForPtr loop);
- static void unroll(For* f, Stmt** unrolled);
- static void unroll(For* f);
+ static void unroll(ForPtr f, StmtPtr* unrolled);
+ static void unroll(ForPtr f);
- static bool normalize(For* f);
- static bool isNormalized(For* f);
+ static bool normalize(ForPtr f);
+ static bool isNormalized(ForPtr f);
- static bool flatten(const std::vector<For*>& f, For** flattened);
- static bool flatten(const std::vector<For*>& f);
+ static bool flatten(const std::vector<ForPtr>& f, ForPtr* flattened);
+ static bool flatten(const std::vector<ForPtr>& f);
// Compresses the given buffer based on its use in the given Stmts.
//
// B[i,j] = A[0,j] + A[0, j+1]
// }
// }
- static void compressBuffer(Buf* buf, Stmt* stmt);
+ static void compressBuffer(BufPtr buf, StmtPtr stmt);
// Compresses all buffers in the given statement.
//
// kernel statement to avoid incorrect buffer compressions.
//
// TODO: Add an IR verifier check to detect invalidly compressed buffers.
- static void compressAllBuffers(Stmt* stmt);
+ static void compressAllBuffers(StmtPtr stmt);
// Get 'num' loops from the loopnest starting at 'f'.
- static std::vector<For*> getLoopStmtsInLoopNest(For* f, size_t num);
+ static std::vector<ForPtr> getLoopStmtsInLoopNest(ForPtr f, size_t num);
// LoopOptions are propagated to tail.
- static void sliceHead(For* f, int factor, For** head, For** tail);
- static void sliceHead(For* f, int factor);
+ static void sliceHead(ForPtr f, int factor, ForPtr* head, ForPtr* tail);
+ static void sliceHead(ForPtr f, int factor);
// LoopOptions are propagated to head.
- static void sliceTail(For* f, int factor, For** head, For** tail);
- static void sliceTail(For* f, int factor);
+ static void sliceTail(ForPtr f, int factor, ForPtr* head, ForPtr* tail);
+ static void sliceTail(ForPtr f, int factor);
- using AccessResult = std::pair<Buf*, Stmt*>;
+ using AccessResult = std::pair<BufPtr, StmtPtr>;
// Insert a cache for the consumer's usages of the buffer produced in
// consumer, and redirect reads and writes in the consumer to that cache.
// Returns a pair of the new cache buffer, and the new rewritten consumer.
static AccessResult cacheAccesses(
- Buf* producer,
+ BufPtr producer,
const std::string& name,
- Stmt* consumer);
+ StmtPtr consumer);
// Insert a temporary computation of statement S in the scope of loop AT.
// S is assumed to be a Store or a Block containing a Store. Along with the
// computation itself, this transformation inserts Alloc/Free statements for
// the temporary buffer used in the computation.
- static void computeAt(Stmt* s, For* at);
+ static void computeAt(StmtPtr s, ForPtr at);
// Rfactor a reduction axis into a normal axis.
//
// S4: for k # reduction axis
// X_rfac[i,j] = ReduceOp(X_rfac[i,j] + Y[i,j,k], reduce_axis={k})
// X[i] = ReduceOp(X[i] + X_rfac[i,j], reduce_axis={j})
- static bool rfactor(Stmt* s, For* outer_reduction_for);
- static bool rfactor(Stmt* s, For* outer_reduction_for, Buf** rfac_buf_ptr);
+ static bool rfactor(StmtPtr s, ForPtr outer_reduction_for);
+ static bool rfactor(
+ StmtPtr s,
+ ForPtr outer_reduction_for,
+ BufPtr* rfac_buf_ptr);
// Vectorize the given loop. This method requires that the given loop
// does not perform a reduction.
// It returns true if vectorization is successful and false otherwise.
- static bool vectorize(For*);
+ static bool vectorize(ForPtr);
// Find the inner-most loops and vectorize them. Currently, this only works
// for the LLVM backend, when no reductions are involved.
void eliminateDeadStores();
void prepareForCodegen();
- const std::unordered_set<Buf*> getInputBufs() const;
- const std::unordered_set<Buf*> getOutputBufs() const {
+ const std::unordered_set<BufPtr> getInputBufs() const;
+ const std::unordered_set<BufPtr> getOutputBufs() const {
return output_bufs_;
}
void initialize(
const std::vector<Tensor*>& output_tensors,
const std::vector<Tensor*>& tensors_to_compute);
- Stmt* insertAllocFree(Stmt* stmt);
- const std::unordered_set<Buf*> getIntermediateBufs() const;
+ StmtPtr insertAllocFree(StmtPtr stmt);
+ const std::unordered_set<BufPtr> getIntermediateBufs() const;
- Stmt* root_stmt_;
+ StmtPtr root_stmt_;
- std::unordered_set<Buf*> output_bufs_;
+ std::unordered_set<BufPtr> output_bufs_;
};
-TORCH_API Stmt* FlattenIndexes(Stmt* s);
+TORCH_API StmtPtr FlattenIndexes(StmtPtr s);
// TODO: Revisit this once we decide on how dependencies analysis should look
// like. Maybe we would choose to use a different API and BufUse would be
// removed, or if we decide to keep it we need to properly document its API.
struct BufLoadOrStoreUse {
- Stmt* s;
+ StmtPtr s;
bool isStore;
};
/*
* Returns a map ( Buf -> uses of this Buf), uses are represented as vectors of
- * BufUse elements, which are Stmt* and a bool isStore flag. The order of uses
+ * BufUse elements, which are StmtPtr and a bool isStore flag. The order of uses
* in the vectors reflects the order in which the uses appear in the given
* statement.
*/
-std::unordered_map<Buf*, std::vector<BufLoadOrStoreUse>> findLoadOrStoreUses(
- Stmt* s);
+std::unordered_map<BufPtr, std::vector<BufLoadOrStoreUse>> findLoadOrStoreUses(
+ StmtPtr s);
} // namespace tensorexpr
} // namespace jit
// AccessInfo
-std::vector<Expr*> AccessInfo::getIndices() const {
- std::vector<Expr*> indices;
+std::vector<ExprPtr> AccessInfo::getIndices() const {
+ std::vector<ExprPtr> indices;
if (expr_) {
- if (auto* load = dynamic_cast<Load*>(expr_)) {
+ if (auto load = to<Load>(expr_)) {
indices = load->indices();
}
} else {
- if (auto* store = dynamic_cast<Store*>(stmt_)) {
+ if (auto store = to<Store>(stmt_)) {
indices = store->indices();
}
}
os << "label = \"" << AccessToString(type_) << "\\n " << *var_ << "[";
if (bounds_.size() > 0) {
for (size_t i = 0; i < bounds_.size() - 1; ++i) {
- os << *IRSimplifier::simplify(new Add(bounds_[i].end, new IntImm(1)))
+ os << *IRSimplifier::simplify(
+ alloc<Add>(bounds_[i].end, alloc<IntImm>(1)))
<< ", ";
}
size_t i = bounds_.size() - 1;
- os << *IRSimplifier::simplify(new Add(bounds_[i].end, new IntImm(1)));
+ os << *IRSimplifier::simplify(
+ alloc<Add>(bounds_[i].end, alloc<IntImm>(1)));
os << "]\"\n ";
}
if (isWrite()) {
}
MemDependencyChecker::MemDependencyChecker(
- const std::unordered_set<Buf*>& inputs,
- const std::unordered_set<Buf*>& outputs) {
- for (auto* s : inputs) {
+ const std::unordered_set<BufPtr>& inputs,
+ const std::unordered_set<BufPtr>& outputs) {
+ for (auto s : inputs) {
inputs_[s] = nullptr;
}
- for (auto* s : outputs) {
+ for (auto s : outputs) {
outputs_[s] = nullptr;
}
return writes;
}
-bool MemDependencyChecker::dependsDirectly(Expr* A, Stmt* B) {
+bool MemDependencyChecker::dependsDirectly(ExprPtr A, StmtPtr B) {
return dependsDirectlyHelper(A, B);
}
-bool MemDependencyChecker::dependsDirectly(Stmt* A, Stmt* B) {
+bool MemDependencyChecker::dependsDirectly(StmtPtr A, StmtPtr B) {
return dependsDirectlyHelper(A, B);
}
-bool MemDependencyChecker::dependsDirectly(Buf* O, Stmt* B) {
+bool MemDependencyChecker::dependsDirectly(BufPtr O, StmtPtr B) {
auto outputAccess = output(O);
auto bWrites = getAllWritesWithin(B);
return false;
}
-bool MemDependencyChecker::dependsDirectly(Stmt* A, Buf* I) {
+bool MemDependencyChecker::dependsDirectly(StmtPtr A, BufPtr I) {
auto aReads = getAllReadsWithin(A);
auto inputAccess = input(I);
return false;
}
-bool MemDependencyChecker::dependsDirectly(Expr* A, Buf* I) {
+bool MemDependencyChecker::dependsDirectly(ExprPtr A, BufPtr I) {
auto aReads = getAllReadsWithin(A);
auto inputAccess = input(I);
return A->hasDependency(B) && B->isWrite();
}
-bool MemDependencyChecker::dependsIndirectly(Expr* A, Stmt* B) {
+bool MemDependencyChecker::dependsIndirectly(ExprPtr A, StmtPtr B) {
return dependsIndirectlyHelper(A, B);
}
-bool MemDependencyChecker::dependsIndirectly(Stmt* A, Stmt* B) {
+bool MemDependencyChecker::dependsIndirectly(StmtPtr A, StmtPtr B) {
return dependsIndirectlyHelper(A, B);
}
-bool MemDependencyChecker::dependsIndirectly(Buf* O, Stmt* B) {
+bool MemDependencyChecker::dependsIndirectly(BufPtr O, StmtPtr B) {
auto outputAccess = output(O);
DependencySet dependencies;
return false;
}
-bool MemDependencyChecker::dependsIndirectly(Stmt* A, Buf* I) {
+bool MemDependencyChecker::dependsIndirectly(StmtPtr A, BufPtr I) {
auto aReads = getAllReadsWithin(A);
auto inputAccess = input(I);
return aDeps.count(inputAccess) != 0;
}
-bool MemDependencyChecker::dependsIndirectly(Expr* A, Buf* I) {
+bool MemDependencyChecker::dependsIndirectly(ExprPtr A, BufPtr I) {
auto aReads = getAllReadsWithin(A);
auto inputAccess = input(I);
return aDeps.count(inputAccess) != 0;
}
-bool MemDependencyChecker::dependsIndirectly(Buf* O, Buf* I) {
+bool MemDependencyChecker::dependsIndirectly(BufPtr O, BufPtr I) {
auto outputAccess = output(O);
auto inputAccess = input(I);
return true;
}
-std::shared_ptr<AccessInfo> MemDependencyChecker::accessFor(Stmt* A) const {
+std::shared_ptr<AccessInfo> MemDependencyChecker::accessFor(StmtPtr A) const {
auto bound = stmtToAccess_.equal_range(A);
for (auto it = bound.first; it != bound.second; ++it) {
if (it->second->expr() == nullptr) {
return nullptr;
}
-std::shared_ptr<AccessInfo> MemDependencyChecker::accessFor(Expr* A) const {
+std::shared_ptr<AccessInfo> MemDependencyChecker::accessFor(ExprPtr A) const {
// TODO exprs can have multiple accesses... we're returning the first but that
// isn't great. Can't do much here.
auto bound = exprToAccess_.equal_range(A);
}
std::unordered_set<std::shared_ptr<AccessInfo>> MemDependencyChecker::
- accessesWithin(Stmt* A) const {
+ accessesWithin(StmtPtr A) const {
auto it = scopeToAccesses_.find(A);
if (it != scopeToAccesses_.end()) {
return std::unordered_set<std::shared_ptr<AccessInfo>>(
}
std::unordered_set<std::shared_ptr<AccessInfo>> MemDependencyChecker::
- accessesWithin(Expr* A) const {
+ accessesWithin(ExprPtr A) const {
return {accessFor(A)};
}
-std::shared_ptr<AccessInfo> MemDependencyChecker::input(Buf* b) const {
+std::shared_ptr<AccessInfo> MemDependencyChecker::input(BufPtr b) const {
auto it = inputs_.find(b);
if (it == inputs_.end()) {
return nullptr;
return it->second;
}
-std::shared_ptr<AccessInfo> MemDependencyChecker::output(Buf* b) const {
+std::shared_ptr<AccessInfo> MemDependencyChecker::output(BufPtr b) const {
auto it = outputs_.find(b);
if (it == outputs_.end()) {
return nullptr;
// Node visitors:
-void MemDependencyChecker::visit(Store* v) {
- Stmt* last = lastStmt_;
+void MemDependencyChecker::visit(StorePtr v) {
+ StmtPtr last = lastStmt_;
lastStmt_ = v;
v->value()->accept(this);
- for (Expr* ind : v->indices()) {
+ for (ExprPtr ind : v->indices()) {
ind->accept(this);
}
lastStmt_ = last;
// Create a new AccessInfo for the store.
- Var* var = v->buf()->base_handle();
+ VarPtr var = v->buf()->base_handle();
auto info = std::make_shared<AccessInfo>(
nextAccess_++, AccessType::Store, v, var, getIndicesBounds(v->indices()));
currentScope_->accesses_.push_back(info);
}
-void MemDependencyChecker::visit(Load* v) {
+void MemDependencyChecker::visit(LoadPtr v) {
// Create a temporary scope to hold any loads that occur within the indices of
// this load.
auto indicesScope =
std::make_shared<Scope>(currentScope_->block, currentScope_);
currentScope_ = indicesScope;
- for (Expr* ind : v->indices()) {
+ for (ExprPtr ind : v->indices()) {
ind->accept(this);
}
// Create a new AccessInfo for the load.
- Var* var = v->buf()->base_handle();
+ VarPtr var = v->buf()->base_handle();
auto load = std::make_shared<AccessInfo>(
nextAccess_++,
AccessType::Load,
bool executionSafetyCheck(
const std::shared_ptr<AccessInfo>& info,
const std::shared_ptr<AccessInfo>& other,
- const std::vector<Expr*>& aStrides,
- const std::vector<Expr*>& oStrides,
+ const std::vector<ExprPtr>& aStrides,
+ const std::vector<ExprPtr>& oStrides,
bool parallelized) {
if (aStrides.empty() || oStrides.empty()) {
return false;
}
TORCH_INTERNAL_ASSERT(info->bounds().size() == other->bounds().size());
for (size_t b = 0; b < info->bounds().size(); ++b) {
- Expr* aIndexStride = aStrides[b];
- Expr* oIndexStride = oStrides[b];
+ ExprPtr aIndexStride = aStrides[b];
+ ExprPtr oIndexStride = oStrides[b];
// can't be safe on this index if we can't determine stride.
if (!aIndexStride->isConstant() || !oIndexStride->isConstant()) {
continue;
}
- Expr* minStride =
- IRSimplifier::simplify(new Min(aIndexStride, oIndexStride, true));
- Expr* maxStride =
- IRSimplifier::simplify(new Max(aIndexStride, oIndexStride, true));
+ ExprPtr minStride =
+ IRSimplifier::simplify(alloc<Min>(aIndexStride, oIndexStride, true));
+ ExprPtr maxStride =
+ IRSimplifier::simplify(alloc<Max>(aIndexStride, oIndexStride, true));
// If the first access has no stride don't apply safety).
if (immediateEquals(minStride, 0)) {
continue;
}
- Expr* modCheck = IRSimplifier::simplify(new Mod(maxStride, minStride));
+ ExprPtr modCheck = IRSimplifier::simplify(alloc<Mod>(maxStride, minStride));
// if the strides can't have easily inferable distinct offsets, they're not
// safe.
// axis is the same sign as the common stride, then they will not
// overlap.
- Expr* startDiff = IRSimplifier::simplify(
- new Sub(info->bounds()[b].start, other->bounds()[b].start));
+ ExprPtr startDiff = IRSimplifier::simplify(
+ alloc<Sub>(info->bounds()[b].start, other->bounds()[b].start));
bool diffNegative = immediateIsNegative(startDiff);
bool strideNegative = immediateIsNegative(minStride);
// Invert the startDiff so mod works.
if (diffNegative != strideNegative) {
- startDiff = IRSimplifier::simplify(new Sub(new IntImm(0), startDiff));
+ startDiff =
+ IRSimplifier::simplify(alloc<Sub>(alloc<IntImm>(0), startDiff));
}
// If both accesses have the same stride, and the difference in start
// element is smaller than this stride then the entire range is distinct.
if (exprEquals(minStride, maxStride)) {
- Expr* check1 =
- IRSimplifier::simplify(new CompareSelect(startDiff, minStride, kLT));
+ ExprPtr check1 = IRSimplifier::simplify(
+ alloc<CompareSelect>(startDiff, minStride, kLT));
if (check1->isConstant() && immediateEquals(check1, 1)) {
return true;
}
}
- startDiff = IRSimplifier::simplify(new Mod(startDiff, minStride));
+ startDiff = IRSimplifier::simplify(alloc<Mod>(startDiff, minStride));
CompareSelectOperation op = strideNegative ? kLT : kGT;
- Expr* check =
- IRSimplifier::simplify(new CompareSelect(startDiff, new IntImm(0), op));
+ ExprPtr check = IRSimplifier::simplify(
+ alloc<CompareSelect>(startDiff, alloc<IntImm>(0), op));
// If the start difference modulo the minimum stride is offset from that
// stride, then the ranges have distinct strides.
return false;
}
-void MemDependencyChecker::visit(For* v) {
- Var* var = v->var();
+void MemDependencyChecker::visit(ForPtr v) {
+ VarPtr var = v->var();
- Stmt* last = lastStmt_;
+ StmtPtr last = lastStmt_;
lastStmt_ = v;
v->var()->accept(this);
// access, which we do via substituting the loop var with (var+1) into the
// indices expr.
- std::vector<std::vector<Expr*>> loopStrides;
+ std::vector<std::vector<ExprPtr>> loopStrides;
loopStrides.resize(currentScope_->accesses_.size());
for (size_t a = 0; a < currentScope_->accesses_.size(); ++a) {
auto& info = currentScope_->accesses_[a];
- std::vector<Expr*> indices = info->getIndices();
+ std::vector<ExprPtr> indices = info->getIndices();
- std::vector<Expr*>& loopIndicesStride = loopStrides[a];
+ std::vector<ExprPtr>& loopIndicesStride = loopStrides[a];
loopIndicesStride.resize(indices.size());
// index expr must depend on the loop var in some way to have a stride.
for (const auto i : c10::irange(indices.size())) {
VarFinder vf;
if (vf.find(indices[i]).count(var) == 0) {
- loopIndicesStride[i] = new IntImm(0);
+ loopIndicesStride[i] = alloc<IntImm>(0);
} else {
// If we've previously swapped the start and end of this bound, we
// should apply the substitution to the reverse of the bounds.
SubstituteInClone(info->bounds()[i].end, {{var, v->start()}}));
info->bounds()[i].start = IRSimplifier::simplify(SubstituteInClone(
info->bounds()[i].start,
- {{var, new Sub(v->stop(), new IntImm(1))}}));
+ {{var, alloc<Sub>(v->stop(), alloc<IntImm>(1))}}));
} else {
info->bounds()[i].start = IRSimplifier::simplify(
SubstituteInClone(info->bounds()[i].start, {{var, v->start()}}));
info->bounds()[i].end = IRSimplifier::simplify(SubstituteInClone(
info->bounds()[i].end,
- {{var, new Sub(v->stop(), new IntImm(1))}}));
+ {{var, alloc<Sub>(v->stop(), alloc<IntImm>(1))}}));
}
- Expr* zeroStep = indices[i];
- Expr* oneStep =
- SubstituteInClone(indices[i], {{var, new Add(var, new IntImm(1))}});
+ ExprPtr zeroStep = indices[i];
+ ExprPtr oneStep = SubstituteInClone(
+ indices[i], {{var, alloc<Add>(var, alloc<IntImm>(1))}});
loopIndicesStride[i] =
- IRSimplifier::simplify(new Sub(oneStep, zeroStep));
+ IRSimplifier::simplify(alloc<Sub>(oneStep, zeroStep));
// If the start < end then swap the order of the bound.
- Expr* diff = IRSimplifier::simplify(
- new Sub(info->bounds()[i].end, info->bounds()[i].start));
+ ExprPtr diff = IRSimplifier::simplify(
+ alloc<Sub>(info->bounds()[i].end, info->bounds()[i].start));
if (diff->isConstant() && immediateIsNegative(diff)) {
info->bounds()[i].swap();
}
bound.start = IRSimplifier::simplify(
SubstituteInClone(bound.start, {{var, v->start()}}));
bound.end = IRSimplifier::simplify(SubstituteInClone(
- bound.end, {{var, new Sub(v->stop(), new IntImm(1))}}));
+ bound.end, {{var, alloc<Sub>(v->stop(), alloc<IntImm>(1))}}));
// If the start < end then swap the order of the bound.
- Expr* diff = IRSimplifier::simplify(new Sub(bound.end, bound.start));
+ ExprPtr diff =
+ IRSimplifier::simplify(alloc<Sub>(bound.end, bound.start));
if (diff->isConstant() && immediateIsNegative(diff)) {
bound.swap();
}
v->loop_options().is_gpu_thread_index();
// Store buffers allocated at this scope.
- std::unordered_set<Var*> local_intermediates;
+ std::unordered_set<VarPtr> local_intermediates;
// Scanning from the top of the loop, we look for accesses which may depend
// on a previous or parallel loop iteration.
currentScope_ = currentScope_->parent;
}
-void MemDependencyChecker::visit(Cond* v) {
- Stmt* last = lastStmt_;
+void MemDependencyChecker::visit(CondPtr v) {
+ StmtPtr last = lastStmt_;
lastStmt_ = v;
auto enclosingScope =
// condition is in enclosing scope.
v->condition()->accept(this);
- Block* true_stmt = v->true_stmt();
- Block* false_stmt = v->false_stmt();
+ BlockPtr true_stmt = v->true_stmt();
+ BlockPtr false_stmt = v->false_stmt();
// Create scopes so the Block visitor doesn't create and merge a new scope.
auto trueScope = std::make_shared<Scope>(true_stmt, enclosingScope);
lastStmt_ = last;
}
-void MemDependencyChecker::visit(IfThenElse* v) {
+void MemDependencyChecker::visit(IfThenElsePtr v) {
// condition is in enclosing scope.
v->condition()->accept(this);
- Expr* true_value = v->true_value();
- Expr* false_value = v->false_value();
+ ExprPtr true_value = v->true_value();
+ ExprPtr false_value = v->false_value();
auto enclosingScope = currentScope_;
currentScope_ = enclosingScope;
}
-void MemDependencyChecker::visit(CompareSelect* v) {
+void MemDependencyChecker::visit(CompareSelectPtr v) {
// condition is in enclosing scope.
v->lhs()->accept(this);
v->rhs()->accept(this);
- Expr* true_value = v->ret_val1();
- Expr* false_value = v->ret_val2();
+ ExprPtr true_value = v->ret_val1();
+ ExprPtr false_value = v->ret_val2();
auto enclosingScope = currentScope_;
// Inserts accesses for a map of buffers (ie. for inputs and outputs).
void MemDependencyChecker::insertBuffers(
- std::unordered_map<Buf*, std::shared_ptr<AccessInfo>>& bufs,
+ std::unordered_map<BufPtr, std::shared_ptr<AccessInfo>>& bufs,
AccessType type) {
for (auto& pair : bufs) {
- Buf* b = pair.first;
- Var* var = b->base_handle();
+ BufPtr b = pair.first;
+ VarPtr var = b->base_handle();
IndexBounds bounds;
- for (auto* d : b->dims()) {
+ for (auto d : b->dims()) {
bounds.push_back(
- {new IntImm(0), IRSimplifier::simplify(new Sub(d, new IntImm(1)))});
+ {alloc<IntImm>(0),
+ IRSimplifier::simplify(alloc<Sub>(d, alloc<IntImm>(1)))});
}
auto info =
std::make_shared<AccessInfo>(nextAccess_++, type, nullptr, var, bounds);
}
}
-void MemDependencyChecker::visit(Block* v) {
+void MemDependencyChecker::visit(BlockPtr v) {
auto prev_scope = currentScope_;
// handle kernel inputs.
}
if (currentScope_->block != v) {
- currentScope_ = std::make_shared<Scope>((Block*)v, prev_scope);
+ currentScope_ = std::make_shared<Scope>((BlockPtr)v, prev_scope);
}
- for (auto* s : *v) {
+ for (auto s : *v) {
s->accept(this);
}
- for (auto* v : currentScope_->localVars) {
+ for (auto v : currentScope_->localVars) {
knownVarBounds_.erase(v);
}
for (auto& pair : currentScope_->shadowedVarBounds) {
}
}
-void MemDependencyChecker::visit(Let* v) {
- Stmt* last = lastStmt_;
+void MemDependencyChecker::visit(LetPtr v) {
+ StmtPtr last = lastStmt_;
lastStmt_ = v;
IRVisitor::visit(v);
lastStmt_ = last;
- Var* var = v->var();
+ VarPtr var = v->var();
if (knownVarBounds_.count(var) != 0) {
currentScope_->shadowedVarBounds[var] = knownVarBounds_[var];
}
// Don't support AtomicAdd yet, it's a bit more complex since it's both a read
// and a write. It's only inserted during Cuda codegen so this should be okay.
-void MemDependencyChecker::visit(AtomicAdd* v) {
+void MemDependencyChecker::visit(AtomicAddPtr v) {
throw std::runtime_error("MemDependencyChecker AtomicAdd unimplemented");
}
-void MemDependencyChecker::visit(Allocate* v) {
- Stmt* last = lastStmt_;
+void MemDependencyChecker::visit(AllocatePtr v) {
+ StmtPtr last = lastStmt_;
lastStmt_ = v;
IRVisitor::visit(v);
- Var* var = v->buffer_var();
+ VarPtr var = v->buffer_var();
IndexBounds bounds;
// TODO: remove the "buf_flat_size" process below and extend the buf bound
// check to support N-d indices access and 1-d index access.
// identify 1-d index access for N-d bufs. Thus we flatten N-d bufs here to
// avoid failing the bound check. But this is not the correct approach and
// should be fixed.
- Expr* flat_size = buf_flat_size(v->buf());
- flat_size = IRSimplifier::simplify(new Sub(flat_size, new IntImm(1)));
- bounds.push_back({new IntImm(0), flat_size});
+ ExprPtr flat_size = buf_flat_size(v->buf());
+ flat_size = IRSimplifier::simplify(alloc<Sub>(flat_size, alloc<IntImm>(1)));
+ bounds.push_back({alloc<IntImm>(0), flat_size});
auto info = std::make_shared<AccessInfo>(
nextAccess_++, AccessType::Alloc, nullptr, var, bounds);
lastStmt_ = last;
}
-void MemDependencyChecker::visit(Free* v) {
- Stmt* last = lastStmt_;
+void MemDependencyChecker::visit(FreePtr v) {
+ StmtPtr last = lastStmt_;
lastStmt_ = v;
IRVisitor::visit(v);
- Var* var = v->buffer_var();
+ VarPtr var = v->buffer_var();
auto it = intermediates_.find(var);
TORCH_INTERNAL_ASSERT(it != intermediates_.end());
// Copy open writes up.
for (auto& pair : child->openWrites_) {
- Var* var = pair.first;
+ VarPtr var = pair.first;
// Intentionally using operator[], we want it to be created if it does not
// exist.
public:
VarBoundBinder(const VarBoundMap& vars) : vars_(vars) {}
- Bound getBounds(Expr* e) {
+ Bound getBounds(ExprPtr e) {
min_ = e;
max_ = e;
e->accept(this);
}
private:
- void visit(Var* v) override {
+ void visit(VarPtr v) override {
auto it = vars_.find(v);
if (it == vars_.end()) {
return;
max_ = SubstituteInClone(max_, {{v, it->second.end}});
}
- Expr* min_{nullptr};
- Expr* max_{nullptr};
+ ExprPtr min_{nullptr};
+ ExprPtr max_{nullptr};
const VarBoundMap& vars_;
};
std::vector<Bound> MemDependencyChecker::getIndicesBounds(
- const std::vector<Expr*>& indices) {
+ const std::vector<ExprPtr>& indices) {
std::vector<Bound> bounds;
bounds.reserve(indices.size());
VarBoundBinder binder(knownVarBounds_);
- for (auto* s : indices) {
+ for (auto s : indices) {
bounds.push_back(binder.getBounds(s));
}
return bounds;
AccessInfo(
size_t id,
AccessType type,
- Stmt* stmt,
- Var* var,
+ StmtPtr stmt,
+ VarPtr var,
IndexBounds bounds)
: id_(id),
type_(type),
AccessInfo(
size_t id,
AccessType type,
- Expr* expr,
- Stmt* stmt,
- Var* var,
+ ExprPtr expr,
+ StmtPtr stmt,
+ VarPtr var,
IndexBounds bounds)
: id_(id),
type_(type),
// The enclosing Stmt this access represents. E.g. if this is a Store then
// Stmt is the Store itself, while if the access is caused by an Expr, this is
// the most immediate parent Stmt.
- Stmt* stmt() const {
+ StmtPtr stmt() const {
return stmt_;
}
// If the access is represented by an Expr (such as Load or Call) then this is
// it, otherwise it's nullptr.
- Expr* expr() const {
+ ExprPtr expr() const {
return expr_;
}
// The Var representing the underlying Buffer.
- Var* var() const {
+ VarPtr var() const {
return var_;
}
}
// Returns the symbolic expression of the indices of this access.
- std::vector<Expr*> getIndices() const;
+ std::vector<ExprPtr> getIndices() const;
// Establishes a dependency or dependent relationship with another access.
void addDependency(const std::shared_ptr<AccessInfo>& write);
private:
size_t id_;
AccessType type_;
- Stmt* stmt_;
- Expr* expr_;
- Var* var_;
+ StmtPtr stmt_;
+ ExprPtr expr_;
+ VarPtr var_;
IndexBounds bounds_;
// Yes these should be sorted.
std::map<size_t, std::shared_ptr<AccessInfo>> dependents_;
};
-using VarBoundMap = std::unordered_map<Var*, Bound>;
+using VarBoundMap = std::unordered_map<VarPtr, Bound>;
/* MemDepedencyChecker analyses a IR fragment and builds a dependency graph of
* accesses contained within.
public:
MemDependencyChecker();
MemDependencyChecker(
- const std::unordered_set<Buf*>& inputs,
- const std::unordered_set<Buf*>& outputs);
+ const std::unordered_set<BufPtr>& inputs,
+ const std::unordered_set<BufPtr>& outputs);
MemDependencyChecker(
const std::vector<BufHandle>& inputs,
const std::vector<BufHandle>& outputs);
// about it.
// Returns true if any read in A has a direct dependence on a write in B.
- bool dependsDirectly(Stmt* A, Stmt* B);
- bool dependsDirectly(Expr* A, Stmt* B);
+ bool dependsDirectly(StmtPtr A, StmtPtr B);
+ bool dependsDirectly(ExprPtr A, StmtPtr B);
// Returns true of the output depends directly on a write contained in B.
- bool dependsDirectly(Buf* output, Stmt* B);
+ bool dependsDirectly(BufPtr output, StmtPtr B);
// Returns true if a read in A depends directly on the provided input.
- bool dependsDirectly(Stmt* A, Buf* input);
- bool dependsDirectly(Expr* A, Buf* input);
+ bool dependsDirectly(StmtPtr A, BufPtr input);
+ bool dependsDirectly(ExprPtr A, BufPtr input);
// Outputs/inputs cannot depend directly.
const std::shared_ptr<AccessInfo>& B);
// Returns true if any read in A has an ancestor write contained in B.
- bool dependsIndirectly(Stmt* A, Stmt* B);
- bool dependsIndirectly(Expr* A, Stmt* B);
+ bool dependsIndirectly(StmtPtr A, StmtPtr B);
+ bool dependsIndirectly(ExprPtr A, StmtPtr B);
// Returns true of the output depends indirectly on a write contained in B.
- bool dependsIndirectly(Buf* output, Stmt* B);
+ bool dependsIndirectly(BufPtr output, StmtPtr B);
// Returns true if a read in A depends indirectly on the provided input.
- bool dependsIndirectly(Stmt* A, Buf* input);
- bool dependsIndirectly(Expr* A, Buf* input);
+ bool dependsIndirectly(StmtPtr A, BufPtr input);
+ bool dependsIndirectly(ExprPtr A, BufPtr input);
// returns true if the output uses any load of the input.
- bool dependsIndirectly(Buf* output, Buf* input);
+ bool dependsIndirectly(BufPtr output, BufPtr input);
// Returns true if the access A has a dependency chain to access B.
bool dependsIndirectly(
const std::shared_ptr<AccessInfo>& B);
// Returns the AccessInfo
- std::shared_ptr<AccessInfo> accessFor(Stmt* A) const;
- std::shared_ptr<AccessInfo> accessFor(Expr* A) const;
+ std::shared_ptr<AccessInfo> accessFor(StmtPtr A) const;
+ std::shared_ptr<AccessInfo> accessFor(ExprPtr A) const;
// Returns all AccessInfos.
- std::unordered_set<std::shared_ptr<AccessInfo>> accessesWithin(Stmt* A) const;
+ std::unordered_set<std::shared_ptr<AccessInfo>> accessesWithin(
+ StmtPtr A) const;
// TODO: this will return only the AccessInfo for A. It's included for
// completeness but be aware it wont return accesses used in the computation
// of A.
- std::unordered_set<std::shared_ptr<AccessInfo>> accessesWithin(Expr* A) const;
+ std::unordered_set<std::shared_ptr<AccessInfo>> accessesWithin(
+ ExprPtr A) const;
// Accesses relating to input and output buffers.
- std::shared_ptr<AccessInfo> input(Buf* B) const;
- std::shared_ptr<AccessInfo> output(Buf* B) const;
+ std::shared_ptr<AccessInfo> input(BufPtr B) const;
+ std::shared_ptr<AccessInfo> output(BufPtr B) const;
// Returns the full history of reads and writes.
const std::vector<std::shared_ptr<AccessInfo>>& getHistory() const;
private:
// Node visitors.
- void visit(Store* v) override;
- void visit(Load* v) override;
- void visit(For* v) override;
- void visit(Cond* v) override;
- void visit(IfThenElse* v) override;
- void visit(CompareSelect* v) override;
- void visit(Block* v) override;
- void visit(Let* v) override;
- void visit(AtomicAdd* v) override;
- void visit(Allocate* v) override;
- void visit(Free* v) override;
+ void visit(StorePtr v) override;
+ void visit(LoadPtr v) override;
+ void visit(ForPtr v) override;
+ void visit(CondPtr v) override;
+ void visit(IfThenElsePtr v) override;
+ void visit(CompareSelectPtr v) override;
+ void visit(BlockPtr v) override;
+ void visit(LetPtr v) override;
+ void visit(AtomicAddPtr v) override;
+ void visit(AllocatePtr v) override;
+ void visit(FreePtr v) override;
using BoundRelationship = std::pair<IndexBounds, std::shared_ptr<AccessInfo>>;
// An internal struct holding the accesses found within a scope Block.
struct Scope {
- Scope(Block* b, std::shared_ptr<Scope> p)
+ Scope(BlockPtr b, std::shared_ptr<Scope> p)
: block(b), parent(std::move(p)) {}
- Block* block;
+ BlockPtr block;
std::shared_ptr<Scope> parent;
- std::unordered_map<Var*, Bound> shadowedVarBounds;
- std::unordered_set<Var*> localVars;
+ std::unordered_map<VarPtr, Bound> shadowedVarBounds;
+ std::unordered_set<VarPtr> localVars;
std::vector<std::shared_ptr<AccessInfo>> accesses_;
- std::unordered_map<Var*, std::list<BoundRelationship>> openWrites_;
+ std::unordered_map<VarPtr, std::list<BoundRelationship>> openWrites_;
};
std::shared_ptr<Scope> currentScope_;
bool allowExecutionOrderAnalysis_{false};
- std::unordered_multimap<Stmt*, std::shared_ptr<AccessInfo>> stmtToAccess_;
- std::unordered_multimap<Expr*, std::shared_ptr<AccessInfo>> exprToAccess_;
- std::unordered_map<Stmt*, std::vector<std::shared_ptr<AccessInfo>>>
+ std::unordered_multimap<StmtPtr, std::shared_ptr<AccessInfo>> stmtToAccess_;
+ std::unordered_multimap<ExprPtr, std::shared_ptr<AccessInfo>> exprToAccess_;
+ std::unordered_map<StmtPtr, std::vector<std::shared_ptr<AccessInfo>>>
scopeToAccesses_;
VarBoundMap knownVarBounds_;
// Finds all accesses that are reads within the scope of v.
- template <typename StmtOrExpr>
- DependencySet getAllReadsWithin(StmtOrExpr* v) {
+ template <typename StmtOrExprPtr>
+ DependencySet getAllReadsWithin(StmtOrExprPtr v) {
DependencySet reads;
auto insertAllReads = [&](const auto& nodes) {
for (auto* l : nodes) {
// Finds all accesses that are writes within the scope of v.
// Writes cannot occur in Exprs, so this is a little simpler.
- DependencySet getAllWritesWithin(Stmt* v) {
+ DependencySet getAllWritesWithin(StmtPtr v) {
DependencySet writes;
// writes just Store currently.
}
// Templated helpers to work on either Exprs or Stmts.
- template <typename StmtOrExpr>
- bool dependsDirectlyHelper(StmtOrExpr* A, Stmt* B) {
+ template <typename StmtOrExprPtr>
+ bool dependsDirectlyHelper(StmtOrExprPtr A, StmtPtr B) {
auto aReads = getAllReadsWithin(A);
auto bWrites = getAllWritesWithin(B);
return false;
}
- template <typename StmtOrExpr>
- bool dependsIndirectlyHelper(StmtOrExpr* A, Stmt* B) {
+ template <typename StmtOrExprPtr>
+ bool dependsIndirectlyHelper(StmtOrExprPtr A, StmtPtr B) {
auto aReads = getAllReadsWithin(A);
auto bWrites = getAllWritesWithin(B);
DependencySet getAllWriteDependencies(const DependencySet& products);
// Maps for inputs and outputs, since they aren't present directly in the IR.
- std::unordered_map<Buf*, std::shared_ptr<AccessInfo>> inputs_;
- std::unordered_map<Buf*, std::shared_ptr<AccessInfo>> outputs_;
- std::unordered_map<Var*, std::shared_ptr<AccessInfo>> intermediates_;
+ std::unordered_map<BufPtr, std::shared_ptr<AccessInfo>> inputs_;
+ std::unordered_map<BufPtr, std::shared_ptr<AccessInfo>> outputs_;
+ std::unordered_map<VarPtr, std::shared_ptr<AccessInfo>> intermediates_;
// Inserts accesses for Buf's: specifically for inputs and outputs.
void insertBuffers(
- std::unordered_map<Buf*, std::shared_ptr<AccessInfo>>& bufs,
+ std::unordered_map<BufPtr, std::shared_ptr<AccessInfo>>& bufs,
AccessType type);
// Update the write history with a new write, adding dependencies and closing
bool closeOverlapped = true);
// Binds symbolic vars in indices with the low and high bound for those vars.
- std::vector<Bound> getIndicesBounds(const std::vector<Expr*>& indices);
+ std::vector<Bound> getIndicesBounds(const std::vector<ExprPtr>& indices);
size_t nextAccess_{0};
- Stmt* lastStmt_{nullptr};
+ StmtPtr lastStmt_{nullptr};
};
} // namespace analysis
constexpr int kLoopH = 2, kLoopW = 3;
if (R == 3 && stride == 2 && pad == 1) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For *head, *tail;
+ ForPtr head, tail;
auto loops = nest.getLoopStmtsFor(conv);
nest.sliceHead(loops[kLoopW], 2, &head, &tail);
loops = nest.getLoopStmtsFor(conv);
nest.sliceHead(loops[kLoopH], 2, &head, &tail);
} else if (R == 3 && stride == 1 && pad == 1) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- For *main, *peeled;
+ ForPtr main, peeled;
auto loops = nest.getAllLoopNestsWritingToBuf(conv->buf());
main = loops[1][kLoopW];
nest.sliceHead(main, 1, &peeled, &main);
auto size_b = b.dims();
// We currently only support rank 2 matmuls
TORCH_INTERNAL_ASSERT(size_a.size() == 2 && size_b.size() == 2);
- auto total_size = dynamic_cast<LongImm*>(
- IRSimplifier::simplify(
- cast<int64_t>(size_a[0]) * cast<int64_t>(size_a[1]) *
- cast<int64_t>(size_b[1]))
- .node());
+ auto total_size =
+ to<LongImm>(IRSimplifier::simplify(
+ cast<int64_t>(size_a[0]) * cast<int64_t>(size_a[1]) *
+ cast<int64_t>(size_b[1]))
+ .node());
// For small sizes, where N*M*K < 1000, lower matmul to a naive 3-level
// loopnest. The number is not tuned very carefully, and in future we should
});
return new Tensor(
result->buf(),
- new tensorexpr::Block(
- {max->stmt(), e->stmt(), sum->stmt(), result->stmt()}));
+ alloc<tensorexpr::Block>(std::vector<StmtPtr>(
+ {max->stmt(), e->stmt(), sum->stmt(), result->stmt()})));
}
auto log_sum = Compute(
});
return new Tensor(
result->buf(),
- new tensorexpr::Block(
+ alloc<tensorexpr::Block>(std::vector<StmtPtr>(
{max->stmt(),
e->stmt(),
sum->stmt(),
log_sum->stmt(),
- result->stmt()}));
+ result->stmt()})));
}
} // namespace tensorexpr
namespace jit {
namespace tensorexpr {
-ReduceOp* Reducer::operator()(
- Buf* result_buf,
+ReduceOpPtr Reducer::operator()(
+ BufPtr result_buf,
ExprHandle body,
- const std::vector<Expr*>& output,
- const std::vector<Var*>& inner) const {
- return new ReduceOp(
+ const std::vector<ExprPtr>& output,
+ const std::vector<VarPtr>& inner) const {
+ return alloc<ReduceOp>(
complete(result_buf, interaction_, body, output, inner), inner, *this);
}
-ReduceOp* Reducer::operator()(
- Buf* result_buf,
- Expr* body,
- const std::vector<Expr*>& output,
- const std::vector<Var*>& inner) const {
- return new ReduceOp(
+ReduceOpPtr Reducer::operator()(
+ BufPtr result_buf,
+ ExprPtr body,
+ const std::vector<ExprPtr>& output,
+ const std::vector<VarPtr>& inner) const {
+ return alloc<ReduceOp>(
complete(result_buf, interaction_, ExprHandle(body), output, inner),
inner,
*this);
}
virtual ~Reducer() = default;
- Expr* initializer() const {
+ ExprPtr initializer() const {
return init_;
}
- ReduceOp* operator()(
- Buf* result_buf,
+ ReduceOpPtr operator()(
+ BufPtr result_buf,
ExprHandle body,
- const std::vector<Expr*>& output,
- const std::vector<Var*>& inner) const;
+ const std::vector<ExprPtr>& output,
+ const std::vector<VarPtr>& inner) const;
- ReduceOp* operator()(
- Buf* result_buf,
- Expr* body,
- const std::vector<Expr*>& output,
- const std::vector<Var*>& inner) const;
+ ReduceOpPtr operator()(
+ BufPtr result_buf,
+ ExprPtr body,
+ const std::vector<ExprPtr>& output,
+ const std::vector<VarPtr>& inner) const;
// Polymorphic handling of Body functions with a variety of parameters.
static ExprHandle getReduceBody(
// Completes the reduction operator by applying the interaction function to
// the accumulation and the body expression.
- static Expr* complete(
- Buf* accumulator,
+ static ExprPtr complete(
+ BufPtr accumulator,
ReduceInteraction interaction,
ExprHandle body,
- const std::vector<Expr*>& output_args,
- const std::vector<Var*>& reduce_args) {
+ const std::vector<ExprPtr>& output_args,
+ const std::vector<VarPtr>& reduce_args) {
ExprHandle accum =
- ExprHandle(new Load(body.dtype(), accumulator, output_args));
+ ExprHandle(alloc<Load>(body.dtype(), accumulator, output_args));
auto e = interaction(accum, body);
return e.node();
}
private:
- Expr* init_;
+ ExprPtr init_;
ReduceInteraction interaction_;
};
class TORCH_API ReduceOp : public ExprNode<ReduceOp> {
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
- ReduceOp(Expr* body, std::vector<Var*> reduce_args, const Reducer& reducer)
+ ReduceOp(
+ ExprPtr body,
+ std::vector<VarPtr> reduce_args,
+ const Reducer& reducer)
: ExprNodeBase(body->dtype()),
body_(body),
reduce_args_(std::move(reduce_args)),
reducer_(reducer) {}
// return the body expression which obtains the value to be reduced.
- Expr* body() const {
+ ExprPtr body() const {
return body_;
}
}
// returns variables associated with the axes of reduction.
- const std::vector<Var*>& reduce_args() const {
+ const std::vector<VarPtr>& reduce_args() const {
return reduce_args_;
}
private:
- Expr* body_;
- std::vector<Var*> reduce_args_;
+ ExprPtr body_;
+ std::vector<VarPtr> reduce_args_;
const Reducer reducer_;
};
class ReductionExpander : public IRMutator {
public:
- Stmt* expand(Stmt* s) {
+ StmtPtr expand(StmtPtr s) {
return s->accept_mutator(this);
}
- Expr* mutate(ReduceOp* v) override {
+ ExprPtr mutate(ReduceOpPtr v) override {
return v->body();
}
};
// AccessInfo
-void AccessInfo::addStore(Store* store, const std::shared_ptr<Scope>& scope) {
+void AccessInfo::addStore(StorePtr store, const std::shared_ptr<Scope>& scope) {
block_ =
block_ ? Block::getSharedParent(block_, scope->block()) : scope->block();
first_usage_ = first_usage_ ? block_->getEnclosedRoot(first_usage_) : store;
last_usage_ = store;
- store_cost_ = IRSimplifier::simplify(new Add(store_cost_, new IntImm(1)));
+ store_cost_ =
+ IRSimplifier::simplify(alloc<Add>(store_cost_, alloc<IntImm>(1)));
stores_.push_back(store);
conditionId_ = scope->conditionId();
}
void AccessInfo::addLoad(
- Load* load,
+ LoadPtr load,
const std::shared_ptr<Scope>& scope,
- Stmt* usage) {
+ StmtPtr usage) {
block_ =
block_ ? Block::getSharedParent(block_, scope->block()) : scope->block();
first_usage_ = first_usage_ ? block_->getEnclosedRoot(first_usage_) : usage;
last_usage_ = usage;
- load_cost_ = IRSimplifier::simplify(new Add(load_cost_, new IntImm(1)));
+ load_cost_ = IRSimplifier::simplify(alloc<Add>(load_cost_, alloc<IntImm>(1)));
loads_.push_back(load);
conditionId_ = scope->conditionId();
TORCH_INTERNAL_ASSERT(indices_.size() == other->indices().size());
last_usage_ = other->last_usage();
- for (auto* s : other->stores()) {
+ for (auto s : other->stores()) {
stores_.push_back(s);
}
- for (auto* l : other->loads()) {
+ for (auto l : other->loads()) {
loads_.push_back(l);
}
store_cost_ =
- IRSimplifier::simplify(new Add(store_cost_, other->store_cost()));
- load_cost_ = IRSimplifier::simplify(new Add(load_cost_, other->load_cost()));
+ IRSimplifier::simplify(alloc<Add>(store_cost_, other->store_cost()));
+ load_cost_ =
+ IRSimplifier::simplify(alloc<Add>(load_cost_, other->load_cost()));
block_ = Block::getSharedParent(block_, other->block());
// update first and last usage to be in the parent Block.
// dimension.
bool overlap = true;
for (size_t i = 0; i < indices_.size(); ++i) {
- Expr* diff = new Sub(indices_[i], other_indices[i]);
+ ExprPtr diff = alloc<Sub>(indices_[i], other_indices[i]);
diff = IRSimplifier::simplify(diff);
if (diff->isConstant() && !immediateEquals(diff, 0)) {
return overlap;
}
-bool AccessInfo::dependsOnVar(Var* v) {
+bool AccessInfo::dependsOnVar(VarPtr v) {
VarFinder vf;
- for (auto* i : indices_) {
+ for (auto i : indices_) {
i->accept(&vf);
}
newInfo->firstUsageOverlapped_ = orig->firstUsageOverlapped_;
newInfo->store_cost_ = orig->store_cost_;
newInfo->load_cost_ = orig->load_cost_;
- for (auto* s : orig->stores_) {
+ for (auto s : orig->stores_) {
newInfo->stores_.push_back(s);
}
- for (auto* s : orig->loads_) {
+ for (auto s : orig->loads_) {
newInfo->loads_.push_back(s);
}
void AccessInfo::print() const {
std::cout << "Access: " << *buf_ << "{";
- for (auto* i : indices_) {
+ for (auto i : indices_) {
std::cout << *i << " ";
}
std::cout << "} stores: " << stores_.size() << " (" << *store_cost_ << ") -";
closedAccesses_.push_back(info);
}
-AccessHashMap& Scope::getAccessMapByBuf(Buf* b) {
+AccessHashMap& Scope::getAccessMapByBuf(BufPtr b) {
auto it = openAccesses_.find(b);
if (it == openAccesses_.end()) {
// create and return
scope->closeAccess(info);
}
-void RegisterizerAnalysis::visit(For* v) {
+void RegisterizerAnalysis::visit(ForPtr v) {
if (v->loop_options().is_gpu_block_index() ||
v->loop_options().is_gpu_thread_index()) {
throw malformed_input(
v->body()->accept(this);
stmtStack_.pop_front();
- Expr* loopExtent = IRSimplifier::simplify(new Sub(v->stop(), v->start()));
+ ExprPtr loopExtent =
+ IRSimplifier::simplify(alloc<Sub>(v->stop(), v->start()));
// now we need to see which accesses we can hoist out of the for loop, their
// costs should be multiplied by the loop extent.
bool closed = false;
// If this access depends on a locally scoped variable, it cannot be
// hosted out of the loop.
- for (auto* v : currentScope_->localVars()) {
+ for (auto v : currentScope_->localVars()) {
if (candidate->dependsOnVar(v)) {
closeAccessIntoScope(candidate, currentScope_);
closed = true;
mergeCurrentScopeIntoParent();
};
-void RegisterizerAnalysis::visit(Cond* v) {
- Expr* condition = v->condition();
- Block* true_stmt = v->true_stmt();
- Block* false_stmt = v->false_stmt();
+void RegisterizerAnalysis::visit(CondPtr v) {
+ ExprPtr condition = v->condition();
+ BlockPtr true_stmt = v->true_stmt();
+ BlockPtr false_stmt = v->false_stmt();
stmtStack_.push_front(v);
// IfThenElses are just like Conds except they are not Stmts, which means no
// registerization can occur internally. However, the first reference to an
// access can occur within one if its visible outside the condition.
-void RegisterizerAnalysis::visit(IfThenElse* v) {
- Expr* condition = v->condition();
- Expr* true_value = v->true_value();
- Expr* false_value = v->false_value();
+void RegisterizerAnalysis::visit(IfThenElsePtr v) {
+ ExprPtr condition = v->condition();
+ ExprPtr true_value = v->true_value();
+ ExprPtr false_value = v->false_value();
// condition is in enclosing scope.
condition->accept(this);
}
}
-void RegisterizerAnalysis::visit(Let* v) {
+void RegisterizerAnalysis::visit(LetPtr v) {
currentScope_->addLocalVar(v->var());
stmtStack_.push_front(v);
stmtStack_.pop_front();
}
-void RegisterizerAnalysis::visit(Block* v) {
+void RegisterizerAnalysis::visit(BlockPtr v) {
auto prev_scope = currentScope_;
if (currentScope_->block() != v) {
currentScope_ = std::make_shared<Scope>(v, prev_scope);
stmtStack_.push_front(v);
- for (auto* s : *v) {
+ for (auto s : *v) {
s->accept(this);
if (currentScope_->block() != v) {
// merge the inner block's accesses into this Block's accesses.
}
}
-void RegisterizerAnalysis::visit(Store* v) {
+void RegisterizerAnalysis::visit(StorePtr v) {
stmtStack_.push_front(v);
v->value()->accept(this);
stmtStack_.pop_front();
// hash the Store:
SimplifierHashType accessHash = hasher_.hash(v->buf());
- for (auto* i : v->indices()) {
+ for (auto i : v->indices()) {
accessHash = hasher_.hash_combine(accessHash, i);
}
}
}
-void RegisterizerAnalysis::visit(Load* v) {
+void RegisterizerAnalysis::visit(LoadPtr v) {
if (v->indices().empty()) {
// already a scalar.
return;
}
// hash the Load:
SimplifierHashType accessHash = hasher_.hash(v->buf());
- for (auto* i : v->indices()) {
+ for (auto i : v->indices()) {
accessHash = hasher_.hash_combine(accessHash, i);
}
// copy across current open accesses, merging as necessary.
// for each Buf with an open access:
for (auto& pair : currentScope_->openAccesses()) {
- Buf* buf = pair.first;
+ BufPtr buf = pair.first;
if (pair.second.empty()) {
continue;
}
// If this access depends on a locally scoped variable, it cannot be
// lifted out of the loop.
- for (auto* v : currentScope_->localVars()) {
+ for (auto v : currentScope_->localVars()) {
if (candidate->dependsOnVar(v)) {
closeAccessIntoScope(candidate, parent);
handled = true;
// RegisterizerReplacer
-Expr* RegisterizerReplacer::mutate(Load* v) {
+ExprPtr RegisterizerReplacer::mutate(LoadPtr v) {
auto it = loadToAccess_.find(v);
if (it == loadToAccess_.end()) {
// This access cannot be registerized.
return info->replacement().var;
}
-Stmt* RegisterizerReplacer::mutate(Store* v) {
+StmtPtr RegisterizerReplacer::mutate(StorePtr v) {
if (eliminatedIntializers_.count(v) != 0) {
// This store is the intializer for a scalar var that is already inserted.
return nullptr;
auto& info = it->second;
- Expr* new_val = v->value()->accept_mutator(this);
+ ExprPtr new_val = v->value()->accept_mutator(this);
- return new Store(info->replacement().var_wrapper, {}, new_val);
+ return alloc<Store>(
+ info->replacement().var_wrapper, std::vector<ExprPtr>({}), new_val);
}
-Stmt* RegisterizerReplacer::mutate(Block* v) {
+StmtPtr RegisterizerReplacer::mutate(BlockPtr v) {
auto& scope = parentToAccesses_[v];
- std::vector<Stmt*> stmts;
- for (Stmt* stmt : v->stmts()) {
+ std::vector<StmtPtr> stmts;
+ for (StmtPtr stmt : v->stmts()) {
{
// Insert the initializer for any Scalars scoped to this block.
auto it = scope.initializerPoints_.find(stmt);
if (it != scope.initializerPoints_.end()) {
for (auto& info : it->second) {
- Stmt* initializer =
+ StmtPtr initializer =
info->replacement().initializer->accept_mutator(this);
stmts.push_back(initializer);
}
}
}
- Stmt* stmt_new = stmt->accept_mutator(this);
+ StmtPtr stmt_new = stmt->accept_mutator(this);
if (stmt_new) {
if (stmt_new->get_parent()) {
stmt_new = Stmt::clone(stmt_new);
auto it = scope.finalizePoints_.find(stmt);
if (it != scope.finalizePoints_.end()) {
for (auto& info : it->second) {
- Store* finalizer =
- new Store(info->buf(), info->indices(), info->replacement().var);
+ StorePtr finalizer = alloc<Store>(
+ info->buf(), info->indices(), info->replacement().var);
stmts.push_back(finalizer);
}
scope.finalizePoints_.erase(it);
}
}
- return new Block(stmts);
+ return alloc<Block>(stmts);
}
void RegisterizerReplacer::buildReplacements() {
// Traverse the list of replacements, creating vars and updating our local
// maps.
for (auto& info : infoSet_) {
- Var* v = new Var(
+ VarPtr v = alloc<Var>(
info->buf()->name_hint() + "_" +
c10::to_string(getBufferAccessCount(info->buf())),
info->buf()->dtype());
info->replacement().var = v;
// we need to wrap the Var in a Buf so we can Load or Store it.
- info->replacement().var_wrapper = new Buf(v, {}, info->buf()->dtype());
+ info->replacement().var_wrapper =
+ alloc<Buf>(v, std::vector<ExprPtr>({}), info->buf()->dtype());
bool first = true;
- for (auto* s : info->stores()) {
+ for (auto s : info->stores()) {
if (first && info->first_usage() == s && !info->firstUsageOverlapped()) {
- info->replacement().initializer = new Let(v, s->value());
+ info->replacement().initializer = alloc<Let>(v, s->value());
eliminatedIntializers_.insert(s);
} else {
storeToAccess_[s] = info;
first = false;
}
- for (auto* s : info->loads()) {
+ for (auto s : info->loads()) {
loadToAccess_[s] = info;
}
// create a default initializer by reading the access.
if (info->replacement().initializer == nullptr) {
- info->replacement().initializer = new Let(
- v, new Load(info->buf()->dtype(), info->buf(), info->indices()));
+ info->replacement().initializer = alloc<Let>(
+ v, alloc<Load>(info->buf()->dtype(), info->buf(), info->indices()));
}
}
}
} // namespace registerizer
// Apply scalar replacement to all accesses in s.
-Stmt* registerize(Stmt* s) {
+StmtPtr registerize(StmtPtr s) {
s = IRSimplifier::simplify(s);
// The outermost node must be a Block so we have somewhere to put outer scope
// scalars.
- if (!dynamic_cast<Block*>(s)) {
- s = new Block({s});
+ if (!to<Block>(s)) {
+ s = alloc<Block>(std::vector<StmtPtr>({s}));
}
registerizer::RegisterizerAnalysis analysis;
s->accept(&analysis);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
AccessInfo(
SimplifierHashType h,
- Buf* b,
- std::vector<Expr*> i,
+ BufPtr b,
+ std::vector<ExprPtr> i,
size_t accessOrder)
: hash_(h),
buf_(b),
indices_(std::move(i)),
- store_cost_(new IntImm(0)),
- load_cost_(new IntImm(0)),
+ store_cost_(alloc<IntImm>(0)),
+ load_cost_(alloc<IntImm>(0)),
accessOrder_(accessOrder) {}
// Adds a Store to this access, which is in the provided scope.
- void addStore(Store* store, const std::shared_ptr<Scope>& scope);
+ void addStore(StorePtr store, const std::shared_ptr<Scope>& scope);
// Adds a Load to this access, which occurs in the usage Stmt in the provided
// scope.
- void addLoad(Load* load, const std::shared_ptr<Scope>& scope, Stmt* usage);
+ void addLoad(
+ LoadPtr load,
+ const std::shared_ptr<Scope>& scope,
+ StmtPtr usage);
// Merge another AccessInfo into this one.
void merge(const std::shared_ptr<AccessInfo>& other);
bool overlaps(const std::shared_ptr<AccessInfo>& other);
// Returns true if the indices of this access depend on the provided Var.
- bool dependsOnVar(Var* v);
+ bool dependsOnVar(VarPtr v);
// Clone this AccessInfo, and set this as the new accesses' hiddenAccess.
static std::shared_ptr<AccessInfo> cloneWithHiddenInfo(
return hash_;
}
- Buf* buf() const {
+ BufPtr buf() const {
return buf_;
}
- const std::vector<Expr*>& indices() const {
+ const std::vector<ExprPtr>& indices() const {
return indices_;
}
- Block* block() const {
+ BlockPtr block() const {
return block_;
}
- void setEnclosingBlock(Block* b) {
+ void setEnclosingBlock(BlockPtr b) {
block_ = b;
}
- Stmt* first_usage() const {
+ StmtPtr first_usage() const {
return first_usage_;
}
- Stmt* last_usage() const {
+ StmtPtr last_usage() const {
return last_usage_;
}
- void setUsageMarks(Stmt* first, Stmt* last) {
+ void setUsageMarks(StmtPtr first, StmtPtr last) {
first_usage_ = first;
last_usage_ = last;
}
return firstUsageOverlapped_;
}
- Expr* store_cost() const {
+ ExprPtr store_cost() const {
return store_cost_;
}
- Expr* load_cost() const {
+ ExprPtr load_cost() const {
return load_cost_;
}
- const std::vector<Store*>& stores() const {
+ const std::vector<StorePtr>& stores() const {
return stores_;
}
- const std::vector<Load*>& loads() const {
+ const std::vector<LoadPtr>& loads() const {
return loads_;
}
- void hoistCosts(Expr* extent) {
- store_cost_ = IRSimplifier::simplify(new Mul(store_cost_, extent));
- load_cost_ = IRSimplifier::simplify(new Mul(load_cost_, extent));
+ void hoistCosts(ExprPtr extent) {
+ store_cost_ = IRSimplifier::simplify(alloc<Mul>(store_cost_, extent));
+ load_cost_ = IRSimplifier::simplify(alloc<Mul>(load_cost_, extent));
}
size_t conditionId() const {
// Holds state relating to the scalar variable we will insert to replace some
// number of loads and stores.
struct ScalarReplacement {
- Var* var{nullptr};
- Buf* var_wrapper{nullptr};
- Let* initializer{nullptr};
+ VarPtr var{nullptr};
+ BufPtr var_wrapper{nullptr};
+ LetPtr initializer{nullptr};
};
ScalarReplacement& replacement() {
private:
SimplifierHashType hash_;
- Buf* buf_;
- std::vector<Expr*> indices_;
- Block* block_{nullptr};
+ BufPtr buf_;
+ std::vector<ExprPtr> indices_;
+ BlockPtr block_{nullptr};
- Stmt* first_usage_{nullptr};
- Stmt* last_usage_{nullptr};
+ StmtPtr first_usage_{nullptr};
+ StmtPtr last_usage_{nullptr};
// Whether or not this access is overlapped in the first Stmt it appears. This
// means we cannot use it's first Store as the initializer.
// The cost in real ops that this access represents, to enable
// filtering accesses that wont save any loads or stores.
- Expr* store_cost_;
- Expr* load_cost_;
+ ExprPtr store_cost_;
+ ExprPtr load_cost_;
// The actual Stores and Loads which represent this access.
// Be careful with these, any mutator will invalidate these pointers.
- std::vector<Store*> stores_;
- std::vector<Load*> loads_;
+ std::vector<StorePtr> stores_;
+ std::vector<LoadPtr> loads_;
// An identifier representing the conditional block, if any, this access
// depends on.
class Scope {
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
- Scope(Block* b, std::shared_ptr<Scope> parent, size_t conditionId = 0)
+ Scope(BlockPtr b, std::shared_ptr<Scope> parent, size_t conditionId = 0)
: block_(b), parent_(std::move(parent)), conditionId_(conditionId) {}
- AccessHashMap& getAccessMapByBuf(Buf* b);
+ AccessHashMap& getAccessMapByBuf(BufPtr b);
- std::unordered_map<Buf*, AccessHashMap>& openAccesses() {
+ std::unordered_map<BufPtr, AccessHashMap>& openAccesses() {
return openAccesses_;
}
return closedAccesses_;
}
- Block* block() const {
+ BlockPtr block() const {
return block_;
}
return conditionId_;
}
- const std::unordered_set<Var*>& localVars() const {
+ const std::unordered_set<VarPtr>& localVars() const {
return localVars_;
}
- void addLocalVar(Var* v) {
+ void addLocalVar(VarPtr v) {
localVars_.insert(v);
}
// overlap with other accesses to the same buf. Buf ->
// Hash ->
// Access
- std::unordered_map<Buf*, AccessHashMap> openAccesses_;
+ std::unordered_map<BufPtr, AccessHashMap> openAccesses_;
std::vector<std::shared_ptr<AccessInfo>> closedAccesses_;
// The Block object this scope represents.
- Block* block_;
+ BlockPtr block_;
// The enclosing scope object.
std::shared_ptr<Scope> parent_;
size_t conditionId_;
// A set of variables local to this scope (e.g. loop vars).
- std::unordered_set<Var*> localVars_;
+ std::unordered_set<VarPtr> localVars_;
};
/* Analyzes the graph and collects accesses to the same symbolic tensor element
: currentScope_(std::make_shared<Scope>(nullptr, nullptr, 0)) {}
~RegisterizerAnalysis() override = default;
- void visit(For* v) override;
+ void visit(ForPtr v) override;
- void visit(Cond* v) override;
+ void visit(CondPtr v) override;
- void visit(Block* v) override;
+ void visit(BlockPtr v) override;
- void visit(Store* v) override;
+ void visit(StorePtr v) override;
- void visit(Load* v) override;
+ void visit(LoadPtr v) override;
- void visit(IfThenElse* v) override;
+ void visit(IfThenElsePtr v) override;
- void visit(Let* v) override;
+ void visit(LetPtr v) override;
-#define STMT_ON_STACK(Op) \
- void visit(Op* v) override { \
- stmtStack_.push_front(v); \
- IRVisitor::visit(v); \
- stmtStack_.pop_front(); \
+#define STMT_ON_STACK(Op) \
+ void visit(Op##Ptr v) override { \
+ stmtStack_.push_front(v); \
+ IRVisitor::visit(v); \
+ stmtStack_.pop_front(); \
}
STMT_ON_STACK(AtomicAdd);
std::unordered_set<size_t> exprConditionals_;
// A stack of enclosing Stmts for tracking the usage Stmt of Loads.
- std::deque<Stmt*> stmtStack_;
+ std::deque<StmtPtr> stmtStack_;
// The current scope being analyzed.
std::shared_ptr<Scope> currentScope_;
buildReplacements();
}
- Expr* mutate(Load* v) override;
+ ExprPtr mutate(LoadPtr v) override;
- Stmt* mutate(Store* v) override;
+ StmtPtr mutate(StorePtr v) override;
- Stmt* mutate(Block* v) override;
+ StmtPtr mutate(BlockPtr v) override;
private:
struct ReplacerScope {
- std::unordered_map<Stmt*, std::deque<std::shared_ptr<AccessInfo>>>
+ std::unordered_map<StmtPtr, std::deque<std::shared_ptr<AccessInfo>>>
initializerPoints_;
- std::unordered_map<Stmt*, std::deque<std::shared_ptr<AccessInfo>>>
+ std::unordered_map<StmtPtr, std::deque<std::shared_ptr<AccessInfo>>>
finalizePoints_;
};
// State relating to the accesses yet to be replaced.
std::vector<std::shared_ptr<AccessInfo>>& infoSet_;
- std::unordered_map<Store*, std::shared_ptr<AccessInfo>> storeToAccess_;
- std::unordered_map<Load*, std::shared_ptr<AccessInfo>> loadToAccess_;
- std::unordered_map<Block*, ReplacerScope> parentToAccesses_;
+ std::unordered_map<StorePtr, std::shared_ptr<AccessInfo>> storeToAccess_;
+ std::unordered_map<LoadPtr, std::shared_ptr<AccessInfo>> loadToAccess_;
+ std::unordered_map<BlockPtr, ReplacerScope> parentToAccesses_;
// Holds the set of Stores that should be pulled into an initializer, so they
// can be eliminated.
- std::set<Store*> eliminatedIntializers_;
+ std::set<StorePtr> eliminatedIntializers_;
// Tracks the number of times we've seen each buffer, so we can name the
// scalar Vars appropriately.
- std::unordered_map<Buf*, unsigned int> bufferAccessCounts_;
- unsigned int getBufferAccessCount(Buf* b) {
+ std::unordered_map<BufPtr, unsigned int> bufferAccessCounts_;
+ unsigned int getBufferAccessCount(BufPtr b) {
return ++bufferAccessCounts_[b];
}
};
// Apply scalar replacement to all accesses in s.
// To produce safe code, this must occur after handling parallelized axes and
// atomics.
-TORCH_API Stmt* registerize(Stmt* s);
+TORCH_API StmtPtr registerize(StmtPtr s);
} // namespace tensorexpr
} // namespace jit
public:
Stmt() = default;
virtual void accept(IRVisitor* visitor) = 0;
- virtual Stmt* accept_mutator(IRMutator* mutator) = 0;
+ virtual StmtPtr accept_mutator(IRMutator* mutator) = 0;
- Stmt* get_parent() const {
+ StmtPtr get_parent() const {
return parent_;
}
* cloned. Note that the variables are not deep-copied since they are
* immutable.
*/
- static Stmt* clone(Stmt* s);
+ static StmtPtr clone(StmtPtr s);
protected:
- static void set_parent(Stmt* s, Stmt* new_parent) {
+ static void set_parent(StmtPtr s, StmtPtr new_parent) {
s->parent_ = new_parent;
}
private:
- Stmt* parent_ = nullptr;
+ StmtPtr parent_ = nullptr;
};
template <class Op>
public:
using StmtNodeBase = StmtNode<Op>;
void accept(IRVisitor* visitor) override {
- visitor->visit(static_cast<Op*>(this));
+ visitor->visit(static_to<Op>(this));
}
- Stmt* accept_mutator(IRMutator* mutator) override;
+ StmtPtr accept_mutator(IRMutator* mutator) override;
StmtNode() = default;
};
template <class Op>
-Stmt* StmtNode<Op>::accept_mutator(IRMutator* mutator) {
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
- StmtNode* this_mutable = const_cast<StmtNode*>(this);
- return mutator->mutate(static_cast<Op*>(this_mutable));
+StmtPtr StmtNode<Op>::accept_mutator(IRMutator* mutator) {
+ return mutator->mutate(static_to<Op>(this));
}
// Concrete Stmt classes
class TORCH_API Block : public StmtNode<Block> {
public:
- static Block* make(const std::vector<Stmt*>& stmts) {
+ static BlockPtr make(const std::vector<StmtPtr>& stmts) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::vector<Stmt*> valid_stmts;
+ std::vector<StmtPtr> valid_stmts;
for (auto& stmt : stmts) {
if (!stmt) {
continue;
if (valid_stmts.empty()) {
return nullptr;
}
- return new Block(valid_stmts);
+ return alloc<Block>(valid_stmts);
}
int nstmts() const {
return stmts_.empty();
}
- void prepend_stmt(Stmt* s) {
+ void prepend_stmt(StmtPtr s) {
if (s->get_parent()) {
throw malformed_input("Block prepend Stmt with existing parent", s);
}
stmts_.push_front(s);
set_parent(s, this);
}
- void append_stmt(Stmt* s) {
+ void append_stmt(StmtPtr s) {
if (s->get_parent()) {
throw malformed_input("Block append Stmt with existing parent", s);
}
set_parent(s, this);
}
- void insert_stmt_before(Stmt* s, Stmt* before) {
+ void insert_stmt_before(StmtPtr s, StmtPtr before) {
if (s->get_parent()) {
throw malformed_input("Block append Stmt with existing parent", s);
}
set_parent(s, this);
}
- void insert_stmt_after(Stmt* s, Stmt* after) {
+ void insert_stmt_after(StmtPtr s, StmtPtr after) {
if (s->get_parent()) {
throw malformed_input("Block append Stmt with existing parent", s);
}
set_parent(s, this);
}
- bool replace_stmt(Stmt* old_stmt, Stmt* new_stmt) {
+ bool replace_stmt(StmtPtr old_stmt, StmtPtr new_stmt) {
if (new_stmt->get_parent()) {
throw malformed_input(
"Block replace Stmt with existing parent", new_stmt);
// Creates a new block by cloning `this` block and replacing the given
// statement with a new statement. Note that `old_stmt` refers to a statement
// in `this` block. If the `old_stmt` is not found, it will return `nullptr`.
- Block* clone_and_replace(Stmt* old_stmt, Stmt* new_stmt) {
+ BlockPtr clone_and_replace(StmtPtr old_stmt, StmtPtr new_stmt) {
if (new_stmt->get_parent()) {
throw malformed_input(
"Block replace Stmt with existing parent", new_stmt);
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::vector<Stmt*> stmts(stmts_.begin(), stmts_.end());
+ std::vector<StmtPtr> stmts(stmts_.begin(), stmts_.end());
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::vector<Stmt*> cloned_stmts(stmts.size());
+ std::vector<StmtPtr> cloned_stmts(stmts.size());
bool found = false;
for (int i = 0; i < static_cast<int>(stmts.size()); ++i) {
if (stmts[i] == old_stmt) {
if (!found) {
return nullptr;
}
- return new Block(cloned_stmts);
+ return alloc<Block>(cloned_stmts);
}
- bool remove_stmt(Stmt* stmt) {
+ bool remove_stmt(StmtPtr stmt) {
auto pos = std::find(stmts_.begin(), stmts_.end(), stmt);
if (pos == stmts_.end()) {
return false;
return true;
}
- std::list<Stmt*> stmts() const {
+ std::list<StmtPtr> stmts() const {
return stmts_;
}
stmts_.clear();
}
- void set_stmts(const std::vector<Stmt*>& stmts) {
+ void set_stmts(const std::vector<StmtPtr>& stmts) {
clear();
init(stmts);
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
- explicit Block(const std::vector<Stmt*>& stmts) {
+ explicit Block(const std::vector<StmtPtr>& stmts) {
init(stmts);
}
- typedef std::list<Stmt*>::iterator iterator;
- typedef std::list<Stmt*>::const_iterator const_iterator;
+ typedef std::list<StmtPtr>::iterator iterator;
+ typedef std::list<StmtPtr>::const_iterator const_iterator;
iterator begin() {
return stmts_.begin();
return stmts_.end();
}
- Stmt* front() {
+ StmtPtr front() {
return stmts_.front();
}
- Stmt* front() const {
+ StmtPtr front() const {
return stmts_.front();
}
- Stmt* back() {
+ StmtPtr back() {
return stmts_.back();
}
- Stmt* back() const {
+ StmtPtr back() const {
return stmts_.back();
}
- void splice(Block::iterator it, Block* other) {
- for (Stmt* s : *other) {
+ void splice(Block::iterator it, BlockPtr other) {
+ for (StmtPtr s : *other) {
set_parent(s, this);
}
stmts_.splice(it, other->stmts_);
}
- static Block* getSharedParent(Stmt* p1, Stmt* p2) {
+ static BlockPtr getSharedParent(StmtPtr p1, StmtPtr p2) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::unordered_set<Block*> enclosing;
+ std::unordered_set<BlockPtr> enclosing;
- Stmt* p1_p = p1;
+ StmtPtr p1_p = p1;
while (p1_p) {
- if (Block* b = dynamic_cast<Block*>(p1_p)) {
+ if (BlockPtr b = to<Block>(p1_p)) {
if (b) {
enclosing.insert(b);
}
p1_p = p1_p->get_parent();
}
- Stmt* p2_p = p2;
+ StmtPtr p2_p = p2;
while (p2_p) {
- if (Block* b = dynamic_cast<Block*>(p2_p)) {
+ if (BlockPtr b = to<Block>(p2_p)) {
if (enclosing.count(b) != 0) {
return b;
}
}
// returns the immediate child containing statement s.
- Stmt* getEnclosedRoot(Stmt* s) const {
+ StmtPtr getEnclosedRoot(StmtPtr s) const {
while (s && s->get_parent() != this) {
s = s->get_parent();
}
}
private:
- std::list<Stmt*> stmts_;
+ std::list<StmtPtr> stmts_;
- void init(const std::vector<Stmt*>& stmts) {
- for (Stmt* s : stmts) {
+ void init(const std::vector<StmtPtr>& stmts) {
+ for (StmtPtr s : stmts) {
if (!s) {
continue;
}
class TORCH_API Store : public StmtNode<Store> {
public:
- Var* base_handle() const {
+ VarPtr base_handle() const {
return buf_->base_handle();
}
- std::vector<Expr*> indices() const {
+ std::vector<ExprPtr> indices() const {
return indices_;
}
- Expr* flat_index() const {
+ ExprPtr flat_index() const {
TORCH_CHECK(indices_.size() == 1, "Indices haven't been flattened.");
return indices_[0];
}
- Expr* value() const {
+ ExprPtr value() const {
return value_;
}
- Buf* buf() const {
+ BufPtr buf() const {
return buf_;
}
- void set_buf(Buf* buf) {
+ void set_buf(BufPtr buf) {
buf_ = buf;
}
- void set_indices(std::vector<Expr*> indices) {
+ void set_indices(std::vector<ExprPtr> indices) {
indices_ = std::move(indices);
}
- void set_value(Expr* value) {
+ void set_value(ExprPtr value) {
value_ = value;
}
- static Store* make(
+ static StorePtr make(
const BufHandle& buf,
const std::vector<ExprHandle>& indices,
const ExprHandle& value);
- Store(Buf* buf, std::vector<Expr*> indices, Expr* value);
+ Store(BufPtr buf, std::vector<ExprPtr> indices, ExprPtr value);
private:
- Buf* buf_;
- std::vector<Expr*> indices_;
- Expr* value_;
+ BufPtr buf_;
+ std::vector<ExprPtr> indices_;
+ ExprPtr value_;
};
// Allocate a buffer of given shapes and dtypes and bind it with the given
// explicitly freed. An unfreed memory is likely considered an error.
class TORCH_API Allocate : public StmtNode<Allocate> {
public:
- static Allocate* make(const BufHandle& buf_handle) {
- return new Allocate(buf_handle.node());
+ static AllocatePtr make(const BufHandle& buf_handle) {
+ return alloc<Allocate>(buf_handle.node());
}
- Var* buffer_var() const {
+ VarPtr buffer_var() const {
return buf_->base_handle();
}
return buf_->dtype();
}
- const std::vector<Expr*> dims() const {
+ const std::vector<ExprPtr> dims() const {
return buf_->dims();
}
- Buf* buf() const {
+ BufPtr buf() const {
return buf_;
}
- void set_buf(Buf* buf) {
+ void set_buf(BufPtr buf) {
buf_ = buf;
}
- explicit Allocate(Buf* buf) : buf_(buf) {}
+ explicit Allocate(BufPtr buf) : buf_(buf) {}
private:
- Buf* buf_;
+ BufPtr buf_;
// TODO: add memory types.
};
// Free the specific buffer. It is an error.
class TORCH_API Free : public StmtNode<Free> {
public:
- static Free* make(const BufHandle& buf_handle) {
- return new Free(buf_handle.node());
+ static FreePtr make(const BufHandle& buf_handle) {
+ return alloc<Free>(buf_handle.node());
}
- Var* buffer_var() const {
+ VarPtr buffer_var() const {
return buf_->base_handle();
}
- Buf* buf() const {
+ BufPtr buf() const {
return buf_;
}
- void set_buf(Buf* buf) {
+ void set_buf(BufPtr buf) {
buf_ = buf;
}
- explicit Free(Buf* buf) : buf_(buf) {}
+ explicit Free(BufPtr buf) : buf_(buf) {}
private:
- Buf* buf_;
+ BufPtr buf_;
};
class TORCH_API Let : public StmtNode<Let> {
public:
- static Let* make(const VarHandle& var, const ExprHandle& val) {
- return new Let(var.node(), val.node());
+ static LetPtr make(const VarHandle& var, const ExprHandle& val) {
+ return alloc<Let>(var.node(), val.node());
}
- Let(Var* var, Expr* val) : dtype_(var->dtype()), var_(var), val_(val) {}
+ Let(VarPtr var, ExprPtr val) : dtype_(var->dtype()), var_(var), val_(val) {}
Dtype dtype() const {
return dtype_;
}
- Var* var() const {
+ VarPtr var() const {
return var_;
}
- Expr* value() const {
+ ExprPtr value() const {
return val_;
}
- void set_var(Var* var) {
+ void set_var(VarPtr var) {
var_ = var;
}
- void set_val(Expr* val) {
+ void set_val(ExprPtr val) {
val_ = val;
}
private:
Dtype dtype_;
- Var* var_;
- Expr* val_;
+ VarPtr var_;
+ ExprPtr val_;
};
class TORCH_API Cond : public StmtNode<Cond> {
public:
- static Cond* make(
+ static CondPtr make(
const ExprHandle& condition,
- Stmt* true_stmt,
- Stmt* false_stmt) {
- return new Cond(condition.node(), true_stmt, false_stmt);
+ StmtPtr true_stmt,
+ StmtPtr false_stmt) {
+ return alloc<Cond>(condition.node(), true_stmt, false_stmt);
}
- Expr* condition() const {
+ ExprPtr condition() const {
return condition_;
}
- Block* true_stmt() const {
+ BlockPtr true_stmt() const {
return true_stmt_;
}
- Block* false_stmt() const {
+ BlockPtr false_stmt() const {
return false_stmt_;
}
- void set_condition(Expr* condition) {
+ void set_condition(ExprPtr condition) {
condition_ = condition;
}
- void set_true_stmt(Stmt* true_stmt) {
+ void set_true_stmt(StmtPtr true_stmt) {
if (true_stmt) {
- Block* b = dynamic_cast<Block*>(true_stmt);
+ BlockPtr b = to<Block>(true_stmt);
if (!b) {
- b = new Block({true_stmt});
+ b = alloc<Block>(std::vector<StmtPtr>({true_stmt}));
}
true_stmt_ = b;
set_parent(true_stmt_, this);
}
}
- void set_false_stmt(Stmt* false_stmt) {
+ void set_false_stmt(StmtPtr false_stmt) {
if (false_stmt) {
- Block* b = dynamic_cast<Block*>(false_stmt);
+ BlockPtr b = to<Block>(false_stmt);
if (!b) {
- b = new Block({false_stmt});
+ b = alloc<Block>(std::vector<StmtPtr>({false_stmt}));
}
false_stmt_ = b;
set_parent(false_stmt_, this);
}
}
- Cond(Expr* condition, Stmt* true_stmt, Stmt* false_stmt)
+ Cond(ExprPtr condition, StmtPtr true_stmt, StmtPtr false_stmt)
: condition_(condition) {
set_true_stmt(true_stmt);
set_false_stmt(false_stmt);
}
- Cond* cloneWithNewBodies(Stmt* true_stmt, Stmt* false_stmt) {
- return new Cond(condition_, true_stmt, false_stmt);
+ CondPtr cloneWithNewBodies(StmtPtr true_stmt, StmtPtr false_stmt) {
+ return alloc<Cond>(condition_, true_stmt, false_stmt);
}
- Cond* cloneWithNewBody(Stmt* true_stmt) {
- return new Cond(condition_, true_stmt, nullptr);
+ CondPtr cloneWithNewBody(StmtPtr true_stmt) {
+ return alloc<Cond>(condition_, true_stmt, nullptr);
}
private:
- Expr* condition_;
- Block* true_stmt_ = nullptr;
- Block* false_stmt_ = nullptr;
+ ExprPtr condition_;
+ BlockPtr true_stmt_ = nullptr;
+ BlockPtr false_stmt_ = nullptr;
};
class TORCH_API LoopOptions {
!is_parallel_;
}
- void set_buffer_mapping(const std::unordered_map<std::string, Buf*>& map) {
+ void set_buffer_mapping(const std::unordered_map<std::string, BufPtr>& map) {
map_input_to_tensor_bufs_ = map;
}
- std::unordered_map<std::string, Buf*> get_buffer_mapping() const {
+ std::unordered_map<std::string, BufPtr> get_buffer_mapping() const {
return map_input_to_tensor_bufs_;
}
int gpu_block_index_{IDX_UNSET};
int gpu_thread_index_{IDX_UNSET};
bool is_parallel_{false};
- std::unordered_map<std::string, Buf*> map_input_to_tensor_bufs_;
+ std::unordered_map<std::string, BufPtr> map_input_to_tensor_bufs_;
};
class TORCH_API For : public StmtNode<For> {
public:
- Var* var() const {
+ VarPtr var() const {
return var_;
}
- Expr* start() const {
+ ExprPtr start() const {
return start_;
}
- Expr* stop() const {
+ ExprPtr stop() const {
return stop_;
}
- Block* body() const {
+ BlockPtr body() const {
return body_;
}
- static For* make(
+ static ForPtr make(
const VarHandle& var,
const ExprHandle& start,
const ExprHandle& stop,
- Stmt* body) {
+ StmtPtr body) {
if (!body) {
return nullptr;
}
- return new For(var.node(), start.node(), stop.node(), body);
+ return alloc<For>(var.node(), start.node(), stop.node(), body);
}
- static For* make(
+ static ForPtr make(
const VarHandle& var,
const ExprHandle& start,
const ExprHandle& stop,
- Stmt* body,
+ StmtPtr body,
const LoopOptions& loop_options) {
if (!body) {
return nullptr;
}
- return new For(var.node(), start.node(), stop.node(), body, loop_options);
+ return alloc<For>(
+ var.node(), start.node(), stop.node(), body, loop_options);
}
const LoopOptions loop_options() const {
return loop_options_;
}
- For(Var* var, Expr* start, Expr* stop, Stmt* body)
+ For(VarPtr var, ExprPtr start, ExprPtr stop, StmtPtr body)
: var_(var), start_(start), stop_(stop) {
- Block* b = dynamic_cast<Block*>(body);
+ BlockPtr b = to<Block>(body);
if (!b) {
- b = new Block({body});
+ b = alloc<Block>(std::vector<StmtPtr>({body}));
}
body_ = b;
set_parent(body_, this);
}
- For(Var* var, Expr* start, Expr* stop, Stmt* body, LoopOptions loop_options)
+ For(VarPtr var,
+ ExprPtr start,
+ ExprPtr stop,
+ StmtPtr body,
+ LoopOptions loop_options)
: var_(var),
start_(start),
stop_(stop),
throw malformed_input("invalid Body in For loop", body);
}
- Block* b = dynamic_cast<Block*>(body);
+ BlockPtr b = to<Block>(body);
if (!b) {
- b = new Block({body});
+ b = alloc<Block>(std::vector<StmtPtr>({body}));
}
body_ = b;
set_parent(body_, this);
return loop_options_.is_parallel();
}
- void set_buffer_map(const std::unordered_map<std::string, Buf*>& map) {
+ void set_buffer_map(const std::unordered_map<std::string, BufPtr>& map) {
loop_options_.set_buffer_mapping(map);
}
- For* cloneWithNewBody(Stmt* body) const {
- return new For(var_, start_, stop_, body, loop_options_);
+ ForPtr cloneWithNewBody(StmtPtr body) const {
+ return alloc<For>(var_, start_, stop_, body, loop_options_);
}
- Block* removeBody() {
+ BlockPtr removeBody() {
auto res = body_;
set_parent(res, nullptr);
body_ = nullptr;
return res;
}
- void set_body(Stmt* body) {
- Block* b = dynamic_cast<Block*>(body);
+ void set_body(StmtPtr body) {
+ BlockPtr b = to<Block>(body);
if (!b) {
- b = new Block({body});
+ b = alloc<Block>(std::vector<StmtPtr>({body}));
}
body_ = b;
set_parent(body_, this);
}
- void set_start(Expr* start) {
+ void set_start(ExprPtr start) {
start_ = start;
}
- void set_stop(Expr* stop) {
+ void set_stop(ExprPtr stop) {
stop_ = stop;
}
- void set_var(Var* var) {
+ void set_var(VarPtr var) {
var_ = var;
}
private:
- Var* var_;
- Expr* start_;
- Expr* stop_;
- Block* body_;
+ VarPtr var_;
+ ExprPtr start_;
+ ExprPtr stop_;
+ BlockPtr body_;
LoopOptions loop_options_;
};
class TORCH_API AtomicAdd : public StmtNode<AtomicAdd> {
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
- AtomicAdd(Buf* buf, std::vector<Expr*> indices, Expr* value)
+ AtomicAdd(BufPtr buf, std::vector<ExprPtr> indices, ExprPtr value)
: buf_(buf), indices_(std::move(indices)), value_(value) {}
- Var* base_handle() const {
+ VarPtr base_handle() const {
return buf_->base_handle();
}
- Buf* buf() const {
+ BufPtr buf() const {
return buf_;
}
- Expr* flat_index() const {
+ ExprPtr flat_index() const {
TORCH_CHECK(indices_.size() == 1, "Indices haven't been flattened.");
return indices_[0];
}
- Expr* value() const {
+ ExprPtr value() const {
return value_;
}
- const std::vector<Expr*>& indices() const {
+ const std::vector<ExprPtr>& indices() const {
return indices_;
}
- void set_buf(Buf* buf) {
+ void set_buf(BufPtr buf) {
buf_ = buf;
}
- void set_indices(std::vector<Expr*> indices) {
+ void set_indices(std::vector<ExprPtr> indices) {
indices_ = std::move(indices);
}
- void set_value(Expr* value) {
+ void set_value(ExprPtr value) {
value_ = value;
}
private:
- Buf* buf_;
- std::vector<Expr*> indices_;
- Expr* value_;
+ BufPtr buf_;
+ std::vector<ExprPtr> indices_;
+ ExprPtr value_;
};
class TORCH_API SyncThreads : public StmtNode<SyncThreads> {
*/
class TORCH_API ExternalCall : public StmtNode<ExternalCall> {
public:
- static ExternalCall* make(
+ static ExternalCallPtr make(
BufHandle buf,
const std::string& func_name,
const std::vector<BufHandle>& buf_args,
const std::vector<ExprHandle>& args);
- Buf* buf() const {
+ BufPtr buf() const {
return buf_;
}
return func_name_;
}
- std::vector<Buf*> buf_args() const {
+ std::vector<BufPtr> buf_args() const {
return buf_args_;
}
- std::vector<Expr*> args() const {
+ std::vector<ExprPtr> args() const {
return args_;
}
- void set_buf(Buf* buf) {
+ void set_buf(BufPtr buf) {
buf_ = buf;
}
- void set_buf_args(std::vector<Buf*> buf_args) {
+ void set_buf_args(std::vector<BufPtr> buf_args) {
buf_args_ = std::move(buf_args);
}
- void set_args(std::vector<Expr*> args) {
+ void set_args(std::vector<ExprPtr> args) {
args_ = std::move(args);
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
ExternalCall(
- Buf* buf,
+ BufPtr buf,
std::string func_name,
- std::vector<Buf*> buf_args,
- std::vector<Expr*> args)
+ std::vector<BufPtr> buf_args,
+ std::vector<ExprPtr> args)
: buf_(buf),
func_name_(std::move(func_name)),
buf_args_(std::move(buf_args)),
args_(std::move(args)) {}
private:
- Buf* buf_;
+ BufPtr buf_;
std::string func_name_;
- std::vector<Buf*> buf_args_;
- std::vector<Expr*> args_;
+ std::vector<BufPtr> buf_args_;
+ std::vector<ExprPtr> args_;
};
} // namespace tensorexpr
namespace jit {
namespace tensorexpr {
-Stmt* Tensor::constructStmt(
- const std::vector<Var*>& args,
- Expr* body,
- const std::vector<Expr*>& reduce_dims,
- const std::vector<Var*>& reduce_args) const {
- std::vector<Expr*> indices(args.begin(), args.end());
+StmtPtr Tensor::constructStmt(
+ const std::vector<VarPtr>& args,
+ ExprPtr body,
+ const std::vector<ExprPtr>& reduce_dims,
+ const std::vector<VarPtr>& reduce_args) const {
+ std::vector<ExprPtr> indices(args.begin(), args.end());
- Stmt* s = new Store(buf_, indices, body);
+ StmtPtr s = alloc<Store>(buf_, indices, body);
size_t ndim = buf()->ndim();
size_t reduce_ndim = reduce_dims.size();
return s;
}
- Expr* init_expr = buf()->initializer();
+ ExprPtr init_expr = buf()->initializer();
if (reduce_ndim > 0) {
for (const auto i : c10::irange(reduce_ndim)) {
// Going in reverse order: from innermost loop to the outermost
size_t dim_index = reduce_ndim - i - 1;
- s = new For(
- reduce_args[dim_index], new IntImm(0), reduce_dims[dim_index], s);
+ s = alloc<For>(
+ reduce_args[dim_index], alloc<IntImm>(0), reduce_dims[dim_index], s);
}
if (init_expr) {
- Store* init_stmt = new Store(buf(), indices, init_expr);
- s = new Block({init_stmt, s});
+ StorePtr init_stmt = alloc<Store>(buf(), indices, init_expr);
+ s = alloc<Block>(std::vector<StmtPtr>({init_stmt, s}));
}
}
for (const auto i : c10::irange(ndim)) {
// Going in reverse order: from innermost loop to the outermost
size_t dim_index = ndim - i - 1;
- s = new For(args[dim_index], new IntImm(0), buf()->dim(dim_index), s);
+ s = alloc<For>(args[dim_index], alloc<IntImm>(0), buf()->dim(dim_index), s);
}
return s;
}
const std::string& name,
const std::vector<DimArg>& dim_args,
const std::function<ExprHandle(const std::vector<VarHandle>&)>& body_func) {
- std::vector<Expr*> dims;
- std::vector<Var*> args;
+ std::vector<ExprPtr> dims;
+ std::vector<VarPtr> args;
unpack_dim_args(dim_args, &dims, &args);
- Expr* body = body_func(VarVectorToVarHandleVector(args)).node();
- Buf* buf = new Buf(name, dims, body->dtype());
+ ExprPtr body = body_func(VarVectorToVarHandleVector(args)).node();
+ BufPtr buf = alloc<Buf>(name, dims, body->dtype());
return new Tensor(buf, args, body);
}
throw malformed_input("mismatch between body and arg size (1)");
}
- std::vector<Expr*> dims;
- std::vector<Var*> args;
+ std::vector<ExprPtr> dims;
+ std::vector<VarPtr> args;
unpack_dim_args(dim_args, &dims, &args);
- Expr* body = body_func(VarHandle(args[0])).node();
- Buf* buf = new Buf(name, dims, body->dtype());
+ ExprPtr body = body_func(VarHandle(args[0])).node();
+ BufPtr buf = alloc<Buf>(name, dims, body->dtype());
return new Tensor(buf, args, body);
}
if (dim_args.size() != 2) {
throw malformed_input("mismatch between body and arg size (2)");
}
- std::vector<Expr*> dims;
- std::vector<Var*> args;
+ std::vector<ExprPtr> dims;
+ std::vector<VarPtr> args;
unpack_dim_args(dim_args, &dims, &args);
- Expr* body = body_func(VarHandle(args[0]), VarHandle(args[1])).node();
- Buf* buf = new Buf(name, dims, body->dtype());
+ ExprPtr body = body_func(VarHandle(args[0]), VarHandle(args[1])).node();
+ BufPtr buf = alloc<Buf>(name, dims, body->dtype());
return new Tensor(buf, args, body);
}
if (dim_args.size() != 3) {
throw malformed_input("mismatch between body and arg size (3)");
}
- std::vector<Expr*> dims;
- std::vector<Var*> args;
+ std::vector<ExprPtr> dims;
+ std::vector<VarPtr> args;
unpack_dim_args(dim_args, &dims, &args);
- Expr* body =
+ ExprPtr body =
body_func(VarHandle(args[0]), VarHandle(args[1]), VarHandle(args[2]))
.node();
- Buf* buf = new Buf(name, dims, body->dtype());
+ BufPtr buf = alloc<Buf>(name, dims, body->dtype());
return new Tensor(buf, args, body);
}
if (dim_args.size() != 4) {
throw malformed_input("mismatch between body and arg size (4)");
}
- std::vector<Expr*> dims;
- std::vector<Var*> args;
+ std::vector<ExprPtr> dims;
+ std::vector<VarPtr> args;
unpack_dim_args(dim_args, &dims, &args);
- Expr* body = body_func(
- VarHandle(args[0]),
- VarHandle(args[1]),
- VarHandle(args[2]),
- VarHandle(args[3]))
- .node();
- Buf* buf = new Buf(name, dims, body->dtype());
+ ExprPtr body = body_func(
+ VarHandle(args[0]),
+ VarHandle(args[1]),
+ VarHandle(args[2]),
+ VarHandle(args[3]))
+ .node();
+ BufPtr buf = alloc<Buf>(name, dims, body->dtype());
return new Tensor(buf, args, body);
}
class TORCH_API Tensor : KernelScopedObject {
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
- Tensor(Buf* buf, const std::vector<Var*>& args, Expr* body) : buf_(buf) {
+ Tensor(BufPtr buf, const std::vector<VarPtr>& args, ExprPtr body)
+ : buf_(buf) {
stmt_ = constructStmt(args, body, {}, {});
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Tensor(
- Buf* buf,
- const std::vector<Var*>& args,
- const std::vector<Expr*>& reduce_dims,
- const std::vector<Var*>& reduce_args,
- Expr* body)
+ BufPtr buf,
+ const std::vector<VarPtr>& args,
+ const std::vector<ExprPtr>& reduce_dims,
+ const std::vector<VarPtr>& reduce_args,
+ ExprPtr body)
: buf_(buf) {
stmt_ = constructStmt(args, body, reduce_dims, reduce_args);
}
- Tensor(Buf* buf, Stmt* stmt) : buf_(buf), stmt_(stmt) {}
+ Tensor(BufPtr buf, StmtPtr stmt) : buf_(buf), stmt_(stmt) {}
- Buf* buf() const {
+ BufPtr buf() const {
return buf_;
}
- Stmt* stmt() const {
+ StmtPtr stmt() const {
return stmt_;
}
inline ExprHandle load(const Ts&... ts);
private:
- Stmt* constructStmt(
- const std::vector<Var*>& args,
- Expr* body,
- const std::vector<Expr*>& reduce_dims,
- const std::vector<Var*>& reduce_args) const;
-
- Buf* buf_;
- Stmt* stmt_;
+ StmtPtr constructStmt(
+ const std::vector<VarPtr>& args,
+ ExprPtr body,
+ const std::vector<ExprPtr>& reduce_dims,
+ const std::vector<VarPtr>& reduce_args) const;
+
+ BufPtr buf_;
+ StmtPtr stmt_;
};
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
explicit Placeholder(const std::vector<ExprHandle>& dims)
: Placeholder(BufHandle("_", dims, kFloat)) {}
- Buf* data() const {
+ BufPtr data() const {
return data_;
}
BufHandle handle() const {
int ndim() const {
return data_->ndim();
}
- Expr* dim(int index) const {
+ ExprPtr dim(int index) const {
return data_->dim(index);
}
- std::vector<Expr*> dims() const {
+ std::vector<ExprPtr> dims() const {
return data_->dims();
}
inline ExprHandle load(const std::vector<ExprHandle>& args) const;
- inline Store* store(
+ inline StorePtr store(
const std::vector<ExprHandle>& args,
const ExprHandle& val) const {
- return new Store(data(), ExprHandleVectorToExprVector(args), val.node());
+ return alloc<Store>(data(), ExprHandleVectorToExprVector(args), val.node());
}
private:
- Buf* data_;
- std::vector<Expr*> strides_;
+ BufPtr data_;
+ std::vector<ExprPtr> strides_;
};
TORCH_API Tensor* Compute(
inline void unpack_dim_args(
const std::vector<DimArg>& dim_args,
- std::vector<Expr*>* dims,
- std::vector<Var*>* vars) {
+ std::vector<ExprPtr>* dims,
+ std::vector<VarPtr>* vars) {
dims->clear();
vars->clear();
for (const DimArg& dim_arg : dim_args) {
- Expr* expr = dim_arg.dim().node();
+ ExprPtr expr = dim_arg.dim().node();
dims->push_back(expr);
- vars->push_back(new Var(
+ vars->push_back(alloc<Var>(
dim_arg.name_hint(),
expr->dtype().scalar_type() == ScalarType::Long ? kLong : kInt));
}
const BodyFunc& body_func,
const std::vector<DimArg>& reduce_args) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::vector<Expr*> dims;
+ std::vector<ExprPtr> dims;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::vector<Var*> vars;
+ std::vector<VarPtr> vars;
unpack_dim_args(dim_args, &dims, &vars);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::vector<Expr*> reduce_dims;
+ std::vector<ExprPtr> reduce_dims;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::vector<Var*> reduce_vars;
+ std::vector<VarPtr> reduce_vars;
unpack_dim_args(reduce_args, &reduce_dims, &reduce_vars);
// If reduce_vars is empty, then it's not a reduction, but rather a simple
// copy
if (reduce_vars.empty()) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- Expr* body =
+ ExprPtr body =
Reducer::getReduceBody(body_func, VarVectorToVarHandleVector(vars))
.node();
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- Buf* func_result = new Buf(func_name, dims, body->dtype());
+ BufPtr func_result = alloc<Buf>(func_name, dims, body->dtype());
return new Tensor(func_result, vars, body);
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::vector<Var*> all_vars;
+ std::vector<VarPtr> all_vars;
all_vars.insert(all_vars.end(), vars.begin(), vars.end());
all_vars.insert(all_vars.end(), reduce_vars.begin(), reduce_vars.end());
ExprHandle body =
Reducer::getReduceBody(body_func, VarVectorToVarHandleVector(all_vars));
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::vector<Expr*> output_args(vars.begin(), vars.end());
+ std::vector<ExprPtr> output_args(vars.begin(), vars.end());
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- Expr* init_expr = new Cast(
+ ExprPtr init_expr = alloc<Cast>(
body.dtype(), init_func(VarVectorToVarHandleVector(vars)).node());
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- Buf* func_result = new Buf(func_name, dims, body.dtype(), init_expr);
+ BufPtr func_result = alloc<Buf>(func_name, dims, body.dtype(), init_expr);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- ReduceOp* reduce_op = reducer(func_result, body, output_args, reduce_vars);
+ ReduceOpPtr reduce_op = reducer(func_result, body, output_args, reduce_vars);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
Tensor* t =
new Tensor(func_result, vars, reduce_dims, reduce_vars, reduce_op);
inline ExprHandle Placeholder::load(const Ts&... ts) const {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<ExprHandle> params({ExprHandle(ts)...});
- return ExprHandle(new Load(data(), ExprHandleVectorToExprVector(params)));
+ return ExprHandle(alloc<Load>(data(), ExprHandleVectorToExprVector(params)));
}
template <typename T>
inline ExprHandle Placeholder::load(const std::vector<T>& args) const {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<ExprHandle> params(args.begin(), args.end());
- return ExprHandle(new Load(data(), ExprHandleVectorToExprVector(params)));
+ return ExprHandle(alloc<Load>(data(), ExprHandleVectorToExprVector(params)));
}
inline ExprHandle Placeholder::load(const std::vector<ExprHandle>& args) const {
py::return_value_policy::reference);
py::class_<Tensor, std::unique_ptr<Tensor, py::nodelete>>(te, "Tensor")
.def(py::init(
- [](BufHandle& b, Stmt* s) { return new Tensor(b.node(), s); }))
+ [](BufHandle& b, StmtPtr s) { return new Tensor(b.node(), s); }))
.def(
"load",
[](Tensor& self, const std::vector<ExprHandle>& v) {
py::return_value_policy::reference);
py::class_<Stmt, std::unique_ptr<Stmt, py::nodelete>>(te, "Stmt")
- .def(py::init([](const std::vector<Stmt*>& stmts) {
+ .def(py::init([](const std::vector<StmtPtr>& stmts) {
return tensorexpr::Block::make(stmts);
}))
.def("__str__", [](Stmt& self) {
[](const VarHandle& var,
const ExprHandle& start,
const ExprHandle& stop,
- Stmt* body) { return For::make(var, start, stop, body); },
+ StmtPtr body) { return For::make(var, start, stop, body); },
py::return_value_policy::reference);
py::class_<Cond, Stmt, std::unique_ptr<Cond, py::nodelete>>(te, "Cond")
.def_static(
"make",
- [](const ExprHandle& condition, Stmt* true_stmt, Stmt* false_stmt) {
- return new Cond(condition.node(), true_stmt, false_stmt);
+ [](const ExprHandle& condition,
+ StmtPtr true_stmt,
+ StmtPtr false_stmt) {
+ return alloc<Cond>(condition.node(), true_stmt, false_stmt);
},
py::return_value_policy::reference)
.def("true_stmt", &Cond::true_stmt, py::return_value_policy::reference)
tensorexpr::Block,
Stmt,
std::unique_ptr<tensorexpr::Block, py::nodelete>>(te, "Block")
- .def(py::init([](const std::vector<Stmt*>& stmts) {
+ .def(py::init([](const std::vector<StmtPtr>& stmts) {
return tensorexpr::Block::make(stmts);
}))
.def(
py::class_<LoopNest>(te, "LoopNest")
.def(py::init<const std::vector<Tensor*>&>())
- .def(py::init([](Stmt* s, const std::vector<BufHandle>& bufs) {
- std::unordered_set<Buf*> buf_nodes;
+ .def(py::init([](StmtPtr s, const std::vector<BufHandle>& bufs) {
+ std::unordered_set<BufPtr> buf_nodes;
for (auto& buf : bufs) {
buf_nodes.insert(buf.node());
}
py::return_value_policy::reference)
.def(
"get_enclosing_loopnest",
- [](const LoopNest& self, Stmt* s) {
+ [](const LoopNest& self, StmtPtr s) {
return self.getEnclosingLoopNest(s);
},
py::return_value_policy::reference)
py::return_value_policy::reference)
.def(
"get_loop_at",
- [](const LoopNest& self, For* root, const std::vector<int>& indices) {
+ [](const LoopNest& self,
+ ForPtr root,
+ const std::vector<int>& indices) {
return self.getLoopAt(root, indices);
},
py::return_value_policy::reference)
.def(
"get_parent_loop",
- [](const LoopNest& self, Stmt* s) { return self.getParentLoop(s); },
+ [](const LoopNest& self, StmtPtr s) { return self.getParentLoop(s); },
py::return_value_policy::reference)
.def_static(
"get_loop_stmts_in_loopnest",
- [](For* f, size_t num) {
+ [](ForPtr f, size_t num) {
return LoopNest::getLoopStmtsInLoopNest(f, num);
},
py::return_value_policy::reference)
.def(
"split_with_tail",
- [](For* f, int factor) {
- For *inner = nullptr, *tail = nullptr;
+ [](ForPtr f, int factor) {
+ ForPtr inner = nullptr, tail = nullptr;
LoopNest::splitWithTail(f, factor, &inner, &tail);
return std::make_tuple(inner, tail);
},
py::return_value_policy::reference)
.def(
"split_with_mask",
- [](For* f, int factor) {
- For* inner = nullptr;
+ [](ForPtr f, int factor) {
+ ForPtr inner = nullptr;
LoopNest::splitWithMask(f, factor, &inner);
return inner;
},
py::return_value_policy::reference)
.def(
"slice_head",
- [](For* f, int factor) {
- For *head = nullptr, *tail = nullptr;
+ [](ForPtr f, int factor) {
+ ForPtr head = nullptr, tail = nullptr;
LoopNest::sliceHead(f, factor, &head, &tail);
return std::make_tuple(head, tail);
},
py::return_value_policy::reference)
.def(
"slice_tail",
- [](For* f, int factor) {
- For *head = nullptr, *tail = nullptr;
+ [](ForPtr f, int factor) {
+ ForPtr head = nullptr, tail = nullptr;
LoopNest::sliceTail(f, factor, &head, &tail);
return std::make_tuple(head, tail);
},
py::return_value_policy::reference)
.def_static(
"normalize",
- [](For* f) {
+ [](ForPtr f) {
LoopNest::normalize(f);
return f;
},
py::return_value_policy::reference)
.def(
"tile",
- [](LoopNest& self, For* x, For* y, int x_factor, int y_factor) {
+ [](LoopNest& self, ForPtr x, ForPtr y, int x_factor, int y_factor) {
return self.tile(x, y, x_factor, y_factor);
},
py::return_value_policy::reference)
.def_static(
"distribute_loop",
- [](For* f) { return LoopNest::distributeLoop(f); },
+ [](ForPtr f) { return LoopNest::distributeLoop(f); },
py::return_value_policy::reference)
.def_static(
"distribute_loop",
- [](For* f, const std::unordered_set<Stmt*>& pivots) {
+ [](ForPtr f, const std::unordered_set<StmtPtr>& pivots) {
return LoopNest::distributeLoop(f, pivots);
},
py::return_value_policy::reference)
.def_static(
"distribute_loop_over_inner_loops",
- [](For* f) { return LoopNest::distributeLoopOverInnerLoops(f); },
+ [](ForPtr f) { return LoopNest::distributeLoopOverInnerLoops(f); },
py::return_value_policy::reference)
.def_static(
"unsafe_fuse_loops",
- [](const std::vector<For*>& loops) {
- For* fused_loop = nullptr;
+ [](const std::vector<ForPtr>& loops) {
+ ForPtr fused_loop = nullptr;
LoopNest::unsafeFuseLoops(loops, &fused_loop);
return fused_loop;
},
py::return_value_policy::reference)
.def_static(
"fuse_loops",
- [](const std::vector<For*>& loops) {
- For* fused_loop = nullptr;
+ [](const std::vector<ForPtr>& loops) {
+ ForPtr fused_loop = nullptr;
LoopNest::fuseLoops(loops, &fused_loop);
return fused_loop;
},
py::return_value_policy::reference)
.def_static(
"reorder",
- [](const std::vector<For*>& loops,
+ [](const std::vector<ForPtr>& loops,
const std::vector<size_t>& permutation) {
return LoopNest::reorder(loops, permutation);
},
py::return_value_policy::reference)
.def(
"unroll",
- [](const LoopNest& self, For* f) {
- Stmt* unrolled = nullptr;
+ [](const LoopNest& self, ForPtr f) {
+ StmtPtr unrolled = nullptr;
self.unroll(f, &unrolled);
return unrolled;
},
py::return_value_policy::reference)
.def(
"vectorize",
- [](For* f) { LoopNest::vectorize(f); },
+ [](ForPtr f) { LoopNest::vectorize(f); },
py::return_value_policy::reference)
.def_static(
"compress_buffer",
- [](BufHandle& buf, Stmt* stmt) {
+ [](BufHandle& buf, StmtPtr stmt) {
return LoopNest::compressBuffer(buf.node(), stmt);
},
py::return_value_policy::reference)
"cache_accesses",
[](const BufHandle& producer,
const std::string& name,
- Stmt* consumer) {
- std::pair<Buf*, Stmt*> ret =
+ StmtPtr consumer) {
+ std::pair<BufPtr, StmtPtr> ret =
LoopNest::cacheAccesses(producer.node(), name, consumer);
return std::make_pair(BufHandle(ret.first), ret.second);
},
py::return_value_policy::reference)
- .def("compute_at", [](Stmt* s, For* at) { LoopNest::computeAt(s, at); })
+ .def(
+ "compute_at",
+ [](StmtPtr s, ForPtr at) { LoopNest::computeAt(s, at); })
.def(
"compute_inline",
- [](LoopNest& self, Stmt* s) { self.computeInline(s); },
+ [](LoopNest& self, StmtPtr s) { self.computeInline(s); },
py::return_value_policy::reference)
.def(
"compute_inline",
py::return_value_policy::reference)
.def(
"rfactor",
- [](Stmt* s, For* target_for) {
- Buf* rfac_buf = nullptr;
+ [](StmtPtr s, ForPtr target_for) {
+ BufPtr rfac_buf = nullptr;
LoopNest::rfactor(s, target_for, &rfac_buf);
return BufHandle(rfac_buf);
},
py::return_value_policy::reference)
.def(
"flatten",
- [](const std::vector<For*>& loops) {
- For* flattened = nullptr;
+ [](const std::vector<ForPtr>& loops) {
+ ForPtr flattened = nullptr;
LoopNest::flatten(loops, &flattened);
return flattened;
},
te.def(
"simplify",
- [](Stmt* stmt) { return IRSimplifier::simplify(stmt); },
+ [](StmtPtr stmt) { return IRSimplifier::simplify(stmt); },
py::return_value_policy::reference);
te.def(
te.def(
"construct_codegen",
[](const std::string& name,
- Stmt* stmt,
+ StmtPtr stmt,
const std::vector<CodeGen::BufferArg>& args) {
CodeGen* cg = nullptr;
if (name == "llvm") {
namespace jit {
namespace tensorexpr {
-const std::string& UniqueNameManager::get_unique_name(Var* v) {
+const std::string& UniqueNameManager::get_unique_name(VarPtr v) {
// Find if we have already encountered this variable.
auto iter = unique_name_mapping_.find(v);
if (iter != unique_name_mapping_.end()) {
#include <unordered_set>
#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/tensorexpr/fwd_decls.h>
namespace torch {
namespace jit {
class VarHandle;
class Var;
-using VarNameMap = std::unordered_map<Var*, std::string>;
+using VarNameMap = std::unordered_map<VarPtr, std::string>;
// A manager to get unique names from vars.
// It starts with the name hints of the var and append "_" + $counter until it
public:
const std::string& get_unique_name(const VarHandle& v);
- const std::string& get_unique_name(Var* v);
+ const std::string& get_unique_name(VarPtr v);
private:
friend class ScopedVarName;
namespace jit {
namespace tensorexpr {
-using VarMapping = std::vector<std::pair<Var*, Expr*>>;
+using VarMapping = std::vector<std::pair<VarPtr, ExprPtr>>;
class VarSubMutator : public IRMutator {
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
VarSubMutator(const VarMapping& var_mapping) {
for (auto& entry : var_mapping) {
- Var* key_var = entry.first;
- Expr* value = entry.second;
+ VarPtr key_var = entry.first;
+ ExprPtr value = entry.second;
if (!key_var) {
throw malformed_input("missing key in VarSubMutator");
}
}
}
- Expr* mutate(Var* var) override {
+ ExprPtr mutate(VarPtr var) override {
auto iter = var_mapping_.find(var);
if (iter == var_mapping_.end()) {
return var;
return iter->second;
}
- Expr* mutate(ReduceOp* var) override {
+ ExprPtr mutate(ReduceOpPtr var) override {
auto body = var->body()->accept_mutator(this);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::vector<Var*> new_inner;
+ std::vector<VarPtr> new_inner;
- for (auto* v : var->reduce_args()) {
- Expr* e = v->accept_mutator(this);
- if (Var* new_var = dynamic_cast<Var*>(e)) {
+ for (auto v : var->reduce_args()) {
+ ExprPtr e = v->accept_mutator(this);
+ if (VarPtr new_var = to<Var>(e)) {
new_inner.push_back(new_var);
} else {
VarFinder varFinder;
}
}
- return new ReduceOp(body, new_inner, var->reducer());
+ return alloc<ReduceOp>(body, new_inner, var->reducer());
}
private:
- std::unordered_map<Var*, Expr*> var_mapping_;
+ std::unordered_map<VarPtr, ExprPtr> var_mapping_;
};
} // namespace tensorexpr