add_subdirectory(SPIRV)
add_subdirectory(Tensor)
add_subdirectory(Tosa)
+add_subdirectory(Transform)
add_subdirectory(Vector)
add_subdirectory(X86Vector)
--- /dev/null
+add_subdirectory(IR)
--- /dev/null
+# The dialect does not have its own ops, so just generate the dialect files.
+set(LLVM_TARGET_DEFINITIONS TransformDialect.td)
+mlir_tablegen(TransformDialect.h.inc -gen-dialect-decls -dialect=transform)
+mlir_tablegen(TransformDialect.cpp.inc -gen-dialect-defs -dialect=transform)
+add_public_tablegen_target(MLIRTransformDialectIncGen)
+add_dependencies(mlir-headers MLIRTransformDialectIncGen)
+
+add_mlir_interface(TransformInterfaces)
--- /dev/null
+//===- TransformDialect.h - Transform Dialect Definition --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
+#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
+
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/Support/LLVM.h"
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h.inc"
+
+namespace mlir {
+namespace transform {
+
+#ifndef NDEBUG
+namespace detail {
+/// Asserts that the operations provided as template arguments implement the
+/// TransformOpInterface. This must be a dynamic assertion since interface
+/// implementations may be registered at runtime.
+template <typename OpTy>
+static inline void checkImplementsTransformInterface(MLIRContext *context) {
+ // Since the operation is being inserted into the Transform dialect and the
+ // dialect does not implement the interface fallback, only check for the op
+ // itself having the interface implementation.
+ RegisteredOperationName opName =
+ *RegisteredOperationName::lookup(OpTy::getOperationName(), context);
+ assert(opName.hasInterface<TransformOpInterface>() &&
+ "ops injected into the transform dialect must implement "
+ "TransformOpInterface");
+}
+} // namespace detail
+#endif // NDEBUG
+
+/// Base class for extensions of the Transform dialect that supports injecting
+/// operations into the Transform dialect at load time. Concrete extensions are
+/// expected to derive this class and register operations in the constructor.
+/// They can be registered with the DialectRegistry and automatically applied
+/// to the Transform dialect when it is loaded.
+template <typename DerivedTy, typename... ExtraDialects>
+class TransformDialectExtension
+ : public DialectExtension<DerivedTy, TransformDialect, ExtraDialects...> {
+ using Initializer = std::function<void(TransformDialect *)>;
+ using DialectLoader = std::function<void(MLIRContext *)>;
+
+public:
+ /// Extension application hook. Actually loads the dependent dialects and
+ /// registers the additional operations. Not expected to be called directly.
+ void apply(MLIRContext *context, TransformDialect *transformDialect,
+ ExtraDialects *...) const final {
+ for (const DialectLoader &loader : dialectLoaders)
+ loader(context);
+ for (const Initializer &init : opInitializers)
+ init(transformDialect);
+ }
+
+protected:
+ /// Injects the operations into the Transform dialect. The operations must
+ /// implement the TransformOpInterface and the implementation must be already
+ /// available when the operation is injected.
+ template <typename... OpTys>
+ void registerTransformOps() {
+ opInitializers.push_back([](TransformDialect *transformDialect) {
+ transformDialect->addOperations<OpTys...>();
+
+#ifndef NDEBUG
+ std::initializer_list<int>{
+ (detail::checkImplementsTransformInterface<OpTys>(
+ transformDialect->getContext()),
+ 0)...};
+#endif // NDEBUG
+ });
+ }
+
+ /// Declares that this Transform dialect extension depends on the dialect
+ /// provided as template parameter. When the Transform dialect is loaded,
+ /// dependent dialects will be loaded as well. This is intended for dialects
+ /// that contain attributes and types used in creation and canonicalization of
+ /// the injected operations.
+ template <typename DialectTy>
+ void declareDependentDialect() {
+ dialectLoaders.push_back(
+ [](MLIRContext *context) { context->loadDialect<DialectTy>(); });
+ }
+
+private:
+ SmallVector<Initializer> opInitializers;
+ SmallVector<DialectLoader> dialectLoaders;
+};
+
+} // namespace transform
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
--- /dev/null
+//===- TransformDialect.td - Transform dialect definition --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT
+#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT
+
+include "mlir/IR/OpBase.td"
+
+def Transform_Dialect : Dialect {
+ let summary = "Fine-grain transformation control dialect";
+ let description = [{
+ ## Disclaimer
+
+ ** Proceed with care: not ready for general use. **
+
+ This dialect is evolving rapidly and may change on a very short notice. To
+ decrease the maintenance burden and churn, only a few in-tree use cases are
+ currently supported in the main tree:
+
+ - high-level transformations on "structured ops" (i.e. ops that operate on
+ chunks of data in a way that can be decomposed into operations on
+ smaller chunks of data and control flow) in Linalg, Tensor and Vector
+ dialects.
+
+ *Please post a description of the intended use case on the MLIR forum and
+ wait for confirmation.*
+
+ ## Overview
+
+ This dialect provides operations that can be used to control transformation
+ of the IR using a different portion of the IR. It refers to the IR being
+ transformed as payload IR, and to the IR guiding the transformation as
+ transform IR.
+
+ The main use case for this dialect is orchestrating fine-grain
+ transformations on individual operations or sets thereof. For example, it
+ may involve finding loop-like operations with specific properties (e.g.,
+ large size) in the payload IR, applying loop tiling to those and only those
+ operations, and then applying loop unrolling to the inner loops produced
+ by the previous transformations. As such, it is not intended as a
+ replacement for the pass infrastructure, nor for the pattern rewriting
+ infrastructure. In the most common case, the transform IR will be processed
+ and applied to the payload IR by a pass. Transformations expressed by the
+ transform dialect may be implemented using the pattern infrastructure or any
+ other relevant MLIR component.
+
+ The following IR gives a rough idea of what the operations in this dialect
+ may look like:
+
+ ```mlir
+ %0 = transform.loop.find { size > 42 }
+ %1:2 = transform.loop.tile { tile_sizes = [2,3,4] }
+ transform.loop.unroll %1#1
+ ```
+
+ The values defined by operations in this dialect correspond to (groups of)
+ operations in the payload IR. In the example above, `%0` corresponds to the
+ set of loops found in the payload IR that satisfy the condition, and `%1`
+ correspond to groups of outer and inner loops, respectively, produced by
+ the tiling transformation.
+
+ This dialect is designed to be extensible, that is, clients of this dialect
+ are allowed to inject additional operations into this dialect using the
+ `TransformDialectExtension` mechanism. This allows the dialect to avoid a
+ dependency on the implementation of the transformation as well as to avoid
+ introducing dialect-specific transform dialects. In the example above,
+ the operations may have been injected by a notional `loop` dialect rather
+ than defined in this dialect, hence the common prefix.
+
+ It is recommended to prefix injected operations with one or several
+ dot-separated words that indicate which extension adds them. For
+ dialect-specific transformations, the prefix is naturally the name of the
+ dialect, e.g., `transform.affine.reschedule`. For dialect-agnostic
+ transformations (typically implemented using interfaces), the prefix may
+ be derived from the interface name or from a common concept, e.g.,
+ `transform.loop.tile` may apply to any loop-like operation that implements
+ `TileableOpInterface`. The C++ classes for the dialect extension should
+ include the prefix in their name, e.g., `AffineTransformDialectExtension` or
+ `LoopTransformDialectExtension` in the cases above. Unprefixed operation
+ names are reserved for ops defined directly in the Transform dialect.
+
+ ## Intended Use and Integrations
+
+ The transformation control infrastructure provided by this dialect is
+ positioned roughly between rewrite patterns and passes. A transformation
+ that is executed by a transform operation is likely to be sufficiently
+ complex to require at least a set of patterns to be implemented. It is also
+ expected to be more focused than a pass: a pass typically applies identical
+ transformations everywhere in the IR, a transform dialect-controlled
+ transformation would apply to a small subset of operations selected, e.g.,
+ by a pattern-matching operation or generated by a previous transformation.
+ It is discouraged, although technically possible, to run a pass pipeline as
+ part of the transform op implementation.
+
+ One of the main scenarios for using this dialect is fine-grain chaining of
+ transformations. For example, a loop-like operation may see its iteration
+ domain split into two parts, implemented as separate loops (transformation
+ known as index-set splitting), each of which is then transformed differently
+ (e.g., the first loop is tiled and the second unrolled) with the necessary
+ enabling and cleanup patterns around the main transformation:
+
+ ```mlir
+ // <generate %loop, e.g., by pattern-matching>
+ // ...
+ %parts:2 = transform.loop.split %loop { upper_bound_divisible_by = 8 }
+ transform.loop.tile %parts#0 { tile_sizes = [8] }
+ transform.loop.unroll %parts#1 { full }
+ ```
+
+ This composition would have been difficult to implement as separate passes
+ since the hypothetical "tiling" and "unrolling" pass would need to somehow
+ differentiate between the parts of the loop produced by the previous pass
+ (both are the same operation, and it is likely undesirable to pollute the
+ operation with pass-specific information). Implementing passes that run the
+ combined transfomration would have run into the combinatorial explosion
+ issue due to multiple possible transform compositions or into the need for
+ deep pass parameterization, the ultimate form of which is an ad-hoc dialect
+ to specify which transformations the pass should run. The transform dialect
+ provides a uniform, extensible mechanism for controlling transformations in
+ such cases.
+
+ The transform dialect is supposed to be consumed by an "interpreter" pass
+ that drives the application of transformations. To ensure extensibility and
+ composability, this pass is not expected to actually perform the
+ transformations specified by the ops. Instead, the transformations are
+ implemented by the transform ops themselves via `TransformOpInterface`. The
+ pass serves as the entry point, handles the flow of transform operations and
+ takes care of bookkeeping. As such, the transform dialect does not provide
+ the interpreter pass. Instead, it provides a set of utilities that can be
+ used by clients to define their own interpreter passes or as part of a more
+ complex pass. For example, the mapping between values in the tranfsorm IR
+ and operations in the payload IR, or the function that applies the
+ transformations specified by ops in the given block sequentially. Note that
+ a transform op may have regions with further transform ops in them, with
+ the op itself guiding how to dispatch the transformation control flow to
+ those regions. This approach allows clients to decide on the relative
+ location of the transform IR in their input (e.g., nested modules, separate
+ modules, optional regions to certain operations, etc.), register additional
+ transform operations and perform client-specific bookkeeping.
+
+ ## Effects on the Infrastructure
+
+ Although scoped to a single dialect, this functionality conceptually belongs
+ to the MLIR infrastructure. It aims to be minimally intrusive and opt-in.
+
+ Some infrastructural components may grow extra functionality to support the
+ transform dialect. In particular, the pattern infrastructure may add extra
+ hooks to identify the "main results" of a transformation or to notify
+ external observers about changes made to certain operations. These are not
+ expected to affect the existing uses of the infrastructure.
+
+ For the sake of reusability, transformations should be implemented as
+ utility functions that are called from the interface methods of transform
+ ops rather than having the methods directly act on the payload IR.
+ }];
+
+ let name = "transform";
+ let cppNamespace = "::mlir::transform";
+
+ let extraClassDeclaration = [{
+ // Make addOperations available to the TransformDialectExtension class.
+ private:
+ using ::mlir::Dialect::addOperations;
+
+ template <typename, typename...>
+ friend class TransformDialectExtension;
+ }];
+}
+
+#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT
--- /dev/null
+//===- TransformInterfaces.h - Transform Dialect Interfaces -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H
+#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace transform {
+
+class TransformOpInterface;
+
+/// The state maintained across applications of various ops implementing the
+/// TransformOpInterface. The operations implementing this interface and the
+/// surrounding structure are referred to as transform IR. The operations to
+/// which transformations apply are referred to as payload IR. The state thus
+/// contains the mapping between values defined in the transform IR ops and
+/// payload IR ops. It assumes that each value in the transform IR can be used
+/// at most once (since transformations are likely to change the payload IR ops
+/// the value corresponds to). Checks that transform IR values correspond to
+/// disjoint sets of payload IR ops throughout the transformation.
+///
+/// A reference to this class is passed as an argument to "apply" methods of the
+/// transform op interface. Thus the "apply" method can call
+/// `state.getPayloadOps( getSomeOperand() )` to obtain the list of operations
+/// associated with its operand and subject to transformation. The method is
+/// expected to populate the `TransformResults` class instance in order to
+/// update the mapping. The `applyTransform` method takes care of propagating
+/// the state of `TransformResults` into the instance of this class.
+class TransformState {
+ /// Mapping between a Value in the transform IR and the corresponding set of
+ /// operations in the payload IR.
+ using TransformOpMapping = DenseMap<Value, SmallVector<Operation *>>;
+
+ /// Mapping between a payload IR operation and the transform IR value it is
+ /// currently associated with.
+ using TransformOpReverseMapping = DenseMap<Operation *, Value>;
+
+public:
+ /// Creates a state for the transformation rooted at the given op.
+ explicit TransformState(Operation *root);
+
+ /// Returns the op at which the transformation state is rooted. This is
+ /// typically helpful for transformations that apply globally.
+ Operation *getTopLevel() const;
+
+ /// Returns the list of ops that the given transform IR value corresponds to.
+ /// This is helpful for transformations that apply to a particular handle.
+ ArrayRef<Operation *> getPayloadOps(Value value) const;
+
+ /// Applies the transformation specified by the given transform op and updates
+ /// the state accordingly.
+ LogicalResult applyTransform(TransformOpInterface transform);
+
+private:
+ /// Identifier for storing top-level value in the `operations` mapping.
+ static constexpr Value kTopLevelValue = Value();
+
+ /// Sets the payload IR ops associated with the given transform IR value.
+ /// Fails if this would result in multiple transform IR values with uses
+ /// corresponding to the same payload IR ops. For example, a hypothetical
+ /// "find function by name" transform op would (indirectly) call this
+ /// function for its result. Having two such calls in a row with for different
+ /// values, e.g. coming from different ops:
+ ///
+ /// %0 = transform.find_func_by_name { name = "myfunc" }
+ /// %1 = transform.find_func_by_name { name = "myfunc" }
+ ///
+ /// would lead to both values pointing to the same operation. The second call
+ /// to setPayloadOps will fail, unless the association with the %0 value is
+ /// removed first by calling update/removePayloadOps.
+ LogicalResult setPayloadOps(Value value, ArrayRef<Operation *> targets);
+
+ /// Forgets the payload IR ops associated with the given transform IR value.
+ void removePayloadOps(Value value);
+
+ /// Updates the payload IR ops associated with the given transform IR value.
+ /// The callback function is called once per associated operation and is
+ /// expected to return the modified operation or nullptr. In the latter case,
+ /// the corresponding operation is no longer associated with the transform IR
+ /// value.
+ void updatePayloadOps(Value value,
+ function_ref<Operation *(Operation *)> callback);
+
+ /// The mapping between payload IR values and transform IR ops.
+ TransformOpMapping operationMapping;
+ TransformOpReverseMapping reverseMapping;
+};
+
+/// Local mapping between values defined by a specific op implementing the
+/// TransformOpInterface and the payload IR ops they correspond to.
+class TransformResults {
+ friend class TransformState;
+
+public:
+ /// Indicates that the result of the transform IR op at the given position
+ /// corresponds to the given list of payload IR ops. Each result must be set
+ /// by the transformation exactly once.
+ void set(OpResult value, ArrayRef<Operation *> ops);
+
+private:
+ /// Creates an instance of TransformResults that expects mappings for
+ /// `numSegments` values.
+ explicit TransformResults(unsigned numSegments);
+
+ /// Gets the list of operations associated with the result identified by its
+ /// number in the list of operation results.
+ ArrayRef<Operation *> get(unsigned resultNumber) const;
+
+ /// Storage for pointers to payload IR ops that are associated with results of
+ /// a transform IR op. `segments` contains as many entries as the transform IR
+ /// op has results. Each entry is a reference to a contiguous segment in
+ /// the `operations` list that contains the pointers to operations. This
+ /// allows for operations to be stored contiguously without nested vectors and
+ /// for different segments to be set in any order.
+ SmallVector<ArrayRef<Operation *>, 2> segments;
+ SmallVector<Operation *> operations;
+};
+
+} // namespace transform
+} // namespace mlir
+
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h.inc"
+
+#endif // DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H
--- /dev/null
+//===- TransformInterfaces.td - Transform Op interfaces ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the interfaces for transformation-related-ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD
+#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD
+
+include "mlir/IR/OpBase.td"
+
+def TransformOpInterface : OpInterface<"TransformOpInterface"> {
+ let description = [{
+ This interface is to be implemented by operations that identify
+ transformations to be performed on other operations. The former are referred
+ to as transform IR operations. The latter are referred to as payload IR
+ operations. Such transform IR operations provide a fine-grain control
+ mechanism over how transformations are applied by using and defining
+ transform IR values, referred to as handles, that correspond to sets of
+ operations in the payload IR. Transformations are applied starting from the
+ operations identified by handles, but may affect other operations as well.
+ Further restrictions may be imposed by flows that rely on transform IR
+ operations to control transformations.
+ }];
+
+ let cppNamespace = "::mlir::transform";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Applies the transformation represented by the current operation. This
+ accepts as arguments the object that must be populated with results of
+ the current transformation and a transformation state object that can be
+ used for queries, e.g., to obtain the list of operations on which the
+ transformation represented by the current op is targeted.
+ }],
+ /*returnType=*/"::mlir::LogicalResult",
+ /*name=*/"apply",
+ /*arguments=*/(ins
+ "::mlir::transform::TransformResults &":$transformResults,
+ "::mlir::transform::TransformState &":$state
+ )>,
+ ];
+}
+
+#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD
#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
shape::ShapeDialect,
sparse_tensor::SparseTensorDialect,
tensor::TensorDialect,
+ transform::TransformDialect,
tosa::TosaDialect,
x86vector::X86VectorDialect>();
// clang-format on
add_subdirectory(SPIRV)
add_subdirectory(Tensor)
add_subdirectory(Tosa)
+add_subdirectory(Transform)
add_subdirectory(Utils)
add_subdirectory(Vector)
add_subdirectory(X86Vector)
--- /dev/null
+add_subdirectory(IR)
--- /dev/null
+add_mlir_dialect_library(MLIRTransformDialect
+ TransformDialect.cpp
+ TransformInterfaces.cpp
+
+ DEPENDS
+ MLIRTransformDialectIncGen
+ MLIRTransformInterfacesIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ )
--- /dev/null
+//===- TransformDialect.cpp - Transform Dialect Definition ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
+
+using namespace mlir;
+
+void transform::TransformDialect::initialize() {}
--- /dev/null
+//===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/ADT/SmallPtrSet.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// TransformState
+//===----------------------------------------------------------------------===//
+
+constexpr const Value transform::TransformState::kTopLevelValue;
+
+transform::TransformState::TransformState(Operation *root) {
+ operationMapping[kTopLevelValue].push_back(root);
+}
+
+Operation *transform::TransformState::getTopLevel() const {
+ return operationMapping.lookup(kTopLevelValue).front();
+}
+
+ArrayRef<Operation *>
+transform::TransformState::getPayloadOps(Value value) const {
+ auto iter = operationMapping.find(value);
+ assert(iter != operationMapping.end() && "unknown handle");
+ return iter->getSecond();
+}
+
+LogicalResult
+transform::TransformState::setPayloadOps(Value value,
+ ArrayRef<Operation *> targets) {
+ assert(value != kTopLevelValue &&
+ "attempting to reset the transformation root");
+
+ if (value.use_empty())
+ return success();
+
+ // Setting new payload for the value without cleaning it first is a misuse of
+ // the API, assert here.
+ SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
+ bool inserted =
+ operationMapping.insert({value, std::move(storedTargets)}).second;
+ assert(inserted && "value is already associated with another list");
+ (void)inserted;
+
+ // Having multiple handles to the same operation is an error in the transform
+ // expressed using the dialect and may be constructed by valid API calls from
+ // valid IR. Emit an error here.
+ for (Operation *op : targets) {
+ auto insertionResult = reverseMapping.insert({op, value});
+ if (!insertionResult.second) {
+ InFlightDiagnostic diag = op->emitError()
+ << "operation tracked by two handles";
+ diag.attachNote(value.getLoc()) << "handle";
+ diag.attachNote(insertionResult.first->second.getLoc()) << "handle";
+ return diag;
+ }
+ }
+
+ return success();
+}
+
+void transform::TransformState::removePayloadOps(Value value) {
+ for (Operation *op : operationMapping[value])
+ reverseMapping.erase(op);
+ operationMapping.erase(value);
+}
+
+void transform::TransformState::updatePayloadOps(
+ Value value, function_ref<Operation *(Operation *)> callback) {
+ auto it = operationMapping.find(value);
+ assert(it != operationMapping.end() && "unknown handle");
+ SmallVector<Operation *> &association = it->getSecond();
+ SmallVector<Operation *> updated;
+ updated.reserve(association.size());
+
+ for (Operation *op : association)
+ if (Operation *updatedOp = callback(op))
+ updated.push_back(updatedOp);
+
+ std::swap(association, updated);
+}
+
+LogicalResult
+transform::TransformState::applyTransform(TransformOpInterface transform) {
+ transform::TransformResults results(transform->getNumResults());
+ if (failed(transform.apply(results, *this)))
+ return failure();
+
+ for (Value target : transform->getOperands())
+ removePayloadOps(target);
+
+ for (auto &en : llvm::enumerate(transform->getResults()))
+ if (failed(setPayloadOps(en.value(), results.get(en.index()))))
+ return failure();
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// TransformResults
+//===----------------------------------------------------------------------===//
+
+transform::TransformResults::TransformResults(unsigned numSegments) {
+ segments.resize(numSegments,
+ ArrayRef<Operation *>(nullptr, static_cast<size_t>(0)));
+}
+
+void transform::TransformResults::set(OpResult value,
+ ArrayRef<Operation *> ops) {
+ unsigned position = value.getResultNumber();
+ assert(position < segments.size() &&
+ "setting results for a non-existent handle");
+ assert(segments[position].data() == nullptr && "results already set");
+ unsigned start = operations.size();
+ llvm::append_range(operations, ops);
+ segments[position] = makeArrayRef(operations).drop_front(start);
+}
+
+ArrayRef<Operation *>
+transform::TransformResults::get(unsigned resultNumber) const {
+ assert(resultNumber < segments.size() &&
+ "querying results for a non-existent handle");
+ assert(segments[resultNumber].data() != nullptr && "querying unset results");
+ return segments[resultNumber];
+}
+
+//===----------------------------------------------------------------------===//
+// Generated interface implementation.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc"
--- /dev/null
+// RUN: mlir-opt %s | FileCheck %s
+
+// These ops are defined by a test extension but should be okay to roundtrip.
+
+// CHECK: transform.test_transform_op
+transform.test_transform_op
+
+// CHECK: = transform.test_produce_param_or_forward_operand 42 {foo = "bar"}
+%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
+
+// CHECK: transform.test_consume_operand_if_matches_param_or_fail %{{.*}}[42]
+transform.test_consume_operand_if_matches_param_or_fail %0[42]
--- /dev/null
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file --verify-diagnostics
+
+// expected-remark @below {{applying transformation}}
+transform.test_transform_op
+
+// -----
+
+%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
+// expected-remark @below {{succeeded}}
+transform.test_consume_operand_if_matches_param_or_fail %0[42]
+
+// -----
+
+%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
+// expected-error @below {{expected the operand to be associated with 21 got 42}}
+transform.test_consume_operand_if_matches_param_or_fail %0[21]
+
+// -----
+
+// expected-error @below {{operation tracked by two handles}}
+%0 = transform.test_produce_param_or_forward_operand 42
+// expected-note @below {{handle}}
+%1 = transform.test_produce_param_or_forward_operand from %0
+// expected-note @below {{handle}}
+%2 = transform.test_produce_param_or_forward_operand from %0
+transform.test_consume_operand_if_matches_param_or_fail %1[42]
+transform.test_consume_operand_if_matches_param_or_fail %2[42]
add_subdirectory(Tensor)
add_subdirectory(Test)
add_subdirectory(Tosa)
+add_subdirectory(Transform)
add_subdirectory(Vector)
--- /dev/null
+set(LLVM_TARGET_DEFINITIONS TestTransformDialectExtension.td)
+mlir_tablegen(TestTransformDialectExtension.h.inc -gen-op-decls)
+mlir_tablegen(TestTransformDialectExtension.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRTestTransformDialectExtensionIncGen)
+
+add_mlir_library(MLIRTestTransformDialect
+ TestTransformDialectExtension.cpp
+ TestTransformDialectInterpreter.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+
+ DEPENDS
+ MLIRTestTransformDialectExtensionIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRPass
+ MLIRPDL
+ MLIRTransformDialect
+)
--- /dev/null
+//===- TestTransformDialectExtension.cpp ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines an extension of the MLIR Transform dialect for testing
+// purposes.
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestTransformDialectExtension.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+
+using namespace mlir;
+
+namespace {
+/// Simple transform op defined outside of the dialect. Just emits a remark when
+/// applied.
+class TestTransformOp
+ : public Op<TestTransformOp, transform::TransformOpInterface::Trait> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp)
+
+ using Op::Op;
+
+ static ArrayRef<StringRef> getAttributeNames() { return {}; }
+
+ static constexpr llvm::StringLiteral getOperationName() {
+ return llvm::StringLiteral("transform.test_transform_op");
+ }
+
+ LogicalResult apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ emitRemark() << "applying transformation";
+ return success();
+ }
+
+ static ParseResult parse(OpAsmParser &parser, OperationState &state) {
+ return success();
+ }
+
+ void print(OpAsmPrinter &printer) {}
+};
+} // namespace
+
+LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::apply(
+ transform::TransformResults &results, transform::TransformState &state) {
+ if (getOperation()->getNumOperands() != 0) {
+ results.set(getResult().cast<OpResult>(), getOperand(0).getDefiningOp());
+ } else {
+ results.set(getResult().cast<OpResult>(),
+ reinterpret_cast<Operation *>(*parameter()));
+ }
+ return success();
+}
+
+LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() {
+ if (parameter().hasValue() ^ (getNumOperands() != 1))
+ return emitOpError() << "expects either a parameter or an operand";
+ return success();
+}
+
+LogicalResult mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply(
+ transform::TransformResults &results, transform::TransformState &state) {
+ ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
+ assert(payload.size() == 1 && "expected a single target op");
+ auto value = reinterpret_cast<intptr_t>(payload[0]);
+ if (value != parameter()) {
+ return emitOpError() << "expected the operand to be associated with "
+ << parameter() << " got " << value;
+ }
+
+ emitRemark() << "succeeded";
+ return success();
+}
+
+namespace {
+/// Test extension of the Transform dialect. Registers additional ops and
+/// declares PDL as dependent dialect since the additional ops are using PDL
+/// types for operands and results.
+class TestTransformDialectExtension
+ : public transform::TransformDialectExtension<
+ TestTransformDialectExtension> {
+public:
+ TestTransformDialectExtension() {
+ declareDependentDialect<pdl::PDLDialect>();
+ registerTransformOps<TestTransformOp,
+#define GET_OP_LIST
+#include "TestTransformDialectExtension.cpp.inc"
+ >();
+ }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "TestTransformDialectExtension.cpp.inc"
+
+void ::test::registerTestTransformDialectExtension(DialectRegistry ®istry) {
+ registry.addExtensions<TestTransformDialectExtension>();
+}
--- /dev/null
+//===- TestTransformDialectExtension.h --------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines an extension of the MLIR Transform dialect for testing
+// purposes.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TESTTRANSFORMDIALECTEXTENSION_H
+#define MLIR_TESTTRANSFORMDIALECTEXTENSION_H
+
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+class DialectRegistry;
+} // namespace mlir
+
+#define GET_OP_CLASSES
+#include "TestTransformDialectExtension.h.inc"
+
+namespace test {
+/// Registers the test extension to the Transform dialect.
+void registerTestTransformDialectExtension(::mlir::DialectRegistry ®istry);
+} // namespace test
+
+#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_H
--- /dev/null
+//===- TestTransformDialectExtension.td --------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the operations that are injected into the Transform
+// dialect through the extension mechanism, as a test.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
+#define MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
+
+include "mlir/IR/OpBase.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+include "mlir/Dialect/PDL/IR/PDLTypes.td"
+
+def TestProduceParamOrForwardOperandOp
+ : Op<Transform_Dialect, "test_produce_param_or_forward_operand",
+ [DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let arguments = (ins Optional<PDL_Operation>:$operand,
+ OptionalAttr<I64Attr>:$parameter);
+ let results = (outs PDL_Operation:$res);
+ let assemblyFormat = "(`from` $operand^)? ($parameter^)? attr-dict";
+ let cppNamespace = "::mlir::test";
+ let hasVerifier = 1;
+}
+
+def TestConsumeOperandIfMatchesParamOrFail
+ : Op<Transform_Dialect, "test_consume_operand_if_matches_param_or_fail",
+ [DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let arguments = (ins PDL_Operation:$operand, I64Attr:$parameter);
+ let assemblyFormat = "$operand `[` $parameter `]` attr-dict";
+ let cppNamespace = "::mlir::test";
+}
+
+#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
--- /dev/null
+//===- TestTransformDialectInterpreter.cpp --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a test pass that interprets Transform dialect operations in
+// the module.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+/// Simple pass that applies transform dialect ops directly contained in a
+/// module.
+class TestTransformDialectInterpreterPass
+ : public PassWrapper<TestTransformDialectInterpreterPass,
+ OperationPass<ModuleOp>> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestTransformDialectInterpreterPass)
+
+ StringRef getArgument() const override {
+ return "test-transform-dialect-interpreter";
+ }
+
+ StringRef getDescription() const override {
+ return "apply transform dialect operations one by one";
+ }
+
+ void runOnOperation() override {
+ ModuleOp module = getOperation();
+ transform::TransformState state(module);
+ for (auto op :
+ module.getBody()->getOps<transform::TransformOpInterface>()) {
+ if (failed(state.applyTransform(op)))
+ return signalPassFailure();
+ }
+ }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+/// Registers the test pass for applying transform dialect ops.
+void registerTestTransformDialectInterpreterPass() {
+ PassRegistration<TestTransformDialectInterpreterPass> reg;
+}
+} // namespace test
+} // namespace mlir
--- /dev/null
+config.suffixes.remove('.td')
\ No newline at end of file
// CHECK-NEXT: tensor
// CHECK-NEXT: test
// CHECK-NEXT: tosa
+// CHECK-NEXT: transform
// CHECK-NEXT: vector
// CHECK-NEXT: x86vector
MLIRTestPass
MLIRTestReducer
MLIRTestRewrite
+ MLIRTestTransformDialect
MLIRTestTransforms
MLIRVectorTestPasses
)
void registerTestSCFUtilsPass();
void registerTestSliceAnalysisPass();
void registerTestTensorTransforms();
+void registerTestTransformDialectInterpreterPass();
void registerTestVectorLowerings();
} // namespace test
} // namespace mlir
namespace test {
void registerTestDialect(DialectRegistry &);
+void registerTestTransformDialectExtension(DialectRegistry &);
} // namespace test
#ifdef MLIR_INCLUDE_TESTS
mlir::test::registerTestSCFUtilsPass();
mlir::test::registerTestSliceAnalysisPass();
mlir::test::registerTestTensorTransforms();
+ mlir::test::registerTestTransformDialectInterpreterPass();
mlir::test::registerTestVectorLowerings();
}
#endif
registerAllDialects(registry);
#ifdef MLIR_INCLUDE_TESTS
::test::registerTestDialect(registry);
+ ::test::registerTestTransformDialectExtension(registry);
#endif
return mlir::asMainReturnCode(
mlir::MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry,
":TensorTransforms",
":TosaDialect",
":TosaToLinalg",
+ ":TransformDialect",
":Transforms",
":TransformsPassIncGen",
":VectorOps",
"//mlir/test:TestShapeDialect",
"//mlir/test:TestTensor",
"//mlir/test:TestTosaDialect",
+ "//mlir/test:TestTransformDialect",
"//mlir/test:TestTransforms",
"//mlir/test:TestTypeDialect",
"//mlir/test:TestVector",
],
)
+td_library(
+ name = "TransformDialectTdFiles",
+ srcs = glob(["include/mlir/Dialect/Transform/IR/*.td"]),
+ deps = [
+ ":OpBaseTdFiles",
+ ],
+)
+
+gentbl_cc_library(
+ name = "TransformDialectInterfacesIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ [
+ "-gen-op-interface-decls",
+ ],
+ "include/mlir/Dialect/Transform/IR/TransformInterfaces.h.inc",
+ ),
+ (
+ [
+ "-gen-op-interface-defs",
+ ],
+ "include/mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc",
+ ),
+ ],
+ tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/Dialect/Transform/IR/TransformInterfaces.td",
+ deps = [":TransformDialectTdFiles"],
+)
+
+gentbl_cc_library(
+ name = "TransformDialectIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ [
+ "-gen-dialect-decls",
+ ],
+ "include/mlir/Dialect/Transform/IR/TransformDialect.h.inc",
+ ),
+ (
+ [
+ "-gen-dialect-defs",
+ ],
+ "include/mlir/Dialect/Transform/IR/TransformDialect.cpp.inc",
+ ),
+ ],
+ tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/Dialect/Transform/IR/TransformDialect.td",
+ deps = [":TransformDialectTdFiles"],
+)
+
+cc_library(
+ name = "TransformDialect",
+ srcs = glob(["lib/Dialect/Transform/IR/*.cpp"]),
+ hdrs = glob(["include/mlir/Dialect/Transform/IR/*.h"]),
+ deps = [
+ ":IR",
+ ":Support",
+ ":TransformDialectIncGen",
+ ":TransformDialectInterfacesIncGen",
+ "//llvm:Support",
+ ],
+)
+
td_library(
name = "ComplexOpsTdFiles",
srcs = [
],
)
+td_library(
+ name = "TransformDialectTdFiles",
+ srcs = glob(["lib/Dialect/Transform/*.td"]),
+ deps = [
+ "//mlir:OpBaseTdFiles",
+ ],
+)
+
+gentbl_cc_library(
+ name = "TestTransformDialectExtensionIncGen",
+ strip_include_prefix = "lib/Dialect/Transform",
+ tbl_outs = [
+ (
+ ["-gen-op-decls"],
+ "lib/Dialect/Transform/TestTransformDialectExtension.h.inc",
+ ),
+ (
+ ["-gen-op-defs"],
+ "lib/Dialect/Transform/TestTransformDialectExtension.cpp.inc",
+ ),
+ ],
+ tblgen = "//mlir:mlir-tblgen",
+ td_file = "lib/Dialect/Transform/TestTransformDialectExtension.td",
+ test = True,
+ deps = [
+ ":TransformDialectTdFiles",
+ "//mlir:PDLDialectTdFiles",
+ "//mlir:TransformDialectTdFiles",
+ ],
+)
+
+cc_library(
+ name = "TestTransformDialect",
+ srcs = glob(["lib/Dialect/Transform/*.cpp"]),
+ hdrs = glob(["lib/Dialect/Transform/*.h"]),
+ includes = ["lib/Dialect/Transform"],
+ deps = [
+ ":TestTransformDialectExtensionIncGen",
+ "//mlir:IR",
+ "//mlir:PDLDialect",
+ "//mlir:Pass",
+ "//mlir:TransformDialect",
+ ],
+)
+
cc_library(
name = "TestDialect",
srcs = glob(["lib/Dialect/Test/*.cpp"]),