[NNC] Add C++ codegen backend to NNC (#62869)
authorCheng Chang <xcc@fb.com>
Thu, 26 Aug 2021 16:52:42 +0000 (09:52 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 26 Aug 2021 16:56:37 +0000 (09:56 -0700)
Summary:
Adds a C++ codegen backend to NNC to generate C++ for CPU instead of generating LLVM IR.
Tensors are represented as blobs of float. Vector operations are devectorized/unrolled.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62869

Test Plan:
https://github.com/pytorch/pytorch/tree/mvz-nnc-aot-prototype makes it able to AOT compile the whole MobileNetV3 model into binary code through LLVM codegen in NNC.

I forked that branch to https://github.com/cheng-chang/pytorch/tree/cc-aot-cpp, merged this PR into it, and modified `fancy_compile` to compile MobileNetV3 into C++ through

```
import torch

m = torch.jit.load('mobnet.pt')
m.eval()
f = torch.jit.freeze(m)
torch._C._fancy_compile(f.graph, [1, 3, 224, 224])
```

The generated C++ file `mobnet.cc` can be found at https://gist.github.com/cheng-chang/e2830cc6920b39204ebf368035b2bcec.

I manually compiled the generated C++ through `g++ -o mobnet -std=c++14 -L./build/lib -ltorch_cpu -ltorch mobnet.cc`, and it succeeded.

Reviewed By: ZolotukhinM

Differential Revision: D30149482

Pulled By: cheng-chang

fbshipit-source-id: e77b189f0353e37cd309423a48a513e668d07675

test/cpp/tensorexpr/test_cpp_codegen.cpp
torch/csrc/jit/tensorexpr/codegen.h
torch/csrc/jit/tensorexpr/cpp_codegen.cpp
torch/csrc/jit/tensorexpr/cpp_codegen.h
torch/csrc/jit/tensorexpr/cpp_intrinsics.h [new file with mode: 0644]
torch/csrc/jit/tensorexpr/expr.h
torch/csrc/jit/tensorexpr/ir.h
torch/csrc/jit/tensorexpr/ir_printer.cpp
torch/csrc/jit/tensorexpr/ir_printer.h

index df9166b..d40caa1 100644 (file)
@@ -1,9 +1,11 @@
 #include <gtest/gtest.h>
 
-#include <test/cpp/tensorexpr/test_base.h>
+#include "test/cpp/tensorexpr/test_base.h"
 
 #include <torch/csrc/jit/tensorexpr/cpp_codegen.h>
+#include <torch/csrc/jit/tensorexpr/fwd_decls.h>
 #include <torch/csrc/jit/tensorexpr/stmt.h>
+#include <torch/csrc/jit/tensorexpr/tensor.h>
 #include <torch/csrc/jit/testing/file_check.h>
 
 namespace torch {
@@ -11,44 +13,245 @@ namespace jit {
 
 using namespace torch::jit::tensorexpr;
 
-TEST(CppPrinter, AllocateOnStackThenFree) {
-  std::vector<ExprPtr> dims = {alloc<IntImm>(2), alloc<IntImm>(3)};
-  BufPtr buf = alloc<Buf>("x", dims, kInt);
-  AllocatePtr alloc_ = alloc<Allocate>(buf);
-  FreePtr free_ = alloc<Free>(buf);
-  BlockPtr block = Block::make({alloc_, free_});
-
-  std::stringstream ss;
-  CppPrinter printer(&ss);
-  printer.visit(block);
-  const std::string expected = R"(
-    # CHECK: {
-    # CHECK:   int x[6];
-    # CHECK: }
+#define STR_CHECK(node, expected) \
+  std::stringstream ss;           \
+  CppPrinter printer(&ss);        \
+  printer.visit(node);            \
+  ASSERT_EQ(ss.str(), expected)
+
+#define FILE_CHECK(node, pattern) \
+  std::stringstream ss;           \
+  CppPrinter printer(&ss);        \
+  printer.visit(node);            \
+  torch::jit::testing::FileCheck().run(pattern, ss.str())
+
+TEST(CppPrinter, IntImm) {
+  auto i = alloc<IntImm>(10);
+  STR_CHECK(i, "10");
+}
+
+TEST(CppPrinter, FloatImm) {
+  auto f = alloc<FloatImm>(10);
+  STR_CHECK(f, "10.f");
+}
+
+TEST(CppPrinter, FloatImm1) {
+  auto f = alloc<FloatImm>(10);
+  STR_CHECK(f, "10.f");
+}
+
+TEST(CppPrinter, DoubleImm) {
+  auto d = alloc<DoubleImm>(10);
+  STR_CHECK(d, "10.0");
+}
+
+TEST(CppPrinter, DoubleImm1) {
+  auto d = alloc<DoubleImm>(10.1);
+  STR_CHECK(d, "10.1");
+}
+
+TEST(CppPrinter, HalfImm) {
+  auto h = alloc<HalfImm>(10);
+  STR_CHECK(h, "10");
+}
+
+TEST(CppPrinter, Add) {
+  auto add = alloc<Add>(alloc<IntImm>(1), alloc<IntImm>(2));
+  STR_CHECK(add, "1 + 2");
+}
+
+TEST(CppPrinter, AddExpr1) {
+  auto add = alloc<Add>(
+      alloc<Add>(alloc<IntImm>(0), alloc<IntImm>(1)),
+      alloc<Sub>(alloc<IntImm>(2), alloc<IntImm>(3)));
+  STR_CHECK(add, "(0 + 1) + (2 - 3)");
+}
+
+TEST(CppPrinter, AddExpr2) {
+  auto add = alloc<Add>(
+      alloc<Mul>(alloc<IntImm>(0), alloc<IntImm>(1)),
+      alloc<Sub>(alloc<IntImm>(2), alloc<IntImm>(3)));
+  STR_CHECK(add, "0 * 1 + (2 - 3)");
+}
+
+TEST(CppPrinter, AddExpr3) {
+  auto add = alloc<Add>(
+      alloc<Add>(alloc<IntImm>(0), alloc<IntImm>(1)),
+      alloc<Div>(alloc<IntImm>(2), alloc<IntImm>(3)));
+  STR_CHECK(add, "(0 + 1) + 2 / 3");
+}
+
+TEST(CppPrinter, Mod) {
+  auto mod = alloc<Mod>(alloc<IntImm>(1), alloc<IntImm>(2));
+  STR_CHECK(mod, "1 % 2");
+}
+
+TEST(CppPrinter, ModFloat) {
+  auto mod = alloc<Mod>(alloc<FloatImm>(1), alloc<FloatImm>(2));
+  STR_CHECK(mod, "std::fmod(1.f, 2.f)");
+}
+
+TEST(CppPrinter, Max) {
+  auto max = alloc<Max>(alloc<IntImm>(1), alloc<IntImm>(2), false);
+  STR_CHECK(max, "std::max(1, 2)");
+}
+
+TEST(CppPrinter, MaxFloat) {
+  auto max = alloc<Max>(alloc<FloatImm>(1), alloc<FloatImm>(2), false);
+  STR_CHECK(max, "std::max(1.f, 2.f)");
+}
+
+TEST(CppPrinter, MaxHalf) {
+  auto max = alloc<Max>(alloc<HalfImm>(1), alloc<HalfImm>(2), false);
+  STR_CHECK(max, "(1 < 2) ? 2 : 1");
+}
+
+TEST(CppPrinter, And) {
+  auto v = alloc<And>(alloc<IntImm>(1), alloc<IntImm>(2));
+  STR_CHECK(v, "1 & 2");
+}
+
+TEST(CppPrinter, CompareSelect) {
+  auto cs = alloc<CompareSelect>(
+      alloc<IntImm>(1),
+      alloc<IntImm>(2),
+      alloc<FloatImm>(1),
+      alloc<FloatImm>(2),
+      CompareSelectOperation::kLE);
+  STR_CHECK(cs, "((1 <= 2) ? 1.f : 2.f)");
+}
+
+TEST(CppPrinter, IfThenElse) {
+  auto cond = alloc<Add>(alloc<IntImm>(1), alloc<IntImm>(2));
+  auto true_value = alloc<Sub>(alloc<IntImm>(0), alloc<IntImm>(1));
+  auto false_value = alloc<Mul>(alloc<IntImm>(2), alloc<IntImm>(3));
+  auto v = alloc<IfThenElse>(cond, true_value, false_value);
+  STR_CHECK(v, "((1 + 2) ? 0 - 1 : 2 * 3)");
+}
+
+TEST(CppPrinter, AllocateFree) {
+  BufHandle buf("x", {2, 3}, kInt);
+  AllocatePtr alloc = Allocate::make(buf);
+  FreePtr free = Free::make(buf);
+  BlockPtr block = Block::make({alloc, free});
+
+  const std::string pattern = R"(
+   # CHECK: {
+   # CHECK:   int* x = static_cast<int*>(malloc(24));
+   # CHECK:   free(x);
+   # CHECK: }
+  )";
+  FILE_CHECK(block, pattern);
+}
+
+TEST(CppPrinter, LoadStore) {
+  Placeholder a(BufHandle("A", {2, 3}, kInt));
+  Placeholder b(BufHandle("B", {3, 4}, kInt));
+  auto store = b.store({2, 2}, a.load(1, 1));
+  STR_CHECK(
+      store, "B[(0 + 2 * (1 * 4)) + 2 * 1] = A[(0 + 1 * (1 * 3)) + 1 * 1];\n");
+}
+
+TEST(CppPrinter, Var) {
+  auto var = alloc<Var>("x", kInt);
+  STR_CHECK(var, "x");
+}
+
+TEST(CppPrinter, Cast) {
+  auto cast = alloc<Cast>(kFloat, alloc<IntImm>(1));
+  STR_CHECK(cast, "static_cast<float>(1)");
+}
+
+TEST(CppPrinter, BitCast) {
+  auto cast = alloc<BitCast>(kInt, alloc<FloatImm>(20));
+  STR_CHECK(cast, "std::bitcast<float, int>(20.f)");
+}
+
+TEST(CppPrinter, Let) {
+  auto var = alloc<Var>("x", kFloat);
+  auto val = alloc<FloatImm>(2);
+  auto let = alloc<Let>(var, val);
+  STR_CHECK(let, "float x = 2.f;\n");
+}
+
+TEST(CppPrinter, For) {
+  constexpr int N = 1024;
+  Placeholder a(BufHandle("A", {N}, kInt));
+  Placeholder b(BufHandle("B", {N}, kInt));
+  Placeholder c(BufHandle("C", {N}, kInt));
+  VarHandle i("i", kInt);
+  auto f = For::make(i, 0, N, c.store({i}, Add::make(a.load(i), b.load(i))));
+  const std::string pattern = R"(
+   # CHECK: for (int i = 0; i < 1024; i++) {
+   # CHECK:   C[i] = (A[i]) + (B[i]);
+   # CHECK: }
   )";
-  torch::jit::testing::FileCheck().run(expected, ss.str());
-}
-
-TEST(CppPrinter, AllocateOnHeapThenFree) {
-  std::vector<ExprPtr> dims = {
-      alloc<IntImm>(20), alloc<IntImm>(50), alloc<IntImm>(3)};
-  BufPtr buf = alloc<Buf>("y", dims, kLong);
-  AllocatePtr alloc_ = alloc<Allocate>(buf);
-  FreePtr free_ = alloc<Free>(buf);
-  BlockPtr block = Block::make({alloc_, free_});
-
-  std::stringstream ss;
-  CppPrinter printer(&ss);
-  printer.visit(block);
-  // size(long) = 8;
-  // dim0 * dim1 * dim2 * size(long) = 24000.
-  const std::string expected = R"(
-    # CHECK: {
-    # CHECK:   int64_t* y = static_cast<int64_t*>(malloc(24000));
-    # CHECK:   free(y);
+  FILE_CHECK(f, pattern);
+}
+
+TEST(CppPrinter, Cond) {
+  Placeholder x(BufHandle("X", {1}, kInt));
+  auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT);
+  auto cond =
+      Cond::make(cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1));
+  const std::string pattern = R"(
+    # CHECK: if (((X[0] < 10) ? 1 : 0)) {
+    # CHECK:   X[0] = (X[0]) + 1;
+    # CHECK: } else {
+    # CHECK:   X[0] = (X[0]) - 1;
     # CHECK: }
   )";
-  torch::jit::testing::FileCheck().run(expected, ss.str());
+  FILE_CHECK(cond, pattern);
+}
+
+TEST(CppPrinter, Intrinsics) {
+  const std::unordered_set<IntrinsicsOp, std::hash<int>> unsupported_ops{
+      kRand, kSigmoid};
+  for (int i = 0; i < kMaxIntrinsicsOp; i++) {
+    IntrinsicsOp op = static_cast<IntrinsicsOp>(i);
+    if (unsupported_ops.count(op)) {
+      continue;
+    }
+
+    if (Intrinsics::OpArgCount(op) == 1) {
+      auto v = alloc<Intrinsics>(op, alloc<FloatImm>(2.0f));
+      STR_CHECK(v, "std::" + v->func_name() + "(2.f)");
+    } else {
+      auto v =
+          alloc<Intrinsics>(op, alloc<FloatImm>(1.0f), alloc<FloatImm>(2.0f));
+      STR_CHECK(v, "std::" + v->func_name() + "(1.f, 2.f)");
+    }
+  }
+}
+
+TEST(CppPrinter, ExternalCall) {
+  std::vector<ExprPtr> dims{alloc<IntImm>(2), alloc<IntImm>(2)};
+  auto output = alloc<Buf>("out", dims, kFloat);
+  auto buf_arg1 = alloc<Buf>("a", dims, kFloat);
+  auto buf_arg2 = alloc<Buf>("b", dims, kFloat);
+  auto scalar_arg = alloc<Add>(alloc<IntImm>(1), alloc<IntImm>(2));
+  std::vector<BufPtr> buf_args{buf_arg1, buf_arg2};
+  std::vector<ExprPtr> scalar_args{scalar_arg};
+  auto call =
+      alloc<ExternalCall>(output, "nnc_aten_matmul", buf_args, scalar_args);
+  const std::string pattern = R"(
+   # CHECK: {
+   # CHECK:   void* buf_ptrs[]{out, a, b};
+   # CHECK:   int64_t buf_ranks[]{2, 2, 2};
+   # CHECK:   int64_t buf_dims[]{2, 2, 2, 2, 2, 2};
+   # CHECK:   int8_t buf_dtypes[]{6, 6, 6};
+   # CHECK:   int64_t extra_args[]{1 + 2};
+   # CHECK:   nnc_aten_matmul(
+   # CHECK:       3,
+   # CHECK:       buf_ptrs,
+   # CHECK:       buf_ranks,
+   # CHECK:       buf_dims,
+   # CHECK:       buf_dtypes,
+   # CHECK:       1,
+   # CHECK:       extra_args);
+   # CHECK: }
+  )";
+  FILE_CHECK(call, pattern);
 }
 
 } // namespace jit
index d7cfe78..29255aa 100644 (file)
@@ -46,6 +46,10 @@ class TORCH_API CodeGen {
     stmt_ = stmt_->accept_mutator(mutator);
   }
 
+  void apply_visitor(IRVisitor* visitor) {
+    stmt_->accept(visitor);
+  }
+
   std::vector<BufferArg>& buffer_args() {
     return buffer_args_;
   }
index 39a5615..20795e4 100644 (file)
+#include <algorithm>
+#include <type_traits>
+#include <vector>
+
 #include <torch/csrc/jit/tensorexpr/cpp_codegen.h>
+#include <torch/csrc/jit/tensorexpr/cpp_intrinsics.h>
+#include <torch/csrc/jit/tensorexpr/external_functions_registry.h>
+#include <torch/csrc/jit/tensorexpr/types.h>
 
 namespace torch {
 namespace jit {
 namespace tensorexpr {
 
-void CppPrinter::visit(AllocatePtr alloc) {
-  constexpr size_t kAllocOnStackThresholdSize = 512;
+// Rewrites the variables' name according to valid C++ naming convention.
+// E.g. in Graph IR, variable name may contain '.', in C++, they are replaced
+// with '_'.
+class CppVarNameRewriter : public IRVisitor {
+ public:
+  void visit(VarPtr v) override {
+    constexpr char kDot = '.';
+    constexpr char kUnderscore = '_';
+    if (v->name_hint().find(kDot) == std::string::npos) {
+      return;
+    }
+    std::string name = v->name_hint();
+    std::replace(name.begin(), name.end(), kDot, kUnderscore);
+    v->set_name_hint(std::move(name));
+  }
+
+  void visit(BufPtr v) override {
+    v->base_handle()->accept(this);
+  }
+};
+
+static std::string declareExternalFunction(const std::string& func_name) {
+  return "void " + func_name +
+      "("
+      "int64_t bufs_num, "
+      "void** buf_data, "
+      "int64_t* buf_ranks, "
+      "int64_t* buf_dims, "
+      "int8_t* buf_dtypes, "
+      "int64_t args_num, "
+      "int64_t* extra_args);";
+}
+
+CppPrinter::CppPrinter(std::ostream* os) : IRPrinter(*os), lane_(0) {}
+
+CppPrinter::~CppPrinter() = default;
+
+void CppPrinter::printPrologue() {
+  os() << "#include <cassert>" << std::endl;
+  os() << "#include <cmath>" << std::endl;
+  os() << "#include <algorithm>" << std::endl;
+  os() << "#include <type_traits>" << std::endl;
+  os() << std::endl;
+
+  os() << "#define POS_INFINITY INFINITY" << std::endl;
+  os() << "#define NEG_INFINITY -INFINITY" << std::endl;
+  os() << std::endl;
+
+  os() << cpp_intrinsics_definition << std::endl;
+  os() << std::endl;
+
+  os() << "namespace torch {" << std::endl;
+  os() << "namespace jit {" << std::endl;
+  os() << "namespace tensorexpr {" << std::endl;
+  for (auto const& it : getNNCFunctionRegistry()) {
+    os() << declareExternalFunction(it.first) << std::endl;
+  }
+  os() << "} // namespace tensorexpr" << std::endl;
+  os() << "} // namespace jit" << std::endl;
+  os() << "} // namespace torch" << std::endl;
+  os() << std::endl;
+
+  os() << "using namespace torch::jit::tensorexpr;" << std::endl;
+  os() << std::endl;
+}
+
+template <typename T>
+inline typename std::enable_if<!std::is_floating_point<T>::value, void>::type
+visit_mod(std::ostream& os, const ExprPtr lhs, const ExprPtr rhs) {
+  os << *lhs << " % " << *rhs;
+}
+
+template <typename T>
+inline typename std::enable_if<std::is_floating_point<T>::value, void>::type
+visit_mod(std::ostream& os, const ExprPtr lhs, const ExprPtr rhs) {
+  os << "std::fmod(" << *lhs << ", " << *rhs << ")";
+}
+
+template <typename T>
+inline typename std::enable_if<
+    std::is_floating_point<T>::value || std::is_integral<T>::value,
+    void>::type
+visit_max(std::ostream& os, const ExprPtr lhs, const ExprPtr rhs) {
+  os << "std::max(" << *lhs << ", " << *rhs << ")";
+}
 
-  size_t size = 1;
-  for (auto dim : alloc->dims()) {
-    IntImmPtr v = to<IntImm>(dim);
-    if (v) {
-      size *= v->value();
+template <typename T>
+inline typename std::enable_if<
+    !std::is_floating_point<T>::value && !std::is_integral<T>::value,
+    void>::type
+visit_max(std::ostream& os, const ExprPtr lhs, const ExprPtr rhs) {
+  os << "(" << *lhs << " < " << *rhs << ") ? " << *rhs << " : " << *lhs;
+}
+
+template <typename T>
+inline typename std::enable_if<
+    std::is_floating_point<T>::value || std::is_integral<T>::value,
+    void>::type
+visit_min(std::ostream& os, const ExprPtr lhs, const ExprPtr rhs) {
+  os << "std::min(" << *lhs << ", " << *rhs << ")";
+}
+
+template <typename T>
+inline typename std::enable_if<
+    !std::is_floating_point<T>::value && !std::is_integral<T>::value,
+    void>::type
+visit_min(std::ostream& os, const ExprPtr lhs, const ExprPtr rhs) {
+  os << *lhs << " < " << *rhs << " ? " << *lhs << " : " << *rhs;
+}
+
+template <typename T>
+void visit_binary_op(
+    std::ostream& os,
+    const ExprPtr lhs,
+    const ExprPtr rhs,
+    IRNodeType op_type) {
+  switch (op_type) {
+    case IRNodeType::kMod:
+      visit_mod<T>(os, lhs, rhs);
+      break;
+    case IRNodeType::kMax:
+      visit_max<T>(os, lhs, rhs);
+      break;
+    case IRNodeType::kMin:
+      visit_min<T>(os, lhs, rhs);
+      break;
+    default:
+      throw std::runtime_error("invalid op type");
+  }
+}
+
+template <typename Op>
+void dispatch_binary_op(std::ostream& os, const BinaryOpNode<Op>* v) {
+  switch (v->lhs()->dtype().scalar_type()) {
+#define TYPE_CASE(Type, Name)                                      \
+  case ScalarType::Name:                                           \
+    visit_binary_op<Type>(os, v->lhs(), v->rhs(), v->expr_type()); \
+    break;
+    AT_FORALL_SCALAR_TYPES_AND2(Half, Bool, TYPE_CASE);
+#undef TYPE_CASE
+    default:
+      throw unsupported_dtype();
+  }
+}
+
+void CppPrinter::visit(RampPtr v) {
+  visit(alloc<Add>(v->base(), alloc<Mul>(alloc<IntImm>(lane_), v->stride())));
+}
+
+void CppPrinter::visit(BroadcastPtr v) {
+  v->value()->accept(this);
+}
+
+void CppPrinter::visit(ModPtr v) {
+  dispatch_binary_op(os(), v.get());
+}
+
+void CppPrinter::visit(MaxPtr v) {
+  dispatch_binary_op(os(), v.get());
+}
+
+void CppPrinter::visit(MinPtr v) {
+  dispatch_binary_op(os(), v.get());
+}
+
+void CppPrinter::visit(CompareSelectPtr v) {
+  os() << "((" << *v->lhs() << " "
+       << IRPrinter::to_string(v->compare_select_op()) << " " << *v->rhs()
+       << ") ? " << *v->ret_val1() << " : " << *v->ret_val2() << ")";
+}
+
+void CppPrinter::visit(IfThenElsePtr v) {
+  os() << "((" << *v->condition() << ") ? " << *v->true_value() << " : "
+       << *v->false_value() << ")";
+}
+
+void CppPrinter::visit(AllocatePtr v) {
+  size_t size = v->dtype().byte_size();
+  for (const auto& dim : v->dims()) {
+    IntImmPtr d = to<IntImm>(dim);
+    if (d) {
+      size *= d->value();
     } else {
       throw std::runtime_error("Only IntImm dimensions are supported for now");
     }
   }
 
   emitIndent();
-  if (size <= kAllocOnStackThresholdSize) {
-    os() << alloc->dtype().ToCppString() << " " << (*alloc->buffer_var()) << "["
-         << size << "];" << std::endl;
-  } else {
-    size *= alloc->dtype().byte_size();
-    os() << alloc->dtype().ToCppString() << "* " << (*alloc->buffer_var())
-         << " = static_cast<" << alloc->dtype().ToCppString() << "*>(malloc("
-         << size << "));" << std::endl;
-    allocated_on_heap_.insert(alloc->buffer_var());
+  os() << v->dtype().ToCppString() << "* " << (*v->buffer_var())
+       << " = static_cast<" << v->dtype().ToCppString() << "*>(malloc(" << size
+       << "));" << std::endl;
+}
+
+void CppPrinter::visit(FreePtr v) {
+  emitIndent();
+  os() << "free(" << *v->buffer_var() << ");" << std::endl;
+}
+
+void CppPrinter::visit(LoadPtr v) {
+  auto flat_idx = flatten_index(v->buf()->dims(), v->indices());
+  os() << *v->base_handle() << "[" << *flat_idx << "]";
+}
+
+void CppPrinter::visit(StorePtr v) {
+  auto flat_idx = flatten_index(v->buf()->dims(), v->indices());
+  const int lanes = v->value()->dtype().lanes();
+  for (int lane = 0; lane < lanes; lane++) {
+    lane_ = lane;
+    emitIndent();
+    os() << *v->base_handle() << "[" << *flat_idx << "] = " << *v->value()
+         << ";" << std::endl;
+  }
+}
+
+void CppPrinter::visit(CastPtr v) {
+  os() << "static_cast<" << v->dtype().ToCppString() << ">(" << *v->src_value()
+       << ")";
+}
+
+void CppPrinter::visit(BitCastPtr v) {
+  os() << "std::bitcast<" << v->src_value()->dtype().ToCppString() << ", "
+       << v->dtype().ToCppString() << ">(" << *v->src_value() << ")";
+}
+
+void CppPrinter::visit(IntrinsicsPtr v) {
+  if (v->op_type() == kRand || v->op_type() == kSigmoid) {
+    throw std::runtime_error("kRand and kSigmoid are not supported");
+  }
+
+  os() << "std::" << v->func_name() << "(";
+  for (int i = 0; i < v->nparams(); i++) {
+    if (i > 0) {
+      os() << ", ";
+    }
+    os() << *v->param(i);
   }
+  os() << ")";
 }
 
-void CppPrinter::visit(FreePtr free) {
-  VarPtr var = free->buffer_var();
-  if (allocated_on_heap_.count(var)) {
+void CppPrinter::visit(ExternalCallPtr v) {
+  // The generated code needs to link against functions defined
+  // in external_functions.cpp.
+
+  auto& func_registry = getNNCFunctionRegistry();
+  if (!func_registry.count(v->func_name())) {
+    throw unimplemented_lowering(v);
+  }
+
+  std::vector<BufPtr> bufs(v->buf_args());
+  bufs.insert(bufs.begin(), v->buf());
+  auto for_buf = [&](const std::function<void(const BufPtr)>& print_buf) {
+    for (size_t i = 0; i < bufs.size(); i++) {
+      if (i > 0) {
+        os() << ", ";
+      }
+      print_buf(bufs[i]);
+    }
+  };
+
+  emitIndent();
+  os() << "{" << std::endl;
+  indent_++;
+
+  emitIndent();
+  os() << "void* buf_ptrs[]{";
+  for_buf([&](const BufPtr b) { os() << *b->base_handle(); });
+  os() << "};" << std::endl;
+
+  emitIndent();
+  os() << "int64_t buf_ranks[]{";
+  for_buf([&](const BufPtr b) { os() << b->ndim(); });
+  os() << "};" << std::endl;
+
+  emitIndent();
+  os() << "int64_t buf_dims[]{";
+  for_buf([&](const BufPtr buf) {
+    for (size_t i = 0; i < buf->ndim(); i++) {
+      if (i > 0) {
+        os() << ", ";
+      }
+      os() << *buf->dim(i);
+    }
+  });
+  os() << "};" << std::endl;
+
+  emitIndent();
+  os() << "int8_t buf_dtypes[]{";
+  for_buf([&](const BufPtr buf) {
+    os() << static_cast<int>(buf->dtype().scalar_type());
+  });
+  os() << "};" << std::endl;
+
+  emitIndent();
+  os() << "int64_t extra_args[]{";
+  for (size_t i = 0; i < v->args().size(); i++) {
+    if (i > 0) {
+      os() << ", ";
+    }
+    os() << *v->args()[i];
+  }
+  os() << "};" << std::endl;
+
+  emitIndent();
+  os() << v->func_name() << "(" << std::endl;
+  emitIndent();
+  os() << "    " << bufs.size() << "," << std::endl;
+  emitIndent();
+  os() << "    buf_ptrs," << std::endl;
+  emitIndent();
+  os() << "    buf_ranks," << std::endl;
+  emitIndent();
+  os() << "    buf_dims," << std::endl;
+  emitIndent();
+  os() << "    buf_dtypes," << std::endl;
+  emitIndent();
+  os() << "    " << v->args().size() << "," << std::endl;
+  emitIndent();
+  os() << "    extra_args);" << std::endl;
+
+  indent_--;
+  emitIndent();
+  os() << "}" << std::endl;
+}
+
+void CppPrinter::visit(LetPtr v) {
+  if (v->dtype().lanes() == 1) {
     emitIndent();
-    os() << "free(" << name_manager()->get_unique_name(var) << ");"
-         << std::endl;
+    os() << v->dtype().ToCppString() << " " << *v->var() << " = " << *v->value()
+         << ";" << std::endl;
+  } else {
+    vector_vars_[v->var()] = v->value();
+  }
+}
+
+void CppPrinter::visit(VarPtr v) {
+  if (v->dtype().lanes() == 1) {
+    os() << name_manager()->get_unique_name(v);
+  } else {
+    os() << *vector_vars_.at(v);
   }
 }
 
+CppCodeGen::CppCodeGen(
+    StmtPtr stmt,
+    const std::vector<BufferArg>& buffer_args,
+    at::Device device,
+    const std::string& kernel_func_name)
+    : CodeGen(stmt, buffer_args, device, kernel_func_name) {
+  init();
+}
+
+void CppCodeGen::init() {
+  printer_ = std::make_unique<CppPrinter>(&oss_);
+  var_name_rewriter_ = std::make_unique<CppVarNameRewriter>();
+
+  apply_visitor(var_name_rewriter_.get());
+
+  printer_->printPrologue();
+  os() << "void " << kernel_func_name() << "(";
+  const std::vector<BufferArg> buffer_args = this->buffer_args();
+  for (size_t i = 0; i < buffer_args.size(); i++) {
+    if (i > 0) {
+      os() << ", ";
+    }
+    const BufferArg& buffer_arg = buffer_args[i];
+    const VarPtr var = buffer_arg.var();
+    Dtype dtype = buffer_arg.dtype();
+    os() << dtype.ToCppString() << (buffer_arg.isVar() ? " " : "* ") << *var;
+  }
+  os() << ")";
+  stmt()->accept(printer_.get());
+  os() << std::endl;
+}
+
+CppCodeGen::~CppCodeGen() = default;
+
+void CppCodeGen::call(const std::vector<CallArg>& args) {
+  // TODO: compile the generated C++ kernel into a library,
+  // and call the library here.
+  os() << "int main() {}" << std::endl;
+}
+
+void CppCodeGen::call_raw(const std::vector<void*>& args) {
+  // TODO: compile the generated C++ kernel into a library,
+  // and call the library here.
+  os() << "int main() {}" << std::endl;
+}
+
+RegisterCodeGen<CppCodeGen> cpp_codegen_reg("cpp_codegen");
+
 } // namespace tensorexpr
 } // namespace jit
 } // namespace torch
index 1cf1565..a6d583e 100644 (file)
 #pragma once
 
+#include <torch/csrc/jit/tensorexpr/codegen.h>
 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
 
-#include <unordered_set>
-
 namespace torch {
 namespace jit {
 namespace tensorexpr {
 
+class CppVarNameRewriter;
+
 // Generates C++ code from the IR.
+//
+// Vector operations are unrolled.
+// For example:
+// C[Ramp(0, 1, 3)] = A[Ramp(0, 2, 3)] + B[Ramp(0, 3, 3)];
+// is unrolled into:
+// C[0] = A[0] + B[0];
+// C[1] = A[2] + B[3];
+// C[2] = A[4] + B[6];
 class TORCH_API CppPrinter : public IRPrinter {
  public:
-  explicit CppPrinter(std::ostream* os) : IRPrinter(*os) {}
+  explicit CppPrinter(std::ostream* os);
+  ~CppPrinter() override;
+
+  void printPrologue();
 
   using IRPrinter::visit;
+
+  // Binary expressions.
+  void visit(ModPtr) override;
+  void visit(MaxPtr) override;
+  void visit(MinPtr) override;
+
+  // Conditional expressions.
+  void visit(CompareSelectPtr) override;
+  void visit(IfThenElsePtr) override;
+
+  // Tensor operations.
   void visit(AllocatePtr) override;
   void visit(FreePtr) override;
+  void visit(LoadPtr) override;
+  void visit(StorePtr) override;
+
+  // Casts.
+  void visit(CastPtr) override;
+  void visit(BitCastPtr) override;
+
+  // Calls.
+  void visit(IntrinsicsPtr) override;
+  void visit(ExternalCallPtr) override;
+
+  // Vars.
+  void visit(LetPtr) override;
+  void visit(VarPtr) override;
+
+  // Vector data types.
+  void visit(RampPtr) override;
+  void visit(BroadcastPtr) override;
 
  private:
-  std::unordered_set<VarPtr> allocated_on_heap_;
+  int lane_;
+  std::unordered_map<VarPtr, ExprPtr> vector_vars_;
+};
+
+class TORCH_API CppCodeGen : public CodeGen {
+ public:
+  CppCodeGen(
+      StmtPtr stmt,
+      const std::vector<BufferArg>& buffer_args,
+      at::Device device = at::kCPU,
+      const std::string& kernel_func_name = "func");
+
+  ~CppCodeGen() override;
+
+  void call(const std::vector<CallArg>& args) override;
+  void call_raw(const std::vector<void*>& args) override;
+
+  template <typename... Ts>
+  void operator()(const Ts&... ts) {
+    call(std::vector<CallArg>({CallArg(ts)...}));
+  }
+
+  std::string getCodeText(const std::string& attr = "") override {
+    return oss_.str();
+  }
+
+ private:
+  void init();
+
+  std::ostream& os() {
+    return printer_->os();
+  }
+
+  std::ostringstream oss_;
+  std::unique_ptr<CppPrinter> printer_;
+  std::unique_ptr<CppVarNameRewriter> var_name_rewriter_;
 };
 
 } // namespace tensorexpr
diff --git a/torch/csrc/jit/tensorexpr/cpp_intrinsics.h b/torch/csrc/jit/tensorexpr/cpp_intrinsics.h
new file mode 100644 (file)
index 0000000..caeeed6
--- /dev/null
@@ -0,0 +1,36 @@
+#pragma once
+
+namespace torch {
+namespace jit {
+namespace tensorexpr {
+
+constexpr auto cpp_intrinsics_definition = R"(
+namespace std {
+
+template <typename T,
+          typename std::enable_if<std::is_floating_point<T>::value, int>::type = 0>
+T rsqrt(T v) {
+  return 1.0f / std::sqrt(v);
+}
+
+template <typename T,
+          typename std::enable_if<std::is_floating_point<T>::value, int>::type = 0>
+T frac(T v) {
+  T intpart;
+  return std::modf(v, &intpart);
+}
+
+template <typename From, typename To>
+To bitcast(const From& v) {
+  assert(sizeof(To) == sizeof(From));
+  To res;
+  std::memcpy(&res, &v, sizeof(From));
+  return res;
+}
+
+} // namespace std
+)";
+
+} // namespace tensorexpr
+} // namespace jit
+} // namespace torch
index 108236e..a4f317f 100644 (file)
@@ -169,8 +169,12 @@ class TORCH_API Var : public ExprNode<Var> {
     return name_hint_;
   }
 
-  void set_name_hint(const std::string& name_hint) {
-    name_hint_ = name_hint;
+  void set_name_hint(const std::string& name) {
+    name_hint_ = name;
+  }
+
+  void set_name_hint(std::string&& name) {
+    name_hint_ = name;
   }
 
   Var(std::string name_hint, Dtype dtype)
index f9fc7dc..7fe1fd1 100644 (file)
@@ -684,6 +684,7 @@ enum IntrinsicsOp {
   kFrac,
   kIsNan,
   kRand, // We need more discussions on this. Should we consider stateful?
+  kMaxIntrinsicsOp,
 };
 
 class TORCH_API Intrinsics : public ExprNode<Intrinsics> {
@@ -864,8 +865,9 @@ class TORCH_API Intrinsics : public ExprNode<Intrinsics> {
     params_ = std::move(params);
   }
 
- private:
   static int OpArgCount(IntrinsicsOp op_type);
+
+ private:
   static Dtype IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1);
   static Dtype IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1, Dtype dt2);
   static Dtype IntrinsicsDtype(
index 27b56e2..2e1fc6e 100644 (file)
@@ -25,6 +25,24 @@ void IRPrinter::print(Expr& expr) {
 void IRPrinter::print(Stmt& stmt) {
   stmt.accept(this);
 }
+std::string IRPrinter::to_string(CompareSelectOperation op) {
+  switch (op) {
+    case CompareSelectOperation::kEQ:
+      return "==";
+    case CompareSelectOperation::kNE:
+      return "!=";
+    case CompareSelectOperation::kGT:
+      return ">";
+    case CompareSelectOperation::kGE:
+      return ">=";
+    case CompareSelectOperation::kLT:
+      return "<";
+    case CompareSelectOperation::kLE:
+      return "<=";
+    default:
+      throw std::runtime_error("invalid compare select operator");
+  }
+}
 
 // TODO: change whether to include the parenthesis to the parent expression,
 // we need to look at the operator precedence to make the output simpler.
@@ -137,28 +155,8 @@ void IRPrinter::visit(CompareSelectPtr v) {
   if (lhs_prec >= self_prec) {
     os() << ")";
   }
-  switch (cmp_op) {
-    case CompareSelectOperation::kEQ:
-      os() << "==";
-      break;
-    case CompareSelectOperation::kNE:
-      os() << "!=";
-      break;
-    case CompareSelectOperation::kGT:
-      os() << ">";
-      break;
-    case CompareSelectOperation::kGE:
-      os() << ">=";
-      break;
-    case CompareSelectOperation::kLT:
-      os() << "<";
-      break;
-    case CompareSelectOperation::kLE:
-      os() << "<=";
-      break;
-    default:
-      throw std::runtime_error("invalid compare select operator");
-  }
+
+  os() << to_string(cmp_op);
 
   if (rhs_prec >= self_prec) {
     os() << "(";
@@ -230,6 +228,13 @@ void IRPrinter::visit(CastPtr v) {
   os() << ")";
 }
 
+void IRPrinter::visit(BitCastPtr v) {
+  auto dtype = v->dtype();
+  os() << "BitCast<" << dtype.ToCppString() << ">(";
+  v->src_value()->accept(this);
+  os() << ")";
+}
+
 void IRPrinter::visit(VarPtr v) {
   os() << name_manager_.get_unique_name(v);
 }
@@ -439,7 +444,7 @@ void IRPrinter::visit(FreePtr v) {
 void IRPrinter::visit(LetPtr v) {
   os() << dtypeToCppString(v->dtype()) << " " << *v->var();
   os() << " = " << *v->value();
-  os() << ";";
+  os() << ";" << std::endl;
 }
 
 void IRPrinter::visit(CondPtr v) {
index 321d1ef..327119d 100644 (file)
@@ -37,6 +37,7 @@ class TORCH_API IRPrinter : public IRVisitor {
   AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_PRINT_VISIT);
 #undef IMM_PRINT_VISIT
   void visit(CastPtr v) override;
+  void visit(BitCastPtr v) override;
   void visit(VarPtr v) override;
   void visit(RampPtr v) override;
   void visit(LoadPtr v) override;
@@ -83,6 +84,8 @@ class TORCH_API IRPrinter : public IRVisitor {
   };
 
  protected:
+  std::string to_string(CompareSelectOperation op);
+
   UniqueNameManager* name_manager() {
     return &name_manager_;
   }