[mlir] provide Python bindings for the Transform dialect
authorAlex Zinenko <zinenko@google.com>
Mon, 30 May 2022 13:14:02 +0000 (15:14 +0200)
committerAlex Zinenko <zinenko@google.com>
Mon, 30 May 2022 15:37:52 +0000 (17:37 +0200)
Python bindings for extensions of the Transform dialect are defined in separate
Python source files that can be imported on-demand, i.e., that are not imported
with the "main" transform dialect. This requires a minor addition to the
ODS-based bindings generator. This approach is consistent with the current
model for downstream projects that are expected to bundle MLIR Python bindings:
such projects can include their custom extensions into the bundle similarly to
how they include their dialects.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D126208

12 files changed:
mlir/cmake/modules/AddMLIRPython.cmake
mlir/python/CMakeLists.txt
mlir/python/mlir/dialects/LinalgStructuredTransformOps.td [new file with mode: 0644]
mlir/python/mlir/dialects/TransformOps.td [new file with mode: 0644]
mlir/python/mlir/dialects/_structured_transform_ops_ext.py [new file with mode: 0644]
mlir/python/mlir/dialects/_transform_ops_ext.py [new file with mode: 0644]
mlir/python/mlir/dialects/transform/__init__.py [new file with mode: 0644]
mlir/python/mlir/dialects/transform/structured.py [new file with mode: 0644]
mlir/test/python/dialects/transform.py [new file with mode: 0644]
mlir/test/python/dialects/transform_structured_ext.py [new file with mode: 0644]
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel

index 8750b10..ff9edd1 100644 (file)
@@ -355,6 +355,61 @@ function(declare_mlir_dialect_python_bindings)
   endif()
 endfunction()
 
+# Function: declare_mlir_dialect_extension_python_bindings
+# Helper to generate source groups for dialect extensions, including both
+# static source files and a TD_FILE to generate wrappers.
+#
+# This will generate a source group named ${ADD_TO_PARENT}.${EXTENSION_NAME}.
+#
+# Arguments:
+#   ROOT_DIR: Same as for declare_mlir_python_sources().
+#   ADD_TO_PARENT: Same as for declare_mlir_python_sources(). Unique names
+#     for the subordinate source groups are derived from this.
+#   TD_FILE: Tablegen file to generate source for (relative to ROOT_DIR).
+#   DIALECT_NAME: Python name of the dialect.
+#   EXTENSION_NAME: Python name of the dialect extension.
+#   SOURCES: Same as declare_mlir_python_sources().
+#   SOURCES_GLOB: Same as declare_mlir_python_sources().
+#   DEPENDS: Additional dependency targets.
+function(declare_mlir_dialect_extension_python_bindings)
+  cmake_parse_arguments(ARG
+    ""
+    "ROOT_DIR;ADD_TO_PARENT;TD_FILE;DIALECT_NAME;EXTENSION_NAME"
+    "SOURCES;SOURCES_GLOB;DEPENDS"
+    ${ARGN})
+  # Source files.
+  set(_extension_target "${ARG_ADD_TO_PARENT}.${ARG_EXTENSION_NAME}")
+  declare_mlir_python_sources(${_extension_target}
+    ROOT_DIR "${ARG_ROOT_DIR}"
+    ADD_TO_PARENT "${ARG_ADD_TO_PARENT}"
+    SOURCES "${ARG_SOURCES}"
+    SOURCES_GLOB "${ARG_SOURCES_GLOB}"
+  )
+
+  # Tablegen
+  if(ARG_TD_FILE)
+    set(tblgen_target "${ARG_ADD_TO_PARENT}.${ARG_EXTENSION_NAME}.tablegen")
+    set(td_file "${ARG_ROOT_DIR}/${ARG_TD_FILE}")
+    get_filename_component(relative_td_directory "${ARG_TD_FILE}" DIRECTORY)
+    file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${relative_td_directory}")
+    set(output_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_ops_gen.py")
+    set(LLVM_TARGET_DEFINITIONS ${td_file})
+    mlir_tablegen("${output_filename}" -gen-python-op-bindings
+                  -bind-dialect=${ARG_DIALECT_NAME}
+                  -dialect-extension=${ARG_EXTENSION_NAME})
+    add_public_tablegen_target(${tblgen_target})
+    if(ARG_DEPENDS)
+      add_dependencies(${tblgen_target} ${ARG_DEPENDS})
+    endif()
+
+    declare_mlir_python_sources("${_extension_target}.ops_gen"
+      ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
+      ADD_TO_PARENT "${_extension_target}"
+      SOURCES "${output_filename}"
+    )
+  endif()
+endfunction()
+
 # Function: mlir_python_setup_extension_rpath
 # Sets RPATH properties on a target, assuming that it is being output to
 # an _mlir_libs directory with all other libraries. For static linkage,
index d280bf1..17048e8 100644 (file)
@@ -119,6 +119,25 @@ declare_mlir_dialect_python_bindings(
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  TD_FILE dialects/TransformOps.td
+  SOURCES
+    dialects/_transform_ops_ext.py
+    dialects/transform/__init__.py
+  DIALECT_NAME transform)
+
+declare_mlir_dialect_extension_python_bindings(
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  TD_FILE dialects/LinalgStructuredTransformOps.td
+  SOURCES
+    dialects/_structured_transform_ops_ext.py
+    dialects/transform/structured.py
+  DIALECT_NAME transform
+  EXTENSION_NAME structured_transform)
+
+declare_mlir_dialect_python_bindings(
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/MathOps.td
   SOURCES dialects/math.py
   DIALECT_NAME math)
diff --git a/mlir/python/mlir/dialects/LinalgStructuredTransformOps.td b/mlir/python/mlir/dialects/LinalgStructuredTransformOps.td
new file mode 100644 (file)
index 0000000..a9a53fe
--- /dev/null
@@ -0,0 +1,21 @@
+//===-- LinalgStructuredTransformOps.td --------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Entry point of the Python bindings generator for the structured transform ops
+// provided by Linalg (and other dialects).
+//
+//===----------------------------------------------------------------------===//
+
+
+#ifndef PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS
+#define PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS
+
+include "mlir/Bindings/Python/Attributes.td"
+include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td"
+
+#endif // PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS
diff --git a/mlir/python/mlir/dialects/TransformOps.td b/mlir/python/mlir/dialects/TransformOps.td
new file mode 100644 (file)
index 0000000..7f0d80e
--- /dev/null
@@ -0,0 +1,15 @@
+//===-- TransformOps.td - Transform ops bind entry point ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_TRANSFORM_OPS
+#define PYTHON_BINDINGS_TRANSFORM_OPS
+
+include "mlir/Bindings/Python/Attributes.td"
+include "mlir/Dialect/Transform/IR/TransformOps.td"
+
+#endif // PYTHON_BINDINGS_TRANSFORM_OPS
diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
new file mode 100644 (file)
index 0000000..70e39be
--- /dev/null
@@ -0,0 +1,178 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+try:
+  from ..ir import *
+  from ._ods_common import get_op_result_or_value as _get_op_result_or_value
+  from ..dialects import pdl
+except ImportError as e:
+  raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import List, Optional, Sequence, Union
+
+IntOrAttrList = Sequence[Union[IntegerAttr, int]]
+OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
+
+
+def _get_array_attr(
+    values: Optional[Union[ArrayAttr, Sequence[Attribute]]]) -> ArrayAttr:
+  """Creates an array attribute from its operand."""
+  if values is None:
+    return ArrayAttr.get([])
+  if isinstance(values, ArrayAttr):
+    return values
+
+  return ArrayAttr.get(values)
+
+
+def _get_int_array_attr(
+    values: Optional[Union[ArrayAttr, Sequence[Union[IntegerAttr, int]]]]
+) -> ArrayAttr:
+  """Creates an integer array attribute from its operand.
+
+  If the operand is already an array attribute, forwards it. Otherwise treats
+  the operand as a list of attributes or integers, possibly intersperced, to
+  create a new array attribute containing integer attributes. Expects the
+  thread-local MLIR context to have been set by the context manager.
+  """
+  if values is None:
+    return ArrayAttr.get([])
+  if isinstance(values, ArrayAttr):
+    return values
+
+  attributes = []
+  for value in values:
+    if isinstance(value, IntegerAttr):
+      attributes.append(value)
+    else:
+      attributes.append(IntegerAttr.get(IntegerType.get_signless(64), value))
+  return ArrayAttr.get(attributes)
+
+
+def _get_int_int_array_attr(
+    values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr,
+                                                     IntOrAttrList]]]]
+) -> ArrayAttr:
+  """Creates an array attribute containing array attributes of integers.
+
+  If the operand is already an array attribute, forwards it. Otherwise treats
+  the operand as a list of attributes or integers, potentially interpserced, to
+  create a new array-of-array attribute. Expects the thread-local MLIR context
+  to have been set by the context manager.
+  """
+  if values is None:
+    return ArrayAttr.get([])
+  if isinstance(values, ArrayAttr):
+    return values
+
+  return ArrayAttr.get([_get_int_array_attr(value) for value in values])
+
+
+class InterchangeOp:
+  """Specialization for InterchangeOp class."""
+
+  def __init__(self,
+               target: Union[Operation, Value],
+               *,
+               iterator_interchange: OptionalIntList = None,
+               loc=None,
+               ip=None):
+    pdl_operation_type = pdl.OperationType.get()
+    interchange_attr = _get_int_array_attr(iterator_interchange)
+    super().__init__(
+        pdl_operation_type,
+        _get_op_result_or_value(target),
+        iterator_interchange=interchange_attr,
+        loc=loc,
+        ip=ip)
+
+
+class PadOp:
+  """Specialization for PadOp class."""
+
+  def __init__(self,
+               target: Union[Operation, Value],
+               *,
+               padding_values: Optional[Union[ArrayAttr,
+                                              Sequence[Attribute]]] = None,
+               padding_dimensions: OptionalIntList = None,
+               pack_paddings: OptionalIntList = None,
+               hoist_paddings: OptionalIntList = None,
+               transpose_paddings: Optional[Union[ArrayAttr, Sequence[Union[
+                   ArrayAttr, IntOrAttrList]]]] = None,
+               loc=None,
+               ip=None):
+    pdl_operation_type = pdl.OperationType.get()
+    padding_values_attr = _get_array_attr(padding_values)
+    padding_dimensions_attr = _get_int_array_attr(padding_dimensions)
+    pack_paddings_attr = _get_int_array_attr(pack_paddings)
+    hoist_paddings_attr = _get_int_array_attr(hoist_paddings)
+    transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings)
+    super().__init__(
+        pdl_operation_type,
+        _get_op_result_or_value(target),
+        padding_values=padding_values_attr,
+        padding_dimensions=padding_dimensions_attr,
+        pack_paddings=pack_paddings_attr,
+        hoist_paddings=hoist_paddings_attr,
+        transpose_paddings=transpose_paddings_attr,
+        loc=loc,
+        ip=ip)
+
+
+class ScalarizeOp:
+  """Specialization for ScalarizeOp class."""
+
+  def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
+    pdl_operation_type = pdl.OperationType.get()
+    super().__init__(
+        pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip)
+
+
+class TileOp:
+  """Specialization for TileOp class."""
+
+  def __init__(self,
+               target: Union[Operation, Value],
+               *,
+               sizes: OptionalIntList = None,
+               interchange: OptionalIntList = None,
+               loc=None,
+               ip=None):
+    pdl_operation_type = pdl.OperationType.get()
+    sizes_attr = _get_int_array_attr(sizes)
+    num_loops = sum(
+        v if v == 0 else 1 for v in self.__extract_values(sizes_attr))
+    super().__init__(
+        pdl_operation_type, [pdl_operation_type] * num_loops,
+        _get_op_result_or_value(target),
+        sizes=sizes_attr,
+        interchange=_get_int_array_attr(interchange) if interchange else None,
+        loc=loc,
+        ip=ip)
+
+  def __extract_values(self, attr: Optional[ArrayAttr]) -> List[int]:
+    if not attr:
+      return []
+    return [IntegerAttr(element).value for element in attr]
+
+
+class VectorizeOp:
+  """Specialization for VectorizeOp class."""
+
+  def __init__(self,
+               target: Union[Operation, Value],
+               *,
+               vectorize_padding: Union[bool, BoolAttr] = False,
+               loc=None,
+               ip=None):
+    pdl_operation_type = pdl.OperationType.get()
+    if isinstance(vectorize_padding, bool):
+      vectorize_padding = BoolAttr.get(vectorize_padding)
+    super().__init__(
+        pdl_operation_type,
+        _get_op_result_or_value(target),
+        vectorize_padding=vectorize_padding,
+        loc=loc,
+        ip=ip)
diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py
new file mode 100644 (file)
index 0000000..138195d
--- /dev/null
@@ -0,0 +1,106 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+try:
+  from ..ir import *
+  from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
+  from ..dialects import pdl
+except ImportError as e:
+  raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, overload, Sequence, Union
+
+
+def _get_symbol_ref_attr(value: Union[Attribute, str]):
+  if isinstance(value, Attribute):
+    return value
+  return FlatSymbolRefAttr.get(value)
+
+
+class GetClosestIsolatedParentOp:
+
+  def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
+    super().__init__(
+        pdl.OperationType.get(),
+        _get_op_result_or_value(target),
+        loc=loc,
+        ip=ip)
+
+
+class PDLMatchOp:
+
+  def __init__(self,
+               target: Union[Operation, Value],
+               pattern_name: Union[Attribute, str],
+               *,
+               loc=None,
+               ip=None):
+    super().__init__(
+        pdl.OperationType.get(),
+        _get_op_result_or_value(target),
+        _get_symbol_ref_attr(pattern_name),
+        loc=loc,
+        ip=ip)
+
+
+class SequenceOp:
+
+  @overload
+  def __init__(self, resultsOrRoot: Sequence[Type],
+               optionalRoot: Optional[Union[Operation, Value]]):
+    ...
+
+  @overload
+  def __init__(self, resultsOrRoot: Optional[Union[Operation, Value]],
+               optionalRoot: NoneType):
+    ...
+
+  def __init__(self, resultsOrRoot=None, optionalRoot=None):
+    results = resultsOrRoot if isinstance(resultsOrRoot, Sequence) else []
+    root = (
+        resultsOrRoot
+        if not isinstance(resultsOrRoot, Sequence) else optionalRoot)
+    root = _get_op_result_or_value(root) if root else None
+    super().__init__(results_=results, root=root)
+    self.regions[0].blocks.append(pdl.OperationType.get())
+
+  @property
+  def body(self) -> Block:
+    return self.regions[0].blocks[0]
+
+  @property
+  def bodyTarget(self) -> Value:
+    return self.body.arguments[0]
+
+
+class WithPDLPatternsOp:
+
+  def __init__(self,
+               target: Optional[Union[Operation, Value]] = None,
+               *,
+               loc=None,
+               ip=None):
+    super().__init__(
+        root=_get_op_result_or_value(target) if target else None,
+        loc=loc,
+        ip=ip)
+    self.regions[0].blocks.append(pdl.OperationType.get())
+
+  @property
+  def body(self) -> Block:
+    return self.regions[0].blocks[0]
+
+  @property
+  def bodyTarget(self) -> Value:
+    return self.body.arguments[0]
+
+
+class YieldOp:
+
+  def __init__(self,
+               operands: Union[Operation, Sequence[Value]] = [],
+               *,
+               loc=None,
+               ip=None):
+    super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
new file mode 100644 (file)
index 0000000..ab4fa56
--- /dev/null
@@ -0,0 +1,5 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from .._transform_ops_gen import *
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
new file mode 100644 (file)
index 0000000..b8ee48c
--- /dev/null
@@ -0,0 +1,5 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from .._structured_transform_ops_gen import *
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
new file mode 100644 (file)
index 0000000..4722017
--- /dev/null
@@ -0,0 +1,84 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import transform
+from mlir.dialects import pdl
+
+
+def run(f):
+  with Context(), Location.unknown():
+    module = Module.create()
+    with InsertionPoint(module.body):
+      print("\nTEST:", f.__name__)
+      f()
+    print(module)
+  return f
+
+
+@run
+def testSequenceOp():
+  sequence = transform.SequenceOp([pdl.OperationType.get()])
+  with InsertionPoint(sequence.body):
+    transform.YieldOp([sequence.bodyTarget])
+  # CHECK-LABEL: TEST: testSequenceOp
+  # CHECK: = transform.sequence {
+  # CHECK: ^{{.*}}(%[[ARG0:.+]]: !pdl.operation):
+  # CHECK:   yield %[[ARG0]] : !pdl.operation
+  # CHECK: } : !pdl.operation
+
+
+@run
+def testNestedSequenceOp():
+  sequence = transform.SequenceOp()
+  with InsertionPoint(sequence.body):
+    nested = transform.SequenceOp(sequence.bodyTarget)
+    with InsertionPoint(nested.body):
+      doubly_nested = transform.SequenceOp([pdl.OperationType.get()],
+                                           nested.bodyTarget)
+      with InsertionPoint(doubly_nested.body):
+        transform.YieldOp([doubly_nested.bodyTarget])
+      transform.YieldOp()
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: testNestedSequenceOp
+  # CHECK: transform.sequence {
+  # CHECK: ^{{.*}}(%[[ARG0:.+]]: !pdl.operation):
+  # CHECK:   sequence %[[ARG0]] {
+  # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !pdl.operation):
+  # CHECK:     = sequence %[[ARG1]] {
+  # CHECK:     ^{{.*}}(%[[ARG2:.+]]: !pdl.operation):
+  # CHECK:       yield %[[ARG2]] : !pdl.operation
+  # CHECK:     } : !pdl.operation
+  # CHECK:   }
+  # CHECK: }
+
+
+@run
+def testTransformPDLOps():
+  withPdl = transform.WithPDLPatternsOp()
+  with InsertionPoint(withPdl.body):
+    sequence = transform.SequenceOp([pdl.OperationType.get()],
+                                    withPdl.bodyTarget)
+    with InsertionPoint(sequence.body):
+      match = transform.PDLMatchOp(sequence.bodyTarget, "pdl_matcher")
+      transform.YieldOp(match)
+  # CHECK-LABEL: TEST: testTransformPDLOps
+  # CHECK: transform.with_pdl_patterns {
+  # CHECK: ^{{.*}}(%[[ARG0:.+]]: !pdl.operation):
+  # CHECK:   = sequence %[[ARG0]] {
+  # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !pdl.operation):
+  # CHECK:     %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
+  # CHECK:     yield %[[RES]] : !pdl.operation
+  # CHECK:   } : !pdl.operation
+  # CHECK: }
+
+
+@run
+def testGetClosestIsolatedParentOp():
+  sequence = transform.SequenceOp()
+  with InsertionPoint(sequence.body):
+    transform.GetClosestIsolatedParentOp(sequence.bodyTarget)
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: testGetClosestIsolatedParentOp
+  # CHECK: transform.sequence
+  # CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation):
+  # CHECK:   = get_closest_isolated_parent %[[ARG1]]
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
new file mode 100644 (file)
index 0000000..463dec1
--- /dev/null
@@ -0,0 +1,118 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import transform
+from mlir.dialects import pdl
+from mlir.dialects.transform import structured
+
+
+def run(f):
+  with Context(), Location.unknown():
+    module = Module.create()
+    with InsertionPoint(module.body):
+      print("\nTEST:", f.__name__)
+      f()
+    print(module)
+  return f
+
+
+@run
+def testInterchange():
+  sequence = transform.SequenceOp()
+  with InsertionPoint(sequence.body):
+    structured.InterchangeOp(
+        sequence.bodyTarget,
+        iterator_interchange=[
+            IntegerAttr.get(IntegerType.get_signless(64), 1), 0
+        ])
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: testInterchange
+  # CHECK: transform.sequence
+  # CHECK: transform.structured.interchange
+  # CHECK: iterator_interchange = [1, 0]
+
+
+@run
+def testPad():
+  sequence = transform.SequenceOp()
+  with InsertionPoint(sequence.body):
+    structured.PadOp(
+        sequence.bodyTarget,
+        padding_values=[FloatAttr.get_f32(42.0)],
+        padding_dimensions=[1],
+        transpose_paddings=[[1, 0]])
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: testPad
+  # CHECK: transform.sequence
+  # CHECK: transform.structured.pad
+  # CHECK-DAG: padding_values = [4.200000e+01 : f32]
+  # CHECK-DAG: padding_dimensions = [1]
+  # CHECK-DAG: transpose_paddings = {{\[}}[1, 0]]
+  # CHECK-DAG: hoist_paddings = []
+  # CHECK-DAG: pack_paddings = []
+
+
+@run
+def testScalarize():
+  sequence = transform.SequenceOp()
+  with InsertionPoint(sequence.body):
+    structured.ScalarizeOp(sequence.bodyTarget)
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: testScalarize
+  # CHECK: transform.structured.scalarize
+
+
+@run
+def testTileCompact():
+  sequence = transform.SequenceOp()
+  with InsertionPoint(sequence.body):
+    structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1])
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: testTileCompact
+  # CHECK: transform.sequence
+  # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile
+  # CHECK-DAG: interchange = [0, 1]
+  # CHECK-DAG: sizes = [4, 8]
+
+
+@run
+def testTileAttributes():
+  sequence = transform.SequenceOp()
+  attr = ArrayAttr.get(
+      [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [4, 8]])
+  ichange = ArrayAttr.get(
+      [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [0, 1]])
+  with InsertionPoint(sequence.body):
+    structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange)
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: testTileAttributes
+  # CHECK: transform.sequence
+  # CHECK: structured.tile
+  # CHECK-DAG: interchange = [0, 1]
+  # CHECK-DAG: sizes = [4, 8]
+
+
+@run
+def testTileZero():
+  sequence = transform.SequenceOp()
+  with InsertionPoint(sequence.body):
+    structured.TileOp(
+        sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3])
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: testTileZero
+  # CHECK: transform.sequence
+  # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile
+  # CHECK-DAG: interchange = [0, 1, 2, 3]
+  # CHECK-DAG: sizes = [4, 0, 2, 0]
+
+
+@run
+def testVectorize():
+  sequence = transform.SequenceOp()
+  with InsertionPoint(sequence.body):
+    structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True)
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: testVectorize
+  # CHECK: transform.sequence
+  # CHECK: = transform.structured.vectorize
+  # CHECK: vectorize_padding = true
index 44c60c1..d6b3aa6 100644 (file)
@@ -50,6 +50,10 @@ class _Dialect(_ods_ir.Dialect):
 
 )Py";
 
+constexpr const char *dialectExtensionTemplate = R"Py(
+from ._{0}_ops_gen import _Dialect
+)Py";
+
 /// Template for operation class:
 ///   {0} is the Python class name;
 ///   {1} is the operation name.
@@ -270,6 +274,10 @@ static llvm::cl::opt<std::string>
                   llvm::cl::desc("The dialect to run the generator for"),
                   llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
 
+static llvm::cl::opt<std::string> clDialectExtensionName(
+    "dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
+    llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
+
 using AttributeClasses = DenseMap<StringRef, StringRef>;
 
 /// Checks whether `str` is a Python keyword.
@@ -1014,8 +1022,14 @@ static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
   AttributeClasses attributeClasses;
   constructAttributeMapping(records, attributeClasses);
 
-  os << llvm::formatv(fileHeader, clDialectName.getValue());
-  os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
+  bool isExtension = !clDialectExtensionName.empty();
+  os << llvm::formatv(fileHeader, isExtension
+                                      ? clDialectExtensionName.getValue()
+                                      : clDialectName.getValue());
+  if (isExtension)
+    os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue());
+  else
+    os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
 
   for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
     Operator op(rec);
index c3cba2f..c94bc5d 100644 (file)
@@ -826,6 +826,74 @@ filegroup(
 )
 
 ##---------------------------------------------------------------------------##
+# Transform dialect and extensions.
+##---------------------------------------------------------------------------##
+
+td_library(
+    name = "TransformOpsPyTdFiles",
+    srcs = [
+        "//mlir:include/mlir/Bindings/Python/Attributes.td",
+    ],
+    deps = [
+        "//mlir:OpBaseTdFiles",
+        "//mlir:TransformDialectTdFiles",
+    ],
+)
+
+gentbl_filegroup(
+    name = "TransformOpsPyGen",
+    tbl_outs = [
+        (
+            [
+                "-gen-python-op-bindings",
+                "-bind-dialect=transform",
+            ],
+            "mlir/dialects/_transform_ops_gen.py",
+        ),
+    ],
+    tblgen = "//mlir:mlir-tblgen",
+    td_file = "mlir/dialects/TransformOps.td",
+    deps = [
+        ":TransformOpsPyTdFiles",
+    ],
+)
+
+gentbl_filegroup(
+    name = "StructuredTransformOpsPyGen",
+    tbl_outs = [
+        (
+            [
+                "-gen-python-op-bindings",
+                "-bind-dialect=transform",
+                "-dialect-extension=structured_transform",
+            ],
+            "mlir/dialects/_structured_transform_ops_gen.py",
+        ),
+    ],
+    tblgen = "//mlir:mlir-tblgen",
+    td_file = "mlir/dialects/LinalgStructuredTransformOps.td",
+    deps = [
+        ":TransformOpsPyTdFiles",
+        "//mlir:LinalgTransformOpsTdFiles",
+    ],
+)
+
+filegroup(
+    name = "TransformOpsPyFiles",
+    srcs = [
+        "mlir/dialects/_structured_transform_ops_ext.py",
+        "mlir/dialects/_transform_ops_ext.py",
+        ":StructuredTransformOpsPyGen",
+        ":TransformOpsPyGen",
+    ],
+)
+
+filegroup(
+    name = "TransformOpsPackagePyFiles",
+    srcs = glob(["mlir/dialects/transform/*.py"]),
+)
+
+##---------------------------------------------------------------------------##
 # Vector dialect.
 ##---------------------------------------------------------------------------##