[TensorExpr] PyBinds: improve QoL of pybind users. (#64886)
authorMikhail Zolotukhin <mvz@fb.com>
Tue, 14 Sep 2021 07:19:57 +0000 (00:19 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 14 Sep 2021 07:21:28 +0000 (00:21 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64886

Bind methods for implicit conversions and constructors to avoid
boilerplate code.

Differential Revision:
D30889193
D30889193

Test Plan: Imported from OSS

Reviewed By: jbschlosser

Pulled By: ZolotukhinM

fbshipit-source-id: 137c0c98f7f1576e1bb97c8de8a900b28407a30e

test/test_tensorexpr_pybind.py
torch/csrc/jit/tensorexpr/tensorexpr_init.cpp

index 9a70838..d04cd05 100644 (file)
@@ -9,15 +9,14 @@ import unittest
 LLVM_ENABLED = torch._C._llvm_enabled()
 
 
-def construct_adder(n: int, dtype=te.Dtype.Float):
-    dN = te.ExprHandle.int(n)
-    A = te.Placeholder('A', dtype, [dN])
-    B = te.Placeholder('B', dtype, [dN])
+def construct_adder(n: int, dtype=torch.float32):
+    A = te.BufHandle('A', [n], dtype)
+    B = te.BufHandle('B', [n], dtype)
 
     def compute(i):
         return A.load([i]) + B.load([i])
 
-    C = te.Compute('C', [te.DimArg(dN, 'i')], compute)
+    C = te.Compute('C', [n], compute)
 
     loopnest = te.LoopNest([C])
     loopnest.prepare_for_codegen()
@@ -50,17 +49,15 @@ class TestTensorExprPyBind(JitTestCase):
     def test_external_calls(self):
         dtype = torch.float32
 
-        ONE = te.ExprHandle.int(1)
-        FOUR = te.ExprHandle.int(4)
-        A = te.BufHandle('A', [ONE, FOUR], dtype)
-        B = te.BufHandle('B', [FOUR, ONE], dtype)
-        C = te.BufHandle('C', [ONE, ONE], dtype)
+        A = te.BufHandle('A', [1, 4], dtype)
+        B = te.BufHandle('B', [4, 1], dtype)
+        C = te.BufHandle('C', [1, 1], dtype)
 
         s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], [])
 
         loopnest = te.LoopNest(s, [C])
         loopnest.prepare_for_codegen()
-        codegen = te.construct_codegen('ir_eval', s, [te.BufferArg(x) for x in [A, B, C]])
+        codegen = te.construct_codegen('ir_eval', s, [A, B, C])
 
         tA = torch.ones(1, 4)
         tB = torch.ones(4, 1)
@@ -97,11 +94,8 @@ class TestTensorExprPyBind(JitTestCase):
         test_with_shape(31)
 
     def test_dtype_error(self):
-        one = te.ExprHandle.int(1)
-        te.Placeholder([one], torch.float32)  # ok
-        te.Placeholder([one])  # ok
-        self.assertRaises(TypeError,
-                          lambda: te.Placeholder([one], "float55"))
+        te.BufHandle('a', [1], torch.float32)  # ok
+        self.assertRaises(TypeError, lambda: te.BufHandle('a', [1], "float55"))
 
     @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
     def test_kernel_with_tensor_inputs(self):
@@ -124,7 +118,7 @@ graph(%a.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu),
         """
         graph = torch._C.parse_ir(graph_str)
 
-        kernel = torch._C._te.TensorExprKernel(graph)
+        kernel = te.TensorExprKernel(graph)
         res1 = kernel.run((x, y, z))
         res2 = kernel.fallback((x, y, z))
         correct = f(x, y, z)
@@ -151,7 +145,7 @@ graph(%a.1 : Float(requires_grad=0, device=cpu),
         """
         graph = torch._C.parse_ir(graph_str)
 
-        kernel = torch._C._te.TensorExprKernel(graph)
+        kernel = te.TensorExprKernel(graph)
         res1 = kernel.run((x, y, z))
         res2 = kernel.fallback((x, y, z))
         correct = f(x, y, z)
@@ -173,7 +167,7 @@ graph(%a : Tensor, %b : Tensor):
 
         exception_thrown = False
         try:
-            kernel = torch._C._te.TensorExprKernel(graph)
+            kernel = te.TensorExprKernel(graph)
         except RuntimeError:
             # Graph doesn't have shape info for inputs => compilation should
             # fail
@@ -187,7 +181,7 @@ graph(%a : Tensor, %b : Tensor):
         torch._C._jit_pass_propagate_shapes_on_graph(graph)
 
         # Now compilation should pass
-        kernel = torch._C._te.TensorExprKernel(graph)
+        kernel = te.TensorExprKernel(graph)
 
         res = kernel.run((x, y))
         correct = torch.mul(x, y)
@@ -205,7 +199,7 @@ graph(%a : Tensor, %b : Tensor):
         # shape info.
         exception_thrown = False
         try:
-            kernel = torch._C._te.TensorExprKernel(graph)
+            kernel = te.TensorExprKernel(graph)
         except RuntimeError:
             exception_thrown = True
             pass
@@ -231,7 +225,7 @@ graph(%a : Tensor, %b : Tensor):
         torch._C._jit_pass_propagate_shapes_on_graph(graph)
 
         # Now compilation should pass
-        kernel = torch._C._te.TensorExprKernel(graph)
+        kernel = te.TensorExprKernel(graph)
 
         device, size = 'cpu', (4, 4)
         x = torch.rand(size, device=device)
@@ -256,7 +250,7 @@ graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
         """
         graph = torch._C.parse_ir(graph_str)
 
-        kernel = torch._C._te.TensorExprKernel(graph)
+        kernel = te.TensorExprKernel(graph)
         res1 = kernel.run((x,))
         res2 = kernel.fallback((x,))
         correct = f(x)
@@ -280,7 +274,7 @@ graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
         """
         graph = torch._C.parse_ir(graph_str)
 
-        kernel = torch._C._te.TensorExprKernel(graph)
+        kernel = te.TensorExprKernel(graph)
         res1 = kernel.run((x,))
         res2 = kernel.fallback((x,))
         correct = f(x)
@@ -306,7 +300,7 @@ graph(%a.1 : Float(3, 4, 5, strides=[20, 5, 1], requires_grad=0, device=cpu)):
         """
         graph = torch._C.parse_ir(graph_str)
 
-        kernel = torch._C._te.TensorExprKernel(graph)
+        kernel = te.TensorExprKernel(graph)
         res1 = kernel.run((x,))
         res2 = kernel.fallback((x,))
         correct = f(x)
@@ -341,7 +335,7 @@ graph(%x : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu)):
                 return te.ifThenElse(te.ExprHandle.isnan(load), te.ExprHandle.float(0.), load)
             return te.Compute2("custom_nan_to_num", get_dim_args(out_shape), compute)
 
-        kernel = torch._C._te.TensorExprKernel(graph, {'aten::nan_to_num' : my_custom_lowering})
+        kernel = te.TensorExprKernel(graph, {'aten::nan_to_num' : my_custom_lowering})
         res1 = kernel.run((x,))
         res2 = kernel.fallback((x,))
         correct = f(x)
@@ -367,7 +361,7 @@ graph(%a : Float(1, 3, 1, strides=[3, 1, 1], requires_grad=0, device=cpu)):
         """
         graph = torch._C.parse_ir(graph_str)
 
-        kernel = torch._C._te.TensorExprKernel(graph)
+        kernel = te.TensorExprKernel(graph)
         res1 = kernel.run((x,))
         res2 = kernel.fallback((x,))
         correct = f(x)
@@ -376,18 +370,15 @@ graph(%a : Float(1, 3, 1, strides=[3, 1, 1], requires_grad=0, device=cpu)):
 
     @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
     def test_alloc_in_loop(self):
-        a, tmp, b = [
-            te.Placeholder(name, te.Dtype.Float, [te.ExprHandle.int(1)])
-            for name in ["a", "tmp", "b"]]
-        t0, t100 = [te.ExprHandle.int(n) for n in [0, 100]]
+        a, tmp, b = [te.BufHandle(name, [1], torch.float32) for name in ["a", "tmp", "b"]]
         body = te.Block([
-            tmp.store([t0], a.load([t0])),
-            b.store([t0], tmp.load([t0]))
+            tmp.store([0], a.load([0])),
+            b.store([0], tmp.load([0]))
         ])
         for _ in range(4):
-            i = te.VarHandle("i", te.Dtype.Int)
-            body = te.For.make(i, t0, t100, body)
-        nest = te.LoopNest(body, [b.data()])
+            i = te.VarHandle("i", torch.int32)
+            body = te.For.make(i, 0, 100, body)
+        nest = te.LoopNest(body, [b])
         nest.prepare_for_codegen()
         f = te.construct_codegen("llvm", nest.simplify(), [a, b])
         ta, tb = [torch.ones(1) for _ in range(2)]
index 27364cf..f0b7be9 100644 (file)
@@ -75,6 +75,13 @@ void initTensorExprBindings(PyObject* module) {
 
   auto expr_handle_class =
       py::class_<ExprHandle>(te, "ExprHandle")
+          .def(
+              "__str__",
+              [](const ExprHandle& self) {
+                std::stringstream ss;
+                ss << self;
+                return ss.str();
+              })
           .def(py::self + py::self)
           .def(py::self * py::self)
           .def(py::self - py::self)
@@ -124,7 +131,23 @@ void initTensorExprBindings(PyObject* module) {
           .def("trunc", [](const ExprHandle& self) { return trunc(self); })
           .def("frac", [](const ExprHandle& self) { return frac(self); })
           .def("lgamma", [](const ExprHandle& self) { return lgamma(self); })
-          .def("isnan", [](const ExprHandle& self) { return isnan(self); });
+          .def("isnan", [](const ExprHandle& self) { return isnan(self); })
+          .def(
+              "cast",
+              [](const ExprHandle& self, const Dtype& dt) {
+                return Cast::make(dt, self);
+              })
+#define EXPRHANDLE_INIT(ctype, name) \
+  .def(py::init([](ctype val) { return name##Imm::make(val); }))
+              AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, EXPRHANDLE_INIT)
+#undef EXPRHANDLE_INIT
+      ;
+
+#define EXPRHANDLE_IMPL_CONV(ctype, name) \
+  py::implicitly_convertible<ctype, ExprHandle>();
+  AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, EXPRHANDLE_IMPL_CONV)
+#undef EXPRHANDLE_IMPL_CONV
+
   te.def(
       "ifThenElse",
       [](const ExprHandle& c, const ExprHandle& t, const ExprHandle& f) {
@@ -149,6 +172,13 @@ void initTensorExprBindings(PyObject* module) {
 #undef EXPRHANDLE_CTOR
 
   py::class_<VarHandle, ExprHandle>(te, "VarHandle")
+      .def(
+          "__str__",
+          [](const ExprHandle& self) {
+            std::stringstream ss;
+            ss << self;
+            return ss.str();
+          })
       .def(py::init<Dtype>())
       .def(py::init<const std::string&, Dtype>());
   py::class_<BufHandle, ExprHandle>( // NOLINT
@@ -163,9 +193,16 @@ void initTensorExprBindings(PyObject* module) {
           [](BufHandle& self, const std::vector<ExprHandle>& v) {
             return Load::make(self, v);
           })
-      .def("load", [](BufHandle& self, const ExprHandle& v) {
-        return Load::make(self, {v});
-      });
+      .def(
+          "load",
+          [](BufHandle& self, const ExprHandle& v) {
+            return Load::make(self, {v});
+          })
+      .def(
+          "store",
+          [](BufHandle& self,
+             const std::vector<ExprHandle>& args,
+             const ExprHandle& val) { return Store::make(self, args, val); });
 
   py::class_<Placeholder>(te, "Placeholder")
       .def(py::init<
@@ -196,12 +233,20 @@ void initTensorExprBindings(PyObject* module) {
       .def("buf", [](Tensor& self) { return BufHandle(self.buf()); })
       .def("stmt", &Tensor::stmt);
   py::class_<Cast, std::shared_ptr<Cast>>(te, "Cast")
-      .def_static("make", &Cast::make);
+      .def_static("make", &Cast::make)
+      .def(
+          "src_value",
+          [](CastPtr& self) { return ExprHandle(self->src_value()); })
+      .def("set_src_value", [](CastPtr& self, const ExprHandle& value) {
+        self->set_src_value(value.node());
+      });
 
   py::class_<DimArg>(te, "DimArg")
       .def(py::init<const ExprHandle&>())
       .def(py::init<const ExprHandle&, const std::string&>());
   py::implicitly_convertible<ExprHandle, DimArg>();
+  py::implicitly_convertible<int32_t, DimArg>();
+  py::implicitly_convertible<int64_t, DimArg>();
 
   te.def(
       "Compute",
@@ -584,7 +629,7 @@ void initTensorExprBindings(PyObject* module) {
           py::return_value_policy::reference)
       .def(
           "flatten",
-          [](const std::vector<ForPtr>& loops) {
+          [](LoopNest& self, const std::vector<ForPtr>& loops) {
             ForPtr flattened = nullptr;
             LoopNest::flatten(loops, &flattened);
             return flattened;