using TSGraph = std::shared_ptr<Graph>;
py::class_<TensorExprKernel>(te, "TensorExprKernel")
.def(py::init<const TSGraph&>())
- .def(py::init([](const TSGraph& g,
- std::unordered_map<std::string, NNCLoweringFunction>
- custom_lowerings_str) {
- std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings;
- for (auto& kv : custom_lowerings_str) {
- custom_lowerings[c10::Symbol::fromQualString(kv.first)] = kv.second;
- }
- return std::make_unique<TensorExprKernel>(g, custom_lowerings);
- }))
+ .def(
+ py::init([](const TSGraph& g,
+ std::unordered_map<std::string, NNCLoweringFunction>
+ custom_lowerings_str,
+ bool pre_alloc = false) {
+ std::unordered_map<c10::Symbol, NNCLoweringFunction>
+ custom_lowerings;
+ for (auto& kv : custom_lowerings_str) {
+ custom_lowerings[c10::Symbol::fromQualString(kv.first)] =
+ kv.second;
+ }
+ return std::make_unique<TensorExprKernel>(
+ g, custom_lowerings, pre_alloc);
+ }),
+ py::arg("g"),
+ py::arg("custom_lowerings_str"),
+ py::arg("pre_alloc") = false)
.def(
"run",
[](TensorExprKernel& self, const py::tuple& inputs) {