LLVM_ENABLED = torch._C._llvm_enabled()
-def construct_adder(n: int, dtype=te.Dtype.Float):
- dN = te.ExprHandle.int(n)
- A = te.Placeholder('A', dtype, [dN])
- B = te.Placeholder('B', dtype, [dN])
+def construct_adder(n: int, dtype=torch.float32):
+ A = te.BufHandle('A', [n], dtype)
+ B = te.BufHandle('B', [n], dtype)
def compute(i):
return A.load([i]) + B.load([i])
- C = te.Compute('C', [te.DimArg(dN, 'i')], compute)
+ C = te.Compute('C', [n], compute)
loopnest = te.LoopNest([C])
loopnest.prepare_for_codegen()
def test_external_calls(self):
dtype = torch.float32
- ONE = te.ExprHandle.int(1)
- FOUR = te.ExprHandle.int(4)
- A = te.BufHandle('A', [ONE, FOUR], dtype)
- B = te.BufHandle('B', [FOUR, ONE], dtype)
- C = te.BufHandle('C', [ONE, ONE], dtype)
+ A = te.BufHandle('A', [1, 4], dtype)
+ B = te.BufHandle('B', [4, 1], dtype)
+ C = te.BufHandle('C', [1, 1], dtype)
s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], [])
loopnest = te.LoopNest(s, [C])
loopnest.prepare_for_codegen()
- codegen = te.construct_codegen('ir_eval', s, [te.BufferArg(x) for x in [A, B, C]])
+ codegen = te.construct_codegen('ir_eval', s, [A, B, C])
tA = torch.ones(1, 4)
tB = torch.ones(4, 1)
test_with_shape(31)
def test_dtype_error(self):
- one = te.ExprHandle.int(1)
- te.Placeholder([one], torch.float32) # ok
- te.Placeholder([one]) # ok
- self.assertRaises(TypeError,
- lambda: te.Placeholder([one], "float55"))
+ te.BufHandle('a', [1], torch.float32) # ok
+ self.assertRaises(TypeError, lambda: te.BufHandle('a', [1], "float55"))
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
def test_kernel_with_tensor_inputs(self):
"""
graph = torch._C.parse_ir(graph_str)
- kernel = torch._C._te.TensorExprKernel(graph)
+ kernel = te.TensorExprKernel(graph)
res1 = kernel.run((x, y, z))
res2 = kernel.fallback((x, y, z))
correct = f(x, y, z)
"""
graph = torch._C.parse_ir(graph_str)
- kernel = torch._C._te.TensorExprKernel(graph)
+ kernel = te.TensorExprKernel(graph)
res1 = kernel.run((x, y, z))
res2 = kernel.fallback((x, y, z))
correct = f(x, y, z)
exception_thrown = False
try:
- kernel = torch._C._te.TensorExprKernel(graph)
+ kernel = te.TensorExprKernel(graph)
except RuntimeError:
# Graph doesn't have shape info for inputs => compilation should
# fail
torch._C._jit_pass_propagate_shapes_on_graph(graph)
# Now compilation should pass
- kernel = torch._C._te.TensorExprKernel(graph)
+ kernel = te.TensorExprKernel(graph)
res = kernel.run((x, y))
correct = torch.mul(x, y)
# shape info.
exception_thrown = False
try:
- kernel = torch._C._te.TensorExprKernel(graph)
+ kernel = te.TensorExprKernel(graph)
except RuntimeError:
exception_thrown = True
pass
torch._C._jit_pass_propagate_shapes_on_graph(graph)
# Now compilation should pass
- kernel = torch._C._te.TensorExprKernel(graph)
+ kernel = te.TensorExprKernel(graph)
device, size = 'cpu', (4, 4)
x = torch.rand(size, device=device)
"""
graph = torch._C.parse_ir(graph_str)
- kernel = torch._C._te.TensorExprKernel(graph)
+ kernel = te.TensorExprKernel(graph)
res1 = kernel.run((x,))
res2 = kernel.fallback((x,))
correct = f(x)
"""
graph = torch._C.parse_ir(graph_str)
- kernel = torch._C._te.TensorExprKernel(graph)
+ kernel = te.TensorExprKernel(graph)
res1 = kernel.run((x,))
res2 = kernel.fallback((x,))
correct = f(x)
"""
graph = torch._C.parse_ir(graph_str)
- kernel = torch._C._te.TensorExprKernel(graph)
+ kernel = te.TensorExprKernel(graph)
res1 = kernel.run((x,))
res2 = kernel.fallback((x,))
correct = f(x)
return te.ifThenElse(te.ExprHandle.isnan(load), te.ExprHandle.float(0.), load)
return te.Compute2("custom_nan_to_num", get_dim_args(out_shape), compute)
- kernel = torch._C._te.TensorExprKernel(graph, {'aten::nan_to_num' : my_custom_lowering})
+ kernel = te.TensorExprKernel(graph, {'aten::nan_to_num' : my_custom_lowering})
res1 = kernel.run((x,))
res2 = kernel.fallback((x,))
correct = f(x)
"""
graph = torch._C.parse_ir(graph_str)
- kernel = torch._C._te.TensorExprKernel(graph)
+ kernel = te.TensorExprKernel(graph)
res1 = kernel.run((x,))
res2 = kernel.fallback((x,))
correct = f(x)
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
def test_alloc_in_loop(self):
- a, tmp, b = [
- te.Placeholder(name, te.Dtype.Float, [te.ExprHandle.int(1)])
- for name in ["a", "tmp", "b"]]
- t0, t100 = [te.ExprHandle.int(n) for n in [0, 100]]
+ a, tmp, b = [te.BufHandle(name, [1], torch.float32) for name in ["a", "tmp", "b"]]
body = te.Block([
- tmp.store([t0], a.load([t0])),
- b.store([t0], tmp.load([t0]))
+ tmp.store([0], a.load([0])),
+ b.store([0], tmp.load([0]))
])
for _ in range(4):
- i = te.VarHandle("i", te.Dtype.Int)
- body = te.For.make(i, t0, t100, body)
- nest = te.LoopNest(body, [b.data()])
+ i = te.VarHandle("i", torch.int32)
+ body = te.For.make(i, 0, 100, body)
+ nest = te.LoopNest(body, [b])
nest.prepare_for_codegen()
f = te.construct_codegen("llvm", nest.simplify(), [a, b])
ta, tb = [torch.ones(1) for _ in range(2)]
auto expr_handle_class =
py::class_<ExprHandle>(te, "ExprHandle")
+ .def(
+ "__str__",
+ [](const ExprHandle& self) {
+ std::stringstream ss;
+ ss << self;
+ return ss.str();
+ })
.def(py::self + py::self)
.def(py::self * py::self)
.def(py::self - py::self)
.def("trunc", [](const ExprHandle& self) { return trunc(self); })
.def("frac", [](const ExprHandle& self) { return frac(self); })
.def("lgamma", [](const ExprHandle& self) { return lgamma(self); })
- .def("isnan", [](const ExprHandle& self) { return isnan(self); });
+ .def("isnan", [](const ExprHandle& self) { return isnan(self); })
+ .def(
+ "cast",
+ [](const ExprHandle& self, const Dtype& dt) {
+ return Cast::make(dt, self);
+ })
+#define EXPRHANDLE_INIT(ctype, name) \
+ .def(py::init([](ctype val) { return name##Imm::make(val); }))
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, EXPRHANDLE_INIT)
+#undef EXPRHANDLE_INIT
+ ;
+
+#define EXPRHANDLE_IMPL_CONV(ctype, name) \
+ py::implicitly_convertible<ctype, ExprHandle>();
+ AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, EXPRHANDLE_IMPL_CONV)
+#undef EXPRHANDLE_IMPL_CONV
+
te.def(
"ifThenElse",
[](const ExprHandle& c, const ExprHandle& t, const ExprHandle& f) {
#undef EXPRHANDLE_CTOR
py::class_<VarHandle, ExprHandle>(te, "VarHandle")
+ .def(
+ "__str__",
+ [](const ExprHandle& self) {
+ std::stringstream ss;
+ ss << self;
+ return ss.str();
+ })
.def(py::init<Dtype>())
.def(py::init<const std::string&, Dtype>());
py::class_<BufHandle, ExprHandle>( // NOLINT
[](BufHandle& self, const std::vector<ExprHandle>& v) {
return Load::make(self, v);
})
- .def("load", [](BufHandle& self, const ExprHandle& v) {
- return Load::make(self, {v});
- });
+ .def(
+ "load",
+ [](BufHandle& self, const ExprHandle& v) {
+ return Load::make(self, {v});
+ })
+ .def(
+ "store",
+ [](BufHandle& self,
+ const std::vector<ExprHandle>& args,
+ const ExprHandle& val) { return Store::make(self, args, val); });
py::class_<Placeholder>(te, "Placeholder")
.def(py::init<
.def("buf", [](Tensor& self) { return BufHandle(self.buf()); })
.def("stmt", &Tensor::stmt);
py::class_<Cast, std::shared_ptr<Cast>>(te, "Cast")
- .def_static("make", &Cast::make);
+ .def_static("make", &Cast::make)
+ .def(
+ "src_value",
+ [](CastPtr& self) { return ExprHandle(self->src_value()); })
+ .def("set_src_value", [](CastPtr& self, const ExprHandle& value) {
+ self->set_src_value(value.node());
+ });
py::class_<DimArg>(te, "DimArg")
.def(py::init<const ExprHandle&>())
.def(py::init<const ExprHandle&, const std::string&>());
py::implicitly_convertible<ExprHandle, DimArg>();
+ py::implicitly_convertible<int32_t, DimArg>();
+ py::implicitly_convertible<int64_t, DimArg>();
te.def(
"Compute",
py::return_value_policy::reference)
.def(
"flatten",
- [](const std::vector<ForPtr>& loops) {
+ [](LoopNest& self, const std::vector<ForPtr>& loops) {
ForPtr flattened = nullptr;
LoopNest::flatten(loops, &flattened);
return flattened;