[NFC][Py Reformat] Reformat python files in mlir subdir
authorTobias Hieta <tobias@hieta.se>
Wed, 17 May 2023 14:53:39 +0000 (16:53 +0200)
committerTobias Hieta <tobias@hieta.se>
Fri, 26 May 2023 06:05:40 +0000 (08:05 +0200)
This is an ongoing series of commits that are reformatting our
Python code.

Reformatting is done with `black`.

If you end up having problems merging this commit because you
have made changes to a python file, the best way to handle that
is to run git checkout --ours <yourfile> and then reformat it
with black.

If you run into any problems, post to discourse about it and
we will try to help.

RFC Thread below:

https://discourse.llvm.org/t/rfc-document-and-standardize-python-code-style

Differential Revision: https://reviews.llvm.org/D150782

163 files changed:
mlir/benchmark/python/benchmark_sparse.py
mlir/benchmark/python/common.py
mlir/examples/standalone/test/CAPI/lit.local.cfg
mlir/examples/standalone/test/lit.cfg.py
mlir/examples/standalone/test/python/lit.local.cfg
mlir/examples/standalone/test/python/smoketest.py
mlir/python/mlir/_mlir_libs/__init__.py
mlir/python/mlir/dialects/_arith_ops_ext.py
mlir/python/mlir/dialects/_bufferization_ops_ext.py
mlir/python/mlir/dialects/_builtin_ops_ext.py
mlir/python/mlir/dialects/_func_ops_ext.py
mlir/python/mlir/dialects/_linalg_ops_ext.py
mlir/python/mlir/dialects/_loop_transform_ops_ext.py
mlir/python/mlir/dialects/_memref_ops_ext.py
mlir/python/mlir/dialects/_ml_program_ops_ext.py
mlir/python/mlir/dialects/_ods_common.py
mlir/python/mlir/dialects/_pdl_ops_ext.py
mlir/python/mlir/dialects/_scf_ops_ext.py
mlir/python/mlir/dialects/_structured_transform_ops_ext.py
mlir/python/mlir/dialects/_tensor_ops_ext.py
mlir/python/mlir/dialects/_transform_ops_ext.py
mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py
mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py
mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
mlir/python/mlir/dialects/linalg/opdsl/lang/types.py
mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
mlir/python/mlir/dialects/python_test.py
mlir/python/mlir/dialects/transform/__init__.py
mlir/python/mlir/execution_engine.py
mlir/python/mlir/ir.py
mlir/python/mlir/runtime/np_to_memref.py
mlir/test/CAPI/lit.local.cfg
mlir/test/Conversion/GPUToCUDA/lit.local.cfg
mlir/test/Conversion/GPUToROCm/lit.local.cfg
mlir/test/Examples/Toy/Ch6/lit.local.cfg
mlir/test/Examples/Toy/Ch7/lit.local.cfg
mlir/test/Examples/lit.local.cfg
mlir/test/Examples/standalone/lit.local.cfg
mlir/test/Integration/Dialect/Async/CPU/lit.local.cfg
mlir/test/Integration/Dialect/LLVMIR/CPU/X86/lit.local.cfg
mlir/test/Integration/Dialect/LLVMIR/CPU/lit.local.cfg
mlir/test/Integration/Dialect/SparseTensor/CPU/lit.local.cfg
mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/lit.local.cfg
mlir/test/Integration/Dialect/SparseTensor/python/lit.local.cfg
mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py
mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
mlir/test/Integration/Dialect/SparseTensor/python/test_elementwise_add_sparse_output.py
mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
mlir/test/Integration/Dialect/SparseTensor/python/tools/np_to_sparse_tensor.py
mlir/test/Integration/Dialect/SparseTensor/python/tools/sparse_compiler.py
mlir/test/Integration/Dialect/SparseTensor/taco/lit.local.cfg
mlir/test/Integration/Dialect/SparseTensor/taco/test_MTTKRP.py
mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py
mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMM.py
mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMV.py
mlir/test/Integration/Dialect/SparseTensor/taco/test_Tensor.py
mlir/test/Integration/Dialect/SparseTensor/taco/test_scalar_tensor_algebra.py
mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_complex.py
mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py
mlir/test/Integration/Dialect/SparseTensor/taco/test_true_dense_tensor_algebra.py
mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py
mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_io.py
mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py
mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_sparse_compiler.py
mlir/test/Integration/Dialect/SparseTensor/taco/tools/testing_utils.py
mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py
mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_io.py
mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_utils.py
mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg
mlir/test/Integration/Dialect/Vector/CPU/ArmSME/lit.local.cfg
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/lit.local.cfg
mlir/test/Integration/Dialect/Vector/CPU/X86Vector/lit.local.cfg
mlir/test/Integration/Dialect/Vector/GPU/CUDA/lit.local.cfg
mlir/test/Integration/GPU/CUDA/TensorCore/lit.local.cfg
mlir/test/Integration/GPU/CUDA/lit.local.cfg
mlir/test/Integration/GPU/ROCM/lit.local.cfg
mlir/test/Integration/lit.local.cfg
mlir/test/Unit/lit.cfg.py
mlir/test/lib/Dialect/Test/lit.local.cfg
mlir/test/lib/Dialect/Transform/lit.local.cfg
mlir/test/lib/Tools/PDLL/lit.local.cfg
mlir/test/lib/Transforms/lit.local.cfg
mlir/test/lit.cfg.py
mlir/test/mlir-cpu-runner/lit.local.cfg
mlir/test/mlir-pdll-lsp-server/lit.local.cfg
mlir/test/mlir-pdll/lit.local.cfg
mlir/test/mlir-spirv-cpu-runner/lit.local.cfg
mlir/test/mlir-vulkan-runner/lit.local.cfg
mlir/test/python/develoment_files.py
mlir/test/python/dialects/arith_dialect.py
mlir/test/python/dialects/async_dialect.py
mlir/test/python/dialects/builtin.py
mlir/test/python/dialects/complex_dialect.py
mlir/test/python/dialects/func.py
mlir/test/python/dialects/gpu.py
mlir/test/python/dialects/linalg/opdsl/arguments.py
mlir/test/python/dialects/linalg/opdsl/assignments.py
mlir/test/python/dialects/linalg/opdsl/doctests.py
mlir/test/python/dialects/linalg/opdsl/emit_convolution.py
mlir/test/python/dialects/linalg/opdsl/emit_fill.py
mlir/test/python/dialects/linalg/opdsl/emit_matmul.py
mlir/test/python/dialects/linalg/opdsl/emit_misc.py
mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
mlir/test/python/dialects/linalg/opdsl/lit.local.cfg
mlir/test/python/dialects/linalg/opdsl/metadata.py
mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
mlir/test/python/dialects/linalg/ops.py
mlir/test/python/dialects/math_dialect.py
mlir/test/python/dialects/memref.py
mlir/test/python/dialects/ml_program.py
mlir/test/python/dialects/ods_helpers.py
mlir/test/python/dialects/pdl_ops.py
mlir/test/python/dialects/python_test.py
mlir/test/python/dialects/quant.py
mlir/test/python/dialects/scf.py
mlir/test/python/dialects/shape.py
mlir/test/python/dialects/sparse_tensor/dialect.py
mlir/test/python/dialects/sparse_tensor/passes.py
mlir/test/python/dialects/tensor.py
mlir/test/python/dialects/transform.py
mlir/test/python/dialects/transform_loop_ext.py
mlir/test/python/dialects/transform_structured_ext.py
mlir/test/python/dialects/vector.py
mlir/test/python/execution_engine.py
mlir/test/python/integration/dialects/linalg/opsrun.py
mlir/test/python/ir/affine_expr.py
mlir/test/python/ir/affine_map.py
mlir/test/python/ir/array_attributes.py
mlir/test/python/ir/attributes.py
mlir/test/python/ir/blocks.py
mlir/test/python/ir/builtin_types.py
mlir/test/python/ir/context_managers.py
mlir/test/python/ir/debug.py
mlir/test/python/ir/diagnostic_handler.py
mlir/test/python/ir/dialects.py
mlir/test/python/ir/exception.py
mlir/test/python/ir/insertion_point.py
mlir/test/python/ir/integer_set.py
mlir/test/python/ir/location.py
mlir/test/python/ir/module.py
mlir/test/python/ir/operation.py
mlir/test/python/ir/symbol_table.py
mlir/test/python/ir/value.py
mlir/test/python/lit.local.cfg
mlir/test/python/pass_manager.py
mlir/test/tblgen-lsp-server/lit.local.cfg
mlir/utils/gdb-scripts/prettyprinters.py
mlir/utils/generate-test-checks.py
mlir/utils/jupyter/mlir_opt_kernel/__main__.py
mlir/utils/jupyter/mlir_opt_kernel/install.py
mlir/utils/jupyter/mlir_opt_kernel/kernel.py
mlir/utils/lldb-scripts/mlirDataFormatters.py
mlir/utils/mbr/mbr/__init__.py
mlir/utils/mbr/mbr/discovery.py
mlir/utils/mbr/mbr/main.py
mlir/utils/mbr/mbr/stats.py
mlir/utils/spirv/gen_spirv_dialect.py

index 6d7a396..72b3ef1 100644 (file)
@@ -25,7 +25,7 @@ from common import setup_passes
 def matmul_dsl(
     A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K),
     B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N),
-    C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True)
+    C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True),
 ):
     """Helper function for mlir sparse matrix multiplication benchmark."""
     C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
@@ -43,6 +43,7 @@ def benchmark_sparse_mlir_multiplication():
         param2_type = ir.RankedTensorType.get([1500, 2000], f64)
         result_type = ir.RankedTensorType.get([1000, 2000], f64)
         with ir.InsertionPoint(module.body):
+
             @func.FuncOp.from_py_func(param1_type, param2_type, result_type)
             def sparse_kernel(x, y, z):
                 return matmul_dsl(x, y, outs=[z])
@@ -51,37 +52,34 @@ def benchmark_sparse_mlir_multiplication():
         with ir.Context(), ir.Location.unknown():
             kernel_func = get_kernel_func_from_module(module)
             timer_func = emit_timer_func()
-            wrapped_func = emit_benchmark_wrapped_main_func(
-                kernel_func,
-                timer_func
-            )
+            wrapped_func = emit_benchmark_wrapped_main_func(kernel_func, timer_func)
             main_module_with_benchmark = ir.Module.parse(
                 str(timer_func) + str(wrapped_func) + str(kernel_func)
             )
             setup_passes(main_module_with_benchmark)
             c_runner_utils = os.getenv("MLIR_C_RUNNER_UTILS", "")
-            assert os.path.exists(c_runner_utils),\
-                f"{c_runner_utils} does not exist." \
-                f" Please pass a valid value for" \
+            assert os.path.exists(c_runner_utils), (
+                f"{c_runner_utils} does not exist."
+                f" Please pass a valid value for"
                 f" MLIR_C_RUNNER_UTILS environment variable."
+            )
             runner_utils = os.getenv("MLIR_RUNNER_UTILS", "")
-            assert os.path.exists(runner_utils),\
-                f"{runner_utils} does not exist." \
-                f" Please pass a valid value for MLIR_RUNNER_UTILS" \
+            assert os.path.exists(runner_utils), (
+                f"{runner_utils} does not exist."
+                f" Please pass a valid value for MLIR_RUNNER_UTILS"
                 f" environment variable."
+            )
 
             engine = ExecutionEngine(
                 main_module_with_benchmark,
                 3,
-                shared_libs=[c_runner_utils, runner_utils]
+                shared_libs=[c_runner_utils, runner_utils],
             )
             return engine.invoke
 
     def runner(engine_invoke):
         compiled_program_args = []
-        for argument_type in [
-            result_type, param1_type, param2_type, result_type
-        ]:
+        for argument_type in [result_type, param1_type, param2_type, result_type]:
             argument_type_str = str(argument_type)
             dimensions_str = re.sub("<|>|tensor", "", argument_type_str)
             dimensions = [int(dim) for dim in dimensions_str.split("x")[:-1]]
@@ -111,6 +109,7 @@ def benchmark_np_matrix_multiplication():
     benchmark, we don't have any `compiler` function returned. We just return
     the `runner` function.
     """
+
     def runner():
         argument1 = np.random.uniform(low=0.0, high=100.0, size=(1000, 1500))
         argument2 = np.random.uniform(low=0.0, high=100.0, size=(1500, 2000))
index 3634641..c605726 100644 (file)
@@ -10,8 +10,7 @@ from mlir.passmanager import PassManager
 
 
 def setup_passes(mlir_module):
-    """Setup pass pipeline parameters for benchmark functions.
-    """
+    """Setup pass pipeline parameters for benchmark functions."""
     opt = (
         "parallelization-strategy=none"
         " vectorization-strategy=none vl=1 enable-simd-index32=False"
@@ -43,12 +42,15 @@ def get_kernel_func_from_module(module: ir.Module) -> func.FuncOp:
     This function only works for a module with one region, one block, and one
     operation.
     """
-    assert len(module.operation.regions) == 1, \
-        "Expected kernel module to have only one region"
-    assert len(module.operation.regions[0].blocks) == 1, \
-        "Expected kernel module to have only one block"
-    assert len(module.operation.regions[0].blocks[0].operations) == 1, \
-        "Expected kernel module to have only one operation"
+    assert (
+        len(module.operation.regions) == 1
+    ), "Expected kernel module to have only one region"
+    assert (
+        len(module.operation.regions[0].blocks) == 1
+    ), "Expected kernel module to have only one block"
+    assert (
+        len(module.operation.regions[0].blocks[0].operations) == 1
+    ), "Expected kernel module to have only one operation"
     return module.operation.regions[0].blocks[0].operations[0]
 
 
@@ -57,8 +59,7 @@ def emit_timer_func() -> func.FuncOp:
     used, the `MLIR_RUNNER_UTILS` and `MLIR_C_RUNNER_UTILS` must be included.
     """
     i64_type = ir.IntegerType.get_signless(64)
-    nanoTime = func.FuncOp(
-        "nanoTime", ([], [i64_type]), visibility="private")
+    nanoTime = func.FuncOp("nanoTime", ([], [i64_type]), visibility="private")
     nanoTime.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
     return nanoTime
 
@@ -76,9 +77,8 @@ def emit_benchmark_wrapped_main_func(kernel_func, timer_func):
     wrapped_func = func.FuncOp(
         # Same signature and an extra buffer of indices to save timings.
         "main",
-        (kernel_func.arguments.types + [memref_of_i64_type],
-         kernel_func.type.results),
-        visibility="public"
+        (kernel_func.arguments.types + [memref_of_i64_type], kernel_func.type.results),
+        visibility="public",
     )
     wrapped_func.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
 
@@ -88,13 +88,13 @@ def emit_benchmark_wrapped_main_func(kernel_func, timer_func):
         zero = arith.ConstantOp.create_index(0)
         n_iterations = memref.DimOp(ir.IndexType.get(), timer_buffer, zero)
         one = arith.ConstantOp.create_index(1)
-        iter_args = list(wrapped_func.arguments[-num_results - 1:-1])
+        iter_args = list(wrapped_func.arguments[-num_results - 1 : -1])
         loop = scf.ForOp(zero, n_iterations, one, iter_args)
         with ir.InsertionPoint(loop.body):
             start = func.CallOp(timer_func, [])
             call = func.CallOp(
                 kernel_func,
-                wrapped_func.arguments[:-num_results - 1] + loop.inner_iter_args
+                wrapped_func.arguments[: -num_results - 1] + loop.inner_iter_args,
             )
             end = func.CallOp(timer_func, [])
             time_taken = arith.SubIOp(end, start)
index 601ac8f..e27dddd 100644 (file)
@@ -16,52 +16,55 @@ from lit.llvm.subst import FindTool
 # Configuration file for the 'lit' test runner.
 
 # name: The name of this test suite.
-config.name = 'STANDALONE'
+config.name = "STANDALONE"
 
 config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
 
 # suffixes: A list of file extensions to treat as test files.
-config.suffixes = ['.mlir']
+config.suffixes = [".mlir"]
 
 # test_source_root: The root path where tests are located.
 config.test_source_root = os.path.dirname(__file__)
 
 # test_exec_root: The root path where tests should be run.
-config.test_exec_root = os.path.join(config.standalone_obj_root, 'test')
+config.test_exec_root = os.path.join(config.standalone_obj_root, "test")
 
-config.substitutions.append(('%PATH%', config.environment['PATH']))
-config.substitutions.append(('%shlibext', config.llvm_shlib_ext))
+config.substitutions.append(("%PATH%", config.environment["PATH"]))
+config.substitutions.append(("%shlibext", config.llvm_shlib_ext))
 
-llvm_config.with_system_environment(
-    ['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP'])
+llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"])
 
 llvm_config.use_default_substitutions()
 
 # excludes: A list of directories to exclude from the testsuite. The 'Inputs'
 # subdirectories contain auxiliary inputs for various tests in their parent
 # directories.
-config.excludes = ['Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt']
+config.excludes = ["Inputs", "Examples", "CMakeLists.txt", "README.txt", "LICENSE.txt"]
 
 # test_exec_root: The root path where tests should be run.
-config.test_exec_root = os.path.join(config.standalone_obj_root, 'test')
-config.standalone_tools_dir = os.path.join(config.standalone_obj_root, 'bin')
-config.standalone_libs_dir = os.path.join(config.standalone_obj_root, 'lib')
+config.test_exec_root = os.path.join(config.standalone_obj_root, "test")
+config.standalone_tools_dir = os.path.join(config.standalone_obj_root, "bin")
+config.standalone_libs_dir = os.path.join(config.standalone_obj_root, "lib")
 
-config.substitutions.append(('%standalone_libs', config.standalone_libs_dir))
+config.substitutions.append(("%standalone_libs", config.standalone_libs_dir))
 
 # Tweak the PATH to include the tools dir.
-llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
+llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True)
 
 tool_dirs = [config.standalone_tools_dir, config.llvm_tools_dir]
 tools = [
-    'mlir-opt',
-    'standalone-capi-test',
-    'standalone-opt',
-    'standalone-translate',
+    "mlir-opt",
+    "standalone-capi-test",
+    "standalone-opt",
+    "standalone-translate",
 ]
 
 llvm_config.add_tool_substitutions(tools, tool_dirs)
 
-llvm_config.with_environment('PYTHONPATH', [
-    os.path.join(config.mlir_obj_dir, 'python_packages', 'standalone'),
-], append_path=True)
+llvm_config.with_environment(
+    "PYTHONPATH",
+    [
+        os.path.join(config.mlir_obj_dir, "python_packages", "standalone"),
+    ],
+    append_path=True,
+)
index b70b9d7..3394f18 100644 (file)
@@ -1,4 +1,4 @@
-config.suffixes.add('.py')
+config.suffixes.add(".py")
 
 if not config.enable_bindings_python:
-  config.unsupported = True
+    config.unsupported = True
index 0d8f41c..08e08cb 100644 (file)
@@ -1,17 +1,16 @@
 # RUN: %python %s | FileCheck %s
 
 from mlir_standalone.ir import *
-from mlir_standalone.dialects import (
-  builtin as builtin_d,
-  standalone as standalone_d
-)
+from mlir_standalone.dialects import builtin as builtin_d, standalone as standalone_d
 
 with Context():
-  standalone_d.register_dialect()
-  module = Module.parse("""
+    standalone_d.register_dialect()
+    module = Module.parse(
+        """
     %0 = arith.constant 2 : i32
     %1 = standalone.foo %0 : i32
-    """)
-  # CHECK: %[[C:.*]] = arith.constant 2 : i32
-  # CHECK: standalone.foo %[[C]] : i32
-  print(str(module))
+    """
+    )
+    # CHECK: %[[C:.*]] = arith.constant 2 : i32
+    # CHECK: standalone.foo %[[C]] : i32
+    print(str(module))
index 7d3d1f6..03fcb10 100644 (file)
@@ -10,26 +10,26 @@ _this_dir = os.path.dirname(__file__)
 
 
 def get_lib_dirs() -> Sequence[str]:
-  """Gets the lib directory for linking to shared libraries.
+    """Gets the lib directory for linking to shared libraries.
 
-  On some platforms, the package may need to be built specially to export
-  development libraries.
-  """
-  return [_this_dir]
+    On some platforms, the package may need to be built specially to export
+    development libraries.
+    """
+    return [_this_dir]
 
 
 def get_include_dirs() -> Sequence[str]:
-  """Gets the include directory for compiling against exported C libraries.
+    """Gets the include directory for compiling against exported C libraries.
 
-  Depending on how the package was build, development C libraries may or may
-  not be present.
-  """
-  return [os.path.join(_this_dir, "include")]
+    Depending on how the package was build, development C libraries may or may
+    not be present.
+    """
+    return [os.path.join(_this_dir, "include")]
 
 
 # Perform Python level site initialization. This involves:
 #   1. Attempting to load initializer modules, specific to the distribution.
-#   2. Defining the concrete mlir.ir.Context that does site specific 
+#   2. Defining the concrete mlir.ir.Context that does site specific
 #      initialization.
 #
 # Aside from just being far more convenient to do this at the Python level,
@@ -38,91 +38,106 @@ def get_include_dirs() -> Sequence[str]:
 # in the scope of the base class __init__).
 #
 # For #1, we:
-#   a. Probe for modules named '_mlirRegisterEverything' and 
-#     '_site_initialize_{i}', where 'i' is a number starting at zero and 
+#   a. Probe for modules named '_mlirRegisterEverything' and
+#     '_site_initialize_{i}', where 'i' is a number starting at zero and
 #     proceeding so long as a module with the name is found.
 #   b. If the module has a 'register_dialects' attribute, it will be called
 #     immediately with a DialectRegistry to populate.
 #   c. If the module has a 'context_init_hook', it will be added to a list
-#     of callbacks that are invoked as the last step of Context 
+#     of callbacks that are invoked as the last step of Context
 #     initialization (and passed the Context under construction).
 #
 # This facility allows downstreams to customize Context creation to their
 # needs.
 def _site_initialize():
-  import importlib
-  import itertools
-  import logging
-  from ._mlir import ir
-  logger = logging.getLogger(__name__)
-  registry = ir.DialectRegistry()
-  post_init_hooks = []
-
-  def process_initializer_module(module_name):
-    try:
-      m = importlib.import_module(f".{module_name}", __name__)
-    except ModuleNotFoundError:
-      return False
-    except ImportError:
-      message = (f"Error importing mlir initializer {module_name}. This may "
-      "happen in unclean incremental builds but is likely a real bug if "
-      "encountered otherwise and the MLIR Python API may not function.")
-      logger.warning(message, exc_info=True)
-
-    logger.debug("Initializing MLIR with module: %s", module_name)
-    if hasattr(m, "register_dialects"):
-      logger.debug("Registering dialects from initializer %r", m)
-      m.register_dialects(registry)
-    if hasattr(m, "context_init_hook"):
-      logger.debug("Adding context init hook from %r", m)
-      post_init_hooks.append(m.context_init_hook)
-    return True
-
-
-  # If _mlirRegisterEverything is built, then include it as an initializer
-  # module.
-  process_initializer_module("_mlirRegisterEverything")
-
-  # Load all _site_initialize_{i} modules, where 'i' is a number starting
-  # at 0.
-  for i in itertools.count():
-    module_name = f"_site_initialize_{i}"
-    if not process_initializer_module(module_name):
-      break
-
-  class Context(ir._BaseContext):
-    def __init__(self, *args, **kwargs):
-      super().__init__(*args, **kwargs)
-      self.append_dialect_registry(registry)
-      for hook in post_init_hooks:
-        hook(self)
-      # TODO: There is some debate about whether we should eagerly load
-      # all dialects. It is being done here in order to preserve existing
-      # behavior. See: https://github.com/llvm/llvm-project/issues/56037
-      self.load_all_available_dialects()
-  ir.Context = Context
-
-  class MLIRError(Exception):
-    """
-    An exception with diagnostic information. Has the following fields:
-      message: str
-      error_diagnostics: List[ir.DiagnosticInfo]
-    """
-    def __init__(self, message, error_diagnostics):
-      self.message = message
-      self.error_diagnostics = error_diagnostics
-      super().__init__(message, error_diagnostics)
-
-    def __str__(self):
-      s = self.message
-      if self.error_diagnostics:
-        s += ':'
-      for diag in self.error_diagnostics:
-        s += "\nerror: "  + str(diag.location)[4:-1] + ": " + diag.message.replace('\n', '\n  ')
-        for note in diag.notes:
-          s += "\n note: "  + str(note.location)[4:-1] + ": " + note.message.replace('\n', '\n  ')
-      return s
-  ir.MLIRError = MLIRError
+    import importlib
+    import itertools
+    import logging
+    from ._mlir import ir
+
+    logger = logging.getLogger(__name__)
+    registry = ir.DialectRegistry()
+    post_init_hooks = []
+
+    def process_initializer_module(module_name):
+        try:
+            m = importlib.import_module(f".{module_name}", __name__)
+        except ModuleNotFoundError:
+            return False
+        except ImportError:
+            message = (
+                f"Error importing mlir initializer {module_name}. This may "
+                "happen in unclean incremental builds but is likely a real bug if "
+                "encountered otherwise and the MLIR Python API may not function."
+            )
+            logger.warning(message, exc_info=True)
+
+        logger.debug("Initializing MLIR with module: %s", module_name)
+        if hasattr(m, "register_dialects"):
+            logger.debug("Registering dialects from initializer %r", m)
+            m.register_dialects(registry)
+        if hasattr(m, "context_init_hook"):
+            logger.debug("Adding context init hook from %r", m)
+            post_init_hooks.append(m.context_init_hook)
+        return True
+
+    # If _mlirRegisterEverything is built, then include it as an initializer
+    # module.
+    process_initializer_module("_mlirRegisterEverything")
+
+    # Load all _site_initialize_{i} modules, where 'i' is a number starting
+    # at 0.
+    for i in itertools.count():
+        module_name = f"_site_initialize_{i}"
+        if not process_initializer_module(module_name):
+            break
+
+    class Context(ir._BaseContext):
+        def __init__(self, *args, **kwargs):
+            super().__init__(*args, **kwargs)
+            self.append_dialect_registry(registry)
+            for hook in post_init_hooks:
+                hook(self)
+            # TODO: There is some debate about whether we should eagerly load
+            # all dialects. It is being done here in order to preserve existing
+            # behavior. See: https://github.com/llvm/llvm-project/issues/56037
+            self.load_all_available_dialects()
+
+    ir.Context = Context
+
+    class MLIRError(Exception):
+        """
+        An exception with diagnostic information. Has the following fields:
+          message: str
+          error_diagnostics: List[ir.DiagnosticInfo]
+        """
+
+        def __init__(self, message, error_diagnostics):
+            self.message = message
+            self.error_diagnostics = error_diagnostics
+            super().__init__(message, error_diagnostics)
+
+        def __str__(self):
+            s = self.message
+            if self.error_diagnostics:
+                s += ":"
+            for diag in self.error_diagnostics:
+                s += (
+                    "\nerror: "
+                    + str(diag.location)[4:-1]
+                    + ": "
+                    + diag.message.replace("\n", "\n  ")
+                )
+                for note in diag.notes:
+                    s += (
+                        "\n note: "
+                        + str(note.location)[4:-1]
+                        + ": "
+                        + note.message.replace("\n", "\n  ")
+                    )
+            return s
+
+    ir.MLIRError = MLIRError
 
 
 _site_initialize()
index 2408593..df38f87 100644 (file)
@@ -3,72 +3,67 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from ..ir import *
-  from ._ods_common import get_default_loc_context as _get_default_loc_context
+    from ..ir import *
+    from ._ods_common import get_default_loc_context as _get_default_loc_context
 
-  from typing import Any, List, Union
+    from typing import Any, List, Union
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
 
 
 def _isa(obj: Any, cls: type):
-  try:
-    cls(obj)
-  except ValueError:
-    return False
-  return True
+    try:
+        cls(obj)
+    except ValueError:
+        return False
+    return True
 
 
 def _is_any_of(obj: Any, classes: List[type]):
-  return any(_isa(obj, cls) for cls in classes)
+    return any(_isa(obj, cls) for cls in classes)
 
 
 def _is_integer_like_type(type: Type):
-  return _is_any_of(type, [IntegerType, IndexType])
+    return _is_any_of(type, [IntegerType, IndexType])
 
 
 def _is_float_type(type: Type):
-  return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
+    return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
 
 
 class ConstantOp:
-  """Specialization for the constant op class."""
-
-  def __init__(self,
-               result: Type,
-               value: Union[int, float, Attribute],
-               *,
-               loc=None,
-               ip=None):
-    if isinstance(value, int):
-      super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
-    elif isinstance(value, float):
-      super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
-    else:
-      super().__init__(value, loc=loc, ip=ip)
-
-  @classmethod
-  def create_index(cls, value: int, *, loc=None, ip=None):
-    """Create an index-typed constant."""
-    return cls(
-        IndexType.get(context=_get_default_loc_context(loc)),
-        value,
-        loc=loc,
-        ip=ip)
-
-  @property
-  def type(self):
-    return self.results[0].type
-
-  @property
-  def value(self):
-    return Attribute(self.operation.attributes["value"])
-
-  @property
-  def literal_value(self) -> Union[int, float]:
-    if _is_integer_like_type(self.type):
-      return IntegerAttr(self.value).value
-    elif _is_float_type(self.type):
-      return FloatAttr(self.value).value
-    else:
-      raise ValueError("only integer and float constants have literal values")
+    """Specialization for the constant op class."""
+
+    def __init__(
+        self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
+    ):
+        if isinstance(value, int):
+            super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
+        elif isinstance(value, float):
+            super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
+        else:
+            super().__init__(value, loc=loc, ip=ip)
+
+    @classmethod
+    def create_index(cls, value: int, *, loc=None, ip=None):
+        """Create an index-typed constant."""
+        return cls(
+            IndexType.get(context=_get_default_loc_context(loc)), value, loc=loc, ip=ip
+        )
+
+    @property
+    def type(self):
+        return self.results[0].type
+
+    @property
+    def value(self):
+        return Attribute(self.operation.attributes["value"])
+
+    @property
+    def literal_value(self) -> Union[int, float]:
+        if _is_integer_like_type(self.type):
+            return IntegerAttr(self.value).value
+        elif _is_float_type(self.type):
+            return FloatAttr(self.value).value
+        else:
+            raise ValueError("only integer and float constants have literal values")
index 6ed35f4..1066cb4 100644 (file)
@@ -3,36 +3,39 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from typing import Sequence, Union
-  from ..ir import *
-  from ._ods_common import get_default_loc_context
+    from typing import Sequence, Union
+    from ..ir import *
+    from ._ods_common import get_default_loc_context
 
-  from typing import Any, List, Union
+    from typing import Any, List, Union
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
 
 
 class AllocTensorOp:
-  """Extends the bufferization.alloc_tensor op."""
+    """Extends the bufferization.alloc_tensor op."""
 
-  def __init__(self,
-               tensor_type: Type,
-               dynamic_sizes: Sequence[Value],
-               copy: Value,
-               size_hint: Value,
-               escape: BoolAttr,
-               *,
-               loc=None,
-               ip=None):
-    """Constructs an `alloc_tensor` with static and/or dynamic sizes."""
-    context = get_default_loc_context(loc)
-    attributes = {}
-    if escape:
-      attributes["escape"] = escape
-    op = self.build_generic(
-        results=[tensor_type],
-        operands=[dynamic_sizes, copy, size_hint],
-        attributes=attributes,
-        loc=loc,
-        ip=ip)
-    OpView.__init__(self, op)
+    def __init__(
+        self,
+        tensor_type: Type,
+        dynamic_sizes: Sequence[Value],
+        copy: Value,
+        size_hint: Value,
+        escape: BoolAttr,
+        *,
+        loc=None,
+        ip=None
+    ):
+        """Constructs an `alloc_tensor` with static and/or dynamic sizes."""
+        context = get_default_loc_context(loc)
+        attributes = {}
+        if escape:
+            attributes["escape"] = escape
+        op = self.build_generic(
+            results=[tensor_type],
+            operands=[dynamic_sizes, copy, size_hint],
+            attributes=attributes,
+            loc=loc,
+            ip=ip,
+        )
+        OpView.__init__(self, op)
index b69163f..27a6012 100644 (file)
@@ -3,18 +3,18 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from ..ir import *
+    from ..ir import *
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
+
 
 class ModuleOp:
-  """Specialization for the module op class."""
+    """Specialization for the module op class."""
 
-  def __init__(self, *, loc=None, ip=None):
-    super().__init__(self.build_generic(results=[], operands=[], loc=loc,
-                                        ip=ip))
-    body = self.regions[0].blocks.append()
+    def __init__(self, *, loc=None, ip=None):
+        super().__init__(self.build_generic(results=[], operands=[], loc=loc, ip=ip))
+        body = self.regions[0].blocks.append()
 
-  @property
-  def body(self):
-    return self.regions[0].blocks[0]
+    @property
+    def body(self):
+        return self.regions[0].blocks[0]
index 56df423..6d264c3 100644 (file)
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from ..ir import *
-  from ._ods_common import get_default_loc_context as _get_default_loc_context
+    from ..ir import *
+    from ._ods_common import get_default_loc_context as _get_default_loc_context
 
-  import inspect
+    import inspect
 
-  from typing import Any, List, Optional, Sequence, Union
+    from typing import Any, List, Optional, Sequence, Union
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
 
 ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
 RESULT_ATTRIBUTE_NAME = "res_attrs"
 
+
 class ConstantOp:
-  """Specialization for the constant op class."""
+    """Specialization for the constant op class."""
 
-  def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
-    super().__init__(result, value, loc=loc, ip=ip)
+    def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
+        super().__init__(result, value, loc=loc, ip=ip)
 
-  @property
-  def type(self):
-    return self.results[0].type
+    @property
+    def type(self):
+        return self.results[0].type
 
 
 class FuncOp:
-  """Specialization for the func op class."""
-
-  def __init__(self,
-               name,
-               type,
-               *,
-               visibility=None,
-               body_builder=None,
-               loc=None,
-               ip=None):
-    """
-    Create a FuncOp with the provided `name`, `type`, and `visibility`.
-    - `name` is a string representing the function name.
-    - `type` is either a FunctionType or a pair of list describing inputs and
-      results.
-    - `visibility` is a string matching `public`, `private`, or `nested`. None
-      implies private visibility.
-    - `body_builder` is an optional callback, when provided a new entry block
-      is created and the callback is invoked with the new op as argument within
-      an InsertionPoint context already set for the block. The callback is
-      expected to insert a terminator in the block.
-    """
-    sym_name = StringAttr.get(str(name))
-
-    # If the type is passed as a tuple, build a FunctionType on the fly.
-    if isinstance(type, tuple):
-      type = FunctionType.get(inputs=type[0], results=type[1])
-
-    type = TypeAttr.get(type)
-    sym_visibility = StringAttr.get(
-        str(visibility)) if visibility is not None else None
-    super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
-    if body_builder:
-      entry_block = self.add_entry_block()
-      with InsertionPoint(entry_block):
-        body_builder(self)
-
-  @property
-  def is_external(self):
-    return len(self.regions[0].blocks) == 0
-
-  @property
-  def body(self):
-    return self.regions[0]
-
-  @property
-  def type(self):
-    return FunctionType(TypeAttr(self.attributes["function_type"]).value)
-
-  @property
-  def visibility(self):
-    return self.attributes["sym_visibility"]
-
-  @property
-  def name(self) -> StringAttr:
-    return StringAttr(self.attributes["sym_name"])
-
-  @property
-  def entry_block(self):
-    if self.is_external:
-      raise IndexError('External function does not have a body')
-    return self.regions[0].blocks[0]
-
-  def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
-    """
-    Add an entry block to the function body using the function signature to
-    infer block arguments.
-    Returns the newly created block
-    """
-    if not self.is_external:
-      raise IndexError('The function already has an entry block!')
-    self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs)
-    return self.body.blocks[0]
-
-  @property
-  def arg_attrs(self):
-    return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
-
-  @arg_attrs.setter
-  def arg_attrs(self, attribute: Union[ArrayAttr, list]):
-    if isinstance(attribute, ArrayAttr):
-      self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
-    else:
-      self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
-          attribute, context=self.context)
-
-  @property
-  def arguments(self):
-    return self.entry_block.arguments
-
-  @property
-  def result_attrs(self):
-    return self.attributes[RESULT_ATTRIBUTE_NAME]
-
-  @result_attrs.setter
-  def result_attrs(self, attribute: ArrayAttr):
-    self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
-
-  @classmethod
-  def from_py_func(FuncOp,
-                   *inputs: Type,
-                   results: Optional[Sequence[Type]] = None,
-                   name: Optional[str] = None):
-    """Decorator to define an MLIR FuncOp specified as a python function.
-
-    Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
-    active for the current thread (i.e. established in a `with` block).
-
-    When applied as a decorator to a Python function, an entry block will
-    be constructed for the FuncOp with types as specified in `*inputs`. The
-    block arguments will be passed positionally to the Python function. In
-    addition, if the Python function accepts keyword arguments generally or
-    has a corresponding keyword argument, the following will be passed:
-      * `func_op`: The `func` op being defined.
-
-    By default, the function name will be the Python function `__name__`. This
-    can be overriden by passing the `name` argument to the decorator.
-
-    If `results` is not specified, then the decorator will implicitly
-    insert a `ReturnOp` with the `Value`'s returned from the decorated
-    function. It will also set the `FuncOp` type with the actual return
-    value types. If `results` is specified, then the decorated function
-    must return `None` and no implicit `ReturnOp` is added (nor are the result
-    types updated). The implicit behavior is intended for simple, single-block
-    cases, and users should specify result types explicitly for any complicated
-    cases.
-
-    The decorated function can further be called from Python and will insert
-    a `CallOp` at the then-current insertion point, returning either None (
-    if no return values), a unary Value (for one result), or a list of Values).
-    This mechanism cannot be used to emit recursive calls (by construction).
-    """
-
-    def decorator(f):
-      from . import func
-      # Introspect the callable for optional features.
-      sig = inspect.signature(f)
-      has_arg_func_op = False
-      for param in sig.parameters.values():
-        if param.kind == param.VAR_KEYWORD:
-          has_arg_func_op = True
-        if param.name == "func_op" and (param.kind
-                                        == param.POSITIONAL_OR_KEYWORD or
-                                        param.kind == param.KEYWORD_ONLY):
-          has_arg_func_op = True
-
-      # Emit the FuncOp.
-      implicit_return = results is None
-      symbol_name = name or f.__name__
-      function_type = FunctionType.get(
-          inputs=inputs, results=[] if implicit_return else results)
-      func_op = FuncOp(name=symbol_name, type=function_type)
-      with InsertionPoint(func_op.add_entry_block()):
-        func_args = func_op.entry_block.arguments
-        func_kwargs = {}
-        if has_arg_func_op:
-          func_kwargs["func_op"] = func_op
-        return_values = f(*func_args, **func_kwargs)
-        if not implicit_return:
-          return_types = list(results)
-          assert return_values is None, (
-              "Capturing a python function with explicit `results=` "
-              "requires that the wrapped function returns None.")
-        else:
-          # Coerce return values, add ReturnOp and rewrite func type.
-          if return_values is None:
-            return_values = []
-          elif isinstance(return_values, tuple):
-            return_values = list(return_values)
-          elif isinstance(return_values, Value):
-            # Returning a single value is fine, coerce it into a list.
-            return_values = [return_values]
-          elif isinstance(return_values, OpView):
-            # Returning a single operation is fine, coerce its results a list.
-            return_values = return_values.operation.results
-          elif isinstance(return_values, Operation):
-            # Returning a single operation is fine, coerce its results a list.
-            return_values = return_values.results
-          else:
-            return_values = list(return_values)
-          func.ReturnOp(return_values)
-          # Recompute the function type.
-          return_types = [v.type for v in return_values]
-          function_type = FunctionType.get(inputs=inputs, results=return_types)
-          func_op.attributes["function_type"] = TypeAttr.get(function_type)
-
-      def emit_call_op(*call_args):
-        call_op = func.CallOp(return_types, FlatSymbolRefAttr.get(symbol_name),
-                              call_args)
-        if return_types is None:
-          return None
-        elif len(return_types) == 1:
-          return call_op.result
+    """Specialization for the func op class."""
+
+    def __init__(
+        self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
+    ):
+        """
+        Create a FuncOp with the provided `name`, `type`, and `visibility`.
+        - `name` is a string representing the function name.
+        - `type` is either a FunctionType or a pair of list describing inputs and
+          results.
+        - `visibility` is a string matching `public`, `private`, or `nested`. None
+          implies private visibility.
+        - `body_builder` is an optional callback, when provided a new entry block
+          is created and the callback is invoked with the new op as argument within
+          an InsertionPoint context already set for the block. The callback is
+          expected to insert a terminator in the block.
+        """
+        sym_name = StringAttr.get(str(name))
+
+        # If the type is passed as a tuple, build a FunctionType on the fly.
+        if isinstance(type, tuple):
+            type = FunctionType.get(inputs=type[0], results=type[1])
+
+        type = TypeAttr.get(type)
+        sym_visibility = (
+            StringAttr.get(str(visibility)) if visibility is not None else None
+        )
+        super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
+        if body_builder:
+            entry_block = self.add_entry_block()
+            with InsertionPoint(entry_block):
+                body_builder(self)
+
+    @property
+    def is_external(self):
+        return len(self.regions[0].blocks) == 0
+
+    @property
+    def body(self):
+        return self.regions[0]
+
+    @property
+    def type(self):
+        return FunctionType(TypeAttr(self.attributes["function_type"]).value)
+
+    @property
+    def visibility(self):
+        return self.attributes["sym_visibility"]
+
+    @property
+    def name(self) -> StringAttr:
+        return StringAttr(self.attributes["sym_name"])
+
+    @property
+    def entry_block(self):
+        if self.is_external:
+            raise IndexError("External function does not have a body")
+        return self.regions[0].blocks[0]
+
+    def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
+        """
+        Add an entry block to the function body using the function signature to
+        infer block arguments.
+        Returns the newly created block
+        """
+        if not self.is_external:
+            raise IndexError("The function already has an entry block!")
+        self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs)
+        return self.body.blocks[0]
+
+    @property
+    def arg_attrs(self):
+        return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
+
+    @arg_attrs.setter
+    def arg_attrs(self, attribute: Union[ArrayAttr, list]):
+        if isinstance(attribute, ArrayAttr):
+            self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
         else:
-          return call_op.results
-
-      wrapped = emit_call_op
-      wrapped.__name__ = f.__name__
-      wrapped.func_op = func_op
-      return wrapped
+            self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
+                attribute, context=self.context
+            )
+
+    @property
+    def arguments(self):
+        return self.entry_block.arguments
+
+    @property
+    def result_attrs(self):
+        return self.attributes[RESULT_ATTRIBUTE_NAME]
+
+    @result_attrs.setter
+    def result_attrs(self, attribute: ArrayAttr):
+        self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
+
+    @classmethod
+    def from_py_func(
+        FuncOp,
+        *inputs: Type,
+        results: Optional[Sequence[Type]] = None,
+        name: Optional[str] = None,
+    ):
+        """Decorator to define an MLIR FuncOp specified as a python function.
+
+        Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
+        active for the current thread (i.e. established in a `with` block).
+
+        When applied as a decorator to a Python function, an entry block will
+        be constructed for the FuncOp with types as specified in `*inputs`. The
+        block arguments will be passed positionally to the Python function. In
+        addition, if the Python function accepts keyword arguments generally or
+        has a corresponding keyword argument, the following will be passed:
+          * `func_op`: The `func` op being defined.
+
+        By default, the function name will be the Python function `__name__`. This
+        can be overriden by passing the `name` argument to the decorator.
+
+        If `results` is not specified, then the decorator will implicitly
+        insert a `ReturnOp` with the `Value`'s returned from the decorated
+        function. It will also set the `FuncOp` type with the actual return
+        value types. If `results` is specified, then the decorated function
+        must return `None` and no implicit `ReturnOp` is added (nor are the result
+        types updated). The implicit behavior is intended for simple, single-block
+        cases, and users should specify result types explicitly for any complicated
+        cases.
+
+        The decorated function can further be called from Python and will insert
+        a `CallOp` at the then-current insertion point, returning either None (
+        if no return values), a unary Value (for one result), or a list of Values).
+        This mechanism cannot be used to emit recursive calls (by construction).
+        """
+
+        def decorator(f):
+            from . import func
+
+            # Introspect the callable for optional features.
+            sig = inspect.signature(f)
+            has_arg_func_op = False
+            for param in sig.parameters.values():
+                if param.kind == param.VAR_KEYWORD:
+                    has_arg_func_op = True
+                if param.name == "func_op" and (
+                    param.kind == param.POSITIONAL_OR_KEYWORD
+                    or param.kind == param.KEYWORD_ONLY
+                ):
+                    has_arg_func_op = True
+
+            # Emit the FuncOp.
+            implicit_return = results is None
+            symbol_name = name or f.__name__
+            function_type = FunctionType.get(
+                inputs=inputs, results=[] if implicit_return else results
+            )
+            func_op = FuncOp(name=symbol_name, type=function_type)
+            with InsertionPoint(func_op.add_entry_block()):
+                func_args = func_op.entry_block.arguments
+                func_kwargs = {}
+                if has_arg_func_op:
+                    func_kwargs["func_op"] = func_op
+                return_values = f(*func_args, **func_kwargs)
+                if not implicit_return:
+                    return_types = list(results)
+                    assert return_values is None, (
+                        "Capturing a python function with explicit `results=` "
+                        "requires that the wrapped function returns None."
+                    )
+                else:
+                    # Coerce return values, add ReturnOp and rewrite func type.
+                    if return_values is None:
+                        return_values = []
+                    elif isinstance(return_values, tuple):
+                        return_values = list(return_values)
+                    elif isinstance(return_values, Value):
+                        # Returning a single value is fine, coerce it into a list.
+                        return_values = [return_values]
+                    elif isinstance(return_values, OpView):
+                        # Returning a single operation is fine, coerce its results a list.
+                        return_values = return_values.operation.results
+                    elif isinstance(return_values, Operation):
+                        # Returning a single operation is fine, coerce its results a list.
+                        return_values = return_values.results
+                    else:
+                        return_values = list(return_values)
+                    func.ReturnOp(return_values)
+                    # Recompute the function type.
+                    return_types = [v.type for v in return_values]
+                    function_type = FunctionType.get(
+                        inputs=inputs, results=return_types
+                    )
+                    func_op.attributes["function_type"] = TypeAttr.get(function_type)
+
+            def emit_call_op(*call_args):
+                call_op = func.CallOp(
+                    return_types, FlatSymbolRefAttr.get(symbol_name), call_args
+                )
+                if return_types is None:
+                    return None
+                elif len(return_types) == 1:
+                    return call_op.result
+                else:
+                    return call_op.results
+
+            wrapped = emit_call_op
+            wrapped.__name__ = f.__name__
+            wrapped.func_op = func_op
+            return wrapped
+
+        return decorator
 
-    return decorator
 
 class CallOp:
-  """Specialization for the call op class."""
-
-  def __init__(self,
-               calleeOrResults: Union[FuncOp, List[Type]],
-               argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
-               arguments: Optional[List] = None,
-               *,
-               loc=None,
-               ip=None):
-    """Creates an call operation.
-
-    The constructor accepts three different forms:
-
-      1. A function op to be called followed by a list of arguments.
-      2. A list of result types, followed by the name of the function to be
-         called as string, following by a list of arguments.
-      3. A list of result types, followed by the name of the function to be
-         called as symbol reference attribute, followed by a list of arguments.
-
-    For example
-
-        f = func.FuncOp("foo", ...)
-        func.CallOp(f, [args])
-        func.CallOp([result_types], "foo", [args])
-
-    In all cases, the location and insertion point may be specified as keyword
-    arguments if not provided by the surrounding context managers.
-    """
-
-    # TODO: consider supporting constructor "overloads", e.g., through a custom
-    # or pybind-provided metaclass.
-    if isinstance(calleeOrResults, FuncOp):
-      if not isinstance(argumentsOrCallee, list):
-        raise ValueError(
-            "when constructing a call to a function, expected " +
-            "the second argument to be a list of call arguments, " +
-            f"got {type(argumentsOrCallee)}")
-      if arguments is not None:
-        raise ValueError("unexpected third argument when constructing a call" +
-                         "to a function")
-
-      super().__init__(
-          calleeOrResults.type.results,
-          FlatSymbolRefAttr.get(
-              calleeOrResults.name.value,
-              context=_get_default_loc_context(loc)),
-          argumentsOrCallee,
-          loc=loc,
-          ip=ip)
-      return
-
-    if isinstance(argumentsOrCallee, list):
-      raise ValueError("when constructing a call to a function by name, " +
-                       "expected the second argument to be a string or a " +
-                       f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}")
-
-    if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
-      super().__init__(
-          calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip)
-    elif isinstance(argumentsOrCallee, str):
-      super().__init__(
-          calleeOrResults,
-          FlatSymbolRefAttr.get(
-              argumentsOrCallee, context=_get_default_loc_context(loc)),
-          arguments,
-          loc=loc,
-          ip=ip)
+    """Specialization for the call op class."""
+
+    def __init__(
+        self,
+        calleeOrResults: Union[FuncOp, List[Type]],
+        argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
+        arguments: Optional[List] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        """Creates an call operation.
+
+        The constructor accepts three different forms:
+
+          1. A function op to be called followed by a list of arguments.
+          2. A list of result types, followed by the name of the function to be
+             called as string, following by a list of arguments.
+          3. A list of result types, followed by the name of the function to be
+             called as symbol reference attribute, followed by a list of arguments.
+
+        For example
+
+            f = func.FuncOp("foo", ...)
+            func.CallOp(f, [args])
+            func.CallOp([result_types], "foo", [args])
+
+        In all cases, the location and insertion point may be specified as keyword
+        arguments if not provided by the surrounding context managers.
+        """
+
+        # TODO: consider supporting constructor "overloads", e.g., through a custom
+        # or pybind-provided metaclass.
+        if isinstance(calleeOrResults, FuncOp):
+            if not isinstance(argumentsOrCallee, list):
+                raise ValueError(
+                    "when constructing a call to a function, expected "
+                    + "the second argument to be a list of call arguments, "
+                    + f"got {type(argumentsOrCallee)}"
+                )
+            if arguments is not None:
+                raise ValueError(
+                    "unexpected third argument when constructing a call"
+                    + "to a function"
+                )
+
+            super().__init__(
+                calleeOrResults.type.results,
+                FlatSymbolRefAttr.get(
+                    calleeOrResults.name.value, context=_get_default_loc_context(loc)
+                ),
+                argumentsOrCallee,
+                loc=loc,
+                ip=ip,
+            )
+            return
+
+        if isinstance(argumentsOrCallee, list):
+            raise ValueError(
+                "when constructing a call to a function by name, "
+                + "expected the second argument to be a string or a "
+                + f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}"
+            )
+
+        if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
+            super().__init__(
+                calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip
+            )
+        elif isinstance(argumentsOrCallee, str):
+            super().__init__(
+                calleeOrResults,
+                FlatSymbolRefAttr.get(
+                    argumentsOrCallee, context=_get_default_loc_context(loc)
+                ),
+                arguments,
+                loc=loc,
+                ip=ip,
+            )
index eb9e969..3f6d854 100644 (file)
@@ -3,39 +3,45 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from typing import Optional, Sequence, Union
-  from ..ir import *
-  from ._ods_common import get_default_loc_context
-  from .._mlir_libs._mlirDialectsLinalg import fill_builtin_region
+    from typing import Optional, Sequence, Union
+    from ..ir import *
+    from ._ods_common import get_default_loc_context
+    from .._mlir_libs._mlirDialectsLinalg import fill_builtin_region
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
 
 from ._ods_common import get_op_result_or_value as _get_op_result_or_value
 
+
 def isa(cls: Type, ty: Type):
-  try:
-    cls(ty)
-    return True
-  except ValueError:
-    return False
+    try:
+        cls(ty)
+        return True
+    except ValueError:
+        return False
 
 
 class StructuredOpMixin:
-  """All structured ops use the same mixin class."""
+    """All structured ops use the same mixin class."""
 
-  def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
-    super().__init__(
-        self.build_generic(results=list(results),
-                           operands=[list(inputs), list(outputs)],
-                           loc=loc,
-                           ip=ip))
+    def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
+        super().__init__(
+            self.build_generic(
+                results=list(results),
+                operands=[list(inputs), list(outputs)],
+                loc=loc,
+                ip=ip,
+            )
+        )
 
 
 def select_opview_mixin(parent_opview_cls):
-  # TODO: This shouldn't be a heuristic: we should have a way to annotate
-  # the OpView to note that it is a structured op.
-  if ("__init__" not in parent_opview_cls.__dict__ and
-      hasattr(parent_opview_cls, "inputs") and
-      hasattr(parent_opview_cls, "outputs") and
-      hasattr(parent_opview_cls, "result_tensors")):
-    return StructuredOpMixin
+    # TODO: This shouldn't be a heuristic: we should have a way to annotate
+    # the OpView to note that it is a structured op.
+    if (
+        "__init__" not in parent_opview_cls.__dict__
+        and hasattr(parent_opview_cls, "inputs")
+        and hasattr(parent_opview_cls, "outputs")
+        and hasattr(parent_opview_cls, "result_tensors")
+    ):
+        return StructuredOpMixin
index 10079d3..3536d45 100644 (file)
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from ..ir import *
-  from ._ods_common import get_op_result_or_value as _get_op_result_or_value
+    from ..ir import *
+    from ._ods_common import get_op_result_or_value as _get_op_result_or_value
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
 
 from typing import Optional, Union
 
 
 class GetParentForOp:
-  """Extension for GetParentForOp."""
-
-  def __init__(
-      self,
-      result_type: Type,
-      target: Union[Operation, Value],
-      *,
-      num_loops: Optional[int] = None,
-      ip=None,
-      loc=None,
-  ):
-    if num_loops is None:
-      num_loops = 1
-    super().__init__(
-        result_type,
-        _get_op_result_or_value(target),
-        num_loops=num_loops,
-        ip=ip,
-        loc=loc,
-    )
+    """Extension for GetParentForOp."""
+
+    def __init__(
+        self,
+        result_type: Type,
+        target: Union[Operation, Value],
+        *,
+        num_loops: Optional[int] = None,
+        ip=None,
+        loc=None,
+    ):
+        if num_loops is None:
+            num_loops = 1
+        super().__init__(
+            result_type,
+            _get_op_result_or_value(target),
+            num_loops=num_loops,
+            ip=ip,
+            loc=loc,
+        )
 
 
 class LoopOutlineOp:
-  """Extension for LoopOutlineOp."""
-
-  def __init__(
-      self,
-      function_type: Type,
-      call_type: Type,
-      target: Union[Operation, Value],
-      *,
-      func_name: Union[str, StringAttr],
-      ip=None,
-      loc=None,
-  ):
-    super().__init__(
-        function_type,
-        call_type,
-        _get_op_result_or_value(target),
-        func_name=(func_name if isinstance(func_name, StringAttr) else
-                   StringAttr.get(func_name)),
-        ip=ip,
-        loc=loc,
-    )
+    """Extension for LoopOutlineOp."""
+
+    def __init__(
+        self,
+        function_type: Type,
+        call_type: Type,
+        target: Union[Operation, Value],
+        *,
+        func_name: Union[str, StringAttr],
+        ip=None,
+        loc=None,
+    ):
+        super().__init__(
+            function_type,
+            call_type,
+            _get_op_result_or_value(target),
+            func_name=(
+                func_name
+                if isinstance(func_name, StringAttr)
+                else StringAttr.get(func_name)
+            ),
+            ip=ip,
+            loc=loc,
+        )
 
 
 class LoopPeelOp:
-  """Extension for LoopPeelOp."""
-
-  def __init__(
-      self,
-      result_type: Type,
-      target: Union[Operation, Value],
-      *,
-      fail_if_already_divisible: Union[bool, BoolAttr] = False,
-      ip=None,
-      loc=None,
-  ):
-    super().__init__(
-        result_type,
-        _get_op_result_or_value(target),
-        fail_if_already_divisible=(fail_if_already_divisible if isinstance(
-            fail_if_already_divisible, BoolAttr) else
-                                   BoolAttr.get(fail_if_already_divisible)),
-        ip=ip,
-        loc=loc,
-    )
+    """Extension for LoopPeelOp."""
+
+    def __init__(
+        self,
+        result_type: Type,
+        target: Union[Operation, Value],
+        *,
+        fail_if_already_divisible: Union[bool, BoolAttr] = False,
+        ip=None,
+        loc=None,
+    ):
+        super().__init__(
+            result_type,
+            _get_op_result_or_value(target),
+            fail_if_already_divisible=(
+                fail_if_already_divisible
+                if isinstance(fail_if_already_divisible, BoolAttr)
+                else BoolAttr.get(fail_if_already_divisible)
+            ),
+            ip=ip,
+            loc=loc,
+        )
 
 
 class LoopPipelineOp:
-  """Extension for LoopPipelineOp."""
-
-  def __init__(
-      self,
-      result_type: Type,
-      target: Union[Operation, Value],
-      *,
-      iteration_interval: Optional[Union[int, IntegerAttr]] = None,
-      read_latency: Optional[Union[int, IntegerAttr]] = None,
-      ip=None,
-      loc=None,
-  ):
-    if iteration_interval is None:
-      iteration_interval = 1
-    if read_latency is None:
-      read_latency = 10
-    super().__init__(
-        result_type,
-        _get_op_result_or_value(target),
-        iteration_interval=iteration_interval,
-        read_latency=read_latency,
-        ip=ip,
-        loc=loc,
-    )
+    """Extension for LoopPipelineOp."""
+
+    def __init__(
+        self,
+        result_type: Type,
+        target: Union[Operation, Value],
+        *,
+        iteration_interval: Optional[Union[int, IntegerAttr]] = None,
+        read_latency: Optional[Union[int, IntegerAttr]] = None,
+        ip=None,
+        loc=None,
+    ):
+        if iteration_interval is None:
+            iteration_interval = 1
+        if read_latency is None:
+            read_latency = 10
+        super().__init__(
+            result_type,
+            _get_op_result_or_value(target),
+            iteration_interval=iteration_interval,
+            read_latency=read_latency,
+            ip=ip,
+            loc=loc,
+        )
 
 
 class LoopUnrollOp:
-  """Extension for LoopUnrollOp."""
-
-  def __init__(
-      self,
-      target: Union[Operation, Value],
-      *,
-      factor: Union[int, IntegerAttr],
-      ip=None,
-      loc=None,
-  ):
-    super().__init__(
-        _get_op_result_or_value(target),
-        factor=factor,
-        ip=ip,
-        loc=loc,
-    )
+    """Extension for LoopUnrollOp."""
+
+    def __init__(
+        self,
+        target: Union[Operation, Value],
+        *,
+        factor: Union[int, IntegerAttr],
+        ip=None,
+        loc=None,
+    ):
+        super().__init__(
+            _get_op_result_or_value(target),
+            factor=factor,
+            ip=ip,
+            loc=loc,
+        )
index a00a087..825f1a0 100644 (file)
@@ -3,34 +3,34 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from ..ir import *
-  from ._ods_common import get_op_result_or_value as _get_op_result_or_value
-  from ._ods_common import get_op_results_or_values as _get_op_results_or_values
+    from ..ir import *
+    from ._ods_common import get_op_result_or_value as _get_op_result_or_value
+    from ._ods_common import get_op_results_or_values as _get_op_results_or_values
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
 
 from typing import Optional, Sequence, Union
 
 
 class LoadOp:
-  """Specialization for the MemRef load operation."""
+    """Specialization for the MemRef load operation."""
 
-  def __init__(self,
-               memref: Union[Operation, OpView, Value],
-               indices: Optional[Union[Operation, OpView,
-                                       Sequence[Value]]] = None,
-               *,
-               loc=None,
-               ip=None):
-    """Creates a memref load operation.
+    def __init__(
+        self,
+        memref: Union[Operation, OpView, Value],
+        indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
+        *,
+        loc=None,
+        ip=None
+    ):
+        """Creates a memref load operation.
 
-    Args:
-      memref: the buffer to load from.
-      indices: the list of subscripts, may be empty for zero-dimensional
-        buffers.
-      loc: user-visible location of the operation.
-      ip: insertion point.
-    """
-    indices_resolved = [] if indices is None else _get_op_results_or_values(
-        indices)
-    super().__init__(memref, indices_resolved, loc=loc, ip=ip)
+        Args:
+          memref: the buffer to load from.
+          indices: the list of subscripts, may be empty for zero-dimensional
+            buffers.
+          loc: user-visible location of the operation.
+          ip: insertion point.
+        """
+        indices_resolved = [] if indices is None else _get_op_results_or_values(indices)
+        super().__init__(memref, indices_resolved, loc=loc, ip=ip)
index 8db82cf..c84d23c 100644 (file)
@@ -3,11 +3,11 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from typing import Union
-  from ..ir import *
-  from ._ods_common import get_default_loc_context as _get_default_loc_context
+    from typing import Union
+    from ..ir import *
+    from ._ods_common import get_default_loc_context as _get_default_loc_context
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
 
 from ._ml_program_ops_gen import *
 
@@ -17,100 +17,97 @@ RESULT_ATTRIBUTE_NAME = "res_attrs"
 
 
 class FuncOp:
-  """Specialization for the func op class."""
-
-  def __init__(self,
-               name,
-               type,
-               *,
-               visibility=None,
-               body_builder=None,
-               loc=None,
-               ip=None):
-    """
-    Create a FuncOp with the provided `name`, `type`, and `visibility`.
-    - `name` is a string representing the function name.
-    - `type` is either a FunctionType or a pair of list describing inputs and
-      results.
-    - `visibility` is a string matching `public`, `private`, or `nested`. None
-      implies private visibility.
-    - `body_builder` is an optional callback, when provided a new entry block
-      is created and the callback is invoked with the new op as argument within
-      an InsertionPoint context already set for the block. The callback is
-      expected to insert a terminator in the block.
-    """
-    sym_name = StringAttr.get(str(name))
-
-    # If the type is passed as a tuple, build a FunctionType on the fly.
-    if isinstance(type, tuple):
-      type = FunctionType.get(inputs=type[0], results=type[1])
-
-    type = TypeAttr.get(type)
-    sym_visibility = StringAttr.get(
-        str(visibility)) if visibility is not None else None
-    super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
-    if body_builder:
-      entry_block = self.add_entry_block()
-      with InsertionPoint(entry_block):
-        body_builder(self)
-
-  @property
-  def is_external(self):
-    return len(self.regions[0].blocks) == 0
-
-  @property
-  def body(self):
-    return self.regions[0]
-
-  @property
-  def type(self):
-    return FunctionType(TypeAttr(self.attributes["function_type"]).value)
-
-  @property
-  def visibility(self):
-    return self.attributes["sym_visibility"]
-
-  @property
-  def name(self) -> StringAttr:
-    return StringAttr(self.attributes["sym_name"])
-
-  @property
-  def entry_block(self):
-    if self.is_external:
-      raise IndexError('External function does not have a body')
-    return self.regions[0].blocks[0]
-
-  def add_entry_block(self):
-    """
-    Add an entry block to the function body using the function signature to
-    infer block arguments.
-    Returns the newly created block
-    """
-    if not self.is_external:
-      raise IndexError('The function already has an entry block!')
-    self.body.blocks.append(*self.type.inputs)
-    return self.body.blocks[0]
-
-  @property
-  def arg_attrs(self):
-    return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
-
-  @arg_attrs.setter
-  def arg_attrs(self, attribute: Union[ArrayAttr, list]):
-    if isinstance(attribute, ArrayAttr):
-      self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
-    else:
-      self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
-          attribute, context=self.context)
-
-  @property
-  def arguments(self):
-    return self.entry_block.arguments
-
-  @property
-  def result_attrs(self):
-    return self.attributes[RESULT_ATTRIBUTE_NAME]
-
-  @result_attrs.setter
-  def result_attrs(self, attribute: ArrayAttr):
-    self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
+    """Specialization for the func op class."""
+
+    def __init__(
+        self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
+    ):
+        """
+        Create a FuncOp with the provided `name`, `type`, and `visibility`.
+        - `name` is a string representing the function name.
+        - `type` is either a FunctionType or a pair of list describing inputs and
+          results.
+        - `visibility` is a string matching `public`, `private`, or `nested`. None
+          implies private visibility.
+        - `body_builder` is an optional callback, when provided a new entry block
+          is created and the callback is invoked with the new op as argument within
+          an InsertionPoint context already set for the block. The callback is
+          expected to insert a terminator in the block.
+        """
+        sym_name = StringAttr.get(str(name))
+
+        # If the type is passed as a tuple, build a FunctionType on the fly.
+        if isinstance(type, tuple):
+            type = FunctionType.get(inputs=type[0], results=type[1])
+
+        type = TypeAttr.get(type)
+        sym_visibility = (
+            StringAttr.get(str(visibility)) if visibility is not None else None
+        )
+        super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
+        if body_builder:
+            entry_block = self.add_entry_block()
+            with InsertionPoint(entry_block):
+                body_builder(self)
+
+    @property
+    def is_external(self):
+        return len(self.regions[0].blocks) == 0
+
+    @property
+    def body(self):
+        return self.regions[0]
+
+    @property
+    def type(self):
+        return FunctionType(TypeAttr(self.attributes["function_type"]).value)
+
+    @property
+    def visibility(self):
+        return self.attributes["sym_visibility"]
+
+    @property
+    def name(self) -> StringAttr:
+        return StringAttr(self.attributes["sym_name"])
+
+    @property
+    def entry_block(self):
+        if self.is_external:
+            raise IndexError("External function does not have a body")
+        return self.regions[0].blocks[0]
+
+    def add_entry_block(self):
+        """
+        Add an entry block to the function body using the function signature to
+        infer block arguments.
+        Returns the newly created block
+        """
+        if not self.is_external:
+            raise IndexError("The function already has an entry block!")
+        self.body.blocks.append(*self.type.inputs)
+        return self.body.blocks[0]
+
+    @property
+    def arg_attrs(self):
+        return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
+
+    @arg_attrs.setter
+    def arg_attrs(self, attribute: Union[ArrayAttr, list]):
+        if isinstance(attribute, ArrayAttr):
+            self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
+        else:
+            self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
+                attribute, context=self.context
+            )
+
+    @property
+    def arguments(self):
+        return self.entry_block.arguments
+
+    @property
+    def result_attrs(self):
+        return self.attributes[RESULT_ATTRIBUTE_NAME]
+
+    @result_attrs.setter
+    def result_attrs(self, attribute: ArrayAttr):
+        self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
index 51b9008..7655629 100644 (file)
@@ -18,144 +18,152 @@ __all__ = [
 
 
 def extend_opview_class(ext_module):
-  """Decorator to extend an OpView class from an extension module.
-
-  Extension modules can expose various entry-points:
-    Stand-alone class with the same name as a parent OpView class (i.e.
-    "ReturnOp"). A name-based match is attempted first before falling back
-    to a below mechanism.
-
-    def select_opview_mixin(parent_opview_cls):
-      If defined, allows an appropriate mixin class to be selected dynamically
-      based on the parent OpView class. Should return NotImplemented if a
-      decision is not made.
-
-  Args:
-    ext_module: A module from which to locate extensions. Can be None if not
-      available.
-
-  Returns:
-    A decorator that takes an OpView subclass and further extends it as
-    needed.
-  """
-
-  def class_decorator(parent_opview_cls: type):
-    if ext_module is None:
-      return parent_opview_cls
-    mixin_cls = NotImplemented
-    # First try to resolve by name.
-    try:
-      mixin_cls = getattr(ext_module, parent_opview_cls.__name__)
-    except AttributeError:
-      # Fall back to a select_opview_mixin hook.
-      try:
-        select_mixin = getattr(ext_module, "select_opview_mixin")
-      except AttributeError:
-        pass
-      else:
-        mixin_cls = select_mixin(parent_opview_cls)
-
-    if mixin_cls is NotImplemented or mixin_cls is None:
-      return parent_opview_cls
-
-    # Have a mixin_cls. Create an appropriate subclass.
-    try:
-
-      class LocalOpView(mixin_cls, parent_opview_cls):
-        pass
-    except TypeError as e:
-      raise TypeError(
-          f"Could not mixin {mixin_cls} into {parent_opview_cls}") from e
-    LocalOpView.__name__ = parent_opview_cls.__name__
-    LocalOpView.__qualname__ = parent_opview_cls.__qualname__
-    return LocalOpView
-
-  return class_decorator
+    """Decorator to extend an OpView class from an extension module.
+
+    Extension modules can expose various entry-points:
+      Stand-alone class with the same name as a parent OpView class (i.e.
+      "ReturnOp"). A name-based match is attempted first before falling back
+      to a below mechanism.
+
+      def select_opview_mixin(parent_opview_cls):
+        If defined, allows an appropriate mixin class to be selected dynamically
+        based on the parent OpView class. Should return NotImplemented if a
+        decision is not made.
+
+    Args:
+      ext_module: A module from which to locate extensions. Can be None if not
+        available.
+
+    Returns:
+      A decorator that takes an OpView subclass and further extends it as
+      needed.
+    """
+
+    def class_decorator(parent_opview_cls: type):
+        if ext_module is None:
+            return parent_opview_cls
+        mixin_cls = NotImplemented
+        # First try to resolve by name.
+        try:
+            mixin_cls = getattr(ext_module, parent_opview_cls.__name__)
+        except AttributeError:
+            # Fall back to a select_opview_mixin hook.
+            try:
+                select_mixin = getattr(ext_module, "select_opview_mixin")
+            except AttributeError:
+                pass
+            else:
+                mixin_cls = select_mixin(parent_opview_cls)
+
+        if mixin_cls is NotImplemented or mixin_cls is None:
+            return parent_opview_cls
+
+        # Have a mixin_cls. Create an appropriate subclass.
+        try:
+
+            class LocalOpView(mixin_cls, parent_opview_cls):
+                pass
+
+        except TypeError as e:
+            raise TypeError(
+                f"Could not mixin {mixin_cls} into {parent_opview_cls}"
+            ) from e
+        LocalOpView.__name__ = parent_opview_cls.__name__
+        LocalOpView.__qualname__ = parent_opview_cls.__qualname__
+        return LocalOpView
+
+    return class_decorator
 
 
 def segmented_accessor(elements, raw_segments, idx):
-  """
-  Returns a slice of elements corresponding to the idx-th segment.
-
-    elements: a sliceable container (operands or results).
-    raw_segments: an mlir.ir.Attribute, of DenseI32Array subclass containing
-        sizes of the segments.
-    idx: index of the segment.
-  """
-  segments = _cext.ir.DenseI32ArrayAttr(raw_segments)
-  start = sum(segments[i] for i in range(idx))
-  end = start + segments[idx]
-  return elements[start:end]
-
-
-def equally_sized_accessor(elements, n_variadic, n_preceding_simple,
-                           n_preceding_variadic):
-  """
-  Returns a starting position and a number of elements per variadic group
-  assuming equally-sized groups and the given numbers of preceding groups.
-
-    elements: a sequential container.
-    n_variadic: the number of variadic groups in the container.
-    n_preceding_simple: the number of non-variadic groups preceding the current
-        group.
-    n_preceding_variadic: the number of variadic groups preceding the current
-        group.
-  """
-
-  total_variadic_length = len(elements) - n_variadic + 1
-  # This should be enforced by the C++-side trait verifier.
-  assert total_variadic_length % n_variadic == 0
-
-  elements_per_group = total_variadic_length // n_variadic
-  start = n_preceding_simple + n_preceding_variadic * elements_per_group
-  return start, elements_per_group
+    """
+    Returns a slice of elements corresponding to the idx-th segment.
+
+      elements: a sliceable container (operands or results).
+      raw_segments: an mlir.ir.Attribute, of DenseI32Array subclass containing
+          sizes of the segments.
+      idx: index of the segment.
+    """
+    segments = _cext.ir.DenseI32ArrayAttr(raw_segments)
+    start = sum(segments[i] for i in range(idx))
+    end = start + segments[idx]
+    return elements[start:end]
+
+
+def equally_sized_accessor(
+    elements, n_variadic, n_preceding_simple, n_preceding_variadic
+):
+    """
+    Returns a starting position and a number of elements per variadic group
+    assuming equally-sized groups and the given numbers of preceding groups.
+
+      elements: a sequential container.
+      n_variadic: the number of variadic groups in the container.
+      n_preceding_simple: the number of non-variadic groups preceding the current
+          group.
+      n_preceding_variadic: the number of variadic groups preceding the current
+          group.
+    """
+
+    total_variadic_length = len(elements) - n_variadic + 1
+    # This should be enforced by the C++-side trait verifier.
+    assert total_variadic_length % n_variadic == 0
+
+    elements_per_group = total_variadic_length // n_variadic
+    start = n_preceding_simple + n_preceding_variadic * elements_per_group
+    return start, elements_per_group
 
 
 def get_default_loc_context(location=None):
-  """
-  Returns a context in which the defaulted location is created. If the location
-  is None, takes the current location from the stack, raises ValueError if there
-  is no location on the stack.
-  """
-  if location is None:
-    # Location.current raises ValueError if there is no current location.
-    return _cext.ir.Location.current.context
-  return location.context
+    """
+    Returns a context in which the defaulted location is created. If the location
+    is None, takes the current location from the stack, raises ValueError if there
+    is no location on the stack.
+    """
+    if location is None:
+        # Location.current raises ValueError if there is no current location.
+        return _cext.ir.Location.current.context
+    return location.context
 
 
 def get_op_result_or_value(
-    arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList]
+    arg: _Union[
+        _cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList
+    ]
 ) -> _cext.ir.Value:
-  """Returns the given value or the single result of the given op.
-
-  This is useful to implement op constructors so that they can take other ops as
-  arguments instead of requiring the caller to extract results for every op.
-  Raises ValueError if provided with an op that doesn't have a single result.
-  """
-  if isinstance(arg, _cext.ir.OpView):
-    return arg.operation.result
-  elif isinstance(arg, _cext.ir.Operation):
-    return arg.result
-  elif isinstance(arg, _cext.ir.OpResultList):
-    return arg[0]
-  else:
-    assert isinstance(arg, _cext.ir.Value)
-    return arg
+    """Returns the given value or the single result of the given op.
+
+    This is useful to implement op constructors so that they can take other ops as
+    arguments instead of requiring the caller to extract results for every op.
+    Raises ValueError if provided with an op that doesn't have a single result.
+    """
+    if isinstance(arg, _cext.ir.OpView):
+        return arg.operation.result
+    elif isinstance(arg, _cext.ir.Operation):
+        return arg.result
+    elif isinstance(arg, _cext.ir.OpResultList):
+        return arg[0]
+    else:
+        assert isinstance(arg, _cext.ir.Value)
+        return arg
 
 
 def get_op_results_or_values(
-    arg: _Union[_cext.ir.OpView, _cext.ir.Operation,
-                _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]]]
+    arg: _Union[
+        _cext.ir.OpView,
+        _cext.ir.Operation,
+        _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]],
+    ]
 ) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]:
-  """Returns the given sequence of values or the results of the given op.
-
-  This is useful to implement op constructors so that they can take other ops as
-  lists of arguments instead of requiring the caller to extract results for
-  every op.
-  """
-  if isinstance(arg, _cext.ir.OpView):
-    return arg.operation.results
-  elif isinstance(arg, _cext.ir.Operation):
-    return arg.results
-  else:
-    return [get_op_result_or_value(element) for element in arg]
+    """Returns the given sequence of values or the results of the given op.
+
+    This is useful to implement op constructors so that they can take other ops as
+    lists of arguments instead of requiring the caller to extract results for
+    every op.
+    """
+    if isinstance(arg, _cext.ir.OpView):
+        return arg.operation.results
+    elif isinstance(arg, _cext.ir.Operation):
+        return arg.results
+    else:
+        return [get_op_result_or_value(element) for element in arg]
index 40ccbef..fc9de0b 100644 (file)
@@ -3,10 +3,10 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from ..ir import *
-  from ..dialects import pdl
+    from ..ir import *
+    from ..dialects import pdl
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
 
 from typing import Union, Optional, Sequence, Mapping
 from ._ods_common import (
@@ -16,264 +16,256 @@ from ._ods_common import (
 
 
 class ApplyNativeConstraintOp:
-  """Specialization for PDL apply native constraint op class."""
-
-  def __init__(
-      self,
-      name: Union[str, StringAttr],
-      args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
-      *,
-      loc=None,
-      ip=None,
-  ):
-    if args is None:
-      args = []
-    args = _get_values(args)
-    super().__init__(name, args, loc=loc, ip=ip)
+    """Specialization for PDL apply native constraint op class."""
+
+    def __init__(
+        self,
+        name: Union[str, StringAttr],
+        args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        if args is None:
+            args = []
+        args = _get_values(args)
+        super().__init__(name, args, loc=loc, ip=ip)
 
 
 class ApplyNativeRewriteOp:
-  """Specialization for PDL apply native rewrite op class."""
-
-  def __init__(
-      self,
-      results: Sequence[Type],
-      name: Union[str, StringAttr],
-      args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
-      *,
-      loc=None,
-      ip=None,
-  ):
-    if args is None:
-      args = []
-    args = _get_values(args)
-    super().__init__(results, name, args, loc=loc, ip=ip)
+    """Specialization for PDL apply native rewrite op class."""
+
+    def __init__(
+        self,
+        results: Sequence[Type],
+        name: Union[str, StringAttr],
+        args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        if args is None:
+            args = []
+        args = _get_values(args)
+        super().__init__(results, name, args, loc=loc, ip=ip)
 
 
 class AttributeOp:
-  """Specialization for PDL attribute op class."""
+    """Specialization for PDL attribute op class."""
 
-  def __init__(
-      self,
-      valueType: Optional[Union[OpView, Operation, Value]] = None,
-      value: Optional[Attribute] = None,
-      *,
-      loc=None,
-      ip=None,
-  ):
-    valueType = valueType if valueType is None else _get_value(valueType)
-    result = pdl.AttributeType.get()
-    super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
+    def __init__(
+        self,
+        valueType: Optional[Union[OpView, Operation, Value]] = None,
+        value: Optional[Attribute] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        valueType = valueType if valueType is None else _get_value(valueType)
+        result = pdl.AttributeType.get()
+        super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
 
 
 class EraseOp:
-  """Specialization for PDL erase op class."""
+    """Specialization for PDL erase op class."""
 
-  def __init__(
-      self,
-      operation: Optional[Union[OpView, Operation, Value]] = None,
-      *,
-      loc=None,
-      ip=None,
-  ):
-    operation = _get_value(operation)
-    super().__init__(operation, loc=loc, ip=ip)
+    def __init__(
+        self,
+        operation: Optional[Union[OpView, Operation, Value]] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        operation = _get_value(operation)
+        super().__init__(operation, loc=loc, ip=ip)
 
 
 class OperandOp:
-  """Specialization for PDL operand op class."""
+    """Specialization for PDL operand op class."""
 
-  def __init__(
-      self,
-      type: Optional[Union[OpView, Operation, Value]] = None,
-      *,
-      loc=None,
-      ip=None,
-  ):
-    type = type if type is None else _get_value(type)
-    result = pdl.ValueType.get()
-    super().__init__(result, valueType=type, loc=loc, ip=ip)
+    def __init__(
+        self,
+        type: Optional[Union[OpView, Operation, Value]] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        type = type if type is None else _get_value(type)
+        result = pdl.ValueType.get()
+        super().__init__(result, valueType=type, loc=loc, ip=ip)
 
 
 class OperandsOp:
-  """Specialization for PDL operands op class."""
+    """Specialization for PDL operands op class."""
 
-  def __init__(
-      self,
-      types: Optional[Union[OpView, Operation, Value]] = None,
-      *,
-      loc=None,
-      ip=None,
-  ):
-    types = types if types is None else _get_value(types)
-    result = pdl.RangeType.get(pdl.ValueType.get())
-    super().__init__(result, valueType=types, loc=loc, ip=ip)
+    def __init__(
+        self,
+        types: Optional[Union[OpView, Operation, Value]] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        types = types if types is None else _get_value(types)
+        result = pdl.RangeType.get(pdl.ValueType.get())
+        super().__init__(result, valueType=types, loc=loc, ip=ip)
 
 
 class OperationOp:
-  """Specialization for PDL operand op class."""
-
-  def __init__(
-      self,
-      name: Optional[Union[str, StringAttr]] = None,
-      args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
-      attributes: Optional[Mapping[str, Union[OpView, Operation,
-                                              Value]]] = None,
-      types: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
-      *,
-      loc=None,
-      ip=None,
-  ):
-    if types is None:
-      types = []
-    if attributes is None:
-      attributes = {}
-    if args is None:
-      args = []
-    args = _get_values(args)
-    attrNames = []
-    attrValues = []
-    for attrName, attrValue in attributes.items():
-      attrNames.append(StringAttr.get(attrName))
-      attrValues.append(_get_value(attrValue))
-    attrNames = ArrayAttr.get(attrNames)
-    types = _get_values(types)
-    result = pdl.OperationType.get()
-    super().__init__(result,
-                     args,
-                     attrValues,
-                     attrNames,
-                     types,
-                     opName=name,
-                     loc=loc,
-                     ip=ip)
+    """Specialization for PDL operand op class."""
+
+    def __init__(
+        self,
+        name: Optional[Union[str, StringAttr]] = None,
+        args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+        attributes: Optional[Mapping[str, Union[OpView, Operation, Value]]] = None,
+        types: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        if types is None:
+            types = []
+        if attributes is None:
+            attributes = {}
+        if args is None:
+            args = []
+        args = _get_values(args)
+        attrNames = []
+        attrValues = []
+        for attrName, attrValue in attributes.items():
+            attrNames.append(StringAttr.get(attrName))
+            attrValues.append(_get_value(attrValue))
+        attrNames = ArrayAttr.get(attrNames)
+        types = _get_values(types)
+        result = pdl.OperationType.get()
+        super().__init__(
+            result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip
+        )
 
 
 class PatternOp:
-  """Specialization for PDL pattern op class."""
-
-  def __init__(
-      self,
-      benefit: Union[IntegerAttr, int],
-      name: Optional[Union[StringAttr, str]] = None,
-      *,
-      loc=None,
-      ip=None,
-  ):
-    """Creates an PDL `pattern` operation."""
-    super().__init__(benefit, sym_name=name, loc=loc, ip=ip)
-    self.regions[0].blocks.append()
-
-  @property
-  def body(self):
-    """Return the body (block) of the pattern."""
-    return self.regions[0].blocks[0]
+    """Specialization for PDL pattern op class."""
+
+    def __init__(
+        self,
+        benefit: Union[IntegerAttr, int],
+        name: Optional[Union[StringAttr, str]] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        """Creates an PDL `pattern` operation."""
+        super().__init__(benefit, sym_name=name, loc=loc, ip=ip)
+        self.regions[0].blocks.append()
+
+    @property
+    def body(self):
+        """Return the body (block) of the pattern."""
+        return self.regions[0].blocks[0]
 
 
 class ReplaceOp:
-  """Specialization for PDL replace op class."""
-
-  def __init__(
-      self,
-      op: Union[OpView, Operation, Value],
-      *,
-      with_op: Optional[Union[OpView, Operation, Value]] = None,
-      with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
-      loc=None,
-      ip=None,
-  ):
-    if with_values is None:
-      with_values = []
-    op = _get_value(op)
-    with_op = with_op if with_op is None else _get_value(with_op)
-    with_values = _get_values(with_values)
-    super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip)
+    """Specialization for PDL replace op class."""
+
+    def __init__(
+        self,
+        op: Union[OpView, Operation, Value],
+        *,
+        with_op: Optional[Union[OpView, Operation, Value]] = None,
+        with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+        loc=None,
+        ip=None,
+    ):
+        if with_values is None:
+            with_values = []
+        op = _get_value(op)
+        with_op = with_op if with_op is None else _get_value(with_op)
+        with_values = _get_values(with_values)
+        super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip)
 
 
 class ResultOp:
-  """Specialization for PDL result op class."""
+    """Specialization for PDL result op class."""
 
-  def __init__(
-      self,
-      parent: Union[OpView, Operation, Value],
-      index: Union[IntegerAttr, int],
-      *,
-      loc=None,
-      ip=None,
-  ):
-    parent = _get_value(parent)
-    result = pdl.ValueType.get()
-    super().__init__(result, parent, index, loc=loc, ip=ip)
+    def __init__(
+        self,
+        parent: Union[OpView, Operation, Value],
+        index: Union[IntegerAttr, int],
+        *,
+        loc=None,
+        ip=None,
+    ):
+        parent = _get_value(parent)
+        result = pdl.ValueType.get()
+        super().__init__(result, parent, index, loc=loc, ip=ip)
 
 
 class ResultsOp:
-  """Specialization for PDL results op class."""
+    """Specialization for PDL results op class."""
 
-  def __init__(
-      self,
-      result: Type,
-      parent: Union[OpView, Operation, Value],
-      index: Optional[Union[IntegerAttr, int]] = None,
-      *,
-      loc=None,
-      ip=None,
-  ):
-    parent = _get_value(parent)
-    super().__init__(result, parent, index=index, loc=loc, ip=ip)
+    def __init__(
+        self,
+        result: Type,
+        parent: Union[OpView, Operation, Value],
+        index: Optional[Union[IntegerAttr, int]] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        parent = _get_value(parent)
+        super().__init__(result, parent, index=index, loc=loc, ip=ip)
 
 
 class RewriteOp:
-  """Specialization for PDL rewrite op class."""
-
-  def __init__(
-      self,
-      root: Optional[Union[OpView, Operation, Value]] = None,
-      name: Optional[Union[StringAttr, str]] = None,
-      args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
-      *,
-      loc=None,
-      ip=None,
-  ):
-    if args is None:
-      args = []
-    root = root if root is None else _get_value(root)
-    args = _get_values(args)
-    super().__init__(args, root=root, name=name, loc=loc, ip=ip)
-
-  def add_body(self):
-    """Add body (block) to the rewrite."""
-    self.regions[0].blocks.append()
-    return self.body
-
-  @property
-  def body(self):
-    """Return the body (block) of the rewrite."""
-    return self.regions[0].blocks[0]
+    """Specialization for PDL rewrite op class."""
+
+    def __init__(
+        self,
+        root: Optional[Union[OpView, Operation, Value]] = None,
+        name: Optional[Union[StringAttr, str]] = None,
+        args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        if args is None:
+            args = []
+        root = root if root is None else _get_value(root)
+        args = _get_values(args)
+        super().__init__(args, root=root, name=name, loc=loc, ip=ip)
+
+    def add_body(self):
+        """Add body (block) to the rewrite."""
+        self.regions[0].blocks.append()
+        return self.body
+
+    @property
+    def body(self):
+        """Return the body (block) of the rewrite."""
+        return self.regions[0].blocks[0]
 
 
 class TypeOp:
-  """Specialization for PDL type op class."""
+    """Specialization for PDL type op class."""
 
-  def __init__(self,
-               constantType: Optional[Union[TypeAttr, Type]] = None,
-               *,
-               loc=None,
-               ip=None):
-    result = pdl.TypeType.get()
-    super().__init__(result, constantType=constantType, loc=loc, ip=ip)
+    def __init__(
+        self, constantType: Optional[Union[TypeAttr, Type]] = None, *, loc=None, ip=None
+    ):
+        result = pdl.TypeType.get()
+        super().__init__(result, constantType=constantType, loc=loc, ip=ip)
 
 
 class TypesOp:
-  """Specialization for PDL types op class."""
-
-  def __init__(
-      self,
-      constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None,
-      *,
-      loc=None,
-      ip=None,
-  ):
-    if constantTypes is None:
-      constantTypes = []
-    result = pdl.RangeType.get(pdl.TypeType.get())
-    super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)
+    """Specialization for PDL types op class."""
+
+    def __init__(
+        self,
+        constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        if constantTypes is None:
+            constantTypes = []
+        result = pdl.RangeType.get(pdl.TypeType.get())
+        super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)
index 3c3e673..4b2519e 100644 (file)
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from ..ir import *
+    from ..ir import *
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
 
 from typing import Any, Optional, Sequence, Union
-from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
+from ._ods_common import (
+    get_op_result_or_value as _get_op_result_or_value,
+    get_op_results_or_values as _get_op_results_or_values,
+)
+
 
 class ForOp:
-  """Specialization for the SCF for op class."""
-
-  def __init__(self,
-               lower_bound,
-               upper_bound,
-               step,
-               iter_args: Optional[Union[Operation, OpView,
-                                         Sequence[Value]]] = None,
-               *,
-               loc=None,
-               ip=None):
-    """Creates an SCF `for` operation.
-
-    - `lower_bound` is the value to use as lower bound of the loop.
-    - `upper_bound` is the value to use as upper bound of the loop.
-    - `step` is the value to use as loop step.
-    - `iter_args` is a list of additional loop-carried arguments or an operation
-      producing them as results.
-    """
-    if iter_args is None:
-      iter_args = []
-    iter_args = _get_op_results_or_values(iter_args)
-
-    results = [arg.type for arg in iter_args]
-    super().__init__(
-        self.build_generic(
-            regions=1,
-            results=results,
-            operands=[
-                _get_op_result_or_value(o)
-                for o in [lower_bound, upper_bound, step]
-            ] + list(iter_args),
-            loc=loc,
-            ip=ip))
-    self.regions[0].blocks.append(IndexType.get(), *results)
-
-  @property
-  def body(self):
-    """Returns the body (block) of the loop."""
-    return self.regions[0].blocks[0]
-
-  @property
-  def induction_variable(self):
-    """Returns the induction variable of the loop."""
-    return self.body.arguments[0]
-
-  @property
-  def inner_iter_args(self):
-    """Returns the loop-carried arguments usable within the loop.
-
-    To obtain the loop-carried operands, use `iter_args`.
-    """
-    return self.body.arguments[1:]
+    """Specialization for the SCF for op class."""
+
+    def __init__(
+        self,
+        lower_bound,
+        upper_bound,
+        step,
+        iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
+        *,
+        loc=None,
+        ip=None
+    ):
+        """Creates an SCF `for` operation.
+
+        - `lower_bound` is the value to use as lower bound of the loop.
+        - `upper_bound` is the value to use as upper bound of the loop.
+        - `step` is the value to use as loop step.
+        - `iter_args` is a list of additional loop-carried arguments or an operation
+          producing them as results.
+        """
+        if iter_args is None:
+            iter_args = []
+        iter_args = _get_op_results_or_values(iter_args)
+
+        results = [arg.type for arg in iter_args]
+        super().__init__(
+            self.build_generic(
+                regions=1,
+                results=results,
+                operands=[
+                    _get_op_result_or_value(o) for o in [lower_bound, upper_bound, step]
+                ]
+                + list(iter_args),
+                loc=loc,
+                ip=ip,
+            )
+        )
+        self.regions[0].blocks.append(IndexType.get(), *results)
+
+    @property
+    def body(self):
+        """Returns the body (block) of the loop."""
+        return self.regions[0].blocks[0]
+
+    @property
+    def induction_variable(self):
+        """Returns the induction variable of the loop."""
+        return self.body.arguments[0]
+
+    @property
+    def inner_iter_args(self):
+        """Returns the loop-carried arguments usable within the loop.
+
+        To obtain the loop-carried operands, use `iter_args`.
+        """
+        return self.body.arguments[1:]
 
 
 class IfOp:
-  """Specialization for the SCF if op class."""
-
-  def __init__(self,
-               cond,
-               results_=[],
-               *,
-               hasElse=False,
-               loc=None,
-               ip=None):
-    """Creates an SCF `if` operation.
-
-    - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
-    - `hasElse` determines whether the if operation has the else branch.
-    """
-    operands = []
-    operands.append(cond)
-    results = []
-    results.extend(results_)
-    super().__init__(
-        self.build_generic(
-            regions=2,
-            results=results,
-            operands=operands,
-            loc=loc,
-            ip=ip))
-    self.regions[0].blocks.append(*[])
-    if hasElse:
-        self.regions[1].blocks.append(*[])
-
-  @property
-  def then_block(self):
-    """Returns the then block of the if operation."""
-    return self.regions[0].blocks[0]
-
-  @property
-  def else_block(self):
-    """Returns the else block of the if operation."""
-    return self.regions[1].blocks[0]
+    """Specialization for the SCF if op class."""
+
+    def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None):
+        """Creates an SCF `if` operation.
+
+        - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
+        - `hasElse` determines whether the if operation has the else branch.
+        """
+        operands = []
+        operands.append(cond)
+        results = []
+        results.extend(results_)
+        super().__init__(
+            self.build_generic(
+                regions=2, results=results, operands=operands, loc=loc, ip=ip
+            )
+        )
+        self.regions[0].blocks.append(*[])
+        if hasElse:
+            self.regions[1].blocks.append(*[])
+
+    @property
+    def then_block(self):
+        """Returns the then block of the if operation."""
+        return self.regions[0].blocks[0]
+
+    @property
+    def else_block(self):
+        """Returns the else block of the if operation."""
+        return self.regions[1].blocks[0]
index 9c051cd..30dafff 100644 (file)
@@ -3,11 +3,11 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from ..ir import *
-  from ._ods_common import get_op_result_or_value as _get_op_result_or_value
-  from ..dialects import pdl, transform
+    from ..ir import *
+    from ._ods_common import get_op_result_or_value as _get_op_result_or_value
+    from ..dialects import pdl, transform
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
 
 from typing import List, Optional, Sequence, Union, overload
 
@@ -16,312 +16,315 @@ OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
 
 
 def _get_int_int_array_attr(
-    values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr,
-                                                     IntOrAttrList]]]]
+    values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]]
 ) -> ArrayAttr:
-  """Creates an array attribute containing array attributes of integers.
+    """Creates an array attribute containing array attributes of integers.
 
     If the operand is already an array attribute, forwards it. Otherwise treats
     the operand as a list of attributes or integers, potentially interpserced, to
     create a new array-of-array attribute. Expects the thread-local MLIR context
     to have been set by the context manager.
     """
-  if values is None:
-    return ArrayAttr.get([])
-  if isinstance(values, ArrayAttr):
-    return values
-  if isinstance(values, list):
-    values = [
-        ArrayAttr.get(
-            [IntegerAttr.get(IntegerType.get_signless(64), v)
-             for v in value])
-        for value in values
-    ]
+    if values is None:
+        return ArrayAttr.get([])
+    if isinstance(values, ArrayAttr):
+        return values
+    if isinstance(values, list):
+        values = [
+            ArrayAttr.get(
+                [IntegerAttr.get(IntegerType.get_signless(64), v) for v in value]
+            )
+            for value in values
+        ]
 
-  return ArrayAttr.get(values)
+    return ArrayAttr.get(values)
 
 
 class DecomposeOp:
-  """Specialization for DecomposeOp class."""
+    """Specialization for DecomposeOp class."""
 
-  def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
-    super().__init__(pdl.OperationType.get(),
-                     _get_op_result_or_value(target),
-                     loc=loc,
-                     ip=ip)
+    def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
+        super().__init__(
+            pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip
+        )
 
 
 class GeneralizeOp:
-  """Specialization for GeneralizeOp class."""
+    """Specialization for GeneralizeOp class."""
 
-  def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
-    super().__init__(pdl.OperationType.get(),
-                     _get_op_result_or_value(target),
-                     loc=loc,
-                     ip=ip)
+    def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
+        super().__init__(
+            pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip
+        )
 
 
 class InterchangeOp:
-  """Specialization for InterchangeOp class."""
-
-  def __init__(
-      self,
-      target: Union[Operation, Value],
-      *,
-      iterator_interchange: OptionalIntList = None,
-      loc=None,
-      ip=None,
-  ):
-    pdl_operation_type = pdl.OperationType.get()
-    super().__init__(
-        pdl_operation_type,
-        _get_op_result_or_value(target),
-        iterator_interchange=iterator_interchange,
-        loc=loc,
-        ip=ip,
-    )
+    """Specialization for InterchangeOp class."""
+
+    def __init__(
+        self,
+        target: Union[Operation, Value],
+        *,
+        iterator_interchange: OptionalIntList = None,
+        loc=None,
+        ip=None,
+    ):
+        pdl_operation_type = pdl.OperationType.get()
+        super().__init__(
+            pdl_operation_type,
+            _get_op_result_or_value(target),
+            iterator_interchange=iterator_interchange,
+            loc=loc,
+            ip=ip,
+        )
 
 
 class MatchOp:
-  """Specialization for MatchOp class."""
-
-  @classmethod
-  def match_op_names(
-      MatchOp,
-      target: Union[Operation, Value],
-      names: Sequence[str],
-      loc=None,
-      ip=None,
-  ):
-    pdl_operation_type = pdl.OperationType.get()
-    return MatchOp(
-        pdl_operation_type,
-        _get_op_result_or_value(target),
-        ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
-        loc=loc,
-        ip=ip,
-    )
+    """Specialization for MatchOp class."""
+
+    @classmethod
+    def match_op_names(
+        MatchOp,
+        target: Union[Operation, Value],
+        names: Sequence[str],
+        loc=None,
+        ip=None,
+    ):
+        pdl_operation_type = pdl.OperationType.get()
+        return MatchOp(
+            pdl_operation_type,
+            _get_op_result_or_value(target),
+            ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
+            loc=loc,
+            ip=ip,
+        )
 
 
 class MultiTileSizesOp:
-  """Specialization for MultitileSizesOp class."""
-
-  def __init__(
-      self,
-      result_type: Type,
-      target: Union[Operation, Value],
-      *,
-      dimension: Union[int, IntegerAttr],
-      target_size: Union[int, IntegerAttr],
-      divisor: Optional[Optional[Union[int, IntegerAttr]]] = None,
-      loc=None,
-      ip=None,
-  ):
-    if divisor is None:
-      divisor = 1
-    super().__init__(
-        result_type,
-        result_type,
-        result_type,
-        _get_op_result_or_value(target),
-        dimension=dimension,
-        target_size=target_size,
-        divisor=divisor,
-        loc=loc,
-        ip=ip,
-    )
+    """Specialization for MultitileSizesOp class."""
+
+    def __init__(
+        self,
+        result_type: Type,
+        target: Union[Operation, Value],
+        *,
+        dimension: Union[int, IntegerAttr],
+        target_size: Union[int, IntegerAttr],
+        divisor: Optional[Optional[Union[int, IntegerAttr]]] = None,
+        loc=None,
+        ip=None,
+    ):
+        if divisor is None:
+            divisor = 1
+        super().__init__(
+            result_type,
+            result_type,
+            result_type,
+            _get_op_result_or_value(target),
+            dimension=dimension,
+            target_size=target_size,
+            divisor=divisor,
+            loc=loc,
+            ip=ip,
+        )
 
 
 class PadOp:
-  """Specialization for PadOp class."""
-
-  def __init__(
-      self,
-      target: Union[Operation, Value],
-      *,
-      padding_values: Optional[Optional[Union[ArrayAttr,
-                                              Sequence[Attribute]]]] = None,
-      padding_dimensions: OptionalIntList = None,
-      pack_paddings: OptionalIntList = None,
-      transpose_paddings: Optional[Union[ArrayAttr, Sequence[Union[
-          ArrayAttr, IntOrAttrList]]]] = None,
-      loc=None,
-      ip=None,
-  ):
-    if transpose_paddings is None:
-      transpose_paddings = []
-    if pack_paddings is None:
-      pack_paddings = []
-    if padding_dimensions is None:
-      padding_dimensions = []
-    if padding_values is None:
-      padding_values = []
-    pdl_operation_type = pdl.OperationType.get()
-    transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings)
-    super().__init__(
-        pdl_operation_type,
-        _get_op_result_or_value(target),
-        padding_values=padding_values,
-        padding_dimensions=padding_dimensions,
-        pack_paddings=pack_paddings,
-        transpose_paddings=transpose_paddings_attr,
-        loc=loc,
-        ip=ip,
-    )
+    """Specialization for PadOp class."""
+
+    def __init__(
+        self,
+        target: Union[Operation, Value],
+        *,
+        padding_values: Optional[
+            Optional[Union[ArrayAttr, Sequence[Attribute]]]
+        ] = None,
+        padding_dimensions: OptionalIntList = None,
+        pack_paddings: OptionalIntList = None,
+        transpose_paddings: Optional[
+            Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]
+        ] = None,
+        loc=None,
+        ip=None,
+    ):
+        if transpose_paddings is None:
+            transpose_paddings = []
+        if pack_paddings is None:
+            pack_paddings = []
+        if padding_dimensions is None:
+            padding_dimensions = []
+        if padding_values is None:
+            padding_values = []
+        pdl_operation_type = pdl.OperationType.get()
+        transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings)
+        super().__init__(
+            pdl_operation_type,
+            _get_op_result_or_value(target),
+            padding_values=padding_values,
+            padding_dimensions=padding_dimensions,
+            pack_paddings=pack_paddings,
+            transpose_paddings=transpose_paddings_attr,
+            loc=loc,
+            ip=ip,
+        )
 
 
 class ScalarizeOp:
-  """Specialization for ScalarizeOp class."""
+    """Specialization for ScalarizeOp class."""
 
-  def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
-    pdl_operation_type = pdl.OperationType.get()
-    super().__init__(pdl_operation_type,
-                     _get_op_result_or_value(target),
-                     loc=loc,
-                     ip=ip)
+    def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
+        pdl_operation_type = pdl.OperationType.get()
+        super().__init__(
+            pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip
+        )
 
 
 class SplitOp:
-  """Specialization for SplitOp class."""
-
-  def __init__(
-      self,
-      target: Union[Operation, Value],
-      dimension: Union[int, Attribute],
-      split_point: Union[int, Operation, Value, Attribute],
-      *,
-      loc=None,
-      ip=None,
-  ):
-    if isinstance(split_point, int):
-      static_split_point = split_point
-      dynamic_split_point = None
-    else:
-      static_split_point = ShapedType.get_dynamic_size()
-      dynamic_split_point = _get_op_result_or_value(split_point)
-
-    target = _get_op_result_or_value(target)
-
-    super().__init__(
-        target.type,
-        target.type,
-        target,
-        dimension=dimension,
-        static_split_point=static_split_point,
-        dynamic_split_point=dynamic_split_point,
-        loc=loc,
-        ip=ip,
-    )
+    """Specialization for SplitOp class."""
+
+    def __init__(
+        self,
+        target: Union[Operation, Value],
+        dimension: Union[int, Attribute],
+        split_point: Union[int, Operation, Value, Attribute],
+        *,
+        loc=None,
+        ip=None,
+    ):
+        if isinstance(split_point, int):
+            static_split_point = split_point
+            dynamic_split_point = None
+        else:
+            static_split_point = ShapedType.get_dynamic_size()
+            dynamic_split_point = _get_op_result_or_value(split_point)
+
+        target = _get_op_result_or_value(target)
+
+        super().__init__(
+            target.type,
+            target.type,
+            target,
+            dimension=dimension,
+            static_split_point=static_split_point,
+            dynamic_split_point=dynamic_split_point,
+            loc=loc,
+            ip=ip,
+        )
 
 
 class TileOp:
-  """Specialization for TileOp class."""
-
-  @overload
-  def __init__(
-      self,
-      loop_types: Union[Type, List[Type]],
-      target: Union[Operation, Value],
-      *,
-      sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]],
-                            ArrayAttr]] = None,
-      interchange: OptionalIntList = None,
-      loc=None,
-      ip=None,
-  ):
-    ...
-
-  @overload
-  def __init__(
-      self,
-      target: Union[Operation, Value, OpView],
-      *,
-      sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]],
-                            ArrayAttr]] = None,
-      interchange: OptionalIntList = None,
-      loc=None,
-      ip=None,
-  ):
-    ...
-
-  def __init__(
-      self,
-      loop_types_or_target: Union[Type, List[Type], Operation, Value],
-      target_or_none: Optional[Union[Operation, Value, OpView]] = None,
-      *,
-      sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]],
-                            ArrayAttr]] = None,
-      interchange: OptionalIntList = None,
-      loc=None,
-      ip=None,
-  ):
-    if interchange is None:
-      interchange = []
-    if sizes is None:
-      sizes = []
-
-    static_sizes = []
-    dynamic_sizes = []
-    if isinstance(sizes, ArrayAttr):
-      sizes_attr = sizes
-    else:
-      for size in sizes:
-        if isinstance(size, int):
-          static_sizes.append(size)
+    """Specialization for TileOp class."""
+
+    @overload
+    def __init__(
+        self,
+        loop_types: Union[Type, List[Type]],
+        target: Union[Operation, Value],
+        *,
+        sizes: Optional[
+            Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
+        ] = None,
+        interchange: OptionalIntList = None,
+        loc=None,
+        ip=None,
+    ):
+        ...
+
+    @overload
+    def __init__(
+        self,
+        target: Union[Operation, Value, OpView],
+        *,
+        sizes: Optional[
+            Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
+        ] = None,
+        interchange: OptionalIntList = None,
+        loc=None,
+        ip=None,
+    ):
+        ...
+
+    def __init__(
+        self,
+        loop_types_or_target: Union[Type, List[Type], Operation, Value],
+        target_or_none: Optional[Union[Operation, Value, OpView]] = None,
+        *,
+        sizes: Optional[
+            Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
+        ] = None,
+        interchange: OptionalIntList = None,
+        loc=None,
+        ip=None,
+    ):
+        if interchange is None:
+            interchange = []
+        if sizes is None:
+            sizes = []
+
+        static_sizes = []
+        dynamic_sizes = []
+        if isinstance(sizes, ArrayAttr):
+            sizes_attr = sizes
+        else:
+            for size in sizes:
+                if isinstance(size, int):
+                    static_sizes.append(size)
+                else:
+                    static_sizes.append(ShapedType.get_dynamic_size())
+                    dynamic_sizes.append(_get_op_result_or_value(size))
+            sizes_attr = DenseI64ArrayAttr.get(static_sizes)
+
+        num_loops = sum(v if v == 0 else 1 for v in self.__extract_values(sizes_attr))
+
+        if isinstance(loop_types_or_target, (Operation, Value, OpView)):
+            loop_types = [transform.AnyOpType.get()] * num_loops
+            target = loop_types_or_target
+            assert target_or_none is None, "Cannot construct TileOp with two targets."
         else:
-          static_sizes.append(ShapedType.get_dynamic_size())
-          dynamic_sizes.append(_get_op_result_or_value(size))
-      sizes_attr = DenseI64ArrayAttr.get(static_sizes)
-
-    num_loops = sum(
-        v if v == 0 else 1 for v in self.__extract_values(sizes_attr))
-
-    if isinstance(loop_types_or_target, (Operation, Value, OpView)):
-      loop_types = [transform.AnyOpType.get()] * num_loops
-      target = loop_types_or_target
-      assert target_or_none is None, "Cannot construct TileOp with two targets."
-    else:
-      loop_types = (([loop_types_or_target] * num_loops) if isinstance(
-          loop_types_or_target, Type) else loop_types_or_target)
-      target = target_or_none
-
-    target = _get_op_result_or_value(target)
-
-    super().__init__(
-        target.type,
-        loop_types,
-        target,
-        dynamic_sizes=dynamic_sizes,
-        static_sizes=sizes_attr,
-        interchange=interchange,
-        loc=loc,
-        ip=ip,
-    )
-
-  def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]:
-    if not attr:
-      return []
-    return [element for element in attr]
+            loop_types = (
+                ([loop_types_or_target] * num_loops)
+                if isinstance(loop_types_or_target, Type)
+                else loop_types_or_target
+            )
+            target = target_or_none
+
+        target = _get_op_result_or_value(target)
+
+        super().__init__(
+            target.type,
+            loop_types,
+            target,
+            dynamic_sizes=dynamic_sizes,
+            static_sizes=sizes_attr,
+            interchange=interchange,
+            loc=loc,
+            ip=ip,
+        )
+
+    def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]:
+        if not attr:
+            return []
+        return [element for element in attr]
 
 
 class VectorizeOp:
-  """Specialization for VectorizeOp class."""
-
-  def __init__(
-      self,
-      target: Union[Operation, Value],
-      *,
-      vectorize_padding: Union[bool, BoolAttr] = False,
-      loc=None,
-      ip=None,
-  ):
-    pdl_operation_type = pdl.OperationType.get()
-    if isinstance(vectorize_padding, bool):
-      vectorize_padding = UnitAttr.get()
-    super().__init__(
-        pdl_operation_type,
-        _get_op_result_or_value(target),
-        vectorize_padding=vectorize_padding,
-        loc=loc,
-        ip=ip,
-    )
+    """Specialization for VectorizeOp class."""
+
+    def __init__(
+        self,
+        target: Union[Operation, Value],
+        *,
+        vectorize_padding: Union[bool, BoolAttr] = False,
+        loc=None,
+        ip=None,
+    ):
+        pdl_operation_type = pdl.OperationType.get()
+        if isinstance(vectorize_padding, bool):
+            vectorize_padding = UnitAttr.get()
+        super().__init__(
+            pdl_operation_type,
+            _get_op_result_or_value(target),
+            vectorize_padding=vectorize_padding,
+            loc=loc,
+            ip=ip,
+        )
index 51d998b..09b9ec6 100644 (file)
@@ -3,40 +3,42 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from ..ir import *
+    from ..ir import *
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
 
 from typing import Any, Optional, Sequence, Union
-from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
+from ._ods_common import (
+    get_op_result_or_value as _get_op_result_or_value,
+    get_op_results_or_values as _get_op_results_or_values,
+)
 
 
 class EmptyOp:
-  """Extends the tensor.empty op."""
+    """Extends the tensor.empty op."""
 
-  def __init__(self,
-               sizes: Sequence[Union[int, Value]],
-               element_type: Type,
-               *,
-               loc=None,
-               ip=None):
-    """Constructs an `empty` with mixed static/dynamic sizes."""
-    # TODO: Refactor the EmptyOp to take an element type attribute and
-    # then use normal result type inference, unifying the Python and C++ side
-    # with a standard mechanism (versus stashing that in builders).
-    dynamic_sizes = []
-    static_sizes = []
-    for s in sizes:
-      if isinstance(s, int):
-        static_sizes.append(s)
-      else:
-        static_sizes.append(ShapedType.get_dynamic_size())
-        dynamic_sizes.append(s)
-    result_type = RankedTensorType.get(static_sizes, element_type)
-    op = self.build_generic(
-        results=[result_type],
-        operands=dynamic_sizes,
-        attributes={},
-        loc=loc,
-        ip=ip)
-    OpView.__init__(self, op)
+    def __init__(
+        self,
+        sizes: Sequence[Union[int, Value]],
+        element_type: Type,
+        *,
+        loc=None,
+        ip=None
+    ):
+        """Constructs an `empty` with mixed static/dynamic sizes."""
+        # TODO: Refactor the EmptyOp to take an element type attribute and
+        # then use normal result type inference, unifying the Python and C++ side
+        # with a standard mechanism (versus stashing that in builders).
+        dynamic_sizes = []
+        static_sizes = []
+        for s in sizes:
+            if isinstance(s, int):
+                static_sizes.append(s)
+            else:
+                static_sizes.append(ShapedType.get_dynamic_size())
+                dynamic_sizes.append(s)
+        result_type = RankedTensorType.get(static_sizes, element_type)
+        op = self.build_generic(
+            results=[result_type], operands=dynamic_sizes, attributes={}, loc=loc, ip=ip
+        )
+        OpView.__init__(self, op)
index cc4428e..425ec65 100644 (file)
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 try:
-  from ..ir import *
-  from ._ods_common import (
-      get_op_result_or_value as _get_op_result_or_value,
-      get_op_results_or_values as _get_op_results_or_values,
-  )
+    from ..ir import *
+    from ._ods_common import (
+        get_op_result_or_value as _get_op_result_or_value,
+        get_op_results_or_values as _get_op_results_or_values,
+    )
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
 
 from typing import Optional, Sequence, Union
 
 
 class CastOp:
-
-  def __init__(self,
-               result_type: Type,
-               target: Union[Operation, Value],
-               *,
-               loc=None,
-               ip=None):
-    super().__init__(result_type,
-                     _get_op_result_or_value(target),
-                     loc=loc,
-                     ip=ip)
+    def __init__(
+        self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None
+    ):
+        super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
 
 
 class GetClosestIsolatedParentOp:
-
-  def __init__(self,
-               result_type: Type,
-               target: Union[Operation, Value],
-               *,
-               loc=None,
-               ip=None):
-    super().__init__(result_type,
-                     _get_op_result_or_value(target),
-                     loc=loc,
-                     ip=ip)
+    def __init__(
+        self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None
+    ):
+        super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
 
 
 class MergeHandlesOp:
-
-  def __init__(
-      self,
-      handles: Sequence[Union[Operation, Value]],
-      *,
-      deduplicate: bool = False,
-      loc=None,
-      ip=None,
-  ):
-    super().__init__(
-        [_get_op_result_or_value(h) for h in handles],
-        deduplicate=deduplicate,
-        loc=loc,
-        ip=ip,
-    )
+    def __init__(
+        self,
+        handles: Sequence[Union[Operation, Value]],
+        *,
+        deduplicate: bool = False,
+        loc=None,
+        ip=None,
+    ):
+        super().__init__(
+            [_get_op_result_or_value(h) for h in handles],
+            deduplicate=deduplicate,
+            loc=loc,
+            ip=ip,
+        )
 
 
 class ReplicateOp:
-
-  def __init__(
-      self,
-      pattern: Union[Operation, Value],
-      handles: Sequence[Union[Operation, Value]],
-      *,
-      loc=None,
-      ip=None,
-  ):
-    super().__init__(
-        [_get_op_result_or_value(h).type for h in handles],
-        _get_op_result_or_value(pattern),
-        [_get_op_result_or_value(h) for h in handles],
-        loc=loc,
-        ip=ip,
-    )
+    def __init__(
+        self,
+        pattern: Union[Operation, Value],
+        handles: Sequence[Union[Operation, Value]],
+        *,
+        loc=None,
+        ip=None,
+    ):
+        super().__init__(
+            [_get_op_result_or_value(h).type for h in handles],
+            _get_op_result_or_value(pattern),
+            [_get_op_result_or_value(h) for h in handles],
+            loc=loc,
+            ip=ip,
+        )
 
 
 class SequenceOp:
-
-  def __init__(
-      self,
-      failure_propagation_mode,
-      results: Sequence[Type],
-      target: Union[Operation, Value, Type],
-      extra_bindings: Optional[Union[Sequence[Value], Sequence[Type], Operation,
-                                     OpView]] = None,
-  ):
-    root = (_get_op_result_or_value(target) if isinstance(
-        target, (Operation, Value)) else None)
-    root_type = root.type if not isinstance(target, Type) else target
-    if not isinstance(failure_propagation_mode, Attribute):
-      failure_propagation_mode_attr = IntegerAttr.get(
-          IntegerType.get_signless(32), failure_propagation_mode._as_int())
-    else:
-      failure_propagation_mode_attr = failure_propagation_mode
-
-    if extra_bindings is None:
-      extra_bindings = []
-    if isinstance(extra_bindings, (Operation, OpView)):
-      extra_bindings = _get_op_results_or_values(extra_bindings)
-
-    extra_binding_types = []
-    if len(extra_bindings) != 0:
-      if isinstance(extra_bindings[0], Type):
-        extra_binding_types = extra_bindings
-        extra_bindings = []
-      else:
-        extra_binding_types = [v.type for v in extra_bindings]
-
-    super().__init__(
-        results_=results,
-        failure_propagation_mode=failure_propagation_mode_attr,
-        root=root,
-        extra_bindings=extra_bindings,
-    )
-    self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
-
-  @property
-  def body(self) -> Block:
-    return self.regions[0].blocks[0]
-
-  @property
-  def bodyTarget(self) -> Value:
-    return self.body.arguments[0]
-
-  @property
-  def bodyExtraArgs(self) -> BlockArgumentList:
-    return self.body.arguments[1:]
+    def __init__(
+        self,
+        failure_propagation_mode,
+        results: Sequence[Type],
+        target: Union[Operation, Value, Type],
+        extra_bindings: Optional[
+            Union[Sequence[Value], Sequence[Type], Operation, OpView]
+        ] = None,
+    ):
+        root = (
+            _get_op_result_or_value(target)
+            if isinstance(target, (Operation, Value))
+            else None
+        )
+        root_type = root.type if not isinstance(target, Type) else target
+        if not isinstance(failure_propagation_mode, Attribute):
+            failure_propagation_mode_attr = IntegerAttr.get(
+                IntegerType.get_signless(32), failure_propagation_mode._as_int()
+            )
+        else:
+            failure_propagation_mode_attr = failure_propagation_mode
+
+        if extra_bindings is None:
+            extra_bindings = []
+        if isinstance(extra_bindings, (Operation, OpView)):
+            extra_bindings = _get_op_results_or_values(extra_bindings)
+
+        extra_binding_types = []
+        if len(extra_bindings) != 0:
+            if isinstance(extra_bindings[0], Type):
+                extra_binding_types = extra_bindings
+                extra_bindings = []
+            else:
+                extra_binding_types = [v.type for v in extra_bindings]
+
+        super().__init__(
+            results_=results,
+            failure_propagation_mode=failure_propagation_mode_attr,
+            root=root,
+            extra_bindings=extra_bindings,
+        )
+        self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
+
+    @property
+    def body(self) -> Block:
+        return self.regions[0].blocks[0]
+
+    @property
+    def bodyTarget(self) -> Value:
+        return self.body.arguments[0]
+
+    @property
+    def bodyExtraArgs(self) -> BlockArgumentList:
+        return self.body.arguments[1:]
 
 
 class YieldOp:
-
-  def __init__(
-      self,
-      operands: Optional[Union[Operation, Sequence[Value]]] = None,
-      *,
-      loc=None,
-      ip=None,
-  ):
-    if operands is None:
-      operands = []
-    super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
+    def __init__(
+        self,
+        operands: Optional[Union[Operation, Sequence[Value]]] = None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        if operands is None:
+            operands = []
+        super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
index 5a695d6..2f65131 100644 (file)
@@ -31,61 +31,60 @@ from .lang.yaml_helper import *
 
 
 def create_arg_parser() -> argparse.ArgumentParser:
-  p = argparse.ArgumentParser(description="Dump an oplib in various formats")
-  p.add_argument("modules",
-                 metavar="M",
-                 type=str,
-                 nargs="*",
-                 help="Op module to dump")
-  p.add_argument("--file",
-                 metavar="F",
-                 type=str,
-                 nargs="*",
-                 help="Python op file to dump")
-  p.add_argument("--format",
-                 type=str,
-                 dest="format",
-                 default="yaml",
-                 choices=("yaml", "repr"),
-                 help="Format in which to dump")
-  return p
+    p = argparse.ArgumentParser(description="Dump an oplib in various formats")
+    p.add_argument(
+        "modules", metavar="M", type=str, nargs="*", help="Op module to dump"
+    )
+    p.add_argument(
+        "--file", metavar="F", type=str, nargs="*", help="Python op file to dump"
+    )
+    p.add_argument(
+        "--format",
+        type=str,
+        dest="format",
+        default="yaml",
+        choices=("yaml", "repr"),
+        help="Format in which to dump",
+    )
+    return p
 
 
 def load_module_from_file(module_name, file_path):
-  spec = importlib.util.spec_from_file_location(module_name, file_path)
-  m = importlib.util.module_from_spec(spec)
-  spec.loader.exec_module(m)
-  return m
+    spec = importlib.util.spec_from_file_location(module_name, file_path)
+    m = importlib.util.module_from_spec(spec)
+    spec.loader.exec_module(m)
+    return m
 
 
 def main(args):
-  # Load all configs.
-  configs = []
-  modules = []
-  for module_name in args.modules:
-    modules.append(
-        importlib.import_module(module_name,
-                                package="mlir.dialects.linalg.opdsl"))
-  for i, file_path in enumerate(args.file or []):
-    modules.append(load_module_from_file(f"_mlir_eval_oplib{i}", file_path))
-  for m in modules:
-    for attr_name, value in m.__dict__.items():
-      # TODO: This class layering is awkward.
-      if isinstance(value, DefinedOpCallable):
-        try:
-          linalg_config = LinalgOpConfig.from_linalg_op_def(value.op_def)
-        except Exception as e:
-          raise ValueError(
-              f"Could not create LinalgOpConfig from {value.op_def}") from e
-        configs.extend(linalg_config)
-
-  # Print.
-  if args.format == "yaml":
-    print(yaml_dump_all(configs))
-  elif args.format == "repr":
-    for config in configs:
-      print(repr(config))
+    # Load all configs.
+    configs = []
+    modules = []
+    for module_name in args.modules:
+        modules.append(
+            importlib.import_module(module_name, package="mlir.dialects.linalg.opdsl")
+        )
+    for i, file_path in enumerate(args.file or []):
+        modules.append(load_module_from_file(f"_mlir_eval_oplib{i}", file_path))
+    for m in modules:
+        for attr_name, value in m.__dict__.items():
+            # TODO: This class layering is awkward.
+            if isinstance(value, DefinedOpCallable):
+                try:
+                    linalg_config = LinalgOpConfig.from_linalg_op_def(value.op_def)
+                except Exception as e:
+                    raise ValueError(
+                        f"Could not create LinalgOpConfig from {value.op_def}"
+                    ) from e
+                configs.extend(linalg_config)
+
+    # Print.
+    if args.format == "yaml":
+        print(yaml_dump_all(configs))
+    elif args.format == "repr":
+        for config in configs:
+            print(repr(config))
 
 
 if __name__ == "__main__":
-  main(create_arg_parser().parse_args())
+    main(create_arg_parser().parse_args())
index 038f068..9fa626d 100644 (file)
@@ -66,201 +66,201 @@ __all__ = [
 
 
 class AffineBuildState:
-  """Internal state for the AffineExprDef._create impls.
-
-  Note that a "local" AffineBuildState can be created relative to a "global"
-  AffineBuildState. In that case, any affine expressions built will inherit
-  symbol and dim bindings from the global state and will update both as new
-  ones are discovered. This allows for building expressions across contexts
-  which share a common symbol and dim space.
-  """
-
-  def __init__(self,
-               *,
-               global_state: "AffineBuildState" = None,
-               allow_new_symbols: bool = True,
-               allow_new_dims: bool = True):
-    if not global_state:
-      self.all_symbols = dict()  # type: Dict[str, int]
-      self.all_dims = dict()  # type: Dict[str, int]
-    else:
-      # Alias the global dict.
-      self.all_symbols = global_state.all_symbols
-      self.all_dims = global_state.all_dims
-
-    # Map of symbols and dims in the current build.
-    self.local_symbols = dict()  # type: Dict[str, int]
-    self.local_dims = dict()  # type: Dict[str, int]
-    self.allow_new_symbols = allow_new_symbols
-    self.allow_new_dims = allow_new_dims
-
-  def get_dim(self, dimname: str) -> int:
-    """Gets the dim position given a name."""
-    pos = self.all_dims.get(dimname)
-    if pos is None:
-      if not self.allow_new_dims:
-        raise ValueError(
-            f"New dimensions not allowed in the current affine expression: "
-            f"Requested '{dimname}', Availble: {self.all_dims}")
-      pos = len(self.all_dims)
-      self.all_dims[dimname] = pos
-    self.local_dims[dimname] = pos
-    return pos
-
-  def get_symbol(self, symname: str) -> int:
-    """Geta a symbol position given a name."""
-    pos = self.all_symbols.get(symname)
-    if pos is None:
-      if not self.allow_new_symbols:
-        raise ValueError(
-            f"New symbols not allowed in the current affine expression: "
-            f"Requested '{symname}', Availble: {self.all_symbols}")
-      pos = len(self.all_symbols)
-      self.all_symbols[symname] = pos
-    self.local_symbols[symname] = pos
-    return pos
-
-  @property
-  def local_dim_count(self) -> int:
-    return len(self.local_dims)
-
-  @property
-  def local_symbol_count(self) -> int:
-    return len(self.local_symbols)
-
-  @property
-  def dim_count(self) -> int:
-    return len(self.all_dims)
-
-  @property
-  def symbol_count(self) -> int:
-    return len(self.all_symbols)
-
-  def __repr__(self):
-    lines = [f"AffineBuildState<"]
-    lines.append(f"  symbols={self.local_symbols}")
-    lines.append(f"  dims={self.local_dims}>")
-    return "\n".join(lines)
+    """Internal state for the AffineExprDef._create impls.
+
+    Note that a "local" AffineBuildState can be created relative to a "global"
+    AffineBuildState. In that case, any affine expressions built will inherit
+    symbol and dim bindings from the global state and will update both as new
+    ones are discovered. This allows for building expressions across contexts
+    which share a common symbol and dim space.
+    """
+
+    def __init__(
+        self,
+        *,
+        global_state: "AffineBuildState" = None,
+        allow_new_symbols: bool = True,
+        allow_new_dims: bool = True,
+    ):
+        if not global_state:
+            self.all_symbols = dict()  # type: Dict[str, int]
+            self.all_dims = dict()  # type: Dict[str, int]
+        else:
+            # Alias the global dict.
+            self.all_symbols = global_state.all_symbols
+            self.all_dims = global_state.all_dims
+
+        # Map of symbols and dims in the current build.
+        self.local_symbols = dict()  # type: Dict[str, int]
+        self.local_dims = dict()  # type: Dict[str, int]
+        self.allow_new_symbols = allow_new_symbols
+        self.allow_new_dims = allow_new_dims
+
+    def get_dim(self, dimname: str) -> int:
+        """Gets the dim position given a name."""
+        pos = self.all_dims.get(dimname)
+        if pos is None:
+            if not self.allow_new_dims:
+                raise ValueError(
+                    f"New dimensions not allowed in the current affine expression: "
+                    f"Requested '{dimname}', Availble: {self.all_dims}"
+                )
+            pos = len(self.all_dims)
+            self.all_dims[dimname] = pos
+        self.local_dims[dimname] = pos
+        return pos
+
+    def get_symbol(self, symname: str) -> int:
+        """Geta a symbol position given a name."""
+        pos = self.all_symbols.get(symname)
+        if pos is None:
+            if not self.allow_new_symbols:
+                raise ValueError(
+                    f"New symbols not allowed in the current affine expression: "
+                    f"Requested '{symname}', Availble: {self.all_symbols}"
+                )
+            pos = len(self.all_symbols)
+            self.all_symbols[symname] = pos
+        self.local_symbols[symname] = pos
+        return pos
+
+    @property
+    def local_dim_count(self) -> int:
+        return len(self.local_dims)
+
+    @property
+    def local_symbol_count(self) -> int:
+        return len(self.local_symbols)
+
+    @property
+    def dim_count(self) -> int:
+        return len(self.all_dims)
+
+    @property
+    def symbol_count(self) -> int:
+        return len(self.all_symbols)
+
+    def __repr__(self):
+        lines = [f"AffineBuildState<"]
+        lines.append(f"  symbols={self.local_symbols}")
+        lines.append(f"  dims={self.local_dims}>")
+        return "\n".join(lines)
 
 
 class AffineExprDef:
-  """Base class for an affine expression being defined."""
+    """Base class for an affine expression being defined."""
 
-  def build(self, state: Optional[AffineBuildState] = None) -> _ir.AffineExpr:
-    """Builds the corresponding _ir.AffineExpr from the definitions.
-    """
-    state = AffineBuildState() if state is None else state
-    expr = self._create(state)
-    return expr
+    def build(self, state: Optional[AffineBuildState] = None) -> _ir.AffineExpr:
+        """Builds the corresponding _ir.AffineExpr from the definitions."""
+        state = AffineBuildState() if state is None else state
+        expr = self._create(state)
+        return expr
 
-  def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
-    raise NotImplementedError()
+    def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
+        raise NotImplementedError()
 
-  @staticmethod
-  def coerce_from(py_value):
-    if isinstance(py_value, int):
-      return AffineConstantExpr(py_value)
-    assert isinstance(py_value, AffineExprDef)
-    return py_value
+    @staticmethod
+    def coerce_from(py_value):
+        if isinstance(py_value, int):
+            return AffineConstantExpr(py_value)
+        assert isinstance(py_value, AffineExprDef)
+        return py_value
 
-  def visit_affine_exprs(self, callback):
-    """Visits all AffineExprDefs including self."""
-    callback(self)
+    def visit_affine_exprs(self, callback):
+        """Visits all AffineExprDefs including self."""
+        callback(self)
 
-  def __add__(lhs, rhs):
-    rhs = AffineExprDef.coerce_from(rhs)
-    return AffineBinaryExprDef(_ir.AffineAddExpr, lhs, rhs)
+    def __add__(lhs, rhs):
+        rhs = AffineExprDef.coerce_from(rhs)
+        return AffineBinaryExprDef(_ir.AffineAddExpr, lhs, rhs)
 
-  def __mul__(lhs, rhs):
-    rhs = AffineExprDef.coerce_from(rhs)
-    return AffineBinaryExprDef(_ir.AffineMulExpr, lhs, rhs)
+    def __mul__(lhs, rhs):
+        rhs = AffineExprDef.coerce_from(rhs)
+        return AffineBinaryExprDef(_ir.AffineMulExpr, lhs, rhs)
 
-  def __mod__(lhs, rhs):
-    rhs = AffineExprDef.coerce_from(rhs)
-    return AffineBinaryExprDef(_ir.AffineModExpr, lhs, rhs)
+    def __mod__(lhs, rhs):
+        rhs = AffineExprDef.coerce_from(rhs)
+        return AffineBinaryExprDef(_ir.AffineModExpr, lhs, rhs)
 
-  def __floordiv__(lhs, rhs):
-    rhs = AffineExprDef.coerce_from(rhs)
-    return AffineBinaryExprDef(_ir.AffineFloorDivExpr, lhs, rhs)
+    def __floordiv__(lhs, rhs):
+        rhs = AffineExprDef.coerce_from(rhs)
+        return AffineBinaryExprDef(_ir.AffineFloorDivExpr, lhs, rhs)
 
-  def __truediv__(lhs, rhs):
-    # TODO: Not really a ceil div - taking liberties for the DSL.
-    rhs = AffineExprDef.coerce_from(rhs)
-    return AffineBinaryExprDef(_ir.AffineCeilDivExpr, lhs, rhs)
+    def __truediv__(lhs, rhs):
+        # TODO: Not really a ceil div - taking liberties for the DSL.
+        rhs = AffineExprDef.coerce_from(rhs)
+        return AffineBinaryExprDef(_ir.AffineCeilDivExpr, lhs, rhs)
 
 
 class AffineConstantExpr(AffineExprDef):
-  """An affine constant being defined."""
+    """An affine constant being defined."""
 
-  def __init__(self, value: int):
-    assert isinstance(value, int)
-    self.value = value
+    def __init__(self, value: int):
+        assert isinstance(value, int)
+        self.value = value
 
-  def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
-    return _ir.AffineConstantExpr.get(self.value)
+    def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
+        return _ir.AffineConstantExpr.get(self.value)
 
-  def __repr__(self):
-    return f"Const({self.value})"
+    def __repr__(self):
+        return f"Const({self.value})"
 
 
 class AffineBinaryExprDef(AffineExprDef):
-  """An affine binary expression being defined."""
+    """An affine binary expression being defined."""
 
-  def __init__(self, ir_ctor, lhs: AffineExprDef, rhs: AffineExprDef):
-    self.ir_ctor = ir_ctor
-    self.lhs = lhs
-    self.rhs = rhs
+    def __init__(self, ir_ctor, lhs: AffineExprDef, rhs: AffineExprDef):
+        self.ir_ctor = ir_ctor
+        self.lhs = lhs
+        self.rhs = rhs
 
-  def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
-    return self.ir_ctor.get(self.lhs._create(state), self.rhs._create(state))
+    def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
+        return self.ir_ctor.get(self.lhs._create(state), self.rhs._create(state))
 
-  def visit_affine_exprs(self, callback):
-    """Visits all AffineExprDefs including self."""
-    super().visit_affine_exprs(callback)
-    self.lhs.visit_affine_exprs(callback)
-    self.rhs.visit_affine_exprs(callback)
+    def visit_affine_exprs(self, callback):
+        """Visits all AffineExprDefs including self."""
+        super().visit_affine_exprs(callback)
+        self.lhs.visit_affine_exprs(callback)
+        self.rhs.visit_affine_exprs(callback)
 
-  def __repr__(self):
-    return f"{self.ir_ctor.__name__}({repr(self.lhs)}, {repr(self.rhs)})"
+    def __repr__(self):
+        return f"{self.ir_ctor.__name__}({repr(self.lhs)}, {repr(self.rhs)})"
 
 
 class DimDef(AffineExprDef):
-  """Represents a named dimension.
-
-  """
-  ALL_DIMS = dict()  # type: Dict[str, "DimDef"]
-
-  def __new__(cls, dimname: str):
-    existing = cls.ALL_DIMS.get(dimname)
-    if existing is not None:
-      return existing
-    new = super().__new__(cls)
-    new.dimname = dimname
-    cls.ALL_DIMS[dimname] = new
-    return new
-
-  def __repr__(self):
-    return f"Dim({self.dimname})"
-
-  def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
-    pos = state.get_dim(self.dimname)
-    return _ir.AffineDimExpr.get(position=pos)
-
-  @classmethod
-  def create_expando(cls):
-    """Create an expando class that creates unique symbols based on attr access.
-    """
+    """Represents a named dimension."""
+
+    ALL_DIMS = dict()  # type: Dict[str, "DimDef"]
+
+    def __new__(cls, dimname: str):
+        existing = cls.ALL_DIMS.get(dimname)
+        if existing is not None:
+            return existing
+        new = super().__new__(cls)
+        new.dimname = dimname
+        cls.ALL_DIMS[dimname] = new
+        return new
 
-    class ExpandoDims:
+    def __repr__(self):
+        return f"Dim({self.dimname})"
 
-      def __getattr__(self, n):
-        return cls(n)
+    def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
+        pos = state.get_dim(self.dimname)
+        return _ir.AffineDimExpr.get(position=pos)
 
-    return ExpandoDims()
+    @classmethod
+    def create_expando(cls):
+        """Create an expando class that creates unique symbols based on attr access."""
+
+        class ExpandoDims:
+            def __getattr__(self, n):
+                return cls(n)
+
+        return ExpandoDims()
 
 
 class SymbolDef(AffineExprDef):
-  """Represents a named symbol.
+    """Represents a named symbol.
 
     >>> s1 = SymbolDef("s1")
     >>> s1
@@ -270,36 +270,35 @@ class SymbolDef(AffineExprDef):
     False
     >>> s1 is SymbolDef("s1")
     True
-  """
-  ALL_SYMBOLS = dict()  # type: Dict[str, "SymbolDef"]
-
-  def __new__(cls, symname: str):
-    existing = cls.ALL_SYMBOLS.get(symname)
-    if existing is not None:
-      return existing
-    new = super().__new__(cls)
-    new.symname = symname
-    cls.ALL_SYMBOLS[symname] = new
-    return new
-
-  def __repr__(self):
-    return f"Symbol({self.symname})"
-
-  def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
-    pos = state.get_symbol(self.symname)
-    return _ir.AffineSymbolExpr.get(position=pos)
-
-  @classmethod
-  def create_expando(cls):
-    """Create an expando class that creates unique symbols based on attr access.
     """
 
-    class ExpandoSymbols:
+    ALL_SYMBOLS = dict()  # type: Dict[str, "SymbolDef"]
+
+    def __new__(cls, symname: str):
+        existing = cls.ALL_SYMBOLS.get(symname)
+        if existing is not None:
+            return existing
+        new = super().__new__(cls)
+        new.symname = symname
+        cls.ALL_SYMBOLS[symname] = new
+        return new
+
+    def __repr__(self):
+        return f"Symbol({self.symname})"
+
+    def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
+        pos = state.get_symbol(self.symname)
+        return _ir.AffineSymbolExpr.get(position=pos)
+
+    @classmethod
+    def create_expando(cls):
+        """Create an expando class that creates unique symbols based on attr access."""
 
-      def __getattr__(self, n):
-        return cls(n)
+        class ExpandoSymbols:
+            def __getattr__(self, n):
+                return cls(n)
 
-    return ExpandoSymbols()
+        return ExpandoSymbols()
 
 
 # Global accessor for on-demand dims and symbols.
index 135f55e..5d5866f 100644 (file)
@@ -23,223 +23,232 @@ from .yaml_helper import *
 
 
 class TensorExpression:
-  """An expression that can appear on the RHS of a comprehension."""
+    """An expression that can appear on the RHS of a comprehension."""
 
-  def to_scalar_expression(self) -> ScalarExpression:
-    raise NotImplementedError()
+    def to_scalar_expression(self) -> ScalarExpression:
+        raise NotImplementedError()
 
-  def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
-    """Visits all tensor expression reachable by the expression."""
-    callback(self)
+    def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
+        """Visits all tensor expression reachable by the expression."""
+        callback(self)
 
-  def collect_dim_uses(self, uses: Set["DimDef"]):
-    """Collects all DimDefs reachable through this expression."""
+    def collect_dim_uses(self, uses: Set["DimDef"]):
+        """Collects all DimDefs reachable through this expression."""
 
-    def visit_dim_def(dim_def: AffineExprDef):
-      if isinstance(dim_def, DimDef):
-        uses.add(dim_def)
+        def visit_dim_def(dim_def: AffineExprDef):
+            if isinstance(dim_def, DimDef):
+                uses.add(dim_def)
 
-    def visit_affine_exprs(expr: "TensorExpression"):
-      if isinstance(expr, TensorUse):
-        for ind in expr.indices:
-          ind.visit_affine_exprs(visit_dim_def)
-      if isinstance(expr, TensorReduceFn):
-        for ind in expr.reduce_fn.reduce_dims:
-          ind.visit_affine_exprs(visit_dim_def)
+        def visit_affine_exprs(expr: "TensorExpression"):
+            if isinstance(expr, TensorUse):
+                for ind in expr.indices:
+                    ind.visit_affine_exprs(visit_dim_def)
+            if isinstance(expr, TensorReduceFn):
+                for ind in expr.reduce_fn.reduce_dims:
+                    ind.visit_affine_exprs(visit_dim_def)
 
-    self.visit_tensor_exprs(visit_affine_exprs)
+        self.visit_tensor_exprs(visit_affine_exprs)
 
-  def collect_tensor_uses(self, uses: Set["TensorUse"]):
-    """Collects all TensorUses reachable through this expression."""
+    def collect_tensor_uses(self, uses: Set["TensorUse"]):
+        """Collects all TensorUses reachable through this expression."""
 
-    def visit_tensor_use(expr: "TensorExpression"):
-      if isinstance(expr, TensorUse):
-        uses.add(expr)
+        def visit_tensor_use(expr: "TensorExpression"):
+            if isinstance(expr, TensorUse):
+                uses.add(expr)
 
-    self.visit_tensor_exprs(visit_tensor_use)
+        self.visit_tensor_exprs(visit_tensor_use)
 
-  def collect_indices(self, indices: Set["index"]):
-    """Collects all index accesses reachable through this expression."""
+    def collect_indices(self, indices: Set["index"]):
+        """Collects all index accesses reachable through this expression."""
 
-    def visit_index(expr: "TensorExpression"):
-      if isinstance(expr, index):
-        indices.add(expr)
+        def visit_index(expr: "TensorExpression"):
+            if isinstance(expr, index):
+                indices.add(expr)
 
-    self.visit_tensor_exprs(visit_index)
+        self.visit_tensor_exprs(visit_index)
 
-  def collect_scalar_uses(self, uses: Set["ScalarDef"]):
-    """Collects all ScalarDefs reachable through this expression."""
+    def collect_scalar_uses(self, uses: Set["ScalarDef"]):
+        """Collects all ScalarDefs reachable through this expression."""
 
-    def visit_scalar_def(expr: "TensorExpression"):
-      if isinstance(expr, ScalarDef):
-        uses.add(expr)
+        def visit_scalar_def(expr: "TensorExpression"):
+            if isinstance(expr, ScalarDef):
+                uses.add(expr)
 
-    self.visit_tensor_exprs(visit_scalar_def)
+        self.visit_tensor_exprs(visit_scalar_def)
 
-  def __add__(self, rhs: "TensorExpression") -> "TensorExpression":
-    return BinaryFn.add(self, rhs)
+    def __add__(self, rhs: "TensorExpression") -> "TensorExpression":
+        return BinaryFn.add(self, rhs)
 
-  def __mul__(self, rhs) -> "TensorExpression":
-    return BinaryFn.mul(self, rhs)
+    def __mul__(self, rhs) -> "TensorExpression":
+        return BinaryFn.mul(self, rhs)
 
-  def __sub__(self, rhs) -> "TensorExpression":
-    return BinaryFn.sub(self, rhs)
+    def __sub__(self, rhs) -> "TensorExpression":
+        return BinaryFn.sub(self, rhs)
 
-  def __hash__(self):
-    return hash(id(self))
+    def __hash__(self):
+        return hash(id(self))
 
 
 class TensorUse(TensorExpression):
-  """A used tensor represented by its (tensor_name, indices).
-
-  Note that forming a comprehension via direct assignment is performed through
-  __setitem__ on the TensorDef level. However, performing a reduction with
-  compound ops (+=, *=, etc) is done by doing a:
-    TensorDef.__getitem__
-    TensorUse.__iadd__
-    TensorDef.__setitem__
-  """
-
-  def __init__(self, operand_def: "OperandDef",
-               indices: Sequence[AffineExprDef]):
-    self.operand_def = operand_def
-    self.indices = tuple(indices)
-
-  def to_scalar_expression(self) -> ScalarExpression:
-    return ScalarArg(self.tensor_name).expr()
-
-  @property
-  def tensor_name(self) -> str:
-    name = self.operand_def.name
-    assert name is not None, "TensorDef not registered with an op"
-    return name
-
-  def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]:
-    # Computes the reduction dims for implicit reductions. Assumes that the rhs
-    # is the expression being reduced and self is being reduced into. Any
-    # indices referenced on the rhs and not in self are considered reduction
-    # dims and will be ordered as encountered on the rhs.
-    rhs_dims = set()
-    lhs_dims = set()
-    rhs.collect_dim_uses(rhs_dims)
-    self.collect_dim_uses(lhs_dims)
-    return rhs_dims - lhs_dims
-
-  def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn":
-    return ReduceFnUse(BinaryFn.add, None, *self._compute_reduce_dims(rhs))(rhs)
-
-  def __repr__(self):
-    return (f"{self.operand_def.name}"
-            f"[{', '.join([repr(i) for i in self.indices])}]")
+    """A used tensor represented by its (tensor_name, indices).
+
+    Note that forming a comprehension via direct assignment is performed through
+    __setitem__ on the TensorDef level. However, performing a reduction with
+    compound ops (+=, *=, etc) is done by doing a:
+      TensorDef.__getitem__
+      TensorUse.__iadd__
+      TensorDef.__setitem__
+    """
+
+    def __init__(self, operand_def: "OperandDef", indices: Sequence[AffineExprDef]):
+        self.operand_def = operand_def
+        self.indices = tuple(indices)
+
+    def to_scalar_expression(self) -> ScalarExpression:
+        return ScalarArg(self.tensor_name).expr()
+
+    @property
+    def tensor_name(self) -> str:
+        name = self.operand_def.name
+        assert name is not None, "TensorDef not registered with an op"
+        return name
+
+    def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]:
+        # Computes the reduction dims for implicit reductions. Assumes that the rhs
+        # is the expression being reduced and self is being reduced into. Any
+        # indices referenced on the rhs and not in self are considered reduction
+        # dims and will be ordered as encountered on the rhs.
+        rhs_dims = set()
+        lhs_dims = set()
+        rhs.collect_dim_uses(rhs_dims)
+        self.collect_dim_uses(lhs_dims)
+        return rhs_dims - lhs_dims
+
+    def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn":
+        return ReduceFnUse(BinaryFn.add, None, *self._compute_reduce_dims(rhs))(rhs)
+
+    def __repr__(self):
+        return (
+            f"{self.operand_def.name}" f"[{', '.join([repr(i) for i in self.indices])}]"
+        )
 
 
 class TensorFn(TensorExpression):
-  """Application of a tensor function."""
-
-  def __init__(self, kind: "FunctionKind", name: Optional[str],
-               operand_def: Optional["OperandDef"], type_var: Optional[TypeVar],
-               args: Sequence[TensorExpression]):
-    if bool(name) + bool(operand_def) != 1:
-      raise ValueError("One of 'name', 'operand_def' must be specified")
-    self.name = name
-    self.kind = kind
-    self.operand_def = operand_def
-    self.type_var = type_var
-    self.args = args
-
-  def to_scalar_expression(self) -> ScalarExpression:
-    if self.operand_def:
-      assert self.operand_def.name, "TensorFn not registered with an op"
-    attr_name = self.operand_def.name if self.operand_def else None
-    args = [arg.to_scalar_expression() for arg in self.args]
-    return ScalarFn(self.kind, self.name, attr_name, self.type_var, args).expr()
-
-  def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
-    super().visit_tensor_exprs(callback)
-    for arg in self.args:
-      arg.visit_tensor_exprs(callback)
-
-  def __repr__(self):
-    name = self.operand_def.name if self.operand_def else self.name
-    return (f"{self.kind.name}.{name}(type_var={self.type_var}, "
-            f"args={', '.join(repr(a) for a in self.args)})")
+    """Application of a tensor function."""
+
+    def __init__(
+        self,
+        kind: "FunctionKind",
+        name: Optional[str],
+        operand_def: Optional["OperandDef"],
+        type_var: Optional[TypeVar],
+        args: Sequence[TensorExpression],
+    ):
+        if bool(name) + bool(operand_def) != 1:
+            raise ValueError("One of 'name', 'operand_def' must be specified")
+        self.name = name
+        self.kind = kind
+        self.operand_def = operand_def
+        self.type_var = type_var
+        self.args = args
+
+    def to_scalar_expression(self) -> ScalarExpression:
+        if self.operand_def:
+            assert self.operand_def.name, "TensorFn not registered with an op"
+        attr_name = self.operand_def.name if self.operand_def else None
+        args = [arg.to_scalar_expression() for arg in self.args]
+        return ScalarFn(self.kind, self.name, attr_name, self.type_var, args).expr()
+
+    def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
+        super().visit_tensor_exprs(callback)
+        for arg in self.args:
+            arg.visit_tensor_exprs(callback)
+
+    def __repr__(self):
+        name = self.operand_def.name if self.operand_def else self.name
+        return (
+            f"{self.kind.name}.{name}(type_var={self.type_var}, "
+            f"args={', '.join(repr(a) for a in self.args)})"
+        )
 
 
 class TensorReduceFn(TensorExpression):
-  """Application of a reduction function.
-
-  This captures the lhs (initial value) separately from the rhs.
-  """
-
-  def __init__(self, reduce_use: "ReduceFnUse",
-               args: Sequence[TensorExpression]):
-    self.reduce_use = reduce_use
-    self.lhs = None  # type: Optional[TensorUse]
-    self.args = args
-
-  def to_scalar_expression(self) -> ScalarExpression:
-    if self.lhs is None:
-      raise ValueError(f"Cannot scalarize a TensorReduceFn that has not been "
-                       f"bound to its lhs: {self}")
-    full_args = [self.lhs.to_scalar_expression()
-                ] + [arg.to_scalar_expression() for arg in self.args]
-    fn_name = None
-    attr_name = None
-    if self.reduce_use.binary_fn:
-      fn_name = self.reduce_use.binary_fn.fn_name
-    if self.reduce_use.binary_attr:
-      attr_name = self.reduce_use.binary_attr.operand_def.name
-    return ScalarFn(FunctionKind.BINARY, fn_name, attr_name, None,
-                    full_args).expr()
-
-  def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
-    for arg in self.args:
-      arg.visit_tensor_exprs(callback)
-
-  def __repr__(self):
-    return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})"
+    """Application of a reduction function.
+
+    This captures the lhs (initial value) separately from the rhs.
+    """
+
+    def __init__(self, reduce_use: "ReduceFnUse", args: Sequence[TensorExpression]):
+        self.reduce_use = reduce_use
+        self.lhs = None  # type: Optional[TensorUse]
+        self.args = args
+
+    def to_scalar_expression(self) -> ScalarExpression:
+        if self.lhs is None:
+            raise ValueError(
+                f"Cannot scalarize a TensorReduceFn that has not been "
+                f"bound to its lhs: {self}"
+            )
+        full_args = [self.lhs.to_scalar_expression()] + [
+            arg.to_scalar_expression() for arg in self.args
+        ]
+        fn_name = None
+        attr_name = None
+        if self.reduce_use.binary_fn:
+            fn_name = self.reduce_use.binary_fn.fn_name
+        if self.reduce_use.binary_attr:
+            attr_name = self.reduce_use.binary_attr.operand_def.name
+        return ScalarFn(FunctionKind.BINARY, fn_name, attr_name, None, full_args).expr()
+
+    def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
+        for arg in self.args:
+            arg.visit_tensor_exprs(callback)
+
+    def __repr__(self):
+        return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})"
 
 
 class const(TensorExpression):
-  """Returns the given constant floating point or integer value."""
+    """Returns the given constant floating point or integer value."""
 
-  def __init__(self, value: Any):
-    with _ir.Context():
-      if isinstance(value, float):
-        self.value = str(_ir.FloatAttr.get_f64(float(value)))
-      elif isinstance(value, int):
-        self.value = str(
-            _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value)))
-      else:
-        raise ValueError(f"const requires int or float but got {type(value)}")
+    def __init__(self, value: Any):
+        with _ir.Context():
+            if isinstance(value, float):
+                self.value = str(_ir.FloatAttr.get_f64(float(value)))
+            elif isinstance(value, int):
+                self.value = str(
+                    _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value))
+                )
+            else:
+                raise ValueError(f"const requires int or float but got {type(value)}")
 
-  def to_scalar_expression(self) -> ScalarExpression:
-    return ScalarConst(self.value).expr()
+    def to_scalar_expression(self) -> ScalarExpression:
+        return ScalarConst(self.value).expr()
 
-  def __repr__(self):
-    return f"const({self.value})"
+    def __repr__(self):
+        return f"const({self.value})"
 
 
 class index(TensorExpression):
-  """Returns the iteration index for a given dimension name.
+    """Returns the iteration index for a given dimension name.
 
-  Resolves the given dimension name to obtain its position in the iteration
-  domain of the operation.
-  """
+    Resolves the given dimension name to obtain its position in the iteration
+    domain of the operation.
+    """
 
-  def __init__(self, dim: DimDef):
-    self.dim_def = dim
-    self.dim = -1
+    def __init__(self, dim: DimDef):
+        self.dim_def = dim
+        self.dim = -1
 
-  def resolve_dimension_name(self, affine_state: AffineBuildState):
-    self.dim = affine_state.get_dim(self.dim_def.dimname)
+    def resolve_dimension_name(self, affine_state: AffineBuildState):
+        self.dim = affine_state.get_dim(self.dim_def.dimname)
 
-  def to_scalar_expression(self) -> ScalarExpression:
-    assert self.dim != -1, "Dimension name not resolved"
-    return ScalarIndex(self.dim).expr()
+    def to_scalar_expression(self) -> ScalarExpression:
+        assert self.dim != -1, "Dimension name not resolved"
+        return ScalarIndex(self.dim).expr()
 
-  def __repr__(self):
-    return f"index({repr(self.dim)})"
+    def __repr__(self):
+        return f"index({repr(self.dim)})"
 
 
 ###############################################################################
@@ -248,155 +257,160 @@ class index(TensorExpression):
 
 
 class FunctionKind(Enum):
-  UNARY = 0
-  BINARY = 1
-  TYPE = 2
+    UNARY = 0
+    BINARY = 1
+    TYPE = 2
 
 
 class UnaryFnType:
-  """Unary function.
+    """Unary function.
 
-  A unary function takes one tensor expression and returns the
-  function evaluation result.
-  """
+    A unary function takes one tensor expression and returns the
+    function evaluation result.
+    """
 
-  def __init__(self, fn_name: str):
-    self.fn_name = fn_name
+    def __init__(self, fn_name: str):
+        self.fn_name = fn_name
 
-  def __call__(self, arg: TensorExpression) -> "TensorFn":
-    return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [arg])
+    def __call__(self, arg: TensorExpression) -> "TensorFn":
+        return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [arg])
 
-  def __repr__(self):
-    return f"{self.fn_name}"
+    def __repr__(self):
+        return f"{self.fn_name}"
 
 
 class UnaryFn:
-  """Unary function namespace."""
-  exp = UnaryFnType("exp")
-  log = UnaryFnType("log")
-  abs = UnaryFnType("abs")
-  ceil = UnaryFnType("ceil")
-  floor = UnaryFnType("floor")
-  negf = UnaryFnType("negf")
+    """Unary function namespace."""
+
+    exp = UnaryFnType("exp")
+    log = UnaryFnType("log")
+    abs = UnaryFnType("abs")
+    ceil = UnaryFnType("ceil")
+    floor = UnaryFnType("floor")
+    negf = UnaryFnType("negf")
 
 
 class BinaryFnType:
-  """Binary function.
+    """Binary function.
 
-  A binary function takes two tensor expressions and returns the
-  function evaluation result.
-  """
+    A binary function takes two tensor expressions and returns the
+    function evaluation result.
+    """
 
-  def __init__(self, fn_name: str):
-    self.fn_name = fn_name
+    def __init__(self, fn_name: str):
+        self.fn_name = fn_name
 
-  def __call__(self, arg0: TensorExpression,
-               arg1: TensorExpression) -> "TensorFn":
-    return TensorFn(FunctionKind.BINARY, self.fn_name, None, None, [arg0, arg1])
+    def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> "TensorFn":
+        return TensorFn(FunctionKind.BINARY, self.fn_name, None, None, [arg0, arg1])
 
-  def __repr__(self):
-    return f"{self.fn_name}"
+    def __repr__(self):
+        return f"{self.fn_name}"
 
 
 class BinaryFn:
-  """Binary function namespace.
+    """Binary function namespace.
 
-  As the integer types are signless, signedness is implement by different
-  functions that treat integers as signed or unsigned values.
+    As the integer types are signless, signedness is implement by different
+    functions that treat integers as signed or unsigned values.
+
+    Examples:
+    - max -> `arith.MaxSIOp`
+    - max_unsinged -> `arith.MaxUIOp`
+    """
 
-  Examples:
-  - max -> `arith.MaxSIOp`
-  - max_unsinged -> `arith.MaxUIOp`
-  """
-  add = BinaryFnType("add")
-  sub = BinaryFnType("sub")
-  mul = BinaryFnType("mul")
-  max_signed = BinaryFnType("max_signed")
-  min_signed = BinaryFnType("min_signed")
-  max_unsigned = BinaryFnType("max_unsigned")
-  min_unsigned = BinaryFnType("min_unsigned")
+    add = BinaryFnType("add")
+    sub = BinaryFnType("sub")
+    mul = BinaryFnType("mul")
+    max_signed = BinaryFnType("max_signed")
+    min_signed = BinaryFnType("min_signed")
+    max_unsigned = BinaryFnType("max_unsigned")
+    min_unsigned = BinaryFnType("min_unsigned")
 
 
 class TypeFnType:
-  """Type conversion function.
+    """Type conversion function.
 
-  A type conversion function takes a target type and a tensor expression and
-  returns the casted tensor expression.
-  """
+    A type conversion function takes a target type and a tensor expression and
+    returns the casted tensor expression.
+    """
 
-  def __init__(self, fn_name: str):
-    self.fn_name = fn_name
+    def __init__(self, fn_name: str):
+        self.fn_name = fn_name
 
-  def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn":
-    return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg])
+    def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn":
+        return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg])
 
-  def __repr__(self):
-    return f"{self.fn_name}"
+    def __repr__(self):
+        return f"{self.fn_name}"
 
 
 class TypeFn:
-  """Type conversion function namespace.
+    """Type conversion function namespace.
+
+    As the integer types are signless, signedness is implement by different cast
+    functions that treat integers as signed (`cast_signed`) or unsigned
+    (`cast_unsigned`) values.
 
-  As the integer types are signless, signedness is implement by different cast
-  functions that treat integers as signed (`cast_signed`) or unsigned
-  (`cast_unsigned`) values.
+    Examples:
+    - cast_signed(I32 -> I64) -> `arith.ExtSIOp`
+    - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
+    """
 
-  Examples:
-  - cast_signed(I32 -> I64) -> `arith.ExtSIOp`
-  - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
-  """
-  cast_signed = TypeFnType("cast_signed")
-  cast_unsigned = TypeFnType("cast_unsigned")
+    cast_signed = TypeFnType("cast_signed")
+    cast_unsigned = TypeFnType("cast_unsigned")
 
 
 class ReduceFnUse:
-  """Reduction function use.
+    """Reduction function use.
 
-  A reduction use specifies the reduction function and dimensions.
-  """
+    A reduction use specifies the reduction function and dimensions.
+    """
 
-  def __init__(self, binary_fn: Optional[BinaryFnType],
-               binary_attr: Optional["BinaryFnAttrDef"], *reduce_dims: DimDef):
-    if bool(binary_fn) + bool(binary_attr) != 1:
-      raise ValueError("One of 'binary_fn', 'binary_attr' must be specified")
-    self.binary_fn = binary_fn
-    self.binary_attr = binary_attr
-    self.reduce_dims = reduce_dims
+    def __init__(
+        self,
+        binary_fn: Optional[BinaryFnType],
+        binary_attr: Optional["BinaryFnAttrDef"],
+        *reduce_dims: DimDef,
+    ):
+        if bool(binary_fn) + bool(binary_attr) != 1:
+            raise ValueError("One of 'binary_fn', 'binary_attr' must be specified")
+        self.binary_fn = binary_fn
+        self.binary_attr = binary_attr
+        self.reduce_dims = reduce_dims
 
-  def __call__(self, *args: TensorExpression) -> "TensorReduceFn":
-    return TensorReduceFn(self, args)
+    def __call__(self, *args: TensorExpression) -> "TensorReduceFn":
+        return TensorReduceFn(self, args)
 
-  def __repr__(self):
-    fn = self.binary_fn if self.binary_fn else self.binary_attr
-    return (
-        f"reduce_{repr(fn)}({', '.join(repr(d) for d in self.reduce_dims)})")
+    def __repr__(self):
+        fn = self.binary_fn if self.binary_fn else self.binary_attr
+        return f"reduce_{repr(fn)}({', '.join(repr(d) for d in self.reduce_dims)})"
 
 
 class ReduceFnType:
-  """Reduction function.
+    """Reduction function.
 
-  A binary function that reduces its RHS into its LHS.
-  """
+    A binary function that reduces its RHS into its LHS.
+    """
 
-  def __init__(self, binary_fn: BinaryFnType):
-    if not isinstance(binary_fn, BinaryFnType):
-      raise ValueError(f"Reduce expected a BinaryFnType but got {binary_fn}")
-    self.binary_fn = binary_fn
+    def __init__(self, binary_fn: BinaryFnType):
+        if not isinstance(binary_fn, BinaryFnType):
+            raise ValueError(f"Reduce expected a BinaryFnType but got {binary_fn}")
+        self.binary_fn = binary_fn
 
-  def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
-    return ReduceFnUse(self.binary_fn, None, *reduce_dims)
+    def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
+        return ReduceFnUse(self.binary_fn, None, *reduce_dims)
 
-  def __repr__(self):
-    return f"reduce_{repr(self.binary_fn)}"
+    def __repr__(self):
+        return f"reduce_{repr(self.binary_fn)}"
 
 
 class ReduceFn:
-  add = ReduceFnType(BinaryFn.add)
-  mul = ReduceFnType(BinaryFn.mul)
-  max_signed = ReduceFnType(BinaryFn.max_signed)
-  min_signed = ReduceFnType(BinaryFn.min_signed)
-  max_unsigned = ReduceFnType(BinaryFn.max_unsigned)
-  min_unsigned = ReduceFnType(BinaryFn.min_unsigned)
+    add = ReduceFnType(BinaryFn.add)
+    mul = ReduceFnType(BinaryFn.mul)
+    max_signed = ReduceFnType(BinaryFn.max_signed)
+    min_signed = ReduceFnType(BinaryFn.min_signed)
+    max_unsigned = ReduceFnType(BinaryFn.max_unsigned)
+    min_unsigned = ReduceFnType(BinaryFn.min_unsigned)
 
 
 ###############################################################################
@@ -405,237 +419,265 @@ class ReduceFn:
 
 
 class OperandKind(Enum):
-  INPUT_TENSOR = 0
-  SCALAR = 1
-  OUTPUT_TENSOR = 2
-  INDEX_ATTR = 3
-  UNARY_FN_ATTR = 4
-  BINARY_FN_ATTR = 5
-  TYPE_FN_ATTR = 6
+    INPUT_TENSOR = 0
+    SCALAR = 1
+    OUTPUT_TENSOR = 2
+    INDEX_ATTR = 3
+    UNARY_FN_ATTR = 4
+    BINARY_FN_ATTR = 5
+    TYPE_FN_ATTR = 6
 
 
 class OperandDef:
-  """Definition of an operand passed to an operation.
-
-  Keep the meta information of Tensor, Scalar, and Attribute operands and
-  provide the shared registration functionality.
-  """
-
-  def __init__(self,
-               kind: OperandKind,
-               type_var: Optional[TypeVar] = None,
-               size_exprs: Optional[Sequence[AffineExprDef]] = None,
-               index_dims: Optional[Sequence[DimDef]] = None,
-               default_indices: Optional[Sequence[int]] = None,
-               default_fn: Optional[str] = None):
-    if type_var and not isinstance(type_var, TypeVar):
-      raise ValueError(
-          f"OperandDef requires a TypeVar but got {repr(type_var)}")
-    self.owner = None  # type: Optional["LinalgOpDef"]
-    self.type_var = type_var
-    self.size_exprs = size_exprs
-    self.index_dims = index_dims
-    self.default_indices = default_indices
-    self.default_fn = default_fn
-    self.kind = kind
-    self.name = None  # type: Optional[str]
-    self.registered_index = -1  # type: int
-
-  def attach(self, index: int, name: str, owner: "LinalgOpDef"):
-    if self.owner:
-      raise ValueError(f"OperandDef already registered with an op: {self}")
-    self.registered_index = index
-    self.name = name
-    self.owner = owner
-
-  def is_input(self) -> bool:
-    return (self.kind == OperandKind.SCALAR or
-            self.kind == OperandKind.INPUT_TENSOR)
-
-  def is_tensor(self) -> bool:
-    return (self.kind == OperandKind.INPUT_TENSOR or
-            self.kind == OperandKind.OUTPUT_TENSOR)
-
-  def is_attribute(self) -> bool:
-    return (self.kind == OperandKind.INDEX_ATTR or
-            self.kind == OperandKind.UNARY_FN_ATTR or
-            self.kind == OperandKind.BINARY_FN_ATTR or
-            self.kind == OperandKind.TYPE_FN_ATTR)
-
-  def __hash__(self):
-    return hash(id(self))
-
-  def __repr__(self):
-    return (f"{self.name}:OperandDef(kind={self.kind.name}, "
+    """Definition of an operand passed to an operation.
+
+    Keep the meta information of Tensor, Scalar, and Attribute operands and
+    provide the shared registration functionality.
+    """
+
+    def __init__(
+        self,
+        kind: OperandKind,
+        type_var: Optional[TypeVar] = None,
+        size_exprs: Optional[Sequence[AffineExprDef]] = None,
+        index_dims: Optional[Sequence[DimDef]] = None,
+        default_indices: Optional[Sequence[int]] = None,
+        default_fn: Optional[str] = None,
+    ):
+        if type_var and not isinstance(type_var, TypeVar):
+            raise ValueError(f"OperandDef requires a TypeVar but got {repr(type_var)}")
+        self.owner = None  # type: Optional["LinalgOpDef"]
+        self.type_var = type_var
+        self.size_exprs = size_exprs
+        self.index_dims = index_dims
+        self.default_indices = default_indices
+        self.default_fn = default_fn
+        self.kind = kind
+        self.name = None  # type: Optional[str]
+        self.registered_index = -1  # type: int
+
+    def attach(self, index: int, name: str, owner: "LinalgOpDef"):
+        if self.owner:
+            raise ValueError(f"OperandDef already registered with an op: {self}")
+        self.registered_index = index
+        self.name = name
+        self.owner = owner
+
+    def is_input(self) -> bool:
+        return self.kind == OperandKind.SCALAR or self.kind == OperandKind.INPUT_TENSOR
+
+    def is_tensor(self) -> bool:
+        return (
+            self.kind == OperandKind.INPUT_TENSOR
+            or self.kind == OperandKind.OUTPUT_TENSOR
+        )
+
+    def is_attribute(self) -> bool:
+        return (
+            self.kind == OperandKind.INDEX_ATTR
+            or self.kind == OperandKind.UNARY_FN_ATTR
+            or self.kind == OperandKind.BINARY_FN_ATTR
+            or self.kind == OperandKind.TYPE_FN_ATTR
+        )
+
+    def __hash__(self):
+        return hash(id(self))
+
+    def __repr__(self):
+        return (
+            f"{self.name}:OperandDef(kind={self.kind.name}, "
             f"type={repr(self.type_var)}, size_exprs={self.size_exprs}, "
             f"index_dims={self.index_dims}, "
             f"default_indices={self.default_indices}, "
-            f"default_fn={self.default_fn})")
+            f"default_fn={self.default_fn})"
+        )
 
 
 class TensorDef:
-  """Tensor operand definition.
-
-  Tensor operands are indexed using the associated indexing_map when forwarded
-  to the body of the structured op. A unique name identifies the tensor operands
-  and an index determines their position in the operation's parameter list. A
-  tensor definition takes type, a shape, and an optional flag to mark output
-  tensors. Additionally, a tuple of index dimensions may be used to map the
-  tensor to the loop dimensions of the operation. This mapping is needed to
-  compute the indexing map of shape-only tensors that have no uses.
-  """
-
-  def __init__(self,
-               type_var: TypeVar,
-               *shape: AffineExprDef,
-               index_dims: Optional[Sequence[DimDef]] = None,
-               output: bool = False):
-    if index_dims and len(shape) != len(index_dims):
-      raise ValueError(f"Expected the shape rank {len(shape)} to match the "
-                       f"number of index_dims {len(index_dims)}")
-    if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims):
-      raise ValueError(f"TensorDef requires index dims of type DimDef but "
-                       f"got {index_dims}")
-    kind = OperandKind.OUTPUT_TENSOR if output else OperandKind.INPUT_TENSOR
-    self.operand_def = OperandDef(
-        kind, type_var=type_var, size_exprs=shape, index_dims=index_dims)
-
-  def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse:
-    assert self.operand_def.owner, "TensorDef is not registered with an op"
-    state = AffineBuildState(
-        global_state=self.operand_def.owner._affine_state,
-        allow_new_symbols=False)
-    if not isinstance(dims, tuple):
-      dims = (dims,)  # Handle single subscript case.
-    # Special case: (None) is a 0d-scalar use.
-    if dims == (None,):
-      dims = ()
-
-    exprs = []
-    for expr_def in dims:
-      if not isinstance(expr_def, AffineExprDef):
-        raise KeyError(
-            "A TensorDef can only be subscripted by a tuple of affine dims")
-      exprs.append(expr_def)
-    return TensorUse(self.operand_def, exprs)
-
-  def __setitem__(self, dims: Sequence[AffineExprDef], value: TensorExpression):
-    """Creates a new 1:1 comprehension by binding this tensor to an expression.
-
-    Note that due to the way assignment works in Python, we have to capture
-    direct assignment as a setitem on the TensorDef.
+    """Tensor operand definition.
+
+    Tensor operands are indexed using the associated indexing_map when forwarded
+    to the body of the structured op. A unique name identifies the tensor operands
+    and an index determines their position in the operation's parameter list. A
+    tensor definition takes type, a shape, and an optional flag to mark output
+    tensors. Additionally, a tuple of index dimensions may be used to map the
+    tensor to the loop dimensions of the operation. This mapping is needed to
+    compute the indexing map of shape-only tensors that have no uses.
     """
-    if not isinstance(value, TensorExpression):
-      raise ValueError(f"Only TensorExpressions can be assigned to TensorDefs. "
-                       f"Got: {repr(value)}")
-    use = self[dims]
-    comp = Comprehension((use, value))
-    self.operand_def.owner.comprehensions.append(comp)
+
+    def __init__(
+        self,
+        type_var: TypeVar,
+        *shape: AffineExprDef,
+        index_dims: Optional[Sequence[DimDef]] = None,
+        output: bool = False,
+    ):
+        if index_dims and len(shape) != len(index_dims):
+            raise ValueError(
+                f"Expected the shape rank {len(shape)} to match the "
+                f"number of index_dims {len(index_dims)}"
+            )
+        if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims):
+            raise ValueError(
+                f"TensorDef requires index dims of type DimDef but " f"got {index_dims}"
+            )
+        kind = OperandKind.OUTPUT_TENSOR if output else OperandKind.INPUT_TENSOR
+        self.operand_def = OperandDef(
+            kind, type_var=type_var, size_exprs=shape, index_dims=index_dims
+        )
+
+    def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse:
+        assert self.operand_def.owner, "TensorDef is not registered with an op"
+        state = AffineBuildState(
+            global_state=self.operand_def.owner._affine_state, allow_new_symbols=False
+        )
+        if not isinstance(dims, tuple):
+            dims = (dims,)  # Handle single subscript case.
+        # Special case: (None) is a 0d-scalar use.
+        if dims == (None,):
+            dims = ()
+
+        exprs = []
+        for expr_def in dims:
+            if not isinstance(expr_def, AffineExprDef):
+                raise KeyError(
+                    "A TensorDef can only be subscripted by a tuple of affine dims"
+                )
+            exprs.append(expr_def)
+        return TensorUse(self.operand_def, exprs)
+
+    def __setitem__(self, dims: Sequence[AffineExprDef], value: TensorExpression):
+        """Creates a new 1:1 comprehension by binding this tensor to an expression.
+
+        Note that due to the way assignment works in Python, we have to capture
+        direct assignment as a setitem on the TensorDef.
+        """
+        if not isinstance(value, TensorExpression):
+            raise ValueError(
+                f"Only TensorExpressions can be assigned to TensorDefs. "
+                f"Got: {repr(value)}"
+            )
+        use = self[dims]
+        comp = Comprehension((use, value))
+        self.operand_def.owner.comprehensions.append(comp)
 
 
 class ScalarDef(TensorExpression):
-  """Scalar operand definition.
+    """Scalar operand definition.
 
-  Scalar operands are forwarded to the body of the structured op as they are.
-  A unique name identifies the scalars and an index determines their position in
-  the operation's parameter list.
-  """
+    Scalar operands are forwarded to the body of the structured op as they are.
+    A unique name identifies the scalars and an index determines their position in
+    the operation's parameter list.
+    """
 
-  def __init__(self, type_var: TypeVar):
-    self.operand_def = OperandDef(OperandKind.SCALAR, type_var=type_var)
+    def __init__(self, type_var: TypeVar):
+        self.operand_def = OperandDef(OperandKind.SCALAR, type_var=type_var)
 
-  @property
-  def scalar_name(self) -> str:
-    name = self.operand_def.name
-    assert name is not None, "ScalarDef not registered with an op"
-    return name
+    @property
+    def scalar_name(self) -> str:
+        name = self.operand_def.name
+        assert name is not None, "ScalarDef not registered with an op"
+        return name
 
-  def to_scalar_expression(self) -> ScalarExpression:
-    return ScalarArg(self.scalar_name).expr()
+    def to_scalar_expression(self) -> ScalarExpression:
+        return ScalarArg(self.scalar_name).expr()
 
 
 class IndexAttrDef:
-  """Index attribute definition.
-
-  Index attributes provide a way to define and set symbols that can be used in
-  indexing expressions. Every attribute specifies a tuple of symbols that at
-  compile-time are replaced by integer values as well as their default values.
-  """
-
-  def __init__(self, *sizes: SymbolDef, default: Sequence[int]):
-    if any(not isinstance(size, SymbolDef) for size in sizes):
-      raise ValueError(f"IndexAttrDef requires sizes of type SymbolDef "
-                       f"but got {sizes}")
-    if any(not isinstance(default_val, int) for default_val in default):
-      raise ValueError(f"IndexAttrDef requires default values of type int "
-                       f"but got {default}")
-    if len(sizes) != len(default):
-      raise ValueError(f"IndexAttrDef expects {len(sizes)} default values "
-                       f"but got {len(default)}")
-    self.operand_def = OperandDef(
-        OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default)
+    """Index attribute definition.
+
+    Index attributes provide a way to define and set symbols that can be used in
+    indexing expressions. Every attribute specifies a tuple of symbols that at
+    compile-time are replaced by integer values as well as their default values.
+    """
+
+    def __init__(self, *sizes: SymbolDef, default: Sequence[int]):
+        if any(not isinstance(size, SymbolDef) for size in sizes):
+            raise ValueError(
+                f"IndexAttrDef requires sizes of type SymbolDef " f"but got {sizes}"
+            )
+        if any(not isinstance(default_val, int) for default_val in default):
+            raise ValueError(
+                f"IndexAttrDef requires default values of type int "
+                f"but got {default}"
+            )
+        if len(sizes) != len(default):
+            raise ValueError(
+                f"IndexAttrDef expects {len(sizes)} default values "
+                f"but got {len(default)}"
+            )
+        self.operand_def = OperandDef(
+            OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default
+        )
 
 
 class UnaryFnAttrDef:
-  """Unary function attribute definition.
+    """Unary function attribute definition.
 
-  Unary function attributes provide a way to make the arithmetic computation
-  parametrizable. Every attribute specifies a default unary function
-  that may be overwritten at operation instantiation time.
-  """
+    Unary function attributes provide a way to make the arithmetic computation
+    parametrizable. Every attribute specifies a default unary function
+    that may be overwritten at operation instantiation time.
+    """
 
-  def __init__(self, default: "UnaryFnType"):
-    if not isinstance(default, UnaryFnType):
-      raise ValueError(f"UnaryFnAttrDef requires default of type UnaryFnType "
-                       f"but got {default}")
-    self.operand_def = OperandDef(
-        OperandKind.UNARY_FN_ATTR, default_fn=default.fn_name)
+    def __init__(self, default: "UnaryFnType"):
+        if not isinstance(default, UnaryFnType):
+            raise ValueError(
+                f"UnaryFnAttrDef requires default of type UnaryFnType "
+                f"but got {default}"
+            )
+        self.operand_def = OperandDef(
+            OperandKind.UNARY_FN_ATTR, default_fn=default.fn_name
+        )
 
-  def __call__(self, arg: TensorExpression) -> TensorFn:
-    return TensorFn(FunctionKind.UNARY, None, self.operand_def, None, [arg])
+    def __call__(self, arg: TensorExpression) -> TensorFn:
+        return TensorFn(FunctionKind.UNARY, None, self.operand_def, None, [arg])
 
 
 class BinaryFnAttrDef:
-  """Binary function attribute definition.
+    """Binary function attribute definition.
 
-  Binary function attributes provide a way to make the arithmetic computation
-  parametrizable. Every attribute specifies a default binary function
-  that may be overwritten at operation instantiation time.
-  """
+    Binary function attributes provide a way to make the arithmetic computation
+    parametrizable. Every attribute specifies a default binary function
+    that may be overwritten at operation instantiation time.
+    """
 
-  def __init__(self, default: "BinaryFnType"):
-    if not isinstance(default, BinaryFnType):
-      raise ValueError(f"BinaryFnAttrDef requires default of type BinaryFnType "
-                       f"but got {default}")
-    self.operand_def = OperandDef(
-        OperandKind.BINARY_FN_ATTR, default_fn=default.fn_name)
+    def __init__(self, default: "BinaryFnType"):
+        if not isinstance(default, BinaryFnType):
+            raise ValueError(
+                f"BinaryFnAttrDef requires default of type BinaryFnType "
+                f"but got {default}"
+            )
+        self.operand_def = OperandDef(
+            OperandKind.BINARY_FN_ATTR, default_fn=default.fn_name
+        )
 
-  def __call__(self, arg0: TensorExpression,
-               arg1: TensorExpression) -> TensorFn:
-    return TensorFn(FunctionKind.BINARY, None, self.operand_def, None,
-                    [arg0, arg1])
+    def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> TensorFn:
+        return TensorFn(FunctionKind.BINARY, None, self.operand_def, None, [arg0, arg1])
 
-  def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
-    return ReduceFnUse(None, self, *reduce_dims)
+    def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
+        return ReduceFnUse(None, self, *reduce_dims)
 
 
 class TypeFnAttrDef:
-  """Type conversion function attribute definition.
+    """Type conversion function attribute definition.
 
-  Type conversion function attributes provide a way to make type conversions
-  parameterizable. Every attribute specifies a default type conversion function
-  that may be overwritten at operation instantiation time.
-  """
+    Type conversion function attributes provide a way to make type conversions
+    parameterizable. Every attribute specifies a default type conversion function
+    that may be overwritten at operation instantiation time.
+    """
 
-  def __init__(self, default: "TypeFnType"):
-    if not isinstance(default, TypeFnType):
-      raise ValueError(f"TypeFnAttrDef requires default of type TypeFnType "
-                       f"but got {default}")
-    self.operand_def = OperandDef(
-        OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name)
+    def __init__(self, default: "TypeFnType"):
+        if not isinstance(default, TypeFnType):
+            raise ValueError(
+                f"TypeFnAttrDef requires default of type TypeFnType "
+                f"but got {default}"
+            )
+        self.operand_def = OperandDef(
+            OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name
+        )
 
-  def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn:
-    return TensorFn(FunctionKind.TYPE, None, self.operand_def, type_var, [arg])
+    def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn:
+        return TensorFn(FunctionKind.TYPE, None, self.operand_def, type_var, [arg])
 
 
 ###############################################################################
@@ -644,48 +686,48 @@ class TypeFnAttrDef:
 
 
 class Comprehension:
-  """Represents a single comprehension."""
-
-  def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]):
-    self.definitions = list()  # List[TensorUse]
-    self.values = list()  # List[TensorExpression]
-
-    # Find the lhs to reduction rhs.
-    for assign, value in bindings:
-      if isinstance(value, TensorReduceFn):
-        if value.lhs:
-          raise ValueError(f"Reduction expression already assigns: {value}")
-        value.lhs = assign
-      self.definitions.append(assign)
-      self.values.append(value)
-
-  @property
-  def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]:
-    """Gets the reduction dims for the comprehension or None."""
-    result = set()
-    for use in self.values:
-      if isinstance(use, TensorReduceFn):
-        result.add(use.reduce_use.reduce_dims)
-      else:
-        result.add(tuple())
-    return result
-
-  def __repr__(self):
-    if len(self.definitions) > 1:
-      defs_repr = f"({', '.join(repr(d) for d in self.definitions)})"
-      values_repr = f"({', '.join(repr(v) for v in self.values)})"
-    else:
-      defs_repr = f"{repr(self.definitions[0])}"
-      values_repr = f"{repr(self.values[0])}"
-
-    return f"{defs_repr} = {values_repr}"
+    """Represents a single comprehension."""
+
+    def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]):
+        self.definitions = list()  # List[TensorUse]
+        self.values = list()  # List[TensorExpression]
+
+        # Find the lhs to reduction rhs.
+        for assign, value in bindings:
+            if isinstance(value, TensorReduceFn):
+                if value.lhs:
+                    raise ValueError(f"Reduction expression already assigns: {value}")
+                value.lhs = assign
+            self.definitions.append(assign)
+            self.values.append(value)
+
+    @property
+    def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]:
+        """Gets the reduction dims for the comprehension or None."""
+        result = set()
+        for use in self.values:
+            if isinstance(use, TensorReduceFn):
+                result.add(use.reduce_use.reduce_dims)
+            else:
+                result.add(tuple())
+        return result
+
+    def __repr__(self):
+        if len(self.definitions) > 1:
+            defs_repr = f"({', '.join(repr(d) for d in self.definitions)})"
+            values_repr = f"({', '.join(repr(v) for v in self.values)})"
+        else:
+            defs_repr = f"{repr(self.definitions[0])}"
+            values_repr = f"{repr(self.values[0])}"
+
+        return f"{defs_repr} = {values_repr}"
 
 
 class OpInterfaceDef:
-  """An interface that an op implements."""
+    """An interface that an op implements."""
 
-  def __init__(self, cpp_name: str):
-    self.cpp_name = cpp_name
+    def __init__(self, cpp_name: str):
+        self.cpp_name = cpp_name
 
 
 ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface")
@@ -694,86 +736,94 @@ FillOpInterface = OpInterfaceDef("LinalgFillOpInterface")
 
 
 class OpDefinitionDef:
-  """A method that an op implements."""
+    """A method that an op implements."""
 
-  def __init__(self, def_name: str):
-    self.def_name = def_name
+    def __init__(self, def_name: str):
+        self.def_name = def_name
 
 
 Canonicalizer = OpDefinitionDef("hasCanonicalizer")
 
 
 class OpMetadataDef(YAMLObject):
-  """Metadata about the op (generally not behavior impacting)."""
-  yaml_tag = "!LinalgOpMetadata"
-
-  def __init__(self, name: str, cpp_class_name: Optional[str],
-               doc: Optional[str]):
-    self.name = name
-    self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name
-    self.doc = doc
-    self.implements = []  # type: List[OpInterfaceDef]
-    self.defines = []  # type: List[OpDefinitionsDef]
-
-  def to_yaml_custom_dict(self):
-    d = dict(
-        name=self.name,
-        cpp_class_name=self.cpp_class_name,
-        doc=self.doc,
-    )
-    if self.implements:
-      d["implements"] = [intr.cpp_name for intr in self.implements]
-    if self.defines:
-      d["defines"] = [defi.def_name for defi in self.defines]
-    return d
+    """Metadata about the op (generally not behavior impacting)."""
+
+    yaml_tag = "!LinalgOpMetadata"
+
+    def __init__(self, name: str, cpp_class_name: Optional[str], doc: Optional[str]):
+        self.name = name
+        self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name
+        self.doc = doc
+        self.implements = []  # type: List[OpInterfaceDef]
+        self.defines = []  # type: List[OpDefinitionsDef]
+
+    def to_yaml_custom_dict(self):
+        d = dict(
+            name=self.name,
+            cpp_class_name=self.cpp_class_name,
+            doc=self.doc,
+        )
+        if self.implements:
+            d["implements"] = [intr.cpp_name for intr in self.implements]
+        if self.defines:
+            d["defines"] = [defi.def_name for defi in self.defines]
+        return d
 
 
 class LinalgOpDef:
-  """Definition of a linalg op."""
-
-  def __init__(self,
-               name: str,
-               cpp_class_name: Optional[str] = None,
-               doc: Optional[str] = None):
-    self.metadata = OpMetadataDef(
-        name=name, cpp_class_name=cpp_class_name, doc=doc)
-    self.registered_operands = dict()  # type: Dict[str, OperandDef]
-    self.domain = list()  # type: List[DimDef]
-    self.comprehensions = list()  # type: List[Comprehension]
-    self._affine_state = AffineBuildState()
-
-  def add_operand(self, name: str, operand: OperandDef):
-    """Registers an operand."""
-    if name in self.registered_operands:
-      raise ValueError(f"The operand {name} is already registered "
-                       f"to {self.registered_operands['name']}")
-    structured_op_methods = [
-        "inputs", "outputs", "result_tensors", "region", "iterator_types",
-        "indexing_maps", "getRegionBuilder", "getLibraryCallName"
-    ]
-    if operand.is_attribute() and name in structured_op_methods:
-      raise ValueError(f"The attribute name {name} conflicts with a structured "
-                       f"op method name")
-    # Ensure output tensors are registered after input tensors and scalars and
-    # attributes are registered after all other operand types.
-    if operand.is_input() and any(
-        not op_def.is_input() for op_def in self.registered_operands.values()):
-      raise ValueError(f"Input {name} registered after an output or attribute")
-    if operand.kind == OperandKind.OUTPUT_TENSOR and any(
-        op_def.is_attribute() for op_def in self.registered_operands.values()):
-      raise ValueError(f"Output {name} registered after an attribute")
-    operand.attach(len(self.registered_operands), name, self)
-    self.registered_operands[name] = operand
-
-  def __repr__(self):
-    lines = [
-        f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_class_name},"
-    ]
-    for name, operand in self.registered_operands.items():
-      lines.append(f"  {operand}")
-    if self.comprehensions:
-      lines[-1] += " {"
-      for comprehension in self.comprehensions:
-        lines.append(f"    {comprehension}")
-      lines.append("}")
-    return "\n".join(lines)
+    """Definition of a linalg op."""
+
+    def __init__(
+        self, name: str, cpp_class_name: Optional[str] = None, doc: Optional[str] = None
+    ):
+        self.metadata = OpMetadataDef(name=name, cpp_class_name=cpp_class_name, doc=doc)
+        self.registered_operands = dict()  # type: Dict[str, OperandDef]
+        self.domain = list()  # type: List[DimDef]
+        self.comprehensions = list()  # type: List[Comprehension]
+        self._affine_state = AffineBuildState()
+
+    def add_operand(self, name: str, operand: OperandDef):
+        """Registers an operand."""
+        if name in self.registered_operands:
+            raise ValueError(
+                f"The operand {name} is already registered "
+                f"to {self.registered_operands['name']}"
+            )
+        structured_op_methods = [
+            "inputs",
+            "outputs",
+            "result_tensors",
+            "region",
+            "iterator_types",
+            "indexing_maps",
+            "getRegionBuilder",
+            "getLibraryCallName",
+        ]
+        if operand.is_attribute() and name in structured_op_methods:
+            raise ValueError(
+                f"The attribute name {name} conflicts with a structured "
+                f"op method name"
+            )
+        # Ensure output tensors are registered after input tensors and scalars and
+        # attributes are registered after all other operand types.
+        if operand.is_input() and any(
+            not op_def.is_input() for op_def in self.registered_operands.values()
+        ):
+            raise ValueError(f"Input {name} registered after an output or attribute")
+        if operand.kind == OperandKind.OUTPUT_TENSOR and any(
+            op_def.is_attribute() for op_def in self.registered_operands.values()
+        ):
+            raise ValueError(f"Output {name} registered after an attribute")
+        operand.attach(len(self.registered_operands), name, self)
+        self.registered_operands[name] = operand
+
+    def __repr__(self):
+        lines = [f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_class_name},"]
+        for name, operand in self.registered_operands.items():
+            lines.append(f"  {operand}")
+        if self.comprehensions:
+            lines[-1] += " {"
+            for comprehension in self.comprehensions:
+                lines.append(f"    {comprehension}")
+            lines.append("}")
+        return "\n".join(lines)
index 2a0da68..d522d57 100644 (file)
@@ -21,422 +21,468 @@ __all__ = ["LinalgStructuredOpConfig", "LinalgOpConfig", "OperandDefConfig"]
 
 
 def _serialize_affine_map(affine_map: _ir.AffineMap) -> str:
-  with affine_map.context:
-    # Affine map printing/parsing is via an AffineMap attr.
-    attr = _ir.AffineMapAttr.get(affine_map)
-    return str(attr)
+    with affine_map.context:
+        # Affine map printing/parsing is via an AffineMap attr.
+        attr = _ir.AffineMapAttr.get(affine_map)
+        return str(attr)
 
 
 class TensorUseConfig:
-  """Wrapper around a TensorUse with additional context-bound state."""
+    """Wrapper around a TensorUse with additional context-bound state."""
 
-  def __init__(self, tensor_use: TensorUse, indexing_map: _ir.AffineMap):
-    self.tensor_use = tensor_use
-    self.indexing_map = indexing_map
+    def __init__(self, tensor_use: TensorUse, indexing_map: _ir.AffineMap):
+        self.tensor_use = tensor_use
+        self.indexing_map = indexing_map
 
-  def __repr__(self):
-    return f"Use({self.tensor_use}, indexing_map={self.indexing_map})"
+    def __repr__(self):
+        return f"Use({self.tensor_use}, indexing_map={self.indexing_map})"
 
 
 class OperandDefConfig(YAMLObject):
-  """Wrapper containing an operand definition with additional state."""
-  yaml_tag = "!LinalgOperandDefConfig"
-
-  def __init__(self,
-               operand_def: OperandDef,
-               shape_map: Optional[_ir.AffineMap] = None,
-               index_attr_map: Optional[_ir.AffineMap] = None):
-    self.operand_def = operand_def
-    self.shape_map = shape_map  # type: Optional[_ir.AffineMap]
-    self.index_attr_map = index_attr_map  # type: Optional[_ir.AffineMap]
-    self.indexing_map = None  # type: Optional[_ir.AffineMap]
-
-  @property
-  def name(self) -> str:
-    return self.operand_def.name
-
-  @property
-  def kind(self) -> OperandKind:
-    return self.operand_def.kind
-
-  @property
-  def type_var(self) -> TypeVar:
-    return self.operand_def.type_var
-
-  def to_yaml_custom_dict(self):
-    self_dict = dict(name=self.name, kind=self.operand_def.kind.name.lower())
-    if self.type_var:
-      self_dict["type_var"] = self.type_var.name
-    if self.shape_map:
-      self_dict["shape_map"] = _serialize_affine_map(self.shape_map)
-    if self.index_attr_map:
-      self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map)
-    if self.operand_def.default_indices:
-      self_dict["default_indices"] = self.operand_def.default_indices
-    if self.operand_def.default_fn:
-      self_dict["default_fn"] = self.operand_def.default_fn
-    return self_dict
-
-  def __repr__(self):
-    return (f"OperandDefConfig({self.operand_def}, "
+    """Wrapper containing an operand definition with additional state."""
+
+    yaml_tag = "!LinalgOperandDefConfig"
+
+    def __init__(
+        self,
+        operand_def: OperandDef,
+        shape_map: Optional[_ir.AffineMap] = None,
+        index_attr_map: Optional[_ir.AffineMap] = None,
+    ):
+        self.operand_def = operand_def
+        self.shape_map = shape_map  # type: Optional[_ir.AffineMap]
+        self.index_attr_map = index_attr_map  # type: Optional[_ir.AffineMap]
+        self.indexing_map = None  # type: Optional[_ir.AffineMap]
+
+    @property
+    def name(self) -> str:
+        return self.operand_def.name
+
+    @property
+    def kind(self) -> OperandKind:
+        return self.operand_def.kind
+
+    @property
+    def type_var(self) -> TypeVar:
+        return self.operand_def.type_var
+
+    def to_yaml_custom_dict(self):
+        self_dict = dict(name=self.name, kind=self.operand_def.kind.name.lower())
+        if self.type_var:
+            self_dict["type_var"] = self.type_var.name
+        if self.shape_map:
+            self_dict["shape_map"] = _serialize_affine_map(self.shape_map)
+        if self.index_attr_map:
+            self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map)
+        if self.operand_def.default_indices:
+            self_dict["default_indices"] = self.operand_def.default_indices
+        if self.operand_def.default_fn:
+            self_dict["default_fn"] = self.operand_def.default_fn
+        return self_dict
+
+    def __repr__(self):
+        return (
+            f"OperandDefConfig({self.operand_def}, "
             f"shape_map={self.shape_map}, "
             f"index_attr_map={self.index_attr_map}, "
-            f"indexing_map={self.indexing_map})")
+            f"indexing_map={self.indexing_map})"
+        )
 
 
 class LinalgIndexingMapsConfig(YAMLObject):
-  """Abstracts the style of indexing maps that the op exports.
-
-  Presently only static (tied to the op name) indexing maps are supported. In
-  the future, it is expected that we will have additional variants:
-    - Dynamic based on attributes
-    - Dynamic based on operands
-  Each is expected to require a different variant of specification.
-  """
-  yaml_tag = "!LinalgIndexingMapsConfig"
-
-  def __init__(self,
-               static_indexing_maps: Optional[Sequence[_ir.AffineMap]] = None):
-    self.static_indexing_maps = static_indexing_maps
-
-  def to_yaml_custom_dict(self):
-    if self.static_indexing_maps is not None:
-      return dict(static_indexing_maps=[
-          _serialize_affine_map(m) for m in self.static_indexing_maps
-      ])
-    raise ValueError(
-        f"LinalgIndexingMapsConfig must have one type of indexing map"
-        f"(got none)")
+    """Abstracts the style of indexing maps that the op exports.
 
+    Presently only static (tied to the op name) indexing maps are supported. In
+    the future, it is expected that we will have additional variants:
+      - Dynamic based on attributes
+      - Dynamic based on operands
+    Each is expected to require a different variant of specification.
+    """
 
-class LinalgStructuredOpConfig(YAMLObject):
-  """Configuration for metadata sufficient to construct a linalg named op."""
-
-  yaml_tag = "!LinalgStructuredOpConfig"
-
-  def __init__(self,
-               comprehension: Comprehension,
-               domain: Sequence[DimDef],
-               registered_operands: Sequence[OperandDef],
-               context: Optional[_ir.Context] = None):
-    self.context = context if context is not None else _ir.Context()
-    self.affine_state = AffineBuildState()
-    self.writes = list()  # type: List[Tuple[TensorUse, TensorExpression]]
-    self.operands = dict()  # type: Dict[OperandDef, OperandDefConfig]
-    self.uses = dict()  # type: Dict[TensorUse, TensorUseConfig]
-
-    # Compute the ordered set of writes and collect the tensor, capture, dims,
-    # and index uses.
-    collected_tensor_uses = set()
-    collected_scalar_uses = set()
-    collected_dim_uses = set()
-    collected_indices = set()
-    for write_use, read_use in zip(comprehension.definitions,
-                                   comprehension.values):
-      self.writes.append((write_use, read_use))
-
-    for write_use, read_use in self.writes:
-      collected_tensor_uses.add(write_use)
-      read_use.collect_tensor_uses(collected_tensor_uses)
-      read_use.collect_scalar_uses(collected_scalar_uses)
-      read_use.collect_dim_uses(collected_dim_uses)
-      write_use.collect_dim_uses(collected_dim_uses)
-      read_use.collect_indices(collected_indices)
-
-    # Set domain to the sorted list of uses if no domain annotation is given.
-    if not domain:
-      domain = sorted(collected_dim_uses, key=lambda dim: dim.dimname)
-
-    # Verify the domain dimensions match the used dimensions.
-    if (len(domain) != len(collected_dim_uses) or
-        any(dim not in collected_dim_uses for dim in domain)):
-      raise ValueError(f"Expected the annotated domain dimensions {domain} to "
-                       f"match the set of dimension used by the tensor "
-                       f"comprehension {collected_dim_uses}")
-
-    # Instantiate the dimensions in the given order.
-    with self.context:
-      local_state = AffineBuildState(
-          global_state=self.affine_state, allow_new_symbols=False)
-      for dim in domain:
-        dim.build(state=local_state)
-
-    # Collect all attribute definitions.
-    collected_attr_defs = list()
-    for operand in registered_operands:
-      if operand.is_attribute():
-        collected_attr_defs.append(operand)
-
-    # Collect all tensors with manual indexing annotation.
-    collected_index_defs = list()
-    for operand in registered_operands:
-      if operand.index_dims:
-        if any(dim not in collected_dim_uses for dim in operand.index_dims):
-          raise ValueError(f"Expected all index dims {operand.index_dims} of "
-                           f"operand {operand.name} to have uses.")
-        collected_index_defs.append(operand)
-
-    # Collect the operand definitions of all tensor/scalar uses, attributes, and
-    # shape-only tensors.
-    all_operand_defs = list()
-    for use in collected_tensor_uses:
-      all_operand_defs.append(use.operand_def)
-    for use in collected_scalar_uses:
-      all_operand_defs.append(use.operand_def)
-    for definition in collected_attr_defs:
-      all_operand_defs.append(definition)
-    for definition in collected_index_defs:
-      all_operand_defs.append(definition)
-
-    # Add all operands in registration order to ensure the symbols are
-    # registered in the order they appear.
-    all_operand_defs = sorted(
-        all_operand_defs, key=lambda operand_def: operand_def.registered_index)
-    for operand_def in all_operand_defs:
-      self.add_operand(operand_def)
-
-    # Add all shape-only tensor index_dim annotations and all tensor uses.
-    for definition in collected_index_defs:
-      self.add_indexed_operand(definition)
-    for use in collected_tensor_uses:
-      self.add_tensor_use(use)
-
-    # Normalize all shape and indexing maps now that full count of dims and
-    # symbols are known.
-    for cuse in self.uses.values():
-      cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map)
-    for definition in collected_index_defs:
-      self.operands[definition].indexing_map = self._normalize_affine_map(
-          self.operands[definition].indexing_map)
-    for operand_config in self.operands.values():
-      if operand_config.shape_map:
-        operand_config.shape_map = self._normalize_affine_map(
-            operand_config.shape_map, with_dims=False)
-      if operand_config.index_attr_map:
-        operand_config.index_attr_map = self._normalize_affine_map(
-            operand_config.index_attr_map, with_dims=False)
-
-    # Now for each write use, propagate the indexing maps from the use to the
-    # tensor, ensuring that there are not conflicts.
-    for write_use, _ in self.writes:
-      write_tensor_config = self.operands[write_use.operand_def]
-      if write_tensor_config.indexing_map:
-        raise ValueError(
-            f"Unexpected multi-write to a single tensor: {write_tensor_config}")
-      write_tensor_config.indexing_map = self.uses[write_use].indexing_map
-
-    # For each read use, propagate the indexing maps from the use to the
-    # tensor, ensuring that there are not conflicts.
-    for _, read_expr in self.writes:
-      read_uses = set()  # type: Set[TensorUse]
-      read_expr.collect_tensor_uses(read_uses)
-      for read_use in read_uses:
-        read_operand_config = self.operands[read_use.operand_def]
-        if (read_operand_config.indexing_map and
-            read_operand_config.indexing_map !=
-            self.uses[read_use].indexing_map):
-          raise ValueError(
-              f"Unexpected multi-read of a tensor with different accesses:"
-              f"{read_operand_config} vs {read_use}")
-        read_operand_config.indexing_map = self.uses[read_use].indexing_map
-
-    # Set the indexing map of all scalar uses to the empty map.
-    for operand_config in self.operands.values():
-      if operand_config.operand_def.kind == OperandKind.SCALAR:
-        operand_config.indexing_map = self._get_scalar_map()
-
-    # Check all registered tensor and scalar operands have an indexing map.
-    for operand in registered_operands:
-      if operand.is_attribute():
-        continue
-      if not (operand in self.operands and self.operands[operand].indexing_map):
-        raise ValueError(f"Failed to compute an indexing map for operand "
-                         f"{operand.name}")
-
-    # Collect reduction dims and ensure all the same.
-    all_reduction_dims = set(comprehension.all_reduction_dims)
-    if len(all_reduction_dims) != 1:
-      raise ValueError(
-          f"All writes within a generic must have the same reduction "
-          f"dims. Got: {all_reduction_dims}")
-    self.reduction_dims = next(iter(all_reduction_dims))
-
-    # Check the index dimension exists and resolve.
-    for index in collected_indices:
-      if index.dim_def.dimname not in self.affine_state.all_dims:
+    yaml_tag = "!LinalgIndexingMapsConfig"
+
+    def __init__(self, static_indexing_maps: Optional[Sequence[_ir.AffineMap]] = None):
+        self.static_indexing_maps = static_indexing_maps
+
+    def to_yaml_custom_dict(self):
+        if self.static_indexing_maps is not None:
+            return dict(
+                static_indexing_maps=[
+                    _serialize_affine_map(m) for m in self.static_indexing_maps
+                ]
+            )
         raise ValueError(
-            f"The dimension {index.dim_def.dimname} is not part of the "
-            f"iteration domain {self.affine_state.all_dims}")
-      index.resolve_dimension_name(self.affine_state)
-
-    # Generate the scalar assignments (used to build a body).
-    self.assignments = [
-        ScalarAssign(write_use.tensor_name, read_expr.to_scalar_expression())
-        for write_use, read_expr in self.writes
-    ]
-
-  @property
-  def ordered_operands(self) -> Sequence[OperandDefConfig]:
-    return sorted(
-        self.operands.values(),
-        key=lambda operand: operand.operand_def.registered_index)
-
-  @property
-  def ordered_dims(self) -> Sequence[Tuple[str, int]]:
-    """Gets the ordered list of dim bindings (symbolic name, position).
-
-    TODO: The original parser relies on parse ordering to arrive at the
-    iterator types, but that ordering is not defined on the Python side, so
-    this may be ambiguous.
-    """
-    return list(self.affine_state.all_dims.items())
-
-  @property
-  def indexing_maps(self) -> Sequence[_ir.AffineMap]:
-    return [o.indexing_map for o in self.ordered_operands if o.indexing_map]
-
-  @property
-  def iterator_types(self) -> Sequence[str]:
-
-    def get_type(symbolic_name, position):
-      for reduction_dim_expr in self.reduction_dims:
-        if reduction_dim_expr.dimname == symbolic_name:
-          return "reduction"
-      return "parallel"
-
-    return [get_type(*dim) for dim in self.ordered_dims]
-
-  def add_operand(self, operand_def: OperandDef):
-    if operand_def in self.operands:
-      return
-    if not (operand_def.is_tensor() or
-            operand_def.kind == OperandKind.INDEX_ATTR):
-      self.operands[operand_def] = OperandDefConfig(operand_def)
-      return
-    with self.context:
-      local_state = AffineBuildState(
-          global_state=self.affine_state, allow_new_dims=False)
-      exprs = []
-      for expr in operand_def.size_exprs:
-        exprs.append(expr.build(state=local_state))
-      assert local_state.local_dim_count == 0
-      affine_map = _ir.AffineMap.get(
-          dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs)
-      if operand_def.kind == OperandKind.INDEX_ATTR:
-        self.operands[operand_def] = OperandDefConfig(
-            operand_def, index_attr_map=affine_map)
-      else:
-        self.operands[operand_def] = OperandDefConfig(
-            operand_def, shape_map=affine_map)
-
-  def add_indexed_operand(self, operand_def: OperandDef):
-    with self.context:
-      local_state = AffineBuildState(
-          global_state=self.affine_state, allow_new_symbols=False)
-      exprs = []
-      for expr in operand_def.index_dims:
-        exprs.append(expr.build(state=local_state))
-      self.operands[operand_def].indexing_map = _ir.AffineMap.get(
-          dim_count=local_state.dim_count,
-          symbol_count=local_state.symbol_count,
-          exprs=exprs)
-
-  def add_tensor_use(self, tensor_use: TensorUse):
-    if tensor_use in self.uses:
-      return
-    with self.context:
-      local_state = AffineBuildState(
-          global_state=self.affine_state, allow_new_symbols=False)
-      exprs = []
-      for expr in tensor_use.indices:
-        exprs.append(expr.build(state=local_state))
-      indexing_map = _ir.AffineMap.get(
-          dim_count=local_state.dim_count,
-          symbol_count=local_state.symbol_count,
-          exprs=exprs)
-
-      use_config = TensorUseConfig(tensor_use, indexing_map)
-      self.uses[tensor_use] = use_config
-
-  def _get_scalar_map(self) -> _ir.AffineMap:
-    """Create an empty affine map used to index a scalar."""
-    with self.context:
-      return _ir.AffineMap.get(
-          dim_count=self.affine_state.dim_count,
-          symbol_count=self.affine_state.symbol_count,
-          exprs=list())
-
-  def _normalize_affine_map(self,
-                            affine_map: _ir.AffineMap,
-                            with_dims: bool = True) -> _ir.AffineMap:
-    """Normalizes an indexing map to have the max known symbols and dims."""
-    with self.context:
-      return _ir.AffineMap.get(
-          dim_count=self.affine_state.dim_count if with_dims else 0,
-          symbol_count=self.affine_state.symbol_count,
-          exprs=list(affine_map.results))
-
-  def to_yaml_custom_dict(self):
-    self_dict = dict(args=self.ordered_operands)
-    # TODO: Refactor the hierarchy internally when supporting more
-    # than static (preserving this serialized form).
-    self_dict["indexing_maps"] = LinalgIndexingMapsConfig(
-        static_indexing_maps=self.indexing_maps)
-    self_dict["iterator_types"] = self.iterator_types
-    self_dict["assignments"] = self.assignments
-    return self_dict
-
-  def __repr__(self):
-    lines = [f"LinalgGenericOpConfig(reduction_dims={self.reduction_dims},"]
-    lines.append("operands=[")
-    for def_config in self.ordered_operands:
-      lines.append(f"  {repr(def_config)}")
-    lines.append("], indexing_maps=[")
-    for m in self.indexing_maps:
-      lines.append(f"  {repr(m)}")
-    lines.append(f"], iterator_types=[")
-    for t in self.iterator_types:
-      lines.append(f"  {t}")
-    lines.append("])")
-    return "\n".join(lines)
+            f"LinalgIndexingMapsConfig must have one type of indexing map" f"(got none)"
+        )
+
+
+class LinalgStructuredOpConfig(YAMLObject):
+    """Configuration for metadata sufficient to construct a linalg named op."""
+
+    yaml_tag = "!LinalgStructuredOpConfig"
+
+    def __init__(
+        self,
+        comprehension: Comprehension,
+        domain: Sequence[DimDef],
+        registered_operands: Sequence[OperandDef],
+        context: Optional[_ir.Context] = None,
+    ):
+        self.context = context if context is not None else _ir.Context()
+        self.affine_state = AffineBuildState()
+        self.writes = list()  # type: List[Tuple[TensorUse, TensorExpression]]
+        self.operands = dict()  # type: Dict[OperandDef, OperandDefConfig]
+        self.uses = dict()  # type: Dict[TensorUse, TensorUseConfig]
+
+        # Compute the ordered set of writes and collect the tensor, capture, dims,
+        # and index uses.
+        collected_tensor_uses = set()
+        collected_scalar_uses = set()
+        collected_dim_uses = set()
+        collected_indices = set()
+        for write_use, read_use in zip(comprehension.definitions, comprehension.values):
+            self.writes.append((write_use, read_use))
+
+        for write_use, read_use in self.writes:
+            collected_tensor_uses.add(write_use)
+            read_use.collect_tensor_uses(collected_tensor_uses)
+            read_use.collect_scalar_uses(collected_scalar_uses)
+            read_use.collect_dim_uses(collected_dim_uses)
+            write_use.collect_dim_uses(collected_dim_uses)
+            read_use.collect_indices(collected_indices)
+
+        # Set domain to the sorted list of uses if no domain annotation is given.
+        if not domain:
+            domain = sorted(collected_dim_uses, key=lambda dim: dim.dimname)
+
+        # Verify the domain dimensions match the used dimensions.
+        if len(domain) != len(collected_dim_uses) or any(
+            dim not in collected_dim_uses for dim in domain
+        ):
+            raise ValueError(
+                f"Expected the annotated domain dimensions {domain} to "
+                f"match the set of dimension used by the tensor "
+                f"comprehension {collected_dim_uses}"
+            )
+
+        # Instantiate the dimensions in the given order.
+        with self.context:
+            local_state = AffineBuildState(
+                global_state=self.affine_state, allow_new_symbols=False
+            )
+            for dim in domain:
+                dim.build(state=local_state)
+
+        # Collect all attribute definitions.
+        collected_attr_defs = list()
+        for operand in registered_operands:
+            if operand.is_attribute():
+                collected_attr_defs.append(operand)
+
+        # Collect all tensors with manual indexing annotation.
+        collected_index_defs = list()
+        for operand in registered_operands:
+            if operand.index_dims:
+                if any(dim not in collected_dim_uses for dim in operand.index_dims):
+                    raise ValueError(
+                        f"Expected all index dims {operand.index_dims} of "
+                        f"operand {operand.name} to have uses."
+                    )
+                collected_index_defs.append(operand)
+
+        # Collect the operand definitions of all tensor/scalar uses, attributes, and
+        # shape-only tensors.
+        all_operand_defs = list()
+        for use in collected_tensor_uses:
+            all_operand_defs.append(use.operand_def)
+        for use in collected_scalar_uses:
+            all_operand_defs.append(use.operand_def)
+        for definition in collected_attr_defs:
+            all_operand_defs.append(definition)
+        for definition in collected_index_defs:
+            all_operand_defs.append(definition)
+
+        # Add all operands in registration order to ensure the symbols are
+        # registered in the order they appear.
+        all_operand_defs = sorted(
+            all_operand_defs, key=lambda operand_def: operand_def.registered_index
+        )
+        for operand_def in all_operand_defs:
+            self.add_operand(operand_def)
+
+        # Add all shape-only tensor index_dim annotations and all tensor uses.
+        for definition in collected_index_defs:
+            self.add_indexed_operand(definition)
+        for use in collected_tensor_uses:
+            self.add_tensor_use(use)
+
+        # Normalize all shape and indexing maps now that full count of dims and
+        # symbols are known.
+        for cuse in self.uses.values():
+            cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map)
+        for definition in collected_index_defs:
+            self.operands[definition].indexing_map = self._normalize_affine_map(
+                self.operands[definition].indexing_map
+            )
+        for operand_config in self.operands.values():
+            if operand_config.shape_map:
+                operand_config.shape_map = self._normalize_affine_map(
+                    operand_config.shape_map, with_dims=False
+                )
+            if operand_config.index_attr_map:
+                operand_config.index_attr_map = self._normalize_affine_map(
+                    operand_config.index_attr_map, with_dims=False
+                )
+
+        # Now for each write use, propagate the indexing maps from the use to the
+        # tensor, ensuring that there are not conflicts.
+        for write_use, _ in self.writes:
+            write_tensor_config = self.operands[write_use.operand_def]
+            if write_tensor_config.indexing_map:
+                raise ValueError(
+                    f"Unexpected multi-write to a single tensor: {write_tensor_config}"
+                )
+            write_tensor_config.indexing_map = self.uses[write_use].indexing_map
+
+        # For each read use, propagate the indexing maps from the use to the
+        # tensor, ensuring that there are not conflicts.
+        for _, read_expr in self.writes:
+            read_uses = set()  # type: Set[TensorUse]
+            read_expr.collect_tensor_uses(read_uses)
+            for read_use in read_uses:
+                read_operand_config = self.operands[read_use.operand_def]
+                if (
+                    read_operand_config.indexing_map
+                    and read_operand_config.indexing_map
+                    != self.uses[read_use].indexing_map
+                ):
+                    raise ValueError(
+                        f"Unexpected multi-read of a tensor with different accesses:"
+                        f"{read_operand_config} vs {read_use}"
+                    )
+                read_operand_config.indexing_map = self.uses[read_use].indexing_map
+
+        # Set the indexing map of all scalar uses to the empty map.
+        for operand_config in self.operands.values():
+            if operand_config.operand_def.kind == OperandKind.SCALAR:
+                operand_config.indexing_map = self._get_scalar_map()
+
+        # Check all registered tensor and scalar operands have an indexing map.
+        for operand in registered_operands:
+            if operand.is_attribute():
+                continue
+            if not (operand in self.operands and self.operands[operand].indexing_map):
+                raise ValueError(
+                    f"Failed to compute an indexing map for operand " f"{operand.name}"
+                )
+
+        # Collect reduction dims and ensure all the same.
+        all_reduction_dims = set(comprehension.all_reduction_dims)
+        if len(all_reduction_dims) != 1:
+            raise ValueError(
+                f"All writes within a generic must have the same reduction "
+                f"dims. Got: {all_reduction_dims}"
+            )
+        self.reduction_dims = next(iter(all_reduction_dims))
+
+        # Check the index dimension exists and resolve.
+        for index in collected_indices:
+            if index.dim_def.dimname not in self.affine_state.all_dims:
+                raise ValueError(
+                    f"The dimension {index.dim_def.dimname} is not part of the "
+                    f"iteration domain {self.affine_state.all_dims}"
+                )
+            index.resolve_dimension_name(self.affine_state)
+
+        # Generate the scalar assignments (used to build a body).
+        self.assignments = [
+            ScalarAssign(write_use.tensor_name, read_expr.to_scalar_expression())
+            for write_use, read_expr in self.writes
+        ]
+
+    @property
+    def ordered_operands(self) -> Sequence[OperandDefConfig]:
+        return sorted(
+            self.operands.values(),
+            key=lambda operand: operand.operand_def.registered_index,
+        )
+
+    @property
+    def ordered_dims(self) -> Sequence[Tuple[str, int]]:
+        """Gets the ordered list of dim bindings (symbolic name, position).
+
+        TODO: The original parser relies on parse ordering to arrive at the
+        iterator types, but that ordering is not defined on the Python side, so
+        this may be ambiguous.
+        """
+        return list(self.affine_state.all_dims.items())
+
+    @property
+    def indexing_maps(self) -> Sequence[_ir.AffineMap]:
+        return [o.indexing_map for o in self.ordered_operands if o.indexing_map]
+
+    @property
+    def iterator_types(self) -> Sequence[str]:
+        def get_type(symbolic_name, position):
+            for reduction_dim_expr in self.reduction_dims:
+                if reduction_dim_expr.dimname == symbolic_name:
+                    return "reduction"
+            return "parallel"
+
+        return [get_type(*dim) for dim in self.ordered_dims]
+
+    def add_operand(self, operand_def: OperandDef):
+        if operand_def in self.operands:
+            return
+        if not (operand_def.is_tensor() or operand_def.kind == OperandKind.INDEX_ATTR):
+            self.operands[operand_def] = OperandDefConfig(operand_def)
+            return
+        with self.context:
+            local_state = AffineBuildState(
+                global_state=self.affine_state, allow_new_dims=False
+            )
+            exprs = []
+            for expr in operand_def.size_exprs:
+                exprs.append(expr.build(state=local_state))
+            assert local_state.local_dim_count == 0
+            affine_map = _ir.AffineMap.get(
+                dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs
+            )
+            if operand_def.kind == OperandKind.INDEX_ATTR:
+                self.operands[operand_def] = OperandDefConfig(
+                    operand_def, index_attr_map=affine_map
+                )
+            else:
+                self.operands[operand_def] = OperandDefConfig(
+                    operand_def, shape_map=affine_map
+                )
+
+    def add_indexed_operand(self, operand_def: OperandDef):
+        with self.context:
+            local_state = AffineBuildState(
+                global_state=self.affine_state, allow_new_symbols=False
+            )
+            exprs = []
+            for expr in operand_def.index_dims:
+                exprs.append(expr.build(state=local_state))
+            self.operands[operand_def].indexing_map = _ir.AffineMap.get(
+                dim_count=local_state.dim_count,
+                symbol_count=local_state.symbol_count,
+                exprs=exprs,
+            )
+
+    def add_tensor_use(self, tensor_use: TensorUse):
+        if tensor_use in self.uses:
+            return
+        with self.context:
+            local_state = AffineBuildState(
+                global_state=self.affine_state, allow_new_symbols=False
+            )
+            exprs = []
+            for expr in tensor_use.indices:
+                exprs.append(expr.build(state=local_state))
+            indexing_map = _ir.AffineMap.get(
+                dim_count=local_state.dim_count,
+                symbol_count=local_state.symbol_count,
+                exprs=exprs,
+            )
+
+            use_config = TensorUseConfig(tensor_use, indexing_map)
+            self.uses[tensor_use] = use_config
+
+    def _get_scalar_map(self) -> _ir.AffineMap:
+        """Create an empty affine map used to index a scalar."""
+        with self.context:
+            return _ir.AffineMap.get(
+                dim_count=self.affine_state.dim_count,
+                symbol_count=self.affine_state.symbol_count,
+                exprs=list(),
+            )
+
+    def _normalize_affine_map(
+        self, affine_map: _ir.AffineMap, with_dims: bool = True
+    ) -> _ir.AffineMap:
+        """Normalizes an indexing map to have the max known symbols and dims."""
+        with self.context:
+            return _ir.AffineMap.get(
+                dim_count=self.affine_state.dim_count if with_dims else 0,
+                symbol_count=self.affine_state.symbol_count,
+                exprs=list(affine_map.results),
+            )
+
+    def to_yaml_custom_dict(self):
+        self_dict = dict(args=self.ordered_operands)
+        # TODO: Refactor the hierarchy internally when supporting more
+        # than static (preserving this serialized form).
+        self_dict["indexing_maps"] = LinalgIndexingMapsConfig(
+            static_indexing_maps=self.indexing_maps
+        )
+        self_dict["iterator_types"] = self.iterator_types
+        self_dict["assignments"] = self.assignments
+        return self_dict
+
+    def __repr__(self):
+        lines = [f"LinalgGenericOpConfig(reduction_dims={self.reduction_dims},"]
+        lines.append("operands=[")
+        for def_config in self.ordered_operands:
+            lines.append(f"  {repr(def_config)}")
+        lines.append("], indexing_maps=[")
+        for m in self.indexing_maps:
+            lines.append(f"  {repr(m)}")
+        lines.append(f"], iterator_types=[")
+        for t in self.iterator_types:
+            lines.append(f"  {t}")
+        lines.append("])")
+        return "\n".join(lines)
 
 
 class LinalgOpConfig(YAMLObject):
-  """Container for any supported linalg op type.
-
-  This includes the concrete type by name for ease of parsing by systems
-  that ignore tags.
-  """
-  yaml_tag = "!LinalgOpConfig"
-
-  def __init__(self,
-               metadata: OpMetadataDef,
-               *,
-               structured_op: Optional[LinalgStructuredOpConfig] = None):
-    self.metadata = metadata
-    self.structured_op = structured_op
-
-  def to_yaml_custom_dict(self):
-    self_dict = dict(metadata=self.metadata,)
-    if self.structured_op:
-      self_dict["structured_op"] = self.structured_op
-    return self_dict
-
-  @staticmethod
-  def from_linalg_op_def(
-      op_def: LinalgOpDef,
-      context: Optional[_ir.Context] = None) -> Sequence["LinalgOpConfig"]:
-    """Expands a LinalgOpDef into corresponding Linalg configured ops."""
-    # TODO: Many LinalgOpDef patterns need to expand to multiple generics.
-    assert len(op_def.comprehensions) == 1, "Only one comprehension supported"
-    return [
-        LinalgOpConfig(
-            op_def.metadata,
-            structured_op=LinalgStructuredOpConfig(
-                op_def.comprehensions[0], op_def.domain,
-                op_def.registered_operands.values(), context)),
-    ]
-
-  def __repr__(self):
-    return (f"LinalgOpConfig(metadata={self.metadata},\n"
-            f"structured_op={self.structured_op})")
+    """Container for any supported linalg op type.
+
+    This includes the concrete type by name for ease of parsing by systems
+    that ignore tags.
+    """
+
+    yaml_tag = "!LinalgOpConfig"
+
+    def __init__(
+        self,
+        metadata: OpMetadataDef,
+        *,
+        structured_op: Optional[LinalgStructuredOpConfig] = None,
+    ):
+        self.metadata = metadata
+        self.structured_op = structured_op
+
+    def to_yaml_custom_dict(self):
+        self_dict = dict(
+            metadata=self.metadata,
+        )
+        if self.structured_op:
+            self_dict["structured_op"] = self.structured_op
+        return self_dict
+
+    @staticmethod
+    def from_linalg_op_def(
+        op_def: LinalgOpDef, context: Optional[_ir.Context] = None
+    ) -> Sequence["LinalgOpConfig"]:
+        """Expands a LinalgOpDef into corresponding Linalg configured ops."""
+        # TODO: Many LinalgOpDef patterns need to expand to multiple generics.
+        assert len(op_def.comprehensions) == 1, "Only one comprehension supported"
+        return [
+            LinalgOpConfig(
+                op_def.metadata,
+                structured_op=LinalgStructuredOpConfig(
+                    op_def.comprehensions[0],
+                    op_def.domain,
+                    op_def.registered_operands.values(),
+                    context,
+                ),
+            ),
+        ]
+
+    def __repr__(self):
+        return (
+            f"LinalgOpConfig(metadata={self.metadata},\n"
+            f"structured_op={self.structured_op})"
+        )
index 45b8d5c..8b8726f 100644 (file)
@@ -10,160 +10,192 @@ import inspect
 import threading
 
 from ..... import ir
-from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
+from ...._ods_common import (
+    get_op_result_or_value as _get_op_result_or_value,
+    get_op_results_or_values as _get_op_results_or_values,
+)
 from .comprehension import *
 from .config import *
 from .emitter import *
 
 _CONTEXT = threading.local()
 
-StructuredOpOuts = Union[ir.Operation, ir.OpView, ir.OpResultList,
-                         Sequence[Union[ir.Value, ir.Operation, ir.OpView]]]
+StructuredOpOuts = Union[
+    ir.Operation,
+    ir.OpView,
+    ir.OpResultList,
+    Sequence[Union[ir.Value, ir.Operation, ir.OpView]],
+]
 
 
 @contextmanager
 def bind_op_def(op_def: LinalgOpDef):
-  if hasattr(_CONTEXT, "current_op_def"):
-    raise ValueError("Cannot recursively define an operation")
-  _CONTEXT.current_op_def = op_def
-  try:
-    yield op_def
-  finally:
-    del _CONTEXT.current_op_def
+    if hasattr(_CONTEXT, "current_op_def"):
+        raise ValueError("Cannot recursively define an operation")
+    _CONTEXT.current_op_def = op_def
+    try:
+        yield op_def
+    finally:
+        del _CONTEXT.current_op_def
 
 
 def current_op_def() -> LinalgOpDef:
-  try:
-    return _CONTEXT.current_op_def
-  except AttributeError:
-    raise ValueError(
-        "Attempt to access the current op definition being defined "
-        "but none is set. Did you mean to call this in an op definition?")
+    try:
+        return _CONTEXT.current_op_def
+    except AttributeError:
+        raise ValueError(
+            "Attempt to access the current op definition being defined "
+            "but none is set. Did you mean to call this in an op definition?"
+        )
 
 
 def _prepare_structured_op_outs(outs: StructuredOpOuts) -> ValueList:
-  if isinstance(outs, (ir.Operation, ir.OpView)):
-    return _get_op_results_or_values(outs)
-  elif isinstance(outs, ir.OpResultList):
-    return outs
+    if isinstance(outs, (ir.Operation, ir.OpView)):
+        return _get_op_results_or_values(outs)
+    elif isinstance(outs, ir.OpResultList):
+        return outs
 
-  return [_get_op_result_or_value(o) for o in outs]
+    return [_get_op_result_or_value(o) for o in outs]
 
 
 class DefinedOpCallable:
-  """Callable that wraps any defined op function."""
-
-  def __init__(self, op_name: str, op_def: LinalgOpDef):
-    self.op_name = op_name
-    self.op_def = op_def
-
-  def __call__(self, *ins: Union[ir.Operation, ir.OpView, ir.Value],
-               outs: StructuredOpOuts, **kwargs):
-    """Emits the corresponding op definition as IR.
-
-    Most arguments are passed through to the underlying emitter. The following
-    keyword argument is interpreted here:
-      emit_generic: Emits a generic form as appropriate (default True). If
-        False, a named form is emitted (which must have been built in to the
-        compiler).
-    """
-    emit_generic = kwargs.pop("emit_generic", False)
-    if not isinstance(emit_generic, bool):
-      raise ValueError(f"The named argument 'emit_generic' needs to be "
-                       f" of type bool but got {type(emit_generic)}")
-
-    op_configs = LinalgOpConfig.from_linalg_op_def(
-        self.op_def, context=ir.Context.current)
-
-    if len(op_configs) != 1:
-      # TODO: Support composite ops.
-      raise NotImplementedError(
-          f"Emission of composite linalg ops not supported: {op_configs}")
-
-    ctx = ir.Context.current
-    linalgDialect = ctx.get_dialect_descriptor("linalg")
-    fully_qualified_name = "linalg." + self.op_name
-    emit_generic = (
-        emit_generic or not ctx.is_registered_operation(fully_qualified_name))
-
-    op_config = op_configs[0]
-    out_values = _prepare_structured_op_outs(outs)
-    in_values = [_get_op_result_or_value(i) for i in ins]
-    if op_config.structured_op:
-      if emit_generic:
-        return emit_generic_structured_op(
-            op_config.structured_op, *in_values, outs=out_values, **kwargs)
-      else:
-        return emit_named_structured_op(
-            op_config.structured_op,
-            self.op_name,
-            self.op_def.metadata.cpp_class_name,
-            *in_values,
-            outs=out_values,
-            **kwargs)
-
-    raise NotImplementedError(
-        f"Emission of linalg op type not supported: {op_config}")
-
-
-def linalg_structured_op(dsl_func=None,
-                         *,
-                         op_name=None,
-                         op_class_name=None) -> DefinedOpCallable:
-  if dsl_func is None:
-    # Curry the keyword args in for delayed application.
-    return functools.partial(
-        linalg_structured_op, op_name=op_name, op_class_name=op_class_name)
-  # Determine default names by introspecting the function.
-  if op_name is None:
-    op_name = dsl_func.__name__
-  if op_class_name is None:
-    # Camel case it.
-    op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op"
-
-  op_def = LinalgOpDef(
-      name=op_name, cpp_class_name=op_class_name, doc=inspect.getdoc(dsl_func))
-
-  # Extract arguments and TensorDefs from the signature.
-  dsl_func_args = list()
-  sig = inspect.signature(dsl_func)
-  for param_name, param in sig.parameters.items():
-    param_default = param.default
-    if isinstance(param_default,
-                  (TensorDef, ScalarDef, IndexAttrDef, UnaryFnAttrDef,
-                   BinaryFnAttrDef, TypeFnAttrDef)):
-      op_def.add_operand(param_name, param_default.operand_def)
-    else:
-      raise ValueError(
-          f"@linalg_structured_op function parameters must be defaulted as "
-          f"TensorDef(...), ScalarDef(...), or IndexAttrDef(...): "
-          f"Found {param_name}: {param_default}")
-    dsl_func_args.append(param_default)
-
-  # Invoke the DSL func to finish populating the op definition.
-  with bind_op_def(op_def):
-    dsl_func(*dsl_func_args)
-
-  # TODO: The returned callable should be an IR emitter but that is not
-  # upstreamed yet.
-  return DefinedOpCallable(op_name, op_def)
+    """Callable that wraps any defined op function."""
+
+    def __init__(self, op_name: str, op_def: LinalgOpDef):
+        self.op_name = op_name
+        self.op_def = op_def
+
+    def __call__(
+        self,
+        *ins: Union[ir.Operation, ir.OpView, ir.Value],
+        outs: StructuredOpOuts,
+        **kwargs,
+    ):
+        """Emits the corresponding op definition as IR.
+
+        Most arguments are passed through to the underlying emitter. The following
+        keyword argument is interpreted here:
+          emit_generic: Emits a generic form as appropriate (default True). If
+            False, a named form is emitted (which must have been built in to the
+            compiler).
+        """
+        emit_generic = kwargs.pop("emit_generic", False)
+        if not isinstance(emit_generic, bool):
+            raise ValueError(
+                f"The named argument 'emit_generic' needs to be "
+                f" of type bool but got {type(emit_generic)}"
+            )
+
+        op_configs = LinalgOpConfig.from_linalg_op_def(
+            self.op_def, context=ir.Context.current
+        )
+
+        if len(op_configs) != 1:
+            # TODO: Support composite ops.
+            raise NotImplementedError(
+                f"Emission of composite linalg ops not supported: {op_configs}"
+            )
+
+        ctx = ir.Context.current
+        linalgDialect = ctx.get_dialect_descriptor("linalg")
+        fully_qualified_name = "linalg." + self.op_name
+        emit_generic = emit_generic or not ctx.is_registered_operation(
+            fully_qualified_name
+        )
+
+        op_config = op_configs[0]
+        out_values = _prepare_structured_op_outs(outs)
+        in_values = [_get_op_result_or_value(i) for i in ins]
+        if op_config.structured_op:
+            if emit_generic:
+                return emit_generic_structured_op(
+                    op_config.structured_op, *in_values, outs=out_values, **kwargs
+                )
+            else:
+                return emit_named_structured_op(
+                    op_config.structured_op,
+                    self.op_name,
+                    self.op_def.metadata.cpp_class_name,
+                    *in_values,
+                    outs=out_values,
+                    **kwargs,
+                )
+
+        raise NotImplementedError(
+            f"Emission of linalg op type not supported: {op_config}"
+        )
+
+
+def linalg_structured_op(
+    dsl_func=None, *, op_name=None, op_class_name=None
+) -> DefinedOpCallable:
+    if dsl_func is None:
+        # Curry the keyword args in for delayed application.
+        return functools.partial(
+            linalg_structured_op, op_name=op_name, op_class_name=op_class_name
+        )
+    # Determine default names by introspecting the function.
+    if op_name is None:
+        op_name = dsl_func.__name__
+    if op_class_name is None:
+        # Camel case it.
+        op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op"
+
+    op_def = LinalgOpDef(
+        name=op_name, cpp_class_name=op_class_name, doc=inspect.getdoc(dsl_func)
+    )
+
+    # Extract arguments and TensorDefs from the signature.
+    dsl_func_args = list()
+    sig = inspect.signature(dsl_func)
+    for param_name, param in sig.parameters.items():
+        param_default = param.default
+        if isinstance(
+            param_default,
+            (
+                TensorDef,
+                ScalarDef,
+                IndexAttrDef,
+                UnaryFnAttrDef,
+                BinaryFnAttrDef,
+                TypeFnAttrDef,
+            ),
+        ):
+            op_def.add_operand(param_name, param_default.operand_def)
+        else:
+            raise ValueError(
+                f"@linalg_structured_op function parameters must be defaulted as "
+                f"TensorDef(...), ScalarDef(...), or IndexAttrDef(...): "
+                f"Found {param_name}: {param_default}"
+            )
+        dsl_func_args.append(param_default)
+
+    # Invoke the DSL func to finish populating the op definition.
+    with bind_op_def(op_def):
+        dsl_func(*dsl_func_args)
+
+    # TODO: The returned callable should be an IR emitter but that is not
+    # upstreamed yet.
+    return DefinedOpCallable(op_name, op_def)
 
 
 def domain(*dimensions: DimDef):
-  if any(not isinstance(d, DimDef) for d in dimensions):
-    raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}")
-  current_op_def().domain.extend(dimensions)
+    if any(not isinstance(d, DimDef) for d in dimensions):
+        raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}")
+    current_op_def().domain.extend(dimensions)
 
 
 def implements(*interfaces: OpInterfaceDef):
-  if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces):
-    raise ValueError(
-        f"Expected interfaces of type OpInterfaceDef but got {interfaces}")
-  current_op_def().metadata.implements.extend(interfaces)
+    if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces):
+        raise ValueError(
+            f"Expected interfaces of type OpInterfaceDef but got {interfaces}"
+        )
+    current_op_def().metadata.implements.extend(interfaces)
 
 
 def defines(*definitions: OpDefinitionDef):
-  if any(not isinstance(defi, OpDefinitionDef) for defi in definitions):
-    raise ValueError(
-        f"Expected definitions of type OpDefinitionDef but got {definitions}")
-  current_op_def().metadata.defines.extend(definitions)
+    if any(not isinstance(defi, OpDefinitionDef) for defi in definitions):
+        raise ValueError(
+            f"Expected definitions of type OpDefinitionDef but got {definitions}"
+        )
+    current_op_def().metadata.defines.extend(definitions)
index b63cb40..62730d9 100644 (file)
@@ -11,7 +11,10 @@ from .... import linalg
 from .... import math
 from .... import arith
 from .... import complex
-from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
+from ...._ods_common import (
+    get_op_result_or_value as _get_op_result_or_value,
+    get_op_results_or_values as _get_op_results_or_values,
+)
 
 from .scalar_expr import *
 from .config import *
@@ -29,529 +32,618 @@ ValueList = Union[Sequence[Value], OpResultList]
 
 
 def isa(cls: Type, ty: Type):
-  try:
-    cls(ty)
-    return True
-  except ValueError:
-    return False
-
-
-def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
-                                 *ins: Value, outs: ValueList,
-                                 **attrs: Union[Sequence[int], TypeFnType]):
-  all_arg_defs = op_config.ordered_operands
-  in_arg_defs = [
-      d for d in all_arg_defs
-      if d.kind in [OperandKind.SCALAR, OperandKind.INPUT_TENSOR]
-  ]
-  out_arg_defs = [
-      d for d in all_arg_defs if d.kind == OperandKind.OUTPUT_TENSOR
-  ]
-  index_attr_arg_defs = [
-      d for d in all_arg_defs if d.kind == OperandKind.INDEX_ATTR
-  ]
-  fn_attr_arg_defs = [
-      d for d in all_arg_defs if d.kind in [
-          OperandKind.UNARY_FN_ATTR, OperandKind.BINARY_FN_ATTR,
-          OperandKind.TYPE_FN_ATTR
-      ]
-  ]
-
-  # Verify outs is a sequence or a list of results.
-  if not isinstance(outs, (Sequence, OpResultList)):
-    raise ValueError(f"Expected named argument outs to have type Sequence or "
-                     f"OpResultLis but got {type(outs)}")
-
-  # Arity validation.
-  if len(ins) != len(in_arg_defs):
-    raise ValueError(f"Expected {len(in_arg_defs)} inputs but got "
-                     f"{len(ins)} for {op_config}")
-  if outs and len(outs) != len(out_arg_defs):
-    raise ValueError(f"Expected {len(out_arg_defs)} outputs but got "
-                     f"{len(outs)} for {op_config}")
-
-  # Compute a replacement list for all index attribute symbols.
-  expressions = []  # type: Sequence[AffineExpr]
-  replacements = []  # type: Sequence[AffineExpr]
-  for index_attr in index_attr_arg_defs:
-    index_attr_vals = index_attr.operand_def.default_indices
-    if index_attr.name in attrs:
-      index_attr_vals = attrs.get(index_attr.name)
-    assert index_attr_vals, "Index attribute has no value"
-    if not all(isinstance(value, int) for value in index_attr_vals):
-      raise ValueError(f"Attribute {index_attr.name} needs to be of type "
-                       f"Sequence[int] but got {type(index_attr_vals)}")
-    results = index_attr.index_attr_map.results  # type: AffineExprList
-    if len(index_attr_vals) != len(results):
-      raise ValueError(f"Attribute {index_attr.name} has length {len(results)} "
-                       f"but got {len(index_attr_vals)} values")
-    for expr, value in zip(results, index_attr_vals):
-      expressions.append(expr)
-      replacements.append(AffineConstantExpr.get(value))
-
-  # Replace all index attribute symbols by their value.
-  # TODO: Add support for shape symbols.
-  indexing_maps = []  # type: Sequence[AffineMap]
-  for curr in op_config.indexing_maps:
-    for expression, replacement in zip(expressions, replacements):
-      curr = curr.replace(expression, replacement, curr.n_dims, curr.n_symbols)
-    indexing_maps.append(curr)
-
-  # TODO: Linalg verification does not currently allow symbols.
-  # Compress them for now and verify none are left.
-  indexing_maps = AffineMap.compress_unused_symbols(indexing_maps,
-                                                    Context.current)
-  if any(indexing_map.n_symbols != 0 for indexing_map in indexing_maps):
-    raise ValueError(f"Expected indexing_maps to use no symbols after "
-                     f"replacement and compression but got {indexing_maps}")
-
-  outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins,
-                                           out_arg_defs, outs)
-
-  result_types = [t for t in out_types if isa(RankedTensorType, t)]
-
-  # Initialize the type dictionary with the predefined types.
-  type_mapping = dict()  # type: Dict[str, Type]
-  type_mapping["F32"] = F32Type.get()
-  type_mapping["F64"] = F64Type.get()
-  type_mapping["I32"] = IntegerType.get_signless(32)
-  type_mapping["I64"] = IntegerType.get_signless(64)
-
-  # Extract type vars for input/output based types.
-  block_arg_types = list()  # type: List[Type]
-  for arg_def, arg_element_type in zip(in_arg_defs + out_arg_defs,
-                                       _get_types_from_values(*ins, *outs)):
-    _add_type_mapping(arg_def, arg_element_type, type_mapping, block_arg_types)
-
-  # Emit the generic op.
-  # TODO: Support emission of pure memref form.
-  indexing_maps_attr = ArrayAttr.get(
-      [AffineMapAttr.get(am) for am in indexing_maps])
-  iterator_types_attr = ArrayAttr.get([
-      Attribute.parse(f"#linalg.iterator_type<{s}>")
-      for s in op_config.iterator_types
-  ])
-
-  # Compute the index attributes used when emitting a named structured op.
-  index_attrs = {}  # type: Dict[str, DenseElementAttr]
-  for index_attr in index_attr_arg_defs:
-    index_attr_vals = attrs.get(index_attr.name)
-    # Only forward attributes set to a non-default value.
-    if index_attr_vals:
-      array = np.array(index_attr_vals, dtype=np.int64)
-      index_attrs[index_attr.name] = DenseElementsAttr.get(array)
-
-  # Compute the function attribute mapping.
-  fn_attr_mapping = {}
-  for fn_attr in fn_attr_arg_defs:
-    attr_val = fn_attr.operand_def.default_fn
-    attr_kind = fn_attr.kind
-    if fn_attr.name in attrs:
-      fn = attrs.get(fn_attr.name)
-      if attr_kind == OperandKind.UNARY_FN_ATTR:
-        if not isinstance(fn, UnaryFnType):
-          raise ValueError(f"Attribute {fn_attr.name} needs to be of type "
-                           f"UnaryFnType but got {type(attr_val)}")
-      elif attr_kind == OperandKind.BINARY_FN_ATTR:
-        if not isinstance(fn, BinaryFnType):
-          raise ValueError(f"Attribute {fn_attr.name} needs to be of type "
-                           f"BinaryFnType but got {type(attr_val)}")
-      else:
-        if not isinstance(fn, TypeFnType):
-          raise ValueError(f"Attribute {fn_attr.name} needs to be of type "
-                           f"TypeFnType but got {type(attr_val)}")
-      attr_val = fn.fn_name
-    assert attr_val, "Function attribute has no value"
-    fn_attr_mapping[fn_attr.name] = (attr_val, attr_kind)
-
-  return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types,
-          type_mapping, indexing_maps_attr, iterator_types_attr, index_attrs,
-          fn_attr_mapping, block_arg_types)
-
-
-def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
-                               outs: ValueList, **attrs: Sequence[int]):
-  all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
-  indexing_maps_attr, iterator_types_attr, index_attrs, fn_attr_mapping, \
-  block_arg_types = \
-     prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
-
-  # An operation that accesses only scalars and scalar/rank zero tensors is
-  # rank polymorhpic. We implement rank polymorphism by generating different
-  # indexing maps and iterators that match the rank of the first output tensor.
-  # An operation is rank polymorphic if the iteration domain has rank zero.
-  if not iterator_types_attr:
-    rank = ShapedType(outs[0].type).rank
-    iterator_types_attr = ArrayAttr.get(
-        [Attribute.parse("#linalg.iterator_type<parallel>")] * rank)
-    scalar_map = AffineMap.get(rank, 0, [])
-    tensor_map = AffineMap.get_identity(rank)
-    indexing_maps = []
-    for arg_def in all_arg_defs:
-      if arg_def.operand_def.kind == OperandKind.SCALAR:
-        indexing_maps.append(scalar_map)
-      if arg_def.operand_def.is_tensor():
-        idx = arg_def.operand_def.registered_index
-        if idx < len(ins) and ShapedType(ins[idx].type).rank == 0:
-          indexing_maps.append(scalar_map)
-        else:
-          indexing_maps.append(tensor_map)
-    indexing_maps_attr = ArrayAttr.get(
-        [AffineMapAttr.get(am) for am in indexing_maps])
-
-  generic_op = linalg.GenericOp(
-      result_tensors=result_types,
-      inputs=ins,
-      outputs=outs,
-      indexing_maps=indexing_maps_attr,
-      iterator_types=iterator_types_attr,
-      doc=None,  # TODO: Make optional.
-      library_call=None)  # TODO: Make optional.
-
-  # Construct the body.
-  block_arg_names = _get_operand_def_names(*in_arg_defs, *out_arg_defs)
-  block = generic_op.regions[0].blocks.append(*block_arg_types)
-  block_arg_mapping = dict(zip(block_arg_names, block.arguments))
-  with InsertionPoint(block):
-    body_builder = _BodyBuilder(type_mapping, block_arg_mapping,
-                                fn_attr_mapping)
-    for assignment in op_config.assignments:
-      body_builder.assign(assignment)
-    body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs))
-
-  if len(result_types) == 1:
-    return generic_op.result
-  else:
-    return generic_op.results
-
-
-def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
-                             op_class_name: str, *ins: Value, outs: ValueList,
-                             **attrs: Sequence[int]):
-  all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
-  indexing_maps_attr, iterator_types_attr, index_attrs, fn_attr_mapping, \
-  block_arg_types = \
-     prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
-
-  # If we get here, there must exist a builtin class `op_class_name`.
-  ctx = Context.current
-  fully_qualified_name = "linalg." + op_name
-  if (not ctx.is_registered_operation(fully_qualified_name) or
-      not op_class_name in linalg.__dict__.keys()):
-    raise NotImplementedError(
-        f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}")
-
-  # Set the index attributes used to compute the indexing maps.
-  named_op = getattr(linalg, op_class_name)(ins, outs, result_types)
-  for name, value in index_attrs.items():
-    named_op.operation.attributes[name] = value
-
-  # Compute the function attributes by combining operand kind and function name.
-  for name, (fn_name, kind) in fn_attr_mapping.items():
-    assert kind.name.lower().endswith("_attr")
-    enum_name = kind.name.lower()[:-5]
-    named_op.operation.attributes[name] = Attribute.parse(
-        f"#linalg.{enum_name}<{fn_name}>")
+    try:
+        cls(ty)
+        return True
+    except ValueError:
+        return False
 
-  linalg.fill_builtin_region(named_op.operation)
 
-  if len(result_types) == 1:
-    return named_op.result
-  else:
-    return named_op.results
+def prepare_common_structured_op(
+    op_config: LinalgStructuredOpConfig,
+    *ins: Value,
+    outs: ValueList,
+    **attrs: Union[Sequence[int], TypeFnType],
+):
+    all_arg_defs = op_config.ordered_operands
+    in_arg_defs = [
+        d
+        for d in all_arg_defs
+        if d.kind in [OperandKind.SCALAR, OperandKind.INPUT_TENSOR]
+    ]
+    out_arg_defs = [d for d in all_arg_defs if d.kind == OperandKind.OUTPUT_TENSOR]
+    index_attr_arg_defs = [d for d in all_arg_defs if d.kind == OperandKind.INDEX_ATTR]
+    fn_attr_arg_defs = [
+        d
+        for d in all_arg_defs
+        if d.kind
+        in [
+            OperandKind.UNARY_FN_ATTR,
+            OperandKind.BINARY_FN_ATTR,
+            OperandKind.TYPE_FN_ATTR,
+        ]
+    ]
+
+    # Verify outs is a sequence or a list of results.
+    if not isinstance(outs, (Sequence, OpResultList)):
+        raise ValueError(
+            f"Expected named argument outs to have type Sequence or "
+            f"OpResultLis but got {type(outs)}"
+        )
+
+    # Arity validation.
+    if len(ins) != len(in_arg_defs):
+        raise ValueError(
+            f"Expected {len(in_arg_defs)} inputs but got " f"{len(ins)} for {op_config}"
+        )
+    if outs and len(outs) != len(out_arg_defs):
+        raise ValueError(
+            f"Expected {len(out_arg_defs)} outputs but got "
+            f"{len(outs)} for {op_config}"
+        )
+
+    # Compute a replacement list for all index attribute symbols.
+    expressions = []  # type: Sequence[AffineExpr]
+    replacements = []  # type: Sequence[AffineExpr]
+    for index_attr in index_attr_arg_defs:
+        index_attr_vals = index_attr.operand_def.default_indices
+        if index_attr.name in attrs:
+            index_attr_vals = attrs.get(index_attr.name)
+        assert index_attr_vals, "Index attribute has no value"
+        if not all(isinstance(value, int) for value in index_attr_vals):
+            raise ValueError(
+                f"Attribute {index_attr.name} needs to be of type "
+                f"Sequence[int] but got {type(index_attr_vals)}"
+            )
+        results = index_attr.index_attr_map.results  # type: AffineExprList
+        if len(index_attr_vals) != len(results):
+            raise ValueError(
+                f"Attribute {index_attr.name} has length {len(results)} "
+                f"but got {len(index_attr_vals)} values"
+            )
+        for expr, value in zip(results, index_attr_vals):
+            expressions.append(expr)
+            replacements.append(AffineConstantExpr.get(value))
+
+    # Replace all index attribute symbols by their value.
+    # TODO: Add support for shape symbols.
+    indexing_maps = []  # type: Sequence[AffineMap]
+    for curr in op_config.indexing_maps:
+        for expression, replacement in zip(expressions, replacements):
+            curr = curr.replace(expression, replacement, curr.n_dims, curr.n_symbols)
+        indexing_maps.append(curr)
+
+    # TODO: Linalg verification does not currently allow symbols.
+    # Compress them for now and verify none are left.
+    indexing_maps = AffineMap.compress_unused_symbols(indexing_maps, Context.current)
+    if any(indexing_map.n_symbols != 0 for indexing_map in indexing_maps):
+        raise ValueError(
+            f"Expected indexing_maps to use no symbols after "
+            f"replacement and compression but got {indexing_maps}"
+        )
+
+    outs, out_types = _infer_structured_outs(
+        op_config, in_arg_defs, ins, out_arg_defs, outs
+    )
+
+    result_types = [t for t in out_types if isa(RankedTensorType, t)]
+
+    # Initialize the type dictionary with the predefined types.
+    type_mapping = dict()  # type: Dict[str, Type]
+    type_mapping["F32"] = F32Type.get()
+    type_mapping["F64"] = F64Type.get()
+    type_mapping["I32"] = IntegerType.get_signless(32)
+    type_mapping["I64"] = IntegerType.get_signless(64)
+
+    # Extract type vars for input/output based types.
+    block_arg_types = list()  # type: List[Type]
+    for arg_def, arg_element_type in zip(
+        in_arg_defs + out_arg_defs, _get_types_from_values(*ins, *outs)
+    ):
+        _add_type_mapping(arg_def, arg_element_type, type_mapping, block_arg_types)
+
+    # Emit the generic op.
+    # TODO: Support emission of pure memref form.
+    indexing_maps_attr = ArrayAttr.get([AffineMapAttr.get(am) for am in indexing_maps])
+    iterator_types_attr = ArrayAttr.get(
+        [
+            Attribute.parse(f"#linalg.iterator_type<{s}>")
+            for s in op_config.iterator_types
+        ]
+    )
+
+    # Compute the index attributes used when emitting a named structured op.
+    index_attrs = {}  # type: Dict[str, DenseElementAttr]
+    for index_attr in index_attr_arg_defs:
+        index_attr_vals = attrs.get(index_attr.name)
+        # Only forward attributes set to a non-default value.
+        if index_attr_vals:
+            array = np.array(index_attr_vals, dtype=np.int64)
+            index_attrs[index_attr.name] = DenseElementsAttr.get(array)
+
+    # Compute the function attribute mapping.
+    fn_attr_mapping = {}
+    for fn_attr in fn_attr_arg_defs:
+        attr_val = fn_attr.operand_def.default_fn
+        attr_kind = fn_attr.kind
+        if fn_attr.name in attrs:
+            fn = attrs.get(fn_attr.name)
+            if attr_kind == OperandKind.UNARY_FN_ATTR:
+                if not isinstance(fn, UnaryFnType):
+                    raise ValueError(
+                        f"Attribute {fn_attr.name} needs to be of type "
+                        f"UnaryFnType but got {type(attr_val)}"
+                    )
+            elif attr_kind == OperandKind.BINARY_FN_ATTR:
+                if not isinstance(fn, BinaryFnType):
+                    raise ValueError(
+                        f"Attribute {fn_attr.name} needs to be of type "
+                        f"BinaryFnType but got {type(attr_val)}"
+                    )
+            else:
+                if not isinstance(fn, TypeFnType):
+                    raise ValueError(
+                        f"Attribute {fn_attr.name} needs to be of type "
+                        f"TypeFnType but got {type(attr_val)}"
+                    )
+            attr_val = fn.fn_name
+        assert attr_val, "Function attribute has no value"
+        fn_attr_mapping[fn_attr.name] = (attr_val, attr_kind)
+
+    return (
+        all_arg_defs,
+        in_arg_defs,
+        out_arg_defs,
+        outs,
+        result_types,
+        type_mapping,
+        indexing_maps_attr,
+        iterator_types_attr,
+        index_attrs,
+        fn_attr_mapping,
+        block_arg_types,
+    )
+
+
+def emit_generic_structured_op(
+    op_config: LinalgStructuredOpConfig,
+    *ins: Value,
+    outs: ValueList,
+    **attrs: Sequence[int],
+):
+    (
+        all_arg_defs,
+        in_arg_defs,
+        out_arg_defs,
+        outs,
+        result_types,
+        type_mapping,
+        indexing_maps_attr,
+        iterator_types_attr,
+        index_attrs,
+        fn_attr_mapping,
+        block_arg_types,
+    ) = prepare_common_structured_op(op_config, *ins, outs=outs, **attrs)
+
+    # An operation that accesses only scalars and scalar/rank zero tensors is
+    # rank polymorhpic. We implement rank polymorphism by generating different
+    # indexing maps and iterators that match the rank of the first output tensor.
+    # An operation is rank polymorphic if the iteration domain has rank zero.
+    if not iterator_types_attr:
+        rank = ShapedType(outs[0].type).rank
+        iterator_types_attr = ArrayAttr.get(
+            [Attribute.parse("#linalg.iterator_type<parallel>")] * rank
+        )
+        scalar_map = AffineMap.get(rank, 0, [])
+        tensor_map = AffineMap.get_identity(rank)
+        indexing_maps = []
+        for arg_def in all_arg_defs:
+            if arg_def.operand_def.kind == OperandKind.SCALAR:
+                indexing_maps.append(scalar_map)
+            if arg_def.operand_def.is_tensor():
+                idx = arg_def.operand_def.registered_index
+                if idx < len(ins) and ShapedType(ins[idx].type).rank == 0:
+                    indexing_maps.append(scalar_map)
+                else:
+                    indexing_maps.append(tensor_map)
+        indexing_maps_attr = ArrayAttr.get(
+            [AffineMapAttr.get(am) for am in indexing_maps]
+        )
+
+    generic_op = linalg.GenericOp(
+        result_tensors=result_types,
+        inputs=ins,
+        outputs=outs,
+        indexing_maps=indexing_maps_attr,
+        iterator_types=iterator_types_attr,
+        doc=None,  # TODO: Make optional.
+        library_call=None,
+    )  # TODO: Make optional.
+
+    # Construct the body.
+    block_arg_names = _get_operand_def_names(*in_arg_defs, *out_arg_defs)
+    block = generic_op.regions[0].blocks.append(*block_arg_types)
+    block_arg_mapping = dict(zip(block_arg_names, block.arguments))
+    with InsertionPoint(block):
+        body_builder = _BodyBuilder(type_mapping, block_arg_mapping, fn_attr_mapping)
+        for assignment in op_config.assignments:
+            body_builder.assign(assignment)
+        body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs))
+
+    if len(result_types) == 1:
+        return generic_op.result
+    else:
+        return generic_op.results
+
+
+def emit_named_structured_op(
+    op_config: LinalgStructuredOpConfig,
+    op_name: str,
+    op_class_name: str,
+    *ins: Value,
+    outs: ValueList,
+    **attrs: Sequence[int],
+):
+    (
+        all_arg_defs,
+        in_arg_defs,
+        out_arg_defs,
+        outs,
+        result_types,
+        type_mapping,
+        indexing_maps_attr,
+        iterator_types_attr,
+        index_attrs,
+        fn_attr_mapping,
+        block_arg_types,
+    ) = prepare_common_structured_op(op_config, *ins, outs=outs, **attrs)
+
+    # If we get here, there must exist a builtin class `op_class_name`.
+    ctx = Context.current
+    fully_qualified_name = "linalg." + op_name
+    if (
+        not ctx.is_registered_operation(fully_qualified_name)
+        or not op_class_name in linalg.__dict__.keys()
+    ):
+        raise NotImplementedError(
+            f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}"
+        )
+
+    # Set the index attributes used to compute the indexing maps.
+    named_op = getattr(linalg, op_class_name)(ins, outs, result_types)
+    for name, value in index_attrs.items():
+        named_op.operation.attributes[name] = value
+
+    # Compute the function attributes by combining operand kind and function name.
+    for name, (fn_name, kind) in fn_attr_mapping.items():
+        assert kind.name.lower().endswith("_attr")
+        enum_name = kind.name.lower()[:-5]
+        named_op.operation.attributes[name] = Attribute.parse(
+            f"#linalg.{enum_name}<{fn_name}>"
+        )
+
+    linalg.fill_builtin_region(named_op.operation)
+
+    if len(result_types) == 1:
+        return named_op.result
+    else:
+        return named_op.results
 
 
 class _BodyBuilder:
-  """Constructs a structured op body by evaluating assignments."""
-
-  def __init__(self, type_mapping: Dict[str, Type],
-               block_arg_mapping: Dict[str, Value], fn_attr_mapping: Dict[str,
-                                                                          str]):
-    self.type_mapping = type_mapping
-    self.block_arg_mapping = block_arg_mapping
-    self.fn_attr_mapping = fn_attr_mapping
-    self.yield_mapping = dict()  # type: Dict[str, Value]
-
-  def assign(self, assignment: ScalarAssign):
-    if assignment.arg in self.yield_mapping:
-      raise ValueError(
-          f"Multiple assignments to the same argument are forbidden: "
-          f"{assignment}")
-    self.yield_mapping[assignment.arg] = self.expression(assignment.value)
-
-  def expression(self, expr: ScalarExpression) -> Value:
-    if expr.scalar_arg:
-      try:
-        return self.block_arg_mapping[expr.scalar_arg.arg]
-      except KeyError:
-        raise ValueError(f"Argument {expr.scalar_arg.arg} is not bound for "
-                         f"this structured op.")
-    elif expr.scalar_const:
-      value_attr = Attribute.parse(expr.scalar_const.value)
-      return arith.ConstantOp(value_attr.type, value_attr).result
-    elif expr.scalar_index:
-      dim_attr = IntegerAttr.get(
-          IntegerType.get_signless(64), expr.scalar_index.dim)
-      return linalg.IndexOp(dim_attr).result
-    elif expr.scalar_fn:
-      kind = expr.scalar_fn.kind.name.lower()
-      fn_name = expr.scalar_fn.fn_name
-      if expr.scalar_fn.attr_name:
-        fn_name, _ = self.fn_attr_mapping[expr.scalar_fn.attr_name]
-      fn = self._get_function(f"_{kind}_{fn_name}")
-      operand_values = [
-          self.expression(operand) for operand in expr.scalar_fn.operands
-      ]
-      if expr.scalar_fn.kind == FunctionKind.TYPE:
-        operand_values = [expr.scalar_fn.type_var.name] + operand_values
-      return fn(*operand_values)
-    raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
-
-  def yield_outputs(self, *output_names: str):
-    output_values = []
-    for n in output_names:
-      try:
-        output_values.append(self.yield_mapping[n])
-      except KeyError:
-        raise ValueError(f"Body assignments do not assign all outputs: "
-                         f"missing '{n}'")
-    linalg.YieldOp(output_values)
-
-  def _get_function(self, fn_name: str) -> Callable:
-    try:
-      fn = getattr(self, f"{fn_name}")
-    except AttributeError:
-      raise ValueError(f"Function '{fn_name}' is not a known function")
-    return fn
-
-  def _cast(self,
-            type_var_name: str,
-            operand: Value,
-            is_unsigned_cast: bool = False) -> Value:
-    try:
-      to_type = self.type_mapping[type_var_name]
-    except KeyError:
-      raise ValueError(f"Unbound type variable '{type_var_name}' ("
-                       f"expected one of {self.type_mapping.keys()}")
-    if operand.type == to_type:
-      return operand
-    if _is_integer_type(to_type):
-      return self._cast_to_integer(to_type, operand, is_unsigned_cast)
-    elif _is_floating_point_type(to_type):
-      return self._cast_to_floating_point(to_type, operand, is_unsigned_cast)
-
-  def _cast_to_integer(self, to_type: Type, operand: Value,
-                       is_unsigned_cast: bool) -> Value:
-    to_width = IntegerType(to_type).width
-    operand_type = operand.type
-    if _is_floating_point_type(operand_type):
-      if is_unsigned_cast:
-        return arith.FPToUIOp(to_type, operand).result
-      return arith.FPToSIOp(to_type, operand).result
-    if _is_index_type(operand_type):
-      return arith.IndexCastOp(to_type, operand).result
-    # Assume integer.
-    from_width = IntegerType(operand_type).width
-    if to_width > from_width:
-      if is_unsigned_cast:
-        return arith.ExtUIOp(to_type, operand).result
-      return arith.ExtSIOp(to_type, operand).result
-    elif to_width < from_width:
-      return arith.TruncIOp(to_type, operand).result
-    raise ValueError(f"Unable to cast body expression from {operand_type} to "
-                     f"{to_type}")
-
-  def _cast_to_floating_point(self, to_type: Type, operand: Value,
-                              is_unsigned_cast: bool) -> Value:
-    operand_type = operand.type
-    if _is_integer_type(operand_type):
-      if is_unsigned_cast:
-        return arith.UIToFPOp(to_type, operand).result
-      return arith.SIToFPOp(to_type, operand).result
-    # Assume FloatType.
-    to_width = _get_floating_point_width(to_type)
-    from_width = _get_floating_point_width(operand_type)
-    if to_width > from_width:
-      return arith.ExtFOp(to_type, operand).result
-    elif to_width < from_width:
-      return arith.TruncFOp(to_type, operand).result
-    raise ValueError(f"Unable to cast body expression from {operand_type} to "
-                     f"{to_type}")
-
-  def _type_cast_signed(self, type_var_name: str, operand: Value) -> Value:
-    return self._cast(type_var_name, operand, False)
-
-  def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
-    return self._cast(type_var_name, operand, True)
-
-  def _unary_exp(self, x: Value) -> Value:
-    if _is_floating_point_type(x.type):
-      return math.ExpOp(x).result
-    raise NotImplementedError("Unsupported 'exp' operand: {x}")
-
-  def _unary_log(self, x: Value) -> Value:
-    if _is_floating_point_type(x.type):
-      return math.LogOp(x).result
-    raise NotImplementedError("Unsupported 'log' operand: {x}")
-
-  def _unary_abs(self, x: Value) -> Value:
-    if _is_floating_point_type(x.type):
-      return math.AbsFOp(x).result
-    raise NotImplementedError("Unsupported 'abs' operand: {x}")
-
-  def _unary_ceil(self, x: Value) -> Value:
-    if _is_floating_point_type(x.type):
-      return math.CeilOp(x).result
-    raise NotImplementedError("Unsupported 'ceil' operand: {x}")
-
-  def _unary_floor(self, x: Value) -> Value:
-    if _is_floating_point_type(x.type):
-      return math.FloorOp(x).result
-    raise NotImplementedError("Unsupported 'floor' operand: {x}")
-
-  def _unary_negf(self, x: Value) -> Value:
-    if _is_floating_point_type(x.type):
-      return arith.NegFOp(x).result
-    if _is_complex_type(x.type):
-      return complex.NegOp(x).result
-    raise NotImplementedError("Unsupported 'negf' operand: {x}")
-
-  def _binary_add(self, lhs: Value, rhs: Value) -> Value:
-    if _is_floating_point_type(lhs.type):
-      return arith.AddFOp(lhs, rhs).result
-    if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
-      return arith.AddIOp(lhs, rhs).result
-    if _is_complex_type(lhs.type):
-      return complex.AddOp(lhs, rhs).result
-    raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}")
-
-  def _binary_sub(self, lhs: Value, rhs: Value) -> Value:
-    if _is_floating_point_type(lhs.type):
-      return arith.SubFOp(lhs, rhs).result
-    if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
-      return arith.SubIOp(lhs, rhs).result
-    if _is_complex_type(lhs.type):
-      return complex.SubOp(lhs, rhs).result
-    raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}")
-
-  def _binary_mul(self, lhs: Value, rhs: Value) -> Value:
-    if _is_floating_point_type(lhs.type):
-      return arith.MulFOp(lhs, rhs).result
-    if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
-      return arith.MulIOp(lhs, rhs).result
-    if _is_complex_type(lhs.type):
-      return complex.MulOp(lhs, rhs).result
-    raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}")
-
-  def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
-    if _is_floating_point_type(lhs.type):
-      return arith.MaxFOp(lhs, rhs).result
-    if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
-      return arith.MaxSIOp(lhs, rhs).result
-    raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}")
-
-  def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
-    if _is_floating_point_type(lhs.type):
-      return arith.MaxFOp(lhs, rhs).result
-    if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
-      return arith.MaxUIOp(lhs, rhs).result
-    raise NotImplementedError(
-        "Unsupported 'max_unsigned' operands: {lhs}, {rhs}")
-
-  def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value:
-    if _is_floating_point_type(lhs.type):
-      return arith.MinFOp(lhs, rhs).result
-    if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
-      return arith.MinSIOp(lhs, rhs).result
-    raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}")
-
-  def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
-    if _is_floating_point_type(lhs.type):
-      return arith.MinFOp(lhs, rhs).result
-    if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
-      return arith.MinUIOp(lhs, rhs).result
-    raise NotImplementedError(
-        "Unsupported 'min_unsigned' operands: {lhs}, {rhs}")
+    """Constructs a structured op body by evaluating assignments."""
+
+    def __init__(
+        self,
+        type_mapping: Dict[str, Type],
+        block_arg_mapping: Dict[str, Value],
+        fn_attr_mapping: Dict[str, str],
+    ):
+        self.type_mapping = type_mapping
+        self.block_arg_mapping = block_arg_mapping
+        self.fn_attr_mapping = fn_attr_mapping
+        self.yield_mapping = dict()  # type: Dict[str, Value]
+
+    def assign(self, assignment: ScalarAssign):
+        if assignment.arg in self.yield_mapping:
+            raise ValueError(
+                f"Multiple assignments to the same argument are forbidden: "
+                f"{assignment}"
+            )
+        self.yield_mapping[assignment.arg] = self.expression(assignment.value)
+
+    def expression(self, expr: ScalarExpression) -> Value:
+        if expr.scalar_arg:
+            try:
+                return self.block_arg_mapping[expr.scalar_arg.arg]
+            except KeyError:
+                raise ValueError(
+                    f"Argument {expr.scalar_arg.arg} is not bound for "
+                    f"this structured op."
+                )
+        elif expr.scalar_const:
+            value_attr = Attribute.parse(expr.scalar_const.value)
+            return arith.ConstantOp(value_attr.type, value_attr).result
+        elif expr.scalar_index:
+            dim_attr = IntegerAttr.get(
+                IntegerType.get_signless(64), expr.scalar_index.dim
+            )
+            return linalg.IndexOp(dim_attr).result
+        elif expr.scalar_fn:
+            kind = expr.scalar_fn.kind.name.lower()
+            fn_name = expr.scalar_fn.fn_name
+            if expr.scalar_fn.attr_name:
+                fn_name, _ = self.fn_attr_mapping[expr.scalar_fn.attr_name]
+            fn = self._get_function(f"_{kind}_{fn_name}")
+            operand_values = [
+                self.expression(operand) for operand in expr.scalar_fn.operands
+            ]
+            if expr.scalar_fn.kind == FunctionKind.TYPE:
+                operand_values = [expr.scalar_fn.type_var.name] + operand_values
+            return fn(*operand_values)
+        raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
+
+    def yield_outputs(self, *output_names: str):
+        output_values = []
+        for n in output_names:
+            try:
+                output_values.append(self.yield_mapping[n])
+            except KeyError:
+                raise ValueError(
+                    f"Body assignments do not assign all outputs: " f"missing '{n}'"
+                )
+        linalg.YieldOp(output_values)
+
+    def _get_function(self, fn_name: str) -> Callable:
+        try:
+            fn = getattr(self, f"{fn_name}")
+        except AttributeError:
+            raise ValueError(f"Function '{fn_name}' is not a known function")
+        return fn
+
+    def _cast(
+        self, type_var_name: str, operand: Value, is_unsigned_cast: bool = False
+    ) -> Value:
+        try:
+            to_type = self.type_mapping[type_var_name]
+        except KeyError:
+            raise ValueError(
+                f"Unbound type variable '{type_var_name}' ("
+                f"expected one of {self.type_mapping.keys()}"
+            )
+        if operand.type == to_type:
+            return operand
+        if _is_integer_type(to_type):
+            return self._cast_to_integer(to_type, operand, is_unsigned_cast)
+        elif _is_floating_point_type(to_type):
+            return self._cast_to_floating_point(to_type, operand, is_unsigned_cast)
+
+    def _cast_to_integer(
+        self, to_type: Type, operand: Value, is_unsigned_cast: bool
+    ) -> Value:
+        to_width = IntegerType(to_type).width
+        operand_type = operand.type
+        if _is_floating_point_type(operand_type):
+            if is_unsigned_cast:
+                return arith.FPToUIOp(to_type, operand).result
+            return arith.FPToSIOp(to_type, operand).result
+        if _is_index_type(operand_type):
+            return arith.IndexCastOp(to_type, operand).result
+        # Assume integer.
+        from_width = IntegerType(operand_type).width
+        if to_width > from_width:
+            if is_unsigned_cast:
+                return arith.ExtUIOp(to_type, operand).result
+            return arith.ExtSIOp(to_type, operand).result
+        elif to_width < from_width:
+            return arith.TruncIOp(to_type, operand).result
+        raise ValueError(
+            f"Unable to cast body expression from {operand_type} to " f"{to_type}"
+        )
+
+    def _cast_to_floating_point(
+        self, to_type: Type, operand: Value, is_unsigned_cast: bool
+    ) -> Value:
+        operand_type = operand.type
+        if _is_integer_type(operand_type):
+            if is_unsigned_cast:
+                return arith.UIToFPOp(to_type, operand).result
+            return arith.SIToFPOp(to_type, operand).result
+        # Assume FloatType.
+        to_width = _get_floating_point_width(to_type)
+        from_width = _get_floating_point_width(operand_type)
+        if to_width > from_width:
+            return arith.ExtFOp(to_type, operand).result
+        elif to_width < from_width:
+            return arith.TruncFOp(to_type, operand).result
+        raise ValueError(
+            f"Unable to cast body expression from {operand_type} to " f"{to_type}"
+        )
+
+    def _type_cast_signed(self, type_var_name: str, operand: Value) -> Value:
+        return self._cast(type_var_name, operand, False)
+
+    def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
+        return self._cast(type_var_name, operand, True)
+
+    def _unary_exp(self, x: Value) -> Value:
+        if _is_floating_point_type(x.type):
+            return math.ExpOp(x).result
+        raise NotImplementedError("Unsupported 'exp' operand: {x}")
+
+    def _unary_log(self, x: Value) -> Value:
+        if _is_floating_point_type(x.type):
+            return math.LogOp(x).result
+        raise NotImplementedError("Unsupported 'log' operand: {x}")
+
+    def _unary_abs(self, x: Value) -> Value:
+        if _is_floating_point_type(x.type):
+            return math.AbsFOp(x).result
+        raise NotImplementedError("Unsupported 'abs' operand: {x}")
+
+    def _unary_ceil(self, x: Value) -> Value:
+        if _is_floating_point_type(x.type):
+            return math.CeilOp(x).result
+        raise NotImplementedError("Unsupported 'ceil' operand: {x}")
+
+    def _unary_floor(self, x: Value) -> Value:
+        if _is_floating_point_type(x.type):
+            return math.FloorOp(x).result
+        raise NotImplementedError("Unsupported 'floor' operand: {x}")
+
+    def _unary_negf(self, x: Value) -> Value:
+        if _is_floating_point_type(x.type):
+            return arith.NegFOp(x).result
+        if _is_complex_type(x.type):
+            return complex.NegOp(x).result
+        raise NotImplementedError("Unsupported 'negf' operand: {x}")
+
+    def _binary_add(self, lhs: Value, rhs: Value) -> Value:
+        if _is_floating_point_type(lhs.type):
+            return arith.AddFOp(lhs, rhs).result
+        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+            return arith.AddIOp(lhs, rhs).result
+        if _is_complex_type(lhs.type):
+            return complex.AddOp(lhs, rhs).result
+        raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}")
+
+    def _binary_sub(self, lhs: Value, rhs: Value) -> Value:
+        if _is_floating_point_type(lhs.type):
+            return arith.SubFOp(lhs, rhs).result
+        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+            return arith.SubIOp(lhs, rhs).result
+        if _is_complex_type(lhs.type):
+            return complex.SubOp(lhs, rhs).result
+        raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}")
+
+    def _binary_mul(self, lhs: Value, rhs: Value) -> Value:
+        if _is_floating_point_type(lhs.type):
+            return arith.MulFOp(lhs, rhs).result
+        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+            return arith.MulIOp(lhs, rhs).result
+        if _is_complex_type(lhs.type):
+            return complex.MulOp(lhs, rhs).result
+        raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}")
+
+    def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
+        if _is_floating_point_type(lhs.type):
+            return arith.MaxFOp(lhs, rhs).result
+        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+            return arith.MaxSIOp(lhs, rhs).result
+        raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}")
+
+    def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
+        if _is_floating_point_type(lhs.type):
+            return arith.MaxFOp(lhs, rhs).result
+        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+            return arith.MaxUIOp(lhs, rhs).result
+        raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}")
+
+    def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value:
+        if _is_floating_point_type(lhs.type):
+            return arith.MinFOp(lhs, rhs).result
+        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+            return arith.MinSIOp(lhs, rhs).result
+        raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}")
+
+    def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
+        if _is_floating_point_type(lhs.type):
+            return arith.MinFOp(lhs, rhs).result
+        if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+            return arith.MinUIOp(lhs, rhs).result
+        raise NotImplementedError("Unsupported 'min_unsigned' operands: {lhs}, {rhs}")
 
 
 def _infer_structured_outs(
     op_config: LinalgStructuredOpConfig,
-    in_arg_defs: Sequence[OperandDefConfig], ins: Sequence[Value],
+    in_arg_defs: Sequence[OperandDefConfig],
+    ins: Sequence[Value],
     out_arg_defs: Sequence[OperandDefConfig],
-    outs: Union[Sequence[Value], OpResultList]) -> Tuple[ValueList, List[Type]]:
-  """Infers implicit outs and output types.
+    outs: Union[Sequence[Value], OpResultList],
+) -> Tuple[ValueList, List[Type]]:
+    """Infers implicit outs and output types.
 
-  Respects existing contents of outs if not empty.
+    Respects existing contents of outs if not empty.
 
-  Returns:
-    normalized outs, output types
-  """
-  # If outs were explicitly provided, we accept them verbatim.
-  if outs:
-    return outs, [out.type for out in outs]
+    Returns:
+      normalized outs, output types
+    """
+    # If outs were explicitly provided, we accept them verbatim.
+    if outs:
+        return outs, [out.type for out in outs]
 
-  raise NotImplementedError(f"Output tensor inference not yet supported for "
-                            "structured ops")
+    raise NotImplementedError(
+        f"Output tensor inference not yet supported for " "structured ops"
+    )
 
 
 def _get_types_from_values(*values: Value) -> Sequence[Type]:
-  types = []
-  for v in values:
-    types.append(v.type)
-  return types
+    types = []
+    for v in values:
+        types.append(v.type)
+    return types
 
 
 def _get_operand_def_names(*operand_configs: OperandDefConfig) -> Sequence[str]:
-  return [odc.operand_def.name for odc in operand_configs]
-
-
-def _add_type_mapping(operand_config: OperandDefConfig, operand_type: Type,
-                      type_mapping: Dict[str, Type],
-                      block_arg_types: Sequence[Type]):
-  element_or_self_type = operand_type
-  # Get the element type for tensor operands and the type itself for scalars.
-  if operand_config.shape_map:
-    try:
-      element_or_self_type = ShapedType(operand_type).element_type
-    except Exception as e:
-      raise ValueError(f"Expected ShapedType but got {operand_type}") from e
-  name = operand_config.type_var.name
-  if name in type_mapping:
-    if type_mapping[name] != element_or_self_type:
-      raise ValueError(f"Cannot overwrite type mapping {name} = "
-                       f"{type_mapping[name]} by type {element_or_self_type}")
-  type_mapping[name] = element_or_self_type
-  block_arg_types.append(element_or_self_type)
+    return [odc.operand_def.name for odc in operand_configs]
+
+
+def _add_type_mapping(
+    operand_config: OperandDefConfig,
+    operand_type: Type,
+    type_mapping: Dict[str, Type],
+    block_arg_types: Sequence[Type],
+):
+    element_or_self_type = operand_type
+    # Get the element type for tensor operands and the type itself for scalars.
+    if operand_config.shape_map:
+        try:
+            element_or_self_type = ShapedType(operand_type).element_type
+        except Exception as e:
+            raise ValueError(f"Expected ShapedType but got {operand_type}") from e
+    name = operand_config.type_var.name
+    if name in type_mapping:
+        if type_mapping[name] != element_or_self_type:
+            raise ValueError(
+                f"Cannot overwrite type mapping {name} = "
+                f"{type_mapping[name]} by type {element_or_self_type}"
+            )
+    type_mapping[name] = element_or_self_type
+    block_arg_types.append(element_or_self_type)
 
 
 def _is_complex_type(t: Type) -> bool:
-  return ComplexType.isinstance(t)
+    return ComplexType.isinstance(t)
 
 
 def _is_floating_point_type(t: Type) -> bool:
-  # TODO: Create a FloatType in the Python API and implement the switch
-  # there.
-  return (F64Type.isinstance(t) or F32Type.isinstance(t) or
-          F16Type.isinstance(t) or BF16Type.isinstance(t))
+    # TODO: Create a FloatType in the Python API and implement the switch
+    # there.
+    return (
+        F64Type.isinstance(t)
+        or F32Type.isinstance(t)
+        or F16Type.isinstance(t)
+        or BF16Type.isinstance(t)
+    )
 
 
 def _is_integer_type(t: Type) -> bool:
-  return IntegerType.isinstance(t)
+    return IntegerType.isinstance(t)
 
 
 def _is_index_type(t: Type) -> bool:
-  return IndexType.isinstance(t)
+    return IndexType.isinstance(t)
 
 
 def _get_floating_point_width(t: Type) -> int:
-  # TODO: Create a FloatType in the Python API and implement the switch
-  # there.
-  if F64Type.isinstance(t):
-    return 64
-  if F32Type.isinstance(t):
-    return 32
-  if F16Type.isinstance(t):
-    return 16
-  if BF16Type.isinstance(t):
-    return 16
-  raise NotImplementedError(f"Unhandled floating point type switch {t}")
+    # TODO: Create a FloatType in the Python API and implement the switch
+    # there.
+    if F64Type.isinstance(t):
+        return 64
+    if F32Type.isinstance(t):
+        return 32
+    if F16Type.isinstance(t):
+        return 16
+    if BF16Type.isinstance(t):
+        return 16
+    raise NotImplementedError(f"Unhandled floating point type switch {t}")
index aa894dc..8685399 100644 (file)
@@ -30,123 +30,137 @@ __all__ = [
 
 
 class ScalarFn:
-  """A type of ScalarExpression that applies a function."""
-
-  def __init__(self, kind: "FunctionKind", fn_name: Optional[str],
-               attr_name: Optional[str], type_var: Optional["TypeVar"],
-               operands: Sequence["ScalarExpression"]):
-    if bool(fn_name) + bool(attr_name) != 1:
-      raise ValueError("One of 'fn_name', 'attr_name' must be specified")
-    self.kind = kind
-    self.fn_name = fn_name
-    self.attr_name = attr_name
-    self.type_var = type_var
-    self.operands = operands
-
-  def expr(self) -> "ScalarExpression":
-    return ScalarExpression(scalar_fn=self)
-
-  def __repr__(self):
-    name = self.fn_name if self.fn_name else self.attr_name
-    return (f"ScalarFn<{self.kind.name}.{name}>(type_var={self.type_var}, "
-            f"operands=[{', '.join(self.operands)}])")
+    """A type of ScalarExpression that applies a function."""
+
+    def __init__(
+        self,
+        kind: "FunctionKind",
+        fn_name: Optional[str],
+        attr_name: Optional[str],
+        type_var: Optional["TypeVar"],
+        operands: Sequence["ScalarExpression"],
+    ):
+        if bool(fn_name) + bool(attr_name) != 1:
+            raise ValueError("One of 'fn_name', 'attr_name' must be specified")
+        self.kind = kind
+        self.fn_name = fn_name
+        self.attr_name = attr_name
+        self.type_var = type_var
+        self.operands = operands
+
+    def expr(self) -> "ScalarExpression":
+        return ScalarExpression(scalar_fn=self)
+
+    def __repr__(self):
+        name = self.fn_name if self.fn_name else self.attr_name
+        return (
+            f"ScalarFn<{self.kind.name}.{name}>(type_var={self.type_var}, "
+            f"operands=[{', '.join(self.operands)}])"
+        )
 
 
 class ScalarArg:
-  """A type of ScalarExpression that references a named argument."""
+    """A type of ScalarExpression that references a named argument."""
 
-  def __init__(self, arg: str):
-    self.arg = arg
+    def __init__(self, arg: str):
+        self.arg = arg
 
-  def expr(self) -> "ScalarExpression":
-    return ScalarExpression(scalar_arg=self)
+    def expr(self) -> "ScalarExpression":
+        return ScalarExpression(scalar_arg=self)
 
-  def __repr__(self):
-    return f"(ScalarArg({self.arg})"
+    def __repr__(self):
+        return f"(ScalarArg({self.arg})"
 
 
 class ScalarConst:
-  """A type of ScalarExpression representing a constant."""
+    """A type of ScalarExpression representing a constant."""
 
-  def __init__(self, value: str):
-    self.value = value
+    def __init__(self, value: str):
+        self.value = value
 
-  def expr(self) -> "ScalarExpression":
-    return ScalarExpression(scalar_const=self)
+    def expr(self) -> "ScalarExpression":
+        return ScalarExpression(scalar_const=self)
 
-  def __repr__(self):
-    return f"(ScalarConst({self.value})"
+    def __repr__(self):
+        return f"(ScalarConst({self.value})"
 
 
 class ScalarIndex:
-  """A type of ScalarExpression accessing an iteration index."""
+    """A type of ScalarExpression accessing an iteration index."""
 
-  def __init__(self, dim: int):
-    self.dim = dim
+    def __init__(self, dim: int):
+        self.dim = dim
 
-  def expr(self) -> "ScalarExpression":
-    return ScalarExpression(scalar_index=self)
+    def expr(self) -> "ScalarExpression":
+        return ScalarExpression(scalar_index=self)
 
-  def __repr__(self):
-    return f"(ScalarIndex({self.dim})"
+    def __repr__(self):
+        return f"(ScalarIndex({self.dim})"
 
 
 class ScalarExpression(YAMLObject):
-  """An expression on scalar values.
-
-  Can be one of:
-    - ScalarFn
-    - ScalarArg
-    - ScalarConst
-    - ScalarIndex
-  """
-  yaml_tag = "!ScalarExpression"
-
-  def __init__(self,
-               scalar_fn: Optional[ScalarFn] = None,
-               scalar_arg: Optional[ScalarArg] = None,
-               scalar_const: Optional[ScalarConst] = None,
-               scalar_index: Optional[ScalarIndex] = None):
-    if (bool(scalar_fn) + bool(scalar_arg) + bool(scalar_const) +
-        bool(scalar_index)) != 1:
-      raise ValueError("One of 'scalar_fn', 'scalar_arg', 'scalar_const', or "
-                       "'scalar_index' must be specified")
-    self.scalar_fn = scalar_fn
-    self.scalar_arg = scalar_arg
-    self.scalar_const = scalar_const
-    self.scalar_index = scalar_index
-
-  def to_yaml_custom_dict(self):
-    if self.scalar_fn:
-      scalar_fn_dict = dict(kind=self.scalar_fn.kind.name.lower())
-      if self.scalar_fn.fn_name:
-        scalar_fn_dict["fn_name"] = self.scalar_fn.fn_name
-      if self.scalar_fn.attr_name:
-        scalar_fn_dict["attr_name"] = self.scalar_fn.attr_name
-      if self.scalar_fn.type_var:
-        scalar_fn_dict["type_var"] = self.scalar_fn.type_var.name
-      scalar_fn_dict["operands"] = list(self.scalar_fn.operands)
-      return dict(scalar_fn=scalar_fn_dict)
-    elif self.scalar_arg:
-      return dict(scalar_arg=self.scalar_arg.arg)
-    elif self.scalar_const:
-      return dict(scalar_const=self.scalar_const.value)
-    elif self.scalar_index:
-      return dict(scalar_index=self.scalar_index.dim)
-    else:
-      raise ValueError(f"Unexpected ScalarExpression type: {self}")
+    """An expression on scalar values.
+
+    Can be one of:
+      - ScalarFn
+      - ScalarArg
+      - ScalarConst
+      - ScalarIndex
+    """
+
+    yaml_tag = "!ScalarExpression"
+
+    def __init__(
+        self,
+        scalar_fn: Optional[ScalarFn] = None,
+        scalar_arg: Optional[ScalarArg] = None,
+        scalar_const: Optional[ScalarConst] = None,
+        scalar_index: Optional[ScalarIndex] = None,
+    ):
+        if (
+            bool(scalar_fn) + bool(scalar_arg) + bool(scalar_const) + bool(scalar_index)
+        ) != 1:
+            raise ValueError(
+                "One of 'scalar_fn', 'scalar_arg', 'scalar_const', or "
+                "'scalar_index' must be specified"
+            )
+        self.scalar_fn = scalar_fn
+        self.scalar_arg = scalar_arg
+        self.scalar_const = scalar_const
+        self.scalar_index = scalar_index
+
+    def to_yaml_custom_dict(self):
+        if self.scalar_fn:
+            scalar_fn_dict = dict(kind=self.scalar_fn.kind.name.lower())
+            if self.scalar_fn.fn_name:
+                scalar_fn_dict["fn_name"] = self.scalar_fn.fn_name
+            if self.scalar_fn.attr_name:
+                scalar_fn_dict["attr_name"] = self.scalar_fn.attr_name
+            if self.scalar_fn.type_var:
+                scalar_fn_dict["type_var"] = self.scalar_fn.type_var.name
+            scalar_fn_dict["operands"] = list(self.scalar_fn.operands)
+            return dict(scalar_fn=scalar_fn_dict)
+        elif self.scalar_arg:
+            return dict(scalar_arg=self.scalar_arg.arg)
+        elif self.scalar_const:
+            return dict(scalar_const=self.scalar_const.value)
+        elif self.scalar_index:
+            return dict(scalar_index=self.scalar_index.dim)
+        else:
+            raise ValueError(f"Unexpected ScalarExpression type: {self}")
 
 
 class ScalarAssign(YAMLObject):
-  """An assignment to a named argument (LHS of a comprehension)."""
-  yaml_tag = "!ScalarAssign"
+    """An assignment to a named argument (LHS of a comprehension)."""
+
+    yaml_tag = "!ScalarAssign"
 
-  def __init__(self, arg: str, value: ScalarExpression):
-    self.arg = arg
-    self.value = value
+    def __init__(self, arg: str, value: ScalarExpression):
+        self.arg = arg
+        self.value = value
 
-  def to_yaml_custom_dict(self):
-    return dict(arg=self.arg, value=self.value)
+    def to_yaml_custom_dict(self):
+        return dict(arg=self.arg, value=self.value)
 
-  def __repr__(self):
-    return f"ScalarAssign({self.arg}, {self.value})"
+    def __repr__(self):
+        return f"ScalarAssign({self.arg}, {self.value})"
index ddac872..4f36029 100644 (file)
@@ -21,13 +21,11 @@ from typing import Dict
 __all__ = [
     "TypeVar",
     "TV",
-
     # Predefined types.
     "I32",
     "I64",
     "F32",
     "F64",
-
     # TypeVar aliases.
     "T",
     "U",
@@ -36,34 +34,34 @@ __all__ = [
 
 
 class TypeVar:
-  """A replaceable type variable.
+    """A replaceable type variable.
 
-  Type variables are uniqued by name.
-  """
-  ALL_TYPEVARS = dict()  # type: Dict[str, "TypeVar"]
+    Type variables are uniqued by name.
+    """
 
-  def __new__(cls, name: str):
-    existing = cls.ALL_TYPEVARS.get(name)
-    if existing is not None:
-      return existing
-    new = super().__new__(cls)
-    new.name = name
-    cls.ALL_TYPEVARS[name] = new
-    return new
+    ALL_TYPEVARS = dict()  # type: Dict[str, "TypeVar"]
 
-  def __repr__(self):
-    return f"TypeVar({self.name})"
+    def __new__(cls, name: str):
+        existing = cls.ALL_TYPEVARS.get(name)
+        if existing is not None:
+            return existing
+        new = super().__new__(cls)
+        new.name = name
+        cls.ALL_TYPEVARS[name] = new
+        return new
 
-  @classmethod
-  def create_expando(cls):
-    """Create an expando class that creates unique type vars on attr access."""
+    def __repr__(self):
+        return f"TypeVar({self.name})"
 
-    class ExpandoTypeVars:
+    @classmethod
+    def create_expando(cls):
+        """Create an expando class that creates unique type vars on attr access."""
 
-      def __getattr__(self, n):
-        return cls(n)
+        class ExpandoTypeVars:
+            def __getattr__(self, n):
+                return cls(n)
 
-    return ExpandoTypeVars()
+        return ExpandoTypeVars()
 
 
 # Expando access via TV.foo
index 1945eea..1672656 100644 (file)
@@ -6,11 +6,12 @@
 import sys
 
 try:
-  import yaml
+    import yaml
 except ModuleNotFoundError as e:
-  raise ModuleNotFoundError(
-      f"This tool requires PyYAML but it was not installed. "
-      f"Recommend: {sys.executable} -m pip install PyYAML") from e
+    raise ModuleNotFoundError(
+        f"This tool requires PyYAML but it was not installed. "
+        f"Recommend: {sys.executable} -m pip install PyYAML"
+    ) from e
 
 __all__ = [
     "yaml_dump",
@@ -20,35 +21,33 @@ __all__ = [
 
 
 class YAMLObject(yaml.YAMLObject):
+    @classmethod
+    def to_yaml(cls, dumper, self):
+        """Default to a custom dictionary mapping."""
+        return dumper.represent_mapping(cls.yaml_tag, self.to_yaml_custom_dict())
 
-  @classmethod
-  def to_yaml(cls, dumper, self):
-    """Default to a custom dictionary mapping."""
-    return dumper.represent_mapping(cls.yaml_tag, self.to_yaml_custom_dict())
+    def to_yaml_custom_dict(self):
+        raise NotImplementedError()
 
-  def to_yaml_custom_dict(self):
-    raise NotImplementedError()
-
-  def as_linalg_yaml(self):
-    return yaml_dump(self)
+    def as_linalg_yaml(self):
+        return yaml_dump(self)
 
 
 def multiline_str_representer(dumper, data):
-  if len(data.splitlines()) > 1:
-    return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
-  else:
-    return dumper.represent_scalar('tag:yaml.org,2002:str', data)
+    if len(data.splitlines()) > 1:
+        return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
+    else:
+        return dumper.represent_scalar("tag:yaml.org,2002:str", data)
 
 
 yaml.add_representer(str, multiline_str_representer)
 
 
 def yaml_dump(data, sort_keys=False, **kwargs):
-  return yaml.dump(data, sort_keys=sort_keys, **kwargs)
+    return yaml.dump(data, sort_keys=sort_keys, **kwargs)
 
 
 def yaml_dump_all(data, sort_keys=False, explicit_start=True, **kwargs):
-  return yaml.dump_all(data,
-                       sort_keys=sort_keys,
-                       explicit_start=explicit_start,
-                       **kwargs)
+    return yaml.dump_all(
+        data, sort_keys=sort_keys, explicit_start=explicit_start, **kwargs
+    )
index 9c96868..bac22a2 100644 (file)
@@ -7,99 +7,113 @@ Batch = S.Batch
 
 
 @linalg_structured_op
-def copy(I=TensorDef(T1),
-         O=TensorDef(U, output=True),
-         cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
-  """Copies the tensor elementwise.
+def copy(
+    I=TensorDef(T1),
+    O=TensorDef(U, output=True),
+    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+    """Copies the tensor elementwise.
 
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  O[None] = cast(U, I[None])
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    O[None] = cast(U, I[None])
 
 
 @linalg_structured_op
-def elemwise_unary(I=TensorDef(T1),
-                   O=TensorDef(U, output=True),
-                   fun=UnaryFnAttrDef(default=UnaryFn.exp),
-                   cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
-  """Applies the unary function fun elementwise.
+def elemwise_unary(
+    I=TensorDef(T1),
+    O=TensorDef(U, output=True),
+    fun=UnaryFnAttrDef(default=UnaryFn.exp),
+    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+    """Applies the unary function fun elementwise.
 
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  O[None] = fun(cast(U, I[None]))
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    O[None] = fun(cast(U, I[None]))
 
 
 @linalg_structured_op
-def elemwise_binary(lhs=TensorDef(T1),
-                    rhs=TensorDef(T2),
-                    O=TensorDef(U, output=True),
-                    fun=BinaryFnAttrDef(default=BinaryFn.add),
-                    cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
-  """Applies the binary function fun elementwise.
+def elemwise_binary(
+    lhs=TensorDef(T1),
+    rhs=TensorDef(T2),
+    O=TensorDef(U, output=True),
+    fun=BinaryFnAttrDef(default=BinaryFn.add),
+    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+    """Applies the binary function fun elementwise.
 
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None]))
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None]))
 
 
 @linalg_structured_op
-def matmul(A=TensorDef(T1, S.M, S.K),
-           B=TensorDef(T2, S.K, S.N),
-           C=TensorDef(U, S.M, S.N, output=True),
-           cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
-  """Performs a matrix multiplication of two 2D inputs.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  domain(D.m, D.n, D.k)
-  implements(ContractionOpInterface)
-  C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
+def matmul(
+    A=TensorDef(T1, S.M, S.K),
+    B=TensorDef(T2, S.K, S.N),
+    C=TensorDef(U, S.M, S.N, output=True),
+    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+    """Performs a matrix multiplication of two 2D inputs.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    domain(D.m, D.n, D.k)
+    implements(ContractionOpInterface)
+    C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
 
 
 @linalg_structured_op
-def matmul_unsigned(A=TensorDef(T1, S.M, S.K),
-                    B=TensorDef(T2, S.K, S.N),
-                    C=TensorDef(U, S.M, S.N, output=True)):
-  """Performs an unsigned matrix multiplication of two 2D inputs.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  domain(D.m, D.n, D.k)
-  implements(ContractionOpInterface)
-  C[D.m, D.n] += TypeFn.cast_unsigned(U, A[D.m, D.k]) * TypeFn.cast_unsigned(
-      U, B[D.k, D.n])
+def matmul_unsigned(
+    A=TensorDef(T1, S.M, S.K),
+    B=TensorDef(T2, S.K, S.N),
+    C=TensorDef(U, S.M, S.N, output=True),
+):
+    """Performs an unsigned matrix multiplication of two 2D inputs.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    domain(D.m, D.n, D.k)
+    implements(ContractionOpInterface)
+    C[D.m, D.n] += TypeFn.cast_unsigned(U, A[D.m, D.k]) * TypeFn.cast_unsigned(
+        U, B[D.k, D.n]
+    )
 
 
 @linalg_structured_op
-def quantized_matmul(A=TensorDef(T1, S.M, S.K),
-                     B=TensorDef(T2, S.K, S.N),
-                     AZp=ScalarDef(I32),
-                     BZp=ScalarDef(I32),
-                     C=TensorDef(U, S.M, S.N, output=True)):
-  """Performs a matrix multiplication of two 2D inputs.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output. The quantized variant
-  includes zero-point adjustments for the left and right operands of the
-  matmul.
-  """
-  domain(D.m, D.n, D.k)
-  C[D.m,
-    D.n] += (TypeFn.cast_signed(U, A[D.m, D.k]) -
-             TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed(U, B[D.k, D.n]) -
-                                            TypeFn.cast_signed(U, BZp))
+def quantized_matmul(
+    A=TensorDef(T1, S.M, S.K),
+    B=TensorDef(T2, S.K, S.N),
+    AZp=ScalarDef(I32),
+    BZp=ScalarDef(I32),
+    C=TensorDef(U, S.M, S.N, output=True),
+):
+    """Performs a matrix multiplication of two 2D inputs.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. The quantized variant
+    includes zero-point adjustments for the left and right operands of the
+    matmul.
+    """
+    domain(D.m, D.n, D.k)
+    C[D.m, D.n] += (TypeFn.cast_signed(U, A[D.m, D.k]) - TypeFn.cast_signed(U, AZp)) * (
+        TypeFn.cast_signed(U, B[D.k, D.n]) - TypeFn.cast_signed(U, BZp)
+    )
 
 
 @linalg_structured_op
-def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0),
-          rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0),
-          accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, output=True)):
-  """Performs a matrix-matrix-transpose multiplication of two 4D inputs.
+def mmt4d(
+    lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0),
+    rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0),
+    accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, output=True),
+):
+    """Performs a matrix-matrix-transpose multiplication of two 4D inputs.
 
     Differences from linalg.matmul:
     * The right hand side is transposed, whence the 't' in 'mmt'.
@@ -108,1132 +122,1201 @@ def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0),
       whence the 2+2=4 dimensions. The inner tile dimensions are identified with
       '0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads
       as: MxK tiles, each of shape M0xK0.
-  """
-  domain(D.m, D.n, D.k, D.m0, D.n0, D.k0)
-  implements(ContractionOpInterface)
-  accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast_signed(
-      TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * TypeFn.cast_signed(
-          TV.AccumType, rhs[D.n, D.k, D.n0, D.k0])
+    """
+    domain(D.m, D.n, D.k, D.m0, D.n0, D.k0)
+    implements(ContractionOpInterface)
+    accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast_signed(
+        TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]
+    ) * TypeFn.cast_signed(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0])
 
 
 @linalg_structured_op
-def batch_matmul(A=TensorDef(T1, Batch, S.M, S.K),
-                 B=TensorDef(T2, Batch, S.K, S.N),
-                 C=TensorDef(U, Batch, S.M, S.N, output=True)):
-  """Performs a batched matrix multiplication of two 3D inputs.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  domain(D.b, D.m, D.n, D.k)
-  implements(ContractionOpInterface)
-  C[D.b, D.m,
-    D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
-        U, B[D.b, D.k, D.n])
+def batch_matmul(
+    A=TensorDef(T1, Batch, S.M, S.K),
+    B=TensorDef(T2, Batch, S.K, S.N),
+    C=TensorDef(U, Batch, S.M, S.N, output=True),
+):
+    """Performs a batched matrix multiplication of two 3D inputs.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    domain(D.b, D.m, D.n, D.k)
+    implements(ContractionOpInterface)
+    C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
+        U, B[D.b, D.k, D.n]
+    )
 
 
 @linalg_structured_op
-def quantized_batch_matmul(A=TensorDef(T1, Batch, S.M, S.K),
-                           B=TensorDef(T2, Batch, S.K, S.N),
-                           AZp=ScalarDef(I32),
-                           BZp=ScalarDef(I32),
-                           C=TensorDef(U, Batch, S.M, S.N, output=True)):
-  """Performs a batched matrix multiplication of two 3D inputs.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output. The quantized variant
-  includes zero-point adjustments for the left and right operands of the
-  matmul.
-  """
-  domain(D.b, D.m, D.n, D.k)
-  C[D.b, D.m, D.n] += (TypeFn.cast_signed(U, A[D.b, D.m, D.k]) -
-                       TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed(
-                           U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp))
+def quantized_batch_matmul(
+    A=TensorDef(T1, Batch, S.M, S.K),
+    B=TensorDef(T2, Batch, S.K, S.N),
+    AZp=ScalarDef(I32),
+    BZp=ScalarDef(I32),
+    C=TensorDef(U, Batch, S.M, S.N, output=True),
+):
+    """Performs a batched matrix multiplication of two 3D inputs.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. The quantized variant
+    includes zero-point adjustments for the left and right operands of the
+    matmul.
+    """
+    domain(D.b, D.m, D.n, D.k)
+    C[D.b, D.m, D.n] += (
+        TypeFn.cast_signed(U, A[D.b, D.m, D.k]) - TypeFn.cast_signed(U, AZp)
+    ) * (TypeFn.cast_signed(U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp))
 
 
 @linalg_structured_op
-def batch_reduce_matmul(A=TensorDef(T1, Batch, S.M, S.K),
-                        B=TensorDef(T2, Batch, S.K, S.N),
-                        C=TensorDef(U, S.M, S.N, output=True)):
-  """Performs a batch-reduce matrix multiplication of two 3D inputs.
-  The partial multiplication results are reduced into a 2D output.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  domain(D.b, D.m, D.n, D.k)
-  implements(ContractionOpInterface)
-  C[D.m, D.n] += TypeFn.cast_signed(
-      U, A[D.b, D.m, D.k] * TypeFn.cast_signed(U, B[D.b, D.k, D.n]))
+def batch_reduce_matmul(
+    A=TensorDef(T1, Batch, S.M, S.K),
+    B=TensorDef(T2, Batch, S.K, S.N),
+    C=TensorDef(U, S.M, S.N, output=True),
+):
+    """Performs a batch-reduce matrix multiplication of two 3D inputs.
+    The partial multiplication results are reduced into a 2D output.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    domain(D.b, D.m, D.n, D.k)
+    implements(ContractionOpInterface)
+    C[D.m, D.n] += TypeFn.cast_signed(
+        U, A[D.b, D.m, D.k] * TypeFn.cast_signed(U, B[D.b, D.k, D.n])
+    )
 
 
 @linalg_structured_op
-def matvec(A=TensorDef(T1, S.M, S.N),
-           y=TensorDef(T2, S.N),
-           x=TensorDef(U, S.M, output=True)):
-  """Performs a matrix-vector multiplication.
+def matvec(
+    A=TensorDef(T1, S.M, S.N), y=TensorDef(T2, S.N), x=TensorDef(U, S.M, output=True)
+):
+    """Performs a matrix-vector multiplication.
 
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  domain(D.m, D.n)
-  implements(ContractionOpInterface)
-  x[D.m] += TypeFn.cast_signed(U, A[D.m, D.n]) * TypeFn.cast_signed(U, y[D.n])
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    domain(D.m, D.n)
+    implements(ContractionOpInterface)
+    x[D.m] += TypeFn.cast_signed(U, A[D.m, D.n]) * TypeFn.cast_signed(U, y[D.n])
 
 
 @linalg_structured_op
-def vecmat(y=TensorDef(T1, S.M),
-           A=TensorDef(T2, S.M, S.N),
-           x=TensorDef(U, S.N, output=True)):
-  """Performs a vector-matrix multiplication.
+def vecmat(
+    y=TensorDef(T1, S.M), A=TensorDef(T2, S.M, S.N), x=TensorDef(U, S.N, output=True)
+):
+    """Performs a vector-matrix multiplication.
 
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  domain(D.n, D.m)
-  implements(ContractionOpInterface)
-  x[D.n] += TypeFn.cast_signed(U, y[D.m]) * TypeFn.cast_signed(U, A[D.m, D.n])
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    domain(D.n, D.m)
+    implements(ContractionOpInterface)
+    x[D.n] += TypeFn.cast_signed(U, y[D.m]) * TypeFn.cast_signed(U, A[D.m, D.n])
 
 
 @linalg_structured_op
-def batch_matvec(A=TensorDef(T1, Batch, S.M, S.K),
-                 B=TensorDef(T2, Batch, S.K),
-                 C=TensorDef(U, Batch, S.M, output=True)):
-  """Performs a batched matrix-vector multiplication.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  domain(D.b, D.m, D.k)
-  implements(ContractionOpInterface)
-  C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
-      U, B[D.b, D.k])
+def batch_matvec(
+    A=TensorDef(T1, Batch, S.M, S.K),
+    B=TensorDef(T2, Batch, S.K),
+    C=TensorDef(U, Batch, S.M, output=True),
+):
+    """Performs a batched matrix-vector multiplication.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    domain(D.b, D.m, D.k)
+    implements(ContractionOpInterface)
+    C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
+        U, B[D.b, D.k]
+    )
 
 
 @linalg_structured_op
-def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U,
-                                                                output=True)):
-  """Performs a dot product of two vectors to a scalar result.
+def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)):
+    """Performs a dot product of two vectors to a scalar result.
 
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ContractionOpInterface)
-  C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m])
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ContractionOpInterface)
+    C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m])
 
 
 @linalg_structured_op
-def conv_1d(I=TensorDef(T1, S.OW + S.KW),
-            K=TensorDef(T2, S.KW),
-            O=TensorDef(U, S.OW, output=True)):
-  """Performs 1-D convolution with no channels.
+def conv_1d(
+    I=TensorDef(T1, S.OW + S.KW),
+    K=TensorDef(T2, S.KW),
+    O=TensorDef(U, S.OW, output=True),
+):
+    """Performs 1-D convolution with no channels.
 
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.ow, D.kw)
-  O[D.ow] += TypeFn.cast_signed(U, I[D.ow + D.kw]) * TypeFn.cast_signed(
-      U, K[D.kw])
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.ow, D.kw)
+    O[D.ow] += TypeFn.cast_signed(U, I[D.ow + D.kw]) * TypeFn.cast_signed(U, K[D.kw])
 
 
 @linalg_structured_op
-def conv_2d(I=TensorDef(T1, S.OH + S.KH, S.OW + S.KW),
-            K=TensorDef(T2, S.KH, S.KW),
-            O=TensorDef(U, S.OH, S.OW, output=True)):
-  """Performs 2-D convolution with no channels.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.oh, D.ow, D.kh, D.kw)
-  O[D.oh, D.ow] += TypeFn.cast_signed(
-      U, I[D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast_signed(U, K[D.kh, D.kw])
+def conv_2d(
+    I=TensorDef(T1, S.OH + S.KH, S.OW + S.KW),
+    K=TensorDef(T2, S.KH, S.KW),
+    O=TensorDef(U, S.OH, S.OW, output=True),
+):
+    """Performs 2-D convolution with no channels.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.oh, D.ow, D.kh, D.kw)
+    O[D.oh, D.ow] += TypeFn.cast_signed(
+        U, I[D.oh + D.kh, D.ow + D.kw]
+    ) * TypeFn.cast_signed(U, K[D.kh, D.kw])
 
 
 @linalg_structured_op
-def conv_3d(I=TensorDef(T1, S.OD + S.KD, S.OH + S.KH, S.OW + S.KW),
-            K=TensorDef(T2, S.KD, S.KH, S.KW),
-            O=TensorDef(U, S.OD, S.OH, S.OW, output=True)):
-  """Performs 3-D convolution with no channels.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw)
-  O[D.od, D.oh, D.ow] += TypeFn.cast_signed(
-      U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast_signed(
-          U, K[D.kd, D.kh, D.kw])
+def conv_3d(
+    I=TensorDef(T1, S.OD + S.KD, S.OH + S.KH, S.OW + S.KW),
+    K=TensorDef(T2, S.KD, S.KH, S.KW),
+    O=TensorDef(U, S.OD, S.OH, S.OW, output=True),
+):
+    """Performs 3-D convolution with no channels.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw)
+    O[D.od, D.oh, D.ow] += TypeFn.cast_signed(
+        U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw]
+    ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw])
 
 
 @linalg_structured_op
-def conv_1d_nwc_wcf(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
-                    K=TensorDef(T2, S.KW, S.C, S.F),
-                    O=TensorDef(U, S.N, S.OW, S.F, output=True),
-                    strides=IndexAttrDef(S.SW, default=[1]),
-                    dilations=IndexAttrDef(S.DW, default=[1])):
-  """Performs 1-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.ow, D.f, D.kw, D.c)
-  O[D.n, D.ow, D.f] += TypeFn.cast_signed(
-      U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast_signed(
-          U, K[D.kw, D.c, D.f])
+def conv_1d_nwc_wcf(
+    I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KW, S.C, S.F),
+    O=TensorDef(U, S.N, S.OW, S.F, output=True),
+    strides=IndexAttrDef(S.SW, default=[1]),
+    dilations=IndexAttrDef(S.DW, default=[1]),
+):
+    """Performs 1-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.ow, D.f, D.kw, D.c)
+    O[D.n, D.ow, D.f] += TypeFn.cast_signed(
+        U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]
+    ) * TypeFn.cast_signed(U, K[D.kw, D.c, D.f])
 
 
 @linalg_structured_op
-def conv_1d_ncw_fcw(I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW),
-                    K=TensorDef(T2, S.F, S.C, S.KW),
-                    O=TensorDef(U, S.N, S.F, S.OW, output=True),
-                    strides=IndexAttrDef(S.SW, default=[1]),
-                    dilations=IndexAttrDef(S.DW, default=[1])):
-  """Performs 1-D convolution.
-
-  Layout:
-    * Input: NCW.
-    * Kernel: FCW.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.f, D.ow, D.c, D.kw)
-  O[D.n, D.f, D.ow] += TypeFn.cast_signed(
-      U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed(
-          U, K[D.f, D.c, D.kw])
+def conv_1d_ncw_fcw(
+    I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW),
+    K=TensorDef(T2, S.F, S.C, S.KW),
+    O=TensorDef(U, S.N, S.F, S.OW, output=True),
+    strides=IndexAttrDef(S.SW, default=[1]),
+    dilations=IndexAttrDef(S.DW, default=[1]),
+):
+    """Performs 1-D convolution.
+
+    Layout:
+      * Input: NCW.
+      * Kernel: FCW.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.f, D.ow, D.c, D.kw)
+    O[D.n, D.f, D.ow] += TypeFn.cast_signed(
+        U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW]
+    ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kw])
 
 
 @linalg_structured_op
-def conv_2d_nhwc_hwcf(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
-                                  S.OW * S.SW + S.KW * S.DW, S.C),
-                      K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
-                      O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
-                      strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                      dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
-  """Performs 2-D convolution.
-
-  Layout:
-    * Input: NHWC.
-    * Kernel: HWCF.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
-  O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed(
-      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
-           D.c]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f])
+def conv_2d_nhwc_hwcf(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs 2-D convolution.
+
+    Layout:
+      * Input: NHWC.
+      * Kernel: HWCF.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
+    O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed(
+        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+    ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f])
 
 
 @linalg_structured_op
-def conv_2d_nhwc_fhwc(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
-                                  S.OW * S.SW + S.KW * S.DW, S.C),
-                      K=TensorDef(T2, S.F, S.KH, S.KW, S.C),
-                      O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
-                      strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                      dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
-  """Performs 2-D convolution.
-
-  Layout:
-    * Input: NHWC.
-    * Kernel: FHWC.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
-  O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed(
-      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
-           D.c]) * TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c])
+def conv_2d_nhwc_fhwc(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.F, S.KH, S.KW, S.C),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs 2-D convolution.
+
+    Layout:
+      * Input: NHWC.
+      * Kernel: FHWC.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
+    O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed(
+        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+    ) * TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c])
 
 
 @linalg_structured_op
-def conv_2d_nhwc_hwcf_q(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
-                                    S.OW * S.SW + S.KW * S.DW, S.C),
-                        K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
-                        IZp=ScalarDef(I32),
-                        KZp=ScalarDef(I32),
-                        O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
-                        strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                        dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
-  """Performs 2-D convolution with zero point offsets.
-
-  Layout:
-    * Input: NHWC.
-    * Kernel: HWCF.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output. This includes the zero
-  point offsets common to quantized operations.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
-  O[D.n, D.oh, D.ow,
-    D.f] += (TypeFn.cast_signed(
-        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) -
-             TypeFn.cast_signed(U, IZp)) * (TypeFn.cast_signed(
-                 U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast_signed(U, KZp))
+def conv_2d_nhwc_hwcf_q(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
+    IZp=ScalarDef(I32),
+    KZp=ScalarDef(I32),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs 2-D convolution with zero point offsets.
+
+    Layout:
+      * Input: NHWC.
+      * Kernel: HWCF.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. This includes the zero
+    point offsets common to quantized operations.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
+    O[D.n, D.oh, D.ow, D.f] += (
+        TypeFn.cast_signed(
+            U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+        )
+        - TypeFn.cast_signed(U, IZp)
+    ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast_signed(U, KZp))
 
 
 @linalg_structured_op
-def conv_2d_nchw_fchw(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH,
-                                  S.OW * S.SW + S.KW * S.DW),
-                      K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
-                      O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True),
-                      strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                      dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
-  """Performs 2-D convolution.
-
-  Layout:
-    * Input: NCHW.
-    * Kernel: FCHW.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
-  O[D.n, D.f, D.oh, D.ow] += TypeFn.cast_signed(
-      U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW +
-           D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw])
+def conv_2d_nchw_fchw(
+    I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
+    K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
+    O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs 2-D convolution.
+
+    Layout:
+      * Input: NCHW.
+      * Kernel: FCHW.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
+    O[D.n, D.f, D.oh, D.ow] += TypeFn.cast_signed(
+        U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
+    ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw])
 
 
 @linalg_structured_op
-def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C,
-                                    S.OH * S.SH + S.KH * S.DH,
-                                    S.OW * S.SW + S.KW * S.DW),
-                        K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW),
-                        O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True),
-                        strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                        dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
-  """Performs 2-D grouped convolution.
-
-  Layout:
-    * Input: NGCHW.
-    * Kernel: FGCHW.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw)
-  O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed(
-      U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW +
-           D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw])
+def conv_2d_ngchw_fgchw(
+    I=TensorDef(
+        T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW
+    ),
+    K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW),
+    O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs 2-D grouped convolution.
+
+    Layout:
+      * Input: NGCHW.
+      * Kernel: FGCHW.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw)
+    O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed(
+        U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
+    ) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw])
 
 
 @linalg_structured_op
-def conv_3d_ndhwc_dhwcf(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
-                                    S.OH * S.SH + S.KH * S.DH,
-                                    S.OW * S.SW + S.KW * S.DW, S.C),
-                        K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F),
-                        O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True),
-                        strides=IndexAttrDef(S.SD,
-                                             S.SH,
-                                             S.SW,
-                                             default=[1, 1, 1]),
-                        dilations=IndexAttrDef(S.DD,
-                                               S.DH,
-                                               S.DW,
-                                               default=[1, 1, 1])):
-  """Performs 3-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
-  O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast_signed(
-      U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
-           D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast_signed(
-               U, K[D.kd, D.kh, D.kw, D.c, D.f])
+def conv_3d_ndhwc_dhwcf(
+    I=TensorDef(
+        T1,
+        S.N,
+        S.OD * S.SD + S.KD * S.DD,
+        S.OH * S.SH + S.KH * S.DH,
+        S.OW * S.SW + S.KW * S.DW,
+        S.C,
+    ),
+    K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F),
+    O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True),
+    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+    """Performs 3-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
+    O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast_signed(
+        U,
+        I[
+            D.n,
+            D.od * S.SD + D.kd * S.DD,
+            D.oh * S.SH + D.kh * S.DH,
+            D.ow * S.SW + D.kw * S.DW,
+            D.c,
+        ],
+    ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f])
 
 
 @linalg_structured_op
-def conv_3d_ndhwc_dhwcf_q(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
-                                      S.OH * S.SH + S.KH * S.DH,
-                                      S.OW * S.SW + S.KW * S.DW, S.C),
-                          K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F),
-                          IZp=ScalarDef(I32),
-                          KZp=ScalarDef(I32),
-                          O=TensorDef(U,
-                                      S.N,
-                                      S.OD,
-                                      S.OH,
-                                      S.OW,
-                                      S.F,
-                                      output=True),
-                          strides=IndexAttrDef(S.SD,
-                                               S.SH,
-                                               S.SW,
-                                               default=[1, 1, 1]),
-                          dilations=IndexAttrDef(S.DD,
-                                                 S.DH,
-                                                 S.DW,
-                                                 default=[1, 1, 1])):
-  """Performs 3-D convolution with zero point offsets.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output. This includes the zero
-  point offsets common to quantized operations.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
-  O[D.n, D.od, D.oh, D.ow, D.f] += (TypeFn.cast_signed(
-      U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
-           D.ow * S.SW + D.kw * S.DW, D.c]) - TypeFn.cast_signed(U, IZp)) * (
-               TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f]) -
-               TypeFn.cast_signed(U, KZp))
+def conv_3d_ndhwc_dhwcf_q(
+    I=TensorDef(
+        T1,
+        S.N,
+        S.OD * S.SD + S.KD * S.DD,
+        S.OH * S.SH + S.KH * S.DH,
+        S.OW * S.SW + S.KW * S.DW,
+        S.C,
+    ),
+    K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F),
+    IZp=ScalarDef(I32),
+    KZp=ScalarDef(I32),
+    O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True),
+    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+    """Performs 3-D convolution with zero point offsets.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. This includes the zero
+    point offsets common to quantized operations.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
+    O[D.n, D.od, D.oh, D.ow, D.f] += (
+        TypeFn.cast_signed(
+            U,
+            I[
+                D.n,
+                D.od * S.SD + D.kd * S.DD,
+                D.oh * S.SH + D.kh * S.DH,
+                D.ow * S.SW + D.kw * S.DW,
+                D.c,
+            ],
+        )
+        - TypeFn.cast_signed(U, IZp)
+    ) * (
+        TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f])
+        - TypeFn.cast_signed(U, KZp)
+    )
 
 
 @linalg_structured_op
-def conv_3d_ncdhw_fcdhw(I=TensorDef(T1, S.N, S.C, S.OD * S.SD + S.KD * S.DD,
-                                    S.OH * S.SH + S.KH * S.DH,
-                                    S.OW * S.SW + S.KW * S.DW),
-                        K=TensorDef(T2, S.F, S.C, S.KD, S.KH, S.KW),
-                        O=TensorDef(U, S.N, S.F, S.OD, S.OH, S.OW, output=True),
-                        strides=IndexAttrDef(S.SD,
-                                             S.SH,
-                                             S.SW,
-                                             default=[1, 1, 1]),
-                        dilations=IndexAttrDef(S.DD,
-                                               S.DH,
-                                               S.DW,
-                                               default=[1, 1, 1])):
-  """Performs 3-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
-  O[D.n, D.f, D.od, D.oh, D.ow] += TypeFn.cast_signed(
-      U, I[D.n, D.c, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
-           D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed(
-               U, K[D.f, D.c, D.kd, D.kh, D.kw])
+def conv_3d_ncdhw_fcdhw(
+    I=TensorDef(
+        T1,
+        S.N,
+        S.C,
+        S.OD * S.SD + S.KD * S.DD,
+        S.OH * S.SH + S.KH * S.DH,
+        S.OW * S.SW + S.KW * S.DW,
+    ),
+    K=TensorDef(T2, S.F, S.C, S.KD, S.KH, S.KW),
+    O=TensorDef(U, S.N, S.F, S.OD, S.OH, S.OW, output=True),
+    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+    """Performs 3-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
+    O[D.n, D.f, D.od, D.oh, D.ow] += TypeFn.cast_signed(
+        U,
+        I[
+            D.n,
+            D.c,
+            D.od * S.SD + D.kd * S.DD,
+            D.oh * S.SH + D.kh * S.DH,
+            D.ow * S.SW + D.kw * S.DW,
+        ],
+    ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kd, D.kh, D.kw])
 
 
 @linalg_structured_op
-def depthwise_conv_1d_nwc_wc(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW,
-                                         S.IC),
-                             K=TensorDef(T2, S.KW, S.IC),
-                             O=TensorDef(U, S.N, S.OW, S.IC, output=True),
-                             strides=IndexAttrDef(S.SW, default=[1]),
-                             dilations=IndexAttrDef(S.DW, default=[1])):
-  """Performs depth-wise 1-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output. Multiplier is set to 1
-  which is a special case for most depthwise convolutions.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.ow, D.ic, D.kw)
-  O[D.n, D.ow, D.ic] += \
-      TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \
-      TypeFn.cast_signed(U, K[D.kw, D.ic])
+def depthwise_conv_1d_nwc_wc(
+    I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC),
+    K=TensorDef(T2, S.KW, S.IC),
+    O=TensorDef(U, S.N, S.OW, S.IC, output=True),
+    strides=IndexAttrDef(S.SW, default=[1]),
+    dilations=IndexAttrDef(S.DW, default=[1]),
+):
+    """Performs depth-wise 1-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. Multiplier is set to 1
+    which is a special case for most depthwise convolutions.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.ow, D.ic, D.kw)
+    O[D.n, D.ow, D.ic] += TypeFn.cast_signed(
+        U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]
+    ) * TypeFn.cast_signed(U, K[D.kw, D.ic])
 
 
 @linalg_structured_op
-def depthwise_conv_1d_ncw_cw(I=TensorDef(T1, S.N, S.IC,
-                                         S.OW * S.SW + S.KW * S.DW),
-                             K=TensorDef(T2, S.IC, S.KW),
-                             O=TensorDef(U, S.N, S.IC, S.OW, output=True),
-                             strides=IndexAttrDef(S.SW, default=[1]),
-                             dilations=IndexAttrDef(S.DW, default=[1])):
-  """Performs depth-wise 1-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output. Multiplier is set to 1
-  which is a special case for most depthwise convolutions.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.ow, D.ic, D.kw)
-  O[D.n, D.ic, D.ow] += \
-      TypeFn.cast_signed(U, I[D.n, D.ic, D.ow * S.SW + D.kw * S.DW]) * \
-      TypeFn.cast_signed(U, K[D.ic, D.kw])
+def depthwise_conv_1d_ncw_cw(
+    I=TensorDef(T1, S.N, S.IC, S.OW * S.SW + S.KW * S.DW),
+    K=TensorDef(T2, S.IC, S.KW),
+    O=TensorDef(U, S.N, S.IC, S.OW, output=True),
+    strides=IndexAttrDef(S.SW, default=[1]),
+    dilations=IndexAttrDef(S.DW, default=[1]),
+):
+    """Performs depth-wise 1-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. Multiplier is set to 1
+    which is a special case for most depthwise convolutions.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.ow, D.ic, D.kw)
+    O[D.n, D.ic, D.ow] += TypeFn.cast_signed(
+        U, I[D.n, D.ic, D.ow * S.SW + D.kw * S.DW]
+    ) * TypeFn.cast_signed(U, K[D.ic, D.kw])
 
 
 @linalg_structured_op
-def depthwise_conv_1d_nwc_wcm(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW,
-                                          S.IC),
-                              K=TensorDef(T2, S.KW, S.IC, S.CM),
-                              O=TensorDef(U, S.N, S.OW, S.IC, S.CM,
-                                          output=True),
-                              strides=IndexAttrDef(S.SW, default=[1]),
-                              dilations=IndexAttrDef(S.DW, default=[1])):
-  """Performs depth-wise 1-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.ow, D.ic, D.cm, D.kw)
-  O[D.n, D.ow, D.ic, D.cm] += \
-      TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \
-      TypeFn.cast_signed(U, K[D.kw, D.ic, D.cm])
+def depthwise_conv_1d_nwc_wcm(
+    I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC),
+    K=TensorDef(T2, S.KW, S.IC, S.CM),
+    O=TensorDef(U, S.N, S.OW, S.IC, S.CM, output=True),
+    strides=IndexAttrDef(S.SW, default=[1]),
+    dilations=IndexAttrDef(S.DW, default=[1]),
+):
+    """Performs depth-wise 1-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.ow, D.ic, D.cm, D.kw)
+    O[D.n, D.ow, D.ic, D.cm] += TypeFn.cast_signed(
+        U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]
+    ) * TypeFn.cast_signed(U, K[D.kw, D.ic, D.cm])
 
 
 @linalg_structured_op
-def depthwise_conv_2d_nhwc_hwc(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
-                                           S.OW * S.SW + S.KW * S.DW, S.IC),
-                               K=TensorDef(T2, S.KH, S.KW, S.IC),
-                               O=TensorDef(U,
-                                           S.N,
-                                           S.OH,
-                                           S.OW,
-                                           S.IC,
-                                           output=True),
-                               strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                               dilations=IndexAttrDef(S.DH,
-                                                      S.DW,
-                                                      default=[1, 1])):
-  """Performs depth-wise 2-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output. Multiplier is set to 1
-  which is a special case for most depthwise convolutions.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
-  O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast_signed(
-      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
-           D.ic]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic])
+def depthwise_conv_2d_nhwc_hwc(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC),
+    K=TensorDef(T2, S.KH, S.KW, S.IC),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs depth-wise 2-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. Multiplier is set to 1
+    which is a special case for most depthwise convolutions.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
+    O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast_signed(
+        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]
+    ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic])
 
 
 @linalg_structured_op
-def depthwise_conv_2d_nchw_chw(I=TensorDef(T1, S.N, S.IC,
-                                           S.OH * S.SH + S.KH * S.DH,
-                                           S.OW * S.SW + S.KW * S.DW),
-                               K=TensorDef(T2, S.IC, S.KH, S.KW),
-                               O=TensorDef(U,
-                                           S.N,
-                                           S.IC,
-                                           S.OH,
-                                           S.OW,
-                                           output=True),
-                               strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                               dilations=IndexAttrDef(S.DH,
-                                                      S.DW,
-                                                      default=[1, 1])):
-  """Performs depth-wise 2-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output. Multiplier is set to 1
-  which is a special case for most depthwise convolutions.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
-  O[D.n, D.ic, D.oh, D.ow] += TypeFn.cast_signed(
-      U, I[D.n, D.ic, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW +
-           D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.ic, D.kh, D.kw])
+def depthwise_conv_2d_nchw_chw(
+    I=TensorDef(T1, S.N, S.IC, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
+    K=TensorDef(T2, S.IC, S.KH, S.KW),
+    O=TensorDef(U, S.N, S.IC, S.OH, S.OW, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs depth-wise 2-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. Multiplier is set to 1
+    which is a special case for most depthwise convolutions.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
+    O[D.n, D.ic, D.oh, D.ow] += TypeFn.cast_signed(
+        U, I[D.n, D.ic, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
+    ) * TypeFn.cast_signed(U, K[D.ic, D.kh, D.kw])
 
 
 @linalg_structured_op
-def depthwise_conv_2d_nhwc_hwc_q(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
-                                             S.OW * S.SW + S.KW * S.DW, S.IC),
-                                 K=TensorDef(T2, S.KH, S.KW, S.IC),
-                                 IZp=ScalarDef(I32),
-                                 KZp=ScalarDef(I32),
-                                 O=TensorDef(U,
-                                             S.N,
-                                             S.OH,
-                                             S.OW,
-                                             S.IC,
-                                             output=True),
-                                 strides=IndexAttrDef(S.SH,
-                                                      S.SW,
-                                                      default=[1, 1]),
-                                 dilations=IndexAttrDef(S.DH,
-                                                        S.DW,
-                                                        default=[1, 1])):
-  """Performs depth-wise 2-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
-  O[D.n, D.oh, D.ow, D.ic] += ((TypeFn.cast_signed(
-      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) -
-                                TypeFn.cast_signed(U, IZp)) *
-                               (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) -
-                                TypeFn.cast_signed(U, KZp)))
+def depthwise_conv_2d_nhwc_hwc_q(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC),
+    K=TensorDef(T2, S.KH, S.KW, S.IC),
+    IZp=ScalarDef(I32),
+    KZp=ScalarDef(I32),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs depth-wise 2-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
+    O[D.n, D.oh, D.ow, D.ic] += (
+        TypeFn.cast_signed(
+            U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]
+        )
+        - TypeFn.cast_signed(U, IZp)
+    ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) - TypeFn.cast_signed(U, KZp))
 
 
 @linalg_structured_op
-def depthwise_conv_2d_nhwc_hwcm(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
-                                            S.OW * S.SW + S.KW * S.DW, S.IC),
-                                K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
-                                O=TensorDef(U,
-                                            S.N,
-                                            S.OH,
-                                            S.OW,
-                                            S.IC,
-                                            S.CM,
-                                            output=True),
-                                strides=IndexAttrDef(S.SH, S.SW, default=[1,
-                                                                          1]),
-                                dilations=IndexAttrDef(S.DH,
-                                                       S.DW,
-                                                       default=[1, 1])):
-  """Performs depth-wise 2-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
-  O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed(
-      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
-           D.ic]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm])
+def depthwise_conv_2d_nhwc_hwcm(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC),
+    K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs depth-wise 2-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
+    O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed(
+        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]
+    ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm])
 
 
 @linalg_structured_op
-def depthwise_conv_2d_nhwc_hwcm_q(I=TensorDef(T1, S.N,
-                                              S.OH * S.SH + S.KH * S.DH,
-                                              S.OW * S.SW + S.KW * S.DW, S.IC),
-                                  K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
-                                  IZp=ScalarDef(I32),
-                                  KZp=ScalarDef(I32),
-                                  O=TensorDef(U,
-                                              S.N,
-                                              S.OH,
-                                              S.OW,
-                                              S.IC,
-                                              S.CM,
-                                              output=True),
-                                  strides=IndexAttrDef(S.SH,
-                                                       S.SW,
-                                                       default=[1, 1]),
-                                  dilations=IndexAttrDef(S.DH,
-                                                         S.DW,
-                                                         default=[1, 1])):
-  """Performs depth-wise 2-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
-  O[D.n, D.oh, D.ow, D.ic,
-    D.cm] += ((TypeFn.cast_signed(
-        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) -
-               TypeFn.cast_signed(U, IZp)) *
-              (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) -
-               TypeFn.cast_signed(U, KZp)))
+def depthwise_conv_2d_nhwc_hwcm_q(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC),
+    K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
+    IZp=ScalarDef(I32),
+    KZp=ScalarDef(I32),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs depth-wise 2-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
+    O[D.n, D.oh, D.ow, D.ic, D.cm] += (
+        TypeFn.cast_signed(
+            U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]
+        )
+        - TypeFn.cast_signed(U, IZp)
+    ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) - TypeFn.cast_signed(U, KZp))
 
 
 @linalg_structured_op
-def depthwise_conv_3d_ndhwc_dhwc(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
-                                             S.OH * S.SH + S.KH * S.DH,
-                                             S.OW * S.SW + S.KW * S.DW, S.IC),
-                                 K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC),
-                                 O=TensorDef(U,
-                                             S.N,
-                                             S.OD,
-                                             S.OH,
-                                             S.OW,
-                                             output=True),
-                                 strides=IndexAttrDef(S.SD,
-                                                      S.SH,
-                                                      S.SW,
-                                                      default=[1, 1, 1]),
-                                 dilations=IndexAttrDef(S.DD,
-                                                        S.DH,
-                                                        S.DW,
-                                                        default=[1, 1, 1])):
-  """Performs depth-wise 3-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output. Multiplier is set to 1
-  which is a special case for most depthwise convolutions.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic)
-  O[D.n, D.od, D.oh, D.ow, D.ic] += TypeFn.cast_signed(
-      U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
-           D.ow * S.SW + D.kw * S.DW, D.ic]) * TypeFn.cast_signed(
-               U, K[D.kd, D.kh, D.kw, D.ic])
+def depthwise_conv_3d_ndhwc_dhwc(
+    I=TensorDef(
+        T1,
+        S.N,
+        S.OD * S.SD + S.KD * S.DD,
+        S.OH * S.SH + S.KH * S.DH,
+        S.OW * S.SW + S.KW * S.DW,
+        S.IC,
+    ),
+    K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC),
+    O=TensorDef(U, S.N, S.OD, S.OH, S.OW, output=True),
+    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+    """Performs depth-wise 3-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. Multiplier is set to 1
+    which is a special case for most depthwise convolutions.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic)
+    O[D.n, D.od, D.oh, D.ow, D.ic] += TypeFn.cast_signed(
+        U,
+        I[
+            D.n,
+            D.od * S.SD + D.kd * S.DD,
+            D.oh * S.SH + D.kh * S.DH,
+            D.ow * S.SW + D.kw * S.DW,
+            D.ic,
+        ],
+    ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic])
 
 
 @linalg_structured_op
-def depthwise_conv_3d_ncdhw_cdhw(I=TensorDef(T1, S.N, S.IC,
-                                             S.OD * S.SD + S.KD * S.DD,
-                                             S.OH * S.SH + S.KH * S.DH,
-                                             S.OW * S.SW + S.KW * S.DW),
-                                 K=TensorDef(T2, S.IC, S.KD, S.KH, S.KW),
-                                 O=TensorDef(U,
-                                             S.N,
-                                             S.IC,
-                                             S.OD,
-                                             S.OH,
-                                             S.OW,
-                                             output=True),
-                                 strides=IndexAttrDef(S.SD,
-                                                      S.SH,
-                                                      S.SW,
-                                                      default=[1, 1, 1]),
-                                 dilations=IndexAttrDef(S.DD,
-                                                        S.DH,
-                                                        S.DW,
-                                                        default=[1, 1, 1])):
-  """Performs depth-wise 3-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output. Multiplier is set to 1
-  which is a special case for most depthwise convolutions.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic)
-  O[D.n, D.ic, D.od, D.oh, D.ow] += TypeFn.cast_signed(
-      U, I[D.n, D.ic, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
-           D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed(
-               U, K[D.ic, D.kd, D.kh, D.kw])
+def depthwise_conv_3d_ncdhw_cdhw(
+    I=TensorDef(
+        T1,
+        S.N,
+        S.IC,
+        S.OD * S.SD + S.KD * S.DD,
+        S.OH * S.SH + S.KH * S.DH,
+        S.OW * S.SW + S.KW * S.DW,
+    ),
+    K=TensorDef(T2, S.IC, S.KD, S.KH, S.KW),
+    O=TensorDef(U, S.N, S.IC, S.OD, S.OH, S.OW, output=True),
+    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+    """Performs depth-wise 3-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. Multiplier is set to 1
+    which is a special case for most depthwise convolutions.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic)
+    O[D.n, D.ic, D.od, D.oh, D.ow] += TypeFn.cast_signed(
+        U,
+        I[
+            D.n,
+            D.ic,
+            D.od * S.SD + D.kd * S.DD,
+            D.oh * S.SH + D.kh * S.DH,
+            D.ow * S.SW + D.kw * S.DW,
+        ],
+    ) * TypeFn.cast_signed(U, K[D.ic, D.kd, D.kh, D.kw])
 
 
 @linalg_structured_op
-def depthwise_conv_3d_ndhwc_dhwcm(I=TensorDef(T1, S.N,
-                                              S.OD * S.SD + S.KD * S.DD,
-                                              S.OH * S.SH + S.KH * S.DH,
-                                              S.OW * S.SW + S.KW * S.DW, S.IC),
-                                  K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC, S.CM),
-                                  O=TensorDef(U,
-                                              S.N,
-                                              S.OD,
-                                              S.OH,
-                                              S.OW,
-                                              S.CM,
-                                              output=True),
-                                  strides=IndexAttrDef(S.SD,
-                                                       S.SH,
-                                                       S.SW,
-                                                       default=[1, 1, 1]),
-                                  dilations=IndexAttrDef(S.DD,
-                                                         S.DH,
-                                                         S.DW,
-                                                         default=[1, 1, 1])):
-  """Performs depth-wise 3-D convolution.
-
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.od, D.oh, D.ow, D.cm, D.kd, D.kh, D.kw, D.ic)
-  O[D.n, D.od, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed(
-      U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
-           D.ow * S.SW + D.kw * S.DW, D.ic]) * TypeFn.cast_signed(
-               U, K[D.kd, D.kh, D.kw, D.ic, D.cm])
+def depthwise_conv_3d_ndhwc_dhwcm(
+    I=TensorDef(
+        T1,
+        S.N,
+        S.OD * S.SD + S.KD * S.DD,
+        S.OH * S.SH + S.KH * S.DH,
+        S.OW * S.SW + S.KW * S.DW,
+        S.IC,
+    ),
+    K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC, S.CM),
+    O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.CM, output=True),
+    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+    """Performs depth-wise 3-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.od, D.oh, D.ow, D.cm, D.kd, D.kh, D.kw, D.ic)
+    O[D.n, D.od, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed(
+        U,
+        I[
+            D.n,
+            D.od * S.SD + D.kd * S.DD,
+            D.oh * S.SH + D.kh * S.DH,
+            D.ow * S.SW + D.kw * S.DW,
+            D.ic,
+        ],
+    ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic, D.cm])
 
 
 @linalg_structured_op
-def pooling_nhwc_sum(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
-                                 S.OW * S.SW + S.KW * S.DW, S.C),
-                     K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
-                     O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
-                     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                     dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
-  """Performs sum pooling.
-
-  Layout:
-    * Input: NHWC.
-    * Kernel: HW.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
-  O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed(
-      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
+def pooling_nhwc_sum(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs sum pooling.
+
+    Layout:
+      * Input: NHWC.
+      * Kernel: HW.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
+    O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed(
+        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+    )
 
 
 @linalg_structured_op
-def pooling_nchw_sum(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH,
-                                 S.OW * S.SW + S.KW * S.DW),
-                     K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
-                     O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True),
-                     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                     dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
-  """Performs sum pooling.
-
-  Layout:
-    * Input: NCHW.
-    * Kernel: HW.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
-  O[D.n, D.c, D.oh, D.ow] += TypeFn.cast_signed(
-      U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW])
+def pooling_nchw_sum(
+    I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
+    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs sum pooling.
+
+    Layout:
+      * Input: NCHW.
+      * Kernel: HW.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
+    O[D.n, D.c, D.oh, D.ow] += TypeFn.cast_signed(
+        U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
+    )
 
 
 @linalg_structured_op
-def pooling_nhwc_max(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
-                                 S.OW * S.SW + S.KW * S.DW, S.C),
-                     K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
-                     O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
-                     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                     dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
-  """Performs max pooling.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
-  O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kh, D.kw](TypeFn.cast_signed(
-      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_nhwc_max(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs max pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
+    O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kh, D.kw](
+        TypeFn.cast_signed(
+            U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+        )
+    )
 
 
 @linalg_structured_op
-def pooling_nhwc_max_unsigned(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
-                                          S.OW * S.SW + S.KW * S.DW, S.C),
-                              K=TensorDef(T2,
-                                          S.KH,
-                                          S.KW,
-                                          index_dims=[D.kh, D.kw]),
-                              O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
-                              strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                              dilations=IndexAttrDef(S.DH, S.DW, default=[1,
-                                                                          1])):
-  """Performs unsigned max pooling.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
-  O[D.n, D.oh, D.ow,
-    D.c] = ReduceFn.max_unsigned[D.kh, D.kw](TypeFn.cast_unsigned(
-        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_nhwc_max_unsigned(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs unsigned max pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
+    O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw](
+        TypeFn.cast_unsigned(
+            U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+        )
+    )
 
 
 @linalg_structured_op
-def pooling_nchw_max(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH,
-                                 S.OW * S.SW + S.KW * S.DW),
-                     K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
-                     O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True),
-                     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                     dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
-  """Performs max pooling.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
-  O[D.n, D.c, D.oh, D.ow] = ReduceFn.max_signed[D.kh, D.kw](TypeFn.cast_signed(
-      U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,]))
+def pooling_nchw_max(
+    I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
+    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs max pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
+    O[D.n, D.c, D.oh, D.ow] = ReduceFn.max_signed[D.kh, D.kw](
+        TypeFn.cast_signed(
+            U,
+            I[
+                D.n,
+                D.c,
+                D.oh * S.SH + D.kh * S.DH,
+                D.ow * S.SW + D.kw * S.DW,
+            ],
+        )
+    )
 
 
 @linalg_structured_op
-def pooling_nhwc_min(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
-                                 S.OW * S.SW + S.KW * S.DW, S.C),
-                     K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
-                     O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
-                     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                     dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
-  """Performs min pooling.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
-  O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kh, D.kw](TypeFn.cast_signed(
-      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_nhwc_min(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs min pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
+    O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kh, D.kw](
+        TypeFn.cast_signed(
+            U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+        )
+    )
 
 
 @linalg_structured_op
-def pooling_nhwc_min_unsigned(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
-                                          S.OW * S.SW + S.KW * S.DW, S.C),
-                              K=TensorDef(T2,
-                                          S.KH,
-                                          S.KW,
-                                          index_dims=[D.kh, D.kw]),
-                              O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
-                              strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-                              dilations=IndexAttrDef(S.DH, S.DW, default=[1,
-                                                                          1])):
-  """Performs unsigned min pooling.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
-  O[D.n, D.oh, D.ow,
-    D.c] = ReduceFn.min_unsigned[D.kh, D.kw](TypeFn.cast_unsigned(
-        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_nhwc_min_unsigned(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs unsigned min pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw)
+    O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw](
+        TypeFn.cast_unsigned(
+            U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+        )
+    )
+
 
 @linalg_structured_op
-def pooling_nwc_sum(I=TensorDef(T1, S.N,
-                                S.OW * S.SW + S.KW * S.DW, S.C),
-                    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
-                    O=TensorDef(U, S.N, S.OW, S.C, output=True),
-                    strides=IndexAttrDef(S.SW, default=[1]),
-                    dilations=IndexAttrDef(S.DW, default=[1])):
-  """Performs sum pooling.
-
-  Layout:
-    * Input: NWC.
-    * Kernel: W.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.ow, D.c, D.kw)
-  O[D.n, D.ow, D.c] += TypeFn.cast_signed(
-      U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])
+def pooling_nwc_sum(
+    I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
+    O=TensorDef(U, S.N, S.OW, S.C, output=True),
+    strides=IndexAttrDef(S.SW, default=[1]),
+    dilations=IndexAttrDef(S.DW, default=[1]),
+):
+    """Performs sum pooling.
+
+    Layout:
+      * Input: NWC.
+      * Kernel: W.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.ow, D.c, D.kw)
+    O[D.n, D.ow, D.c] += TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])
 
 
 @linalg_structured_op
-def pooling_ncw_sum(I=TensorDef(T1, S.N, S.C,
-                                S.OW * S.SW + S.KW * S.DW),
-                    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
-                    O=TensorDef(U, S.N, S.C, S.OW, output=True),
-                    strides=IndexAttrDef(S.SW, default=[1]),
-                    dilations=IndexAttrDef(S.DW, default=[1])):
-  """Performs sum pooling.
-
-  Layout:
-    * Input: NCW.
-    * Kernel: W.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.c, D.ow, D.kw)
-  O[D.n, D.c, D.ow] += TypeFn.cast_signed(
-      U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW])
+def pooling_ncw_sum(
+    I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW),
+    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
+    O=TensorDef(U, S.N, S.C, S.OW, output=True),
+    strides=IndexAttrDef(S.SW, default=[1]),
+    dilations=IndexAttrDef(S.DW, default=[1]),
+):
+    """Performs sum pooling.
+
+    Layout:
+      * Input: NCW.
+      * Kernel: W.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.c, D.ow, D.kw)
+    O[D.n, D.c, D.ow] += TypeFn.cast_signed(U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW])
 
 
 @linalg_structured_op
-def pooling_nwc_max(I=TensorDef(T1, S.N,
-                                S.OW * S.SW + S.KW * S.DW, S.C),
-                    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
-                    O=TensorDef(U, S.N, S.OW, S.C, output=True),
-                    strides=IndexAttrDef(S.SW, default=[1]),
-                    dilations=IndexAttrDef(S.DW, default=[1])):
-  """Performs max pooling.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.ow, D.c, D.kw)
-  O[D.n, D.ow, D.c] = ReduceFn.max_signed[[D.kw]](TypeFn.cast_signed(
-      U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_nwc_max(
+    I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
+    O=TensorDef(U, S.N, S.OW, S.C, output=True),
+    strides=IndexAttrDef(S.SW, default=[1]),
+    dilations=IndexAttrDef(S.DW, default=[1]),
+):
+    """Performs max pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.ow, D.c, D.kw)
+    O[D.n, D.ow, D.c] = ReduceFn.max_signed[[D.kw]](
+        TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])
+    )
 
 
 @linalg_structured_op
-def pooling_nwc_max_unsigned(I=TensorDef(T1, S.N,
-                                         S.OW * S.SW + S.KW * S.DW, S.C),
-                             K=TensorDef(T2,
-                                         S.KW,
-                                         index_dims=[D.kw]),
-                             O=TensorDef(U, S.N, S.OW, S.C, output=True),
-                             strides=IndexAttrDef(S.SW, default=[1]),
-                             dilations=IndexAttrDef(S.DW, default=[1])):
-  """Performs unsigned max pooling.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.ow, D.c, D.kw)
-  O[D.n, D.ow,
-    D.c] = ReduceFn.max_unsigned[[D.kw]](TypeFn.cast_unsigned(
-        U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_nwc_max_unsigned(
+    I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
+    O=TensorDef(U, S.N, S.OW, S.C, output=True),
+    strides=IndexAttrDef(S.SW, default=[1]),
+    dilations=IndexAttrDef(S.DW, default=[1]),
+):
+    """Performs unsigned max pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.ow, D.c, D.kw)
+    O[D.n, D.ow, D.c] = ReduceFn.max_unsigned[[D.kw]](
+        TypeFn.cast_unsigned(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])
+    )
 
 
 @linalg_structured_op
-def pooling_ncw_max(I=TensorDef(T1, S.N, S.C,
-                                S.OW * S.SW + S.KW * S.DW),
-                    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
-                    O=TensorDef(U, S.N, S.C, S.OW, output=True),
-                    strides=IndexAttrDef(S.SW, default=[1]),
-                    dilations=IndexAttrDef(S.DW, default=[1])):
-  """Performs max pooling.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.c, D.ow, D.kw)
-  O[D.n, D.c, D.ow] = ReduceFn.max_signed[[D.kw]](TypeFn.cast_signed(
-      U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW,]))
+def pooling_ncw_max(
+    I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW),
+    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
+    O=TensorDef(U, S.N, S.C, S.OW, output=True),
+    strides=IndexAttrDef(S.SW, default=[1]),
+    dilations=IndexAttrDef(S.DW, default=[1]),
+):
+    """Performs max pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.c, D.ow, D.kw)
+    O[D.n, D.c, D.ow] = ReduceFn.max_signed[[D.kw]](
+        TypeFn.cast_signed(
+            U,
+            I[
+                D.n,
+                D.c,
+                D.ow * S.SW + D.kw * S.DW,
+            ],
+        )
+    )
 
 
 @linalg_structured_op
-def pooling_nwc_min(I=TensorDef(T1, S.N,
-                                S.OW * S.SW + S.KW * S.DW, S.C),
-                    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
-                    O=TensorDef(U, S.N, S.OW, S.C, output=True),
-                    strides=IndexAttrDef(S.SW, default=[1]),
-                    dilations=IndexAttrDef(S.DW, default=[1])):
-  """Performs min pooling.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.ow, D.c, D.kw)
-  O[D.n, D.ow, D.c] = ReduceFn.min_signed[[D.kw]](TypeFn.cast_signed(
-      U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_nwc_min(
+    I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
+    O=TensorDef(U, S.N, S.OW, S.C, output=True),
+    strides=IndexAttrDef(S.SW, default=[1]),
+    dilations=IndexAttrDef(S.DW, default=[1]),
+):
+    """Performs min pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.ow, D.c, D.kw)
+    O[D.n, D.ow, D.c] = ReduceFn.min_signed[[D.kw]](
+        TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])
+    )
 
 
 @linalg_structured_op
-def pooling_nwc_min_unsigned(I=TensorDef(T1, S.N,
-                                         S.OW * S.SW + S.KW * S.DW, S.C),
-                             K=TensorDef(T2,
-                                         S.KW,
-                                         index_dims=[D.kw]),
-                             O=TensorDef(U, S.N, S.OW, S.C, output=True),
-                             strides=IndexAttrDef(S.SW, default=[1]),
-                             dilations=IndexAttrDef(S.DW, default=[1])):
-  """Performs unsigned min pooling.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.ow, D.c, D.kw)
-  O[D.n, D.ow,
-    D.c] = ReduceFn.min_unsigned[[D.kw]](TypeFn.cast_unsigned(
-        U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]))
-
+def pooling_nwc_min_unsigned(
+    I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.KW, index_dims=[D.kw]),
+    O=TensorDef(U, S.N, S.OW, S.C, output=True),
+    strides=IndexAttrDef(S.SW, default=[1]),
+    dilations=IndexAttrDef(S.DW, default=[1]),
+):
+    """Performs unsigned min pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.ow, D.c, D.kw)
+    O[D.n, D.ow, D.c] = ReduceFn.min_unsigned[[D.kw]](
+        TypeFn.cast_unsigned(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])
+    )
 
 
 @linalg_structured_op
-def pooling_ndhwc_sum(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
-                                  S.OH * S.SH + S.KH * S.DH,
-                                  S.OW * S.SW + S.KW * S.DW, S.C),
-                      K=TensorDef(T2,
-                                  S.KD,
-                                  S.KH,
-                                  S.KW,
-                                  index_dims=[D.kd, D.kh, D.kw]),
-                      O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
-                      strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
-                      dilations=IndexAttrDef(S.DD,
-                                             S.DH,
-                                             S.DW,
-                                             default=[1, 1, 1])):
-  """Performs 3D sum pooling.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
-  O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast_signed(
-      U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
-           D.ow * S.SW + D.kw * S.DW, D.c])
+def pooling_ndhwc_sum(
+    I=TensorDef(
+        T1,
+        S.N,
+        S.OD * S.SD + S.KD * S.DD,
+        S.OH * S.SH + S.KH * S.DH,
+        S.OW * S.SW + S.KW * S.DW,
+        S.C,
+    ),
+    K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
+    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+    """Performs 3D sum pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
+    O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast_signed(
+        U,
+        I[
+            D.n,
+            D.od * S.SD + D.kd * S.DD,
+            D.oh * S.SH + D.kh * S.DH,
+            D.ow * S.SW + D.kw * S.DW,
+            D.c,
+        ],
+    )
 
 
 @linalg_structured_op
-def pooling_ndhwc_max(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
-                                  S.OH * S.SH + S.KH * S.DH,
-                                  S.OW * S.SW + S.KW * S.DW, S.C),
-                      K=TensorDef(T2,
-                                  S.KD,
-                                  S.KH,
-                                  S.KW,
-                                  index_dims=[D.kd, D.kh, D.kw]),
-                      O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
-                      strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
-                      dilations=IndexAttrDef(S.DD,
-                                             S.DH,
-                                             S.DW,
-                                             default=[1, 1, 1])):
-  """Performs 3D max pooling.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
-  O[D.n, D.od, D.oh, D.ow,
-    D.c] = ReduceFn.max_signed[D.kd, D.kh, D.kw](TypeFn.cast_signed(
-        U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
-             D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_ndhwc_max(
+    I=TensorDef(
+        T1,
+        S.N,
+        S.OD * S.SD + S.KD * S.DD,
+        S.OH * S.SH + S.KH * S.DH,
+        S.OW * S.SW + S.KW * S.DW,
+        S.C,
+    ),
+    K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
+    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+    """Performs 3D max pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
+    O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kd, D.kh, D.kw](
+        TypeFn.cast_signed(
+            U,
+            I[
+                D.n,
+                D.od * S.SD + D.kd * S.DD,
+                D.oh * S.SH + D.kh * S.DH,
+                D.ow * S.SW + D.kw * S.DW,
+                D.c,
+            ],
+        )
+    )
 
 
 @linalg_structured_op
-def pooling_ndhwc_min(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD,
-                                  S.OH * S.SH + S.KH * S.DH,
-                                  S.OW * S.SW + S.KW * S.DW, S.C),
-                      K=TensorDef(T2,
-                                  S.KD,
-                                  S.KH,
-                                  S.KW,
-                                  index_dims=[D.kd, D.kh, D.kw]),
-                      O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
-                      strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
-                      dilations=IndexAttrDef(S.DD,
-                                             S.DH,
-                                             S.DW,
-                                             default=[1, 1, 1])):
-  """Performs 3D min pooling.
-
-  Numeric casting is performed on the input operand, promoting it to the same
-  data type as the accumulator/output.
-  """
-  implements(ConvolutionOpInterface)
-  domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
-  O[D.n, D.od, D.oh, D.ow,
-    D.c] = ReduceFn.min_signed[D.kd, D.kh, D.kw](TypeFn.cast_signed(
-        U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
-             D.ow * S.SW + D.kw * S.DW, D.c]))
+def pooling_ndhwc_min(
+    I=TensorDef(
+        T1,
+        S.N,
+        S.OD * S.SD + S.KD * S.DD,
+        S.OH * S.SH + S.KH * S.DH,
+        S.OW * S.SW + S.KW * S.DW,
+        S.C,
+    ),
+    K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
+    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+    """Performs 3D min pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw)
+    O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kd, D.kh, D.kw](
+        TypeFn.cast_signed(
+            U,
+            I[
+                D.n,
+                D.od * S.SD + D.kd * S.DD,
+                D.oh * S.SH + D.kh * S.DH,
+                D.ow * S.SW + D.kw * S.DW,
+                D.c,
+            ],
+        )
+    )
 
 
 @linalg_structured_op
 def fill(value=ScalarDef(T1), O=TensorDef(U, output=True)):
-  """Fills the output tensor with the given value.
+    """Fills the output tensor with the given value.
 
-  Works for arbitrary ranked output tensors since the operation performs scalar
-  accesses only and is thus rank polymorphic. Numeric casting is performed on
-  the value operand, promoting it to the same data type as the output.
-  """
-  implements(FillOpInterface)
-  defines(Canonicalizer)
-  O[None] = TypeFn.cast_signed(U, value)
+    Works for arbitrary ranked output tensors since the operation performs scalar
+    accesses only and is thus rank polymorphic. Numeric casting is performed on
+    the value operand, promoting it to the same data type as the output.
+    """
+    implements(FillOpInterface)
+    defines(Canonicalizer)
+    O[None] = TypeFn.cast_signed(U, value)
 
 
 @linalg_structured_op
-def fill_rng_2d(min=ScalarDef(F64),
-                max=ScalarDef(F64),
-                seed=ScalarDef(I32),
-                O=TensorDef(T, S.M, S.N, output=True)):
-  """Fills the output tensor with pseudo random numbers.
-
-  The operation generations pseudo random numbers using a linear congruential
-  generator. It provides no guarantees regarding the distribution of the
-  generated random numbers. Instead of generating the random numbers
-  sequentially, it instantiates one random number generator per data element
-  and runs them in parallel. The seed operand and the indices of the data
-  element seed the random number generation. The min and max operands limit
-  the range of the generated random numbers.
-  """
-  domain(D.m, D.n)
-  multiplier = TypeFn.cast_signed(I32, const(1103515245))
-  increment = TypeFn.cast_signed(I32, const(12345))
-  rand1 = (TypeFn.cast_signed(I32, index(D.m)) + seed) * multiplier + increment
-  rand2 = (TypeFn.cast_signed(I32, index(D.n)) + rand1) * multiplier + increment
-  inv_range = TypeFn.cast_signed(F64, const(2.3283064e-10))
-  offset = TypeFn.cast_signed(F64, const(2147483647))
-  scaling = (max - min) * inv_range
-  O[D.m, D.n] = TypeFn.cast_signed(
-      T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min)
+def fill_rng_2d(
+    min=ScalarDef(F64),
+    max=ScalarDef(F64),
+    seed=ScalarDef(I32),
+    O=TensorDef(T, S.M, S.N, output=True),
+):
+    """Fills the output tensor with pseudo random numbers.
+
+    The operation generations pseudo random numbers using a linear congruential
+    generator. It provides no guarantees regarding the distribution of the
+    generated random numbers. Instead of generating the random numbers
+    sequentially, it instantiates one random number generator per data element
+    and runs them in parallel. The seed operand and the indices of the data
+    element seed the random number generation. The min and max operands limit
+    the range of the generated random numbers.
+    """
+    domain(D.m, D.n)
+    multiplier = TypeFn.cast_signed(I32, const(1103515245))
+    increment = TypeFn.cast_signed(I32, const(12345))
+    rand1 = (TypeFn.cast_signed(I32, index(D.m)) + seed) * multiplier + increment
+    rand2 = (TypeFn.cast_signed(I32, index(D.n)) + rand1) * multiplier + increment
+    inv_range = TypeFn.cast_signed(F64, const(2.3283064e-10))
+    offset = TypeFn.cast_signed(F64, const(2147483647))
+    scaling = (max - min) * inv_range
+    O[D.m, D.n] = TypeFn.cast_signed(
+        T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min
+    )
index ca0d479..980f237 100644 (file)
@@ -5,6 +5,8 @@
 from ._python_test_ops_gen import *
 from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestTensorType
 
+
 def register_python_test_dialect(context, load=True):
-  from .._mlir_libs import _mlirPythonTest
-  _mlirPythonTest.register_python_test_dialect(context, load)
+    from .._mlir_libs import _mlirPythonTest
+
+    _mlirPythonTest.register_python_test_dialect(context, load)
index 78956c4..b505a49 100644 (file)
@@ -6,16 +6,18 @@ from enum import Enum
 
 
 class FailurePropagationMode(Enum):
-  """Propagation mode for silenceable errors."""
-  PROPAGATE = 1
-  SUPPRESS = 2
+    """Propagation mode for silenceable errors."""
 
-  def _as_int(self):
-    if self is FailurePropagationMode.PROPAGATE:
-      return 1
+    PROPAGATE = 1
+    SUPPRESS = 2
+
+    def _as_int(self):
+        if self is FailurePropagationMode.PROPAGATE:
+            return 1
+
+        assert self is FailurePropagationMode.SUPPRESS
+        return 2
 
-    assert self is FailurePropagationMode.SUPPRESS
-    return 2
 
 from .._transform_ops_gen import *
 from ..._mlir_libs._mlirDialectsTransform import *
index 262545b..4739231 100644 (file)
@@ -7,37 +7,37 @@ from ._mlir_libs import _mlirExecutionEngine as _execution_engine
 import ctypes
 
 __all__ = [
-  "ExecutionEngine",
+    "ExecutionEngine",
 ]
 
-class ExecutionEngine(_execution_engine.ExecutionEngine):
 
-  def lookup(self, name):
-    """Lookup a function emitted with the `llvm.emit_c_interface`
-    attribute and returns a ctype callable.
-    Raise a RuntimeError if the function isn't found.
-    """
-    func = self.raw_lookup("_mlir_ciface_" + name)
-    if not func:
-      raise RuntimeError("Unknown function " + name)
-    prototype = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
-    return prototype(func)
+class ExecutionEngine(_execution_engine.ExecutionEngine):
+    def lookup(self, name):
+        """Lookup a function emitted with the `llvm.emit_c_interface`
+        attribute and returns a ctype callable.
+        Raise a RuntimeError if the function isn't found.
+        """
+        func = self.raw_lookup("_mlir_ciface_" + name)
+        if not func:
+            raise RuntimeError("Unknown function " + name)
+        prototype = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
+        return prototype(func)
 
-  def invoke(self, name, *ctypes_args):
-    """Invoke a function with the list of ctypes arguments.
-    All arguments must be pointers.
-    Raise a RuntimeError if the function isn't found.
-    """
-    func = self.lookup(name)
-    packed_args = (ctypes.c_void_p * len(ctypes_args))()
-    for argNum in range(len(ctypes_args)):
-      packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p)
-    func(packed_args)
+    def invoke(self, name, *ctypes_args):
+        """Invoke a function with the list of ctypes arguments.
+        All arguments must be pointers.
+        Raise a RuntimeError if the function isn't found.
+        """
+        func = self.lookup(name)
+        packed_args = (ctypes.c_void_p * len(ctypes_args))()
+        for argNum in range(len(ctypes_args)):
+            packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p)
+        func(packed_args)
 
-  def register_runtime(self, name, ctypes_callback):
-    """Register a runtime function available to the jitted code
-    under the provided `name`. The `ctypes_callback` must be a
-    `CFuncType` that outlives the execution engine.
-    """
-    callback = ctypes.cast(ctypes_callback, ctypes.c_void_p)
-    self.raw_register_runtime("_mlir_ciface_" + name, callback)
+    def register_runtime(self, name, ctypes_callback):
+        """Register a runtime function available to the jitted code
+        under the provided `name`. The `ctypes_callback` must be a
+        `CFuncType` that outlives the execution engine.
+        """
+        callback = ctypes.cast(ctypes_callback, ctypes.c_void_p)
+        self.raw_register_runtime("_mlir_ciface_" + name, callback)
index be065d4..99c21ff 100644 (file)
@@ -8,124 +8,123 @@ from ._mlir_libs._mlir.ir import _GlobalDebug
 
 # Convenience decorator for registering user-friendly Attribute builders.
 def register_attribute_builder(kind):
+    def decorator_builder(func):
+        AttrBuilder.insert(kind, func)
+        return func
 
-  def decorator_builder(func):
-    AttrBuilder.insert(kind, func)
-    return func
-
-  return decorator_builder
+    return decorator_builder
 
 
 @register_attribute_builder("BoolAttr")
 def _boolAttr(x, context):
-  return BoolAttr.get(x, context=context)
+    return BoolAttr.get(x, context=context)
 
 
 @register_attribute_builder("IndexAttr")
 def _indexAttr(x, context):
-  return IntegerAttr.get(IndexType.get(context=context), x)
+    return IntegerAttr.get(IndexType.get(context=context), x)
 
 
 @register_attribute_builder("I16Attr")
 def _i16Attr(x, context):
-  return IntegerAttr.get(IntegerType.get_signless(16, context=context), x)
+    return IntegerAttr.get(IntegerType.get_signless(16, context=context), x)
 
 
 @register_attribute_builder("I32Attr")
 def _i32Attr(x, context):
-  return IntegerAttr.get(IntegerType.get_signless(32, context=context), x)
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), x)
 
 
 @register_attribute_builder("I64Attr")
 def _i64Attr(x, context):
-  return IntegerAttr.get(IntegerType.get_signless(64, context=context), x)
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), x)
 
 
 @register_attribute_builder("SI16Attr")
 def _si16Attr(x, context):
-  return IntegerAttr.get(IntegerType.get_signed(16, context=context), x)
+    return IntegerAttr.get(IntegerType.get_signed(16, context=context), x)
 
 
 @register_attribute_builder("SI32Attr")
 def _si32Attr(x, context):
-  return IntegerAttr.get(IntegerType.get_signed(32, context=context), x)
+    return IntegerAttr.get(IntegerType.get_signed(32, context=context), x)
 
 
 @register_attribute_builder("F32Attr")
 def _f32Attr(x, context):
-  return FloatAttr.get_f32(x, context=context)
+    return FloatAttr.get_f32(x, context=context)
 
 
 @register_attribute_builder("F64Attr")
 def _f64Attr(x, context):
-  return FloatAttr.get_f64(x, context=context)
+    return FloatAttr.get_f64(x, context=context)
 
 
 @register_attribute_builder("StrAttr")
 def _stringAttr(x, context):
-  return StringAttr.get(x, context=context)
+    return StringAttr.get(x, context=context)
 
 
 @register_attribute_builder("SymbolNameAttr")
 def _symbolNameAttr(x, context):
-  return StringAttr.get(x, context=context)
+    return StringAttr.get(x, context=context)
 
 
 @register_attribute_builder("SymbolRefAttr")
 def _symbolRefAttr(x, context):
-  return FlatSymbolRefAttr.get(x, context=context)
+    return FlatSymbolRefAttr.get(x, context=context)
 
 
 @register_attribute_builder("ArrayAttr")
 def _arrayAttr(x, context):
-  return ArrayAttr.get(x, context=context)
+    return ArrayAttr.get(x, context=context)
 
 
 @register_attribute_builder("I32ArrayAttr")
 def _i32ArrayAttr(x, context):
-  return ArrayAttr.get([_i32Attr(v, context) for v in x])
+    return ArrayAttr.get([_i32Attr(v, context) for v in x])
 
 
 @register_attribute_builder("I64ArrayAttr")
 def _i64ArrayAttr(x, context):
-  return ArrayAttr.get([_i64Attr(v, context) for v in x])
+    return ArrayAttr.get([_i64Attr(v, context) for v in x])
 
 
 @register_attribute_builder("F32ArrayAttr")
 def _f32ArrayAttr(x, context):
-  return ArrayAttr.get([_f32Attr(v, context) for v in x])
+    return ArrayAttr.get([_f32Attr(v, context) for v in x])
 
 
 @register_attribute_builder("F64ArrayAttr")
 def _f64ArrayAttr(x, context):
-  return ArrayAttr.get([_f64Attr(v, context) for v in x])
+    return ArrayAttr.get([_f64Attr(v, context) for v in x])
 
 
 @register_attribute_builder("DenseI64ArrayAttr")
 def _denseI64ArrayAttr(x, context):
-  return DenseI64ArrayAttr.get(x, context=context)
+    return DenseI64ArrayAttr.get(x, context=context)
 
 
 @register_attribute_builder("TypeAttr")
 def _typeAttr(x, context):
-  return TypeAttr.get(x, context=context)
+    return TypeAttr.get(x, context=context)
 
 
 @register_attribute_builder("TypeArrayAttr")
 def _typeArrayAttr(x, context):
-  return _arrayAttr([TypeAttr.get(t, context=context) for t in x], context)
+    return _arrayAttr([TypeAttr.get(t, context=context) for t in x], context)
 
 
 try:
-  import numpy as np
+    import numpy as np
 
-  @register_attribute_builder("IndexElementsAttr")
-  def _indexElementsAttr(x, context):
-    return DenseElementsAttr.get(
-        np.array(x, dtype=np.int64),
-        type=IndexType.get(context=context),
-        context=context,
-    )
+    @register_attribute_builder("IndexElementsAttr")
+    def _indexElementsAttr(x, context):
+        return DenseElementsAttr.get(
+            np.array(x, dtype=np.int64),
+            type=IndexType.get(context=context),
+            context=context,
+        )
 
 except ImportError:
-  pass
+    pass
index d709679..51433d7 100644 (file)
@@ -9,131 +9,134 @@ import ctypes
 
 
 class C128(ctypes.Structure):
-  """A ctype representation for MLIR's Double Complex."""
-  _fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)]
+    """A ctype representation for MLIR's Double Complex."""
+
+    _fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)]
 
 
 class C64(ctypes.Structure):
-  """A ctype representation for MLIR's Float Complex."""
-  _fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)]
+    """A ctype representation for MLIR's Float Complex."""
+
+    _fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)]
 
 
 class F16(ctypes.Structure):
-  """A ctype representation for MLIR's Float16."""
-  _fields_ = [("f16", ctypes.c_int16)]
+    """A ctype representation for MLIR's Float16."""
+
+    _fields_ = [("f16", ctypes.c_int16)]
 
 
 # https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
 def as_ctype(dtp):
-  """Converts dtype to ctype."""
-  if dtp == np.dtype(np.complex128):
-    return C128
-  if dtp == np.dtype(np.complex64):
-    return C64
-  if dtp == np.dtype(np.float16):
-    return F16
-  return np.ctypeslib.as_ctypes_type(dtp)
+    """Converts dtype to ctype."""
+    if dtp == np.dtype(np.complex128):
+        return C128
+    if dtp == np.dtype(np.complex64):
+        return C64
+    if dtp == np.dtype(np.float16):
+        return F16
+    return np.ctypeslib.as_ctypes_type(dtp)
 
 
 def to_numpy(array):
-  """Converts ctypes array back to numpy dtype array."""
-  if array.dtype == C128:
-    return array.view("complex128")
-  if array.dtype == C64:
-    return array.view("complex64")
-  if array.dtype == F16:
-    return array.view("float16")
-  return array
+    """Converts ctypes array back to numpy dtype array."""
+    if array.dtype == C128:
+        return array.view("complex128")
+    if array.dtype == C64:
+        return array.view("complex64")
+    if array.dtype == F16:
+        return array.view("float16")
+    return array
 
 
 def make_nd_memref_descriptor(rank, dtype):
+    class MemRefDescriptor(ctypes.Structure):
+        """Builds an empty descriptor for the given rank/dtype, where rank>0."""
 
-  class MemRefDescriptor(ctypes.Structure):
-    """Builds an empty descriptor for the given rank/dtype, where rank>0."""
+        _fields_ = [
+            ("allocated", ctypes.c_longlong),
+            ("aligned", ctypes.POINTER(dtype)),
+            ("offset", ctypes.c_longlong),
+            ("shape", ctypes.c_longlong * rank),
+            ("strides", ctypes.c_longlong * rank),
+        ]
 
-    _fields_ = [
-        ("allocated", ctypes.c_longlong),
-        ("aligned", ctypes.POINTER(dtype)),
-        ("offset", ctypes.c_longlong),
-        ("shape", ctypes.c_longlong * rank),
-        ("strides", ctypes.c_longlong * rank),
-    ]
-
-  return MemRefDescriptor
+    return MemRefDescriptor
 
 
 def make_zero_d_memref_descriptor(dtype):
+    class MemRefDescriptor(ctypes.Structure):
+        """Builds an empty descriptor for the given dtype, where rank=0."""
 
-  class MemRefDescriptor(ctypes.Structure):
-    """Builds an empty descriptor for the given dtype, where rank=0."""
-
-    _fields_ = [
-        ("allocated", ctypes.c_longlong),
-        ("aligned", ctypes.POINTER(dtype)),
-        ("offset", ctypes.c_longlong),
-    ]
+        _fields_ = [
+            ("allocated", ctypes.c_longlong),
+            ("aligned", ctypes.POINTER(dtype)),
+            ("offset", ctypes.c_longlong),
+        ]
 
-  return MemRefDescriptor
+    return MemRefDescriptor
 
 
 class UnrankedMemRefDescriptor(ctypes.Structure):
-  """Creates a ctype struct for memref descriptor"""
-  _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)]
+    """Creates a ctype struct for memref descriptor"""
+
+    _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)]
 
 
 def get_ranked_memref_descriptor(nparray):
-  """Returns a ranked memref descriptor for the given numpy array."""
-  ctp = as_ctype(nparray.dtype)
-  if nparray.ndim == 0:
-    x = make_zero_d_memref_descriptor(ctp)()
+    """Returns a ranked memref descriptor for the given numpy array."""
+    ctp = as_ctype(nparray.dtype)
+    if nparray.ndim == 0:
+        x = make_zero_d_memref_descriptor(ctp)()
+        x.allocated = nparray.ctypes.data
+        x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
+        x.offset = ctypes.c_longlong(0)
+        return x
+
+    x = make_nd_memref_descriptor(nparray.ndim, ctp)()
     x.allocated = nparray.ctypes.data
     x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
     x.offset = ctypes.c_longlong(0)
-    return x
+    x.shape = nparray.ctypes.shape
 
-  x = make_nd_memref_descriptor(nparray.ndim, ctp)()
-  x.allocated = nparray.ctypes.data
-  x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
-  x.offset = ctypes.c_longlong(0)
-  x.shape = nparray.ctypes.shape
-
-  # Numpy uses byte quantities to express strides, MLIR OTOH uses the
-  # torch abstraction which specifies strides in terms of elements.
-  strides_ctype_t = ctypes.c_longlong * nparray.ndim
-  x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides])
-  return x
+    # Numpy uses byte quantities to express strides, MLIR OTOH uses the
+    # torch abstraction which specifies strides in terms of elements.
+    strides_ctype_t = ctypes.c_longlong * nparray.ndim
+    x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides])
+    return x
 
 
 def get_unranked_memref_descriptor(nparray):
-  """Returns a generic/unranked memref descriptor for the given numpy array."""
-  d = UnrankedMemRefDescriptor()
-  d.rank = nparray.ndim
-  x = get_ranked_memref_descriptor(nparray)
-  d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
-  return d
+    """Returns a generic/unranked memref descriptor for the given numpy array."""
+    d = UnrankedMemRefDescriptor()
+    d.rank = nparray.ndim
+    x = get_ranked_memref_descriptor(nparray)
+    d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
+    return d
 
 
 def unranked_memref_to_numpy(unranked_memref, np_dtype):
-  """Converts unranked memrefs to numpy arrays."""
-  ctp = as_ctype(np_dtype)
-  descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp)
-  val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor))
-  np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape)
-  strided_arr = np.lib.stride_tricks.as_strided(
-      np_arr,
-      np.ctypeslib.as_array(val[0].shape),
-      np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize,
-  )
-  return to_numpy(strided_arr)
+    """Converts unranked memrefs to numpy arrays."""
+    ctp = as_ctype(np_dtype)
+    descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp)
+    val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor))
+    np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape)
+    strided_arr = np.lib.stride_tricks.as_strided(
+        np_arr,
+        np.ctypeslib.as_array(val[0].shape),
+        np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize,
+    )
+    return to_numpy(strided_arr)
 
 
 def ranked_memref_to_numpy(ranked_memref):
-  """Converts ranked memrefs to numpy arrays."""
-  np_arr = np.ctypeslib.as_array(
-      ranked_memref[0].aligned, shape=ranked_memref[0].shape)
-  strided_arr = np.lib.stride_tricks.as_strided(
-      np_arr,
-      np.ctypeslib.as_array(ranked_memref[0].shape),
-      np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize,
-  )
-  return to_numpy(strided_arr)
+    """Converts ranked memrefs to numpy arrays."""
+    np_arr = np.ctypeslib.as_array(
+        ranked_memref[0].aligned, shape=ranked_memref[0].shape
+    )
+    strided_arr = np.lib.stride_tricks.as_strided(
+        np_arr,
+        np.ctypeslib.as_array(ranked_memref[0].shape),
+        np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize,
+    )
+    return to_numpy(strided_arr)
index f08a0de..bb0c17c 100644 (file)
@@ -1 +1 @@
-config.suffixes.add('.c')
+config.suffixes.add(".c")
index 847c3ef..bc470cc 100644 (file)
@@ -1,2 +1,2 @@
 if not config.run_cuda_tests:
-  config.unsupported = True
\ No newline at end of file
+    config.unsupported = True
index 6eb5617..2f5cc9f 100644 (file)
@@ -1,2 +1,2 @@
 if not config.run_rocm_tests:
-  config.unsupported = True
+    config.unsupported = True
index c5aeb13..0d9aa10 100644 (file)
@@ -1,5 +1,3 @@
 # Requires native execution.
-if 'host-supports-jit' not in config.available_features:
+if "host-supports-jit" not in config.available_features:
     config.unsupported = True
-
-
index c5aeb13..0d9aa10 100644 (file)
@@ -1,5 +1,3 @@
 # Requires native execution.
-if 'host-supports-jit' not in config.available_features:
+if "host-supports-jit" not in config.available_features:
     config.unsupported = True
-
-
index 97db322..1a51296 100644 (file)
@@ -1,2 +1,2 @@
 if not config.build_examples:
-  config.unsupported = True
+    config.unsupported = True
index cf7c8ff..fe8397c 100644 (file)
@@ -1,13 +1,12 @@
 # Disable with sanitizers for now, this require some more setup apparently.
-for san in ['asan', 'msan', 'ubsan']:
-   if (san in config.available_features):
-      config.unsupported = True
+for san in ["asan", "msan", "ubsan"]:
+    if san in config.available_features:
+        config.unsupported = True
 
 config.substitutions.append(("%cmake_exe", config.host_cmake))
 config.substitutions.append(("%cmake_generator", config.host_cmake_generator))
 config.substitutions.append(("%host_cxx", config.host_cxx))
 config.substitutions.append(("%host_cc", config.host_cc))
 config.substitutions.append(("%enable_libcxx", config.enable_libcxx))
-config.substitutions.append(
-    ("%mlir_cmake_dir", config.mlir_cmake_dir))
+config.substitutions.append(("%mlir_cmake_dir", config.mlir_cmake_dir))
 config.substitutions.append(("%llvm_use_linker", config.llvm_use_linker))
index 7215eda..073f637 100644 (file)
@@ -1,5 +1,5 @@
 import sys
 
 # Windows does not have aligned_alloc
-if sys.platform == 'win32':
+if sys.platform == "win32":
     config.unsupported = True
index 263c8f8..071a13c 100644 (file)
@@ -1,4 +1,4 @@
 import platform
 
-if platform.machine() != 'x86_64':
+if platform.machine() != "x86_64":
     config.unsupported = True
index 7d1e494..3214a11 100644 (file)
@@ -1,18 +1,22 @@
 import sys
 
-lli_cmd = 'lli'
+lli_cmd = "lli"
 if config.riscv_emulator_lli_executable:
     lli_cmd = config.riscv_emulator_lli_executable
 
-config.substitutions.append(('%mlir_native_utils_lib_dir',
-    config.riscv_emulator_utils_lib_dir or config.mlir_lib_dir))
+config.substitutions.append(
+    (
+        "%mlir_native_utils_lib_dir",
+        config.riscv_emulator_utils_lib_dir or config.mlir_lib_dir,
+    )
+)
 
 if config.riscv_vector_emulator_executable:
     # Run test in qemu emulator.
     emulation_cmd = config.riscv_vector_emulator_executable
     if config.riscv_vector_emulator_options:
-        emulation_cmd = emulation_cmd + ' ' + config.riscv_vector_emulator_options
-    emulation_cmd = emulation_cmd + ' ' + lli_cmd + ' --march=riscv64 -mattr=+v '
-    config.substitutions.append(('%lli', emulation_cmd))
+        emulation_cmd = emulation_cmd + " " + config.riscv_vector_emulator_options
+    emulation_cmd = emulation_cmd + " " + lli_cmd + " --march=riscv64 -mattr=+v "
+    config.substitutions.append(("%lli", emulation_cmd))
 else:
-    config.substitutions.append(('%lli', lli_cmd))
+    config.substitutions.append(("%lli", lli_cmd))
index 9bf49cc..6e07eb8 100644 (file)
@@ -2,14 +2,16 @@ import sys
 from lit.llvm import llvm_config
 
 # FIXME: %mlir_native_utils_lib_dir is set incorrectly on Windows
-if sys.platform == 'win32':
+if sys.platform == "win32":
     config.unsupported = True
 
 # ArmSVE tests must be enabled via build flag.
 if config.mlir_run_arm_sve_tests:
-    config.substitutions.append(('%ENABLE_VLA', 'true'))
-    config.substitutions.append(('%VLA_ARCH_ATTR_OPTIONS', '--march=aarch64 --mattr="+sve"'))
+    config.substitutions.append(("%ENABLE_VLA", "true"))
+    config.substitutions.append(
+        ("%VLA_ARCH_ATTR_OPTIONS", '--march=aarch64 --mattr="+sve"')
+    )
 else:
-    config.substitutions.append(('%ENABLE_VLA', 'false'))
-    config.substitutions.append(('%VLA_ARCH_ATTR_OPTIONS', ''))
-    config.substitutions.append(('%mlir_native_utils_lib_dir', config.mlir_lib_dir))
+    config.substitutions.append(("%ENABLE_VLA", "false"))
+    config.substitutions.append(("%VLA_ARCH_ATTR_OPTIONS", ""))
+    config.substitutions.append(("%mlir_native_utils_lib_dir", config.mlir_lib_dir))
index c586aae..6788cce 100644 (file)
@@ -1,2 +1,2 @@
 if not config.enable_cuda_runner or not config.mlir_run_cuda_sm80_tests:
-  config.unsupported = True
+    config.unsupported = True
index cf04454..361b657 100644 (file)
@@ -1,5 +1,5 @@
 # Disable ASAN's leak detection for python OpsDSL tests.
-config.environment['ASAN_OPTIONS'] = 'detect_leaks=0'
+config.environment["ASAN_OPTIONS"] = "detect_leaks=0"
 # Only run when python bindings are enabled.
 if not config.enable_bindings_python:
-  config.unsupported = True
+    config.unsupported = True
index 958aa86..1f9b636 100644 (file)
@@ -18,42 +18,45 @@ _SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
 sys.path.append(_SCRIPT_PATH)
 from tools import sparse_compiler
 
+
 @dsl.linalg_structured_op
 def sddmm_dsl(
     A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K),
     B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N),
     S=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N),
-    C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True)):
-  C[dsl.D.m,
-    dsl.D.n] += S[dsl.D.m, dsl.D.n] * A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
+    C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True),
+):
+    C[dsl.D.m, dsl.D.n] += (
+        S[dsl.D.m, dsl.D.n] * A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
+    )
 
 
 def build_SDDMM(attr: st.EncodingAttr):
-  """Build SDDMM kernel.
+    """Build SDDMM kernel.
 
-  This method generates a linalg op with for matrix multiplication using
-  just the Python API. Effectively, a generic linalg op is constructed
-  that computes C(i,j) += S(i,j) SUM_k A(i,k) B(k,j) for sparse S.
-  """
-  module = ir.Module.create()
-  f64 = ir.F64Type.get()
-  a = ir.RankedTensorType.get([8, 8], f64)
-  b = ir.RankedTensorType.get([8, 8], f64)
-  c = ir.RankedTensorType.get([8, 8], f64)
-  s = ir.RankedTensorType.get([8, 8], f64, attr)
-  arguments = [a, b, s, c]
-  with ir.InsertionPoint(module.body):
+    This method generates a linalg op with for matrix multiplication using
+    just the Python API. Effectively, a generic linalg op is constructed
+    that computes C(i,j) += S(i,j) SUM_k A(i,k) B(k,j) for sparse S.
+    """
+    module = ir.Module.create()
+    f64 = ir.F64Type.get()
+    a = ir.RankedTensorType.get([8, 8], f64)
+    b = ir.RankedTensorType.get([8, 8], f64)
+    c = ir.RankedTensorType.get([8, 8], f64)
+    s = ir.RankedTensorType.get([8, 8], f64, attr)
+    arguments = [a, b, s, c]
+    with ir.InsertionPoint(module.body):
 
-    @func.FuncOp.from_py_func(*arguments)
-    def sddmm(*args):
-      return sddmm_dsl(args[0], args[1], args[2], outs=[args[3]])
+        @func.FuncOp.from_py_func(*arguments)
+        def sddmm(*args):
+            return sddmm_dsl(args[0], args[1], args[2], outs=[args[3]])
 
-  return module
+    return module
 
 
 def boilerplate(attr: st.EncodingAttr):
-  """Returns boilerplate code for main driver."""
-  return f"""
+    """Returns boilerplate code for main driver."""
+    return f"""
 func.func @main(%a: tensor<8x8xf64>,
            %b: tensor<8x8xf64>,
            %c: tensor<8x8xf64>) -> tensor<8x8xf64> attributes {{ llvm.emit_c_interface }} {{
@@ -69,92 +72,100 @@ func.func @main(%a: tensor<8x8xf64>,
 
 
 def build_compile_and_run_SDDMMM(attr: st.EncodingAttr, compiler):
-  # Build.
-  module = build_SDDMM(attr)
-  func = str(module.operation.regions[0].blocks[0].operations[0].operation)
-  module = ir.Module.parse(func + boilerplate(attr))
-
-  # Compile.
-  engine = compiler.compile_and_jit(module)
-
-  # Set up numpy input and buffer for output.
-  a = np.array([[1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1],
-                [1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2],
-                [1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3],
-                [1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4],
-                [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5],
-                [1.6, 2.6, 3.6, 4.6, 5.6, 6.6, 7.6, 8.6],
-                [1.7, 2.7, 3.7, 4.7, 5.7, 6.7, 7.7, 8.7],
-                [1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8]], np.float64)
-  b = np.ones((8, 8), np.float64)
-  c = np.zeros((8, 8), np.float64)
-
-  mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
-  mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
-  mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
-
-  # Allocate a MemRefDescriptor to receive the output tensor.
-  # The buffer itself is allocated inside the MLIR code generation.
-  ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)()
-  mem_out = ctypes.pointer(ctypes.pointer(ref_out))
-
-  # Invoke the kernel and get numpy output.
-  # Built-in bufferization uses in-out buffers.
-  # TODO: replace with inplace comprehensive bufferization.
-  engine.invoke('main', mem_out, mem_a, mem_b, mem_c)
-
-  # Sanity check on computed result. Only a few elements
-  # are sampled from the full dense matrix multiplication.
-  full_matmul = np.matmul(a, b)
-  expected = np.zeros((8, 8), np.float64)
-  expected[0, 0] = 1.0 * full_matmul[0, 0]
-  expected[0, 2] = 2.0 * full_matmul[0, 2]
-  expected[4, 1] = 3.0 * full_matmul[4, 1]
-  c = rt.ranked_memref_to_numpy(mem_out[0])
-  if np.allclose(c, expected):
-    pass
-  else:
-    quit(f'FAILURE')
+    # Build.
+    module = build_SDDMM(attr)
+    func = str(module.operation.regions[0].blocks[0].operations[0].operation)
+    module = ir.Module.parse(func + boilerplate(attr))
+
+    # Compile.
+    engine = compiler.compile_and_jit(module)
+
+    # Set up numpy input and buffer for output.
+    a = np.array(
+        [
+            [1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1],
+            [1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2],
+            [1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3],
+            [1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4],
+            [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5],
+            [1.6, 2.6, 3.6, 4.6, 5.6, 6.6, 7.6, 8.6],
+            [1.7, 2.7, 3.7, 4.7, 5.7, 6.7, 7.7, 8.7],
+            [1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8],
+        ],
+        np.float64,
+    )
+    b = np.ones((8, 8), np.float64)
+    c = np.zeros((8, 8), np.float64)
+
+    mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
+    mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
+    mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
+
+    # Allocate a MemRefDescriptor to receive the output tensor.
+    # The buffer itself is allocated inside the MLIR code generation.
+    ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)()
+    mem_out = ctypes.pointer(ctypes.pointer(ref_out))
+
+    # Invoke the kernel and get numpy output.
+    # Built-in bufferization uses in-out buffers.
+    # TODO: replace with inplace comprehensive bufferization.
+    engine.invoke("main", mem_out, mem_a, mem_b, mem_c)
+
+    # Sanity check on computed result. Only a few elements
+    # are sampled from the full dense matrix multiplication.
+    full_matmul = np.matmul(a, b)
+    expected = np.zeros((8, 8), np.float64)
+    expected[0, 0] = 1.0 * full_matmul[0, 0]
+    expected[0, 2] = 2.0 * full_matmul[0, 2]
+    expected[4, 1] = 3.0 * full_matmul[4, 1]
+    c = rt.ranked_memref_to_numpy(mem_out[0])
+    if np.allclose(c, expected):
+        pass
+    else:
+        quit(f"FAILURE")
 
 
 def main():
-  support_lib = os.getenv('SUPPORT_LIB')
-  assert support_lib is not None, 'SUPPORT_LIB is undefined'
-  if not os.path.exists(support_lib):
-    raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
-                            support_lib)
-
-  # CHECK-LABEL: TEST: testSDDMMM
-  print('\nTEST: testSDDMMM')
-  with ir.Context() as ctx, ir.Location.unknown():
-    count = 0
-    # Loop over various ways to compile and annotate the SDDMM kernel with
-    # a *single* sparse tensor. Note that we deliberate do not exhaustively
-    # search the full state space to reduce runtime of the test. It is
-    # straightforward to adapt the code below to explore more combinations.
-    levels = [[st.DimLevelType.dense, st.DimLevelType.dense],
-              [st.DimLevelType.dense, st.DimLevelType.compressed],
-              [st.DimLevelType.compressed, st.DimLevelType.dense],
-              [st.DimLevelType.compressed, st.DimLevelType.compressed]]
-    orderings = [
-        ir.AffineMap.get_permutation([0, 1]),
-        ir.AffineMap.get_permutation([1, 0])
-    ]
-    for level in levels:
-      for ordering in orderings:
-        for pwidth in [32]:
-          for iwidth in [32]:
-            for e in [True]:
-              attr = st.EncodingAttr.get(level, ordering, None, pwidth,
-                                         iwidth)
-              opt = (f'parallelization-strategy=none')
-              compiler = sparse_compiler.SparseCompiler(
-                  options=opt, opt_level=0, shared_libs=[support_lib])
-              build_compile_and_run_SDDMMM(attr, compiler)
-              count = count + 1
-  # CHECK: Passed 8 tests
-  print('Passed ', count, 'tests')
-
-
-if __name__ == '__main__':
-  main()
+    support_lib = os.getenv("SUPPORT_LIB")
+    assert support_lib is not None, "SUPPORT_LIB is undefined"
+    if not os.path.exists(support_lib):
+        raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
+
+    # CHECK-LABEL: TEST: testSDDMMM
+    print("\nTEST: testSDDMMM")
+    with ir.Context() as ctx, ir.Location.unknown():
+        count = 0
+        # Loop over various ways to compile and annotate the SDDMM kernel with
+        # a *single* sparse tensor. Note that we deliberate do not exhaustively
+        # search the full state space to reduce runtime of the test. It is
+        # straightforward to adapt the code below to explore more combinations.
+        levels = [
+            [st.DimLevelType.dense, st.DimLevelType.dense],
+            [st.DimLevelType.dense, st.DimLevelType.compressed],
+            [st.DimLevelType.compressed, st.DimLevelType.dense],
+            [st.DimLevelType.compressed, st.DimLevelType.compressed],
+        ]
+        orderings = [
+            ir.AffineMap.get_permutation([0, 1]),
+            ir.AffineMap.get_permutation([1, 0]),
+        ]
+        for level in levels:
+            for ordering in orderings:
+                for pwidth in [32]:
+                    for iwidth in [32]:
+                        for e in [True]:
+                            attr = st.EncodingAttr.get(
+                                level, ordering, None, pwidth, iwidth
+                            )
+                            opt = f"parallelization-strategy=none"
+                            compiler = sparse_compiler.SparseCompiler(
+                                options=opt, opt_level=0, shared_libs=[support_lib]
+                            )
+                            build_compile_and_run_SDDMMM(attr, compiler)
+                            count = count + 1
+    # CHECK: Passed 8 tests
+    print("Passed ", count, "tests")
+
+
+if __name__ == "__main__":
+    main()
index 97954ce..69f6cdc 100644 (file)
@@ -18,45 +18,47 @@ _SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
 sys.path.append(_SCRIPT_PATH)
 from tools import sparse_compiler
 
+
 @dsl.linalg_structured_op
 def matmul_dsl(
     A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K),
     B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N),
-    C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True)):
-  C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
+    C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True),
+):
+    C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
 
 
 def build_SpMM(attr: st.EncodingAttr):
-  """Build SpMM kernel.
+    """Build SpMM kernel.
 
-  This method generates a linalg op with for matrix multiplication using
-  just the Python API. Effectively, a generic linalg op is constructed
-  that computes C(i,j) += A(i,k) * B(k,j) for annotated matrix A.
-  """
-  module = ir.Module.create()
-  f64 = ir.F64Type.get()
-  a = ir.RankedTensorType.get([3, 4], f64, attr)
-  b = ir.RankedTensorType.get([4, 2], f64)
-  c = ir.RankedTensorType.get([3, 2], f64)
-  arguments = [a, b, c]
-  with ir.InsertionPoint(module.body):
+    This method generates a linalg op with for matrix multiplication using
+    just the Python API. Effectively, a generic linalg op is constructed
+    that computes C(i,j) += A(i,k) * B(k,j) for annotated matrix A.
+    """
+    module = ir.Module.create()
+    f64 = ir.F64Type.get()
+    a = ir.RankedTensorType.get([3, 4], f64, attr)
+    b = ir.RankedTensorType.get([4, 2], f64)
+    c = ir.RankedTensorType.get([3, 2], f64)
+    arguments = [a, b, c]
+    with ir.InsertionPoint(module.body):
 
-    @func.FuncOp.from_py_func(*arguments)
-    def spMxM(*args):
-      return matmul_dsl(args[0], args[1], outs=[args[2]])
+        @func.FuncOp.from_py_func(*arguments)
+        def spMxM(*args):
+            return matmul_dsl(args[0], args[1], outs=[args[2]])
 
-  return module
+    return module
 
 
 def boilerplate(attr: st.EncodingAttr):
-  """Returns boilerplate main method.
-
-  This method sets up a boilerplate main method that takes three tensors
-  (a, b, c), converts the first tensor a into s sparse tensor, and then
-  calls the sparse kernel for matrix multiplication. For convenience,
-  this part is purely done as string input.
-  """
-  return f"""
+    """Returns boilerplate main method.
+
+    This method sets up a boilerplate main method that takes three tensors
+    (a, b, c), converts the first tensor a into s sparse tensor, and then
+    calls the sparse kernel for matrix multiplication. For convenience,
+    this part is purely done as string input.
+    """
+    return f"""
 func.func @main(%ad: tensor<3x4xf64>, %b: tensor<4x2xf64>, %c: tensor<3x2xf64>) -> tensor<3x2xf64>
   attributes {{ llvm.emit_c_interface }} {{
   %a = sparse_tensor.convert %ad : tensor<3x4xf64> to tensor<3x4xf64, {attr}>
@@ -69,82 +71,87 @@ func.func @main(%ad: tensor<3x4xf64>, %b: tensor<4x2xf64>, %c: tensor<3x2xf64>)
 
 
 def build_compile_and_run_SpMM(attr: st.EncodingAttr, compiler):
-  # Build.
-  module = build_SpMM(attr)
-  func = str(module.operation.regions[0].blocks[0].operations[0].operation)
-  module = ir.Module.parse(func + boilerplate(attr))
-
-  # Compile.
-  engine = compiler.compile_and_jit(module)
-
-  # Set up numpy input and buffer for output.
-  a = np.array(
-      [[1.1, 0.0, 0.0, 1.4], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 3.3, 0.0]],
-      np.float64)
-  b = np.array([[1.0, 2.0], [4.0, 3.0], [5.0, 6.0], [8.0, 7.0]], np.float64)
-  c = np.zeros((3, 2), np.float64)
-
-  mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
-  mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
-  mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
-  # Allocate a MemRefDescriptor to receive the output tensor.
-  # The buffer itself is allocated inside the MLIR code generation.
-  ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)()
-  mem_out = ctypes.pointer(ctypes.pointer(ref_out))
-
-  # Invoke the kernel and get numpy output.
-  # Built-in bufferization uses in-out buffers.
-  # TODO: replace with inplace comprehensive bufferization.
-  engine.invoke('main', mem_out, mem_a, mem_b, mem_c)
-
-  # Sanity check on computed result.
-  expected = np.matmul(a, b);
-  c = rt.ranked_memref_to_numpy(mem_out[0])
-  if np.allclose(c, expected):
-    pass
-  else:
-    quit(f'FAILURE')
+    # Build.
+    module = build_SpMM(attr)
+    func = str(module.operation.regions[0].blocks[0].operations[0].operation)
+    module = ir.Module.parse(func + boilerplate(attr))
+
+    # Compile.
+    engine = compiler.compile_and_jit(module)
+
+    # Set up numpy input and buffer for output.
+    a = np.array(
+        [[1.1, 0.0, 0.0, 1.4], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 3.3, 0.0]], np.float64
+    )
+    b = np.array([[1.0, 2.0], [4.0, 3.0], [5.0, 6.0], [8.0, 7.0]], np.float64)
+    c = np.zeros((3, 2), np.float64)
+
+    mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
+    mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
+    mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
+    # Allocate a MemRefDescriptor to receive the output tensor.
+    # The buffer itself is allocated inside the MLIR code generation.
+    ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)()
+    mem_out = ctypes.pointer(ctypes.pointer(ref_out))
+
+    # Invoke the kernel and get numpy output.
+    # Built-in bufferization uses in-out buffers.
+    # TODO: replace with inplace comprehensive bufferization.
+    engine.invoke("main", mem_out, mem_a, mem_b, mem_c)
+
+    # Sanity check on computed result.
+    expected = np.matmul(a, b)
+    c = rt.ranked_memref_to_numpy(mem_out[0])
+    if np.allclose(c, expected):
+        pass
+    else:
+        quit(f"FAILURE")
 
 
 def main():
-  support_lib = os.getenv('SUPPORT_LIB')
-  assert support_lib is not None, 'SUPPORT_LIB is undefined'
-  if not os.path.exists(support_lib):
-    raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
-
-  # CHECK-LABEL: TEST: testSpMM
-  print('\nTEST: testSpMM')
-  with ir.Context() as ctx, ir.Location.unknown():
-    count = 0
-    # Loop over various ways to compile and annotate the SpMM kernel with
-    # a *single* sparse tensor. Note that we deliberate do not exhaustively
-    # search the full state space to reduce runtime of the test. It is
-    # straightforward to adapt the code below to explore more combinations.
-
-    vl = 1
-    e = False
-    opt = (f'parallelization-strategy=none')
-    levels = [[st.DimLevelType.dense, st.DimLevelType.dense],
-              [st.DimLevelType.dense, st.DimLevelType.compressed],
-              [st.DimLevelType.compressed, st.DimLevelType.dense],
-              [st.DimLevelType.compressed, st.DimLevelType.compressed]]
-    orderings = [
-        ir.AffineMap.get_permutation([0, 1]),
-        ir.AffineMap.get_permutation([1, 0])
-    ]
-    bitwidths = [0]
-    compiler = sparse_compiler.SparseCompiler(
-        options=opt, opt_level=0, shared_libs=[support_lib])
-    for level in levels:
-      for ordering in orderings:
-        for pwidth in bitwidths:
-          for iwidth in bitwidths:
-            attr = st.EncodingAttr.get(level, ordering, None, pwidth, iwidth)
-            build_compile_and_run_SpMM(attr, compiler)
-            count = count + 1
-    # CHECK: Passed 8 tests
-    print('Passed ', count, 'tests')
-
-
-if __name__ == '__main__':
-  main()
+    support_lib = os.getenv("SUPPORT_LIB")
+    assert support_lib is not None, "SUPPORT_LIB is undefined"
+    if not os.path.exists(support_lib):
+        raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
+
+    # CHECK-LABEL: TEST: testSpMM
+    print("\nTEST: testSpMM")
+    with ir.Context() as ctx, ir.Location.unknown():
+        count = 0
+        # Loop over various ways to compile and annotate the SpMM kernel with
+        # a *single* sparse tensor. Note that we deliberate do not exhaustively
+        # search the full state space to reduce runtime of the test. It is
+        # straightforward to adapt the code below to explore more combinations.
+
+        vl = 1
+        e = False
+        opt = f"parallelization-strategy=none"
+        levels = [
+            [st.DimLevelType.dense, st.DimLevelType.dense],
+            [st.DimLevelType.dense, st.DimLevelType.compressed],
+            [st.DimLevelType.compressed, st.DimLevelType.dense],
+            [st.DimLevelType.compressed, st.DimLevelType.compressed],
+        ]
+        orderings = [
+            ir.AffineMap.get_permutation([0, 1]),
+            ir.AffineMap.get_permutation([1, 0]),
+        ]
+        bitwidths = [0]
+        compiler = sparse_compiler.SparseCompiler(
+            options=opt, opt_level=0, shared_libs=[support_lib]
+        )
+        for level in levels:
+            for ordering in orderings:
+                for pwidth in bitwidths:
+                    for iwidth in bitwidths:
+                        attr = st.EncodingAttr.get(
+                            level, ordering, None, pwidth, iwidth
+                        )
+                        build_compile_and_run_SpMM(attr, compiler)
+                        count = count + 1
+        # CHECK: Passed 8 tests
+        print("Passed ", count, "tests")
+
+
+if __name__ == "__main__":
+    main()
index b29b029..a41bde1 100644 (file)
@@ -57,49 +57,52 @@ func.func @main(%ad: tensor<3x4xf64>, %bd: tensor<3x4xf64>) -> tensor<3x4xf64, #
 
 
 def _run_test(support_lib, kernel):
-  """Compiles, runs and checks results."""
-  compiler = sparse_compiler.SparseCompiler(
-      options='', opt_level=2, shared_libs=[support_lib])
-  module = ir.Module.parse(kernel)
-  engine = compiler.compile_and_jit(module)
-
-  # Set up numpy inputs and buffer for output.
-  a = np.array(
-      [[1.1, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 6.6, 0.0]],
-      np.float64)
-  b = np.array(
-      [[1.1, 0.0, 0.0, 2.8], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]],
-      np.float64)
-
-  mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
-  mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
-
-  # The sparse tensor output is a pointer to pointer of char.
-  out = ctypes.c_char(0)
-  mem_out = ctypes.pointer(ctypes.pointer(out))
-
-  # Invoke the kernel.
-  engine.invoke('main', mem_a, mem_b, mem_out)
-
-  # Retrieve and check the result.
-  rank, nse, shape, values, indices = test_tools.sparse_tensor_to_coo_tensor(
-      support_lib, mem_out[0], np.float64)
-
-  # CHECK: PASSED
-  if np.allclose(values, [2.2, 2.8, 6.6]) and np.allclose(
-      indices, [[0, 0], [0, 3], [2, 2]]):
-    print('PASSED')
-  else:
-    quit('FAILURE')
+    """Compiles, runs and checks results."""
+    compiler = sparse_compiler.SparseCompiler(
+        options="", opt_level=2, shared_libs=[support_lib]
+    )
+    module = ir.Module.parse(kernel)
+    engine = compiler.compile_and_jit(module)
+
+    # Set up numpy inputs and buffer for output.
+    a = np.array(
+        [[1.1, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 6.6, 0.0]], np.float64
+    )
+    b = np.array(
+        [[1.1, 0.0, 0.0, 2.8], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], np.float64
+    )
+
+    mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
+    mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
+
+    # The sparse tensor output is a pointer to pointer of char.
+    out = ctypes.c_char(0)
+    mem_out = ctypes.pointer(ctypes.pointer(out))
+
+    # Invoke the kernel.
+    engine.invoke("main", mem_a, mem_b, mem_out)
+
+    # Retrieve and check the result.
+    rank, nse, shape, values, indices = test_tools.sparse_tensor_to_coo_tensor(
+        support_lib, mem_out[0], np.float64
+    )
+
+    # CHECK: PASSED
+    if np.allclose(values, [2.2, 2.8, 6.6]) and np.allclose(
+        indices, [[0, 0], [0, 3], [2, 2]]
+    ):
+        print("PASSED")
+    else:
+        quit("FAILURE")
 
 
 def test_elementwise_add():
-  # Obtain path to runtime support library.
-  support_lib = os.getenv('SUPPORT_LIB')
-  assert support_lib is not None, 'SUPPORT_LIB is undefined'
-  assert os.path.exists(support_lib), f'{support_lib} does not exist'
-  with ir.Context() as ctx, ir.Location.unknown():
-    _run_test(support_lib, _KERNEL_STR)
+    # Obtain path to runtime support library.
+    support_lib = os.getenv("SUPPORT_LIB")
+    assert support_lib is not None, "SUPPORT_LIB is undefined"
+    assert os.path.exists(support_lib), f"{support_lib} does not exist"
+    with ir.Context() as ctx, ir.Location.unknown():
+        _run_test(support_lib, _KERNEL_STR)
 
 
 test_elementwise_add()
index 7d57b1c..7d77490 100644 (file)
@@ -18,8 +18,8 @@ from tools import sparse_compiler
 
 # TODO: move more into actual IR building.
 def boilerplate(attr: st.EncodingAttr):
-  """Returns boilerplate main method."""
-  return f"""
+    """Returns boilerplate main method."""
+    return f"""
 func.func @main(%p : !llvm.ptr<i8>) -> () attributes {{ llvm.emit_c_interface }} {{
   %d = arith.constant sparse<[[0, 0], [1, 1], [0, 9], [9, 0], [4, 4]],
                              [1.0, 2.0, 3.0, 4.0, 5.0]> : tensor<10x10xf64>
@@ -31,13 +31,13 @@ func.func @main(%p : !llvm.ptr<i8>) -> () attributes {{ llvm.emit_c_interface }}
 
 
 def expected():
-  """Returns expected contents of output.
+    """Returns expected contents of output.
 
-  Regardless of the dimension ordering, compression, and bitwidths that are
-  used in the sparse tensor, the output is always lexicographically sorted
-  by natural index order.
-  """
-  return f"""; extended FROSTT format
+    Regardless of the dimension ordering, compression, and bitwidths that are
+    used in the sparse tensor, the output is always lexicographically sorted
+    by natural index order.
+    """
+    return f"""; extended FROSTT format
 2 5
 10 10
 1 1 1
@@ -49,53 +49,55 @@ def expected():
 
 
 def build_compile_and_run_output(attr: st.EncodingAttr, compiler):
-  # Build and Compile.
-  module = ir.Module.parse(boilerplate(attr))
-  engine = compiler.compile_and_jit(module)
+    # Build and Compile.
+    module = ir.Module.parse(boilerplate(attr))
+    engine = compiler.compile_and_jit(module)
 
-  # Invoke the kernel and compare output.
-  with tempfile.TemporaryDirectory() as test_dir:
-    out = os.path.join(test_dir, 'out.tns')
-    buf = out.encode('utf-8')
-    mem_a = ctypes.pointer(ctypes.pointer(ctypes.create_string_buffer(buf)))
-    engine.invoke('main', mem_a)
+    # Invoke the kernel and compare output.
+    with tempfile.TemporaryDirectory() as test_dir:
+        out = os.path.join(test_dir, "out.tns")
+        buf = out.encode("utf-8")
+        mem_a = ctypes.pointer(ctypes.pointer(ctypes.create_string_buffer(buf)))
+        engine.invoke("main", mem_a)
 
-    actual = open(out).read()
-    if actual != expected():
-      quit('FAILURE')
+        actual = open(out).read()
+        if actual != expected():
+            quit("FAILURE")
 
 
 def main():
-  support_lib = os.getenv('SUPPORT_LIB')
-  assert support_lib is not None, 'SUPPORT_LIB is undefined'
-  if not os.path.exists(support_lib):
-    raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
-                            support_lib)
-
-  # CHECK-LABEL: TEST: test_output
-  print('\nTEST: test_output')
-  count = 0
-  with ir.Context() as ctx, ir.Location.unknown():
-    # Loop over various sparse types: CSR, DCSR, CSC, DCSC.
-    levels = [[st.DimLevelType.dense, st.DimLevelType.compressed],
-              [st.DimLevelType.compressed, st.DimLevelType.compressed]]
-    orderings = [
-        ir.AffineMap.get_permutation([0, 1]),
-        ir.AffineMap.get_permutation([1, 0])
-    ]
-    bitwidths = [8, 16, 32, 64]
-    compiler = sparse_compiler.SparseCompiler(
-        options='', opt_level=2, shared_libs=[support_lib])
-    for level in levels:
-      for ordering in orderings:
-        for bwidth in bitwidths:
-          attr = st.EncodingAttr.get(level, ordering, None, bwidth, bwidth)
-          build_compile_and_run_output(attr, compiler)
-          count = count + 1
-
-  # CHECK: Passed 16 tests
-  print('Passed', count, 'tests')
-
-
-if __name__ == '__main__':
-  main()
+    support_lib = os.getenv("SUPPORT_LIB")
+    assert support_lib is not None, "SUPPORT_LIB is undefined"
+    if not os.path.exists(support_lib):
+        raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
+
+    # CHECK-LABEL: TEST: test_output
+    print("\nTEST: test_output")
+    count = 0
+    with ir.Context() as ctx, ir.Location.unknown():
+        # Loop over various sparse types: CSR, DCSR, CSC, DCSC.
+        levels = [
+            [st.DimLevelType.dense, st.DimLevelType.compressed],
+            [st.DimLevelType.compressed, st.DimLevelType.compressed],
+        ]
+        orderings = [
+            ir.AffineMap.get_permutation([0, 1]),
+            ir.AffineMap.get_permutation([1, 0]),
+        ]
+        bitwidths = [8, 16, 32, 64]
+        compiler = sparse_compiler.SparseCompiler(
+            options="", opt_level=2, shared_libs=[support_lib]
+        )
+        for level in levels:
+            for ordering in orderings:
+                for bwidth in bitwidths:
+                    attr = st.EncodingAttr.get(level, ordering, None, bwidth, bwidth)
+                    build_compile_and_run_output(attr, compiler)
+                    count = count + 1
+
+    # CHECK: Passed 16 tests
+    print("Passed", count, "tests")
+
+
+if __name__ == "__main__":
+    main()
index 3a04e5b..373f745 100644 (file)
@@ -28,216 +28,241 @@ from tools import sparse_compiler
 # TODO: move this boilerplate to its own module, so it can be used by
 # other tests and programs.
 class TypeConverter:
-  """Converter between NumPy types and MLIR types."""
-
-  def __init__(self, context: ir.Context):
-    # Note 1: these are numpy "scalar types" (i.e., the values of
-    # np.sctypeDict) not numpy "dtypes" (i.e., the np.dtype class).
-    #
-    # Note 2: we must construct the MLIR types in the same context as the
-    # types that'll be passed to irtype_to_sctype() or irtype_to_dtype();
-    # otherwise, those methods will raise a KeyError.
-    types_list = [
-      (np.float64, ir.F64Type.get(context=context)),
-      (np.float32, ir.F32Type.get(context=context)),
-      (np.int64, ir.IntegerType.get_signless(64, context=context)),
-      (np.int32, ir.IntegerType.get_signless(32, context=context)),
-      (np.int16, ir.IntegerType.get_signless(16, context=context)),
-      (np.int8, ir.IntegerType.get_signless(8, context=context)),
-    ]
-    self._sc2ir = dict(types_list)
-    self._ir2sc = dict(( (ir,sc) for sc,ir in types_list ))
-
-  def dtype_to_irtype(self, dtype: np.dtype) -> ir.Type:
-    """Returns the MLIR equivalent of a NumPy dtype."""
-    try:
-      return self.sctype_to_irtype(dtype.type)
-    except KeyError as e:
-      raise KeyError(f'Unknown dtype: {dtype}') from e
-
-  def sctype_to_irtype(self, sctype) -> ir.Type:
-    """Returns the MLIR equivalent of a NumPy scalar type."""
-    if sctype in self._sc2ir:
-      return self._sc2ir[sctype]
-    else:
-      raise KeyError(f'Unknown sctype: {sctype}')
-
-  def irtype_to_dtype(self, tp: ir.Type) -> np.dtype:
-    """Returns the NumPy dtype equivalent of an MLIR type."""
-    return np.dtype(self.irtype_to_sctype(tp))
-
-  def irtype_to_sctype(self, tp: ir.Type):
-    """Returns the NumPy scalar-type equivalent of an MLIR type."""
-    if tp in self._ir2sc:
-      return self._ir2sc[tp]
-    else:
-      raise KeyError(f'Unknown ir.Type: {tp}')
-
-  def get_RankedTensorType_of_nparray(self, nparray: np.ndarray) -> ir.RankedTensorType:
-    """Returns the ir.RankedTensorType of a NumPy array.  Note that NumPy
-    arrays can only be converted to/from dense tensors, not sparse tensors."""
-    # TODO: handle strides as well?
-    return ir.RankedTensorType.get(nparray.shape,
-                                   self.dtype_to_irtype(nparray.dtype))
+    """Converter between NumPy types and MLIR types."""
+
+    def __init__(self, context: ir.Context):
+        # Note 1: these are numpy "scalar types" (i.e., the values of
+        # np.sctypeDict) not numpy "dtypes" (i.e., the np.dtype class).
+        #
+        # Note 2: we must construct the MLIR types in the same context as the
+        # types that'll be passed to irtype_to_sctype() or irtype_to_dtype();
+        # otherwise, those methods will raise a KeyError.
+        types_list = [
+            (np.float64, ir.F64Type.get(context=context)),
+            (np.float32, ir.F32Type.get(context=context)),
+            (np.int64, ir.IntegerType.get_signless(64, context=context)),
+            (np.int32, ir.IntegerType.get_signless(32, context=context)),
+            (np.int16, ir.IntegerType.get_signless(16, context=context)),
+            (np.int8, ir.IntegerType.get_signless(8, context=context)),
+        ]
+        self._sc2ir = dict(types_list)
+        self._ir2sc = dict(((ir, sc) for sc, ir in types_list))
+
+    def dtype_to_irtype(self, dtype: np.dtype) -> ir.Type:
+        """Returns the MLIR equivalent of a NumPy dtype."""
+        try:
+            return self.sctype_to_irtype(dtype.type)
+        except KeyError as e:
+            raise KeyError(f"Unknown dtype: {dtype}") from e
+
+    def sctype_to_irtype(self, sctype) -> ir.Type:
+        """Returns the MLIR equivalent of a NumPy scalar type."""
+        if sctype in self._sc2ir:
+            return self._sc2ir[sctype]
+        else:
+            raise KeyError(f"Unknown sctype: {sctype}")
+
+    def irtype_to_dtype(self, tp: ir.Type) -> np.dtype:
+        """Returns the NumPy dtype equivalent of an MLIR type."""
+        return np.dtype(self.irtype_to_sctype(tp))
+
+    def irtype_to_sctype(self, tp: ir.Type):
+        """Returns the NumPy scalar-type equivalent of an MLIR type."""
+        if tp in self._ir2sc:
+            return self._ir2sc[tp]
+        else:
+            raise KeyError(f"Unknown ir.Type: {tp}")
+
+    def get_RankedTensorType_of_nparray(
+        self, nparray: np.ndarray
+    ) -> ir.RankedTensorType:
+        """Returns the ir.RankedTensorType of a NumPy array.  Note that NumPy
+        arrays can only be converted to/from dense tensors, not sparse tensors."""
+        # TODO: handle strides as well?
+        return ir.RankedTensorType.get(
+            nparray.shape, self.dtype_to_irtype(nparray.dtype)
+        )
+
 
 # ===----------------------------------------------------------------------=== #
 
+
 class StressTest:
-  def __init__(self, tyconv: TypeConverter):
-    self._tyconv = tyconv
-    self._roundtripTp = None
-    self._module = None
-    self._engine = None
-
-  def _assertEqualsRoundtripTp(self, tp: ir.RankedTensorType):
-    assert self._roundtripTp is not None, \
-        'StressTest: uninitialized roundtrip type'
-    if tp != self._roundtripTp:
-      raise AssertionError(
-          f"Type is not equal to the roundtrip type.\n"
-          f"\tExpected: {self._roundtripTp}\n"
-          f"\tFound:    {tp}\n")
-
-  def build(self, types: List[ir.Type]):
-    """Builds the ir.Module.  The module has only the @main function,
-    which will convert the input through the list of types and then back
-    to the initial type.  The roundtrip type must be a dense tensor."""
-    assert self._module is None, 'StressTest: must not call build() repeatedly'
-    self._module = ir.Module.create()
-    with ir.InsertionPoint(self._module.body):
-      tp0 = types.pop(0)
-      self._roundtripTp = tp0
-      # TODO: assert dense? assert element type is recognised by the TypeConverter?
-      types.append(tp0)
-      funcTp = ir.FunctionType.get(inputs=[tp0], results=[tp0])
-      funcOp = func.FuncOp(name='main', type=funcTp)
-      funcOp.attributes['llvm.emit_c_interface'] = ir.UnitAttr.get()
-      with ir.InsertionPoint(funcOp.add_entry_block()):
-        arg0 = funcOp.entry_block.arguments[0]
-        self._assertEqualsRoundtripTp(arg0.type)
-        v = st.ConvertOp(types.pop(0), arg0)
-        for tp in types:
-          w = st.ConvertOp(tp, v)
-          # Release intermediate tensors before they fall out of scope.
-          bufferization.DeallocTensorOp(v.result)
-          v = w
-        self._assertEqualsRoundtripTp(v.result.type)
-        func.ReturnOp(v)
-    return self
-
-  def writeTo(self, filename):
-    """Write the ir.Module to the given file.  If the file already exists,
-    then raises an error.  If the filename is None, then is a no-op."""
-    assert self._module is not None, \
-        'StressTest: must call build() before writeTo()'
-    if filename is None:
-      # Silent no-op, for convenience.
-      return self
-    if os.path.exists(filename):
-      raise FileExistsError(errno.EEXIST, os.strerror(errno.EEXIST), filename)
-    with open(filename, 'w') as f:
-      f.write(str(self._module))
-    return self
-
-  def compile(self, compiler):
-    """Compile the ir.Module."""
-    assert self._module is not None, \
-        'StressTest: must call build() before compile()'
-    assert self._engine is None, \
-        'StressTest: must not call compile() repeatedly'
-    self._engine = compiler.compile_and_jit(self._module)
-    return self
-
-  def run(self, np_arg0: np.ndarray) -> np.ndarray:
-    """Runs the test on the given numpy array, and returns the resulting
-    numpy array."""
-    assert self._engine is not None, \
-        'StressTest: must call compile() before run()'
-    self._assertEqualsRoundtripTp(
-        self._tyconv.get_RankedTensorType_of_nparray(np_arg0))
-    np_out = np.zeros(np_arg0.shape, dtype=np_arg0.dtype)
-    self._assertEqualsRoundtripTp(
-        self._tyconv.get_RankedTensorType_of_nparray(np_out))
-    mem_arg0 = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_arg0)))
-    mem_out = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_out)))
-    self._engine.invoke('main', mem_out, mem_arg0)
-    return rt.ranked_memref_to_numpy(mem_out[0])
+    def __init__(self, tyconv: TypeConverter):
+        self._tyconv = tyconv
+        self._roundtripTp = None
+        self._module = None
+        self._engine = None
+
+    def _assertEqualsRoundtripTp(self, tp: ir.RankedTensorType):
+        assert self._roundtripTp is not None, "StressTest: uninitialized roundtrip type"
+        if tp != self._roundtripTp:
+            raise AssertionError(
+                f"Type is not equal to the roundtrip type.\n"
+                f"\tExpected: {self._roundtripTp}\n"
+                f"\tFound:    {tp}\n"
+            )
+
+    def build(self, types: List[ir.Type]):
+        """Builds the ir.Module.  The module has only the @main function,
+        which will convert the input through the list of types and then back
+        to the initial type.  The roundtrip type must be a dense tensor."""
+        assert self._module is None, "StressTest: must not call build() repeatedly"
+        self._module = ir.Module.create()
+        with ir.InsertionPoint(self._module.body):
+            tp0 = types.pop(0)
+            self._roundtripTp = tp0
+            # TODO: assert dense? assert element type is recognised by the TypeConverter?
+            types.append(tp0)
+            funcTp = ir.FunctionType.get(inputs=[tp0], results=[tp0])
+            funcOp = func.FuncOp(name="main", type=funcTp)
+            funcOp.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+            with ir.InsertionPoint(funcOp.add_entry_block()):
+                arg0 = funcOp.entry_block.arguments[0]
+                self._assertEqualsRoundtripTp(arg0.type)
+                v = st.ConvertOp(types.pop(0), arg0)
+                for tp in types:
+                    w = st.ConvertOp(tp, v)
+                    # Release intermediate tensors before they fall out of scope.
+                    bufferization.DeallocTensorOp(v.result)
+                    v = w
+                self._assertEqualsRoundtripTp(v.result.type)
+                func.ReturnOp(v)
+        return self
+
+    def writeTo(self, filename):
+        """Write the ir.Module to the given file.  If the file already exists,
+        then raises an error.  If the filename is None, then is a no-op."""
+        assert (
+            self._module is not None
+        ), "StressTest: must call build() before writeTo()"
+        if filename is None:
+            # Silent no-op, for convenience.
+            return self
+        if os.path.exists(filename):
+            raise FileExistsError(errno.EEXIST, os.strerror(errno.EEXIST), filename)
+        with open(filename, "w") as f:
+            f.write(str(self._module))
+        return self
+
+    def compile(self, compiler):
+        """Compile the ir.Module."""
+        assert (
+            self._module is not None
+        ), "StressTest: must call build() before compile()"
+        assert self._engine is None, "StressTest: must not call compile() repeatedly"
+        self._engine = compiler.compile_and_jit(self._module)
+        return self
+
+    def run(self, np_arg0: np.ndarray) -> np.ndarray:
+        """Runs the test on the given numpy array, and returns the resulting
+        numpy array."""
+        assert self._engine is not None, "StressTest: must call compile() before run()"
+        self._assertEqualsRoundtripTp(
+            self._tyconv.get_RankedTensorType_of_nparray(np_arg0)
+        )
+        np_out = np.zeros(np_arg0.shape, dtype=np_arg0.dtype)
+        self._assertEqualsRoundtripTp(
+            self._tyconv.get_RankedTensorType_of_nparray(np_out)
+        )
+        mem_arg0 = ctypes.pointer(
+            ctypes.pointer(rt.get_ranked_memref_descriptor(np_arg0))
+        )
+        mem_out = ctypes.pointer(
+            ctypes.pointer(rt.get_ranked_memref_descriptor(np_out))
+        )
+        self._engine.invoke("main", mem_out, mem_arg0)
+        return rt.ranked_memref_to_numpy(mem_out[0])
+
 
 # ===----------------------------------------------------------------------=== #
 
+
 def main():
-  """
-  USAGE: python3 test_stress.py [raw_module.mlir [compiled_module.mlir]]
-
-  The environment variable SUPPORT_LIB must be set to point to the
-  libmlir_c_runner_utils shared library.  There are two optional
-  arguments, for debugging purposes.  The first argument specifies where
-  to write out the raw/generated ir.Module.  The second argument specifies
-  where to write out the compiled version of that ir.Module.
-  """
-  support_lib = os.getenv('SUPPORT_LIB')
-  assert support_lib is not None, 'SUPPORT_LIB is undefined'
-  if not os.path.exists(support_lib):
-    raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
-
-  # CHECK-LABEL: TEST: test_stress
-  print("\nTEST: test_stress")
-  with ir.Context() as ctx, ir.Location.unknown():
-    # Disable direct sparse2sparse conversion, because it doubles the time!
-    # TODO: While direct s2s is far too slow for per-commit testing,
-    # we should have some framework ensure that we run this test with
-    # `s2s=0` on a regular basis, to ensure that it does continue to work.
-    # TODO: be sure to test s2s=0 together with singletons.
-    s2s = 1
-    sparsification_options = (
-        f'parallelization-strategy=none '
-        f's2s-strategy={s2s}')
-    compiler = sparse_compiler.SparseCompiler(
-        options=sparsification_options, opt_level=0, shared_libs=[support_lib])
-    f64 = ir.F64Type.get()
-    # Be careful about increasing this because
-    #     len(types) = 1 + len(level_choices)^rank * rank! * len(bitwidths)^2
-    shape = range(2, 3)
-    rank = len(shape)
-    # All combinations.
-    # TODO: add singleton here too; which requires updating how `np_arg0`
-    # is initialized below.
-    levels = list(itertools.product(*itertools.repeat(
-      [st.DimLevelType.dense, st.DimLevelType.compressed], rank)))
-    # All permutations.
-    orderings = list(map(ir.AffineMap.get_permutation,
-      itertools.permutations(range(rank))))
-    bitwidths = [0]
-    # The first type must be a dense tensor for numpy conversion to work.
-    types = [ir.RankedTensorType.get(shape, f64)]
-    for level in levels:
-      for ordering in orderings:
-        for pwidth in bitwidths:
-          for iwidth in bitwidths:
-            attr = st.EncodingAttr.get(level, ordering, None, pwidth, iwidth)
-            types.append(ir.RankedTensorType.get(shape, f64, attr))
-    #
-    # For exhaustiveness we should have one or more StressTest, such
-    # that their paths cover all 2*n*(n-1) directed pairwise combinations
-    # of the `types` set.  However, since n is already superexponential,
-    # such exhaustiveness would be prohibitive for a test that runs on
-    # every commit.  So for now we'll just pick one particular path that
-    # at least hits all n elements of the `types` set.
-    #
-    tyconv = TypeConverter(ctx)
-    size = 1
-    for d in shape:
-      size *= d
-    np_arg0 = np.arange(size, dtype=tyconv.irtype_to_dtype(f64)).reshape(*shape)
-    np_out = (
-        StressTest(tyconv).build(types).writeTo(
-            sys.argv[1] if len(sys.argv) > 1 else None).compile(compiler)
-        .writeTo(sys.argv[2] if len(sys.argv) > 2 else None).run(np_arg0))
-    # CHECK: Passed
-    if np.allclose(np_out, np_arg0):
-      print('Passed')
-    else:
-      sys.exit('FAILURE')
-
-if __name__ == '__main__':
-  main()
+    """
+    USAGE: python3 test_stress.py [raw_module.mlir [compiled_module.mlir]]
+
+    The environment variable SUPPORT_LIB must be set to point to the
+    libmlir_c_runner_utils shared library.  There are two optional
+    arguments, for debugging purposes.  The first argument specifies where
+    to write out the raw/generated ir.Module.  The second argument specifies
+    where to write out the compiled version of that ir.Module.
+    """
+    support_lib = os.getenv("SUPPORT_LIB")
+    assert support_lib is not None, "SUPPORT_LIB is undefined"
+    if not os.path.exists(support_lib):
+        raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
+
+    # CHECK-LABEL: TEST: test_stress
+    print("\nTEST: test_stress")
+    with ir.Context() as ctx, ir.Location.unknown():
+        # Disable direct sparse2sparse conversion, because it doubles the time!
+        # TODO: While direct s2s is far too slow for per-commit testing,
+        # we should have some framework ensure that we run this test with
+        # `s2s=0` on a regular basis, to ensure that it does continue to work.
+        # TODO: be sure to test s2s=0 together with singletons.
+        s2s = 1
+        sparsification_options = f"parallelization-strategy=none " f"s2s-strategy={s2s}"
+        compiler = sparse_compiler.SparseCompiler(
+            options=sparsification_options, opt_level=0, shared_libs=[support_lib]
+        )
+        f64 = ir.F64Type.get()
+        # Be careful about increasing this because
+        #     len(types) = 1 + len(level_choices)^rank * rank! * len(bitwidths)^2
+        shape = range(2, 3)
+        rank = len(shape)
+        # All combinations.
+        # TODO: add singleton here too; which requires updating how `np_arg0`
+        # is initialized below.
+        levels = list(
+            itertools.product(
+                *itertools.repeat(
+                    [st.DimLevelType.dense, st.DimLevelType.compressed], rank
+                )
+            )
+        )
+        # All permutations.
+        orderings = list(
+            map(ir.AffineMap.get_permutation, itertools.permutations(range(rank)))
+        )
+        bitwidths = [0]
+        # The first type must be a dense tensor for numpy conversion to work.
+        types = [ir.RankedTensorType.get(shape, f64)]
+        for level in levels:
+            for ordering in orderings:
+                for pwidth in bitwidths:
+                    for iwidth in bitwidths:
+                        attr = st.EncodingAttr.get(
+                            level, ordering, None, pwidth, iwidth
+                        )
+                        types.append(ir.RankedTensorType.get(shape, f64, attr))
+        #
+        # For exhaustiveness we should have one or more StressTest, such
+        # that their paths cover all 2*n*(n-1) directed pairwise combinations
+        # of the `types` set.  However, since n is already superexponential,
+        # such exhaustiveness would be prohibitive for a test that runs on
+        # every commit.  So for now we'll just pick one particular path that
+        # at least hits all n elements of the `types` set.
+        #
+        tyconv = TypeConverter(ctx)
+        size = 1
+        for d in shape:
+            size *= d
+        np_arg0 = np.arange(size, dtype=tyconv.irtype_to_dtype(f64)).reshape(*shape)
+        np_out = (
+            StressTest(tyconv)
+            .build(types)
+            .writeTo(sys.argv[1] if len(sys.argv) > 1 else None)
+            .compile(compiler)
+            .writeTo(sys.argv[2] if len(sys.argv) > 2 else None)
+            .run(np_arg0)
+        )
+        # CHECK: Passed
+        if np.allclose(np_out, np_arg0):
+            print("Passed")
+        else:
+            sys.exit("FAILURE")
+
+
+if __name__ == "__main__":
+    main()
index f5b0ab6..785d42c 100644 (file)
@@ -11,65 +11,71 @@ import numpy as np
 
 @functools.lru_cache()
 def _get_c_shared_lib(lib_name: str):
-  """Loads and returns the requested C shared library.
+    """Loads and returns the requested C shared library.
 
-  Args:
-    lib_name: A string representing the C shared library.
+    Args:
+      lib_name: A string representing the C shared library.
 
-  Returns:
-    The C shared library.
+    Returns:
+      The C shared library.
 
-  Raises:
-    OSError: If there is any problem in loading the shared library.
-    ValueError:  If the shared library doesn't contain the needed routine.
-  """
-  # This raises OSError exception if there is any problem in loading the shared
-  # library.
-  c_lib = ctypes.CDLL(lib_name)
+    Raises:
+      OSError: If there is any problem in loading the shared library.
+      ValueError:  If the shared library doesn't contain the needed routine.
+    """
+    # This raises OSError exception if there is any problem in loading the shared
+    # library.
+    c_lib = ctypes.CDLL(lib_name)
 
-  try:
-    c_lib.convertFromMLIRSparseTensorF64.restype = ctypes.c_void_p
-  except Exception as e:
-    raise ValueError('Missing function convertFromMLIRSparseTensorF64 from '
-                     f'the C shared library: {e} ') from e
+    try:
+        c_lib.convertFromMLIRSparseTensorF64.restype = ctypes.c_void_p
+    except Exception as e:
+        raise ValueError(
+            "Missing function convertFromMLIRSparseTensorF64 from "
+            f"the C shared library: {e} "
+        ) from e
 
-  return c_lib
+    return c_lib
 
 
 def sparse_tensor_to_coo_tensor(support_lib, sparse, dtype):
-  """Converts a sparse tensor to COO-flavored format.
+    """Converts a sparse tensor to COO-flavored format.
 
-  Args:
-     support_lib: A string for the supporting C shared library.
-     sparse: A ctypes.pointer to the sparse tensor descriptor.
-     dtype: The numpy data type for the tensor elements.
+    Args:
+       support_lib: A string for the supporting C shared library.
+       sparse: A ctypes.pointer to the sparse tensor descriptor.
+       dtype: The numpy data type for the tensor elements.
 
-  Returns:
-    A tuple that contains the following values:
-    rank: An integer for the rank of the tensor.
-    nse: An integer for the number of non-zero values in the tensor.
-    shape: A 1D numpy array of integers, for the shape of the tensor.
-    values: A 1D numpy array, for the non-zero values in the tensor.
-    indices: A 2D numpy array of integers, representing the indices for the
-      non-zero values in the tensor.
+    Returns:
+      A tuple that contains the following values:
+      rank: An integer for the rank of the tensor.
+      nse: An integer for the number of non-zero values in the tensor.
+      shape: A 1D numpy array of integers, for the shape of the tensor.
+      values: A 1D numpy array, for the non-zero values in the tensor.
+      indices: A 2D numpy array of integers, representing the indices for the
+        non-zero values in the tensor.
 
-  Raises:
-    OSError: If there is any problem in loading the shared library.
-    ValueError:  If the shared library doesn't contain the needed routine.
-  """
-  c_lib = _get_c_shared_lib(support_lib)
+    Raises:
+      OSError: If there is any problem in loading the shared library.
+      ValueError:  If the shared library doesn't contain the needed routine.
+    """
+    c_lib = _get_c_shared_lib(support_lib)
 
-  rank = ctypes.c_ulonglong(0)
-  nse = ctypes.c_ulonglong(0)
-  shape = ctypes.POINTER(ctypes.c_ulonglong)()
-  values = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))()
-  indices = ctypes.POINTER(ctypes.c_ulonglong)()
-  c_lib.convertFromMLIRSparseTensorF64(sparse, ctypes.byref(rank),
-                                       ctypes.byref(nse), ctypes.byref(shape),
-                                       ctypes.byref(values),
-                                       ctypes.byref(indices))
-  # Convert the returned values to the corresponding numpy types.
-  shape = np.ctypeslib.as_array(shape, shape=[rank.value])
-  values = np.ctypeslib.as_array(values, shape=[nse.value])
-  indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value])
-  return rank, nse, shape, values, indices
+    rank = ctypes.c_ulonglong(0)
+    nse = ctypes.c_ulonglong(0)
+    shape = ctypes.POINTER(ctypes.c_ulonglong)()
+    values = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))()
+    indices = ctypes.POINTER(ctypes.c_ulonglong)()
+    c_lib.convertFromMLIRSparseTensorF64(
+        sparse,
+        ctypes.byref(rank),
+        ctypes.byref(nse),
+        ctypes.byref(shape),
+        ctypes.byref(values),
+        ctypes.byref(indices),
+    )
+    # Convert the returned values to the corresponding numpy types.
+    shape = np.ctypeslib.as_array(shape, shape=[rank.value])
+    values = np.ctypeslib.as_array(values, shape=[nse.value])
+    indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value])
+    return rank, nse, shape, values, indices
index 25004f9..d549a9a 100644 (file)
@@ -9,30 +9,31 @@ from mlir import ir
 from mlir import passmanager
 from typing import Sequence
 
+
 class SparseCompiler:
-  """Sparse compiler class for compiling and building MLIR modules."""
-
-  def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
-    pipeline = f'builtin.module(sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}})'
-    self.pipeline = pipeline
-    self.opt_level = opt_level
-    self.shared_libs = shared_libs
-
-  def __call__(self, module: ir.Module):
-    """Convenience application method."""
-    self.compile(module)
-
-  def compile(self, module: ir.Module):
-    """Compiles the module by invoking the sparse copmiler pipeline."""
-    passmanager.PassManager.parse(self.pipeline).run(module.operation)
-
-  def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
-    """Wraps the module in a JIT execution engine."""
-    return execution_engine.ExecutionEngine(
-        module, opt_level=self.opt_level, shared_libs=self.shared_libs)
-
-  def compile_and_jit(self,
-                      module: ir.Module) -> execution_engine.ExecutionEngine:
-    """Compiles and jits the module."""
-    self.compile(module)
-    return self.jit(module)
+    """Sparse compiler class for compiling and building MLIR modules."""
+
+    def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
+        pipeline = f"builtin.module(sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}})"
+        self.pipeline = pipeline
+        self.opt_level = opt_level
+        self.shared_libs = shared_libs
+
+    def __call__(self, module: ir.Module):
+        """Convenience application method."""
+        self.compile(module)
+
+    def compile(self, module: ir.Module):
+        """Compiles the module by invoking the sparse copmiler pipeline."""
+        passmanager.PassManager.parse(self.pipeline).run(module.operation)
+
+    def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+        """Wraps the module in a JIT execution engine."""
+        return execution_engine.ExecutionEngine(
+            module, opt_level=self.opt_level, shared_libs=self.shared_libs
+        )
+
+    def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+        """Compiles and jits the module."""
+        self.compile(module)
+        return self.jit(module)
index 7137d0f..f1bbcf4 100644 (file)
@@ -1,5 +1,5 @@
 # Disable ASAN's leak detection for python taco tests.
-config.environment['ASAN_OPTIONS'] = 'detect_leaks=0'
+config.environment["ASAN_OPTIONS"] = "detect_leaks=0"
 # Only run when python bindings are enabled.
 if not config.enable_bindings_python:
-  config.unsupported = True
+    config.unsupported = True
index 88b13ae..2d558f8 100644 (file)
@@ -46,10 +46,10 @@ A[i, j] = B[i, k, l] * D[l, j] * C[k, j]
 
 # Perform the MTTKRP computation and write the result to file.
 with tempfile.TemporaryDirectory() as test_dir:
-  golden_file = os.path.join(_SCRIPT_PATH, "data/gold_A.tns")
-  out_file = os.path.join(test_dir, "A.tns")
-  pt.write(out_file, A)
-  #
-  # CHECK: Compare result True
-  #
-  print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
+    golden_file = os.path.join(_SCRIPT_PATH, "data/gold_A.tns")
+    out_file = os.path.join(test_dir, "A.tns")
+    pt.write(out_file, A)
+    #
+    # CHECK: Compare result True
+    #
+    print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
index ba4ea9c..ef94ea9 100644 (file)
@@ -46,13 +46,13 @@ expected = """; extended FROSTT format
 
 # Force evaluation of the kernels by writing out X and Y.
 with tempfile.TemporaryDirectory() as test_dir:
-  x_file = os.path.join(test_dir, "X.tns")
-  y_file = os.path.join(test_dir, "Y.tns")
-  pt.write(x_file, X)
-  pt.write(y_file, Y)
-  #
-  # CHECK: Compare result True True
-  #
-  x_data = utils.file_as_string(x_file)
-  y_data = utils.file_as_string(y_file)
-  print(f"Compare result {x_data == expected} {y_data == expected}")
+    x_file = os.path.join(test_dir, "X.tns")
+    y_file = os.path.join(test_dir, "Y.tns")
+    pt.write(x_file, X)
+    pt.write(y_file, Y)
+    #
+    # CHECK: Compare result True True
+    #
+    x_data = utils.file_as_string(x_file)
+    y_data = utils.file_as_string(y_file)
+    print(f"Compare result {x_data == expected} {y_data == expected}")
index 10309cb..02bbbc0 100644 (file)
@@ -26,10 +26,10 @@ C[i, j] = A[i, k] * B[k, j]
 
 # Force evaluation of the kernel by writing out C.
 with tempfile.TemporaryDirectory() as test_dir:
-  golden_file = os.path.join(_SCRIPT_PATH, "data/gold_C.tns")
-  out_file = os.path.join(test_dir, "C.tns")
-  pt.write(out_file, C)
-  #
-  # CHECK: Compare result True
-  #
-  print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
+    golden_file = os.path.join(_SCRIPT_PATH, "data/gold_C.tns")
+    out_file = os.path.join(test_dir, "C.tns")
+    pt.write(out_file, C)
+    #
+    # CHECK: Compare result True
+    #
+    print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
index de150ea..2038a47 100644 (file)
@@ -47,10 +47,10 @@ y[i] = A[i, j] * x[j] + z[i]
 
 # Perform the SpMV computation and write the result to file
 with tempfile.TemporaryDirectory() as test_dir:
-  golden_file = os.path.join(_SCRIPT_PATH, "data/gold_y.tns")
-  out_file = os.path.join(test_dir, "y.tns")
-  pt.write(out_file, y)
-  #
-  # CHECK: Compare result True
-  #
-  print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
+    golden_file = os.path.join(_SCRIPT_PATH, "data/gold_y.tns")
+    out_file = os.path.join(test_dir, "y.tns")
+    pt.write(out_file, y)
+    #
+    # CHECK: Compare result True
+    #
+    print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
index c1e6c87..cd24e0d 100644 (file)
@@ -18,11 +18,10 @@ i, j, k, l, m = pt.get_index_vars(5)
 alpha = pt.tensor(42.0)
 
 # Set up some sparse tensors with different dim annotations and ordering.
-S = pt.tensor([8, 8, 8],
-              pt.format([pt.compressed, pt.dense, pt.compressed], [1, 0, 2]))
-X = pt.tensor([8, 8, 8],
-              pt.format([pt.compressed, pt.compressed, pt.compressed],
-                        [1, 0, 2]))
+S = pt.tensor([8, 8, 8], pt.format([pt.compressed, pt.dense, pt.compressed], [1, 0, 2]))
+X = pt.tensor(
+    [8, 8, 8], pt.format([pt.compressed, pt.compressed, pt.compressed], [1, 0, 2])
+)
 S.insert([0, 0, 0], 2.0)
 S.insert([1, 1, 1], 3.0)
 S.insert([4, 4, 4], 4.0)
@@ -32,16 +31,14 @@ X[i, j, k] = alpha[0] * S[i, j, k]
 
 # Set up tensors with a dense last dimension. This results in a full
 # enveloping storage of all last "rows" with one or more nonzeros.
-T = pt.tensor([1, 2, 3, 4, 5],
-              pt.format([
-                  pt.compressed, pt.compressed, pt.compressed, pt.compressed,
-                  pt.dense
-              ]))
-Y = pt.tensor([1, 2, 3, 4, 5],
-              pt.format([
-                  pt.compressed, pt.compressed, pt.compressed, pt.compressed,
-                  pt.dense
-              ]))
+T = pt.tensor(
+    [1, 2, 3, 4, 5],
+    pt.format([pt.compressed, pt.compressed, pt.compressed, pt.compressed, pt.dense]),
+)
+Y = pt.tensor(
+    [1, 2, 3, 4, 5],
+    pt.format([pt.compressed, pt.compressed, pt.compressed, pt.compressed, pt.dense]),
+)
 T.insert([0, 1, 2, 3, 4], -2.0)
 
 Y[i, j, k, l, m] = alpha[0] * T[i, j, k, l, m]
@@ -85,18 +82,18 @@ z_expected = """; extended FROSTT format
 
 # Force evaluation of the kernel by writing out X.
 with tempfile.TemporaryDirectory() as test_dir:
-  x_file = os.path.join(test_dir, 'X.tns')
-  pt.write(x_file, X)
-  y_file = os.path.join(test_dir, 'Y.tns')
-  pt.write(y_file, Y)
-  z_file = os.path.join(test_dir, 'Z.tns')
-  pt.write(z_file, Z)
-  #
-  # CHECK: Compare result True True True
-  #
-  x_data = utils.file_as_string(x_file)
-  y_data = utils.file_as_string(y_file)
-  z_data = utils.file_as_string(z_file)
-  print(
-      f'Compare result {x_data == x_expected} {y_data == y_expected} {z_data == z_expected}'
-  )
+    x_file = os.path.join(test_dir, "X.tns")
+    pt.write(x_file, X)
+    y_file = os.path.join(test_dir, "Y.tns")
+    pt.write(y_file, Y)
+    z_file = os.path.join(test_dir, "Z.tns")
+    pt.write(z_file, Z)
+    #
+    # CHECK: Compare result True True True
+    #
+    x_data = utils.file_as_string(x_file)
+    y_data = utils.file_as_string(y_file)
+    z_data = utils.file_as_string(z_file)
+    print(
+        f"Compare result {x_data == x_expected} {y_data == y_expected} {z_data == z_expected}"
+    )
index 60b91de..206ffa9 100644 (file)
@@ -12,7 +12,7 @@ compressed = pt.compressed
 
 i, j = pt.get_index_vars(2)
 A = pt.tensor([2, 3])
-S = pt.tensor(3) # S is a scalar tensor.
+S = pt.tensor(3)  # S is a scalar tensor.
 B = pt.tensor([2, 3], compressed)
 A.insert([0, 1], 10)
 A.insert([1, 2], 40)
@@ -26,11 +26,11 @@ passed += np.array_equal(values, [30.0, 120.0])
 
 # Sum all the values in A.
 S[0] = A[i, j]
-passed += (S.get_scalar_value() == 50.0)
+passed += S.get_scalar_value() == 50.0
 
 indices, values = S.get_coordinates_and_values()
-passed += (len(indices)==0)
-passed += (values == 50.0)
+passed += len(indices) == 0
+passed += values == 50.0
 
 # CHECK: Number of passed: 5
 print("Number of passed:", passed)
index 8fd545b..b0fed50 100644 (file)
@@ -12,20 +12,20 @@ compressed = pt.compressed
 passed = 0
 all_types = [pt.complex64, pt.complex128]
 for t in all_types:
-  i, j = pt.get_index_vars(2)
-  A = pt.tensor([2, 3], dtype=t)
-  B = pt.tensor([2, 3], dtype=t)
-  C = pt.tensor([2, 3], compressed, dtype=t)
-  A.insert([0, 1], 10 + 20j)
-  A.insert([1, 2], 40 + 0.5j)
-  B.insert([0, 0], 20)
-  B.insert([1, 2], 30 + 15j)
-  C[i, j] = A[i, j] + B[i, j]
+    i, j = pt.get_index_vars(2)
+    A = pt.tensor([2, 3], dtype=t)
+    B = pt.tensor([2, 3], dtype=t)
+    C = pt.tensor([2, 3], compressed, dtype=t)
+    A.insert([0, 1], 10 + 20j)
+    A.insert([1, 2], 40 + 0.5j)
+    B.insert([0, 0], 20)
+    B.insert([1, 2], 30 + 15j)
+    C[i, j] = A[i, j] + B[i, j]
 
-  indices, values = C.get_coordinates_and_values()
-  passed += isinstance(values[0], t.value)
-  passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
-  passed += np.allclose(values, [20, 10 + 20j, 70 + 15.5j])
+    indices, values = C.get_coordinates_and_values()
+    passed += isinstance(values[0], t.value)
+    passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
+    passed += np.allclose(values, [20, 10 + 20j, 70 + 15.5j])
 
 # CHECK: Number of passed: 6
 print("Number of passed:", passed)
index cec687f..4ba2836 100644 (file)
@@ -12,24 +12,22 @@ compressed = pt.compressed
 dense = pt.dense
 
 passed = 0
-all_types = [
-    pt.int8, pt.int16, pt.int32, pt.int64, pt.float16, pt.float32, pt.float64
-]
+all_types = [pt.int8, pt.int16, pt.int32, pt.int64, pt.float16, pt.float32, pt.float64]
 for t in all_types:
-  i, j = pt.get_index_vars(2)
-  A = pt.tensor([2, 3], dtype=t)
-  B = pt.tensor([2, 3], dtype=t)
-  C = pt.tensor([2, 3], compressed, dtype=t)
-  A.insert([0, 1], 10)
-  A.insert([1, 2], 40)
-  B.insert([0, 0], 20)
-  B.insert([1, 2], 30)
-  C[i, j] = A[i, j] + B[i, j]
+    i, j = pt.get_index_vars(2)
+    A = pt.tensor([2, 3], dtype=t)
+    B = pt.tensor([2, 3], dtype=t)
+    C = pt.tensor([2, 3], compressed, dtype=t)
+    A.insert([0, 1], 10)
+    A.insert([1, 2], 40)
+    B.insert([0, 0], 20)
+    B.insert([1, 2], 30)
+    C[i, j] = A[i, j] + B[i, j]
 
-  indices, values = C.get_coordinates_and_values()
-  passed += isinstance(values[0], t.value)
-  passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
-  passed += np.allclose(values, [20.0, 10.0, 70.0])
+    indices, values = C.get_coordinates_and_values()
+    passed += isinstance(values[0], t.value)
+    passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
+    passed += np.allclose(values, [20.0, 10.0, 70.0])
 
 # CHECK: Number of passed: 21
 print("Number of passed:", passed)
index a138678..78bce34 100644 (file)
@@ -10,8 +10,8 @@ from tools import mlir_pytaco_api as pt
 
 i, j = pt.get_index_vars(2)
 # Both tensors are true dense tensors.
-A = pt.from_array(np.full([2,3], 1, dtype=np.float64))
-B = pt.from_array(np.full([2,3], 2, dtype=np.float64))
+A = pt.from_array(np.full([2, 3], 1, dtype=np.float64))
+B = pt.from_array(np.full([2, 3], 2, dtype=np.float64))
 # Define the result tensor as a true dense tensor. The parameter is_dense=True
 # is an MLIR-PyTACO extension.
 C = pt.tensor([2, 3], dtype=pt.float64, is_dense=True)
index 44d28b0..b3194f7 100644 (file)
@@ -65,19 +65,20 @@ _SubtreeLeafChecker = Optional[Callable[..., bool]]
 
 
 class Type(enum.Enum):
-  """The data types supported by TACO.
+    """The data types supported by TACO.
 
-  We use numpy data types to implement the enum data types.
-  """
-  INT8 = np.int8
-  INT16 = np.int16
-  INT32 = np.int32
-  INT64 = np.int64
-  FLOAT16 = np.float16
-  FLOAT32 = np.float32
-  FLOAT64 = np.float64
-  COMPLEX64 = np.complex64
-  COMPLEX128 = np.complex128
+    We use numpy data types to implement the enum data types.
+    """
+
+    INT8 = np.int8
+    INT16 = np.int16
+    INT32 = np.int32
+    INT64 = np.int64
+    FLOAT16 = np.float16
+    FLOAT32 = np.float32
+    FLOAT64 = np.float64
+    COMPLEX64 = np.complex64
+    COMPLEX128 = np.complex128
 
 
 # All floating point type enums.
@@ -88,1732 +89,1810 @@ _INT_TYPES = (Type.INT8, Type.INT16, Type.INT32, Type.INT64)
 _COMPLEX_TYPES = (Type.COMPLEX64, Type.COMPLEX128)
 # Type alias for any numpy type used to implement the runtime support for the
 # enum data types.
-_AnyRuntimeType = Union[np.int8, np.int16, np.int32, np.int64, np.float16,
-                        np.float32, np.float64, np.complex64, np.complex128]
+_AnyRuntimeType = Union[
+    np.int8,
+    np.int16,
+    np.int32,
+    np.int64,
+    np.float16,
+    np.float32,
+    np.float64,
+    np.complex64,
+    np.complex128,
+]
 
 
 @dataclasses.dataclass(frozen=True)
 class DType:
-  """The data type class.
+    """The data type class.
 
-  We support the TACO API dtype class with an alias of this class.
+    We support the TACO API dtype class with an alias of this class.
 
-  The following methods are defined by the TACO API:
-    is_float: Returns whether the data type represents a floating point value.
-    is_int:   Returns whether the data type represents an integral value.
+    The following methods are defined by the TACO API:
+      is_float: Returns whether the data type represents a floating point value.
+      is_int:   Returns whether the data type represents an integral value.
 
-  Attributes:
-    kind: A Type enum representing the data type.
-    value: The numpy data type for the TACO data type.
-  """
-  kind: Type = Type.FLOAT32
+    Attributes:
+      kind: A Type enum representing the data type.
+      value: The numpy data type for the TACO data type.
+    """
 
-  def is_float(self) -> bool:
-    """Returns whether the data type represents a floating point value."""
-    return self.kind in _FLOAT_TYPES
+    kind: Type = Type.FLOAT32
 
-  def is_int(self) -> bool:
-    """Returns whether the data type represents an integral value."""
-    return self.kind in _INT_TYPES
+    def is_float(self) -> bool:
+        """Returns whether the data type represents a floating point value."""
+        return self.kind in _FLOAT_TYPES
 
-  def is_complex(self) -> bool:
-    """Returns whether the data type represents a complex value."""
-    return self.kind in _COMPLEX_TYPES
+    def is_int(self) -> bool:
+        """Returns whether the data type represents an integral value."""
+        return self.kind in _INT_TYPES
 
-  @property
-  def value(self) -> _AnyRuntimeType:
-    """Returns the numpy dtype for the data type."""
-    return self.kind.value
+    def is_complex(self) -> bool:
+        """Returns whether the data type represents a complex value."""
+        return self.kind in _COMPLEX_TYPES
+
+    @property
+    def value(self) -> _AnyRuntimeType:
+        """Returns the numpy dtype for the data type."""
+        return self.kind.value
 
 
 def _dtype_to_mlir_str(dtype: DType) -> str:
-  """Returns the MLIR string for the given dtype."""
-  dtype_to_str = {
-      Type.INT16: "i8",
-      Type.INT16: "i16",
-      Type.INT32: "i32",
-      Type.INT64: "i64",
-      Type.FLOAT16: "f16",
-      Type.FLOAT32: "f32",
-      Type.FLOAT64: "f64",
-      Type.COMPLEX64: "complex<f32>",
-      Type.COMPLEX128: "complex<f64>"
-  }
-  return dtype_to_str[dtype.kind]
+    """Returns the MLIR string for the given dtype."""
+    dtype_to_str = {
+        Type.INT16: "i8",
+        Type.INT16: "i16",
+        Type.INT32: "i32",
+        Type.INT64: "i64",
+        Type.FLOAT16: "f16",
+        Type.FLOAT32: "f32",
+        Type.FLOAT64: "f64",
+        Type.COMPLEX64: "complex<f32>",
+        Type.COMPLEX128: "complex<f64>",
+    }
+    return dtype_to_str[dtype.kind]
 
 
 def _nptype_to_taco_type(ty: np.dtype) -> DType:
-  """Returns the TACO type for the given numpy type."""
-  nptype_to_dtype = {
-      np.int8: Type.INT8,
-      np.int16: Type.INT16,
-      np.int32: Type.INT32,
-      np.int64: Type.INT64,
-      np.float16: Type.FLOAT16,
-      np.float32: Type.FLOAT32,
-      np.float64: Type.FLOAT64,
-      np.complex64: Type.COMPLEX64,
-      np.complex128: Type.COMPLEX128
-  }
-  return DType(nptype_to_dtype[ty])
+    """Returns the TACO type for the given numpy type."""
+    nptype_to_dtype = {
+        np.int8: Type.INT8,
+        np.int16: Type.INT16,
+        np.int32: Type.INT32,
+        np.int64: Type.INT64,
+        np.float16: Type.FLOAT16,
+        np.float32: Type.FLOAT32,
+        np.float64: Type.FLOAT64,
+        np.complex64: Type.COMPLEX64,
+        np.complex128: Type.COMPLEX128,
+    }
+    return DType(nptype_to_dtype[ty])
 
 
 def _mlir_type_from_taco_type(dtype: DType) -> ir.Type:
-  """Returns the MLIR type corresponding to the given TACO type."""
-  dtype_to_irtype = {
-      Type.INT8: ir.IntegerType.get_signless(8),
-      Type.INT16: ir.IntegerType.get_signless(16),
-      Type.INT32: ir.IntegerType.get_signless(32),
-      Type.INT64: ir.IntegerType.get_signless(64),
-      Type.FLOAT16: ir.F16Type.get(),
-      Type.FLOAT32: ir.F32Type.get(),
-      Type.FLOAT64: ir.F64Type.get(),
-      Type.COMPLEX64: ir.ComplexType.get(ir.F32Type.get()),
-      Type.COMPLEX128: ir.ComplexType.get(ir.F64Type.get())
-  }
-  return dtype_to_irtype[dtype.kind]
+    """Returns the MLIR type corresponding to the given TACO type."""
+    dtype_to_irtype = {
+        Type.INT8: ir.IntegerType.get_signless(8),
+        Type.INT16: ir.IntegerType.get_signless(16),
+        Type.INT32: ir.IntegerType.get_signless(32),
+        Type.INT64: ir.IntegerType.get_signless(64),
+        Type.FLOAT16: ir.F16Type.get(),
+        Type.FLOAT32: ir.F32Type.get(),
+        Type.FLOAT64: ir.F64Type.get(),
+        Type.COMPLEX64: ir.ComplexType.get(ir.F32Type.get()),
+        Type.COMPLEX128: ir.ComplexType.get(ir.F64Type.get()),
+    }
+    return dtype_to_irtype[dtype.kind]
+
 
 def _ctype_pointer_from_array(array: np.ndarray) -> ctypes.pointer:
-  """Returns the ctype pointer for the given numpy array."""
-  return ctypes.pointer(
-      ctypes.pointer(runtime.get_ranked_memref_descriptor(array)))
+    """Returns the ctype pointer for the given numpy array."""
+    return ctypes.pointer(ctypes.pointer(runtime.get_ranked_memref_descriptor(array)))
 
 
 class ModeFormat(enum.Enum):
-  """The tensor dimension storage format class.
+    """The tensor dimension storage format class.
 
-  We support the TACO API mode_format class with an alias of this class.
+    We support the TACO API mode_format class with an alias of this class.
+
+    In TACO, a tensor dimension is called a mode and the storage format for a
+    tensor dimension is called a mode format.
+    """
 
-  In TACO, a tensor dimension is called a mode and the storage format for a
-  tensor dimension is called a mode format.
-  """
-  DENSE = sparse_tensor.DimLevelType.dense
-  COMPRESSED = sparse_tensor.DimLevelType.compressed
+    DENSE = sparse_tensor.DimLevelType.dense
+    COMPRESSED = sparse_tensor.DimLevelType.compressed
 
 
-def _mode_format_operation(a: ModeFormat, b: ModeFormat,
-                           op: _LogicalOp) -> ModeFormat:
-  """Implements the given operator on ModeFormat."""
-  return (ModeFormat.COMPRESSED
-          if op(a == ModeFormat.COMPRESSED, b == ModeFormat.COMPRESSED) else
-          ModeFormat.DENSE)
+def _mode_format_operation(a: ModeFormat, b: ModeFormat, op: _LogicalOp) -> ModeFormat:
+    """Implements the given operator on ModeFormat."""
+    return (
+        ModeFormat.COMPRESSED
+        if op(a == ModeFormat.COMPRESSED, b == ModeFormat.COMPRESSED)
+        else ModeFormat.DENSE
+    )
 
 
 def _mode_format_estimator(op: _BinaryOp) -> _ModeFormatOp:
-  """Produces a ModeFormat operator for the given binary operator.
+    """Produces a ModeFormat operator for the given binary operator.
 
-  The ModeFormat operator is used as a heuristic to derive the destination
-  dimension sparsity from the source dimension sparsity. In particular, if the
-  binary operator produces a disjunction of the zero values from its source
-  operands, such as the MUL operator, we return a ModeFormat operator that
-  uses operator.or_. That is, we estimate that a dimension for the MUL
-  operation result to be sparse if either of its source operands is sparse.
+    The ModeFormat operator is used as a heuristic to derive the destination
+    dimension sparsity from the source dimension sparsity. In particular, if the
+    binary operator produces a disjunction of the zero values from its source
+    operands, such as the MUL operator, we return a ModeFormat operator that
+    uses operator.or_. That is, we estimate that a dimension for the MUL
+    operation result to be sparse if either of its source operands is sparse.
 
-  On the other hand, if the binary operator produces a conjunction of the
-  zero values from its source operands, such as the ADD operator, we return
-  a ModeFormat operator that uses operator.and_. In this case, we estimate
-  that a dimension for the ADD operation result to be sparse if both of its
-  source operands are sparse.
+    On the other hand, if the binary operator produces a conjunction of the
+    zero values from its source operands, such as the ADD operator, we return
+    a ModeFormat operator that uses operator.and_. In this case, we estimate
+    that a dimension for the ADD operation result to be sparse if both of its
+    source operands are sparse.
 
-  Args:
-    op: A _BinaryOp object representing a supporting operator on tensors.
+    Args:
+      op: A _BinaryOp object representing a supporting operator on tensors.
 
-  Returns:
-    A ModeFormatOp for estimating the destination dimension sparsity from
-    the source dimension sparsity.
-  """
-  conjunction = functools.partial(_mode_format_operation, op=operator.and_)
-  disjunction = functools.partial(_mode_format_operation, op=operator.or_)
-  return conjunction if op(0, 1) != 0 else disjunction
+    Returns:
+      A ModeFormatOp for estimating the destination dimension sparsity from
+      the source dimension sparsity.
+    """
+    conjunction = functools.partial(_mode_format_operation, op=operator.and_)
+    disjunction = functools.partial(_mode_format_operation, op=operator.or_)
+    return conjunction if op(0, 1) != 0 else disjunction
 
 
 def _all_instance_of(collection: Iterable, cls: Any) -> bool:
-  """Returns true if all elements of the iterable is an instance of cls."""
-  return all(isinstance(e, cls) for e in collection)
+    """Returns true if all elements of the iterable is an instance of cls."""
+    return all(isinstance(e, cls) for e in collection)
 
 
 def _identity_ordering(rank: int) -> List[int]:
-  """Returns the identity ordering for tensor of given rank."""
-  return list(range(rank))
+    """Returns the identity ordering for tensor of given rank."""
+    return list(range(rank))
 
 
 @dataclasses.dataclass(frozen=True)
 class ModeOrdering:
-  """The tensor dimension ordering class.
-
-  We support the TACO API mode_ordering class with an alias of this class.
+    """The tensor dimension ordering class.
 
-  Attributes:
-    ordering: A list of integers representing the ordering of the tensor
-      dimensions.
-  """
-  ordering: List[int]
+    We support the TACO API mode_ordering class with an alias of this class.
 
-  def __post_init__(self) -> None:
-    """Verifies the value in ordering.
-
-    Raises:
-       ValueError: If ordering is not a list of integers.
+    Attributes:
+      ordering: A list of integers representing the ordering of the tensor
+        dimensions.
     """
-    if (not isinstance(self.ordering, list) or
-        not _all_instance_of(self.ordering, int)):
-      raise ValueError("Ordering must be a list of integers: "
-                       f"{self.ordering}")
-    # Check that ordering is a permutation of the dimension numbers.
-    if sorted(self.ordering) != _identity_ordering(self.rank()):
-      raise ValueError(f"Invalid ordering: {self.ordering} != "
-                       f"permutation{_identity_ordering(self.rank())}.")
-
-  def rank(self) -> int:
-    """Returns the number of dimensions represented by the ordering."""
-    return len(self.ordering)
 
+    ordering: List[int]
 
-@dataclasses.dataclass(frozen=True)
-class ModeFormatPack:
-  """The tensor dimension format class.
+    def __post_init__(self) -> None:
+        """Verifies the value in ordering.
 
-  We support the TACO API mode_format_pack class with an alias of this class.
+        Raises:
+           ValueError: If ordering is not a list of integers.
+        """
+        if not isinstance(self.ordering, list) or not _all_instance_of(
+            self.ordering, int
+        ):
+            raise ValueError("Ordering must be a list of integers: " f"{self.ordering}")
+        # Check that ordering is a permutation of the dimension numbers.
+        if sorted(self.ordering) != _identity_ordering(self.rank()):
+            raise ValueError(
+                f"Invalid ordering: {self.ordering} != "
+                f"permutation{_identity_ordering(self.rank())}."
+            )
 
-  The storage format of a tensor contains one mode_format for each tensor
-  dimension.
+    def rank(self) -> int:
+        """Returns the number of dimensions represented by the ordering."""
+        return len(self.ordering)
 
-  Attributes:
-    formats: A list of ModeFormat representing the storage format for each of
-      the tensor dimension.
-  """
-  formats: List[ModeFormat]
 
-  def __post_init__(self) -> None:
-    """Verifies the value in formats.
-
-    Raises:
-       ValueError: If formats is not a list of ModeFormats.
-    """
-    if (not isinstance(self.formats, list) or
-        not _all_instance_of(self.formats, ModeFormat)):
-      raise ValueError("Formats must be a list of ModeFormat: "
-                       f"{self.formats}")
-
-  def rank(self) -> int:
-    """Returns the number of dimensions represented by the format pack."""
-    return len(self.formats)
+@dataclasses.dataclass(frozen=True)
+class ModeFormatPack:
+    """The tensor dimension format class.
 
+    We support the TACO API mode_format_pack class with an alias of this class.
 
-@dataclasses.dataclass
-class Format:
-  """The tensor format class defined by the TACO API.
-
-  Attributes:
-    format_pack: A ModeFormatPack representing the storage format for the tensor
-      dimensions.
-    ordering: A ModeOrdering representing the tensor dimension ordering in the
-      storage.
-  """
-  format_pack: ModeFormatPack
-  ordering: Optional[ModeOrdering] = None
-
-  def __post_init__(self) -> None:
-    """Verifies and fixes up the values in format_pack and ordering.
-
-    Verifies and fixes up the values in format_pack and ordering to supports the
-    initializer syntax defined by the TACO API. If format_pack is a list of
-    ModeFormat, replaces it with ModeFormatPack constructed from the list. If
-    ordering is not provided, set ordering to the natural ordering for the rank
-    corresponding to format_pack.
+    The storage format of a tensor contains one mode_format for each tensor
+    dimension.
 
-    Raises:
-       ValueError: If format_pack is not an instance of ModeFormatPack or if
-         ordering is not an instance of ModeOrdering.
+    Attributes:
+      formats: A list of ModeFormat representing the storage format for each of
+        the tensor dimension.
     """
-    if isinstance(self.format_pack, list):
-      if not _all_instance_of(self.format_pack, ModeFormat):
-        raise ValueError(f"Expected a list of ModeFormat: {self.format_pack}")
-      self.format_pack = ModeFormatPack(self.format_pack)
-    if not isinstance(self.format_pack, ModeFormatPack):
-      raise ValueError(f"Expected ModeFormatpack: {self.format_pack}")
-
-    if self.ordering is None:
-      self.ordering = ModeOrdering(list(range(self.rank())))
-    if isinstance(self.ordering, list):
-      if not _all_instance_of(self.ordering, int):
-        raise ValueError(f"Expected a list of integer: {self.ordering}")
-      self.ordering = ModeOrdering(self.ordering)
-    if not isinstance(self.ordering, ModeOrdering):
-      raise ValueError(f"Expected ModeOrdering: {self.ordering}")
-
-    if self.format_pack.rank() != self.ordering.rank():
-      raise ValueError("Inconsistent ModeFormatPack and ModeOrdering: "
-                       f"len({self.format_pack}) != "
-                       f"len({self.ordering})")
-
-  def rank(self) -> int:
-    """Returns the number of dimensions represented by the format."""
-    return self.format_pack.rank()
-
-  def get_permutation_and_sparsity(self) -> Tuple[np.ndarray, np.ndarray]:
-    """Constructs the numpy arrays for the permutation and sparsity."""
-    perm = np.array(self.ordering.ordering, dtype=np.ulonglong)
-    a = [f.value for f in self.format_pack.formats]
-    sparse = np.array(a, dtype=np.uint8)
-    return (perm, sparse)
-
-  def mlir_tensor_attr(self) -> Optional[sparse_tensor.EncodingAttr]:
-    """Constructs the MLIR attributes for the tensor format."""
-    order = (
-        range(self.rank()) if
-        (self.ordering is None) else self.ordering.ordering)
-    mlir_storage_format = [f.value for f in self.format_pack.formats]
-    return sparse_tensor.EncodingAttr.get(mlir_storage_format,
-                                          ir.AffineMap.get_permutation(order),
-                                          None, _POS_WIDTH, _CRD_WIDTH)
-
-
-def _make_format(formats: List[ModeFormat],
-                 ordering: Optional[List[int]] = None) -> Format:
-  """Constructs a format from a list of ModeFormat and an optional ordering.
-
-  Args:
-    formats: A list of ModeFormat, one for each dimension of a tensor.
-    ordering: An optional list of integer, for the ordering of the tensor
-      dimensions. When an ordering is not given, the identity ordering is used.
-
-  Returns:
-    A tensor format object.
-
-  Raises:
-    ValueError: If formats is not a list of ModeFormat or the length of formats
-      is not consistent with the len of ordering.
-  """
-  ordering = ordering or _identity_ordering(len(formats))
-  return Format(ModeFormatPack(formats), ModeOrdering(ordering))
 
+    formats: List[ModeFormat]
 
-class IndexExpr(abc.ABC):
-  """The index notation base class.
+    def __post_init__(self) -> None:
+        """Verifies the value in formats.
 
-  We support the TACO API index_expression class with an alias of this class.
-  """
+        Raises:
+           ValueError: If formats is not a list of ModeFormats.
+        """
+        if not isinstance(self.formats, list) or not _all_instance_of(
+            self.formats, ModeFormat
+        ):
+            raise ValueError("Formats must be a list of ModeFormat: " f"{self.formats}")
 
-  def _verify_operand_and_build_expr(self, rhs, op: _BinaryOp) -> "_BinaryExpr":
-    """Verifies the RHS operand and returns a binary expression.
+    def rank(self) -> int:
+        """Returns the number of dimensions represented by the format pack."""
+        return len(self.formats)
 
-    Args:
-      rhs: The RHS of the binary operation, which could be any Python object
-        from user inputs.
-      op: A _BinaryOp object representing the binary operator.
 
-    Raises:
-      ValueError: If rhs is not an IndexExpr.
-    """
-    if not isinstance(rhs, IndexExpr):
-      raise ValueError(f"Expected IndexExpr: {rhs}")
-    return _BinaryExpr(op, self, rhs)
-
-  def _build_unary_expr(self, op: _UnaryOp) -> "_UnaryExpr":
-    """Build a unary expression.
+@dataclasses.dataclass
+class Format:
+    """The tensor format class defined by the TACO API.
 
-    Args:
-      op: A _UnaryOp object representing the unary operation.
+    Attributes:
+      format_pack: A ModeFormatPack representing the storage format for the tensor
+        dimensions.
+      ordering: A ModeOrdering representing the tensor dimension ordering in the
+        storage.
     """
-    return _UnaryExpr(op, self)
 
-  def __add__(self, rhs) -> "_BinaryExpr":
-    """Defines the operator +.
+    format_pack: ModeFormatPack
+    ordering: Optional[ModeOrdering] = None
+
+    def __post_init__(self) -> None:
+        """Verifies and fixes up the values in format_pack and ordering.
+
+        Verifies and fixes up the values in format_pack and ordering to supports the
+        initializer syntax defined by the TACO API. If format_pack is a list of
+        ModeFormat, replaces it with ModeFormatPack constructed from the list. If
+        ordering is not provided, set ordering to the natural ordering for the rank
+        corresponding to format_pack.
+
+        Raises:
+           ValueError: If format_pack is not an instance of ModeFormatPack or if
+             ordering is not an instance of ModeOrdering.
+        """
+        if isinstance(self.format_pack, list):
+            if not _all_instance_of(self.format_pack, ModeFormat):
+                raise ValueError(f"Expected a list of ModeFormat: {self.format_pack}")
+            self.format_pack = ModeFormatPack(self.format_pack)
+        if not isinstance(self.format_pack, ModeFormatPack):
+            raise ValueError(f"Expected ModeFormatpack: {self.format_pack}")
+
+        if self.ordering is None:
+            self.ordering = ModeOrdering(list(range(self.rank())))
+        if isinstance(self.ordering, list):
+            if not _all_instance_of(self.ordering, int):
+                raise ValueError(f"Expected a list of integer: {self.ordering}")
+            self.ordering = ModeOrdering(self.ordering)
+        if not isinstance(self.ordering, ModeOrdering):
+            raise ValueError(f"Expected ModeOrdering: {self.ordering}")
+
+        if self.format_pack.rank() != self.ordering.rank():
+            raise ValueError(
+                "Inconsistent ModeFormatPack and ModeOrdering: "
+                f"len({self.format_pack}) != "
+                f"len({self.ordering})"
+            )
+
+    def rank(self) -> int:
+        """Returns the number of dimensions represented by the format."""
+        return self.format_pack.rank()
+
+    def get_permutation_and_sparsity(self) -> Tuple[np.ndarray, np.ndarray]:
+        """Constructs the numpy arrays for the permutation and sparsity."""
+        perm = np.array(self.ordering.ordering, dtype=np.ulonglong)
+        a = [f.value for f in self.format_pack.formats]
+        sparse = np.array(a, dtype=np.uint8)
+        return (perm, sparse)
+
+    def mlir_tensor_attr(self) -> Optional[sparse_tensor.EncodingAttr]:
+        """Constructs the MLIR attributes for the tensor format."""
+        order = (
+            range(self.rank()) if (self.ordering is None) else self.ordering.ordering
+        )
+        mlir_storage_format = [f.value for f in self.format_pack.formats]
+        return sparse_tensor.EncodingAttr.get(
+            mlir_storage_format,
+            ir.AffineMap.get_permutation(order),
+            None,
+            _POS_WIDTH,
+            _CRD_WIDTH,
+        )
+
+
+def _make_format(
+    formats: List[ModeFormat], ordering: Optional[List[int]] = None
+) -> Format:
+    """Constructs a format from a list of ModeFormat and an optional ordering.
 
     Args:
-      rhs: The value being added, which could be any Python object from user
-        inputs.
+      formats: A list of ModeFormat, one for each dimension of a tensor.
+      ordering: An optional list of integer, for the ordering of the tensor
+        dimensions. When an ordering is not given, the identity ordering is used.
 
     Returns:
-      A _BinaryExpr object representing the operation.
+      A tensor format object.
 
     Raises:
-      ValueError: If rhs is not an IndexExpr.
+      ValueError: If formats is not a list of ModeFormat or the length of formats
+        is not consistent with the len of ordering.
     """
-    return self._verify_operand_and_build_expr(rhs, operator.add)
+    ordering = ordering or _identity_ordering(len(formats))
+    return Format(ModeFormatPack(formats), ModeOrdering(ordering))
 
-  def __mul__(self, rhs) -> "_BinaryExpr":
-    """Defines the operator *.
-
-    Args:
-      rhs: The value being multiplied, which could be any Python object from
-        user inputs.
 
-    Returns:
-      A _BinaryExpr object representing the operation.
+class IndexExpr(abc.ABC):
+    """The index notation base class.
 
-    Raises:
-      ValueError: If rhs is not an IndexExpr.
+    We support the TACO API index_expression class with an alias of this class.
     """
-    return self._verify_operand_and_build_expr(rhs, operator.mul)
 
-  def __abs__(self) -> "_UnaryExpr":
-    """Defines the operator abs.
-
-    Returns:
-      A _UnaryExpr object representing the operation.
-    """
-    return self._build_unary_expr(operator.abs)
+    def _verify_operand_and_build_expr(self, rhs, op: _BinaryOp) -> "_BinaryExpr":
+        """Verifies the RHS operand and returns a binary expression.
+
+        Args:
+          rhs: The RHS of the binary operation, which could be any Python object
+            from user inputs.
+          op: A _BinaryOp object representing the binary operator.
+
+        Raises:
+          ValueError: If rhs is not an IndexExpr.
+        """
+        if not isinstance(rhs, IndexExpr):
+            raise ValueError(f"Expected IndexExpr: {rhs}")
+        return _BinaryExpr(op, self, rhs)
+
+    def _build_unary_expr(self, op: _UnaryOp) -> "_UnaryExpr":
+        """Build a unary expression.
+
+        Args:
+          op: A _UnaryOp object representing the unary operation.
+        """
+        return _UnaryExpr(op, self)
+
+    def __add__(self, rhs) -> "_BinaryExpr":
+        """Defines the operator +.
+
+        Args:
+          rhs: The value being added, which could be any Python object from user
+            inputs.
+
+        Returns:
+          A _BinaryExpr object representing the operation.
+
+        Raises:
+          ValueError: If rhs is not an IndexExpr.
+        """
+        return self._verify_operand_and_build_expr(rhs, operator.add)
+
+    def __mul__(self, rhs) -> "_BinaryExpr":
+        """Defines the operator *.
+
+        Args:
+          rhs: The value being multiplied, which could be any Python object from
+            user inputs.
+
+        Returns:
+          A _BinaryExpr object representing the operation.
+
+        Raises:
+          ValueError: If rhs is not an IndexExpr.
+        """
+        return self._verify_operand_and_build_expr(rhs, operator.mul)
+
+    def __abs__(self) -> "_UnaryExpr":
+        """Defines the operator abs.
+
+        Returns:
+          A _UnaryExpr object representing the operation.
+        """
+        return self._build_unary_expr(operator.abs)
+
+    def __neg__(self) -> "_UnaryExpr":
+        """Defines the operator neg.
+
+        Returns:
+          A _UnaryExpr object representing the operation.
+        """
+        return self._build_unary_expr(operator.neg)
+
+    def __sub__(self, rhs) -> "_BinaryExpr":
+        """Defines the operator -.
+
+        Args:
+          rhs: The value being subtracted, which could be any Python object from
+            user inputs.
+
+        Returns:
+          A _BinaryExpr object representing the operation.
+
+        Raises:
+          ValueError: If rhs is not an IndexExpr.
+        """
+        return self._verify_operand_and_build_expr(rhs, operator.sub)
+
+    @abc.abstractmethod
+    def _visit(
+        self, func: _ExprVisitor, args, *, leaf_checker: _SubtreeLeafChecker = None
+    ) -> None:
+        """A post-order visitor.
+
+        Args:
+          func: A callable applied to each node in the expression tree.
+          args: The variable-length arguments passed to the callable. These
+            arguments are grouped as an iterable and will be unpacked before passing
+            to the callable. This is to enable the keyword argument only syntax
+            after this argument.
+          leaf_checker: A callable object to identify nodes that should be treated
+            as leaf nodes to support partial tree visiting.
+        """
+        pass
+
+    @abc.abstractmethod
+    def _emit_expression(
+        self,
+        expr_to_opnd: Dict["IndexExpr", lang.OperandDef],
+        expr_to_info: _ExprInfoDict,
+    ) -> lang.ScalarExpression:
+        """Emits MLIR for the expression tree.
+
+        Args:
+          expr_to_opnd: A dictionary for looking up structured op input operands for
+            the input nodes of the structured op.
+          expr_to_info: A dictionary for looking up code generation information for
+            expressions.
+
+        Returns:
+          A linalg dialect ScalarExpression for the expression.
+        """
+        pass
+
+    @abc.abstractmethod
+    def dtype(self) -> DType:
+        """Returns the data type for the result of the expression."""
+        pass
+
+    def _emit_structured_op(self, expr_to_info: _ExprInfoDict) -> None:
+        """Emits a structured op in the linalg dialect for the expression tree.
+
+        We define a DefineOpcallable in the domain specific language for the linalg
+        dialect and execute the callable to generate the structured op. Self is the
+        root of the expression tree for the structured op.
+
+        Args:
+          expr_to_info: A dictionary for looking up code generation information for
+            expressions.
+        """
+        op_info = expr_to_info[self].structop_info
+        op_name = op_info.dst_name
+        op_def = lang.LinalgOpDef(name=op_name)
+        op_callable = lang.DefinedOpCallable(op_name, op_def)
+
+        # Collect the input expression nodes for the structured op.
+        expr_inputs = []
+        self._visit(
+            _gather_structured_op_input,
+            (self, expr_to_info, expr_inputs),
+            leaf_checker=_is_structured_op_leaf,
+        )
+
+        # Create a linalg structured op operand for each input expression node and
+        # build a dictionary for looking up the information.
+        expr_to_input_opnd = {
+            e: _emit_structured_op_input(e, expr_to_info, op_def) for e in expr_inputs
+        }
+
+        # Emit the expression tree, which produces the value assigned to the
+        # destination tensor.
+        value = self._emit_expression(expr_to_input_opnd, expr_to_info)
+        # Emit the structured op representation for the destination tensor.
+        dst_opnd = _emit_operand(
+            op_def,
+            op_info.dst_indices,
+            op_info.dst_name,
+            lang.OperandKind.OUTPUT_TENSOR,
+        )
+        dst_dim_syms = _mlir_dimensions_from_index_vars(op_info.dst_indices)
+        dst_use = lang.TensorUse(dst_opnd, dst_dim_syms)
+
+        expr_info = expr_to_info[self]
+        # If the structured op reduces some indices, explicitly represent the
+        # reduction. This is done by generating a ReduceFn for the dimensions being
+        # reduced in the linalg dialect and calling the function with the value
+        # being reduced. We only support add reduction currently.
+        if expr_info.reduce_indices:
+            reduce_dims = _mlir_dimensions_from_index_vars(expr_info.reduce_indices)
+            value = lang.ReduceFn.add[reduce_dims](value)
+
+        # Emit the assignment as a comprehension in the linalg dialect.
+        comp = lang.Comprehension((dst_use, value))
+        op_def.comprehensions.append(comp)
+
+        # The structured op in the linalg dialect requires an explicit
+        # initialization for the destination tensor. Emit MLIR to initialize the
+        # destination tensor.
+        init = op_info.emit_tensor_init()
+
+        # Collect MLIR values for the linalg input operands, with the assumption
+        # that dictionary preserves the insertion order.
+        args = [
+            expr_to_info[expr].mlir_value for expr, opnd in expr_to_input_opnd.items()
+        ]
+        # Execute the DefineOpcallable object for the linalg dialect operation to
+        # emit MLIR for the linalg structured op.
+        expr_info.mlir_value = op_callable(*args, outs=[init])
+
+    def _identify_structured_ops(
+        self,
+        expr_to_info: _ExprInfoDict,
+        dst: "Tensor",
+        dst_indices: Tuple["IndexVar", ...],
+    ) -> List["IndexExpr"]:
+        """Returns expression nodes for the roots of the identified structured ops.
+
+        A structured op in the linalg dialect only supports reduction performed on
+        the whole expression. If the expression tree contains reduction that are
+        performed on part of the expression tree, the expression tree needs to be
+        implemented with multiple structured ops. This routine identifies all the
+        expression nodes that contain reduction as the root of structured ops in the
+        linalg dialect.
+
+        Args:
+          expr_to_info: A dictionary for looking up code generation information for
+            expressions.
+          dst: A destination Tensor that accepts the value of the expression tree.
+          dst_indices: The indices used by the destination index expression.
+
+        Returns:
+          An ordered list of IndexExpr for the root expressions of the structured
+          ops, where child expressions go before parent expressions that use their
+          results.
+        """
+        reduce_indices = tuple(set(expr_to_info[self].src_indices) - set(dst_indices))
+        for reduce_index in reduce_indices:
+            _mark_structured_op_root(self, reduce_index, expr_to_info)
+
+        self._visit(_accumulate_reduce_indices, (expr_to_info,))
+        structop_roots = []
+        self._visit(_gather_structured_op, (expr_to_info, structop_roots))
+
+        # Handle the root of the top level expression.
+        if not structop_roots or structop_roots[-1] != self:
+            # The top level expression is not a reduction. Add the top level
+            # expression as a structured op root.
+            structop_roots.append(self)
+
+        # Use user specified information for the destination tensor to build an
+        # _StructOpInfo for the top level expression.
+        expr_to_info[self].structop_info = _StructOpInfo(
+            dst_indices, tuple(dst.shape), dst.dtype, dst.name, dst.format
+        )
+
+        return structop_roots
+
+    def _validate_and_collect_expr_info(
+        self,
+        dst: "Tensor",
+        dst_indices: Tuple["IndexVar", ...],
+    ) -> _ExprInfoDict:
+        """Propagates expression information for validation.
+
+        Propagates the indices used by child expression nodes to parent expression
+        nodes. Also collects and validates the sizes for the dimensions
+        corresponding to the indices.
+
+        Args:
+          dst: A destination Tensor that accepts the value of the expression tree.
+          dst_indices: The indices used by the destination index expression.
+
+        Raises:
+          ValueError if there is any inconsistency in indices or dimensional
+          values.
+
+        Returns:
+          A dictionary of (IndexExpr, _ExprInfo).
+        """
+        expr_to_info = {}
+        # Validate the expression tree and construct expression information.
+        self._visit(_validate_and_collect_expr_info, (expr_to_info,))
+
+        # Validate the destination dimension information.
+        info = expr_to_info[self]
+        index_to_dim_info = {i: d for i, d in zip(info.src_indices, info.dim_infos)}
+        for (
+            i,
+            d,
+        ) in zip(dst_indices, dst.shape):
+            if i not in index_to_dim_info:
+                raise ValueError(
+                    "Destination IndexVar not used in the " f"source expression: {i}"
+                )
+            else:
+                if d != index_to_dim_info[i].dim and index_to_dim_info[i].dim != -1:
+                    raise ValueError(
+                        f"Inconsistent destination dimension for {i}: "
+                        f"{d} vs {index_to_dim_info[i].dim}"
+                    )
+
+        return expr_to_info
+
+    def _emit_assignment(
+        self,
+        module: ir.Module,
+        dst: "Tensor",
+        dst_indices: Tuple["IndexVar", ...],
+        expr_to_info: _ExprInfoDict,
+        input_accesses: List["Access"],
+    ) -> None:
+        """Emits an MLIR function for assigning the expression to a tensor."""
+        input_types = [a.tensor.mlir_tensor_type() for a in input_accesses]
+
+        # Build the kernel for the operations.
+        with ir.InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(*input_types, name=_ENTRY_NAME)
+            def linalg_funcop(*args):
+                # Set up the mapping from the Access nodes to their MLIR values.
+                for e, mlir in zip(input_accesses, args):
+                    expr_to_info[e].mlir_value = mlir
+
+                # Emit structured ops in the linalg dialect to implement the assignment.
+                for structop_root in self._identify_structured_ops(
+                    expr_to_info, dst, dst_indices
+                ):
+                    structop_root._emit_structured_op(expr_to_info)
+                    dst._record_stats(expr_to_info[structop_root].structop_info)
+
+                # The function returns the MLIR value of the root expression.
+                return expr_to_info[self].mlir_value
+
+            linalg_funcop.func_op.attributes[
+                "llvm.emit_c_interface"
+            ] = ir.UnitAttr.get()
+
+    def get_input_accesses(self) -> List["Access"]:
+        """Compute the list of input accesses for the expression."""
+        input_accesses = []
+        self._visit(_gather_input_accesses_index_vars, (input_accesses,))
+        return input_accesses
+
+    def compile(
+        self,
+        dst: "Tensor",
+        dst_indices: Tuple["IndexVar", ...],
+    ) -> execution_engine.ExecutionEngine:
+        """Compiles the tensor assignment dst[dst_indices] = expression.
+
+        Args:
+          dst: The destination tensor.
+          dst_indices: The tuple of IndexVar used to access the destination tensor.
+
+        Returns:
+          The execution engine for the tensor assignment.
+
+        Raises:
+          ValueError: If the expression is not proper or not supported.
+        """
+        expr_to_info = self._validate_and_collect_expr_info(dst, dst_indices)
+        input_accesses = self.get_input_accesses()
+
+        # Build and compile the module to produce the execution engine.
+        with ir.Context(), ir.Location.unknown():
+            module = ir.Module.create()
+            self._emit_assignment(
+                module, dst, dst_indices, expr_to_info, input_accesses
+            )
+            engine = utils.compile_and_build_engine(module)
+
+        return engine
 
-  def __neg__(self) -> "_UnaryExpr":
-    """Defines the operator neg.
 
-    Returns:
-      A _UnaryExpr object representing the operation.
-    """
-    return self._build_unary_expr(operator.neg)
+class _AtomicCounter:
+    """An atomic counter."""
 
-  def __sub__(self, rhs) -> "_BinaryExpr":
-    """Defines the operator -.
+    def __init__(self):
+        self._counter = 0
+        self._counter_lock = threading.Lock()
 
-    Args:
-      rhs: The value being subtracted, which could be any Python object from
-        user inputs.
+    def increment(self) -> int:
+        """Increments the counter by one and returns the old value."""
+        old_value = self._counter
+        with self._counter_lock:
+            self._counter = self._counter + 1
+        return old_value
 
-    Returns:
-      A _BinaryExpr object representing the operation.
 
-    Raises:
-      ValueError: If rhs is not an IndexExpr.
-    """
-    return self._verify_operand_and_build_expr(rhs, operator.sub)
-
-  @abc.abstractmethod
-  def _visit(self,
-             func: _ExprVisitor,
-             args,
-             *,
-             leaf_checker: _SubtreeLeafChecker = None) -> None:
-    """A post-order visitor.
-
-    Args:
-      func: A callable applied to each node in the expression tree.
-      args: The variable-length arguments passed to the callable. These
-        arguments are grouped as an iterable and will be unpacked before passing
-        to the callable. This is to enable the keyword argument only syntax
-        after this argument.
-      leaf_checker: A callable object to identify nodes that should be treated
-        as leaf nodes to support partial tree visiting.
-    """
-    pass
+class IndexVar(IndexExpr):
+    """The tensor index class.
 
-  @abc.abstractmethod
-  def _emit_expression(
-      self,
-      expr_to_opnd: Dict["IndexExpr", lang.OperandDef],
-      expr_to_info: _ExprInfoDict,
-  ) -> lang.ScalarExpression:
-    """Emits MLIR for the expression tree.
+    We support the TACO API index_var class with an alias of this class.
 
-    Args:
-      expr_to_opnd: A dictionary for looking up structured op input operands for
-        the input nodes of the structured op.
-      expr_to_info: A dictionary for looking up code generation information for
-        expressions.
+    An IndexVar object represents an index variable in tensor index notation.
 
-    Returns:
-      A linalg dialect ScalarExpression for the expression.
+    Attributes:
+      name: A unique string name of the IndexVar.
     """
-    pass
 
-  @abc.abstractmethod
-  def dtype(self) -> DType:
-    """Returns the data type for the result of the expression."""
-    pass
+    _counter = _AtomicCounter()
 
-  def _emit_structured_op(self, expr_to_info: _ExprInfoDict) -> None:
-    """Emits a structured op in the linalg dialect for the expression tree.
+    def __init__(self):
+        id = self._counter.increment()
+        self._name = f"{_TACO_INDEX_PREFIX}{id}"
 
-    We define a DefineOpcallable in the domain specific language for the linalg
-    dialect and execute the callable to generate the structured op. Self is the
-    root of the expression tree for the structured op.
+    def __repr__(self) -> str:
+        return f"IndexVar(name={repr(self._name)})"
 
-    Args:
-      expr_to_info: A dictionary for looking up code generation information for
-        expressions.
-    """
-    op_info = expr_to_info[self].structop_info
-    op_name = op_info.dst_name
-    op_def = lang.LinalgOpDef(name=op_name)
-    op_callable = lang.DefinedOpCallable(op_name, op_def)
-
-    # Collect the input expression nodes for the structured op.
-    expr_inputs = []
-    self._visit(
-        _gather_structured_op_input,
-        (self, expr_to_info, expr_inputs),
-        leaf_checker=_is_structured_op_leaf,
-    )
+    @property
+    def name(self) -> str:
+        """Returns the name of the IndexVar."""
+        return self._name
 
-    # Create a linalg structured op operand for each input expression node and
-    # build a dictionary for looking up the information.
-    expr_to_input_opnd = {
-        e: _emit_structured_op_input(e, expr_to_info, op_def)
-        for e in expr_inputs
-    }
+    def _visit(
+        self, func: _ExprVisitor, args, *, leaf_checker: _SubtreeLeafChecker = None
+    ) -> None:
+        """A post-order visitor."""
+        if leaf_checker:
+            assert leaf_checker(self, *args)
+        func(self, *args)
 
-    # Emit the expression tree, which produces the value assigned to the
-    # destination tensor.
-    value = self._emit_expression(expr_to_input_opnd, expr_to_info)
-    # Emit the structured op representation for the destination tensor.
-    dst_opnd = _emit_operand(op_def, op_info.dst_indices, op_info.dst_name,
-                             lang.OperandKind.OUTPUT_TENSOR)
-    dst_dim_syms = _mlir_dimensions_from_index_vars(op_info.dst_indices)
-    dst_use = lang.TensorUse(dst_opnd, dst_dim_syms)
-
-    expr_info = expr_to_info[self]
-    # If the structured op reduces some indices, explicitly represent the
-    # reduction. This is done by generating a ReduceFn for the dimensions being
-    # reduced in the linalg dialect and calling the function with the value
-    # being reduced. We only support add reduction currently.
-    if expr_info.reduce_indices:
-      reduce_dims = _mlir_dimensions_from_index_vars(expr_info.reduce_indices)
-      value = lang.ReduceFn.add[reduce_dims](value)
-
-    # Emit the assignment as a comprehension in the linalg dialect.
-    comp = lang.Comprehension((dst_use, value))
-    op_def.comprehensions.append(comp)
-
-    # The structured op in the linalg dialect requires an explicit
-    # initialization for the destination tensor. Emit MLIR to initialize the
-    # destination tensor.
-    init = op_info.emit_tensor_init()
-
-    # Collect MLIR values for the linalg input operands, with the assumption
-    # that dictionary preserves the insertion order.
-    args = [
-        expr_to_info[expr].mlir_value
-        for expr, opnd in expr_to_input_opnd.items()
-    ]
-    # Execute the DefineOpcallable object for the linalg dialect operation to
-    # emit MLIR for the linalg structured op.
-    expr_info.mlir_value = op_callable(*args, outs=[init])
-
-  def _identify_structured_ops(
-      self,
-      expr_to_info: _ExprInfoDict,
-      dst: "Tensor",
-      dst_indices: Tuple["IndexVar", ...],
-  ) -> List["IndexExpr"]:
-    """Returns expression nodes for the roots of the identified structured ops.
-
-    A structured op in the linalg dialect only supports reduction performed on
-    the whole expression. If the expression tree contains reduction that are
-    performed on part of the expression tree, the expression tree needs to be
-    implemented with multiple structured ops. This routine identifies all the
-    expression nodes that contain reduction as the root of structured ops in the
-    linalg dialect.
+    def _emit_expression(
+        self,
+        expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
+        expr_to_info: _ExprInfoDict,
+    ) -> lang.ScalarExpression:
+        """Emits a index value casted to the data type of the tensor expression."""
+        dim = getattr(lang.D, self.name)
+        index = lang.index(dim)
+        int_value = lang.TypeFn.cast_unsigned(lang.TV.I64, index)
+        return lang.TypeFn.cast_unsigned(lang.T, int_value)
 
-    Args:
-      expr_to_info: A dictionary for looking up code generation information for
-        expressions.
-      dst: A destination Tensor that accepts the value of the expression tree.
-      dst_indices: The indices used by the destination index expression.
+    def dtype(self) -> DType:
+        """Returns the data type for the index value.
 
-    Returns:
-      An ordered list of IndexExpr for the root expressions of the structured
-      ops, where child expressions go before parent expressions that use their
-      results.
-    """
-    reduce_indices = tuple(
-        set(expr_to_info[self].src_indices) - set(dst_indices))
-    for reduce_index in reduce_indices:
-      _mark_structured_op_root(self, reduce_index, expr_to_info)
-
-    self._visit(_accumulate_reduce_indices, (expr_to_info,))
-    structop_roots = []
-    self._visit(_gather_structured_op, (expr_to_info, structop_roots))
-
-    # Handle the root of the top level expression.
-    if not structop_roots or structop_roots[-1] != self:
-      # The top level expression is not a reduction. Add the top level
-      # expression as a structured op root.
-      structop_roots.append(self)
-
-    # Use user specified information for the destination tensor to build an
-    # _StructOpInfo for the top level expression.
-    expr_to_info[self].structop_info = _StructOpInfo(dst_indices,
-                                                     tuple(dst.shape),
-                                                     dst.dtype, dst.name,
-                                                     dst.format)
-
-    return structop_roots
-
-  def _validate_and_collect_expr_info(
-      self,
-      dst: "Tensor",
-      dst_indices: Tuple["IndexVar", ...],
-  ) -> _ExprInfoDict:
-    """Propagates expression information for validation.
-
-    Propagates the indices used by child expression nodes to parent expression
-    nodes. Also collects and validates the sizes for the dimensions
-    corresponding to the indices.
+        This is unreachable for IndexVar.
+        """
+        assert 0
 
-    Args:
-      dst: A destination Tensor that accepts the value of the expression tree.
-      dst_indices: The indices used by the destination index expression.
 
-    Raises:
-      ValueError if there is any inconsistency in indices or dimensional
-      values.
+def get_index_vars(n: int) -> List[IndexVar]:
+    """Returns a list of n IndexVar.
 
-    Returns:
-      A dictionary of (IndexExpr, _ExprInfo).
-    """
-    expr_to_info = {}
-    # Validate the expression tree and construct expression information.
-    self._visit(_validate_and_collect_expr_info, (expr_to_info,))
-
-    # Validate the destination dimension information.
-    info = expr_to_info[self]
-    index_to_dim_info = {i: d for i, d in zip(info.src_indices, info.dim_infos)}
-    for i, d, in zip(dst_indices, dst.shape):
-      if i not in index_to_dim_info:
-        raise ValueError("Destination IndexVar not used in the "
-                         f"source expression: {i}")
-      else:
-        if d != index_to_dim_info[i].dim and index_to_dim_info[i].dim != -1:
-          raise ValueError(f"Inconsistent destination dimension for {i}: "
-                           f"{d} vs {index_to_dim_info[i].dim}")
-
-    return expr_to_info
-
-  def _emit_assignment(
-      self,
-      module: ir.Module,
-      dst: "Tensor",
-      dst_indices: Tuple["IndexVar", ...],
-      expr_to_info: _ExprInfoDict,
-      input_accesses: List["Access"],
-  ) -> None:
-    """Emits an MLIR function for assigning the expression to a tensor."""
-    input_types = [a.tensor.mlir_tensor_type() for a in input_accesses]
-
-    # Build the kernel for the operations.
-    with ir.InsertionPoint(module.body):
-
-      @func.FuncOp.from_py_func(*input_types, name=_ENTRY_NAME)
-      def linalg_funcop(*args):
-        # Set up the mapping from the Access nodes to their MLIR values.
-        for e, mlir in zip(input_accesses, args):
-          expr_to_info[e].mlir_value = mlir
-
-        # Emit structured ops in the linalg dialect to implement the assignment.
-        for structop_root in self._identify_structured_ops(
-            expr_to_info, dst, dst_indices):
-          structop_root._emit_structured_op(expr_to_info)
-          dst._record_stats(expr_to_info[structop_root].structop_info)
-
-        # The function returns the MLIR value of the root expression.
-        return expr_to_info[self].mlir_value
-
-      linalg_funcop.func_op.attributes[
-          "llvm.emit_c_interface"] = ir.UnitAttr.get()
-
-  def get_input_accesses(self) -> List["Access"]:
-    """Compute the list of input accesses for the expression."""
-    input_accesses = []
-    self._visit(_gather_input_accesses_index_vars, (input_accesses,))
-    return input_accesses
-
-  def compile(
-      self,
-      dst: "Tensor",
-      dst_indices: Tuple["IndexVar", ...],
-  ) -> execution_engine.ExecutionEngine:
-    """Compiles the tensor assignment dst[dst_indices] = expression.
+    This routine is defined by the TACO API.
 
     Args:
-      dst: The destination tensor.
-      dst_indices: The tuple of IndexVar used to access the destination tensor.
+      n: An integer representing the number of IndexVar to get.
 
     Returns:
-      The execution engine for the tensor assignment.
+      A list of IndexVar.
 
     Raises:
-      ValueError: If the expression is not proper or not supported.
+      ValueError: if n is not a positive integer.
     """
-    expr_to_info = self._validate_and_collect_expr_info(dst, dst_indices)
-    input_accesses = self.get_input_accesses()
-
-    # Build and compile the module to produce the execution engine.
-    with ir.Context(), ir.Location.unknown():
-      module = ir.Module.create()
-      self._emit_assignment(module, dst, dst_indices, expr_to_info,
-                            input_accesses)
-      engine = utils.compile_and_build_engine(module)
-
-    return engine
-
-
-class _AtomicCounter:
-  """An atomic counter."""
-
-  def __init__(self):
-    self._counter = 0
-    self._counter_lock = threading.Lock()
-
-  def increment(self) -> int:
-    """Increments the counter by one and returns the old value."""
-    old_value = self._counter
-    with self._counter_lock:
-      self._counter = self._counter + 1
-    return old_value
-
-
-class IndexVar(IndexExpr):
-  """The tensor index class.
-
-  We support the TACO API index_var class with an alias of this class.
-
-  An IndexVar object represents an index variable in tensor index notation.
-
-  Attributes:
-    name: A unique string name of the IndexVar.
-  """
-  _counter = _AtomicCounter()
-
-  def __init__(self):
-    id = self._counter.increment()
-    self._name = f"{_TACO_INDEX_PREFIX}{id}"
-
-  def __repr__(self) -> str:
-    return f"IndexVar(name={repr(self._name)})"
-
-  @property
-  def name(self) -> str:
-    """Returns the name of the IndexVar."""
-    return self._name
-
-  def _visit(self,
-             func: _ExprVisitor,
-             args,
-             *,
-             leaf_checker: _SubtreeLeafChecker = None) -> None:
-    """A post-order visitor."""
-    if leaf_checker:
-      assert leaf_checker(self, *args)
-    func(self, *args)
-
-  def _emit_expression(
-      self,
-      expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
-      expr_to_info: _ExprInfoDict,
-  ) -> lang.ScalarExpression:
-    """Emits a index value casted to the data type of the tensor expression."""
-    dim = getattr(lang.D, self.name)
-    index = lang.index(dim)
-    int_value = lang.TypeFn.cast_unsigned(lang.TV.I64, index)
-    return lang.TypeFn.cast_unsigned(lang.T, int_value)
-
-  def dtype(self) -> DType:
-    """Returns the data type for the index value.
-
-    This is unreachable for IndexVar.
-    """
-    assert 0
-
-
-def get_index_vars(n: int) -> List[IndexVar]:
-  """Returns a list of n IndexVar.
-
-  This routine is defined by the TACO API.
-
-  Args:
-    n: An integer representing the number of IndexVar to get.
-
-  Returns:
-    A list of IndexVar.
-
-  Raises:
-    ValueError: if n is not a positive integer.
-  """
-  if not isinstance(n, int) or n <= 0:
-    raise ValueError(f"Expected an integer: {n}.")
-  # If lock contention ever becomes an issue, we could implement a bulk getter
-  # that returns a range by only claiming the lock once.
-  return [IndexVar() for i in range(n)]
+    if not isinstance(n, int) or n <= 0:
+        raise ValueError(f"Expected an integer: {n}.")
+    # If lock contention ever becomes an issue, we could implement a bulk getter
+    # that returns a range by only claiming the lock once.
+    return [IndexVar() for i in range(n)]
 
 
 def _mlir_symbols_from_index_vars(
-    index_vars: Tuple[IndexVar, ...]) -> Tuple[lang.SymbolDef, ...]:
-  """Returns a tuple of MLIR symbols for the given tuple of index_var."""
-  return tuple(getattr(lang.S, i.name) for i in index_vars)
+    index_vars: Tuple[IndexVar, ...]
+) -> Tuple[lang.SymbolDef, ...]:
+    """Returns a tuple of MLIR symbols for the given tuple of index_var."""
+    return tuple(getattr(lang.S, i.name) for i in index_vars)
 
 
 def _mlir_dimensions_from_index_vars(
-    index_vars: Tuple[IndexVar, ...]) -> Tuple[lang.DimDef, ...]:
-  """Returns a tuple of MLIR dimensions for the given tuple of index_var."""
-  return tuple(getattr(lang.D, i.name) for i in index_vars)
+    index_vars: Tuple[IndexVar, ...]
+) -> Tuple[lang.DimDef, ...]:
+    """Returns a tuple of MLIR dimensions for the given tuple of index_var."""
+    return tuple(getattr(lang.D, i.name) for i in index_vars)
 
 
 def _mlir_tensor_type(
-    dtype: DType, shape: Tuple[int, ...],
-    attr: Optional[sparse_tensor.EncodingAttr]) -> ir.RankedTensorType:
-  """Returns an MLIR tensor type.
-
-  Args:
-    dtype: An DType object for the element data type of the tensor.
-    shape: A tuple of integer for the shape of the tensor.
-    attr: An optional MLIR sparse tensor attribute, only provided if the tensor
-      is a sparse tensor.
-
-  Returns:
-    An MLIR ranked tensor type.
-  """
-  ir_type = _mlir_type_from_taco_type(dtype)
-  return ir.RankedTensorType.get(shape, ir_type, attr)
-
-
-@dataclasses.dataclass(frozen=True)
-class _StructOpInfo:
-  """Information for generating a structured op in the linalg dialect.
-
-  This information is associated with an expression node that serves as the
-  root for an expression subtree implemented with a structured op.
-
-  Attributes:
-    dst_indices: A tuple of IndexVar, representing the result dimensions of the
-      structured op. This is used to construct the temporary variable for the
-      tensor to hold the structured op result.
-    dst_dims: A tuple of int, representing the result shape of the structured
-      op.
-    dst_dtype: A DType representing the data type of the structured op result.
-    dst_name: A string representing the name of the structured op result.
-    dst_format: An optional Format object representing the destination tensor
-      format. None represents a true dense tensor.
-  """
-  dst_indices: Tuple[IndexVar, ...]
-  dst_dims: Tuple[int, ...]
-  dst_dtype: DType
-  dst_name: str
-  dst_format: Optional[Format]
-
-  def __post_init__(self) -> None:
-    """Verifies the integrity of the attribute values."""
-    assert len(self.dst_indices) == len(self.dst_dims)
-
-  def emit_tensor_init(self) -> ir.RankedTensorType:
-    """Returns an initialization for the destination tensor."""
-    if self.dst_format is None or self.dst_format.rank() == 0:
-      # Initialize the dense tensor.
-      ir_type = _mlir_type_from_taco_type(self.dst_dtype)
-      empty = tensor.EmptyOp(self.dst_dims, ir_type).result
-      zero = arith.ConstantOp(ir_type, 0.0)
-      return linalg.fill(zero, outs=[empty])
-
-    # Initialize the sparse tensor.
-    mlir_type = _mlir_tensor_type(self.dst_dtype, self.dst_dims,
-                                  self.dst_format.mlir_tensor_attr())
-    index_type = ir.IndexType.get()
-    return bufferization.AllocTensorOp(mlir_type, [], None, None, None)
-
-
-class _Stats:
-  """Information to describe how a tensor expression is implemented.
-
-  Currently, we only record the temporary tensors introduced for splitting the
-  original expression.
-  """
-
-  def __init__(self):
-    self._temps = []
-
-  def __repr__(self) -> str:
-    return f"_Stats({repr(self._temps)})"
-
-  def add_element(self, structop: _StructOpInfo):
-    """Adds a temporary tensor."""
-    self._temps.append(structop)
-
-  def get_total(self) -> int:
-    """Gets the total number of temporary tensors."""
-    return len(self._temps)
-
-  def _get_element(self, idx: int) -> _StructOpInfo:
-    """Gets the ith temporary tensor."""
-    assert idx < self.get_total()
-    return self._temps[idx]
-
-  def get_dimensions(self, idx: int) -> Tuple[int]:
-    """Gets the dimensions for the ith temporary tensor."""
-    return self._get_element(idx).dst_dims
-
-  def get_formats(self, idx: int) -> Tuple[ModeFormat]:
-    """Gets the ModeFormats for the ith temporary tensor."""
-    return tuple(self._get_element(idx).dst_format.format_pack.formats)
-
-
-class _SparseValueInfo(enum.Enum):
-  """Describes how a sparse tensor value is stored.
-  _UNPACKED: The sparse tensor value is stored as (coordnates, values) in
-    Python.
-  _PACKED: The sparse tensor value is stored as a C pointer to a packed MLIR
-    sparse tensor.
-  """
-  _UNPACKED = 0
-  _PACKED = 1
-
-
-@dataclasses.dataclass(frozen=True)
-class _Assignment:
-  """Records an assignment to a tensor T as T[indices] = expression."""
-  indices: Tuple["IndexVar", ...]
-  expression: "IndexExpr"
-
-
-class Tensor:
-  """The tensor class.
-
-  We support the TACO API tensor class with an alias of this class.
-
-  This class is part of the TACO API with the following methods:
-    insert: Inserts a value to the given coordinate in the tensor.
-    to_array: Returns a numpy ndarray for the tensor.
-
-  TACO API also defines the following arrtibutes for the class:
-    dtype: A dtype object representing the data type of the tensor.
-    format: A format object representing the storage format of the tensor.
-    name: A string object representing the name of the tensor.
-    order: An integral rank of the tensor.
-    shape: A list of integers representing the shape of the tensor.
-
-  We currently ignore the tensor dimension ordering for dense tensor.
-  """
-  _counter = _AtomicCounter()
-
-  def _get_unique_name(self) -> str:
-    """Returns a unique name for creating a new Tensor."""
-    return f"{_TACO_TENSOR_PREFIX}{self._counter.increment()}"
-
-  def _init_format(self, fmt: Union[ModeFormat, List[ModeFormat],
-                                    Format]) -> None:
-    """Process the fmt argument for the Tensor constructor.
-
-    Args:
-      fmt: This argument can be a ModeFormat, List[ModeFormat], or format. If
-        this argument is a ModeFormat, uses this ModeFormat for all the tensor
-        dimensions. If this argument is a list of ModeFormat, the len of the
-        list should equal to the rank of the tensor. If this argument is a
-        format, uses it for the format of the tensor.
-
-    Raises:
-      ValueError: If fmt is not one of the expected type or is inconsistent
-        with the rank of the tensor. This is because fmt could be an users
-        input.
-    """
-    if isinstance(fmt, ModeFormat):
-      self._format = _make_format([fmt] * self.order)
-    elif isinstance(fmt, list):
-      if len(fmt) == self.order and isinstance(fmt[0], ModeFormat):
-        self._format = _make_format(fmt)
-      else:
-        raise ValueError("Inconsistent shape and format: "
-                         f"{self._shape}, {fmt}.")
-    elif isinstance(fmt, Format):
-      if fmt.rank() != self.order:
-        raise ValueError("Inconsistent shape and format: "
-                         f"{self._shape}, {fmt}.")
-      else:
-        self._format = fmt
-    else:
-      raise ValueError(f"Invalid format argument: {fmt}.")
-
-  def __init__(self,
-               value_or_shape: Optional[Union[List[int], Tuple[int, ...],
-                                              complex, float, int]] = None,
-               fmt: Optional[Union[ModeFormat, List[ModeFormat],
-                                   Format]] = None,
-               dtype: Optional[DType] = None,
-               name: Optional[str] = None,
-               is_dense: bool = False):
-    """The tensor constructor interface defined by TACO API.
-
-    Args:
-      value_or_shape: This argument is optional and can be int, float,
-        List[int], or Tuple[int, ...]. If this argument is an int or float,
-        creates a scalar tensor and initializes it with the value. If this
-        argument is a list or tuple of int, uses it as the shape to create a
-        tensor.
-      fmt: This argument can be a ModeFormat, List[ModeFormat], or format. If
-        this argument is a ModeFormat, uses this ModeFormat for all the tensor
-        dimensions. If this argument is a list of ModeFormat, the len of the
-        list should equal to the rank of the tensor. If this argument is a
-        format, uses it for the format of the tensor.
-      dtype: An object of dtype, representing the data type of the tensor.
-      name: A string name of the tensor. If a name is not given, creates a
-        unique name for the tensor.
-      is_dense: A boolean variable to indicate whether the tensor is a dense
-        tensor without any sparsity annotation.
-
-    Raises:
-      ValueError: If there is any inconsistency among the input arguments.
-    """
-    # Take care of the argument default values common to both sparse tensors
-    # and dense tensors.
-    dtype = dtype or DType(Type.FLOAT32)
-    self._name = name or self._get_unique_name()
-    self._assignment = None
-    self._engine = None
-    self._sparse_value_location = _SparseValueInfo._UNPACKED
-    self._dense_storage = None
-    self._dtype = dtype
-
-    if is_dense:
-      assert (fmt is None)
-      assert (isinstance(value_or_shape, tuple) or isinstance(
-          value_or_shape, list)) and _all_instance_of(value_or_shape, int)
-      self._shape = value_or_shape
-      self._format = None
-      return
-
-    fmt = fmt or ModeFormat.COMPRESSED
-    # We currently use _coords and _values to host the sparse tensor value with
-    # COO format, and _dense_storage to host the dense tensor value. We don't
-    # support the conversion between the two storages.
-    self._coords = []
-    self._values = []
-    self._stats = _Stats()
-    if value_or_shape is None or isinstance(value_or_shape, int) or isinstance(
-        value_or_shape, float) or isinstance(value_or_shape, complex):
-      # Create a scalar tensor and ignore the fmt parameter.
-      self._shape = []
-      self._format = _make_format([], [])
-      if value_or_shape is not None:
-        self._dense_storage = np.array(value_or_shape, dtype=self._dtype.value)
-    elif (isinstance(value_or_shape, tuple) or isinstance(
-        value_or_shape, list)) and _all_instance_of(value_or_shape, int):
-      # Create a tensor with the specified shape and format.
-      self._shape = list(value_or_shape)
-      self._init_format(fmt)
-    else:
-      raise ValueError("Invalid first argument. "
-                       "Must be a tuple or list for a shape or a single value"
-                       f"if initializing a scalar tensor: {value_or_shape}.")
-
-  def _set_packed_sparse_tensor(self, pointer: ctypes.c_void_p) -> None:
-    """Records the MLIR sparse tensor pointer."""
-    self._sparse_value_location = _SparseValueInfo._PACKED
-    self._packed_sparse_value = pointer
-
-  def is_unpacked(self) -> bool:
-    """Returns true if the tensor value is not packed as MLIR sparse tensor."""
-    return (self._sparse_value_location == _SparseValueInfo._UNPACKED)
-
-  def unpack(self) -> None:
-    """Unpacks the MLIR sparse tensor representation."""
-    if self.is_dense() or self.is_unpacked():
-      return
-
-    # Use the output MLIR sparse tensor pointer to retrieve the COO-flavored
-    # values and verify the values.
-    rank, nse, shape, values, indices = utils.sparse_tensor_to_coo_tensor(
-        self._packed_sparse_value, self._dtype.value)
-    assert rank == self.order
-    assert np.array_equal(self.shape, shape)
-    assert nse == len(values)
-    self._coords = indices
-    self._values = values
-    self._sparse_value_location = _SparseValueInfo._UNPACKED
-
-  def __repr__(self) -> str:
-    self._sync_value()
-    self.unpack()
-    value_str = (f"{repr(self._dense_storage)})" if self.is_dense() else
-                 f"{repr(self._coords)} {repr(self._values)})")
-    return (f"Tensor(_name={repr(self._name)} "
-            f"_dtype={repr(self._dtype)} : ") + value_str
-
-  def insert(self, coords: List[int], val: Union[complex, float, int]) -> None:
-    """Inserts a value to the given coordinate.
+    dtype: DType, shape: Tuple[int, ...], attr: Optional[sparse_tensor.EncodingAttr]
+) -> ir.RankedTensorType:
+    """Returns an MLIR tensor type.
 
     Args:
-      coords: A list of integer coordinates. The length of the list must be the
-        same as the rank of the tensor.
-      val: A value being inserted. It is either an integral or a floating point
-        value. This value will be converted to the data type of the tensor.
-
-    Raises:
-      ValueError: When there is any problem in the parameters.
-    """
-    if self.is_dense():
-      raise ValueError("Insert method is not supported for dense tensors.")
-    if self._assignment != None or not self.is_unpacked():
-      raise ValueError(
-          "Can't use Insert method for a tensor constructed from a file.")
-    if not isinstance(coords, list):
-      raise ValueError(f"Non list coordinate detected: {coords}.")
-    if not _all_instance_of(coords, int):
-      raise ValueError(f"Non integer coordinate detected: {coords}.")
-    if (len(coords) != self.order or
-        any([c < 0 or c >= self._shape[i] for i, c in enumerate(coords)])):
-      raise ValueError("Invalid coordinate for rank: "
-                       f"{self.order}, {coords}.")
-
-    if not isinstance(val, int) and not isinstance(
-        val, float) and not isinstance(val, complex):
-      raise ValueError(f"Value is neither int nor float: {val}.")
-
-    self._coords.append(tuple(coords))
-    self._values.append(self._dtype.value(val))
-
-  def is_dense(self) -> bool:
-    """Returns true if the tensor doesn't have sparsity annotation."""
-    return self.order == 0 or self._format is None
-
-  def to_array(self) -> np.ndarray:
-    """Returns the numpy array for the Tensor.
-
-    This is currenly only implemented for dense Tensor.
-    """
-    if not self.is_dense():
-      raise ValueError("Conversion from non-dense Tensor "
-                       "to numpy array not supported yet.")
-
-    self._sync_value()
-
-    return self._dense_storage
-
-  @staticmethod
-  def from_array(array: np.ndarray) -> "Tensor":
-    """Returns a dense tensor with the value copied from the input array.
-
-    We currently only support the conversion of float32 and float64 numpy arrays
-    to Tensor.
-
-    Args:
-      array: The numpy array that provides the data type, shape and value for
-        the tensor.
+      dtype: An DType object for the element data type of the tensor.
+      shape: A tuple of integer for the shape of the tensor.
+      attr: An optional MLIR sparse tensor attribute, only provided if the tensor
+        is a sparse tensor.
 
     Returns:
-      A Tensor object.
-
-    Raises:
-      ValueError if the data type of the numpy array is not supported.
+      An MLIR ranked tensor type.
     """
-    if array.dtype != np.float32 and array.dtype != np.float64:
-      raise ValueError(f"Expected floating point value type: {array.dtype}.")
-    t = Tensor(
-        array.shape,
-        dtype=_nptype_to_taco_type(array.dtype.type),
-        is_dense=True)
-    t._dense_storage = np.copy(array)
-    return t
-
-  @staticmethod
-  def from_coo(
-      coordinates: List[Tuple[int, ...]],
-      values: List[_AnyRuntimeType],
-      fmt: Format,
-      dtype: DType,
-  ) -> "Tensor":
-    """Converts coordinates and values to a sparse tensor representation.
+    ir_type = _mlir_type_from_taco_type(dtype)
+    return ir.RankedTensorType.get(shape, ir_type, attr)
 
-    Args:
-      coordinates: A list of coordinates with non-zero values.
-      values: The non-zero values.
-      fmt: The tensor storage format.
-      dtype: The tensor element data type.
 
-    Returns:
-      A tensor with the given non-zero values and storage format. The shape of
-      the tensor has the minimum size for each dimension to make the given
-      coordinates valid.
-    """
-    assert (isinstance(coordinates, List) and
-            _all_instance_of(coordinates, Tuple))
-    assert (isinstance(values, List) and _all_instance_of(values, dtype.value))
-    assert isinstance(fmt, Format)
-
-    rank = fmt.rank()
-    assert all(len(c) == rank and _all_instance_of(c, int) for c in coordinates)
-
-    # Find the maximum coordinate value for each dimension.
-    max_coordinate = list(map(max, zip(*coordinates)))
-    # The size of each dimension is one more that such a maximum coordinate
-    # value.
-    shape = [c + 1 for c in max_coordinate]
-    t = Tensor(shape, fmt, dtype=dtype)
-    t._coords = coordinates
-    t._values = values
-
-    return tensor
-
-  @staticmethod
-  def from_file(
-      filename: str,
-      fmt: Format,
-      dtype: DType,
-  ) -> "Tensor":
-    """Constructs a sparse tensor using the COO-flavored values from a file.
-
-    Args:
-      filename: A string for the name of the file that contains the sparse
-        tensor data.
-      fmt: The tensor storage format.
-      dtype: The tensor element data type.
-
-    Returns:
-      A tensor with the given non-zero values and storage format. The tensor
-      value is stored as an MLIR sparse tensor.
+@dataclasses.dataclass(frozen=True)
+class _StructOpInfo:
+    """Information for generating a structured op in the linalg dialect.
+
+    This information is associated with an expression node that serves as the
+    root for an expression subtree implemented with a structured op.
+
+    Attributes:
+      dst_indices: A tuple of IndexVar, representing the result dimensions of the
+        structured op. This is used to construct the temporary variable for the
+        tensor to hold the structured op result.
+      dst_dims: A tuple of int, representing the result shape of the structured
+        op.
+      dst_dtype: A DType representing the data type of the structured op result.
+      dst_name: A string representing the name of the structured op result.
+      dst_format: An optional Format object representing the destination tensor
+        format. None represents a true dense tensor.
     """
-    sparse_tensor, shape = utils.create_sparse_tensor(filename,
-                                                      fmt.format_pack.formats,
-                                                      _dtype_to_mlir_str(dtype))
-    t = Tensor(shape.tolist(), fmt, dtype=dtype)
-    t._set_packed_sparse_tensor(sparse_tensor)
-
-    return t
 
-  def to_file(self, filename: str) -> None:
-    """Output the tensor value to a file.
+    dst_indices: Tuple[IndexVar, ...]
+    dst_dims: Tuple[int, ...]
+    dst_dtype: DType
+    dst_name: str
+    dst_format: Optional[Format]
+
+    def __post_init__(self) -> None:
+        """Verifies the integrity of the attribute values."""
+        assert len(self.dst_indices) == len(self.dst_dims)
+
+    def emit_tensor_init(self) -> ir.RankedTensorType:
+        """Returns an initialization for the destination tensor."""
+        if self.dst_format is None or self.dst_format.rank() == 0:
+            # Initialize the dense tensor.
+            ir_type = _mlir_type_from_taco_type(self.dst_dtype)
+            empty = tensor.EmptyOp(self.dst_dims, ir_type).result
+            zero = arith.ConstantOp(ir_type, 0.0)
+            return linalg.fill(zero, outs=[empty])
+
+        # Initialize the sparse tensor.
+        mlir_type = _mlir_tensor_type(
+            self.dst_dtype, self.dst_dims, self.dst_format.mlir_tensor_attr()
+        )
+        index_type = ir.IndexType.get()
+        return bufferization.AllocTensorOp(mlir_type, [], None, None, None)
 
-    This method evaluates any pending assignment to the tensor and outputs the
-    tensor value.
 
-    Args:
-      filename: A string file name.
+class _Stats:
+    """Information to describe how a tensor expression is implemented.
 
-    Raises:
-       ValueError: If the tensor is dense, or an unpacked sparse tensor.
+    Currently, we only record the temporary tensors introduced for splitting the
+    original expression.
     """
-    self._sync_value()
-
-    if self.is_dense():
-      raise ValueError("Writing dense tensors without sparsity annotation to "
-                       "file is not supported.")
 
-    if self.is_unpacked():
-      raise ValueError("Writing unpacked sparse tensors to file is not "
-                       "supported.")
+    def __init__(self):
+        self._temps = []
 
-    utils.output_sparse_tensor(self._packed_sparse_value, filename,
-                               self._format.format_pack.formats,
-                               _dtype_to_mlir_str(self._dtype))
+    def __repr__(self) -> str:
+        return f"_Stats({repr(self._temps)})"
 
-  @property
-  def dtype(self) -> DType:
-    """Returns the data type for the Tensor."""
-    return self._dtype
+    def add_element(self, structop: _StructOpInfo):
+        """Adds a temporary tensor."""
+        self._temps.append(structop)
 
-  @property
-  def format(self) -> Format:
-    """Returns the storage format for the Tensor."""
-    return self._format
+    def get_total(self) -> int:
+        """Gets the total number of temporary tensors."""
+        return len(self._temps)
 
-  @property
-  def name(self) -> str:
-    """Returns the name for the Tensor."""
-    return self._name
+    def _get_element(self, idx: int) -> _StructOpInfo:
+        """Gets the ith temporary tensor."""
+        assert idx < self.get_total()
+        return self._temps[idx]
 
-  @property
-  def order(self) -> int:
-    """Returns the rank of the Tensor."""
-    return len(self._shape)
+    def get_dimensions(self, idx: int) -> Tuple[int]:
+        """Gets the dimensions for the ith temporary tensor."""
+        return self._get_element(idx).dst_dims
 
-  @property
-  def shape(self) -> List[int]:
-    """Returns the shape of the Tensor."""
-    return self._shape
+    def get_formats(self, idx: int) -> Tuple[ModeFormat]:
+        """Gets the ModeFormats for the ith temporary tensor."""
+        return tuple(self._get_element(idx).dst_format.format_pack.formats)
 
-  def _verify_and_normalize_indices(self, indices) -> Tuple[IndexVar, ...]:
-    """Verifies and normalizes the indices to access the tensor.
 
-    Args:
-      indices: The index expression used to access a tensor, which could be any
-        Python object from user inputs.
-
-    Returns:
-      A tuple of IndexVar.
-
-    Raises:
-      ValueError: If indices is not 0 for scalar tensors, or not an IndexVar or
-        a tuple of IndexVar for other tensors.
+class _SparseValueInfo(enum.Enum):
+    """Describes how a sparse tensor value is stored.
+    _UNPACKED: The sparse tensor value is stored as (coordnates, values) in
+      Python.
+    _PACKED: The sparse tensor value is stored as a C pointer to a packed MLIR
+      sparse tensor.
     """
-    if self.order == 0:
-      if not isinstance(indices, int) or indices != 0:
-        raise ValueError(f"Expected 0 to index scalar tensors: {indices}")
-      return ()
-
-    if isinstance(indices, IndexVar):
-      return (indices,)
-    elif isinstance(indices, tuple) and _all_instance_of(indices, IndexVar):
-      return indices
 
-    raise ValueError(f"Expected IndexVars: {indices}")
+    _UNPACKED = 0
+    _PACKED = 1
 
-  def __getitem__(self, key) -> "Access":
-    """Verifies and processes a tensor access.
 
-    In the tensor index notation, a tensor access T[i, j] is represented as
-    retrieving a value with key (i, j) from the tensor object T in Python. This
-    routine verifies the key for the tensor access and returns a tensor access
-    object.
-
-    Args:
-      key: The key used to access the tensor, which could be any Python object
-        from user inputs.
+@dataclasses.dataclass(frozen=True)
+class _Assignment:
+    """Records an assignment to a tensor T as T[indices] = expression."""
 
-    Returns:
-      The corresponding tensor access object.
+    indices: Tuple["IndexVar", ...]
+    expression: "IndexExpr"
 
-    Raises:
-      ValueError: If key is not an IndexVar or a tuple of IndexVar.
-    """
-    indices = self._verify_and_normalize_indices(key)
-    return Access(self, indices)
 
-  def __setitem__(self, key, value) -> None:
-    """Verifies and processes a tensor assignment.
+class Tensor:
+    """The tensor class.
 
-    In the tensor index notation, a tensor assignment "T[i, j] = ..." is
-    represented as setting a value for a tensor object T via key (i, j) in
-    Python. This routine verifies the key, evaluates the value, and assigns the
-    value to the tensor.
+    We support the TACO API tensor class with an alias of this class.
 
-    We only support assignment of dense tensor currently.
+    This class is part of the TACO API with the following methods:
+      insert: Inserts a value to the given coordinate in the tensor.
+      to_array: Returns a numpy ndarray for the tensor.
 
-    Args:
-      key: The key used to access the tensor, which could be any Python object
-        from user inputs.
-      value: The value assigned to the tensor, which could be any Python object
-        from user inputs.
+    TACO API also defines the following arrtibutes for the class:
+      dtype: A dtype object representing the data type of the tensor.
+      format: A format object representing the storage format of the tensor.
+      name: A string object representing the name of the tensor.
+      order: An integral rank of the tensor.
+      shape: A list of integers representing the shape of the tensor.
 
-    Raises:
-      ValueError: If tensor is not a dense tensor, or the key is not an IndexVar
-        or a tuple of IndexVar, or the length of the indices is not the same as
-        the rank of the tensor.
+    We currently ignore the tensor dimension ordering for dense tensor.
     """
-    indices = self._verify_and_normalize_indices(key)
-    if len(indices) != self.order:
-      raise ValueError("Mismatch between indices and tensor rank: "
-                       f"len({indices}) != {self.order}.")
 
-    self._assignment = _Assignment(indices, value)
-    self._engine = None
-
-  def compile(self, force_recompile: bool = False) -> None:
-    """Compiles the tensor assignment to an execution engine.
-
-    Calling compile the second time does not do anything unless
-    force_recompile is True.
+    _counter = _AtomicCounter()
+
+    def _get_unique_name(self) -> str:
+        """Returns a unique name for creating a new Tensor."""
+        return f"{_TACO_TENSOR_PREFIX}{self._counter.increment()}"
+
+    def _init_format(self, fmt: Union[ModeFormat, List[ModeFormat], Format]) -> None:
+        """Process the fmt argument for the Tensor constructor.
+
+        Args:
+          fmt: This argument can be a ModeFormat, List[ModeFormat], or format. If
+            this argument is a ModeFormat, uses this ModeFormat for all the tensor
+            dimensions. If this argument is a list of ModeFormat, the len of the
+            list should equal to the rank of the tensor. If this argument is a
+            format, uses it for the format of the tensor.
+
+        Raises:
+          ValueError: If fmt is not one of the expected type or is inconsistent
+            with the rank of the tensor. This is because fmt could be an users
+            input.
+        """
+        if isinstance(fmt, ModeFormat):
+            self._format = _make_format([fmt] * self.order)
+        elif isinstance(fmt, list):
+            if len(fmt) == self.order and isinstance(fmt[0], ModeFormat):
+                self._format = _make_format(fmt)
+            else:
+                raise ValueError(
+                    "Inconsistent shape and format: " f"{self._shape}, {fmt}."
+                )
+        elif isinstance(fmt, Format):
+            if fmt.rank() != self.order:
+                raise ValueError(
+                    "Inconsistent shape and format: " f"{self._shape}, {fmt}."
+                )
+            else:
+                self._format = fmt
+        else:
+            raise ValueError(f"Invalid format argument: {fmt}.")
+
+    def __init__(
+        self,
+        value_or_shape: Optional[
+            Union[List[int], Tuple[int, ...], complex, float, int]
+        ] = None,
+        fmt: Optional[Union[ModeFormat, List[ModeFormat], Format]] = None,
+        dtype: Optional[DType] = None,
+        name: Optional[str] = None,
+        is_dense: bool = False,
+    ):
+        """The tensor constructor interface defined by TACO API.
+
+        Args:
+          value_or_shape: This argument is optional and can be int, float,
+            List[int], or Tuple[int, ...]. If this argument is an int or float,
+            creates a scalar tensor and initializes it with the value. If this
+            argument is a list or tuple of int, uses it as the shape to create a
+            tensor.
+          fmt: This argument can be a ModeFormat, List[ModeFormat], or format. If
+            this argument is a ModeFormat, uses this ModeFormat for all the tensor
+            dimensions. If this argument is a list of ModeFormat, the len of the
+            list should equal to the rank of the tensor. If this argument is a
+            format, uses it for the format of the tensor.
+          dtype: An object of dtype, representing the data type of the tensor.
+          name: A string name of the tensor. If a name is not given, creates a
+            unique name for the tensor.
+          is_dense: A boolean variable to indicate whether the tensor is a dense
+            tensor without any sparsity annotation.
+
+        Raises:
+          ValueError: If there is any inconsistency among the input arguments.
+        """
+        # Take care of the argument default values common to both sparse tensors
+        # and dense tensors.
+        dtype = dtype or DType(Type.FLOAT32)
+        self._name = name or self._get_unique_name()
+        self._assignment = None
+        self._engine = None
+        self._sparse_value_location = _SparseValueInfo._UNPACKED
+        self._dense_storage = None
+        self._dtype = dtype
+
+        if is_dense:
+            assert fmt is None
+            assert (
+                isinstance(value_or_shape, tuple) or isinstance(value_or_shape, list)
+            ) and _all_instance_of(value_or_shape, int)
+            self._shape = value_or_shape
+            self._format = None
+            return
+
+        fmt = fmt or ModeFormat.COMPRESSED
+        # We currently use _coords and _values to host the sparse tensor value with
+        # COO format, and _dense_storage to host the dense tensor value. We don't
+        # support the conversion between the two storages.
+        self._coords = []
+        self._values = []
+        self._stats = _Stats()
+        if (
+            value_or_shape is None
+            or isinstance(value_or_shape, int)
+            or isinstance(value_or_shape, float)
+            or isinstance(value_or_shape, complex)
+        ):
+            # Create a scalar tensor and ignore the fmt parameter.
+            self._shape = []
+            self._format = _make_format([], [])
+            if value_or_shape is not None:
+                self._dense_storage = np.array(value_or_shape, dtype=self._dtype.value)
+        elif (
+            isinstance(value_or_shape, tuple) or isinstance(value_or_shape, list)
+        ) and _all_instance_of(value_or_shape, int):
+            # Create a tensor with the specified shape and format.
+            self._shape = list(value_or_shape)
+            self._init_format(fmt)
+        else:
+            raise ValueError(
+                "Invalid first argument. "
+                "Must be a tuple or list for a shape or a single value"
+                f"if initializing a scalar tensor: {value_or_shape}."
+            )
+
+    def _set_packed_sparse_tensor(self, pointer: ctypes.c_void_p) -> None:
+        """Records the MLIR sparse tensor pointer."""
+        self._sparse_value_location = _SparseValueInfo._PACKED
+        self._packed_sparse_value = pointer
+
+    def is_unpacked(self) -> bool:
+        """Returns true if the tensor value is not packed as MLIR sparse tensor."""
+        return self._sparse_value_location == _SparseValueInfo._UNPACKED
+
+    def unpack(self) -> None:
+        """Unpacks the MLIR sparse tensor representation."""
+        if self.is_dense() or self.is_unpacked():
+            return
+
+        # Use the output MLIR sparse tensor pointer to retrieve the COO-flavored
+        # values and verify the values.
+        rank, nse, shape, values, indices = utils.sparse_tensor_to_coo_tensor(
+            self._packed_sparse_value, self._dtype.value
+        )
+        assert rank == self.order
+        assert np.array_equal(self.shape, shape)
+        assert nse == len(values)
+        self._coords = indices
+        self._values = values
+        self._sparse_value_location = _SparseValueInfo._UNPACKED
+
+    def __repr__(self) -> str:
+        self._sync_value()
+        self.unpack()
+        value_str = (
+            f"{repr(self._dense_storage)})"
+            if self.is_dense()
+            else f"{repr(self._coords)} {repr(self._values)})"
+        )
+        return (
+            f"Tensor(_name={repr(self._name)} " f"_dtype={repr(self._dtype)} : "
+        ) + value_str
+
+    def insert(self, coords: List[int], val: Union[complex, float, int]) -> None:
+        """Inserts a value to the given coordinate.
+
+        Args:
+          coords: A list of integer coordinates. The length of the list must be the
+            same as the rank of the tensor.
+          val: A value being inserted. It is either an integral or a floating point
+            value. This value will be converted to the data type of the tensor.
+
+        Raises:
+          ValueError: When there is any problem in the parameters.
+        """
+        if self.is_dense():
+            raise ValueError("Insert method is not supported for dense tensors.")
+        if self._assignment != None or not self.is_unpacked():
+            raise ValueError(
+                "Can't use Insert method for a tensor constructed from a file."
+            )
+        if not isinstance(coords, list):
+            raise ValueError(f"Non list coordinate detected: {coords}.")
+        if not _all_instance_of(coords, int):
+            raise ValueError(f"Non integer coordinate detected: {coords}.")
+        if len(coords) != self.order or any(
+            [c < 0 or c >= self._shape[i] for i, c in enumerate(coords)]
+        ):
+            raise ValueError("Invalid coordinate for rank: " f"{self.order}, {coords}.")
+
+        if (
+            not isinstance(val, int)
+            and not isinstance(val, float)
+            and not isinstance(val, complex)
+        ):
+            raise ValueError(f"Value is neither int nor float: {val}.")
+
+        self._coords.append(tuple(coords))
+        self._values.append(self._dtype.value(val))
+
+    def is_dense(self) -> bool:
+        """Returns true if the tensor doesn't have sparsity annotation."""
+        return self.order == 0 or self._format is None
+
+    def to_array(self) -> np.ndarray:
+        """Returns the numpy array for the Tensor.
+
+        This is currenly only implemented for dense Tensor.
+        """
+        if not self.is_dense():
+            raise ValueError(
+                "Conversion from non-dense Tensor " "to numpy array not supported yet."
+            )
+
+        self._sync_value()
+
+        return self._dense_storage
+
+    @staticmethod
+    def from_array(array: np.ndarray) -> "Tensor":
+        """Returns a dense tensor with the value copied from the input array.
+
+        We currently only support the conversion of float32 and float64 numpy arrays
+        to Tensor.
+
+        Args:
+          array: The numpy array that provides the data type, shape and value for
+            the tensor.
+
+        Returns:
+          A Tensor object.
+
+        Raises:
+          ValueError if the data type of the numpy array is not supported.
+        """
+        if array.dtype != np.float32 and array.dtype != np.float64:
+            raise ValueError(f"Expected floating point value type: {array.dtype}.")
+        t = Tensor(
+            array.shape, dtype=_nptype_to_taco_type(array.dtype.type), is_dense=True
+        )
+        t._dense_storage = np.copy(array)
+        return t
+
+    @staticmethod
+    def from_coo(
+        coordinates: List[Tuple[int, ...]],
+        values: List[_AnyRuntimeType],
+        fmt: Format,
+        dtype: DType,
+    ) -> "Tensor":
+        """Converts coordinates and values to a sparse tensor representation.
+
+        Args:
+          coordinates: A list of coordinates with non-zero values.
+          values: The non-zero values.
+          fmt: The tensor storage format.
+          dtype: The tensor element data type.
+
+        Returns:
+          A tensor with the given non-zero values and storage format. The shape of
+          the tensor has the minimum size for each dimension to make the given
+          coordinates valid.
+        """
+        assert isinstance(coordinates, List) and _all_instance_of(coordinates, Tuple)
+        assert isinstance(values, List) and _all_instance_of(values, dtype.value)
+        assert isinstance(fmt, Format)
+
+        rank = fmt.rank()
+        assert all(len(c) == rank and _all_instance_of(c, int) for c in coordinates)
+
+        # Find the maximum coordinate value for each dimension.
+        max_coordinate = list(map(max, zip(*coordinates)))
+        # The size of each dimension is one more that such a maximum coordinate
+        # value.
+        shape = [c + 1 for c in max_coordinate]
+        t = Tensor(shape, fmt, dtype=dtype)
+        t._coords = coordinates
+        t._values = values
+
+        return tensor
+
+    @staticmethod
+    def from_file(
+        filename: str,
+        fmt: Format,
+        dtype: DType,
+    ) -> "Tensor":
+        """Constructs a sparse tensor using the COO-flavored values from a file.
+
+        Args:
+          filename: A string for the name of the file that contains the sparse
+            tensor data.
+          fmt: The tensor storage format.
+          dtype: The tensor element data type.
+
+        Returns:
+          A tensor with the given non-zero values and storage format. The tensor
+          value is stored as an MLIR sparse tensor.
+        """
+        sparse_tensor, shape = utils.create_sparse_tensor(
+            filename, fmt.format_pack.formats, _dtype_to_mlir_str(dtype)
+        )
+        t = Tensor(shape.tolist(), fmt, dtype=dtype)
+        t._set_packed_sparse_tensor(sparse_tensor)
+
+        return t
+
+    def to_file(self, filename: str) -> None:
+        """Output the tensor value to a file.
+
+        This method evaluates any pending assignment to the tensor and outputs the
+        tensor value.
+
+        Args:
+          filename: A string file name.
+
+        Raises:
+           ValueError: If the tensor is dense, or an unpacked sparse tensor.
+        """
+        self._sync_value()
+
+        if self.is_dense():
+            raise ValueError(
+                "Writing dense tensors without sparsity annotation to "
+                "file is not supported."
+            )
+
+        if self.is_unpacked():
+            raise ValueError(
+                "Writing unpacked sparse tensors to file is not " "supported."
+            )
+
+        utils.output_sparse_tensor(
+            self._packed_sparse_value,
+            filename,
+            self._format.format_pack.formats,
+            _dtype_to_mlir_str(self._dtype),
+        )
+
+    @property
+    def dtype(self) -> DType:
+        """Returns the data type for the Tensor."""
+        return self._dtype
+
+    @property
+    def format(self) -> Format:
+        """Returns the storage format for the Tensor."""
+        return self._format
+
+    @property
+    def name(self) -> str:
+        """Returns the name for the Tensor."""
+        return self._name
+
+    @property
+    def order(self) -> int:
+        """Returns the rank of the Tensor."""
+        return len(self._shape)
+
+    @property
+    def shape(self) -> List[int]:
+        """Returns the shape of the Tensor."""
+        return self._shape
+
+    def _verify_and_normalize_indices(self, indices) -> Tuple[IndexVar, ...]:
+        """Verifies and normalizes the indices to access the tensor.
+
+        Args:
+          indices: The index expression used to access a tensor, which could be any
+            Python object from user inputs.
+
+        Returns:
+          A tuple of IndexVar.
+
+        Raises:
+          ValueError: If indices is not 0 for scalar tensors, or not an IndexVar or
+            a tuple of IndexVar for other tensors.
+        """
+        if self.order == 0:
+            if not isinstance(indices, int) or indices != 0:
+                raise ValueError(f"Expected 0 to index scalar tensors: {indices}")
+            return ()
+
+        if isinstance(indices, IndexVar):
+            return (indices,)
+        elif isinstance(indices, tuple) and _all_instance_of(indices, IndexVar):
+            return indices
+
+        raise ValueError(f"Expected IndexVars: {indices}")
+
+    def __getitem__(self, key) -> "Access":
+        """Verifies and processes a tensor access.
+
+        In the tensor index notation, a tensor access T[i, j] is represented as
+        retrieving a value with key (i, j) from the tensor object T in Python. This
+        routine verifies the key for the tensor access and returns a tensor access
+        object.
+
+        Args:
+          key: The key used to access the tensor, which could be any Python object
+            from user inputs.
+
+        Returns:
+          The corresponding tensor access object.
+
+        Raises:
+          ValueError: If key is not an IndexVar or a tuple of IndexVar.
+        """
+        indices = self._verify_and_normalize_indices(key)
+        return Access(self, indices)
+
+    def __setitem__(self, key, value) -> None:
+        """Verifies and processes a tensor assignment.
+
+        In the tensor index notation, a tensor assignment "T[i, j] = ..." is
+        represented as setting a value for a tensor object T via key (i, j) in
+        Python. This routine verifies the key, evaluates the value, and assigns the
+        value to the tensor.
+
+        We only support assignment of dense tensor currently.
+
+        Args:
+          key: The key used to access the tensor, which could be any Python object
+            from user inputs.
+          value: The value assigned to the tensor, which could be any Python object
+            from user inputs.
+
+        Raises:
+          ValueError: If tensor is not a dense tensor, or the key is not an IndexVar
+            or a tuple of IndexVar, or the length of the indices is not the same as
+            the rank of the tensor.
+        """
+        indices = self._verify_and_normalize_indices(key)
+        if len(indices) != self.order:
+            raise ValueError(
+                "Mismatch between indices and tensor rank: "
+                f"len({indices}) != {self.order}."
+            )
+
+        self._assignment = _Assignment(indices, value)
+        self._engine = None
+
+    def compile(self, force_recompile: bool = False) -> None:
+        """Compiles the tensor assignment to an execution engine.
+
+        Calling compile the second time does not do anything unless
+        force_recompile is True.
+
+        Args:
+          force_recompile: A boolean value to enable recompilation, such as for the
+            purpose of timing.
+
+        Raises:
+          ValueError: If the assignment is not proper or not supported.
+        """
+        if self._assignment is None or (
+            self._engine is not None and not force_recompile
+        ):
+            return
+
+        self._engine = self._assignment.expression.compile(
+            self, self._assignment.indices
+        )
+
+    def compute(self) -> None:
+        """Executes the engine for the tensor assignment.
+
+        Raises:
+          ValueError: If the assignment hasn't been compiled yet.
+        """
+        if self._assignment is None:
+            return
+
+        if self._engine is None:
+            raise ValueError("Need to invoke compile() before invoking compute().")
+
+        input_accesses = self._assignment.expression.get_input_accesses()
+        # Gather the pointers for the input buffers.
+        input_pointers = [a.tensor.ctype_pointer() for a in input_accesses]
+        if self.is_dense():
+            # The pointer to receive dense output is the first argument to the
+            # execution engine.
+            arg_pointers = [self.dense_dst_ctype_pointer()] + input_pointers
+        else:
+            # The pointer to receive the sparse tensor output is the last argument
+            # to the execution engine and is a pointer to pointer of char.
+            arg_pointers = input_pointers + [
+                ctypes.pointer(ctypes.pointer(ctypes.c_char(0)))
+            ]
+
+        # Invoke the execution engine to run the module.
+        self._engine.invoke(_ENTRY_NAME, *arg_pointers)
+
+        # Retrieve the result.
+        if self.is_dense():
+            result = runtime.ranked_memref_to_numpy(arg_pointers[0][0])
+            assert isinstance(result, np.ndarray)
+            self._dense_storage = result
+        else:
+            self._set_packed_sparse_tensor(arg_pointers[-1][0])
+
+        self._assignment = None
+        self._engine = None
+
+    def evaluate(self) -> None:
+        """Evaluates the tensor assignment."""
+        self.compile()
+        self.compute()
+
+    def _sync_value(self) -> None:
+        """Updates the tensor value by evaluating the pending assignment."""
+        if self._assignment is not None:
+            self.evaluate()
+
+    def mlir_tensor_type(self) -> ir.RankedTensorType:
+        """Returns the MLIR type for the tensor."""
+        mlir_attr = (
+            None
+            if (self._format is None or self.order == 0)
+            else self._format.mlir_tensor_attr()
+        )
+        return _mlir_tensor_type(self._dtype, tuple(self._shape), mlir_attr)
+
+    def dense_dst_ctype_pointer(self) -> ctypes.pointer:
+        """Returns the ctypes pointer for the pointer to an MemRefDescriptor.
+
+        For a dense tensor output, the MLIR compiler allocates the storage for
+        the tensor. This routine returns the pointer to an MLIR MemRefDescriptor for
+        receiving the tensor.
+        """
+        assert self.is_dense()
+        mem_ref_desc = runtime.make_nd_memref_descriptor(
+            self.order, np.ctypeslib.as_ctypes_type(self.dtype.value)
+        )()
+        return ctypes.pointer(ctypes.pointer(mem_ref_desc))
+
+    def ctype_pointer(self) -> ctypes.pointer:
+        """Returns the ctypes pointer for the pointer to the input tensor."""
+        if self.is_dense():
+            if self._dense_storage is None:
+                self._dense_storage = np.zeros(self._shape, self._dtype.value)
+            return _ctype_pointer_from_array(self._dense_storage)
+
+        if self.is_unpacked():
+            shape = np.array(self._shape, np.int64)
+            indices = np.array(self._coords, np.int64)
+            values = np.array(self._values, self._dtype.value)
+            perm, sparse = self.format.get_permutation_and_sparsity()
+            ptr = utils.coo_tensor_to_sparse_tensor(
+                shape, values, indices, perm, sparse
+            )
+        else:
+            ptr = self._packed_sparse_value
+
+        return ctypes.pointer(ctypes.cast(ptr, ctypes.c_void_p))
+
+    def get_scalar_value(self) -> _AnyRuntimeType:
+        """Returns the value for the scalar tensor.
+
+        This method also evaluates the assignment to the tensor.
+
+        Raises:
+          ValueError: If the tensor is not a scalar.
+        """
+        if self.order != 0:
+            raise ValueError(f"Expected a scalar tensor, got: rank={self.order}")
+
+        self._sync_value()
+        return self._dense_storage
+
+    def get_coordinates_and_values(
+        self,
+    ) -> Tuple[List[Tuple[int, ...]], List[_AnyRuntimeType]]:
+        """Returns the coordinates and values for the non-zero elements.
+
+        This method also evaluates the assignment to the tensor and unpack the
+        sparse tensor.
+        """
+        self._sync_value()
+
+        if not self.is_dense():
+            self.unpack()
+            return (self._coords, self._values)
+
+        if self.order == 0:
+            return ([], self._dense_storage)
+
+        # Coordinates for non-zero elements, grouped by dimensions.
+        coords_by_dims = self._dense_storage.nonzero()
+        # Coordinates for non-zero elements, grouped by elements.
+        coords = np.transpose(coords_by_dims)
+        values = self._dense_storage[coords_by_dims]
+        return (coords, values)
+
+    def _record_stats(self, structop: "_StructOpInfo"):
+        """Collects information for temporary tensors."""
+        # Exclude user specified destination tensors.
+        if structop.dst_name == self.name:
+            return
+
+        self._stats.add_element(structop)
+
+
+def _emit_operand(
+    op_def: lang.LinalgOpDef,
+    indices: Tuple[IndexVar, ...],
+    name: str,
+    kind: lang.OperandKind,
+) -> lang.OperandDef:
+    """Emits an operand for a tensor access in the current linalg operation.
 
     Args:
-      force_recompile: A boolean value to enable recompilation, such as for the
-        purpose of timing.
-
-    Raises:
-      ValueError: If the assignment is not proper or not supported.
-    """
-    if self._assignment is None or (self._engine is not None and
-                                    not force_recompile):
-      return
-
-    self._engine = self._assignment.expression.compile(self,
-                                                       self._assignment.indices)
+      op_def: A LinalgOpDef representing the current linalg dialect operation.
+      indices: A tuple of IndexVar used to access the tensor.
+      name: A unique string name of the tensor.
+      kind: An OperandKind for the operand.
 
-  def compute(self) -> None:
-    """Executes the engine for the tensor assignment.
-
-    Raises:
-      ValueError: If the assignment hasn't been compiled yet.
-    """
-    if self._assignment is None:
-      return
-
-    if self._engine is None:
-      raise ValueError("Need to invoke compile() before invoking compute().")
-
-    input_accesses = self._assignment.expression.get_input_accesses()
-    # Gather the pointers for the input buffers.
-    input_pointers = [a.tensor.ctype_pointer() for a in input_accesses]
-    if self.is_dense():
-      # The pointer to receive dense output is the first argument to the
-      # execution engine.
-      arg_pointers = [self.dense_dst_ctype_pointer()] + input_pointers
-    else:
-      # The pointer to receive the sparse tensor output is the last argument
-      # to the execution engine and is a pointer to pointer of char.
-      arg_pointers = input_pointers + [
-          ctypes.pointer(ctypes.pointer(ctypes.c_char(0)))
-      ]
-
-    # Invoke the execution engine to run the module.
-    self._engine.invoke(_ENTRY_NAME, *arg_pointers)
-
-    # Retrieve the result.
-    if self.is_dense():
-      result = runtime.ranked_memref_to_numpy(arg_pointers[0][0])
-      assert isinstance(result, np.ndarray)
-      self._dense_storage = result
-    else:
-      self._set_packed_sparse_tensor(arg_pointers[-1][0])
-
-    self._assignment = None
-    self._engine = None
-
-  def evaluate(self) -> None:
-    """Evaluates the tensor assignment."""
-    self.compile()
-    self.compute()
-
-  def _sync_value(self) -> None:
-    """Updates the tensor value by evaluating the pending assignment."""
-    if self._assignment is not None:
-      self.evaluate()
-
-  def mlir_tensor_type(self) -> ir.RankedTensorType:
-    """Returns the MLIR type for the tensor."""
-    mlir_attr = (None if (self._format is None or self.order == 0) else
-                 self._format.mlir_tensor_attr())
-    return _mlir_tensor_type(self._dtype, tuple(self._shape), mlir_attr)
-
-  def dense_dst_ctype_pointer(self) -> ctypes.pointer:
-    """Returns the ctypes pointer for the pointer to an MemRefDescriptor.
-
-    For a dense tensor output, the MLIR compiler allocates the storage for
-    the tensor. This routine returns the pointer to an MLIR MemRefDescriptor for
-    receiving the tensor.
-    """
-    assert self.is_dense()
-    mem_ref_desc = runtime.make_nd_memref_descriptor(
-        self.order, np.ctypeslib.as_ctypes_type(self.dtype.value))()
-    return ctypes.pointer(ctypes.pointer(mem_ref_desc))
-
-  def ctype_pointer(self) -> ctypes.pointer:
-    """Returns the ctypes pointer for the pointer to the input tensor."""
-    if self.is_dense():
-      if self._dense_storage is None:
-        self._dense_storage = np.zeros(self._shape, self._dtype.value)
-      return _ctype_pointer_from_array(self._dense_storage)
-
-    if self.is_unpacked():
-      shape = np.array(self._shape, np.int64)
-      indices = np.array(self._coords, np.int64)
-      values = np.array(self._values, self._dtype.value)
-      perm, sparse = self.format.get_permutation_and_sparsity()
-      ptr = utils.coo_tensor_to_sparse_tensor(shape, values, indices, perm,
-                                              sparse)
-    else:
-      ptr = self._packed_sparse_value
-
-    return ctypes.pointer(ctypes.cast(ptr, ctypes.c_void_p))
-
-  def get_scalar_value(self) -> _AnyRuntimeType:
-    """Returns the value for the scalar tensor.
-
-    This method also evaluates the assignment to the tensor.
-
-    Raises:
-      ValueError: If the tensor is not a scalar.
-    """
-    if self.order != 0:
-      raise ValueError(f"Expected a scalar tensor, got: rank={self.order}")
-
-    self._sync_value()
-    return self._dense_storage
-
-
-  def get_coordinates_and_values(
-      self) -> Tuple[List[Tuple[int, ...]], List[_AnyRuntimeType]]:
-    """Returns the coordinates and values for the non-zero elements.
-
-    This method also evaluates the assignment to the tensor and unpack the
-    sparse tensor.
+    Returns:
+      An OperandDef representing the operand.
     """
-    self._sync_value()
-
-    if not self.is_dense():
-      self.unpack()
-      return (self._coords, self._values)
-
-    if self.order == 0:
-      return ([], self._dense_storage)
-
-    # Coordinates for non-zero elements, grouped by dimensions.
-    coords_by_dims = self._dense_storage.nonzero()
-    # Coordinates for non-zero elements, grouped by elements.
-    coords = np.transpose(coords_by_dims)
-    values = self._dense_storage[coords_by_dims]
-    return (coords, values)
-
-  def _record_stats(self, structop: "_StructOpInfo"):
-    """Collects information for temporary tensors."""
-    # Exclude user specified destination tensors.
-    if structop.dst_name == self.name:
-      return
-
-    self._stats.add_element(structop)
-
-
-def _emit_operand(op_def: lang.LinalgOpDef, indices: Tuple[IndexVar, ...],
-                  name: str, kind: lang.OperandKind) -> lang.OperandDef:
-  """Emits an operand for a tensor access in the current linalg operation.
-
-  Args:
-    op_def: A LinalgOpDef representing the current linalg dialect operation.
-    indices: A tuple of IndexVar used to access the tensor.
-    name: A unique string name of the tensor.
-    kind: An OperandKind for the operand.
-
-  Returns:
-    An OperandDef representing the operand.
-  """
-  dim_sym = _mlir_symbols_from_index_vars(indices)
-  opnd = lang.OperandDef(kind, lang.T, dim_sym)
-  op_def.add_operand(name, opnd)
-  return opnd
+    dim_sym = _mlir_symbols_from_index_vars(indices)
+    opnd = lang.OperandDef(kind, lang.T, dim_sym)
+    op_def.add_operand(name, opnd)
+    return opnd
 
 
 @dataclasses.dataclass(frozen=True)
 class _DimInfo:
-  """Information for an operand dimension.
+    """Information for an operand dimension.
 
-  Attributes:
-    dim: An integer for the size of the dimension.
-    mode_format: A ModeFormat for the dimension sparsity.
-  """
-  dim: int
-  mode_format: ModeFormat
+    Attributes:
+      dim: An integer for the size of the dimension.
+      mode_format: A ModeFormat for the dimension sparsity.
+    """
+
+    dim: int
+    mode_format: ModeFormat
 
 
 def _get_dummy_dim_info() -> _DimInfo:
-  """Constructs the _DimInfo for an index used in tensor expressions."""
-  return _DimInfo(-1, ModeFormat.DENSE)
+    """Constructs the _DimInfo for an index used in tensor expressions."""
+    return _DimInfo(-1, ModeFormat.DENSE)
 
 
 @dataclasses.dataclass()
 class _ExprInfo:
-  """Expression information for validation and code generation.
-
-  Attributes:
-    src_indices: A tuple of IndexVar for the indices used by the tensors in the
-      expression tree.
-    dim_infos: A tuple of _DimInfo, representing the dimension information
-      corresponding to the src_indices.
-    reduce_indices: A set of IndexVar for the indices reduced by the expression.
-    acc_reduce_indices: An accumulated set of IndexVar for the indices reduced
-      by the expression and its children.
-    structop_info: Information to support the code generation for a structured
-      op in the linalg dialect, if the corresponding expression node is the root
-      of a subtree for a structured op.
-    mlir_value: The MLIR value generated for the structured op.
-  """
-  src_indices: Tuple[IndexVar, ...]
-  dim_infos: Tuple[_DimInfo, ...]
-  reduce_indices: Optional[Set[IndexVar]] = None
-  acc_reduce_indices: Optional[Set[IndexVar]] = None
-  structop_info: Optional[_StructOpInfo] = None
-  mlir_value: Optional[ir.Value] = None
-
-  def __post_init__(self) -> None:
-    """Verifies and fix up attribute values.
-
-    Verifies the consistency of the attributes and modifies the default values
-    to support convenient initializer syntax.
+    """Expression information for validation and code generation.
+
+    Attributes:
+      src_indices: A tuple of IndexVar for the indices used by the tensors in the
+        expression tree.
+      dim_infos: A tuple of _DimInfo, representing the dimension information
+        corresponding to the src_indices.
+      reduce_indices: A set of IndexVar for the indices reduced by the expression.
+      acc_reduce_indices: An accumulated set of IndexVar for the indices reduced
+        by the expression and its children.
+      structop_info: Information to support the code generation for a structured
+        op in the linalg dialect, if the corresponding expression node is the root
+        of a subtree for a structured op.
+      mlir_value: The MLIR value generated for the structured op.
     """
-    assert len(self.src_indices) == len(self.dim_infos)
-    self.reduce_indices = self.reduce_indices or set()
-    self.acc_reduce_indices = self.acc_reduce_indices or set()
 
+    src_indices: Tuple[IndexVar, ...]
+    dim_infos: Tuple[_DimInfo, ...]
+    reduce_indices: Optional[Set[IndexVar]] = None
+    acc_reduce_indices: Optional[Set[IndexVar]] = None
+    structop_info: Optional[_StructOpInfo] = None
+    mlir_value: Optional[ir.Value] = None
 
-@dataclasses.dataclass(frozen=True)
-class Access(IndexExpr):
-  """The tensor access class.
+    def __post_init__(self) -> None:
+        """Verifies and fix up attribute values.
+
+        Verifies the consistency of the attributes and modifies the default values
+        to support convenient initializer syntax.
+        """
+        assert len(self.src_indices) == len(self.dim_infos)
+        self.reduce_indices = self.reduce_indices or set()
+        self.acc_reduce_indices = self.acc_reduce_indices or set()
 
-  We support the TACO API access class with an alias of this class.
 
-  Attributes:
-    tensor: A Tensor being accessed.
-    indices: A tuple of IndexVar, representing the indices used to access the
-      Tensor.
-  """
-  tensor: Tensor
-  indices: Tuple[IndexVar, ...]
+@dataclasses.dataclass(frozen=True)
+class Access(IndexExpr):
+    """The tensor access class.
 
-  def __post_init__(self) -> None:
-    """Verifies the tensor and indices for a tensor access.
+    We support the TACO API access class with an alias of this class.
 
-    Raises:
-       ValueError: If indices is not a list of IndexVar or the len of indices
-       doesn't equal to the rank of the tensor.
+    Attributes:
+      tensor: A Tensor being accessed.
+      indices: A tuple of IndexVar, representing the indices used to access the
+        Tensor.
     """
-    if (not isinstance(self.indices, tuple) or
-        not _all_instance_of(self.indices, IndexVar)):
-      raise ValueError(f"Indices contain non IndexVar: {str(self.indices)}.")
-    if self.tensor.order != len(self.indices):
-      raise ValueError("Invalid indices for rank: "
-                       f"str{self.tensor.order} != len({str(self.indices)}).")
-
-  def __repr__(self) -> str:
-    # The Tensor __repr__ method evaluates the pending assignment to the tensor.
-    # We want to define the __repr__ method here to avoid such evaluation of the
-    # tensor assignment.
-    indices_str = ", ".join(map(lambda i: i.name, self.indices))
-    return (f"Tensor({self.tensor.name}) " f"Indices({indices_str})")
-
-  def _emit_expression(
-      self,
-      expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
-      expr_to_info: _ExprInfoDict,
-  ) -> lang.ScalarExpression:
-    """Emits a linalg dialect TensorUse expression for the tensor access."""
-    assert self in expr_to_opnd
-    dims = _mlir_dimensions_from_index_vars(self.indices)
-    return lang.TensorUse(expr_to_opnd[self], dims)
-
-  def _visit(self,
-             func: _ExprVisitor,
-             args,
-             *,
-             leaf_checker: _SubtreeLeafChecker = None) -> None:
-    if leaf_checker:
-      assert leaf_checker(self, *args)
-    func(self, *args)
-
-  def dtype(self) -> DType:
-    return self.tensor.dtype
+
+    tensor: Tensor
+    indices: Tuple[IndexVar, ...]
+
+    def __post_init__(self) -> None:
+        """Verifies the tensor and indices for a tensor access.
+
+        Raises:
+           ValueError: If indices is not a list of IndexVar or the len of indices
+           doesn't equal to the rank of the tensor.
+        """
+        if not isinstance(self.indices, tuple) or not _all_instance_of(
+            self.indices, IndexVar
+        ):
+            raise ValueError(f"Indices contain non IndexVar: {str(self.indices)}.")
+        if self.tensor.order != len(self.indices):
+            raise ValueError(
+                "Invalid indices for rank: "
+                f"str{self.tensor.order} != len({str(self.indices)})."
+            )
+
+    def __repr__(self) -> str:
+        # The Tensor __repr__ method evaluates the pending assignment to the tensor.
+        # We want to define the __repr__ method here to avoid such evaluation of the
+        # tensor assignment.
+        indices_str = ", ".join(map(lambda i: i.name, self.indices))
+        return f"Tensor({self.tensor.name}) " f"Indices({indices_str})"
+
+    def _emit_expression(
+        self,
+        expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
+        expr_to_info: _ExprInfoDict,
+    ) -> lang.ScalarExpression:
+        """Emits a linalg dialect TensorUse expression for the tensor access."""
+        assert self in expr_to_opnd
+        dims = _mlir_dimensions_from_index_vars(self.indices)
+        return lang.TensorUse(expr_to_opnd[self], dims)
+
+    def _visit(
+        self, func: _ExprVisitor, args, *, leaf_checker: _SubtreeLeafChecker = None
+    ) -> None:
+        if leaf_checker:
+            assert leaf_checker(self, *args)
+        func(self, *args)
+
+    def dtype(self) -> DType:
+        return self.tensor.dtype
 
 
 def _gather_input_accesses_index_vars(
     expr: IndexExpr,
     input_accesses: List[Access],
 ) -> None:
-  """Collects Access nodes."""
-  if isinstance(expr, Access) and expr not in input_accesses:
-    input_accesses.append(expr)
+    """Collects Access nodes."""
+    if isinstance(expr, Access) and expr not in input_accesses:
+        input_accesses.append(expr)
 
 
 def _op_ceil(__a: Any) -> Any:
-  """A _UnaryOp object for operation ceil."""
-  pass
+    """A _UnaryOp object for operation ceil."""
+    pass
 
 
 def _op_floor(__a: Any) -> Any:
-  """A _UnaryOp object for operation floor."""
-  pass
+    """A _UnaryOp object for operation floor."""
+    pass
 
 
 def _op_unary_to_callable(op: _UnaryOp) -> lang.UnaryFnType:
-  """Returns the linalg dialect function object for the given operation."""
-  op_to_callable = {
-      operator.abs: lang.UnaryFn.abs,
-      operator.neg: lang.UnaryFn.negf,
-      _op_ceil: lang.UnaryFn.ceil,
-      _op_floor: lang.UnaryFn.floor,
-  }
-  return op_to_callable[op]
+    """Returns the linalg dialect function object for the given operation."""
+    op_to_callable = {
+        operator.abs: lang.UnaryFn.abs,
+        operator.neg: lang.UnaryFn.negf,
+        _op_ceil: lang.UnaryFn.ceil,
+        _op_floor: lang.UnaryFn.floor,
+    }
+    return op_to_callable[op]
 
 
 @dataclasses.dataclass(frozen=True)
 class _UnaryExpr(IndexExpr):
-  """The representation for a Unary operation.
-
-  Attributes:
-  op: A _UnaryOp representing the operation.
-  a: An IndexExpr representing the operand for the operation.
-  """
-  op: _BinaryOp
-  a: IndexExpr
-
-  def __post_init__(self) -> None:
-    """Verifies that the operand being added is an IndexExpr."""
-    assert isinstance(self.a, IndexExpr)
-
-  def _emit_expression(
-      self,
-      expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
-      expr_to_info: _ExprInfoDict,
-  ) -> lang.ScalarExpression:
-    """Emits the expression tree and returns the expression."""
-    # The current expression node is an internal node of the structured op.
-    if self not in expr_to_opnd:
-      a = self.a._emit_expression(expr_to_opnd, expr_to_info)
-      return _op_unary_to_callable(self.op)(a)
-
-    # The current expression is a leaf node of the structured op. That is, it is
-    # a temporary tensor generated by its child structured op.
-    op_info = expr_to_info[self].structop_info
-    assert op_info is not None
-    dims = _mlir_dimensions_from_index_vars(op_info.dst_indices)
-    return lang.TensorUse(expr_to_opnd[self], dims)
-
-  def _visit(self,
-             func: _ExprVisitor,
-             args,
-             *,
-             leaf_checker: _SubtreeLeafChecker = None) -> None:
-    """A post-order visitor."""
-    if leaf_checker is None or not leaf_checker(self, *args):
-      self.a._visit(func, args, leaf_checker=leaf_checker)
-    func(self, *args)
-
-  def dtype(self) -> DType:
-    """Returns the data type of the operation."""
-    return self.a.dtype()
+    """The representation for a Unary operation.
+
+    Attributes:
+    op: A _UnaryOp representing the operation.
+    a: An IndexExpr representing the operand for the operation.
+    """
+
+    op: _BinaryOp
+    a: IndexExpr
+
+    def __post_init__(self) -> None:
+        """Verifies that the operand being added is an IndexExpr."""
+        assert isinstance(self.a, IndexExpr)
+
+    def _emit_expression(
+        self,
+        expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
+        expr_to_info: _ExprInfoDict,
+    ) -> lang.ScalarExpression:
+        """Emits the expression tree and returns the expression."""
+        # The current expression node is an internal node of the structured op.
+        if self not in expr_to_opnd:
+            a = self.a._emit_expression(expr_to_opnd, expr_to_info)
+            return _op_unary_to_callable(self.op)(a)
+
+        # The current expression is a leaf node of the structured op. That is, it is
+        # a temporary tensor generated by its child structured op.
+        op_info = expr_to_info[self].structop_info
+        assert op_info is not None
+        dims = _mlir_dimensions_from_index_vars(op_info.dst_indices)
+        return lang.TensorUse(expr_to_opnd[self], dims)
+
+    def _visit(
+        self, func: _ExprVisitor, args, *, leaf_checker: _SubtreeLeafChecker = None
+    ) -> None:
+        """A post-order visitor."""
+        if leaf_checker is None or not leaf_checker(self, *args):
+            self.a._visit(func, args, leaf_checker=leaf_checker)
+        func(self, *args)
+
+    def dtype(self) -> DType:
+        """Returns the data type of the operation."""
+        return self.a.dtype()
 
 
 def _op_to_callable(op: _BinaryOp) -> lang.BinaryFnType:
-  """Returns the linalg dialect function object for the given operation."""
-  op_to_callable = {
-      operator.add: lang.BinaryFn.add,
-      operator.sub: lang.BinaryFn.sub,
-      operator.mul: lang.BinaryFn.mul,
-  }
-  return op_to_callable[op]
+    """Returns the linalg dialect function object for the given operation."""
+    op_to_callable = {
+        operator.add: lang.BinaryFn.add,
+        operator.sub: lang.BinaryFn.sub,
+        operator.mul: lang.BinaryFn.mul,
+    }
+    return op_to_callable[op]
+
 
 @dataclasses.dataclass(frozen=True)
 class _BinaryExpr(IndexExpr):
-  """The representation for a binary operation.
-
-  Attributes:
-  op: A _BinaryOp representing the binary operation.
-  a: An IndexExpr representing the first operand of the operation.
-  b: An IndexExpr representing the second operand of the operation.
-  """
-  op: _BinaryOp
-  a: IndexExpr
-  b: IndexExpr
-
-  def __post_init__(self) -> None:
-    """Verifies that the operands being added are IndexExpr."""
-    assert isinstance(self.a, IndexExpr) and isinstance(self.b, IndexExpr)
-
-  def _emit_expression(
-      self,
-      expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
-      expr_to_info: _ExprInfoDict,
-  ) -> lang.ScalarExpression:
-    """Emits the expression tree and returns the expression."""
-    # The current expression node is an internal node of the structured op.
-    if self not in expr_to_opnd:
-      a = self.a._emit_expression(expr_to_opnd, expr_to_info)
-      b = self.b._emit_expression(expr_to_opnd, expr_to_info)
-      return _op_to_callable(self.op)(a, b)
-
-    # The current expression is a leaf node of the structured op. That is, it is
-    # a temporary tensor generated by its child structured op.
-    op_info = expr_to_info[self].structop_info
-    assert op_info is not None
-    dims = _mlir_dimensions_from_index_vars(op_info.dst_indices)
-    return lang.TensorUse(expr_to_opnd[self], dims)
-
-  def _visit(self,
-             func: _ExprVisitor,
-             args,
-             *,
-             leaf_checker: _SubtreeLeafChecker = None) -> None:
-    """A post-order visitor."""
-    if leaf_checker is None or not leaf_checker(self, *args):
-      self.a._visit(func, args, leaf_checker=leaf_checker)
-      self.b._visit(func, args, leaf_checker=leaf_checker)
-    func(self, *args)
-
-  def dtype(self) -> DType:
-    """Returns the data type of the binary operation."""
-    return self.a.dtype()
+    """The representation for a binary operation.
+
+    Attributes:
+    op: A _BinaryOp representing the binary operation.
+    a: An IndexExpr representing the first operand of the operation.
+    b: An IndexExpr representing the second operand of the operation.
+    """
+
+    op: _BinaryOp
+    a: IndexExpr
+    b: IndexExpr
+
+    def __post_init__(self) -> None:
+        """Verifies that the operands being added are IndexExpr."""
+        assert isinstance(self.a, IndexExpr) and isinstance(self.b, IndexExpr)
+
+    def _emit_expression(
+        self,
+        expr_to_opnd: Dict[IndexExpr, lang.OperandDef],
+        expr_to_info: _ExprInfoDict,
+    ) -> lang.ScalarExpression:
+        """Emits the expression tree and returns the expression."""
+        # The current expression node is an internal node of the structured op.
+        if self not in expr_to_opnd:
+            a = self.a._emit_expression(expr_to_opnd, expr_to_info)
+            b = self.b._emit_expression(expr_to_opnd, expr_to_info)
+            return _op_to_callable(self.op)(a, b)
+
+        # The current expression is a leaf node of the structured op. That is, it is
+        # a temporary tensor generated by its child structured op.
+        op_info = expr_to_info[self].structop_info
+        assert op_info is not None
+        dims = _mlir_dimensions_from_index_vars(op_info.dst_indices)
+        return lang.TensorUse(expr_to_opnd[self], dims)
+
+    def _visit(
+        self, func: _ExprVisitor, args, *, leaf_checker: _SubtreeLeafChecker = None
+    ) -> None:
+        """A post-order visitor."""
+        if leaf_checker is None or not leaf_checker(self, *args):
+            self.a._visit(func, args, leaf_checker=leaf_checker)
+            self.b._visit(func, args, leaf_checker=leaf_checker)
+        func(self, *args)
+
+    def dtype(self) -> DType:
+        """Returns the data type of the binary operation."""
+        return self.a.dtype()
 
 
 def _validate_and_collect_dim_info(
@@ -1822,105 +1901,104 @@ def _validate_and_collect_dim_info(
     dim_infos: Tuple[_DimInfo, ...],
     expr: _BinaryExpr,
 ) -> None:
-  """Validates and collects the dimension information for an index notation.
-
-  Validates (indices, dim_infos) against the information collected from other
-  source operands and is represented by index_to_dim_info. In particular, we
-  ensure that each IndexVar corresponds to only one dimension size. We also
-  aggregate the new information represented in (indices, dim_infos) to
-  index_to_dim_info.
-
-  Args:
-    index_to_dim: A dictionary of (IndexVar, _DimInfo) collected from the
-      previous operands.
-    indices: The IndexVars to be validated.
-    dim_infos: The dimension information for the IndexVars to be validated.
-    expr: The binary expression where (indices, dim_infos) is used.
-
-  Raises:
-    ValueError if there is any problem in the IndexVars or dimensional values.
-  """
-  assert len(indices) == len(dim_infos)
-  for i, d in zip(indices, dim_infos):
-    if i not in index_to_dim_info:
-      index_to_dim_info[i] = d
-    else:
-      dim = index_to_dim_info[i].dim
-      if dim == -1 or d.dim == -1:
-        dim = dim if dim != -1 else d.dim
-      elif dim != d.dim:
-        raise ValueError(f"Inconsistent source dimension for {i}: "
-                         f"{d.dim} vs {dim}")
-      mode_format = _mode_format_estimator(expr.op)(
-          index_to_dim_info[i].mode_format, d.mode_format)
-      index_to_dim_info[i] = _DimInfo(d.dim, mode_format)
+    """Validates and collects the dimension information for an index notation.
+
+    Validates (indices, dim_infos) against the information collected from other
+    source operands and is represented by index_to_dim_info. In particular, we
+    ensure that each IndexVar corresponds to only one dimension size. We also
+    aggregate the new information represented in (indices, dim_infos) to
+    index_to_dim_info.
+
+    Args:
+      index_to_dim: A dictionary of (IndexVar, _DimInfo) collected from the
+        previous operands.
+      indices: The IndexVars to be validated.
+      dim_infos: The dimension information for the IndexVars to be validated.
+      expr: The binary expression where (indices, dim_infos) is used.
+
+    Raises:
+      ValueError if there is any problem in the IndexVars or dimensional values.
+    """
+    assert len(indices) == len(dim_infos)
+    for i, d in zip(indices, dim_infos):
+        if i not in index_to_dim_info:
+            index_to_dim_info[i] = d
+        else:
+            dim = index_to_dim_info[i].dim
+            if dim == -1 or d.dim == -1:
+                dim = dim if dim != -1 else d.dim
+            elif dim != d.dim:
+                raise ValueError(
+                    f"Inconsistent source dimension for {i}: " f"{d.dim} vs {dim}"
+                )
+            mode_format = _mode_format_estimator(expr.op)(
+                index_to_dim_info[i].mode_format, d.mode_format
+            )
+            index_to_dim_info[i] = _DimInfo(d.dim, mode_format)
 
 
 def _validate_and_collect_expr_info(
     expr: IndexExpr,
     expr_to_info: _ExprInfoDict,
 ) -> None:
-  """Validates dimension information and constructs _ExprInfo.
-
-  Validates that dimensional values for the same IndexVar are the same. Collects
-  a list of IndexVar used by the expression and their corresponding dimensional
-  values. Constructs an _ExprInfo object to record the information for the
-  IndexExpr.
-
-  This routine is passed to the post-order visitor as an _ExprVisitor object.
-
-  Args:
-    expr: The IndexExpr being validated.
-    expr_to_info: The dictionary of (IndexExpr, _ExprInfo) for recording the
-      expression information.
-
-  Raises:
-    ValueError if there is any problem in the IndexVars or dimensional values.
-  """
-  # Objects of class Access can be shared by different expressions. Avoid
-  # processing Access objects multiple times by skipping the processing if expr
-  # is already in the dictionary.
-  if expr in expr_to_info:
-    return
-
-  if isinstance(expr, IndexVar):
-    src_indices = expr,  # A tuple with one element.
-    dim_infos = _get_dummy_dim_info(),  # A tuple with one element.
-  elif isinstance(expr, Access):
-    src_indices = expr.indices
-    src_dims = tuple(expr.tensor.shape)
-    if expr.tensor.format is None:
-      # Treat each dimension of a dense tensor as DENSE for the purpose of
-      # calculating temporary tensor storage format.
-      mode_formats = tuple([ModeFormat.DENSE] * len(src_dims))
+    """Validates dimension information and constructs _ExprInfo.
+
+    Validates that dimensional values for the same IndexVar are the same. Collects
+    a list of IndexVar used by the expression and their corresponding dimensional
+    values. Constructs an _ExprInfo object to record the information for the
+    IndexExpr.
+
+    This routine is passed to the post-order visitor as an _ExprVisitor object.
+
+    Args:
+      expr: The IndexExpr being validated.
+      expr_to_info: The dictionary of (IndexExpr, _ExprInfo) for recording the
+        expression information.
+
+    Raises:
+      ValueError if there is any problem in the IndexVars or dimensional values.
+    """
+    # Objects of class Access can be shared by different expressions. Avoid
+    # processing Access objects multiple times by skipping the processing if expr
+    # is already in the dictionary.
+    if expr in expr_to_info:
+        return
+
+    if isinstance(expr, IndexVar):
+        src_indices = (expr,)  # A tuple with one element.
+        dim_infos = (_get_dummy_dim_info(),)  # A tuple with one element.
+    elif isinstance(expr, Access):
+        src_indices = expr.indices
+        src_dims = tuple(expr.tensor.shape)
+        if expr.tensor.format is None:
+            # Treat each dimension of a dense tensor as DENSE for the purpose of
+            # calculating temporary tensor storage format.
+            mode_formats = tuple([ModeFormat.DENSE] * len(src_dims))
+        else:
+            mode_formats = tuple(expr.tensor.format.format_pack.formats)
+        assert len(src_dims) == len(mode_formats)
+        dim_infos = tuple([_DimInfo(d, m) for d, m in zip(src_dims, mode_formats)])
+    elif isinstance(expr, _UnaryExpr):
+        a_info = expr_to_info[expr.a]
+        index_to_dim_info = {i: d for i, d in zip(a_info.src_indices, a_info.dim_infos)}
+        # Here we rely on the fact that dictionaries keep the insertion order for
+        # keys and values.
+        src_indices = tuple(index_to_dim_info.keys())
+        dim_infos = tuple(index_to_dim_info.values())
     else:
-      mode_formats = tuple(expr.tensor.format.format_pack.formats)
-    assert len(src_dims) == len(mode_formats)
-    dim_infos = tuple([_DimInfo(d, m) for d, m in zip(src_dims, mode_formats)])
-  elif isinstance(expr, _UnaryExpr):
-    a_info = expr_to_info[expr.a]
-    index_to_dim_info = {
-        i: d for i, d in zip(a_info.src_indices, a_info.dim_infos)
-    }
-    # Here we rely on the fact that dictionaries keep the insertion order for
-    # keys and values.
-    src_indices = tuple(index_to_dim_info.keys())
-    dim_infos = tuple(index_to_dim_info.values())
-  else:
-    assert isinstance(expr, _BinaryExpr)
-    a_info = expr_to_info[expr.a]
-    index_to_dim_info = {
-        i: d for i, d in zip(a_info.src_indices, a_info.dim_infos)
-    }
-    b_info = expr_to_info[expr.b]
-    _validate_and_collect_dim_info(index_to_dim_info, b_info.src_indices,
-                                   b_info.dim_infos, expr)
-    # Here we rely on the fact that dictionaries keep the insertion order for
-    # keys and values.
-    src_indices = tuple(index_to_dim_info.keys())
-    dim_infos = tuple(index_to_dim_info.values())
+        assert isinstance(expr, _BinaryExpr)
+        a_info = expr_to_info[expr.a]
+        index_to_dim_info = {i: d for i, d in zip(a_info.src_indices, a_info.dim_infos)}
+        b_info = expr_to_info[expr.b]
+        _validate_and_collect_dim_info(
+            index_to_dim_info, b_info.src_indices, b_info.dim_infos, expr
+        )
+        # Here we rely on the fact that dictionaries keep the insertion order for
+        # keys and values.
+        src_indices = tuple(index_to_dim_info.keys())
+        dim_infos = tuple(index_to_dim_info.values())
 
-  expr_to_info[expr] = _ExprInfo(src_indices, dim_infos)
+    expr_to_info[expr] = _ExprInfo(src_indices, dim_infos)
 
 
 def _mark_structured_op_root(
@@ -1928,90 +2006,92 @@ def _mark_structured_op_root(
     reduce_index: IndexVar,
     expr_to_info: _ExprInfoDict,
 ) -> None:
-  """Identifies the root expression for a structured op in the linalg dialect.
-
-  An linalg structured op can only perform reduction on the whole expression.
-  For a TACO tensor algebra expression, the reduction on an IndexVar is done at
-  the smallest expression that contains all the uses of the IndexVar. If such an
-  expression is only part of the whole expression, we need to split this
-  sub-expression tree out from its parent and implement the sub-expression as a
-  structured op.
-
-  This routine identifies the root expression node for performing a reduction on
-  the given IndexVar. If the reduction of the given IndexVar should be performed
-  on expression X, then the IndexVar is added to expr_to_info[X].reduce_indices
-
-  Args:
-    expr: The root IndexExpr for the tensor algebra expression.
-    reduce_index: The IndexVar which we want to find out the proper expression
-      to perform a reduction.
-    expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
-
-  Raises:
-      ValueError: If the expression is not proper or not supported.
-  """
-  expr_info = expr_to_info[expr]
-  if isinstance(expr, Access):
-    # Handle simple reduction expression in the format of A[i] = B[i, j].
-    if reduce_index in expr_info.src_indices:
-      expr_info.reduce_indices.add(reduce_index)
-    return
-  elif isinstance(expr, IndexVar):
-    # A[i] = B[i] + j is not allowed.
-    raise ValueError(f"IndexVar is not part of the iteration domain: {expr}.")
-
-  assert (isinstance(expr, _BinaryExpr))
-  a_info = expr_to_info[expr.a]
-  b_info = expr_to_info[expr.b]
-
-  if reduce_index in a_info.src_indices and reduce_index in b_info.src_indices:
-    expr_info.reduce_indices.add(reduce_index)
-    return
-
-  if reduce_index in a_info.src_indices:
-    _mark_structured_op_root(expr.a, reduce_index, expr_to_info)
-  elif reduce_index in b_info.src_indices:
-    _mark_structured_op_root(expr.b, reduce_index, expr_to_info)
-  else:
-    assert False, "Unreachable path"
+    """Identifies the root expression for a structured op in the linalg dialect.
 
+    An linalg structured op can only perform reduction on the whole expression.
+    For a TACO tensor algebra expression, the reduction on an IndexVar is done at
+    the smallest expression that contains all the uses of the IndexVar. If such an
+    expression is only part of the whole expression, we need to split this
+    sub-expression tree out from its parent and implement the sub-expression as a
+    structured op.
 
-def _accumulate_reduce_indices(
-    expr: IndexExpr,
-    expr_to_info: _ExprInfoDict,
-) -> None:
-  """Propagates reduction indices from child expressions to parent expressions.
+    This routine identifies the root expression node for performing a reduction on
+    the given IndexVar. If the reduction of the given IndexVar should be performed
+    on expression X, then the IndexVar is added to expr_to_info[X].reduce_indices
 
-  This routine is passed to the post-order visitor as an _ExprVisitor object.
+    Args:
+      expr: The root IndexExpr for the tensor algebra expression.
+      reduce_index: The IndexVar which we want to find out the proper expression
+        to perform a reduction.
+      expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
 
-  Args:
-    expr: The IndexExpr being visited.
-    expr_to_info: The dictionary of (IndexExpr, _ExprInfo) for recording the
-      expression information.
-  """
-  assert expr in expr_to_info
-  expr_info = expr_to_info[expr]
+    Raises:
+        ValueError: If the expression is not proper or not supported.
+    """
+    expr_info = expr_to_info[expr]
+    if isinstance(expr, Access):
+        # Handle simple reduction expression in the format of A[i] = B[i, j].
+        if reduce_index in expr_info.src_indices:
+            expr_info.reduce_indices.add(reduce_index)
+        return
+    elif isinstance(expr, IndexVar):
+        # A[i] = B[i] + j is not allowed.
+        raise ValueError(f"IndexVar is not part of the iteration domain: {expr}.")
 
-  if isinstance(expr, _BinaryExpr):
+    assert isinstance(expr, _BinaryExpr)
     a_info = expr_to_info[expr.a]
     b_info = expr_to_info[expr.b]
-    expr_info.acc_reduce_indices = (
-        a_info.acc_reduce_indices | b_info.acc_reduce_indices
-        | expr_info.reduce_indices)
-  elif isinstance(expr, _UnaryExpr):
-    a_info = expr_to_info[expr.a]
-    expr_info.acc_reduce_indices = (
-        a_info.acc_reduce_indices | expr_info.reduce_indices)
-  elif isinstance(expr, IndexVar):
-    # If an IndexVar is reducing itself, it means the IndexVar is outside the
-    # iteration domain. This usage is now allowed and we should emit an error
-    # before reaching here.
-    assert not expr_info.reduce_indices
-  else:
-    assert isinstance(expr, Access)
-    # Handle simple reduction expression in the format of A[i] = B[i, j].
-    expr_info.acc_reduce_indices = expr_info.reduce_indices
 
+    if reduce_index in a_info.src_indices and reduce_index in b_info.src_indices:
+        expr_info.reduce_indices.add(reduce_index)
+        return
+
+    if reduce_index in a_info.src_indices:
+        _mark_structured_op_root(expr.a, reduce_index, expr_to_info)
+    elif reduce_index in b_info.src_indices:
+        _mark_structured_op_root(expr.b, reduce_index, expr_to_info)
+    else:
+        assert False, "Unreachable path"
+
+
+def _accumulate_reduce_indices(
+    expr: IndexExpr,
+    expr_to_info: _ExprInfoDict,
+) -> None:
+    """Propagates reduction indices from child expressions to parent expressions.
+
+    This routine is passed to the post-order visitor as an _ExprVisitor object.
+
+    Args:
+      expr: The IndexExpr being visited.
+      expr_to_info: The dictionary of (IndexExpr, _ExprInfo) for recording the
+        expression information.
+    """
+    assert expr in expr_to_info
+    expr_info = expr_to_info[expr]
+
+    if isinstance(expr, _BinaryExpr):
+        a_info = expr_to_info[expr.a]
+        b_info = expr_to_info[expr.b]
+        expr_info.acc_reduce_indices = (
+            a_info.acc_reduce_indices
+            | b_info.acc_reduce_indices
+            | expr_info.reduce_indices
+        )
+    elif isinstance(expr, _UnaryExpr):
+        a_info = expr_to_info[expr.a]
+        expr_info.acc_reduce_indices = (
+            a_info.acc_reduce_indices | expr_info.reduce_indices
+        )
+    elif isinstance(expr, IndexVar):
+        # If an IndexVar is reducing itself, it means the IndexVar is outside the
+        # iteration domain. This usage is now allowed and we should emit an error
+        # before reaching here.
+        assert not expr_info.reduce_indices
+    else:
+        assert isinstance(expr, Access)
+        # Handle simple reduction expression in the format of A[i] = B[i, j].
+        expr_info.acc_reduce_indices = expr_info.reduce_indices
 
 
 def _gather_structured_op(
@@ -2019,42 +2099,42 @@ def _gather_structured_op(
     expr_to_info: _ExprInfoDict,
     structop_roots: List[IndexExpr],
 ) -> None:
-  """Adds structured op root expression information to structop_roots.
-
-  This routine is passed to the post-order visitor as an _ExprVisitor object.
-
-  Args:
-    expr: The IndexExpr being visited.
-    expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
-    structop_roots: The resulting list of IndexExpr that are the roots for
-      linalg structured ops.
-  """
-  if not expr_to_info[expr].reduce_indices:
-    return
-
-  # If the expression is the root for reducing some indices, collect the indices
-  # and dimensions for the reduction result.
-  dst_indices = []
-  dst_dims = []
-  mode_fmts = []
-  for i, d in zip(expr_to_info[expr].src_indices, expr_to_info[expr].dim_infos):
-    if i not in expr_to_info[expr].acc_reduce_indices:
-      dst_indices.append(i)
-      dst_dims.append(d.dim)
-      mode_fmts.append(d.mode_format)
-
-  # Add the information to the dictionary.
-  op_info = _StructOpInfo(
-      tuple(dst_indices),
-      tuple(dst_dims),
-      expr.dtype(),
-      f"temp{len(structop_roots)}",
-      _make_format(mode_fmts),
-  )
-  expr_to_info[expr].structop_info = op_info
-
-  # Add the expression to the list of structured op roots.
-  structop_roots.append(expr)
+    """Adds structured op root expression information to structop_roots.
+
+    This routine is passed to the post-order visitor as an _ExprVisitor object.
+
+    Args:
+      expr: The IndexExpr being visited.
+      expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
+      structop_roots: The resulting list of IndexExpr that are the roots for
+        linalg structured ops.
+    """
+    if not expr_to_info[expr].reduce_indices:
+        return
+
+    # If the expression is the root for reducing some indices, collect the indices
+    # and dimensions for the reduction result.
+    dst_indices = []
+    dst_dims = []
+    mode_fmts = []
+    for i, d in zip(expr_to_info[expr].src_indices, expr_to_info[expr].dim_infos):
+        if i not in expr_to_info[expr].acc_reduce_indices:
+            dst_indices.append(i)
+            dst_dims.append(d.dim)
+            mode_fmts.append(d.mode_format)
+
+    # Add the information to the dictionary.
+    op_info = _StructOpInfo(
+        tuple(dst_indices),
+        tuple(dst_dims),
+        expr.dtype(),
+        f"temp{len(structop_roots)}",
+        _make_format(mode_fmts),
+    )
+    expr_to_info[expr].structop_info = op_info
+
+    # Add the expression to the list of structured op roots.
+    structop_roots.append(expr)
 
 
 def _is_structured_op_leaf(
@@ -2063,29 +2143,31 @@ def _is_structured_op_leaf(
     expr_to_info: _ExprInfoDict,
     *unused_args,
 ) -> bool:
-  """Returns true iff the expression is a leaf node for a structured op.
+    """Returns true iff the expression is a leaf node for a structured op.
 
-  The root of a structured op is a leaf of its parent structured op that uses
-  its result. An expression node is a leaf node for the current structured op if
-  it is an Access node or the root for a structured op that is not the current
-  structured op.
+    The root of a structured op is a leaf of its parent structured op that uses
+    its result. An expression node is a leaf node for the current structured op if
+    it is an Access node or the root for a structured op that is not the current
+    structured op.
 
-  This routine is passed to the post-order visitor as a _SubtreeLeafChecker
-  object. Because the post-order visitor pass the same parameters to both
-  _SubtreeLeafChecker and _ExprVisitor, this routine may received unused
-  parameters.
+    This routine is passed to the post-order visitor as a _SubtreeLeafChecker
+    object. Because the post-order visitor pass the same parameters to both
+    _SubtreeLeafChecker and _ExprVisitor, this routine may received unused
+    parameters.
 
-  Args:
-    expr: The IndexExpr being visited.
-    root: The root of the current structured op.
-    expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
+    Args:
+      expr: The IndexExpr being visited.
+      root: The root of the current structured op.
+      expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
 
-  Returns:
-    True if the current IndexExpr is a leaf for the current structured op.
-  """
-  return (expr != root and
-          expr_to_info[expr].structop_info is not None) or isinstance(
-              expr, Access) or isinstance(expr, IndexVar)
+    Returns:
+      True if the current IndexExpr is a leaf for the current structured op.
+    """
+    return (
+        (expr != root and expr_to_info[expr].structop_info is not None)
+        or isinstance(expr, Access)
+        or isinstance(expr, IndexVar)
+    )
 
 
 def _gather_structured_op_input(
@@ -2094,26 +2176,28 @@ def _gather_structured_op_input(
     expr_to_info: _ExprInfoDict,
     structop_inputs: List[IndexExpr],
 ) -> None:
-  """Adds the IndexExpr to structop_inputs if it is an input.
+    """Adds the IndexExpr to structop_inputs if it is an input.
 
-  If the current IndexExpr is an input for the current structured op, adds it to
-  structop_inputs. The current IndexExpr is an input if it is an Access node or
-  if it is the root for a structured op that is not the current structured op.
+    If the current IndexExpr is an input for the current structured op, adds it to
+    structop_inputs. The current IndexExpr is an input if it is an Access node or
+    if it is the root for a structured op that is not the current structured op.
 
-  This routine is passed to the post-order visitor as an _ExprVisitor object.
+    This routine is passed to the post-order visitor as an _ExprVisitor object.
 
-  Args:
-    expr: The IndexExpr being visited.
-    root: The root of the current structured op.
-    expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
-    structop_inputs: The resulting list of IndexExpr that provide input to the
-      current structured op.
-  """
-  if ((expr != root or isinstance(expr, Access)) and
-      expr not in structop_inputs) and (isinstance(expr, Access) or
-                                        (expr in expr_to_info and
-                                         expr_to_info[expr].structop_info)):
-    structop_inputs.append(expr)
+    Args:
+      expr: The IndexExpr being visited.
+      root: The root of the current structured op.
+      expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
+      structop_inputs: The resulting list of IndexExpr that provide input to the
+        current structured op.
+    """
+    if (
+        (expr != root or isinstance(expr, Access)) and expr not in structop_inputs
+    ) and (
+        isinstance(expr, Access)
+        or (expr in expr_to_info and expr_to_info[expr].structop_info)
+    ):
+        structop_inputs.append(expr)
 
 
 def _emit_structured_op_input(
@@ -2121,35 +2205,35 @@ def _emit_structured_op_input(
     expr_to_info: _ExprInfoDict,
     op_def: lang.LinalgOpDef,
 ) -> lang.OperandDef:
-  """Emits OperandDef in the linalg dialect for the input IndexExpr.
-
-  Args:
-    expr: The input IndexExpr for the current structured op.
-    expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
-    op_def: The linalg operation for the current structured op.
-
-  Returns:
-    An OperandDef in the linalg dialect for the input IndexExpr.
-  """
-  op_info = expr_to_info[expr].structop_info
-  if op_info and not isinstance(expr, Access):
-    # The input is a temporary tensor produced by another structured op.
-    indices = op_info.dst_indices
-    name = op_info.dst_name
-  else:
-    # The input is a user provided tensor.
-    assert isinstance(expr, Access)
-    indices = expr.indices
-    name = expr.tensor.name
-
-  dim_sym = _mlir_symbols_from_index_vars(indices)
-  opnd = lang.OperandDef(lang.OperandKind.INPUT_TENSOR, lang.T, dim_sym)
-  op_def.add_operand(name, opnd)
-  return opnd
+    """Emits OperandDef in the linalg dialect for the input IndexExpr.
+
+    Args:
+      expr: The input IndexExpr for the current structured op.
+      expr_to_info: The dictionary to look up _ExprInfo for IndexExpr.
+      op_def: The linalg operation for the current structured op.
+
+    Returns:
+      An OperandDef in the linalg dialect for the input IndexExpr.
+    """
+    op_info = expr_to_info[expr].structop_info
+    if op_info and not isinstance(expr, Access):
+        # The input is a temporary tensor produced by another structured op.
+        indices = op_info.dst_indices
+        name = op_info.dst_name
+    else:
+        # The input is a user provided tensor.
+        assert isinstance(expr, Access)
+        indices = expr.indices
+        name = expr.tensor.name
+
+    dim_sym = _mlir_symbols_from_index_vars(indices)
+    opnd = lang.OperandDef(lang.OperandKind.INPUT_TENSOR, lang.T, dim_sym)
+    op_def.add_operand(name, opnd)
+    return opnd
 
 
 def _check_and_build_unary(a: Access, op: _UnaryOp) -> "_UnaryExpr":
-  """Build a unary operation ceil.
+    """Build a unary operation ceil.
 
     Args:
       a: The operand, which could be any Python object from user inputs.
@@ -2161,13 +2245,13 @@ def _check_and_build_unary(a: Access, op: _UnaryOp) -> "_UnaryExpr":
     Raises:
       ValueError: If a is not an IndexExpr.
     """
-  if not isinstance(a, Access):
-    raise ValueError(f"Expected an Access Operand: {a}")
-  return a._build_unary_expr(op)
+    if not isinstance(a, Access):
+        raise ValueError(f"Expected an Access Operand: {a}")
+    return a._build_unary_expr(op)
 
 
 def ceil(a: Access) -> "_UnaryExpr":
-  """Defines the operation ceil.
+    """Defines the operation ceil.
 
     Args:
       a: The operand, which could be any Python object from user inputs.
@@ -2178,11 +2262,11 @@ def ceil(a: Access) -> "_UnaryExpr":
     Raises:
       ValueError: If a is not an IndexExpr.
     """
-  return _check_and_build_unary(a, _op_ceil)
+    return _check_and_build_unary(a, _op_ceil)
 
 
 def floor(a: Access) -> "_UnaryExpr":
-  """Defines the operation floor.
+    """Defines the operation floor.
 
     Args:
       a: The operand, which could be any Python object from user inputs.
@@ -2193,4 +2277,4 @@ def floor(a: Access) -> "_UnaryExpr":
     Raises:
       ValueError: If a is not an IndexExpr.
     """
-  return _check_and_build_unary(a, _op_floor)
+    return _check_and_build_unary(a, _op_floor)
index e6a7d8e..785401c 100644 (file)
@@ -31,50 +31,52 @@ _MTX_FILENAME_SUFFIX = ".mtx"
 _TNS_FILENAME_SUFFIX = ".tns"
 
 
-def read(filename: str, fmt: Format,
-         dtype: DType = DType(Type.FLOAT32)) -> Tensor:
-  """Inputs a tensor from a given file.
-
-  The name suffix of the file specifies the format of the input tensor. We
-  currently only support .mtx format for support sparse tensors.
-
-  Args:
-    filename: A string input filename.
-    fmt: The storage format of the tensor.
-    dtype: The data type, default to float32.
-
-  Raises:
-    ValueError: If filename doesn't end with .mtx or .tns, or fmt is not an
-    instance of Format or fmt is not a sparse tensor.
-  """
-  if (not isinstance(filename, str) or
-      (not filename.endswith(_MTX_FILENAME_SUFFIX) and
-       not filename.endswith(_TNS_FILENAME_SUFFIX))):
-    raise ValueError("Expected string filename ends with "
-                     f"{_MTX_FILENAME_SUFFIX} or {_TNS_FILENAME_SUFFIX}: "
-                     f"{filename}.")
-
-  return Tensor.from_file(filename, fmt, dtype)
+def read(filename: str, fmt: Format, dtype: DType = DType(Type.FLOAT32)) -> Tensor:
+    """Inputs a tensor from a given file.
+
+    The name suffix of the file specifies the format of the input tensor. We
+    currently only support .mtx format for support sparse tensors.
+
+    Args:
+      filename: A string input filename.
+      fmt: The storage format of the tensor.
+      dtype: The data type, default to float32.
+
+    Raises:
+      ValueError: If filename doesn't end with .mtx or .tns, or fmt is not an
+      instance of Format or fmt is not a sparse tensor.
+    """
+    if not isinstance(filename, str) or (
+        not filename.endswith(_MTX_FILENAME_SUFFIX)
+        and not filename.endswith(_TNS_FILENAME_SUFFIX)
+    ):
+        raise ValueError(
+            "Expected string filename ends with "
+            f"{_MTX_FILENAME_SUFFIX} or {_TNS_FILENAME_SUFFIX}: "
+            f"{filename}."
+        )
+
+    return Tensor.from_file(filename, fmt, dtype)
 
 
 def write(filename: str, tensor: Tensor) -> None:
-  """Outputs a tensor to a given file.
-
-  The name suffix of the file specifies the format of the output. We currently
-  only support .tns format.
-
-  Args:
-    filename: A string output filename.
-    tensor: The tensor to output.
-
-  Raises:
-    ValueError: If filename doesn't end with .tns or tensor is not a Tensor.
-  """
-  if (not isinstance(filename, str) or
-      not filename.endswith(_TNS_FILENAME_SUFFIX)):
-    raise ValueError("Expected string filename ends with"
-                     f" {_TNS_FILENAME_SUFFIX}: {filename}.")
-  if not isinstance(tensor, Tensor):
-    raise ValueError(f"Expected a Tensor object: {tensor}.")
-
-  tensor.to_file(filename)
+    """Outputs a tensor to a given file.
+
+    The name suffix of the file specifies the format of the output. We currently
+    only support .tns format.
+
+    Args:
+      filename: A string output filename.
+      tensor: The tensor to output.
+
+    Raises:
+      ValueError: If filename doesn't end with .tns or tensor is not a Tensor.
+    """
+    if not isinstance(filename, str) or not filename.endswith(_TNS_FILENAME_SUFFIX):
+        raise ValueError(
+            "Expected string filename ends with" f" {_TNS_FILENAME_SUFFIX}: {filename}."
+        )
+    if not isinstance(tensor, Tensor):
+        raise ValueError(f"Expected a Tensor object: {tensor}.")
+
+    tensor.to_file(filename)
index 988c57b..1e1061b 100644 (file)
@@ -36,190 +36,234 @@ _ENTRY_NAME = "main"
 
 @functools.lru_cache()
 def _get_support_lib_name() -> str:
-  """Gets the string name for the supporting C shared library."""
-  return os.getenv(_SUPPORTLIB_ENV_VAR, _DEFAULT_SUPPORTLIB)
+    """Gets the string name for the supporting C shared library."""
+    return os.getenv(_SUPPORTLIB_ENV_VAR, _DEFAULT_SUPPORTLIB)
 
 
 @functools.lru_cache()
 def _get_sparse_compiler() -> mlir_sparse_compiler.SparseCompiler:
-  """Gets the MLIR sparse compiler with default setting."""
-  return mlir_sparse_compiler.SparseCompiler(
-      options="", opt_level=_OPT_LEVEL, shared_libs=[_get_support_lib_name()])
+    """Gets the MLIR sparse compiler with default setting."""
+    return mlir_sparse_compiler.SparseCompiler(
+        options="", opt_level=_OPT_LEVEL, shared_libs=[_get_support_lib_name()]
+    )
 
 
 def _record_support_funcs(
-    ty: np.dtype, to_func: _SupportFunc, from_func: _SupportFunc,
-    ty_to_funcs: Dict[np.dtype, Tuple[_SupportFunc, _SupportFunc]]) -> None:
-  """Records the two supporting functions for a given data type."""
-  to_func.restype = ctypes.c_void_p
-  from_func.restype = ctypes.c_void_p
-  ty_to_funcs[ty] = (to_func, from_func)
+    ty: np.dtype,
+    to_func: _SupportFunc,
+    from_func: _SupportFunc,
+    ty_to_funcs: Dict[np.dtype, Tuple[_SupportFunc, _SupportFunc]],
+) -> None:
+    """Records the two supporting functions for a given data type."""
+    to_func.restype = ctypes.c_void_p
+    from_func.restype = ctypes.c_void_p
+    ty_to_funcs[ty] = (to_func, from_func)
 
 
 @functools.lru_cache()
 def _get_support_func_locator() -> _SupportFuncLocator:
-  """Constructs a function to locate the supporting functions for a data type.
-
-  Loads the supporting C shared library with the needed routines. Constructs a
-  dictionary from the supported data types to the routines for the data types,
-  and then a function to look up the dictionary for a given data type.
-
-  The name of the supporting C shared library is either provided by an
-  an environment variable or a default value.
-
-  Returns:
-    The function to look up the supporting functions for a given data type.
-
-  Raises:
-    OSError: If there is any problem in loading the shared library.
-    ValueError: If the shared library doesn't contain the needed routines.
-  """
-  # This raises OSError exception if there is any problem in loading the shared
-  # library.
-  c_lib = ctypes.CDLL(_get_support_lib_name())
-
-  type_to_funcs = {}
-  try:
-    support_types = [(np.int8, c_lib.convertToMLIRSparseTensorI8,
-                      c_lib.convertFromMLIRSparseTensorI8),
-                     (np.int16, c_lib.convertToMLIRSparseTensorI16,
-                      c_lib.convertFromMLIRSparseTensorI16),
-                     (np.int32, c_lib.convertToMLIRSparseTensorI32,
-                      c_lib.convertFromMLIRSparseTensorI32),
-                     (np.int64, c_lib.convertToMLIRSparseTensorI64,
-                      c_lib.convertFromMLIRSparseTensorI64),
-                     (np.float16, c_lib.convertToMLIRSparseTensorF16,
-                      c_lib.convertFromMLIRSparseTensorF16),
-                     (np.float32, c_lib.convertToMLIRSparseTensorF32,
-                      c_lib.convertFromMLIRSparseTensorF32),
-                     (np.float64, c_lib.convertToMLIRSparseTensorF64,
-                      c_lib.convertFromMLIRSparseTensorF64),
-                     (np.complex64, c_lib.convertToMLIRSparseTensorC32,
-                      c_lib.convertFromMLIRSparseTensorC32),
-                     (np.complex128, c_lib.convertToMLIRSparseTensorC64,
-                      c_lib.convertFromMLIRSparseTensorC64)]
-  except Exception as e:
-    raise ValueError(f"Missing supporting function: {e}") from e
-  for i, info in enumerate(support_types):
-    _record_support_funcs(info[0], info[1], info[2], type_to_funcs)
-
-  def get_support_funcs(ty: np.dtype):
-    funcs = type_to_funcs[ty]
-    assert funcs is not None
-    return funcs
-
-  return get_support_funcs
+    """Constructs a function to locate the supporting functions for a data type.
+
+    Loads the supporting C shared library with the needed routines. Constructs a
+    dictionary from the supported data types to the routines for the data types,
+    and then a function to look up the dictionary for a given data type.
+
+    The name of the supporting C shared library is either provided by an
+    an environment variable or a default value.
+
+    Returns:
+      The function to look up the supporting functions for a given data type.
+
+    Raises:
+      OSError: If there is any problem in loading the shared library.
+      ValueError: If the shared library doesn't contain the needed routines.
+    """
+    # This raises OSError exception if there is any problem in loading the shared
+    # library.
+    c_lib = ctypes.CDLL(_get_support_lib_name())
+
+    type_to_funcs = {}
+    try:
+        support_types = [
+            (
+                np.int8,
+                c_lib.convertToMLIRSparseTensorI8,
+                c_lib.convertFromMLIRSparseTensorI8,
+            ),
+            (
+                np.int16,
+                c_lib.convertToMLIRSparseTensorI16,
+                c_lib.convertFromMLIRSparseTensorI16,
+            ),
+            (
+                np.int32,
+                c_lib.convertToMLIRSparseTensorI32,
+                c_lib.convertFromMLIRSparseTensorI32,
+            ),
+            (
+                np.int64,
+                c_lib.convertToMLIRSparseTensorI64,
+                c_lib.convertFromMLIRSparseTensorI64,
+            ),
+            (
+                np.float16,
+                c_lib.convertToMLIRSparseTensorF16,
+                c_lib.convertFromMLIRSparseTensorF16,
+            ),
+            (
+                np.float32,
+                c_lib.convertToMLIRSparseTensorF32,
+                c_lib.convertFromMLIRSparseTensorF32,
+            ),
+            (
+                np.float64,
+                c_lib.convertToMLIRSparseTensorF64,
+                c_lib.convertFromMLIRSparseTensorF64,
+            ),
+            (
+                np.complex64,
+                c_lib.convertToMLIRSparseTensorC32,
+                c_lib.convertFromMLIRSparseTensorC32,
+            ),
+            (
+                np.complex128,
+                c_lib.convertToMLIRSparseTensorC64,
+                c_lib.convertFromMLIRSparseTensorC64,
+            ),
+        ]
+    except Exception as e:
+        raise ValueError(f"Missing supporting function: {e}") from e
+    for i, info in enumerate(support_types):
+        _record_support_funcs(info[0], info[1], info[2], type_to_funcs)
+
+    def get_support_funcs(ty: np.dtype):
+        funcs = type_to_funcs[ty]
+        assert funcs is not None
+        return funcs
+
+    return get_support_funcs
 
 
 def sparse_tensor_to_coo_tensor(
     sparse_tensor: ctypes.c_void_p,
     dtype: np.dtype,
 ) -> Tuple[int, int, np.ndarray, np.ndarray, np.ndarray]:
-  """Converts an MLIR sparse tensor to a COO-flavored format tensor.
-
-  Args:
-     sparse_tensor: A ctypes.c_void_p to the MLIR sparse tensor descriptor.
-     dtype: The numpy data type for the tensor elements.
-
-  Returns:
-    A tuple that contains the following values for the COO-flavored format
-    tensor:
-    rank: An integer for the rank of the tensor.
-    nse: An integer for the number of non-zero values in the tensor.
-    shape: A 1D numpy array of integers, for the shape of the tensor.
-    values: A 1D numpy array, for the non-zero values in the tensor.
-    indices: A 2D numpy array of integers, representing the indices for the
-      non-zero values in the tensor.
-
-  Raises:
-    OSError: If there is any problem in loading the shared library.
-    ValueError: If the shared library doesn't contain the needed routines.
-  """
-  convert_from = _get_support_func_locator()(dtype)[1]
-  rank = ctypes.c_ulonglong(0)
-  nse = ctypes.c_ulonglong(0)
-  shape = ctypes.POINTER(ctypes.c_ulonglong)()
-
-  values = ctypes.POINTER(runtime.as_ctype(np.dtype(dtype)))()
-  indices = ctypes.POINTER(ctypes.c_ulonglong)()
-  convert_from(sparse_tensor, ctypes.byref(rank), ctypes.byref(nse),
-               ctypes.byref(shape), ctypes.byref(values), ctypes.byref(indices))
-
-  # Convert the returned values to the corresponding numpy types.
-  shape = np.ctypeslib.as_array(shape, shape=[rank.value])
-  values = runtime.to_numpy(np.ctypeslib.as_array(values, shape=[nse.value]))
-  indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value])
-  return rank.value, nse.value, shape, values, indices
-
-
-def coo_tensor_to_sparse_tensor(np_shape: np.ndarray, np_values: np.ndarray,
-                                np_indices: np.ndarray, np_perm: np.ndarray,
-                                np_sparse: np.ndarray) -> int:
-  """Converts a COO-flavored format sparse tensor to an MLIR sparse tensor.
-
-  Args:
-     np_shape: A 1D numpy array of integers, for the shape of the tensor.
-     np_values: A 1D numpy array, for the non-zero values in the tensor.
-     np_indices: A 2D numpy array of integers, representing the indices for the
-       non-zero values in the tensor.
-     np_perm: A 1D numpy array of integers, representing the storage ordering
-       for the dimensions.
-     np_sparse: A 1D numpy array of uint8, representing the sparsity values
-       for the dimensions.
-
-  Returns:
-     An integer for the non-null ctypes.c_void_p to the MLIR sparse tensor
-     descriptor.
-
-  Raises:
-    OSError: If there is any problem in loading the shared library.
-    ValueError: If the shared library doesn't contain the needed routines.
-  """
-
-  r = len(np_shape)
-  rank = ctypes.c_ulonglong(r)
-  nse = ctypes.c_ulonglong(len(np_values))
-  shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
-  values = np_values.ctypes.data_as(
-      ctypes.POINTER(runtime.as_ctype(np.dtype(np_values.dtype))))
-  indices = np_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
-
-  perm = np_perm.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
-  sparse = np_sparse.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8))
-
-  convert_to = _get_support_func_locator()(np_values.dtype.type)[0]
-  ptr = convert_to(rank, nse, shape, values, indices, perm, sparse)
-  assert ptr is not None, "Problem with calling convertToMLIRSparseTensorF64"
-  return ptr
-
-
-def compile_and_build_engine(
-    module: ir.Module) -> execution_engine.ExecutionEngine:
-  """Compiles an MLIR module and builds a JIT execution engine.
-
-  Args:
-    module: The MLIR module.
-
-  Returns:
-    A JIT execution engine for the MLIR module.
-
-  """
-  return _get_sparse_compiler().compile_and_jit(module)
+    """Converts an MLIR sparse tensor to a COO-flavored format tensor.
+
+    Args:
+       sparse_tensor: A ctypes.c_void_p to the MLIR sparse tensor descriptor.
+       dtype: The numpy data type for the tensor elements.
+
+    Returns:
+      A tuple that contains the following values for the COO-flavored format
+      tensor:
+      rank: An integer for the rank of the tensor.
+      nse: An integer for the number of non-zero values in the tensor.
+      shape: A 1D numpy array of integers, for the shape of the tensor.
+      values: A 1D numpy array, for the non-zero values in the tensor.
+      indices: A 2D numpy array of integers, representing the indices for the
+        non-zero values in the tensor.
+
+    Raises:
+      OSError: If there is any problem in loading the shared library.
+      ValueError: If the shared library doesn't contain the needed routines.
+    """
+    convert_from = _get_support_func_locator()(dtype)[1]
+    rank = ctypes.c_ulonglong(0)
+    nse = ctypes.c_ulonglong(0)
+    shape = ctypes.POINTER(ctypes.c_ulonglong)()
+
+    values = ctypes.POINTER(runtime.as_ctype(np.dtype(dtype)))()
+    indices = ctypes.POINTER(ctypes.c_ulonglong)()
+    convert_from(
+        sparse_tensor,
+        ctypes.byref(rank),
+        ctypes.byref(nse),
+        ctypes.byref(shape),
+        ctypes.byref(values),
+        ctypes.byref(indices),
+    )
+
+    # Convert the returned values to the corresponding numpy types.
+    shape = np.ctypeslib.as_array(shape, shape=[rank.value])
+    values = runtime.to_numpy(np.ctypeslib.as_array(values, shape=[nse.value]))
+    indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value])
+    return rank.value, nse.value, shape, values, indices
+
+
+def coo_tensor_to_sparse_tensor(
+    np_shape: np.ndarray,
+    np_values: np.ndarray,
+    np_indices: np.ndarray,
+    np_perm: np.ndarray,
+    np_sparse: np.ndarray,
+) -> int:
+    """Converts a COO-flavored format sparse tensor to an MLIR sparse tensor.
+
+    Args:
+       np_shape: A 1D numpy array of integers, for the shape of the tensor.
+       np_values: A 1D numpy array, for the non-zero values in the tensor.
+       np_indices: A 2D numpy array of integers, representing the indices for the
+         non-zero values in the tensor.
+       np_perm: A 1D numpy array of integers, representing the storage ordering
+         for the dimensions.
+       np_sparse: A 1D numpy array of uint8, representing the sparsity values
+         for the dimensions.
+
+    Returns:
+       An integer for the non-null ctypes.c_void_p to the MLIR sparse tensor
+       descriptor.
+
+    Raises:
+      OSError: If there is any problem in loading the shared library.
+      ValueError: If the shared library doesn't contain the needed routines.
+    """
+
+    r = len(np_shape)
+    rank = ctypes.c_ulonglong(r)
+    nse = ctypes.c_ulonglong(len(np_values))
+    shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
+    values = np_values.ctypes.data_as(
+        ctypes.POINTER(runtime.as_ctype(np.dtype(np_values.dtype)))
+    )
+    indices = np_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
+
+    perm = np_perm.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
+    sparse = np_sparse.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8))
+
+    convert_to = _get_support_func_locator()(np_values.dtype.type)[0]
+    ptr = convert_to(rank, nse, shape, values, indices, perm, sparse)
+    assert ptr is not None, "Problem with calling convertToMLIRSparseTensorF64"
+    return ptr
+
+
+def compile_and_build_engine(module: ir.Module) -> execution_engine.ExecutionEngine:
+    """Compiles an MLIR module and builds a JIT execution engine.
+
+    Args:
+      module: The MLIR module.
+
+    Returns:
+      A JIT execution engine for the MLIR module.
+
+    """
+    return _get_sparse_compiler().compile_and_jit(module)
 
 
 class _SparseTensorDescriptor(ctypes.Structure):
-  """A C structure for an MLIR sparse tensor."""
-  _fields_ = [
-      # A pointer for the MLIR sparse tensor storage.
-      ("storage", ctypes.POINTER(ctypes.c_ulonglong)),
-      # An MLIR MemRef descriptor for the shape of the sparse tensor.
-      ("shape", runtime.make_nd_memref_descriptor(1, ctypes.c_ulonglong)),
-  ]
+    """A C structure for an MLIR sparse tensor."""
+
+    _fields_ = [
+        # A pointer for the MLIR sparse tensor storage.
+        ("storage", ctypes.POINTER(ctypes.c_ulonglong)),
+        # An MLIR MemRef descriptor for the shape of the sparse tensor.
+        ("shape", runtime.make_nd_memref_descriptor(1, ctypes.c_ulonglong)),
+    ]
 
 
 def _output_one_dim(dim: int, rank: int, shape: str, type: str) -> str:
-  """Produces the MLIR text code to output the size for the given dimension."""
-  return f"""
+    """Produces the MLIR text code to output the size for the given dimension."""
+    return f"""
   %c{dim} = arith.constant {dim} : index
   %d{dim} = tensor.dim %t, %c{dim} : tensor<{shape}x{type}, #enc>
   memref.store %d{dim}, %b[%c{dim}] : memref<{rank}xindex>
@@ -233,26 +277,29 @@ def _output_one_dim(dim: int, rank: int, shape: str, type: str) -> str:
 # (2) Use scf.for instead of an unrolled loop to write out the dimension sizes
 #     when tensor.dim supports non-constant dimension value.
 def _get_create_sparse_tensor_kernel(
-    sparsity_codes: Sequence[sparse_tensor.DimLevelType], type: str) -> str:
-  """Creates an MLIR text kernel to contruct a sparse tensor from a file.
+    sparsity_codes: Sequence[sparse_tensor.DimLevelType], type: str
+) -> str:
+    """Creates an MLIR text kernel to contruct a sparse tensor from a file.
 
-  The kernel returns a _SparseTensorDescriptor structure.
-  """
-  rank = len(sparsity_codes)
+    The kernel returns a _SparseTensorDescriptor structure.
+    """
+    rank = len(sparsity_codes)
 
-  # Use ? to represent a dimension in the dynamic shape string representation.
-  shape = "x".join(map(lambda d: "?", range(rank)))
+    # Use ? to represent a dimension in the dynamic shape string representation.
+    shape = "x".join(map(lambda d: "?", range(rank)))
 
-  # Convert the encoded sparsity values to a string representation.
-  sparsity = ", ".join(
-      map(lambda s: '"compressed"' if s.value else '"dense"', sparsity_codes))
+    # Convert the encoded sparsity values to a string representation.
+    sparsity = ", ".join(
+        map(lambda s: '"compressed"' if s.value else '"dense"', sparsity_codes)
+    )
 
-  # Get the MLIR text code to write the dimension sizes to the output buffer.
-  output_dims = "\n".join(
-      map(lambda d: _output_one_dim(d, rank, shape, type), range(rank)))
+    # Get the MLIR text code to write the dimension sizes to the output buffer.
+    output_dims = "\n".join(
+        map(lambda d: _output_one_dim(d, rank, shape, type), range(rank))
+    )
 
-  # Return the MLIR text kernel.
-  return f"""
+    # Return the MLIR text kernel.
+    return f"""
 !Ptr = !llvm.ptr<i8>
 #enc = #sparse_tensor.encoding<{{
   lvlTypes = [ {sparsity} ]
@@ -266,69 +313,69 @@ attributes {{ llvm.emit_c_interface }} {{
 }}"""
 
 
-def create_sparse_tensor(filename: str,
-                         sparsity: Sequence[sparse_tensor.DimLevelType],
-                         type: str) -> Tuple[ctypes.c_void_p, np.ndarray]:
-  """Creates an MLIR sparse tensor from the input file.
+def create_sparse_tensor(
+    filename: str, sparsity: Sequence[sparse_tensor.DimLevelType], type: str
+) -> Tuple[ctypes.c_void_p, np.ndarray]:
+    """Creates an MLIR sparse tensor from the input file.
 
-  Args:
-    filename: A string for the name of the file that contains the tensor data in
-      a COO-flavored format.
-    sparsity: A sequence of DimLevelType values, one for each dimension of the
-      tensor.
+    Args:
+      filename: A string for the name of the file that contains the tensor data in
+        a COO-flavored format.
+      sparsity: A sequence of DimLevelType values, one for each dimension of the
+        tensor.
 
-  Returns:
-    A Tuple containing the following values:
-    storage: A ctypes.c_void_p for the MLIR sparse tensor storage.
-    shape: A 1D numpy array of integers, for the shape of the tensor.
+    Returns:
+      A Tuple containing the following values:
+      storage: A ctypes.c_void_p for the MLIR sparse tensor storage.
+      shape: A 1D numpy array of integers, for the shape of the tensor.
 
-  Raises:
-    OSError: If there is any problem in loading the supporting C shared library.
-    ValueError:  If the shared library doesn't contain the needed routine.
-  """
-  with ir.Context() as ctx, ir.Location.unknown():
-    module = _get_create_sparse_tensor_kernel(sparsity, type)
-    module = ir.Module.parse(module)
-    engine = compile_and_build_engine(module)
+    Raises:
+      OSError: If there is any problem in loading the supporting C shared library.
+      ValueError:  If the shared library doesn't contain the needed routine.
+    """
+    with ir.Context() as ctx, ir.Location.unknown():
+        module = _get_create_sparse_tensor_kernel(sparsity, type)
+        module = ir.Module.parse(module)
+        engine = compile_and_build_engine(module)
 
-  # A sparse tensor descriptor to receive the kernel result.
-  c_tensor_desc = _SparseTensorDescriptor()
-  # Convert the filename to a byte stream.
-  c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
+    # A sparse tensor descriptor to receive the kernel result.
+    c_tensor_desc = _SparseTensorDescriptor()
+    # Convert the filename to a byte stream.
+    c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
 
-  arg_pointers = [
-      ctypes.byref(ctypes.pointer(c_tensor_desc)),
-      ctypes.byref(c_filename)
-  ]
+    arg_pointers = [
+        ctypes.byref(ctypes.pointer(c_tensor_desc)),
+        ctypes.byref(c_filename),
+    ]
 
-  # Invoke the execution engine to run the module and return the result.
-  engine.invoke(_ENTRY_NAME, *arg_pointers)
-  shape = runtime.ranked_memref_to_numpy(ctypes.pointer(c_tensor_desc.shape))
-  return c_tensor_desc.storage, shape
+    # Invoke the execution engine to run the module and return the result.
+    engine.invoke(_ENTRY_NAME, *arg_pointers)
+    shape = runtime.ranked_memref_to_numpy(ctypes.pointer(c_tensor_desc.shape))
+    return c_tensor_desc.storage, shape
 
 
 # TODO: With better support from MLIR, we may improve the current implementation
 # by using Python code to generate the kernel instead of doing MLIR text code
 # stitching.
 def _get_output_sparse_tensor_kernel(
-        sparsity_codes: Sequence[sparse_tensor.DimLevelType],
-        type: str) -> str:
-  """Creates an MLIR text kernel to output a sparse tensor to a file.
+    sparsity_codes: Sequence[sparse_tensor.DimLevelType], type: str
+) -> str:
+    """Creates an MLIR text kernel to output a sparse tensor to a file.
 
-  The kernel returns void.
-  """
-  rank = len(sparsity_codes)
+    The kernel returns void.
+    """
+    rank = len(sparsity_codes)
 
-  # Use ? to represent a dimension in the dynamic shape string representation.
-  shape = "x".join(map(lambda d: "?", range(rank)))
+    # Use ? to represent a dimension in the dynamic shape string representation.
+    shape = "x".join(map(lambda d: "?", range(rank)))
 
-  # Convert the encoded sparsity values to a string representation.
-  sparsity = ", ".join(
-      map(lambda s: '"compressed"'
-          if s.value else '"dense"', sparsity_codes))
+    # Convert the encoded sparsity values to a string representation.
+    sparsity = ", ".join(
+        map(lambda s: '"compressed"' if s.value else '"dense"', sparsity_codes)
+    )
 
-  # Return the MLIR text kernel.
-  return f"""
+    # Return the MLIR text kernel.
+    return f"""
 !Ptr = !llvm.ptr<i8>
 #enc = #sparse_tensor.encoding<{{
   lvlTypes = [ {sparsity} ]
@@ -340,35 +387,38 @@ attributes {{ llvm.emit_c_interface }} {{
 }}"""
 
 
-def output_sparse_tensor(tensor: ctypes.c_void_p, filename: str,
-                         sparsity: Sequence[sparse_tensor.DimLevelType],
-                         type: str) -> None:
-  """Outputs an MLIR sparse tensor to the given file.
-
-  Args:
-    tensor: A C pointer to the MLIR sparse tensor.
-    filename: A string for the name of the file that contains the tensor data in
-      a COO-flavored format.
-    sparsity: A sequence of DimLevelType values, one for each dimension of the
-      tensor.
-    type: The MLIR string for the data type.
-
-  Raises:
-    OSError: If there is any problem in loading the supporting C shared library.
-    ValueError:  If the shared library doesn't contain the needed routine.
-  """
-  with ir.Context() as ctx, ir.Location.unknown():
-    module = _get_output_sparse_tensor_kernel(sparsity, type)
-    module = ir.Module.parse(module)
-    engine = compile_and_build_engine(module)
-
-  # Convert the filename to a byte stream.
-  c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
-
-  arg_pointers = [
-      ctypes.byref(ctypes.cast(tensor, ctypes.c_void_p)),
-      ctypes.byref(c_filename)
-  ]
-
-  # Invoke the execution engine to run the module and return the result.
-  engine.invoke(_ENTRY_NAME, *arg_pointers)
+def output_sparse_tensor(
+    tensor: ctypes.c_void_p,
+    filename: str,
+    sparsity: Sequence[sparse_tensor.DimLevelType],
+    type: str,
+) -> None:
+    """Outputs an MLIR sparse tensor to the given file.
+
+    Args:
+      tensor: A C pointer to the MLIR sparse tensor.
+      filename: A string for the name of the file that contains the tensor data in
+        a COO-flavored format.
+      sparsity: A sequence of DimLevelType values, one for each dimension of the
+        tensor.
+      type: The MLIR string for the data type.
+
+    Raises:
+      OSError: If there is any problem in loading the supporting C shared library.
+      ValueError:  If the shared library doesn't contain the needed routine.
+    """
+    with ir.Context() as ctx, ir.Location.unknown():
+        module = _get_output_sparse_tensor_kernel(sparsity, type)
+        module = ir.Module.parse(module)
+        engine = compile_and_build_engine(module)
+
+    # Convert the filename to a byte stream.
+    c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
+
+    arg_pointers = [
+        ctypes.byref(ctypes.cast(tensor, ctypes.c_void_p)),
+        ctypes.byref(c_filename),
+    ]
+
+    # Invoke the execution engine to run the module and return the result.
+    engine.invoke(_ENTRY_NAME, *arg_pointers)
index 69db28d..8f193b8 100644 (file)
@@ -13,29 +13,29 @@ from typing import Sequence
 
 
 class SparseCompiler:
-  """Sparse compiler class for compiling and building MLIR modules."""
-
-  def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
-    pipeline = f'builtin.module(sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}})'
-    self.pipeline = pipeline
-    self.opt_level = opt_level
-    self.shared_libs = shared_libs
-
-  def __call__(self, module: ir.Module):
-    """Convenience application method."""
-    self.compile(module)
-
-  def compile(self, module: ir.Module):
-    """Compiles the module by invoking the sparse copmiler pipeline."""
-    passmanager.PassManager.parse(self.pipeline).run(module.operation)
-
-  def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
-    """Wraps the module in a JIT execution engine."""
-    return execution_engine.ExecutionEngine(
-        module, opt_level=self.opt_level, shared_libs=self.shared_libs)
-
-  def compile_and_jit(self,
-                      module: ir.Module) -> execution_engine.ExecutionEngine:
-    """Compiles and jits the module."""
-    self.compile(module)
-    return self.jit(module)
+    """Sparse compiler class for compiling and building MLIR modules."""
+
+    def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
+        pipeline = f"builtin.module(sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}})"
+        self.pipeline = pipeline
+        self.opt_level = opt_level
+        self.shared_libs = shared_libs
+
+    def __call__(self, module: ir.Module):
+        """Convenience application method."""
+        self.compile(module)
+
+    def compile(self, module: ir.Module):
+        """Compiles the module by invoking the sparse copmiler pipeline."""
+        passmanager.PassManager.parse(self.pipeline).run(module.operation)
+
+    def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+        """Wraps the module in a JIT execution engine."""
+        return execution_engine.ExecutionEngine(
+            module, opt_level=self.opt_level, shared_libs=self.shared_libs
+        )
+
+    def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+        """Compiles and jits the module."""
+        self.compile(module)
+        return self.jit(module)
index 466c9df..1be88fa 100644 (file)
@@ -8,38 +8,40 @@ import numpy as np
 
 
 def compare_sparse_tns(expected: str, actual: str, rtol: float = 0.0001) -> bool:
-  """Compares sparse tensor actual output file with expected output file.
+    """Compares sparse tensor actual output file with expected output file.
 
-  This routine assumes the input files are in FROSTT format. See
-  http://frostt.io/tensors/file-formats.html for FROSTT (.tns) format.
+    This routine assumes the input files are in FROSTT format. See
+    http://frostt.io/tensors/file-formats.html for FROSTT (.tns) format.
 
-  It also assumes the first line in the output file is a comment line.
+    It also assumes the first line in the output file is a comment line.
 
-  """
-  with open(actual, "r") as actual_f:
-    with open(expected, "r") as expected_f:
-      # Skip the first comment line.
-      _ = actual_f.readline()
-      _ = expected_f.readline()
+    """
+    with open(actual, "r") as actual_f:
+        with open(expected, "r") as expected_f:
+            # Skip the first comment line.
+            _ = actual_f.readline()
+            _ = expected_f.readline()
 
-      # Compare the two lines of meta data
-      if (actual_f.readline() != expected_f.readline() or
-          actual_f.readline() != expected_f.readline()):
-        return FALSE
+            # Compare the two lines of meta data
+            if (
+                actual_f.readline() != expected_f.readline()
+                or actual_f.readline() != expected_f.readline()
+            ):
+                return FALSE
 
-  actual_data = np.loadtxt(actual, np.float64, skiprows=3)
-  expected_data = np.loadtxt(expected, np.float64, skiprows=3)
-  return np.allclose(actual_data, expected_data, rtol=rtol)
+    actual_data = np.loadtxt(actual, np.float64, skiprows=3)
+    expected_data = np.loadtxt(expected, np.float64, skiprows=3)
+    return np.allclose(actual_data, expected_data, rtol=rtol)
 
 
 def file_as_string(file: str) -> str:
-  """Returns contents of file as string."""
-  with open(file, "r") as f:
-    return f.read()
+    """Returns contents of file as string."""
+    with open(file, "r") as f:
+        return f.read()
 
 
 def run_test(f):
-  """Prints the test name and runs the test."""
-  print(f.__name__)
-  f()
-  return f
+    """Prints the test name and runs the test."""
+    print(f.__name__)
+    f()
+    return f
index 5b7e648..45ce446 100644 (file)
@@ -18,509 +18,630 @@ _DENSE = mlir_pytaco.ModeFormat.DENSE
 
 
 def _init_3d(T, I, J, K):
-  for i in range(I):
-    for j in range(J):
-      for k in range(K):
-        T.insert([i, j, k], i + j + k + 1)
+    for i in range(I):
+        for j in range(J):
+            for k in range(K):
+                T.insert([i, j, k], i + j + k + 1)
 
 
 def _init_2d(T, I, J):
-  for i in range(I):
-    for j in range(J):
-      T.insert([i, j], i + j + 1)
+    for i in range(I):
+        for j in range(J):
+            T.insert([i, j], i + j + 1)
 
 
 def _init_1d_with_value(T, I, v):
-  for i in range(I):
-    T.insert([i], v)
+    for i in range(I):
+        T.insert([i], v)
 
 
 def test_expect_error(name, code, error):
-  """Executes the code then verifies the expected error message."""
-  try:
-    exec(code)
-  except ValueError as e:
-    passed = "passed" if (str(e).startswith(error)) else "failed"
-    print(f"test_{name}: {passed}")
+    """Executes the code then verifies the expected error message."""
+    try:
+        exec(code)
+    except ValueError as e:
+        passed = "passed" if (str(e).startswith(error)) else "failed"
+        print(f"test_{name}: {passed}")
 
 
 # CHECK-LABEL: test_tensor_dtype
 @testing_utils.run_test
 def test_tensor_dtype():
-  passed = mlir_pytaco.DType(mlir_pytaco.Type.INT16).is_int()
-  passed += mlir_pytaco.DType(mlir_pytaco.Type.INT32).is_int()
-  passed += mlir_pytaco.DType(mlir_pytaco.Type.INT64).is_int()
-  passed += mlir_pytaco.DType(mlir_pytaco.Type.FLOAT32).is_float()
-  passed += mlir_pytaco.DType(mlir_pytaco.Type.FLOAT64).is_float()
-  # CHECK: Number of passed: 5
-  print("Number of passed:", passed)
+    passed = mlir_pytaco.DType(mlir_pytaco.Type.INT16).is_int()
+    passed += mlir_pytaco.DType(mlir_pytaco.Type.INT32).is_int()
+    passed += mlir_pytaco.DType(mlir_pytaco.Type.INT64).is_int()
+    passed += mlir_pytaco.DType(mlir_pytaco.Type.FLOAT32).is_float()
+    passed += mlir_pytaco.DType(mlir_pytaco.Type.FLOAT64).is_float()
+    # CHECK: Number of passed: 5
+    print("Number of passed:", passed)
 
 
 # CHECK: test_mode_ordering_not_int: passed
-test_expect_error("mode_ordering_not_int",
-                  "m = mlir_pytaco.ModeOrdering(['x'])",
-                  "Ordering must be a list of integers")
+test_expect_error(
+    "mode_ordering_not_int",
+    "m = mlir_pytaco.ModeOrdering(['x'])",
+    "Ordering must be a list of integers",
+)
 
 # CHECK: test_mode_ordering_not_permutation: passed
-test_expect_error("mode_ordering_not_permutation",
-                  "m = mlir_pytaco.ModeOrdering([2, 1])", "Invalid ordering")
+test_expect_error(
+    "mode_ordering_not_permutation",
+    "m = mlir_pytaco.ModeOrdering([2, 1])",
+    "Invalid ordering",
+)
 
 # CHECK: test_mode_format_invalid: passed
-test_expect_error("mode_format_invalid",
-                  "m = mlir_pytaco.ModeFormatPack(['y'])",
-                  "Formats must be a list of ModeFormat")
+test_expect_error(
+    "mode_format_invalid",
+    "m = mlir_pytaco.ModeFormatPack(['y'])",
+    "Formats must be a list of ModeFormat",
+)
 
 # CHECK: test_expect_mode_format_pack: passed
-test_expect_error("expect_mode_format_pack", ("""
+test_expect_error(
+    "expect_mode_format_pack",
+    (
+        """
 mode_ordering = mlir_pytaco.ModeOrdering([0, 1, 2])
 f = mlir_pytaco.Format(["x"], mode_ordering)
-    """), "Expected a list of ModeFormat")
+    """
+    ),
+    "Expected a list of ModeFormat",
+)
 
 # CHECK: test_expect_mode_ordering: passed
-test_expect_error("expect_mode_ordering", ("""
+test_expect_error(
+    "expect_mode_ordering",
+    (
+        """
 mode_format_pack = mlir_pytaco.ModeFormatPack([_COMPRESSED, _COMPRESSED])
 f = mlir_pytaco.Format(mode_format_pack, "x")
-    """), "Expected ModeOrdering")
+    """
+    ),
+    "Expected ModeOrdering",
+)
 
 # CHECK: test_inconsistent_mode_format_pack_and_mode_ordering: passed
-test_expect_error("inconsistent_mode_format_pack_and_mode_ordering", ("""
+test_expect_error(
+    "inconsistent_mode_format_pack_and_mode_ordering",
+    (
+        """
 mode_format_pack = mlir_pytaco.ModeFormatPack([_COMPRESSED, _COMPRESSED])
 mode_ordering = mlir_pytaco.ModeOrdering([0, 1, 2])
 f = mlir_pytaco.Format(mode_format_pack, mode_ordering)
-    """), "Inconsistent ModeFormatPack and ModeOrdering")
+    """
+    ),
+    "Inconsistent ModeFormatPack and ModeOrdering",
+)
 
 
 # CHECK-LABEL: test_format_default_ordering
 @testing_utils.run_test
 def test_format_default_ordering():
-  f = mlir_pytaco.Format([_COMPRESSED, _COMPRESSED])
-  passed = 0
-  passed += np.array_equal(f.ordering.ordering, [0, 1])
-  # CHECK: Number of passed: 1
-  print("Number of passed:", passed)
+    f = mlir_pytaco.Format([_COMPRESSED, _COMPRESSED])
+    passed = 0
+    passed += np.array_equal(f.ordering.ordering, [0, 1])
+    # CHECK: Number of passed: 1
+    print("Number of passed:", passed)
 
 
 # CHECK-LABEL: test_format_explicit_ordering
 @testing_utils.run_test
 def test_format_explicit_ordering():
-  f = mlir_pytaco.Format([_COMPRESSED, _DENSE], [1, 0])
-  passed = 0
-  passed += np.array_equal(f.ordering.ordering, [1, 0])
-  # CHECK: Number of passed: 1
-  print("Number of passed:", passed)
+    f = mlir_pytaco.Format([_COMPRESSED, _DENSE], [1, 0])
+    passed = 0
+    passed += np.array_equal(f.ordering.ordering, [1, 0])
+    # CHECK: Number of passed: 1
+    print("Number of passed:", passed)
 
 
 # CHECK-LABEL: test_index_var
 @testing_utils.run_test
 def test_index_var():
-  i = mlir_pytaco.IndexVar()
-  j = mlir_pytaco.IndexVar()
-  passed = (i.name != j.name)
+    i = mlir_pytaco.IndexVar()
+    j = mlir_pytaco.IndexVar()
+    passed = i.name != j.name
 
-  vars = mlir_pytaco.get_index_vars(10)
-  passed += (len(vars) == 10)
-  passed += (all([isinstance(e, mlir_pytaco.IndexVar) for e in vars]))
+    vars = mlir_pytaco.get_index_vars(10)
+    passed += len(vars) == 10
+    passed += all([isinstance(e, mlir_pytaco.IndexVar) for e in vars])
 
-  # CHECK: Number of passed: 3
-  print("Number of passed:", passed)
+    # CHECK: Number of passed: 3
+    print("Number of passed:", passed)
 
 
 # CHECK: test_tensor_invalid_first_argument: passed
-test_expect_error("tensor_invalid_first_argument",
-                  "t = mlir_pytaco.Tensor('f')", "Invalid first argument")
+test_expect_error(
+    "tensor_invalid_first_argument",
+    "t = mlir_pytaco.Tensor('f')",
+    "Invalid first argument",
+)
 
 # CHECK: test_tensor_inconsistent_shape_and_format: passed
-test_expect_error("tensor_inconsistent_shape_and_format", ("""
+test_expect_error(
+    "tensor_inconsistent_shape_and_format",
+    (
+        """
 mode_format_pack = mlir_pytaco.ModeFormatPack([_COMPRESSED, _COMPRESSED])
 mode_ordering = mlir_pytaco.ModeOrdering([0, 1])
 f = mlir_pytaco.Format(mode_format_pack, mode_ordering)
 t = mlir_pytaco.Tensor([3], f)
-    """), "Inconsistent shape and format")
+    """
+    ),
+    "Inconsistent shape and format",
+)
 
 # CHECK: test_tensor_invalid_format: passed
-test_expect_error("tensor_invalid_format", "t = mlir_pytaco.Tensor([3], 'f')",
-                  "Invalid format argument")
+test_expect_error(
+    "tensor_invalid_format",
+    "t = mlir_pytaco.Tensor([3], 'f')",
+    "Invalid format argument",
+)
 
 # CHECK: test_tensor_insert_nonlist_coordinate: passed
-test_expect_error("tensor_insert_nonlist_coordinate", ("""
+test_expect_error(
+    "tensor_insert_nonlist_coordinate",
+    (
+        """
 t = mlir_pytaco.Tensor([3])
 t.insert(1, 0)
-    """), "Non list coordinate detected")
+    """
+    ),
+    "Non list coordinate detected",
+)
 
 # CHECK: test_tensor_insert_too_much_coordinate: passed
-test_expect_error("tensor_insert_too_much_coordinate", ("""
+test_expect_error(
+    "tensor_insert_too_much_coordinate",
+    (
+        """
 t = mlir_pytaco.Tensor([3])
 t.insert([0, 0], 0)
-    """), "Invalid coordinate")
+    """
+    ),
+    "Invalid coordinate",
+)
 
 # CHECK: test_tensor_insert_coordinate_outof_range: passed
-test_expect_error("tensor_insert_coordinate_outof_range", ("""
+test_expect_error(
+    "tensor_insert_coordinate_outof_range",
+    (
+        """
 t = mlir_pytaco.Tensor([1, 1])
 t.insert([1, 0], 0)
-    """), "Invalid coordinate")
+    """
+    ),
+    "Invalid coordinate",
+)
 
 # CHECK: test_tensor_insert_coordinate_nonint: passed
-test_expect_error("tensor_insert_coordinate_nonint", ("""
+test_expect_error(
+    "tensor_insert_coordinate_nonint",
+    (
+        """
 t = mlir_pytaco.Tensor([1, 1])
 t.insert([0, "xy"], 0)
-    """), "Non integer coordinate detected")
+    """
+    ),
+    "Non integer coordinate detected",
+)
 
 # CHECK: test_tensor_insert_invalid_value: passed
-test_expect_error("tensor_insert_invalid_value", ("""
+test_expect_error(
+    "tensor_insert_invalid_value",
+    (
+        """
 t = mlir_pytaco.Tensor([1, 1])
 t.insert([0, 0], "x")
-    """), "Value is neither int nor float")
+    """
+    ),
+    "Value is neither int nor float",
+)
 
 # CHECK: test_access_non_index_var_index: passed
-test_expect_error("access_non_index_var_index", ("""
+test_expect_error(
+    "access_non_index_var_index",
+    (
+        """
 t = mlir_pytaco.Tensor([5, 6])
 i = mlir_pytaco.IndexVar()
 a = mlir_pytaco.Access(t, (i, "j"))
-    """), "Indices contain non IndexVar")
+    """
+    ),
+    "Indices contain non IndexVar",
+)
 
 # CHECK: test_access_inconsistent_rank_indices: passed
-test_expect_error("access_inconsistent_rank_indices", ("""
+test_expect_error(
+    "access_inconsistent_rank_indices",
+    (
+        """
 t = mlir_pytaco.Tensor([5, 6])
 i = mlir_pytaco.IndexVar()
 a = mlir_pytaco.Access(t, (i,))
-    """), "Invalid indices for rank")
+    """
+    ),
+    "Invalid indices for rank",
+)
 
 # CHECK: test_access_invalid_indices_for_rank: passed
-test_expect_error("access_invalid_indices_for_rank", ("""
+test_expect_error(
+    "access_invalid_indices_for_rank",
+    (
+        """
 t = mlir_pytaco.Tensor([5, 6])
 i, j, k = mlir_pytaco.get_index_vars(3)
 a = mlir_pytaco.Access(t, (i,j, k))
-    """), "Invalid indices for rank")
+    """
+    ),
+    "Invalid indices for rank",
+)
 
 # CHECK: test_invalid_indices: passed
-test_expect_error("invalid_indices", ("""
+test_expect_error(
+    "invalid_indices",
+    (
+        """
 i, j = mlir_pytaco.get_index_vars(2)
 A = mlir_pytaco.Tensor([2, 3])
 B = mlir_pytaco.Tensor([2, 3])
 C = mlir_pytaco.Tensor([2, 3], _DENSE)
 C[i, j] = A[1, j] + B[i, j]
-    """), "Expected IndexVars")
+    """
+    ),
+    "Expected IndexVars",
+)
 
 # CHECK: test_inconsistent_rank_indices: passed
-test_expect_error("inconsistent_rank_indices", ("""
+test_expect_error(
+    "inconsistent_rank_indices",
+    (
+        """
 i, j = mlir_pytaco.get_index_vars(2)
 A = mlir_pytaco.Tensor([2, 3])
 C = mlir_pytaco.Tensor([2, 3], _DENSE)
 C[i, j] = A[i]
-    """), "Invalid indices for rank")
+    """
+    ),
+    "Invalid indices for rank",
+)
 
 # CHECK: test_destination_index_not_used_in_source: passed
-test_expect_error("destination_index_not_used_in_source", ("""
+test_expect_error(
+    "destination_index_not_used_in_source",
+    (
+        """
 i, j = mlir_pytaco.get_index_vars(2)
 A = mlir_pytaco.Tensor([3])
 C = mlir_pytaco.Tensor([3], _DENSE)
 C[j] = A[i]
 C.evaluate()
-    """), "Destination IndexVar not used in the source expression")
+    """
+    ),
+    "Destination IndexVar not used in the source expression",
+)
 
 # CHECK: test_destination_dim_not_consistent_with_source: passed
-test_expect_error("destination_dim_not_consistent_with_source", ("""
+test_expect_error(
+    "destination_dim_not_consistent_with_source",
+    (
+        """
 i = mlir_pytaco.IndexVar()
 A = mlir_pytaco.Tensor([3])
 C = mlir_pytaco.Tensor([5], _DENSE)
 C[i] = A[i]
 C.evaluate()
-    """), "Inconsistent destination dimension for IndexVar")
+    """
+    ),
+    "Inconsistent destination dimension for IndexVar",
+)
 
 # CHECK: test_inconsistent_source_dim: passed
-test_expect_error("inconsistent_source_dim", ("""
+test_expect_error(
+    "inconsistent_source_dim",
+    (
+        """
 i = mlir_pytaco.IndexVar()
 A = mlir_pytaco.Tensor([3])
 B = mlir_pytaco.Tensor([5])
 C = mlir_pytaco.Tensor([3], _DENSE)
 C[i] = A[i] + B[i]
 C.evaluate()
-    """), "Inconsistent source dimension for IndexVar")
+    """
+    ),
+    "Inconsistent source dimension for IndexVar",
+)
 
 # CHECK: test_index_var_outside_domain: passed
-test_expect_error("index_var_outside_domain", ("""
+test_expect_error(
+    "index_var_outside_domain",
+    (
+        """
 i, j = mlir_pytaco.get_index_vars(2)
 A = mlir_pytaco.Tensor([3])
 B = mlir_pytaco.Tensor([3])
 B[i] = A[i] + j
 B.evaluate()
-    """), "IndexVar is not part of the iteration domain")
+    """
+    ),
+    "IndexVar is not part of the iteration domain",
+)
 
 
 # CHECK-LABEL: test_tensor_all_dense_sparse
 @testing_utils.run_test
 def test_tensor_all_dense_sparse():
-  a = mlir_pytaco.Tensor([4], [_DENSE])
-  passed = (not a.is_dense())
-  passed += (a.order == 1)
-  passed += (a.shape[0] == 4)
-  # CHECK: Number of passed: 3
-  print("Number of passed:", passed)
+    a = mlir_pytaco.Tensor([4], [_DENSE])
+    passed = not a.is_dense()
+    passed += a.order == 1
+    passed += a.shape[0] == 4
+    # CHECK: Number of passed: 3
+    print("Number of passed:", passed)
 
 
 # CHECK-LABEL: test_tensor_true_dense
 @testing_utils.run_test
 def test_tensor_true_dense():
-  a = mlir_pytaco.Tensor.from_array(np.random.uniform(size=5))
-  passed = a.is_dense()
-  passed += (a.order == 1)
-  passed += (a.shape[0] == 5)
-  # CHECK: Number of passed: 3
-  print("Number of passed:", passed)
+    a = mlir_pytaco.Tensor.from_array(np.random.uniform(size=5))
+    passed = a.is_dense()
+    passed += a.order == 1
+    passed += a.shape[0] == 5
+    # CHECK: Number of passed: 3
+    print("Number of passed:", passed)
 
 
 # CHECK-LABEL: test_tensor_copy
 @testing_utils.run_test
 def test_tensor_copy():
-  i, j = mlir_pytaco.get_index_vars(2)
-  I = 2
-  J = 3
-  A = mlir_pytaco.Tensor([I, J])
-  A.insert([0, 1], 5.0)
-  A.insert([1, 2], 6.0)
-  B = mlir_pytaco.Tensor([I, J])
-  B[i, j] = A[i, j]
-  passed = (B._assignment is not None)
-  passed += (B._engine is None)
-  try:
+    i, j = mlir_pytaco.get_index_vars(2)
+    I = 2
+    J = 3
+    A = mlir_pytaco.Tensor([I, J])
+    A.insert([0, 1], 5.0)
+    A.insert([1, 2], 6.0)
+    B = mlir_pytaco.Tensor([I, J])
+    B[i, j] = A[i, j]
+    passed = B._assignment is not None
+    passed += B._engine is None
+    try:
+        B.compute()
+    except ValueError as e:
+        passed += str(e).startswith("Need to invoke compile")
+    B.compile()
+    passed += B._engine is not None
     B.compute()
-  except ValueError as e:
-    passed += (str(e).startswith("Need to invoke compile"))
-  B.compile()
-  passed += (B._engine is not None)
-  B.compute()
-  passed += (B._assignment is None)
-  passed += (B._engine is None)
-  indices, values = B.get_coordinates_and_values()
-  passed += np.array_equal(indices, [[0, 1], [1, 2]])
-  passed += np.allclose(values, [5.0, 6.0])
-  # No temporary tensor is used.
-  passed += (B._stats.get_total() == 0)
-  # CHECK: Number of passed: 9
-  print("Number of passed:", passed)
+    passed += B._assignment is None
+    passed += B._engine is None
+    indices, values = B.get_coordinates_and_values()
+    passed += np.array_equal(indices, [[0, 1], [1, 2]])
+    passed += np.allclose(values, [5.0, 6.0])
+    # No temporary tensor is used.
+    passed += B._stats.get_total() == 0
+    # CHECK: Number of passed: 9
+    print("Number of passed:", passed)
 
 
 # CHECK-LABEL: test_tensor_trivial_reduction
 @testing_utils.run_test
 def test_tensor_trivial_reduction():
-  i, j = mlir_pytaco.get_index_vars(2)
-  I = 2
-  J = 3
-  A = mlir_pytaco.Tensor([I, J])
-  A.insert([0, 1], 5.0)
-  A.insert([0, 2], 3.0)
-  A.insert([1, 2], 6.0)
-  B = mlir_pytaco.Tensor([I])
-  B[i] = A[i, j]
-  indices, values = B.get_coordinates_and_values()
-  passed = np.array_equal(indices, [[0], [1]])
-  passed += np.allclose(values, [8.0, 6.0])
-  # No temporary tensor is used.
-  passed += (B._stats.get_total() == 0)
-
-  # CHECK: Number of passed: 3
-  print("Number of passed:", passed)
+    i, j = mlir_pytaco.get_index_vars(2)
+    I = 2
+    J = 3
+    A = mlir_pytaco.Tensor([I, J])
+    A.insert([0, 1], 5.0)
+    A.insert([0, 2], 3.0)
+    A.insert([1, 2], 6.0)
+    B = mlir_pytaco.Tensor([I])
+    B[i] = A[i, j]
+    indices, values = B.get_coordinates_and_values()
+    passed = np.array_equal(indices, [[0], [1]])
+    passed += np.allclose(values, [8.0, 6.0])
+    # No temporary tensor is used.
+    passed += B._stats.get_total() == 0
+
+    # CHECK: Number of passed: 3
+    print("Number of passed:", passed)
 
 
 # CHECK-LABEL: test_binary_add
 @testing_utils.run_test
 def test_binary_add():
-  i = mlir_pytaco.IndexVar()
-  A = mlir_pytaco.Tensor([4])
-  B = mlir_pytaco.Tensor([4])
-  C = mlir_pytaco.Tensor([4])
-  A.insert([1], 10)
-  A.insert([2], 1)
-  B.insert([3], 20)
-  B.insert([2], 2)
-  C[i] = A[i] + B[i]
-  indices, values = C.get_coordinates_and_values()
-  passed = np.array_equal(indices, [[1], [2], [3]])
-  passed += np.array_equal(values, [10., 3., 20.])
-  # No temporary tensor is used.
-  passed += (C._stats.get_total() == 0)
-  # CHECK: Number of passed: 3
-  print("Number of passed:", passed)
+    i = mlir_pytaco.IndexVar()
+    A = mlir_pytaco.Tensor([4])
+    B = mlir_pytaco.Tensor([4])
+    C = mlir_pytaco.Tensor([4])
+    A.insert([1], 10)
+    A.insert([2], 1)
+    B.insert([3], 20)
+    B.insert([2], 2)
+    C[i] = A[i] + B[i]
+    indices, values = C.get_coordinates_and_values()
+    passed = np.array_equal(indices, [[1], [2], [3]])
+    passed += np.array_equal(values, [10.0, 3.0, 20.0])
+    # No temporary tensor is used.
+    passed += C._stats.get_total() == 0
+    # CHECK: Number of passed: 3
+    print("Number of passed:", passed)
 
 
 # CHECK-LABEL: test_binary_add_sub
 @testing_utils.run_test
 def test_binary_add_sub():
-  i = mlir_pytaco.IndexVar()
-  j = mlir_pytaco.IndexVar()
-  A = mlir_pytaco.Tensor([2, 3])
-  B = mlir_pytaco.Tensor([2, 3])
-  C = mlir_pytaco.Tensor([2, 3])
-  D = mlir_pytaco.Tensor([2, 3])
-  A.insert([0, 1], 10)
-  A.insert([1, 2], 40)
-  B.insert([0, 0], 20)
-  B.insert([1, 2], 30)
-  C.insert([0, 1], 5)
-  C.insert([1, 2], 7)
-  D[i, j] = A[i, j] + B[i, j] - C[i, j]
-  indices, values = D.get_coordinates_and_values()
-  passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
-  passed += np.array_equal(values, [20., 5., 63.])
-  # No temporary tensor is used.
-  passed += (D._stats.get_total() == 0)
-  # CHECK: Number of passed: 3
-  print("Number of passed:", passed)
+    i = mlir_pytaco.IndexVar()
+    j = mlir_pytaco.IndexVar()
+    A = mlir_pytaco.Tensor([2, 3])
+    B = mlir_pytaco.Tensor([2, 3])
+    C = mlir_pytaco.Tensor([2, 3])
+    D = mlir_pytaco.Tensor([2, 3])
+    A.insert([0, 1], 10)
+    A.insert([1, 2], 40)
+    B.insert([0, 0], 20)
+    B.insert([1, 2], 30)
+    C.insert([0, 1], 5)
+    C.insert([1, 2], 7)
+    D[i, j] = A[i, j] + B[i, j] - C[i, j]
+    indices, values = D.get_coordinates_and_values()
+    passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
+    passed += np.array_equal(values, [20.0, 5.0, 63.0])
+    # No temporary tensor is used.
+    passed += D._stats.get_total() == 0
+    # CHECK: Number of passed: 3
+    print("Number of passed:", passed)
 
 
 # CHECK-LABEL: test_binary_mul_add
 @testing_utils.run_test
 def test_binary_mul_add():
-  i = mlir_pytaco.IndexVar()
-  j = mlir_pytaco.IndexVar()
-  A = mlir_pytaco.Tensor([2, 3])
-  B = mlir_pytaco.Tensor([2, 3])
-  C = mlir_pytaco.Tensor([2, 3])
-  D = mlir_pytaco.Tensor([2, 3])
-  A.insert([0, 1], 10)
-  A.insert([1, 2], 40)
-  B.insert([0, 0], 20)
-  B.insert([1, 2], 30)
-  C.insert([0, 1], 5)
-  C.insert([1, 2], 7)
-  D[i, j] = A[i, j] * C[i, j] + B[i, j]
-  indices, values = D.get_coordinates_and_values()
-  passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
-  passed += np.array_equal(values, [20., 50., 310.])
-  # No temporary tensor is used.
-  passed += (D._stats.get_total() == 0)
-  # CHECK: Number of passed: 3
-  print("Number of passed:", passed)
+    i = mlir_pytaco.IndexVar()
+    j = mlir_pytaco.IndexVar()
+    A = mlir_pytaco.Tensor([2, 3])
+    B = mlir_pytaco.Tensor([2, 3])
+    C = mlir_pytaco.Tensor([2, 3])
+    D = mlir_pytaco.Tensor([2, 3])
+    A.insert([0, 1], 10)
+    A.insert([1, 2], 40)
+    B.insert([0, 0], 20)
+    B.insert([1, 2], 30)
+    C.insert([0, 1], 5)
+    C.insert([1, 2], 7)
+    D[i, j] = A[i, j] * C[i, j] + B[i, j]
+    indices, values = D.get_coordinates_and_values()
+    passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
+    passed += np.array_equal(values, [20.0, 50.0, 310.0])
+    # No temporary tensor is used.
+    passed += D._stats.get_total() == 0
+    # CHECK: Number of passed: 3
+    print("Number of passed:", passed)
 
 
 # CHECK-LABEL: test_binary_add_reduce_at_root
 @testing_utils.run_test
 def test_binary_add_reduce_at_root():
-  i = mlir_pytaco.IndexVar()
-  j = mlir_pytaco.IndexVar()
-  A = mlir_pytaco.Tensor([2, 3])
-  B = mlir_pytaco.Tensor([2, 3])
-  C = mlir_pytaco.Tensor([2], _DENSE)
-  A.insert([0, 1], 10)
-  A.insert([1, 2], 40)
-  B.insert([0, 0], 20)
-  B.insert([1, 2], 30)
-  C[i] = A[i, j] + B[i, j]
-  indices, values = C.get_coordinates_and_values()
-  passed = np.array_equal(indices, [[0], [1]])
-  passed += np.array_equal(values, [30., 70.])
-  # No temporary tensor is used.
-  passed += (C._stats.get_total() == 0)
-  # CHECK: Number of passed: 3
-  print("Number of passed:", passed)
+    i = mlir_pytaco.IndexVar()
+    j = mlir_pytaco.IndexVar()
+    A = mlir_pytaco.Tensor([2, 3])
+    B = mlir_pytaco.Tensor([2, 3])
+    C = mlir_pytaco.Tensor([2], _DENSE)
+    A.insert([0, 1], 10)
+    A.insert([1, 2], 40)
+    B.insert([0, 0], 20)
+    B.insert([1, 2], 30)
+    C[i] = A[i, j] + B[i, j]
+    indices, values = C.get_coordinates_and_values()
+    passed = np.array_equal(indices, [[0], [1]])
+    passed += np.array_equal(values, [30.0, 70.0])
+    # No temporary tensor is used.
+    passed += C._stats.get_total() == 0
+    # CHECK: Number of passed: 3
+    print("Number of passed:", passed)
 
 
 # CHECK-LABEL: test_binary_add_reduce_at_child
 @testing_utils.run_test
 def test_binary_add_reduce_at_child():
-  i = mlir_pytaco.IndexVar()
-  j = mlir_pytaco.IndexVar()
-  I = 2
-  J = 3
-  A = mlir_pytaco.Tensor([I, J])
-  B = mlir_pytaco.Tensor([J])
-  C = mlir_pytaco.Tensor([I])
-  D = mlir_pytaco.Tensor([I], _DENSE)
-
-  _init_2d(A, I, J)
-  _init_1d_with_value(C, I, 2)
-  _init_1d_with_value(B, J, 1)
-
-  D[i] = A[i, j] * B[j] + C[i]
-  indices, values = D.get_coordinates_and_values()
-  passed = np.array_equal(indices, [[0], [1]])
-  passed += np.array_equal(values, [8., 11.])
-
-  # The expression is implemented as:
-  #    temp0[i] = A[i, j] * B[i]
-  #    D[i] = temp0[i] + C[i]
-  # Check the temporary tensor introduced by the implementation.
-  stats = D._stats
-  passed += (stats.get_total() == 1)
-  passed += (stats.get_formats(0) == (_COMPRESSED,))
-  passed += (stats.get_dimensions(0) == (I,))
-  # CHECK: Number of passed: 5
-  print("Number of passed:", passed)
+    i = mlir_pytaco.IndexVar()
+    j = mlir_pytaco.IndexVar()
+    I = 2
+    J = 3
+    A = mlir_pytaco.Tensor([I, J])
+    B = mlir_pytaco.Tensor([J])
+    C = mlir_pytaco.Tensor([I])
+    D = mlir_pytaco.Tensor([I], _DENSE)
+
+    _init_2d(A, I, J)
+    _init_1d_with_value(C, I, 2)
+    _init_1d_with_value(B, J, 1)
+
+    D[i] = A[i, j] * B[j] + C[i]
+    indices, values = D.get_coordinates_and_values()
+    passed = np.array_equal(indices, [[0], [1]])
+    passed += np.array_equal(values, [8.0, 11.0])
+
+    # The expression is implemented as:
+    #    temp0[i] = A[i, j] * B[i]
+    #    D[i] = temp0[i] + C[i]
+    # Check the temporary tensor introduced by the implementation.
+    stats = D._stats
+    passed += stats.get_total() == 1
+    passed += stats.get_formats(0) == (_COMPRESSED,)
+    passed += stats.get_dimensions(0) == (I,)
+    # CHECK: Number of passed: 5
+    print("Number of passed:", passed)
 
 
 # CHECK-LABEL: test_binary_add_reduce_3d_1
 @testing_utils.run_test
 def test_binary_add_reduce_3d_1():
-  i, j, k, l = mlir_pytaco.get_index_vars(4)
-  I = 2
-  J = 3
-  K = 4
-  L = 5
-  A = mlir_pytaco.Tensor([I, J, K])
-  B = mlir_pytaco.Tensor([I, J, L])
-  C = mlir_pytaco.Tensor([K])
-  D = mlir_pytaco.Tensor([L])
-  E = mlir_pytaco.Tensor([I], _DENSE)
-
-  _init_3d(A, I, J, K)
-  _init_3d(B, I, J, L)
-  _init_1d_with_value(C, K, 1)
-  _init_1d_with_value(D, L, 2)
-
-  E[i] = A[i, j, k] * C[k] + B[i, j, l] * D[l]
-  indices, values = E.get_coordinates_and_values()
-  passed = np.array_equal(indices, [[0], [1]])
-  passed += np.array_equal(values, [162., 204.])
-
-  # The expression is implemented as:
-  #    temp0[i, j] = A[i, j, k] * C[k]
-  #    temp1[i, j] = B[i, j, l] * D[l]
-  #    E[i] = temp0[i, j] + temp1[i, j]
-  # Check the two temporary tensors introduced by the implementation.
-  stats = E._stats
-  passed += (stats.get_total() == 2)
-  passed += (stats.get_formats(0) == (_COMPRESSED, _COMPRESSED))
-  passed += (stats.get_dimensions(0) == (I, J))
-  passed += (stats.get_formats(1) == (_COMPRESSED, _COMPRESSED))
-  passed += (stats.get_dimensions(1) == (I, J))
-  # CHECK: Number of passed: 7
-  print("Number of passed:", passed)
+    i, j, k, l = mlir_pytaco.get_index_vars(4)
+    I = 2
+    J = 3
+    K = 4
+    L = 5
+    A = mlir_pytaco.Tensor([I, J, K])
+    B = mlir_pytaco.Tensor([I, J, L])
+    C = mlir_pytaco.Tensor([K])
+    D = mlir_pytaco.Tensor([L])
+    E = mlir_pytaco.Tensor([I], _DENSE)
+
+    _init_3d(A, I, J, K)
+    _init_3d(B, I, J, L)
+    _init_1d_with_value(C, K, 1)
+    _init_1d_with_value(D, L, 2)
+
+    E[i] = A[i, j, k] * C[k] + B[i, j, l] * D[l]
+    indices, values = E.get_coordinates_and_values()
+    passed = np.array_equal(indices, [[0], [1]])
+    passed += np.array_equal(values, [162.0, 204.0])
+
+    # The expression is implemented as:
+    #    temp0[i, j] = A[i, j, k] * C[k]
+    #    temp1[i, j] = B[i, j, l] * D[l]
+    #    E[i] = temp0[i, j] + temp1[i, j]
+    # Check the two temporary tensors introduced by the implementation.
+    stats = E._stats
+    passed += stats.get_total() == 2
+    passed += stats.get_formats(0) == (_COMPRESSED, _COMPRESSED)
+    passed += stats.get_dimensions(0) == (I, J)
+    passed += stats.get_formats(1) == (_COMPRESSED, _COMPRESSED)
+    passed += stats.get_dimensions(1) == (I, J)
+    # CHECK: Number of passed: 7
+    print("Number of passed:", passed)
 
 
 # CHECK-LABEL: test_binary_add_reduce_3d_2
 @testing_utils.run_test
 def test_binary_add_reduce_3d_2():
-  i, j, k, l = mlir_pytaco.get_index_vars(4)
-  I = 2
-  J = 3
-  K = 4
-  L = 5
-  A = mlir_pytaco.Tensor([I, J, K], [_COMPRESSED, _COMPRESSED, _DENSE])
-  B = mlir_pytaco.Tensor([I, L, K], [_DENSE, _COMPRESSED, _COMPRESSED])
-  C = mlir_pytaco.Tensor([J, K], [_COMPRESSED, _COMPRESSED])
-  D = mlir_pytaco.Tensor([L])
-  E = mlir_pytaco.Tensor([I], _DENSE)
-
-  _init_3d(A, I, J, K)
-  _init_3d(B, I, L, K)
-  _init_2d(C, J, K)
-  _init_1d_with_value(D, L, 2)
-
-  E[i] = A[i, j, k] + C[j, k] + B[i, l, k] * D[l]
-  indices, values = E.get_coordinates_and_values()
-  passed = np.array_equal(indices, [[0], [1]])
-  passed += np.array_equal(values, [264., 316.])
-
-  # The expression is implemented as:
-  #    temp0[i, k] = A[i, j, k] + C[j, k]
-  #    temp1[i, k] = B[i, l, k] * D[l]
-  #    E[i] = temp0[i, k] + temp1[i, k]
-  # Check the two temporary tensors introduced by the implementation.
-  stats = E._stats
-  passed += (stats.get_total() == 2)
-  passed += (stats.get_formats(0) == (_COMPRESSED, _DENSE))
-  passed += (stats.get_dimensions(0) == (I, K))
-  passed += (stats.get_formats(1) == (_DENSE, _COMPRESSED))
-  passed += (stats.get_dimensions(1) == (I, K))
-  # CHECK: Number of passed: 7
-  print("Number of passed:", passed)
+    i, j, k, l = mlir_pytaco.get_index_vars(4)
+    I = 2
+    J = 3
+    K = 4
+    L = 5
+    A = mlir_pytaco.Tensor([I, J, K], [_COMPRESSED, _COMPRESSED, _DENSE])
+    B = mlir_pytaco.Tensor([I, L, K], [_DENSE, _COMPRESSED, _COMPRESSED])
+    C = mlir_pytaco.Tensor([J, K], [_COMPRESSED, _COMPRESSED])
+    D = mlir_pytaco.Tensor([L])
+    E = mlir_pytaco.Tensor([I], _DENSE)
+
+    _init_3d(A, I, J, K)
+    _init_3d(B, I, L, K)
+    _init_2d(C, J, K)
+    _init_1d_with_value(D, L, 2)
+
+    E[i] = A[i, j, k] + C[j, k] + B[i, l, k] * D[l]
+    indices, values = E.get_coordinates_and_values()
+    passed = np.array_equal(indices, [[0], [1]])
+    passed += np.array_equal(values, [264.0, 316.0])
+
+    # The expression is implemented as:
+    #    temp0[i, k] = A[i, j, k] + C[j, k]
+    #    temp1[i, k] = B[i, l, k] * D[l]
+    #    E[i] = temp0[i, k] + temp1[i, k]
+    # Check the two temporary tensors introduced by the implementation.
+    stats = E._stats
+    passed += stats.get_total() == 2
+    passed += stats.get_formats(0) == (_COMPRESSED, _DENSE)
+    passed += stats.get_dimensions(0) == (I, K)
+    passed += stats.get_formats(1) == (_DENSE, _COMPRESSED)
+    passed += stats.get_dimensions(1) == (I, K)
+    # CHECK: Number of passed: 7
+    print("Number of passed:", passed)
index cce97d6..1d52747 100644 (file)
@@ -32,21 +32,21 @@ _MTX_DATA = """%%MatrixMarket matrix coordinate real general
 # CHECK-LABEL: test_read_mtx_matrix_general
 @testing_utils.run_test
 def test_read_mtx_matrix_general():
-  with tempfile.TemporaryDirectory() as test_dir:
-    file_name = os.path.join(test_dir, "data.mtx")
-    with open(file_name, "w") as file:
-      file.write(_MTX_DATA)
-    a = mlir_pytaco_io.read(file_name, _FORMAT)
-  passed = 0
-  # The value of a is stored as an MLIR sparse tensor.
-  passed += (not a.is_unpacked())
-  a.unpack()
-  passed += (a.is_unpacked())
-  coords, values = a.get_coordinates_and_values()
-  passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]])
-  passed += np.allclose(values, [2.0, 3.0, 4.0])
-  # CHECK: 4
-  print(passed)
+    with tempfile.TemporaryDirectory() as test_dir:
+        file_name = os.path.join(test_dir, "data.mtx")
+        with open(file_name, "w") as file:
+            file.write(_MTX_DATA)
+        a = mlir_pytaco_io.read(file_name, _FORMAT)
+    passed = 0
+    # The value of a is stored as an MLIR sparse tensor.
+    passed += not a.is_unpacked()
+    a.unpack()
+    passed += a.is_unpacked()
+    coords, values = a.get_coordinates_and_values()
+    passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]])
+    passed += np.allclose(values, [2.0, 3.0, 4.0])
+    # CHECK: 4
+    print(passed)
 
 
 _TNS_DATA = """2 3
@@ -60,57 +60,57 @@ _TNS_DATA = """2 3
 # CHECK-LABEL: test_read_tns
 @testing_utils.run_test
 def test_read_tns():
-  with tempfile.TemporaryDirectory() as test_dir:
-    file_name = os.path.join(test_dir, "data.tns")
-    with open(file_name, "w") as file:
-      file.write(_TNS_DATA)
-    a = mlir_pytaco_io.read(file_name, _FORMAT)
-  passed = 0
-  # The value of a is stored as an MLIR sparse tensor.
-  passed += (not a.is_unpacked())
-  a.unpack()
-  passed += (a.is_unpacked())
-  coords, values = a.get_coordinates_and_values()
-  passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]])
-  passed += np.allclose(values, [2.0, 3.0, 4.0])
-  # CHECK: 4
-  print(passed)
+    with tempfile.TemporaryDirectory() as test_dir:
+        file_name = os.path.join(test_dir, "data.tns")
+        with open(file_name, "w") as file:
+            file.write(_TNS_DATA)
+        a = mlir_pytaco_io.read(file_name, _FORMAT)
+    passed = 0
+    # The value of a is stored as an MLIR sparse tensor.
+    passed += not a.is_unpacked()
+    a.unpack()
+    passed += a.is_unpacked()
+    coords, values = a.get_coordinates_and_values()
+    passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]])
+    passed += np.allclose(values, [2.0, 3.0, 4.0])
+    # CHECK: 4
+    print(passed)
 
 
 # CHECK-LABEL: test_write_unpacked_tns
 @testing_utils.run_test
 def test_write_unpacked_tns():
-  a = mlir_pytaco.Tensor([2, 3])
-  a.insert([0, 1], 10)
-  a.insert([1, 2], 40)
-  a.insert([0, 0], 20)
-  with tempfile.TemporaryDirectory() as test_dir:
-    file_name = os.path.join(test_dir, "data.tns")
-    try:
-      mlir_pytaco_io.write(file_name, a)
-    except ValueError as e:
-      # CHECK: Writing unpacked sparse tensors to file is not supported
-      print(e)
+    a = mlir_pytaco.Tensor([2, 3])
+    a.insert([0, 1], 10)
+    a.insert([1, 2], 40)
+    a.insert([0, 0], 20)
+    with tempfile.TemporaryDirectory() as test_dir:
+        file_name = os.path.join(test_dir, "data.tns")
+        try:
+            mlir_pytaco_io.write(file_name, a)
+        except ValueError as e:
+            # CHECK: Writing unpacked sparse tensors to file is not supported
+            print(e)
 
 
 # CHECK-LABEL: test_write_packed_tns
 @testing_utils.run_test
 def test_write_packed_tns():
-  a = mlir_pytaco.Tensor([2, 3])
-  a.insert([0, 1], 10)
-  a.insert([1, 2], 40)
-  a.insert([0, 0], 20)
-  b = mlir_pytaco.Tensor([2, 3])
-  i, j = mlir_pytaco.get_index_vars(2)
-  b[i, j] = a[i, j] + a[i, j]
-  with tempfile.TemporaryDirectory() as test_dir:
-    file_name = os.path.join(test_dir, "data.tns")
-    mlir_pytaco_io.write(file_name, b)
-    with open(file_name, "r") as file:
-      lines = file.readlines()
-  passed = 0
-  # Skip the comment line in the output.
-  if lines[1:] == ["2 3\n", "2 3\n", "1 1 40\n", "1 2 20\n", "2 3 80\n"]:
-    passed = 1
-  # CHECK: 1
-  print(passed)
+    a = mlir_pytaco.Tensor([2, 3])
+    a.insert([0, 1], 10)
+    a.insert([1, 2], 40)
+    a.insert([0, 0], 20)
+    b = mlir_pytaco.Tensor([2, 3])
+    i, j = mlir_pytaco.get_index_vars(2)
+    b[i, j] = a[i, j] + a[i, j]
+    with tempfile.TemporaryDirectory() as test_dir:
+        file_name = os.path.join(test_dir, "data.tns")
+        mlir_pytaco_io.write(file_name, b)
+        with open(file_name, "r") as file:
+            lines = file.readlines()
+    passed = 0
+    # Skip the comment line in the output.
+    if lines[1:] == ["2 3\n", "2 3\n", "1 1 40\n", "1 2 20\n", "2 3 80\n"]:
+        passed = 1
+    # CHECK: 1
+    print(passed)
index 1325969..1344f4a 100644 (file)
@@ -20,79 +20,93 @@ _DENSE = mlir_pytaco.ModeFormat.DENSE
 
 
 def _to_string(s: Sequence[int]) -> str:
-  """Converts a sequence of integer to a space separated value string."""
-  return " ".join(map(lambda e: str(e), s))
+    """Converts a sequence of integer to a space separated value string."""
+    return " ".join(map(lambda e: str(e), s))
 
 
 def _add_one(s: Sequence[int]) -> Sequence[int]:
-  """Adds one to each element in the sequence of integer."""
-  return [i + 1 for i in s]
+    """Adds one to each element in the sequence of integer."""
+    return [i + 1 for i in s]
 
 
 @dataclasses.dataclass(frozen=True)
 class _SparseTensorCOO:
-  """Values for a COO-flavored format sparse tensor.
-
-  Attributes:
-    rank: An integer rank for the tensor.
-    nse: An integer for the number of non-zero values.
-    shape: A sequence of integer for the dimension size.
-    values: A sequence of float for the non-zero values of the tensor.
-    indices: A sequence of coordinate, each coordinate is a sequence of integer.
-  """
-  rank: int
-  nse: int
-  shape: Sequence[int]
-  values: Sequence[float]
-  indices: Sequence[Sequence[int]]
+    """Values for a COO-flavored format sparse tensor.
+
+    Attributes:
+      rank: An integer rank for the tensor.
+      nse: An integer for the number of non-zero values.
+      shape: A sequence of integer for the dimension size.
+      values: A sequence of float for the non-zero values of the tensor.
+      indices: A sequence of coordinate, each coordinate is a sequence of integer.
+    """
+
+    rank: int
+    nse: int
+    shape: Sequence[int]
+    values: Sequence[float]
+    indices: Sequence[Sequence[int]]
 
 
 def _coo_values_to_tns_format(t: _SparseTensorCOO) -> str:
-  """Converts a sparse tensor COO-flavored values to TNS text format."""
-  # The coo_value_str contains one line for each (coordinate value) pair.
-  # Indices are 1-based in TNS text format but 0-based in MLIR.
-  coo_value_str = "\n".join(
-      map(lambda i: _to_string(_add_one(t.indices[i])) + " " + str(t.values[i]),
-          range(t.nse)))
-
-  # Returns the TNS text format representation for the tensor.
-  return f"""{t.rank} {t.nse}
+    """Converts a sparse tensor COO-flavored values to TNS text format."""
+    # The coo_value_str contains one line for each (coordinate value) pair.
+    # Indices are 1-based in TNS text format but 0-based in MLIR.
+    coo_value_str = "\n".join(
+        map(
+            lambda i: _to_string(_add_one(t.indices[i])) + " " + str(t.values[i]),
+            range(t.nse),
+        )
+    )
+
+    # Returns the TNS text format representation for the tensor.
+    return f"""{t.rank} {t.nse}
 {_to_string(t.shape)}
 {coo_value_str}
 """
 
 
 def _implement_read_tns_test(
-    t: _SparseTensorCOO,
-    sparsity_codes: Sequence[sparse_tensor.DimLevelType]) -> int:
-  tns_data = _coo_values_to_tns_format(t)
-
-  # Write sparse tensor data to a file.
-  with tempfile.TemporaryDirectory() as test_dir:
-    file_name = os.path.join(test_dir, "data.tns")
-    with open(file_name, "w") as file:
-      file.write(tns_data)
-
-    # Read the data from the file and construct an MLIR sparse tensor.
-    sparse_tensor, o_shape = pytaco_utils.create_sparse_tensor(
-        file_name, sparsity_codes, "f64")
-
-  passed = 0
-
-  # Verify the output shape for the tensor.
-  if np.array_equal(o_shape, t.shape):
-    passed += 1
-
-  # Use the output MLIR sparse tensor pointer to retrieve the COO-flavored
-  # values and verify the values.
-  o_rank, o_nse, o_shape, o_values, o_indices = (
-      pytaco_utils.sparse_tensor_to_coo_tensor(sparse_tensor, np.float64))
-  if o_rank == t.rank and o_nse == t.nse and np.array_equal(
-      o_shape, t.shape) and np.allclose(o_values, t.values) and np.array_equal(
-          o_indices, t.indices):
-    passed += 1
-
-  return passed
+    t: _SparseTensorCOO, sparsity_codes: Sequence[sparse_tensor.DimLevelType]
+) -> int:
+    tns_data = _coo_values_to_tns_format(t)
+
+    # Write sparse tensor data to a file.
+    with tempfile.TemporaryDirectory() as test_dir:
+        file_name = os.path.join(test_dir, "data.tns")
+        with open(file_name, "w") as file:
+            file.write(tns_data)
+
+        # Read the data from the file and construct an MLIR sparse tensor.
+        sparse_tensor, o_shape = pytaco_utils.create_sparse_tensor(
+            file_name, sparsity_codes, "f64"
+        )
+
+    passed = 0
+
+    # Verify the output shape for the tensor.
+    if np.array_equal(o_shape, t.shape):
+        passed += 1
+
+    # Use the output MLIR sparse tensor pointer to retrieve the COO-flavored
+    # values and verify the values.
+    (
+        o_rank,
+        o_nse,
+        o_shape,
+        o_values,
+        o_indices,
+    ) = pytaco_utils.sparse_tensor_to_coo_tensor(sparse_tensor, np.float64)
+    if (
+        o_rank == t.rank
+        and o_nse == t.nse
+        and np.array_equal(o_shape, t.shape)
+        and np.allclose(o_values, t.values)
+        and np.array_equal(o_indices, t.indices)
+    ):
+        passed += 1
+
+    return passed
 
 
 # A 2D sparse tensor data in COO-flavored format.
index 12c97db..70b4b66 100644 (file)
@@ -5,11 +5,11 @@ if not config.mlir_run_amx_tests:
     config.unsupported = True
 
 # No JIT on win32.
-if sys.platform == 'win32':
+if sys.platform == "win32":
     config.unsupported = True
 
 if config.intel_sde_executable:
     # Run test in emulator (Intel SDE): AMX needs Sapphire Rapids CPU.
-    config.substitutions.append(('%lli', config.intel_sde_executable + ' -spr -- lli'))
+    config.substitutions.append(("%lli", config.intel_sde_executable + " -spr -- lli"))
 else:
-    config.substitutions.append(('%lli', 'lli'))
+    config.substitutions.append(("%lli", "lli"))
index 0423fc0..296b441 100644 (file)
@@ -5,5 +5,5 @@ if not config.mlir_run_arm_sme_tests:
     config.unsupported = True
 
 # No JIT on win32.
-if sys.platform == 'win32':
+if sys.platform == "win32":
     config.unsupported = True
index 8a0d884..37d3a74 100644 (file)
@@ -5,5 +5,5 @@ if not config.mlir_run_arm_sve_tests:
     config.unsupported = True
 
 # No JIT on win32.
-if sys.platform == 'win32':
+if sys.platform == "win32":
     config.unsupported = True
index 0e22874..bde8156 100644 (file)
@@ -5,11 +5,11 @@ if not config.mlir_run_x86vector_tests:
     config.unsupported = True
 
 # No JIT on win32.
-if sys.platform == 'win32':
+if sys.platform == "win32":
     config.unsupported = True
 
 if config.intel_sde_executable:
     # Run test in emulator (Intel SDE).
-    config.substitutions.append(('%lli', config.intel_sde_executable + ' -tgl -- lli'))
+    config.substitutions.append(("%lli", config.intel_sde_executable + " -tgl -- lli"))
 else:
-    config.substitutions.append(('%lli', 'lli'))
+    config.substitutions.append(("%lli", "lli"))
index 451b9fc..3bd7024 100644 (file)
@@ -2,4 +2,4 @@ import sys
 
 # TensorCore tests must be enabled via build flag.
 if not config.mlir_run_cuda_tensor_core_tests:
-  config.unsupported = True
+    config.unsupported = True
index 0bdebfe..acb8dd4 100644 (file)
@@ -1,2 +1,2 @@
 if not config.enable_cuda_runner:
-  config.unsupported = True
+    config.unsupported = True
index b0d086f..e1f8648 100644 (file)
@@ -1,4 +1,4 @@
 if not config.enable_rocm_runner or not config.rocm_test_chipset:
-  config.unsupported = True
+    config.unsupported = True
 
-config.substitutions.append(('%chip', config.rocm_test_chipset))
+config.substitutions.append(("%chip", config.rocm_test_chipset))
index 80a862a..1b4a323 100644 (file)
@@ -3,8 +3,9 @@ from lit.llvm import llvm_config
 if not config.mlir_include_integration_tests:
     config.unsupported = True
 
+
 def configure_aarch64_lli_cmd():
-    lli_cmd = 'lli'
+    lli_cmd = "lli"
 
     # NOTE: If the SVE tests are disabled and the SME tests are enabled to run
     # under emulation, the SVE specific RUN lines in the SparseTensor tests
@@ -12,8 +13,12 @@ def configure_aarch64_lli_cmd():
     if not (config.mlir_run_arm_sve_tests or config.mlir_run_arm_sme_tests):
         return lli_cmd
 
-    config.substitutions.append(('%mlir_native_utils_lib_dir',
-        config.arm_emulator_utils_lib_dir or config.mlir_lib_dir))
+    config.substitutions.append(
+        (
+            "%mlir_native_utils_lib_dir",
+            config.arm_emulator_utils_lib_dir or config.mlir_lib_dir,
+        )
+    )
 
     if config.arm_emulator_executable:
         if config.arm_emulator_lli_executable:
@@ -23,16 +28,22 @@ def configure_aarch64_lli_cmd():
             # when running under an emulator. If the user didn't specify an lli
             # executable, use absolute path %llvm_tools_dir/lli.
             lli_cmd = llvm_config.use_llvm_tool(
-                'lli', search_env='LLI', required=True,
-                search_paths=[config.llvm_tools_dir], use_installed=False
+                "lli",
+                search_env="LLI",
+                required=True,
+                search_paths=[config.llvm_tools_dir],
+                use_installed=False,
             )
 
         # Run test in emulator (qemu or armie)
-        emulation_cmd = f'{config.arm_emulator_executable} {config.arm_emulator_options}'
-        lli_cmd = f'{emulation_cmd} {lli_cmd}'
+        emulation_cmd = (
+            f"{config.arm_emulator_executable} {config.arm_emulator_options}"
+        )
+        lli_cmd = f"{emulation_cmd} {lli_cmd}"
 
     return lli_cmd
 
+
 aarch64_lli_cmd = configure_aarch64_lli_cmd()
 
 # Configure the following AArch64 substitutions:
@@ -52,5 +63,5 @@ aarch64_lli_cmd = configure_aarch64_lli_cmd()
 # could be used in the SparseTensor tests where necessary, but the meaning
 # conveyed by the substitution name would be a misnomer if the host target
 # is not AArch64 and MLIR_RUN_ARM_SVE_TESTS=OFF.
-config.substitutions.append(('%lli_aarch64_cmd', aarch64_lli_cmd))
-config.substitutions.append(('%lli_host_or_aarch64_cmd', aarch64_lli_cmd))
+config.substitutions.append(("%lli_aarch64_cmd", aarch64_lli_cmd))
+config.substitutions.append(("%lli_host_or_aarch64_cmd", aarch64_lli_cmd))
index 5b66517..1898b72 100644 (file)
@@ -8,43 +8,43 @@ import subprocess
 import lit.formats
 
 # name: The name of this test suite.
-config.name = 'MLIR-Unit'
+config.name = "MLIR-Unit"
 
 # suffixes: A list of file extensions to treat as test files.
 config.suffixes = []
 
 # test_source_root: The root path where tests are located.
 # test_exec_root: The root path where tests should be run.
-config.test_exec_root = os.path.join(config.mlir_obj_root, 'unittests')
+config.test_exec_root = os.path.join(config.mlir_obj_root, "unittests")
 config.test_source_root = config.test_exec_root
 
 # testFormat: The test format to use to interpret tests.
-config.test_format = lit.formats.GoogleTest(config.llvm_build_mode, 'Tests')
+config.test_format = lit.formats.GoogleTest(config.llvm_build_mode, "Tests")
 
 # Propagate the temp directory. Windows requires this because it uses \Windows\
 # if none of these are present.
-if 'TMP' in os.environ:
-    config.environment['TMP'] = os.environ['TMP']
-if 'TEMP' in os.environ:
-    config.environment['TEMP'] = os.environ['TEMP']
+if "TMP" in os.environ:
+    config.environment["TMP"] = os.environ["TMP"]
+if "TEMP" in os.environ:
+    config.environment["TEMP"] = os.environ["TEMP"]
 
 # Propagate HOME as it can be used to override incorrect homedir in passwd
 # that causes the tests to fail.
-if 'HOME' in os.environ:
-    config.environment['HOME'] = os.environ['HOME']
+if "HOME" in os.environ:
+    config.environment["HOME"] = os.environ["HOME"]
 
 # Propagate sanitizer options.
 for var in [
-    'ASAN_SYMBOLIZER_PATH',
-    'HWASAN_SYMBOLIZER_PATH',
-    'MSAN_SYMBOLIZER_PATH',
-    'TSAN_SYMBOLIZER_PATH',
-    'UBSAN_SYMBOLIZER_PATH',
-    'ASAN_OPTIONS',
-    'HWASAN_OPTIONS',
-    'MSAN_OPTIONS',
-    'TSAN_OPTIONS',
-    'UBSAN_OPTIONS',
+    "ASAN_SYMBOLIZER_PATH",
+    "HWASAN_SYMBOLIZER_PATH",
+    "MSAN_SYMBOLIZER_PATH",
+    "TSAN_SYMBOLIZER_PATH",
+    "UBSAN_SYMBOLIZER_PATH",
+    "ASAN_OPTIONS",
+    "HWASAN_OPTIONS",
+    "MSAN_OPTIONS",
+    "TSAN_OPTIONS",
+    "UBSAN_OPTIONS",
 ]:
     if var in os.environ:
         config.environment[var] = os.environ[var]
index edb5b44..65a7f20 100644 (file)
@@ -1 +1 @@
-config.suffixes.remove('.td')
\ No newline at end of file
+config.suffixes.remove(".td")
index edb5b44..65a7f20 100644 (file)
@@ -1 +1 @@
-config.suffixes.remove('.td')
\ No newline at end of file
+config.suffixes.remove(".td")
index 8cfe5cd..8ffccee 100644 (file)
@@ -1 +1 @@
-config.suffixes.remove('.pdll')
+config.suffixes.remove(".pdll")
index 8cfe5cd..8ffccee 100644 (file)
@@ -1 +1 @@
-config.suffixes.remove('.pdll')
+config.suffixes.remove(".pdll")
index 1fc2e31..ad0b0d5 100644 (file)
@@ -16,21 +16,32 @@ from lit.llvm.subst import FindTool
 # Configuration file for the 'lit' test runner.
 
 # name: The name of this test suite.
-config.name = 'MLIR'
+config.name = "MLIR"
 
 config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
 
 # suffixes: A list of file extensions to treat as test files.
-config.suffixes = ['.td', '.mlir', '.toy', '.ll', '.tc', '.py', '.yaml', '.test', '.pdll', '.c']
+config.suffixes = [
+    ".td",
+    ".mlir",
+    ".toy",
+    ".ll",
+    ".tc",
+    ".py",
+    ".yaml",
+    ".test",
+    ".pdll",
+    ".c",
+]
 
 # test_source_root: The root path where tests are located.
 config.test_source_root = os.path.dirname(__file__)
 
 # test_exec_root: The root path where tests should be run.
-config.test_exec_root = os.path.join(config.mlir_obj_root, 'test')
+config.test_exec_root = os.path.join(config.mlir_obj_root, "test")
 
-config.substitutions.append(('%PATH%', config.environment['PATH']))
-config.substitutions.append(('%shlibext', config.llvm_shlib_ext))
+config.substitutions.append(("%PATH%", config.environment["PATH"]))
+config.substitutions.append(("%shlibext", config.llvm_shlib_ext))
 config.substitutions.append(("%mlir_src_root", config.mlir_src_root))
 config.substitutions.append(("%host_cxx", config.host_cxx))
 config.substitutions.append(("%host_cc", config.host_cc))
@@ -40,94 +51,109 @@ config.substitutions.append(("%host_cc", config.host_cc))
 # substitution of the same name and the found path.
 # Correctly handles the platforms shared library directory and naming conventions.
 def add_runtime(name):
-    path = ''
-    for prefix in ['', 'lib']:
-        path = os.path.join(config.llvm_shlib_dir, f'{prefix}{name}{config.llvm_shlib_ext}')
+    path = ""
+    for prefix in ["", "lib"]:
+        path = os.path.join(
+            config.llvm_shlib_dir, f"{prefix}{name}{config.llvm_shlib_ext}"
+        )
         if os.path.isfile(path):
             break
-    return ToolSubst(f'%{name}', path)
+    return ToolSubst(f"%{name}", path)
 
 
-llvm_config.with_system_environment(
-    ['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP'])
+llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"])
 
 llvm_config.use_default_substitutions()
 
 # excludes: A list of directories to exclude from the testsuite. The 'Inputs'
 # subdirectories contain auxiliary inputs for various tests in their parent
 # directories.
-config.excludes = ['Inputs', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt',
-                   'lit.cfg.py', 'lit.site.cfg.py']
+config.excludes = [
+    "Inputs",
+    "CMakeLists.txt",
+    "README.txt",
+    "LICENSE.txt",
+    "lit.cfg.py",
+    "lit.site.cfg.py",
+]
 
 # Tweak the PATH to include the tools dir.
-llvm_config.with_environment('PATH', config.mlir_tools_dir, append_path=True)
-llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
+llvm_config.with_environment("PATH", config.mlir_tools_dir, append_path=True)
+llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True)
 
 tool_dirs = [config.mlir_tools_dir, config.llvm_tools_dir]
 tools = [
-    'mlir-tblgen',
-    'mlir-translate',
-    'mlir-lsp-server',
-    'mlir-capi-execution-engine-test',
-    'mlir-capi-ir-test',
-    'mlir-capi-llvm-test',
-    'mlir-capi-pass-test',
-    'mlir-capi-pdl-test',
-    'mlir-capi-quant-test',
-    'mlir-capi-sparse-tensor-test',
-    'mlir-capi-transform-test',
-    'mlir-cpu-runner',
-    add_runtime('mlir_runner_utils'),
-    add_runtime('mlir_c_runner_utils'),
-    add_runtime('mlir_async_runtime'),
-    'mlir-linalg-ods-yaml-gen',
-    'mlir-reduce',
-    'mlir-pdll',
-    'not',
+    "mlir-tblgen",
+    "mlir-translate",
+    "mlir-lsp-server",
+    "mlir-capi-execution-engine-test",
+    "mlir-capi-ir-test",
+    "mlir-capi-llvm-test",
+    "mlir-capi-pass-test",
+    "mlir-capi-pdl-test",
+    "mlir-capi-quant-test",
+    "mlir-capi-sparse-tensor-test",
+    "mlir-capi-transform-test",
+    "mlir-cpu-runner",
+    add_runtime("mlir_runner_utils"),
+    add_runtime("mlir_c_runner_utils"),
+    add_runtime("mlir_async_runtime"),
+    "mlir-linalg-ods-yaml-gen",
+    "mlir-reduce",
+    "mlir-pdll",
+    "not",
 ]
 
 if config.enable_spirv_cpu_runner:
-    tools.extend(['mlir-spirv-cpu-runner', add_runtime('mlir_test_spirv_cpu_runner_c_wrappers')])
+    tools.extend(
+        ["mlir-spirv-cpu-runner", add_runtime("mlir_test_spirv_cpu_runner_c_wrappers")]
+    )
 
 if config.enable_vulkan_runner:
-    tools.extend([add_runtime('vulkan-runtime-wrappers')])
+    tools.extend([add_runtime("vulkan-runtime-wrappers")])
 
 if config.enable_rocm_runner:
-    tools.extend([add_runtime('mlir_rocm_runtime')])
+    tools.extend([add_runtime("mlir_rocm_runtime")])
 
 if config.enable_cuda_runner:
-    tools.extend([add_runtime('mlir_cuda_runtime')])
+    tools.extend([add_runtime("mlir_cuda_runtime")])
 
 # The following tools are optional
-tools.extend([
-    ToolSubst('toyc-ch1', unresolved='ignore'),
-    ToolSubst('toyc-ch2', unresolved='ignore'),
-    ToolSubst('toyc-ch3', unresolved='ignore'),
-    ToolSubst('toyc-ch4', unresolved='ignore'),
-    ToolSubst('toyc-ch5', unresolved='ignore'),
-    ToolSubst('toyc-ch6', unresolved='ignore'),
-    ToolSubst('toyc-ch7', unresolved='ignore'),
-    ToolSubst('%mlir_lib_dir', config.mlir_lib_dir, unresolved='ignore'),
-    ToolSubst('%mlir_src_dir', config.mlir_src_root, unresolved='ignore'),
-])
+tools.extend(
+    [
+        ToolSubst("toyc-ch1", unresolved="ignore"),
+        ToolSubst("toyc-ch2", unresolved="ignore"),
+        ToolSubst("toyc-ch3", unresolved="ignore"),
+        ToolSubst("toyc-ch4", unresolved="ignore"),
+        ToolSubst("toyc-ch5", unresolved="ignore"),
+        ToolSubst("toyc-ch6", unresolved="ignore"),
+        ToolSubst("toyc-ch7", unresolved="ignore"),
+        ToolSubst("%mlir_lib_dir", config.mlir_lib_dir, unresolved="ignore"),
+        ToolSubst("%mlir_src_dir", config.mlir_src_root, unresolved="ignore"),
+    ]
+)
 
 python_executable = config.python_executable
 # Python configuration with sanitizer requires some magic preloading. This will only work on clang/linux.
 # TODO: detect Darwin/Windows situation (or mark these tests as unsupported on these platforms).
 if "asan" in config.available_features and "Linux" in config.host_os:
-  python_executable = f"LD_PRELOAD=$({config.host_cxx} -print-file-name=libclang_rt.asan-{config.host_arch}.so) {config.python_executable}"
+    python_executable = f"LD_PRELOAD=$({config.host_cxx} -print-file-name=libclang_rt.asan-{config.host_arch}.so) {config.python_executable}"
 # On Windows the path to python could contains spaces in which case it needs to be provided in quotes.
 # This is the equivalent of how %python is setup in llvm/utils/lit/lit/llvm/config.py.
 elif "Windows" in config.host_os:
-  python_executable = '"%s"' % (python_executable)
-tools.extend([
-  ToolSubst('%PYTHON', python_executable, unresolved='ignore'),
-])
+    python_executable = '"%s"' % (python_executable)
+tools.extend(
+    [
+        ToolSubst("%PYTHON", python_executable, unresolved="ignore"),
+    ]
+)
 
 if "MLIR_OPT_CHECK_IR_ROUNDTRIP" in os.environ:
-  tools.extend([
-    ToolSubst('mlir-opt', 'mlir-opt --verify-roundtrip', unresolved='fatal'),
-  ])
+    tools.extend(
+        [
+            ToolSubst("mlir-opt", "mlir-opt --verify-roundtrip", unresolved="fatal"),
+        ]
+    )
 
 llvm_config.add_tool_substitutions(tools, tool_dirs)
 
@@ -135,40 +161,48 @@ llvm_config.add_tool_substitutions(tools, tool_dirs)
 # FileCheck -enable-var-scope is enabled by default in MLIR test
 # This option avoids to accidentally reuse variable across -LABEL match,
 # it can be explicitly opted-in by prefixing the variable name with $
-config.environment['FILECHECK_OPTS'] = "-enable-var-scope --allow-unused-prefixes=false"
+config.environment["FILECHECK_OPTS"] = "-enable-var-scope --allow-unused-prefixes=false"
 
 # Add the python path for both the source and binary tree.
 # Note that presently, the python sources come from the source tree and the
 # binaries come from the build tree. This should be unified to the build tree
 # by copying/linking sources to build.
 if config.enable_bindings_python:
-  llvm_config.with_environment('PYTHONPATH', [
-      os.path.join(config.mlir_obj_root, 'python_packages', 'mlir_core'),
-      os.path.join(config.mlir_obj_root, 'python_packages', 'mlir_test'),
-  ], append_path=True)
+    llvm_config.with_environment(
+        "PYTHONPATH",
+        [
+            os.path.join(config.mlir_obj_root, "python_packages", "mlir_core"),
+            os.path.join(config.mlir_obj_root, "python_packages", "mlir_test"),
+        ],
+        append_path=True,
+    )
 
 if config.enable_assertions:
-  config.available_features.add('asserts')
+    config.available_features.add("asserts")
 else:
-  config.available_features.add('noasserts')
+    config.available_features.add("noasserts")
+
 
 def have_host_jit_feature_support(feature_name):
-  mlir_cpu_runner_exe = lit.util.which('mlir-cpu-runner', config.mlir_tools_dir)
+    mlir_cpu_runner_exe = lit.util.which("mlir-cpu-runner", config.mlir_tools_dir)
+
+    if not mlir_cpu_runner_exe:
+        return False
 
-  if not mlir_cpu_runner_exe:
-    return False
+    try:
+        mlir_cpu_runner_cmd = subprocess.Popen(
+            [mlir_cpu_runner_exe, "--host-supports-" + feature_name],
+            stdout=subprocess.PIPE,
+        )
+    except OSError:
+        print("could not exec mlir-cpu-runner")
+        return False
 
-  try:
-    mlir_cpu_runner_cmd = subprocess.Popen(
-        [mlir_cpu_runner_exe, '--host-supports-' + feature_name], stdout=subprocess.PIPE)
-  except OSError:
-    print('could not exec mlir-cpu-runner')
-    return False
+    mlir_cpu_runner_out = mlir_cpu_runner_cmd.stdout.read().decode("ascii")
+    mlir_cpu_runner_cmd.wait()
 
-  mlir_cpu_runner_out = mlir_cpu_runner_cmd.stdout.read().decode('ascii')
-  mlir_cpu_runner_cmd.wait()
+    return "true" in mlir_cpu_runner_out
 
-  return 'true' in mlir_cpu_runner_out
 
-if have_host_jit_feature_support('jit'):
-  config.available_features.add('host-supports-jit')
+if have_host_jit_feature_support("jit"):
+    config.available_features.add("host-supports-jit")
index 3f59ff1..3c20d20 100644 (file)
@@ -1,12 +1,11 @@
 import sys
 
 # MSAN does not work with JIT.
-if 'msan' in config.available_features:
-  config.unsupported = True
+if "msan" in config.available_features:
+    config.unsupported = True
 
 # Requires native execution.
-if 'host-supports-jit' not in config.available_features:
+if "host-supports-jit" not in config.available_features:
     config.unsupported = True
 
-config.available_features.add(
-        config.root.native_target.lower() + '-native-target')
+config.available_features.add(config.root.native_target.lower() + "-native-target")
index c438027..4cb5622 100644 (file)
@@ -1,2 +1,2 @@
-config.suffixes = ['.pdll', '.mlir']
-config.excludes = ['include']
+config.suffixes = [".pdll", ".mlir"]
+config.excludes = ["include"]
index 286bea4..8717dd0 100644 (file)
@@ -1,4 +1,4 @@
 import sys
 
 if not config.enable_spirv_cpu_runner:
-  config.unsupported = True
+    config.unsupported = True
index f99be2a..6da7fcd 100644 (file)
@@ -1,2 +1,2 @@
 if not config.enable_vulkan_runner:
-  config.unsupported = True
+    config.unsupported = True
index ea0a911..4dc3a0b 100644 (file)
@@ -14,5 +14,6 @@ expected_lib_name = "MLIRPythonCAPI"
 all_libs = os.listdir(get_lib_dirs()[0])
 found_lib = False
 for file_name in all_libs:
-  if expected_lib_name in file_name: found_lib = True
+    if expected_lib_name in file_name:
+        found_lib = True
 assert found_lib, f"Did not find '{expected_lib_name}' lib in {all_libs}"
index acae9b6..8e9613d 100644 (file)
@@ -4,16 +4,18 @@ from mlir.ir import *
 import mlir.dialects.func as func
 import mlir.dialects.arith as arith
 
+
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
+    print("\nTEST:", f.__name__)
+    f()
+
 
 # CHECK-LABEL: TEST: testConstantOp
 @run
 def testConstantOps():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    with InsertionPoint(module.body):
-      arith.ConstantOp(value=42.42, result=F32Type.get())
-    # CHECK:         %cst = arith.constant 4.242000e+01 : f32
-    print(module)
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            arith.ConstantOp(value=42.42, result=F32Type.get())
+        # CHECK:         %cst = arith.constant 4.242000e+01 : f32
+        print(module)
index da3103c..f6181cc 100644 (file)
@@ -5,14 +5,17 @@ import mlir.dialects.async_dialect
 import mlir.dialects.async_dialect.passes
 from mlir.passmanager import *
 
+
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
+    print("\nTEST:", f.__name__)
+    f()
+
 
 def testAsyncPass():
-  with Context() as context:
-    PassManager.parse('any(async-to-async-runtime)')
-  print('SUCCESS')
+    with Context() as context:
+        PassManager.parse("any(async-to-async-runtime)")
+    print("SUCCESS")
+
 
 # CHECK-LABEL: testAsyncPass
 #       CHECK: SUCCESS
index eab24b5..18ebba6 100644 (file)
@@ -7,232 +7,242 @@ import numpy as np
 
 
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    return f
 
 
 # CHECK-LABEL: TEST: testFromPyFunc
 @run
 def testFromPyFunc():
-  with Context() as ctx, Location.unknown() as loc:
-    ctx.allow_unregistered_dialects = True
-    m = builtin.ModuleOp()
-    f32 = F32Type.get()
-    f64 = F64Type.get()
-    with InsertionPoint(m.body):
-      # CHECK-LABEL: func @unary_return(%arg0: f64) -> f64
-      # CHECK: return %arg0 : f64
-      @func.FuncOp.from_py_func(f64)
-      def unary_return(a):
-        return a
-
-      # CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64)
-      # CHECK: return %arg0, %arg1 : f32, f64
-      @func.FuncOp.from_py_func(f32, f64)
-      def binary_return(a, b):
-        return a, b
-
-      # CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64)
-      # CHECK: return
-      @func.FuncOp.from_py_func(f32, f64)
-      def none_return(a, b):
-        pass
-
-      # CHECK-LABEL: func @call_unary
-      # CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64
-      # CHECK: return %0 : f64
-      @func.FuncOp.from_py_func(f64)
-      def call_unary(a):
-        return unary_return(a)
-
-      # CHECK-LABEL: func @call_binary
-      # CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64)
-      # CHECK: return %0#0, %0#1 : f32, f64
-      @func.FuncOp.from_py_func(f32, f64)
-      def call_binary(a, b):
-        return binary_return(a, b)
-
-      # We expect coercion of a single result operation to a returned value.
-      # CHECK-LABEL: func @single_result_op
-      # CHECK: %0 = "custom.op1"() : () -> f32
-      # CHECK: return %0 : f32
-      @func.FuncOp.from_py_func()
-      def single_result_op():
-        return Operation.create("custom.op1", results=[f32])
-
-      # CHECK-LABEL: func @call_none
-      # CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> ()
-      # CHECK: return
-      @func.FuncOp.from_py_func(f32, f64)
-      def call_none(a, b):
-        return none_return(a, b)
-
-      ## Variants and optional feature tests.
-      # CHECK-LABEL: func @from_name_arg
-      @func.FuncOp.from_py_func(f32, f64, name="from_name_arg")
-      def explicit_name(a, b):
-        return b
-
-      @func.FuncOp.from_py_func(f32, f64)
-      def positional_func_op(a, b, func_op):
-        assert isinstance(func_op, func.FuncOp)
-        return b
-
-      @func.FuncOp.from_py_func(f32, f64)
-      def kw_func_op(a, b=None, func_op=None):
-        assert isinstance(func_op, func.FuncOp)
-        return b
-
-      @func.FuncOp.from_py_func(f32, f64)
-      def kwargs_func_op(a, b=None, **kwargs):
-        assert isinstance(kwargs["func_op"], func.FuncOp)
-        return b
-
-      # CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64
-      # CHECK: return %arg1 : f64
-      @func.FuncOp.from_py_func(f32, f64, results=[f64])
-      def explicit_results(a, b):
-        func.ReturnOp([b])
-
-  print(m)
+    with Context() as ctx, Location.unknown() as loc:
+        ctx.allow_unregistered_dialects = True
+        m = builtin.ModuleOp()
+        f32 = F32Type.get()
+        f64 = F64Type.get()
+        with InsertionPoint(m.body):
+            # CHECK-LABEL: func @unary_return(%arg0: f64) -> f64
+            # CHECK: return %arg0 : f64
+            @func.FuncOp.from_py_func(f64)
+            def unary_return(a):
+                return a
+
+            # CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64)
+            # CHECK: return %arg0, %arg1 : f32, f64
+            @func.FuncOp.from_py_func(f32, f64)
+            def binary_return(a, b):
+                return a, b
+
+            # CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64)
+            # CHECK: return
+            @func.FuncOp.from_py_func(f32, f64)
+            def none_return(a, b):
+                pass
+
+            # CHECK-LABEL: func @call_unary
+            # CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64
+            # CHECK: return %0 : f64
+            @func.FuncOp.from_py_func(f64)
+            def call_unary(a):
+                return unary_return(a)
+
+            # CHECK-LABEL: func @call_binary
+            # CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64)
+            # CHECK: return %0#0, %0#1 : f32, f64
+            @func.FuncOp.from_py_func(f32, f64)
+            def call_binary(a, b):
+                return binary_return(a, b)
+
+            # We expect coercion of a single result operation to a returned value.
+            # CHECK-LABEL: func @single_result_op
+            # CHECK: %0 = "custom.op1"() : () -> f32
+            # CHECK: return %0 : f32
+            @func.FuncOp.from_py_func()
+            def single_result_op():
+                return Operation.create("custom.op1", results=[f32])
+
+            # CHECK-LABEL: func @call_none
+            # CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> ()
+            # CHECK: return
+            @func.FuncOp.from_py_func(f32, f64)
+            def call_none(a, b):
+                return none_return(a, b)
+
+            ## Variants and optional feature tests.
+            # CHECK-LABEL: func @from_name_arg
+            @func.FuncOp.from_py_func(f32, f64, name="from_name_arg")
+            def explicit_name(a, b):
+                return b
+
+            @func.FuncOp.from_py_func(f32, f64)
+            def positional_func_op(a, b, func_op):
+                assert isinstance(func_op, func.FuncOp)
+                return b
+
+            @func.FuncOp.from_py_func(f32, f64)
+            def kw_func_op(a, b=None, func_op=None):
+                assert isinstance(func_op, func.FuncOp)
+                return b
+
+            @func.FuncOp.from_py_func(f32, f64)
+            def kwargs_func_op(a, b=None, **kwargs):
+                assert isinstance(kwargs["func_op"], func.FuncOp)
+                return b
+
+            # CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64
+            # CHECK: return %arg1 : f64
+            @func.FuncOp.from_py_func(f32, f64, results=[f64])
+            def explicit_results(a, b):
+                func.ReturnOp([b])
+
+    print(m)
 
 
 # CHECK-LABEL: TEST: testFromPyFuncErrors
 @run
 def testFromPyFuncErrors():
-  with Context() as ctx, Location.unknown() as loc:
-    m = builtin.ModuleOp()
-    f32 = F32Type.get()
-    f64 = F64Type.get()
-    with InsertionPoint(m.body):
-      try:
+    with Context() as ctx, Location.unknown() as loc:
+        m = builtin.ModuleOp()
+        f32 = F32Type.get()
+        f64 = F64Type.get()
+        with InsertionPoint(m.body):
+            try:
 
-        @func.FuncOp.from_py_func(f64, results=[f64])
-        def unary_return(a):
-          return a
-      except AssertionError as e:
-        # CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None.
-        print(e)
+                @func.FuncOp.from_py_func(f64, results=[f64])
+                def unary_return(a):
+                    return a
+
+            except AssertionError as e:
+                # CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None.
+                print(e)
 
 
 # CHECK-LABEL: TEST: testBuildFuncOp
 @run
 def testBuildFuncOp():
-  ctx = Context()
-  with Location.unknown(ctx) as loc:
-    m = builtin.ModuleOp()
-
-    f32 = F32Type.get()
-    tensor_type = RankedTensorType.get((2, 3, 4), f32)
-    with InsertionPoint.at_block_begin(m.body):
-      f = func.FuncOp(name="some_func",
-                            type=FunctionType.get(
-                                inputs=[tensor_type, tensor_type],
-                                results=[tensor_type]),
-                            visibility="nested")
-      # CHECK: Name is: "some_func"
-      print("Name is: ", f.name)
-
-      # CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
-      print("Type is: ", f.type)
-
-      # CHECK: Visibility is: "nested"
-      print("Visibility is: ", f.visibility)
-
-      try:
-        entry_block = f.entry_block
-      except IndexError as e:
-        # CHECK: External function does not have a body
-        print(e)
-
-      with InsertionPoint(f.add_entry_block()):
-        func.ReturnOp([f.entry_block.arguments[0]])
-        pass
-
-      try:
-        f.add_entry_block()
-      except IndexError as e:
-        # CHECK: The function already has an entry block!
-        print(e)
-
-      # Try the callback builder and passing type as tuple.
-      f = func.FuncOp(name="some_other_func",
-                            type=([tensor_type, tensor_type], [tensor_type]),
-                            visibility="nested",
-                            body_builder=lambda f: func.ReturnOp(
-                                [f.entry_block.arguments[0]]))
-
-  # CHECK: module  {
-  # CHECK:  func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
-  # CHECK:   return %arg0 : tensor<2x3x4xf32>
-  # CHECK:  }
-  # CHECK:  func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
-  # CHECK:   return %arg0 : tensor<2x3x4xf32>
-  # CHECK:  }
-  print(m)
+    ctx = Context()
+    with Location.unknown(ctx) as loc:
+        m = builtin.ModuleOp()
+
+        f32 = F32Type.get()
+        tensor_type = RankedTensorType.get((2, 3, 4), f32)
+        with InsertionPoint.at_block_begin(m.body):
+            f = func.FuncOp(
+                name="some_func",
+                type=FunctionType.get(
+                    inputs=[tensor_type, tensor_type], results=[tensor_type]
+                ),
+                visibility="nested",
+            )
+            # CHECK: Name is: "some_func"
+            print("Name is: ", f.name)
+
+            # CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
+            print("Type is: ", f.type)
+
+            # CHECK: Visibility is: "nested"
+            print("Visibility is: ", f.visibility)
+
+            try:
+                entry_block = f.entry_block
+            except IndexError as e:
+                # CHECK: External function does not have a body
+                print(e)
+
+            with InsertionPoint(f.add_entry_block()):
+                func.ReturnOp([f.entry_block.arguments[0]])
+                pass
+
+            try:
+                f.add_entry_block()
+            except IndexError as e:
+                # CHECK: The function already has an entry block!
+                print(e)
+
+            # Try the callback builder and passing type as tuple.
+            f = func.FuncOp(
+                name="some_other_func",
+                type=([tensor_type, tensor_type], [tensor_type]),
+                visibility="nested",
+                body_builder=lambda f: func.ReturnOp([f.entry_block.arguments[0]]),
+            )
+
+    # CHECK: module  {
+    # CHECK:  func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+    # CHECK:   return %arg0 : tensor<2x3x4xf32>
+    # CHECK:  }
+    # CHECK:  func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+    # CHECK:   return %arg0 : tensor<2x3x4xf32>
+    # CHECK:  }
+    print(m)
 
 
 # CHECK-LABEL: TEST: testFuncArgumentAccess
 @run
 def testFuncArgumentAccess():
-  with Context() as ctx, Location.unknown():
-    ctx.allow_unregistered_dialects = True
-    module = Module.create()
-    f32 = F32Type.get()
-    f64 = F64Type.get()
-    with InsertionPoint(module.body):
-      f = func.FuncOp("some_func", ([f32, f32], [f32, f32]))
-      with InsertionPoint(f.add_entry_block()):
-        func.ReturnOp(f.arguments)
-      f.arg_attrs = ArrayAttr.get([
-          DictAttr.get({
-              "custom_dialect.foo": StringAttr.get("bar"),
-              "custom_dialect.baz": UnitAttr.get()
-          }),
-          DictAttr.get({"custom_dialect.qux": ArrayAttr.get([])})
-      ])
-      f.result_attrs = ArrayAttr.get([
-          DictAttr.get({"custom_dialect.res1": FloatAttr.get(f32, 42.0)}),
-          DictAttr.get({"custom_dialect.res2": FloatAttr.get(f64, 256.0)})
-      ])
-
-      other = func.FuncOp("other_func", ([f32, f32], []))
-      with InsertionPoint(other.add_entry_block()):
-        func.ReturnOp([])
-      other.arg_attrs = [
-          DictAttr.get({"custom_dialect.foo": StringAttr.get("qux")}),
-          DictAttr.get()
-      ]
-
-  # CHECK: [{custom_dialect.baz, custom_dialect.foo = "bar"}, {custom_dialect.qux = []}]
-  print(f.arg_attrs)
-
-  # CHECK: [{custom_dialect.res1 = 4.200000e+01 : f32}, {custom_dialect.res2 = 2.560000e+02 : f64}]
-  print(f.result_attrs)
-
-  # CHECK: func @some_func(
-  # CHECK: %[[ARG0:.*]]: f32 {custom_dialect.baz, custom_dialect.foo = "bar"},
-  # CHECK: %[[ARG1:.*]]: f32 {custom_dialect.qux = []}) ->
-  # CHECK: f32 {custom_dialect.res1 = 4.200000e+01 : f32},
-  # CHECK: f32 {custom_dialect.res2 = 2.560000e+02 : f64})
-  # CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32
-  #
-  # CHECK: func @other_func(
-  # CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"},
-  # CHECK: %{{.*}}: f32)
-  print(module)
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+        module = Module.create()
+        f32 = F32Type.get()
+        f64 = F64Type.get()
+        with InsertionPoint(module.body):
+            f = func.FuncOp("some_func", ([f32, f32], [f32, f32]))
+            with InsertionPoint(f.add_entry_block()):
+                func.ReturnOp(f.arguments)
+            f.arg_attrs = ArrayAttr.get(
+                [
+                    DictAttr.get(
+                        {
+                            "custom_dialect.foo": StringAttr.get("bar"),
+                            "custom_dialect.baz": UnitAttr.get(),
+                        }
+                    ),
+                    DictAttr.get({"custom_dialect.qux": ArrayAttr.get([])}),
+                ]
+            )
+            f.result_attrs = ArrayAttr.get(
+                [
+                    DictAttr.get({"custom_dialect.res1": FloatAttr.get(f32, 42.0)}),
+                    DictAttr.get({"custom_dialect.res2": FloatAttr.get(f64, 256.0)}),
+                ]
+            )
+
+            other = func.FuncOp("other_func", ([f32, f32], []))
+            with InsertionPoint(other.add_entry_block()):
+                func.ReturnOp([])
+            other.arg_attrs = [
+                DictAttr.get({"custom_dialect.foo": StringAttr.get("qux")}),
+                DictAttr.get(),
+            ]
+
+    # CHECK: [{custom_dialect.baz, custom_dialect.foo = "bar"}, {custom_dialect.qux = []}]
+    print(f.arg_attrs)
+
+    # CHECK: [{custom_dialect.res1 = 4.200000e+01 : f32}, {custom_dialect.res2 = 2.560000e+02 : f64}]
+    print(f.result_attrs)
+
+    # CHECK: func @some_func(
+    # CHECK: %[[ARG0:.*]]: f32 {custom_dialect.baz, custom_dialect.foo = "bar"},
+    # CHECK: %[[ARG1:.*]]: f32 {custom_dialect.qux = []}) ->
+    # CHECK: f32 {custom_dialect.res1 = 4.200000e+01 : f32},
+    # CHECK: f32 {custom_dialect.res2 = 2.560000e+02 : f64})
+    # CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32
+    #
+    # CHECK: func @other_func(
+    # CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"},
+    # CHECK: %{{.*}}: f32)
+    print(module)
 
 
 # CHECK-LABEL: testDenseElementsAttr
 @run
 def testDenseElementsAttr():
-  with Context(), Location.unknown():
-    values = np.arange(4, dtype=np.int32)
-    i32 = IntegerType.get_signless(32)
-    print(DenseElementsAttr.get(values, type=i32))
-    # CHECK{LITERAL}: dense<[0, 1, 2, 3]> : tensor<4xi32>
-    print(DenseElementsAttr.get(values, type=i32, shape=(2, 2)))
-    # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
-    print(DenseElementsAttr.get(values, type=VectorType.get((2, 2), i32)))
-    # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : vector<2x2xi32>
+    with Context(), Location.unknown():
+        values = np.arange(4, dtype=np.int32)
+        i32 = IntegerType.get_signless(32)
+        print(DenseElementsAttr.get(values, type=i32))
+        # CHECK{LITERAL}: dense<[0, 1, 2, 3]> : tensor<4xi32>
+        print(DenseElementsAttr.get(values, type=i32, shape=(2, 2)))
+        # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
+        print(DenseElementsAttr.get(values, type=VectorType.get((2, 2), i32)))
+        # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : vector<2x2xi32>
index e724575..afad217 100644 (file)
@@ -9,24 +9,24 @@ import mlir.dialects.complex as mlir_complex
 
 
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
+    print("\nTEST:", f.__name__)
+    f()
 
 
 # CHECK-LABEL: TEST: testComplexOps
 @run
 def testComplexOps():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    with InsertionPoint(module.body):
-
-      @func.FuncOp.from_py_func(ComplexType.get(F32Type.get()))
-      def emit_add(arg):
-        return mlir_complex.AddOp(arg, arg)
-
-    # CHECK-LABEL: func @emit_add(
-    # CHECK-SAME:                  %[[ARG:.*]]: complex<f32>) -> complex<f32> {
-    # CHECK:         %[[RES:.*]] = complex.add %[[ARG]], %[[ARG]] : complex<f32>
-    # CHECK:         return %[[RES]] : complex<f32>
-    # CHECK:       }
-    print(module)
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(ComplexType.get(F32Type.get()))
+            def emit_add(arg):
+                return mlir_complex.AddOp(arg, arg)
+
+        # CHECK-LABEL: func @emit_add(
+        # CHECK-SAME:                  %[[ARG:.*]]: complex<f32>) -> complex<f32> {
+        # CHECK:         %[[RES:.*]] = complex.add %[[ARG]], %[[ARG]] : complex<f32>
+        # CHECK:         return %[[RES]] : complex<f32>
+        # CHECK:       }
+        print(module)
index 3be9cac..161a12d 100644 (file)
@@ -7,13 +7,13 @@ from mlir.dialects import func
 
 
 def constructAndPrintInModule(f):
-  print("\nTEST:", f.__name__)
-  with Context(), Location.unknown():
-    module = Module.create()
-    with InsertionPoint(module.body):
-      f()
-    print(module)
-  return f
+    print("\nTEST:", f.__name__)
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            f()
+        print(module)
+    return f
 
 
 # CHECK-LABEL: TEST: testConstantOp
@@ -21,21 +21,21 @@ def constructAndPrintInModule(f):
 
 @constructAndPrintInModule
 def testConstantOp():
-  c1 = arith.ConstantOp(IntegerType.get_signless(32), 42)
-  c2 = arith.ConstantOp(IntegerType.get_signless(64), 100)
-  c3 = arith.ConstantOp(F32Type.get(), 3.14)
-  c4 = arith.ConstantOp(F64Type.get(), 1.23)
-  # CHECK: 42
-  print(c1.literal_value)
+    c1 = arith.ConstantOp(IntegerType.get_signless(32), 42)
+    c2 = arith.ConstantOp(IntegerType.get_signless(64), 100)
+    c3 = arith.ConstantOp(F32Type.get(), 3.14)
+    c4 = arith.ConstantOp(F64Type.get(), 1.23)
+    # CHECK: 42
+    print(c1.literal_value)
 
-  # CHECK: 100
-  print(c2.literal_value)
+    # CHECK: 100
+    print(c2.literal_value)
 
-  # CHECK: 3.140000104904175
-  print(c3.literal_value)
+    # CHECK: 3.140000104904175
+    print(c3.literal_value)
 
-  # CHECK: 1.23
-  print(c4.literal_value)
+    # CHECK: 1.23
+    print(c4.literal_value)
 
 
 # CHECK: = arith.constant 42 : i32
@@ -47,17 +47,17 @@ def testConstantOp():
 # CHECK-LABEL: TEST: testVectorConstantOp
 @constructAndPrintInModule
 def testVectorConstantOp():
-  int_type = IntegerType.get_signless(32)
-  vec_type = VectorType.get([2, 2], int_type)
-  c1 = arith.ConstantOp(
-      vec_type,
-      DenseElementsAttr.get_splat(vec_type, IntegerAttr.get(int_type, 42)))
-  try:
-    print(c1.literal_value)
-  except ValueError as e:
-    assert "only integer and float constants have literal values" in str(e)
-  else:
-    assert False
+    int_type = IntegerType.get_signless(32)
+    vec_type = VectorType.get([2, 2], int_type)
+    c1 = arith.ConstantOp(
+        vec_type, DenseElementsAttr.get_splat(vec_type, IntegerAttr.get(int_type, 42))
+    )
+    try:
+        print(c1.literal_value)
+    except ValueError as e:
+        assert "only integer and float constants have literal values" in str(e)
+    else:
+        assert False
 
 
 # CHECK: = arith.constant dense<42> : vector<2x2xi32>
@@ -66,9 +66,9 @@ def testVectorConstantOp():
 # CHECK-LABEL: TEST: testConstantIndexOp
 @constructAndPrintInModule
 def testConstantIndexOp():
-  c1 = arith.ConstantOp.create_index(10)
-  # CHECK: 10
-  print(c1.literal_value)
+    c1 = arith.ConstantOp.create_index(10)
+    # CHECK: 10
+    print(c1.literal_value)
 
 
 # CHECK: = arith.constant 10 : index
@@ -77,18 +77,18 @@ def testConstantIndexOp():
 # CHECK-LABEL: TEST: testFunctionCalls
 @constructAndPrintInModule
 def testFunctionCalls():
-  foo = func.FuncOp("foo", ([], []))
-  foo.sym_visibility = StringAttr.get("private")
-  bar = func.FuncOp("bar", ([], [IndexType.get()]))
-  bar.sym_visibility = StringAttr.get("private")
-  qux = func.FuncOp("qux", ([], [F32Type.get()]))
-  qux.sym_visibility = StringAttr.get("private")
-
-  with InsertionPoint(func.FuncOp("caller", ([], [])).add_entry_block()):
-    func.CallOp(foo, [])
-    func.CallOp([IndexType.get()], "bar", [])
-    func.CallOp([F32Type.get()], FlatSymbolRefAttr.get("qux"), [])
-    func.ReturnOp([])
+    foo = func.FuncOp("foo", ([], []))
+    foo.sym_visibility = StringAttr.get("private")
+    bar = func.FuncOp("bar", ([], [IndexType.get()]))
+    bar.sym_visibility = StringAttr.get("private")
+    qux = func.FuncOp("qux", ([], [F32Type.get()]))
+    qux.sym_visibility = StringAttr.get("private")
+
+    with InsertionPoint(func.FuncOp("caller", ([], [])).add_entry_block()):
+        func.CallOp(foo, [])
+        func.CallOp([IndexType.get()], "bar", [])
+        func.CallOp([F32Type.get()], FlatSymbolRefAttr.get("qux"), [])
+        func.ReturnOp([])
 
 
 # CHECK: func private @foo()
index 38bf038..7eefaed 100644 (file)
@@ -5,14 +5,17 @@ import mlir.dialects.gpu
 import mlir.dialects.gpu.passes
 from mlir.passmanager import *
 
+
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
+    print("\nTEST:", f.__name__)
+    f()
+
 
 def testGPUPass():
-  with Context() as context:
-    PassManager.parse('any(gpu-kernel-outlining)')
-  print('SUCCESS')
+    with Context() as context:
+        PassManager.parse("any(gpu-kernel-outlining)")
+    print("SUCCESS")
+
 
 # CHECK-LABEL: testGPUPass
 #       CHECK: SUCCESS
index d787c5f..7892d02 100644 (file)
@@ -34,8 +34,9 @@ def matmul(
     C=TensorDef(U, S.M, S.N, output=True),
     bfn=BinaryFnAttrDef(default=BinaryFn.mul),
     ufn=UnaryFnAttrDef(default=UnaryFn.exp),
-    cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
-  C[D.m, D.n] += bfn(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
+    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+    C[D.m, D.n] += bfn(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
 
 
 # CHECK: ---
@@ -47,7 +48,7 @@ def matmul(
 # CHECK:     type_var: T
 @linalg_structured_op
 def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)):
-  O[D.m, D.n] = value
+    O[D.m, D.n] = value
 
 
 # CHECK: ---
@@ -71,5 +72,6 @@ def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)):
 def strided_copy(
     I=TensorDef(T, S.IH, S.IW),
     O=TensorDef(T, S.OH, S.OW, output=True),
-    strides=IndexAttrDef(S.SH, S.SW, default=[1, 2])):
-  O[D.oh, D.ow] = I[D.oh * S.SH, D.ow * S.SW]
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 2]),
+):
+    O[D.oh, D.ow] = I[D.oh * S.SH, D.ow * S.SW]
index eacf435..ad0a3ea 100644 (file)
@@ -35,8 +35,9 @@ def matmul(
     B=TensorDef(T, S.K, S.N),
     C=TensorDef(U, S.M, S.N, output=True),
     mul=BinaryFnAttrDef(default=BinaryFn.mul),
-    cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
-  C[D.m, D.n] += mul(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
+    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+    C[D.m, D.n] += mul(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n]))
 
 
 # CHECK: ---
@@ -79,12 +80,12 @@ def matmul(
 # CHECK:                  scalar_const: '1.{{[0]*}}e+03 : f64'
 @linalg_structured_op
 def constants(
-    O=TensorDef(T, S.M, S.K, output=True),
-    exp=UnaryFnAttrDef(default=UnaryFn.exp)):
-  pi = TypeFn.cast_signed(T, const(3.1415926535897931))
-  cst42 = TypeFn.cast_signed(T, const(42))
-  cst1000 = TypeFn.cast_signed(T, exp(const(1e+3)))
-  O[D.m, D.n] = UnaryFn.exp(pi) + cst42 - cst1000
+    O=TensorDef(T, S.M, S.K, output=True), exp=UnaryFnAttrDef(default=UnaryFn.exp)
+):
+    pi = TypeFn.cast_signed(T, const(3.1415926535897931))
+    cst42 = TypeFn.cast_signed(T, const(42))
+    cst1000 = TypeFn.cast_signed(T, exp(const(1e3)))
+    O[D.m, D.n] = UnaryFn.exp(pi) + cst42 - cst1000
 
 
 # CHECK: ---
@@ -100,7 +101,7 @@ def constants(
 # CHECK:          scalar_index: 0
 @linalg_structured_op
 def indices(O=TensorDef(T, S.M, S.K, output=True)):
-  O[D.m, D.n] = index(D.n) + index(D.m)
+    O[D.m, D.n] = index(D.n) + index(D.m)
 
 
 # CHECK: ---
@@ -111,4 +112,4 @@ def indices(O=TensorDef(T, S.M, S.K, output=True)):
 # CHECK:      scalar_arg: value
 @linalg_structured_op
 def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)):
-  O[D.m, D.n] = value
+    O[D.m, D.n] = value
index 4aae768..d2f9cec 100644 (file)
@@ -3,10 +3,11 @@
 import doctest
 import importlib
 
+
 def test_module(module_name):
-  print(f"--- Testing module: {module_name}")
-  m = importlib.import_module(module_name)
-  doctest.testmod(m, verbose=True, raise_on_error=True, report=True)
+    print(f"--- Testing module: {module_name}")
+    m = importlib.import_module(module_name)
+    doctest.testmod(m, verbose=True, raise_on_error=True, report=True)
 
 
 test_module("mlir.dialects.linalg.opdsl.lang.affine")
index ebe2c0f..d666d31 100644 (file)
@@ -17,43 +17,44 @@ def conv_poly(
     K=TensorDef(T2, S.KH, S.KW, S.C),
     O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 2])):
-  domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
-  O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed(
-      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
-           D.c]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c])
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 2]),
+):
+    domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
+    O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed(
+        U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+    ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c])
 
 
 with Context() as ctx, Location.unknown():
-  module = Module.create()
-  f32 = F32Type.get()
-  i32 = IntegerType.get_signless(32)
-  with InsertionPoint(module.body):
-
-    # Convolution indexing maps.
-    # CHECK: #[[$CONV_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d3, d2 * 4 + d4 * 2, d5)>
-    # CHECK: #[[$CONV_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
-    # CHECK: #[[$CONV_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
-
-    # CHECK-LABEL: @test_f32i32_conv
-    # CHECK: linalg.generic
-    # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$CONV_MAP_K]], #[[$CONV_MAP_O]]]
-    # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
-    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[FILTER:.+]]: f32, %[[OUT:.+]]: i32)
-    # CHECK-NEXT:   %[[IN_CAST:.+]] = arith.fptosi %[[IN:.+]] : f32 to i32
-    # CHECK-NEXT:   %[[FILTER_CAST:.+]] = arith.fptosi %[[FILTER:.+]] : f32 to i32
-    # CHECK-NEXT:   %[[PROD:.+]] = arith.muli %[[IN_CAST]], %[[FILTER_CAST]] : i32
-    # CHECK-NEXT:   %[[SUM:.+]] = arith.addi %[[OUT]], %[[PROD]] : i32
-    # CHECK-NEXT:   linalg.yield %[[SUM]] : i32
-    # CHECK-NEXT: -> tensor<1x2x4x1xi32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((1, 4, 16, 1), f32),
-        RankedTensorType.get((2, 2, 1), f32),
-        RankedTensorType.get((1, 2, 4, 1), i32))
-    def test_f32i32_conv(input, filter, init_result):
-      # Use default dilations and set non-default strides.
-      return conv_poly(
-          input, filter, outs=[init_result], strides=[2, 4])
+    module = Module.create()
+    f32 = F32Type.get()
+    i32 = IntegerType.get_signless(32)
+    with InsertionPoint(module.body):
+
+        # Convolution indexing maps.
+        # CHECK: #[[$CONV_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d3, d2 * 4 + d4 * 2, d5)>
+        # CHECK: #[[$CONV_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
+        # CHECK: #[[$CONV_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
+
+        # CHECK-LABEL: @test_f32i32_conv
+        # CHECK: linalg.generic
+        # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$CONV_MAP_K]], #[[$CONV_MAP_O]]]
+        # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[FILTER:.+]]: f32, %[[OUT:.+]]: i32)
+        # CHECK-NEXT:   %[[IN_CAST:.+]] = arith.fptosi %[[IN:.+]] : f32 to i32
+        # CHECK-NEXT:   %[[FILTER_CAST:.+]] = arith.fptosi %[[FILTER:.+]] : f32 to i32
+        # CHECK-NEXT:   %[[PROD:.+]] = arith.muli %[[IN_CAST]], %[[FILTER_CAST]] : i32
+        # CHECK-NEXT:   %[[SUM:.+]] = arith.addi %[[OUT]], %[[PROD]] : i32
+        # CHECK-NEXT:   linalg.yield %[[SUM]] : i32
+        # CHECK-NEXT: -> tensor<1x2x4x1xi32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((1, 4, 16, 1), f32),
+            RankedTensorType.get((2, 2, 1), f32),
+            RankedTensorType.get((1, 2, 4, 1), i32),
+        )
+        def test_f32i32_conv(input, filter, init_result):
+            # Use default dilations and set non-default strides.
+            return conv_poly(input, filter, outs=[init_result], strides=[2, 4])
 
 
 print(module)
index 1f840b0..ffef737 100644 (file)
@@ -13,47 +13,51 @@ T2 = TV.T2
 
 @linalg_structured_op
 def fill_poly(value=ScalarDef(T1), O=TensorDef(U, output=True)):
-  O[None] = TypeFn.cast_signed(U, value)
+    O[None] = TypeFn.cast_signed(U, value)
+
 
 @linalg_structured_op
 def fill_rank_zero_poly(I=TensorDef(T1), O=TensorDef(U, output=True)):
-  O[None] = TypeFn.cast_signed(U, I[None])
+    O[None] = TypeFn.cast_signed(U, I[None])
+
 
 with Context() as ctx, Location.unknown():
-  module = Module.create()
-  f32 = F32Type.get()
-  with InsertionPoint(module.body):
-
-    # Fill indexing maps.
-    # CHECK-DAG: #[[$MAP0:.+]] = affine_map<() -> ()>
-    # CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
-    # CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-    # CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2) -> ()>
-    # CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-
-    # CHECK-LABEL: @test_fill_0d
-    # CHECK: linalg.generic
-    # CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]
-    # CHECK-SAME: iterator_types = []
-    @func.FuncOp.from_py_func(f32, RankedTensorType.get([], f32))
-    def test_fill_0d(value, init_result):
-      return fill_poly(value, outs=[init_result])
-
-    # CHECK-LABEL: @test_fill_2d
-    # CHECK: linalg.generic
-    # CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]]]
-    # CHECK-SAME: iterator_types = ["parallel", "parallel"]
-    @func.FuncOp.from_py_func(f32, RankedTensorType.get([4, 16], f32))
-    def test_fill_2d(value, init_result):
-      return fill_poly(value, outs=[init_result])
-
-    # CHECK-LABEL: @test_fill_rank_zero_3d
-    # CHECK: linalg.generic
-    # CHECK-SAME: indexing_maps = [#[[$MAP3]], #[[$MAP4]]]
-    # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get([], f32), RankedTensorType.get([4, 8, 16], f32))
-    def test_fill_rank_zero_3d(input, init_result):
-      return fill_rank_zero_poly(input, outs=[init_result])
+    module = Module.create()
+    f32 = F32Type.get()
+    with InsertionPoint(module.body):
+
+        # Fill indexing maps.
+        # CHECK-DAG: #[[$MAP0:.+]] = affine_map<() -> ()>
+        # CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
+        # CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+        # CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2) -> ()>
+        # CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+        # CHECK-LABEL: @test_fill_0d
+        # CHECK: linalg.generic
+        # CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]
+        # CHECK-SAME: iterator_types = []
+        @func.FuncOp.from_py_func(f32, RankedTensorType.get([], f32))
+        def test_fill_0d(value, init_result):
+            return fill_poly(value, outs=[init_result])
+
+        # CHECK-LABEL: @test_fill_2d
+        # CHECK: linalg.generic
+        # CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]]]
+        # CHECK-SAME: iterator_types = ["parallel", "parallel"]
+        @func.FuncOp.from_py_func(f32, RankedTensorType.get([4, 16], f32))
+        def test_fill_2d(value, init_result):
+            return fill_poly(value, outs=[init_result])
+
+        # CHECK-LABEL: @test_fill_rank_zero_3d
+        # CHECK: linalg.generic
+        # CHECK-SAME: indexing_maps = [#[[$MAP3]], #[[$MAP4]]]
+        # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get([], f32), RankedTensorType.get([4, 8, 16], f32)
+        )
+        def test_fill_rank_zero_3d(input, init_result):
+            return fill_rank_zero_poly(input, outs=[init_result])
+
 
 print(module)
index 6dff754..18c237c 100644 (file)
@@ -16,9 +16,10 @@ T2 = TV.T2
 def matmul_mono(
     A=TensorDef(T, S.M, S.K),
     B=TensorDef(T, S.K, S.N),
-    C=TensorDef(T, S.M, S.N, output=True)):
-  domain(D.m, D.n, D.k)
-  C[D.m, D.n] += A[D.m, D.k] * B[D.k, D.n]
+    C=TensorDef(T, S.M, S.N, output=True),
+):
+    domain(D.m, D.n, D.k)
+    C[D.m, D.n] += A[D.m, D.k] * B[D.k, D.n]
 
 
 @linalg_structured_op
@@ -26,146 +27,162 @@ def matmul_poly(
     A=TensorDef(T1, S.M, S.K),
     B=TensorDef(T2, S.K, S.N),
     C=TensorDef(U, S.M, S.N, output=True),
-    cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
-  domain(D.m, D.n, D.k)
-  C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
+    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+    domain(D.m, D.n, D.k)
+    C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
 
 
 with Context() as ctx, Location.unknown():
-  module = Module.create()
-  f16 = F16Type.get()
-  f32 = F32Type.get()
-  f64 = F64Type.get()
-  i8 = IntegerType.get_signless(8)
-  i16 = IntegerType.get_signless(16)
-  i32 = IntegerType.get_signless(32)
-  with InsertionPoint(module.body):
-
-    # Multiplication indexing maps. We verify only the indexing maps of the
-    # first multiplication and then do additional tests on casting and body
-    # generation behavior.
-    # CHECK: #[[$MUL_MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
-    # CHECK: #[[$MUL_MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
-    # CHECK: #[[$MUL_MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-
-    # CHECK-LABEL: func @test_matmul_mono
-    # CHECK-SAME:  %[[A:.+]]: tensor<4x16xf32>
-    # CHECK-SAME:  %[[B:.+]]: tensor<16x8xf32>
-    # CHECK: %[[INITC:.+]] = tensor.empty() : tensor<4x8xf32>
-    # CHECK: linalg.generic
-    # CHECK-SAME: indexing_maps = [#[[$MUL_MAP_A]], #[[$MUL_MAP_B]], #[[$MUL_MAP_C]]]
-    # CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
-    # CHECK-SAME: ins(%[[A]], %[[B]]
-    # CHECK-SAME: outs(%[[INITC]]
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32))
-    def test_matmul_mono(lhs, rhs):
-      init_result = tensor.EmptyOp([4, 8], f32)
-      return matmul_mono(lhs, rhs, outs=[init_result.result])
-
-    # CHECK-LABEL: @test_i8i8i32_matmul
-    # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: i32)
-    # CHECK-NEXT:   %[[A_CAST:.+]] = arith.extsi %[[A_ARG]] : i8 to i32
-    # CHECK-NEXT:   %[[B_CAST:.+]] = arith.extsi %[[B_ARG]] : i8 to i32
-    # CHECK-NEXT:   %[[MUL:.+]] = arith.muli %[[A_CAST]], %[[B_CAST]] : i32
-    # CHECK-NEXT:   %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i32
-    # CHECK-NEXT:   linalg.yield %[[ADD]] : i32
-    # CHECK-NEXT: -> tensor<4x8xi32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
-        RankedTensorType.get((4, 8), i32))
-    def test_i8i8i32_matmul(lhs, rhs, init_result):
-      return matmul_poly(lhs, rhs, outs=[init_result])
-
-    # CHECK-LABEL: @test_i8i8i32_matmul_unsigned
-    # CHECK:   = arith.extui
-    # CHECK:   = arith.extui
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
-        RankedTensorType.get((4, 8), i32))
-    def test_i8i8i32_matmul_unsigned(lhs, rhs, init_result):
-      return matmul_poly(
-          lhs, rhs, outs=[init_result], cast=TypeFn.cast_unsigned)
-
-    # CHECK-LABEL: @test_i8i16i32_matmul
-    # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i16, %[[C_ARG:.+]]: i32)
-    # CHECK-NEXT:   %[[A_CAST:.+]] = arith.extsi %[[A_ARG]] : i8 to i32
-    # CHECK-NEXT:   %[[B_CAST:.+]] = arith.extsi %[[B_ARG]] : i16 to i32
-    # CHECK-NEXT:   %[[MUL:.+]] = arith.muli %[[A_CAST]], %[[B_CAST]] : i32
-    # CHECK-NEXT:   %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i32
-    # CHECK-NEXT:   linalg.yield %[[ADD]] : i32
-    # CHECK-NEXT: -> tensor<4x8xi32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i16),
-        RankedTensorType.get((4, 8), i32))
-    def test_i8i16i32_matmul(lhs, rhs, init_result):
-      return matmul_poly(lhs, rhs, outs=[init_result])
-
-    # CHECK-LABEL: @test_i32i32i16_matmul
-    # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i16)
-    # CHECK-NEXT:   %[[A_CAST:.+]] = arith.trunci %[[A_ARG]] : i32 to i16
-    # CHECK-NEXT:   %[[B_CAST:.+]] = arith.trunci %[[B_ARG]] : i32 to i16
-    # CHECK-NEXT:   %[[MUL:.+]] = arith.muli %[[A_CAST]], %[[B_CAST]] : i16
-    # CHECK-NEXT:   %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i16
-    # CHECK-NEXT:   linalg.yield %[[ADD]] : i16
-    # CHECK-NEXT: -> tensor<4x8xi16>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), i32), RankedTensorType.get((16, 8), i32),
-        RankedTensorType.get((4, 8), i16))
-    def test_i32i32i16_matmul(lhs, rhs, init_result):
-      return matmul_poly(lhs, rhs, outs=[init_result])
-
-    # CHECK-LABEL: @test_i8i8f32_matmul
-    # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: f32)
-    # CHECK-NEXT:   %[[A_CAST:.+]] = arith.sitofp %[[A_ARG]] : i8 to f32
-    # CHECK-NEXT:   %[[B_CAST:.+]] = arith.sitofp %[[B_ARG]] : i8 to f32
-    # CHECK-NEXT:   %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
-    # CHECK-NEXT:   %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
-    # CHECK-NEXT:   linalg.yield %[[ADD]] : f32
-    # CHECK-NEXT: -> tensor<4x8xf32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
-        RankedTensorType.get((4, 8), f32))
-    def test_i8i8f32_matmul(lhs, rhs, init_result):
-      return matmul_poly(lhs, rhs, outs=[init_result])
-
-    # CHECK-LABEL: @test_i8i8f32_matmul_unsigned
-    # CHECK:   = arith.uitofp
-    # CHECK:   = arith.uitofp
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8),
-        RankedTensorType.get((4, 8), f32))
-    def test_i8i8f32_matmul_unsigned(lhs, rhs, init_result):
-      return matmul_poly(
-          lhs, rhs, outs=[init_result], cast=TypeFn.cast_unsigned)
-
-    # CHECK-LABEL: @test_f16f16f32_matmul
-    # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
-    # CHECK-NEXT:   %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
-    # CHECK-NEXT:   %[[B_CAST:.+]] = arith.extf %[[B_ARG]] : f16 to f32
-    # CHECK-NEXT:   %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
-    # CHECK-NEXT:   %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
-    # CHECK-NEXT:   linalg.yield %[[ADD]] : f32
-    # CHECK-NEXT: -> tensor<4x8xf32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), f16), RankedTensorType.get((16, 8), f16),
-        RankedTensorType.get((4, 8), f32))
-    def test_f16f16f32_matmul(lhs, rhs, init_result):
-      return matmul_poly(lhs, rhs, outs=[init_result])
-
-    # CHECK-LABEL: @test_f64f64f32_matmul
-    # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f64, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32)
-    # CHECK-NEXT:   %[[A_CAST:.+]] = arith.truncf %[[A_ARG]] : f64 to f32
-    # CHECK-NEXT:   %[[B_CAST:.+]] = arith.truncf %[[B_ARG]] : f64 to f32
-    # CHECK-NEXT:   %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
-    # CHECK-NEXT:   %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
-    # CHECK-NEXT:   linalg.yield %[[ADD]] : f32
-    # CHECK-NEXT: -> tensor<4x8xf32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), f64), RankedTensorType.get((16, 8), f64),
-        RankedTensorType.get((4, 8), f32))
-    def test_f64f64f32_matmul(lhs, rhs, init_result):
-      return matmul_poly(lhs, rhs, outs=[init_result])
+    module = Module.create()
+    f16 = F16Type.get()
+    f32 = F32Type.get()
+    f64 = F64Type.get()
+    i8 = IntegerType.get_signless(8)
+    i16 = IntegerType.get_signless(16)
+    i32 = IntegerType.get_signless(32)
+    with InsertionPoint(module.body):
+
+        # Multiplication indexing maps. We verify only the indexing maps of the
+        # first multiplication and then do additional tests on casting and body
+        # generation behavior.
+        # CHECK: #[[$MUL_MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+        # CHECK: #[[$MUL_MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+        # CHECK: #[[$MUL_MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+        # CHECK-LABEL: func @test_matmul_mono
+        # CHECK-SAME:  %[[A:.+]]: tensor<4x16xf32>
+        # CHECK-SAME:  %[[B:.+]]: tensor<16x8xf32>
+        # CHECK: %[[INITC:.+]] = tensor.empty() : tensor<4x8xf32>
+        # CHECK: linalg.generic
+        # CHECK-SAME: indexing_maps = [#[[$MUL_MAP_A]], #[[$MUL_MAP_B]], #[[$MUL_MAP_C]]]
+        # CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+        # CHECK-SAME: ins(%[[A]], %[[B]]
+        # CHECK-SAME: outs(%[[INITC]]
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
+        )
+        def test_matmul_mono(lhs, rhs):
+            init_result = tensor.EmptyOp([4, 8], f32)
+            return matmul_mono(lhs, rhs, outs=[init_result.result])
+
+        # CHECK-LABEL: @test_i8i8i32_matmul
+        # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: i32)
+        # CHECK-NEXT:   %[[A_CAST:.+]] = arith.extsi %[[A_ARG]] : i8 to i32
+        # CHECK-NEXT:   %[[B_CAST:.+]] = arith.extsi %[[B_ARG]] : i8 to i32
+        # CHECK-NEXT:   %[[MUL:.+]] = arith.muli %[[A_CAST]], %[[B_CAST]] : i32
+        # CHECK-NEXT:   %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i32
+        # CHECK-NEXT:   linalg.yield %[[ADD]] : i32
+        # CHECK-NEXT: -> tensor<4x8xi32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), i8),
+            RankedTensorType.get((16, 8), i8),
+            RankedTensorType.get((4, 8), i32),
+        )
+        def test_i8i8i32_matmul(lhs, rhs, init_result):
+            return matmul_poly(lhs, rhs, outs=[init_result])
+
+        # CHECK-LABEL: @test_i8i8i32_matmul_unsigned
+        # CHECK:   = arith.extui
+        # CHECK:   = arith.extui
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), i8),
+            RankedTensorType.get((16, 8), i8),
+            RankedTensorType.get((4, 8), i32),
+        )
+        def test_i8i8i32_matmul_unsigned(lhs, rhs, init_result):
+            return matmul_poly(lhs, rhs, outs=[init_result], cast=TypeFn.cast_unsigned)
+
+        # CHECK-LABEL: @test_i8i16i32_matmul
+        # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i16, %[[C_ARG:.+]]: i32)
+        # CHECK-NEXT:   %[[A_CAST:.+]] = arith.extsi %[[A_ARG]] : i8 to i32
+        # CHECK-NEXT:   %[[B_CAST:.+]] = arith.extsi %[[B_ARG]] : i16 to i32
+        # CHECK-NEXT:   %[[MUL:.+]] = arith.muli %[[A_CAST]], %[[B_CAST]] : i32
+        # CHECK-NEXT:   %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i32
+        # CHECK-NEXT:   linalg.yield %[[ADD]] : i32
+        # CHECK-NEXT: -> tensor<4x8xi32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), i8),
+            RankedTensorType.get((16, 8), i16),
+            RankedTensorType.get((4, 8), i32),
+        )
+        def test_i8i16i32_matmul(lhs, rhs, init_result):
+            return matmul_poly(lhs, rhs, outs=[init_result])
+
+        # CHECK-LABEL: @test_i32i32i16_matmul
+        # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i16)
+        # CHECK-NEXT:   %[[A_CAST:.+]] = arith.trunci %[[A_ARG]] : i32 to i16
+        # CHECK-NEXT:   %[[B_CAST:.+]] = arith.trunci %[[B_ARG]] : i32 to i16
+        # CHECK-NEXT:   %[[MUL:.+]] = arith.muli %[[A_CAST]], %[[B_CAST]] : i16
+        # CHECK-NEXT:   %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i16
+        # CHECK-NEXT:   linalg.yield %[[ADD]] : i16
+        # CHECK-NEXT: -> tensor<4x8xi16>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), i32),
+            RankedTensorType.get((16, 8), i32),
+            RankedTensorType.get((4, 8), i16),
+        )
+        def test_i32i32i16_matmul(lhs, rhs, init_result):
+            return matmul_poly(lhs, rhs, outs=[init_result])
+
+        # CHECK-LABEL: @test_i8i8f32_matmul
+        # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: f32)
+        # CHECK-NEXT:   %[[A_CAST:.+]] = arith.sitofp %[[A_ARG]] : i8 to f32
+        # CHECK-NEXT:   %[[B_CAST:.+]] = arith.sitofp %[[B_ARG]] : i8 to f32
+        # CHECK-NEXT:   %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
+        # CHECK-NEXT:   %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
+        # CHECK-NEXT:   linalg.yield %[[ADD]] : f32
+        # CHECK-NEXT: -> tensor<4x8xf32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), i8),
+            RankedTensorType.get((16, 8), i8),
+            RankedTensorType.get((4, 8), f32),
+        )
+        def test_i8i8f32_matmul(lhs, rhs, init_result):
+            return matmul_poly(lhs, rhs, outs=[init_result])
+
+        # CHECK-LABEL: @test_i8i8f32_matmul_unsigned
+        # CHECK:   = arith.uitofp
+        # CHECK:   = arith.uitofp
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), i8),
+            RankedTensorType.get((16, 8), i8),
+            RankedTensorType.get((4, 8), f32),
+        )
+        def test_i8i8f32_matmul_unsigned(lhs, rhs, init_result):
+            return matmul_poly(lhs, rhs, outs=[init_result], cast=TypeFn.cast_unsigned)
+
+        # CHECK-LABEL: @test_f16f16f32_matmul
+        # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
+        # CHECK-NEXT:   %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
+        # CHECK-NEXT:   %[[B_CAST:.+]] = arith.extf %[[B_ARG]] : f16 to f32
+        # CHECK-NEXT:   %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
+        # CHECK-NEXT:   %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
+        # CHECK-NEXT:   linalg.yield %[[ADD]] : f32
+        # CHECK-NEXT: -> tensor<4x8xf32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), f16),
+            RankedTensorType.get((16, 8), f16),
+            RankedTensorType.get((4, 8), f32),
+        )
+        def test_f16f16f32_matmul(lhs, rhs, init_result):
+            return matmul_poly(lhs, rhs, outs=[init_result])
+
+        # CHECK-LABEL: @test_f64f64f32_matmul
+        # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f64, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32)
+        # CHECK-NEXT:   %[[A_CAST:.+]] = arith.truncf %[[A_ARG]] : f64 to f32
+        # CHECK-NEXT:   %[[B_CAST:.+]] = arith.truncf %[[B_ARG]] : f64 to f32
+        # CHECK-NEXT:   %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32
+        # CHECK-NEXT:   %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32
+        # CHECK-NEXT:   linalg.yield %[[ADD]] : f32
+        # CHECK-NEXT: -> tensor<4x8xf32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), f64),
+            RankedTensorType.get((16, 8), f64),
+            RankedTensorType.get((4, 8), f32),
+        )
+        def test_f64f64f32_matmul(lhs, rhs, init_result):
+            return matmul_poly(lhs, rhs, outs=[init_result])
 
 
 print(module)
index aad7149..f8e034f 100644 (file)
@@ -17,14 +17,16 @@ from mlir.dialects.linalg.opdsl.lang import *
 
 @linalg_structured_op
 def test_const(O=TensorDef(F32, S.M, S.N, output=True)):
-  O[D.m, D.n] = TypeFn.cast_unsigned(F32, const(42)) + TypeFn.cast_unsigned(
-      F32, const(2.3283064e-10))
+    O[D.m, D.n] = TypeFn.cast_unsigned(F32, const(42)) + TypeFn.cast_unsigned(
+        F32, const(2.3283064e-10)
+    )
 
 
 @linalg_structured_op
 def test_index(O=TensorDef(I32, S.M, S.N, output=True)):
-  O[D.m, D.n] = TypeFn.cast_signed(I32, index(D.m)) + TypeFn.cast_signed(
-      I32, index(D.n))
+    O[D.m, D.n] = TypeFn.cast_signed(I32, index(D.m)) + TypeFn.cast_signed(
+        I32, index(D.n)
+    )
 
 
 @linalg_structured_op
@@ -32,120 +34,129 @@ def elemwise_unary_poly(
     I=TensorDef(T),
     O=TensorDef(U, output=True),
     fun=UnaryFnAttrDef(default=UnaryFn.exp),
-    cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
-  O[None] = fun(cast(U, I[None]))
+    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+    O[None] = fun(cast(U, I[None]))
 
 
 @linalg_structured_op(op_name="custom_op_name")
 def non_default_op_name(I=TensorDef(T, S.N), O=TensorDef(T, S.N, output=True)):
-  O[D.n] = I[D.n]
+    O[D.n] = I[D.n]
 
 
 with Context() as ctx, Location.unknown():
-  module = Module.create()
-  f32 = F32Type.get()
-  c32 = ComplexType.get(f32)
-  i32 = IntegerType.get_signless(32)
-  with InsertionPoint(module.body):
-
-    # CHECK-LABEL: @test_f32_const
-    # CHECK-DAG:    %[[CST0:.+]] = arith.constant 42 : i64
-    # CHECK-DAG:    %[[CST0_CAST:.+]] = arith.uitofp %[[CST0]] : i64 to f32
-    # CHECK-DAG:    %[[CST1:.+]] = arith.constant 2.3283063999999999E-10 : f64
-    # CHECK-DAG:    %[[CST1_CAST:.+]] = arith.truncf %[[CST1]] : f64 to f32
-    # CHECK-DAG:    %[[SUM:.+]] = arith.addf %[[CST0_CAST]], %[[CST1_CAST]] : f32
-    # CHECK-NEXT:   linalg.yield %[[SUM]] : f32
-    @func.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32))
-    def test_f32_const(init_result):
-      return test_const(outs=[init_result])
-
-    # CHECK-LABEL: @test_i32_index
-    # CHECK-DAG:    %[[IDX0:.+]] = linalg.index 0 : index
-    # CHECK-DAG:    %[[IDX1:.+]] = linalg.index 1 : index
-    # CHECK-DAG:    %[[IDX0_CAST:.+]] = arith.index_cast %[[IDX0]] : index to i32
-    # CHECK-DAG:    %[[IDX1_CAST:.+]] = arith.index_cast %[[IDX1]] : index to i32
-    # CHECK-DAG:    %[[SUM:.+]] = arith.addi %[[IDX0_CAST]], %[[IDX1_CAST]] : i32
-    # CHECK-NEXT:   linalg.yield %[[SUM]] : i32
-    @func.FuncOp.from_py_func(RankedTensorType.get((4, 16), i32))
-    def test_i32_index(init_result):
-      return test_index(outs=[init_result])
-
-    # CHECK-LABEL: @test_f32_elemwise_exp
-    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-    # CHECK-NEXT:   %[[EXP:.+]] = math.exp %[[IN]] : f32
-    # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
-    # CHECK-NEXT: -> tensor<4x16xf32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
-    def test_f32_elemwise_exp(input, init_result):
-      return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.exp)
-
-    # CHECK-LABEL: @test_f32_elemwise_log
-    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-    # CHECK-NEXT:   %[[LOG:.+]] = math.log %[[IN]] : f32
-    # CHECK-NEXT:   linalg.yield %[[LOG]] : f32
-    # CHECK-NEXT: -> tensor<4x16xf32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
-    def test_f32_elemwise_log(input, init_result):
-      return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log)
-
-    # CHECK-LABEL: @test_f32_elemwise_abs
-    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-    # CHECK-NEXT:   %[[EXP:.+]] = math.absf %[[IN]] : f32
-    # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
-    # CHECK-NEXT: -> tensor<4x16xf32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
-    def test_f32_elemwise_abs(input, init_result):
-      return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs)
-
-    # CHECK-LABEL: @test_f32_elemwise_ceil
-    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-    # CHECK-NEXT:   %[[EXP:.+]] = math.ceil %[[IN]] : f32
-    # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
-    # CHECK-NEXT: -> tensor<4x16xf32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
-    def test_f32_elemwise_ceil(input, init_result):
-      return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.ceil)
-
-    # CHECK-LABEL: @test_f32_elemwise_floor
-    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-    # CHECK-NEXT:   %[[EXP:.+]] = math.floor %[[IN]] : f32
-    # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
-    # CHECK-NEXT: -> tensor<4x16xf32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
-    def test_f32_elemwise_floor(input, init_result):
-      return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.floor)
-
-    # CHECK-LABEL: @test_f32_elemwise_neg
-    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-    # CHECK-NEXT:   %[[EXP:.+]] = arith.negf %[[IN]] : f32
-    # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
-    # CHECK-NEXT: -> tensor<4x16xf32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
-    def test_f32_elemwise_neg(input, init_result):
-      return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
-
-    # CHECK-LABEL: @test_c32_elemwise_neg
-    # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
-    # CHECK-NEXT:   %[[EXP:.+]] = complex.neg %[[IN]] : complex<f32>
-    # CHECK-NEXT:   linalg.yield %[[EXP]] : complex<f32>
-    # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32))
-    def test_c32_elemwise_neg(input, init_result):
-      return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
-
-    # Just check that we don't assert out on name mismatch.
-    # CHECK-LABEL: @test_non_default_op_name
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((42,), f32), RankedTensorType.get((42,), f32))
-    def test_non_default_op_name(input, init_result):
-      return non_default_op_name(input, outs=[init_result])
+    module = Module.create()
+    f32 = F32Type.get()
+    c32 = ComplexType.get(f32)
+    i32 = IntegerType.get_signless(32)
+    with InsertionPoint(module.body):
+
+        # CHECK-LABEL: @test_f32_const
+        # CHECK-DAG:    %[[CST0:.+]] = arith.constant 42 : i64
+        # CHECK-DAG:    %[[CST0_CAST:.+]] = arith.uitofp %[[CST0]] : i64 to f32
+        # CHECK-DAG:    %[[CST1:.+]] = arith.constant 2.3283063999999999E-10 : f64
+        # CHECK-DAG:    %[[CST1_CAST:.+]] = arith.truncf %[[CST1]] : f64 to f32
+        # CHECK-DAG:    %[[SUM:.+]] = arith.addf %[[CST0_CAST]], %[[CST1_CAST]] : f32
+        # CHECK-NEXT:   linalg.yield %[[SUM]] : f32
+        @func.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32))
+        def test_f32_const(init_result):
+            return test_const(outs=[init_result])
+
+        # CHECK-LABEL: @test_i32_index
+        # CHECK-DAG:    %[[IDX0:.+]] = linalg.index 0 : index
+        # CHECK-DAG:    %[[IDX1:.+]] = linalg.index 1 : index
+        # CHECK-DAG:    %[[IDX0_CAST:.+]] = arith.index_cast %[[IDX0]] : index to i32
+        # CHECK-DAG:    %[[IDX1_CAST:.+]] = arith.index_cast %[[IDX1]] : index to i32
+        # CHECK-DAG:    %[[SUM:.+]] = arith.addi %[[IDX0_CAST]], %[[IDX1_CAST]] : i32
+        # CHECK-NEXT:   linalg.yield %[[SUM]] : i32
+        @func.FuncOp.from_py_func(RankedTensorType.get((4, 16), i32))
+        def test_i32_index(init_result):
+            return test_index(outs=[init_result])
+
+        # CHECK-LABEL: @test_f32_elemwise_exp
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+        # CHECK-NEXT:   %[[EXP:.+]] = math.exp %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+        # CHECK-NEXT: -> tensor<4x16xf32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
+        )
+        def test_f32_elemwise_exp(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.exp)
+
+        # CHECK-LABEL: @test_f32_elemwise_log
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+        # CHECK-NEXT:   %[[LOG:.+]] = math.log %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[LOG]] : f32
+        # CHECK-NEXT: -> tensor<4x16xf32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
+        )
+        def test_f32_elemwise_log(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log)
+
+        # CHECK-LABEL: @test_f32_elemwise_abs
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+        # CHECK-NEXT:   %[[EXP:.+]] = math.absf %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+        # CHECK-NEXT: -> tensor<4x16xf32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
+        )
+        def test_f32_elemwise_abs(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs)
+
+        # CHECK-LABEL: @test_f32_elemwise_ceil
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+        # CHECK-NEXT:   %[[EXP:.+]] = math.ceil %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+        # CHECK-NEXT: -> tensor<4x16xf32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
+        )
+        def test_f32_elemwise_ceil(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.ceil)
+
+        # CHECK-LABEL: @test_f32_elemwise_floor
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+        # CHECK-NEXT:   %[[EXP:.+]] = math.floor %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+        # CHECK-NEXT: -> tensor<4x16xf32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
+        )
+        def test_f32_elemwise_floor(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.floor)
+
+        # CHECK-LABEL: @test_f32_elemwise_neg
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+        # CHECK-NEXT:   %[[EXP:.+]] = arith.negf %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+        # CHECK-NEXT: -> tensor<4x16xf32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
+        )
+        def test_f32_elemwise_neg(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
+
+        # CHECK-LABEL: @test_c32_elemwise_neg
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
+        # CHECK-NEXT:   %[[EXP:.+]] = complex.neg %[[IN]] : complex<f32>
+        # CHECK-NEXT:   linalg.yield %[[EXP]] : complex<f32>
+        # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
+        )
+        def test_c32_elemwise_neg(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
+
+        # Just check that we don't assert out on name mismatch.
+        # CHECK-LABEL: @test_non_default_op_name
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((42,), f32), RankedTensorType.get((42,), f32)
+        )
+        def test_non_default_op_name(input, init_result):
+            return non_default_op_name(input, outs=[init_result])
 
 
 print(module)
index 2fd6338..ab049d3 100644 (file)
@@ -19,121 +19,134 @@ def pooling_poly(
     reduce=BinaryFnAttrDef(default=BinaryFn.max_signed),
     cast=TypeFnAttrDef(default=TypeFn.cast_signed),
     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
-  domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
-  O[D.n, D.oh, D.ow, D.c] = reduce[D.kh, D.kw](
-      cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
-                D.c]))
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
+    O[D.n, D.oh, D.ow, D.c] = reduce[D.kh, D.kw](
+        cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
+    )
 
 
 with Context() as ctx, Location.unknown():
-  module = Module.create()
-  f32 = F32Type.get()
-  i32 = IntegerType.get_signless(32)
-  with InsertionPoint(module.body):
-
-    # Pooling indexing maps.
-    # CHECK: #[[$POOL_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d3, d2 * 4 + d4 * 2, d5)>
-    # CHECK: #[[$POOL_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)>
-    # CHECK: #[[$POOL_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
-
-    # CHECK-LABEL: @test_f32i32_max_pooling
-    # CHECK: linalg.generic
-    # CHECK-SAME: indexing_maps = [#[[$POOL_MAP_I]], #[[$POOL_MAP_K]], #[[$POOL_MAP_O]]]
-    # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
-    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32)
-    # CHECK-NEXT:   %[[IN_CAST:.+]] = arith.fptosi %[[IN:.+]] : f32 to i32
-    # CHECK-NEXT:   %[[MAX:.+]] = arith.maxsi %[[OUT]], %[[IN_CAST:.+]] : i32
-    # CHECK-NEXT:   linalg.yield %[[MAX]] : i32
-    # CHECK-NEXT: -> tensor<1x2x4x1xi32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((1, 4, 16, 1), f32),
-        RankedTensorType.get((2, 2), f32),
-        RankedTensorType.get((1, 2, 4, 1), i32))
-    def test_f32i32_max_pooling(input, shape, init_result):
-      return pooling_poly(
-          input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
-
-    # CHECK-LABEL: @test_f32i32_max_unsigned_pooling
-    # CHECK:   = arith.fptoui
-    # CHECK:   = arith.maxui
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((1, 4, 16, 1), f32),
-        RankedTensorType.get((2, 2), f32),
-        RankedTensorType.get((1, 2, 4, 1), i32))
-    def test_f32i32_max_unsigned_pooling(input, shape, init_result):
-      return pooling_poly(
-          input,
-          shape,
-          outs=[init_result],
-          reduce=BinaryFn.max_unsigned,
-          cast=TypeFn.cast_unsigned,
-          strides=[2, 4],
-          dilations=[1, 2])
-
-    # CHECK-LABEL: @test_f32f32_max_pooling
-    # CHECK: linalg.generic
-    # CHECK-SAME: indexing_maps = [#[[$POOL_MAP_I]], #[[$POOL_MAP_K]], #[[$POOL_MAP_O]]]
-    # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
-    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: f32)
-    # CHECK-NEXT:   %[[MAX:.+]] = arith.maxf %[[OUT]], %[[IN:.+]] : f32
-    # CHECK-NEXT:   linalg.yield %[[MAX]] : f32
-    # CHECK-NEXT: -> tensor<1x2x4x1xf32>
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((1, 4, 16, 1), f32),
-        RankedTensorType.get((2, 2), f32),
-        RankedTensorType.get((1, 2, 4, 1), f32))
-    def test_f32f32_max_pooling(input, shape, init_result):
-      return pooling_poly(
-          input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
-
-    # CHECK-LABEL: @test_f32i32_min_pooling
-    # CHECK:   = arith.fptosi
-    # CHECK:   = arith.minsi
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((1, 4, 16, 1), f32),
-        RankedTensorType.get((2, 2), f32),
-        RankedTensorType.get((1, 2, 4, 1), i32))
-    def test_f32i32_min_pooling(input, shape, init_result):
-      return pooling_poly(
-          input,
-          shape,
-          outs=[init_result],
-          reduce=BinaryFn.min_signed,
-          strides=[2, 4],
-          dilations=[1, 2])
-
-    # CHECK-LABEL: @test_f32i32_min_unsigned_pooling
-    # CHECK:   = arith.fptoui
-    # CHECK:   = arith.minui
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((1, 4, 16, 1), f32),
-        RankedTensorType.get((2, 2), f32),
-        RankedTensorType.get((1, 2, 4, 1), i32))
-    def test_f32i32_min_unsigned_pooling(input, shape, init_result):
-      return pooling_poly(
-          input,
-          shape,
-          outs=[init_result],
-          reduce=BinaryFn.min_unsigned,
-          cast=TypeFn.cast_unsigned,
-          strides=[2, 4],
-          dilations=[1, 2])
-
-    # CHECK-LABEL: @test_f32f32_min_pooling
-    # CHECK:   = arith.minf
-    @func.FuncOp.from_py_func(
-        RankedTensorType.get((1, 4, 16, 1), f32),
-        RankedTensorType.get((2, 2), f32),
-        RankedTensorType.get((1, 2, 4, 1), f32))
-    def test_f32f32_min_pooling(input, shape, init_result):
-      return pooling_poly(
-          input,
-          shape,
-          outs=[init_result],
-          reduce=BinaryFn.min_signed,
-          strides=[2, 4],
-          dilations=[1, 2])
+    module = Module.create()
+    f32 = F32Type.get()
+    i32 = IntegerType.get_signless(32)
+    with InsertionPoint(module.body):
+
+        # Pooling indexing maps.
+        # CHECK: #[[$POOL_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d3, d2 * 4 + d4 * 2, d5)>
+        # CHECK: #[[$POOL_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)>
+        # CHECK: #[[$POOL_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
+
+        # CHECK-LABEL: @test_f32i32_max_pooling
+        # CHECK: linalg.generic
+        # CHECK-SAME: indexing_maps = [#[[$POOL_MAP_I]], #[[$POOL_MAP_K]], #[[$POOL_MAP_O]]]
+        # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32)
+        # CHECK-NEXT:   %[[IN_CAST:.+]] = arith.fptosi %[[IN:.+]] : f32 to i32
+        # CHECK-NEXT:   %[[MAX:.+]] = arith.maxsi %[[OUT]], %[[IN_CAST:.+]] : i32
+        # CHECK-NEXT:   linalg.yield %[[MAX]] : i32
+        # CHECK-NEXT: -> tensor<1x2x4x1xi32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((1, 4, 16, 1), f32),
+            RankedTensorType.get((2, 2), f32),
+            RankedTensorType.get((1, 2, 4, 1), i32),
+        )
+        def test_f32i32_max_pooling(input, shape, init_result):
+            return pooling_poly(
+                input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]
+            )
+
+        # CHECK-LABEL: @test_f32i32_max_unsigned_pooling
+        # CHECK:   = arith.fptoui
+        # CHECK:   = arith.maxui
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((1, 4, 16, 1), f32),
+            RankedTensorType.get((2, 2), f32),
+            RankedTensorType.get((1, 2, 4, 1), i32),
+        )
+        def test_f32i32_max_unsigned_pooling(input, shape, init_result):
+            return pooling_poly(
+                input,
+                shape,
+                outs=[init_result],
+                reduce=BinaryFn.max_unsigned,
+                cast=TypeFn.cast_unsigned,
+                strides=[2, 4],
+                dilations=[1, 2],
+            )
+
+        # CHECK-LABEL: @test_f32f32_max_pooling
+        # CHECK: linalg.generic
+        # CHECK-SAME: indexing_maps = [#[[$POOL_MAP_I]], #[[$POOL_MAP_K]], #[[$POOL_MAP_O]]]
+        # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: f32)
+        # CHECK-NEXT:   %[[MAX:.+]] = arith.maxf %[[OUT]], %[[IN:.+]] : f32
+        # CHECK-NEXT:   linalg.yield %[[MAX]] : f32
+        # CHECK-NEXT: -> tensor<1x2x4x1xf32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((1, 4, 16, 1), f32),
+            RankedTensorType.get((2, 2), f32),
+            RankedTensorType.get((1, 2, 4, 1), f32),
+        )
+        def test_f32f32_max_pooling(input, shape, init_result):
+            return pooling_poly(
+                input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]
+            )
+
+        # CHECK-LABEL: @test_f32i32_min_pooling
+        # CHECK:   = arith.fptosi
+        # CHECK:   = arith.minsi
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((1, 4, 16, 1), f32),
+            RankedTensorType.get((2, 2), f32),
+            RankedTensorType.get((1, 2, 4, 1), i32),
+        )
+        def test_f32i32_min_pooling(input, shape, init_result):
+            return pooling_poly(
+                input,
+                shape,
+                outs=[init_result],
+                reduce=BinaryFn.min_signed,
+                strides=[2, 4],
+                dilations=[1, 2],
+            )
+
+        # CHECK-LABEL: @test_f32i32_min_unsigned_pooling
+        # CHECK:   = arith.fptoui
+        # CHECK:   = arith.minui
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((1, 4, 16, 1), f32),
+            RankedTensorType.get((2, 2), f32),
+            RankedTensorType.get((1, 2, 4, 1), i32),
+        )
+        def test_f32i32_min_unsigned_pooling(input, shape, init_result):
+            return pooling_poly(
+                input,
+                shape,
+                outs=[init_result],
+                reduce=BinaryFn.min_unsigned,
+                cast=TypeFn.cast_unsigned,
+                strides=[2, 4],
+                dilations=[1, 2],
+            )
+
+        # CHECK-LABEL: @test_f32f32_min_pooling
+        # CHECK:   = arith.minf
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((1, 4, 16, 1), f32),
+            RankedTensorType.get((2, 2), f32),
+            RankedTensorType.get((1, 2, 4, 1), f32),
+        )
+        def test_f32f32_min_pooling(input, shape, init_result):
+            return pooling_poly(
+                input,
+                shape,
+                outs=[init_result],
+                reduce=BinaryFn.min_signed,
+                strides=[2, 4],
+                dilations=[1, 2],
+            )
 
 
 print(module)
index cead85f..18d2d45 100644 (file)
@@ -4,6 +4,6 @@
 # Since both lit and the python bindings use the same python interpreter,
 # we can just check whether yaml can be imported here and exclude if not.
 try:
-  import yaml
+    import yaml
 except ModuleNotFoundError:
-  config.unsupported = True
+    config.unsupported = True
index a7502e9..9c940e1 100644 (file)
@@ -13,8 +13,10 @@ from mlir.dialects.linalg.opdsl.lang import *
 def matmul(
     A=TensorDef(T, S.M, S.K),
     B=TensorDef(T, S.K, S.N),
-    C=TensorDef(U, S.M, S.N, output=True)):
-  implements(ContractionOpInterface)
-  defines(Canonicalizer)
-  C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed(
-      U, B[D.k, D.n])
+    C=TensorDef(U, S.M, S.N, output=True),
+):
+    implements(ContractionOpInterface)
+    defines(Canonicalizer)
+    C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed(
+        U, B[D.k, D.n]
+    )
index 871341c..4f3569b 100644 (file)
@@ -22,10 +22,12 @@ from mlir.dialects.linalg.opdsl.lang import *
 def matmul(
     A=TensorDef(T, S.M, S.K),
     B=TensorDef(T, S.K, S.N),
-    C=TensorDef(U, S.M, S.N, output=True)):
-  domain(D.m, D.n, D.k)
-  C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed(
-      U, B[D.k, D.n])
+    C=TensorDef(U, S.M, S.N, output=True),
+):
+    domain(D.m, D.n, D.k)
+    C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed(
+        U, B[D.k, D.n]
+    )
 
 
 # Verifies that assignment to a scalar (represented as [None]) is represented
@@ -43,7 +45,7 @@ def matmul(
 # CHECK-NEXT: - reduction
 @linalg_structured_op
 def dot(A=TensorDef(T, S.M), B=TensorDef(T, S.M), C=TensorDef(U, output=True)):
-  C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m])
+    C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m])
 
 
 # Verifies that the index_dims of shape-only operands translate to correct
@@ -64,6 +66,7 @@ def dot(A=TensorDef(T, S.M), B=TensorDef(T, S.M), C=TensorDef(U, output=True)):
 def pool(
     I=TensorDef(T, S.I),
     K=TensorDef(T, S.K, index_dims=[D.k]),
-    O=TensorDef(U, S.O, output=True)):
-  domain(D.o, D.k)
-  O[D.o] += TypeFn.cast_signed(U, I[D.o * 2 + D.k])
+    O=TensorDef(U, S.O, output=True),
+):
+    domain(D.o, D.k)
+    O[D.o] += TypeFn.cast_signed(U, I[D.o * 2 + D.k])
index 1167abf..5e8414a 100644 (file)
@@ -6,145 +6,154 @@ from mlir.ir import *
 
 
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    return f
 
 
 # CHECK-LABEL: TEST: testFill
 @run
 def testFill():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32 = F32Type.get()
-    with InsertionPoint(module.body):
-      # CHECK-LABEL: func @fill_tensor
-      #  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: tensor<12x?xf32>
-      #  CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32
-      #  CHECK-NEXT: %[[RES:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : tensor<12x?xf32>) -> tensor<12x?xf32>
-      #  CHECK-NEXT: return %[[RES]] : tensor<12x?xf32>
-      @func.FuncOp.from_py_func(
-          RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32))
-      def fill_tensor(out):
-        zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
-        return linalg.fill(zero, outs=[out])
-
-      # CHECK-LABEL: func @fill_buffer
-      #  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: memref<12x?xf32>
-      #  CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32
-      #  CHECK-NEXT: linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : memref<12x?xf32>)
-      #  CHECK-NEXT: return
-      @func.FuncOp.from_py_func(
-          MemRefType.get((12, ShapedType.get_dynamic_size()), f32))
-      def fill_buffer(out):
-        zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
-        linalg.fill(zero, outs=[out])
-
-  print(module)
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+            # CHECK-LABEL: func @fill_tensor
+            #  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: tensor<12x?xf32>
+            #  CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32
+            #  CHECK-NEXT: %[[RES:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : tensor<12x?xf32>) -> tensor<12x?xf32>
+            #  CHECK-NEXT: return %[[RES]] : tensor<12x?xf32>
+            @func.FuncOp.from_py_func(
+                RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32)
+            )
+            def fill_tensor(out):
+                zero = arith.ConstantOp(
+                    value=FloatAttr.get(f32, 0.0), result=f32
+                ).result
+                return linalg.fill(zero, outs=[out])
+
+            # CHECK-LABEL: func @fill_buffer
+            #  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: memref<12x?xf32>
+            #  CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32
+            #  CHECK-NEXT: linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : memref<12x?xf32>)
+            #  CHECK-NEXT: return
+            @func.FuncOp.from_py_func(
+                MemRefType.get((12, ShapedType.get_dynamic_size()), f32)
+            )
+            def fill_buffer(out):
+                zero = arith.ConstantOp(
+                    value=FloatAttr.get(f32, 0.0), result=f32
+                ).result
+                linalg.fill(zero, outs=[out])
+
+    print(module)
 
 
 # CHECK-LABEL: TEST: testNamedStructuredOpCustomForm
 @run
 def testNamedStructuredOpCustomForm():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32 = F32Type.get()
-    with InsertionPoint(module.body):
-
-      @func.FuncOp.from_py_func(
-          RankedTensorType.get((4, 8), f32), RankedTensorType.get((4, 8), f32))
-      def named_form(lhs, rhs):
-        init_result = tensor.EmptyOp([4, 8], f32)
-        # Check for the named form with custom format
-        #      CHECK: linalg.elemwise_unary
-        # CHECK-SAME:    cast = #linalg.type_fn<cast_signed>
-        # CHECK-SAME:    fun = #linalg.unary_fn<exp>
-        # CHECK-SAME:    ins(%{{.*}} : tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
-        unary_result = linalg.elemwise_unary(lhs, outs=[init_result.result])
-        #      CHECK: linalg.elemwise_binary
-        # CHECK-SAME:    cast = #linalg.type_fn<cast_unsigned>
-        # CHECK-SAME:    fun = #linalg.binary_fn<mul>
-        # CHECK-SAME:    ins(%{{.*}}, %{{.*}} : tensor<4x8xf32>, tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
-        #      CHECK: return
-        binary_result = linalg.elemwise_binary(
-            lhs,
-            rhs,
-            outs=[init_result.result],
-            fun=BinaryFn.mul,
-            cast=TypeFn.cast_unsigned)
-        return unary_result, binary_result
-
-  print(module)
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(
+                RankedTensorType.get((4, 8), f32), RankedTensorType.get((4, 8), f32)
+            )
+            def named_form(lhs, rhs):
+                init_result = tensor.EmptyOp([4, 8], f32)
+                # Check for the named form with custom format
+                #      CHECK: linalg.elemwise_unary
+                # CHECK-SAME:    cast = #linalg.type_fn<cast_signed>
+                # CHECK-SAME:    fun = #linalg.unary_fn<exp>
+                # CHECK-SAME:    ins(%{{.*}} : tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
+                unary_result = linalg.elemwise_unary(lhs, outs=[init_result.result])
+                #      CHECK: linalg.elemwise_binary
+                # CHECK-SAME:    cast = #linalg.type_fn<cast_unsigned>
+                # CHECK-SAME:    fun = #linalg.binary_fn<mul>
+                # CHECK-SAME:    ins(%{{.*}}, %{{.*}} : tensor<4x8xf32>, tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
+                #      CHECK: return
+                binary_result = linalg.elemwise_binary(
+                    lhs,
+                    rhs,
+                    outs=[init_result.result],
+                    fun=BinaryFn.mul,
+                    cast=TypeFn.cast_unsigned,
+                )
+                return unary_result, binary_result
+
+    print(module)
 
 
 # CHECK-LABEL: TEST: testNamedStructuredOpGenericForm
 @run
 def testNamedStructuredOpGenericForm():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32 = F32Type.get()
-    with InsertionPoint(module.body):
-
-      @func.FuncOp.from_py_func(
-          RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8),
-                                                                   f32))
-      def named_form(lhs, rhs):
-        init_result = tensor.EmptyOp([4, 8], f32)
-        #      CHECK: "linalg.matmul"(%{{.*}})
-        # CHECK-SAME:    cast = #linalg.type_fn<cast_signed>
-        # CHECK-SAME:    operand_segment_sizes = array<i32: 2, 1>
-        # CHECK-NEXT:  ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32):
-        # CHECK-NEXT:    arith.mulf{{.*}} (f32, f32) -> f32
-        # CHECK-NEXT:    arith.addf{{.*}} (f32, f32) -> f32
-        # CHECK-NEXT:    linalg.yield{{.*}} (f32) -> ()
-        # CHECK-NEXT: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
-        return linalg.matmul(lhs, rhs, outs=[init_result.result])
-
-  module.operation.print(print_generic_op_form=True)
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(
+                RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
+            )
+            def named_form(lhs, rhs):
+                init_result = tensor.EmptyOp([4, 8], f32)
+                #      CHECK: "linalg.matmul"(%{{.*}})
+                # CHECK-SAME:    cast = #linalg.type_fn<cast_signed>
+                # CHECK-SAME:    operand_segment_sizes = array<i32: 2, 1>
+                # CHECK-NEXT:  ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32):
+                # CHECK-NEXT:    arith.mulf{{.*}} (f32, f32) -> f32
+                # CHECK-NEXT:    arith.addf{{.*}} (f32, f32) -> f32
+                # CHECK-NEXT:    linalg.yield{{.*}} (f32) -> ()
+                # CHECK-NEXT: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
+                return linalg.matmul(lhs, rhs, outs=[init_result.result])
+
+    module.operation.print(print_generic_op_form=True)
 
 
 # CHECK-LABEL: TEST: testNamedStructuredAsGenericOp
 @run
 def testNamedStructuredAsGenericOp():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32 = F32Type.get()
-    with InsertionPoint(module.body):
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
 
-      @func.FuncOp.from_py_func(
-          RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8),
-                                                                   f32))
-      def generic_form(lhs, rhs):
-        init_result = tensor.EmptyOp([4, 8], f32)
-        # CHECK: linalg.generic
-        return linalg.matmul(
-            lhs, rhs, outs=[init_result.result], emit_generic=True)
+            @func.FuncOp.from_py_func(
+                RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
+            )
+            def generic_form(lhs, rhs):
+                init_result = tensor.EmptyOp([4, 8], f32)
+                # CHECK: linalg.generic
+                return linalg.matmul(
+                    lhs, rhs, outs=[init_result.result], emit_generic=True
+                )
 
-  print(module)
+    print(module)
 
 
 # CHECK-LABEL: TEST: testOpResultFromOtherOp
 @run
 def testOpResultFromOtherOp():
-  with Context(), Location.unknown():
-    module = Module.create()
-    f32 = F32Type.get()
-    with InsertionPoint(module.body):
-
-      @func.FuncOp.from_py_func(
-          RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8),
-                                                                   f32))
-      def pass_an_op_directly(arg0, arg1):
-        one = arith.ConstantOp(F32Type.get(), 1.0)
-        # CHECK: %[[LHS:.*]] = linalg.fill
-        lhs = linalg.fill(one, outs=[arg0])
-        # CHECK: %[[RHS:.*]] = linalg.fill
-        rhs = linalg.fill(one, outs=[arg1])
-        # CHECK: %[[INIT:.*]] = tensor.empty
-        init = tensor.EmptyOp([4, 8], f32)
-        # CHECK: linalg.matmul
-        # CHECK: ins(%[[LHS]], %[[RHS]]
-        # CHECK: outs(%[[INIT]]
-        return linalg.matmul(lhs, rhs, outs=init)
-
-  print(module)
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(
+                RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
+            )
+            def pass_an_op_directly(arg0, arg1):
+                one = arith.ConstantOp(F32Type.get(), 1.0)
+                # CHECK: %[[LHS:.*]] = linalg.fill
+                lhs = linalg.fill(one, outs=[arg0])
+                # CHECK: %[[RHS:.*]] = linalg.fill
+                rhs = linalg.fill(one, outs=[arg1])
+                # CHECK: %[[INIT:.*]] = tensor.empty
+                init = tensor.EmptyOp([4, 8], f32)
+                # CHECK: linalg.matmul
+                # CHECK: ins(%[[LHS]], %[[RHS]]
+                # CHECK: outs(%[[INIT]]
+                return linalg.matmul(lhs, rhs, outs=init)
+
+    print(module)
index 04b6d84..3d402c5 100644 (file)
@@ -7,23 +7,26 @@ from mlir.ir import *
 import mlir.dialects.func as func
 import mlir.dialects.math as mlir_math
 
+
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
+    print("\nTEST:", f.__name__)
+    f()
+
 
 # CHECK-LABEL: TEST: testMathOps
 @run
 def testMathOps():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    with InsertionPoint(module.body):
-      @func.FuncOp.from_py_func(F32Type.get())
-      def emit_sqrt(arg):
-        return mlir_math.SqrtOp(arg)
-
-    # CHECK-LABEL: func @emit_sqrt(
-    # CHECK-SAME:                  %[[ARG:.*]]: f32) -> f32 {
-    # CHECK:         math.sqrt %[[ARG]] : f32
-    # CHECK:         return
-    # CHECK:       }
-    print(module)
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(F32Type.get())
+            def emit_sqrt(arg):
+                return mlir_math.SqrtOp(arg)
+
+        # CHECK-LABEL: func @emit_sqrt(
+        # CHECK-SAME:                  %[[ARG:.*]]: f32) -> f32 {
+        # CHECK:         math.sqrt %[[ARG]] : f32
+        # CHECK:         return
+        # CHECK:       }
+        print(module)
index 59092fe..2e3cae6 100644 (file)
@@ -6,17 +6,17 @@ import mlir.dialects.memref as memref
 
 
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    return f
 
 
 # CHECK-LABEL: TEST: testSubViewAccessors
 @run
 def testSubViewAccessors():
-  ctx = Context()
-  module = Module.parse(
-      r"""
+    ctx = Context()
+    module = Module.parse(
+        r"""
     func.func @f1(%arg0: memref<?x?xf32>) {
       %0 = arith.constant 0 : index
       %1 = arith.constant 1 : index
@@ -27,48 +27,52 @@ def testSubViewAccessors():
       memref.subview %arg0[%0, %1][%2, %3][%4, %5] : memref<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
       return
     }
-  """, ctx)
-  func_body = module.body.operations[0].regions[0].blocks[0]
-  subview = func_body.operations[6]
+  """,
+        ctx,
+    )
+    func_body = module.body.operations[0].regions[0].blocks[0]
+    subview = func_body.operations[6]
 
-  assert subview.source == subview.operands[0]
-  assert len(subview.offsets) == 2
-  assert len(subview.sizes) == 2
-  assert len(subview.strides) == 2
-  assert subview.result == subview.results[0]
+    assert subview.source == subview.operands[0]
+    assert len(subview.offsets) == 2
+    assert len(subview.sizes) == 2
+    assert len(subview.strides) == 2
+    assert subview.result == subview.results[0]
 
-  # CHECK: SubViewOp
-  print(type(subview).__name__)
+    # CHECK: SubViewOp
+    print(type(subview).__name__)
 
-  # CHECK: constant 0
-  print(subview.offsets[0])
-  # CHECK: constant 1
-  print(subview.offsets[1])
-  # CHECK: constant 2
-  print(subview.sizes[0])
-  # CHECK: constant 3
-  print(subview.sizes[1])
-  # CHECK: constant 4
-  print(subview.strides[0])
-  # CHECK: constant 5
-  print(subview.strides[1])
+    # CHECK: constant 0
+    print(subview.offsets[0])
+    # CHECK: constant 1
+    print(subview.offsets[1])
+    # CHECK: constant 2
+    print(subview.sizes[0])
+    # CHECK: constant 3
+    print(subview.sizes[1])
+    # CHECK: constant 4
+    print(subview.strides[0])
+    # CHECK: constant 5
+    print(subview.strides[1])
 
 
 # CHECK-LABEL: TEST: testCustomBuidlers
 @run
 def testCustomBuidlers():
-  with Context() as ctx, Location.unknown(ctx):
-    module = Module.parse(r"""
+    with Context() as ctx, Location.unknown(ctx):
+        module = Module.parse(
+            r"""
       func.func @f1(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) {
         return
       }
-    """)
-    f = module.body.operations[0]
-    func_body = f.regions[0].blocks[0]
-    with InsertionPoint.at_block_terminator(func_body):
-      memref.LoadOp(f.arguments[0], f.arguments[1:])
+    """
+        )
+        f = module.body.operations[0]
+        func_body = f.regions[0].blocks[0]
+        with InsertionPoint.at_block_terminator(func_body):
+            memref.LoadOp(f.arguments[0], f.arguments[1:])
 
-    # CHECK: func @f1(%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
-    # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
-    print(module)
-    assert module.operation.verify()
+        # CHECK: func @f1(%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+        # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
+        print(module)
+        assert module.operation.verify()
index 4d9804f..f16de2a 100644 (file)
@@ -6,23 +6,23 @@ from mlir.dialects import ml_program
 
 
 def constructAndPrintInModule(f):
-  print("\nTEST:", f.__name__)
-  with Context(), Location.unknown():
-    module = Module.create()
-    with InsertionPoint(module.body):
-      f()
-    print(module)
-  return f
+    print("\nTEST:", f.__name__)
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            f()
+        print(module)
+    return f
 
 
 # CHECK-LABEL: testFuncOp
 @constructAndPrintInModule
 def testFuncOp():
-  # CHECK: ml_program.func @foobar(%arg0: si32) -> si32
-  f = ml_program.FuncOp(
-      name="foobar",
-      type=([IntegerType.get_signed(32)], [IntegerType.get_signed(32)]))
-  block = f.add_entry_block()
-  with InsertionPoint(block):
-    # CHECK: ml_program.return
-    ml_program.ReturnOp([block.arguments[0]])
+    # CHECK: ml_program.func @foobar(%arg0: si32) -> si32
+    f = ml_program.FuncOp(
+        name="foobar", type=([IntegerType.get_signed(32)], [IntegerType.get_signed(32)])
+    )
+    block = f.add_entry_block()
+    with InsertionPoint(block):
+        # CHECK: ml_program.return
+        ml_program.ReturnOp([block.arguments[0]])
index 802a1f2..71879bd 100644 (file)
@@ -6,207 +6,205 @@ from mlir.ir import *
 
 
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  gc.collect()
-  assert Context._get_live_count() == 0
+    print("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
 
 
 def add_dummy_value():
-  return Operation.create(
-      "custom.value",
-      results=[IntegerType.get_signless(32)]).result
+    return Operation.create(
+        "custom.value", results=[IntegerType.get_signless(32)]
+    ).result
 
 
 def testOdsBuildDefaultImplicitRegions():
-
-  class TestFixedRegionsOp(OpView):
-    OPERATION_NAME = "custom.test_op"
-    _ODS_REGIONS = (2, True)
-
-  class TestVariadicRegionsOp(OpView):
-    OPERATION_NAME = "custom.test_any_regions_op"
-    _ODS_REGIONS = (2, False)
-
-  with Context() as ctx, Location.unknown():
-    ctx.allow_unregistered_dialects = True
-    m = Module.create()
-    with InsertionPoint(m.body):
-      op = TestFixedRegionsOp.build_generic(results=[], operands=[])
-      # CHECK: NUM_REGIONS: 2
-      print(f"NUM_REGIONS: {len(op.regions)}")
-      # Including a regions= that matches should be fine.
-      op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=2)
-      print(f"NUM_REGIONS: {len(op.regions)}")
-      # Reject greater than.
-      try:
-        op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=3)
-      except ValueError as e:
-        # CHECK: ERROR:Operation "custom.test_op" requires a maximum of 2 regions but was built with regions=3
-        print(f"ERROR:{e}")
-      # Reject less than.
-      try:
-        op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=1)
-      except ValueError as e:
-        # CHECK: ERROR:Operation "custom.test_op" requires a minimum of 2 regions but was built with regions=1
-        print(f"ERROR:{e}")
-
-      # If no regions specified for a variadic region op, build the minimum.
-      op = TestVariadicRegionsOp.build_generic(results=[], operands=[])
-      # CHECK: DEFAULT_NUM_REGIONS: 2
-      print(f"DEFAULT_NUM_REGIONS: {len(op.regions)}")
-      # Should also accept an explicit regions= that matches the minimum.
-      op = TestVariadicRegionsOp.build_generic(
-          results=[], operands=[], regions=2)
-      # CHECK: EQ_NUM_REGIONS: 2
-      print(f"EQ_NUM_REGIONS: {len(op.regions)}")
-      # And accept greater than minimum.
-      # Should also accept an explicit regions= that matches the minimum.
-      op = TestVariadicRegionsOp.build_generic(
-          results=[], operands=[], regions=3)
-      # CHECK: GT_NUM_REGIONS: 3
-      print(f"GT_NUM_REGIONS: {len(op.regions)}")
-      # Should reject less than minimum.
-      try:
-        op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=1)
-      except ValueError as e:
-        # CHECK: ERROR:Operation "custom.test_any_regions_op" requires a minimum of 2 regions but was built with regions=1
-        print(f"ERROR:{e}")
-
+    class TestFixedRegionsOp(OpView):
+        OPERATION_NAME = "custom.test_op"
+        _ODS_REGIONS = (2, True)
+
+    class TestVariadicRegionsOp(OpView):
+        OPERATION_NAME = "custom.test_any_regions_op"
+        _ODS_REGIONS = (2, False)
+
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+        m = Module.create()
+        with InsertionPoint(m.body):
+            op = TestFixedRegionsOp.build_generic(results=[], operands=[])
+            # CHECK: NUM_REGIONS: 2
+            print(f"NUM_REGIONS: {len(op.regions)}")
+            # Including a regions= that matches should be fine.
+            op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=2)
+            print(f"NUM_REGIONS: {len(op.regions)}")
+            # Reject greater than.
+            try:
+                op = TestFixedRegionsOp.build_generic(
+                    results=[], operands=[], regions=3
+                )
+            except ValueError as e:
+                # CHECK: ERROR:Operation "custom.test_op" requires a maximum of 2 regions but was built with regions=3
+                print(f"ERROR:{e}")
+            # Reject less than.
+            try:
+                op = TestFixedRegionsOp.build_generic(
+                    results=[], operands=[], regions=1
+                )
+            except ValueError as e:
+                # CHECK: ERROR:Operation "custom.test_op" requires a minimum of 2 regions but was built with regions=1
+                print(f"ERROR:{e}")
+
+            # If no regions specified for a variadic region op, build the minimum.
+            op = TestVariadicRegionsOp.build_generic(results=[], operands=[])
+            # CHECK: DEFAULT_NUM_REGIONS: 2
+            print(f"DEFAULT_NUM_REGIONS: {len(op.regions)}")
+            # Should also accept an explicit regions= that matches the minimum.
+            op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=2)
+            # CHECK: EQ_NUM_REGIONS: 2
+            print(f"EQ_NUM_REGIONS: {len(op.regions)}")
+            # And accept greater than minimum.
+            # Should also accept an explicit regions= that matches the minimum.
+            op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=3)
+            # CHECK: GT_NUM_REGIONS: 3
+            print(f"GT_NUM_REGIONS: {len(op.regions)}")
+            # Should reject less than minimum.
+            try:
+                op = TestVariadicRegionsOp.build_generic(
+                    results=[], operands=[], regions=1
+                )
+            except ValueError as e:
+                # CHECK: ERROR:Operation "custom.test_any_regions_op" requires a minimum of 2 regions but was built with regions=1
+                print(f"ERROR:{e}")
 
 
 run(testOdsBuildDefaultImplicitRegions)
 
 
 def testOdsBuildDefaultNonVariadic():
+    class TestOp(OpView):
+        OPERATION_NAME = "custom.test_op"
+
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+        m = Module.create()
+        with InsertionPoint(m.body):
+            v0 = add_dummy_value()
+            v1 = add_dummy_value()
+            t0 = IntegerType.get_signless(8)
+            t1 = IntegerType.get_signless(16)
+            op = TestOp.build_generic(results=[t0, t1], operands=[v0, v1])
+            # CHECK: %[[V0:.+]] = "custom.value"
+            # CHECK: %[[V1:.+]] = "custom.value"
+            # CHECK: "custom.test_op"(%[[V0]], %[[V1]])
+            # CHECK-NOT: operand_segment_sizes
+            # CHECK-NOT: result_segment_sizes
+            # CHECK-SAME: : (i32, i32) -> (i8, i16)
+            print(m)
 
-  class TestOp(OpView):
-    OPERATION_NAME = "custom.test_op"
-
-  with Context() as ctx, Location.unknown():
-    ctx.allow_unregistered_dialects = True
-    m = Module.create()
-    with InsertionPoint(m.body):
-      v0 = add_dummy_value()
-      v1 = add_dummy_value()
-      t0 = IntegerType.get_signless(8)
-      t1 = IntegerType.get_signless(16)
-      op = TestOp.build_generic(results=[t0, t1], operands=[v0, v1])
-      # CHECK: %[[V0:.+]] = "custom.value"
-      # CHECK: %[[V1:.+]] = "custom.value"
-      # CHECK: "custom.test_op"(%[[V0]], %[[V1]])
-      # CHECK-NOT: operand_segment_sizes
-      # CHECK-NOT: result_segment_sizes
-      # CHECK-SAME: : (i32, i32) -> (i8, i16)
-      print(m)
 
 run(testOdsBuildDefaultNonVariadic)
 
 
 def testOdsBuildDefaultSizedVariadic():
+    class TestOp(OpView):
+        OPERATION_NAME = "custom.test_op"
+        _ODS_OPERAND_SEGMENTS = [1, -1, 0]
+        _ODS_RESULT_SEGMENTS = [-1, 0, 1]
+
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+        m = Module.create()
+        with InsertionPoint(m.body):
+            v0 = add_dummy_value()
+            v1 = add_dummy_value()
+            v2 = add_dummy_value()
+            v3 = add_dummy_value()
+            t0 = IntegerType.get_signless(8)
+            t1 = IntegerType.get_signless(16)
+            t2 = IntegerType.get_signless(32)
+            t3 = IntegerType.get_signless(64)
+            # CHECK: %[[V0:.+]] = "custom.value"
+            # CHECK: %[[V1:.+]] = "custom.value"
+            # CHECK: %[[V2:.+]] = "custom.value"
+            # CHECK: %[[V3:.+]] = "custom.value"
+            # CHECK: "custom.test_op"(%[[V0]], %[[V1]], %[[V2]], %[[V3]])
+            # CHECK-SAME: operand_segment_sizes = array<i32: 1, 2, 1>
+            # CHECK-SAME: result_segment_sizes = array<i32: 2, 1, 1>
+            # CHECK-SAME: : (i32, i32, i32, i32) -> (i8, i16, i32, i64)
+            op = TestOp.build_generic(
+                results=[[t0, t1], t2, t3], operands=[v0, [v1, v2], v3]
+            )
+
+            # Now test with optional omitted.
+            # CHECK: "custom.test_op"(%[[V0]])
+            # CHECK-SAME: operand_segment_sizes = array<i32: 1, 0, 0>
+            # CHECK-SAME: result_segment_sizes = array<i32: 0, 0, 1>
+            # CHECK-SAME: (i32) -> i64
+            op = TestOp.build_generic(
+                results=[None, None, t3], operands=[v0, None, None]
+            )
+            print(m)
+
+            # And verify that errors are raised for None in a required operand.
+            try:
+                op = TestOp.build_generic(
+                    results=[None, None, t3], operands=[None, None, None]
+                )
+            except ValueError as e:
+                # CHECK: OPERAND_CAST_ERROR:Operand 0 of operation "custom.test_op" must be a Value (was None and operand is not optional)
+                print(f"OPERAND_CAST_ERROR:{e}")
+
+            # And verify that errors are raised for None in a required result.
+            try:
+                op = TestOp.build_generic(
+                    results=[None, None, None], operands=[v0, None, None]
+                )
+            except ValueError as e:
+                # CHECK: RESULT_CAST_ERROR:Result 2 of operation "custom.test_op" must be a Type (was None and result is not optional)
+                print(f"RESULT_CAST_ERROR:{e}")
+
+            # Variadic lists with None elements should reject.
+            try:
+                op = TestOp.build_generic(
+                    results=[None, None, t3], operands=[v0, [None], None]
+                )
+            except ValueError as e:
+                # CHECK: OPERAND_LIST_CAST_ERROR:Operand 1 of operation "custom.test_op" must be a Sequence of Values (contained a None item)
+                print(f"OPERAND_LIST_CAST_ERROR:{e}")
+            try:
+                op = TestOp.build_generic(
+                    results=[[None], None, t3], operands=[v0, None, None]
+                )
+            except ValueError as e:
+                # CHECK: RESULT_LIST_CAST_ERROR:Result 0 of operation "custom.test_op" must be a Sequence of Types (contained a None item)
+                print(f"RESULT_LIST_CAST_ERROR:{e}")
 
-  class TestOp(OpView):
-    OPERATION_NAME = "custom.test_op"
-    _ODS_OPERAND_SEGMENTS = [1, -1, 0]
-    _ODS_RESULT_SEGMENTS = [-1, 0, 1]
-
-  with Context() as ctx, Location.unknown():
-    ctx.allow_unregistered_dialects = True
-    m = Module.create()
-    with InsertionPoint(m.body):
-      v0 = add_dummy_value()
-      v1 = add_dummy_value()
-      v2 = add_dummy_value()
-      v3 = add_dummy_value()
-      t0 = IntegerType.get_signless(8)
-      t1 = IntegerType.get_signless(16)
-      t2 = IntegerType.get_signless(32)
-      t3 = IntegerType.get_signless(64)
-      # CHECK: %[[V0:.+]] = "custom.value"
-      # CHECK: %[[V1:.+]] = "custom.value"
-      # CHECK: %[[V2:.+]] = "custom.value"
-      # CHECK: %[[V3:.+]] = "custom.value"
-      # CHECK: "custom.test_op"(%[[V0]], %[[V1]], %[[V2]], %[[V3]])
-      # CHECK-SAME: operand_segment_sizes = array<i32: 1, 2, 1>
-      # CHECK-SAME: result_segment_sizes = array<i32: 2, 1, 1>
-      # CHECK-SAME: : (i32, i32, i32, i32) -> (i8, i16, i32, i64)
-      op = TestOp.build_generic(
-          results=[[t0, t1], t2, t3],
-          operands=[v0, [v1, v2], v3])
-
-      # Now test with optional omitted.
-      # CHECK: "custom.test_op"(%[[V0]])
-      # CHECK-SAME: operand_segment_sizes = array<i32: 1, 0, 0>
-      # CHECK-SAME: result_segment_sizes = array<i32: 0, 0, 1>
-      # CHECK-SAME: (i32) -> i64
-      op = TestOp.build_generic(
-          results=[None, None, t3],
-          operands=[v0, None, None])
-      print(m)
-
-      # And verify that errors are raised for None in a required operand.
-      try:
-        op = TestOp.build_generic(
-            results=[None, None, t3],
-            operands=[None, None, None])
-      except ValueError as e:
-        # CHECK: OPERAND_CAST_ERROR:Operand 0 of operation "custom.test_op" must be a Value (was None and operand is not optional)
-        print(f"OPERAND_CAST_ERROR:{e}")
-
-      # And verify that errors are raised for None in a required result.
-      try:
-        op = TestOp.build_generic(
-            results=[None, None, None],
-            operands=[v0, None, None])
-      except ValueError as e:
-        # CHECK: RESULT_CAST_ERROR:Result 2 of operation "custom.test_op" must be a Type (was None and result is not optional)
-        print(f"RESULT_CAST_ERROR:{e}")
-
-      # Variadic lists with None elements should reject.
-      try:
-        op = TestOp.build_generic(
-            results=[None, None, t3],
-            operands=[v0, [None], None])
-      except ValueError as e:
-        # CHECK: OPERAND_LIST_CAST_ERROR:Operand 1 of operation "custom.test_op" must be a Sequence of Values (contained a None item)
-        print(f"OPERAND_LIST_CAST_ERROR:{e}")
-      try:
-        op = TestOp.build_generic(
-            results=[[None], None, t3],
-            operands=[v0, None, None])
-      except ValueError as e:
-        # CHECK: RESULT_LIST_CAST_ERROR:Result 0 of operation "custom.test_op" must be a Sequence of Types (contained a None item)
-        print(f"RESULT_LIST_CAST_ERROR:{e}")
 
 run(testOdsBuildDefaultSizedVariadic)
 
 
 def testOdsBuildDefaultCastError():
+    class TestOp(OpView):
+        OPERATION_NAME = "custom.test_op"
+
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+        m = Module.create()
+        with InsertionPoint(m.body):
+            v0 = add_dummy_value()
+            v1 = add_dummy_value()
+            t0 = IntegerType.get_signless(8)
+            t1 = IntegerType.get_signless(16)
+            try:
+                op = TestOp.build_generic(results=[t0, t1], operands=[None, v1])
+            except ValueError as e:
+                # CHECK: ERROR: Operand 0 of operation "custom.test_op" must be a Value
+                print(f"ERROR: {e}")
+            try:
+                op = TestOp.build_generic(results=[t0, None], operands=[v0, v1])
+            except ValueError as e:
+                # CHECK: Result 1 of operation "custom.test_op" must be a Type
+                print(f"ERROR: {e}")
 
-  class TestOp(OpView):
-    OPERATION_NAME = "custom.test_op"
-
-  with Context() as ctx, Location.unknown():
-    ctx.allow_unregistered_dialects = True
-    m = Module.create()
-    with InsertionPoint(m.body):
-      v0 = add_dummy_value()
-      v1 = add_dummy_value()
-      t0 = IntegerType.get_signless(8)
-      t1 = IntegerType.get_signless(16)
-      try:
-        op = TestOp.build_generic(
-            results=[t0, t1],
-            operands=[None, v1])
-      except ValueError as e:
-        # CHECK: ERROR: Operand 0 of operation "custom.test_op" must be a Value
-        print(f"ERROR: {e}")
-      try:
-        op = TestOp.build_generic(
-            results=[t0, None],
-            operands=[v0, v1])
-      except ValueError as e:
-        # CHECK: Result 1 of operation "custom.test_op" must be a Type
-        print(f"ERROR: {e}")
 
 run(testOdsBuildDefaultCastError)
index 3d9cd19..0d364f9 100644 (file)
@@ -5,13 +5,13 @@ from mlir.dialects.pdl import *
 
 
 def constructAndPrintInModule(f):
-  print("\nTEST:", f.__name__)
-  with Context(), Location.unknown():
-    module = Module.create()
-    with InsertionPoint(module.body):
-      f()
-    print(module)
-  return f
+    print("\nTEST:", f.__name__)
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            f()
+        print(module)
+    return f
 
 
 # CHECK: module  {
@@ -27,15 +27,15 @@ def constructAndPrintInModule(f):
 # CHECK: }
 @constructAndPrintInModule
 def test_operations():
-  pattern = PatternOp(1, "operations")
-  with InsertionPoint(pattern.body):
-    attr = AttributeOp()
-    ty = TypeOp()
-    op0 = OperationOp(attributes={"attr": attr}, types=[ty])
-    op0_result = ResultOp(op0, 0)
-    input = OperandOp()
-    root = OperationOp(args=[op0_result, input])
-    RewriteOp(root, "rewriter")
+    pattern = PatternOp(1, "operations")
+    with InsertionPoint(pattern.body):
+        attr = AttributeOp()
+        ty = TypeOp()
+        op0 = OperationOp(attributes={"attr": attr}, types=[ty])
+        op0_result = ResultOp(op0, 0)
+        input = OperandOp()
+        root = OperationOp(args=[op0_result, input])
+        RewriteOp(root, "rewriter")
 
 
 # CHECK: module  {
@@ -47,11 +47,12 @@ def test_operations():
 # CHECK: }
 @constructAndPrintInModule
 def test_rewrite_with_args():
-  pattern = PatternOp(1, "rewrite_with_args")
-  with InsertionPoint(pattern.body):
-    input = OperandOp()
-    root = OperationOp(args=[input])
-    RewriteOp(root, "rewriter", args=[input])
+    pattern = PatternOp(1, "rewrite_with_args")
+    with InsertionPoint(pattern.body):
+        input = OperandOp()
+        root = OperationOp(args=[input])
+        RewriteOp(root, "rewriter", args=[input])
+
 
 # CHECK: module  {
 # CHECK:   pdl.pattern @rewrite_multi_root_optimal : benefit(1)  {
@@ -69,18 +70,19 @@ def test_rewrite_with_args():
 # CHECK: }
 @constructAndPrintInModule
 def test_rewrite_multi_root_optimal():
-  pattern = PatternOp(1, "rewrite_multi_root_optimal")
-  with InsertionPoint(pattern.body):
-    input1 = OperandOp()
-    input2 = OperandOp()
-    ty = TypeOp()
-    op1 = OperationOp(args=[input1], types=[ty])
-    val1 = ResultOp(op1, 0)
-    root1 = OperationOp(args=[val1])
-    op2 = OperationOp(args=[input2], types=[ty])
-    val2 = ResultOp(op2, 0)
-    root2 = OperationOp(args=[val1, val2])
-    RewriteOp(name="rewriter", args=[root1, root2])
+    pattern = PatternOp(1, "rewrite_multi_root_optimal")
+    with InsertionPoint(pattern.body):
+        input1 = OperandOp()
+        input2 = OperandOp()
+        ty = TypeOp()
+        op1 = OperationOp(args=[input1], types=[ty])
+        val1 = ResultOp(op1, 0)
+        root1 = OperationOp(args=[val1])
+        op2 = OperationOp(args=[input2], types=[ty])
+        val2 = ResultOp(op2, 0)
+        root2 = OperationOp(args=[val1, val2])
+        RewriteOp(name="rewriter", args=[root1, root2])
+
 
 # CHECK: module  {
 # CHECK:   pdl.pattern @rewrite_multi_root_forced : benefit(1)  {
@@ -98,18 +100,19 @@ def test_rewrite_multi_root_optimal():
 # CHECK: }
 @constructAndPrintInModule
 def test_rewrite_multi_root_forced():
-  pattern = PatternOp(1, "rewrite_multi_root_forced")
-  with InsertionPoint(pattern.body):
-    input1 = OperandOp()
-    input2 = OperandOp()
-    ty = TypeOp()
-    op1 = OperationOp(args=[input1], types=[ty])
-    val1 = ResultOp(op1, 0)
-    root1 = OperationOp(args=[val1])
-    op2 = OperationOp(args=[input2], types=[ty])
-    val2 = ResultOp(op2, 0)
-    root2 = OperationOp(args=[val1, val2])
-    RewriteOp(root1, name="rewriter", args=[root2])
+    pattern = PatternOp(1, "rewrite_multi_root_forced")
+    with InsertionPoint(pattern.body):
+        input1 = OperandOp()
+        input2 = OperandOp()
+        ty = TypeOp()
+        op1 = OperationOp(args=[input1], types=[ty])
+        val1 = ResultOp(op1, 0)
+        root1 = OperationOp(args=[val1])
+        op2 = OperationOp(args=[input2], types=[ty])
+        val2 = ResultOp(op2, 0)
+        root2 = OperationOp(args=[val1, val2])
+        RewriteOp(root1, name="rewriter", args=[root2])
+
 
 # CHECK: module  {
 # CHECK:   pdl.pattern @rewrite_add_body : benefit(1)  {
@@ -125,16 +128,17 @@ def test_rewrite_multi_root_forced():
 # CHECK: }
 @constructAndPrintInModule
 def test_rewrite_add_body():
-  pattern = PatternOp(1, "rewrite_add_body")
-  with InsertionPoint(pattern.body):
-    ty1 = TypeOp(IntegerType.get_signless(32))
-    ty2 = TypeOp()
-    root = OperationOp(types=[ty1, ty2])
-    rewrite = RewriteOp(root)
-    with InsertionPoint(rewrite.add_body()):
-      ty3 = TypeOp()
-      newOp = OperationOp(name="foo.op", types=[ty1, ty3])
-      ReplaceOp(root, with_op=newOp)
+    pattern = PatternOp(1, "rewrite_add_body")
+    with InsertionPoint(pattern.body):
+        ty1 = TypeOp(IntegerType.get_signless(32))
+        ty2 = TypeOp()
+        root = OperationOp(types=[ty1, ty2])
+        rewrite = RewriteOp(root)
+        with InsertionPoint(rewrite.add_body()):
+            ty3 = TypeOp()
+            newOp = OperationOp(name="foo.op", types=[ty1, ty3])
+            ReplaceOp(root, with_op=newOp)
+
 
 # CHECK: module  {
 # CHECK:   pdl.pattern @rewrite_type : benefit(1)  {
@@ -148,14 +152,15 @@ def test_rewrite_add_body():
 # CHECK: }
 @constructAndPrintInModule
 def test_rewrite_type():
-  pattern = PatternOp(1, "rewrite_type")
-  with InsertionPoint(pattern.body):
-    ty1 = TypeOp(IntegerType.get_signless(32))
-    ty2 = TypeOp()
-    root = OperationOp(types=[ty1, ty2])
-    rewrite = RewriteOp(root)
-    with InsertionPoint(rewrite.add_body()):
-      newOp = OperationOp(name="foo.op", types=[ty1, ty2])
+    pattern = PatternOp(1, "rewrite_type")
+    with InsertionPoint(pattern.body):
+        ty1 = TypeOp(IntegerType.get_signless(32))
+        ty2 = TypeOp()
+        root = OperationOp(types=[ty1, ty2])
+        rewrite = RewriteOp(root)
+        with InsertionPoint(rewrite.add_body()):
+            newOp = OperationOp(name="foo.op", types=[ty1, ty2])
+
 
 # CHECK: module  {
 # CHECK:   pdl.pattern @rewrite_types : benefit(1)  {
@@ -169,14 +174,17 @@ def test_rewrite_type():
 # CHECK: }
 @constructAndPrintInModule
 def test_rewrite_types():
-  pattern = PatternOp(1, "rewrite_types")
-  with InsertionPoint(pattern.body):
-    types = TypesOp()
-    root = OperationOp(types=[types])
-    rewrite = RewriteOp(root)
-    with InsertionPoint(rewrite.add_body()):
-      otherTypes = TypesOp([IntegerType.get_signless(32), IntegerType.get_signless(64)])
-      newOp = OperationOp(name="foo.op", types=[types, otherTypes])
+    pattern = PatternOp(1, "rewrite_types")
+    with InsertionPoint(pattern.body):
+        types = TypesOp()
+        root = OperationOp(types=[types])
+        rewrite = RewriteOp(root)
+        with InsertionPoint(rewrite.add_body()):
+            otherTypes = TypesOp(
+                [IntegerType.get_signless(32), IntegerType.get_signless(64)]
+            )
+            newOp = OperationOp(name="foo.op", types=[types, otherTypes])
+
 
 # CHECK: module  {
 # CHECK:   pdl.pattern @rewrite_operands : benefit(1)  {
@@ -190,14 +198,15 @@ def test_rewrite_types():
 # CHECK: }
 @constructAndPrintInModule
 def test_rewrite_operands():
-  pattern = PatternOp(1, "rewrite_operands")
-  with InsertionPoint(pattern.body):
-    types = TypesOp()
-    operands = OperandsOp(types)
-    root = OperationOp(args=[operands])
-    rewrite = RewriteOp(root)
-    with InsertionPoint(rewrite.add_body()):
-      newOp = OperationOp(name="foo.op", types=[types])
+    pattern = PatternOp(1, "rewrite_operands")
+    with InsertionPoint(pattern.body):
+        types = TypesOp()
+        operands = OperandsOp(types)
+        root = OperationOp(args=[operands])
+        rewrite = RewriteOp(root)
+        with InsertionPoint(rewrite.add_body()):
+            newOp = OperationOp(name="foo.op", types=[types])
+
 
 # CHECK: module  {
 # CHECK:   pdl.pattern @native_rewrite : benefit(1)  {
@@ -209,12 +218,13 @@ def test_rewrite_operands():
 # CHECK: }
 @constructAndPrintInModule
 def test_native_rewrite():
-  pattern = PatternOp(1, "native_rewrite")
-  with InsertionPoint(pattern.body):
-    root = OperationOp()
-    rewrite = RewriteOp(root)
-    with InsertionPoint(rewrite.add_body()):
-      ApplyNativeRewriteOp([], "NativeRewrite", args=[root])
+    pattern = PatternOp(1, "native_rewrite")
+    with InsertionPoint(pattern.body):
+        root = OperationOp()
+        rewrite = RewriteOp(root)
+        with InsertionPoint(rewrite.add_body()):
+            ApplyNativeRewriteOp([], "NativeRewrite", args=[root])
+
 
 # CHECK: module  {
 # CHECK:   pdl.pattern @attribute_with_value : benefit(1)  {
@@ -227,13 +237,14 @@ def test_native_rewrite():
 # CHECK: }
 @constructAndPrintInModule
 def test_attribute_with_value():
-  pattern = PatternOp(1, "attribute_with_value")
-  with InsertionPoint(pattern.body):
-    root = OperationOp()
-    rewrite = RewriteOp(root)
-    with InsertionPoint(rewrite.add_body()):
-      attr = AttributeOp(value=Attribute.parse('"value"'))
-      ApplyNativeRewriteOp([], "NativeRewrite", args=[attr])
+    pattern = PatternOp(1, "attribute_with_value")
+    with InsertionPoint(pattern.body):
+        root = OperationOp()
+        rewrite = RewriteOp(root)
+        with InsertionPoint(rewrite.add_body()):
+            attr = AttributeOp(value=Attribute.parse('"value"'))
+            ApplyNativeRewriteOp([], "NativeRewrite", args=[attr])
+
 
 # CHECK: module  {
 # CHECK:   pdl.pattern @erase : benefit(1)  {
@@ -245,12 +256,13 @@ def test_attribute_with_value():
 # CHECK: }
 @constructAndPrintInModule
 def test_erase():
-  pattern = PatternOp(1, "erase")
-  with InsertionPoint(pattern.body):
-    root = OperationOp()
-    rewrite = RewriteOp(root)
-    with InsertionPoint(rewrite.add_body()):
-      EraseOp(root)
+    pattern = PatternOp(1, "erase")
+    with InsertionPoint(pattern.body):
+        root = OperationOp()
+        rewrite = RewriteOp(root)
+        with InsertionPoint(rewrite.add_body()):
+            EraseOp(root)
+
 
 # CHECK: module  {
 # CHECK:   pdl.pattern @operation_results : benefit(1)  {
@@ -263,14 +275,15 @@ def test_erase():
 # CHECK: }
 @constructAndPrintInModule
 def test_operation_results():
-  valueRange = RangeType.get(ValueType.get())
-  pattern = PatternOp(1, "operation_results")
-  with InsertionPoint(pattern.body):
-    types = TypesOp()
-    inputOp = OperationOp(types=[types])
-    results = ResultsOp(valueRange, inputOp)
-    root = OperationOp(args=[results])
-    RewriteOp(root, name="rewriter")
+    valueRange = RangeType.get(ValueType.get())
+    pattern = PatternOp(1, "operation_results")
+    with InsertionPoint(pattern.body):
+        types = TypesOp()
+        inputOp = OperationOp(types=[types])
+        results = ResultsOp(valueRange, inputOp)
+        root = OperationOp(args=[results])
+        RewriteOp(root, name="rewriter")
+
 
 # CHECK: module  {
 # CHECK:   pdl.pattern : benefit(1)  {
@@ -282,9 +295,9 @@ def test_operation_results():
 # CHECK: }
 @constructAndPrintInModule
 def test_apply_native_constraint():
-  pattern = PatternOp(1)
-  with InsertionPoint(pattern.body):
-    resultType = TypeOp()
-    ApplyNativeConstraintOp("typeConstraint", args=[resultType])
-    root = OperationOp(types=[resultType])
-    RewriteOp(root, name="rewrite")
+    pattern = PatternOp(1)
+    with InsertionPoint(pattern.body):
+        resultType = TypeOp()
+        ApplyNativeConstraintOp("typeConstraint", args=[resultType])
+        root = OperationOp(types=[resultType])
+        RewriteOp(root, name="rewrite")
index 2ca79b2..72a765c 100644 (file)
@@ -5,367 +5,373 @@ import mlir.dialects.func as func
 import mlir.dialects.python_test as test
 import mlir.dialects.tensor as tensor
 
+
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    return f
+
 
 # CHECK-LABEL: TEST: testAttributes
 @run
 def testAttributes():
-  with Context() as ctx, Location.unknown():
-    ctx.allow_unregistered_dialects = True
-
-    #
-    # Check op construction with attributes.
-    #
-
-    i32 = IntegerType.get_signless(32)
-    one = IntegerAttr.get(i32, 1)
-    two = IntegerAttr.get(i32, 2)
-    unit = UnitAttr.get()
-
-    # CHECK: "python_test.attributed_op"() {
-    # CHECK-DAG: mandatory_i32 = 1 : i32
-    # CHECK-DAG: optional_i32 = 2 : i32
-    # CHECK-DAG: unit
-    # CHECK: }
-    op = test.AttributedOp(one, optional_i32=two, unit=unit)
-    print(f"{op}")
-
-    # CHECK: "python_test.attributed_op"() {
-    # CHECK: mandatory_i32 = 2 : i32
-    # CHECK: }
-    op2 = test.AttributedOp(two)
-    print(f"{op2}")
-
-    #
-    # Check generic "attributes" access and mutation.
-    #
-
-    assert "additional" not in op.attributes
-
-    # CHECK: "python_test.attributed_op"() {
-    # CHECK-DAG: additional = 1 : i32
-    # CHECK-DAG: mandatory_i32 = 2 : i32
-    # CHECK: }
-    op2.attributes["additional"] = one
-    print(f"{op2}")
-
-    # CHECK: "python_test.attributed_op"() {
-    # CHECK-DAG: additional = 2 : i32
-    # CHECK-DAG: mandatory_i32 = 2 : i32
-    # CHECK: }
-    op2.attributes["additional"] = two
-    print(f"{op2}")
-
-    # CHECK: "python_test.attributed_op"() {
-    # CHECK-NOT: additional = 2 : i32
-    # CHECK:     mandatory_i32 = 2 : i32
-    # CHECK: }
-    del op2.attributes["additional"]
-    print(f"{op2}")
-
-    try:
-      print(op.attributes["additional"])
-    except KeyError:
-      pass
-    else:
-      assert False, "expected KeyError on unknown attribute key"
-
-    #
-    # Check accessors to defined attributes.
-    #
-
-    # CHECK: Mandatory: 1
-    # CHECK: Optional: 2
-    # CHECK: Unit: True
-    print(f"Mandatory: {op.mandatory_i32.value}")
-    print(f"Optional: {op.optional_i32.value}")
-    print(f"Unit: {op.unit}")
-
-    # CHECK: Mandatory: 2
-    # CHECK: Optional: None
-    # CHECK: Unit: False
-    print(f"Mandatory: {op2.mandatory_i32.value}")
-    print(f"Optional: {op2.optional_i32}")
-    print(f"Unit: {op2.unit}")
-
-    # CHECK: Mandatory: 2
-    # CHECK: Optional: None
-    # CHECK: Unit: False
-    op.mandatory_i32 = two
-    op.optional_i32 = None
-    op.unit = False
-    print(f"Mandatory: {op.mandatory_i32.value}")
-    print(f"Optional: {op.optional_i32}")
-    print(f"Unit: {op.unit}")
-    assert "optional_i32" not in op.attributes
-    assert "unit" not in op.attributes
-
-    try:
-      op.mandatory_i32 = None
-    except ValueError:
-      pass
-    else:
-      assert False, "expected ValueError on setting a mandatory attribute to None"
-
-    # CHECK: Optional: 2
-    op.optional_i32 = two
-    print(f"Optional: {op.optional_i32.value}")
-
-    # CHECK: Optional: None
-    del op.optional_i32
-    print(f"Optional: {op.optional_i32}")
-
-    # CHECK: Unit: False
-    op.unit = None
-    print(f"Unit: {op.unit}")
-    assert "unit" not in op.attributes
-
-    # CHECK: Unit: True
-    op.unit = True
-    print(f"Unit: {op.unit}")
-
-    # CHECK: Unit: False
-    del op.unit
-    print(f"Unit: {op.unit}")
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+
+        #
+        # Check op construction with attributes.
+        #
+
+        i32 = IntegerType.get_signless(32)
+        one = IntegerAttr.get(i32, 1)
+        two = IntegerAttr.get(i32, 2)
+        unit = UnitAttr.get()
+
+        # CHECK: "python_test.attributed_op"() {
+        # CHECK-DAG: mandatory_i32 = 1 : i32
+        # CHECK-DAG: optional_i32 = 2 : i32
+        # CHECK-DAG: unit
+        # CHECK: }
+        op = test.AttributedOp(one, optional_i32=two, unit=unit)
+        print(f"{op}")
+
+        # CHECK: "python_test.attributed_op"() {
+        # CHECK: mandatory_i32 = 2 : i32
+        # CHECK: }
+        op2 = test.AttributedOp(two)
+        print(f"{op2}")
+
+        #
+        # Check generic "attributes" access and mutation.
+        #
+
+        assert "additional" not in op.attributes
+
+        # CHECK: "python_test.attributed_op"() {
+        # CHECK-DAG: additional = 1 : i32
+        # CHECK-DAG: mandatory_i32 = 2 : i32
+        # CHECK: }
+        op2.attributes["additional"] = one
+        print(f"{op2}")
+
+        # CHECK: "python_test.attributed_op"() {
+        # CHECK-DAG: additional = 2 : i32
+        # CHECK-DAG: mandatory_i32 = 2 : i32
+        # CHECK: }
+        op2.attributes["additional"] = two
+        print(f"{op2}")
+
+        # CHECK: "python_test.attributed_op"() {
+        # CHECK-NOT: additional = 2 : i32
+        # CHECK:     mandatory_i32 = 2 : i32
+        # CHECK: }
+        del op2.attributes["additional"]
+        print(f"{op2}")
+
+        try:
+            print(op.attributes["additional"])
+        except KeyError:
+            pass
+        else:
+            assert False, "expected KeyError on unknown attribute key"
+
+        #
+        # Check accessors to defined attributes.
+        #
+
+        # CHECK: Mandatory: 1
+        # CHECK: Optional: 2
+        # CHECK: Unit: True
+        print(f"Mandatory: {op.mandatory_i32.value}")
+        print(f"Optional: {op.optional_i32.value}")
+        print(f"Unit: {op.unit}")
+
+        # CHECK: Mandatory: 2
+        # CHECK: Optional: None
+        # CHECK: Unit: False
+        print(f"Mandatory: {op2.mandatory_i32.value}")
+        print(f"Optional: {op2.optional_i32}")
+        print(f"Unit: {op2.unit}")
+
+        # CHECK: Mandatory: 2
+        # CHECK: Optional: None
+        # CHECK: Unit: False
+        op.mandatory_i32 = two
+        op.optional_i32 = None
+        op.unit = False
+        print(f"Mandatory: {op.mandatory_i32.value}")
+        print(f"Optional: {op.optional_i32}")
+        print(f"Unit: {op.unit}")
+        assert "optional_i32" not in op.attributes
+        assert "unit" not in op.attributes
+
+        try:
+            op.mandatory_i32 = None
+        except ValueError:
+            pass
+        else:
+            assert False, "expected ValueError on setting a mandatory attribute to None"
+
+        # CHECK: Optional: 2
+        op.optional_i32 = two
+        print(f"Optional: {op.optional_i32.value}")
+
+        # CHECK: Optional: None
+        del op.optional_i32
+        print(f"Optional: {op.optional_i32}")
+
+        # CHECK: Unit: False
+        op.unit = None
+        print(f"Unit: {op.unit}")
+        assert "unit" not in op.attributes
+
+        # CHECK: Unit: True
+        op.unit = True
+        print(f"Unit: {op.unit}")
+
+        # CHECK: Unit: False
+        del op.unit
+        print(f"Unit: {op.unit}")
+
 
 # CHECK-LABEL: TEST: attrBuilder
 @run
 def attrBuilder():
-  with Context() as ctx, Location.unknown():
-    ctx.allow_unregistered_dialects = True
-    op = test.AttributesOp(x_bool=True,
-                           x_i16=1,
-                           x_i32=2,
-                           x_i64=3,
-                           x_si16=-1,
-                           x_si32=-2,
-                           x_f32=1.5,
-                           x_f64=2.5,
-                           x_str='x_str',
-                           x_i32_array=[1, 2, 3],
-                           x_i64_array=[4, 5, 6],
-                           x_f32_array=[1.5, -2.5, 3.5],
-                           x_f64_array=[4.5, 5.5, -6.5],
-                           x_i64_dense=[1, 2, 3, 4, 5, 6])
-    print(op)
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+        op = test.AttributesOp(
+            x_bool=True,
+            x_i16=1,
+            x_i32=2,
+            x_i64=3,
+            x_si16=-1,
+            x_si32=-2,
+            x_f32=1.5,
+            x_f64=2.5,
+            x_str="x_str",
+            x_i32_array=[1, 2, 3],
+            x_i64_array=[4, 5, 6],
+            x_f32_array=[1.5, -2.5, 3.5],
+            x_f64_array=[4.5, 5.5, -6.5],
+            x_i64_dense=[1, 2, 3, 4, 5, 6],
+        )
+        print(op)
 
 
 # CHECK-LABEL: TEST: inferReturnTypes
 @run
 def inferReturnTypes():
-  with Context() as ctx, Location.unknown(ctx):
-    test.register_python_test_dialect(ctx)
-    module = Module.create()
-    with InsertionPoint(module.body):
-      op = test.InferResultsOp()
-      dummy = test.DummyOp()
-
-    # CHECK: [Type(i32), Type(i64)]
-    iface = InferTypeOpInterface(op)
-    print(iface.inferReturnTypes())
-
-    # CHECK: [Type(i32), Type(i64)]
-    iface_static = InferTypeOpInterface(test.InferResultsOp)
-    print(iface.inferReturnTypes())
-
-    assert isinstance(iface.opview, test.InferResultsOp)
-    assert iface.opview == iface.operation.opview
-
-    try:
-      iface_static.opview
-    except TypeError:
-      pass
-    else:
-      assert False, ("not expected to be able to obtain an opview from a static"
-                     " interface")
-
-    try:
-      InferTypeOpInterface(dummy)
-    except ValueError:
-      pass
-    else:
-      assert False, "not expected dummy op to implement the interface"
-
-    try:
-      InferTypeOpInterface(test.DummyOp)
-    except ValueError:
-      pass
-    else:
-      assert False, "not expected dummy op class to implement the interface"
+    with Context() as ctx, Location.unknown(ctx):
+        test.register_python_test_dialect(ctx)
+        module = Module.create()
+        with InsertionPoint(module.body):
+            op = test.InferResultsOp()
+            dummy = test.DummyOp()
+
+        # CHECK: [Type(i32), Type(i64)]
+        iface = InferTypeOpInterface(op)
+        print(iface.inferReturnTypes())
+
+        # CHECK: [Type(i32), Type(i64)]
+        iface_static = InferTypeOpInterface(test.InferResultsOp)
+        print(iface.inferReturnTypes())
+
+        assert isinstance(iface.opview, test.InferResultsOp)
+        assert iface.opview == iface.operation.opview
+
+        try:
+            iface_static.opview
+        except TypeError:
+            pass
+        else:
+            assert False, (
+                "not expected to be able to obtain an opview from a static" " interface"
+            )
+
+        try:
+            InferTypeOpInterface(dummy)
+        except ValueError:
+            pass
+        else:
+            assert False, "not expected dummy op to implement the interface"
+
+        try:
+            InferTypeOpInterface(test.DummyOp)
+        except ValueError:
+            pass
+        else:
+            assert False, "not expected dummy op class to implement the interface"
 
 
 # CHECK-LABEL: TEST: resultTypesDefinedByTraits
 @run
 def resultTypesDefinedByTraits():
-  with Context() as ctx, Location.unknown(ctx):
-    test.register_python_test_dialect(ctx)
-    module = Module.create()
-    with InsertionPoint(module.body):
-      inferred = test.InferResultsOp()
-      same = test.SameOperandAndResultTypeOp([inferred.results[0]])
-      # CHECK-COUNT-2: i32
-      print(same.one.type)
-      print(same.two.type)
-
-      first_type_attr = test.FirstAttrDeriveTypeAttrOp(
-          inferred.results[1], TypeAttr.get(IndexType.get()))
-      # CHECK-COUNT-2: index
-      print(first_type_attr.one.type)
-      print(first_type_attr.two.type)
-
-      first_attr = test.FirstAttrDeriveAttrOp(
-          FloatAttr.get(F32Type.get(), 3.14))
-      # CHECK-COUNT-3: f32
-      print(first_attr.one.type)
-      print(first_attr.two.type)
-      print(first_attr.three.type)
-
-      implied = test.InferResultsImpliedOp()
-      # CHECK: i32
-      print(implied.integer.type)
-      # CHECK: f64
-      print(implied.flt.type)
-      # CHECK: index
-      print(implied.index.type)
+    with Context() as ctx, Location.unknown(ctx):
+        test.register_python_test_dialect(ctx)
+        module = Module.create()
+        with InsertionPoint(module.body):
+            inferred = test.InferResultsOp()
+            same = test.SameOperandAndResultTypeOp([inferred.results[0]])
+            # CHECK-COUNT-2: i32
+            print(same.one.type)
+            print(same.two.type)
+
+            first_type_attr = test.FirstAttrDeriveTypeAttrOp(
+                inferred.results[1], TypeAttr.get(IndexType.get())
+            )
+            # CHECK-COUNT-2: index
+            print(first_type_attr.one.type)
+            print(first_type_attr.two.type)
+
+            first_attr = test.FirstAttrDeriveAttrOp(FloatAttr.get(F32Type.get(), 3.14))
+            # CHECK-COUNT-3: f32
+            print(first_attr.one.type)
+            print(first_attr.two.type)
+            print(first_attr.three.type)
+
+            implied = test.InferResultsImpliedOp()
+            # CHECK: i32
+            print(implied.integer.type)
+            # CHECK: f64
+            print(implied.flt.type)
+            # CHECK: index
+            print(implied.index.type)
 
 
 # CHECK-LABEL: TEST: testOptionalOperandOp
 @run
 def testOptionalOperandOp():
-  with Context() as ctx, Location.unknown():
-    test.register_python_test_dialect(ctx)
+    with Context() as ctx, Location.unknown():
+        test.register_python_test_dialect(ctx)
 
-    module = Module.create()
-    with InsertionPoint(module.body):
+        module = Module.create()
+        with InsertionPoint(module.body):
 
-      op1 = test.OptionalOperandOp()
-      # CHECK: op1.input is None: True
-      print(f"op1.input is None: {op1.input is None}")
+            op1 = test.OptionalOperandOp()
+            # CHECK: op1.input is None: True
+            print(f"op1.input is None: {op1.input is None}")
 
-      op2 = test.OptionalOperandOp(input=op1)
-      # CHECK: op2.input is None: False
-      print(f"op2.input is None: {op2.input is None}")
+            op2 = test.OptionalOperandOp(input=op1)
+            # CHECK: op2.input is None: False
+            print(f"op2.input is None: {op2.input is None}")
 
 
 # CHECK-LABEL: TEST: testCustomAttribute
 @run
 def testCustomAttribute():
-  with Context() as ctx:
-    test.register_python_test_dialect(ctx)
-    a = test.TestAttr.get()
-    # CHECK: #python_test.test_attr
-    print(a)
-
-    # The following cast must not assert.
-    b = test.TestAttr(a)
-
-    unit = UnitAttr.get()
-    try:
-      test.TestAttr(unit)
-    except ValueError as e:
-      assert "Cannot cast attribute to TestAttr" in str(e)
-    else:
-      raise
-
-    # The following must trigger a TypeError from our adaptors and must not
-    # crash.
-    try:
-      test.TestAttr(42)
-    except TypeError as e:
-      assert "Expected an MLIR object" in str(e)
-    else:
-      raise
-
-    # The following must trigger a TypeError from pybind (therefore, not
-    # checking its message) and must not crash.
-    try:
-      test.TestAttr(42, 56)
-    except TypeError:
-      pass
-    else:
-      raise
+    with Context() as ctx:
+        test.register_python_test_dialect(ctx)
+        a = test.TestAttr.get()
+        # CHECK: #python_test.test_attr
+        print(a)
+
+        # The following cast must not assert.
+        b = test.TestAttr(a)
+
+        unit = UnitAttr.get()
+        try:
+            test.TestAttr(unit)
+        except ValueError as e:
+            assert "Cannot cast attribute to TestAttr" in str(e)
+        else:
+            raise
+
+        # The following must trigger a TypeError from our adaptors and must not
+        # crash.
+        try:
+            test.TestAttr(42)
+        except TypeError as e:
+            assert "Expected an MLIR object" in str(e)
+        else:
+            raise
+
+        # The following must trigger a TypeError from pybind (therefore, not
+        # checking its message) and must not crash.
+        try:
+            test.TestAttr(42, 56)
+        except TypeError:
+            pass
+        else:
+            raise
 
 
 @run
 def testCustomType():
-  with Context() as ctx:
-    test.register_python_test_dialect(ctx)
-    a = test.TestType.get()
-    # CHECK: !python_test.test_type
-    print(a)
-
-    # The following cast must not assert.
-    b = test.TestType(a)
-    # Instance custom types should have typeids
-    assert isinstance(b.typeid, TypeID)
-    # Subclasses of ir.Type should not have a static_typeid
-    # CHECK: 'TestType' object has no attribute 'static_typeid'
-    try:
-      b.static_typeid
-    except AttributeError as e:
-      print(e)
-
-    i8 = IntegerType.get_signless(8)
-    try:
-      test.TestType(i8)
-    except ValueError as e:
-      assert "Cannot cast type to TestType" in str(e)
-    else:
-      raise
-
-    # The following must trigger a TypeError from our adaptors and must not
-    # crash.
-    try:
-      test.TestType(42)
-    except TypeError as e:
-      assert "Expected an MLIR object" in str(e)
-    else:
-      raise
-
-    # The following must trigger a TypeError from pybind (therefore, not
-    # checking its message) and must not crash.
-    try:
-      test.TestType(42, 56)
-    except TypeError:
-      pass
-    else:
-      raise
+    with Context() as ctx:
+        test.register_python_test_dialect(ctx)
+        a = test.TestType.get()
+        # CHECK: !python_test.test_type
+        print(a)
+
+        # The following cast must not assert.
+        b = test.TestType(a)
+        # Instance custom types should have typeids
+        assert isinstance(b.typeid, TypeID)
+        # Subclasses of ir.Type should not have a static_typeid
+        # CHECK: 'TestType' object has no attribute 'static_typeid'
+        try:
+            b.static_typeid
+        except AttributeError as e:
+            print(e)
+
+        i8 = IntegerType.get_signless(8)
+        try:
+            test.TestType(i8)
+        except ValueError as e:
+            assert "Cannot cast type to TestType" in str(e)
+        else:
+            raise
+
+        # The following must trigger a TypeError from our adaptors and must not
+        # crash.
+        try:
+            test.TestType(42)
+        except TypeError as e:
+            assert "Expected an MLIR object" in str(e)
+        else:
+            raise
+
+        # The following must trigger a TypeError from pybind (therefore, not
+        # checking its message) and must not crash.
+        try:
+            test.TestType(42, 56)
+        except TypeError:
+            pass
+        else:
+            raise
 
 
 @run
 # CHECK-LABEL: TEST: testTensorValue
 def testTensorValue():
-  with Context() as ctx, Location.unknown():
-    test.register_python_test_dialect(ctx)
+    with Context() as ctx, Location.unknown():
+        test.register_python_test_dialect(ctx)
 
-    i8 = IntegerType.get_signless(8)
+        i8 = IntegerType.get_signless(8)
 
-    class Tensor(test.TestTensorValue):
-      def __str__(self):
-        return super().__str__().replace("Value", "Tensor")
+        class Tensor(test.TestTensorValue):
+            def __str__(self):
+                return super().__str__().replace("Value", "Tensor")
 
-    module = Module.create()
-    with InsertionPoint(module.body):
-      t = tensor.EmptyOp([10, 10], i8).result
+        module = Module.create()
+        with InsertionPoint(module.body):
+            t = tensor.EmptyOp([10, 10], i8).result
 
-      # CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
-      print(Value(t))
+            # CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
+            print(Value(t))
 
-      tt = Tensor(t)
-      # CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
-      print(tt)
+            tt = Tensor(t)
+            # CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
+            print(tt)
 
-      # CHECK: False
-      print(tt.is_null())
+            # CHECK: False
+            print(tt.is_null())
 
-      # Classes of custom types that inherit from concrete types should have
-      # static_typeid
-      assert isinstance(test.TestTensorType.static_typeid, TypeID)
-      # And it should be equal to the in-tree concrete type
-      assert test.TestTensorType.static_typeid == t.type.typeid
+            # Classes of custom types that inherit from concrete types should have
+            # static_typeid
+            assert isinstance(test.TestTensorType.static_typeid, TypeID)
+            # And it should be equal to the in-tree concrete type
+            assert test.TestTensorType.static_typeid == t.type.typeid
 
 
 # CHECK-LABEL: TEST: inferReturnTypeComponents
@@ -412,7 +418,7 @@ def inferReturnTypeComponents():
         # CHECK: shape: None
         iface = InferShapedTypeOpInterface(unranked_op)
         shaped_type_components = iface.inferReturnTypeComponents(
-          operands=[unranked_op.operand]
+            operands=[unranked_op.operand]
         )[0]
         print("has rank:", shaped_type_components.has_rank)
         print("rank:", shaped_type_components.rank)
index 32614be..0ee3327 100644 (file)
@@ -5,127 +5,133 @@ from mlir.dialects import quant
 
 
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    return f
 
 
 # CHECK-LABEL: TEST: test_type_hierarchy
 @run
 def test_type_hierarchy():
-  with Context():
-    i8 = IntegerType.get_signless(8)
-    any = Type.parse("!quant.any<i8<-8:7>:f32>")
-    uniform = Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>")
-    per_axis = Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
-    calibrated = Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
+    with Context():
+        i8 = IntegerType.get_signless(8)
+        any = Type.parse("!quant.any<i8<-8:7>:f32>")
+        uniform = Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>")
+        per_axis = Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
+        calibrated = Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
 
-    assert not quant.QuantizedType.isinstance(i8)
-    assert quant.QuantizedType.isinstance(any)
-    assert quant.QuantizedType.isinstance(uniform)
-    assert quant.QuantizedType.isinstance(per_axis)
-    assert quant.QuantizedType.isinstance(calibrated)
+        assert not quant.QuantizedType.isinstance(i8)
+        assert quant.QuantizedType.isinstance(any)
+        assert quant.QuantizedType.isinstance(uniform)
+        assert quant.QuantizedType.isinstance(per_axis)
+        assert quant.QuantizedType.isinstance(calibrated)
 
-    assert quant.AnyQuantizedType.isinstance(any)
-    assert quant.UniformQuantizedType.isinstance(uniform)
-    assert quant.UniformQuantizedPerAxisType.isinstance(per_axis)
-    assert quant.CalibratedQuantizedType.isinstance(calibrated)
+        assert quant.AnyQuantizedType.isinstance(any)
+        assert quant.UniformQuantizedType.isinstance(uniform)
+        assert quant.UniformQuantizedPerAxisType.isinstance(per_axis)
+        assert quant.CalibratedQuantizedType.isinstance(calibrated)
 
-    assert not quant.AnyQuantizedType.isinstance(uniform)
-    assert not quant.UniformQuantizedType.isinstance(per_axis)
+        assert not quant.AnyQuantizedType.isinstance(uniform)
+        assert not quant.UniformQuantizedType.isinstance(per_axis)
 
 
 # CHECK-LABEL: TEST: test_any_quantized_type
 @run
 def test_any_quantized_type():
-  with Context():
-    i8 = IntegerType.get_signless(8)
-    f32 = F32Type.get()
-    any = quant.AnyQuantizedType.get(quant.QuantizedType.FLAG_SIGNED, i8, f32,
-                                     -8, 7)
-
-    # CHECK: flags: 1
-    print(f"flags: {any.flags}")
-    # CHECK: signed: True
-    print(f"signed: {any.is_signed}")
-    # CHECK: storage type: i8
-    print(f"storage type: {any.storage_type}")
-    # CHECK: expressed type: f32
-    print(f"expressed type: {any.expressed_type}")
-    # CHECK: storage min: -8
-    print(f"storage min: {any.storage_type_min}")
-    # CHECK: storage max: 7
-    print(f"storage max: {any.storage_type_max}")
-    # CHECK: storage width: 8
-    print(f"storage width: {any.storage_type_integral_width}")
-    # CHECK: quantized element type: !quant.any<i8<-8:7>:f32>
-    print(f"quantized element type: {any.quantized_element_type}")
-    # CHECK: !quant.any<i8<-8:7>:f32>
-    print(any)
-    assert any == Type.parse("!quant.any<i8<-8:7>:f32>")
+    with Context():
+        i8 = IntegerType.get_signless(8)
+        f32 = F32Type.get()
+        any = quant.AnyQuantizedType.get(
+            quant.QuantizedType.FLAG_SIGNED, i8, f32, -8, 7
+        )
+
+        # CHECK: flags: 1
+        print(f"flags: {any.flags}")
+        # CHECK: signed: True
+        print(f"signed: {any.is_signed}")
+        # CHECK: storage type: i8
+        print(f"storage type: {any.storage_type}")
+        # CHECK: expressed type: f32
+        print(f"expressed type: {any.expressed_type}")
+        # CHECK: storage min: -8
+        print(f"storage min: {any.storage_type_min}")
+        # CHECK: storage max: 7
+        print(f"storage max: {any.storage_type_max}")
+        # CHECK: storage width: 8
+        print(f"storage width: {any.storage_type_integral_width}")
+        # CHECK: quantized element type: !quant.any<i8<-8:7>:f32>
+        print(f"quantized element type: {any.quantized_element_type}")
+        # CHECK: !quant.any<i8<-8:7>:f32>
+        print(any)
+        assert any == Type.parse("!quant.any<i8<-8:7>:f32>")
 
 
 # CHECK-LABEL: TEST: test_uniform_type
 @run
 def test_uniform_type():
-  with Context():
-    i8 = IntegerType.get_signless(8)
-    f32 = F32Type.get()
-    uniform = quant.UniformQuantizedType.get(
-        quant.UniformQuantizedType.FLAG_SIGNED, i8, f32, 0.99872, 127, -8, 7)
-
-    # CHECK: scale: 0.99872
-    print(f"scale: {uniform.scale}")
-    # CHECK: zero point: 127
-    print(f"zero point: {uniform.zero_point}")
-    # CHECK: fixed point: False
-    print(f"fixed point: {uniform.is_fixed_point}")
-    # CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127>
-    print(uniform)
-    assert uniform == Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>")
+    with Context():
+        i8 = IntegerType.get_signless(8)
+        f32 = F32Type.get()
+        uniform = quant.UniformQuantizedType.get(
+            quant.UniformQuantizedType.FLAG_SIGNED, i8, f32, 0.99872, 127, -8, 7
+        )
+
+        # CHECK: scale: 0.99872
+        print(f"scale: {uniform.scale}")
+        # CHECK: zero point: 127
+        print(f"zero point: {uniform.zero_point}")
+        # CHECK: fixed point: False
+        print(f"fixed point: {uniform.is_fixed_point}")
+        # CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127>
+        print(uniform)
+        assert uniform == Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>")
 
 
 # CHECK-LABEL: TEST: test_uniform_per_axis_type
 @run
 def test_uniform_per_axis_type():
-  with Context():
-    i8 = IntegerType.get_signless(8)
-    f32 = F32Type.get()
-    per_axis = quant.UniformQuantizedPerAxisType.get(
-        quant.QuantizedType.FLAG_SIGNED,
-        i8,
-        f32, [200, 0.99872], [0, 120],
-        quantized_dimension=1,
-        storage_type_min=quant.QuantizedType.default_minimum_for_integer(
-            is_signed=True, integral_width=8),
-        storage_type_max=quant.QuantizedType.default_maximum_for_integer(
-            is_signed=True, integral_width=8))
-
-    # CHECK: scales: None
-    print(f"scales: {per_axis.scales}")
-    # CHECK: zero_points: None
-    print(f"zero_points: {per_axis.zero_points}")
-    # CHECK: quantized dim: 1
-    print(f"quantized dim: {per_axis.quantized_dimension}")
-    # CHECK: fixed point: False
-    print(f"fixed point: {per_axis.is_fixed_point}")
-    # CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}>
-    print(per_axis)
-    assert per_axis == Type.parse(
-        "!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
+    with Context():
+        i8 = IntegerType.get_signless(8)
+        f32 = F32Type.get()
+        per_axis = quant.UniformQuantizedPerAxisType.get(
+            quant.QuantizedType.FLAG_SIGNED,
+            i8,
+            f32,
+            [200, 0.99872],
+            [0, 120],
+            quantized_dimension=1,
+            storage_type_min=quant.QuantizedType.default_minimum_for_integer(
+                is_signed=True, integral_width=8
+            ),
+            storage_type_max=quant.QuantizedType.default_maximum_for_integer(
+                is_signed=True, integral_width=8
+            ),
+        )
+
+        # CHECK: scales: None
+        print(f"scales: {per_axis.scales}")
+        # CHECK: zero_points: None
+        print(f"zero_points: {per_axis.zero_points}")
+        # CHECK: quantized dim: 1
+        print(f"quantized dim: {per_axis.quantized_dimension}")
+        # CHECK: fixed point: False
+        print(f"fixed point: {per_axis.is_fixed_point}")
+        # CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}>
+        print(per_axis)
+        assert per_axis == Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>")
 
 
 # CHECK-LABEL: TEST: test_calibrated_type
 @run
 def test_calibrated_type():
-  with Context():
-    f32 = F32Type.get()
-    calibrated = quant.CalibratedQuantizedType.get(f32, -0.998, 1.2321)
-
-    # CHECK: min: -0.998
-    print(f"min: {calibrated.min}")
-    # CHECK: max: 1.2321
-    print(f"max: {calibrated.max}")
-    # CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>>
-    print(calibrated)
-    assert calibrated == Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
+    with Context():
+        f32 = F32Type.get()
+        calibrated = quant.CalibratedQuantizedType.get(f32, -0.998, 1.2321)
+
+        # CHECK: min: -0.998
+        print(f"min: {calibrated.min}")
+        # CHECK: max: 1.2321
+        print(f"max: {calibrated.max}")
+        # CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>>
+        print(calibrated)
+        assert calibrated == Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
index 4a618ff..8cb55fd 100644 (file)
@@ -8,26 +8,26 @@ from mlir.dialects import builtin
 
 
 def constructAndPrintInModule(f):
-  print("\nTEST:", f.__name__)
-  with Context(), Location.unknown():
-    module = Module.create()
-    with InsertionPoint(module.body):
-      f()
-    print(module)
-  return f
+    print("\nTEST:", f.__name__)
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            f()
+        print(module)
+    return f
 
 
 # CHECK-LABEL: TEST: testSimpleLoop
 @constructAndPrintInModule
 def testSimpleLoop():
-  index_type = IndexType.get()
+    index_type = IndexType.get()
 
-  @func.FuncOp.from_py_func(index_type, index_type, index_type)
-  def simple_loop(lb, ub, step):
-    loop = scf.ForOp(lb, ub, step, [lb, lb])
-    with InsertionPoint(loop.body):
-      scf.YieldOp(loop.inner_iter_args)
-    return
+    @func.FuncOp.from_py_func(index_type, index_type, index_type)
+    def simple_loop(lb, ub, step):
+        loop = scf.ForOp(lb, ub, step, [lb, lb])
+        with InsertionPoint(loop.body):
+            scf.YieldOp(loop.inner_iter_args)
+        return
 
 
 # CHECK: func @simple_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
@@ -39,14 +39,14 @@ def testSimpleLoop():
 # CHECK-LABEL: TEST: testInductionVar
 @constructAndPrintInModule
 def testInductionVar():
-  index_type = IndexType.get()
+    index_type = IndexType.get()
 
-  @func.FuncOp.from_py_func(index_type, index_type, index_type)
-  def induction_var(lb, ub, step):
-    loop = scf.ForOp(lb, ub, step, [lb])
-    with InsertionPoint(loop.body):
-      scf.YieldOp([loop.induction_variable])
-    return
+    @func.FuncOp.from_py_func(index_type, index_type, index_type)
+    def induction_var(lb, ub, step):
+        loop = scf.ForOp(lb, ub, step, [lb])
+        with InsertionPoint(loop.body):
+            scf.YieldOp([loop.induction_variable])
+        return
 
 
 # CHECK: func @induction_var(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
@@ -56,19 +56,18 @@ def testInductionVar():
 
 @constructAndPrintInModule
 def testOpsAsArguments():
-  index_type = IndexType.get()
-  callee = func.FuncOp(
-      "callee", ([], [index_type, index_type]), visibility="private")
-  f = func.FuncOp("ops_as_arguments", ([], []))
-  with InsertionPoint(f.add_entry_block()):
-    lb = arith.ConstantOp.create_index(0)
-    ub = arith.ConstantOp.create_index(42)
-    step = arith.ConstantOp.create_index(2)
-    iter_args = func.CallOp(callee, [])
-    loop = scf.ForOp(lb, ub, step, iter_args)
-    with InsertionPoint(loop.body):
-      scf.YieldOp(loop.inner_iter_args)
-    func.ReturnOp([])
+    index_type = IndexType.get()
+    callee = func.FuncOp("callee", ([], [index_type, index_type]), visibility="private")
+    f = func.FuncOp("ops_as_arguments", ([], []))
+    with InsertionPoint(f.add_entry_block()):
+        lb = arith.ConstantOp.create_index(0)
+        ub = arith.ConstantOp.create_index(42)
+        step = arith.ConstantOp.create_index(2)
+        iter_args = func.CallOp(callee, [])
+        loop = scf.ForOp(lb, ub, step, iter_args)
+        with InsertionPoint(loop.body):
+            scf.YieldOp(loop.inner_iter_args)
+        func.ReturnOp([])
 
 
 # CHECK-LABEL: TEST: testOpsAsArguments
@@ -86,17 +85,17 @@ def testOpsAsArguments():
 
 @constructAndPrintInModule
 def testIfWithoutElse():
-  bool = IntegerType.get_signless(1)
-  i32 = IntegerType.get_signless(32)
+    bool = IntegerType.get_signless(1)
+    i32 = IntegerType.get_signless(32)
 
-  @func.FuncOp.from_py_func(bool)
-  def simple_if(cond):
-    if_op = scf.IfOp(cond)
-    with InsertionPoint(if_op.then_block):
-      one = arith.ConstantOp(i32, 1)
-      add = arith.AddIOp(one, one)
-      scf.YieldOp([])
-    return
+    @func.FuncOp.from_py_func(bool)
+    def simple_if(cond):
+        if_op = scf.IfOp(cond)
+        with InsertionPoint(if_op.then_block):
+            one = arith.ConstantOp(i32, 1)
+            add = arith.AddIOp(one, one)
+            scf.YieldOp([])
+        return
 
 
 # CHECK: func @simple_if(%[[ARG0:.*]]: i1)
@@ -108,22 +107,22 @@ def testIfWithoutElse():
 
 @constructAndPrintInModule
 def testIfWithElse():
-  bool = IntegerType.get_signless(1)
-  i32 = IntegerType.get_signless(32)
-
-  @func.FuncOp.from_py_func(bool)
-  def simple_if_else(cond):
-    if_op = scf.IfOp(cond, [i32, i32], hasElse=True)
-    with InsertionPoint(if_op.then_block):
-      x_true = arith.ConstantOp(i32, 0)
-      y_true = arith.ConstantOp(i32, 1)
-      scf.YieldOp([x_true, y_true])
-    with InsertionPoint(if_op.else_block):
-      x_false = arith.ConstantOp(i32, 2)
-      y_false = arith.ConstantOp(i32, 3)
-      scf.YieldOp([x_false, y_false])
-    add = arith.AddIOp(if_op.results[0], if_op.results[1])
-    return
+    bool = IntegerType.get_signless(1)
+    i32 = IntegerType.get_signless(32)
+
+    @func.FuncOp.from_py_func(bool)
+    def simple_if_else(cond):
+        if_op = scf.IfOp(cond, [i32, i32], hasElse=True)
+        with InsertionPoint(if_op.then_block):
+            x_true = arith.ConstantOp(i32, 0)
+            y_true = arith.ConstantOp(i32, 1)
+            scf.YieldOp([x_true, y_true])
+        with InsertionPoint(if_op.else_block):
+            x_false = arith.ConstantOp(i32, 2)
+            y_false = arith.ConstantOp(i32, 3)
+            scf.YieldOp([x_false, y_false])
+        add = arith.AddIOp(if_op.results[0], if_op.results[1])
+        return
 
 
 # CHECK: func @simple_if_else(%[[ARG0:.*]]: i1)
index 3e7a8b2..ad75585 100644 (file)
@@ -7,36 +7,38 @@ import mlir.dialects.shape as shape
 
 
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    return f
 
 
 # CHECK-LABEL: TEST: testConstShape
 @run
 def testConstShape():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32 = F32Type.get()
-    with InsertionPoint(module.body):
-      @func.FuncOp.from_py_func(
-          RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32))
-      def const_shape_tensor(arg):
-        shape.ConstWitnessOp(False)
-        shape.ConstSizeOp(30)
-        shape.ConstSizeOp(IntegerAttr.get(IndexType.get(), 40))
-        x = shape.ConstShapeOp([1, 2])
-        shape.MeetOp(x, x, error="impossible")
-        return shape.ConstShapeOp(
-            DenseElementsAttr.get(
-                np.array([3, 4], dtype=np.int64), type=IndexType.get()))
-
-
-
-    # CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>)
-    # CHECK-DAG: shape.const_witness false
-    # CHECK-DAG: shape.const_size 30
-    # CHECK-DAG: shape.const_size 40
-    # CHECK-DAG: shape.const_shape [1, 2] : tensor<2xindex>
-    # CHECK-DAG: shape.const_shape [3, 4] : tensor<2xindex>
-    print(module)
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(
+                RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32)
+            )
+            def const_shape_tensor(arg):
+                shape.ConstWitnessOp(False)
+                shape.ConstSizeOp(30)
+                shape.ConstSizeOp(IntegerAttr.get(IndexType.get(), 40))
+                x = shape.ConstShapeOp([1, 2])
+                shape.MeetOp(x, x, error="impossible")
+                return shape.ConstShapeOp(
+                    DenseElementsAttr.get(
+                        np.array([3, 4], dtype=np.int64), type=IndexType.get()
+                    )
+                )
+
+        # CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>)
+        # CHECK-DAG: shape.const_witness false
+        # CHECK-DAG: shape.const_size 30
+        # CHECK-DAG: shape.const_size 40
+        # CHECK-DAG: shape.const_shape [1, 2] : tensor<2xindex>
+        # CHECK-DAG: shape.const_shape [3, 4] : tensor<2xindex>
+        print(module)
index 6190beb..b7a0606 100644 (file)
 from mlir.ir import *
 from mlir.dialects import sparse_tensor as st
 
+
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    return f
 
 
 # CHECK-LABEL: TEST: testEncodingAttr1D
 @run
 def testEncodingAttr1D():
-  with Context() as ctx:
-    parsed = Attribute.parse('#sparse_tensor.encoding<{'
-                             '  lvlTypes = [ "compressed" ],'
-                             '  posWidth = 16,'
-                             '  crdWidth = 32'
-                             '}>')
-    # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 16, crdWidth = 32 }>
-    print(parsed)
-
-    casted = st.EncodingAttr(parsed)
-    # CHECK: equal: True
-    print(f"equal: {casted == parsed}")
-
-    # CHECK: lvl_types: [<DimLevelType.compressed: 8>]
-    print(f"lvl_types: {casted.lvl_types}")
-    # CHECK: dim_ordering: None
-    print(f"dim_ordering: {casted.dim_ordering}")
-    # CHECK: pos_width: 16
-    print(f"pos_width: {casted.pos_width}")
-    # CHECK: crd_width: 32
-    print(f"crd_width: {casted.crd_width}")
-
-    created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
-    # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>
-    print(created)
-    # CHECK: created_equal: False
-    print(f"created_equal: {created == casted}")
-
-    # Verify that the factory creates an instance of the proper type.
-    # CHECK: is_proper_instance: True
-    print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
-    # CHECK: created_pos_width: 0
-    print(f"created_pos_width: {created.pos_width}")
+    with Context() as ctx:
+        parsed = Attribute.parse(
+            "#sparse_tensor.encoding<{"
+            '  lvlTypes = [ "compressed" ],'
+            "  posWidth = 16,"
+            "  crdWidth = 32"
+            "}>"
+        )
+        # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 16, crdWidth = 32 }>
+        print(parsed)
+
+        casted = st.EncodingAttr(parsed)
+        # CHECK: equal: True
+        print(f"equal: {casted == parsed}")
+
+        # CHECK: lvl_types: [<DimLevelType.compressed: 8>]
+        print(f"lvl_types: {casted.lvl_types}")
+        # CHECK: dim_ordering: None
+        print(f"dim_ordering: {casted.dim_ordering}")
+        # CHECK: pos_width: 16
+        print(f"pos_width: {casted.pos_width}")
+        # CHECK: crd_width: 32
+        print(f"crd_width: {casted.crd_width}")
+
+        created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
+        # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>
+        print(created)
+        # CHECK: created_equal: False
+        print(f"created_equal: {created == casted}")
+
+        # Verify that the factory creates an instance of the proper type.
+        # CHECK: is_proper_instance: True
+        print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
+        # CHECK: created_pos_width: 0
+        print(f"created_pos_width: {created.pos_width}")
 
 
 # CHECK-LABEL: TEST: testEncodingAttr2D
 @run
 def testEncodingAttr2D():
-  with Context() as ctx:
-    parsed = Attribute.parse('#sparse_tensor.encoding<{'
-                             '  lvlTypes = [ "dense", "compressed" ],'
-                             '  dimOrdering = affine_map<(d0, d1) -> (d1, d0)>,'
-                             '  posWidth = 8,'
-                             '  crdWidth = 32'
-                             '}>')
-    # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, posWidth = 8, crdWidth = 32 }>
-    print(parsed)
-
-    casted = st.EncodingAttr(parsed)
-    # CHECK: equal: True
-    print(f"equal: {casted == parsed}")
-
-    # CHECK: lvl_types: [<DimLevelType.dense: 4>, <DimLevelType.compressed: 8>]
-    print(f"lvl_types: {casted.lvl_types}")
-    # CHECK: dim_ordering: (d0, d1) -> (d1, d0)
-    print(f"dim_ordering: {casted.dim_ordering}")
-    # CHECK: pos_width: 8
-    print(f"pos_width: {casted.pos_width}")
-    # CHECK: crd_width: 32
-    print(f"crd_width: {casted.crd_width}")
-
-    created = st.EncodingAttr.get(casted.lvl_types, casted.dim_ordering,
-                                  casted.higher_ordering, 8, 32)
-    # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, posWidth = 8, crdWidth = 32 }>
-    print(created)
-    # CHECK: created_equal: True
-    print(f"created_equal: {created == casted}")
+    with Context() as ctx:
+        parsed = Attribute.parse(
+            "#sparse_tensor.encoding<{"
+            '  lvlTypes = [ "dense", "compressed" ],'
+            "  dimOrdering = affine_map<(d0, d1) -> (d1, d0)>,"
+            "  posWidth = 8,"
+            "  crdWidth = 32"
+            "}>"
+        )
+        # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, posWidth = 8, crdWidth = 32 }>
+        print(parsed)
+
+        casted = st.EncodingAttr(parsed)
+        # CHECK: equal: True
+        print(f"equal: {casted == parsed}")
+
+        # CHECK: lvl_types: [<DimLevelType.dense: 4>, <DimLevelType.compressed: 8>]
+        print(f"lvl_types: {casted.lvl_types}")
+        # CHECK: dim_ordering: (d0, d1) -> (d1, d0)
+        print(f"dim_ordering: {casted.dim_ordering}")
+        # CHECK: pos_width: 8
+        print(f"pos_width: {casted.pos_width}")
+        # CHECK: crd_width: 32
+        print(f"crd_width: {casted.crd_width}")
+
+        created = st.EncodingAttr.get(
+            casted.lvl_types, casted.dim_ordering, casted.higher_ordering, 8, 32
+        )
+        # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, posWidth = 8, crdWidth = 32 }>
+        print(created)
+        # CHECK: created_equal: True
+        print(f"created_equal: {created == casted}")
 
 
 # CHECK-LABEL: TEST: testEncodingAttrOnTensorType
 @run
 def testEncodingAttrOnTensorType():
-  with Context() as ctx, Location.unknown():
-    encoding = st.EncodingAttr(
-        Attribute.parse('#sparse_tensor.encoding<{'
-                        '  lvlTypes = [ "compressed" ], '
-                        '  posWidth = 64,'
-                        '  crdWidth = 32'
-                        '}>'))
-    tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding)
-    # CHECK: tensor<1024xf32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 64, crdWidth = 32 }>>
-    print(tt)
-    # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 64, crdWidth = 32 }>
-    print(tt.encoding)
-    assert tt.encoding == encoding
+    with Context() as ctx, Location.unknown():
+        encoding = st.EncodingAttr(
+            Attribute.parse(
+                "#sparse_tensor.encoding<{"
+                '  lvlTypes = [ "compressed" ], '
+                "  posWidth = 64,"
+                "  crdWidth = 32"
+                "}>"
+            )
+        )
+        tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding)
+        # CHECK: tensor<1024xf32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 64, crdWidth = 32 }>>
+        print(tt)
+        # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 64, crdWidth = 32 }>
+        print(tt.encoding)
+        assert tt.encoding == encoding
index 9319e16..c37c520 100644 (file)
@@ -7,16 +7,16 @@ from mlir.dialects import sparse_tensor as st
 
 
 def run(f):
-  print('\nTEST:', f.__name__)
-  f()
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    return f
 
 
 # CHECK-LABEL: TEST: testSparseTensorPass
 @run
 def testSparseTensorPass():
-  with Context() as context:
-    PassManager.parse('any(sparsification)')
-    PassManager.parse('any(sparse-tensor-conversion)')
-  # CHECK: SUCCESS
-  print('SUCCESS')
+    with Context() as context:
+        PassManager.parse("any(sparsification)")
+        PassManager.parse("any(sparse-tensor-conversion)")
+    # CHECK: SUCCESS
+    print("SUCCESS")
index b0ad4b4..b690c93 100644 (file)
@@ -7,125 +7,135 @@ import mlir.dialects.tensor as tensor
 
 
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    return f
 
 
 # CHECK-LABEL: TEST: testDimOp
 @run
 def testDimOp():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32Type = F32Type.get()
-    indexType = IndexType.get()
-    with InsertionPoint(module.body):
-
-      @func.FuncOp.from_py_func(
-          RankedTensorType.get(
-              (ShapedType.get_dynamic_size(), ShapedType.get_dynamic_size()),
-              f32Type))
-      #      CHECK: func @tensor_static_dim
-      # CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
-      #  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-      #  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-      #      CHECK:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
-      #      CHECK:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
-      #      CHECK:   return %[[D0]], %[[D1]]
-      def tensor_static_dim(t):
-        c0 = arith.ConstantOp(indexType, 0)
-        c1 = arith.ConstantOp(indexType, 1)
-        d0 = tensor.DimOp(t, c0)
-        d1 = tensor.DimOp(t, c1)
-        return [d0.result, d1.result]
-
-    print(module)
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f32Type = F32Type.get()
+        indexType = IndexType.get()
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(
+                RankedTensorType.get(
+                    (ShapedType.get_dynamic_size(), ShapedType.get_dynamic_size()),
+                    f32Type,
+                )
+            )
+            #      CHECK: func @tensor_static_dim
+            # CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
+            #  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+            #  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+            #      CHECK:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+            #      CHECK:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+            #      CHECK:   return %[[D0]], %[[D1]]
+            def tensor_static_dim(t):
+                c0 = arith.ConstantOp(indexType, 0)
+                c1 = arith.ConstantOp(indexType, 1)
+                d0 = tensor.DimOp(t, c0)
+                d1 = tensor.DimOp(t, c1)
+                return [d0.result, d1.result]
+
+        print(module)
 
 
 # CHECK-LABEL: TEST: testEmptyOp
 @run
 def testEmptyOp():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32 = F32Type.get()
-    with InsertionPoint(module.body):
-      # CHECK-LABEL: func @static_sizes
-      # CHECK: %0 = tensor.empty() : tensor<3x4xf32>
-      @func.FuncOp.from_py_func()
-      def static_sizes():
-        return tensor.EmptyOp([3, 4], f32)
-
-      # CHECK-LABEL: func @dynamic_sizes
-      # CHECK: %0 = tensor.empty(%arg0, %arg1) : tensor<?x?xf32>
-      @func.FuncOp.from_py_func(IndexType.get(), IndexType.get())
-      def dynamic_sizes(d0, d1):
-        return tensor.EmptyOp([d0, d1], f32)
-
-      # CHECK-LABEL: func @mixed_static_dynamic_sizes
-      # CHECK: %0 = tensor.empty(%arg0) : tensor<?x4xf32>
-      @func.FuncOp.from_py_func(IndexType.get())
-      def mixed_static_dynamic_sizes(d0):
-        return tensor.EmptyOp([d0, 4], f32)
-
-      # CHECK-LABEL: func @zero_d
-      # CHECK: %0 = tensor.empty() : tensor<f32>
-      @func.FuncOp.from_py_func()
-      def zero_d():
-        return tensor.EmptyOp([], f32)
-
-  print(module)
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+            # CHECK-LABEL: func @static_sizes
+            # CHECK: %0 = tensor.empty() : tensor<3x4xf32>
+            @func.FuncOp.from_py_func()
+            def static_sizes():
+                return tensor.EmptyOp([3, 4], f32)
+
+            # CHECK-LABEL: func @dynamic_sizes
+            # CHECK: %0 = tensor.empty(%arg0, %arg1) : tensor<?x?xf32>
+            @func.FuncOp.from_py_func(IndexType.get(), IndexType.get())
+            def dynamic_sizes(d0, d1):
+                return tensor.EmptyOp([d0, d1], f32)
+
+            # CHECK-LABEL: func @mixed_static_dynamic_sizes
+            # CHECK: %0 = tensor.empty(%arg0) : tensor<?x4xf32>
+            @func.FuncOp.from_py_func(IndexType.get())
+            def mixed_static_dynamic_sizes(d0):
+                return tensor.EmptyOp([d0, 4], f32)
+
+            # CHECK-LABEL: func @zero_d
+            # CHECK: %0 = tensor.empty() : tensor<f32>
+            @func.FuncOp.from_py_func()
+            def zero_d():
+                return tensor.EmptyOp([], f32)
+
+    print(module)
 
 
 # CHECK-LABEL: TEST: testInferTypesInsertSlice
 @run
 def testInferTypesInsertSlice():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32Type = F32Type.get()
-    with InsertionPoint(module.body):
-
-      @func.FuncOp.from_py_func(
-          RankedTensorType.get((1, 1), f32Type),
-          RankedTensorType.get((1, 1), f32Type))
-      # CHECK: func @f
-      # CHECK:      tensor.insert_slice %arg0 into %arg1[0, 0] [1, 1] [0, 0] :
-      # CHECK-SAME:   tensor<1x1xf32> into tensor<1x1xf32>
-      def f(source, dest):
-        d0 = tensor.InsertSliceOp(source, dest, [], [], [],
-                                  DenseI64ArrayAttr.get([0, 0]),
-                                  DenseI64ArrayAttr.get([1, 1]),
-                                  DenseI64ArrayAttr.get([0, 0]))
-        return [d0.result]
-
-  print(module)
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f32Type = F32Type.get()
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(
+                RankedTensorType.get((1, 1), f32Type),
+                RankedTensorType.get((1, 1), f32Type),
+            )
+            # CHECK: func @f
+            # CHECK:      tensor.insert_slice %arg0 into %arg1[0, 0] [1, 1] [0, 0] :
+            # CHECK-SAME:   tensor<1x1xf32> into tensor<1x1xf32>
+            def f(source, dest):
+                d0 = tensor.InsertSliceOp(
+                    source,
+                    dest,
+                    [],
+                    [],
+                    [],
+                    DenseI64ArrayAttr.get([0, 0]),
+                    DenseI64ArrayAttr.get([1, 1]),
+                    DenseI64ArrayAttr.get([0, 0]),
+                )
+                return [d0.result]
+
+    print(module)
 
 
 # CHECK-LABEL: TEST: testFromElementsOp
 @run
 def testFromElementsOp():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32 = F32Type.get()
-    with InsertionPoint(module.body):
-      @func.FuncOp.from_py_func()
-      def default_builder():
-        c0 = arith.ConstantOp(f32, 0.0)
-        # CHECK: %[[C0:.*]] = "arith.constant
-        # CHECK-SAME: value = 0.000000e+00 : f32
-        print(c0)
-        c1 = arith.ConstantOp(f32, 1.0)
-        # CHECK: %[[C1:.*]] = "arith.constant
-        # CHECK-SAME: value = 1.000000e+00 : f32
-        print(c1)
-
-        t = tensor.FromElementsOp(RankedTensorType.get((2,), f32), [c0, c1])
-        # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<2xf32>
-        print(t)
-
-        t = tensor.FromElementsOp(RankedTensorType.get((2, 1), f32), [c0, c1])
-        # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<2x1xf32>
-        print(t)
-
-        t = tensor.FromElementsOp(RankedTensorType.get((1, 2), f32), [c0, c1])
-        # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<1x2xf32>
-        print(t)
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func()
+            def default_builder():
+                c0 = arith.ConstantOp(f32, 0.0)
+                # CHECK: %[[C0:.*]] = "arith.constant
+                # CHECK-SAME: value = 0.000000e+00 : f32
+                print(c0)
+                c1 = arith.ConstantOp(f32, 1.0)
+                # CHECK: %[[C1:.*]] = "arith.constant
+                # CHECK-SAME: value = 1.000000e+00 : f32
+                print(c1)
+
+                t = tensor.FromElementsOp(RankedTensorType.get((2,), f32), [c0, c1])
+                # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<2xf32>
+                print(t)
+
+                t = tensor.FromElementsOp(RankedTensorType.get((2, 1), f32), [c0, c1])
+                # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<2x1xf32>
+                print(t)
+
+                t = tensor.FromElementsOp(RankedTensorType.get((1, 2), f32), [c0, c1])
+                # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<1x2xf32>
+                print(t)
index 6b36c02..ca6499b 100644 (file)
@@ -6,158 +6,188 @@ from mlir.dialects.transform import pdl as transform_pdl
 
 
 def run(f):
-  with Context(), Location.unknown():
-    module = Module.create()
-    with InsertionPoint(module.body):
-      print("\nTEST:", f.__name__)
-      f()
-    print(module)
-  return f
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            print("\nTEST:", f.__name__)
+            f()
+        print(module)
+    return f
 
 
 @run
 def testTypes():
-  # CHECK-LABEL: TEST: testTypes
-  # CHECK: !transform.any_op
-  any_op = transform.AnyOpType.get()
-  print(any_op)
+    # CHECK-LABEL: TEST: testTypes
+    # CHECK: !transform.any_op
+    any_op = transform.AnyOpType.get()
+    print(any_op)
 
-  # CHECK: !transform.op<"foo.bar">
-  # CHECK: foo.bar
-  concrete_op = transform.OperationType.get("foo.bar")
-  print(concrete_op)
-  print(concrete_op.operation_name)
+    # CHECK: !transform.op<"foo.bar">
+    # CHECK: foo.bar
+    concrete_op = transform.OperationType.get("foo.bar")
+    print(concrete_op)
+    print(concrete_op.operation_name)
 
 
 @run
 def testSequenceOp():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
-                                  [transform.AnyOpType.get()],
-                                  transform.AnyOpType.get())
-  with InsertionPoint(sequence.body):
-    transform.YieldOp([sequence.bodyTarget])
-  # CHECK-LABEL: TEST: testSequenceOp
-  # CHECK: = transform.sequence -> !transform.any_op failures(propagate) {
-  # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
-  # CHECK:   yield %[[ARG0]] : !transform.any_op
-  # CHECK: }
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE,
+        [transform.AnyOpType.get()],
+        transform.AnyOpType.get(),
+    )
+    with InsertionPoint(sequence.body):
+        transform.YieldOp([sequence.bodyTarget])
+    # CHECK-LABEL: TEST: testSequenceOp
+    # CHECK: = transform.sequence -> !transform.any_op failures(propagate) {
+    # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+    # CHECK:   yield %[[ARG0]] : !transform.any_op
+    # CHECK: }
 
 
 @run
 def testNestedSequenceOp():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get())
-  with InsertionPoint(sequence.body):
-    nested = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], sequence.bodyTarget)
-    with InsertionPoint(nested.body):
-      doubly_nested = transform.SequenceOp(
-          transform.FailurePropagationMode.PROPAGATE,
-          [transform.AnyOpType.get()], nested.bodyTarget)
-      with InsertionPoint(doubly_nested.body):
-        transform.YieldOp([doubly_nested.bodyTarget])
-      transform.YieldOp()
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testNestedSequenceOp
-  # CHECK: transform.sequence failures(propagate) {
-  # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
-  # CHECK:   sequence %[[ARG0]] : !transform.any_op failures(propagate) {
-  # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
-  # CHECK:     = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
-  # CHECK:     ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
-  # CHECK:       yield %[[ARG2]] : !transform.any_op
-  # CHECK:     }
-  # CHECK:   }
-  # CHECK: }
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        nested = transform.SequenceOp(
+            transform.FailurePropagationMode.PROPAGATE, [], sequence.bodyTarget
+        )
+        with InsertionPoint(nested.body):
+            doubly_nested = transform.SequenceOp(
+                transform.FailurePropagationMode.PROPAGATE,
+                [transform.AnyOpType.get()],
+                nested.bodyTarget,
+            )
+            with InsertionPoint(doubly_nested.body):
+                transform.YieldOp([doubly_nested.bodyTarget])
+            transform.YieldOp()
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testNestedSequenceOp
+    # CHECK: transform.sequence failures(propagate) {
+    # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+    # CHECK:   sequence %[[ARG0]] : !transform.any_op failures(propagate) {
+    # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+    # CHECK:     = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
+    # CHECK:     ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
+    # CHECK:       yield %[[ARG2]] : !transform.any_op
+    # CHECK:     }
+    # CHECK:   }
+    # CHECK: }
 
 
 @run
 def testSequenceOpWithExtras():
-  sequence = transform.SequenceOp(
-      transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get(),
-      [transform.AnyOpType.get(),
-       transform.OperationType.get("foo.bar")])
-  with InsertionPoint(sequence.body):
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testSequenceOpWithExtras
-  # CHECK: transform.sequence failures(propagate)
-  # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE,
+        [],
+        transform.AnyOpType.get(),
+        [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
+    )
+    with InsertionPoint(sequence.body):
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testSequenceOpWithExtras
+    # CHECK: transform.sequence failures(propagate)
+    # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
 
 
 @run
 def testNestedSequenceOpWithExtras():
-  sequence = transform.SequenceOp(
-      transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get(),
-      [transform.AnyOpType.get(),
-       transform.OperationType.get("foo.bar")])
-  with InsertionPoint(sequence.body):
-    nested = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
-                                  [], sequence.bodyTarget,
-                                  sequence.bodyExtraArgs)
-    with InsertionPoint(nested.body):
-      transform.YieldOp()
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
-  # CHECK: transform.sequence failures(propagate)
-  # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
-  # CHECK:   sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE,
+        [],
+        transform.AnyOpType.get(),
+        [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
+    )
+    with InsertionPoint(sequence.body):
+        nested = transform.SequenceOp(
+            transform.FailurePropagationMode.PROPAGATE,
+            [],
+            sequence.bodyTarget,
+            sequence.bodyExtraArgs,
+        )
+        with InsertionPoint(nested.body):
+            transform.YieldOp()
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
+    # CHECK: transform.sequence failures(propagate)
+    # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
+    # CHECK:   sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
 
 
 @run
 def testTransformPDLOps():
-  withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
-  with InsertionPoint(withPdl.body):
-    sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
-                                    [transform.AnyOpType.get()],
-                                    withPdl.bodyTarget)
-    with InsertionPoint(sequence.body):
-      match = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher")
-      transform.YieldOp(match)
-  # CHECK-LABEL: TEST: testTransformPDLOps
-  # CHECK: transform.with_pdl_patterns {
-  # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
-  # CHECK:   = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
-  # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
-  # CHECK:     %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
-  # CHECK:     yield %[[RES]] : !transform.any_op
-  # CHECK:   }
-  # CHECK: }
+    withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
+    with InsertionPoint(withPdl.body):
+        sequence = transform.SequenceOp(
+            transform.FailurePropagationMode.PROPAGATE,
+            [transform.AnyOpType.get()],
+            withPdl.bodyTarget,
+        )
+        with InsertionPoint(sequence.body):
+            match = transform_pdl.PDLMatchOp(
+                transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher"
+            )
+            transform.YieldOp(match)
+    # CHECK-LABEL: TEST: testTransformPDLOps
+    # CHECK: transform.with_pdl_patterns {
+    # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+    # CHECK:   = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
+    # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+    # CHECK:     %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
+    # CHECK:     yield %[[RES]] : !transform.any_op
+    # CHECK:   }
+    # CHECK: }
 
 
 @run
 def testGetClosestIsolatedParentOp():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get())
-  with InsertionPoint(sequence.body):
-    transform.GetClosestIsolatedParentOp(transform.AnyOpType.get(), sequence.bodyTarget)
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testGetClosestIsolatedParentOp
-  # CHECK: transform.sequence
-  # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
-  # CHECK:   = get_closest_isolated_parent %[[ARG1]]
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        transform.GetClosestIsolatedParentOp(
+            transform.AnyOpType.get(), sequence.bodyTarget
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testGetClosestIsolatedParentOp
+    # CHECK: transform.sequence
+    # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+    # CHECK:   = get_closest_isolated_parent %[[ARG1]]
 
 
 @run
 def testMergeHandlesOp():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get())
-  with InsertionPoint(sequence.body):
-    transform.MergeHandlesOp([sequence.bodyTarget])
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testMergeHandlesOp
-  # CHECK: transform.sequence
-  # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
-  # CHECK:   = merge_handles %[[ARG1]]
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        transform.MergeHandlesOp([sequence.bodyTarget])
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMergeHandlesOp
+    # CHECK: transform.sequence
+    # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
+    # CHECK:   = merge_handles %[[ARG1]]
 
 
 @run
 def testReplicateOp():
-  with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
-  with InsertionPoint(with_pdl.body):
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget)
-    with InsertionPoint(sequence.body):
-      m1 = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "first")
-      m2 = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "second")
-      transform.ReplicateOp(m1, [m2])
-      transform.YieldOp()
-  # CHECK-LABEL: TEST: testReplicateOp
-  # CHECK: %[[FIRST:.+]] = pdl_match
-  # CHECK: %[[SECOND:.+]] = pdl_match
-  # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
+    with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
+    with InsertionPoint(with_pdl.body):
+        sequence = transform.SequenceOp(
+            transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget
+        )
+        with InsertionPoint(sequence.body):
+            m1 = transform_pdl.PDLMatchOp(
+                transform.AnyOpType.get(), sequence.bodyTarget, "first"
+            )
+            m2 = transform_pdl.PDLMatchOp(
+                transform.AnyOpType.get(), sequence.bodyTarget, "second"
+            )
+            transform.ReplicateOp(m1, [m2])
+            transform.YieldOp()
+    # CHECK-LABEL: TEST: testReplicateOp
+    # CHECK: %[[FIRST:.+]] = pdl_match
+    # CHECK: %[[SECOND:.+]] = pdl_match
+    # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
index 067a8b6..28a022a 100644 (file)
@@ -7,70 +7,92 @@ from mlir.dialects.transform import loop
 
 
 def run(f):
-  with Context(), Location.unknown():
-    module = Module.create()
-    with InsertionPoint(module.body):
-      print("\nTEST:", f.__name__)
-      f()
-    print(module)
-  return f
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            print("\nTEST:", f.__name__)
+            f()
+        print(module)
+    return f
 
 
 @run
 def getParentLoop():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
-                                  [], pdl.OperationType.get())
-  with InsertionPoint(sequence.body):
-    loop.GetParentForOp(transform.OperationType.get("scf.for"), sequence.bodyTarget, num_loops=2)
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: getParentLoop
-  # CHECK: = transform.loop.get_parent_for %
-  # CHECK: num_loops = 2
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        loop.GetParentForOp(
+            transform.OperationType.get("scf.for"), sequence.bodyTarget, num_loops=2
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: getParentLoop
+    # CHECK: = transform.loop.get_parent_for %
+    # CHECK: num_loops = 2
 
 
 @run
 def loopOutline():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
-                                  [], transform.OperationType.get("scf.for"))
-  with InsertionPoint(sequence.body):
-    loop.LoopOutlineOp(transform.AnyOpType.get(), transform.AnyOpType.get(), sequence.bodyTarget, func_name="foo")
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: loopOutline
-  # CHECK: = transform.loop.outline %
-  # CHECK: func_name = "foo"
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE,
+        [],
+        transform.OperationType.get("scf.for"),
+    )
+    with InsertionPoint(sequence.body):
+        loop.LoopOutlineOp(
+            transform.AnyOpType.get(),
+            transform.AnyOpType.get(),
+            sequence.bodyTarget,
+            func_name="foo",
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: loopOutline
+    # CHECK: = transform.loop.outline %
+    # CHECK: func_name = "foo"
 
 
 @run
 def loopPeel():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
-                                  [], transform.OperationType.get("scf.for"))
-  with InsertionPoint(sequence.body):
-    loop.LoopPeelOp(pdl.OperationType.get(), sequence.bodyTarget)
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: loopPeel
-  # CHECK: = transform.loop.peel %
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE,
+        [],
+        transform.OperationType.get("scf.for"),
+    )
+    with InsertionPoint(sequence.body):
+        loop.LoopPeelOp(pdl.OperationType.get(), sequence.bodyTarget)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: loopPeel
+    # CHECK: = transform.loop.peel %
 
 
 @run
 def loopPipeline():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
-                                  [], transform.OperationType.get("scf.for"))
-  with InsertionPoint(sequence.body):
-    loop.LoopPipelineOp(pdl.OperationType.get(), sequence.bodyTarget, iteration_interval=3)
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: loopPipeline
-  # CHECK: = transform.loop.pipeline %
-  # CHECK-DAG: iteration_interval = 3
-  # (read_latency has default value and is not printed)
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE,
+        [],
+        transform.OperationType.get("scf.for"),
+    )
+    with InsertionPoint(sequence.body):
+        loop.LoopPipelineOp(
+            pdl.OperationType.get(), sequence.bodyTarget, iteration_interval=3
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: loopPipeline
+    # CHECK: = transform.loop.pipeline %
+    # CHECK-DAG: iteration_interval = 3
+    # (read_latency has default value and is not printed)
 
 
 @run
 def loopUnroll():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
-                                  [], transform.OperationType.get("scf.for"))
-  with InsertionPoint(sequence.body):
-    loop.LoopUnrollOp(sequence.bodyTarget, factor=42)
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: loopUnroll
-  # CHECK: transform.loop.unroll %
-  # CHECK: factor = 42
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE,
+        [],
+        transform.OperationType.get("scf.for"),
+    )
+    with InsertionPoint(sequence.body):
+        loop.LoopUnrollOp(sequence.bodyTarget, factor=42)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: loopUnroll
+    # CHECK: transform.loop.unroll %
+    # CHECK: factor = 42
index d2a82b8..2dfae47 100644 (file)
@@ -8,204 +8,230 @@ from mlir.dialects.transform import pdl as transform_pdl
 
 
 def run(f):
-  with Context(), Location.unknown():
-    module = Module.create()
-    with InsertionPoint(module.body):
-      print("\nTEST:", f.__name__)
-      f()
-    print(module)
-  return f
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            print("\nTEST:", f.__name__)
+            f()
+        print(module)
+    return f
 
 
 @run
 def testDecompose():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
-  with InsertionPoint(sequence.body):
-    structured.DecomposeOp(sequence.bodyTarget)
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testDecompose
-  # CHECK: transform.sequence
-  # CHECK: transform.structured.decompose
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.DecomposeOp(sequence.bodyTarget)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testDecompose
+    # CHECK: transform.sequence
+    # CHECK: transform.structured.decompose
 
 
 @run
 def testGeneralize():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
-  with InsertionPoint(sequence.body):
-    structured.GeneralizeOp(sequence.bodyTarget)
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testGeneralize
-  # CHECK: transform.sequence
-  # CHECK: transform.structured.generalize
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.GeneralizeOp(sequence.bodyTarget)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testGeneralize
+    # CHECK: transform.sequence
+    # CHECK: transform.structured.generalize
 
 
 @run
 def testInterchange():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
-  with InsertionPoint(sequence.body):
-    structured.InterchangeOp(
-        sequence.bodyTarget,
-        iterator_interchange=[1, 0])
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testInterchange
-  # CHECK: transform.sequence
-  # CHECK: transform.structured.interchange
-  # CHECK: iterator_interchange = [1, 0]
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.InterchangeOp(sequence.bodyTarget, iterator_interchange=[1, 0])
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testInterchange
+    # CHECK: transform.sequence
+    # CHECK: transform.structured.interchange
+    # CHECK: iterator_interchange = [1, 0]
 
 
 @run
 def testMultitileSizes():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
-  with InsertionPoint(sequence.body):
-    structured.MultiTileSizesOp(pdl.OperationType.get(),
-                                sequence.bodyTarget,
-                                dimension=1,
-                                target_size=42)
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testMultitileSizes
-  # CHECK: transform.sequence
-  # CHECK: transform.structured.multitile_sizes
-  # CHECK-DAG: dimension = 1
-  # CHECK-DAG: target_size = 42
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.MultiTileSizesOp(
+            pdl.OperationType.get(), sequence.bodyTarget, dimension=1, target_size=42
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMultitileSizes
+    # CHECK: transform.sequence
+    # CHECK: transform.structured.multitile_sizes
+    # CHECK-DAG: dimension = 1
+    # CHECK-DAG: target_size = 42
 
 
 @run
 def testPad():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
-  with InsertionPoint(sequence.body):
-    structured.PadOp(
-        sequence.bodyTarget,
-        padding_values=[FloatAttr.get_f32(42.0)],
-        padding_dimensions=[1],
-        transpose_paddings=[[1, 0]])
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testPad
-  # CHECK: transform.sequence
-  # CHECK: transform.structured.pad
-  # CHECK-DAG: padding_values = [4.200000e+01 : f32]
-  # CHECK-DAG: padding_dimensions = [1]
-  # CHECK-DAG: transpose_paddings = {{\[}}[1, 0]]
-  # (pack_paddings has default values)
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.PadOp(
+            sequence.bodyTarget,
+            padding_values=[FloatAttr.get_f32(42.0)],
+            padding_dimensions=[1],
+            transpose_paddings=[[1, 0]],
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testPad
+    # CHECK: transform.sequence
+    # CHECK: transform.structured.pad
+    # CHECK-DAG: padding_values = [4.200000e+01 : f32]
+    # CHECK-DAG: padding_dimensions = [1]
+    # CHECK-DAG: transpose_paddings = {{\[}}[1, 0]]
+    # (pack_paddings has default values)
+
 
 @run
 def testScalarize():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
-  with InsertionPoint(sequence.body):
-    structured.ScalarizeOp(sequence.bodyTarget)
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testScalarize
-  # CHECK: transform.structured.scalarize
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.ScalarizeOp(sequence.bodyTarget)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testScalarize
+    # CHECK: transform.structured.scalarize
 
 
 @run
 def testSplit():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
-  with InsertionPoint(sequence.body):
-    split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42)
-    structured.SplitOp(
-        split.results[0], dimension=3, split_point=split.results[1])
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testSplit
-  # CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
-  # CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42)
+        structured.SplitOp(split.results[0], dimension=3, split_point=split.results[1])
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testSplit
+    # CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
+    # CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
+
 
 @run
 def testTileCompact():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
-  with InsertionPoint(sequence.body):
-    structured.TileOp(sequence.bodyTarget,
-                      sizes=[4, 8],
-                      interchange=[0, 1])
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testTileCompact
-  # CHECK: transform.sequence
-  # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
-  # CHECK: interchange = [0, 1]
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1])
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testTileCompact
+    # CHECK: transform.sequence
+    # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
+    # CHECK: interchange = [0, 1]
+
 
 @run
 def testTileAttributes():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
-  attr = DenseI64ArrayAttr.get([4, 8])
-  ichange = DenseI64ArrayAttr.get([0, 1])
-  with InsertionPoint(sequence.body):
-    structured.TileOp(sequence.bodyTarget,
-                      sizes=attr,
-                      interchange=ichange)
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testTileAttributes
-  # CHECK: transform.sequence
-  # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
-  # CHECK: interchange = [0, 1]
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    attr = DenseI64ArrayAttr.get([4, 8])
+    ichange = DenseI64ArrayAttr.get([0, 1])
+    with InsertionPoint(sequence.body):
+        structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testTileAttributes
+    # CHECK: transform.sequence
+    # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
+    # CHECK: interchange = [0, 1]
+
 
 @run
 def testTileZero():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
-  with InsertionPoint(sequence.body):
-    structured.TileOp(sequence.bodyTarget,
-                      sizes=[4, 0, 2, 0],
-                      interchange=[0, 1, 2, 3])
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testTileZero
-  # CHECK: transform.sequence
-  # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 0, 2, 0]
-  # CHECK: interchange = [0, 1, 2, 3]
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.TileOp(
+            sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3]
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testTileZero
+    # CHECK: transform.sequence
+    # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 0, 2, 0]
+    # CHECK: interchange = [0, 1, 2, 3]
+
 
 @run
 def testTileDynamic():
-  with_pdl = transform_pdl.WithPDLPatternsOp(pdl.OperationType.get())
-  with InsertionPoint(with_pdl.body):
-    sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [],
-                                    with_pdl.bodyTarget)
-    with InsertionPoint(sequence.body):
-      m1 = transform_pdl.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "first")
-      m2 = transform_pdl.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "second")
-      structured.TileOp(sequence.bodyTarget,
-                        sizes=[m1, 3, m2, 0])
-      transform.YieldOp()
-  # CHECK-LABEL: TEST: testTileDynamic
-  # CHECK: %[[FIRST:.+]] = pdl_match
-  # CHECK: %[[SECOND:.+]] = pdl_match
-  # CHECK: %{{.+}}, %{{.+}}:3 = transform.structured.tile %{{.*}}[%[[FIRST]], 3, %[[SECOND]], 0]
+    with_pdl = transform_pdl.WithPDLPatternsOp(pdl.OperationType.get())
+    with InsertionPoint(with_pdl.body):
+        sequence = transform.SequenceOp(
+            transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget
+        )
+        with InsertionPoint(sequence.body):
+            m1 = transform_pdl.PDLMatchOp(
+                pdl.OperationType.get(), sequence.bodyTarget, "first"
+            )
+            m2 = transform_pdl.PDLMatchOp(
+                pdl.OperationType.get(), sequence.bodyTarget, "second"
+            )
+            structured.TileOp(sequence.bodyTarget, sizes=[m1, 3, m2, 0])
+            transform.YieldOp()
+    # CHECK-LABEL: TEST: testTileDynamic
+    # CHECK: %[[FIRST:.+]] = pdl_match
+    # CHECK: %[[SECOND:.+]] = pdl_match
+    # CHECK: %{{.+}}, %{{.+}}:3 = transform.structured.tile %{{.*}}[%[[FIRST]], 3, %[[SECOND]], 0]
 
 
 @run
 def testTileExplicitLoopTypeSingle():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
-                                  [], transform.AnyOpType.get())
-  with InsertionPoint(sequence.body):
-    structured.TileOp(transform.OperationType.get("scf.for"),
-                      sequence.bodyTarget,
-                      sizes=[2, 3, 4])
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testTileExplicitLoopTypeSingle
-  # CHECK: = transform.structured.tile %{{.*}} : (!{{.*}}) ->
-  # CHECK-COUNT-3: !transform.op<"scf.for">
-
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.TileOp(
+            transform.OperationType.get("scf.for"), sequence.bodyTarget, sizes=[2, 3, 4]
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testTileExplicitLoopTypeSingle
+    # CHECK: = transform.structured.tile %{{.*}} : (!{{.*}}) ->
+    # CHECK-COUNT-3: !transform.op<"scf.for">
 
 
 @run
 def testTileExplicitLoopTypeAll():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
-                                  [], transform.AnyOpType.get())
-  types = [
-      transform.OperationType.get(x)
-      for x in ["scf.for", "scf.parallel", "scf.forall"]
-  ]
-  with InsertionPoint(sequence.body):
-    structured.TileOp(types, sequence.bodyTarget, sizes=[2, 3, 4])
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testTileExplicitLoopTypeAll
-  # CHECK: = transform.structured.tile
-  # CHECK-SAME : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">,
-  # CHECK-SAME: !transform.op<"scf.parallel">, !transform.op<"scf.forall">
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+    )
+    types = [
+        transform.OperationType.get(x)
+        for x in ["scf.for", "scf.parallel", "scf.forall"]
+    ]
+    with InsertionPoint(sequence.body):
+        structured.TileOp(types, sequence.bodyTarget, sizes=[2, 3, 4])
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testTileExplicitLoopTypeAll
+    # CHECK: = transform.structured.tile
+    # CHECK-SAME : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">,
+    # CHECK-SAME: !transform.op<"scf.parallel">, !transform.op<"scf.forall">
+
 
 @run
 def testVectorize():
-  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
-  with InsertionPoint(sequence.body):
-    structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True)
-    transform.YieldOp()
-  # CHECK-LABEL: TEST: testVectorize
-  # CHECK: transform.sequence
-  # CHECK: = transform.structured.vectorize
-  # CHECK: {vectorize_padding}
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+    )
+    with InsertionPoint(sequence.body):
+        structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testVectorize
+    # CHECK: transform.sequence
+    # CHECK: = transform.structured.vectorize
+    # CHECK: {vectorize_padding}
index 83c0961..2347abb 100644 (file)
@@ -5,57 +5,62 @@ import mlir.dialects.builtin as builtin
 import mlir.dialects.func as func
 import mlir.dialects.vector as vector
 
+
 def run(f):
-  print("\nTEST:", f.__name__)
-  with Context(), Location.unknown():
-    f()
-  return f
+    print("\nTEST:", f.__name__)
+    with Context(), Location.unknown():
+        f()
+    return f
+
 
 # CHECK-LABEL: TEST: testPrintOp
 @run
 def testPrintOp():
-  module = Module.create()
-  with InsertionPoint(module.body):
+    module = Module.create()
+    with InsertionPoint(module.body):
 
-    @func.FuncOp.from_py_func(VectorType.get((12, 5), F32Type.get()))
-    def print_vector(arg):
-      return vector.PrintOp(arg)
+        @func.FuncOp.from_py_func(VectorType.get((12, 5), F32Type.get()))
+        def print_vector(arg):
+            return vector.PrintOp(arg)
 
-  # CHECK-LABEL: func @print_vector(
-  # CHECK-SAME:                     %[[ARG:.*]]: vector<12x5xf32>) {
-  #       CHECK:   vector.print %[[ARG]] : vector<12x5xf32>
-  #       CHECK:   return
-  #       CHECK: }
-  print(module)
+    # CHECK-LABEL: func @print_vector(
+    # CHECK-SAME:                     %[[ARG:.*]]: vector<12x5xf32>) {
+    #       CHECK:   vector.print %[[ARG]] : vector<12x5xf32>
+    #       CHECK:   return
+    #       CHECK: }
+    print(module)
 
 
 # CHECK-LABEL: TEST: testTransferReadOp
 @run
 def testTransferReadOp():
-  module = Module.create()
-  with InsertionPoint(module.body):
-    vector_type = VectorType.get([2, 3], F32Type.get())
-    memref_type = MemRefType.get(
-        [ShapedType.get_dynamic_size(),
-         ShapedType.get_dynamic_size()], F32Type.get())
-    index_type = IndexType.get()
-    mask_type = VectorType.get(vector_type.shape, IntegerType.get_signless(1))
-    identity_map = AffineMap.get_identity(vector_type.rank)
-    identity_map_attr = AffineMapAttr.get(identity_map)
-    f = func.FuncOp("transfer_read",
-                          ([memref_type, index_type,
-                            F32Type.get(), mask_type], []))
-    with InsertionPoint(f.add_entry_block()):
-      A, zero, padding, mask = f.arguments
-      vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr,
-                            padding, mask=mask)
-      vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr,
-                            padding)
-      func.ReturnOp([])
-
-  # CHECK: @transfer_read(%[[MEM:.*]]: memref<?x?xf32>, %[[IDX:.*]]: index,
-  # CHECK: %[[PAD:.*]]: f32, %[[MASK:.*]]: vector<2x3xi1>)
-  # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]], %[[MASK]]
-  # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]]
-  # CHECK-NOT: %[[MASK]]
-  print(module)
+    module = Module.create()
+    with InsertionPoint(module.body):
+        vector_type = VectorType.get([2, 3], F32Type.get())
+        memref_type = MemRefType.get(
+            [ShapedType.get_dynamic_size(), ShapedType.get_dynamic_size()],
+            F32Type.get(),
+        )
+        index_type = IndexType.get()
+        mask_type = VectorType.get(vector_type.shape, IntegerType.get_signless(1))
+        identity_map = AffineMap.get_identity(vector_type.rank)
+        identity_map_attr = AffineMapAttr.get(identity_map)
+        f = func.FuncOp(
+            "transfer_read", ([memref_type, index_type, F32Type.get(), mask_type], [])
+        )
+        with InsertionPoint(f.add_entry_block()):
+            A, zero, padding, mask = f.arguments
+            vector.TransferReadOp(
+                vector_type, A, [zero, zero], identity_map_attr, padding, mask=mask
+            )
+            vector.TransferReadOp(
+                vector_type, A, [zero, zero], identity_map_attr, padding
+            )
+            func.ReturnOp([])
+
+    # CHECK: @transfer_read(%[[MEM:.*]]: memref<?x?xf32>, %[[IDX:.*]]: index,
+    # CHECK: %[[PAD:.*]]: f32, %[[MASK:.*]]: vector<2x3xi1>)
+    # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]], %[[MASK]]
+    # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]]
+    # CHECK-NOT: %[[MASK]]
+    print(module)
index 973810d..50d6e82 100644 (file)
@@ -10,34 +10,36 @@ from mlir.runtime import *
 # Log everything to stderr and flush so that we have a unified stream to match
 # errors/info emitted by MLIR to stderr.
 def log(*args):
-  print(*args, file=sys.stderr)
-  sys.stderr.flush()
+    print(*args, file=sys.stderr)
+    sys.stderr.flush()
 
 
 def run(f):
-  log("\nTEST:", f.__name__)
-  f()
-  gc.collect()
-  assert Context._get_live_count() == 0
+    log("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
 
 
 # Verify capsule interop.
 # CHECK-LABEL: TEST: testCapsule
 def testCapsule():
-  with Context():
-    module = Module.parse(r"""
+    with Context():
+        module = Module.parse(
+            r"""
 llvm.func @none() {
   llvm.return
 }
-    """)
-    execution_engine = ExecutionEngine(module)
-    execution_engine_capsule = execution_engine._CAPIPtr
-    # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr
-    log(repr(execution_engine_capsule))
-    execution_engine._testing_release()
-    execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule)
-    # CHECK: _mlirExecutionEngine.ExecutionEngine
-    log(repr(execution_engine1))
+    """
+        )
+        execution_engine = ExecutionEngine(module)
+        execution_engine_capsule = execution_engine._CAPIPtr
+        # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr
+        log(repr(execution_engine_capsule))
+        execution_engine._testing_release()
+        execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule)
+        # CHECK: _mlirExecutionEngine.ExecutionEngine
+        log(repr(execution_engine1))
 
 
 run(testCapsule)
@@ -46,40 +48,45 @@ run(testCapsule)
 # Test invalid ExecutionEngine creation
 # CHECK-LABEL: TEST: testInvalidModule
 def testInvalidModule():
-  with Context():
-    # Builtin function
-    module = Module.parse(r"""
+    with Context():
+        # Builtin function
+        module = Module.parse(
+            r"""
     func.func @foo() { return }
-    """)
-    # CHECK: Got RuntimeError:  Failure while creating the ExecutionEngine.
-    try:
-      execution_engine = ExecutionEngine(module)
-    except RuntimeError as e:
-      log("Got RuntimeError: ", e)
+    """
+        )
+        # CHECK: Got RuntimeError:  Failure while creating the ExecutionEngine.
+        try:
+            execution_engine = ExecutionEngine(module)
+        except RuntimeError as e:
+            log("Got RuntimeError: ", e)
 
 
 run(testInvalidModule)
 
 
 def lowerToLLVM(module):
-  pm = PassManager.parse(
-      "builtin.module(convert-complex-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts)")
-  pm.run(module.operation)
-  return module
+    pm = PassManager.parse(
+        "builtin.module(convert-complex-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts)"
+    )
+    pm.run(module.operation)
+    return module
 
 
 # Test simple ExecutionEngine execution
 # CHECK-LABEL: TEST: testInvokeVoid
 def testInvokeVoid():
-  with Context():
-    module = Module.parse(r"""
+    with Context():
+        module = Module.parse(
+            r"""
 func.func @void() attributes { llvm.emit_c_interface } {
   return
 }
-    """)
-    execution_engine = ExecutionEngine(lowerToLLVM(module))
-    # Nothing to check other than no exception thrown here.
-    execution_engine.invoke("void")
+    """
+        )
+        execution_engine = ExecutionEngine(lowerToLLVM(module))
+        # Nothing to check other than no exception thrown here.
+        execution_engine.invoke("void")
 
 
 run(testInvokeVoid)
@@ -88,23 +95,25 @@ run(testInvokeVoid)
 # Test argument passing and result with a simple float addition.
 # CHECK-LABEL: TEST: testInvokeFloatAdd
 def testInvokeFloatAdd():
-  with Context():
-    module = Module.parse(r"""
+    with Context():
+        module = Module.parse(
+            r"""
 func.func @add(%arg0: f32, %arg1: f32) -> f32 attributes { llvm.emit_c_interface } {
   %add = arith.addf %arg0, %arg1 : f32
   return %add : f32
 }
-    """)
-    execution_engine = ExecutionEngine(lowerToLLVM(module))
-    # Prepare arguments: two input floats and one result.
-    # Arguments must be passed as pointers.
-    c_float_p = ctypes.c_float * 1
-    arg0 = c_float_p(42.)
-    arg1 = c_float_p(2.)
-    res = c_float_p(-1.)
-    execution_engine.invoke("add", arg0, arg1, res)
-    # CHECK: 42.0 + 2.0 = 44.0
-    log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]))
+    """
+        )
+        execution_engine = ExecutionEngine(lowerToLLVM(module))
+        # Prepare arguments: two input floats and one result.
+        # Arguments must be passed as pointers.
+        c_float_p = ctypes.c_float * 1
+        arg0 = c_float_p(42.0)
+        arg1 = c_float_p(2.0)
+        res = c_float_p(-1.0)
+        execution_engine.invoke("add", arg0, arg1, res)
+        # CHECK: 42.0 + 2.0 = 44.0
+        log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]))
 
 
 run(testInvokeFloatAdd)
@@ -113,33 +122,35 @@ run(testInvokeFloatAdd)
 # Test callback
 # CHECK-LABEL: TEST: testBasicCallback
 def testBasicCallback():
-  # Define a callback function that takes a float and an integer and returns a float.
-  @ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float, ctypes.c_int)
-  def callback(a, b):
-    return a / 2 + b / 2
-
-  with Context():
-    # The module just forwards to a runtime function known as "some_callback_into_python".
-    module = Module.parse(r"""
+    # Define a callback function that takes a float and an integer and returns a float.
+    @ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float, ctypes.c_int)
+    def callback(a, b):
+        return a / 2 + b / 2
+
+    with Context():
+        # The module just forwards to a runtime function known as "some_callback_into_python".
+        module = Module.parse(
+            r"""
 func.func @add(%arg0: f32, %arg1: i32) -> f32 attributes { llvm.emit_c_interface } {
   %resf = call @some_callback_into_python(%arg0, %arg1) : (f32, i32) -> (f32)
   return %resf : f32
 }
 func.func private @some_callback_into_python(f32, i32) -> f32 attributes { llvm.emit_c_interface }
-    """)
-    execution_engine = ExecutionEngine(lowerToLLVM(module))
-    execution_engine.register_runtime("some_callback_into_python", callback)
-
-    # Prepare arguments: two input floats and one result.
-    # Arguments must be passed as pointers.
-    c_float_p = ctypes.c_float * 1
-    c_int_p = ctypes.c_int * 1
-    arg0 = c_float_p(42.)
-    arg1 = c_int_p(2)
-    res = c_float_p(-1.)
-    execution_engine.invoke("add", arg0, arg1, res)
-    # CHECK: 42.0 + 2 = 44.0
-    log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0] * 2))
+    """
+        )
+        execution_engine = ExecutionEngine(lowerToLLVM(module))
+        execution_engine.register_runtime("some_callback_into_python", callback)
+
+        # Prepare arguments: two input floats and one result.
+        # Arguments must be passed as pointers.
+        c_float_p = ctypes.c_float * 1
+        c_int_p = ctypes.c_int * 1
+        arg0 = c_float_p(42.0)
+        arg1 = c_int_p(2)
+        res = c_float_p(-1.0)
+        execution_engine.invoke("add", arg0, arg1, res)
+        # CHECK: 42.0 + 2 = 44.0
+        log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0] * 2))
 
 
 run(testBasicCallback)
@@ -148,44 +159,46 @@ run(testBasicCallback)
 # Test callback with an unranked memref
 # CHECK-LABEL: TEST: testUnrankedMemRefCallback
 def testUnrankedMemRefCallback():
-  # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it.
-  @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
-  def callback(a):
-    arr = unranked_memref_to_numpy(a, np.float32)
-    log("Inside callback: ")
-    log(arr)
-
-  with Context():
-    # The module just forwards to a runtime function known as "some_callback_into_python".
-    module = Module.parse(r"""
+    # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it.
+    @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
+    def callback(a):
+        arr = unranked_memref_to_numpy(a, np.float32)
+        log("Inside callback: ")
+        log(arr)
+
+    with Context():
+        # The module just forwards to a runtime function known as "some_callback_into_python".
+        module = Module.parse(
+            r"""
 func.func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } {
   call @some_callback_into_python(%arg0) : (memref<*xf32>) -> ()
   return
 }
 func.func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface }
-""")
-    execution_engine = ExecutionEngine(lowerToLLVM(module))
-    execution_engine.register_runtime("some_callback_into_python", callback)
-    inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
-    # CHECK: Inside callback:
-    # CHECK{LITERAL}: [[1. 2.]
-    # CHECK{LITERAL}:  [3. 4.]]
-    execution_engine.invoke(
-        "callback_memref",
-        ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))),
-    )
-    inp_arr_1 = np.array([5, 6, 7], dtype=np.float32)
-    strided_arr = np.lib.stride_tricks.as_strided(
-        inp_arr_1, strides=(4, 0), shape=(3, 4))
-    # CHECK: Inside callback:
-    # CHECK{LITERAL}: [[5. 5. 5. 5.]
-    # CHECK{LITERAL}:  [6. 6. 6. 6.]
-    # CHECK{LITERAL}:  [7. 7. 7. 7.]]
-    execution_engine.invoke(
-        "callback_memref",
-        ctypes.pointer(
-            ctypes.pointer(get_unranked_memref_descriptor(strided_arr))),
-    )
+"""
+        )
+        execution_engine = ExecutionEngine(lowerToLLVM(module))
+        execution_engine.register_runtime("some_callback_into_python", callback)
+        inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
+        # CHECK: Inside callback:
+        # CHECK{LITERAL}: [[1. 2.]
+        # CHECK{LITERAL}:  [3. 4.]]
+        execution_engine.invoke(
+            "callback_memref",
+            ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))),
+        )
+        inp_arr_1 = np.array([5, 6, 7], dtype=np.float32)
+        strided_arr = np.lib.stride_tricks.as_strided(
+            inp_arr_1, strides=(4, 0), shape=(3, 4)
+        )
+        # CHECK: Inside callback:
+        # CHECK{LITERAL}: [[5. 5. 5. 5.]
+        # CHECK{LITERAL}:  [6. 6. 6. 6.]
+        # CHECK{LITERAL}:  [7. 7. 7. 7.]]
+        execution_engine.invoke(
+            "callback_memref",
+            ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(strided_arr))),
+        )
 
 
 run(testUnrankedMemRefCallback)
@@ -194,36 +207,39 @@ run(testUnrankedMemRefCallback)
 # Test callback with a ranked memref.
 # CHECK-LABEL: TEST: testRankedMemRefCallback
 def testRankedMemRefCallback():
-  # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it.
-  @ctypes.CFUNCTYPE(
-      None,
-      ctypes.POINTER(
-          make_nd_memref_descriptor(2,
-                                    np.ctypeslib.as_ctypes_type(np.float32))),
-  )
-  def callback(a):
-    arr = ranked_memref_to_numpy(a)
-    log("Inside Callback: ")
-    log(arr)
-
-  with Context():
-    # The module just forwards to a runtime function known as "some_callback_into_python".
-    module = Module.parse(r"""
+    # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it.
+    @ctypes.CFUNCTYPE(
+        None,
+        ctypes.POINTER(
+            make_nd_memref_descriptor(2, np.ctypeslib.as_ctypes_type(np.float32))
+        ),
+    )
+    def callback(a):
+        arr = ranked_memref_to_numpy(a)
+        log("Inside Callback: ")
+        log(arr)
+
+    with Context():
+        # The module just forwards to a runtime function known as "some_callback_into_python".
+        module = Module.parse(
+            r"""
 func.func @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } {
   call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> ()
   return
 }
 func.func private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface }
-""")
-    execution_engine = ExecutionEngine(lowerToLLVM(module))
-    execution_engine.register_runtime("some_callback_into_python", callback)
-    inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32)
-    # CHECK: Inside Callback:
-    # CHECK{LITERAL}: [[1. 5.]
-    # CHECK{LITERAL}:  [6. 7.]]
-    execution_engine.invoke(
-        "callback_memref",
-        ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))))
+"""
+        )
+        execution_engine = ExecutionEngine(lowerToLLVM(module))
+        execution_engine.register_runtime("some_callback_into_python", callback)
+        inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32)
+        # CHECK: Inside Callback:
+        # CHECK{LITERAL}: [[1. 5.]
+        # CHECK{LITERAL}:  [6. 7.]]
+        execution_engine.invoke(
+            "callback_memref",
+            ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))),
+        )
 
 
 run(testRankedMemRefCallback)
@@ -232,8 +248,9 @@ run(testRankedMemRefCallback)
 #  Test addition of two memrefs.
 # CHECK-LABEL: TEST: testMemrefAdd
 def testMemrefAdd():
-  with Context():
-    module = Module.parse("""
+    with Context():
+        module = Module.parse(
+            """
     module  {
       func.func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } {
         %0 = arith.constant 0 : index
@@ -243,23 +260,28 @@ def testMemrefAdd():
         memref.store %3, %arg2[%0] : memref<1xf32>
         return
       }
-    } """)
-    arg1 = np.array([32.5]).astype(np.float32)
-    arg2 = np.array(6).astype(np.float32)
-    res = np.array([0]).astype(np.float32)
-
-    arg1_memref_ptr = ctypes.pointer(
-        ctypes.pointer(get_ranked_memref_descriptor(arg1)))
-    arg2_memref_ptr = ctypes.pointer(
-        ctypes.pointer(get_ranked_memref_descriptor(arg2)))
-    res_memref_ptr = ctypes.pointer(
-        ctypes.pointer(get_ranked_memref_descriptor(res)))
-
-    execution_engine = ExecutionEngine(lowerToLLVM(module))
-    execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr,
-                            res_memref_ptr)
-    # CHECK: [32.5] + 6.0 = [38.5]
-    log("{0} + {1} = {2}".format(arg1, arg2, res))
+    } """
+        )
+        arg1 = np.array([32.5]).astype(np.float32)
+        arg2 = np.array(6).astype(np.float32)
+        res = np.array([0]).astype(np.float32)
+
+        arg1_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_ranked_memref_descriptor(arg1))
+        )
+        arg2_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_ranked_memref_descriptor(arg2))
+        )
+        res_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_ranked_memref_descriptor(res))
+        )
+
+        execution_engine = ExecutionEngine(lowerToLLVM(module))
+        execution_engine.invoke(
+            "main", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr
+        )
+        # CHECK: [32.5] + 6.0 = [38.5]
+        log("{0} + {1} = {2}".format(arg1, arg2, res))
 
 
 run(testMemrefAdd)
@@ -268,8 +290,9 @@ run(testMemrefAdd)
 # Test addition of two f16 memrefs
 # CHECK-LABEL: TEST: testF16MemrefAdd
 def testF16MemrefAdd():
-  with Context():
-    module = Module.parse("""
+    with Context():
+        module = Module.parse(
+            """
     module  {
       func.func @main(%arg0: memref<1xf16>,
                       %arg1: memref<1xf16>,
@@ -281,29 +304,34 @@ def testF16MemrefAdd():
         memref.store %3, %arg2[%0] : memref<1xf16>
         return
       }
-    } """)
-
-    arg1 = np.array([11.]).astype(np.float16)
-    arg2 = np.array([22.]).astype(np.float16)
-    arg3 = np.array([0.]).astype(np.float16)
-
-    arg1_memref_ptr = ctypes.pointer(
-        ctypes.pointer(get_ranked_memref_descriptor(arg1)))
-    arg2_memref_ptr = ctypes.pointer(
-        ctypes.pointer(get_ranked_memref_descriptor(arg2)))
-    arg3_memref_ptr = ctypes.pointer(
-        ctypes.pointer(get_ranked_memref_descriptor(arg3)))
-
-    execution_engine = ExecutionEngine(lowerToLLVM(module))
-    execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr,
-                            arg3_memref_ptr)
-    # CHECK: [11.] + [22.] = [33.]
-    log("{0} + {1} = {2}".format(arg1, arg2, arg3))
-
-    # test to-numpy utility
-    # CHECK: [33.]
-    npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
-    log(npout)
+    } """
+        )
+
+        arg1 = np.array([11.0]).astype(np.float16)
+        arg2 = np.array([22.0]).astype(np.float16)
+        arg3 = np.array([0.0]).astype(np.float16)
+
+        arg1_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_ranked_memref_descriptor(arg1))
+        )
+        arg2_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_ranked_memref_descriptor(arg2))
+        )
+        arg3_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_ranked_memref_descriptor(arg3))
+        )
+
+        execution_engine = ExecutionEngine(lowerToLLVM(module))
+        execution_engine.invoke(
+            "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr
+        )
+        # CHECK: [11.] + [22.] = [33.]
+        log("{0} + {1} = {2}".format(arg1, arg2, arg3))
+
+        # test to-numpy utility
+        # CHECK: [33.]
+        npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
+        log(npout)
 
 
 run(testF16MemrefAdd)
@@ -312,8 +340,9 @@ run(testF16MemrefAdd)
 # Test addition of two complex memrefs
 # CHECK-LABEL: TEST: testComplexMemrefAdd
 def testComplexMemrefAdd():
-  with Context():
-    module = Module.parse("""
+    with Context():
+        module = Module.parse(
+            """
     module  {
       func.func @main(%arg0: memref<1xcomplex<f64>>,
                       %arg1: memref<1xcomplex<f64>>,
@@ -325,31 +354,34 @@ def testComplexMemrefAdd():
         memref.store %3, %arg2[%0] : memref<1xcomplex<f64>>
         return
       }
-    } """)
-
-    arg1 = np.array([1.+2.j]).astype(np.complex128)
-    arg2 = np.array([3.+4.j]).astype(np.complex128)
-    arg3  = np.array([0.+0.j]).astype(np.complex128)
-
-    arg1_memref_ptr = ctypes.pointer(
-        ctypes.pointer(get_ranked_memref_descriptor(arg1)))
-    arg2_memref_ptr = ctypes.pointer(
-        ctypes.pointer(get_ranked_memref_descriptor(arg2)))
-    arg3_memref_ptr = ctypes.pointer(
-        ctypes.pointer(get_ranked_memref_descriptor(arg3)))
-
-    execution_engine = ExecutionEngine(lowerToLLVM(module))
-    execution_engine.invoke("main",
-                            arg1_memref_ptr,
-                            arg2_memref_ptr,
-                            arg3_memref_ptr)
-    # CHECK: [1.+2.j] + [3.+4.j] = [4.+6.j]
-    log("{0} + {1} = {2}".format(arg1, arg2, arg3))
-
-    # test to-numpy utility
-    # CHECK: [4.+6.j]
-    npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
-    log(npout)
+    } """
+        )
+
+        arg1 = np.array([1.0 + 2.0j]).astype(np.complex128)
+        arg2 = np.array([3.0 + 4.0j]).astype(np.complex128)
+        arg3 = np.array([0.0 + 0.0j]).astype(np.complex128)
+
+        arg1_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_ranked_memref_descriptor(arg1))
+        )
+        arg2_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_ranked_memref_descriptor(arg2))
+        )
+        arg3_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_ranked_memref_descriptor(arg3))
+        )
+
+        execution_engine = ExecutionEngine(lowerToLLVM(module))
+        execution_engine.invoke(
+            "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr
+        )
+        # CHECK: [1.+2.j] + [3.+4.j] = [4.+6.j]
+        log("{0} + {1} = {2}".format(arg1, arg2, arg3))
+
+        # test to-numpy utility
+        # CHECK: [4.+6.j]
+        npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
+        log(npout)
 
 
 run(testComplexMemrefAdd)
@@ -358,8 +390,9 @@ run(testComplexMemrefAdd)
 # Test addition of two complex unranked memrefs
 # CHECK-LABEL: TEST: testComplexUnrankedMemrefAdd
 def testComplexUnrankedMemrefAdd():
-  with Context():
-    module = Module.parse("""
+    with Context():
+        module = Module.parse(
+            """
     module  {
       func.func @main(%arg0: memref<*xcomplex<f32>>,
                       %arg1: memref<*xcomplex<f32>>,
@@ -374,32 +407,34 @@ def testComplexUnrankedMemrefAdd():
         memref.store %3, %C[%0] : memref<1xcomplex<f32>>
         return
       }
-    } """)
-
-    arg1 = np.array([5.+6.j]).astype(np.complex64)
-    arg2 = np.array([7.+8.j]).astype(np.complex64)
-    arg3  = np.array([0.+0.j]).astype(np.complex64)
-
-    arg1_memref_ptr = ctypes.pointer(
-        ctypes.pointer(get_unranked_memref_descriptor(arg1)))
-    arg2_memref_ptr = ctypes.pointer(
-        ctypes.pointer(get_unranked_memref_descriptor(arg2)))
-    arg3_memref_ptr = ctypes.pointer(
-        ctypes.pointer(get_unranked_memref_descriptor(arg3)))
-
-    execution_engine = ExecutionEngine(lowerToLLVM(module))
-    execution_engine.invoke("main",
-                            arg1_memref_ptr,
-                            arg2_memref_ptr,
-                            arg3_memref_ptr)
-    # CHECK: [5.+6.j] + [7.+8.j] = [12.+14.j]
-    log("{0} + {1} = {2}".format(arg1, arg2, arg3))
-
-    # test to-numpy utility
-    # CHECK: [12.+14.j]
-    npout = unranked_memref_to_numpy(arg3_memref_ptr[0],
-                                     np.dtype(np.complex64))
-    log(npout)
+    } """
+        )
+
+        arg1 = np.array([5.0 + 6.0j]).astype(np.complex64)
+        arg2 = np.array([7.0 + 8.0j]).astype(np.complex64)
+        arg3 = np.array([0.0 + 0.0j]).astype(np.complex64)
+
+        arg1_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_unranked_memref_descriptor(arg1))
+        )
+        arg2_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_unranked_memref_descriptor(arg2))
+        )
+        arg3_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_unranked_memref_descriptor(arg3))
+        )
+
+        execution_engine = ExecutionEngine(lowerToLLVM(module))
+        execution_engine.invoke(
+            "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr
+        )
+        # CHECK: [5.+6.j] + [7.+8.j] = [12.+14.j]
+        log("{0} + {1} = {2}".format(arg1, arg2, arg3))
+
+        # test to-numpy utility
+        # CHECK: [12.+14.j]
+        npout = unranked_memref_to_numpy(arg3_memref_ptr[0], np.dtype(np.complex64))
+        log(npout)
 
 
 run(testComplexUnrankedMemrefAdd)
@@ -408,8 +443,9 @@ run(testComplexUnrankedMemrefAdd)
 #  Test addition of two 2d_memref
 # CHECK-LABEL: TEST: testDynamicMemrefAdd2D
 def testDynamicMemrefAdd2D():
-  with Context():
-    module = Module.parse("""
+    with Context():
+        module = Module.parse(
+            """
       module  {
         func.func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref<?x?xf32>, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} {
           %c0 = arith.constant 0 : index
@@ -441,23 +477,28 @@ def testDynamicMemrefAdd2D():
           return
         }
       }
-        """)
-    arg1 = np.random.randn(2, 2).astype(np.float32)
-    arg2 = np.random.randn(2, 2).astype(np.float32)
-    res = np.random.randn(2, 2).astype(np.float32)
-
-    arg1_memref_ptr = ctypes.pointer(
-        ctypes.pointer(get_ranked_memref_descriptor(arg1)))
-    arg2_memref_ptr = ctypes.pointer(
-        ctypes.pointer(get_ranked_memref_descriptor(arg2)))
-    res_memref_ptr = ctypes.pointer(
-        ctypes.pointer(get_ranked_memref_descriptor(res)))
-
-    execution_engine = ExecutionEngine(lowerToLLVM(module))
-    execution_engine.invoke("memref_add_2d", arg1_memref_ptr, arg2_memref_ptr,
-                            res_memref_ptr)
-    # CHECK: True
-    log(np.allclose(arg1 + arg2, res))
+        """
+        )
+        arg1 = np.random.randn(2, 2).astype(np.float32)
+        arg2 = np.random.randn(2, 2).astype(np.float32)
+        res = np.random.randn(2, 2).astype(np.float32)
+
+        arg1_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_ranked_memref_descriptor(arg1))
+        )
+        arg2_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_ranked_memref_descriptor(arg2))
+        )
+        res_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_ranked_memref_descriptor(res))
+        )
+
+        execution_engine = ExecutionEngine(lowerToLLVM(module))
+        execution_engine.invoke(
+            "memref_add_2d", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr
+        )
+        # CHECK: True
+        log(np.allclose(arg1 + arg2, res))
 
 
 run(testDynamicMemrefAdd2D)
@@ -466,8 +507,9 @@ run(testDynamicMemrefAdd2D)
 #  Test loading of shared libraries.
 # CHECK-LABEL: TEST: testSharedLibLoad
 def testSharedLibLoad():
-  with Context():
-    module = Module.parse("""
+    with Context():
+        module = Module.parse(
+            """
       module  {
       func.func @main(%arg0: memref<1xf32>) attributes { llvm.emit_c_interface } {
         %c0 = arith.constant 0 : index
@@ -478,35 +520,36 @@ def testSharedLibLoad():
         return
       }
       func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface }
-     } """)
-    arg0 = np.array([0.0]).astype(np.float32)
-
-    arg0_memref_ptr = ctypes.pointer(
-        ctypes.pointer(get_ranked_memref_descriptor(arg0)))
-
-    if sys.platform == 'win32':
-      shared_libs = [
-          "../../../../bin/mlir_runner_utils.dll",
-          "../../../../bin/mlir_c_runner_utils.dll"
-      ]
-    elif sys.platform == 'darwin':
-      shared_libs = [
-          "../../../../lib/libmlir_runner_utils.dylib",
-          "../../../../lib/libmlir_c_runner_utils.dylib"
-      ]
-    else:
-      shared_libs = [
-          "../../../../lib/libmlir_runner_utils.so",
-          "../../../../lib/libmlir_c_runner_utils.so"
-      ]
-
-    execution_engine = ExecutionEngine(
-        lowerToLLVM(module),
-        opt_level=3,
-        shared_libs=shared_libs)
-    execution_engine.invoke("main", arg0_memref_ptr)
-    # CHECK: Unranked Memref
-    # CHECK-NEXT: [42]
+     } """
+        )
+        arg0 = np.array([0.0]).astype(np.float32)
+
+        arg0_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_ranked_memref_descriptor(arg0))
+        )
+
+        if sys.platform == "win32":
+            shared_libs = [
+                "../../../../bin/mlir_runner_utils.dll",
+                "../../../../bin/mlir_c_runner_utils.dll",
+            ]
+        elif sys.platform == "darwin":
+            shared_libs = [
+                "../../../../lib/libmlir_runner_utils.dylib",
+                "../../../../lib/libmlir_c_runner_utils.dylib",
+            ]
+        else:
+            shared_libs = [
+                "../../../../lib/libmlir_runner_utils.so",
+                "../../../../lib/libmlir_c_runner_utils.so",
+            ]
+
+        execution_engine = ExecutionEngine(
+            lowerToLLVM(module), opt_level=3, shared_libs=shared_libs
+        )
+        execution_engine.invoke("main", arg0_memref_ptr)
+        # CHECK: Unranked Memref
+        # CHECK-NEXT: [42]
 
 
 run(testSharedLibLoad)
@@ -515,8 +558,9 @@ run(testSharedLibLoad)
 #  Test that nano time clock is available.
 # CHECK-LABEL: TEST: testNanoTime
 def testNanoTime():
-  with Context():
-    module = Module.parse("""
+    with Context():
+        module = Module.parse(
+            """
       module {
       func.func @main() attributes { llvm.emit_c_interface } {
         %now = call @nanoTime() : () -> i64
@@ -529,26 +573,26 @@ def testNanoTime():
       }
       func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface }
       func.func private @printMemrefI64(memref<*xi64>) attributes { llvm.emit_c_interface }
-    }""")
-
-    if sys.platform == 'win32':
-      shared_libs = [
-          "../../../../bin/mlir_runner_utils.dll",
-          "../../../../bin/mlir_c_runner_utils.dll"
-      ]
-    else:
-      shared_libs = [
-          "../../../../lib/libmlir_runner_utils.so",
-          "../../../../lib/libmlir_c_runner_utils.so"
-      ]
-
-    execution_engine = ExecutionEngine(
-        lowerToLLVM(module),
-        opt_level=3,
-        shared_libs=shared_libs)
-    execution_engine.invoke("main")
-    # CHECK: Unranked Memref
-    # CHECK: [{{.*}}]
+    }"""
+        )
+
+        if sys.platform == "win32":
+            shared_libs = [
+                "../../../../bin/mlir_runner_utils.dll",
+                "../../../../bin/mlir_c_runner_utils.dll",
+            ]
+        else:
+            shared_libs = [
+                "../../../../lib/libmlir_runner_utils.so",
+                "../../../../lib/libmlir_c_runner_utils.so",
+            ]
+
+        execution_engine = ExecutionEngine(
+            lowerToLLVM(module), opt_level=3, shared_libs=shared_libs
+        )
+        execution_engine.invoke("main")
+        # CHECK: Unranked Memref
+        # CHECK: [{{.*}}]
 
 
 run(testNanoTime)
@@ -557,36 +601,36 @@ run(testNanoTime)
 #  Test that nano time clock is available.
 # CHECK-LABEL: TEST: testDumpToObjectFile
 def testDumpToObjectFile():
-  fd, object_path = tempfile.mkstemp(suffix=".o")
+    fd, object_path = tempfile.mkstemp(suffix=".o")
 
-  try:
-    with Context():
-      module = Module.parse("""
+    try:
+        with Context():
+            module = Module.parse(
+                """
         module {
         func.func @main() attributes { llvm.emit_c_interface } {
           return
         }
-      }""")
+      }"""
+            )
 
-      execution_engine = ExecutionEngine(
-          lowerToLLVM(module),
-          opt_level=3)
+            execution_engine = ExecutionEngine(lowerToLLVM(module), opt_level=3)
 
-      # CHECK: Object file exists: True
-      print(f"Object file exists: {os.path.exists(object_path)}")
-      # CHECK: Object file is empty: True
-      print(f"Object file is empty: {os.path.getsize(object_path) == 0}")
+            # CHECK: Object file exists: True
+            print(f"Object file exists: {os.path.exists(object_path)}")
+            # CHECK: Object file is empty: True
+            print(f"Object file is empty: {os.path.getsize(object_path) == 0}")
 
-      execution_engine.dump_to_object_file(object_path)
+            execution_engine.dump_to_object_file(object_path)
 
-      # CHECK: Object file exists: True
-      print(f"Object file exists: {os.path.exists(object_path)}")
-      # CHECK: Object file is empty: False
-      print(f"Object file is empty: {os.path.getsize(object_path) == 0}")
+            # CHECK: Object file exists: True
+            print(f"Object file exists: {os.path.exists(object_path)}")
+            # CHECK: Object file is empty: False
+            print(f"Object file is empty: {os.path.getsize(object_path) == 0}")
 
-  finally:
-    os.close(fd)
-    os.remove(object_path)
+    finally:
+        os.close(fd)
+        os.remove(object_path)
 
 
 run(testDumpToObjectFile)
index 2cba577..f6519fb 100644 (file)
@@ -15,8 +15,8 @@ from mlir.dialects.linalg.opdsl.lang import *
 # Log everything to stderr and flush so that we have a unified stream to match
 # errors/info emitted by MLIR to stderr.
 def log(*args):
-  print(*args, file=sys.stderr)
-  sys.stderr.flush()
+    print(*args, file=sys.stderr)
+    sys.stderr.flush()
 
 
 elemwise_boiler = """
@@ -186,428 +186,458 @@ func.func @main() -> i32 attributes {llvm.emit_c_interface} {
 
 
 def transform(module, boilerplate):
-  # TODO: Allow cloning functions from one module to another.
-  # Atm we have to resort to string concatenation.
-  ops = module.operation.regions[0].blocks[0].operations
-  mod = Module.parse("\n".join([str(op) for op in ops]) + boilerplate)
-
-  pm = PassManager('builtin.module')
-  pm.add("func.func(convert-linalg-to-loops)")
-  pm.add("func.func(lower-affine)")
-  pm.add("func.func(convert-math-to-llvm)")
-  pm.add("func.func(convert-scf-to-cf)")
-  pm.add("func.func(arith-expand)")
-  pm.add("func.func(memref-expand)")
-  pm.add("convert-vector-to-llvm")
-  pm.add("finalize-memref-to-llvm")
-  pm.add("convert-func-to-llvm")
-  pm.add("reconcile-unrealized-casts")
-  pm.run(mod.operation)
-  return mod
+    # TODO: Allow cloning functions from one module to another.
+    # Atm we have to resort to string concatenation.
+    ops = module.operation.regions[0].blocks[0].operations
+    mod = Module.parse("\n".join([str(op) for op in ops]) + boilerplate)
+
+    pm = PassManager("builtin.module")
+    pm.add("func.func(convert-linalg-to-loops)")
+    pm.add("func.func(lower-affine)")
+    pm.add("func.func(convert-math-to-llvm)")
+    pm.add("func.func(convert-scf-to-cf)")
+    pm.add("func.func(arith-expand)")
+    pm.add("func.func(memref-expand)")
+    pm.add("convert-vector-to-llvm")
+    pm.add("finalize-memref-to-llvm")
+    pm.add("convert-func-to-llvm")
+    pm.add("reconcile-unrealized-casts")
+    pm.run(mod.operation)
+    return mod
 
 
 def test_elemwise_builtin():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32 = F32Type.get()
-    i8 = IntegerType.get_signless(8)
-    with InsertionPoint(module.body):
-
-      @func.FuncOp.from_py_func(
-          MemRefType.get((), f32), MemRefType.get((4, 8), f32),
-          MemRefType.get((4, 8), f32))
-      def elemwise_exp_add_on_buffers(lhs, rhs, out):
-        linalg.elemwise_unary(lhs, outs=[out])
-        linalg.elemwise_binary(out, rhs, outs=[out])
-
-      @func.FuncOp.from_py_func(
-          MemRefType.get((), f32), MemRefType.get((4, 8), f32),
-          MemRefType.get((4, 8), f32))
-      def elemwise_log_mul_on_buffers(lhs, rhs, out):
-        linalg.elemwise_unary(lhs, outs=[out], fun=UnaryFn.log)
-        linalg.elemwise_binary(out, rhs, outs=[out], fun=BinaryFn.mul)
-
-    execution_engine = ExecutionEngine(transform(module, elemwise_boiler))
-
-    # TODO: FFI-based solution to allow testing and printing with python code.
-    # Prepare arguments: one result f32.
-    # Arguments must be passed as pointers.
-    c_float_p = ctypes.c_float * 1
-    res = c_float_p(-1.)
-    execution_engine.invoke("main", res)
-
-    log("RESULT: ", res[0])
-    # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846
-    # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0
-    # CHECK: RESULT: 4.71828
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        i8 = IntegerType.get_signless(8)
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(
+                MemRefType.get((), f32),
+                MemRefType.get((4, 8), f32),
+                MemRefType.get((4, 8), f32),
+            )
+            def elemwise_exp_add_on_buffers(lhs, rhs, out):
+                linalg.elemwise_unary(lhs, outs=[out])
+                linalg.elemwise_binary(out, rhs, outs=[out])
+
+            @func.FuncOp.from_py_func(
+                MemRefType.get((), f32),
+                MemRefType.get((4, 8), f32),
+                MemRefType.get((4, 8), f32),
+            )
+            def elemwise_log_mul_on_buffers(lhs, rhs, out):
+                linalg.elemwise_unary(lhs, outs=[out], fun=UnaryFn.log)
+                linalg.elemwise_binary(out, rhs, outs=[out], fun=BinaryFn.mul)
+
+        execution_engine = ExecutionEngine(transform(module, elemwise_boiler))
+
+        # TODO: FFI-based solution to allow testing and printing with python code.
+        # Prepare arguments: one result f32.
+        # Arguments must be passed as pointers.
+        c_float_p = ctypes.c_float * 1
+        res = c_float_p(-1.0)
+        execution_engine.invoke("main", res)
+
+        log("RESULT: ", res[0])
+        # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846
+        # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0
+        # CHECK: RESULT: 4.71828
 
 
 test_elemwise_builtin()
 
 
 def test_elemwise_generic():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32 = F32Type.get()
-    i8 = IntegerType.get_signless(8)
-    with InsertionPoint(module.body):
-
-      @func.FuncOp.from_py_func(
-          MemRefType.get((), f32), MemRefType.get((4, 8), f32),
-          MemRefType.get((4, 8), f32))
-      def elemwise_exp_add_on_buffers(lhs, rhs, out):
-        linalg.elemwise_unary(lhs, outs=[out], emit_generic=True)
-        linalg.elemwise_binary(out, rhs, outs=[out], emit_generic=True)
-
-      @func.FuncOp.from_py_func(
-          MemRefType.get((), f32), MemRefType.get((4, 8), f32),
-          MemRefType.get((4, 8), f32))
-      def elemwise_log_mul_on_buffers(lhs, rhs, out):
-        linalg.elemwise_unary(
-            lhs, outs=[out], fun=UnaryFn.log, emit_generic=True)
-        linalg.elemwise_binary(
-            out, rhs, outs=[out], fun=BinaryFn.mul, emit_generic=True)
-
-    execution_engine = ExecutionEngine(transform(module, elemwise_boiler))
-
-    # TODO: FFI-based solution to allow testing and printing with python code.
-    # Prepare arguments: one result f32.
-    # Arguments must be passed as pointers.
-    c_float_p = ctypes.c_float * 1
-    res = c_float_p(-1.)
-    execution_engine.invoke("main", res)
-
-    log("RESULT: ", res[0])
-    # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846
-    # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0
-    # CHECK: RESULT: 4.71828
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        i8 = IntegerType.get_signless(8)
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(
+                MemRefType.get((), f32),
+                MemRefType.get((4, 8), f32),
+                MemRefType.get((4, 8), f32),
+            )
+            def elemwise_exp_add_on_buffers(lhs, rhs, out):
+                linalg.elemwise_unary(lhs, outs=[out], emit_generic=True)
+                linalg.elemwise_binary(out, rhs, outs=[out], emit_generic=True)
+
+            @func.FuncOp.from_py_func(
+                MemRefType.get((), f32),
+                MemRefType.get((4, 8), f32),
+                MemRefType.get((4, 8), f32),
+            )
+            def elemwise_log_mul_on_buffers(lhs, rhs, out):
+                linalg.elemwise_unary(
+                    lhs, outs=[out], fun=UnaryFn.log, emit_generic=True
+                )
+                linalg.elemwise_binary(
+                    out, rhs, outs=[out], fun=BinaryFn.mul, emit_generic=True
+                )
+
+        execution_engine = ExecutionEngine(transform(module, elemwise_boiler))
+
+        # TODO: FFI-based solution to allow testing and printing with python code.
+        # Prepare arguments: one result f32.
+        # Arguments must be passed as pointers.
+        c_float_p = ctypes.c_float * 1
+        res = c_float_p(-1.0)
+        execution_engine.invoke("main", res)
+
+        log("RESULT: ", res[0])
+        # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846
+        # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0
+        # CHECK: RESULT: 4.71828
 
 
 test_elemwise_generic()
 
 
 def test_matmul_builtin():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32 = F32Type.get()
-    i8 = IntegerType.get_signless(8)
-    with InsertionPoint(module.body):
-
-      @func.FuncOp.from_py_func(
-          MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32),
-          MemRefType.get((4, 8), f32))
-      def matmul_signed_on_buffers(lhs, rhs, out):
-        linalg.matmul(lhs, rhs, outs=[out])
-
-      @func.FuncOp.from_py_func(
-          MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32),
-          MemRefType.get((4, 8), f32))
-      def matmul_unsigned_on_buffers(lhs, rhs, out):
-        linalg.matmul(lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned)
-
-    execution_engine = ExecutionEngine(transform(module, matmul_boiler))
-
-    # TODO: FFI-based solution to allow testing and printing with python code.
-    # Prepare arguments: one result f32.
-    # Arguments must be passed as pointers.
-    c_float_p = ctypes.c_float * 1
-    res = c_float_p(-1.)
-    execution_engine.invoke("main", res)
-
-    log("RESULT: ", res[0])
-    # matmul_signed_on_buffers: -1 * 2.0 * 16 = -32
-    # matmul_unsigned_on_buffers: (2^8-1) * 2.0 * 16 = 8160
-    # CHECK: RESULT: 8128
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        i8 = IntegerType.get_signless(8)
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(
+                MemRefType.get((4, 16), i8),
+                MemRefType.get((16, 8), f32),
+                MemRefType.get((4, 8), f32),
+            )
+            def matmul_signed_on_buffers(lhs, rhs, out):
+                linalg.matmul(lhs, rhs, outs=[out])
+
+            @func.FuncOp.from_py_func(
+                MemRefType.get((4, 16), i8),
+                MemRefType.get((16, 8), f32),
+                MemRefType.get((4, 8), f32),
+            )
+            def matmul_unsigned_on_buffers(lhs, rhs, out):
+                linalg.matmul(lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned)
+
+        execution_engine = ExecutionEngine(transform(module, matmul_boiler))
+
+        # TODO: FFI-based solution to allow testing and printing with python code.
+        # Prepare arguments: one result f32.
+        # Arguments must be passed as pointers.
+        c_float_p = ctypes.c_float * 1
+        res = c_float_p(-1.0)
+        execution_engine.invoke("main", res)
+
+        log("RESULT: ", res[0])
+        # matmul_signed_on_buffers: -1 * 2.0 * 16 = -32
+        # matmul_unsigned_on_buffers: (2^8-1) * 2.0 * 16 = 8160
+        # CHECK: RESULT: 8128
 
 
 test_matmul_builtin()
 
 
 def test_matmul_generic():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32 = F32Type.get()
-    i8 = IntegerType.get_signless(8)
-    with InsertionPoint(module.body):
-
-      @func.FuncOp.from_py_func(
-          MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32),
-          MemRefType.get((4, 8), f32))
-      def matmul_signed_on_buffers(lhs, rhs, out):
-        linalg.matmul(lhs, rhs, outs=[out], emit_generic=True)
-
-      @func.FuncOp.from_py_func(
-          MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32),
-          MemRefType.get((4, 8), f32))
-      def matmul_unsigned_on_buffers(lhs, rhs, out):
-        linalg.matmul(
-            lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned, emit_generic=True)
-
-    execution_engine = ExecutionEngine(transform(module, matmul_boiler))
-
-    # TODO: FFI-based solution to allow testing and printing with python code.
-    # Prepare arguments: one result f32.
-    # Arguments must be passed as pointers.
-    c_float_p = ctypes.c_float * 1
-    res = c_float_p(-1.)
-    execution_engine.invoke("main", res)
-
-    log("RESULT: ", res[0])
-    # matmul_signed_on_buffers = -1 * 2.0 * 16 = -32
-    # matmul_unsigned_on_buffers = (2^8-1) * 2.0 * 16 = 8160
-    # CHECK: RESULT: 8128
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        i8 = IntegerType.get_signless(8)
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(
+                MemRefType.get((4, 16), i8),
+                MemRefType.get((16, 8), f32),
+                MemRefType.get((4, 8), f32),
+            )
+            def matmul_signed_on_buffers(lhs, rhs, out):
+                linalg.matmul(lhs, rhs, outs=[out], emit_generic=True)
+
+            @func.FuncOp.from_py_func(
+                MemRefType.get((4, 16), i8),
+                MemRefType.get((16, 8), f32),
+                MemRefType.get((4, 8), f32),
+            )
+            def matmul_unsigned_on_buffers(lhs, rhs, out):
+                linalg.matmul(
+                    lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned, emit_generic=True
+                )
+
+        execution_engine = ExecutionEngine(transform(module, matmul_boiler))
+
+        # TODO: FFI-based solution to allow testing and printing with python code.
+        # Prepare arguments: one result f32.
+        # Arguments must be passed as pointers.
+        c_float_p = ctypes.c_float * 1
+        res = c_float_p(-1.0)
+        execution_engine.invoke("main", res)
+
+        log("RESULT: ", res[0])
+        # matmul_signed_on_buffers = -1 * 2.0 * 16 = -32
+        # matmul_unsigned_on_buffers = (2^8-1) * 2.0 * 16 = 8160
+        # CHECK: RESULT: 8128
 
 
 test_matmul_generic()
 
 
 def test_fill_builtin():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32 = F32Type.get()
-    i32 = IntegerType.get_signless(32)
-    with InsertionPoint(module.body):
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        i32 = IntegerType.get_signless(32)
+        with InsertionPoint(module.body):
 
-      @func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
-      def fill_0d_on_buffers(value, out):
-        linalg.fill(value, outs=[out])
+            @func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
+            def fill_0d_on_buffers(value, out):
+                linalg.fill(value, outs=[out])
 
-      @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
-      def fill_1d_on_buffers(value, out):
-        linalg.fill(value, outs=[out])
+            @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
+            def fill_1d_on_buffers(value, out):
+                linalg.fill(value, outs=[out])
 
-      @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
-      def fill_2d_on_buffers(value, out):
-        linalg.fill(value, outs=[out])
+            @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
+            def fill_2d_on_buffers(value, out):
+                linalg.fill(value, outs=[out])
 
-    execution_engine = ExecutionEngine(transform(module, fill_boiler))
+        execution_engine = ExecutionEngine(transform(module, fill_boiler))
 
-    # TODO: FFI-based solution to allow testing and printing with python code.
-    # Prepare arguments: one result i32.
-    # Arguments must be passed as pointers.
-    c_int_p = ctypes.c_int * 1
-    res = c_int_p(-1)
-    execution_engine.invoke("main", res)
+        # TODO: FFI-based solution to allow testing and printing with python code.
+        # Prepare arguments: one result i32.
+        # Arguments must be passed as pointers.
+        c_int_p = ctypes.c_int * 1
+        res = c_int_p(-1)
+        execution_engine.invoke("main", res)
 
-    log("RESULT: ", res[0])
-    # CHECK: RESULT: 6
+        log("RESULT: ", res[0])
+        # CHECK: RESULT: 6
 
 
 test_fill_builtin()
 
 
 def test_fill_generic():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32 = F32Type.get()
-    i32 = IntegerType.get_signless(32)
-    with InsertionPoint(module.body):
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        i32 = IntegerType.get_signless(32)
+        with InsertionPoint(module.body):
 
-      @func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
-      def fill_0d_on_buffers(value, out):
-        linalg.fill(value, outs=[out], emit_generic=True)
+            @func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
+            def fill_0d_on_buffers(value, out):
+                linalg.fill(value, outs=[out], emit_generic=True)
 
-      @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
-      def fill_1d_on_buffers(value, out):
-        linalg.fill(value, outs=[out], emit_generic=True)
+            @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
+            def fill_1d_on_buffers(value, out):
+                linalg.fill(value, outs=[out], emit_generic=True)
 
-      @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
-      def fill_2d_on_buffers(value, out):
-        linalg.fill(value, outs=[out], emit_generic=True)
+            @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
+            def fill_2d_on_buffers(value, out):
+                linalg.fill(value, outs=[out], emit_generic=True)
 
-    execution_engine = ExecutionEngine(transform(module, fill_boiler))
+        execution_engine = ExecutionEngine(transform(module, fill_boiler))
 
-    # TODO: FFI-based solution to allow testing and printing with python code.
-    # Prepare arguments: one result i32.
-    # Arguments must be passed as pointers.
-    c_int_p = ctypes.c_int * 1
-    res = c_int_p(-1)
-    execution_engine.invoke("main", res)
+        # TODO: FFI-based solution to allow testing and printing with python code.
+        # Prepare arguments: one result i32.
+        # Arguments must be passed as pointers.
+        c_int_p = ctypes.c_int * 1
+        res = c_int_p(-1)
+        execution_engine.invoke("main", res)
 
-    log("RESULT: ", res[0])
-    # CHECK: RESULT: 6
+        log("RESULT: ", res[0])
+        # CHECK: RESULT: 6
 
 
 test_fill_generic()
 
 
 def test_fill_rng_builtin():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f64 = F64Type.get()
-    i32 = IntegerType.get_signless(32)
-    with InsertionPoint(module.body):
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f64 = F64Type.get()
+        i32 = IntegerType.get_signless(32)
+        with InsertionPoint(module.body):
 
-      @func.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32))
-      def fill_rng_on_buffers(min, max, seed, out):
-        linalg.fill_rng_2d(min, max, seed, outs=[out])
+            @func.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32))
+            def fill_rng_on_buffers(min, max, seed, out):
+                linalg.fill_rng_2d(min, max, seed, outs=[out])
 
-    execution_engine = ExecutionEngine(transform(module, fill_rng_boiler))
+        execution_engine = ExecutionEngine(transform(module, fill_rng_boiler))
 
-    # TODO: FFI-based solution to allow testing and printing with python code.
-    # Prepare arguments: one result i32.
-    # Arguments must be passed as pointers.
-    c_int_p = ctypes.c_int * 1
-    res = c_int_p(-1)
-    execution_engine.invoke("main", res)
+        # TODO: FFI-based solution to allow testing and printing with python code.
+        # Prepare arguments: one result i32.
+        # Arguments must be passed as pointers.
+        c_int_p = ctypes.c_int * 1
+        res = c_int_p(-1)
+        execution_engine.invoke("main", res)
 
-    log("RESULT: ", res[0])
-    # CHECK: RESULT: -480
+        log("RESULT: ", res[0])
+        # CHECK: RESULT: -480
 
 
 test_fill_rng_builtin()
 
 
 def test_fill_rng_generic():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f64 = F64Type.get()
-    i32 = IntegerType.get_signless(32)
-    with InsertionPoint(module.body):
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f64 = F64Type.get()
+        i32 = IntegerType.get_signless(32)
+        with InsertionPoint(module.body):
 
-      @func.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32))
-      def fill_rng_on_buffers(min, max, seed, out):
-        linalg.fill_rng_2d(min, max, seed, outs=[out], emit_generic=True)
+            @func.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32))
+            def fill_rng_on_buffers(min, max, seed, out):
+                linalg.fill_rng_2d(min, max, seed, outs=[out], emit_generic=True)
 
-    execution_engine = ExecutionEngine(transform(module, fill_rng_boiler))
+        execution_engine = ExecutionEngine(transform(module, fill_rng_boiler))
 
-    # TODO: FFI-based solution to allow testing and printing with python code.
-    # Prepare arguments: one result i32.
-    # Arguments must be passed as pointers.
-    c_int_p = ctypes.c_int * 1
-    res = c_int_p(-1)
-    execution_engine.invoke("main", res)
+        # TODO: FFI-based solution to allow testing and printing with python code.
+        # Prepare arguments: one result i32.
+        # Arguments must be passed as pointers.
+        c_int_p = ctypes.c_int * 1
+        res = c_int_p(-1)
+        execution_engine.invoke("main", res)
 
-    log("RESULT: ", res[0])
-    # CHECK: RESULT: -480
+        log("RESULT: ", res[0])
+        # CHECK: RESULT: -480
 
 
 test_fill_rng_generic()
 
 
 def test_max_pooling_builtin():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f64 = F64Type.get()
-    i32 = IntegerType.get_signless(32)
-    with InsertionPoint(module.body):
-
-      @func.FuncOp.from_py_func(
-          MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
-          MemRefType.get((1, 2, 4, 1), i32))
-      def pooling_on_buffers(input, shape, output):
-        linalg.pooling_nhwc_max(
-            input, shape, outs=[output], strides=[2, 4], dilations=[1, 2])
-
-    execution_engine = ExecutionEngine(transform(module, pooling_boiler))
-
-    # TODO: FFI-based solution to allow testing and printing with python code.
-    # Prepare arguments: one result i32.
-    # Arguments must be passed as pointers.
-    c_int_p = ctypes.c_int * 1
-    res = c_int_p(-1)
-    execution_engine.invoke("main", res)
-
-    log("RESULT: ", res[0])
-    # 77 is not selected due to the dilation 2 in the second dimension.
-    # CHECK: RESULT: 42
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f64 = F64Type.get()
+        i32 = IntegerType.get_signless(32)
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(
+                MemRefType.get((1, 4, 16, 1), f64),
+                MemRefType.get((2, 2), f64),
+                MemRefType.get((1, 2, 4, 1), i32),
+            )
+            def pooling_on_buffers(input, shape, output):
+                linalg.pooling_nhwc_max(
+                    input, shape, outs=[output], strides=[2, 4], dilations=[1, 2]
+                )
+
+        execution_engine = ExecutionEngine(transform(module, pooling_boiler))
+
+        # TODO: FFI-based solution to allow testing and printing with python code.
+        # Prepare arguments: one result i32.
+        # Arguments must be passed as pointers.
+        c_int_p = ctypes.c_int * 1
+        res = c_int_p(-1)
+        execution_engine.invoke("main", res)
+
+        log("RESULT: ", res[0])
+        # 77 is not selected due to the dilation 2 in the second dimension.
+        # CHECK: RESULT: 42
 
 
 test_max_pooling_builtin()
 
 
 def test_max_pooling_generic():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f64 = F64Type.get()
-    i32 = IntegerType.get_signless(32)
-    with InsertionPoint(module.body):
-
-      @func.FuncOp.from_py_func(
-          MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
-          MemRefType.get((1, 2, 4, 1), i32))
-      def pooling_on_buffers(input, shape, output):
-        linalg.pooling_nhwc_max(
-            input,
-            shape,
-            outs=[output],
-            strides=[2, 4],
-            dilations=[1, 2],
-            emit_generic=True)
-
-    execution_engine = ExecutionEngine(transform(module, pooling_boiler))
-
-    # TODO: FFI-based solution to allow testing and printing with python code.
-    # Prepare arguments: one result i32.
-    # Arguments must be passed as pointers.
-    c_int_p = ctypes.c_int * 1
-    res = c_int_p(-1)
-    execution_engine.invoke("main", res)
-
-    log("RESULT: ", res[0])
-    # 77 is not selected due to the dilation 2 in the second dimension.
-    # CHECK: RESULT: 42
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f64 = F64Type.get()
+        i32 = IntegerType.get_signless(32)
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(
+                MemRefType.get((1, 4, 16, 1), f64),
+                MemRefType.get((2, 2), f64),
+                MemRefType.get((1, 2, 4, 1), i32),
+            )
+            def pooling_on_buffers(input, shape, output):
+                linalg.pooling_nhwc_max(
+                    input,
+                    shape,
+                    outs=[output],
+                    strides=[2, 4],
+                    dilations=[1, 2],
+                    emit_generic=True,
+                )
+
+        execution_engine = ExecutionEngine(transform(module, pooling_boiler))
+
+        # TODO: FFI-based solution to allow testing and printing with python code.
+        # Prepare arguments: one result i32.
+        # Arguments must be passed as pointers.
+        c_int_p = ctypes.c_int * 1
+        res = c_int_p(-1)
+        execution_engine.invoke("main", res)
+
+        log("RESULT: ", res[0])
+        # 77 is not selected due to the dilation 2 in the second dimension.
+        # CHECK: RESULT: 42
 
 
 test_max_pooling_generic()
 
 
 def test_min_pooling_builtin():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f64 = F64Type.get()
-    i32 = IntegerType.get_signless(32)
-    with InsertionPoint(module.body):
-
-      @func.FuncOp.from_py_func(
-          MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
-          MemRefType.get((1, 2, 4, 1), i32))
-      # Set the strides and use the default dilations.
-      def pooling_on_buffers(input, shape, output):
-        linalg.pooling_nhwc_min(input, shape, outs=[output], strides=[2, 4])
-
-    execution_engine = ExecutionEngine(transform(module, pooling_boiler))
-
-    # TODO: FFI-based solution to allow testing and printing with python code.
-    # Prepare arguments: one result i32.
-    # Arguments must be passed as pointers.
-    c_int_p = ctypes.c_int * 1
-    res = c_int_p(-1)
-    execution_engine.invoke("main", res)
-
-    log("RESULT: ", res[0])
-    # CHECK: RESULT: -13
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f64 = F64Type.get()
+        i32 = IntegerType.get_signless(32)
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(
+                MemRefType.get((1, 4, 16, 1), f64),
+                MemRefType.get((2, 2), f64),
+                MemRefType.get((1, 2, 4, 1), i32),
+            )
+            # Set the strides and use the default dilations.
+            def pooling_on_buffers(input, shape, output):
+                linalg.pooling_nhwc_min(input, shape, outs=[output], strides=[2, 4])
+
+        execution_engine = ExecutionEngine(transform(module, pooling_boiler))
+
+        # TODO: FFI-based solution to allow testing and printing with python code.
+        # Prepare arguments: one result i32.
+        # Arguments must be passed as pointers.
+        c_int_p = ctypes.c_int * 1
+        res = c_int_p(-1)
+        execution_engine.invoke("main", res)
+
+        log("RESULT: ", res[0])
+        # CHECK: RESULT: -13
 
 
 test_min_pooling_builtin()
 
 
 def test_min_pooling_generic():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f64 = F64Type.get()
-    i32 = IntegerType.get_signless(32)
-    with InsertionPoint(module.body):
-
-      @func.FuncOp.from_py_func(
-          MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
-          MemRefType.get((1, 2, 4, 1), i32))
-      # Set the strides and use the default dilations.
-      def pooling_on_buffers(input, shape, output):
-        linalg.pooling_nhwc_min(
-            input, shape, outs=[output], strides=[2, 4], emit_generic=True)
-
-    execution_engine = ExecutionEngine(transform(module, pooling_boiler))
-
-    # TODO: FFI-based solution to allow testing and printing with python code.
-    # Prepare arguments: one result i32.
-    # Arguments must be passed as pointers.
-    c_int_p = ctypes.c_int * 1
-    res = c_int_p(-1)
-    execution_engine.invoke("main", res)
-
-    log("RESULT: ", res[0])
-    # CHECK: RESULT: -13
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f64 = F64Type.get()
+        i32 = IntegerType.get_signless(32)
+        with InsertionPoint(module.body):
+
+            @func.FuncOp.from_py_func(
+                MemRefType.get((1, 4, 16, 1), f64),
+                MemRefType.get((2, 2), f64),
+                MemRefType.get((1, 2, 4, 1), i32),
+            )
+            # Set the strides and use the default dilations.
+            def pooling_on_buffers(input, shape, output):
+                linalg.pooling_nhwc_min(
+                    input, shape, outs=[output], strides=[2, 4], emit_generic=True
+                )
+
+        execution_engine = ExecutionEngine(transform(module, pooling_boiler))
+
+        # TODO: FFI-based solution to allow testing and printing with python code.
+        # Prepare arguments: one result i32.
+        # Arguments must be passed as pointers.
+        c_int_p = ctypes.c_int * 1
+        res = c_int_p(-1)
+        execution_engine.invoke("main", res)
+
+        log("RESULT: ", res[0])
+        # CHECK: RESULT: -13
 
 
 test_min_pooling_generic()
index 6a3a6fc..6356430 100644 (file)
@@ -3,59 +3,61 @@
 import gc
 from mlir.ir import *
 
+
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  gc.collect()
-  assert Context._get_live_count() == 0
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
+    return f
 
 
 # CHECK-LABEL: TEST: testAffineExprCapsule
 @run
 def testAffineExprCapsule():
-  with Context() as ctx:
-    affine_expr = AffineExpr.get_constant(42)
+    with Context() as ctx:
+        affine_expr = AffineExpr.get_constant(42)
 
-  affine_expr_capsule = affine_expr._CAPIPtr
-  # CHECK: capsule object
-  # CHECK: mlir.ir.AffineExpr._CAPIPtr
-  print(affine_expr_capsule)
+    affine_expr_capsule = affine_expr._CAPIPtr
+    # CHECK: capsule object
+    # CHECK: mlir.ir.AffineExpr._CAPIPtr
+    print(affine_expr_capsule)
 
-  affine_expr_2 = AffineExpr._CAPICreate(affine_expr_capsule)
-  assert affine_expr == affine_expr_2
-  assert affine_expr_2.context == ctx
+    affine_expr_2 = AffineExpr._CAPICreate(affine_expr_capsule)
+    assert affine_expr == affine_expr_2
+    assert affine_expr_2.context == ctx
 
 
 # CHECK-LABEL: TEST: testAffineExprEq
 @run
 def testAffineExprEq():
-  with Context():
-    a1 = AffineExpr.get_constant(42)
-    a2 = AffineExpr.get_constant(42)
-    a3 = AffineExpr.get_constant(43)
-    # CHECK: True
-    print(a1 == a1)
-    # CHECK: True
-    print(a1 == a2)
-    # CHECK: False
-    print(a1 == a3)
-    # CHECK: False
-    print(a1 == None)
-    # CHECK: False
-    print(a1 == "foo")
+    with Context():
+        a1 = AffineExpr.get_constant(42)
+        a2 = AffineExpr.get_constant(42)
+        a3 = AffineExpr.get_constant(43)
+        # CHECK: True
+        print(a1 == a1)
+        # CHECK: True
+        print(a1 == a2)
+        # CHECK: False
+        print(a1 == a3)
+        # CHECK: False
+        print(a1 == None)
+        # CHECK: False
+        print(a1 == "foo")
 
 
 # CHECK-LABEL: TEST: testAffineExprContext
 @run
 def testAffineExprContext():
-  with Context():
-    a1 = AffineExpr.get_constant(42)
-  with Context():
-    a2 = AffineExpr.get_constant(42)
+    with Context():
+        a1 = AffineExpr.get_constant(42)
+    with Context():
+        a2 = AffineExpr.get_constant(42)
+
+    # CHECK: False
+    print(a1 == a2)
 
-  # CHECK: False
-  print(a1 == a2)
 
 run(testAffineExprContext)
 
@@ -63,340 +65,343 @@ run(testAffineExprContext)
 # CHECK-LABEL: TEST: testAffineExprConstant
 @run
 def testAffineExprConstant():
-  with Context():
-    a1 = AffineExpr.get_constant(42)
-    # CHECK: 42
-    print(a1.value)
-    # CHECK: 42
-    print(a1)
+    with Context():
+        a1 = AffineExpr.get_constant(42)
+        # CHECK: 42
+        print(a1.value)
+        # CHECK: 42
+        print(a1)
 
-    a2 = AffineConstantExpr.get(42)
-    # CHECK: 42
-    print(a2.value)
-    # CHECK: 42
-    print(a2)
+        a2 = AffineConstantExpr.get(42)
+        # CHECK: 42
+        print(a2.value)
+        # CHECK: 42
+        print(a2)
 
-    assert a1 == a2
+        assert a1 == a2
 
 
 # CHECK-LABEL: TEST: testAffineExprDim
 @run
 def testAffineExprDim():
-  with Context():
-    d1 = AffineExpr.get_dim(1)
-    d11 = AffineDimExpr.get(1)
-    d2 = AffineDimExpr.get(2)
+    with Context():
+        d1 = AffineExpr.get_dim(1)
+        d11 = AffineDimExpr.get(1)
+        d2 = AffineDimExpr.get(2)
 
-    # CHECK: 1
-    print(d1.position)
-    # CHECK: d1
-    print(d1)
+        # CHECK: 1
+        print(d1.position)
+        # CHECK: d1
+        print(d1)
 
-    # CHECK: 2
-    print(d2.position)
-    # CHECK: d2
-    print(d2)
+        # CHECK: 2
+        print(d2.position)
+        # CHECK: d2
+        print(d2)
 
-    assert d1 == d11
-    assert d1 != d2
+        assert d1 == d11
+        assert d1 != d2
 
 
 # CHECK-LABEL: TEST: testAffineExprSymbol
 @run
 def testAffineExprSymbol():
-  with Context():
-    s1 = AffineExpr.get_symbol(1)
-    s11 = AffineSymbolExpr.get(1)
-    s2 = AffineSymbolExpr.get(2)
+    with Context():
+        s1 = AffineExpr.get_symbol(1)
+        s11 = AffineSymbolExpr.get(1)
+        s2 = AffineSymbolExpr.get(2)
 
-    # CHECK: 1
-    print(s1.position)
-    # CHECK: s1
-    print(s1)
+        # CHECK: 1
+        print(s1.position)
+        # CHECK: s1
+        print(s1)
 
-    # CHECK: 2
-    print(s2.position)
-    # CHECK: s2
-    print(s2)
+        # CHECK: 2
+        print(s2.position)
+        # CHECK: s2
+        print(s2)
 
-    assert s1 == s11
-    assert s1 != s2
+        assert s1 == s11
+        assert s1 != s2
 
 
 # CHECK-LABEL: TEST: testAffineAddExpr
 @run
 def testAffineAddExpr():
-  with Context():
-    d1 = AffineDimExpr.get(1)
-    d2 = AffineDimExpr.get(2)
-    d12 = AffineExpr.get_add(d1, d2)
-    # CHECK: d1 + d2
-    print(d12)
+    with Context():
+        d1 = AffineDimExpr.get(1)
+        d2 = AffineDimExpr.get(2)
+        d12 = AffineExpr.get_add(d1, d2)
+        # CHECK: d1 + d2
+        print(d12)
 
-    d12op = d1 + d2
-    # CHECK: d1 + d2
-    print(d12op)
+        d12op = d1 + d2
+        # CHECK: d1 + d2
+        print(d12op)
 
-    d1cst_op = d1 + 2
-    # CHECK: d1 + 2
-    print(d1cst_op)
+        d1cst_op = d1 + 2
+        # CHECK: d1 + 2
+        print(d1cst_op)
 
-    d1cst_op2 = 2 + d1
-    # CHECK: d1 + 2
-    print(d1cst_op2)
+        d1cst_op2 = 2 + d1
+        # CHECK: d1 + 2
+        print(d1cst_op2)
 
-    assert d12 == d12op
-    assert d12.lhs == d1
-    assert d12.rhs == d2
+        assert d12 == d12op
+        assert d12.lhs == d1
+        assert d12.rhs == d2
 
 
 # CHECK-LABEL: TEST: testAffineMulExpr
 @run
 def testAffineMulExpr():
-  with Context():
-    d1 = AffineDimExpr.get(1)
-    c2 = AffineConstantExpr.get(2)
-    expr = AffineExpr.get_mul(d1, c2)
-    # CHECK: d1 * 2
-    print(expr)
+    with Context():
+        d1 = AffineDimExpr.get(1)
+        c2 = AffineConstantExpr.get(2)
+        expr = AffineExpr.get_mul(d1, c2)
+        # CHECK: d1 * 2
+        print(expr)
 
-    # CHECK: d1 * 2
-    op = d1 * c2
-    print(op)
+        # CHECK: d1 * 2
+        op = d1 * c2
+        print(op)
 
-    # CHECK: d1 * 2
-    op_cst = d1 * 2
-    print(op_cst)
+        # CHECK: d1 * 2
+        op_cst = d1 * 2
+        print(op_cst)
 
-    # CHECK: d1 * 2
-    op_cst2 = 2 * d1
-    print(op_cst2)
+        # CHECK: d1 * 2
+        op_cst2 = 2 * d1
+        print(op_cst2)
 
-    assert expr == op
-    assert expr == op_cst
-    assert expr.lhs == d1
-    assert expr.rhs == c2
+        assert expr == op
+        assert expr == op_cst
+        assert expr.lhs == d1
+        assert expr.rhs == c2
 
 
 # CHECK-LABEL: TEST: testAffineModExpr
 @run
 def testAffineModExpr():
-  with Context():
-    d1 = AffineDimExpr.get(1)
-    c2 = AffineConstantExpr.get(2)
-    expr = AffineExpr.get_mod(d1, c2)
-    # CHECK: d1 mod 2
-    print(expr)
+    with Context():
+        d1 = AffineDimExpr.get(1)
+        c2 = AffineConstantExpr.get(2)
+        expr = AffineExpr.get_mod(d1, c2)
+        # CHECK: d1 mod 2
+        print(expr)
 
-    # CHECK: d1 mod 2
-    op = d1 % c2
-    print(op)
+        # CHECK: d1 mod 2
+        op = d1 % c2
+        print(op)
 
-    # CHECK: d1 mod 2
-    op_cst = d1 % 2
-    print(op_cst)
+        # CHECK: d1 mod 2
+        op_cst = d1 % 2
+        print(op_cst)
 
-    # CHECK: 2 mod d1
-    print(2 % d1)
+        # CHECK: 2 mod d1
+        print(2 % d1)
 
-    assert expr == op
-    assert expr == op_cst
-    assert expr.lhs == d1
-    assert expr.rhs == c2
+        assert expr == op
+        assert expr == op_cst
+        assert expr.lhs == d1
+        assert expr.rhs == c2
 
-    expr2 = AffineExpr.get_mod(c2, d1)
-    expr3 = AffineExpr.get_mod(2, d1)
-    expr4 = AffineExpr.get_mod(d1, 2)
+        expr2 = AffineExpr.get_mod(c2, d1)
+        expr3 = AffineExpr.get_mod(2, d1)
+        expr4 = AffineExpr.get_mod(d1, 2)
 
-    # CHECK: 2 mod d1
-    print(expr2)
-    # CHECK: 2 mod d1
-    print(expr3)
-    # CHECK: d1 mod 2
-    print(expr4)
+        # CHECK: 2 mod d1
+        print(expr2)
+        # CHECK: 2 mod d1
+        print(expr3)
+        # CHECK: d1 mod 2
+        print(expr4)
 
-    assert expr2 == expr3
-    assert expr4 == expr
+        assert expr2 == expr3
+        assert expr4 == expr
 
 
 # CHECK-LABEL: TEST: testAffineFloorDivExpr
 @run
 def testAffineFloorDivExpr():
-  with Context():
-    d1 = AffineDimExpr.get(1)
-    c2 = AffineConstantExpr.get(2)
-    expr = AffineExpr.get_floor_div(d1, c2)
-    # CHECK: d1 floordiv 2
-    print(expr)
+    with Context():
+        d1 = AffineDimExpr.get(1)
+        c2 = AffineConstantExpr.get(2)
+        expr = AffineExpr.get_floor_div(d1, c2)
+        # CHECK: d1 floordiv 2
+        print(expr)
 
-    assert expr.lhs == d1
-    assert expr.rhs == c2
+        assert expr.lhs == d1
+        assert expr.rhs == c2
 
-    expr2 = AffineExpr.get_floor_div(c2, d1)
-    expr3 = AffineExpr.get_floor_div(2, d1)
-    expr4 = AffineExpr.get_floor_div(d1, 2)
+        expr2 = AffineExpr.get_floor_div(c2, d1)
+        expr3 = AffineExpr.get_floor_div(2, d1)
+        expr4 = AffineExpr.get_floor_div(d1, 2)
 
-    # CHECK: 2 floordiv d1
-    print(expr2)
-    # CHECK: 2 floordiv d1
-    print(expr3)
-    # CHECK: d1 floordiv 2
-    print(expr4)
+        # CHECK: 2 floordiv d1
+        print(expr2)
+        # CHECK: 2 floordiv d1
+        print(expr3)
+        # CHECK: d1 floordiv 2
+        print(expr4)
 
-    assert expr2 == expr3
-    assert expr4 == expr
+        assert expr2 == expr3
+        assert expr4 == expr
 
 
 # CHECK-LABEL: TEST: testAffineCeilDivExpr
 @run
 def testAffineCeilDivExpr():
-  with Context():
-    d1 = AffineDimExpr.get(1)
-    c2 = AffineConstantExpr.get(2)
-    expr = AffineExpr.get_ceil_div(d1, c2)
-    # CHECK: d1 ceildiv 2
-    print(expr)
+    with Context():
+        d1 = AffineDimExpr.get(1)
+        c2 = AffineConstantExpr.get(2)
+        expr = AffineExpr.get_ceil_div(d1, c2)
+        # CHECK: d1 ceildiv 2
+        print(expr)
 
-    assert expr.lhs == d1
-    assert expr.rhs == c2
+        assert expr.lhs == d1
+        assert expr.rhs == c2
 
-    expr2 = AffineExpr.get_ceil_div(c2, d1)
-    expr3 = AffineExpr.get_ceil_div(2, d1)
-    expr4 = AffineExpr.get_ceil_div(d1, 2)
+        expr2 = AffineExpr.get_ceil_div(c2, d1)
+        expr3 = AffineExpr.get_ceil_div(2, d1)
+        expr4 = AffineExpr.get_ceil_div(d1, 2)
 
-    # CHECK: 2 ceildiv d1
-    print(expr2)
-    # CHECK: 2 ceildiv d1
-    print(expr3)
-    # CHECK: d1 ceildiv 2
-    print(expr4)
+        # CHECK: 2 ceildiv d1
+        print(expr2)
+        # CHECK: 2 ceildiv d1
+        print(expr3)
+        # CHECK: d1 ceildiv 2
+        print(expr4)
 
-    assert expr2 == expr3
-    assert expr4 == expr
+        assert expr2 == expr3
+        assert expr4 == expr
 
 
 # CHECK-LABEL: TEST: testAffineExprSub
 @run
 def testAffineExprSub():
-  with Context():
-    d1 = AffineDimExpr.get(1)
-    d2 = AffineDimExpr.get(2)
-    expr = d1 - d2
-    # CHECK: d1 - d2
-    print(expr)
-
-    assert expr.lhs == d1
-    rhs = AffineMulExpr(expr.rhs)
-    # CHECK: d2
-    print(rhs.lhs)
-    # CHECK: -1
-    print(rhs.rhs)
-
-    # CHECK: d1 - 42
-    print(d1 - 42)
-    # CHECK: -d1 + 42
-    print(42 - d1)
-
-    c42 = AffineConstantExpr.get(42)
-    assert d1 - 42 == d1 - c42
-    assert 42 - d1 == c42 - d1
+    with Context():
+        d1 = AffineDimExpr.get(1)
+        d2 = AffineDimExpr.get(2)
+        expr = d1 - d2
+        # CHECK: d1 - d2
+        print(expr)
+
+        assert expr.lhs == d1
+        rhs = AffineMulExpr(expr.rhs)
+        # CHECK: d2
+        print(rhs.lhs)
+        # CHECK: -1
+        print(rhs.rhs)
+
+        # CHECK: d1 - 42
+        print(d1 - 42)
+        # CHECK: -d1 + 42
+        print(42 - d1)
+
+        c42 = AffineConstantExpr.get(42)
+        assert d1 - 42 == d1 - c42
+        assert 42 - d1 == c42 - d1
+
 
 # CHECK-LABEL: TEST: testClassHierarchy
 @run
 def testClassHierarchy():
-  with Context():
-    d1 = AffineDimExpr.get(1)
-    c2 = AffineConstantExpr.get(2)
-    add = AffineAddExpr.get(d1, c2)
-    mul = AffineMulExpr.get(d1, c2)
-    mod = AffineModExpr.get(d1, c2)
-    floor_div = AffineFloorDivExpr.get(d1, c2)
-    ceil_div = AffineCeilDivExpr.get(d1, c2)
+    with Context():
+        d1 = AffineDimExpr.get(1)
+        c2 = AffineConstantExpr.get(2)
+        add = AffineAddExpr.get(d1, c2)
+        mul = AffineMulExpr.get(d1, c2)
+        mod = AffineModExpr.get(d1, c2)
+        floor_div = AffineFloorDivExpr.get(d1, c2)
+        ceil_div = AffineCeilDivExpr.get(d1, c2)
+
+        # CHECK: False
+        print(isinstance(d1, AffineBinaryExpr))
+        # CHECK: False
+        print(isinstance(c2, AffineBinaryExpr))
+        # CHECK: True
+        print(isinstance(add, AffineBinaryExpr))
+        # CHECK: True
+        print(isinstance(mul, AffineBinaryExpr))
+        # CHECK: True
+        print(isinstance(mod, AffineBinaryExpr))
+        # CHECK: True
+        print(isinstance(floor_div, AffineBinaryExpr))
+        # CHECK: True
+        print(isinstance(ceil_div, AffineBinaryExpr))
+
+        try:
+            AffineBinaryExpr(d1)
+        except ValueError as e:
+            # CHECK: Cannot cast affine expression to AffineBinaryExpr
+            print(e)
+
+        try:
+            AffineBinaryExpr(c2)
+        except ValueError as e:
+            # CHECK: Cannot cast affine expression to AffineBinaryExpr
+            print(e)
 
-    # CHECK: False
-    print(isinstance(d1, AffineBinaryExpr))
-    # CHECK: False
-    print(isinstance(c2, AffineBinaryExpr))
-    # CHECK: True
-    print(isinstance(add, AffineBinaryExpr))
-    # CHECK: True
-    print(isinstance(mul, AffineBinaryExpr))
-    # CHECK: True
-    print(isinstance(mod, AffineBinaryExpr))
-    # CHECK: True
-    print(isinstance(floor_div, AffineBinaryExpr))
-    # CHECK: True
-    print(isinstance(ceil_div, AffineBinaryExpr))
-
-    try:
-      AffineBinaryExpr(d1)
-    except ValueError as e:
-      # CHECK: Cannot cast affine expression to AffineBinaryExpr
-      print(e)
-
-    try:
-      AffineBinaryExpr(c2)
-    except ValueError as e:
-      # CHECK: Cannot cast affine expression to AffineBinaryExpr
-      print(e)
 
 # CHECK-LABEL: TEST: testIsInstance
 @run
 def testIsInstance():
-  with Context():
-    d1 = AffineDimExpr.get(1)
-    c2 = AffineConstantExpr.get(2)
-    add = AffineAddExpr.get(d1, c2)
-    mul = AffineMulExpr.get(d1, c2)
-
-    # CHECK: True
-    print(AffineDimExpr.isinstance(d1))
-    # CHECK: False
-    print(AffineConstantExpr.isinstance(d1))
-    # CHECK: True
-    print(AffineConstantExpr.isinstance(c2))
-    # CHECK: False
-    print(AffineMulExpr.isinstance(c2))
-    # CHECK: True
-    print(AffineAddExpr.isinstance(add))
-    # CHECK: False
-    print(AffineMulExpr.isinstance(add))
-    # CHECK: True
-    print(AffineMulExpr.isinstance(mul))
-    # CHECK: False
-    print(AffineAddExpr.isinstance(mul))
+    with Context():
+        d1 = AffineDimExpr.get(1)
+        c2 = AffineConstantExpr.get(2)
+        add = AffineAddExpr.get(d1, c2)
+        mul = AffineMulExpr.get(d1, c2)
+
+        # CHECK: True
+        print(AffineDimExpr.isinstance(d1))
+        # CHECK: False
+        print(AffineConstantExpr.isinstance(d1))
+        # CHECK: True
+        print(AffineConstantExpr.isinstance(c2))
+        # CHECK: False
+        print(AffineMulExpr.isinstance(c2))
+        # CHECK: True
+        print(AffineAddExpr.isinstance(add))
+        # CHECK: False
+        print(AffineMulExpr.isinstance(add))
+        # CHECK: True
+        print(AffineMulExpr.isinstance(mul))
+        # CHECK: False
+        print(AffineAddExpr.isinstance(mul))
 
 
 # CHECK-LABEL: TEST: testCompose
 @run
 def testCompose():
-  with Context():
-    # d0 + d2.
-    expr = AffineAddExpr.get(AffineDimExpr.get(0), AffineDimExpr.get(2))
+    with Context():
+        # d0 + d2.
+        expr = AffineAddExpr.get(AffineDimExpr.get(0), AffineDimExpr.get(2))
 
-    # (d0, d1, d2)[s0, s1] -> (d0 + s1, d1 + s0, d0 + d1 + d2)
-    map1 = AffineAddExpr.get(AffineDimExpr.get(0), AffineSymbolExpr.get(1))
-    map2 = AffineAddExpr.get(AffineDimExpr.get(1), AffineSymbolExpr.get(0))
-    map3 = AffineAddExpr.get(
-        AffineAddExpr.get(AffineDimExpr.get(0), AffineDimExpr.get(1)),
-        AffineDimExpr.get(2))
-    map = AffineMap.get(3, 2, [map1, map2, map3])
+        # (d0, d1, d2)[s0, s1] -> (d0 + s1, d1 + s0, d0 + d1 + d2)
+        map1 = AffineAddExpr.get(AffineDimExpr.get(0), AffineSymbolExpr.get(1))
+        map2 = AffineAddExpr.get(AffineDimExpr.get(1), AffineSymbolExpr.get(0))
+        map3 = AffineAddExpr.get(
+            AffineAddExpr.get(AffineDimExpr.get(0), AffineDimExpr.get(1)),
+            AffineDimExpr.get(2),
+        )
+        map = AffineMap.get(3, 2, [map1, map2, map3])
 
-    # CHECK: d0 + s1 + d0 + d1 + d2
-    print(expr.compose(map))
+        # CHECK: d0 + s1 + d0 + d1 + d2
+        print(expr.compose(map))
 
 
 # CHECK-LABEL: TEST: testHash
 @run
 def testHash():
-  with Context():
-    d0 = AffineDimExpr.get(0)
-    s1 = AffineSymbolExpr.get(1)
-    assert hash(d0) == hash(AffineDimExpr.get(0))
-    assert hash(d0 + s1) == hash(AffineAddExpr.get(d0, s1))
-
-    dictionary = dict()
-    dictionary[d0] = 0
-    dictionary[s1] = 1
-    assert d0 in dictionary
-    assert s1 in dictionary
+    with Context():
+        d0 = AffineDimExpr.get(0)
+        s1 = AffineSymbolExpr.get(1)
+        assert hash(d0) == hash(AffineDimExpr.get(0))
+        assert hash(d0 + s1) == hash(AffineAddExpr.get(d0, s1))
+
+        dictionary = dict()
+        dictionary[d0] = 0
+        dictionary[s1] = 1
+        assert d0 in dictionary
+        assert s1 in dictionary
index 52c7261..672335e 100644 (file)
@@ -5,237 +5,241 @@ from mlir.ir import *
 
 
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  gc.collect()
-  assert Context._get_live_count() == 0
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
+    return f
 
 
 # CHECK-LABEL: TEST: testAffineMapCapsule
 @run
 def testAffineMapCapsule():
-  with Context() as ctx:
-    am1 = AffineMap.get_empty(ctx)
-  # CHECK: mlir.ir.AffineMap._CAPIPtr
-  affine_map_capsule = am1._CAPIPtr
-  print(affine_map_capsule)
-  am2 = AffineMap._CAPICreate(affine_map_capsule)
-  assert am2 == am1
-  assert am2.context is ctx
+    with Context() as ctx:
+        am1 = AffineMap.get_empty(ctx)
+    # CHECK: mlir.ir.AffineMap._CAPIPtr
+    affine_map_capsule = am1._CAPIPtr
+    print(affine_map_capsule)
+    am2 = AffineMap._CAPICreate(affine_map_capsule)
+    assert am2 == am1
+    assert am2.context is ctx
 
 
 # CHECK-LABEL: TEST: testAffineMapGet
 @run
 def testAffineMapGet():
-  with Context() as ctx:
-    d0 = AffineDimExpr.get(0)
-    d1 = AffineDimExpr.get(1)
-    c2 = AffineConstantExpr.get(2)
-
-    # CHECK: (d0, d1)[s0, s1, s2] -> ()
-    map0 = AffineMap.get(2, 3, [])
-    print(map0)
-
-    # CHECK: (d0, d1)[s0, s1, s2] -> (d1, 2)
-    map1 = AffineMap.get(2, 3, [d1, c2])
-    print(map1)
-
-    # CHECK: () -> (2)
-    map2 = AffineMap.get(0, 0, [c2])
-    print(map2)
-
-    # CHECK: (d0, d1) -> (d0, d1)
-    map3 = AffineMap.get(2, 0, [d0, d1])
-    print(map3)
-
-    # CHECK: (d0, d1) -> (d1)
-    map4 = AffineMap.get(2, 0, [d1])
-    print(map4)
-
-    # CHECK: (d0, d1, d2) -> (d2, d0, d1)
-    map5 = AffineMap.get_permutation([2, 0, 1])
-    print(map5)
-
-    assert map1 == AffineMap.get(2, 3, [d1, c2])
-    assert AffineMap.get(0, 0, []) == AffineMap.get_empty()
-    assert map2 == AffineMap.get_constant(2)
-    assert map3 == AffineMap.get_identity(2)
-    assert map4 == AffineMap.get_minor_identity(2, 1)
-
-    try:
-      AffineMap.get(1, 1, [1])
-    except RuntimeError as e:
-      # CHECK: Invalid expression when attempting to create an AffineMap
-      print(e)
-
-    try:
-      AffineMap.get(1, 1, [None])
-    except RuntimeError as e:
-      # CHECK: Invalid expression (None?) when attempting to create an AffineMap
-      print(e)
-
-    try:
-      AffineMap.get_permutation([1, 0, 1])
-    except RuntimeError as e:
-      # CHECK: Invalid permutation when attempting to create an AffineMap
-      print(e)
-
-    try:
-      map3.get_submap([42])
-    except ValueError as e:
-      # CHECK: result position out of bounds
-      print(e)
-
-    try:
-      map3.get_minor_submap(42)
-    except ValueError as e:
-      # CHECK: number of results out of bounds
-      print(e)
-
-    try:
-      map3.get_major_submap(42)
-    except ValueError as e:
-      # CHECK: number of results out of bounds
-      print(e)
+    with Context() as ctx:
+        d0 = AffineDimExpr.get(0)
+        d1 = AffineDimExpr.get(1)
+        c2 = AffineConstantExpr.get(2)
+
+        # CHECK: (d0, d1)[s0, s1, s2] -> ()
+        map0 = AffineMap.get(2, 3, [])
+        print(map0)
+
+        # CHECK: (d0, d1)[s0, s1, s2] -> (d1, 2)
+        map1 = AffineMap.get(2, 3, [d1, c2])
+        print(map1)
+
+        # CHECK: () -> (2)
+        map2 = AffineMap.get(0, 0, [c2])
+        print(map2)
+
+        # CHECK: (d0, d1) -> (d0, d1)
+        map3 = AffineMap.get(2, 0, [d0, d1])
+        print(map3)
+
+        # CHECK: (d0, d1) -> (d1)
+        map4 = AffineMap.get(2, 0, [d1])
+        print(map4)
+
+        # CHECK: (d0, d1, d2) -> (d2, d0, d1)
+        map5 = AffineMap.get_permutation([2, 0, 1])
+        print(map5)
+
+        assert map1 == AffineMap.get(2, 3, [d1, c2])
+        assert AffineMap.get(0, 0, []) == AffineMap.get_empty()
+        assert map2 == AffineMap.get_constant(2)
+        assert map3 == AffineMap.get_identity(2)
+        assert map4 == AffineMap.get_minor_identity(2, 1)
+
+        try:
+            AffineMap.get(1, 1, [1])
+        except RuntimeError as e:
+            # CHECK: Invalid expression when attempting to create an AffineMap
+            print(e)
+
+        try:
+            AffineMap.get(1, 1, [None])
+        except RuntimeError as e:
+            # CHECK: Invalid expression (None?) when attempting to create an AffineMap
+            print(e)
+
+        try:
+            AffineMap.get_permutation([1, 0, 1])
+        except RuntimeError as e:
+            # CHECK: Invalid permutation when attempting to create an AffineMap
+            print(e)
+
+        try:
+            map3.get_submap([42])
+        except ValueError as e:
+            # CHECK: result position out of bounds
+            print(e)
+
+        try:
+            map3.get_minor_submap(42)
+        except ValueError as e:
+            # CHECK: number of results out of bounds
+            print(e)
+
+        try:
+            map3.get_major_submap(42)
+        except ValueError as e:
+            # CHECK: number of results out of bounds
+            print(e)
 
 
 # CHECK-LABEL: TEST: testAffineMapDerive
 @run
 def testAffineMapDerive():
-  with Context() as ctx:
-    map5 = AffineMap.get_identity(5)
+    with Context() as ctx:
+        map5 = AffineMap.get_identity(5)
 
-    # CHECK: (d0, d1, d2, d3, d4) -> (d1, d2, d3)
-    map123 = map5.get_submap([1, 2, 3])
-    print(map123)
+        # CHECK: (d0, d1, d2, d3, d4) -> (d1, d2, d3)
+        map123 = map5.get_submap([1, 2, 3])
+        print(map123)
 
-    # CHECK: (d0, d1, d2, d3, d4) -> (d0, d1)
-    map01 = map5.get_major_submap(2)
-    print(map01)
+        # CHECK: (d0, d1, d2, d3, d4) -> (d0, d1)
+        map01 = map5.get_major_submap(2)
+        print(map01)
 
-    # CHECK: (d0, d1, d2, d3, d4) -> (d3, d4)
-    map34 = map5.get_minor_submap(2)
-    print(map34)
+        # CHECK: (d0, d1, d2, d3, d4) -> (d3, d4)
+        map34 = map5.get_minor_submap(2)
+        print(map34)
 
 
 # CHECK-LABEL: TEST: testAffineMapProperties
 @run
 def testAffineMapProperties():
-  with Context():
-    d0 = AffineDimExpr.get(0)
-    d1 = AffineDimExpr.get(1)
-    d2 = AffineDimExpr.get(2)
-    map1 = AffineMap.get(3, 0, [d2, d0])
-    map2 = AffineMap.get(3, 0, [d2, d0, d1])
-    map3 = AffineMap.get(3, 1, [d2, d0, d1])
-    # CHECK: False
-    print(map1.is_permutation)
-    # CHECK: True
-    print(map1.is_projected_permutation)
-    # CHECK: True
-    print(map2.is_permutation)
-    # CHECK: True
-    print(map2.is_projected_permutation)
-    # CHECK: False
-    print(map3.is_permutation)
-    # CHECK: False
-    print(map3.is_projected_permutation)
+    with Context():
+        d0 = AffineDimExpr.get(0)
+        d1 = AffineDimExpr.get(1)
+        d2 = AffineDimExpr.get(2)
+        map1 = AffineMap.get(3, 0, [d2, d0])
+        map2 = AffineMap.get(3, 0, [d2, d0, d1])
+        map3 = AffineMap.get(3, 1, [d2, d0, d1])
+        # CHECK: False
+        print(map1.is_permutation)
+        # CHECK: True
+        print(map1.is_projected_permutation)
+        # CHECK: True
+        print(map2.is_permutation)
+        # CHECK: True
+        print(map2.is_projected_permutation)
+        # CHECK: False
+        print(map3.is_permutation)
+        # CHECK: False
+        print(map3.is_projected_permutation)
 
 
 # CHECK-LABEL: TEST: testAffineMapExprs
 @run
 def testAffineMapExprs():
-  with Context():
-    d0 = AffineDimExpr.get(0)
-    d1 = AffineDimExpr.get(1)
-    d2 = AffineDimExpr.get(2)
-    map3 = AffineMap.get(3, 1, [d2, d0, d1])
-
-    # CHECK: 3
-    print(map3.n_dims)
-    # CHECK: 4
-    print(map3.n_inputs)
-    # CHECK: 1
-    print(map3.n_symbols)
-    assert map3.n_inputs == map3.n_dims + map3.n_symbols
-
-    # CHECK: 3
-    print(len(map3.results))
-    for expr in map3.results:
-      # CHECK: d2
-      # CHECK: d0
-      # CHECK: d1
-      print(expr)
-    for expr in map3.results[-1:-4:-1]:
-      # CHECK: d1
-      # CHECK: d0
-      # CHECK: d2
-      print(expr)
-    assert list(map3.results) == [d2, d0, d1]
+    with Context():
+        d0 = AffineDimExpr.get(0)
+        d1 = AffineDimExpr.get(1)
+        d2 = AffineDimExpr.get(2)
+        map3 = AffineMap.get(3, 1, [d2, d0, d1])
+
+        # CHECK: 3
+        print(map3.n_dims)
+        # CHECK: 4
+        print(map3.n_inputs)
+        # CHECK: 1
+        print(map3.n_symbols)
+        assert map3.n_inputs == map3.n_dims + map3.n_symbols
+
+        # CHECK: 3
+        print(len(map3.results))
+        for expr in map3.results:
+            # CHECK: d2
+            # CHECK: d0
+            # CHECK: d1
+            print(expr)
+        for expr in map3.results[-1:-4:-1]:
+            # CHECK: d1
+            # CHECK: d0
+            # CHECK: d2
+            print(expr)
+        assert list(map3.results) == [d2, d0, d1]
 
 
 # CHECK-LABEL: TEST: testCompressUnusedSymbols
 @run
 def testCompressUnusedSymbols():
-  with Context() as ctx:
-    d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1),
-                  AffineDimExpr.get(2))
-    s0, s1, s2 = (AffineSymbolExpr.get(0), AffineSymbolExpr.get(1),
-                  AffineSymbolExpr.get(2))
-    maps = [
-        AffineMap.get(3, 3, [d2, d0, d1]),
-        AffineMap.get(3, 3, [d2, d0 + s2, d1]),
-        AffineMap.get(3, 3, [d1, d2, d0])
-    ]
-
-    compressed_maps = AffineMap.compress_unused_symbols(maps, ctx)
-
-    #      CHECK: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0, d1))
-    # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2, d1))
-    # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d1, d2, d0))
-    print(maps)
-
-    #      CHECK: AffineMap((d0, d1, d2)[s0] -> (d2, d0, d1))
-    # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d2, d0 + s0, d1))
-    # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d1, d2, d0))
-    print(compressed_maps)
+    with Context() as ctx:
+        d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1), AffineDimExpr.get(2))
+        s0, s1, s2 = (
+            AffineSymbolExpr.get(0),
+            AffineSymbolExpr.get(1),
+            AffineSymbolExpr.get(2),
+        )
+        maps = [
+            AffineMap.get(3, 3, [d2, d0, d1]),
+            AffineMap.get(3, 3, [d2, d0 + s2, d1]),
+            AffineMap.get(3, 3, [d1, d2, d0]),
+        ]
+
+        compressed_maps = AffineMap.compress_unused_symbols(maps, ctx)
+
+        #      CHECK: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0, d1))
+        # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2, d1))
+        # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d1, d2, d0))
+        print(maps)
+
+        #      CHECK: AffineMap((d0, d1, d2)[s0] -> (d2, d0, d1))
+        # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d2, d0 + s0, d1))
+        # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d1, d2, d0))
+        print(compressed_maps)
 
 
 # CHECK-LABEL: TEST: testReplace
 @run
 def testReplace():
-  with Context() as ctx:
-    d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1),
-                  AffineDimExpr.get(2))
-    s0, s1, s2 = (AffineSymbolExpr.get(0), AffineSymbolExpr.get(1),
-                  AffineSymbolExpr.get(2))
-    map1 = AffineMap.get(3, 3, [d2, d0 + s1 + s2, d1 + s0])
+    with Context() as ctx:
+        d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1), AffineDimExpr.get(2))
+        s0, s1, s2 = (
+            AffineSymbolExpr.get(0),
+            AffineSymbolExpr.get(1),
+            AffineSymbolExpr.get(2),
+        )
+        map1 = AffineMap.get(3, 3, [d2, d0 + s1 + s2, d1 + s0])
 
-    replace0 = map1.replace(s0, AffineConstantExpr.get(42), 3, 3)
-    replace1 = map1.replace(s1, AffineConstantExpr.get(42), 3, 3)
-    replace3 = map1.replace(s2, AffineConstantExpr.get(42), 3, 2)
+        replace0 = map1.replace(s0, AffineConstantExpr.get(42), 3, 3)
+        replace1 = map1.replace(s1, AffineConstantExpr.get(42), 3, 3)
+        replace3 = map1.replace(s2, AffineConstantExpr.get(42), 3, 2)
 
-    # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s1 + s2, d1 + 42)
-    print(replace0)
+        # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s1 + s2, d1 + 42)
+        print(replace0)
 
-    # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2 + 42, d1 + s0)
-    print(replace1)
+        # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2 + 42, d1 + s0)
+        print(replace1)
 
-    # CHECK: (d0, d1, d2)[s0, s1] -> (d2, d0 + s1 + 42, d1 + s0)
-    print(replace3)
+        # CHECK: (d0, d1, d2)[s0, s1] -> (d2, d0 + s1 + 42, d1 + s0)
+        print(replace3)
 
 
 # CHECK-LABEL: TEST: testHash
 @run
 def testHash():
-  with Context():
-    d0, d1 = AffineDimExpr.get(0), AffineDimExpr.get(1)
-    m1 = AffineMap.get(2, 0, [d0, d1])
-    m2 = AffineMap.get(2, 0, [d1, d0])
-    assert hash(m1) == hash(AffineMap.get(2, 0, [d0, d1]))
-
-    dictionary = dict()
-    dictionary[m1] = 1
-    dictionary[m2] = 2
-    assert m1 in dictionary
+    with Context():
+        d0, d1 = AffineDimExpr.get(0), AffineDimExpr.get(1)
+        m1 = AffineMap.get(2, 0, [d0, d1])
+        m2 = AffineMap.get(2, 0, [d1, d0])
+        assert hash(m1) == hash(AffineMap.get(2, 0, [d0, d1]))
+
+        dictionary = dict()
+        dictionary[m1] = 1
+        dictionary[m2] = 2
+        assert m1 in dictionary
index 3de4edb..5ce8bc6 100644 (file)
@@ -6,26 +6,30 @@ import gc
 from mlir.ir import *
 import numpy as np
 
+
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  gc.collect()
-  assert Context._get_live_count() == 0
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
+    return f
+
 
 ################################################################################
 # Tests of the array/buffer .get() factory method on unsupported dtype.
 ################################################################################
 
+
 @run
 def testGetDenseElementsUnsupported():
-  with Context():
-    array = np.array([["hello", "goodbye"]])
-    try:
-      attr = DenseElementsAttr.get(array)
-    except ValueError as e:
-      # CHECK: unimplemented array format conversion from format:
-      print(e)
+    with Context():
+        array = np.array([["hello", "goodbye"]])
+        try:
+            attr = DenseElementsAttr.get(array)
+        except ValueError as e:
+            # CHECK: unimplemented array format conversion from format:
+            print(e)
+
 
 ################################################################################
 # Splats.
@@ -34,85 +38,85 @@ def testGetDenseElementsUnsupported():
 # CHECK-LABEL: TEST: testGetDenseElementsSplatInt
 @run
 def testGetDenseElementsSplatInt():
-  with Context(), Location.unknown():
-    t = IntegerType.get_signless(32)
-    element = IntegerAttr.get(t, 555)
-    shaped_type = RankedTensorType.get((2, 3, 4), t)
-    attr = DenseElementsAttr.get_splat(shaped_type, element)
-    # CHECK: dense<555> : tensor<2x3x4xi32>
-    print(attr)
-    # CHECK: is_splat: True
-    print("is_splat:", attr.is_splat)
-    assert attr.get_splat_value() == element
+    with Context(), Location.unknown():
+        t = IntegerType.get_signless(32)
+        element = IntegerAttr.get(t, 555)
+        shaped_type = RankedTensorType.get((2, 3, 4), t)
+        attr = DenseElementsAttr.get_splat(shaped_type, element)
+        # CHECK: dense<555> : tensor<2x3x4xi32>
+        print(attr)
+        # CHECK: is_splat: True
+        print("is_splat:", attr.is_splat)
+        assert attr.get_splat_value() == element
 
 
 # CHECK-LABEL: TEST: testGetDenseElementsSplatFloat
 @run
 def testGetDenseElementsSplatFloat():
-  with Context(), Location.unknown():
-    t = F32Type.get()
-    element = FloatAttr.get(t, 1.2)
-    shaped_type = RankedTensorType.get((2, 3, 4), t)
-    attr = DenseElementsAttr.get_splat(shaped_type, element)
-    # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32>
-    print(attr)
-    assert attr.get_splat_value() == element
+    with Context(), Location.unknown():
+        t = F32Type.get()
+        element = FloatAttr.get(t, 1.2)
+        shaped_type = RankedTensorType.get((2, 3, 4), t)
+        attr = DenseElementsAttr.get_splat(shaped_type, element)
+        # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32>
+        print(attr)
+        assert attr.get_splat_value() == element
 
 
 # CHECK-LABEL: TEST: testGetDenseElementsSplatErrors
 @run
 def testGetDenseElementsSplatErrors():
-  with Context(), Location.unknown():
-    t = F32Type.get()
-    other_t = F64Type.get()
-    element = FloatAttr.get(t, 1.2)
-    other_element = FloatAttr.get(other_t, 1.2)
-    shaped_type = RankedTensorType.get((2, 3, 4), t)
-    dynamic_shaped_type = UnrankedTensorType.get(t)
-    non_shaped_type = t
-
-    try:
-      attr = DenseElementsAttr.get_splat(non_shaped_type, element)
-    except ValueError as e:
-      # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(f32)
-      print(e)
-
-    try:
-      attr = DenseElementsAttr.get_splat(dynamic_shaped_type, element)
-    except ValueError as e:
-      # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(tensor<*xf32>)
-      print(e)
-
-    try:
-      attr = DenseElementsAttr.get_splat(shaped_type, other_element)
-    except ValueError as e:
-      # CHECK: Shaped element type and attribute type must be equal: shaped=Type(tensor<2x3x4xf32>), element=Attribute(1.200000e+00 : f64)
-      print(e)
+    with Context(), Location.unknown():
+        t = F32Type.get()
+        other_t = F64Type.get()
+        element = FloatAttr.get(t, 1.2)
+        other_element = FloatAttr.get(other_t, 1.2)
+        shaped_type = RankedTensorType.get((2, 3, 4), t)
+        dynamic_shaped_type = UnrankedTensorType.get(t)
+        non_shaped_type = t
+
+        try:
+            attr = DenseElementsAttr.get_splat(non_shaped_type, element)
+        except ValueError as e:
+            # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(f32)
+            print(e)
+
+        try:
+            attr = DenseElementsAttr.get_splat(dynamic_shaped_type, element)
+        except ValueError as e:
+            # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(tensor<*xf32>)
+            print(e)
+
+        try:
+            attr = DenseElementsAttr.get_splat(shaped_type, other_element)
+        except ValueError as e:
+            # CHECK: Shaped element type and attribute type must be equal: shaped=Type(tensor<2x3x4xf32>), element=Attribute(1.200000e+00 : f64)
+            print(e)
 
 
 # CHECK-LABEL: TEST: testRepeatedValuesSplat
 @run
 def testRepeatedValuesSplat():
-  with Context():
-    array = np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dtype=np.float32)
-    attr = DenseElementsAttr.get(array)
-    # CHECK: dense<1.000000e+00> : tensor<2x3xf32>
-    print(attr)
-    # CHECK: is_splat: True
-    print("is_splat:", attr.is_splat)
-    # CHECK{LITERAL}: [[1. 1. 1.]
-    # CHECK{LITERAL}:  [1. 1. 1.]]
-    print(np.array(attr))
+    with Context():
+        array = np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dtype=np.float32)
+        attr = DenseElementsAttr.get(array)
+        # CHECK: dense<1.000000e+00> : tensor<2x3xf32>
+        print(attr)
+        # CHECK: is_splat: True
+        print("is_splat:", attr.is_splat)
+        # CHECK{LITERAL}: [[1. 1. 1.]
+        # CHECK{LITERAL}:  [1. 1. 1.]]
+        print(np.array(attr))
 
 
 # CHECK-LABEL: TEST: testNonSplat
 @run
 def testNonSplat():
-  with Context():
-    array = np.array([2.0, 1.0, 1.0], dtype=np.float32)
-    attr = DenseElementsAttr.get(array)
-    # CHECK: is_splat: False
-    print("is_splat:", attr.is_splat)
+    with Context():
+        array = np.array([2.0, 1.0, 1.0], dtype=np.float32)
+        attr = DenseElementsAttr.get(array)
+        # CHECK: is_splat: False
+        print("is_splat:", attr.is_splat)
 
 
 ################################################################################
@@ -121,50 +125,59 @@ def testNonSplat():
 
 ### explicitly provided types
 
+
 @run
 def testGetDenseElementsBF16():
-  with Context():
-    array = np.array([[2, 4, 8], [16, 32, 64]], dtype=np.uint16)
-    attr = DenseElementsAttr.get(array, type=BF16Type.get())
-    # Note: These values don't mean much since just bit-casting. But they
-    # shouldn't change.
-    # CHECK: dense<{{\[}}[1.836710e-40, 3.673420e-40, 7.346840e-40], [1.469370e-39, 2.938740e-39, 5.877470e-39]]> : tensor<2x3xbf16>
-    print(attr)
+    with Context():
+        array = np.array([[2, 4, 8], [16, 32, 64]], dtype=np.uint16)
+        attr = DenseElementsAttr.get(array, type=BF16Type.get())
+        # Note: These values don't mean much since just bit-casting. But they
+        # shouldn't change.
+        # CHECK: dense<{{\[}}[1.836710e-40, 3.673420e-40, 7.346840e-40], [1.469370e-39, 2.938740e-39, 5.877470e-39]]> : tensor<2x3xbf16>
+        print(attr)
+
 
 @run
 def testGetDenseElementsInteger4():
-  with Context():
-    array = np.array([[2, 4, 7], [-2, -4, -8]], dtype=np.uint8)
-    attr = DenseElementsAttr.get(array, type=IntegerType.get_signless(4))
-    # Note: These values don't mean much since just bit-casting. But they
-    # shouldn't change.
-    # CHECK: dense<{{\[}}[2, 4, 7], [-2, -4, -8]]> : tensor<2x3xi4>
-    print(attr)
+    with Context():
+        array = np.array([[2, 4, 7], [-2, -4, -8]], dtype=np.uint8)
+        attr = DenseElementsAttr.get(array, type=IntegerType.get_signless(4))
+        # Note: These values don't mean much since just bit-casting. But they
+        # shouldn't change.
+        # CHECK: dense<{{\[}}[2, 4, 7], [-2, -4, -8]]> : tensor<2x3xi4>
+        print(attr)
 
 
 @run
 def testGetDenseElementsBool():
-  with Context():
-    bool_array = np.array([[1, 0, 1], [0, 1, 0]], dtype=np.bool_)
-    array = np.packbits(bool_array, axis=None, bitorder="little")
-    attr = DenseElementsAttr.get(
-        array, type=IntegerType.get_signless(1), shape=bool_array.shape)
-    # CHECK: dense<{{\[}}[true, false, true], [false, true, false]]> : tensor<2x3xi1>
-    print(attr)
+    with Context():
+        bool_array = np.array([[1, 0, 1], [0, 1, 0]], dtype=np.bool_)
+        array = np.packbits(bool_array, axis=None, bitorder="little")
+        attr = DenseElementsAttr.get(
+            array, type=IntegerType.get_signless(1), shape=bool_array.shape
+        )
+        # CHECK: dense<{{\[}}[true, false, true], [false, true, false]]> : tensor<2x3xi1>
+        print(attr)
 
 
 @run
 def testGetDenseElementsBoolSplat():
-  with Context():
-    zero = np.array(0, dtype=np.uint8)
-    one = np.array(255, dtype=np.uint8)
-    print(one)
-    # CHECK: dense<false> : tensor<4x2x5xi1>
-    print(DenseElementsAttr.get(
-        zero, type=IntegerType.get_signless(1), shape=(4, 2, 5)))
-    # CHECK: dense<true> : tensor<4x2x5xi1>
-    print(DenseElementsAttr.get(
-        one, type=IntegerType.get_signless(1), shape=(4, 2, 5)))
+    with Context():
+        zero = np.array(0, dtype=np.uint8)
+        one = np.array(255, dtype=np.uint8)
+        print(one)
+        # CHECK: dense<false> : tensor<4x2x5xi1>
+        print(
+            DenseElementsAttr.get(
+                zero, type=IntegerType.get_signless(1), shape=(4, 2, 5)
+            )
+        )
+        # CHECK: dense<true> : tensor<4x2x5xi1>
+        print(
+            DenseElementsAttr.get(
+                one, type=IntegerType.get_signless(1), shape=(4, 2, 5)
+            )
+        )
 
 
 ### float and double arrays.
@@ -172,213 +185,213 @@ def testGetDenseElementsBoolSplat():
 # CHECK-LABEL: TEST: testGetDenseElementsF16
 @run
 def testGetDenseElementsF16():
-  with Context():
-    array = np.array([[2.0, 4.0, 8.0], [16.0, 32.0, 64.0]], dtype=np.float16)
-    attr = DenseElementsAttr.get(array)
-    # CHECK: dense<{{\[}}[2.000000e+00, 4.000000e+00, 8.000000e+00], [1.600000e+01, 3.200000e+01, 6.400000e+01]]> : tensor<2x3xf16>
-    print(attr)
-    # CHECK: {{\[}}[ 2. 4. 8.]
-    # CHECK: {{\[}}16. 32. 64.]]
-    print(np.array(attr))
+    with Context():
+        array = np.array([[2.0, 4.0, 8.0], [16.0, 32.0, 64.0]], dtype=np.float16)
+        attr = DenseElementsAttr.get(array)
+        # CHECK: dense<{{\[}}[2.000000e+00, 4.000000e+00, 8.000000e+00], [1.600000e+01, 3.200000e+01, 6.400000e+01]]> : tensor<2x3xf16>
+        print(attr)
+        # CHECK: {{\[}}[ 2. 4. 8.]
+        # CHECK: {{\[}}16. 32. 64.]]
+        print(np.array(attr))
 
 
 # CHECK-LABEL: TEST: testGetDenseElementsF32
 @run
 def testGetDenseElementsF32():
-  with Context():
-    array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32)
-    attr = DenseElementsAttr.get(array)
-    # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf32>
-    print(attr)
-    # CHECK: {{\[}}[1.1 2.2 3.3]
-    # CHECK: {{\[}}4.4 5.5 6.6]]
-    print(np.array(attr))
+    with Context():
+        array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32)
+        attr = DenseElementsAttr.get(array)
+        # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf32>
+        print(attr)
+        # CHECK: {{\[}}[1.1 2.2 3.3]
+        # CHECK: {{\[}}4.4 5.5 6.6]]
+        print(np.array(attr))
 
 
 # CHECK-LABEL: TEST: testGetDenseElementsF64
 @run
 def testGetDenseElementsF64():
-  with Context():
-    array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64)
-    attr = DenseElementsAttr.get(array)
-    # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64>
-    print(attr)
-    # CHECK: {{\[}}[1.1 2.2 3.3]
-    # CHECK: {{\[}}4.4 5.5 6.6]]
-    print(np.array(attr))
+    with Context():
+        array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64)
+        attr = DenseElementsAttr.get(array)
+        # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64>
+        print(attr)
+        # CHECK: {{\[}}[1.1 2.2 3.3]
+        # CHECK: {{\[}}4.4 5.5 6.6]]
+        print(np.array(attr))
 
 
 ### 16 bit integer arrays
 # CHECK-LABEL: TEST: testGetDenseElementsI16Signless
 @run
 def testGetDenseElementsI16Signless():
-  with Context():
-    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)
-    attr = DenseElementsAttr.get(array)
-    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16>
-    print(attr)
-    # CHECK: {{\[}}[1 2 3]
-    # CHECK: {{\[}}4 5 6]]
-    print(np.array(attr))
+    with Context():
+        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)
+        attr = DenseElementsAttr.get(array)
+        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16>
+        print(attr)
+        # CHECK: {{\[}}[1 2 3]
+        # CHECK: {{\[}}4 5 6]]
+        print(np.array(attr))
 
 
 # CHECK-LABEL: TEST: testGetDenseElementsUI16Signless
 @run
 def testGetDenseElementsUI16Signless():
-  with Context():
-    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16)
-    attr = DenseElementsAttr.get(array)
-    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16>
-    print(attr)
-    # CHECK: {{\[}}[1 2 3]
-    # CHECK: {{\[}}4 5 6]]
-    print(np.array(attr))
+    with Context():
+        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16)
+        attr = DenseElementsAttr.get(array)
+        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16>
+        print(attr)
+        # CHECK: {{\[}}[1 2 3]
+        # CHECK: {{\[}}4 5 6]]
+        print(np.array(attr))
 
 
 # CHECK-LABEL: TEST: testGetDenseElementsI16
 @run
 def testGetDenseElementsI16():
-  with Context():
-    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)
-    attr = DenseElementsAttr.get(array, signless=False)
-    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi16>
-    print(attr)
-    # CHECK: {{\[}}[1 2 3]
-    # CHECK: {{\[}}4 5 6]]
-    print(np.array(attr))
+    with Context():
+        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)
+        attr = DenseElementsAttr.get(array, signless=False)
+        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi16>
+        print(attr)
+        # CHECK: {{\[}}[1 2 3]
+        # CHECK: {{\[}}4 5 6]]
+        print(np.array(attr))
 
 
 # CHECK-LABEL: TEST: testGetDenseElementsUI16
 @run
 def testGetDenseElementsUI16():
-  with Context():
-    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16)
-    attr = DenseElementsAttr.get(array, signless=False)
-    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui16>
-    print(attr)
-    # CHECK: {{\[}}[1 2 3]
-    # CHECK: {{\[}}4 5 6]]
-    print(np.array(attr))
+    with Context():
+        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16)
+        attr = DenseElementsAttr.get(array, signless=False)
+        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui16>
+        print(attr)
+        # CHECK: {{\[}}[1 2 3]
+        # CHECK: {{\[}}4 5 6]]
+        print(np.array(attr))
+
 
 ### 32 bit integer arrays
 # CHECK-LABEL: TEST: testGetDenseElementsI32Signless
 @run
 def testGetDenseElementsI32Signless():
-  with Context():
-    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
-    attr = DenseElementsAttr.get(array)
-    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
-    print(attr)
-    # CHECK: {{\[}}[1 2 3]
-    # CHECK: {{\[}}4 5 6]]
-    print(np.array(attr))
+    with Context():
+        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
+        attr = DenseElementsAttr.get(array)
+        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
+        print(attr)
+        # CHECK: {{\[}}[1 2 3]
+        # CHECK: {{\[}}4 5 6]]
+        print(np.array(attr))
 
 
 # CHECK-LABEL: TEST: testGetDenseElementsUI32Signless
 @run
 def testGetDenseElementsUI32Signless():
-  with Context():
-    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
-    attr = DenseElementsAttr.get(array)
-    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
-    print(attr)
-    # CHECK: {{\[}}[1 2 3]
-    # CHECK: {{\[}}4 5 6]]
-    print(np.array(attr))
+    with Context():
+        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
+        attr = DenseElementsAttr.get(array)
+        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
+        print(attr)
+        # CHECK: {{\[}}[1 2 3]
+        # CHECK: {{\[}}4 5 6]]
+        print(np.array(attr))
 
 
 # CHECK-LABEL: TEST: testGetDenseElementsI32
 @run
 def testGetDenseElementsI32():
-  with Context():
-    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
-    attr = DenseElementsAttr.get(array, signless=False)
-    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32>
-    print(attr)
-    # CHECK: {{\[}}[1 2 3]
-    # CHECK: {{\[}}4 5 6]]
-    print(np.array(attr))
+    with Context():
+        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
+        attr = DenseElementsAttr.get(array, signless=False)
+        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32>
+        print(attr)
+        # CHECK: {{\[}}[1 2 3]
+        # CHECK: {{\[}}4 5 6]]
+        print(np.array(attr))
 
 
 # CHECK-LABEL: TEST: testGetDenseElementsUI32
 @run
 def testGetDenseElementsUI32():
-  with Context():
-    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
-    attr = DenseElementsAttr.get(array, signless=False)
-    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32>
-    print(attr)
-    # CHECK: {{\[}}[1 2 3]
-    # CHECK: {{\[}}4 5 6]]
-    print(np.array(attr))
+    with Context():
+        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
+        attr = DenseElementsAttr.get(array, signless=False)
+        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32>
+        print(attr)
+        # CHECK: {{\[}}[1 2 3]
+        # CHECK: {{\[}}4 5 6]]
+        print(np.array(attr))
 
 
 ## 64bit integer arrays
 # CHECK-LABEL: TEST: testGetDenseElementsI64Signless
 @run
 def testGetDenseElementsI64Signless():
-  with Context():
-    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
-    attr = DenseElementsAttr.get(array)
-    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
-    print(attr)
-    # CHECK: {{\[}}[1 2 3]
-    # CHECK: {{\[}}4 5 6]]
-    print(np.array(attr))
+    with Context():
+        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
+        attr = DenseElementsAttr.get(array)
+        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
+        print(attr)
+        # CHECK: {{\[}}[1 2 3]
+        # CHECK: {{\[}}4 5 6]]
+        print(np.array(attr))
 
 
 # CHECK-LABEL: TEST: testGetDenseElementsUI64Signless
 @run
 def testGetDenseElementsUI64Signless():
-  with Context():
-    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
-    attr = DenseElementsAttr.get(array)
-    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
-    print(attr)
-    # CHECK: {{\[}}[1 2 3]
-    # CHECK: {{\[}}4 5 6]]
-    print(np.array(attr))
+    with Context():
+        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
+        attr = DenseElementsAttr.get(array)
+        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
+        print(attr)
+        # CHECK: {{\[}}[1 2 3]
+        # CHECK: {{\[}}4 5 6]]
+        print(np.array(attr))
 
 
 # CHECK-LABEL: TEST: testGetDenseElementsI64
 @run
 def testGetDenseElementsI64():
-  with Context():
-    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
-    attr = DenseElementsAttr.get(array, signless=False)
-    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64>
-    print(attr)
-    # CHECK: {{\[}}[1 2 3]
-    # CHECK: {{\[}}4 5 6]]
-    print(np.array(attr))
+    with Context():
+        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
+        attr = DenseElementsAttr.get(array, signless=False)
+        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64>
+        print(attr)
+        # CHECK: {{\[}}[1 2 3]
+        # CHECK: {{\[}}4 5 6]]
+        print(np.array(attr))
 
 
 # CHECK-LABEL: TEST: testGetDenseElementsUI64
 @run
 def testGetDenseElementsUI64():
-  with Context():
-    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
-    attr = DenseElementsAttr.get(array, signless=False)
-    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64>
-    print(attr)
-    # CHECK: {{\[}}[1 2 3]
-    # CHECK: {{\[}}4 5 6]]
-    print(np.array(attr))
+    with Context():
+        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
+        attr = DenseElementsAttr.get(array, signless=False)
+        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64>
+        print(attr)
+        # CHECK: {{\[}}[1 2 3]
+        # CHECK: {{\[}}4 5 6]]
+        print(np.array(attr))
 
 
 # CHECK-LABEL: TEST: testGetDenseElementsIndex
 @run
 def testGetDenseElementsIndex():
-  with Context(), Location.unknown():
-    idx_type = IndexType.get()
-    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
-    attr = DenseElementsAttr.get(array, type=idx_type)
-    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xindex>
-    print(attr)
-    arr = np.array(attr)
-    # CHECK: {{\[}}[1 2 3]
-    # CHECK: {{\[}}4 5 6]]
-    print(arr)
-    # CHECK: True
-    print(arr.dtype == np.int64)
-
+    with Context(), Location.unknown():
+        idx_type = IndexType.get()
+        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
+        attr = DenseElementsAttr.get(array, type=idx_type)
+        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xindex>
+        print(attr)
+        arr = np.array(attr)
+        # CHECK: {{\[}}[1 2 3]
+        # CHECK: {{\[}}4 5 6]]
+        print(arr)
+        # CHECK: True
+        print(arr.dtype == np.int64)
index 6aad943..2907405 100644 (file)
@@ -6,554 +6,550 @@ from mlir.ir import *
 
 
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  gc.collect()
-  assert Context._get_live_count() == 0
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
+    return f
 
 
 # CHECK-LABEL: TEST: testParsePrint
 @run
 def testParsePrint():
-  with Context() as ctx:
-    t = Attribute.parse('"hello"')
-  assert t.context is ctx
-  ctx = None
-  gc.collect()
-  # CHECK: "hello"
-  print(str(t))
-  # CHECK: Attribute("hello")
-  print(repr(t))
+    with Context() as ctx:
+        t = Attribute.parse('"hello"')
+    assert t.context is ctx
+    ctx = None
+    gc.collect()
+    # CHECK: "hello"
+    print(str(t))
+    # CHECK: Attribute("hello")
+    print(repr(t))
 
 
 # CHECK-LABEL: TEST: testParseError
 @run
 def testParseError():
-  with Context():
-    try:
-      t = Attribute.parse("BAD_ATTR_DOES_NOT_EXIST")
-    except MLIRError as e:
-      # CHECK: testParseError: <
-      # CHECK:   Unable to parse attribute:
-      # CHECK:   error: "BAD_ATTR_DOES_NOT_EXIST":1:1: expected attribute value
-      # CHECK: >
-      print(f"testParseError: <{e}>")
-    else:
-      print("Exception not produced")
+    with Context():
+        try:
+            t = Attribute.parse("BAD_ATTR_DOES_NOT_EXIST")
+        except MLIRError as e:
+            # CHECK: testParseError: <
+            # CHECK:   Unable to parse attribute:
+            # CHECK:   error: "BAD_ATTR_DOES_NOT_EXIST":1:1: expected attribute value
+            # CHECK: >
+            print(f"testParseError: <{e}>")
+        else:
+            print("Exception not produced")
 
 
 # CHECK-LABEL: TEST: testAttrEq
 @run
 def testAttrEq():
-  with Context():
-    a1 = Attribute.parse('"attr1"')
-    a2 = Attribute.parse('"attr2"')
-    a3 = Attribute.parse('"attr1"')
-    # CHECK: a1 == a1: True
-    print("a1 == a1:", a1 == a1)
-    # CHECK: a1 == a2: False
-    print("a1 == a2:", a1 == a2)
-    # CHECK: a1 == a3: True
-    print("a1 == a3:", a1 == a3)
-    # CHECK: a1 == None: False
-    print("a1 == None:", a1 == None)
+    with Context():
+        a1 = Attribute.parse('"attr1"')
+        a2 = Attribute.parse('"attr2"')
+        a3 = Attribute.parse('"attr1"')
+        # CHECK: a1 == a1: True
+        print("a1 == a1:", a1 == a1)
+        # CHECK: a1 == a2: False
+        print("a1 == a2:", a1 == a2)
+        # CHECK: a1 == a3: True
+        print("a1 == a3:", a1 == a3)
+        # CHECK: a1 == None: False
+        print("a1 == None:", a1 == None)
 
 
 # CHECK-LABEL: TEST: testAttrHash
 @run
 def testAttrHash():
-  with Context():
-    a1 = Attribute.parse('"attr1"')
-    a2 = Attribute.parse('"attr2"')
-    a3 = Attribute.parse('"attr1"')
-    # CHECK: hash(a1) == hash(a3): True
-    print("hash(a1) == hash(a3):", a1.__hash__() == a3.__hash__())
+    with Context():
+        a1 = Attribute.parse('"attr1"')
+        a2 = Attribute.parse('"attr2"')
+        a3 = Attribute.parse('"attr1"')
+        # CHECK: hash(a1) == hash(a3): True
+        print("hash(a1) == hash(a3):", a1.__hash__() == a3.__hash__())
 
-    s = set()
-    s.add(a1)
-    s.add(a2)
-    s.add(a3)
-    # CHECK: len(s): 2
-    print("len(s): ", len(s))
+        s = set()
+        s.add(a1)
+        s.add(a2)
+        s.add(a3)
+        # CHECK: len(s): 2
+        print("len(s): ", len(s))
 
 
 # CHECK-LABEL: TEST: testAttrCast
 @run
 def testAttrCast():
-  with Context():
-    a1 = Attribute.parse('"attr1"')
-    a2 = Attribute(a1)
-    # CHECK: a1 == a2: True
-    print("a1 == a2:", a1 == a2)
+    with Context():
+        a1 = Attribute.parse('"attr1"')
+        a2 = Attribute(a1)
+        # CHECK: a1 == a2: True
+        print("a1 == a2:", a1 == a2)
 
 
 # CHECK-LABEL: TEST: testAttrIsInstance
 @run
 def testAttrIsInstance():
-  with Context():
-    a1 = Attribute.parse("42")
-    a2 = Attribute.parse("[42]")
-    assert IntegerAttr.isinstance(a1)
-    assert not IntegerAttr.isinstance(a2)
-    assert not ArrayAttr.isinstance(a1)
-    assert ArrayAttr.isinstance(a2)
+    with Context():
+        a1 = Attribute.parse("42")
+        a2 = Attribute.parse("[42]")
+        assert IntegerAttr.isinstance(a1)
+        assert not IntegerAttr.isinstance(a2)
+        assert not ArrayAttr.isinstance(a1)
+        assert ArrayAttr.isinstance(a2)
 
 
 # CHECK-LABEL: TEST: testAttrEqDoesNotRaise
 @run
 def testAttrEqDoesNotRaise():
-  with Context():
-    a1 = Attribute.parse('"attr1"')
-    not_an_attr = "foo"
-    # CHECK: False
-    print(a1 == not_an_attr)
-    # CHECK: False
-    print(a1 == None)
-    # CHECK: True
-    print(a1 != None)
+    with Context():
+        a1 = Attribute.parse('"attr1"')
+        not_an_attr = "foo"
+        # CHECK: False
+        print(a1 == not_an_attr)
+        # CHECK: False
+        print(a1 == None)
+        # CHECK: True
+        print(a1 != None)
 
 
 # CHECK-LABEL: TEST: testAttrCapsule
 @run
 def testAttrCapsule():
-  with Context() as ctx:
-    a1 = Attribute.parse('"attr1"')
-  # CHECK: mlir.ir.Attribute._CAPIPtr
-  attr_capsule = a1._CAPIPtr
-  print(attr_capsule)
-  a2 = Attribute._CAPICreate(attr_capsule)
-  assert a2 == a1
-  assert a2.context is ctx
+    with Context() as ctx:
+        a1 = Attribute.parse('"attr1"')
+    # CHECK: mlir.ir.Attribute._CAPIPtr
+    attr_capsule = a1._CAPIPtr
+    print(attr_capsule)
+    a2 = Attribute._CAPICreate(attr_capsule)
+    assert a2 == a1
+    assert a2.context is ctx
 
 
 # CHECK-LABEL: TEST: testStandardAttrCasts
 @run
 def testStandardAttrCasts():
-  with Context():
-    a1 = Attribute.parse('"attr1"')
-    astr = StringAttr(a1)
-    aself = StringAttr(astr)
-    # CHECK: Attribute("attr1")
-    print(repr(astr))
-    try:
-      tillegal = StringAttr(Attribute.parse("1.0"))
-    except ValueError as e:
-      # CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64))
-      print("ValueError:", e)
-    else:
-      print("Exception not produced")
+    with Context():
+        a1 = Attribute.parse('"attr1"')
+        astr = StringAttr(a1)
+        aself = StringAttr(astr)
+        # CHECK: Attribute("attr1")
+        print(repr(astr))
+        try:
+            tillegal = StringAttr(Attribute.parse("1.0"))
+        except ValueError as e:
+            # CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64))
+            print("ValueError:", e)
+        else:
+            print("Exception not produced")
 
 
 # CHECK-LABEL: TEST: testAffineMapAttr
 @run
 def testAffineMapAttr():
-  with Context() as ctx:
-    d0 = AffineDimExpr.get(0)
-    d1 = AffineDimExpr.get(1)
-    c2 = AffineConstantExpr.get(2)
-    map0 = AffineMap.get(2, 3, [])
+    with Context() as ctx:
+        d0 = AffineDimExpr.get(0)
+        d1 = AffineDimExpr.get(1)
+        c2 = AffineConstantExpr.get(2)
+        map0 = AffineMap.get(2, 3, [])
 
-    # CHECK: affine_map<(d0, d1)[s0, s1, s2] -> ()>
-    attr_built = AffineMapAttr.get(map0)
-    print(str(attr_built))
+        # CHECK: affine_map<(d0, d1)[s0, s1, s2] -> ()>
+        attr_built = AffineMapAttr.get(map0)
+        print(str(attr_built))
 
-    attr_parsed = Attribute.parse(str(attr_built))
-    assert attr_built == attr_parsed
+        attr_parsed = Attribute.parse(str(attr_built))
+        assert attr_built == attr_parsed
 
 
 # CHECK-LABEL: TEST: testFloatAttr
 @run
 def testFloatAttr():
-  with Context(), Location.unknown():
-    fattr = FloatAttr(Attribute.parse("42.0 : f32"))
-    # CHECK: fattr value: 42.0
-    print("fattr value:", fattr.value)
-
-    # Test factory methods.
-    # CHECK: default_get: 4.200000e+01 : f32
-    print("default_get:", FloatAttr.get(
-        F32Type.get(), 42.0))
-    # CHECK: f32_get: 4.200000e+01 : f32
-    print("f32_get:", FloatAttr.get_f32(42.0))
-    # CHECK: f64_get: 4.200000e+01 : f64
-    print("f64_get:", FloatAttr.get_f64(42.0))
-    try:
-      fattr_invalid = FloatAttr.get(
-          IntegerType.get_signless(32), 42)
-    except MLIRError as e:
-      # CHECK: Invalid attribute:
-      # CHECK: error: unknown: expected floating point type
-      print(e)
-    else:
-      print("Exception not produced")
+    with Context(), Location.unknown():
+        fattr = FloatAttr(Attribute.parse("42.0 : f32"))
+        # CHECK: fattr value: 42.0
+        print("fattr value:", fattr.value)
+
+        # Test factory methods.
+        # CHECK: default_get: 4.200000e+01 : f32
+        print("default_get:", FloatAttr.get(F32Type.get(), 42.0))
+        # CHECK: f32_get: 4.200000e+01 : f32
+        print("f32_get:", FloatAttr.get_f32(42.0))
+        # CHECK: f64_get: 4.200000e+01 : f64
+        print("f64_get:", FloatAttr.get_f64(42.0))
+        try:
+            fattr_invalid = FloatAttr.get(IntegerType.get_signless(32), 42)
+        except MLIRError as e:
+            # CHECK: Invalid attribute:
+            # CHECK: error: unknown: expected floating point type
+            print(e)
+        else:
+            print("Exception not produced")
 
 
 # CHECK-LABEL: TEST: testIntegerAttr
 @run
 def testIntegerAttr():
-  with Context() as ctx:
-    i_attr = IntegerAttr(Attribute.parse("42"))
-    # CHECK: i_attr value: 42
-    print("i_attr value:", i_attr.value)
-    # CHECK: i_attr type: i64
-    print("i_attr type:", i_attr.type)
-    si_attr = IntegerAttr(Attribute.parse("-1 : si8"))
-    # CHECK: si_attr value: -1
-    print("si_attr value:", si_attr.value)
-    ui_attr = IntegerAttr(Attribute.parse("255 : ui8"))
-    # CHECK: ui_attr value: 255
-    print("ui_attr value:", ui_attr.value)
-    idx_attr = IntegerAttr(Attribute.parse("-1 : index"))
-    # CHECK: idx_attr value: -1
-    print("idx_attr value:", idx_attr.value)
-
-    # Test factory methods.
-    # CHECK: default_get: 42 : i32
-    print("default_get:", IntegerAttr.get(
-        IntegerType.get_signless(32), 42))
+    with Context() as ctx:
+        i_attr = IntegerAttr(Attribute.parse("42"))
+        # CHECK: i_attr value: 42
+        print("i_attr value:", i_attr.value)
+        # CHECK: i_attr type: i64
+        print("i_attr type:", i_attr.type)
+        si_attr = IntegerAttr(Attribute.parse("-1 : si8"))
+        # CHECK: si_attr value: -1
+        print("si_attr value:", si_attr.value)
+        ui_attr = IntegerAttr(Attribute.parse("255 : ui8"))
+        # CHECK: ui_attr value: 255
+        print("ui_attr value:", ui_attr.value)
+        idx_attr = IntegerAttr(Attribute.parse("-1 : index"))
+        # CHECK: idx_attr value: -1
+        print("idx_attr value:", idx_attr.value)
+
+        # Test factory methods.
+        # CHECK: default_get: 42 : i32
+        print("default_get:", IntegerAttr.get(IntegerType.get_signless(32), 42))
 
 
 # CHECK-LABEL: TEST: testBoolAttr
 @run
 def testBoolAttr():
-  with Context() as ctx:
-    battr = BoolAttr(Attribute.parse("true"))
-    # CHECK: iattr value: True
-    print("iattr value:", battr.value)
+    with Context() as ctx:
+        battr = BoolAttr(Attribute.parse("true"))
+        # CHECK: iattr value: True
+        print("iattr value:", battr.value)
 
-    # Test factory methods.
-    # CHECK: default_get: true
-    print("default_get:", BoolAttr.get(True))
+        # Test factory methods.
+        # CHECK: default_get: true
+        print("default_get:", BoolAttr.get(True))
 
 
 # CHECK-LABEL: TEST: testFlatSymbolRefAttr
 @run
 def testFlatSymbolRefAttr():
-  with Context() as ctx:
-    sattr = FlatSymbolRefAttr(Attribute.parse('@symbol'))
-    # CHECK: symattr value: symbol
-    print("symattr value:", sattr.value)
+    with Context() as ctx:
+        sattr = FlatSymbolRefAttr(Attribute.parse("@symbol"))
+        # CHECK: symattr value: symbol
+        print("symattr value:", sattr.value)
 
-    # Test factory methods.
-    # CHECK: default_get: @foobar
-    print("default_get:", FlatSymbolRefAttr.get("foobar"))
+        # Test factory methods.
+        # CHECK: default_get: @foobar
+        print("default_get:", FlatSymbolRefAttr.get("foobar"))
 
 
 # CHECK-LABEL: TEST: testOpaqueAttr
 @run
 def testOpaqueAttr():
-  with Context() as ctx:
-    ctx.allow_unregistered_dialects = True
-    oattr = OpaqueAttr(Attribute.parse("#pytest_dummy.dummyattr<>"))
-    # CHECK: oattr value: pytest_dummy
-    print("oattr value:", oattr.dialect_namespace)
-    # CHECK: oattr value: b'dummyattr<>'
-    print("oattr value:", oattr.data)
-
-    # Test factory methods.
-    # CHECK: default_get: #foobar<123>
-    print(
-        "default_get:",
-        OpaqueAttr.get("foobar", bytes("123", "utf-8"), NoneType.get()))
+    with Context() as ctx:
+        ctx.allow_unregistered_dialects = True
+        oattr = OpaqueAttr(Attribute.parse("#pytest_dummy.dummyattr<>"))
+        # CHECK: oattr value: pytest_dummy
+        print("oattr value:", oattr.dialect_namespace)
+        # CHECK: oattr value: b'dummyattr<>'
+        print("oattr value:", oattr.data)
+
+        # Test factory methods.
+        # CHECK: default_get: #foobar<123>
+        print(
+            "default_get:",
+            OpaqueAttr.get("foobar", bytes("123", "utf-8"), NoneType.get()),
+        )
 
 
 # CHECK-LABEL: TEST: testStringAttr
 @run
 def testStringAttr():
-  with Context() as ctx:
-    sattr = StringAttr(Attribute.parse('"stringattr"'))
-    # CHECK: sattr value: stringattr
-    print("sattr value:", sattr.value)
-    # CHECK: sattr value: b'stringattr'
-    print("sattr value:", sattr.value_bytes)
+    with Context() as ctx:
+        sattr = StringAttr(Attribute.parse('"stringattr"'))
+        # CHECK: sattr value: stringattr
+        print("sattr value:", sattr.value)
+        # CHECK: sattr value: b'stringattr'
+        print("sattr value:", sattr.value_bytes)
 
-    # Test factory methods.
-    # CHECK: default_get: "foobar"
-    print("default_get:", StringAttr.get("foobar"))
-    # CHECK: typed_get: "12345" : i32
-    print("typed_get:", StringAttr.get_typed(
-        IntegerType.get_signless(32), "12345"))
+        # Test factory methods.
+        # CHECK: default_get: "foobar"
+        print("default_get:", StringAttr.get("foobar"))
+        # CHECK: typed_get: "12345" : i32
+        print("typed_get:", StringAttr.get_typed(IntegerType.get_signless(32), "12345"))
 
 
 # CHECK-LABEL: TEST: testNamedAttr
 @run
 def testNamedAttr():
-  with Context():
-    a = Attribute.parse('"stringattr"')
-    named = a.get_named("foobar")  # Note: under the small object threshold
-    # CHECK: attr: "stringattr"
-    print("attr:", named.attr)
-    # CHECK: name: foobar
-    print("name:", named.name)
-    # CHECK: named: NamedAttribute(foobar="stringattr")
-    print("named:", named)
+    with Context():
+        a = Attribute.parse('"stringattr"')
+        named = a.get_named("foobar")  # Note: under the small object threshold
+        # CHECK: attr: "stringattr"
+        print("attr:", named.attr)
+        # CHECK: name: foobar
+        print("name:", named.name)
+        # CHECK: named: NamedAttribute(foobar="stringattr")
+        print("named:", named)
 
 
 # CHECK-LABEL: TEST: testDenseIntAttr
 @run
 def testDenseIntAttr():
-  with Context():
-    raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>")
-    # CHECK: attr: dense<[{{\[}}0, 1, 2], [3, 4, 5]]>
-    print("attr:", raw)
+    with Context():
+        raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>")
+        # CHECK: attr: dense<[{{\[}}0, 1, 2], [3, 4, 5]]>
+        print("attr:", raw)
 
-    a = DenseIntElementsAttr(raw)
-    assert len(a) == 6
+        a = DenseIntElementsAttr(raw)
+        assert len(a) == 6
 
-    # CHECK: 0 1 2 3 4 5
-    for value in a:
-      print(value, end=" ")
-    print()
+        # CHECK: 0 1 2 3 4 5
+        for value in a:
+            print(value, end=" ")
+        print()
 
-    # CHECK: i32
-    print(ShapedType(a.type).element_type)
+        # CHECK: i32
+        print(ShapedType(a.type).element_type)
 
-    raw = Attribute.parse("dense<[true,false,true,false]> : vector<4xi1>")
-    # CHECK: attr: dense<[true, false, true, false]>
-    print("attr:", raw)
+        raw = Attribute.parse("dense<[true,false,true,false]> : vector<4xi1>")
+        # CHECK: attr: dense<[true, false, true, false]>
+        print("attr:", raw)
 
-    a = DenseIntElementsAttr(raw)
-    assert len(a) == 4
+        a = DenseIntElementsAttr(raw)
+        assert len(a) == 4
 
-    # CHECK: 1 0 1 0
-    for value in a:
-      print(value, end=" ")
-    print()
+        # CHECK: 1 0 1 0
+        for value in a:
+            print(value, end=" ")
+        print()
 
-    # CHECK: i1
-    print(ShapedType(a.type).element_type)
+        # CHECK: i1
+        print(ShapedType(a.type).element_type)
 
 
 @run
 def testDenseArrayGetItem():
-  def print_item(AttrClass, attr_asm):
-    attr = AttrClass(Attribute.parse(attr_asm))
-    print(f"{len(attr)}: {attr[0]}, {attr[1]}")
-
-  with Context():
-    # CHECK: 2: 0, 1
-    print_item(DenseBoolArrayAttr, "array<i1: false, true>")
-    # CHECK: 2: 2, 3
-    print_item(DenseI8ArrayAttr, "array<i8: 2, 3>")
-    # CHECK: 2: 4, 5
-    print_item(DenseI16ArrayAttr, "array<i16: 4, 5>")
-    # CHECK: 2: 6, 7
-    print_item(DenseI32ArrayAttr, "array<i32: 6, 7>")
-    # CHECK: 2: 8, 9
-    print_item(DenseI64ArrayAttr, "array<i64: 8, 9>")
-    # CHECK: 2: 1.{{0+}}, 2.{{0+}}
-    print_item(DenseF32ArrayAttr, "array<f32: 1.0, 2.0>")
-    # CHECK: 2: 3.{{0+}}, 4.{{0+}}
-    print_item(DenseF64ArrayAttr, "array<f64: 3.0, 4.0>")
+    def print_item(AttrClass, attr_asm):
+        attr = AttrClass(Attribute.parse(attr_asm))
+        print(f"{len(attr)}: {attr[0]}, {attr[1]}")
+
+    with Context():
+        # CHECK: 2: 0, 1
+        print_item(DenseBoolArrayAttr, "array<i1: false, true>")
+        # CHECK: 2: 2, 3
+        print_item(DenseI8ArrayAttr, "array<i8: 2, 3>")
+        # CHECK: 2: 4, 5
+        print_item(DenseI16ArrayAttr, "array<i16: 4, 5>")
+        # CHECK: 2: 6, 7
+        print_item(DenseI32ArrayAttr, "array<i32: 6, 7>")
+        # CHECK: 2: 8, 9
+        print_item(DenseI64ArrayAttr, "array<i64: 8, 9>")
+        # CHECK: 2: 1.{{0+}}, 2.{{0+}}
+        print_item(DenseF32ArrayAttr, "array<f32: 1.0, 2.0>")
+        # CHECK: 2: 3.{{0+}}, 4.{{0+}}
+        print_item(DenseF64ArrayAttr, "array<f64: 3.0, 4.0>")
 
 
 # CHECK-LABEL: TEST: testDenseIntAttrGetItem
 @run
 def testDenseIntAttrGetItem():
-  def print_item(attr_asm):
-    attr = DenseIntElementsAttr(Attribute.parse(attr_asm))
-    dtype = ShapedType(attr.type).element_type
-    try:
-      item = attr[0]
-      print(f"{dtype}:", item)
-    except TypeError as e:
-      print(f"{dtype}:", e)
-
-  with Context():
-    # CHECK: i1: 1
-    print_item("dense<true> : tensor<i1>")
-    # CHECK: i8: 123
-    print_item("dense<123> : tensor<i8>")
-    # CHECK: i16: 123
-    print_item("dense<123> : tensor<i16>")
-    # CHECK: i32: 123
-    print_item("dense<123> : tensor<i32>")
-    # CHECK: i64: 123
-    print_item("dense<123> : tensor<i64>")
-    # CHECK: ui8: 123
-    print_item("dense<123> : tensor<ui8>")
-    # CHECK: ui16: 123
-    print_item("dense<123> : tensor<ui16>")
-    # CHECK: ui32: 123
-    print_item("dense<123> : tensor<ui32>")
-    # CHECK: ui64: 123
-    print_item("dense<123> : tensor<ui64>")
-    # CHECK: si8: -123
-    print_item("dense<-123> : tensor<si8>")
-    # CHECK: si16: -123
-    print_item("dense<-123> : tensor<si16>")
-    # CHECK: si32: -123
-    print_item("dense<-123> : tensor<si32>")
-    # CHECK: si64: -123
-    print_item("dense<-123> : tensor<si64>")
-
-    # CHECK: i7: Unsupported integer type
-    print_item("dense<123> : tensor<i7>")
+    def print_item(attr_asm):
+        attr = DenseIntElementsAttr(Attribute.parse(attr_asm))
+        dtype = ShapedType(attr.type).element_type
+        try:
+            item = attr[0]
+            print(f"{dtype}:", item)
+        except TypeError as e:
+            print(f"{dtype}:", e)
+
+    with Context():
+        # CHECK: i1: 1
+        print_item("dense<true> : tensor<i1>")
+        # CHECK: i8: 123
+        print_item("dense<123> : tensor<i8>")
+        # CHECK: i16: 123
+        print_item("dense<123> : tensor<i16>")
+        # CHECK: i32: 123
+        print_item("dense<123> : tensor<i32>")
+        # CHECK: i64: 123
+        print_item("dense<123> : tensor<i64>")
+        # CHECK: ui8: 123
+        print_item("dense<123> : tensor<ui8>")
+        # CHECK: ui16: 123
+        print_item("dense<123> : tensor<ui16>")
+        # CHECK: ui32: 123
+        print_item("dense<123> : tensor<ui32>")
+        # CHECK: ui64: 123
+        print_item("dense<123> : tensor<ui64>")
+        # CHECK: si8: -123
+        print_item("dense<-123> : tensor<si8>")
+        # CHECK: si16: -123
+        print_item("dense<-123> : tensor<si16>")
+        # CHECK: si32: -123
+        print_item("dense<-123> : tensor<si32>")
+        # CHECK: si64: -123
+        print_item("dense<-123> : tensor<si64>")
+
+        # CHECK: i7: Unsupported integer type
+        print_item("dense<123> : tensor<i7>")
 
 
 # CHECK-LABEL: TEST: testDenseFPAttr
 @run
 def testDenseFPAttr():
-  with Context():
-    raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>")
-    # CHECK: attr: dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
+    with Context():
+        raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>")
+        # CHECK: attr: dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
 
-    print("attr:", raw)
+        print("attr:", raw)
 
-    a = DenseFPElementsAttr(raw)
-    assert len(a) == 4
+        a = DenseFPElementsAttr(raw)
+        assert len(a) == 4
 
-    # CHECK: 0.0 1.0 2.0 3.0
-    for value in a:
-      print(value, end=" ")
-    print()
+        # CHECK: 0.0 1.0 2.0 3.0
+        for value in a:
+            print(value, end=" ")
+        print()
 
-    # CHECK: f32
-    print(ShapedType(a.type).element_type)
+        # CHECK: f32
+        print(ShapedType(a.type).element_type)
 
 
 # CHECK-LABEL: TEST: testDictAttr
 @run
 def testDictAttr():
-  with Context():
-    dict_attr = {
-      'stringattr':  StringAttr.get('string'),
-      'integerattr' : IntegerAttr.get(
-        IntegerType.get_signless(32), 42)
-    }
+    with Context():
+        dict_attr = {
+            "stringattr": StringAttr.get("string"),
+            "integerattr": IntegerAttr.get(IntegerType.get_signless(32), 42),
+        }
 
-    a = DictAttr.get(dict_attr)
+        a = DictAttr.get(dict_attr)
 
-    # CHECK attr: {integerattr = 42 : i32, stringattr = "string"}
-    print("attr:", a)
+        # CHECK attr: {integerattr = 42 : i32, stringattr = "string"}
+        print("attr:", a)
 
-    assert len(a) == 2
+        assert len(a) == 2
 
-    # CHECK: 42 : i32
-    print(a['integerattr'])
+        # CHECK: 42 : i32
+        print(a["integerattr"])
 
-    # CHECK: "string"
-    print(a['stringattr'])
+        # CHECK: "string"
+        print(a["stringattr"])
 
-    # CHECK: True
-    print('stringattr' in a)
+        # CHECK: True
+        print("stringattr" in a)
 
-    # CHECK: False
-    print('not_in_dict' in a)
+        # CHECK: False
+        print("not_in_dict" in a)
 
-    # Check that exceptions are raised as expected.
-    try:
-      _ = a['does_not_exist']
-    except KeyError:
-      pass
-    else:
-      assert False, "Exception not produced"
+        # Check that exceptions are raised as expected.
+        try:
+            _ = a["does_not_exist"]
+        except KeyError:
+            pass
+        else:
+            assert False, "Exception not produced"
 
-    try:
-      _ = a[42]
-    except IndexError:
-      pass
-    else:
-      assert False, "expected IndexError on accessing an out-of-bounds attribute"
+        try:
+            _ = a[42]
+        except IndexError:
+            pass
+        else:
+            assert False, "expected IndexError on accessing an out-of-bounds attribute"
 
-    # CHECK "empty: {}"
-    print("empty: ", DictAttr.get())
+        # CHECK "empty: {}"
+        print("empty: ", DictAttr.get())
 
 
 # CHECK-LABEL: TEST: testTypeAttr
 @run
 def testTypeAttr():
-  with Context():
-    raw = Attribute.parse("vector<4xf32>")
-    # CHECK: attr: vector<4xf32>
-    print("attr:", raw)
-    type_attr = TypeAttr(raw)
-    # CHECK: f32
-    print(ShapedType(type_attr.value).element_type)
+    with Context():
+        raw = Attribute.parse("vector<4xf32>")
+        # CHECK: attr: vector<4xf32>
+        print("attr:", raw)
+        type_attr = TypeAttr(raw)
+        # CHECK: f32
+        print(ShapedType(type_attr.value).element_type)
 
 
 # CHECK-LABEL: TEST: testArrayAttr
 @run
 def testArrayAttr():
-  with Context():
-    raw = Attribute.parse("[42, true, vector<4xf32>]")
-  # CHECK: attr: [42, true, vector<4xf32>]
-  print("raw attr:", raw)
-  # CHECK: - 42
-  # CHECK: - true
-  # CHECK: - vector<4xf32>
-  for attr in ArrayAttr(raw):
-    print("- ", attr)
-
-  with Context():
-    intAttr = Attribute.parse("42")
-    vecAttr = Attribute.parse("vector<4xf32>")
-    boolAttr = BoolAttr.get(True)
-    raw = ArrayAttr.get([vecAttr, boolAttr, intAttr])
-  # CHECK: attr: [vector<4xf32>, true, 42]
-  print("raw attr:", raw)
-  # CHECK: - vector<4xf32>
-  # CHECK: - true
-  # CHECK: - 42
-  arr = ArrayAttr(raw)
-  for attr in arr:
-    print("- ", attr)
-  # CHECK: attr[0]: vector<4xf32>
-  print("attr[0]:", arr[0])
-  # CHECK: attr[1]: true
-  print("attr[1]:", arr[1])
-  # CHECK: attr[2]: 42
-  print("attr[2]:", arr[2])
-  try:
-    print("attr[3]:", arr[3])
-  except IndexError as e:
-    # CHECK: Error: ArrayAttribute index out of range
-    print("Error: ", e)
-  with Context():
+    with Context():
+        raw = Attribute.parse("[42, true, vector<4xf32>]")
+    # CHECK: attr: [42, true, vector<4xf32>]
+    print("raw attr:", raw)
+    # CHECK: - 42
+    # CHECK: - true
+    # CHECK: - vector<4xf32>
+    for attr in ArrayAttr(raw):
+        print("- ", attr)
+
+    with Context():
+        intAttr = Attribute.parse("42")
+        vecAttr = Attribute.parse("vector<4xf32>")
+        boolAttr = BoolAttr.get(True)
+        raw = ArrayAttr.get([vecAttr, boolAttr, intAttr])
+    # CHECK: attr: [vector<4xf32>, true, 42]
+    print("raw attr:", raw)
+    # CHECK: - vector<4xf32>
+    # CHECK: - true
+    # CHECK: - 42
+    arr = ArrayAttr(raw)
+    for attr in arr:
+        print("- ", attr)
+    # CHECK: attr[0]: vector<4xf32>
+    print("attr[0]:", arr[0])
+    # CHECK: attr[1]: true
+    print("attr[1]:", arr[1])
+    # CHECK: attr[2]: 42
+    print("attr[2]:", arr[2])
     try:
-      ArrayAttr.get([None])
-    except RuntimeError as e:
-      # CHECK: Error: Invalid attribute (None?) when attempting to create an ArrayAttribute
-      print("Error: ", e)
-    try:
-      ArrayAttr.get([42])
-    except RuntimeError as e:
-      # CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute
-      print("Error: ", e)
-
-  with Context():
-    array = ArrayAttr.get([StringAttr.get("a"), StringAttr.get("b")])
-    array = array + [StringAttr.get("c")]
-    # CHECK: concat: ["a", "b", "c"]
-    print("concat: ", array)
+        print("attr[3]:", arr[3])
+    except IndexError as e:
+        # CHECK: Error: ArrayAttribute index out of range
+        print("Error: ", e)
+    with Context():
+        try:
+            ArrayAttr.get([None])
+        except RuntimeError as e:
+            # CHECK: Error: Invalid attribute (None?) when attempting to create an ArrayAttribute
+            print("Error: ", e)
+        try:
+            ArrayAttr.get([42])
+        except RuntimeError as e:
+            # CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute
+            print("Error: ", e)
+
+    with Context():
+        array = ArrayAttr.get([StringAttr.get("a"), StringAttr.get("b")])
+        array = array + [StringAttr.get("c")]
+        # CHECK: concat: ["a", "b", "c"]
+        print("concat: ", array)
 
 
 # CHECK-LABEL: TEST: testStridedLayoutAttr
 @run
 def testStridedLayoutAttr():
-  with Context():
-    attr = StridedLayoutAttr.get(42, [5, 7, 13])
-    # CHECK: strided<[5, 7, 13], offset: 42>
-    print(attr)
-    # CHECK: 42
-    print(attr.offset)
-    # CHECK: 3
-    print(len(attr.strides))
-    # CHECK: 5
-    print(attr.strides[0])
-    # CHECK: 7
-    print(attr.strides[1])
-    # CHECK: 13
-    print(attr.strides[2])
-
-    attr = StridedLayoutAttr.get_fully_dynamic(3)
-    dynamic = ShapedType.get_dynamic_stride_or_offset()
-    # CHECK: strided<[?, ?, ?], offset: ?>
-    print(attr)
-    # CHECK: offset is dynamic: True
-    print(f"offset is dynamic: {attr.offset == dynamic}")
-    # CHECK: rank: 3
-    print(f"rank: {len(attr.strides)}")
-    # CHECK: strides are dynamic: [True, True, True]
-    print(f"strides are dynamic: {[s == dynamic for s in attr.strides]}")
+    with Context():
+        attr = StridedLayoutAttr.get(42, [5, 7, 13])
+        # CHECK: strided<[5, 7, 13], offset: 42>
+        print(attr)
+        # CHECK: 42
+        print(attr.offset)
+        # CHECK: 3
+        print(len(attr.strides))
+        # CHECK: 5
+        print(attr.strides[0])
+        # CHECK: 7
+        print(attr.strides[1])
+        # CHECK: 13
+        print(attr.strides[2])
+
+        attr = StridedLayoutAttr.get_fully_dynamic(3)
+        dynamic = ShapedType.get_dynamic_stride_or_offset()
+        # CHECK: strided<[?, ?, ?], offset: ?>
+        print(attr)
+        # CHECK: offset is dynamic: True
+        print(f"offset is dynamic: {attr.offset == dynamic}")
+        # CHECK: rank: 3
+        print(f"rank: {len(attr.strides)}")
+        # CHECK: strides are dynamic: [True, True, True]
+        print(f"strides are dynamic: {[s == dynamic for s in attr.strides]}")
index e929d79..8b4d946 100644 (file)
@@ -10,11 +10,11 @@ from mlir.dialects import func
 
 
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  gc.collect()
-  assert Context._get_live_count() == 0
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
+    return f
 
 
 # CHECK-LABEL: TEST: testBlockCreation
@@ -26,60 +26,66 @@ def run(f):
 # CHECK:   return
 @run
 def testBlockCreation():
-  with Context() as ctx, Location.unknown():
-    module = builtin.ModuleOp()
-    with InsertionPoint(module.body):
-      f_type = FunctionType.get(
-          [IntegerType.get_signless(32),
-           IntegerType.get_signless(16)], [])
-      f_op = func.FuncOp("test", f_type)
-      entry_block = f_op.add_entry_block([Location.name("arg0"), Location.name("arg1")])
-      i32_arg, i16_arg = entry_block.arguments
-      successor_block = entry_block.create_after(i32_arg.type, arg_locs=[Location.name("successor")])
-      with InsertionPoint(successor_block) as successor_ip:
-        assert successor_ip.block == successor_block
-        func.ReturnOp([])
-      middle_block = successor_block.create_before(i16_arg.type, arg_locs=[Location.name("middle")])
-
-      with InsertionPoint(entry_block) as entry_ip:
-        assert entry_ip.block == entry_block
-        cf.BranchOp([i16_arg], dest=middle_block)
-
-      with InsertionPoint(middle_block) as middle_ip:
-        assert middle_ip.block == middle_block
-        cf.BranchOp([i32_arg], dest=successor_block)
-    module.print(enable_debug_info=True)
-    # Ensure region back references are coherent.
-    assert entry_block.region == middle_block.region == successor_block.region
+    with Context() as ctx, Location.unknown():
+        module = builtin.ModuleOp()
+        with InsertionPoint(module.body):
+            f_type = FunctionType.get(
+                [IntegerType.get_signless(32), IntegerType.get_signless(16)], []
+            )
+            f_op = func.FuncOp("test", f_type)
+            entry_block = f_op.add_entry_block(
+                [Location.name("arg0"), Location.name("arg1")]
+            )
+            i32_arg, i16_arg = entry_block.arguments
+            successor_block = entry_block.create_after(
+                i32_arg.type, arg_locs=[Location.name("successor")]
+            )
+            with InsertionPoint(successor_block) as successor_ip:
+                assert successor_ip.block == successor_block
+                func.ReturnOp([])
+            middle_block = successor_block.create_before(
+                i16_arg.type, arg_locs=[Location.name("middle")]
+            )
+
+            with InsertionPoint(entry_block) as entry_ip:
+                assert entry_ip.block == entry_block
+                cf.BranchOp([i16_arg], dest=middle_block)
+
+            with InsertionPoint(middle_block) as middle_ip:
+                assert middle_ip.block == middle_block
+                cf.BranchOp([i32_arg], dest=successor_block)
+        module.print(enable_debug_info=True)
+        # Ensure region back references are coherent.
+        assert entry_block.region == middle_block.region == successor_block.region
 
 
 # CHECK-LABEL: TEST: testBlockCreationArgLocs
 @run
 def testBlockCreationArgLocs():
-  with Context() as ctx:
-    ctx.allow_unregistered_dialects = True
-    f32 = F32Type.get()
-    op = Operation.create("test", regions=1, loc=Location.unknown())
-    blocks = op.regions[0].blocks
-
-    with Location.name("default_loc"):
-      blocks.append(f32)
-    blocks.append()
-    # CHECK:      ^bb0(%{{.+}}: f32 loc("default_loc")):
-    # CHECK-NEXT: ^bb1:
-    op.print(enable_debug_info=True)
-
-    try:
-      blocks.append(f32)
-    except RuntimeError as err:
-      # CHECK: Missing loc: An MLIR function requires a Location but none was provided
-      print("Missing loc:", err)
-
-    try:
-      blocks.append(f32, f32, arg_locs=[Location.unknown()])
-    except ValueError as err:
-      # CHECK: Wrong loc count: Expected 2 locations, got: 1
-      print("Wrong loc count:", err)
+    with Context() as ctx:
+        ctx.allow_unregistered_dialects = True
+        f32 = F32Type.get()
+        op = Operation.create("test", regions=1, loc=Location.unknown())
+        blocks = op.regions[0].blocks
+
+        with Location.name("default_loc"):
+            blocks.append(f32)
+        blocks.append()
+        # CHECK:      ^bb0(%{{.+}}: f32 loc("default_loc")):
+        # CHECK-NEXT: ^bb1:
+        op.print(enable_debug_info=True)
+
+        try:
+            blocks.append(f32)
+        except RuntimeError as err:
+            # CHECK: Missing loc: An MLIR function requires a Location but none was provided
+            print("Missing loc:", err)
+
+        try:
+            blocks.append(f32, f32, arg_locs=[Location.unknown()])
+        except ValueError as err:
+            # CHECK: Wrong loc count: Expected 2 locations, got: 1
+            print("Wrong loc count:", err)
 
 
 # CHECK-LABEL: TEST: testFirstBlockCreation
@@ -87,19 +93,20 @@ def testBlockCreationArgLocs():
 # CHECK:   return
 @run
 def testFirstBlockCreation():
-  with Context() as ctx, Location.unknown():
-    module = builtin.ModuleOp()
-    f32 = F32Type.get()
-    with InsertionPoint(module.body):
-      f = func.FuncOp("test", ([f32], []))
-      entry_block = Block.create_at_start(f.operation.regions[0],
-                                          [f32], [Location.name("arg_loc")])
-      with InsertionPoint(entry_block):
-        func.ReturnOp([])
-
-    module.print(enable_debug_info=True)
-    assert module.verify()
-    assert f.body.blocks[0] == entry_block
+    with Context() as ctx, Location.unknown():
+        module = builtin.ModuleOp()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+            f = func.FuncOp("test", ([f32], []))
+            entry_block = Block.create_at_start(
+                f.operation.regions[0], [f32], [Location.name("arg_loc")]
+            )
+            with InsertionPoint(entry_block):
+                func.ReturnOp([])
+
+        module.print(enable_debug_info=True)
+        assert module.verify()
+        assert f.body.blocks[0] == entry_block
 
 
 # CHECK-LABEL: TEST: testBlockMove
@@ -109,32 +116,32 @@ def testFirstBlockCreation():
 # CHECK:  }) : () -> f32
 @run
 def testBlockMove():
-  with Context() as ctx, Location.unknown():
-    ctx.allow_unregistered_dialects = True
-    module = Module.create()
-    f32 = F32Type.get()
-    with InsertionPoint(module.body):
-      dummy = Operation.create("dummy", regions=1)
-      block = Block.create_at_start(dummy.operation.regions[0], [f32])
-      with InsertionPoint(block):
-        ret_op = Operation.create("ret", operands=[block.arguments[0]])
-      realop = Operation.create("realop",
-                                results=[r.type for r in ret_op.operands],
-                                regions=1)
-      block.append_to(realop.operation.regions[0])
-      dummy.operation.erase()
-    print(module)
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+            dummy = Operation.create("dummy", regions=1)
+            block = Block.create_at_start(dummy.operation.regions[0], [f32])
+            with InsertionPoint(block):
+                ret_op = Operation.create("ret", operands=[block.arguments[0]])
+            realop = Operation.create(
+                "realop", results=[r.type for r in ret_op.operands], regions=1
+            )
+            block.append_to(realop.operation.regions[0])
+            dummy.operation.erase()
+        print(module)
 
 
 # CHECK-LABEL: TEST: testBlockHash
 @run
 def testBlockHash():
-  with Context() as ctx, Location.unknown():
-    ctx.allow_unregistered_dialects = True
-    module = Module.create()
-    f32 = F32Type.get()
-    with InsertionPoint(module.body):
-      dummy = Operation.create("dummy", regions=1)
-      block1 = Block.create_at_start(dummy.operation.regions[0], [f32])
-      block2 = Block.create_at_start(dummy.operation.regions[0], [f32])
-      assert hash(block1) != hash(block2)
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+            dummy = Operation.create("dummy", regions=1)
+            block1 = Block.create_at_start(dummy.operation.regions[0], [f32])
+            block2 = Block.create_at_start(dummy.operation.regions[0], [f32])
+            assert hash(block1) != hash(block2)
index 19e21ff..fc484a5 100644 (file)
@@ -5,246 +5,246 @@ from mlir.ir import *
 
 
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  gc.collect()
-  assert Context._get_live_count() == 0
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
+    return f
 
 
 # CHECK-LABEL: TEST: testParsePrint
 @run
 def testParsePrint():
-  ctx = Context()
-  t = Type.parse("i32", ctx)
-  assert t.context is ctx
-  ctx = None
-  gc.collect()
-  # CHECK: i32
-  print(str(t))
-  # CHECK: Type(i32)
-  print(repr(t))
+    ctx = Context()
+    t = Type.parse("i32", ctx)
+    assert t.context is ctx
+    ctx = None
+    gc.collect()
+    # CHECK: i32
+    print(str(t))
+    # CHECK: Type(i32)
+    print(repr(t))
 
 
 # CHECK-LABEL: TEST: testParseError
 @run
 def testParseError():
-  ctx = Context()
-  try:
-    t = Type.parse("BAD_TYPE_DOES_NOT_EXIST", ctx)
-  except MLIRError as e:
-    # CHECK: testParseError: <
-    # CHECK:   Unable to parse type:
-    # CHECK:   error: "BAD_TYPE_DOES_NOT_EXIST":1:1: expected non-function type
-    # CHECK: >
-    print(f"testParseError: <{e}>")
-  else:
-    print("Exception not produced")
+    ctx = Context()
+    try:
+        t = Type.parse("BAD_TYPE_DOES_NOT_EXIST", ctx)
+    except MLIRError as e:
+        # CHECK: testParseError: <
+        # CHECK:   Unable to parse type:
+        # CHECK:   error: "BAD_TYPE_DOES_NOT_EXIST":1:1: expected non-function type
+        # CHECK: >
+        print(f"testParseError: <{e}>")
+    else:
+        print("Exception not produced")
 
 
 # CHECK-LABEL: TEST: testTypeEq
 @run
 def testTypeEq():
-  ctx = Context()
-  t1 = Type.parse("i32", ctx)
-  t2 = Type.parse("f32", ctx)
-  t3 = Type.parse("i32", ctx)
-  # CHECK: t1 == t1: True
-  print("t1 == t1:", t1 == t1)
-  # CHECK: t1 == t2: False
-  print("t1 == t2:", t1 == t2)
-  # CHECK: t1 == t3: True
-  print("t1 == t3:", t1 == t3)
-  # CHECK: t1 == None: False
-  print("t1 == None:", t1 == None)
+    ctx = Context()
+    t1 = Type.parse("i32", ctx)
+    t2 = Type.parse("f32", ctx)
+    t3 = Type.parse("i32", ctx)
+    # CHECK: t1 == t1: True
+    print("t1 == t1:", t1 == t1)
+    # CHECK: t1 == t2: False
+    print("t1 == t2:", t1 == t2)
+    # CHECK: t1 == t3: True
+    print("t1 == t3:", t1 == t3)
+    # CHECK: t1 == None: False
+    print("t1 == None:", t1 == None)
 
 
 # CHECK-LABEL: TEST: testTypeHash
 @run
 def testTypeHash():
-  ctx = Context()
-  t1 = Type.parse("i32", ctx)
-  t2 = Type.parse("f32", ctx)
-  t3 = Type.parse("i32", ctx)
+    ctx = Context()
+    t1 = Type.parse("i32", ctx)
+    t2 = Type.parse("f32", ctx)
+    t3 = Type.parse("i32", ctx)
 
-  # CHECK: hash(t1) == hash(t3): True
-  print("hash(t1) == hash(t3):", t1.__hash__() == t3.__hash__())
+    # CHECK: hash(t1) == hash(t3): True
+    print("hash(t1) == hash(t3):", t1.__hash__() == t3.__hash__())
 
-  s = set()
-  s.add(t1)
-  s.add(t2)
-  s.add(t3)
-  # CHECK: len(s): 2
-  print("len(s): ", len(s))
+    s = set()
+    s.add(t1)
+    s.add(t2)
+    s.add(t3)
+    # CHECK: len(s): 2
+    print("len(s): ", len(s))
 
 
 # CHECK-LABEL: TEST: testTypeCast
 @run
 def testTypeCast():
-  ctx = Context()
-  t1 = Type.parse("i32", ctx)
-  t2 = Type(t1)
-  # CHECK: t1 == t2: True
-  print("t1 == t2:", t1 == t2)
+    ctx = Context()
+    t1 = Type.parse("i32", ctx)
+    t2 = Type(t1)
+    # CHECK: t1 == t2: True
+    print("t1 == t2:", t1 == t2)
 
 
 # CHECK-LABEL: TEST: testTypeIsInstance
 @run
 def testTypeIsInstance():
-  ctx = Context()
-  t1 = Type.parse("i32", ctx)
-  t2 = Type.parse("f32", ctx)
-  # CHECK: True
-  print(IntegerType.isinstance(t1))
-  # CHECK: False
-  print(F32Type.isinstance(t1))
-  # CHECK: True
-  print(F32Type.isinstance(t2))
+    ctx = Context()
+    t1 = Type.parse("i32", ctx)
+    t2 = Type.parse("f32", ctx)
+    # CHECK: True
+    print(IntegerType.isinstance(t1))
+    # CHECK: False
+    print(F32Type.isinstance(t1))
+    # CHECK: True
+    print(F32Type.isinstance(t2))
 
 
 # CHECK-LABEL: TEST: testTypeEqDoesNotRaise
 @run
 def testTypeEqDoesNotRaise():
-  ctx = Context()
-  t1 = Type.parse("i32", ctx)
-  not_a_type = "foo"
-  # CHECK: False
-  print(t1 == not_a_type)
-  # CHECK: False
-  print(t1 == None)
-  # CHECK: True
-  print(t1 != None)
+    ctx = Context()
+    t1 = Type.parse("i32", ctx)
+    not_a_type = "foo"
+    # CHECK: False
+    print(t1 == not_a_type)
+    # CHECK: False
+    print(t1 == None)
+    # CHECK: True
+    print(t1 != None)
 
 
 # CHECK-LABEL: TEST: testTypeCapsule
 @run
 def testTypeCapsule():
-  with Context() as ctx:
-    t1 = Type.parse("i32", ctx)
-  # CHECK: mlir.ir.Type._CAPIPtr
-  type_capsule = t1._CAPIPtr
-  print(type_capsule)
-  t2 = Type._CAPICreate(type_capsule)
-  assert t2 == t1
-  assert t2.context is ctx
+    with Context() as ctx:
+        t1 = Type.parse("i32", ctx)
+    # CHECK: mlir.ir.Type._CAPIPtr
+    type_capsule = t1._CAPIPtr
+    print(type_capsule)
+    t2 = Type._CAPICreate(type_capsule)
+    assert t2 == t1
+    assert t2.context is ctx
 
 
 # CHECK-LABEL: TEST: testStandardTypeCasts
 @run
 def testStandardTypeCasts():
-  ctx = Context()
-  t1 = Type.parse("i32", ctx)
-  tint = IntegerType(t1)
-  tself = IntegerType(tint)
-  # CHECK: Type(i32)
-  print(repr(tint))
-  try:
-    tillegal = IntegerType(Type.parse("f32", ctx))
-  except ValueError as e:
-    # CHECK: ValueError: Cannot cast type to IntegerType (from Type(f32))
-    print("ValueError:", e)
-  else:
-    print("Exception not produced")
+    ctx = Context()
+    t1 = Type.parse("i32", ctx)
+    tint = IntegerType(t1)
+    tself = IntegerType(tint)
+    # CHECK: Type(i32)
+    print(repr(tint))
+    try:
+        tillegal = IntegerType(Type.parse("f32", ctx))
+    except ValueError as e:
+        # CHECK: ValueError: Cannot cast type to IntegerType (from Type(f32))
+        print("ValueError:", e)
+    else:
+        print("Exception not produced")
 
 
 # CHECK-LABEL: TEST: testIntegerType
 @run
 def testIntegerType():
-  with Context() as ctx:
-    i32 = IntegerType(Type.parse("i32"))
-    # CHECK: i32 width: 32
-    print("i32 width:", i32.width)
-    # CHECK: i32 signless: True
-    print("i32 signless:", i32.is_signless)
-    # CHECK: i32 signed: False
-    print("i32 signed:", i32.is_signed)
-    # CHECK: i32 unsigned: False
-    print("i32 unsigned:", i32.is_unsigned)
-
-    s32 = IntegerType(Type.parse("si32"))
-    # CHECK: s32 signless: False
-    print("s32 signless:", s32.is_signless)
-    # CHECK: s32 signed: True
-    print("s32 signed:", s32.is_signed)
-    # CHECK: s32 unsigned: False
-    print("s32 unsigned:", s32.is_unsigned)
-
-    u32 = IntegerType(Type.parse("ui32"))
-    # CHECK: u32 signless: False
-    print("u32 signless:", u32.is_signless)
-    # CHECK: u32 signed: False
-    print("u32 signed:", u32.is_signed)
-    # CHECK: u32 unsigned: True
-    print("u32 unsigned:", u32.is_unsigned)
-
-    # CHECK: signless: i16
-    print("signless:", IntegerType.get_signless(16))
-    # CHECK: signed: si8
-    print("signed:", IntegerType.get_signed(8))
-    # CHECK: unsigned: ui64
-    print("unsigned:", IntegerType.get_unsigned(64))
+    with Context() as ctx:
+        i32 = IntegerType(Type.parse("i32"))
+        # CHECK: i32 width: 32
+        print("i32 width:", i32.width)
+        # CHECK: i32 signless: True
+        print("i32 signless:", i32.is_signless)
+        # CHECK: i32 signed: False
+        print("i32 signed:", i32.is_signed)
+        # CHECK: i32 unsigned: False
+        print("i32 unsigned:", i32.is_unsigned)
+
+        s32 = IntegerType(Type.parse("si32"))
+        # CHECK: s32 signless: False
+        print("s32 signless:", s32.is_signless)
+        # CHECK: s32 signed: True
+        print("s32 signed:", s32.is_signed)
+        # CHECK: s32 unsigned: False
+        print("s32 unsigned:", s32.is_unsigned)
+
+        u32 = IntegerType(Type.parse("ui32"))
+        # CHECK: u32 signless: False
+        print("u32 signless:", u32.is_signless)
+        # CHECK: u32 signed: False
+        print("u32 signed:", u32.is_signed)
+        # CHECK: u32 unsigned: True
+        print("u32 unsigned:", u32.is_unsigned)
+
+        # CHECK: signless: i16
+        print("signless:", IntegerType.get_signless(16))
+        # CHECK: signed: si8
+        print("signed:", IntegerType.get_signed(8))
+        # CHECK: unsigned: ui64
+        print("unsigned:", IntegerType.get_unsigned(64))
 
 
 # CHECK-LABEL: TEST: testIndexType
 @run
 def testIndexType():
-  with Context() as ctx:
-    # CHECK: index type: index
-    print("index type:", IndexType.get())
+    with Context() as ctx:
+        # CHECK: index type: index
+        print("index type:", IndexType.get())
 
 
 # CHECK-LABEL: TEST: testFloatType
 @run
 def testFloatType():
-  with Context():
-    # CHECK: float: f8E4M3FN
-    print("float:", Float8E4M3FNType.get())
-    # CHECK: float: f8E5M2
-    print("float:", Float8E5M2Type.get())
-    # CHECK: float: f8E5M2FNUZ
-    print("float:", Float8E5M2FNUZType.get())
-    # CHECK: float: f8E4M3FNUZ
-    print("float:", Float8E4M3FNUZType.get())
-    # CHECK: float: f8E4M3B11FNUZ
-    print("float:", Float8E4M3B11FNUZType.get())
-    # CHECK: float: bf16
-    print("float:", BF16Type.get())
-    # CHECK: float: f16
-    print("float:", F16Type.get())
-    # CHECK: float: f32
-    print("float:", F32Type.get())
-    # CHECK: float: f64
-    print("float:", F64Type.get())
+    with Context():
+        # CHECK: float: f8E4M3FN
+        print("float:", Float8E4M3FNType.get())
+        # CHECK: float: f8E5M2
+        print("float:", Float8E5M2Type.get())
+        # CHECK: float: f8E5M2FNUZ
+        print("float:", Float8E5M2FNUZType.get())
+        # CHECK: float: f8E4M3FNUZ
+        print("float:", Float8E4M3FNUZType.get())
+        # CHECK: float: f8E4M3B11FNUZ
+        print("float:", Float8E4M3B11FNUZType.get())
+        # CHECK: float: bf16
+        print("float:", BF16Type.get())
+        # CHECK: float: f16
+        print("float:", F16Type.get())
+        # CHECK: float: f32
+        print("float:", F32Type.get())
+        # CHECK: float: f64
+        print("float:", F64Type.get())
 
 
 # CHECK-LABEL: TEST: testNoneType
 @run
 def testNoneType():
-  with Context():
-    # CHECK: none type: none
-    print("none type:", NoneType.get())
+    with Context():
+        # CHECK: none type: none
+        print("none type:", NoneType.get())
 
 
 # CHECK-LABEL: TEST: testComplexType
 @run
 def testComplexType():
-  with Context() as ctx:
-    complex_i32 = ComplexType(Type.parse("complex<i32>"))
-    # CHECK: complex type element: i32
-    print("complex type element:", complex_i32.element_type)
+    with Context() as ctx:
+        complex_i32 = ComplexType(Type.parse("complex<i32>"))
+        # CHECK: complex type element: i32
+        print("complex type element:", complex_i32.element_type)
 
-    f32 = F32Type.get()
-    # CHECK: complex type: complex<f32>
-    print("complex type:", ComplexType.get(f32))
+        f32 = F32Type.get()
+        # CHECK: complex type: complex<f32>
+        print("complex type:", ComplexType.get(f32))
 
-    index = IndexType.get()
-    try:
-      complex_invalid = ComplexType.get(index)
-    except ValueError as e:
-      # CHECK: invalid 'Type(index)' and expected floating point or integer type.
-      print(e)
-    else:
-      print("Exception not produced")
+        index = IndexType.get()
+        try:
+            complex_invalid = ComplexType.get(index)
+        except ValueError as e:
+            # CHECK: invalid 'Type(index)' and expected floating point or integer type.
+            print(e)
+        else:
+            print("Exception not produced")
 
 
 # CHECK-LABEL: TEST: testConcreteShapedType
@@ -253,27 +253,26 @@ def testComplexType():
 # shaped type. The class hierarchy is preserved on the python side.
 @run
 def testConcreteShapedType():
-  with Context() as ctx:
-    vector = VectorType(Type.parse("vector<2x3xf32>"))
-    # CHECK: element type: f32
-    print("element type:", vector.element_type)
-    # CHECK: whether the given shaped type is ranked: True
-    print("whether the given shaped type is ranked:", vector.has_rank)
-    # CHECK: rank: 2
-    print("rank:", vector.rank)
-    # CHECK: whether the shaped type has a static shape: True
-    print("whether the shaped type has a static shape:",
-          vector.has_static_shape)
-    # CHECK: whether the dim-th dimension is dynamic: False
-    print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0))
-    # CHECK: dim size: 3
-    print("dim size:", vector.get_dim_size(1))
-    # CHECK: is_dynamic_size: False
-    print("is_dynamic_size:", vector.is_dynamic_size(3))
-    # CHECK: is_dynamic_stride_or_offset: False
-    print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1))
-    # CHECK: isinstance(ShapedType): True
-    print("isinstance(ShapedType):", isinstance(vector, ShapedType))
+    with Context() as ctx:
+        vector = VectorType(Type.parse("vector<2x3xf32>"))
+        # CHECK: element type: f32
+        print("element type:", vector.element_type)
+        # CHECK: whether the given shaped type is ranked: True
+        print("whether the given shaped type is ranked:", vector.has_rank)
+        # CHECK: rank: 2
+        print("rank:", vector.rank)
+        # CHECK: whether the shaped type has a static shape: True
+        print("whether the shaped type has a static shape:", vector.has_static_shape)
+        # CHECK: whether the dim-th dimension is dynamic: False
+        print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0))
+        # CHECK: dim size: 3
+        print("dim size:", vector.get_dim_size(1))
+        # CHECK: is_dynamic_size: False
+        print("is_dynamic_size:", vector.is_dynamic_size(3))
+        # CHECK: is_dynamic_stride_or_offset: False
+        print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1))
+        # CHECK: isinstance(ShapedType): True
+        print("isinstance(ShapedType):", isinstance(vector, ShapedType))
 
 
 # CHECK-LABEL: TEST: testAbstractShapedType
@@ -281,321 +280,322 @@ def testConcreteShapedType():
 # shaped type (using vector as an example).
 @run
 def testAbstractShapedType():
-  ctx = Context()
-  vector = ShapedType(Type.parse("vector<2x3xf32>", ctx))
-  # CHECK: element type: f32
-  print("element type:", vector.element_type)
+    ctx = Context()
+    vector = ShapedType(Type.parse("vector<2x3xf32>", ctx))
+    # CHECK: element type: f32
+    print("element type:", vector.element_type)
 
 
 # CHECK-LABEL: TEST: testVectorType
 @run
 def testVectorType():
-  with Context(), Location.unknown():
-    f32 = F32Type.get()
-    shape = [2, 3]
-    # CHECK: vector type: vector<2x3xf32>
-    print("vector type:", VectorType.get(shape, f32))
-
-    none = NoneType.get()
-    try:
-      vector_invalid = VectorType.get(shape, none)
-    except MLIRError as e:
-      # CHECK: Invalid type:
-      # CHECK: error: unknown: vector elements must be int/index/float type but got 'none'
-      print(e)
-    else:
-      print("Exception not produced")
+    with Context(), Location.unknown():
+        f32 = F32Type.get()
+        shape = [2, 3]
+        # CHECK: vector type: vector<2x3xf32>
+        print("vector type:", VectorType.get(shape, f32))
+
+        none = NoneType.get()
+        try:
+            vector_invalid = VectorType.get(shape, none)
+        except MLIRError as e:
+            # CHECK: Invalid type:
+            # CHECK: error: unknown: vector elements must be int/index/float type but got 'none'
+            print(e)
+        else:
+            print("Exception not produced")
 
 
 # CHECK-LABEL: TEST: testRankedTensorType
 @run
 def testRankedTensorType():
-  with Context(), Location.unknown():
-    f32 = F32Type.get()
-    shape = [2, 3]
-    loc = Location.unknown()
-    # CHECK: ranked tensor type: tensor<2x3xf32>
-    print("ranked tensor type:", RankedTensorType.get(shape, f32))
-
-    none = NoneType.get()
-    try:
-      tensor_invalid = RankedTensorType.get(shape, none)
-    except MLIRError as e:
-      # CHECK: Invalid type:
-      # CHECK: error: unknown: invalid tensor element type: 'none'
-      print(e)
-    else:
-      print("Exception not produced")
-
-    # Encoding should be None.
-    assert RankedTensorType.get(shape, f32).encoding is None
-
-    tensor = RankedTensorType.get(shape, f32)
-    assert tensor.shape == shape
+    with Context(), Location.unknown():
+        f32 = F32Type.get()
+        shape = [2, 3]
+        loc = Location.unknown()
+        # CHECK: ranked tensor type: tensor<2x3xf32>
+        print("ranked tensor type:", RankedTensorType.get(shape, f32))
+
+        none = NoneType.get()
+        try:
+            tensor_invalid = RankedTensorType.get(shape, none)
+        except MLIRError as e:
+            # CHECK: Invalid type:
+            # CHECK: error: unknown: invalid tensor element type: 'none'
+            print(e)
+        else:
+            print("Exception not produced")
+
+        # Encoding should be None.
+        assert RankedTensorType.get(shape, f32).encoding is None
+
+        tensor = RankedTensorType.get(shape, f32)
+        assert tensor.shape == shape
 
 
 # CHECK-LABEL: TEST: testUnrankedTensorType
 @run
 def testUnrankedTensorType():
-  with Context(), Location.unknown():
-    f32 = F32Type.get()
-    loc = Location.unknown()
-    unranked_tensor = UnrankedTensorType.get(f32)
-    # CHECK: unranked tensor type: tensor<*xf32>
-    print("unranked tensor type:", unranked_tensor)
-    try:
-      invalid_rank = unranked_tensor.rank
-    except ValueError as e:
-      # CHECK: calling this method requires that the type has a rank.
-      print(e)
-    else:
-      print("Exception not produced")
-    try:
-      invalid_is_dynamic_dim = unranked_tensor.is_dynamic_dim(0)
-    except ValueError as e:
-      # CHECK: calling this method requires that the type has a rank.
-      print(e)
-    else:
-      print("Exception not produced")
-    try:
-      invalid_get_dim_size = unranked_tensor.get_dim_size(1)
-    except ValueError as e:
-      # CHECK: calling this method requires that the type has a rank.
-      print(e)
-    else:
-      print("Exception not produced")
-
-    none = NoneType.get()
-    try:
-      tensor_invalid = UnrankedTensorType.get(none)
-    except MLIRError as e:
-      # CHECK: Invalid type:
-      # CHECK: error: unknown: invalid tensor element type: 'none'
-      print(e)
-    else:
-      print("Exception not produced")
+    with Context(), Location.unknown():
+        f32 = F32Type.get()
+        loc = Location.unknown()
+        unranked_tensor = UnrankedTensorType.get(f32)
+        # CHECK: unranked tensor type: tensor<*xf32>
+        print("unranked tensor type:", unranked_tensor)
+        try:
+            invalid_rank = unranked_tensor.rank
+        except ValueError as e:
+            # CHECK: calling this method requires that the type has a rank.
+            print(e)
+        else:
+            print("Exception not produced")
+        try:
+            invalid_is_dynamic_dim = unranked_tensor.is_dynamic_dim(0)
+        except ValueError as e:
+            # CHECK: calling this method requires that the type has a rank.
+            print(e)
+        else:
+            print("Exception not produced")
+        try:
+            invalid_get_dim_size = unranked_tensor.get_dim_size(1)
+        except ValueError as e:
+            # CHECK: calling this method requires that the type has a rank.
+            print(e)
+        else:
+            print("Exception not produced")
+
+        none = NoneType.get()
+        try:
+            tensor_invalid = UnrankedTensorType.get(none)
+        except MLIRError as e:
+            # CHECK: Invalid type:
+            # CHECK: error: unknown: invalid tensor element type: 'none'
+            print(e)
+        else:
+            print("Exception not produced")
 
 
 # CHECK-LABEL: TEST: testMemRefType
 @run
 def testMemRefType():
-  with Context(), Location.unknown():
-    f32 = F32Type.get()
-    shape = [2, 3]
-    loc = Location.unknown()
-    memref = MemRefType.get(shape, f32, memory_space=Attribute.parse("2"))
-    # CHECK: memref type: memref<2x3xf32, 2>
-    print("memref type:", memref)
-    # CHECK: memref layout: affine_map<(d0, d1) -> (d0, d1)>
-    print("memref layout:", memref.layout)
-    # CHECK: memref affine map: (d0, d1) -> (d0, d1)
-    print("memref affine map:", memref.affine_map)
-    # CHECK: memory space: 2
-    print("memory space:", memref.memory_space)
-
-    layout = AffineMapAttr.get(AffineMap.get_permutation([1, 0]))
-    memref_layout = MemRefType.get(shape, f32, layout=layout)
-    # CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
-    print("memref type:", memref_layout)
-    # CHECK: memref layout: affine_map<(d0, d1) -> (d1, d0)>
-    print("memref layout:", memref_layout.layout)
-    # CHECK: memref affine map: (d0, d1) -> (d1, d0)
-    print("memref affine map:", memref_layout.affine_map)
-    # CHECK: memory space: <<NULL ATTRIBUTE>>
-    print("memory space:", memref_layout.memory_space)
-
-    none = NoneType.get()
-    try:
-      memref_invalid = MemRefType.get(shape, none)
-    except MLIRError as e:
-      # CHECK: Invalid type:
-      # CHECK: error: unknown: invalid memref element type
-      print(e)
-    else:
-      print("Exception not produced")
-
-    assert memref.shape == shape
+    with Context(), Location.unknown():
+        f32 = F32Type.get()
+        shape = [2, 3]
+        loc = Location.unknown()
+        memref = MemRefType.get(shape, f32, memory_space=Attribute.parse("2"))
+        # CHECK: memref type: memref<2x3xf32, 2>
+        print("memref type:", memref)
+        # CHECK: memref layout: affine_map<(d0, d1) -> (d0, d1)>
+        print("memref layout:", memref.layout)
+        # CHECK: memref affine map: (d0, d1) -> (d0, d1)
+        print("memref affine map:", memref.affine_map)
+        # CHECK: memory space: 2
+        print("memory space:", memref.memory_space)
+
+        layout = AffineMapAttr.get(AffineMap.get_permutation([1, 0]))
+        memref_layout = MemRefType.get(shape, f32, layout=layout)
+        # CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
+        print("memref type:", memref_layout)
+        # CHECK: memref layout: affine_map<(d0, d1) -> (d1, d0)>
+        print("memref layout:", memref_layout.layout)
+        # CHECK: memref affine map: (d0, d1) -> (d1, d0)
+        print("memref affine map:", memref_layout.affine_map)
+        # CHECK: memory space: <<NULL ATTRIBUTE>>
+        print("memory space:", memref_layout.memory_space)
+
+        none = NoneType.get()
+        try:
+            memref_invalid = MemRefType.get(shape, none)
+        except MLIRError as e:
+            # CHECK: Invalid type:
+            # CHECK: error: unknown: invalid memref element type
+            print(e)
+        else:
+            print("Exception not produced")
+
+        assert memref.shape == shape
 
 
 # CHECK-LABEL: TEST: testUnrankedMemRefType
 @run
 def testUnrankedMemRefType():
-  with Context(), Location.unknown():
-    f32 = F32Type.get()
-    loc = Location.unknown()
-    unranked_memref = UnrankedMemRefType.get(f32, Attribute.parse("2"))
-    # CHECK: unranked memref type: memref<*xf32, 2>
-    print("unranked memref type:", unranked_memref)
-    try:
-      invalid_rank = unranked_memref.rank
-    except ValueError as e:
-      # CHECK: calling this method requires that the type has a rank.
-      print(e)
-    else:
-      print("Exception not produced")
-    try:
-      invalid_is_dynamic_dim = unranked_memref.is_dynamic_dim(0)
-    except ValueError as e:
-      # CHECK: calling this method requires that the type has a rank.
-      print(e)
-    else:
-      print("Exception not produced")
-    try:
-      invalid_get_dim_size = unranked_memref.get_dim_size(1)
-    except ValueError as e:
-      # CHECK: calling this method requires that the type has a rank.
-      print(e)
-    else:
-      print("Exception not produced")
-
-    none = NoneType.get()
-    try:
-      memref_invalid = UnrankedMemRefType.get(none, Attribute.parse("2"))
-    except MLIRError as e:
-      # CHECK: Invalid type:
-      # CHECK: error: unknown: invalid memref element type
-      print(e)
-    else:
-      print("Exception not produced")
+    with Context(), Location.unknown():
+        f32 = F32Type.get()
+        loc = Location.unknown()
+        unranked_memref = UnrankedMemRefType.get(f32, Attribute.parse("2"))
+        # CHECK: unranked memref type: memref<*xf32, 2>
+        print("unranked memref type:", unranked_memref)
+        try:
+            invalid_rank = unranked_memref.rank
+        except ValueError as e:
+            # CHECK: calling this method requires that the type has a rank.
+            print(e)
+        else:
+            print("Exception not produced")
+        try:
+            invalid_is_dynamic_dim = unranked_memref.is_dynamic_dim(0)
+        except ValueError as e:
+            # CHECK: calling this method requires that the type has a rank.
+            print(e)
+        else:
+            print("Exception not produced")
+        try:
+            invalid_get_dim_size = unranked_memref.get_dim_size(1)
+        except ValueError as e:
+            # CHECK: calling this method requires that the type has a rank.
+            print(e)
+        else:
+            print("Exception not produced")
+
+        none = NoneType.get()
+        try:
+            memref_invalid = UnrankedMemRefType.get(none, Attribute.parse("2"))
+        except MLIRError as e:
+            # CHECK: Invalid type:
+            # CHECK: error: unknown: invalid memref element type
+            print(e)
+        else:
+            print("Exception not produced")
 
 
 # CHECK-LABEL: TEST: testTupleType
 @run
 def testTupleType():
-  with Context() as ctx:
-    i32 = IntegerType(Type.parse("i32"))
-    f32 = F32Type.get()
-    vector = VectorType(Type.parse("vector<2x3xf32>"))
-    l = [i32, f32, vector]
-    tuple_type = TupleType.get_tuple(l)
-    # CHECK: tuple type: tuple<i32, f32, vector<2x3xf32>>
-    print("tuple type:", tuple_type)
-    # CHECK: number of types: 3
-    print("number of types:", tuple_type.num_types)
-    # CHECK: pos-th type in the tuple type: f32
-    print("pos-th type in the tuple type:", tuple_type.get_type(1))
+    with Context() as ctx:
+        i32 = IntegerType(Type.parse("i32"))
+        f32 = F32Type.get()
+        vector = VectorType(Type.parse("vector<2x3xf32>"))
+        l = [i32, f32, vector]
+        tuple_type = TupleType.get_tuple(l)
+        # CHECK: tuple type: tuple<i32, f32, vector<2x3xf32>>
+        print("tuple type:", tuple_type)
+        # CHECK: number of types: 3
+        print("number of types:", tuple_type.num_types)
+        # CHECK: pos-th type in the tuple type: f32
+        print("pos-th type in the tuple type:", tuple_type.get_type(1))
 
 
 # CHECK-LABEL: TEST: testFunctionType
 @run
 def testFunctionType():
-  with Context() as ctx:
-    input_types = [IntegerType.get_signless(32), IntegerType.get_signless(16)]
-    result_types = [IndexType.get()]
-    func = FunctionType.get(input_types, result_types)
-    # CHECK: INPUTS: [Type(i32), Type(i16)]
-    print("INPUTS:", func.inputs)
-    # CHECK: RESULTS: [Type(index)]
-    print("RESULTS:", func.results)
+    with Context() as ctx:
+        input_types = [IntegerType.get_signless(32), IntegerType.get_signless(16)]
+        result_types = [IndexType.get()]
+        func = FunctionType.get(input_types, result_types)
+        # CHECK: INPUTS: [Type(i32), Type(i16)]
+        print("INPUTS:", func.inputs)
+        # CHECK: RESULTS: [Type(index)]
+        print("RESULTS:", func.results)
 
 
 # CHECK-LABEL: TEST: testOpaqueType
 @run
 def testOpaqueType():
-  with Context() as ctx:
-    ctx.allow_unregistered_dialects = True
-    opaque = OpaqueType.get("dialect", "type")
-    # CHECK: opaque type: !dialect.type
-    print("opaque type:", opaque)
-    # CHECK: dialect namespace: dialect
-    print("dialect namespace:", opaque.dialect_namespace)
-    # CHECK: data: type
-    print("data:", opaque.data)
+    with Context() as ctx:
+        ctx.allow_unregistered_dialects = True
+        opaque = OpaqueType.get("dialect", "type")
+        # CHECK: opaque type: !dialect.type
+        print("opaque type:", opaque)
+        # CHECK: dialect namespace: dialect
+        print("dialect namespace:", opaque.dialect_namespace)
+        # CHECK: data: type
+        print("data:", opaque.data)
 
 
 # CHECK-LABEL: TEST: testShapedTypeConstants
 # Tests that ShapedType exposes magic value constants.
 @run
 def testShapedTypeConstants():
-  # CHECK: <class 'int'>
-  print(type(ShapedType.get_dynamic_size()))
-  # CHECK: <class 'int'>
-  print(type(ShapedType.get_dynamic_stride_or_offset()))
+    # CHECK: <class 'int'>
+    print(type(ShapedType.get_dynamic_size()))
+    # CHECK: <class 'int'>
+    print(type(ShapedType.get_dynamic_stride_or_offset()))
 
 
 # CHECK-LABEL: TEST: testTypeIDs
 @run
 def testTypeIDs():
-  with Context(), Location.unknown():
-    f32 = F32Type.get()
-
-    types = [
-      (IntegerType, IntegerType.get_signless(16)),
-      (IndexType, IndexType.get()),
-      (Float8E4M3FNType, Float8E4M3FNType.get()),
-      (Float8E5M2Type, Float8E5M2Type.get()),
-      (Float8E4M3FNUZType, Float8E4M3FNUZType.get()),
-      (Float8E4M3B11FNUZType, Float8E4M3B11FNUZType.get()),
-      (Float8E5M2FNUZType, Float8E5M2FNUZType.get()),
-      (BF16Type, BF16Type.get()),
-      (F16Type, F16Type.get()),
-      (F32Type, F32Type.get()),
-      (F64Type, F64Type.get()),
-      (NoneType, NoneType.get()),
-      (ComplexType, ComplexType.get(f32)),
-      (VectorType, VectorType.get([2, 3], f32)),
-      (RankedTensorType, RankedTensorType.get([2, 3], f32)),
-      (UnrankedTensorType, UnrankedTensorType.get(f32)),
-      (MemRefType, MemRefType.get([2, 3], f32)),
-      (UnrankedMemRefType, UnrankedMemRefType.get(f32, Attribute.parse("2"))),
-      (TupleType, TupleType.get_tuple([f32])),
-      (FunctionType, FunctionType.get([], [])),
-      (OpaqueType, OpaqueType.get("tensor", "bob")),
-    ]
-
-    # CHECK: IntegerType(i16)
-    # CHECK: IndexType(index)
-    # CHECK: Float8E4M3FNType(f8E4M3FN)
-    # CHECK: Float8E5M2Type(f8E5M2)
-    # CHECK: Float8E4M3FNUZType(f8E4M3FNUZ)
-    # CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ)
-    # CHECK: Float8E5M2FNUZType(f8E5M2FNUZ)
-    # CHECK: BF16Type(bf16)
-    # CHECK: F16Type(f16)
-    # CHECK: F32Type(f32)
-    # CHECK: F64Type(f64)
-    # CHECK: NoneType(none)
-    # CHECK: ComplexType(complex<f32>)
-    # CHECK: VectorType(vector<2x3xf32>)
-    # CHECK: RankedTensorType(tensor<2x3xf32>)
-    # CHECK: UnrankedTensorType(tensor<*xf32>)
-    # CHECK: MemRefType(memref<2x3xf32>)
-    # CHECK: UnrankedMemRefType(memref<*xf32, 2>)
-    # CHECK: TupleType(tuple<f32>)
-    # CHECK: FunctionType(() -> ())
-    # CHECK: OpaqueType(!tensor.bob)
-    for _, t in types:
-      print(repr(t))
-
-    # Test getTypeIdFunction agrees with
-    # mlirTypeGetTypeID(self) for an instance.
-    # CHECK: all equal
-    for t1, t2 in types:
-      tid1, tid2 = t1.static_typeid, Type(t2).typeid
-      assert tid1 == tid2 and hash(tid1) == hash(
-          tid2), f"expected hash and value equality {t1} {t2}"
-    else:
-      print("all equal")
-
-    # Test that storing PyTypeID in python dicts
-    # works as expected.
-    typeid_dict = dict(types)
-    assert len(typeid_dict)
-
-    # CHECK: all equal
-    for t1, t2 in typeid_dict.items():
-      assert t1.static_typeid == t2.typeid and hash(
-          t1.static_typeid) == hash(
-              t2.typeid), f"expected hash and value equality {t1} {t2}"
-    else:
-      print("all equal")
-
-    # CHECK: ShapedType has no typeid.
-    try:
-      print(ShapedType.static_typeid)
-    except AttributeError as e:
-      print(e)
-
-    vector_type = Type.parse("vector<2x3xf32>")
-    # CHECK: True
-    print(ShapedType(vector_type).typeid == vector_type.typeid)
+    with Context(), Location.unknown():
+        f32 = F32Type.get()
+
+        types = [
+            (IntegerType, IntegerType.get_signless(16)),
+            (IndexType, IndexType.get()),
+            (Float8E4M3FNType, Float8E4M3FNType.get()),
+            (Float8E5M2Type, Float8E5M2Type.get()),
+            (Float8E4M3FNUZType, Float8E4M3FNUZType.get()),
+            (Float8E4M3B11FNUZType, Float8E4M3B11FNUZType.get()),
+            (Float8E5M2FNUZType, Float8E5M2FNUZType.get()),
+            (BF16Type, BF16Type.get()),
+            (F16Type, F16Type.get()),
+            (F32Type, F32Type.get()),
+            (F64Type, F64Type.get()),
+            (NoneType, NoneType.get()),
+            (ComplexType, ComplexType.get(f32)),
+            (VectorType, VectorType.get([2, 3], f32)),
+            (RankedTensorType, RankedTensorType.get([2, 3], f32)),
+            (UnrankedTensorType, UnrankedTensorType.get(f32)),
+            (MemRefType, MemRefType.get([2, 3], f32)),
+            (UnrankedMemRefType, UnrankedMemRefType.get(f32, Attribute.parse("2"))),
+            (TupleType, TupleType.get_tuple([f32])),
+            (FunctionType, FunctionType.get([], [])),
+            (OpaqueType, OpaqueType.get("tensor", "bob")),
+        ]
+
+        # CHECK: IntegerType(i16)
+        # CHECK: IndexType(index)
+        # CHECK: Float8E4M3FNType(f8E4M3FN)
+        # CHECK: Float8E5M2Type(f8E5M2)
+        # CHECK: Float8E4M3FNUZType(f8E4M3FNUZ)
+        # CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ)
+        # CHECK: Float8E5M2FNUZType(f8E5M2FNUZ)
+        # CHECK: BF16Type(bf16)
+        # CHECK: F16Type(f16)
+        # CHECK: F32Type(f32)
+        # CHECK: F64Type(f64)
+        # CHECK: NoneType(none)
+        # CHECK: ComplexType(complex<f32>)
+        # CHECK: VectorType(vector<2x3xf32>)
+        # CHECK: RankedTensorType(tensor<2x3xf32>)
+        # CHECK: UnrankedTensorType(tensor<*xf32>)
+        # CHECK: MemRefType(memref<2x3xf32>)
+        # CHECK: UnrankedMemRefType(memref<*xf32, 2>)
+        # CHECK: TupleType(tuple<f32>)
+        # CHECK: FunctionType(() -> ())
+        # CHECK: OpaqueType(!tensor.bob)
+        for _, t in types:
+            print(repr(t))
+
+        # Test getTypeIdFunction agrees with
+        # mlirTypeGetTypeID(self) for an instance.
+        # CHECK: all equal
+        for t1, t2 in types:
+            tid1, tid2 = t1.static_typeid, Type(t2).typeid
+            assert tid1 == tid2 and hash(tid1) == hash(
+                tid2
+            ), f"expected hash and value equality {t1} {t2}"
+        else:
+            print("all equal")
+
+        # Test that storing PyTypeID in python dicts
+        # works as expected.
+        typeid_dict = dict(types)
+        assert len(typeid_dict)
+
+        # CHECK: all equal
+        for t1, t2 in typeid_dict.items():
+            assert t1.static_typeid == t2.typeid and hash(t1.static_typeid) == hash(
+                t2.typeid
+            ), f"expected hash and value equality {t1} {t2}"
+        else:
+            print("all equal")
+
+        # CHECK: ShapedType has no typeid.
+        try:
+            print(ShapedType.static_typeid)
+        except AttributeError as e:
+            print(e)
+
+        vector_type = Type.parse("vector<2x3xf32>")
+        # CHECK: True
+        print(ShapedType(vector_type).typeid == vector_type.typeid)
index b93fcf7..48d9e35 100644 (file)
 import gc
 from mlir.ir import *
 
+
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  gc.collect()
-  assert Context._get_live_count() == 0
+    print("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
 
 
 # CHECK-LABEL: TEST: testContextEnterExit
 def testContextEnterExit():
-  with Context() as ctx:
-    assert Context.current is ctx
-  try:
-    _ = Context.current
-  except ValueError as e:
-    # CHECK: No current Context
-    print(e)
-  else: assert False, "Expected exception"
+    with Context() as ctx:
+        assert Context.current is ctx
+    try:
+        _ = Context.current
+    except ValueError as e:
+        # CHECK: No current Context
+        print(e)
+    else:
+        assert False, "Expected exception"
+
 
 run(testContextEnterExit)
 
 
 # CHECK-LABEL: TEST: testLocationEnterExit
 def testLocationEnterExit():
-  ctx1 = Context()
-  with Location.unknown(ctx1) as loc1:
-    assert Context.current is ctx1
-    assert Location.current is loc1
-
-    # Re-asserting the same context should not change the location.
-    with ctx1:
-      assert Context.current is ctx1
-      assert Location.current is loc1
-      # Asserting a different context should clear it.
-      with Context() as ctx2:
-        assert Context.current is ctx2
-        try:
-          _ = Location.current
-        except ValueError: pass
-        else: assert False, "Expected exception"
-
-      # And should restore.
-      assert Context.current is ctx1
-      assert Location.current is loc1
-
-  # All should clear.
-  try:
-    _ = Location.current
-  except ValueError as e:
-    # CHECK: No current Location
-    print(e)
-  else: assert False, "Expected exception"
+    ctx1 = Context()
+    with Location.unknown(ctx1) as loc1:
+        assert Context.current is ctx1
+        assert Location.current is loc1
+
+        # Re-asserting the same context should not change the location.
+        with ctx1:
+            assert Context.current is ctx1
+            assert Location.current is loc1
+            # Asserting a different context should clear it.
+            with Context() as ctx2:
+                assert Context.current is ctx2
+                try:
+                    _ = Location.current
+                except ValueError:
+                    pass
+                else:
+                    assert False, "Expected exception"
+
+            # And should restore.
+            assert Context.current is ctx1
+            assert Location.current is loc1
+
+    # All should clear.
+    try:
+        _ = Location.current
+    except ValueError as e:
+        # CHECK: No current Location
+        print(e)
+    else:
+        assert False, "Expected exception"
+
 
 run(testLocationEnterExit)
 
 
 # CHECK-LABEL: TEST: testInsertionPointEnterExit
 def testInsertionPointEnterExit():
-  ctx1 = Context()
-  m = Module.create(Location.unknown(ctx1))
-  ip = InsertionPoint(m.body)
-
-  with ip:
-    assert InsertionPoint.current is ip
-    # Asserting a location from the same context should preserve.
-    with Location.unknown(ctx1) as loc1:
-      assert InsertionPoint.current is ip
-      assert Location.current is loc1
-    # Location should clear.
+    ctx1 = Context()
+    m = Module.create(Location.unknown(ctx1))
+    ip = InsertionPoint(m.body)
+
+    with ip:
+        assert InsertionPoint.current is ip
+        # Asserting a location from the same context should preserve.
+        with Location.unknown(ctx1) as loc1:
+            assert InsertionPoint.current is ip
+            assert Location.current is loc1
+        # Location should clear.
+        try:
+            _ = Location.current
+        except ValueError:
+            pass
+        else:
+            assert False, "Expected exception"
+
+        # Asserting the same Context should preserve.
+        with ctx1:
+            assert InsertionPoint.current is ip
+
+        # Asserting a different context should clear it.
+        with Context() as ctx2:
+            assert Context.current is ctx2
+            try:
+                _ = InsertionPoint.current
+            except ValueError:
+                pass
+            else:
+                assert False, "Expected exception"
+
+    # All should clear.
     try:
-      _ = Location.current
-    except ValueError: pass
-    else: assert False, "Expected exception"
-
-    # Asserting the same Context should preserve.
-    with ctx1:
-      assert InsertionPoint.current is ip
-
-    # Asserting a different context should clear it.
-    with Context() as ctx2:
-      assert Context.current is ctx2
-      try:
         _ = InsertionPoint.current
-      except ValueError: pass
-      else: assert False, "Expected exception"
-
-  # All should clear.
-  try:
-    _ = InsertionPoint.current
-  except ValueError as e:
-    # CHECK: No current InsertionPoint
-    print(e)
-  else: assert False, "Expected exception"
+    except ValueError as e:
+        # CHECK: No current InsertionPoint
+        print(e)
+    else:
+        assert False, "Expected exception"
+
 
 run(testInsertionPointEnterExit)
index 3268d9f..629a710 100644 (file)
@@ -2,38 +2,40 @@
 
 from mlir.ir import *
 
+
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
+    print("\nTEST:", f.__name__)
+    f()
 
 
 # CHECK-LABEL: TEST: testNameIsPrivate
 def testNameIsPrivate():
-  # `import *` ignores private names starting with an understore, so the debug
-  # flag shouldn't be visible unless explicitly imported.
-  try:
-    _GlobalDebug.flag = True
-  except NameError:
-    pass
-  else:
-    assert False, "_GlobalDebug must not be available by default"
+    # `import *` ignores private names starting with an understore, so the debug
+    # flag shouldn't be visible unless explicitly imported.
+    try:
+        _GlobalDebug.flag = True
+    except NameError:
+        pass
+    else:
+        assert False, "_GlobalDebug must not be available by default"
+
 
 run(testNameIsPrivate)
 
 
 # CHECK-LABEL: TEST: testDebugDlag
 def testDebugDlag():
-  # Private names must be imported expilcitly.
-  from mlir.ir import _GlobalDebug
-
-  # CHECK: False
-  print(_GlobalDebug.flag)
-  _GlobalDebug.flag = True
-  # CHECK: True
-  print(_GlobalDebug.flag)
-  _GlobalDebug.flag = False
-  # CHECK: False
-  print(_GlobalDebug.flag)
+    # Private names must be imported expilcitly.
+    from mlir.ir import _GlobalDebug
+
+    # CHECK: False
+    print(_GlobalDebug.flag)
+    _GlobalDebug.flag = True
+    # CHECK: True
+    print(_GlobalDebug.flag)
+    _GlobalDebug.flag = False
+    # CHECK: False
+    print(_GlobalDebug.flag)
 
-run(testDebugDlag)
 
+run(testDebugDlag)
index cc07f6e..2f4300d 100644 (file)
 import gc
 from mlir.ir import *
 
+
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  gc.collect()
-  assert Context._get_live_count() == 0
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
+    return f
 
 
 @run
 def testLifecycleContextDestroy():
-  ctx = Context()
-  def callback(foo): ...
-  handler = ctx.attach_diagnostic_handler(callback)
-  assert handler.attached
-  # If context is destroyed before the handler, it should auto-detach.
-  ctx = None
-  gc.collect()
-  assert not handler.attached
+    ctx = Context()
+
+    def callback(foo):
+        ...
+
+    handler = ctx.attach_diagnostic_handler(callback)
+    assert handler.attached
+    # If context is destroyed before the handler, it should auto-detach.
+    ctx = None
+    gc.collect()
+    assert not handler.attached
 
-  # And finally collecting the handler should be fine.
-  handler = None
-  gc.collect()
+    # And finally collecting the handler should be fine.
+    handler = None
+    gc.collect()
 
 
 @run
 def testLifecycleExplicitDetach():
-  ctx = Context()
-  def callback(foo): ...
-  handler = ctx.attach_diagnostic_handler(callback)
-  assert handler.attached
-  handler.detach()
-  assert not handler.attached
+    ctx = Context()
+
+    def callback(foo):
+        ...
+
+    handler = ctx.attach_diagnostic_handler(callback)
+    assert handler.attached
+    handler.detach()
+    assert not handler.attached
 
 
 @run
 def testLifecycleWith():
-  ctx = Context()
-  def callback(foo): ...
-  with ctx.attach_diagnostic_handler(callback) as handler:
-    assert handler.attached
-  assert not handler.attached
+    ctx = Context()
+
+    def callback(foo):
+        ...
+
+    with ctx.attach_diagnostic_handler(callback) as handler:
+        assert handler.attached
+    assert not handler.attached
 
 
 @run
 def testLifecycleWithAndExplicitDetach():
-  ctx = Context()
-  def callback(foo): ...
-  with ctx.attach_diagnostic_handler(callback) as handler:
-    assert handler.attached
-    handler.detach()
-  assert not handler.attached
+    ctx = Context()
+
+    def callback(foo):
+        ...
+
+    with ctx.attach_diagnostic_handler(callback) as handler:
+        assert handler.attached
+        handler.detach()
+    assert not handler.attached
 
 
 # CHECK-LABEL: TEST: testDiagnosticCallback
 @run
 def testDiagnosticCallback():
-  ctx = Context()
-  def callback(d):
-    # CHECK: DIAGNOSTIC: message='foobar', severity=DiagnosticSeverity.ERROR, loc=loc(unknown)
-    print(f"DIAGNOSTIC: message='{d.message}', severity={d.severity}, loc={d.location}")
-    return True
-  handler = ctx.attach_diagnostic_handler(callback)
-  loc = Location.unknown(ctx)
-  loc.emit_error("foobar")
-  assert not handler.had_error
+    ctx = Context()
+
+    def callback(d):
+        # CHECK: DIAGNOSTIC: message='foobar', severity=DiagnosticSeverity.ERROR, loc=loc(unknown)
+        print(
+            f"DIAGNOSTIC: message='{d.message}', severity={d.severity}, loc={d.location}"
+        )
+        return True
+
+    handler = ctx.attach_diagnostic_handler(callback)
+    loc = Location.unknown(ctx)
+    loc.emit_error("foobar")
+    assert not handler.had_error
 
 
 # CHECK-LABEL: TEST: testDiagnosticEmptyNotes
 # TODO: Come up with a way to inject a diagnostic with notes from this API.
 @run
 def testDiagnosticEmptyNotes():
-  ctx = Context()
-  def callback(d):
-    # CHECK: DIAGNOSTIC: notes=()
-    print(f"DIAGNOSTIC: notes={d.notes}")
-    return True
-  handler = ctx.attach_diagnostic_handler(callback)
-  loc = Location.unknown(ctx)
-  loc.emit_error("foobar")
-  assert not handler.had_error
+    ctx = Context()
+
+    def callback(d):
+        # CHECK: DIAGNOSTIC: notes=()
+        print(f"DIAGNOSTIC: notes={d.notes}")
+        return True
+
+    handler = ctx.attach_diagnostic_handler(callback)
+    loc = Location.unknown(ctx)
+    loc.emit_error("foobar")
+    assert not handler.had_error
 
 
 # CHECK-LABEL: TEST: testDiagnosticNonEmptyNotes
 @run
 def testDiagnosticNonEmptyNotes():
-  ctx = Context()
-  ctx.emit_error_diagnostics = True
-  def callback(d):
-    # CHECK: DIAGNOSTIC:
-    # CHECK:   message='arith.addi' op requires one result
-    # CHECK:   notes=['see current operation: "arith.addi"() : () -> ()']
-    print(f"DIAGNOSTIC:")
-    print(f"  message={d.message}")
-    print(f"  notes={list(map(str, d.notes))}")
-    return True
-  handler = ctx.attach_diagnostic_handler(callback)
-  loc = Location.unknown(ctx)
-  try:
-    Operation.create('arith.addi', loc=loc).verify()
-  except MLIRError:
-    pass
-  assert not handler.had_error
+    ctx = Context()
+    ctx.emit_error_diagnostics = True
+
+    def callback(d):
+        # CHECK: DIAGNOSTIC:
+        # CHECK:   message='arith.addi' op requires one result
+        # CHECK:   notes=['see current operation: "arith.addi"() : () -> ()']
+        print(f"DIAGNOSTIC:")
+        print(f"  message={d.message}")
+        print(f"  notes={list(map(str, d.notes))}")
+        return True
+
+    handler = ctx.attach_diagnostic_handler(callback)
+    loc = Location.unknown(ctx)
+    try:
+        Operation.create("arith.addi", loc=loc).verify()
+    except MLIRError:
+        pass
+    assert not handler.had_error
+
 
 # CHECK-LABEL: TEST: testDiagnosticCallbackException
 @run
 def testDiagnosticCallbackException():
-  ctx = Context()
-  def callback(d):
-    raise ValueError("Error in handler")
-  handler = ctx.attach_diagnostic_handler(callback)
-  loc = Location.unknown(ctx)
-  loc.emit_error("foobar")
-  assert handler.had_error
+    ctx = Context()
+
+    def callback(d):
+        raise ValueError("Error in handler")
+
+    handler = ctx.attach_diagnostic_handler(callback)
+    loc = Location.unknown(ctx)
+    loc.emit_error("foobar")
+    assert handler.had_error
 
 
 # CHECK-LABEL: TEST: testEscapingDiagnostic
 @run
 def testEscapingDiagnostic():
-  ctx = Context()
-  diags = []
-  def callback(d):
-    diags.append(d)
-    return True
-  handler = ctx.attach_diagnostic_handler(callback)
-  loc = Location.unknown(ctx)
-  loc.emit_error("foobar")
-  assert not handler.had_error
-
-  # CHECK: DIAGNOSTIC: <Invalid Diagnostic>
-  print(f"DIAGNOSTIC: {str(diags[0])}")
-  try:
-    diags[0].severity
-    raise RuntimeError("expected exception")
-  except ValueError:
-    pass
-  try:
-    diags[0].location
-    raise RuntimeError("expected exception")
-  except ValueError:
-    pass
-  try:
-    diags[0].message
-    raise RuntimeError("expected exception")
-  except ValueError:
-    pass
-  try:
-    diags[0].notes
-    raise RuntimeError("expected exception")
-  except ValueError:
-    pass
-
+    ctx = Context()
+    diags = []
+
+    def callback(d):
+        diags.append(d)
+        return True
+
+    handler = ctx.attach_diagnostic_handler(callback)
+    loc = Location.unknown(ctx)
+    loc.emit_error("foobar")
+    assert not handler.had_error
+
+    # CHECK: DIAGNOSTIC: <Invalid Diagnostic>
+    print(f"DIAGNOSTIC: {str(diags[0])}")
+    try:
+        diags[0].severity
+        raise RuntimeError("expected exception")
+    except ValueError:
+        pass
+    try:
+        diags[0].location
+        raise RuntimeError("expected exception")
+    except ValueError:
+        pass
+    try:
+        diags[0].message
+        raise RuntimeError("expected exception")
+    except ValueError:
+        pass
+    try:
+        diags[0].notes
+        raise RuntimeError("expected exception")
+    except ValueError:
+        pass
 
 
 # CHECK-LABEL: TEST: testDiagnosticReturnTrueHandles
 @run
 def testDiagnosticReturnTrueHandles():
-  ctx = Context()
-  def callback1(d):
-    print(f"CALLBACK1: {d}")
-    return True
-  def callback2(d):
-    print(f"CALLBACK2: {d}")
-    return True
-  ctx.attach_diagnostic_handler(callback1)
-  ctx.attach_diagnostic_handler(callback2)
-  loc = Location.unknown(ctx)
-  # CHECK-NOT: CALLBACK1
-  # CHECK: CALLBACK2: foobar
-  # CHECK-NOT: CALLBACK1
-  loc.emit_error("foobar")
+    ctx = Context()
+
+    def callback1(d):
+        print(f"CALLBACK1: {d}")
+        return True
+
+    def callback2(d):
+        print(f"CALLBACK2: {d}")
+        return True
+
+    ctx.attach_diagnostic_handler(callback1)
+    ctx.attach_diagnostic_handler(callback2)
+    loc = Location.unknown(ctx)
+    # CHECK-NOT: CALLBACK1
+    # CHECK: CALLBACK2: foobar
+    # CHECK-NOT: CALLBACK1
+    loc.emit_error("foobar")
 
 
 # CHECK-LABEL: TEST: testDiagnosticReturnFalseDoesNotHandle
 @run
 def testDiagnosticReturnFalseDoesNotHandle():
-  ctx = Context()
-  def callback1(d):
-    print(f"CALLBACK1: {d}")
-    return True
-  def callback2(d):
-    print(f"CALLBACK2: {d}")
-    return False
-  ctx.attach_diagnostic_handler(callback1)
-  ctx.attach_diagnostic_handler(callback2)
-  loc = Location.unknown(ctx)
-  # CHECK: CALLBACK2: foobar
-  # CHECK: CALLBACK1: foobar
-  loc.emit_error("foobar")
+    ctx = Context()
+
+    def callback1(d):
+        print(f"CALLBACK1: {d}")
+        return True
+
+    def callback2(d):
+        print(f"CALLBACK2: {d}")
+        return False
+
+    ctx.attach_diagnostic_handler(callback1)
+    ctx.attach_diagnostic_handler(callback2)
+    loc = Location.unknown(ctx)
+    # CHECK: CALLBACK2: foobar
+    # CHECK: CALLBACK1: foobar
+    loc.emit_error("foobar")
index 65e81e8..eebf7c3 100644 (file)
@@ -5,60 +5,60 @@ from mlir.ir import *
 
 
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  gc.collect()
-  assert Context._get_live_count() == 0
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
+    return f
 
 
 # CHECK-LABEL: TEST: testDialectDescriptor
 @run
 def testDialectDescriptor():
-  ctx = Context()
-  d = ctx.get_dialect_descriptor("func")
-  # CHECK: <DialectDescriptor func>
-  print(d)
-  # CHECK: func
-  print(d.namespace)
-  try:
-    _ = ctx.get_dialect_descriptor("not_existing")
-  except ValueError:
-    pass
-  else:
-    assert False, "Expected exception"
+    ctx = Context()
+    d = ctx.get_dialect_descriptor("func")
+    # CHECK: <DialectDescriptor func>
+    print(d)
+    # CHECK: func
+    print(d.namespace)
+    try:
+        _ = ctx.get_dialect_descriptor("not_existing")
+    except ValueError:
+        pass
+    else:
+        assert False, "Expected exception"
 
 
 # CHECK-LABEL: TEST: testUserDialectClass
 @run
 def testUserDialectClass():
-  ctx = Context()
-  # Access using attribute.
-  d = ctx.dialects.func
-  # CHECK: <Dialect func (class mlir.dialects._func_ops_gen._Dialect)>
-  print(d)
-  try:
-    _ = ctx.dialects.not_existing
-  except AttributeError:
-    pass
-  else:
-    assert False, "Expected exception"
-
-  # Access using index.
-  d = ctx.dialects["func"]
-  # CHECK: <Dialect func (class mlir.dialects._func_ops_gen._Dialect)>
-  print(d)
-  try:
-    _ = ctx.dialects["not_existing"]
-  except IndexError:
-    pass
-  else:
-    assert False, "Expected exception"
-
-  # Using the 'd' alias.
-  d = ctx.d["func"]
-  # CHECK: <Dialect func (class mlir.dialects._func_ops_gen._Dialect)>
-  print(d)
+    ctx = Context()
+    # Access using attribute.
+    d = ctx.dialects.func
+    # CHECK: <Dialect func (class mlir.dialects._func_ops_gen._Dialect)>
+    print(d)
+    try:
+        _ = ctx.dialects.not_existing
+    except AttributeError:
+        pass
+    else:
+        assert False, "Expected exception"
+
+    # Access using index.
+    d = ctx.dialects["func"]
+    # CHECK: <Dialect func (class mlir.dialects._func_ops_gen._Dialect)>
+    print(d)
+    try:
+        _ = ctx.dialects["not_existing"]
+    except IndexError:
+        pass
+    else:
+        assert False, "Expected exception"
+
+    # Using the 'd' alias.
+    d = ctx.d["func"]
+    # CHECK: <Dialect func (class mlir.dialects._func_ops_gen._Dialect)>
+    print(d)
 
 
 # CHECK-LABEL: TEST: testCustomOpView
@@ -67,40 +67,40 @@ def testUserDialectClass():
 # additional capabilities come online.
 @run
 def testCustomOpView():
+    def createInput():
+        op = Operation.create("pytest_dummy.intinput", results=[f32])
+        # TODO: Auto result cast from operation
+        return op.results[0]
 
-  def createInput():
-    op = Operation.create("pytest_dummy.intinput", results=[f32])
-    # TODO: Auto result cast from operation
-    return op.results[0]
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+        m = Module.create()
 
-  with Context() as ctx, Location.unknown():
-    ctx.allow_unregistered_dialects = True
-    m = Module.create()
+        with InsertionPoint(m.body):
+            f32 = F32Type.get()
+            # Create via dialects context collection.
+            input1 = createInput()
+            input2 = createInput()
+            op1 = ctx.dialects.arith.AddFOp(input1, input2)
 
-    with InsertionPoint(m.body):
-      f32 = F32Type.get()
-      # Create via dialects context collection.
-      input1 = createInput()
-      input2 = createInput()
-      op1 = ctx.dialects.arith.AddFOp(input1, input2)
+            # Create via an import
+            from mlir.dialects.arith import AddFOp
 
-      # Create via an import
-      from mlir.dialects.arith import AddFOp
-      AddFOp(input1, op1.result)
+            AddFOp(input1, op1.result)
 
-  # CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput"
-  # CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput"
-  # CHECK: %[[R0:.*]] = arith.addf %[[INPUT0]], %[[INPUT1]] : f32
-  # CHECK: %[[R1:.*]] = arith.addf %[[INPUT0]], %[[R0]] : f32
-  m.operation.print()
+    # CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput"
+    # CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput"
+    # CHECK: %[[R0:.*]] = arith.addf %[[INPUT0]], %[[INPUT1]] : f32
+    # CHECK: %[[R1:.*]] = arith.addf %[[INPUT0]], %[[R0]] : f32
+    m.operation.print()
 
 
 # CHECK-LABEL: TEST: testIsRegisteredOperation
 @run
 def testIsRegisteredOperation():
-  ctx = Context()
+    ctx = Context()
 
-  # CHECK: cf.cond_br: True
-  print(f"cf.cond_br: {ctx.is_registered_operation('cf.cond_br')}")
-  # CHECK: func.not_existing: False
-  print(f"func.not_existing: {ctx.is_registered_operation('func.not_existing')}")
+    # CHECK: cf.cond_br: True
+    print(f"cf.cond_br: {ctx.is_registered_operation('cf.cond_br')}")
+    # CHECK: func.not_existing: False
+    print(f"func.not_existing: {ctx.is_registered_operation('func.not_existing')}")
index 6cb2375..74085cd 100644 (file)
@@ -3,75 +3,93 @@
 import gc
 from mlir.ir import *
 
+
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  gc.collect()
-  assert Context._get_live_count() == 0
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
+    return f
 
 
 # CHECK-LABEL: TEST: test_exception
 @run
 def test_exception():
-  ctx =  Context()
-  ctx.allow_unregistered_dialects = True
-  try:
-    Operation.parse("""
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    try:
+        Operation.parse(
+            """
       func.func @foo() {
           "test.use"(%0) : (i64) -> ()  loc("use")
           %0 = "test.def"() : () -> i64 loc("def")
           return
       }
-    """, context=ctx)
-  except MLIRError as e:
-    # CHECK: Exception: <
-    # CHECK:   Unable to parse operation assembly:
-    # CHECK:   error: "use": operand #0 does not dominate this use
-    # CHECK:    note: "use": see current operation: "test.use"(%0) : (i64) -> ()
-    # CHECK:    note: "def": operand defined here (op in the same block)
-    # CHECK: >
-    print(f"Exception: <{e}>")
+    """,
+            context=ctx,
+        )
+    except MLIRError as e:
+        # CHECK: Exception: <
+        # CHECK:   Unable to parse operation assembly:
+        # CHECK:   error: "use": operand #0 does not dominate this use
+        # CHECK:    note: "use": see current operation: "test.use"(%0) : (i64) -> ()
+        # CHECK:    note: "def": operand defined here (op in the same block)
+        # CHECK: >
+        print(f"Exception: <{e}>")
 
-    # CHECK: message: Unable to parse operation assembly
-    print(f"message: {e.message}")
+        # CHECK: message: Unable to parse operation assembly
+        print(f"message: {e.message}")
 
-    # CHECK: error_diagnostics[0]:           loc("use") operand #0 does not dominate this use
-    # CHECK: error_diagnostics[0].notes[0]:  loc("use") see current operation: "test.use"(%0) : (i64) -> ()
-    # CHECK: error_diagnostics[0].notes[1]:  loc("def") operand defined here (op in the same block)
-    print("error_diagnostics[0]:          ", e.error_diagnostics[0].location, e.error_diagnostics[0].message)
-    print("error_diagnostics[0].notes[0]: ", e.error_diagnostics[0].notes[0].location, e.error_diagnostics[0].notes[0].message)
-    print("error_diagnostics[0].notes[1]: ", e.error_diagnostics[0].notes[1].location, e.error_diagnostics[0].notes[1].message)
+        # CHECK: error_diagnostics[0]:           loc("use") operand #0 does not dominate this use
+        # CHECK: error_diagnostics[0].notes[0]:  loc("use") see current operation: "test.use"(%0) : (i64) -> ()
+        # CHECK: error_diagnostics[0].notes[1]:  loc("def") operand defined here (op in the same block)
+        print(
+            "error_diagnostics[0]:          ",
+            e.error_diagnostics[0].location,
+            e.error_diagnostics[0].message,
+        )
+        print(
+            "error_diagnostics[0].notes[0]: ",
+            e.error_diagnostics[0].notes[0].location,
+            e.error_diagnostics[0].notes[0].message,
+        )
+        print(
+            "error_diagnostics[0].notes[1]: ",
+            e.error_diagnostics[0].notes[1].location,
+            e.error_diagnostics[0].notes[1].message,
+        )
 
 
 # CHECK-LABEL: test_emit_error_diagnostics
 @run
 def test_emit_error_diagnostics():
-  ctx = Context()
-  loc = Location.unknown(ctx)
-  handler_diags = []
-  def handler(d):
-    handler_diags.append(str(d))
-    return True
-  ctx.attach_diagnostic_handler(handler)
+    ctx = Context()
+    loc = Location.unknown(ctx)
+    handler_diags = []
+
+    def handler(d):
+        handler_diags.append(str(d))
+        return True
+
+    ctx.attach_diagnostic_handler(handler)
 
-  try:
-    Attribute.parse("not an attr", ctx)
-  except MLIRError as e:
-    # CHECK: emit_error_diagnostics=False:
-    # CHECK: e.error_diagnostics: ['expected attribute value']
-    # CHECK: handler_diags: []
-    print(f"emit_error_diagnostics=False:")
-    print(f"e.error_diagnostics: {[str(diag) for diag in e.error_diagnostics]}")
-    print(f"handler_diags: {handler_diags}")
+    try:
+        Attribute.parse("not an attr", ctx)
+    except MLIRError as e:
+        # CHECK: emit_error_diagnostics=False:
+        # CHECK: e.error_diagnostics: ['expected attribute value']
+        # CHECK: handler_diags: []
+        print(f"emit_error_diagnostics=False:")
+        print(f"e.error_diagnostics: {[str(diag) for diag in e.error_diagnostics]}")
+        print(f"handler_diags: {handler_diags}")
 
-  ctx.emit_error_diagnostics = True
-  try:
-    Attribute.parse("not an attr", ctx)
-  except MLIRError as e:
-    # CHECK: emit_error_diagnostics=True:
-    # CHECK: e.error_diagnostics: []
-    # CHECK: handler_diags: ['expected attribute value']
-    print(f"emit_error_diagnostics=True:")
-    print(f"e.error_diagnostics: {[str(diag) for diag in e.error_diagnostics]}")
-    print(f"handler_diags: {handler_diags}")
+    ctx.emit_error_diagnostics = True
+    try:
+        Attribute.parse("not an attr", ctx)
+    except MLIRError as e:
+        # CHECK: emit_error_diagnostics=True:
+        # CHECK: e.error_diagnostics: []
+        # CHECK: handler_diags: ['expected attribute value']
+        print(f"emit_error_diagnostics=True:")
+        print(f"e.error_diagnostics: {[str(diag) for diag in e.error_diagnostics]}")
+        print(f"handler_diags: {handler_diags}")
index 81a6ec2..0dc7d75 100644 (file)
@@ -5,168 +5,191 @@ import io
 import itertools
 from mlir.ir import *
 
+
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  gc.collect()
-  assert Context._get_live_count() == 0
+    print("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
 
 
 # CHECK-LABEL: TEST: test_insert_at_block_end
 def test_insert_at_block_end():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  with Location.unknown(ctx):
-    module = Module.parse(r"""
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        module = Module.parse(
+            r"""
       func.func @foo() -> () {
         "custom.op1"() : () -> ()
       }
-    """)
-    entry_block = module.body.operations[0].regions[0].blocks[0]
-    ip = InsertionPoint(entry_block)
-    ip.insert(Operation.create("custom.op2"))
-    # CHECK: "custom.op1"
-    # CHECK: "custom.op2"
-    module.operation.print()
+    """
+        )
+        entry_block = module.body.operations[0].regions[0].blocks[0]
+        ip = InsertionPoint(entry_block)
+        ip.insert(Operation.create("custom.op2"))
+        # CHECK: "custom.op1"
+        # CHECK: "custom.op2"
+        module.operation.print()
+
 
 run(test_insert_at_block_end)
 
 
 # CHECK-LABEL: TEST: test_insert_before_operation
 def test_insert_before_operation():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  with Location.unknown(ctx):
-    module = Module.parse(r"""
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        module = Module.parse(
+            r"""
       func.func @foo() -> () {
         "custom.op1"() : () -> ()
         "custom.op2"() : () -> ()
       }
-    """)
-    entry_block = module.body.operations[0].regions[0].blocks[0]
-    ip = InsertionPoint(entry_block.operations[1])
-    ip.insert(Operation.create("custom.op3"))
-    # CHECK: "custom.op1"
-    # CHECK: "custom.op3"
-    # CHECK: "custom.op2"
-    module.operation.print()
+    """
+        )
+        entry_block = module.body.operations[0].regions[0].blocks[0]
+        ip = InsertionPoint(entry_block.operations[1])
+        ip.insert(Operation.create("custom.op3"))
+        # CHECK: "custom.op1"
+        # CHECK: "custom.op3"
+        # CHECK: "custom.op2"
+        module.operation.print()
+
 
 run(test_insert_before_operation)
 
 
 # CHECK-LABEL: TEST: test_insert_at_block_begin
 def test_insert_at_block_begin():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  with Location.unknown(ctx):
-    module = Module.parse(r"""
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        module = Module.parse(
+            r"""
       func.func @foo() -> () {
         "custom.op2"() : () -> ()
       }
-    """)
-    entry_block = module.body.operations[0].regions[0].blocks[0]
-    ip = InsertionPoint.at_block_begin(entry_block)
-    ip.insert(Operation.create("custom.op1"))
-    # CHECK: "custom.op1"
-    # CHECK: "custom.op2"
-    module.operation.print()
+    """
+        )
+        entry_block = module.body.operations[0].regions[0].blocks[0]
+        ip = InsertionPoint.at_block_begin(entry_block)
+        ip.insert(Operation.create("custom.op1"))
+        # CHECK: "custom.op1"
+        # CHECK: "custom.op2"
+        module.operation.print()
+
 
 run(test_insert_at_block_begin)
 
 
 # CHECK-LABEL: TEST: test_insert_at_block_begin_empty
 def test_insert_at_block_begin_empty():
-  # TODO: Write this test case when we can create such a situation.
-  pass
+    # TODO: Write this test case when we can create such a situation.
+    pass
+
 
 run(test_insert_at_block_begin_empty)
 
 
 # CHECK-LABEL: TEST: test_insert_at_terminator
 def test_insert_at_terminator():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  with Location.unknown(ctx):
-    module = Module.parse(r"""
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        module = Module.parse(
+            r"""
       func.func @foo() -> () {
         "custom.op1"() : () -> ()
         return
       }
-    """)
-    entry_block = module.body.operations[0].regions[0].blocks[0]
-    ip = InsertionPoint.at_block_terminator(entry_block)
-    ip.insert(Operation.create("custom.op2"))
-    # CHECK: "custom.op1"
-    # CHECK: "custom.op2"
-    module.operation.print()
+    """
+        )
+        entry_block = module.body.operations[0].regions[0].blocks[0]
+        ip = InsertionPoint.at_block_terminator(entry_block)
+        ip.insert(Operation.create("custom.op2"))
+        # CHECK: "custom.op1"
+        # CHECK: "custom.op2"
+        module.operation.print()
+
 
 run(test_insert_at_terminator)
 
 
 # CHECK-LABEL: TEST: test_insert_at_block_terminator_missing
 def test_insert_at_block_terminator_missing():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  with ctx:
-    module = Module.parse(r"""
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with ctx:
+        module = Module.parse(
+            r"""
       func.func @foo() -> () {
         "custom.op1"() : () -> ()
       }
-    """)
-    entry_block = module.body.operations[0].regions[0].blocks[0]
-    try:
-      ip = InsertionPoint.at_block_terminator(entry_block)
-    except ValueError as e:
-      # CHECK: Block has no terminator
-      print(e)
-    else:
-      assert False, "Expected exception"
+    """
+        )
+        entry_block = module.body.operations[0].regions[0].blocks[0]
+        try:
+            ip = InsertionPoint.at_block_terminator(entry_block)
+        except ValueError as e:
+            # CHECK: Block has no terminator
+            print(e)
+        else:
+            assert False, "Expected exception"
+
 
 run(test_insert_at_block_terminator_missing)
 
 
 # CHECK-LABEL: TEST: test_insert_at_end_with_terminator_errors
 def test_insert_at_end_with_terminator_errors():
-  with Context() as ctx, Location.unknown():
-    ctx.allow_unregistered_dialects = True
-    module = Module.parse(r"""
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+        module = Module.parse(
+            r"""
       func.func @foo() -> () {
         return
       }
-    """)
-    entry_block = module.body.operations[0].regions[0].blocks[0]
-    with InsertionPoint(entry_block):
-      try:
-        Operation.create("custom.op1", results=[], operands=[])
-      except IndexError as e:
-        # CHECK: ERROR: Cannot insert operation at the end of a block that already has a terminator.
-        print(f"ERROR: {e}")
+    """
+        )
+        entry_block = module.body.operations[0].regions[0].blocks[0]
+        with InsertionPoint(entry_block):
+            try:
+                Operation.create("custom.op1", results=[], operands=[])
+            except IndexError as e:
+                # CHECK: ERROR: Cannot insert operation at the end of a block that already has a terminator.
+                print(f"ERROR: {e}")
+
 
 run(test_insert_at_end_with_terminator_errors)
 
 
 # CHECK-LABEL: TEST: test_insertion_point_context
 def test_insertion_point_context():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  with Location.unknown(ctx):
-    module = Module.parse(r"""
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        module = Module.parse(
+            r"""
       func.func @foo() -> () {
         "custom.op1"() : () -> ()
       }
-    """)
-    entry_block = module.body.operations[0].regions[0].blocks[0]
-    with InsertionPoint(entry_block):
-      Operation.create("custom.op2")
-      with InsertionPoint.at_block_begin(entry_block):
-        Operation.create("custom.opa")
-        Operation.create("custom.opb")
-      Operation.create("custom.op3")
-    # CHECK: "custom.opa"
-    # CHECK: "custom.opb"
-    # CHECK: "custom.op1"
-    # CHECK: "custom.op2"
-    # CHECK: "custom.op3"
-    module.operation.print()
+    """
+        )
+        entry_block = module.body.operations[0].regions[0].blocks[0]
+        with InsertionPoint(entry_block):
+            Operation.create("custom.op2")
+            with InsertionPoint.at_block_begin(entry_block):
+                Operation.create("custom.opa")
+                Operation.create("custom.opb")
+            Operation.create("custom.op3")
+        # CHECK: "custom.opa"
+        # CHECK: "custom.opb"
+        # CHECK: "custom.op1"
+        # CHECK: "custom.op2"
+        # CHECK: "custom.op3"
+        module.operation.print()
+
 
 run(test_insertion_point_context)
index d9f158c..9fe0480 100644 (file)
 import gc
 from mlir.ir import *
 
+
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  gc.collect()
-  assert Context._get_live_count() == 0
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
+    return f
 
 
 # CHECK-LABEL: TEST: testIntegerSetCapsule
 @run
 def testIntegerSetCapsule():
-  with Context() as ctx:
-    is1 = IntegerSet.get_empty(1, 1, ctx)
-  capsule = is1._CAPIPtr
-  # CHECK: mlir.ir.IntegerSet._CAPIPtr
-  print(capsule)
-  is2 = IntegerSet._CAPICreate(capsule)
-  assert is1 == is2
-  assert is2.context is ctx
+    with Context() as ctx:
+        is1 = IntegerSet.get_empty(1, 1, ctx)
+    capsule = is1._CAPIPtr
+    # CHECK: mlir.ir.IntegerSet._CAPIPtr
+    print(capsule)
+    is2 = IntegerSet._CAPICreate(capsule)
+    assert is1 == is2
+    assert is2.context is ctx
 
 
 # CHECK-LABEL: TEST: testIntegerSetGet
 @run
 def testIntegerSetGet():
-  with Context():
-    d0 = AffineDimExpr.get(0)
-    d1 = AffineDimExpr.get(1)
-    s0 = AffineSymbolExpr.get(0)
-    c42 = AffineConstantExpr.get(42)
-
-    # CHECK: (d0, d1)[s0] : (d0 - d1 == 0, s0 - 42 >= 0)
-    set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42], [True, False])
-    print(set0)
-
-    # CHECK: (d0)[s0] : (1 == 0)
-    set1 = IntegerSet.get_empty(1, 1)
-    print(set1)
-
-    # CHECK: (d0)[s0, s1] : (d0 - s1 == 0, s0 - 42 >= 0)
-    set2 = set0.get_replaced([d0, AffineSymbolExpr.get(1)], [s0], 1, 2)
-    print(set2)
-
-    try:
-      IntegerSet.get(2, 1, [], [])
-    except ValueError as e:
-      # CHECK: Expected non-empty list of constraints
-      print(e)
-
-    try:
-      IntegerSet.get(2, 1, [d0 - d1], [True, False])
-    except ValueError as e:
-      # CHECK: Expected the number of constraints to match that of equality flags
-      print(e)
-
-    try:
-      IntegerSet.get(2, 1, [0], [True])
-    except RuntimeError as e:
-      # CHECK: Invalid expression when attempting to create an IntegerSet
-      print(e)
-
-    try:
-      IntegerSet.get(2, 1, [None], [True])
-    except RuntimeError as e:
-      # CHECK: Invalid expression (None?) when attempting to create an IntegerSet
-      print(e)
-
-    try:
-      set0.get_replaced([d0], [s0], 1, 1)
-    except ValueError as e:
-      # CHECK: Expected the number of dimension replacement expressions to match that of dimensions
-      print(e)
-
-    try:
-      set0.get_replaced([d0, d1], [s0, s0], 1, 1)
-    except ValueError as e:
-      # CHECK: Expected the number of symbol replacement expressions to match that of symbols
-      print(e)
-
-    try:
-      set0.get_replaced([d0, 1], [s0], 1, 1)
-    except RuntimeError as e:
-      # CHECK: Invalid expression when attempting to create an IntegerSet by replacing dimensions
-      print(e)
-
-    try:
-      set0.get_replaced([d0, d1], [None], 1, 1)
-    except RuntimeError as e:
-      # CHECK: Invalid expression (None?) when attempting to create an IntegerSet by replacing symbols
-      print(e)
+    with Context():
+        d0 = AffineDimExpr.get(0)
+        d1 = AffineDimExpr.get(1)
+        s0 = AffineSymbolExpr.get(0)
+        c42 = AffineConstantExpr.get(42)
+
+        # CHECK: (d0, d1)[s0] : (d0 - d1 == 0, s0 - 42 >= 0)
+        set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42], [True, False])
+        print(set0)
+
+        # CHECK: (d0)[s0] : (1 == 0)
+        set1 = IntegerSet.get_empty(1, 1)
+        print(set1)
+
+        # CHECK: (d0)[s0, s1] : (d0 - s1 == 0, s0 - 42 >= 0)
+        set2 = set0.get_replaced([d0, AffineSymbolExpr.get(1)], [s0], 1, 2)
+        print(set2)
+
+        try:
+            IntegerSet.get(2, 1, [], [])
+        except ValueError as e:
+            # CHECK: Expected non-empty list of constraints
+            print(e)
+
+        try:
+            IntegerSet.get(2, 1, [d0 - d1], [True, False])
+        except ValueError as e:
+            # CHECK: Expected the number of constraints to match that of equality flags
+            print(e)
+
+        try:
+            IntegerSet.get(2, 1, [0], [True])
+        except RuntimeError as e:
+            # CHECK: Invalid expression when attempting to create an IntegerSet
+            print(e)
+
+        try:
+            IntegerSet.get(2, 1, [None], [True])
+        except RuntimeError as e:
+            # CHECK: Invalid expression (None?) when attempting to create an IntegerSet
+            print(e)
+
+        try:
+            set0.get_replaced([d0], [s0], 1, 1)
+        except ValueError as e:
+            # CHECK: Expected the number of dimension replacement expressions to match that of dimensions
+            print(e)
+
+        try:
+            set0.get_replaced([d0, d1], [s0, s0], 1, 1)
+        except ValueError as e:
+            # CHECK: Expected the number of symbol replacement expressions to match that of symbols
+            print(e)
+
+        try:
+            set0.get_replaced([d0, 1], [s0], 1, 1)
+        except RuntimeError as e:
+            # CHECK: Invalid expression when attempting to create an IntegerSet by replacing dimensions
+            print(e)
+
+        try:
+            set0.get_replaced([d0, d1], [None], 1, 1)
+        except RuntimeError as e:
+            # CHECK: Invalid expression (None?) when attempting to create an IntegerSet by replacing symbols
+            print(e)
 
 
 # CHECK-LABEL: TEST: testIntegerSetProperties
 @run
 def testIntegerSetProperties():
-  with Context():
-    d0 = AffineDimExpr.get(0)
-    d1 = AffineDimExpr.get(1)
-    s0 = AffineSymbolExpr.get(0)
-    c42 = AffineConstantExpr.get(42)
-
-    set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42, s0 - d0], [True, False, False])
-    # CHECK: 2
-    print(set0.n_dims)
-    # CHECK: 1
-    print(set0.n_symbols)
-    # CHECK: 3
-    print(set0.n_inputs)
-    # CHECK: 1
-    print(set0.n_equalities)
-    # CHECK: 2
-    print(set0.n_inequalities)
-
-    # CHECK: 3
-    print(len(set0.constraints))
-
-    # CHECK-DAG: d0 - d1 == 0
-    # CHECK-DAG: s0 - 42 >= 0
-    # CHECK-DAG: -d0 + s0 >= 0
-    for cstr in set0.constraints:
-      print(cstr.expr, end='')
-      print(" == 0" if cstr.is_eq else " >= 0")
+    with Context():
+        d0 = AffineDimExpr.get(0)
+        d1 = AffineDimExpr.get(1)
+        s0 = AffineSymbolExpr.get(0)
+        c42 = AffineConstantExpr.get(42)
+
+        set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42, s0 - d0], [True, False, False])
+        # CHECK: 2
+        print(set0.n_dims)
+        # CHECK: 1
+        print(set0.n_symbols)
+        # CHECK: 3
+        print(set0.n_inputs)
+        # CHECK: 1
+        print(set0.n_equalities)
+        # CHECK: 2
+        print(set0.n_inequalities)
+
+        # CHECK: 3
+        print(len(set0.constraints))
+
+        # CHECK-DAG: d0 - d1 == 0
+        # CHECK-DAG: s0 - 42 >= 0
+        # CHECK-DAG: -d0 + s0 >= 0
+        for cstr in set0.constraints:
+            print(cstr.expr, end="")
+            print(" == 0" if cstr.is_eq else " >= 0")
 
 
 # TODO-LABEL: TEST: testHash
 @run
 def testHash():
-  with Context():
-    d0 = AffineDimExpr.get(0)
-    d1 = AffineDimExpr.get(1)
-    set = IntegerSet.get(2, 0, [d0 + d1], [True])
+    with Context():
+        d0 = AffineDimExpr.get(0)
+        d1 = AffineDimExpr.get(1)
+        set = IntegerSet.get(2, 0, [d0 + d1], [True])
 
-    assert hash(set) == hash(IntegerSet.get(2, 0, [d0 + d1], [True]))
+        assert hash(set) == hash(IntegerSet.get(2, 0, [d0 + d1], [True]))
 
-    dictionary = dict()
-    dictionary[set] = 42
-    assert set in dictionary
+        dictionary = dict()
+        dictionary[set] = 42
+        assert set in dictionary
index 6a30a1d..f66d6c5 100644 (file)
 import gc
 from mlir.ir import *
 
+
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  gc.collect()
-  assert Context._get_live_count() == 0
+    print("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
 
 
 # CHECK-LABEL: TEST: testUnknown
 def testUnknown():
-  with Context() as ctx:
-    loc = Location.unknown()
-  assert loc.context is ctx
-  ctx = None
-  gc.collect()
-  # CHECK: unknown str: loc(unknown)
-  print("unknown str:", str(loc))
-  # CHECK: unknown repr: loc(unknown)
-  print("unknown repr:", repr(loc))
+    with Context() as ctx:
+        loc = Location.unknown()
+    assert loc.context is ctx
+    ctx = None
+    gc.collect()
+    # CHECK: unknown str: loc(unknown)
+    print("unknown str:", str(loc))
+    # CHECK: unknown repr: loc(unknown)
+    print("unknown repr:", repr(loc))
+
 
 run(testUnknown)
 
 
 # CHECK-LABEL: TEST: testLocationAttr
 def testLocationAttr():
-  with Context() as ctxt:
-    loc = Location.unknown()
-    attr = loc.attr
-    clone = Location.from_attr(attr)
-  gc.collect()
-  # CHECK: loc: loc(unknown)
-  print("loc:", str(loc))
-  # CHECK: clone: loc(unknown)
-  print("clone:", str(clone))
-  assert loc == clone
+    with Context() as ctxt:
+        loc = Location.unknown()
+        attr = loc.attr
+        clone = Location.from_attr(attr)
+    gc.collect()
+    # CHECK: loc: loc(unknown)
+    print("loc:", str(loc))
+    # CHECK: clone: loc(unknown)
+    print("clone:", str(clone))
+    assert loc == clone
+
 
 run(testLocationAttr)
 
 # CHECK-LABEL: TEST: testFileLineCol
 def testFileLineCol():
-  with Context() as ctx:
-    loc = Location.file("foo.txt", 123, 56)
-  ctx = None
-  gc.collect()
-  # CHECK: file str: loc("foo.txt":123:56)
-  print("file str:", str(loc))
-  # CHECK: file repr: loc("foo.txt":123:56)
-  print("file repr:", repr(loc))
+    with Context() as ctx:
+        loc = Location.file("foo.txt", 123, 56)
+    ctx = None
+    gc.collect()
+    # CHECK: file str: loc("foo.txt":123:56)
+    print("file str:", str(loc))
+    # CHECK: file repr: loc("foo.txt":123:56)
+    print("file repr:", repr(loc))
+
 
 run(testFileLineCol)
 
 
 # CHECK-LABEL: TEST: testName
 def testName():
-  with Context() as ctx:
-    loc = Location.name("nombre")
-    locWithChildLoc = Location.name("naam", loc)
-  ctx = None
-  gc.collect()
-  # CHECK: file str: loc("nombre")
-  print("file str:", str(loc))
-  # CHECK: file repr: loc("nombre")
-  print("file repr:", repr(loc))
-  # CHECK: file str: loc("naam"("nombre"))
-  print("file str:", str(locWithChildLoc))
-  # CHECK: file repr: loc("naam"("nombre"))
-  print("file repr:", repr(locWithChildLoc))
+    with Context() as ctx:
+        loc = Location.name("nombre")
+        locWithChildLoc = Location.name("naam", loc)
+    ctx = None
+    gc.collect()
+    # CHECK: file str: loc("nombre")
+    print("file str:", str(loc))
+    # CHECK: file repr: loc("nombre")
+    print("file repr:", repr(loc))
+    # CHECK: file str: loc("naam"("nombre"))
+    print("file str:", str(locWithChildLoc))
+    # CHECK: file repr: loc("naam"("nombre"))
+    print("file repr:", repr(locWithChildLoc))
+
 
 run(testName)
 
 
 # CHECK-LABEL: TEST: testCallSite
 def testCallSite():
-  with Context() as ctx:
-    loc = Location.callsite(
-        Location.file("foo.text", 123, 45), [
-            Location.file("util.foo", 379, 21),
-            Location.file("main.foo", 100, 63)
-        ])
-  ctx = None
-  # CHECK: file str: loc(callsite("foo.text":123:45 at callsite("util.foo":379:21 at "main.foo":100:63))
-  print("file str:", str(loc))
-  # CHECK: file repr: loc(callsite("foo.text":123:45 at callsite("util.foo":379:21 at "main.foo":100:63))
-  print("file repr:", repr(loc))
+    with Context() as ctx:
+        loc = Location.callsite(
+            Location.file("foo.text", 123, 45),
+            [Location.file("util.foo", 379, 21), Location.file("main.foo", 100, 63)],
+        )
+    ctx = None
+    # CHECK: file str: loc(callsite("foo.text":123:45 at callsite("util.foo":379:21 at "main.foo":100:63))
+    print("file str:", str(loc))
+    # CHECK: file repr: loc(callsite("foo.text":123:45 at callsite("util.foo":379:21 at "main.foo":100:63))
+    print("file repr:", repr(loc))
+
 
 run(testCallSite)
 
 
 # CHECK-LABEL: TEST: testFused
 def testFused():
-  with Context() as ctx:
-    loc_single = Location.fused([Location.name("apple")])
-    loc = Location.fused(
-        [Location.name("apple"), Location.name("banana")])
-    attr = Attribute.parse('"sauteed"')
-    loc_attr = Location.fused([Location.name("carrot"),
-                               Location.name("potatoes")], attr)
-    loc_empty = Location.fused([])
-    loc_empty_attr = Location.fused([], attr)
-    loc_single_attr = Location.fused([Location.name("apple")], attr)
-  ctx = None
-  # CHECK: file str: loc("apple")
-  print("file str:", str(loc_single))
-  # CHECK: file repr: loc("apple")
-  print("file repr:", repr(loc_single))
-  # CHECK: file str: loc(fused["apple", "banana"])
-  print("file str:", str(loc))
-  # CHECK: file repr: loc(fused["apple", "banana"])
-  print("file repr:", repr(loc))
-  # CHECK: file str: loc(fused<"sauteed">["carrot", "potatoes"])
-  print("file str:", str(loc_attr))
-  # CHECK: file repr: loc(fused<"sauteed">["carrot", "potatoes"])
-  print("file repr:", repr(loc_attr))
-  # CHECK: file str: loc(unknown)
-  print("file str:", str(loc_empty))
-  # CHECK: file repr: loc(unknown)
-  print("file repr:", repr(loc_empty))
-  # CHECK: file str: loc(fused<"sauteed">[unknown])
-  print("file str:", str(loc_empty_attr))
-  # CHECK: file repr: loc(fused<"sauteed">[unknown])
-  print("file repr:", repr(loc_empty_attr))
-  # CHECK: file str: loc(fused<"sauteed">["apple"])
-  print("file str:", str(loc_single_attr))
-  # CHECK: file repr: loc(fused<"sauteed">["apple"])
-  print("file repr:", repr(loc_single_attr))
+    with Context() as ctx:
+        loc_single = Location.fused([Location.name("apple")])
+        loc = Location.fused([Location.name("apple"), Location.name("banana")])
+        attr = Attribute.parse('"sauteed"')
+        loc_attr = Location.fused(
+            [Location.name("carrot"), Location.name("potatoes")], attr
+        )
+        loc_empty = Location.fused([])
+        loc_empty_attr = Location.fused([], attr)
+        loc_single_attr = Location.fused([Location.name("apple")], attr)
+    ctx = None
+    # CHECK: file str: loc("apple")
+    print("file str:", str(loc_single))
+    # CHECK: file repr: loc("apple")
+    print("file repr:", repr(loc_single))
+    # CHECK: file str: loc(fused["apple", "banana"])
+    print("file str:", str(loc))
+    # CHECK: file repr: loc(fused["apple", "banana"])
+    print("file repr:", repr(loc))
+    # CHECK: file str: loc(fused<"sauteed">["carrot", "potatoes"])
+    print("file str:", str(loc_attr))
+    # CHECK: file repr: loc(fused<"sauteed">["carrot", "potatoes"])
+    print("file repr:", repr(loc_attr))
+    # CHECK: file str: loc(unknown)
+    print("file str:", str(loc_empty))
+    # CHECK: file repr: loc(unknown)
+    print("file repr:", repr(loc_empty))
+    # CHECK: file str: loc(fused<"sauteed">[unknown])
+    print("file str:", str(loc_empty_attr))
+    # CHECK: file repr: loc(fused<"sauteed">[unknown])
+    print("file repr:", repr(loc_empty_attr))
+    # CHECK: file str: loc(fused<"sauteed">["apple"])
+    print("file str:", str(loc_single_attr))
+    # CHECK: file repr: loc(fused<"sauteed">["apple"])
+    print("file repr:", repr(loc_single_attr))
+
 
 run(testFused)
 
 
 # CHECK-LABEL: TEST: testLocationCapsule
 def testLocationCapsule():
-  with Context() as ctx:
-    loc1 = Location.file("foo.txt", 123, 56)
-  # CHECK: mlir.ir.Location._CAPIPtr
-  loc_capsule = loc1._CAPIPtr
-  print(loc_capsule)
-  loc2 = Location._CAPICreate(loc_capsule)
-  assert loc2 == loc1
-  assert loc2.context is ctx
+    with Context() as ctx:
+        loc1 = Location.file("foo.txt", 123, 56)
+    # CHECK: mlir.ir.Location._CAPIPtr
+    loc_capsule = loc1._CAPIPtr
+    print(loc_capsule)
+    loc2 = Location._CAPICreate(loc_capsule)
+    assert loc2 == loc1
+    assert loc2.context is ctx
+
 
 run(testLocationCapsule)
index 2d00923..a5c38a6 100644 (file)
@@ -3,12 +3,13 @@
 import gc
 from mlir.ir import *
 
+
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  gc.collect()
-  assert Context._get_live_count() == 0
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
+    return f
 
 
 # Verify successful parse.
@@ -16,14 +17,14 @@ def run(f):
 # CHECK: module @successfulParse
 @run
 def testParseSuccess():
-  ctx = Context()
-  module = Module.parse(r"""module @successfulParse {}""", ctx)
-  assert module.context is ctx
-  print("CLEAR CONTEXT")
-  ctx = None  # Ensure that module captures the context.
-  gc.collect()
-  module.dump()  # Just outputs to stderr. Verifies that it functions.
-  print(str(module))
+    ctx = Context()
+    module = Module.parse(r"""module @successfulParse {}""", ctx)
+    assert module.context is ctx
+    print("CLEAR CONTEXT")
+    ctx = None  # Ensure that module captures the context.
+    gc.collect()
+    module.dump()  # Just outputs to stderr. Verifies that it functions.
+    print(str(module))
 
 
 # Verify parse error.
@@ -34,13 +35,13 @@ def testParseSuccess():
 # CHECK: >
 @run
 def testParseError():
-  ctx = Context()
-  try:
-    module = Module.parse(r"""}SYNTAX ERROR{""", ctx)
-  except MLIRError as e:
-    print(f"testParseError: <{e}>")
-  else:
-    print("Exception not produced")
+    ctx = Context()
+    try:
+        module = Module.parse(r"""}SYNTAX ERROR{""", ctx)
+    except MLIRError as e:
+        print(f"testParseError: <{e}>")
+    else:
+        print("Exception not produced")
 
 
 # Verify successful parse.
@@ -48,13 +49,13 @@ def testParseError():
 # CHECK: module {
 @run
 def testCreateEmpty():
-  ctx = Context()
-  loc = Location.unknown(ctx)
-  module = Module.create(loc)
-  print("CLEAR CONTEXT")
-  ctx = None  # Ensure that module captures the context.
-  gc.collect()
-  print(str(module))
+    ctx = Context()
+    loc = Location.unknown(ctx)
+    module = Module.create(loc)
+    print("CLEAR CONTEXT")
+    ctx = None  # Ensure that module captures the context.
+    gc.collect()
+    print(str(module))
 
 
 # Verify round-trip of ASM that contains unicode.
@@ -65,11 +66,14 @@ def testCreateEmpty():
 # CHECK: foo = "\F0\9F\98\8A"
 @run
 def testRoundtripUnicode():
-  ctx = Context()
-  module = Module.parse(r"""
+    ctx = Context()
+    module = Module.parse(
+        r"""
     func.func private @roundtripUnicode() attributes { foo = "😊" }
-  """, ctx)
-  print(str(module))
+  """,
+        ctx,
+    )
+    print(str(module))
 
 
 # Verify round-trip of ASM that contains unicode.
@@ -80,73 +84,74 @@ def testRoundtripUnicode():
 # CHECK: foo = "\F0\9F\98\8A"
 @run
 def testRoundtripBinary():
-  with Context():
-    module = Module.parse(r"""
+    with Context():
+        module = Module.parse(
+            r"""
       func.func private @roundtripUnicode() attributes { foo = "😊" }
-    """)
-    binary_asm = module.operation.get_asm(binary=True)
-    assert isinstance(binary_asm, bytes)
-    module = Module.parse(binary_asm)
-    print(module)
+    """
+        )
+        binary_asm = module.operation.get_asm(binary=True)
+        assert isinstance(binary_asm, bytes)
+        module = Module.parse(binary_asm)
+        print(module)
 
 
 # Tests that module.operation works and correctly interns instances.
 # CHECK-LABEL: TEST: testModuleOperation
 @run
 def testModuleOperation():
-  ctx = Context()
-  module = Module.parse(r"""module @successfulParse {}""", ctx)
-  assert ctx._get_live_module_count() == 1
-  op1 = module.operation
-  assert ctx._get_live_operation_count() == 1
-  # CHECK: module @successfulParse
-  print(op1)
-
-  # Ensure that operations are the same on multiple calls.
-  op2 = module.operation
-  assert ctx._get_live_operation_count() == 1
-  assert op1 is op2
-
-  # Test live operation clearing.
-  op1 = module.operation
-  assert ctx._get_live_operation_count() == 1
-  num_invalidated = ctx._clear_live_operations()
-  assert num_invalidated == 1
-  assert ctx._get_live_operation_count() == 0
-  op1 = None
-  gc.collect()
-  op1 = module.operation
-
-  # Ensure that if module is de-referenced, the operations are still valid.
-  module = None
-  gc.collect()
-  print(op1)
-
-  # Collect and verify lifetime.
-  op1 = None
-  op2 = None
-  gc.collect()
-  print("LIVE OPERATIONS:", ctx._get_live_operation_count())
-  assert ctx._get_live_operation_count() == 0
-  assert ctx._get_live_module_count() == 0
+    ctx = Context()
+    module = Module.parse(r"""module @successfulParse {}""", ctx)
+    assert ctx._get_live_module_count() == 1
+    op1 = module.operation
+    assert ctx._get_live_operation_count() == 1
+    # CHECK: module @successfulParse
+    print(op1)
+
+    # Ensure that operations are the same on multiple calls.
+    op2 = module.operation
+    assert ctx._get_live_operation_count() == 1
+    assert op1 is op2
+
+    # Test live operation clearing.
+    op1 = module.operation
+    assert ctx._get_live_operation_count() == 1
+    num_invalidated = ctx._clear_live_operations()
+    assert num_invalidated == 1
+    assert ctx._get_live_operation_count() == 0
+    op1 = None
+    gc.collect()
+    op1 = module.operation
+
+    # Ensure that if module is de-referenced, the operations are still valid.
+    module = None
+    gc.collect()
+    print(op1)
+
+    # Collect and verify lifetime.
+    op1 = None
+    op2 = None
+    gc.collect()
+    print("LIVE OPERATIONS:", ctx._get_live_operation_count())
+    assert ctx._get_live_operation_count() == 0
+    assert ctx._get_live_module_count() == 0
 
 
 # CHECK-LABEL: TEST: testModuleCapsule
 @run
 def testModuleCapsule():
-  ctx = Context()
-  module = Module.parse(r"""module @successfulParse {}""", ctx)
-  assert ctx._get_live_module_count() == 1
-  # CHECK: "mlir.ir.Module._CAPIPtr"
-  module_capsule = module._CAPIPtr
-  print(module_capsule)
-  module_dup = Module._CAPICreate(module_capsule)
-  assert module is module_dup
-  assert module_dup.context is ctx
-  # Gc and verify destructed.
-  module = None
-  module_capsule = None
-  module_dup = None
-  gc.collect()
-  assert ctx._get_live_module_count() == 0
-
+    ctx = Context()
+    module = Module.parse(r"""module @successfulParse {}""", ctx)
+    assert ctx._get_live_module_count() == 1
+    # CHECK: "mlir.ir.Module._CAPIPtr"
+    module_capsule = module._CAPIPtr
+    print(module_capsule)
+    module_dup = Module._CAPICreate(module_capsule)
+    assert module is module_dup
+    assert module_dup.context is ctx
+    # Gc and verify destructed.
+    module = None
+    module_capsule = None
+    module_dup = None
+    gc.collect()
+    assert ctx._get_live_module_count() == 0
index 22a8089..639f8ff 100644 (file)
@@ -8,232 +8,242 @@ from mlir.dialects.builtin import ModuleOp
 
 
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  gc.collect()
-  assert Context._get_live_count() == 0
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
+    return f
 
 
 def expect_index_error(callback):
-  try:
-    _ = callback()
-    raise RuntimeError("Expected IndexError")
-  except IndexError:
-    pass
+    try:
+        _ = callback()
+        raise RuntimeError("Expected IndexError")
+    except IndexError:
+        pass
 
 
 # Verify iterator based traversal of the op/region/block hierarchy.
 # CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
 @run
 def testTraverseOpRegionBlockIterators():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  module = Module.parse(
-      r"""
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    module = Module.parse(
+        r"""
     func.func @f1(%arg0: i32) -> i32 {
       %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
       return %1 : i32
     }
-  """, ctx)
-  op = module.operation
-  assert op.context is ctx
-  # Get the block using iterators off of the named collections.
-  regions = list(op.regions)
-  blocks = list(regions[0].blocks)
-  # CHECK: MODULE REGIONS=1 BLOCKS=1
-  print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
-
-  # Should verify.
-  # CHECK: .verify = True
-  print(f".verify = {module.operation.verify()}")
-
-  # Get the blocks from the default collection.
-  default_blocks = list(regions[0])
-  # They should compare equal regardless of how obtained.
-  assert default_blocks == blocks
-
-  # Should be able to get the operations from either the named collection
-  # or the block.
-  operations = list(blocks[0].operations)
-  default_operations = list(blocks[0])
-  assert default_operations == operations
-
-  def walk_operations(indent, op):
-    for i, region in enumerate(op.regions):
-      print(f"{indent}REGION {i}:")
-      for j, block in enumerate(region):
-        print(f"{indent}  BLOCK {j}:")
-        for k, child_op in enumerate(block):
-          print(f"{indent}    OP {k}: {child_op}")
-          walk_operations(indent + "      ", child_op)
-
-  # CHECK: REGION 0:
-  # CHECK:   BLOCK 0:
-  # CHECK:     OP 0: func
-  # CHECK:       REGION 0:
-  # CHECK:         BLOCK 0:
-  # CHECK:           OP 0: %0 = "custom.addi"
-  # CHECK:           OP 1: func.return
-  walk_operations("", op)
-
-  # CHECK:    Region iter: <mlir.{{.+}}.RegionIterator
-  # CHECK:     Block iter: <mlir.{{.+}}.BlockIterator
-  # CHECK: Operation iter: <mlir.{{.+}}.OperationIterator
-  print("   Region iter:", iter(op.regions))
-  print("    Block iter:", iter(op.regions[0]))
-  print("Operation iter:", iter(op.regions[0].blocks[0]))
+  """,
+        ctx,
+    )
+    op = module.operation
+    assert op.context is ctx
+    # Get the block using iterators off of the named collections.
+    regions = list(op.regions)
+    blocks = list(regions[0].blocks)
+    # CHECK: MODULE REGIONS=1 BLOCKS=1
+    print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
+
+    # Should verify.
+    # CHECK: .verify = True
+    print(f".verify = {module.operation.verify()}")
+
+    # Get the blocks from the default collection.
+    default_blocks = list(regions[0])
+    # They should compare equal regardless of how obtained.
+    assert default_blocks == blocks
+
+    # Should be able to get the operations from either the named collection
+    # or the block.
+    operations = list(blocks[0].operations)
+    default_operations = list(blocks[0])
+    assert default_operations == operations
+
+    def walk_operations(indent, op):
+        for i, region in enumerate(op.regions):
+            print(f"{indent}REGION {i}:")
+            for j, block in enumerate(region):
+                print(f"{indent}  BLOCK {j}:")
+                for k, child_op in enumerate(block):
+                    print(f"{indent}    OP {k}: {child_op}")
+                    walk_operations(indent + "      ", child_op)
+
+    # CHECK: REGION 0:
+    # CHECK:   BLOCK 0:
+    # CHECK:     OP 0: func
+    # CHECK:       REGION 0:
+    # CHECK:         BLOCK 0:
+    # CHECK:           OP 0: %0 = "custom.addi"
+    # CHECK:           OP 1: func.return
+    walk_operations("", op)
+
+    # CHECK:    Region iter: <mlir.{{.+}}.RegionIterator
+    # CHECK:     Block iter: <mlir.{{.+}}.BlockIterator
+    # CHECK: Operation iter: <mlir.{{.+}}.OperationIterator
+    print("   Region iter:", iter(op.regions))
+    print("    Block iter:", iter(op.regions[0]))
+    print("Operation iter:", iter(op.regions[0].blocks[0]))
 
 
 # Verify index based traversal of the op/region/block hierarchy.
 # CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices
 @run
 def testTraverseOpRegionBlockIndices():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  module = Module.parse(
-      r"""
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    module = Module.parse(
+        r"""
     func.func @f1(%arg0: i32) -> i32 {
       %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
       return %1 : i32
     }
-  """, ctx)
-
-  def walk_operations(indent, op):
-    for i in range(len(op.regions)):
-      region = op.regions[i]
-      print(f"{indent}REGION {i}:")
-      for j in range(len(region.blocks)):
-        block = region.blocks[j]
-        print(f"{indent}  BLOCK {j}:")
-        for k in range(len(block.operations)):
-          child_op = block.operations[k]
-          print(f"{indent}    OP {k}: {child_op}")
-          print(f"{indent}    OP {k}: parent {child_op.operation.parent.name}")
-          walk_operations(indent + "      ", child_op)
-
-  # CHECK: REGION 0:
-  # CHECK:   BLOCK 0:
-  # CHECK:     OP 0: func
-  # CHECK:     OP 0: parent builtin.module
-  # CHECK:       REGION 0:
-  # CHECK:         BLOCK 0:
-  # CHECK:           OP 0: %0 = "custom.addi"
-  # CHECK:           OP 0: parent func.func
-  # CHECK:           OP 1: func.return
-  # CHECK:           OP 1: parent func.func
-  walk_operations("", module.operation)
+  """,
+        ctx,
+    )
+
+    def walk_operations(indent, op):
+        for i in range(len(op.regions)):
+            region = op.regions[i]
+            print(f"{indent}REGION {i}:")
+            for j in range(len(region.blocks)):
+                block = region.blocks[j]
+                print(f"{indent}  BLOCK {j}:")
+                for k in range(len(block.operations)):
+                    child_op = block.operations[k]
+                    print(f"{indent}    OP {k}: {child_op}")
+                    print(
+                        f"{indent}    OP {k}: parent {child_op.operation.parent.name}"
+                    )
+                    walk_operations(indent + "      ", child_op)
+
+    # CHECK: REGION 0:
+    # CHECK:   BLOCK 0:
+    # CHECK:     OP 0: func
+    # CHECK:     OP 0: parent builtin.module
+    # CHECK:       REGION 0:
+    # CHECK:         BLOCK 0:
+    # CHECK:           OP 0: %0 = "custom.addi"
+    # CHECK:           OP 0: parent func.func
+    # CHECK:           OP 1: func.return
+    # CHECK:           OP 1: parent func.func
+    walk_operations("", module.operation)
 
 
 # CHECK-LABEL: TEST: testBlockAndRegionOwners
 @run
 def testBlockAndRegionOwners():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  module = Module.parse(
-      r"""
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    module = Module.parse(
+        r"""
     builtin.module {
       func.func @f() {
         func.return
       }
     }
-  """, ctx)
+  """,
+        ctx,
+    )
 
-  assert module.operation.regions[0].owner == module.operation
-  assert module.operation.regions[0].blocks[0].owner == module.operation
+    assert module.operation.regions[0].owner == module.operation
+    assert module.operation.regions[0].blocks[0].owner == module.operation
 
-  func = module.body.operations[0]
-  assert func.operation.regions[0].owner == func
-  assert func.operation.regions[0].blocks[0].owner == func
+    func = module.body.operations[0]
+    assert func.operation.regions[0].owner == func
+    assert func.operation.regions[0].blocks[0].owner == func
 
 
 # CHECK-LABEL: TEST: testBlockArgumentList
 @run
 def testBlockArgumentList():
-  with Context() as ctx:
-    module = Module.parse(
-        r"""
+    with Context() as ctx:
+        module = Module.parse(
+            r"""
       func.func @f1(%arg0: i32, %arg1: f64, %arg2: index) {
         return
       }
-    """, ctx)
-    func = module.body.operations[0]
-    entry_block = func.regions[0].blocks[0]
-    assert len(entry_block.arguments) == 3
-    # CHECK: Argument 0, type i32
-    # CHECK: Argument 1, type f64
-    # CHECK: Argument 2, type index
-    for arg in entry_block.arguments:
-      print(f"Argument {arg.arg_number}, type {arg.type}")
-      new_type = IntegerType.get_signless(8 * (arg.arg_number + 1))
-      arg.set_type(new_type)
-
-    # CHECK: Argument 0, type i8
-    # CHECK: Argument 1, type i16
-    # CHECK: Argument 2, type i24
-    for arg in entry_block.arguments:
-      print(f"Argument {arg.arg_number}, type {arg.type}")
-
-    # Check that slicing works for block argument lists.
-    # CHECK: Argument 1, type i16
-    # CHECK: Argument 2, type i24
-    for arg in entry_block.arguments[1:]:
-      print(f"Argument {arg.arg_number}, type {arg.type}")
-
-    # Check that we can concatenate slices of argument lists.
-    # CHECK: Length: 4
-    print("Length: ",
-          len(entry_block.arguments[:2] + entry_block.arguments[1:]))
-
-    # CHECK: Type: i8
-    # CHECK: Type: i16
-    # CHECK: Type: i24
-    for t in entry_block.arguments.types:
-      print("Type: ", t)
-
-    # Check that slicing and type access compose.
-    # CHECK: Sliced type: i16
-    # CHECK: Sliced type: i24
-    for t in entry_block.arguments[1:].types:
-      print("Sliced type: ", t)
-
-    # Check that slice addition works as expected.
-    # CHECK: Argument 2, type i24
-    # CHECK: Argument 0, type i8
-    restructured = entry_block.arguments[-1:] + entry_block.arguments[:1]
-    for arg in restructured:
-      print(f"Argument {arg.arg_number}, type {arg.type}")
+    """,
+            ctx,
+        )
+        func = module.body.operations[0]
+        entry_block = func.regions[0].blocks[0]
+        assert len(entry_block.arguments) == 3
+        # CHECK: Argument 0, type i32
+        # CHECK: Argument 1, type f64
+        # CHECK: Argument 2, type index
+        for arg in entry_block.arguments:
+            print(f"Argument {arg.arg_number}, type {arg.type}")
+            new_type = IntegerType.get_signless(8 * (arg.arg_number + 1))
+            arg.set_type(new_type)
+
+        # CHECK: Argument 0, type i8
+        # CHECK: Argument 1, type i16
+        # CHECK: Argument 2, type i24
+        for arg in entry_block.arguments:
+            print(f"Argument {arg.arg_number}, type {arg.type}")
+
+        # Check that slicing works for block argument lists.
+        # CHECK: Argument 1, type i16
+        # CHECK: Argument 2, type i24
+        for arg in entry_block.arguments[1:]:
+            print(f"Argument {arg.arg_number}, type {arg.type}")
+
+        # Check that we can concatenate slices of argument lists.
+        # CHECK: Length: 4
+        print("Length: ", len(entry_block.arguments[:2] + entry_block.arguments[1:]))
+
+        # CHECK: Type: i8
+        # CHECK: Type: i16
+        # CHECK: Type: i24
+        for t in entry_block.arguments.types:
+            print("Type: ", t)
+
+        # Check that slicing and type access compose.
+        # CHECK: Sliced type: i16
+        # CHECK: Sliced type: i24
+        for t in entry_block.arguments[1:].types:
+            print("Sliced type: ", t)
+
+        # Check that slice addition works as expected.
+        # CHECK: Argument 2, type i24
+        # CHECK: Argument 0, type i8
+        restructured = entry_block.arguments[-1:] + entry_block.arguments[:1]
+        for arg in restructured:
+            print(f"Argument {arg.arg_number}, type {arg.type}")
 
 
 # CHECK-LABEL: TEST: testOperationOperands
 @run
 def testOperationOperands():
-  with Context() as ctx:
-    ctx.allow_unregistered_dialects = True
-    module = Module.parse(r"""
+    with Context() as ctx:
+        ctx.allow_unregistered_dialects = True
+        module = Module.parse(
+            r"""
       func.func @f1(%arg0: i32) {
         %0 = "test.producer"() : () -> i64
         "test.consumer"(%arg0, %0) : (i32, i64) -> ()
         return
-      }""")
-    func = module.body.operations[0]
-    entry_block = func.regions[0].blocks[0]
-    consumer = entry_block.operations[1]
-    assert len(consumer.operands) == 2
-    # CHECK: Operand 0, type i32
-    # CHECK: Operand 1, type i64
-    for i, operand in enumerate(consumer.operands):
-      print(f"Operand {i}, type {operand.type}")
-
-
+      }"""
+        )
+        func = module.body.operations[0]
+        entry_block = func.regions[0].blocks[0]
+        consumer = entry_block.operations[1]
+        assert len(consumer.operands) == 2
+        # CHECK: Operand 0, type i32
+        # CHECK: Operand 1, type i64
+        for i, operand in enumerate(consumer.operands):
+            print(f"Operand {i}, type {operand.type}")
 
 
 # CHECK-LABEL: TEST: testOperationOperandsSlice
 @run
 def testOperationOperandsSlice():
-  with Context() as ctx:
-    ctx.allow_unregistered_dialects = True
-    module = Module.parse(r"""
+    with Context() as ctx:
+        ctx.allow_unregistered_dialects = True
+        module = Module.parse(
+            r"""
       func.func @f1() {
         %0 = "test.producer0"() : () -> i64
         %1 = "test.producer1"() : () -> i64
@@ -242,708 +252,727 @@ def testOperationOperandsSlice():
         %4 = "test.producer4"() : () -> i64
         "test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> ()
         return
-      }""")
-    func = module.body.operations[0]
-    entry_block = func.regions[0].blocks[0]
-    consumer = entry_block.operations[5]
-    assert len(consumer.operands) == 5
-    for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]):
-      assert left == right
-
-    # CHECK: test.producer0
-    # CHECK: test.producer1
-    # CHECK: test.producer2
-    # CHECK: test.producer3
-    # CHECK: test.producer4
-    full_slice = consumer.operands[:]
-    for operand in full_slice:
-      print(operand)
-
-    # CHECK: test.producer0
-    # CHECK: test.producer1
-    first_two = consumer.operands[0:2]
-    for operand in first_two:
-      print(operand)
-
-    # CHECK: test.producer3
-    # CHECK: test.producer4
-    last_two = consumer.operands[3:]
-    for operand in last_two:
-      print(operand)
-
-    # CHECK: test.producer0
-    # CHECK: test.producer2
-    # CHECK: test.producer4
-    even = consumer.operands[::2]
-    for operand in even:
-      print(operand)
-
-    # CHECK: test.producer2
-    fourth = consumer.operands[::2][1::2]
-    for operand in fourth:
-      print(operand)
-
-
+      }"""
+        )
+        func = module.body.operations[0]
+        entry_block = func.regions[0].blocks[0]
+        consumer = entry_block.operations[5]
+        assert len(consumer.operands) == 5
+        for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]):
+            assert left == right
+
+        # CHECK: test.producer0
+        # CHECK: test.producer1
+        # CHECK: test.producer2
+        # CHECK: test.producer3
+        # CHECK: test.producer4
+        full_slice = consumer.operands[:]
+        for operand in full_slice:
+            print(operand)
+
+        # CHECK: test.producer0
+        # CHECK: test.producer1
+        first_two = consumer.operands[0:2]
+        for operand in first_two:
+            print(operand)
+
+        # CHECK: test.producer3
+        # CHECK: test.producer4
+        last_two = consumer.operands[3:]
+        for operand in last_two:
+            print(operand)
+
+        # CHECK: test.producer0
+        # CHECK: test.producer2
+        # CHECK: test.producer4
+        even = consumer.operands[::2]
+        for operand in even:
+            print(operand)
+
+        # CHECK: test.producer2
+        fourth = consumer.operands[::2][1::2]
+        for operand in fourth:
+            print(operand)
 
 
 # CHECK-LABEL: TEST: testOperationOperandsSet
 @run
 def testOperationOperandsSet():
-  with Context() as ctx, Location.unknown(ctx):
-    ctx.allow_unregistered_dialects = True
-    module = Module.parse(r"""
+    with Context() as ctx, Location.unknown(ctx):
+        ctx.allow_unregistered_dialects = True
+        module = Module.parse(
+            r"""
       func.func @f1() {
         %0 = "test.producer0"() : () -> i64
         %1 = "test.producer1"() : () -> i64
         %2 = "test.producer2"() : () -> i64
         "test.consumer"(%0) : (i64) -> ()
         return
-      }""")
-    func = module.body.operations[0]
-    entry_block = func.regions[0].blocks[0]
-    producer1 = entry_block.operations[1]
-    producer2 = entry_block.operations[2]
-    consumer = entry_block.operations[3]
-    assert len(consumer.operands) == 1
-    type = consumer.operands[0].type
-
-    # CHECK: test.producer1
-    consumer.operands[0] = producer1.result
-    print(consumer.operands[0])
-
-    # CHECK: test.producer2
-    consumer.operands[-1] = producer2.result
-    print(consumer.operands[0])
+      }"""
+        )
+        func = module.body.operations[0]
+        entry_block = func.regions[0].blocks[0]
+        producer1 = entry_block.operations[1]
+        producer2 = entry_block.operations[2]
+        consumer = entry_block.operations[3]
+        assert len(consumer.operands) == 1
+        type = consumer.operands[0].type
 
+        # CHECK: test.producer1
+        consumer.operands[0] = producer1.result
+        print(consumer.operands[0])
 
+        # CHECK: test.producer2
+        consumer.operands[-1] = producer2.result
+        print(consumer.operands[0])
 
 
 # CHECK-LABEL: TEST: testDetachedOperation
 @run
 def testDetachedOperation():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  with Location.unknown(ctx):
-    i32 = IntegerType.get_signed(32)
-    op1 = Operation.create(
-        "custom.op1",
-        results=[i32, i32],
-        regions=1,
-        attributes={
-            "foo": StringAttr.get("foo_value"),
-            "bar": StringAttr.get("bar_value"),
-        })
-    # CHECK: %0:2 = "custom.op1"() ({
-    # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32)
-    print(op1)
-
-  # TODO: Check successors once enough infra exists to do it properly.
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        i32 = IntegerType.get_signed(32)
+        op1 = Operation.create(
+            "custom.op1",
+            results=[i32, i32],
+            regions=1,
+            attributes={
+                "foo": StringAttr.get("foo_value"),
+                "bar": StringAttr.get("bar_value"),
+            },
+        )
+        # CHECK: %0:2 = "custom.op1"() ({
+        # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32)
+        print(op1)
+
+    # TODO: Check successors once enough infra exists to do it properly.
 
 
 # CHECK-LABEL: TEST: testOperationInsertionPoint
 @run
 def testOperationInsertionPoint():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  module = Module.parse(
-      r"""
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    module = Module.parse(
+        r"""
     func.func @f1(%arg0: i32) -> i32 {
       %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
       return %1 : i32
     }
-  """, ctx)
-
-  # Create test op.
-  with Location.unknown(ctx):
-    op1 = Operation.create("custom.op1")
-    op2 = Operation.create("custom.op2")
-
-    func = module.body.operations[0]
-    entry_block = func.regions[0].blocks[0]
-    ip = InsertionPoint.at_block_begin(entry_block)
-    ip.insert(op1)
-    ip.insert(op2)
-    # CHECK: func @f1
-    # CHECK: "custom.op1"()
-    # CHECK: "custom.op2"()
-    # CHECK: %0 = "custom.addi"
-    print(module)
-
-  # Trying to add a previously added op should raise.
-  try:
-    ip.insert(op1)
-  except ValueError:
-    pass
-  else:
-    assert False, "expected insert of attached op to raise"
+  """,
+        ctx,
+    )
+
+    # Create test op.
+    with Location.unknown(ctx):
+        op1 = Operation.create("custom.op1")
+        op2 = Operation.create("custom.op2")
+
+        func = module.body.operations[0]
+        entry_block = func.regions[0].blocks[0]
+        ip = InsertionPoint.at_block_begin(entry_block)
+        ip.insert(op1)
+        ip.insert(op2)
+        # CHECK: func @f1
+        # CHECK: "custom.op1"()
+        # CHECK: "custom.op2"()
+        # CHECK: %0 = "custom.addi"
+        print(module)
+
+    # Trying to add a previously added op should raise.
+    try:
+        ip.insert(op1)
+    except ValueError:
+        pass
+    else:
+        assert False, "expected insert of attached op to raise"
 
 
 # CHECK-LABEL: TEST: testOperationWithRegion
 @run
 def testOperationWithRegion():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  with Location.unknown(ctx):
-    i32 = IntegerType.get_signed(32)
-    op1 = Operation.create("custom.op1", regions=1)
-    block = op1.regions[0].blocks.append(i32, i32)
-    # CHECK: "custom.op1"() ({
-    # CHECK: ^bb0(%arg0: si32, %arg1: si32):
-    # CHECK:   "custom.terminator"() : () -> ()
-    # CHECK: }) : () -> ()
-    terminator = Operation.create("custom.terminator")
-    ip = InsertionPoint(block)
-    ip.insert(terminator)
-    print(op1)
-
-    # Now add the whole operation to another op.
-    # TODO: Verify lifetime hazard by nulling out the new owning module and
-    # accessing op1.
-    # TODO: Also verify accessing the terminator once both parents are nulled
-    # out.
-    module = Module.parse(r"""
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        i32 = IntegerType.get_signed(32)
+        op1 = Operation.create("custom.op1", regions=1)
+        block = op1.regions[0].blocks.append(i32, i32)
+        # CHECK: "custom.op1"() ({
+        # CHECK: ^bb0(%arg0: si32, %arg1: si32):
+        # CHECK:   "custom.terminator"() : () -> ()
+        # CHECK: }) : () -> ()
+        terminator = Operation.create("custom.terminator")
+        ip = InsertionPoint(block)
+        ip.insert(terminator)
+        print(op1)
+
+        # Now add the whole operation to another op.
+        # TODO: Verify lifetime hazard by nulling out the new owning module and
+        # accessing op1.
+        # TODO: Also verify accessing the terminator once both parents are nulled
+        # out.
+        module = Module.parse(
+            r"""
       func.func @f1(%arg0: i32) -> i32 {
         %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
         return %1 : i32
       }
-    """)
-    func = module.body.operations[0]
-    entry_block = func.regions[0].blocks[0]
-    ip = InsertionPoint.at_block_begin(entry_block)
-    ip.insert(op1)
-    # CHECK: func @f1
-    # CHECK: "custom.op1"()
-    # CHECK:   "custom.terminator"
-    # CHECK: %0 = "custom.addi"
-    print(module)
+    """
+        )
+        func = module.body.operations[0]
+        entry_block = func.regions[0].blocks[0]
+        ip = InsertionPoint.at_block_begin(entry_block)
+        ip.insert(op1)
+        # CHECK: func @f1
+        # CHECK: "custom.op1"()
+        # CHECK:   "custom.terminator"
+        # CHECK: %0 = "custom.addi"
+        print(module)
 
 
 # CHECK-LABEL: TEST: testOperationResultList
 @run
 def testOperationResultList():
-  ctx = Context()
-  module = Module.parse(
-      r"""
+    ctx = Context()
+    module = Module.parse(
+        r"""
     func.func @f1() {
       %0:3 = call @f2() : () -> (i32, f64, index)
       return
     }
     func.func private @f2() -> (i32, f64, index)
-  """, ctx)
-  caller = module.body.operations[0]
-  call = caller.regions[0].blocks[0].operations[0]
-  assert len(call.results) == 3
-  # CHECK: Result 0, type i32
-  # CHECK: Result 1, type f64
-  # CHECK: Result 2, type index
-  for res in call.results:
-    print(f"Result {res.result_number}, type {res.type}")
-
-  # CHECK: Result type i32
-  # CHECK: Result type f64
-  # CHECK: Result type index
-  for t in call.results.types:
-    print(f"Result type {t}")
-
-  # Out of range
-  expect_index_error(lambda: call.results[3])
-  expect_index_error(lambda: call.results[-4])
+  """,
+        ctx,
+    )
+    caller = module.body.operations[0]
+    call = caller.regions[0].blocks[0].operations[0]
+    assert len(call.results) == 3
+    # CHECK: Result 0, type i32
+    # CHECK: Result 1, type f64
+    # CHECK: Result 2, type index
+    for res in call.results:
+        print(f"Result {res.result_number}, type {res.type}")
+
+    # CHECK: Result type i32
+    # CHECK: Result type f64
+    # CHECK: Result type index
+    for t in call.results.types:
+        print(f"Result type {t}")
+
+    # Out of range
+    expect_index_error(lambda: call.results[3])
+    expect_index_error(lambda: call.results[-4])
 
 
 # CHECK-LABEL: TEST: testOperationResultListSlice
 @run
 def testOperationResultListSlice():
-  with Context() as ctx:
-    ctx.allow_unregistered_dialects = True
-    module = Module.parse(r"""
+    with Context() as ctx:
+        ctx.allow_unregistered_dialects = True
+        module = Module.parse(
+            r"""
       func.func @f1() {
         "some.op"() : () -> (i1, i2, i3, i4, i5)
         return
       }
-    """)
-    func = module.body.operations[0]
-    entry_block = func.regions[0].blocks[0]
-    producer = entry_block.operations[0]
-
-    assert len(producer.results) == 5
-    for left, right in zip(producer.results, producer.results[::-1][::-1]):
-      assert left == right
-      assert left.result_number == right.result_number
-
-    # CHECK: Result 0, type i1
-    # CHECK: Result 1, type i2
-    # CHECK: Result 2, type i3
-    # CHECK: Result 3, type i4
-    # CHECK: Result 4, type i5
-    full_slice = producer.results[:]
-    for res in full_slice:
-      print(f"Result {res.result_number}, type {res.type}")
-
-    # CHECK: Result 1, type i2
-    # CHECK: Result 2, type i3
-    # CHECK: Result 3, type i4
-    middle = producer.results[1:4]
-    for res in middle:
-      print(f"Result {res.result_number}, type {res.type}")
-
-    # CHECK: Result 1, type i2
-    # CHECK: Result 3, type i4
-    odd = producer.results[1::2]
-    for res in odd:
-      print(f"Result {res.result_number}, type {res.type}")
-
-    # CHECK: Result 3, type i4
-    # CHECK: Result 1, type i2
-    inverted_middle = producer.results[-2:0:-2]
-    for res in inverted_middle:
-      print(f"Result {res.result_number}, type {res.type}")
+    """
+        )
+        func = module.body.operations[0]
+        entry_block = func.regions[0].blocks[0]
+        producer = entry_block.operations[0]
+
+        assert len(producer.results) == 5
+        for left, right in zip(producer.results, producer.results[::-1][::-1]):
+            assert left == right
+            assert left.result_number == right.result_number
+
+        # CHECK: Result 0, type i1
+        # CHECK: Result 1, type i2
+        # CHECK: Result 2, type i3
+        # CHECK: Result 3, type i4
+        # CHECK: Result 4, type i5
+        full_slice = producer.results[:]
+        for res in full_slice:
+            print(f"Result {res.result_number}, type {res.type}")
+
+        # CHECK: Result 1, type i2
+        # CHECK: Result 2, type i3
+        # CHECK: Result 3, type i4
+        middle = producer.results[1:4]
+        for res in middle:
+            print(f"Result {res.result_number}, type {res.type}")
+
+        # CHECK: Result 1, type i2
+        # CHECK: Result 3, type i4
+        odd = producer.results[1::2]
+        for res in odd:
+            print(f"Result {res.result_number}, type {res.type}")
+
+        # CHECK: Result 3, type i4
+        # CHECK: Result 1, type i2
+        inverted_middle = producer.results[-2:0:-2]
+        for res in inverted_middle:
+            print(f"Result {res.result_number}, type {res.type}")
 
 
 # CHECK-LABEL: TEST: testOperationAttributes
 @run
 def testOperationAttributes():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  module = Module.parse(
-      r"""
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    module = Module.parse(
+        r"""
     "some.op"() { some.attribute = 1 : i8,
                   other.attribute = 3.0,
                   dependent = "text" } : () -> ()
-  """, ctx)
-  op = module.body.operations[0]
-  assert len(op.attributes) == 3
-  iattr = IntegerAttr(op.attributes["some.attribute"])
-  fattr = FloatAttr(op.attributes["other.attribute"])
-  sattr = StringAttr(op.attributes["dependent"])
-  # CHECK: Attribute type i8, value 1
-  print(f"Attribute type {iattr.type}, value {iattr.value}")
-  # CHECK: Attribute type f64, value 3.0
-  print(f"Attribute type {fattr.type}, value {fattr.value}")
-  # CHECK: Attribute value text
-  print(f"Attribute value {sattr.value}")
-  # CHECK: Attribute value b'text'
-  print(f"Attribute value {sattr.value_bytes}")
-
-  # We don't know in which order the attributes are stored.
-  # CHECK-DAG: NamedAttribute(dependent="text")
-  # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
-  # CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
-  for attr in op.attributes:
-    print(str(attr))
-
-  # Check that exceptions are raised as expected.
-  try:
-    op.attributes["does_not_exist"]
-  except KeyError:
-    pass
-  else:
-    assert False, "expected KeyError on accessing a non-existent attribute"
-
-  try:
-    op.attributes[42]
-  except IndexError:
-    pass
-  else:
-    assert False, "expected IndexError on accessing an out-of-bounds attribute"
-
+  """,
+        ctx,
+    )
+    op = module.body.operations[0]
+    assert len(op.attributes) == 3
+    iattr = IntegerAttr(op.attributes["some.attribute"])
+    fattr = FloatAttr(op.attributes["other.attribute"])
+    sattr = StringAttr(op.attributes["dependent"])
+    # CHECK: Attribute type i8, value 1
+    print(f"Attribute type {iattr.type}, value {iattr.value}")
+    # CHECK: Attribute type f64, value 3.0
+    print(f"Attribute type {fattr.type}, value {fattr.value}")
+    # CHECK: Attribute value text
+    print(f"Attribute value {sattr.value}")
+    # CHECK: Attribute value b'text'
+    print(f"Attribute value {sattr.value_bytes}")
+
+    # We don't know in which order the attributes are stored.
+    # CHECK-DAG: NamedAttribute(dependent="text")
+    # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
+    # CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
+    for attr in op.attributes:
+        print(str(attr))
+
+    # Check that exceptions are raised as expected.
+    try:
+        op.attributes["does_not_exist"]
+    except KeyError:
+        pass
+    else:
+        assert False, "expected KeyError on accessing a non-existent attribute"
 
+    try:
+        op.attributes[42]
+    except IndexError:
+        pass
+    else:
+        assert False, "expected IndexError on accessing an out-of-bounds attribute"
 
 
 # CHECK-LABEL: TEST: testOperationPrint
 @run
 def testOperationPrint():
-  ctx = Context()
-  module = Module.parse(
-      r"""
+    ctx = Context()
+    module = Module.parse(
+        r"""
     func.func @f1(%arg0: i32) -> i32 {
       %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
       return %arg0 : i32
     }
-  """, ctx)
-
-  # Test print to stdout.
-  # CHECK: return %arg0 : i32
-  module.operation.print()
-
-  # Test print to text file.
-  f = io.StringIO()
-  # CHECK: <class 'str'>
-  # CHECK: return %arg0 : i32
-  module.operation.print(file=f)
-  str_value = f.getvalue()
-  print(str_value.__class__)
-  print(f.getvalue())
-
-  # Test roundtrip to bytecode.
-  bytecode_stream = io.BytesIO()
-  module.operation.write_bytecode(bytecode_stream, desired_version=1)
-  bytecode = bytecode_stream.getvalue()
-  assert bytecode.startswith(b'ML\xefR'), "Expected bytecode to start with MLïR"
-  module_roundtrip = Module.parse(bytecode, ctx)
-  f = io.StringIO()
-  module_roundtrip.operation.print(file=f)
-  roundtrip_value = f.getvalue()
-  assert str_value == roundtrip_value, "Mismatch after roundtrip bytecode"
-
-
-  # Test print to binary file.
-  f = io.BytesIO()
-  # CHECK: <class 'bytes'>
-  # CHECK: return %arg0 : i32
-  module.operation.print(file=f, binary=True)
-  bytes_value = f.getvalue()
-  print(bytes_value.__class__)
-  print(bytes_value)
-
-  # Test get_asm local_scope.
-  # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
-  module.operation.print(enable_debug_info=True, use_local_scope=True)
-
-  # Test get_asm with options.
-  # CHECK: value = dense_resource<__elided__> : tensor<4xi32>
-  # CHECK: "func.return"(%arg0) : (i32) -> () -:4:7
-  module.operation.print(
-      large_elements_limit=2,
-      enable_debug_info=True,
-      pretty_debug_info=True,
-      print_generic_op_form=True,
-      use_local_scope=True)
-
-
+  """,
+        ctx,
+    )
+
+    # Test print to stdout.
+    # CHECK: return %arg0 : i32
+    module.operation.print()
+
+    # Test print to text file.
+    f = io.StringIO()
+    # CHECK: <class 'str'>
+    # CHECK: return %arg0 : i32
+    module.operation.print(file=f)
+    str_value = f.getvalue()
+    print(str_value.__class__)
+    print(f.getvalue())
+
+    # Test roundtrip to bytecode.
+    bytecode_stream = io.BytesIO()
+    module.operation.write_bytecode(bytecode_stream, desired_version=1)
+    bytecode = bytecode_stream.getvalue()
+    assert bytecode.startswith(b"ML\xefR"), "Expected bytecode to start with MLïR"
+    module_roundtrip = Module.parse(bytecode, ctx)
+    f = io.StringIO()
+    module_roundtrip.operation.print(file=f)
+    roundtrip_value = f.getvalue()
+    assert str_value == roundtrip_value, "Mismatch after roundtrip bytecode"
+
+    # Test print to binary file.
+    f = io.BytesIO()
+    # CHECK: <class 'bytes'>
+    # CHECK: return %arg0 : i32
+    module.operation.print(file=f, binary=True)
+    bytes_value = f.getvalue()
+    print(bytes_value.__class__)
+    print(bytes_value)
+
+    # Test get_asm local_scope.
+    # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
+    module.operation.print(enable_debug_info=True, use_local_scope=True)
+
+    # Test get_asm with options.
+    # CHECK: value = dense_resource<__elided__> : tensor<4xi32>
+    # CHECK: "func.return"(%arg0) : (i32) -> () -:4:7
+    module.operation.print(
+        large_elements_limit=2,
+        enable_debug_info=True,
+        pretty_debug_info=True,
+        print_generic_op_form=True,
+        use_local_scope=True,
+    )
 
 
 # CHECK-LABEL: TEST: testKnownOpView
 @run
 def testKnownOpView():
-  with Context(), Location.unknown():
-    Context.current.allow_unregistered_dialects = True
-    module = Module.parse(r"""
+    with Context(), Location.unknown():
+        Context.current.allow_unregistered_dialects = True
+        module = Module.parse(
+            r"""
       %1 = "custom.f32"() : () -> f32
       %2 = "custom.f32"() : () -> f32
       %3 = arith.addf %1, %2 : f32
-    """)
-    print(module)
+    """
+        )
+        print(module)
 
-    # addf should map to a known OpView class in the arithmetic dialect.
-    # We know the OpView for it defines an 'lhs' attribute.
-    addf = module.body.operations[2]
-    # CHECK: <mlir.dialects._arith_ops_gen.AddFOp object
-    print(repr(addf))
-    # CHECK: "custom.f32"()
-    print(addf.lhs)
+        # addf should map to a known OpView class in the arithmetic dialect.
+        # We know the OpView for it defines an 'lhs' attribute.
+        addf = module.body.operations[2]
+        # CHECK: <mlir.dialects._arith_ops_gen.AddFOp object
+        print(repr(addf))
+        # CHECK: "custom.f32"()
+        print(addf.lhs)
 
-    # One of the custom ops should resolve to the default OpView.
-    custom = module.body.operations[0]
-    # CHECK: OpView object
-    print(repr(custom))
+        # One of the custom ops should resolve to the default OpView.
+        custom = module.body.operations[0]
+        # CHECK: OpView object
+        print(repr(custom))
 
-    # Check again to make sure negative caching works.
-    custom = module.body.operations[0]
-    # CHECK: OpView object
-    print(repr(custom))
+        # Check again to make sure negative caching works.
+        custom = module.body.operations[0]
+        # CHECK: OpView object
+        print(repr(custom))
 
 
 # CHECK-LABEL: TEST: testSingleResultProperty
 @run
 def testSingleResultProperty():
-  with Context(), Location.unknown():
-    Context.current.allow_unregistered_dialects = True
-    module = Module.parse(r"""
+    with Context(), Location.unknown():
+        Context.current.allow_unregistered_dialects = True
+        module = Module.parse(
+            r"""
       "custom.no_result"() : () -> ()
       %0:2 = "custom.two_result"() : () -> (f32, f32)
       %1 = "custom.one_result"() : () -> f32
-    """)
-    print(module)
+    """
+        )
+        print(module)
 
-  try:
-    module.body.operations[0].result
-  except ValueError as e:
-    # CHECK: Cannot call .result on operation custom.no_result which has 0 results
-    print(e)
-  else:
-    assert False, "Expected exception"
+    try:
+        module.body.operations[0].result
+    except ValueError as e:
+        # CHECK: Cannot call .result on operation custom.no_result which has 0 results
+        print(e)
+    else:
+        assert False, "Expected exception"
 
-  try:
-    module.body.operations[1].result
-  except ValueError as e:
-    # CHECK: Cannot call .result on operation custom.two_result which has 2 results
-    print(e)
-  else:
-    assert False, "Expected exception"
+    try:
+        module.body.operations[1].result
+    except ValueError as e:
+        # CHECK: Cannot call .result on operation custom.two_result which has 2 results
+        print(e)
+    else:
+        assert False, "Expected exception"
 
-  # CHECK: %1 = "custom.one_result"() : () -> f32
-  print(module.body.operations[2])
+    # CHECK: %1 = "custom.one_result"() : () -> f32
+    print(module.body.operations[2])
 
 
 def create_invalid_operation():
-  # This module has two region and is invalid verify that we fallback
-  # to the generic printer for safety.
-  op = Operation.create("builtin.module", regions=2)
-  op.regions[0].blocks.append()
-  return op
+    # This module has two region and is invalid verify that we fallback
+    # to the generic printer for safety.
+    op = Operation.create("builtin.module", regions=2)
+    op.regions[0].blocks.append()
+    return op
+
 
 # CHECK-LABEL: TEST: testInvalidOperationStrSoftFails
 @run
 def testInvalidOperationStrSoftFails():
-  ctx = Context()
-  with Location.unknown(ctx):
-    invalid_op = create_invalid_operation()
-    # Verify that we fallback to the generic printer for safety.
-    # CHECK: "builtin.module"() ({
-    # CHECK: }) : () -> ()
-    print(invalid_op)
-    try:
-      invalid_op.verify()
-    except MLIRError as e:
-      # CHECK: Exception: <
-      # CHECK:   Verification failed:
-      # CHECK:   error: unknown: 'builtin.module' op requires one region
-      # CHECK:    note: unknown: see current operation:
-      # CHECK:     "builtin.module"() ({
-      # CHECK:     ^bb0:
-      # CHECK:     }, {
-      # CHECK:     }) : () -> ()
-      # CHECK: >
-      print(f"Exception: <{e}>")
+    ctx = Context()
+    with Location.unknown(ctx):
+        invalid_op = create_invalid_operation()
+        # Verify that we fallback to the generic printer for safety.
+        # CHECK: "builtin.module"() ({
+        # CHECK: }) : () -> ()
+        print(invalid_op)
+        try:
+            invalid_op.verify()
+        except MLIRError as e:
+            # CHECK: Exception: <
+            # CHECK:   Verification failed:
+            # CHECK:   error: unknown: 'builtin.module' op requires one region
+            # CHECK:    note: unknown: see current operation:
+            # CHECK:     "builtin.module"() ({
+            # CHECK:     ^bb0:
+            # CHECK:     }, {
+            # CHECK:     }) : () -> ()
+            # CHECK: >
+            print(f"Exception: <{e}>")
 
 
 # CHECK-LABEL: TEST: testInvalidModuleStrSoftFails
 @run
 def testInvalidModuleStrSoftFails():
-  ctx = Context()
-  with Location.unknown(ctx):
-    module = Module.create()
-    with InsertionPoint(module.body):
-      invalid_op = create_invalid_operation()
-    # Verify that we fallback to the generic printer for safety.
-    # CHECK: "builtin.module"() ({
-    # CHECK: }) : () -> ()
-    print(module)
+    ctx = Context()
+    with Location.unknown(ctx):
+        module = Module.create()
+        with InsertionPoint(module.body):
+            invalid_op = create_invalid_operation()
+        # Verify that we fallback to the generic printer for safety.
+        # CHECK: "builtin.module"() ({
+        # CHECK: }) : () -> ()
+        print(module)
 
 
 # CHECK-LABEL: TEST: testInvalidOperationGetAsmBinarySoftFails
 @run
 def testInvalidOperationGetAsmBinarySoftFails():
-  ctx = Context()
-  with Location.unknown(ctx):
-    invalid_op = create_invalid_operation()
-    # Verify that we fallback to the generic printer for safety.
-    # CHECK: b'"builtin.module"() ({\n^bb0:\n}, {\n}) : () -> ()\n'
-    print(invalid_op.get_asm(binary=True))
+    ctx = Context()
+    with Location.unknown(ctx):
+        invalid_op = create_invalid_operation()
+        # Verify that we fallback to the generic printer for safety.
+        # CHECK: b'"builtin.module"() ({\n^bb0:\n}, {\n}) : () -> ()\n'
+        print(invalid_op.get_asm(binary=True))
 
 
 # CHECK-LABEL: TEST: testCreateWithInvalidAttributes
 @run
 def testCreateWithInvalidAttributes():
-  ctx = Context()
-  with Location.unknown(ctx):
-    try:
-      Operation.create(
-          "builtin.module", attributes={None: StringAttr.get("name")})
-    except Exception as e:
-      # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
-      print(e)
-    try:
-      Operation.create(
-          "builtin.module", attributes={42: StringAttr.get("name")})
-    except Exception as e:
-      # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
-      print(e)
-    try:
-      Operation.create("builtin.module", attributes={"some_key": ctx})
-    except Exception as e:
-      # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module"
-      print(e)
-    try:
-      Operation.create("builtin.module", attributes={"some_key": None})
-    except Exception as e:
-      # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module"
-      print(e)
+    ctx = Context()
+    with Location.unknown(ctx):
+        try:
+            Operation.create(
+                "builtin.module", attributes={None: StringAttr.get("name")}
+            )
+        except Exception as e:
+            # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
+            print(e)
+        try:
+            Operation.create("builtin.module", attributes={42: StringAttr.get("name")})
+        except Exception as e:
+            # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
+            print(e)
+        try:
+            Operation.create("builtin.module", attributes={"some_key": ctx})
+        except Exception as e:
+            # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module"
+            print(e)
+        try:
+            Operation.create("builtin.module", attributes={"some_key": None})
+        except Exception as e:
+            # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module"
+            print(e)
 
 
 # CHECK-LABEL: TEST: testOperationName
 @run
 def testOperationName():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  module = Module.parse(
-      r"""
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    module = Module.parse(
+        r"""
     %0 = "custom.op1"() : () -> f32
     %1 = "custom.op2"() : () -> i32
     %2 = "custom.op1"() : () -> f32
-  """, ctx)
+  """,
+        ctx,
+    )
 
-  # CHECK: custom.op1
-  # CHECK: custom.op2
-  # CHECK: custom.op1
-  for op in module.body.operations:
-    print(op.operation.name)
+    # CHECK: custom.op1
+    # CHECK: custom.op2
+    # CHECK: custom.op1
+    for op in module.body.operations:
+        print(op.operation.name)
 
 
 # CHECK-LABEL: TEST: testCapsuleConversions
 @run
 def testCapsuleConversions():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  with Location.unknown(ctx):
-    m = Operation.create("custom.op1").operation
-    m_capsule = m._CAPIPtr
-    assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
-    m2 = Operation._CAPICreate(m_capsule)
-    assert m2 is m
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        m = Operation.create("custom.op1").operation
+        m_capsule = m._CAPIPtr
+        assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
+        m2 = Operation._CAPICreate(m_capsule)
+        assert m2 is m
 
 
 # CHECK-LABEL: TEST: testOperationErase
 @run
 def testOperationErase():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  with Location.unknown(ctx):
-    m = Module.create()
-    with InsertionPoint(m.body):
-      op = Operation.create("custom.op1")
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        m = Module.create()
+        with InsertionPoint(m.body):
+            op = Operation.create("custom.op1")
 
-      # CHECK: "custom.op1"
-      print(m)
+            # CHECK: "custom.op1"
+            print(m)
 
-      op.operation.erase()
+            op.operation.erase()
 
-      # CHECK-NOT: "custom.op1"
-      print(m)
+            # CHECK-NOT: "custom.op1"
+            print(m)
 
-      # Ensure we can create another operation
-      Operation.create("custom.op2")
+            # Ensure we can create another operation
+            Operation.create("custom.op2")
 
 
 # CHECK-LABEL: TEST: testOperationClone
 @run
 def testOperationClone():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  with Location.unknown(ctx):
-    m = Module.create()
-    with InsertionPoint(m.body):
-      op = Operation.create("custom.op1")
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        m = Module.create()
+        with InsertionPoint(m.body):
+            op = Operation.create("custom.op1")
 
-      # CHECK: "custom.op1"
-      print(m)
+            # CHECK: "custom.op1"
+            print(m)
 
-      clone = op.operation.clone()
-      op.operation.erase()
+            clone = op.operation.clone()
+            op.operation.erase()
 
-      # CHECK: "custom.op1"
-      print(m)
+            # CHECK: "custom.op1"
+            print(m)
 
 
 # CHECK-LABEL: TEST: testOperationLoc
 @run
 def testOperationLoc():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  with ctx:
-    loc = Location.name("loc")
-    op = Operation.create("custom.op", loc=loc)
-    assert op.location == loc
-    assert op.operation.location == loc
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with ctx:
+        loc = Location.name("loc")
+        op = Operation.create("custom.op", loc=loc)
+        assert op.location == loc
+        assert op.operation.location == loc
 
 
 # CHECK-LABEL: TEST: testModuleMerge
 @run
 def testModuleMerge():
-  with Context():
-    m1 = Module.parse("func.func private @foo()")
-    m2 = Module.parse("""
+    with Context():
+        m1 = Module.parse("func.func private @foo()")
+        m2 = Module.parse(
+            """
       func.func private @bar()
       func.func private @qux()
-    """)
-    foo = m1.body.operations[0]
-    bar = m2.body.operations[0]
-    qux = m2.body.operations[1]
-    bar.move_before(foo)
-    qux.move_after(foo)
+    """
+        )
+        foo = m1.body.operations[0]
+        bar = m2.body.operations[0]
+        qux = m2.body.operations[1]
+        bar.move_before(foo)
+        qux.move_after(foo)
 
-    # CHECK: module
-    # CHECK: func private @bar
-    # CHECK: func private @foo
-    # CHECK: func private @qux
-    print(m1)
+        # CHECK: module
+        # CHECK: func private @bar
+        # CHECK: func private @foo
+        # CHECK: func private @qux
+        print(m1)
 
-    # CHECK: module {
-    # CHECK-NEXT: }
-    print(m2)
+        # CHECK: module {
+        # CHECK-NEXT: }
+        print(m2)
 
 
 # CHECK-LABEL: TEST: testAppendMoveFromAnotherBlock
 @run
 def testAppendMoveFromAnotherBlock():
-  with Context():
-    m1 = Module.parse("func.func private @foo()")
-    m2 = Module.parse("func.func private @bar()")
-    func = m1.body.operations[0]
-    m2.body.append(func)
+    with Context():
+        m1 = Module.parse("func.func private @foo()")
+        m2 = Module.parse("func.func private @bar()")
+        func = m1.body.operations[0]
+        m2.body.append(func)
 
-    # CHECK: module
-    # CHECK: func private @bar
-    # CHECK: func private @foo
+        # CHECK: module
+        # CHECK: func private @bar
+        # CHECK: func private @foo
 
-    print(m2)
-    # CHECK: module {
-    # CHECK-NEXT: }
-    print(m1)
+        print(m2)
+        # CHECK: module {
+        # CHECK-NEXT: }
+        print(m1)
 
 
 # CHECK-LABEL: TEST: testDetachFromParent
 @run
 def testDetachFromParent():
-  with Context():
-    m1 = Module.parse("func.func private @foo()")
-    func = m1.body.operations[0].detach_from_parent()
+    with Context():
+        m1 = Module.parse("func.func private @foo()")
+        func = m1.body.operations[0].detach_from_parent()
 
-    try:
-      func.detach_from_parent()
-    except ValueError as e:
-      if "has no parent" not in str(e):
-        raise
-    else:
-      assert False, "expected ValueError when detaching a detached operation"
+        try:
+            func.detach_from_parent()
+        except ValueError as e:
+            if "has no parent" not in str(e):
+                raise
+        else:
+            assert False, "expected ValueError when detaching a detached operation"
 
-    print(m1)
-    # CHECK-NOT: func private @foo
+        print(m1)
+        # CHECK-NOT: func private @foo
 
 
 # CHECK-LABEL: TEST: testOperationHash
 @run
 def testOperationHash():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  with ctx, Location.unknown():
-    op = Operation.create("custom.op1")
-    assert hash(op) == hash(op.operation)
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with ctx, Location.unknown():
+        op = Operation.create("custom.op1")
+        assert hash(op) == hash(op.operation)
 
 
 # CHECK-LABEL: TEST: testOperationParse
 @run
 def testOperationParse():
-  with Context() as ctx:
-    ctx.allow_unregistered_dialects = True
-
-    # Generic operation parsing.
-    m = Operation.parse('module {}')
-    o = Operation.parse('"test.foo"() : () -> ()')
-    assert isinstance(m, ModuleOp)
-    assert type(o) is OpView
-
-    # Parsing specific operation.
-    m = ModuleOp.parse('module {}')
-    assert isinstance(m, ModuleOp)
-    try:
-      ModuleOp.parse('"test.foo"() : () -> ()')
-    except MLIRError as e:
-      # CHECK: error: Expected a 'builtin.module' op, got: 'test.foo'
-      print(f"error: {e}")
-    else:
-      assert False, "expected error"
-
-    o = Operation.parse('"test.foo"() : () -> ()', source_name="my-source-string")
-    # CHECK: op_with_source_name: "test.foo"() : () -> () loc("my-source-string":1:1)
-    print(f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}")
+    with Context() as ctx:
+        ctx.allow_unregistered_dialects = True
+
+        # Generic operation parsing.
+        m = Operation.parse("module {}")
+        o = Operation.parse('"test.foo"() : () -> ()')
+        assert isinstance(m, ModuleOp)
+        assert type(o) is OpView
+
+        # Parsing specific operation.
+        m = ModuleOp.parse("module {}")
+        assert isinstance(m, ModuleOp)
+        try:
+            ModuleOp.parse('"test.foo"() : () -> ()')
+        except MLIRError as e:
+            # CHECK: error: Expected a 'builtin.module' op, got: 'test.foo'
+            print(f"error: {e}")
+        else:
+            assert False, "expected error"
+
+        o = Operation.parse('"test.foo"() : () -> ()', source_name="my-source-string")
+        # CHECK: op_with_source_name: "test.foo"() : () -> () loc("my-source-string":1:1)
+        print(
+            f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}"
+        )
index 9ce8959..17f3e35 100644 (file)
@@ -7,150 +7,162 @@ from mlir.ir import *
 
 
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  gc.collect()
-  assert Context._get_live_count() == 0
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
+    return f
 
 
 # CHECK-LABEL: TEST: testSymbolTableInsert
 @run
 def testSymbolTableInsert():
-  with Context() as ctx:
-    ctx.allow_unregistered_dialects = True
-    m1 = Module.parse("""
+    with Context() as ctx:
+        ctx.allow_unregistered_dialects = True
+        m1 = Module.parse(
+            """
       func.func private @foo()
-      func.func private @bar()""")
-    m2 = Module.parse("""
+      func.func private @bar()"""
+        )
+        m2 = Module.parse(
+            """
       func.func private @qux()
       func.func private @foo()
-      "foo.bar"() : () -> ()""")
-
-    symbol_table = SymbolTable(m1.operation)
-
-    # CHECK: func private @foo
-    # CHECK: func private @bar
-    assert "foo" in symbol_table
-    print(symbol_table["foo"])
-    assert "bar" in symbol_table
-    bar = symbol_table["bar"]
-    print(symbol_table["bar"])
-
-    assert "qux" not in symbol_table
-
-    del symbol_table["bar"]
-    try:
-      symbol_table.erase(symbol_table["bar"])
-    except KeyError:
-      pass
-    else:
-      assert False, "expected KeyError"
-
-    # CHECK: module
-    # CHECK:   func private @foo()
-    print(m1)
-    assert "bar" not in symbol_table
-
-    try:
-      print(bar)
-    except RuntimeError as e:
-      if "the operation has been invalidated" not in str(e):
-        raise
-    else:
-      assert False, "expected RuntimeError due to invalidated operation"
-
-    qux = m2.body.operations[0]
-    m1.body.append(qux)
-    symbol_table.insert(qux)
-    assert "qux" in symbol_table
-
-    # Check that insertion actually renames this symbol in the symbol table.
-    foo2 = m2.body.operations[0]
-    m1.body.append(foo2)
-    updated_name = symbol_table.insert(foo2)
-    assert foo2.name.value != "foo"
-    assert foo2.name == updated_name
-
-    # CHECK: module
-    # CHECK:   func private @foo()
-    # CHECK:   func private @qux()
-    # CHECK:   func private @foo{{.*}}
-    print(m1)
-
-    try:
-      symbol_table.insert(m2.body.operations[0])
-    except ValueError as e:
-      if "Expected operation to have a symbol name" not in str(e):
-        raise
-    else:
-      assert False, "exepcted ValueError when adding a non-symbol"
+      "foo.bar"() : () -> ()"""
+        )
+
+        symbol_table = SymbolTable(m1.operation)
+
+        # CHECK: func private @foo
+        # CHECK: func private @bar
+        assert "foo" in symbol_table
+        print(symbol_table["foo"])
+        assert "bar" in symbol_table
+        bar = symbol_table["bar"]
+        print(symbol_table["bar"])
+
+        assert "qux" not in symbol_table
+
+        del symbol_table["bar"]
+        try:
+            symbol_table.erase(symbol_table["bar"])
+        except KeyError:
+            pass
+        else:
+            assert False, "expected KeyError"
+
+        # CHECK: module
+        # CHECK:   func private @foo()
+        print(m1)
+        assert "bar" not in symbol_table
+
+        try:
+            print(bar)
+        except RuntimeError as e:
+            if "the operation has been invalidated" not in str(e):
+                raise
+        else:
+            assert False, "expected RuntimeError due to invalidated operation"
+
+        qux = m2.body.operations[0]
+        m1.body.append(qux)
+        symbol_table.insert(qux)
+        assert "qux" in symbol_table
+
+        # Check that insertion actually renames this symbol in the symbol table.
+        foo2 = m2.body.operations[0]
+        m1.body.append(foo2)
+        updated_name = symbol_table.insert(foo2)
+        assert foo2.name.value != "foo"
+        assert foo2.name == updated_name
+
+        # CHECK: module
+        # CHECK:   func private @foo()
+        # CHECK:   func private @qux()
+        # CHECK:   func private @foo{{.*}}
+        print(m1)
+
+        try:
+            symbol_table.insert(m2.body.operations[0])
+        except ValueError as e:
+            if "Expected operation to have a symbol name" not in str(e):
+                raise
+        else:
+            assert False, "exepcted ValueError when adding a non-symbol"
 
 
 # CHECK-LABEL: testSymbolTableRAUW
 @run
 def testSymbolTableRAUW():
-  with Context() as ctx:
-    m = Module.parse("""
+    with Context() as ctx:
+        m = Module.parse(
+            """
       func.func private @foo() {
         call @bar() : () -> ()
         return
       }
       func.func private @bar()
-      """)
-    foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2]
-    SymbolTable.set_symbol_name(bar, "bam")
-    # Note that module.operation counts as a "nested symbol table" which won't
-    # be traversed into, so it is necessary to traverse its children.
-    SymbolTable.replace_all_symbol_uses("bar", "bam", foo)
-    # CHECK: call @bam()
-    # CHECK: func private @bam
-    print(m)
-    # CHECK: Foo symbol: "foo"
-    # CHECK: Bar symbol: "bam"
-    print(f"Foo symbol: {SymbolTable.get_symbol_name(foo)}")
-    print(f"Bar symbol: {SymbolTable.get_symbol_name(bar)}")
+      """
+        )
+        foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2]
+        SymbolTable.set_symbol_name(bar, "bam")
+        # Note that module.operation counts as a "nested symbol table" which won't
+        # be traversed into, so it is necessary to traverse its children.
+        SymbolTable.replace_all_symbol_uses("bar", "bam", foo)
+        # CHECK: call @bam()
+        # CHECK: func private @bam
+        print(m)
+        # CHECK: Foo symbol: "foo"
+        # CHECK: Bar symbol: "bam"
+        print(f"Foo symbol: {SymbolTable.get_symbol_name(foo)}")
+        print(f"Bar symbol: {SymbolTable.get_symbol_name(bar)}")
 
 
 # CHECK-LABEL: testSymbolTableVisibility
 @run
 def testSymbolTableVisibility():
-  with Context() as ctx:
-    m = Module.parse("""
+    with Context() as ctx:
+        m = Module.parse(
+            """
       func.func private @foo() {
         return
       }
-      """)
-    foo = m.operation.regions[0].blocks[0].operations[0]
-    # CHECK: Existing visibility: "private"
-    print(f"Existing visibility: {SymbolTable.get_visibility(foo)}")
-    SymbolTable.set_visibility(foo, "public")
-    # CHECK: func public @foo
-    print(m)
+      """
+        )
+        foo = m.operation.regions[0].blocks[0].operations[0]
+        # CHECK: Existing visibility: "private"
+        print(f"Existing visibility: {SymbolTable.get_visibility(foo)}")
+        SymbolTable.set_visibility(foo, "public")
+        # CHECK: func public @foo
+        print(m)
 
 
 # CHECK: testWalkSymbolTables
 @run
 def testWalkSymbolTables():
-  with Context() as ctx:
-    m = Module.parse("""
+    with Context() as ctx:
+        m = Module.parse(
+            """
       module @outer {
         module @inner{
         }
       }
-      """)
-    def callback(symbol_table_op, uses_visible):
-      print(f"SYMBOL TABLE: {uses_visible}: {symbol_table_op}")
-    # CHECK: SYMBOL TABLE: True: module @inner
-    # CHECK: SYMBOL TABLE: True: module @outer
-    SymbolTable.walk_symbol_tables(m.operation, True, callback)
-
-    # Make sure exceptions in the callback are handled.
-    def error_callback(symbol_table_op, uses_visible):
-      assert False, "Raised from python"
-    try:
-      SymbolTable.walk_symbol_tables(m.operation, True, error_callback)
-    except RuntimeError as e:
-      # CHECK: GOT EXCEPTION: Exception raised in callback: AssertionError: Raised from python
-      print(f"GOT EXCEPTION: {e}")
+      """
+        )
 
+        def callback(symbol_table_op, uses_visible):
+            print(f"SYMBOL TABLE: {uses_visible}: {symbol_table_op}")
+
+        # CHECK: SYMBOL TABLE: True: module @inner
+        # CHECK: SYMBOL TABLE: True: module @outer
+        SymbolTable.walk_symbol_tables(m.operation, True, callback)
+
+        # Make sure exceptions in the callback are handled.
+        def error_callback(symbol_table_op, uses_visible):
+            assert False, "Raised from python"
+
+        try:
+            SymbolTable.walk_symbol_tables(m.operation, True, error_callback)
+        except RuntimeError as e:
+            # CHECK: GOT EXCEPTION: Exception raised in callback: AssertionError: Raised from python
+            print(f"GOT EXCEPTION: {e}")
index 66568c4..8a2ada1 100644 (file)
@@ -6,229 +6,235 @@ from mlir.dialects import func
 
 
 def run(f):
-  print("\nTEST:", f.__name__)
-  f()
-  gc.collect()
-  assert Context._get_live_count() == 0
-  return f
+    print("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
+    return f
 
 
 # CHECK-LABEL: TEST: testCapsuleConversions
 @run
 def testCapsuleConversions():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  with Location.unknown(ctx):
-    i32 = IntegerType.get_signless(32)
-    value = Operation.create("custom.op1", results=[i32]).result
-    value_capsule = value._CAPIPtr
-    assert '"mlir.ir.Value._CAPIPtr"' in repr(value_capsule)
-    value2 = Value._CAPICreate(value_capsule)
-    assert value2 == value
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        i32 = IntegerType.get_signless(32)
+        value = Operation.create("custom.op1", results=[i32]).result
+        value_capsule = value._CAPIPtr
+        assert '"mlir.ir.Value._CAPIPtr"' in repr(value_capsule)
+        value2 = Value._CAPICreate(value_capsule)
+        assert value2 == value
 
 
 # CHECK-LABEL: TEST: testOpResultOwner
 @run
 def testOpResultOwner():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  with Location.unknown(ctx):
-    i32 = IntegerType.get_signless(32)
-    op = Operation.create("custom.op1", results=[i32])
-    assert op.result.owner == op
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        i32 = IntegerType.get_signless(32)
+        op = Operation.create("custom.op1", results=[i32])
+        assert op.result.owner == op
 
 
 # CHECK-LABEL: TEST: testBlockArgOwner
 @run
 def testBlockArgOwner():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  module = Module.parse(
-      r"""
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    module = Module.parse(
+        r"""
     func.func @foo(%arg0: f32) {
       return
-    }""", ctx)
-  func = module.body.operations[0]
-  block = func.regions[0].blocks[0]
-  assert block.arguments[0].owner == block
+    }""",
+        ctx,
+    )
+    func = module.body.operations[0]
+    block = func.regions[0].blocks[0]
+    assert block.arguments[0].owner == block
 
 
 # CHECK-LABEL: TEST: testValueIsInstance
 @run
 def testValueIsInstance():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  module = Module.parse(
-      r"""
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    module = Module.parse(
+        r"""
     func.func @foo(%arg0: f32) {
       %0 = "some_dialect.some_op"() : () -> f64
       return
-    }""", ctx)
-  func = module.body.operations[0]
-  assert BlockArgument.isinstance(func.regions[0].blocks[0].arguments[0])
-  assert not OpResult.isinstance(func.regions[0].blocks[0].arguments[0])
+    }""",
+        ctx,
+    )
+    func = module.body.operations[0]
+    assert BlockArgument.isinstance(func.regions[0].blocks[0].arguments[0])
+    assert not OpResult.isinstance(func.regions[0].blocks[0].arguments[0])
 
-  op = func.regions[0].blocks[0].operations[0]
-  assert not BlockArgument.isinstance(op.results[0])
-  assert OpResult.isinstance(op.results[0])
+    op = func.regions[0].blocks[0].operations[0]
+    assert not BlockArgument.isinstance(op.results[0])
+    assert OpResult.isinstance(op.results[0])
 
 
 # CHECK-LABEL: TEST: testValueHash
 @run
 def testValueHash():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  module = Module.parse(
-      r"""
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    module = Module.parse(
+        r"""
     func.func @foo(%arg0: f32) -> f32 {
       %0 = "some_dialect.some_op"(%arg0) : (f32) -> f32
       return %0 : f32
-    }""", ctx)
+    }""",
+        ctx,
+    )
 
-  [func] = module.body.operations
-  block = func.entry_block
-  op, ret = block.operations
-  assert hash(block.arguments[0]) == hash(op.operands[0])
-  assert hash(op.result) == hash(ret.operands[0])
+    [func] = module.body.operations
+    block = func.entry_block
+    op, ret = block.operations
+    assert hash(block.arguments[0]) == hash(op.operands[0])
+    assert hash(op.result) == hash(ret.operands[0])
 
 
 # CHECK-LABEL: TEST: testValueUses
 @run
 def testValueUses():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  with Location.unknown(ctx):
-    i32 = IntegerType.get_signless(32)
-    module = Module.create()
-    with InsertionPoint(module.body):
-      value = Operation.create("custom.op1", results=[i32]).results[0]
-      op1 = Operation.create("custom.op2", operands=[value])
-      op2 = Operation.create("custom.op2", operands=[value])
-
-  # CHECK: Use owner: "custom.op2"
-  # CHECK: Use operand_number: 0
-  # CHECK: Use owner: "custom.op2"
-  # CHECK: Use operand_number: 0
-  for use in value.uses:
-    assert use.owner in [op1, op2]
-    print(f"Use owner: {use.owner}")
-    print(f"Use operand_number: {use.operand_number}")
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        i32 = IntegerType.get_signless(32)
+        module = Module.create()
+        with InsertionPoint(module.body):
+            value = Operation.create("custom.op1", results=[i32]).results[0]
+            op1 = Operation.create("custom.op2", operands=[value])
+            op2 = Operation.create("custom.op2", operands=[value])
+
+    # CHECK: Use owner: "custom.op2"
+    # CHECK: Use operand_number: 0
+    # CHECK: Use owner: "custom.op2"
+    # CHECK: Use operand_number: 0
+    for use in value.uses:
+        assert use.owner in [op1, op2]
+        print(f"Use owner: {use.owner}")
+        print(f"Use operand_number: {use.operand_number}")
 
 
 # CHECK-LABEL: TEST: testValueReplaceAllUsesWith
 @run
 def testValueReplaceAllUsesWith():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  with Location.unknown(ctx):
-    i32 = IntegerType.get_signless(32)
-    module = Module.create()
-    with InsertionPoint(module.body):
-      value = Operation.create("custom.op1", results=[i32]).results[0]
-      op1 = Operation.create("custom.op2", operands=[value])
-      op2 = Operation.create("custom.op2", operands=[value])
-      value2 = Operation.create("custom.op3", results=[i32]).results[0]
-      value.replace_all_uses_with(value2)
-
-  assert len(list(value.uses)) == 0
-
-  # CHECK: Use owner: "custom.op2"
-  # CHECK: Use operand_number: 0
-  # CHECK: Use owner: "custom.op2"
-  # CHECK: Use operand_number: 0
-  for use in value2.uses:
-    assert use.owner in [op1, op2]
-    print(f"Use owner: {use.owner}")
-    print(f"Use operand_number: {use.operand_number}")
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        i32 = IntegerType.get_signless(32)
+        module = Module.create()
+        with InsertionPoint(module.body):
+            value = Operation.create("custom.op1", results=[i32]).results[0]
+            op1 = Operation.create("custom.op2", operands=[value])
+            op2 = Operation.create("custom.op2", operands=[value])
+            value2 = Operation.create("custom.op3", results=[i32]).results[0]
+            value.replace_all_uses_with(value2)
+
+    assert len(list(value.uses)) == 0
+
+    # CHECK: Use owner: "custom.op2"
+    # CHECK: Use operand_number: 0
+    # CHECK: Use owner: "custom.op2"
+    # CHECK: Use operand_number: 0
+    for use in value2.uses:
+        assert use.owner in [op1, op2]
+        print(f"Use owner: {use.owner}")
+        print(f"Use operand_number: {use.operand_number}")
 
 
 # CHECK-LABEL: TEST: testValuePrintAsOperand
 @run
 def testValuePrintAsOperand():
-  ctx = Context()
-  ctx.allow_unregistered_dialects = True
-  with Location.unknown(ctx):
-    i32 = IntegerType.get_signless(32)
-    module = Module.create()
-    with InsertionPoint(module.body):
-      value = Operation.create("custom.op1", results=[i32]).results[0]
-      # CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32)
-      print(value)
-
-      value2 = Operation.create("custom.op2", results=[i32]).results[0]
-      # CHECK: Value(%[[VAL2:.*]] = "custom.op2"() : () -> i32)
-      print(value2)
-
-      f = func.FuncOp("test", ([i32, i32], []))
-      entry_block1 = Block.create_at_start(f.operation.regions[0], [i32, i32])
-
-      with InsertionPoint(entry_block1):
-        value3 = Operation.create("custom.op3", results=[i32]).results[0]
-        # CHECK: Value(%[[VAL3:.*]] = "custom.op3"() : () -> i32)
-        print(value3)
-        value4 = Operation.create("custom.op4", results=[i32]).results[0]
-        # CHECK: Value(%[[VAL4:.*]] = "custom.op4"() : () -> i32)
-        print(value4)
-
-        f = func.FuncOp("test", ([i32, i32], []))
-        entry_block2 = Block.create_at_start(f.operation.regions[0], [i32, i32])
-        with InsertionPoint(entry_block2):
-          value5 = Operation.create("custom.op5", results=[i32]).results[0]
-          # CHECK: Value(%[[VAL5:.*]] = "custom.op5"() : () -> i32)
-          print(value5)
-          value6 = Operation.create("custom.op6", results=[i32]).results[0]
-          # CHECK: Value(%[[VAL6:.*]] = "custom.op6"() : () -> i32)
-          print(value6)
-
-          func.ReturnOp([])
-
-        func.ReturnOp([])
-
-    # CHECK: %[[VAL1]]
-    print(value.get_name())
-    # CHECK: %[[VAL2]]
-    print(value2.get_name())
-    # CHECK: %[[VAL3]]
-    print(value3.get_name())
-    # CHECK: %[[VAL4]]
-    print(value4.get_name())
-
-    # CHECK: %0
-    print(value3.get_name(use_local_scope=True))
-    # CHECK: %1
-    print(value4.get_name(use_local_scope=True))
-
-    # CHECK: %[[VAL5]]
-    print(value5.get_name())
-    # CHECK: %[[VAL6]]
-    print(value6.get_name())
-
-    # CHECK: %[[ARG0:.*]]
-    print(entry_block1.arguments[0].get_name())
-    # CHECK: %[[ARG1:.*]]
-    print(entry_block1.arguments[1].get_name())
-
-    # CHECK: %[[ARG2:.*]]
-    print(entry_block2.arguments[0].get_name())
-    # CHECK: %[[ARG3:.*]]
-    print(entry_block2.arguments[1].get_name())
-
-    # CHECK: module {
-    # CHECK:   %[[VAL1]] = "custom.op1"() : () -> i32
-    # CHECK:   %[[VAL2]] = "custom.op2"() : () -> i32
-    # CHECK:   func.func @test(%[[ARG0]]: i32, %[[ARG1]]: i32) {
-    # CHECK:     %[[VAL3]] = "custom.op3"() : () -> i32
-    # CHECK:     %[[VAL4]] = "custom.op4"() : () -> i32
-    # CHECK:     func @test(%[[ARG2]]: i32, %[[ARG3]]: i32) {
-    # CHECK:       %[[VAL5]] = "custom.op5"() : () -> i32
-    # CHECK:       %[[VAL6]] = "custom.op6"() : () -> i32
-    # CHECK:       return
-    # CHECK:     }
-    # CHECK:     return
-    # CHECK:   }
-    # CHECK: }
-    print(module)
-
-    value2.owner.detach_from_parent()
-    # CHECK: %0
-    print(value2.get_name())
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        i32 = IntegerType.get_signless(32)
+        module = Module.create()
+        with InsertionPoint(module.body):
+            value = Operation.create("custom.op1", results=[i32]).results[0]
+            # CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32)
+            print(value)
+
+            value2 = Operation.create("custom.op2", results=[i32]).results[0]
+            # CHECK: Value(%[[VAL2:.*]] = "custom.op2"() : () -> i32)
+            print(value2)
+
+            f = func.FuncOp("test", ([i32, i32], []))
+            entry_block1 = Block.create_at_start(f.operation.regions[0], [i32, i32])
+
+            with InsertionPoint(entry_block1):
+                value3 = Operation.create("custom.op3", results=[i32]).results[0]
+                # CHECK: Value(%[[VAL3:.*]] = "custom.op3"() : () -> i32)
+                print(value3)
+                value4 = Operation.create("custom.op4", results=[i32]).results[0]
+                # CHECK: Value(%[[VAL4:.*]] = "custom.op4"() : () -> i32)
+                print(value4)
+
+                f = func.FuncOp("test", ([i32, i32], []))
+                entry_block2 = Block.create_at_start(f.operation.regions[0], [i32, i32])
+                with InsertionPoint(entry_block2):
+                    value5 = Operation.create("custom.op5", results=[i32]).results[0]
+                    # CHECK: Value(%[[VAL5:.*]] = "custom.op5"() : () -> i32)
+                    print(value5)
+                    value6 = Operation.create("custom.op6", results=[i32]).results[0]
+                    # CHECK: Value(%[[VAL6:.*]] = "custom.op6"() : () -> i32)
+                    print(value6)
+
+                    func.ReturnOp([])
+
+                func.ReturnOp([])
+
+        # CHECK: %[[VAL1]]
+        print(value.get_name())
+        # CHECK: %[[VAL2]]
+        print(value2.get_name())
+        # CHECK: %[[VAL3]]
+        print(value3.get_name())
+        # CHECK: %[[VAL4]]
+        print(value4.get_name())
+
+        # CHECK: %0
+        print(value3.get_name(use_local_scope=True))
+        # CHECK: %1
+        print(value4.get_name(use_local_scope=True))
+
+        # CHECK: %[[VAL5]]
+        print(value5.get_name())
+        # CHECK: %[[VAL6]]
+        print(value6.get_name())
+
+        # CHECK: %[[ARG0:.*]]
+        print(entry_block1.arguments[0].get_name())
+        # CHECK: %[[ARG1:.*]]
+        print(entry_block1.arguments[1].get_name())
+
+        # CHECK: %[[ARG2:.*]]
+        print(entry_block2.arguments[0].get_name())
+        # CHECK: %[[ARG3:.*]]
+        print(entry_block2.arguments[1].get_name())
+
+        # CHECK: module {
+        # CHECK:   %[[VAL1]] = "custom.op1"() : () -> i32
+        # CHECK:   %[[VAL2]] = "custom.op2"() : () -> i32
+        # CHECK:   func.func @test(%[[ARG0]]: i32, %[[ARG1]]: i32) {
+        # CHECK:     %[[VAL3]] = "custom.op3"() : () -> i32
+        # CHECK:     %[[VAL4]] = "custom.op4"() : () -> i32
+        # CHECK:     func @test(%[[ARG2]]: i32, %[[ARG3]]: i32) {
+        # CHECK:       %[[VAL5]] = "custom.op5"() : () -> i32
+        # CHECK:       %[[VAL6]] = "custom.op6"() : () -> i32
+        # CHECK:       return
+        # CHECK:     }
+        # CHECK:     return
+        # CHECK:   }
+        # CHECK: }
+        print(module)
+
+        value2.owner.detach_from_parent()
+        # CHECK: %0
+        print(value2.get_name())
index 8a98474..12d6e1f 100644 (file)
@@ -1,4 +1,4 @@
-config.environment['ASAN_OPTIONS'] = 'detect_leaks=0'
+config.environment["ASAN_OPTIONS"] = "detect_leaks=0"
 if not config.enable_bindings_python:
-  config.unsupported = True
-config.excludes.add('python_test_ops.td')
+    config.unsupported = True
+config.excludes.add("python_test_ops.td")
index 8b27653..4b3a02a 100644 (file)
@@ -8,122 +8,140 @@ from mlir.dialects.func import FuncOp
 # Log everything to stderr and flush so that we have a unified stream to match
 # errors/info emitted by MLIR to stderr.
 def log(*args):
-  print(*args, file=sys.stderr)
-  sys.stderr.flush()
+    print(*args, file=sys.stderr)
+    sys.stderr.flush()
+
 
 def run(f):
-  log("\nTEST:", f.__name__)
-  f()
-  gc.collect()
-  assert Context._get_live_count() == 0
+    log("\nTEST:", f.__name__)
+    f()
+    gc.collect()
+    assert Context._get_live_count() == 0
+
 
 # Verify capsule interop.
 # CHECK-LABEL: TEST: testCapsule
 def testCapsule():
-  with Context():
-    pm = PassManager()
-    pm_capsule = pm._CAPIPtr
-    assert '"mlir.passmanager.PassManager._CAPIPtr"' in repr(pm_capsule)
-    pm._testing_release()
-    pm1 = PassManager._CAPICreate(pm_capsule)
-    assert pm1 is not None  # And does not crash.
+    with Context():
+        pm = PassManager()
+        pm_capsule = pm._CAPIPtr
+        assert '"mlir.passmanager.PassManager._CAPIPtr"' in repr(pm_capsule)
+        pm._testing_release()
+        pm1 = PassManager._CAPICreate(pm_capsule)
+        assert pm1 is not None  # And does not crash.
+
+
 run(testCapsule)
 
 # CHECK-LABEL: TEST: testConstruct
 @run
 def testConstruct():
-  with Context():
-    # CHECK: pm1: 'any()'
-    # CHECK: pm2: 'builtin.module()'
-    pm1 = PassManager()
-    pm2 = PassManager("builtin.module")
-    log(f"pm1: '{pm1}'")
-    log(f"pm2: '{pm2}'")
+    with Context():
+        # CHECK: pm1: 'any()'
+        # CHECK: pm2: 'builtin.module()'
+        pm1 = PassManager()
+        pm2 = PassManager("builtin.module")
+        log(f"pm1: '{pm1}'")
+        log(f"pm2: '{pm2}'")
 
 
 # Verify successful round-trip.
 # CHECK-LABEL: TEST: testParseSuccess
 def testParseSuccess():
-  with Context():
-    # An unregistered pass should not parse.
-    try:
-      pm = PassManager.parse("builtin.module(func.func(not-existing-pass{json=false}))")
-    except ValueError as e:
-      # CHECK: ValueError exception: {{.+}} 'not-existing-pass' does not refer to a registered pass
-      log("ValueError exception:", e)
-    else:
-      log("Exception not produced")
-
-    # A registered pass should parse successfully.
-    pm = PassManager.parse("builtin.module(func.func(print-op-stats{json=false}))")
-    # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false}))
-    log("Roundtrip: ", pm)
+    with Context():
+        # An unregistered pass should not parse.
+        try:
+            pm = PassManager.parse(
+                "builtin.module(func.func(not-existing-pass{json=false}))"
+            )
+        except ValueError as e:
+            # CHECK: ValueError exception: {{.+}} 'not-existing-pass' does not refer to a registered pass
+            log("ValueError exception:", e)
+        else:
+            log("Exception not produced")
+
+        # A registered pass should parse successfully.
+        pm = PassManager.parse("builtin.module(func.func(print-op-stats{json=false}))")
+        # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false}))
+        log("Roundtrip: ", pm)
+
+
 run(testParseSuccess)
 
 # Verify successful round-trip.
 # CHECK-LABEL: TEST: testParseSpacedPipeline
 def testParseSpacedPipeline():
-  with Context():
-    # A registered pass should parse successfully even if has extras spaces for readability
-    pm = PassManager.parse("""builtin.module(
+    with Context():
+        # A registered pass should parse successfully even if has extras spaces for readability
+        pm = PassManager.parse(
+            """builtin.module(
         func.func( print-op-stats{ json=false } )
-    )""")
-    # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false}))
-    log("Roundtrip: ", pm)
+    )"""
+        )
+        # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false}))
+        log("Roundtrip: ", pm)
+
+
 run(testParseSpacedPipeline)
 
 # Verify failure on unregistered pass.
 # CHECK-LABEL: TEST: testParseFail
 def testParseFail():
-  with Context():
-    try:
-      pm = PassManager.parse("any(unknown-pass)")
-    except ValueError as e:
-      #      CHECK: ValueError exception: MLIR Textual PassPipeline Parser:1:1: error:
-      # CHECK-SAME: 'unknown-pass' does not refer to a registered pass or pass pipeline
-      #      CHECK: unknown-pass
-      #      CHECK: ^
-      log("ValueError exception:", e)
-    else:
-      log("Exception not produced")
+    with Context():
+        try:
+            pm = PassManager.parse("any(unknown-pass)")
+        except ValueError as e:
+            #      CHECK: ValueError exception: MLIR Textual PassPipeline Parser:1:1: error:
+            # CHECK-SAME: 'unknown-pass' does not refer to a registered pass or pass pipeline
+            #      CHECK: unknown-pass
+            #      CHECK: ^
+            log("ValueError exception:", e)
+        else:
+            log("Exception not produced")
+
+
 run(testParseFail)
 
 # Check that adding to a pass manager works
 # CHECK-LABEL: TEST: testAdd
 @run
 def testAdd():
-  pm = PassManager("any", Context())
-  # CHECK: pm: 'any()'
-  log(f"pm: '{pm}'")
-  # CHECK: pm: 'any(cse)'
-  pm.add("cse")
-  log(f"pm: '{pm}'")
-  # CHECK: pm: 'any(cse,cse)'
-  pm.add("cse")
-  log(f"pm: '{pm}'")
+    pm = PassManager("any", Context())
+    # CHECK: pm: 'any()'
+    log(f"pm: '{pm}'")
+    # CHECK: pm: 'any(cse)'
+    pm.add("cse")
+    log(f"pm: '{pm}'")
+    # CHECK: pm: 'any(cse,cse)'
+    pm.add("cse")
+    log(f"pm: '{pm}'")
 
 
 # Verify failure on incorrect level of nesting.
 # CHECK-LABEL: TEST: testInvalidNesting
 def testInvalidNesting():
-  with Context():
-    try:
-      pm = PassManager.parse("func.func(normalize-memrefs)")
-    except ValueError as e:
-      # CHECK: ValueError exception: Can't add pass 'NormalizeMemRefs' restricted to 'builtin.module' on a PassManager intended to run on 'func.func', did you intend to nest?
-      log("ValueError exception:", e)
-    else:
-      log("Exception not produced")
+    with Context():
+        try:
+            pm = PassManager.parse("func.func(normalize-memrefs)")
+        except ValueError as e:
+            # CHECK: ValueError exception: Can't add pass 'NormalizeMemRefs' restricted to 'builtin.module' on a PassManager intended to run on 'func.func', did you intend to nest?
+            log("ValueError exception:", e)
+        else:
+            log("Exception not produced")
+
+
 run(testInvalidNesting)
 
 
 # Verify that a pass manager can execute on IR
 # CHECK-LABEL: TEST: testRunPipeline
 def testRunPipeline():
-  with Context():
-    pm = PassManager.parse("any(print-op-stats{json=false})")
-    func = FuncOp.parse(r"""func.func @successfulParse() { return }""")
-    pm.run(func)
+    with Context():
+        pm = PassManager.parse("any(print-op-stats{json=false})")
+        func = FuncOp.parse(r"""func.func @successfulParse() { return }""")
+        pm.run(func)
+
+
 # CHECK: Operations encountered:
 # CHECK: func.func      , 1
 # CHECK: func.return        , 1
@@ -132,16 +150,16 @@ run(testRunPipeline)
 # CHECK-LABEL: TEST: testRunPipelineError
 @run
 def testRunPipelineError():
-  with Context() as ctx:
-    ctx.allow_unregistered_dialects = True
-    op = Operation.parse('"test.op"() : () -> ()')
-    pm = PassManager.parse("any(cse)")
-    try:
-      pm.run(op)
-    except MLIRError as e:
-      # CHECK: Exception: <
-      # CHECK:   Failure while executing pass pipeline:
-      # CHECK:   error: "-":1:1: 'test.op' op trying to schedule a pass on an unregistered operation
-      # CHECK:    note: "-":1:1: see current operation: "test.op"() : () -> ()
-      # CHECK: >
-      print(f"Exception: <{e}>")
+    with Context() as ctx:
+        ctx.allow_unregistered_dialects = True
+        op = Operation.parse('"test.op"() : () -> ()')
+        pm = PassManager.parse("any(cse)")
+        try:
+            pm.run(op)
+        except MLIRError as e:
+            # CHECK: Exception: <
+            # CHECK:   Failure while executing pass pipeline:
+            # CHECK:   error: "-":1:1: 'test.op' op trying to schedule a pass on an unregistered operation
+            # CHECK:    note: "-":1:1: see current operation: "test.op"() : () -> ()
+            # CHECK: >
+            print(f"Exception: <{e}>")
index 25d08c7..aa35dbf 100644 (file)
@@ -1 +1 @@
-config.excludes = ['include']
+config.excludes = ["include"]
index 85a1a14..9ea8bdb 100644 (file)
@@ -4,216 +4,223 @@ import gdb.printing
 
 
 class StoragePrinter:
-  """Prints bases of a struct and its fields."""
+    """Prints bases of a struct and its fields."""
 
-  def __init__(self, val):
-    self.val = val
+    def __init__(self, val):
+        self.val = val
 
-  def children(self):
-    for field in self.val.type.fields():
-      if field.is_base_class:
-        yield '<%s>' % field.name, self.val.cast(field.type)
-      else:
-        yield field.name, self.val[field.name]
+    def children(self):
+        for field in self.val.type.fields():
+            if field.is_base_class:
+                yield "<%s>" % field.name, self.val.cast(field.type)
+            else:
+                yield field.name, self.val[field.name]
+
+    def to_string(self):
+        return "mlir::Storage"
 
-  def to_string(self):
-    return 'mlir::Storage'
 
 class TupleTypeStoragePrinter(StoragePrinter):
+    def children(self):
+        for child in StoragePrinter.children(self):
+            yield child
+        pointer_type = gdb.lookup_type("mlir::Type").pointer()
+        elements = (self.val.address + 1).cast(pointer_type)
+        for i in range(self.val["numElements"]):
+            yield "elements[%u]" % i, elements[i]
 
-  def children(self):
-    for child in StoragePrinter.children(self):
-      yield child
-    pointer_type = gdb.lookup_type('mlir::Type').pointer()
-    elements = (self.val.address + 1).cast(pointer_type)
-    for i in range(self.val['numElements']):
-      yield 'elements[%u]' % i, elements[i]
+    def to_string(self):
+        return "mlir::TupleTypeStorage of %u elements" % self.val["numElements"]
 
-  def to_string(self):
-    return 'mlir::TupleTypeStorage of %u elements' % self.val['numElements']
 
 class FusedLocationStoragePrinter(StoragePrinter):
+    def children(self):
+        for child in StoragePrinter.children(self):
+            yield child
+        pointer_type = gdb.lookup_type("mlir::Location").pointer()
+        elements = (self.val.address + 1).cast(pointer_type)
+        for i in range(self.val["numLocs"]):
+            yield "locs[%u]" % i, elements[i]
 
-  def children(self):
-    for child in StoragePrinter.children(self):
-      yield child
-    pointer_type = gdb.lookup_type('mlir::Location').pointer()
-    elements = (self.val.address + 1).cast(pointer_type)
-    for i in range(self.val['numLocs']):
-      yield 'locs[%u]' % i, elements[i]
-
-  def to_string(self):
-    return 'mlir::FusedLocationStorage of %u locs' % self.val['numLocs']
+    def to_string(self):
+        return "mlir::FusedLocationStorage of %u locs" % self.val["numLocs"]
 
 
 class StorageTypeMap:
-  """Maps a TypeID to the corresponding concrete type.
-
-  Types need to be registered by name before the first lookup.
-  """
-
-  def __init__(self):
-    self.map = None
-    self.type_names = []
-
-  def register_type(self, type_name):
-    assert not self.map, 'register_type called after __getitem__'
-    self.type_names += [type_name]
-
-  def _init_map(self):
-    """Lazy initialization  of self.map."""
-    if self.map:
-      return
-    self.map = {}
-    for type_name in self.type_names:
-      concrete_type = gdb.lookup_type(type_name)
-      try:
-        storage = gdb.parse_and_eval(
-            "&'mlir::detail::TypeIDExported::get<%s>()::instance'" % type_name)
-      except gdb.error:
-        # Skip when TypeID instance cannot be found in current context.
-        continue
-      if concrete_type and storage:
-        self.map[int(storage)] = concrete_type
-
-  def __getitem__(self, type_id):
-    self._init_map()
-    return self.map.get(int(type_id['storage']))
+    """Maps a TypeID to the corresponding concrete type.
+
+    Types need to be registered by name before the first lookup.
+    """
+
+    def __init__(self):
+        self.map = None
+        self.type_names = []
+
+    def register_type(self, type_name):
+        assert not self.map, "register_type called after __getitem__"
+        self.type_names += [type_name]
+
+    def _init_map(self):
+        """Lazy initialization  of self.map."""
+        if self.map:
+            return
+        self.map = {}
+        for type_name in self.type_names:
+            concrete_type = gdb.lookup_type(type_name)
+            try:
+                storage = gdb.parse_and_eval(
+                    "&'mlir::detail::TypeIDExported::get<%s>()::instance'" % type_name
+                )
+            except gdb.error:
+                # Skip when TypeID instance cannot be found in current context.
+                continue
+            if concrete_type and storage:
+                self.map[int(storage)] = concrete_type
+
+    def __getitem__(self, type_id):
+        self._init_map()
+        return self.map.get(int(type_id["storage"]))
 
 
 storage_type_map = StorageTypeMap()
 
 
 def get_type_id_printer(val):
-  """Returns a printer of the name of a mlir::TypeID."""
-
-  class TypeIdPrinter:
+    """Returns a printer of the name of a mlir::TypeID."""
 
-    def __init__(self, string):
-      self.string = string
+    class TypeIdPrinter:
+        def __init__(self, string):
+            self.string = string
 
-    def to_string(self):
-      return self.string
+        def to_string(self):
+            return self.string
 
-  concrete_type = storage_type_map[val]
-  if not concrete_type:
-    return None
-  return TypeIdPrinter('mlir::TypeID::get<%s>()' % concrete_type)
+    concrete_type = storage_type_map[val]
+    if not concrete_type:
+        return None
+    return TypeIdPrinter("mlir::TypeID::get<%s>()" % concrete_type)
 
 
 def get_attr_or_type_printer(val, get_type_id):
-  """Returns a printer for mlir::Attribute or mlir::Type."""
-
-  class AttrOrTypePrinter:
-
-    def __init__(self, type_id, impl):
-      self.type_id = type_id
-      self.impl = impl
-
-    def children(self):
-      yield 'typeID', self.type_id
-      yield 'impl', self.impl
-
-    def to_string(self):
-      return 'cast<%s>' % self.impl.type
-
-  if not val['impl']:
-    return None
-  impl = val['impl'].dereference()
-  type_id = get_type_id(impl)
-  concrete_type = storage_type_map[type_id]
-  if not concrete_type:
-    return None
-  # 3rd template argument of StorageUserBase is the storage type.
-  storage_type = concrete_type.fields()[0].type.template_argument(2)
-  if not storage_type:
-    return None
-  return AttrOrTypePrinter(type_id, impl.cast(storage_type))
+    """Returns a printer for mlir::Attribute or mlir::Type."""
+
+    class AttrOrTypePrinter:
+        def __init__(self, type_id, impl):
+            self.type_id = type_id
+            self.impl = impl
+
+        def children(self):
+            yield "typeID", self.type_id
+            yield "impl", self.impl
+
+        def to_string(self):
+            return "cast<%s>" % self.impl.type
+
+    if not val["impl"]:
+        return None
+    impl = val["impl"].dereference()
+    type_id = get_type_id(impl)
+    concrete_type = storage_type_map[type_id]
+    if not concrete_type:
+        return None
+    # 3rd template argument of StorageUserBase is the storage type.
+    storage_type = concrete_type.fields()[0].type.template_argument(2)
+    if not storage_type:
+        return None
+    return AttrOrTypePrinter(type_id, impl.cast(storage_type))
 
 
 class ImplPrinter:
-  """Printer for an instance with a single 'impl' member pointer."""
+    """Printer for an instance with a single 'impl' member pointer."""
 
-  def __init__(self, val):
-    self.val = val
-    self.impl = val['impl']
+    def __init__(self, val):
+        self.val = val
+        self.impl = val["impl"]
 
-  def children(self):
-    if self.impl:
-      yield 'impl', self.impl.dereference()
+    def children(self):
+        if self.impl:
+            yield "impl", self.impl.dereference()
 
-  def to_string(self):
-    return self.val.type.name
+    def to_string(self):
+        return self.val.type.name
 
 
 # Printers of types deriving from Attribute::AttrBase or Type::TypeBase.
 for name in [
     # mlir/IR/Attributes.h
-    'ArrayAttr',
-    'DictionaryAttr',
-    'FloatAttr',
-    'IntegerAttr',
-    'IntegerSetAttr',
-    'OpaqueAttr',
-    'StringAttr',
-    'SymbolRefAttr',
-    'TypeAttr',
-    'UnitAttr',
-    'DenseStringElementsAttr',
-    'DenseIntOrFPElementsAttr',
-    'SparseElementsAttr',
+    "ArrayAttr",
+    "DictionaryAttr",
+    "FloatAttr",
+    "IntegerAttr",
+    "IntegerSetAttr",
+    "OpaqueAttr",
+    "StringAttr",
+    "SymbolRefAttr",
+    "TypeAttr",
+    "UnitAttr",
+    "DenseStringElementsAttr",
+    "DenseIntOrFPElementsAttr",
+    "SparseElementsAttr",
     # mlir/IR/BuiltinTypes.h
-    'ComplexType',
-    'IndexType',
-    'IntegerType',
-    'Float16Type',
-    'Float32Type',
-    'Float64Type',
-    'Float80Type',
-    'Float128Type',
-    'NoneType',
-    'VectorType',
-    'RankedTensorType',
-    'UnrankedTensorType',
-    'MemRefType',
-    'UnrankedMemRefType',
-    'TupleType',
+    "ComplexType",
+    "IndexType",
+    "IntegerType",
+    "Float16Type",
+    "Float32Type",
+    "Float64Type",
+    "Float80Type",
+    "Float128Type",
+    "NoneType",
+    "VectorType",
+    "RankedTensorType",
+    "UnrankedTensorType",
+    "MemRefType",
+    "UnrankedMemRefType",
+    "TupleType",
     # mlir/IR/Location.h
-    'CallSiteLoc',
-    'FileLineColLoc',
-    'FusedLoc',
-    'NameLoc',
-    'OpaqueLoc',
-    'UnknownLoc'
+    "CallSiteLoc",
+    "FileLineColLoc",
+    "FusedLoc",
+    "NameLoc",
+    "OpaqueLoc",
+    "UnknownLoc",
 ]:
-  storage_type_map.register_type('mlir::%s' % name)  # Register for upcasting.
-storage_type_map.register_type('void')  # Register default.
+    storage_type_map.register_type("mlir::%s" % name)  # Register for upcasting.
+storage_type_map.register_type("void")  # Register default.
 
 
-pp = gdb.printing.RegexpCollectionPrettyPrinter('MLIRSupport')
+pp = gdb.printing.RegexpCollectionPrettyPrinter("MLIRSupport")
 
-pp.add_printer('mlir::OperationName', '^mlir::OperationName$', ImplPrinter)
-pp.add_printer('mlir::Value', '^mlir::Value$', ImplPrinter)
+pp.add_printer("mlir::OperationName", "^mlir::OperationName$", ImplPrinter)
+pp.add_printer("mlir::Value", "^mlir::Value$", ImplPrinter)
 
 # Printers for types deriving from AttributeStorage or TypeStorage.
-pp.add_printer('mlir::detail::FusedLocationStorage',
-               '^mlir::detail::FusedLocationStorage',
-               FusedLocationStoragePrinter)
-pp.add_printer('mlir::detail::TupleTypeStorage',
-               '^mlir::detail::TupleTypeStorage$', TupleTypeStoragePrinter)
+pp.add_printer(
+    "mlir::detail::FusedLocationStorage",
+    "^mlir::detail::FusedLocationStorage",
+    FusedLocationStoragePrinter,
+)
+pp.add_printer(
+    "mlir::detail::TupleTypeStorage",
+    "^mlir::detail::TupleTypeStorage$",
+    TupleTypeStoragePrinter,
+)
 
-pp.add_printer('mlir::TypeID', '^mlir::TypeID$', get_type_id_printer)
+pp.add_printer("mlir::TypeID", "^mlir::TypeID$", get_type_id_printer)
 
 
 def add_attr_or_type_printers(name):
-  """Adds printers for mlir::Attribute or mlir::Type and their Storage type."""
-  get_type_id = lambda val: val['abstract%s' % name]['typeID']
-  pp.add_printer('mlir::%s' % name, '^mlir::%s$' % name,
-                 lambda val: get_attr_or_type_printer(val, get_type_id))
+    """Adds printers for mlir::Attribute or mlir::Type and their Storage type."""
+    get_type_id = lambda val: val["abstract%s" % name]["typeID"]
+    pp.add_printer(
+        "mlir::%s" % name,
+        "^mlir::%s$" % name,
+        lambda val: get_attr_or_type_printer(val, get_type_id),
+    )
 
 
 # Upcasting printers of mlir::Attribute and mlir::Type.
-for name in ['Attribute', 'Type']:
-  add_attr_or_type_printers(name)
+for name in ["Attribute", "Type"]:
+    add_attr_or_type_printers(name)
 
 gdb.printing.register_pretty_printer(gdb.current_objfile(), pp)
index 474f812..0210d7a 100755 (executable)
@@ -32,7 +32,7 @@ import os  # Used to advertise this file's name ("autogenerated_note").
 import re
 import sys
 
-ADVERT_BEGIN = '// NOTE: Assertions have been autogenerated by '
+ADVERT_BEGIN = "// NOTE: Assertions have been autogenerated by "
 ADVERT_END = """
 // The script is designed to make adding checks to
 // a test case fast, it is *not* designed to be authoritative
@@ -42,250 +42,249 @@ ADVERT_END = """
 
 
 # Regex command to match an SSA identifier.
-SSA_RE_STR = '[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*'
+SSA_RE_STR = "[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*"
 SSA_RE = re.compile(SSA_RE_STR)
 
 
 # Class used to generate and manage string substitution blocks for SSA value
 # names.
 class SSAVariableNamer:
+    def __init__(self):
+        self.scopes = []
+        self.name_counter = 0
 
-  def __init__(self):
-    self.scopes = []
-    self.name_counter = 0
+    # Generate a substitution name for the given ssa value name.
+    def generate_name(self, ssa_name):
+        variable = "VAL_" + str(self.name_counter)
+        self.name_counter += 1
+        self.scopes[-1][ssa_name] = variable
+        return variable
 
-  # Generate a substitution name for the given ssa value name.
-  def generate_name(self, ssa_name):
-    variable = 'VAL_' + str(self.name_counter)
-    self.name_counter += 1
-    self.scopes[-1][ssa_name] = variable
-    return variable
+    # Push a new variable name scope.
+    def push_name_scope(self):
+        self.scopes.append({})
 
-  # Push a new variable name scope.
-  def push_name_scope(self):
-    self.scopes.append({})
+    # Pop the last variable name scope.
+    def pop_name_scope(self):
+        self.scopes.pop()
 
-  # Pop the last variable name scope.
-  def pop_name_scope(self):
-    self.scopes.pop()
+    # Return the level of nesting (number of pushed scopes).
+    def num_scopes(self):
+        return len(self.scopes)
 
-  # Return the level of nesting (number of pushed scopes).
-  def num_scopes(self):
-    return len(self.scopes)
-
-  # Reset the counter.
-  def clear_counter(self):
-    self.name_counter = 0
+    # Reset the counter.
+    def clear_counter(self):
+        self.name_counter = 0
 
 
 # Process a line of input that has been split at each SSA identifier '%'.
 def process_line(line_chunks, variable_namer):
-  output_line = ''
-
-  # Process the rest that contained an SSA value name.
-  for chunk in line_chunks:
-    m = SSA_RE.match(chunk)
-    ssa_name = m.group(0)
-
-    # Check if an existing variable exists for this name.
-    variable = None
-    for scope in variable_namer.scopes:
-      variable = scope.get(ssa_name)
-      if variable is not None:
-        break
-
-    # If one exists, then output the existing name.
-    if variable is not None:
-      output_line += '%[[' + variable + ']]'
-    else:
-      # Otherwise, generate a new variable.
-      variable = variable_namer.generate_name(ssa_name)
-      output_line += '%[[' + variable + ':.*]]'
+    output_line = ""
+
+    # Process the rest that contained an SSA value name.
+    for chunk in line_chunks:
+        m = SSA_RE.match(chunk)
+        ssa_name = m.group(0)
+
+        # Check if an existing variable exists for this name.
+        variable = None
+        for scope in variable_namer.scopes:
+            variable = scope.get(ssa_name)
+            if variable is not None:
+                break
 
-    # Append the non named group.
-    output_line += chunk[len(ssa_name):]
+        # If one exists, then output the existing name.
+        if variable is not None:
+            output_line += "%[[" + variable + "]]"
+        else:
+            # Otherwise, generate a new variable.
+            variable = variable_namer.generate_name(ssa_name)
+            output_line += "%[[" + variable + ":.*]]"
 
-  return output_line.rstrip() + '\n'
+        # Append the non named group.
+        output_line += chunk[len(ssa_name) :]
+
+    return output_line.rstrip() + "\n"
 
 
 # Process the source file lines. The source file doesn't have to be .mlir.
 def process_source_lines(source_lines, note, args):
-  source_split_re = re.compile(args.source_delim_regex)
+    source_split_re = re.compile(args.source_delim_regex)
 
-  source_segments = [[]]
-  for line in source_lines:
-    # Remove previous note.
-    if line == note:
-      continue
-    # Remove previous CHECK lines.
-    if line.find(args.check_prefix) != -1:
-      continue
-    # Segment the file based on --source_delim_regex.
-    if source_split_re.search(line):
-      source_segments.append([])
+    source_segments = [[]]
+    for line in source_lines:
+        # Remove previous note.
+        if line == note:
+            continue
+        # Remove previous CHECK lines.
+        if line.find(args.check_prefix) != -1:
+            continue
+        # Segment the file based on --source_delim_regex.
+        if source_split_re.search(line):
+            source_segments.append([])
 
-    source_segments[-1].append(line + '\n')
-  return source_segments
+        source_segments[-1].append(line + "\n")
+    return source_segments
 
 
 # Pre-process a line of input to remove any character sequences that will be
 # problematic with FileCheck.
 def preprocess_line(line):
-  # Replace any double brackets, '[[' with escaped replacements. '[['
-  # corresponds to variable names in FileCheck.
-  output_line = line.replace('[[', '{{\\[\\[}}')
+    # Replace any double brackets, '[[' with escaped replacements. '[['
+    # corresponds to variable names in FileCheck.
+    output_line = line.replace("[[", "{{\\[\\[}}")
 
-  # Replace any single brackets that are followed by an SSA identifier, the
-  # identifier will be replace by a variable; Creating the same situation as
-  # above.
-  output_line = output_line.replace('[%', '{{\\[}}%')
+    # Replace any single brackets that are followed by an SSA identifier, the
+    # identifier will be replace by a variable; Creating the same situation as
+    # above.
+    output_line = output_line.replace("[%", "{{\\[}}%")
 
-  return output_line
+    return output_line
 
 
 def main():
-  parser = argparse.ArgumentParser(
-      description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
-  parser.add_argument(
-      '--check-prefix', default='CHECK', help='Prefix to use from check file.')
-  parser.add_argument(
-      '-o',
-      '--output',
-      nargs='?',
-      type=argparse.FileType('w'),
-      default=None)
-  parser.add_argument(
-      'input',
-      nargs='?',
-      type=argparse.FileType('r'),
-      default=sys.stdin)
-  parser.add_argument(
-      '--source', type=str,
-      help='Print each CHECK chunk before each delimeter line in the source'
-           'file, respectively. The delimeter lines are identified by '
-           '--source_delim_regex.')
-  parser.add_argument('--source_delim_regex', type=str, default='func @')
-  parser.add_argument(
-      '--starts_from_scope', type=int, default=1,
-      help='Omit the top specified level of content. For example, by default '
-           'it omits "module {"')
-  parser.add_argument('-i', '--inplace', action='store_true', default=False)
-
-  args = parser.parse_args()
-
-  # Open the given input file.
-  input_lines = [l.rstrip() for l in args.input]
-  args.input.close()
-
-  # Generate a note used for the generated check file.
-  script_name = os.path.basename(__file__)
-  autogenerated_note = (ADVERT_BEGIN + 'utils/' + script_name + "\n" + ADVERT_END)
-
-  source_segments = None
-  if args.source:
-    source_segments = process_source_lines(
-        [l.rstrip() for l in open(args.source, 'r')],
-        autogenerated_note,
-        args
+    parser = argparse.ArgumentParser(
+        description=__doc__, formatter_class=argparse.RawTextHelpFormatter
+    )
+    parser.add_argument(
+        "--check-prefix", default="CHECK", help="Prefix to use from check file."
+    )
+    parser.add_argument(
+        "-o", "--output", nargs="?", type=argparse.FileType("w"), default=None
+    )
+    parser.add_argument(
+        "input", nargs="?", type=argparse.FileType("r"), default=sys.stdin
     )
+    parser.add_argument(
+        "--source",
+        type=str,
+        help="Print each CHECK chunk before each delimeter line in the source"
+        "file, respectively. The delimeter lines are identified by "
+        "--source_delim_regex.",
+    )
+    parser.add_argument("--source_delim_regex", type=str, default="func @")
+    parser.add_argument(
+        "--starts_from_scope",
+        type=int,
+        default=1,
+        help="Omit the top specified level of content. For example, by default "
+        'it omits "module {"',
+    )
+    parser.add_argument("-i", "--inplace", action="store_true", default=False)
+
+    args = parser.parse_args()
+
+    # Open the given input file.
+    input_lines = [l.rstrip() for l in args.input]
+    args.input.close()
 
-  if args.inplace:
-    assert args.output is None
-    output = open(args.source, 'w')
-  elif args.output is None:
-    output = sys.stdout
-  else:
-    output = args.output
-
-  output_segments = [[]]
-  # A map containing data used for naming SSA value names.
-  variable_namer = SSAVariableNamer()
-  for input_line in input_lines:
-    if not input_line:
-      continue
-    lstripped_input_line = input_line.lstrip()
-
-    # Lines with blocks begin with a ^. These lines have a trailing comment
-    # that needs to be stripped.
-    is_block = lstripped_input_line[0] == '^'
-    if is_block:
-      input_line = input_line.rsplit('//', 1)[0].rstrip()
-
-    cur_level = variable_namer.num_scopes()
-
-    # If the line starts with a '}', pop the last name scope.
-    if lstripped_input_line[0] == '}':
-      variable_namer.pop_name_scope()
-      cur_level = variable_namer.num_scopes()
-
-    # If the line ends with a '{', push a new name scope.
-    if input_line[-1] == '{':
-      variable_namer.push_name_scope()
-      if cur_level == args.starts_from_scope:
-        output_segments.append([])
-
-    # Omit lines at the near top level e.g. "module {".
-    if cur_level < args.starts_from_scope:
-      continue
-
-    if len(output_segments[-1]) == 0:
-      variable_namer.clear_counter()
-
-    # Preprocess the input to remove any sequences that may be problematic with
-    # FileCheck.
-    input_line = preprocess_line(input_line)
-
-    # Split the line at the each SSA value name.
-    ssa_split = input_line.split('%')
-
-    # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'.
-    if len(output_segments[-1]) != 0 or not ssa_split[0]:
-      output_line = '// ' + args.check_prefix + ': '
-      # Pad to align with the 'LABEL' statements.
-      output_line += (' ' * len('-LABEL'))
-
-      # Output the first line chunk that does not contain an SSA name.
-      output_line += ssa_split[0]
-
-      # Process the rest of the input line.
-      output_line += process_line(ssa_split[1:], variable_namer)
+    # Generate a note used for the generated check file.
+    script_name = os.path.basename(__file__)
+    autogenerated_note = ADVERT_BEGIN + "utils/" + script_name + "\n" + ADVERT_END
 
+    source_segments = None
+    if args.source:
+        source_segments = process_source_lines(
+            [l.rstrip() for l in open(args.source, "r")], autogenerated_note, args
+        )
+
+    if args.inplace:
+        assert args.output is None
+        output = open(args.source, "w")
+    elif args.output is None:
+        output = sys.stdout
+    else:
+        output = args.output
+
+    output_segments = [[]]
+    # A map containing data used for naming SSA value names.
+    variable_namer = SSAVariableNamer()
+    for input_line in input_lines:
+        if not input_line:
+            continue
+        lstripped_input_line = input_line.lstrip()
+
+        # Lines with blocks begin with a ^. These lines have a trailing comment
+        # that needs to be stripped.
+        is_block = lstripped_input_line[0] == "^"
+        if is_block:
+            input_line = input_line.rsplit("//", 1)[0].rstrip()
+
+        cur_level = variable_namer.num_scopes()
+
+        # If the line starts with a '}', pop the last name scope.
+        if lstripped_input_line[0] == "}":
+            variable_namer.pop_name_scope()
+            cur_level = variable_namer.num_scopes()
+
+        # If the line ends with a '{', push a new name scope.
+        if input_line[-1] == "{":
+            variable_namer.push_name_scope()
+            if cur_level == args.starts_from_scope:
+                output_segments.append([])
+
+        # Omit lines at the near top level e.g. "module {".
+        if cur_level < args.starts_from_scope:
+            continue
+
+        if len(output_segments[-1]) == 0:
+            variable_namer.clear_counter()
+
+        # Preprocess the input to remove any sequences that may be problematic with
+        # FileCheck.
+        input_line = preprocess_line(input_line)
+
+        # Split the line at the each SSA value name.
+        ssa_split = input_line.split("%")
+
+        # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'.
+        if len(output_segments[-1]) != 0 or not ssa_split[0]:
+            output_line = "// " + args.check_prefix + ": "
+            # Pad to align with the 'LABEL' statements.
+            output_line += " " * len("-LABEL")
+
+            # Output the first line chunk that does not contain an SSA name.
+            output_line += ssa_split[0]
+
+            # Process the rest of the input line.
+            output_line += process_line(ssa_split[1:], variable_namer)
+
+        else:
+            # Output the first line chunk that does not contain an SSA name for the
+            # label.
+            output_line = "// " + args.check_prefix + "-LABEL: " + ssa_split[0] + "\n"
+
+            # Process the rest of the input line on separate check lines.
+            for argument in ssa_split[1:]:
+                output_line += "// " + args.check_prefix + "-SAME:  "
+
+                # Pad to align with the original position in the line.
+                output_line += " " * len(ssa_split[0])
+
+                # Process the rest of the line.
+                output_line += process_line([argument], variable_namer)
+
+        # Append the output line.
+        output_segments[-1].append(output_line)
+
+    output.write(autogenerated_note + "\n")
+
+    # Write the output.
+    if source_segments:
+        assert len(output_segments) == len(source_segments)
+        for check_segment, source_segment in zip(output_segments, source_segments):
+            for line in check_segment:
+                output.write(line)
+            for line in source_segment:
+                output.write(line)
     else:
-      # Output the first line chunk that does not contain an SSA name for the
-      # label.
-      output_line = '// ' + args.check_prefix + '-LABEL: ' + ssa_split[0] + '\n'
-
-      # Process the rest of the input line on separate check lines.
-      for argument in ssa_split[1:]:
-        output_line += '// ' + args.check_prefix + '-SAME:  '
-
-        # Pad to align with the original position in the line.
-        output_line += ' ' * len(ssa_split[0])
-
-        # Process the rest of the line.
-        output_line += process_line([argument], variable_namer)
-
-    # Append the output line.
-    output_segments[-1].append(output_line)
-
-  output.write(autogenerated_note + '\n')
-
-  # Write the output.
-  if source_segments:
-    assert len(output_segments) == len(source_segments)
-    for check_segment, source_segment in zip(output_segments, source_segments):
-      for line in check_segment:
-        output.write(line)
-      for line in source_segment:
-        output.write(line)
-  else:
-    for segment in output_segments:
-      output.write('\n')
-      for output_line in segment:
-        output.write(output_line)
-    output.write('\n')
-  output.close()
-
-
-if __name__ == '__main__':
-  main()
+        for segment in output_segments:
+            output.write("\n")
+            for output_line in segment:
+                output.write(output_line)
+        output.write("\n")
+    output.close()
+
+
+if __name__ == "__main__":
+    main()
index 02582f9..21994ff 100644 (file)
@@ -4,4 +4,5 @@
 
 from ipykernel.kernelapp import IPKernelApp
 from .kernel import MlirOptKernel
+
 IPKernelApp.launch_instance(kernel_class=MlirOptKernel)
index ddb37c8..bd7b1d1 100644 (file)
@@ -10,12 +10,11 @@ from jupyter_client.kernelspec import KernelSpecManager
 
 def install_my_kernel_spec(user=True, prefix=None):
     """Install the kernel spec for user in given prefix."""
-    print('Installing mlir-opt IPython kernel spec')
+    print("Installing mlir-opt IPython kernel spec")
     pkgroot = os.path.dirname(__file__)
-    KernelSpecManager().install_kernel_spec(os.path.join(pkgroot, 'assets'),
-                                            'mlir',
-                                            user=user,
-                                            prefix=prefix)
+    KernelSpecManager().install_kernel_spec(
+        os.path.join(pkgroot, "assets"), "mlir", user=user, prefix=prefix
+    )
 
 
 def _is_root():
@@ -29,15 +28,16 @@ def _is_root():
 
 def main(argv=None):
     parser = argparse.ArgumentParser(
-        description='Install KernelSpec for MlirOpt Kernel')
+        description="Install KernelSpec for MlirOpt Kernel"
+    )
     prefix_locations = parser.add_mutually_exclusive_group()
 
-    prefix_locations.add_argument('--user',
-                                  help='Install in user home directory',
-                                  action='store_true')
-    prefix_locations.add_argument('--prefix',
-                                  help='Install directory prefix',
-                                  default=None)
+    prefix_locations.add_argument(
+        "--user", help="Install in user home directory", action="store_true"
+    )
+    prefix_locations.add_argument(
+        "--prefix", help="Install directory prefix", default=None
+    )
 
     args = parser.parse_args(argv)
 
@@ -47,5 +47,5 @@ def main(argv=None):
     install_my_kernel_spec(user=user, prefix=prefix)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
index 85462da..c0e4fc1 100644 (file)
@@ -9,7 +9,7 @@ import tempfile
 import traceback
 from ipykernel.kernelbase import Kernel
 
-__version__ = '0.0.1'
+__version__ = "0.0.1"
 
 
 def _get_executable():
@@ -19,7 +19,7 @@ def _get_executable():
         """Returns whether executable file."""
         return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
 
-    program = os.environ.get('MLIR_OPT_EXECUTABLE', 'mlir-opt')
+    program = os.environ.get("MLIR_OPT_EXECUTABLE", "mlir-opt")
     path, name = os.path.split(program)
     # Attempt to get the executable
     if path:
@@ -30,7 +30,7 @@ def _get_executable():
             file = os.path.join(path, name)
             if is_exe(file):
                 return file
-    raise OSError('mlir-opt not found, please see README')
+    raise OSError("mlir-opt not found, please see README")
 
 
 class MlirOptKernel(Kernel):
@@ -51,19 +51,17 @@ class MlirOptKernel(Kernel):
     ```
     """
 
-    implementation = 'mlir'
+    implementation = "mlir"
     implementation_version = __version__
 
     language_version = __version__
     language = "mlir"
     language_info = {
         "name": "mlir",
-        "codemirror_mode": {
-            "name": "mlir"
-        },
+        "codemirror_mode": {"name": "mlir"},
         "mimetype": "text/x-mlir",
         "file_extension": ".mlir",
-        "pygments_lexer": "text"
+        "pygments_lexer": "text",
     }
 
     @property
@@ -88,31 +86,28 @@ class MlirOptKernel(Kernel):
         """Reports regular command output."""
         if not self.silent:
             # Send standard output
-            stream_content = {'name': 'stdout', 'text': output}
-            self.send_response(self.iopub_socket, 'stream', stream_content)
+            stream_content = {"name": "stdout", "text": output}
+            self.send_response(self.iopub_socket, "stream", stream_content)
 
     def process_error(self, output):
         """Reports error response."""
         if not self.silent:
             # Send standard error
-            stream_content = {'name': 'stderr', 'text': output}
-            self.send_response(self.iopub_socket, 'stream', stream_content)
-
-    def do_execute(self,
-                   code,
-                   silent,
-                   store_history=True,
-                   user_expressions=None,
-                   allow_stdin=False):
+            stream_content = {"name": "stderr", "text": output}
+            self.send_response(self.iopub_socket, "stream", stream_content)
+
+    def do_execute(
+        self, code, silent, store_history=True, user_expressions=None, allow_stdin=False
+    ):
         """Execute user code using mlir-opt binary."""
 
         def ok_status():
             """Returns OK status."""
             return {
-                'status': 'ok',
-                'execution_count': self.execution_count,
-                'payload': [],
-                'user_expressions': {}
+                "status": "ok",
+                "execution_count": self.execution_count,
+                "payload": [],
+                "user_expressions": {},
             }
 
         def run(code):
@@ -123,29 +118,27 @@ class MlirOptKernel(Kernel):
                     # Specify input and output file to error out if also
                     # set as arg.
                     self.get_executable(),
-                    '--color',
+                    "--color",
                     inputmlir.name,
-                    '-o',
-                    '-'
+                    "-o",
+                    "-",
                 ]
                 # Simple handling of repeating last line.
-                if code.endswith('\n_'):
+                if code.endswith("\n_"):
                     if not self._:
-                        raise NameError('No previous result set')
+                        raise NameError("No previous result set")
                     code = code[:-1] + self._
                 inputmlir.write(code.encode("utf-8"))
                 inputmlir.close()
-                pipe = Popen(command,
-                             stdout=subprocess.PIPE,
-                             stderr=subprocess.PIPE)
+                pipe = Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
                 output, errors = pipe.communicate()
                 exitcode = pipe.returncode
             finally:
                 os.unlink(inputmlir.name)
 
-# Replace temporary filename with placeholder. This takes the very
-# remote chance where the full input filename (generated above)
-# overlaps with something in the dump unrelated to the file.
+            # Replace temporary filename with placeholder. This takes the very
+            # remote chance where the full input filename (generated above)
+            # overlaps with something in the dump unrelated to the file.
             fname = inputmlir.name.encode("utf-8")
             output = output.replace(fname, b"<<input>>")
             errors = errors.replace(fname, b"<<input>>")
@@ -163,7 +156,7 @@ class MlirOptKernel(Kernel):
             else:
                 self._ = output.decode("utf-8")
         except KeyboardInterrupt:
-            return {'status': 'abort', 'execution_count': self.execution_count}
+            return {"status": "abort", "execution_count": self.execution_count}
         except Exception as error:
             # Print traceback for local debugging.
             traceback.print_exc()
@@ -172,24 +165,24 @@ class MlirOptKernel(Kernel):
             errors = repr(error).encode("utf-8")
 
         if exitcode:
-            content = {'ename': '', 'evalue': str(exitcode), 'traceback': []}
+            content = {"ename": "", "evalue": str(exitcode), "traceback": []}
 
-            self.send_response(self.iopub_socket, 'error', content)
+            self.send_response(self.iopub_socket, "error", content)
             self.process_error(errors.decode("utf-8"))
 
-            content['execution_count'] = self.execution_count
-            content['status'] = 'error'
+            content["execution_count"] = self.execution_count
+            content["status"] = "error"
             return content
 
         if not silent:
             data = {}
-            data['text/x-mlir'] = self._
+            data["text/x-mlir"] = self._
             content = {
-                'execution_count': self.execution_count,
-                'data': data,
-                'metadata': {}
+                "execution_count": self.execution_count,
+                "data": data,
+                "metadata": {},
             }
-            self.send_response(self.iopub_socket, 'execute_result', content)
+            self.send_response(self.iopub_socket, "execute_result", content)
             self.process_output(self._)
             self.process_error(errors.decode("utf-8"))
         return ok_status()
index bfd76a7..5d06b40 100644 (file)
@@ -521,8 +521,7 @@ class InDirectRangeSynthProvider:
 
 
 class IPListRangeSynthProvider:
-    """Define an LLDB synthetic children provider for an IPList.
-    """
+    """Define an LLDB synthetic children provider for an IPList."""
 
     def __init__(self, valobj, internal_dict):
         self.valobj = valobj
@@ -575,8 +574,7 @@ class IPListRangeSynthProvider:
 
 
 class ValueSynthProvider:
-    """Define an LLDB synthetic children provider for Values.
-    """
+    """Define an LLDB synthetic children provider for Values."""
 
     def __init__(self, valobj, internal_dict):
         self.valobj = valobj
@@ -677,8 +675,7 @@ class ValueSynthProvider:
 
 
 def ValueSummaryProvider(valobj: lldb.SBValue, internal_dict):
-    """Define an LLDB summary provider for Values.
-    """
+    """Define an LLDB summary provider for Values."""
 
     index = valobj.GetChildMemberWithName("index").GetValueAsUnsigned()
     # Check if this is a block argument or not (block arguments have locations).
index 3e47ec8..d01befd 100644 (file)
@@ -9,5 +9,6 @@ class BenchmarkRunConfig:
     class. The `compiler` attribute is optional, for example for python
     benchmarks.
     """
+
     runner: typing.Callable
     compiler: typing.Optional[typing.Callable] = None
index 37cc458..6c9803e 100644 (file)
@@ -16,21 +16,17 @@ def discover_benchmark_modules(top_level_path):
     defaults to "benchmark_"
     """
     config = configparser.ConfigParser()
-    config.read(
-        os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.ini")
-    )
+    config.read(os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.ini"))
     if "discovery" in config.sections():
         filename_prefix = config["discovery"]["filename_prefix"]
     else:
         filename_prefix = "benchmark_"
-    if re.search(fr"{filename_prefix}.*.py$", top_level_path):
+    if re.search(rf"{filename_prefix}.*.py$", top_level_path):
         # A specific python file so just include that.
         benchmark_files = [top_level_path]
     else:
         # A directory so recursively search for all python files.
-        benchmark_files = pathlib.Path(
-            top_level_path
-        ).rglob(f"{filename_prefix}*.py")
+        benchmark_files = pathlib.Path(top_level_path).rglob(f"{filename_prefix}*.py")
     for benchmark_filename in benchmark_files:
         benchmark_abs_dir = os.path.abspath(os.path.dirname(benchmark_filename))
         sys.path.append(benchmark_abs_dir)
@@ -46,9 +42,7 @@ def get_benchmark_functions(module, benchmark_function_name=None):
     a specific prefix, which defaults to "benchmark_".
     """
     config = configparser.ConfigParser()
-    config.read(
-        os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.ini")
-    )
+    config.read(os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.ini"))
     if "discovery" in config.sections():
         function_prefix = config["discovery"].get("function_prefix")
     else:
@@ -57,9 +51,8 @@ def get_benchmark_functions(module, benchmark_function_name=None):
     module_functions = []
     for attribute_name in dir(module):
         attribute = getattr(module, attribute_name)
-        if (
-            isinstance(attribute, types.FunctionType)
-            and attribute_name.startswith(function_prefix)
+        if isinstance(attribute, types.FunctionType) and attribute_name.startswith(
+            function_prefix
         ):
             module_functions.append(attribute)
 
index 0f67454..5d301ab 100644 (file)
@@ -12,8 +12,7 @@ from stats import has_enough_measurements
 
 
 def main(top_level_path, stop_on_error):
-    """Top level function called when the CLI is invoked.
-    """
+    """Top level function called when the CLI is invoked."""
     if "::" in top_level_path:
         if top_level_path.count("::") > 1:
             raise AssertionError(f"Invalid path {top_level_path}")
@@ -22,16 +21,14 @@ def main(top_level_path, stop_on_error):
         benchmark_function_name = None
 
     if not os.path.exists(top_level_path):
-        raise AssertionError(
-            f"The top-level path {top_level_path} doesn't exist"
-        )
+        raise AssertionError(f"The top-level path {top_level_path} doesn't exist")
 
     modules = [module for module in discover_benchmark_modules(top_level_path)]
     benchmark_dicts = []
     for module in modules:
         benchmark_functions = [
-            function for function in
-            get_benchmark_functions(module, benchmark_function_name)
+            function
+            for function in get_benchmark_functions(module, benchmark_function_name)
         ]
         for benchmark_function in benchmark_functions:
             try:
@@ -96,10 +93,9 @@ def main(top_level_path, stop_on_error):
 
             if len(measurements_ns) > 0:
                 measurements_s = [t * 1e-9 for t in measurements_ns]
-                benchmark_identifier = ":".join([
-                    module.__name__,
-                    benchmark_function.__name__
-                ])
+                benchmark_identifier = ":".join(
+                    [module.__name__, benchmark_function.__name__]
+                )
                 benchmark_dicts.append(
                     {
                         "name": benchmark_identifier,
index 3288021..9b7a3dc 100644 (file)
@@ -16,9 +16,7 @@ def has_enough_measurements(measurements):
     If 1. is true, 2. doesn't need to be true.
     """
     config = configparser.ConfigParser()
-    config.read(
-        os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.cfg")
-    )
+    config.read(os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.cfg"))
     if "stats" in config:
         stats_dict = {
             "max_number_of_measurements": int(
@@ -34,6 +32,6 @@ def has_enough_measurements(measurements):
             "max_time_for_a_benchmark_ns": 1e9,
         }
     return (
-        np.sum(measurements) >= stats_dict["max_time_for_a_benchmark_ns"] or
-        np.size(measurements) >= stats_dict["max_number_of_measurements"]
+        np.sum(measurements) >= stats_dict["max_time_for_a_benchmark_ns"]
+        or np.size(measurements) >= stats_dict["max_number_of_measurements"]
     )
index aeb1827..426bfca 100755 (executable)
@@ -23,1088 +23,1164 @@ import requests
 import textwrap
 import yaml
 
-SPIRV_HTML_SPEC_URL = 'https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html'
-SPIRV_JSON_SPEC_URL = 'https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/spirv.core.grammar.json'
+SPIRV_HTML_SPEC_URL = (
+    "https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html"
+)
+SPIRV_JSON_SPEC_URL = "https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/spirv.core.grammar.json"
 
-SPIRV_CL_EXT_HTML_SPEC_URL = 'https://www.khronos.org/registry/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html'
-SPIRV_CL_EXT_JSON_SPEC_URL = 'https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/extinst.opencl.std.100.grammar.json'
+SPIRV_CL_EXT_HTML_SPEC_URL = "https://www.khronos.org/registry/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html"
+SPIRV_CL_EXT_JSON_SPEC_URL = "https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/extinst.opencl.std.100.grammar.json"
 
-AUTOGEN_OP_DEF_SEPARATOR = '\n// -----\n\n'
-AUTOGEN_ENUM_SECTION_MARKER = 'enum section. Generated from SPIR-V spec; DO NOT MODIFY!'
+AUTOGEN_OP_DEF_SEPARATOR = "\n// -----\n\n"
+AUTOGEN_ENUM_SECTION_MARKER = "enum section. Generated from SPIR-V spec; DO NOT MODIFY!"
 AUTOGEN_OPCODE_SECTION_MARKER = (
-    'opcode section. Generated from SPIR-V spec; DO NOT MODIFY!')
+    "opcode section. Generated from SPIR-V spec; DO NOT MODIFY!"
+)
+
 
 def get_spirv_doc_from_html_spec(url, settings):
-  """Extracts instruction documentation from SPIR-V HTML spec.
-
-  Returns:
-    - A dict mapping from instruction opcode to documentation.
-  """
-  if url is None:
-    url = SPIRV_HTML_SPEC_URL
-
-  response = requests.get(url)
-  spec = response.content
-
-  from bs4 import BeautifulSoup
-  spirv = BeautifulSoup(spec, 'html.parser')
-
-  doc = {}
-
-  if settings.gen_cl_ops:
-    section_anchor = spirv.find('h2', {'id': '_binary_form'})
-    for section in section_anchor.parent.find_all('div', {'class': 'sect2'}):
-      for table in section.find_all('table'):
-        inst_html = table.tbody.tr.td
-        opname = inst_html.a['id']
-        # Ignore the first line, which is just the opname.
-        doc[opname] = inst_html.text.split('\n', 1)[1].strip()
-  else:
-    section_anchor = spirv.find('h3', {'id': '_instructions_3'})
-    for section in section_anchor.parent.find_all('div', {'class': 'sect3'}):
-      for table in section.find_all('table'):
-        inst_html = table.tbody.tr.td.p
-        opname = inst_html.a['id']
-        # Ignore the first line, which is just the opname.
-        doc[opname] = inst_html.text.split('\n', 1)[1].strip()
-
-  return doc
+    """Extracts instruction documentation from SPIR-V HTML spec.
+
+    Returns:
+      - A dict mapping from instruction opcode to documentation.
+    """
+    if url is None:
+        url = SPIRV_HTML_SPEC_URL
+
+    response = requests.get(url)
+    spec = response.content
+
+    from bs4 import BeautifulSoup
+
+    spirv = BeautifulSoup(spec, "html.parser")
+
+    doc = {}
+
+    if settings.gen_cl_ops:
+        section_anchor = spirv.find("h2", {"id": "_binary_form"})
+        for section in section_anchor.parent.find_all("div", {"class": "sect2"}):
+            for table in section.find_all("table"):
+                inst_html = table.tbody.tr.td
+                opname = inst_html.a["id"]
+                # Ignore the first line, which is just the opname.
+                doc[opname] = inst_html.text.split("\n", 1)[1].strip()
+    else:
+        section_anchor = spirv.find("h3", {"id": "_instructions_3"})
+        for section in section_anchor.parent.find_all("div", {"class": "sect3"}):
+            for table in section.find_all("table"):
+                inst_html = table.tbody.tr.td.p
+                opname = inst_html.a["id"]
+                # Ignore the first line, which is just the opname.
+                doc[opname] = inst_html.text.split("\n", 1)[1].strip()
+
+    return doc
 
 
 def get_spirv_grammar_from_json_spec(url):
-  """Extracts operand kind and instruction grammar from SPIR-V JSON spec.
+    """Extracts operand kind and instruction grammar from SPIR-V JSON spec.
 
-  Returns:
-    - A list containing all operand kinds' grammar
-    - A list containing all instructions' grammar
-  """
-  response = requests.get(SPIRV_JSON_SPEC_URL)
-  spec = response.content
+    Returns:
+      - A list containing all operand kinds' grammar
+      - A list containing all instructions' grammar
+    """
+    response = requests.get(SPIRV_JSON_SPEC_URL)
+    spec = response.content
 
-  import json
-  spirv = json.loads(spec)
+    import json
 
-  if url is None:
-    return spirv['operand_kinds'], spirv['instructions']
+    spirv = json.loads(spec)
 
-  response_ext = requests.get(url)
-  spec_ext = response_ext.content
-  spirv_ext = json.loads(spec_ext)
+    if url is None:
+        return spirv["operand_kinds"], spirv["instructions"]
 
-  return spirv['operand_kinds'], spirv_ext['instructions']
+    response_ext = requests.get(url)
+    spec_ext = response_ext.content
+    spirv_ext = json.loads(spec_ext)
+
+    return spirv["operand_kinds"], spirv_ext["instructions"]
 
 
 def split_list_into_sublists(items):
-  """Split the list of items into multiple sublists.
+    """Split the list of items into multiple sublists.
 
-  This is to make sure the string composed from each sublist won't exceed
-  80 characters.
+    This is to make sure the string composed from each sublist won't exceed
+    80 characters.
 
-  Arguments:
-    - items: a list of strings
-  """
-  chuncks = []
-  chunk = []
-  chunk_len = 0
+    Arguments:
+      - items: a list of strings
+    """
+    chuncks = []
+    chunk = []
+    chunk_len = 0
 
-  for item in items:
-    chunk_len += len(item) + 2
-    if chunk_len > 80:
-      chuncks.append(chunk)
-      chunk = []
-      chunk_len = len(item) + 2
-    chunk.append(item)
+    for item in items:
+        chunk_len += len(item) + 2
+        if chunk_len > 80:
+            chuncks.append(chunk)
+            chunk = []
+            chunk_len = len(item) + 2
+        chunk.append(item)
 
-  if len(chunk) != 0:
-    chuncks.append(chunk)
+    if len(chunk) != 0:
+        chuncks.append(chunk)
 
-  return chuncks
+    return chuncks
 
 
 def uniquify_enum_cases(lst):
-  """Prunes duplicate enum cases from the list.
-
-  Arguments:
-   - lst: List whose elements are to be uniqued. Assumes each element is a
-     (symbol, value) pair and elements already sorted according to value.
-
-  Returns:
-   - A list with all duplicates removed. The elements are sorted according to
-     value and, for each value, uniqued according to symbol.
-     original list,
-   - A map from deduplicated cases to the uniqued case.
-  """
-  cases = lst
-  uniqued_cases = []
-  duplicated_cases = {}
-
-  # First sort according to the value
-  cases.sort(key=lambda x: x[1])
-
-  # Then group them according to the value
-  for _, groups in itertools.groupby(cases, key=lambda x: x[1]):
-    # For each value, sort according to the enumerant symbol.
-    sorted_group = sorted(groups, key=lambda x: x[0])
-    # Keep the "smallest" case, which is typically the symbol without extension
-    # suffix. But we have special cases that we want to fix.
-    case = sorted_group[0]
-    for i in range(1, len(sorted_group)):
-      duplicated_cases[sorted_group[i][0]] = case[0]
-    if case[0] == 'HlslSemanticGOOGLE':
-      assert len(sorted_group) == 2, 'unexpected new variant for HlslSemantic'
-      case = sorted_group[1]
-      duplicated_cases[sorted_group[0][0]] = case[0]
-    uniqued_cases.append(case)
-
-  return uniqued_cases, duplicated_cases
+    """Prunes duplicate enum cases from the list.
+
+    Arguments:
+     - lst: List whose elements are to be uniqued. Assumes each element is a
+       (symbol, value) pair and elements already sorted according to value.
+
+    Returns:
+     - A list with all duplicates removed. The elements are sorted according to
+       value and, for each value, uniqued according to symbol.
+       original list,
+     - A map from deduplicated cases to the uniqued case.
+    """
+    cases = lst
+    uniqued_cases = []
+    duplicated_cases = {}
+
+    # First sort according to the value
+    cases.sort(key=lambda x: x[1])
+
+    # Then group them according to the value
+    for _, groups in itertools.groupby(cases, key=lambda x: x[1]):
+        # For each value, sort according to the enumerant symbol.
+        sorted_group = sorted(groups, key=lambda x: x[0])
+        # Keep the "smallest" case, which is typically the symbol without extension
+        # suffix. But we have special cases that we want to fix.
+        case = sorted_group[0]
+        for i in range(1, len(sorted_group)):
+            duplicated_cases[sorted_group[i][0]] = case[0]
+        if case[0] == "HlslSemanticGOOGLE":
+            assert len(sorted_group) == 2, "unexpected new variant for HlslSemantic"
+            case = sorted_group[1]
+            duplicated_cases[sorted_group[0][0]] = case[0]
+        uniqued_cases.append(case)
+
+    return uniqued_cases, duplicated_cases
 
 
 def toposort(dag, sort_fn):
-  """Topologically sorts the given dag.
+    """Topologically sorts the given dag.
 
-  Arguments:
-    - dag: a dict mapping from a node to its incoming nodes.
-    - sort_fn: a function for sorting nodes in the same batch.
+    Arguments:
+      - dag: a dict mapping from a node to its incoming nodes.
+      - sort_fn: a function for sorting nodes in the same batch.
 
-  Returns:
-    A list containing topologically sorted nodes.
-  """
+    Returns:
+      A list containing topologically sorted nodes.
+    """
 
-  # Returns the next batch of nodes without incoming edges
-  def get_next_batch(dag):
-    while True:
-      no_prev_nodes = set(node for node, prev in dag.items() if not prev)
-      if not no_prev_nodes:
-        break
-      yield sorted(no_prev_nodes, key=sort_fn)
-      dag = {
-          node: (prev - no_prev_nodes)
-          for node, prev in dag.items()
-          if node not in no_prev_nodes
-      }
-    assert not dag, 'found cyclic dependency'
+    # Returns the next batch of nodes without incoming edges
+    def get_next_batch(dag):
+        while True:
+            no_prev_nodes = set(node for node, prev in dag.items() if not prev)
+            if not no_prev_nodes:
+                break
+            yield sorted(no_prev_nodes, key=sort_fn)
+            dag = {
+                node: (prev - no_prev_nodes)
+                for node, prev in dag.items()
+                if node not in no_prev_nodes
+            }
+        assert not dag, "found cyclic dependency"
 
-  sorted_nodes = []
-  for batch in get_next_batch(dag):
-    sorted_nodes.extend(batch)
+    sorted_nodes = []
+    for batch in get_next_batch(dag):
+        sorted_nodes.extend(batch)
 
-  return sorted_nodes
+    return sorted_nodes
 
 
 def toposort_capabilities(all_cases, capability_mapping):
-  """Returns topologically sorted capability (symbol, value) pairs.
-
-  Arguments:
-    - all_cases: all capability cases (containing symbol, value, and implied
-      capabilities).
-    - capability_mapping: mapping from duplicated capability symbols to the
-      canonicalized symbol chosen for SPIRVBase.td.
-
-  Returns:
-    A list containing topologically sorted capability (symbol, value) pairs.
-  """
-  dag = {}
-  name_to_value = {}
-  for case in all_cases:
-    # Get the current capability.
-    cur = case['enumerant']
-    name_to_value[cur] = case['value']
-    # Ignore duplicated symbols.
-    if cur in capability_mapping:
-      continue
-
-    # Get capabilities implied by the current capability.
-    prev = case.get('capabilities', [])
-    uniqued_prev = set([capability_mapping.get(c, c) for c in prev])
-    dag[cur] = uniqued_prev
-
-  sorted_caps = toposort(dag, lambda x: name_to_value[x])
-  # Attach the capability's value as the second component of the pair.
-  return [(c, name_to_value[c]) for c in sorted_caps]
+    """Returns topologically sorted capability (symbol, value) pairs.
+
+    Arguments:
+      - all_cases: all capability cases (containing symbol, value, and implied
+        capabilities).
+      - capability_mapping: mapping from duplicated capability symbols to the
+        canonicalized symbol chosen for SPIRVBase.td.
+
+    Returns:
+      A list containing topologically sorted capability (symbol, value) pairs.
+    """
+    dag = {}
+    name_to_value = {}
+    for case in all_cases:
+        # Get the current capability.
+        cur = case["enumerant"]
+        name_to_value[cur] = case["value"]
+        # Ignore duplicated symbols.
+        if cur in capability_mapping:
+            continue
+
+        # Get capabilities implied by the current capability.
+        prev = case.get("capabilities", [])
+        uniqued_prev = set([capability_mapping.get(c, c) for c in prev])
+        dag[cur] = uniqued_prev
+
+    sorted_caps = toposort(dag, lambda x: name_to_value[x])
+    # Attach the capability's value as the second component of the pair.
+    return [(c, name_to_value[c]) for c in sorted_caps]
 
 
 def get_capability_mapping(operand_kinds):
-  """Returns the capability mapping from duplicated cases to canonicalized ones.
+    """Returns the capability mapping from duplicated cases to canonicalized ones.
 
-  Arguments:
-    - operand_kinds: all operand kinds' grammar spec
+    Arguments:
+      - operand_kinds: all operand kinds' grammar spec
 
-  Returns:
-    - A map mapping from duplicated capability symbols to the canonicalized
-      symbol chosen for SPIRVBase.td.
-  """
-  # Find the operand kind for capability
-  cap_kind = {}
-  for kind in operand_kinds:
-    if kind['kind'] == 'Capability':
-      cap_kind = kind
+    Returns:
+      - A map mapping from duplicated capability symbols to the canonicalized
+        symbol chosen for SPIRVBase.td.
+    """
+    # Find the operand kind for capability
+    cap_kind = {}
+    for kind in operand_kinds:
+        if kind["kind"] == "Capability":
+            cap_kind = kind
 
-  kind_cases = [
-      (case['enumerant'], case['value']) for case in cap_kind['enumerants']
-  ]
-  _, capability_mapping = uniquify_enum_cases(kind_cases)
+    kind_cases = [(case["enumerant"], case["value"]) for case in cap_kind["enumerants"]]
+    _, capability_mapping = uniquify_enum_cases(kind_cases)
 
-  return capability_mapping
+    return capability_mapping
 
 
 def get_availability_spec(enum_case, capability_mapping, for_op, for_cap):
-  """Returns the availability specification string for the given enum case.
-
-  Arguments:
-    - enum_case: the enum case to generate availability spec for. It may contain
-      'version', 'lastVersion', 'extensions', or 'capabilities'.
-    - capability_mapping: mapping from duplicated capability symbols to the
-      canonicalized symbol chosen for SPIRVBase.td.
-    - for_op: bool value indicating whether this is the availability spec for an
-      op itself.
-    - for_cap: bool value indicating whether this is the availability spec for
-      capabilities themselves.
-
-  Returns:
-    - A `let availability = [...];` string if with availability spec or
-      empty string if without availability spec
-  """
-  assert not (for_op and for_cap), 'cannot set both for_op and for_cap'
-
-  DEFAULT_MIN_VERSION = 'MinVersion<SPIRV_V_1_0>'
-  DEFAULT_MAX_VERSION = 'MaxVersion<SPIRV_V_1_6>'
-  DEFAULT_CAP = 'Capability<[]>'
-  DEFAULT_EXT = 'Extension<[]>'
-
-  min_version = enum_case.get('version', '')
-  if min_version == 'None':
-    min_version = ''
-  elif min_version:
-    min_version = 'MinVersion<SPIRV_V_{}>'.format(min_version.replace('.', '_'))
-  # TODO: delete this once ODS can support dialect-specific content
-  # and we can use omission to mean no requirements.
-  if for_op and not min_version:
-    min_version = DEFAULT_MIN_VERSION
-
-  max_version = enum_case.get('lastVersion', '')
-  if max_version:
-    max_version = 'MaxVersion<SPIRV_V_{}>'.format(max_version.replace('.', '_'))
-  # TODO: delete this once ODS can support dialect-specific content
-  # and we can use omission to mean no requirements.
-  if for_op and not max_version:
-    max_version = DEFAULT_MAX_VERSION
-
-  exts = enum_case.get('extensions', [])
-  if exts:
-    exts = 'Extension<[{}]>'.format(', '.join(sorted(set(exts))))
-    # We need to strip the minimal version requirement if this symbol is
-    # available via an extension, which means *any* SPIR-V version can support
-    # it as long as the extension is provided. The grammar's 'version' field
-    # under such case should be interpreted as this symbol is introduced as
-    # a core symbol since the given version, rather than a minimal version
-    # requirement.
-    min_version = DEFAULT_MIN_VERSION if for_op else ''
-  # TODO: delete this once ODS can support dialect-specific content
-  # and we can use omission to mean no requirements.
-  if for_op and not exts:
-    exts = DEFAULT_EXT
-
-  caps = enum_case.get('capabilities', [])
-  implies = ''
-  if caps:
-    canonicalized_caps = []
-    for c in caps:
-      if c in capability_mapping:
-        canonicalized_caps.append(capability_mapping[c])
-      else:
-        canonicalized_caps.append(c)
-    prefixed_caps = [
-        'SPIRV_C_{}'.format(c) for c in sorted(set(canonicalized_caps))
-    ]
-    if for_cap:
-      # If this is generating the availability for capabilities, we need to
-      # put the capability "requirements" in implies field because now
-      # the "capabilities" field in the source grammar means so.
-      caps = ''
-      implies = 'list<I32EnumAttrCase> implies = [{}];'.format(
-          ', '.join(prefixed_caps))
-    else:
-      caps = 'Capability<[{}]>'.format(', '.join(prefixed_caps))
-      implies = ''
-  # TODO: delete this once ODS can support dialect-specific content
-  # and we can use omission to mean no requirements.
-  if for_op and not caps:
-    caps = DEFAULT_CAP
-
-  avail = ''
-  # Compose availability spec if any of the requirements is not empty.
-  # For ops, because we have a default in SPIRV_Op class, omit if the spec
-  # is the same.
-  if (min_version or max_version or caps or exts) and not (
-      for_op and min_version == DEFAULT_MIN_VERSION and
-      max_version == DEFAULT_MAX_VERSION and caps == DEFAULT_CAP and
-      exts == DEFAULT_EXT):
-    joined_spec = ',\n    '.join(
-        [e for e in [min_version, max_version, exts, caps] if e])
-    avail = '{} availability = [\n    {}\n  ];'.format(
-        'let' if for_op else 'list<Availability>', joined_spec)
-
-  return '{}{}{}'.format(implies, '\n  ' if implies and avail else '', avail)
+    """Returns the availability specification string for the given enum case.
+
+    Arguments:
+      - enum_case: the enum case to generate availability spec for. It may contain
+        'version', 'lastVersion', 'extensions', or 'capabilities'.
+      - capability_mapping: mapping from duplicated capability symbols to the
+        canonicalized symbol chosen for SPIRVBase.td.
+      - for_op: bool value indicating whether this is the availability spec for an
+        op itself.
+      - for_cap: bool value indicating whether this is the availability spec for
+        capabilities themselves.
+
+    Returns:
+      - A `let availability = [...];` string if with availability spec or
+        empty string if without availability spec
+    """
+    assert not (for_op and for_cap), "cannot set both for_op and for_cap"
+
+    DEFAULT_MIN_VERSION = "MinVersion<SPIRV_V_1_0>"
+    DEFAULT_MAX_VERSION = "MaxVersion<SPIRV_V_1_6>"
+    DEFAULT_CAP = "Capability<[]>"
+    DEFAULT_EXT = "Extension<[]>"
+
+    min_version = enum_case.get("version", "")
+    if min_version == "None":
+        min_version = ""
+    elif min_version:
+        min_version = "MinVersion<SPIRV_V_{}>".format(min_version.replace(".", "_"))
+    # TODO: delete this once ODS can support dialect-specific content
+    # and we can use omission to mean no requirements.
+    if for_op and not min_version:
+        min_version = DEFAULT_MIN_VERSION
+
+    max_version = enum_case.get("lastVersion", "")
+    if max_version:
+        max_version = "MaxVersion<SPIRV_V_{}>".format(max_version.replace(".", "_"))
+    # TODO: delete this once ODS can support dialect-specific content
+    # and we can use omission to mean no requirements.
+    if for_op and not max_version:
+        max_version = DEFAULT_MAX_VERSION
+
+    exts = enum_case.get("extensions", [])
+    if exts:
+        exts = "Extension<[{}]>".format(", ".join(sorted(set(exts))))
+        # We need to strip the minimal version requirement if this symbol is
+        # available via an extension, which means *any* SPIR-V version can support
+        # it as long as the extension is provided. The grammar's 'version' field
+        # under such case should be interpreted as this symbol is introduced as
+        # a core symbol since the given version, rather than a minimal version
+        # requirement.
+        min_version = DEFAULT_MIN_VERSION if for_op else ""
+    # TODO: delete this once ODS can support dialect-specific content
+    # and we can use omission to mean no requirements.
+    if for_op and not exts:
+        exts = DEFAULT_EXT
+
+    caps = enum_case.get("capabilities", [])
+    implies = ""
+    if caps:
+        canonicalized_caps = []
+        for c in caps:
+            if c in capability_mapping:
+                canonicalized_caps.append(capability_mapping[c])
+            else:
+                canonicalized_caps.append(c)
+        prefixed_caps = [
+            "SPIRV_C_{}".format(c) for c in sorted(set(canonicalized_caps))
+        ]
+        if for_cap:
+            # If this is generating the availability for capabilities, we need to
+            # put the capability "requirements" in implies field because now
+            # the "capabilities" field in the source grammar means so.
+            caps = ""
+            implies = "list<I32EnumAttrCase> implies = [{}];".format(
+                ", ".join(prefixed_caps)
+            )
+        else:
+            caps = "Capability<[{}]>".format(", ".join(prefixed_caps))
+            implies = ""
+    # TODO: delete this once ODS can support dialect-specific content
+    # and we can use omission to mean no requirements.
+    if for_op and not caps:
+        caps = DEFAULT_CAP
+
+    avail = ""
+    # Compose availability spec if any of the requirements is not empty.
+    # For ops, because we have a default in SPIRV_Op class, omit if the spec
+    # is the same.
+    if (min_version or max_version or caps or exts) and not (
+        for_op
+        and min_version == DEFAULT_MIN_VERSION
+        and max_version == DEFAULT_MAX_VERSION
+        and caps == DEFAULT_CAP
+        and exts == DEFAULT_EXT
+    ):
+        joined_spec = ",\n    ".join(
+            [e for e in [min_version, max_version, exts, caps] if e]
+        )
+        avail = "{} availability = [\n    {}\n  ];".format(
+            "let" if for_op else "list<Availability>", joined_spec
+        )
+
+    return "{}{}{}".format(implies, "\n  " if implies and avail else "", avail)
 
 
 def gen_operand_kind_enum_attr(operand_kind, capability_mapping):
-  """Generates the TableGen EnumAttr definition for the given operand kind.
-
-  Returns:
-    - The operand kind's name
-    - A string containing the TableGen EnumAttr definition
-  """
-  if 'enumerants' not in operand_kind:
-    return '', ''
-
-  # Returns a symbol for the given case in the given kind. This function
-  # handles Dim specially to avoid having numbers as the start of symbols,
-  # which does not play well with C++ and the MLIR parser.
-  def get_case_symbol(kind_name, case_name):
-    if kind_name == 'Dim':
-      if case_name == '1D' or case_name == '2D' or case_name == '3D':
-        return 'Dim{}'.format(case_name)
-    return case_name
-
-  kind_name = operand_kind['kind']
-  is_bit_enum = operand_kind['category'] == 'BitEnum'
-  kind_acronym = ''.join([c for c in kind_name if c >= 'A' and c <= 'Z'])
-
-  name_to_case_dict = {}
-  for case in operand_kind['enumerants']:
-    name_to_case_dict[case['enumerant']] = case
-
-  if kind_name == 'Capability':
-    # Special treatment for capability cases: we need to sort them topologically
-    # because a capability can refer to another via the 'implies' field.
-    kind_cases = toposort_capabilities(operand_kind['enumerants'],
-                                       capability_mapping)
-  else:
-    kind_cases = [(case['enumerant'], case['value'])
-                  for case in operand_kind['enumerants']]
-    kind_cases, _ = uniquify_enum_cases(kind_cases)
-  max_len = max([len(symbol) for (symbol, _) in kind_cases])
-
-  # Generate the definition for each enum case
-  case_category = 'I32Bit' if is_bit_enum else 'I32'
-  fmt_str = 'def SPIRV_{acronym}_{case_name} {colon:>{offset}} '\
-            '{category}EnumAttrCase{suffix}<"{symbol}"{case_value_part}>{avail}'
-  case_defs = []
-  for case_pair in kind_cases:
-    name = case_pair[0]
-    if is_bit_enum:
-      value = int(case_pair[1], base=16)
+    """Generates the TableGen EnumAttr definition for the given operand kind.
+
+    Returns:
+      - The operand kind's name
+      - A string containing the TableGen EnumAttr definition
+    """
+    if "enumerants" not in operand_kind:
+        return "", ""
+
+    # Returns a symbol for the given case in the given kind. This function
+    # handles Dim specially to avoid having numbers as the start of symbols,
+    # which does not play well with C++ and the MLIR parser.
+    def get_case_symbol(kind_name, case_name):
+        if kind_name == "Dim":
+            if case_name == "1D" or case_name == "2D" or case_name == "3D":
+                return "Dim{}".format(case_name)
+        return case_name
+
+    kind_name = operand_kind["kind"]
+    is_bit_enum = operand_kind["category"] == "BitEnum"
+    kind_acronym = "".join([c for c in kind_name if c >= "A" and c <= "Z"])
+
+    name_to_case_dict = {}
+    for case in operand_kind["enumerants"]:
+        name_to_case_dict[case["enumerant"]] = case
+
+    if kind_name == "Capability":
+        # Special treatment for capability cases: we need to sort them topologically
+        # because a capability can refer to another via the 'implies' field.
+        kind_cases = toposort_capabilities(
+            operand_kind["enumerants"], capability_mapping
+        )
     else:
-      value = int(case_pair[1])
-    avail = get_availability_spec(name_to_case_dict[name],
-                                  capability_mapping,
-                                  False, kind_name == 'Capability')
-    if is_bit_enum:
-      if value == 0:
-        suffix = 'None'
-        value = ''
-      else:
-        suffix = "Bit"
-        value = ', {}'.format(int(math.log2(value)))
-    else:
-        suffix = ''
-        value = ', {}'.format(value)
-
-    case_def = fmt_str.format(
-        category=case_category,
-        suffix=suffix,
-        acronym=kind_acronym,
-        case_name=name,
-        symbol=get_case_symbol(kind_name, name),
-        case_value_part=value,
-        avail=' {{\n  {}\n}}'.format(avail) if avail else ';',
-        colon=':',
-        offset=(max_len + 1 - len(name)))
-    case_defs.append(case_def)
-  case_defs = '\n'.join(case_defs)
-
-  # Generate the list of enum case names
-  fmt_str = 'SPIRV_{acronym}_{symbol}';
-  case_names = [fmt_str.format(acronym=kind_acronym,symbol=case[0])
-                for case in kind_cases]
-
-  # Split them into sublists and concatenate into multiple lines
-  case_names = split_list_into_sublists(case_names)
-  case_names = ['{:6}'.format('') + ', '.join(sublist)
-                for sublist in case_names]
-  case_names = ',\n'.join(case_names)
-
-  # Generate the enum attribute definition
-  kind_category = 'Bit' if is_bit_enum else 'I32'
-  enum_attr = '''def SPIRV_{name}Attr :
+        kind_cases = [
+            (case["enumerant"], case["value"]) for case in operand_kind["enumerants"]
+        ]
+        kind_cases, _ = uniquify_enum_cases(kind_cases)
+    max_len = max([len(symbol) for (symbol, _) in kind_cases])
+
+    # Generate the definition for each enum case
+    case_category = "I32Bit" if is_bit_enum else "I32"
+    fmt_str = (
+        "def SPIRV_{acronym}_{case_name} {colon:>{offset}} "
+        '{category}EnumAttrCase{suffix}<"{symbol}"{case_value_part}>{avail}'
+    )
+    case_defs = []
+    for case_pair in kind_cases:
+        name = case_pair[0]
+        if is_bit_enum:
+            value = int(case_pair[1], base=16)
+        else:
+            value = int(case_pair[1])
+        avail = get_availability_spec(
+            name_to_case_dict[name],
+            capability_mapping,
+            False,
+            kind_name == "Capability",
+        )
+        if is_bit_enum:
+            if value == 0:
+                suffix = "None"
+                value = ""
+            else:
+                suffix = "Bit"
+                value = ", {}".format(int(math.log2(value)))
+        else:
+            suffix = ""
+            value = ", {}".format(value)
+
+        case_def = fmt_str.format(
+            category=case_category,
+            suffix=suffix,
+            acronym=kind_acronym,
+            case_name=name,
+            symbol=get_case_symbol(kind_name, name),
+            case_value_part=value,
+            avail=" {{\n  {}\n}}".format(avail) if avail else ";",
+            colon=":",
+            offset=(max_len + 1 - len(name)),
+        )
+        case_defs.append(case_def)
+    case_defs = "\n".join(case_defs)
+
+    # Generate the list of enum case names
+    fmt_str = "SPIRV_{acronym}_{symbol}"
+    case_names = [
+        fmt_str.format(acronym=kind_acronym, symbol=case[0]) for case in kind_cases
+    ]
+
+    # Split them into sublists and concatenate into multiple lines
+    case_names = split_list_into_sublists(case_names)
+    case_names = ["{:6}".format("") + ", ".join(sublist) for sublist in case_names]
+    case_names = ",\n".join(case_names)
+
+    # Generate the enum attribute definition
+    kind_category = "Bit" if is_bit_enum else "I32"
+    enum_attr = """def SPIRV_{name}Attr :
     SPIRV_{category}EnumAttr<"{name}", "valid SPIR-V {name}", "{snake_name}", [
 {cases}
-    ]>;'''.format(
-          name=kind_name,
-          snake_name=snake_casify(kind_name),
-          category=kind_category,
-          cases=case_names)
-  return kind_name, case_defs + '\n\n' + enum_attr
+    ]>;""".format(
+        name=kind_name,
+        snake_name=snake_casify(kind_name),
+        category=kind_category,
+        cases=case_names,
+    )
+    return kind_name, case_defs + "\n\n" + enum_attr
 
 
 def gen_opcode(instructions):
-  """ Generates the TableGen definition to map opname to opcode
-
-  Returns:
-    - A string containing the TableGen SPIRV_OpCode definition
-  """
-
-  max_len = max([len(inst['opname']) for inst in instructions])
-  def_fmt_str = 'def SPIRV_OC_{name} {colon:>{offset}} '\
-            'I32EnumAttrCase<"{name}", {value}>;'
-  opcode_defs = [
-      def_fmt_str.format(
-          name=inst['opname'],
-          value=inst['opcode'],
-          colon=':',
-          offset=(max_len + 1 - len(inst['opname']))) for inst in instructions
-  ]
-  opcode_str = '\n'.join(opcode_defs)
-
-  decl_fmt_str = 'SPIRV_OC_{name}'
-  opcode_list = [
-      decl_fmt_str.format(name=inst['opname']) for inst in instructions
-  ]
-  opcode_list = split_list_into_sublists(opcode_list)
-  opcode_list = [
-      '{:6}'.format('') + ', '.join(sublist) for sublist in opcode_list
-  ]
-  opcode_list = ',\n'.join(opcode_list)
-  enum_attr = 'def SPIRV_OpcodeAttr :\n'\
-              '    SPIRV_I32EnumAttr<"{name}", "valid SPIR-V instructions", '\
-              '"opcode", [\n'\
-              '{lst}\n'\
-              '    ]>;'.format(name='Opcode', lst=opcode_list)
-  return opcode_str + '\n\n' + enum_attr
+    """Generates the TableGen definition to map opname to opcode
+
+    Returns:
+      - A string containing the TableGen SPIRV_OpCode definition
+    """
+
+    max_len = max([len(inst["opname"]) for inst in instructions])
+    def_fmt_str = (
+        "def SPIRV_OC_{name} {colon:>{offset}} " 'I32EnumAttrCase<"{name}", {value}>;'
+    )
+    opcode_defs = [
+        def_fmt_str.format(
+            name=inst["opname"],
+            value=inst["opcode"],
+            colon=":",
+            offset=(max_len + 1 - len(inst["opname"])),
+        )
+        for inst in instructions
+    ]
+    opcode_str = "\n".join(opcode_defs)
+
+    decl_fmt_str = "SPIRV_OC_{name}"
+    opcode_list = [decl_fmt_str.format(name=inst["opname"]) for inst in instructions]
+    opcode_list = split_list_into_sublists(opcode_list)
+    opcode_list = ["{:6}".format("") + ", ".join(sublist) for sublist in opcode_list]
+    opcode_list = ",\n".join(opcode_list)
+    enum_attr = (
+        "def SPIRV_OpcodeAttr :\n"
+        '    SPIRV_I32EnumAttr<"{name}", "valid SPIR-V instructions", '
+        '"opcode", [\n'
+        "{lst}\n"
+        "    ]>;".format(name="Opcode", lst=opcode_list)
+    )
+    return opcode_str + "\n\n" + enum_attr
+
 
 def map_cap_to_opnames(instructions):
-  """Maps capabilities to instructions enabled by those capabilities
+    """Maps capabilities to instructions enabled by those capabilities
 
-  Arguments:
-    - instructions: a list containing a subset of SPIR-V instructions' grammar
-  Returns:
-    - A map with keys representing capabilities and values of lists of
-    instructions enabled by the corresponding key
-  """
-  cap_to_inst = {}
+    Arguments:
+      - instructions: a list containing a subset of SPIR-V instructions' grammar
+    Returns:
+      - A map with keys representing capabilities and values of lists of
+      instructions enabled by the corresponding key
+    """
+    cap_to_inst = {}
 
-  for inst in instructions:
-    caps = inst['capabilities'] if 'capabilities' in inst else ['0_core_0']
-    for cap in caps:
-      if cap not in cap_to_inst:
-        cap_to_inst[cap] = []
-      cap_to_inst[cap].append(inst['opname'])
+    for inst in instructions:
+        caps = inst["capabilities"] if "capabilities" in inst else ["0_core_0"]
+        for cap in caps:
+            if cap not in cap_to_inst:
+                cap_to_inst[cap] = []
+            cap_to_inst[cap].append(inst["opname"])
+
+    return cap_to_inst
 
-  return cap_to_inst
 
 def gen_instr_coverage_report(path, instructions):
-  """Dumps to standard output a YAML report of current instruction coverage
+    """Dumps to standard output a YAML report of current instruction coverage
 
-  Arguments:
-    - path: the path to SPIRBase.td
-    - instructions: a list containing all SPIR-V instructions' grammar
-  """
-  with open(path, 'r') as f:
-    content = f.read()
+    Arguments:
+      - path: the path to SPIRBase.td
+      - instructions: a list containing all SPIR-V instructions' grammar
+    """
+    with open(path, "r") as f:
+        content = f.read()
 
-  content = content.split(AUTOGEN_OPCODE_SECTION_MARKER)
+    content = content.split(AUTOGEN_OPCODE_SECTION_MARKER)
 
-  existing_opcodes = [k[11:] for k in re.findall('def SPIRV_OC_\w+', content[1])]
-  existing_instructions = list(
-          filter(lambda inst: (inst['opname'] in existing_opcodes),
-              instructions))
+    existing_opcodes = [k[11:] for k in re.findall("def SPIRV_OC_\w+", content[1])]
+    existing_instructions = list(
+        filter(lambda inst: (inst["opname"] in existing_opcodes), instructions)
+    )
 
-  instructions_opnames = [inst['opname'] for inst in instructions]
+    instructions_opnames = [inst["opname"] for inst in instructions]
 
-  remaining_opcodes = list(set(instructions_opnames) - set(existing_opcodes))
-  remaining_instructions = list(
-          filter(lambda inst: (inst['opname'] in remaining_opcodes),
-              instructions))
+    remaining_opcodes = list(set(instructions_opnames) - set(existing_opcodes))
+    remaining_instructions = list(
+        filter(lambda inst: (inst["opname"] in remaining_opcodes), instructions)
+    )
 
-  rem_cap_to_instr = map_cap_to_opnames(remaining_instructions)
-  ex_cap_to_instr = map_cap_to_opnames(existing_instructions)
+    rem_cap_to_instr = map_cap_to_opnames(remaining_instructions)
+    ex_cap_to_instr = map_cap_to_opnames(existing_instructions)
 
-  rem_cap_to_cov = {}
+    rem_cap_to_cov = {}
 
-  # Calculate coverage for each capability
-  for cap in rem_cap_to_instr:
-    if cap not in ex_cap_to_instr:
-      rem_cap_to_cov[cap] = 0.0
-    else:
-      rem_cap_to_cov[cap] = \
-              (len(ex_cap_to_instr[cap]) / (len(ex_cap_to_instr[cap]) \
-              + len(rem_cap_to_instr[cap])))
+    # Calculate coverage for each capability
+    for cap in rem_cap_to_instr:
+        if cap not in ex_cap_to_instr:
+            rem_cap_to_cov[cap] = 0.0
+        else:
+            rem_cap_to_cov[cap] = len(ex_cap_to_instr[cap]) / (
+                len(ex_cap_to_instr[cap]) + len(rem_cap_to_instr[cap])
+            )
 
-  report = {}
+    report = {}
 
-  # Merge the 3 maps into one report
-  for cap in rem_cap_to_instr:
-    report[cap] = {}
-    report[cap]['Supported Instructions'] = \
+    # Merge the 3 maps into one report
+    for cap in rem_cap_to_instr:
+        report[cap] = {}
+        report[cap]["Supported Instructions"] = (
             ex_cap_to_instr[cap] if cap in ex_cap_to_instr else []
-    report[cap]['Unsupported Instructions']  = rem_cap_to_instr[cap]
-    report[cap]['Coverage'] = '{}%'.format(int(rem_cap_to_cov[cap] * 100))
+        )
+        report[cap]["Unsupported Instructions"] = rem_cap_to_instr[cap]
+        report[cap]["Coverage"] = "{}%".format(int(rem_cap_to_cov[cap] * 100))
+
+    print(yaml.dump(report))
 
-  print(yaml.dump(report))
 
 def update_td_opcodes(path, instructions, filter_list):
-  """Updates SPIRBase.td with new generated opcode cases.
-
-  Arguments:
-    - path: the path to SPIRBase.td
-    - instructions: a list containing all SPIR-V instructions' grammar
-    - filter_list: a list containing new opnames to add
-  """
-
-  with open(path, 'r') as f:
-    content = f.read()
-
-  content = content.split(AUTOGEN_OPCODE_SECTION_MARKER)
-  assert len(content) == 3
-
-  # Extend opcode list with existing list
-  prefix = 'def SPIRV_OC_'
-  existing_opcodes = [k[len(prefix):] for k in re.findall(prefix + '\w+', content[1])]
-  filter_list.extend(existing_opcodes)
-  filter_list = list(set(filter_list))
-
-  # Generate the opcode for all instructions in SPIR-V
-  filter_instrs = list(
-      filter(lambda inst: (inst['opname'] in filter_list), instructions))
-  # Sort instruction based on opcode
-  filter_instrs.sort(key=lambda inst: inst['opcode'])
-  opcode = gen_opcode(filter_instrs)
-
-  # Substitute the opcode
-  content = content[0] + AUTOGEN_OPCODE_SECTION_MARKER + '\n\n' + \
-        opcode + '\n\n// End ' + AUTOGEN_OPCODE_SECTION_MARKER \
+    """Updates SPIRBase.td with new generated opcode cases.
+
+    Arguments:
+      - path: the path to SPIRBase.td
+      - instructions: a list containing all SPIR-V instructions' grammar
+      - filter_list: a list containing new opnames to add
+    """
+
+    with open(path, "r") as f:
+        content = f.read()
+
+    content = content.split(AUTOGEN_OPCODE_SECTION_MARKER)
+    assert len(content) == 3
+
+    # Extend opcode list with existing list
+    prefix = "def SPIRV_OC_"
+    existing_opcodes = [
+        k[len(prefix) :] for k in re.findall(prefix + "\w+", content[1])
+    ]
+    filter_list.extend(existing_opcodes)
+    filter_list = list(set(filter_list))
+
+    # Generate the opcode for all instructions in SPIR-V
+    filter_instrs = list(
+        filter(lambda inst: (inst["opname"] in filter_list), instructions)
+    )
+    # Sort instruction based on opcode
+    filter_instrs.sort(key=lambda inst: inst["opcode"])
+    opcode = gen_opcode(filter_instrs)
+
+    # Substitute the opcode
+    content = (
+        content[0]
+        + AUTOGEN_OPCODE_SECTION_MARKER
+        + "\n\n"
+        + opcode
+        + "\n\n// End "
+        + AUTOGEN_OPCODE_SECTION_MARKER
         + content[2]
+    )
 
-  with open(path, 'w') as f:
-    f.write(content)
+    with open(path, "w") as f:
+        f.write(content)
 
 
 def update_td_enum_attrs(path, operand_kinds, filter_list):
-  """Updates SPIRBase.td with new generated enum definitions.
-
-  Arguments:
-    - path: the path to SPIRBase.td
-    - operand_kinds: a list containing all operand kinds' grammar
-    - filter_list: a list containing new enums to add
-  """
-  with open(path, 'r') as f:
-    content = f.read()
-
-  content = content.split(AUTOGEN_ENUM_SECTION_MARKER)
-  assert len(content) == 3
-
-  # Extend filter list with existing enum definitions
-  existing_kinds = [
-      k[8:-4] for k in re.findall('def SPIRV_\w+Attr', content[1])]
-  filter_list.extend(existing_kinds)
-
-  capability_mapping = get_capability_mapping(operand_kinds)
-
-  # Generate definitions for all enums in filter list
-  defs = [
-      gen_operand_kind_enum_attr(kind, capability_mapping)
-      for kind in operand_kinds
-      if kind['kind'] in filter_list
-  ]
-  # Sort alphabetically according to enum name
-  defs.sort(key=lambda enum : enum[0])
-  # Only keep the definitions from now on
-  # Put Capability's definition at the very beginning because capability cases
-  # will be referenced later
-  defs = [enum[1] for enum in defs if enum[0] == 'Capability'
-         ] + [enum[1] for enum in defs if enum[0] != 'Capability']
-
-  # Substitute the old section
-  content = content[0] + AUTOGEN_ENUM_SECTION_MARKER + '\n\n' + \
-      '\n\n'.join(defs) + "\n\n// End " + AUTOGEN_ENUM_SECTION_MARKER  \
-      + content[2];
-
-  with open(path, 'w') as f:
-    f.write(content)
+    """Updates SPIRBase.td with new generated enum definitions.
+
+    Arguments:
+      - path: the path to SPIRBase.td
+      - operand_kinds: a list containing all operand kinds' grammar
+      - filter_list: a list containing new enums to add
+    """
+    with open(path, "r") as f:
+        content = f.read()
+
+    content = content.split(AUTOGEN_ENUM_SECTION_MARKER)
+    assert len(content) == 3
+
+    # Extend filter list with existing enum definitions
+    existing_kinds = [k[8:-4] for k in re.findall("def SPIRV_\w+Attr", content[1])]
+    filter_list.extend(existing_kinds)
+
+    capability_mapping = get_capability_mapping(operand_kinds)
+
+    # Generate definitions for all enums in filter list
+    defs = [
+        gen_operand_kind_enum_attr(kind, capability_mapping)
+        for kind in operand_kinds
+        if kind["kind"] in filter_list
+    ]
+    # Sort alphabetically according to enum name
+    defs.sort(key=lambda enum: enum[0])
+    # Only keep the definitions from now on
+    # Put Capability's definition at the very beginning because capability cases
+    # will be referenced later
+    defs = [enum[1] for enum in defs if enum[0] == "Capability"] + [
+        enum[1] for enum in defs if enum[0] != "Capability"
+    ]
+
+    # Substitute the old section
+    content = (
+        content[0]
+        + AUTOGEN_ENUM_SECTION_MARKER
+        + "\n\n"
+        + "\n\n".join(defs)
+        + "\n\n// End "
+        + AUTOGEN_ENUM_SECTION_MARKER
+        + content[2]
+    )
+
+    with open(path, "w") as f:
+        f.write(content)
 
 
 def snake_casify(name):
-  """Turns the given name to follow snake_case convention."""
-  return re.sub(r'(?<!^)(?=[A-Z])', '_', name).lower()
+    """Turns the given name to follow snake_case convention."""
+    return re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
 
 
 def map_spec_operand_to_ods_argument(operand):
-  """Maps an operand in SPIR-V JSON spec to an op argument in ODS.
-
-  Arguments:
-    - A dict containing the operand's kind, quantifier, and name
-
-  Returns:
-    - A string containing both the type and name for the argument
-  """
-  kind = operand['kind']
-  quantifier = operand.get('quantifier', '')
-
-  # These instruction "operands" are for encoding the results; they should
-  # not be handled here.
-  assert kind != 'IdResultType', 'unexpected to handle "IdResultType" kind'
-  assert kind != 'IdResult', 'unexpected to handle "IdResult" kind'
-
-  if kind == 'IdRef':
-    if quantifier == '':
-      arg_type = 'SPIRV_Type'
-    elif quantifier == '?':
-      arg_type = 'Optional<SPIRV_Type>'
-    else:
-      arg_type = 'Variadic<SPIRV_Type>'
-  elif kind == 'IdMemorySemantics' or kind == 'IdScope':
-    # TODO: Need to further constrain 'IdMemorySemantics'
-    # and 'IdScope' given that they should be generated from OpConstant.
-    assert quantifier == '', ('unexpected to have optional/variadic memory '
-                              'semantics or scope <id>')
-    arg_type = 'SPIRV_' + kind[2:] + 'Attr'
-  elif kind == 'LiteralInteger':
-    if quantifier == '':
-      arg_type = 'I32Attr'
-    elif quantifier == '?':
-      arg_type = 'OptionalAttr<I32Attr>'
+    """Maps an operand in SPIR-V JSON spec to an op argument in ODS.
+
+    Arguments:
+      - A dict containing the operand's kind, quantifier, and name
+
+    Returns:
+      - A string containing both the type and name for the argument
+    """
+    kind = operand["kind"]
+    quantifier = operand.get("quantifier", "")
+
+    # These instruction "operands" are for encoding the results; they should
+    # not be handled here.
+    assert kind != "IdResultType", 'unexpected to handle "IdResultType" kind'
+    assert kind != "IdResult", 'unexpected to handle "IdResult" kind'
+
+    if kind == "IdRef":
+        if quantifier == "":
+            arg_type = "SPIRV_Type"
+        elif quantifier == "?":
+            arg_type = "Optional<SPIRV_Type>"
+        else:
+            arg_type = "Variadic<SPIRV_Type>"
+    elif kind == "IdMemorySemantics" or kind == "IdScope":
+        # TODO: Need to further constrain 'IdMemorySemantics'
+        # and 'IdScope' given that they should be generated from OpConstant.
+        assert quantifier == "", (
+            "unexpected to have optional/variadic memory " "semantics or scope <id>"
+        )
+        arg_type = "SPIRV_" + kind[2:] + "Attr"
+    elif kind == "LiteralInteger":
+        if quantifier == "":
+            arg_type = "I32Attr"
+        elif quantifier == "?":
+            arg_type = "OptionalAttr<I32Attr>"
+        else:
+            arg_type = "OptionalAttr<I32ArrayAttr>"
+    elif (
+        kind == "LiteralString"
+        or kind == "LiteralContextDependentNumber"
+        or kind == "LiteralExtInstInteger"
+        or kind == "LiteralSpecConstantOpInteger"
+        or kind == "PairLiteralIntegerIdRef"
+        or kind == "PairIdRefLiteralInteger"
+        or kind == "PairIdRefIdRef"
+    ):
+        assert False, '"{}" kind unimplemented'.format(kind)
     else:
-      arg_type = 'OptionalAttr<I32ArrayAttr>'
-  elif kind == 'LiteralString' or \
-      kind == 'LiteralContextDependentNumber' or \
-      kind == 'LiteralExtInstInteger' or \
-      kind == 'LiteralSpecConstantOpInteger' or \
-      kind == 'PairLiteralIntegerIdRef' or \
-      kind == 'PairIdRefLiteralInteger' or \
-      kind == 'PairIdRefIdRef':
-    assert False, '"{}" kind unimplemented'.format(kind)
-  else:
-    # The rest are all enum operands that we represent with op attributes.
-    assert quantifier != '*', 'unexpected to have variadic enum attribute'
-    arg_type = 'SPIRV_{}Attr'.format(kind)
-    if quantifier == '?':
-      arg_type = 'OptionalAttr<{}>'.format(arg_type)
-
-  name = operand.get('name', '')
-  name = snake_casify(name) if name else kind.lower()
-
-  return '{}:${}'.format(arg_type, name)
+        # The rest are all enum operands that we represent with op attributes.
+        assert quantifier != "*", "unexpected to have variadic enum attribute"
+        arg_type = "SPIRV_{}Attr".format(kind)
+        if quantifier == "?":
+            arg_type = "OptionalAttr<{}>".format(arg_type)
+
+    name = operand.get("name", "")
+    name = snake_casify(name) if name else kind.lower()
+
+    return "{}:${}".format(arg_type, name)
 
 
 def get_description(text, appendix):
-  """Generates the description for the given SPIR-V instruction.
-
-  Arguments:
-    - text: Textual description of the operation as string.
-    - appendix: Additional contents to attach in description as string,
-                includking IR examples, and others.
-
-  Returns:
-    - A string that corresponds to the description of the Tablegen op.
-  """
-  fmt_str = '{text}\n\n    <!-- End of AutoGen section -->\n{appendix}\n  '
-  return fmt_str.format(text=text, appendix=appendix)
-
-
-def get_op_definition(instruction, opname, doc, existing_info, capability_mapping, settings):
-  """Generates the TableGen op definition for the given SPIR-V instruction.
-
-  Arguments:
-    - instruction: the instruction's SPIR-V JSON grammar
-    - doc: the instruction's SPIR-V HTML doc
-    - existing_info: a dict containing potential manually specified sections for
-      this instruction
-    - capability_mapping: mapping from duplicated capability symbols to the
-                   canonicalized symbol chosen for SPIRVBase.td
-
-  Returns:
-    - A string containing the TableGen op definition
-  """
-  if settings.gen_cl_ops:
-    fmt_str = ('def SPIRV_{opname}Op : '
-               'SPIRV_{inst_category}<"{opname_src}", {opcode}, <<Insert result type>> > '
-               '{{\n  let summary = {summary};\n\n  let description = '
-               '[{{\n{description}}}];{availability}\n')
-  else:
-    fmt_str = ('def SPIRV_{vendor_name}{opname_src}Op : '
-               'SPIRV_{inst_category}<"{opname_src}"{category_args}, [{traits}]> '
-               '{{\n  let summary = {summary};\n\n  let description = '
-               '[{{\n{description}}}];{availability}\n')
-
-  vendor_name = ''
-  inst_category = existing_info.get('inst_category', 'Op')
-  if inst_category == 'Op':
-    fmt_str +='\n  let arguments = (ins{args});\n\n'\
-              '  let results = (outs{results});\n'
-  elif inst_category.endswith('VendorOp'):
-    vendor_name = inst_category.split('VendorOp')[0].upper()
-    assert len(vendor_name) != 0, 'Invalid instruction category'
-
-  fmt_str +='{extras}'\
-            '}}\n'
-
-  opname_src = instruction['opname']
-  if opname.startswith('Op'):
-    opname_src = opname_src[2:]
-  if len(vendor_name) > 0:
-    assert opname_src.endswith(vendor_name), "op name does not match the instruction category"
-    opname_src = opname_src[:-len(vendor_name)]
-
-  category_args = existing_info.get('category_args', '')
-
-  if '\n' in doc:
-    summary, text = doc.split('\n', 1)
-  else:
-    summary = doc
-    text = ''
-  wrapper = textwrap.TextWrapper(
-      width=76, initial_indent='    ', subsequent_indent='    ')
-
-  # Format summary. If the summary can fit in the same line, we print it out
-  # as a "-quoted string; otherwise, wrap the lines using "[{...}]".
-  summary = summary.strip()
-  if len(summary) + len('  let summary = "";') <= 80:
-    summary = '"{}"'.format(summary)
-  else:
-    summary = '[{{\n{}\n  }}]'.format(wrapper.fill(summary))
-
-  # Wrap text
-  text = text.split('\n')
-  text = [wrapper.fill(line) for line in text if line]
-  text = '\n\n'.join(text)
-
-  operands = instruction.get('operands', [])
-
-  # Op availability
-  avail = get_availability_spec(instruction, capability_mapping, True, False)
-  if avail:
-    avail = '\n\n  {0}'.format(avail)
-
-  # Set op's result
-  results = ''
-  if len(operands) > 0 and operands[0]['kind'] == 'IdResultType':
-    results = '\n    SPIRV_Type:$result\n  '
-    operands = operands[1:]
-  if 'results' in existing_info:
-    results = existing_info['results']
-
-  # Ignore the operand standing for the result <id>
-  if len(operands) > 0 and operands[0]['kind'] == 'IdResult':
-    operands = operands[1:]
-
-  # Set op' argument
-  arguments = existing_info.get('arguments', None)
-  if arguments is None:
-    arguments = [map_spec_operand_to_ods_argument(o) for o in operands]
-    arguments = ',\n    '.join(arguments)
-    if arguments:
-      # Prepend and append whitespace for formatting
-      arguments = '\n    {}\n  '.format(arguments)
-
-  description = existing_info.get('description', None)
-  if description is None:
-    assembly = '\n    ```\n'\
-               '    [TODO]\n'\
-               '    ```\n\n'\
-               '    #### Example:\n\n'\
-               '    ```mlir\n'\
-               '    [TODO]\n' \
-               '    ```'
-    description = get_description(text, assembly)
-
-  return fmt_str.format(
-      opname=opname,
-      opname_src=opname_src,
-      opcode=instruction['opcode'],
-      category_args=category_args,
-      inst_category=inst_category,
-      vendor_name=vendor_name,
-      traits=existing_info.get('traits', ''),
-      summary=summary,
-      description=description,
-      availability=avail,
-      args=arguments,
-      results=results,
-      extras=existing_info.get('extras', ''))
+    """Generates the description for the given SPIR-V instruction.
+
+    Arguments:
+      - text: Textual description of the operation as string.
+      - appendix: Additional contents to attach in description as string,
+                  includking IR examples, and others.
+
+    Returns:
+      - A string that corresponds to the description of the Tablegen op.
+    """
+    fmt_str = "{text}\n\n    <!-- End of AutoGen section -->\n{appendix}\n  "
+    return fmt_str.format(text=text, appendix=appendix)
+
+
+def get_op_definition(
+    instruction, opname, doc, existing_info, capability_mapping, settings
+):
+    """Generates the TableGen op definition for the given SPIR-V instruction.
+
+    Arguments:
+      - instruction: the instruction's SPIR-V JSON grammar
+      - doc: the instruction's SPIR-V HTML doc
+      - existing_info: a dict containing potential manually specified sections for
+        this instruction
+      - capability_mapping: mapping from duplicated capability symbols to the
+                     canonicalized symbol chosen for SPIRVBase.td
+
+    Returns:
+      - A string containing the TableGen op definition
+    """
+    if settings.gen_cl_ops:
+        fmt_str = (
+            "def SPIRV_{opname}Op : "
+            'SPIRV_{inst_category}<"{opname_src}", {opcode}, <<Insert result type>> > '
+            "{{\n  let summary = {summary};\n\n  let description = "
+            "[{{\n{description}}}];{availability}\n"
+        )
+    else:
+        fmt_str = (
+            "def SPIRV_{vendor_name}{opname_src}Op : "
+            'SPIRV_{inst_category}<"{opname_src}"{category_args}, [{traits}]> '
+            "{{\n  let summary = {summary};\n\n  let description = "
+            "[{{\n{description}}}];{availability}\n"
+        )
+
+    vendor_name = ""
+    inst_category = existing_info.get("inst_category", "Op")
+    if inst_category == "Op":
+        fmt_str += (
+            "\n  let arguments = (ins{args});\n\n" "  let results = (outs{results});\n"
+        )
+    elif inst_category.endswith("VendorOp"):
+        vendor_name = inst_category.split("VendorOp")[0].upper()
+        assert len(vendor_name) != 0, "Invalid instruction category"
+
+    fmt_str += "{extras}" "}}\n"
+
+    opname_src = instruction["opname"]
+    if opname.startswith("Op"):
+        opname_src = opname_src[2:]
+    if len(vendor_name) > 0:
+        assert opname_src.endswith(
+            vendor_name
+        ), "op name does not match the instruction category"
+        opname_src = opname_src[: -len(vendor_name)]
+
+    category_args = existing_info.get("category_args", "")
+
+    if "\n" in doc:
+        summary, text = doc.split("\n", 1)
+    else:
+        summary = doc
+        text = ""
+    wrapper = textwrap.TextWrapper(
+        width=76, initial_indent="    ", subsequent_indent="    "
+    )
+
+    # Format summary. If the summary can fit in the same line, we print it out
+    # as a "-quoted string; otherwise, wrap the lines using "[{...}]".
+    summary = summary.strip()
+    if len(summary) + len('  let summary = "";') <= 80:
+        summary = '"{}"'.format(summary)
+    else:
+        summary = "[{{\n{}\n  }}]".format(wrapper.fill(summary))
+
+    # Wrap text
+    text = text.split("\n")
+    text = [wrapper.fill(line) for line in text if line]
+    text = "\n\n".join(text)
+
+    operands = instruction.get("operands", [])
+
+    # Op availability
+    avail = get_availability_spec(instruction, capability_mapping, True, False)
+    if avail:
+        avail = "\n\n  {0}".format(avail)
+
+    # Set op's result
+    results = ""
+    if len(operands) > 0 and operands[0]["kind"] == "IdResultType":
+        results = "\n    SPIRV_Type:$result\n  "
+        operands = operands[1:]
+    if "results" in existing_info:
+        results = existing_info["results"]
+
+    # Ignore the operand standing for the result <id>
+    if len(operands) > 0 and operands[0]["kind"] == "IdResult":
+        operands = operands[1:]
+
+    # Set op' argument
+    arguments = existing_info.get("arguments", None)
+    if arguments is None:
+        arguments = [map_spec_operand_to_ods_argument(o) for o in operands]
+        arguments = ",\n    ".join(arguments)
+        if arguments:
+            # Prepend and append whitespace for formatting
+            arguments = "\n    {}\n  ".format(arguments)
+
+    description = existing_info.get("description", None)
+    if description is None:
+        assembly = (
+            "\n    ```\n"
+            "    [TODO]\n"
+            "    ```\n\n"
+            "    #### Example:\n\n"
+            "    ```mlir\n"
+            "    [TODO]\n"
+            "    ```"
+        )
+        description = get_description(text, assembly)
+
+    return fmt_str.format(
+        opname=opname,
+        opname_src=opname_src,
+        opcode=instruction["opcode"],
+        category_args=category_args,
+        inst_category=inst_category,
+        vendor_name=vendor_name,
+        traits=existing_info.get("traits", ""),
+        summary=summary,
+        description=description,
+        availability=avail,
+        args=arguments,
+        results=results,
+        extras=existing_info.get("extras", ""),
+    )
 
 
 def get_string_between(base, start, end):
-  """Extracts a substring with a specified start and end from a string.
-
-  Arguments:
-    - base: string to extract from.
-    - start: string to use as the start of the substring.
-    - end: string to use as the end of the substring.
-
-  Returns:
-    - The substring if found
-    - The part of the base after end of the substring. Is the base string itself
-      if the substring wasnt found.
-  """
-  split = base.split(start, 1)
-  if len(split) == 2:
-    rest = split[1].split(end, 1)
-    assert len(rest) == 2, \
-           'cannot find end "{end}" while extracting substring '\
-           'starting with {start}'.format(start=start, end=end)
-    return rest[0].rstrip(end), rest[1]
-  return '', split[0]
+    """Extracts a substring with a specified start and end from a string.
+
+    Arguments:
+      - base: string to extract from.
+      - start: string to use as the start of the substring.
+      - end: string to use as the end of the substring.
+
+    Returns:
+      - The substring if found
+      - The part of the base after end of the substring. Is the base string itself
+        if the substring wasnt found.
+    """
+    split = base.split(start, 1)
+    if len(split) == 2:
+        rest = split[1].split(end, 1)
+        assert len(rest) == 2, (
+            'cannot find end "{end}" while extracting substring '
+            "starting with {start}".format(start=start, end=end)
+        )
+        return rest[0].rstrip(end), rest[1]
+    return "", split[0]
 
 
 def get_string_between_nested(base, start, end):
-  """Extracts a substring with a nested start and end from a string.
-
-  Arguments:
-    - base: string to extract from.
-    - start: string to use as the start of the substring.
-    - end: string to use as the end of the substring.
-
-  Returns:
-    - The substring if found
-    - The part of the base after end of the substring. Is the base string itself
-      if the substring wasn't found.
-  """
-  split = base.split(start, 1)
-  if len(split) == 2:
-    # Handle nesting delimiters
-    rest = split[1]
-    unmatched_start = 1
-    index = 0
-    while unmatched_start > 0 and index < len(rest):
-      if rest[index:].startswith(end):
-        unmatched_start -= 1
-        if unmatched_start == 0:
-          break
-        index += len(end)
-      elif rest[index:].startswith(start):
-        unmatched_start += 1
-        index += len(start)
-      else:
-        index += 1
-
-    assert index < len(rest), \
-           'cannot find end "{end}" while extracting substring '\
-           'starting with "{start}"'.format(start=start, end=end)
-    return rest[:index], rest[index + len(end):]
-  return '', split[0]
+    """Extracts a substring with a nested start and end from a string.
+
+    Arguments:
+      - base: string to extract from.
+      - start: string to use as the start of the substring.
+      - end: string to use as the end of the substring.
+
+    Returns:
+      - The substring if found
+      - The part of the base after end of the substring. Is the base string itself
+        if the substring wasn't found.
+    """
+    split = base.split(start, 1)
+    if len(split) == 2:
+        # Handle nesting delimiters
+        rest = split[1]
+        unmatched_start = 1
+        index = 0
+        while unmatched_start > 0 and index < len(rest):
+            if rest[index:].startswith(end):
+                unmatched_start -= 1
+                if unmatched_start == 0:
+                    break
+                index += len(end)
+            elif rest[index:].startswith(start):
+                unmatched_start += 1
+                index += len(start)
+            else:
+                index += 1
+
+        assert index < len(rest), (
+            'cannot find end "{end}" while extracting substring '
+            'starting with "{start}"'.format(start=start, end=end)
+        )
+        return rest[:index], rest[index + len(end) :]
+    return "", split[0]
 
 
 def extract_td_op_info(op_def):
-  """Extracts potentially manually specified sections in op's definition.
-
-  Arguments: - A string containing the op's TableGen definition
-
-  Returns:
-    - A dict containing potential manually specified sections
-  """
-  # Get opname
-  opname = [o[8:-2] for o in re.findall('def SPIRV_\w+Op', op_def)]
-  assert len(opname) == 1, 'more than one ops in the same section!'
-  opname = opname[0]
-
-  # Get instruction category
-  inst_category = [
-      o[4:] for o in re.findall('SPIRV_\w+Op',
-                                op_def.split(':', 1)[1])
-  ]
-  assert len(inst_category) <= 1, 'more than one ops in the same section!'
-  inst_category = inst_category[0] if len(inst_category) == 1 else 'Op'
-
-  # Get category_args
-  op_tmpl_params, _ = get_string_between_nested(op_def, '<', '>')
-  opstringname, rest = get_string_between(op_tmpl_params, '"', '"')
-  category_args = rest.split('[', 1)[0]
-
-  # Get traits
-  traits, _ = get_string_between_nested(rest, '[', ']')
-
-  # Get description
-  description, rest = get_string_between(op_def, 'let description = [{\n',
-                                         '}];\n')
-
-  # Get arguments
-  args, rest = get_string_between(rest, '  let arguments = (ins', ');\n')
-
-  # Get results
-  results, rest = get_string_between(rest, '  let results = (outs', ');\n')
-
-  extras = rest.strip(' }\n')
-  if extras:
-    extras = '\n  {}\n'.format(extras)
-
-  return {
-      # Prefix with 'Op' to make it consistent with SPIR-V spec
-      'opname': 'Op{}'.format(opname),
-      'inst_category': inst_category,
-      'category_args': category_args,
-      'traits': traits,
-      'description': description,
-      'arguments': args,
-      'results': results,
-      'extras': extras
-  }
-
-
-def update_td_op_definitions(path, instructions, docs, filter_list,
-                             inst_category, capability_mapping, settings):
-  """Updates SPIRVOps.td with newly generated op definition.
-
-  Arguments:
-    - path: path to SPIRVOps.td
-    - instructions: SPIR-V JSON grammar for all instructions
-    - docs: SPIR-V HTML doc for all instructions
-    - filter_list: a list containing new opnames to include
-    - capability_mapping: mapping from duplicated capability symbols to the
-                   canonicalized symbol chosen for SPIRVBase.td.
-
-  Returns:
-    - A string containing all the TableGen op definitions
-  """
-  with open(path, 'r') as f:
-    content = f.read()
-
-  # Split the file into chunks, each containing one op.
-  ops = content.split(AUTOGEN_OP_DEF_SEPARATOR)
-  header = ops[0]
-  footer = ops[-1]
-  ops = ops[1:-1]
-
-  # For each existing op, extract the manually-written sections out to retain
-  # them when re-generating the ops. Also append the existing ops to filter
-  # list.
-  name_op_map = {}  # Map from opname to its existing ODS definition
-  op_info_dict = {}
-  for op in ops:
-    info_dict = extract_td_op_info(op)
-    opname = info_dict['opname']
-    name_op_map[opname] = op
-    op_info_dict[opname] = info_dict
-    filter_list.append(opname)
-  filter_list = sorted(list(set(filter_list)))
-
-  op_defs = []
-
-  if settings.gen_cl_ops:
-    fix_opname = lambda src: src.replace('CL','').lower()
-  else:
-    fix_opname = lambda src: src
-
-  for opname in filter_list:
-    # Find the grammar spec for this op
-    try:
-      fixed_opname = fix_opname(opname)
-      instruction = next(
-          inst for inst in instructions if inst['opname'] == fixed_opname)
-
-      op_defs.append(
-          get_op_definition(
-              instruction, opname, docs[fixed_opname],
-              op_info_dict.get(opname, {'inst_category': inst_category}),
-              capability_mapping, settings))
-    except StopIteration:
-      # This is an op added by us; use the existing ODS definition.
-      op_defs.append(name_op_map[opname])
-
-  # Substitute the old op definitions
-  op_defs = [header] + op_defs + [footer]
-  content = AUTOGEN_OP_DEF_SEPARATOR.join(op_defs)
-
-  with open(path, 'w') as f:
-    f.write(content)
-
-
-if __name__ == '__main__':
-  import argparse
-
-  cli_parser = argparse.ArgumentParser(
-      description='Update SPIR-V dialect definitions using SPIR-V spec')
-
-  cli_parser.add_argument(
-      '--base-td-path',
-      dest='base_td_path',
-      type=str,
-      default=None,
-      help='Path to SPIRVBase.td')
-  cli_parser.add_argument(
-      '--op-td-path',
-      dest='op_td_path',
-      type=str,
-      default=None,
-      help='Path to SPIRVOps.td')
-
-  cli_parser.add_argument(
-      '--new-enum',
-      dest='new_enum',
-      type=str,
-      default=None,
-      help='SPIR-V enum to be added to SPIRVBase.td')
-  cli_parser.add_argument(
-      '--new-opcodes',
-      dest='new_opcodes',
-      type=str,
-      default=None,
-      nargs='*',
-      help='update SPIR-V opcodes in SPIRVBase.td')
-  cli_parser.add_argument(
-      '--new-inst',
-      dest='new_inst',
-      type=str,
-      default=None,
-      nargs='*',
-      help='SPIR-V instruction to be added to ops file')
-  cli_parser.add_argument(
-      '--inst-category',
-      dest='inst_category',
-      type=str,
-      default='Op',
-      help='SPIR-V instruction category used for choosing '\
-           'the TableGen base class to define this op')
-  cli_parser.add_argument(
-      '--gen-cl-ops',
-      dest='gen_cl_ops',
-      help='Generate OpenCL Extended Instruction Set op',
-      action='store_true')
-  cli_parser.set_defaults(gen_cl_ops=False)
-  cli_parser.add_argument('--gen-inst-coverage', dest='gen_inst_coverage', action='store_true')
-  cli_parser.set_defaults(gen_inst_coverage=False)
-
-  args = cli_parser.parse_args()
-
-  if args.gen_cl_ops:
-    ext_html_url = SPIRV_CL_EXT_HTML_SPEC_URL
-    ext_json_url = SPIRV_CL_EXT_JSON_SPEC_URL
-  else:
-    ext_html_url = None
-    ext_json_url = None
-
-  operand_kinds, instructions = get_spirv_grammar_from_json_spec(ext_json_url)
-
-  # Define new enum attr
-  if args.new_enum is not None:
-    assert args.base_td_path is not None
-    filter_list = [args.new_enum] if args.new_enum else []
-    update_td_enum_attrs(args.base_td_path, operand_kinds, filter_list)
-
-  # Define new opcode
-  if args.new_opcodes is not None:
-    assert args.base_td_path is not None
-    update_td_opcodes(args.base_td_path, instructions, args.new_opcodes)
-
-  # Define new op
-  if args.new_inst is not None:
-    assert args.op_td_path is not None
-    docs = get_spirv_doc_from_html_spec(ext_html_url, args)
-    capability_mapping = get_capability_mapping(operand_kinds)
-    update_td_op_definitions(args.op_td_path, instructions, docs, args.new_inst,
-                             args.inst_category, capability_mapping, args)
-    print('Done. Note that this script just generates a template; ', end='')
-    print('please read the spec and update traits, arguments, and ', end='')
-    print('results accordingly.')
-
-  if args.gen_inst_coverage:
-    gen_instr_coverage_report(args.base_td_path, instructions)
+    """Extracts potentially manually specified sections in op's definition.
+
+    Arguments: - A string containing the op's TableGen definition
+
+    Returns:
+      - A dict containing potential manually specified sections
+    """
+    # Get opname
+    opname = [o[8:-2] for o in re.findall("def SPIRV_\w+Op", op_def)]
+    assert len(opname) == 1, "more than one ops in the same section!"
+    opname = opname[0]
+
+    # Get instruction category
+    inst_category = [o[4:] for o in re.findall("SPIRV_\w+Op", op_def.split(":", 1)[1])]
+    assert len(inst_category) <= 1, "more than one ops in the same section!"
+    inst_category = inst_category[0] if len(inst_category) == 1 else "Op"
+
+    # Get category_args
+    op_tmpl_params, _ = get_string_between_nested(op_def, "<", ">")
+    opstringname, rest = get_string_between(op_tmpl_params, '"', '"')
+    category_args = rest.split("[", 1)[0]
+
+    # Get traits
+    traits, _ = get_string_between_nested(rest, "[", "]")
+
+    # Get description
+    description, rest = get_string_between(op_def, "let description = [{\n", "}];\n")
+
+    # Get arguments
+    args, rest = get_string_between(rest, "  let arguments = (ins", ");\n")
+
+    # Get results
+    results, rest = get_string_between(rest, "  let results = (outs", ");\n")
+
+    extras = rest.strip(" }\n")
+    if extras:
+        extras = "\n  {}\n".format(extras)
+
+    return {
+        # Prefix with 'Op' to make it consistent with SPIR-V spec
+        "opname": "Op{}".format(opname),
+        "inst_category": inst_category,
+        "category_args": category_args,
+        "traits": traits,
+        "description": description,
+        "arguments": args,
+        "results": results,
+        "extras": extras,
+    }
+
+
+def update_td_op_definitions(
+    path, instructions, docs, filter_list, inst_category, capability_mapping, settings
+):
+    """Updates SPIRVOps.td with newly generated op definition.
+
+    Arguments:
+      - path: path to SPIRVOps.td
+      - instructions: SPIR-V JSON grammar for all instructions
+      - docs: SPIR-V HTML doc for all instructions
+      - filter_list: a list containing new opnames to include
+      - capability_mapping: mapping from duplicated capability symbols to the
+                     canonicalized symbol chosen for SPIRVBase.td.
+
+    Returns:
+      - A string containing all the TableGen op definitions
+    """
+    with open(path, "r") as f:
+        content = f.read()
+
+    # Split the file into chunks, each containing one op.
+    ops = content.split(AUTOGEN_OP_DEF_SEPARATOR)
+    header = ops[0]
+    footer = ops[-1]
+    ops = ops[1:-1]
+
+    # For each existing op, extract the manually-written sections out to retain
+    # them when re-generating the ops. Also append the existing ops to filter
+    # list.
+    name_op_map = {}  # Map from opname to its existing ODS definition
+    op_info_dict = {}
+    for op in ops:
+        info_dict = extract_td_op_info(op)
+        opname = info_dict["opname"]
+        name_op_map[opname] = op
+        op_info_dict[opname] = info_dict
+        filter_list.append(opname)
+    filter_list = sorted(list(set(filter_list)))
+
+    op_defs = []
+
+    if settings.gen_cl_ops:
+        fix_opname = lambda src: src.replace("CL", "").lower()
+    else:
+        fix_opname = lambda src: src
+
+    for opname in filter_list:
+        # Find the grammar spec for this op
+        try:
+            fixed_opname = fix_opname(opname)
+            instruction = next(
+                inst for inst in instructions if inst["opname"] == fixed_opname
+            )
+
+            op_defs.append(
+                get_op_definition(
+                    instruction,
+                    opname,
+                    docs[fixed_opname],
+                    op_info_dict.get(opname, {"inst_category": inst_category}),
+                    capability_mapping,
+                    settings,
+                )
+            )
+        except StopIteration:
+            # This is an op added by us; use the existing ODS definition.
+            op_defs.append(name_op_map[opname])
+
+    # Substitute the old op definitions
+    op_defs = [header] + op_defs + [footer]
+    content = AUTOGEN_OP_DEF_SEPARATOR.join(op_defs)
+
+    with open(path, "w") as f:
+        f.write(content)
+
+
+if __name__ == "__main__":
+    import argparse
+
+    cli_parser = argparse.ArgumentParser(
+        description="Update SPIR-V dialect definitions using SPIR-V spec"
+    )
+
+    cli_parser.add_argument(
+        "--base-td-path",
+        dest="base_td_path",
+        type=str,
+        default=None,
+        help="Path to SPIRVBase.td",
+    )
+    cli_parser.add_argument(
+        "--op-td-path",
+        dest="op_td_path",
+        type=str,
+        default=None,
+        help="Path to SPIRVOps.td",
+    )
+
+    cli_parser.add_argument(
+        "--new-enum",
+        dest="new_enum",
+        type=str,
+        default=None,
+        help="SPIR-V enum to be added to SPIRVBase.td",
+    )
+    cli_parser.add_argument(
+        "--new-opcodes",
+        dest="new_opcodes",
+        type=str,
+        default=None,
+        nargs="*",
+        help="update SPIR-V opcodes in SPIRVBase.td",
+    )
+    cli_parser.add_argument(
+        "--new-inst",
+        dest="new_inst",
+        type=str,
+        default=None,
+        nargs="*",
+        help="SPIR-V instruction to be added to ops file",
+    )
+    cli_parser.add_argument(
+        "--inst-category",
+        dest="inst_category",
+        type=str,
+        default="Op",
+        help="SPIR-V instruction category used for choosing "
+        "the TableGen base class to define this op",
+    )
+    cli_parser.add_argument(
+        "--gen-cl-ops",
+        dest="gen_cl_ops",
+        help="Generate OpenCL Extended Instruction Set op",
+        action="store_true",
+    )
+    cli_parser.set_defaults(gen_cl_ops=False)
+    cli_parser.add_argument(
+        "--gen-inst-coverage", dest="gen_inst_coverage", action="store_true"
+    )
+    cli_parser.set_defaults(gen_inst_coverage=False)
+
+    args = cli_parser.parse_args()
+
+    if args.gen_cl_ops:
+        ext_html_url = SPIRV_CL_EXT_HTML_SPEC_URL
+        ext_json_url = SPIRV_CL_EXT_JSON_SPEC_URL
+    else:
+        ext_html_url = None
+        ext_json_url = None
+
+    operand_kinds, instructions = get_spirv_grammar_from_json_spec(ext_json_url)
+
+    # Define new enum attr
+    if args.new_enum is not None:
+        assert args.base_td_path is not None
+        filter_list = [args.new_enum] if args.new_enum else []
+        update_td_enum_attrs(args.base_td_path, operand_kinds, filter_list)
+
+    # Define new opcode
+    if args.new_opcodes is not None:
+        assert args.base_td_path is not None
+        update_td_opcodes(args.base_td_path, instructions, args.new_opcodes)
+
+    # Define new op
+    if args.new_inst is not None:
+        assert args.op_td_path is not None
+        docs = get_spirv_doc_from_html_spec(ext_html_url, args)
+        capability_mapping = get_capability_mapping(operand_kinds)
+        update_td_op_definitions(
+            args.op_td_path,
+            instructions,
+            docs,
+            args.new_inst,
+            args.inst_category,
+            capability_mapping,
+            args,
+        )
+        print("Done. Note that this script just generates a template; ", end="")
+        print("please read the spec and update traits, arguments, and ", end="")
+        print("results accordingly.")
+
+    if args.gen_inst_coverage:
+        gen_instr_coverage_report(args.base_td_path, instructions)