[tensorexpr] Add 'pre_alloc' argument in python API of tensorexpr kernel (#64718)
authorHui Guo <huiguo@fb.com>
Fri, 10 Sep 2021 16:59:25 +0000 (09:59 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 10 Sep 2021 17:03:00 +0000 (10:03 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64718

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D30826582

Pulled By: huiguoo

fbshipit-source-id: 6c173c8964f2643039273cdc83e64fb02bb5f381

torch/csrc/jit/tensorexpr/tensorexpr_init.cpp

index ad8962d..7e93092 100644 (file)
@@ -666,15 +666,23 @@ void initTensorExprBindings(PyObject* module) {
   using TSGraph = std::shared_ptr<Graph>;
   py::class_<TensorExprKernel>(te, "TensorExprKernel")
       .def(py::init<const TSGraph&>())
-      .def(py::init([](const TSGraph& g,
-                       std::unordered_map<std::string, NNCLoweringFunction>
-                           custom_lowerings_str) {
-        std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings;
-        for (auto& kv : custom_lowerings_str) {
-          custom_lowerings[c10::Symbol::fromQualString(kv.first)] = kv.second;
-        }
-        return std::make_unique<TensorExprKernel>(g, custom_lowerings);
-      }))
+      .def(
+          py::init([](const TSGraph& g,
+                      std::unordered_map<std::string, NNCLoweringFunction>
+                          custom_lowerings_str,
+                      bool pre_alloc = false) {
+            std::unordered_map<c10::Symbol, NNCLoweringFunction>
+                custom_lowerings;
+            for (auto& kv : custom_lowerings_str) {
+              custom_lowerings[c10::Symbol::fromQualString(kv.first)] =
+                  kv.second;
+            }
+            return std::make_unique<TensorExprKernel>(
+                g, custom_lowerings, pre_alloc);
+          }),
+          py::arg("g"),
+          py::arg("custom_lowerings_str"),
+          py::arg("pre_alloc") = false)
       .def(
           "run",
           [](TensorExprKernel& self, const py::tuple& inputs) {