[mlir][sparse] refactor sparse compiler pipeline to single place
authorAart Bik <ajcbik@google.com>
Tue, 22 Feb 2022 20:21:07 +0000 (12:21 -0800)
committerAart Bik <ajcbik@google.com>
Wed, 23 Feb 2022 00:23:56 +0000 (16:23 -0800)
Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D120347

mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
mlir/test/Integration/Dialect/SparseTensor/python/test_elementwise_add_sparse_output.py
mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
mlir/test/Integration/Dialect/SparseTensor/python/tools/sparse_compiler.py [new file with mode: 0644]

index c52b30c..538d5c8 100644 (file)
@@ -4,18 +4,19 @@
 import ctypes
 import numpy as np
 import os
-
-import mlir.all_passes_registration
+import sys
 
 from mlir import ir
 from mlir import runtime as rt
 from mlir import execution_engine
-from mlir import passmanager
 
 from mlir.dialects import sparse_tensor as st
 from mlir.dialects import builtin
 from mlir.dialects.linalg.opdsl import lang as dsl
 
+_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(_SCRIPT_PATH)
+from tools import sparse_compiler
 
 @dsl.linalg_structured_op
 def sddmm_dsl(
@@ -119,18 +120,6 @@ def build_compile_and_run_SDDMMM(attr: st.EncodingAttr, opt: str,
     quit(f'FAILURE')
 
 
-class SparseCompiler:
-  """Sparse compiler passes."""
-
-  def __init__(self, options: str):
-    pipeline = (
-        f'sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}}')
-    self.pipeline = pipeline
-
-  def __call__(self, module: ir.Module):
-    passmanager.PassManager.parse(self.pipeline).run(module)
-
-
 def main():
   support_lib = os.getenv('SUPPORT_LIB')
   assert support_lib is not None, 'SUPPORT_LIB is undefined'
@@ -166,7 +155,7 @@ def main():
                   opt = (f'parallelization-strategy={par} '
                          f'vectorization-strategy={vec} '
                          f'vl={vl} enable-simd-index32={e}')
-                  compiler = SparseCompiler(options=opt)
+                  compiler = sparse_compiler.SparseCompiler(options=opt)
                   build_compile_and_run_SDDMMM(attr, opt, support_lib, compiler)
                   count = count + 1
   # CHECK: Passed 16 tests
index 1b66628..77b94ea 100644 (file)
@@ -4,18 +4,19 @@
 import ctypes
 import numpy as np
 import os
-
-import mlir.all_passes_registration
+import sys
 
 from mlir import ir
 from mlir import runtime as rt
 from mlir import execution_engine
-from mlir import passmanager
 
 from mlir.dialects import sparse_tensor as st
 from mlir.dialects import builtin
 from mlir.dialects.linalg.opdsl import lang as dsl
 
+_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(_SCRIPT_PATH)
+from tools import sparse_compiler
 
 @dsl.linalg_structured_op
 def matmul_dsl(
@@ -108,18 +109,6 @@ def build_compile_and_run_SpMM(attr: st.EncodingAttr, support_lib: str,
     quit(f'FAILURE')
 
 
-class SparseCompiler:
-  """Sparse compiler passes."""
-
-  def __init__(self, options: str):
-    pipeline = (
-        f'sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}}')
-    self.pipeline = pipeline
-
-  def __call__(self, module: ir.Module):
-    passmanager.PassManager.parse(self.pipeline).run(module)
-
-
 def main():
   support_lib = os.getenv('SUPPORT_LIB')
   assert support_lib is not None, 'SUPPORT_LIB is undefined'
@@ -155,7 +144,7 @@ def main():
         for pwidth in bitwidths:
           for iwidth in bitwidths:
             attr = st.EncodingAttr.get(level, ordering, pwidth, iwidth)
-            compiler = SparseCompiler(options=opt)
+            compiler = sparse_compiler.SparseCompiler(options=opt)
             build_compile_and_run_SpMM(attr, support_lib, compiler)
             count = count + 1
     # CHECK: Passed 8 tests
index 52e089e..1cc79c4 100644 (file)
@@ -5,12 +5,9 @@ import numpy as np
 import os
 import sys
 
-import mlir.all_passes_registration
-
 from mlir import ir
 from mlir import runtime as rt
 from mlir import execution_engine
-from mlir import passmanager
 from mlir.dialects import sparse_tensor as st
 from mlir.dialects import builtin
 from mlir.dialects.linalg.opdsl import lang as dsl
@@ -18,6 +15,7 @@ from mlir.dialects.linalg.opdsl import lang as dsl
 _SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
 sys.path.append(_SCRIPT_PATH)
 from tools import np_to_sparse_tensor as test_tools
+from tools import sparse_compiler
 
 # TODO: Use linalg_structured_op to generate the kernel after making it to
 # handle sparse tensor outputs.
@@ -61,21 +59,10 @@ func @main(%ad: tensor<3x4xf64>, %bd: tensor<3x4xf64>) -> tensor<3x4xf64, #DCSR>
 """
 
 
-class _SparseCompiler:
-  """Sparse compiler passes."""
-
-  def __init__(self):
-    self.pipeline = (
-        f'sparse-compiler{{reassociate-fp-reductions=1 enable-index-optimizations=1}}')
-
-  def __call__(self, module: ir.Module):
-    passmanager.PassManager.parse(self.pipeline).run(module)
-
-
 def _run_test(support_lib, kernel):
   """Compiles, runs and checks results."""
   module = ir.Module.parse(kernel)
-  _SparseCompiler()(module)
+  sparse_compiler.SparseCompiler(options='')(module)
   engine = execution_engine.ExecutionEngine(
       module, opt_level=0, shared_libs=[support_lib])
 
index c29f618..5e2210b 100644 (file)
@@ -3,18 +3,19 @@
 
 import ctypes
 import os
+import sys
 import tempfile
 
-import mlir.all_passes_registration
-
 from mlir import execution_engine
 from mlir import ir
-from mlir import passmanager
 from mlir import runtime as rt
 
 from mlir.dialects import builtin
 from mlir.dialects import sparse_tensor as st
 
+_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(_SCRIPT_PATH)
+from tools import sparse_compiler
 
 # TODO: move more into actual IR building.
 def boilerplate(attr: st.EncodingAttr):
@@ -68,18 +69,6 @@ def build_compile_and_run_output(attr: st.EncodingAttr, support_lib: str,
       quit('FAILURE')
 
 
-class SparseCompiler:
-  """Sparse compiler passes."""
-
-  def __init__(self):
-    pipeline = (
-        f'sparse-compiler{{reassociate-fp-reductions=1 enable-index-optimizations=1}}')
-    self.pipeline = pipeline
-
-  def __call__(self, module: ir.Module):
-    passmanager.PassManager.parse(self.pipeline).run(module)
-
-
 def main():
   support_lib = os.getenv('SUPPORT_LIB')
   assert support_lib is not None, 'SUPPORT_LIB is undefined'
@@ -103,7 +92,7 @@ def main():
       for ordering in orderings:
         for bwidth in bitwidths:
           attr = st.EncodingAttr.get(level, ordering, bwidth, bwidth)
-          compiler = SparseCompiler()
+          compiler = sparse_compiler.SparseCompiler(options='')
           build_compile_and_run_output(attr, support_lib, compiler)
           count = count + 1
 
index ccf1ffd..7958e76 100644 (file)
@@ -6,21 +6,23 @@ import errno
 import itertools
 import os
 import sys
+
 from typing import List, Callable
 
 import numpy as np
 
-import mlir.all_passes_registration
-
 from mlir import ir
 from mlir import runtime as rt
 from mlir.execution_engine import ExecutionEngine
-from mlir.passmanager import PassManager
 
 from mlir.dialects import builtin
 from mlir.dialects import std
 from mlir.dialects import sparse_tensor as st
 
+_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(_SCRIPT_PATH)
+from tools import sparse_compiler
+
 # ===----------------------------------------------------------------------=== #
 
 # TODO: move this boilerplate to its own module, so it can be used by
@@ -137,13 +139,15 @@ class StressTest:
       f.write(str(self._module))
     return self
 
-  def compile(self, compiler: Callable[[ir.Module], ExecutionEngine]):
+  def compile(self, compiler, support_lib: str):
     """Compile the ir.Module."""
     assert self._module is not None, \
         'StressTest: must call build() before compile()'
     assert self._engine is None, \
         'StressTest: must not call compile() repeatedly'
-    self._engine = compiler(self._module)
+    compiler(self._module)
+    self._engine = ExecutionEngine(
+        self._module, opt_level=0, shared_libs=[support_lib])
     return self
 
   def run(self, np_arg0: np.ndarray) -> np.ndarray:
@@ -163,24 +167,6 @@ class StressTest:
 
 # ===----------------------------------------------------------------------=== #
 
-# TODO: move this boilerplate to its own module, so it can be used by
-# other tests and programs.
-class SparseCompiler:
-  """Sparse compiler passes."""
-
-  def __init__(self, sparsification_options: str, support_lib: str):
-    self._support_lib = support_lib
-    self._pipeline = (
-        f'sparse-compiler{{{sparsification_options} reassociate-fp-reductions=1 enable-index-optimizations=1}}')
-    # Must be in the scope of a `with ir.Context():`
-    self._passmanager = PassManager.parse(self._pipeline)
-
-  def __call__(self, module: ir.Module) -> ExecutionEngine:
-    self._passmanager.run(module)
-    return ExecutionEngine(module, opt_level=0, shared_libs=[self._support_lib])
-
-# ===----------------------------------------------------------------------=== #
-
 def main():
   """
   USAGE: python3 test_stress.py [raw_module.mlir [compiled_module.mlir]]
@@ -208,7 +194,7 @@ def main():
         f'vectorization-strategy={vec} '
         f'vl={vl} '
         f'enable-simd-index32={e}')
-    compiler = SparseCompiler(sparsification_options, support_lib)
+    compiler = sparse_compiler.SparseCompiler(options=sparsification_options)
     f64 = ir.F64Type.get()
     # Be careful about increasing this because
     #     len(types) = 1 + 2^rank * rank! * len(bitwidths)^2
@@ -243,12 +229,10 @@ def main():
       size *= d
     np_arg0 = np.arange(size, dtype=tyconv.irtype_to_dtype(f64)).reshape(*shape)
     np_out = (
-        StressTest(tyconv)
-        .build(types)
-        .writeTo(sys.argv[1] if len(sys.argv) > 1 else None)
-        .compile(compiler)
-        .writeTo(sys.argv[2] if len(sys.argv) > 2 else None)
-        .run(np_arg0))
+        StressTest(tyconv).build(types).writeTo(
+            sys.argv[1] if len(sys.argv) > 1 else None).compile(
+                compiler, support_lib).writeTo(
+                    sys.argv[2] if len(sys.argv) > 2 else None).run(np_arg0))
     # CHECK: Passed
     if np.allclose(np_out, np_arg0):
       print('Passed')
diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/tools/sparse_compiler.py b/mlir/test/Integration/Dialect/SparseTensor/python/tools/sparse_compiler.py
new file mode 100644 (file)
index 0000000..47b145f
--- /dev/null
@@ -0,0 +1,19 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#  This file contains the sparse compiler class.
+
+from mlir import all_passes_registration
+from mlir import ir
+from mlir import passmanager
+
+class SparseCompiler:
+  """Sparse compiler definition."""
+
+  def __init__(self, options: str):
+    pipeline = f'sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}}'
+    self.pipeline = pipeline
+
+  def __call__(self, module: ir.Module):
+    passmanager.PassManager.parse(self.pipeline).run(module)