[mlir][python] provide bindings for ops from the sparse_tensor dialect
authorAlex Zinenko <zinenko@google.com>
Thu, 30 Sep 2021 13:09:30 +0000 (15:09 +0200)
committerAlex Zinenko <zinenko@google.com>
Thu, 30 Sep 2021 13:53:16 +0000 (15:53 +0200)
Previously, the dialect was exposed for linking and pass management purposes,
but we did not generate op classes for it. Generate them.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D110819

mlir/python/CMakeLists.txt
mlir/python/mlir/dialects/SparseTensorOps.td [new file with mode: 0644]
mlir/python/mlir/dialects/sparse_tensor.py

index 2ab3a9a..eb7e1e4 100644 (file)
@@ -128,6 +128,7 @@ declare_mlir_dialect_python_bindings(
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  TD_FILE dialects/SparseTensorOps.td
   SOURCES dialects/sparse_tensor.py
   DIALECT_NAME sparse_tensor)
 
diff --git a/mlir/python/mlir/dialects/SparseTensorOps.td b/mlir/python/mlir/dialects/SparseTensorOps.td
new file mode 100644 (file)
index 0000000..b3b4846
--- /dev/null
@@ -0,0 +1,15 @@
+//===-- SparseTensorOps.td - Entry point for bindings ------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_SPARSE_TENSOR_OPS
+#define PYTHON_BINDINGS_SPARSE_TENSOR_OPS
+
+include "mlir/Bindings/Python/Attributes.td"
+include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.td"
+
+#endif
index 4a89ef8..4f6b675 100644 (file)
@@ -2,5 +2,6 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from ._sparse_tensor_ops_gen import *
 from .._mlir_libs._mlir.dialects.sparse_tensor import *
 from .._mlir_libs import _mlirSparseTensorPasses as _cextSparseTensorPasses