torch::jit::testing::FileCheck().check("isnan")->run(oss.str());
}
+TEST_F(Kernel, Vectorize) {
+#ifdef TORCH_ENABLE_LLVM
+ const auto graph_string = R"IR(
+ graph(%0 : Float(100, 16, strides=[16, 1], device=cpu),
+ %1 : Float(100, 16, strides=[16, 1], device=cpu)):
+ %2 : Float(100, 16, strides=[16, 1]) = aten::mul(%0, %1)
+ %3 : Float(100, 16, strides=[16, 1]) = aten::mul(%0, %2)
+ return (%3))IR";
+ auto graph = std::make_shared<Graph>();
+ parseIR(graph_string, &*graph);
+
+ auto a = at::rand({100, 16}, TensorOptions(kCPU).dtype(at::kFloat));
+ auto b = at::rand({100, 16}, TensorOptions(kCPU).dtype(at::kFloat));
+ auto o = at::zeros({100, 16}, TensorOptions(kCPU).dtype(at::kFloat));
+ auto ref = a * (a * b);
+ TensorExprKernel k(graph);
+ std::vector<at::Tensor> inputs = {a, b};
+ StmtPtr s = k.getCodeGenStmt();
+
+ std::ostringstream oss;
+ oss << *s;
+
+ // Check the IR we produced
+ const std::string& verification_pattern = R"IR(# CHECK: Ramp)IR";
+ torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
+
+ std::vector<IValue> stack = fmap<IValue>(inputs);
+ k.run(stack);
+ o = stack[0].toTensor();
+ for (size_t i = 0; i < 100 * 16; i++) {
+ CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
+ }
+#endif
+}
+
+// TODO: To vectorize loopnest for 100x3 case, we need to flatten loops first.
+TEST_F(Kernel, DISABLED_FlattenVectorize) {
+#ifdef TORCH_ENABLE_LLVM
+ const auto graph_string = R"IR(
+ graph(%0 : Float(100, 3, strides=[3, 1], device=cpu),
+ %1 : Float(100, 3, strides=[3, 1], device=cpu)):
+ %2 : Float(100, 3, strides=[3, 1]) = aten::mul(%0, %1)
+ %3 : Float(100, 3, strides=[3, 1]) = aten::mul(%0, %2)
+ return (%3))IR";
+ auto graph = std::make_shared<Graph>();
+ parseIR(graph_string, &*graph);
+
+ auto a = at::rand({100, 3}, TensorOptions(kCPU).dtype(at::kFloat));
+ auto b = at::rand({100, 3}, TensorOptions(kCPU).dtype(at::kFloat));
+ auto o = at::zeros({100, 3}, TensorOptions(kCPU).dtype(at::kFloat));
+ auto ref = a * (a * b);
+ TensorExprKernel k(graph);
+ std::vector<at::Tensor> inputs = {a, b};
+ StmtPtr s = k.getCodeGenStmt();
+
+ std::ostringstream oss;
+ oss << *s;
+
+ // Check the IR we produced
+ const std::string& verification_pattern = R"IR(# CHECK: Ramp)IR";
+ torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
+
+ std::vector<IValue> stack = fmap<IValue>(inputs);
+ k.run(stack);
+ o = stack[0].toTensor();
+ for (size_t i = 0; i < 100 * 3; i++) {
+ CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
+ }
+#endif
+}
+
} // namespace jit
} // namespace torch
auto start_imm = intValue(start);
auto stop_imm = intValue(stop);
if (!start_imm) {
- throw std::runtime_error(
- "Can't vectorize due to non-constant loop start!");
+ // Can't vectorize due to non-constant loop start!
+ success_ = false;
+ return v;
}
if (!stop_imm) {
- throw std::runtime_error(
- "Can't vectorize due to non-constant loop stop!");
+ // Can't vectorize due to non-constant loop stop!
+ success_ = false;
+ return v;
}
var_ = var;
StmtPtr new_body = body->accept_mutator(this);
if (new_body == body) {
- throw std::runtime_error("Vectorization failed!");
+ // Vectorization failed!
+ success_ = false;
+ return v;
}
return new_body;
}
+ bool success() const {
+ return success_;
+ }
+
ExprPtr mutate(AddPtr v) override {
std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
return try_vectorize(v, inputs, [&]() {
ExprPtr mutate(VarPtr v) override {
if (v == var_) {
- return Ramp::make(ExprHandle(start_), 1, lanes_).node();
+ return Ramp::make(
+ ExprHandle(start_), ExprHandle(immLike(start_, 1)), lanes_)
+ .node();
}
return v;
return v;
}
- throw std::runtime_error("Can't vectorize a Ramp!");
+ // Can't vectorize a Ramp!
+ success_ = false;
+ return v;
}
ExprPtr mutate(LoadPtr v) override {
return v;
}
- throw std::runtime_error("Can't vectorize a Broadcast!");
+ // Can't vectorize a Broadcast!
+ success_ = false;
+ return v;
}
ExprPtr mutate(IfThenElsePtr v) override {
ExprPtr condition = v->condition();
ExprPtr new_condition = condition->accept_mutator(this);
if (new_condition != condition) {
- throw std::runtime_error("Can't vectorize an IfThenElse condition!");
+ // Can't vectorize an IfThenElse condition!
+ success_ = false;
+ return v;
}
std::vector<ExprPtr> inputs = {v->true_value(), v->false_value()};
ExprPtr new_stop = stop->accept_mutator(this);
if (new_start != start || new_stop != stop) {
- throw std::runtime_error(
- "Can't vectorize nested For with dependent loop bounds!");
+ // Can't vectorize nested For with dependent loop bounds!
+ success_ = false;
+ return v;
}
StmtPtr body = v->body();
VarPtr var_ = nullptr;
int lanes_ = 0;
ExprPtr start_ = nullptr;
+ bool success_ = true;
};
bool LoopNest::vectorize(ForPtr f) {
Vectorizer v;
StmtPtr new_f = nullptr;
- try {
- new_f = Stmt::clone(f);
- normalize(to<For>(new_f));
- new_f = FlattenIndexes(new_f);
- new_f = v.vectorize(to<For>(new_f));
- } catch (std::runtime_error& e) {
+ new_f = Stmt::clone(f);
+ normalize(to<For>(new_f));
+ new_f = FlattenIndexes(new_f);
+ new_f = v.vectorize(to<For>(new_f));
+ if (!v.success()) {
// We clone f before vectorizing. So, any partial vectorization will
// have modified the clone. In case of an exception, we can continue
// using f.