From f9008e6366c2496b1ca1785b891d5578174ad63e Mon Sep 17 00:00:00 2001 From: Tobias Hieta Date: Wed, 17 May 2023 16:53:39 +0200 Subject: [PATCH] [NFC][Py Reformat] Reformat python files in mlir subdir This is an ongoing series of commits that are reformatting our Python code. Reformatting is done with `black`. If you end up having problems merging this commit because you have made changes to a python file, the best way to handle that is to run git checkout --ours and then reformat it with black. If you run into any problems, post to discourse about it and we will try to help. RFC Thread below: https://discourse.llvm.org/t/rfc-document-and-standardize-python-code-style Differential Revision: https://reviews.llvm.org/D150782 --- mlir/benchmark/python/benchmark_sparse.py | 29 +- mlir/benchmark/python/common.py | 30 +- mlir/examples/standalone/test/CAPI/lit.local.cfg | 2 +- mlir/examples/standalone/test/lit.cfg.py | 43 +- mlir/examples/standalone/test/python/lit.local.cfg | 4 +- mlir/examples/standalone/test/python/smoketest.py | 19 +- mlir/python/mlir/_mlir_libs/__init__.py | 189 +- mlir/python/mlir/dialects/_arith_ops_ext.py | 99 +- .../python/mlir/dialects/_bufferization_ops_ext.py | 57 +- mlir/python/mlir/dialects/_builtin_ops_ext.py | 20 +- mlir/python/mlir/dialects/_func_ops_ext.py | 575 +-- mlir/python/mlir/dialects/_linalg_ops_ext.py | 54 +- .../mlir/dialects/_loop_transform_ops_ext.py | 213 +- mlir/python/mlir/dialects/_memref_ops_ext.py | 46 +- mlir/python/mlir/dialects/_ml_program_ops_ext.py | 199 +- mlir/python/mlir/dialects/_ods_common.py | 262 +- mlir/python/mlir/dialects/_pdl_ops_ext.py | 428 ++- mlir/python/mlir/dialects/_scf_ops_ext.py | 187 +- .../mlir/dialects/_structured_transform_ops_ext.py | 545 +-- mlir/python/mlir/dialects/_tensor_ops_ext.py | 62 +- mlir/python/mlir/dialects/_transform_ops_ext.py | 227 +- .../mlir/dialects/linalg/opdsl/dump_oplib.py | 97 +- .../mlir/dialects/linalg/opdsl/lang/affine.py | 373 +- .../dialects/linalg/opdsl/lang/comprehension.py | 1202 ++++--- .../mlir/dialects/linalg/opdsl/lang/config.py | 846 ++--- mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py | 288 +- .../mlir/dialects/linalg/opdsl/lang/emitter.py | 1078 +++--- .../mlir/dialects/linalg/opdsl/lang/scalar_expr.py | 200 +- .../mlir/dialects/linalg/opdsl/lang/types.py | 44 +- .../mlir/dialects/linalg/opdsl/lang/yaml_helper.py | 43 +- .../dialects/linalg/opdsl/ops/core_named_ops.py | 2175 ++++++------ mlir/python/mlir/dialects/python_test.py | 6 +- mlir/python/mlir/dialects/transform/__init__.py | 18 +- mlir/python/mlir/execution_engine.py | 58 +- mlir/python/mlir/ir.py | 67 +- mlir/python/mlir/runtime/np_to_memref.py | 175 +- mlir/test/CAPI/lit.local.cfg | 2 +- mlir/test/Conversion/GPUToCUDA/lit.local.cfg | 2 +- mlir/test/Conversion/GPUToROCm/lit.local.cfg | 2 +- mlir/test/Examples/Toy/Ch6/lit.local.cfg | 4 +- mlir/test/Examples/Toy/Ch7/lit.local.cfg | 4 +- mlir/test/Examples/lit.local.cfg | 2 +- mlir/test/Examples/standalone/lit.local.cfg | 9 +- .../Integration/Dialect/Async/CPU/lit.local.cfg | 2 +- .../Dialect/LLVMIR/CPU/X86/lit.local.cfg | 2 +- .../Integration/Dialect/LLVMIR/CPU/lit.local.cfg | 18 +- .../Dialect/SparseTensor/CPU/lit.local.cfg | 14 +- .../Dialect/SparseTensor/GPU/CUDA/lit.local.cfg | 2 +- .../Dialect/SparseTensor/python/lit.local.cfg | 4 +- .../Dialect/SparseTensor/python/test_SDDMM.py | 227 +- .../Dialect/SparseTensor/python/test_SpMM.py | 211 +- .../python/test_elementwise_add_sparse_output.py | 83 +- .../Dialect/SparseTensor/python/test_output.py | 108 +- .../Dialect/SparseTensor/python/test_stress.py | 435 +-- .../python/tools/np_to_sparse_tensor.py | 106 +- .../SparseTensor/python/tools/sparse_compiler.py | 53 +- .../Dialect/SparseTensor/taco/lit.local.cfg | 4 +- .../Dialect/SparseTensor/taco/test_MTTKRP.py | 14 +- .../Dialect/SparseTensor/taco/test_SDDMM.py | 20 +- .../Dialect/SparseTensor/taco/test_SpMM.py | 14 +- .../Dialect/SparseTensor/taco/test_SpMV.py | 14 +- .../Dialect/SparseTensor/taco/test_Tensor.py | 57 +- .../taco/test_scalar_tensor_algebra.py | 8 +- .../SparseTensor/taco/test_tensor_complex.py | 26 +- .../Dialect/SparseTensor/taco/test_tensor_types.py | 30 +- .../taco/test_true_dense_tensor_algebra.py | 4 +- .../Dialect/SparseTensor/taco/tools/mlir_pytaco.py | 3702 ++++++++++---------- .../SparseTensor/taco/tools/mlir_pytaco_io.py | 90 +- .../SparseTensor/taco/tools/mlir_pytaco_utils.py | 560 +-- .../taco/tools/mlir_sparse_compiler.py | 52 +- .../SparseTensor/taco/tools/testing_utils.py | 50 +- .../SparseTensor/taco/unit_test_tensor_core.py | 743 ++-- .../SparseTensor/taco/unit_test_tensor_io.py | 118 +- .../SparseTensor/taco/unit_test_tensor_utils.py | 128 +- .../Dialect/Vector/CPU/AMX/lit.local.cfg | 6 +- .../Dialect/Vector/CPU/ArmSME/lit.local.cfg | 2 +- .../Dialect/Vector/CPU/ArmSVE/lit.local.cfg | 2 +- .../Dialect/Vector/CPU/X86Vector/lit.local.cfg | 6 +- .../Dialect/Vector/GPU/CUDA/lit.local.cfg | 2 +- .../Integration/GPU/CUDA/TensorCore/lit.local.cfg | 2 +- mlir/test/Integration/GPU/CUDA/lit.local.cfg | 2 +- mlir/test/Integration/GPU/ROCM/lit.local.cfg | 4 +- mlir/test/Integration/lit.local.cfg | 29 +- mlir/test/Unit/lit.cfg.py | 38 +- mlir/test/lib/Dialect/Test/lit.local.cfg | 2 +- mlir/test/lib/Dialect/Transform/lit.local.cfg | 2 +- mlir/test/lib/Tools/PDLL/lit.local.cfg | 2 +- mlir/test/lib/Transforms/lit.local.cfg | 2 +- mlir/test/lit.cfg.py | 190 +- mlir/test/mlir-cpu-runner/lit.local.cfg | 9 +- mlir/test/mlir-pdll-lsp-server/lit.local.cfg | 2 +- mlir/test/mlir-pdll/lit.local.cfg | 4 +- mlir/test/mlir-spirv-cpu-runner/lit.local.cfg | 2 +- mlir/test/mlir-vulkan-runner/lit.local.cfg | 2 +- mlir/test/python/develoment_files.py | 3 +- mlir/test/python/dialects/arith_dialect.py | 18 +- mlir/test/python/dialects/async_dialect.py | 13 +- mlir/test/python/dialects/builtin.py | 416 +-- mlir/test/python/dialects/complex_dialect.py | 32 +- mlir/test/python/dialects/func.py | 90 +- mlir/test/python/dialects/gpu.py | 13 +- .../test/python/dialects/linalg/opdsl/arguments.py | 12 +- .../python/dialects/linalg/opdsl/assignments.py | 21 +- mlir/test/python/dialects/linalg/opdsl/doctests.py | 7 +- .../dialects/linalg/opdsl/emit_convolution.py | 69 +- .../test/python/dialects/linalg/opdsl/emit_fill.py | 78 +- .../python/dialects/linalg/opdsl/emit_matmul.py | 297 +- .../test/python/dialects/linalg/opdsl/emit_misc.py | 233 +- .../python/dialects/linalg/opdsl/emit_pooling.py | 237 +- .../python/dialects/linalg/opdsl/lit.local.cfg | 4 +- mlir/test/python/dialects/linalg/opdsl/metadata.py | 12 +- .../dialects/linalg/opdsl/shape_maps_iteration.py | 19 +- mlir/test/python/dialects/linalg/ops.py | 239 +- mlir/test/python/dialects/math_dialect.py | 33 +- mlir/test/python/dialects/memref.py | 82 +- mlir/test/python/dialects/ml_program.py | 30 +- mlir/test/python/dialects/ods_helpers.py | 352 +- mlir/test/python/dialects/pdl_ops.py | 237 +- mlir/test/python/dialects/python_test.py | 634 ++-- mlir/test/python/dialects/quant.py | 198 +- mlir/test/python/dialects/scf.py | 119 +- mlir/test/python/dialects/shape.py | 58 +- mlir/test/python/dialects/sparse_tensor/dialect.py | 165 +- mlir/test/python/dialects/sparse_tensor/passes.py | 16 +- mlir/test/python/dialects/tensor.py | 214 +- mlir/test/python/dialects/transform.py | 266 +- mlir/test/python/dialects/transform_loop_ext.py | 116 +- .../python/dialects/transform_structured_ext.py | 328 +- mlir/test/python/dialects/vector.py | 89 +- mlir/test/python/execution_engine.py | 684 ++-- .../python/integration/dialects/linalg/opsrun.py | 700 ++-- mlir/test/python/ir/affine_expr.py | 563 +-- mlir/test/python/ir/affine_map.py | 378 +- mlir/test/python/ir/array_attributes.py | 483 +-- mlir/test/python/ir/attributes.py | 794 +++-- mlir/test/python/ir/blocks.py | 189 +- mlir/test/python/ir/builtin_types.py | 908 ++--- mlir/test/python/ir/context_managers.py | 159 +- mlir/test/python/ir/debug.py | 46 +- mlir/test/python/ir/diagnostic_handler.py | 301 +- mlir/test/python/ir/dialects.py | 140 +- mlir/test/python/ir/exception.py | 122 +- mlir/test/python/ir/insertion_point.py | 209 +- mlir/test/python/ir/integer_set.py | 227 +- mlir/test/python/ir/location.py | 203 +- mlir/test/python/ir/module.py | 183 +- mlir/test/python/ir/operation.py | 1405 ++++---- mlir/test/python/ir/symbol_table.py | 234 +- mlir/test/python/ir/value.py | 358 +- mlir/test/python/lit.local.cfg | 6 +- mlir/test/python/pass_manager.py | 188 +- mlir/test/tblgen-lsp-server/lit.local.cfg | 2 +- mlir/utils/gdb-scripts/prettyprinters.py | 327 +- mlir/utils/generate-test-checks.py | 421 ++- mlir/utils/jupyter/mlir_opt_kernel/__main__.py | 1 + mlir/utils/jupyter/mlir_opt_kernel/install.py | 26 +- mlir/utils/jupyter/mlir_opt_kernel/kernel.py | 81 +- mlir/utils/lldb-scripts/mlirDataFormatters.py | 9 +- mlir/utils/mbr/mbr/__init__.py | 1 + mlir/utils/mbr/mbr/discovery.py | 19 +- mlir/utils/mbr/mbr/main.py | 18 +- mlir/utils/mbr/mbr/stats.py | 8 +- mlir/utils/spirv/gen_spirv_dialect.py | 2048 +++++------ 163 files changed, 17322 insertions(+), 16063 deletions(-) diff --git a/mlir/benchmark/python/benchmark_sparse.py b/mlir/benchmark/python/benchmark_sparse.py index 6d7a396..72b3ef1 100644 --- a/mlir/benchmark/python/benchmark_sparse.py +++ b/mlir/benchmark/python/benchmark_sparse.py @@ -25,7 +25,7 @@ from common import setup_passes def matmul_dsl( A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K), B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N), - C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True) + C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True), ): """Helper function for mlir sparse matrix multiplication benchmark.""" C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n] @@ -43,6 +43,7 @@ def benchmark_sparse_mlir_multiplication(): param2_type = ir.RankedTensorType.get([1500, 2000], f64) result_type = ir.RankedTensorType.get([1000, 2000], f64) with ir.InsertionPoint(module.body): + @func.FuncOp.from_py_func(param1_type, param2_type, result_type) def sparse_kernel(x, y, z): return matmul_dsl(x, y, outs=[z]) @@ -51,37 +52,34 @@ def benchmark_sparse_mlir_multiplication(): with ir.Context(), ir.Location.unknown(): kernel_func = get_kernel_func_from_module(module) timer_func = emit_timer_func() - wrapped_func = emit_benchmark_wrapped_main_func( - kernel_func, - timer_func - ) + wrapped_func = emit_benchmark_wrapped_main_func(kernel_func, timer_func) main_module_with_benchmark = ir.Module.parse( str(timer_func) + str(wrapped_func) + str(kernel_func) ) setup_passes(main_module_with_benchmark) c_runner_utils = os.getenv("MLIR_C_RUNNER_UTILS", "") - assert os.path.exists(c_runner_utils),\ - f"{c_runner_utils} does not exist." \ - f" Please pass a valid value for" \ + assert os.path.exists(c_runner_utils), ( + f"{c_runner_utils} does not exist." + f" Please pass a valid value for" f" MLIR_C_RUNNER_UTILS environment variable." + ) runner_utils = os.getenv("MLIR_RUNNER_UTILS", "") - assert os.path.exists(runner_utils),\ - f"{runner_utils} does not exist." \ - f" Please pass a valid value for MLIR_RUNNER_UTILS" \ + assert os.path.exists(runner_utils), ( + f"{runner_utils} does not exist." + f" Please pass a valid value for MLIR_RUNNER_UTILS" f" environment variable." + ) engine = ExecutionEngine( main_module_with_benchmark, 3, - shared_libs=[c_runner_utils, runner_utils] + shared_libs=[c_runner_utils, runner_utils], ) return engine.invoke def runner(engine_invoke): compiled_program_args = [] - for argument_type in [ - result_type, param1_type, param2_type, result_type - ]: + for argument_type in [result_type, param1_type, param2_type, result_type]: argument_type_str = str(argument_type) dimensions_str = re.sub("<|>|tensor", "", argument_type_str) dimensions = [int(dim) for dim in dimensions_str.split("x")[:-1]] @@ -111,6 +109,7 @@ def benchmark_np_matrix_multiplication(): benchmark, we don't have any `compiler` function returned. We just return the `runner` function. """ + def runner(): argument1 = np.random.uniform(low=0.0, high=100.0, size=(1000, 1500)) argument2 = np.random.uniform(low=0.0, high=100.0, size=(1500, 2000)) diff --git a/mlir/benchmark/python/common.py b/mlir/benchmark/python/common.py index 3634641..c605726 100644 --- a/mlir/benchmark/python/common.py +++ b/mlir/benchmark/python/common.py @@ -10,8 +10,7 @@ from mlir.passmanager import PassManager def setup_passes(mlir_module): - """Setup pass pipeline parameters for benchmark functions. - """ + """Setup pass pipeline parameters for benchmark functions.""" opt = ( "parallelization-strategy=none" " vectorization-strategy=none vl=1 enable-simd-index32=False" @@ -43,12 +42,15 @@ def get_kernel_func_from_module(module: ir.Module) -> func.FuncOp: This function only works for a module with one region, one block, and one operation. """ - assert len(module.operation.regions) == 1, \ - "Expected kernel module to have only one region" - assert len(module.operation.regions[0].blocks) == 1, \ - "Expected kernel module to have only one block" - assert len(module.operation.regions[0].blocks[0].operations) == 1, \ - "Expected kernel module to have only one operation" + assert ( + len(module.operation.regions) == 1 + ), "Expected kernel module to have only one region" + assert ( + len(module.operation.regions[0].blocks) == 1 + ), "Expected kernel module to have only one block" + assert ( + len(module.operation.regions[0].blocks[0].operations) == 1 + ), "Expected kernel module to have only one operation" return module.operation.regions[0].blocks[0].operations[0] @@ -57,8 +59,7 @@ def emit_timer_func() -> func.FuncOp: used, the `MLIR_RUNNER_UTILS` and `MLIR_C_RUNNER_UTILS` must be included. """ i64_type = ir.IntegerType.get_signless(64) - nanoTime = func.FuncOp( - "nanoTime", ([], [i64_type]), visibility="private") + nanoTime = func.FuncOp("nanoTime", ([], [i64_type]), visibility="private") nanoTime.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() return nanoTime @@ -76,9 +77,8 @@ def emit_benchmark_wrapped_main_func(kernel_func, timer_func): wrapped_func = func.FuncOp( # Same signature and an extra buffer of indices to save timings. "main", - (kernel_func.arguments.types + [memref_of_i64_type], - kernel_func.type.results), - visibility="public" + (kernel_func.arguments.types + [memref_of_i64_type], kernel_func.type.results), + visibility="public", ) wrapped_func.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() @@ -88,13 +88,13 @@ def emit_benchmark_wrapped_main_func(kernel_func, timer_func): zero = arith.ConstantOp.create_index(0) n_iterations = memref.DimOp(ir.IndexType.get(), timer_buffer, zero) one = arith.ConstantOp.create_index(1) - iter_args = list(wrapped_func.arguments[-num_results - 1:-1]) + iter_args = list(wrapped_func.arguments[-num_results - 1 : -1]) loop = scf.ForOp(zero, n_iterations, one, iter_args) with ir.InsertionPoint(loop.body): start = func.CallOp(timer_func, []) call = func.CallOp( kernel_func, - wrapped_func.arguments[:-num_results - 1] + loop.inner_iter_args + wrapped_func.arguments[: -num_results - 1] + loop.inner_iter_args, ) end = func.CallOp(timer_func, []) time_taken = arith.SubIOp(end, start) diff --git a/mlir/examples/standalone/test/CAPI/lit.local.cfg b/mlir/examples/standalone/test/CAPI/lit.local.cfg index f08a0de..bb0c17c 100644 --- a/mlir/examples/standalone/test/CAPI/lit.local.cfg +++ b/mlir/examples/standalone/test/CAPI/lit.local.cfg @@ -1 +1 @@ -config.suffixes.add('.c') +config.suffixes.add(".c") diff --git a/mlir/examples/standalone/test/lit.cfg.py b/mlir/examples/standalone/test/lit.cfg.py index 601ac8f..e27dddd 100644 --- a/mlir/examples/standalone/test/lit.cfg.py +++ b/mlir/examples/standalone/test/lit.cfg.py @@ -16,52 +16,55 @@ from lit.llvm.subst import FindTool # Configuration file for the 'lit' test runner. # name: The name of this test suite. -config.name = 'STANDALONE' +config.name = "STANDALONE" config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) # suffixes: A list of file extensions to treat as test files. -config.suffixes = ['.mlir'] +config.suffixes = [".mlir"] # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) # test_exec_root: The root path where tests should be run. -config.test_exec_root = os.path.join(config.standalone_obj_root, 'test') +config.test_exec_root = os.path.join(config.standalone_obj_root, "test") -config.substitutions.append(('%PATH%', config.environment['PATH'])) -config.substitutions.append(('%shlibext', config.llvm_shlib_ext)) +config.substitutions.append(("%PATH%", config.environment["PATH"])) +config.substitutions.append(("%shlibext", config.llvm_shlib_ext)) -llvm_config.with_system_environment( - ['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP']) +llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"]) llvm_config.use_default_substitutions() # excludes: A list of directories to exclude from the testsuite. The 'Inputs' # subdirectories contain auxiliary inputs for various tests in their parent # directories. -config.excludes = ['Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt'] +config.excludes = ["Inputs", "Examples", "CMakeLists.txt", "README.txt", "LICENSE.txt"] # test_exec_root: The root path where tests should be run. -config.test_exec_root = os.path.join(config.standalone_obj_root, 'test') -config.standalone_tools_dir = os.path.join(config.standalone_obj_root, 'bin') -config.standalone_libs_dir = os.path.join(config.standalone_obj_root, 'lib') +config.test_exec_root = os.path.join(config.standalone_obj_root, "test") +config.standalone_tools_dir = os.path.join(config.standalone_obj_root, "bin") +config.standalone_libs_dir = os.path.join(config.standalone_obj_root, "lib") -config.substitutions.append(('%standalone_libs', config.standalone_libs_dir)) +config.substitutions.append(("%standalone_libs", config.standalone_libs_dir)) # Tweak the PATH to include the tools dir. -llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) +llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True) tool_dirs = [config.standalone_tools_dir, config.llvm_tools_dir] tools = [ - 'mlir-opt', - 'standalone-capi-test', - 'standalone-opt', - 'standalone-translate', + "mlir-opt", + "standalone-capi-test", + "standalone-opt", + "standalone-translate", ] llvm_config.add_tool_substitutions(tools, tool_dirs) -llvm_config.with_environment('PYTHONPATH', [ - os.path.join(config.mlir_obj_dir, 'python_packages', 'standalone'), -], append_path=True) +llvm_config.with_environment( + "PYTHONPATH", + [ + os.path.join(config.mlir_obj_dir, "python_packages", "standalone"), + ], + append_path=True, +) diff --git a/mlir/examples/standalone/test/python/lit.local.cfg b/mlir/examples/standalone/test/python/lit.local.cfg index b70b9d7..3394f18 100644 --- a/mlir/examples/standalone/test/python/lit.local.cfg +++ b/mlir/examples/standalone/test/python/lit.local.cfg @@ -1,4 +1,4 @@ -config.suffixes.add('.py') +config.suffixes.add(".py") if not config.enable_bindings_python: - config.unsupported = True + config.unsupported = True diff --git a/mlir/examples/standalone/test/python/smoketest.py b/mlir/examples/standalone/test/python/smoketest.py index 0d8f41c..08e08cb 100644 --- a/mlir/examples/standalone/test/python/smoketest.py +++ b/mlir/examples/standalone/test/python/smoketest.py @@ -1,17 +1,16 @@ # RUN: %python %s | FileCheck %s from mlir_standalone.ir import * -from mlir_standalone.dialects import ( - builtin as builtin_d, - standalone as standalone_d -) +from mlir_standalone.dialects import builtin as builtin_d, standalone as standalone_d with Context(): - standalone_d.register_dialect() - module = Module.parse(""" + standalone_d.register_dialect() + module = Module.parse( + """ %0 = arith.constant 2 : i32 %1 = standalone.foo %0 : i32 - """) - # CHECK: %[[C:.*]] = arith.constant 2 : i32 - # CHECK: standalone.foo %[[C]] : i32 - print(str(module)) + """ + ) + # CHECK: %[[C:.*]] = arith.constant 2 : i32 + # CHECK: standalone.foo %[[C]] : i32 + print(str(module)) diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index 7d3d1f6..03fcb10 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -10,26 +10,26 @@ _this_dir = os.path.dirname(__file__) def get_lib_dirs() -> Sequence[str]: - """Gets the lib directory for linking to shared libraries. + """Gets the lib directory for linking to shared libraries. - On some platforms, the package may need to be built specially to export - development libraries. - """ - return [_this_dir] + On some platforms, the package may need to be built specially to export + development libraries. + """ + return [_this_dir] def get_include_dirs() -> Sequence[str]: - """Gets the include directory for compiling against exported C libraries. + """Gets the include directory for compiling against exported C libraries. - Depending on how the package was build, development C libraries may or may - not be present. - """ - return [os.path.join(_this_dir, "include")] + Depending on how the package was build, development C libraries may or may + not be present. + """ + return [os.path.join(_this_dir, "include")] # Perform Python level site initialization. This involves: # 1. Attempting to load initializer modules, specific to the distribution. -# 2. Defining the concrete mlir.ir.Context that does site specific +# 2. Defining the concrete mlir.ir.Context that does site specific # initialization. # # Aside from just being far more convenient to do this at the Python level, @@ -38,91 +38,106 @@ def get_include_dirs() -> Sequence[str]: # in the scope of the base class __init__). # # For #1, we: -# a. Probe for modules named '_mlirRegisterEverything' and -# '_site_initialize_{i}', where 'i' is a number starting at zero and +# a. Probe for modules named '_mlirRegisterEverything' and +# '_site_initialize_{i}', where 'i' is a number starting at zero and # proceeding so long as a module with the name is found. # b. If the module has a 'register_dialects' attribute, it will be called # immediately with a DialectRegistry to populate. # c. If the module has a 'context_init_hook', it will be added to a list -# of callbacks that are invoked as the last step of Context +# of callbacks that are invoked as the last step of Context # initialization (and passed the Context under construction). # # This facility allows downstreams to customize Context creation to their # needs. def _site_initialize(): - import importlib - import itertools - import logging - from ._mlir import ir - logger = logging.getLogger(__name__) - registry = ir.DialectRegistry() - post_init_hooks = [] - - def process_initializer_module(module_name): - try: - m = importlib.import_module(f".{module_name}", __name__) - except ModuleNotFoundError: - return False - except ImportError: - message = (f"Error importing mlir initializer {module_name}. This may " - "happen in unclean incremental builds but is likely a real bug if " - "encountered otherwise and the MLIR Python API may not function.") - logger.warning(message, exc_info=True) - - logger.debug("Initializing MLIR with module: %s", module_name) - if hasattr(m, "register_dialects"): - logger.debug("Registering dialects from initializer %r", m) - m.register_dialects(registry) - if hasattr(m, "context_init_hook"): - logger.debug("Adding context init hook from %r", m) - post_init_hooks.append(m.context_init_hook) - return True - - - # If _mlirRegisterEverything is built, then include it as an initializer - # module. - process_initializer_module("_mlirRegisterEverything") - - # Load all _site_initialize_{i} modules, where 'i' is a number starting - # at 0. - for i in itertools.count(): - module_name = f"_site_initialize_{i}" - if not process_initializer_module(module_name): - break - - class Context(ir._BaseContext): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.append_dialect_registry(registry) - for hook in post_init_hooks: - hook(self) - # TODO: There is some debate about whether we should eagerly load - # all dialects. It is being done here in order to preserve existing - # behavior. See: https://github.com/llvm/llvm-project/issues/56037 - self.load_all_available_dialects() - ir.Context = Context - - class MLIRError(Exception): - """ - An exception with diagnostic information. Has the following fields: - message: str - error_diagnostics: List[ir.DiagnosticInfo] - """ - def __init__(self, message, error_diagnostics): - self.message = message - self.error_diagnostics = error_diagnostics - super().__init__(message, error_diagnostics) - - def __str__(self): - s = self.message - if self.error_diagnostics: - s += ':' - for diag in self.error_diagnostics: - s += "\nerror: " + str(diag.location)[4:-1] + ": " + diag.message.replace('\n', '\n ') - for note in diag.notes: - s += "\n note: " + str(note.location)[4:-1] + ": " + note.message.replace('\n', '\n ') - return s - ir.MLIRError = MLIRError + import importlib + import itertools + import logging + from ._mlir import ir + + logger = logging.getLogger(__name__) + registry = ir.DialectRegistry() + post_init_hooks = [] + + def process_initializer_module(module_name): + try: + m = importlib.import_module(f".{module_name}", __name__) + except ModuleNotFoundError: + return False + except ImportError: + message = ( + f"Error importing mlir initializer {module_name}. This may " + "happen in unclean incremental builds but is likely a real bug if " + "encountered otherwise and the MLIR Python API may not function." + ) + logger.warning(message, exc_info=True) + + logger.debug("Initializing MLIR with module: %s", module_name) + if hasattr(m, "register_dialects"): + logger.debug("Registering dialects from initializer %r", m) + m.register_dialects(registry) + if hasattr(m, "context_init_hook"): + logger.debug("Adding context init hook from %r", m) + post_init_hooks.append(m.context_init_hook) + return True + + # If _mlirRegisterEverything is built, then include it as an initializer + # module. + process_initializer_module("_mlirRegisterEverything") + + # Load all _site_initialize_{i} modules, where 'i' is a number starting + # at 0. + for i in itertools.count(): + module_name = f"_site_initialize_{i}" + if not process_initializer_module(module_name): + break + + class Context(ir._BaseContext): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.append_dialect_registry(registry) + for hook in post_init_hooks: + hook(self) + # TODO: There is some debate about whether we should eagerly load + # all dialects. It is being done here in order to preserve existing + # behavior. See: https://github.com/llvm/llvm-project/issues/56037 + self.load_all_available_dialects() + + ir.Context = Context + + class MLIRError(Exception): + """ + An exception with diagnostic information. Has the following fields: + message: str + error_diagnostics: List[ir.DiagnosticInfo] + """ + + def __init__(self, message, error_diagnostics): + self.message = message + self.error_diagnostics = error_diagnostics + super().__init__(message, error_diagnostics) + + def __str__(self): + s = self.message + if self.error_diagnostics: + s += ":" + for diag in self.error_diagnostics: + s += ( + "\nerror: " + + str(diag.location)[4:-1] + + ": " + + diag.message.replace("\n", "\n ") + ) + for note in diag.notes: + s += ( + "\n note: " + + str(note.location)[4:-1] + + ": " + + note.message.replace("\n", "\n ") + ) + return s + + ir.MLIRError = MLIRError _site_initialize() diff --git a/mlir/python/mlir/dialects/_arith_ops_ext.py b/mlir/python/mlir/dialects/_arith_ops_ext.py index 2408593..df38f87 100644 --- a/mlir/python/mlir/dialects/_arith_ops_ext.py +++ b/mlir/python/mlir/dialects/_arith_ops_ext.py @@ -3,72 +3,67 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from ..ir import * - from ._ods_common import get_default_loc_context as _get_default_loc_context + from ..ir import * + from ._ods_common import get_default_loc_context as _get_default_loc_context - from typing import Any, List, Union + from typing import Any, List, Union except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e def _isa(obj: Any, cls: type): - try: - cls(obj) - except ValueError: - return False - return True + try: + cls(obj) + except ValueError: + return False + return True def _is_any_of(obj: Any, classes: List[type]): - return any(_isa(obj, cls) for cls in classes) + return any(_isa(obj, cls) for cls in classes) def _is_integer_like_type(type: Type): - return _is_any_of(type, [IntegerType, IndexType]) + return _is_any_of(type, [IntegerType, IndexType]) def _is_float_type(type: Type): - return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type]) + return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type]) class ConstantOp: - """Specialization for the constant op class.""" - - def __init__(self, - result: Type, - value: Union[int, float, Attribute], - *, - loc=None, - ip=None): - if isinstance(value, int): - super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip) - elif isinstance(value, float): - super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip) - else: - super().__init__(value, loc=loc, ip=ip) - - @classmethod - def create_index(cls, value: int, *, loc=None, ip=None): - """Create an index-typed constant.""" - return cls( - IndexType.get(context=_get_default_loc_context(loc)), - value, - loc=loc, - ip=ip) - - @property - def type(self): - return self.results[0].type - - @property - def value(self): - return Attribute(self.operation.attributes["value"]) - - @property - def literal_value(self) -> Union[int, float]: - if _is_integer_like_type(self.type): - return IntegerAttr(self.value).value - elif _is_float_type(self.type): - return FloatAttr(self.value).value - else: - raise ValueError("only integer and float constants have literal values") + """Specialization for the constant op class.""" + + def __init__( + self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None + ): + if isinstance(value, int): + super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip) + elif isinstance(value, float): + super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip) + else: + super().__init__(value, loc=loc, ip=ip) + + @classmethod + def create_index(cls, value: int, *, loc=None, ip=None): + """Create an index-typed constant.""" + return cls( + IndexType.get(context=_get_default_loc_context(loc)), value, loc=loc, ip=ip + ) + + @property + def type(self): + return self.results[0].type + + @property + def value(self): + return Attribute(self.operation.attributes["value"]) + + @property + def literal_value(self) -> Union[int, float]: + if _is_integer_like_type(self.type): + return IntegerAttr(self.value).value + elif _is_float_type(self.type): + return FloatAttr(self.value).value + else: + raise ValueError("only integer and float constants have literal values") diff --git a/mlir/python/mlir/dialects/_bufferization_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_ops_ext.py index 6ed35f4..1066cb4 100644 --- a/mlir/python/mlir/dialects/_bufferization_ops_ext.py +++ b/mlir/python/mlir/dialects/_bufferization_ops_ext.py @@ -3,36 +3,39 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from typing import Sequence, Union - from ..ir import * - from ._ods_common import get_default_loc_context + from typing import Sequence, Union + from ..ir import * + from ._ods_common import get_default_loc_context - from typing import Any, List, Union + from typing import Any, List, Union except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e class AllocTensorOp: - """Extends the bufferization.alloc_tensor op.""" + """Extends the bufferization.alloc_tensor op.""" - def __init__(self, - tensor_type: Type, - dynamic_sizes: Sequence[Value], - copy: Value, - size_hint: Value, - escape: BoolAttr, - *, - loc=None, - ip=None): - """Constructs an `alloc_tensor` with static and/or dynamic sizes.""" - context = get_default_loc_context(loc) - attributes = {} - if escape: - attributes["escape"] = escape - op = self.build_generic( - results=[tensor_type], - operands=[dynamic_sizes, copy, size_hint], - attributes=attributes, - loc=loc, - ip=ip) - OpView.__init__(self, op) + def __init__( + self, + tensor_type: Type, + dynamic_sizes: Sequence[Value], + copy: Value, + size_hint: Value, + escape: BoolAttr, + *, + loc=None, + ip=None + ): + """Constructs an `alloc_tensor` with static and/or dynamic sizes.""" + context = get_default_loc_context(loc) + attributes = {} + if escape: + attributes["escape"] = escape + op = self.build_generic( + results=[tensor_type], + operands=[dynamic_sizes, copy, size_hint], + attributes=attributes, + loc=loc, + ip=ip, + ) + OpView.__init__(self, op) diff --git a/mlir/python/mlir/dialects/_builtin_ops_ext.py b/mlir/python/mlir/dialects/_builtin_ops_ext.py index b69163f..27a6012 100644 --- a/mlir/python/mlir/dialects/_builtin_ops_ext.py +++ b/mlir/python/mlir/dialects/_builtin_ops_ext.py @@ -3,18 +3,18 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from ..ir import * + from ..ir import * except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e + class ModuleOp: - """Specialization for the module op class.""" + """Specialization for the module op class.""" - def __init__(self, *, loc=None, ip=None): - super().__init__(self.build_generic(results=[], operands=[], loc=loc, - ip=ip)) - body = self.regions[0].blocks.append() + def __init__(self, *, loc=None, ip=None): + super().__init__(self.build_generic(results=[], operands=[], loc=loc, ip=ip)) + body = self.regions[0].blocks.append() - @property - def body(self): - return self.regions[0].blocks[0] + @property + def body(self): + return self.regions[0].blocks[0] diff --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/_func_ops_ext.py index 56df423..6d264c3 100644 --- a/mlir/python/mlir/dialects/_func_ops_ext.py +++ b/mlir/python/mlir/dialects/_func_ops_ext.py @@ -3,298 +3,317 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from ..ir import * - from ._ods_common import get_default_loc_context as _get_default_loc_context + from ..ir import * + from ._ods_common import get_default_loc_context as _get_default_loc_context - import inspect + import inspect - from typing import Any, List, Optional, Sequence, Union + from typing import Any, List, Optional, Sequence, Union except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e ARGUMENT_ATTRIBUTE_NAME = "arg_attrs" RESULT_ATTRIBUTE_NAME = "res_attrs" + class ConstantOp: - """Specialization for the constant op class.""" + """Specialization for the constant op class.""" - def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None): - super().__init__(result, value, loc=loc, ip=ip) + def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None): + super().__init__(result, value, loc=loc, ip=ip) - @property - def type(self): - return self.results[0].type + @property + def type(self): + return self.results[0].type class FuncOp: - """Specialization for the func op class.""" - - def __init__(self, - name, - type, - *, - visibility=None, - body_builder=None, - loc=None, - ip=None): - """ - Create a FuncOp with the provided `name`, `type`, and `visibility`. - - `name` is a string representing the function name. - - `type` is either a FunctionType or a pair of list describing inputs and - results. - - `visibility` is a string matching `public`, `private`, or `nested`. None - implies private visibility. - - `body_builder` is an optional callback, when provided a new entry block - is created and the callback is invoked with the new op as argument within - an InsertionPoint context already set for the block. The callback is - expected to insert a terminator in the block. - """ - sym_name = StringAttr.get(str(name)) - - # If the type is passed as a tuple, build a FunctionType on the fly. - if isinstance(type, tuple): - type = FunctionType.get(inputs=type[0], results=type[1]) - - type = TypeAttr.get(type) - sym_visibility = StringAttr.get( - str(visibility)) if visibility is not None else None - super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) - if body_builder: - entry_block = self.add_entry_block() - with InsertionPoint(entry_block): - body_builder(self) - - @property - def is_external(self): - return len(self.regions[0].blocks) == 0 - - @property - def body(self): - return self.regions[0] - - @property - def type(self): - return FunctionType(TypeAttr(self.attributes["function_type"]).value) - - @property - def visibility(self): - return self.attributes["sym_visibility"] - - @property - def name(self) -> StringAttr: - return StringAttr(self.attributes["sym_name"]) - - @property - def entry_block(self): - if self.is_external: - raise IndexError('External function does not have a body') - return self.regions[0].blocks[0] - - def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None): - """ - Add an entry block to the function body using the function signature to - infer block arguments. - Returns the newly created block - """ - if not self.is_external: - raise IndexError('The function already has an entry block!') - self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs) - return self.body.blocks[0] - - @property - def arg_attrs(self): - return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) - - @arg_attrs.setter - def arg_attrs(self, attribute: Union[ArrayAttr, list]): - if isinstance(attribute, ArrayAttr): - self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute - else: - self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( - attribute, context=self.context) - - @property - def arguments(self): - return self.entry_block.arguments - - @property - def result_attrs(self): - return self.attributes[RESULT_ATTRIBUTE_NAME] - - @result_attrs.setter - def result_attrs(self, attribute: ArrayAttr): - self.attributes[RESULT_ATTRIBUTE_NAME] = attribute - - @classmethod - def from_py_func(FuncOp, - *inputs: Type, - results: Optional[Sequence[Type]] = None, - name: Optional[str] = None): - """Decorator to define an MLIR FuncOp specified as a python function. - - Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are - active for the current thread (i.e. established in a `with` block). - - When applied as a decorator to a Python function, an entry block will - be constructed for the FuncOp with types as specified in `*inputs`. The - block arguments will be passed positionally to the Python function. In - addition, if the Python function accepts keyword arguments generally or - has a corresponding keyword argument, the following will be passed: - * `func_op`: The `func` op being defined. - - By default, the function name will be the Python function `__name__`. This - can be overriden by passing the `name` argument to the decorator. - - If `results` is not specified, then the decorator will implicitly - insert a `ReturnOp` with the `Value`'s returned from the decorated - function. It will also set the `FuncOp` type with the actual return - value types. If `results` is specified, then the decorated function - must return `None` and no implicit `ReturnOp` is added (nor are the result - types updated). The implicit behavior is intended for simple, single-block - cases, and users should specify result types explicitly for any complicated - cases. - - The decorated function can further be called from Python and will insert - a `CallOp` at the then-current insertion point, returning either None ( - if no return values), a unary Value (for one result), or a list of Values). - This mechanism cannot be used to emit recursive calls (by construction). - """ - - def decorator(f): - from . import func - # Introspect the callable for optional features. - sig = inspect.signature(f) - has_arg_func_op = False - for param in sig.parameters.values(): - if param.kind == param.VAR_KEYWORD: - has_arg_func_op = True - if param.name == "func_op" and (param.kind - == param.POSITIONAL_OR_KEYWORD or - param.kind == param.KEYWORD_ONLY): - has_arg_func_op = True - - # Emit the FuncOp. - implicit_return = results is None - symbol_name = name or f.__name__ - function_type = FunctionType.get( - inputs=inputs, results=[] if implicit_return else results) - func_op = FuncOp(name=symbol_name, type=function_type) - with InsertionPoint(func_op.add_entry_block()): - func_args = func_op.entry_block.arguments - func_kwargs = {} - if has_arg_func_op: - func_kwargs["func_op"] = func_op - return_values = f(*func_args, **func_kwargs) - if not implicit_return: - return_types = list(results) - assert return_values is None, ( - "Capturing a python function with explicit `results=` " - "requires that the wrapped function returns None.") - else: - # Coerce return values, add ReturnOp and rewrite func type. - if return_values is None: - return_values = [] - elif isinstance(return_values, tuple): - return_values = list(return_values) - elif isinstance(return_values, Value): - # Returning a single value is fine, coerce it into a list. - return_values = [return_values] - elif isinstance(return_values, OpView): - # Returning a single operation is fine, coerce its results a list. - return_values = return_values.operation.results - elif isinstance(return_values, Operation): - # Returning a single operation is fine, coerce its results a list. - return_values = return_values.results - else: - return_values = list(return_values) - func.ReturnOp(return_values) - # Recompute the function type. - return_types = [v.type for v in return_values] - function_type = FunctionType.get(inputs=inputs, results=return_types) - func_op.attributes["function_type"] = TypeAttr.get(function_type) - - def emit_call_op(*call_args): - call_op = func.CallOp(return_types, FlatSymbolRefAttr.get(symbol_name), - call_args) - if return_types is None: - return None - elif len(return_types) == 1: - return call_op.result + """Specialization for the func op class.""" + + def __init__( + self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None + ): + """ + Create a FuncOp with the provided `name`, `type`, and `visibility`. + - `name` is a string representing the function name. + - `type` is either a FunctionType or a pair of list describing inputs and + results. + - `visibility` is a string matching `public`, `private`, or `nested`. None + implies private visibility. + - `body_builder` is an optional callback, when provided a new entry block + is created and the callback is invoked with the new op as argument within + an InsertionPoint context already set for the block. The callback is + expected to insert a terminator in the block. + """ + sym_name = StringAttr.get(str(name)) + + # If the type is passed as a tuple, build a FunctionType on the fly. + if isinstance(type, tuple): + type = FunctionType.get(inputs=type[0], results=type[1]) + + type = TypeAttr.get(type) + sym_visibility = ( + StringAttr.get(str(visibility)) if visibility is not None else None + ) + super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) + if body_builder: + entry_block = self.add_entry_block() + with InsertionPoint(entry_block): + body_builder(self) + + @property + def is_external(self): + return len(self.regions[0].blocks) == 0 + + @property + def body(self): + return self.regions[0] + + @property + def type(self): + return FunctionType(TypeAttr(self.attributes["function_type"]).value) + + @property + def visibility(self): + return self.attributes["sym_visibility"] + + @property + def name(self) -> StringAttr: + return StringAttr(self.attributes["sym_name"]) + + @property + def entry_block(self): + if self.is_external: + raise IndexError("External function does not have a body") + return self.regions[0].blocks[0] + + def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None): + """ + Add an entry block to the function body using the function signature to + infer block arguments. + Returns the newly created block + """ + if not self.is_external: + raise IndexError("The function already has an entry block!") + self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs) + return self.body.blocks[0] + + @property + def arg_attrs(self): + return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) + + @arg_attrs.setter + def arg_attrs(self, attribute: Union[ArrayAttr, list]): + if isinstance(attribute, ArrayAttr): + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute else: - return call_op.results - - wrapped = emit_call_op - wrapped.__name__ = f.__name__ - wrapped.func_op = func_op - return wrapped + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( + attribute, context=self.context + ) + + @property + def arguments(self): + return self.entry_block.arguments + + @property + def result_attrs(self): + return self.attributes[RESULT_ATTRIBUTE_NAME] + + @result_attrs.setter + def result_attrs(self, attribute: ArrayAttr): + self.attributes[RESULT_ATTRIBUTE_NAME] = attribute + + @classmethod + def from_py_func( + FuncOp, + *inputs: Type, + results: Optional[Sequence[Type]] = None, + name: Optional[str] = None, + ): + """Decorator to define an MLIR FuncOp specified as a python function. + + Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are + active for the current thread (i.e. established in a `with` block). + + When applied as a decorator to a Python function, an entry block will + be constructed for the FuncOp with types as specified in `*inputs`. The + block arguments will be passed positionally to the Python function. In + addition, if the Python function accepts keyword arguments generally or + has a corresponding keyword argument, the following will be passed: + * `func_op`: The `func` op being defined. + + By default, the function name will be the Python function `__name__`. This + can be overriden by passing the `name` argument to the decorator. + + If `results` is not specified, then the decorator will implicitly + insert a `ReturnOp` with the `Value`'s returned from the decorated + function. It will also set the `FuncOp` type with the actual return + value types. If `results` is specified, then the decorated function + must return `None` and no implicit `ReturnOp` is added (nor are the result + types updated). The implicit behavior is intended for simple, single-block + cases, and users should specify result types explicitly for any complicated + cases. + + The decorated function can further be called from Python and will insert + a `CallOp` at the then-current insertion point, returning either None ( + if no return values), a unary Value (for one result), or a list of Values). + This mechanism cannot be used to emit recursive calls (by construction). + """ + + def decorator(f): + from . import func + + # Introspect the callable for optional features. + sig = inspect.signature(f) + has_arg_func_op = False + for param in sig.parameters.values(): + if param.kind == param.VAR_KEYWORD: + has_arg_func_op = True + if param.name == "func_op" and ( + param.kind == param.POSITIONAL_OR_KEYWORD + or param.kind == param.KEYWORD_ONLY + ): + has_arg_func_op = True + + # Emit the FuncOp. + implicit_return = results is None + symbol_name = name or f.__name__ + function_type = FunctionType.get( + inputs=inputs, results=[] if implicit_return else results + ) + func_op = FuncOp(name=symbol_name, type=function_type) + with InsertionPoint(func_op.add_entry_block()): + func_args = func_op.entry_block.arguments + func_kwargs = {} + if has_arg_func_op: + func_kwargs["func_op"] = func_op + return_values = f(*func_args, **func_kwargs) + if not implicit_return: + return_types = list(results) + assert return_values is None, ( + "Capturing a python function with explicit `results=` " + "requires that the wrapped function returns None." + ) + else: + # Coerce return values, add ReturnOp and rewrite func type. + if return_values is None: + return_values = [] + elif isinstance(return_values, tuple): + return_values = list(return_values) + elif isinstance(return_values, Value): + # Returning a single value is fine, coerce it into a list. + return_values = [return_values] + elif isinstance(return_values, OpView): + # Returning a single operation is fine, coerce its results a list. + return_values = return_values.operation.results + elif isinstance(return_values, Operation): + # Returning a single operation is fine, coerce its results a list. + return_values = return_values.results + else: + return_values = list(return_values) + func.ReturnOp(return_values) + # Recompute the function type. + return_types = [v.type for v in return_values] + function_type = FunctionType.get( + inputs=inputs, results=return_types + ) + func_op.attributes["function_type"] = TypeAttr.get(function_type) + + def emit_call_op(*call_args): + call_op = func.CallOp( + return_types, FlatSymbolRefAttr.get(symbol_name), call_args + ) + if return_types is None: + return None + elif len(return_types) == 1: + return call_op.result + else: + return call_op.results + + wrapped = emit_call_op + wrapped.__name__ = f.__name__ + wrapped.func_op = func_op + return wrapped + + return decorator - return decorator class CallOp: - """Specialization for the call op class.""" - - def __init__(self, - calleeOrResults: Union[FuncOp, List[Type]], - argumentsOrCallee: Union[List, FlatSymbolRefAttr, str], - arguments: Optional[List] = None, - *, - loc=None, - ip=None): - """Creates an call operation. - - The constructor accepts three different forms: - - 1. A function op to be called followed by a list of arguments. - 2. A list of result types, followed by the name of the function to be - called as string, following by a list of arguments. - 3. A list of result types, followed by the name of the function to be - called as symbol reference attribute, followed by a list of arguments. - - For example - - f = func.FuncOp("foo", ...) - func.CallOp(f, [args]) - func.CallOp([result_types], "foo", [args]) - - In all cases, the location and insertion point may be specified as keyword - arguments if not provided by the surrounding context managers. - """ - - # TODO: consider supporting constructor "overloads", e.g., through a custom - # or pybind-provided metaclass. - if isinstance(calleeOrResults, FuncOp): - if not isinstance(argumentsOrCallee, list): - raise ValueError( - "when constructing a call to a function, expected " + - "the second argument to be a list of call arguments, " + - f"got {type(argumentsOrCallee)}") - if arguments is not None: - raise ValueError("unexpected third argument when constructing a call" + - "to a function") - - super().__init__( - calleeOrResults.type.results, - FlatSymbolRefAttr.get( - calleeOrResults.name.value, - context=_get_default_loc_context(loc)), - argumentsOrCallee, - loc=loc, - ip=ip) - return - - if isinstance(argumentsOrCallee, list): - raise ValueError("when constructing a call to a function by name, " + - "expected the second argument to be a string or a " + - f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}") - - if isinstance(argumentsOrCallee, FlatSymbolRefAttr): - super().__init__( - calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip) - elif isinstance(argumentsOrCallee, str): - super().__init__( - calleeOrResults, - FlatSymbolRefAttr.get( - argumentsOrCallee, context=_get_default_loc_context(loc)), - arguments, - loc=loc, - ip=ip) + """Specialization for the call op class.""" + + def __init__( + self, + calleeOrResults: Union[FuncOp, List[Type]], + argumentsOrCallee: Union[List, FlatSymbolRefAttr, str], + arguments: Optional[List] = None, + *, + loc=None, + ip=None, + ): + """Creates an call operation. + + The constructor accepts three different forms: + + 1. A function op to be called followed by a list of arguments. + 2. A list of result types, followed by the name of the function to be + called as string, following by a list of arguments. + 3. A list of result types, followed by the name of the function to be + called as symbol reference attribute, followed by a list of arguments. + + For example + + f = func.FuncOp("foo", ...) + func.CallOp(f, [args]) + func.CallOp([result_types], "foo", [args]) + + In all cases, the location and insertion point may be specified as keyword + arguments if not provided by the surrounding context managers. + """ + + # TODO: consider supporting constructor "overloads", e.g., through a custom + # or pybind-provided metaclass. + if isinstance(calleeOrResults, FuncOp): + if not isinstance(argumentsOrCallee, list): + raise ValueError( + "when constructing a call to a function, expected " + + "the second argument to be a list of call arguments, " + + f"got {type(argumentsOrCallee)}" + ) + if arguments is not None: + raise ValueError( + "unexpected third argument when constructing a call" + + "to a function" + ) + + super().__init__( + calleeOrResults.type.results, + FlatSymbolRefAttr.get( + calleeOrResults.name.value, context=_get_default_loc_context(loc) + ), + argumentsOrCallee, + loc=loc, + ip=ip, + ) + return + + if isinstance(argumentsOrCallee, list): + raise ValueError( + "when constructing a call to a function by name, " + + "expected the second argument to be a string or a " + + f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}" + ) + + if isinstance(argumentsOrCallee, FlatSymbolRefAttr): + super().__init__( + calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip + ) + elif isinstance(argumentsOrCallee, str): + super().__init__( + calleeOrResults, + FlatSymbolRefAttr.get( + argumentsOrCallee, context=_get_default_loc_context(loc) + ), + arguments, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py index eb9e969..3f6d854 100644 --- a/mlir/python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py @@ -3,39 +3,45 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from typing import Optional, Sequence, Union - from ..ir import * - from ._ods_common import get_default_loc_context - from .._mlir_libs._mlirDialectsLinalg import fill_builtin_region + from typing import Optional, Sequence, Union + from ..ir import * + from ._ods_common import get_default_loc_context + from .._mlir_libs._mlirDialectsLinalg import fill_builtin_region except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e from ._ods_common import get_op_result_or_value as _get_op_result_or_value + def isa(cls: Type, ty: Type): - try: - cls(ty) - return True - except ValueError: - return False + try: + cls(ty) + return True + except ValueError: + return False class StructuredOpMixin: - """All structured ops use the same mixin class.""" + """All structured ops use the same mixin class.""" - def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None): - super().__init__( - self.build_generic(results=list(results), - operands=[list(inputs), list(outputs)], - loc=loc, - ip=ip)) + def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None): + super().__init__( + self.build_generic( + results=list(results), + operands=[list(inputs), list(outputs)], + loc=loc, + ip=ip, + ) + ) def select_opview_mixin(parent_opview_cls): - # TODO: This shouldn't be a heuristic: we should have a way to annotate - # the OpView to note that it is a structured op. - if ("__init__" not in parent_opview_cls.__dict__ and - hasattr(parent_opview_cls, "inputs") and - hasattr(parent_opview_cls, "outputs") and - hasattr(parent_opview_cls, "result_tensors")): - return StructuredOpMixin + # TODO: This shouldn't be a heuristic: we should have a way to annotate + # the OpView to note that it is a structured op. + if ( + "__init__" not in parent_opview_cls.__dict__ + and hasattr(parent_opview_cls, "inputs") + and hasattr(parent_opview_cls, "outputs") + and hasattr(parent_opview_cls, "result_tensors") + ): + return StructuredOpMixin diff --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py index 10079d3..3536d45 100644 --- a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py @@ -3,125 +3,130 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from ..ir import * - from ._ods_common import get_op_result_or_value as _get_op_result_or_value + from ..ir import * + from ._ods_common import get_op_result_or_value as _get_op_result_or_value except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e from typing import Optional, Union class GetParentForOp: - """Extension for GetParentForOp.""" - - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - num_loops: Optional[int] = None, - ip=None, - loc=None, - ): - if num_loops is None: - num_loops = 1 - super().__init__( - result_type, - _get_op_result_or_value(target), - num_loops=num_loops, - ip=ip, - loc=loc, - ) + """Extension for GetParentForOp.""" + + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + num_loops: Optional[int] = None, + ip=None, + loc=None, + ): + if num_loops is None: + num_loops = 1 + super().__init__( + result_type, + _get_op_result_or_value(target), + num_loops=num_loops, + ip=ip, + loc=loc, + ) class LoopOutlineOp: - """Extension for LoopOutlineOp.""" - - def __init__( - self, - function_type: Type, - call_type: Type, - target: Union[Operation, Value], - *, - func_name: Union[str, StringAttr], - ip=None, - loc=None, - ): - super().__init__( - function_type, - call_type, - _get_op_result_or_value(target), - func_name=(func_name if isinstance(func_name, StringAttr) else - StringAttr.get(func_name)), - ip=ip, - loc=loc, - ) + """Extension for LoopOutlineOp.""" + + def __init__( + self, + function_type: Type, + call_type: Type, + target: Union[Operation, Value], + *, + func_name: Union[str, StringAttr], + ip=None, + loc=None, + ): + super().__init__( + function_type, + call_type, + _get_op_result_or_value(target), + func_name=( + func_name + if isinstance(func_name, StringAttr) + else StringAttr.get(func_name) + ), + ip=ip, + loc=loc, + ) class LoopPeelOp: - """Extension for LoopPeelOp.""" - - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - fail_if_already_divisible: Union[bool, BoolAttr] = False, - ip=None, - loc=None, - ): - super().__init__( - result_type, - _get_op_result_or_value(target), - fail_if_already_divisible=(fail_if_already_divisible if isinstance( - fail_if_already_divisible, BoolAttr) else - BoolAttr.get(fail_if_already_divisible)), - ip=ip, - loc=loc, - ) + """Extension for LoopPeelOp.""" + + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + fail_if_already_divisible: Union[bool, BoolAttr] = False, + ip=None, + loc=None, + ): + super().__init__( + result_type, + _get_op_result_or_value(target), + fail_if_already_divisible=( + fail_if_already_divisible + if isinstance(fail_if_already_divisible, BoolAttr) + else BoolAttr.get(fail_if_already_divisible) + ), + ip=ip, + loc=loc, + ) class LoopPipelineOp: - """Extension for LoopPipelineOp.""" - - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - iteration_interval: Optional[Union[int, IntegerAttr]] = None, - read_latency: Optional[Union[int, IntegerAttr]] = None, - ip=None, - loc=None, - ): - if iteration_interval is None: - iteration_interval = 1 - if read_latency is None: - read_latency = 10 - super().__init__( - result_type, - _get_op_result_or_value(target), - iteration_interval=iteration_interval, - read_latency=read_latency, - ip=ip, - loc=loc, - ) + """Extension for LoopPipelineOp.""" + + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + iteration_interval: Optional[Union[int, IntegerAttr]] = None, + read_latency: Optional[Union[int, IntegerAttr]] = None, + ip=None, + loc=None, + ): + if iteration_interval is None: + iteration_interval = 1 + if read_latency is None: + read_latency = 10 + super().__init__( + result_type, + _get_op_result_or_value(target), + iteration_interval=iteration_interval, + read_latency=read_latency, + ip=ip, + loc=loc, + ) class LoopUnrollOp: - """Extension for LoopUnrollOp.""" - - def __init__( - self, - target: Union[Operation, Value], - *, - factor: Union[int, IntegerAttr], - ip=None, - loc=None, - ): - super().__init__( - _get_op_result_or_value(target), - factor=factor, - ip=ip, - loc=loc, - ) + """Extension for LoopUnrollOp.""" + + def __init__( + self, + target: Union[Operation, Value], + *, + factor: Union[int, IntegerAttr], + ip=None, + loc=None, + ): + super().__init__( + _get_op_result_or_value(target), + factor=factor, + ip=ip, + loc=loc, + ) diff --git a/mlir/python/mlir/dialects/_memref_ops_ext.py b/mlir/python/mlir/dialects/_memref_ops_ext.py index a00a087..825f1a0 100644 --- a/mlir/python/mlir/dialects/_memref_ops_ext.py +++ b/mlir/python/mlir/dialects/_memref_ops_ext.py @@ -3,34 +3,34 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from ..ir import * - from ._ods_common import get_op_result_or_value as _get_op_result_or_value - from ._ods_common import get_op_results_or_values as _get_op_results_or_values + from ..ir import * + from ._ods_common import get_op_result_or_value as _get_op_result_or_value + from ._ods_common import get_op_results_or_values as _get_op_results_or_values except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e from typing import Optional, Sequence, Union class LoadOp: - """Specialization for the MemRef load operation.""" + """Specialization for the MemRef load operation.""" - def __init__(self, - memref: Union[Operation, OpView, Value], - indices: Optional[Union[Operation, OpView, - Sequence[Value]]] = None, - *, - loc=None, - ip=None): - """Creates a memref load operation. + def __init__( + self, + memref: Union[Operation, OpView, Value], + indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None, + *, + loc=None, + ip=None + ): + """Creates a memref load operation. - Args: - memref: the buffer to load from. - indices: the list of subscripts, may be empty for zero-dimensional - buffers. - loc: user-visible location of the operation. - ip: insertion point. - """ - indices_resolved = [] if indices is None else _get_op_results_or_values( - indices) - super().__init__(memref, indices_resolved, loc=loc, ip=ip) + Args: + memref: the buffer to load from. + indices: the list of subscripts, may be empty for zero-dimensional + buffers. + loc: user-visible location of the operation. + ip: insertion point. + """ + indices_resolved = [] if indices is None else _get_op_results_or_values(indices) + super().__init__(memref, indices_resolved, loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/_ml_program_ops_ext.py b/mlir/python/mlir/dialects/_ml_program_ops_ext.py index 8db82cf..c84d23c 100644 --- a/mlir/python/mlir/dialects/_ml_program_ops_ext.py +++ b/mlir/python/mlir/dialects/_ml_program_ops_ext.py @@ -3,11 +3,11 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from typing import Union - from ..ir import * - from ._ods_common import get_default_loc_context as _get_default_loc_context + from typing import Union + from ..ir import * + from ._ods_common import get_default_loc_context as _get_default_loc_context except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e from ._ml_program_ops_gen import * @@ -17,100 +17,97 @@ RESULT_ATTRIBUTE_NAME = "res_attrs" class FuncOp: - """Specialization for the func op class.""" - - def __init__(self, - name, - type, - *, - visibility=None, - body_builder=None, - loc=None, - ip=None): - """ - Create a FuncOp with the provided `name`, `type`, and `visibility`. - - `name` is a string representing the function name. - - `type` is either a FunctionType or a pair of list describing inputs and - results. - - `visibility` is a string matching `public`, `private`, or `nested`. None - implies private visibility. - - `body_builder` is an optional callback, when provided a new entry block - is created and the callback is invoked with the new op as argument within - an InsertionPoint context already set for the block. The callback is - expected to insert a terminator in the block. - """ - sym_name = StringAttr.get(str(name)) - - # If the type is passed as a tuple, build a FunctionType on the fly. - if isinstance(type, tuple): - type = FunctionType.get(inputs=type[0], results=type[1]) - - type = TypeAttr.get(type) - sym_visibility = StringAttr.get( - str(visibility)) if visibility is not None else None - super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) - if body_builder: - entry_block = self.add_entry_block() - with InsertionPoint(entry_block): - body_builder(self) - - @property - def is_external(self): - return len(self.regions[0].blocks) == 0 - - @property - def body(self): - return self.regions[0] - - @property - def type(self): - return FunctionType(TypeAttr(self.attributes["function_type"]).value) - - @property - def visibility(self): - return self.attributes["sym_visibility"] - - @property - def name(self) -> StringAttr: - return StringAttr(self.attributes["sym_name"]) - - @property - def entry_block(self): - if self.is_external: - raise IndexError('External function does not have a body') - return self.regions[0].blocks[0] - - def add_entry_block(self): - """ - Add an entry block to the function body using the function signature to - infer block arguments. - Returns the newly created block - """ - if not self.is_external: - raise IndexError('The function already has an entry block!') - self.body.blocks.append(*self.type.inputs) - return self.body.blocks[0] - - @property - def arg_attrs(self): - return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) - - @arg_attrs.setter - def arg_attrs(self, attribute: Union[ArrayAttr, list]): - if isinstance(attribute, ArrayAttr): - self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute - else: - self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( - attribute, context=self.context) - - @property - def arguments(self): - return self.entry_block.arguments - - @property - def result_attrs(self): - return self.attributes[RESULT_ATTRIBUTE_NAME] - - @result_attrs.setter - def result_attrs(self, attribute: ArrayAttr): - self.attributes[RESULT_ATTRIBUTE_NAME] = attribute + """Specialization for the func op class.""" + + def __init__( + self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None + ): + """ + Create a FuncOp with the provided `name`, `type`, and `visibility`. + - `name` is a string representing the function name. + - `type` is either a FunctionType or a pair of list describing inputs and + results. + - `visibility` is a string matching `public`, `private`, or `nested`. None + implies private visibility. + - `body_builder` is an optional callback, when provided a new entry block + is created and the callback is invoked with the new op as argument within + an InsertionPoint context already set for the block. The callback is + expected to insert a terminator in the block. + """ + sym_name = StringAttr.get(str(name)) + + # If the type is passed as a tuple, build a FunctionType on the fly. + if isinstance(type, tuple): + type = FunctionType.get(inputs=type[0], results=type[1]) + + type = TypeAttr.get(type) + sym_visibility = ( + StringAttr.get(str(visibility)) if visibility is not None else None + ) + super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) + if body_builder: + entry_block = self.add_entry_block() + with InsertionPoint(entry_block): + body_builder(self) + + @property + def is_external(self): + return len(self.regions[0].blocks) == 0 + + @property + def body(self): + return self.regions[0] + + @property + def type(self): + return FunctionType(TypeAttr(self.attributes["function_type"]).value) + + @property + def visibility(self): + return self.attributes["sym_visibility"] + + @property + def name(self) -> StringAttr: + return StringAttr(self.attributes["sym_name"]) + + @property + def entry_block(self): + if self.is_external: + raise IndexError("External function does not have a body") + return self.regions[0].blocks[0] + + def add_entry_block(self): + """ + Add an entry block to the function body using the function signature to + infer block arguments. + Returns the newly created block + """ + if not self.is_external: + raise IndexError("The function already has an entry block!") + self.body.blocks.append(*self.type.inputs) + return self.body.blocks[0] + + @property + def arg_attrs(self): + return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) + + @arg_attrs.setter + def arg_attrs(self, attribute: Union[ArrayAttr, list]): + if isinstance(attribute, ArrayAttr): + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute + else: + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( + attribute, context=self.context + ) + + @property + def arguments(self): + return self.entry_block.arguments + + @property + def result_attrs(self): + return self.attributes[RESULT_ATTRIBUTE_NAME] + + @result_attrs.setter + def result_attrs(self, attribute: ArrayAttr): + self.attributes[RESULT_ATTRIBUTE_NAME] = attribute diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index 51b9008..7655629 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -18,144 +18,152 @@ __all__ = [ def extend_opview_class(ext_module): - """Decorator to extend an OpView class from an extension module. - - Extension modules can expose various entry-points: - Stand-alone class with the same name as a parent OpView class (i.e. - "ReturnOp"). A name-based match is attempted first before falling back - to a below mechanism. - - def select_opview_mixin(parent_opview_cls): - If defined, allows an appropriate mixin class to be selected dynamically - based on the parent OpView class. Should return NotImplemented if a - decision is not made. - - Args: - ext_module: A module from which to locate extensions. Can be None if not - available. - - Returns: - A decorator that takes an OpView subclass and further extends it as - needed. - """ - - def class_decorator(parent_opview_cls: type): - if ext_module is None: - return parent_opview_cls - mixin_cls = NotImplemented - # First try to resolve by name. - try: - mixin_cls = getattr(ext_module, parent_opview_cls.__name__) - except AttributeError: - # Fall back to a select_opview_mixin hook. - try: - select_mixin = getattr(ext_module, "select_opview_mixin") - except AttributeError: - pass - else: - mixin_cls = select_mixin(parent_opview_cls) - - if mixin_cls is NotImplemented or mixin_cls is None: - return parent_opview_cls - - # Have a mixin_cls. Create an appropriate subclass. - try: - - class LocalOpView(mixin_cls, parent_opview_cls): - pass - except TypeError as e: - raise TypeError( - f"Could not mixin {mixin_cls} into {parent_opview_cls}") from e - LocalOpView.__name__ = parent_opview_cls.__name__ - LocalOpView.__qualname__ = parent_opview_cls.__qualname__ - return LocalOpView - - return class_decorator + """Decorator to extend an OpView class from an extension module. + + Extension modules can expose various entry-points: + Stand-alone class with the same name as a parent OpView class (i.e. + "ReturnOp"). A name-based match is attempted first before falling back + to a below mechanism. + + def select_opview_mixin(parent_opview_cls): + If defined, allows an appropriate mixin class to be selected dynamically + based on the parent OpView class. Should return NotImplemented if a + decision is not made. + + Args: + ext_module: A module from which to locate extensions. Can be None if not + available. + + Returns: + A decorator that takes an OpView subclass and further extends it as + needed. + """ + + def class_decorator(parent_opview_cls: type): + if ext_module is None: + return parent_opview_cls + mixin_cls = NotImplemented + # First try to resolve by name. + try: + mixin_cls = getattr(ext_module, parent_opview_cls.__name__) + except AttributeError: + # Fall back to a select_opview_mixin hook. + try: + select_mixin = getattr(ext_module, "select_opview_mixin") + except AttributeError: + pass + else: + mixin_cls = select_mixin(parent_opview_cls) + + if mixin_cls is NotImplemented or mixin_cls is None: + return parent_opview_cls + + # Have a mixin_cls. Create an appropriate subclass. + try: + + class LocalOpView(mixin_cls, parent_opview_cls): + pass + + except TypeError as e: + raise TypeError( + f"Could not mixin {mixin_cls} into {parent_opview_cls}" + ) from e + LocalOpView.__name__ = parent_opview_cls.__name__ + LocalOpView.__qualname__ = parent_opview_cls.__qualname__ + return LocalOpView + + return class_decorator def segmented_accessor(elements, raw_segments, idx): - """ - Returns a slice of elements corresponding to the idx-th segment. - - elements: a sliceable container (operands or results). - raw_segments: an mlir.ir.Attribute, of DenseI32Array subclass containing - sizes of the segments. - idx: index of the segment. - """ - segments = _cext.ir.DenseI32ArrayAttr(raw_segments) - start = sum(segments[i] for i in range(idx)) - end = start + segments[idx] - return elements[start:end] - - -def equally_sized_accessor(elements, n_variadic, n_preceding_simple, - n_preceding_variadic): - """ - Returns a starting position and a number of elements per variadic group - assuming equally-sized groups and the given numbers of preceding groups. - - elements: a sequential container. - n_variadic: the number of variadic groups in the container. - n_preceding_simple: the number of non-variadic groups preceding the current - group. - n_preceding_variadic: the number of variadic groups preceding the current - group. - """ - - total_variadic_length = len(elements) - n_variadic + 1 - # This should be enforced by the C++-side trait verifier. - assert total_variadic_length % n_variadic == 0 - - elements_per_group = total_variadic_length // n_variadic - start = n_preceding_simple + n_preceding_variadic * elements_per_group - return start, elements_per_group + """ + Returns a slice of elements corresponding to the idx-th segment. + + elements: a sliceable container (operands or results). + raw_segments: an mlir.ir.Attribute, of DenseI32Array subclass containing + sizes of the segments. + idx: index of the segment. + """ + segments = _cext.ir.DenseI32ArrayAttr(raw_segments) + start = sum(segments[i] for i in range(idx)) + end = start + segments[idx] + return elements[start:end] + + +def equally_sized_accessor( + elements, n_variadic, n_preceding_simple, n_preceding_variadic +): + """ + Returns a starting position and a number of elements per variadic group + assuming equally-sized groups and the given numbers of preceding groups. + + elements: a sequential container. + n_variadic: the number of variadic groups in the container. + n_preceding_simple: the number of non-variadic groups preceding the current + group. + n_preceding_variadic: the number of variadic groups preceding the current + group. + """ + + total_variadic_length = len(elements) - n_variadic + 1 + # This should be enforced by the C++-side trait verifier. + assert total_variadic_length % n_variadic == 0 + + elements_per_group = total_variadic_length // n_variadic + start = n_preceding_simple + n_preceding_variadic * elements_per_group + return start, elements_per_group def get_default_loc_context(location=None): - """ - Returns a context in which the defaulted location is created. If the location - is None, takes the current location from the stack, raises ValueError if there - is no location on the stack. - """ - if location is None: - # Location.current raises ValueError if there is no current location. - return _cext.ir.Location.current.context - return location.context + """ + Returns a context in which the defaulted location is created. If the location + is None, takes the current location from the stack, raises ValueError if there + is no location on the stack. + """ + if location is None: + # Location.current raises ValueError if there is no current location. + return _cext.ir.Location.current.context + return location.context def get_op_result_or_value( - arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList] + arg: _Union[ + _cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList + ] ) -> _cext.ir.Value: - """Returns the given value or the single result of the given op. - - This is useful to implement op constructors so that they can take other ops as - arguments instead of requiring the caller to extract results for every op. - Raises ValueError if provided with an op that doesn't have a single result. - """ - if isinstance(arg, _cext.ir.OpView): - return arg.operation.result - elif isinstance(arg, _cext.ir.Operation): - return arg.result - elif isinstance(arg, _cext.ir.OpResultList): - return arg[0] - else: - assert isinstance(arg, _cext.ir.Value) - return arg + """Returns the given value or the single result of the given op. + + This is useful to implement op constructors so that they can take other ops as + arguments instead of requiring the caller to extract results for every op. + Raises ValueError if provided with an op that doesn't have a single result. + """ + if isinstance(arg, _cext.ir.OpView): + return arg.operation.result + elif isinstance(arg, _cext.ir.Operation): + return arg.result + elif isinstance(arg, _cext.ir.OpResultList): + return arg[0] + else: + assert isinstance(arg, _cext.ir.Value) + return arg def get_op_results_or_values( - arg: _Union[_cext.ir.OpView, _cext.ir.Operation, - _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]]] + arg: _Union[ + _cext.ir.OpView, + _cext.ir.Operation, + _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]], + ] ) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]: - """Returns the given sequence of values or the results of the given op. - - This is useful to implement op constructors so that they can take other ops as - lists of arguments instead of requiring the caller to extract results for - every op. - """ - if isinstance(arg, _cext.ir.OpView): - return arg.operation.results - elif isinstance(arg, _cext.ir.Operation): - return arg.results - else: - return [get_op_result_or_value(element) for element in arg] + """Returns the given sequence of values or the results of the given op. + + This is useful to implement op constructors so that they can take other ops as + lists of arguments instead of requiring the caller to extract results for + every op. + """ + if isinstance(arg, _cext.ir.OpView): + return arg.operation.results + elif isinstance(arg, _cext.ir.Operation): + return arg.results + else: + return [get_op_result_or_value(element) for element in arg] diff --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py index 40ccbef..fc9de0b 100644 --- a/mlir/python/mlir/dialects/_pdl_ops_ext.py +++ b/mlir/python/mlir/dialects/_pdl_ops_ext.py @@ -3,10 +3,10 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from ..ir import * - from ..dialects import pdl + from ..ir import * + from ..dialects import pdl except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e from typing import Union, Optional, Sequence, Mapping from ._ods_common import ( @@ -16,264 +16,256 @@ from ._ods_common import ( class ApplyNativeConstraintOp: - """Specialization for PDL apply native constraint op class.""" - - def __init__( - self, - name: Union[str, StringAttr], - args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, - *, - loc=None, - ip=None, - ): - if args is None: - args = [] - args = _get_values(args) - super().__init__(name, args, loc=loc, ip=ip) + """Specialization for PDL apply native constraint op class.""" + + def __init__( + self, + name: Union[str, StringAttr], + args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + *, + loc=None, + ip=None, + ): + if args is None: + args = [] + args = _get_values(args) + super().__init__(name, args, loc=loc, ip=ip) class ApplyNativeRewriteOp: - """Specialization for PDL apply native rewrite op class.""" - - def __init__( - self, - results: Sequence[Type], - name: Union[str, StringAttr], - args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, - *, - loc=None, - ip=None, - ): - if args is None: - args = [] - args = _get_values(args) - super().__init__(results, name, args, loc=loc, ip=ip) + """Specialization for PDL apply native rewrite op class.""" + + def __init__( + self, + results: Sequence[Type], + name: Union[str, StringAttr], + args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + *, + loc=None, + ip=None, + ): + if args is None: + args = [] + args = _get_values(args) + super().__init__(results, name, args, loc=loc, ip=ip) class AttributeOp: - """Specialization for PDL attribute op class.""" + """Specialization for PDL attribute op class.""" - def __init__( - self, - valueType: Optional[Union[OpView, Operation, Value]] = None, - value: Optional[Attribute] = None, - *, - loc=None, - ip=None, - ): - valueType = valueType if valueType is None else _get_value(valueType) - result = pdl.AttributeType.get() - super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip) + def __init__( + self, + valueType: Optional[Union[OpView, Operation, Value]] = None, + value: Optional[Attribute] = None, + *, + loc=None, + ip=None, + ): + valueType = valueType if valueType is None else _get_value(valueType) + result = pdl.AttributeType.get() + super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip) class EraseOp: - """Specialization for PDL erase op class.""" + """Specialization for PDL erase op class.""" - def __init__( - self, - operation: Optional[Union[OpView, Operation, Value]] = None, - *, - loc=None, - ip=None, - ): - operation = _get_value(operation) - super().__init__(operation, loc=loc, ip=ip) + def __init__( + self, + operation: Optional[Union[OpView, Operation, Value]] = None, + *, + loc=None, + ip=None, + ): + operation = _get_value(operation) + super().__init__(operation, loc=loc, ip=ip) class OperandOp: - """Specialization for PDL operand op class.""" + """Specialization for PDL operand op class.""" - def __init__( - self, - type: Optional[Union[OpView, Operation, Value]] = None, - *, - loc=None, - ip=None, - ): - type = type if type is None else _get_value(type) - result = pdl.ValueType.get() - super().__init__(result, valueType=type, loc=loc, ip=ip) + def __init__( + self, + type: Optional[Union[OpView, Operation, Value]] = None, + *, + loc=None, + ip=None, + ): + type = type if type is None else _get_value(type) + result = pdl.ValueType.get() + super().__init__(result, valueType=type, loc=loc, ip=ip) class OperandsOp: - """Specialization for PDL operands op class.""" + """Specialization for PDL operands op class.""" - def __init__( - self, - types: Optional[Union[OpView, Operation, Value]] = None, - *, - loc=None, - ip=None, - ): - types = types if types is None else _get_value(types) - result = pdl.RangeType.get(pdl.ValueType.get()) - super().__init__(result, valueType=types, loc=loc, ip=ip) + def __init__( + self, + types: Optional[Union[OpView, Operation, Value]] = None, + *, + loc=None, + ip=None, + ): + types = types if types is None else _get_value(types) + result = pdl.RangeType.get(pdl.ValueType.get()) + super().__init__(result, valueType=types, loc=loc, ip=ip) class OperationOp: - """Specialization for PDL operand op class.""" - - def __init__( - self, - name: Optional[Union[str, StringAttr]] = None, - args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, - attributes: Optional[Mapping[str, Union[OpView, Operation, - Value]]] = None, - types: Optional[Sequence[Union[OpView, Operation, Value]]] = None, - *, - loc=None, - ip=None, - ): - if types is None: - types = [] - if attributes is None: - attributes = {} - if args is None: - args = [] - args = _get_values(args) - attrNames = [] - attrValues = [] - for attrName, attrValue in attributes.items(): - attrNames.append(StringAttr.get(attrName)) - attrValues.append(_get_value(attrValue)) - attrNames = ArrayAttr.get(attrNames) - types = _get_values(types) - result = pdl.OperationType.get() - super().__init__(result, - args, - attrValues, - attrNames, - types, - opName=name, - loc=loc, - ip=ip) + """Specialization for PDL operand op class.""" + + def __init__( + self, + name: Optional[Union[str, StringAttr]] = None, + args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + attributes: Optional[Mapping[str, Union[OpView, Operation, Value]]] = None, + types: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + *, + loc=None, + ip=None, + ): + if types is None: + types = [] + if attributes is None: + attributes = {} + if args is None: + args = [] + args = _get_values(args) + attrNames = [] + attrValues = [] + for attrName, attrValue in attributes.items(): + attrNames.append(StringAttr.get(attrName)) + attrValues.append(_get_value(attrValue)) + attrNames = ArrayAttr.get(attrNames) + types = _get_values(types) + result = pdl.OperationType.get() + super().__init__( + result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip + ) class PatternOp: - """Specialization for PDL pattern op class.""" - - def __init__( - self, - benefit: Union[IntegerAttr, int], - name: Optional[Union[StringAttr, str]] = None, - *, - loc=None, - ip=None, - ): - """Creates an PDL `pattern` operation.""" - super().__init__(benefit, sym_name=name, loc=loc, ip=ip) - self.regions[0].blocks.append() - - @property - def body(self): - """Return the body (block) of the pattern.""" - return self.regions[0].blocks[0] + """Specialization for PDL pattern op class.""" + + def __init__( + self, + benefit: Union[IntegerAttr, int], + name: Optional[Union[StringAttr, str]] = None, + *, + loc=None, + ip=None, + ): + """Creates an PDL `pattern` operation.""" + super().__init__(benefit, sym_name=name, loc=loc, ip=ip) + self.regions[0].blocks.append() + + @property + def body(self): + """Return the body (block) of the pattern.""" + return self.regions[0].blocks[0] class ReplaceOp: - """Specialization for PDL replace op class.""" - - def __init__( - self, - op: Union[OpView, Operation, Value], - *, - with_op: Optional[Union[OpView, Operation, Value]] = None, - with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None, - loc=None, - ip=None, - ): - if with_values is None: - with_values = [] - op = _get_value(op) - with_op = with_op if with_op is None else _get_value(with_op) - with_values = _get_values(with_values) - super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip) + """Specialization for PDL replace op class.""" + + def __init__( + self, + op: Union[OpView, Operation, Value], + *, + with_op: Optional[Union[OpView, Operation, Value]] = None, + with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + loc=None, + ip=None, + ): + if with_values is None: + with_values = [] + op = _get_value(op) + with_op = with_op if with_op is None else _get_value(with_op) + with_values = _get_values(with_values) + super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip) class ResultOp: - """Specialization for PDL result op class.""" + """Specialization for PDL result op class.""" - def __init__( - self, - parent: Union[OpView, Operation, Value], - index: Union[IntegerAttr, int], - *, - loc=None, - ip=None, - ): - parent = _get_value(parent) - result = pdl.ValueType.get() - super().__init__(result, parent, index, loc=loc, ip=ip) + def __init__( + self, + parent: Union[OpView, Operation, Value], + index: Union[IntegerAttr, int], + *, + loc=None, + ip=None, + ): + parent = _get_value(parent) + result = pdl.ValueType.get() + super().__init__(result, parent, index, loc=loc, ip=ip) class ResultsOp: - """Specialization for PDL results op class.""" + """Specialization for PDL results op class.""" - def __init__( - self, - result: Type, - parent: Union[OpView, Operation, Value], - index: Optional[Union[IntegerAttr, int]] = None, - *, - loc=None, - ip=None, - ): - parent = _get_value(parent) - super().__init__(result, parent, index=index, loc=loc, ip=ip) + def __init__( + self, + result: Type, + parent: Union[OpView, Operation, Value], + index: Optional[Union[IntegerAttr, int]] = None, + *, + loc=None, + ip=None, + ): + parent = _get_value(parent) + super().__init__(result, parent, index=index, loc=loc, ip=ip) class RewriteOp: - """Specialization for PDL rewrite op class.""" - - def __init__( - self, - root: Optional[Union[OpView, Operation, Value]] = None, - name: Optional[Union[StringAttr, str]] = None, - args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, - *, - loc=None, - ip=None, - ): - if args is None: - args = [] - root = root if root is None else _get_value(root) - args = _get_values(args) - super().__init__(args, root=root, name=name, loc=loc, ip=ip) - - def add_body(self): - """Add body (block) to the rewrite.""" - self.regions[0].blocks.append() - return self.body - - @property - def body(self): - """Return the body (block) of the rewrite.""" - return self.regions[0].blocks[0] + """Specialization for PDL rewrite op class.""" + + def __init__( + self, + root: Optional[Union[OpView, Operation, Value]] = None, + name: Optional[Union[StringAttr, str]] = None, + args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + *, + loc=None, + ip=None, + ): + if args is None: + args = [] + root = root if root is None else _get_value(root) + args = _get_values(args) + super().__init__(args, root=root, name=name, loc=loc, ip=ip) + + def add_body(self): + """Add body (block) to the rewrite.""" + self.regions[0].blocks.append() + return self.body + + @property + def body(self): + """Return the body (block) of the rewrite.""" + return self.regions[0].blocks[0] class TypeOp: - """Specialization for PDL type op class.""" + """Specialization for PDL type op class.""" - def __init__(self, - constantType: Optional[Union[TypeAttr, Type]] = None, - *, - loc=None, - ip=None): - result = pdl.TypeType.get() - super().__init__(result, constantType=constantType, loc=loc, ip=ip) + def __init__( + self, constantType: Optional[Union[TypeAttr, Type]] = None, *, loc=None, ip=None + ): + result = pdl.TypeType.get() + super().__init__(result, constantType=constantType, loc=loc, ip=ip) class TypesOp: - """Specialization for PDL types op class.""" - - def __init__( - self, - constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None, - *, - loc=None, - ip=None, - ): - if constantTypes is None: - constantTypes = [] - result = pdl.RangeType.get(pdl.TypeType.get()) - super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip) + """Specialization for PDL types op class.""" + + def __init__( + self, + constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None, + *, + loc=None, + ip=None, + ): + if constantTypes is None: + constantTypes = [] + result = pdl.RangeType.get(pdl.TypeType.get()) + super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py index 3c3e673..4b2519e 100644 --- a/mlir/python/mlir/dialects/_scf_ops_ext.py +++ b/mlir/python/mlir/dialects/_scf_ops_ext.py @@ -3,105 +3,104 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from ..ir import * + from ..ir import * except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e from typing import Any, Optional, Sequence, Union -from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values +from ._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, +) + class ForOp: - """Specialization for the SCF for op class.""" - - def __init__(self, - lower_bound, - upper_bound, - step, - iter_args: Optional[Union[Operation, OpView, - Sequence[Value]]] = None, - *, - loc=None, - ip=None): - """Creates an SCF `for` operation. - - - `lower_bound` is the value to use as lower bound of the loop. - - `upper_bound` is the value to use as upper bound of the loop. - - `step` is the value to use as loop step. - - `iter_args` is a list of additional loop-carried arguments or an operation - producing them as results. - """ - if iter_args is None: - iter_args = [] - iter_args = _get_op_results_or_values(iter_args) - - results = [arg.type for arg in iter_args] - super().__init__( - self.build_generic( - regions=1, - results=results, - operands=[ - _get_op_result_or_value(o) - for o in [lower_bound, upper_bound, step] - ] + list(iter_args), - loc=loc, - ip=ip)) - self.regions[0].blocks.append(IndexType.get(), *results) - - @property - def body(self): - """Returns the body (block) of the loop.""" - return self.regions[0].blocks[0] - - @property - def induction_variable(self): - """Returns the induction variable of the loop.""" - return self.body.arguments[0] - - @property - def inner_iter_args(self): - """Returns the loop-carried arguments usable within the loop. - - To obtain the loop-carried operands, use `iter_args`. - """ - return self.body.arguments[1:] + """Specialization for the SCF for op class.""" + + def __init__( + self, + lower_bound, + upper_bound, + step, + iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None, + *, + loc=None, + ip=None + ): + """Creates an SCF `for` operation. + + - `lower_bound` is the value to use as lower bound of the loop. + - `upper_bound` is the value to use as upper bound of the loop. + - `step` is the value to use as loop step. + - `iter_args` is a list of additional loop-carried arguments or an operation + producing them as results. + """ + if iter_args is None: + iter_args = [] + iter_args = _get_op_results_or_values(iter_args) + + results = [arg.type for arg in iter_args] + super().__init__( + self.build_generic( + regions=1, + results=results, + operands=[ + _get_op_result_or_value(o) for o in [lower_bound, upper_bound, step] + ] + + list(iter_args), + loc=loc, + ip=ip, + ) + ) + self.regions[0].blocks.append(IndexType.get(), *results) + + @property + def body(self): + """Returns the body (block) of the loop.""" + return self.regions[0].blocks[0] + + @property + def induction_variable(self): + """Returns the induction variable of the loop.""" + return self.body.arguments[0] + + @property + def inner_iter_args(self): + """Returns the loop-carried arguments usable within the loop. + + To obtain the loop-carried operands, use `iter_args`. + """ + return self.body.arguments[1:] class IfOp: - """Specialization for the SCF if op class.""" - - def __init__(self, - cond, - results_=[], - *, - hasElse=False, - loc=None, - ip=None): - """Creates an SCF `if` operation. - - - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed. - - `hasElse` determines whether the if operation has the else branch. - """ - operands = [] - operands.append(cond) - results = [] - results.extend(results_) - super().__init__( - self.build_generic( - regions=2, - results=results, - operands=operands, - loc=loc, - ip=ip)) - self.regions[0].blocks.append(*[]) - if hasElse: - self.regions[1].blocks.append(*[]) - - @property - def then_block(self): - """Returns the then block of the if operation.""" - return self.regions[0].blocks[0] - - @property - def else_block(self): - """Returns the else block of the if operation.""" - return self.regions[1].blocks[0] + """Specialization for the SCF if op class.""" + + def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None): + """Creates an SCF `if` operation. + + - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed. + - `hasElse` determines whether the if operation has the else branch. + """ + operands = [] + operands.append(cond) + results = [] + results.extend(results_) + super().__init__( + self.build_generic( + regions=2, results=results, operands=operands, loc=loc, ip=ip + ) + ) + self.regions[0].blocks.append(*[]) + if hasElse: + self.regions[1].blocks.append(*[]) + + @property + def then_block(self): + """Returns the then block of the if operation.""" + return self.regions[0].blocks[0] + + @property + def else_block(self): + """Returns the else block of the if operation.""" + return self.regions[1].blocks[0] diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 9c051cd..30dafff 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -3,11 +3,11 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from ..ir import * - from ._ods_common import get_op_result_or_value as _get_op_result_or_value - from ..dialects import pdl, transform + from ..ir import * + from ._ods_common import get_op_result_or_value as _get_op_result_or_value + from ..dialects import pdl, transform except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e from typing import List, Optional, Sequence, Union, overload @@ -16,312 +16,315 @@ OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]] def _get_int_int_array_attr( - values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, - IntOrAttrList]]]] + values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]] ) -> ArrayAttr: - """Creates an array attribute containing array attributes of integers. + """Creates an array attribute containing array attributes of integers. If the operand is already an array attribute, forwards it. Otherwise treats the operand as a list of attributes or integers, potentially interpserced, to create a new array-of-array attribute. Expects the thread-local MLIR context to have been set by the context manager. """ - if values is None: - return ArrayAttr.get([]) - if isinstance(values, ArrayAttr): - return values - if isinstance(values, list): - values = [ - ArrayAttr.get( - [IntegerAttr.get(IntegerType.get_signless(64), v) - for v in value]) - for value in values - ] + if values is None: + return ArrayAttr.get([]) + if isinstance(values, ArrayAttr): + return values + if isinstance(values, list): + values = [ + ArrayAttr.get( + [IntegerAttr.get(IntegerType.get_signless(64), v) for v in value] + ) + for value in values + ] - return ArrayAttr.get(values) + return ArrayAttr.get(values) class DecomposeOp: - """Specialization for DecomposeOp class.""" + """Specialization for DecomposeOp class.""" - def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): - super().__init__(pdl.OperationType.get(), - _get_op_result_or_value(target), - loc=loc, - ip=ip) + def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): + super().__init__( + pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip + ) class GeneralizeOp: - """Specialization for GeneralizeOp class.""" + """Specialization for GeneralizeOp class.""" - def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): - super().__init__(pdl.OperationType.get(), - _get_op_result_or_value(target), - loc=loc, - ip=ip) + def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): + super().__init__( + pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip + ) class InterchangeOp: - """Specialization for InterchangeOp class.""" - - def __init__( - self, - target: Union[Operation, Value], - *, - iterator_interchange: OptionalIntList = None, - loc=None, - ip=None, - ): - pdl_operation_type = pdl.OperationType.get() - super().__init__( - pdl_operation_type, - _get_op_result_or_value(target), - iterator_interchange=iterator_interchange, - loc=loc, - ip=ip, - ) + """Specialization for InterchangeOp class.""" + + def __init__( + self, + target: Union[Operation, Value], + *, + iterator_interchange: OptionalIntList = None, + loc=None, + ip=None, + ): + pdl_operation_type = pdl.OperationType.get() + super().__init__( + pdl_operation_type, + _get_op_result_or_value(target), + iterator_interchange=iterator_interchange, + loc=loc, + ip=ip, + ) class MatchOp: - """Specialization for MatchOp class.""" - - @classmethod - def match_op_names( - MatchOp, - target: Union[Operation, Value], - names: Sequence[str], - loc=None, - ip=None, - ): - pdl_operation_type = pdl.OperationType.get() - return MatchOp( - pdl_operation_type, - _get_op_result_or_value(target), - ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))), - loc=loc, - ip=ip, - ) + """Specialization for MatchOp class.""" + + @classmethod + def match_op_names( + MatchOp, + target: Union[Operation, Value], + names: Sequence[str], + loc=None, + ip=None, + ): + pdl_operation_type = pdl.OperationType.get() + return MatchOp( + pdl_operation_type, + _get_op_result_or_value(target), + ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))), + loc=loc, + ip=ip, + ) class MultiTileSizesOp: - """Specialization for MultitileSizesOp class.""" - - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - dimension: Union[int, IntegerAttr], - target_size: Union[int, IntegerAttr], - divisor: Optional[Optional[Union[int, IntegerAttr]]] = None, - loc=None, - ip=None, - ): - if divisor is None: - divisor = 1 - super().__init__( - result_type, - result_type, - result_type, - _get_op_result_or_value(target), - dimension=dimension, - target_size=target_size, - divisor=divisor, - loc=loc, - ip=ip, - ) + """Specialization for MultitileSizesOp class.""" + + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + dimension: Union[int, IntegerAttr], + target_size: Union[int, IntegerAttr], + divisor: Optional[Optional[Union[int, IntegerAttr]]] = None, + loc=None, + ip=None, + ): + if divisor is None: + divisor = 1 + super().__init__( + result_type, + result_type, + result_type, + _get_op_result_or_value(target), + dimension=dimension, + target_size=target_size, + divisor=divisor, + loc=loc, + ip=ip, + ) class PadOp: - """Specialization for PadOp class.""" - - def __init__( - self, - target: Union[Operation, Value], - *, - padding_values: Optional[Optional[Union[ArrayAttr, - Sequence[Attribute]]]] = None, - padding_dimensions: OptionalIntList = None, - pack_paddings: OptionalIntList = None, - transpose_paddings: Optional[Union[ArrayAttr, Sequence[Union[ - ArrayAttr, IntOrAttrList]]]] = None, - loc=None, - ip=None, - ): - if transpose_paddings is None: - transpose_paddings = [] - if pack_paddings is None: - pack_paddings = [] - if padding_dimensions is None: - padding_dimensions = [] - if padding_values is None: - padding_values = [] - pdl_operation_type = pdl.OperationType.get() - transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings) - super().__init__( - pdl_operation_type, - _get_op_result_or_value(target), - padding_values=padding_values, - padding_dimensions=padding_dimensions, - pack_paddings=pack_paddings, - transpose_paddings=transpose_paddings_attr, - loc=loc, - ip=ip, - ) + """Specialization for PadOp class.""" + + def __init__( + self, + target: Union[Operation, Value], + *, + padding_values: Optional[ + Optional[Union[ArrayAttr, Sequence[Attribute]]] + ] = None, + padding_dimensions: OptionalIntList = None, + pack_paddings: OptionalIntList = None, + transpose_paddings: Optional[ + Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]] + ] = None, + loc=None, + ip=None, + ): + if transpose_paddings is None: + transpose_paddings = [] + if pack_paddings is None: + pack_paddings = [] + if padding_dimensions is None: + padding_dimensions = [] + if padding_values is None: + padding_values = [] + pdl_operation_type = pdl.OperationType.get() + transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings) + super().__init__( + pdl_operation_type, + _get_op_result_or_value(target), + padding_values=padding_values, + padding_dimensions=padding_dimensions, + pack_paddings=pack_paddings, + transpose_paddings=transpose_paddings_attr, + loc=loc, + ip=ip, + ) class ScalarizeOp: - """Specialization for ScalarizeOp class.""" + """Specialization for ScalarizeOp class.""" - def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): - pdl_operation_type = pdl.OperationType.get() - super().__init__(pdl_operation_type, - _get_op_result_or_value(target), - loc=loc, - ip=ip) + def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): + pdl_operation_type = pdl.OperationType.get() + super().__init__( + pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip + ) class SplitOp: - """Specialization for SplitOp class.""" - - def __init__( - self, - target: Union[Operation, Value], - dimension: Union[int, Attribute], - split_point: Union[int, Operation, Value, Attribute], - *, - loc=None, - ip=None, - ): - if isinstance(split_point, int): - static_split_point = split_point - dynamic_split_point = None - else: - static_split_point = ShapedType.get_dynamic_size() - dynamic_split_point = _get_op_result_or_value(split_point) - - target = _get_op_result_or_value(target) - - super().__init__( - target.type, - target.type, - target, - dimension=dimension, - static_split_point=static_split_point, - dynamic_split_point=dynamic_split_point, - loc=loc, - ip=ip, - ) + """Specialization for SplitOp class.""" + + def __init__( + self, + target: Union[Operation, Value], + dimension: Union[int, Attribute], + split_point: Union[int, Operation, Value, Attribute], + *, + loc=None, + ip=None, + ): + if isinstance(split_point, int): + static_split_point = split_point + dynamic_split_point = None + else: + static_split_point = ShapedType.get_dynamic_size() + dynamic_split_point = _get_op_result_or_value(split_point) + + target = _get_op_result_or_value(target) + + super().__init__( + target.type, + target.type, + target, + dimension=dimension, + static_split_point=static_split_point, + dynamic_split_point=dynamic_split_point, + loc=loc, + ip=ip, + ) class TileOp: - """Specialization for TileOp class.""" - - @overload - def __init__( - self, - loop_types: Union[Type, List[Type]], - target: Union[Operation, Value], - *, - sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]], - ArrayAttr]] = None, - interchange: OptionalIntList = None, - loc=None, - ip=None, - ): - ... - - @overload - def __init__( - self, - target: Union[Operation, Value, OpView], - *, - sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]], - ArrayAttr]] = None, - interchange: OptionalIntList = None, - loc=None, - ip=None, - ): - ... - - def __init__( - self, - loop_types_or_target: Union[Type, List[Type], Operation, Value], - target_or_none: Optional[Union[Operation, Value, OpView]] = None, - *, - sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]], - ArrayAttr]] = None, - interchange: OptionalIntList = None, - loc=None, - ip=None, - ): - if interchange is None: - interchange = [] - if sizes is None: - sizes = [] - - static_sizes = [] - dynamic_sizes = [] - if isinstance(sizes, ArrayAttr): - sizes_attr = sizes - else: - for size in sizes: - if isinstance(size, int): - static_sizes.append(size) + """Specialization for TileOp class.""" + + @overload + def __init__( + self, + loop_types: Union[Type, List[Type]], + target: Union[Operation, Value], + *, + sizes: Optional[ + Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr] + ] = None, + interchange: OptionalIntList = None, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, Value, OpView], + *, + sizes: Optional[ + Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr] + ] = None, + interchange: OptionalIntList = None, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + loop_types_or_target: Union[Type, List[Type], Operation, Value], + target_or_none: Optional[Union[Operation, Value, OpView]] = None, + *, + sizes: Optional[ + Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr] + ] = None, + interchange: OptionalIntList = None, + loc=None, + ip=None, + ): + if interchange is None: + interchange = [] + if sizes is None: + sizes = [] + + static_sizes = [] + dynamic_sizes = [] + if isinstance(sizes, ArrayAttr): + sizes_attr = sizes + else: + for size in sizes: + if isinstance(size, int): + static_sizes.append(size) + else: + static_sizes.append(ShapedType.get_dynamic_size()) + dynamic_sizes.append(_get_op_result_or_value(size)) + sizes_attr = DenseI64ArrayAttr.get(static_sizes) + + num_loops = sum(v if v == 0 else 1 for v in self.__extract_values(sizes_attr)) + + if isinstance(loop_types_or_target, (Operation, Value, OpView)): + loop_types = [transform.AnyOpType.get()] * num_loops + target = loop_types_or_target + assert target_or_none is None, "Cannot construct TileOp with two targets." else: - static_sizes.append(ShapedType.get_dynamic_size()) - dynamic_sizes.append(_get_op_result_or_value(size)) - sizes_attr = DenseI64ArrayAttr.get(static_sizes) - - num_loops = sum( - v if v == 0 else 1 for v in self.__extract_values(sizes_attr)) - - if isinstance(loop_types_or_target, (Operation, Value, OpView)): - loop_types = [transform.AnyOpType.get()] * num_loops - target = loop_types_or_target - assert target_or_none is None, "Cannot construct TileOp with two targets." - else: - loop_types = (([loop_types_or_target] * num_loops) if isinstance( - loop_types_or_target, Type) else loop_types_or_target) - target = target_or_none - - target = _get_op_result_or_value(target) - - super().__init__( - target.type, - loop_types, - target, - dynamic_sizes=dynamic_sizes, - static_sizes=sizes_attr, - interchange=interchange, - loc=loc, - ip=ip, - ) - - def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]: - if not attr: - return [] - return [element for element in attr] + loop_types = ( + ([loop_types_or_target] * num_loops) + if isinstance(loop_types_or_target, Type) + else loop_types_or_target + ) + target = target_or_none + + target = _get_op_result_or_value(target) + + super().__init__( + target.type, + loop_types, + target, + dynamic_sizes=dynamic_sizes, + static_sizes=sizes_attr, + interchange=interchange, + loc=loc, + ip=ip, + ) + + def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]: + if not attr: + return [] + return [element for element in attr] class VectorizeOp: - """Specialization for VectorizeOp class.""" - - def __init__( - self, - target: Union[Operation, Value], - *, - vectorize_padding: Union[bool, BoolAttr] = False, - loc=None, - ip=None, - ): - pdl_operation_type = pdl.OperationType.get() - if isinstance(vectorize_padding, bool): - vectorize_padding = UnitAttr.get() - super().__init__( - pdl_operation_type, - _get_op_result_or_value(target), - vectorize_padding=vectorize_padding, - loc=loc, - ip=ip, - ) + """Specialization for VectorizeOp class.""" + + def __init__( + self, + target: Union[Operation, Value], + *, + vectorize_padding: Union[bool, BoolAttr] = False, + loc=None, + ip=None, + ): + pdl_operation_type = pdl.OperationType.get() + if isinstance(vectorize_padding, bool): + vectorize_padding = UnitAttr.get() + super().__init__( + pdl_operation_type, + _get_op_result_or_value(target), + vectorize_padding=vectorize_padding, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/_tensor_ops_ext.py b/mlir/python/mlir/dialects/_tensor_ops_ext.py index 51d998b..09b9ec6 100644 --- a/mlir/python/mlir/dialects/_tensor_ops_ext.py +++ b/mlir/python/mlir/dialects/_tensor_ops_ext.py @@ -3,40 +3,42 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from ..ir import * + from ..ir import * except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e from typing import Any, Optional, Sequence, Union -from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values +from ._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, +) class EmptyOp: - """Extends the tensor.empty op.""" + """Extends the tensor.empty op.""" - def __init__(self, - sizes: Sequence[Union[int, Value]], - element_type: Type, - *, - loc=None, - ip=None): - """Constructs an `empty` with mixed static/dynamic sizes.""" - # TODO: Refactor the EmptyOp to take an element type attribute and - # then use normal result type inference, unifying the Python and C++ side - # with a standard mechanism (versus stashing that in builders). - dynamic_sizes = [] - static_sizes = [] - for s in sizes: - if isinstance(s, int): - static_sizes.append(s) - else: - static_sizes.append(ShapedType.get_dynamic_size()) - dynamic_sizes.append(s) - result_type = RankedTensorType.get(static_sizes, element_type) - op = self.build_generic( - results=[result_type], - operands=dynamic_sizes, - attributes={}, - loc=loc, - ip=ip) - OpView.__init__(self, op) + def __init__( + self, + sizes: Sequence[Union[int, Value]], + element_type: Type, + *, + loc=None, + ip=None + ): + """Constructs an `empty` with mixed static/dynamic sizes.""" + # TODO: Refactor the EmptyOp to take an element type attribute and + # then use normal result type inference, unifying the Python and C++ side + # with a standard mechanism (versus stashing that in builders). + dynamic_sizes = [] + static_sizes = [] + for s in sizes: + if isinstance(s, int): + static_sizes.append(s) + else: + static_sizes.append(ShapedType.get_dynamic_size()) + dynamic_sizes.append(s) + result_type = RankedTensorType.get(static_sizes, element_type) + op = self.build_generic( + results=[result_type], operands=dynamic_sizes, attributes={}, loc=loc, ip=ip + ) + OpView.__init__(self, op) diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py index cc4428e..425ec65 100644 --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -3,144 +3,131 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from ..ir import * - from ._ods_common import ( - get_op_result_or_value as _get_op_result_or_value, - get_op_results_or_values as _get_op_results_or_values, - ) + from ..ir import * + from ._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + ) except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e from typing import Optional, Sequence, Union class CastOp: - - def __init__(self, - result_type: Type, - target: Union[Operation, Value], - *, - loc=None, - ip=None): - super().__init__(result_type, - _get_op_result_or_value(target), - loc=loc, - ip=ip) + def __init__( + self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None + ): + super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip) class GetClosestIsolatedParentOp: - - def __init__(self, - result_type: Type, - target: Union[Operation, Value], - *, - loc=None, - ip=None): - super().__init__(result_type, - _get_op_result_or_value(target), - loc=loc, - ip=ip) + def __init__( + self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None + ): + super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip) class MergeHandlesOp: - - def __init__( - self, - handles: Sequence[Union[Operation, Value]], - *, - deduplicate: bool = False, - loc=None, - ip=None, - ): - super().__init__( - [_get_op_result_or_value(h) for h in handles], - deduplicate=deduplicate, - loc=loc, - ip=ip, - ) + def __init__( + self, + handles: Sequence[Union[Operation, Value]], + *, + deduplicate: bool = False, + loc=None, + ip=None, + ): + super().__init__( + [_get_op_result_or_value(h) for h in handles], + deduplicate=deduplicate, + loc=loc, + ip=ip, + ) class ReplicateOp: - - def __init__( - self, - pattern: Union[Operation, Value], - handles: Sequence[Union[Operation, Value]], - *, - loc=None, - ip=None, - ): - super().__init__( - [_get_op_result_or_value(h).type for h in handles], - _get_op_result_or_value(pattern), - [_get_op_result_or_value(h) for h in handles], - loc=loc, - ip=ip, - ) + def __init__( + self, + pattern: Union[Operation, Value], + handles: Sequence[Union[Operation, Value]], + *, + loc=None, + ip=None, + ): + super().__init__( + [_get_op_result_or_value(h).type for h in handles], + _get_op_result_or_value(pattern), + [_get_op_result_or_value(h) for h in handles], + loc=loc, + ip=ip, + ) class SequenceOp: - - def __init__( - self, - failure_propagation_mode, - results: Sequence[Type], - target: Union[Operation, Value, Type], - extra_bindings: Optional[Union[Sequence[Value], Sequence[Type], Operation, - OpView]] = None, - ): - root = (_get_op_result_or_value(target) if isinstance( - target, (Operation, Value)) else None) - root_type = root.type if not isinstance(target, Type) else target - if not isinstance(failure_propagation_mode, Attribute): - failure_propagation_mode_attr = IntegerAttr.get( - IntegerType.get_signless(32), failure_propagation_mode._as_int()) - else: - failure_propagation_mode_attr = failure_propagation_mode - - if extra_bindings is None: - extra_bindings = [] - if isinstance(extra_bindings, (Operation, OpView)): - extra_bindings = _get_op_results_or_values(extra_bindings) - - extra_binding_types = [] - if len(extra_bindings) != 0: - if isinstance(extra_bindings[0], Type): - extra_binding_types = extra_bindings - extra_bindings = [] - else: - extra_binding_types = [v.type for v in extra_bindings] - - super().__init__( - results_=results, - failure_propagation_mode=failure_propagation_mode_attr, - root=root, - extra_bindings=extra_bindings, - ) - self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types)) - - @property - def body(self) -> Block: - return self.regions[0].blocks[0] - - @property - def bodyTarget(self) -> Value: - return self.body.arguments[0] - - @property - def bodyExtraArgs(self) -> BlockArgumentList: - return self.body.arguments[1:] + def __init__( + self, + failure_propagation_mode, + results: Sequence[Type], + target: Union[Operation, Value, Type], + extra_bindings: Optional[ + Union[Sequence[Value], Sequence[Type], Operation, OpView] + ] = None, + ): + root = ( + _get_op_result_or_value(target) + if isinstance(target, (Operation, Value)) + else None + ) + root_type = root.type if not isinstance(target, Type) else target + if not isinstance(failure_propagation_mode, Attribute): + failure_propagation_mode_attr = IntegerAttr.get( + IntegerType.get_signless(32), failure_propagation_mode._as_int() + ) + else: + failure_propagation_mode_attr = failure_propagation_mode + + if extra_bindings is None: + extra_bindings = [] + if isinstance(extra_bindings, (Operation, OpView)): + extra_bindings = _get_op_results_or_values(extra_bindings) + + extra_binding_types = [] + if len(extra_bindings) != 0: + if isinstance(extra_bindings[0], Type): + extra_binding_types = extra_bindings + extra_bindings = [] + else: + extra_binding_types = [v.type for v in extra_bindings] + + super().__init__( + results_=results, + failure_propagation_mode=failure_propagation_mode_attr, + root=root, + extra_bindings=extra_bindings, + ) + self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types)) + + @property + def body(self) -> Block: + return self.regions[0].blocks[0] + + @property + def bodyTarget(self) -> Value: + return self.body.arguments[0] + + @property + def bodyExtraArgs(self) -> BlockArgumentList: + return self.body.arguments[1:] class YieldOp: - - def __init__( - self, - operands: Optional[Union[Operation, Sequence[Value]]] = None, - *, - loc=None, - ip=None, - ): - if operands is None: - operands = [] - super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip) + def __init__( + self, + operands: Optional[Union[Operation, Sequence[Value]]] = None, + *, + loc=None, + ip=None, + ): + if operands is None: + operands = [] + super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py b/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py index 5a695d6..2f65131 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py @@ -31,61 +31,60 @@ from .lang.yaml_helper import * def create_arg_parser() -> argparse.ArgumentParser: - p = argparse.ArgumentParser(description="Dump an oplib in various formats") - p.add_argument("modules", - metavar="M", - type=str, - nargs="*", - help="Op module to dump") - p.add_argument("--file", - metavar="F", - type=str, - nargs="*", - help="Python op file to dump") - p.add_argument("--format", - type=str, - dest="format", - default="yaml", - choices=("yaml", "repr"), - help="Format in which to dump") - return p + p = argparse.ArgumentParser(description="Dump an oplib in various formats") + p.add_argument( + "modules", metavar="M", type=str, nargs="*", help="Op module to dump" + ) + p.add_argument( + "--file", metavar="F", type=str, nargs="*", help="Python op file to dump" + ) + p.add_argument( + "--format", + type=str, + dest="format", + default="yaml", + choices=("yaml", "repr"), + help="Format in which to dump", + ) + return p def load_module_from_file(module_name, file_path): - spec = importlib.util.spec_from_file_location(module_name, file_path) - m = importlib.util.module_from_spec(spec) - spec.loader.exec_module(m) - return m + spec = importlib.util.spec_from_file_location(module_name, file_path) + m = importlib.util.module_from_spec(spec) + spec.loader.exec_module(m) + return m def main(args): - # Load all configs. - configs = [] - modules = [] - for module_name in args.modules: - modules.append( - importlib.import_module(module_name, - package="mlir.dialects.linalg.opdsl")) - for i, file_path in enumerate(args.file or []): - modules.append(load_module_from_file(f"_mlir_eval_oplib{i}", file_path)) - for m in modules: - for attr_name, value in m.__dict__.items(): - # TODO: This class layering is awkward. - if isinstance(value, DefinedOpCallable): - try: - linalg_config = LinalgOpConfig.from_linalg_op_def(value.op_def) - except Exception as e: - raise ValueError( - f"Could not create LinalgOpConfig from {value.op_def}") from e - configs.extend(linalg_config) - - # Print. - if args.format == "yaml": - print(yaml_dump_all(configs)) - elif args.format == "repr": - for config in configs: - print(repr(config)) + # Load all configs. + configs = [] + modules = [] + for module_name in args.modules: + modules.append( + importlib.import_module(module_name, package="mlir.dialects.linalg.opdsl") + ) + for i, file_path in enumerate(args.file or []): + modules.append(load_module_from_file(f"_mlir_eval_oplib{i}", file_path)) + for m in modules: + for attr_name, value in m.__dict__.items(): + # TODO: This class layering is awkward. + if isinstance(value, DefinedOpCallable): + try: + linalg_config = LinalgOpConfig.from_linalg_op_def(value.op_def) + except Exception as e: + raise ValueError( + f"Could not create LinalgOpConfig from {value.op_def}" + ) from e + configs.extend(linalg_config) + + # Print. + if args.format == "yaml": + print(yaml_dump_all(configs)) + elif args.format == "repr": + for config in configs: + print(repr(config)) if __name__ == "__main__": - main(create_arg_parser().parse_args()) + main(create_arg_parser().parse_args()) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py index 038f068..9fa626d 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py @@ -66,201 +66,201 @@ __all__ = [ class AffineBuildState: - """Internal state for the AffineExprDef._create impls. - - Note that a "local" AffineBuildState can be created relative to a "global" - AffineBuildState. In that case, any affine expressions built will inherit - symbol and dim bindings from the global state and will update both as new - ones are discovered. This allows for building expressions across contexts - which share a common symbol and dim space. - """ - - def __init__(self, - *, - global_state: "AffineBuildState" = None, - allow_new_symbols: bool = True, - allow_new_dims: bool = True): - if not global_state: - self.all_symbols = dict() # type: Dict[str, int] - self.all_dims = dict() # type: Dict[str, int] - else: - # Alias the global dict. - self.all_symbols = global_state.all_symbols - self.all_dims = global_state.all_dims - - # Map of symbols and dims in the current build. - self.local_symbols = dict() # type: Dict[str, int] - self.local_dims = dict() # type: Dict[str, int] - self.allow_new_symbols = allow_new_symbols - self.allow_new_dims = allow_new_dims - - def get_dim(self, dimname: str) -> int: - """Gets the dim position given a name.""" - pos = self.all_dims.get(dimname) - if pos is None: - if not self.allow_new_dims: - raise ValueError( - f"New dimensions not allowed in the current affine expression: " - f"Requested '{dimname}', Availble: {self.all_dims}") - pos = len(self.all_dims) - self.all_dims[dimname] = pos - self.local_dims[dimname] = pos - return pos - - def get_symbol(self, symname: str) -> int: - """Geta a symbol position given a name.""" - pos = self.all_symbols.get(symname) - if pos is None: - if not self.allow_new_symbols: - raise ValueError( - f"New symbols not allowed in the current affine expression: " - f"Requested '{symname}', Availble: {self.all_symbols}") - pos = len(self.all_symbols) - self.all_symbols[symname] = pos - self.local_symbols[symname] = pos - return pos - - @property - def local_dim_count(self) -> int: - return len(self.local_dims) - - @property - def local_symbol_count(self) -> int: - return len(self.local_symbols) - - @property - def dim_count(self) -> int: - return len(self.all_dims) - - @property - def symbol_count(self) -> int: - return len(self.all_symbols) - - def __repr__(self): - lines = [f"AffineBuildState<"] - lines.append(f" symbols={self.local_symbols}") - lines.append(f" dims={self.local_dims}>") - return "\n".join(lines) + """Internal state for the AffineExprDef._create impls. + + Note that a "local" AffineBuildState can be created relative to a "global" + AffineBuildState. In that case, any affine expressions built will inherit + symbol and dim bindings from the global state and will update both as new + ones are discovered. This allows for building expressions across contexts + which share a common symbol and dim space. + """ + + def __init__( + self, + *, + global_state: "AffineBuildState" = None, + allow_new_symbols: bool = True, + allow_new_dims: bool = True, + ): + if not global_state: + self.all_symbols = dict() # type: Dict[str, int] + self.all_dims = dict() # type: Dict[str, int] + else: + # Alias the global dict. + self.all_symbols = global_state.all_symbols + self.all_dims = global_state.all_dims + + # Map of symbols and dims in the current build. + self.local_symbols = dict() # type: Dict[str, int] + self.local_dims = dict() # type: Dict[str, int] + self.allow_new_symbols = allow_new_symbols + self.allow_new_dims = allow_new_dims + + def get_dim(self, dimname: str) -> int: + """Gets the dim position given a name.""" + pos = self.all_dims.get(dimname) + if pos is None: + if not self.allow_new_dims: + raise ValueError( + f"New dimensions not allowed in the current affine expression: " + f"Requested '{dimname}', Availble: {self.all_dims}" + ) + pos = len(self.all_dims) + self.all_dims[dimname] = pos + self.local_dims[dimname] = pos + return pos + + def get_symbol(self, symname: str) -> int: + """Geta a symbol position given a name.""" + pos = self.all_symbols.get(symname) + if pos is None: + if not self.allow_new_symbols: + raise ValueError( + f"New symbols not allowed in the current affine expression: " + f"Requested '{symname}', Availble: {self.all_symbols}" + ) + pos = len(self.all_symbols) + self.all_symbols[symname] = pos + self.local_symbols[symname] = pos + return pos + + @property + def local_dim_count(self) -> int: + return len(self.local_dims) + + @property + def local_symbol_count(self) -> int: + return len(self.local_symbols) + + @property + def dim_count(self) -> int: + return len(self.all_dims) + + @property + def symbol_count(self) -> int: + return len(self.all_symbols) + + def __repr__(self): + lines = [f"AffineBuildState<"] + lines.append(f" symbols={self.local_symbols}") + lines.append(f" dims={self.local_dims}>") + return "\n".join(lines) class AffineExprDef: - """Base class for an affine expression being defined.""" + """Base class for an affine expression being defined.""" - def build(self, state: Optional[AffineBuildState] = None) -> _ir.AffineExpr: - """Builds the corresponding _ir.AffineExpr from the definitions. - """ - state = AffineBuildState() if state is None else state - expr = self._create(state) - return expr + def build(self, state: Optional[AffineBuildState] = None) -> _ir.AffineExpr: + """Builds the corresponding _ir.AffineExpr from the definitions.""" + state = AffineBuildState() if state is None else state + expr = self._create(state) + return expr - def _create(self, state: AffineBuildState) -> _ir.AffineExpr: - raise NotImplementedError() + def _create(self, state: AffineBuildState) -> _ir.AffineExpr: + raise NotImplementedError() - @staticmethod - def coerce_from(py_value): - if isinstance(py_value, int): - return AffineConstantExpr(py_value) - assert isinstance(py_value, AffineExprDef) - return py_value + @staticmethod + def coerce_from(py_value): + if isinstance(py_value, int): + return AffineConstantExpr(py_value) + assert isinstance(py_value, AffineExprDef) + return py_value - def visit_affine_exprs(self, callback): - """Visits all AffineExprDefs including self.""" - callback(self) + def visit_affine_exprs(self, callback): + """Visits all AffineExprDefs including self.""" + callback(self) - def __add__(lhs, rhs): - rhs = AffineExprDef.coerce_from(rhs) - return AffineBinaryExprDef(_ir.AffineAddExpr, lhs, rhs) + def __add__(lhs, rhs): + rhs = AffineExprDef.coerce_from(rhs) + return AffineBinaryExprDef(_ir.AffineAddExpr, lhs, rhs) - def __mul__(lhs, rhs): - rhs = AffineExprDef.coerce_from(rhs) - return AffineBinaryExprDef(_ir.AffineMulExpr, lhs, rhs) + def __mul__(lhs, rhs): + rhs = AffineExprDef.coerce_from(rhs) + return AffineBinaryExprDef(_ir.AffineMulExpr, lhs, rhs) - def __mod__(lhs, rhs): - rhs = AffineExprDef.coerce_from(rhs) - return AffineBinaryExprDef(_ir.AffineModExpr, lhs, rhs) + def __mod__(lhs, rhs): + rhs = AffineExprDef.coerce_from(rhs) + return AffineBinaryExprDef(_ir.AffineModExpr, lhs, rhs) - def __floordiv__(lhs, rhs): - rhs = AffineExprDef.coerce_from(rhs) - return AffineBinaryExprDef(_ir.AffineFloorDivExpr, lhs, rhs) + def __floordiv__(lhs, rhs): + rhs = AffineExprDef.coerce_from(rhs) + return AffineBinaryExprDef(_ir.AffineFloorDivExpr, lhs, rhs) - def __truediv__(lhs, rhs): - # TODO: Not really a ceil div - taking liberties for the DSL. - rhs = AffineExprDef.coerce_from(rhs) - return AffineBinaryExprDef(_ir.AffineCeilDivExpr, lhs, rhs) + def __truediv__(lhs, rhs): + # TODO: Not really a ceil div - taking liberties for the DSL. + rhs = AffineExprDef.coerce_from(rhs) + return AffineBinaryExprDef(_ir.AffineCeilDivExpr, lhs, rhs) class AffineConstantExpr(AffineExprDef): - """An affine constant being defined.""" + """An affine constant being defined.""" - def __init__(self, value: int): - assert isinstance(value, int) - self.value = value + def __init__(self, value: int): + assert isinstance(value, int) + self.value = value - def _create(self, state: AffineBuildState) -> _ir.AffineExpr: - return _ir.AffineConstantExpr.get(self.value) + def _create(self, state: AffineBuildState) -> _ir.AffineExpr: + return _ir.AffineConstantExpr.get(self.value) - def __repr__(self): - return f"Const({self.value})" + def __repr__(self): + return f"Const({self.value})" class AffineBinaryExprDef(AffineExprDef): - """An affine binary expression being defined.""" + """An affine binary expression being defined.""" - def __init__(self, ir_ctor, lhs: AffineExprDef, rhs: AffineExprDef): - self.ir_ctor = ir_ctor - self.lhs = lhs - self.rhs = rhs + def __init__(self, ir_ctor, lhs: AffineExprDef, rhs: AffineExprDef): + self.ir_ctor = ir_ctor + self.lhs = lhs + self.rhs = rhs - def _create(self, state: AffineBuildState) -> _ir.AffineExpr: - return self.ir_ctor.get(self.lhs._create(state), self.rhs._create(state)) + def _create(self, state: AffineBuildState) -> _ir.AffineExpr: + return self.ir_ctor.get(self.lhs._create(state), self.rhs._create(state)) - def visit_affine_exprs(self, callback): - """Visits all AffineExprDefs including self.""" - super().visit_affine_exprs(callback) - self.lhs.visit_affine_exprs(callback) - self.rhs.visit_affine_exprs(callback) + def visit_affine_exprs(self, callback): + """Visits all AffineExprDefs including self.""" + super().visit_affine_exprs(callback) + self.lhs.visit_affine_exprs(callback) + self.rhs.visit_affine_exprs(callback) - def __repr__(self): - return f"{self.ir_ctor.__name__}({repr(self.lhs)}, {repr(self.rhs)})" + def __repr__(self): + return f"{self.ir_ctor.__name__}({repr(self.lhs)}, {repr(self.rhs)})" class DimDef(AffineExprDef): - """Represents a named dimension. - - """ - ALL_DIMS = dict() # type: Dict[str, "DimDef"] - - def __new__(cls, dimname: str): - existing = cls.ALL_DIMS.get(dimname) - if existing is not None: - return existing - new = super().__new__(cls) - new.dimname = dimname - cls.ALL_DIMS[dimname] = new - return new - - def __repr__(self): - return f"Dim({self.dimname})" - - def _create(self, state: AffineBuildState) -> _ir.AffineExpr: - pos = state.get_dim(self.dimname) - return _ir.AffineDimExpr.get(position=pos) - - @classmethod - def create_expando(cls): - """Create an expando class that creates unique symbols based on attr access. - """ + """Represents a named dimension.""" + + ALL_DIMS = dict() # type: Dict[str, "DimDef"] + + def __new__(cls, dimname: str): + existing = cls.ALL_DIMS.get(dimname) + if existing is not None: + return existing + new = super().__new__(cls) + new.dimname = dimname + cls.ALL_DIMS[dimname] = new + return new - class ExpandoDims: + def __repr__(self): + return f"Dim({self.dimname})" - def __getattr__(self, n): - return cls(n) + def _create(self, state: AffineBuildState) -> _ir.AffineExpr: + pos = state.get_dim(self.dimname) + return _ir.AffineDimExpr.get(position=pos) - return ExpandoDims() + @classmethod + def create_expando(cls): + """Create an expando class that creates unique symbols based on attr access.""" + + class ExpandoDims: + def __getattr__(self, n): + return cls(n) + + return ExpandoDims() class SymbolDef(AffineExprDef): - """Represents a named symbol. + """Represents a named symbol. >>> s1 = SymbolDef("s1") >>> s1 @@ -270,36 +270,35 @@ class SymbolDef(AffineExprDef): False >>> s1 is SymbolDef("s1") True - """ - ALL_SYMBOLS = dict() # type: Dict[str, "SymbolDef"] - - def __new__(cls, symname: str): - existing = cls.ALL_SYMBOLS.get(symname) - if existing is not None: - return existing - new = super().__new__(cls) - new.symname = symname - cls.ALL_SYMBOLS[symname] = new - return new - - def __repr__(self): - return f"Symbol({self.symname})" - - def _create(self, state: AffineBuildState) -> _ir.AffineExpr: - pos = state.get_symbol(self.symname) - return _ir.AffineSymbolExpr.get(position=pos) - - @classmethod - def create_expando(cls): - """Create an expando class that creates unique symbols based on attr access. """ - class ExpandoSymbols: + ALL_SYMBOLS = dict() # type: Dict[str, "SymbolDef"] + + def __new__(cls, symname: str): + existing = cls.ALL_SYMBOLS.get(symname) + if existing is not None: + return existing + new = super().__new__(cls) + new.symname = symname + cls.ALL_SYMBOLS[symname] = new + return new + + def __repr__(self): + return f"Symbol({self.symname})" + + def _create(self, state: AffineBuildState) -> _ir.AffineExpr: + pos = state.get_symbol(self.symname) + return _ir.AffineSymbolExpr.get(position=pos) + + @classmethod + def create_expando(cls): + """Create an expando class that creates unique symbols based on attr access.""" - def __getattr__(self, n): - return cls(n) + class ExpandoSymbols: + def __getattr__(self, n): + return cls(n) - return ExpandoSymbols() + return ExpandoSymbols() # Global accessor for on-demand dims and symbols. diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index 135f55e..5d5866fd 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -23,223 +23,232 @@ from .yaml_helper import * class TensorExpression: - """An expression that can appear on the RHS of a comprehension.""" + """An expression that can appear on the RHS of a comprehension.""" - def to_scalar_expression(self) -> ScalarExpression: - raise NotImplementedError() + def to_scalar_expression(self) -> ScalarExpression: + raise NotImplementedError() - def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): - """Visits all tensor expression reachable by the expression.""" - callback(self) + def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): + """Visits all tensor expression reachable by the expression.""" + callback(self) - def collect_dim_uses(self, uses: Set["DimDef"]): - """Collects all DimDefs reachable through this expression.""" + def collect_dim_uses(self, uses: Set["DimDef"]): + """Collects all DimDefs reachable through this expression.""" - def visit_dim_def(dim_def: AffineExprDef): - if isinstance(dim_def, DimDef): - uses.add(dim_def) + def visit_dim_def(dim_def: AffineExprDef): + if isinstance(dim_def, DimDef): + uses.add(dim_def) - def visit_affine_exprs(expr: "TensorExpression"): - if isinstance(expr, TensorUse): - for ind in expr.indices: - ind.visit_affine_exprs(visit_dim_def) - if isinstance(expr, TensorReduceFn): - for ind in expr.reduce_fn.reduce_dims: - ind.visit_affine_exprs(visit_dim_def) + def visit_affine_exprs(expr: "TensorExpression"): + if isinstance(expr, TensorUse): + for ind in expr.indices: + ind.visit_affine_exprs(visit_dim_def) + if isinstance(expr, TensorReduceFn): + for ind in expr.reduce_fn.reduce_dims: + ind.visit_affine_exprs(visit_dim_def) - self.visit_tensor_exprs(visit_affine_exprs) + self.visit_tensor_exprs(visit_affine_exprs) - def collect_tensor_uses(self, uses: Set["TensorUse"]): - """Collects all TensorUses reachable through this expression.""" + def collect_tensor_uses(self, uses: Set["TensorUse"]): + """Collects all TensorUses reachable through this expression.""" - def visit_tensor_use(expr: "TensorExpression"): - if isinstance(expr, TensorUse): - uses.add(expr) + def visit_tensor_use(expr: "TensorExpression"): + if isinstance(expr, TensorUse): + uses.add(expr) - self.visit_tensor_exprs(visit_tensor_use) + self.visit_tensor_exprs(visit_tensor_use) - def collect_indices(self, indices: Set["index"]): - """Collects all index accesses reachable through this expression.""" + def collect_indices(self, indices: Set["index"]): + """Collects all index accesses reachable through this expression.""" - def visit_index(expr: "TensorExpression"): - if isinstance(expr, index): - indices.add(expr) + def visit_index(expr: "TensorExpression"): + if isinstance(expr, index): + indices.add(expr) - self.visit_tensor_exprs(visit_index) + self.visit_tensor_exprs(visit_index) - def collect_scalar_uses(self, uses: Set["ScalarDef"]): - """Collects all ScalarDefs reachable through this expression.""" + def collect_scalar_uses(self, uses: Set["ScalarDef"]): + """Collects all ScalarDefs reachable through this expression.""" - def visit_scalar_def(expr: "TensorExpression"): - if isinstance(expr, ScalarDef): - uses.add(expr) + def visit_scalar_def(expr: "TensorExpression"): + if isinstance(expr, ScalarDef): + uses.add(expr) - self.visit_tensor_exprs(visit_scalar_def) + self.visit_tensor_exprs(visit_scalar_def) - def __add__(self, rhs: "TensorExpression") -> "TensorExpression": - return BinaryFn.add(self, rhs) + def __add__(self, rhs: "TensorExpression") -> "TensorExpression": + return BinaryFn.add(self, rhs) - def __mul__(self, rhs) -> "TensorExpression": - return BinaryFn.mul(self, rhs) + def __mul__(self, rhs) -> "TensorExpression": + return BinaryFn.mul(self, rhs) - def __sub__(self, rhs) -> "TensorExpression": - return BinaryFn.sub(self, rhs) + def __sub__(self, rhs) -> "TensorExpression": + return BinaryFn.sub(self, rhs) - def __hash__(self): - return hash(id(self)) + def __hash__(self): + return hash(id(self)) class TensorUse(TensorExpression): - """A used tensor represented by its (tensor_name, indices). - - Note that forming a comprehension via direct assignment is performed through - __setitem__ on the TensorDef level. However, performing a reduction with - compound ops (+=, *=, etc) is done by doing a: - TensorDef.__getitem__ - TensorUse.__iadd__ - TensorDef.__setitem__ - """ - - def __init__(self, operand_def: "OperandDef", - indices: Sequence[AffineExprDef]): - self.operand_def = operand_def - self.indices = tuple(indices) - - def to_scalar_expression(self) -> ScalarExpression: - return ScalarArg(self.tensor_name).expr() - - @property - def tensor_name(self) -> str: - name = self.operand_def.name - assert name is not None, "TensorDef not registered with an op" - return name - - def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]: - # Computes the reduction dims for implicit reductions. Assumes that the rhs - # is the expression being reduced and self is being reduced into. Any - # indices referenced on the rhs and not in self are considered reduction - # dims and will be ordered as encountered on the rhs. - rhs_dims = set() - lhs_dims = set() - rhs.collect_dim_uses(rhs_dims) - self.collect_dim_uses(lhs_dims) - return rhs_dims - lhs_dims - - def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn": - return ReduceFnUse(BinaryFn.add, None, *self._compute_reduce_dims(rhs))(rhs) - - def __repr__(self): - return (f"{self.operand_def.name}" - f"[{', '.join([repr(i) for i in self.indices])}]") + """A used tensor represented by its (tensor_name, indices). + + Note that forming a comprehension via direct assignment is performed through + __setitem__ on the TensorDef level. However, performing a reduction with + compound ops (+=, *=, etc) is done by doing a: + TensorDef.__getitem__ + TensorUse.__iadd__ + TensorDef.__setitem__ + """ + + def __init__(self, operand_def: "OperandDef", indices: Sequence[AffineExprDef]): + self.operand_def = operand_def + self.indices = tuple(indices) + + def to_scalar_expression(self) -> ScalarExpression: + return ScalarArg(self.tensor_name).expr() + + @property + def tensor_name(self) -> str: + name = self.operand_def.name + assert name is not None, "TensorDef not registered with an op" + return name + + def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]: + # Computes the reduction dims for implicit reductions. Assumes that the rhs + # is the expression being reduced and self is being reduced into. Any + # indices referenced on the rhs and not in self are considered reduction + # dims and will be ordered as encountered on the rhs. + rhs_dims = set() + lhs_dims = set() + rhs.collect_dim_uses(rhs_dims) + self.collect_dim_uses(lhs_dims) + return rhs_dims - lhs_dims + + def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn": + return ReduceFnUse(BinaryFn.add, None, *self._compute_reduce_dims(rhs))(rhs) + + def __repr__(self): + return ( + f"{self.operand_def.name}" f"[{', '.join([repr(i) for i in self.indices])}]" + ) class TensorFn(TensorExpression): - """Application of a tensor function.""" - - def __init__(self, kind: "FunctionKind", name: Optional[str], - operand_def: Optional["OperandDef"], type_var: Optional[TypeVar], - args: Sequence[TensorExpression]): - if bool(name) + bool(operand_def) != 1: - raise ValueError("One of 'name', 'operand_def' must be specified") - self.name = name - self.kind = kind - self.operand_def = operand_def - self.type_var = type_var - self.args = args - - def to_scalar_expression(self) -> ScalarExpression: - if self.operand_def: - assert self.operand_def.name, "TensorFn not registered with an op" - attr_name = self.operand_def.name if self.operand_def else None - args = [arg.to_scalar_expression() for arg in self.args] - return ScalarFn(self.kind, self.name, attr_name, self.type_var, args).expr() - - def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): - super().visit_tensor_exprs(callback) - for arg in self.args: - arg.visit_tensor_exprs(callback) - - def __repr__(self): - name = self.operand_def.name if self.operand_def else self.name - return (f"{self.kind.name}.{name}(type_var={self.type_var}, " - f"args={', '.join(repr(a) for a in self.args)})") + """Application of a tensor function.""" + + def __init__( + self, + kind: "FunctionKind", + name: Optional[str], + operand_def: Optional["OperandDef"], + type_var: Optional[TypeVar], + args: Sequence[TensorExpression], + ): + if bool(name) + bool(operand_def) != 1: + raise ValueError("One of 'name', 'operand_def' must be specified") + self.name = name + self.kind = kind + self.operand_def = operand_def + self.type_var = type_var + self.args = args + + def to_scalar_expression(self) -> ScalarExpression: + if self.operand_def: + assert self.operand_def.name, "TensorFn not registered with an op" + attr_name = self.operand_def.name if self.operand_def else None + args = [arg.to_scalar_expression() for arg in self.args] + return ScalarFn(self.kind, self.name, attr_name, self.type_var, args).expr() + + def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): + super().visit_tensor_exprs(callback) + for arg in self.args: + arg.visit_tensor_exprs(callback) + + def __repr__(self): + name = self.operand_def.name if self.operand_def else self.name + return ( + f"{self.kind.name}.{name}(type_var={self.type_var}, " + f"args={', '.join(repr(a) for a in self.args)})" + ) class TensorReduceFn(TensorExpression): - """Application of a reduction function. - - This captures the lhs (initial value) separately from the rhs. - """ - - def __init__(self, reduce_use: "ReduceFnUse", - args: Sequence[TensorExpression]): - self.reduce_use = reduce_use - self.lhs = None # type: Optional[TensorUse] - self.args = args - - def to_scalar_expression(self) -> ScalarExpression: - if self.lhs is None: - raise ValueError(f"Cannot scalarize a TensorReduceFn that has not been " - f"bound to its lhs: {self}") - full_args = [self.lhs.to_scalar_expression() - ] + [arg.to_scalar_expression() for arg in self.args] - fn_name = None - attr_name = None - if self.reduce_use.binary_fn: - fn_name = self.reduce_use.binary_fn.fn_name - if self.reduce_use.binary_attr: - attr_name = self.reduce_use.binary_attr.operand_def.name - return ScalarFn(FunctionKind.BINARY, fn_name, attr_name, None, - full_args).expr() - - def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): - for arg in self.args: - arg.visit_tensor_exprs(callback) - - def __repr__(self): - return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})" + """Application of a reduction function. + + This captures the lhs (initial value) separately from the rhs. + """ + + def __init__(self, reduce_use: "ReduceFnUse", args: Sequence[TensorExpression]): + self.reduce_use = reduce_use + self.lhs = None # type: Optional[TensorUse] + self.args = args + + def to_scalar_expression(self) -> ScalarExpression: + if self.lhs is None: + raise ValueError( + f"Cannot scalarize a TensorReduceFn that has not been " + f"bound to its lhs: {self}" + ) + full_args = [self.lhs.to_scalar_expression()] + [ + arg.to_scalar_expression() for arg in self.args + ] + fn_name = None + attr_name = None + if self.reduce_use.binary_fn: + fn_name = self.reduce_use.binary_fn.fn_name + if self.reduce_use.binary_attr: + attr_name = self.reduce_use.binary_attr.operand_def.name + return ScalarFn(FunctionKind.BINARY, fn_name, attr_name, None, full_args).expr() + + def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): + for arg in self.args: + arg.visit_tensor_exprs(callback) + + def __repr__(self): + return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})" class const(TensorExpression): - """Returns the given constant floating point or integer value.""" + """Returns the given constant floating point or integer value.""" - def __init__(self, value: Any): - with _ir.Context(): - if isinstance(value, float): - self.value = str(_ir.FloatAttr.get_f64(float(value))) - elif isinstance(value, int): - self.value = str( - _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value))) - else: - raise ValueError(f"const requires int or float but got {type(value)}") + def __init__(self, value: Any): + with _ir.Context(): + if isinstance(value, float): + self.value = str(_ir.FloatAttr.get_f64(float(value))) + elif isinstance(value, int): + self.value = str( + _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value)) + ) + else: + raise ValueError(f"const requires int or float but got {type(value)}") - def to_scalar_expression(self) -> ScalarExpression: - return ScalarConst(self.value).expr() + def to_scalar_expression(self) -> ScalarExpression: + return ScalarConst(self.value).expr() - def __repr__(self): - return f"const({self.value})" + def __repr__(self): + return f"const({self.value})" class index(TensorExpression): - """Returns the iteration index for a given dimension name. + """Returns the iteration index for a given dimension name. - Resolves the given dimension name to obtain its position in the iteration - domain of the operation. - """ + Resolves the given dimension name to obtain its position in the iteration + domain of the operation. + """ - def __init__(self, dim: DimDef): - self.dim_def = dim - self.dim = -1 + def __init__(self, dim: DimDef): + self.dim_def = dim + self.dim = -1 - def resolve_dimension_name(self, affine_state: AffineBuildState): - self.dim = affine_state.get_dim(self.dim_def.dimname) + def resolve_dimension_name(self, affine_state: AffineBuildState): + self.dim = affine_state.get_dim(self.dim_def.dimname) - def to_scalar_expression(self) -> ScalarExpression: - assert self.dim != -1, "Dimension name not resolved" - return ScalarIndex(self.dim).expr() + def to_scalar_expression(self) -> ScalarExpression: + assert self.dim != -1, "Dimension name not resolved" + return ScalarIndex(self.dim).expr() - def __repr__(self): - return f"index({repr(self.dim)})" + def __repr__(self): + return f"index({repr(self.dim)})" ############################################################################### @@ -248,155 +257,160 @@ class index(TensorExpression): class FunctionKind(Enum): - UNARY = 0 - BINARY = 1 - TYPE = 2 + UNARY = 0 + BINARY = 1 + TYPE = 2 class UnaryFnType: - """Unary function. + """Unary function. - A unary function takes one tensor expression and returns the - function evaluation result. - """ + A unary function takes one tensor expression and returns the + function evaluation result. + """ - def __init__(self, fn_name: str): - self.fn_name = fn_name + def __init__(self, fn_name: str): + self.fn_name = fn_name - def __call__(self, arg: TensorExpression) -> "TensorFn": - return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [arg]) + def __call__(self, arg: TensorExpression) -> "TensorFn": + return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [arg]) - def __repr__(self): - return f"{self.fn_name}" + def __repr__(self): + return f"{self.fn_name}" class UnaryFn: - """Unary function namespace.""" - exp = UnaryFnType("exp") - log = UnaryFnType("log") - abs = UnaryFnType("abs") - ceil = UnaryFnType("ceil") - floor = UnaryFnType("floor") - negf = UnaryFnType("negf") + """Unary function namespace.""" + + exp = UnaryFnType("exp") + log = UnaryFnType("log") + abs = UnaryFnType("abs") + ceil = UnaryFnType("ceil") + floor = UnaryFnType("floor") + negf = UnaryFnType("negf") class BinaryFnType: - """Binary function. + """Binary function. - A binary function takes two tensor expressions and returns the - function evaluation result. - """ + A binary function takes two tensor expressions and returns the + function evaluation result. + """ - def __init__(self, fn_name: str): - self.fn_name = fn_name + def __init__(self, fn_name: str): + self.fn_name = fn_name - def __call__(self, arg0: TensorExpression, - arg1: TensorExpression) -> "TensorFn": - return TensorFn(FunctionKind.BINARY, self.fn_name, None, None, [arg0, arg1]) + def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> "TensorFn": + return TensorFn(FunctionKind.BINARY, self.fn_name, None, None, [arg0, arg1]) - def __repr__(self): - return f"{self.fn_name}" + def __repr__(self): + return f"{self.fn_name}" class BinaryFn: - """Binary function namespace. + """Binary function namespace. - As the integer types are signless, signedness is implement by different - functions that treat integers as signed or unsigned values. + As the integer types are signless, signedness is implement by different + functions that treat integers as signed or unsigned values. + + Examples: + - max -> `arith.MaxSIOp` + - max_unsinged -> `arith.MaxUIOp` + """ - Examples: - - max -> `arith.MaxSIOp` - - max_unsinged -> `arith.MaxUIOp` - """ - add = BinaryFnType("add") - sub = BinaryFnType("sub") - mul = BinaryFnType("mul") - max_signed = BinaryFnType("max_signed") - min_signed = BinaryFnType("min_signed") - max_unsigned = BinaryFnType("max_unsigned") - min_unsigned = BinaryFnType("min_unsigned") + add = BinaryFnType("add") + sub = BinaryFnType("sub") + mul = BinaryFnType("mul") + max_signed = BinaryFnType("max_signed") + min_signed = BinaryFnType("min_signed") + max_unsigned = BinaryFnType("max_unsigned") + min_unsigned = BinaryFnType("min_unsigned") class TypeFnType: - """Type conversion function. + """Type conversion function. - A type conversion function takes a target type and a tensor expression and - returns the casted tensor expression. - """ + A type conversion function takes a target type and a tensor expression and + returns the casted tensor expression. + """ - def __init__(self, fn_name: str): - self.fn_name = fn_name + def __init__(self, fn_name: str): + self.fn_name = fn_name - def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn": - return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg]) + def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn": + return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg]) - def __repr__(self): - return f"{self.fn_name}" + def __repr__(self): + return f"{self.fn_name}" class TypeFn: - """Type conversion function namespace. + """Type conversion function namespace. + + As the integer types are signless, signedness is implement by different cast + functions that treat integers as signed (`cast_signed`) or unsigned + (`cast_unsigned`) values. - As the integer types are signless, signedness is implement by different cast - functions that treat integers as signed (`cast_signed`) or unsigned - (`cast_unsigned`) values. + Examples: + - cast_signed(I32 -> I64) -> `arith.ExtSIOp` + - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp` + """ - Examples: - - cast_signed(I32 -> I64) -> `arith.ExtSIOp` - - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp` - """ - cast_signed = TypeFnType("cast_signed") - cast_unsigned = TypeFnType("cast_unsigned") + cast_signed = TypeFnType("cast_signed") + cast_unsigned = TypeFnType("cast_unsigned") class ReduceFnUse: - """Reduction function use. + """Reduction function use. - A reduction use specifies the reduction function and dimensions. - """ + A reduction use specifies the reduction function and dimensions. + """ - def __init__(self, binary_fn: Optional[BinaryFnType], - binary_attr: Optional["BinaryFnAttrDef"], *reduce_dims: DimDef): - if bool(binary_fn) + bool(binary_attr) != 1: - raise ValueError("One of 'binary_fn', 'binary_attr' must be specified") - self.binary_fn = binary_fn - self.binary_attr = binary_attr - self.reduce_dims = reduce_dims + def __init__( + self, + binary_fn: Optional[BinaryFnType], + binary_attr: Optional["BinaryFnAttrDef"], + *reduce_dims: DimDef, + ): + if bool(binary_fn) + bool(binary_attr) != 1: + raise ValueError("One of 'binary_fn', 'binary_attr' must be specified") + self.binary_fn = binary_fn + self.binary_attr = binary_attr + self.reduce_dims = reduce_dims - def __call__(self, *args: TensorExpression) -> "TensorReduceFn": - return TensorReduceFn(self, args) + def __call__(self, *args: TensorExpression) -> "TensorReduceFn": + return TensorReduceFn(self, args) - def __repr__(self): - fn = self.binary_fn if self.binary_fn else self.binary_attr - return ( - f"reduce_{repr(fn)}({', '.join(repr(d) for d in self.reduce_dims)})") + def __repr__(self): + fn = self.binary_fn if self.binary_fn else self.binary_attr + return f"reduce_{repr(fn)}({', '.join(repr(d) for d in self.reduce_dims)})" class ReduceFnType: - """Reduction function. + """Reduction function. - A binary function that reduces its RHS into its LHS. - """ + A binary function that reduces its RHS into its LHS. + """ - def __init__(self, binary_fn: BinaryFnType): - if not isinstance(binary_fn, BinaryFnType): - raise ValueError(f"Reduce expected a BinaryFnType but got {binary_fn}") - self.binary_fn = binary_fn + def __init__(self, binary_fn: BinaryFnType): + if not isinstance(binary_fn, BinaryFnType): + raise ValueError(f"Reduce expected a BinaryFnType but got {binary_fn}") + self.binary_fn = binary_fn - def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: - return ReduceFnUse(self.binary_fn, None, *reduce_dims) + def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: + return ReduceFnUse(self.binary_fn, None, *reduce_dims) - def __repr__(self): - return f"reduce_{repr(self.binary_fn)}" + def __repr__(self): + return f"reduce_{repr(self.binary_fn)}" class ReduceFn: - add = ReduceFnType(BinaryFn.add) - mul = ReduceFnType(BinaryFn.mul) - max_signed = ReduceFnType(BinaryFn.max_signed) - min_signed = ReduceFnType(BinaryFn.min_signed) - max_unsigned = ReduceFnType(BinaryFn.max_unsigned) - min_unsigned = ReduceFnType(BinaryFn.min_unsigned) + add = ReduceFnType(BinaryFn.add) + mul = ReduceFnType(BinaryFn.mul) + max_signed = ReduceFnType(BinaryFn.max_signed) + min_signed = ReduceFnType(BinaryFn.min_signed) + max_unsigned = ReduceFnType(BinaryFn.max_unsigned) + min_unsigned = ReduceFnType(BinaryFn.min_unsigned) ############################################################################### @@ -405,237 +419,265 @@ class ReduceFn: class OperandKind(Enum): - INPUT_TENSOR = 0 - SCALAR = 1 - OUTPUT_TENSOR = 2 - INDEX_ATTR = 3 - UNARY_FN_ATTR = 4 - BINARY_FN_ATTR = 5 - TYPE_FN_ATTR = 6 + INPUT_TENSOR = 0 + SCALAR = 1 + OUTPUT_TENSOR = 2 + INDEX_ATTR = 3 + UNARY_FN_ATTR = 4 + BINARY_FN_ATTR = 5 + TYPE_FN_ATTR = 6 class OperandDef: - """Definition of an operand passed to an operation. - - Keep the meta information of Tensor, Scalar, and Attribute operands and - provide the shared registration functionality. - """ - - def __init__(self, - kind: OperandKind, - type_var: Optional[TypeVar] = None, - size_exprs: Optional[Sequence[AffineExprDef]] = None, - index_dims: Optional[Sequence[DimDef]] = None, - default_indices: Optional[Sequence[int]] = None, - default_fn: Optional[str] = None): - if type_var and not isinstance(type_var, TypeVar): - raise ValueError( - f"OperandDef requires a TypeVar but got {repr(type_var)}") - self.owner = None # type: Optional["LinalgOpDef"] - self.type_var = type_var - self.size_exprs = size_exprs - self.index_dims = index_dims - self.default_indices = default_indices - self.default_fn = default_fn - self.kind = kind - self.name = None # type: Optional[str] - self.registered_index = -1 # type: int - - def attach(self, index: int, name: str, owner: "LinalgOpDef"): - if self.owner: - raise ValueError(f"OperandDef already registered with an op: {self}") - self.registered_index = index - self.name = name - self.owner = owner - - def is_input(self) -> bool: - return (self.kind == OperandKind.SCALAR or - self.kind == OperandKind.INPUT_TENSOR) - - def is_tensor(self) -> bool: - return (self.kind == OperandKind.INPUT_TENSOR or - self.kind == OperandKind.OUTPUT_TENSOR) - - def is_attribute(self) -> bool: - return (self.kind == OperandKind.INDEX_ATTR or - self.kind == OperandKind.UNARY_FN_ATTR or - self.kind == OperandKind.BINARY_FN_ATTR or - self.kind == OperandKind.TYPE_FN_ATTR) - - def __hash__(self): - return hash(id(self)) - - def __repr__(self): - return (f"{self.name}:OperandDef(kind={self.kind.name}, " + """Definition of an operand passed to an operation. + + Keep the meta information of Tensor, Scalar, and Attribute operands and + provide the shared registration functionality. + """ + + def __init__( + self, + kind: OperandKind, + type_var: Optional[TypeVar] = None, + size_exprs: Optional[Sequence[AffineExprDef]] = None, + index_dims: Optional[Sequence[DimDef]] = None, + default_indices: Optional[Sequence[int]] = None, + default_fn: Optional[str] = None, + ): + if type_var and not isinstance(type_var, TypeVar): + raise ValueError(f"OperandDef requires a TypeVar but got {repr(type_var)}") + self.owner = None # type: Optional["LinalgOpDef"] + self.type_var = type_var + self.size_exprs = size_exprs + self.index_dims = index_dims + self.default_indices = default_indices + self.default_fn = default_fn + self.kind = kind + self.name = None # type: Optional[str] + self.registered_index = -1 # type: int + + def attach(self, index: int, name: str, owner: "LinalgOpDef"): + if self.owner: + raise ValueError(f"OperandDef already registered with an op: {self}") + self.registered_index = index + self.name = name + self.owner = owner + + def is_input(self) -> bool: + return self.kind == OperandKind.SCALAR or self.kind == OperandKind.INPUT_TENSOR + + def is_tensor(self) -> bool: + return ( + self.kind == OperandKind.INPUT_TENSOR + or self.kind == OperandKind.OUTPUT_TENSOR + ) + + def is_attribute(self) -> bool: + return ( + self.kind == OperandKind.INDEX_ATTR + or self.kind == OperandKind.UNARY_FN_ATTR + or self.kind == OperandKind.BINARY_FN_ATTR + or self.kind == OperandKind.TYPE_FN_ATTR + ) + + def __hash__(self): + return hash(id(self)) + + def __repr__(self): + return ( + f"{self.name}:OperandDef(kind={self.kind.name}, " f"type={repr(self.type_var)}, size_exprs={self.size_exprs}, " f"index_dims={self.index_dims}, " f"default_indices={self.default_indices}, " - f"default_fn={self.default_fn})") + f"default_fn={self.default_fn})" + ) class TensorDef: - """Tensor operand definition. - - Tensor operands are indexed using the associated indexing_map when forwarded - to the body of the structured op. A unique name identifies the tensor operands - and an index determines their position in the operation's parameter list. A - tensor definition takes type, a shape, and an optional flag to mark output - tensors. Additionally, a tuple of index dimensions may be used to map the - tensor to the loop dimensions of the operation. This mapping is needed to - compute the indexing map of shape-only tensors that have no uses. - """ - - def __init__(self, - type_var: TypeVar, - *shape: AffineExprDef, - index_dims: Optional[Sequence[DimDef]] = None, - output: bool = False): - if index_dims and len(shape) != len(index_dims): - raise ValueError(f"Expected the shape rank {len(shape)} to match the " - f"number of index_dims {len(index_dims)}") - if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims): - raise ValueError(f"TensorDef requires index dims of type DimDef but " - f"got {index_dims}") - kind = OperandKind.OUTPUT_TENSOR if output else OperandKind.INPUT_TENSOR - self.operand_def = OperandDef( - kind, type_var=type_var, size_exprs=shape, index_dims=index_dims) - - def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse: - assert self.operand_def.owner, "TensorDef is not registered with an op" - state = AffineBuildState( - global_state=self.operand_def.owner._affine_state, - allow_new_symbols=False) - if not isinstance(dims, tuple): - dims = (dims,) # Handle single subscript case. - # Special case: (None) is a 0d-scalar use. - if dims == (None,): - dims = () - - exprs = [] - for expr_def in dims: - if not isinstance(expr_def, AffineExprDef): - raise KeyError( - "A TensorDef can only be subscripted by a tuple of affine dims") - exprs.append(expr_def) - return TensorUse(self.operand_def, exprs) - - def __setitem__(self, dims: Sequence[AffineExprDef], value: TensorExpression): - """Creates a new 1:1 comprehension by binding this tensor to an expression. - - Note that due to the way assignment works in Python, we have to capture - direct assignment as a setitem on the TensorDef. + """Tensor operand definition. + + Tensor operands are indexed using the associated indexing_map when forwarded + to the body of the structured op. A unique name identifies the tensor operands + and an index determines their position in the operation's parameter list. A + tensor definition takes type, a shape, and an optional flag to mark output + tensors. Additionally, a tuple of index dimensions may be used to map the + tensor to the loop dimensions of the operation. This mapping is needed to + compute the indexing map of shape-only tensors that have no uses. """ - if not isinstance(value, TensorExpression): - raise ValueError(f"Only TensorExpressions can be assigned to TensorDefs. " - f"Got: {repr(value)}") - use = self[dims] - comp = Comprehension((use, value)) - self.operand_def.owner.comprehensions.append(comp) + + def __init__( + self, + type_var: TypeVar, + *shape: AffineExprDef, + index_dims: Optional[Sequence[DimDef]] = None, + output: bool = False, + ): + if index_dims and len(shape) != len(index_dims): + raise ValueError( + f"Expected the shape rank {len(shape)} to match the " + f"number of index_dims {len(index_dims)}" + ) + if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims): + raise ValueError( + f"TensorDef requires index dims of type DimDef but " f"got {index_dims}" + ) + kind = OperandKind.OUTPUT_TENSOR if output else OperandKind.INPUT_TENSOR + self.operand_def = OperandDef( + kind, type_var=type_var, size_exprs=shape, index_dims=index_dims + ) + + def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse: + assert self.operand_def.owner, "TensorDef is not registered with an op" + state = AffineBuildState( + global_state=self.operand_def.owner._affine_state, allow_new_symbols=False + ) + if not isinstance(dims, tuple): + dims = (dims,) # Handle single subscript case. + # Special case: (None) is a 0d-scalar use. + if dims == (None,): + dims = () + + exprs = [] + for expr_def in dims: + if not isinstance(expr_def, AffineExprDef): + raise KeyError( + "A TensorDef can only be subscripted by a tuple of affine dims" + ) + exprs.append(expr_def) + return TensorUse(self.operand_def, exprs) + + def __setitem__(self, dims: Sequence[AffineExprDef], value: TensorExpression): + """Creates a new 1:1 comprehension by binding this tensor to an expression. + + Note that due to the way assignment works in Python, we have to capture + direct assignment as a setitem on the TensorDef. + """ + if not isinstance(value, TensorExpression): + raise ValueError( + f"Only TensorExpressions can be assigned to TensorDefs. " + f"Got: {repr(value)}" + ) + use = self[dims] + comp = Comprehension((use, value)) + self.operand_def.owner.comprehensions.append(comp) class ScalarDef(TensorExpression): - """Scalar operand definition. + """Scalar operand definition. - Scalar operands are forwarded to the body of the structured op as they are. - A unique name identifies the scalars and an index determines their position in - the operation's parameter list. - """ + Scalar operands are forwarded to the body of the structured op as they are. + A unique name identifies the scalars and an index determines their position in + the operation's parameter list. + """ - def __init__(self, type_var: TypeVar): - self.operand_def = OperandDef(OperandKind.SCALAR, type_var=type_var) + def __init__(self, type_var: TypeVar): + self.operand_def = OperandDef(OperandKind.SCALAR, type_var=type_var) - @property - def scalar_name(self) -> str: - name = self.operand_def.name - assert name is not None, "ScalarDef not registered with an op" - return name + @property + def scalar_name(self) -> str: + name = self.operand_def.name + assert name is not None, "ScalarDef not registered with an op" + return name - def to_scalar_expression(self) -> ScalarExpression: - return ScalarArg(self.scalar_name).expr() + def to_scalar_expression(self) -> ScalarExpression: + return ScalarArg(self.scalar_name).expr() class IndexAttrDef: - """Index attribute definition. - - Index attributes provide a way to define and set symbols that can be used in - indexing expressions. Every attribute specifies a tuple of symbols that at - compile-time are replaced by integer values as well as their default values. - """ - - def __init__(self, *sizes: SymbolDef, default: Sequence[int]): - if any(not isinstance(size, SymbolDef) for size in sizes): - raise ValueError(f"IndexAttrDef requires sizes of type SymbolDef " - f"but got {sizes}") - if any(not isinstance(default_val, int) for default_val in default): - raise ValueError(f"IndexAttrDef requires default values of type int " - f"but got {default}") - if len(sizes) != len(default): - raise ValueError(f"IndexAttrDef expects {len(sizes)} default values " - f"but got {len(default)}") - self.operand_def = OperandDef( - OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default) + """Index attribute definition. + + Index attributes provide a way to define and set symbols that can be used in + indexing expressions. Every attribute specifies a tuple of symbols that at + compile-time are replaced by integer values as well as their default values. + """ + + def __init__(self, *sizes: SymbolDef, default: Sequence[int]): + if any(not isinstance(size, SymbolDef) for size in sizes): + raise ValueError( + f"IndexAttrDef requires sizes of type SymbolDef " f"but got {sizes}" + ) + if any(not isinstance(default_val, int) for default_val in default): + raise ValueError( + f"IndexAttrDef requires default values of type int " + f"but got {default}" + ) + if len(sizes) != len(default): + raise ValueError( + f"IndexAttrDef expects {len(sizes)} default values " + f"but got {len(default)}" + ) + self.operand_def = OperandDef( + OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default + ) class UnaryFnAttrDef: - """Unary function attribute definition. + """Unary function attribute definition. - Unary function attributes provide a way to make the arithmetic computation - parametrizable. Every attribute specifies a default unary function - that may be overwritten at operation instantiation time. - """ + Unary function attributes provide a way to make the arithmetic computation + parametrizable. Every attribute specifies a default unary function + that may be overwritten at operation instantiation time. + """ - def __init__(self, default: "UnaryFnType"): - if not isinstance(default, UnaryFnType): - raise ValueError(f"UnaryFnAttrDef requires default of type UnaryFnType " - f"but got {default}") - self.operand_def = OperandDef( - OperandKind.UNARY_FN_ATTR, default_fn=default.fn_name) + def __init__(self, default: "UnaryFnType"): + if not isinstance(default, UnaryFnType): + raise ValueError( + f"UnaryFnAttrDef requires default of type UnaryFnType " + f"but got {default}" + ) + self.operand_def = OperandDef( + OperandKind.UNARY_FN_ATTR, default_fn=default.fn_name + ) - def __call__(self, arg: TensorExpression) -> TensorFn: - return TensorFn(FunctionKind.UNARY, None, self.operand_def, None, [arg]) + def __call__(self, arg: TensorExpression) -> TensorFn: + return TensorFn(FunctionKind.UNARY, None, self.operand_def, None, [arg]) class BinaryFnAttrDef: - """Binary function attribute definition. + """Binary function attribute definition. - Binary function attributes provide a way to make the arithmetic computation - parametrizable. Every attribute specifies a default binary function - that may be overwritten at operation instantiation time. - """ + Binary function attributes provide a way to make the arithmetic computation + parametrizable. Every attribute specifies a default binary function + that may be overwritten at operation instantiation time. + """ - def __init__(self, default: "BinaryFnType"): - if not isinstance(default, BinaryFnType): - raise ValueError(f"BinaryFnAttrDef requires default of type BinaryFnType " - f"but got {default}") - self.operand_def = OperandDef( - OperandKind.BINARY_FN_ATTR, default_fn=default.fn_name) + def __init__(self, default: "BinaryFnType"): + if not isinstance(default, BinaryFnType): + raise ValueError( + f"BinaryFnAttrDef requires default of type BinaryFnType " + f"but got {default}" + ) + self.operand_def = OperandDef( + OperandKind.BINARY_FN_ATTR, default_fn=default.fn_name + ) - def __call__(self, arg0: TensorExpression, - arg1: TensorExpression) -> TensorFn: - return TensorFn(FunctionKind.BINARY, None, self.operand_def, None, - [arg0, arg1]) + def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> TensorFn: + return TensorFn(FunctionKind.BINARY, None, self.operand_def, None, [arg0, arg1]) - def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: - return ReduceFnUse(None, self, *reduce_dims) + def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: + return ReduceFnUse(None, self, *reduce_dims) class TypeFnAttrDef: - """Type conversion function attribute definition. + """Type conversion function attribute definition. - Type conversion function attributes provide a way to make type conversions - parameterizable. Every attribute specifies a default type conversion function - that may be overwritten at operation instantiation time. - """ + Type conversion function attributes provide a way to make type conversions + parameterizable. Every attribute specifies a default type conversion function + that may be overwritten at operation instantiation time. + """ - def __init__(self, default: "TypeFnType"): - if not isinstance(default, TypeFnType): - raise ValueError(f"TypeFnAttrDef requires default of type TypeFnType " - f"but got {default}") - self.operand_def = OperandDef( - OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name) + def __init__(self, default: "TypeFnType"): + if not isinstance(default, TypeFnType): + raise ValueError( + f"TypeFnAttrDef requires default of type TypeFnType " + f"but got {default}" + ) + self.operand_def = OperandDef( + OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name + ) - def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn: - return TensorFn(FunctionKind.TYPE, None, self.operand_def, type_var, [arg]) + def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn: + return TensorFn(FunctionKind.TYPE, None, self.operand_def, type_var, [arg]) ############################################################################### @@ -644,48 +686,48 @@ class TypeFnAttrDef: class Comprehension: - """Represents a single comprehension.""" - - def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]): - self.definitions = list() # List[TensorUse] - self.values = list() # List[TensorExpression] - - # Find the lhs to reduction rhs. - for assign, value in bindings: - if isinstance(value, TensorReduceFn): - if value.lhs: - raise ValueError(f"Reduction expression already assigns: {value}") - value.lhs = assign - self.definitions.append(assign) - self.values.append(value) - - @property - def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]: - """Gets the reduction dims for the comprehension or None.""" - result = set() - for use in self.values: - if isinstance(use, TensorReduceFn): - result.add(use.reduce_use.reduce_dims) - else: - result.add(tuple()) - return result - - def __repr__(self): - if len(self.definitions) > 1: - defs_repr = f"({', '.join(repr(d) for d in self.definitions)})" - values_repr = f"({', '.join(repr(v) for v in self.values)})" - else: - defs_repr = f"{repr(self.definitions[0])}" - values_repr = f"{repr(self.values[0])}" - - return f"{defs_repr} = {values_repr}" + """Represents a single comprehension.""" + + def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]): + self.definitions = list() # List[TensorUse] + self.values = list() # List[TensorExpression] + + # Find the lhs to reduction rhs. + for assign, value in bindings: + if isinstance(value, TensorReduceFn): + if value.lhs: + raise ValueError(f"Reduction expression already assigns: {value}") + value.lhs = assign + self.definitions.append(assign) + self.values.append(value) + + @property + def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]: + """Gets the reduction dims for the comprehension or None.""" + result = set() + for use in self.values: + if isinstance(use, TensorReduceFn): + result.add(use.reduce_use.reduce_dims) + else: + result.add(tuple()) + return result + + def __repr__(self): + if len(self.definitions) > 1: + defs_repr = f"({', '.join(repr(d) for d in self.definitions)})" + values_repr = f"({', '.join(repr(v) for v in self.values)})" + else: + defs_repr = f"{repr(self.definitions[0])}" + values_repr = f"{repr(self.values[0])}" + + return f"{defs_repr} = {values_repr}" class OpInterfaceDef: - """An interface that an op implements.""" + """An interface that an op implements.""" - def __init__(self, cpp_name: str): - self.cpp_name = cpp_name + def __init__(self, cpp_name: str): + self.cpp_name = cpp_name ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface") @@ -694,86 +736,94 @@ FillOpInterface = OpInterfaceDef("LinalgFillOpInterface") class OpDefinitionDef: - """A method that an op implements.""" + """A method that an op implements.""" - def __init__(self, def_name: str): - self.def_name = def_name + def __init__(self, def_name: str): + self.def_name = def_name Canonicalizer = OpDefinitionDef("hasCanonicalizer") class OpMetadataDef(YAMLObject): - """Metadata about the op (generally not behavior impacting).""" - yaml_tag = "!LinalgOpMetadata" - - def __init__(self, name: str, cpp_class_name: Optional[str], - doc: Optional[str]): - self.name = name - self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name - self.doc = doc - self.implements = [] # type: List[OpInterfaceDef] - self.defines = [] # type: List[OpDefinitionsDef] - - def to_yaml_custom_dict(self): - d = dict( - name=self.name, - cpp_class_name=self.cpp_class_name, - doc=self.doc, - ) - if self.implements: - d["implements"] = [intr.cpp_name for intr in self.implements] - if self.defines: - d["defines"] = [defi.def_name for defi in self.defines] - return d + """Metadata about the op (generally not behavior impacting).""" + + yaml_tag = "!LinalgOpMetadata" + + def __init__(self, name: str, cpp_class_name: Optional[str], doc: Optional[str]): + self.name = name + self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name + self.doc = doc + self.implements = [] # type: List[OpInterfaceDef] + self.defines = [] # type: List[OpDefinitionsDef] + + def to_yaml_custom_dict(self): + d = dict( + name=self.name, + cpp_class_name=self.cpp_class_name, + doc=self.doc, + ) + if self.implements: + d["implements"] = [intr.cpp_name for intr in self.implements] + if self.defines: + d["defines"] = [defi.def_name for defi in self.defines] + return d class LinalgOpDef: - """Definition of a linalg op.""" - - def __init__(self, - name: str, - cpp_class_name: Optional[str] = None, - doc: Optional[str] = None): - self.metadata = OpMetadataDef( - name=name, cpp_class_name=cpp_class_name, doc=doc) - self.registered_operands = dict() # type: Dict[str, OperandDef] - self.domain = list() # type: List[DimDef] - self.comprehensions = list() # type: List[Comprehension] - self._affine_state = AffineBuildState() - - def add_operand(self, name: str, operand: OperandDef): - """Registers an operand.""" - if name in self.registered_operands: - raise ValueError(f"The operand {name} is already registered " - f"to {self.registered_operands['name']}") - structured_op_methods = [ - "inputs", "outputs", "result_tensors", "region", "iterator_types", - "indexing_maps", "getRegionBuilder", "getLibraryCallName" - ] - if operand.is_attribute() and name in structured_op_methods: - raise ValueError(f"The attribute name {name} conflicts with a structured " - f"op method name") - # Ensure output tensors are registered after input tensors and scalars and - # attributes are registered after all other operand types. - if operand.is_input() and any( - not op_def.is_input() for op_def in self.registered_operands.values()): - raise ValueError(f"Input {name} registered after an output or attribute") - if operand.kind == OperandKind.OUTPUT_TENSOR and any( - op_def.is_attribute() for op_def in self.registered_operands.values()): - raise ValueError(f"Output {name} registered after an attribute") - operand.attach(len(self.registered_operands), name, self) - self.registered_operands[name] = operand - - def __repr__(self): - lines = [ - f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_class_name}," - ] - for name, operand in self.registered_operands.items(): - lines.append(f" {operand}") - if self.comprehensions: - lines[-1] += " {" - for comprehension in self.comprehensions: - lines.append(f" {comprehension}") - lines.append("}") - return "\n".join(lines) + """Definition of a linalg op.""" + + def __init__( + self, name: str, cpp_class_name: Optional[str] = None, doc: Optional[str] = None + ): + self.metadata = OpMetadataDef(name=name, cpp_class_name=cpp_class_name, doc=doc) + self.registered_operands = dict() # type: Dict[str, OperandDef] + self.domain = list() # type: List[DimDef] + self.comprehensions = list() # type: List[Comprehension] + self._affine_state = AffineBuildState() + + def add_operand(self, name: str, operand: OperandDef): + """Registers an operand.""" + if name in self.registered_operands: + raise ValueError( + f"The operand {name} is already registered " + f"to {self.registered_operands['name']}" + ) + structured_op_methods = [ + "inputs", + "outputs", + "result_tensors", + "region", + "iterator_types", + "indexing_maps", + "getRegionBuilder", + "getLibraryCallName", + ] + if operand.is_attribute() and name in structured_op_methods: + raise ValueError( + f"The attribute name {name} conflicts with a structured " + f"op method name" + ) + # Ensure output tensors are registered after input tensors and scalars and + # attributes are registered after all other operand types. + if operand.is_input() and any( + not op_def.is_input() for op_def in self.registered_operands.values() + ): + raise ValueError(f"Input {name} registered after an output or attribute") + if operand.kind == OperandKind.OUTPUT_TENSOR and any( + op_def.is_attribute() for op_def in self.registered_operands.values() + ): + raise ValueError(f"Output {name} registered after an attribute") + operand.attach(len(self.registered_operands), name, self) + self.registered_operands[name] = operand + + def __repr__(self): + lines = [f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_class_name},"] + for name, operand in self.registered_operands.items(): + lines.append(f" {operand}") + if self.comprehensions: + lines[-1] += " {" + for comprehension in self.comprehensions: + lines.append(f" {comprehension}") + lines.append("}") + return "\n".join(lines) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py index 2a0da68..d522d57 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py @@ -21,422 +21,468 @@ __all__ = ["LinalgStructuredOpConfig", "LinalgOpConfig", "OperandDefConfig"] def _serialize_affine_map(affine_map: _ir.AffineMap) -> str: - with affine_map.context: - # Affine map printing/parsing is via an AffineMap attr. - attr = _ir.AffineMapAttr.get(affine_map) - return str(attr) + with affine_map.context: + # Affine map printing/parsing is via an AffineMap attr. + attr = _ir.AffineMapAttr.get(affine_map) + return str(attr) class TensorUseConfig: - """Wrapper around a TensorUse with additional context-bound state.""" + """Wrapper around a TensorUse with additional context-bound state.""" - def __init__(self, tensor_use: TensorUse, indexing_map: _ir.AffineMap): - self.tensor_use = tensor_use - self.indexing_map = indexing_map + def __init__(self, tensor_use: TensorUse, indexing_map: _ir.AffineMap): + self.tensor_use = tensor_use + self.indexing_map = indexing_map - def __repr__(self): - return f"Use({self.tensor_use}, indexing_map={self.indexing_map})" + def __repr__(self): + return f"Use({self.tensor_use}, indexing_map={self.indexing_map})" class OperandDefConfig(YAMLObject): - """Wrapper containing an operand definition with additional state.""" - yaml_tag = "!LinalgOperandDefConfig" - - def __init__(self, - operand_def: OperandDef, - shape_map: Optional[_ir.AffineMap] = None, - index_attr_map: Optional[_ir.AffineMap] = None): - self.operand_def = operand_def - self.shape_map = shape_map # type: Optional[_ir.AffineMap] - self.index_attr_map = index_attr_map # type: Optional[_ir.AffineMap] - self.indexing_map = None # type: Optional[_ir.AffineMap] - - @property - def name(self) -> str: - return self.operand_def.name - - @property - def kind(self) -> OperandKind: - return self.operand_def.kind - - @property - def type_var(self) -> TypeVar: - return self.operand_def.type_var - - def to_yaml_custom_dict(self): - self_dict = dict(name=self.name, kind=self.operand_def.kind.name.lower()) - if self.type_var: - self_dict["type_var"] = self.type_var.name - if self.shape_map: - self_dict["shape_map"] = _serialize_affine_map(self.shape_map) - if self.index_attr_map: - self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map) - if self.operand_def.default_indices: - self_dict["default_indices"] = self.operand_def.default_indices - if self.operand_def.default_fn: - self_dict["default_fn"] = self.operand_def.default_fn - return self_dict - - def __repr__(self): - return (f"OperandDefConfig({self.operand_def}, " + """Wrapper containing an operand definition with additional state.""" + + yaml_tag = "!LinalgOperandDefConfig" + + def __init__( + self, + operand_def: OperandDef, + shape_map: Optional[_ir.AffineMap] = None, + index_attr_map: Optional[_ir.AffineMap] = None, + ): + self.operand_def = operand_def + self.shape_map = shape_map # type: Optional[_ir.AffineMap] + self.index_attr_map = index_attr_map # type: Optional[_ir.AffineMap] + self.indexing_map = None # type: Optional[_ir.AffineMap] + + @property + def name(self) -> str: + return self.operand_def.name + + @property + def kind(self) -> OperandKind: + return self.operand_def.kind + + @property + def type_var(self) -> TypeVar: + return self.operand_def.type_var + + def to_yaml_custom_dict(self): + self_dict = dict(name=self.name, kind=self.operand_def.kind.name.lower()) + if self.type_var: + self_dict["type_var"] = self.type_var.name + if self.shape_map: + self_dict["shape_map"] = _serialize_affine_map(self.shape_map) + if self.index_attr_map: + self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map) + if self.operand_def.default_indices: + self_dict["default_indices"] = self.operand_def.default_indices + if self.operand_def.default_fn: + self_dict["default_fn"] = self.operand_def.default_fn + return self_dict + + def __repr__(self): + return ( + f"OperandDefConfig({self.operand_def}, " f"shape_map={self.shape_map}, " f"index_attr_map={self.index_attr_map}, " - f"indexing_map={self.indexing_map})") + f"indexing_map={self.indexing_map})" + ) class LinalgIndexingMapsConfig(YAMLObject): - """Abstracts the style of indexing maps that the op exports. - - Presently only static (tied to the op name) indexing maps are supported. In - the future, it is expected that we will have additional variants: - - Dynamic based on attributes - - Dynamic based on operands - Each is expected to require a different variant of specification. - """ - yaml_tag = "!LinalgIndexingMapsConfig" - - def __init__(self, - static_indexing_maps: Optional[Sequence[_ir.AffineMap]] = None): - self.static_indexing_maps = static_indexing_maps - - def to_yaml_custom_dict(self): - if self.static_indexing_maps is not None: - return dict(static_indexing_maps=[ - _serialize_affine_map(m) for m in self.static_indexing_maps - ]) - raise ValueError( - f"LinalgIndexingMapsConfig must have one type of indexing map" - f"(got none)") + """Abstracts the style of indexing maps that the op exports. + Presently only static (tied to the op name) indexing maps are supported. In + the future, it is expected that we will have additional variants: + - Dynamic based on attributes + - Dynamic based on operands + Each is expected to require a different variant of specification. + """ -class LinalgStructuredOpConfig(YAMLObject): - """Configuration for metadata sufficient to construct a linalg named op.""" - - yaml_tag = "!LinalgStructuredOpConfig" - - def __init__(self, - comprehension: Comprehension, - domain: Sequence[DimDef], - registered_operands: Sequence[OperandDef], - context: Optional[_ir.Context] = None): - self.context = context if context is not None else _ir.Context() - self.affine_state = AffineBuildState() - self.writes = list() # type: List[Tuple[TensorUse, TensorExpression]] - self.operands = dict() # type: Dict[OperandDef, OperandDefConfig] - self.uses = dict() # type: Dict[TensorUse, TensorUseConfig] - - # Compute the ordered set of writes and collect the tensor, capture, dims, - # and index uses. - collected_tensor_uses = set() - collected_scalar_uses = set() - collected_dim_uses = set() - collected_indices = set() - for write_use, read_use in zip(comprehension.definitions, - comprehension.values): - self.writes.append((write_use, read_use)) - - for write_use, read_use in self.writes: - collected_tensor_uses.add(write_use) - read_use.collect_tensor_uses(collected_tensor_uses) - read_use.collect_scalar_uses(collected_scalar_uses) - read_use.collect_dim_uses(collected_dim_uses) - write_use.collect_dim_uses(collected_dim_uses) - read_use.collect_indices(collected_indices) - - # Set domain to the sorted list of uses if no domain annotation is given. - if not domain: - domain = sorted(collected_dim_uses, key=lambda dim: dim.dimname) - - # Verify the domain dimensions match the used dimensions. - if (len(domain) != len(collected_dim_uses) or - any(dim not in collected_dim_uses for dim in domain)): - raise ValueError(f"Expected the annotated domain dimensions {domain} to " - f"match the set of dimension used by the tensor " - f"comprehension {collected_dim_uses}") - - # Instantiate the dimensions in the given order. - with self.context: - local_state = AffineBuildState( - global_state=self.affine_state, allow_new_symbols=False) - for dim in domain: - dim.build(state=local_state) - - # Collect all attribute definitions. - collected_attr_defs = list() - for operand in registered_operands: - if operand.is_attribute(): - collected_attr_defs.append(operand) - - # Collect all tensors with manual indexing annotation. - collected_index_defs = list() - for operand in registered_operands: - if operand.index_dims: - if any(dim not in collected_dim_uses for dim in operand.index_dims): - raise ValueError(f"Expected all index dims {operand.index_dims} of " - f"operand {operand.name} to have uses.") - collected_index_defs.append(operand) - - # Collect the operand definitions of all tensor/scalar uses, attributes, and - # shape-only tensors. - all_operand_defs = list() - for use in collected_tensor_uses: - all_operand_defs.append(use.operand_def) - for use in collected_scalar_uses: - all_operand_defs.append(use.operand_def) - for definition in collected_attr_defs: - all_operand_defs.append(definition) - for definition in collected_index_defs: - all_operand_defs.append(definition) - - # Add all operands in registration order to ensure the symbols are - # registered in the order they appear. - all_operand_defs = sorted( - all_operand_defs, key=lambda operand_def: operand_def.registered_index) - for operand_def in all_operand_defs: - self.add_operand(operand_def) - - # Add all shape-only tensor index_dim annotations and all tensor uses. - for definition in collected_index_defs: - self.add_indexed_operand(definition) - for use in collected_tensor_uses: - self.add_tensor_use(use) - - # Normalize all shape and indexing maps now that full count of dims and - # symbols are known. - for cuse in self.uses.values(): - cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map) - for definition in collected_index_defs: - self.operands[definition].indexing_map = self._normalize_affine_map( - self.operands[definition].indexing_map) - for operand_config in self.operands.values(): - if operand_config.shape_map: - operand_config.shape_map = self._normalize_affine_map( - operand_config.shape_map, with_dims=False) - if operand_config.index_attr_map: - operand_config.index_attr_map = self._normalize_affine_map( - operand_config.index_attr_map, with_dims=False) - - # Now for each write use, propagate the indexing maps from the use to the - # tensor, ensuring that there are not conflicts. - for write_use, _ in self.writes: - write_tensor_config = self.operands[write_use.operand_def] - if write_tensor_config.indexing_map: - raise ValueError( - f"Unexpected multi-write to a single tensor: {write_tensor_config}") - write_tensor_config.indexing_map = self.uses[write_use].indexing_map - - # For each read use, propagate the indexing maps from the use to the - # tensor, ensuring that there are not conflicts. - for _, read_expr in self.writes: - read_uses = set() # type: Set[TensorUse] - read_expr.collect_tensor_uses(read_uses) - for read_use in read_uses: - read_operand_config = self.operands[read_use.operand_def] - if (read_operand_config.indexing_map and - read_operand_config.indexing_map != - self.uses[read_use].indexing_map): - raise ValueError( - f"Unexpected multi-read of a tensor with different accesses:" - f"{read_operand_config} vs {read_use}") - read_operand_config.indexing_map = self.uses[read_use].indexing_map - - # Set the indexing map of all scalar uses to the empty map. - for operand_config in self.operands.values(): - if operand_config.operand_def.kind == OperandKind.SCALAR: - operand_config.indexing_map = self._get_scalar_map() - - # Check all registered tensor and scalar operands have an indexing map. - for operand in registered_operands: - if operand.is_attribute(): - continue - if not (operand in self.operands and self.operands[operand].indexing_map): - raise ValueError(f"Failed to compute an indexing map for operand " - f"{operand.name}") - - # Collect reduction dims and ensure all the same. - all_reduction_dims = set(comprehension.all_reduction_dims) - if len(all_reduction_dims) != 1: - raise ValueError( - f"All writes within a generic must have the same reduction " - f"dims. Got: {all_reduction_dims}") - self.reduction_dims = next(iter(all_reduction_dims)) - - # Check the index dimension exists and resolve. - for index in collected_indices: - if index.dim_def.dimname not in self.affine_state.all_dims: + yaml_tag = "!LinalgIndexingMapsConfig" + + def __init__(self, static_indexing_maps: Optional[Sequence[_ir.AffineMap]] = None): + self.static_indexing_maps = static_indexing_maps + + def to_yaml_custom_dict(self): + if self.static_indexing_maps is not None: + return dict( + static_indexing_maps=[ + _serialize_affine_map(m) for m in self.static_indexing_maps + ] + ) raise ValueError( - f"The dimension {index.dim_def.dimname} is not part of the " - f"iteration domain {self.affine_state.all_dims}") - index.resolve_dimension_name(self.affine_state) - - # Generate the scalar assignments (used to build a body). - self.assignments = [ - ScalarAssign(write_use.tensor_name, read_expr.to_scalar_expression()) - for write_use, read_expr in self.writes - ] - - @property - def ordered_operands(self) -> Sequence[OperandDefConfig]: - return sorted( - self.operands.values(), - key=lambda operand: operand.operand_def.registered_index) - - @property - def ordered_dims(self) -> Sequence[Tuple[str, int]]: - """Gets the ordered list of dim bindings (symbolic name, position). - - TODO: The original parser relies on parse ordering to arrive at the - iterator types, but that ordering is not defined on the Python side, so - this may be ambiguous. - """ - return list(self.affine_state.all_dims.items()) - - @property - def indexing_maps(self) -> Sequence[_ir.AffineMap]: - return [o.indexing_map for o in self.ordered_operands if o.indexing_map] - - @property - def iterator_types(self) -> Sequence[str]: - - def get_type(symbolic_name, position): - for reduction_dim_expr in self.reduction_dims: - if reduction_dim_expr.dimname == symbolic_name: - return "reduction" - return "parallel" - - return [get_type(*dim) for dim in self.ordered_dims] - - def add_operand(self, operand_def: OperandDef): - if operand_def in self.operands: - return - if not (operand_def.is_tensor() or - operand_def.kind == OperandKind.INDEX_ATTR): - self.operands[operand_def] = OperandDefConfig(operand_def) - return - with self.context: - local_state = AffineBuildState( - global_state=self.affine_state, allow_new_dims=False) - exprs = [] - for expr in operand_def.size_exprs: - exprs.append(expr.build(state=local_state)) - assert local_state.local_dim_count == 0 - affine_map = _ir.AffineMap.get( - dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs) - if operand_def.kind == OperandKind.INDEX_ATTR: - self.operands[operand_def] = OperandDefConfig( - operand_def, index_attr_map=affine_map) - else: - self.operands[operand_def] = OperandDefConfig( - operand_def, shape_map=affine_map) - - def add_indexed_operand(self, operand_def: OperandDef): - with self.context: - local_state = AffineBuildState( - global_state=self.affine_state, allow_new_symbols=False) - exprs = [] - for expr in operand_def.index_dims: - exprs.append(expr.build(state=local_state)) - self.operands[operand_def].indexing_map = _ir.AffineMap.get( - dim_count=local_state.dim_count, - symbol_count=local_state.symbol_count, - exprs=exprs) - - def add_tensor_use(self, tensor_use: TensorUse): - if tensor_use in self.uses: - return - with self.context: - local_state = AffineBuildState( - global_state=self.affine_state, allow_new_symbols=False) - exprs = [] - for expr in tensor_use.indices: - exprs.append(expr.build(state=local_state)) - indexing_map = _ir.AffineMap.get( - dim_count=local_state.dim_count, - symbol_count=local_state.symbol_count, - exprs=exprs) - - use_config = TensorUseConfig(tensor_use, indexing_map) - self.uses[tensor_use] = use_config - - def _get_scalar_map(self) -> _ir.AffineMap: - """Create an empty affine map used to index a scalar.""" - with self.context: - return _ir.AffineMap.get( - dim_count=self.affine_state.dim_count, - symbol_count=self.affine_state.symbol_count, - exprs=list()) - - def _normalize_affine_map(self, - affine_map: _ir.AffineMap, - with_dims: bool = True) -> _ir.AffineMap: - """Normalizes an indexing map to have the max known symbols and dims.""" - with self.context: - return _ir.AffineMap.get( - dim_count=self.affine_state.dim_count if with_dims else 0, - symbol_count=self.affine_state.symbol_count, - exprs=list(affine_map.results)) - - def to_yaml_custom_dict(self): - self_dict = dict(args=self.ordered_operands) - # TODO: Refactor the hierarchy internally when supporting more - # than static (preserving this serialized form). - self_dict["indexing_maps"] = LinalgIndexingMapsConfig( - static_indexing_maps=self.indexing_maps) - self_dict["iterator_types"] = self.iterator_types - self_dict["assignments"] = self.assignments - return self_dict - - def __repr__(self): - lines = [f"LinalgGenericOpConfig(reduction_dims={self.reduction_dims},"] - lines.append("operands=[") - for def_config in self.ordered_operands: - lines.append(f" {repr(def_config)}") - lines.append("], indexing_maps=[") - for m in self.indexing_maps: - lines.append(f" {repr(m)}") - lines.append(f"], iterator_types=[") - for t in self.iterator_types: - lines.append(f" {t}") - lines.append("])") - return "\n".join(lines) + f"LinalgIndexingMapsConfig must have one type of indexing map" f"(got none)" + ) + + +class LinalgStructuredOpConfig(YAMLObject): + """Configuration for metadata sufficient to construct a linalg named op.""" + + yaml_tag = "!LinalgStructuredOpConfig" + + def __init__( + self, + comprehension: Comprehension, + domain: Sequence[DimDef], + registered_operands: Sequence[OperandDef], + context: Optional[_ir.Context] = None, + ): + self.context = context if context is not None else _ir.Context() + self.affine_state = AffineBuildState() + self.writes = list() # type: List[Tuple[TensorUse, TensorExpression]] + self.operands = dict() # type: Dict[OperandDef, OperandDefConfig] + self.uses = dict() # type: Dict[TensorUse, TensorUseConfig] + + # Compute the ordered set of writes and collect the tensor, capture, dims, + # and index uses. + collected_tensor_uses = set() + collected_scalar_uses = set() + collected_dim_uses = set() + collected_indices = set() + for write_use, read_use in zip(comprehension.definitions, comprehension.values): + self.writes.append((write_use, read_use)) + + for write_use, read_use in self.writes: + collected_tensor_uses.add(write_use) + read_use.collect_tensor_uses(collected_tensor_uses) + read_use.collect_scalar_uses(collected_scalar_uses) + read_use.collect_dim_uses(collected_dim_uses) + write_use.collect_dim_uses(collected_dim_uses) + read_use.collect_indices(collected_indices) + + # Set domain to the sorted list of uses if no domain annotation is given. + if not domain: + domain = sorted(collected_dim_uses, key=lambda dim: dim.dimname) + + # Verify the domain dimensions match the used dimensions. + if len(domain) != len(collected_dim_uses) or any( + dim not in collected_dim_uses for dim in domain + ): + raise ValueError( + f"Expected the annotated domain dimensions {domain} to " + f"match the set of dimension used by the tensor " + f"comprehension {collected_dim_uses}" + ) + + # Instantiate the dimensions in the given order. + with self.context: + local_state = AffineBuildState( + global_state=self.affine_state, allow_new_symbols=False + ) + for dim in domain: + dim.build(state=local_state) + + # Collect all attribute definitions. + collected_attr_defs = list() + for operand in registered_operands: + if operand.is_attribute(): + collected_attr_defs.append(operand) + + # Collect all tensors with manual indexing annotation. + collected_index_defs = list() + for operand in registered_operands: + if operand.index_dims: + if any(dim not in collected_dim_uses for dim in operand.index_dims): + raise ValueError( + f"Expected all index dims {operand.index_dims} of " + f"operand {operand.name} to have uses." + ) + collected_index_defs.append(operand) + + # Collect the operand definitions of all tensor/scalar uses, attributes, and + # shape-only tensors. + all_operand_defs = list() + for use in collected_tensor_uses: + all_operand_defs.append(use.operand_def) + for use in collected_scalar_uses: + all_operand_defs.append(use.operand_def) + for definition in collected_attr_defs: + all_operand_defs.append(definition) + for definition in collected_index_defs: + all_operand_defs.append(definition) + + # Add all operands in registration order to ensure the symbols are + # registered in the order they appear. + all_operand_defs = sorted( + all_operand_defs, key=lambda operand_def: operand_def.registered_index + ) + for operand_def in all_operand_defs: + self.add_operand(operand_def) + + # Add all shape-only tensor index_dim annotations and all tensor uses. + for definition in collected_index_defs: + self.add_indexed_operand(definition) + for use in collected_tensor_uses: + self.add_tensor_use(use) + + # Normalize all shape and indexing maps now that full count of dims and + # symbols are known. + for cuse in self.uses.values(): + cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map) + for definition in collected_index_defs: + self.operands[definition].indexing_map = self._normalize_affine_map( + self.operands[definition].indexing_map + ) + for operand_config in self.operands.values(): + if operand_config.shape_map: + operand_config.shape_map = self._normalize_affine_map( + operand_config.shape_map, with_dims=False + ) + if operand_config.index_attr_map: + operand_config.index_attr_map = self._normalize_affine_map( + operand_config.index_attr_map, with_dims=False + ) + + # Now for each write use, propagate the indexing maps from the use to the + # tensor, ensuring that there are not conflicts. + for write_use, _ in self.writes: + write_tensor_config = self.operands[write_use.operand_def] + if write_tensor_config.indexing_map: + raise ValueError( + f"Unexpected multi-write to a single tensor: {write_tensor_config}" + ) + write_tensor_config.indexing_map = self.uses[write_use].indexing_map + + # For each read use, propagate the indexing maps from the use to the + # tensor, ensuring that there are not conflicts. + for _, read_expr in self.writes: + read_uses = set() # type: Set[TensorUse] + read_expr.collect_tensor_uses(read_uses) + for read_use in read_uses: + read_operand_config = self.operands[read_use.operand_def] + if ( + read_operand_config.indexing_map + and read_operand_config.indexing_map + != self.uses[read_use].indexing_map + ): + raise ValueError( + f"Unexpected multi-read of a tensor with different accesses:" + f"{read_operand_config} vs {read_use}" + ) + read_operand_config.indexing_map = self.uses[read_use].indexing_map + + # Set the indexing map of all scalar uses to the empty map. + for operand_config in self.operands.values(): + if operand_config.operand_def.kind == OperandKind.SCALAR: + operand_config.indexing_map = self._get_scalar_map() + + # Check all registered tensor and scalar operands have an indexing map. + for operand in registered_operands: + if operand.is_attribute(): + continue + if not (operand in self.operands and self.operands[operand].indexing_map): + raise ValueError( + f"Failed to compute an indexing map for operand " f"{operand.name}" + ) + + # Collect reduction dims and ensure all the same. + all_reduction_dims = set(comprehension.all_reduction_dims) + if len(all_reduction_dims) != 1: + raise ValueError( + f"All writes within a generic must have the same reduction " + f"dims. Got: {all_reduction_dims}" + ) + self.reduction_dims = next(iter(all_reduction_dims)) + + # Check the index dimension exists and resolve. + for index in collected_indices: + if index.dim_def.dimname not in self.affine_state.all_dims: + raise ValueError( + f"The dimension {index.dim_def.dimname} is not part of the " + f"iteration domain {self.affine_state.all_dims}" + ) + index.resolve_dimension_name(self.affine_state) + + # Generate the scalar assignments (used to build a body). + self.assignments = [ + ScalarAssign(write_use.tensor_name, read_expr.to_scalar_expression()) + for write_use, read_expr in self.writes + ] + + @property + def ordered_operands(self) -> Sequence[OperandDefConfig]: + return sorted( + self.operands.values(), + key=lambda operand: operand.operand_def.registered_index, + ) + + @property + def ordered_dims(self) -> Sequence[Tuple[str, int]]: + """Gets the ordered list of dim bindings (symbolic name, position). + + TODO: The original parser relies on parse ordering to arrive at the + iterator types, but that ordering is not defined on the Python side, so + this may be ambiguous. + """ + return list(self.affine_state.all_dims.items()) + + @property + def indexing_maps(self) -> Sequence[_ir.AffineMap]: + return [o.indexing_map for o in self.ordered_operands if o.indexing_map] + + @property + def iterator_types(self) -> Sequence[str]: + def get_type(symbolic_name, position): + for reduction_dim_expr in self.reduction_dims: + if reduction_dim_expr.dimname == symbolic_name: + return "reduction" + return "parallel" + + return [get_type(*dim) for dim in self.ordered_dims] + + def add_operand(self, operand_def: OperandDef): + if operand_def in self.operands: + return + if not (operand_def.is_tensor() or operand_def.kind == OperandKind.INDEX_ATTR): + self.operands[operand_def] = OperandDefConfig(operand_def) + return + with self.context: + local_state = AffineBuildState( + global_state=self.affine_state, allow_new_dims=False + ) + exprs = [] + for expr in operand_def.size_exprs: + exprs.append(expr.build(state=local_state)) + assert local_state.local_dim_count == 0 + affine_map = _ir.AffineMap.get( + dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs + ) + if operand_def.kind == OperandKind.INDEX_ATTR: + self.operands[operand_def] = OperandDefConfig( + operand_def, index_attr_map=affine_map + ) + else: + self.operands[operand_def] = OperandDefConfig( + operand_def, shape_map=affine_map + ) + + def add_indexed_operand(self, operand_def: OperandDef): + with self.context: + local_state = AffineBuildState( + global_state=self.affine_state, allow_new_symbols=False + ) + exprs = [] + for expr in operand_def.index_dims: + exprs.append(expr.build(state=local_state)) + self.operands[operand_def].indexing_map = _ir.AffineMap.get( + dim_count=local_state.dim_count, + symbol_count=local_state.symbol_count, + exprs=exprs, + ) + + def add_tensor_use(self, tensor_use: TensorUse): + if tensor_use in self.uses: + return + with self.context: + local_state = AffineBuildState( + global_state=self.affine_state, allow_new_symbols=False + ) + exprs = [] + for expr in tensor_use.indices: + exprs.append(expr.build(state=local_state)) + indexing_map = _ir.AffineMap.get( + dim_count=local_state.dim_count, + symbol_count=local_state.symbol_count, + exprs=exprs, + ) + + use_config = TensorUseConfig(tensor_use, indexing_map) + self.uses[tensor_use] = use_config + + def _get_scalar_map(self) -> _ir.AffineMap: + """Create an empty affine map used to index a scalar.""" + with self.context: + return _ir.AffineMap.get( + dim_count=self.affine_state.dim_count, + symbol_count=self.affine_state.symbol_count, + exprs=list(), + ) + + def _normalize_affine_map( + self, affine_map: _ir.AffineMap, with_dims: bool = True + ) -> _ir.AffineMap: + """Normalizes an indexing map to have the max known symbols and dims.""" + with self.context: + return _ir.AffineMap.get( + dim_count=self.affine_state.dim_count if with_dims else 0, + symbol_count=self.affine_state.symbol_count, + exprs=list(affine_map.results), + ) + + def to_yaml_custom_dict(self): + self_dict = dict(args=self.ordered_operands) + # TODO: Refactor the hierarchy internally when supporting more + # than static (preserving this serialized form). + self_dict["indexing_maps"] = LinalgIndexingMapsConfig( + static_indexing_maps=self.indexing_maps + ) + self_dict["iterator_types"] = self.iterator_types + self_dict["assignments"] = self.assignments + return self_dict + + def __repr__(self): + lines = [f"LinalgGenericOpConfig(reduction_dims={self.reduction_dims},"] + lines.append("operands=[") + for def_config in self.ordered_operands: + lines.append(f" {repr(def_config)}") + lines.append("], indexing_maps=[") + for m in self.indexing_maps: + lines.append(f" {repr(m)}") + lines.append(f"], iterator_types=[") + for t in self.iterator_types: + lines.append(f" {t}") + lines.append("])") + return "\n".join(lines) class LinalgOpConfig(YAMLObject): - """Container for any supported linalg op type. - - This includes the concrete type by name for ease of parsing by systems - that ignore tags. - """ - yaml_tag = "!LinalgOpConfig" - - def __init__(self, - metadata: OpMetadataDef, - *, - structured_op: Optional[LinalgStructuredOpConfig] = None): - self.metadata = metadata - self.structured_op = structured_op - - def to_yaml_custom_dict(self): - self_dict = dict(metadata=self.metadata,) - if self.structured_op: - self_dict["structured_op"] = self.structured_op - return self_dict - - @staticmethod - def from_linalg_op_def( - op_def: LinalgOpDef, - context: Optional[_ir.Context] = None) -> Sequence["LinalgOpConfig"]: - """Expands a LinalgOpDef into corresponding Linalg configured ops.""" - # TODO: Many LinalgOpDef patterns need to expand to multiple generics. - assert len(op_def.comprehensions) == 1, "Only one comprehension supported" - return [ - LinalgOpConfig( - op_def.metadata, - structured_op=LinalgStructuredOpConfig( - op_def.comprehensions[0], op_def.domain, - op_def.registered_operands.values(), context)), - ] - - def __repr__(self): - return (f"LinalgOpConfig(metadata={self.metadata},\n" - f"structured_op={self.structured_op})") + """Container for any supported linalg op type. + + This includes the concrete type by name for ease of parsing by systems + that ignore tags. + """ + + yaml_tag = "!LinalgOpConfig" + + def __init__( + self, + metadata: OpMetadataDef, + *, + structured_op: Optional[LinalgStructuredOpConfig] = None, + ): + self.metadata = metadata + self.structured_op = structured_op + + def to_yaml_custom_dict(self): + self_dict = dict( + metadata=self.metadata, + ) + if self.structured_op: + self_dict["structured_op"] = self.structured_op + return self_dict + + @staticmethod + def from_linalg_op_def( + op_def: LinalgOpDef, context: Optional[_ir.Context] = None + ) -> Sequence["LinalgOpConfig"]: + """Expands a LinalgOpDef into corresponding Linalg configured ops.""" + # TODO: Many LinalgOpDef patterns need to expand to multiple generics. + assert len(op_def.comprehensions) == 1, "Only one comprehension supported" + return [ + LinalgOpConfig( + op_def.metadata, + structured_op=LinalgStructuredOpConfig( + op_def.comprehensions[0], + op_def.domain, + op_def.registered_operands.values(), + context, + ), + ), + ] + + def __repr__(self): + return ( + f"LinalgOpConfig(metadata={self.metadata},\n" + f"structured_op={self.structured_op})" + ) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py index 45b8d5c..8b8726f 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -10,160 +10,192 @@ import inspect import threading from ..... import ir -from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values +from ...._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, +) from .comprehension import * from .config import * from .emitter import * _CONTEXT = threading.local() -StructuredOpOuts = Union[ir.Operation, ir.OpView, ir.OpResultList, - Sequence[Union[ir.Value, ir.Operation, ir.OpView]]] +StructuredOpOuts = Union[ + ir.Operation, + ir.OpView, + ir.OpResultList, + Sequence[Union[ir.Value, ir.Operation, ir.OpView]], +] @contextmanager def bind_op_def(op_def: LinalgOpDef): - if hasattr(_CONTEXT, "current_op_def"): - raise ValueError("Cannot recursively define an operation") - _CONTEXT.current_op_def = op_def - try: - yield op_def - finally: - del _CONTEXT.current_op_def + if hasattr(_CONTEXT, "current_op_def"): + raise ValueError("Cannot recursively define an operation") + _CONTEXT.current_op_def = op_def + try: + yield op_def + finally: + del _CONTEXT.current_op_def def current_op_def() -> LinalgOpDef: - try: - return _CONTEXT.current_op_def - except AttributeError: - raise ValueError( - "Attempt to access the current op definition being defined " - "but none is set. Did you mean to call this in an op definition?") + try: + return _CONTEXT.current_op_def + except AttributeError: + raise ValueError( + "Attempt to access the current op definition being defined " + "but none is set. Did you mean to call this in an op definition?" + ) def _prepare_structured_op_outs(outs: StructuredOpOuts) -> ValueList: - if isinstance(outs, (ir.Operation, ir.OpView)): - return _get_op_results_or_values(outs) - elif isinstance(outs, ir.OpResultList): - return outs + if isinstance(outs, (ir.Operation, ir.OpView)): + return _get_op_results_or_values(outs) + elif isinstance(outs, ir.OpResultList): + return outs - return [_get_op_result_or_value(o) for o in outs] + return [_get_op_result_or_value(o) for o in outs] class DefinedOpCallable: - """Callable that wraps any defined op function.""" - - def __init__(self, op_name: str, op_def: LinalgOpDef): - self.op_name = op_name - self.op_def = op_def - - def __call__(self, *ins: Union[ir.Operation, ir.OpView, ir.Value], - outs: StructuredOpOuts, **kwargs): - """Emits the corresponding op definition as IR. - - Most arguments are passed through to the underlying emitter. The following - keyword argument is interpreted here: - emit_generic: Emits a generic form as appropriate (default True). If - False, a named form is emitted (which must have been built in to the - compiler). - """ - emit_generic = kwargs.pop("emit_generic", False) - if not isinstance(emit_generic, bool): - raise ValueError(f"The named argument 'emit_generic' needs to be " - f" of type bool but got {type(emit_generic)}") - - op_configs = LinalgOpConfig.from_linalg_op_def( - self.op_def, context=ir.Context.current) - - if len(op_configs) != 1: - # TODO: Support composite ops. - raise NotImplementedError( - f"Emission of composite linalg ops not supported: {op_configs}") - - ctx = ir.Context.current - linalgDialect = ctx.get_dialect_descriptor("linalg") - fully_qualified_name = "linalg." + self.op_name - emit_generic = ( - emit_generic or not ctx.is_registered_operation(fully_qualified_name)) - - op_config = op_configs[0] - out_values = _prepare_structured_op_outs(outs) - in_values = [_get_op_result_or_value(i) for i in ins] - if op_config.structured_op: - if emit_generic: - return emit_generic_structured_op( - op_config.structured_op, *in_values, outs=out_values, **kwargs) - else: - return emit_named_structured_op( - op_config.structured_op, - self.op_name, - self.op_def.metadata.cpp_class_name, - *in_values, - outs=out_values, - **kwargs) - - raise NotImplementedError( - f"Emission of linalg op type not supported: {op_config}") - - -def linalg_structured_op(dsl_func=None, - *, - op_name=None, - op_class_name=None) -> DefinedOpCallable: - if dsl_func is None: - # Curry the keyword args in for delayed application. - return functools.partial( - linalg_structured_op, op_name=op_name, op_class_name=op_class_name) - # Determine default names by introspecting the function. - if op_name is None: - op_name = dsl_func.__name__ - if op_class_name is None: - # Camel case it. - op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op" - - op_def = LinalgOpDef( - name=op_name, cpp_class_name=op_class_name, doc=inspect.getdoc(dsl_func)) - - # Extract arguments and TensorDefs from the signature. - dsl_func_args = list() - sig = inspect.signature(dsl_func) - for param_name, param in sig.parameters.items(): - param_default = param.default - if isinstance(param_default, - (TensorDef, ScalarDef, IndexAttrDef, UnaryFnAttrDef, - BinaryFnAttrDef, TypeFnAttrDef)): - op_def.add_operand(param_name, param_default.operand_def) - else: - raise ValueError( - f"@linalg_structured_op function parameters must be defaulted as " - f"TensorDef(...), ScalarDef(...), or IndexAttrDef(...): " - f"Found {param_name}: {param_default}") - dsl_func_args.append(param_default) - - # Invoke the DSL func to finish populating the op definition. - with bind_op_def(op_def): - dsl_func(*dsl_func_args) - - # TODO: The returned callable should be an IR emitter but that is not - # upstreamed yet. - return DefinedOpCallable(op_name, op_def) + """Callable that wraps any defined op function.""" + + def __init__(self, op_name: str, op_def: LinalgOpDef): + self.op_name = op_name + self.op_def = op_def + + def __call__( + self, + *ins: Union[ir.Operation, ir.OpView, ir.Value], + outs: StructuredOpOuts, + **kwargs, + ): + """Emits the corresponding op definition as IR. + + Most arguments are passed through to the underlying emitter. The following + keyword argument is interpreted here: + emit_generic: Emits a generic form as appropriate (default True). If + False, a named form is emitted (which must have been built in to the + compiler). + """ + emit_generic = kwargs.pop("emit_generic", False) + if not isinstance(emit_generic, bool): + raise ValueError( + f"The named argument 'emit_generic' needs to be " + f" of type bool but got {type(emit_generic)}" + ) + + op_configs = LinalgOpConfig.from_linalg_op_def( + self.op_def, context=ir.Context.current + ) + + if len(op_configs) != 1: + # TODO: Support composite ops. + raise NotImplementedError( + f"Emission of composite linalg ops not supported: {op_configs}" + ) + + ctx = ir.Context.current + linalgDialect = ctx.get_dialect_descriptor("linalg") + fully_qualified_name = "linalg." + self.op_name + emit_generic = emit_generic or not ctx.is_registered_operation( + fully_qualified_name + ) + + op_config = op_configs[0] + out_values = _prepare_structured_op_outs(outs) + in_values = [_get_op_result_or_value(i) for i in ins] + if op_config.structured_op: + if emit_generic: + return emit_generic_structured_op( + op_config.structured_op, *in_values, outs=out_values, **kwargs + ) + else: + return emit_named_structured_op( + op_config.structured_op, + self.op_name, + self.op_def.metadata.cpp_class_name, + *in_values, + outs=out_values, + **kwargs, + ) + + raise NotImplementedError( + f"Emission of linalg op type not supported: {op_config}" + ) + + +def linalg_structured_op( + dsl_func=None, *, op_name=None, op_class_name=None +) -> DefinedOpCallable: + if dsl_func is None: + # Curry the keyword args in for delayed application. + return functools.partial( + linalg_structured_op, op_name=op_name, op_class_name=op_class_name + ) + # Determine default names by introspecting the function. + if op_name is None: + op_name = dsl_func.__name__ + if op_class_name is None: + # Camel case it. + op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op" + + op_def = LinalgOpDef( + name=op_name, cpp_class_name=op_class_name, doc=inspect.getdoc(dsl_func) + ) + + # Extract arguments and TensorDefs from the signature. + dsl_func_args = list() + sig = inspect.signature(dsl_func) + for param_name, param in sig.parameters.items(): + param_default = param.default + if isinstance( + param_default, + ( + TensorDef, + ScalarDef, + IndexAttrDef, + UnaryFnAttrDef, + BinaryFnAttrDef, + TypeFnAttrDef, + ), + ): + op_def.add_operand(param_name, param_default.operand_def) + else: + raise ValueError( + f"@linalg_structured_op function parameters must be defaulted as " + f"TensorDef(...), ScalarDef(...), or IndexAttrDef(...): " + f"Found {param_name}: {param_default}" + ) + dsl_func_args.append(param_default) + + # Invoke the DSL func to finish populating the op definition. + with bind_op_def(op_def): + dsl_func(*dsl_func_args) + + # TODO: The returned callable should be an IR emitter but that is not + # upstreamed yet. + return DefinedOpCallable(op_name, op_def) def domain(*dimensions: DimDef): - if any(not isinstance(d, DimDef) for d in dimensions): - raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}") - current_op_def().domain.extend(dimensions) + if any(not isinstance(d, DimDef) for d in dimensions): + raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}") + current_op_def().domain.extend(dimensions) def implements(*interfaces: OpInterfaceDef): - if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces): - raise ValueError( - f"Expected interfaces of type OpInterfaceDef but got {interfaces}") - current_op_def().metadata.implements.extend(interfaces) + if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces): + raise ValueError( + f"Expected interfaces of type OpInterfaceDef but got {interfaces}" + ) + current_op_def().metadata.implements.extend(interfaces) def defines(*definitions: OpDefinitionDef): - if any(not isinstance(defi, OpDefinitionDef) for defi in definitions): - raise ValueError( - f"Expected definitions of type OpDefinitionDef but got {definitions}") - current_op_def().metadata.defines.extend(definitions) + if any(not isinstance(defi, OpDefinitionDef) for defi in definitions): + raise ValueError( + f"Expected definitions of type OpDefinitionDef but got {definitions}" + ) + current_op_def().metadata.defines.extend(definitions) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index b63cb40..62730d9 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -11,7 +11,10 @@ from .... import linalg from .... import math from .... import arith from .... import complex -from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values +from ...._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, +) from .scalar_expr import * from .config import * @@ -29,529 +32,618 @@ ValueList = Union[Sequence[Value], OpResultList] def isa(cls: Type, ty: Type): - try: - cls(ty) - return True - except ValueError: - return False - - -def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, - *ins: Value, outs: ValueList, - **attrs: Union[Sequence[int], TypeFnType]): - all_arg_defs = op_config.ordered_operands - in_arg_defs = [ - d for d in all_arg_defs - if d.kind in [OperandKind.SCALAR, OperandKind.INPUT_TENSOR] - ] - out_arg_defs = [ - d for d in all_arg_defs if d.kind == OperandKind.OUTPUT_TENSOR - ] - index_attr_arg_defs = [ - d for d in all_arg_defs if d.kind == OperandKind.INDEX_ATTR - ] - fn_attr_arg_defs = [ - d for d in all_arg_defs if d.kind in [ - OperandKind.UNARY_FN_ATTR, OperandKind.BINARY_FN_ATTR, - OperandKind.TYPE_FN_ATTR - ] - ] - - # Verify outs is a sequence or a list of results. - if not isinstance(outs, (Sequence, OpResultList)): - raise ValueError(f"Expected named argument outs to have type Sequence or " - f"OpResultLis but got {type(outs)}") - - # Arity validation. - if len(ins) != len(in_arg_defs): - raise ValueError(f"Expected {len(in_arg_defs)} inputs but got " - f"{len(ins)} for {op_config}") - if outs and len(outs) != len(out_arg_defs): - raise ValueError(f"Expected {len(out_arg_defs)} outputs but got " - f"{len(outs)} for {op_config}") - - # Compute a replacement list for all index attribute symbols. - expressions = [] # type: Sequence[AffineExpr] - replacements = [] # type: Sequence[AffineExpr] - for index_attr in index_attr_arg_defs: - index_attr_vals = index_attr.operand_def.default_indices - if index_attr.name in attrs: - index_attr_vals = attrs.get(index_attr.name) - assert index_attr_vals, "Index attribute has no value" - if not all(isinstance(value, int) for value in index_attr_vals): - raise ValueError(f"Attribute {index_attr.name} needs to be of type " - f"Sequence[int] but got {type(index_attr_vals)}") - results = index_attr.index_attr_map.results # type: AffineExprList - if len(index_attr_vals) != len(results): - raise ValueError(f"Attribute {index_attr.name} has length {len(results)} " - f"but got {len(index_attr_vals)} values") - for expr, value in zip(results, index_attr_vals): - expressions.append(expr) - replacements.append(AffineConstantExpr.get(value)) - - # Replace all index attribute symbols by their value. - # TODO: Add support for shape symbols. - indexing_maps = [] # type: Sequence[AffineMap] - for curr in op_config.indexing_maps: - for expression, replacement in zip(expressions, replacements): - curr = curr.replace(expression, replacement, curr.n_dims, curr.n_symbols) - indexing_maps.append(curr) - - # TODO: Linalg verification does not currently allow symbols. - # Compress them for now and verify none are left. - indexing_maps = AffineMap.compress_unused_symbols(indexing_maps, - Context.current) - if any(indexing_map.n_symbols != 0 for indexing_map in indexing_maps): - raise ValueError(f"Expected indexing_maps to use no symbols after " - f"replacement and compression but got {indexing_maps}") - - outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins, - out_arg_defs, outs) - - result_types = [t for t in out_types if isa(RankedTensorType, t)] - - # Initialize the type dictionary with the predefined types. - type_mapping = dict() # type: Dict[str, Type] - type_mapping["F32"] = F32Type.get() - type_mapping["F64"] = F64Type.get() - type_mapping["I32"] = IntegerType.get_signless(32) - type_mapping["I64"] = IntegerType.get_signless(64) - - # Extract type vars for input/output based types. - block_arg_types = list() # type: List[Type] - for arg_def, arg_element_type in zip(in_arg_defs + out_arg_defs, - _get_types_from_values(*ins, *outs)): - _add_type_mapping(arg_def, arg_element_type, type_mapping, block_arg_types) - - # Emit the generic op. - # TODO: Support emission of pure memref form. - indexing_maps_attr = ArrayAttr.get( - [AffineMapAttr.get(am) for am in indexing_maps]) - iterator_types_attr = ArrayAttr.get([ - Attribute.parse(f"#linalg.iterator_type<{s}>") - for s in op_config.iterator_types - ]) - - # Compute the index attributes used when emitting a named structured op. - index_attrs = {} # type: Dict[str, DenseElementAttr] - for index_attr in index_attr_arg_defs: - index_attr_vals = attrs.get(index_attr.name) - # Only forward attributes set to a non-default value. - if index_attr_vals: - array = np.array(index_attr_vals, dtype=np.int64) - index_attrs[index_attr.name] = DenseElementsAttr.get(array) - - # Compute the function attribute mapping. - fn_attr_mapping = {} - for fn_attr in fn_attr_arg_defs: - attr_val = fn_attr.operand_def.default_fn - attr_kind = fn_attr.kind - if fn_attr.name in attrs: - fn = attrs.get(fn_attr.name) - if attr_kind == OperandKind.UNARY_FN_ATTR: - if not isinstance(fn, UnaryFnType): - raise ValueError(f"Attribute {fn_attr.name} needs to be of type " - f"UnaryFnType but got {type(attr_val)}") - elif attr_kind == OperandKind.BINARY_FN_ATTR: - if not isinstance(fn, BinaryFnType): - raise ValueError(f"Attribute {fn_attr.name} needs to be of type " - f"BinaryFnType but got {type(attr_val)}") - else: - if not isinstance(fn, TypeFnType): - raise ValueError(f"Attribute {fn_attr.name} needs to be of type " - f"TypeFnType but got {type(attr_val)}") - attr_val = fn.fn_name - assert attr_val, "Function attribute has no value" - fn_attr_mapping[fn_attr.name] = (attr_val, attr_kind) - - return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, - type_mapping, indexing_maps_attr, iterator_types_attr, index_attrs, - fn_attr_mapping, block_arg_types) - - -def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, - outs: ValueList, **attrs: Sequence[int]): - all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - indexing_maps_attr, iterator_types_attr, index_attrs, fn_attr_mapping, \ - block_arg_types = \ - prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) - - # An operation that accesses only scalars and scalar/rank zero tensors is - # rank polymorhpic. We implement rank polymorphism by generating different - # indexing maps and iterators that match the rank of the first output tensor. - # An operation is rank polymorphic if the iteration domain has rank zero. - if not iterator_types_attr: - rank = ShapedType(outs[0].type).rank - iterator_types_attr = ArrayAttr.get( - [Attribute.parse("#linalg.iterator_type")] * rank) - scalar_map = AffineMap.get(rank, 0, []) - tensor_map = AffineMap.get_identity(rank) - indexing_maps = [] - for arg_def in all_arg_defs: - if arg_def.operand_def.kind == OperandKind.SCALAR: - indexing_maps.append(scalar_map) - if arg_def.operand_def.is_tensor(): - idx = arg_def.operand_def.registered_index - if idx < len(ins) and ShapedType(ins[idx].type).rank == 0: - indexing_maps.append(scalar_map) - else: - indexing_maps.append(tensor_map) - indexing_maps_attr = ArrayAttr.get( - [AffineMapAttr.get(am) for am in indexing_maps]) - - generic_op = linalg.GenericOp( - result_tensors=result_types, - inputs=ins, - outputs=outs, - indexing_maps=indexing_maps_attr, - iterator_types=iterator_types_attr, - doc=None, # TODO: Make optional. - library_call=None) # TODO: Make optional. - - # Construct the body. - block_arg_names = _get_operand_def_names(*in_arg_defs, *out_arg_defs) - block = generic_op.regions[0].blocks.append(*block_arg_types) - block_arg_mapping = dict(zip(block_arg_names, block.arguments)) - with InsertionPoint(block): - body_builder = _BodyBuilder(type_mapping, block_arg_mapping, - fn_attr_mapping) - for assignment in op_config.assignments: - body_builder.assign(assignment) - body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs)) - - if len(result_types) == 1: - return generic_op.result - else: - return generic_op.results - - -def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str, - op_class_name: str, *ins: Value, outs: ValueList, - **attrs: Sequence[int]): - all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - indexing_maps_attr, iterator_types_attr, index_attrs, fn_attr_mapping, \ - block_arg_types = \ - prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) - - # If we get here, there must exist a builtin class `op_class_name`. - ctx = Context.current - fully_qualified_name = "linalg." + op_name - if (not ctx.is_registered_operation(fully_qualified_name) or - not op_class_name in linalg.__dict__.keys()): - raise NotImplementedError( - f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}") - - # Set the index attributes used to compute the indexing maps. - named_op = getattr(linalg, op_class_name)(ins, outs, result_types) - for name, value in index_attrs.items(): - named_op.operation.attributes[name] = value - - # Compute the function attributes by combining operand kind and function name. - for name, (fn_name, kind) in fn_attr_mapping.items(): - assert kind.name.lower().endswith("_attr") - enum_name = kind.name.lower()[:-5] - named_op.operation.attributes[name] = Attribute.parse( - f"#linalg.{enum_name}<{fn_name}>") + try: + cls(ty) + return True + except ValueError: + return False - linalg.fill_builtin_region(named_op.operation) - if len(result_types) == 1: - return named_op.result - else: - return named_op.results +def prepare_common_structured_op( + op_config: LinalgStructuredOpConfig, + *ins: Value, + outs: ValueList, + **attrs: Union[Sequence[int], TypeFnType], +): + all_arg_defs = op_config.ordered_operands + in_arg_defs = [ + d + for d in all_arg_defs + if d.kind in [OperandKind.SCALAR, OperandKind.INPUT_TENSOR] + ] + out_arg_defs = [d for d in all_arg_defs if d.kind == OperandKind.OUTPUT_TENSOR] + index_attr_arg_defs = [d for d in all_arg_defs if d.kind == OperandKind.INDEX_ATTR] + fn_attr_arg_defs = [ + d + for d in all_arg_defs + if d.kind + in [ + OperandKind.UNARY_FN_ATTR, + OperandKind.BINARY_FN_ATTR, + OperandKind.TYPE_FN_ATTR, + ] + ] + + # Verify outs is a sequence or a list of results. + if not isinstance(outs, (Sequence, OpResultList)): + raise ValueError( + f"Expected named argument outs to have type Sequence or " + f"OpResultLis but got {type(outs)}" + ) + + # Arity validation. + if len(ins) != len(in_arg_defs): + raise ValueError( + f"Expected {len(in_arg_defs)} inputs but got " f"{len(ins)} for {op_config}" + ) + if outs and len(outs) != len(out_arg_defs): + raise ValueError( + f"Expected {len(out_arg_defs)} outputs but got " + f"{len(outs)} for {op_config}" + ) + + # Compute a replacement list for all index attribute symbols. + expressions = [] # type: Sequence[AffineExpr] + replacements = [] # type: Sequence[AffineExpr] + for index_attr in index_attr_arg_defs: + index_attr_vals = index_attr.operand_def.default_indices + if index_attr.name in attrs: + index_attr_vals = attrs.get(index_attr.name) + assert index_attr_vals, "Index attribute has no value" + if not all(isinstance(value, int) for value in index_attr_vals): + raise ValueError( + f"Attribute {index_attr.name} needs to be of type " + f"Sequence[int] but got {type(index_attr_vals)}" + ) + results = index_attr.index_attr_map.results # type: AffineExprList + if len(index_attr_vals) != len(results): + raise ValueError( + f"Attribute {index_attr.name} has length {len(results)} " + f"but got {len(index_attr_vals)} values" + ) + for expr, value in zip(results, index_attr_vals): + expressions.append(expr) + replacements.append(AffineConstantExpr.get(value)) + + # Replace all index attribute symbols by their value. + # TODO: Add support for shape symbols. + indexing_maps = [] # type: Sequence[AffineMap] + for curr in op_config.indexing_maps: + for expression, replacement in zip(expressions, replacements): + curr = curr.replace(expression, replacement, curr.n_dims, curr.n_symbols) + indexing_maps.append(curr) + + # TODO: Linalg verification does not currently allow symbols. + # Compress them for now and verify none are left. + indexing_maps = AffineMap.compress_unused_symbols(indexing_maps, Context.current) + if any(indexing_map.n_symbols != 0 for indexing_map in indexing_maps): + raise ValueError( + f"Expected indexing_maps to use no symbols after " + f"replacement and compression but got {indexing_maps}" + ) + + outs, out_types = _infer_structured_outs( + op_config, in_arg_defs, ins, out_arg_defs, outs + ) + + result_types = [t for t in out_types if isa(RankedTensorType, t)] + + # Initialize the type dictionary with the predefined types. + type_mapping = dict() # type: Dict[str, Type] + type_mapping["F32"] = F32Type.get() + type_mapping["F64"] = F64Type.get() + type_mapping["I32"] = IntegerType.get_signless(32) + type_mapping["I64"] = IntegerType.get_signless(64) + + # Extract type vars for input/output based types. + block_arg_types = list() # type: List[Type] + for arg_def, arg_element_type in zip( + in_arg_defs + out_arg_defs, _get_types_from_values(*ins, *outs) + ): + _add_type_mapping(arg_def, arg_element_type, type_mapping, block_arg_types) + + # Emit the generic op. + # TODO: Support emission of pure memref form. + indexing_maps_attr = ArrayAttr.get([AffineMapAttr.get(am) for am in indexing_maps]) + iterator_types_attr = ArrayAttr.get( + [ + Attribute.parse(f"#linalg.iterator_type<{s}>") + for s in op_config.iterator_types + ] + ) + + # Compute the index attributes used when emitting a named structured op. + index_attrs = {} # type: Dict[str, DenseElementAttr] + for index_attr in index_attr_arg_defs: + index_attr_vals = attrs.get(index_attr.name) + # Only forward attributes set to a non-default value. + if index_attr_vals: + array = np.array(index_attr_vals, dtype=np.int64) + index_attrs[index_attr.name] = DenseElementsAttr.get(array) + + # Compute the function attribute mapping. + fn_attr_mapping = {} + for fn_attr in fn_attr_arg_defs: + attr_val = fn_attr.operand_def.default_fn + attr_kind = fn_attr.kind + if fn_attr.name in attrs: + fn = attrs.get(fn_attr.name) + if attr_kind == OperandKind.UNARY_FN_ATTR: + if not isinstance(fn, UnaryFnType): + raise ValueError( + f"Attribute {fn_attr.name} needs to be of type " + f"UnaryFnType but got {type(attr_val)}" + ) + elif attr_kind == OperandKind.BINARY_FN_ATTR: + if not isinstance(fn, BinaryFnType): + raise ValueError( + f"Attribute {fn_attr.name} needs to be of type " + f"BinaryFnType but got {type(attr_val)}" + ) + else: + if not isinstance(fn, TypeFnType): + raise ValueError( + f"Attribute {fn_attr.name} needs to be of type " + f"TypeFnType but got {type(attr_val)}" + ) + attr_val = fn.fn_name + assert attr_val, "Function attribute has no value" + fn_attr_mapping[fn_attr.name] = (attr_val, attr_kind) + + return ( + all_arg_defs, + in_arg_defs, + out_arg_defs, + outs, + result_types, + type_mapping, + indexing_maps_attr, + iterator_types_attr, + index_attrs, + fn_attr_mapping, + block_arg_types, + ) + + +def emit_generic_structured_op( + op_config: LinalgStructuredOpConfig, + *ins: Value, + outs: ValueList, + **attrs: Sequence[int], +): + ( + all_arg_defs, + in_arg_defs, + out_arg_defs, + outs, + result_types, + type_mapping, + indexing_maps_attr, + iterator_types_attr, + index_attrs, + fn_attr_mapping, + block_arg_types, + ) = prepare_common_structured_op(op_config, *ins, outs=outs, **attrs) + + # An operation that accesses only scalars and scalar/rank zero tensors is + # rank polymorhpic. We implement rank polymorphism by generating different + # indexing maps and iterators that match the rank of the first output tensor. + # An operation is rank polymorphic if the iteration domain has rank zero. + if not iterator_types_attr: + rank = ShapedType(outs[0].type).rank + iterator_types_attr = ArrayAttr.get( + [Attribute.parse("#linalg.iterator_type")] * rank + ) + scalar_map = AffineMap.get(rank, 0, []) + tensor_map = AffineMap.get_identity(rank) + indexing_maps = [] + for arg_def in all_arg_defs: + if arg_def.operand_def.kind == OperandKind.SCALAR: + indexing_maps.append(scalar_map) + if arg_def.operand_def.is_tensor(): + idx = arg_def.operand_def.registered_index + if idx < len(ins) and ShapedType(ins[idx].type).rank == 0: + indexing_maps.append(scalar_map) + else: + indexing_maps.append(tensor_map) + indexing_maps_attr = ArrayAttr.get( + [AffineMapAttr.get(am) for am in indexing_maps] + ) + + generic_op = linalg.GenericOp( + result_tensors=result_types, + inputs=ins, + outputs=outs, + indexing_maps=indexing_maps_attr, + iterator_types=iterator_types_attr, + doc=None, # TODO: Make optional. + library_call=None, + ) # TODO: Make optional. + + # Construct the body. + block_arg_names = _get_operand_def_names(*in_arg_defs, *out_arg_defs) + block = generic_op.regions[0].blocks.append(*block_arg_types) + block_arg_mapping = dict(zip(block_arg_names, block.arguments)) + with InsertionPoint(block): + body_builder = _BodyBuilder(type_mapping, block_arg_mapping, fn_attr_mapping) + for assignment in op_config.assignments: + body_builder.assign(assignment) + body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs)) + + if len(result_types) == 1: + return generic_op.result + else: + return generic_op.results + + +def emit_named_structured_op( + op_config: LinalgStructuredOpConfig, + op_name: str, + op_class_name: str, + *ins: Value, + outs: ValueList, + **attrs: Sequence[int], +): + ( + all_arg_defs, + in_arg_defs, + out_arg_defs, + outs, + result_types, + type_mapping, + indexing_maps_attr, + iterator_types_attr, + index_attrs, + fn_attr_mapping, + block_arg_types, + ) = prepare_common_structured_op(op_config, *ins, outs=outs, **attrs) + + # If we get here, there must exist a builtin class `op_class_name`. + ctx = Context.current + fully_qualified_name = "linalg." + op_name + if ( + not ctx.is_registered_operation(fully_qualified_name) + or not op_class_name in linalg.__dict__.keys() + ): + raise NotImplementedError( + f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}" + ) + + # Set the index attributes used to compute the indexing maps. + named_op = getattr(linalg, op_class_name)(ins, outs, result_types) + for name, value in index_attrs.items(): + named_op.operation.attributes[name] = value + + # Compute the function attributes by combining operand kind and function name. + for name, (fn_name, kind) in fn_attr_mapping.items(): + assert kind.name.lower().endswith("_attr") + enum_name = kind.name.lower()[:-5] + named_op.operation.attributes[name] = Attribute.parse( + f"#linalg.{enum_name}<{fn_name}>" + ) + + linalg.fill_builtin_region(named_op.operation) + + if len(result_types) == 1: + return named_op.result + else: + return named_op.results class _BodyBuilder: - """Constructs a structured op body by evaluating assignments.""" - - def __init__(self, type_mapping: Dict[str, Type], - block_arg_mapping: Dict[str, Value], fn_attr_mapping: Dict[str, - str]): - self.type_mapping = type_mapping - self.block_arg_mapping = block_arg_mapping - self.fn_attr_mapping = fn_attr_mapping - self.yield_mapping = dict() # type: Dict[str, Value] - - def assign(self, assignment: ScalarAssign): - if assignment.arg in self.yield_mapping: - raise ValueError( - f"Multiple assignments to the same argument are forbidden: " - f"{assignment}") - self.yield_mapping[assignment.arg] = self.expression(assignment.value) - - def expression(self, expr: ScalarExpression) -> Value: - if expr.scalar_arg: - try: - return self.block_arg_mapping[expr.scalar_arg.arg] - except KeyError: - raise ValueError(f"Argument {expr.scalar_arg.arg} is not bound for " - f"this structured op.") - elif expr.scalar_const: - value_attr = Attribute.parse(expr.scalar_const.value) - return arith.ConstantOp(value_attr.type, value_attr).result - elif expr.scalar_index: - dim_attr = IntegerAttr.get( - IntegerType.get_signless(64), expr.scalar_index.dim) - return linalg.IndexOp(dim_attr).result - elif expr.scalar_fn: - kind = expr.scalar_fn.kind.name.lower() - fn_name = expr.scalar_fn.fn_name - if expr.scalar_fn.attr_name: - fn_name, _ = self.fn_attr_mapping[expr.scalar_fn.attr_name] - fn = self._get_function(f"_{kind}_{fn_name}") - operand_values = [ - self.expression(operand) for operand in expr.scalar_fn.operands - ] - if expr.scalar_fn.kind == FunctionKind.TYPE: - operand_values = [expr.scalar_fn.type_var.name] + operand_values - return fn(*operand_values) - raise NotImplementedError(f"Unimplemented scalar body expression: {expr}") - - def yield_outputs(self, *output_names: str): - output_values = [] - for n in output_names: - try: - output_values.append(self.yield_mapping[n]) - except KeyError: - raise ValueError(f"Body assignments do not assign all outputs: " - f"missing '{n}'") - linalg.YieldOp(output_values) - - def _get_function(self, fn_name: str) -> Callable: - try: - fn = getattr(self, f"{fn_name}") - except AttributeError: - raise ValueError(f"Function '{fn_name}' is not a known function") - return fn - - def _cast(self, - type_var_name: str, - operand: Value, - is_unsigned_cast: bool = False) -> Value: - try: - to_type = self.type_mapping[type_var_name] - except KeyError: - raise ValueError(f"Unbound type variable '{type_var_name}' (" - f"expected one of {self.type_mapping.keys()}") - if operand.type == to_type: - return operand - if _is_integer_type(to_type): - return self._cast_to_integer(to_type, operand, is_unsigned_cast) - elif _is_floating_point_type(to_type): - return self._cast_to_floating_point(to_type, operand, is_unsigned_cast) - - def _cast_to_integer(self, to_type: Type, operand: Value, - is_unsigned_cast: bool) -> Value: - to_width = IntegerType(to_type).width - operand_type = operand.type - if _is_floating_point_type(operand_type): - if is_unsigned_cast: - return arith.FPToUIOp(to_type, operand).result - return arith.FPToSIOp(to_type, operand).result - if _is_index_type(operand_type): - return arith.IndexCastOp(to_type, operand).result - # Assume integer. - from_width = IntegerType(operand_type).width - if to_width > from_width: - if is_unsigned_cast: - return arith.ExtUIOp(to_type, operand).result - return arith.ExtSIOp(to_type, operand).result - elif to_width < from_width: - return arith.TruncIOp(to_type, operand).result - raise ValueError(f"Unable to cast body expression from {operand_type} to " - f"{to_type}") - - def _cast_to_floating_point(self, to_type: Type, operand: Value, - is_unsigned_cast: bool) -> Value: - operand_type = operand.type - if _is_integer_type(operand_type): - if is_unsigned_cast: - return arith.UIToFPOp(to_type, operand).result - return arith.SIToFPOp(to_type, operand).result - # Assume FloatType. - to_width = _get_floating_point_width(to_type) - from_width = _get_floating_point_width(operand_type) - if to_width > from_width: - return arith.ExtFOp(to_type, operand).result - elif to_width < from_width: - return arith.TruncFOp(to_type, operand).result - raise ValueError(f"Unable to cast body expression from {operand_type} to " - f"{to_type}") - - def _type_cast_signed(self, type_var_name: str, operand: Value) -> Value: - return self._cast(type_var_name, operand, False) - - def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value: - return self._cast(type_var_name, operand, True) - - def _unary_exp(self, x: Value) -> Value: - if _is_floating_point_type(x.type): - return math.ExpOp(x).result - raise NotImplementedError("Unsupported 'exp' operand: {x}") - - def _unary_log(self, x: Value) -> Value: - if _is_floating_point_type(x.type): - return math.LogOp(x).result - raise NotImplementedError("Unsupported 'log' operand: {x}") - - def _unary_abs(self, x: Value) -> Value: - if _is_floating_point_type(x.type): - return math.AbsFOp(x).result - raise NotImplementedError("Unsupported 'abs' operand: {x}") - - def _unary_ceil(self, x: Value) -> Value: - if _is_floating_point_type(x.type): - return math.CeilOp(x).result - raise NotImplementedError("Unsupported 'ceil' operand: {x}") - - def _unary_floor(self, x: Value) -> Value: - if _is_floating_point_type(x.type): - return math.FloorOp(x).result - raise NotImplementedError("Unsupported 'floor' operand: {x}") - - def _unary_negf(self, x: Value) -> Value: - if _is_floating_point_type(x.type): - return arith.NegFOp(x).result - if _is_complex_type(x.type): - return complex.NegOp(x).result - raise NotImplementedError("Unsupported 'negf' operand: {x}") - - def _binary_add(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): - return arith.AddFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return arith.AddIOp(lhs, rhs).result - if _is_complex_type(lhs.type): - return complex.AddOp(lhs, rhs).result - raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}") - - def _binary_sub(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): - return arith.SubFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return arith.SubIOp(lhs, rhs).result - if _is_complex_type(lhs.type): - return complex.SubOp(lhs, rhs).result - raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}") - - def _binary_mul(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): - return arith.MulFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return arith.MulIOp(lhs, rhs).result - if _is_complex_type(lhs.type): - return complex.MulOp(lhs, rhs).result - raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}") - - def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): - return arith.MaxFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return arith.MaxSIOp(lhs, rhs).result - raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}") - - def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): - return arith.MaxFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return arith.MaxUIOp(lhs, rhs).result - raise NotImplementedError( - "Unsupported 'max_unsigned' operands: {lhs}, {rhs}") - - def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): - return arith.MinFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return arith.MinSIOp(lhs, rhs).result - raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}") - - def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): - return arith.MinFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return arith.MinUIOp(lhs, rhs).result - raise NotImplementedError( - "Unsupported 'min_unsigned' operands: {lhs}, {rhs}") + """Constructs a structured op body by evaluating assignments.""" + + def __init__( + self, + type_mapping: Dict[str, Type], + block_arg_mapping: Dict[str, Value], + fn_attr_mapping: Dict[str, str], + ): + self.type_mapping = type_mapping + self.block_arg_mapping = block_arg_mapping + self.fn_attr_mapping = fn_attr_mapping + self.yield_mapping = dict() # type: Dict[str, Value] + + def assign(self, assignment: ScalarAssign): + if assignment.arg in self.yield_mapping: + raise ValueError( + f"Multiple assignments to the same argument are forbidden: " + f"{assignment}" + ) + self.yield_mapping[assignment.arg] = self.expression(assignment.value) + + def expression(self, expr: ScalarExpression) -> Value: + if expr.scalar_arg: + try: + return self.block_arg_mapping[expr.scalar_arg.arg] + except KeyError: + raise ValueError( + f"Argument {expr.scalar_arg.arg} is not bound for " + f"this structured op." + ) + elif expr.scalar_const: + value_attr = Attribute.parse(expr.scalar_const.value) + return arith.ConstantOp(value_attr.type, value_attr).result + elif expr.scalar_index: + dim_attr = IntegerAttr.get( + IntegerType.get_signless(64), expr.scalar_index.dim + ) + return linalg.IndexOp(dim_attr).result + elif expr.scalar_fn: + kind = expr.scalar_fn.kind.name.lower() + fn_name = expr.scalar_fn.fn_name + if expr.scalar_fn.attr_name: + fn_name, _ = self.fn_attr_mapping[expr.scalar_fn.attr_name] + fn = self._get_function(f"_{kind}_{fn_name}") + operand_values = [ + self.expression(operand) for operand in expr.scalar_fn.operands + ] + if expr.scalar_fn.kind == FunctionKind.TYPE: + operand_values = [expr.scalar_fn.type_var.name] + operand_values + return fn(*operand_values) + raise NotImplementedError(f"Unimplemented scalar body expression: {expr}") + + def yield_outputs(self, *output_names: str): + output_values = [] + for n in output_names: + try: + output_values.append(self.yield_mapping[n]) + except KeyError: + raise ValueError( + f"Body assignments do not assign all outputs: " f"missing '{n}'" + ) + linalg.YieldOp(output_values) + + def _get_function(self, fn_name: str) -> Callable: + try: + fn = getattr(self, f"{fn_name}") + except AttributeError: + raise ValueError(f"Function '{fn_name}' is not a known function") + return fn + + def _cast( + self, type_var_name: str, operand: Value, is_unsigned_cast: bool = False + ) -> Value: + try: + to_type = self.type_mapping[type_var_name] + except KeyError: + raise ValueError( + f"Unbound type variable '{type_var_name}' (" + f"expected one of {self.type_mapping.keys()}" + ) + if operand.type == to_type: + return operand + if _is_integer_type(to_type): + return self._cast_to_integer(to_type, operand, is_unsigned_cast) + elif _is_floating_point_type(to_type): + return self._cast_to_floating_point(to_type, operand, is_unsigned_cast) + + def _cast_to_integer( + self, to_type: Type, operand: Value, is_unsigned_cast: bool + ) -> Value: + to_width = IntegerType(to_type).width + operand_type = operand.type + if _is_floating_point_type(operand_type): + if is_unsigned_cast: + return arith.FPToUIOp(to_type, operand).result + return arith.FPToSIOp(to_type, operand).result + if _is_index_type(operand_type): + return arith.IndexCastOp(to_type, operand).result + # Assume integer. + from_width = IntegerType(operand_type).width + if to_width > from_width: + if is_unsigned_cast: + return arith.ExtUIOp(to_type, operand).result + return arith.ExtSIOp(to_type, operand).result + elif to_width < from_width: + return arith.TruncIOp(to_type, operand).result + raise ValueError( + f"Unable to cast body expression from {operand_type} to " f"{to_type}" + ) + + def _cast_to_floating_point( + self, to_type: Type, operand: Value, is_unsigned_cast: bool + ) -> Value: + operand_type = operand.type + if _is_integer_type(operand_type): + if is_unsigned_cast: + return arith.UIToFPOp(to_type, operand).result + return arith.SIToFPOp(to_type, operand).result + # Assume FloatType. + to_width = _get_floating_point_width(to_type) + from_width = _get_floating_point_width(operand_type) + if to_width > from_width: + return arith.ExtFOp(to_type, operand).result + elif to_width < from_width: + return arith.TruncFOp(to_type, operand).result + raise ValueError( + f"Unable to cast body expression from {operand_type} to " f"{to_type}" + ) + + def _type_cast_signed(self, type_var_name: str, operand: Value) -> Value: + return self._cast(type_var_name, operand, False) + + def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value: + return self._cast(type_var_name, operand, True) + + def _unary_exp(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.ExpOp(x).result + raise NotImplementedError("Unsupported 'exp' operand: {x}") + + def _unary_log(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.LogOp(x).result + raise NotImplementedError("Unsupported 'log' operand: {x}") + + def _unary_abs(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.AbsFOp(x).result + raise NotImplementedError("Unsupported 'abs' operand: {x}") + + def _unary_ceil(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.CeilOp(x).result + raise NotImplementedError("Unsupported 'ceil' operand: {x}") + + def _unary_floor(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.FloorOp(x).result + raise NotImplementedError("Unsupported 'floor' operand: {x}") + + def _unary_negf(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return arith.NegFOp(x).result + if _is_complex_type(x.type): + return complex.NegOp(x).result + raise NotImplementedError("Unsupported 'negf' operand: {x}") + + def _binary_add(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.AddFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.AddIOp(lhs, rhs).result + if _is_complex_type(lhs.type): + return complex.AddOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}") + + def _binary_sub(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.SubFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.SubIOp(lhs, rhs).result + if _is_complex_type(lhs.type): + return complex.SubOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}") + + def _binary_mul(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.MulFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.MulIOp(lhs, rhs).result + if _is_complex_type(lhs.type): + return complex.MulOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}") + + def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.MaxFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.MaxSIOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}") + + def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.MaxFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.MaxUIOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}") + + def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.MinFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.MinSIOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}") + + def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.MinFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.MinUIOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'min_unsigned' operands: {lhs}, {rhs}") def _infer_structured_outs( op_config: LinalgStructuredOpConfig, - in_arg_defs: Sequence[OperandDefConfig], ins: Sequence[Value], + in_arg_defs: Sequence[OperandDefConfig], + ins: Sequence[Value], out_arg_defs: Sequence[OperandDefConfig], - outs: Union[Sequence[Value], OpResultList]) -> Tuple[ValueList, List[Type]]: - """Infers implicit outs and output types. + outs: Union[Sequence[Value], OpResultList], +) -> Tuple[ValueList, List[Type]]: + """Infers implicit outs and output types. - Respects existing contents of outs if not empty. + Respects existing contents of outs if not empty. - Returns: - normalized outs, output types - """ - # If outs were explicitly provided, we accept them verbatim. - if outs: - return outs, [out.type for out in outs] + Returns: + normalized outs, output types + """ + # If outs were explicitly provided, we accept them verbatim. + if outs: + return outs, [out.type for out in outs] - raise NotImplementedError(f"Output tensor inference not yet supported for " - "structured ops") + raise NotImplementedError( + f"Output tensor inference not yet supported for " "structured ops" + ) def _get_types_from_values(*values: Value) -> Sequence[Type]: - types = [] - for v in values: - types.append(v.type) - return types + types = [] + for v in values: + types.append(v.type) + return types def _get_operand_def_names(*operand_configs: OperandDefConfig) -> Sequence[str]: - return [odc.operand_def.name for odc in operand_configs] - - -def _add_type_mapping(operand_config: OperandDefConfig, operand_type: Type, - type_mapping: Dict[str, Type], - block_arg_types: Sequence[Type]): - element_or_self_type = operand_type - # Get the element type for tensor operands and the type itself for scalars. - if operand_config.shape_map: - try: - element_or_self_type = ShapedType(operand_type).element_type - except Exception as e: - raise ValueError(f"Expected ShapedType but got {operand_type}") from e - name = operand_config.type_var.name - if name in type_mapping: - if type_mapping[name] != element_or_self_type: - raise ValueError(f"Cannot overwrite type mapping {name} = " - f"{type_mapping[name]} by type {element_or_self_type}") - type_mapping[name] = element_or_self_type - block_arg_types.append(element_or_self_type) + return [odc.operand_def.name for odc in operand_configs] + + +def _add_type_mapping( + operand_config: OperandDefConfig, + operand_type: Type, + type_mapping: Dict[str, Type], + block_arg_types: Sequence[Type], +): + element_or_self_type = operand_type + # Get the element type for tensor operands and the type itself for scalars. + if operand_config.shape_map: + try: + element_or_self_type = ShapedType(operand_type).element_type + except Exception as e: + raise ValueError(f"Expected ShapedType but got {operand_type}") from e + name = operand_config.type_var.name + if name in type_mapping: + if type_mapping[name] != element_or_self_type: + raise ValueError( + f"Cannot overwrite type mapping {name} = " + f"{type_mapping[name]} by type {element_or_self_type}" + ) + type_mapping[name] = element_or_self_type + block_arg_types.append(element_or_self_type) def _is_complex_type(t: Type) -> bool: - return ComplexType.isinstance(t) + return ComplexType.isinstance(t) def _is_floating_point_type(t: Type) -> bool: - # TODO: Create a FloatType in the Python API and implement the switch - # there. - return (F64Type.isinstance(t) or F32Type.isinstance(t) or - F16Type.isinstance(t) or BF16Type.isinstance(t)) + # TODO: Create a FloatType in the Python API and implement the switch + # there. + return ( + F64Type.isinstance(t) + or F32Type.isinstance(t) + or F16Type.isinstance(t) + or BF16Type.isinstance(t) + ) def _is_integer_type(t: Type) -> bool: - return IntegerType.isinstance(t) + return IntegerType.isinstance(t) def _is_index_type(t: Type) -> bool: - return IndexType.isinstance(t) + return IndexType.isinstance(t) def _get_floating_point_width(t: Type) -> int: - # TODO: Create a FloatType in the Python API and implement the switch - # there. - if F64Type.isinstance(t): - return 64 - if F32Type.isinstance(t): - return 32 - if F16Type.isinstance(t): - return 16 - if BF16Type.isinstance(t): - return 16 - raise NotImplementedError(f"Unhandled floating point type switch {t}") + # TODO: Create a FloatType in the Python API and implement the switch + # there. + if F64Type.isinstance(t): + return 64 + if F32Type.isinstance(t): + return 32 + if F16Type.isinstance(t): + return 16 + if BF16Type.isinstance(t): + return 16 + raise NotImplementedError(f"Unhandled floating point type switch {t}") diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py index aa894dc..8685399 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py @@ -30,123 +30,137 @@ __all__ = [ class ScalarFn: - """A type of ScalarExpression that applies a function.""" - - def __init__(self, kind: "FunctionKind", fn_name: Optional[str], - attr_name: Optional[str], type_var: Optional["TypeVar"], - operands: Sequence["ScalarExpression"]): - if bool(fn_name) + bool(attr_name) != 1: - raise ValueError("One of 'fn_name', 'attr_name' must be specified") - self.kind = kind - self.fn_name = fn_name - self.attr_name = attr_name - self.type_var = type_var - self.operands = operands - - def expr(self) -> "ScalarExpression": - return ScalarExpression(scalar_fn=self) - - def __repr__(self): - name = self.fn_name if self.fn_name else self.attr_name - return (f"ScalarFn<{self.kind.name}.{name}>(type_var={self.type_var}, " - f"operands=[{', '.join(self.operands)}])") + """A type of ScalarExpression that applies a function.""" + + def __init__( + self, + kind: "FunctionKind", + fn_name: Optional[str], + attr_name: Optional[str], + type_var: Optional["TypeVar"], + operands: Sequence["ScalarExpression"], + ): + if bool(fn_name) + bool(attr_name) != 1: + raise ValueError("One of 'fn_name', 'attr_name' must be specified") + self.kind = kind + self.fn_name = fn_name + self.attr_name = attr_name + self.type_var = type_var + self.operands = operands + + def expr(self) -> "ScalarExpression": + return ScalarExpression(scalar_fn=self) + + def __repr__(self): + name = self.fn_name if self.fn_name else self.attr_name + return ( + f"ScalarFn<{self.kind.name}.{name}>(type_var={self.type_var}, " + f"operands=[{', '.join(self.operands)}])" + ) class ScalarArg: - """A type of ScalarExpression that references a named argument.""" + """A type of ScalarExpression that references a named argument.""" - def __init__(self, arg: str): - self.arg = arg + def __init__(self, arg: str): + self.arg = arg - def expr(self) -> "ScalarExpression": - return ScalarExpression(scalar_arg=self) + def expr(self) -> "ScalarExpression": + return ScalarExpression(scalar_arg=self) - def __repr__(self): - return f"(ScalarArg({self.arg})" + def __repr__(self): + return f"(ScalarArg({self.arg})" class ScalarConst: - """A type of ScalarExpression representing a constant.""" + """A type of ScalarExpression representing a constant.""" - def __init__(self, value: str): - self.value = value + def __init__(self, value: str): + self.value = value - def expr(self) -> "ScalarExpression": - return ScalarExpression(scalar_const=self) + def expr(self) -> "ScalarExpression": + return ScalarExpression(scalar_const=self) - def __repr__(self): - return f"(ScalarConst({self.value})" + def __repr__(self): + return f"(ScalarConst({self.value})" class ScalarIndex: - """A type of ScalarExpression accessing an iteration index.""" + """A type of ScalarExpression accessing an iteration index.""" - def __init__(self, dim: int): - self.dim = dim + def __init__(self, dim: int): + self.dim = dim - def expr(self) -> "ScalarExpression": - return ScalarExpression(scalar_index=self) + def expr(self) -> "ScalarExpression": + return ScalarExpression(scalar_index=self) - def __repr__(self): - return f"(ScalarIndex({self.dim})" + def __repr__(self): + return f"(ScalarIndex({self.dim})" class ScalarExpression(YAMLObject): - """An expression on scalar values. - - Can be one of: - - ScalarFn - - ScalarArg - - ScalarConst - - ScalarIndex - """ - yaml_tag = "!ScalarExpression" - - def __init__(self, - scalar_fn: Optional[ScalarFn] = None, - scalar_arg: Optional[ScalarArg] = None, - scalar_const: Optional[ScalarConst] = None, - scalar_index: Optional[ScalarIndex] = None): - if (bool(scalar_fn) + bool(scalar_arg) + bool(scalar_const) + - bool(scalar_index)) != 1: - raise ValueError("One of 'scalar_fn', 'scalar_arg', 'scalar_const', or " - "'scalar_index' must be specified") - self.scalar_fn = scalar_fn - self.scalar_arg = scalar_arg - self.scalar_const = scalar_const - self.scalar_index = scalar_index - - def to_yaml_custom_dict(self): - if self.scalar_fn: - scalar_fn_dict = dict(kind=self.scalar_fn.kind.name.lower()) - if self.scalar_fn.fn_name: - scalar_fn_dict["fn_name"] = self.scalar_fn.fn_name - if self.scalar_fn.attr_name: - scalar_fn_dict["attr_name"] = self.scalar_fn.attr_name - if self.scalar_fn.type_var: - scalar_fn_dict["type_var"] = self.scalar_fn.type_var.name - scalar_fn_dict["operands"] = list(self.scalar_fn.operands) - return dict(scalar_fn=scalar_fn_dict) - elif self.scalar_arg: - return dict(scalar_arg=self.scalar_arg.arg) - elif self.scalar_const: - return dict(scalar_const=self.scalar_const.value) - elif self.scalar_index: - return dict(scalar_index=self.scalar_index.dim) - else: - raise ValueError(f"Unexpected ScalarExpression type: {self}") + """An expression on scalar values. + + Can be one of: + - ScalarFn + - ScalarArg + - ScalarConst + - ScalarIndex + """ + + yaml_tag = "!ScalarExpression" + + def __init__( + self, + scalar_fn: Optional[ScalarFn] = None, + scalar_arg: Optional[ScalarArg] = None, + scalar_const: Optional[ScalarConst] = None, + scalar_index: Optional[ScalarIndex] = None, + ): + if ( + bool(scalar_fn) + bool(scalar_arg) + bool(scalar_const) + bool(scalar_index) + ) != 1: + raise ValueError( + "One of 'scalar_fn', 'scalar_arg', 'scalar_const', or " + "'scalar_index' must be specified" + ) + self.scalar_fn = scalar_fn + self.scalar_arg = scalar_arg + self.scalar_const = scalar_const + self.scalar_index = scalar_index + + def to_yaml_custom_dict(self): + if self.scalar_fn: + scalar_fn_dict = dict(kind=self.scalar_fn.kind.name.lower()) + if self.scalar_fn.fn_name: + scalar_fn_dict["fn_name"] = self.scalar_fn.fn_name + if self.scalar_fn.attr_name: + scalar_fn_dict["attr_name"] = self.scalar_fn.attr_name + if self.scalar_fn.type_var: + scalar_fn_dict["type_var"] = self.scalar_fn.type_var.name + scalar_fn_dict["operands"] = list(self.scalar_fn.operands) + return dict(scalar_fn=scalar_fn_dict) + elif self.scalar_arg: + return dict(scalar_arg=self.scalar_arg.arg) + elif self.scalar_const: + return dict(scalar_const=self.scalar_const.value) + elif self.scalar_index: + return dict(scalar_index=self.scalar_index.dim) + else: + raise ValueError(f"Unexpected ScalarExpression type: {self}") class ScalarAssign(YAMLObject): - """An assignment to a named argument (LHS of a comprehension).""" - yaml_tag = "!ScalarAssign" + """An assignment to a named argument (LHS of a comprehension).""" + + yaml_tag = "!ScalarAssign" - def __init__(self, arg: str, value: ScalarExpression): - self.arg = arg - self.value = value + def __init__(self, arg: str, value: ScalarExpression): + self.arg = arg + self.value = value - def to_yaml_custom_dict(self): - return dict(arg=self.arg, value=self.value) + def to_yaml_custom_dict(self): + return dict(arg=self.arg, value=self.value) - def __repr__(self): - return f"ScalarAssign({self.arg}, {self.value})" + def __repr__(self): + return f"ScalarAssign({self.arg}, {self.value})" diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/types.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/types.py index ddac872..4f36029 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/types.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/types.py @@ -21,13 +21,11 @@ from typing import Dict __all__ = [ "TypeVar", "TV", - # Predefined types. "I32", "I64", "F32", "F64", - # TypeVar aliases. "T", "U", @@ -36,34 +34,34 @@ __all__ = [ class TypeVar: - """A replaceable type variable. + """A replaceable type variable. - Type variables are uniqued by name. - """ - ALL_TYPEVARS = dict() # type: Dict[str, "TypeVar"] + Type variables are uniqued by name. + """ - def __new__(cls, name: str): - existing = cls.ALL_TYPEVARS.get(name) - if existing is not None: - return existing - new = super().__new__(cls) - new.name = name - cls.ALL_TYPEVARS[name] = new - return new + ALL_TYPEVARS = dict() # type: Dict[str, "TypeVar"] - def __repr__(self): - return f"TypeVar({self.name})" + def __new__(cls, name: str): + existing = cls.ALL_TYPEVARS.get(name) + if existing is not None: + return existing + new = super().__new__(cls) + new.name = name + cls.ALL_TYPEVARS[name] = new + return new - @classmethod - def create_expando(cls): - """Create an expando class that creates unique type vars on attr access.""" + def __repr__(self): + return f"TypeVar({self.name})" - class ExpandoTypeVars: + @classmethod + def create_expando(cls): + """Create an expando class that creates unique type vars on attr access.""" - def __getattr__(self, n): - return cls(n) + class ExpandoTypeVars: + def __getattr__(self, n): + return cls(n) - return ExpandoTypeVars() + return ExpandoTypeVars() # Expando access via TV.foo diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py index 1945eea..1672656 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py @@ -6,11 +6,12 @@ import sys try: - import yaml + import yaml except ModuleNotFoundError as e: - raise ModuleNotFoundError( - f"This tool requires PyYAML but it was not installed. " - f"Recommend: {sys.executable} -m pip install PyYAML") from e + raise ModuleNotFoundError( + f"This tool requires PyYAML but it was not installed. " + f"Recommend: {sys.executable} -m pip install PyYAML" + ) from e __all__ = [ "yaml_dump", @@ -20,35 +21,33 @@ __all__ = [ class YAMLObject(yaml.YAMLObject): + @classmethod + def to_yaml(cls, dumper, self): + """Default to a custom dictionary mapping.""" + return dumper.represent_mapping(cls.yaml_tag, self.to_yaml_custom_dict()) - @classmethod - def to_yaml(cls, dumper, self): - """Default to a custom dictionary mapping.""" - return dumper.represent_mapping(cls.yaml_tag, self.to_yaml_custom_dict()) + def to_yaml_custom_dict(self): + raise NotImplementedError() - def to_yaml_custom_dict(self): - raise NotImplementedError() - - def as_linalg_yaml(self): - return yaml_dump(self) + def as_linalg_yaml(self): + return yaml_dump(self) def multiline_str_representer(dumper, data): - if len(data.splitlines()) > 1: - return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') - else: - return dumper.represent_scalar('tag:yaml.org,2002:str', data) + if len(data.splitlines()) > 1: + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + else: + return dumper.represent_scalar("tag:yaml.org,2002:str", data) yaml.add_representer(str, multiline_str_representer) def yaml_dump(data, sort_keys=False, **kwargs): - return yaml.dump(data, sort_keys=sort_keys, **kwargs) + return yaml.dump(data, sort_keys=sort_keys, **kwargs) def yaml_dump_all(data, sort_keys=False, explicit_start=True, **kwargs): - return yaml.dump_all(data, - sort_keys=sort_keys, - explicit_start=explicit_start, - **kwargs) + return yaml.dump_all( + data, sort_keys=sort_keys, explicit_start=explicit_start, **kwargs + ) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 9c96868..bac22a2 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -7,99 +7,113 @@ Batch = S.Batch @linalg_structured_op -def copy(I=TensorDef(T1), - O=TensorDef(U, output=True), - cast=TypeFnAttrDef(default=TypeFn.cast_signed)): - """Copies the tensor elementwise. +def copy( + I=TensorDef(T1), + O=TensorDef(U, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast_signed), +): + """Copies the tensor elementwise. - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - O[None] = cast(U, I[None]) + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + O[None] = cast(U, I[None]) @linalg_structured_op -def elemwise_unary(I=TensorDef(T1), - O=TensorDef(U, output=True), - fun=UnaryFnAttrDef(default=UnaryFn.exp), - cast=TypeFnAttrDef(default=TypeFn.cast_signed)): - """Applies the unary function fun elementwise. +def elemwise_unary( + I=TensorDef(T1), + O=TensorDef(U, output=True), + fun=UnaryFnAttrDef(default=UnaryFn.exp), + cast=TypeFnAttrDef(default=TypeFn.cast_signed), +): + """Applies the unary function fun elementwise. - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - O[None] = fun(cast(U, I[None])) + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + O[None] = fun(cast(U, I[None])) @linalg_structured_op -def elemwise_binary(lhs=TensorDef(T1), - rhs=TensorDef(T2), - O=TensorDef(U, output=True), - fun=BinaryFnAttrDef(default=BinaryFn.add), - cast=TypeFnAttrDef(default=TypeFn.cast_signed)): - """Applies the binary function fun elementwise. +def elemwise_binary( + lhs=TensorDef(T1), + rhs=TensorDef(T2), + O=TensorDef(U, output=True), + fun=BinaryFnAttrDef(default=BinaryFn.add), + cast=TypeFnAttrDef(default=TypeFn.cast_signed), +): + """Applies the binary function fun elementwise. - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None])) + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None])) @linalg_structured_op -def matmul(A=TensorDef(T1, S.M, S.K), - B=TensorDef(T2, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True), - cast=TypeFnAttrDef(default=TypeFn.cast_signed)): - """Performs a matrix multiplication of two 2D inputs. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) +def matmul( + A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast_signed), +): + """Performs a matrix multiplication of two 2D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) @linalg_structured_op -def matmul_unsigned(A=TensorDef(T1, S.M, S.K), - B=TensorDef(T2, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True)): - """Performs an unsigned matrix multiplication of two 2D inputs. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.m, D.n] += TypeFn.cast_unsigned(U, A[D.m, D.k]) * TypeFn.cast_unsigned( - U, B[D.k, D.n]) +def matmul_unsigned( + A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True), +): + """Performs an unsigned matrix multiplication of two 2D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += TypeFn.cast_unsigned(U, A[D.m, D.k]) * TypeFn.cast_unsigned( + U, B[D.k, D.n] + ) @linalg_structured_op -def quantized_matmul(A=TensorDef(T1, S.M, S.K), - B=TensorDef(T2, S.K, S.N), - AZp=ScalarDef(I32), - BZp=ScalarDef(I32), - C=TensorDef(U, S.M, S.N, output=True)): - """Performs a matrix multiplication of two 2D inputs. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. The quantized variant - includes zero-point adjustments for the left and right operands of the - matmul. - """ - domain(D.m, D.n, D.k) - C[D.m, - D.n] += (TypeFn.cast_signed(U, A[D.m, D.k]) - - TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed(U, B[D.k, D.n]) - - TypeFn.cast_signed(U, BZp)) +def quantized_matmul( + A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.K, S.N), + AZp=ScalarDef(I32), + BZp=ScalarDef(I32), + C=TensorDef(U, S.M, S.N, output=True), +): + """Performs a matrix multiplication of two 2D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. The quantized variant + includes zero-point adjustments for the left and right operands of the + matmul. + """ + domain(D.m, D.n, D.k) + C[D.m, D.n] += (TypeFn.cast_signed(U, A[D.m, D.k]) - TypeFn.cast_signed(U, AZp)) * ( + TypeFn.cast_signed(U, B[D.k, D.n]) - TypeFn.cast_signed(U, BZp) + ) @linalg_structured_op -def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), - rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0), - accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, output=True)): - """Performs a matrix-matrix-transpose multiplication of two 4D inputs. +def mmt4d( + lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), + rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0), + accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, output=True), +): + """Performs a matrix-matrix-transpose multiplication of two 4D inputs. Differences from linalg.matmul: * The right hand side is transposed, whence the 't' in 'mmt'. @@ -108,1132 +122,1201 @@ def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), whence the 2+2=4 dimensions. The inner tile dimensions are identified with '0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads as: MxK tiles, each of shape M0xK0. - """ - domain(D.m, D.n, D.k, D.m0, D.n0, D.k0) - implements(ContractionOpInterface) - accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast_signed( - TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * TypeFn.cast_signed( - TV.AccumType, rhs[D.n, D.k, D.n0, D.k0]) + """ + domain(D.m, D.n, D.k, D.m0, D.n0, D.k0) + implements(ContractionOpInterface) + accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast_signed( + TV.AccumType, lhs[D.m, D.k, D.m0, D.k0] + ) * TypeFn.cast_signed(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0]) @linalg_structured_op -def batch_matmul(A=TensorDef(T1, Batch, S.M, S.K), - B=TensorDef(T2, Batch, S.K, S.N), - C=TensorDef(U, Batch, S.M, S.N, output=True)): - """Performs a batched matrix multiplication of two 3D inputs. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.b, D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.b, D.m, - D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( - U, B[D.b, D.k, D.n]) +def batch_matmul( + A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K, S.N), + C=TensorDef(U, Batch, S.M, S.N, output=True), +): + """Performs a batched matrix multiplication of two 3D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( + U, B[D.b, D.k, D.n] + ) @linalg_structured_op -def quantized_batch_matmul(A=TensorDef(T1, Batch, S.M, S.K), - B=TensorDef(T2, Batch, S.K, S.N), - AZp=ScalarDef(I32), - BZp=ScalarDef(I32), - C=TensorDef(U, Batch, S.M, S.N, output=True)): - """Performs a batched matrix multiplication of two 3D inputs. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. The quantized variant - includes zero-point adjustments for the left and right operands of the - matmul. - """ - domain(D.b, D.m, D.n, D.k) - C[D.b, D.m, D.n] += (TypeFn.cast_signed(U, A[D.b, D.m, D.k]) - - TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed( - U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp)) +def quantized_batch_matmul( + A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K, S.N), + AZp=ScalarDef(I32), + BZp=ScalarDef(I32), + C=TensorDef(U, Batch, S.M, S.N, output=True), +): + """Performs a batched matrix multiplication of two 3D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. The quantized variant + includes zero-point adjustments for the left and right operands of the + matmul. + """ + domain(D.b, D.m, D.n, D.k) + C[D.b, D.m, D.n] += ( + TypeFn.cast_signed(U, A[D.b, D.m, D.k]) - TypeFn.cast_signed(U, AZp) + ) * (TypeFn.cast_signed(U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp)) @linalg_structured_op -def batch_reduce_matmul(A=TensorDef(T1, Batch, S.M, S.K), - B=TensorDef(T2, Batch, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True)): - """Performs a batch-reduce matrix multiplication of two 3D inputs. - The partial multiplication results are reduced into a 2D output. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.b, D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.m, D.n] += TypeFn.cast_signed( - U, A[D.b, D.m, D.k] * TypeFn.cast_signed(U, B[D.b, D.k, D.n])) +def batch_reduce_matmul( + A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True), +): + """Performs a batch-reduce matrix multiplication of two 3D inputs. + The partial multiplication results are reduced into a 2D output. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += TypeFn.cast_signed( + U, A[D.b, D.m, D.k] * TypeFn.cast_signed(U, B[D.b, D.k, D.n]) + ) @linalg_structured_op -def matvec(A=TensorDef(T1, S.M, S.N), - y=TensorDef(T2, S.N), - x=TensorDef(U, S.M, output=True)): - """Performs a matrix-vector multiplication. +def matvec( + A=TensorDef(T1, S.M, S.N), y=TensorDef(T2, S.N), x=TensorDef(U, S.M, output=True) +): + """Performs a matrix-vector multiplication. - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.m, D.n) - implements(ContractionOpInterface) - x[D.m] += TypeFn.cast_signed(U, A[D.m, D.n]) * TypeFn.cast_signed(U, y[D.n]) + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.m, D.n) + implements(ContractionOpInterface) + x[D.m] += TypeFn.cast_signed(U, A[D.m, D.n]) * TypeFn.cast_signed(U, y[D.n]) @linalg_structured_op -def vecmat(y=TensorDef(T1, S.M), - A=TensorDef(T2, S.M, S.N), - x=TensorDef(U, S.N, output=True)): - """Performs a vector-matrix multiplication. +def vecmat( + y=TensorDef(T1, S.M), A=TensorDef(T2, S.M, S.N), x=TensorDef(U, S.N, output=True) +): + """Performs a vector-matrix multiplication. - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.n, D.m) - implements(ContractionOpInterface) - x[D.n] += TypeFn.cast_signed(U, y[D.m]) * TypeFn.cast_signed(U, A[D.m, D.n]) + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.n, D.m) + implements(ContractionOpInterface) + x[D.n] += TypeFn.cast_signed(U, y[D.m]) * TypeFn.cast_signed(U, A[D.m, D.n]) @linalg_structured_op -def batch_matvec(A=TensorDef(T1, Batch, S.M, S.K), - B=TensorDef(T2, Batch, S.K), - C=TensorDef(U, Batch, S.M, output=True)): - """Performs a batched matrix-vector multiplication. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.b, D.m, D.k) - implements(ContractionOpInterface) - C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( - U, B[D.b, D.k]) +def batch_matvec( + A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K), + C=TensorDef(U, Batch, S.M, output=True), +): + """Performs a batched matrix-vector multiplication. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.m, D.k) + implements(ContractionOpInterface) + C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( + U, B[D.b, D.k] + ) @linalg_structured_op -def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, - output=True)): - """Performs a dot product of two vectors to a scalar result. +def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)): + """Performs a dot product of two vectors to a scalar result. - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ContractionOpInterface) - C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m]) + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ContractionOpInterface) + C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m]) @linalg_structured_op -def conv_1d(I=TensorDef(T1, S.OW + S.KW), - K=TensorDef(T2, S.KW), - O=TensorDef(U, S.OW, output=True)): - """Performs 1-D convolution with no channels. +def conv_1d( + I=TensorDef(T1, S.OW + S.KW), + K=TensorDef(T2, S.KW), + O=TensorDef(U, S.OW, output=True), +): + """Performs 1-D convolution with no channels. - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.ow, D.kw) - O[D.ow] += TypeFn.cast_signed(U, I[D.ow + D.kw]) * TypeFn.cast_signed( - U, K[D.kw]) + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.ow, D.kw) + O[D.ow] += TypeFn.cast_signed(U, I[D.ow + D.kw]) * TypeFn.cast_signed(U, K[D.kw]) @linalg_structured_op -def conv_2d(I=TensorDef(T1, S.OH + S.KH, S.OW + S.KW), - K=TensorDef(T2, S.KH, S.KW), - O=TensorDef(U, S.OH, S.OW, output=True)): - """Performs 2-D convolution with no channels. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.oh, D.ow, D.kh, D.kw) - O[D.oh, D.ow] += TypeFn.cast_signed( - U, I[D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast_signed(U, K[D.kh, D.kw]) +def conv_2d( + I=TensorDef(T1, S.OH + S.KH, S.OW + S.KW), + K=TensorDef(T2, S.KH, S.KW), + O=TensorDef(U, S.OH, S.OW, output=True), +): + """Performs 2-D convolution with no channels. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.oh, D.ow, D.kh, D.kw) + O[D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.oh + D.kh, D.ow + D.kw] + ) * TypeFn.cast_signed(U, K[D.kh, D.kw]) @linalg_structured_op -def conv_3d(I=TensorDef(T1, S.OD + S.KD, S.OH + S.KH, S.OW + S.KW), - K=TensorDef(T2, S.KD, S.KH, S.KW), - O=TensorDef(U, S.OD, S.OH, S.OW, output=True)): - """Performs 3-D convolution with no channels. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw) - O[D.od, D.oh, D.ow] += TypeFn.cast_signed( - U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast_signed( - U, K[D.kd, D.kh, D.kw]) +def conv_3d( + I=TensorDef(T1, S.OD + S.KD, S.OH + S.KH, S.OW + S.KW), + K=TensorDef(T2, S.KD, S.KH, S.KW), + O=TensorDef(U, S.OD, S.OH, S.OW, output=True), +): + """Performs 3-D convolution with no channels. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw) + O[D.od, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw] + ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw]) @linalg_structured_op -def conv_1d_nwc_wcf(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KW, S.C, S.F), - O=TensorDef(U, S.N, S.OW, S.F, output=True), - strides=IndexAttrDef(S.SW, default=[1]), - dilations=IndexAttrDef(S.DW, default=[1])): - """Performs 1-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.ow, D.f, D.kw, D.c) - O[D.n, D.ow, D.f] += TypeFn.cast_signed( - U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast_signed( - U, K[D.kw, D.c, D.f]) +def conv_1d_nwc_wcf( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, S.C, S.F), + O=TensorDef(U, S.N, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs 1-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.f, D.kw, D.c) + O[D.n, D.ow, D.f] += TypeFn.cast_signed( + U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c] + ) * TypeFn.cast_signed(U, K[D.kw, D.c, D.f]) @linalg_structured_op -def conv_1d_ncw_fcw(I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.F, S.C, S.KW), - O=TensorDef(U, S.N, S.F, S.OW, output=True), - strides=IndexAttrDef(S.SW, default=[1]), - dilations=IndexAttrDef(S.DW, default=[1])): - """Performs 1-D convolution. - - Layout: - * Input: NCW. - * Kernel: FCW. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.f, D.ow, D.c, D.kw) - O[D.n, D.f, D.ow] += TypeFn.cast_signed( - U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed( - U, K[D.f, D.c, D.kw]) +def conv_1d_ncw_fcw( + I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.F, S.C, S.KW), + O=TensorDef(U, S.N, S.F, S.OW, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs 1-D convolution. + + Layout: + * Input: NCW. + * Kernel: FCW. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.f, D.ow, D.c, D.kw) + O[D.n, D.f, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW] + ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kw]) @linalg_structured_op -def conv_2d_nhwc_hwcf(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KH, S.KW, S.C, S.F), - O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): - """Performs 2-D convolution. - - Layout: - * Input: NHWC. - * Kernel: HWCF. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.c]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) +def conv_2d_nhwc_hwcf( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, S.C, S.F), + O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D convolution. + + Layout: + * Input: NHWC. + * Kernel: HWCF. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) @linalg_structured_op -def conv_2d_nhwc_fhwc(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.F, S.KH, S.KW, S.C), - O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): - """Performs 2-D convolution. - - Layout: - * Input: NHWC. - * Kernel: FHWC. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.c]) * TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c]) +def conv_2d_nhwc_fhwc( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.F, S.KH, S.KW, S.C), + O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D convolution. + + Layout: + * Input: NHWC. + * Kernel: FHWC. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) * TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c]) @linalg_structured_op -def conv_2d_nhwc_hwcf_q(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KH, S.KW, S.C, S.F), - IZp=ScalarDef(I32), - KZp=ScalarDef(I32), - O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): - """Performs 2-D convolution with zero point offsets. - - Layout: - * Input: NHWC. - * Kernel: HWCF. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. This includes the zero - point offsets common to quantized operations. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, - D.f] += (TypeFn.cast_signed( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) - - TypeFn.cast_signed(U, IZp)) * (TypeFn.cast_signed( - U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast_signed(U, KZp)) +def conv_2d_nhwc_hwcf_q( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, S.C, S.F), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D convolution with zero point offsets. + + Layout: + * Input: NHWC. + * Kernel: HWCF. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. This includes the zero + point offsets common to quantized operations. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.f] += ( + TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) + - TypeFn.cast_signed(U, IZp) + ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast_signed(U, KZp)) @linalg_structured_op -def conv_2d_nchw_fchw(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.F, S.C, S.KH, S.KW), - O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): - """Performs 2-D convolution. - - Layout: - * Input: NCHW. - * Kernel: FCHW. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.f, D.oh, D.ow] += TypeFn.cast_signed( - U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + - D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) +def conv_2d_nchw_fchw( + I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.F, S.C, S.KH, S.KW), + O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D convolution. + + Layout: + * Input: NCHW. + * Kernel: FCHW. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.f, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] + ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) @linalg_structured_op -def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C, - S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW), - O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): - """Performs 2-D grouped convolution. - - Layout: - * Input: NGCHW. - * Kernel: FGCHW. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed( - U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + - D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw]) +def conv_2d_ngchw_fgchw( + I=TensorDef( + T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW + ), + K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW), + O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D grouped convolution. + + Layout: + * Input: NGCHW. + * Kernel: FGCHW. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] + ) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw]) @linalg_structured_op -def conv_3d_ndhwc_dhwcf(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, - S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F), - O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True), - strides=IndexAttrDef(S.SD, - S.SH, - S.SW, - default=[1, 1, 1]), - dilations=IndexAttrDef(S.DD, - S.DH, - S.DW, - default=[1, 1, 1])): - """Performs 3-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) - O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast_signed( - U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast_signed( - U, K[D.kd, D.kh, D.kw, D.c, D.f]) +def conv_3d_ndhwc_dhwcf( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.C, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs 3-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) + O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.c, + ], + ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f]) @linalg_structured_op -def conv_3d_ndhwc_dhwcf_q(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, - S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F), - IZp=ScalarDef(I32), - KZp=ScalarDef(I32), - O=TensorDef(U, - S.N, - S.OD, - S.OH, - S.OW, - S.F, - output=True), - strides=IndexAttrDef(S.SD, - S.SH, - S.SW, - default=[1, 1, 1]), - dilations=IndexAttrDef(S.DD, - S.DH, - S.DW, - default=[1, 1, 1])): - """Performs 3-D convolution with zero point offsets. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. This includes the zero - point offsets common to quantized operations. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) - O[D.n, D.od, D.oh, D.ow, D.f] += (TypeFn.cast_signed( - U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW, D.c]) - TypeFn.cast_signed(U, IZp)) * ( - TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f]) - - TypeFn.cast_signed(U, KZp)) +def conv_3d_ndhwc_dhwcf_q( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.C, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs 3-D convolution with zero point offsets. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. This includes the zero + point offsets common to quantized operations. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) + O[D.n, D.od, D.oh, D.ow, D.f] += ( + TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.c, + ], + ) + - TypeFn.cast_signed(U, IZp) + ) * ( + TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f]) + - TypeFn.cast_signed(U, KZp) + ) @linalg_structured_op -def conv_3d_ncdhw_fcdhw(I=TensorDef(T1, S.N, S.C, S.OD * S.SD + S.KD * S.DD, - S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.F, S.C, S.KD, S.KH, S.KW), - O=TensorDef(U, S.N, S.F, S.OD, S.OH, S.OW, output=True), - strides=IndexAttrDef(S.SD, - S.SH, - S.SW, - default=[1, 1, 1]), - dilations=IndexAttrDef(S.DD, - S.DH, - S.DW, - default=[1, 1, 1])): - """Performs 3-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) - O[D.n, D.f, D.od, D.oh, D.ow] += TypeFn.cast_signed( - U, I[D.n, D.c, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed( - U, K[D.f, D.c, D.kd, D.kh, D.kw]) +def conv_3d_ncdhw_fcdhw( + I=TensorDef( + T1, + S.N, + S.C, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + ), + K=TensorDef(T2, S.F, S.C, S.KD, S.KH, S.KW), + O=TensorDef(U, S.N, S.F, S.OD, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs 3-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) + O[D.n, D.f, D.od, D.oh, D.ow] += TypeFn.cast_signed( + U, + I[ + D.n, + D.c, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + ], + ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kd, D.kh, D.kw]) @linalg_structured_op -def depthwise_conv_1d_nwc_wc(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, - S.IC), - K=TensorDef(T2, S.KW, S.IC), - O=TensorDef(U, S.N, S.OW, S.IC, output=True), - strides=IndexAttrDef(S.SW, default=[1]), - dilations=IndexAttrDef(S.DW, default=[1])): - """Performs depth-wise 1-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. Multiplier is set to 1 - which is a special case for most depthwise convolutions. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.ow, D.ic, D.kw) - O[D.n, D.ow, D.ic] += \ - TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \ - TypeFn.cast_signed(U, K[D.kw, D.ic]) +def depthwise_conv_1d_nwc_wc( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KW, S.IC), + O=TensorDef(U, S.N, S.OW, S.IC, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs depth-wise 1-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.ic, D.kw) + O[D.n, D.ow, D.ic] += TypeFn.cast_signed( + U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic] + ) * TypeFn.cast_signed(U, K[D.kw, D.ic]) @linalg_structured_op -def depthwise_conv_1d_ncw_cw(I=TensorDef(T1, S.N, S.IC, - S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.IC, S.KW), - O=TensorDef(U, S.N, S.IC, S.OW, output=True), - strides=IndexAttrDef(S.SW, default=[1]), - dilations=IndexAttrDef(S.DW, default=[1])): - """Performs depth-wise 1-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. Multiplier is set to 1 - which is a special case for most depthwise convolutions. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.ow, D.ic, D.kw) - O[D.n, D.ic, D.ow] += \ - TypeFn.cast_signed(U, I[D.n, D.ic, D.ow * S.SW + D.kw * S.DW]) * \ - TypeFn.cast_signed(U, K[D.ic, D.kw]) +def depthwise_conv_1d_ncw_cw( + I=TensorDef(T1, S.N, S.IC, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.IC, S.KW), + O=TensorDef(U, S.N, S.IC, S.OW, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs depth-wise 1-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.ic, D.kw) + O[D.n, D.ic, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.ic, D.ow * S.SW + D.kw * S.DW] + ) * TypeFn.cast_signed(U, K[D.ic, D.kw]) @linalg_structured_op -def depthwise_conv_1d_nwc_wcm(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, - S.IC), - K=TensorDef(T2, S.KW, S.IC, S.CM), - O=TensorDef(U, S.N, S.OW, S.IC, S.CM, - output=True), - strides=IndexAttrDef(S.SW, default=[1]), - dilations=IndexAttrDef(S.DW, default=[1])): - """Performs depth-wise 1-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.ow, D.ic, D.cm, D.kw) - O[D.n, D.ow, D.ic, D.cm] += \ - TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \ - TypeFn.cast_signed(U, K[D.kw, D.ic, D.cm]) +def depthwise_conv_1d_nwc_wcm( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KW, S.IC, S.CM), + O=TensorDef(U, S.N, S.OW, S.IC, S.CM, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs depth-wise 1-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.ic, D.cm, D.kw) + O[D.n, D.ow, D.ic, D.cm] += TypeFn.cast_signed( + U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic] + ) * TypeFn.cast_signed(U, K[D.kw, D.ic, D.cm]) @linalg_structured_op -def depthwise_conv_2d_nhwc_hwc(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.IC), - K=TensorDef(T2, S.KH, S.KW, S.IC), - O=TensorDef(U, - S.N, - S.OH, - S.OW, - S.IC, - output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, - S.DW, - default=[1, 1])): - """Performs depth-wise 2-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. Multiplier is set to 1 - which is a special case for most depthwise convolutions. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast_signed( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.ic]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) +def depthwise_conv_2d_nhwc_hwc( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KH, S.KW, S.IC), + O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic] + ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) @linalg_structured_op -def depthwise_conv_2d_nchw_chw(I=TensorDef(T1, S.N, S.IC, - S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.IC, S.KH, S.KW), - O=TensorDef(U, - S.N, - S.IC, - S.OH, - S.OW, - output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, - S.DW, - default=[1, 1])): - """Performs depth-wise 2-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. Multiplier is set to 1 - which is a special case for most depthwise convolutions. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) - O[D.n, D.ic, D.oh, D.ow] += TypeFn.cast_signed( - U, I[D.n, D.ic, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + - D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.ic, D.kh, D.kw]) +def depthwise_conv_2d_nchw_chw( + I=TensorDef(T1, S.N, S.IC, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.IC, S.KH, S.KW), + O=TensorDef(U, S.N, S.IC, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) + O[D.n, D.ic, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.ic, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] + ) * TypeFn.cast_signed(U, K[D.ic, D.kh, D.kw]) @linalg_structured_op -def depthwise_conv_2d_nhwc_hwc_q(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.IC), - K=TensorDef(T2, S.KH, S.KW, S.IC), - IZp=ScalarDef(I32), - KZp=ScalarDef(I32), - O=TensorDef(U, - S.N, - S.OH, - S.OW, - S.IC, - output=True), - strides=IndexAttrDef(S.SH, - S.SW, - default=[1, 1]), - dilations=IndexAttrDef(S.DH, - S.DW, - default=[1, 1])): - """Performs depth-wise 2-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.ic] += ((TypeFn.cast_signed( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) - - TypeFn.cast_signed(U, IZp)) * - (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) - - TypeFn.cast_signed(U, KZp))) +def depthwise_conv_2d_nhwc_hwc_q( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KH, S.KW, S.IC), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.ic] += ( + TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic] + ) + - TypeFn.cast_signed(U, IZp) + ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) - TypeFn.cast_signed(U, KZp)) @linalg_structured_op -def depthwise_conv_2d_nhwc_hwcm(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.IC), - K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), - O=TensorDef(U, - S.N, - S.OH, - S.OW, - S.IC, - S.CM, - output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, - 1]), - dilations=IndexAttrDef(S.DH, - S.DW, - default=[1, 1])): - """Performs depth-wise 2-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.ic]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) +def depthwise_conv_2d_nhwc_hwcm( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), + O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic] + ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) @linalg_structured_op -def depthwise_conv_2d_nhwc_hwcm_q(I=TensorDef(T1, S.N, - S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.IC), - K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), - IZp=ScalarDef(I32), - KZp=ScalarDef(I32), - O=TensorDef(U, - S.N, - S.OH, - S.OW, - S.IC, - S.CM, - output=True), - strides=IndexAttrDef(S.SH, - S.SW, - default=[1, 1]), - dilations=IndexAttrDef(S.DH, - S.DW, - default=[1, 1])): - """Performs depth-wise 2-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.ic, - D.cm] += ((TypeFn.cast_signed( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) - - TypeFn.cast_signed(U, IZp)) * - (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) - - TypeFn.cast_signed(U, KZp))) +def depthwise_conv_2d_nhwc_hwcm_q( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.ic, D.cm] += ( + TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic] + ) + - TypeFn.cast_signed(U, IZp) + ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) - TypeFn.cast_signed(U, KZp)) @linalg_structured_op -def depthwise_conv_3d_ndhwc_dhwc(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, - S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.IC), - K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC), - O=TensorDef(U, - S.N, - S.OD, - S.OH, - S.OW, - output=True), - strides=IndexAttrDef(S.SD, - S.SH, - S.SW, - default=[1, 1, 1]), - dilations=IndexAttrDef(S.DD, - S.DH, - S.DW, - default=[1, 1, 1])): - """Performs depth-wise 3-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. Multiplier is set to 1 - which is a special case for most depthwise convolutions. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic) - O[D.n, D.od, D.oh, D.ow, D.ic] += TypeFn.cast_signed( - U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW, D.ic]) * TypeFn.cast_signed( - U, K[D.kd, D.kh, D.kw, D.ic]) +def depthwise_conv_3d_ndhwc_dhwc( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.IC, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs depth-wise 3-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic) + O[D.n, D.od, D.oh, D.ow, D.ic] += TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.ic, + ], + ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic]) @linalg_structured_op -def depthwise_conv_3d_ncdhw_cdhw(I=TensorDef(T1, S.N, S.IC, - S.OD * S.SD + S.KD * S.DD, - S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.IC, S.KD, S.KH, S.KW), - O=TensorDef(U, - S.N, - S.IC, - S.OD, - S.OH, - S.OW, - output=True), - strides=IndexAttrDef(S.SD, - S.SH, - S.SW, - default=[1, 1, 1]), - dilations=IndexAttrDef(S.DD, - S.DH, - S.DW, - default=[1, 1, 1])): - """Performs depth-wise 3-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. Multiplier is set to 1 - which is a special case for most depthwise convolutions. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic) - O[D.n, D.ic, D.od, D.oh, D.ow] += TypeFn.cast_signed( - U, I[D.n, D.ic, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed( - U, K[D.ic, D.kd, D.kh, D.kw]) +def depthwise_conv_3d_ncdhw_cdhw( + I=TensorDef( + T1, + S.N, + S.IC, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + ), + K=TensorDef(T2, S.IC, S.KD, S.KH, S.KW), + O=TensorDef(U, S.N, S.IC, S.OD, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs depth-wise 3-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic) + O[D.n, D.ic, D.od, D.oh, D.ow] += TypeFn.cast_signed( + U, + I[ + D.n, + D.ic, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + ], + ) * TypeFn.cast_signed(U, K[D.ic, D.kd, D.kh, D.kw]) @linalg_structured_op -def depthwise_conv_3d_ndhwc_dhwcm(I=TensorDef(T1, S.N, - S.OD * S.SD + S.KD * S.DD, - S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.IC), - K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC, S.CM), - O=TensorDef(U, - S.N, - S.OD, - S.OH, - S.OW, - S.CM, - output=True), - strides=IndexAttrDef(S.SD, - S.SH, - S.SW, - default=[1, 1, 1]), - dilations=IndexAttrDef(S.DD, - S.DH, - S.DW, - default=[1, 1, 1])): - """Performs depth-wise 3-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.od, D.oh, D.ow, D.cm, D.kd, D.kh, D.kw, D.ic) - O[D.n, D.od, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed( - U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW, D.ic]) * TypeFn.cast_signed( - U, K[D.kd, D.kh, D.kw, D.ic, D.cm]) +def depthwise_conv_3d_ndhwc_dhwcm( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.IC, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC, S.CM), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.CM, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs depth-wise 3-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.cm, D.kd, D.kh, D.kw, D.ic) + O[D.n, D.od, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.ic, + ], + ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic, D.cm]) @linalg_structured_op -def pooling_nhwc_sum(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), - O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): - """Performs sum pooling. - - Layout: - * Input: NHWC. - * Kernel: HW. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) +def pooling_nhwc_sum( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs sum pooling. + + Layout: + * Input: NHWC. + * Kernel: HW. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) @linalg_structured_op -def pooling_nchw_sum(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), - O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): - """Performs sum pooling. - - Layout: - * Input: NCHW. - * Kernel: HW. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) - O[D.n, D.c, D.oh, D.ow] += TypeFn.cast_signed( - U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]) +def pooling_nchw_sum( + I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs sum pooling. + + Layout: + * Input: NCHW. + * Kernel: HW. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) + O[D.n, D.c, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] + ) @linalg_structured_op -def pooling_nhwc_max(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), - O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): - """Performs max pooling. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kh, D.kw](TypeFn.cast_signed( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) +def pooling_nhwc_max( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kh, D.kw]( + TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) + ) @linalg_structured_op -def pooling_nhwc_max_unsigned(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, - S.KH, - S.KW, - index_dims=[D.kh, D.kw]), - O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, - 1])): - """Performs unsigned max pooling. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.oh, D.ow, - D.c] = ReduceFn.max_unsigned[D.kh, D.kw](TypeFn.cast_unsigned( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) +def pooling_nhwc_max_unsigned( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs unsigned max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw]( + TypeFn.cast_unsigned( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) + ) @linalg_structured_op -def pooling_nchw_max(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), - O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): - """Performs max pooling. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) - O[D.n, D.c, D.oh, D.ow] = ReduceFn.max_signed[D.kh, D.kw](TypeFn.cast_signed( - U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,])) +def pooling_nchw_max( + I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) + O[D.n, D.c, D.oh, D.ow] = ReduceFn.max_signed[D.kh, D.kw]( + TypeFn.cast_signed( + U, + I[ + D.n, + D.c, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + ], + ) + ) @linalg_structured_op -def pooling_nhwc_min(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), - O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): - """Performs min pooling. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kh, D.kw](TypeFn.cast_signed( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) +def pooling_nhwc_min( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kh, D.kw]( + TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) + ) @linalg_structured_op -def pooling_nhwc_min_unsigned(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, - S.KH, - S.KW, - index_dims=[D.kh, D.kw]), - O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, - 1])): - """Performs unsigned min pooling. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.oh, D.ow, - D.c] = ReduceFn.min_unsigned[D.kh, D.kw](TypeFn.cast_unsigned( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) +def pooling_nhwc_min_unsigned( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs unsigned min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw]( + TypeFn.cast_unsigned( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) + ) + @linalg_structured_op -def pooling_nwc_sum(I=TensorDef(T1, S.N, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KW, index_dims=[D.kw]), - O=TensorDef(U, S.N, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SW, default=[1]), - dilations=IndexAttrDef(S.DW, default=[1])): - """Performs sum pooling. - - Layout: - * Input: NWC. - * Kernel: W. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.ow, D.c, D.kw) - O[D.n, D.ow, D.c] += TypeFn.cast_signed( - U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) +def pooling_nwc_sum( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs sum pooling. + + Layout: + * Input: NWC. + * Kernel: W. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, D.c] += TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) @linalg_structured_op -def pooling_ncw_sum(I=TensorDef(T1, S.N, S.C, - S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.KW, index_dims=[D.kw]), - O=TensorDef(U, S.N, S.C, S.OW, output=True), - strides=IndexAttrDef(S.SW, default=[1]), - dilations=IndexAttrDef(S.DW, default=[1])): - """Performs sum pooling. - - Layout: - * Input: NCW. - * Kernel: W. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.c, D.ow, D.kw) - O[D.n, D.c, D.ow] += TypeFn.cast_signed( - U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW]) +def pooling_ncw_sum( + I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.C, S.OW, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs sum pooling. + + Layout: + * Input: NCW. + * Kernel: W. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.c, D.ow, D.kw) + O[D.n, D.c, D.ow] += TypeFn.cast_signed(U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW]) @linalg_structured_op -def pooling_nwc_max(I=TensorDef(T1, S.N, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KW, index_dims=[D.kw]), - O=TensorDef(U, S.N, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SW, default=[1]), - dilations=IndexAttrDef(S.DW, default=[1])): - """Performs max pooling. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.ow, D.c, D.kw) - O[D.n, D.ow, D.c] = ReduceFn.max_signed[[D.kw]](TypeFn.cast_signed( - U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])) +def pooling_nwc_max( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, D.c] = ReduceFn.max_signed[[D.kw]]( + TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) + ) @linalg_structured_op -def pooling_nwc_max_unsigned(I=TensorDef(T1, S.N, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, - S.KW, - index_dims=[D.kw]), - O=TensorDef(U, S.N, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SW, default=[1]), - dilations=IndexAttrDef(S.DW, default=[1])): - """Performs unsigned max pooling. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.ow, D.c, D.kw) - O[D.n, D.ow, - D.c] = ReduceFn.max_unsigned[[D.kw]](TypeFn.cast_unsigned( - U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])) +def pooling_nwc_max_unsigned( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs unsigned max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, D.c] = ReduceFn.max_unsigned[[D.kw]]( + TypeFn.cast_unsigned(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) + ) @linalg_structured_op -def pooling_ncw_max(I=TensorDef(T1, S.N, S.C, - S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.KW, index_dims=[D.kw]), - O=TensorDef(U, S.N, S.C, S.OW, output=True), - strides=IndexAttrDef(S.SW, default=[1]), - dilations=IndexAttrDef(S.DW, default=[1])): - """Performs max pooling. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.c, D.ow, D.kw) - O[D.n, D.c, D.ow] = ReduceFn.max_signed[[D.kw]](TypeFn.cast_signed( - U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW,])) +def pooling_ncw_max( + I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.C, S.OW, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.c, D.ow, D.kw) + O[D.n, D.c, D.ow] = ReduceFn.max_signed[[D.kw]]( + TypeFn.cast_signed( + U, + I[ + D.n, + D.c, + D.ow * S.SW + D.kw * S.DW, + ], + ) + ) @linalg_structured_op -def pooling_nwc_min(I=TensorDef(T1, S.N, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KW, index_dims=[D.kw]), - O=TensorDef(U, S.N, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SW, default=[1]), - dilations=IndexAttrDef(S.DW, default=[1])): - """Performs min pooling. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.ow, D.c, D.kw) - O[D.n, D.ow, D.c] = ReduceFn.min_signed[[D.kw]](TypeFn.cast_signed( - U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])) +def pooling_nwc_min( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, D.c] = ReduceFn.min_signed[[D.kw]]( + TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) + ) @linalg_structured_op -def pooling_nwc_min_unsigned(I=TensorDef(T1, S.N, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, - S.KW, - index_dims=[D.kw]), - O=TensorDef(U, S.N, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SW, default=[1]), - dilations=IndexAttrDef(S.DW, default=[1])): - """Performs unsigned min pooling. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.ow, D.c, D.kw) - O[D.n, D.ow, - D.c] = ReduceFn.min_unsigned[[D.kw]](TypeFn.cast_unsigned( - U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])) - +def pooling_nwc_min_unsigned( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs unsigned min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, D.c] = ReduceFn.min_unsigned[[D.kw]]( + TypeFn.cast_unsigned(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) + ) @linalg_structured_op -def pooling_ndhwc_sum(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, - S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, - S.KD, - S.KH, - S.KW, - index_dims=[D.kd, D.kh, D.kw]), - O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), - dilations=IndexAttrDef(S.DD, - S.DH, - S.DW, - default=[1, 1, 1])): - """Performs 3D sum pooling. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) - O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast_signed( - U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW, D.c]) +def pooling_ndhwc_sum( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.C, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs 3D sum pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) + O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.c, + ], + ) @linalg_structured_op -def pooling_ndhwc_max(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, - S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, - S.KD, - S.KH, - S.KW, - index_dims=[D.kd, D.kh, D.kw]), - O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), - dilations=IndexAttrDef(S.DD, - S.DH, - S.DW, - default=[1, 1, 1])): - """Performs 3D max pooling. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) - O[D.n, D.od, D.oh, D.ow, - D.c] = ReduceFn.max_signed[D.kd, D.kh, D.kw](TypeFn.cast_signed( - U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW, D.c])) +def pooling_ndhwc_max( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.C, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs 3D max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) + O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kd, D.kh, D.kw]( + TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.c, + ], + ) + ) @linalg_structured_op -def pooling_ndhwc_min(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, - S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, - S.KD, - S.KH, - S.KW, - index_dims=[D.kd, D.kh, D.kw]), - O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), - dilations=IndexAttrDef(S.DD, - S.DH, - S.DW, - default=[1, 1, 1])): - """Performs 3D min pooling. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) - O[D.n, D.od, D.oh, D.ow, - D.c] = ReduceFn.min_signed[D.kd, D.kh, D.kw](TypeFn.cast_signed( - U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW, D.c])) +def pooling_ndhwc_min( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.C, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs 3D min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) + O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kd, D.kh, D.kw]( + TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.c, + ], + ) + ) @linalg_structured_op def fill(value=ScalarDef(T1), O=TensorDef(U, output=True)): - """Fills the output tensor with the given value. + """Fills the output tensor with the given value. - Works for arbitrary ranked output tensors since the operation performs scalar - accesses only and is thus rank polymorphic. Numeric casting is performed on - the value operand, promoting it to the same data type as the output. - """ - implements(FillOpInterface) - defines(Canonicalizer) - O[None] = TypeFn.cast_signed(U, value) + Works for arbitrary ranked output tensors since the operation performs scalar + accesses only and is thus rank polymorphic. Numeric casting is performed on + the value operand, promoting it to the same data type as the output. + """ + implements(FillOpInterface) + defines(Canonicalizer) + O[None] = TypeFn.cast_signed(U, value) @linalg_structured_op -def fill_rng_2d(min=ScalarDef(F64), - max=ScalarDef(F64), - seed=ScalarDef(I32), - O=TensorDef(T, S.M, S.N, output=True)): - """Fills the output tensor with pseudo random numbers. - - The operation generations pseudo random numbers using a linear congruential - generator. It provides no guarantees regarding the distribution of the - generated random numbers. Instead of generating the random numbers - sequentially, it instantiates one random number generator per data element - and runs them in parallel. The seed operand and the indices of the data - element seed the random number generation. The min and max operands limit - the range of the generated random numbers. - """ - domain(D.m, D.n) - multiplier = TypeFn.cast_signed(I32, const(1103515245)) - increment = TypeFn.cast_signed(I32, const(12345)) - rand1 = (TypeFn.cast_signed(I32, index(D.m)) + seed) * multiplier + increment - rand2 = (TypeFn.cast_signed(I32, index(D.n)) + rand1) * multiplier + increment - inv_range = TypeFn.cast_signed(F64, const(2.3283064e-10)) - offset = TypeFn.cast_signed(F64, const(2147483647)) - scaling = (max - min) * inv_range - O[D.m, D.n] = TypeFn.cast_signed( - T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min) +def fill_rng_2d( + min=ScalarDef(F64), + max=ScalarDef(F64), + seed=ScalarDef(I32), + O=TensorDef(T, S.M, S.N, output=True), +): + """Fills the output tensor with pseudo random numbers. + + The operation generations pseudo random numbers using a linear congruential + generator. It provides no guarantees regarding the distribution of the + generated random numbers. Instead of generating the random numbers + sequentially, it instantiates one random number generator per data element + and runs them in parallel. The seed operand and the indices of the data + element seed the random number generation. The min and max operands limit + the range of the generated random numbers. + """ + domain(D.m, D.n) + multiplier = TypeFn.cast_signed(I32, const(1103515245)) + increment = TypeFn.cast_signed(I32, const(12345)) + rand1 = (TypeFn.cast_signed(I32, index(D.m)) + seed) * multiplier + increment + rand2 = (TypeFn.cast_signed(I32, index(D.n)) + rand1) * multiplier + increment + inv_range = TypeFn.cast_signed(F64, const(2.3283064e-10)) + offset = TypeFn.cast_signed(F64, const(2147483647)) + scaling = (max - min) * inv_range + O[D.m, D.n] = TypeFn.cast_signed( + T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min + ) diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py index ca0d479..980f237 100644 --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -5,6 +5,8 @@ from ._python_test_ops_gen import * from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestTensorType + def register_python_test_dialect(context, load=True): - from .._mlir_libs import _mlirPythonTest - _mlirPythonTest.register_python_test_dialect(context, load) + from .._mlir_libs import _mlirPythonTest + + _mlirPythonTest.register_python_test_dialect(context, load) diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index 78956c4..b505a49 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -6,16 +6,18 @@ from enum import Enum class FailurePropagationMode(Enum): - """Propagation mode for silenceable errors.""" - PROPAGATE = 1 - SUPPRESS = 2 + """Propagation mode for silenceable errors.""" - def _as_int(self): - if self is FailurePropagationMode.PROPAGATE: - return 1 + PROPAGATE = 1 + SUPPRESS = 2 + + def _as_int(self): + if self is FailurePropagationMode.PROPAGATE: + return 1 + + assert self is FailurePropagationMode.SUPPRESS + return 2 - assert self is FailurePropagationMode.SUPPRESS - return 2 from .._transform_ops_gen import * from ..._mlir_libs._mlirDialectsTransform import * diff --git a/mlir/python/mlir/execution_engine.py b/mlir/python/mlir/execution_engine.py index 262545b..4739231 100644 --- a/mlir/python/mlir/execution_engine.py +++ b/mlir/python/mlir/execution_engine.py @@ -7,37 +7,37 @@ from ._mlir_libs import _mlirExecutionEngine as _execution_engine import ctypes __all__ = [ - "ExecutionEngine", + "ExecutionEngine", ] -class ExecutionEngine(_execution_engine.ExecutionEngine): - def lookup(self, name): - """Lookup a function emitted with the `llvm.emit_c_interface` - attribute and returns a ctype callable. - Raise a RuntimeError if the function isn't found. - """ - func = self.raw_lookup("_mlir_ciface_" + name) - if not func: - raise RuntimeError("Unknown function " + name) - prototype = ctypes.CFUNCTYPE(None, ctypes.c_void_p) - return prototype(func) +class ExecutionEngine(_execution_engine.ExecutionEngine): + def lookup(self, name): + """Lookup a function emitted with the `llvm.emit_c_interface` + attribute and returns a ctype callable. + Raise a RuntimeError if the function isn't found. + """ + func = self.raw_lookup("_mlir_ciface_" + name) + if not func: + raise RuntimeError("Unknown function " + name) + prototype = ctypes.CFUNCTYPE(None, ctypes.c_void_p) + return prototype(func) - def invoke(self, name, *ctypes_args): - """Invoke a function with the list of ctypes arguments. - All arguments must be pointers. - Raise a RuntimeError if the function isn't found. - """ - func = self.lookup(name) - packed_args = (ctypes.c_void_p * len(ctypes_args))() - for argNum in range(len(ctypes_args)): - packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p) - func(packed_args) + def invoke(self, name, *ctypes_args): + """Invoke a function with the list of ctypes arguments. + All arguments must be pointers. + Raise a RuntimeError if the function isn't found. + """ + func = self.lookup(name) + packed_args = (ctypes.c_void_p * len(ctypes_args))() + for argNum in range(len(ctypes_args)): + packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p) + func(packed_args) - def register_runtime(self, name, ctypes_callback): - """Register a runtime function available to the jitted code - under the provided `name`. The `ctypes_callback` must be a - `CFuncType` that outlives the execution engine. - """ - callback = ctypes.cast(ctypes_callback, ctypes.c_void_p) - self.raw_register_runtime("_mlir_ciface_" + name, callback) + def register_runtime(self, name, ctypes_callback): + """Register a runtime function available to the jitted code + under the provided `name`. The `ctypes_callback` must be a + `CFuncType` that outlives the execution engine. + """ + callback = ctypes.cast(ctypes_callback, ctypes.c_void_p) + self.raw_register_runtime("_mlir_ciface_" + name, callback) diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index be065d4..99c21ff 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -8,124 +8,123 @@ from ._mlir_libs._mlir.ir import _GlobalDebug # Convenience decorator for registering user-friendly Attribute builders. def register_attribute_builder(kind): + def decorator_builder(func): + AttrBuilder.insert(kind, func) + return func - def decorator_builder(func): - AttrBuilder.insert(kind, func) - return func - - return decorator_builder + return decorator_builder @register_attribute_builder("BoolAttr") def _boolAttr(x, context): - return BoolAttr.get(x, context=context) + return BoolAttr.get(x, context=context) @register_attribute_builder("IndexAttr") def _indexAttr(x, context): - return IntegerAttr.get(IndexType.get(context=context), x) + return IntegerAttr.get(IndexType.get(context=context), x) @register_attribute_builder("I16Attr") def _i16Attr(x, context): - return IntegerAttr.get(IntegerType.get_signless(16, context=context), x) + return IntegerAttr.get(IntegerType.get_signless(16, context=context), x) @register_attribute_builder("I32Attr") def _i32Attr(x, context): - return IntegerAttr.get(IntegerType.get_signless(32, context=context), x) + return IntegerAttr.get(IntegerType.get_signless(32, context=context), x) @register_attribute_builder("I64Attr") def _i64Attr(x, context): - return IntegerAttr.get(IntegerType.get_signless(64, context=context), x) + return IntegerAttr.get(IntegerType.get_signless(64, context=context), x) @register_attribute_builder("SI16Attr") def _si16Attr(x, context): - return IntegerAttr.get(IntegerType.get_signed(16, context=context), x) + return IntegerAttr.get(IntegerType.get_signed(16, context=context), x) @register_attribute_builder("SI32Attr") def _si32Attr(x, context): - return IntegerAttr.get(IntegerType.get_signed(32, context=context), x) + return IntegerAttr.get(IntegerType.get_signed(32, context=context), x) @register_attribute_builder("F32Attr") def _f32Attr(x, context): - return FloatAttr.get_f32(x, context=context) + return FloatAttr.get_f32(x, context=context) @register_attribute_builder("F64Attr") def _f64Attr(x, context): - return FloatAttr.get_f64(x, context=context) + return FloatAttr.get_f64(x, context=context) @register_attribute_builder("StrAttr") def _stringAttr(x, context): - return StringAttr.get(x, context=context) + return StringAttr.get(x, context=context) @register_attribute_builder("SymbolNameAttr") def _symbolNameAttr(x, context): - return StringAttr.get(x, context=context) + return StringAttr.get(x, context=context) @register_attribute_builder("SymbolRefAttr") def _symbolRefAttr(x, context): - return FlatSymbolRefAttr.get(x, context=context) + return FlatSymbolRefAttr.get(x, context=context) @register_attribute_builder("ArrayAttr") def _arrayAttr(x, context): - return ArrayAttr.get(x, context=context) + return ArrayAttr.get(x, context=context) @register_attribute_builder("I32ArrayAttr") def _i32ArrayAttr(x, context): - return ArrayAttr.get([_i32Attr(v, context) for v in x]) + return ArrayAttr.get([_i32Attr(v, context) for v in x]) @register_attribute_builder("I64ArrayAttr") def _i64ArrayAttr(x, context): - return ArrayAttr.get([_i64Attr(v, context) for v in x]) + return ArrayAttr.get([_i64Attr(v, context) for v in x]) @register_attribute_builder("F32ArrayAttr") def _f32ArrayAttr(x, context): - return ArrayAttr.get([_f32Attr(v, context) for v in x]) + return ArrayAttr.get([_f32Attr(v, context) for v in x]) @register_attribute_builder("F64ArrayAttr") def _f64ArrayAttr(x, context): - return ArrayAttr.get([_f64Attr(v, context) for v in x]) + return ArrayAttr.get([_f64Attr(v, context) for v in x]) @register_attribute_builder("DenseI64ArrayAttr") def _denseI64ArrayAttr(x, context): - return DenseI64ArrayAttr.get(x, context=context) + return DenseI64ArrayAttr.get(x, context=context) @register_attribute_builder("TypeAttr") def _typeAttr(x, context): - return TypeAttr.get(x, context=context) + return TypeAttr.get(x, context=context) @register_attribute_builder("TypeArrayAttr") def _typeArrayAttr(x, context): - return _arrayAttr([TypeAttr.get(t, context=context) for t in x], context) + return _arrayAttr([TypeAttr.get(t, context=context) for t in x], context) try: - import numpy as np + import numpy as np - @register_attribute_builder("IndexElementsAttr") - def _indexElementsAttr(x, context): - return DenseElementsAttr.get( - np.array(x, dtype=np.int64), - type=IndexType.get(context=context), - context=context, - ) + @register_attribute_builder("IndexElementsAttr") + def _indexElementsAttr(x, context): + return DenseElementsAttr.get( + np.array(x, dtype=np.int64), + type=IndexType.get(context=context), + context=context, + ) except ImportError: - pass + pass diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py index d709679..51433d7 100644 --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -9,131 +9,134 @@ import ctypes class C128(ctypes.Structure): - """A ctype representation for MLIR's Double Complex.""" - _fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)] + """A ctype representation for MLIR's Double Complex.""" + + _fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)] class C64(ctypes.Structure): - """A ctype representation for MLIR's Float Complex.""" - _fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)] + """A ctype representation for MLIR's Float Complex.""" + + _fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)] class F16(ctypes.Structure): - """A ctype representation for MLIR's Float16.""" - _fields_ = [("f16", ctypes.c_int16)] + """A ctype representation for MLIR's Float16.""" + + _fields_ = [("f16", ctypes.c_int16)] # https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype def as_ctype(dtp): - """Converts dtype to ctype.""" - if dtp == np.dtype(np.complex128): - return C128 - if dtp == np.dtype(np.complex64): - return C64 - if dtp == np.dtype(np.float16): - return F16 - return np.ctypeslib.as_ctypes_type(dtp) + """Converts dtype to ctype.""" + if dtp == np.dtype(np.complex128): + return C128 + if dtp == np.dtype(np.complex64): + return C64 + if dtp == np.dtype(np.float16): + return F16 + return np.ctypeslib.as_ctypes_type(dtp) def to_numpy(array): - """Converts ctypes array back to numpy dtype array.""" - if array.dtype == C128: - return array.view("complex128") - if array.dtype == C64: - return array.view("complex64") - if array.dtype == F16: - return array.view("float16") - return array + """Converts ctypes array back to numpy dtype array.""" + if array.dtype == C128: + return array.view("complex128") + if array.dtype == C64: + return array.view("complex64") + if array.dtype == F16: + return array.view("float16") + return array def make_nd_memref_descriptor(rank, dtype): + class MemRefDescriptor(ctypes.Structure): + """Builds an empty descriptor for the given rank/dtype, where rank>0.""" - class MemRefDescriptor(ctypes.Structure): - """Builds an empty descriptor for the given rank/dtype, where rank>0.""" + _fields_ = [ + ("allocated", ctypes.c_longlong), + ("aligned", ctypes.POINTER(dtype)), + ("offset", ctypes.c_longlong), + ("shape", ctypes.c_longlong * rank), + ("strides", ctypes.c_longlong * rank), + ] - _fields_ = [ - ("allocated", ctypes.c_longlong), - ("aligned", ctypes.POINTER(dtype)), - ("offset", ctypes.c_longlong), - ("shape", ctypes.c_longlong * rank), - ("strides", ctypes.c_longlong * rank), - ] - - return MemRefDescriptor + return MemRefDescriptor def make_zero_d_memref_descriptor(dtype): + class MemRefDescriptor(ctypes.Structure): + """Builds an empty descriptor for the given dtype, where rank=0.""" - class MemRefDescriptor(ctypes.Structure): - """Builds an empty descriptor for the given dtype, where rank=0.""" - - _fields_ = [ - ("allocated", ctypes.c_longlong), - ("aligned", ctypes.POINTER(dtype)), - ("offset", ctypes.c_longlong), - ] + _fields_ = [ + ("allocated", ctypes.c_longlong), + ("aligned", ctypes.POINTER(dtype)), + ("offset", ctypes.c_longlong), + ] - return MemRefDescriptor + return MemRefDescriptor class UnrankedMemRefDescriptor(ctypes.Structure): - """Creates a ctype struct for memref descriptor""" - _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)] + """Creates a ctype struct for memref descriptor""" + + _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)] def get_ranked_memref_descriptor(nparray): - """Returns a ranked memref descriptor for the given numpy array.""" - ctp = as_ctype(nparray.dtype) - if nparray.ndim == 0: - x = make_zero_d_memref_descriptor(ctp)() + """Returns a ranked memref descriptor for the given numpy array.""" + ctp = as_ctype(nparray.dtype) + if nparray.ndim == 0: + x = make_zero_d_memref_descriptor(ctp)() + x.allocated = nparray.ctypes.data + x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp)) + x.offset = ctypes.c_longlong(0) + return x + + x = make_nd_memref_descriptor(nparray.ndim, ctp)() x.allocated = nparray.ctypes.data x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp)) x.offset = ctypes.c_longlong(0) - return x + x.shape = nparray.ctypes.shape - x = make_nd_memref_descriptor(nparray.ndim, ctp)() - x.allocated = nparray.ctypes.data - x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp)) - x.offset = ctypes.c_longlong(0) - x.shape = nparray.ctypes.shape - - # Numpy uses byte quantities to express strides, MLIR OTOH uses the - # torch abstraction which specifies strides in terms of elements. - strides_ctype_t = ctypes.c_longlong * nparray.ndim - x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides]) - return x + # Numpy uses byte quantities to express strides, MLIR OTOH uses the + # torch abstraction which specifies strides in terms of elements. + strides_ctype_t = ctypes.c_longlong * nparray.ndim + x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides]) + return x def get_unranked_memref_descriptor(nparray): - """Returns a generic/unranked memref descriptor for the given numpy array.""" - d = UnrankedMemRefDescriptor() - d.rank = nparray.ndim - x = get_ranked_memref_descriptor(nparray) - d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p) - return d + """Returns a generic/unranked memref descriptor for the given numpy array.""" + d = UnrankedMemRefDescriptor() + d.rank = nparray.ndim + x = get_ranked_memref_descriptor(nparray) + d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p) + return d def unranked_memref_to_numpy(unranked_memref, np_dtype): - """Converts unranked memrefs to numpy arrays.""" - ctp = as_ctype(np_dtype) - descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp) - val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor)) - np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape) - strided_arr = np.lib.stride_tricks.as_strided( - np_arr, - np.ctypeslib.as_array(val[0].shape), - np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize, - ) - return to_numpy(strided_arr) + """Converts unranked memrefs to numpy arrays.""" + ctp = as_ctype(np_dtype) + descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp) + val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor)) + np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape) + strided_arr = np.lib.stride_tricks.as_strided( + np_arr, + np.ctypeslib.as_array(val[0].shape), + np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize, + ) + return to_numpy(strided_arr) def ranked_memref_to_numpy(ranked_memref): - """Converts ranked memrefs to numpy arrays.""" - np_arr = np.ctypeslib.as_array( - ranked_memref[0].aligned, shape=ranked_memref[0].shape) - strided_arr = np.lib.stride_tricks.as_strided( - np_arr, - np.ctypeslib.as_array(ranked_memref[0].shape), - np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize, - ) - return to_numpy(strided_arr) + """Converts ranked memrefs to numpy arrays.""" + np_arr = np.ctypeslib.as_array( + ranked_memref[0].aligned, shape=ranked_memref[0].shape + ) + strided_arr = np.lib.stride_tricks.as_strided( + np_arr, + np.ctypeslib.as_array(ranked_memref[0].shape), + np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize, + ) + return to_numpy(strided_arr) diff --git a/mlir/test/CAPI/lit.local.cfg b/mlir/test/CAPI/lit.local.cfg index f08a0de..bb0c17c 100644 --- a/mlir/test/CAPI/lit.local.cfg +++ b/mlir/test/CAPI/lit.local.cfg @@ -1 +1 @@ -config.suffixes.add('.c') +config.suffixes.add(".c") diff --git a/mlir/test/Conversion/GPUToCUDA/lit.local.cfg b/mlir/test/Conversion/GPUToCUDA/lit.local.cfg index 847c3ef..bc470cc 100644 --- a/mlir/test/Conversion/GPUToCUDA/lit.local.cfg +++ b/mlir/test/Conversion/GPUToCUDA/lit.local.cfg @@ -1,2 +1,2 @@ if not config.run_cuda_tests: - config.unsupported = True \ No newline at end of file + config.unsupported = True diff --git a/mlir/test/Conversion/GPUToROCm/lit.local.cfg b/mlir/test/Conversion/GPUToROCm/lit.local.cfg index 6eb5617..2f5cc9f 100644 --- a/mlir/test/Conversion/GPUToROCm/lit.local.cfg +++ b/mlir/test/Conversion/GPUToROCm/lit.local.cfg @@ -1,2 +1,2 @@ if not config.run_rocm_tests: - config.unsupported = True + config.unsupported = True diff --git a/mlir/test/Examples/Toy/Ch6/lit.local.cfg b/mlir/test/Examples/Toy/Ch6/lit.local.cfg index c5aeb13..0d9aa10 100644 --- a/mlir/test/Examples/Toy/Ch6/lit.local.cfg +++ b/mlir/test/Examples/Toy/Ch6/lit.local.cfg @@ -1,5 +1,3 @@ # Requires native execution. -if 'host-supports-jit' not in config.available_features: +if "host-supports-jit" not in config.available_features: config.unsupported = True - - diff --git a/mlir/test/Examples/Toy/Ch7/lit.local.cfg b/mlir/test/Examples/Toy/Ch7/lit.local.cfg index c5aeb13..0d9aa10 100644 --- a/mlir/test/Examples/Toy/Ch7/lit.local.cfg +++ b/mlir/test/Examples/Toy/Ch7/lit.local.cfg @@ -1,5 +1,3 @@ # Requires native execution. -if 'host-supports-jit' not in config.available_features: +if "host-supports-jit" not in config.available_features: config.unsupported = True - - diff --git a/mlir/test/Examples/lit.local.cfg b/mlir/test/Examples/lit.local.cfg index 97db322..1a51296 100644 --- a/mlir/test/Examples/lit.local.cfg +++ b/mlir/test/Examples/lit.local.cfg @@ -1,2 +1,2 @@ if not config.build_examples: - config.unsupported = True + config.unsupported = True diff --git a/mlir/test/Examples/standalone/lit.local.cfg b/mlir/test/Examples/standalone/lit.local.cfg index cf7c8ff..fe8397c 100644 --- a/mlir/test/Examples/standalone/lit.local.cfg +++ b/mlir/test/Examples/standalone/lit.local.cfg @@ -1,13 +1,12 @@ # Disable with sanitizers for now, this require some more setup apparently. -for san in ['asan', 'msan', 'ubsan']: - if (san in config.available_features): - config.unsupported = True +for san in ["asan", "msan", "ubsan"]: + if san in config.available_features: + config.unsupported = True config.substitutions.append(("%cmake_exe", config.host_cmake)) config.substitutions.append(("%cmake_generator", config.host_cmake_generator)) config.substitutions.append(("%host_cxx", config.host_cxx)) config.substitutions.append(("%host_cc", config.host_cc)) config.substitutions.append(("%enable_libcxx", config.enable_libcxx)) -config.substitutions.append( - ("%mlir_cmake_dir", config.mlir_cmake_dir)) +config.substitutions.append(("%mlir_cmake_dir", config.mlir_cmake_dir)) config.substitutions.append(("%llvm_use_linker", config.llvm_use_linker)) diff --git a/mlir/test/Integration/Dialect/Async/CPU/lit.local.cfg b/mlir/test/Integration/Dialect/Async/CPU/lit.local.cfg index 7215eda..073f637 100644 --- a/mlir/test/Integration/Dialect/Async/CPU/lit.local.cfg +++ b/mlir/test/Integration/Dialect/Async/CPU/lit.local.cfg @@ -1,5 +1,5 @@ import sys # Windows does not have aligned_alloc -if sys.platform == 'win32': +if sys.platform == "win32": config.unsupported = True diff --git a/mlir/test/Integration/Dialect/LLVMIR/CPU/X86/lit.local.cfg b/mlir/test/Integration/Dialect/LLVMIR/CPU/X86/lit.local.cfg index 263c8f8..071a13c 100644 --- a/mlir/test/Integration/Dialect/LLVMIR/CPU/X86/lit.local.cfg +++ b/mlir/test/Integration/Dialect/LLVMIR/CPU/X86/lit.local.cfg @@ -1,4 +1,4 @@ import platform -if platform.machine() != 'x86_64': +if platform.machine() != "x86_64": config.unsupported = True diff --git a/mlir/test/Integration/Dialect/LLVMIR/CPU/lit.local.cfg b/mlir/test/Integration/Dialect/LLVMIR/CPU/lit.local.cfg index 7d1e494..3214a11 100644 --- a/mlir/test/Integration/Dialect/LLVMIR/CPU/lit.local.cfg +++ b/mlir/test/Integration/Dialect/LLVMIR/CPU/lit.local.cfg @@ -1,18 +1,22 @@ import sys -lli_cmd = 'lli' +lli_cmd = "lli" if config.riscv_emulator_lli_executable: lli_cmd = config.riscv_emulator_lli_executable -config.substitutions.append(('%mlir_native_utils_lib_dir', - config.riscv_emulator_utils_lib_dir or config.mlir_lib_dir)) +config.substitutions.append( + ( + "%mlir_native_utils_lib_dir", + config.riscv_emulator_utils_lib_dir or config.mlir_lib_dir, + ) +) if config.riscv_vector_emulator_executable: # Run test in qemu emulator. emulation_cmd = config.riscv_vector_emulator_executable if config.riscv_vector_emulator_options: - emulation_cmd = emulation_cmd + ' ' + config.riscv_vector_emulator_options - emulation_cmd = emulation_cmd + ' ' + lli_cmd + ' --march=riscv64 -mattr=+v ' - config.substitutions.append(('%lli', emulation_cmd)) + emulation_cmd = emulation_cmd + " " + config.riscv_vector_emulator_options + emulation_cmd = emulation_cmd + " " + lli_cmd + " --march=riscv64 -mattr=+v " + config.substitutions.append(("%lli", emulation_cmd)) else: - config.substitutions.append(('%lli', lli_cmd)) + config.substitutions.append(("%lli", lli_cmd)) diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/lit.local.cfg b/mlir/test/Integration/Dialect/SparseTensor/CPU/lit.local.cfg index 9bf49cc..6e07eb8 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/lit.local.cfg +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/lit.local.cfg @@ -2,14 +2,16 @@ import sys from lit.llvm import llvm_config # FIXME: %mlir_native_utils_lib_dir is set incorrectly on Windows -if sys.platform == 'win32': +if sys.platform == "win32": config.unsupported = True # ArmSVE tests must be enabled via build flag. if config.mlir_run_arm_sve_tests: - config.substitutions.append(('%ENABLE_VLA', 'true')) - config.substitutions.append(('%VLA_ARCH_ATTR_OPTIONS', '--march=aarch64 --mattr="+sve"')) + config.substitutions.append(("%ENABLE_VLA", "true")) + config.substitutions.append( + ("%VLA_ARCH_ATTR_OPTIONS", '--march=aarch64 --mattr="+sve"') + ) else: - config.substitutions.append(('%ENABLE_VLA', 'false')) - config.substitutions.append(('%VLA_ARCH_ATTR_OPTIONS', '')) - config.substitutions.append(('%mlir_native_utils_lib_dir', config.mlir_lib_dir)) + config.substitutions.append(("%ENABLE_VLA", "false")) + config.substitutions.append(("%VLA_ARCH_ATTR_OPTIONS", "")) + config.substitutions.append(("%mlir_native_utils_lib_dir", config.mlir_lib_dir)) diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/lit.local.cfg b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/lit.local.cfg index c586aae..6788cce 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/lit.local.cfg +++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/lit.local.cfg @@ -1,2 +1,2 @@ if not config.enable_cuda_runner or not config.mlir_run_cuda_sm80_tests: - config.unsupported = True + config.unsupported = True diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/lit.local.cfg b/mlir/test/Integration/Dialect/SparseTensor/python/lit.local.cfg index cf04454..361b657 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/python/lit.local.cfg +++ b/mlir/test/Integration/Dialect/SparseTensor/python/lit.local.cfg @@ -1,5 +1,5 @@ # Disable ASAN's leak detection for python OpsDSL tests. -config.environment['ASAN_OPTIONS'] = 'detect_leaks=0' +config.environment["ASAN_OPTIONS"] = "detect_leaks=0" # Only run when python bindings are enabled. if not config.enable_bindings_python: - config.unsupported = True + config.unsupported = True diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py index 958aa86..1f9b636 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py +++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py @@ -18,42 +18,45 @@ _SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__)) sys.path.append(_SCRIPT_PATH) from tools import sparse_compiler + @dsl.linalg_structured_op def sddmm_dsl( A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K), B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N), S=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N), - C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True)): - C[dsl.D.m, - dsl.D.n] += S[dsl.D.m, dsl.D.n] * A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n] + C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True), +): + C[dsl.D.m, dsl.D.n] += ( + S[dsl.D.m, dsl.D.n] * A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n] + ) def build_SDDMM(attr: st.EncodingAttr): - """Build SDDMM kernel. + """Build SDDMM kernel. - This method generates a linalg op with for matrix multiplication using - just the Python API. Effectively, a generic linalg op is constructed - that computes C(i,j) += S(i,j) SUM_k A(i,k) B(k,j) for sparse S. - """ - module = ir.Module.create() - f64 = ir.F64Type.get() - a = ir.RankedTensorType.get([8, 8], f64) - b = ir.RankedTensorType.get([8, 8], f64) - c = ir.RankedTensorType.get([8, 8], f64) - s = ir.RankedTensorType.get([8, 8], f64, attr) - arguments = [a, b, s, c] - with ir.InsertionPoint(module.body): + This method generates a linalg op with for matrix multiplication using + just the Python API. Effectively, a generic linalg op is constructed + that computes C(i,j) += S(i,j) SUM_k A(i,k) B(k,j) for sparse S. + """ + module = ir.Module.create() + f64 = ir.F64Type.get() + a = ir.RankedTensorType.get([8, 8], f64) + b = ir.RankedTensorType.get([8, 8], f64) + c = ir.RankedTensorType.get([8, 8], f64) + s = ir.RankedTensorType.get([8, 8], f64, attr) + arguments = [a, b, s, c] + with ir.InsertionPoint(module.body): - @func.FuncOp.from_py_func(*arguments) - def sddmm(*args): - return sddmm_dsl(args[0], args[1], args[2], outs=[args[3]]) + @func.FuncOp.from_py_func(*arguments) + def sddmm(*args): + return sddmm_dsl(args[0], args[1], args[2], outs=[args[3]]) - return module + return module def boilerplate(attr: st.EncodingAttr): - """Returns boilerplate code for main driver.""" - return f""" + """Returns boilerplate code for main driver.""" + return f""" func.func @main(%a: tensor<8x8xf64>, %b: tensor<8x8xf64>, %c: tensor<8x8xf64>) -> tensor<8x8xf64> attributes {{ llvm.emit_c_interface }} {{ @@ -69,92 +72,100 @@ func.func @main(%a: tensor<8x8xf64>, def build_compile_and_run_SDDMMM(attr: st.EncodingAttr, compiler): - # Build. - module = build_SDDMM(attr) - func = str(module.operation.regions[0].blocks[0].operations[0].operation) - module = ir.Module.parse(func + boilerplate(attr)) - - # Compile. - engine = compiler.compile_and_jit(module) - - # Set up numpy input and buffer for output. - a = np.array([[1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1], - [1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2], - [1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3], - [1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4], - [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5], - [1.6, 2.6, 3.6, 4.6, 5.6, 6.6, 7.6, 8.6], - [1.7, 2.7, 3.7, 4.7, 5.7, 6.7, 7.7, 8.7], - [1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8]], np.float64) - b = np.ones((8, 8), np.float64) - c = np.zeros((8, 8), np.float64) - - mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a))) - mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b))) - mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c))) - - # Allocate a MemRefDescriptor to receive the output tensor. - # The buffer itself is allocated inside the MLIR code generation. - ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)() - mem_out = ctypes.pointer(ctypes.pointer(ref_out)) - - # Invoke the kernel and get numpy output. - # Built-in bufferization uses in-out buffers. - # TODO: replace with inplace comprehensive bufferization. - engine.invoke('main', mem_out, mem_a, mem_b, mem_c) - - # Sanity check on computed result. Only a few elements - # are sampled from the full dense matrix multiplication. - full_matmul = np.matmul(a, b) - expected = np.zeros((8, 8), np.float64) - expected[0, 0] = 1.0 * full_matmul[0, 0] - expected[0, 2] = 2.0 * full_matmul[0, 2] - expected[4, 1] = 3.0 * full_matmul[4, 1] - c = rt.ranked_memref_to_numpy(mem_out[0]) - if np.allclose(c, expected): - pass - else: - quit(f'FAILURE') + # Build. + module = build_SDDMM(attr) + func = str(module.operation.regions[0].blocks[0].operations[0].operation) + module = ir.Module.parse(func + boilerplate(attr)) + + # Compile. + engine = compiler.compile_and_jit(module) + + # Set up numpy input and buffer for output. + a = np.array( + [ + [1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1], + [1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2], + [1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3], + [1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4], + [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5], + [1.6, 2.6, 3.6, 4.6, 5.6, 6.6, 7.6, 8.6], + [1.7, 2.7, 3.7, 4.7, 5.7, 6.7, 7.7, 8.7], + [1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8], + ], + np.float64, + ) + b = np.ones((8, 8), np.float64) + c = np.zeros((8, 8), np.float64) + + mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a))) + mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b))) + mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c))) + + # Allocate a MemRefDescriptor to receive the output tensor. + # The buffer itself is allocated inside the MLIR code generation. + ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)() + mem_out = ctypes.pointer(ctypes.pointer(ref_out)) + + # Invoke the kernel and get numpy output. + # Built-in bufferization uses in-out buffers. + # TODO: replace with inplace comprehensive bufferization. + engine.invoke("main", mem_out, mem_a, mem_b, mem_c) + + # Sanity check on computed result. Only a few elements + # are sampled from the full dense matrix multiplication. + full_matmul = np.matmul(a, b) + expected = np.zeros((8, 8), np.float64) + expected[0, 0] = 1.0 * full_matmul[0, 0] + expected[0, 2] = 2.0 * full_matmul[0, 2] + expected[4, 1] = 3.0 * full_matmul[4, 1] + c = rt.ranked_memref_to_numpy(mem_out[0]) + if np.allclose(c, expected): + pass + else: + quit(f"FAILURE") def main(): - support_lib = os.getenv('SUPPORT_LIB') - assert support_lib is not None, 'SUPPORT_LIB is undefined' - if not os.path.exists(support_lib): - raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), - support_lib) - - # CHECK-LABEL: TEST: testSDDMMM - print('\nTEST: testSDDMMM') - with ir.Context() as ctx, ir.Location.unknown(): - count = 0 - # Loop over various ways to compile and annotate the SDDMM kernel with - # a *single* sparse tensor. Note that we deliberate do not exhaustively - # search the full state space to reduce runtime of the test. It is - # straightforward to adapt the code below to explore more combinations. - levels = [[st.DimLevelType.dense, st.DimLevelType.dense], - [st.DimLevelType.dense, st.DimLevelType.compressed], - [st.DimLevelType.compressed, st.DimLevelType.dense], - [st.DimLevelType.compressed, st.DimLevelType.compressed]] - orderings = [ - ir.AffineMap.get_permutation([0, 1]), - ir.AffineMap.get_permutation([1, 0]) - ] - for level in levels: - for ordering in orderings: - for pwidth in [32]: - for iwidth in [32]: - for e in [True]: - attr = st.EncodingAttr.get(level, ordering, None, pwidth, - iwidth) - opt = (f'parallelization-strategy=none') - compiler = sparse_compiler.SparseCompiler( - options=opt, opt_level=0, shared_libs=[support_lib]) - build_compile_and_run_SDDMMM(attr, compiler) - count = count + 1 - # CHECK: Passed 8 tests - print('Passed ', count, 'tests') - - -if __name__ == '__main__': - main() + support_lib = os.getenv("SUPPORT_LIB") + assert support_lib is not None, "SUPPORT_LIB is undefined" + if not os.path.exists(support_lib): + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib) + + # CHECK-LABEL: TEST: testSDDMMM + print("\nTEST: testSDDMMM") + with ir.Context() as ctx, ir.Location.unknown(): + count = 0 + # Loop over various ways to compile and annotate the SDDMM kernel with + # a *single* sparse tensor. Note that we deliberate do not exhaustively + # search the full state space to reduce runtime of the test. It is + # straightforward to adapt the code below to explore more combinations. + levels = [ + [st.DimLevelType.dense, st.DimLevelType.dense], + [st.DimLevelType.dense, st.DimLevelType.compressed], + [st.DimLevelType.compressed, st.DimLevelType.dense], + [st.DimLevelType.compressed, st.DimLevelType.compressed], + ] + orderings = [ + ir.AffineMap.get_permutation([0, 1]), + ir.AffineMap.get_permutation([1, 0]), + ] + for level in levels: + for ordering in orderings: + for pwidth in [32]: + for iwidth in [32]: + for e in [True]: + attr = st.EncodingAttr.get( + level, ordering, None, pwidth, iwidth + ) + opt = f"parallelization-strategy=none" + compiler = sparse_compiler.SparseCompiler( + options=opt, opt_level=0, shared_libs=[support_lib] + ) + build_compile_and_run_SDDMMM(attr, compiler) + count = count + 1 + # CHECK: Passed 8 tests + print("Passed ", count, "tests") + + +if __name__ == "__main__": + main() diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py index 97954ce..69f6cdc 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py +++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py @@ -18,45 +18,47 @@ _SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__)) sys.path.append(_SCRIPT_PATH) from tools import sparse_compiler + @dsl.linalg_structured_op def matmul_dsl( A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K), B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N), - C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True)): - C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n] + C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True), +): + C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n] def build_SpMM(attr: st.EncodingAttr): - """Build SpMM kernel. + """Build SpMM kernel. - This method generates a linalg op with for matrix multiplication using - just the Python API. Effectively, a generic linalg op is constructed - that computes C(i,j) += A(i,k) * B(k,j) for annotated matrix A. - """ - module = ir.Module.create() - f64 = ir.F64Type.get() - a = ir.RankedTensorType.get([3, 4], f64, attr) - b = ir.RankedTensorType.get([4, 2], f64) - c = ir.RankedTensorType.get([3, 2], f64) - arguments = [a, b, c] - with ir.InsertionPoint(module.body): + This method generates a linalg op with for matrix multiplication using + just the Python API. Effectively, a generic linalg op is constructed + that computes C(i,j) += A(i,k) * B(k,j) for annotated matrix A. + """ + module = ir.Module.create() + f64 = ir.F64Type.get() + a = ir.RankedTensorType.get([3, 4], f64, attr) + b = ir.RankedTensorType.get([4, 2], f64) + c = ir.RankedTensorType.get([3, 2], f64) + arguments = [a, b, c] + with ir.InsertionPoint(module.body): - @func.FuncOp.from_py_func(*arguments) - def spMxM(*args): - return matmul_dsl(args[0], args[1], outs=[args[2]]) + @func.FuncOp.from_py_func(*arguments) + def spMxM(*args): + return matmul_dsl(args[0], args[1], outs=[args[2]]) - return module + return module def boilerplate(attr: st.EncodingAttr): - """Returns boilerplate main method. - - This method sets up a boilerplate main method that takes three tensors - (a, b, c), converts the first tensor a into s sparse tensor, and then - calls the sparse kernel for matrix multiplication. For convenience, - this part is purely done as string input. - """ - return f""" + """Returns boilerplate main method. + + This method sets up a boilerplate main method that takes three tensors + (a, b, c), converts the first tensor a into s sparse tensor, and then + calls the sparse kernel for matrix multiplication. For convenience, + this part is purely done as string input. + """ + return f""" func.func @main(%ad: tensor<3x4xf64>, %b: tensor<4x2xf64>, %c: tensor<3x2xf64>) -> tensor<3x2xf64> attributes {{ llvm.emit_c_interface }} {{ %a = sparse_tensor.convert %ad : tensor<3x4xf64> to tensor<3x4xf64, {attr}> @@ -69,82 +71,87 @@ func.func @main(%ad: tensor<3x4xf64>, %b: tensor<4x2xf64>, %c: tensor<3x2xf64>) def build_compile_and_run_SpMM(attr: st.EncodingAttr, compiler): - # Build. - module = build_SpMM(attr) - func = str(module.operation.regions[0].blocks[0].operations[0].operation) - module = ir.Module.parse(func + boilerplate(attr)) - - # Compile. - engine = compiler.compile_and_jit(module) - - # Set up numpy input and buffer for output. - a = np.array( - [[1.1, 0.0, 0.0, 1.4], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 3.3, 0.0]], - np.float64) - b = np.array([[1.0, 2.0], [4.0, 3.0], [5.0, 6.0], [8.0, 7.0]], np.float64) - c = np.zeros((3, 2), np.float64) - - mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a))) - mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b))) - mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c))) - # Allocate a MemRefDescriptor to receive the output tensor. - # The buffer itself is allocated inside the MLIR code generation. - ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)() - mem_out = ctypes.pointer(ctypes.pointer(ref_out)) - - # Invoke the kernel and get numpy output. - # Built-in bufferization uses in-out buffers. - # TODO: replace with inplace comprehensive bufferization. - engine.invoke('main', mem_out, mem_a, mem_b, mem_c) - - # Sanity check on computed result. - expected = np.matmul(a, b); - c = rt.ranked_memref_to_numpy(mem_out[0]) - if np.allclose(c, expected): - pass - else: - quit(f'FAILURE') + # Build. + module = build_SpMM(attr) + func = str(module.operation.regions[0].blocks[0].operations[0].operation) + module = ir.Module.parse(func + boilerplate(attr)) + + # Compile. + engine = compiler.compile_and_jit(module) + + # Set up numpy input and buffer for output. + a = np.array( + [[1.1, 0.0, 0.0, 1.4], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 3.3, 0.0]], np.float64 + ) + b = np.array([[1.0, 2.0], [4.0, 3.0], [5.0, 6.0], [8.0, 7.0]], np.float64) + c = np.zeros((3, 2), np.float64) + + mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a))) + mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b))) + mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c))) + # Allocate a MemRefDescriptor to receive the output tensor. + # The buffer itself is allocated inside the MLIR code generation. + ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)() + mem_out = ctypes.pointer(ctypes.pointer(ref_out)) + + # Invoke the kernel and get numpy output. + # Built-in bufferization uses in-out buffers. + # TODO: replace with inplace comprehensive bufferization. + engine.invoke("main", mem_out, mem_a, mem_b, mem_c) + + # Sanity check on computed result. + expected = np.matmul(a, b) + c = rt.ranked_memref_to_numpy(mem_out[0]) + if np.allclose(c, expected): + pass + else: + quit(f"FAILURE") def main(): - support_lib = os.getenv('SUPPORT_LIB') - assert support_lib is not None, 'SUPPORT_LIB is undefined' - if not os.path.exists(support_lib): - raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib) - - # CHECK-LABEL: TEST: testSpMM - print('\nTEST: testSpMM') - with ir.Context() as ctx, ir.Location.unknown(): - count = 0 - # Loop over various ways to compile and annotate the SpMM kernel with - # a *single* sparse tensor. Note that we deliberate do not exhaustively - # search the full state space to reduce runtime of the test. It is - # straightforward to adapt the code below to explore more combinations. - - vl = 1 - e = False - opt = (f'parallelization-strategy=none') - levels = [[st.DimLevelType.dense, st.DimLevelType.dense], - [st.DimLevelType.dense, st.DimLevelType.compressed], - [st.DimLevelType.compressed, st.DimLevelType.dense], - [st.DimLevelType.compressed, st.DimLevelType.compressed]] - orderings = [ - ir.AffineMap.get_permutation([0, 1]), - ir.AffineMap.get_permutation([1, 0]) - ] - bitwidths = [0] - compiler = sparse_compiler.SparseCompiler( - options=opt, opt_level=0, shared_libs=[support_lib]) - for level in levels: - for ordering in orderings: - for pwidth in bitwidths: - for iwidth in bitwidths: - attr = st.EncodingAttr.get(level, ordering, None, pwidth, iwidth) - build_compile_and_run_SpMM(attr, compiler) - count = count + 1 - # CHECK: Passed 8 tests - print('Passed ', count, 'tests') - - -if __name__ == '__main__': - main() + support_lib = os.getenv("SUPPORT_LIB") + assert support_lib is not None, "SUPPORT_LIB is undefined" + if not os.path.exists(support_lib): + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib) + + # CHECK-LABEL: TEST: testSpMM + print("\nTEST: testSpMM") + with ir.Context() as ctx, ir.Location.unknown(): + count = 0 + # Loop over various ways to compile and annotate the SpMM kernel with + # a *single* sparse tensor. Note that we deliberate do not exhaustively + # search the full state space to reduce runtime of the test. It is + # straightforward to adapt the code below to explore more combinations. + + vl = 1 + e = False + opt = f"parallelization-strategy=none" + levels = [ + [st.DimLevelType.dense, st.DimLevelType.dense], + [st.DimLevelType.dense, st.DimLevelType.compressed], + [st.DimLevelType.compressed, st.DimLevelType.dense], + [st.DimLevelType.compressed, st.DimLevelType.compressed], + ] + orderings = [ + ir.AffineMap.get_permutation([0, 1]), + ir.AffineMap.get_permutation([1, 0]), + ] + bitwidths = [0] + compiler = sparse_compiler.SparseCompiler( + options=opt, opt_level=0, shared_libs=[support_lib] + ) + for level in levels: + for ordering in orderings: + for pwidth in bitwidths: + for iwidth in bitwidths: + attr = st.EncodingAttr.get( + level, ordering, None, pwidth, iwidth + ) + build_compile_and_run_SpMM(attr, compiler) + count = count + 1 + # CHECK: Passed 8 tests + print("Passed ", count, "tests") + + +if __name__ == "__main__": + main() diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_elementwise_add_sparse_output.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_elementwise_add_sparse_output.py index b29b029..a41bde1 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/python/test_elementwise_add_sparse_output.py +++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_elementwise_add_sparse_output.py @@ -57,49 +57,52 @@ func.func @main(%ad: tensor<3x4xf64>, %bd: tensor<3x4xf64>) -> tensor<3x4xf64, # def _run_test(support_lib, kernel): - """Compiles, runs and checks results.""" - compiler = sparse_compiler.SparseCompiler( - options='', opt_level=2, shared_libs=[support_lib]) - module = ir.Module.parse(kernel) - engine = compiler.compile_and_jit(module) - - # Set up numpy inputs and buffer for output. - a = np.array( - [[1.1, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 6.6, 0.0]], - np.float64) - b = np.array( - [[1.1, 0.0, 0.0, 2.8], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - np.float64) - - mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a))) - mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b))) - - # The sparse tensor output is a pointer to pointer of char. - out = ctypes.c_char(0) - mem_out = ctypes.pointer(ctypes.pointer(out)) - - # Invoke the kernel. - engine.invoke('main', mem_a, mem_b, mem_out) - - # Retrieve and check the result. - rank, nse, shape, values, indices = test_tools.sparse_tensor_to_coo_tensor( - support_lib, mem_out[0], np.float64) - - # CHECK: PASSED - if np.allclose(values, [2.2, 2.8, 6.6]) and np.allclose( - indices, [[0, 0], [0, 3], [2, 2]]): - print('PASSED') - else: - quit('FAILURE') + """Compiles, runs and checks results.""" + compiler = sparse_compiler.SparseCompiler( + options="", opt_level=2, shared_libs=[support_lib] + ) + module = ir.Module.parse(kernel) + engine = compiler.compile_and_jit(module) + + # Set up numpy inputs and buffer for output. + a = np.array( + [[1.1, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 6.6, 0.0]], np.float64 + ) + b = np.array( + [[1.1, 0.0, 0.0, 2.8], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], np.float64 + ) + + mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a))) + mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b))) + + # The sparse tensor output is a pointer to pointer of char. + out = ctypes.c_char(0) + mem_out = ctypes.pointer(ctypes.pointer(out)) + + # Invoke the kernel. + engine.invoke("main", mem_a, mem_b, mem_out) + + # Retrieve and check the result. + rank, nse, shape, values, indices = test_tools.sparse_tensor_to_coo_tensor( + support_lib, mem_out[0], np.float64 + ) + + # CHECK: PASSED + if np.allclose(values, [2.2, 2.8, 6.6]) and np.allclose( + indices, [[0, 0], [0, 3], [2, 2]] + ): + print("PASSED") + else: + quit("FAILURE") def test_elementwise_add(): - # Obtain path to runtime support library. - support_lib = os.getenv('SUPPORT_LIB') - assert support_lib is not None, 'SUPPORT_LIB is undefined' - assert os.path.exists(support_lib), f'{support_lib} does not exist' - with ir.Context() as ctx, ir.Location.unknown(): - _run_test(support_lib, _KERNEL_STR) + # Obtain path to runtime support library. + support_lib = os.getenv("SUPPORT_LIB") + assert support_lib is not None, "SUPPORT_LIB is undefined" + assert os.path.exists(support_lib), f"{support_lib} does not exist" + with ir.Context() as ctx, ir.Location.unknown(): + _run_test(support_lib, _KERNEL_STR) test_elementwise_add() diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py index 7d57b1c..7d77490 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py +++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_output.py @@ -18,8 +18,8 @@ from tools import sparse_compiler # TODO: move more into actual IR building. def boilerplate(attr: st.EncodingAttr): - """Returns boilerplate main method.""" - return f""" + """Returns boilerplate main method.""" + return f""" func.func @main(%p : !llvm.ptr) -> () attributes {{ llvm.emit_c_interface }} {{ %d = arith.constant sparse<[[0, 0], [1, 1], [0, 9], [9, 0], [4, 4]], [1.0, 2.0, 3.0, 4.0, 5.0]> : tensor<10x10xf64> @@ -31,13 +31,13 @@ func.func @main(%p : !llvm.ptr) -> () attributes {{ llvm.emit_c_interface }} def expected(): - """Returns expected contents of output. + """Returns expected contents of output. - Regardless of the dimension ordering, compression, and bitwidths that are - used in the sparse tensor, the output is always lexicographically sorted - by natural index order. - """ - return f"""; extended FROSTT format + Regardless of the dimension ordering, compression, and bitwidths that are + used in the sparse tensor, the output is always lexicographically sorted + by natural index order. + """ + return f"""; extended FROSTT format 2 5 10 10 1 1 1 @@ -49,53 +49,55 @@ def expected(): def build_compile_and_run_output(attr: st.EncodingAttr, compiler): - # Build and Compile. - module = ir.Module.parse(boilerplate(attr)) - engine = compiler.compile_and_jit(module) + # Build and Compile. + module = ir.Module.parse(boilerplate(attr)) + engine = compiler.compile_and_jit(module) - # Invoke the kernel and compare output. - with tempfile.TemporaryDirectory() as test_dir: - out = os.path.join(test_dir, 'out.tns') - buf = out.encode('utf-8') - mem_a = ctypes.pointer(ctypes.pointer(ctypes.create_string_buffer(buf))) - engine.invoke('main', mem_a) + # Invoke the kernel and compare output. + with tempfile.TemporaryDirectory() as test_dir: + out = os.path.join(test_dir, "out.tns") + buf = out.encode("utf-8") + mem_a = ctypes.pointer(ctypes.pointer(ctypes.create_string_buffer(buf))) + engine.invoke("main", mem_a) - actual = open(out).read() - if actual != expected(): - quit('FAILURE') + actual = open(out).read() + if actual != expected(): + quit("FAILURE") def main(): - support_lib = os.getenv('SUPPORT_LIB') - assert support_lib is not None, 'SUPPORT_LIB is undefined' - if not os.path.exists(support_lib): - raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), - support_lib) - - # CHECK-LABEL: TEST: test_output - print('\nTEST: test_output') - count = 0 - with ir.Context() as ctx, ir.Location.unknown(): - # Loop over various sparse types: CSR, DCSR, CSC, DCSC. - levels = [[st.DimLevelType.dense, st.DimLevelType.compressed], - [st.DimLevelType.compressed, st.DimLevelType.compressed]] - orderings = [ - ir.AffineMap.get_permutation([0, 1]), - ir.AffineMap.get_permutation([1, 0]) - ] - bitwidths = [8, 16, 32, 64] - compiler = sparse_compiler.SparseCompiler( - options='', opt_level=2, shared_libs=[support_lib]) - for level in levels: - for ordering in orderings: - for bwidth in bitwidths: - attr = st.EncodingAttr.get(level, ordering, None, bwidth, bwidth) - build_compile_and_run_output(attr, compiler) - count = count + 1 - - # CHECK: Passed 16 tests - print('Passed', count, 'tests') - - -if __name__ == '__main__': - main() + support_lib = os.getenv("SUPPORT_LIB") + assert support_lib is not None, "SUPPORT_LIB is undefined" + if not os.path.exists(support_lib): + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib) + + # CHECK-LABEL: TEST: test_output + print("\nTEST: test_output") + count = 0 + with ir.Context() as ctx, ir.Location.unknown(): + # Loop over various sparse types: CSR, DCSR, CSC, DCSC. + levels = [ + [st.DimLevelType.dense, st.DimLevelType.compressed], + [st.DimLevelType.compressed, st.DimLevelType.compressed], + ] + orderings = [ + ir.AffineMap.get_permutation([0, 1]), + ir.AffineMap.get_permutation([1, 0]), + ] + bitwidths = [8, 16, 32, 64] + compiler = sparse_compiler.SparseCompiler( + options="", opt_level=2, shared_libs=[support_lib] + ) + for level in levels: + for ordering in orderings: + for bwidth in bitwidths: + attr = st.EncodingAttr.get(level, ordering, None, bwidth, bwidth) + build_compile_and_run_output(attr, compiler) + count = count + 1 + + # CHECK: Passed 16 tests + print("Passed", count, "tests") + + +if __name__ == "__main__": + main() diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py index 3a04e5b..373f745 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py +++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py @@ -28,216 +28,241 @@ from tools import sparse_compiler # TODO: move this boilerplate to its own module, so it can be used by # other tests and programs. class TypeConverter: - """Converter between NumPy types and MLIR types.""" - - def __init__(self, context: ir.Context): - # Note 1: these are numpy "scalar types" (i.e., the values of - # np.sctypeDict) not numpy "dtypes" (i.e., the np.dtype class). - # - # Note 2: we must construct the MLIR types in the same context as the - # types that'll be passed to irtype_to_sctype() or irtype_to_dtype(); - # otherwise, those methods will raise a KeyError. - types_list = [ - (np.float64, ir.F64Type.get(context=context)), - (np.float32, ir.F32Type.get(context=context)), - (np.int64, ir.IntegerType.get_signless(64, context=context)), - (np.int32, ir.IntegerType.get_signless(32, context=context)), - (np.int16, ir.IntegerType.get_signless(16, context=context)), - (np.int8, ir.IntegerType.get_signless(8, context=context)), - ] - self._sc2ir = dict(types_list) - self._ir2sc = dict(( (ir,sc) for sc,ir in types_list )) - - def dtype_to_irtype(self, dtype: np.dtype) -> ir.Type: - """Returns the MLIR equivalent of a NumPy dtype.""" - try: - return self.sctype_to_irtype(dtype.type) - except KeyError as e: - raise KeyError(f'Unknown dtype: {dtype}') from e - - def sctype_to_irtype(self, sctype) -> ir.Type: - """Returns the MLIR equivalent of a NumPy scalar type.""" - if sctype in self._sc2ir: - return self._sc2ir[sctype] - else: - raise KeyError(f'Unknown sctype: {sctype}') - - def irtype_to_dtype(self, tp: ir.Type) -> np.dtype: - """Returns the NumPy dtype equivalent of an MLIR type.""" - return np.dtype(self.irtype_to_sctype(tp)) - - def irtype_to_sctype(self, tp: ir.Type): - """Returns the NumPy scalar-type equivalent of an MLIR type.""" - if tp in self._ir2sc: - return self._ir2sc[tp] - else: - raise KeyError(f'Unknown ir.Type: {tp}') - - def get_RankedTensorType_of_nparray(self, nparray: np.ndarray) -> ir.RankedTensorType: - """Returns the ir.RankedTensorType of a NumPy array. Note that NumPy - arrays can only be converted to/from dense tensors, not sparse tensors.""" - # TODO: handle strides as well? - return ir.RankedTensorType.get(nparray.shape, - self.dtype_to_irtype(nparray.dtype)) + """Converter between NumPy types and MLIR types.""" + + def __init__(self, context: ir.Context): + # Note 1: these are numpy "scalar types" (i.e., the values of + # np.sctypeDict) not numpy "dtypes" (i.e., the np.dtype class). + # + # Note 2: we must construct the MLIR types in the same context as the + # types that'll be passed to irtype_to_sctype() or irtype_to_dtype(); + # otherwise, those methods will raise a KeyError. + types_list = [ + (np.float64, ir.F64Type.get(context=context)), + (np.float32, ir.F32Type.get(context=context)), + (np.int64, ir.IntegerType.get_signless(64, context=context)), + (np.int32, ir.IntegerType.get_signless(32, context=context)), + (np.int16, ir.IntegerType.get_signless(16, context=context)), + (np.int8, ir.IntegerType.get_signless(8, context=context)), + ] + self._sc2ir = dict(types_list) + self._ir2sc = dict(((ir, sc) for sc, ir in types_list)) + + def dtype_to_irtype(self, dtype: np.dtype) -> ir.Type: + """Returns the MLIR equivalent of a NumPy dtype.""" + try: + return self.sctype_to_irtype(dtype.type) + except KeyError as e: + raise KeyError(f"Unknown dtype: {dtype}") from e + + def sctype_to_irtype(self, sctype) -> ir.Type: + """Returns the MLIR equivalent of a NumPy scalar type.""" + if sctype in self._sc2ir: + return self._sc2ir[sctype] + else: + raise KeyError(f"Unknown sctype: {sctype}") + + def irtype_to_dtype(self, tp: ir.Type) -> np.dtype: + """Returns the NumPy dtype equivalent of an MLIR type.""" + return np.dtype(self.irtype_to_sctype(tp)) + + def irtype_to_sctype(self, tp: ir.Type): + """Returns the NumPy scalar-type equivalent of an MLIR type.""" + if tp in self._ir2sc: + return self._ir2sc[tp] + else: + raise KeyError(f"Unknown ir.Type: {tp}") + + def get_RankedTensorType_of_nparray( + self, nparray: np.ndarray + ) -> ir.RankedTensorType: + """Returns the ir.RankedTensorType of a NumPy array. Note that NumPy + arrays can only be converted to/from dense tensors, not sparse tensors.""" + # TODO: handle strides as well? + return ir.RankedTensorType.get( + nparray.shape, self.dtype_to_irtype(nparray.dtype) + ) + # ===----------------------------------------------------------------------=== # + class StressTest: - def __init__(self, tyconv: TypeConverter): - self._tyconv = tyconv - self._roundtripTp = None - self._module = None - self._engine = None - - def _assertEqualsRoundtripTp(self, tp: ir.RankedTensorType): - assert self._roundtripTp is not None, \ - 'StressTest: uninitialized roundtrip type' - if tp != self._roundtripTp: - raise AssertionError( - f"Type is not equal to the roundtrip type.\n" - f"\tExpected: {self._roundtripTp}\n" - f"\tFound: {tp}\n") - - def build(self, types: List[ir.Type]): - """Builds the ir.Module. The module has only the @main function, - which will convert the input through the list of types and then back - to the initial type. The roundtrip type must be a dense tensor.""" - assert self._module is None, 'StressTest: must not call build() repeatedly' - self._module = ir.Module.create() - with ir.InsertionPoint(self._module.body): - tp0 = types.pop(0) - self._roundtripTp = tp0 - # TODO: assert dense? assert element type is recognised by the TypeConverter? - types.append(tp0) - funcTp = ir.FunctionType.get(inputs=[tp0], results=[tp0]) - funcOp = func.FuncOp(name='main', type=funcTp) - funcOp.attributes['llvm.emit_c_interface'] = ir.UnitAttr.get() - with ir.InsertionPoint(funcOp.add_entry_block()): - arg0 = funcOp.entry_block.arguments[0] - self._assertEqualsRoundtripTp(arg0.type) - v = st.ConvertOp(types.pop(0), arg0) - for tp in types: - w = st.ConvertOp(tp, v) - # Release intermediate tensors before they fall out of scope. - bufferization.DeallocTensorOp(v.result) - v = w - self._assertEqualsRoundtripTp(v.result.type) - func.ReturnOp(v) - return self - - def writeTo(self, filename): - """Write the ir.Module to the given file. If the file already exists, - then raises an error. If the filename is None, then is a no-op.""" - assert self._module is not None, \ - 'StressTest: must call build() before writeTo()' - if filename is None: - # Silent no-op, for convenience. - return self - if os.path.exists(filename): - raise FileExistsError(errno.EEXIST, os.strerror(errno.EEXIST), filename) - with open(filename, 'w') as f: - f.write(str(self._module)) - return self - - def compile(self, compiler): - """Compile the ir.Module.""" - assert self._module is not None, \ - 'StressTest: must call build() before compile()' - assert self._engine is None, \ - 'StressTest: must not call compile() repeatedly' - self._engine = compiler.compile_and_jit(self._module) - return self - - def run(self, np_arg0: np.ndarray) -> np.ndarray: - """Runs the test on the given numpy array, and returns the resulting - numpy array.""" - assert self._engine is not None, \ - 'StressTest: must call compile() before run()' - self._assertEqualsRoundtripTp( - self._tyconv.get_RankedTensorType_of_nparray(np_arg0)) - np_out = np.zeros(np_arg0.shape, dtype=np_arg0.dtype) - self._assertEqualsRoundtripTp( - self._tyconv.get_RankedTensorType_of_nparray(np_out)) - mem_arg0 = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_arg0))) - mem_out = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_out))) - self._engine.invoke('main', mem_out, mem_arg0) - return rt.ranked_memref_to_numpy(mem_out[0]) + def __init__(self, tyconv: TypeConverter): + self._tyconv = tyconv + self._roundtripTp = None + self._module = None + self._engine = None + + def _assertEqualsRoundtripTp(self, tp: ir.RankedTensorType): + assert self._roundtripTp is not None, "StressTest: uninitialized roundtrip type" + if tp != self._roundtripTp: + raise AssertionError( + f"Type is not equal to the roundtrip type.\n" + f"\tExpected: {self._roundtripTp}\n" + f"\tFound: {tp}\n" + ) + + def build(self, types: List[ir.Type]): + """Builds the ir.Module. The module has only the @main function, + which will convert the input through the list of types and then back + to the initial type. The roundtrip type must be a dense tensor.""" + assert self._module is None, "StressTest: must not call build() repeatedly" + self._module = ir.Module.create() + with ir.InsertionPoint(self._module.body): + tp0 = types.pop(0) + self._roundtripTp = tp0 + # TODO: assert dense? assert element type is recognised by the TypeConverter? + types.append(tp0) + funcTp = ir.FunctionType.get(inputs=[tp0], results=[tp0]) + funcOp = func.FuncOp(name="main", type=funcTp) + funcOp.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + with ir.InsertionPoint(funcOp.add_entry_block()): + arg0 = funcOp.entry_block.arguments[0] + self._assertEqualsRoundtripTp(arg0.type) + v = st.ConvertOp(types.pop(0), arg0) + for tp in types: + w = st.ConvertOp(tp, v) + # Release intermediate tensors before they fall out of scope. + bufferization.DeallocTensorOp(v.result) + v = w + self._assertEqualsRoundtripTp(v.result.type) + func.ReturnOp(v) + return self + + def writeTo(self, filename): + """Write the ir.Module to the given file. If the file already exists, + then raises an error. If the filename is None, then is a no-op.""" + assert ( + self._module is not None + ), "StressTest: must call build() before writeTo()" + if filename is None: + # Silent no-op, for convenience. + return self + if os.path.exists(filename): + raise FileExistsError(errno.EEXIST, os.strerror(errno.EEXIST), filename) + with open(filename, "w") as f: + f.write(str(self._module)) + return self + + def compile(self, compiler): + """Compile the ir.Module.""" + assert ( + self._module is not None + ), "StressTest: must call build() before compile()" + assert self._engine is None, "StressTest: must not call compile() repeatedly" + self._engine = compiler.compile_and_jit(self._module) + return self + + def run(self, np_arg0: np.ndarray) -> np.ndarray: + """Runs the test on the given numpy array, and returns the resulting + numpy array.""" + assert self._engine is not None, "StressTest: must call compile() before run()" + self._assertEqualsRoundtripTp( + self._tyconv.get_RankedTensorType_of_nparray(np_arg0) + ) + np_out = np.zeros(np_arg0.shape, dtype=np_arg0.dtype) + self._assertEqualsRoundtripTp( + self._tyconv.get_RankedTensorType_of_nparray(np_out) + ) + mem_arg0 = ctypes.pointer( + ctypes.pointer(rt.get_ranked_memref_descriptor(np_arg0)) + ) + mem_out = ctypes.pointer( + ctypes.pointer(rt.get_ranked_memref_descriptor(np_out)) + ) + self._engine.invoke("main", mem_out, mem_arg0) + return rt.ranked_memref_to_numpy(mem_out[0]) + # ===----------------------------------------------------------------------=== # + def main(): - """ - USAGE: python3 test_stress.py [raw_module.mlir [compiled_module.mlir]] - - The environment variable SUPPORT_LIB must be set to point to the - libmlir_c_runner_utils shared library. There are two optional - arguments, for debugging purposes. The first argument specifies where - to write out the raw/generated ir.Module. The second argument specifies - where to write out the compiled version of that ir.Module. - """ - support_lib = os.getenv('SUPPORT_LIB') - assert support_lib is not None, 'SUPPORT_LIB is undefined' - if not os.path.exists(support_lib): - raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib) - - # CHECK-LABEL: TEST: test_stress - print("\nTEST: test_stress") - with ir.Context() as ctx, ir.Location.unknown(): - # Disable direct sparse2sparse conversion, because it doubles the time! - # TODO: While direct s2s is far too slow for per-commit testing, - # we should have some framework ensure that we run this test with - # `s2s=0` on a regular basis, to ensure that it does continue to work. - # TODO: be sure to test s2s=0 together with singletons. - s2s = 1 - sparsification_options = ( - f'parallelization-strategy=none ' - f's2s-strategy={s2s}') - compiler = sparse_compiler.SparseCompiler( - options=sparsification_options, opt_level=0, shared_libs=[support_lib]) - f64 = ir.F64Type.get() - # Be careful about increasing this because - # len(types) = 1 + len(level_choices)^rank * rank! * len(bitwidths)^2 - shape = range(2, 3) - rank = len(shape) - # All combinations. - # TODO: add singleton here too; which requires updating how `np_arg0` - # is initialized below. - levels = list(itertools.product(*itertools.repeat( - [st.DimLevelType.dense, st.DimLevelType.compressed], rank))) - # All permutations. - orderings = list(map(ir.AffineMap.get_permutation, - itertools.permutations(range(rank)))) - bitwidths = [0] - # The first type must be a dense tensor for numpy conversion to work. - types = [ir.RankedTensorType.get(shape, f64)] - for level in levels: - for ordering in orderings: - for pwidth in bitwidths: - for iwidth in bitwidths: - attr = st.EncodingAttr.get(level, ordering, None, pwidth, iwidth) - types.append(ir.RankedTensorType.get(shape, f64, attr)) - # - # For exhaustiveness we should have one or more StressTest, such - # that their paths cover all 2*n*(n-1) directed pairwise combinations - # of the `types` set. However, since n is already superexponential, - # such exhaustiveness would be prohibitive for a test that runs on - # every commit. So for now we'll just pick one particular path that - # at least hits all n elements of the `types` set. - # - tyconv = TypeConverter(ctx) - size = 1 - for d in shape: - size *= d - np_arg0 = np.arange(size, dtype=tyconv.irtype_to_dtype(f64)).reshape(*shape) - np_out = ( - StressTest(tyconv).build(types).writeTo( - sys.argv[1] if len(sys.argv) > 1 else None).compile(compiler) - .writeTo(sys.argv[2] if len(sys.argv) > 2 else None).run(np_arg0)) - # CHECK: Passed - if np.allclose(np_out, np_arg0): - print('Passed') - else: - sys.exit('FAILURE') - -if __name__ == '__main__': - main() + """ + USAGE: python3 test_stress.py [raw_module.mlir [compiled_module.mlir]] + + The environment variable SUPPORT_LIB must be set to point to the + libmlir_c_runner_utils shared library. There are two optional + arguments, for debugging purposes. The first argument specifies where + to write out the raw/generated ir.Module. The second argument specifies + where to write out the compiled version of that ir.Module. + """ + support_lib = os.getenv("SUPPORT_LIB") + assert support_lib is not None, "SUPPORT_LIB is undefined" + if not os.path.exists(support_lib): + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib) + + # CHECK-LABEL: TEST: test_stress + print("\nTEST: test_stress") + with ir.Context() as ctx, ir.Location.unknown(): + # Disable direct sparse2sparse conversion, because it doubles the time! + # TODO: While direct s2s is far too slow for per-commit testing, + # we should have some framework ensure that we run this test with + # `s2s=0` on a regular basis, to ensure that it does continue to work. + # TODO: be sure to test s2s=0 together with singletons. + s2s = 1 + sparsification_options = f"parallelization-strategy=none " f"s2s-strategy={s2s}" + compiler = sparse_compiler.SparseCompiler( + options=sparsification_options, opt_level=0, shared_libs=[support_lib] + ) + f64 = ir.F64Type.get() + # Be careful about increasing this because + # len(types) = 1 + len(level_choices)^rank * rank! * len(bitwidths)^2 + shape = range(2, 3) + rank = len(shape) + # All combinations. + # TODO: add singleton here too; which requires updating how `np_arg0` + # is initialized below. + levels = list( + itertools.product( + *itertools.repeat( + [st.DimLevelType.dense, st.DimLevelType.compressed], rank + ) + ) + ) + # All permutations. + orderings = list( + map(ir.AffineMap.get_permutation, itertools.permutations(range(rank))) + ) + bitwidths = [0] + # The first type must be a dense tensor for numpy conversion to work. + types = [ir.RankedTensorType.get(shape, f64)] + for level in levels: + for ordering in orderings: + for pwidth in bitwidths: + for iwidth in bitwidths: + attr = st.EncodingAttr.get( + level, ordering, None, pwidth, iwidth + ) + types.append(ir.RankedTensorType.get(shape, f64, attr)) + # + # For exhaustiveness we should have one or more StressTest, such + # that their paths cover all 2*n*(n-1) directed pairwise combinations + # of the `types` set. However, since n is already superexponential, + # such exhaustiveness would be prohibitive for a test that runs on + # every commit. So for now we'll just pick one particular path that + # at least hits all n elements of the `types` set. + # + tyconv = TypeConverter(ctx) + size = 1 + for d in shape: + size *= d + np_arg0 = np.arange(size, dtype=tyconv.irtype_to_dtype(f64)).reshape(*shape) + np_out = ( + StressTest(tyconv) + .build(types) + .writeTo(sys.argv[1] if len(sys.argv) > 1 else None) + .compile(compiler) + .writeTo(sys.argv[2] if len(sys.argv) > 2 else None) + .run(np_arg0) + ) + # CHECK: Passed + if np.allclose(np_out, np_arg0): + print("Passed") + else: + sys.exit("FAILURE") + + +if __name__ == "__main__": + main() diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/tools/np_to_sparse_tensor.py b/mlir/test/Integration/Dialect/SparseTensor/python/tools/np_to_sparse_tensor.py index f5b0ab6..785d42c 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/python/tools/np_to_sparse_tensor.py +++ b/mlir/test/Integration/Dialect/SparseTensor/python/tools/np_to_sparse_tensor.py @@ -11,65 +11,71 @@ import numpy as np @functools.lru_cache() def _get_c_shared_lib(lib_name: str): - """Loads and returns the requested C shared library. + """Loads and returns the requested C shared library. - Args: - lib_name: A string representing the C shared library. + Args: + lib_name: A string representing the C shared library. - Returns: - The C shared library. + Returns: + The C shared library. - Raises: - OSError: If there is any problem in loading the shared library. - ValueError: If the shared library doesn't contain the needed routine. - """ - # This raises OSError exception if there is any problem in loading the shared - # library. - c_lib = ctypes.CDLL(lib_name) + Raises: + OSError: If there is any problem in loading the shared library. + ValueError: If the shared library doesn't contain the needed routine. + """ + # This raises OSError exception if there is any problem in loading the shared + # library. + c_lib = ctypes.CDLL(lib_name) - try: - c_lib.convertFromMLIRSparseTensorF64.restype = ctypes.c_void_p - except Exception as e: - raise ValueError('Missing function convertFromMLIRSparseTensorF64 from ' - f'the C shared library: {e} ') from e + try: + c_lib.convertFromMLIRSparseTensorF64.restype = ctypes.c_void_p + except Exception as e: + raise ValueError( + "Missing function convertFromMLIRSparseTensorF64 from " + f"the C shared library: {e} " + ) from e - return c_lib + return c_lib def sparse_tensor_to_coo_tensor(support_lib, sparse, dtype): - """Converts a sparse tensor to COO-flavored format. + """Converts a sparse tensor to COO-flavored format. - Args: - support_lib: A string for the supporting C shared library. - sparse: A ctypes.pointer to the sparse tensor descriptor. - dtype: The numpy data type for the tensor elements. + Args: + support_lib: A string for the supporting C shared library. + sparse: A ctypes.pointer to the sparse tensor descriptor. + dtype: The numpy data type for the tensor elements. - Returns: - A tuple that contains the following values: - rank: An integer for the rank of the tensor. - nse: An integer for the number of non-zero values in the tensor. - shape: A 1D numpy array of integers, for the shape of the tensor. - values: A 1D numpy array, for the non-zero values in the tensor. - indices: A 2D numpy array of integers, representing the indices for the - non-zero values in the tensor. + Returns: + A tuple that contains the following values: + rank: An integer for the rank of the tensor. + nse: An integer for the number of non-zero values in the tensor. + shape: A 1D numpy array of integers, for the shape of the tensor. + values: A 1D numpy array, for the non-zero values in the tensor. + indices: A 2D numpy array of integers, representing the indices for the + non-zero values in the tensor. - Raises: - OSError: If there is any problem in loading the shared library. - ValueError: If the shared library doesn't contain the needed routine. - """ - c_lib = _get_c_shared_lib(support_lib) + Raises: + OSError: If there is any problem in loading the shared library. + ValueError: If the shared library doesn't contain the needed routine. + """ + c_lib = _get_c_shared_lib(support_lib) - rank = ctypes.c_ulonglong(0) - nse = ctypes.c_ulonglong(0) - shape = ctypes.POINTER(ctypes.c_ulonglong)() - values = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))() - indices = ctypes.POINTER(ctypes.c_ulonglong)() - c_lib.convertFromMLIRSparseTensorF64(sparse, ctypes.byref(rank), - ctypes.byref(nse), ctypes.byref(shape), - ctypes.byref(values), - ctypes.byref(indices)) - # Convert the returned values to the corresponding numpy types. - shape = np.ctypeslib.as_array(shape, shape=[rank.value]) - values = np.ctypeslib.as_array(values, shape=[nse.value]) - indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value]) - return rank, nse, shape, values, indices + rank = ctypes.c_ulonglong(0) + nse = ctypes.c_ulonglong(0) + shape = ctypes.POINTER(ctypes.c_ulonglong)() + values = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))() + indices = ctypes.POINTER(ctypes.c_ulonglong)() + c_lib.convertFromMLIRSparseTensorF64( + sparse, + ctypes.byref(rank), + ctypes.byref(nse), + ctypes.byref(shape), + ctypes.byref(values), + ctypes.byref(indices), + ) + # Convert the returned values to the corresponding numpy types. + shape = np.ctypeslib.as_array(shape, shape=[rank.value]) + values = np.ctypeslib.as_array(values, shape=[nse.value]) + indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value]) + return rank, nse, shape, values, indices diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/tools/sparse_compiler.py b/mlir/test/Integration/Dialect/SparseTensor/python/tools/sparse_compiler.py index 25004f9..d549a9a 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/python/tools/sparse_compiler.py +++ b/mlir/test/Integration/Dialect/SparseTensor/python/tools/sparse_compiler.py @@ -9,30 +9,31 @@ from mlir import ir from mlir import passmanager from typing import Sequence + class SparseCompiler: - """Sparse compiler class for compiling and building MLIR modules.""" - - def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]): - pipeline = f'builtin.module(sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}})' - self.pipeline = pipeline - self.opt_level = opt_level - self.shared_libs = shared_libs - - def __call__(self, module: ir.Module): - """Convenience application method.""" - self.compile(module) - - def compile(self, module: ir.Module): - """Compiles the module by invoking the sparse copmiler pipeline.""" - passmanager.PassManager.parse(self.pipeline).run(module.operation) - - def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine: - """Wraps the module in a JIT execution engine.""" - return execution_engine.ExecutionEngine( - module, opt_level=self.opt_level, shared_libs=self.shared_libs) - - def compile_and_jit(self, - module: ir.Module) -> execution_engine.ExecutionEngine: - """Compiles and jits the module.""" - self.compile(module) - return self.jit(module) + """Sparse compiler class for compiling and building MLIR modules.""" + + def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]): + pipeline = f"builtin.module(sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}})" + self.pipeline = pipeline + self.opt_level = opt_level + self.shared_libs = shared_libs + + def __call__(self, module: ir.Module): + """Convenience application method.""" + self.compile(module) + + def compile(self, module: ir.Module): + """Compiles the module by invoking the sparse copmiler pipeline.""" + passmanager.PassManager.parse(self.pipeline).run(module.operation) + + def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine: + """Wraps the module in a JIT execution engine.""" + return execution_engine.ExecutionEngine( + module, opt_level=self.opt_level, shared_libs=self.shared_libs + ) + + def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine: + """Compiles and jits the module.""" + self.compile(module) + return self.jit(module) diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/lit.local.cfg b/mlir/test/Integration/Dialect/SparseTensor/taco/lit.local.cfg index 7137d0f..f1bbcf4 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/taco/lit.local.cfg +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/lit.local.cfg @@ -1,5 +1,5 @@ # Disable ASAN's leak detection for python taco tests. -config.environment['ASAN_OPTIONS'] = 'detect_leaks=0' +config.environment["ASAN_OPTIONS"] = "detect_leaks=0" # Only run when python bindings are enabled. if not config.enable_bindings_python: - config.unsupported = True + config.unsupported = True diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_MTTKRP.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_MTTKRP.py index 88b13ae..2d558f8 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_MTTKRP.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_MTTKRP.py @@ -46,10 +46,10 @@ A[i, j] = B[i, k, l] * D[l, j] * C[k, j] # Perform the MTTKRP computation and write the result to file. with tempfile.TemporaryDirectory() as test_dir: - golden_file = os.path.join(_SCRIPT_PATH, "data/gold_A.tns") - out_file = os.path.join(test_dir, "A.tns") - pt.write(out_file, A) - # - # CHECK: Compare result True - # - print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}") + golden_file = os.path.join(_SCRIPT_PATH, "data/gold_A.tns") + out_file = os.path.join(test_dir, "A.tns") + pt.write(out_file, A) + # + # CHECK: Compare result True + # + print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}") diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py index ba4ea9c..ef94ea9 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py @@ -46,13 +46,13 @@ expected = """; extended FROSTT format # Force evaluation of the kernels by writing out X and Y. with tempfile.TemporaryDirectory() as test_dir: - x_file = os.path.join(test_dir, "X.tns") - y_file = os.path.join(test_dir, "Y.tns") - pt.write(x_file, X) - pt.write(y_file, Y) - # - # CHECK: Compare result True True - # - x_data = utils.file_as_string(x_file) - y_data = utils.file_as_string(y_file) - print(f"Compare result {x_data == expected} {y_data == expected}") + x_file = os.path.join(test_dir, "X.tns") + y_file = os.path.join(test_dir, "Y.tns") + pt.write(x_file, X) + pt.write(y_file, Y) + # + # CHECK: Compare result True True + # + x_data = utils.file_as_string(x_file) + y_data = utils.file_as_string(y_file) + print(f"Compare result {x_data == expected} {y_data == expected}") diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMM.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMM.py index 10309cb..02bbbc0 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMM.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMM.py @@ -26,10 +26,10 @@ C[i, j] = A[i, k] * B[k, j] # Force evaluation of the kernel by writing out C. with tempfile.TemporaryDirectory() as test_dir: - golden_file = os.path.join(_SCRIPT_PATH, "data/gold_C.tns") - out_file = os.path.join(test_dir, "C.tns") - pt.write(out_file, C) - # - # CHECK: Compare result True - # - print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}") + golden_file = os.path.join(_SCRIPT_PATH, "data/gold_C.tns") + out_file = os.path.join(test_dir, "C.tns") + pt.write(out_file, C) + # + # CHECK: Compare result True + # + print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}") diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMV.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMV.py index de150ea..2038a47 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMV.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMV.py @@ -47,10 +47,10 @@ y[i] = A[i, j] * x[j] + z[i] # Perform the SpMV computation and write the result to file with tempfile.TemporaryDirectory() as test_dir: - golden_file = os.path.join(_SCRIPT_PATH, "data/gold_y.tns") - out_file = os.path.join(test_dir, "y.tns") - pt.write(out_file, y) - # - # CHECK: Compare result True - # - print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}") + golden_file = os.path.join(_SCRIPT_PATH, "data/gold_y.tns") + out_file = os.path.join(test_dir, "y.tns") + pt.write(out_file, y) + # + # CHECK: Compare result True + # + print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}") diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_Tensor.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_Tensor.py index c1e6c87..cd24e0d 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_Tensor.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_Tensor.py @@ -18,11 +18,10 @@ i, j, k, l, m = pt.get_index_vars(5) alpha = pt.tensor(42.0) # Set up some sparse tensors with different dim annotations and ordering. -S = pt.tensor([8, 8, 8], - pt.format([pt.compressed, pt.dense, pt.compressed], [1, 0, 2])) -X = pt.tensor([8, 8, 8], - pt.format([pt.compressed, pt.compressed, pt.compressed], - [1, 0, 2])) +S = pt.tensor([8, 8, 8], pt.format([pt.compressed, pt.dense, pt.compressed], [1, 0, 2])) +X = pt.tensor( + [8, 8, 8], pt.format([pt.compressed, pt.compressed, pt.compressed], [1, 0, 2]) +) S.insert([0, 0, 0], 2.0) S.insert([1, 1, 1], 3.0) S.insert([4, 4, 4], 4.0) @@ -32,16 +31,14 @@ X[i, j, k] = alpha[0] * S[i, j, k] # Set up tensors with a dense last dimension. This results in a full # enveloping storage of all last "rows" with one or more nonzeros. -T = pt.tensor([1, 2, 3, 4, 5], - pt.format([ - pt.compressed, pt.compressed, pt.compressed, pt.compressed, - pt.dense - ])) -Y = pt.tensor([1, 2, 3, 4, 5], - pt.format([ - pt.compressed, pt.compressed, pt.compressed, pt.compressed, - pt.dense - ])) +T = pt.tensor( + [1, 2, 3, 4, 5], + pt.format([pt.compressed, pt.compressed, pt.compressed, pt.compressed, pt.dense]), +) +Y = pt.tensor( + [1, 2, 3, 4, 5], + pt.format([pt.compressed, pt.compressed, pt.compressed, pt.compressed, pt.dense]), +) T.insert([0, 1, 2, 3, 4], -2.0) Y[i, j, k, l, m] = alpha[0] * T[i, j, k, l, m] @@ -85,18 +82,18 @@ z_expected = """; extended FROSTT format # Force evaluation of the kernel by writing out X. with tempfile.TemporaryDirectory() as test_dir: - x_file = os.path.join(test_dir, 'X.tns') - pt.write(x_file, X) - y_file = os.path.join(test_dir, 'Y.tns') - pt.write(y_file, Y) - z_file = os.path.join(test_dir, 'Z.tns') - pt.write(z_file, Z) - # - # CHECK: Compare result True True True - # - x_data = utils.file_as_string(x_file) - y_data = utils.file_as_string(y_file) - z_data = utils.file_as_string(z_file) - print( - f'Compare result {x_data == x_expected} {y_data == y_expected} {z_data == z_expected}' - ) + x_file = os.path.join(test_dir, "X.tns") + pt.write(x_file, X) + y_file = os.path.join(test_dir, "Y.tns") + pt.write(y_file, Y) + z_file = os.path.join(test_dir, "Z.tns") + pt.write(z_file, Z) + # + # CHECK: Compare result True True True + # + x_data = utils.file_as_string(x_file) + y_data = utils.file_as_string(y_file) + z_data = utils.file_as_string(z_file) + print( + f"Compare result {x_data == x_expected} {y_data == y_expected} {z_data == z_expected}" + ) diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_scalar_tensor_algebra.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_scalar_tensor_algebra.py index 60b91de4..206ffa9 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_scalar_tensor_algebra.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_scalar_tensor_algebra.py @@ -12,7 +12,7 @@ compressed = pt.compressed i, j = pt.get_index_vars(2) A = pt.tensor([2, 3]) -S = pt.tensor(3) # S is a scalar tensor. +S = pt.tensor(3) # S is a scalar tensor. B = pt.tensor([2, 3], compressed) A.insert([0, 1], 10) A.insert([1, 2], 40) @@ -26,11 +26,11 @@ passed += np.array_equal(values, [30.0, 120.0]) # Sum all the values in A. S[0] = A[i, j] -passed += (S.get_scalar_value() == 50.0) +passed += S.get_scalar_value() == 50.0 indices, values = S.get_coordinates_and_values() -passed += (len(indices)==0) -passed += (values == 50.0) +passed += len(indices) == 0 +passed += values == 50.0 # CHECK: Number of passed: 5 print("Number of passed:", passed) diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_complex.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_complex.py index 8fd545b..b0fed50 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_complex.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_complex.py @@ -12,20 +12,20 @@ compressed = pt.compressed passed = 0 all_types = [pt.complex64, pt.complex128] for t in all_types: - i, j = pt.get_index_vars(2) - A = pt.tensor([2, 3], dtype=t) - B = pt.tensor([2, 3], dtype=t) - C = pt.tensor([2, 3], compressed, dtype=t) - A.insert([0, 1], 10 + 20j) - A.insert([1, 2], 40 + 0.5j) - B.insert([0, 0], 20) - B.insert([1, 2], 30 + 15j) - C[i, j] = A[i, j] + B[i, j] + i, j = pt.get_index_vars(2) + A = pt.tensor([2, 3], dtype=t) + B = pt.tensor([2, 3], dtype=t) + C = pt.tensor([2, 3], compressed, dtype=t) + A.insert([0, 1], 10 + 20j) + A.insert([1, 2], 40 + 0.5j) + B.insert([0, 0], 20) + B.insert([1, 2], 30 + 15j) + C[i, j] = A[i, j] + B[i, j] - indices, values = C.get_coordinates_and_values() - passed += isinstance(values[0], t.value) - passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]]) - passed += np.allclose(values, [20, 10 + 20j, 70 + 15.5j]) + indices, values = C.get_coordinates_and_values() + passed += isinstance(values[0], t.value) + passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]]) + passed += np.allclose(values, [20, 10 + 20j, 70 + 15.5j]) # CHECK: Number of passed: 6 print("Number of passed:", passed) diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py index cec687f..4ba2836 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_tensor_types.py @@ -12,24 +12,22 @@ compressed = pt.compressed dense = pt.dense passed = 0 -all_types = [ - pt.int8, pt.int16, pt.int32, pt.int64, pt.float16, pt.float32, pt.float64 -] +all_types = [pt.int8, pt.int16, pt.int32, pt.int64, pt.float16, pt.float32, pt.float64] for t in all_types: - i, j = pt.get_index_vars(2) - A = pt.tensor([2, 3], dtype=t) - B = pt.tensor([2, 3], dtype=t) - C = pt.tensor([2, 3], compressed, dtype=t) - A.insert([0, 1], 10) - A.insert([1, 2], 40) - B.insert([0, 0], 20) - B.insert([1, 2], 30) - C[i, j] = A[i, j] + B[i, j] + i, j = pt.get_index_vars(2) + A = pt.tensor([2, 3], dtype=t) + B = pt.tensor([2, 3], dtype=t) + C = pt.tensor([2, 3], compressed, dtype=t) + A.insert([0, 1], 10) + A.insert([1, 2], 40) + B.insert([0, 0], 20) + B.insert([1, 2], 30) + C[i, j] = A[i, j] + B[i, j] - indices, values = C.get_coordinates_and_values() - passed += isinstance(values[0], t.value) - passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]]) - passed += np.allclose(values, [20.0, 10.0, 70.0]) + indices, values = C.get_coordinates_and_values() + passed += isinstance(values[0], t.value) + passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]]) + passed += np.allclose(values, [20.0, 10.0, 70.0]) # CHECK: Number of passed: 21 print("Number of passed:", passed) diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_true_dense_tensor_algebra.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_true_dense_tensor_algebra.py index a138678..78bce34 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_true_dense_tensor_algebra.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_true_dense_tensor_algebra.py @@ -10,8 +10,8 @@ from tools import mlir_pytaco_api as pt i, j = pt.get_index_vars(2) # Both tensors are true dense tensors. -A = pt.from_array(np.full([2,3], 1, dtype=np.float64)) -B = pt.from_array(np.full([2,3], 2, dtype=np.float64)) +A = pt.from_array(np.full([2, 3], 1, dtype=np.float64)) +B = pt.from_array(np.full([2, 3], 2, dtype=np.float64)) # Define the result tensor as a true dense tensor. The parameter is_dense=True # is an MLIR-PyTACO extension. C = pt.tensor([2, 3], dtype=pt.float64, is_dense=True) diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py index 44d28b0..b3194f7 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py @@ -65,19 +65,20 @@ _SubtreeLeafChecker = Optional[Callable[..., bool]] class Type(enum.Enum): - """The data types supported by TACO. + """The data types supported by TACO. - We use numpy data types to implement the enum data types. - """ - INT8 = np.int8 - INT16 = np.int16 - INT32 = np.int32 - INT64 = np.int64 - FLOAT16 = np.float16 - FLOAT32 = np.float32 - FLOAT64 = np.float64 - COMPLEX64 = np.complex64 - COMPLEX128 = np.complex128 + We use numpy data types to implement the enum data types. + """ + + INT8 = np.int8 + INT16 = np.int16 + INT32 = np.int32 + INT64 = np.int64 + FLOAT16 = np.float16 + FLOAT32 = np.float32 + FLOAT64 = np.float64 + COMPLEX64 = np.complex64 + COMPLEX128 = np.complex128 # All floating point type enums. @@ -88,1732 +89,1810 @@ _INT_TYPES = (Type.INT8, Type.INT16, Type.INT32, Type.INT64) _COMPLEX_TYPES = (Type.COMPLEX64, Type.COMPLEX128) # Type alias for any numpy type used to implement the runtime support for the # enum data types. -_AnyRuntimeType = Union[np.int8, np.int16, np.int32, np.int64, np.float16, - np.float32, np.float64, np.complex64, np.complex128] +_AnyRuntimeType = Union[ + np.int8, + np.int16, + np.int32, + np.int64, + np.float16, + np.float32, + np.float64, + np.complex64, + np.complex128, +] @dataclasses.dataclass(frozen=True) class DType: - """The data type class. + """The data type class. - We support the TACO API dtype class with an alias of this class. + We support the TACO API dtype class with an alias of this class. - The following methods are defined by the TACO API: - is_float: Returns whether the data type represents a floating point value. - is_int: Returns whether the data type represents an integral value. + The following methods are defined by the TACO API: + is_float: Returns whether the data type represents a floating point value. + is_int: Returns whether the data type represents an integral value. - Attributes: - kind: A Type enum representing the data type. - value: The numpy data type for the TACO data type. - """ - kind: Type = Type.FLOAT32 + Attributes: + kind: A Type enum representing the data type. + value: The numpy data type for the TACO data type. + """ - def is_float(self) -> bool: - """Returns whether the data type represents a floating point value.""" - return self.kind in _FLOAT_TYPES + kind: Type = Type.FLOAT32 - def is_int(self) -> bool: - """Returns whether the data type represents an integral value.""" - return self.kind in _INT_TYPES + def is_float(self) -> bool: + """Returns whether the data type represents a floating point value.""" + return self.kind in _FLOAT_TYPES - def is_complex(self) -> bool: - """Returns whether the data type represents a complex value.""" - return self.kind in _COMPLEX_TYPES + def is_int(self) -> bool: + """Returns whether the data type represents an integral value.""" + return self.kind in _INT_TYPES - @property - def value(self) -> _AnyRuntimeType: - """Returns the numpy dtype for the data type.""" - return self.kind.value + def is_complex(self) -> bool: + """Returns whether the data type represents a complex value.""" + return self.kind in _COMPLEX_TYPES + + @property + def value(self) -> _AnyRuntimeType: + """Returns the numpy dtype for the data type.""" + return self.kind.value def _dtype_to_mlir_str(dtype: DType) -> str: - """Returns the MLIR string for the given dtype.""" - dtype_to_str = { - Type.INT16: "i8", - Type.INT16: "i16", - Type.INT32: "i32", - Type.INT64: "i64", - Type.FLOAT16: "f16", - Type.FLOAT32: "f32", - Type.FLOAT64: "f64", - Type.COMPLEX64: "complex", - Type.COMPLEX128: "complex" - } - return dtype_to_str[dtype.kind] + """Returns the MLIR string for the given dtype.""" + dtype_to_str = { + Type.INT16: "i8", + Type.INT16: "i16", + Type.INT32: "i32", + Type.INT64: "i64", + Type.FLOAT16: "f16", + Type.FLOAT32: "f32", + Type.FLOAT64: "f64", + Type.COMPLEX64: "complex", + Type.COMPLEX128: "complex", + } + return dtype_to_str[dtype.kind] def _nptype_to_taco_type(ty: np.dtype) -> DType: - """Returns the TACO type for the given numpy type.""" - nptype_to_dtype = { - np.int8: Type.INT8, - np.int16: Type.INT16, - np.int32: Type.INT32, - np.int64: Type.INT64, - np.float16: Type.FLOAT16, - np.float32: Type.FLOAT32, - np.float64: Type.FLOAT64, - np.complex64: Type.COMPLEX64, - np.complex128: Type.COMPLEX128 - } - return DType(nptype_to_dtype[ty]) + """Returns the TACO type for the given numpy type.""" + nptype_to_dtype = { + np.int8: Type.INT8, + np.int16: Type.INT16, + np.int32: Type.INT32, + np.int64: Type.INT64, + np.float16: Type.FLOAT16, + np.float32: Type.FLOAT32, + np.float64: Type.FLOAT64, + np.complex64: Type.COMPLEX64, + np.complex128: Type.COMPLEX128, + } + return DType(nptype_to_dtype[ty]) def _mlir_type_from_taco_type(dtype: DType) -> ir.Type: - """Returns the MLIR type corresponding to the given TACO type.""" - dtype_to_irtype = { - Type.INT8: ir.IntegerType.get_signless(8), - Type.INT16: ir.IntegerType.get_signless(16), - Type.INT32: ir.IntegerType.get_signless(32), - Type.INT64: ir.IntegerType.get_signless(64), - Type.FLOAT16: ir.F16Type.get(), - Type.FLOAT32: ir.F32Type.get(), - Type.FLOAT64: ir.F64Type.get(), - Type.COMPLEX64: ir.ComplexType.get(ir.F32Type.get()), - Type.COMPLEX128: ir.ComplexType.get(ir.F64Type.get()) - } - return dtype_to_irtype[dtype.kind] + """Returns the MLIR type corresponding to the given TACO type.""" + dtype_to_irtype = { + Type.INT8: ir.IntegerType.get_signless(8), + Type.INT16: ir.IntegerType.get_signless(16), + Type.INT32: ir.IntegerType.get_signless(32), + Type.INT64: ir.IntegerType.get_signless(64), + Type.FLOAT16: ir.F16Type.get(), + Type.FLOAT32: ir.F32Type.get(), + Type.FLOAT64: ir.F64Type.get(), + Type.COMPLEX64: ir.ComplexType.get(ir.F32Type.get()), + Type.COMPLEX128: ir.ComplexType.get(ir.F64Type.get()), + } + return dtype_to_irtype[dtype.kind] + def _ctype_pointer_from_array(array: np.ndarray) -> ctypes.pointer: - """Returns the ctype pointer for the given numpy array.""" - return ctypes.pointer( - ctypes.pointer(runtime.get_ranked_memref_descriptor(array))) + """Returns the ctype pointer for the given numpy array.""" + return ctypes.pointer(ctypes.pointer(runtime.get_ranked_memref_descriptor(array))) class ModeFormat(enum.Enum): - """The tensor dimension storage format class. + """The tensor dimension storage format class. - We support the TACO API mode_format class with an alias of this class. + We support the TACO API mode_format class with an alias of this class. + + In TACO, a tensor dimension is called a mode and the storage format for a + tensor dimension is called a mode format. + """ - In TACO, a tensor dimension is called a mode and the storage format for a - tensor dimension is called a mode format. - """ - DENSE = sparse_tensor.DimLevelType.dense - COMPRESSED = sparse_tensor.DimLevelType.compressed + DENSE = sparse_tensor.DimLevelType.dense + COMPRESSED = sparse_tensor.DimLevelType.compressed -def _mode_format_operation(a: ModeFormat, b: ModeFormat, - op: _LogicalOp) -> ModeFormat: - """Implements the given operator on ModeFormat.""" - return (ModeFormat.COMPRESSED - if op(a == ModeFormat.COMPRESSED, b == ModeFormat.COMPRESSED) else - ModeFormat.DENSE) +def _mode_format_operation(a: ModeFormat, b: ModeFormat, op: _LogicalOp) -> ModeFormat: + """Implements the given operator on ModeFormat.""" + return ( + ModeFormat.COMPRESSED + if op(a == ModeFormat.COMPRESSED, b == ModeFormat.COMPRESSED) + else ModeFormat.DENSE + ) def _mode_format_estimator(op: _BinaryOp) -> _ModeFormatOp: - """Produces a ModeFormat operator for the given binary operator. + """Produces a ModeFormat operator for the given binary operator. - The ModeFormat operator is used as a heuristic to derive the destination - dimension sparsity from the source dimension sparsity. In particular, if the - binary operator produces a disjunction of the zero values from its source - operands, such as the MUL operator, we return a ModeFormat operator that - uses operator.or_. That is, we estimate that a dimension for the MUL - operation result to be sparse if either of its source operands is sparse. + The ModeFormat operator is used as a heuristic to derive the destination + dimension sparsity from the source dimension sparsity. In particular, if the + binary operator produces a disjunction of the zero values from its source + operands, such as the MUL operator, we return a ModeFormat operator that + uses operator.or_. That is, we estimate that a dimension for the MUL + operation result to be sparse if either of its source operands is sparse. - On the other hand, if the binary operator produces a conjunction of the - zero values from its source operands, such as the ADD operator, we return - a ModeFormat operator that uses operator.and_. In this case, we estimate - that a dimension for the ADD operation result to be sparse if both of its - source operands are sparse. + On the other hand, if the binary operator produces a conjunction of the + zero values from its source operands, such as the ADD operator, we return + a ModeFormat operator that uses operator.and_. In this case, we estimate + that a dimension for the ADD operation result to be sparse if both of its + source operands are sparse. - Args: - op: A _BinaryOp object representing a supporting operator on tensors. + Args: + op: A _BinaryOp object representing a supporting operator on tensors. - Returns: - A ModeFormatOp for estimating the destination dimension sparsity from - the source dimension sparsity. - """ - conjunction = functools.partial(_mode_format_operation, op=operator.and_) - disjunction = functools.partial(_mode_format_operation, op=operator.or_) - return conjunction if op(0, 1) != 0 else disjunction + Returns: + A ModeFormatOp for estimating the destination dimension sparsity from + the source dimension sparsity. + """ + conjunction = functools.partial(_mode_format_operation, op=operator.and_) + disjunction = functools.partial(_mode_format_operation, op=operator.or_) + return conjunction if op(0, 1) != 0 else disjunction def _all_instance_of(collection: Iterable, cls: Any) -> bool: - """Returns true if all elements of the iterable is an instance of cls.""" - return all(isinstance(e, cls) for e in collection) + """Returns true if all elements of the iterable is an instance of cls.""" + return all(isinstance(e, cls) for e in collection) def _identity_ordering(rank: int) -> List[int]: - """Returns the identity ordering for tensor of given rank.""" - return list(range(rank)) + """Returns the identity ordering for tensor of given rank.""" + return list(range(rank)) @dataclasses.dataclass(frozen=True) class ModeOrdering: - """The tensor dimension ordering class. - - We support the TACO API mode_ordering class with an alias of this class. + """The tensor dimension ordering class. - Attributes: - ordering: A list of integers representing the ordering of the tensor - dimensions. - """ - ordering: List[int] + We support the TACO API mode_ordering class with an alias of this class. - def __post_init__(self) -> None: - """Verifies the value in ordering. - - Raises: - ValueError: If ordering is not a list of integers. + Attributes: + ordering: A list of integers representing the ordering of the tensor + dimensions. """ - if (not isinstance(self.ordering, list) or - not _all_instance_of(self.ordering, int)): - raise ValueError("Ordering must be a list of integers: " - f"{self.ordering}") - # Check that ordering is a permutation of the dimension numbers. - if sorted(self.ordering) != _identity_ordering(self.rank()): - raise ValueError(f"Invalid ordering: {self.ordering} != " - f"permutation{_identity_ordering(self.rank())}.") - - def rank(self) -> int: - """Returns the number of dimensions represented by the ordering.""" - return len(self.ordering) + ordering: List[int] -@dataclasses.dataclass(frozen=True) -class ModeFormatPack: - """The tensor dimension format class. + def __post_init__(self) -> None: + """Verifies the value in ordering. - We support the TACO API mode_format_pack class with an alias of this class. + Raises: + ValueError: If ordering is not a list of integers. + """ + if not isinstance(self.ordering, list) or not _all_instance_of( + self.ordering, int + ): + raise ValueError("Ordering must be a list of integers: " f"{self.ordering}") + # Check that ordering is a permutation of the dimension numbers. + if sorted(self.ordering) != _identity_ordering(self.rank()): + raise ValueError( + f"Invalid ordering: {self.ordering} != " + f"permutation{_identity_ordering(self.rank())}." + ) - The storage format of a tensor contains one mode_format for each tensor - dimension. + def rank(self) -> int: + """Returns the number of dimensions represented by the ordering.""" + return len(self.ordering) - Attributes: - formats: A list of ModeFormat representing the storage format for each of - the tensor dimension. - """ - formats: List[ModeFormat] - def __post_init__(self) -> None: - """Verifies the value in formats. - - Raises: - ValueError: If formats is not a list of ModeFormats. - """ - if (not isinstance(self.formats, list) or - not _all_instance_of(self.formats, ModeFormat)): - raise ValueError("Formats must be a list of ModeFormat: " - f"{self.formats}") - - def rank(self) -> int: - """Returns the number of dimensions represented by the format pack.""" - return len(self.formats) +@dataclasses.dataclass(frozen=True) +class ModeFormatPack: + """The tensor dimension format class. + We support the TACO API mode_format_pack class with an alias of this class. -@dataclasses.dataclass -class Format: - """The tensor format class defined by the TACO API. - - Attributes: - format_pack: A ModeFormatPack representing the storage format for the tensor - dimensions. - ordering: A ModeOrdering representing the tensor dimension ordering in the - storage. - """ - format_pack: ModeFormatPack - ordering: Optional[ModeOrdering] = None - - def __post_init__(self) -> None: - """Verifies and fixes up the values in format_pack and ordering. - - Verifies and fixes up the values in format_pack and ordering to supports the - initializer syntax defined by the TACO API. If format_pack is a list of - ModeFormat, replaces it with ModeFormatPack constructed from the list. If - ordering is not provided, set ordering to the natural ordering for the rank - corresponding to format_pack. + The storage format of a tensor contains one mode_format for each tensor + dimension. - Raises: - ValueError: If format_pack is not an instance of ModeFormatPack or if - ordering is not an instance of ModeOrdering. + Attributes: + formats: A list of ModeFormat representing the storage format for each of + the tensor dimension. """ - if isinstance(self.format_pack, list): - if not _all_instance_of(self.format_pack, ModeFormat): - raise ValueError(f"Expected a list of ModeFormat: {self.format_pack}") - self.format_pack = ModeFormatPack(self.format_pack) - if not isinstance(self.format_pack, ModeFormatPack): - raise ValueError(f"Expected ModeFormatpack: {self.format_pack}") - - if self.ordering is None: - self.ordering = ModeOrdering(list(range(self.rank()))) - if isinstance(self.ordering, list): - if not _all_instance_of(self.ordering, int): - raise ValueError(f"Expected a list of integer: {self.ordering}") - self.ordering = ModeOrdering(self.ordering) - if not isinstance(self.ordering, ModeOrdering): - raise ValueError(f"Expected ModeOrdering: {self.ordering}") - - if self.format_pack.rank() != self.ordering.rank(): - raise ValueError("Inconsistent ModeFormatPack and ModeOrdering: " - f"len({self.format_pack}) != " - f"len({self.ordering})") - - def rank(self) -> int: - """Returns the number of dimensions represented by the format.""" - return self.format_pack.rank() - - def get_permutation_and_sparsity(self) -> Tuple[np.ndarray, np.ndarray]: - """Constructs the numpy arrays for the permutation and sparsity.""" - perm = np.array(self.ordering.ordering, dtype=np.ulonglong) - a = [f.value for f in self.format_pack.formats] - sparse = np.array(a, dtype=np.uint8) - return (perm, sparse) - - def mlir_tensor_attr(self) -> Optional[sparse_tensor.EncodingAttr]: - """Constructs the MLIR attributes for the tensor format.""" - order = ( - range(self.rank()) if - (self.ordering is None) else self.ordering.ordering) - mlir_storage_format = [f.value for f in self.format_pack.formats] - return sparse_tensor.EncodingAttr.get(mlir_storage_format, - ir.AffineMap.get_permutation(order), - None, _POS_WIDTH, _CRD_WIDTH) - - -def _make_format(formats: List[ModeFormat], - ordering: Optional[List[int]] = None) -> Format: - """Constructs a format from a list of ModeFormat and an optional ordering. - - Args: - formats: A list of ModeFormat, one for each dimension of a tensor. - ordering: An optional list of integer, for the ordering of the tensor - dimensions. When an ordering is not given, the identity ordering is used. - - Returns: - A tensor format object. - - Raises: - ValueError: If formats is not a list of ModeFormat or the length of formats - is not consistent with the len of ordering. - """ - ordering = ordering or _identity_ordering(len(formats)) - return Format(ModeFormatPack(formats), ModeOrdering(ordering)) + formats: List[ModeFormat] -class IndexExpr(abc.ABC): - """The index notation base class. + def __post_init__(self) -> None: + """Verifies the value in formats. - We support the TACO API index_expression class with an alias of this class. - """ + Raises: + ValueError: If formats is not a list of ModeFormats. + """ + if not isinstance(self.formats, list) or not _all_instance_of( + self.formats, ModeFormat + ): + raise ValueError("Formats must be a list of ModeFormat: " f"{self.formats}") - def _verify_operand_and_build_expr(self, rhs, op: _BinaryOp) -> "_BinaryExpr": - """Verifies the RHS operand and returns a binary expression. + def rank(self) -> int: + """Returns the number of dimensions represented by the format pack.""" + return len(self.formats) - Args: - rhs: The RHS of the binary operation, which could be any Python object - from user inputs. - op: A _BinaryOp object representing the binary operator. - Raises: - ValueError: If rhs is not an IndexExpr. - """ - if not isinstance(rhs, IndexExpr): - raise ValueError(f"Expected IndexExpr: {rhs}") - return _BinaryExpr(op, self, rhs) - - def _build_unary_expr(self, op: _UnaryOp) -> "_UnaryExpr": - """Build a unary expression. +@dataclasses.dataclass +class Format: + """The tensor format class defined by the TACO API. - Args: - op: A _UnaryOp object representing the unary operation. + Attributes: + format_pack: A ModeFormatPack representing the storage format for the tensor + dimensions. + ordering: A ModeOrdering representing the tensor dimension ordering in the + storage. """ - return _UnaryExpr(op, self) - def __add__(self, rhs) -> "_BinaryExpr": - """Defines the operator +. + format_pack: ModeFormatPack + ordering: Optional[ModeOrdering] = None + + def __post_init__(self) -> None: + """Verifies and fixes up the values in format_pack and ordering. + + Verifies and fixes up the values in format_pack and ordering to supports the + initializer syntax defined by the TACO API. If format_pack is a list of + ModeFormat, replaces it with ModeFormatPack constructed from the list. If + ordering is not provided, set ordering to the natural ordering for the rank + corresponding to format_pack. + + Raises: + ValueError: If format_pack is not an instance of ModeFormatPack or if + ordering is not an instance of ModeOrdering. + """ + if isinstance(self.format_pack, list): + if not _all_instance_of(self.format_pack, ModeFormat): + raise ValueError(f"Expected a list of ModeFormat: {self.format_pack}") + self.format_pack = ModeFormatPack(self.format_pack) + if not isinstance(self.format_pack, ModeFormatPack): + raise ValueError(f"Expected ModeFormatpack: {self.format_pack}") + + if self.ordering is None: + self.ordering = ModeOrdering(list(range(self.rank()))) + if isinstance(self.ordering, list): + if not _all_instance_of(self.ordering, int): + raise ValueError(f"Expected a list of integer: {self.ordering}") + self.ordering = ModeOrdering(self.ordering) + if not isinstance(self.ordering, ModeOrdering): + raise ValueError(f"Expected ModeOrdering: {self.ordering}") + + if self.format_pack.rank() != self.ordering.rank(): + raise ValueError( + "Inconsistent ModeFormatPack and ModeOrdering: " + f"len({self.format_pack}) != " + f"len({self.ordering})" + ) + + def rank(self) -> int: + """Returns the number of dimensions represented by the format.""" + return self.format_pack.rank() + + def get_permutation_and_sparsity(self) -> Tuple[np.ndarray, np.ndarray]: + """Constructs the numpy arrays for the permutation and sparsity.""" + perm = np.array(self.ordering.ordering, dtype=np.ulonglong) + a = [f.value for f in self.format_pack.formats] + sparse = np.array(a, dtype=np.uint8) + return (perm, sparse) + + def mlir_tensor_attr(self) -> Optional[sparse_tensor.EncodingAttr]: + """Constructs the MLIR attributes for the tensor format.""" + order = ( + range(self.rank()) if (self.ordering is None) else self.ordering.ordering + ) + mlir_storage_format = [f.value for f in self.format_pack.formats] + return sparse_tensor.EncodingAttr.get( + mlir_storage_format, + ir.AffineMap.get_permutation(order), + None, + _POS_WIDTH, + _CRD_WIDTH, + ) + + +def _make_format( + formats: List[ModeFormat], ordering: Optional[List[int]] = None +) -> Format: + """Constructs a format from a list of ModeFormat and an optional ordering. Args: - rhs: The value being added, which could be any Python object from user - inputs. + formats: A list of ModeFormat, one for each dimension of a tensor. + ordering: An optional list of integer, for the ordering of the tensor + dimensions. When an ordering is not given, the identity ordering is used. Returns: - A _BinaryExpr object representing the operation. + A tensor format object. Raises: - ValueError: If rhs is not an IndexExpr. + ValueError: If formats is not a list of ModeFormat or the length of formats + is not consistent with the len of ordering. """ - return self._verify_operand_and_build_expr(rhs, operator.add) + ordering = ordering or _identity_ordering(len(formats)) + return Format(ModeFormatPack(formats), ModeOrdering(ordering)) - def __mul__(self, rhs) -> "_BinaryExpr": - """Defines the operator *. - - Args: - rhs: The value being multiplied, which could be any Python object from - user inputs. - Returns: - A _BinaryExpr object representing the operation. +class IndexExpr(abc.ABC): + """The index notation base class. - Raises: - ValueError: If rhs is not an IndexExpr. + We support the TACO API index_expression class with an alias of this class. """ - return self._verify_operand_and_build_expr(rhs, operator.mul) - def __abs__(self) -> "_UnaryExpr": - """Defines the operator abs. - - Returns: - A _UnaryExpr object representing the operation. - """ - return self._build_unary_expr(operator.abs) + def _verify_operand_and_build_expr(self, rhs, op: _BinaryOp) -> "_BinaryExpr": + """Verifies the RHS operand and returns a binary expression. + + Args: + rhs: The RHS of the binary operation, which could be any Python object + from user inputs. + op: A _BinaryOp object representing the binary operator. + + Raises: + ValueError: If rhs is not an IndexExpr. + """ + if not isinstance(rhs, IndexExpr): + raise ValueError(f"Expected IndexExpr: {rhs}") + return _BinaryExpr(op, self, rhs) + + def _build_unary_expr(self, op: _UnaryOp) -> "_UnaryExpr": + """Build a unary expression. + + Args: + op: A _UnaryOp object representing the unary operation. + """ + return _UnaryExpr(op, self) + + def __add__(self, rhs) -> "_BinaryExpr": + """Defines the operator +. + + Args: + rhs: The value being added, which could be any Python object from user + inputs. + + Returns: + A _BinaryExpr object representing the operation. + + Raises: + ValueError: If rhs is not an IndexExpr. + """ + return self._verify_operand_and_build_expr(rhs, operator.add) + + def __mul__(self, rhs) -> "_BinaryExpr": + """Defines the operator *. + + Args: + rhs: The value being multiplied, which could be any Python object from + user inputs. + + Returns: + A _BinaryExpr object representing the operation. + + Raises: + ValueError: If rhs is not an IndexExpr. + """ + return self._verify_operand_and_build_expr(rhs, operator.mul) + + def __abs__(self) -> "_UnaryExpr": + """Defines the operator abs. + + Returns: + A _UnaryExpr object representing the operation. + """ + return self._build_unary_expr(operator.abs) + + def __neg__(self) -> "_UnaryExpr": + """Defines the operator neg. + + Returns: + A _UnaryExpr object representing the operation. + """ + return self._build_unary_expr(operator.neg) + + def __sub__(self, rhs) -> "_BinaryExpr": + """Defines the operator -. + + Args: + rhs: The value being subtracted, which could be any Python object from + user inputs. + + Returns: + A _BinaryExpr object representing the operation. + + Raises: + ValueError: If rhs is not an IndexExpr. + """ + return self._verify_operand_and_build_expr(rhs, operator.sub) + + @abc.abstractmethod + def _visit( + self, func: _ExprVisitor, args, *, leaf_checker: _SubtreeLeafChecker = None + ) -> None: + """A post-order visitor. + + Args: + func: A callable applied to each node in the expression tree. + args: The variable-length arguments passed to the callable. These + arguments are grouped as an iterable and will be unpacked before passing + to the callable. This is to enable the keyword argument only syntax + after this argument. + leaf_checker: A callable object to identify nodes that should be treated + as leaf nodes to support partial tree visiting. + """ + pass + + @abc.abstractmethod + def _emit_expression( + self, + expr_to_opnd: Dict["IndexExpr", lang.OperandDef], + expr_to_info: _ExprInfoDict, + ) -> lang.ScalarExpression: + """Emits MLIR for the expression tree. + + Args: + expr_to_opnd: A dictionary for looking up structured op input operands for + the input nodes of the structured op. + expr_to_info: A dictionary for looking up code generation information for + expressions. + + Returns: + A linalg dialect ScalarExpression for the expression. + """ + pass + + @abc.abstractmethod + def dtype(self) -> DType: + """Returns the data type for the result of the expression.""" + pass + + def _emit_structured_op(self, expr_to_info: _ExprInfoDict) -> None: + """Emits a structured op in the linalg dialect for the expression tree. + + We define a DefineOpcallable in the domain specific language for the linalg + dialect and execute the callable to generate the structured op. Self is the + root of the expression tree for the structured op. + + Args: + expr_to_info: A dictionary for looking up code generation information for + expressions. + """ + op_info = expr_to_info[self].structop_info + op_name = op_info.dst_name + op_def = lang.LinalgOpDef(name=op_name) + op_callable = lang.DefinedOpCallable(op_name, op_def) + + # Collect the input expression nodes for the structured op. + expr_inputs = [] + self._visit( + _gather_structured_op_input, + (self, expr_to_info, expr_inputs), + leaf_checker=_is_structured_op_leaf, + ) + + # Create a linalg structured op operand for each input expression node and + # build a dictionary for looking up the information. + expr_to_input_opnd = { + e: _emit_structured_op_input(e, expr_to_info, op_def) for e in expr_inputs + } + + # Emit the expression tree, which produces the value assigned to the + # destination tensor. + value = self._emit_expression(expr_to_input_opnd, expr_to_info) + # Emit the structured op representation for the destination tensor. + dst_opnd = _emit_operand( + op_def, + op_info.dst_indices, + op_info.dst_name, + lang.OperandKind.OUTPUT_TENSOR, + ) + dst_dim_syms = _mlir_dimensions_from_index_vars(op_info.dst_indices) + dst_use = lang.TensorUse(dst_opnd, dst_dim_syms) + + expr_info = expr_to_info[self] + # If the structured op reduces some indices, explicitly represent the + # reduction. This is done by generating a ReduceFn for the dimensions being + # reduced in the linalg dialect and calling the function with the value + # being reduced. We only support add reduction currently. + if expr_info.reduce_indices: + reduce_dims = _mlir_dimensions_from_index_vars(expr_info.reduce_indices) + value = lang.ReduceFn.add[reduce_dims](value) + + # Emit the assignment as a comprehension in the linalg dialect. + comp = lang.Comprehension((dst_use, value)) + op_def.comprehensions.append(comp) + + # The structured op in the linalg dialect requires an explicit + # initialization for the destination tensor. Emit MLIR to initialize the + # destination tensor. + init = op_info.emit_tensor_init() + + # Collect MLIR values for the linalg input operands, with the assumption + # that dictionary preserves the insertion order. + args = [ + expr_to_info[expr].mlir_value for expr, opnd in expr_to_input_opnd.items() + ] + # Execute the DefineOpcallable object for the linalg dialect operation to + # emit MLIR for the linalg structured op. + expr_info.mlir_value = op_callable(*args, outs=[init]) + + def _identify_structured_ops( + self, + expr_to_info: _ExprInfoDict, + dst: "Tensor", + dst_indices: Tuple["IndexVar", ...], + ) -> List["IndexExpr"]: + """Returns expression nodes for the roots of the identified structured ops. + + A structured op in the linalg dialect only supports reduction performed on + the whole expression. If the expression tree contains reduction that are + performed on part of the expression tree, the expression tree needs to be + implemented with multiple structured ops. This routine identifies all the + expression nodes that contain reduction as the root of structured ops in the + linalg dialect. + + Args: + expr_to_info: A dictionary for looking up code generation information for + expressions. + dst: A destination Tensor that accepts the value of the expression tree. + dst_indices: The indices used by the destination index expression. + + Returns: + An ordered list of IndexExpr for the root expressions of the structured + ops, where child expressions go before parent expressions that use their + results. + """ + reduce_indices = tuple(set(expr_to_info[self].src_indices) - set(dst_indices)) + for reduce_index in reduce_indices: + _mark_structured_op_root(self, reduce_index, expr_to_info) + + self._visit(_accumulate_reduce_indices, (expr_to_info,)) + structop_roots = [] + self._visit(_gather_structured_op, (expr_to_info, structop_roots)) + + # Handle the root of the top level expression. + if not structop_roots or structop_roots[-1] != self: + # The top level expression is not a reduction. Add the top level + # expression as a structured op root. + structop_roots.append(self) + + # Use user specified information for the destination tensor to build an + # _StructOpInfo for the top level expression. + expr_to_info[self].structop_info = _StructOpInfo( + dst_indices, tuple(dst.shape), dst.dtype, dst.name, dst.format + ) + + return structop_roots + + def _validate_and_collect_expr_info( + self, + dst: "Tensor", + dst_indices: Tuple["IndexVar", ...], + ) -> _ExprInfoDict: + """Propagates expression information for validation. + + Propagates the indices used by child expression nodes to parent expression + nodes. Also collects and validates the sizes for the dimensions + corresponding to the indices. + + Args: + dst: A destination Tensor that accepts the value of the expression tree. + dst_indices: The indices used by the destination index expression. + + Raises: + ValueError if there is any inconsistency in indices or dimensional + values. + + Returns: + A dictionary of (IndexExpr, _ExprInfo). + """ + expr_to_info = {} + # Validate the expression tree and construct expression information. + self._visit(_validate_and_collect_expr_info, (expr_to_info,)) + + # Validate the destination dimension information. + info = expr_to_info[self] + index_to_dim_info = {i: d for i, d in zip(info.src_indices, info.dim_infos)} + for ( + i, + d, + ) in zip(dst_indices, dst.shape): + if i not in index_to_dim_info: + raise ValueError( + "Destination IndexVar not used in the " f"source expression: {i}" + ) + else: + if d != index_to_dim_info[i].dim and index_to_dim_info[i].dim != -1: + raise ValueError( + f"Inconsistent destination dimension for {i}: " + f"{d} vs {index_to_dim_info[i].dim}" + ) + + return expr_to_info + + def _emit_assignment( + self, + module: ir.Module, + dst: "Tensor", + dst_indices: Tuple["IndexVar", ...], + expr_to_info: _ExprInfoDict, + input_accesses: List["Access"], + ) -> None: + """Emits an MLIR function for assigning the expression to a tensor.""" + input_types = [a.tensor.mlir_tensor_type() for a in input_accesses] + + # Build the kernel for the operations. + with ir.InsertionPoint(module.body): + + @func.FuncOp.from_py_func(*input_types, name=_ENTRY_NAME) + def linalg_funcop(*args): + # Set up the mapping from the Access nodes to their MLIR values. + for e, mlir in zip(input_accesses, args): + expr_to_info[e].mlir_value = mlir + + # Emit structured ops in the linalg dialect to implement the assignment. + for structop_root in self._identify_structured_ops( + expr_to_info, dst, dst_indices + ): + structop_root._emit_structured_op(expr_to_info) + dst._record_stats(expr_to_info[structop_root].structop_info) + + # The function returns the MLIR value of the root expression. + return expr_to_info[self].mlir_value + + linalg_funcop.func_op.attributes[ + "llvm.emit_c_interface" + ] = ir.UnitAttr.get() + + def get_input_accesses(self) -> List["Access"]: + """Compute the list of input accesses for the expression.""" + input_accesses = [] + self._visit(_gather_input_accesses_index_vars, (input_accesses,)) + return input_accesses + + def compile( + self, + dst: "Tensor", + dst_indices: Tuple["IndexVar", ...], + ) -> execution_engine.ExecutionEngine: + """Compiles the tensor assignment dst[dst_indices] = expression. + + Args: + dst: The destination tensor. + dst_indices: The tuple of IndexVar used to access the destination tensor. + + Returns: + The execution engine for the tensor assignment. + + Raises: + ValueError: If the expression is not proper or not supported. + """ + expr_to_info = self._validate_and_collect_expr_info(dst, dst_indices) + input_accesses = self.get_input_accesses() + + # Build and compile the module to produce the execution engine. + with ir.Context(), ir.Location.unknown(): + module = ir.Module.create() + self._emit_assignment( + module, dst, dst_indices, expr_to_info, input_accesses + ) + engine = utils.compile_and_build_engine(module) + + return engine - def __neg__(self) -> "_UnaryExpr": - """Defines the operator neg. - Returns: - A _UnaryExpr object representing the operation. - """ - return self._build_unary_expr(operator.neg) +class _AtomicCounter: + """An atomic counter.""" - def __sub__(self, rhs) -> "_BinaryExpr": - """Defines the operator -. + def __init__(self): + self._counter = 0 + self._counter_lock = threading.Lock() - Args: - rhs: The value being subtracted, which could be any Python object from - user inputs. + def increment(self) -> int: + """Increments the counter by one and returns the old value.""" + old_value = self._counter + with self._counter_lock: + self._counter = self._counter + 1 + return old_value - Returns: - A _BinaryExpr object representing the operation. - Raises: - ValueError: If rhs is not an IndexExpr. - """ - return self._verify_operand_and_build_expr(rhs, operator.sub) - - @abc.abstractmethod - def _visit(self, - func: _ExprVisitor, - args, - *, - leaf_checker: _SubtreeLeafChecker = None) -> None: - """A post-order visitor. - - Args: - func: A callable applied to each node in the expression tree. - args: The variable-length arguments passed to the callable. These - arguments are grouped as an iterable and will be unpacked before passing - to the callable. This is to enable the keyword argument only syntax - after this argument. - leaf_checker: A callable object to identify nodes that should be treated - as leaf nodes to support partial tree visiting. - """ - pass +class IndexVar(IndexExpr): + """The tensor index class. - @abc.abstractmethod - def _emit_expression( - self, - expr_to_opnd: Dict["IndexExpr", lang.OperandDef], - expr_to_info: _ExprInfoDict, - ) -> lang.ScalarExpression: - """Emits MLIR for the expression tree. + We support the TACO API index_var class with an alias of this class. - Args: - expr_to_opnd: A dictionary for looking up structured op input operands for - the input nodes of the structured op. - expr_to_info: A dictionary for looking up code generation information for - expressions. + An IndexVar object represents an index variable in tensor index notation. - Returns: - A linalg dialect ScalarExpression for the expression. + Attributes: + name: A unique string name of the IndexVar. """ - pass - @abc.abstractmethod - def dtype(self) -> DType: - """Returns the data type for the result of the expression.""" - pass + _counter = _AtomicCounter() - def _emit_structured_op(self, expr_to_info: _ExprInfoDict) -> None: - """Emits a structured op in the linalg dialect for the expression tree. + def __init__(self): + id = self._counter.increment() + self._name = f"{_TACO_INDEX_PREFIX}{id}" - We define a DefineOpcallable in the domain specific language for the linalg - dialect and execute the callable to generate the structured op. Self is the - root of the expression tree for the structured op. + def __repr__(self) -> str: + return f"IndexVar(name={repr(self._name)})" - Args: - expr_to_info: A dictionary for looking up code generation information for - expressions. - """ - op_info = expr_to_info[self].structop_info - op_name = op_info.dst_name - op_def = lang.LinalgOpDef(name=op_name) - op_callable = lang.DefinedOpCallable(op_name, op_def) - - # Collect the input expression nodes for the structured op. - expr_inputs = [] - self._visit( - _gather_structured_op_input, - (self, expr_to_info, expr_inputs), - leaf_checker=_is_structured_op_leaf, - ) + @property + def name(self) -> str: + """Returns the name of the IndexVar.""" + return self._name - # Create a linalg structured op operand for each input expression node and - # build a dictionary for looking up the information. - expr_to_input_opnd = { - e: _emit_structured_op_input(e, expr_to_info, op_def) - for e in expr_inputs - } + def _visit( + self, func: _ExprVisitor, args, *, leaf_checker: _SubtreeLeafChecker = None + ) -> None: + """A post-order visitor.""" + if leaf_checker: + assert leaf_checker(self, *args) + func(self, *args) - # Emit the expression tree, which produces the value assigned to the - # destination tensor. - value = self._emit_expression(expr_to_input_opnd, expr_to_info) - # Emit the structured op representation for the destination tensor. - dst_opnd = _emit_operand(op_def, op_info.dst_indices, op_info.dst_name, - lang.OperandKind.OUTPUT_TENSOR) - dst_dim_syms = _mlir_dimensions_from_index_vars(op_info.dst_indices) - dst_use = lang.TensorUse(dst_opnd, dst_dim_syms) - - expr_info = expr_to_info[self] - # If the structured op reduces some indices, explicitly represent the - # reduction. This is done by generating a ReduceFn for the dimensions being - # reduced in the linalg dialect and calling the function with the value - # being reduced. We only support add reduction currently. - if expr_info.reduce_indices: - reduce_dims = _mlir_dimensions_from_index_vars(expr_info.reduce_indices) - value = lang.ReduceFn.add[reduce_dims](value) - - # Emit the assignment as a comprehension in the linalg dialect. - comp = lang.Comprehension((dst_use, value)) - op_def.comprehensions.append(comp) - - # The structured op in the linalg dialect requires an explicit - # initialization for the destination tensor. Emit MLIR to initialize the - # destination tensor. - init = op_info.emit_tensor_init() - - # Collect MLIR values for the linalg input operands, with the assumption - # that dictionary preserves the insertion order. - args = [ - expr_to_info[expr].mlir_value - for expr, opnd in expr_to_input_opnd.items() - ] - # Execute the DefineOpcallable object for the linalg dialect operation to - # emit MLIR for the linalg structured op. - expr_info.mlir_value = op_callable(*args, outs=[init]) - - def _identify_structured_ops( - self, - expr_to_info: _ExprInfoDict, - dst: "Tensor", - dst_indices: Tuple["IndexVar", ...], - ) -> List["IndexExpr"]: - """Returns expression nodes for the roots of the identified structured ops. - - A structured op in the linalg dialect only supports reduction performed on - the whole expression. If the expression tree contains reduction that are - performed on part of the expression tree, the expression tree needs to be - implemented with multiple structured ops. This routine identifies all the - expression nodes that contain reduction as the root of structured ops in the - linalg dialect. + def _emit_expression( + self, + expr_to_opnd: Dict[IndexExpr, lang.OperandDef], + expr_to_info: _ExprInfoDict, + ) -> lang.ScalarExpression: + """Emits a index value casted to the data type of the tensor expression.""" + dim = getattr(lang.D, self.name) + index = lang.index(dim) + int_value = lang.TypeFn.cast_unsigned(lang.TV.I64, index) + return lang.TypeFn.cast_unsigned(lang.T, int_value) - Args: - expr_to_info: A dictionary for looking up code generation information for - expressions. - dst: A destination Tensor that accepts the value of the expression tree. - dst_indices: The indices used by the destination index expression. + def dtype(self) -> DType: + """Returns the data type for the index value. - Returns: - An ordered list of IndexExpr for the root expressions of the structured - ops, where child expressions go before parent expressions that use their - results. - """ - reduce_indices = tuple( - set(expr_to_info[self].src_indices) - set(dst_indices)) - for reduce_index in reduce_indices: - _mark_structured_op_root(self, reduce_index, expr_to_info) - - self._visit(_accumulate_reduce_indices, (expr_to_info,)) - structop_roots = [] - self._visit(_gather_structured_op, (expr_to_info, structop_roots)) - - # Handle the root of the top level expression. - if not structop_roots or structop_roots[-1] != self: - # The top level expression is not a reduction. Add the top level - # expression as a structured op root. - structop_roots.append(self) - - # Use user specified information for the destination tensor to build an - # _StructOpInfo for the top level expression. - expr_to_info[self].structop_info = _StructOpInfo(dst_indices, - tuple(dst.shape), - dst.dtype, dst.name, - dst.format) - - return structop_roots - - def _validate_and_collect_expr_info( - self, - dst: "Tensor", - dst_indices: Tuple["IndexVar", ...], - ) -> _ExprInfoDict: - """Propagates expression information for validation. - - Propagates the indices used by child expression nodes to parent expression - nodes. Also collects and validates the sizes for the dimensions - corresponding to the indices. + This is unreachable for IndexVar. + """ + assert 0 - Args: - dst: A destination Tensor that accepts the value of the expression tree. - dst_indices: The indices used by the destination index expression. - Raises: - ValueError if there is any inconsistency in indices or dimensional - values. +def get_index_vars(n: int) -> List[IndexVar]: + """Returns a list of n IndexVar. - Returns: - A dictionary of (IndexExpr, _ExprInfo). - """ - expr_to_info = {} - # Validate the expression tree and construct expression information. - self._visit(_validate_and_collect_expr_info, (expr_to_info,)) - - # Validate the destination dimension information. - info = expr_to_info[self] - index_to_dim_info = {i: d for i, d in zip(info.src_indices, info.dim_infos)} - for i, d, in zip(dst_indices, dst.shape): - if i not in index_to_dim_info: - raise ValueError("Destination IndexVar not used in the " - f"source expression: {i}") - else: - if d != index_to_dim_info[i].dim and index_to_dim_info[i].dim != -1: - raise ValueError(f"Inconsistent destination dimension for {i}: " - f"{d} vs {index_to_dim_info[i].dim}") - - return expr_to_info - - def _emit_assignment( - self, - module: ir.Module, - dst: "Tensor", - dst_indices: Tuple["IndexVar", ...], - expr_to_info: _ExprInfoDict, - input_accesses: List["Access"], - ) -> None: - """Emits an MLIR function for assigning the expression to a tensor.""" - input_types = [a.tensor.mlir_tensor_type() for a in input_accesses] - - # Build the kernel for the operations. - with ir.InsertionPoint(module.body): - - @func.FuncOp.from_py_func(*input_types, name=_ENTRY_NAME) - def linalg_funcop(*args): - # Set up the mapping from the Access nodes to their MLIR values. - for e, mlir in zip(input_accesses, args): - expr_to_info[e].mlir_value = mlir - - # Emit structured ops in the linalg dialect to implement the assignment. - for structop_root in self._identify_structured_ops( - expr_to_info, dst, dst_indices): - structop_root._emit_structured_op(expr_to_info) - dst._record_stats(expr_to_info[structop_root].structop_info) - - # The function returns the MLIR value of the root expression. - return expr_to_info[self].mlir_value - - linalg_funcop.func_op.attributes[ - "llvm.emit_c_interface"] = ir.UnitAttr.get() - - def get_input_accesses(self) -> List["Access"]: - """Compute the list of input accesses for the expression.""" - input_accesses = [] - self._visit(_gather_input_accesses_index_vars, (input_accesses,)) - return input_accesses - - def compile( - self, - dst: "Tensor", - dst_indices: Tuple["IndexVar", ...], - ) -> execution_engine.ExecutionEngine: - """Compiles the tensor assignment dst[dst_indices] = expression. + This routine is defined by the TACO API. Args: - dst: The destination tensor. - dst_indices: The tuple of IndexVar used to access the destination tensor. + n: An integer representing the number of IndexVar to get. Returns: - The execution engine for the tensor assignment. + A list of IndexVar. Raises: - ValueError: If the expression is not proper or not supported. + ValueError: if n is not a positive integer. """ - expr_to_info = self._validate_and_collect_expr_info(dst, dst_indices) - input_accesses = self.get_input_accesses() - - # Build and compile the module to produce the execution engine. - with ir.Context(), ir.Location.unknown(): - module = ir.Module.create() - self._emit_assignment(module, dst, dst_indices, expr_to_info, - input_accesses) - engine = utils.compile_and_build_engine(module) - - return engine - - -class _AtomicCounter: - """An atomic counter.""" - - def __init__(self): - self._counter = 0 - self._counter_lock = threading.Lock() - - def increment(self) -> int: - """Increments the counter by one and returns the old value.""" - old_value = self._counter - with self._counter_lock: - self._counter = self._counter + 1 - return old_value - - -class IndexVar(IndexExpr): - """The tensor index class. - - We support the TACO API index_var class with an alias of this class. - - An IndexVar object represents an index variable in tensor index notation. - - Attributes: - name: A unique string name of the IndexVar. - """ - _counter = _AtomicCounter() - - def __init__(self): - id = self._counter.increment() - self._name = f"{_TACO_INDEX_PREFIX}{id}" - - def __repr__(self) -> str: - return f"IndexVar(name={repr(self._name)})" - - @property - def name(self) -> str: - """Returns the name of the IndexVar.""" - return self._name - - def _visit(self, - func: _ExprVisitor, - args, - *, - leaf_checker: _SubtreeLeafChecker = None) -> None: - """A post-order visitor.""" - if leaf_checker: - assert leaf_checker(self, *args) - func(self, *args) - - def _emit_expression( - self, - expr_to_opnd: Dict[IndexExpr, lang.OperandDef], - expr_to_info: _ExprInfoDict, - ) -> lang.ScalarExpression: - """Emits a index value casted to the data type of the tensor expression.""" - dim = getattr(lang.D, self.name) - index = lang.index(dim) - int_value = lang.TypeFn.cast_unsigned(lang.TV.I64, index) - return lang.TypeFn.cast_unsigned(lang.T, int_value) - - def dtype(self) -> DType: - """Returns the data type for the index value. - - This is unreachable for IndexVar. - """ - assert 0 - - -def get_index_vars(n: int) -> List[IndexVar]: - """Returns a list of n IndexVar. - - This routine is defined by the TACO API. - - Args: - n: An integer representing the number of IndexVar to get. - - Returns: - A list of IndexVar. - - Raises: - ValueError: if n is not a positive integer. - """ - if not isinstance(n, int) or n <= 0: - raise ValueError(f"Expected an integer: {n}.") - # If lock contention ever becomes an issue, we could implement a bulk getter - # that returns a range by only claiming the lock once. - return [IndexVar() for i in range(n)] + if not isinstance(n, int) or n <= 0: + raise ValueError(f"Expected an integer: {n}.") + # If lock contention ever becomes an issue, we could implement a bulk getter + # that returns a range by only claiming the lock once. + return [IndexVar() for i in range(n)] def _mlir_symbols_from_index_vars( - index_vars: Tuple[IndexVar, ...]) -> Tuple[lang.SymbolDef, ...]: - """Returns a tuple of MLIR symbols for the given tuple of index_var.""" - return tuple(getattr(lang.S, i.name) for i in index_vars) + index_vars: Tuple[IndexVar, ...] +) -> Tuple[lang.SymbolDef, ...]: + """Returns a tuple of MLIR symbols for the given tuple of index_var.""" + return tuple(getattr(lang.S, i.name) for i in index_vars) def _mlir_dimensions_from_index_vars( - index_vars: Tuple[IndexVar, ...]) -> Tuple[lang.DimDef, ...]: - """Returns a tuple of MLIR dimensions for the given tuple of index_var.""" - return tuple(getattr(lang.D, i.name) for i in index_vars) + index_vars: Tuple[IndexVar, ...] +) -> Tuple[lang.DimDef, ...]: + """Returns a tuple of MLIR dimensions for the given tuple of index_var.""" + return tuple(getattr(lang.D, i.name) for i in index_vars) def _mlir_tensor_type( - dtype: DType, shape: Tuple[int, ...], - attr: Optional[sparse_tensor.EncodingAttr]) -> ir.RankedTensorType: - """Returns an MLIR tensor type. - - Args: - dtype: An DType object for the element data type of the tensor. - shape: A tuple of integer for the shape of the tensor. - attr: An optional MLIR sparse tensor attribute, only provided if the tensor - is a sparse tensor. - - Returns: - An MLIR ranked tensor type. - """ - ir_type = _mlir_type_from_taco_type(dtype) - return ir.RankedTensorType.get(shape, ir_type, attr) - - -@dataclasses.dataclass(frozen=True) -class _StructOpInfo: - """Information for generating a structured op in the linalg dialect. - - This information is associated with an expression node that serves as the - root for an expression subtree implemented with a structured op. - - Attributes: - dst_indices: A tuple of IndexVar, representing the result dimensions of the - structured op. This is used to construct the temporary variable for the - tensor to hold the structured op result. - dst_dims: A tuple of int, representing the result shape of the structured - op. - dst_dtype: A DType representing the data type of the structured op result. - dst_name: A string representing the name of the structured op result. - dst_format: An optional Format object representing the destination tensor - format. None represents a true dense tensor. - """ - dst_indices: Tuple[IndexVar, ...] - dst_dims: Tuple[int, ...] - dst_dtype: DType - dst_name: str - dst_format: Optional[Format] - - def __post_init__(self) -> None: - """Verifies the integrity of the attribute values.""" - assert len(self.dst_indices) == len(self.dst_dims) - - def emit_tensor_init(self) -> ir.RankedTensorType: - """Returns an initialization for the destination tensor.""" - if self.dst_format is None or self.dst_format.rank() == 0: - # Initialize the dense tensor. - ir_type = _mlir_type_from_taco_type(self.dst_dtype) - empty = tensor.EmptyOp(self.dst_dims, ir_type).result - zero = arith.ConstantOp(ir_type, 0.0) - return linalg.fill(zero, outs=[empty]) - - # Initialize the sparse tensor. - mlir_type = _mlir_tensor_type(self.dst_dtype, self.dst_dims, - self.dst_format.mlir_tensor_attr()) - index_type = ir.IndexType.get() - return bufferization.AllocTensorOp(mlir_type, [], None, None, None) - - -class _Stats: - """Information to describe how a tensor expression is implemented. - - Currently, we only record the temporary tensors introduced for splitting the - original expression. - """ - - def __init__(self): - self._temps = [] - - def __repr__(self) -> str: - return f"_Stats({repr(self._temps)})" - - def add_element(self, structop: _StructOpInfo): - """Adds a temporary tensor.""" - self._temps.append(structop) - - def get_total(self) -> int: - """Gets the total number of temporary tensors.""" - return len(self._temps) - - def _get_element(self, idx: int) -> _StructOpInfo: - """Gets the ith temporary tensor.""" - assert idx < self.get_total() - return self._temps[idx] - - def get_dimensions(self, idx: int) -> Tuple[int]: - """Gets the dimensions for the ith temporary tensor.""" - return self._get_element(idx).dst_dims - - def get_formats(self, idx: int) -> Tuple[ModeFormat]: - """Gets the ModeFormats for the ith temporary tensor.""" - return tuple(self._get_element(idx).dst_format.format_pack.formats) - - -class _SparseValueInfo(enum.Enum): - """Describes how a sparse tensor value is stored. - _UNPACKED: The sparse tensor value is stored as (coordnates, values) in - Python. - _PACKED: The sparse tensor value is stored as a C pointer to a packed MLIR - sparse tensor. - """ - _UNPACKED = 0 - _PACKED = 1 - - -@dataclasses.dataclass(frozen=True) -class _Assignment: - """Records an assignment to a tensor T as T[indices] = expression.""" - indices: Tuple["IndexVar", ...] - expression: "IndexExpr" - - -class Tensor: - """The tensor class. - - We support the TACO API tensor class with an alias of this class. - - This class is part of the TACO API with the following methods: - insert: Inserts a value to the given coordinate in the tensor. - to_array: Returns a numpy ndarray for the tensor. - - TACO API also defines the following arrtibutes for the class: - dtype: A dtype object representing the data type of the tensor. - format: A format object representing the storage format of the tensor. - name: A string object representing the name of the tensor. - order: An integral rank of the tensor. - shape: A list of integers representing the shape of the tensor. - - We currently ignore the tensor dimension ordering for dense tensor. - """ - _counter = _AtomicCounter() - - def _get_unique_name(self) -> str: - """Returns a unique name for creating a new Tensor.""" - return f"{_TACO_TENSOR_PREFIX}{self._counter.increment()}" - - def _init_format(self, fmt: Union[ModeFormat, List[ModeFormat], - Format]) -> None: - """Process the fmt argument for the Tensor constructor. - - Args: - fmt: This argument can be a ModeFormat, List[ModeFormat], or format. If - this argument is a ModeFormat, uses this ModeFormat for all the tensor - dimensions. If this argument is a list of ModeFormat, the len of the - list should equal to the rank of the tensor. If this argument is a - format, uses it for the format of the tensor. - - Raises: - ValueError: If fmt is not one of the expected type or is inconsistent - with the rank of the tensor. This is because fmt could be an users - input. - """ - if isinstance(fmt, ModeFormat): - self._format = _make_format([fmt] * self.order) - elif isinstance(fmt, list): - if len(fmt) == self.order and isinstance(fmt[0], ModeFormat): - self._format = _make_format(fmt) - else: - raise ValueError("Inconsistent shape and format: " - f"{self._shape}, {fmt}.") - elif isinstance(fmt, Format): - if fmt.rank() != self.order: - raise ValueError("Inconsistent shape and format: " - f"{self._shape}, {fmt}.") - else: - self._format = fmt - else: - raise ValueError(f"Invalid format argument: {fmt}.") - - def __init__(self, - value_or_shape: Optional[Union[List[int], Tuple[int, ...], - complex, float, int]] = None, - fmt: Optional[Union[ModeFormat, List[ModeFormat], - Format]] = None, - dtype: Optional[DType] = None, - name: Optional[str] = None, - is_dense: bool = False): - """The tensor constructor interface defined by TACO API. - - Args: - value_or_shape: This argument is optional and can be int, float, - List[int], or Tuple[int, ...]. If this argument is an int or float, - creates a scalar tensor and initializes it with the value. If this - argument is a list or tuple of int, uses it as the shape to create a - tensor. - fmt: This argument can be a ModeFormat, List[ModeFormat], or format. If - this argument is a ModeFormat, uses this ModeFormat for all the tensor - dimensions. If this argument is a list of ModeFormat, the len of the - list should equal to the rank of the tensor. If this argument is a - format, uses it for the format of the tensor. - dtype: An object of dtype, representing the data type of the tensor. - name: A string name of the tensor. If a name is not given, creates a - unique name for the tensor. - is_dense: A boolean variable to indicate whether the tensor is a dense - tensor without any sparsity annotation. - - Raises: - ValueError: If there is any inconsistency among the input arguments. - """ - # Take care of the argument default values common to both sparse tensors - # and dense tensors. - dtype = dtype or DType(Type.FLOAT32) - self._name = name or self._get_unique_name() - self._assignment = None - self._engine = None - self._sparse_value_location = _SparseValueInfo._UNPACKED - self._dense_storage = None - self._dtype = dtype - - if is_dense: - assert (fmt is None) - assert (isinstance(value_or_shape, tuple) or isinstance( - value_or_shape, list)) and _all_instance_of(value_or_shape, int) - self._shape = value_or_shape - self._format = None - return - - fmt = fmt or ModeFormat.COMPRESSED - # We currently use _coords and _values to host the sparse tensor value with - # COO format, and _dense_storage to host the dense tensor value. We don't - # support the conversion between the two storages. - self._coords = [] - self._values = [] - self._stats = _Stats() - if value_or_shape is None or isinstance(value_or_shape, int) or isinstance( - value_or_shape, float) or isinstance(value_or_shape, complex): - # Create a scalar tensor and ignore the fmt parameter. - self._shape = [] - self._format = _make_format([], []) - if value_or_shape is not None: - self._dense_storage = np.array(value_or_shape, dtype=self._dtype.value) - elif (isinstance(value_or_shape, tuple) or isinstance( - value_or_shape, list)) and _all_instance_of(value_or_shape, int): - # Create a tensor with the specified shape and format. - self._shape = list(value_or_shape) - self._init_format(fmt) - else: - raise ValueError("Invalid first argument. " - "Must be a tuple or list for a shape or a single value" - f"if initializing a scalar tensor: {value_or_shape}.") - - def _set_packed_sparse_tensor(self, pointer: ctypes.c_void_p) -> None: - """Records the MLIR sparse tensor pointer.""" - self._sparse_value_location = _SparseValueInfo._PACKED - self._packed_sparse_value = pointer - - def is_unpacked(self) -> bool: - """Returns true if the tensor value is not packed as MLIR sparse tensor.""" - return (self._sparse_value_location == _SparseValueInfo._UNPACKED) - - def unpack(self) -> None: - """Unpacks the MLIR sparse tensor representation.""" - if self.is_dense() or self.is_unpacked(): - return - - # Use the output MLIR sparse tensor pointer to retrieve the COO-flavored - # values and verify the values. - rank, nse, shape, values, indices = utils.sparse_tensor_to_coo_tensor( - self._packed_sparse_value, self._dtype.value) - assert rank == self.order - assert np.array_equal(self.shape, shape) - assert nse == len(values) - self._coords = indices - self._values = values - self._sparse_value_location = _SparseValueInfo._UNPACKED - - def __repr__(self) -> str: - self._sync_value() - self.unpack() - value_str = (f"{repr(self._dense_storage)})" if self.is_dense() else - f"{repr(self._coords)} {repr(self._values)})") - return (f"Tensor(_name={repr(self._name)} " - f"_dtype={repr(self._dtype)} : ") + value_str - - def insert(self, coords: List[int], val: Union[complex, float, int]) -> None: - """Inserts a value to the given coordinate. + dtype: DType, shape: Tuple[int, ...], attr: Optional[sparse_tensor.EncodingAttr] +) -> ir.RankedTensorType: + """Returns an MLIR tensor type. Args: - coords: A list of integer coordinates. The length of the list must be the - same as the rank of the tensor. - val: A value being inserted. It is either an integral or a floating point - value. This value will be converted to the data type of the tensor. - - Raises: - ValueError: When there is any problem in the parameters. - """ - if self.is_dense(): - raise ValueError("Insert method is not supported for dense tensors.") - if self._assignment != None or not self.is_unpacked(): - raise ValueError( - "Can't use Insert method for a tensor constructed from a file.") - if not isinstance(coords, list): - raise ValueError(f"Non list coordinate detected: {coords}.") - if not _all_instance_of(coords, int): - raise ValueError(f"Non integer coordinate detected: {coords}.") - if (len(coords) != self.order or - any([c < 0 or c >= self._shape[i] for i, c in enumerate(coords)])): - raise ValueError("Invalid coordinate for rank: " - f"{self.order}, {coords}.") - - if not isinstance(val, int) and not isinstance( - val, float) and not isinstance(val, complex): - raise ValueError(f"Value is neither int nor float: {val}.") - - self._coords.append(tuple(coords)) - self._values.append(self._dtype.value(val)) - - def is_dense(self) -> bool: - """Returns true if the tensor doesn't have sparsity annotation.""" - return self.order == 0 or self._format is None - - def to_array(self) -> np.ndarray: - """Returns the numpy array for the Tensor. - - This is currenly only implemented for dense Tensor. - """ - if not self.is_dense(): - raise ValueError("Conversion from non-dense Tensor " - "to numpy array not supported yet.") - - self._sync_value() - - return self._dense_storage - - @staticmethod - def from_array(array: np.ndarray) -> "Tensor": - """Returns a dense tensor with the value copied from the input array. - - We currently only support the conversion of float32 and float64 numpy arrays - to Tensor. - - Args: - array: The numpy array that provides the data type, shape and value for - the tensor. + dtype: An DType object for the element data type of the tensor. + shape: A tuple of integer for the shape of the tensor. + attr: An optional MLIR sparse tensor attribute, only provided if the tensor + is a sparse tensor. Returns: - A Tensor object. - - Raises: - ValueError if the data type of the numpy array is not supported. + An MLIR ranked tensor type. """ - if array.dtype != np.float32 and array.dtype != np.float64: - raise ValueError(f"Expected floating point value type: {array.dtype}.") - t = Tensor( - array.shape, - dtype=_nptype_to_taco_type(array.dtype.type), - is_dense=True) - t._dense_storage = np.copy(array) - return t - - @staticmethod - def from_coo( - coordinates: List[Tuple[int, ...]], - values: List[_AnyRuntimeType], - fmt: Format, - dtype: DType, - ) -> "Tensor": - """Converts coordinates and values to a sparse tensor representation. + ir_type = _mlir_type_from_taco_type(dtype) + return ir.RankedTensorType.get(shape, ir_type, attr) - Args: - coordinates: A list of coordinates with non-zero values. - values: The non-zero values. - fmt: The tensor storage format. - dtype: The tensor element data type. - Returns: - A tensor with the given non-zero values and storage format. The shape of - the tensor has the minimum size for each dimension to make the given - coordinates valid. - """ - assert (isinstance(coordinates, List) and - _all_instance_of(coordinates, Tuple)) - assert (isinstance(values, List) and _all_instance_of(values, dtype.value)) - assert isinstance(fmt, Format) - - rank = fmt.rank() - assert all(len(c) == rank and _all_instance_of(c, int) for c in coordinates) - - # Find the maximum coordinate value for each dimension. - max_coordinate = list(map(max, zip(*coordinates))) - # The size of each dimension is one more that such a maximum coordinate - # value. - shape = [c + 1 for c in max_coordinate] - t = Tensor(shape, fmt, dtype=dtype) - t._coords = coordinates - t._values = values - - return tensor - - @staticmethod - def from_file( - filename: str, - fmt: Format, - dtype: DType, - ) -> "Tensor": - """Constructs a sparse tensor using the COO-flavored values from a file. - - Args: - filename: A string for the name of the file that contains the sparse - tensor data. - fmt: The tensor storage format. - dtype: The tensor element data type. - - Returns: - A tensor with the given non-zero values and storage format. The tensor - value is stored as an MLIR sparse tensor. +@dataclasses.dataclass(frozen=True) +class _StructOpInfo: + """Information for generating a structured op in the linalg dialect. + + This information is associated with an expression node that serves as the + root for an expression subtree implemented with a structured op. + + Attributes: + dst_indices: A tuple of IndexVar, representing the result dimensions of the + structured op. This is used to construct the temporary variable for the + tensor to hold the structured op result. + dst_dims: A tuple of int, representing the result shape of the structured + op. + dst_dtype: A DType representing the data type of the structured op result. + dst_name: A string representing the name of the structured op result. + dst_format: An optional Format object representing the destination tensor + format. None represents a true dense tensor. """ - sparse_tensor, shape = utils.create_sparse_tensor(filename, - fmt.format_pack.formats, - _dtype_to_mlir_str(dtype)) - t = Tensor(shape.tolist(), fmt, dtype=dtype) - t._set_packed_sparse_tensor(sparse_tensor) - - return t - def to_file(self, filename: str) -> None: - """Output the tensor value to a file. + dst_indices: Tuple[IndexVar, ...] + dst_dims: Tuple[int, ...] + dst_dtype: DType + dst_name: str + dst_format: Optional[Format] + + def __post_init__(self) -> None: + """Verifies the integrity of the attribute values.""" + assert len(self.dst_indices) == len(self.dst_dims) + + def emit_tensor_init(self) -> ir.RankedTensorType: + """Returns an initialization for the destination tensor.""" + if self.dst_format is None or self.dst_format.rank() == 0: + # Initialize the dense tensor. + ir_type = _mlir_type_from_taco_type(self.dst_dtype) + empty = tensor.EmptyOp(self.dst_dims, ir_type).result + zero = arith.ConstantOp(ir_type, 0.0) + return linalg.fill(zero, outs=[empty]) + + # Initialize the sparse tensor. + mlir_type = _mlir_tensor_type( + self.dst_dtype, self.dst_dims, self.dst_format.mlir_tensor_attr() + ) + index_type = ir.IndexType.get() + return bufferization.AllocTensorOp(mlir_type, [], None, None, None) - This method evaluates any pending assignment to the tensor and outputs the - tensor value. - Args: - filename: A string file name. +class _Stats: + """Information to describe how a tensor expression is implemented. - Raises: - ValueError: If the tensor is dense, or an unpacked sparse tensor. + Currently, we only record the temporary tensors introduced for splitting the + original expression. """ - self._sync_value() - - if self.is_dense(): - raise ValueError("Writing dense tensors without sparsity annotation to " - "file is not supported.") - if self.is_unpacked(): - raise ValueError("Writing unpacked sparse tensors to file is not " - "supported.") + def __init__(self): + self._temps = [] - utils.output_sparse_tensor(self._packed_sparse_value, filename, - self._format.format_pack.formats, - _dtype_to_mlir_str(self._dtype)) + def __repr__(self) -> str: + return f"_Stats({repr(self._temps)})" - @property - def dtype(self) -> DType: - """Returns the data type for the Tensor.""" - return self._dtype + def add_element(self, structop: _StructOpInfo): + """Adds a temporary tensor.""" + self._temps.append(structop) - @property - def format(self) -> Format: - """Returns the storage format for the Tensor.""" - return self._format + def get_total(self) -> int: + """Gets the total number of temporary tensors.""" + return len(self._temps) - @property - def name(self) -> str: - """Returns the name for the Tensor.""" - return self._name + def _get_element(self, idx: int) -> _StructOpInfo: + """Gets the ith temporary tensor.""" + assert idx < self.get_total() + return self._temps[idx] - @property - def order(self) -> int: - """Returns the rank of the Tensor.""" - return len(self._shape) + def get_dimensions(self, idx: int) -> Tuple[int]: + """Gets the dimensions for the ith temporary tensor.""" + return self._get_element(idx).dst_dims - @property - def shape(self) -> List[int]: - """Returns the shape of the Tensor.""" - return self._shape + def get_formats(self, idx: int) -> Tuple[ModeFormat]: + """Gets the ModeFormats for the ith temporary tensor.""" + return tuple(self._get_element(idx).dst_format.format_pack.formats) - def _verify_and_normalize_indices(self, indices) -> Tuple[IndexVar, ...]: - """Verifies and normalizes the indices to access the tensor. - Args: - indices: The index expression used to access a tensor, which could be any - Python object from user inputs. - - Returns: - A tuple of IndexVar. - - Raises: - ValueError: If indices is not 0 for scalar tensors, or not an IndexVar or - a tuple of IndexVar for other tensors. +class _SparseValueInfo(enum.Enum): + """Describes how a sparse tensor value is stored. + _UNPACKED: The sparse tensor value is stored as (coordnates, values) in + Python. + _PACKED: The sparse tensor value is stored as a C pointer to a packed MLIR + sparse tensor. """ - if self.order == 0: - if not isinstance(indices, int) or indices != 0: - raise ValueError(f"Expected 0 to index scalar tensors: {indices}") - return () - - if isinstance(indices, IndexVar): - return (indices,) - elif isinstance(indices, tuple) and _all_instance_of(indices, IndexVar): - return indices - raise ValueError(f"Expected IndexVars: {indices}") + _UNPACKED = 0 + _PACKED = 1 - def __getitem__(self, key) -> "Access": - """Verifies and processes a tensor access. - In the tensor index notation, a tensor access T[i, j] is represented as - retrieving a value with key (i, j) from the tensor object T in Python. This - routine verifies the key for the tensor access and returns a tensor access - object. - - Args: - key: The key used to access the tensor, which could be any Python object - from user inputs. +@dataclasses.dataclass(frozen=True) +class _Assignment: + """Records an assignment to a tensor T as T[indices] = expression.""" - Returns: - The corresponding tensor access object. + indices: Tuple["IndexVar", ...] + expression: "IndexExpr" - Raises: - ValueError: If key is not an IndexVar or a tuple of IndexVar. - """ - indices = self._verify_and_normalize_indices(key) - return Access(self, indices) - def __setitem__(self, key, value) -> None: - """Verifies and processes a tensor assignment. +class Tensor: + """The tensor class. - In the tensor index notation, a tensor assignment "T[i, j] = ..." is - represented as setting a value for a tensor object T via key (i, j) in - Python. This routine verifies the key, evaluates the value, and assigns the - value to the tensor. + We support the TACO API tensor class with an alias of this class. - We only support assignment of dense tensor currently. + This class is part of the TACO API with the following methods: + insert: Inserts a value to the given coordinate in the tensor. + to_array: Returns a numpy ndarray for the tensor. - Args: - key: The key used to access the tensor, which could be any Python object - from user inputs. - value: The value assigned to the tensor, which could be any Python object - from user inputs. + TACO API also defines the following arrtibutes for the class: + dtype: A dtype object representing the data type of the tensor. + format: A format object representing the storage format of the tensor. + name: A string object representing the name of the tensor. + order: An integral rank of the tensor. + shape: A list of integers representing the shape of the tensor. - Raises: - ValueError: If tensor is not a dense tensor, or the key is not an IndexVar - or a tuple of IndexVar, or the length of the indices is not the same as - the rank of the tensor. + We currently ignore the tensor dimension ordering for dense tensor. """ - indices = self._verify_and_normalize_indices(key) - if len(indices) != self.order: - raise ValueError("Mismatch between indices and tensor rank: " - f"len({indices}) != {self.order}.") - self._assignment = _Assignment(indices, value) - self._engine = None - - def compile(self, force_recompile: bool = False) -> None: - """Compiles the tensor assignment to an execution engine. - - Calling compile the second time does not do anything unless - force_recompile is True. + _counter = _AtomicCounter() + + def _get_unique_name(self) -> str: + """Returns a unique name for creating a new Tensor.""" + return f"{_TACO_TENSOR_PREFIX}{self._counter.increment()}" + + def _init_format(self, fmt: Union[ModeFormat, List[ModeFormat], Format]) -> None: + """Process the fmt argument for the Tensor constructor. + + Args: + fmt: This argument can be a ModeFormat, List[ModeFormat], or format. If + this argument is a ModeFormat, uses this ModeFormat for all the tensor + dimensions. If this argument is a list of ModeFormat, the len of the + list should equal to the rank of the tensor. If this argument is a + format, uses it for the format of the tensor. + + Raises: + ValueError: If fmt is not one of the expected type or is inconsistent + with the rank of the tensor. This is because fmt could be an users + input. + """ + if isinstance(fmt, ModeFormat): + self._format = _make_format([fmt] * self.order) + elif isinstance(fmt, list): + if len(fmt) == self.order and isinstance(fmt[0], ModeFormat): + self._format = _make_format(fmt) + else: + raise ValueError( + "Inconsistent shape and format: " f"{self._shape}, {fmt}." + ) + elif isinstance(fmt, Format): + if fmt.rank() != self.order: + raise ValueError( + "Inconsistent shape and format: " f"{self._shape}, {fmt}." + ) + else: + self._format = fmt + else: + raise ValueError(f"Invalid format argument: {fmt}.") + + def __init__( + self, + value_or_shape: Optional[ + Union[List[int], Tuple[int, ...], complex, float, int] + ] = None, + fmt: Optional[Union[ModeFormat, List[ModeFormat], Format]] = None, + dtype: Optional[DType] = None, + name: Optional[str] = None, + is_dense: bool = False, + ): + """The tensor constructor interface defined by TACO API. + + Args: + value_or_shape: This argument is optional and can be int, float, + List[int], or Tuple[int, ...]. If this argument is an int or float, + creates a scalar tensor and initializes it with the value. If this + argument is a list or tuple of int, uses it as the shape to create a + tensor. + fmt: This argument can be a ModeFormat, List[ModeFormat], or format. If + this argument is a ModeFormat, uses this ModeFormat for all the tensor + dimensions. If this argument is a list of ModeFormat, the len of the + list should equal to the rank of the tensor. If this argument is a + format, uses it for the format of the tensor. + dtype: An object of dtype, representing the data type of the tensor. + name: A string name of the tensor. If a name is not given, creates a + unique name for the tensor. + is_dense: A boolean variable to indicate whether the tensor is a dense + tensor without any sparsity annotation. + + Raises: + ValueError: If there is any inconsistency among the input arguments. + """ + # Take care of the argument default values common to both sparse tensors + # and dense tensors. + dtype = dtype or DType(Type.FLOAT32) + self._name = name or self._get_unique_name() + self._assignment = None + self._engine = None + self._sparse_value_location = _SparseValueInfo._UNPACKED + self._dense_storage = None + self._dtype = dtype + + if is_dense: + assert fmt is None + assert ( + isinstance(value_or_shape, tuple) or isinstance(value_or_shape, list) + ) and _all_instance_of(value_or_shape, int) + self._shape = value_or_shape + self._format = None + return + + fmt = fmt or ModeFormat.COMPRESSED + # We currently use _coords and _values to host the sparse tensor value with + # COO format, and _dense_storage to host the dense tensor value. We don't + # support the conversion between the two storages. + self._coords = [] + self._values = [] + self._stats = _Stats() + if ( + value_or_shape is None + or isinstance(value_or_shape, int) + or isinstance(value_or_shape, float) + or isinstance(value_or_shape, complex) + ): + # Create a scalar tensor and ignore the fmt parameter. + self._shape = [] + self._format = _make_format([], []) + if value_or_shape is not None: + self._dense_storage = np.array(value_or_shape, dtype=self._dtype.value) + elif ( + isinstance(value_or_shape, tuple) or isinstance(value_or_shape, list) + ) and _all_instance_of(value_or_shape, int): + # Create a tensor with the specified shape and format. + self._shape = list(value_or_shape) + self._init_format(fmt) + else: + raise ValueError( + "Invalid first argument. " + "Must be a tuple or list for a shape or a single value" + f"if initializing a scalar tensor: {value_or_shape}." + ) + + def _set_packed_sparse_tensor(self, pointer: ctypes.c_void_p) -> None: + """Records the MLIR sparse tensor pointer.""" + self._sparse_value_location = _SparseValueInfo._PACKED + self._packed_sparse_value = pointer + + def is_unpacked(self) -> bool: + """Returns true if the tensor value is not packed as MLIR sparse tensor.""" + return self._sparse_value_location == _SparseValueInfo._UNPACKED + + def unpack(self) -> None: + """Unpacks the MLIR sparse tensor representation.""" + if self.is_dense() or self.is_unpacked(): + return + + # Use the output MLIR sparse tensor pointer to retrieve the COO-flavored + # values and verify the values. + rank, nse, shape, values, indices = utils.sparse_tensor_to_coo_tensor( + self._packed_sparse_value, self._dtype.value + ) + assert rank == self.order + assert np.array_equal(self.shape, shape) + assert nse == len(values) + self._coords = indices + self._values = values + self._sparse_value_location = _SparseValueInfo._UNPACKED + + def __repr__(self) -> str: + self._sync_value() + self.unpack() + value_str = ( + f"{repr(self._dense_storage)})" + if self.is_dense() + else f"{repr(self._coords)} {repr(self._values)})" + ) + return ( + f"Tensor(_name={repr(self._name)} " f"_dtype={repr(self._dtype)} : " + ) + value_str + + def insert(self, coords: List[int], val: Union[complex, float, int]) -> None: + """Inserts a value to the given coordinate. + + Args: + coords: A list of integer coordinates. The length of the list must be the + same as the rank of the tensor. + val: A value being inserted. It is either an integral or a floating point + value. This value will be converted to the data type of the tensor. + + Raises: + ValueError: When there is any problem in the parameters. + """ + if self.is_dense(): + raise ValueError("Insert method is not supported for dense tensors.") + if self._assignment != None or not self.is_unpacked(): + raise ValueError( + "Can't use Insert method for a tensor constructed from a file." + ) + if not isinstance(coords, list): + raise ValueError(f"Non list coordinate detected: {coords}.") + if not _all_instance_of(coords, int): + raise ValueError(f"Non integer coordinate detected: {coords}.") + if len(coords) != self.order or any( + [c < 0 or c >= self._shape[i] for i, c in enumerate(coords)] + ): + raise ValueError("Invalid coordinate for rank: " f"{self.order}, {coords}.") + + if ( + not isinstance(val, int) + and not isinstance(val, float) + and not isinstance(val, complex) + ): + raise ValueError(f"Value is neither int nor float: {val}.") + + self._coords.append(tuple(coords)) + self._values.append(self._dtype.value(val)) + + def is_dense(self) -> bool: + """Returns true if the tensor doesn't have sparsity annotation.""" + return self.order == 0 or self._format is None + + def to_array(self) -> np.ndarray: + """Returns the numpy array for the Tensor. + + This is currenly only implemented for dense Tensor. + """ + if not self.is_dense(): + raise ValueError( + "Conversion from non-dense Tensor " "to numpy array not supported yet." + ) + + self._sync_value() + + return self._dense_storage + + @staticmethod + def from_array(array: np.ndarray) -> "Tensor": + """Returns a dense tensor with the value copied from the input array. + + We currently only support the conversion of float32 and float64 numpy arrays + to Tensor. + + Args: + array: The numpy array that provides the data type, shape and value for + the tensor. + + Returns: + A Tensor object. + + Raises: + ValueError if the data type of the numpy array is not supported. + """ + if array.dtype != np.float32 and array.dtype != np.float64: + raise ValueError(f"Expected floating point value type: {array.dtype}.") + t = Tensor( + array.shape, dtype=_nptype_to_taco_type(array.dtype.type), is_dense=True + ) + t._dense_storage = np.copy(array) + return t + + @staticmethod + def from_coo( + coordinates: List[Tuple[int, ...]], + values: List[_AnyRuntimeType], + fmt: Format, + dtype: DType, + ) -> "Tensor": + """Converts coordinates and values to a sparse tensor representation. + + Args: + coordinates: A list of coordinates with non-zero values. + values: The non-zero values. + fmt: The tensor storage format. + dtype: The tensor element data type. + + Returns: + A tensor with the given non-zero values and storage format. The shape of + the tensor has the minimum size for each dimension to make the given + coordinates valid. + """ + assert isinstance(coordinates, List) and _all_instance_of(coordinates, Tuple) + assert isinstance(values, List) and _all_instance_of(values, dtype.value) + assert isinstance(fmt, Format) + + rank = fmt.rank() + assert all(len(c) == rank and _all_instance_of(c, int) for c in coordinates) + + # Find the maximum coordinate value for each dimension. + max_coordinate = list(map(max, zip(*coordinates))) + # The size of each dimension is one more that such a maximum coordinate + # value. + shape = [c + 1 for c in max_coordinate] + t = Tensor(shape, fmt, dtype=dtype) + t._coords = coordinates + t._values = values + + return tensor + + @staticmethod + def from_file( + filename: str, + fmt: Format, + dtype: DType, + ) -> "Tensor": + """Constructs a sparse tensor using the COO-flavored values from a file. + + Args: + filename: A string for the name of the file that contains the sparse + tensor data. + fmt: The tensor storage format. + dtype: The tensor element data type. + + Returns: + A tensor with the given non-zero values and storage format. The tensor + value is stored as an MLIR sparse tensor. + """ + sparse_tensor, shape = utils.create_sparse_tensor( + filename, fmt.format_pack.formats, _dtype_to_mlir_str(dtype) + ) + t = Tensor(shape.tolist(), fmt, dtype=dtype) + t._set_packed_sparse_tensor(sparse_tensor) + + return t + + def to_file(self, filename: str) -> None: + """Output the tensor value to a file. + + This method evaluates any pending assignment to the tensor and outputs the + tensor value. + + Args: + filename: A string file name. + + Raises: + ValueError: If the tensor is dense, or an unpacked sparse tensor. + """ + self._sync_value() + + if self.is_dense(): + raise ValueError( + "Writing dense tensors without sparsity annotation to " + "file is not supported." + ) + + if self.is_unpacked(): + raise ValueError( + "Writing unpacked sparse tensors to file is not " "supported." + ) + + utils.output_sparse_tensor( + self._packed_sparse_value, + filename, + self._format.format_pack.formats, + _dtype_to_mlir_str(self._dtype), + ) + + @property + def dtype(self) -> DType: + """Returns the data type for the Tensor.""" + return self._dtype + + @property + def format(self) -> Format: + """Returns the storage format for the Tensor.""" + return self._format + + @property + def name(self) -> str: + """Returns the name for the Tensor.""" + return self._name + + @property + def order(self) -> int: + """Returns the rank of the Tensor.""" + return len(self._shape) + + @property + def shape(self) -> List[int]: + """Returns the shape of the Tensor.""" + return self._shape + + def _verify_and_normalize_indices(self, indices) -> Tuple[IndexVar, ...]: + """Verifies and normalizes the indices to access the tensor. + + Args: + indices: The index expression used to access a tensor, which could be any + Python object from user inputs. + + Returns: + A tuple of IndexVar. + + Raises: + ValueError: If indices is not 0 for scalar tensors, or not an IndexVar or + a tuple of IndexVar for other tensors. + """ + if self.order == 0: + if not isinstance(indices, int) or indices != 0: + raise ValueError(f"Expected 0 to index scalar tensors: {indices}") + return () + + if isinstance(indices, IndexVar): + return (indices,) + elif isinstance(indices, tuple) and _all_instance_of(indices, IndexVar): + return indices + + raise ValueError(f"Expected IndexVars: {indices}") + + def __getitem__(self, key) -> "Access": + """Verifies and processes a tensor access. + + In the tensor index notation, a tensor access T[i, j] is represented as + retrieving a value with key (i, j) from the tensor object T in Python. This + routine verifies the key for the tensor access and returns a tensor access + object. + + Args: + key: The key used to access the tensor, which could be any Python object + from user inputs. + + Returns: + The corresponding tensor access object. + + Raises: + ValueError: If key is not an IndexVar or a tuple of IndexVar. + """ + indices = self._verify_and_normalize_indices(key) + return Access(self, indices) + + def __setitem__(self, key, value) -> None: + """Verifies and processes a tensor assignment. + + In the tensor index notation, a tensor assignment "T[i, j] = ..." is + represented as setting a value for a tensor object T via key (i, j) in + Python. This routine verifies the key, evaluates the value, and assigns the + value to the tensor. + + We only support assignment of dense tensor currently. + + Args: + key: The key used to access the tensor, which could be any Python object + from user inputs. + value: The value assigned to the tensor, which could be any Python object + from user inputs. + + Raises: + ValueError: If tensor is not a dense tensor, or the key is not an IndexVar + or a tuple of IndexVar, or the length of the indices is not the same as + the rank of the tensor. + """ + indices = self._verify_and_normalize_indices(key) + if len(indices) != self.order: + raise ValueError( + "Mismatch between indices and tensor rank: " + f"len({indices}) != {self.order}." + ) + + self._assignment = _Assignment(indices, value) + self._engine = None + + def compile(self, force_recompile: bool = False) -> None: + """Compiles the tensor assignment to an execution engine. + + Calling compile the second time does not do anything unless + force_recompile is True. + + Args: + force_recompile: A boolean value to enable recompilation, such as for the + purpose of timing. + + Raises: + ValueError: If the assignment is not proper or not supported. + """ + if self._assignment is None or ( + self._engine is not None and not force_recompile + ): + return + + self._engine = self._assignment.expression.compile( + self, self._assignment.indices + ) + + def compute(self) -> None: + """Executes the engine for the tensor assignment. + + Raises: + ValueError: If the assignment hasn't been compiled yet. + """ + if self._assignment is None: + return + + if self._engine is None: + raise ValueError("Need to invoke compile() before invoking compute().") + + input_accesses = self._assignment.expression.get_input_accesses() + # Gather the pointers for the input buffers. + input_pointers = [a.tensor.ctype_pointer() for a in input_accesses] + if self.is_dense(): + # The pointer to receive dense output is the first argument to the + # execution engine. + arg_pointers = [self.dense_dst_ctype_pointer()] + input_pointers + else: + # The pointer to receive the sparse tensor output is the last argument + # to the execution engine and is a pointer to pointer of char. + arg_pointers = input_pointers + [ + ctypes.pointer(ctypes.pointer(ctypes.c_char(0))) + ] + + # Invoke the execution engine to run the module. + self._engine.invoke(_ENTRY_NAME, *arg_pointers) + + # Retrieve the result. + if self.is_dense(): + result = runtime.ranked_memref_to_numpy(arg_pointers[0][0]) + assert isinstance(result, np.ndarray) + self._dense_storage = result + else: + self._set_packed_sparse_tensor(arg_pointers[-1][0]) + + self._assignment = None + self._engine = None + + def evaluate(self) -> None: + """Evaluates the tensor assignment.""" + self.compile() + self.compute() + + def _sync_value(self) -> None: + """Updates the tensor value by evaluating the pending assignment.""" + if self._assignment is not None: + self.evaluate() + + def mlir_tensor_type(self) -> ir.RankedTensorType: + """Returns the MLIR type for the tensor.""" + mlir_attr = ( + None + if (self._format is None or self.order == 0) + else self._format.mlir_tensor_attr() + ) + return _mlir_tensor_type(self._dtype, tuple(self._shape), mlir_attr) + + def dense_dst_ctype_pointer(self) -> ctypes.pointer: + """Returns the ctypes pointer for the pointer to an MemRefDescriptor. + + For a dense tensor output, the MLIR compiler allocates the storage for + the tensor. This routine returns the pointer to an MLIR MemRefDescriptor for + receiving the tensor. + """ + assert self.is_dense() + mem_ref_desc = runtime.make_nd_memref_descriptor( + self.order, np.ctypeslib.as_ctypes_type(self.dtype.value) + )() + return ctypes.pointer(ctypes.pointer(mem_ref_desc)) + + def ctype_pointer(self) -> ctypes.pointer: + """Returns the ctypes pointer for the pointer to the input tensor.""" + if self.is_dense(): + if self._dense_storage is None: + self._dense_storage = np.zeros(self._shape, self._dtype.value) + return _ctype_pointer_from_array(self._dense_storage) + + if self.is_unpacked(): + shape = np.array(self._shape, np.int64) + indices = np.array(self._coords, np.int64) + values = np.array(self._values, self._dtype.value) + perm, sparse = self.format.get_permutation_and_sparsity() + ptr = utils.coo_tensor_to_sparse_tensor( + shape, values, indices, perm, sparse + ) + else: + ptr = self._packed_sparse_value + + return ctypes.pointer(ctypes.cast(ptr, ctypes.c_void_p)) + + def get_scalar_value(self) -> _AnyRuntimeType: + """Returns the value for the scalar tensor. + + This method also evaluates the assignment to the tensor. + + Raises: + ValueError: If the tensor is not a scalar. + """ + if self.order != 0: + raise ValueError(f"Expected a scalar tensor, got: rank={self.order}") + + self._sync_value() + return self._dense_storage + + def get_coordinates_and_values( + self, + ) -> Tuple[List[Tuple[int, ...]], List[_AnyRuntimeType]]: + """Returns the coordinates and values for the non-zero elements. + + This method also evaluates the assignment to the tensor and unpack the + sparse tensor. + """ + self._sync_value() + + if not self.is_dense(): + self.unpack() + return (self._coords, self._values) + + if self.order == 0: + return ([], self._dense_storage) + + # Coordinates for non-zero elements, grouped by dimensions. + coords_by_dims = self._dense_storage.nonzero() + # Coordinates for non-zero elements, grouped by elements. + coords = np.transpose(coords_by_dims) + values = self._dense_storage[coords_by_dims] + return (coords, values) + + def _record_stats(self, structop: "_StructOpInfo"): + """Collects information for temporary tensors.""" + # Exclude user specified destination tensors. + if structop.dst_name == self.name: + return + + self._stats.add_element(structop) + + +def _emit_operand( + op_def: lang.LinalgOpDef, + indices: Tuple[IndexVar, ...], + name: str, + kind: lang.OperandKind, +) -> lang.OperandDef: + """Emits an operand for a tensor access in the current linalg operation. Args: - force_recompile: A boolean value to enable recompilation, such as for the - purpose of timing. - - Raises: - ValueError: If the assignment is not proper or not supported. - """ - if self._assignment is None or (self._engine is not None and - not force_recompile): - return - - self._engine = self._assignment.expression.compile(self, - self._assignment.indices) + op_def: A LinalgOpDef representing the current linalg dialect operation. + indices: A tuple of IndexVar used to access the tensor. + name: A unique string name of the tensor. + kind: An OperandKind for the operand. - def compute(self) -> None: - """Executes the engine for the tensor assignment. - - Raises: - ValueError: If the assignment hasn't been compiled yet. - """ - if self._assignment is None: - return - - if self._engine is None: - raise ValueError("Need to invoke compile() before invoking compute().") - - input_accesses = self._assignment.expression.get_input_accesses() - # Gather the pointers for the input buffers. - input_pointers = [a.tensor.ctype_pointer() for a in input_accesses] - if self.is_dense(): - # The pointer to receive dense output is the first argument to the - # execution engine. - arg_pointers = [self.dense_dst_ctype_pointer()] + input_pointers - else: - # The pointer to receive the sparse tensor output is the last argument - # to the execution engine and is a pointer to pointer of char. - arg_pointers = input_pointers + [ - ctypes.pointer(ctypes.pointer(ctypes.c_char(0))) - ] - - # Invoke the execution engine to run the module. - self._engine.invoke(_ENTRY_NAME, *arg_pointers) - - # Retrieve the result. - if self.is_dense(): - result = runtime.ranked_memref_to_numpy(arg_pointers[0][0]) - assert isinstance(result, np.ndarray) - self._dense_storage = result - else: - self._set_packed_sparse_tensor(arg_pointers[-1][0]) - - self._assignment = None - self._engine = None - - def evaluate(self) -> None: - """Evaluates the tensor assignment.""" - self.compile() - self.compute() - - def _sync_value(self) -> None: - """Updates the tensor value by evaluating the pending assignment.""" - if self._assignment is not None: - self.evaluate() - - def mlir_tensor_type(self) -> ir.RankedTensorType: - """Returns the MLIR type for the tensor.""" - mlir_attr = (None if (self._format is None or self.order == 0) else - self._format.mlir_tensor_attr()) - return _mlir_tensor_type(self._dtype, tuple(self._shape), mlir_attr) - - def dense_dst_ctype_pointer(self) -> ctypes.pointer: - """Returns the ctypes pointer for the pointer to an MemRefDescriptor. - - For a dense tensor output, the MLIR compiler allocates the storage for - the tensor. This routine returns the pointer to an MLIR MemRefDescriptor for - receiving the tensor. - """ - assert self.is_dense() - mem_ref_desc = runtime.make_nd_memref_descriptor( - self.order, np.ctypeslib.as_ctypes_type(self.dtype.value))() - return ctypes.pointer(ctypes.pointer(mem_ref_desc)) - - def ctype_pointer(self) -> ctypes.pointer: - """Returns the ctypes pointer for the pointer to the input tensor.""" - if self.is_dense(): - if self._dense_storage is None: - self._dense_storage = np.zeros(self._shape, self._dtype.value) - return _ctype_pointer_from_array(self._dense_storage) - - if self.is_unpacked(): - shape = np.array(self._shape, np.int64) - indices = np.array(self._coords, np.int64) - values = np.array(self._values, self._dtype.value) - perm, sparse = self.format.get_permutation_and_sparsity() - ptr = utils.coo_tensor_to_sparse_tensor(shape, values, indices, perm, - sparse) - else: - ptr = self._packed_sparse_value - - return ctypes.pointer(ctypes.cast(ptr, ctypes.c_void_p)) - - def get_scalar_value(self) -> _AnyRuntimeType: - """Returns the value for the scalar tensor. - - This method also evaluates the assignment to the tensor. - - Raises: - ValueError: If the tensor is not a scalar. - """ - if self.order != 0: - raise ValueError(f"Expected a scalar tensor, got: rank={self.order}") - - self._sync_value() - return self._dense_storage - - - def get_coordinates_and_values( - self) -> Tuple[List[Tuple[int, ...]], List[_AnyRuntimeType]]: - """Returns the coordinates and values for the non-zero elements. - - This method also evaluates the assignment to the tensor and unpack the - sparse tensor. + Returns: + An OperandDef representing the operand. """ - self._sync_value() - - if not self.is_dense(): - self.unpack() - return (self._coords, self._values) - - if self.order == 0: - return ([], self._dense_storage) - - # Coordinates for non-zero elements, grouped by dimensions. - coords_by_dims = self._dense_storage.nonzero() - # Coordinates for non-zero elements, grouped by elements. - coords = np.transpose(coords_by_dims) - values = self._dense_storage[coords_by_dims] - return (coords, values) - - def _record_stats(self, structop: "_StructOpInfo"): - """Collects information for temporary tensors.""" - # Exclude user specified destination tensors. - if structop.dst_name == self.name: - return - - self._stats.add_element(structop) - - -def _emit_operand(op_def: lang.LinalgOpDef, indices: Tuple[IndexVar, ...], - name: str, kind: lang.OperandKind) -> lang.OperandDef: - """Emits an operand for a tensor access in the current linalg operation. - - Args: - op_def: A LinalgOpDef representing the current linalg dialect operation. - indices: A tuple of IndexVar used to access the tensor. - name: A unique string name of the tensor. - kind: An OperandKind for the operand. - - Returns: - An OperandDef representing the operand. - """ - dim_sym = _mlir_symbols_from_index_vars(indices) - opnd = lang.OperandDef(kind, lang.T, dim_sym) - op_def.add_operand(name, opnd) - return opnd + dim_sym = _mlir_symbols_from_index_vars(indices) + opnd = lang.OperandDef(kind, lang.T, dim_sym) + op_def.add_operand(name, opnd) + return opnd @dataclasses.dataclass(frozen=True) class _DimInfo: - """Information for an operand dimension. + """Information for an operand dimension. - Attributes: - dim: An integer for the size of the dimension. - mode_format: A ModeFormat for the dimension sparsity. - """ - dim: int - mode_format: ModeFormat + Attributes: + dim: An integer for the size of the dimension. + mode_format: A ModeFormat for the dimension sparsity. + """ + + dim: int + mode_format: ModeFormat def _get_dummy_dim_info() -> _DimInfo: - """Constructs the _DimInfo for an index used in tensor expressions.""" - return _DimInfo(-1, ModeFormat.DENSE) + """Constructs the _DimInfo for an index used in tensor expressions.""" + return _DimInfo(-1, ModeFormat.DENSE) @dataclasses.dataclass() class _ExprInfo: - """Expression information for validation and code generation. - - Attributes: - src_indices: A tuple of IndexVar for the indices used by the tensors in the - expression tree. - dim_infos: A tuple of _DimInfo, representing the dimension information - corresponding to the src_indices. - reduce_indices: A set of IndexVar for the indices reduced by the expression. - acc_reduce_indices: An accumulated set of IndexVar for the indices reduced - by the expression and its children. - structop_info: Information to support the code generation for a structured - op in the linalg dialect, if the corresponding expression node is the root - of a subtree for a structured op. - mlir_value: The MLIR value generated for the structured op. - """ - src_indices: Tuple[IndexVar, ...] - dim_infos: Tuple[_DimInfo, ...] - reduce_indices: Optional[Set[IndexVar]] = None - acc_reduce_indices: Optional[Set[IndexVar]] = None - structop_info: Optional[_StructOpInfo] = None - mlir_value: Optional[ir.Value] = None - - def __post_init__(self) -> None: - """Verifies and fix up attribute values. - - Verifies the consistency of the attributes and modifies the default values - to support convenient initializer syntax. + """Expression information for validation and code generation. + + Attributes: + src_indices: A tuple of IndexVar for the indices used by the tensors in the + expression tree. + dim_infos: A tuple of _DimInfo, representing the dimension information + corresponding to the src_indices. + reduce_indices: A set of IndexVar for the indices reduced by the expression. + acc_reduce_indices: An accumulated set of IndexVar for the indices reduced + by the expression and its children. + structop_info: Information to support the code generation for a structured + op in the linalg dialect, if the corresponding expression node is the root + of a subtree for a structured op. + mlir_value: The MLIR value generated for the structured op. """ - assert len(self.src_indices) == len(self.dim_infos) - self.reduce_indices = self.reduce_indices or set() - self.acc_reduce_indices = self.acc_reduce_indices or set() + src_indices: Tuple[IndexVar, ...] + dim_infos: Tuple[_DimInfo, ...] + reduce_indices: Optional[Set[IndexVar]] = None + acc_reduce_indices: Optional[Set[IndexVar]] = None + structop_info: Optional[_StructOpInfo] = None + mlir_value: Optional[ir.Value] = None -@dataclasses.dataclass(frozen=True) -class Access(IndexExpr): - """The tensor access class. + def __post_init__(self) -> None: + """Verifies and fix up attribute values. + + Verifies the consistency of the attributes and modifies the default values + to support convenient initializer syntax. + """ + assert len(self.src_indices) == len(self.dim_infos) + self.reduce_indices = self.reduce_indices or set() + self.acc_reduce_indices = self.acc_reduce_indices or set() - We support the TACO API access class with an alias of this class. - Attributes: - tensor: A Tensor being accessed. - indices: A tuple of IndexVar, representing the indices used to access the - Tensor. - """ - tensor: Tensor - indices: Tuple[IndexVar, ...] +@dataclasses.dataclass(frozen=True) +class Access(IndexExpr): + """The tensor access class. - def __post_init__(self) -> None: - """Verifies the tensor and indices for a tensor access. + We support the TACO API access class with an alias of this class. - Raises: - ValueError: If indices is not a list of IndexVar or the len of indices - doesn't equal to the rank of the tensor. + Attributes: + tensor: A Tensor being accessed. + indices: A tuple of IndexVar, representing the indices used to access the + Tensor. """ - if (not isinstance(self.indices, tuple) or - not _all_instance_of(self.indices, IndexVar)): - raise ValueError(f"Indices contain non IndexVar: {str(self.indices)}.") - if self.tensor.order != len(self.indices): - raise ValueError("Invalid indices for rank: " - f"str{self.tensor.order} != len({str(self.indices)}).") - - def __repr__(self) -> str: - # The Tensor __repr__ method evaluates the pending assignment to the tensor. - # We want to define the __repr__ method here to avoid such evaluation of the - # tensor assignment. - indices_str = ", ".join(map(lambda i: i.name, self.indices)) - return (f"Tensor({self.tensor.name}) " f"Indices({indices_str})") - - def _emit_expression( - self, - expr_to_opnd: Dict[IndexExpr, lang.OperandDef], - expr_to_info: _ExprInfoDict, - ) -> lang.ScalarExpression: - """Emits a linalg dialect TensorUse expression for the tensor access.""" - assert self in expr_to_opnd - dims = _mlir_dimensions_from_index_vars(self.indices) - return lang.TensorUse(expr_to_opnd[self], dims) - - def _visit(self, - func: _ExprVisitor, - args, - *, - leaf_checker: _SubtreeLeafChecker = None) -> None: - if leaf_checker: - assert leaf_checker(self, *args) - func(self, *args) - - def dtype(self) -> DType: - return self.tensor.dtype + + tensor: Tensor + indices: Tuple[IndexVar, ...] + + def __post_init__(self) -> None: + """Verifies the tensor and indices for a tensor access. + + Raises: + ValueError: If indices is not a list of IndexVar or the len of indices + doesn't equal to the rank of the tensor. + """ + if not isinstance(self.indices, tuple) or not _all_instance_of( + self.indices, IndexVar + ): + raise ValueError(f"Indices contain non IndexVar: {str(self.indices)}.") + if self.tensor.order != len(self.indices): + raise ValueError( + "Invalid indices for rank: " + f"str{self.tensor.order} != len({str(self.indices)})." + ) + + def __repr__(self) -> str: + # The Tensor __repr__ method evaluates the pending assignment to the tensor. + # We want to define the __repr__ method here to avoid such evaluation of the + # tensor assignment. + indices_str = ", ".join(map(lambda i: i.name, self.indices)) + return f"Tensor({self.tensor.name}) " f"Indices({indices_str})" + + def _emit_expression( + self, + expr_to_opnd: Dict[IndexExpr, lang.OperandDef], + expr_to_info: _ExprInfoDict, + ) -> lang.ScalarExpression: + """Emits a linalg dialect TensorUse expression for the tensor access.""" + assert self in expr_to_opnd + dims = _mlir_dimensions_from_index_vars(self.indices) + return lang.TensorUse(expr_to_opnd[self], dims) + + def _visit( + self, func: _ExprVisitor, args, *, leaf_checker: _SubtreeLeafChecker = None + ) -> None: + if leaf_checker: + assert leaf_checker(self, *args) + func(self, *args) + + def dtype(self) -> DType: + return self.tensor.dtype def _gather_input_accesses_index_vars( expr: IndexExpr, input_accesses: List[Access], ) -> None: - """Collects Access nodes.""" - if isinstance(expr, Access) and expr not in input_accesses: - input_accesses.append(expr) + """Collects Access nodes.""" + if isinstance(expr, Access) and expr not in input_accesses: + input_accesses.append(expr) def _op_ceil(__a: Any) -> Any: - """A _UnaryOp object for operation ceil.""" - pass + """A _UnaryOp object for operation ceil.""" + pass def _op_floor(__a: Any) -> Any: - """A _UnaryOp object for operation floor.""" - pass + """A _UnaryOp object for operation floor.""" + pass def _op_unary_to_callable(op: _UnaryOp) -> lang.UnaryFnType: - """Returns the linalg dialect function object for the given operation.""" - op_to_callable = { - operator.abs: lang.UnaryFn.abs, - operator.neg: lang.UnaryFn.negf, - _op_ceil: lang.UnaryFn.ceil, - _op_floor: lang.UnaryFn.floor, - } - return op_to_callable[op] + """Returns the linalg dialect function object for the given operation.""" + op_to_callable = { + operator.abs: lang.UnaryFn.abs, + operator.neg: lang.UnaryFn.negf, + _op_ceil: lang.UnaryFn.ceil, + _op_floor: lang.UnaryFn.floor, + } + return op_to_callable[op] @dataclasses.dataclass(frozen=True) class _UnaryExpr(IndexExpr): - """The representation for a Unary operation. - - Attributes: - op: A _UnaryOp representing the operation. - a: An IndexExpr representing the operand for the operation. - """ - op: _BinaryOp - a: IndexExpr - - def __post_init__(self) -> None: - """Verifies that the operand being added is an IndexExpr.""" - assert isinstance(self.a, IndexExpr) - - def _emit_expression( - self, - expr_to_opnd: Dict[IndexExpr, lang.OperandDef], - expr_to_info: _ExprInfoDict, - ) -> lang.ScalarExpression: - """Emits the expression tree and returns the expression.""" - # The current expression node is an internal node of the structured op. - if self not in expr_to_opnd: - a = self.a._emit_expression(expr_to_opnd, expr_to_info) - return _op_unary_to_callable(self.op)(a) - - # The current expression is a leaf node of the structured op. That is, it is - # a temporary tensor generated by its child structured op. - op_info = expr_to_info[self].structop_info - assert op_info is not None - dims = _mlir_dimensions_from_index_vars(op_info.dst_indices) - return lang.TensorUse(expr_to_opnd[self], dims) - - def _visit(self, - func: _ExprVisitor, - args, - *, - leaf_checker: _SubtreeLeafChecker = None) -> None: - """A post-order visitor.""" - if leaf_checker is None or not leaf_checker(self, *args): - self.a._visit(func, args, leaf_checker=leaf_checker) - func(self, *args) - - def dtype(self) -> DType: - """Returns the data type of the operation.""" - return self.a.dtype() + """The representation for a Unary operation. + + Attributes: + op: A _UnaryOp representing the operation. + a: An IndexExpr representing the operand for the operation. + """ + + op: _BinaryOp + a: IndexExpr + + def __post_init__(self) -> None: + """Verifies that the operand being added is an IndexExpr.""" + assert isinstance(self.a, IndexExpr) + + def _emit_expression( + self, + expr_to_opnd: Dict[IndexExpr, lang.OperandDef], + expr_to_info: _ExprInfoDict, + ) -> lang.ScalarExpression: + """Emits the expression tree and returns the expression.""" + # The current expression node is an internal node of the structured op. + if self not in expr_to_opnd: + a = self.a._emit_expression(expr_to_opnd, expr_to_info) + return _op_unary_to_callable(self.op)(a) + + # The current expression is a leaf node of the structured op. That is, it is + # a temporary tensor generated by its child structured op. + op_info = expr_to_info[self].structop_info + assert op_info is not None + dims = _mlir_dimensions_from_index_vars(op_info.dst_indices) + return lang.TensorUse(expr_to_opnd[self], dims) + + def _visit( + self, func: _ExprVisitor, args, *, leaf_checker: _SubtreeLeafChecker = None + ) -> None: + """A post-order visitor.""" + if leaf_checker is None or not leaf_checker(self, *args): + self.a._visit(func, args, leaf_checker=leaf_checker) + func(self, *args) + + def dtype(self) -> DType: + """Returns the data type of the operation.""" + return self.a.dtype() def _op_to_callable(op: _BinaryOp) -> lang.BinaryFnType: - """Returns the linalg dialect function object for the given operation.""" - op_to_callable = { - operator.add: lang.BinaryFn.add, - operator.sub: lang.BinaryFn.sub, - operator.mul: lang.BinaryFn.mul, - } - return op_to_callable[op] + """Returns the linalg dialect function object for the given operation.""" + op_to_callable = { + operator.add: lang.BinaryFn.add, + operator.sub: lang.BinaryFn.sub, + operator.mul: lang.BinaryFn.mul, + } + return op_to_callable[op] + @dataclasses.dataclass(frozen=True) class _BinaryExpr(IndexExpr): - """The representation for a binary operation. - - Attributes: - op: A _BinaryOp representing the binary operation. - a: An IndexExpr representing the first operand of the operation. - b: An IndexExpr representing the second operand of the operation. - """ - op: _BinaryOp - a: IndexExpr - b: IndexExpr - - def __post_init__(self) -> None: - """Verifies that the operands being added are IndexExpr.""" - assert isinstance(self.a, IndexExpr) and isinstance(self.b, IndexExpr) - - def _emit_expression( - self, - expr_to_opnd: Dict[IndexExpr, lang.OperandDef], - expr_to_info: _ExprInfoDict, - ) -> lang.ScalarExpression: - """Emits the expression tree and returns the expression.""" - # The current expression node is an internal node of the structured op. - if self not in expr_to_opnd: - a = self.a._emit_expression(expr_to_opnd, expr_to_info) - b = self.b._emit_expression(expr_to_opnd, expr_to_info) - return _op_to_callable(self.op)(a, b) - - # The current expression is a leaf node of the structured op. That is, it is - # a temporary tensor generated by its child structured op. - op_info = expr_to_info[self].structop_info - assert op_info is not None - dims = _mlir_dimensions_from_index_vars(op_info.dst_indices) - return lang.TensorUse(expr_to_opnd[self], dims) - - def _visit(self, - func: _ExprVisitor, - args, - *, - leaf_checker: _SubtreeLeafChecker = None) -> None: - """A post-order visitor.""" - if leaf_checker is None or not leaf_checker(self, *args): - self.a._visit(func, args, leaf_checker=leaf_checker) - self.b._visit(func, args, leaf_checker=leaf_checker) - func(self, *args) - - def dtype(self) -> DType: - """Returns the data type of the binary operation.""" - return self.a.dtype() + """The representation for a binary operation. + + Attributes: + op: A _BinaryOp representing the binary operation. + a: An IndexExpr representing the first operand of the operation. + b: An IndexExpr representing the second operand of the operation. + """ + + op: _BinaryOp + a: IndexExpr + b: IndexExpr + + def __post_init__(self) -> None: + """Verifies that the operands being added are IndexExpr.""" + assert isinstance(self.a, IndexExpr) and isinstance(self.b, IndexExpr) + + def _emit_expression( + self, + expr_to_opnd: Dict[IndexExpr, lang.OperandDef], + expr_to_info: _ExprInfoDict, + ) -> lang.ScalarExpression: + """Emits the expression tree and returns the expression.""" + # The current expression node is an internal node of the structured op. + if self not in expr_to_opnd: + a = self.a._emit_expression(expr_to_opnd, expr_to_info) + b = self.b._emit_expression(expr_to_opnd, expr_to_info) + return _op_to_callable(self.op)(a, b) + + # The current expression is a leaf node of the structured op. That is, it is + # a temporary tensor generated by its child structured op. + op_info = expr_to_info[self].structop_info + assert op_info is not None + dims = _mlir_dimensions_from_index_vars(op_info.dst_indices) + return lang.TensorUse(expr_to_opnd[self], dims) + + def _visit( + self, func: _ExprVisitor, args, *, leaf_checker: _SubtreeLeafChecker = None + ) -> None: + """A post-order visitor.""" + if leaf_checker is None or not leaf_checker(self, *args): + self.a._visit(func, args, leaf_checker=leaf_checker) + self.b._visit(func, args, leaf_checker=leaf_checker) + func(self, *args) + + def dtype(self) -> DType: + """Returns the data type of the binary operation.""" + return self.a.dtype() def _validate_and_collect_dim_info( @@ -1822,105 +1901,104 @@ def _validate_and_collect_dim_info( dim_infos: Tuple[_DimInfo, ...], expr: _BinaryExpr, ) -> None: - """Validates and collects the dimension information for an index notation. - - Validates (indices, dim_infos) against the information collected from other - source operands and is represented by index_to_dim_info. In particular, we - ensure that each IndexVar corresponds to only one dimension size. We also - aggregate the new information represented in (indices, dim_infos) to - index_to_dim_info. - - Args: - index_to_dim: A dictionary of (IndexVar, _DimInfo) collected from the - previous operands. - indices: The IndexVars to be validated. - dim_infos: The dimension information for the IndexVars to be validated. - expr: The binary expression where (indices, dim_infos) is used. - - Raises: - ValueError if there is any problem in the IndexVars or dimensional values. - """ - assert len(indices) == len(dim_infos) - for i, d in zip(indices, dim_infos): - if i not in index_to_dim_info: - index_to_dim_info[i] = d - else: - dim = index_to_dim_info[i].dim - if dim == -1 or d.dim == -1: - dim = dim if dim != -1 else d.dim - elif dim != d.dim: - raise ValueError(f"Inconsistent source dimension for {i}: " - f"{d.dim} vs {dim}") - mode_format = _mode_format_estimator(expr.op)( - index_to_dim_info[i].mode_format, d.mode_format) - index_to_dim_info[i] = _DimInfo(d.dim, mode_format) + """Validates and collects the dimension information for an index notation. + + Validates (indices, dim_infos) against the information collected from other + source operands and is represented by index_to_dim_info. In particular, we + ensure that each IndexVar corresponds to only one dimension size. We also + aggregate the new information represented in (indices, dim_infos) to + index_to_dim_info. + + Args: + index_to_dim: A dictionary of (IndexVar, _DimInfo) collected from the + previous operands. + indices: The IndexVars to be validated. + dim_infos: The dimension information for the IndexVars to be validated. + expr: The binary expression where (indices, dim_infos) is used. + + Raises: + ValueError if there is any problem in the IndexVars or dimensional values. + """ + assert len(indices) == len(dim_infos) + for i, d in zip(indices, dim_infos): + if i not in index_to_dim_info: + index_to_dim_info[i] = d + else: + dim = index_to_dim_info[i].dim + if dim == -1 or d.dim == -1: + dim = dim if dim != -1 else d.dim + elif dim != d.dim: + raise ValueError( + f"Inconsistent source dimension for {i}: " f"{d.dim} vs {dim}" + ) + mode_format = _mode_format_estimator(expr.op)( + index_to_dim_info[i].mode_format, d.mode_format + ) + index_to_dim_info[i] = _DimInfo(d.dim, mode_format) def _validate_and_collect_expr_info( expr: IndexExpr, expr_to_info: _ExprInfoDict, ) -> None: - """Validates dimension information and constructs _ExprInfo. - - Validates that dimensional values for the same IndexVar are the same. Collects - a list of IndexVar used by the expression and their corresponding dimensional - values. Constructs an _ExprInfo object to record the information for the - IndexExpr. - - This routine is passed to the post-order visitor as an _ExprVisitor object. - - Args: - expr: The IndexExpr being validated. - expr_to_info: The dictionary of (IndexExpr, _ExprInfo) for recording the - expression information. - - Raises: - ValueError if there is any problem in the IndexVars or dimensional values. - """ - # Objects of class Access can be shared by different expressions. Avoid - # processing Access objects multiple times by skipping the processing if expr - # is already in the dictionary. - if expr in expr_to_info: - return - - if isinstance(expr, IndexVar): - src_indices = expr, # A tuple with one element. - dim_infos = _get_dummy_dim_info(), # A tuple with one element. - elif isinstance(expr, Access): - src_indices = expr.indices - src_dims = tuple(expr.tensor.shape) - if expr.tensor.format is None: - # Treat each dimension of a dense tensor as DENSE for the purpose of - # calculating temporary tensor storage format. - mode_formats = tuple([ModeFormat.DENSE] * len(src_dims)) + """Validates dimension information and constructs _ExprInfo. + + Validates that dimensional values for the same IndexVar are the same. Collects + a list of IndexVar used by the expression and their corresponding dimensional + values. Constructs an _ExprInfo object to record the information for the + IndexExpr. + + This routine is passed to the post-order visitor as an _ExprVisitor object. + + Args: + expr: The IndexExpr being validated. + expr_to_info: The dictionary of (IndexExpr, _ExprInfo) for recording the + expression information. + + Raises: + ValueError if there is any problem in the IndexVars or dimensional values. + """ + # Objects of class Access can be shared by different expressions. Avoid + # processing Access objects multiple times by skipping the processing if expr + # is already in the dictionary. + if expr in expr_to_info: + return + + if isinstance(expr, IndexVar): + src_indices = (expr,) # A tuple with one element. + dim_infos = (_get_dummy_dim_info(),) # A tuple with one element. + elif isinstance(expr, Access): + src_indices = expr.indices + src_dims = tuple(expr.tensor.shape) + if expr.tensor.format is None: + # Treat each dimension of a dense tensor as DENSE for the purpose of + # calculating temporary tensor storage format. + mode_formats = tuple([ModeFormat.DENSE] * len(src_dims)) + else: + mode_formats = tuple(expr.tensor.format.format_pack.formats) + assert len(src_dims) == len(mode_formats) + dim_infos = tuple([_DimInfo(d, m) for d, m in zip(src_dims, mode_formats)]) + elif isinstance(expr, _UnaryExpr): + a_info = expr_to_info[expr.a] + index_to_dim_info = {i: d for i, d in zip(a_info.src_indices, a_info.dim_infos)} + # Here we rely on the fact that dictionaries keep the insertion order for + # keys and values. + src_indices = tuple(index_to_dim_info.keys()) + dim_infos = tuple(index_to_dim_info.values()) else: - mode_formats = tuple(expr.tensor.format.format_pack.formats) - assert len(src_dims) == len(mode_formats) - dim_infos = tuple([_DimInfo(d, m) for d, m in zip(src_dims, mode_formats)]) - elif isinstance(expr, _UnaryExpr): - a_info = expr_to_info[expr.a] - index_to_dim_info = { - i: d for i, d in zip(a_info.src_indices, a_info.dim_infos) - } - # Here we rely on the fact that dictionaries keep the insertion order for - # keys and values. - src_indices = tuple(index_to_dim_info.keys()) - dim_infos = tuple(index_to_dim_info.values()) - else: - assert isinstance(expr, _BinaryExpr) - a_info = expr_to_info[expr.a] - index_to_dim_info = { - i: d for i, d in zip(a_info.src_indices, a_info.dim_infos) - } - b_info = expr_to_info[expr.b] - _validate_and_collect_dim_info(index_to_dim_info, b_info.src_indices, - b_info.dim_infos, expr) - # Here we rely on the fact that dictionaries keep the insertion order for - # keys and values. - src_indices = tuple(index_to_dim_info.keys()) - dim_infos = tuple(index_to_dim_info.values()) + assert isinstance(expr, _BinaryExpr) + a_info = expr_to_info[expr.a] + index_to_dim_info = {i: d for i, d in zip(a_info.src_indices, a_info.dim_infos)} + b_info = expr_to_info[expr.b] + _validate_and_collect_dim_info( + index_to_dim_info, b_info.src_indices, b_info.dim_infos, expr + ) + # Here we rely on the fact that dictionaries keep the insertion order for + # keys and values. + src_indices = tuple(index_to_dim_info.keys()) + dim_infos = tuple(index_to_dim_info.values()) - expr_to_info[expr] = _ExprInfo(src_indices, dim_infos) + expr_to_info[expr] = _ExprInfo(src_indices, dim_infos) def _mark_structured_op_root( @@ -1928,90 +2006,92 @@ def _mark_structured_op_root( reduce_index: IndexVar, expr_to_info: _ExprInfoDict, ) -> None: - """Identifies the root expression for a structured op in the linalg dialect. - - An linalg structured op can only perform reduction on the whole expression. - For a TACO tensor algebra expression, the reduction on an IndexVar is done at - the smallest expression that contains all the uses of the IndexVar. If such an - expression is only part of the whole expression, we need to split this - sub-expression tree out from its parent and implement the sub-expression as a - structured op. - - This routine identifies the root expression node for performing a reduction on - the given IndexVar. If the reduction of the given IndexVar should be performed - on expression X, then the IndexVar is added to expr_to_info[X].reduce_indices - - Args: - expr: The root IndexExpr for the tensor algebra expression. - reduce_index: The IndexVar which we want to find out the proper expression - to perform a reduction. - expr_to_info: The dictionary to look up _ExprInfo for IndexExpr. - - Raises: - ValueError: If the expression is not proper or not supported. - """ - expr_info = expr_to_info[expr] - if isinstance(expr, Access): - # Handle simple reduction expression in the format of A[i] = B[i, j]. - if reduce_index in expr_info.src_indices: - expr_info.reduce_indices.add(reduce_index) - return - elif isinstance(expr, IndexVar): - # A[i] = B[i] + j is not allowed. - raise ValueError(f"IndexVar is not part of the iteration domain: {expr}.") - - assert (isinstance(expr, _BinaryExpr)) - a_info = expr_to_info[expr.a] - b_info = expr_to_info[expr.b] - - if reduce_index in a_info.src_indices and reduce_index in b_info.src_indices: - expr_info.reduce_indices.add(reduce_index) - return - - if reduce_index in a_info.src_indices: - _mark_structured_op_root(expr.a, reduce_index, expr_to_info) - elif reduce_index in b_info.src_indices: - _mark_structured_op_root(expr.b, reduce_index, expr_to_info) - else: - assert False, "Unreachable path" + """Identifies the root expression for a structured op in the linalg dialect. + An linalg structured op can only perform reduction on the whole expression. + For a TACO tensor algebra expression, the reduction on an IndexVar is done at + the smallest expression that contains all the uses of the IndexVar. If such an + expression is only part of the whole expression, we need to split this + sub-expression tree out from its parent and implement the sub-expression as a + structured op. -def _accumulate_reduce_indices( - expr: IndexExpr, - expr_to_info: _ExprInfoDict, -) -> None: - """Propagates reduction indices from child expressions to parent expressions. + This routine identifies the root expression node for performing a reduction on + the given IndexVar. If the reduction of the given IndexVar should be performed + on expression X, then the IndexVar is added to expr_to_info[X].reduce_indices - This routine is passed to the post-order visitor as an _ExprVisitor object. + Args: + expr: The root IndexExpr for the tensor algebra expression. + reduce_index: The IndexVar which we want to find out the proper expression + to perform a reduction. + expr_to_info: The dictionary to look up _ExprInfo for IndexExpr. - Args: - expr: The IndexExpr being visited. - expr_to_info: The dictionary of (IndexExpr, _ExprInfo) for recording the - expression information. - """ - assert expr in expr_to_info - expr_info = expr_to_info[expr] + Raises: + ValueError: If the expression is not proper or not supported. + """ + expr_info = expr_to_info[expr] + if isinstance(expr, Access): + # Handle simple reduction expression in the format of A[i] = B[i, j]. + if reduce_index in expr_info.src_indices: + expr_info.reduce_indices.add(reduce_index) + return + elif isinstance(expr, IndexVar): + # A[i] = B[i] + j is not allowed. + raise ValueError(f"IndexVar is not part of the iteration domain: {expr}.") - if isinstance(expr, _BinaryExpr): + assert isinstance(expr, _BinaryExpr) a_info = expr_to_info[expr.a] b_info = expr_to_info[expr.b] - expr_info.acc_reduce_indices = ( - a_info.acc_reduce_indices | b_info.acc_reduce_indices - | expr_info.reduce_indices) - elif isinstance(expr, _UnaryExpr): - a_info = expr_to_info[expr.a] - expr_info.acc_reduce_indices = ( - a_info.acc_reduce_indices | expr_info.reduce_indices) - elif isinstance(expr, IndexVar): - # If an IndexVar is reducing itself, it means the IndexVar is outside the - # iteration domain. This usage is now allowed and we should emit an error - # before reaching here. - assert not expr_info.reduce_indices - else: - assert isinstance(expr, Access) - # Handle simple reduction expression in the format of A[i] = B[i, j]. - expr_info.acc_reduce_indices = expr_info.reduce_indices + if reduce_index in a_info.src_indices and reduce_index in b_info.src_indices: + expr_info.reduce_indices.add(reduce_index) + return + + if reduce_index in a_info.src_indices: + _mark_structured_op_root(expr.a, reduce_index, expr_to_info) + elif reduce_index in b_info.src_indices: + _mark_structured_op_root(expr.b, reduce_index, expr_to_info) + else: + assert False, "Unreachable path" + + +def _accumulate_reduce_indices( + expr: IndexExpr, + expr_to_info: _ExprInfoDict, +) -> None: + """Propagates reduction indices from child expressions to parent expressions. + + This routine is passed to the post-order visitor as an _ExprVisitor object. + + Args: + expr: The IndexExpr being visited. + expr_to_info: The dictionary of (IndexExpr, _ExprInfo) for recording the + expression information. + """ + assert expr in expr_to_info + expr_info = expr_to_info[expr] + + if isinstance(expr, _BinaryExpr): + a_info = expr_to_info[expr.a] + b_info = expr_to_info[expr.b] + expr_info.acc_reduce_indices = ( + a_info.acc_reduce_indices + | b_info.acc_reduce_indices + | expr_info.reduce_indices + ) + elif isinstance(expr, _UnaryExpr): + a_info = expr_to_info[expr.a] + expr_info.acc_reduce_indices = ( + a_info.acc_reduce_indices | expr_info.reduce_indices + ) + elif isinstance(expr, IndexVar): + # If an IndexVar is reducing itself, it means the IndexVar is outside the + # iteration domain. This usage is now allowed and we should emit an error + # before reaching here. + assert not expr_info.reduce_indices + else: + assert isinstance(expr, Access) + # Handle simple reduction expression in the format of A[i] = B[i, j]. + expr_info.acc_reduce_indices = expr_info.reduce_indices def _gather_structured_op( @@ -2019,42 +2099,42 @@ def _gather_structured_op( expr_to_info: _ExprInfoDict, structop_roots: List[IndexExpr], ) -> None: - """Adds structured op root expression information to structop_roots. - - This routine is passed to the post-order visitor as an _ExprVisitor object. - - Args: - expr: The IndexExpr being visited. - expr_to_info: The dictionary to look up _ExprInfo for IndexExpr. - structop_roots: The resulting list of IndexExpr that are the roots for - linalg structured ops. - """ - if not expr_to_info[expr].reduce_indices: - return - - # If the expression is the root for reducing some indices, collect the indices - # and dimensions for the reduction result. - dst_indices = [] - dst_dims = [] - mode_fmts = [] - for i, d in zip(expr_to_info[expr].src_indices, expr_to_info[expr].dim_infos): - if i not in expr_to_info[expr].acc_reduce_indices: - dst_indices.append(i) - dst_dims.append(d.dim) - mode_fmts.append(d.mode_format) - - # Add the information to the dictionary. - op_info = _StructOpInfo( - tuple(dst_indices), - tuple(dst_dims), - expr.dtype(), - f"temp{len(structop_roots)}", - _make_format(mode_fmts), - ) - expr_to_info[expr].structop_info = op_info - - # Add the expression to the list of structured op roots. - structop_roots.append(expr) + """Adds structured op root expression information to structop_roots. + + This routine is passed to the post-order visitor as an _ExprVisitor object. + + Args: + expr: The IndexExpr being visited. + expr_to_info: The dictionary to look up _ExprInfo for IndexExpr. + structop_roots: The resulting list of IndexExpr that are the roots for + linalg structured ops. + """ + if not expr_to_info[expr].reduce_indices: + return + + # If the expression is the root for reducing some indices, collect the indices + # and dimensions for the reduction result. + dst_indices = [] + dst_dims = [] + mode_fmts = [] + for i, d in zip(expr_to_info[expr].src_indices, expr_to_info[expr].dim_infos): + if i not in expr_to_info[expr].acc_reduce_indices: + dst_indices.append(i) + dst_dims.append(d.dim) + mode_fmts.append(d.mode_format) + + # Add the information to the dictionary. + op_info = _StructOpInfo( + tuple(dst_indices), + tuple(dst_dims), + expr.dtype(), + f"temp{len(structop_roots)}", + _make_format(mode_fmts), + ) + expr_to_info[expr].structop_info = op_info + + # Add the expression to the list of structured op roots. + structop_roots.append(expr) def _is_structured_op_leaf( @@ -2063,29 +2143,31 @@ def _is_structured_op_leaf( expr_to_info: _ExprInfoDict, *unused_args, ) -> bool: - """Returns true iff the expression is a leaf node for a structured op. + """Returns true iff the expression is a leaf node for a structured op. - The root of a structured op is a leaf of its parent structured op that uses - its result. An expression node is a leaf node for the current structured op if - it is an Access node or the root for a structured op that is not the current - structured op. + The root of a structured op is a leaf of its parent structured op that uses + its result. An expression node is a leaf node for the current structured op if + it is an Access node or the root for a structured op that is not the current + structured op. - This routine is passed to the post-order visitor as a _SubtreeLeafChecker - object. Because the post-order visitor pass the same parameters to both - _SubtreeLeafChecker and _ExprVisitor, this routine may received unused - parameters. + This routine is passed to the post-order visitor as a _SubtreeLeafChecker + object. Because the post-order visitor pass the same parameters to both + _SubtreeLeafChecker and _ExprVisitor, this routine may received unused + parameters. - Args: - expr: The IndexExpr being visited. - root: The root of the current structured op. - expr_to_info: The dictionary to look up _ExprInfo for IndexExpr. + Args: + expr: The IndexExpr being visited. + root: The root of the current structured op. + expr_to_info: The dictionary to look up _ExprInfo for IndexExpr. - Returns: - True if the current IndexExpr is a leaf for the current structured op. - """ - return (expr != root and - expr_to_info[expr].structop_info is not None) or isinstance( - expr, Access) or isinstance(expr, IndexVar) + Returns: + True if the current IndexExpr is a leaf for the current structured op. + """ + return ( + (expr != root and expr_to_info[expr].structop_info is not None) + or isinstance(expr, Access) + or isinstance(expr, IndexVar) + ) def _gather_structured_op_input( @@ -2094,26 +2176,28 @@ def _gather_structured_op_input( expr_to_info: _ExprInfoDict, structop_inputs: List[IndexExpr], ) -> None: - """Adds the IndexExpr to structop_inputs if it is an input. + """Adds the IndexExpr to structop_inputs if it is an input. - If the current IndexExpr is an input for the current structured op, adds it to - structop_inputs. The current IndexExpr is an input if it is an Access node or - if it is the root for a structured op that is not the current structured op. + If the current IndexExpr is an input for the current structured op, adds it to + structop_inputs. The current IndexExpr is an input if it is an Access node or + if it is the root for a structured op that is not the current structured op. - This routine is passed to the post-order visitor as an _ExprVisitor object. + This routine is passed to the post-order visitor as an _ExprVisitor object. - Args: - expr: The IndexExpr being visited. - root: The root of the current structured op. - expr_to_info: The dictionary to look up _ExprInfo for IndexExpr. - structop_inputs: The resulting list of IndexExpr that provide input to the - current structured op. - """ - if ((expr != root or isinstance(expr, Access)) and - expr not in structop_inputs) and (isinstance(expr, Access) or - (expr in expr_to_info and - expr_to_info[expr].structop_info)): - structop_inputs.append(expr) + Args: + expr: The IndexExpr being visited. + root: The root of the current structured op. + expr_to_info: The dictionary to look up _ExprInfo for IndexExpr. + structop_inputs: The resulting list of IndexExpr that provide input to the + current structured op. + """ + if ( + (expr != root or isinstance(expr, Access)) and expr not in structop_inputs + ) and ( + isinstance(expr, Access) + or (expr in expr_to_info and expr_to_info[expr].structop_info) + ): + structop_inputs.append(expr) def _emit_structured_op_input( @@ -2121,35 +2205,35 @@ def _emit_structured_op_input( expr_to_info: _ExprInfoDict, op_def: lang.LinalgOpDef, ) -> lang.OperandDef: - """Emits OperandDef in the linalg dialect for the input IndexExpr. - - Args: - expr: The input IndexExpr for the current structured op. - expr_to_info: The dictionary to look up _ExprInfo for IndexExpr. - op_def: The linalg operation for the current structured op. - - Returns: - An OperandDef in the linalg dialect for the input IndexExpr. - """ - op_info = expr_to_info[expr].structop_info - if op_info and not isinstance(expr, Access): - # The input is a temporary tensor produced by another structured op. - indices = op_info.dst_indices - name = op_info.dst_name - else: - # The input is a user provided tensor. - assert isinstance(expr, Access) - indices = expr.indices - name = expr.tensor.name - - dim_sym = _mlir_symbols_from_index_vars(indices) - opnd = lang.OperandDef(lang.OperandKind.INPUT_TENSOR, lang.T, dim_sym) - op_def.add_operand(name, opnd) - return opnd + """Emits OperandDef in the linalg dialect for the input IndexExpr. + + Args: + expr: The input IndexExpr for the current structured op. + expr_to_info: The dictionary to look up _ExprInfo for IndexExpr. + op_def: The linalg operation for the current structured op. + + Returns: + An OperandDef in the linalg dialect for the input IndexExpr. + """ + op_info = expr_to_info[expr].structop_info + if op_info and not isinstance(expr, Access): + # The input is a temporary tensor produced by another structured op. + indices = op_info.dst_indices + name = op_info.dst_name + else: + # The input is a user provided tensor. + assert isinstance(expr, Access) + indices = expr.indices + name = expr.tensor.name + + dim_sym = _mlir_symbols_from_index_vars(indices) + opnd = lang.OperandDef(lang.OperandKind.INPUT_TENSOR, lang.T, dim_sym) + op_def.add_operand(name, opnd) + return opnd def _check_and_build_unary(a: Access, op: _UnaryOp) -> "_UnaryExpr": - """Build a unary operation ceil. + """Build a unary operation ceil. Args: a: The operand, which could be any Python object from user inputs. @@ -2161,13 +2245,13 @@ def _check_and_build_unary(a: Access, op: _UnaryOp) -> "_UnaryExpr": Raises: ValueError: If a is not an IndexExpr. """ - if not isinstance(a, Access): - raise ValueError(f"Expected an Access Operand: {a}") - return a._build_unary_expr(op) + if not isinstance(a, Access): + raise ValueError(f"Expected an Access Operand: {a}") + return a._build_unary_expr(op) def ceil(a: Access) -> "_UnaryExpr": - """Defines the operation ceil. + """Defines the operation ceil. Args: a: The operand, which could be any Python object from user inputs. @@ -2178,11 +2262,11 @@ def ceil(a: Access) -> "_UnaryExpr": Raises: ValueError: If a is not an IndexExpr. """ - return _check_and_build_unary(a, _op_ceil) + return _check_and_build_unary(a, _op_ceil) def floor(a: Access) -> "_UnaryExpr": - """Defines the operation floor. + """Defines the operation floor. Args: a: The operand, which could be any Python object from user inputs. @@ -2193,4 +2277,4 @@ def floor(a: Access) -> "_UnaryExpr": Raises: ValueError: If a is not an IndexExpr. """ - return _check_and_build_unary(a, _op_floor) + return _check_and_build_unary(a, _op_floor) diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_io.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_io.py index e6a7d8e..785401c 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_io.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_io.py @@ -31,50 +31,52 @@ _MTX_FILENAME_SUFFIX = ".mtx" _TNS_FILENAME_SUFFIX = ".tns" -def read(filename: str, fmt: Format, - dtype: DType = DType(Type.FLOAT32)) -> Tensor: - """Inputs a tensor from a given file. - - The name suffix of the file specifies the format of the input tensor. We - currently only support .mtx format for support sparse tensors. - - Args: - filename: A string input filename. - fmt: The storage format of the tensor. - dtype: The data type, default to float32. - - Raises: - ValueError: If filename doesn't end with .mtx or .tns, or fmt is not an - instance of Format or fmt is not a sparse tensor. - """ - if (not isinstance(filename, str) or - (not filename.endswith(_MTX_FILENAME_SUFFIX) and - not filename.endswith(_TNS_FILENAME_SUFFIX))): - raise ValueError("Expected string filename ends with " - f"{_MTX_FILENAME_SUFFIX} or {_TNS_FILENAME_SUFFIX}: " - f"{filename}.") - - return Tensor.from_file(filename, fmt, dtype) +def read(filename: str, fmt: Format, dtype: DType = DType(Type.FLOAT32)) -> Tensor: + """Inputs a tensor from a given file. + + The name suffix of the file specifies the format of the input tensor. We + currently only support .mtx format for support sparse tensors. + + Args: + filename: A string input filename. + fmt: The storage format of the tensor. + dtype: The data type, default to float32. + + Raises: + ValueError: If filename doesn't end with .mtx or .tns, or fmt is not an + instance of Format or fmt is not a sparse tensor. + """ + if not isinstance(filename, str) or ( + not filename.endswith(_MTX_FILENAME_SUFFIX) + and not filename.endswith(_TNS_FILENAME_SUFFIX) + ): + raise ValueError( + "Expected string filename ends with " + f"{_MTX_FILENAME_SUFFIX} or {_TNS_FILENAME_SUFFIX}: " + f"{filename}." + ) + + return Tensor.from_file(filename, fmt, dtype) def write(filename: str, tensor: Tensor) -> None: - """Outputs a tensor to a given file. - - The name suffix of the file specifies the format of the output. We currently - only support .tns format. - - Args: - filename: A string output filename. - tensor: The tensor to output. - - Raises: - ValueError: If filename doesn't end with .tns or tensor is not a Tensor. - """ - if (not isinstance(filename, str) or - not filename.endswith(_TNS_FILENAME_SUFFIX)): - raise ValueError("Expected string filename ends with" - f" {_TNS_FILENAME_SUFFIX}: {filename}.") - if not isinstance(tensor, Tensor): - raise ValueError(f"Expected a Tensor object: {tensor}.") - - tensor.to_file(filename) + """Outputs a tensor to a given file. + + The name suffix of the file specifies the format of the output. We currently + only support .tns format. + + Args: + filename: A string output filename. + tensor: The tensor to output. + + Raises: + ValueError: If filename doesn't end with .tns or tensor is not a Tensor. + """ + if not isinstance(filename, str) or not filename.endswith(_TNS_FILENAME_SUFFIX): + raise ValueError( + "Expected string filename ends with" f" {_TNS_FILENAME_SUFFIX}: {filename}." + ) + if not isinstance(tensor, Tensor): + raise ValueError(f"Expected a Tensor object: {tensor}.") + + tensor.to_file(filename) diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py index 988c57b..1e1061b 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco_utils.py @@ -36,190 +36,234 @@ _ENTRY_NAME = "main" @functools.lru_cache() def _get_support_lib_name() -> str: - """Gets the string name for the supporting C shared library.""" - return os.getenv(_SUPPORTLIB_ENV_VAR, _DEFAULT_SUPPORTLIB) + """Gets the string name for the supporting C shared library.""" + return os.getenv(_SUPPORTLIB_ENV_VAR, _DEFAULT_SUPPORTLIB) @functools.lru_cache() def _get_sparse_compiler() -> mlir_sparse_compiler.SparseCompiler: - """Gets the MLIR sparse compiler with default setting.""" - return mlir_sparse_compiler.SparseCompiler( - options="", opt_level=_OPT_LEVEL, shared_libs=[_get_support_lib_name()]) + """Gets the MLIR sparse compiler with default setting.""" + return mlir_sparse_compiler.SparseCompiler( + options="", opt_level=_OPT_LEVEL, shared_libs=[_get_support_lib_name()] + ) def _record_support_funcs( - ty: np.dtype, to_func: _SupportFunc, from_func: _SupportFunc, - ty_to_funcs: Dict[np.dtype, Tuple[_SupportFunc, _SupportFunc]]) -> None: - """Records the two supporting functions for a given data type.""" - to_func.restype = ctypes.c_void_p - from_func.restype = ctypes.c_void_p - ty_to_funcs[ty] = (to_func, from_func) + ty: np.dtype, + to_func: _SupportFunc, + from_func: _SupportFunc, + ty_to_funcs: Dict[np.dtype, Tuple[_SupportFunc, _SupportFunc]], +) -> None: + """Records the two supporting functions for a given data type.""" + to_func.restype = ctypes.c_void_p + from_func.restype = ctypes.c_void_p + ty_to_funcs[ty] = (to_func, from_func) @functools.lru_cache() def _get_support_func_locator() -> _SupportFuncLocator: - """Constructs a function to locate the supporting functions for a data type. - - Loads the supporting C shared library with the needed routines. Constructs a - dictionary from the supported data types to the routines for the data types, - and then a function to look up the dictionary for a given data type. - - The name of the supporting C shared library is either provided by an - an environment variable or a default value. - - Returns: - The function to look up the supporting functions for a given data type. - - Raises: - OSError: If there is any problem in loading the shared library. - ValueError: If the shared library doesn't contain the needed routines. - """ - # This raises OSError exception if there is any problem in loading the shared - # library. - c_lib = ctypes.CDLL(_get_support_lib_name()) - - type_to_funcs = {} - try: - support_types = [(np.int8, c_lib.convertToMLIRSparseTensorI8, - c_lib.convertFromMLIRSparseTensorI8), - (np.int16, c_lib.convertToMLIRSparseTensorI16, - c_lib.convertFromMLIRSparseTensorI16), - (np.int32, c_lib.convertToMLIRSparseTensorI32, - c_lib.convertFromMLIRSparseTensorI32), - (np.int64, c_lib.convertToMLIRSparseTensorI64, - c_lib.convertFromMLIRSparseTensorI64), - (np.float16, c_lib.convertToMLIRSparseTensorF16, - c_lib.convertFromMLIRSparseTensorF16), - (np.float32, c_lib.convertToMLIRSparseTensorF32, - c_lib.convertFromMLIRSparseTensorF32), - (np.float64, c_lib.convertToMLIRSparseTensorF64, - c_lib.convertFromMLIRSparseTensorF64), - (np.complex64, c_lib.convertToMLIRSparseTensorC32, - c_lib.convertFromMLIRSparseTensorC32), - (np.complex128, c_lib.convertToMLIRSparseTensorC64, - c_lib.convertFromMLIRSparseTensorC64)] - except Exception as e: - raise ValueError(f"Missing supporting function: {e}") from e - for i, info in enumerate(support_types): - _record_support_funcs(info[0], info[1], info[2], type_to_funcs) - - def get_support_funcs(ty: np.dtype): - funcs = type_to_funcs[ty] - assert funcs is not None - return funcs - - return get_support_funcs + """Constructs a function to locate the supporting functions for a data type. + + Loads the supporting C shared library with the needed routines. Constructs a + dictionary from the supported data types to the routines for the data types, + and then a function to look up the dictionary for a given data type. + + The name of the supporting C shared library is either provided by an + an environment variable or a default value. + + Returns: + The function to look up the supporting functions for a given data type. + + Raises: + OSError: If there is any problem in loading the shared library. + ValueError: If the shared library doesn't contain the needed routines. + """ + # This raises OSError exception if there is any problem in loading the shared + # library. + c_lib = ctypes.CDLL(_get_support_lib_name()) + + type_to_funcs = {} + try: + support_types = [ + ( + np.int8, + c_lib.convertToMLIRSparseTensorI8, + c_lib.convertFromMLIRSparseTensorI8, + ), + ( + np.int16, + c_lib.convertToMLIRSparseTensorI16, + c_lib.convertFromMLIRSparseTensorI16, + ), + ( + np.int32, + c_lib.convertToMLIRSparseTensorI32, + c_lib.convertFromMLIRSparseTensorI32, + ), + ( + np.int64, + c_lib.convertToMLIRSparseTensorI64, + c_lib.convertFromMLIRSparseTensorI64, + ), + ( + np.float16, + c_lib.convertToMLIRSparseTensorF16, + c_lib.convertFromMLIRSparseTensorF16, + ), + ( + np.float32, + c_lib.convertToMLIRSparseTensorF32, + c_lib.convertFromMLIRSparseTensorF32, + ), + ( + np.float64, + c_lib.convertToMLIRSparseTensorF64, + c_lib.convertFromMLIRSparseTensorF64, + ), + ( + np.complex64, + c_lib.convertToMLIRSparseTensorC32, + c_lib.convertFromMLIRSparseTensorC32, + ), + ( + np.complex128, + c_lib.convertToMLIRSparseTensorC64, + c_lib.convertFromMLIRSparseTensorC64, + ), + ] + except Exception as e: + raise ValueError(f"Missing supporting function: {e}") from e + for i, info in enumerate(support_types): + _record_support_funcs(info[0], info[1], info[2], type_to_funcs) + + def get_support_funcs(ty: np.dtype): + funcs = type_to_funcs[ty] + assert funcs is not None + return funcs + + return get_support_funcs def sparse_tensor_to_coo_tensor( sparse_tensor: ctypes.c_void_p, dtype: np.dtype, ) -> Tuple[int, int, np.ndarray, np.ndarray, np.ndarray]: - """Converts an MLIR sparse tensor to a COO-flavored format tensor. - - Args: - sparse_tensor: A ctypes.c_void_p to the MLIR sparse tensor descriptor. - dtype: The numpy data type for the tensor elements. - - Returns: - A tuple that contains the following values for the COO-flavored format - tensor: - rank: An integer for the rank of the tensor. - nse: An integer for the number of non-zero values in the tensor. - shape: A 1D numpy array of integers, for the shape of the tensor. - values: A 1D numpy array, for the non-zero values in the tensor. - indices: A 2D numpy array of integers, representing the indices for the - non-zero values in the tensor. - - Raises: - OSError: If there is any problem in loading the shared library. - ValueError: If the shared library doesn't contain the needed routines. - """ - convert_from = _get_support_func_locator()(dtype)[1] - rank = ctypes.c_ulonglong(0) - nse = ctypes.c_ulonglong(0) - shape = ctypes.POINTER(ctypes.c_ulonglong)() - - values = ctypes.POINTER(runtime.as_ctype(np.dtype(dtype)))() - indices = ctypes.POINTER(ctypes.c_ulonglong)() - convert_from(sparse_tensor, ctypes.byref(rank), ctypes.byref(nse), - ctypes.byref(shape), ctypes.byref(values), ctypes.byref(indices)) - - # Convert the returned values to the corresponding numpy types. - shape = np.ctypeslib.as_array(shape, shape=[rank.value]) - values = runtime.to_numpy(np.ctypeslib.as_array(values, shape=[nse.value])) - indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value]) - return rank.value, nse.value, shape, values, indices - - -def coo_tensor_to_sparse_tensor(np_shape: np.ndarray, np_values: np.ndarray, - np_indices: np.ndarray, np_perm: np.ndarray, - np_sparse: np.ndarray) -> int: - """Converts a COO-flavored format sparse tensor to an MLIR sparse tensor. - - Args: - np_shape: A 1D numpy array of integers, for the shape of the tensor. - np_values: A 1D numpy array, for the non-zero values in the tensor. - np_indices: A 2D numpy array of integers, representing the indices for the - non-zero values in the tensor. - np_perm: A 1D numpy array of integers, representing the storage ordering - for the dimensions. - np_sparse: A 1D numpy array of uint8, representing the sparsity values - for the dimensions. - - Returns: - An integer for the non-null ctypes.c_void_p to the MLIR sparse tensor - descriptor. - - Raises: - OSError: If there is any problem in loading the shared library. - ValueError: If the shared library doesn't contain the needed routines. - """ - - r = len(np_shape) - rank = ctypes.c_ulonglong(r) - nse = ctypes.c_ulonglong(len(np_values)) - shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong)) - values = np_values.ctypes.data_as( - ctypes.POINTER(runtime.as_ctype(np.dtype(np_values.dtype)))) - indices = np_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong)) - - perm = np_perm.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong)) - sparse = np_sparse.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)) - - convert_to = _get_support_func_locator()(np_values.dtype.type)[0] - ptr = convert_to(rank, nse, shape, values, indices, perm, sparse) - assert ptr is not None, "Problem with calling convertToMLIRSparseTensorF64" - return ptr - - -def compile_and_build_engine( - module: ir.Module) -> execution_engine.ExecutionEngine: - """Compiles an MLIR module and builds a JIT execution engine. - - Args: - module: The MLIR module. - - Returns: - A JIT execution engine for the MLIR module. - - """ - return _get_sparse_compiler().compile_and_jit(module) + """Converts an MLIR sparse tensor to a COO-flavored format tensor. + + Args: + sparse_tensor: A ctypes.c_void_p to the MLIR sparse tensor descriptor. + dtype: The numpy data type for the tensor elements. + + Returns: + A tuple that contains the following values for the COO-flavored format + tensor: + rank: An integer for the rank of the tensor. + nse: An integer for the number of non-zero values in the tensor. + shape: A 1D numpy array of integers, for the shape of the tensor. + values: A 1D numpy array, for the non-zero values in the tensor. + indices: A 2D numpy array of integers, representing the indices for the + non-zero values in the tensor. + + Raises: + OSError: If there is any problem in loading the shared library. + ValueError: If the shared library doesn't contain the needed routines. + """ + convert_from = _get_support_func_locator()(dtype)[1] + rank = ctypes.c_ulonglong(0) + nse = ctypes.c_ulonglong(0) + shape = ctypes.POINTER(ctypes.c_ulonglong)() + + values = ctypes.POINTER(runtime.as_ctype(np.dtype(dtype)))() + indices = ctypes.POINTER(ctypes.c_ulonglong)() + convert_from( + sparse_tensor, + ctypes.byref(rank), + ctypes.byref(nse), + ctypes.byref(shape), + ctypes.byref(values), + ctypes.byref(indices), + ) + + # Convert the returned values to the corresponding numpy types. + shape = np.ctypeslib.as_array(shape, shape=[rank.value]) + values = runtime.to_numpy(np.ctypeslib.as_array(values, shape=[nse.value])) + indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value]) + return rank.value, nse.value, shape, values, indices + + +def coo_tensor_to_sparse_tensor( + np_shape: np.ndarray, + np_values: np.ndarray, + np_indices: np.ndarray, + np_perm: np.ndarray, + np_sparse: np.ndarray, +) -> int: + """Converts a COO-flavored format sparse tensor to an MLIR sparse tensor. + + Args: + np_shape: A 1D numpy array of integers, for the shape of the tensor. + np_values: A 1D numpy array, for the non-zero values in the tensor. + np_indices: A 2D numpy array of integers, representing the indices for the + non-zero values in the tensor. + np_perm: A 1D numpy array of integers, representing the storage ordering + for the dimensions. + np_sparse: A 1D numpy array of uint8, representing the sparsity values + for the dimensions. + + Returns: + An integer for the non-null ctypes.c_void_p to the MLIR sparse tensor + descriptor. + + Raises: + OSError: If there is any problem in loading the shared library. + ValueError: If the shared library doesn't contain the needed routines. + """ + + r = len(np_shape) + rank = ctypes.c_ulonglong(r) + nse = ctypes.c_ulonglong(len(np_values)) + shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong)) + values = np_values.ctypes.data_as( + ctypes.POINTER(runtime.as_ctype(np.dtype(np_values.dtype))) + ) + indices = np_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong)) + + perm = np_perm.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong)) + sparse = np_sparse.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)) + + convert_to = _get_support_func_locator()(np_values.dtype.type)[0] + ptr = convert_to(rank, nse, shape, values, indices, perm, sparse) + assert ptr is not None, "Problem with calling convertToMLIRSparseTensorF64" + return ptr + + +def compile_and_build_engine(module: ir.Module) -> execution_engine.ExecutionEngine: + """Compiles an MLIR module and builds a JIT execution engine. + + Args: + module: The MLIR module. + + Returns: + A JIT execution engine for the MLIR module. + + """ + return _get_sparse_compiler().compile_and_jit(module) class _SparseTensorDescriptor(ctypes.Structure): - """A C structure for an MLIR sparse tensor.""" - _fields_ = [ - # A pointer for the MLIR sparse tensor storage. - ("storage", ctypes.POINTER(ctypes.c_ulonglong)), - # An MLIR MemRef descriptor for the shape of the sparse tensor. - ("shape", runtime.make_nd_memref_descriptor(1, ctypes.c_ulonglong)), - ] + """A C structure for an MLIR sparse tensor.""" + + _fields_ = [ + # A pointer for the MLIR sparse tensor storage. + ("storage", ctypes.POINTER(ctypes.c_ulonglong)), + # An MLIR MemRef descriptor for the shape of the sparse tensor. + ("shape", runtime.make_nd_memref_descriptor(1, ctypes.c_ulonglong)), + ] def _output_one_dim(dim: int, rank: int, shape: str, type: str) -> str: - """Produces the MLIR text code to output the size for the given dimension.""" - return f""" + """Produces the MLIR text code to output the size for the given dimension.""" + return f""" %c{dim} = arith.constant {dim} : index %d{dim} = tensor.dim %t, %c{dim} : tensor<{shape}x{type}, #enc> memref.store %d{dim}, %b[%c{dim}] : memref<{rank}xindex> @@ -233,26 +277,29 @@ def _output_one_dim(dim: int, rank: int, shape: str, type: str) -> str: # (2) Use scf.for instead of an unrolled loop to write out the dimension sizes # when tensor.dim supports non-constant dimension value. def _get_create_sparse_tensor_kernel( - sparsity_codes: Sequence[sparse_tensor.DimLevelType], type: str) -> str: - """Creates an MLIR text kernel to contruct a sparse tensor from a file. + sparsity_codes: Sequence[sparse_tensor.DimLevelType], type: str +) -> str: + """Creates an MLIR text kernel to contruct a sparse tensor from a file. - The kernel returns a _SparseTensorDescriptor structure. - """ - rank = len(sparsity_codes) + The kernel returns a _SparseTensorDescriptor structure. + """ + rank = len(sparsity_codes) - # Use ? to represent a dimension in the dynamic shape string representation. - shape = "x".join(map(lambda d: "?", range(rank))) + # Use ? to represent a dimension in the dynamic shape string representation. + shape = "x".join(map(lambda d: "?", range(rank))) - # Convert the encoded sparsity values to a string representation. - sparsity = ", ".join( - map(lambda s: '"compressed"' if s.value else '"dense"', sparsity_codes)) + # Convert the encoded sparsity values to a string representation. + sparsity = ", ".join( + map(lambda s: '"compressed"' if s.value else '"dense"', sparsity_codes) + ) - # Get the MLIR text code to write the dimension sizes to the output buffer. - output_dims = "\n".join( - map(lambda d: _output_one_dim(d, rank, shape, type), range(rank))) + # Get the MLIR text code to write the dimension sizes to the output buffer. + output_dims = "\n".join( + map(lambda d: _output_one_dim(d, rank, shape, type), range(rank)) + ) - # Return the MLIR text kernel. - return f""" + # Return the MLIR text kernel. + return f""" !Ptr = !llvm.ptr #enc = #sparse_tensor.encoding<{{ lvlTypes = [ {sparsity} ] @@ -266,69 +313,69 @@ attributes {{ llvm.emit_c_interface }} {{ }}""" -def create_sparse_tensor(filename: str, - sparsity: Sequence[sparse_tensor.DimLevelType], - type: str) -> Tuple[ctypes.c_void_p, np.ndarray]: - """Creates an MLIR sparse tensor from the input file. +def create_sparse_tensor( + filename: str, sparsity: Sequence[sparse_tensor.DimLevelType], type: str +) -> Tuple[ctypes.c_void_p, np.ndarray]: + """Creates an MLIR sparse tensor from the input file. - Args: - filename: A string for the name of the file that contains the tensor data in - a COO-flavored format. - sparsity: A sequence of DimLevelType values, one for each dimension of the - tensor. + Args: + filename: A string for the name of the file that contains the tensor data in + a COO-flavored format. + sparsity: A sequence of DimLevelType values, one for each dimension of the + tensor. - Returns: - A Tuple containing the following values: - storage: A ctypes.c_void_p for the MLIR sparse tensor storage. - shape: A 1D numpy array of integers, for the shape of the tensor. + Returns: + A Tuple containing the following values: + storage: A ctypes.c_void_p for the MLIR sparse tensor storage. + shape: A 1D numpy array of integers, for the shape of the tensor. - Raises: - OSError: If there is any problem in loading the supporting C shared library. - ValueError: If the shared library doesn't contain the needed routine. - """ - with ir.Context() as ctx, ir.Location.unknown(): - module = _get_create_sparse_tensor_kernel(sparsity, type) - module = ir.Module.parse(module) - engine = compile_and_build_engine(module) + Raises: + OSError: If there is any problem in loading the supporting C shared library. + ValueError: If the shared library doesn't contain the needed routine. + """ + with ir.Context() as ctx, ir.Location.unknown(): + module = _get_create_sparse_tensor_kernel(sparsity, type) + module = ir.Module.parse(module) + engine = compile_and_build_engine(module) - # A sparse tensor descriptor to receive the kernel result. - c_tensor_desc = _SparseTensorDescriptor() - # Convert the filename to a byte stream. - c_filename = ctypes.c_char_p(bytes(filename, "utf-8")) + # A sparse tensor descriptor to receive the kernel result. + c_tensor_desc = _SparseTensorDescriptor() + # Convert the filename to a byte stream. + c_filename = ctypes.c_char_p(bytes(filename, "utf-8")) - arg_pointers = [ - ctypes.byref(ctypes.pointer(c_tensor_desc)), - ctypes.byref(c_filename) - ] + arg_pointers = [ + ctypes.byref(ctypes.pointer(c_tensor_desc)), + ctypes.byref(c_filename), + ] - # Invoke the execution engine to run the module and return the result. - engine.invoke(_ENTRY_NAME, *arg_pointers) - shape = runtime.ranked_memref_to_numpy(ctypes.pointer(c_tensor_desc.shape)) - return c_tensor_desc.storage, shape + # Invoke the execution engine to run the module and return the result. + engine.invoke(_ENTRY_NAME, *arg_pointers) + shape = runtime.ranked_memref_to_numpy(ctypes.pointer(c_tensor_desc.shape)) + return c_tensor_desc.storage, shape # TODO: With better support from MLIR, we may improve the current implementation # by using Python code to generate the kernel instead of doing MLIR text code # stitching. def _get_output_sparse_tensor_kernel( - sparsity_codes: Sequence[sparse_tensor.DimLevelType], - type: str) -> str: - """Creates an MLIR text kernel to output a sparse tensor to a file. + sparsity_codes: Sequence[sparse_tensor.DimLevelType], type: str +) -> str: + """Creates an MLIR text kernel to output a sparse tensor to a file. - The kernel returns void. - """ - rank = len(sparsity_codes) + The kernel returns void. + """ + rank = len(sparsity_codes) - # Use ? to represent a dimension in the dynamic shape string representation. - shape = "x".join(map(lambda d: "?", range(rank))) + # Use ? to represent a dimension in the dynamic shape string representation. + shape = "x".join(map(lambda d: "?", range(rank))) - # Convert the encoded sparsity values to a string representation. - sparsity = ", ".join( - map(lambda s: '"compressed"' - if s.value else '"dense"', sparsity_codes)) + # Convert the encoded sparsity values to a string representation. + sparsity = ", ".join( + map(lambda s: '"compressed"' if s.value else '"dense"', sparsity_codes) + ) - # Return the MLIR text kernel. - return f""" + # Return the MLIR text kernel. + return f""" !Ptr = !llvm.ptr #enc = #sparse_tensor.encoding<{{ lvlTypes = [ {sparsity} ] @@ -340,35 +387,38 @@ attributes {{ llvm.emit_c_interface }} {{ }}""" -def output_sparse_tensor(tensor: ctypes.c_void_p, filename: str, - sparsity: Sequence[sparse_tensor.DimLevelType], - type: str) -> None: - """Outputs an MLIR sparse tensor to the given file. - - Args: - tensor: A C pointer to the MLIR sparse tensor. - filename: A string for the name of the file that contains the tensor data in - a COO-flavored format. - sparsity: A sequence of DimLevelType values, one for each dimension of the - tensor. - type: The MLIR string for the data type. - - Raises: - OSError: If there is any problem in loading the supporting C shared library. - ValueError: If the shared library doesn't contain the needed routine. - """ - with ir.Context() as ctx, ir.Location.unknown(): - module = _get_output_sparse_tensor_kernel(sparsity, type) - module = ir.Module.parse(module) - engine = compile_and_build_engine(module) - - # Convert the filename to a byte stream. - c_filename = ctypes.c_char_p(bytes(filename, "utf-8")) - - arg_pointers = [ - ctypes.byref(ctypes.cast(tensor, ctypes.c_void_p)), - ctypes.byref(c_filename) - ] - - # Invoke the execution engine to run the module and return the result. - engine.invoke(_ENTRY_NAME, *arg_pointers) +def output_sparse_tensor( + tensor: ctypes.c_void_p, + filename: str, + sparsity: Sequence[sparse_tensor.DimLevelType], + type: str, +) -> None: + """Outputs an MLIR sparse tensor to the given file. + + Args: + tensor: A C pointer to the MLIR sparse tensor. + filename: A string for the name of the file that contains the tensor data in + a COO-flavored format. + sparsity: A sequence of DimLevelType values, one for each dimension of the + tensor. + type: The MLIR string for the data type. + + Raises: + OSError: If there is any problem in loading the supporting C shared library. + ValueError: If the shared library doesn't contain the needed routine. + """ + with ir.Context() as ctx, ir.Location.unknown(): + module = _get_output_sparse_tensor_kernel(sparsity, type) + module = ir.Module.parse(module) + engine = compile_and_build_engine(module) + + # Convert the filename to a byte stream. + c_filename = ctypes.c_char_p(bytes(filename, "utf-8")) + + arg_pointers = [ + ctypes.byref(ctypes.cast(tensor, ctypes.c_void_p)), + ctypes.byref(c_filename), + ] + + # Invoke the execution engine to run the module and return the result. + engine.invoke(_ENTRY_NAME, *arg_pointers) diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_sparse_compiler.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_sparse_compiler.py index 69db28d..8f193b8 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_sparse_compiler.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_sparse_compiler.py @@ -13,29 +13,29 @@ from typing import Sequence class SparseCompiler: - """Sparse compiler class for compiling and building MLIR modules.""" - - def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]): - pipeline = f'builtin.module(sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}})' - self.pipeline = pipeline - self.opt_level = opt_level - self.shared_libs = shared_libs - - def __call__(self, module: ir.Module): - """Convenience application method.""" - self.compile(module) - - def compile(self, module: ir.Module): - """Compiles the module by invoking the sparse copmiler pipeline.""" - passmanager.PassManager.parse(self.pipeline).run(module.operation) - - def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine: - """Wraps the module in a JIT execution engine.""" - return execution_engine.ExecutionEngine( - module, opt_level=self.opt_level, shared_libs=self.shared_libs) - - def compile_and_jit(self, - module: ir.Module) -> execution_engine.ExecutionEngine: - """Compiles and jits the module.""" - self.compile(module) - return self.jit(module) + """Sparse compiler class for compiling and building MLIR modules.""" + + def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]): + pipeline = f"builtin.module(sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}})" + self.pipeline = pipeline + self.opt_level = opt_level + self.shared_libs = shared_libs + + def __call__(self, module: ir.Module): + """Convenience application method.""" + self.compile(module) + + def compile(self, module: ir.Module): + """Compiles the module by invoking the sparse copmiler pipeline.""" + passmanager.PassManager.parse(self.pipeline).run(module.operation) + + def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine: + """Wraps the module in a JIT execution engine.""" + return execution_engine.ExecutionEngine( + module, opt_level=self.opt_level, shared_libs=self.shared_libs + ) + + def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine: + """Compiles and jits the module.""" + self.compile(module) + return self.jit(module) diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/testing_utils.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/testing_utils.py index 466c9df..1be88fa 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/testing_utils.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/testing_utils.py @@ -8,38 +8,40 @@ import numpy as np def compare_sparse_tns(expected: str, actual: str, rtol: float = 0.0001) -> bool: - """Compares sparse tensor actual output file with expected output file. + """Compares sparse tensor actual output file with expected output file. - This routine assumes the input files are in FROSTT format. See - http://frostt.io/tensors/file-formats.html for FROSTT (.tns) format. + This routine assumes the input files are in FROSTT format. See + http://frostt.io/tensors/file-formats.html for FROSTT (.tns) format. - It also assumes the first line in the output file is a comment line. + It also assumes the first line in the output file is a comment line. - """ - with open(actual, "r") as actual_f: - with open(expected, "r") as expected_f: - # Skip the first comment line. - _ = actual_f.readline() - _ = expected_f.readline() + """ + with open(actual, "r") as actual_f: + with open(expected, "r") as expected_f: + # Skip the first comment line. + _ = actual_f.readline() + _ = expected_f.readline() - # Compare the two lines of meta data - if (actual_f.readline() != expected_f.readline() or - actual_f.readline() != expected_f.readline()): - return FALSE + # Compare the two lines of meta data + if ( + actual_f.readline() != expected_f.readline() + or actual_f.readline() != expected_f.readline() + ): + return FALSE - actual_data = np.loadtxt(actual, np.float64, skiprows=3) - expected_data = np.loadtxt(expected, np.float64, skiprows=3) - return np.allclose(actual_data, expected_data, rtol=rtol) + actual_data = np.loadtxt(actual, np.float64, skiprows=3) + expected_data = np.loadtxt(expected, np.float64, skiprows=3) + return np.allclose(actual_data, expected_data, rtol=rtol) def file_as_string(file: str) -> str: - """Returns contents of file as string.""" - with open(file, "r") as f: - return f.read() + """Returns contents of file as string.""" + with open(file, "r") as f: + return f.read() def run_test(f): - """Prints the test name and runs the test.""" - print(f.__name__) - f() - return f + """Prints the test name and runs the test.""" + print(f.__name__) + f() + return f diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py index 5b7e648..45ce446 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_core.py @@ -18,509 +18,630 @@ _DENSE = mlir_pytaco.ModeFormat.DENSE def _init_3d(T, I, J, K): - for i in range(I): - for j in range(J): - for k in range(K): - T.insert([i, j, k], i + j + k + 1) + for i in range(I): + for j in range(J): + for k in range(K): + T.insert([i, j, k], i + j + k + 1) def _init_2d(T, I, J): - for i in range(I): - for j in range(J): - T.insert([i, j], i + j + 1) + for i in range(I): + for j in range(J): + T.insert([i, j], i + j + 1) def _init_1d_with_value(T, I, v): - for i in range(I): - T.insert([i], v) + for i in range(I): + T.insert([i], v) def test_expect_error(name, code, error): - """Executes the code then verifies the expected error message.""" - try: - exec(code) - except ValueError as e: - passed = "passed" if (str(e).startswith(error)) else "failed" - print(f"test_{name}: {passed}") + """Executes the code then verifies the expected error message.""" + try: + exec(code) + except ValueError as e: + passed = "passed" if (str(e).startswith(error)) else "failed" + print(f"test_{name}: {passed}") # CHECK-LABEL: test_tensor_dtype @testing_utils.run_test def test_tensor_dtype(): - passed = mlir_pytaco.DType(mlir_pytaco.Type.INT16).is_int() - passed += mlir_pytaco.DType(mlir_pytaco.Type.INT32).is_int() - passed += mlir_pytaco.DType(mlir_pytaco.Type.INT64).is_int() - passed += mlir_pytaco.DType(mlir_pytaco.Type.FLOAT32).is_float() - passed += mlir_pytaco.DType(mlir_pytaco.Type.FLOAT64).is_float() - # CHECK: Number of passed: 5 - print("Number of passed:", passed) + passed = mlir_pytaco.DType(mlir_pytaco.Type.INT16).is_int() + passed += mlir_pytaco.DType(mlir_pytaco.Type.INT32).is_int() + passed += mlir_pytaco.DType(mlir_pytaco.Type.INT64).is_int() + passed += mlir_pytaco.DType(mlir_pytaco.Type.FLOAT32).is_float() + passed += mlir_pytaco.DType(mlir_pytaco.Type.FLOAT64).is_float() + # CHECK: Number of passed: 5 + print("Number of passed:", passed) # CHECK: test_mode_ordering_not_int: passed -test_expect_error("mode_ordering_not_int", - "m = mlir_pytaco.ModeOrdering(['x'])", - "Ordering must be a list of integers") +test_expect_error( + "mode_ordering_not_int", + "m = mlir_pytaco.ModeOrdering(['x'])", + "Ordering must be a list of integers", +) # CHECK: test_mode_ordering_not_permutation: passed -test_expect_error("mode_ordering_not_permutation", - "m = mlir_pytaco.ModeOrdering([2, 1])", "Invalid ordering") +test_expect_error( + "mode_ordering_not_permutation", + "m = mlir_pytaco.ModeOrdering([2, 1])", + "Invalid ordering", +) # CHECK: test_mode_format_invalid: passed -test_expect_error("mode_format_invalid", - "m = mlir_pytaco.ModeFormatPack(['y'])", - "Formats must be a list of ModeFormat") +test_expect_error( + "mode_format_invalid", + "m = mlir_pytaco.ModeFormatPack(['y'])", + "Formats must be a list of ModeFormat", +) # CHECK: test_expect_mode_format_pack: passed -test_expect_error("expect_mode_format_pack", (""" +test_expect_error( + "expect_mode_format_pack", + ( + """ mode_ordering = mlir_pytaco.ModeOrdering([0, 1, 2]) f = mlir_pytaco.Format(["x"], mode_ordering) - """), "Expected a list of ModeFormat") + """ + ), + "Expected a list of ModeFormat", +) # CHECK: test_expect_mode_ordering: passed -test_expect_error("expect_mode_ordering", (""" +test_expect_error( + "expect_mode_ordering", + ( + """ mode_format_pack = mlir_pytaco.ModeFormatPack([_COMPRESSED, _COMPRESSED]) f = mlir_pytaco.Format(mode_format_pack, "x") - """), "Expected ModeOrdering") + """ + ), + "Expected ModeOrdering", +) # CHECK: test_inconsistent_mode_format_pack_and_mode_ordering: passed -test_expect_error("inconsistent_mode_format_pack_and_mode_ordering", (""" +test_expect_error( + "inconsistent_mode_format_pack_and_mode_ordering", + ( + """ mode_format_pack = mlir_pytaco.ModeFormatPack([_COMPRESSED, _COMPRESSED]) mode_ordering = mlir_pytaco.ModeOrdering([0, 1, 2]) f = mlir_pytaco.Format(mode_format_pack, mode_ordering) - """), "Inconsistent ModeFormatPack and ModeOrdering") + """ + ), + "Inconsistent ModeFormatPack and ModeOrdering", +) # CHECK-LABEL: test_format_default_ordering @testing_utils.run_test def test_format_default_ordering(): - f = mlir_pytaco.Format([_COMPRESSED, _COMPRESSED]) - passed = 0 - passed += np.array_equal(f.ordering.ordering, [0, 1]) - # CHECK: Number of passed: 1 - print("Number of passed:", passed) + f = mlir_pytaco.Format([_COMPRESSED, _COMPRESSED]) + passed = 0 + passed += np.array_equal(f.ordering.ordering, [0, 1]) + # CHECK: Number of passed: 1 + print("Number of passed:", passed) # CHECK-LABEL: test_format_explicit_ordering @testing_utils.run_test def test_format_explicit_ordering(): - f = mlir_pytaco.Format([_COMPRESSED, _DENSE], [1, 0]) - passed = 0 - passed += np.array_equal(f.ordering.ordering, [1, 0]) - # CHECK: Number of passed: 1 - print("Number of passed:", passed) + f = mlir_pytaco.Format([_COMPRESSED, _DENSE], [1, 0]) + passed = 0 + passed += np.array_equal(f.ordering.ordering, [1, 0]) + # CHECK: Number of passed: 1 + print("Number of passed:", passed) # CHECK-LABEL: test_index_var @testing_utils.run_test def test_index_var(): - i = mlir_pytaco.IndexVar() - j = mlir_pytaco.IndexVar() - passed = (i.name != j.name) + i = mlir_pytaco.IndexVar() + j = mlir_pytaco.IndexVar() + passed = i.name != j.name - vars = mlir_pytaco.get_index_vars(10) - passed += (len(vars) == 10) - passed += (all([isinstance(e, mlir_pytaco.IndexVar) for e in vars])) + vars = mlir_pytaco.get_index_vars(10) + passed += len(vars) == 10 + passed += all([isinstance(e, mlir_pytaco.IndexVar) for e in vars]) - # CHECK: Number of passed: 3 - print("Number of passed:", passed) + # CHECK: Number of passed: 3 + print("Number of passed:", passed) # CHECK: test_tensor_invalid_first_argument: passed -test_expect_error("tensor_invalid_first_argument", - "t = mlir_pytaco.Tensor('f')", "Invalid first argument") +test_expect_error( + "tensor_invalid_first_argument", + "t = mlir_pytaco.Tensor('f')", + "Invalid first argument", +) # CHECK: test_tensor_inconsistent_shape_and_format: passed -test_expect_error("tensor_inconsistent_shape_and_format", (""" +test_expect_error( + "tensor_inconsistent_shape_and_format", + ( + """ mode_format_pack = mlir_pytaco.ModeFormatPack([_COMPRESSED, _COMPRESSED]) mode_ordering = mlir_pytaco.ModeOrdering([0, 1]) f = mlir_pytaco.Format(mode_format_pack, mode_ordering) t = mlir_pytaco.Tensor([3], f) - """), "Inconsistent shape and format") + """ + ), + "Inconsistent shape and format", +) # CHECK: test_tensor_invalid_format: passed -test_expect_error("tensor_invalid_format", "t = mlir_pytaco.Tensor([3], 'f')", - "Invalid format argument") +test_expect_error( + "tensor_invalid_format", + "t = mlir_pytaco.Tensor([3], 'f')", + "Invalid format argument", +) # CHECK: test_tensor_insert_nonlist_coordinate: passed -test_expect_error("tensor_insert_nonlist_coordinate", (""" +test_expect_error( + "tensor_insert_nonlist_coordinate", + ( + """ t = mlir_pytaco.Tensor([3]) t.insert(1, 0) - """), "Non list coordinate detected") + """ + ), + "Non list coordinate detected", +) # CHECK: test_tensor_insert_too_much_coordinate: passed -test_expect_error("tensor_insert_too_much_coordinate", (""" +test_expect_error( + "tensor_insert_too_much_coordinate", + ( + """ t = mlir_pytaco.Tensor([3]) t.insert([0, 0], 0) - """), "Invalid coordinate") + """ + ), + "Invalid coordinate", +) # CHECK: test_tensor_insert_coordinate_outof_range: passed -test_expect_error("tensor_insert_coordinate_outof_range", (""" +test_expect_error( + "tensor_insert_coordinate_outof_range", + ( + """ t = mlir_pytaco.Tensor([1, 1]) t.insert([1, 0], 0) - """), "Invalid coordinate") + """ + ), + "Invalid coordinate", +) # CHECK: test_tensor_insert_coordinate_nonint: passed -test_expect_error("tensor_insert_coordinate_nonint", (""" +test_expect_error( + "tensor_insert_coordinate_nonint", + ( + """ t = mlir_pytaco.Tensor([1, 1]) t.insert([0, "xy"], 0) - """), "Non integer coordinate detected") + """ + ), + "Non integer coordinate detected", +) # CHECK: test_tensor_insert_invalid_value: passed -test_expect_error("tensor_insert_invalid_value", (""" +test_expect_error( + "tensor_insert_invalid_value", + ( + """ t = mlir_pytaco.Tensor([1, 1]) t.insert([0, 0], "x") - """), "Value is neither int nor float") + """ + ), + "Value is neither int nor float", +) # CHECK: test_access_non_index_var_index: passed -test_expect_error("access_non_index_var_index", (""" +test_expect_error( + "access_non_index_var_index", + ( + """ t = mlir_pytaco.Tensor([5, 6]) i = mlir_pytaco.IndexVar() a = mlir_pytaco.Access(t, (i, "j")) - """), "Indices contain non IndexVar") + """ + ), + "Indices contain non IndexVar", +) # CHECK: test_access_inconsistent_rank_indices: passed -test_expect_error("access_inconsistent_rank_indices", (""" +test_expect_error( + "access_inconsistent_rank_indices", + ( + """ t = mlir_pytaco.Tensor([5, 6]) i = mlir_pytaco.IndexVar() a = mlir_pytaco.Access(t, (i,)) - """), "Invalid indices for rank") + """ + ), + "Invalid indices for rank", +) # CHECK: test_access_invalid_indices_for_rank: passed -test_expect_error("access_invalid_indices_for_rank", (""" +test_expect_error( + "access_invalid_indices_for_rank", + ( + """ t = mlir_pytaco.Tensor([5, 6]) i, j, k = mlir_pytaco.get_index_vars(3) a = mlir_pytaco.Access(t, (i,j, k)) - """), "Invalid indices for rank") + """ + ), + "Invalid indices for rank", +) # CHECK: test_invalid_indices: passed -test_expect_error("invalid_indices", (""" +test_expect_error( + "invalid_indices", + ( + """ i, j = mlir_pytaco.get_index_vars(2) A = mlir_pytaco.Tensor([2, 3]) B = mlir_pytaco.Tensor([2, 3]) C = mlir_pytaco.Tensor([2, 3], _DENSE) C[i, j] = A[1, j] + B[i, j] - """), "Expected IndexVars") + """ + ), + "Expected IndexVars", +) # CHECK: test_inconsistent_rank_indices: passed -test_expect_error("inconsistent_rank_indices", (""" +test_expect_error( + "inconsistent_rank_indices", + ( + """ i, j = mlir_pytaco.get_index_vars(2) A = mlir_pytaco.Tensor([2, 3]) C = mlir_pytaco.Tensor([2, 3], _DENSE) C[i, j] = A[i] - """), "Invalid indices for rank") + """ + ), + "Invalid indices for rank", +) # CHECK: test_destination_index_not_used_in_source: passed -test_expect_error("destination_index_not_used_in_source", (""" +test_expect_error( + "destination_index_not_used_in_source", + ( + """ i, j = mlir_pytaco.get_index_vars(2) A = mlir_pytaco.Tensor([3]) C = mlir_pytaco.Tensor([3], _DENSE) C[j] = A[i] C.evaluate() - """), "Destination IndexVar not used in the source expression") + """ + ), + "Destination IndexVar not used in the source expression", +) # CHECK: test_destination_dim_not_consistent_with_source: passed -test_expect_error("destination_dim_not_consistent_with_source", (""" +test_expect_error( + "destination_dim_not_consistent_with_source", + ( + """ i = mlir_pytaco.IndexVar() A = mlir_pytaco.Tensor([3]) C = mlir_pytaco.Tensor([5], _DENSE) C[i] = A[i] C.evaluate() - """), "Inconsistent destination dimension for IndexVar") + """ + ), + "Inconsistent destination dimension for IndexVar", +) # CHECK: test_inconsistent_source_dim: passed -test_expect_error("inconsistent_source_dim", (""" +test_expect_error( + "inconsistent_source_dim", + ( + """ i = mlir_pytaco.IndexVar() A = mlir_pytaco.Tensor([3]) B = mlir_pytaco.Tensor([5]) C = mlir_pytaco.Tensor([3], _DENSE) C[i] = A[i] + B[i] C.evaluate() - """), "Inconsistent source dimension for IndexVar") + """ + ), + "Inconsistent source dimension for IndexVar", +) # CHECK: test_index_var_outside_domain: passed -test_expect_error("index_var_outside_domain", (""" +test_expect_error( + "index_var_outside_domain", + ( + """ i, j = mlir_pytaco.get_index_vars(2) A = mlir_pytaco.Tensor([3]) B = mlir_pytaco.Tensor([3]) B[i] = A[i] + j B.evaluate() - """), "IndexVar is not part of the iteration domain") + """ + ), + "IndexVar is not part of the iteration domain", +) # CHECK-LABEL: test_tensor_all_dense_sparse @testing_utils.run_test def test_tensor_all_dense_sparse(): - a = mlir_pytaco.Tensor([4], [_DENSE]) - passed = (not a.is_dense()) - passed += (a.order == 1) - passed += (a.shape[0] == 4) - # CHECK: Number of passed: 3 - print("Number of passed:", passed) + a = mlir_pytaco.Tensor([4], [_DENSE]) + passed = not a.is_dense() + passed += a.order == 1 + passed += a.shape[0] == 4 + # CHECK: Number of passed: 3 + print("Number of passed:", passed) # CHECK-LABEL: test_tensor_true_dense @testing_utils.run_test def test_tensor_true_dense(): - a = mlir_pytaco.Tensor.from_array(np.random.uniform(size=5)) - passed = a.is_dense() - passed += (a.order == 1) - passed += (a.shape[0] == 5) - # CHECK: Number of passed: 3 - print("Number of passed:", passed) + a = mlir_pytaco.Tensor.from_array(np.random.uniform(size=5)) + passed = a.is_dense() + passed += a.order == 1 + passed += a.shape[0] == 5 + # CHECK: Number of passed: 3 + print("Number of passed:", passed) # CHECK-LABEL: test_tensor_copy @testing_utils.run_test def test_tensor_copy(): - i, j = mlir_pytaco.get_index_vars(2) - I = 2 - J = 3 - A = mlir_pytaco.Tensor([I, J]) - A.insert([0, 1], 5.0) - A.insert([1, 2], 6.0) - B = mlir_pytaco.Tensor([I, J]) - B[i, j] = A[i, j] - passed = (B._assignment is not None) - passed += (B._engine is None) - try: + i, j = mlir_pytaco.get_index_vars(2) + I = 2 + J = 3 + A = mlir_pytaco.Tensor([I, J]) + A.insert([0, 1], 5.0) + A.insert([1, 2], 6.0) + B = mlir_pytaco.Tensor([I, J]) + B[i, j] = A[i, j] + passed = B._assignment is not None + passed += B._engine is None + try: + B.compute() + except ValueError as e: + passed += str(e).startswith("Need to invoke compile") + B.compile() + passed += B._engine is not None B.compute() - except ValueError as e: - passed += (str(e).startswith("Need to invoke compile")) - B.compile() - passed += (B._engine is not None) - B.compute() - passed += (B._assignment is None) - passed += (B._engine is None) - indices, values = B.get_coordinates_and_values() - passed += np.array_equal(indices, [[0, 1], [1, 2]]) - passed += np.allclose(values, [5.0, 6.0]) - # No temporary tensor is used. - passed += (B._stats.get_total() == 0) - # CHECK: Number of passed: 9 - print("Number of passed:", passed) + passed += B._assignment is None + passed += B._engine is None + indices, values = B.get_coordinates_and_values() + passed += np.array_equal(indices, [[0, 1], [1, 2]]) + passed += np.allclose(values, [5.0, 6.0]) + # No temporary tensor is used. + passed += B._stats.get_total() == 0 + # CHECK: Number of passed: 9 + print("Number of passed:", passed) # CHECK-LABEL: test_tensor_trivial_reduction @testing_utils.run_test def test_tensor_trivial_reduction(): - i, j = mlir_pytaco.get_index_vars(2) - I = 2 - J = 3 - A = mlir_pytaco.Tensor([I, J]) - A.insert([0, 1], 5.0) - A.insert([0, 2], 3.0) - A.insert([1, 2], 6.0) - B = mlir_pytaco.Tensor([I]) - B[i] = A[i, j] - indices, values = B.get_coordinates_and_values() - passed = np.array_equal(indices, [[0], [1]]) - passed += np.allclose(values, [8.0, 6.0]) - # No temporary tensor is used. - passed += (B._stats.get_total() == 0) - - # CHECK: Number of passed: 3 - print("Number of passed:", passed) + i, j = mlir_pytaco.get_index_vars(2) + I = 2 + J = 3 + A = mlir_pytaco.Tensor([I, J]) + A.insert([0, 1], 5.0) + A.insert([0, 2], 3.0) + A.insert([1, 2], 6.0) + B = mlir_pytaco.Tensor([I]) + B[i] = A[i, j] + indices, values = B.get_coordinates_and_values() + passed = np.array_equal(indices, [[0], [1]]) + passed += np.allclose(values, [8.0, 6.0]) + # No temporary tensor is used. + passed += B._stats.get_total() == 0 + + # CHECK: Number of passed: 3 + print("Number of passed:", passed) # CHECK-LABEL: test_binary_add @testing_utils.run_test def test_binary_add(): - i = mlir_pytaco.IndexVar() - A = mlir_pytaco.Tensor([4]) - B = mlir_pytaco.Tensor([4]) - C = mlir_pytaco.Tensor([4]) - A.insert([1], 10) - A.insert([2], 1) - B.insert([3], 20) - B.insert([2], 2) - C[i] = A[i] + B[i] - indices, values = C.get_coordinates_and_values() - passed = np.array_equal(indices, [[1], [2], [3]]) - passed += np.array_equal(values, [10., 3., 20.]) - # No temporary tensor is used. - passed += (C._stats.get_total() == 0) - # CHECK: Number of passed: 3 - print("Number of passed:", passed) + i = mlir_pytaco.IndexVar() + A = mlir_pytaco.Tensor([4]) + B = mlir_pytaco.Tensor([4]) + C = mlir_pytaco.Tensor([4]) + A.insert([1], 10) + A.insert([2], 1) + B.insert([3], 20) + B.insert([2], 2) + C[i] = A[i] + B[i] + indices, values = C.get_coordinates_and_values() + passed = np.array_equal(indices, [[1], [2], [3]]) + passed += np.array_equal(values, [10.0, 3.0, 20.0]) + # No temporary tensor is used. + passed += C._stats.get_total() == 0 + # CHECK: Number of passed: 3 + print("Number of passed:", passed) # CHECK-LABEL: test_binary_add_sub @testing_utils.run_test def test_binary_add_sub(): - i = mlir_pytaco.IndexVar() - j = mlir_pytaco.IndexVar() - A = mlir_pytaco.Tensor([2, 3]) - B = mlir_pytaco.Tensor([2, 3]) - C = mlir_pytaco.Tensor([2, 3]) - D = mlir_pytaco.Tensor([2, 3]) - A.insert([0, 1], 10) - A.insert([1, 2], 40) - B.insert([0, 0], 20) - B.insert([1, 2], 30) - C.insert([0, 1], 5) - C.insert([1, 2], 7) - D[i, j] = A[i, j] + B[i, j] - C[i, j] - indices, values = D.get_coordinates_and_values() - passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]]) - passed += np.array_equal(values, [20., 5., 63.]) - # No temporary tensor is used. - passed += (D._stats.get_total() == 0) - # CHECK: Number of passed: 3 - print("Number of passed:", passed) + i = mlir_pytaco.IndexVar() + j = mlir_pytaco.IndexVar() + A = mlir_pytaco.Tensor([2, 3]) + B = mlir_pytaco.Tensor([2, 3]) + C = mlir_pytaco.Tensor([2, 3]) + D = mlir_pytaco.Tensor([2, 3]) + A.insert([0, 1], 10) + A.insert([1, 2], 40) + B.insert([0, 0], 20) + B.insert([1, 2], 30) + C.insert([0, 1], 5) + C.insert([1, 2], 7) + D[i, j] = A[i, j] + B[i, j] - C[i, j] + indices, values = D.get_coordinates_and_values() + passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]]) + passed += np.array_equal(values, [20.0, 5.0, 63.0]) + # No temporary tensor is used. + passed += D._stats.get_total() == 0 + # CHECK: Number of passed: 3 + print("Number of passed:", passed) # CHECK-LABEL: test_binary_mul_add @testing_utils.run_test def test_binary_mul_add(): - i = mlir_pytaco.IndexVar() - j = mlir_pytaco.IndexVar() - A = mlir_pytaco.Tensor([2, 3]) - B = mlir_pytaco.Tensor([2, 3]) - C = mlir_pytaco.Tensor([2, 3]) - D = mlir_pytaco.Tensor([2, 3]) - A.insert([0, 1], 10) - A.insert([1, 2], 40) - B.insert([0, 0], 20) - B.insert([1, 2], 30) - C.insert([0, 1], 5) - C.insert([1, 2], 7) - D[i, j] = A[i, j] * C[i, j] + B[i, j] - indices, values = D.get_coordinates_and_values() - passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]]) - passed += np.array_equal(values, [20., 50., 310.]) - # No temporary tensor is used. - passed += (D._stats.get_total() == 0) - # CHECK: Number of passed: 3 - print("Number of passed:", passed) + i = mlir_pytaco.IndexVar() + j = mlir_pytaco.IndexVar() + A = mlir_pytaco.Tensor([2, 3]) + B = mlir_pytaco.Tensor([2, 3]) + C = mlir_pytaco.Tensor([2, 3]) + D = mlir_pytaco.Tensor([2, 3]) + A.insert([0, 1], 10) + A.insert([1, 2], 40) + B.insert([0, 0], 20) + B.insert([1, 2], 30) + C.insert([0, 1], 5) + C.insert([1, 2], 7) + D[i, j] = A[i, j] * C[i, j] + B[i, j] + indices, values = D.get_coordinates_and_values() + passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]]) + passed += np.array_equal(values, [20.0, 50.0, 310.0]) + # No temporary tensor is used. + passed += D._stats.get_total() == 0 + # CHECK: Number of passed: 3 + print("Number of passed:", passed) # CHECK-LABEL: test_binary_add_reduce_at_root @testing_utils.run_test def test_binary_add_reduce_at_root(): - i = mlir_pytaco.IndexVar() - j = mlir_pytaco.IndexVar() - A = mlir_pytaco.Tensor([2, 3]) - B = mlir_pytaco.Tensor([2, 3]) - C = mlir_pytaco.Tensor([2], _DENSE) - A.insert([0, 1], 10) - A.insert([1, 2], 40) - B.insert([0, 0], 20) - B.insert([1, 2], 30) - C[i] = A[i, j] + B[i, j] - indices, values = C.get_coordinates_and_values() - passed = np.array_equal(indices, [[0], [1]]) - passed += np.array_equal(values, [30., 70.]) - # No temporary tensor is used. - passed += (C._stats.get_total() == 0) - # CHECK: Number of passed: 3 - print("Number of passed:", passed) + i = mlir_pytaco.IndexVar() + j = mlir_pytaco.IndexVar() + A = mlir_pytaco.Tensor([2, 3]) + B = mlir_pytaco.Tensor([2, 3]) + C = mlir_pytaco.Tensor([2], _DENSE) + A.insert([0, 1], 10) + A.insert([1, 2], 40) + B.insert([0, 0], 20) + B.insert([1, 2], 30) + C[i] = A[i, j] + B[i, j] + indices, values = C.get_coordinates_and_values() + passed = np.array_equal(indices, [[0], [1]]) + passed += np.array_equal(values, [30.0, 70.0]) + # No temporary tensor is used. + passed += C._stats.get_total() == 0 + # CHECK: Number of passed: 3 + print("Number of passed:", passed) # CHECK-LABEL: test_binary_add_reduce_at_child @testing_utils.run_test def test_binary_add_reduce_at_child(): - i = mlir_pytaco.IndexVar() - j = mlir_pytaco.IndexVar() - I = 2 - J = 3 - A = mlir_pytaco.Tensor([I, J]) - B = mlir_pytaco.Tensor([J]) - C = mlir_pytaco.Tensor([I]) - D = mlir_pytaco.Tensor([I], _DENSE) - - _init_2d(A, I, J) - _init_1d_with_value(C, I, 2) - _init_1d_with_value(B, J, 1) - - D[i] = A[i, j] * B[j] + C[i] - indices, values = D.get_coordinates_and_values() - passed = np.array_equal(indices, [[0], [1]]) - passed += np.array_equal(values, [8., 11.]) - - # The expression is implemented as: - # temp0[i] = A[i, j] * B[i] - # D[i] = temp0[i] + C[i] - # Check the temporary tensor introduced by the implementation. - stats = D._stats - passed += (stats.get_total() == 1) - passed += (stats.get_formats(0) == (_COMPRESSED,)) - passed += (stats.get_dimensions(0) == (I,)) - # CHECK: Number of passed: 5 - print("Number of passed:", passed) + i = mlir_pytaco.IndexVar() + j = mlir_pytaco.IndexVar() + I = 2 + J = 3 + A = mlir_pytaco.Tensor([I, J]) + B = mlir_pytaco.Tensor([J]) + C = mlir_pytaco.Tensor([I]) + D = mlir_pytaco.Tensor([I], _DENSE) + + _init_2d(A, I, J) + _init_1d_with_value(C, I, 2) + _init_1d_with_value(B, J, 1) + + D[i] = A[i, j] * B[j] + C[i] + indices, values = D.get_coordinates_and_values() + passed = np.array_equal(indices, [[0], [1]]) + passed += np.array_equal(values, [8.0, 11.0]) + + # The expression is implemented as: + # temp0[i] = A[i, j] * B[i] + # D[i] = temp0[i] + C[i] + # Check the temporary tensor introduced by the implementation. + stats = D._stats + passed += stats.get_total() == 1 + passed += stats.get_formats(0) == (_COMPRESSED,) + passed += stats.get_dimensions(0) == (I,) + # CHECK: Number of passed: 5 + print("Number of passed:", passed) # CHECK-LABEL: test_binary_add_reduce_3d_1 @testing_utils.run_test def test_binary_add_reduce_3d_1(): - i, j, k, l = mlir_pytaco.get_index_vars(4) - I = 2 - J = 3 - K = 4 - L = 5 - A = mlir_pytaco.Tensor([I, J, K]) - B = mlir_pytaco.Tensor([I, J, L]) - C = mlir_pytaco.Tensor([K]) - D = mlir_pytaco.Tensor([L]) - E = mlir_pytaco.Tensor([I], _DENSE) - - _init_3d(A, I, J, K) - _init_3d(B, I, J, L) - _init_1d_with_value(C, K, 1) - _init_1d_with_value(D, L, 2) - - E[i] = A[i, j, k] * C[k] + B[i, j, l] * D[l] - indices, values = E.get_coordinates_and_values() - passed = np.array_equal(indices, [[0], [1]]) - passed += np.array_equal(values, [162., 204.]) - - # The expression is implemented as: - # temp0[i, j] = A[i, j, k] * C[k] - # temp1[i, j] = B[i, j, l] * D[l] - # E[i] = temp0[i, j] + temp1[i, j] - # Check the two temporary tensors introduced by the implementation. - stats = E._stats - passed += (stats.get_total() == 2) - passed += (stats.get_formats(0) == (_COMPRESSED, _COMPRESSED)) - passed += (stats.get_dimensions(0) == (I, J)) - passed += (stats.get_formats(1) == (_COMPRESSED, _COMPRESSED)) - passed += (stats.get_dimensions(1) == (I, J)) - # CHECK: Number of passed: 7 - print("Number of passed:", passed) + i, j, k, l = mlir_pytaco.get_index_vars(4) + I = 2 + J = 3 + K = 4 + L = 5 + A = mlir_pytaco.Tensor([I, J, K]) + B = mlir_pytaco.Tensor([I, J, L]) + C = mlir_pytaco.Tensor([K]) + D = mlir_pytaco.Tensor([L]) + E = mlir_pytaco.Tensor([I], _DENSE) + + _init_3d(A, I, J, K) + _init_3d(B, I, J, L) + _init_1d_with_value(C, K, 1) + _init_1d_with_value(D, L, 2) + + E[i] = A[i, j, k] * C[k] + B[i, j, l] * D[l] + indices, values = E.get_coordinates_and_values() + passed = np.array_equal(indices, [[0], [1]]) + passed += np.array_equal(values, [162.0, 204.0]) + + # The expression is implemented as: + # temp0[i, j] = A[i, j, k] * C[k] + # temp1[i, j] = B[i, j, l] * D[l] + # E[i] = temp0[i, j] + temp1[i, j] + # Check the two temporary tensors introduced by the implementation. + stats = E._stats + passed += stats.get_total() == 2 + passed += stats.get_formats(0) == (_COMPRESSED, _COMPRESSED) + passed += stats.get_dimensions(0) == (I, J) + passed += stats.get_formats(1) == (_COMPRESSED, _COMPRESSED) + passed += stats.get_dimensions(1) == (I, J) + # CHECK: Number of passed: 7 + print("Number of passed:", passed) # CHECK-LABEL: test_binary_add_reduce_3d_2 @testing_utils.run_test def test_binary_add_reduce_3d_2(): - i, j, k, l = mlir_pytaco.get_index_vars(4) - I = 2 - J = 3 - K = 4 - L = 5 - A = mlir_pytaco.Tensor([I, J, K], [_COMPRESSED, _COMPRESSED, _DENSE]) - B = mlir_pytaco.Tensor([I, L, K], [_DENSE, _COMPRESSED, _COMPRESSED]) - C = mlir_pytaco.Tensor([J, K], [_COMPRESSED, _COMPRESSED]) - D = mlir_pytaco.Tensor([L]) - E = mlir_pytaco.Tensor([I], _DENSE) - - _init_3d(A, I, J, K) - _init_3d(B, I, L, K) - _init_2d(C, J, K) - _init_1d_with_value(D, L, 2) - - E[i] = A[i, j, k] + C[j, k] + B[i, l, k] * D[l] - indices, values = E.get_coordinates_and_values() - passed = np.array_equal(indices, [[0], [1]]) - passed += np.array_equal(values, [264., 316.]) - - # The expression is implemented as: - # temp0[i, k] = A[i, j, k] + C[j, k] - # temp1[i, k] = B[i, l, k] * D[l] - # E[i] = temp0[i, k] + temp1[i, k] - # Check the two temporary tensors introduced by the implementation. - stats = E._stats - passed += (stats.get_total() == 2) - passed += (stats.get_formats(0) == (_COMPRESSED, _DENSE)) - passed += (stats.get_dimensions(0) == (I, K)) - passed += (stats.get_formats(1) == (_DENSE, _COMPRESSED)) - passed += (stats.get_dimensions(1) == (I, K)) - # CHECK: Number of passed: 7 - print("Number of passed:", passed) + i, j, k, l = mlir_pytaco.get_index_vars(4) + I = 2 + J = 3 + K = 4 + L = 5 + A = mlir_pytaco.Tensor([I, J, K], [_COMPRESSED, _COMPRESSED, _DENSE]) + B = mlir_pytaco.Tensor([I, L, K], [_DENSE, _COMPRESSED, _COMPRESSED]) + C = mlir_pytaco.Tensor([J, K], [_COMPRESSED, _COMPRESSED]) + D = mlir_pytaco.Tensor([L]) + E = mlir_pytaco.Tensor([I], _DENSE) + + _init_3d(A, I, J, K) + _init_3d(B, I, L, K) + _init_2d(C, J, K) + _init_1d_with_value(D, L, 2) + + E[i] = A[i, j, k] + C[j, k] + B[i, l, k] * D[l] + indices, values = E.get_coordinates_and_values() + passed = np.array_equal(indices, [[0], [1]]) + passed += np.array_equal(values, [264.0, 316.0]) + + # The expression is implemented as: + # temp0[i, k] = A[i, j, k] + C[j, k] + # temp1[i, k] = B[i, l, k] * D[l] + # E[i] = temp0[i, k] + temp1[i, k] + # Check the two temporary tensors introduced by the implementation. + stats = E._stats + passed += stats.get_total() == 2 + passed += stats.get_formats(0) == (_COMPRESSED, _DENSE) + passed += stats.get_dimensions(0) == (I, K) + passed += stats.get_formats(1) == (_DENSE, _COMPRESSED) + passed += stats.get_dimensions(1) == (I, K) + # CHECK: Number of passed: 7 + print("Number of passed:", passed) diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_io.py b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_io.py index cce97d6..1d52747 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_io.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_io.py @@ -32,21 +32,21 @@ _MTX_DATA = """%%MatrixMarket matrix coordinate real general # CHECK-LABEL: test_read_mtx_matrix_general @testing_utils.run_test def test_read_mtx_matrix_general(): - with tempfile.TemporaryDirectory() as test_dir: - file_name = os.path.join(test_dir, "data.mtx") - with open(file_name, "w") as file: - file.write(_MTX_DATA) - a = mlir_pytaco_io.read(file_name, _FORMAT) - passed = 0 - # The value of a is stored as an MLIR sparse tensor. - passed += (not a.is_unpacked()) - a.unpack() - passed += (a.is_unpacked()) - coords, values = a.get_coordinates_and_values() - passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]]) - passed += np.allclose(values, [2.0, 3.0, 4.0]) - # CHECK: 4 - print(passed) + with tempfile.TemporaryDirectory() as test_dir: + file_name = os.path.join(test_dir, "data.mtx") + with open(file_name, "w") as file: + file.write(_MTX_DATA) + a = mlir_pytaco_io.read(file_name, _FORMAT) + passed = 0 + # The value of a is stored as an MLIR sparse tensor. + passed += not a.is_unpacked() + a.unpack() + passed += a.is_unpacked() + coords, values = a.get_coordinates_and_values() + passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]]) + passed += np.allclose(values, [2.0, 3.0, 4.0]) + # CHECK: 4 + print(passed) _TNS_DATA = """2 3 @@ -60,57 +60,57 @@ _TNS_DATA = """2 3 # CHECK-LABEL: test_read_tns @testing_utils.run_test def test_read_tns(): - with tempfile.TemporaryDirectory() as test_dir: - file_name = os.path.join(test_dir, "data.tns") - with open(file_name, "w") as file: - file.write(_TNS_DATA) - a = mlir_pytaco_io.read(file_name, _FORMAT) - passed = 0 - # The value of a is stored as an MLIR sparse tensor. - passed += (not a.is_unpacked()) - a.unpack() - passed += (a.is_unpacked()) - coords, values = a.get_coordinates_and_values() - passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]]) - passed += np.allclose(values, [2.0, 3.0, 4.0]) - # CHECK: 4 - print(passed) + with tempfile.TemporaryDirectory() as test_dir: + file_name = os.path.join(test_dir, "data.tns") + with open(file_name, "w") as file: + file.write(_TNS_DATA) + a = mlir_pytaco_io.read(file_name, _FORMAT) + passed = 0 + # The value of a is stored as an MLIR sparse tensor. + passed += not a.is_unpacked() + a.unpack() + passed += a.is_unpacked() + coords, values = a.get_coordinates_and_values() + passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]]) + passed += np.allclose(values, [2.0, 3.0, 4.0]) + # CHECK: 4 + print(passed) # CHECK-LABEL: test_write_unpacked_tns @testing_utils.run_test def test_write_unpacked_tns(): - a = mlir_pytaco.Tensor([2, 3]) - a.insert([0, 1], 10) - a.insert([1, 2], 40) - a.insert([0, 0], 20) - with tempfile.TemporaryDirectory() as test_dir: - file_name = os.path.join(test_dir, "data.tns") - try: - mlir_pytaco_io.write(file_name, a) - except ValueError as e: - # CHECK: Writing unpacked sparse tensors to file is not supported - print(e) + a = mlir_pytaco.Tensor([2, 3]) + a.insert([0, 1], 10) + a.insert([1, 2], 40) + a.insert([0, 0], 20) + with tempfile.TemporaryDirectory() as test_dir: + file_name = os.path.join(test_dir, "data.tns") + try: + mlir_pytaco_io.write(file_name, a) + except ValueError as e: + # CHECK: Writing unpacked sparse tensors to file is not supported + print(e) # CHECK-LABEL: test_write_packed_tns @testing_utils.run_test def test_write_packed_tns(): - a = mlir_pytaco.Tensor([2, 3]) - a.insert([0, 1], 10) - a.insert([1, 2], 40) - a.insert([0, 0], 20) - b = mlir_pytaco.Tensor([2, 3]) - i, j = mlir_pytaco.get_index_vars(2) - b[i, j] = a[i, j] + a[i, j] - with tempfile.TemporaryDirectory() as test_dir: - file_name = os.path.join(test_dir, "data.tns") - mlir_pytaco_io.write(file_name, b) - with open(file_name, "r") as file: - lines = file.readlines() - passed = 0 - # Skip the comment line in the output. - if lines[1:] == ["2 3\n", "2 3\n", "1 1 40\n", "1 2 20\n", "2 3 80\n"]: - passed = 1 - # CHECK: 1 - print(passed) + a = mlir_pytaco.Tensor([2, 3]) + a.insert([0, 1], 10) + a.insert([1, 2], 40) + a.insert([0, 0], 20) + b = mlir_pytaco.Tensor([2, 3]) + i, j = mlir_pytaco.get_index_vars(2) + b[i, j] = a[i, j] + a[i, j] + with tempfile.TemporaryDirectory() as test_dir: + file_name = os.path.join(test_dir, "data.tns") + mlir_pytaco_io.write(file_name, b) + with open(file_name, "r") as file: + lines = file.readlines() + passed = 0 + # Skip the comment line in the output. + if lines[1:] == ["2 3\n", "2 3\n", "1 1 40\n", "1 2 20\n", "2 3 80\n"]: + passed = 1 + # CHECK: 1 + print(passed) diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_utils.py b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_utils.py index 1325969..1344f4a 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_utils.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/unit_test_tensor_utils.py @@ -20,79 +20,93 @@ _DENSE = mlir_pytaco.ModeFormat.DENSE def _to_string(s: Sequence[int]) -> str: - """Converts a sequence of integer to a space separated value string.""" - return " ".join(map(lambda e: str(e), s)) + """Converts a sequence of integer to a space separated value string.""" + return " ".join(map(lambda e: str(e), s)) def _add_one(s: Sequence[int]) -> Sequence[int]: - """Adds one to each element in the sequence of integer.""" - return [i + 1 for i in s] + """Adds one to each element in the sequence of integer.""" + return [i + 1 for i in s] @dataclasses.dataclass(frozen=True) class _SparseTensorCOO: - """Values for a COO-flavored format sparse tensor. - - Attributes: - rank: An integer rank for the tensor. - nse: An integer for the number of non-zero values. - shape: A sequence of integer for the dimension size. - values: A sequence of float for the non-zero values of the tensor. - indices: A sequence of coordinate, each coordinate is a sequence of integer. - """ - rank: int - nse: int - shape: Sequence[int] - values: Sequence[float] - indices: Sequence[Sequence[int]] + """Values for a COO-flavored format sparse tensor. + + Attributes: + rank: An integer rank for the tensor. + nse: An integer for the number of non-zero values. + shape: A sequence of integer for the dimension size. + values: A sequence of float for the non-zero values of the tensor. + indices: A sequence of coordinate, each coordinate is a sequence of integer. + """ + + rank: int + nse: int + shape: Sequence[int] + values: Sequence[float] + indices: Sequence[Sequence[int]] def _coo_values_to_tns_format(t: _SparseTensorCOO) -> str: - """Converts a sparse tensor COO-flavored values to TNS text format.""" - # The coo_value_str contains one line for each (coordinate value) pair. - # Indices are 1-based in TNS text format but 0-based in MLIR. - coo_value_str = "\n".join( - map(lambda i: _to_string(_add_one(t.indices[i])) + " " + str(t.values[i]), - range(t.nse))) - - # Returns the TNS text format representation for the tensor. - return f"""{t.rank} {t.nse} + """Converts a sparse tensor COO-flavored values to TNS text format.""" + # The coo_value_str contains one line for each (coordinate value) pair. + # Indices are 1-based in TNS text format but 0-based in MLIR. + coo_value_str = "\n".join( + map( + lambda i: _to_string(_add_one(t.indices[i])) + " " + str(t.values[i]), + range(t.nse), + ) + ) + + # Returns the TNS text format representation for the tensor. + return f"""{t.rank} {t.nse} {_to_string(t.shape)} {coo_value_str} """ def _implement_read_tns_test( - t: _SparseTensorCOO, - sparsity_codes: Sequence[sparse_tensor.DimLevelType]) -> int: - tns_data = _coo_values_to_tns_format(t) - - # Write sparse tensor data to a file. - with tempfile.TemporaryDirectory() as test_dir: - file_name = os.path.join(test_dir, "data.tns") - with open(file_name, "w") as file: - file.write(tns_data) - - # Read the data from the file and construct an MLIR sparse tensor. - sparse_tensor, o_shape = pytaco_utils.create_sparse_tensor( - file_name, sparsity_codes, "f64") - - passed = 0 - - # Verify the output shape for the tensor. - if np.array_equal(o_shape, t.shape): - passed += 1 - - # Use the output MLIR sparse tensor pointer to retrieve the COO-flavored - # values and verify the values. - o_rank, o_nse, o_shape, o_values, o_indices = ( - pytaco_utils.sparse_tensor_to_coo_tensor(sparse_tensor, np.float64)) - if o_rank == t.rank and o_nse == t.nse and np.array_equal( - o_shape, t.shape) and np.allclose(o_values, t.values) and np.array_equal( - o_indices, t.indices): - passed += 1 - - return passed + t: _SparseTensorCOO, sparsity_codes: Sequence[sparse_tensor.DimLevelType] +) -> int: + tns_data = _coo_values_to_tns_format(t) + + # Write sparse tensor data to a file. + with tempfile.TemporaryDirectory() as test_dir: + file_name = os.path.join(test_dir, "data.tns") + with open(file_name, "w") as file: + file.write(tns_data) + + # Read the data from the file and construct an MLIR sparse tensor. + sparse_tensor, o_shape = pytaco_utils.create_sparse_tensor( + file_name, sparsity_codes, "f64" + ) + + passed = 0 + + # Verify the output shape for the tensor. + if np.array_equal(o_shape, t.shape): + passed += 1 + + # Use the output MLIR sparse tensor pointer to retrieve the COO-flavored + # values and verify the values. + ( + o_rank, + o_nse, + o_shape, + o_values, + o_indices, + ) = pytaco_utils.sparse_tensor_to_coo_tensor(sparse_tensor, np.float64) + if ( + o_rank == t.rank + and o_nse == t.nse + and np.array_equal(o_shape, t.shape) + and np.allclose(o_values, t.values) + and np.array_equal(o_indices, t.indices) + ): + passed += 1 + + return passed # A 2D sparse tensor data in COO-flavored format. diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg b/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg index 12c97db..70b4b66 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg +++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg @@ -5,11 +5,11 @@ if not config.mlir_run_amx_tests: config.unsupported = True # No JIT on win32. -if sys.platform == 'win32': +if sys.platform == "win32": config.unsupported = True if config.intel_sde_executable: # Run test in emulator (Intel SDE): AMX needs Sapphire Rapids CPU. - config.substitutions.append(('%lli', config.intel_sde_executable + ' -spr -- lli')) + config.substitutions.append(("%lli", config.intel_sde_executable + " -spr -- lli")) else: - config.substitutions.append(('%lli', 'lli')) + config.substitutions.append(("%lli", "lli")) diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/lit.local.cfg b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/lit.local.cfg index 0423fc0..296b441 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/lit.local.cfg +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/lit.local.cfg @@ -5,5 +5,5 @@ if not config.mlir_run_arm_sme_tests: config.unsupported = True # No JIT on win32. -if sys.platform == 'win32': +if sys.platform == "win32": config.unsupported = True diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/lit.local.cfg b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/lit.local.cfg index 8a0d884..37d3a74 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/lit.local.cfg +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/lit.local.cfg @@ -5,5 +5,5 @@ if not config.mlir_run_arm_sve_tests: config.unsupported = True # No JIT on win32. -if sys.platform == 'win32': +if sys.platform == "win32": config.unsupported = True diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/lit.local.cfg b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/lit.local.cfg index 0e22874..bde8156 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/lit.local.cfg +++ b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/lit.local.cfg @@ -5,11 +5,11 @@ if not config.mlir_run_x86vector_tests: config.unsupported = True # No JIT on win32. -if sys.platform == 'win32': +if sys.platform == "win32": config.unsupported = True if config.intel_sde_executable: # Run test in emulator (Intel SDE). - config.substitutions.append(('%lli', config.intel_sde_executable + ' -tgl -- lli')) + config.substitutions.append(("%lli", config.intel_sde_executable + " -tgl -- lli")) else: - config.substitutions.append(('%lli', 'lli')) + config.substitutions.append(("%lli", "lli")) diff --git a/mlir/test/Integration/Dialect/Vector/GPU/CUDA/lit.local.cfg b/mlir/test/Integration/Dialect/Vector/GPU/CUDA/lit.local.cfg index 0bdebfe..acb8dd4 100644 --- a/mlir/test/Integration/Dialect/Vector/GPU/CUDA/lit.local.cfg +++ b/mlir/test/Integration/Dialect/Vector/GPU/CUDA/lit.local.cfg @@ -1,2 +1,2 @@ if not config.enable_cuda_runner: - config.unsupported = True + config.unsupported = True diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/lit.local.cfg b/mlir/test/Integration/GPU/CUDA/TensorCore/lit.local.cfg index 451b9fc..3bd7024 100644 --- a/mlir/test/Integration/GPU/CUDA/TensorCore/lit.local.cfg +++ b/mlir/test/Integration/GPU/CUDA/TensorCore/lit.local.cfg @@ -2,4 +2,4 @@ import sys # TensorCore tests must be enabled via build flag. if not config.mlir_run_cuda_tensor_core_tests: - config.unsupported = True + config.unsupported = True diff --git a/mlir/test/Integration/GPU/CUDA/lit.local.cfg b/mlir/test/Integration/GPU/CUDA/lit.local.cfg index 0bdebfe..acb8dd4 100644 --- a/mlir/test/Integration/GPU/CUDA/lit.local.cfg +++ b/mlir/test/Integration/GPU/CUDA/lit.local.cfg @@ -1,2 +1,2 @@ if not config.enable_cuda_runner: - config.unsupported = True + config.unsupported = True diff --git a/mlir/test/Integration/GPU/ROCM/lit.local.cfg b/mlir/test/Integration/GPU/ROCM/lit.local.cfg index b0d086f..e1f8648 100644 --- a/mlir/test/Integration/GPU/ROCM/lit.local.cfg +++ b/mlir/test/Integration/GPU/ROCM/lit.local.cfg @@ -1,4 +1,4 @@ if not config.enable_rocm_runner or not config.rocm_test_chipset: - config.unsupported = True + config.unsupported = True -config.substitutions.append(('%chip', config.rocm_test_chipset)) +config.substitutions.append(("%chip", config.rocm_test_chipset)) diff --git a/mlir/test/Integration/lit.local.cfg b/mlir/test/Integration/lit.local.cfg index 80a862a..1b4a323 100644 --- a/mlir/test/Integration/lit.local.cfg +++ b/mlir/test/Integration/lit.local.cfg @@ -3,8 +3,9 @@ from lit.llvm import llvm_config if not config.mlir_include_integration_tests: config.unsupported = True + def configure_aarch64_lli_cmd(): - lli_cmd = 'lli' + lli_cmd = "lli" # NOTE: If the SVE tests are disabled and the SME tests are enabled to run # under emulation, the SVE specific RUN lines in the SparseTensor tests @@ -12,8 +13,12 @@ def configure_aarch64_lli_cmd(): if not (config.mlir_run_arm_sve_tests or config.mlir_run_arm_sme_tests): return lli_cmd - config.substitutions.append(('%mlir_native_utils_lib_dir', - config.arm_emulator_utils_lib_dir or config.mlir_lib_dir)) + config.substitutions.append( + ( + "%mlir_native_utils_lib_dir", + config.arm_emulator_utils_lib_dir or config.mlir_lib_dir, + ) + ) if config.arm_emulator_executable: if config.arm_emulator_lli_executable: @@ -23,16 +28,22 @@ def configure_aarch64_lli_cmd(): # when running under an emulator. If the user didn't specify an lli # executable, use absolute path %llvm_tools_dir/lli. lli_cmd = llvm_config.use_llvm_tool( - 'lli', search_env='LLI', required=True, - search_paths=[config.llvm_tools_dir], use_installed=False + "lli", + search_env="LLI", + required=True, + search_paths=[config.llvm_tools_dir], + use_installed=False, ) # Run test in emulator (qemu or armie) - emulation_cmd = f'{config.arm_emulator_executable} {config.arm_emulator_options}' - lli_cmd = f'{emulation_cmd} {lli_cmd}' + emulation_cmd = ( + f"{config.arm_emulator_executable} {config.arm_emulator_options}" + ) + lli_cmd = f"{emulation_cmd} {lli_cmd}" return lli_cmd + aarch64_lli_cmd = configure_aarch64_lli_cmd() # Configure the following AArch64 substitutions: @@ -52,5 +63,5 @@ aarch64_lli_cmd = configure_aarch64_lli_cmd() # could be used in the SparseTensor tests where necessary, but the meaning # conveyed by the substitution name would be a misnomer if the host target # is not AArch64 and MLIR_RUN_ARM_SVE_TESTS=OFF. -config.substitutions.append(('%lli_aarch64_cmd', aarch64_lli_cmd)) -config.substitutions.append(('%lli_host_or_aarch64_cmd', aarch64_lli_cmd)) +config.substitutions.append(("%lli_aarch64_cmd", aarch64_lli_cmd)) +config.substitutions.append(("%lli_host_or_aarch64_cmd", aarch64_lli_cmd)) diff --git a/mlir/test/Unit/lit.cfg.py b/mlir/test/Unit/lit.cfg.py index 5b66517..1898b72 100644 --- a/mlir/test/Unit/lit.cfg.py +++ b/mlir/test/Unit/lit.cfg.py @@ -8,43 +8,43 @@ import subprocess import lit.formats # name: The name of this test suite. -config.name = 'MLIR-Unit' +config.name = "MLIR-Unit" # suffixes: A list of file extensions to treat as test files. config.suffixes = [] # test_source_root: The root path where tests are located. # test_exec_root: The root path where tests should be run. -config.test_exec_root = os.path.join(config.mlir_obj_root, 'unittests') +config.test_exec_root = os.path.join(config.mlir_obj_root, "unittests") config.test_source_root = config.test_exec_root # testFormat: The test format to use to interpret tests. -config.test_format = lit.formats.GoogleTest(config.llvm_build_mode, 'Tests') +config.test_format = lit.formats.GoogleTest(config.llvm_build_mode, "Tests") # Propagate the temp directory. Windows requires this because it uses \Windows\ # if none of these are present. -if 'TMP' in os.environ: - config.environment['TMP'] = os.environ['TMP'] -if 'TEMP' in os.environ: - config.environment['TEMP'] = os.environ['TEMP'] +if "TMP" in os.environ: + config.environment["TMP"] = os.environ["TMP"] +if "TEMP" in os.environ: + config.environment["TEMP"] = os.environ["TEMP"] # Propagate HOME as it can be used to override incorrect homedir in passwd # that causes the tests to fail. -if 'HOME' in os.environ: - config.environment['HOME'] = os.environ['HOME'] +if "HOME" in os.environ: + config.environment["HOME"] = os.environ["HOME"] # Propagate sanitizer options. for var in [ - 'ASAN_SYMBOLIZER_PATH', - 'HWASAN_SYMBOLIZER_PATH', - 'MSAN_SYMBOLIZER_PATH', - 'TSAN_SYMBOLIZER_PATH', - 'UBSAN_SYMBOLIZER_PATH', - 'ASAN_OPTIONS', - 'HWASAN_OPTIONS', - 'MSAN_OPTIONS', - 'TSAN_OPTIONS', - 'UBSAN_OPTIONS', + "ASAN_SYMBOLIZER_PATH", + "HWASAN_SYMBOLIZER_PATH", + "MSAN_SYMBOLIZER_PATH", + "TSAN_SYMBOLIZER_PATH", + "UBSAN_SYMBOLIZER_PATH", + "ASAN_OPTIONS", + "HWASAN_OPTIONS", + "MSAN_OPTIONS", + "TSAN_OPTIONS", + "UBSAN_OPTIONS", ]: if var in os.environ: config.environment[var] = os.environ[var] diff --git a/mlir/test/lib/Dialect/Test/lit.local.cfg b/mlir/test/lib/Dialect/Test/lit.local.cfg index edb5b44..65a7f20 100644 --- a/mlir/test/lib/Dialect/Test/lit.local.cfg +++ b/mlir/test/lib/Dialect/Test/lit.local.cfg @@ -1 +1 @@ -config.suffixes.remove('.td') \ No newline at end of file +config.suffixes.remove(".td") diff --git a/mlir/test/lib/Dialect/Transform/lit.local.cfg b/mlir/test/lib/Dialect/Transform/lit.local.cfg index edb5b44..65a7f20 100644 --- a/mlir/test/lib/Dialect/Transform/lit.local.cfg +++ b/mlir/test/lib/Dialect/Transform/lit.local.cfg @@ -1 +1 @@ -config.suffixes.remove('.td') \ No newline at end of file +config.suffixes.remove(".td") diff --git a/mlir/test/lib/Tools/PDLL/lit.local.cfg b/mlir/test/lib/Tools/PDLL/lit.local.cfg index 8cfe5cd..8ffccee 100644 --- a/mlir/test/lib/Tools/PDLL/lit.local.cfg +++ b/mlir/test/lib/Tools/PDLL/lit.local.cfg @@ -1 +1 @@ -config.suffixes.remove('.pdll') +config.suffixes.remove(".pdll") diff --git a/mlir/test/lib/Transforms/lit.local.cfg b/mlir/test/lib/Transforms/lit.local.cfg index 8cfe5cd..8ffccee 100644 --- a/mlir/test/lib/Transforms/lit.local.cfg +++ b/mlir/test/lib/Transforms/lit.local.cfg @@ -1 +1 @@ -config.suffixes.remove('.pdll') +config.suffixes.remove(".pdll") diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py index 1fc2e31..ad0b0d5 100644 --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -16,21 +16,32 @@ from lit.llvm.subst import FindTool # Configuration file for the 'lit' test runner. # name: The name of this test suite. -config.name = 'MLIR' +config.name = "MLIR" config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) # suffixes: A list of file extensions to treat as test files. -config.suffixes = ['.td', '.mlir', '.toy', '.ll', '.tc', '.py', '.yaml', '.test', '.pdll', '.c'] +config.suffixes = [ + ".td", + ".mlir", + ".toy", + ".ll", + ".tc", + ".py", + ".yaml", + ".test", + ".pdll", + ".c", +] # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) # test_exec_root: The root path where tests should be run. -config.test_exec_root = os.path.join(config.mlir_obj_root, 'test') +config.test_exec_root = os.path.join(config.mlir_obj_root, "test") -config.substitutions.append(('%PATH%', config.environment['PATH'])) -config.substitutions.append(('%shlibext', config.llvm_shlib_ext)) +config.substitutions.append(("%PATH%", config.environment["PATH"])) +config.substitutions.append(("%shlibext", config.llvm_shlib_ext)) config.substitutions.append(("%mlir_src_root", config.mlir_src_root)) config.substitutions.append(("%host_cxx", config.host_cxx)) config.substitutions.append(("%host_cc", config.host_cc)) @@ -40,94 +51,109 @@ config.substitutions.append(("%host_cc", config.host_cc)) # substitution of the same name and the found path. # Correctly handles the platforms shared library directory and naming conventions. def add_runtime(name): - path = '' - for prefix in ['', 'lib']: - path = os.path.join(config.llvm_shlib_dir, f'{prefix}{name}{config.llvm_shlib_ext}') + path = "" + for prefix in ["", "lib"]: + path = os.path.join( + config.llvm_shlib_dir, f"{prefix}{name}{config.llvm_shlib_ext}" + ) if os.path.isfile(path): break - return ToolSubst(f'%{name}', path) + return ToolSubst(f"%{name}", path) -llvm_config.with_system_environment( - ['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP']) +llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"]) llvm_config.use_default_substitutions() # excludes: A list of directories to exclude from the testsuite. The 'Inputs' # subdirectories contain auxiliary inputs for various tests in their parent # directories. -config.excludes = ['Inputs', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt', - 'lit.cfg.py', 'lit.site.cfg.py'] +config.excludes = [ + "Inputs", + "CMakeLists.txt", + "README.txt", + "LICENSE.txt", + "lit.cfg.py", + "lit.site.cfg.py", +] # Tweak the PATH to include the tools dir. -llvm_config.with_environment('PATH', config.mlir_tools_dir, append_path=True) -llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) +llvm_config.with_environment("PATH", config.mlir_tools_dir, append_path=True) +llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True) tool_dirs = [config.mlir_tools_dir, config.llvm_tools_dir] tools = [ - 'mlir-tblgen', - 'mlir-translate', - 'mlir-lsp-server', - 'mlir-capi-execution-engine-test', - 'mlir-capi-ir-test', - 'mlir-capi-llvm-test', - 'mlir-capi-pass-test', - 'mlir-capi-pdl-test', - 'mlir-capi-quant-test', - 'mlir-capi-sparse-tensor-test', - 'mlir-capi-transform-test', - 'mlir-cpu-runner', - add_runtime('mlir_runner_utils'), - add_runtime('mlir_c_runner_utils'), - add_runtime('mlir_async_runtime'), - 'mlir-linalg-ods-yaml-gen', - 'mlir-reduce', - 'mlir-pdll', - 'not', + "mlir-tblgen", + "mlir-translate", + "mlir-lsp-server", + "mlir-capi-execution-engine-test", + "mlir-capi-ir-test", + "mlir-capi-llvm-test", + "mlir-capi-pass-test", + "mlir-capi-pdl-test", + "mlir-capi-quant-test", + "mlir-capi-sparse-tensor-test", + "mlir-capi-transform-test", + "mlir-cpu-runner", + add_runtime("mlir_runner_utils"), + add_runtime("mlir_c_runner_utils"), + add_runtime("mlir_async_runtime"), + "mlir-linalg-ods-yaml-gen", + "mlir-reduce", + "mlir-pdll", + "not", ] if config.enable_spirv_cpu_runner: - tools.extend(['mlir-spirv-cpu-runner', add_runtime('mlir_test_spirv_cpu_runner_c_wrappers')]) + tools.extend( + ["mlir-spirv-cpu-runner", add_runtime("mlir_test_spirv_cpu_runner_c_wrappers")] + ) if config.enable_vulkan_runner: - tools.extend([add_runtime('vulkan-runtime-wrappers')]) + tools.extend([add_runtime("vulkan-runtime-wrappers")]) if config.enable_rocm_runner: - tools.extend([add_runtime('mlir_rocm_runtime')]) + tools.extend([add_runtime("mlir_rocm_runtime")]) if config.enable_cuda_runner: - tools.extend([add_runtime('mlir_cuda_runtime')]) + tools.extend([add_runtime("mlir_cuda_runtime")]) # The following tools are optional -tools.extend([ - ToolSubst('toyc-ch1', unresolved='ignore'), - ToolSubst('toyc-ch2', unresolved='ignore'), - ToolSubst('toyc-ch3', unresolved='ignore'), - ToolSubst('toyc-ch4', unresolved='ignore'), - ToolSubst('toyc-ch5', unresolved='ignore'), - ToolSubst('toyc-ch6', unresolved='ignore'), - ToolSubst('toyc-ch7', unresolved='ignore'), - ToolSubst('%mlir_lib_dir', config.mlir_lib_dir, unresolved='ignore'), - ToolSubst('%mlir_src_dir', config.mlir_src_root, unresolved='ignore'), -]) +tools.extend( + [ + ToolSubst("toyc-ch1", unresolved="ignore"), + ToolSubst("toyc-ch2", unresolved="ignore"), + ToolSubst("toyc-ch3", unresolved="ignore"), + ToolSubst("toyc-ch4", unresolved="ignore"), + ToolSubst("toyc-ch5", unresolved="ignore"), + ToolSubst("toyc-ch6", unresolved="ignore"), + ToolSubst("toyc-ch7", unresolved="ignore"), + ToolSubst("%mlir_lib_dir", config.mlir_lib_dir, unresolved="ignore"), + ToolSubst("%mlir_src_dir", config.mlir_src_root, unresolved="ignore"), + ] +) python_executable = config.python_executable # Python configuration with sanitizer requires some magic preloading. This will only work on clang/linux. # TODO: detect Darwin/Windows situation (or mark these tests as unsupported on these platforms). if "asan" in config.available_features and "Linux" in config.host_os: - python_executable = f"LD_PRELOAD=$({config.host_cxx} -print-file-name=libclang_rt.asan-{config.host_arch}.so) {config.python_executable}" + python_executable = f"LD_PRELOAD=$({config.host_cxx} -print-file-name=libclang_rt.asan-{config.host_arch}.so) {config.python_executable}" # On Windows the path to python could contains spaces in which case it needs to be provided in quotes. # This is the equivalent of how %python is setup in llvm/utils/lit/lit/llvm/config.py. elif "Windows" in config.host_os: - python_executable = '"%s"' % (python_executable) -tools.extend([ - ToolSubst('%PYTHON', python_executable, unresolved='ignore'), -]) + python_executable = '"%s"' % (python_executable) +tools.extend( + [ + ToolSubst("%PYTHON", python_executable, unresolved="ignore"), + ] +) if "MLIR_OPT_CHECK_IR_ROUNDTRIP" in os.environ: - tools.extend([ - ToolSubst('mlir-opt', 'mlir-opt --verify-roundtrip', unresolved='fatal'), - ]) + tools.extend( + [ + ToolSubst("mlir-opt", "mlir-opt --verify-roundtrip", unresolved="fatal"), + ] + ) llvm_config.add_tool_substitutions(tools, tool_dirs) @@ -135,40 +161,48 @@ llvm_config.add_tool_substitutions(tools, tool_dirs) # FileCheck -enable-var-scope is enabled by default in MLIR test # This option avoids to accidentally reuse variable across -LABEL match, # it can be explicitly opted-in by prefixing the variable name with $ -config.environment['FILECHECK_OPTS'] = "-enable-var-scope --allow-unused-prefixes=false" +config.environment["FILECHECK_OPTS"] = "-enable-var-scope --allow-unused-prefixes=false" # Add the python path for both the source and binary tree. # Note that presently, the python sources come from the source tree and the # binaries come from the build tree. This should be unified to the build tree # by copying/linking sources to build. if config.enable_bindings_python: - llvm_config.with_environment('PYTHONPATH', [ - os.path.join(config.mlir_obj_root, 'python_packages', 'mlir_core'), - os.path.join(config.mlir_obj_root, 'python_packages', 'mlir_test'), - ], append_path=True) + llvm_config.with_environment( + "PYTHONPATH", + [ + os.path.join(config.mlir_obj_root, "python_packages", "mlir_core"), + os.path.join(config.mlir_obj_root, "python_packages", "mlir_test"), + ], + append_path=True, + ) if config.enable_assertions: - config.available_features.add('asserts') + config.available_features.add("asserts") else: - config.available_features.add('noasserts') + config.available_features.add("noasserts") + def have_host_jit_feature_support(feature_name): - mlir_cpu_runner_exe = lit.util.which('mlir-cpu-runner', config.mlir_tools_dir) + mlir_cpu_runner_exe = lit.util.which("mlir-cpu-runner", config.mlir_tools_dir) + + if not mlir_cpu_runner_exe: + return False - if not mlir_cpu_runner_exe: - return False + try: + mlir_cpu_runner_cmd = subprocess.Popen( + [mlir_cpu_runner_exe, "--host-supports-" + feature_name], + stdout=subprocess.PIPE, + ) + except OSError: + print("could not exec mlir-cpu-runner") + return False - try: - mlir_cpu_runner_cmd = subprocess.Popen( - [mlir_cpu_runner_exe, '--host-supports-' + feature_name], stdout=subprocess.PIPE) - except OSError: - print('could not exec mlir-cpu-runner') - return False + mlir_cpu_runner_out = mlir_cpu_runner_cmd.stdout.read().decode("ascii") + mlir_cpu_runner_cmd.wait() - mlir_cpu_runner_out = mlir_cpu_runner_cmd.stdout.read().decode('ascii') - mlir_cpu_runner_cmd.wait() + return "true" in mlir_cpu_runner_out - return 'true' in mlir_cpu_runner_out -if have_host_jit_feature_support('jit'): - config.available_features.add('host-supports-jit') +if have_host_jit_feature_support("jit"): + config.available_features.add("host-supports-jit") diff --git a/mlir/test/mlir-cpu-runner/lit.local.cfg b/mlir/test/mlir-cpu-runner/lit.local.cfg index 3f59ff1..3c20d20 100644 --- a/mlir/test/mlir-cpu-runner/lit.local.cfg +++ b/mlir/test/mlir-cpu-runner/lit.local.cfg @@ -1,12 +1,11 @@ import sys # MSAN does not work with JIT. -if 'msan' in config.available_features: - config.unsupported = True +if "msan" in config.available_features: + config.unsupported = True # Requires native execution. -if 'host-supports-jit' not in config.available_features: +if "host-supports-jit" not in config.available_features: config.unsupported = True -config.available_features.add( - config.root.native_target.lower() + '-native-target') +config.available_features.add(config.root.native_target.lower() + "-native-target") diff --git a/mlir/test/mlir-pdll-lsp-server/lit.local.cfg b/mlir/test/mlir-pdll-lsp-server/lit.local.cfg index 25d08c7..aa35dbf 100644 --- a/mlir/test/mlir-pdll-lsp-server/lit.local.cfg +++ b/mlir/test/mlir-pdll-lsp-server/lit.local.cfg @@ -1 +1 @@ -config.excludes = ['include'] +config.excludes = ["include"] diff --git a/mlir/test/mlir-pdll/lit.local.cfg b/mlir/test/mlir-pdll/lit.local.cfg index c438027..4cb5622 100644 --- a/mlir/test/mlir-pdll/lit.local.cfg +++ b/mlir/test/mlir-pdll/lit.local.cfg @@ -1,2 +1,2 @@ -config.suffixes = ['.pdll', '.mlir'] -config.excludes = ['include'] +config.suffixes = [".pdll", ".mlir"] +config.excludes = ["include"] diff --git a/mlir/test/mlir-spirv-cpu-runner/lit.local.cfg b/mlir/test/mlir-spirv-cpu-runner/lit.local.cfg index 286bea4..8717dd0 100644 --- a/mlir/test/mlir-spirv-cpu-runner/lit.local.cfg +++ b/mlir/test/mlir-spirv-cpu-runner/lit.local.cfg @@ -1,4 +1,4 @@ import sys if not config.enable_spirv_cpu_runner: - config.unsupported = True + config.unsupported = True diff --git a/mlir/test/mlir-vulkan-runner/lit.local.cfg b/mlir/test/mlir-vulkan-runner/lit.local.cfg index f99be2a..6da7fcd 100644 --- a/mlir/test/mlir-vulkan-runner/lit.local.cfg +++ b/mlir/test/mlir-vulkan-runner/lit.local.cfg @@ -1,2 +1,2 @@ if not config.enable_vulkan_runner: - config.unsupported = True + config.unsupported = True diff --git a/mlir/test/python/develoment_files.py b/mlir/test/python/develoment_files.py index ea0a911..4dc3a0b 100644 --- a/mlir/test/python/develoment_files.py +++ b/mlir/test/python/develoment_files.py @@ -14,5 +14,6 @@ expected_lib_name = "MLIRPythonCAPI" all_libs = os.listdir(get_lib_dirs()[0]) found_lib = False for file_name in all_libs: - if expected_lib_name in file_name: found_lib = True + if expected_lib_name in file_name: + found_lib = True assert found_lib, f"Did not find '{expected_lib_name}' lib in {all_libs}" diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py index acae9b6..8e9613d 100644 --- a/mlir/test/python/dialects/arith_dialect.py +++ b/mlir/test/python/dialects/arith_dialect.py @@ -4,16 +4,18 @@ from mlir.ir import * import mlir.dialects.func as func import mlir.dialects.arith as arith + def run(f): - print("\nTEST:", f.__name__) - f() + print("\nTEST:", f.__name__) + f() + # CHECK-LABEL: TEST: testConstantOp @run def testConstantOps(): - with Context() as ctx, Location.unknown(): - module = Module.create() - with InsertionPoint(module.body): - arith.ConstantOp(value=42.42, result=F32Type.get()) - # CHECK: %cst = arith.constant 4.242000e+01 : f32 - print(module) + with Context() as ctx, Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + arith.ConstantOp(value=42.42, result=F32Type.get()) + # CHECK: %cst = arith.constant 4.242000e+01 : f32 + print(module) diff --git a/mlir/test/python/dialects/async_dialect.py b/mlir/test/python/dialects/async_dialect.py index da3103c..f6181cc 100644 --- a/mlir/test/python/dialects/async_dialect.py +++ b/mlir/test/python/dialects/async_dialect.py @@ -5,14 +5,17 @@ import mlir.dialects.async_dialect import mlir.dialects.async_dialect.passes from mlir.passmanager import * + def run(f): - print("\nTEST:", f.__name__) - f() + print("\nTEST:", f.__name__) + f() + def testAsyncPass(): - with Context() as context: - PassManager.parse('any(async-to-async-runtime)') - print('SUCCESS') + with Context() as context: + PassManager.parse("any(async-to-async-runtime)") + print("SUCCESS") + # CHECK-LABEL: testAsyncPass # CHECK: SUCCESS diff --git a/mlir/test/python/dialects/builtin.py b/mlir/test/python/dialects/builtin.py index eab24b5..18ebba6 100644 --- a/mlir/test/python/dialects/builtin.py +++ b/mlir/test/python/dialects/builtin.py @@ -7,232 +7,242 @@ import numpy as np def run(f): - print("\nTEST:", f.__name__) - f() - return f + print("\nTEST:", f.__name__) + f() + return f # CHECK-LABEL: TEST: testFromPyFunc @run def testFromPyFunc(): - with Context() as ctx, Location.unknown() as loc: - ctx.allow_unregistered_dialects = True - m = builtin.ModuleOp() - f32 = F32Type.get() - f64 = F64Type.get() - with InsertionPoint(m.body): - # CHECK-LABEL: func @unary_return(%arg0: f64) -> f64 - # CHECK: return %arg0 : f64 - @func.FuncOp.from_py_func(f64) - def unary_return(a): - return a - - # CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64) - # CHECK: return %arg0, %arg1 : f32, f64 - @func.FuncOp.from_py_func(f32, f64) - def binary_return(a, b): - return a, b - - # CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64) - # CHECK: return - @func.FuncOp.from_py_func(f32, f64) - def none_return(a, b): - pass - - # CHECK-LABEL: func @call_unary - # CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64 - # CHECK: return %0 : f64 - @func.FuncOp.from_py_func(f64) - def call_unary(a): - return unary_return(a) - - # CHECK-LABEL: func @call_binary - # CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64) - # CHECK: return %0#0, %0#1 : f32, f64 - @func.FuncOp.from_py_func(f32, f64) - def call_binary(a, b): - return binary_return(a, b) - - # We expect coercion of a single result operation to a returned value. - # CHECK-LABEL: func @single_result_op - # CHECK: %0 = "custom.op1"() : () -> f32 - # CHECK: return %0 : f32 - @func.FuncOp.from_py_func() - def single_result_op(): - return Operation.create("custom.op1", results=[f32]) - - # CHECK-LABEL: func @call_none - # CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> () - # CHECK: return - @func.FuncOp.from_py_func(f32, f64) - def call_none(a, b): - return none_return(a, b) - - ## Variants and optional feature tests. - # CHECK-LABEL: func @from_name_arg - @func.FuncOp.from_py_func(f32, f64, name="from_name_arg") - def explicit_name(a, b): - return b - - @func.FuncOp.from_py_func(f32, f64) - def positional_func_op(a, b, func_op): - assert isinstance(func_op, func.FuncOp) - return b - - @func.FuncOp.from_py_func(f32, f64) - def kw_func_op(a, b=None, func_op=None): - assert isinstance(func_op, func.FuncOp) - return b - - @func.FuncOp.from_py_func(f32, f64) - def kwargs_func_op(a, b=None, **kwargs): - assert isinstance(kwargs["func_op"], func.FuncOp) - return b - - # CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64 - # CHECK: return %arg1 : f64 - @func.FuncOp.from_py_func(f32, f64, results=[f64]) - def explicit_results(a, b): - func.ReturnOp([b]) - - print(m) + with Context() as ctx, Location.unknown() as loc: + ctx.allow_unregistered_dialects = True + m = builtin.ModuleOp() + f32 = F32Type.get() + f64 = F64Type.get() + with InsertionPoint(m.body): + # CHECK-LABEL: func @unary_return(%arg0: f64) -> f64 + # CHECK: return %arg0 : f64 + @func.FuncOp.from_py_func(f64) + def unary_return(a): + return a + + # CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64) + # CHECK: return %arg0, %arg1 : f32, f64 + @func.FuncOp.from_py_func(f32, f64) + def binary_return(a, b): + return a, b + + # CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64) + # CHECK: return + @func.FuncOp.from_py_func(f32, f64) + def none_return(a, b): + pass + + # CHECK-LABEL: func @call_unary + # CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64 + # CHECK: return %0 : f64 + @func.FuncOp.from_py_func(f64) + def call_unary(a): + return unary_return(a) + + # CHECK-LABEL: func @call_binary + # CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64) + # CHECK: return %0#0, %0#1 : f32, f64 + @func.FuncOp.from_py_func(f32, f64) + def call_binary(a, b): + return binary_return(a, b) + + # We expect coercion of a single result operation to a returned value. + # CHECK-LABEL: func @single_result_op + # CHECK: %0 = "custom.op1"() : () -> f32 + # CHECK: return %0 : f32 + @func.FuncOp.from_py_func() + def single_result_op(): + return Operation.create("custom.op1", results=[f32]) + + # CHECK-LABEL: func @call_none + # CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> () + # CHECK: return + @func.FuncOp.from_py_func(f32, f64) + def call_none(a, b): + return none_return(a, b) + + ## Variants and optional feature tests. + # CHECK-LABEL: func @from_name_arg + @func.FuncOp.from_py_func(f32, f64, name="from_name_arg") + def explicit_name(a, b): + return b + + @func.FuncOp.from_py_func(f32, f64) + def positional_func_op(a, b, func_op): + assert isinstance(func_op, func.FuncOp) + return b + + @func.FuncOp.from_py_func(f32, f64) + def kw_func_op(a, b=None, func_op=None): + assert isinstance(func_op, func.FuncOp) + return b + + @func.FuncOp.from_py_func(f32, f64) + def kwargs_func_op(a, b=None, **kwargs): + assert isinstance(kwargs["func_op"], func.FuncOp) + return b + + # CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64 + # CHECK: return %arg1 : f64 + @func.FuncOp.from_py_func(f32, f64, results=[f64]) + def explicit_results(a, b): + func.ReturnOp([b]) + + print(m) # CHECK-LABEL: TEST: testFromPyFuncErrors @run def testFromPyFuncErrors(): - with Context() as ctx, Location.unknown() as loc: - m = builtin.ModuleOp() - f32 = F32Type.get() - f64 = F64Type.get() - with InsertionPoint(m.body): - try: + with Context() as ctx, Location.unknown() as loc: + m = builtin.ModuleOp() + f32 = F32Type.get() + f64 = F64Type.get() + with InsertionPoint(m.body): + try: - @func.FuncOp.from_py_func(f64, results=[f64]) - def unary_return(a): - return a - except AssertionError as e: - # CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None. - print(e) + @func.FuncOp.from_py_func(f64, results=[f64]) + def unary_return(a): + return a + + except AssertionError as e: + # CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None. + print(e) # CHECK-LABEL: TEST: testBuildFuncOp @run def testBuildFuncOp(): - ctx = Context() - with Location.unknown(ctx) as loc: - m = builtin.ModuleOp() - - f32 = F32Type.get() - tensor_type = RankedTensorType.get((2, 3, 4), f32) - with InsertionPoint.at_block_begin(m.body): - f = func.FuncOp(name="some_func", - type=FunctionType.get( - inputs=[tensor_type, tensor_type], - results=[tensor_type]), - visibility="nested") - # CHECK: Name is: "some_func" - print("Name is: ", f.name) - - # CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32> - print("Type is: ", f.type) - - # CHECK: Visibility is: "nested" - print("Visibility is: ", f.visibility) - - try: - entry_block = f.entry_block - except IndexError as e: - # CHECK: External function does not have a body - print(e) - - with InsertionPoint(f.add_entry_block()): - func.ReturnOp([f.entry_block.arguments[0]]) - pass - - try: - f.add_entry_block() - except IndexError as e: - # CHECK: The function already has an entry block! - print(e) - - # Try the callback builder and passing type as tuple. - f = func.FuncOp(name="some_other_func", - type=([tensor_type, tensor_type], [tensor_type]), - visibility="nested", - body_builder=lambda f: func.ReturnOp( - [f.entry_block.arguments[0]])) - - # CHECK: module { - # CHECK: func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { - # CHECK: return %arg0 : tensor<2x3x4xf32> - # CHECK: } - # CHECK: func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { - # CHECK: return %arg0 : tensor<2x3x4xf32> - # CHECK: } - print(m) + ctx = Context() + with Location.unknown(ctx) as loc: + m = builtin.ModuleOp() + + f32 = F32Type.get() + tensor_type = RankedTensorType.get((2, 3, 4), f32) + with InsertionPoint.at_block_begin(m.body): + f = func.FuncOp( + name="some_func", + type=FunctionType.get( + inputs=[tensor_type, tensor_type], results=[tensor_type] + ), + visibility="nested", + ) + # CHECK: Name is: "some_func" + print("Name is: ", f.name) + + # CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32> + print("Type is: ", f.type) + + # CHECK: Visibility is: "nested" + print("Visibility is: ", f.visibility) + + try: + entry_block = f.entry_block + except IndexError as e: + # CHECK: External function does not have a body + print(e) + + with InsertionPoint(f.add_entry_block()): + func.ReturnOp([f.entry_block.arguments[0]]) + pass + + try: + f.add_entry_block() + except IndexError as e: + # CHECK: The function already has an entry block! + print(e) + + # Try the callback builder and passing type as tuple. + f = func.FuncOp( + name="some_other_func", + type=([tensor_type, tensor_type], [tensor_type]), + visibility="nested", + body_builder=lambda f: func.ReturnOp([f.entry_block.arguments[0]]), + ) + + # CHECK: module { + # CHECK: func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + # CHECK: return %arg0 : tensor<2x3x4xf32> + # CHECK: } + # CHECK: func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + # CHECK: return %arg0 : tensor<2x3x4xf32> + # CHECK: } + print(m) # CHECK-LABEL: TEST: testFuncArgumentAccess @run def testFuncArgumentAccess(): - with Context() as ctx, Location.unknown(): - ctx.allow_unregistered_dialects = True - module = Module.create() - f32 = F32Type.get() - f64 = F64Type.get() - with InsertionPoint(module.body): - f = func.FuncOp("some_func", ([f32, f32], [f32, f32])) - with InsertionPoint(f.add_entry_block()): - func.ReturnOp(f.arguments) - f.arg_attrs = ArrayAttr.get([ - DictAttr.get({ - "custom_dialect.foo": StringAttr.get("bar"), - "custom_dialect.baz": UnitAttr.get() - }), - DictAttr.get({"custom_dialect.qux": ArrayAttr.get([])}) - ]) - f.result_attrs = ArrayAttr.get([ - DictAttr.get({"custom_dialect.res1": FloatAttr.get(f32, 42.0)}), - DictAttr.get({"custom_dialect.res2": FloatAttr.get(f64, 256.0)}) - ]) - - other = func.FuncOp("other_func", ([f32, f32], [])) - with InsertionPoint(other.add_entry_block()): - func.ReturnOp([]) - other.arg_attrs = [ - DictAttr.get({"custom_dialect.foo": StringAttr.get("qux")}), - DictAttr.get() - ] - - # CHECK: [{custom_dialect.baz, custom_dialect.foo = "bar"}, {custom_dialect.qux = []}] - print(f.arg_attrs) - - # CHECK: [{custom_dialect.res1 = 4.200000e+01 : f32}, {custom_dialect.res2 = 2.560000e+02 : f64}] - print(f.result_attrs) - - # CHECK: func @some_func( - # CHECK: %[[ARG0:.*]]: f32 {custom_dialect.baz, custom_dialect.foo = "bar"}, - # CHECK: %[[ARG1:.*]]: f32 {custom_dialect.qux = []}) -> - # CHECK: f32 {custom_dialect.res1 = 4.200000e+01 : f32}, - # CHECK: f32 {custom_dialect.res2 = 2.560000e+02 : f64}) - # CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32 - # - # CHECK: func @other_func( - # CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"}, - # CHECK: %{{.*}}: f32) - print(module) + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + module = Module.create() + f32 = F32Type.get() + f64 = F64Type.get() + with InsertionPoint(module.body): + f = func.FuncOp("some_func", ([f32, f32], [f32, f32])) + with InsertionPoint(f.add_entry_block()): + func.ReturnOp(f.arguments) + f.arg_attrs = ArrayAttr.get( + [ + DictAttr.get( + { + "custom_dialect.foo": StringAttr.get("bar"), + "custom_dialect.baz": UnitAttr.get(), + } + ), + DictAttr.get({"custom_dialect.qux": ArrayAttr.get([])}), + ] + ) + f.result_attrs = ArrayAttr.get( + [ + DictAttr.get({"custom_dialect.res1": FloatAttr.get(f32, 42.0)}), + DictAttr.get({"custom_dialect.res2": FloatAttr.get(f64, 256.0)}), + ] + ) + + other = func.FuncOp("other_func", ([f32, f32], [])) + with InsertionPoint(other.add_entry_block()): + func.ReturnOp([]) + other.arg_attrs = [ + DictAttr.get({"custom_dialect.foo": StringAttr.get("qux")}), + DictAttr.get(), + ] + + # CHECK: [{custom_dialect.baz, custom_dialect.foo = "bar"}, {custom_dialect.qux = []}] + print(f.arg_attrs) + + # CHECK: [{custom_dialect.res1 = 4.200000e+01 : f32}, {custom_dialect.res2 = 2.560000e+02 : f64}] + print(f.result_attrs) + + # CHECK: func @some_func( + # CHECK: %[[ARG0:.*]]: f32 {custom_dialect.baz, custom_dialect.foo = "bar"}, + # CHECK: %[[ARG1:.*]]: f32 {custom_dialect.qux = []}) -> + # CHECK: f32 {custom_dialect.res1 = 4.200000e+01 : f32}, + # CHECK: f32 {custom_dialect.res2 = 2.560000e+02 : f64}) + # CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32 + # + # CHECK: func @other_func( + # CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"}, + # CHECK: %{{.*}}: f32) + print(module) # CHECK-LABEL: testDenseElementsAttr @run def testDenseElementsAttr(): - with Context(), Location.unknown(): - values = np.arange(4, dtype=np.int32) - i32 = IntegerType.get_signless(32) - print(DenseElementsAttr.get(values, type=i32)) - # CHECK{LITERAL}: dense<[0, 1, 2, 3]> : tensor<4xi32> - print(DenseElementsAttr.get(values, type=i32, shape=(2, 2))) - # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> - print(DenseElementsAttr.get(values, type=VectorType.get((2, 2), i32))) - # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : vector<2x2xi32> + with Context(), Location.unknown(): + values = np.arange(4, dtype=np.int32) + i32 = IntegerType.get_signless(32) + print(DenseElementsAttr.get(values, type=i32)) + # CHECK{LITERAL}: dense<[0, 1, 2, 3]> : tensor<4xi32> + print(DenseElementsAttr.get(values, type=i32, shape=(2, 2))) + # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> + print(DenseElementsAttr.get(values, type=VectorType.get((2, 2), i32))) + # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : vector<2x2xi32> diff --git a/mlir/test/python/dialects/complex_dialect.py b/mlir/test/python/dialects/complex_dialect.py index e724575..afad217 100644 --- a/mlir/test/python/dialects/complex_dialect.py +++ b/mlir/test/python/dialects/complex_dialect.py @@ -9,24 +9,24 @@ import mlir.dialects.complex as mlir_complex def run(f): - print("\nTEST:", f.__name__) - f() + print("\nTEST:", f.__name__) + f() # CHECK-LABEL: TEST: testComplexOps @run def testComplexOps(): - with Context() as ctx, Location.unknown(): - module = Module.create() - with InsertionPoint(module.body): - - @func.FuncOp.from_py_func(ComplexType.get(F32Type.get())) - def emit_add(arg): - return mlir_complex.AddOp(arg, arg) - - # CHECK-LABEL: func @emit_add( - # CHECK-SAME: %[[ARG:.*]]: complex) -> complex { - # CHECK: %[[RES:.*]] = complex.add %[[ARG]], %[[ARG]] : complex - # CHECK: return %[[RES]] : complex - # CHECK: } - print(module) + with Context() as ctx, Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func(ComplexType.get(F32Type.get())) + def emit_add(arg): + return mlir_complex.AddOp(arg, arg) + + # CHECK-LABEL: func @emit_add( + # CHECK-SAME: %[[ARG:.*]]: complex) -> complex { + # CHECK: %[[RES:.*]] = complex.add %[[ARG]], %[[ARG]] : complex + # CHECK: return %[[RES]] : complex + # CHECK: } + print(module) diff --git a/mlir/test/python/dialects/func.py b/mlir/test/python/dialects/func.py index 3be9cac..161a12d 100644 --- a/mlir/test/python/dialects/func.py +++ b/mlir/test/python/dialects/func.py @@ -7,13 +7,13 @@ from mlir.dialects import func def constructAndPrintInModule(f): - print("\nTEST:", f.__name__) - with Context(), Location.unknown(): - module = Module.create() - with InsertionPoint(module.body): - f() - print(module) - return f + print("\nTEST:", f.__name__) + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + f() + print(module) + return f # CHECK-LABEL: TEST: testConstantOp @@ -21,21 +21,21 @@ def constructAndPrintInModule(f): @constructAndPrintInModule def testConstantOp(): - c1 = arith.ConstantOp(IntegerType.get_signless(32), 42) - c2 = arith.ConstantOp(IntegerType.get_signless(64), 100) - c3 = arith.ConstantOp(F32Type.get(), 3.14) - c4 = arith.ConstantOp(F64Type.get(), 1.23) - # CHECK: 42 - print(c1.literal_value) + c1 = arith.ConstantOp(IntegerType.get_signless(32), 42) + c2 = arith.ConstantOp(IntegerType.get_signless(64), 100) + c3 = arith.ConstantOp(F32Type.get(), 3.14) + c4 = arith.ConstantOp(F64Type.get(), 1.23) + # CHECK: 42 + print(c1.literal_value) - # CHECK: 100 - print(c2.literal_value) + # CHECK: 100 + print(c2.literal_value) - # CHECK: 3.140000104904175 - print(c3.literal_value) + # CHECK: 3.140000104904175 + print(c3.literal_value) - # CHECK: 1.23 - print(c4.literal_value) + # CHECK: 1.23 + print(c4.literal_value) # CHECK: = arith.constant 42 : i32 @@ -47,17 +47,17 @@ def testConstantOp(): # CHECK-LABEL: TEST: testVectorConstantOp @constructAndPrintInModule def testVectorConstantOp(): - int_type = IntegerType.get_signless(32) - vec_type = VectorType.get([2, 2], int_type) - c1 = arith.ConstantOp( - vec_type, - DenseElementsAttr.get_splat(vec_type, IntegerAttr.get(int_type, 42))) - try: - print(c1.literal_value) - except ValueError as e: - assert "only integer and float constants have literal values" in str(e) - else: - assert False + int_type = IntegerType.get_signless(32) + vec_type = VectorType.get([2, 2], int_type) + c1 = arith.ConstantOp( + vec_type, DenseElementsAttr.get_splat(vec_type, IntegerAttr.get(int_type, 42)) + ) + try: + print(c1.literal_value) + except ValueError as e: + assert "only integer and float constants have literal values" in str(e) + else: + assert False # CHECK: = arith.constant dense<42> : vector<2x2xi32> @@ -66,9 +66,9 @@ def testVectorConstantOp(): # CHECK-LABEL: TEST: testConstantIndexOp @constructAndPrintInModule def testConstantIndexOp(): - c1 = arith.ConstantOp.create_index(10) - # CHECK: 10 - print(c1.literal_value) + c1 = arith.ConstantOp.create_index(10) + # CHECK: 10 + print(c1.literal_value) # CHECK: = arith.constant 10 : index @@ -77,18 +77,18 @@ def testConstantIndexOp(): # CHECK-LABEL: TEST: testFunctionCalls @constructAndPrintInModule def testFunctionCalls(): - foo = func.FuncOp("foo", ([], [])) - foo.sym_visibility = StringAttr.get("private") - bar = func.FuncOp("bar", ([], [IndexType.get()])) - bar.sym_visibility = StringAttr.get("private") - qux = func.FuncOp("qux", ([], [F32Type.get()])) - qux.sym_visibility = StringAttr.get("private") - - with InsertionPoint(func.FuncOp("caller", ([], [])).add_entry_block()): - func.CallOp(foo, []) - func.CallOp([IndexType.get()], "bar", []) - func.CallOp([F32Type.get()], FlatSymbolRefAttr.get("qux"), []) - func.ReturnOp([]) + foo = func.FuncOp("foo", ([], [])) + foo.sym_visibility = StringAttr.get("private") + bar = func.FuncOp("bar", ([], [IndexType.get()])) + bar.sym_visibility = StringAttr.get("private") + qux = func.FuncOp("qux", ([], [F32Type.get()])) + qux.sym_visibility = StringAttr.get("private") + + with InsertionPoint(func.FuncOp("caller", ([], [])).add_entry_block()): + func.CallOp(foo, []) + func.CallOp([IndexType.get()], "bar", []) + func.CallOp([F32Type.get()], FlatSymbolRefAttr.get("qux"), []) + func.ReturnOp([]) # CHECK: func private @foo() diff --git a/mlir/test/python/dialects/gpu.py b/mlir/test/python/dialects/gpu.py index 38bf038..7eefaed 100644 --- a/mlir/test/python/dialects/gpu.py +++ b/mlir/test/python/dialects/gpu.py @@ -5,14 +5,17 @@ import mlir.dialects.gpu import mlir.dialects.gpu.passes from mlir.passmanager import * + def run(f): - print("\nTEST:", f.__name__) - f() + print("\nTEST:", f.__name__) + f() + def testGPUPass(): - with Context() as context: - PassManager.parse('any(gpu-kernel-outlining)') - print('SUCCESS') + with Context() as context: + PassManager.parse("any(gpu-kernel-outlining)") + print("SUCCESS") + # CHECK-LABEL: testGPUPass # CHECK: SUCCESS diff --git a/mlir/test/python/dialects/linalg/opdsl/arguments.py b/mlir/test/python/dialects/linalg/opdsl/arguments.py index d787c5f..7892d02 100644 --- a/mlir/test/python/dialects/linalg/opdsl/arguments.py +++ b/mlir/test/python/dialects/linalg/opdsl/arguments.py @@ -34,8 +34,9 @@ def matmul( C=TensorDef(U, S.M, S.N, output=True), bfn=BinaryFnAttrDef(default=BinaryFn.mul), ufn=UnaryFnAttrDef(default=UnaryFn.exp), - cast=TypeFnAttrDef(default=TypeFn.cast_signed)): - C[D.m, D.n] += bfn(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n])) + cast=TypeFnAttrDef(default=TypeFn.cast_signed), +): + C[D.m, D.n] += bfn(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n])) # CHECK: --- @@ -47,7 +48,7 @@ def matmul( # CHECK: type_var: T @linalg_structured_op def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)): - O[D.m, D.n] = value + O[D.m, D.n] = value # CHECK: --- @@ -71,5 +72,6 @@ def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)): def strided_copy( I=TensorDef(T, S.IH, S.IW), O=TensorDef(T, S.OH, S.OW, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 2])): - O[D.oh, D.ow] = I[D.oh * S.SH, D.ow * S.SW] + strides=IndexAttrDef(S.SH, S.SW, default=[1, 2]), +): + O[D.oh, D.ow] = I[D.oh * S.SH, D.ow * S.SW] diff --git a/mlir/test/python/dialects/linalg/opdsl/assignments.py b/mlir/test/python/dialects/linalg/opdsl/assignments.py index eacf435..ad0a3ea 100644 --- a/mlir/test/python/dialects/linalg/opdsl/assignments.py +++ b/mlir/test/python/dialects/linalg/opdsl/assignments.py @@ -35,8 +35,9 @@ def matmul( B=TensorDef(T, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True), mul=BinaryFnAttrDef(default=BinaryFn.mul), - cast=TypeFnAttrDef(default=TypeFn.cast_signed)): - C[D.m, D.n] += mul(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n])) + cast=TypeFnAttrDef(default=TypeFn.cast_signed), +): + C[D.m, D.n] += mul(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n])) # CHECK: --- @@ -79,12 +80,12 @@ def matmul( # CHECK: scalar_const: '1.{{[0]*}}e+03 : f64' @linalg_structured_op def constants( - O=TensorDef(T, S.M, S.K, output=True), - exp=UnaryFnAttrDef(default=UnaryFn.exp)): - pi = TypeFn.cast_signed(T, const(3.1415926535897931)) - cst42 = TypeFn.cast_signed(T, const(42)) - cst1000 = TypeFn.cast_signed(T, exp(const(1e+3))) - O[D.m, D.n] = UnaryFn.exp(pi) + cst42 - cst1000 + O=TensorDef(T, S.M, S.K, output=True), exp=UnaryFnAttrDef(default=UnaryFn.exp) +): + pi = TypeFn.cast_signed(T, const(3.1415926535897931)) + cst42 = TypeFn.cast_signed(T, const(42)) + cst1000 = TypeFn.cast_signed(T, exp(const(1e3))) + O[D.m, D.n] = UnaryFn.exp(pi) + cst42 - cst1000 # CHECK: --- @@ -100,7 +101,7 @@ def constants( # CHECK: scalar_index: 0 @linalg_structured_op def indices(O=TensorDef(T, S.M, S.K, output=True)): - O[D.m, D.n] = index(D.n) + index(D.m) + O[D.m, D.n] = index(D.n) + index(D.m) # CHECK: --- @@ -111,4 +112,4 @@ def indices(O=TensorDef(T, S.M, S.K, output=True)): # CHECK: scalar_arg: value @linalg_structured_op def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)): - O[D.m, D.n] = value + O[D.m, D.n] = value diff --git a/mlir/test/python/dialects/linalg/opdsl/doctests.py b/mlir/test/python/dialects/linalg/opdsl/doctests.py index 4aae768..d2f9cec 100644 --- a/mlir/test/python/dialects/linalg/opdsl/doctests.py +++ b/mlir/test/python/dialects/linalg/opdsl/doctests.py @@ -3,10 +3,11 @@ import doctest import importlib + def test_module(module_name): - print(f"--- Testing module: {module_name}") - m = importlib.import_module(module_name) - doctest.testmod(m, verbose=True, raise_on_error=True, report=True) + print(f"--- Testing module: {module_name}") + m = importlib.import_module(module_name) + doctest.testmod(m, verbose=True, raise_on_error=True, report=True) test_module("mlir.dialects.linalg.opdsl.lang.affine") diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py b/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py index ebe2c0f..d666d31 100644 --- a/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py @@ -17,43 +17,44 @@ def conv_poly( K=TensorDef(T2, S.KH, S.KW, S.C), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 2])): - domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.c]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c]) + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 2]), +): + domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c]) with Context() as ctx, Location.unknown(): - module = Module.create() - f32 = F32Type.get() - i32 = IntegerType.get_signless(32) - with InsertionPoint(module.body): - - # Convolution indexing maps. - # CHECK: #[[$CONV_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d3, d2 * 4 + d4 * 2, d5)> - # CHECK: #[[$CONV_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> - # CHECK: #[[$CONV_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> - - # CHECK-LABEL: @test_f32i32_conv - # CHECK: linalg.generic - # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$CONV_MAP_K]], #[[$CONV_MAP_O]]] - # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"] - # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[FILTER:.+]]: f32, %[[OUT:.+]]: i32) - # CHECK-NEXT: %[[IN_CAST:.+]] = arith.fptosi %[[IN:.+]] : f32 to i32 - # CHECK-NEXT: %[[FILTER_CAST:.+]] = arith.fptosi %[[FILTER:.+]] : f32 to i32 - # CHECK-NEXT: %[[PROD:.+]] = arith.muli %[[IN_CAST]], %[[FILTER_CAST]] : i32 - # CHECK-NEXT: %[[SUM:.+]] = arith.addi %[[OUT]], %[[PROD]] : i32 - # CHECK-NEXT: linalg.yield %[[SUM]] : i32 - # CHECK-NEXT: -> tensor<1x2x4x1xi32> - @func.FuncOp.from_py_func( - RankedTensorType.get((1, 4, 16, 1), f32), - RankedTensorType.get((2, 2, 1), f32), - RankedTensorType.get((1, 2, 4, 1), i32)) - def test_f32i32_conv(input, filter, init_result): - # Use default dilations and set non-default strides. - return conv_poly( - input, filter, outs=[init_result], strides=[2, 4]) + module = Module.create() + f32 = F32Type.get() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + + # Convolution indexing maps. + # CHECK: #[[$CONV_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d3, d2 * 4 + d4 * 2, d5)> + # CHECK: #[[$CONV_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> + # CHECK: #[[$CONV_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> + + # CHECK-LABEL: @test_f32i32_conv + # CHECK: linalg.generic + # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$CONV_MAP_K]], #[[$CONV_MAP_O]]] + # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"] + # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[FILTER:.+]]: f32, %[[OUT:.+]]: i32) + # CHECK-NEXT: %[[IN_CAST:.+]] = arith.fptosi %[[IN:.+]] : f32 to i32 + # CHECK-NEXT: %[[FILTER_CAST:.+]] = arith.fptosi %[[FILTER:.+]] : f32 to i32 + # CHECK-NEXT: %[[PROD:.+]] = arith.muli %[[IN_CAST]], %[[FILTER_CAST]] : i32 + # CHECK-NEXT: %[[SUM:.+]] = arith.addi %[[OUT]], %[[PROD]] : i32 + # CHECK-NEXT: linalg.yield %[[SUM]] : i32 + # CHECK-NEXT: -> tensor<1x2x4x1xi32> + @func.FuncOp.from_py_func( + RankedTensorType.get((1, 4, 16, 1), f32), + RankedTensorType.get((2, 2, 1), f32), + RankedTensorType.get((1, 2, 4, 1), i32), + ) + def test_f32i32_conv(input, filter, init_result): + # Use default dilations and set non-default strides. + return conv_poly(input, filter, outs=[init_result], strides=[2, 4]) print(module) diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_fill.py b/mlir/test/python/dialects/linalg/opdsl/emit_fill.py index 1f840b0..ffef737 100644 --- a/mlir/test/python/dialects/linalg/opdsl/emit_fill.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_fill.py @@ -13,47 +13,51 @@ T2 = TV.T2 @linalg_structured_op def fill_poly(value=ScalarDef(T1), O=TensorDef(U, output=True)): - O[None] = TypeFn.cast_signed(U, value) + O[None] = TypeFn.cast_signed(U, value) + @linalg_structured_op def fill_rank_zero_poly(I=TensorDef(T1), O=TensorDef(U, output=True)): - O[None] = TypeFn.cast_signed(U, I[None]) + O[None] = TypeFn.cast_signed(U, I[None]) + with Context() as ctx, Location.unknown(): - module = Module.create() - f32 = F32Type.get() - with InsertionPoint(module.body): - - # Fill indexing maps. - # CHECK-DAG: #[[$MAP0:.+]] = affine_map<() -> ()> - # CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()> - # CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)> - # CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2) -> ()> - # CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> - - # CHECK-LABEL: @test_fill_0d - # CHECK: linalg.generic - # CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]] - # CHECK-SAME: iterator_types = [] - @func.FuncOp.from_py_func(f32, RankedTensorType.get([], f32)) - def test_fill_0d(value, init_result): - return fill_poly(value, outs=[init_result]) - - # CHECK-LABEL: @test_fill_2d - # CHECK: linalg.generic - # CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]]] - # CHECK-SAME: iterator_types = ["parallel", "parallel"] - @func.FuncOp.from_py_func(f32, RankedTensorType.get([4, 16], f32)) - def test_fill_2d(value, init_result): - return fill_poly(value, outs=[init_result]) - - # CHECK-LABEL: @test_fill_rank_zero_3d - # CHECK: linalg.generic - # CHECK-SAME: indexing_maps = [#[[$MAP3]], #[[$MAP4]]] - # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] - @func.FuncOp.from_py_func( - RankedTensorType.get([], f32), RankedTensorType.get([4, 8, 16], f32)) - def test_fill_rank_zero_3d(input, init_result): - return fill_rank_zero_poly(input, outs=[init_result]) + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + + # Fill indexing maps. + # CHECK-DAG: #[[$MAP0:.+]] = affine_map<() -> ()> + # CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()> + # CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)> + # CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2) -> ()> + # CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + + # CHECK-LABEL: @test_fill_0d + # CHECK: linalg.generic + # CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]] + # CHECK-SAME: iterator_types = [] + @func.FuncOp.from_py_func(f32, RankedTensorType.get([], f32)) + def test_fill_0d(value, init_result): + return fill_poly(value, outs=[init_result]) + + # CHECK-LABEL: @test_fill_2d + # CHECK: linalg.generic + # CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]]] + # CHECK-SAME: iterator_types = ["parallel", "parallel"] + @func.FuncOp.from_py_func(f32, RankedTensorType.get([4, 16], f32)) + def test_fill_2d(value, init_result): + return fill_poly(value, outs=[init_result]) + + # CHECK-LABEL: @test_fill_rank_zero_3d + # CHECK: linalg.generic + # CHECK-SAME: indexing_maps = [#[[$MAP3]], #[[$MAP4]]] + # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] + @func.FuncOp.from_py_func( + RankedTensorType.get([], f32), RankedTensorType.get([4, 8, 16], f32) + ) + def test_fill_rank_zero_3d(input, init_result): + return fill_rank_zero_poly(input, outs=[init_result]) + print(module) diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py b/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py index 6dff754..18c237c 100644 --- a/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py @@ -16,9 +16,10 @@ T2 = TV.T2 def matmul_mono( A=TensorDef(T, S.M, S.K), B=TensorDef(T, S.K, S.N), - C=TensorDef(T, S.M, S.N, output=True)): - domain(D.m, D.n, D.k) - C[D.m, D.n] += A[D.m, D.k] * B[D.k, D.n] + C=TensorDef(T, S.M, S.N, output=True), +): + domain(D.m, D.n, D.k) + C[D.m, D.n] += A[D.m, D.k] * B[D.k, D.n] @linalg_structured_op @@ -26,146 +27,162 @@ def matmul_poly( A=TensorDef(T1, S.M, S.K), B=TensorDef(T2, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True), - cast=TypeFnAttrDef(default=TypeFn.cast_signed)): - domain(D.m, D.n, D.k) - C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) + cast=TypeFnAttrDef(default=TypeFn.cast_signed), +): + domain(D.m, D.n, D.k) + C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) with Context() as ctx, Location.unknown(): - module = Module.create() - f16 = F16Type.get() - f32 = F32Type.get() - f64 = F64Type.get() - i8 = IntegerType.get_signless(8) - i16 = IntegerType.get_signless(16) - i32 = IntegerType.get_signless(32) - with InsertionPoint(module.body): - - # Multiplication indexing maps. We verify only the indexing maps of the - # first multiplication and then do additional tests on casting and body - # generation behavior. - # CHECK: #[[$MUL_MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> - # CHECK: #[[$MUL_MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> - # CHECK: #[[$MUL_MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> - - # CHECK-LABEL: func @test_matmul_mono - # CHECK-SAME: %[[A:.+]]: tensor<4x16xf32> - # CHECK-SAME: %[[B:.+]]: tensor<16x8xf32> - # CHECK: %[[INITC:.+]] = tensor.empty() : tensor<4x8xf32> - # CHECK: linalg.generic - # CHECK-SAME: indexing_maps = [#[[$MUL_MAP_A]], #[[$MUL_MAP_B]], #[[$MUL_MAP_C]]] - # CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] - # CHECK-SAME: ins(%[[A]], %[[B]] - # CHECK-SAME: outs(%[[INITC]] - @func.FuncOp.from_py_func( - RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)) - def test_matmul_mono(lhs, rhs): - init_result = tensor.EmptyOp([4, 8], f32) - return matmul_mono(lhs, rhs, outs=[init_result.result]) - - # CHECK-LABEL: @test_i8i8i32_matmul - # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: i32) - # CHECK-NEXT: %[[A_CAST:.+]] = arith.extsi %[[A_ARG]] : i8 to i32 - # CHECK-NEXT: %[[B_CAST:.+]] = arith.extsi %[[B_ARG]] : i8 to i32 - # CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[A_CAST]], %[[B_CAST]] : i32 - # CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i32 - # CHECK-NEXT: linalg.yield %[[ADD]] : i32 - # CHECK-NEXT: -> tensor<4x8xi32> - @func.FuncOp.from_py_func( - RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8), - RankedTensorType.get((4, 8), i32)) - def test_i8i8i32_matmul(lhs, rhs, init_result): - return matmul_poly(lhs, rhs, outs=[init_result]) - - # CHECK-LABEL: @test_i8i8i32_matmul_unsigned - # CHECK: = arith.extui - # CHECK: = arith.extui - @func.FuncOp.from_py_func( - RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8), - RankedTensorType.get((4, 8), i32)) - def test_i8i8i32_matmul_unsigned(lhs, rhs, init_result): - return matmul_poly( - lhs, rhs, outs=[init_result], cast=TypeFn.cast_unsigned) - - # CHECK-LABEL: @test_i8i16i32_matmul - # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i16, %[[C_ARG:.+]]: i32) - # CHECK-NEXT: %[[A_CAST:.+]] = arith.extsi %[[A_ARG]] : i8 to i32 - # CHECK-NEXT: %[[B_CAST:.+]] = arith.extsi %[[B_ARG]] : i16 to i32 - # CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[A_CAST]], %[[B_CAST]] : i32 - # CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i32 - # CHECK-NEXT: linalg.yield %[[ADD]] : i32 - # CHECK-NEXT: -> tensor<4x8xi32> - @func.FuncOp.from_py_func( - RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i16), - RankedTensorType.get((4, 8), i32)) - def test_i8i16i32_matmul(lhs, rhs, init_result): - return matmul_poly(lhs, rhs, outs=[init_result]) - - # CHECK-LABEL: @test_i32i32i16_matmul - # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i16) - # CHECK-NEXT: %[[A_CAST:.+]] = arith.trunci %[[A_ARG]] : i32 to i16 - # CHECK-NEXT: %[[B_CAST:.+]] = arith.trunci %[[B_ARG]] : i32 to i16 - # CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[A_CAST]], %[[B_CAST]] : i16 - # CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i16 - # CHECK-NEXT: linalg.yield %[[ADD]] : i16 - # CHECK-NEXT: -> tensor<4x8xi16> - @func.FuncOp.from_py_func( - RankedTensorType.get((4, 16), i32), RankedTensorType.get((16, 8), i32), - RankedTensorType.get((4, 8), i16)) - def test_i32i32i16_matmul(lhs, rhs, init_result): - return matmul_poly(lhs, rhs, outs=[init_result]) - - # CHECK-LABEL: @test_i8i8f32_matmul - # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: f32) - # CHECK-NEXT: %[[A_CAST:.+]] = arith.sitofp %[[A_ARG]] : i8 to f32 - # CHECK-NEXT: %[[B_CAST:.+]] = arith.sitofp %[[B_ARG]] : i8 to f32 - # CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32 - # CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32 - # CHECK-NEXT: linalg.yield %[[ADD]] : f32 - # CHECK-NEXT: -> tensor<4x8xf32> - @func.FuncOp.from_py_func( - RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8), - RankedTensorType.get((4, 8), f32)) - def test_i8i8f32_matmul(lhs, rhs, init_result): - return matmul_poly(lhs, rhs, outs=[init_result]) - - # CHECK-LABEL: @test_i8i8f32_matmul_unsigned - # CHECK: = arith.uitofp - # CHECK: = arith.uitofp - @func.FuncOp.from_py_func( - RankedTensorType.get((4, 16), i8), RankedTensorType.get((16, 8), i8), - RankedTensorType.get((4, 8), f32)) - def test_i8i8f32_matmul_unsigned(lhs, rhs, init_result): - return matmul_poly( - lhs, rhs, outs=[init_result], cast=TypeFn.cast_unsigned) - - # CHECK-LABEL: @test_f16f16f32_matmul - # CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32) - # CHECK-NEXT: %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32 - # CHECK-NEXT: %[[B_CAST:.+]] = arith.extf %[[B_ARG]] : f16 to f32 - # CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32 - # CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32 - # CHECK-NEXT: linalg.yield %[[ADD]] : f32 - # CHECK-NEXT: -> tensor<4x8xf32> - @func.FuncOp.from_py_func( - RankedTensorType.get((4, 16), f16), RankedTensorType.get((16, 8), f16), - RankedTensorType.get((4, 8), f32)) - def test_f16f16f32_matmul(lhs, rhs, init_result): - return matmul_poly(lhs, rhs, outs=[init_result]) - - # CHECK-LABEL: @test_f64f64f32_matmul - # CHECK: ^{{.*}}(%[[A_ARG:.+]]: f64, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32) - # CHECK-NEXT: %[[A_CAST:.+]] = arith.truncf %[[A_ARG]] : f64 to f32 - # CHECK-NEXT: %[[B_CAST:.+]] = arith.truncf %[[B_ARG]] : f64 to f32 - # CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32 - # CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32 - # CHECK-NEXT: linalg.yield %[[ADD]] : f32 - # CHECK-NEXT: -> tensor<4x8xf32> - @func.FuncOp.from_py_func( - RankedTensorType.get((4, 16), f64), RankedTensorType.get((16, 8), f64), - RankedTensorType.get((4, 8), f32)) - def test_f64f64f32_matmul(lhs, rhs, init_result): - return matmul_poly(lhs, rhs, outs=[init_result]) + module = Module.create() + f16 = F16Type.get() + f32 = F32Type.get() + f64 = F64Type.get() + i8 = IntegerType.get_signless(8) + i16 = IntegerType.get_signless(16) + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + + # Multiplication indexing maps. We verify only the indexing maps of the + # first multiplication and then do additional tests on casting and body + # generation behavior. + # CHECK: #[[$MUL_MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> + # CHECK: #[[$MUL_MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> + # CHECK: #[[$MUL_MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> + + # CHECK-LABEL: func @test_matmul_mono + # CHECK-SAME: %[[A:.+]]: tensor<4x16xf32> + # CHECK-SAME: %[[B:.+]]: tensor<16x8xf32> + # CHECK: %[[INITC:.+]] = tensor.empty() : tensor<4x8xf32> + # CHECK: linalg.generic + # CHECK-SAME: indexing_maps = [#[[$MUL_MAP_A]], #[[$MUL_MAP_B]], #[[$MUL_MAP_C]]] + # CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] + # CHECK-SAME: ins(%[[A]], %[[B]] + # CHECK-SAME: outs(%[[INITC]] + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32) + ) + def test_matmul_mono(lhs, rhs): + init_result = tensor.EmptyOp([4, 8], f32) + return matmul_mono(lhs, rhs, outs=[init_result.result]) + + # CHECK-LABEL: @test_i8i8i32_matmul + # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: i32) + # CHECK-NEXT: %[[A_CAST:.+]] = arith.extsi %[[A_ARG]] : i8 to i32 + # CHECK-NEXT: %[[B_CAST:.+]] = arith.extsi %[[B_ARG]] : i8 to i32 + # CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[A_CAST]], %[[B_CAST]] : i32 + # CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i32 + # CHECK-NEXT: linalg.yield %[[ADD]] : i32 + # CHECK-NEXT: -> tensor<4x8xi32> + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), i8), + RankedTensorType.get((16, 8), i8), + RankedTensorType.get((4, 8), i32), + ) + def test_i8i8i32_matmul(lhs, rhs, init_result): + return matmul_poly(lhs, rhs, outs=[init_result]) + + # CHECK-LABEL: @test_i8i8i32_matmul_unsigned + # CHECK: = arith.extui + # CHECK: = arith.extui + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), i8), + RankedTensorType.get((16, 8), i8), + RankedTensorType.get((4, 8), i32), + ) + def test_i8i8i32_matmul_unsigned(lhs, rhs, init_result): + return matmul_poly(lhs, rhs, outs=[init_result], cast=TypeFn.cast_unsigned) + + # CHECK-LABEL: @test_i8i16i32_matmul + # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i16, %[[C_ARG:.+]]: i32) + # CHECK-NEXT: %[[A_CAST:.+]] = arith.extsi %[[A_ARG]] : i8 to i32 + # CHECK-NEXT: %[[B_CAST:.+]] = arith.extsi %[[B_ARG]] : i16 to i32 + # CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[A_CAST]], %[[B_CAST]] : i32 + # CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i32 + # CHECK-NEXT: linalg.yield %[[ADD]] : i32 + # CHECK-NEXT: -> tensor<4x8xi32> + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), i8), + RankedTensorType.get((16, 8), i16), + RankedTensorType.get((4, 8), i32), + ) + def test_i8i16i32_matmul(lhs, rhs, init_result): + return matmul_poly(lhs, rhs, outs=[init_result]) + + # CHECK-LABEL: @test_i32i32i16_matmul + # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i16) + # CHECK-NEXT: %[[A_CAST:.+]] = arith.trunci %[[A_ARG]] : i32 to i16 + # CHECK-NEXT: %[[B_CAST:.+]] = arith.trunci %[[B_ARG]] : i32 to i16 + # CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[A_CAST]], %[[B_CAST]] : i16 + # CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[C_ARG]], %[[MUL]] : i16 + # CHECK-NEXT: linalg.yield %[[ADD]] : i16 + # CHECK-NEXT: -> tensor<4x8xi16> + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), i32), + RankedTensorType.get((16, 8), i32), + RankedTensorType.get((4, 8), i16), + ) + def test_i32i32i16_matmul(lhs, rhs, init_result): + return matmul_poly(lhs, rhs, outs=[init_result]) + + # CHECK-LABEL: @test_i8i8f32_matmul + # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: f32) + # CHECK-NEXT: %[[A_CAST:.+]] = arith.sitofp %[[A_ARG]] : i8 to f32 + # CHECK-NEXT: %[[B_CAST:.+]] = arith.sitofp %[[B_ARG]] : i8 to f32 + # CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32 + # CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32 + # CHECK-NEXT: linalg.yield %[[ADD]] : f32 + # CHECK-NEXT: -> tensor<4x8xf32> + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), i8), + RankedTensorType.get((16, 8), i8), + RankedTensorType.get((4, 8), f32), + ) + def test_i8i8f32_matmul(lhs, rhs, init_result): + return matmul_poly(lhs, rhs, outs=[init_result]) + + # CHECK-LABEL: @test_i8i8f32_matmul_unsigned + # CHECK: = arith.uitofp + # CHECK: = arith.uitofp + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), i8), + RankedTensorType.get((16, 8), i8), + RankedTensorType.get((4, 8), f32), + ) + def test_i8i8f32_matmul_unsigned(lhs, rhs, init_result): + return matmul_poly(lhs, rhs, outs=[init_result], cast=TypeFn.cast_unsigned) + + # CHECK-LABEL: @test_f16f16f32_matmul + # CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32) + # CHECK-NEXT: %[[A_CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32 + # CHECK-NEXT: %[[B_CAST:.+]] = arith.extf %[[B_ARG]] : f16 to f32 + # CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32 + # CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32 + # CHECK-NEXT: linalg.yield %[[ADD]] : f32 + # CHECK-NEXT: -> tensor<4x8xf32> + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f16), + RankedTensorType.get((16, 8), f16), + RankedTensorType.get((4, 8), f32), + ) + def test_f16f16f32_matmul(lhs, rhs, init_result): + return matmul_poly(lhs, rhs, outs=[init_result]) + + # CHECK-LABEL: @test_f64f64f32_matmul + # CHECK: ^{{.*}}(%[[A_ARG:.+]]: f64, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32) + # CHECK-NEXT: %[[A_CAST:.+]] = arith.truncf %[[A_ARG]] : f64 to f32 + # CHECK-NEXT: %[[B_CAST:.+]] = arith.truncf %[[B_ARG]] : f64 to f32 + # CHECK-NEXT: %[[MUL:.+]] = arith.mulf %[[A_CAST]], %[[B_CAST]] : f32 + # CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[C_ARG]], %[[MUL]] : f32 + # CHECK-NEXT: linalg.yield %[[ADD]] : f32 + # CHECK-NEXT: -> tensor<4x8xf32> + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f64), + RankedTensorType.get((16, 8), f64), + RankedTensorType.get((4, 8), f32), + ) + def test_f64f64f32_matmul(lhs, rhs, init_result): + return matmul_poly(lhs, rhs, outs=[init_result]) print(module) diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py index aad7149..f8e034f 100644 --- a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py @@ -17,14 +17,16 @@ from mlir.dialects.linalg.opdsl.lang import * @linalg_structured_op def test_const(O=TensorDef(F32, S.M, S.N, output=True)): - O[D.m, D.n] = TypeFn.cast_unsigned(F32, const(42)) + TypeFn.cast_unsigned( - F32, const(2.3283064e-10)) + O[D.m, D.n] = TypeFn.cast_unsigned(F32, const(42)) + TypeFn.cast_unsigned( + F32, const(2.3283064e-10) + ) @linalg_structured_op def test_index(O=TensorDef(I32, S.M, S.N, output=True)): - O[D.m, D.n] = TypeFn.cast_signed(I32, index(D.m)) + TypeFn.cast_signed( - I32, index(D.n)) + O[D.m, D.n] = TypeFn.cast_signed(I32, index(D.m)) + TypeFn.cast_signed( + I32, index(D.n) + ) @linalg_structured_op @@ -32,120 +34,129 @@ def elemwise_unary_poly( I=TensorDef(T), O=TensorDef(U, output=True), fun=UnaryFnAttrDef(default=UnaryFn.exp), - cast=TypeFnAttrDef(default=TypeFn.cast_signed)): - O[None] = fun(cast(U, I[None])) + cast=TypeFnAttrDef(default=TypeFn.cast_signed), +): + O[None] = fun(cast(U, I[None])) @linalg_structured_op(op_name="custom_op_name") def non_default_op_name(I=TensorDef(T, S.N), O=TensorDef(T, S.N, output=True)): - O[D.n] = I[D.n] + O[D.n] = I[D.n] with Context() as ctx, Location.unknown(): - module = Module.create() - f32 = F32Type.get() - c32 = ComplexType.get(f32) - i32 = IntegerType.get_signless(32) - with InsertionPoint(module.body): - - # CHECK-LABEL: @test_f32_const - # CHECK-DAG: %[[CST0:.+]] = arith.constant 42 : i64 - # CHECK-DAG: %[[CST0_CAST:.+]] = arith.uitofp %[[CST0]] : i64 to f32 - # CHECK-DAG: %[[CST1:.+]] = arith.constant 2.3283063999999999E-10 : f64 - # CHECK-DAG: %[[CST1_CAST:.+]] = arith.truncf %[[CST1]] : f64 to f32 - # CHECK-DAG: %[[SUM:.+]] = arith.addf %[[CST0_CAST]], %[[CST1_CAST]] : f32 - # CHECK-NEXT: linalg.yield %[[SUM]] : f32 - @func.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32)) - def test_f32_const(init_result): - return test_const(outs=[init_result]) - - # CHECK-LABEL: @test_i32_index - # CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index - # CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index - # CHECK-DAG: %[[IDX0_CAST:.+]] = arith.index_cast %[[IDX0]] : index to i32 - # CHECK-DAG: %[[IDX1_CAST:.+]] = arith.index_cast %[[IDX1]] : index to i32 - # CHECK-DAG: %[[SUM:.+]] = arith.addi %[[IDX0_CAST]], %[[IDX1_CAST]] : i32 - # CHECK-NEXT: linalg.yield %[[SUM]] : i32 - @func.FuncOp.from_py_func(RankedTensorType.get((4, 16), i32)) - def test_i32_index(init_result): - return test_index(outs=[init_result]) - - # CHECK-LABEL: @test_f32_elemwise_exp - # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) - # CHECK-NEXT: %[[EXP:.+]] = math.exp %[[IN]] : f32 - # CHECK-NEXT: linalg.yield %[[EXP]] : f32 - # CHECK-NEXT: -> tensor<4x16xf32> - @func.FuncOp.from_py_func( - RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)) - def test_f32_elemwise_exp(input, init_result): - return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.exp) - - # CHECK-LABEL: @test_f32_elemwise_log - # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) - # CHECK-NEXT: %[[LOG:.+]] = math.log %[[IN]] : f32 - # CHECK-NEXT: linalg.yield %[[LOG]] : f32 - # CHECK-NEXT: -> tensor<4x16xf32> - @func.FuncOp.from_py_func( - RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)) - def test_f32_elemwise_log(input, init_result): - return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log) - - # CHECK-LABEL: @test_f32_elemwise_abs - # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) - # CHECK-NEXT: %[[EXP:.+]] = math.absf %[[IN]] : f32 - # CHECK-NEXT: linalg.yield %[[EXP]] : f32 - # CHECK-NEXT: -> tensor<4x16xf32> - @func.FuncOp.from_py_func( - RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)) - def test_f32_elemwise_abs(input, init_result): - return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs) - - # CHECK-LABEL: @test_f32_elemwise_ceil - # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) - # CHECK-NEXT: %[[EXP:.+]] = math.ceil %[[IN]] : f32 - # CHECK-NEXT: linalg.yield %[[EXP]] : f32 - # CHECK-NEXT: -> tensor<4x16xf32> - @func.FuncOp.from_py_func( - RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)) - def test_f32_elemwise_ceil(input, init_result): - return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.ceil) - - # CHECK-LABEL: @test_f32_elemwise_floor - # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) - # CHECK-NEXT: %[[EXP:.+]] = math.floor %[[IN]] : f32 - # CHECK-NEXT: linalg.yield %[[EXP]] : f32 - # CHECK-NEXT: -> tensor<4x16xf32> - @func.FuncOp.from_py_func( - RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)) - def test_f32_elemwise_floor(input, init_result): - return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.floor) - - # CHECK-LABEL: @test_f32_elemwise_neg - # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) - # CHECK-NEXT: %[[EXP:.+]] = arith.negf %[[IN]] : f32 - # CHECK-NEXT: linalg.yield %[[EXP]] : f32 - # CHECK-NEXT: -> tensor<4x16xf32> - @func.FuncOp.from_py_func( - RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)) - def test_f32_elemwise_neg(input, init_result): - return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf) - - # CHECK-LABEL: @test_c32_elemwise_neg - # CHECK: ^{{.*}}(%[[IN:.+]]: complex, %[[OUT:.+]]: complex) - # CHECK-NEXT: %[[EXP:.+]] = complex.neg %[[IN]] : complex - # CHECK-NEXT: linalg.yield %[[EXP]] : complex - # CHECK-NEXT: -> tensor<4x16xcomplex> - @func.FuncOp.from_py_func( - RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)) - def test_c32_elemwise_neg(input, init_result): - return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf) - - # Just check that we don't assert out on name mismatch. - # CHECK-LABEL: @test_non_default_op_name - @func.FuncOp.from_py_func( - RankedTensorType.get((42,), f32), RankedTensorType.get((42,), f32)) - def test_non_default_op_name(input, init_result): - return non_default_op_name(input, outs=[init_result]) + module = Module.create() + f32 = F32Type.get() + c32 = ComplexType.get(f32) + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + + # CHECK-LABEL: @test_f32_const + # CHECK-DAG: %[[CST0:.+]] = arith.constant 42 : i64 + # CHECK-DAG: %[[CST0_CAST:.+]] = arith.uitofp %[[CST0]] : i64 to f32 + # CHECK-DAG: %[[CST1:.+]] = arith.constant 2.3283063999999999E-10 : f64 + # CHECK-DAG: %[[CST1_CAST:.+]] = arith.truncf %[[CST1]] : f64 to f32 + # CHECK-DAG: %[[SUM:.+]] = arith.addf %[[CST0_CAST]], %[[CST1_CAST]] : f32 + # CHECK-NEXT: linalg.yield %[[SUM]] : f32 + @func.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32)) + def test_f32_const(init_result): + return test_const(outs=[init_result]) + + # CHECK-LABEL: @test_i32_index + # CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index + # CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index + # CHECK-DAG: %[[IDX0_CAST:.+]] = arith.index_cast %[[IDX0]] : index to i32 + # CHECK-DAG: %[[IDX1_CAST:.+]] = arith.index_cast %[[IDX1]] : index to i32 + # CHECK-DAG: %[[SUM:.+]] = arith.addi %[[IDX0_CAST]], %[[IDX1_CAST]] : i32 + # CHECK-NEXT: linalg.yield %[[SUM]] : i32 + @func.FuncOp.from_py_func(RankedTensorType.get((4, 16), i32)) + def test_i32_index(init_result): + return test_index(outs=[init_result]) + + # CHECK-LABEL: @test_f32_elemwise_exp + # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) + # CHECK-NEXT: %[[EXP:.+]] = math.exp %[[IN]] : f32 + # CHECK-NEXT: linalg.yield %[[EXP]] : f32 + # CHECK-NEXT: -> tensor<4x16xf32> + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32) + ) + def test_f32_elemwise_exp(input, init_result): + return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.exp) + + # CHECK-LABEL: @test_f32_elemwise_log + # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) + # CHECK-NEXT: %[[LOG:.+]] = math.log %[[IN]] : f32 + # CHECK-NEXT: linalg.yield %[[LOG]] : f32 + # CHECK-NEXT: -> tensor<4x16xf32> + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32) + ) + def test_f32_elemwise_log(input, init_result): + return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log) + + # CHECK-LABEL: @test_f32_elemwise_abs + # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) + # CHECK-NEXT: %[[EXP:.+]] = math.absf %[[IN]] : f32 + # CHECK-NEXT: linalg.yield %[[EXP]] : f32 + # CHECK-NEXT: -> tensor<4x16xf32> + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32) + ) + def test_f32_elemwise_abs(input, init_result): + return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs) + + # CHECK-LABEL: @test_f32_elemwise_ceil + # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) + # CHECK-NEXT: %[[EXP:.+]] = math.ceil %[[IN]] : f32 + # CHECK-NEXT: linalg.yield %[[EXP]] : f32 + # CHECK-NEXT: -> tensor<4x16xf32> + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32) + ) + def test_f32_elemwise_ceil(input, init_result): + return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.ceil) + + # CHECK-LABEL: @test_f32_elemwise_floor + # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) + # CHECK-NEXT: %[[EXP:.+]] = math.floor %[[IN]] : f32 + # CHECK-NEXT: linalg.yield %[[EXP]] : f32 + # CHECK-NEXT: -> tensor<4x16xf32> + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32) + ) + def test_f32_elemwise_floor(input, init_result): + return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.floor) + + # CHECK-LABEL: @test_f32_elemwise_neg + # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32) + # CHECK-NEXT: %[[EXP:.+]] = arith.negf %[[IN]] : f32 + # CHECK-NEXT: linalg.yield %[[EXP]] : f32 + # CHECK-NEXT: -> tensor<4x16xf32> + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32) + ) + def test_f32_elemwise_neg(input, init_result): + return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf) + + # CHECK-LABEL: @test_c32_elemwise_neg + # CHECK: ^{{.*}}(%[[IN:.+]]: complex, %[[OUT:.+]]: complex) + # CHECK-NEXT: %[[EXP:.+]] = complex.neg %[[IN]] : complex + # CHECK-NEXT: linalg.yield %[[EXP]] : complex + # CHECK-NEXT: -> tensor<4x16xcomplex> + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32) + ) + def test_c32_elemwise_neg(input, init_result): + return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf) + + # Just check that we don't assert out on name mismatch. + # CHECK-LABEL: @test_non_default_op_name + @func.FuncOp.from_py_func( + RankedTensorType.get((42,), f32), RankedTensorType.get((42,), f32) + ) + def test_non_default_op_name(input, init_result): + return non_default_op_name(input, outs=[init_result]) print(module) diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py index 2fd6338..ab049d3 100644 --- a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py @@ -19,121 +19,134 @@ def pooling_poly( reduce=BinaryFnAttrDef(default=BinaryFn.max_signed), cast=TypeFnAttrDef(default=TypeFn.cast_signed), strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): - domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.c] = reduce[D.kh, D.kw]( - cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.c])) + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.c] = reduce[D.kh, D.kw]( + cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) + ) with Context() as ctx, Location.unknown(): - module = Module.create() - f32 = F32Type.get() - i32 = IntegerType.get_signless(32) - with InsertionPoint(module.body): - - # Pooling indexing maps. - # CHECK: #[[$POOL_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d3, d2 * 4 + d4 * 2, d5)> - # CHECK: #[[$POOL_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)> - # CHECK: #[[$POOL_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> - - # CHECK-LABEL: @test_f32i32_max_pooling - # CHECK: linalg.generic - # CHECK-SAME: indexing_maps = [#[[$POOL_MAP_I]], #[[$POOL_MAP_K]], #[[$POOL_MAP_O]]] - # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"] - # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32) - # CHECK-NEXT: %[[IN_CAST:.+]] = arith.fptosi %[[IN:.+]] : f32 to i32 - # CHECK-NEXT: %[[MAX:.+]] = arith.maxsi %[[OUT]], %[[IN_CAST:.+]] : i32 - # CHECK-NEXT: linalg.yield %[[MAX]] : i32 - # CHECK-NEXT: -> tensor<1x2x4x1xi32> - @func.FuncOp.from_py_func( - RankedTensorType.get((1, 4, 16, 1), f32), - RankedTensorType.get((2, 2), f32), - RankedTensorType.get((1, 2, 4, 1), i32)) - def test_f32i32_max_pooling(input, shape, init_result): - return pooling_poly( - input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]) - - # CHECK-LABEL: @test_f32i32_max_unsigned_pooling - # CHECK: = arith.fptoui - # CHECK: = arith.maxui - @func.FuncOp.from_py_func( - RankedTensorType.get((1, 4, 16, 1), f32), - RankedTensorType.get((2, 2), f32), - RankedTensorType.get((1, 2, 4, 1), i32)) - def test_f32i32_max_unsigned_pooling(input, shape, init_result): - return pooling_poly( - input, - shape, - outs=[init_result], - reduce=BinaryFn.max_unsigned, - cast=TypeFn.cast_unsigned, - strides=[2, 4], - dilations=[1, 2]) - - # CHECK-LABEL: @test_f32f32_max_pooling - # CHECK: linalg.generic - # CHECK-SAME: indexing_maps = [#[[$POOL_MAP_I]], #[[$POOL_MAP_K]], #[[$POOL_MAP_O]]] - # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"] - # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: f32) - # CHECK-NEXT: %[[MAX:.+]] = arith.maxf %[[OUT]], %[[IN:.+]] : f32 - # CHECK-NEXT: linalg.yield %[[MAX]] : f32 - # CHECK-NEXT: -> tensor<1x2x4x1xf32> - @func.FuncOp.from_py_func( - RankedTensorType.get((1, 4, 16, 1), f32), - RankedTensorType.get((2, 2), f32), - RankedTensorType.get((1, 2, 4, 1), f32)) - def test_f32f32_max_pooling(input, shape, init_result): - return pooling_poly( - input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]) - - # CHECK-LABEL: @test_f32i32_min_pooling - # CHECK: = arith.fptosi - # CHECK: = arith.minsi - @func.FuncOp.from_py_func( - RankedTensorType.get((1, 4, 16, 1), f32), - RankedTensorType.get((2, 2), f32), - RankedTensorType.get((1, 2, 4, 1), i32)) - def test_f32i32_min_pooling(input, shape, init_result): - return pooling_poly( - input, - shape, - outs=[init_result], - reduce=BinaryFn.min_signed, - strides=[2, 4], - dilations=[1, 2]) - - # CHECK-LABEL: @test_f32i32_min_unsigned_pooling - # CHECK: = arith.fptoui - # CHECK: = arith.minui - @func.FuncOp.from_py_func( - RankedTensorType.get((1, 4, 16, 1), f32), - RankedTensorType.get((2, 2), f32), - RankedTensorType.get((1, 2, 4, 1), i32)) - def test_f32i32_min_unsigned_pooling(input, shape, init_result): - return pooling_poly( - input, - shape, - outs=[init_result], - reduce=BinaryFn.min_unsigned, - cast=TypeFn.cast_unsigned, - strides=[2, 4], - dilations=[1, 2]) - - # CHECK-LABEL: @test_f32f32_min_pooling - # CHECK: = arith.minf - @func.FuncOp.from_py_func( - RankedTensorType.get((1, 4, 16, 1), f32), - RankedTensorType.get((2, 2), f32), - RankedTensorType.get((1, 2, 4, 1), f32)) - def test_f32f32_min_pooling(input, shape, init_result): - return pooling_poly( - input, - shape, - outs=[init_result], - reduce=BinaryFn.min_signed, - strides=[2, 4], - dilations=[1, 2]) + module = Module.create() + f32 = F32Type.get() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + + # Pooling indexing maps. + # CHECK: #[[$POOL_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d3, d2 * 4 + d4 * 2, d5)> + # CHECK: #[[$POOL_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)> + # CHECK: #[[$POOL_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> + + # CHECK-LABEL: @test_f32i32_max_pooling + # CHECK: linalg.generic + # CHECK-SAME: indexing_maps = [#[[$POOL_MAP_I]], #[[$POOL_MAP_K]], #[[$POOL_MAP_O]]] + # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"] + # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32) + # CHECK-NEXT: %[[IN_CAST:.+]] = arith.fptosi %[[IN:.+]] : f32 to i32 + # CHECK-NEXT: %[[MAX:.+]] = arith.maxsi %[[OUT]], %[[IN_CAST:.+]] : i32 + # CHECK-NEXT: linalg.yield %[[MAX]] : i32 + # CHECK-NEXT: -> tensor<1x2x4x1xi32> + @func.FuncOp.from_py_func( + RankedTensorType.get((1, 4, 16, 1), f32), + RankedTensorType.get((2, 2), f32), + RankedTensorType.get((1, 2, 4, 1), i32), + ) + def test_f32i32_max_pooling(input, shape, init_result): + return pooling_poly( + input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2] + ) + + # CHECK-LABEL: @test_f32i32_max_unsigned_pooling + # CHECK: = arith.fptoui + # CHECK: = arith.maxui + @func.FuncOp.from_py_func( + RankedTensorType.get((1, 4, 16, 1), f32), + RankedTensorType.get((2, 2), f32), + RankedTensorType.get((1, 2, 4, 1), i32), + ) + def test_f32i32_max_unsigned_pooling(input, shape, init_result): + return pooling_poly( + input, + shape, + outs=[init_result], + reduce=BinaryFn.max_unsigned, + cast=TypeFn.cast_unsigned, + strides=[2, 4], + dilations=[1, 2], + ) + + # CHECK-LABEL: @test_f32f32_max_pooling + # CHECK: linalg.generic + # CHECK-SAME: indexing_maps = [#[[$POOL_MAP_I]], #[[$POOL_MAP_K]], #[[$POOL_MAP_O]]] + # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"] + # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: f32) + # CHECK-NEXT: %[[MAX:.+]] = arith.maxf %[[OUT]], %[[IN:.+]] : f32 + # CHECK-NEXT: linalg.yield %[[MAX]] : f32 + # CHECK-NEXT: -> tensor<1x2x4x1xf32> + @func.FuncOp.from_py_func( + RankedTensorType.get((1, 4, 16, 1), f32), + RankedTensorType.get((2, 2), f32), + RankedTensorType.get((1, 2, 4, 1), f32), + ) + def test_f32f32_max_pooling(input, shape, init_result): + return pooling_poly( + input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2] + ) + + # CHECK-LABEL: @test_f32i32_min_pooling + # CHECK: = arith.fptosi + # CHECK: = arith.minsi + @func.FuncOp.from_py_func( + RankedTensorType.get((1, 4, 16, 1), f32), + RankedTensorType.get((2, 2), f32), + RankedTensorType.get((1, 2, 4, 1), i32), + ) + def test_f32i32_min_pooling(input, shape, init_result): + return pooling_poly( + input, + shape, + outs=[init_result], + reduce=BinaryFn.min_signed, + strides=[2, 4], + dilations=[1, 2], + ) + + # CHECK-LABEL: @test_f32i32_min_unsigned_pooling + # CHECK: = arith.fptoui + # CHECK: = arith.minui + @func.FuncOp.from_py_func( + RankedTensorType.get((1, 4, 16, 1), f32), + RankedTensorType.get((2, 2), f32), + RankedTensorType.get((1, 2, 4, 1), i32), + ) + def test_f32i32_min_unsigned_pooling(input, shape, init_result): + return pooling_poly( + input, + shape, + outs=[init_result], + reduce=BinaryFn.min_unsigned, + cast=TypeFn.cast_unsigned, + strides=[2, 4], + dilations=[1, 2], + ) + + # CHECK-LABEL: @test_f32f32_min_pooling + # CHECK: = arith.minf + @func.FuncOp.from_py_func( + RankedTensorType.get((1, 4, 16, 1), f32), + RankedTensorType.get((2, 2), f32), + RankedTensorType.get((1, 2, 4, 1), f32), + ) + def test_f32f32_min_pooling(input, shape, init_result): + return pooling_poly( + input, + shape, + outs=[init_result], + reduce=BinaryFn.min_signed, + strides=[2, 4], + dilations=[1, 2], + ) print(module) diff --git a/mlir/test/python/dialects/linalg/opdsl/lit.local.cfg b/mlir/test/python/dialects/linalg/opdsl/lit.local.cfg index cead85f..18d2d45 100644 --- a/mlir/test/python/dialects/linalg/opdsl/lit.local.cfg +++ b/mlir/test/python/dialects/linalg/opdsl/lit.local.cfg @@ -4,6 +4,6 @@ # Since both lit and the python bindings use the same python interpreter, # we can just check whether yaml can be imported here and exclude if not. try: - import yaml + import yaml except ModuleNotFoundError: - config.unsupported = True + config.unsupported = True diff --git a/mlir/test/python/dialects/linalg/opdsl/metadata.py b/mlir/test/python/dialects/linalg/opdsl/metadata.py index a7502e9..9c940e1 100644 --- a/mlir/test/python/dialects/linalg/opdsl/metadata.py +++ b/mlir/test/python/dialects/linalg/opdsl/metadata.py @@ -13,8 +13,10 @@ from mlir.dialects.linalg.opdsl.lang import * def matmul( A=TensorDef(T, S.M, S.K), B=TensorDef(T, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True)): - implements(ContractionOpInterface) - defines(Canonicalizer) - C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed( - U, B[D.k, D.n]) + C=TensorDef(U, S.M, S.N, output=True), +): + implements(ContractionOpInterface) + defines(Canonicalizer) + C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed( + U, B[D.k, D.n] + ) diff --git a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py index 871341c..4f3569b 100644 --- a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py +++ b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py @@ -22,10 +22,12 @@ from mlir.dialects.linalg.opdsl.lang import * def matmul( A=TensorDef(T, S.M, S.K), B=TensorDef(T, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True)): - domain(D.m, D.n, D.k) - C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed( - U, B[D.k, D.n]) + C=TensorDef(U, S.M, S.N, output=True), +): + domain(D.m, D.n, D.k) + C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed( + U, B[D.k, D.n] + ) # Verifies that assignment to a scalar (represented as [None]) is represented @@ -43,7 +45,7 @@ def matmul( # CHECK-NEXT: - reduction @linalg_structured_op def dot(A=TensorDef(T, S.M), B=TensorDef(T, S.M), C=TensorDef(U, output=True)): - C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m]) + C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m]) # Verifies that the index_dims of shape-only operands translate to correct @@ -64,6 +66,7 @@ def dot(A=TensorDef(T, S.M), B=TensorDef(T, S.M), C=TensorDef(U, output=True)): def pool( I=TensorDef(T, S.I), K=TensorDef(T, S.K, index_dims=[D.k]), - O=TensorDef(U, S.O, output=True)): - domain(D.o, D.k) - O[D.o] += TypeFn.cast_signed(U, I[D.o * 2 + D.k]) + O=TensorDef(U, S.O, output=True), +): + domain(D.o, D.k) + O[D.o] += TypeFn.cast_signed(U, I[D.o * 2 + D.k]) diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py index 1167abf..5e8414a 100644 --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -6,145 +6,154 @@ from mlir.ir import * def run(f): - print("\nTEST:", f.__name__) - f() - return f + print("\nTEST:", f.__name__) + f() + return f # CHECK-LABEL: TEST: testFill @run def testFill(): - with Context() as ctx, Location.unknown(): - module = Module.create() - f32 = F32Type.get() - with InsertionPoint(module.body): - # CHECK-LABEL: func @fill_tensor - # CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<12x?xf32> - # CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32 - # CHECK-NEXT: %[[RES:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : tensor<12x?xf32>) -> tensor<12x?xf32> - # CHECK-NEXT: return %[[RES]] : tensor<12x?xf32> - @func.FuncOp.from_py_func( - RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32)) - def fill_tensor(out): - zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result - return linalg.fill(zero, outs=[out]) - - # CHECK-LABEL: func @fill_buffer - # CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref<12x?xf32> - # CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32 - # CHECK-NEXT: linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : memref<12x?xf32>) - # CHECK-NEXT: return - @func.FuncOp.from_py_func( - MemRefType.get((12, ShapedType.get_dynamic_size()), f32)) - def fill_buffer(out): - zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result - linalg.fill(zero, outs=[out]) - - print(module) + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + # CHECK-LABEL: func @fill_tensor + # CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<12x?xf32> + # CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32 + # CHECK-NEXT: %[[RES:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : tensor<12x?xf32>) -> tensor<12x?xf32> + # CHECK-NEXT: return %[[RES]] : tensor<12x?xf32> + @func.FuncOp.from_py_func( + RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32) + ) + def fill_tensor(out): + zero = arith.ConstantOp( + value=FloatAttr.get(f32, 0.0), result=f32 + ).result + return linalg.fill(zero, outs=[out]) + + # CHECK-LABEL: func @fill_buffer + # CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref<12x?xf32> + # CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32 + # CHECK-NEXT: linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : memref<12x?xf32>) + # CHECK-NEXT: return + @func.FuncOp.from_py_func( + MemRefType.get((12, ShapedType.get_dynamic_size()), f32) + ) + def fill_buffer(out): + zero = arith.ConstantOp( + value=FloatAttr.get(f32, 0.0), result=f32 + ).result + linalg.fill(zero, outs=[out]) + + print(module) # CHECK-LABEL: TEST: testNamedStructuredOpCustomForm @run def testNamedStructuredOpCustomForm(): - with Context() as ctx, Location.unknown(): - module = Module.create() - f32 = F32Type.get() - with InsertionPoint(module.body): - - @func.FuncOp.from_py_func( - RankedTensorType.get((4, 8), f32), RankedTensorType.get((4, 8), f32)) - def named_form(lhs, rhs): - init_result = tensor.EmptyOp([4, 8], f32) - # Check for the named form with custom format - # CHECK: linalg.elemwise_unary - # CHECK-SAME: cast = #linalg.type_fn - # CHECK-SAME: fun = #linalg.unary_fn - # CHECK-SAME: ins(%{{.*}} : tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>) - unary_result = linalg.elemwise_unary(lhs, outs=[init_result.result]) - # CHECK: linalg.elemwise_binary - # CHECK-SAME: cast = #linalg.type_fn - # CHECK-SAME: fun = #linalg.binary_fn - # CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<4x8xf32>, tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>) - # CHECK: return - binary_result = linalg.elemwise_binary( - lhs, - rhs, - outs=[init_result.result], - fun=BinaryFn.mul, - cast=TypeFn.cast_unsigned) - return unary_result, binary_result - - print(module) + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 8), f32), RankedTensorType.get((4, 8), f32) + ) + def named_form(lhs, rhs): + init_result = tensor.EmptyOp([4, 8], f32) + # Check for the named form with custom format + # CHECK: linalg.elemwise_unary + # CHECK-SAME: cast = #linalg.type_fn + # CHECK-SAME: fun = #linalg.unary_fn + # CHECK-SAME: ins(%{{.*}} : tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>) + unary_result = linalg.elemwise_unary(lhs, outs=[init_result.result]) + # CHECK: linalg.elemwise_binary + # CHECK-SAME: cast = #linalg.type_fn + # CHECK-SAME: fun = #linalg.binary_fn + # CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<4x8xf32>, tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>) + # CHECK: return + binary_result = linalg.elemwise_binary( + lhs, + rhs, + outs=[init_result.result], + fun=BinaryFn.mul, + cast=TypeFn.cast_unsigned, + ) + return unary_result, binary_result + + print(module) # CHECK-LABEL: TEST: testNamedStructuredOpGenericForm @run def testNamedStructuredOpGenericForm(): - with Context() as ctx, Location.unknown(): - module = Module.create() - f32 = F32Type.get() - with InsertionPoint(module.body): - - @func.FuncOp.from_py_func( - RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), - f32)) - def named_form(lhs, rhs): - init_result = tensor.EmptyOp([4, 8], f32) - # CHECK: "linalg.matmul"(%{{.*}}) - # CHECK-SAME: cast = #linalg.type_fn - # CHECK-SAME: operand_segment_sizes = array - # CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32): - # CHECK-NEXT: arith.mulf{{.*}} (f32, f32) -> f32 - # CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32 - # CHECK-NEXT: linalg.yield{{.*}} (f32) -> () - # CHECK-NEXT: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> - return linalg.matmul(lhs, rhs, outs=[init_result.result]) - - module.operation.print(print_generic_op_form=True) + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32) + ) + def named_form(lhs, rhs): + init_result = tensor.EmptyOp([4, 8], f32) + # CHECK: "linalg.matmul"(%{{.*}}) + # CHECK-SAME: cast = #linalg.type_fn + # CHECK-SAME: operand_segment_sizes = array + # CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32): + # CHECK-NEXT: arith.mulf{{.*}} (f32, f32) -> f32 + # CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32 + # CHECK-NEXT: linalg.yield{{.*}} (f32) -> () + # CHECK-NEXT: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> + return linalg.matmul(lhs, rhs, outs=[init_result.result]) + + module.operation.print(print_generic_op_form=True) # CHECK-LABEL: TEST: testNamedStructuredAsGenericOp @run def testNamedStructuredAsGenericOp(): - with Context() as ctx, Location.unknown(): - module = Module.create() - f32 = F32Type.get() - with InsertionPoint(module.body): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): - @func.FuncOp.from_py_func( - RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), - f32)) - def generic_form(lhs, rhs): - init_result = tensor.EmptyOp([4, 8], f32) - # CHECK: linalg.generic - return linalg.matmul( - lhs, rhs, outs=[init_result.result], emit_generic=True) + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32) + ) + def generic_form(lhs, rhs): + init_result = tensor.EmptyOp([4, 8], f32) + # CHECK: linalg.generic + return linalg.matmul( + lhs, rhs, outs=[init_result.result], emit_generic=True + ) - print(module) + print(module) # CHECK-LABEL: TEST: testOpResultFromOtherOp @run def testOpResultFromOtherOp(): - with Context(), Location.unknown(): - module = Module.create() - f32 = F32Type.get() - with InsertionPoint(module.body): - - @func.FuncOp.from_py_func( - RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), - f32)) - def pass_an_op_directly(arg0, arg1): - one = arith.ConstantOp(F32Type.get(), 1.0) - # CHECK: %[[LHS:.*]] = linalg.fill - lhs = linalg.fill(one, outs=[arg0]) - # CHECK: %[[RHS:.*]] = linalg.fill - rhs = linalg.fill(one, outs=[arg1]) - # CHECK: %[[INIT:.*]] = tensor.empty - init = tensor.EmptyOp([4, 8], f32) - # CHECK: linalg.matmul - # CHECK: ins(%[[LHS]], %[[RHS]] - # CHECK: outs(%[[INIT]] - return linalg.matmul(lhs, rhs, outs=init) - - print(module) + with Context(), Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func( + RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32) + ) + def pass_an_op_directly(arg0, arg1): + one = arith.ConstantOp(F32Type.get(), 1.0) + # CHECK: %[[LHS:.*]] = linalg.fill + lhs = linalg.fill(one, outs=[arg0]) + # CHECK: %[[RHS:.*]] = linalg.fill + rhs = linalg.fill(one, outs=[arg1]) + # CHECK: %[[INIT:.*]] = tensor.empty + init = tensor.EmptyOp([4, 8], f32) + # CHECK: linalg.matmul + # CHECK: ins(%[[LHS]], %[[RHS]] + # CHECK: outs(%[[INIT]] + return linalg.matmul(lhs, rhs, outs=init) + + print(module) diff --git a/mlir/test/python/dialects/math_dialect.py b/mlir/test/python/dialects/math_dialect.py index 04b6d84..3d402c5 100644 --- a/mlir/test/python/dialects/math_dialect.py +++ b/mlir/test/python/dialects/math_dialect.py @@ -7,23 +7,26 @@ from mlir.ir import * import mlir.dialects.func as func import mlir.dialects.math as mlir_math + def run(f): - print("\nTEST:", f.__name__) - f() + print("\nTEST:", f.__name__) + f() + # CHECK-LABEL: TEST: testMathOps @run def testMathOps(): - with Context() as ctx, Location.unknown(): - module = Module.create() - with InsertionPoint(module.body): - @func.FuncOp.from_py_func(F32Type.get()) - def emit_sqrt(arg): - return mlir_math.SqrtOp(arg) - - # CHECK-LABEL: func @emit_sqrt( - # CHECK-SAME: %[[ARG:.*]]: f32) -> f32 { - # CHECK: math.sqrt %[[ARG]] : f32 - # CHECK: return - # CHECK: } - print(module) + with Context() as ctx, Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func(F32Type.get()) + def emit_sqrt(arg): + return mlir_math.SqrtOp(arg) + + # CHECK-LABEL: func @emit_sqrt( + # CHECK-SAME: %[[ARG:.*]]: f32) -> f32 { + # CHECK: math.sqrt %[[ARG]] : f32 + # CHECK: return + # CHECK: } + print(module) diff --git a/mlir/test/python/dialects/memref.py b/mlir/test/python/dialects/memref.py index 59092fe..2e3cae6 100644 --- a/mlir/test/python/dialects/memref.py +++ b/mlir/test/python/dialects/memref.py @@ -6,17 +6,17 @@ import mlir.dialects.memref as memref def run(f): - print("\nTEST:", f.__name__) - f() - return f + print("\nTEST:", f.__name__) + f() + return f # CHECK-LABEL: TEST: testSubViewAccessors @run def testSubViewAccessors(): - ctx = Context() - module = Module.parse( - r""" + ctx = Context() + module = Module.parse( + r""" func.func @f1(%arg0: memref) { %0 = arith.constant 0 : index %1 = arith.constant 1 : index @@ -27,48 +27,52 @@ def testSubViewAccessors(): memref.subview %arg0[%0, %1][%2, %3][%4, %5] : memref to memref> return } - """, ctx) - func_body = module.body.operations[0].regions[0].blocks[0] - subview = func_body.operations[6] + """, + ctx, + ) + func_body = module.body.operations[0].regions[0].blocks[0] + subview = func_body.operations[6] - assert subview.source == subview.operands[0] - assert len(subview.offsets) == 2 - assert len(subview.sizes) == 2 - assert len(subview.strides) == 2 - assert subview.result == subview.results[0] + assert subview.source == subview.operands[0] + assert len(subview.offsets) == 2 + assert len(subview.sizes) == 2 + assert len(subview.strides) == 2 + assert subview.result == subview.results[0] - # CHECK: SubViewOp - print(type(subview).__name__) + # CHECK: SubViewOp + print(type(subview).__name__) - # CHECK: constant 0 - print(subview.offsets[0]) - # CHECK: constant 1 - print(subview.offsets[1]) - # CHECK: constant 2 - print(subview.sizes[0]) - # CHECK: constant 3 - print(subview.sizes[1]) - # CHECK: constant 4 - print(subview.strides[0]) - # CHECK: constant 5 - print(subview.strides[1]) + # CHECK: constant 0 + print(subview.offsets[0]) + # CHECK: constant 1 + print(subview.offsets[1]) + # CHECK: constant 2 + print(subview.sizes[0]) + # CHECK: constant 3 + print(subview.sizes[1]) + # CHECK: constant 4 + print(subview.strides[0]) + # CHECK: constant 5 + print(subview.strides[1]) # CHECK-LABEL: TEST: testCustomBuidlers @run def testCustomBuidlers(): - with Context() as ctx, Location.unknown(ctx): - module = Module.parse(r""" + with Context() as ctx, Location.unknown(ctx): + module = Module.parse( + r""" func.func @f1(%arg0: memref, %arg1: index, %arg2: index) { return } - """) - f = module.body.operations[0] - func_body = f.regions[0].blocks[0] - with InsertionPoint.at_block_terminator(func_body): - memref.LoadOp(f.arguments[0], f.arguments[1:]) + """ + ) + f = module.body.operations[0] + func_body = f.regions[0].blocks[0] + with InsertionPoint.at_block_terminator(func_body): + memref.LoadOp(f.arguments[0], f.arguments[1:]) - # CHECK: func @f1(%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) - # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] - print(module) - assert module.operation.verify() + # CHECK: func @f1(%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) + # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] + print(module) + assert module.operation.verify() diff --git a/mlir/test/python/dialects/ml_program.py b/mlir/test/python/dialects/ml_program.py index 4d9804f..f16de2a 100644 --- a/mlir/test/python/dialects/ml_program.py +++ b/mlir/test/python/dialects/ml_program.py @@ -6,23 +6,23 @@ from mlir.dialects import ml_program def constructAndPrintInModule(f): - print("\nTEST:", f.__name__) - with Context(), Location.unknown(): - module = Module.create() - with InsertionPoint(module.body): - f() - print(module) - return f + print("\nTEST:", f.__name__) + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + f() + print(module) + return f # CHECK-LABEL: testFuncOp @constructAndPrintInModule def testFuncOp(): - # CHECK: ml_program.func @foobar(%arg0: si32) -> si32 - f = ml_program.FuncOp( - name="foobar", - type=([IntegerType.get_signed(32)], [IntegerType.get_signed(32)])) - block = f.add_entry_block() - with InsertionPoint(block): - # CHECK: ml_program.return - ml_program.ReturnOp([block.arguments[0]]) + # CHECK: ml_program.func @foobar(%arg0: si32) -> si32 + f = ml_program.FuncOp( + name="foobar", type=([IntegerType.get_signed(32)], [IntegerType.get_signed(32)]) + ) + block = f.add_entry_block() + with InsertionPoint(block): + # CHECK: ml_program.return + ml_program.ReturnOp([block.arguments[0]]) diff --git a/mlir/test/python/dialects/ods_helpers.py b/mlir/test/python/dialects/ods_helpers.py index 802a1f2..71879bd 100644 --- a/mlir/test/python/dialects/ods_helpers.py +++ b/mlir/test/python/dialects/ods_helpers.py @@ -6,207 +6,205 @@ from mlir.ir import * def run(f): - print("\nTEST:", f.__name__) - f() - gc.collect() - assert Context._get_live_count() == 0 + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 def add_dummy_value(): - return Operation.create( - "custom.value", - results=[IntegerType.get_signless(32)]).result + return Operation.create( + "custom.value", results=[IntegerType.get_signless(32)] + ).result def testOdsBuildDefaultImplicitRegions(): - - class TestFixedRegionsOp(OpView): - OPERATION_NAME = "custom.test_op" - _ODS_REGIONS = (2, True) - - class TestVariadicRegionsOp(OpView): - OPERATION_NAME = "custom.test_any_regions_op" - _ODS_REGIONS = (2, False) - - with Context() as ctx, Location.unknown(): - ctx.allow_unregistered_dialects = True - m = Module.create() - with InsertionPoint(m.body): - op = TestFixedRegionsOp.build_generic(results=[], operands=[]) - # CHECK: NUM_REGIONS: 2 - print(f"NUM_REGIONS: {len(op.regions)}") - # Including a regions= that matches should be fine. - op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=2) - print(f"NUM_REGIONS: {len(op.regions)}") - # Reject greater than. - try: - op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=3) - except ValueError as e: - # CHECK: ERROR:Operation "custom.test_op" requires a maximum of 2 regions but was built with regions=3 - print(f"ERROR:{e}") - # Reject less than. - try: - op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=1) - except ValueError as e: - # CHECK: ERROR:Operation "custom.test_op" requires a minimum of 2 regions but was built with regions=1 - print(f"ERROR:{e}") - - # If no regions specified for a variadic region op, build the minimum. - op = TestVariadicRegionsOp.build_generic(results=[], operands=[]) - # CHECK: DEFAULT_NUM_REGIONS: 2 - print(f"DEFAULT_NUM_REGIONS: {len(op.regions)}") - # Should also accept an explicit regions= that matches the minimum. - op = TestVariadicRegionsOp.build_generic( - results=[], operands=[], regions=2) - # CHECK: EQ_NUM_REGIONS: 2 - print(f"EQ_NUM_REGIONS: {len(op.regions)}") - # And accept greater than minimum. - # Should also accept an explicit regions= that matches the minimum. - op = TestVariadicRegionsOp.build_generic( - results=[], operands=[], regions=3) - # CHECK: GT_NUM_REGIONS: 3 - print(f"GT_NUM_REGIONS: {len(op.regions)}") - # Should reject less than minimum. - try: - op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=1) - except ValueError as e: - # CHECK: ERROR:Operation "custom.test_any_regions_op" requires a minimum of 2 regions but was built with regions=1 - print(f"ERROR:{e}") - + class TestFixedRegionsOp(OpView): + OPERATION_NAME = "custom.test_op" + _ODS_REGIONS = (2, True) + + class TestVariadicRegionsOp(OpView): + OPERATION_NAME = "custom.test_any_regions_op" + _ODS_REGIONS = (2, False) + + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + m = Module.create() + with InsertionPoint(m.body): + op = TestFixedRegionsOp.build_generic(results=[], operands=[]) + # CHECK: NUM_REGIONS: 2 + print(f"NUM_REGIONS: {len(op.regions)}") + # Including a regions= that matches should be fine. + op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=2) + print(f"NUM_REGIONS: {len(op.regions)}") + # Reject greater than. + try: + op = TestFixedRegionsOp.build_generic( + results=[], operands=[], regions=3 + ) + except ValueError as e: + # CHECK: ERROR:Operation "custom.test_op" requires a maximum of 2 regions but was built with regions=3 + print(f"ERROR:{e}") + # Reject less than. + try: + op = TestFixedRegionsOp.build_generic( + results=[], operands=[], regions=1 + ) + except ValueError as e: + # CHECK: ERROR:Operation "custom.test_op" requires a minimum of 2 regions but was built with regions=1 + print(f"ERROR:{e}") + + # If no regions specified for a variadic region op, build the minimum. + op = TestVariadicRegionsOp.build_generic(results=[], operands=[]) + # CHECK: DEFAULT_NUM_REGIONS: 2 + print(f"DEFAULT_NUM_REGIONS: {len(op.regions)}") + # Should also accept an explicit regions= that matches the minimum. + op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=2) + # CHECK: EQ_NUM_REGIONS: 2 + print(f"EQ_NUM_REGIONS: {len(op.regions)}") + # And accept greater than minimum. + # Should also accept an explicit regions= that matches the minimum. + op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=3) + # CHECK: GT_NUM_REGIONS: 3 + print(f"GT_NUM_REGIONS: {len(op.regions)}") + # Should reject less than minimum. + try: + op = TestVariadicRegionsOp.build_generic( + results=[], operands=[], regions=1 + ) + except ValueError as e: + # CHECK: ERROR:Operation "custom.test_any_regions_op" requires a minimum of 2 regions but was built with regions=1 + print(f"ERROR:{e}") run(testOdsBuildDefaultImplicitRegions) def testOdsBuildDefaultNonVariadic(): + class TestOp(OpView): + OPERATION_NAME = "custom.test_op" + + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + m = Module.create() + with InsertionPoint(m.body): + v0 = add_dummy_value() + v1 = add_dummy_value() + t0 = IntegerType.get_signless(8) + t1 = IntegerType.get_signless(16) + op = TestOp.build_generic(results=[t0, t1], operands=[v0, v1]) + # CHECK: %[[V0:.+]] = "custom.value" + # CHECK: %[[V1:.+]] = "custom.value" + # CHECK: "custom.test_op"(%[[V0]], %[[V1]]) + # CHECK-NOT: operand_segment_sizes + # CHECK-NOT: result_segment_sizes + # CHECK-SAME: : (i32, i32) -> (i8, i16) + print(m) - class TestOp(OpView): - OPERATION_NAME = "custom.test_op" - - with Context() as ctx, Location.unknown(): - ctx.allow_unregistered_dialects = True - m = Module.create() - with InsertionPoint(m.body): - v0 = add_dummy_value() - v1 = add_dummy_value() - t0 = IntegerType.get_signless(8) - t1 = IntegerType.get_signless(16) - op = TestOp.build_generic(results=[t0, t1], operands=[v0, v1]) - # CHECK: %[[V0:.+]] = "custom.value" - # CHECK: %[[V1:.+]] = "custom.value" - # CHECK: "custom.test_op"(%[[V0]], %[[V1]]) - # CHECK-NOT: operand_segment_sizes - # CHECK-NOT: result_segment_sizes - # CHECK-SAME: : (i32, i32) -> (i8, i16) - print(m) run(testOdsBuildDefaultNonVariadic) def testOdsBuildDefaultSizedVariadic(): + class TestOp(OpView): + OPERATION_NAME = "custom.test_op" + _ODS_OPERAND_SEGMENTS = [1, -1, 0] + _ODS_RESULT_SEGMENTS = [-1, 0, 1] + + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + m = Module.create() + with InsertionPoint(m.body): + v0 = add_dummy_value() + v1 = add_dummy_value() + v2 = add_dummy_value() + v3 = add_dummy_value() + t0 = IntegerType.get_signless(8) + t1 = IntegerType.get_signless(16) + t2 = IntegerType.get_signless(32) + t3 = IntegerType.get_signless(64) + # CHECK: %[[V0:.+]] = "custom.value" + # CHECK: %[[V1:.+]] = "custom.value" + # CHECK: %[[V2:.+]] = "custom.value" + # CHECK: %[[V3:.+]] = "custom.value" + # CHECK: "custom.test_op"(%[[V0]], %[[V1]], %[[V2]], %[[V3]]) + # CHECK-SAME: operand_segment_sizes = array + # CHECK-SAME: result_segment_sizes = array + # CHECK-SAME: : (i32, i32, i32, i32) -> (i8, i16, i32, i64) + op = TestOp.build_generic( + results=[[t0, t1], t2, t3], operands=[v0, [v1, v2], v3] + ) + + # Now test with optional omitted. + # CHECK: "custom.test_op"(%[[V0]]) + # CHECK-SAME: operand_segment_sizes = array + # CHECK-SAME: result_segment_sizes = array + # CHECK-SAME: (i32) -> i64 + op = TestOp.build_generic( + results=[None, None, t3], operands=[v0, None, None] + ) + print(m) + + # And verify that errors are raised for None in a required operand. + try: + op = TestOp.build_generic( + results=[None, None, t3], operands=[None, None, None] + ) + except ValueError as e: + # CHECK: OPERAND_CAST_ERROR:Operand 0 of operation "custom.test_op" must be a Value (was None and operand is not optional) + print(f"OPERAND_CAST_ERROR:{e}") + + # And verify that errors are raised for None in a required result. + try: + op = TestOp.build_generic( + results=[None, None, None], operands=[v0, None, None] + ) + except ValueError as e: + # CHECK: RESULT_CAST_ERROR:Result 2 of operation "custom.test_op" must be a Type (was None and result is not optional) + print(f"RESULT_CAST_ERROR:{e}") + + # Variadic lists with None elements should reject. + try: + op = TestOp.build_generic( + results=[None, None, t3], operands=[v0, [None], None] + ) + except ValueError as e: + # CHECK: OPERAND_LIST_CAST_ERROR:Operand 1 of operation "custom.test_op" must be a Sequence of Values (contained a None item) + print(f"OPERAND_LIST_CAST_ERROR:{e}") + try: + op = TestOp.build_generic( + results=[[None], None, t3], operands=[v0, None, None] + ) + except ValueError as e: + # CHECK: RESULT_LIST_CAST_ERROR:Result 0 of operation "custom.test_op" must be a Sequence of Types (contained a None item) + print(f"RESULT_LIST_CAST_ERROR:{e}") - class TestOp(OpView): - OPERATION_NAME = "custom.test_op" - _ODS_OPERAND_SEGMENTS = [1, -1, 0] - _ODS_RESULT_SEGMENTS = [-1, 0, 1] - - with Context() as ctx, Location.unknown(): - ctx.allow_unregistered_dialects = True - m = Module.create() - with InsertionPoint(m.body): - v0 = add_dummy_value() - v1 = add_dummy_value() - v2 = add_dummy_value() - v3 = add_dummy_value() - t0 = IntegerType.get_signless(8) - t1 = IntegerType.get_signless(16) - t2 = IntegerType.get_signless(32) - t3 = IntegerType.get_signless(64) - # CHECK: %[[V0:.+]] = "custom.value" - # CHECK: %[[V1:.+]] = "custom.value" - # CHECK: %[[V2:.+]] = "custom.value" - # CHECK: %[[V3:.+]] = "custom.value" - # CHECK: "custom.test_op"(%[[V0]], %[[V1]], %[[V2]], %[[V3]]) - # CHECK-SAME: operand_segment_sizes = array - # CHECK-SAME: result_segment_sizes = array - # CHECK-SAME: : (i32, i32, i32, i32) -> (i8, i16, i32, i64) - op = TestOp.build_generic( - results=[[t0, t1], t2, t3], - operands=[v0, [v1, v2], v3]) - - # Now test with optional omitted. - # CHECK: "custom.test_op"(%[[V0]]) - # CHECK-SAME: operand_segment_sizes = array - # CHECK-SAME: result_segment_sizes = array - # CHECK-SAME: (i32) -> i64 - op = TestOp.build_generic( - results=[None, None, t3], - operands=[v0, None, None]) - print(m) - - # And verify that errors are raised for None in a required operand. - try: - op = TestOp.build_generic( - results=[None, None, t3], - operands=[None, None, None]) - except ValueError as e: - # CHECK: OPERAND_CAST_ERROR:Operand 0 of operation "custom.test_op" must be a Value (was None and operand is not optional) - print(f"OPERAND_CAST_ERROR:{e}") - - # And verify that errors are raised for None in a required result. - try: - op = TestOp.build_generic( - results=[None, None, None], - operands=[v0, None, None]) - except ValueError as e: - # CHECK: RESULT_CAST_ERROR:Result 2 of operation "custom.test_op" must be a Type (was None and result is not optional) - print(f"RESULT_CAST_ERROR:{e}") - - # Variadic lists with None elements should reject. - try: - op = TestOp.build_generic( - results=[None, None, t3], - operands=[v0, [None], None]) - except ValueError as e: - # CHECK: OPERAND_LIST_CAST_ERROR:Operand 1 of operation "custom.test_op" must be a Sequence of Values (contained a None item) - print(f"OPERAND_LIST_CAST_ERROR:{e}") - try: - op = TestOp.build_generic( - results=[[None], None, t3], - operands=[v0, None, None]) - except ValueError as e: - # CHECK: RESULT_LIST_CAST_ERROR:Result 0 of operation "custom.test_op" must be a Sequence of Types (contained a None item) - print(f"RESULT_LIST_CAST_ERROR:{e}") run(testOdsBuildDefaultSizedVariadic) def testOdsBuildDefaultCastError(): + class TestOp(OpView): + OPERATION_NAME = "custom.test_op" + + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + m = Module.create() + with InsertionPoint(m.body): + v0 = add_dummy_value() + v1 = add_dummy_value() + t0 = IntegerType.get_signless(8) + t1 = IntegerType.get_signless(16) + try: + op = TestOp.build_generic(results=[t0, t1], operands=[None, v1]) + except ValueError as e: + # CHECK: ERROR: Operand 0 of operation "custom.test_op" must be a Value + print(f"ERROR: {e}") + try: + op = TestOp.build_generic(results=[t0, None], operands=[v0, v1]) + except ValueError as e: + # CHECK: Result 1 of operation "custom.test_op" must be a Type + print(f"ERROR: {e}") - class TestOp(OpView): - OPERATION_NAME = "custom.test_op" - - with Context() as ctx, Location.unknown(): - ctx.allow_unregistered_dialects = True - m = Module.create() - with InsertionPoint(m.body): - v0 = add_dummy_value() - v1 = add_dummy_value() - t0 = IntegerType.get_signless(8) - t1 = IntegerType.get_signless(16) - try: - op = TestOp.build_generic( - results=[t0, t1], - operands=[None, v1]) - except ValueError as e: - # CHECK: ERROR: Operand 0 of operation "custom.test_op" must be a Value - print(f"ERROR: {e}") - try: - op = TestOp.build_generic( - results=[t0, None], - operands=[v0, v1]) - except ValueError as e: - # CHECK: Result 1 of operation "custom.test_op" must be a Type - print(f"ERROR: {e}") run(testOdsBuildDefaultCastError) diff --git a/mlir/test/python/dialects/pdl_ops.py b/mlir/test/python/dialects/pdl_ops.py index 3d9cd19..0d364f9 100644 --- a/mlir/test/python/dialects/pdl_ops.py +++ b/mlir/test/python/dialects/pdl_ops.py @@ -5,13 +5,13 @@ from mlir.dialects.pdl import * def constructAndPrintInModule(f): - print("\nTEST:", f.__name__) - with Context(), Location.unknown(): - module = Module.create() - with InsertionPoint(module.body): - f() - print(module) - return f + print("\nTEST:", f.__name__) + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + f() + print(module) + return f # CHECK: module { @@ -27,15 +27,15 @@ def constructAndPrintInModule(f): # CHECK: } @constructAndPrintInModule def test_operations(): - pattern = PatternOp(1, "operations") - with InsertionPoint(pattern.body): - attr = AttributeOp() - ty = TypeOp() - op0 = OperationOp(attributes={"attr": attr}, types=[ty]) - op0_result = ResultOp(op0, 0) - input = OperandOp() - root = OperationOp(args=[op0_result, input]) - RewriteOp(root, "rewriter") + pattern = PatternOp(1, "operations") + with InsertionPoint(pattern.body): + attr = AttributeOp() + ty = TypeOp() + op0 = OperationOp(attributes={"attr": attr}, types=[ty]) + op0_result = ResultOp(op0, 0) + input = OperandOp() + root = OperationOp(args=[op0_result, input]) + RewriteOp(root, "rewriter") # CHECK: module { @@ -47,11 +47,12 @@ def test_operations(): # CHECK: } @constructAndPrintInModule def test_rewrite_with_args(): - pattern = PatternOp(1, "rewrite_with_args") - with InsertionPoint(pattern.body): - input = OperandOp() - root = OperationOp(args=[input]) - RewriteOp(root, "rewriter", args=[input]) + pattern = PatternOp(1, "rewrite_with_args") + with InsertionPoint(pattern.body): + input = OperandOp() + root = OperationOp(args=[input]) + RewriteOp(root, "rewriter", args=[input]) + # CHECK: module { # CHECK: pdl.pattern @rewrite_multi_root_optimal : benefit(1) { @@ -69,18 +70,19 @@ def test_rewrite_with_args(): # CHECK: } @constructAndPrintInModule def test_rewrite_multi_root_optimal(): - pattern = PatternOp(1, "rewrite_multi_root_optimal") - with InsertionPoint(pattern.body): - input1 = OperandOp() - input2 = OperandOp() - ty = TypeOp() - op1 = OperationOp(args=[input1], types=[ty]) - val1 = ResultOp(op1, 0) - root1 = OperationOp(args=[val1]) - op2 = OperationOp(args=[input2], types=[ty]) - val2 = ResultOp(op2, 0) - root2 = OperationOp(args=[val1, val2]) - RewriteOp(name="rewriter", args=[root1, root2]) + pattern = PatternOp(1, "rewrite_multi_root_optimal") + with InsertionPoint(pattern.body): + input1 = OperandOp() + input2 = OperandOp() + ty = TypeOp() + op1 = OperationOp(args=[input1], types=[ty]) + val1 = ResultOp(op1, 0) + root1 = OperationOp(args=[val1]) + op2 = OperationOp(args=[input2], types=[ty]) + val2 = ResultOp(op2, 0) + root2 = OperationOp(args=[val1, val2]) + RewriteOp(name="rewriter", args=[root1, root2]) + # CHECK: module { # CHECK: pdl.pattern @rewrite_multi_root_forced : benefit(1) { @@ -98,18 +100,19 @@ def test_rewrite_multi_root_optimal(): # CHECK: } @constructAndPrintInModule def test_rewrite_multi_root_forced(): - pattern = PatternOp(1, "rewrite_multi_root_forced") - with InsertionPoint(pattern.body): - input1 = OperandOp() - input2 = OperandOp() - ty = TypeOp() - op1 = OperationOp(args=[input1], types=[ty]) - val1 = ResultOp(op1, 0) - root1 = OperationOp(args=[val1]) - op2 = OperationOp(args=[input2], types=[ty]) - val2 = ResultOp(op2, 0) - root2 = OperationOp(args=[val1, val2]) - RewriteOp(root1, name="rewriter", args=[root2]) + pattern = PatternOp(1, "rewrite_multi_root_forced") + with InsertionPoint(pattern.body): + input1 = OperandOp() + input2 = OperandOp() + ty = TypeOp() + op1 = OperationOp(args=[input1], types=[ty]) + val1 = ResultOp(op1, 0) + root1 = OperationOp(args=[val1]) + op2 = OperationOp(args=[input2], types=[ty]) + val2 = ResultOp(op2, 0) + root2 = OperationOp(args=[val1, val2]) + RewriteOp(root1, name="rewriter", args=[root2]) + # CHECK: module { # CHECK: pdl.pattern @rewrite_add_body : benefit(1) { @@ -125,16 +128,17 @@ def test_rewrite_multi_root_forced(): # CHECK: } @constructAndPrintInModule def test_rewrite_add_body(): - pattern = PatternOp(1, "rewrite_add_body") - with InsertionPoint(pattern.body): - ty1 = TypeOp(IntegerType.get_signless(32)) - ty2 = TypeOp() - root = OperationOp(types=[ty1, ty2]) - rewrite = RewriteOp(root) - with InsertionPoint(rewrite.add_body()): - ty3 = TypeOp() - newOp = OperationOp(name="foo.op", types=[ty1, ty3]) - ReplaceOp(root, with_op=newOp) + pattern = PatternOp(1, "rewrite_add_body") + with InsertionPoint(pattern.body): + ty1 = TypeOp(IntegerType.get_signless(32)) + ty2 = TypeOp() + root = OperationOp(types=[ty1, ty2]) + rewrite = RewriteOp(root) + with InsertionPoint(rewrite.add_body()): + ty3 = TypeOp() + newOp = OperationOp(name="foo.op", types=[ty1, ty3]) + ReplaceOp(root, with_op=newOp) + # CHECK: module { # CHECK: pdl.pattern @rewrite_type : benefit(1) { @@ -148,14 +152,15 @@ def test_rewrite_add_body(): # CHECK: } @constructAndPrintInModule def test_rewrite_type(): - pattern = PatternOp(1, "rewrite_type") - with InsertionPoint(pattern.body): - ty1 = TypeOp(IntegerType.get_signless(32)) - ty2 = TypeOp() - root = OperationOp(types=[ty1, ty2]) - rewrite = RewriteOp(root) - with InsertionPoint(rewrite.add_body()): - newOp = OperationOp(name="foo.op", types=[ty1, ty2]) + pattern = PatternOp(1, "rewrite_type") + with InsertionPoint(pattern.body): + ty1 = TypeOp(IntegerType.get_signless(32)) + ty2 = TypeOp() + root = OperationOp(types=[ty1, ty2]) + rewrite = RewriteOp(root) + with InsertionPoint(rewrite.add_body()): + newOp = OperationOp(name="foo.op", types=[ty1, ty2]) + # CHECK: module { # CHECK: pdl.pattern @rewrite_types : benefit(1) { @@ -169,14 +174,17 @@ def test_rewrite_type(): # CHECK: } @constructAndPrintInModule def test_rewrite_types(): - pattern = PatternOp(1, "rewrite_types") - with InsertionPoint(pattern.body): - types = TypesOp() - root = OperationOp(types=[types]) - rewrite = RewriteOp(root) - with InsertionPoint(rewrite.add_body()): - otherTypes = TypesOp([IntegerType.get_signless(32), IntegerType.get_signless(64)]) - newOp = OperationOp(name="foo.op", types=[types, otherTypes]) + pattern = PatternOp(1, "rewrite_types") + with InsertionPoint(pattern.body): + types = TypesOp() + root = OperationOp(types=[types]) + rewrite = RewriteOp(root) + with InsertionPoint(rewrite.add_body()): + otherTypes = TypesOp( + [IntegerType.get_signless(32), IntegerType.get_signless(64)] + ) + newOp = OperationOp(name="foo.op", types=[types, otherTypes]) + # CHECK: module { # CHECK: pdl.pattern @rewrite_operands : benefit(1) { @@ -190,14 +198,15 @@ def test_rewrite_types(): # CHECK: } @constructAndPrintInModule def test_rewrite_operands(): - pattern = PatternOp(1, "rewrite_operands") - with InsertionPoint(pattern.body): - types = TypesOp() - operands = OperandsOp(types) - root = OperationOp(args=[operands]) - rewrite = RewriteOp(root) - with InsertionPoint(rewrite.add_body()): - newOp = OperationOp(name="foo.op", types=[types]) + pattern = PatternOp(1, "rewrite_operands") + with InsertionPoint(pattern.body): + types = TypesOp() + operands = OperandsOp(types) + root = OperationOp(args=[operands]) + rewrite = RewriteOp(root) + with InsertionPoint(rewrite.add_body()): + newOp = OperationOp(name="foo.op", types=[types]) + # CHECK: module { # CHECK: pdl.pattern @native_rewrite : benefit(1) { @@ -209,12 +218,13 @@ def test_rewrite_operands(): # CHECK: } @constructAndPrintInModule def test_native_rewrite(): - pattern = PatternOp(1, "native_rewrite") - with InsertionPoint(pattern.body): - root = OperationOp() - rewrite = RewriteOp(root) - with InsertionPoint(rewrite.add_body()): - ApplyNativeRewriteOp([], "NativeRewrite", args=[root]) + pattern = PatternOp(1, "native_rewrite") + with InsertionPoint(pattern.body): + root = OperationOp() + rewrite = RewriteOp(root) + with InsertionPoint(rewrite.add_body()): + ApplyNativeRewriteOp([], "NativeRewrite", args=[root]) + # CHECK: module { # CHECK: pdl.pattern @attribute_with_value : benefit(1) { @@ -227,13 +237,14 @@ def test_native_rewrite(): # CHECK: } @constructAndPrintInModule def test_attribute_with_value(): - pattern = PatternOp(1, "attribute_with_value") - with InsertionPoint(pattern.body): - root = OperationOp() - rewrite = RewriteOp(root) - with InsertionPoint(rewrite.add_body()): - attr = AttributeOp(value=Attribute.parse('"value"')) - ApplyNativeRewriteOp([], "NativeRewrite", args=[attr]) + pattern = PatternOp(1, "attribute_with_value") + with InsertionPoint(pattern.body): + root = OperationOp() + rewrite = RewriteOp(root) + with InsertionPoint(rewrite.add_body()): + attr = AttributeOp(value=Attribute.parse('"value"')) + ApplyNativeRewriteOp([], "NativeRewrite", args=[attr]) + # CHECK: module { # CHECK: pdl.pattern @erase : benefit(1) { @@ -245,12 +256,13 @@ def test_attribute_with_value(): # CHECK: } @constructAndPrintInModule def test_erase(): - pattern = PatternOp(1, "erase") - with InsertionPoint(pattern.body): - root = OperationOp() - rewrite = RewriteOp(root) - with InsertionPoint(rewrite.add_body()): - EraseOp(root) + pattern = PatternOp(1, "erase") + with InsertionPoint(pattern.body): + root = OperationOp() + rewrite = RewriteOp(root) + with InsertionPoint(rewrite.add_body()): + EraseOp(root) + # CHECK: module { # CHECK: pdl.pattern @operation_results : benefit(1) { @@ -263,14 +275,15 @@ def test_erase(): # CHECK: } @constructAndPrintInModule def test_operation_results(): - valueRange = RangeType.get(ValueType.get()) - pattern = PatternOp(1, "operation_results") - with InsertionPoint(pattern.body): - types = TypesOp() - inputOp = OperationOp(types=[types]) - results = ResultsOp(valueRange, inputOp) - root = OperationOp(args=[results]) - RewriteOp(root, name="rewriter") + valueRange = RangeType.get(ValueType.get()) + pattern = PatternOp(1, "operation_results") + with InsertionPoint(pattern.body): + types = TypesOp() + inputOp = OperationOp(types=[types]) + results = ResultsOp(valueRange, inputOp) + root = OperationOp(args=[results]) + RewriteOp(root, name="rewriter") + # CHECK: module { # CHECK: pdl.pattern : benefit(1) { @@ -282,9 +295,9 @@ def test_operation_results(): # CHECK: } @constructAndPrintInModule def test_apply_native_constraint(): - pattern = PatternOp(1) - with InsertionPoint(pattern.body): - resultType = TypeOp() - ApplyNativeConstraintOp("typeConstraint", args=[resultType]) - root = OperationOp(types=[resultType]) - RewriteOp(root, name="rewrite") + pattern = PatternOp(1) + with InsertionPoint(pattern.body): + resultType = TypeOp() + ApplyNativeConstraintOp("typeConstraint", args=[resultType]) + root = OperationOp(types=[resultType]) + RewriteOp(root, name="rewrite") diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py index 2ca79b2..72a765c 100644 --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -5,367 +5,373 @@ import mlir.dialects.func as func import mlir.dialects.python_test as test import mlir.dialects.tensor as tensor + def run(f): - print("\nTEST:", f.__name__) - f() - return f + print("\nTEST:", f.__name__) + f() + return f + # CHECK-LABEL: TEST: testAttributes @run def testAttributes(): - with Context() as ctx, Location.unknown(): - ctx.allow_unregistered_dialects = True - - # - # Check op construction with attributes. - # - - i32 = IntegerType.get_signless(32) - one = IntegerAttr.get(i32, 1) - two = IntegerAttr.get(i32, 2) - unit = UnitAttr.get() - - # CHECK: "python_test.attributed_op"() { - # CHECK-DAG: mandatory_i32 = 1 : i32 - # CHECK-DAG: optional_i32 = 2 : i32 - # CHECK-DAG: unit - # CHECK: } - op = test.AttributedOp(one, optional_i32=two, unit=unit) - print(f"{op}") - - # CHECK: "python_test.attributed_op"() { - # CHECK: mandatory_i32 = 2 : i32 - # CHECK: } - op2 = test.AttributedOp(two) - print(f"{op2}") - - # - # Check generic "attributes" access and mutation. - # - - assert "additional" not in op.attributes - - # CHECK: "python_test.attributed_op"() { - # CHECK-DAG: additional = 1 : i32 - # CHECK-DAG: mandatory_i32 = 2 : i32 - # CHECK: } - op2.attributes["additional"] = one - print(f"{op2}") - - # CHECK: "python_test.attributed_op"() { - # CHECK-DAG: additional = 2 : i32 - # CHECK-DAG: mandatory_i32 = 2 : i32 - # CHECK: } - op2.attributes["additional"] = two - print(f"{op2}") - - # CHECK: "python_test.attributed_op"() { - # CHECK-NOT: additional = 2 : i32 - # CHECK: mandatory_i32 = 2 : i32 - # CHECK: } - del op2.attributes["additional"] - print(f"{op2}") - - try: - print(op.attributes["additional"]) - except KeyError: - pass - else: - assert False, "expected KeyError on unknown attribute key" - - # - # Check accessors to defined attributes. - # - - # CHECK: Mandatory: 1 - # CHECK: Optional: 2 - # CHECK: Unit: True - print(f"Mandatory: {op.mandatory_i32.value}") - print(f"Optional: {op.optional_i32.value}") - print(f"Unit: {op.unit}") - - # CHECK: Mandatory: 2 - # CHECK: Optional: None - # CHECK: Unit: False - print(f"Mandatory: {op2.mandatory_i32.value}") - print(f"Optional: {op2.optional_i32}") - print(f"Unit: {op2.unit}") - - # CHECK: Mandatory: 2 - # CHECK: Optional: None - # CHECK: Unit: False - op.mandatory_i32 = two - op.optional_i32 = None - op.unit = False - print(f"Mandatory: {op.mandatory_i32.value}") - print(f"Optional: {op.optional_i32}") - print(f"Unit: {op.unit}") - assert "optional_i32" not in op.attributes - assert "unit" not in op.attributes - - try: - op.mandatory_i32 = None - except ValueError: - pass - else: - assert False, "expected ValueError on setting a mandatory attribute to None" - - # CHECK: Optional: 2 - op.optional_i32 = two - print(f"Optional: {op.optional_i32.value}") - - # CHECK: Optional: None - del op.optional_i32 - print(f"Optional: {op.optional_i32}") - - # CHECK: Unit: False - op.unit = None - print(f"Unit: {op.unit}") - assert "unit" not in op.attributes - - # CHECK: Unit: True - op.unit = True - print(f"Unit: {op.unit}") - - # CHECK: Unit: False - del op.unit - print(f"Unit: {op.unit}") + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + + # + # Check op construction with attributes. + # + + i32 = IntegerType.get_signless(32) + one = IntegerAttr.get(i32, 1) + two = IntegerAttr.get(i32, 2) + unit = UnitAttr.get() + + # CHECK: "python_test.attributed_op"() { + # CHECK-DAG: mandatory_i32 = 1 : i32 + # CHECK-DAG: optional_i32 = 2 : i32 + # CHECK-DAG: unit + # CHECK: } + op = test.AttributedOp(one, optional_i32=two, unit=unit) + print(f"{op}") + + # CHECK: "python_test.attributed_op"() { + # CHECK: mandatory_i32 = 2 : i32 + # CHECK: } + op2 = test.AttributedOp(two) + print(f"{op2}") + + # + # Check generic "attributes" access and mutation. + # + + assert "additional" not in op.attributes + + # CHECK: "python_test.attributed_op"() { + # CHECK-DAG: additional = 1 : i32 + # CHECK-DAG: mandatory_i32 = 2 : i32 + # CHECK: } + op2.attributes["additional"] = one + print(f"{op2}") + + # CHECK: "python_test.attributed_op"() { + # CHECK-DAG: additional = 2 : i32 + # CHECK-DAG: mandatory_i32 = 2 : i32 + # CHECK: } + op2.attributes["additional"] = two + print(f"{op2}") + + # CHECK: "python_test.attributed_op"() { + # CHECK-NOT: additional = 2 : i32 + # CHECK: mandatory_i32 = 2 : i32 + # CHECK: } + del op2.attributes["additional"] + print(f"{op2}") + + try: + print(op.attributes["additional"]) + except KeyError: + pass + else: + assert False, "expected KeyError on unknown attribute key" + + # + # Check accessors to defined attributes. + # + + # CHECK: Mandatory: 1 + # CHECK: Optional: 2 + # CHECK: Unit: True + print(f"Mandatory: {op.mandatory_i32.value}") + print(f"Optional: {op.optional_i32.value}") + print(f"Unit: {op.unit}") + + # CHECK: Mandatory: 2 + # CHECK: Optional: None + # CHECK: Unit: False + print(f"Mandatory: {op2.mandatory_i32.value}") + print(f"Optional: {op2.optional_i32}") + print(f"Unit: {op2.unit}") + + # CHECK: Mandatory: 2 + # CHECK: Optional: None + # CHECK: Unit: False + op.mandatory_i32 = two + op.optional_i32 = None + op.unit = False + print(f"Mandatory: {op.mandatory_i32.value}") + print(f"Optional: {op.optional_i32}") + print(f"Unit: {op.unit}") + assert "optional_i32" not in op.attributes + assert "unit" not in op.attributes + + try: + op.mandatory_i32 = None + except ValueError: + pass + else: + assert False, "expected ValueError on setting a mandatory attribute to None" + + # CHECK: Optional: 2 + op.optional_i32 = two + print(f"Optional: {op.optional_i32.value}") + + # CHECK: Optional: None + del op.optional_i32 + print(f"Optional: {op.optional_i32}") + + # CHECK: Unit: False + op.unit = None + print(f"Unit: {op.unit}") + assert "unit" not in op.attributes + + # CHECK: Unit: True + op.unit = True + print(f"Unit: {op.unit}") + + # CHECK: Unit: False + del op.unit + print(f"Unit: {op.unit}") + # CHECK-LABEL: TEST: attrBuilder @run def attrBuilder(): - with Context() as ctx, Location.unknown(): - ctx.allow_unregistered_dialects = True - op = test.AttributesOp(x_bool=True, - x_i16=1, - x_i32=2, - x_i64=3, - x_si16=-1, - x_si32=-2, - x_f32=1.5, - x_f64=2.5, - x_str='x_str', - x_i32_array=[1, 2, 3], - x_i64_array=[4, 5, 6], - x_f32_array=[1.5, -2.5, 3.5], - x_f64_array=[4.5, 5.5, -6.5], - x_i64_dense=[1, 2, 3, 4, 5, 6]) - print(op) + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + op = test.AttributesOp( + x_bool=True, + x_i16=1, + x_i32=2, + x_i64=3, + x_si16=-1, + x_si32=-2, + x_f32=1.5, + x_f64=2.5, + x_str="x_str", + x_i32_array=[1, 2, 3], + x_i64_array=[4, 5, 6], + x_f32_array=[1.5, -2.5, 3.5], + x_f64_array=[4.5, 5.5, -6.5], + x_i64_dense=[1, 2, 3, 4, 5, 6], + ) + print(op) # CHECK-LABEL: TEST: inferReturnTypes @run def inferReturnTypes(): - with Context() as ctx, Location.unknown(ctx): - test.register_python_test_dialect(ctx) - module = Module.create() - with InsertionPoint(module.body): - op = test.InferResultsOp() - dummy = test.DummyOp() - - # CHECK: [Type(i32), Type(i64)] - iface = InferTypeOpInterface(op) - print(iface.inferReturnTypes()) - - # CHECK: [Type(i32), Type(i64)] - iface_static = InferTypeOpInterface(test.InferResultsOp) - print(iface.inferReturnTypes()) - - assert isinstance(iface.opview, test.InferResultsOp) - assert iface.opview == iface.operation.opview - - try: - iface_static.opview - except TypeError: - pass - else: - assert False, ("not expected to be able to obtain an opview from a static" - " interface") - - try: - InferTypeOpInterface(dummy) - except ValueError: - pass - else: - assert False, "not expected dummy op to implement the interface" - - try: - InferTypeOpInterface(test.DummyOp) - except ValueError: - pass - else: - assert False, "not expected dummy op class to implement the interface" + with Context() as ctx, Location.unknown(ctx): + test.register_python_test_dialect(ctx) + module = Module.create() + with InsertionPoint(module.body): + op = test.InferResultsOp() + dummy = test.DummyOp() + + # CHECK: [Type(i32), Type(i64)] + iface = InferTypeOpInterface(op) + print(iface.inferReturnTypes()) + + # CHECK: [Type(i32), Type(i64)] + iface_static = InferTypeOpInterface(test.InferResultsOp) + print(iface.inferReturnTypes()) + + assert isinstance(iface.opview, test.InferResultsOp) + assert iface.opview == iface.operation.opview + + try: + iface_static.opview + except TypeError: + pass + else: + assert False, ( + "not expected to be able to obtain an opview from a static" " interface" + ) + + try: + InferTypeOpInterface(dummy) + except ValueError: + pass + else: + assert False, "not expected dummy op to implement the interface" + + try: + InferTypeOpInterface(test.DummyOp) + except ValueError: + pass + else: + assert False, "not expected dummy op class to implement the interface" # CHECK-LABEL: TEST: resultTypesDefinedByTraits @run def resultTypesDefinedByTraits(): - with Context() as ctx, Location.unknown(ctx): - test.register_python_test_dialect(ctx) - module = Module.create() - with InsertionPoint(module.body): - inferred = test.InferResultsOp() - same = test.SameOperandAndResultTypeOp([inferred.results[0]]) - # CHECK-COUNT-2: i32 - print(same.one.type) - print(same.two.type) - - first_type_attr = test.FirstAttrDeriveTypeAttrOp( - inferred.results[1], TypeAttr.get(IndexType.get())) - # CHECK-COUNT-2: index - print(first_type_attr.one.type) - print(first_type_attr.two.type) - - first_attr = test.FirstAttrDeriveAttrOp( - FloatAttr.get(F32Type.get(), 3.14)) - # CHECK-COUNT-3: f32 - print(first_attr.one.type) - print(first_attr.two.type) - print(first_attr.three.type) - - implied = test.InferResultsImpliedOp() - # CHECK: i32 - print(implied.integer.type) - # CHECK: f64 - print(implied.flt.type) - # CHECK: index - print(implied.index.type) + with Context() as ctx, Location.unknown(ctx): + test.register_python_test_dialect(ctx) + module = Module.create() + with InsertionPoint(module.body): + inferred = test.InferResultsOp() + same = test.SameOperandAndResultTypeOp([inferred.results[0]]) + # CHECK-COUNT-2: i32 + print(same.one.type) + print(same.two.type) + + first_type_attr = test.FirstAttrDeriveTypeAttrOp( + inferred.results[1], TypeAttr.get(IndexType.get()) + ) + # CHECK-COUNT-2: index + print(first_type_attr.one.type) + print(first_type_attr.two.type) + + first_attr = test.FirstAttrDeriveAttrOp(FloatAttr.get(F32Type.get(), 3.14)) + # CHECK-COUNT-3: f32 + print(first_attr.one.type) + print(first_attr.two.type) + print(first_attr.three.type) + + implied = test.InferResultsImpliedOp() + # CHECK: i32 + print(implied.integer.type) + # CHECK: f64 + print(implied.flt.type) + # CHECK: index + print(implied.index.type) # CHECK-LABEL: TEST: testOptionalOperandOp @run def testOptionalOperandOp(): - with Context() as ctx, Location.unknown(): - test.register_python_test_dialect(ctx) + with Context() as ctx, Location.unknown(): + test.register_python_test_dialect(ctx) - module = Module.create() - with InsertionPoint(module.body): + module = Module.create() + with InsertionPoint(module.body): - op1 = test.OptionalOperandOp() - # CHECK: op1.input is None: True - print(f"op1.input is None: {op1.input is None}") + op1 = test.OptionalOperandOp() + # CHECK: op1.input is None: True + print(f"op1.input is None: {op1.input is None}") - op2 = test.OptionalOperandOp(input=op1) - # CHECK: op2.input is None: False - print(f"op2.input is None: {op2.input is None}") + op2 = test.OptionalOperandOp(input=op1) + # CHECK: op2.input is None: False + print(f"op2.input is None: {op2.input is None}") # CHECK-LABEL: TEST: testCustomAttribute @run def testCustomAttribute(): - with Context() as ctx: - test.register_python_test_dialect(ctx) - a = test.TestAttr.get() - # CHECK: #python_test.test_attr - print(a) - - # The following cast must not assert. - b = test.TestAttr(a) - - unit = UnitAttr.get() - try: - test.TestAttr(unit) - except ValueError as e: - assert "Cannot cast attribute to TestAttr" in str(e) - else: - raise - - # The following must trigger a TypeError from our adaptors and must not - # crash. - try: - test.TestAttr(42) - except TypeError as e: - assert "Expected an MLIR object" in str(e) - else: - raise - - # The following must trigger a TypeError from pybind (therefore, not - # checking its message) and must not crash. - try: - test.TestAttr(42, 56) - except TypeError: - pass - else: - raise + with Context() as ctx: + test.register_python_test_dialect(ctx) + a = test.TestAttr.get() + # CHECK: #python_test.test_attr + print(a) + + # The following cast must not assert. + b = test.TestAttr(a) + + unit = UnitAttr.get() + try: + test.TestAttr(unit) + except ValueError as e: + assert "Cannot cast attribute to TestAttr" in str(e) + else: + raise + + # The following must trigger a TypeError from our adaptors and must not + # crash. + try: + test.TestAttr(42) + except TypeError as e: + assert "Expected an MLIR object" in str(e) + else: + raise + + # The following must trigger a TypeError from pybind (therefore, not + # checking its message) and must not crash. + try: + test.TestAttr(42, 56) + except TypeError: + pass + else: + raise @run def testCustomType(): - with Context() as ctx: - test.register_python_test_dialect(ctx) - a = test.TestType.get() - # CHECK: !python_test.test_type - print(a) - - # The following cast must not assert. - b = test.TestType(a) - # Instance custom types should have typeids - assert isinstance(b.typeid, TypeID) - # Subclasses of ir.Type should not have a static_typeid - # CHECK: 'TestType' object has no attribute 'static_typeid' - try: - b.static_typeid - except AttributeError as e: - print(e) - - i8 = IntegerType.get_signless(8) - try: - test.TestType(i8) - except ValueError as e: - assert "Cannot cast type to TestType" in str(e) - else: - raise - - # The following must trigger a TypeError from our adaptors and must not - # crash. - try: - test.TestType(42) - except TypeError as e: - assert "Expected an MLIR object" in str(e) - else: - raise - - # The following must trigger a TypeError from pybind (therefore, not - # checking its message) and must not crash. - try: - test.TestType(42, 56) - except TypeError: - pass - else: - raise + with Context() as ctx: + test.register_python_test_dialect(ctx) + a = test.TestType.get() + # CHECK: !python_test.test_type + print(a) + + # The following cast must not assert. + b = test.TestType(a) + # Instance custom types should have typeids + assert isinstance(b.typeid, TypeID) + # Subclasses of ir.Type should not have a static_typeid + # CHECK: 'TestType' object has no attribute 'static_typeid' + try: + b.static_typeid + except AttributeError as e: + print(e) + + i8 = IntegerType.get_signless(8) + try: + test.TestType(i8) + except ValueError as e: + assert "Cannot cast type to TestType" in str(e) + else: + raise + + # The following must trigger a TypeError from our adaptors and must not + # crash. + try: + test.TestType(42) + except TypeError as e: + assert "Expected an MLIR object" in str(e) + else: + raise + + # The following must trigger a TypeError from pybind (therefore, not + # checking its message) and must not crash. + try: + test.TestType(42, 56) + except TypeError: + pass + else: + raise @run # CHECK-LABEL: TEST: testTensorValue def testTensorValue(): - with Context() as ctx, Location.unknown(): - test.register_python_test_dialect(ctx) + with Context() as ctx, Location.unknown(): + test.register_python_test_dialect(ctx) - i8 = IntegerType.get_signless(8) + i8 = IntegerType.get_signless(8) - class Tensor(test.TestTensorValue): - def __str__(self): - return super().__str__().replace("Value", "Tensor") + class Tensor(test.TestTensorValue): + def __str__(self): + return super().__str__().replace("Value", "Tensor") - module = Module.create() - with InsertionPoint(module.body): - t = tensor.EmptyOp([10, 10], i8).result + module = Module.create() + with InsertionPoint(module.body): + t = tensor.EmptyOp([10, 10], i8).result - # CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>) - print(Value(t)) + # CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>) + print(Value(t)) - tt = Tensor(t) - # CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>) - print(tt) + tt = Tensor(t) + # CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>) + print(tt) - # CHECK: False - print(tt.is_null()) + # CHECK: False + print(tt.is_null()) - # Classes of custom types that inherit from concrete types should have - # static_typeid - assert isinstance(test.TestTensorType.static_typeid, TypeID) - # And it should be equal to the in-tree concrete type - assert test.TestTensorType.static_typeid == t.type.typeid + # Classes of custom types that inherit from concrete types should have + # static_typeid + assert isinstance(test.TestTensorType.static_typeid, TypeID) + # And it should be equal to the in-tree concrete type + assert test.TestTensorType.static_typeid == t.type.typeid # CHECK-LABEL: TEST: inferReturnTypeComponents @@ -412,7 +418,7 @@ def inferReturnTypeComponents(): # CHECK: shape: None iface = InferShapedTypeOpInterface(unranked_op) shaped_type_components = iface.inferReturnTypeComponents( - operands=[unranked_op.operand] + operands=[unranked_op.operand] )[0] print("has rank:", shaped_type_components.has_rank) print("rank:", shaped_type_components.rank) diff --git a/mlir/test/python/dialects/quant.py b/mlir/test/python/dialects/quant.py index 32614be..0ee3327 100644 --- a/mlir/test/python/dialects/quant.py +++ b/mlir/test/python/dialects/quant.py @@ -5,127 +5,133 @@ from mlir.dialects import quant def run(f): - print("\nTEST:", f.__name__) - f() - return f + print("\nTEST:", f.__name__) + f() + return f # CHECK-LABEL: TEST: test_type_hierarchy @run def test_type_hierarchy(): - with Context(): - i8 = IntegerType.get_signless(8) - any = Type.parse("!quant.any:f32>") - uniform = Type.parse("!quant.uniform:f32, 0.99872:127>") - per_axis = Type.parse("!quant.uniform") - calibrated = Type.parse("!quant.calibrated>") + with Context(): + i8 = IntegerType.get_signless(8) + any = Type.parse("!quant.any:f32>") + uniform = Type.parse("!quant.uniform:f32, 0.99872:127>") + per_axis = Type.parse("!quant.uniform") + calibrated = Type.parse("!quant.calibrated>") - assert not quant.QuantizedType.isinstance(i8) - assert quant.QuantizedType.isinstance(any) - assert quant.QuantizedType.isinstance(uniform) - assert quant.QuantizedType.isinstance(per_axis) - assert quant.QuantizedType.isinstance(calibrated) + assert not quant.QuantizedType.isinstance(i8) + assert quant.QuantizedType.isinstance(any) + assert quant.QuantizedType.isinstance(uniform) + assert quant.QuantizedType.isinstance(per_axis) + assert quant.QuantizedType.isinstance(calibrated) - assert quant.AnyQuantizedType.isinstance(any) - assert quant.UniformQuantizedType.isinstance(uniform) - assert quant.UniformQuantizedPerAxisType.isinstance(per_axis) - assert quant.CalibratedQuantizedType.isinstance(calibrated) + assert quant.AnyQuantizedType.isinstance(any) + assert quant.UniformQuantizedType.isinstance(uniform) + assert quant.UniformQuantizedPerAxisType.isinstance(per_axis) + assert quant.CalibratedQuantizedType.isinstance(calibrated) - assert not quant.AnyQuantizedType.isinstance(uniform) - assert not quant.UniformQuantizedType.isinstance(per_axis) + assert not quant.AnyQuantizedType.isinstance(uniform) + assert not quant.UniformQuantizedType.isinstance(per_axis) # CHECK-LABEL: TEST: test_any_quantized_type @run def test_any_quantized_type(): - with Context(): - i8 = IntegerType.get_signless(8) - f32 = F32Type.get() - any = quant.AnyQuantizedType.get(quant.QuantizedType.FLAG_SIGNED, i8, f32, - -8, 7) - - # CHECK: flags: 1 - print(f"flags: {any.flags}") - # CHECK: signed: True - print(f"signed: {any.is_signed}") - # CHECK: storage type: i8 - print(f"storage type: {any.storage_type}") - # CHECK: expressed type: f32 - print(f"expressed type: {any.expressed_type}") - # CHECK: storage min: -8 - print(f"storage min: {any.storage_type_min}") - # CHECK: storage max: 7 - print(f"storage max: {any.storage_type_max}") - # CHECK: storage width: 8 - print(f"storage width: {any.storage_type_integral_width}") - # CHECK: quantized element type: !quant.any:f32> - print(f"quantized element type: {any.quantized_element_type}") - # CHECK: !quant.any:f32> - print(any) - assert any == Type.parse("!quant.any:f32>") + with Context(): + i8 = IntegerType.get_signless(8) + f32 = F32Type.get() + any = quant.AnyQuantizedType.get( + quant.QuantizedType.FLAG_SIGNED, i8, f32, -8, 7 + ) + + # CHECK: flags: 1 + print(f"flags: {any.flags}") + # CHECK: signed: True + print(f"signed: {any.is_signed}") + # CHECK: storage type: i8 + print(f"storage type: {any.storage_type}") + # CHECK: expressed type: f32 + print(f"expressed type: {any.expressed_type}") + # CHECK: storage min: -8 + print(f"storage min: {any.storage_type_min}") + # CHECK: storage max: 7 + print(f"storage max: {any.storage_type_max}") + # CHECK: storage width: 8 + print(f"storage width: {any.storage_type_integral_width}") + # CHECK: quantized element type: !quant.any:f32> + print(f"quantized element type: {any.quantized_element_type}") + # CHECK: !quant.any:f32> + print(any) + assert any == Type.parse("!quant.any:f32>") # CHECK-LABEL: TEST: test_uniform_type @run def test_uniform_type(): - with Context(): - i8 = IntegerType.get_signless(8) - f32 = F32Type.get() - uniform = quant.UniformQuantizedType.get( - quant.UniformQuantizedType.FLAG_SIGNED, i8, f32, 0.99872, 127, -8, 7) - - # CHECK: scale: 0.99872 - print(f"scale: {uniform.scale}") - # CHECK: zero point: 127 - print(f"zero point: {uniform.zero_point}") - # CHECK: fixed point: False - print(f"fixed point: {uniform.is_fixed_point}") - # CHECK: !quant.uniform:f32, 9.987200e-01:127> - print(uniform) - assert uniform == Type.parse("!quant.uniform:f32, 0.99872:127>") + with Context(): + i8 = IntegerType.get_signless(8) + f32 = F32Type.get() + uniform = quant.UniformQuantizedType.get( + quant.UniformQuantizedType.FLAG_SIGNED, i8, f32, 0.99872, 127, -8, 7 + ) + + # CHECK: scale: 0.99872 + print(f"scale: {uniform.scale}") + # CHECK: zero point: 127 + print(f"zero point: {uniform.zero_point}") + # CHECK: fixed point: False + print(f"fixed point: {uniform.is_fixed_point}") + # CHECK: !quant.uniform:f32, 9.987200e-01:127> + print(uniform) + assert uniform == Type.parse("!quant.uniform:f32, 0.99872:127>") # CHECK-LABEL: TEST: test_uniform_per_axis_type @run def test_uniform_per_axis_type(): - with Context(): - i8 = IntegerType.get_signless(8) - f32 = F32Type.get() - per_axis = quant.UniformQuantizedPerAxisType.get( - quant.QuantizedType.FLAG_SIGNED, - i8, - f32, [200, 0.99872], [0, 120], - quantized_dimension=1, - storage_type_min=quant.QuantizedType.default_minimum_for_integer( - is_signed=True, integral_width=8), - storage_type_max=quant.QuantizedType.default_maximum_for_integer( - is_signed=True, integral_width=8)) - - # CHECK: scales: None - print(f"scales: {per_axis.scales}") - # CHECK: zero_points: None - print(f"zero_points: {per_axis.zero_points}") - # CHECK: quantized dim: 1 - print(f"quantized dim: {per_axis.quantized_dimension}") - # CHECK: fixed point: False - print(f"fixed point: {per_axis.is_fixed_point}") - # CHECK: !quant.uniform - print(per_axis) - assert per_axis == Type.parse( - "!quant.uniform") + with Context(): + i8 = IntegerType.get_signless(8) + f32 = F32Type.get() + per_axis = quant.UniformQuantizedPerAxisType.get( + quant.QuantizedType.FLAG_SIGNED, + i8, + f32, + [200, 0.99872], + [0, 120], + quantized_dimension=1, + storage_type_min=quant.QuantizedType.default_minimum_for_integer( + is_signed=True, integral_width=8 + ), + storage_type_max=quant.QuantizedType.default_maximum_for_integer( + is_signed=True, integral_width=8 + ), + ) + + # CHECK: scales: None + print(f"scales: {per_axis.scales}") + # CHECK: zero_points: None + print(f"zero_points: {per_axis.zero_points}") + # CHECK: quantized dim: 1 + print(f"quantized dim: {per_axis.quantized_dimension}") + # CHECK: fixed point: False + print(f"fixed point: {per_axis.is_fixed_point}") + # CHECK: !quant.uniform + print(per_axis) + assert per_axis == Type.parse("!quant.uniform") # CHECK-LABEL: TEST: test_calibrated_type @run def test_calibrated_type(): - with Context(): - f32 = F32Type.get() - calibrated = quant.CalibratedQuantizedType.get(f32, -0.998, 1.2321) - - # CHECK: min: -0.998 - print(f"min: {calibrated.min}") - # CHECK: max: 1.2321 - print(f"max: {calibrated.max}") - # CHECK: !quant.calibrated> - print(calibrated) - assert calibrated == Type.parse("!quant.calibrated>") + with Context(): + f32 = F32Type.get() + calibrated = quant.CalibratedQuantizedType.get(f32, -0.998, 1.2321) + + # CHECK: min: -0.998 + print(f"min: {calibrated.min}") + # CHECK: max: 1.2321 + print(f"max: {calibrated.max}") + # CHECK: !quant.calibrated> + print(calibrated) + assert calibrated == Type.parse("!quant.calibrated>") diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py index 4a618ff4..8cb55fd 100644 --- a/mlir/test/python/dialects/scf.py +++ b/mlir/test/python/dialects/scf.py @@ -8,26 +8,26 @@ from mlir.dialects import builtin def constructAndPrintInModule(f): - print("\nTEST:", f.__name__) - with Context(), Location.unknown(): - module = Module.create() - with InsertionPoint(module.body): - f() - print(module) - return f + print("\nTEST:", f.__name__) + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + f() + print(module) + return f # CHECK-LABEL: TEST: testSimpleLoop @constructAndPrintInModule def testSimpleLoop(): - index_type = IndexType.get() + index_type = IndexType.get() - @func.FuncOp.from_py_func(index_type, index_type, index_type) - def simple_loop(lb, ub, step): - loop = scf.ForOp(lb, ub, step, [lb, lb]) - with InsertionPoint(loop.body): - scf.YieldOp(loop.inner_iter_args) - return + @func.FuncOp.from_py_func(index_type, index_type, index_type) + def simple_loop(lb, ub, step): + loop = scf.ForOp(lb, ub, step, [lb, lb]) + with InsertionPoint(loop.body): + scf.YieldOp(loop.inner_iter_args) + return # CHECK: func @simple_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) @@ -39,14 +39,14 @@ def testSimpleLoop(): # CHECK-LABEL: TEST: testInductionVar @constructAndPrintInModule def testInductionVar(): - index_type = IndexType.get() + index_type = IndexType.get() - @func.FuncOp.from_py_func(index_type, index_type, index_type) - def induction_var(lb, ub, step): - loop = scf.ForOp(lb, ub, step, [lb]) - with InsertionPoint(loop.body): - scf.YieldOp([loop.induction_variable]) - return + @func.FuncOp.from_py_func(index_type, index_type, index_type) + def induction_var(lb, ub, step): + loop = scf.ForOp(lb, ub, step, [lb]) + with InsertionPoint(loop.body): + scf.YieldOp([loop.induction_variable]) + return # CHECK: func @induction_var(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) @@ -56,19 +56,18 @@ def testInductionVar(): @constructAndPrintInModule def testOpsAsArguments(): - index_type = IndexType.get() - callee = func.FuncOp( - "callee", ([], [index_type, index_type]), visibility="private") - f = func.FuncOp("ops_as_arguments", ([], [])) - with InsertionPoint(f.add_entry_block()): - lb = arith.ConstantOp.create_index(0) - ub = arith.ConstantOp.create_index(42) - step = arith.ConstantOp.create_index(2) - iter_args = func.CallOp(callee, []) - loop = scf.ForOp(lb, ub, step, iter_args) - with InsertionPoint(loop.body): - scf.YieldOp(loop.inner_iter_args) - func.ReturnOp([]) + index_type = IndexType.get() + callee = func.FuncOp("callee", ([], [index_type, index_type]), visibility="private") + f = func.FuncOp("ops_as_arguments", ([], [])) + with InsertionPoint(f.add_entry_block()): + lb = arith.ConstantOp.create_index(0) + ub = arith.ConstantOp.create_index(42) + step = arith.ConstantOp.create_index(2) + iter_args = func.CallOp(callee, []) + loop = scf.ForOp(lb, ub, step, iter_args) + with InsertionPoint(loop.body): + scf.YieldOp(loop.inner_iter_args) + func.ReturnOp([]) # CHECK-LABEL: TEST: testOpsAsArguments @@ -86,17 +85,17 @@ def testOpsAsArguments(): @constructAndPrintInModule def testIfWithoutElse(): - bool = IntegerType.get_signless(1) - i32 = IntegerType.get_signless(32) + bool = IntegerType.get_signless(1) + i32 = IntegerType.get_signless(32) - @func.FuncOp.from_py_func(bool) - def simple_if(cond): - if_op = scf.IfOp(cond) - with InsertionPoint(if_op.then_block): - one = arith.ConstantOp(i32, 1) - add = arith.AddIOp(one, one) - scf.YieldOp([]) - return + @func.FuncOp.from_py_func(bool) + def simple_if(cond): + if_op = scf.IfOp(cond) + with InsertionPoint(if_op.then_block): + one = arith.ConstantOp(i32, 1) + add = arith.AddIOp(one, one) + scf.YieldOp([]) + return # CHECK: func @simple_if(%[[ARG0:.*]]: i1) @@ -108,22 +107,22 @@ def testIfWithoutElse(): @constructAndPrintInModule def testIfWithElse(): - bool = IntegerType.get_signless(1) - i32 = IntegerType.get_signless(32) - - @func.FuncOp.from_py_func(bool) - def simple_if_else(cond): - if_op = scf.IfOp(cond, [i32, i32], hasElse=True) - with InsertionPoint(if_op.then_block): - x_true = arith.ConstantOp(i32, 0) - y_true = arith.ConstantOp(i32, 1) - scf.YieldOp([x_true, y_true]) - with InsertionPoint(if_op.else_block): - x_false = arith.ConstantOp(i32, 2) - y_false = arith.ConstantOp(i32, 3) - scf.YieldOp([x_false, y_false]) - add = arith.AddIOp(if_op.results[0], if_op.results[1]) - return + bool = IntegerType.get_signless(1) + i32 = IntegerType.get_signless(32) + + @func.FuncOp.from_py_func(bool) + def simple_if_else(cond): + if_op = scf.IfOp(cond, [i32, i32], hasElse=True) + with InsertionPoint(if_op.then_block): + x_true = arith.ConstantOp(i32, 0) + y_true = arith.ConstantOp(i32, 1) + scf.YieldOp([x_true, y_true]) + with InsertionPoint(if_op.else_block): + x_false = arith.ConstantOp(i32, 2) + y_false = arith.ConstantOp(i32, 3) + scf.YieldOp([x_false, y_false]) + add = arith.AddIOp(if_op.results[0], if_op.results[1]) + return # CHECK: func @simple_if_else(%[[ARG0:.*]]: i1) diff --git a/mlir/test/python/dialects/shape.py b/mlir/test/python/dialects/shape.py index 3e7a8b2..ad75585 100644 --- a/mlir/test/python/dialects/shape.py +++ b/mlir/test/python/dialects/shape.py @@ -7,36 +7,38 @@ import mlir.dialects.shape as shape def run(f): - print("\nTEST:", f.__name__) - f() - return f + print("\nTEST:", f.__name__) + f() + return f # CHECK-LABEL: TEST: testConstShape @run def testConstShape(): - with Context() as ctx, Location.unknown(): - module = Module.create() - f32 = F32Type.get() - with InsertionPoint(module.body): - @func.FuncOp.from_py_func( - RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32)) - def const_shape_tensor(arg): - shape.ConstWitnessOp(False) - shape.ConstSizeOp(30) - shape.ConstSizeOp(IntegerAttr.get(IndexType.get(), 40)) - x = shape.ConstShapeOp([1, 2]) - shape.MeetOp(x, x, error="impossible") - return shape.ConstShapeOp( - DenseElementsAttr.get( - np.array([3, 4], dtype=np.int64), type=IndexType.get())) - - - - # CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>) - # CHECK-DAG: shape.const_witness false - # CHECK-DAG: shape.const_size 30 - # CHECK-DAG: shape.const_size 40 - # CHECK-DAG: shape.const_shape [1, 2] : tensor<2xindex> - # CHECK-DAG: shape.const_shape [3, 4] : tensor<2xindex> - print(module) + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func( + RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32) + ) + def const_shape_tensor(arg): + shape.ConstWitnessOp(False) + shape.ConstSizeOp(30) + shape.ConstSizeOp(IntegerAttr.get(IndexType.get(), 40)) + x = shape.ConstShapeOp([1, 2]) + shape.MeetOp(x, x, error="impossible") + return shape.ConstShapeOp( + DenseElementsAttr.get( + np.array([3, 4], dtype=np.int64), type=IndexType.get() + ) + ) + + # CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>) + # CHECK-DAG: shape.const_witness false + # CHECK-DAG: shape.const_size 30 + # CHECK-DAG: shape.const_size 40 + # CHECK-DAG: shape.const_shape [1, 2] : tensor<2xindex> + # CHECK-DAG: shape.const_shape [3, 4] : tensor<2xindex> + print(module) diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py index 6190beb..b7a0606 100644 --- a/mlir/test/python/dialects/sparse_tensor/dialect.py +++ b/mlir/test/python/dialects/sparse_tensor/dialect.py @@ -3,97 +3,106 @@ from mlir.ir import * from mlir.dialects import sparse_tensor as st + def run(f): - print("\nTEST:", f.__name__) - f() - return f + print("\nTEST:", f.__name__) + f() + return f # CHECK-LABEL: TEST: testEncodingAttr1D @run def testEncodingAttr1D(): - with Context() as ctx: - parsed = Attribute.parse('#sparse_tensor.encoding<{' - ' lvlTypes = [ "compressed" ],' - ' posWidth = 16,' - ' crdWidth = 32' - '}>') - # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 16, crdWidth = 32 }> - print(parsed) - - casted = st.EncodingAttr(parsed) - # CHECK: equal: True - print(f"equal: {casted == parsed}") - - # CHECK: lvl_types: [] - print(f"lvl_types: {casted.lvl_types}") - # CHECK: dim_ordering: None - print(f"dim_ordering: {casted.dim_ordering}") - # CHECK: pos_width: 16 - print(f"pos_width: {casted.pos_width}") - # CHECK: crd_width: 32 - print(f"crd_width: {casted.crd_width}") - - created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0) - # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }> - print(created) - # CHECK: created_equal: False - print(f"created_equal: {created == casted}") - - # Verify that the factory creates an instance of the proper type. - # CHECK: is_proper_instance: True - print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}") - # CHECK: created_pos_width: 0 - print(f"created_pos_width: {created.pos_width}") + with Context() as ctx: + parsed = Attribute.parse( + "#sparse_tensor.encoding<{" + ' lvlTypes = [ "compressed" ],' + " posWidth = 16," + " crdWidth = 32" + "}>" + ) + # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 16, crdWidth = 32 }> + print(parsed) + + casted = st.EncodingAttr(parsed) + # CHECK: equal: True + print(f"equal: {casted == parsed}") + + # CHECK: lvl_types: [] + print(f"lvl_types: {casted.lvl_types}") + # CHECK: dim_ordering: None + print(f"dim_ordering: {casted.dim_ordering}") + # CHECK: pos_width: 16 + print(f"pos_width: {casted.pos_width}") + # CHECK: crd_width: 32 + print(f"crd_width: {casted.crd_width}") + + created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0) + # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }> + print(created) + # CHECK: created_equal: False + print(f"created_equal: {created == casted}") + + # Verify that the factory creates an instance of the proper type. + # CHECK: is_proper_instance: True + print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}") + # CHECK: created_pos_width: 0 + print(f"created_pos_width: {created.pos_width}") # CHECK-LABEL: TEST: testEncodingAttr2D @run def testEncodingAttr2D(): - with Context() as ctx: - parsed = Attribute.parse('#sparse_tensor.encoding<{' - ' lvlTypes = [ "dense", "compressed" ],' - ' dimOrdering = affine_map<(d0, d1) -> (d1, d0)>,' - ' posWidth = 8,' - ' crdWidth = 32' - '}>') - # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, posWidth = 8, crdWidth = 32 }> - print(parsed) - - casted = st.EncodingAttr(parsed) - # CHECK: equal: True - print(f"equal: {casted == parsed}") - - # CHECK: lvl_types: [, ] - print(f"lvl_types: {casted.lvl_types}") - # CHECK: dim_ordering: (d0, d1) -> (d1, d0) - print(f"dim_ordering: {casted.dim_ordering}") - # CHECK: pos_width: 8 - print(f"pos_width: {casted.pos_width}") - # CHECK: crd_width: 32 - print(f"crd_width: {casted.crd_width}") - - created = st.EncodingAttr.get(casted.lvl_types, casted.dim_ordering, - casted.higher_ordering, 8, 32) - # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, posWidth = 8, crdWidth = 32 }> - print(created) - # CHECK: created_equal: True - print(f"created_equal: {created == casted}") + with Context() as ctx: + parsed = Attribute.parse( + "#sparse_tensor.encoding<{" + ' lvlTypes = [ "dense", "compressed" ],' + " dimOrdering = affine_map<(d0, d1) -> (d1, d0)>," + " posWidth = 8," + " crdWidth = 32" + "}>" + ) + # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, posWidth = 8, crdWidth = 32 }> + print(parsed) + + casted = st.EncodingAttr(parsed) + # CHECK: equal: True + print(f"equal: {casted == parsed}") + + # CHECK: lvl_types: [, ] + print(f"lvl_types: {casted.lvl_types}") + # CHECK: dim_ordering: (d0, d1) -> (d1, d0) + print(f"dim_ordering: {casted.dim_ordering}") + # CHECK: pos_width: 8 + print(f"pos_width: {casted.pos_width}") + # CHECK: crd_width: 32 + print(f"crd_width: {casted.crd_width}") + + created = st.EncodingAttr.get( + casted.lvl_types, casted.dim_ordering, casted.higher_ordering, 8, 32 + ) + # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, posWidth = 8, crdWidth = 32 }> + print(created) + # CHECK: created_equal: True + print(f"created_equal: {created == casted}") # CHECK-LABEL: TEST: testEncodingAttrOnTensorType @run def testEncodingAttrOnTensorType(): - with Context() as ctx, Location.unknown(): - encoding = st.EncodingAttr( - Attribute.parse('#sparse_tensor.encoding<{' - ' lvlTypes = [ "compressed" ], ' - ' posWidth = 64,' - ' crdWidth = 32' - '}>')) - tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding) - # CHECK: tensor<1024xf32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 64, crdWidth = 32 }>> - print(tt) - # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 64, crdWidth = 32 }> - print(tt.encoding) - assert tt.encoding == encoding + with Context() as ctx, Location.unknown(): + encoding = st.EncodingAttr( + Attribute.parse( + "#sparse_tensor.encoding<{" + ' lvlTypes = [ "compressed" ], ' + " posWidth = 64," + " crdWidth = 32" + "}>" + ) + ) + tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding) + # CHECK: tensor<1024xf32, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 64, crdWidth = 32 }>> + print(tt) + # CHECK: #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 64, crdWidth = 32 }> + print(tt.encoding) + assert tt.encoding == encoding diff --git a/mlir/test/python/dialects/sparse_tensor/passes.py b/mlir/test/python/dialects/sparse_tensor/passes.py index 9319e16..c37c520 100644 --- a/mlir/test/python/dialects/sparse_tensor/passes.py +++ b/mlir/test/python/dialects/sparse_tensor/passes.py @@ -7,16 +7,16 @@ from mlir.dialects import sparse_tensor as st def run(f): - print('\nTEST:', f.__name__) - f() - return f + print("\nTEST:", f.__name__) + f() + return f # CHECK-LABEL: TEST: testSparseTensorPass @run def testSparseTensorPass(): - with Context() as context: - PassManager.parse('any(sparsification)') - PassManager.parse('any(sparse-tensor-conversion)') - # CHECK: SUCCESS - print('SUCCESS') + with Context() as context: + PassManager.parse("any(sparsification)") + PassManager.parse("any(sparse-tensor-conversion)") + # CHECK: SUCCESS + print("SUCCESS") diff --git a/mlir/test/python/dialects/tensor.py b/mlir/test/python/dialects/tensor.py index b0ad4b4..b690c93 100644 --- a/mlir/test/python/dialects/tensor.py +++ b/mlir/test/python/dialects/tensor.py @@ -7,125 +7,135 @@ import mlir.dialects.tensor as tensor def run(f): - print("\nTEST:", f.__name__) - f() - return f + print("\nTEST:", f.__name__) + f() + return f # CHECK-LABEL: TEST: testDimOp @run def testDimOp(): - with Context() as ctx, Location.unknown(): - module = Module.create() - f32Type = F32Type.get() - indexType = IndexType.get() - with InsertionPoint(module.body): - - @func.FuncOp.from_py_func( - RankedTensorType.get( - (ShapedType.get_dynamic_size(), ShapedType.get_dynamic_size()), - f32Type)) - # CHECK: func @tensor_static_dim - # CHECK-SAME: %[[ARG0:.+]]: tensor - # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index - # CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index - # CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] - # CHECK: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] - # CHECK: return %[[D0]], %[[D1]] - def tensor_static_dim(t): - c0 = arith.ConstantOp(indexType, 0) - c1 = arith.ConstantOp(indexType, 1) - d0 = tensor.DimOp(t, c0) - d1 = tensor.DimOp(t, c1) - return [d0.result, d1.result] - - print(module) + with Context() as ctx, Location.unknown(): + module = Module.create() + f32Type = F32Type.get() + indexType = IndexType.get() + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func( + RankedTensorType.get( + (ShapedType.get_dynamic_size(), ShapedType.get_dynamic_size()), + f32Type, + ) + ) + # CHECK: func @tensor_static_dim + # CHECK-SAME: %[[ARG0:.+]]: tensor + # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + # CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + # CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] + # CHECK: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] + # CHECK: return %[[D0]], %[[D1]] + def tensor_static_dim(t): + c0 = arith.ConstantOp(indexType, 0) + c1 = arith.ConstantOp(indexType, 1) + d0 = tensor.DimOp(t, c0) + d1 = tensor.DimOp(t, c1) + return [d0.result, d1.result] + + print(module) # CHECK-LABEL: TEST: testEmptyOp @run def testEmptyOp(): - with Context() as ctx, Location.unknown(): - module = Module.create() - f32 = F32Type.get() - with InsertionPoint(module.body): - # CHECK-LABEL: func @static_sizes - # CHECK: %0 = tensor.empty() : tensor<3x4xf32> - @func.FuncOp.from_py_func() - def static_sizes(): - return tensor.EmptyOp([3, 4], f32) - - # CHECK-LABEL: func @dynamic_sizes - # CHECK: %0 = tensor.empty(%arg0, %arg1) : tensor - @func.FuncOp.from_py_func(IndexType.get(), IndexType.get()) - def dynamic_sizes(d0, d1): - return tensor.EmptyOp([d0, d1], f32) - - # CHECK-LABEL: func @mixed_static_dynamic_sizes - # CHECK: %0 = tensor.empty(%arg0) : tensor - @func.FuncOp.from_py_func(IndexType.get()) - def mixed_static_dynamic_sizes(d0): - return tensor.EmptyOp([d0, 4], f32) - - # CHECK-LABEL: func @zero_d - # CHECK: %0 = tensor.empty() : tensor - @func.FuncOp.from_py_func() - def zero_d(): - return tensor.EmptyOp([], f32) - - print(module) + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + # CHECK-LABEL: func @static_sizes + # CHECK: %0 = tensor.empty() : tensor<3x4xf32> + @func.FuncOp.from_py_func() + def static_sizes(): + return tensor.EmptyOp([3, 4], f32) + + # CHECK-LABEL: func @dynamic_sizes + # CHECK: %0 = tensor.empty(%arg0, %arg1) : tensor + @func.FuncOp.from_py_func(IndexType.get(), IndexType.get()) + def dynamic_sizes(d0, d1): + return tensor.EmptyOp([d0, d1], f32) + + # CHECK-LABEL: func @mixed_static_dynamic_sizes + # CHECK: %0 = tensor.empty(%arg0) : tensor + @func.FuncOp.from_py_func(IndexType.get()) + def mixed_static_dynamic_sizes(d0): + return tensor.EmptyOp([d0, 4], f32) + + # CHECK-LABEL: func @zero_d + # CHECK: %0 = tensor.empty() : tensor + @func.FuncOp.from_py_func() + def zero_d(): + return tensor.EmptyOp([], f32) + + print(module) # CHECK-LABEL: TEST: testInferTypesInsertSlice @run def testInferTypesInsertSlice(): - with Context() as ctx, Location.unknown(): - module = Module.create() - f32Type = F32Type.get() - with InsertionPoint(module.body): - - @func.FuncOp.from_py_func( - RankedTensorType.get((1, 1), f32Type), - RankedTensorType.get((1, 1), f32Type)) - # CHECK: func @f - # CHECK: tensor.insert_slice %arg0 into %arg1[0, 0] [1, 1] [0, 0] : - # CHECK-SAME: tensor<1x1xf32> into tensor<1x1xf32> - def f(source, dest): - d0 = tensor.InsertSliceOp(source, dest, [], [], [], - DenseI64ArrayAttr.get([0, 0]), - DenseI64ArrayAttr.get([1, 1]), - DenseI64ArrayAttr.get([0, 0])) - return [d0.result] - - print(module) + with Context() as ctx, Location.unknown(): + module = Module.create() + f32Type = F32Type.get() + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func( + RankedTensorType.get((1, 1), f32Type), + RankedTensorType.get((1, 1), f32Type), + ) + # CHECK: func @f + # CHECK: tensor.insert_slice %arg0 into %arg1[0, 0] [1, 1] [0, 0] : + # CHECK-SAME: tensor<1x1xf32> into tensor<1x1xf32> + def f(source, dest): + d0 = tensor.InsertSliceOp( + source, + dest, + [], + [], + [], + DenseI64ArrayAttr.get([0, 0]), + DenseI64ArrayAttr.get([1, 1]), + DenseI64ArrayAttr.get([0, 0]), + ) + return [d0.result] + + print(module) # CHECK-LABEL: TEST: testFromElementsOp @run def testFromElementsOp(): - with Context() as ctx, Location.unknown(): - module = Module.create() - f32 = F32Type.get() - with InsertionPoint(module.body): - @func.FuncOp.from_py_func() - def default_builder(): - c0 = arith.ConstantOp(f32, 0.0) - # CHECK: %[[C0:.*]] = "arith.constant - # CHECK-SAME: value = 0.000000e+00 : f32 - print(c0) - c1 = arith.ConstantOp(f32, 1.0) - # CHECK: %[[C1:.*]] = "arith.constant - # CHECK-SAME: value = 1.000000e+00 : f32 - print(c1) - - t = tensor.FromElementsOp(RankedTensorType.get((2,), f32), [c0, c1]) - # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<2xf32> - print(t) - - t = tensor.FromElementsOp(RankedTensorType.get((2, 1), f32), [c0, c1]) - # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<2x1xf32> - print(t) - - t = tensor.FromElementsOp(RankedTensorType.get((1, 2), f32), [c0, c1]) - # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<1x2xf32> - print(t) + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func() + def default_builder(): + c0 = arith.ConstantOp(f32, 0.0) + # CHECK: %[[C0:.*]] = "arith.constant + # CHECK-SAME: value = 0.000000e+00 : f32 + print(c0) + c1 = arith.ConstantOp(f32, 1.0) + # CHECK: %[[C1:.*]] = "arith.constant + # CHECK-SAME: value = 1.000000e+00 : f32 + print(c1) + + t = tensor.FromElementsOp(RankedTensorType.get((2,), f32), [c0, c1]) + # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<2xf32> + print(t) + + t = tensor.FromElementsOp(RankedTensorType.get((2, 1), f32), [c0, c1]) + # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<2x1xf32> + print(t) + + t = tensor.FromElementsOp(RankedTensorType.get((1, 2), f32), [c0, c1]) + # CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<1x2xf32> + print(t) diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py index 6b36c02..ca6499b 100644 --- a/mlir/test/python/dialects/transform.py +++ b/mlir/test/python/dialects/transform.py @@ -6,158 +6,188 @@ from mlir.dialects.transform import pdl as transform_pdl def run(f): - with Context(), Location.unknown(): - module = Module.create() - with InsertionPoint(module.body): - print("\nTEST:", f.__name__) - f() - print(module) - return f + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + print("\nTEST:", f.__name__) + f() + print(module) + return f @run def testTypes(): - # CHECK-LABEL: TEST: testTypes - # CHECK: !transform.any_op - any_op = transform.AnyOpType.get() - print(any_op) + # CHECK-LABEL: TEST: testTypes + # CHECK: !transform.any_op + any_op = transform.AnyOpType.get() + print(any_op) - # CHECK: !transform.op<"foo.bar"> - # CHECK: foo.bar - concrete_op = transform.OperationType.get("foo.bar") - print(concrete_op) - print(concrete_op.operation_name) + # CHECK: !transform.op<"foo.bar"> + # CHECK: foo.bar + concrete_op = transform.OperationType.get("foo.bar") + print(concrete_op) + print(concrete_op.operation_name) @run def testSequenceOp(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, - [transform.AnyOpType.get()], - transform.AnyOpType.get()) - with InsertionPoint(sequence.body): - transform.YieldOp([sequence.bodyTarget]) - # CHECK-LABEL: TEST: testSequenceOp - # CHECK: = transform.sequence -> !transform.any_op failures(propagate) { - # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op): - # CHECK: yield %[[ARG0]] : !transform.any_op - # CHECK: } + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, + [transform.AnyOpType.get()], + transform.AnyOpType.get(), + ) + with InsertionPoint(sequence.body): + transform.YieldOp([sequence.bodyTarget]) + # CHECK-LABEL: TEST: testSequenceOp + # CHECK: = transform.sequence -> !transform.any_op failures(propagate) { + # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op): + # CHECK: yield %[[ARG0]] : !transform.any_op + # CHECK: } @run def testNestedSequenceOp(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()) - with InsertionPoint(sequence.body): - nested = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], sequence.bodyTarget) - with InsertionPoint(nested.body): - doubly_nested = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, - [transform.AnyOpType.get()], nested.bodyTarget) - with InsertionPoint(doubly_nested.body): - transform.YieldOp([doubly_nested.bodyTarget]) - transform.YieldOp() - transform.YieldOp() - # CHECK-LABEL: TEST: testNestedSequenceOp - # CHECK: transform.sequence failures(propagate) { - # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op): - # CHECK: sequence %[[ARG0]] : !transform.any_op failures(propagate) { - # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): - # CHECK: = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) { - # CHECK: ^{{.*}}(%[[ARG2:.+]]: !transform.any_op): - # CHECK: yield %[[ARG2]] : !transform.any_op - # CHECK: } - # CHECK: } - # CHECK: } + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + nested = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], sequence.bodyTarget + ) + with InsertionPoint(nested.body): + doubly_nested = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, + [transform.AnyOpType.get()], + nested.bodyTarget, + ) + with InsertionPoint(doubly_nested.body): + transform.YieldOp([doubly_nested.bodyTarget]) + transform.YieldOp() + transform.YieldOp() + # CHECK-LABEL: TEST: testNestedSequenceOp + # CHECK: transform.sequence failures(propagate) { + # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op): + # CHECK: sequence %[[ARG0]] : !transform.any_op failures(propagate) { + # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): + # CHECK: = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) { + # CHECK: ^{{.*}}(%[[ARG2:.+]]: !transform.any_op): + # CHECK: yield %[[ARG2]] : !transform.any_op + # CHECK: } + # CHECK: } + # CHECK: } @run def testSequenceOpWithExtras(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get(), - [transform.AnyOpType.get(), - transform.OperationType.get("foo.bar")]) - with InsertionPoint(sequence.body): - transform.YieldOp() - # CHECK-LABEL: TEST: testSequenceOpWithExtras - # CHECK: transform.sequence failures(propagate) - # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, + [], + transform.AnyOpType.get(), + [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")], + ) + with InsertionPoint(sequence.body): + transform.YieldOp() + # CHECK-LABEL: TEST: testSequenceOpWithExtras + # CHECK: transform.sequence failures(propagate) + # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">): @run def testNestedSequenceOpWithExtras(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get(), - [transform.AnyOpType.get(), - transform.OperationType.get("foo.bar")]) - with InsertionPoint(sequence.body): - nested = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, - [], sequence.bodyTarget, - sequence.bodyExtraArgs) - with InsertionPoint(nested.body): - transform.YieldOp() - transform.YieldOp() - # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras - # CHECK: transform.sequence failures(propagate) - # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">): - # CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">) + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, + [], + transform.AnyOpType.get(), + [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")], + ) + with InsertionPoint(sequence.body): + nested = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, + [], + sequence.bodyTarget, + sequence.bodyExtraArgs, + ) + with InsertionPoint(nested.body): + transform.YieldOp() + transform.YieldOp() + # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras + # CHECK: transform.sequence failures(propagate) + # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">): + # CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">) @run def testTransformPDLOps(): - withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get()) - with InsertionPoint(withPdl.body): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, - [transform.AnyOpType.get()], - withPdl.bodyTarget) - with InsertionPoint(sequence.body): - match = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher") - transform.YieldOp(match) - # CHECK-LABEL: TEST: testTransformPDLOps - # CHECK: transform.with_pdl_patterns { - # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op): - # CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) { - # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): - # CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]] - # CHECK: yield %[[RES]] : !transform.any_op - # CHECK: } - # CHECK: } + withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get()) + with InsertionPoint(withPdl.body): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, + [transform.AnyOpType.get()], + withPdl.bodyTarget, + ) + with InsertionPoint(sequence.body): + match = transform_pdl.PDLMatchOp( + transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher" + ) + transform.YieldOp(match) + # CHECK-LABEL: TEST: testTransformPDLOps + # CHECK: transform.with_pdl_patterns { + # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op): + # CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) { + # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): + # CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]] + # CHECK: yield %[[RES]] : !transform.any_op + # CHECK: } + # CHECK: } @run def testGetClosestIsolatedParentOp(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()) - with InsertionPoint(sequence.body): - transform.GetClosestIsolatedParentOp(transform.AnyOpType.get(), sequence.bodyTarget) - transform.YieldOp() - # CHECK-LABEL: TEST: testGetClosestIsolatedParentOp - # CHECK: transform.sequence - # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): - # CHECK: = get_closest_isolated_parent %[[ARG1]] + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + transform.GetClosestIsolatedParentOp( + transform.AnyOpType.get(), sequence.bodyTarget + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testGetClosestIsolatedParentOp + # CHECK: transform.sequence + # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): + # CHECK: = get_closest_isolated_parent %[[ARG1]] @run def testMergeHandlesOp(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()) - with InsertionPoint(sequence.body): - transform.MergeHandlesOp([sequence.bodyTarget]) - transform.YieldOp() - # CHECK-LABEL: TEST: testMergeHandlesOp - # CHECK: transform.sequence - # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): - # CHECK: = merge_handles %[[ARG1]] + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + transform.MergeHandlesOp([sequence.bodyTarget]) + transform.YieldOp() + # CHECK-LABEL: TEST: testMergeHandlesOp + # CHECK: transform.sequence + # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): + # CHECK: = merge_handles %[[ARG1]] @run def testReplicateOp(): - with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get()) - with InsertionPoint(with_pdl.body): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget) - with InsertionPoint(sequence.body): - m1 = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "first") - m2 = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "second") - transform.ReplicateOp(m1, [m2]) - transform.YieldOp() - # CHECK-LABEL: TEST: testReplicateOp - # CHECK: %[[FIRST:.+]] = pdl_match - # CHECK: %[[SECOND:.+]] = pdl_match - # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]] + with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get()) + with InsertionPoint(with_pdl.body): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget + ) + with InsertionPoint(sequence.body): + m1 = transform_pdl.PDLMatchOp( + transform.AnyOpType.get(), sequence.bodyTarget, "first" + ) + m2 = transform_pdl.PDLMatchOp( + transform.AnyOpType.get(), sequence.bodyTarget, "second" + ) + transform.ReplicateOp(m1, [m2]) + transform.YieldOp() + # CHECK-LABEL: TEST: testReplicateOp + # CHECK: %[[FIRST:.+]] = pdl_match + # CHECK: %[[SECOND:.+]] = pdl_match + # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]] diff --git a/mlir/test/python/dialects/transform_loop_ext.py b/mlir/test/python/dialects/transform_loop_ext.py index 067a8b6..28a022a 100644 --- a/mlir/test/python/dialects/transform_loop_ext.py +++ b/mlir/test/python/dialects/transform_loop_ext.py @@ -7,70 +7,92 @@ from mlir.dialects.transform import loop def run(f): - with Context(), Location.unknown(): - module = Module.create() - with InsertionPoint(module.body): - print("\nTEST:", f.__name__) - f() - print(module) - return f + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + print("\nTEST:", f.__name__) + f() + print(module) + return f @run def getParentLoop(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, - [], pdl.OperationType.get()) - with InsertionPoint(sequence.body): - loop.GetParentForOp(transform.OperationType.get("scf.for"), sequence.bodyTarget, num_loops=2) - transform.YieldOp() - # CHECK-LABEL: TEST: getParentLoop - # CHECK: = transform.loop.get_parent_for % - # CHECK: num_loops = 2 + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + ) + with InsertionPoint(sequence.body): + loop.GetParentForOp( + transform.OperationType.get("scf.for"), sequence.bodyTarget, num_loops=2 + ) + transform.YieldOp() + # CHECK-LABEL: TEST: getParentLoop + # CHECK: = transform.loop.get_parent_for % + # CHECK: num_loops = 2 @run def loopOutline(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, - [], transform.OperationType.get("scf.for")) - with InsertionPoint(sequence.body): - loop.LoopOutlineOp(transform.AnyOpType.get(), transform.AnyOpType.get(), sequence.bodyTarget, func_name="foo") - transform.YieldOp() - # CHECK-LABEL: TEST: loopOutline - # CHECK: = transform.loop.outline % - # CHECK: func_name = "foo" + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, + [], + transform.OperationType.get("scf.for"), + ) + with InsertionPoint(sequence.body): + loop.LoopOutlineOp( + transform.AnyOpType.get(), + transform.AnyOpType.get(), + sequence.bodyTarget, + func_name="foo", + ) + transform.YieldOp() + # CHECK-LABEL: TEST: loopOutline + # CHECK: = transform.loop.outline % + # CHECK: func_name = "foo" @run def loopPeel(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, - [], transform.OperationType.get("scf.for")) - with InsertionPoint(sequence.body): - loop.LoopPeelOp(pdl.OperationType.get(), sequence.bodyTarget) - transform.YieldOp() - # CHECK-LABEL: TEST: loopPeel - # CHECK: = transform.loop.peel % + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, + [], + transform.OperationType.get("scf.for"), + ) + with InsertionPoint(sequence.body): + loop.LoopPeelOp(pdl.OperationType.get(), sequence.bodyTarget) + transform.YieldOp() + # CHECK-LABEL: TEST: loopPeel + # CHECK: = transform.loop.peel % @run def loopPipeline(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, - [], transform.OperationType.get("scf.for")) - with InsertionPoint(sequence.body): - loop.LoopPipelineOp(pdl.OperationType.get(), sequence.bodyTarget, iteration_interval=3) - transform.YieldOp() - # CHECK-LABEL: TEST: loopPipeline - # CHECK: = transform.loop.pipeline % - # CHECK-DAG: iteration_interval = 3 - # (read_latency has default value and is not printed) + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, + [], + transform.OperationType.get("scf.for"), + ) + with InsertionPoint(sequence.body): + loop.LoopPipelineOp( + pdl.OperationType.get(), sequence.bodyTarget, iteration_interval=3 + ) + transform.YieldOp() + # CHECK-LABEL: TEST: loopPipeline + # CHECK: = transform.loop.pipeline % + # CHECK-DAG: iteration_interval = 3 + # (read_latency has default value and is not printed) @run def loopUnroll(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, - [], transform.OperationType.get("scf.for")) - with InsertionPoint(sequence.body): - loop.LoopUnrollOp(sequence.bodyTarget, factor=42) - transform.YieldOp() - # CHECK-LABEL: TEST: loopUnroll - # CHECK: transform.loop.unroll % - # CHECK: factor = 42 + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, + [], + transform.OperationType.get("scf.for"), + ) + with InsertionPoint(sequence.body): + loop.LoopUnrollOp(sequence.bodyTarget, factor=42) + transform.YieldOp() + # CHECK-LABEL: TEST: loopUnroll + # CHECK: transform.loop.unroll % + # CHECK: factor = 42 diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py index d2a82b8..2dfae47 100644 --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -8,204 +8,230 @@ from mlir.dialects.transform import pdl as transform_pdl def run(f): - with Context(), Location.unknown(): - module = Module.create() - with InsertionPoint(module.body): - print("\nTEST:", f.__name__) - f() - print(module) - return f + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + print("\nTEST:", f.__name__) + f() + print(module) + return f @run def testDecompose(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) - with InsertionPoint(sequence.body): - structured.DecomposeOp(sequence.bodyTarget) - transform.YieldOp() - # CHECK-LABEL: TEST: testDecompose - # CHECK: transform.sequence - # CHECK: transform.structured.decompose + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + ) + with InsertionPoint(sequence.body): + structured.DecomposeOp(sequence.bodyTarget) + transform.YieldOp() + # CHECK-LABEL: TEST: testDecompose + # CHECK: transform.sequence + # CHECK: transform.structured.decompose @run def testGeneralize(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) - with InsertionPoint(sequence.body): - structured.GeneralizeOp(sequence.bodyTarget) - transform.YieldOp() - # CHECK-LABEL: TEST: testGeneralize - # CHECK: transform.sequence - # CHECK: transform.structured.generalize + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + ) + with InsertionPoint(sequence.body): + structured.GeneralizeOp(sequence.bodyTarget) + transform.YieldOp() + # CHECK-LABEL: TEST: testGeneralize + # CHECK: transform.sequence + # CHECK: transform.structured.generalize @run def testInterchange(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) - with InsertionPoint(sequence.body): - structured.InterchangeOp( - sequence.bodyTarget, - iterator_interchange=[1, 0]) - transform.YieldOp() - # CHECK-LABEL: TEST: testInterchange - # CHECK: transform.sequence - # CHECK: transform.structured.interchange - # CHECK: iterator_interchange = [1, 0] + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + ) + with InsertionPoint(sequence.body): + structured.InterchangeOp(sequence.bodyTarget, iterator_interchange=[1, 0]) + transform.YieldOp() + # CHECK-LABEL: TEST: testInterchange + # CHECK: transform.sequence + # CHECK: transform.structured.interchange + # CHECK: iterator_interchange = [1, 0] @run def testMultitileSizes(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) - with InsertionPoint(sequence.body): - structured.MultiTileSizesOp(pdl.OperationType.get(), - sequence.bodyTarget, - dimension=1, - target_size=42) - transform.YieldOp() - # CHECK-LABEL: TEST: testMultitileSizes - # CHECK: transform.sequence - # CHECK: transform.structured.multitile_sizes - # CHECK-DAG: dimension = 1 - # CHECK-DAG: target_size = 42 + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + ) + with InsertionPoint(sequence.body): + structured.MultiTileSizesOp( + pdl.OperationType.get(), sequence.bodyTarget, dimension=1, target_size=42 + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testMultitileSizes + # CHECK: transform.sequence + # CHECK: transform.structured.multitile_sizes + # CHECK-DAG: dimension = 1 + # CHECK-DAG: target_size = 42 @run def testPad(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) - with InsertionPoint(sequence.body): - structured.PadOp( - sequence.bodyTarget, - padding_values=[FloatAttr.get_f32(42.0)], - padding_dimensions=[1], - transpose_paddings=[[1, 0]]) - transform.YieldOp() - # CHECK-LABEL: TEST: testPad - # CHECK: transform.sequence - # CHECK: transform.structured.pad - # CHECK-DAG: padding_values = [4.200000e+01 : f32] - # CHECK-DAG: padding_dimensions = [1] - # CHECK-DAG: transpose_paddings = {{\[}}[1, 0]] - # (pack_paddings has default values) + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + ) + with InsertionPoint(sequence.body): + structured.PadOp( + sequence.bodyTarget, + padding_values=[FloatAttr.get_f32(42.0)], + padding_dimensions=[1], + transpose_paddings=[[1, 0]], + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testPad + # CHECK: transform.sequence + # CHECK: transform.structured.pad + # CHECK-DAG: padding_values = [4.200000e+01 : f32] + # CHECK-DAG: padding_dimensions = [1] + # CHECK-DAG: transpose_paddings = {{\[}}[1, 0]] + # (pack_paddings has default values) + @run def testScalarize(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) - with InsertionPoint(sequence.body): - structured.ScalarizeOp(sequence.bodyTarget) - transform.YieldOp() - # CHECK-LABEL: TEST: testScalarize - # CHECK: transform.structured.scalarize + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + ) + with InsertionPoint(sequence.body): + structured.ScalarizeOp(sequence.bodyTarget) + transform.YieldOp() + # CHECK-LABEL: TEST: testScalarize + # CHECK: transform.structured.scalarize @run def testSplit(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) - with InsertionPoint(sequence.body): - split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42) - structured.SplitOp( - split.results[0], dimension=3, split_point=split.results[1]) - transform.YieldOp() - # CHECK-LABEL: TEST: testSplit - # CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1 - # CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3 + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + ) + with InsertionPoint(sequence.body): + split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42) + structured.SplitOp(split.results[0], dimension=3, split_point=split.results[1]) + transform.YieldOp() + # CHECK-LABEL: TEST: testSplit + # CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1 + # CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3 + @run def testTileCompact(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) - with InsertionPoint(sequence.body): - structured.TileOp(sequence.bodyTarget, - sizes=[4, 8], - interchange=[0, 1]) - transform.YieldOp() - # CHECK-LABEL: TEST: testTileCompact - # CHECK: transform.sequence - # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8] - # CHECK: interchange = [0, 1] + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + ) + with InsertionPoint(sequence.body): + structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1]) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileCompact + # CHECK: transform.sequence + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8] + # CHECK: interchange = [0, 1] + @run def testTileAttributes(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) - attr = DenseI64ArrayAttr.get([4, 8]) - ichange = DenseI64ArrayAttr.get([0, 1]) - with InsertionPoint(sequence.body): - structured.TileOp(sequence.bodyTarget, - sizes=attr, - interchange=ichange) - transform.YieldOp() - # CHECK-LABEL: TEST: testTileAttributes - # CHECK: transform.sequence - # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8] - # CHECK: interchange = [0, 1] + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + ) + attr = DenseI64ArrayAttr.get([4, 8]) + ichange = DenseI64ArrayAttr.get([0, 1]) + with InsertionPoint(sequence.body): + structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileAttributes + # CHECK: transform.sequence + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8] + # CHECK: interchange = [0, 1] + @run def testTileZero(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) - with InsertionPoint(sequence.body): - structured.TileOp(sequence.bodyTarget, - sizes=[4, 0, 2, 0], - interchange=[0, 1, 2, 3]) - transform.YieldOp() - # CHECK-LABEL: TEST: testTileZero - # CHECK: transform.sequence - # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 0, 2, 0] - # CHECK: interchange = [0, 1, 2, 3] + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + ) + with InsertionPoint(sequence.body): + structured.TileOp( + sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3] + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileZero + # CHECK: transform.sequence + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 0, 2, 0] + # CHECK: interchange = [0, 1, 2, 3] + @run def testTileDynamic(): - with_pdl = transform_pdl.WithPDLPatternsOp(pdl.OperationType.get()) - with InsertionPoint(with_pdl.body): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], - with_pdl.bodyTarget) - with InsertionPoint(sequence.body): - m1 = transform_pdl.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "first") - m2 = transform_pdl.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "second") - structured.TileOp(sequence.bodyTarget, - sizes=[m1, 3, m2, 0]) - transform.YieldOp() - # CHECK-LABEL: TEST: testTileDynamic - # CHECK: %[[FIRST:.+]] = pdl_match - # CHECK: %[[SECOND:.+]] = pdl_match - # CHECK: %{{.+}}, %{{.+}}:3 = transform.structured.tile %{{.*}}[%[[FIRST]], 3, %[[SECOND]], 0] + with_pdl = transform_pdl.WithPDLPatternsOp(pdl.OperationType.get()) + with InsertionPoint(with_pdl.body): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget + ) + with InsertionPoint(sequence.body): + m1 = transform_pdl.PDLMatchOp( + pdl.OperationType.get(), sequence.bodyTarget, "first" + ) + m2 = transform_pdl.PDLMatchOp( + pdl.OperationType.get(), sequence.bodyTarget, "second" + ) + structured.TileOp(sequence.bodyTarget, sizes=[m1, 3, m2, 0]) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileDynamic + # CHECK: %[[FIRST:.+]] = pdl_match + # CHECK: %[[SECOND:.+]] = pdl_match + # CHECK: %{{.+}}, %{{.+}}:3 = transform.structured.tile %{{.*}}[%[[FIRST]], 3, %[[SECOND]], 0] @run def testTileExplicitLoopTypeSingle(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, - [], transform.AnyOpType.get()) - with InsertionPoint(sequence.body): - structured.TileOp(transform.OperationType.get("scf.for"), - sequence.bodyTarget, - sizes=[2, 3, 4]) - transform.YieldOp() - # CHECK-LABEL: TEST: testTileExplicitLoopTypeSingle - # CHECK: = transform.structured.tile %{{.*}} : (!{{.*}}) -> - # CHECK-COUNT-3: !transform.op<"scf.for"> - + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + structured.TileOp( + transform.OperationType.get("scf.for"), sequence.bodyTarget, sizes=[2, 3, 4] + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileExplicitLoopTypeSingle + # CHECK: = transform.structured.tile %{{.*}} : (!{{.*}}) -> + # CHECK-COUNT-3: !transform.op<"scf.for"> @run def testTileExplicitLoopTypeAll(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, - [], transform.AnyOpType.get()) - types = [ - transform.OperationType.get(x) - for x in ["scf.for", "scf.parallel", "scf.forall"] - ] - with InsertionPoint(sequence.body): - structured.TileOp(types, sequence.bodyTarget, sizes=[2, 3, 4]) - transform.YieldOp() - # CHECK-LABEL: TEST: testTileExplicitLoopTypeAll - # CHECK: = transform.structured.tile - # CHECK-SAME : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, - # CHECK-SAME: !transform.op<"scf.parallel">, !transform.op<"scf.forall"> + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + types = [ + transform.OperationType.get(x) + for x in ["scf.for", "scf.parallel", "scf.forall"] + ] + with InsertionPoint(sequence.body): + structured.TileOp(types, sequence.bodyTarget, sizes=[2, 3, 4]) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileExplicitLoopTypeAll + # CHECK: = transform.structured.tile + # CHECK-SAME : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, + # CHECK-SAME: !transform.op<"scf.parallel">, !transform.op<"scf.forall"> + @run def testVectorize(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) - with InsertionPoint(sequence.body): - structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True) - transform.YieldOp() - # CHECK-LABEL: TEST: testVectorize - # CHECK: transform.sequence - # CHECK: = transform.structured.vectorize - # CHECK: {vectorize_padding} + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + ) + with InsertionPoint(sequence.body): + structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True) + transform.YieldOp() + # CHECK-LABEL: TEST: testVectorize + # CHECK: transform.sequence + # CHECK: = transform.structured.vectorize + # CHECK: {vectorize_padding} diff --git a/mlir/test/python/dialects/vector.py b/mlir/test/python/dialects/vector.py index 83c0961..2347abb 100644 --- a/mlir/test/python/dialects/vector.py +++ b/mlir/test/python/dialects/vector.py @@ -5,57 +5,62 @@ import mlir.dialects.builtin as builtin import mlir.dialects.func as func import mlir.dialects.vector as vector + def run(f): - print("\nTEST:", f.__name__) - with Context(), Location.unknown(): - f() - return f + print("\nTEST:", f.__name__) + with Context(), Location.unknown(): + f() + return f + # CHECK-LABEL: TEST: testPrintOp @run def testPrintOp(): - module = Module.create() - with InsertionPoint(module.body): + module = Module.create() + with InsertionPoint(module.body): - @func.FuncOp.from_py_func(VectorType.get((12, 5), F32Type.get())) - def print_vector(arg): - return vector.PrintOp(arg) + @func.FuncOp.from_py_func(VectorType.get((12, 5), F32Type.get())) + def print_vector(arg): + return vector.PrintOp(arg) - # CHECK-LABEL: func @print_vector( - # CHECK-SAME: %[[ARG:.*]]: vector<12x5xf32>) { - # CHECK: vector.print %[[ARG]] : vector<12x5xf32> - # CHECK: return - # CHECK: } - print(module) + # CHECK-LABEL: func @print_vector( + # CHECK-SAME: %[[ARG:.*]]: vector<12x5xf32>) { + # CHECK: vector.print %[[ARG]] : vector<12x5xf32> + # CHECK: return + # CHECK: } + print(module) # CHECK-LABEL: TEST: testTransferReadOp @run def testTransferReadOp(): - module = Module.create() - with InsertionPoint(module.body): - vector_type = VectorType.get([2, 3], F32Type.get()) - memref_type = MemRefType.get( - [ShapedType.get_dynamic_size(), - ShapedType.get_dynamic_size()], F32Type.get()) - index_type = IndexType.get() - mask_type = VectorType.get(vector_type.shape, IntegerType.get_signless(1)) - identity_map = AffineMap.get_identity(vector_type.rank) - identity_map_attr = AffineMapAttr.get(identity_map) - f = func.FuncOp("transfer_read", - ([memref_type, index_type, - F32Type.get(), mask_type], [])) - with InsertionPoint(f.add_entry_block()): - A, zero, padding, mask = f.arguments - vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr, - padding, mask=mask) - vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr, - padding) - func.ReturnOp([]) - - # CHECK: @transfer_read(%[[MEM:.*]]: memref, %[[IDX:.*]]: index, - # CHECK: %[[PAD:.*]]: f32, %[[MASK:.*]]: vector<2x3xi1>) - # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]], %[[MASK]] - # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]] - # CHECK-NOT: %[[MASK]] - print(module) + module = Module.create() + with InsertionPoint(module.body): + vector_type = VectorType.get([2, 3], F32Type.get()) + memref_type = MemRefType.get( + [ShapedType.get_dynamic_size(), ShapedType.get_dynamic_size()], + F32Type.get(), + ) + index_type = IndexType.get() + mask_type = VectorType.get(vector_type.shape, IntegerType.get_signless(1)) + identity_map = AffineMap.get_identity(vector_type.rank) + identity_map_attr = AffineMapAttr.get(identity_map) + f = func.FuncOp( + "transfer_read", ([memref_type, index_type, F32Type.get(), mask_type], []) + ) + with InsertionPoint(f.add_entry_block()): + A, zero, padding, mask = f.arguments + vector.TransferReadOp( + vector_type, A, [zero, zero], identity_map_attr, padding, mask=mask + ) + vector.TransferReadOp( + vector_type, A, [zero, zero], identity_map_attr, padding + ) + func.ReturnOp([]) + + # CHECK: @transfer_read(%[[MEM:.*]]: memref, %[[IDX:.*]]: index, + # CHECK: %[[PAD:.*]]: f32, %[[MASK:.*]]: vector<2x3xi1>) + # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]], %[[MASK]] + # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]] + # CHECK-NOT: %[[MASK]] + print(module) diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py index 973810d..50d6e82 100644 --- a/mlir/test/python/execution_engine.py +++ b/mlir/test/python/execution_engine.py @@ -10,34 +10,36 @@ from mlir.runtime import * # Log everything to stderr and flush so that we have a unified stream to match # errors/info emitted by MLIR to stderr. def log(*args): - print(*args, file=sys.stderr) - sys.stderr.flush() + print(*args, file=sys.stderr) + sys.stderr.flush() def run(f): - log("\nTEST:", f.__name__) - f() - gc.collect() - assert Context._get_live_count() == 0 + log("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 # Verify capsule interop. # CHECK-LABEL: TEST: testCapsule def testCapsule(): - with Context(): - module = Module.parse(r""" + with Context(): + module = Module.parse( + r""" llvm.func @none() { llvm.return } - """) - execution_engine = ExecutionEngine(module) - execution_engine_capsule = execution_engine._CAPIPtr - # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr - log(repr(execution_engine_capsule)) - execution_engine._testing_release() - execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule) - # CHECK: _mlirExecutionEngine.ExecutionEngine - log(repr(execution_engine1)) + """ + ) + execution_engine = ExecutionEngine(module) + execution_engine_capsule = execution_engine._CAPIPtr + # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr + log(repr(execution_engine_capsule)) + execution_engine._testing_release() + execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule) + # CHECK: _mlirExecutionEngine.ExecutionEngine + log(repr(execution_engine1)) run(testCapsule) @@ -46,40 +48,45 @@ run(testCapsule) # Test invalid ExecutionEngine creation # CHECK-LABEL: TEST: testInvalidModule def testInvalidModule(): - with Context(): - # Builtin function - module = Module.parse(r""" + with Context(): + # Builtin function + module = Module.parse( + r""" func.func @foo() { return } - """) - # CHECK: Got RuntimeError: Failure while creating the ExecutionEngine. - try: - execution_engine = ExecutionEngine(module) - except RuntimeError as e: - log("Got RuntimeError: ", e) + """ + ) + # CHECK: Got RuntimeError: Failure while creating the ExecutionEngine. + try: + execution_engine = ExecutionEngine(module) + except RuntimeError as e: + log("Got RuntimeError: ", e) run(testInvalidModule) def lowerToLLVM(module): - pm = PassManager.parse( - "builtin.module(convert-complex-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts)") - pm.run(module.operation) - return module + pm = PassManager.parse( + "builtin.module(convert-complex-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts)" + ) + pm.run(module.operation) + return module # Test simple ExecutionEngine execution # CHECK-LABEL: TEST: testInvokeVoid def testInvokeVoid(): - with Context(): - module = Module.parse(r""" + with Context(): + module = Module.parse( + r""" func.func @void() attributes { llvm.emit_c_interface } { return } - """) - execution_engine = ExecutionEngine(lowerToLLVM(module)) - # Nothing to check other than no exception thrown here. - execution_engine.invoke("void") + """ + ) + execution_engine = ExecutionEngine(lowerToLLVM(module)) + # Nothing to check other than no exception thrown here. + execution_engine.invoke("void") run(testInvokeVoid) @@ -88,23 +95,25 @@ run(testInvokeVoid) # Test argument passing and result with a simple float addition. # CHECK-LABEL: TEST: testInvokeFloatAdd def testInvokeFloatAdd(): - with Context(): - module = Module.parse(r""" + with Context(): + module = Module.parse( + r""" func.func @add(%arg0: f32, %arg1: f32) -> f32 attributes { llvm.emit_c_interface } { %add = arith.addf %arg0, %arg1 : f32 return %add : f32 } - """) - execution_engine = ExecutionEngine(lowerToLLVM(module)) - # Prepare arguments: two input floats and one result. - # Arguments must be passed as pointers. - c_float_p = ctypes.c_float * 1 - arg0 = c_float_p(42.) - arg1 = c_float_p(2.) - res = c_float_p(-1.) - execution_engine.invoke("add", arg0, arg1, res) - # CHECK: 42.0 + 2.0 = 44.0 - log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0])) + """ + ) + execution_engine = ExecutionEngine(lowerToLLVM(module)) + # Prepare arguments: two input floats and one result. + # Arguments must be passed as pointers. + c_float_p = ctypes.c_float * 1 + arg0 = c_float_p(42.0) + arg1 = c_float_p(2.0) + res = c_float_p(-1.0) + execution_engine.invoke("add", arg0, arg1, res) + # CHECK: 42.0 + 2.0 = 44.0 + log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0])) run(testInvokeFloatAdd) @@ -113,33 +122,35 @@ run(testInvokeFloatAdd) # Test callback # CHECK-LABEL: TEST: testBasicCallback def testBasicCallback(): - # Define a callback function that takes a float and an integer and returns a float. - @ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float, ctypes.c_int) - def callback(a, b): - return a / 2 + b / 2 - - with Context(): - # The module just forwards to a runtime function known as "some_callback_into_python". - module = Module.parse(r""" + # Define a callback function that takes a float and an integer and returns a float. + @ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float, ctypes.c_int) + def callback(a, b): + return a / 2 + b / 2 + + with Context(): + # The module just forwards to a runtime function known as "some_callback_into_python". + module = Module.parse( + r""" func.func @add(%arg0: f32, %arg1: i32) -> f32 attributes { llvm.emit_c_interface } { %resf = call @some_callback_into_python(%arg0, %arg1) : (f32, i32) -> (f32) return %resf : f32 } func.func private @some_callback_into_python(f32, i32) -> f32 attributes { llvm.emit_c_interface } - """) - execution_engine = ExecutionEngine(lowerToLLVM(module)) - execution_engine.register_runtime("some_callback_into_python", callback) - - # Prepare arguments: two input floats and one result. - # Arguments must be passed as pointers. - c_float_p = ctypes.c_float * 1 - c_int_p = ctypes.c_int * 1 - arg0 = c_float_p(42.) - arg1 = c_int_p(2) - res = c_float_p(-1.) - execution_engine.invoke("add", arg0, arg1, res) - # CHECK: 42.0 + 2 = 44.0 - log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0] * 2)) + """ + ) + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.register_runtime("some_callback_into_python", callback) + + # Prepare arguments: two input floats and one result. + # Arguments must be passed as pointers. + c_float_p = ctypes.c_float * 1 + c_int_p = ctypes.c_int * 1 + arg0 = c_float_p(42.0) + arg1 = c_int_p(2) + res = c_float_p(-1.0) + execution_engine.invoke("add", arg0, arg1, res) + # CHECK: 42.0 + 2 = 44.0 + log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0] * 2)) run(testBasicCallback) @@ -148,44 +159,46 @@ run(testBasicCallback) # Test callback with an unranked memref # CHECK-LABEL: TEST: testUnrankedMemRefCallback def testUnrankedMemRefCallback(): - # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it. - @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) - def callback(a): - arr = unranked_memref_to_numpy(a, np.float32) - log("Inside callback: ") - log(arr) - - with Context(): - # The module just forwards to a runtime function known as "some_callback_into_python". - module = Module.parse(r""" + # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it. + @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) + def callback(a): + arr = unranked_memref_to_numpy(a, np.float32) + log("Inside callback: ") + log(arr) + + with Context(): + # The module just forwards to a runtime function known as "some_callback_into_python". + module = Module.parse( + r""" func.func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } { call @some_callback_into_python(%arg0) : (memref<*xf32>) -> () return } func.func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface } -""") - execution_engine = ExecutionEngine(lowerToLLVM(module)) - execution_engine.register_runtime("some_callback_into_python", callback) - inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32) - # CHECK: Inside callback: - # CHECK{LITERAL}: [[1. 2.] - # CHECK{LITERAL}: [3. 4.]] - execution_engine.invoke( - "callback_memref", - ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))), - ) - inp_arr_1 = np.array([5, 6, 7], dtype=np.float32) - strided_arr = np.lib.stride_tricks.as_strided( - inp_arr_1, strides=(4, 0), shape=(3, 4)) - # CHECK: Inside callback: - # CHECK{LITERAL}: [[5. 5. 5. 5.] - # CHECK{LITERAL}: [6. 6. 6. 6.] - # CHECK{LITERAL}: [7. 7. 7. 7.]] - execution_engine.invoke( - "callback_memref", - ctypes.pointer( - ctypes.pointer(get_unranked_memref_descriptor(strided_arr))), - ) +""" + ) + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.register_runtime("some_callback_into_python", callback) + inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32) + # CHECK: Inside callback: + # CHECK{LITERAL}: [[1. 2.] + # CHECK{LITERAL}: [3. 4.]] + execution_engine.invoke( + "callback_memref", + ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))), + ) + inp_arr_1 = np.array([5, 6, 7], dtype=np.float32) + strided_arr = np.lib.stride_tricks.as_strided( + inp_arr_1, strides=(4, 0), shape=(3, 4) + ) + # CHECK: Inside callback: + # CHECK{LITERAL}: [[5. 5. 5. 5.] + # CHECK{LITERAL}: [6. 6. 6. 6.] + # CHECK{LITERAL}: [7. 7. 7. 7.]] + execution_engine.invoke( + "callback_memref", + ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(strided_arr))), + ) run(testUnrankedMemRefCallback) @@ -194,36 +207,39 @@ run(testUnrankedMemRefCallback) # Test callback with a ranked memref. # CHECK-LABEL: TEST: testRankedMemRefCallback def testRankedMemRefCallback(): - # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it. - @ctypes.CFUNCTYPE( - None, - ctypes.POINTER( - make_nd_memref_descriptor(2, - np.ctypeslib.as_ctypes_type(np.float32))), - ) - def callback(a): - arr = ranked_memref_to_numpy(a) - log("Inside Callback: ") - log(arr) - - with Context(): - # The module just forwards to a runtime function known as "some_callback_into_python". - module = Module.parse(r""" + # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it. + @ctypes.CFUNCTYPE( + None, + ctypes.POINTER( + make_nd_memref_descriptor(2, np.ctypeslib.as_ctypes_type(np.float32)) + ), + ) + def callback(a): + arr = ranked_memref_to_numpy(a) + log("Inside Callback: ") + log(arr) + + with Context(): + # The module just forwards to a runtime function known as "some_callback_into_python". + module = Module.parse( + r""" func.func @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } { call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> () return } func.func private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface } -""") - execution_engine = ExecutionEngine(lowerToLLVM(module)) - execution_engine.register_runtime("some_callback_into_python", callback) - inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32) - # CHECK: Inside Callback: - # CHECK{LITERAL}: [[1. 5.] - # CHECK{LITERAL}: [6. 7.]] - execution_engine.invoke( - "callback_memref", - ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr)))) +""" + ) + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.register_runtime("some_callback_into_python", callback) + inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32) + # CHECK: Inside Callback: + # CHECK{LITERAL}: [[1. 5.] + # CHECK{LITERAL}: [6. 7.]] + execution_engine.invoke( + "callback_memref", + ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))), + ) run(testRankedMemRefCallback) @@ -232,8 +248,9 @@ run(testRankedMemRefCallback) # Test addition of two memrefs. # CHECK-LABEL: TEST: testMemrefAdd def testMemrefAdd(): - with Context(): - module = Module.parse(""" + with Context(): + module = Module.parse( + """ module { func.func @main(%arg0: memref<1xf32>, %arg1: memref, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } { %0 = arith.constant 0 : index @@ -243,23 +260,28 @@ def testMemrefAdd(): memref.store %3, %arg2[%0] : memref<1xf32> return } - } """) - arg1 = np.array([32.5]).astype(np.float32) - arg2 = np.array(6).astype(np.float32) - res = np.array([0]).astype(np.float32) - - arg1_memref_ptr = ctypes.pointer( - ctypes.pointer(get_ranked_memref_descriptor(arg1))) - arg2_memref_ptr = ctypes.pointer( - ctypes.pointer(get_ranked_memref_descriptor(arg2))) - res_memref_ptr = ctypes.pointer( - ctypes.pointer(get_ranked_memref_descriptor(res))) - - execution_engine = ExecutionEngine(lowerToLLVM(module)) - execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr, - res_memref_ptr) - # CHECK: [32.5] + 6.0 = [38.5] - log("{0} + {1} = {2}".format(arg1, arg2, res)) + } """ + ) + arg1 = np.array([32.5]).astype(np.float32) + arg2 = np.array(6).astype(np.float32) + res = np.array([0]).astype(np.float32) + + arg1_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(arg1)) + ) + arg2_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(arg2)) + ) + res_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(res)) + ) + + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.invoke( + "main", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr + ) + # CHECK: [32.5] + 6.0 = [38.5] + log("{0} + {1} = {2}".format(arg1, arg2, res)) run(testMemrefAdd) @@ -268,8 +290,9 @@ run(testMemrefAdd) # Test addition of two f16 memrefs # CHECK-LABEL: TEST: testF16MemrefAdd def testF16MemrefAdd(): - with Context(): - module = Module.parse(""" + with Context(): + module = Module.parse( + """ module { func.func @main(%arg0: memref<1xf16>, %arg1: memref<1xf16>, @@ -281,29 +304,34 @@ def testF16MemrefAdd(): memref.store %3, %arg2[%0] : memref<1xf16> return } - } """) - - arg1 = np.array([11.]).astype(np.float16) - arg2 = np.array([22.]).astype(np.float16) - arg3 = np.array([0.]).astype(np.float16) - - arg1_memref_ptr = ctypes.pointer( - ctypes.pointer(get_ranked_memref_descriptor(arg1))) - arg2_memref_ptr = ctypes.pointer( - ctypes.pointer(get_ranked_memref_descriptor(arg2))) - arg3_memref_ptr = ctypes.pointer( - ctypes.pointer(get_ranked_memref_descriptor(arg3))) - - execution_engine = ExecutionEngine(lowerToLLVM(module)) - execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr, - arg3_memref_ptr) - # CHECK: [11.] + [22.] = [33.] - log("{0} + {1} = {2}".format(arg1, arg2, arg3)) - - # test to-numpy utility - # CHECK: [33.] - npout = ranked_memref_to_numpy(arg3_memref_ptr[0]) - log(npout) + } """ + ) + + arg1 = np.array([11.0]).astype(np.float16) + arg2 = np.array([22.0]).astype(np.float16) + arg3 = np.array([0.0]).astype(np.float16) + + arg1_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(arg1)) + ) + arg2_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(arg2)) + ) + arg3_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(arg3)) + ) + + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.invoke( + "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr + ) + # CHECK: [11.] + [22.] = [33.] + log("{0} + {1} = {2}".format(arg1, arg2, arg3)) + + # test to-numpy utility + # CHECK: [33.] + npout = ranked_memref_to_numpy(arg3_memref_ptr[0]) + log(npout) run(testF16MemrefAdd) @@ -312,8 +340,9 @@ run(testF16MemrefAdd) # Test addition of two complex memrefs # CHECK-LABEL: TEST: testComplexMemrefAdd def testComplexMemrefAdd(): - with Context(): - module = Module.parse(""" + with Context(): + module = Module.parse( + """ module { func.func @main(%arg0: memref<1xcomplex>, %arg1: memref<1xcomplex>, @@ -325,31 +354,34 @@ def testComplexMemrefAdd(): memref.store %3, %arg2[%0] : memref<1xcomplex> return } - } """) - - arg1 = np.array([1.+2.j]).astype(np.complex128) - arg2 = np.array([3.+4.j]).astype(np.complex128) - arg3 = np.array([0.+0.j]).astype(np.complex128) - - arg1_memref_ptr = ctypes.pointer( - ctypes.pointer(get_ranked_memref_descriptor(arg1))) - arg2_memref_ptr = ctypes.pointer( - ctypes.pointer(get_ranked_memref_descriptor(arg2))) - arg3_memref_ptr = ctypes.pointer( - ctypes.pointer(get_ranked_memref_descriptor(arg3))) - - execution_engine = ExecutionEngine(lowerToLLVM(module)) - execution_engine.invoke("main", - arg1_memref_ptr, - arg2_memref_ptr, - arg3_memref_ptr) - # CHECK: [1.+2.j] + [3.+4.j] = [4.+6.j] - log("{0} + {1} = {2}".format(arg1, arg2, arg3)) - - # test to-numpy utility - # CHECK: [4.+6.j] - npout = ranked_memref_to_numpy(arg3_memref_ptr[0]) - log(npout) + } """ + ) + + arg1 = np.array([1.0 + 2.0j]).astype(np.complex128) + arg2 = np.array([3.0 + 4.0j]).astype(np.complex128) + arg3 = np.array([0.0 + 0.0j]).astype(np.complex128) + + arg1_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(arg1)) + ) + arg2_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(arg2)) + ) + arg3_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(arg3)) + ) + + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.invoke( + "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr + ) + # CHECK: [1.+2.j] + [3.+4.j] = [4.+6.j] + log("{0} + {1} = {2}".format(arg1, arg2, arg3)) + + # test to-numpy utility + # CHECK: [4.+6.j] + npout = ranked_memref_to_numpy(arg3_memref_ptr[0]) + log(npout) run(testComplexMemrefAdd) @@ -358,8 +390,9 @@ run(testComplexMemrefAdd) # Test addition of two complex unranked memrefs # CHECK-LABEL: TEST: testComplexUnrankedMemrefAdd def testComplexUnrankedMemrefAdd(): - with Context(): - module = Module.parse(""" + with Context(): + module = Module.parse( + """ module { func.func @main(%arg0: memref<*xcomplex>, %arg1: memref<*xcomplex>, @@ -374,32 +407,34 @@ def testComplexUnrankedMemrefAdd(): memref.store %3, %C[%0] : memref<1xcomplex> return } - } """) - - arg1 = np.array([5.+6.j]).astype(np.complex64) - arg2 = np.array([7.+8.j]).astype(np.complex64) - arg3 = np.array([0.+0.j]).astype(np.complex64) - - arg1_memref_ptr = ctypes.pointer( - ctypes.pointer(get_unranked_memref_descriptor(arg1))) - arg2_memref_ptr = ctypes.pointer( - ctypes.pointer(get_unranked_memref_descriptor(arg2))) - arg3_memref_ptr = ctypes.pointer( - ctypes.pointer(get_unranked_memref_descriptor(arg3))) - - execution_engine = ExecutionEngine(lowerToLLVM(module)) - execution_engine.invoke("main", - arg1_memref_ptr, - arg2_memref_ptr, - arg3_memref_ptr) - # CHECK: [5.+6.j] + [7.+8.j] = [12.+14.j] - log("{0} + {1} = {2}".format(arg1, arg2, arg3)) - - # test to-numpy utility - # CHECK: [12.+14.j] - npout = unranked_memref_to_numpy(arg3_memref_ptr[0], - np.dtype(np.complex64)) - log(npout) + } """ + ) + + arg1 = np.array([5.0 + 6.0j]).astype(np.complex64) + arg2 = np.array([7.0 + 8.0j]).astype(np.complex64) + arg3 = np.array([0.0 + 0.0j]).astype(np.complex64) + + arg1_memref_ptr = ctypes.pointer( + ctypes.pointer(get_unranked_memref_descriptor(arg1)) + ) + arg2_memref_ptr = ctypes.pointer( + ctypes.pointer(get_unranked_memref_descriptor(arg2)) + ) + arg3_memref_ptr = ctypes.pointer( + ctypes.pointer(get_unranked_memref_descriptor(arg3)) + ) + + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.invoke( + "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr + ) + # CHECK: [5.+6.j] + [7.+8.j] = [12.+14.j] + log("{0} + {1} = {2}".format(arg1, arg2, arg3)) + + # test to-numpy utility + # CHECK: [12.+14.j] + npout = unranked_memref_to_numpy(arg3_memref_ptr[0], np.dtype(np.complex64)) + log(npout) run(testComplexUnrankedMemrefAdd) @@ -408,8 +443,9 @@ run(testComplexUnrankedMemrefAdd) # Test addition of two 2d_memref # CHECK-LABEL: TEST: testDynamicMemrefAdd2D def testDynamicMemrefAdd2D(): - with Context(): - module = Module.parse(""" + with Context(): + module = Module.parse( + """ module { func.func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index @@ -441,23 +477,28 @@ def testDynamicMemrefAdd2D(): return } } - """) - arg1 = np.random.randn(2, 2).astype(np.float32) - arg2 = np.random.randn(2, 2).astype(np.float32) - res = np.random.randn(2, 2).astype(np.float32) - - arg1_memref_ptr = ctypes.pointer( - ctypes.pointer(get_ranked_memref_descriptor(arg1))) - arg2_memref_ptr = ctypes.pointer( - ctypes.pointer(get_ranked_memref_descriptor(arg2))) - res_memref_ptr = ctypes.pointer( - ctypes.pointer(get_ranked_memref_descriptor(res))) - - execution_engine = ExecutionEngine(lowerToLLVM(module)) - execution_engine.invoke("memref_add_2d", arg1_memref_ptr, arg2_memref_ptr, - res_memref_ptr) - # CHECK: True - log(np.allclose(arg1 + arg2, res)) + """ + ) + arg1 = np.random.randn(2, 2).astype(np.float32) + arg2 = np.random.randn(2, 2).astype(np.float32) + res = np.random.randn(2, 2).astype(np.float32) + + arg1_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(arg1)) + ) + arg2_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(arg2)) + ) + res_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(res)) + ) + + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.invoke( + "memref_add_2d", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr + ) + # CHECK: True + log(np.allclose(arg1 + arg2, res)) run(testDynamicMemrefAdd2D) @@ -466,8 +507,9 @@ run(testDynamicMemrefAdd2D) # Test loading of shared libraries. # CHECK-LABEL: TEST: testSharedLibLoad def testSharedLibLoad(): - with Context(): - module = Module.parse(""" + with Context(): + module = Module.parse( + """ module { func.func @main(%arg0: memref<1xf32>) attributes { llvm.emit_c_interface } { %c0 = arith.constant 0 : index @@ -478,35 +520,36 @@ def testSharedLibLoad(): return } func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface } - } """) - arg0 = np.array([0.0]).astype(np.float32) - - arg0_memref_ptr = ctypes.pointer( - ctypes.pointer(get_ranked_memref_descriptor(arg0))) - - if sys.platform == 'win32': - shared_libs = [ - "../../../../bin/mlir_runner_utils.dll", - "../../../../bin/mlir_c_runner_utils.dll" - ] - elif sys.platform == 'darwin': - shared_libs = [ - "../../../../lib/libmlir_runner_utils.dylib", - "../../../../lib/libmlir_c_runner_utils.dylib" - ] - else: - shared_libs = [ - "../../../../lib/libmlir_runner_utils.so", - "../../../../lib/libmlir_c_runner_utils.so" - ] - - execution_engine = ExecutionEngine( - lowerToLLVM(module), - opt_level=3, - shared_libs=shared_libs) - execution_engine.invoke("main", arg0_memref_ptr) - # CHECK: Unranked Memref - # CHECK-NEXT: [42] + } """ + ) + arg0 = np.array([0.0]).astype(np.float32) + + arg0_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(arg0)) + ) + + if sys.platform == "win32": + shared_libs = [ + "../../../../bin/mlir_runner_utils.dll", + "../../../../bin/mlir_c_runner_utils.dll", + ] + elif sys.platform == "darwin": + shared_libs = [ + "../../../../lib/libmlir_runner_utils.dylib", + "../../../../lib/libmlir_c_runner_utils.dylib", + ] + else: + shared_libs = [ + "../../../../lib/libmlir_runner_utils.so", + "../../../../lib/libmlir_c_runner_utils.so", + ] + + execution_engine = ExecutionEngine( + lowerToLLVM(module), opt_level=3, shared_libs=shared_libs + ) + execution_engine.invoke("main", arg0_memref_ptr) + # CHECK: Unranked Memref + # CHECK-NEXT: [42] run(testSharedLibLoad) @@ -515,8 +558,9 @@ run(testSharedLibLoad) # Test that nano time clock is available. # CHECK-LABEL: TEST: testNanoTime def testNanoTime(): - with Context(): - module = Module.parse(""" + with Context(): + module = Module.parse( + """ module { func.func @main() attributes { llvm.emit_c_interface } { %now = call @nanoTime() : () -> i64 @@ -529,26 +573,26 @@ def testNanoTime(): } func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface } func.func private @printMemrefI64(memref<*xi64>) attributes { llvm.emit_c_interface } - }""") - - if sys.platform == 'win32': - shared_libs = [ - "../../../../bin/mlir_runner_utils.dll", - "../../../../bin/mlir_c_runner_utils.dll" - ] - else: - shared_libs = [ - "../../../../lib/libmlir_runner_utils.so", - "../../../../lib/libmlir_c_runner_utils.so" - ] - - execution_engine = ExecutionEngine( - lowerToLLVM(module), - opt_level=3, - shared_libs=shared_libs) - execution_engine.invoke("main") - # CHECK: Unranked Memref - # CHECK: [{{.*}}] + }""" + ) + + if sys.platform == "win32": + shared_libs = [ + "../../../../bin/mlir_runner_utils.dll", + "../../../../bin/mlir_c_runner_utils.dll", + ] + else: + shared_libs = [ + "../../../../lib/libmlir_runner_utils.so", + "../../../../lib/libmlir_c_runner_utils.so", + ] + + execution_engine = ExecutionEngine( + lowerToLLVM(module), opt_level=3, shared_libs=shared_libs + ) + execution_engine.invoke("main") + # CHECK: Unranked Memref + # CHECK: [{{.*}}] run(testNanoTime) @@ -557,36 +601,36 @@ run(testNanoTime) # Test that nano time clock is available. # CHECK-LABEL: TEST: testDumpToObjectFile def testDumpToObjectFile(): - fd, object_path = tempfile.mkstemp(suffix=".o") + fd, object_path = tempfile.mkstemp(suffix=".o") - try: - with Context(): - module = Module.parse(""" + try: + with Context(): + module = Module.parse( + """ module { func.func @main() attributes { llvm.emit_c_interface } { return } - }""") + }""" + ) - execution_engine = ExecutionEngine( - lowerToLLVM(module), - opt_level=3) + execution_engine = ExecutionEngine(lowerToLLVM(module), opt_level=3) - # CHECK: Object file exists: True - print(f"Object file exists: {os.path.exists(object_path)}") - # CHECK: Object file is empty: True - print(f"Object file is empty: {os.path.getsize(object_path) == 0}") + # CHECK: Object file exists: True + print(f"Object file exists: {os.path.exists(object_path)}") + # CHECK: Object file is empty: True + print(f"Object file is empty: {os.path.getsize(object_path) == 0}") - execution_engine.dump_to_object_file(object_path) + execution_engine.dump_to_object_file(object_path) - # CHECK: Object file exists: True - print(f"Object file exists: {os.path.exists(object_path)}") - # CHECK: Object file is empty: False - print(f"Object file is empty: {os.path.getsize(object_path) == 0}") + # CHECK: Object file exists: True + print(f"Object file exists: {os.path.exists(object_path)}") + # CHECK: Object file is empty: False + print(f"Object file is empty: {os.path.getsize(object_path) == 0}") - finally: - os.close(fd) - os.remove(object_path) + finally: + os.close(fd) + os.remove(object_path) run(testDumpToObjectFile) diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py index 2cba577..f6519fb 100644 --- a/mlir/test/python/integration/dialects/linalg/opsrun.py +++ b/mlir/test/python/integration/dialects/linalg/opsrun.py @@ -15,8 +15,8 @@ from mlir.dialects.linalg.opdsl.lang import * # Log everything to stderr and flush so that we have a unified stream to match # errors/info emitted by MLIR to stderr. def log(*args): - print(*args, file=sys.stderr) - sys.stderr.flush() + print(*args, file=sys.stderr) + sys.stderr.flush() elemwise_boiler = """ @@ -186,428 +186,458 @@ func.func @main() -> i32 attributes {llvm.emit_c_interface} { def transform(module, boilerplate): - # TODO: Allow cloning functions from one module to another. - # Atm we have to resort to string concatenation. - ops = module.operation.regions[0].blocks[0].operations - mod = Module.parse("\n".join([str(op) for op in ops]) + boilerplate) - - pm = PassManager('builtin.module') - pm.add("func.func(convert-linalg-to-loops)") - pm.add("func.func(lower-affine)") - pm.add("func.func(convert-math-to-llvm)") - pm.add("func.func(convert-scf-to-cf)") - pm.add("func.func(arith-expand)") - pm.add("func.func(memref-expand)") - pm.add("convert-vector-to-llvm") - pm.add("finalize-memref-to-llvm") - pm.add("convert-func-to-llvm") - pm.add("reconcile-unrealized-casts") - pm.run(mod.operation) - return mod + # TODO: Allow cloning functions from one module to another. + # Atm we have to resort to string concatenation. + ops = module.operation.regions[0].blocks[0].operations + mod = Module.parse("\n".join([str(op) for op in ops]) + boilerplate) + + pm = PassManager("builtin.module") + pm.add("func.func(convert-linalg-to-loops)") + pm.add("func.func(lower-affine)") + pm.add("func.func(convert-math-to-llvm)") + pm.add("func.func(convert-scf-to-cf)") + pm.add("func.func(arith-expand)") + pm.add("func.func(memref-expand)") + pm.add("convert-vector-to-llvm") + pm.add("finalize-memref-to-llvm") + pm.add("convert-func-to-llvm") + pm.add("reconcile-unrealized-casts") + pm.run(mod.operation) + return mod def test_elemwise_builtin(): - with Context() as ctx, Location.unknown(): - module = Module.create() - f32 = F32Type.get() - i8 = IntegerType.get_signless(8) - with InsertionPoint(module.body): - - @func.FuncOp.from_py_func( - MemRefType.get((), f32), MemRefType.get((4, 8), f32), - MemRefType.get((4, 8), f32)) - def elemwise_exp_add_on_buffers(lhs, rhs, out): - linalg.elemwise_unary(lhs, outs=[out]) - linalg.elemwise_binary(out, rhs, outs=[out]) - - @func.FuncOp.from_py_func( - MemRefType.get((), f32), MemRefType.get((4, 8), f32), - MemRefType.get((4, 8), f32)) - def elemwise_log_mul_on_buffers(lhs, rhs, out): - linalg.elemwise_unary(lhs, outs=[out], fun=UnaryFn.log) - linalg.elemwise_binary(out, rhs, outs=[out], fun=BinaryFn.mul) - - execution_engine = ExecutionEngine(transform(module, elemwise_boiler)) - - # TODO: FFI-based solution to allow testing and printing with python code. - # Prepare arguments: one result f32. - # Arguments must be passed as pointers. - c_float_p = ctypes.c_float * 1 - res = c_float_p(-1.) - execution_engine.invoke("main", res) - - log("RESULT: ", res[0]) - # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846 - # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0 - # CHECK: RESULT: 4.71828 + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + i8 = IntegerType.get_signless(8) + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func( + MemRefType.get((), f32), + MemRefType.get((4, 8), f32), + MemRefType.get((4, 8), f32), + ) + def elemwise_exp_add_on_buffers(lhs, rhs, out): + linalg.elemwise_unary(lhs, outs=[out]) + linalg.elemwise_binary(out, rhs, outs=[out]) + + @func.FuncOp.from_py_func( + MemRefType.get((), f32), + MemRefType.get((4, 8), f32), + MemRefType.get((4, 8), f32), + ) + def elemwise_log_mul_on_buffers(lhs, rhs, out): + linalg.elemwise_unary(lhs, outs=[out], fun=UnaryFn.log) + linalg.elemwise_binary(out, rhs, outs=[out], fun=BinaryFn.mul) + + execution_engine = ExecutionEngine(transform(module, elemwise_boiler)) + + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result f32. + # Arguments must be passed as pointers. + c_float_p = ctypes.c_float * 1 + res = c_float_p(-1.0) + execution_engine.invoke("main", res) + + log("RESULT: ", res[0]) + # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846 + # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0 + # CHECK: RESULT: 4.71828 test_elemwise_builtin() def test_elemwise_generic(): - with Context() as ctx, Location.unknown(): - module = Module.create() - f32 = F32Type.get() - i8 = IntegerType.get_signless(8) - with InsertionPoint(module.body): - - @func.FuncOp.from_py_func( - MemRefType.get((), f32), MemRefType.get((4, 8), f32), - MemRefType.get((4, 8), f32)) - def elemwise_exp_add_on_buffers(lhs, rhs, out): - linalg.elemwise_unary(lhs, outs=[out], emit_generic=True) - linalg.elemwise_binary(out, rhs, outs=[out], emit_generic=True) - - @func.FuncOp.from_py_func( - MemRefType.get((), f32), MemRefType.get((4, 8), f32), - MemRefType.get((4, 8), f32)) - def elemwise_log_mul_on_buffers(lhs, rhs, out): - linalg.elemwise_unary( - lhs, outs=[out], fun=UnaryFn.log, emit_generic=True) - linalg.elemwise_binary( - out, rhs, outs=[out], fun=BinaryFn.mul, emit_generic=True) - - execution_engine = ExecutionEngine(transform(module, elemwise_boiler)) - - # TODO: FFI-based solution to allow testing and printing with python code. - # Prepare arguments: one result f32. - # Arguments must be passed as pointers. - c_float_p = ctypes.c_float * 1 - res = c_float_p(-1.) - execution_engine.invoke("main", res) - - log("RESULT: ", res[0]) - # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846 - # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0 - # CHECK: RESULT: 4.71828 + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + i8 = IntegerType.get_signless(8) + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func( + MemRefType.get((), f32), + MemRefType.get((4, 8), f32), + MemRefType.get((4, 8), f32), + ) + def elemwise_exp_add_on_buffers(lhs, rhs, out): + linalg.elemwise_unary(lhs, outs=[out], emit_generic=True) + linalg.elemwise_binary(out, rhs, outs=[out], emit_generic=True) + + @func.FuncOp.from_py_func( + MemRefType.get((), f32), + MemRefType.get((4, 8), f32), + MemRefType.get((4, 8), f32), + ) + def elemwise_log_mul_on_buffers(lhs, rhs, out): + linalg.elemwise_unary( + lhs, outs=[out], fun=UnaryFn.log, emit_generic=True + ) + linalg.elemwise_binary( + out, rhs, outs=[out], fun=BinaryFn.mul, emit_generic=True + ) + + execution_engine = ExecutionEngine(transform(module, elemwise_boiler)) + + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result f32. + # Arguments must be passed as pointers. + c_float_p = ctypes.c_float * 1 + res = c_float_p(-1.0) + execution_engine.invoke("main", res) + + log("RESULT: ", res[0]) + # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846 + # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0 + # CHECK: RESULT: 4.71828 test_elemwise_generic() def test_matmul_builtin(): - with Context() as ctx, Location.unknown(): - module = Module.create() - f32 = F32Type.get() - i8 = IntegerType.get_signless(8) - with InsertionPoint(module.body): - - @func.FuncOp.from_py_func( - MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32), - MemRefType.get((4, 8), f32)) - def matmul_signed_on_buffers(lhs, rhs, out): - linalg.matmul(lhs, rhs, outs=[out]) - - @func.FuncOp.from_py_func( - MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32), - MemRefType.get((4, 8), f32)) - def matmul_unsigned_on_buffers(lhs, rhs, out): - linalg.matmul(lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned) - - execution_engine = ExecutionEngine(transform(module, matmul_boiler)) - - # TODO: FFI-based solution to allow testing and printing with python code. - # Prepare arguments: one result f32. - # Arguments must be passed as pointers. - c_float_p = ctypes.c_float * 1 - res = c_float_p(-1.) - execution_engine.invoke("main", res) - - log("RESULT: ", res[0]) - # matmul_signed_on_buffers: -1 * 2.0 * 16 = -32 - # matmul_unsigned_on_buffers: (2^8-1) * 2.0 * 16 = 8160 - # CHECK: RESULT: 8128 + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + i8 = IntegerType.get_signless(8) + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func( + MemRefType.get((4, 16), i8), + MemRefType.get((16, 8), f32), + MemRefType.get((4, 8), f32), + ) + def matmul_signed_on_buffers(lhs, rhs, out): + linalg.matmul(lhs, rhs, outs=[out]) + + @func.FuncOp.from_py_func( + MemRefType.get((4, 16), i8), + MemRefType.get((16, 8), f32), + MemRefType.get((4, 8), f32), + ) + def matmul_unsigned_on_buffers(lhs, rhs, out): + linalg.matmul(lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned) + + execution_engine = ExecutionEngine(transform(module, matmul_boiler)) + + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result f32. + # Arguments must be passed as pointers. + c_float_p = ctypes.c_float * 1 + res = c_float_p(-1.0) + execution_engine.invoke("main", res) + + log("RESULT: ", res[0]) + # matmul_signed_on_buffers: -1 * 2.0 * 16 = -32 + # matmul_unsigned_on_buffers: (2^8-1) * 2.0 * 16 = 8160 + # CHECK: RESULT: 8128 test_matmul_builtin() def test_matmul_generic(): - with Context() as ctx, Location.unknown(): - module = Module.create() - f32 = F32Type.get() - i8 = IntegerType.get_signless(8) - with InsertionPoint(module.body): - - @func.FuncOp.from_py_func( - MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32), - MemRefType.get((4, 8), f32)) - def matmul_signed_on_buffers(lhs, rhs, out): - linalg.matmul(lhs, rhs, outs=[out], emit_generic=True) - - @func.FuncOp.from_py_func( - MemRefType.get((4, 16), i8), MemRefType.get((16, 8), f32), - MemRefType.get((4, 8), f32)) - def matmul_unsigned_on_buffers(lhs, rhs, out): - linalg.matmul( - lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned, emit_generic=True) - - execution_engine = ExecutionEngine(transform(module, matmul_boiler)) - - # TODO: FFI-based solution to allow testing and printing with python code. - # Prepare arguments: one result f32. - # Arguments must be passed as pointers. - c_float_p = ctypes.c_float * 1 - res = c_float_p(-1.) - execution_engine.invoke("main", res) - - log("RESULT: ", res[0]) - # matmul_signed_on_buffers = -1 * 2.0 * 16 = -32 - # matmul_unsigned_on_buffers = (2^8-1) * 2.0 * 16 = 8160 - # CHECK: RESULT: 8128 + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + i8 = IntegerType.get_signless(8) + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func( + MemRefType.get((4, 16), i8), + MemRefType.get((16, 8), f32), + MemRefType.get((4, 8), f32), + ) + def matmul_signed_on_buffers(lhs, rhs, out): + linalg.matmul(lhs, rhs, outs=[out], emit_generic=True) + + @func.FuncOp.from_py_func( + MemRefType.get((4, 16), i8), + MemRefType.get((16, 8), f32), + MemRefType.get((4, 8), f32), + ) + def matmul_unsigned_on_buffers(lhs, rhs, out): + linalg.matmul( + lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned, emit_generic=True + ) + + execution_engine = ExecutionEngine(transform(module, matmul_boiler)) + + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result f32. + # Arguments must be passed as pointers. + c_float_p = ctypes.c_float * 1 + res = c_float_p(-1.0) + execution_engine.invoke("main", res) + + log("RESULT: ", res[0]) + # matmul_signed_on_buffers = -1 * 2.0 * 16 = -32 + # matmul_unsigned_on_buffers = (2^8-1) * 2.0 * 16 = 8160 + # CHECK: RESULT: 8128 test_matmul_generic() def test_fill_builtin(): - with Context() as ctx, Location.unknown(): - module = Module.create() - f32 = F32Type.get() - i32 = IntegerType.get_signless(32) - with InsertionPoint(module.body): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): - @func.FuncOp.from_py_func(f32, MemRefType.get([], i32)) - def fill_0d_on_buffers(value, out): - linalg.fill(value, outs=[out]) + @func.FuncOp.from_py_func(f32, MemRefType.get([], i32)) + def fill_0d_on_buffers(value, out): + linalg.fill(value, outs=[out]) - @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32)) - def fill_1d_on_buffers(value, out): - linalg.fill(value, outs=[out]) + @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32)) + def fill_1d_on_buffers(value, out): + linalg.fill(value, outs=[out]) - @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32)) - def fill_2d_on_buffers(value, out): - linalg.fill(value, outs=[out]) + @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32)) + def fill_2d_on_buffers(value, out): + linalg.fill(value, outs=[out]) - execution_engine = ExecutionEngine(transform(module, fill_boiler)) + execution_engine = ExecutionEngine(transform(module, fill_boiler)) - # TODO: FFI-based solution to allow testing and printing with python code. - # Prepare arguments: one result i32. - # Arguments must be passed as pointers. - c_int_p = ctypes.c_int * 1 - res = c_int_p(-1) - execution_engine.invoke("main", res) + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result i32. + # Arguments must be passed as pointers. + c_int_p = ctypes.c_int * 1 + res = c_int_p(-1) + execution_engine.invoke("main", res) - log("RESULT: ", res[0]) - # CHECK: RESULT: 6 + log("RESULT: ", res[0]) + # CHECK: RESULT: 6 test_fill_builtin() def test_fill_generic(): - with Context() as ctx, Location.unknown(): - module = Module.create() - f32 = F32Type.get() - i32 = IntegerType.get_signless(32) - with InsertionPoint(module.body): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): - @func.FuncOp.from_py_func(f32, MemRefType.get([], i32)) - def fill_0d_on_buffers(value, out): - linalg.fill(value, outs=[out], emit_generic=True) + @func.FuncOp.from_py_func(f32, MemRefType.get([], i32)) + def fill_0d_on_buffers(value, out): + linalg.fill(value, outs=[out], emit_generic=True) - @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32)) - def fill_1d_on_buffers(value, out): - linalg.fill(value, outs=[out], emit_generic=True) + @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32)) + def fill_1d_on_buffers(value, out): + linalg.fill(value, outs=[out], emit_generic=True) - @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32)) - def fill_2d_on_buffers(value, out): - linalg.fill(value, outs=[out], emit_generic=True) + @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32)) + def fill_2d_on_buffers(value, out): + linalg.fill(value, outs=[out], emit_generic=True) - execution_engine = ExecutionEngine(transform(module, fill_boiler)) + execution_engine = ExecutionEngine(transform(module, fill_boiler)) - # TODO: FFI-based solution to allow testing and printing with python code. - # Prepare arguments: one result i32. - # Arguments must be passed as pointers. - c_int_p = ctypes.c_int * 1 - res = c_int_p(-1) - execution_engine.invoke("main", res) + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result i32. + # Arguments must be passed as pointers. + c_int_p = ctypes.c_int * 1 + res = c_int_p(-1) + execution_engine.invoke("main", res) - log("RESULT: ", res[0]) - # CHECK: RESULT: 6 + log("RESULT: ", res[0]) + # CHECK: RESULT: 6 test_fill_generic() def test_fill_rng_builtin(): - with Context() as ctx, Location.unknown(): - module = Module.create() - f64 = F64Type.get() - i32 = IntegerType.get_signless(32) - with InsertionPoint(module.body): + with Context() as ctx, Location.unknown(): + module = Module.create() + f64 = F64Type.get() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): - @func.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32)) - def fill_rng_on_buffers(min, max, seed, out): - linalg.fill_rng_2d(min, max, seed, outs=[out]) + @func.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32)) + def fill_rng_on_buffers(min, max, seed, out): + linalg.fill_rng_2d(min, max, seed, outs=[out]) - execution_engine = ExecutionEngine(transform(module, fill_rng_boiler)) + execution_engine = ExecutionEngine(transform(module, fill_rng_boiler)) - # TODO: FFI-based solution to allow testing and printing with python code. - # Prepare arguments: one result i32. - # Arguments must be passed as pointers. - c_int_p = ctypes.c_int * 1 - res = c_int_p(-1) - execution_engine.invoke("main", res) + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result i32. + # Arguments must be passed as pointers. + c_int_p = ctypes.c_int * 1 + res = c_int_p(-1) + execution_engine.invoke("main", res) - log("RESULT: ", res[0]) - # CHECK: RESULT: -480 + log("RESULT: ", res[0]) + # CHECK: RESULT: -480 test_fill_rng_builtin() def test_fill_rng_generic(): - with Context() as ctx, Location.unknown(): - module = Module.create() - f64 = F64Type.get() - i32 = IntegerType.get_signless(32) - with InsertionPoint(module.body): + with Context() as ctx, Location.unknown(): + module = Module.create() + f64 = F64Type.get() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): - @func.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32)) - def fill_rng_on_buffers(min, max, seed, out): - linalg.fill_rng_2d(min, max, seed, outs=[out], emit_generic=True) + @func.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32)) + def fill_rng_on_buffers(min, max, seed, out): + linalg.fill_rng_2d(min, max, seed, outs=[out], emit_generic=True) - execution_engine = ExecutionEngine(transform(module, fill_rng_boiler)) + execution_engine = ExecutionEngine(transform(module, fill_rng_boiler)) - # TODO: FFI-based solution to allow testing and printing with python code. - # Prepare arguments: one result i32. - # Arguments must be passed as pointers. - c_int_p = ctypes.c_int * 1 - res = c_int_p(-1) - execution_engine.invoke("main", res) + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result i32. + # Arguments must be passed as pointers. + c_int_p = ctypes.c_int * 1 + res = c_int_p(-1) + execution_engine.invoke("main", res) - log("RESULT: ", res[0]) - # CHECK: RESULT: -480 + log("RESULT: ", res[0]) + # CHECK: RESULT: -480 test_fill_rng_generic() def test_max_pooling_builtin(): - with Context() as ctx, Location.unknown(): - module = Module.create() - f64 = F64Type.get() - i32 = IntegerType.get_signless(32) - with InsertionPoint(module.body): - - @func.FuncOp.from_py_func( - MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64), - MemRefType.get((1, 2, 4, 1), i32)) - def pooling_on_buffers(input, shape, output): - linalg.pooling_nhwc_max( - input, shape, outs=[output], strides=[2, 4], dilations=[1, 2]) - - execution_engine = ExecutionEngine(transform(module, pooling_boiler)) - - # TODO: FFI-based solution to allow testing and printing with python code. - # Prepare arguments: one result i32. - # Arguments must be passed as pointers. - c_int_p = ctypes.c_int * 1 - res = c_int_p(-1) - execution_engine.invoke("main", res) - - log("RESULT: ", res[0]) - # 77 is not selected due to the dilation 2 in the second dimension. - # CHECK: RESULT: 42 + with Context() as ctx, Location.unknown(): + module = Module.create() + f64 = F64Type.get() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func( + MemRefType.get((1, 4, 16, 1), f64), + MemRefType.get((2, 2), f64), + MemRefType.get((1, 2, 4, 1), i32), + ) + def pooling_on_buffers(input, shape, output): + linalg.pooling_nhwc_max( + input, shape, outs=[output], strides=[2, 4], dilations=[1, 2] + ) + + execution_engine = ExecutionEngine(transform(module, pooling_boiler)) + + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result i32. + # Arguments must be passed as pointers. + c_int_p = ctypes.c_int * 1 + res = c_int_p(-1) + execution_engine.invoke("main", res) + + log("RESULT: ", res[0]) + # 77 is not selected due to the dilation 2 in the second dimension. + # CHECK: RESULT: 42 test_max_pooling_builtin() def test_max_pooling_generic(): - with Context() as ctx, Location.unknown(): - module = Module.create() - f64 = F64Type.get() - i32 = IntegerType.get_signless(32) - with InsertionPoint(module.body): - - @func.FuncOp.from_py_func( - MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64), - MemRefType.get((1, 2, 4, 1), i32)) - def pooling_on_buffers(input, shape, output): - linalg.pooling_nhwc_max( - input, - shape, - outs=[output], - strides=[2, 4], - dilations=[1, 2], - emit_generic=True) - - execution_engine = ExecutionEngine(transform(module, pooling_boiler)) - - # TODO: FFI-based solution to allow testing and printing with python code. - # Prepare arguments: one result i32. - # Arguments must be passed as pointers. - c_int_p = ctypes.c_int * 1 - res = c_int_p(-1) - execution_engine.invoke("main", res) - - log("RESULT: ", res[0]) - # 77 is not selected due to the dilation 2 in the second dimension. - # CHECK: RESULT: 42 + with Context() as ctx, Location.unknown(): + module = Module.create() + f64 = F64Type.get() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func( + MemRefType.get((1, 4, 16, 1), f64), + MemRefType.get((2, 2), f64), + MemRefType.get((1, 2, 4, 1), i32), + ) + def pooling_on_buffers(input, shape, output): + linalg.pooling_nhwc_max( + input, + shape, + outs=[output], + strides=[2, 4], + dilations=[1, 2], + emit_generic=True, + ) + + execution_engine = ExecutionEngine(transform(module, pooling_boiler)) + + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result i32. + # Arguments must be passed as pointers. + c_int_p = ctypes.c_int * 1 + res = c_int_p(-1) + execution_engine.invoke("main", res) + + log("RESULT: ", res[0]) + # 77 is not selected due to the dilation 2 in the second dimension. + # CHECK: RESULT: 42 test_max_pooling_generic() def test_min_pooling_builtin(): - with Context() as ctx, Location.unknown(): - module = Module.create() - f64 = F64Type.get() - i32 = IntegerType.get_signless(32) - with InsertionPoint(module.body): - - @func.FuncOp.from_py_func( - MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64), - MemRefType.get((1, 2, 4, 1), i32)) - # Set the strides and use the default dilations. - def pooling_on_buffers(input, shape, output): - linalg.pooling_nhwc_min(input, shape, outs=[output], strides=[2, 4]) - - execution_engine = ExecutionEngine(transform(module, pooling_boiler)) - - # TODO: FFI-based solution to allow testing and printing with python code. - # Prepare arguments: one result i32. - # Arguments must be passed as pointers. - c_int_p = ctypes.c_int * 1 - res = c_int_p(-1) - execution_engine.invoke("main", res) - - log("RESULT: ", res[0]) - # CHECK: RESULT: -13 + with Context() as ctx, Location.unknown(): + module = Module.create() + f64 = F64Type.get() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func( + MemRefType.get((1, 4, 16, 1), f64), + MemRefType.get((2, 2), f64), + MemRefType.get((1, 2, 4, 1), i32), + ) + # Set the strides and use the default dilations. + def pooling_on_buffers(input, shape, output): + linalg.pooling_nhwc_min(input, shape, outs=[output], strides=[2, 4]) + + execution_engine = ExecutionEngine(transform(module, pooling_boiler)) + + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result i32. + # Arguments must be passed as pointers. + c_int_p = ctypes.c_int * 1 + res = c_int_p(-1) + execution_engine.invoke("main", res) + + log("RESULT: ", res[0]) + # CHECK: RESULT: -13 test_min_pooling_builtin() def test_min_pooling_generic(): - with Context() as ctx, Location.unknown(): - module = Module.create() - f64 = F64Type.get() - i32 = IntegerType.get_signless(32) - with InsertionPoint(module.body): - - @func.FuncOp.from_py_func( - MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64), - MemRefType.get((1, 2, 4, 1), i32)) - # Set the strides and use the default dilations. - def pooling_on_buffers(input, shape, output): - linalg.pooling_nhwc_min( - input, shape, outs=[output], strides=[2, 4], emit_generic=True) - - execution_engine = ExecutionEngine(transform(module, pooling_boiler)) - - # TODO: FFI-based solution to allow testing and printing with python code. - # Prepare arguments: one result i32. - # Arguments must be passed as pointers. - c_int_p = ctypes.c_int * 1 - res = c_int_p(-1) - execution_engine.invoke("main", res) - - log("RESULT: ", res[0]) - # CHECK: RESULT: -13 + with Context() as ctx, Location.unknown(): + module = Module.create() + f64 = F64Type.get() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func( + MemRefType.get((1, 4, 16, 1), f64), + MemRefType.get((2, 2), f64), + MemRefType.get((1, 2, 4, 1), i32), + ) + # Set the strides and use the default dilations. + def pooling_on_buffers(input, shape, output): + linalg.pooling_nhwc_min( + input, shape, outs=[output], strides=[2, 4], emit_generic=True + ) + + execution_engine = ExecutionEngine(transform(module, pooling_boiler)) + + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result i32. + # Arguments must be passed as pointers. + c_int_p = ctypes.c_int * 1 + res = c_int_p(-1) + execution_engine.invoke("main", res) + + log("RESULT: ", res[0]) + # CHECK: RESULT: -13 test_min_pooling_generic() diff --git a/mlir/test/python/ir/affine_expr.py b/mlir/test/python/ir/affine_expr.py index 6a3a6fc..6356430 100644 --- a/mlir/test/python/ir/affine_expr.py +++ b/mlir/test/python/ir/affine_expr.py @@ -3,59 +3,61 @@ import gc from mlir.ir import * + def run(f): - print("\nTEST:", f.__name__) - f() - gc.collect() - assert Context._get_live_count() == 0 - return f + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + return f # CHECK-LABEL: TEST: testAffineExprCapsule @run def testAffineExprCapsule(): - with Context() as ctx: - affine_expr = AffineExpr.get_constant(42) + with Context() as ctx: + affine_expr = AffineExpr.get_constant(42) - affine_expr_capsule = affine_expr._CAPIPtr - # CHECK: capsule object - # CHECK: mlir.ir.AffineExpr._CAPIPtr - print(affine_expr_capsule) + affine_expr_capsule = affine_expr._CAPIPtr + # CHECK: capsule object + # CHECK: mlir.ir.AffineExpr._CAPIPtr + print(affine_expr_capsule) - affine_expr_2 = AffineExpr._CAPICreate(affine_expr_capsule) - assert affine_expr == affine_expr_2 - assert affine_expr_2.context == ctx + affine_expr_2 = AffineExpr._CAPICreate(affine_expr_capsule) + assert affine_expr == affine_expr_2 + assert affine_expr_2.context == ctx # CHECK-LABEL: TEST: testAffineExprEq @run def testAffineExprEq(): - with Context(): - a1 = AffineExpr.get_constant(42) - a2 = AffineExpr.get_constant(42) - a3 = AffineExpr.get_constant(43) - # CHECK: True - print(a1 == a1) - # CHECK: True - print(a1 == a2) - # CHECK: False - print(a1 == a3) - # CHECK: False - print(a1 == None) - # CHECK: False - print(a1 == "foo") + with Context(): + a1 = AffineExpr.get_constant(42) + a2 = AffineExpr.get_constant(42) + a3 = AffineExpr.get_constant(43) + # CHECK: True + print(a1 == a1) + # CHECK: True + print(a1 == a2) + # CHECK: False + print(a1 == a3) + # CHECK: False + print(a1 == None) + # CHECK: False + print(a1 == "foo") # CHECK-LABEL: TEST: testAffineExprContext @run def testAffineExprContext(): - with Context(): - a1 = AffineExpr.get_constant(42) - with Context(): - a2 = AffineExpr.get_constant(42) + with Context(): + a1 = AffineExpr.get_constant(42) + with Context(): + a2 = AffineExpr.get_constant(42) + + # CHECK: False + print(a1 == a2) - # CHECK: False - print(a1 == a2) run(testAffineExprContext) @@ -63,340 +65,343 @@ run(testAffineExprContext) # CHECK-LABEL: TEST: testAffineExprConstant @run def testAffineExprConstant(): - with Context(): - a1 = AffineExpr.get_constant(42) - # CHECK: 42 - print(a1.value) - # CHECK: 42 - print(a1) + with Context(): + a1 = AffineExpr.get_constant(42) + # CHECK: 42 + print(a1.value) + # CHECK: 42 + print(a1) - a2 = AffineConstantExpr.get(42) - # CHECK: 42 - print(a2.value) - # CHECK: 42 - print(a2) + a2 = AffineConstantExpr.get(42) + # CHECK: 42 + print(a2.value) + # CHECK: 42 + print(a2) - assert a1 == a2 + assert a1 == a2 # CHECK-LABEL: TEST: testAffineExprDim @run def testAffineExprDim(): - with Context(): - d1 = AffineExpr.get_dim(1) - d11 = AffineDimExpr.get(1) - d2 = AffineDimExpr.get(2) + with Context(): + d1 = AffineExpr.get_dim(1) + d11 = AffineDimExpr.get(1) + d2 = AffineDimExpr.get(2) - # CHECK: 1 - print(d1.position) - # CHECK: d1 - print(d1) + # CHECK: 1 + print(d1.position) + # CHECK: d1 + print(d1) - # CHECK: 2 - print(d2.position) - # CHECK: d2 - print(d2) + # CHECK: 2 + print(d2.position) + # CHECK: d2 + print(d2) - assert d1 == d11 - assert d1 != d2 + assert d1 == d11 + assert d1 != d2 # CHECK-LABEL: TEST: testAffineExprSymbol @run def testAffineExprSymbol(): - with Context(): - s1 = AffineExpr.get_symbol(1) - s11 = AffineSymbolExpr.get(1) - s2 = AffineSymbolExpr.get(2) + with Context(): + s1 = AffineExpr.get_symbol(1) + s11 = AffineSymbolExpr.get(1) + s2 = AffineSymbolExpr.get(2) - # CHECK: 1 - print(s1.position) - # CHECK: s1 - print(s1) + # CHECK: 1 + print(s1.position) + # CHECK: s1 + print(s1) - # CHECK: 2 - print(s2.position) - # CHECK: s2 - print(s2) + # CHECK: 2 + print(s2.position) + # CHECK: s2 + print(s2) - assert s1 == s11 - assert s1 != s2 + assert s1 == s11 + assert s1 != s2 # CHECK-LABEL: TEST: testAffineAddExpr @run def testAffineAddExpr(): - with Context(): - d1 = AffineDimExpr.get(1) - d2 = AffineDimExpr.get(2) - d12 = AffineExpr.get_add(d1, d2) - # CHECK: d1 + d2 - print(d12) + with Context(): + d1 = AffineDimExpr.get(1) + d2 = AffineDimExpr.get(2) + d12 = AffineExpr.get_add(d1, d2) + # CHECK: d1 + d2 + print(d12) - d12op = d1 + d2 - # CHECK: d1 + d2 - print(d12op) + d12op = d1 + d2 + # CHECK: d1 + d2 + print(d12op) - d1cst_op = d1 + 2 - # CHECK: d1 + 2 - print(d1cst_op) + d1cst_op = d1 + 2 + # CHECK: d1 + 2 + print(d1cst_op) - d1cst_op2 = 2 + d1 - # CHECK: d1 + 2 - print(d1cst_op2) + d1cst_op2 = 2 + d1 + # CHECK: d1 + 2 + print(d1cst_op2) - assert d12 == d12op - assert d12.lhs == d1 - assert d12.rhs == d2 + assert d12 == d12op + assert d12.lhs == d1 + assert d12.rhs == d2 # CHECK-LABEL: TEST: testAffineMulExpr @run def testAffineMulExpr(): - with Context(): - d1 = AffineDimExpr.get(1) - c2 = AffineConstantExpr.get(2) - expr = AffineExpr.get_mul(d1, c2) - # CHECK: d1 * 2 - print(expr) + with Context(): + d1 = AffineDimExpr.get(1) + c2 = AffineConstantExpr.get(2) + expr = AffineExpr.get_mul(d1, c2) + # CHECK: d1 * 2 + print(expr) - # CHECK: d1 * 2 - op = d1 * c2 - print(op) + # CHECK: d1 * 2 + op = d1 * c2 + print(op) - # CHECK: d1 * 2 - op_cst = d1 * 2 - print(op_cst) + # CHECK: d1 * 2 + op_cst = d1 * 2 + print(op_cst) - # CHECK: d1 * 2 - op_cst2 = 2 * d1 - print(op_cst2) + # CHECK: d1 * 2 + op_cst2 = 2 * d1 + print(op_cst2) - assert expr == op - assert expr == op_cst - assert expr.lhs == d1 - assert expr.rhs == c2 + assert expr == op + assert expr == op_cst + assert expr.lhs == d1 + assert expr.rhs == c2 # CHECK-LABEL: TEST: testAffineModExpr @run def testAffineModExpr(): - with Context(): - d1 = AffineDimExpr.get(1) - c2 = AffineConstantExpr.get(2) - expr = AffineExpr.get_mod(d1, c2) - # CHECK: d1 mod 2 - print(expr) + with Context(): + d1 = AffineDimExpr.get(1) + c2 = AffineConstantExpr.get(2) + expr = AffineExpr.get_mod(d1, c2) + # CHECK: d1 mod 2 + print(expr) - # CHECK: d1 mod 2 - op = d1 % c2 - print(op) + # CHECK: d1 mod 2 + op = d1 % c2 + print(op) - # CHECK: d1 mod 2 - op_cst = d1 % 2 - print(op_cst) + # CHECK: d1 mod 2 + op_cst = d1 % 2 + print(op_cst) - # CHECK: 2 mod d1 - print(2 % d1) + # CHECK: 2 mod d1 + print(2 % d1) - assert expr == op - assert expr == op_cst - assert expr.lhs == d1 - assert expr.rhs == c2 + assert expr == op + assert expr == op_cst + assert expr.lhs == d1 + assert expr.rhs == c2 - expr2 = AffineExpr.get_mod(c2, d1) - expr3 = AffineExpr.get_mod(2, d1) - expr4 = AffineExpr.get_mod(d1, 2) + expr2 = AffineExpr.get_mod(c2, d1) + expr3 = AffineExpr.get_mod(2, d1) + expr4 = AffineExpr.get_mod(d1, 2) - # CHECK: 2 mod d1 - print(expr2) - # CHECK: 2 mod d1 - print(expr3) - # CHECK: d1 mod 2 - print(expr4) + # CHECK: 2 mod d1 + print(expr2) + # CHECK: 2 mod d1 + print(expr3) + # CHECK: d1 mod 2 + print(expr4) - assert expr2 == expr3 - assert expr4 == expr + assert expr2 == expr3 + assert expr4 == expr # CHECK-LABEL: TEST: testAffineFloorDivExpr @run def testAffineFloorDivExpr(): - with Context(): - d1 = AffineDimExpr.get(1) - c2 = AffineConstantExpr.get(2) - expr = AffineExpr.get_floor_div(d1, c2) - # CHECK: d1 floordiv 2 - print(expr) + with Context(): + d1 = AffineDimExpr.get(1) + c2 = AffineConstantExpr.get(2) + expr = AffineExpr.get_floor_div(d1, c2) + # CHECK: d1 floordiv 2 + print(expr) - assert expr.lhs == d1 - assert expr.rhs == c2 + assert expr.lhs == d1 + assert expr.rhs == c2 - expr2 = AffineExpr.get_floor_div(c2, d1) - expr3 = AffineExpr.get_floor_div(2, d1) - expr4 = AffineExpr.get_floor_div(d1, 2) + expr2 = AffineExpr.get_floor_div(c2, d1) + expr3 = AffineExpr.get_floor_div(2, d1) + expr4 = AffineExpr.get_floor_div(d1, 2) - # CHECK: 2 floordiv d1 - print(expr2) - # CHECK: 2 floordiv d1 - print(expr3) - # CHECK: d1 floordiv 2 - print(expr4) + # CHECK: 2 floordiv d1 + print(expr2) + # CHECK: 2 floordiv d1 + print(expr3) + # CHECK: d1 floordiv 2 + print(expr4) - assert expr2 == expr3 - assert expr4 == expr + assert expr2 == expr3 + assert expr4 == expr # CHECK-LABEL: TEST: testAffineCeilDivExpr @run def testAffineCeilDivExpr(): - with Context(): - d1 = AffineDimExpr.get(1) - c2 = AffineConstantExpr.get(2) - expr = AffineExpr.get_ceil_div(d1, c2) - # CHECK: d1 ceildiv 2 - print(expr) + with Context(): + d1 = AffineDimExpr.get(1) + c2 = AffineConstantExpr.get(2) + expr = AffineExpr.get_ceil_div(d1, c2) + # CHECK: d1 ceildiv 2 + print(expr) - assert expr.lhs == d1 - assert expr.rhs == c2 + assert expr.lhs == d1 + assert expr.rhs == c2 - expr2 = AffineExpr.get_ceil_div(c2, d1) - expr3 = AffineExpr.get_ceil_div(2, d1) - expr4 = AffineExpr.get_ceil_div(d1, 2) + expr2 = AffineExpr.get_ceil_div(c2, d1) + expr3 = AffineExpr.get_ceil_div(2, d1) + expr4 = AffineExpr.get_ceil_div(d1, 2) - # CHECK: 2 ceildiv d1 - print(expr2) - # CHECK: 2 ceildiv d1 - print(expr3) - # CHECK: d1 ceildiv 2 - print(expr4) + # CHECK: 2 ceildiv d1 + print(expr2) + # CHECK: 2 ceildiv d1 + print(expr3) + # CHECK: d1 ceildiv 2 + print(expr4) - assert expr2 == expr3 - assert expr4 == expr + assert expr2 == expr3 + assert expr4 == expr # CHECK-LABEL: TEST: testAffineExprSub @run def testAffineExprSub(): - with Context(): - d1 = AffineDimExpr.get(1) - d2 = AffineDimExpr.get(2) - expr = d1 - d2 - # CHECK: d1 - d2 - print(expr) - - assert expr.lhs == d1 - rhs = AffineMulExpr(expr.rhs) - # CHECK: d2 - print(rhs.lhs) - # CHECK: -1 - print(rhs.rhs) - - # CHECK: d1 - 42 - print(d1 - 42) - # CHECK: -d1 + 42 - print(42 - d1) - - c42 = AffineConstantExpr.get(42) - assert d1 - 42 == d1 - c42 - assert 42 - d1 == c42 - d1 + with Context(): + d1 = AffineDimExpr.get(1) + d2 = AffineDimExpr.get(2) + expr = d1 - d2 + # CHECK: d1 - d2 + print(expr) + + assert expr.lhs == d1 + rhs = AffineMulExpr(expr.rhs) + # CHECK: d2 + print(rhs.lhs) + # CHECK: -1 + print(rhs.rhs) + + # CHECK: d1 - 42 + print(d1 - 42) + # CHECK: -d1 + 42 + print(42 - d1) + + c42 = AffineConstantExpr.get(42) + assert d1 - 42 == d1 - c42 + assert 42 - d1 == c42 - d1 + # CHECK-LABEL: TEST: testClassHierarchy @run def testClassHierarchy(): - with Context(): - d1 = AffineDimExpr.get(1) - c2 = AffineConstantExpr.get(2) - add = AffineAddExpr.get(d1, c2) - mul = AffineMulExpr.get(d1, c2) - mod = AffineModExpr.get(d1, c2) - floor_div = AffineFloorDivExpr.get(d1, c2) - ceil_div = AffineCeilDivExpr.get(d1, c2) + with Context(): + d1 = AffineDimExpr.get(1) + c2 = AffineConstantExpr.get(2) + add = AffineAddExpr.get(d1, c2) + mul = AffineMulExpr.get(d1, c2) + mod = AffineModExpr.get(d1, c2) + floor_div = AffineFloorDivExpr.get(d1, c2) + ceil_div = AffineCeilDivExpr.get(d1, c2) + + # CHECK: False + print(isinstance(d1, AffineBinaryExpr)) + # CHECK: False + print(isinstance(c2, AffineBinaryExpr)) + # CHECK: True + print(isinstance(add, AffineBinaryExpr)) + # CHECK: True + print(isinstance(mul, AffineBinaryExpr)) + # CHECK: True + print(isinstance(mod, AffineBinaryExpr)) + # CHECK: True + print(isinstance(floor_div, AffineBinaryExpr)) + # CHECK: True + print(isinstance(ceil_div, AffineBinaryExpr)) + + try: + AffineBinaryExpr(d1) + except ValueError as e: + # CHECK: Cannot cast affine expression to AffineBinaryExpr + print(e) + + try: + AffineBinaryExpr(c2) + except ValueError as e: + # CHECK: Cannot cast affine expression to AffineBinaryExpr + print(e) - # CHECK: False - print(isinstance(d1, AffineBinaryExpr)) - # CHECK: False - print(isinstance(c2, AffineBinaryExpr)) - # CHECK: True - print(isinstance(add, AffineBinaryExpr)) - # CHECK: True - print(isinstance(mul, AffineBinaryExpr)) - # CHECK: True - print(isinstance(mod, AffineBinaryExpr)) - # CHECK: True - print(isinstance(floor_div, AffineBinaryExpr)) - # CHECK: True - print(isinstance(ceil_div, AffineBinaryExpr)) - - try: - AffineBinaryExpr(d1) - except ValueError as e: - # CHECK: Cannot cast affine expression to AffineBinaryExpr - print(e) - - try: - AffineBinaryExpr(c2) - except ValueError as e: - # CHECK: Cannot cast affine expression to AffineBinaryExpr - print(e) # CHECK-LABEL: TEST: testIsInstance @run def testIsInstance(): - with Context(): - d1 = AffineDimExpr.get(1) - c2 = AffineConstantExpr.get(2) - add = AffineAddExpr.get(d1, c2) - mul = AffineMulExpr.get(d1, c2) - - # CHECK: True - print(AffineDimExpr.isinstance(d1)) - # CHECK: False - print(AffineConstantExpr.isinstance(d1)) - # CHECK: True - print(AffineConstantExpr.isinstance(c2)) - # CHECK: False - print(AffineMulExpr.isinstance(c2)) - # CHECK: True - print(AffineAddExpr.isinstance(add)) - # CHECK: False - print(AffineMulExpr.isinstance(add)) - # CHECK: True - print(AffineMulExpr.isinstance(mul)) - # CHECK: False - print(AffineAddExpr.isinstance(mul)) + with Context(): + d1 = AffineDimExpr.get(1) + c2 = AffineConstantExpr.get(2) + add = AffineAddExpr.get(d1, c2) + mul = AffineMulExpr.get(d1, c2) + + # CHECK: True + print(AffineDimExpr.isinstance(d1)) + # CHECK: False + print(AffineConstantExpr.isinstance(d1)) + # CHECK: True + print(AffineConstantExpr.isinstance(c2)) + # CHECK: False + print(AffineMulExpr.isinstance(c2)) + # CHECK: True + print(AffineAddExpr.isinstance(add)) + # CHECK: False + print(AffineMulExpr.isinstance(add)) + # CHECK: True + print(AffineMulExpr.isinstance(mul)) + # CHECK: False + print(AffineAddExpr.isinstance(mul)) # CHECK-LABEL: TEST: testCompose @run def testCompose(): - with Context(): - # d0 + d2. - expr = AffineAddExpr.get(AffineDimExpr.get(0), AffineDimExpr.get(2)) + with Context(): + # d0 + d2. + expr = AffineAddExpr.get(AffineDimExpr.get(0), AffineDimExpr.get(2)) - # (d0, d1, d2)[s0, s1] -> (d0 + s1, d1 + s0, d0 + d1 + d2) - map1 = AffineAddExpr.get(AffineDimExpr.get(0), AffineSymbolExpr.get(1)) - map2 = AffineAddExpr.get(AffineDimExpr.get(1), AffineSymbolExpr.get(0)) - map3 = AffineAddExpr.get( - AffineAddExpr.get(AffineDimExpr.get(0), AffineDimExpr.get(1)), - AffineDimExpr.get(2)) - map = AffineMap.get(3, 2, [map1, map2, map3]) + # (d0, d1, d2)[s0, s1] -> (d0 + s1, d1 + s0, d0 + d1 + d2) + map1 = AffineAddExpr.get(AffineDimExpr.get(0), AffineSymbolExpr.get(1)) + map2 = AffineAddExpr.get(AffineDimExpr.get(1), AffineSymbolExpr.get(0)) + map3 = AffineAddExpr.get( + AffineAddExpr.get(AffineDimExpr.get(0), AffineDimExpr.get(1)), + AffineDimExpr.get(2), + ) + map = AffineMap.get(3, 2, [map1, map2, map3]) - # CHECK: d0 + s1 + d0 + d1 + d2 - print(expr.compose(map)) + # CHECK: d0 + s1 + d0 + d1 + d2 + print(expr.compose(map)) # CHECK-LABEL: TEST: testHash @run def testHash(): - with Context(): - d0 = AffineDimExpr.get(0) - s1 = AffineSymbolExpr.get(1) - assert hash(d0) == hash(AffineDimExpr.get(0)) - assert hash(d0 + s1) == hash(AffineAddExpr.get(d0, s1)) - - dictionary = dict() - dictionary[d0] = 0 - dictionary[s1] = 1 - assert d0 in dictionary - assert s1 in dictionary + with Context(): + d0 = AffineDimExpr.get(0) + s1 = AffineSymbolExpr.get(1) + assert hash(d0) == hash(AffineDimExpr.get(0)) + assert hash(d0 + s1) == hash(AffineAddExpr.get(d0, s1)) + + dictionary = dict() + dictionary[d0] = 0 + dictionary[s1] = 1 + assert d0 in dictionary + assert s1 in dictionary diff --git a/mlir/test/python/ir/affine_map.py b/mlir/test/python/ir/affine_map.py index 52c7261..672335e 100644 --- a/mlir/test/python/ir/affine_map.py +++ b/mlir/test/python/ir/affine_map.py @@ -5,237 +5,241 @@ from mlir.ir import * def run(f): - print("\nTEST:", f.__name__) - f() - gc.collect() - assert Context._get_live_count() == 0 - return f + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + return f # CHECK-LABEL: TEST: testAffineMapCapsule @run def testAffineMapCapsule(): - with Context() as ctx: - am1 = AffineMap.get_empty(ctx) - # CHECK: mlir.ir.AffineMap._CAPIPtr - affine_map_capsule = am1._CAPIPtr - print(affine_map_capsule) - am2 = AffineMap._CAPICreate(affine_map_capsule) - assert am2 == am1 - assert am2.context is ctx + with Context() as ctx: + am1 = AffineMap.get_empty(ctx) + # CHECK: mlir.ir.AffineMap._CAPIPtr + affine_map_capsule = am1._CAPIPtr + print(affine_map_capsule) + am2 = AffineMap._CAPICreate(affine_map_capsule) + assert am2 == am1 + assert am2.context is ctx # CHECK-LABEL: TEST: testAffineMapGet @run def testAffineMapGet(): - with Context() as ctx: - d0 = AffineDimExpr.get(0) - d1 = AffineDimExpr.get(1) - c2 = AffineConstantExpr.get(2) - - # CHECK: (d0, d1)[s0, s1, s2] -> () - map0 = AffineMap.get(2, 3, []) - print(map0) - - # CHECK: (d0, d1)[s0, s1, s2] -> (d1, 2) - map1 = AffineMap.get(2, 3, [d1, c2]) - print(map1) - - # CHECK: () -> (2) - map2 = AffineMap.get(0, 0, [c2]) - print(map2) - - # CHECK: (d0, d1) -> (d0, d1) - map3 = AffineMap.get(2, 0, [d0, d1]) - print(map3) - - # CHECK: (d0, d1) -> (d1) - map4 = AffineMap.get(2, 0, [d1]) - print(map4) - - # CHECK: (d0, d1, d2) -> (d2, d0, d1) - map5 = AffineMap.get_permutation([2, 0, 1]) - print(map5) - - assert map1 == AffineMap.get(2, 3, [d1, c2]) - assert AffineMap.get(0, 0, []) == AffineMap.get_empty() - assert map2 == AffineMap.get_constant(2) - assert map3 == AffineMap.get_identity(2) - assert map4 == AffineMap.get_minor_identity(2, 1) - - try: - AffineMap.get(1, 1, [1]) - except RuntimeError as e: - # CHECK: Invalid expression when attempting to create an AffineMap - print(e) - - try: - AffineMap.get(1, 1, [None]) - except RuntimeError as e: - # CHECK: Invalid expression (None?) when attempting to create an AffineMap - print(e) - - try: - AffineMap.get_permutation([1, 0, 1]) - except RuntimeError as e: - # CHECK: Invalid permutation when attempting to create an AffineMap - print(e) - - try: - map3.get_submap([42]) - except ValueError as e: - # CHECK: result position out of bounds - print(e) - - try: - map3.get_minor_submap(42) - except ValueError as e: - # CHECK: number of results out of bounds - print(e) - - try: - map3.get_major_submap(42) - except ValueError as e: - # CHECK: number of results out of bounds - print(e) + with Context() as ctx: + d0 = AffineDimExpr.get(0) + d1 = AffineDimExpr.get(1) + c2 = AffineConstantExpr.get(2) + + # CHECK: (d0, d1)[s0, s1, s2] -> () + map0 = AffineMap.get(2, 3, []) + print(map0) + + # CHECK: (d0, d1)[s0, s1, s2] -> (d1, 2) + map1 = AffineMap.get(2, 3, [d1, c2]) + print(map1) + + # CHECK: () -> (2) + map2 = AffineMap.get(0, 0, [c2]) + print(map2) + + # CHECK: (d0, d1) -> (d0, d1) + map3 = AffineMap.get(2, 0, [d0, d1]) + print(map3) + + # CHECK: (d0, d1) -> (d1) + map4 = AffineMap.get(2, 0, [d1]) + print(map4) + + # CHECK: (d0, d1, d2) -> (d2, d0, d1) + map5 = AffineMap.get_permutation([2, 0, 1]) + print(map5) + + assert map1 == AffineMap.get(2, 3, [d1, c2]) + assert AffineMap.get(0, 0, []) == AffineMap.get_empty() + assert map2 == AffineMap.get_constant(2) + assert map3 == AffineMap.get_identity(2) + assert map4 == AffineMap.get_minor_identity(2, 1) + + try: + AffineMap.get(1, 1, [1]) + except RuntimeError as e: + # CHECK: Invalid expression when attempting to create an AffineMap + print(e) + + try: + AffineMap.get(1, 1, [None]) + except RuntimeError as e: + # CHECK: Invalid expression (None?) when attempting to create an AffineMap + print(e) + + try: + AffineMap.get_permutation([1, 0, 1]) + except RuntimeError as e: + # CHECK: Invalid permutation when attempting to create an AffineMap + print(e) + + try: + map3.get_submap([42]) + except ValueError as e: + # CHECK: result position out of bounds + print(e) + + try: + map3.get_minor_submap(42) + except ValueError as e: + # CHECK: number of results out of bounds + print(e) + + try: + map3.get_major_submap(42) + except ValueError as e: + # CHECK: number of results out of bounds + print(e) # CHECK-LABEL: TEST: testAffineMapDerive @run def testAffineMapDerive(): - with Context() as ctx: - map5 = AffineMap.get_identity(5) + with Context() as ctx: + map5 = AffineMap.get_identity(5) - # CHECK: (d0, d1, d2, d3, d4) -> (d1, d2, d3) - map123 = map5.get_submap([1, 2, 3]) - print(map123) + # CHECK: (d0, d1, d2, d3, d4) -> (d1, d2, d3) + map123 = map5.get_submap([1, 2, 3]) + print(map123) - # CHECK: (d0, d1, d2, d3, d4) -> (d0, d1) - map01 = map5.get_major_submap(2) - print(map01) + # CHECK: (d0, d1, d2, d3, d4) -> (d0, d1) + map01 = map5.get_major_submap(2) + print(map01) - # CHECK: (d0, d1, d2, d3, d4) -> (d3, d4) - map34 = map5.get_minor_submap(2) - print(map34) + # CHECK: (d0, d1, d2, d3, d4) -> (d3, d4) + map34 = map5.get_minor_submap(2) + print(map34) # CHECK-LABEL: TEST: testAffineMapProperties @run def testAffineMapProperties(): - with Context(): - d0 = AffineDimExpr.get(0) - d1 = AffineDimExpr.get(1) - d2 = AffineDimExpr.get(2) - map1 = AffineMap.get(3, 0, [d2, d0]) - map2 = AffineMap.get(3, 0, [d2, d0, d1]) - map3 = AffineMap.get(3, 1, [d2, d0, d1]) - # CHECK: False - print(map1.is_permutation) - # CHECK: True - print(map1.is_projected_permutation) - # CHECK: True - print(map2.is_permutation) - # CHECK: True - print(map2.is_projected_permutation) - # CHECK: False - print(map3.is_permutation) - # CHECK: False - print(map3.is_projected_permutation) + with Context(): + d0 = AffineDimExpr.get(0) + d1 = AffineDimExpr.get(1) + d2 = AffineDimExpr.get(2) + map1 = AffineMap.get(3, 0, [d2, d0]) + map2 = AffineMap.get(3, 0, [d2, d0, d1]) + map3 = AffineMap.get(3, 1, [d2, d0, d1]) + # CHECK: False + print(map1.is_permutation) + # CHECK: True + print(map1.is_projected_permutation) + # CHECK: True + print(map2.is_permutation) + # CHECK: True + print(map2.is_projected_permutation) + # CHECK: False + print(map3.is_permutation) + # CHECK: False + print(map3.is_projected_permutation) # CHECK-LABEL: TEST: testAffineMapExprs @run def testAffineMapExprs(): - with Context(): - d0 = AffineDimExpr.get(0) - d1 = AffineDimExpr.get(1) - d2 = AffineDimExpr.get(2) - map3 = AffineMap.get(3, 1, [d2, d0, d1]) - - # CHECK: 3 - print(map3.n_dims) - # CHECK: 4 - print(map3.n_inputs) - # CHECK: 1 - print(map3.n_symbols) - assert map3.n_inputs == map3.n_dims + map3.n_symbols - - # CHECK: 3 - print(len(map3.results)) - for expr in map3.results: - # CHECK: d2 - # CHECK: d0 - # CHECK: d1 - print(expr) - for expr in map3.results[-1:-4:-1]: - # CHECK: d1 - # CHECK: d0 - # CHECK: d2 - print(expr) - assert list(map3.results) == [d2, d0, d1] + with Context(): + d0 = AffineDimExpr.get(0) + d1 = AffineDimExpr.get(1) + d2 = AffineDimExpr.get(2) + map3 = AffineMap.get(3, 1, [d2, d0, d1]) + + # CHECK: 3 + print(map3.n_dims) + # CHECK: 4 + print(map3.n_inputs) + # CHECK: 1 + print(map3.n_symbols) + assert map3.n_inputs == map3.n_dims + map3.n_symbols + + # CHECK: 3 + print(len(map3.results)) + for expr in map3.results: + # CHECK: d2 + # CHECK: d0 + # CHECK: d1 + print(expr) + for expr in map3.results[-1:-4:-1]: + # CHECK: d1 + # CHECK: d0 + # CHECK: d2 + print(expr) + assert list(map3.results) == [d2, d0, d1] # CHECK-LABEL: TEST: testCompressUnusedSymbols @run def testCompressUnusedSymbols(): - with Context() as ctx: - d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1), - AffineDimExpr.get(2)) - s0, s1, s2 = (AffineSymbolExpr.get(0), AffineSymbolExpr.get(1), - AffineSymbolExpr.get(2)) - maps = [ - AffineMap.get(3, 3, [d2, d0, d1]), - AffineMap.get(3, 3, [d2, d0 + s2, d1]), - AffineMap.get(3, 3, [d1, d2, d0]) - ] - - compressed_maps = AffineMap.compress_unused_symbols(maps, ctx) - - # CHECK: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0, d1)) - # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2, d1)) - # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d1, d2, d0)) - print(maps) - - # CHECK: AffineMap((d0, d1, d2)[s0] -> (d2, d0, d1)) - # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d2, d0 + s0, d1)) - # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d1, d2, d0)) - print(compressed_maps) + with Context() as ctx: + d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1), AffineDimExpr.get(2)) + s0, s1, s2 = ( + AffineSymbolExpr.get(0), + AffineSymbolExpr.get(1), + AffineSymbolExpr.get(2), + ) + maps = [ + AffineMap.get(3, 3, [d2, d0, d1]), + AffineMap.get(3, 3, [d2, d0 + s2, d1]), + AffineMap.get(3, 3, [d1, d2, d0]), + ] + + compressed_maps = AffineMap.compress_unused_symbols(maps, ctx) + + # CHECK: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0, d1)) + # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2, d1)) + # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d1, d2, d0)) + print(maps) + + # CHECK: AffineMap((d0, d1, d2)[s0] -> (d2, d0, d1)) + # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d2, d0 + s0, d1)) + # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d1, d2, d0)) + print(compressed_maps) # CHECK-LABEL: TEST: testReplace @run def testReplace(): - with Context() as ctx: - d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1), - AffineDimExpr.get(2)) - s0, s1, s2 = (AffineSymbolExpr.get(0), AffineSymbolExpr.get(1), - AffineSymbolExpr.get(2)) - map1 = AffineMap.get(3, 3, [d2, d0 + s1 + s2, d1 + s0]) + with Context() as ctx: + d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1), AffineDimExpr.get(2)) + s0, s1, s2 = ( + AffineSymbolExpr.get(0), + AffineSymbolExpr.get(1), + AffineSymbolExpr.get(2), + ) + map1 = AffineMap.get(3, 3, [d2, d0 + s1 + s2, d1 + s0]) - replace0 = map1.replace(s0, AffineConstantExpr.get(42), 3, 3) - replace1 = map1.replace(s1, AffineConstantExpr.get(42), 3, 3) - replace3 = map1.replace(s2, AffineConstantExpr.get(42), 3, 2) + replace0 = map1.replace(s0, AffineConstantExpr.get(42), 3, 3) + replace1 = map1.replace(s1, AffineConstantExpr.get(42), 3, 3) + replace3 = map1.replace(s2, AffineConstantExpr.get(42), 3, 2) - # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s1 + s2, d1 + 42) - print(replace0) + # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s1 + s2, d1 + 42) + print(replace0) - # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2 + 42, d1 + s0) - print(replace1) + # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2 + 42, d1 + s0) + print(replace1) - # CHECK: (d0, d1, d2)[s0, s1] -> (d2, d0 + s1 + 42, d1 + s0) - print(replace3) + # CHECK: (d0, d1, d2)[s0, s1] -> (d2, d0 + s1 + 42, d1 + s0) + print(replace3) # CHECK-LABEL: TEST: testHash @run def testHash(): - with Context(): - d0, d1 = AffineDimExpr.get(0), AffineDimExpr.get(1) - m1 = AffineMap.get(2, 0, [d0, d1]) - m2 = AffineMap.get(2, 0, [d1, d0]) - assert hash(m1) == hash(AffineMap.get(2, 0, [d0, d1])) - - dictionary = dict() - dictionary[m1] = 1 - dictionary[m2] = 2 - assert m1 in dictionary + with Context(): + d0, d1 = AffineDimExpr.get(0), AffineDimExpr.get(1) + m1 = AffineMap.get(2, 0, [d0, d1]) + m2 = AffineMap.get(2, 0, [d1, d0]) + assert hash(m1) == hash(AffineMap.get(2, 0, [d0, d1])) + + dictionary = dict() + dictionary[m1] = 1 + dictionary[m2] = 2 + assert m1 in dictionary diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py index 3de4edb..5ce8bc6 100644 --- a/mlir/test/python/ir/array_attributes.py +++ b/mlir/test/python/ir/array_attributes.py @@ -6,26 +6,30 @@ import gc from mlir.ir import * import numpy as np + def run(f): - print("\nTEST:", f.__name__) - f() - gc.collect() - assert Context._get_live_count() == 0 - return f + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + return f + ################################################################################ # Tests of the array/buffer .get() factory method on unsupported dtype. ################################################################################ + @run def testGetDenseElementsUnsupported(): - with Context(): - array = np.array([["hello", "goodbye"]]) - try: - attr = DenseElementsAttr.get(array) - except ValueError as e: - # CHECK: unimplemented array format conversion from format: - print(e) + with Context(): + array = np.array([["hello", "goodbye"]]) + try: + attr = DenseElementsAttr.get(array) + except ValueError as e: + # CHECK: unimplemented array format conversion from format: + print(e) + ################################################################################ # Splats. @@ -34,85 +38,85 @@ def testGetDenseElementsUnsupported(): # CHECK-LABEL: TEST: testGetDenseElementsSplatInt @run def testGetDenseElementsSplatInt(): - with Context(), Location.unknown(): - t = IntegerType.get_signless(32) - element = IntegerAttr.get(t, 555) - shaped_type = RankedTensorType.get((2, 3, 4), t) - attr = DenseElementsAttr.get_splat(shaped_type, element) - # CHECK: dense<555> : tensor<2x3x4xi32> - print(attr) - # CHECK: is_splat: True - print("is_splat:", attr.is_splat) - assert attr.get_splat_value() == element + with Context(), Location.unknown(): + t = IntegerType.get_signless(32) + element = IntegerAttr.get(t, 555) + shaped_type = RankedTensorType.get((2, 3, 4), t) + attr = DenseElementsAttr.get_splat(shaped_type, element) + # CHECK: dense<555> : tensor<2x3x4xi32> + print(attr) + # CHECK: is_splat: True + print("is_splat:", attr.is_splat) + assert attr.get_splat_value() == element # CHECK-LABEL: TEST: testGetDenseElementsSplatFloat @run def testGetDenseElementsSplatFloat(): - with Context(), Location.unknown(): - t = F32Type.get() - element = FloatAttr.get(t, 1.2) - shaped_type = RankedTensorType.get((2, 3, 4), t) - attr = DenseElementsAttr.get_splat(shaped_type, element) - # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32> - print(attr) - assert attr.get_splat_value() == element + with Context(), Location.unknown(): + t = F32Type.get() + element = FloatAttr.get(t, 1.2) + shaped_type = RankedTensorType.get((2, 3, 4), t) + attr = DenseElementsAttr.get_splat(shaped_type, element) + # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32> + print(attr) + assert attr.get_splat_value() == element # CHECK-LABEL: TEST: testGetDenseElementsSplatErrors @run def testGetDenseElementsSplatErrors(): - with Context(), Location.unknown(): - t = F32Type.get() - other_t = F64Type.get() - element = FloatAttr.get(t, 1.2) - other_element = FloatAttr.get(other_t, 1.2) - shaped_type = RankedTensorType.get((2, 3, 4), t) - dynamic_shaped_type = UnrankedTensorType.get(t) - non_shaped_type = t - - try: - attr = DenseElementsAttr.get_splat(non_shaped_type, element) - except ValueError as e: - # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(f32) - print(e) - - try: - attr = DenseElementsAttr.get_splat(dynamic_shaped_type, element) - except ValueError as e: - # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(tensor<*xf32>) - print(e) - - try: - attr = DenseElementsAttr.get_splat(shaped_type, other_element) - except ValueError as e: - # CHECK: Shaped element type and attribute type must be equal: shaped=Type(tensor<2x3x4xf32>), element=Attribute(1.200000e+00 : f64) - print(e) + with Context(), Location.unknown(): + t = F32Type.get() + other_t = F64Type.get() + element = FloatAttr.get(t, 1.2) + other_element = FloatAttr.get(other_t, 1.2) + shaped_type = RankedTensorType.get((2, 3, 4), t) + dynamic_shaped_type = UnrankedTensorType.get(t) + non_shaped_type = t + + try: + attr = DenseElementsAttr.get_splat(non_shaped_type, element) + except ValueError as e: + # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(f32) + print(e) + + try: + attr = DenseElementsAttr.get_splat(dynamic_shaped_type, element) + except ValueError as e: + # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(tensor<*xf32>) + print(e) + + try: + attr = DenseElementsAttr.get_splat(shaped_type, other_element) + except ValueError as e: + # CHECK: Shaped element type and attribute type must be equal: shaped=Type(tensor<2x3x4xf32>), element=Attribute(1.200000e+00 : f64) + print(e) # CHECK-LABEL: TEST: testRepeatedValuesSplat @run def testRepeatedValuesSplat(): - with Context(): - array = np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dtype=np.float32) - attr = DenseElementsAttr.get(array) - # CHECK: dense<1.000000e+00> : tensor<2x3xf32> - print(attr) - # CHECK: is_splat: True - print("is_splat:", attr.is_splat) - # CHECK{LITERAL}: [[1. 1. 1.] - # CHECK{LITERAL}: [1. 1. 1.]] - print(np.array(attr)) + with Context(): + array = np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dtype=np.float32) + attr = DenseElementsAttr.get(array) + # CHECK: dense<1.000000e+00> : tensor<2x3xf32> + print(attr) + # CHECK: is_splat: True + print("is_splat:", attr.is_splat) + # CHECK{LITERAL}: [[1. 1. 1.] + # CHECK{LITERAL}: [1. 1. 1.]] + print(np.array(attr)) # CHECK-LABEL: TEST: testNonSplat @run def testNonSplat(): - with Context(): - array = np.array([2.0, 1.0, 1.0], dtype=np.float32) - attr = DenseElementsAttr.get(array) - # CHECK: is_splat: False - print("is_splat:", attr.is_splat) + with Context(): + array = np.array([2.0, 1.0, 1.0], dtype=np.float32) + attr = DenseElementsAttr.get(array) + # CHECK: is_splat: False + print("is_splat:", attr.is_splat) ################################################################################ @@ -121,50 +125,59 @@ def testNonSplat(): ### explicitly provided types + @run def testGetDenseElementsBF16(): - with Context(): - array = np.array([[2, 4, 8], [16, 32, 64]], dtype=np.uint16) - attr = DenseElementsAttr.get(array, type=BF16Type.get()) - # Note: These values don't mean much since just bit-casting. But they - # shouldn't change. - # CHECK: dense<{{\[}}[1.836710e-40, 3.673420e-40, 7.346840e-40], [1.469370e-39, 2.938740e-39, 5.877470e-39]]> : tensor<2x3xbf16> - print(attr) + with Context(): + array = np.array([[2, 4, 8], [16, 32, 64]], dtype=np.uint16) + attr = DenseElementsAttr.get(array, type=BF16Type.get()) + # Note: These values don't mean much since just bit-casting. But they + # shouldn't change. + # CHECK: dense<{{\[}}[1.836710e-40, 3.673420e-40, 7.346840e-40], [1.469370e-39, 2.938740e-39, 5.877470e-39]]> : tensor<2x3xbf16> + print(attr) + @run def testGetDenseElementsInteger4(): - with Context(): - array = np.array([[2, 4, 7], [-2, -4, -8]], dtype=np.uint8) - attr = DenseElementsAttr.get(array, type=IntegerType.get_signless(4)) - # Note: These values don't mean much since just bit-casting. But they - # shouldn't change. - # CHECK: dense<{{\[}}[2, 4, 7], [-2, -4, -8]]> : tensor<2x3xi4> - print(attr) + with Context(): + array = np.array([[2, 4, 7], [-2, -4, -8]], dtype=np.uint8) + attr = DenseElementsAttr.get(array, type=IntegerType.get_signless(4)) + # Note: These values don't mean much since just bit-casting. But they + # shouldn't change. + # CHECK: dense<{{\[}}[2, 4, 7], [-2, -4, -8]]> : tensor<2x3xi4> + print(attr) @run def testGetDenseElementsBool(): - with Context(): - bool_array = np.array([[1, 0, 1], [0, 1, 0]], dtype=np.bool_) - array = np.packbits(bool_array, axis=None, bitorder="little") - attr = DenseElementsAttr.get( - array, type=IntegerType.get_signless(1), shape=bool_array.shape) - # CHECK: dense<{{\[}}[true, false, true], [false, true, false]]> : tensor<2x3xi1> - print(attr) + with Context(): + bool_array = np.array([[1, 0, 1], [0, 1, 0]], dtype=np.bool_) + array = np.packbits(bool_array, axis=None, bitorder="little") + attr = DenseElementsAttr.get( + array, type=IntegerType.get_signless(1), shape=bool_array.shape + ) + # CHECK: dense<{{\[}}[true, false, true], [false, true, false]]> : tensor<2x3xi1> + print(attr) @run def testGetDenseElementsBoolSplat(): - with Context(): - zero = np.array(0, dtype=np.uint8) - one = np.array(255, dtype=np.uint8) - print(one) - # CHECK: dense : tensor<4x2x5xi1> - print(DenseElementsAttr.get( - zero, type=IntegerType.get_signless(1), shape=(4, 2, 5))) - # CHECK: dense : tensor<4x2x5xi1> - print(DenseElementsAttr.get( - one, type=IntegerType.get_signless(1), shape=(4, 2, 5))) + with Context(): + zero = np.array(0, dtype=np.uint8) + one = np.array(255, dtype=np.uint8) + print(one) + # CHECK: dense : tensor<4x2x5xi1> + print( + DenseElementsAttr.get( + zero, type=IntegerType.get_signless(1), shape=(4, 2, 5) + ) + ) + # CHECK: dense : tensor<4x2x5xi1> + print( + DenseElementsAttr.get( + one, type=IntegerType.get_signless(1), shape=(4, 2, 5) + ) + ) ### float and double arrays. @@ -172,213 +185,213 @@ def testGetDenseElementsBoolSplat(): # CHECK-LABEL: TEST: testGetDenseElementsF16 @run def testGetDenseElementsF16(): - with Context(): - array = np.array([[2.0, 4.0, 8.0], [16.0, 32.0, 64.0]], dtype=np.float16) - attr = DenseElementsAttr.get(array) - # CHECK: dense<{{\[}}[2.000000e+00, 4.000000e+00, 8.000000e+00], [1.600000e+01, 3.200000e+01, 6.400000e+01]]> : tensor<2x3xf16> - print(attr) - # CHECK: {{\[}}[ 2. 4. 8.] - # CHECK: {{\[}}16. 32. 64.]] - print(np.array(attr)) + with Context(): + array = np.array([[2.0, 4.0, 8.0], [16.0, 32.0, 64.0]], dtype=np.float16) + attr = DenseElementsAttr.get(array) + # CHECK: dense<{{\[}}[2.000000e+00, 4.000000e+00, 8.000000e+00], [1.600000e+01, 3.200000e+01, 6.400000e+01]]> : tensor<2x3xf16> + print(attr) + # CHECK: {{\[}}[ 2. 4. 8.] + # CHECK: {{\[}}16. 32. 64.]] + print(np.array(attr)) # CHECK-LABEL: TEST: testGetDenseElementsF32 @run def testGetDenseElementsF32(): - with Context(): - array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32) - attr = DenseElementsAttr.get(array) - # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf32> - print(attr) - # CHECK: {{\[}}[1.1 2.2 3.3] - # CHECK: {{\[}}4.4 5.5 6.6]] - print(np.array(attr)) + with Context(): + array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32) + attr = DenseElementsAttr.get(array) + # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf32> + print(attr) + # CHECK: {{\[}}[1.1 2.2 3.3] + # CHECK: {{\[}}4.4 5.5 6.6]] + print(np.array(attr)) # CHECK-LABEL: TEST: testGetDenseElementsF64 @run def testGetDenseElementsF64(): - with Context(): - array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64) - attr = DenseElementsAttr.get(array) - # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64> - print(attr) - # CHECK: {{\[}}[1.1 2.2 3.3] - # CHECK: {{\[}}4.4 5.5 6.6]] - print(np.array(attr)) + with Context(): + array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64) + attr = DenseElementsAttr.get(array) + # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64> + print(attr) + # CHECK: {{\[}}[1.1 2.2 3.3] + # CHECK: {{\[}}4.4 5.5 6.6]] + print(np.array(attr)) ### 16 bit integer arrays # CHECK-LABEL: TEST: testGetDenseElementsI16Signless @run def testGetDenseElementsI16Signless(): - with Context(): - array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16) - attr = DenseElementsAttr.get(array) - # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16> - print(attr) - # CHECK: {{\[}}[1 2 3] - # CHECK: {{\[}}4 5 6]] - print(np.array(attr)) + with Context(): + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16) + attr = DenseElementsAttr.get(array) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16> + print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) # CHECK-LABEL: TEST: testGetDenseElementsUI16Signless @run def testGetDenseElementsUI16Signless(): - with Context(): - array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16) - attr = DenseElementsAttr.get(array) - # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16> - print(attr) - # CHECK: {{\[}}[1 2 3] - # CHECK: {{\[}}4 5 6]] - print(np.array(attr)) + with Context(): + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16) + attr = DenseElementsAttr.get(array) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16> + print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) # CHECK-LABEL: TEST: testGetDenseElementsI16 @run def testGetDenseElementsI16(): - with Context(): - array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16) - attr = DenseElementsAttr.get(array, signless=False) - # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi16> - print(attr) - # CHECK: {{\[}}[1 2 3] - # CHECK: {{\[}}4 5 6]] - print(np.array(attr)) + with Context(): + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16) + attr = DenseElementsAttr.get(array, signless=False) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi16> + print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) # CHECK-LABEL: TEST: testGetDenseElementsUI16 @run def testGetDenseElementsUI16(): - with Context(): - array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16) - attr = DenseElementsAttr.get(array, signless=False) - # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui16> - print(attr) - # CHECK: {{\[}}[1 2 3] - # CHECK: {{\[}}4 5 6]] - print(np.array(attr)) + with Context(): + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16) + attr = DenseElementsAttr.get(array, signless=False) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui16> + print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) + ### 32 bit integer arrays # CHECK-LABEL: TEST: testGetDenseElementsI32Signless @run def testGetDenseElementsI32Signless(): - with Context(): - array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) - attr = DenseElementsAttr.get(array) - # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> - print(attr) - # CHECK: {{\[}}[1 2 3] - # CHECK: {{\[}}4 5 6]] - print(np.array(attr)) + with Context(): + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) + attr = DenseElementsAttr.get(array) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> + print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) # CHECK-LABEL: TEST: testGetDenseElementsUI32Signless @run def testGetDenseElementsUI32Signless(): - with Context(): - array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32) - attr = DenseElementsAttr.get(array) - # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> - print(attr) - # CHECK: {{\[}}[1 2 3] - # CHECK: {{\[}}4 5 6]] - print(np.array(attr)) + with Context(): + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32) + attr = DenseElementsAttr.get(array) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> + print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) # CHECK-LABEL: TEST: testGetDenseElementsI32 @run def testGetDenseElementsI32(): - with Context(): - array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) - attr = DenseElementsAttr.get(array, signless=False) - # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32> - print(attr) - # CHECK: {{\[}}[1 2 3] - # CHECK: {{\[}}4 5 6]] - print(np.array(attr)) + with Context(): + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) + attr = DenseElementsAttr.get(array, signless=False) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32> + print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) # CHECK-LABEL: TEST: testGetDenseElementsUI32 @run def testGetDenseElementsUI32(): - with Context(): - array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32) - attr = DenseElementsAttr.get(array, signless=False) - # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32> - print(attr) - # CHECK: {{\[}}[1 2 3] - # CHECK: {{\[}}4 5 6]] - print(np.array(attr)) + with Context(): + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32) + attr = DenseElementsAttr.get(array, signless=False) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32> + print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) ## 64bit integer arrays # CHECK-LABEL: TEST: testGetDenseElementsI64Signless @run def testGetDenseElementsI64Signless(): - with Context(): - array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) - attr = DenseElementsAttr.get(array) - # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> - print(attr) - # CHECK: {{\[}}[1 2 3] - # CHECK: {{\[}}4 5 6]] - print(np.array(attr)) + with Context(): + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) + attr = DenseElementsAttr.get(array) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> + print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) # CHECK-LABEL: TEST: testGetDenseElementsUI64Signless @run def testGetDenseElementsUI64Signless(): - with Context(): - array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64) - attr = DenseElementsAttr.get(array) - # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> - print(attr) - # CHECK: {{\[}}[1 2 3] - # CHECK: {{\[}}4 5 6]] - print(np.array(attr)) + with Context(): + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64) + attr = DenseElementsAttr.get(array) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> + print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) # CHECK-LABEL: TEST: testGetDenseElementsI64 @run def testGetDenseElementsI64(): - with Context(): - array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) - attr = DenseElementsAttr.get(array, signless=False) - # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64> - print(attr) - # CHECK: {{\[}}[1 2 3] - # CHECK: {{\[}}4 5 6]] - print(np.array(attr)) + with Context(): + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) + attr = DenseElementsAttr.get(array, signless=False) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64> + print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) # CHECK-LABEL: TEST: testGetDenseElementsUI64 @run def testGetDenseElementsUI64(): - with Context(): - array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64) - attr = DenseElementsAttr.get(array, signless=False) - # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64> - print(attr) - # CHECK: {{\[}}[1 2 3] - # CHECK: {{\[}}4 5 6]] - print(np.array(attr)) + with Context(): + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64) + attr = DenseElementsAttr.get(array, signless=False) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64> + print(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(np.array(attr)) # CHECK-LABEL: TEST: testGetDenseElementsIndex @run def testGetDenseElementsIndex(): - with Context(), Location.unknown(): - idx_type = IndexType.get() - array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) - attr = DenseElementsAttr.get(array, type=idx_type) - # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xindex> - print(attr) - arr = np.array(attr) - # CHECK: {{\[}}[1 2 3] - # CHECK: {{\[}}4 5 6]] - print(arr) - # CHECK: True - print(arr.dtype == np.int64) - + with Context(), Location.unknown(): + idx_type = IndexType.get() + array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) + attr = DenseElementsAttr.get(array, type=idx_type) + # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xindex> + print(attr) + arr = np.array(attr) + # CHECK: {{\[}}[1 2 3] + # CHECK: {{\[}}4 5 6]] + print(arr) + # CHECK: True + print(arr.dtype == np.int64) diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py index 6aad943..2907405 100644 --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -6,554 +6,550 @@ from mlir.ir import * def run(f): - print("\nTEST:", f.__name__) - f() - gc.collect() - assert Context._get_live_count() == 0 - return f + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + return f # CHECK-LABEL: TEST: testParsePrint @run def testParsePrint(): - with Context() as ctx: - t = Attribute.parse('"hello"') - assert t.context is ctx - ctx = None - gc.collect() - # CHECK: "hello" - print(str(t)) - # CHECK: Attribute("hello") - print(repr(t)) + with Context() as ctx: + t = Attribute.parse('"hello"') + assert t.context is ctx + ctx = None + gc.collect() + # CHECK: "hello" + print(str(t)) + # CHECK: Attribute("hello") + print(repr(t)) # CHECK-LABEL: TEST: testParseError @run def testParseError(): - with Context(): - try: - t = Attribute.parse("BAD_ATTR_DOES_NOT_EXIST") - except MLIRError as e: - # CHECK: testParseError: < - # CHECK: Unable to parse attribute: - # CHECK: error: "BAD_ATTR_DOES_NOT_EXIST":1:1: expected attribute value - # CHECK: > - print(f"testParseError: <{e}>") - else: - print("Exception not produced") + with Context(): + try: + t = Attribute.parse("BAD_ATTR_DOES_NOT_EXIST") + except MLIRError as e: + # CHECK: testParseError: < + # CHECK: Unable to parse attribute: + # CHECK: error: "BAD_ATTR_DOES_NOT_EXIST":1:1: expected attribute value + # CHECK: > + print(f"testParseError: <{e}>") + else: + print("Exception not produced") # CHECK-LABEL: TEST: testAttrEq @run def testAttrEq(): - with Context(): - a1 = Attribute.parse('"attr1"') - a2 = Attribute.parse('"attr2"') - a3 = Attribute.parse('"attr1"') - # CHECK: a1 == a1: True - print("a1 == a1:", a1 == a1) - # CHECK: a1 == a2: False - print("a1 == a2:", a1 == a2) - # CHECK: a1 == a3: True - print("a1 == a3:", a1 == a3) - # CHECK: a1 == None: False - print("a1 == None:", a1 == None) + with Context(): + a1 = Attribute.parse('"attr1"') + a2 = Attribute.parse('"attr2"') + a3 = Attribute.parse('"attr1"') + # CHECK: a1 == a1: True + print("a1 == a1:", a1 == a1) + # CHECK: a1 == a2: False + print("a1 == a2:", a1 == a2) + # CHECK: a1 == a3: True + print("a1 == a3:", a1 == a3) + # CHECK: a1 == None: False + print("a1 == None:", a1 == None) # CHECK-LABEL: TEST: testAttrHash @run def testAttrHash(): - with Context(): - a1 = Attribute.parse('"attr1"') - a2 = Attribute.parse('"attr2"') - a3 = Attribute.parse('"attr1"') - # CHECK: hash(a1) == hash(a3): True - print("hash(a1) == hash(a3):", a1.__hash__() == a3.__hash__()) + with Context(): + a1 = Attribute.parse('"attr1"') + a2 = Attribute.parse('"attr2"') + a3 = Attribute.parse('"attr1"') + # CHECK: hash(a1) == hash(a3): True + print("hash(a1) == hash(a3):", a1.__hash__() == a3.__hash__()) - s = set() - s.add(a1) - s.add(a2) - s.add(a3) - # CHECK: len(s): 2 - print("len(s): ", len(s)) + s = set() + s.add(a1) + s.add(a2) + s.add(a3) + # CHECK: len(s): 2 + print("len(s): ", len(s)) # CHECK-LABEL: TEST: testAttrCast @run def testAttrCast(): - with Context(): - a1 = Attribute.parse('"attr1"') - a2 = Attribute(a1) - # CHECK: a1 == a2: True - print("a1 == a2:", a1 == a2) + with Context(): + a1 = Attribute.parse('"attr1"') + a2 = Attribute(a1) + # CHECK: a1 == a2: True + print("a1 == a2:", a1 == a2) # CHECK-LABEL: TEST: testAttrIsInstance @run def testAttrIsInstance(): - with Context(): - a1 = Attribute.parse("42") - a2 = Attribute.parse("[42]") - assert IntegerAttr.isinstance(a1) - assert not IntegerAttr.isinstance(a2) - assert not ArrayAttr.isinstance(a1) - assert ArrayAttr.isinstance(a2) + with Context(): + a1 = Attribute.parse("42") + a2 = Attribute.parse("[42]") + assert IntegerAttr.isinstance(a1) + assert not IntegerAttr.isinstance(a2) + assert not ArrayAttr.isinstance(a1) + assert ArrayAttr.isinstance(a2) # CHECK-LABEL: TEST: testAttrEqDoesNotRaise @run def testAttrEqDoesNotRaise(): - with Context(): - a1 = Attribute.parse('"attr1"') - not_an_attr = "foo" - # CHECK: False - print(a1 == not_an_attr) - # CHECK: False - print(a1 == None) - # CHECK: True - print(a1 != None) + with Context(): + a1 = Attribute.parse('"attr1"') + not_an_attr = "foo" + # CHECK: False + print(a1 == not_an_attr) + # CHECK: False + print(a1 == None) + # CHECK: True + print(a1 != None) # CHECK-LABEL: TEST: testAttrCapsule @run def testAttrCapsule(): - with Context() as ctx: - a1 = Attribute.parse('"attr1"') - # CHECK: mlir.ir.Attribute._CAPIPtr - attr_capsule = a1._CAPIPtr - print(attr_capsule) - a2 = Attribute._CAPICreate(attr_capsule) - assert a2 == a1 - assert a2.context is ctx + with Context() as ctx: + a1 = Attribute.parse('"attr1"') + # CHECK: mlir.ir.Attribute._CAPIPtr + attr_capsule = a1._CAPIPtr + print(attr_capsule) + a2 = Attribute._CAPICreate(attr_capsule) + assert a2 == a1 + assert a2.context is ctx # CHECK-LABEL: TEST: testStandardAttrCasts @run def testStandardAttrCasts(): - with Context(): - a1 = Attribute.parse('"attr1"') - astr = StringAttr(a1) - aself = StringAttr(astr) - # CHECK: Attribute("attr1") - print(repr(astr)) - try: - tillegal = StringAttr(Attribute.parse("1.0")) - except ValueError as e: - # CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64)) - print("ValueError:", e) - else: - print("Exception not produced") + with Context(): + a1 = Attribute.parse('"attr1"') + astr = StringAttr(a1) + aself = StringAttr(astr) + # CHECK: Attribute("attr1") + print(repr(astr)) + try: + tillegal = StringAttr(Attribute.parse("1.0")) + except ValueError as e: + # CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64)) + print("ValueError:", e) + else: + print("Exception not produced") # CHECK-LABEL: TEST: testAffineMapAttr @run def testAffineMapAttr(): - with Context() as ctx: - d0 = AffineDimExpr.get(0) - d1 = AffineDimExpr.get(1) - c2 = AffineConstantExpr.get(2) - map0 = AffineMap.get(2, 3, []) + with Context() as ctx: + d0 = AffineDimExpr.get(0) + d1 = AffineDimExpr.get(1) + c2 = AffineConstantExpr.get(2) + map0 = AffineMap.get(2, 3, []) - # CHECK: affine_map<(d0, d1)[s0, s1, s2] -> ()> - attr_built = AffineMapAttr.get(map0) - print(str(attr_built)) + # CHECK: affine_map<(d0, d1)[s0, s1, s2] -> ()> + attr_built = AffineMapAttr.get(map0) + print(str(attr_built)) - attr_parsed = Attribute.parse(str(attr_built)) - assert attr_built == attr_parsed + attr_parsed = Attribute.parse(str(attr_built)) + assert attr_built == attr_parsed # CHECK-LABEL: TEST: testFloatAttr @run def testFloatAttr(): - with Context(), Location.unknown(): - fattr = FloatAttr(Attribute.parse("42.0 : f32")) - # CHECK: fattr value: 42.0 - print("fattr value:", fattr.value) - - # Test factory methods. - # CHECK: default_get: 4.200000e+01 : f32 - print("default_get:", FloatAttr.get( - F32Type.get(), 42.0)) - # CHECK: f32_get: 4.200000e+01 : f32 - print("f32_get:", FloatAttr.get_f32(42.0)) - # CHECK: f64_get: 4.200000e+01 : f64 - print("f64_get:", FloatAttr.get_f64(42.0)) - try: - fattr_invalid = FloatAttr.get( - IntegerType.get_signless(32), 42) - except MLIRError as e: - # CHECK: Invalid attribute: - # CHECK: error: unknown: expected floating point type - print(e) - else: - print("Exception not produced") + with Context(), Location.unknown(): + fattr = FloatAttr(Attribute.parse("42.0 : f32")) + # CHECK: fattr value: 42.0 + print("fattr value:", fattr.value) + + # Test factory methods. + # CHECK: default_get: 4.200000e+01 : f32 + print("default_get:", FloatAttr.get(F32Type.get(), 42.0)) + # CHECK: f32_get: 4.200000e+01 : f32 + print("f32_get:", FloatAttr.get_f32(42.0)) + # CHECK: f64_get: 4.200000e+01 : f64 + print("f64_get:", FloatAttr.get_f64(42.0)) + try: + fattr_invalid = FloatAttr.get(IntegerType.get_signless(32), 42) + except MLIRError as e: + # CHECK: Invalid attribute: + # CHECK: error: unknown: expected floating point type + print(e) + else: + print("Exception not produced") # CHECK-LABEL: TEST: testIntegerAttr @run def testIntegerAttr(): - with Context() as ctx: - i_attr = IntegerAttr(Attribute.parse("42")) - # CHECK: i_attr value: 42 - print("i_attr value:", i_attr.value) - # CHECK: i_attr type: i64 - print("i_attr type:", i_attr.type) - si_attr = IntegerAttr(Attribute.parse("-1 : si8")) - # CHECK: si_attr value: -1 - print("si_attr value:", si_attr.value) - ui_attr = IntegerAttr(Attribute.parse("255 : ui8")) - # CHECK: ui_attr value: 255 - print("ui_attr value:", ui_attr.value) - idx_attr = IntegerAttr(Attribute.parse("-1 : index")) - # CHECK: idx_attr value: -1 - print("idx_attr value:", idx_attr.value) - - # Test factory methods. - # CHECK: default_get: 42 : i32 - print("default_get:", IntegerAttr.get( - IntegerType.get_signless(32), 42)) + with Context() as ctx: + i_attr = IntegerAttr(Attribute.parse("42")) + # CHECK: i_attr value: 42 + print("i_attr value:", i_attr.value) + # CHECK: i_attr type: i64 + print("i_attr type:", i_attr.type) + si_attr = IntegerAttr(Attribute.parse("-1 : si8")) + # CHECK: si_attr value: -1 + print("si_attr value:", si_attr.value) + ui_attr = IntegerAttr(Attribute.parse("255 : ui8")) + # CHECK: ui_attr value: 255 + print("ui_attr value:", ui_attr.value) + idx_attr = IntegerAttr(Attribute.parse("-1 : index")) + # CHECK: idx_attr value: -1 + print("idx_attr value:", idx_attr.value) + + # Test factory methods. + # CHECK: default_get: 42 : i32 + print("default_get:", IntegerAttr.get(IntegerType.get_signless(32), 42)) # CHECK-LABEL: TEST: testBoolAttr @run def testBoolAttr(): - with Context() as ctx: - battr = BoolAttr(Attribute.parse("true")) - # CHECK: iattr value: True - print("iattr value:", battr.value) + with Context() as ctx: + battr = BoolAttr(Attribute.parse("true")) + # CHECK: iattr value: True + print("iattr value:", battr.value) - # Test factory methods. - # CHECK: default_get: true - print("default_get:", BoolAttr.get(True)) + # Test factory methods. + # CHECK: default_get: true + print("default_get:", BoolAttr.get(True)) # CHECK-LABEL: TEST: testFlatSymbolRefAttr @run def testFlatSymbolRefAttr(): - with Context() as ctx: - sattr = FlatSymbolRefAttr(Attribute.parse('@symbol')) - # CHECK: symattr value: symbol - print("symattr value:", sattr.value) + with Context() as ctx: + sattr = FlatSymbolRefAttr(Attribute.parse("@symbol")) + # CHECK: symattr value: symbol + print("symattr value:", sattr.value) - # Test factory methods. - # CHECK: default_get: @foobar - print("default_get:", FlatSymbolRefAttr.get("foobar")) + # Test factory methods. + # CHECK: default_get: @foobar + print("default_get:", FlatSymbolRefAttr.get("foobar")) # CHECK-LABEL: TEST: testOpaqueAttr @run def testOpaqueAttr(): - with Context() as ctx: - ctx.allow_unregistered_dialects = True - oattr = OpaqueAttr(Attribute.parse("#pytest_dummy.dummyattr<>")) - # CHECK: oattr value: pytest_dummy - print("oattr value:", oattr.dialect_namespace) - # CHECK: oattr value: b'dummyattr<>' - print("oattr value:", oattr.data) - - # Test factory methods. - # CHECK: default_get: #foobar<123> - print( - "default_get:", - OpaqueAttr.get("foobar", bytes("123", "utf-8"), NoneType.get())) + with Context() as ctx: + ctx.allow_unregistered_dialects = True + oattr = OpaqueAttr(Attribute.parse("#pytest_dummy.dummyattr<>")) + # CHECK: oattr value: pytest_dummy + print("oattr value:", oattr.dialect_namespace) + # CHECK: oattr value: b'dummyattr<>' + print("oattr value:", oattr.data) + + # Test factory methods. + # CHECK: default_get: #foobar<123> + print( + "default_get:", + OpaqueAttr.get("foobar", bytes("123", "utf-8"), NoneType.get()), + ) # CHECK-LABEL: TEST: testStringAttr @run def testStringAttr(): - with Context() as ctx: - sattr = StringAttr(Attribute.parse('"stringattr"')) - # CHECK: sattr value: stringattr - print("sattr value:", sattr.value) - # CHECK: sattr value: b'stringattr' - print("sattr value:", sattr.value_bytes) + with Context() as ctx: + sattr = StringAttr(Attribute.parse('"stringattr"')) + # CHECK: sattr value: stringattr + print("sattr value:", sattr.value) + # CHECK: sattr value: b'stringattr' + print("sattr value:", sattr.value_bytes) - # Test factory methods. - # CHECK: default_get: "foobar" - print("default_get:", StringAttr.get("foobar")) - # CHECK: typed_get: "12345" : i32 - print("typed_get:", StringAttr.get_typed( - IntegerType.get_signless(32), "12345")) + # Test factory methods. + # CHECK: default_get: "foobar" + print("default_get:", StringAttr.get("foobar")) + # CHECK: typed_get: "12345" : i32 + print("typed_get:", StringAttr.get_typed(IntegerType.get_signless(32), "12345")) # CHECK-LABEL: TEST: testNamedAttr @run def testNamedAttr(): - with Context(): - a = Attribute.parse('"stringattr"') - named = a.get_named("foobar") # Note: under the small object threshold - # CHECK: attr: "stringattr" - print("attr:", named.attr) - # CHECK: name: foobar - print("name:", named.name) - # CHECK: named: NamedAttribute(foobar="stringattr") - print("named:", named) + with Context(): + a = Attribute.parse('"stringattr"') + named = a.get_named("foobar") # Note: under the small object threshold + # CHECK: attr: "stringattr" + print("attr:", named.attr) + # CHECK: name: foobar + print("name:", named.name) + # CHECK: named: NamedAttribute(foobar="stringattr") + print("named:", named) # CHECK-LABEL: TEST: testDenseIntAttr @run def testDenseIntAttr(): - with Context(): - raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>") - # CHECK: attr: dense<[{{\[}}0, 1, 2], [3, 4, 5]]> - print("attr:", raw) + with Context(): + raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>") + # CHECK: attr: dense<[{{\[}}0, 1, 2], [3, 4, 5]]> + print("attr:", raw) - a = DenseIntElementsAttr(raw) - assert len(a) == 6 + a = DenseIntElementsAttr(raw) + assert len(a) == 6 - # CHECK: 0 1 2 3 4 5 - for value in a: - print(value, end=" ") - print() + # CHECK: 0 1 2 3 4 5 + for value in a: + print(value, end=" ") + print() - # CHECK: i32 - print(ShapedType(a.type).element_type) + # CHECK: i32 + print(ShapedType(a.type).element_type) - raw = Attribute.parse("dense<[true,false,true,false]> : vector<4xi1>") - # CHECK: attr: dense<[true, false, true, false]> - print("attr:", raw) + raw = Attribute.parse("dense<[true,false,true,false]> : vector<4xi1>") + # CHECK: attr: dense<[true, false, true, false]> + print("attr:", raw) - a = DenseIntElementsAttr(raw) - assert len(a) == 4 + a = DenseIntElementsAttr(raw) + assert len(a) == 4 - # CHECK: 1 0 1 0 - for value in a: - print(value, end=" ") - print() + # CHECK: 1 0 1 0 + for value in a: + print(value, end=" ") + print() - # CHECK: i1 - print(ShapedType(a.type).element_type) + # CHECK: i1 + print(ShapedType(a.type).element_type) @run def testDenseArrayGetItem(): - def print_item(AttrClass, attr_asm): - attr = AttrClass(Attribute.parse(attr_asm)) - print(f"{len(attr)}: {attr[0]}, {attr[1]}") - - with Context(): - # CHECK: 2: 0, 1 - print_item(DenseBoolArrayAttr, "array") - # CHECK: 2: 2, 3 - print_item(DenseI8ArrayAttr, "array") - # CHECK: 2: 4, 5 - print_item(DenseI16ArrayAttr, "array") - # CHECK: 2: 6, 7 - print_item(DenseI32ArrayAttr, "array") - # CHECK: 2: 8, 9 - print_item(DenseI64ArrayAttr, "array") - # CHECK: 2: 1.{{0+}}, 2.{{0+}} - print_item(DenseF32ArrayAttr, "array") - # CHECK: 2: 3.{{0+}}, 4.{{0+}} - print_item(DenseF64ArrayAttr, "array") + def print_item(AttrClass, attr_asm): + attr = AttrClass(Attribute.parse(attr_asm)) + print(f"{len(attr)}: {attr[0]}, {attr[1]}") + + with Context(): + # CHECK: 2: 0, 1 + print_item(DenseBoolArrayAttr, "array") + # CHECK: 2: 2, 3 + print_item(DenseI8ArrayAttr, "array") + # CHECK: 2: 4, 5 + print_item(DenseI16ArrayAttr, "array") + # CHECK: 2: 6, 7 + print_item(DenseI32ArrayAttr, "array") + # CHECK: 2: 8, 9 + print_item(DenseI64ArrayAttr, "array") + # CHECK: 2: 1.{{0+}}, 2.{{0+}} + print_item(DenseF32ArrayAttr, "array") + # CHECK: 2: 3.{{0+}}, 4.{{0+}} + print_item(DenseF64ArrayAttr, "array") # CHECK-LABEL: TEST: testDenseIntAttrGetItem @run def testDenseIntAttrGetItem(): - def print_item(attr_asm): - attr = DenseIntElementsAttr(Attribute.parse(attr_asm)) - dtype = ShapedType(attr.type).element_type - try: - item = attr[0] - print(f"{dtype}:", item) - except TypeError as e: - print(f"{dtype}:", e) - - with Context(): - # CHECK: i1: 1 - print_item("dense : tensor") - # CHECK: i8: 123 - print_item("dense<123> : tensor") - # CHECK: i16: 123 - print_item("dense<123> : tensor") - # CHECK: i32: 123 - print_item("dense<123> : tensor") - # CHECK: i64: 123 - print_item("dense<123> : tensor") - # CHECK: ui8: 123 - print_item("dense<123> : tensor") - # CHECK: ui16: 123 - print_item("dense<123> : tensor") - # CHECK: ui32: 123 - print_item("dense<123> : tensor") - # CHECK: ui64: 123 - print_item("dense<123> : tensor") - # CHECK: si8: -123 - print_item("dense<-123> : tensor") - # CHECK: si16: -123 - print_item("dense<-123> : tensor") - # CHECK: si32: -123 - print_item("dense<-123> : tensor") - # CHECK: si64: -123 - print_item("dense<-123> : tensor") - - # CHECK: i7: Unsupported integer type - print_item("dense<123> : tensor") + def print_item(attr_asm): + attr = DenseIntElementsAttr(Attribute.parse(attr_asm)) + dtype = ShapedType(attr.type).element_type + try: + item = attr[0] + print(f"{dtype}:", item) + except TypeError as e: + print(f"{dtype}:", e) + + with Context(): + # CHECK: i1: 1 + print_item("dense : tensor") + # CHECK: i8: 123 + print_item("dense<123> : tensor") + # CHECK: i16: 123 + print_item("dense<123> : tensor") + # CHECK: i32: 123 + print_item("dense<123> : tensor") + # CHECK: i64: 123 + print_item("dense<123> : tensor") + # CHECK: ui8: 123 + print_item("dense<123> : tensor") + # CHECK: ui16: 123 + print_item("dense<123> : tensor") + # CHECK: ui32: 123 + print_item("dense<123> : tensor") + # CHECK: ui64: 123 + print_item("dense<123> : tensor") + # CHECK: si8: -123 + print_item("dense<-123> : tensor") + # CHECK: si16: -123 + print_item("dense<-123> : tensor") + # CHECK: si32: -123 + print_item("dense<-123> : tensor") + # CHECK: si64: -123 + print_item("dense<-123> : tensor") + + # CHECK: i7: Unsupported integer type + print_item("dense<123> : tensor") # CHECK-LABEL: TEST: testDenseFPAttr @run def testDenseFPAttr(): - with Context(): - raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>") - # CHECK: attr: dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> + with Context(): + raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>") + # CHECK: attr: dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> - print("attr:", raw) + print("attr:", raw) - a = DenseFPElementsAttr(raw) - assert len(a) == 4 + a = DenseFPElementsAttr(raw) + assert len(a) == 4 - # CHECK: 0.0 1.0 2.0 3.0 - for value in a: - print(value, end=" ") - print() + # CHECK: 0.0 1.0 2.0 3.0 + for value in a: + print(value, end=" ") + print() - # CHECK: f32 - print(ShapedType(a.type).element_type) + # CHECK: f32 + print(ShapedType(a.type).element_type) # CHECK-LABEL: TEST: testDictAttr @run def testDictAttr(): - with Context(): - dict_attr = { - 'stringattr': StringAttr.get('string'), - 'integerattr' : IntegerAttr.get( - IntegerType.get_signless(32), 42) - } + with Context(): + dict_attr = { + "stringattr": StringAttr.get("string"), + "integerattr": IntegerAttr.get(IntegerType.get_signless(32), 42), + } - a = DictAttr.get(dict_attr) + a = DictAttr.get(dict_attr) - # CHECK attr: {integerattr = 42 : i32, stringattr = "string"} - print("attr:", a) + # CHECK attr: {integerattr = 42 : i32, stringattr = "string"} + print("attr:", a) - assert len(a) == 2 + assert len(a) == 2 - # CHECK: 42 : i32 - print(a['integerattr']) + # CHECK: 42 : i32 + print(a["integerattr"]) - # CHECK: "string" - print(a['stringattr']) + # CHECK: "string" + print(a["stringattr"]) - # CHECK: True - print('stringattr' in a) + # CHECK: True + print("stringattr" in a) - # CHECK: False - print('not_in_dict' in a) + # CHECK: False + print("not_in_dict" in a) - # Check that exceptions are raised as expected. - try: - _ = a['does_not_exist'] - except KeyError: - pass - else: - assert False, "Exception not produced" + # Check that exceptions are raised as expected. + try: + _ = a["does_not_exist"] + except KeyError: + pass + else: + assert False, "Exception not produced" - try: - _ = a[42] - except IndexError: - pass - else: - assert False, "expected IndexError on accessing an out-of-bounds attribute" + try: + _ = a[42] + except IndexError: + pass + else: + assert False, "expected IndexError on accessing an out-of-bounds attribute" - # CHECK "empty: {}" - print("empty: ", DictAttr.get()) + # CHECK "empty: {}" + print("empty: ", DictAttr.get()) # CHECK-LABEL: TEST: testTypeAttr @run def testTypeAttr(): - with Context(): - raw = Attribute.parse("vector<4xf32>") - # CHECK: attr: vector<4xf32> - print("attr:", raw) - type_attr = TypeAttr(raw) - # CHECK: f32 - print(ShapedType(type_attr.value).element_type) + with Context(): + raw = Attribute.parse("vector<4xf32>") + # CHECK: attr: vector<4xf32> + print("attr:", raw) + type_attr = TypeAttr(raw) + # CHECK: f32 + print(ShapedType(type_attr.value).element_type) # CHECK-LABEL: TEST: testArrayAttr @run def testArrayAttr(): - with Context(): - raw = Attribute.parse("[42, true, vector<4xf32>]") - # CHECK: attr: [42, true, vector<4xf32>] - print("raw attr:", raw) - # CHECK: - 42 - # CHECK: - true - # CHECK: - vector<4xf32> - for attr in ArrayAttr(raw): - print("- ", attr) - - with Context(): - intAttr = Attribute.parse("42") - vecAttr = Attribute.parse("vector<4xf32>") - boolAttr = BoolAttr.get(True) - raw = ArrayAttr.get([vecAttr, boolAttr, intAttr]) - # CHECK: attr: [vector<4xf32>, true, 42] - print("raw attr:", raw) - # CHECK: - vector<4xf32> - # CHECK: - true - # CHECK: - 42 - arr = ArrayAttr(raw) - for attr in arr: - print("- ", attr) - # CHECK: attr[0]: vector<4xf32> - print("attr[0]:", arr[0]) - # CHECK: attr[1]: true - print("attr[1]:", arr[1]) - # CHECK: attr[2]: 42 - print("attr[2]:", arr[2]) - try: - print("attr[3]:", arr[3]) - except IndexError as e: - # CHECK: Error: ArrayAttribute index out of range - print("Error: ", e) - with Context(): + with Context(): + raw = Attribute.parse("[42, true, vector<4xf32>]") + # CHECK: attr: [42, true, vector<4xf32>] + print("raw attr:", raw) + # CHECK: - 42 + # CHECK: - true + # CHECK: - vector<4xf32> + for attr in ArrayAttr(raw): + print("- ", attr) + + with Context(): + intAttr = Attribute.parse("42") + vecAttr = Attribute.parse("vector<4xf32>") + boolAttr = BoolAttr.get(True) + raw = ArrayAttr.get([vecAttr, boolAttr, intAttr]) + # CHECK: attr: [vector<4xf32>, true, 42] + print("raw attr:", raw) + # CHECK: - vector<4xf32> + # CHECK: - true + # CHECK: - 42 + arr = ArrayAttr(raw) + for attr in arr: + print("- ", attr) + # CHECK: attr[0]: vector<4xf32> + print("attr[0]:", arr[0]) + # CHECK: attr[1]: true + print("attr[1]:", arr[1]) + # CHECK: attr[2]: 42 + print("attr[2]:", arr[2]) try: - ArrayAttr.get([None]) - except RuntimeError as e: - # CHECK: Error: Invalid attribute (None?) when attempting to create an ArrayAttribute - print("Error: ", e) - try: - ArrayAttr.get([42]) - except RuntimeError as e: - # CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute - print("Error: ", e) - - with Context(): - array = ArrayAttr.get([StringAttr.get("a"), StringAttr.get("b")]) - array = array + [StringAttr.get("c")] - # CHECK: concat: ["a", "b", "c"] - print("concat: ", array) + print("attr[3]:", arr[3]) + except IndexError as e: + # CHECK: Error: ArrayAttribute index out of range + print("Error: ", e) + with Context(): + try: + ArrayAttr.get([None]) + except RuntimeError as e: + # CHECK: Error: Invalid attribute (None?) when attempting to create an ArrayAttribute + print("Error: ", e) + try: + ArrayAttr.get([42]) + except RuntimeError as e: + # CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute + print("Error: ", e) + + with Context(): + array = ArrayAttr.get([StringAttr.get("a"), StringAttr.get("b")]) + array = array + [StringAttr.get("c")] + # CHECK: concat: ["a", "b", "c"] + print("concat: ", array) # CHECK-LABEL: TEST: testStridedLayoutAttr @run def testStridedLayoutAttr(): - with Context(): - attr = StridedLayoutAttr.get(42, [5, 7, 13]) - # CHECK: strided<[5, 7, 13], offset: 42> - print(attr) - # CHECK: 42 - print(attr.offset) - # CHECK: 3 - print(len(attr.strides)) - # CHECK: 5 - print(attr.strides[0]) - # CHECK: 7 - print(attr.strides[1]) - # CHECK: 13 - print(attr.strides[2]) - - attr = StridedLayoutAttr.get_fully_dynamic(3) - dynamic = ShapedType.get_dynamic_stride_or_offset() - # CHECK: strided<[?, ?, ?], offset: ?> - print(attr) - # CHECK: offset is dynamic: True - print(f"offset is dynamic: {attr.offset == dynamic}") - # CHECK: rank: 3 - print(f"rank: {len(attr.strides)}") - # CHECK: strides are dynamic: [True, True, True] - print(f"strides are dynamic: {[s == dynamic for s in attr.strides]}") + with Context(): + attr = StridedLayoutAttr.get(42, [5, 7, 13]) + # CHECK: strided<[5, 7, 13], offset: 42> + print(attr) + # CHECK: 42 + print(attr.offset) + # CHECK: 3 + print(len(attr.strides)) + # CHECK: 5 + print(attr.strides[0]) + # CHECK: 7 + print(attr.strides[1]) + # CHECK: 13 + print(attr.strides[2]) + + attr = StridedLayoutAttr.get_fully_dynamic(3) + dynamic = ShapedType.get_dynamic_stride_or_offset() + # CHECK: strided<[?, ?, ?], offset: ?> + print(attr) + # CHECK: offset is dynamic: True + print(f"offset is dynamic: {attr.offset == dynamic}") + # CHECK: rank: 3 + print(f"rank: {len(attr.strides)}") + # CHECK: strides are dynamic: [True, True, True] + print(f"strides are dynamic: {[s == dynamic for s in attr.strides]}") diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py index e929d79..8b4d946 100644 --- a/mlir/test/python/ir/blocks.py +++ b/mlir/test/python/ir/blocks.py @@ -10,11 +10,11 @@ from mlir.dialects import func def run(f): - print("\nTEST:", f.__name__) - f() - gc.collect() - assert Context._get_live_count() == 0 - return f + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + return f # CHECK-LABEL: TEST: testBlockCreation @@ -26,60 +26,66 @@ def run(f): # CHECK: return @run def testBlockCreation(): - with Context() as ctx, Location.unknown(): - module = builtin.ModuleOp() - with InsertionPoint(module.body): - f_type = FunctionType.get( - [IntegerType.get_signless(32), - IntegerType.get_signless(16)], []) - f_op = func.FuncOp("test", f_type) - entry_block = f_op.add_entry_block([Location.name("arg0"), Location.name("arg1")]) - i32_arg, i16_arg = entry_block.arguments - successor_block = entry_block.create_after(i32_arg.type, arg_locs=[Location.name("successor")]) - with InsertionPoint(successor_block) as successor_ip: - assert successor_ip.block == successor_block - func.ReturnOp([]) - middle_block = successor_block.create_before(i16_arg.type, arg_locs=[Location.name("middle")]) - - with InsertionPoint(entry_block) as entry_ip: - assert entry_ip.block == entry_block - cf.BranchOp([i16_arg], dest=middle_block) - - with InsertionPoint(middle_block) as middle_ip: - assert middle_ip.block == middle_block - cf.BranchOp([i32_arg], dest=successor_block) - module.print(enable_debug_info=True) - # Ensure region back references are coherent. - assert entry_block.region == middle_block.region == successor_block.region + with Context() as ctx, Location.unknown(): + module = builtin.ModuleOp() + with InsertionPoint(module.body): + f_type = FunctionType.get( + [IntegerType.get_signless(32), IntegerType.get_signless(16)], [] + ) + f_op = func.FuncOp("test", f_type) + entry_block = f_op.add_entry_block( + [Location.name("arg0"), Location.name("arg1")] + ) + i32_arg, i16_arg = entry_block.arguments + successor_block = entry_block.create_after( + i32_arg.type, arg_locs=[Location.name("successor")] + ) + with InsertionPoint(successor_block) as successor_ip: + assert successor_ip.block == successor_block + func.ReturnOp([]) + middle_block = successor_block.create_before( + i16_arg.type, arg_locs=[Location.name("middle")] + ) + + with InsertionPoint(entry_block) as entry_ip: + assert entry_ip.block == entry_block + cf.BranchOp([i16_arg], dest=middle_block) + + with InsertionPoint(middle_block) as middle_ip: + assert middle_ip.block == middle_block + cf.BranchOp([i32_arg], dest=successor_block) + module.print(enable_debug_info=True) + # Ensure region back references are coherent. + assert entry_block.region == middle_block.region == successor_block.region # CHECK-LABEL: TEST: testBlockCreationArgLocs @run def testBlockCreationArgLocs(): - with Context() as ctx: - ctx.allow_unregistered_dialects = True - f32 = F32Type.get() - op = Operation.create("test", regions=1, loc=Location.unknown()) - blocks = op.regions[0].blocks - - with Location.name("default_loc"): - blocks.append(f32) - blocks.append() - # CHECK: ^bb0(%{{.+}}: f32 loc("default_loc")): - # CHECK-NEXT: ^bb1: - op.print(enable_debug_info=True) - - try: - blocks.append(f32) - except RuntimeError as err: - # CHECK: Missing loc: An MLIR function requires a Location but none was provided - print("Missing loc:", err) - - try: - blocks.append(f32, f32, arg_locs=[Location.unknown()]) - except ValueError as err: - # CHECK: Wrong loc count: Expected 2 locations, got: 1 - print("Wrong loc count:", err) + with Context() as ctx: + ctx.allow_unregistered_dialects = True + f32 = F32Type.get() + op = Operation.create("test", regions=1, loc=Location.unknown()) + blocks = op.regions[0].blocks + + with Location.name("default_loc"): + blocks.append(f32) + blocks.append() + # CHECK: ^bb0(%{{.+}}: f32 loc("default_loc")): + # CHECK-NEXT: ^bb1: + op.print(enable_debug_info=True) + + try: + blocks.append(f32) + except RuntimeError as err: + # CHECK: Missing loc: An MLIR function requires a Location but none was provided + print("Missing loc:", err) + + try: + blocks.append(f32, f32, arg_locs=[Location.unknown()]) + except ValueError as err: + # CHECK: Wrong loc count: Expected 2 locations, got: 1 + print("Wrong loc count:", err) # CHECK-LABEL: TEST: testFirstBlockCreation @@ -87,19 +93,20 @@ def testBlockCreationArgLocs(): # CHECK: return @run def testFirstBlockCreation(): - with Context() as ctx, Location.unknown(): - module = builtin.ModuleOp() - f32 = F32Type.get() - with InsertionPoint(module.body): - f = func.FuncOp("test", ([f32], [])) - entry_block = Block.create_at_start(f.operation.regions[0], - [f32], [Location.name("arg_loc")]) - with InsertionPoint(entry_block): - func.ReturnOp([]) - - module.print(enable_debug_info=True) - assert module.verify() - assert f.body.blocks[0] == entry_block + with Context() as ctx, Location.unknown(): + module = builtin.ModuleOp() + f32 = F32Type.get() + with InsertionPoint(module.body): + f = func.FuncOp("test", ([f32], [])) + entry_block = Block.create_at_start( + f.operation.regions[0], [f32], [Location.name("arg_loc")] + ) + with InsertionPoint(entry_block): + func.ReturnOp([]) + + module.print(enable_debug_info=True) + assert module.verify() + assert f.body.blocks[0] == entry_block # CHECK-LABEL: TEST: testBlockMove @@ -109,32 +116,32 @@ def testFirstBlockCreation(): # CHECK: }) : () -> f32 @run def testBlockMove(): - with Context() as ctx, Location.unknown(): - ctx.allow_unregistered_dialects = True - module = Module.create() - f32 = F32Type.get() - with InsertionPoint(module.body): - dummy = Operation.create("dummy", regions=1) - block = Block.create_at_start(dummy.operation.regions[0], [f32]) - with InsertionPoint(block): - ret_op = Operation.create("ret", operands=[block.arguments[0]]) - realop = Operation.create("realop", - results=[r.type for r in ret_op.operands], - regions=1) - block.append_to(realop.operation.regions[0]) - dummy.operation.erase() - print(module) + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + dummy = Operation.create("dummy", regions=1) + block = Block.create_at_start(dummy.operation.regions[0], [f32]) + with InsertionPoint(block): + ret_op = Operation.create("ret", operands=[block.arguments[0]]) + realop = Operation.create( + "realop", results=[r.type for r in ret_op.operands], regions=1 + ) + block.append_to(realop.operation.regions[0]) + dummy.operation.erase() + print(module) # CHECK-LABEL: TEST: testBlockHash @run def testBlockHash(): - with Context() as ctx, Location.unknown(): - ctx.allow_unregistered_dialects = True - module = Module.create() - f32 = F32Type.get() - with InsertionPoint(module.body): - dummy = Operation.create("dummy", regions=1) - block1 = Block.create_at_start(dummy.operation.regions[0], [f32]) - block2 = Block.create_at_start(dummy.operation.regions[0], [f32]) - assert hash(block1) != hash(block2) + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + dummy = Operation.create("dummy", regions=1) + block1 = Block.create_at_start(dummy.operation.regions[0], [f32]) + block2 = Block.create_at_start(dummy.operation.regions[0], [f32]) + assert hash(block1) != hash(block2) diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py index 19e21ff..fc484a5 100644 --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -5,246 +5,246 @@ from mlir.ir import * def run(f): - print("\nTEST:", f.__name__) - f() - gc.collect() - assert Context._get_live_count() == 0 - return f + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + return f # CHECK-LABEL: TEST: testParsePrint @run def testParsePrint(): - ctx = Context() - t = Type.parse("i32", ctx) - assert t.context is ctx - ctx = None - gc.collect() - # CHECK: i32 - print(str(t)) - # CHECK: Type(i32) - print(repr(t)) + ctx = Context() + t = Type.parse("i32", ctx) + assert t.context is ctx + ctx = None + gc.collect() + # CHECK: i32 + print(str(t)) + # CHECK: Type(i32) + print(repr(t)) # CHECK-LABEL: TEST: testParseError @run def testParseError(): - ctx = Context() - try: - t = Type.parse("BAD_TYPE_DOES_NOT_EXIST", ctx) - except MLIRError as e: - # CHECK: testParseError: < - # CHECK: Unable to parse type: - # CHECK: error: "BAD_TYPE_DOES_NOT_EXIST":1:1: expected non-function type - # CHECK: > - print(f"testParseError: <{e}>") - else: - print("Exception not produced") + ctx = Context() + try: + t = Type.parse("BAD_TYPE_DOES_NOT_EXIST", ctx) + except MLIRError as e: + # CHECK: testParseError: < + # CHECK: Unable to parse type: + # CHECK: error: "BAD_TYPE_DOES_NOT_EXIST":1:1: expected non-function type + # CHECK: > + print(f"testParseError: <{e}>") + else: + print("Exception not produced") # CHECK-LABEL: TEST: testTypeEq @run def testTypeEq(): - ctx = Context() - t1 = Type.parse("i32", ctx) - t2 = Type.parse("f32", ctx) - t3 = Type.parse("i32", ctx) - # CHECK: t1 == t1: True - print("t1 == t1:", t1 == t1) - # CHECK: t1 == t2: False - print("t1 == t2:", t1 == t2) - # CHECK: t1 == t3: True - print("t1 == t3:", t1 == t3) - # CHECK: t1 == None: False - print("t1 == None:", t1 == None) + ctx = Context() + t1 = Type.parse("i32", ctx) + t2 = Type.parse("f32", ctx) + t3 = Type.parse("i32", ctx) + # CHECK: t1 == t1: True + print("t1 == t1:", t1 == t1) + # CHECK: t1 == t2: False + print("t1 == t2:", t1 == t2) + # CHECK: t1 == t3: True + print("t1 == t3:", t1 == t3) + # CHECK: t1 == None: False + print("t1 == None:", t1 == None) # CHECK-LABEL: TEST: testTypeHash @run def testTypeHash(): - ctx = Context() - t1 = Type.parse("i32", ctx) - t2 = Type.parse("f32", ctx) - t3 = Type.parse("i32", ctx) + ctx = Context() + t1 = Type.parse("i32", ctx) + t2 = Type.parse("f32", ctx) + t3 = Type.parse("i32", ctx) - # CHECK: hash(t1) == hash(t3): True - print("hash(t1) == hash(t3):", t1.__hash__() == t3.__hash__()) + # CHECK: hash(t1) == hash(t3): True + print("hash(t1) == hash(t3):", t1.__hash__() == t3.__hash__()) - s = set() - s.add(t1) - s.add(t2) - s.add(t3) - # CHECK: len(s): 2 - print("len(s): ", len(s)) + s = set() + s.add(t1) + s.add(t2) + s.add(t3) + # CHECK: len(s): 2 + print("len(s): ", len(s)) # CHECK-LABEL: TEST: testTypeCast @run def testTypeCast(): - ctx = Context() - t1 = Type.parse("i32", ctx) - t2 = Type(t1) - # CHECK: t1 == t2: True - print("t1 == t2:", t1 == t2) + ctx = Context() + t1 = Type.parse("i32", ctx) + t2 = Type(t1) + # CHECK: t1 == t2: True + print("t1 == t2:", t1 == t2) # CHECK-LABEL: TEST: testTypeIsInstance @run def testTypeIsInstance(): - ctx = Context() - t1 = Type.parse("i32", ctx) - t2 = Type.parse("f32", ctx) - # CHECK: True - print(IntegerType.isinstance(t1)) - # CHECK: False - print(F32Type.isinstance(t1)) - # CHECK: True - print(F32Type.isinstance(t2)) + ctx = Context() + t1 = Type.parse("i32", ctx) + t2 = Type.parse("f32", ctx) + # CHECK: True + print(IntegerType.isinstance(t1)) + # CHECK: False + print(F32Type.isinstance(t1)) + # CHECK: True + print(F32Type.isinstance(t2)) # CHECK-LABEL: TEST: testTypeEqDoesNotRaise @run def testTypeEqDoesNotRaise(): - ctx = Context() - t1 = Type.parse("i32", ctx) - not_a_type = "foo" - # CHECK: False - print(t1 == not_a_type) - # CHECK: False - print(t1 == None) - # CHECK: True - print(t1 != None) + ctx = Context() + t1 = Type.parse("i32", ctx) + not_a_type = "foo" + # CHECK: False + print(t1 == not_a_type) + # CHECK: False + print(t1 == None) + # CHECK: True + print(t1 != None) # CHECK-LABEL: TEST: testTypeCapsule @run def testTypeCapsule(): - with Context() as ctx: - t1 = Type.parse("i32", ctx) - # CHECK: mlir.ir.Type._CAPIPtr - type_capsule = t1._CAPIPtr - print(type_capsule) - t2 = Type._CAPICreate(type_capsule) - assert t2 == t1 - assert t2.context is ctx + with Context() as ctx: + t1 = Type.parse("i32", ctx) + # CHECK: mlir.ir.Type._CAPIPtr + type_capsule = t1._CAPIPtr + print(type_capsule) + t2 = Type._CAPICreate(type_capsule) + assert t2 == t1 + assert t2.context is ctx # CHECK-LABEL: TEST: testStandardTypeCasts @run def testStandardTypeCasts(): - ctx = Context() - t1 = Type.parse("i32", ctx) - tint = IntegerType(t1) - tself = IntegerType(tint) - # CHECK: Type(i32) - print(repr(tint)) - try: - tillegal = IntegerType(Type.parse("f32", ctx)) - except ValueError as e: - # CHECK: ValueError: Cannot cast type to IntegerType (from Type(f32)) - print("ValueError:", e) - else: - print("Exception not produced") + ctx = Context() + t1 = Type.parse("i32", ctx) + tint = IntegerType(t1) + tself = IntegerType(tint) + # CHECK: Type(i32) + print(repr(tint)) + try: + tillegal = IntegerType(Type.parse("f32", ctx)) + except ValueError as e: + # CHECK: ValueError: Cannot cast type to IntegerType (from Type(f32)) + print("ValueError:", e) + else: + print("Exception not produced") # CHECK-LABEL: TEST: testIntegerType @run def testIntegerType(): - with Context() as ctx: - i32 = IntegerType(Type.parse("i32")) - # CHECK: i32 width: 32 - print("i32 width:", i32.width) - # CHECK: i32 signless: True - print("i32 signless:", i32.is_signless) - # CHECK: i32 signed: False - print("i32 signed:", i32.is_signed) - # CHECK: i32 unsigned: False - print("i32 unsigned:", i32.is_unsigned) - - s32 = IntegerType(Type.parse("si32")) - # CHECK: s32 signless: False - print("s32 signless:", s32.is_signless) - # CHECK: s32 signed: True - print("s32 signed:", s32.is_signed) - # CHECK: s32 unsigned: False - print("s32 unsigned:", s32.is_unsigned) - - u32 = IntegerType(Type.parse("ui32")) - # CHECK: u32 signless: False - print("u32 signless:", u32.is_signless) - # CHECK: u32 signed: False - print("u32 signed:", u32.is_signed) - # CHECK: u32 unsigned: True - print("u32 unsigned:", u32.is_unsigned) - - # CHECK: signless: i16 - print("signless:", IntegerType.get_signless(16)) - # CHECK: signed: si8 - print("signed:", IntegerType.get_signed(8)) - # CHECK: unsigned: ui64 - print("unsigned:", IntegerType.get_unsigned(64)) + with Context() as ctx: + i32 = IntegerType(Type.parse("i32")) + # CHECK: i32 width: 32 + print("i32 width:", i32.width) + # CHECK: i32 signless: True + print("i32 signless:", i32.is_signless) + # CHECK: i32 signed: False + print("i32 signed:", i32.is_signed) + # CHECK: i32 unsigned: False + print("i32 unsigned:", i32.is_unsigned) + + s32 = IntegerType(Type.parse("si32")) + # CHECK: s32 signless: False + print("s32 signless:", s32.is_signless) + # CHECK: s32 signed: True + print("s32 signed:", s32.is_signed) + # CHECK: s32 unsigned: False + print("s32 unsigned:", s32.is_unsigned) + + u32 = IntegerType(Type.parse("ui32")) + # CHECK: u32 signless: False + print("u32 signless:", u32.is_signless) + # CHECK: u32 signed: False + print("u32 signed:", u32.is_signed) + # CHECK: u32 unsigned: True + print("u32 unsigned:", u32.is_unsigned) + + # CHECK: signless: i16 + print("signless:", IntegerType.get_signless(16)) + # CHECK: signed: si8 + print("signed:", IntegerType.get_signed(8)) + # CHECK: unsigned: ui64 + print("unsigned:", IntegerType.get_unsigned(64)) # CHECK-LABEL: TEST: testIndexType @run def testIndexType(): - with Context() as ctx: - # CHECK: index type: index - print("index type:", IndexType.get()) + with Context() as ctx: + # CHECK: index type: index + print("index type:", IndexType.get()) # CHECK-LABEL: TEST: testFloatType @run def testFloatType(): - with Context(): - # CHECK: float: f8E4M3FN - print("float:", Float8E4M3FNType.get()) - # CHECK: float: f8E5M2 - print("float:", Float8E5M2Type.get()) - # CHECK: float: f8E5M2FNUZ - print("float:", Float8E5M2FNUZType.get()) - # CHECK: float: f8E4M3FNUZ - print("float:", Float8E4M3FNUZType.get()) - # CHECK: float: f8E4M3B11FNUZ - print("float:", Float8E4M3B11FNUZType.get()) - # CHECK: float: bf16 - print("float:", BF16Type.get()) - # CHECK: float: f16 - print("float:", F16Type.get()) - # CHECK: float: f32 - print("float:", F32Type.get()) - # CHECK: float: f64 - print("float:", F64Type.get()) + with Context(): + # CHECK: float: f8E4M3FN + print("float:", Float8E4M3FNType.get()) + # CHECK: float: f8E5M2 + print("float:", Float8E5M2Type.get()) + # CHECK: float: f8E5M2FNUZ + print("float:", Float8E5M2FNUZType.get()) + # CHECK: float: f8E4M3FNUZ + print("float:", Float8E4M3FNUZType.get()) + # CHECK: float: f8E4M3B11FNUZ + print("float:", Float8E4M3B11FNUZType.get()) + # CHECK: float: bf16 + print("float:", BF16Type.get()) + # CHECK: float: f16 + print("float:", F16Type.get()) + # CHECK: float: f32 + print("float:", F32Type.get()) + # CHECK: float: f64 + print("float:", F64Type.get()) # CHECK-LABEL: TEST: testNoneType @run def testNoneType(): - with Context(): - # CHECK: none type: none - print("none type:", NoneType.get()) + with Context(): + # CHECK: none type: none + print("none type:", NoneType.get()) # CHECK-LABEL: TEST: testComplexType @run def testComplexType(): - with Context() as ctx: - complex_i32 = ComplexType(Type.parse("complex")) - # CHECK: complex type element: i32 - print("complex type element:", complex_i32.element_type) + with Context() as ctx: + complex_i32 = ComplexType(Type.parse("complex")) + # CHECK: complex type element: i32 + print("complex type element:", complex_i32.element_type) - f32 = F32Type.get() - # CHECK: complex type: complex - print("complex type:", ComplexType.get(f32)) + f32 = F32Type.get() + # CHECK: complex type: complex + print("complex type:", ComplexType.get(f32)) - index = IndexType.get() - try: - complex_invalid = ComplexType.get(index) - except ValueError as e: - # CHECK: invalid 'Type(index)' and expected floating point or integer type. - print(e) - else: - print("Exception not produced") + index = IndexType.get() + try: + complex_invalid = ComplexType.get(index) + except ValueError as e: + # CHECK: invalid 'Type(index)' and expected floating point or integer type. + print(e) + else: + print("Exception not produced") # CHECK-LABEL: TEST: testConcreteShapedType @@ -253,27 +253,26 @@ def testComplexType(): # shaped type. The class hierarchy is preserved on the python side. @run def testConcreteShapedType(): - with Context() as ctx: - vector = VectorType(Type.parse("vector<2x3xf32>")) - # CHECK: element type: f32 - print("element type:", vector.element_type) - # CHECK: whether the given shaped type is ranked: True - print("whether the given shaped type is ranked:", vector.has_rank) - # CHECK: rank: 2 - print("rank:", vector.rank) - # CHECK: whether the shaped type has a static shape: True - print("whether the shaped type has a static shape:", - vector.has_static_shape) - # CHECK: whether the dim-th dimension is dynamic: False - print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0)) - # CHECK: dim size: 3 - print("dim size:", vector.get_dim_size(1)) - # CHECK: is_dynamic_size: False - print("is_dynamic_size:", vector.is_dynamic_size(3)) - # CHECK: is_dynamic_stride_or_offset: False - print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1)) - # CHECK: isinstance(ShapedType): True - print("isinstance(ShapedType):", isinstance(vector, ShapedType)) + with Context() as ctx: + vector = VectorType(Type.parse("vector<2x3xf32>")) + # CHECK: element type: f32 + print("element type:", vector.element_type) + # CHECK: whether the given shaped type is ranked: True + print("whether the given shaped type is ranked:", vector.has_rank) + # CHECK: rank: 2 + print("rank:", vector.rank) + # CHECK: whether the shaped type has a static shape: True + print("whether the shaped type has a static shape:", vector.has_static_shape) + # CHECK: whether the dim-th dimension is dynamic: False + print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0)) + # CHECK: dim size: 3 + print("dim size:", vector.get_dim_size(1)) + # CHECK: is_dynamic_size: False + print("is_dynamic_size:", vector.is_dynamic_size(3)) + # CHECK: is_dynamic_stride_or_offset: False + print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1)) + # CHECK: isinstance(ShapedType): True + print("isinstance(ShapedType):", isinstance(vector, ShapedType)) # CHECK-LABEL: TEST: testAbstractShapedType @@ -281,321 +280,322 @@ def testConcreteShapedType(): # shaped type (using vector as an example). @run def testAbstractShapedType(): - ctx = Context() - vector = ShapedType(Type.parse("vector<2x3xf32>", ctx)) - # CHECK: element type: f32 - print("element type:", vector.element_type) + ctx = Context() + vector = ShapedType(Type.parse("vector<2x3xf32>", ctx)) + # CHECK: element type: f32 + print("element type:", vector.element_type) # CHECK-LABEL: TEST: testVectorType @run def testVectorType(): - with Context(), Location.unknown(): - f32 = F32Type.get() - shape = [2, 3] - # CHECK: vector type: vector<2x3xf32> - print("vector type:", VectorType.get(shape, f32)) - - none = NoneType.get() - try: - vector_invalid = VectorType.get(shape, none) - except MLIRError as e: - # CHECK: Invalid type: - # CHECK: error: unknown: vector elements must be int/index/float type but got 'none' - print(e) - else: - print("Exception not produced") + with Context(), Location.unknown(): + f32 = F32Type.get() + shape = [2, 3] + # CHECK: vector type: vector<2x3xf32> + print("vector type:", VectorType.get(shape, f32)) + + none = NoneType.get() + try: + vector_invalid = VectorType.get(shape, none) + except MLIRError as e: + # CHECK: Invalid type: + # CHECK: error: unknown: vector elements must be int/index/float type but got 'none' + print(e) + else: + print("Exception not produced") # CHECK-LABEL: TEST: testRankedTensorType @run def testRankedTensorType(): - with Context(), Location.unknown(): - f32 = F32Type.get() - shape = [2, 3] - loc = Location.unknown() - # CHECK: ranked tensor type: tensor<2x3xf32> - print("ranked tensor type:", RankedTensorType.get(shape, f32)) - - none = NoneType.get() - try: - tensor_invalid = RankedTensorType.get(shape, none) - except MLIRError as e: - # CHECK: Invalid type: - # CHECK: error: unknown: invalid tensor element type: 'none' - print(e) - else: - print("Exception not produced") - - # Encoding should be None. - assert RankedTensorType.get(shape, f32).encoding is None - - tensor = RankedTensorType.get(shape, f32) - assert tensor.shape == shape + with Context(), Location.unknown(): + f32 = F32Type.get() + shape = [2, 3] + loc = Location.unknown() + # CHECK: ranked tensor type: tensor<2x3xf32> + print("ranked tensor type:", RankedTensorType.get(shape, f32)) + + none = NoneType.get() + try: + tensor_invalid = RankedTensorType.get(shape, none) + except MLIRError as e: + # CHECK: Invalid type: + # CHECK: error: unknown: invalid tensor element type: 'none' + print(e) + else: + print("Exception not produced") + + # Encoding should be None. + assert RankedTensorType.get(shape, f32).encoding is None + + tensor = RankedTensorType.get(shape, f32) + assert tensor.shape == shape # CHECK-LABEL: TEST: testUnrankedTensorType @run def testUnrankedTensorType(): - with Context(), Location.unknown(): - f32 = F32Type.get() - loc = Location.unknown() - unranked_tensor = UnrankedTensorType.get(f32) - # CHECK: unranked tensor type: tensor<*xf32> - print("unranked tensor type:", unranked_tensor) - try: - invalid_rank = unranked_tensor.rank - except ValueError as e: - # CHECK: calling this method requires that the type has a rank. - print(e) - else: - print("Exception not produced") - try: - invalid_is_dynamic_dim = unranked_tensor.is_dynamic_dim(0) - except ValueError as e: - # CHECK: calling this method requires that the type has a rank. - print(e) - else: - print("Exception not produced") - try: - invalid_get_dim_size = unranked_tensor.get_dim_size(1) - except ValueError as e: - # CHECK: calling this method requires that the type has a rank. - print(e) - else: - print("Exception not produced") - - none = NoneType.get() - try: - tensor_invalid = UnrankedTensorType.get(none) - except MLIRError as e: - # CHECK: Invalid type: - # CHECK: error: unknown: invalid tensor element type: 'none' - print(e) - else: - print("Exception not produced") + with Context(), Location.unknown(): + f32 = F32Type.get() + loc = Location.unknown() + unranked_tensor = UnrankedTensorType.get(f32) + # CHECK: unranked tensor type: tensor<*xf32> + print("unranked tensor type:", unranked_tensor) + try: + invalid_rank = unranked_tensor.rank + except ValueError as e: + # CHECK: calling this method requires that the type has a rank. + print(e) + else: + print("Exception not produced") + try: + invalid_is_dynamic_dim = unranked_tensor.is_dynamic_dim(0) + except ValueError as e: + # CHECK: calling this method requires that the type has a rank. + print(e) + else: + print("Exception not produced") + try: + invalid_get_dim_size = unranked_tensor.get_dim_size(1) + except ValueError as e: + # CHECK: calling this method requires that the type has a rank. + print(e) + else: + print("Exception not produced") + + none = NoneType.get() + try: + tensor_invalid = UnrankedTensorType.get(none) + except MLIRError as e: + # CHECK: Invalid type: + # CHECK: error: unknown: invalid tensor element type: 'none' + print(e) + else: + print("Exception not produced") # CHECK-LABEL: TEST: testMemRefType @run def testMemRefType(): - with Context(), Location.unknown(): - f32 = F32Type.get() - shape = [2, 3] - loc = Location.unknown() - memref = MemRefType.get(shape, f32, memory_space=Attribute.parse("2")) - # CHECK: memref type: memref<2x3xf32, 2> - print("memref type:", memref) - # CHECK: memref layout: affine_map<(d0, d1) -> (d0, d1)> - print("memref layout:", memref.layout) - # CHECK: memref affine map: (d0, d1) -> (d0, d1) - print("memref affine map:", memref.affine_map) - # CHECK: memory space: 2 - print("memory space:", memref.memory_space) - - layout = AffineMapAttr.get(AffineMap.get_permutation([1, 0])) - memref_layout = MemRefType.get(shape, f32, layout=layout) - # CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>> - print("memref type:", memref_layout) - # CHECK: memref layout: affine_map<(d0, d1) -> (d1, d0)> - print("memref layout:", memref_layout.layout) - # CHECK: memref affine map: (d0, d1) -> (d1, d0) - print("memref affine map:", memref_layout.affine_map) - # CHECK: memory space: <> - print("memory space:", memref_layout.memory_space) - - none = NoneType.get() - try: - memref_invalid = MemRefType.get(shape, none) - except MLIRError as e: - # CHECK: Invalid type: - # CHECK: error: unknown: invalid memref element type - print(e) - else: - print("Exception not produced") - - assert memref.shape == shape + with Context(), Location.unknown(): + f32 = F32Type.get() + shape = [2, 3] + loc = Location.unknown() + memref = MemRefType.get(shape, f32, memory_space=Attribute.parse("2")) + # CHECK: memref type: memref<2x3xf32, 2> + print("memref type:", memref) + # CHECK: memref layout: affine_map<(d0, d1) -> (d0, d1)> + print("memref layout:", memref.layout) + # CHECK: memref affine map: (d0, d1) -> (d0, d1) + print("memref affine map:", memref.affine_map) + # CHECK: memory space: 2 + print("memory space:", memref.memory_space) + + layout = AffineMapAttr.get(AffineMap.get_permutation([1, 0])) + memref_layout = MemRefType.get(shape, f32, layout=layout) + # CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>> + print("memref type:", memref_layout) + # CHECK: memref layout: affine_map<(d0, d1) -> (d1, d0)> + print("memref layout:", memref_layout.layout) + # CHECK: memref affine map: (d0, d1) -> (d1, d0) + print("memref affine map:", memref_layout.affine_map) + # CHECK: memory space: <> + print("memory space:", memref_layout.memory_space) + + none = NoneType.get() + try: + memref_invalid = MemRefType.get(shape, none) + except MLIRError as e: + # CHECK: Invalid type: + # CHECK: error: unknown: invalid memref element type + print(e) + else: + print("Exception not produced") + + assert memref.shape == shape # CHECK-LABEL: TEST: testUnrankedMemRefType @run def testUnrankedMemRefType(): - with Context(), Location.unknown(): - f32 = F32Type.get() - loc = Location.unknown() - unranked_memref = UnrankedMemRefType.get(f32, Attribute.parse("2")) - # CHECK: unranked memref type: memref<*xf32, 2> - print("unranked memref type:", unranked_memref) - try: - invalid_rank = unranked_memref.rank - except ValueError as e: - # CHECK: calling this method requires that the type has a rank. - print(e) - else: - print("Exception not produced") - try: - invalid_is_dynamic_dim = unranked_memref.is_dynamic_dim(0) - except ValueError as e: - # CHECK: calling this method requires that the type has a rank. - print(e) - else: - print("Exception not produced") - try: - invalid_get_dim_size = unranked_memref.get_dim_size(1) - except ValueError as e: - # CHECK: calling this method requires that the type has a rank. - print(e) - else: - print("Exception not produced") - - none = NoneType.get() - try: - memref_invalid = UnrankedMemRefType.get(none, Attribute.parse("2")) - except MLIRError as e: - # CHECK: Invalid type: - # CHECK: error: unknown: invalid memref element type - print(e) - else: - print("Exception not produced") + with Context(), Location.unknown(): + f32 = F32Type.get() + loc = Location.unknown() + unranked_memref = UnrankedMemRefType.get(f32, Attribute.parse("2")) + # CHECK: unranked memref type: memref<*xf32, 2> + print("unranked memref type:", unranked_memref) + try: + invalid_rank = unranked_memref.rank + except ValueError as e: + # CHECK: calling this method requires that the type has a rank. + print(e) + else: + print("Exception not produced") + try: + invalid_is_dynamic_dim = unranked_memref.is_dynamic_dim(0) + except ValueError as e: + # CHECK: calling this method requires that the type has a rank. + print(e) + else: + print("Exception not produced") + try: + invalid_get_dim_size = unranked_memref.get_dim_size(1) + except ValueError as e: + # CHECK: calling this method requires that the type has a rank. + print(e) + else: + print("Exception not produced") + + none = NoneType.get() + try: + memref_invalid = UnrankedMemRefType.get(none, Attribute.parse("2")) + except MLIRError as e: + # CHECK: Invalid type: + # CHECK: error: unknown: invalid memref element type + print(e) + else: + print("Exception not produced") # CHECK-LABEL: TEST: testTupleType @run def testTupleType(): - with Context() as ctx: - i32 = IntegerType(Type.parse("i32")) - f32 = F32Type.get() - vector = VectorType(Type.parse("vector<2x3xf32>")) - l = [i32, f32, vector] - tuple_type = TupleType.get_tuple(l) - # CHECK: tuple type: tuple> - print("tuple type:", tuple_type) - # CHECK: number of types: 3 - print("number of types:", tuple_type.num_types) - # CHECK: pos-th type in the tuple type: f32 - print("pos-th type in the tuple type:", tuple_type.get_type(1)) + with Context() as ctx: + i32 = IntegerType(Type.parse("i32")) + f32 = F32Type.get() + vector = VectorType(Type.parse("vector<2x3xf32>")) + l = [i32, f32, vector] + tuple_type = TupleType.get_tuple(l) + # CHECK: tuple type: tuple> + print("tuple type:", tuple_type) + # CHECK: number of types: 3 + print("number of types:", tuple_type.num_types) + # CHECK: pos-th type in the tuple type: f32 + print("pos-th type in the tuple type:", tuple_type.get_type(1)) # CHECK-LABEL: TEST: testFunctionType @run def testFunctionType(): - with Context() as ctx: - input_types = [IntegerType.get_signless(32), IntegerType.get_signless(16)] - result_types = [IndexType.get()] - func = FunctionType.get(input_types, result_types) - # CHECK: INPUTS: [Type(i32), Type(i16)] - print("INPUTS:", func.inputs) - # CHECK: RESULTS: [Type(index)] - print("RESULTS:", func.results) + with Context() as ctx: + input_types = [IntegerType.get_signless(32), IntegerType.get_signless(16)] + result_types = [IndexType.get()] + func = FunctionType.get(input_types, result_types) + # CHECK: INPUTS: [Type(i32), Type(i16)] + print("INPUTS:", func.inputs) + # CHECK: RESULTS: [Type(index)] + print("RESULTS:", func.results) # CHECK-LABEL: TEST: testOpaqueType @run def testOpaqueType(): - with Context() as ctx: - ctx.allow_unregistered_dialects = True - opaque = OpaqueType.get("dialect", "type") - # CHECK: opaque type: !dialect.type - print("opaque type:", opaque) - # CHECK: dialect namespace: dialect - print("dialect namespace:", opaque.dialect_namespace) - # CHECK: data: type - print("data:", opaque.data) + with Context() as ctx: + ctx.allow_unregistered_dialects = True + opaque = OpaqueType.get("dialect", "type") + # CHECK: opaque type: !dialect.type + print("opaque type:", opaque) + # CHECK: dialect namespace: dialect + print("dialect namespace:", opaque.dialect_namespace) + # CHECK: data: type + print("data:", opaque.data) # CHECK-LABEL: TEST: testShapedTypeConstants # Tests that ShapedType exposes magic value constants. @run def testShapedTypeConstants(): - # CHECK: - print(type(ShapedType.get_dynamic_size())) - # CHECK: - print(type(ShapedType.get_dynamic_stride_or_offset())) + # CHECK: + print(type(ShapedType.get_dynamic_size())) + # CHECK: + print(type(ShapedType.get_dynamic_stride_or_offset())) # CHECK-LABEL: TEST: testTypeIDs @run def testTypeIDs(): - with Context(), Location.unknown(): - f32 = F32Type.get() - - types = [ - (IntegerType, IntegerType.get_signless(16)), - (IndexType, IndexType.get()), - (Float8E4M3FNType, Float8E4M3FNType.get()), - (Float8E5M2Type, Float8E5M2Type.get()), - (Float8E4M3FNUZType, Float8E4M3FNUZType.get()), - (Float8E4M3B11FNUZType, Float8E4M3B11FNUZType.get()), - (Float8E5M2FNUZType, Float8E5M2FNUZType.get()), - (BF16Type, BF16Type.get()), - (F16Type, F16Type.get()), - (F32Type, F32Type.get()), - (F64Type, F64Type.get()), - (NoneType, NoneType.get()), - (ComplexType, ComplexType.get(f32)), - (VectorType, VectorType.get([2, 3], f32)), - (RankedTensorType, RankedTensorType.get([2, 3], f32)), - (UnrankedTensorType, UnrankedTensorType.get(f32)), - (MemRefType, MemRefType.get([2, 3], f32)), - (UnrankedMemRefType, UnrankedMemRefType.get(f32, Attribute.parse("2"))), - (TupleType, TupleType.get_tuple([f32])), - (FunctionType, FunctionType.get([], [])), - (OpaqueType, OpaqueType.get("tensor", "bob")), - ] - - # CHECK: IntegerType(i16) - # CHECK: IndexType(index) - # CHECK: Float8E4M3FNType(f8E4M3FN) - # CHECK: Float8E5M2Type(f8E5M2) - # CHECK: Float8E4M3FNUZType(f8E4M3FNUZ) - # CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ) - # CHECK: Float8E5M2FNUZType(f8E5M2FNUZ) - # CHECK: BF16Type(bf16) - # CHECK: F16Type(f16) - # CHECK: F32Type(f32) - # CHECK: F64Type(f64) - # CHECK: NoneType(none) - # CHECK: ComplexType(complex) - # CHECK: VectorType(vector<2x3xf32>) - # CHECK: RankedTensorType(tensor<2x3xf32>) - # CHECK: UnrankedTensorType(tensor<*xf32>) - # CHECK: MemRefType(memref<2x3xf32>) - # CHECK: UnrankedMemRefType(memref<*xf32, 2>) - # CHECK: TupleType(tuple) - # CHECK: FunctionType(() -> ()) - # CHECK: OpaqueType(!tensor.bob) - for _, t in types: - print(repr(t)) - - # Test getTypeIdFunction agrees with - # mlirTypeGetTypeID(self) for an instance. - # CHECK: all equal - for t1, t2 in types: - tid1, tid2 = t1.static_typeid, Type(t2).typeid - assert tid1 == tid2 and hash(tid1) == hash( - tid2), f"expected hash and value equality {t1} {t2}" - else: - print("all equal") - - # Test that storing PyTypeID in python dicts - # works as expected. - typeid_dict = dict(types) - assert len(typeid_dict) - - # CHECK: all equal - for t1, t2 in typeid_dict.items(): - assert t1.static_typeid == t2.typeid and hash( - t1.static_typeid) == hash( - t2.typeid), f"expected hash and value equality {t1} {t2}" - else: - print("all equal") - - # CHECK: ShapedType has no typeid. - try: - print(ShapedType.static_typeid) - except AttributeError as e: - print(e) - - vector_type = Type.parse("vector<2x3xf32>") - # CHECK: True - print(ShapedType(vector_type).typeid == vector_type.typeid) + with Context(), Location.unknown(): + f32 = F32Type.get() + + types = [ + (IntegerType, IntegerType.get_signless(16)), + (IndexType, IndexType.get()), + (Float8E4M3FNType, Float8E4M3FNType.get()), + (Float8E5M2Type, Float8E5M2Type.get()), + (Float8E4M3FNUZType, Float8E4M3FNUZType.get()), + (Float8E4M3B11FNUZType, Float8E4M3B11FNUZType.get()), + (Float8E5M2FNUZType, Float8E5M2FNUZType.get()), + (BF16Type, BF16Type.get()), + (F16Type, F16Type.get()), + (F32Type, F32Type.get()), + (F64Type, F64Type.get()), + (NoneType, NoneType.get()), + (ComplexType, ComplexType.get(f32)), + (VectorType, VectorType.get([2, 3], f32)), + (RankedTensorType, RankedTensorType.get([2, 3], f32)), + (UnrankedTensorType, UnrankedTensorType.get(f32)), + (MemRefType, MemRefType.get([2, 3], f32)), + (UnrankedMemRefType, UnrankedMemRefType.get(f32, Attribute.parse("2"))), + (TupleType, TupleType.get_tuple([f32])), + (FunctionType, FunctionType.get([], [])), + (OpaqueType, OpaqueType.get("tensor", "bob")), + ] + + # CHECK: IntegerType(i16) + # CHECK: IndexType(index) + # CHECK: Float8E4M3FNType(f8E4M3FN) + # CHECK: Float8E5M2Type(f8E5M2) + # CHECK: Float8E4M3FNUZType(f8E4M3FNUZ) + # CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ) + # CHECK: Float8E5M2FNUZType(f8E5M2FNUZ) + # CHECK: BF16Type(bf16) + # CHECK: F16Type(f16) + # CHECK: F32Type(f32) + # CHECK: F64Type(f64) + # CHECK: NoneType(none) + # CHECK: ComplexType(complex) + # CHECK: VectorType(vector<2x3xf32>) + # CHECK: RankedTensorType(tensor<2x3xf32>) + # CHECK: UnrankedTensorType(tensor<*xf32>) + # CHECK: MemRefType(memref<2x3xf32>) + # CHECK: UnrankedMemRefType(memref<*xf32, 2>) + # CHECK: TupleType(tuple) + # CHECK: FunctionType(() -> ()) + # CHECK: OpaqueType(!tensor.bob) + for _, t in types: + print(repr(t)) + + # Test getTypeIdFunction agrees with + # mlirTypeGetTypeID(self) for an instance. + # CHECK: all equal + for t1, t2 in types: + tid1, tid2 = t1.static_typeid, Type(t2).typeid + assert tid1 == tid2 and hash(tid1) == hash( + tid2 + ), f"expected hash and value equality {t1} {t2}" + else: + print("all equal") + + # Test that storing PyTypeID in python dicts + # works as expected. + typeid_dict = dict(types) + assert len(typeid_dict) + + # CHECK: all equal + for t1, t2 in typeid_dict.items(): + assert t1.static_typeid == t2.typeid and hash(t1.static_typeid) == hash( + t2.typeid + ), f"expected hash and value equality {t1} {t2}" + else: + print("all equal") + + # CHECK: ShapedType has no typeid. + try: + print(ShapedType.static_typeid) + except AttributeError as e: + print(e) + + vector_type = Type.parse("vector<2x3xf32>") + # CHECK: True + print(ShapedType(vector_type).typeid == vector_type.typeid) diff --git a/mlir/test/python/ir/context_managers.py b/mlir/test/python/ir/context_managers.py index b93fcf7..48d9e35 100644 --- a/mlir/test/python/ir/context_managers.py +++ b/mlir/test/python/ir/context_managers.py @@ -3,97 +3,110 @@ import gc from mlir.ir import * + def run(f): - print("\nTEST:", f.__name__) - f() - gc.collect() - assert Context._get_live_count() == 0 + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 # CHECK-LABEL: TEST: testContextEnterExit def testContextEnterExit(): - with Context() as ctx: - assert Context.current is ctx - try: - _ = Context.current - except ValueError as e: - # CHECK: No current Context - print(e) - else: assert False, "Expected exception" + with Context() as ctx: + assert Context.current is ctx + try: + _ = Context.current + except ValueError as e: + # CHECK: No current Context + print(e) + else: + assert False, "Expected exception" + run(testContextEnterExit) # CHECK-LABEL: TEST: testLocationEnterExit def testLocationEnterExit(): - ctx1 = Context() - with Location.unknown(ctx1) as loc1: - assert Context.current is ctx1 - assert Location.current is loc1 - - # Re-asserting the same context should not change the location. - with ctx1: - assert Context.current is ctx1 - assert Location.current is loc1 - # Asserting a different context should clear it. - with Context() as ctx2: - assert Context.current is ctx2 - try: - _ = Location.current - except ValueError: pass - else: assert False, "Expected exception" - - # And should restore. - assert Context.current is ctx1 - assert Location.current is loc1 - - # All should clear. - try: - _ = Location.current - except ValueError as e: - # CHECK: No current Location - print(e) - else: assert False, "Expected exception" + ctx1 = Context() + with Location.unknown(ctx1) as loc1: + assert Context.current is ctx1 + assert Location.current is loc1 + + # Re-asserting the same context should not change the location. + with ctx1: + assert Context.current is ctx1 + assert Location.current is loc1 + # Asserting a different context should clear it. + with Context() as ctx2: + assert Context.current is ctx2 + try: + _ = Location.current + except ValueError: + pass + else: + assert False, "Expected exception" + + # And should restore. + assert Context.current is ctx1 + assert Location.current is loc1 + + # All should clear. + try: + _ = Location.current + except ValueError as e: + # CHECK: No current Location + print(e) + else: + assert False, "Expected exception" + run(testLocationEnterExit) # CHECK-LABEL: TEST: testInsertionPointEnterExit def testInsertionPointEnterExit(): - ctx1 = Context() - m = Module.create(Location.unknown(ctx1)) - ip = InsertionPoint(m.body) - - with ip: - assert InsertionPoint.current is ip - # Asserting a location from the same context should preserve. - with Location.unknown(ctx1) as loc1: - assert InsertionPoint.current is ip - assert Location.current is loc1 - # Location should clear. + ctx1 = Context() + m = Module.create(Location.unknown(ctx1)) + ip = InsertionPoint(m.body) + + with ip: + assert InsertionPoint.current is ip + # Asserting a location from the same context should preserve. + with Location.unknown(ctx1) as loc1: + assert InsertionPoint.current is ip + assert Location.current is loc1 + # Location should clear. + try: + _ = Location.current + except ValueError: + pass + else: + assert False, "Expected exception" + + # Asserting the same Context should preserve. + with ctx1: + assert InsertionPoint.current is ip + + # Asserting a different context should clear it. + with Context() as ctx2: + assert Context.current is ctx2 + try: + _ = InsertionPoint.current + except ValueError: + pass + else: + assert False, "Expected exception" + + # All should clear. try: - _ = Location.current - except ValueError: pass - else: assert False, "Expected exception" - - # Asserting the same Context should preserve. - with ctx1: - assert InsertionPoint.current is ip - - # Asserting a different context should clear it. - with Context() as ctx2: - assert Context.current is ctx2 - try: _ = InsertionPoint.current - except ValueError: pass - else: assert False, "Expected exception" - - # All should clear. - try: - _ = InsertionPoint.current - except ValueError as e: - # CHECK: No current InsertionPoint - print(e) - else: assert False, "Expected exception" + except ValueError as e: + # CHECK: No current InsertionPoint + print(e) + else: + assert False, "Expected exception" + run(testInsertionPointEnterExit) diff --git a/mlir/test/python/ir/debug.py b/mlir/test/python/ir/debug.py index 3268d9f..629a710 100644 --- a/mlir/test/python/ir/debug.py +++ b/mlir/test/python/ir/debug.py @@ -2,38 +2,40 @@ from mlir.ir import * + def run(f): - print("\nTEST:", f.__name__) - f() + print("\nTEST:", f.__name__) + f() # CHECK-LABEL: TEST: testNameIsPrivate def testNameIsPrivate(): - # `import *` ignores private names starting with an understore, so the debug - # flag shouldn't be visible unless explicitly imported. - try: - _GlobalDebug.flag = True - except NameError: - pass - else: - assert False, "_GlobalDebug must not be available by default" + # `import *` ignores private names starting with an understore, so the debug + # flag shouldn't be visible unless explicitly imported. + try: + _GlobalDebug.flag = True + except NameError: + pass + else: + assert False, "_GlobalDebug must not be available by default" + run(testNameIsPrivate) # CHECK-LABEL: TEST: testDebugDlag def testDebugDlag(): - # Private names must be imported expilcitly. - from mlir.ir import _GlobalDebug - - # CHECK: False - print(_GlobalDebug.flag) - _GlobalDebug.flag = True - # CHECK: True - print(_GlobalDebug.flag) - _GlobalDebug.flag = False - # CHECK: False - print(_GlobalDebug.flag) + # Private names must be imported expilcitly. + from mlir.ir import _GlobalDebug + + # CHECK: False + print(_GlobalDebug.flag) + _GlobalDebug.flag = True + # CHECK: True + print(_GlobalDebug.flag) + _GlobalDebug.flag = False + # CHECK: False + print(_GlobalDebug.flag) -run(testDebugDlag) +run(testDebugDlag) diff --git a/mlir/test/python/ir/diagnostic_handler.py b/mlir/test/python/ir/diagnostic_handler.py index cc07f6e..2f4300d 100644 --- a/mlir/test/python/ir/diagnostic_handler.py +++ b/mlir/test/python/ir/diagnostic_handler.py @@ -3,191 +3,222 @@ import gc from mlir.ir import * + def run(f): - print("\nTEST:", f.__name__) - f() - gc.collect() - assert Context._get_live_count() == 0 - return f + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + return f @run def testLifecycleContextDestroy(): - ctx = Context() - def callback(foo): ... - handler = ctx.attach_diagnostic_handler(callback) - assert handler.attached - # If context is destroyed before the handler, it should auto-detach. - ctx = None - gc.collect() - assert not handler.attached + ctx = Context() + + def callback(foo): + ... + + handler = ctx.attach_diagnostic_handler(callback) + assert handler.attached + # If context is destroyed before the handler, it should auto-detach. + ctx = None + gc.collect() + assert not handler.attached - # And finally collecting the handler should be fine. - handler = None - gc.collect() + # And finally collecting the handler should be fine. + handler = None + gc.collect() @run def testLifecycleExplicitDetach(): - ctx = Context() - def callback(foo): ... - handler = ctx.attach_diagnostic_handler(callback) - assert handler.attached - handler.detach() - assert not handler.attached + ctx = Context() + + def callback(foo): + ... + + handler = ctx.attach_diagnostic_handler(callback) + assert handler.attached + handler.detach() + assert not handler.attached @run def testLifecycleWith(): - ctx = Context() - def callback(foo): ... - with ctx.attach_diagnostic_handler(callback) as handler: - assert handler.attached - assert not handler.attached + ctx = Context() + + def callback(foo): + ... + + with ctx.attach_diagnostic_handler(callback) as handler: + assert handler.attached + assert not handler.attached @run def testLifecycleWithAndExplicitDetach(): - ctx = Context() - def callback(foo): ... - with ctx.attach_diagnostic_handler(callback) as handler: - assert handler.attached - handler.detach() - assert not handler.attached + ctx = Context() + + def callback(foo): + ... + + with ctx.attach_diagnostic_handler(callback) as handler: + assert handler.attached + handler.detach() + assert not handler.attached # CHECK-LABEL: TEST: testDiagnosticCallback @run def testDiagnosticCallback(): - ctx = Context() - def callback(d): - # CHECK: DIAGNOSTIC: message='foobar', severity=DiagnosticSeverity.ERROR, loc=loc(unknown) - print(f"DIAGNOSTIC: message='{d.message}', severity={d.severity}, loc={d.location}") - return True - handler = ctx.attach_diagnostic_handler(callback) - loc = Location.unknown(ctx) - loc.emit_error("foobar") - assert not handler.had_error + ctx = Context() + + def callback(d): + # CHECK: DIAGNOSTIC: message='foobar', severity=DiagnosticSeverity.ERROR, loc=loc(unknown) + print( + f"DIAGNOSTIC: message='{d.message}', severity={d.severity}, loc={d.location}" + ) + return True + + handler = ctx.attach_diagnostic_handler(callback) + loc = Location.unknown(ctx) + loc.emit_error("foobar") + assert not handler.had_error # CHECK-LABEL: TEST: testDiagnosticEmptyNotes # TODO: Come up with a way to inject a diagnostic with notes from this API. @run def testDiagnosticEmptyNotes(): - ctx = Context() - def callback(d): - # CHECK: DIAGNOSTIC: notes=() - print(f"DIAGNOSTIC: notes={d.notes}") - return True - handler = ctx.attach_diagnostic_handler(callback) - loc = Location.unknown(ctx) - loc.emit_error("foobar") - assert not handler.had_error + ctx = Context() + + def callback(d): + # CHECK: DIAGNOSTIC: notes=() + print(f"DIAGNOSTIC: notes={d.notes}") + return True + + handler = ctx.attach_diagnostic_handler(callback) + loc = Location.unknown(ctx) + loc.emit_error("foobar") + assert not handler.had_error # CHECK-LABEL: TEST: testDiagnosticNonEmptyNotes @run def testDiagnosticNonEmptyNotes(): - ctx = Context() - ctx.emit_error_diagnostics = True - def callback(d): - # CHECK: DIAGNOSTIC: - # CHECK: message='arith.addi' op requires one result - # CHECK: notes=['see current operation: "arith.addi"() : () -> ()'] - print(f"DIAGNOSTIC:") - print(f" message={d.message}") - print(f" notes={list(map(str, d.notes))}") - return True - handler = ctx.attach_diagnostic_handler(callback) - loc = Location.unknown(ctx) - try: - Operation.create('arith.addi', loc=loc).verify() - except MLIRError: - pass - assert not handler.had_error + ctx = Context() + ctx.emit_error_diagnostics = True + + def callback(d): + # CHECK: DIAGNOSTIC: + # CHECK: message='arith.addi' op requires one result + # CHECK: notes=['see current operation: "arith.addi"() : () -> ()'] + print(f"DIAGNOSTIC:") + print(f" message={d.message}") + print(f" notes={list(map(str, d.notes))}") + return True + + handler = ctx.attach_diagnostic_handler(callback) + loc = Location.unknown(ctx) + try: + Operation.create("arith.addi", loc=loc).verify() + except MLIRError: + pass + assert not handler.had_error + # CHECK-LABEL: TEST: testDiagnosticCallbackException @run def testDiagnosticCallbackException(): - ctx = Context() - def callback(d): - raise ValueError("Error in handler") - handler = ctx.attach_diagnostic_handler(callback) - loc = Location.unknown(ctx) - loc.emit_error("foobar") - assert handler.had_error + ctx = Context() + + def callback(d): + raise ValueError("Error in handler") + + handler = ctx.attach_diagnostic_handler(callback) + loc = Location.unknown(ctx) + loc.emit_error("foobar") + assert handler.had_error # CHECK-LABEL: TEST: testEscapingDiagnostic @run def testEscapingDiagnostic(): - ctx = Context() - diags = [] - def callback(d): - diags.append(d) - return True - handler = ctx.attach_diagnostic_handler(callback) - loc = Location.unknown(ctx) - loc.emit_error("foobar") - assert not handler.had_error - - # CHECK: DIAGNOSTIC: - print(f"DIAGNOSTIC: {str(diags[0])}") - try: - diags[0].severity - raise RuntimeError("expected exception") - except ValueError: - pass - try: - diags[0].location - raise RuntimeError("expected exception") - except ValueError: - pass - try: - diags[0].message - raise RuntimeError("expected exception") - except ValueError: - pass - try: - diags[0].notes - raise RuntimeError("expected exception") - except ValueError: - pass - + ctx = Context() + diags = [] + + def callback(d): + diags.append(d) + return True + + handler = ctx.attach_diagnostic_handler(callback) + loc = Location.unknown(ctx) + loc.emit_error("foobar") + assert not handler.had_error + + # CHECK: DIAGNOSTIC: + print(f"DIAGNOSTIC: {str(diags[0])}") + try: + diags[0].severity + raise RuntimeError("expected exception") + except ValueError: + pass + try: + diags[0].location + raise RuntimeError("expected exception") + except ValueError: + pass + try: + diags[0].message + raise RuntimeError("expected exception") + except ValueError: + pass + try: + diags[0].notes + raise RuntimeError("expected exception") + except ValueError: + pass # CHECK-LABEL: TEST: testDiagnosticReturnTrueHandles @run def testDiagnosticReturnTrueHandles(): - ctx = Context() - def callback1(d): - print(f"CALLBACK1: {d}") - return True - def callback2(d): - print(f"CALLBACK2: {d}") - return True - ctx.attach_diagnostic_handler(callback1) - ctx.attach_diagnostic_handler(callback2) - loc = Location.unknown(ctx) - # CHECK-NOT: CALLBACK1 - # CHECK: CALLBACK2: foobar - # CHECK-NOT: CALLBACK1 - loc.emit_error("foobar") + ctx = Context() + + def callback1(d): + print(f"CALLBACK1: {d}") + return True + + def callback2(d): + print(f"CALLBACK2: {d}") + return True + + ctx.attach_diagnostic_handler(callback1) + ctx.attach_diagnostic_handler(callback2) + loc = Location.unknown(ctx) + # CHECK-NOT: CALLBACK1 + # CHECK: CALLBACK2: foobar + # CHECK-NOT: CALLBACK1 + loc.emit_error("foobar") # CHECK-LABEL: TEST: testDiagnosticReturnFalseDoesNotHandle @run def testDiagnosticReturnFalseDoesNotHandle(): - ctx = Context() - def callback1(d): - print(f"CALLBACK1: {d}") - return True - def callback2(d): - print(f"CALLBACK2: {d}") - return False - ctx.attach_diagnostic_handler(callback1) - ctx.attach_diagnostic_handler(callback2) - loc = Location.unknown(ctx) - # CHECK: CALLBACK2: foobar - # CHECK: CALLBACK1: foobar - loc.emit_error("foobar") + ctx = Context() + + def callback1(d): + print(f"CALLBACK1: {d}") + return True + + def callback2(d): + print(f"CALLBACK2: {d}") + return False + + ctx.attach_diagnostic_handler(callback1) + ctx.attach_diagnostic_handler(callback2) + loc = Location.unknown(ctx) + # CHECK: CALLBACK2: foobar + # CHECK: CALLBACK1: foobar + loc.emit_error("foobar") diff --git a/mlir/test/python/ir/dialects.py b/mlir/test/python/ir/dialects.py index 65e81e8..eebf7c3 100644 --- a/mlir/test/python/ir/dialects.py +++ b/mlir/test/python/ir/dialects.py @@ -5,60 +5,60 @@ from mlir.ir import * def run(f): - print("\nTEST:", f.__name__) - f() - gc.collect() - assert Context._get_live_count() == 0 - return f + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + return f # CHECK-LABEL: TEST: testDialectDescriptor @run def testDialectDescriptor(): - ctx = Context() - d = ctx.get_dialect_descriptor("func") - # CHECK: - print(d) - # CHECK: func - print(d.namespace) - try: - _ = ctx.get_dialect_descriptor("not_existing") - except ValueError: - pass - else: - assert False, "Expected exception" + ctx = Context() + d = ctx.get_dialect_descriptor("func") + # CHECK: + print(d) + # CHECK: func + print(d.namespace) + try: + _ = ctx.get_dialect_descriptor("not_existing") + except ValueError: + pass + else: + assert False, "Expected exception" # CHECK-LABEL: TEST: testUserDialectClass @run def testUserDialectClass(): - ctx = Context() - # Access using attribute. - d = ctx.dialects.func - # CHECK: - print(d) - try: - _ = ctx.dialects.not_existing - except AttributeError: - pass - else: - assert False, "Expected exception" - - # Access using index. - d = ctx.dialects["func"] - # CHECK: - print(d) - try: - _ = ctx.dialects["not_existing"] - except IndexError: - pass - else: - assert False, "Expected exception" - - # Using the 'd' alias. - d = ctx.d["func"] - # CHECK: - print(d) + ctx = Context() + # Access using attribute. + d = ctx.dialects.func + # CHECK: + print(d) + try: + _ = ctx.dialects.not_existing + except AttributeError: + pass + else: + assert False, "Expected exception" + + # Access using index. + d = ctx.dialects["func"] + # CHECK: + print(d) + try: + _ = ctx.dialects["not_existing"] + except IndexError: + pass + else: + assert False, "Expected exception" + + # Using the 'd' alias. + d = ctx.d["func"] + # CHECK: + print(d) # CHECK-LABEL: TEST: testCustomOpView @@ -67,40 +67,40 @@ def testUserDialectClass(): # additional capabilities come online. @run def testCustomOpView(): + def createInput(): + op = Operation.create("pytest_dummy.intinput", results=[f32]) + # TODO: Auto result cast from operation + return op.results[0] - def createInput(): - op = Operation.create("pytest_dummy.intinput", results=[f32]) - # TODO: Auto result cast from operation - return op.results[0] + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + m = Module.create() - with Context() as ctx, Location.unknown(): - ctx.allow_unregistered_dialects = True - m = Module.create() + with InsertionPoint(m.body): + f32 = F32Type.get() + # Create via dialects context collection. + input1 = createInput() + input2 = createInput() + op1 = ctx.dialects.arith.AddFOp(input1, input2) - with InsertionPoint(m.body): - f32 = F32Type.get() - # Create via dialects context collection. - input1 = createInput() - input2 = createInput() - op1 = ctx.dialects.arith.AddFOp(input1, input2) + # Create via an import + from mlir.dialects.arith import AddFOp - # Create via an import - from mlir.dialects.arith import AddFOp - AddFOp(input1, op1.result) + AddFOp(input1, op1.result) - # CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput" - # CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput" - # CHECK: %[[R0:.*]] = arith.addf %[[INPUT0]], %[[INPUT1]] : f32 - # CHECK: %[[R1:.*]] = arith.addf %[[INPUT0]], %[[R0]] : f32 - m.operation.print() + # CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput" + # CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput" + # CHECK: %[[R0:.*]] = arith.addf %[[INPUT0]], %[[INPUT1]] : f32 + # CHECK: %[[R1:.*]] = arith.addf %[[INPUT0]], %[[R0]] : f32 + m.operation.print() # CHECK-LABEL: TEST: testIsRegisteredOperation @run def testIsRegisteredOperation(): - ctx = Context() + ctx = Context() - # CHECK: cf.cond_br: True - print(f"cf.cond_br: {ctx.is_registered_operation('cf.cond_br')}") - # CHECK: func.not_existing: False - print(f"func.not_existing: {ctx.is_registered_operation('func.not_existing')}") + # CHECK: cf.cond_br: True + print(f"cf.cond_br: {ctx.is_registered_operation('cf.cond_br')}") + # CHECK: func.not_existing: False + print(f"func.not_existing: {ctx.is_registered_operation('func.not_existing')}") diff --git a/mlir/test/python/ir/exception.py b/mlir/test/python/ir/exception.py index 6cb2375..74085cd 100644 --- a/mlir/test/python/ir/exception.py +++ b/mlir/test/python/ir/exception.py @@ -3,75 +3,93 @@ import gc from mlir.ir import * + def run(f): - print("\nTEST:", f.__name__) - f() - gc.collect() - assert Context._get_live_count() == 0 - return f + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + return f # CHECK-LABEL: TEST: test_exception @run def test_exception(): - ctx = Context() - ctx.allow_unregistered_dialects = True - try: - Operation.parse(""" + ctx = Context() + ctx.allow_unregistered_dialects = True + try: + Operation.parse( + """ func.func @foo() { "test.use"(%0) : (i64) -> () loc("use") %0 = "test.def"() : () -> i64 loc("def") return } - """, context=ctx) - except MLIRError as e: - # CHECK: Exception: < - # CHECK: Unable to parse operation assembly: - # CHECK: error: "use": operand #0 does not dominate this use - # CHECK: note: "use": see current operation: "test.use"(%0) : (i64) -> () - # CHECK: note: "def": operand defined here (op in the same block) - # CHECK: > - print(f"Exception: <{e}>") + """, + context=ctx, + ) + except MLIRError as e: + # CHECK: Exception: < + # CHECK: Unable to parse operation assembly: + # CHECK: error: "use": operand #0 does not dominate this use + # CHECK: note: "use": see current operation: "test.use"(%0) : (i64) -> () + # CHECK: note: "def": operand defined here (op in the same block) + # CHECK: > + print(f"Exception: <{e}>") - # CHECK: message: Unable to parse operation assembly - print(f"message: {e.message}") + # CHECK: message: Unable to parse operation assembly + print(f"message: {e.message}") - # CHECK: error_diagnostics[0]: loc("use") operand #0 does not dominate this use - # CHECK: error_diagnostics[0].notes[0]: loc("use") see current operation: "test.use"(%0) : (i64) -> () - # CHECK: error_diagnostics[0].notes[1]: loc("def") operand defined here (op in the same block) - print("error_diagnostics[0]: ", e.error_diagnostics[0].location, e.error_diagnostics[0].message) - print("error_diagnostics[0].notes[0]: ", e.error_diagnostics[0].notes[0].location, e.error_diagnostics[0].notes[0].message) - print("error_diagnostics[0].notes[1]: ", e.error_diagnostics[0].notes[1].location, e.error_diagnostics[0].notes[1].message) + # CHECK: error_diagnostics[0]: loc("use") operand #0 does not dominate this use + # CHECK: error_diagnostics[0].notes[0]: loc("use") see current operation: "test.use"(%0) : (i64) -> () + # CHECK: error_diagnostics[0].notes[1]: loc("def") operand defined here (op in the same block) + print( + "error_diagnostics[0]: ", + e.error_diagnostics[0].location, + e.error_diagnostics[0].message, + ) + print( + "error_diagnostics[0].notes[0]: ", + e.error_diagnostics[0].notes[0].location, + e.error_diagnostics[0].notes[0].message, + ) + print( + "error_diagnostics[0].notes[1]: ", + e.error_diagnostics[0].notes[1].location, + e.error_diagnostics[0].notes[1].message, + ) # CHECK-LABEL: test_emit_error_diagnostics @run def test_emit_error_diagnostics(): - ctx = Context() - loc = Location.unknown(ctx) - handler_diags = [] - def handler(d): - handler_diags.append(str(d)) - return True - ctx.attach_diagnostic_handler(handler) + ctx = Context() + loc = Location.unknown(ctx) + handler_diags = [] + + def handler(d): + handler_diags.append(str(d)) + return True + + ctx.attach_diagnostic_handler(handler) - try: - Attribute.parse("not an attr", ctx) - except MLIRError as e: - # CHECK: emit_error_diagnostics=False: - # CHECK: e.error_diagnostics: ['expected attribute value'] - # CHECK: handler_diags: [] - print(f"emit_error_diagnostics=False:") - print(f"e.error_diagnostics: {[str(diag) for diag in e.error_diagnostics]}") - print(f"handler_diags: {handler_diags}") + try: + Attribute.parse("not an attr", ctx) + except MLIRError as e: + # CHECK: emit_error_diagnostics=False: + # CHECK: e.error_diagnostics: ['expected attribute value'] + # CHECK: handler_diags: [] + print(f"emit_error_diagnostics=False:") + print(f"e.error_diagnostics: {[str(diag) for diag in e.error_diagnostics]}") + print(f"handler_diags: {handler_diags}") - ctx.emit_error_diagnostics = True - try: - Attribute.parse("not an attr", ctx) - except MLIRError as e: - # CHECK: emit_error_diagnostics=True: - # CHECK: e.error_diagnostics: [] - # CHECK: handler_diags: ['expected attribute value'] - print(f"emit_error_diagnostics=True:") - print(f"e.error_diagnostics: {[str(diag) for diag in e.error_diagnostics]}") - print(f"handler_diags: {handler_diags}") + ctx.emit_error_diagnostics = True + try: + Attribute.parse("not an attr", ctx) + except MLIRError as e: + # CHECK: emit_error_diagnostics=True: + # CHECK: e.error_diagnostics: [] + # CHECK: handler_diags: ['expected attribute value'] + print(f"emit_error_diagnostics=True:") + print(f"e.error_diagnostics: {[str(diag) for diag in e.error_diagnostics]}") + print(f"handler_diags: {handler_diags}") diff --git a/mlir/test/python/ir/insertion_point.py b/mlir/test/python/ir/insertion_point.py index 81a6ec2..0dc7d75 100644 --- a/mlir/test/python/ir/insertion_point.py +++ b/mlir/test/python/ir/insertion_point.py @@ -5,168 +5,191 @@ import io import itertools from mlir.ir import * + def run(f): - print("\nTEST:", f.__name__) - f() - gc.collect() - assert Context._get_live_count() == 0 + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 # CHECK-LABEL: TEST: test_insert_at_block_end def test_insert_at_block_end(): - ctx = Context() - ctx.allow_unregistered_dialects = True - with Location.unknown(ctx): - module = Module.parse(r""" + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + module = Module.parse( + r""" func.func @foo() -> () { "custom.op1"() : () -> () } - """) - entry_block = module.body.operations[0].regions[0].blocks[0] - ip = InsertionPoint(entry_block) - ip.insert(Operation.create("custom.op2")) - # CHECK: "custom.op1" - # CHECK: "custom.op2" - module.operation.print() + """ + ) + entry_block = module.body.operations[0].regions[0].blocks[0] + ip = InsertionPoint(entry_block) + ip.insert(Operation.create("custom.op2")) + # CHECK: "custom.op1" + # CHECK: "custom.op2" + module.operation.print() + run(test_insert_at_block_end) # CHECK-LABEL: TEST: test_insert_before_operation def test_insert_before_operation(): - ctx = Context() - ctx.allow_unregistered_dialects = True - with Location.unknown(ctx): - module = Module.parse(r""" + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + module = Module.parse( + r""" func.func @foo() -> () { "custom.op1"() : () -> () "custom.op2"() : () -> () } - """) - entry_block = module.body.operations[0].regions[0].blocks[0] - ip = InsertionPoint(entry_block.operations[1]) - ip.insert(Operation.create("custom.op3")) - # CHECK: "custom.op1" - # CHECK: "custom.op3" - # CHECK: "custom.op2" - module.operation.print() + """ + ) + entry_block = module.body.operations[0].regions[0].blocks[0] + ip = InsertionPoint(entry_block.operations[1]) + ip.insert(Operation.create("custom.op3")) + # CHECK: "custom.op1" + # CHECK: "custom.op3" + # CHECK: "custom.op2" + module.operation.print() + run(test_insert_before_operation) # CHECK-LABEL: TEST: test_insert_at_block_begin def test_insert_at_block_begin(): - ctx = Context() - ctx.allow_unregistered_dialects = True - with Location.unknown(ctx): - module = Module.parse(r""" + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + module = Module.parse( + r""" func.func @foo() -> () { "custom.op2"() : () -> () } - """) - entry_block = module.body.operations[0].regions[0].blocks[0] - ip = InsertionPoint.at_block_begin(entry_block) - ip.insert(Operation.create("custom.op1")) - # CHECK: "custom.op1" - # CHECK: "custom.op2" - module.operation.print() + """ + ) + entry_block = module.body.operations[0].regions[0].blocks[0] + ip = InsertionPoint.at_block_begin(entry_block) + ip.insert(Operation.create("custom.op1")) + # CHECK: "custom.op1" + # CHECK: "custom.op2" + module.operation.print() + run(test_insert_at_block_begin) # CHECK-LABEL: TEST: test_insert_at_block_begin_empty def test_insert_at_block_begin_empty(): - # TODO: Write this test case when we can create such a situation. - pass + # TODO: Write this test case when we can create such a situation. + pass + run(test_insert_at_block_begin_empty) # CHECK-LABEL: TEST: test_insert_at_terminator def test_insert_at_terminator(): - ctx = Context() - ctx.allow_unregistered_dialects = True - with Location.unknown(ctx): - module = Module.parse(r""" + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + module = Module.parse( + r""" func.func @foo() -> () { "custom.op1"() : () -> () return } - """) - entry_block = module.body.operations[0].regions[0].blocks[0] - ip = InsertionPoint.at_block_terminator(entry_block) - ip.insert(Operation.create("custom.op2")) - # CHECK: "custom.op1" - # CHECK: "custom.op2" - module.operation.print() + """ + ) + entry_block = module.body.operations[0].regions[0].blocks[0] + ip = InsertionPoint.at_block_terminator(entry_block) + ip.insert(Operation.create("custom.op2")) + # CHECK: "custom.op1" + # CHECK: "custom.op2" + module.operation.print() + run(test_insert_at_terminator) # CHECK-LABEL: TEST: test_insert_at_block_terminator_missing def test_insert_at_block_terminator_missing(): - ctx = Context() - ctx.allow_unregistered_dialects = True - with ctx: - module = Module.parse(r""" + ctx = Context() + ctx.allow_unregistered_dialects = True + with ctx: + module = Module.parse( + r""" func.func @foo() -> () { "custom.op1"() : () -> () } - """) - entry_block = module.body.operations[0].regions[0].blocks[0] - try: - ip = InsertionPoint.at_block_terminator(entry_block) - except ValueError as e: - # CHECK: Block has no terminator - print(e) - else: - assert False, "Expected exception" + """ + ) + entry_block = module.body.operations[0].regions[0].blocks[0] + try: + ip = InsertionPoint.at_block_terminator(entry_block) + except ValueError as e: + # CHECK: Block has no terminator + print(e) + else: + assert False, "Expected exception" + run(test_insert_at_block_terminator_missing) # CHECK-LABEL: TEST: test_insert_at_end_with_terminator_errors def test_insert_at_end_with_terminator_errors(): - with Context() as ctx, Location.unknown(): - ctx.allow_unregistered_dialects = True - module = Module.parse(r""" + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + module = Module.parse( + r""" func.func @foo() -> () { return } - """) - entry_block = module.body.operations[0].regions[0].blocks[0] - with InsertionPoint(entry_block): - try: - Operation.create("custom.op1", results=[], operands=[]) - except IndexError as e: - # CHECK: ERROR: Cannot insert operation at the end of a block that already has a terminator. - print(f"ERROR: {e}") + """ + ) + entry_block = module.body.operations[0].regions[0].blocks[0] + with InsertionPoint(entry_block): + try: + Operation.create("custom.op1", results=[], operands=[]) + except IndexError as e: + # CHECK: ERROR: Cannot insert operation at the end of a block that already has a terminator. + print(f"ERROR: {e}") + run(test_insert_at_end_with_terminator_errors) # CHECK-LABEL: TEST: test_insertion_point_context def test_insertion_point_context(): - ctx = Context() - ctx.allow_unregistered_dialects = True - with Location.unknown(ctx): - module = Module.parse(r""" + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + module = Module.parse( + r""" func.func @foo() -> () { "custom.op1"() : () -> () } - """) - entry_block = module.body.operations[0].regions[0].blocks[0] - with InsertionPoint(entry_block): - Operation.create("custom.op2") - with InsertionPoint.at_block_begin(entry_block): - Operation.create("custom.opa") - Operation.create("custom.opb") - Operation.create("custom.op3") - # CHECK: "custom.opa" - # CHECK: "custom.opb" - # CHECK: "custom.op1" - # CHECK: "custom.op2" - # CHECK: "custom.op3" - module.operation.print() + """ + ) + entry_block = module.body.operations[0].regions[0].blocks[0] + with InsertionPoint(entry_block): + Operation.create("custom.op2") + with InsertionPoint.at_block_begin(entry_block): + Operation.create("custom.opa") + Operation.create("custom.opb") + Operation.create("custom.op3") + # CHECK: "custom.opa" + # CHECK: "custom.opb" + # CHECK: "custom.op1" + # CHECK: "custom.op2" + # CHECK: "custom.op3" + module.operation.print() + run(test_insertion_point_context) diff --git a/mlir/test/python/ir/integer_set.py b/mlir/test/python/ir/integer_set.py index d9f158c..9fe0480 100644 --- a/mlir/test/python/ir/integer_set.py +++ b/mlir/test/python/ir/integer_set.py @@ -3,139 +3,140 @@ import gc from mlir.ir import * + def run(f): - print("\nTEST:", f.__name__) - f() - gc.collect() - assert Context._get_live_count() == 0 - return f + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + return f # CHECK-LABEL: TEST: testIntegerSetCapsule @run def testIntegerSetCapsule(): - with Context() as ctx: - is1 = IntegerSet.get_empty(1, 1, ctx) - capsule = is1._CAPIPtr - # CHECK: mlir.ir.IntegerSet._CAPIPtr - print(capsule) - is2 = IntegerSet._CAPICreate(capsule) - assert is1 == is2 - assert is2.context is ctx + with Context() as ctx: + is1 = IntegerSet.get_empty(1, 1, ctx) + capsule = is1._CAPIPtr + # CHECK: mlir.ir.IntegerSet._CAPIPtr + print(capsule) + is2 = IntegerSet._CAPICreate(capsule) + assert is1 == is2 + assert is2.context is ctx # CHECK-LABEL: TEST: testIntegerSetGet @run def testIntegerSetGet(): - with Context(): - d0 = AffineDimExpr.get(0) - d1 = AffineDimExpr.get(1) - s0 = AffineSymbolExpr.get(0) - c42 = AffineConstantExpr.get(42) - - # CHECK: (d0, d1)[s0] : (d0 - d1 == 0, s0 - 42 >= 0) - set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42], [True, False]) - print(set0) - - # CHECK: (d0)[s0] : (1 == 0) - set1 = IntegerSet.get_empty(1, 1) - print(set1) - - # CHECK: (d0)[s0, s1] : (d0 - s1 == 0, s0 - 42 >= 0) - set2 = set0.get_replaced([d0, AffineSymbolExpr.get(1)], [s0], 1, 2) - print(set2) - - try: - IntegerSet.get(2, 1, [], []) - except ValueError as e: - # CHECK: Expected non-empty list of constraints - print(e) - - try: - IntegerSet.get(2, 1, [d0 - d1], [True, False]) - except ValueError as e: - # CHECK: Expected the number of constraints to match that of equality flags - print(e) - - try: - IntegerSet.get(2, 1, [0], [True]) - except RuntimeError as e: - # CHECK: Invalid expression when attempting to create an IntegerSet - print(e) - - try: - IntegerSet.get(2, 1, [None], [True]) - except RuntimeError as e: - # CHECK: Invalid expression (None?) when attempting to create an IntegerSet - print(e) - - try: - set0.get_replaced([d0], [s0], 1, 1) - except ValueError as e: - # CHECK: Expected the number of dimension replacement expressions to match that of dimensions - print(e) - - try: - set0.get_replaced([d0, d1], [s0, s0], 1, 1) - except ValueError as e: - # CHECK: Expected the number of symbol replacement expressions to match that of symbols - print(e) - - try: - set0.get_replaced([d0, 1], [s0], 1, 1) - except RuntimeError as e: - # CHECK: Invalid expression when attempting to create an IntegerSet by replacing dimensions - print(e) - - try: - set0.get_replaced([d0, d1], [None], 1, 1) - except RuntimeError as e: - # CHECK: Invalid expression (None?) when attempting to create an IntegerSet by replacing symbols - print(e) + with Context(): + d0 = AffineDimExpr.get(0) + d1 = AffineDimExpr.get(1) + s0 = AffineSymbolExpr.get(0) + c42 = AffineConstantExpr.get(42) + + # CHECK: (d0, d1)[s0] : (d0 - d1 == 0, s0 - 42 >= 0) + set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42], [True, False]) + print(set0) + + # CHECK: (d0)[s0] : (1 == 0) + set1 = IntegerSet.get_empty(1, 1) + print(set1) + + # CHECK: (d0)[s0, s1] : (d0 - s1 == 0, s0 - 42 >= 0) + set2 = set0.get_replaced([d0, AffineSymbolExpr.get(1)], [s0], 1, 2) + print(set2) + + try: + IntegerSet.get(2, 1, [], []) + except ValueError as e: + # CHECK: Expected non-empty list of constraints + print(e) + + try: + IntegerSet.get(2, 1, [d0 - d1], [True, False]) + except ValueError as e: + # CHECK: Expected the number of constraints to match that of equality flags + print(e) + + try: + IntegerSet.get(2, 1, [0], [True]) + except RuntimeError as e: + # CHECK: Invalid expression when attempting to create an IntegerSet + print(e) + + try: + IntegerSet.get(2, 1, [None], [True]) + except RuntimeError as e: + # CHECK: Invalid expression (None?) when attempting to create an IntegerSet + print(e) + + try: + set0.get_replaced([d0], [s0], 1, 1) + except ValueError as e: + # CHECK: Expected the number of dimension replacement expressions to match that of dimensions + print(e) + + try: + set0.get_replaced([d0, d1], [s0, s0], 1, 1) + except ValueError as e: + # CHECK: Expected the number of symbol replacement expressions to match that of symbols + print(e) + + try: + set0.get_replaced([d0, 1], [s0], 1, 1) + except RuntimeError as e: + # CHECK: Invalid expression when attempting to create an IntegerSet by replacing dimensions + print(e) + + try: + set0.get_replaced([d0, d1], [None], 1, 1) + except RuntimeError as e: + # CHECK: Invalid expression (None?) when attempting to create an IntegerSet by replacing symbols + print(e) # CHECK-LABEL: TEST: testIntegerSetProperties @run def testIntegerSetProperties(): - with Context(): - d0 = AffineDimExpr.get(0) - d1 = AffineDimExpr.get(1) - s0 = AffineSymbolExpr.get(0) - c42 = AffineConstantExpr.get(42) - - set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42, s0 - d0], [True, False, False]) - # CHECK: 2 - print(set0.n_dims) - # CHECK: 1 - print(set0.n_symbols) - # CHECK: 3 - print(set0.n_inputs) - # CHECK: 1 - print(set0.n_equalities) - # CHECK: 2 - print(set0.n_inequalities) - - # CHECK: 3 - print(len(set0.constraints)) - - # CHECK-DAG: d0 - d1 == 0 - # CHECK-DAG: s0 - 42 >= 0 - # CHECK-DAG: -d0 + s0 >= 0 - for cstr in set0.constraints: - print(cstr.expr, end='') - print(" == 0" if cstr.is_eq else " >= 0") + with Context(): + d0 = AffineDimExpr.get(0) + d1 = AffineDimExpr.get(1) + s0 = AffineSymbolExpr.get(0) + c42 = AffineConstantExpr.get(42) + + set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42, s0 - d0], [True, False, False]) + # CHECK: 2 + print(set0.n_dims) + # CHECK: 1 + print(set0.n_symbols) + # CHECK: 3 + print(set0.n_inputs) + # CHECK: 1 + print(set0.n_equalities) + # CHECK: 2 + print(set0.n_inequalities) + + # CHECK: 3 + print(len(set0.constraints)) + + # CHECK-DAG: d0 - d1 == 0 + # CHECK-DAG: s0 - 42 >= 0 + # CHECK-DAG: -d0 + s0 >= 0 + for cstr in set0.constraints: + print(cstr.expr, end="") + print(" == 0" if cstr.is_eq else " >= 0") # TODO-LABEL: TEST: testHash @run def testHash(): - with Context(): - d0 = AffineDimExpr.get(0) - d1 = AffineDimExpr.get(1) - set = IntegerSet.get(2, 0, [d0 + d1], [True]) + with Context(): + d0 = AffineDimExpr.get(0) + d1 = AffineDimExpr.get(1) + set = IntegerSet.get(2, 0, [d0 + d1], [True]) - assert hash(set) == hash(IntegerSet.get(2, 0, [d0 + d1], [True])) + assert hash(set) == hash(IntegerSet.get(2, 0, [d0 + d1], [True])) - dictionary = dict() - dictionary[set] = 42 - assert set in dictionary + dictionary = dict() + dictionary[set] = 42 + assert set in dictionary diff --git a/mlir/test/python/ir/location.py b/mlir/test/python/ir/location.py index 6a30a1d..f66d6c5 100644 --- a/mlir/test/python/ir/location.py +++ b/mlir/test/python/ir/location.py @@ -3,143 +3,150 @@ import gc from mlir.ir import * + def run(f): - print("\nTEST:", f.__name__) - f() - gc.collect() - assert Context._get_live_count() == 0 + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 # CHECK-LABEL: TEST: testUnknown def testUnknown(): - with Context() as ctx: - loc = Location.unknown() - assert loc.context is ctx - ctx = None - gc.collect() - # CHECK: unknown str: loc(unknown) - print("unknown str:", str(loc)) - # CHECK: unknown repr: loc(unknown) - print("unknown repr:", repr(loc)) + with Context() as ctx: + loc = Location.unknown() + assert loc.context is ctx + ctx = None + gc.collect() + # CHECK: unknown str: loc(unknown) + print("unknown str:", str(loc)) + # CHECK: unknown repr: loc(unknown) + print("unknown repr:", repr(loc)) + run(testUnknown) # CHECK-LABEL: TEST: testLocationAttr def testLocationAttr(): - with Context() as ctxt: - loc = Location.unknown() - attr = loc.attr - clone = Location.from_attr(attr) - gc.collect() - # CHECK: loc: loc(unknown) - print("loc:", str(loc)) - # CHECK: clone: loc(unknown) - print("clone:", str(clone)) - assert loc == clone + with Context() as ctxt: + loc = Location.unknown() + attr = loc.attr + clone = Location.from_attr(attr) + gc.collect() + # CHECK: loc: loc(unknown) + print("loc:", str(loc)) + # CHECK: clone: loc(unknown) + print("clone:", str(clone)) + assert loc == clone + run(testLocationAttr) # CHECK-LABEL: TEST: testFileLineCol def testFileLineCol(): - with Context() as ctx: - loc = Location.file("foo.txt", 123, 56) - ctx = None - gc.collect() - # CHECK: file str: loc("foo.txt":123:56) - print("file str:", str(loc)) - # CHECK: file repr: loc("foo.txt":123:56) - print("file repr:", repr(loc)) + with Context() as ctx: + loc = Location.file("foo.txt", 123, 56) + ctx = None + gc.collect() + # CHECK: file str: loc("foo.txt":123:56) + print("file str:", str(loc)) + # CHECK: file repr: loc("foo.txt":123:56) + print("file repr:", repr(loc)) + run(testFileLineCol) # CHECK-LABEL: TEST: testName def testName(): - with Context() as ctx: - loc = Location.name("nombre") - locWithChildLoc = Location.name("naam", loc) - ctx = None - gc.collect() - # CHECK: file str: loc("nombre") - print("file str:", str(loc)) - # CHECK: file repr: loc("nombre") - print("file repr:", repr(loc)) - # CHECK: file str: loc("naam"("nombre")) - print("file str:", str(locWithChildLoc)) - # CHECK: file repr: loc("naam"("nombre")) - print("file repr:", repr(locWithChildLoc)) + with Context() as ctx: + loc = Location.name("nombre") + locWithChildLoc = Location.name("naam", loc) + ctx = None + gc.collect() + # CHECK: file str: loc("nombre") + print("file str:", str(loc)) + # CHECK: file repr: loc("nombre") + print("file repr:", repr(loc)) + # CHECK: file str: loc("naam"("nombre")) + print("file str:", str(locWithChildLoc)) + # CHECK: file repr: loc("naam"("nombre")) + print("file repr:", repr(locWithChildLoc)) + run(testName) # CHECK-LABEL: TEST: testCallSite def testCallSite(): - with Context() as ctx: - loc = Location.callsite( - Location.file("foo.text", 123, 45), [ - Location.file("util.foo", 379, 21), - Location.file("main.foo", 100, 63) - ]) - ctx = None - # CHECK: file str: loc(callsite("foo.text":123:45 at callsite("util.foo":379:21 at "main.foo":100:63)) - print("file str:", str(loc)) - # CHECK: file repr: loc(callsite("foo.text":123:45 at callsite("util.foo":379:21 at "main.foo":100:63)) - print("file repr:", repr(loc)) + with Context() as ctx: + loc = Location.callsite( + Location.file("foo.text", 123, 45), + [Location.file("util.foo", 379, 21), Location.file("main.foo", 100, 63)], + ) + ctx = None + # CHECK: file str: loc(callsite("foo.text":123:45 at callsite("util.foo":379:21 at "main.foo":100:63)) + print("file str:", str(loc)) + # CHECK: file repr: loc(callsite("foo.text":123:45 at callsite("util.foo":379:21 at "main.foo":100:63)) + print("file repr:", repr(loc)) + run(testCallSite) # CHECK-LABEL: TEST: testFused def testFused(): - with Context() as ctx: - loc_single = Location.fused([Location.name("apple")]) - loc = Location.fused( - [Location.name("apple"), Location.name("banana")]) - attr = Attribute.parse('"sauteed"') - loc_attr = Location.fused([Location.name("carrot"), - Location.name("potatoes")], attr) - loc_empty = Location.fused([]) - loc_empty_attr = Location.fused([], attr) - loc_single_attr = Location.fused([Location.name("apple")], attr) - ctx = None - # CHECK: file str: loc("apple") - print("file str:", str(loc_single)) - # CHECK: file repr: loc("apple") - print("file repr:", repr(loc_single)) - # CHECK: file str: loc(fused["apple", "banana"]) - print("file str:", str(loc)) - # CHECK: file repr: loc(fused["apple", "banana"]) - print("file repr:", repr(loc)) - # CHECK: file str: loc(fused<"sauteed">["carrot", "potatoes"]) - print("file str:", str(loc_attr)) - # CHECK: file repr: loc(fused<"sauteed">["carrot", "potatoes"]) - print("file repr:", repr(loc_attr)) - # CHECK: file str: loc(unknown) - print("file str:", str(loc_empty)) - # CHECK: file repr: loc(unknown) - print("file repr:", repr(loc_empty)) - # CHECK: file str: loc(fused<"sauteed">[unknown]) - print("file str:", str(loc_empty_attr)) - # CHECK: file repr: loc(fused<"sauteed">[unknown]) - print("file repr:", repr(loc_empty_attr)) - # CHECK: file str: loc(fused<"sauteed">["apple"]) - print("file str:", str(loc_single_attr)) - # CHECK: file repr: loc(fused<"sauteed">["apple"]) - print("file repr:", repr(loc_single_attr)) + with Context() as ctx: + loc_single = Location.fused([Location.name("apple")]) + loc = Location.fused([Location.name("apple"), Location.name("banana")]) + attr = Attribute.parse('"sauteed"') + loc_attr = Location.fused( + [Location.name("carrot"), Location.name("potatoes")], attr + ) + loc_empty = Location.fused([]) + loc_empty_attr = Location.fused([], attr) + loc_single_attr = Location.fused([Location.name("apple")], attr) + ctx = None + # CHECK: file str: loc("apple") + print("file str:", str(loc_single)) + # CHECK: file repr: loc("apple") + print("file repr:", repr(loc_single)) + # CHECK: file str: loc(fused["apple", "banana"]) + print("file str:", str(loc)) + # CHECK: file repr: loc(fused["apple", "banana"]) + print("file repr:", repr(loc)) + # CHECK: file str: loc(fused<"sauteed">["carrot", "potatoes"]) + print("file str:", str(loc_attr)) + # CHECK: file repr: loc(fused<"sauteed">["carrot", "potatoes"]) + print("file repr:", repr(loc_attr)) + # CHECK: file str: loc(unknown) + print("file str:", str(loc_empty)) + # CHECK: file repr: loc(unknown) + print("file repr:", repr(loc_empty)) + # CHECK: file str: loc(fused<"sauteed">[unknown]) + print("file str:", str(loc_empty_attr)) + # CHECK: file repr: loc(fused<"sauteed">[unknown]) + print("file repr:", repr(loc_empty_attr)) + # CHECK: file str: loc(fused<"sauteed">["apple"]) + print("file str:", str(loc_single_attr)) + # CHECK: file repr: loc(fused<"sauteed">["apple"]) + print("file repr:", repr(loc_single_attr)) + run(testFused) # CHECK-LABEL: TEST: testLocationCapsule def testLocationCapsule(): - with Context() as ctx: - loc1 = Location.file("foo.txt", 123, 56) - # CHECK: mlir.ir.Location._CAPIPtr - loc_capsule = loc1._CAPIPtr - print(loc_capsule) - loc2 = Location._CAPICreate(loc_capsule) - assert loc2 == loc1 - assert loc2.context is ctx + with Context() as ctx: + loc1 = Location.file("foo.txt", 123, 56) + # CHECK: mlir.ir.Location._CAPIPtr + loc_capsule = loc1._CAPIPtr + print(loc_capsule) + loc2 = Location._CAPICreate(loc_capsule) + assert loc2 == loc1 + assert loc2.context is ctx + run(testLocationCapsule) diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py index 2d00923..a5c38a6 100644 --- a/mlir/test/python/ir/module.py +++ b/mlir/test/python/ir/module.py @@ -3,12 +3,13 @@ import gc from mlir.ir import * + def run(f): - print("\nTEST:", f.__name__) - f() - gc.collect() - assert Context._get_live_count() == 0 - return f + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + return f # Verify successful parse. @@ -16,14 +17,14 @@ def run(f): # CHECK: module @successfulParse @run def testParseSuccess(): - ctx = Context() - module = Module.parse(r"""module @successfulParse {}""", ctx) - assert module.context is ctx - print("CLEAR CONTEXT") - ctx = None # Ensure that module captures the context. - gc.collect() - module.dump() # Just outputs to stderr. Verifies that it functions. - print(str(module)) + ctx = Context() + module = Module.parse(r"""module @successfulParse {}""", ctx) + assert module.context is ctx + print("CLEAR CONTEXT") + ctx = None # Ensure that module captures the context. + gc.collect() + module.dump() # Just outputs to stderr. Verifies that it functions. + print(str(module)) # Verify parse error. @@ -34,13 +35,13 @@ def testParseSuccess(): # CHECK: > @run def testParseError(): - ctx = Context() - try: - module = Module.parse(r"""}SYNTAX ERROR{""", ctx) - except MLIRError as e: - print(f"testParseError: <{e}>") - else: - print("Exception not produced") + ctx = Context() + try: + module = Module.parse(r"""}SYNTAX ERROR{""", ctx) + except MLIRError as e: + print(f"testParseError: <{e}>") + else: + print("Exception not produced") # Verify successful parse. @@ -48,13 +49,13 @@ def testParseError(): # CHECK: module { @run def testCreateEmpty(): - ctx = Context() - loc = Location.unknown(ctx) - module = Module.create(loc) - print("CLEAR CONTEXT") - ctx = None # Ensure that module captures the context. - gc.collect() - print(str(module)) + ctx = Context() + loc = Location.unknown(ctx) + module = Module.create(loc) + print("CLEAR CONTEXT") + ctx = None # Ensure that module captures the context. + gc.collect() + print(str(module)) # Verify round-trip of ASM that contains unicode. @@ -65,11 +66,14 @@ def testCreateEmpty(): # CHECK: foo = "\F0\9F\98\8A" @run def testRoundtripUnicode(): - ctx = Context() - module = Module.parse(r""" + ctx = Context() + module = Module.parse( + r""" func.func private @roundtripUnicode() attributes { foo = "😊" } - """, ctx) - print(str(module)) + """, + ctx, + ) + print(str(module)) # Verify round-trip of ASM that contains unicode. @@ -80,73 +84,74 @@ def testRoundtripUnicode(): # CHECK: foo = "\F0\9F\98\8A" @run def testRoundtripBinary(): - with Context(): - module = Module.parse(r""" + with Context(): + module = Module.parse( + r""" func.func private @roundtripUnicode() attributes { foo = "😊" } - """) - binary_asm = module.operation.get_asm(binary=True) - assert isinstance(binary_asm, bytes) - module = Module.parse(binary_asm) - print(module) + """ + ) + binary_asm = module.operation.get_asm(binary=True) + assert isinstance(binary_asm, bytes) + module = Module.parse(binary_asm) + print(module) # Tests that module.operation works and correctly interns instances. # CHECK-LABEL: TEST: testModuleOperation @run def testModuleOperation(): - ctx = Context() - module = Module.parse(r"""module @successfulParse {}""", ctx) - assert ctx._get_live_module_count() == 1 - op1 = module.operation - assert ctx._get_live_operation_count() == 1 - # CHECK: module @successfulParse - print(op1) - - # Ensure that operations are the same on multiple calls. - op2 = module.operation - assert ctx._get_live_operation_count() == 1 - assert op1 is op2 - - # Test live operation clearing. - op1 = module.operation - assert ctx._get_live_operation_count() == 1 - num_invalidated = ctx._clear_live_operations() - assert num_invalidated == 1 - assert ctx._get_live_operation_count() == 0 - op1 = None - gc.collect() - op1 = module.operation - - # Ensure that if module is de-referenced, the operations are still valid. - module = None - gc.collect() - print(op1) - - # Collect and verify lifetime. - op1 = None - op2 = None - gc.collect() - print("LIVE OPERATIONS:", ctx._get_live_operation_count()) - assert ctx._get_live_operation_count() == 0 - assert ctx._get_live_module_count() == 0 + ctx = Context() + module = Module.parse(r"""module @successfulParse {}""", ctx) + assert ctx._get_live_module_count() == 1 + op1 = module.operation + assert ctx._get_live_operation_count() == 1 + # CHECK: module @successfulParse + print(op1) + + # Ensure that operations are the same on multiple calls. + op2 = module.operation + assert ctx._get_live_operation_count() == 1 + assert op1 is op2 + + # Test live operation clearing. + op1 = module.operation + assert ctx._get_live_operation_count() == 1 + num_invalidated = ctx._clear_live_operations() + assert num_invalidated == 1 + assert ctx._get_live_operation_count() == 0 + op1 = None + gc.collect() + op1 = module.operation + + # Ensure that if module is de-referenced, the operations are still valid. + module = None + gc.collect() + print(op1) + + # Collect and verify lifetime. + op1 = None + op2 = None + gc.collect() + print("LIVE OPERATIONS:", ctx._get_live_operation_count()) + assert ctx._get_live_operation_count() == 0 + assert ctx._get_live_module_count() == 0 # CHECK-LABEL: TEST: testModuleCapsule @run def testModuleCapsule(): - ctx = Context() - module = Module.parse(r"""module @successfulParse {}""", ctx) - assert ctx._get_live_module_count() == 1 - # CHECK: "mlir.ir.Module._CAPIPtr" - module_capsule = module._CAPIPtr - print(module_capsule) - module_dup = Module._CAPICreate(module_capsule) - assert module is module_dup - assert module_dup.context is ctx - # Gc and verify destructed. - module = None - module_capsule = None - module_dup = None - gc.collect() - assert ctx._get_live_module_count() == 0 - + ctx = Context() + module = Module.parse(r"""module @successfulParse {}""", ctx) + assert ctx._get_live_module_count() == 1 + # CHECK: "mlir.ir.Module._CAPIPtr" + module_capsule = module._CAPIPtr + print(module_capsule) + module_dup = Module._CAPICreate(module_capsule) + assert module is module_dup + assert module_dup.context is ctx + # Gc and verify destructed. + module = None + module_capsule = None + module_dup = None + gc.collect() + assert ctx._get_live_module_count() == 0 diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index 22a8089..639f8ff 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -8,232 +8,242 @@ from mlir.dialects.builtin import ModuleOp def run(f): - print("\nTEST:", f.__name__) - f() - gc.collect() - assert Context._get_live_count() == 0 - return f + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + return f def expect_index_error(callback): - try: - _ = callback() - raise RuntimeError("Expected IndexError") - except IndexError: - pass + try: + _ = callback() + raise RuntimeError("Expected IndexError") + except IndexError: + pass # Verify iterator based traversal of the op/region/block hierarchy. # CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators @run def testTraverseOpRegionBlockIterators(): - ctx = Context() - ctx.allow_unregistered_dialects = True - module = Module.parse( - r""" + ctx = Context() + ctx.allow_unregistered_dialects = True + module = Module.parse( + r""" func.func @f1(%arg0: i32) -> i32 { %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 return %1 : i32 } - """, ctx) - op = module.operation - assert op.context is ctx - # Get the block using iterators off of the named collections. - regions = list(op.regions) - blocks = list(regions[0].blocks) - # CHECK: MODULE REGIONS=1 BLOCKS=1 - print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}") - - # Should verify. - # CHECK: .verify = True - print(f".verify = {module.operation.verify()}") - - # Get the blocks from the default collection. - default_blocks = list(regions[0]) - # They should compare equal regardless of how obtained. - assert default_blocks == blocks - - # Should be able to get the operations from either the named collection - # or the block. - operations = list(blocks[0].operations) - default_operations = list(blocks[0]) - assert default_operations == operations - - def walk_operations(indent, op): - for i, region in enumerate(op.regions): - print(f"{indent}REGION {i}:") - for j, block in enumerate(region): - print(f"{indent} BLOCK {j}:") - for k, child_op in enumerate(block): - print(f"{indent} OP {k}: {child_op}") - walk_operations(indent + " ", child_op) - - # CHECK: REGION 0: - # CHECK: BLOCK 0: - # CHECK: OP 0: func - # CHECK: REGION 0: - # CHECK: BLOCK 0: - # CHECK: OP 0: %0 = "custom.addi" - # CHECK: OP 1: func.return - walk_operations("", op) - - # CHECK: Region iter: i32 { %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 return %1 : i32 } - """, ctx) - - def walk_operations(indent, op): - for i in range(len(op.regions)): - region = op.regions[i] - print(f"{indent}REGION {i}:") - for j in range(len(region.blocks)): - block = region.blocks[j] - print(f"{indent} BLOCK {j}:") - for k in range(len(block.operations)): - child_op = block.operations[k] - print(f"{indent} OP {k}: {child_op}") - print(f"{indent} OP {k}: parent {child_op.operation.parent.name}") - walk_operations(indent + " ", child_op) - - # CHECK: REGION 0: - # CHECK: BLOCK 0: - # CHECK: OP 0: func - # CHECK: OP 0: parent builtin.module - # CHECK: REGION 0: - # CHECK: BLOCK 0: - # CHECK: OP 0: %0 = "custom.addi" - # CHECK: OP 0: parent func.func - # CHECK: OP 1: func.return - # CHECK: OP 1: parent func.func - walk_operations("", module.operation) + """, + ctx, + ) + + def walk_operations(indent, op): + for i in range(len(op.regions)): + region = op.regions[i] + print(f"{indent}REGION {i}:") + for j in range(len(region.blocks)): + block = region.blocks[j] + print(f"{indent} BLOCK {j}:") + for k in range(len(block.operations)): + child_op = block.operations[k] + print(f"{indent} OP {k}: {child_op}") + print( + f"{indent} OP {k}: parent {child_op.operation.parent.name}" + ) + walk_operations(indent + " ", child_op) + + # CHECK: REGION 0: + # CHECK: BLOCK 0: + # CHECK: OP 0: func + # CHECK: OP 0: parent builtin.module + # CHECK: REGION 0: + # CHECK: BLOCK 0: + # CHECK: OP 0: %0 = "custom.addi" + # CHECK: OP 0: parent func.func + # CHECK: OP 1: func.return + # CHECK: OP 1: parent func.func + walk_operations("", module.operation) # CHECK-LABEL: TEST: testBlockAndRegionOwners @run def testBlockAndRegionOwners(): - ctx = Context() - ctx.allow_unregistered_dialects = True - module = Module.parse( - r""" + ctx = Context() + ctx.allow_unregistered_dialects = True + module = Module.parse( + r""" builtin.module { func.func @f() { func.return } } - """, ctx) + """, + ctx, + ) - assert module.operation.regions[0].owner == module.operation - assert module.operation.regions[0].blocks[0].owner == module.operation + assert module.operation.regions[0].owner == module.operation + assert module.operation.regions[0].blocks[0].owner == module.operation - func = module.body.operations[0] - assert func.operation.regions[0].owner == func - assert func.operation.regions[0].blocks[0].owner == func + func = module.body.operations[0] + assert func.operation.regions[0].owner == func + assert func.operation.regions[0].blocks[0].owner == func # CHECK-LABEL: TEST: testBlockArgumentList @run def testBlockArgumentList(): - with Context() as ctx: - module = Module.parse( - r""" + with Context() as ctx: + module = Module.parse( + r""" func.func @f1(%arg0: i32, %arg1: f64, %arg2: index) { return } - """, ctx) - func = module.body.operations[0] - entry_block = func.regions[0].blocks[0] - assert len(entry_block.arguments) == 3 - # CHECK: Argument 0, type i32 - # CHECK: Argument 1, type f64 - # CHECK: Argument 2, type index - for arg in entry_block.arguments: - print(f"Argument {arg.arg_number}, type {arg.type}") - new_type = IntegerType.get_signless(8 * (arg.arg_number + 1)) - arg.set_type(new_type) - - # CHECK: Argument 0, type i8 - # CHECK: Argument 1, type i16 - # CHECK: Argument 2, type i24 - for arg in entry_block.arguments: - print(f"Argument {arg.arg_number}, type {arg.type}") - - # Check that slicing works for block argument lists. - # CHECK: Argument 1, type i16 - # CHECK: Argument 2, type i24 - for arg in entry_block.arguments[1:]: - print(f"Argument {arg.arg_number}, type {arg.type}") - - # Check that we can concatenate slices of argument lists. - # CHECK: Length: 4 - print("Length: ", - len(entry_block.arguments[:2] + entry_block.arguments[1:])) - - # CHECK: Type: i8 - # CHECK: Type: i16 - # CHECK: Type: i24 - for t in entry_block.arguments.types: - print("Type: ", t) - - # Check that slicing and type access compose. - # CHECK: Sliced type: i16 - # CHECK: Sliced type: i24 - for t in entry_block.arguments[1:].types: - print("Sliced type: ", t) - - # Check that slice addition works as expected. - # CHECK: Argument 2, type i24 - # CHECK: Argument 0, type i8 - restructured = entry_block.arguments[-1:] + entry_block.arguments[:1] - for arg in restructured: - print(f"Argument {arg.arg_number}, type {arg.type}") + """, + ctx, + ) + func = module.body.operations[0] + entry_block = func.regions[0].blocks[0] + assert len(entry_block.arguments) == 3 + # CHECK: Argument 0, type i32 + # CHECK: Argument 1, type f64 + # CHECK: Argument 2, type index + for arg in entry_block.arguments: + print(f"Argument {arg.arg_number}, type {arg.type}") + new_type = IntegerType.get_signless(8 * (arg.arg_number + 1)) + arg.set_type(new_type) + + # CHECK: Argument 0, type i8 + # CHECK: Argument 1, type i16 + # CHECK: Argument 2, type i24 + for arg in entry_block.arguments: + print(f"Argument {arg.arg_number}, type {arg.type}") + + # Check that slicing works for block argument lists. + # CHECK: Argument 1, type i16 + # CHECK: Argument 2, type i24 + for arg in entry_block.arguments[1:]: + print(f"Argument {arg.arg_number}, type {arg.type}") + + # Check that we can concatenate slices of argument lists. + # CHECK: Length: 4 + print("Length: ", len(entry_block.arguments[:2] + entry_block.arguments[1:])) + + # CHECK: Type: i8 + # CHECK: Type: i16 + # CHECK: Type: i24 + for t in entry_block.arguments.types: + print("Type: ", t) + + # Check that slicing and type access compose. + # CHECK: Sliced type: i16 + # CHECK: Sliced type: i24 + for t in entry_block.arguments[1:].types: + print("Sliced type: ", t) + + # Check that slice addition works as expected. + # CHECK: Argument 2, type i24 + # CHECK: Argument 0, type i8 + restructured = entry_block.arguments[-1:] + entry_block.arguments[:1] + for arg in restructured: + print(f"Argument {arg.arg_number}, type {arg.type}") # CHECK-LABEL: TEST: testOperationOperands @run def testOperationOperands(): - with Context() as ctx: - ctx.allow_unregistered_dialects = True - module = Module.parse(r""" + with Context() as ctx: + ctx.allow_unregistered_dialects = True + module = Module.parse( + r""" func.func @f1(%arg0: i32) { %0 = "test.producer"() : () -> i64 "test.consumer"(%arg0, %0) : (i32, i64) -> () return - }""") - func = module.body.operations[0] - entry_block = func.regions[0].blocks[0] - consumer = entry_block.operations[1] - assert len(consumer.operands) == 2 - # CHECK: Operand 0, type i32 - # CHECK: Operand 1, type i64 - for i, operand in enumerate(consumer.operands): - print(f"Operand {i}, type {operand.type}") - - + }""" + ) + func = module.body.operations[0] + entry_block = func.regions[0].blocks[0] + consumer = entry_block.operations[1] + assert len(consumer.operands) == 2 + # CHECK: Operand 0, type i32 + # CHECK: Operand 1, type i64 + for i, operand in enumerate(consumer.operands): + print(f"Operand {i}, type {operand.type}") # CHECK-LABEL: TEST: testOperationOperandsSlice @run def testOperationOperandsSlice(): - with Context() as ctx: - ctx.allow_unregistered_dialects = True - module = Module.parse(r""" + with Context() as ctx: + ctx.allow_unregistered_dialects = True + module = Module.parse( + r""" func.func @f1() { %0 = "test.producer0"() : () -> i64 %1 = "test.producer1"() : () -> i64 @@ -242,708 +252,727 @@ def testOperationOperandsSlice(): %4 = "test.producer4"() : () -> i64 "test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> () return - }""") - func = module.body.operations[0] - entry_block = func.regions[0].blocks[0] - consumer = entry_block.operations[5] - assert len(consumer.operands) == 5 - for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]): - assert left == right - - # CHECK: test.producer0 - # CHECK: test.producer1 - # CHECK: test.producer2 - # CHECK: test.producer3 - # CHECK: test.producer4 - full_slice = consumer.operands[:] - for operand in full_slice: - print(operand) - - # CHECK: test.producer0 - # CHECK: test.producer1 - first_two = consumer.operands[0:2] - for operand in first_two: - print(operand) - - # CHECK: test.producer3 - # CHECK: test.producer4 - last_two = consumer.operands[3:] - for operand in last_two: - print(operand) - - # CHECK: test.producer0 - # CHECK: test.producer2 - # CHECK: test.producer4 - even = consumer.operands[::2] - for operand in even: - print(operand) - - # CHECK: test.producer2 - fourth = consumer.operands[::2][1::2] - for operand in fourth: - print(operand) - - + }""" + ) + func = module.body.operations[0] + entry_block = func.regions[0].blocks[0] + consumer = entry_block.operations[5] + assert len(consumer.operands) == 5 + for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]): + assert left == right + + # CHECK: test.producer0 + # CHECK: test.producer1 + # CHECK: test.producer2 + # CHECK: test.producer3 + # CHECK: test.producer4 + full_slice = consumer.operands[:] + for operand in full_slice: + print(operand) + + # CHECK: test.producer0 + # CHECK: test.producer1 + first_two = consumer.operands[0:2] + for operand in first_two: + print(operand) + + # CHECK: test.producer3 + # CHECK: test.producer4 + last_two = consumer.operands[3:] + for operand in last_two: + print(operand) + + # CHECK: test.producer0 + # CHECK: test.producer2 + # CHECK: test.producer4 + even = consumer.operands[::2] + for operand in even: + print(operand) + + # CHECK: test.producer2 + fourth = consumer.operands[::2][1::2] + for operand in fourth: + print(operand) # CHECK-LABEL: TEST: testOperationOperandsSet @run def testOperationOperandsSet(): - with Context() as ctx, Location.unknown(ctx): - ctx.allow_unregistered_dialects = True - module = Module.parse(r""" + with Context() as ctx, Location.unknown(ctx): + ctx.allow_unregistered_dialects = True + module = Module.parse( + r""" func.func @f1() { %0 = "test.producer0"() : () -> i64 %1 = "test.producer1"() : () -> i64 %2 = "test.producer2"() : () -> i64 "test.consumer"(%0) : (i64) -> () return - }""") - func = module.body.operations[0] - entry_block = func.regions[0].blocks[0] - producer1 = entry_block.operations[1] - producer2 = entry_block.operations[2] - consumer = entry_block.operations[3] - assert len(consumer.operands) == 1 - type = consumer.operands[0].type - - # CHECK: test.producer1 - consumer.operands[0] = producer1.result - print(consumer.operands[0]) - - # CHECK: test.producer2 - consumer.operands[-1] = producer2.result - print(consumer.operands[0]) + }""" + ) + func = module.body.operations[0] + entry_block = func.regions[0].blocks[0] + producer1 = entry_block.operations[1] + producer2 = entry_block.operations[2] + consumer = entry_block.operations[3] + assert len(consumer.operands) == 1 + type = consumer.operands[0].type + # CHECK: test.producer1 + consumer.operands[0] = producer1.result + print(consumer.operands[0]) + # CHECK: test.producer2 + consumer.operands[-1] = producer2.result + print(consumer.operands[0]) # CHECK-LABEL: TEST: testDetachedOperation @run def testDetachedOperation(): - ctx = Context() - ctx.allow_unregistered_dialects = True - with Location.unknown(ctx): - i32 = IntegerType.get_signed(32) - op1 = Operation.create( - "custom.op1", - results=[i32, i32], - regions=1, - attributes={ - "foo": StringAttr.get("foo_value"), - "bar": StringAttr.get("bar_value"), - }) - # CHECK: %0:2 = "custom.op1"() ({ - # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32) - print(op1) - - # TODO: Check successors once enough infra exists to do it properly. + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + i32 = IntegerType.get_signed(32) + op1 = Operation.create( + "custom.op1", + results=[i32, i32], + regions=1, + attributes={ + "foo": StringAttr.get("foo_value"), + "bar": StringAttr.get("bar_value"), + }, + ) + # CHECK: %0:2 = "custom.op1"() ({ + # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32) + print(op1) + + # TODO: Check successors once enough infra exists to do it properly. # CHECK-LABEL: TEST: testOperationInsertionPoint @run def testOperationInsertionPoint(): - ctx = Context() - ctx.allow_unregistered_dialects = True - module = Module.parse( - r""" + ctx = Context() + ctx.allow_unregistered_dialects = True + module = Module.parse( + r""" func.func @f1(%arg0: i32) -> i32 { %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 return %1 : i32 } - """, ctx) - - # Create test op. - with Location.unknown(ctx): - op1 = Operation.create("custom.op1") - op2 = Operation.create("custom.op2") - - func = module.body.operations[0] - entry_block = func.regions[0].blocks[0] - ip = InsertionPoint.at_block_begin(entry_block) - ip.insert(op1) - ip.insert(op2) - # CHECK: func @f1 - # CHECK: "custom.op1"() - # CHECK: "custom.op2"() - # CHECK: %0 = "custom.addi" - print(module) - - # Trying to add a previously added op should raise. - try: - ip.insert(op1) - except ValueError: - pass - else: - assert False, "expected insert of attached op to raise" + """, + ctx, + ) + + # Create test op. + with Location.unknown(ctx): + op1 = Operation.create("custom.op1") + op2 = Operation.create("custom.op2") + + func = module.body.operations[0] + entry_block = func.regions[0].blocks[0] + ip = InsertionPoint.at_block_begin(entry_block) + ip.insert(op1) + ip.insert(op2) + # CHECK: func @f1 + # CHECK: "custom.op1"() + # CHECK: "custom.op2"() + # CHECK: %0 = "custom.addi" + print(module) + + # Trying to add a previously added op should raise. + try: + ip.insert(op1) + except ValueError: + pass + else: + assert False, "expected insert of attached op to raise" # CHECK-LABEL: TEST: testOperationWithRegion @run def testOperationWithRegion(): - ctx = Context() - ctx.allow_unregistered_dialects = True - with Location.unknown(ctx): - i32 = IntegerType.get_signed(32) - op1 = Operation.create("custom.op1", regions=1) - block = op1.regions[0].blocks.append(i32, i32) - # CHECK: "custom.op1"() ({ - # CHECK: ^bb0(%arg0: si32, %arg1: si32): - # CHECK: "custom.terminator"() : () -> () - # CHECK: }) : () -> () - terminator = Operation.create("custom.terminator") - ip = InsertionPoint(block) - ip.insert(terminator) - print(op1) - - # Now add the whole operation to another op. - # TODO: Verify lifetime hazard by nulling out the new owning module and - # accessing op1. - # TODO: Also verify accessing the terminator once both parents are nulled - # out. - module = Module.parse(r""" + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + i32 = IntegerType.get_signed(32) + op1 = Operation.create("custom.op1", regions=1) + block = op1.regions[0].blocks.append(i32, i32) + # CHECK: "custom.op1"() ({ + # CHECK: ^bb0(%arg0: si32, %arg1: si32): + # CHECK: "custom.terminator"() : () -> () + # CHECK: }) : () -> () + terminator = Operation.create("custom.terminator") + ip = InsertionPoint(block) + ip.insert(terminator) + print(op1) + + # Now add the whole operation to another op. + # TODO: Verify lifetime hazard by nulling out the new owning module and + # accessing op1. + # TODO: Also verify accessing the terminator once both parents are nulled + # out. + module = Module.parse( + r""" func.func @f1(%arg0: i32) -> i32 { %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 return %1 : i32 } - """) - func = module.body.operations[0] - entry_block = func.regions[0].blocks[0] - ip = InsertionPoint.at_block_begin(entry_block) - ip.insert(op1) - # CHECK: func @f1 - # CHECK: "custom.op1"() - # CHECK: "custom.terminator" - # CHECK: %0 = "custom.addi" - print(module) + """ + ) + func = module.body.operations[0] + entry_block = func.regions[0].blocks[0] + ip = InsertionPoint.at_block_begin(entry_block) + ip.insert(op1) + # CHECK: func @f1 + # CHECK: "custom.op1"() + # CHECK: "custom.terminator" + # CHECK: %0 = "custom.addi" + print(module) # CHECK-LABEL: TEST: testOperationResultList @run def testOperationResultList(): - ctx = Context() - module = Module.parse( - r""" + ctx = Context() + module = Module.parse( + r""" func.func @f1() { %0:3 = call @f2() : () -> (i32, f64, index) return } func.func private @f2() -> (i32, f64, index) - """, ctx) - caller = module.body.operations[0] - call = caller.regions[0].blocks[0].operations[0] - assert len(call.results) == 3 - # CHECK: Result 0, type i32 - # CHECK: Result 1, type f64 - # CHECK: Result 2, type index - for res in call.results: - print(f"Result {res.result_number}, type {res.type}") - - # CHECK: Result type i32 - # CHECK: Result type f64 - # CHECK: Result type index - for t in call.results.types: - print(f"Result type {t}") - - # Out of range - expect_index_error(lambda: call.results[3]) - expect_index_error(lambda: call.results[-4]) + """, + ctx, + ) + caller = module.body.operations[0] + call = caller.regions[0].blocks[0].operations[0] + assert len(call.results) == 3 + # CHECK: Result 0, type i32 + # CHECK: Result 1, type f64 + # CHECK: Result 2, type index + for res in call.results: + print(f"Result {res.result_number}, type {res.type}") + + # CHECK: Result type i32 + # CHECK: Result type f64 + # CHECK: Result type index + for t in call.results.types: + print(f"Result type {t}") + + # Out of range + expect_index_error(lambda: call.results[3]) + expect_index_error(lambda: call.results[-4]) # CHECK-LABEL: TEST: testOperationResultListSlice @run def testOperationResultListSlice(): - with Context() as ctx: - ctx.allow_unregistered_dialects = True - module = Module.parse(r""" + with Context() as ctx: + ctx.allow_unregistered_dialects = True + module = Module.parse( + r""" func.func @f1() { "some.op"() : () -> (i1, i2, i3, i4, i5) return } - """) - func = module.body.operations[0] - entry_block = func.regions[0].blocks[0] - producer = entry_block.operations[0] - - assert len(producer.results) == 5 - for left, right in zip(producer.results, producer.results[::-1][::-1]): - assert left == right - assert left.result_number == right.result_number - - # CHECK: Result 0, type i1 - # CHECK: Result 1, type i2 - # CHECK: Result 2, type i3 - # CHECK: Result 3, type i4 - # CHECK: Result 4, type i5 - full_slice = producer.results[:] - for res in full_slice: - print(f"Result {res.result_number}, type {res.type}") - - # CHECK: Result 1, type i2 - # CHECK: Result 2, type i3 - # CHECK: Result 3, type i4 - middle = producer.results[1:4] - for res in middle: - print(f"Result {res.result_number}, type {res.type}") - - # CHECK: Result 1, type i2 - # CHECK: Result 3, type i4 - odd = producer.results[1::2] - for res in odd: - print(f"Result {res.result_number}, type {res.type}") - - # CHECK: Result 3, type i4 - # CHECK: Result 1, type i2 - inverted_middle = producer.results[-2:0:-2] - for res in inverted_middle: - print(f"Result {res.result_number}, type {res.type}") + """ + ) + func = module.body.operations[0] + entry_block = func.regions[0].blocks[0] + producer = entry_block.operations[0] + + assert len(producer.results) == 5 + for left, right in zip(producer.results, producer.results[::-1][::-1]): + assert left == right + assert left.result_number == right.result_number + + # CHECK: Result 0, type i1 + # CHECK: Result 1, type i2 + # CHECK: Result 2, type i3 + # CHECK: Result 3, type i4 + # CHECK: Result 4, type i5 + full_slice = producer.results[:] + for res in full_slice: + print(f"Result {res.result_number}, type {res.type}") + + # CHECK: Result 1, type i2 + # CHECK: Result 2, type i3 + # CHECK: Result 3, type i4 + middle = producer.results[1:4] + for res in middle: + print(f"Result {res.result_number}, type {res.type}") + + # CHECK: Result 1, type i2 + # CHECK: Result 3, type i4 + odd = producer.results[1::2] + for res in odd: + print(f"Result {res.result_number}, type {res.type}") + + # CHECK: Result 3, type i4 + # CHECK: Result 1, type i2 + inverted_middle = producer.results[-2:0:-2] + for res in inverted_middle: + print(f"Result {res.result_number}, type {res.type}") # CHECK-LABEL: TEST: testOperationAttributes @run def testOperationAttributes(): - ctx = Context() - ctx.allow_unregistered_dialects = True - module = Module.parse( - r""" + ctx = Context() + ctx.allow_unregistered_dialects = True + module = Module.parse( + r""" "some.op"() { some.attribute = 1 : i8, other.attribute = 3.0, dependent = "text" } : () -> () - """, ctx) - op = module.body.operations[0] - assert len(op.attributes) == 3 - iattr = IntegerAttr(op.attributes["some.attribute"]) - fattr = FloatAttr(op.attributes["other.attribute"]) - sattr = StringAttr(op.attributes["dependent"]) - # CHECK: Attribute type i8, value 1 - print(f"Attribute type {iattr.type}, value {iattr.value}") - # CHECK: Attribute type f64, value 3.0 - print(f"Attribute type {fattr.type}, value {fattr.value}") - # CHECK: Attribute value text - print(f"Attribute value {sattr.value}") - # CHECK: Attribute value b'text' - print(f"Attribute value {sattr.value_bytes}") - - # We don't know in which order the attributes are stored. - # CHECK-DAG: NamedAttribute(dependent="text") - # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64) - # CHECK-DAG: NamedAttribute(some.attribute=1 : i8) - for attr in op.attributes: - print(str(attr)) - - # Check that exceptions are raised as expected. - try: - op.attributes["does_not_exist"] - except KeyError: - pass - else: - assert False, "expected KeyError on accessing a non-existent attribute" - - try: - op.attributes[42] - except IndexError: - pass - else: - assert False, "expected IndexError on accessing an out-of-bounds attribute" - + """, + ctx, + ) + op = module.body.operations[0] + assert len(op.attributes) == 3 + iattr = IntegerAttr(op.attributes["some.attribute"]) + fattr = FloatAttr(op.attributes["other.attribute"]) + sattr = StringAttr(op.attributes["dependent"]) + # CHECK: Attribute type i8, value 1 + print(f"Attribute type {iattr.type}, value {iattr.value}") + # CHECK: Attribute type f64, value 3.0 + print(f"Attribute type {fattr.type}, value {fattr.value}") + # CHECK: Attribute value text + print(f"Attribute value {sattr.value}") + # CHECK: Attribute value b'text' + print(f"Attribute value {sattr.value_bytes}") + + # We don't know in which order the attributes are stored. + # CHECK-DAG: NamedAttribute(dependent="text") + # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64) + # CHECK-DAG: NamedAttribute(some.attribute=1 : i8) + for attr in op.attributes: + print(str(attr)) + + # Check that exceptions are raised as expected. + try: + op.attributes["does_not_exist"] + except KeyError: + pass + else: + assert False, "expected KeyError on accessing a non-existent attribute" + try: + op.attributes[42] + except IndexError: + pass + else: + assert False, "expected IndexError on accessing an out-of-bounds attribute" # CHECK-LABEL: TEST: testOperationPrint @run def testOperationPrint(): - ctx = Context() - module = Module.parse( - r""" + ctx = Context() + module = Module.parse( + r""" func.func @f1(%arg0: i32) -> i32 { %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom") return %arg0 : i32 } - """, ctx) - - # Test print to stdout. - # CHECK: return %arg0 : i32 - module.operation.print() - - # Test print to text file. - f = io.StringIO() - # CHECK: - # CHECK: return %arg0 : i32 - module.operation.print(file=f) - str_value = f.getvalue() - print(str_value.__class__) - print(f.getvalue()) - - # Test roundtrip to bytecode. - bytecode_stream = io.BytesIO() - module.operation.write_bytecode(bytecode_stream, desired_version=1) - bytecode = bytecode_stream.getvalue() - assert bytecode.startswith(b'ML\xefR'), "Expected bytecode to start with MLïR" - module_roundtrip = Module.parse(bytecode, ctx) - f = io.StringIO() - module_roundtrip.operation.print(file=f) - roundtrip_value = f.getvalue() - assert str_value == roundtrip_value, "Mismatch after roundtrip bytecode" - - - # Test print to binary file. - f = io.BytesIO() - # CHECK: - # CHECK: return %arg0 : i32 - module.operation.print(file=f, binary=True) - bytes_value = f.getvalue() - print(bytes_value.__class__) - print(bytes_value) - - # Test get_asm local_scope. - # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom") - module.operation.print(enable_debug_info=True, use_local_scope=True) - - # Test get_asm with options. - # CHECK: value = dense_resource<__elided__> : tensor<4xi32> - # CHECK: "func.return"(%arg0) : (i32) -> () -:4:7 - module.operation.print( - large_elements_limit=2, - enable_debug_info=True, - pretty_debug_info=True, - print_generic_op_form=True, - use_local_scope=True) - - + """, + ctx, + ) + + # Test print to stdout. + # CHECK: return %arg0 : i32 + module.operation.print() + + # Test print to text file. + f = io.StringIO() + # CHECK: + # CHECK: return %arg0 : i32 + module.operation.print(file=f) + str_value = f.getvalue() + print(str_value.__class__) + print(f.getvalue()) + + # Test roundtrip to bytecode. + bytecode_stream = io.BytesIO() + module.operation.write_bytecode(bytecode_stream, desired_version=1) + bytecode = bytecode_stream.getvalue() + assert bytecode.startswith(b"ML\xefR"), "Expected bytecode to start with MLïR" + module_roundtrip = Module.parse(bytecode, ctx) + f = io.StringIO() + module_roundtrip.operation.print(file=f) + roundtrip_value = f.getvalue() + assert str_value == roundtrip_value, "Mismatch after roundtrip bytecode" + + # Test print to binary file. + f = io.BytesIO() + # CHECK: + # CHECK: return %arg0 : i32 + module.operation.print(file=f, binary=True) + bytes_value = f.getvalue() + print(bytes_value.__class__) + print(bytes_value) + + # Test get_asm local_scope. + # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom") + module.operation.print(enable_debug_info=True, use_local_scope=True) + + # Test get_asm with options. + # CHECK: value = dense_resource<__elided__> : tensor<4xi32> + # CHECK: "func.return"(%arg0) : (i32) -> () -:4:7 + module.operation.print( + large_elements_limit=2, + enable_debug_info=True, + pretty_debug_info=True, + print_generic_op_form=True, + use_local_scope=True, + ) # CHECK-LABEL: TEST: testKnownOpView @run def testKnownOpView(): - with Context(), Location.unknown(): - Context.current.allow_unregistered_dialects = True - module = Module.parse(r""" + with Context(), Location.unknown(): + Context.current.allow_unregistered_dialects = True + module = Module.parse( + r""" %1 = "custom.f32"() : () -> f32 %2 = "custom.f32"() : () -> f32 %3 = arith.addf %1, %2 : f32 - """) - print(module) + """ + ) + print(module) - # addf should map to a known OpView class in the arithmetic dialect. - # We know the OpView for it defines an 'lhs' attribute. - addf = module.body.operations[2] - # CHECK: () %0:2 = "custom.two_result"() : () -> (f32, f32) %1 = "custom.one_result"() : () -> f32 - """) - print(module) + """ + ) + print(module) - try: - module.body.operations[0].result - except ValueError as e: - # CHECK: Cannot call .result on operation custom.no_result which has 0 results - print(e) - else: - assert False, "Expected exception" + try: + module.body.operations[0].result + except ValueError as e: + # CHECK: Cannot call .result on operation custom.no_result which has 0 results + print(e) + else: + assert False, "Expected exception" - try: - module.body.operations[1].result - except ValueError as e: - # CHECK: Cannot call .result on operation custom.two_result which has 2 results - print(e) - else: - assert False, "Expected exception" + try: + module.body.operations[1].result + except ValueError as e: + # CHECK: Cannot call .result on operation custom.two_result which has 2 results + print(e) + else: + assert False, "Expected exception" - # CHECK: %1 = "custom.one_result"() : () -> f32 - print(module.body.operations[2]) + # CHECK: %1 = "custom.one_result"() : () -> f32 + print(module.body.operations[2]) def create_invalid_operation(): - # This module has two region and is invalid verify that we fallback - # to the generic printer for safety. - op = Operation.create("builtin.module", regions=2) - op.regions[0].blocks.append() - return op + # This module has two region and is invalid verify that we fallback + # to the generic printer for safety. + op = Operation.create("builtin.module", regions=2) + op.regions[0].blocks.append() + return op + # CHECK-LABEL: TEST: testInvalidOperationStrSoftFails @run def testInvalidOperationStrSoftFails(): - ctx = Context() - with Location.unknown(ctx): - invalid_op = create_invalid_operation() - # Verify that we fallback to the generic printer for safety. - # CHECK: "builtin.module"() ({ - # CHECK: }) : () -> () - print(invalid_op) - try: - invalid_op.verify() - except MLIRError as e: - # CHECK: Exception: < - # CHECK: Verification failed: - # CHECK: error: unknown: 'builtin.module' op requires one region - # CHECK: note: unknown: see current operation: - # CHECK: "builtin.module"() ({ - # CHECK: ^bb0: - # CHECK: }, { - # CHECK: }) : () -> () - # CHECK: > - print(f"Exception: <{e}>") + ctx = Context() + with Location.unknown(ctx): + invalid_op = create_invalid_operation() + # Verify that we fallback to the generic printer for safety. + # CHECK: "builtin.module"() ({ + # CHECK: }) : () -> () + print(invalid_op) + try: + invalid_op.verify() + except MLIRError as e: + # CHECK: Exception: < + # CHECK: Verification failed: + # CHECK: error: unknown: 'builtin.module' op requires one region + # CHECK: note: unknown: see current operation: + # CHECK: "builtin.module"() ({ + # CHECK: ^bb0: + # CHECK: }, { + # CHECK: }) : () -> () + # CHECK: > + print(f"Exception: <{e}>") # CHECK-LABEL: TEST: testInvalidModuleStrSoftFails @run def testInvalidModuleStrSoftFails(): - ctx = Context() - with Location.unknown(ctx): - module = Module.create() - with InsertionPoint(module.body): - invalid_op = create_invalid_operation() - # Verify that we fallback to the generic printer for safety. - # CHECK: "builtin.module"() ({ - # CHECK: }) : () -> () - print(module) + ctx = Context() + with Location.unknown(ctx): + module = Module.create() + with InsertionPoint(module.body): + invalid_op = create_invalid_operation() + # Verify that we fallback to the generic printer for safety. + # CHECK: "builtin.module"() ({ + # CHECK: }) : () -> () + print(module) # CHECK-LABEL: TEST: testInvalidOperationGetAsmBinarySoftFails @run def testInvalidOperationGetAsmBinarySoftFails(): - ctx = Context() - with Location.unknown(ctx): - invalid_op = create_invalid_operation() - # Verify that we fallback to the generic printer for safety. - # CHECK: b'"builtin.module"() ({\n^bb0:\n}, {\n}) : () -> ()\n' - print(invalid_op.get_asm(binary=True)) + ctx = Context() + with Location.unknown(ctx): + invalid_op = create_invalid_operation() + # Verify that we fallback to the generic printer for safety. + # CHECK: b'"builtin.module"() ({\n^bb0:\n}, {\n}) : () -> ()\n' + print(invalid_op.get_asm(binary=True)) # CHECK-LABEL: TEST: testCreateWithInvalidAttributes @run def testCreateWithInvalidAttributes(): - ctx = Context() - with Location.unknown(ctx): - try: - Operation.create( - "builtin.module", attributes={None: StringAttr.get("name")}) - except Exception as e: - # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module" - print(e) - try: - Operation.create( - "builtin.module", attributes={42: StringAttr.get("name")}) - except Exception as e: - # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module" - print(e) - try: - Operation.create("builtin.module", attributes={"some_key": ctx}) - except Exception as e: - # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module" - print(e) - try: - Operation.create("builtin.module", attributes={"some_key": None}) - except Exception as e: - # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module" - print(e) + ctx = Context() + with Location.unknown(ctx): + try: + Operation.create( + "builtin.module", attributes={None: StringAttr.get("name")} + ) + except Exception as e: + # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module" + print(e) + try: + Operation.create("builtin.module", attributes={42: StringAttr.get("name")}) + except Exception as e: + # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module" + print(e) + try: + Operation.create("builtin.module", attributes={"some_key": ctx}) + except Exception as e: + # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module" + print(e) + try: + Operation.create("builtin.module", attributes={"some_key": None}) + except Exception as e: + # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module" + print(e) # CHECK-LABEL: TEST: testOperationName @run def testOperationName(): - ctx = Context() - ctx.allow_unregistered_dialects = True - module = Module.parse( - r""" + ctx = Context() + ctx.allow_unregistered_dialects = True + module = Module.parse( + r""" %0 = "custom.op1"() : () -> f32 %1 = "custom.op2"() : () -> i32 %2 = "custom.op1"() : () -> f32 - """, ctx) + """, + ctx, + ) - # CHECK: custom.op1 - # CHECK: custom.op2 - # CHECK: custom.op1 - for op in module.body.operations: - print(op.operation.name) + # CHECK: custom.op1 + # CHECK: custom.op2 + # CHECK: custom.op1 + for op in module.body.operations: + print(op.operation.name) # CHECK-LABEL: TEST: testCapsuleConversions @run def testCapsuleConversions(): - ctx = Context() - ctx.allow_unregistered_dialects = True - with Location.unknown(ctx): - m = Operation.create("custom.op1").operation - m_capsule = m._CAPIPtr - assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule) - m2 = Operation._CAPICreate(m_capsule) - assert m2 is m + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + m = Operation.create("custom.op1").operation + m_capsule = m._CAPIPtr + assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule) + m2 = Operation._CAPICreate(m_capsule) + assert m2 is m # CHECK-LABEL: TEST: testOperationErase @run def testOperationErase(): - ctx = Context() - ctx.allow_unregistered_dialects = True - with Location.unknown(ctx): - m = Module.create() - with InsertionPoint(m.body): - op = Operation.create("custom.op1") + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + m = Module.create() + with InsertionPoint(m.body): + op = Operation.create("custom.op1") - # CHECK: "custom.op1" - print(m) + # CHECK: "custom.op1" + print(m) - op.operation.erase() + op.operation.erase() - # CHECK-NOT: "custom.op1" - print(m) + # CHECK-NOT: "custom.op1" + print(m) - # Ensure we can create another operation - Operation.create("custom.op2") + # Ensure we can create another operation + Operation.create("custom.op2") # CHECK-LABEL: TEST: testOperationClone @run def testOperationClone(): - ctx = Context() - ctx.allow_unregistered_dialects = True - with Location.unknown(ctx): - m = Module.create() - with InsertionPoint(m.body): - op = Operation.create("custom.op1") + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + m = Module.create() + with InsertionPoint(m.body): + op = Operation.create("custom.op1") - # CHECK: "custom.op1" - print(m) + # CHECK: "custom.op1" + print(m) - clone = op.operation.clone() - op.operation.erase() + clone = op.operation.clone() + op.operation.erase() - # CHECK: "custom.op1" - print(m) + # CHECK: "custom.op1" + print(m) # CHECK-LABEL: TEST: testOperationLoc @run def testOperationLoc(): - ctx = Context() - ctx.allow_unregistered_dialects = True - with ctx: - loc = Location.name("loc") - op = Operation.create("custom.op", loc=loc) - assert op.location == loc - assert op.operation.location == loc + ctx = Context() + ctx.allow_unregistered_dialects = True + with ctx: + loc = Location.name("loc") + op = Operation.create("custom.op", loc=loc) + assert op.location == loc + assert op.operation.location == loc # CHECK-LABEL: TEST: testModuleMerge @run def testModuleMerge(): - with Context(): - m1 = Module.parse("func.func private @foo()") - m2 = Module.parse(""" + with Context(): + m1 = Module.parse("func.func private @foo()") + m2 = Module.parse( + """ func.func private @bar() func.func private @qux() - """) - foo = m1.body.operations[0] - bar = m2.body.operations[0] - qux = m2.body.operations[1] - bar.move_before(foo) - qux.move_after(foo) + """ + ) + foo = m1.body.operations[0] + bar = m2.body.operations[0] + qux = m2.body.operations[1] + bar.move_before(foo) + qux.move_after(foo) - # CHECK: module - # CHECK: func private @bar - # CHECK: func private @foo - # CHECK: func private @qux - print(m1) + # CHECK: module + # CHECK: func private @bar + # CHECK: func private @foo + # CHECK: func private @qux + print(m1) - # CHECK: module { - # CHECK-NEXT: } - print(m2) + # CHECK: module { + # CHECK-NEXT: } + print(m2) # CHECK-LABEL: TEST: testAppendMoveFromAnotherBlock @run def testAppendMoveFromAnotherBlock(): - with Context(): - m1 = Module.parse("func.func private @foo()") - m2 = Module.parse("func.func private @bar()") - func = m1.body.operations[0] - m2.body.append(func) + with Context(): + m1 = Module.parse("func.func private @foo()") + m2 = Module.parse("func.func private @bar()") + func = m1.body.operations[0] + m2.body.append(func) - # CHECK: module - # CHECK: func private @bar - # CHECK: func private @foo + # CHECK: module + # CHECK: func private @bar + # CHECK: func private @foo - print(m2) - # CHECK: module { - # CHECK-NEXT: } - print(m1) + print(m2) + # CHECK: module { + # CHECK-NEXT: } + print(m1) # CHECK-LABEL: TEST: testDetachFromParent @run def testDetachFromParent(): - with Context(): - m1 = Module.parse("func.func private @foo()") - func = m1.body.operations[0].detach_from_parent() + with Context(): + m1 = Module.parse("func.func private @foo()") + func = m1.body.operations[0].detach_from_parent() - try: - func.detach_from_parent() - except ValueError as e: - if "has no parent" not in str(e): - raise - else: - assert False, "expected ValueError when detaching a detached operation" + try: + func.detach_from_parent() + except ValueError as e: + if "has no parent" not in str(e): + raise + else: + assert False, "expected ValueError when detaching a detached operation" - print(m1) - # CHECK-NOT: func private @foo + print(m1) + # CHECK-NOT: func private @foo # CHECK-LABEL: TEST: testOperationHash @run def testOperationHash(): - ctx = Context() - ctx.allow_unregistered_dialects = True - with ctx, Location.unknown(): - op = Operation.create("custom.op1") - assert hash(op) == hash(op.operation) + ctx = Context() + ctx.allow_unregistered_dialects = True + with ctx, Location.unknown(): + op = Operation.create("custom.op1") + assert hash(op) == hash(op.operation) # CHECK-LABEL: TEST: testOperationParse @run def testOperationParse(): - with Context() as ctx: - ctx.allow_unregistered_dialects = True - - # Generic operation parsing. - m = Operation.parse('module {}') - o = Operation.parse('"test.foo"() : () -> ()') - assert isinstance(m, ModuleOp) - assert type(o) is OpView - - # Parsing specific operation. - m = ModuleOp.parse('module {}') - assert isinstance(m, ModuleOp) - try: - ModuleOp.parse('"test.foo"() : () -> ()') - except MLIRError as e: - # CHECK: error: Expected a 'builtin.module' op, got: 'test.foo' - print(f"error: {e}") - else: - assert False, "expected error" - - o = Operation.parse('"test.foo"() : () -> ()', source_name="my-source-string") - # CHECK: op_with_source_name: "test.foo"() : () -> () loc("my-source-string":1:1) - print(f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}") + with Context() as ctx: + ctx.allow_unregistered_dialects = True + + # Generic operation parsing. + m = Operation.parse("module {}") + o = Operation.parse('"test.foo"() : () -> ()') + assert isinstance(m, ModuleOp) + assert type(o) is OpView + + # Parsing specific operation. + m = ModuleOp.parse("module {}") + assert isinstance(m, ModuleOp) + try: + ModuleOp.parse('"test.foo"() : () -> ()') + except MLIRError as e: + # CHECK: error: Expected a 'builtin.module' op, got: 'test.foo' + print(f"error: {e}") + else: + assert False, "expected error" + + o = Operation.parse('"test.foo"() : () -> ()', source_name="my-source-string") + # CHECK: op_with_source_name: "test.foo"() : () -> () loc("my-source-string":1:1) + print( + f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}" + ) diff --git a/mlir/test/python/ir/symbol_table.py b/mlir/test/python/ir/symbol_table.py index 9ce8959..17f3e35 100644 --- a/mlir/test/python/ir/symbol_table.py +++ b/mlir/test/python/ir/symbol_table.py @@ -7,150 +7,162 @@ from mlir.ir import * def run(f): - print("\nTEST:", f.__name__) - f() - gc.collect() - assert Context._get_live_count() == 0 - return f + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + return f # CHECK-LABEL: TEST: testSymbolTableInsert @run def testSymbolTableInsert(): - with Context() as ctx: - ctx.allow_unregistered_dialects = True - m1 = Module.parse(""" + with Context() as ctx: + ctx.allow_unregistered_dialects = True + m1 = Module.parse( + """ func.func private @foo() - func.func private @bar()""") - m2 = Module.parse(""" + func.func private @bar()""" + ) + m2 = Module.parse( + """ func.func private @qux() func.func private @foo() - "foo.bar"() : () -> ()""") - - symbol_table = SymbolTable(m1.operation) - - # CHECK: func private @foo - # CHECK: func private @bar - assert "foo" in symbol_table - print(symbol_table["foo"]) - assert "bar" in symbol_table - bar = symbol_table["bar"] - print(symbol_table["bar"]) - - assert "qux" not in symbol_table - - del symbol_table["bar"] - try: - symbol_table.erase(symbol_table["bar"]) - except KeyError: - pass - else: - assert False, "expected KeyError" - - # CHECK: module - # CHECK: func private @foo() - print(m1) - assert "bar" not in symbol_table - - try: - print(bar) - except RuntimeError as e: - if "the operation has been invalidated" not in str(e): - raise - else: - assert False, "expected RuntimeError due to invalidated operation" - - qux = m2.body.operations[0] - m1.body.append(qux) - symbol_table.insert(qux) - assert "qux" in symbol_table - - # Check that insertion actually renames this symbol in the symbol table. - foo2 = m2.body.operations[0] - m1.body.append(foo2) - updated_name = symbol_table.insert(foo2) - assert foo2.name.value != "foo" - assert foo2.name == updated_name - - # CHECK: module - # CHECK: func private @foo() - # CHECK: func private @qux() - # CHECK: func private @foo{{.*}} - print(m1) - - try: - symbol_table.insert(m2.body.operations[0]) - except ValueError as e: - if "Expected operation to have a symbol name" not in str(e): - raise - else: - assert False, "exepcted ValueError when adding a non-symbol" + "foo.bar"() : () -> ()""" + ) + + symbol_table = SymbolTable(m1.operation) + + # CHECK: func private @foo + # CHECK: func private @bar + assert "foo" in symbol_table + print(symbol_table["foo"]) + assert "bar" in symbol_table + bar = symbol_table["bar"] + print(symbol_table["bar"]) + + assert "qux" not in symbol_table + + del symbol_table["bar"] + try: + symbol_table.erase(symbol_table["bar"]) + except KeyError: + pass + else: + assert False, "expected KeyError" + + # CHECK: module + # CHECK: func private @foo() + print(m1) + assert "bar" not in symbol_table + + try: + print(bar) + except RuntimeError as e: + if "the operation has been invalidated" not in str(e): + raise + else: + assert False, "expected RuntimeError due to invalidated operation" + + qux = m2.body.operations[0] + m1.body.append(qux) + symbol_table.insert(qux) + assert "qux" in symbol_table + + # Check that insertion actually renames this symbol in the symbol table. + foo2 = m2.body.operations[0] + m1.body.append(foo2) + updated_name = symbol_table.insert(foo2) + assert foo2.name.value != "foo" + assert foo2.name == updated_name + + # CHECK: module + # CHECK: func private @foo() + # CHECK: func private @qux() + # CHECK: func private @foo{{.*}} + print(m1) + + try: + symbol_table.insert(m2.body.operations[0]) + except ValueError as e: + if "Expected operation to have a symbol name" not in str(e): + raise + else: + assert False, "exepcted ValueError when adding a non-symbol" # CHECK-LABEL: testSymbolTableRAUW @run def testSymbolTableRAUW(): - with Context() as ctx: - m = Module.parse(""" + with Context() as ctx: + m = Module.parse( + """ func.func private @foo() { call @bar() : () -> () return } func.func private @bar() - """) - foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2] - SymbolTable.set_symbol_name(bar, "bam") - # Note that module.operation counts as a "nested symbol table" which won't - # be traversed into, so it is necessary to traverse its children. - SymbolTable.replace_all_symbol_uses("bar", "bam", foo) - # CHECK: call @bam() - # CHECK: func private @bam - print(m) - # CHECK: Foo symbol: "foo" - # CHECK: Bar symbol: "bam" - print(f"Foo symbol: {SymbolTable.get_symbol_name(foo)}") - print(f"Bar symbol: {SymbolTable.get_symbol_name(bar)}") + """ + ) + foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2] + SymbolTable.set_symbol_name(bar, "bam") + # Note that module.operation counts as a "nested symbol table" which won't + # be traversed into, so it is necessary to traverse its children. + SymbolTable.replace_all_symbol_uses("bar", "bam", foo) + # CHECK: call @bam() + # CHECK: func private @bam + print(m) + # CHECK: Foo symbol: "foo" + # CHECK: Bar symbol: "bam" + print(f"Foo symbol: {SymbolTable.get_symbol_name(foo)}") + print(f"Bar symbol: {SymbolTable.get_symbol_name(bar)}") # CHECK-LABEL: testSymbolTableVisibility @run def testSymbolTableVisibility(): - with Context() as ctx: - m = Module.parse(""" + with Context() as ctx: + m = Module.parse( + """ func.func private @foo() { return } - """) - foo = m.operation.regions[0].blocks[0].operations[0] - # CHECK: Existing visibility: "private" - print(f"Existing visibility: {SymbolTable.get_visibility(foo)}") - SymbolTable.set_visibility(foo, "public") - # CHECK: func public @foo - print(m) + """ + ) + foo = m.operation.regions[0].blocks[0].operations[0] + # CHECK: Existing visibility: "private" + print(f"Existing visibility: {SymbolTable.get_visibility(foo)}") + SymbolTable.set_visibility(foo, "public") + # CHECK: func public @foo + print(m) # CHECK: testWalkSymbolTables @run def testWalkSymbolTables(): - with Context() as ctx: - m = Module.parse(""" + with Context() as ctx: + m = Module.parse( + """ module @outer { module @inner{ } } - """) - def callback(symbol_table_op, uses_visible): - print(f"SYMBOL TABLE: {uses_visible}: {symbol_table_op}") - # CHECK: SYMBOL TABLE: True: module @inner - # CHECK: SYMBOL TABLE: True: module @outer - SymbolTable.walk_symbol_tables(m.operation, True, callback) - - # Make sure exceptions in the callback are handled. - def error_callback(symbol_table_op, uses_visible): - assert False, "Raised from python" - try: - SymbolTable.walk_symbol_tables(m.operation, True, error_callback) - except RuntimeError as e: - # CHECK: GOT EXCEPTION: Exception raised in callback: AssertionError: Raised from python - print(f"GOT EXCEPTION: {e}") + """ + ) + def callback(symbol_table_op, uses_visible): + print(f"SYMBOL TABLE: {uses_visible}: {symbol_table_op}") + + # CHECK: SYMBOL TABLE: True: module @inner + # CHECK: SYMBOL TABLE: True: module @outer + SymbolTable.walk_symbol_tables(m.operation, True, callback) + + # Make sure exceptions in the callback are handled. + def error_callback(symbol_table_op, uses_visible): + assert False, "Raised from python" + + try: + SymbolTable.walk_symbol_tables(m.operation, True, error_callback) + except RuntimeError as e: + # CHECK: GOT EXCEPTION: Exception raised in callback: AssertionError: Raised from python + print(f"GOT EXCEPTION: {e}") diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py index 66568c4..8a2ada1 100644 --- a/mlir/test/python/ir/value.py +++ b/mlir/test/python/ir/value.py @@ -6,229 +6,235 @@ from mlir.dialects import func def run(f): - print("\nTEST:", f.__name__) - f() - gc.collect() - assert Context._get_live_count() == 0 - return f + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + return f # CHECK-LABEL: TEST: testCapsuleConversions @run def testCapsuleConversions(): - ctx = Context() - ctx.allow_unregistered_dialects = True - with Location.unknown(ctx): - i32 = IntegerType.get_signless(32) - value = Operation.create("custom.op1", results=[i32]).result - value_capsule = value._CAPIPtr - assert '"mlir.ir.Value._CAPIPtr"' in repr(value_capsule) - value2 = Value._CAPICreate(value_capsule) - assert value2 == value + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + i32 = IntegerType.get_signless(32) + value = Operation.create("custom.op1", results=[i32]).result + value_capsule = value._CAPIPtr + assert '"mlir.ir.Value._CAPIPtr"' in repr(value_capsule) + value2 = Value._CAPICreate(value_capsule) + assert value2 == value # CHECK-LABEL: TEST: testOpResultOwner @run def testOpResultOwner(): - ctx = Context() - ctx.allow_unregistered_dialects = True - with Location.unknown(ctx): - i32 = IntegerType.get_signless(32) - op = Operation.create("custom.op1", results=[i32]) - assert op.result.owner == op + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + i32 = IntegerType.get_signless(32) + op = Operation.create("custom.op1", results=[i32]) + assert op.result.owner == op # CHECK-LABEL: TEST: testBlockArgOwner @run def testBlockArgOwner(): - ctx = Context() - ctx.allow_unregistered_dialects = True - module = Module.parse( - r""" + ctx = Context() + ctx.allow_unregistered_dialects = True + module = Module.parse( + r""" func.func @foo(%arg0: f32) { return - }""", ctx) - func = module.body.operations[0] - block = func.regions[0].blocks[0] - assert block.arguments[0].owner == block + }""", + ctx, + ) + func = module.body.operations[0] + block = func.regions[0].blocks[0] + assert block.arguments[0].owner == block # CHECK-LABEL: TEST: testValueIsInstance @run def testValueIsInstance(): - ctx = Context() - ctx.allow_unregistered_dialects = True - module = Module.parse( - r""" + ctx = Context() + ctx.allow_unregistered_dialects = True + module = Module.parse( + r""" func.func @foo(%arg0: f32) { %0 = "some_dialect.some_op"() : () -> f64 return - }""", ctx) - func = module.body.operations[0] - assert BlockArgument.isinstance(func.regions[0].blocks[0].arguments[0]) - assert not OpResult.isinstance(func.regions[0].blocks[0].arguments[0]) + }""", + ctx, + ) + func = module.body.operations[0] + assert BlockArgument.isinstance(func.regions[0].blocks[0].arguments[0]) + assert not OpResult.isinstance(func.regions[0].blocks[0].arguments[0]) - op = func.regions[0].blocks[0].operations[0] - assert not BlockArgument.isinstance(op.results[0]) - assert OpResult.isinstance(op.results[0]) + op = func.regions[0].blocks[0].operations[0] + assert not BlockArgument.isinstance(op.results[0]) + assert OpResult.isinstance(op.results[0]) # CHECK-LABEL: TEST: testValueHash @run def testValueHash(): - ctx = Context() - ctx.allow_unregistered_dialects = True - module = Module.parse( - r""" + ctx = Context() + ctx.allow_unregistered_dialects = True + module = Module.parse( + r""" func.func @foo(%arg0: f32) -> f32 { %0 = "some_dialect.some_op"(%arg0) : (f32) -> f32 return %0 : f32 - }""", ctx) + }""", + ctx, + ) - [func] = module.body.operations - block = func.entry_block - op, ret = block.operations - assert hash(block.arguments[0]) == hash(op.operands[0]) - assert hash(op.result) == hash(ret.operands[0]) + [func] = module.body.operations + block = func.entry_block + op, ret = block.operations + assert hash(block.arguments[0]) == hash(op.operands[0]) + assert hash(op.result) == hash(ret.operands[0]) # CHECK-LABEL: TEST: testValueUses @run def testValueUses(): - ctx = Context() - ctx.allow_unregistered_dialects = True - with Location.unknown(ctx): - i32 = IntegerType.get_signless(32) - module = Module.create() - with InsertionPoint(module.body): - value = Operation.create("custom.op1", results=[i32]).results[0] - op1 = Operation.create("custom.op2", operands=[value]) - op2 = Operation.create("custom.op2", operands=[value]) - - # CHECK: Use owner: "custom.op2" - # CHECK: Use operand_number: 0 - # CHECK: Use owner: "custom.op2" - # CHECK: Use operand_number: 0 - for use in value.uses: - assert use.owner in [op1, op2] - print(f"Use owner: {use.owner}") - print(f"Use operand_number: {use.operand_number}") + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + i32 = IntegerType.get_signless(32) + module = Module.create() + with InsertionPoint(module.body): + value = Operation.create("custom.op1", results=[i32]).results[0] + op1 = Operation.create("custom.op2", operands=[value]) + op2 = Operation.create("custom.op2", operands=[value]) + + # CHECK: Use owner: "custom.op2" + # CHECK: Use operand_number: 0 + # CHECK: Use owner: "custom.op2" + # CHECK: Use operand_number: 0 + for use in value.uses: + assert use.owner in [op1, op2] + print(f"Use owner: {use.owner}") + print(f"Use operand_number: {use.operand_number}") # CHECK-LABEL: TEST: testValueReplaceAllUsesWith @run def testValueReplaceAllUsesWith(): - ctx = Context() - ctx.allow_unregistered_dialects = True - with Location.unknown(ctx): - i32 = IntegerType.get_signless(32) - module = Module.create() - with InsertionPoint(module.body): - value = Operation.create("custom.op1", results=[i32]).results[0] - op1 = Operation.create("custom.op2", operands=[value]) - op2 = Operation.create("custom.op2", operands=[value]) - value2 = Operation.create("custom.op3", results=[i32]).results[0] - value.replace_all_uses_with(value2) - - assert len(list(value.uses)) == 0 - - # CHECK: Use owner: "custom.op2" - # CHECK: Use operand_number: 0 - # CHECK: Use owner: "custom.op2" - # CHECK: Use operand_number: 0 - for use in value2.uses: - assert use.owner in [op1, op2] - print(f"Use owner: {use.owner}") - print(f"Use operand_number: {use.operand_number}") + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + i32 = IntegerType.get_signless(32) + module = Module.create() + with InsertionPoint(module.body): + value = Operation.create("custom.op1", results=[i32]).results[0] + op1 = Operation.create("custom.op2", operands=[value]) + op2 = Operation.create("custom.op2", operands=[value]) + value2 = Operation.create("custom.op3", results=[i32]).results[0] + value.replace_all_uses_with(value2) + + assert len(list(value.uses)) == 0 + + # CHECK: Use owner: "custom.op2" + # CHECK: Use operand_number: 0 + # CHECK: Use owner: "custom.op2" + # CHECK: Use operand_number: 0 + for use in value2.uses: + assert use.owner in [op1, op2] + print(f"Use owner: {use.owner}") + print(f"Use operand_number: {use.operand_number}") # CHECK-LABEL: TEST: testValuePrintAsOperand @run def testValuePrintAsOperand(): - ctx = Context() - ctx.allow_unregistered_dialects = True - with Location.unknown(ctx): - i32 = IntegerType.get_signless(32) - module = Module.create() - with InsertionPoint(module.body): - value = Operation.create("custom.op1", results=[i32]).results[0] - # CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32) - print(value) - - value2 = Operation.create("custom.op2", results=[i32]).results[0] - # CHECK: Value(%[[VAL2:.*]] = "custom.op2"() : () -> i32) - print(value2) - - f = func.FuncOp("test", ([i32, i32], [])) - entry_block1 = Block.create_at_start(f.operation.regions[0], [i32, i32]) - - with InsertionPoint(entry_block1): - value3 = Operation.create("custom.op3", results=[i32]).results[0] - # CHECK: Value(%[[VAL3:.*]] = "custom.op3"() : () -> i32) - print(value3) - value4 = Operation.create("custom.op4", results=[i32]).results[0] - # CHECK: Value(%[[VAL4:.*]] = "custom.op4"() : () -> i32) - print(value4) - - f = func.FuncOp("test", ([i32, i32], [])) - entry_block2 = Block.create_at_start(f.operation.regions[0], [i32, i32]) - with InsertionPoint(entry_block2): - value5 = Operation.create("custom.op5", results=[i32]).results[0] - # CHECK: Value(%[[VAL5:.*]] = "custom.op5"() : () -> i32) - print(value5) - value6 = Operation.create("custom.op6", results=[i32]).results[0] - # CHECK: Value(%[[VAL6:.*]] = "custom.op6"() : () -> i32) - print(value6) - - func.ReturnOp([]) - - func.ReturnOp([]) - - # CHECK: %[[VAL1]] - print(value.get_name()) - # CHECK: %[[VAL2]] - print(value2.get_name()) - # CHECK: %[[VAL3]] - print(value3.get_name()) - # CHECK: %[[VAL4]] - print(value4.get_name()) - - # CHECK: %0 - print(value3.get_name(use_local_scope=True)) - # CHECK: %1 - print(value4.get_name(use_local_scope=True)) - - # CHECK: %[[VAL5]] - print(value5.get_name()) - # CHECK: %[[VAL6]] - print(value6.get_name()) - - # CHECK: %[[ARG0:.*]] - print(entry_block1.arguments[0].get_name()) - # CHECK: %[[ARG1:.*]] - print(entry_block1.arguments[1].get_name()) - - # CHECK: %[[ARG2:.*]] - print(entry_block2.arguments[0].get_name()) - # CHECK: %[[ARG3:.*]] - print(entry_block2.arguments[1].get_name()) - - # CHECK: module { - # CHECK: %[[VAL1]] = "custom.op1"() : () -> i32 - # CHECK: %[[VAL2]] = "custom.op2"() : () -> i32 - # CHECK: func.func @test(%[[ARG0]]: i32, %[[ARG1]]: i32) { - # CHECK: %[[VAL3]] = "custom.op3"() : () -> i32 - # CHECK: %[[VAL4]] = "custom.op4"() : () -> i32 - # CHECK: func @test(%[[ARG2]]: i32, %[[ARG3]]: i32) { - # CHECK: %[[VAL5]] = "custom.op5"() : () -> i32 - # CHECK: %[[VAL6]] = "custom.op6"() : () -> i32 - # CHECK: return - # CHECK: } - # CHECK: return - # CHECK: } - # CHECK: } - print(module) - - value2.owner.detach_from_parent() - # CHECK: %0 - print(value2.get_name()) + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + i32 = IntegerType.get_signless(32) + module = Module.create() + with InsertionPoint(module.body): + value = Operation.create("custom.op1", results=[i32]).results[0] + # CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32) + print(value) + + value2 = Operation.create("custom.op2", results=[i32]).results[0] + # CHECK: Value(%[[VAL2:.*]] = "custom.op2"() : () -> i32) + print(value2) + + f = func.FuncOp("test", ([i32, i32], [])) + entry_block1 = Block.create_at_start(f.operation.regions[0], [i32, i32]) + + with InsertionPoint(entry_block1): + value3 = Operation.create("custom.op3", results=[i32]).results[0] + # CHECK: Value(%[[VAL3:.*]] = "custom.op3"() : () -> i32) + print(value3) + value4 = Operation.create("custom.op4", results=[i32]).results[0] + # CHECK: Value(%[[VAL4:.*]] = "custom.op4"() : () -> i32) + print(value4) + + f = func.FuncOp("test", ([i32, i32], [])) + entry_block2 = Block.create_at_start(f.operation.regions[0], [i32, i32]) + with InsertionPoint(entry_block2): + value5 = Operation.create("custom.op5", results=[i32]).results[0] + # CHECK: Value(%[[VAL5:.*]] = "custom.op5"() : () -> i32) + print(value5) + value6 = Operation.create("custom.op6", results=[i32]).results[0] + # CHECK: Value(%[[VAL6:.*]] = "custom.op6"() : () -> i32) + print(value6) + + func.ReturnOp([]) + + func.ReturnOp([]) + + # CHECK: %[[VAL1]] + print(value.get_name()) + # CHECK: %[[VAL2]] + print(value2.get_name()) + # CHECK: %[[VAL3]] + print(value3.get_name()) + # CHECK: %[[VAL4]] + print(value4.get_name()) + + # CHECK: %0 + print(value3.get_name(use_local_scope=True)) + # CHECK: %1 + print(value4.get_name(use_local_scope=True)) + + # CHECK: %[[VAL5]] + print(value5.get_name()) + # CHECK: %[[VAL6]] + print(value6.get_name()) + + # CHECK: %[[ARG0:.*]] + print(entry_block1.arguments[0].get_name()) + # CHECK: %[[ARG1:.*]] + print(entry_block1.arguments[1].get_name()) + + # CHECK: %[[ARG2:.*]] + print(entry_block2.arguments[0].get_name()) + # CHECK: %[[ARG3:.*]] + print(entry_block2.arguments[1].get_name()) + + # CHECK: module { + # CHECK: %[[VAL1]] = "custom.op1"() : () -> i32 + # CHECK: %[[VAL2]] = "custom.op2"() : () -> i32 + # CHECK: func.func @test(%[[ARG0]]: i32, %[[ARG1]]: i32) { + # CHECK: %[[VAL3]] = "custom.op3"() : () -> i32 + # CHECK: %[[VAL4]] = "custom.op4"() : () -> i32 + # CHECK: func @test(%[[ARG2]]: i32, %[[ARG3]]: i32) { + # CHECK: %[[VAL5]] = "custom.op5"() : () -> i32 + # CHECK: %[[VAL6]] = "custom.op6"() : () -> i32 + # CHECK: return + # CHECK: } + # CHECK: return + # CHECK: } + # CHECK: } + print(module) + + value2.owner.detach_from_parent() + # CHECK: %0 + print(value2.get_name()) diff --git a/mlir/test/python/lit.local.cfg b/mlir/test/python/lit.local.cfg index 8a98474..12d6e1f 100644 --- a/mlir/test/python/lit.local.cfg +++ b/mlir/test/python/lit.local.cfg @@ -1,4 +1,4 @@ -config.environment['ASAN_OPTIONS'] = 'detect_leaks=0' +config.environment["ASAN_OPTIONS"] = "detect_leaks=0" if not config.enable_bindings_python: - config.unsupported = True -config.excludes.add('python_test_ops.td') + config.unsupported = True +config.excludes.add("python_test_ops.td") diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py index 8b27653..4b3a02a 100644 --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -8,122 +8,140 @@ from mlir.dialects.func import FuncOp # Log everything to stderr and flush so that we have a unified stream to match # errors/info emitted by MLIR to stderr. def log(*args): - print(*args, file=sys.stderr) - sys.stderr.flush() + print(*args, file=sys.stderr) + sys.stderr.flush() + def run(f): - log("\nTEST:", f.__name__) - f() - gc.collect() - assert Context._get_live_count() == 0 + log("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + # Verify capsule interop. # CHECK-LABEL: TEST: testCapsule def testCapsule(): - with Context(): - pm = PassManager() - pm_capsule = pm._CAPIPtr - assert '"mlir.passmanager.PassManager._CAPIPtr"' in repr(pm_capsule) - pm._testing_release() - pm1 = PassManager._CAPICreate(pm_capsule) - assert pm1 is not None # And does not crash. + with Context(): + pm = PassManager() + pm_capsule = pm._CAPIPtr + assert '"mlir.passmanager.PassManager._CAPIPtr"' in repr(pm_capsule) + pm._testing_release() + pm1 = PassManager._CAPICreate(pm_capsule) + assert pm1 is not None # And does not crash. + + run(testCapsule) # CHECK-LABEL: TEST: testConstruct @run def testConstruct(): - with Context(): - # CHECK: pm1: 'any()' - # CHECK: pm2: 'builtin.module()' - pm1 = PassManager() - pm2 = PassManager("builtin.module") - log(f"pm1: '{pm1}'") - log(f"pm2: '{pm2}'") + with Context(): + # CHECK: pm1: 'any()' + # CHECK: pm2: 'builtin.module()' + pm1 = PassManager() + pm2 = PassManager("builtin.module") + log(f"pm1: '{pm1}'") + log(f"pm2: '{pm2}'") # Verify successful round-trip. # CHECK-LABEL: TEST: testParseSuccess def testParseSuccess(): - with Context(): - # An unregistered pass should not parse. - try: - pm = PassManager.parse("builtin.module(func.func(not-existing-pass{json=false}))") - except ValueError as e: - # CHECK: ValueError exception: {{.+}} 'not-existing-pass' does not refer to a registered pass - log("ValueError exception:", e) - else: - log("Exception not produced") - - # A registered pass should parse successfully. - pm = PassManager.parse("builtin.module(func.func(print-op-stats{json=false}))") - # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false})) - log("Roundtrip: ", pm) + with Context(): + # An unregistered pass should not parse. + try: + pm = PassManager.parse( + "builtin.module(func.func(not-existing-pass{json=false}))" + ) + except ValueError as e: + # CHECK: ValueError exception: {{.+}} 'not-existing-pass' does not refer to a registered pass + log("ValueError exception:", e) + else: + log("Exception not produced") + + # A registered pass should parse successfully. + pm = PassManager.parse("builtin.module(func.func(print-op-stats{json=false}))") + # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false})) + log("Roundtrip: ", pm) + + run(testParseSuccess) # Verify successful round-trip. # CHECK-LABEL: TEST: testParseSpacedPipeline def testParseSpacedPipeline(): - with Context(): - # A registered pass should parse successfully even if has extras spaces for readability - pm = PassManager.parse("""builtin.module( + with Context(): + # A registered pass should parse successfully even if has extras spaces for readability + pm = PassManager.parse( + """builtin.module( func.func( print-op-stats{ json=false } ) - )""") - # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false})) - log("Roundtrip: ", pm) + )""" + ) + # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false})) + log("Roundtrip: ", pm) + + run(testParseSpacedPipeline) # Verify failure on unregistered pass. # CHECK-LABEL: TEST: testParseFail def testParseFail(): - with Context(): - try: - pm = PassManager.parse("any(unknown-pass)") - except ValueError as e: - # CHECK: ValueError exception: MLIR Textual PassPipeline Parser:1:1: error: - # CHECK-SAME: 'unknown-pass' does not refer to a registered pass or pass pipeline - # CHECK: unknown-pass - # CHECK: ^ - log("ValueError exception:", e) - else: - log("Exception not produced") + with Context(): + try: + pm = PassManager.parse("any(unknown-pass)") + except ValueError as e: + # CHECK: ValueError exception: MLIR Textual PassPipeline Parser:1:1: error: + # CHECK-SAME: 'unknown-pass' does not refer to a registered pass or pass pipeline + # CHECK: unknown-pass + # CHECK: ^ + log("ValueError exception:", e) + else: + log("Exception not produced") + + run(testParseFail) # Check that adding to a pass manager works # CHECK-LABEL: TEST: testAdd @run def testAdd(): - pm = PassManager("any", Context()) - # CHECK: pm: 'any()' - log(f"pm: '{pm}'") - # CHECK: pm: 'any(cse)' - pm.add("cse") - log(f"pm: '{pm}'") - # CHECK: pm: 'any(cse,cse)' - pm.add("cse") - log(f"pm: '{pm}'") + pm = PassManager("any", Context()) + # CHECK: pm: 'any()' + log(f"pm: '{pm}'") + # CHECK: pm: 'any(cse)' + pm.add("cse") + log(f"pm: '{pm}'") + # CHECK: pm: 'any(cse,cse)' + pm.add("cse") + log(f"pm: '{pm}'") # Verify failure on incorrect level of nesting. # CHECK-LABEL: TEST: testInvalidNesting def testInvalidNesting(): - with Context(): - try: - pm = PassManager.parse("func.func(normalize-memrefs)") - except ValueError as e: - # CHECK: ValueError exception: Can't add pass 'NormalizeMemRefs' restricted to 'builtin.module' on a PassManager intended to run on 'func.func', did you intend to nest? - log("ValueError exception:", e) - else: - log("Exception not produced") + with Context(): + try: + pm = PassManager.parse("func.func(normalize-memrefs)") + except ValueError as e: + # CHECK: ValueError exception: Can't add pass 'NormalizeMemRefs' restricted to 'builtin.module' on a PassManager intended to run on 'func.func', did you intend to nest? + log("ValueError exception:", e) + else: + log("Exception not produced") + + run(testInvalidNesting) # Verify that a pass manager can execute on IR # CHECK-LABEL: TEST: testRunPipeline def testRunPipeline(): - with Context(): - pm = PassManager.parse("any(print-op-stats{json=false})") - func = FuncOp.parse(r"""func.func @successfulParse() { return }""") - pm.run(func) + with Context(): + pm = PassManager.parse("any(print-op-stats{json=false})") + func = FuncOp.parse(r"""func.func @successfulParse() { return }""") + pm.run(func) + + # CHECK: Operations encountered: # CHECK: func.func , 1 # CHECK: func.return , 1 @@ -132,16 +150,16 @@ run(testRunPipeline) # CHECK-LABEL: TEST: testRunPipelineError @run def testRunPipelineError(): - with Context() as ctx: - ctx.allow_unregistered_dialects = True - op = Operation.parse('"test.op"() : () -> ()') - pm = PassManager.parse("any(cse)") - try: - pm.run(op) - except MLIRError as e: - # CHECK: Exception: < - # CHECK: Failure while executing pass pipeline: - # CHECK: error: "-":1:1: 'test.op' op trying to schedule a pass on an unregistered operation - # CHECK: note: "-":1:1: see current operation: "test.op"() : () -> () - # CHECK: > - print(f"Exception: <{e}>") + with Context() as ctx: + ctx.allow_unregistered_dialects = True + op = Operation.parse('"test.op"() : () -> ()') + pm = PassManager.parse("any(cse)") + try: + pm.run(op) + except MLIRError as e: + # CHECK: Exception: < + # CHECK: Failure while executing pass pipeline: + # CHECK: error: "-":1:1: 'test.op' op trying to schedule a pass on an unregistered operation + # CHECK: note: "-":1:1: see current operation: "test.op"() : () -> () + # CHECK: > + print(f"Exception: <{e}>") diff --git a/mlir/test/tblgen-lsp-server/lit.local.cfg b/mlir/test/tblgen-lsp-server/lit.local.cfg index 25d08c7..aa35dbf 100644 --- a/mlir/test/tblgen-lsp-server/lit.local.cfg +++ b/mlir/test/tblgen-lsp-server/lit.local.cfg @@ -1 +1 @@ -config.excludes = ['include'] +config.excludes = ["include"] diff --git a/mlir/utils/gdb-scripts/prettyprinters.py b/mlir/utils/gdb-scripts/prettyprinters.py index 85a1a14..9ea8bdb 100644 --- a/mlir/utils/gdb-scripts/prettyprinters.py +++ b/mlir/utils/gdb-scripts/prettyprinters.py @@ -4,216 +4,223 @@ import gdb.printing class StoragePrinter: - """Prints bases of a struct and its fields.""" + """Prints bases of a struct and its fields.""" - def __init__(self, val): - self.val = val + def __init__(self, val): + self.val = val - def children(self): - for field in self.val.type.fields(): - if field.is_base_class: - yield '<%s>' % field.name, self.val.cast(field.type) - else: - yield field.name, self.val[field.name] + def children(self): + for field in self.val.type.fields(): + if field.is_base_class: + yield "<%s>" % field.name, self.val.cast(field.type) + else: + yield field.name, self.val[field.name] + + def to_string(self): + return "mlir::Storage" - def to_string(self): - return 'mlir::Storage' class TupleTypeStoragePrinter(StoragePrinter): + def children(self): + for child in StoragePrinter.children(self): + yield child + pointer_type = gdb.lookup_type("mlir::Type").pointer() + elements = (self.val.address + 1).cast(pointer_type) + for i in range(self.val["numElements"]): + yield "elements[%u]" % i, elements[i] - def children(self): - for child in StoragePrinter.children(self): - yield child - pointer_type = gdb.lookup_type('mlir::Type').pointer() - elements = (self.val.address + 1).cast(pointer_type) - for i in range(self.val['numElements']): - yield 'elements[%u]' % i, elements[i] + def to_string(self): + return "mlir::TupleTypeStorage of %u elements" % self.val["numElements"] - def to_string(self): - return 'mlir::TupleTypeStorage of %u elements' % self.val['numElements'] class FusedLocationStoragePrinter(StoragePrinter): + def children(self): + for child in StoragePrinter.children(self): + yield child + pointer_type = gdb.lookup_type("mlir::Location").pointer() + elements = (self.val.address + 1).cast(pointer_type) + for i in range(self.val["numLocs"]): + yield "locs[%u]" % i, elements[i] - def children(self): - for child in StoragePrinter.children(self): - yield child - pointer_type = gdb.lookup_type('mlir::Location').pointer() - elements = (self.val.address + 1).cast(pointer_type) - for i in range(self.val['numLocs']): - yield 'locs[%u]' % i, elements[i] - - def to_string(self): - return 'mlir::FusedLocationStorage of %u locs' % self.val['numLocs'] + def to_string(self): + return "mlir::FusedLocationStorage of %u locs" % self.val["numLocs"] class StorageTypeMap: - """Maps a TypeID to the corresponding concrete type. - - Types need to be registered by name before the first lookup. - """ - - def __init__(self): - self.map = None - self.type_names = [] - - def register_type(self, type_name): - assert not self.map, 'register_type called after __getitem__' - self.type_names += [type_name] - - def _init_map(self): - """Lazy initialization of self.map.""" - if self.map: - return - self.map = {} - for type_name in self.type_names: - concrete_type = gdb.lookup_type(type_name) - try: - storage = gdb.parse_and_eval( - "&'mlir::detail::TypeIDExported::get<%s>()::instance'" % type_name) - except gdb.error: - # Skip when TypeID instance cannot be found in current context. - continue - if concrete_type and storage: - self.map[int(storage)] = concrete_type - - def __getitem__(self, type_id): - self._init_map() - return self.map.get(int(type_id['storage'])) + """Maps a TypeID to the corresponding concrete type. + + Types need to be registered by name before the first lookup. + """ + + def __init__(self): + self.map = None + self.type_names = [] + + def register_type(self, type_name): + assert not self.map, "register_type called after __getitem__" + self.type_names += [type_name] + + def _init_map(self): + """Lazy initialization of self.map.""" + if self.map: + return + self.map = {} + for type_name in self.type_names: + concrete_type = gdb.lookup_type(type_name) + try: + storage = gdb.parse_and_eval( + "&'mlir::detail::TypeIDExported::get<%s>()::instance'" % type_name + ) + except gdb.error: + # Skip when TypeID instance cannot be found in current context. + continue + if concrete_type and storage: + self.map[int(storage)] = concrete_type + + def __getitem__(self, type_id): + self._init_map() + return self.map.get(int(type_id["storage"])) storage_type_map = StorageTypeMap() def get_type_id_printer(val): - """Returns a printer of the name of a mlir::TypeID.""" - - class TypeIdPrinter: + """Returns a printer of the name of a mlir::TypeID.""" - def __init__(self, string): - self.string = string + class TypeIdPrinter: + def __init__(self, string): + self.string = string - def to_string(self): - return self.string + def to_string(self): + return self.string - concrete_type = storage_type_map[val] - if not concrete_type: - return None - return TypeIdPrinter('mlir::TypeID::get<%s>()' % concrete_type) + concrete_type = storage_type_map[val] + if not concrete_type: + return None + return TypeIdPrinter("mlir::TypeID::get<%s>()" % concrete_type) def get_attr_or_type_printer(val, get_type_id): - """Returns a printer for mlir::Attribute or mlir::Type.""" - - class AttrOrTypePrinter: - - def __init__(self, type_id, impl): - self.type_id = type_id - self.impl = impl - - def children(self): - yield 'typeID', self.type_id - yield 'impl', self.impl - - def to_string(self): - return 'cast<%s>' % self.impl.type - - if not val['impl']: - return None - impl = val['impl'].dereference() - type_id = get_type_id(impl) - concrete_type = storage_type_map[type_id] - if not concrete_type: - return None - # 3rd template argument of StorageUserBase is the storage type. - storage_type = concrete_type.fields()[0].type.template_argument(2) - if not storage_type: - return None - return AttrOrTypePrinter(type_id, impl.cast(storage_type)) + """Returns a printer for mlir::Attribute or mlir::Type.""" + + class AttrOrTypePrinter: + def __init__(self, type_id, impl): + self.type_id = type_id + self.impl = impl + + def children(self): + yield "typeID", self.type_id + yield "impl", self.impl + + def to_string(self): + return "cast<%s>" % self.impl.type + + if not val["impl"]: + return None + impl = val["impl"].dereference() + type_id = get_type_id(impl) + concrete_type = storage_type_map[type_id] + if not concrete_type: + return None + # 3rd template argument of StorageUserBase is the storage type. + storage_type = concrete_type.fields()[0].type.template_argument(2) + if not storage_type: + return None + return AttrOrTypePrinter(type_id, impl.cast(storage_type)) class ImplPrinter: - """Printer for an instance with a single 'impl' member pointer.""" + """Printer for an instance with a single 'impl' member pointer.""" - def __init__(self, val): - self.val = val - self.impl = val['impl'] + def __init__(self, val): + self.val = val + self.impl = val["impl"] - def children(self): - if self.impl: - yield 'impl', self.impl.dereference() + def children(self): + if self.impl: + yield "impl", self.impl.dereference() - def to_string(self): - return self.val.type.name + def to_string(self): + return self.val.type.name # Printers of types deriving from Attribute::AttrBase or Type::TypeBase. for name in [ # mlir/IR/Attributes.h - 'ArrayAttr', - 'DictionaryAttr', - 'FloatAttr', - 'IntegerAttr', - 'IntegerSetAttr', - 'OpaqueAttr', - 'StringAttr', - 'SymbolRefAttr', - 'TypeAttr', - 'UnitAttr', - 'DenseStringElementsAttr', - 'DenseIntOrFPElementsAttr', - 'SparseElementsAttr', + "ArrayAttr", + "DictionaryAttr", + "FloatAttr", + "IntegerAttr", + "IntegerSetAttr", + "OpaqueAttr", + "StringAttr", + "SymbolRefAttr", + "TypeAttr", + "UnitAttr", + "DenseStringElementsAttr", + "DenseIntOrFPElementsAttr", + "SparseElementsAttr", # mlir/IR/BuiltinTypes.h - 'ComplexType', - 'IndexType', - 'IntegerType', - 'Float16Type', - 'Float32Type', - 'Float64Type', - 'Float80Type', - 'Float128Type', - 'NoneType', - 'VectorType', - 'RankedTensorType', - 'UnrankedTensorType', - 'MemRefType', - 'UnrankedMemRefType', - 'TupleType', + "ComplexType", + "IndexType", + "IntegerType", + "Float16Type", + "Float32Type", + "Float64Type", + "Float80Type", + "Float128Type", + "NoneType", + "VectorType", + "RankedTensorType", + "UnrankedTensorType", + "MemRefType", + "UnrankedMemRefType", + "TupleType", # mlir/IR/Location.h - 'CallSiteLoc', - 'FileLineColLoc', - 'FusedLoc', - 'NameLoc', - 'OpaqueLoc', - 'UnknownLoc' + "CallSiteLoc", + "FileLineColLoc", + "FusedLoc", + "NameLoc", + "OpaqueLoc", + "UnknownLoc", ]: - storage_type_map.register_type('mlir::%s' % name) # Register for upcasting. -storage_type_map.register_type('void') # Register default. + storage_type_map.register_type("mlir::%s" % name) # Register for upcasting. +storage_type_map.register_type("void") # Register default. -pp = gdb.printing.RegexpCollectionPrettyPrinter('MLIRSupport') +pp = gdb.printing.RegexpCollectionPrettyPrinter("MLIRSupport") -pp.add_printer('mlir::OperationName', '^mlir::OperationName$', ImplPrinter) -pp.add_printer('mlir::Value', '^mlir::Value$', ImplPrinter) +pp.add_printer("mlir::OperationName", "^mlir::OperationName$", ImplPrinter) +pp.add_printer("mlir::Value", "^mlir::Value$", ImplPrinter) # Printers for types deriving from AttributeStorage or TypeStorage. -pp.add_printer('mlir::detail::FusedLocationStorage', - '^mlir::detail::FusedLocationStorage', - FusedLocationStoragePrinter) -pp.add_printer('mlir::detail::TupleTypeStorage', - '^mlir::detail::TupleTypeStorage$', TupleTypeStoragePrinter) +pp.add_printer( + "mlir::detail::FusedLocationStorage", + "^mlir::detail::FusedLocationStorage", + FusedLocationStoragePrinter, +) +pp.add_printer( + "mlir::detail::TupleTypeStorage", + "^mlir::detail::TupleTypeStorage$", + TupleTypeStoragePrinter, +) -pp.add_printer('mlir::TypeID', '^mlir::TypeID$', get_type_id_printer) +pp.add_printer("mlir::TypeID", "^mlir::TypeID$", get_type_id_printer) def add_attr_or_type_printers(name): - """Adds printers for mlir::Attribute or mlir::Type and their Storage type.""" - get_type_id = lambda val: val['abstract%s' % name]['typeID'] - pp.add_printer('mlir::%s' % name, '^mlir::%s$' % name, - lambda val: get_attr_or_type_printer(val, get_type_id)) + """Adds printers for mlir::Attribute or mlir::Type and their Storage type.""" + get_type_id = lambda val: val["abstract%s" % name]["typeID"] + pp.add_printer( + "mlir::%s" % name, + "^mlir::%s$" % name, + lambda val: get_attr_or_type_printer(val, get_type_id), + ) # Upcasting printers of mlir::Attribute and mlir::Type. -for name in ['Attribute', 'Type']: - add_attr_or_type_printers(name) +for name in ["Attribute", "Type"]: + add_attr_or_type_printers(name) gdb.printing.register_pretty_printer(gdb.current_objfile(), pp) diff --git a/mlir/utils/generate-test-checks.py b/mlir/utils/generate-test-checks.py index 474f812..0210d7a 100755 --- a/mlir/utils/generate-test-checks.py +++ b/mlir/utils/generate-test-checks.py @@ -32,7 +32,7 @@ import os # Used to advertise this file's name ("autogenerated_note"). import re import sys -ADVERT_BEGIN = '// NOTE: Assertions have been autogenerated by ' +ADVERT_BEGIN = "// NOTE: Assertions have been autogenerated by " ADVERT_END = """ // The script is designed to make adding checks to // a test case fast, it is *not* designed to be authoritative @@ -42,250 +42,249 @@ ADVERT_END = """ # Regex command to match an SSA identifier. -SSA_RE_STR = '[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*' +SSA_RE_STR = "[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*" SSA_RE = re.compile(SSA_RE_STR) # Class used to generate and manage string substitution blocks for SSA value # names. class SSAVariableNamer: + def __init__(self): + self.scopes = [] + self.name_counter = 0 - def __init__(self): - self.scopes = [] - self.name_counter = 0 + # Generate a substitution name for the given ssa value name. + def generate_name(self, ssa_name): + variable = "VAL_" + str(self.name_counter) + self.name_counter += 1 + self.scopes[-1][ssa_name] = variable + return variable - # Generate a substitution name for the given ssa value name. - def generate_name(self, ssa_name): - variable = 'VAL_' + str(self.name_counter) - self.name_counter += 1 - self.scopes[-1][ssa_name] = variable - return variable + # Push a new variable name scope. + def push_name_scope(self): + self.scopes.append({}) - # Push a new variable name scope. - def push_name_scope(self): - self.scopes.append({}) + # Pop the last variable name scope. + def pop_name_scope(self): + self.scopes.pop() - # Pop the last variable name scope. - def pop_name_scope(self): - self.scopes.pop() + # Return the level of nesting (number of pushed scopes). + def num_scopes(self): + return len(self.scopes) - # Return the level of nesting (number of pushed scopes). - def num_scopes(self): - return len(self.scopes) - - # Reset the counter. - def clear_counter(self): - self.name_counter = 0 + # Reset the counter. + def clear_counter(self): + self.name_counter = 0 # Process a line of input that has been split at each SSA identifier '%'. def process_line(line_chunks, variable_namer): - output_line = '' - - # Process the rest that contained an SSA value name. - for chunk in line_chunks: - m = SSA_RE.match(chunk) - ssa_name = m.group(0) - - # Check if an existing variable exists for this name. - variable = None - for scope in variable_namer.scopes: - variable = scope.get(ssa_name) - if variable is not None: - break - - # If one exists, then output the existing name. - if variable is not None: - output_line += '%[[' + variable + ']]' - else: - # Otherwise, generate a new variable. - variable = variable_namer.generate_name(ssa_name) - output_line += '%[[' + variable + ':.*]]' + output_line = "" + + # Process the rest that contained an SSA value name. + for chunk in line_chunks: + m = SSA_RE.match(chunk) + ssa_name = m.group(0) + + # Check if an existing variable exists for this name. + variable = None + for scope in variable_namer.scopes: + variable = scope.get(ssa_name) + if variable is not None: + break - # Append the non named group. - output_line += chunk[len(ssa_name):] + # If one exists, then output the existing name. + if variable is not None: + output_line += "%[[" + variable + "]]" + else: + # Otherwise, generate a new variable. + variable = variable_namer.generate_name(ssa_name) + output_line += "%[[" + variable + ":.*]]" - return output_line.rstrip() + '\n' + # Append the non named group. + output_line += chunk[len(ssa_name) :] + + return output_line.rstrip() + "\n" # Process the source file lines. The source file doesn't have to be .mlir. def process_source_lines(source_lines, note, args): - source_split_re = re.compile(args.source_delim_regex) + source_split_re = re.compile(args.source_delim_regex) - source_segments = [[]] - for line in source_lines: - # Remove previous note. - if line == note: - continue - # Remove previous CHECK lines. - if line.find(args.check_prefix) != -1: - continue - # Segment the file based on --source_delim_regex. - if source_split_re.search(line): - source_segments.append([]) + source_segments = [[]] + for line in source_lines: + # Remove previous note. + if line == note: + continue + # Remove previous CHECK lines. + if line.find(args.check_prefix) != -1: + continue + # Segment the file based on --source_delim_regex. + if source_split_re.search(line): + source_segments.append([]) - source_segments[-1].append(line + '\n') - return source_segments + source_segments[-1].append(line + "\n") + return source_segments # Pre-process a line of input to remove any character sequences that will be # problematic with FileCheck. def preprocess_line(line): - # Replace any double brackets, '[[' with escaped replacements. '[[' - # corresponds to variable names in FileCheck. - output_line = line.replace('[[', '{{\\[\\[}}') + # Replace any double brackets, '[[' with escaped replacements. '[[' + # corresponds to variable names in FileCheck. + output_line = line.replace("[[", "{{\\[\\[}}") - # Replace any single brackets that are followed by an SSA identifier, the - # identifier will be replace by a variable; Creating the same situation as - # above. - output_line = output_line.replace('[%', '{{\\[}}%') + # Replace any single brackets that are followed by an SSA identifier, the + # identifier will be replace by a variable; Creating the same situation as + # above. + output_line = output_line.replace("[%", "{{\\[}}%") - return output_line + return output_line def main(): - parser = argparse.ArgumentParser( - description=__doc__, formatter_class=argparse.RawTextHelpFormatter) - parser.add_argument( - '--check-prefix', default='CHECK', help='Prefix to use from check file.') - parser.add_argument( - '-o', - '--output', - nargs='?', - type=argparse.FileType('w'), - default=None) - parser.add_argument( - 'input', - nargs='?', - type=argparse.FileType('r'), - default=sys.stdin) - parser.add_argument( - '--source', type=str, - help='Print each CHECK chunk before each delimeter line in the source' - 'file, respectively. The delimeter lines are identified by ' - '--source_delim_regex.') - parser.add_argument('--source_delim_regex', type=str, default='func @') - parser.add_argument( - '--starts_from_scope', type=int, default=1, - help='Omit the top specified level of content. For example, by default ' - 'it omits "module {"') - parser.add_argument('-i', '--inplace', action='store_true', default=False) - - args = parser.parse_args() - - # Open the given input file. - input_lines = [l.rstrip() for l in args.input] - args.input.close() - - # Generate a note used for the generated check file. - script_name = os.path.basename(__file__) - autogenerated_note = (ADVERT_BEGIN + 'utils/' + script_name + "\n" + ADVERT_END) - - source_segments = None - if args.source: - source_segments = process_source_lines( - [l.rstrip() for l in open(args.source, 'r')], - autogenerated_note, - args + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawTextHelpFormatter + ) + parser.add_argument( + "--check-prefix", default="CHECK", help="Prefix to use from check file." + ) + parser.add_argument( + "-o", "--output", nargs="?", type=argparse.FileType("w"), default=None + ) + parser.add_argument( + "input", nargs="?", type=argparse.FileType("r"), default=sys.stdin ) + parser.add_argument( + "--source", + type=str, + help="Print each CHECK chunk before each delimeter line in the source" + "file, respectively. The delimeter lines are identified by " + "--source_delim_regex.", + ) + parser.add_argument("--source_delim_regex", type=str, default="func @") + parser.add_argument( + "--starts_from_scope", + type=int, + default=1, + help="Omit the top specified level of content. For example, by default " + 'it omits "module {"', + ) + parser.add_argument("-i", "--inplace", action="store_true", default=False) + + args = parser.parse_args() + + # Open the given input file. + input_lines = [l.rstrip() for l in args.input] + args.input.close() - if args.inplace: - assert args.output is None - output = open(args.source, 'w') - elif args.output is None: - output = sys.stdout - else: - output = args.output - - output_segments = [[]] - # A map containing data used for naming SSA value names. - variable_namer = SSAVariableNamer() - for input_line in input_lines: - if not input_line: - continue - lstripped_input_line = input_line.lstrip() - - # Lines with blocks begin with a ^. These lines have a trailing comment - # that needs to be stripped. - is_block = lstripped_input_line[0] == '^' - if is_block: - input_line = input_line.rsplit('//', 1)[0].rstrip() - - cur_level = variable_namer.num_scopes() - - # If the line starts with a '}', pop the last name scope. - if lstripped_input_line[0] == '}': - variable_namer.pop_name_scope() - cur_level = variable_namer.num_scopes() - - # If the line ends with a '{', push a new name scope. - if input_line[-1] == '{': - variable_namer.push_name_scope() - if cur_level == args.starts_from_scope: - output_segments.append([]) - - # Omit lines at the near top level e.g. "module {". - if cur_level < args.starts_from_scope: - continue - - if len(output_segments[-1]) == 0: - variable_namer.clear_counter() - - # Preprocess the input to remove any sequences that may be problematic with - # FileCheck. - input_line = preprocess_line(input_line) - - # Split the line at the each SSA value name. - ssa_split = input_line.split('%') - - # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'. - if len(output_segments[-1]) != 0 or not ssa_split[0]: - output_line = '// ' + args.check_prefix + ': ' - # Pad to align with the 'LABEL' statements. - output_line += (' ' * len('-LABEL')) - - # Output the first line chunk that does not contain an SSA name. - output_line += ssa_split[0] - - # Process the rest of the input line. - output_line += process_line(ssa_split[1:], variable_namer) + # Generate a note used for the generated check file. + script_name = os.path.basename(__file__) + autogenerated_note = ADVERT_BEGIN + "utils/" + script_name + "\n" + ADVERT_END + source_segments = None + if args.source: + source_segments = process_source_lines( + [l.rstrip() for l in open(args.source, "r")], autogenerated_note, args + ) + + if args.inplace: + assert args.output is None + output = open(args.source, "w") + elif args.output is None: + output = sys.stdout + else: + output = args.output + + output_segments = [[]] + # A map containing data used for naming SSA value names. + variable_namer = SSAVariableNamer() + for input_line in input_lines: + if not input_line: + continue + lstripped_input_line = input_line.lstrip() + + # Lines with blocks begin with a ^. These lines have a trailing comment + # that needs to be stripped. + is_block = lstripped_input_line[0] == "^" + if is_block: + input_line = input_line.rsplit("//", 1)[0].rstrip() + + cur_level = variable_namer.num_scopes() + + # If the line starts with a '}', pop the last name scope. + if lstripped_input_line[0] == "}": + variable_namer.pop_name_scope() + cur_level = variable_namer.num_scopes() + + # If the line ends with a '{', push a new name scope. + if input_line[-1] == "{": + variable_namer.push_name_scope() + if cur_level == args.starts_from_scope: + output_segments.append([]) + + # Omit lines at the near top level e.g. "module {". + if cur_level < args.starts_from_scope: + continue + + if len(output_segments[-1]) == 0: + variable_namer.clear_counter() + + # Preprocess the input to remove any sequences that may be problematic with + # FileCheck. + input_line = preprocess_line(input_line) + + # Split the line at the each SSA value name. + ssa_split = input_line.split("%") + + # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'. + if len(output_segments[-1]) != 0 or not ssa_split[0]: + output_line = "// " + args.check_prefix + ": " + # Pad to align with the 'LABEL' statements. + output_line += " " * len("-LABEL") + + # Output the first line chunk that does not contain an SSA name. + output_line += ssa_split[0] + + # Process the rest of the input line. + output_line += process_line(ssa_split[1:], variable_namer) + + else: + # Output the first line chunk that does not contain an SSA name for the + # label. + output_line = "// " + args.check_prefix + "-LABEL: " + ssa_split[0] + "\n" + + # Process the rest of the input line on separate check lines. + for argument in ssa_split[1:]: + output_line += "// " + args.check_prefix + "-SAME: " + + # Pad to align with the original position in the line. + output_line += " " * len(ssa_split[0]) + + # Process the rest of the line. + output_line += process_line([argument], variable_namer) + + # Append the output line. + output_segments[-1].append(output_line) + + output.write(autogenerated_note + "\n") + + # Write the output. + if source_segments: + assert len(output_segments) == len(source_segments) + for check_segment, source_segment in zip(output_segments, source_segments): + for line in check_segment: + output.write(line) + for line in source_segment: + output.write(line) else: - # Output the first line chunk that does not contain an SSA name for the - # label. - output_line = '// ' + args.check_prefix + '-LABEL: ' + ssa_split[0] + '\n' - - # Process the rest of the input line on separate check lines. - for argument in ssa_split[1:]: - output_line += '// ' + args.check_prefix + '-SAME: ' - - # Pad to align with the original position in the line. - output_line += ' ' * len(ssa_split[0]) - - # Process the rest of the line. - output_line += process_line([argument], variable_namer) - - # Append the output line. - output_segments[-1].append(output_line) - - output.write(autogenerated_note + '\n') - - # Write the output. - if source_segments: - assert len(output_segments) == len(source_segments) - for check_segment, source_segment in zip(output_segments, source_segments): - for line in check_segment: - output.write(line) - for line in source_segment: - output.write(line) - else: - for segment in output_segments: - output.write('\n') - for output_line in segment: - output.write(output_line) - output.write('\n') - output.close() - - -if __name__ == '__main__': - main() + for segment in output_segments: + output.write("\n") + for output_line in segment: + output.write(output_line) + output.write("\n") + output.close() + + +if __name__ == "__main__": + main() diff --git a/mlir/utils/jupyter/mlir_opt_kernel/__main__.py b/mlir/utils/jupyter/mlir_opt_kernel/__main__.py index 02582f9..21994ff 100644 --- a/mlir/utils/jupyter/mlir_opt_kernel/__main__.py +++ b/mlir/utils/jupyter/mlir_opt_kernel/__main__.py @@ -4,4 +4,5 @@ from ipykernel.kernelapp import IPKernelApp from .kernel import MlirOptKernel + IPKernelApp.launch_instance(kernel_class=MlirOptKernel) diff --git a/mlir/utils/jupyter/mlir_opt_kernel/install.py b/mlir/utils/jupyter/mlir_opt_kernel/install.py index ddb37c8..bd7b1d1 100644 --- a/mlir/utils/jupyter/mlir_opt_kernel/install.py +++ b/mlir/utils/jupyter/mlir_opt_kernel/install.py @@ -10,12 +10,11 @@ from jupyter_client.kernelspec import KernelSpecManager def install_my_kernel_spec(user=True, prefix=None): """Install the kernel spec for user in given prefix.""" - print('Installing mlir-opt IPython kernel spec') + print("Installing mlir-opt IPython kernel spec") pkgroot = os.path.dirname(__file__) - KernelSpecManager().install_kernel_spec(os.path.join(pkgroot, 'assets'), - 'mlir', - user=user, - prefix=prefix) + KernelSpecManager().install_kernel_spec( + os.path.join(pkgroot, "assets"), "mlir", user=user, prefix=prefix + ) def _is_root(): @@ -29,15 +28,16 @@ def _is_root(): def main(argv=None): parser = argparse.ArgumentParser( - description='Install KernelSpec for MlirOpt Kernel') + description="Install KernelSpec for MlirOpt Kernel" + ) prefix_locations = parser.add_mutually_exclusive_group() - prefix_locations.add_argument('--user', - help='Install in user home directory', - action='store_true') - prefix_locations.add_argument('--prefix', - help='Install directory prefix', - default=None) + prefix_locations.add_argument( + "--user", help="Install in user home directory", action="store_true" + ) + prefix_locations.add_argument( + "--prefix", help="Install directory prefix", default=None + ) args = parser.parse_args(argv) @@ -47,5 +47,5 @@ def main(argv=None): install_my_kernel_spec(user=user, prefix=prefix) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mlir/utils/jupyter/mlir_opt_kernel/kernel.py b/mlir/utils/jupyter/mlir_opt_kernel/kernel.py index 85462da..c0e4fc1 100644 --- a/mlir/utils/jupyter/mlir_opt_kernel/kernel.py +++ b/mlir/utils/jupyter/mlir_opt_kernel/kernel.py @@ -9,7 +9,7 @@ import tempfile import traceback from ipykernel.kernelbase import Kernel -__version__ = '0.0.1' +__version__ = "0.0.1" def _get_executable(): @@ -19,7 +19,7 @@ def _get_executable(): """Returns whether executable file.""" return os.path.isfile(fpath) and os.access(fpath, os.X_OK) - program = os.environ.get('MLIR_OPT_EXECUTABLE', 'mlir-opt') + program = os.environ.get("MLIR_OPT_EXECUTABLE", "mlir-opt") path, name = os.path.split(program) # Attempt to get the executable if path: @@ -30,7 +30,7 @@ def _get_executable(): file = os.path.join(path, name) if is_exe(file): return file - raise OSError('mlir-opt not found, please see README') + raise OSError("mlir-opt not found, please see README") class MlirOptKernel(Kernel): @@ -51,19 +51,17 @@ class MlirOptKernel(Kernel): ``` """ - implementation = 'mlir' + implementation = "mlir" implementation_version = __version__ language_version = __version__ language = "mlir" language_info = { "name": "mlir", - "codemirror_mode": { - "name": "mlir" - }, + "codemirror_mode": {"name": "mlir"}, "mimetype": "text/x-mlir", "file_extension": ".mlir", - "pygments_lexer": "text" + "pygments_lexer": "text", } @property @@ -88,31 +86,28 @@ class MlirOptKernel(Kernel): """Reports regular command output.""" if not self.silent: # Send standard output - stream_content = {'name': 'stdout', 'text': output} - self.send_response(self.iopub_socket, 'stream', stream_content) + stream_content = {"name": "stdout", "text": output} + self.send_response(self.iopub_socket, "stream", stream_content) def process_error(self, output): """Reports error response.""" if not self.silent: # Send standard error - stream_content = {'name': 'stderr', 'text': output} - self.send_response(self.iopub_socket, 'stream', stream_content) - - def do_execute(self, - code, - silent, - store_history=True, - user_expressions=None, - allow_stdin=False): + stream_content = {"name": "stderr", "text": output} + self.send_response(self.iopub_socket, "stream", stream_content) + + def do_execute( + self, code, silent, store_history=True, user_expressions=None, allow_stdin=False + ): """Execute user code using mlir-opt binary.""" def ok_status(): """Returns OK status.""" return { - 'status': 'ok', - 'execution_count': self.execution_count, - 'payload': [], - 'user_expressions': {} + "status": "ok", + "execution_count": self.execution_count, + "payload": [], + "user_expressions": {}, } def run(code): @@ -123,29 +118,27 @@ class MlirOptKernel(Kernel): # Specify input and output file to error out if also # set as arg. self.get_executable(), - '--color', + "--color", inputmlir.name, - '-o', - '-' + "-o", + "-", ] # Simple handling of repeating last line. - if code.endswith('\n_'): + if code.endswith("\n_"): if not self._: - raise NameError('No previous result set') + raise NameError("No previous result set") code = code[:-1] + self._ inputmlir.write(code.encode("utf-8")) inputmlir.close() - pipe = Popen(command, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + pipe = Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) output, errors = pipe.communicate() exitcode = pipe.returncode finally: os.unlink(inputmlir.name) -# Replace temporary filename with placeholder. This takes the very -# remote chance where the full input filename (generated above) -# overlaps with something in the dump unrelated to the file. + # Replace temporary filename with placeholder. This takes the very + # remote chance where the full input filename (generated above) + # overlaps with something in the dump unrelated to the file. fname = inputmlir.name.encode("utf-8") output = output.replace(fname, b"<>") errors = errors.replace(fname, b"<>") @@ -163,7 +156,7 @@ class MlirOptKernel(Kernel): else: self._ = output.decode("utf-8") except KeyboardInterrupt: - return {'status': 'abort', 'execution_count': self.execution_count} + return {"status": "abort", "execution_count": self.execution_count} except Exception as error: # Print traceback for local debugging. traceback.print_exc() @@ -172,24 +165,24 @@ class MlirOptKernel(Kernel): errors = repr(error).encode("utf-8") if exitcode: - content = {'ename': '', 'evalue': str(exitcode), 'traceback': []} + content = {"ename": "", "evalue": str(exitcode), "traceback": []} - self.send_response(self.iopub_socket, 'error', content) + self.send_response(self.iopub_socket, "error", content) self.process_error(errors.decode("utf-8")) - content['execution_count'] = self.execution_count - content['status'] = 'error' + content["execution_count"] = self.execution_count + content["status"] = "error" return content if not silent: data = {} - data['text/x-mlir'] = self._ + data["text/x-mlir"] = self._ content = { - 'execution_count': self.execution_count, - 'data': data, - 'metadata': {} + "execution_count": self.execution_count, + "data": data, + "metadata": {}, } - self.send_response(self.iopub_socket, 'execute_result', content) + self.send_response(self.iopub_socket, "execute_result", content) self.process_output(self._) self.process_error(errors.decode("utf-8")) return ok_status() diff --git a/mlir/utils/lldb-scripts/mlirDataFormatters.py b/mlir/utils/lldb-scripts/mlirDataFormatters.py index bfd76a7..5d06b40 100644 --- a/mlir/utils/lldb-scripts/mlirDataFormatters.py +++ b/mlir/utils/lldb-scripts/mlirDataFormatters.py @@ -521,8 +521,7 @@ class InDirectRangeSynthProvider: class IPListRangeSynthProvider: - """Define an LLDB synthetic children provider for an IPList. - """ + """Define an LLDB synthetic children provider for an IPList.""" def __init__(self, valobj, internal_dict): self.valobj = valobj @@ -575,8 +574,7 @@ class IPListRangeSynthProvider: class ValueSynthProvider: - """Define an LLDB synthetic children provider for Values. - """ + """Define an LLDB synthetic children provider for Values.""" def __init__(self, valobj, internal_dict): self.valobj = valobj @@ -677,8 +675,7 @@ class ValueSynthProvider: def ValueSummaryProvider(valobj: lldb.SBValue, internal_dict): - """Define an LLDB summary provider for Values. - """ + """Define an LLDB summary provider for Values.""" index = valobj.GetChildMemberWithName("index").GetValueAsUnsigned() # Check if this is a block argument or not (block arguments have locations). diff --git a/mlir/utils/mbr/mbr/__init__.py b/mlir/utils/mbr/mbr/__init__.py index 3e47ec8..d01befd 100644 --- a/mlir/utils/mbr/mbr/__init__.py +++ b/mlir/utils/mbr/mbr/__init__.py @@ -9,5 +9,6 @@ class BenchmarkRunConfig: class. The `compiler` attribute is optional, for example for python benchmarks. """ + runner: typing.Callable compiler: typing.Optional[typing.Callable] = None diff --git a/mlir/utils/mbr/mbr/discovery.py b/mlir/utils/mbr/mbr/discovery.py index 37cc458..6c9803e 100644 --- a/mlir/utils/mbr/mbr/discovery.py +++ b/mlir/utils/mbr/mbr/discovery.py @@ -16,21 +16,17 @@ def discover_benchmark_modules(top_level_path): defaults to "benchmark_" """ config = configparser.ConfigParser() - config.read( - os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.ini") - ) + config.read(os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.ini")) if "discovery" in config.sections(): filename_prefix = config["discovery"]["filename_prefix"] else: filename_prefix = "benchmark_" - if re.search(fr"{filename_prefix}.*.py$", top_level_path): + if re.search(rf"{filename_prefix}.*.py$", top_level_path): # A specific python file so just include that. benchmark_files = [top_level_path] else: # A directory so recursively search for all python files. - benchmark_files = pathlib.Path( - top_level_path - ).rglob(f"{filename_prefix}*.py") + benchmark_files = pathlib.Path(top_level_path).rglob(f"{filename_prefix}*.py") for benchmark_filename in benchmark_files: benchmark_abs_dir = os.path.abspath(os.path.dirname(benchmark_filename)) sys.path.append(benchmark_abs_dir) @@ -46,9 +42,7 @@ def get_benchmark_functions(module, benchmark_function_name=None): a specific prefix, which defaults to "benchmark_". """ config = configparser.ConfigParser() - config.read( - os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.ini") - ) + config.read(os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.ini")) if "discovery" in config.sections(): function_prefix = config["discovery"].get("function_prefix") else: @@ -57,9 +51,8 @@ def get_benchmark_functions(module, benchmark_function_name=None): module_functions = [] for attribute_name in dir(module): attribute = getattr(module, attribute_name) - if ( - isinstance(attribute, types.FunctionType) - and attribute_name.startswith(function_prefix) + if isinstance(attribute, types.FunctionType) and attribute_name.startswith( + function_prefix ): module_functions.append(attribute) diff --git a/mlir/utils/mbr/mbr/main.py b/mlir/utils/mbr/mbr/main.py index 0f67454..5d301ab 100644 --- a/mlir/utils/mbr/mbr/main.py +++ b/mlir/utils/mbr/mbr/main.py @@ -12,8 +12,7 @@ from stats import has_enough_measurements def main(top_level_path, stop_on_error): - """Top level function called when the CLI is invoked. - """ + """Top level function called when the CLI is invoked.""" if "::" in top_level_path: if top_level_path.count("::") > 1: raise AssertionError(f"Invalid path {top_level_path}") @@ -22,16 +21,14 @@ def main(top_level_path, stop_on_error): benchmark_function_name = None if not os.path.exists(top_level_path): - raise AssertionError( - f"The top-level path {top_level_path} doesn't exist" - ) + raise AssertionError(f"The top-level path {top_level_path} doesn't exist") modules = [module for module in discover_benchmark_modules(top_level_path)] benchmark_dicts = [] for module in modules: benchmark_functions = [ - function for function in - get_benchmark_functions(module, benchmark_function_name) + function + for function in get_benchmark_functions(module, benchmark_function_name) ] for benchmark_function in benchmark_functions: try: @@ -96,10 +93,9 @@ def main(top_level_path, stop_on_error): if len(measurements_ns) > 0: measurements_s = [t * 1e-9 for t in measurements_ns] - benchmark_identifier = ":".join([ - module.__name__, - benchmark_function.__name__ - ]) + benchmark_identifier = ":".join( + [module.__name__, benchmark_function.__name__] + ) benchmark_dicts.append( { "name": benchmark_identifier, diff --git a/mlir/utils/mbr/mbr/stats.py b/mlir/utils/mbr/mbr/stats.py index 3288021..9b7a3dc 100644 --- a/mlir/utils/mbr/mbr/stats.py +++ b/mlir/utils/mbr/mbr/stats.py @@ -16,9 +16,7 @@ def has_enough_measurements(measurements): If 1. is true, 2. doesn't need to be true. """ config = configparser.ConfigParser() - config.read( - os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.cfg") - ) + config.read(os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.cfg")) if "stats" in config: stats_dict = { "max_number_of_measurements": int( @@ -34,6 +32,6 @@ def has_enough_measurements(measurements): "max_time_for_a_benchmark_ns": 1e9, } return ( - np.sum(measurements) >= stats_dict["max_time_for_a_benchmark_ns"] or - np.size(measurements) >= stats_dict["max_number_of_measurements"] + np.sum(measurements) >= stats_dict["max_time_for_a_benchmark_ns"] + or np.size(measurements) >= stats_dict["max_number_of_measurements"] ) diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py index aeb1827..426bfca 100755 --- a/mlir/utils/spirv/gen_spirv_dialect.py +++ b/mlir/utils/spirv/gen_spirv_dialect.py @@ -23,1088 +23,1164 @@ import requests import textwrap import yaml -SPIRV_HTML_SPEC_URL = 'https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html' -SPIRV_JSON_SPEC_URL = 'https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/spirv.core.grammar.json' +SPIRV_HTML_SPEC_URL = ( + "https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html" +) +SPIRV_JSON_SPEC_URL = "https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/spirv.core.grammar.json" -SPIRV_CL_EXT_HTML_SPEC_URL = 'https://www.khronos.org/registry/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html' -SPIRV_CL_EXT_JSON_SPEC_URL = 'https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/extinst.opencl.std.100.grammar.json' +SPIRV_CL_EXT_HTML_SPEC_URL = "https://www.khronos.org/registry/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html" +SPIRV_CL_EXT_JSON_SPEC_URL = "https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/extinst.opencl.std.100.grammar.json" -AUTOGEN_OP_DEF_SEPARATOR = '\n// -----\n\n' -AUTOGEN_ENUM_SECTION_MARKER = 'enum section. Generated from SPIR-V spec; DO NOT MODIFY!' +AUTOGEN_OP_DEF_SEPARATOR = "\n// -----\n\n" +AUTOGEN_ENUM_SECTION_MARKER = "enum section. Generated from SPIR-V spec; DO NOT MODIFY!" AUTOGEN_OPCODE_SECTION_MARKER = ( - 'opcode section. Generated from SPIR-V spec; DO NOT MODIFY!') + "opcode section. Generated from SPIR-V spec; DO NOT MODIFY!" +) + def get_spirv_doc_from_html_spec(url, settings): - """Extracts instruction documentation from SPIR-V HTML spec. - - Returns: - - A dict mapping from instruction opcode to documentation. - """ - if url is None: - url = SPIRV_HTML_SPEC_URL - - response = requests.get(url) - spec = response.content - - from bs4 import BeautifulSoup - spirv = BeautifulSoup(spec, 'html.parser') - - doc = {} - - if settings.gen_cl_ops: - section_anchor = spirv.find('h2', {'id': '_binary_form'}) - for section in section_anchor.parent.find_all('div', {'class': 'sect2'}): - for table in section.find_all('table'): - inst_html = table.tbody.tr.td - opname = inst_html.a['id'] - # Ignore the first line, which is just the opname. - doc[opname] = inst_html.text.split('\n', 1)[1].strip() - else: - section_anchor = spirv.find('h3', {'id': '_instructions_3'}) - for section in section_anchor.parent.find_all('div', {'class': 'sect3'}): - for table in section.find_all('table'): - inst_html = table.tbody.tr.td.p - opname = inst_html.a['id'] - # Ignore the first line, which is just the opname. - doc[opname] = inst_html.text.split('\n', 1)[1].strip() - - return doc + """Extracts instruction documentation from SPIR-V HTML spec. + + Returns: + - A dict mapping from instruction opcode to documentation. + """ + if url is None: + url = SPIRV_HTML_SPEC_URL + + response = requests.get(url) + spec = response.content + + from bs4 import BeautifulSoup + + spirv = BeautifulSoup(spec, "html.parser") + + doc = {} + + if settings.gen_cl_ops: + section_anchor = spirv.find("h2", {"id": "_binary_form"}) + for section in section_anchor.parent.find_all("div", {"class": "sect2"}): + for table in section.find_all("table"): + inst_html = table.tbody.tr.td + opname = inst_html.a["id"] + # Ignore the first line, which is just the opname. + doc[opname] = inst_html.text.split("\n", 1)[1].strip() + else: + section_anchor = spirv.find("h3", {"id": "_instructions_3"}) + for section in section_anchor.parent.find_all("div", {"class": "sect3"}): + for table in section.find_all("table"): + inst_html = table.tbody.tr.td.p + opname = inst_html.a["id"] + # Ignore the first line, which is just the opname. + doc[opname] = inst_html.text.split("\n", 1)[1].strip() + + return doc def get_spirv_grammar_from_json_spec(url): - """Extracts operand kind and instruction grammar from SPIR-V JSON spec. + """Extracts operand kind and instruction grammar from SPIR-V JSON spec. - Returns: - - A list containing all operand kinds' grammar - - A list containing all instructions' grammar - """ - response = requests.get(SPIRV_JSON_SPEC_URL) - spec = response.content + Returns: + - A list containing all operand kinds' grammar + - A list containing all instructions' grammar + """ + response = requests.get(SPIRV_JSON_SPEC_URL) + spec = response.content - import json - spirv = json.loads(spec) + import json - if url is None: - return spirv['operand_kinds'], spirv['instructions'] + spirv = json.loads(spec) - response_ext = requests.get(url) - spec_ext = response_ext.content - spirv_ext = json.loads(spec_ext) + if url is None: + return spirv["operand_kinds"], spirv["instructions"] - return spirv['operand_kinds'], spirv_ext['instructions'] + response_ext = requests.get(url) + spec_ext = response_ext.content + spirv_ext = json.loads(spec_ext) + + return spirv["operand_kinds"], spirv_ext["instructions"] def split_list_into_sublists(items): - """Split the list of items into multiple sublists. + """Split the list of items into multiple sublists. - This is to make sure the string composed from each sublist won't exceed - 80 characters. + This is to make sure the string composed from each sublist won't exceed + 80 characters. - Arguments: - - items: a list of strings - """ - chuncks = [] - chunk = [] - chunk_len = 0 + Arguments: + - items: a list of strings + """ + chuncks = [] + chunk = [] + chunk_len = 0 - for item in items: - chunk_len += len(item) + 2 - if chunk_len > 80: - chuncks.append(chunk) - chunk = [] - chunk_len = len(item) + 2 - chunk.append(item) + for item in items: + chunk_len += len(item) + 2 + if chunk_len > 80: + chuncks.append(chunk) + chunk = [] + chunk_len = len(item) + 2 + chunk.append(item) - if len(chunk) != 0: - chuncks.append(chunk) + if len(chunk) != 0: + chuncks.append(chunk) - return chuncks + return chuncks def uniquify_enum_cases(lst): - """Prunes duplicate enum cases from the list. - - Arguments: - - lst: List whose elements are to be uniqued. Assumes each element is a - (symbol, value) pair and elements already sorted according to value. - - Returns: - - A list with all duplicates removed. The elements are sorted according to - value and, for each value, uniqued according to symbol. - original list, - - A map from deduplicated cases to the uniqued case. - """ - cases = lst - uniqued_cases = [] - duplicated_cases = {} - - # First sort according to the value - cases.sort(key=lambda x: x[1]) - - # Then group them according to the value - for _, groups in itertools.groupby(cases, key=lambda x: x[1]): - # For each value, sort according to the enumerant symbol. - sorted_group = sorted(groups, key=lambda x: x[0]) - # Keep the "smallest" case, which is typically the symbol without extension - # suffix. But we have special cases that we want to fix. - case = sorted_group[0] - for i in range(1, len(sorted_group)): - duplicated_cases[sorted_group[i][0]] = case[0] - if case[0] == 'HlslSemanticGOOGLE': - assert len(sorted_group) == 2, 'unexpected new variant for HlslSemantic' - case = sorted_group[1] - duplicated_cases[sorted_group[0][0]] = case[0] - uniqued_cases.append(case) - - return uniqued_cases, duplicated_cases + """Prunes duplicate enum cases from the list. + + Arguments: + - lst: List whose elements are to be uniqued. Assumes each element is a + (symbol, value) pair and elements already sorted according to value. + + Returns: + - A list with all duplicates removed. The elements are sorted according to + value and, for each value, uniqued according to symbol. + original list, + - A map from deduplicated cases to the uniqued case. + """ + cases = lst + uniqued_cases = [] + duplicated_cases = {} + + # First sort according to the value + cases.sort(key=lambda x: x[1]) + + # Then group them according to the value + for _, groups in itertools.groupby(cases, key=lambda x: x[1]): + # For each value, sort according to the enumerant symbol. + sorted_group = sorted(groups, key=lambda x: x[0]) + # Keep the "smallest" case, which is typically the symbol without extension + # suffix. But we have special cases that we want to fix. + case = sorted_group[0] + for i in range(1, len(sorted_group)): + duplicated_cases[sorted_group[i][0]] = case[0] + if case[0] == "HlslSemanticGOOGLE": + assert len(sorted_group) == 2, "unexpected new variant for HlslSemantic" + case = sorted_group[1] + duplicated_cases[sorted_group[0][0]] = case[0] + uniqued_cases.append(case) + + return uniqued_cases, duplicated_cases def toposort(dag, sort_fn): - """Topologically sorts the given dag. + """Topologically sorts the given dag. - Arguments: - - dag: a dict mapping from a node to its incoming nodes. - - sort_fn: a function for sorting nodes in the same batch. + Arguments: + - dag: a dict mapping from a node to its incoming nodes. + - sort_fn: a function for sorting nodes in the same batch. - Returns: - A list containing topologically sorted nodes. - """ + Returns: + A list containing topologically sorted nodes. + """ - # Returns the next batch of nodes without incoming edges - def get_next_batch(dag): - while True: - no_prev_nodes = set(node for node, prev in dag.items() if not prev) - if not no_prev_nodes: - break - yield sorted(no_prev_nodes, key=sort_fn) - dag = { - node: (prev - no_prev_nodes) - for node, prev in dag.items() - if node not in no_prev_nodes - } - assert not dag, 'found cyclic dependency' + # Returns the next batch of nodes without incoming edges + def get_next_batch(dag): + while True: + no_prev_nodes = set(node for node, prev in dag.items() if not prev) + if not no_prev_nodes: + break + yield sorted(no_prev_nodes, key=sort_fn) + dag = { + node: (prev - no_prev_nodes) + for node, prev in dag.items() + if node not in no_prev_nodes + } + assert not dag, "found cyclic dependency" - sorted_nodes = [] - for batch in get_next_batch(dag): - sorted_nodes.extend(batch) + sorted_nodes = [] + for batch in get_next_batch(dag): + sorted_nodes.extend(batch) - return sorted_nodes + return sorted_nodes def toposort_capabilities(all_cases, capability_mapping): - """Returns topologically sorted capability (symbol, value) pairs. - - Arguments: - - all_cases: all capability cases (containing symbol, value, and implied - capabilities). - - capability_mapping: mapping from duplicated capability symbols to the - canonicalized symbol chosen for SPIRVBase.td. - - Returns: - A list containing topologically sorted capability (symbol, value) pairs. - """ - dag = {} - name_to_value = {} - for case in all_cases: - # Get the current capability. - cur = case['enumerant'] - name_to_value[cur] = case['value'] - # Ignore duplicated symbols. - if cur in capability_mapping: - continue - - # Get capabilities implied by the current capability. - prev = case.get('capabilities', []) - uniqued_prev = set([capability_mapping.get(c, c) for c in prev]) - dag[cur] = uniqued_prev - - sorted_caps = toposort(dag, lambda x: name_to_value[x]) - # Attach the capability's value as the second component of the pair. - return [(c, name_to_value[c]) for c in sorted_caps] + """Returns topologically sorted capability (symbol, value) pairs. + + Arguments: + - all_cases: all capability cases (containing symbol, value, and implied + capabilities). + - capability_mapping: mapping from duplicated capability symbols to the + canonicalized symbol chosen for SPIRVBase.td. + + Returns: + A list containing topologically sorted capability (symbol, value) pairs. + """ + dag = {} + name_to_value = {} + for case in all_cases: + # Get the current capability. + cur = case["enumerant"] + name_to_value[cur] = case["value"] + # Ignore duplicated symbols. + if cur in capability_mapping: + continue + + # Get capabilities implied by the current capability. + prev = case.get("capabilities", []) + uniqued_prev = set([capability_mapping.get(c, c) for c in prev]) + dag[cur] = uniqued_prev + + sorted_caps = toposort(dag, lambda x: name_to_value[x]) + # Attach the capability's value as the second component of the pair. + return [(c, name_to_value[c]) for c in sorted_caps] def get_capability_mapping(operand_kinds): - """Returns the capability mapping from duplicated cases to canonicalized ones. + """Returns the capability mapping from duplicated cases to canonicalized ones. - Arguments: - - operand_kinds: all operand kinds' grammar spec + Arguments: + - operand_kinds: all operand kinds' grammar spec - Returns: - - A map mapping from duplicated capability symbols to the canonicalized - symbol chosen for SPIRVBase.td. - """ - # Find the operand kind for capability - cap_kind = {} - for kind in operand_kinds: - if kind['kind'] == 'Capability': - cap_kind = kind + Returns: + - A map mapping from duplicated capability symbols to the canonicalized + symbol chosen for SPIRVBase.td. + """ + # Find the operand kind for capability + cap_kind = {} + for kind in operand_kinds: + if kind["kind"] == "Capability": + cap_kind = kind - kind_cases = [ - (case['enumerant'], case['value']) for case in cap_kind['enumerants'] - ] - _, capability_mapping = uniquify_enum_cases(kind_cases) + kind_cases = [(case["enumerant"], case["value"]) for case in cap_kind["enumerants"]] + _, capability_mapping = uniquify_enum_cases(kind_cases) - return capability_mapping + return capability_mapping def get_availability_spec(enum_case, capability_mapping, for_op, for_cap): - """Returns the availability specification string for the given enum case. - - Arguments: - - enum_case: the enum case to generate availability spec for. It may contain - 'version', 'lastVersion', 'extensions', or 'capabilities'. - - capability_mapping: mapping from duplicated capability symbols to the - canonicalized symbol chosen for SPIRVBase.td. - - for_op: bool value indicating whether this is the availability spec for an - op itself. - - for_cap: bool value indicating whether this is the availability spec for - capabilities themselves. - - Returns: - - A `let availability = [...];` string if with availability spec or - empty string if without availability spec - """ - assert not (for_op and for_cap), 'cannot set both for_op and for_cap' - - DEFAULT_MIN_VERSION = 'MinVersion' - DEFAULT_MAX_VERSION = 'MaxVersion' - DEFAULT_CAP = 'Capability<[]>' - DEFAULT_EXT = 'Extension<[]>' - - min_version = enum_case.get('version', '') - if min_version == 'None': - min_version = '' - elif min_version: - min_version = 'MinVersion'.format(min_version.replace('.', '_')) - # TODO: delete this once ODS can support dialect-specific content - # and we can use omission to mean no requirements. - if for_op and not min_version: - min_version = DEFAULT_MIN_VERSION - - max_version = enum_case.get('lastVersion', '') - if max_version: - max_version = 'MaxVersion'.format(max_version.replace('.', '_')) - # TODO: delete this once ODS can support dialect-specific content - # and we can use omission to mean no requirements. - if for_op and not max_version: - max_version = DEFAULT_MAX_VERSION - - exts = enum_case.get('extensions', []) - if exts: - exts = 'Extension<[{}]>'.format(', '.join(sorted(set(exts)))) - # We need to strip the minimal version requirement if this symbol is - # available via an extension, which means *any* SPIR-V version can support - # it as long as the extension is provided. The grammar's 'version' field - # under such case should be interpreted as this symbol is introduced as - # a core symbol since the given version, rather than a minimal version - # requirement. - min_version = DEFAULT_MIN_VERSION if for_op else '' - # TODO: delete this once ODS can support dialect-specific content - # and we can use omission to mean no requirements. - if for_op and not exts: - exts = DEFAULT_EXT - - caps = enum_case.get('capabilities', []) - implies = '' - if caps: - canonicalized_caps = [] - for c in caps: - if c in capability_mapping: - canonicalized_caps.append(capability_mapping[c]) - else: - canonicalized_caps.append(c) - prefixed_caps = [ - 'SPIRV_C_{}'.format(c) for c in sorted(set(canonicalized_caps)) - ] - if for_cap: - # If this is generating the availability for capabilities, we need to - # put the capability "requirements" in implies field because now - # the "capabilities" field in the source grammar means so. - caps = '' - implies = 'list implies = [{}];'.format( - ', '.join(prefixed_caps)) - else: - caps = 'Capability<[{}]>'.format(', '.join(prefixed_caps)) - implies = '' - # TODO: delete this once ODS can support dialect-specific content - # and we can use omission to mean no requirements. - if for_op and not caps: - caps = DEFAULT_CAP - - avail = '' - # Compose availability spec if any of the requirements is not empty. - # For ops, because we have a default in SPIRV_Op class, omit if the spec - # is the same. - if (min_version or max_version or caps or exts) and not ( - for_op and min_version == DEFAULT_MIN_VERSION and - max_version == DEFAULT_MAX_VERSION and caps == DEFAULT_CAP and - exts == DEFAULT_EXT): - joined_spec = ',\n '.join( - [e for e in [min_version, max_version, exts, caps] if e]) - avail = '{} availability = [\n {}\n ];'.format( - 'let' if for_op else 'list', joined_spec) - - return '{}{}{}'.format(implies, '\n ' if implies and avail else '', avail) + """Returns the availability specification string for the given enum case. + + Arguments: + - enum_case: the enum case to generate availability spec for. It may contain + 'version', 'lastVersion', 'extensions', or 'capabilities'. + - capability_mapping: mapping from duplicated capability symbols to the + canonicalized symbol chosen for SPIRVBase.td. + - for_op: bool value indicating whether this is the availability spec for an + op itself. + - for_cap: bool value indicating whether this is the availability spec for + capabilities themselves. + + Returns: + - A `let availability = [...];` string if with availability spec or + empty string if without availability spec + """ + assert not (for_op and for_cap), "cannot set both for_op and for_cap" + + DEFAULT_MIN_VERSION = "MinVersion" + DEFAULT_MAX_VERSION = "MaxVersion" + DEFAULT_CAP = "Capability<[]>" + DEFAULT_EXT = "Extension<[]>" + + min_version = enum_case.get("version", "") + if min_version == "None": + min_version = "" + elif min_version: + min_version = "MinVersion".format(min_version.replace(".", "_")) + # TODO: delete this once ODS can support dialect-specific content + # and we can use omission to mean no requirements. + if for_op and not min_version: + min_version = DEFAULT_MIN_VERSION + + max_version = enum_case.get("lastVersion", "") + if max_version: + max_version = "MaxVersion".format(max_version.replace(".", "_")) + # TODO: delete this once ODS can support dialect-specific content + # and we can use omission to mean no requirements. + if for_op and not max_version: + max_version = DEFAULT_MAX_VERSION + + exts = enum_case.get("extensions", []) + if exts: + exts = "Extension<[{}]>".format(", ".join(sorted(set(exts)))) + # We need to strip the minimal version requirement if this symbol is + # available via an extension, which means *any* SPIR-V version can support + # it as long as the extension is provided. The grammar's 'version' field + # under such case should be interpreted as this symbol is introduced as + # a core symbol since the given version, rather than a minimal version + # requirement. + min_version = DEFAULT_MIN_VERSION if for_op else "" + # TODO: delete this once ODS can support dialect-specific content + # and we can use omission to mean no requirements. + if for_op and not exts: + exts = DEFAULT_EXT + + caps = enum_case.get("capabilities", []) + implies = "" + if caps: + canonicalized_caps = [] + for c in caps: + if c in capability_mapping: + canonicalized_caps.append(capability_mapping[c]) + else: + canonicalized_caps.append(c) + prefixed_caps = [ + "SPIRV_C_{}".format(c) for c in sorted(set(canonicalized_caps)) + ] + if for_cap: + # If this is generating the availability for capabilities, we need to + # put the capability "requirements" in implies field because now + # the "capabilities" field in the source grammar means so. + caps = "" + implies = "list implies = [{}];".format( + ", ".join(prefixed_caps) + ) + else: + caps = "Capability<[{}]>".format(", ".join(prefixed_caps)) + implies = "" + # TODO: delete this once ODS can support dialect-specific content + # and we can use omission to mean no requirements. + if for_op and not caps: + caps = DEFAULT_CAP + + avail = "" + # Compose availability spec if any of the requirements is not empty. + # For ops, because we have a default in SPIRV_Op class, omit if the spec + # is the same. + if (min_version or max_version or caps or exts) and not ( + for_op + and min_version == DEFAULT_MIN_VERSION + and max_version == DEFAULT_MAX_VERSION + and caps == DEFAULT_CAP + and exts == DEFAULT_EXT + ): + joined_spec = ",\n ".join( + [e for e in [min_version, max_version, exts, caps] if e] + ) + avail = "{} availability = [\n {}\n ];".format( + "let" if for_op else "list", joined_spec + ) + + return "{}{}{}".format(implies, "\n " if implies and avail else "", avail) def gen_operand_kind_enum_attr(operand_kind, capability_mapping): - """Generates the TableGen EnumAttr definition for the given operand kind. - - Returns: - - The operand kind's name - - A string containing the TableGen EnumAttr definition - """ - if 'enumerants' not in operand_kind: - return '', '' - - # Returns a symbol for the given case in the given kind. This function - # handles Dim specially to avoid having numbers as the start of symbols, - # which does not play well with C++ and the MLIR parser. - def get_case_symbol(kind_name, case_name): - if kind_name == 'Dim': - if case_name == '1D' or case_name == '2D' or case_name == '3D': - return 'Dim{}'.format(case_name) - return case_name - - kind_name = operand_kind['kind'] - is_bit_enum = operand_kind['category'] == 'BitEnum' - kind_acronym = ''.join([c for c in kind_name if c >= 'A' and c <= 'Z']) - - name_to_case_dict = {} - for case in operand_kind['enumerants']: - name_to_case_dict[case['enumerant']] = case - - if kind_name == 'Capability': - # Special treatment for capability cases: we need to sort them topologically - # because a capability can refer to another via the 'implies' field. - kind_cases = toposort_capabilities(operand_kind['enumerants'], - capability_mapping) - else: - kind_cases = [(case['enumerant'], case['value']) - for case in operand_kind['enumerants']] - kind_cases, _ = uniquify_enum_cases(kind_cases) - max_len = max([len(symbol) for (symbol, _) in kind_cases]) - - # Generate the definition for each enum case - case_category = 'I32Bit' if is_bit_enum else 'I32' - fmt_str = 'def SPIRV_{acronym}_{case_name} {colon:>{offset}} '\ - '{category}EnumAttrCase{suffix}<"{symbol}"{case_value_part}>{avail}' - case_defs = [] - for case_pair in kind_cases: - name = case_pair[0] - if is_bit_enum: - value = int(case_pair[1], base=16) + """Generates the TableGen EnumAttr definition for the given operand kind. + + Returns: + - The operand kind's name + - A string containing the TableGen EnumAttr definition + """ + if "enumerants" not in operand_kind: + return "", "" + + # Returns a symbol for the given case in the given kind. This function + # handles Dim specially to avoid having numbers as the start of symbols, + # which does not play well with C++ and the MLIR parser. + def get_case_symbol(kind_name, case_name): + if kind_name == "Dim": + if case_name == "1D" or case_name == "2D" or case_name == "3D": + return "Dim{}".format(case_name) + return case_name + + kind_name = operand_kind["kind"] + is_bit_enum = operand_kind["category"] == "BitEnum" + kind_acronym = "".join([c for c in kind_name if c >= "A" and c <= "Z"]) + + name_to_case_dict = {} + for case in operand_kind["enumerants"]: + name_to_case_dict[case["enumerant"]] = case + + if kind_name == "Capability": + # Special treatment for capability cases: we need to sort them topologically + # because a capability can refer to another via the 'implies' field. + kind_cases = toposort_capabilities( + operand_kind["enumerants"], capability_mapping + ) else: - value = int(case_pair[1]) - avail = get_availability_spec(name_to_case_dict[name], - capability_mapping, - False, kind_name == 'Capability') - if is_bit_enum: - if value == 0: - suffix = 'None' - value = '' - else: - suffix = "Bit" - value = ', {}'.format(int(math.log2(value))) - else: - suffix = '' - value = ', {}'.format(value) - - case_def = fmt_str.format( - category=case_category, - suffix=suffix, - acronym=kind_acronym, - case_name=name, - symbol=get_case_symbol(kind_name, name), - case_value_part=value, - avail=' {{\n {}\n}}'.format(avail) if avail else ';', - colon=':', - offset=(max_len + 1 - len(name))) - case_defs.append(case_def) - case_defs = '\n'.join(case_defs) - - # Generate the list of enum case names - fmt_str = 'SPIRV_{acronym}_{symbol}'; - case_names = [fmt_str.format(acronym=kind_acronym,symbol=case[0]) - for case in kind_cases] - - # Split them into sublists and concatenate into multiple lines - case_names = split_list_into_sublists(case_names) - case_names = ['{:6}'.format('') + ', '.join(sublist) - for sublist in case_names] - case_names = ',\n'.join(case_names) - - # Generate the enum attribute definition - kind_category = 'Bit' if is_bit_enum else 'I32' - enum_attr = '''def SPIRV_{name}Attr : + kind_cases = [ + (case["enumerant"], case["value"]) for case in operand_kind["enumerants"] + ] + kind_cases, _ = uniquify_enum_cases(kind_cases) + max_len = max([len(symbol) for (symbol, _) in kind_cases]) + + # Generate the definition for each enum case + case_category = "I32Bit" if is_bit_enum else "I32" + fmt_str = ( + "def SPIRV_{acronym}_{case_name} {colon:>{offset}} " + '{category}EnumAttrCase{suffix}<"{symbol}"{case_value_part}>{avail}' + ) + case_defs = [] + for case_pair in kind_cases: + name = case_pair[0] + if is_bit_enum: + value = int(case_pair[1], base=16) + else: + value = int(case_pair[1]) + avail = get_availability_spec( + name_to_case_dict[name], + capability_mapping, + False, + kind_name == "Capability", + ) + if is_bit_enum: + if value == 0: + suffix = "None" + value = "" + else: + suffix = "Bit" + value = ", {}".format(int(math.log2(value))) + else: + suffix = "" + value = ", {}".format(value) + + case_def = fmt_str.format( + category=case_category, + suffix=suffix, + acronym=kind_acronym, + case_name=name, + symbol=get_case_symbol(kind_name, name), + case_value_part=value, + avail=" {{\n {}\n}}".format(avail) if avail else ";", + colon=":", + offset=(max_len + 1 - len(name)), + ) + case_defs.append(case_def) + case_defs = "\n".join(case_defs) + + # Generate the list of enum case names + fmt_str = "SPIRV_{acronym}_{symbol}" + case_names = [ + fmt_str.format(acronym=kind_acronym, symbol=case[0]) for case in kind_cases + ] + + # Split them into sublists and concatenate into multiple lines + case_names = split_list_into_sublists(case_names) + case_names = ["{:6}".format("") + ", ".join(sublist) for sublist in case_names] + case_names = ",\n".join(case_names) + + # Generate the enum attribute definition + kind_category = "Bit" if is_bit_enum else "I32" + enum_attr = """def SPIRV_{name}Attr : SPIRV_{category}EnumAttr<"{name}", "valid SPIR-V {name}", "{snake_name}", [ {cases} - ]>;'''.format( - name=kind_name, - snake_name=snake_casify(kind_name), - category=kind_category, - cases=case_names) - return kind_name, case_defs + '\n\n' + enum_attr + ]>;""".format( + name=kind_name, + snake_name=snake_casify(kind_name), + category=kind_category, + cases=case_names, + ) + return kind_name, case_defs + "\n\n" + enum_attr def gen_opcode(instructions): - """ Generates the TableGen definition to map opname to opcode - - Returns: - - A string containing the TableGen SPIRV_OpCode definition - """ - - max_len = max([len(inst['opname']) for inst in instructions]) - def_fmt_str = 'def SPIRV_OC_{name} {colon:>{offset}} '\ - 'I32EnumAttrCase<"{name}", {value}>;' - opcode_defs = [ - def_fmt_str.format( - name=inst['opname'], - value=inst['opcode'], - colon=':', - offset=(max_len + 1 - len(inst['opname']))) for inst in instructions - ] - opcode_str = '\n'.join(opcode_defs) - - decl_fmt_str = 'SPIRV_OC_{name}' - opcode_list = [ - decl_fmt_str.format(name=inst['opname']) for inst in instructions - ] - opcode_list = split_list_into_sublists(opcode_list) - opcode_list = [ - '{:6}'.format('') + ', '.join(sublist) for sublist in opcode_list - ] - opcode_list = ',\n'.join(opcode_list) - enum_attr = 'def SPIRV_OpcodeAttr :\n'\ - ' SPIRV_I32EnumAttr<"{name}", "valid SPIR-V instructions", '\ - '"opcode", [\n'\ - '{lst}\n'\ - ' ]>;'.format(name='Opcode', lst=opcode_list) - return opcode_str + '\n\n' + enum_attr + """Generates the TableGen definition to map opname to opcode + + Returns: + - A string containing the TableGen SPIRV_OpCode definition + """ + + max_len = max([len(inst["opname"]) for inst in instructions]) + def_fmt_str = ( + "def SPIRV_OC_{name} {colon:>{offset}} " 'I32EnumAttrCase<"{name}", {value}>;' + ) + opcode_defs = [ + def_fmt_str.format( + name=inst["opname"], + value=inst["opcode"], + colon=":", + offset=(max_len + 1 - len(inst["opname"])), + ) + for inst in instructions + ] + opcode_str = "\n".join(opcode_defs) + + decl_fmt_str = "SPIRV_OC_{name}" + opcode_list = [decl_fmt_str.format(name=inst["opname"]) for inst in instructions] + opcode_list = split_list_into_sublists(opcode_list) + opcode_list = ["{:6}".format("") + ", ".join(sublist) for sublist in opcode_list] + opcode_list = ",\n".join(opcode_list) + enum_attr = ( + "def SPIRV_OpcodeAttr :\n" + ' SPIRV_I32EnumAttr<"{name}", "valid SPIR-V instructions", ' + '"opcode", [\n' + "{lst}\n" + " ]>;".format(name="Opcode", lst=opcode_list) + ) + return opcode_str + "\n\n" + enum_attr + def map_cap_to_opnames(instructions): - """Maps capabilities to instructions enabled by those capabilities + """Maps capabilities to instructions enabled by those capabilities - Arguments: - - instructions: a list containing a subset of SPIR-V instructions' grammar - Returns: - - A map with keys representing capabilities and values of lists of - instructions enabled by the corresponding key - """ - cap_to_inst = {} + Arguments: + - instructions: a list containing a subset of SPIR-V instructions' grammar + Returns: + - A map with keys representing capabilities and values of lists of + instructions enabled by the corresponding key + """ + cap_to_inst = {} - for inst in instructions: - caps = inst['capabilities'] if 'capabilities' in inst else ['0_core_0'] - for cap in caps: - if cap not in cap_to_inst: - cap_to_inst[cap] = [] - cap_to_inst[cap].append(inst['opname']) + for inst in instructions: + caps = inst["capabilities"] if "capabilities" in inst else ["0_core_0"] + for cap in caps: + if cap not in cap_to_inst: + cap_to_inst[cap] = [] + cap_to_inst[cap].append(inst["opname"]) + + return cap_to_inst - return cap_to_inst def gen_instr_coverage_report(path, instructions): - """Dumps to standard output a YAML report of current instruction coverage + """Dumps to standard output a YAML report of current instruction coverage - Arguments: - - path: the path to SPIRBase.td - - instructions: a list containing all SPIR-V instructions' grammar - """ - with open(path, 'r') as f: - content = f.read() + Arguments: + - path: the path to SPIRBase.td + - instructions: a list containing all SPIR-V instructions' grammar + """ + with open(path, "r") as f: + content = f.read() - content = content.split(AUTOGEN_OPCODE_SECTION_MARKER) + content = content.split(AUTOGEN_OPCODE_SECTION_MARKER) - existing_opcodes = [k[11:] for k in re.findall('def SPIRV_OC_\w+', content[1])] - existing_instructions = list( - filter(lambda inst: (inst['opname'] in existing_opcodes), - instructions)) + existing_opcodes = [k[11:] for k in re.findall("def SPIRV_OC_\w+", content[1])] + existing_instructions = list( + filter(lambda inst: (inst["opname"] in existing_opcodes), instructions) + ) - instructions_opnames = [inst['opname'] for inst in instructions] + instructions_opnames = [inst["opname"] for inst in instructions] - remaining_opcodes = list(set(instructions_opnames) - set(existing_opcodes)) - remaining_instructions = list( - filter(lambda inst: (inst['opname'] in remaining_opcodes), - instructions)) + remaining_opcodes = list(set(instructions_opnames) - set(existing_opcodes)) + remaining_instructions = list( + filter(lambda inst: (inst["opname"] in remaining_opcodes), instructions) + ) - rem_cap_to_instr = map_cap_to_opnames(remaining_instructions) - ex_cap_to_instr = map_cap_to_opnames(existing_instructions) + rem_cap_to_instr = map_cap_to_opnames(remaining_instructions) + ex_cap_to_instr = map_cap_to_opnames(existing_instructions) - rem_cap_to_cov = {} + rem_cap_to_cov = {} - # Calculate coverage for each capability - for cap in rem_cap_to_instr: - if cap not in ex_cap_to_instr: - rem_cap_to_cov[cap] = 0.0 - else: - rem_cap_to_cov[cap] = \ - (len(ex_cap_to_instr[cap]) / (len(ex_cap_to_instr[cap]) \ - + len(rem_cap_to_instr[cap]))) + # Calculate coverage for each capability + for cap in rem_cap_to_instr: + if cap not in ex_cap_to_instr: + rem_cap_to_cov[cap] = 0.0 + else: + rem_cap_to_cov[cap] = len(ex_cap_to_instr[cap]) / ( + len(ex_cap_to_instr[cap]) + len(rem_cap_to_instr[cap]) + ) - report = {} + report = {} - # Merge the 3 maps into one report - for cap in rem_cap_to_instr: - report[cap] = {} - report[cap]['Supported Instructions'] = \ + # Merge the 3 maps into one report + for cap in rem_cap_to_instr: + report[cap] = {} + report[cap]["Supported Instructions"] = ( ex_cap_to_instr[cap] if cap in ex_cap_to_instr else [] - report[cap]['Unsupported Instructions'] = rem_cap_to_instr[cap] - report[cap]['Coverage'] = '{}%'.format(int(rem_cap_to_cov[cap] * 100)) + ) + report[cap]["Unsupported Instructions"] = rem_cap_to_instr[cap] + report[cap]["Coverage"] = "{}%".format(int(rem_cap_to_cov[cap] * 100)) + + print(yaml.dump(report)) - print(yaml.dump(report)) def update_td_opcodes(path, instructions, filter_list): - """Updates SPIRBase.td with new generated opcode cases. - - Arguments: - - path: the path to SPIRBase.td - - instructions: a list containing all SPIR-V instructions' grammar - - filter_list: a list containing new opnames to add - """ - - with open(path, 'r') as f: - content = f.read() - - content = content.split(AUTOGEN_OPCODE_SECTION_MARKER) - assert len(content) == 3 - - # Extend opcode list with existing list - prefix = 'def SPIRV_OC_' - existing_opcodes = [k[len(prefix):] for k in re.findall(prefix + '\w+', content[1])] - filter_list.extend(existing_opcodes) - filter_list = list(set(filter_list)) - - # Generate the opcode for all instructions in SPIR-V - filter_instrs = list( - filter(lambda inst: (inst['opname'] in filter_list), instructions)) - # Sort instruction based on opcode - filter_instrs.sort(key=lambda inst: inst['opcode']) - opcode = gen_opcode(filter_instrs) - - # Substitute the opcode - content = content[0] + AUTOGEN_OPCODE_SECTION_MARKER + '\n\n' + \ - opcode + '\n\n// End ' + AUTOGEN_OPCODE_SECTION_MARKER \ + """Updates SPIRBase.td with new generated opcode cases. + + Arguments: + - path: the path to SPIRBase.td + - instructions: a list containing all SPIR-V instructions' grammar + - filter_list: a list containing new opnames to add + """ + + with open(path, "r") as f: + content = f.read() + + content = content.split(AUTOGEN_OPCODE_SECTION_MARKER) + assert len(content) == 3 + + # Extend opcode list with existing list + prefix = "def SPIRV_OC_" + existing_opcodes = [ + k[len(prefix) :] for k in re.findall(prefix + "\w+", content[1]) + ] + filter_list.extend(existing_opcodes) + filter_list = list(set(filter_list)) + + # Generate the opcode for all instructions in SPIR-V + filter_instrs = list( + filter(lambda inst: (inst["opname"] in filter_list), instructions) + ) + # Sort instruction based on opcode + filter_instrs.sort(key=lambda inst: inst["opcode"]) + opcode = gen_opcode(filter_instrs) + + # Substitute the opcode + content = ( + content[0] + + AUTOGEN_OPCODE_SECTION_MARKER + + "\n\n" + + opcode + + "\n\n// End " + + AUTOGEN_OPCODE_SECTION_MARKER + content[2] + ) - with open(path, 'w') as f: - f.write(content) + with open(path, "w") as f: + f.write(content) def update_td_enum_attrs(path, operand_kinds, filter_list): - """Updates SPIRBase.td with new generated enum definitions. - - Arguments: - - path: the path to SPIRBase.td - - operand_kinds: a list containing all operand kinds' grammar - - filter_list: a list containing new enums to add - """ - with open(path, 'r') as f: - content = f.read() - - content = content.split(AUTOGEN_ENUM_SECTION_MARKER) - assert len(content) == 3 - - # Extend filter list with existing enum definitions - existing_kinds = [ - k[8:-4] for k in re.findall('def SPIRV_\w+Attr', content[1])] - filter_list.extend(existing_kinds) - - capability_mapping = get_capability_mapping(operand_kinds) - - # Generate definitions for all enums in filter list - defs = [ - gen_operand_kind_enum_attr(kind, capability_mapping) - for kind in operand_kinds - if kind['kind'] in filter_list - ] - # Sort alphabetically according to enum name - defs.sort(key=lambda enum : enum[0]) - # Only keep the definitions from now on - # Put Capability's definition at the very beginning because capability cases - # will be referenced later - defs = [enum[1] for enum in defs if enum[0] == 'Capability' - ] + [enum[1] for enum in defs if enum[0] != 'Capability'] - - # Substitute the old section - content = content[0] + AUTOGEN_ENUM_SECTION_MARKER + '\n\n' + \ - '\n\n'.join(defs) + "\n\n// End " + AUTOGEN_ENUM_SECTION_MARKER \ - + content[2]; - - with open(path, 'w') as f: - f.write(content) + """Updates SPIRBase.td with new generated enum definitions. + + Arguments: + - path: the path to SPIRBase.td + - operand_kinds: a list containing all operand kinds' grammar + - filter_list: a list containing new enums to add + """ + with open(path, "r") as f: + content = f.read() + + content = content.split(AUTOGEN_ENUM_SECTION_MARKER) + assert len(content) == 3 + + # Extend filter list with existing enum definitions + existing_kinds = [k[8:-4] for k in re.findall("def SPIRV_\w+Attr", content[1])] + filter_list.extend(existing_kinds) + + capability_mapping = get_capability_mapping(operand_kinds) + + # Generate definitions for all enums in filter list + defs = [ + gen_operand_kind_enum_attr(kind, capability_mapping) + for kind in operand_kinds + if kind["kind"] in filter_list + ] + # Sort alphabetically according to enum name + defs.sort(key=lambda enum: enum[0]) + # Only keep the definitions from now on + # Put Capability's definition at the very beginning because capability cases + # will be referenced later + defs = [enum[1] for enum in defs if enum[0] == "Capability"] + [ + enum[1] for enum in defs if enum[0] != "Capability" + ] + + # Substitute the old section + content = ( + content[0] + + AUTOGEN_ENUM_SECTION_MARKER + + "\n\n" + + "\n\n".join(defs) + + "\n\n// End " + + AUTOGEN_ENUM_SECTION_MARKER + + content[2] + ) + + with open(path, "w") as f: + f.write(content) def snake_casify(name): - """Turns the given name to follow snake_case convention.""" - return re.sub(r'(?' - else: - arg_type = 'Variadic' - elif kind == 'IdMemorySemantics' or kind == 'IdScope': - # TODO: Need to further constrain 'IdMemorySemantics' - # and 'IdScope' given that they should be generated from OpConstant. - assert quantifier == '', ('unexpected to have optional/variadic memory ' - 'semantics or scope ') - arg_type = 'SPIRV_' + kind[2:] + 'Attr' - elif kind == 'LiteralInteger': - if quantifier == '': - arg_type = 'I32Attr' - elif quantifier == '?': - arg_type = 'OptionalAttr' + """Maps an operand in SPIR-V JSON spec to an op argument in ODS. + + Arguments: + - A dict containing the operand's kind, quantifier, and name + + Returns: + - A string containing both the type and name for the argument + """ + kind = operand["kind"] + quantifier = operand.get("quantifier", "") + + # These instruction "operands" are for encoding the results; they should + # not be handled here. + assert kind != "IdResultType", 'unexpected to handle "IdResultType" kind' + assert kind != "IdResult", 'unexpected to handle "IdResult" kind' + + if kind == "IdRef": + if quantifier == "": + arg_type = "SPIRV_Type" + elif quantifier == "?": + arg_type = "Optional" + else: + arg_type = "Variadic" + elif kind == "IdMemorySemantics" or kind == "IdScope": + # TODO: Need to further constrain 'IdMemorySemantics' + # and 'IdScope' given that they should be generated from OpConstant. + assert quantifier == "", ( + "unexpected to have optional/variadic memory " "semantics or scope " + ) + arg_type = "SPIRV_" + kind[2:] + "Attr" + elif kind == "LiteralInteger": + if quantifier == "": + arg_type = "I32Attr" + elif quantifier == "?": + arg_type = "OptionalAttr" + else: + arg_type = "OptionalAttr" + elif ( + kind == "LiteralString" + or kind == "LiteralContextDependentNumber" + or kind == "LiteralExtInstInteger" + or kind == "LiteralSpecConstantOpInteger" + or kind == "PairLiteralIntegerIdRef" + or kind == "PairIdRefLiteralInteger" + or kind == "PairIdRefIdRef" + ): + assert False, '"{}" kind unimplemented'.format(kind) else: - arg_type = 'OptionalAttr' - elif kind == 'LiteralString' or \ - kind == 'LiteralContextDependentNumber' or \ - kind == 'LiteralExtInstInteger' or \ - kind == 'LiteralSpecConstantOpInteger' or \ - kind == 'PairLiteralIntegerIdRef' or \ - kind == 'PairIdRefLiteralInteger' or \ - kind == 'PairIdRefIdRef': - assert False, '"{}" kind unimplemented'.format(kind) - else: - # The rest are all enum operands that we represent with op attributes. - assert quantifier != '*', 'unexpected to have variadic enum attribute' - arg_type = 'SPIRV_{}Attr'.format(kind) - if quantifier == '?': - arg_type = 'OptionalAttr<{}>'.format(arg_type) - - name = operand.get('name', '') - name = snake_casify(name) if name else kind.lower() - - return '{}:${}'.format(arg_type, name) + # The rest are all enum operands that we represent with op attributes. + assert quantifier != "*", "unexpected to have variadic enum attribute" + arg_type = "SPIRV_{}Attr".format(kind) + if quantifier == "?": + arg_type = "OptionalAttr<{}>".format(arg_type) + + name = operand.get("name", "") + name = snake_casify(name) if name else kind.lower() + + return "{}:${}".format(arg_type, name) def get_description(text, appendix): - """Generates the description for the given SPIR-V instruction. - - Arguments: - - text: Textual description of the operation as string. - - appendix: Additional contents to attach in description as string, - includking IR examples, and others. - - Returns: - - A string that corresponds to the description of the Tablegen op. - """ - fmt_str = '{text}\n\n \n{appendix}\n ' - return fmt_str.format(text=text, appendix=appendix) - - -def get_op_definition(instruction, opname, doc, existing_info, capability_mapping, settings): - """Generates the TableGen op definition for the given SPIR-V instruction. - - Arguments: - - instruction: the instruction's SPIR-V JSON grammar - - doc: the instruction's SPIR-V HTML doc - - existing_info: a dict containing potential manually specified sections for - this instruction - - capability_mapping: mapping from duplicated capability symbols to the - canonicalized symbol chosen for SPIRVBase.td - - Returns: - - A string containing the TableGen op definition - """ - if settings.gen_cl_ops: - fmt_str = ('def SPIRV_{opname}Op : ' - 'SPIRV_{inst_category}<"{opname_src}", {opcode}, <> > ' - '{{\n let summary = {summary};\n\n let description = ' - '[{{\n{description}}}];{availability}\n') - else: - fmt_str = ('def SPIRV_{vendor_name}{opname_src}Op : ' - 'SPIRV_{inst_category}<"{opname_src}"{category_args}, [{traits}]> ' - '{{\n let summary = {summary};\n\n let description = ' - '[{{\n{description}}}];{availability}\n') - - vendor_name = '' - inst_category = existing_info.get('inst_category', 'Op') - if inst_category == 'Op': - fmt_str +='\n let arguments = (ins{args});\n\n'\ - ' let results = (outs{results});\n' - elif inst_category.endswith('VendorOp'): - vendor_name = inst_category.split('VendorOp')[0].upper() - assert len(vendor_name) != 0, 'Invalid instruction category' - - fmt_str +='{extras}'\ - '}}\n' - - opname_src = instruction['opname'] - if opname.startswith('Op'): - opname_src = opname_src[2:] - if len(vendor_name) > 0: - assert opname_src.endswith(vendor_name), "op name does not match the instruction category" - opname_src = opname_src[:-len(vendor_name)] - - category_args = existing_info.get('category_args', '') - - if '\n' in doc: - summary, text = doc.split('\n', 1) - else: - summary = doc - text = '' - wrapper = textwrap.TextWrapper( - width=76, initial_indent=' ', subsequent_indent=' ') - - # Format summary. If the summary can fit in the same line, we print it out - # as a "-quoted string; otherwise, wrap the lines using "[{...}]". - summary = summary.strip() - if len(summary) + len(' let summary = "";') <= 80: - summary = '"{}"'.format(summary) - else: - summary = '[{{\n{}\n }}]'.format(wrapper.fill(summary)) - - # Wrap text - text = text.split('\n') - text = [wrapper.fill(line) for line in text if line] - text = '\n\n'.join(text) - - operands = instruction.get('operands', []) - - # Op availability - avail = get_availability_spec(instruction, capability_mapping, True, False) - if avail: - avail = '\n\n {0}'.format(avail) - - # Set op's result - results = '' - if len(operands) > 0 and operands[0]['kind'] == 'IdResultType': - results = '\n SPIRV_Type:$result\n ' - operands = operands[1:] - if 'results' in existing_info: - results = existing_info['results'] - - # Ignore the operand standing for the result - if len(operands) > 0 and operands[0]['kind'] == 'IdResult': - operands = operands[1:] - - # Set op' argument - arguments = existing_info.get('arguments', None) - if arguments is None: - arguments = [map_spec_operand_to_ods_argument(o) for o in operands] - arguments = ',\n '.join(arguments) - if arguments: - # Prepend and append whitespace for formatting - arguments = '\n {}\n '.format(arguments) - - description = existing_info.get('description', None) - if description is None: - assembly = '\n ```\n'\ - ' [TODO]\n'\ - ' ```\n\n'\ - ' #### Example:\n\n'\ - ' ```mlir\n'\ - ' [TODO]\n' \ - ' ```' - description = get_description(text, assembly) - - return fmt_str.format( - opname=opname, - opname_src=opname_src, - opcode=instruction['opcode'], - category_args=category_args, - inst_category=inst_category, - vendor_name=vendor_name, - traits=existing_info.get('traits', ''), - summary=summary, - description=description, - availability=avail, - args=arguments, - results=results, - extras=existing_info.get('extras', '')) + """Generates the description for the given SPIR-V instruction. + + Arguments: + - text: Textual description of the operation as string. + - appendix: Additional contents to attach in description as string, + includking IR examples, and others. + + Returns: + - A string that corresponds to the description of the Tablegen op. + """ + fmt_str = "{text}\n\n \n{appendix}\n " + return fmt_str.format(text=text, appendix=appendix) + + +def get_op_definition( + instruction, opname, doc, existing_info, capability_mapping, settings +): + """Generates the TableGen op definition for the given SPIR-V instruction. + + Arguments: + - instruction: the instruction's SPIR-V JSON grammar + - doc: the instruction's SPIR-V HTML doc + - existing_info: a dict containing potential manually specified sections for + this instruction + - capability_mapping: mapping from duplicated capability symbols to the + canonicalized symbol chosen for SPIRVBase.td + + Returns: + - A string containing the TableGen op definition + """ + if settings.gen_cl_ops: + fmt_str = ( + "def SPIRV_{opname}Op : " + 'SPIRV_{inst_category}<"{opname_src}", {opcode}, <> > ' + "{{\n let summary = {summary};\n\n let description = " + "[{{\n{description}}}];{availability}\n" + ) + else: + fmt_str = ( + "def SPIRV_{vendor_name}{opname_src}Op : " + 'SPIRV_{inst_category}<"{opname_src}"{category_args}, [{traits}]> ' + "{{\n let summary = {summary};\n\n let description = " + "[{{\n{description}}}];{availability}\n" + ) + + vendor_name = "" + inst_category = existing_info.get("inst_category", "Op") + if inst_category == "Op": + fmt_str += ( + "\n let arguments = (ins{args});\n\n" " let results = (outs{results});\n" + ) + elif inst_category.endswith("VendorOp"): + vendor_name = inst_category.split("VendorOp")[0].upper() + assert len(vendor_name) != 0, "Invalid instruction category" + + fmt_str += "{extras}" "}}\n" + + opname_src = instruction["opname"] + if opname.startswith("Op"): + opname_src = opname_src[2:] + if len(vendor_name) > 0: + assert opname_src.endswith( + vendor_name + ), "op name does not match the instruction category" + opname_src = opname_src[: -len(vendor_name)] + + category_args = existing_info.get("category_args", "") + + if "\n" in doc: + summary, text = doc.split("\n", 1) + else: + summary = doc + text = "" + wrapper = textwrap.TextWrapper( + width=76, initial_indent=" ", subsequent_indent=" " + ) + + # Format summary. If the summary can fit in the same line, we print it out + # as a "-quoted string; otherwise, wrap the lines using "[{...}]". + summary = summary.strip() + if len(summary) + len(' let summary = "";') <= 80: + summary = '"{}"'.format(summary) + else: + summary = "[{{\n{}\n }}]".format(wrapper.fill(summary)) + + # Wrap text + text = text.split("\n") + text = [wrapper.fill(line) for line in text if line] + text = "\n\n".join(text) + + operands = instruction.get("operands", []) + + # Op availability + avail = get_availability_spec(instruction, capability_mapping, True, False) + if avail: + avail = "\n\n {0}".format(avail) + + # Set op's result + results = "" + if len(operands) > 0 and operands[0]["kind"] == "IdResultType": + results = "\n SPIRV_Type:$result\n " + operands = operands[1:] + if "results" in existing_info: + results = existing_info["results"] + + # Ignore the operand standing for the result + if len(operands) > 0 and operands[0]["kind"] == "IdResult": + operands = operands[1:] + + # Set op' argument + arguments = existing_info.get("arguments", None) + if arguments is None: + arguments = [map_spec_operand_to_ods_argument(o) for o in operands] + arguments = ",\n ".join(arguments) + if arguments: + # Prepend and append whitespace for formatting + arguments = "\n {}\n ".format(arguments) + + description = existing_info.get("description", None) + if description is None: + assembly = ( + "\n ```\n" + " [TODO]\n" + " ```\n\n" + " #### Example:\n\n" + " ```mlir\n" + " [TODO]\n" + " ```" + ) + description = get_description(text, assembly) + + return fmt_str.format( + opname=opname, + opname_src=opname_src, + opcode=instruction["opcode"], + category_args=category_args, + inst_category=inst_category, + vendor_name=vendor_name, + traits=existing_info.get("traits", ""), + summary=summary, + description=description, + availability=avail, + args=arguments, + results=results, + extras=existing_info.get("extras", ""), + ) def get_string_between(base, start, end): - """Extracts a substring with a specified start and end from a string. - - Arguments: - - base: string to extract from. - - start: string to use as the start of the substring. - - end: string to use as the end of the substring. - - Returns: - - The substring if found - - The part of the base after end of the substring. Is the base string itself - if the substring wasnt found. - """ - split = base.split(start, 1) - if len(split) == 2: - rest = split[1].split(end, 1) - assert len(rest) == 2, \ - 'cannot find end "{end}" while extracting substring '\ - 'starting with {start}'.format(start=start, end=end) - return rest[0].rstrip(end), rest[1] - return '', split[0] + """Extracts a substring with a specified start and end from a string. + + Arguments: + - base: string to extract from. + - start: string to use as the start of the substring. + - end: string to use as the end of the substring. + + Returns: + - The substring if found + - The part of the base after end of the substring. Is the base string itself + if the substring wasnt found. + """ + split = base.split(start, 1) + if len(split) == 2: + rest = split[1].split(end, 1) + assert len(rest) == 2, ( + 'cannot find end "{end}" while extracting substring ' + "starting with {start}".format(start=start, end=end) + ) + return rest[0].rstrip(end), rest[1] + return "", split[0] def get_string_between_nested(base, start, end): - """Extracts a substring with a nested start and end from a string. - - Arguments: - - base: string to extract from. - - start: string to use as the start of the substring. - - end: string to use as the end of the substring. - - Returns: - - The substring if found - - The part of the base after end of the substring. Is the base string itself - if the substring wasn't found. - """ - split = base.split(start, 1) - if len(split) == 2: - # Handle nesting delimiters - rest = split[1] - unmatched_start = 1 - index = 0 - while unmatched_start > 0 and index < len(rest): - if rest[index:].startswith(end): - unmatched_start -= 1 - if unmatched_start == 0: - break - index += len(end) - elif rest[index:].startswith(start): - unmatched_start += 1 - index += len(start) - else: - index += 1 - - assert index < len(rest), \ - 'cannot find end "{end}" while extracting substring '\ - 'starting with "{start}"'.format(start=start, end=end) - return rest[:index], rest[index + len(end):] - return '', split[0] + """Extracts a substring with a nested start and end from a string. + + Arguments: + - base: string to extract from. + - start: string to use as the start of the substring. + - end: string to use as the end of the substring. + + Returns: + - The substring if found + - The part of the base after end of the substring. Is the base string itself + if the substring wasn't found. + """ + split = base.split(start, 1) + if len(split) == 2: + # Handle nesting delimiters + rest = split[1] + unmatched_start = 1 + index = 0 + while unmatched_start > 0 and index < len(rest): + if rest[index:].startswith(end): + unmatched_start -= 1 + if unmatched_start == 0: + break + index += len(end) + elif rest[index:].startswith(start): + unmatched_start += 1 + index += len(start) + else: + index += 1 + + assert index < len(rest), ( + 'cannot find end "{end}" while extracting substring ' + 'starting with "{start}"'.format(start=start, end=end) + ) + return rest[:index], rest[index + len(end) :] + return "", split[0] def extract_td_op_info(op_def): - """Extracts potentially manually specified sections in op's definition. - - Arguments: - A string containing the op's TableGen definition - - Returns: - - A dict containing potential manually specified sections - """ - # Get opname - opname = [o[8:-2] for o in re.findall('def SPIRV_\w+Op', op_def)] - assert len(opname) == 1, 'more than one ops in the same section!' - opname = opname[0] - - # Get instruction category - inst_category = [ - o[4:] for o in re.findall('SPIRV_\w+Op', - op_def.split(':', 1)[1]) - ] - assert len(inst_category) <= 1, 'more than one ops in the same section!' - inst_category = inst_category[0] if len(inst_category) == 1 else 'Op' - - # Get category_args - op_tmpl_params, _ = get_string_between_nested(op_def, '<', '>') - opstringname, rest = get_string_between(op_tmpl_params, '"', '"') - category_args = rest.split('[', 1)[0] - - # Get traits - traits, _ = get_string_between_nested(rest, '[', ']') - - # Get description - description, rest = get_string_between(op_def, 'let description = [{\n', - '}];\n') - - # Get arguments - args, rest = get_string_between(rest, ' let arguments = (ins', ');\n') - - # Get results - results, rest = get_string_between(rest, ' let results = (outs', ');\n') - - extras = rest.strip(' }\n') - if extras: - extras = '\n {}\n'.format(extras) - - return { - # Prefix with 'Op' to make it consistent with SPIR-V spec - 'opname': 'Op{}'.format(opname), - 'inst_category': inst_category, - 'category_args': category_args, - 'traits': traits, - 'description': description, - 'arguments': args, - 'results': results, - 'extras': extras - } - - -def update_td_op_definitions(path, instructions, docs, filter_list, - inst_category, capability_mapping, settings): - """Updates SPIRVOps.td with newly generated op definition. - - Arguments: - - path: path to SPIRVOps.td - - instructions: SPIR-V JSON grammar for all instructions - - docs: SPIR-V HTML doc for all instructions - - filter_list: a list containing new opnames to include - - capability_mapping: mapping from duplicated capability symbols to the - canonicalized symbol chosen for SPIRVBase.td. - - Returns: - - A string containing all the TableGen op definitions - """ - with open(path, 'r') as f: - content = f.read() - - # Split the file into chunks, each containing one op. - ops = content.split(AUTOGEN_OP_DEF_SEPARATOR) - header = ops[0] - footer = ops[-1] - ops = ops[1:-1] - - # For each existing op, extract the manually-written sections out to retain - # them when re-generating the ops. Also append the existing ops to filter - # list. - name_op_map = {} # Map from opname to its existing ODS definition - op_info_dict = {} - for op in ops: - info_dict = extract_td_op_info(op) - opname = info_dict['opname'] - name_op_map[opname] = op - op_info_dict[opname] = info_dict - filter_list.append(opname) - filter_list = sorted(list(set(filter_list))) - - op_defs = [] - - if settings.gen_cl_ops: - fix_opname = lambda src: src.replace('CL','').lower() - else: - fix_opname = lambda src: src - - for opname in filter_list: - # Find the grammar spec for this op - try: - fixed_opname = fix_opname(opname) - instruction = next( - inst for inst in instructions if inst['opname'] == fixed_opname) - - op_defs.append( - get_op_definition( - instruction, opname, docs[fixed_opname], - op_info_dict.get(opname, {'inst_category': inst_category}), - capability_mapping, settings)) - except StopIteration: - # This is an op added by us; use the existing ODS definition. - op_defs.append(name_op_map[opname]) - - # Substitute the old op definitions - op_defs = [header] + op_defs + [footer] - content = AUTOGEN_OP_DEF_SEPARATOR.join(op_defs) - - with open(path, 'w') as f: - f.write(content) - - -if __name__ == '__main__': - import argparse - - cli_parser = argparse.ArgumentParser( - description='Update SPIR-V dialect definitions using SPIR-V spec') - - cli_parser.add_argument( - '--base-td-path', - dest='base_td_path', - type=str, - default=None, - help='Path to SPIRVBase.td') - cli_parser.add_argument( - '--op-td-path', - dest='op_td_path', - type=str, - default=None, - help='Path to SPIRVOps.td') - - cli_parser.add_argument( - '--new-enum', - dest='new_enum', - type=str, - default=None, - help='SPIR-V enum to be added to SPIRVBase.td') - cli_parser.add_argument( - '--new-opcodes', - dest='new_opcodes', - type=str, - default=None, - nargs='*', - help='update SPIR-V opcodes in SPIRVBase.td') - cli_parser.add_argument( - '--new-inst', - dest='new_inst', - type=str, - default=None, - nargs='*', - help='SPIR-V instruction to be added to ops file') - cli_parser.add_argument( - '--inst-category', - dest='inst_category', - type=str, - default='Op', - help='SPIR-V instruction category used for choosing '\ - 'the TableGen base class to define this op') - cli_parser.add_argument( - '--gen-cl-ops', - dest='gen_cl_ops', - help='Generate OpenCL Extended Instruction Set op', - action='store_true') - cli_parser.set_defaults(gen_cl_ops=False) - cli_parser.add_argument('--gen-inst-coverage', dest='gen_inst_coverage', action='store_true') - cli_parser.set_defaults(gen_inst_coverage=False) - - args = cli_parser.parse_args() - - if args.gen_cl_ops: - ext_html_url = SPIRV_CL_EXT_HTML_SPEC_URL - ext_json_url = SPIRV_CL_EXT_JSON_SPEC_URL - else: - ext_html_url = None - ext_json_url = None - - operand_kinds, instructions = get_spirv_grammar_from_json_spec(ext_json_url) - - # Define new enum attr - if args.new_enum is not None: - assert args.base_td_path is not None - filter_list = [args.new_enum] if args.new_enum else [] - update_td_enum_attrs(args.base_td_path, operand_kinds, filter_list) - - # Define new opcode - if args.new_opcodes is not None: - assert args.base_td_path is not None - update_td_opcodes(args.base_td_path, instructions, args.new_opcodes) - - # Define new op - if args.new_inst is not None: - assert args.op_td_path is not None - docs = get_spirv_doc_from_html_spec(ext_html_url, args) - capability_mapping = get_capability_mapping(operand_kinds) - update_td_op_definitions(args.op_td_path, instructions, docs, args.new_inst, - args.inst_category, capability_mapping, args) - print('Done. Note that this script just generates a template; ', end='') - print('please read the spec and update traits, arguments, and ', end='') - print('results accordingly.') - - if args.gen_inst_coverage: - gen_instr_coverage_report(args.base_td_path, instructions) + """Extracts potentially manually specified sections in op's definition. + + Arguments: - A string containing the op's TableGen definition + + Returns: + - A dict containing potential manually specified sections + """ + # Get opname + opname = [o[8:-2] for o in re.findall("def SPIRV_\w+Op", op_def)] + assert len(opname) == 1, "more than one ops in the same section!" + opname = opname[0] + + # Get instruction category + inst_category = [o[4:] for o in re.findall("SPIRV_\w+Op", op_def.split(":", 1)[1])] + assert len(inst_category) <= 1, "more than one ops in the same section!" + inst_category = inst_category[0] if len(inst_category) == 1 else "Op" + + # Get category_args + op_tmpl_params, _ = get_string_between_nested(op_def, "<", ">") + opstringname, rest = get_string_between(op_tmpl_params, '"', '"') + category_args = rest.split("[", 1)[0] + + # Get traits + traits, _ = get_string_between_nested(rest, "[", "]") + + # Get description + description, rest = get_string_between(op_def, "let description = [{\n", "}];\n") + + # Get arguments + args, rest = get_string_between(rest, " let arguments = (ins", ");\n") + + # Get results + results, rest = get_string_between(rest, " let results = (outs", ");\n") + + extras = rest.strip(" }\n") + if extras: + extras = "\n {}\n".format(extras) + + return { + # Prefix with 'Op' to make it consistent with SPIR-V spec + "opname": "Op{}".format(opname), + "inst_category": inst_category, + "category_args": category_args, + "traits": traits, + "description": description, + "arguments": args, + "results": results, + "extras": extras, + } + + +def update_td_op_definitions( + path, instructions, docs, filter_list, inst_category, capability_mapping, settings +): + """Updates SPIRVOps.td with newly generated op definition. + + Arguments: + - path: path to SPIRVOps.td + - instructions: SPIR-V JSON grammar for all instructions + - docs: SPIR-V HTML doc for all instructions + - filter_list: a list containing new opnames to include + - capability_mapping: mapping from duplicated capability symbols to the + canonicalized symbol chosen for SPIRVBase.td. + + Returns: + - A string containing all the TableGen op definitions + """ + with open(path, "r") as f: + content = f.read() + + # Split the file into chunks, each containing one op. + ops = content.split(AUTOGEN_OP_DEF_SEPARATOR) + header = ops[0] + footer = ops[-1] + ops = ops[1:-1] + + # For each existing op, extract the manually-written sections out to retain + # them when re-generating the ops. Also append the existing ops to filter + # list. + name_op_map = {} # Map from opname to its existing ODS definition + op_info_dict = {} + for op in ops: + info_dict = extract_td_op_info(op) + opname = info_dict["opname"] + name_op_map[opname] = op + op_info_dict[opname] = info_dict + filter_list.append(opname) + filter_list = sorted(list(set(filter_list))) + + op_defs = [] + + if settings.gen_cl_ops: + fix_opname = lambda src: src.replace("CL", "").lower() + else: + fix_opname = lambda src: src + + for opname in filter_list: + # Find the grammar spec for this op + try: + fixed_opname = fix_opname(opname) + instruction = next( + inst for inst in instructions if inst["opname"] == fixed_opname + ) + + op_defs.append( + get_op_definition( + instruction, + opname, + docs[fixed_opname], + op_info_dict.get(opname, {"inst_category": inst_category}), + capability_mapping, + settings, + ) + ) + except StopIteration: + # This is an op added by us; use the existing ODS definition. + op_defs.append(name_op_map[opname]) + + # Substitute the old op definitions + op_defs = [header] + op_defs + [footer] + content = AUTOGEN_OP_DEF_SEPARATOR.join(op_defs) + + with open(path, "w") as f: + f.write(content) + + +if __name__ == "__main__": + import argparse + + cli_parser = argparse.ArgumentParser( + description="Update SPIR-V dialect definitions using SPIR-V spec" + ) + + cli_parser.add_argument( + "--base-td-path", + dest="base_td_path", + type=str, + default=None, + help="Path to SPIRVBase.td", + ) + cli_parser.add_argument( + "--op-td-path", + dest="op_td_path", + type=str, + default=None, + help="Path to SPIRVOps.td", + ) + + cli_parser.add_argument( + "--new-enum", + dest="new_enum", + type=str, + default=None, + help="SPIR-V enum to be added to SPIRVBase.td", + ) + cli_parser.add_argument( + "--new-opcodes", + dest="new_opcodes", + type=str, + default=None, + nargs="*", + help="update SPIR-V opcodes in SPIRVBase.td", + ) + cli_parser.add_argument( + "--new-inst", + dest="new_inst", + type=str, + default=None, + nargs="*", + help="SPIR-V instruction to be added to ops file", + ) + cli_parser.add_argument( + "--inst-category", + dest="inst_category", + type=str, + default="Op", + help="SPIR-V instruction category used for choosing " + "the TableGen base class to define this op", + ) + cli_parser.add_argument( + "--gen-cl-ops", + dest="gen_cl_ops", + help="Generate OpenCL Extended Instruction Set op", + action="store_true", + ) + cli_parser.set_defaults(gen_cl_ops=False) + cli_parser.add_argument( + "--gen-inst-coverage", dest="gen_inst_coverage", action="store_true" + ) + cli_parser.set_defaults(gen_inst_coverage=False) + + args = cli_parser.parse_args() + + if args.gen_cl_ops: + ext_html_url = SPIRV_CL_EXT_HTML_SPEC_URL + ext_json_url = SPIRV_CL_EXT_JSON_SPEC_URL + else: + ext_html_url = None + ext_json_url = None + + operand_kinds, instructions = get_spirv_grammar_from_json_spec(ext_json_url) + + # Define new enum attr + if args.new_enum is not None: + assert args.base_td_path is not None + filter_list = [args.new_enum] if args.new_enum else [] + update_td_enum_attrs(args.base_td_path, operand_kinds, filter_list) + + # Define new opcode + if args.new_opcodes is not None: + assert args.base_td_path is not None + update_td_opcodes(args.base_td_path, instructions, args.new_opcodes) + + # Define new op + if args.new_inst is not None: + assert args.op_td_path is not None + docs = get_spirv_doc_from_html_spec(ext_html_url, args) + capability_mapping = get_capability_mapping(operand_kinds) + update_td_op_definitions( + args.op_td_path, + instructions, + docs, + args.new_inst, + args.inst_category, + capability_mapping, + args, + ) + print("Done. Note that this script just generates a template; ", end="") + print("please read the spec and update traits, arguments, and ", end="") + print("results accordingly.") + + if args.gen_inst_coverage: + gen_instr_coverage_report(args.base_td_path, instructions) -- 2.7.4