}
#endif
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.capi.h.inc"
+
#endif // MLIR_C_DIALECT_SPARSE_TENSOR_H
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name SparseTensor)
+mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix SparseTensor)
+mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix SparseTensor)
add_public_tablegen_target(MLIRSparseTensorPassIncGen)
add_mlir_doc(Passes SparseTensorPasses ./ -gen-pass-doc)
)
add_dependencies(MLIRBindingsPythonExtension MLIRAsyncPassesBindingsPythonExtension)
+add_mlir_python_extension(MLIRSparseTensorPassesBindingsPythonExtension _mlirSparseTensorPasses
+ INSTALL_DIR
+ python
+ SOURCES
+ SparseTensorPasses.cpp
+)
+add_dependencies(MLIRBindingsPythonExtension MLIRSparseTensorPassesBindingsPythonExtension)
+
add_mlir_python_extension(MLIRGPUPassesBindingsPythonExtension _mlirGPUPasses
INSTALL_DIR
python
--- /dev/null
+//===- SparseTensorPasses.cpp - Pybind module for the SparseTensor passes -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Dialect/SparseTensor.h"
+
+#include <pybind11/pybind11.h>
+
+// -----------------------------------------------------------------------------
+// Module initialization.
+// -----------------------------------------------------------------------------
+
+PYBIND11_MODULE(_mlirSparseTensorPasses, m) {
+ m.doc() = "MLIR SparseTensor Dialect Passes";
+
+ // Register all SparseTensor passes on load.
+ mlirRegisterSparseTensorPasses();
+}
add_mlir_public_c_api_library(MLIRCAPISparseTensor
SparseTensor.cpp
+ SparseTensorPasses.cpp
PARTIAL_SOURCES_INTENDED
LINK_LIBS PUBLIC
MLIRCAPIIR
MLIRSparseTensor
+ MLIRSparseTensorTransforms
)
add_mlir_public_c_api_library(MLIRCAPIStandard
--- /dev/null
+//===- SparseTensorPasses.cpp - C API for SparseTensor Dialect Passes -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/CAPI/Pass.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Pass/Pass.h"
+
+// Must include the declarations as they carry important visibility attributes.
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.capi.h.inc"
+
+using namespace mlir;
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.capi.cpp.inc"
+
+#ifdef __cplusplus
+}
+#endif
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._cext_loader import _reexport_cext
+from .._cext_loader import _load_extension
+
_reexport_cext("dialects.sparse_tensor", __name__)
+_cextSparseTensorPasses = _load_extension("_mlirSparseTensorPasses")
+
del _reexport_cext
+del _load_extension
--- /dev/null
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.passmanager import *
+
+from mlir.dialects import sparse_tensor as st
+
+
+def run(f):
+ print('\nTEST:', f.__name__)
+ f()
+ return f
+
+
+# CHECK-LABEL: TEST: testSparseTensorPass
+@run
+def testSparseTensorPass():
+ with Context() as context:
+ PassManager.parse('sparsification')
+ PassManager.parse('sparse-tensor-conversion')
+ # CHECK: SUCCESS
+ print('SUCCESS')