[mlir:OpConversionPattern] Add overloads for taking an Adaptor instead of ArrayRef
authorRiver Riddle <riddleriver@gmail.com>
Fri, 24 Sep 2021 17:50:58 +0000 (17:50 +0000)
committerRiver Riddle <riddleriver@gmail.com>
Fri, 24 Sep 2021 17:51:41 +0000 (17:51 +0000)
This has been a TODO for a long time, and it brings about many advantages (namely nice accessors, and less fragile code). The existing overloads that accept ArrayRef are now treated as deprecated and will be removed in a followup (after a small grace period). Most of the upstream MLIR usages have been fixed by this commit, the rest will be handled in a followup.

Differential Revision: https://reviews.llvm.org/D110293

33 files changed:
mlir/docs/Bufferization.md
mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp
mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
mlir/lib/Transforms/Bufferize.cpp

index ac2e068..4a6db34 100644 (file)
@@ -139,10 +139,10 @@ class BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(tensor::CastOp op, ArrayRef<Value> operands,
+  matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto resultType = getTypeConverter()->convertType(op.getType());
-    rewriter.replaceOpWithNewOp<MemRefCastOp>(op, resultType, operands[0]);
+    rewriter.replaceOpWithNewOp<MemRefCastOp>(op, resultType, adaptor.source());
     return success();
   }
 };
index 21e2591..81358dc 100644 (file)
@@ -131,6 +131,8 @@ protected:
 template <typename SourceOp>
 class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
 public:
+  using OpAdaptor = typename SourceOp::Adaptor;
+
   explicit ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
                                   PatternBenefit benefit = 1)
       : ConvertToLLVMPattern(SourceOp::getOperationName(),
@@ -140,7 +142,8 @@ public:
   /// Wrappers around the RewritePattern methods that pass the derived op type.
   void rewrite(Operation *op, ArrayRef<Value> operands,
                ConversionPatternRewriter &rewriter) const final {
-    rewrite(cast<SourceOp>(op), operands, rewriter);
+    rewrite(cast<SourceOp>(op), OpAdaptor(operands, op->getAttrDictionary()),
+            rewriter);
   }
   LogicalResult match(Operation *op) const final {
     return match(cast<SourceOp>(op));
@@ -148,28 +151,53 @@ public:
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
-    return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
+    return matchAndRewrite(cast<SourceOp>(op),
+                           OpAdaptor(operands, op->getAttrDictionary()),
+                           rewriter);
   }
 
   /// Rewrite and Match methods that operate on the SourceOp type. These must be
   /// overridden by the derived pattern class.
+  /// NOTICE: These methods are deprecated and will be removed. All new code
+  /// should use the adaptor methods below instead.
   virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
                        ConversionPatternRewriter &rewriter) const {
     llvm_unreachable("must override rewrite or matchAndRewrite");
   }
-  virtual LogicalResult match(SourceOp op) const {
-    llvm_unreachable("must override match or matchAndRewrite");
-  }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const {
     if (succeeded(match(op))) {
-      rewrite(op, operands, rewriter);
+      rewrite(op, OpAdaptor(operands, op->getAttrDictionary()), rewriter);
       return success();
     }
     return failure();
   }
 
+  /// Rewrite and Match methods that operate on the SourceOp type. These must be
+  /// overridden by the derived pattern class.
+  virtual LogicalResult match(SourceOp op) const {
+    llvm_unreachable("must override match or matchAndRewrite");
+  }
+  virtual void rewrite(SourceOp op, OpAdaptor adaptor,
+                       ConversionPatternRewriter &rewriter) const {
+    ValueRange operands = adaptor.getOperands();
+    rewrite(op,
+            ArrayRef<Value>(operands.getBase().get<const Value *>(),
+                            operands.size()),
+            rewriter);
+  }
+  virtual LogicalResult
+  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const {
+    ValueRange operands = adaptor.getOperands();
+    return matchAndRewrite(
+        op,
+        ArrayRef<Value>(operands.getBase().get<const Value *>(),
+                        operands.size()),
+        rewriter);
+  }
+
 private:
   using ConvertToLLVMPattern::match;
   using ConvertToLLVMPattern::matchAndRewrite;
index 7354b55..9fe9690 100644 (file)
@@ -366,79 +366,121 @@ private:
   using RewritePattern::rewrite;
 };
 
-namespace detail {
-/// OpOrInterfaceConversionPatternBase is a wrapper around ConversionPattern
-/// that allows for matching and rewriting against an instance of a derived
-/// operation class or an Interface as opposed to a raw Operation.
+/// OpConversionPattern is a wrapper around ConversionPattern that allows for
+/// matching and rewriting against an instance of a derived operation class as
+/// opposed to a raw Operation.
 template <typename SourceOp>
-struct OpOrInterfaceConversionPatternBase : public ConversionPattern {
-  using ConversionPattern::ConversionPattern;
+class OpConversionPattern : public ConversionPattern {
+public:
+  using OpAdaptor = typename SourceOp::Adaptor;
+
+  OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
+      : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
+  OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context,
+                      PatternBenefit benefit = 1)
+      : ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit,
+                          context) {}
 
   /// Wrappers around the ConversionPattern methods that pass the derived op
   /// type.
   void rewrite(Operation *op, ArrayRef<Value> operands,
                ConversionPatternRewriter &rewriter) const final {
-    rewrite(cast<SourceOp>(op), operands, rewriter);
+    rewrite(cast<SourceOp>(op), OpAdaptor(operands, op->getAttrDictionary()),
+            rewriter);
   }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
-    return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
+    return matchAndRewrite(cast<SourceOp>(op),
+                           OpAdaptor(operands, op->getAttrDictionary()),
+                           rewriter);
   }
 
-  // TODO: Use OperandAdaptor when it supports access to unnamed operands.
-
-  /// Rewrite and Match methods that operate on the SourceOp type. These must be
-  /// overridden by the derived pattern class.
+  /// Rewrite and Match methods that operate on the SourceOp type and accept the
+  /// raw operand values.
+  /// NOTICE: These methods are deprecated and will be removed. All new code
+  /// should use the adaptor methods below instead.
   virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
                        ConversionPatternRewriter &rewriter) const {
     llvm_unreachable("must override matchAndRewrite or a rewrite method");
   }
-
   virtual LogicalResult
   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const {
     if (failed(match(op)))
       return failure();
-    rewrite(op, operands, rewriter);
+    rewrite(op, OpAdaptor(operands, op->getAttrDictionary()), rewriter);
     return success();
   }
 
+  /// Rewrite and Match methods that operate on the SourceOp type. These must be
+  /// overridden by the derived pattern class.
+  virtual void rewrite(SourceOp op, OpAdaptor adaptor,
+                       ConversionPatternRewriter &rewriter) const {
+    ValueRange operands = adaptor.getOperands();
+    rewrite(op,
+            ArrayRef<Value>(operands.getBase().get<const Value *>(),
+                            operands.size()),
+            rewriter);
+  }
+  virtual LogicalResult
+  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const {
+    ValueRange operands = adaptor.getOperands();
+    return matchAndRewrite(
+        op,
+        ArrayRef<Value>(operands.getBase().get<const Value *>(),
+                        operands.size()),
+        rewriter);
+  }
+
 private:
   using ConversionPattern::matchAndRewrite;
 };
-} // namespace detail
-
-/// OpConversionPattern is a wrapper around ConversionPattern that allows for
-/// matching and rewriting against an instance of a derived operation class as
-/// opposed to a raw Operation.
-template <typename SourceOp>
-struct OpConversionPattern
-    : public detail::OpOrInterfaceConversionPatternBase<SourceOp> {
-  OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
-      : detail::OpOrInterfaceConversionPatternBase<SourceOp>(
-            SourceOp::getOperationName(), benefit, context) {}
-  OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context,
-                      PatternBenefit benefit = 1)
-      : detail::OpOrInterfaceConversionPatternBase<SourceOp>(
-            typeConverter, SourceOp::getOperationName(), benefit, context) {}
-};
 
 /// OpInterfaceConversionPattern is a wrapper around ConversionPattern that
 /// allows for matching and rewriting against an instance of an OpInterface
 /// class as opposed to a raw Operation.
 template <typename SourceOp>
-struct OpInterfaceConversionPattern
-    : public detail::OpOrInterfaceConversionPatternBase<SourceOp> {
+class OpInterfaceConversionPattern : public ConversionPattern {
+public:
   OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
-      : detail::OpOrInterfaceConversionPatternBase<SourceOp>(
-            Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
-            benefit, context) {}
+      : ConversionPattern(Pattern::MatchInterfaceOpTypeTag(),
+                          SourceOp::getInterfaceID(), benefit, context) {}
   OpInterfaceConversionPattern(TypeConverter &typeConverter,
                                MLIRContext *context, PatternBenefit benefit = 1)
-      : detail::OpOrInterfaceConversionPatternBase<SourceOp>(
-            typeConverter, Pattern::MatchInterfaceOpTypeTag(),
-            SourceOp::getInterfaceID(), benefit, context) {}
+      : ConversionPattern(typeConverter, Pattern::MatchInterfaceOpTypeTag(),
+                          SourceOp::getInterfaceID(), benefit, context) {}
+
+  /// Wrappers around the ConversionPattern methods that pass the derived op
+  /// type.
+  void rewrite(Operation *op, ArrayRef<Value> operands,
+               ConversionPatternRewriter &rewriter) const final {
+    rewrite(cast<SourceOp>(op), operands, rewriter);
+  }
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
+  }
+
+  /// Rewrite and Match methods that operate on the SourceOp type. These must be
+  /// overridden by the derived pattern class.
+  virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
+                       ConversionPatternRewriter &rewriter) const {
+    llvm_unreachable("must override matchAndRewrite or a rewrite method");
+  }
+  virtual LogicalResult
+  matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const {
+    if (failed(match(op)))
+      return failure();
+    rewrite(op, operands, rewriter);
+    return success();
+  }
+
+private:
+  using ConversionPattern::matchAndRewrite;
 };
 
 /// Add a pattern to the given pattern list to convert the signature of a
index 4f4dd0e..8e0c1a2 100644 (file)
@@ -326,7 +326,7 @@ public:
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(CoroIdOp op, ArrayRef<Value> operands,
+  matchAndRewrite(CoroIdOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto token = AsyncAPI::tokenType(op->getContext());
     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
@@ -356,7 +356,7 @@ public:
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(CoroBeginOp op, ArrayRef<Value> operands,
+  matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
     auto loc = op->getLoc();
@@ -371,7 +371,7 @@ public:
         ValueRange(coroSize.getResult()));
 
     // Begin a coroutine: @llvm.coro.begin.
-    auto coroId = CoroBeginOpAdaptor(operands).id();
+    auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).id();
     rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
         op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)}));
 
@@ -390,13 +390,14 @@ public:
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(CoroFreeOp op, ArrayRef<Value> operands,
+  matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
     auto loc = op->getLoc();
 
     // Get a pointer to the coroutine frame memory: @llvm.coro.free.
-    auto coroMem = rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, operands);
+    auto coroMem =
+        rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, adaptor.getOperands());
 
     // Free the memory.
     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
@@ -418,14 +419,14 @@ public:
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(CoroEndOp op, ArrayRef<Value> operands,
+  matchAndRewrite(CoroEndOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // We are not in the block that is part of the unwind sequence.
     auto constFalse = rewriter.create<LLVM::ConstantOp>(
         op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
 
     // Mark the end of a coroutine: @llvm.coro.end.
-    auto coroHdl = CoroEndOpAdaptor(operands).handle();
+    auto coroHdl = adaptor.handle();
     rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(),
                                      ValueRange({coroHdl, constFalse}));
     rewriter.eraseOp(op);
@@ -445,11 +446,11 @@ public:
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(CoroSaveOp op, ArrayRef<Value> operands,
+  matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Save the coroutine state: @llvm.coro.save
     rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
-        op, AsyncAPI::tokenType(op->getContext()), operands);
+        op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands());
 
     return success();
   }
@@ -491,7 +492,7 @@ public:
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(CoroSuspendOp op, ArrayRef<Value> operands,
+  matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto i8 = rewriter.getIntegerType(8);
     auto i32 = rewriter.getI32Type();
@@ -502,7 +503,7 @@ public:
         loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
 
     // Suspend a coroutine: @llvm.coro.suspend
-    auto coroState = CoroSuspendOpAdaptor(operands).state();
+    auto coroState = adaptor.state();
     auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
         loc, i8, ValueRange({coroState, constFalse}));
 
@@ -541,7 +542,7 @@ public:
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(RuntimeCreateOp op, ArrayRef<Value> operands,
+  matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     TypeConverter *converter = getTypeConverter();
     Type resultType = op->getResultTypes()[0];
@@ -595,13 +596,14 @@ public:
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(RuntimeCreateGroupOp op, ArrayRef<Value> operands,
+  matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     TypeConverter *converter = getTypeConverter();
     Type resultType = op.getResult().getType();
 
-    rewriter.replaceOpWithNewOp<CallOp>(
-        op, kCreateGroup, converter->convertType(resultType), operands);
+    rewriter.replaceOpWithNewOp<CallOp>(op, kCreateGroup,
+                                        converter->convertType(resultType),
+                                        adaptor.getOperands());
     return success();
   }
 };
@@ -618,14 +620,15 @@ public:
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef<Value> operands,
+  matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     StringRef apiFuncName =
         TypeSwitch<Type, StringRef>(op.operand().getType())
             .Case<TokenType>([](Type) { return kEmplaceToken; })
             .Case<ValueType>([](Type) { return kEmplaceValue; });
 
-    rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), operands);
+    rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(),
+                                        adaptor.getOperands());
 
     return success();
   }
@@ -643,14 +646,15 @@ public:
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(RuntimeSetErrorOp op, ArrayRef<Value> operands,
+  matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     StringRef apiFuncName =
         TypeSwitch<Type, StringRef>(op.operand().getType())
             .Case<TokenType>([](Type) { return kSetTokenError; })
             .Case<ValueType>([](Type) { return kSetValueError; });
 
-    rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), operands);
+    rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(),
+                                        adaptor.getOperands());
 
     return success();
   }
@@ -667,7 +671,7 @@ public:
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(RuntimeIsErrorOp op, ArrayRef<Value> operands,
+  matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     StringRef apiFuncName =
         TypeSwitch<Type, StringRef>(op.operand().getType())
@@ -676,7 +680,7 @@ public:
             .Case<ValueType>([](Type) { return kIsValueError; });
 
     rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, rewriter.getI1Type(),
-                                        operands);
+                                        adaptor.getOperands());
     return success();
   }
 };
@@ -692,7 +696,7 @@ public:
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(RuntimeAwaitOp op, ArrayRef<Value> operands,
+  matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     StringRef apiFuncName =
         TypeSwitch<Type, StringRef>(op.operand().getType())
@@ -700,7 +704,8 @@ public:
             .Case<ValueType>([](Type) { return kAwaitValue; })
             .Case<GroupType>([](Type) { return kAwaitGroup; });
 
-    rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), operands);
+    rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(),
+                            adaptor.getOperands());
     rewriter.eraseOp(op);
 
     return success();
@@ -719,7 +724,7 @@ public:
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(RuntimeAwaitAndResumeOp op, ArrayRef<Value> operands,
+  matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     StringRef apiFuncName =
         TypeSwitch<Type, StringRef>(op.operand().getType())
@@ -727,8 +732,8 @@ public:
             .Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
             .Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
 
-    Value operand = RuntimeAwaitAndResumeOpAdaptor(operands).operand();
-    Value handle = RuntimeAwaitAndResumeOpAdaptor(operands).handle();
+    Value operand = adaptor.operand();
+    Value handle = adaptor.handle();
 
     // A pointer to coroutine resume intrinsic wrapper.
     addResumeFunction(op->getParentOfType<ModuleOp>());
@@ -755,7 +760,7 @@ public:
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(RuntimeResumeOp op, ArrayRef<Value> operands,
+  matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // A pointer to coroutine resume intrinsic wrapper.
     addResumeFunction(op->getParentOfType<ModuleOp>());
@@ -764,7 +769,7 @@ public:
         op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
 
     // Call async runtime API to execute a coroutine in the managed thread.
-    auto coroHdl = RuntimeResumeOpAdaptor(operands).handle();
+    auto coroHdl = adaptor.handle();
     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), kExecute,
                                         ValueRange({coroHdl, resumePtr.res()}));
 
@@ -783,13 +788,13 @@ public:
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(RuntimeStoreOp op, ArrayRef<Value> operands,
+  matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op->getLoc();
 
     // Get a pointer to the async value storage from the runtime.
     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
-    auto storage = RuntimeStoreOpAdaptor(operands).storage();
+    auto storage = adaptor.storage();
     auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
                                               TypeRange(i8Ptr), storage);
 
@@ -805,7 +810,7 @@ public:
         storagePtr.getResult(0));
 
     // Store the yielded value into the async value storage.
-    auto value = RuntimeStoreOpAdaptor(operands).value();
+    auto value = adaptor.value();
     rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult());
 
     // Erase the original runtime store operation.
@@ -826,13 +831,13 @@ public:
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(RuntimeLoadOp op, ArrayRef<Value> operands,
+  matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op->getLoc();
 
     // Get a pointer to the async value storage from the runtime.
     auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
-    auto storage = RuntimeLoadOpAdaptor(operands).storage();
+    auto storage = adaptor.storage();
     auto storagePtr = rewriter.create<CallOp>(loc, kGetValueStorage,
                                               TypeRange(i8Ptr), storage);
 
@@ -866,15 +871,15 @@ public:
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(RuntimeAddToGroupOp op, ArrayRef<Value> operands,
+  matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Currently we can only add tokens to the group.
     if (!op.operand().getType().isa<TokenType>())
       return rewriter.notifyMatchFailure(op, "only token type is supported");
 
     // Replace with a runtime API function call.
-    rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup,
-                                        rewriter.getI64Type(), operands);
+    rewriter.replaceOpWithNewOp<CallOp>(
+        op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands());
 
     return success();
   }
@@ -896,13 +901,13 @@ public:
         apiFunctionName(apiFunctionName) {}
 
   LogicalResult
-  matchAndRewrite(RefCountingOp op, ArrayRef<Value> operands,
+  matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto count =
         rewriter.create<ConstantOp>(op->getLoc(), rewriter.getI32Type(),
                                     rewriter.getI32IntegerAttr(op.count()));
 
-    auto operand = typename RefCountingOp::Adaptor(operands).operand();
+    auto operand = adaptor.operand();
     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), apiFunctionName,
                                         ValueRange({operand, count}));
 
@@ -937,9 +942,9 @@ public:
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
+    rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands());
     return success();
   }
 };
@@ -1032,7 +1037,7 @@ class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ExecuteOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ExecuteOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     ExecuteOp newOp =
         cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
@@ -1040,7 +1045,7 @@ public:
                                 newOp.getRegion().end());
 
     // Set operands and update block argument and result types.
-    newOp->setOperands(operands);
+    newOp->setOperands(adaptor.getOperands());
     if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
       return failure();
     for (auto result : newOp.getResults())
@@ -1056,9 +1061,9 @@ class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(AwaitOp op, ArrayRef<Value> operands,
+  matchAndRewrite(AwaitOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<AwaitOp>(op, operands.front());
+    rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front());
     return success();
   }
 };
@@ -1068,9 +1073,9 @@ class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands,
+  matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<async::YieldOp>(op, operands);
+    rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands());
     return success();
   }
 };
index f651eed..6ca60d0 100644 (file)
@@ -26,16 +26,13 @@ struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands,
+  matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    complex::AbsOp::Adaptor transformed(operands);
     auto loc = op.getLoc();
     auto type = op.getType();
 
-    Value real =
-        rewriter.create<complex::ReOp>(loc, type, transformed.complex());
-    Value imag =
-        rewriter.create<complex::ImOp>(loc, type, transformed.complex());
+    Value real = rewriter.create<complex::ReOp>(loc, type, adaptor.complex());
+    Value imag = rewriter.create<complex::ImOp>(loc, type, adaptor.complex());
     Value realSqr = rewriter.create<MulFOp>(loc, real, real);
     Value imagSqr = rewriter.create<MulFOp>(loc, imag, imag);
     Value sqNorm = rewriter.create<AddFOp>(loc, realSqr, imagSqr);
@@ -53,23 +50,16 @@ struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
                          AndOp, OrOp>;
 
   LogicalResult
-  matchAndRewrite(ComparisonOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    typename ComparisonOp::Adaptor transformed(operands);
     auto loc = op.getLoc();
-    auto type = transformed.lhs()
-                    .getType()
-                    .template cast<ComplexType>()
-                    .getElementType();
-
-    Value realLhs =
-        rewriter.create<complex::ReOp>(loc, type, transformed.lhs());
-    Value imagLhs =
-        rewriter.create<complex::ImOp>(loc, type, transformed.lhs());
-    Value realRhs =
-        rewriter.create<complex::ReOp>(loc, type, transformed.rhs());
-    Value imagRhs =
-        rewriter.create<complex::ImOp>(loc, type, transformed.rhs());
+    auto type =
+        adaptor.lhs().getType().template cast<ComplexType>().getElementType();
+
+    Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.lhs());
+    Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.lhs());
+    Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.rhs());
+    Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.rhs());
     Value realComparison = rewriter.create<CmpFOp>(loc, p, realLhs, realRhs);
     Value imagComparison = rewriter.create<CmpFOp>(loc, p, imagLhs, imagRhs);
 
@@ -87,19 +77,18 @@ struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
   using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(BinaryComplexOp op, ArrayRef<Value> operands,
+  matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    typename BinaryComplexOp::Adaptor transformed(operands);
-    auto type = transformed.lhs().getType().template cast<ComplexType>();
+    auto type = adaptor.lhs().getType().template cast<ComplexType>();
     auto elementType = type.getElementType().template cast<FloatType>();
     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
 
-    Value realLhs = b.create<complex::ReOp>(elementType, transformed.lhs());
-    Value realRhs = b.create<complex::ReOp>(elementType, transformed.rhs());
+    Value realLhs = b.create<complex::ReOp>(elementType, adaptor.lhs());
+    Value realRhs = b.create<complex::ReOp>(elementType, adaptor.rhs());
     Value resultReal =
         b.create<BinaryStandardOp>(elementType, realLhs, realRhs);
-    Value imagLhs = b.create<complex::ImOp>(elementType, transformed.lhs());
-    Value imagRhs = b.create<complex::ImOp>(elementType, transformed.rhs());
+    Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.lhs());
+    Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.rhs());
     Value resultImag =
         b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs);
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
@@ -112,21 +101,20 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
   using OpConversionPattern<complex::DivOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(complex::DivOp op, ArrayRef<Value> operands,
+  matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    complex::DivOp::Adaptor transformed(operands);
     auto loc = op.getLoc();
-    auto type = transformed.lhs().getType().cast<ComplexType>();
+    auto type = adaptor.lhs().getType().cast<ComplexType>();
     auto elementType = type.getElementType().cast<FloatType>();
 
     Value lhsReal =
-        rewriter.create<complex::ReOp>(loc, elementType, transformed.lhs());
+        rewriter.create<complex::ReOp>(loc, elementType, adaptor.lhs());
     Value lhsImag =
-        rewriter.create<complex::ImOp>(loc, elementType, transformed.lhs());
+        rewriter.create<complex::ImOp>(loc, elementType, adaptor.lhs());
     Value rhsReal =
-        rewriter.create<complex::ReOp>(loc, elementType, transformed.rhs());
+        rewriter.create<complex::ReOp>(loc, elementType, adaptor.rhs());
     Value rhsImag =
-        rewriter.create<complex::ImOp>(loc, elementType, transformed.rhs());
+        rewriter.create<complex::ImOp>(loc, elementType, adaptor.rhs());
 
     // Smith's algorithm to divide complex numbers. It is just a bit smarter
     // way to compute the following formula:
@@ -321,17 +309,16 @@ struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
   using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(complex::ExpOp op, ArrayRef<Value> operands,
+  matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    complex::ExpOp::Adaptor transformed(operands);
     auto loc = op.getLoc();
-    auto type = transformed.complex().getType().cast<ComplexType>();
+    auto type = adaptor.complex().getType().cast<ComplexType>();
     auto elementType = type.getElementType().cast<FloatType>();
 
     Value real =
-        rewriter.create<complex::ReOp>(loc, elementType, transformed.complex());
+        rewriter.create<complex::ReOp>(loc, elementType, adaptor.complex());
     Value imag =
-        rewriter.create<complex::ImOp>(loc, elementType, transformed.complex());
+        rewriter.create<complex::ImOp>(loc, elementType, adaptor.complex());
     Value expReal = rewriter.create<math::ExpOp>(loc, real);
     Value cosImag = rewriter.create<math::CosOp>(loc, imag);
     Value resultReal = rewriter.create<MulFOp>(loc, expReal, cosImag);
@@ -348,17 +335,16 @@ struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
   using OpConversionPattern<complex::LogOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(complex::LogOp op, ArrayRef<Value> operands,
+  matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    complex::LogOp::Adaptor transformed(operands);
-    auto type = transformed.complex().getType().cast<ComplexType>();
+    auto type = adaptor.complex().getType().cast<ComplexType>();
     auto elementType = type.getElementType().cast<FloatType>();
     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
 
-    Value abs = b.create<complex::AbsOp>(elementType, transformed.complex());
+    Value abs = b.create<complex::AbsOp>(elementType, adaptor.complex());
     Value resultReal = b.create<math::LogOp>(elementType, abs);
-    Value real = b.create<complex::ReOp>(elementType, transformed.complex());
-    Value imag = b.create<complex::ImOp>(elementType, transformed.complex());
+    Value real = b.create<complex::ReOp>(elementType, adaptor.complex());
+    Value imag = b.create<complex::ImOp>(elementType, adaptor.complex());
     Value resultImag = b.create<math::Atan2Op>(elementType, imag, real);
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
                                                    resultImag);
@@ -370,15 +356,14 @@ struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
   using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(complex::Log1pOp op, ArrayRef<Value> operands,
+  matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    complex::Log1pOp::Adaptor transformed(operands);
-    auto type = transformed.complex().getType().cast<ComplexType>();
+    auto type = adaptor.complex().getType().cast<ComplexType>();
     auto elementType = type.getElementType().cast<FloatType>();
     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
 
-    Value real = b.create<complex::ReOp>(elementType, transformed.complex());
-    Value imag = b.create<complex::ImOp>(elementType, transformed.complex());
+    Value real = b.create<complex::ReOp>(elementType, adaptor.complex());
+    Value imag = b.create<complex::ImOp>(elementType, adaptor.complex());
     Value one =
         b.create<ConstantOp>(elementType, b.getFloatAttr(elementType, 1));
     Value realPlusOne = b.create<AddFOp>(real, one);
@@ -392,20 +377,19 @@ struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
   using OpConversionPattern<complex::MulOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(complex::MulOp op, ArrayRef<Value> operands,
+  matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    complex::MulOp::Adaptor transformed(operands);
     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    auto type = transformed.lhs().getType().cast<ComplexType>();
+    auto type = adaptor.lhs().getType().cast<ComplexType>();
     auto elementType = type.getElementType().cast<FloatType>();
 
-    Value lhsReal = b.create<complex::ReOp>(elementType, transformed.lhs());
+    Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.lhs());
     Value lhsRealAbs = b.create<AbsFOp>(lhsReal);
-    Value lhsImag = b.create<complex::ImOp>(elementType, transformed.lhs());
+    Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.lhs());
     Value lhsImagAbs = b.create<AbsFOp>(lhsImag);
-    Value rhsReal = b.create<complex::ReOp>(elementType, transformed.rhs());
+    Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.rhs());
     Value rhsRealAbs = b.create<AbsFOp>(rhsReal);
-    Value rhsImag = b.create<complex::ImOp>(elementType, transformed.rhs());
+    Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.rhs());
     Value rhsImagAbs = b.create<AbsFOp>(rhsImag);
 
     Value lhsRealTimesRhsReal = b.create<MulFOp>(lhsReal, rhsReal);
@@ -530,17 +514,16 @@ struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
   using OpConversionPattern<complex::NegOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(complex::NegOp op, ArrayRef<Value> operands,
+  matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    complex::NegOp::Adaptor transformed(operands);
     auto loc = op.getLoc();
-    auto type = transformed.complex().getType().cast<ComplexType>();
+    auto type = adaptor.complex().getType().cast<ComplexType>();
     auto elementType = type.getElementType().cast<FloatType>();
 
     Value real =
-        rewriter.create<complex::ReOp>(loc, elementType, transformed.complex());
+        rewriter.create<complex::ReOp>(loc, elementType, adaptor.complex());
     Value imag =
-        rewriter.create<complex::ImOp>(loc, elementType, transformed.complex());
+        rewriter.create<complex::ImOp>(loc, elementType, adaptor.complex());
     Value negReal = rewriter.create<NegFOp>(loc, real);
     Value negImag = rewriter.create<NegFOp>(loc, imag);
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
@@ -552,25 +535,23 @@ struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
   using OpConversionPattern<complex::SignOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(complex::SignOp op, ArrayRef<Value> operands,
+  matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    complex::SignOp::Adaptor transformed(operands);
-    auto type = transformed.complex().getType().cast<ComplexType>();
+    auto type = adaptor.complex().getType().cast<ComplexType>();
     auto elementType = type.getElementType().cast<FloatType>();
     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
 
-    Value real = b.create<complex::ReOp>(elementType, transformed.complex());
-    Value imag = b.create<complex::ImOp>(elementType, transformed.complex());
+    Value real = b.create<complex::ReOp>(elementType, adaptor.complex());
+    Value imag = b.create<complex::ImOp>(elementType, adaptor.complex());
     Value zero = b.create<ConstantOp>(elementType, b.getZeroAttr(elementType));
     Value realIsZero = b.create<CmpFOp>(CmpFPredicate::OEQ, real, zero);
     Value imagIsZero = b.create<CmpFOp>(CmpFPredicate::OEQ, imag, zero);
     Value isZero = b.create<AndOp>(realIsZero, imagIsZero);
-    auto abs = b.create<complex::AbsOp>(elementType, transformed.complex());
+    auto abs = b.create<complex::AbsOp>(elementType, adaptor.complex());
     Value realSign = b.create<DivFOp>(real, abs);
     Value imagSign = b.create<DivFOp>(imag, abs);
     Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
-    rewriter.replaceOpWithNewOp<SelectOp>(op, isZero, transformed.complex(),
-                                          sign);
+    rewriter.replaceOpWithNewOp<SelectOp>(op, isZero, adaptor.complex(), sign);
     return success();
   }
 };
index a303a87..88eca46 100644 (file)
@@ -33,7 +33,7 @@ public:
   using OpConversionPattern<SourceOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
+  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -45,7 +45,7 @@ public:
   using OpConversionPattern<SourceOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
+  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -58,7 +58,7 @@ public:
   using OpConversionPattern<gpu::BlockDimOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(gpu::BlockDimOp op, ArrayRef<Value> operands,
+  matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -68,7 +68,7 @@ public:
   using OpConversionPattern<gpu::GPUFuncOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(gpu::GPUFuncOp funcOp, ArrayRef<Value> operands,
+  matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 
 private:
@@ -81,7 +81,7 @@ public:
   using OpConversionPattern<gpu::GPUModuleOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(gpu::GPUModuleOp moduleOp, ArrayRef<Value> operands,
+  matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -91,7 +91,7 @@ public:
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(gpu::ModuleEndOp endOp, ArrayRef<Value> operands,
+  matchAndRewrite(gpu::ModuleEndOp endOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     rewriter.eraseOp(endOp);
     return success();
@@ -105,7 +105,7 @@ public:
   using OpConversionPattern<gpu::ReturnOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(gpu::ReturnOp returnOp, ArrayRef<Value> operands,
+  matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -129,7 +129,7 @@ static Optional<int32_t> getLaunchConfigIndex(Operation *op) {
 
 template <typename SourceOp, spirv::BuiltIn builtin>
 LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
-    SourceOp op, ArrayRef<Value> operands,
+    SourceOp op, typename SourceOp::Adaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   auto index = getLaunchConfigIndex(op);
   if (!index)
@@ -150,7 +150,7 @@ LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
 template <typename SourceOp, spirv::BuiltIn builtin>
 LogicalResult
 SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
-    SourceOp op, ArrayRef<Value> operands,
+    SourceOp op, typename SourceOp::Adaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
   auto indexType = typeConverter->getIndexType();
@@ -162,7 +162,7 @@ SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
 }
 
 LogicalResult WorkGroupSizeConversion::matchAndRewrite(
-    gpu::BlockDimOp op, ArrayRef<Value> operands,
+    gpu::BlockDimOp op, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   auto index = getLaunchConfigIndex(op);
   if (!index)
@@ -264,7 +264,7 @@ getDefaultABIAttrs(MLIRContext *context, gpu::GPUFuncOp funcOp,
 }
 
 LogicalResult GPUFuncOpConversion::matchAndRewrite(
-    gpu::GPUFuncOp funcOp, ArrayRef<Value> operands,
+    gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   if (!gpu::GPUDialect::isKernel(funcOp))
     return failure();
@@ -306,7 +306,7 @@ LogicalResult GPUFuncOpConversion::matchAndRewrite(
 //===----------------------------------------------------------------------===//
 
 LogicalResult GPUModuleConversion::matchAndRewrite(
-    gpu::GPUModuleOp moduleOp, ArrayRef<Value> operands,
+    gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   spirv::TargetEnvAttr targetEnv = spirv::lookupTargetEnvOrDefault(moduleOp);
   spirv::AddressingModel addressingModel = spirv::getAddressingModel(targetEnv);
@@ -336,9 +336,9 @@ LogicalResult GPUModuleConversion::matchAndRewrite(
 //===----------------------------------------------------------------------===//
 
 LogicalResult GPUReturnOpConversion::matchAndRewrite(
-    gpu::ReturnOp returnOp, ArrayRef<Value> operands,
+    gpu::ReturnOp returnOp, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
-  if (!operands.empty())
+  if (!adaptor.getOperands().empty())
     return failure();
 
   rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
index bd1e4ad..a43ffbc 100644 (file)
@@ -55,7 +55,7 @@ struct SingleWorkgroupReduction final
   matchAsPerformingReduction(linalg::GenericOp genericOp);
 
   LogicalResult
-  matchAndRewrite(linalg::GenericOp genericOp, ArrayRef<Value> operands,
+  matchAndRewrite(linalg::GenericOp genericOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -109,7 +109,7 @@ SingleWorkgroupReduction::matchAsPerformingReduction(
 }
 
 LogicalResult SingleWorkgroupReduction::matchAndRewrite(
-    linalg::GenericOp genericOp, ArrayRef<Value> operands,
+    linalg::GenericOp genericOp, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   Operation *op = genericOp.getOperation();
   auto originalInputType = op->getOperand(0).getType().cast<MemRefType>();
@@ -134,7 +134,8 @@ LogicalResult SingleWorkgroupReduction::matchAndRewrite(
   // TODO: Query the target environment to make sure the current
   // workload fits in a local workgroup.
 
-  Value convertedInput = operands[0], convertedOutput = operands[1];
+  Value convertedInput = adaptor.getOperands()[0];
+  Value convertedOutput = adaptor.getOperands()[1];
   Location loc = genericOp.getLoc();
 
   auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
index e30cbc0..04e8869 100644 (file)
@@ -37,9 +37,9 @@ public:
   using OpConversionPattern<StdOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
+  matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    assert(operands.size() <= 2);
+    assert(adaptor.getOperands().size() <= 2);
     auto dstType = this->getTypeConverter()->convertType(operation.getType());
     if (!dstType)
       return failure();
@@ -48,7 +48,8 @@ public:
       return operation.emitError(
           "bitwidth emulation is not implemented yet on unsigned op");
     }
-    rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType, operands);
+    rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType,
+                                                  adaptor.getOperands());
     return success();
   }
 };
@@ -62,14 +63,15 @@ public:
   using OpConversionPattern<math::Log1pOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(math::Log1pOp operation, ArrayRef<Value> operands,
+  matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    assert(operands.size() == 1);
+    assert(adaptor.getOperands().size() == 1);
     Location loc = operation.getLoc();
     auto type =
         this->getTypeConverter()->convertType(operation.operand().getType());
     auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
-    auto onePlus = rewriter.create<spirv::FAddOp>(loc, one, operands[0]);
+    auto onePlus =
+        rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperands()[0]);
     rewriter.replaceOpWithNewOp<spirv::GLSLLogOp>(operation, type, onePlus);
     return success();
   }
index 9f9f115..7fb0f02 100644 (file)
@@ -158,7 +158,7 @@ public:
   using OpConversionPattern<memref::AllocOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(memref::AllocOp operation, ArrayRef<Value> operands,
+  matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -169,7 +169,7 @@ public:
   using OpConversionPattern<memref::DeallocOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(memref::DeallocOp operation, ArrayRef<Value> operands,
+  matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -179,7 +179,7 @@ public:
   using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
+  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -189,7 +189,7 @@ public:
   using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
+  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -199,7 +199,7 @@ public:
   using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(memref::StoreOp storeOp, ArrayRef<Value> operands,
+  matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -209,7 +209,7 @@ public:
   using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(memref::StoreOp storeOp, ArrayRef<Value> operands,
+  matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -220,8 +220,7 @@ public:
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-AllocOpPattern::matchAndRewrite(memref::AllocOp operation,
-                                ArrayRef<Value> operands,
+AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
                                 ConversionPatternRewriter &rewriter) const {
   MemRefType allocType = operation.getType();
   if (!isAllocationSupported(allocType))
@@ -260,7 +259,7 @@ AllocOpPattern::matchAndRewrite(memref::AllocOp operation,
 
 LogicalResult
 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
-                                  ArrayRef<Value> operands,
+                                  OpAdaptor adaptor,
                                   ConversionPatternRewriter &rewriter) const {
   MemRefType deallocType = operation.memref().getType().cast<MemRefType>();
   if (!isAllocationSupported(deallocType))
@@ -274,10 +273,8 @@ DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
-                                  ArrayRef<Value> operands,
+IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
                                   ConversionPatternRewriter &rewriter) const {
-  memref::LoadOpAdaptor loadOperands(operands);
   auto loc = loadOp.getLoc();
   auto memrefType = loadOp.memref().getType().cast<MemRefType>();
   if (!memrefType.getElementType().isSignlessInteger())
@@ -285,8 +282,8 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
 
   auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
   spirv::AccessChainOp accessChainOp =
-      spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
-                           loadOperands.indices(), loc, rewriter);
+      spirv::getElementPtr(typeConverter, memrefType, adaptor.memref(),
+                           adaptor.indices(), loc, rewriter);
 
   if (!accessChainOp)
     return failure();
@@ -372,15 +369,14 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
 }
 
 LogicalResult
-LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
+LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
                                ConversionPatternRewriter &rewriter) const {
-  memref::LoadOpAdaptor loadOperands(operands);
   auto memrefType = loadOp.memref().getType().cast<MemRefType>();
   if (memrefType.getElementType().isSignlessInteger())
     return failure();
   auto loadPtr = spirv::getElementPtr(
-      *getTypeConverter<SPIRVTypeConverter>(), memrefType,
-      loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter);
+      *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.memref(),
+      adaptor.indices(), loadOp.getLoc(), rewriter);
 
   if (!loadPtr)
     return failure();
@@ -390,10 +386,8 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
 }
 
 LogicalResult
-IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
-                                   ArrayRef<Value> operands,
+IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
                                    ConversionPatternRewriter &rewriter) const {
-  memref::StoreOpAdaptor storeOperands(operands);
   auto memrefType = storeOp.memref().getType().cast<MemRefType>();
   if (!memrefType.getElementType().isSignlessInteger())
     return failure();
@@ -401,8 +395,8 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
   auto loc = storeOp.getLoc();
   auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
   spirv::AccessChainOp accessChainOp =
-      spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(),
-                           storeOperands.indices(), loc, rewriter);
+      spirv::getElementPtr(typeConverter, memrefType, adaptor.memref(),
+                           adaptor.indices(), loc, rewriter);
 
   if (!accessChainOp)
     return failure();
@@ -427,7 +421,7 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
   assert(dstBits % srcBits == 0);
 
   if (srcBits == dstBits) {
-    Value storeVal = storeOperands.value();
+    Value storeVal = adaptor.value();
     if (isBool)
       storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
     rewriter.replaceOpWithNewOp<spirv::StoreOp>(
@@ -458,7 +452,7 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
       rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
   clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
 
-  Value storeVal = storeOperands.value();
+  Value storeVal = adaptor.value();
   if (isBool)
     storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
   storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter);
@@ -487,23 +481,20 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
 }
 
 LogicalResult
-StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
-                                ArrayRef<Value> operands,
+StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
                                 ConversionPatternRewriter &rewriter) const {
-  memref::StoreOpAdaptor storeOperands(operands);
   auto memrefType = storeOp.memref().getType().cast<MemRefType>();
   if (memrefType.getElementType().isSignlessInteger())
     return failure();
-  auto storePtr =
-      spirv::getElementPtr(*getTypeConverter<SPIRVTypeConverter>(), memrefType,
-                           storeOperands.memref(), storeOperands.indices(),
-                           storeOp.getLoc(), rewriter);
+  auto storePtr = spirv::getElementPtr(
+      *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.memref(),
+      adaptor.indices(), storeOp.getLoc(), rewriter);
 
   if (!storePtr)
     return failure();
 
   rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
-                                              storeOperands.value());
+                                              adaptor.value());
   return success();
 }
 
index 08e3d3f..dc5bdc7 100644 (file)
@@ -84,7 +84,7 @@ public:
   using SCFToSPIRVPattern<scf::ForOp>::SCFToSPIRVPattern;
 
   LogicalResult
-  matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
+  matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -95,7 +95,7 @@ public:
   using SCFToSPIRVPattern<scf::IfOp>::SCFToSPIRVPattern;
 
   LogicalResult
-  matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
+  matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -104,7 +104,7 @@ public:
   using SCFToSPIRVPattern<scf::YieldOp>::SCFToSPIRVPattern;
 
   LogicalResult
-  matchAndRewrite(scf::YieldOp terminatorOp, ArrayRef<Value> operands,
+  matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 } // namespace
@@ -146,14 +146,13 @@ static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
+ForOpConversion::matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
                                  ConversionPatternRewriter &rewriter) const {
   // scf::ForOp can be lowered to the structured control flow represented by
   // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
   // latch and the merge block the exit block. The resulting spirv::LoopOp has a
   // single back edge from the continue to header block, and a single exit from
   // header to merge.
-  scf::ForOpAdaptor forOperands(operands);
   auto loc = forOp.getLoc();
   auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
   loopOp.addEntryAndMergeBlock();
@@ -165,9 +164,8 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
   loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header);
 
   // Create the new induction variable to use.
-  BlockArgument newIndVar =
-      header->addArgument(forOperands.lowerBound().getType());
-  for (Value arg : forOperands.initArgs())
+  BlockArgument newIndVar = header->addArgument(adaptor.lowerBound().getType());
+  for (Value arg : adaptor.initArgs())
     header->addArgument(arg.getType());
   Block *body = forOp.getBody();
 
@@ -187,8 +185,8 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
   rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.body(),
                               std::next(loopOp.body().begin(), 2));
 
-  SmallVector<Value, 8> args(1, forOperands.lowerBound());
-  args.append(forOperands.initArgs().begin(), forOperands.initArgs().end());
+  SmallVector<Value, 8> args(1, adaptor.lowerBound());
+  args.append(adaptor.initArgs().begin(), adaptor.initArgs().end());
   // Branch into it from the entry.
   rewriter.setInsertionPointToEnd(&(loopOp.body().front()));
   rewriter.create<spirv::BranchOp>(loc, header, args);
@@ -197,7 +195,7 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
   rewriter.setInsertionPointToEnd(header);
   auto *mergeBlock = loopOp.getMergeBlock();
   auto cmpOp = rewriter.create<spirv::SLessThanOp>(
-      loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound());
+      loc, rewriter.getI1Type(), newIndVar, adaptor.upperBound());
 
   rewriter.create<spirv::BranchConditionalOp>(
       loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
@@ -209,7 +207,7 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
 
   // Add the step to the induction variable and branch to the header.
   Value updatedIndVar = rewriter.create<spirv::IAddOp>(
-      loc, newIndVar.getType(), newIndVar, forOperands.step());
+      loc, newIndVar.getType(), newIndVar, adaptor.step());
   rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
 
   // Infer the return types from the init operands. Vector type may get
@@ -217,7 +215,7 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
   // extra logic to figure out the right type we just infer it from the Init
   // operands.
   SmallVector<Type, 8> initTypes;
-  for (auto arg : forOperands.initArgs())
+  for (auto arg : adaptor.initArgs())
     initTypes.push_back(arg.getType());
   replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext, initTypes);
   return success();
@@ -228,12 +226,11 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
+IfOpConversion::matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
                                 ConversionPatternRewriter &rewriter) const {
   // When lowering `scf::IfOp` we explicitly create a selection header block
   // before the control flow diverges and a merge block where control flow
   // subsequently converges.
-  scf::IfOpAdaptor ifOperands(operands);
   auto loc = ifOp.getLoc();
 
   // Create `spv.selection` operation, selection header block and merge block.
@@ -267,7 +264,7 @@ IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
 
   // Create a `spv.BranchConditional` operation for selection header block.
   rewriter.setInsertionPointToEnd(selectionHeaderBlock);
-  rewriter.create<spirv::BranchConditionalOp>(loc, ifOperands.condition(),
+  rewriter.create<spirv::BranchConditionalOp>(loc, adaptor.condition(),
                                               thenBlock, ArrayRef<Value>(),
                                               elseBlock, ArrayRef<Value>());
 
@@ -289,8 +286,10 @@ IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
 /// parent region. For loops we also need to update the branch looping back to
 /// the header with the loop carried values.
 LogicalResult TerminatorOpConversion::matchAndRewrite(
-    scf::YieldOp terminatorOp, ArrayRef<Value> operands,
+    scf::YieldOp terminatorOp, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
+  ValueRange operands = adaptor.getOperands();
+
   // If the region is return values, store each value into the associated
   // VariableOp created during lowering of the parent region.
   if (!operands.empty()) {
index 8d957e0..348d8ad 100644 (file)
@@ -302,7 +302,7 @@ public:
   using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(spirv::AccessChainOp op, ArrayRef<Value> operands,
+  matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto dstType = typeConverter.convertType(op.component_ptr().getType());
     if (!dstType)
@@ -327,7 +327,7 @@ public:
   using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(spirv::AddressOfOp op, ArrayRef<Value> operands,
+  matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto dstType = typeConverter.convertType(op.pointer().getType());
     if (!dstType)
@@ -343,7 +343,7 @@ public:
   using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(spirv::BitFieldInsertOp op, ArrayRef<Value> operands,
+  matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto srcType = op.getType();
     auto dstType = typeConverter.convertType(srcType);
@@ -387,7 +387,7 @@ public:
   using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(spirv::ConstantOp constOp, ArrayRef<Value> operands,
+  matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto srcType = constOp.getType();
     if (!srcType.isa<VectorType>() && !srcType.isIntOrFloat())
@@ -419,8 +419,8 @@ public:
       rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
       return success();
     }
-    rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, operands,
-                                                  constOp->getAttrs());
+    rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
+        constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
     return success();
   }
 };
@@ -431,7 +431,7 @@ public:
   using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(spirv::BitFieldSExtractOp op, ArrayRef<Value> operands,
+  matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto srcType = op.getType();
     auto dstType = typeConverter.convertType(srcType);
@@ -484,7 +484,7 @@ public:
   using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(spirv::BitFieldUExtractOp op, ArrayRef<Value> operands,
+  matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto srcType = op.getType();
     auto dstType = typeConverter.convertType(srcType);
@@ -518,9 +518,9 @@ public:
   using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(spirv::BranchOp branchOp, ArrayRef<Value> operands,
+  matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, operands,
+    rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
                                             branchOp.getTarget());
     return success();
   }
@@ -533,7 +533,7 @@ public:
       spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(spirv::BranchConditionalOp op, ArrayRef<Value> operands,
+  matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // If branch weights exist, map them to 32-bit integer vector.
     ElementsAttr branchWeights = nullptr;
@@ -560,7 +560,7 @@ public:
   using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(spirv::CompositeExtractOp op, ArrayRef<Value> operands,
+  matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto dstType = this->typeConverter.convertType(op.getType());
     if (!dstType)
@@ -590,7 +590,7 @@ public:
   using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(spirv::CompositeInsertOp op, ArrayRef<Value> operands,
+  matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto dstType = this->typeConverter.convertType(op.getType());
     if (!dstType)
@@ -619,13 +619,13 @@ public:
   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
+  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto dstType = this->typeConverter.convertType(operation.getType());
     if (!dstType)
       return failure();
-    rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType, operands,
-                                                 operation->getAttrs());
+    rewriter.template replaceOpWithNewOp<LLVMOp>(
+        operation, dstType, adaptor.getOperands(), operation->getAttrs());
     return success();
   }
 };
@@ -638,7 +638,7 @@ public:
   using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(spirv::ExecutionModeOp op, ArrayRef<Value> operands,
+  matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // First, create the global struct's name that would be associated with
     // this entry point's execution mode. We set it to be:
@@ -717,7 +717,7 @@ public:
   using SPIRVToLLVMConversion<spirv::GlobalVariableOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(spirv::GlobalVariableOp op, ArrayRef<Value> operands,
+  matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Currently, there is no support of initialization with a constant value in
     // SPIR-V dialect. Specialization constants are not considered as well.
@@ -767,7 +767,7 @@ public:
   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
+  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
     Type fromType = operation.operand().getType();
@@ -779,12 +779,12 @@ public:
 
     if (getBitWidth(fromType) < getBitWidth(toType)) {
       rewriter.template replaceOpWithNewOp<LLVMExtOp>(operation, dstType,
-                                                      operands);
+                                                      adaptor.getOperands());
       return success();
     }
     if (getBitWidth(fromType) > getBitWidth(toType)) {
       rewriter.template replaceOpWithNewOp<LLVMTruncOp>(operation, dstType,
-                                                        operands);
+                                                        adaptor.getOperands());
       return success();
     }
     return failure();
@@ -797,18 +797,18 @@ public:
   using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(spirv::FunctionCallOp callOp, ArrayRef<Value> operands,
+  matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     if (callOp.getNumResults() == 0) {
-      rewriter.replaceOpWithNewOp<LLVM::CallOp>(callOp, llvm::None, operands,
-                                                callOp->getAttrs());
+      rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+          callOp, llvm::None, adaptor.getOperands(), callOp->getAttrs());
       return success();
     }
 
     // Function returns a single result.
     auto dstType = typeConverter.convertType(callOp.getType(0));
-    rewriter.replaceOpWithNewOp<LLVM::CallOp>(callOp, dstType, operands,
-                                              callOp->getAttrs());
+    rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+        callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
     return success();
   }
 };
@@ -820,7 +820,7 @@ public:
   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
+  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
     auto dstType = this->typeConverter.convertType(operation.getType());
@@ -841,7 +841,7 @@ public:
   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
+  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
     auto dstType = this->typeConverter.convertType(operation.getType());
@@ -861,7 +861,7 @@ public:
   using SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(spirv::GLSLInverseSqrtOp op, ArrayRef<Value> operands,
+  matchAndRewrite(spirv::GLSLInverseSqrtOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto srcType = op.getType();
     auto dstType = typeConverter.convertType(srcType);
@@ -877,15 +877,14 @@ public:
 };
 
 /// Converts `spv.Load` and `spv.Store` to LLVM dialect.
-template <typename SPIRVop>
-class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVop> {
+template <typename SPIRVOp>
+class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
 public:
-  using SPIRVToLLVMConversion<SPIRVop>::SPIRVToLLVMConversion;
+  using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(SPIRVop op, ArrayRef<Value> operands,
+  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-
     if (!op.memory_access().hasValue()) {
       return replaceWithLoadOrStore(
           op, rewriter, this->typeConverter, /*alignment=*/0,
@@ -918,9 +917,8 @@ public:
   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(SPIRVOp notOp, ArrayRef<Value> operands,
+  matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-
     auto srcType = notOp.getType();
     auto dstType = this->typeConverter.convertType(srcType);
     if (!dstType)
@@ -947,7 +945,7 @@ public:
   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(SPIRVOp op, ArrayRef<Value> operands,
+  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     rewriter.eraseOp(op);
     return success();
@@ -959,7 +957,7 @@ public:
   using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(spirv::ReturnOp returnOp, ArrayRef<Value> operands,
+  matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
                                                 ArrayRef<Value>());
@@ -972,10 +970,10 @@ public:
   using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(spirv::ReturnValueOp returnValueOp, ArrayRef<Value> operands,
+  matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
-                                                operands);
+                                                adaptor.getOperands());
     return success();
   }
 };
@@ -1033,7 +1031,7 @@ public:
   using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(spirv::LoopOp loopOp, ArrayRef<Value> operands,
+  matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // There is no support of loop control at the moment.
     if (loopOp.loop_control() != spirv::LoopControl::None)
@@ -1080,7 +1078,7 @@ public:
   using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(spirv::SelectionOp op, ArrayRef<Value> operands,
+  matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // There is no support for `Flatten` or `DontFlatten` selection control at
     // the moment. This are just compiler hints and can be performed during the
@@ -1149,7 +1147,7 @@ public:
   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
+  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
     auto dstType = this->typeConverter.convertType(operation.getType());
@@ -1161,7 +1159,7 @@ public:
 
     if (op1Type == op2Type) {
       rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
-                                                   operands);
+                                                   adaptor.getOperands());
       return success();
     }
 
@@ -1186,7 +1184,7 @@ public:
   using SPIRVToLLVMConversion<spirv::GLSLTanOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(spirv::GLSLTanOp tanOp, ArrayRef<Value> operands,
+  matchAndRewrite(spirv::GLSLTanOp tanOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto dstType = typeConverter.convertType(tanOp.getType());
     if (!dstType)
@@ -1211,7 +1209,7 @@ public:
   using SPIRVToLLVMConversion<spirv::GLSLTanhOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(spirv::GLSLTanhOp tanhOp, ArrayRef<Value> operands,
+  matchAndRewrite(spirv::GLSLTanhOp tanhOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto srcType = tanhOp.getType();
     auto dstType = typeConverter.convertType(srcType);
@@ -1239,7 +1237,7 @@ public:
   using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(spirv::VariableOp varOp, ArrayRef<Value> operands,
+  matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto srcType = varOp.getType();
     // Initialization is supported for scalars and vectors only.
@@ -1274,7 +1272,7 @@ public:
   using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(spirv::FuncOp funcOp, ArrayRef<Value> operands,
+  matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
     // Convert function signature. At the moment LLVMType converter is enough
@@ -1337,7 +1335,7 @@ public:
   using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion;
 
   LogicalResult
-  matchAndRewrite(spirv::ModuleOp spvModuleOp, ArrayRef<Value> operands,
+  matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
     auto newModuleOp =
index b662b52..f622d5e 100644 (file)
@@ -29,19 +29,17 @@ public:
   using OpConversionPattern<AnyOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
+  matchAndRewrite(AnyOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 } // namespace
 
 LogicalResult
-AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
+AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor,
                                  ConversionPatternRewriter &rewriter) const {
-  AnyOp::Adaptor transformed(operands);
-
   // Replace `any` with its first operand.
   // Any operand would be a valid substitution.
-  rewriter.replaceOp(op, {transformed.inputs().front()});
+  rewriter.replaceOp(op, {adaptor.inputs().front()});
   return success();
 }
 
@@ -52,16 +50,13 @@ public:
   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands,
+  matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    typename SrcOpTy::Adaptor transformed(operands);
-
     // For now, only error-free types are supported by this lowering.
     if (op.getType().template isa<SizeType>())
       return failure();
 
-    rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(),
-                                         transformed.rhs());
+    rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.lhs(), adaptor.rhs());
     return success();
   }
 };
@@ -72,7 +67,7 @@ struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
   using OpConversionPattern<BroadcastOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands,
+  matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -120,7 +115,7 @@ Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
 } // namespace
 
 LogicalResult BroadcastOpConverter::matchAndRewrite(
-    BroadcastOp op, ArrayRef<Value> operands,
+    BroadcastOp op, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
   // on shapes.
@@ -129,7 +124,6 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
 
   auto loc = op.getLoc();
   ImplicitLocOpBuilder lb(loc, rewriter);
-  BroadcastOp::Adaptor transformed(operands);
 
   Value zero = lb.create<ConstantIndexOp>(0);
   Type indexTy = lb.getIndexType();
@@ -138,7 +132,7 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
   // representing the shape extents, the rank is the extent of the only
   // dimension in the tensor.
   SmallVector<Value> ranks, rankDiffs;
-  llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) {
+  llvm::append_range(ranks, llvm::map_range(adaptor.shapes(), [&](Value v) {
                        return lb.create<tensor::DimOp>(v, zero);
                      }));
 
@@ -157,9 +151,8 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
   Value replacement = lb.create<tensor::GenerateOp>(
       getExtentTensorType(lb.getContext()), ValueRange{maxRank},
       [&](OpBuilder &b, Location loc, ValueRange args) {
-        Value broadcastedDim =
-            getBroadcastedDim(ImplicitLocOpBuilder(loc, b),
-                              transformed.shapes(), rankDiffs, args[0]);
+        Value broadcastedDim = getBroadcastedDim(
+            ImplicitLocOpBuilder(loc, b), adaptor.shapes(), rankDiffs, args[0]);
 
         b.create<tensor::YieldOp>(loc, broadcastedDim);
       });
@@ -175,13 +168,13 @@ public:
   using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 } // namespace
 
 LogicalResult ConstShapeOpConverter::matchAndRewrite(
-    ConstShapeOp op, ArrayRef<Value> operands,
+    ConstShapeOp op, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
 
   // For now, this lowering supports only extent tensors, not `shape.shape`
@@ -209,13 +202,13 @@ public:
   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 } // namespace
 
 LogicalResult ConstSizeOpConversion::matchAndRewrite(
-    ConstSizeOp op, ArrayRef<Value> operands,
+    ConstSizeOp op, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue());
   return success();
@@ -227,17 +220,16 @@ struct IsBroadcastableOpConverter
   using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(IsBroadcastableOp op, ArrayRef<Value> operands,
+  matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 } // namespace
 
 LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
-    IsBroadcastableOp op, ArrayRef<Value> operands,
+    IsBroadcastableOp op, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
   // on shapes.
-  IsBroadcastableOp::Adaptor transformed(operands);
   if (!llvm::all_of(op.shapes(),
                     [](Value v) { return !v.getType().isa<ShapeType>(); }))
     return failure();
@@ -252,7 +244,7 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
   // representing the shape extents, the rank is the extent of the only
   // dimension in the tensor.
   SmallVector<Value> ranks, rankDiffs;
-  llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) {
+  llvm::append_range(ranks, llvm::map_range(adaptor.shapes(), [&](Value v) {
                        return lb.create<tensor::DimOp>(v, zero);
                      }));
 
@@ -279,10 +271,10 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
         // could reuse the Broadcast lowering entirely, but we redo the work
         // here to make optimizations easier between the two loops.
         Value broadcastedDim = getBroadcastedDim(
-            ImplicitLocOpBuilder(loc, b), transformed.shapes(), rankDiffs, iv);
+            ImplicitLocOpBuilder(loc, b), adaptor.shapes(), rankDiffs, iv);
 
         Value broadcastable = iterArgs[0];
-        for (auto tup : llvm::zip(transformed.shapes(), rankDiffs)) {
+        for (auto tup : llvm::zip(adaptor.shapes(), rankDiffs)) {
           Value shape, rankDiff;
           std::tie(shape, rankDiff) = tup;
           Value outOfBounds =
@@ -327,16 +319,14 @@ class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
+  matchAndRewrite(GetExtentOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 } // namespace
 
 LogicalResult GetExtentOpConverter::matchAndRewrite(
-    GetExtentOp op, ArrayRef<Value> operands,
+    GetExtentOp op, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
-  GetExtentOp::Adaptor transformed(operands);
-
   // For now, only error-free types are supported by this lowering.
   if (op.getType().isa<SizeType>())
     return failure();
@@ -346,14 +336,13 @@ LogicalResult GetExtentOpConverter::matchAndRewrite(
   if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
     if (shapeOfOp.arg().getType().isa<ShapedType>()) {
       rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.arg(),
-                                                 transformed.dim());
+                                                 adaptor.dim());
       return success();
     }
   }
 
-  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(),
-                                                 transformed.shape(),
-                                                 ValueRange{transformed.dim()});
+  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
+      op, rewriter.getIndexType(), adaptor.shape(), ValueRange{adaptor.dim()});
   return success();
 }
 
@@ -363,20 +352,19 @@ public:
   using OpConversionPattern<shape::RankOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
+  matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 } // namespace
 
 LogicalResult
-RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
+RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
                                  ConversionPatternRewriter &rewriter) const {
   // For now, this lowering supports only error-free types.
   if (op.getType().isa<SizeType>())
     return failure();
 
-  shape::RankOp::Adaptor transformed(operands);
-  rewriter.replaceOpWithNewOp<tensor::DimOp>(op, transformed.shape(), 0);
+  rewriter.replaceOpWithNewOp<tensor::DimOp>(op, adaptor.shape(), 0);
   return success();
 }
 
@@ -387,32 +375,30 @@ public:
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
+  matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final;
 };
 } // namespace
 
 LogicalResult
-ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
+ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
                                    ConversionPatternRewriter &rewriter) const {
   // For now, this lowering is only defined on `tensor<?xindex>` operands.
   if (op.shape().getType().isa<ShapeType>())
     return failure();
 
   auto loc = op.getLoc();
-  shape::ReduceOp::Adaptor transformed(operands);
 
   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
   Type indexTy = rewriter.getIndexType();
   Value rank =
-      rewriter.create<tensor::DimOp>(loc, indexTy, transformed.shape(), zero);
+      rewriter.create<tensor::DimOp>(loc, indexTy, adaptor.shape(), zero);
 
   auto loop = rewriter.create<scf::ForOp>(
       loc, zero, rank, one, op.initVals(),
       [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
-        Value extent =
-            b.create<tensor::ExtractOp>(loc, transformed.shape(), iv);
+        Value extent = b.create<tensor::ExtractOp>(loc, adaptor.shape(), iv);
 
         SmallVector<Value, 2> mappedValues{iv, extent};
         mappedValues.append(args.begin(), args.end());
@@ -468,13 +454,13 @@ struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
   using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 } // namespace
 
 LogicalResult
-ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
+ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
                                     ConversionPatternRewriter &rewriter) const {
   if (!llvm::all_of(op.shapes(),
                     [](Value v) { return !v.getType().isa<ShapeType>(); }))
@@ -487,16 +473,15 @@ ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
     return success();
   }
 
-  ShapeEqOp::Adaptor transformed(operands);
   auto loc = op.getLoc();
   Type indexTy = rewriter.getIndexType();
   Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
-  Value firstShape = transformed.shapes().front();
+  Value firstShape = adaptor.shapes().front();
   Value firstRank =
       rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero);
   Value result = nullptr;
   // Generate a linear sequence of compares, all with firstShape as lhs.
-  for (Value shape : transformed.shapes().drop_front(1)) {
+  for (Value shape : adaptor.shapes().drop_front(1)) {
     Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero);
     Value eqRank =
         rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, firstRank, rank);
@@ -536,13 +521,13 @@ public:
   using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 } // namespace
 
 LogicalResult ShapeOfOpConversion::matchAndRewrite(
-    ShapeOfOp op, ArrayRef<Value> operands,
+    ShapeOfOp op, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
 
   // For now, only error-free types are supported by this lowering.
@@ -551,8 +536,7 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
 
   // For ranked tensor arguments, lower to `tensor.from_elements`.
   auto loc = op.getLoc();
-  ShapeOfOp::Adaptor transformed(operands);
-  Value tensor = transformed.arg();
+  Value tensor = adaptor.arg();
   Type tensorTy = tensor.getType();
   if (tensorTy.isa<RankedTensorType>()) {
 
@@ -599,13 +583,13 @@ public:
   using OpConversionPattern<SplitAtOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(SplitAtOp op, ArrayRef<Value> operands,
+  matchAndRewrite(SplitAtOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 } // namespace
 
 LogicalResult SplitAtOpConversion::matchAndRewrite(
-    SplitAtOp op, ArrayRef<Value> operands,
+    SplitAtOp op, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   // Error conditions are not implemented, only lower if all operands and
   // results are extent tensors.
@@ -613,13 +597,12 @@ LogicalResult SplitAtOpConversion::matchAndRewrite(
                    [](Value v) { return v.getType().isa<ShapeType>(); }))
     return failure();
 
-  SplitAtOp::Adaptor transformed(op);
   ImplicitLocOpBuilder b(op.getLoc(), rewriter);
   Value zero = b.create<ConstantIndexOp>(0);
-  Value rank = b.create<tensor::DimOp>(transformed.operand(), zero);
+  Value rank = b.create<tensor::DimOp>(adaptor.operand(), zero);
 
   // index < 0 ? index + rank : index
-  Value originalIndex = transformed.index();
+  Value originalIndex = adaptor.index();
   Value add = b.create<AddIOp>(originalIndex, rank);
   Value indexIsNegative =
       b.create<CmpIOp>(CmpIPredicate::slt, originalIndex, zero);
@@ -627,10 +610,10 @@ LogicalResult SplitAtOpConversion::matchAndRewrite(
 
   Value one = b.create<ConstantIndexOp>(1);
   Value head =
-      b.create<tensor::ExtractSliceOp>(transformed.operand(), zero, index, one);
+      b.create<tensor::ExtractSliceOp>(adaptor.operand(), zero, index, one);
   Value tailSize = b.create<SubIOp>(rank, index);
-  Value tail = b.create<tensor::ExtractSliceOp>(transformed.operand(), index,
-                                                tailSize, one);
+  Value tail =
+      b.create<tensor::ExtractSliceOp>(adaptor.operand(), index, tailSize, one);
   rewriter.replaceOp(op, {head, tail});
   return success();
 }
@@ -642,10 +625,8 @@ public:
   using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    ToExtentTensorOpAdaptor adaptor(operands);
-
     if (!adaptor.input().getType().isa<RankedTensorType>())
       return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
 
index 7a59330..8327492 100644 (file)
@@ -292,7 +292,7 @@ struct FuncOpConversion : public FuncOpConversionBase {
       : FuncOpConversionBase(converter) {}
 
   LogicalResult
-  matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
+  matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
     if (!newFuncOp)
@@ -319,7 +319,7 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
   using FuncOpConversionBase::FuncOpConversionBase;
 
   LogicalResult
-  matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
+  matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
     // TODO: bare ptr conversion could be handled by argument materialization
@@ -442,10 +442,9 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<AssertOp> {
   using ConvertOpToLLVMPattern<AssertOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(AssertOp op, ArrayRef<Value> operands,
+  matchAndRewrite(AssertOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
-    AssertOp::Adaptor transformed(operands);
 
     // Insert the `abort` declaration if necessary.
     auto module = op->getParentOfType<ModuleOp>();
@@ -471,7 +470,7 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<AssertOp> {
     // Generate assertion test.
     rewriter.setInsertionPointToEnd(opBlock);
     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
-        op, transformed.arg(), continuationBlock, failureBlock);
+        op, adaptor.arg(), continuationBlock, failureBlock);
 
     return success();
   }
@@ -481,7 +480,7 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
   using ConvertOpToLLVMPattern<ConstantOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(ConstantOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ConstantOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // If constant refers to a function, convert it to "addressof".
     if (auto symbolRef = op.getValue().dyn_cast<FlatSymbolRefAttr>()) {
@@ -506,8 +505,8 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
           op, "referring to a symbol outside of the current module");
 
     return LLVM::detail::oneToOneRewrite(
-        op, LLVM::ConstantOp::getOperationName(), operands, *getTypeConverter(),
-        rewriter);
+        op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
+        *getTypeConverter(), rewriter);
   }
 };
 
@@ -520,10 +519,8 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
   using Base = ConvertOpToLLVMPattern<CallOpType>;
 
   LogicalResult
-  matchAndRewrite(CallOpType callOp, ArrayRef<Value> operands,
+  matchAndRewrite(CallOpType callOp, typename CallOpType::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    typename CallOpType::Adaptor transformed(operands);
-
     // Pack the result types into a struct.
     Type packedResult = nullptr;
     unsigned numResults = callOp.getNumResults();
@@ -536,8 +533,8 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
     }
 
     auto promoted = this->getTypeConverter()->promoteOperands(
-        callOp.getLoc(), /*opOperands=*/callOp->getOperands(), operands,
-        rewriter);
+        callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
+        adaptor.getOperands(), rewriter);
     auto newOp = rewriter.create<LLVM::CallOp>(
         callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
         promoted, callOp->getAttrs());
@@ -591,22 +588,21 @@ struct UnrealizedConversionCastOpLowering
       UnrealizedConversionCastOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(UnrealizedConversionCastOp op, ArrayRef<Value> operands,
+  matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    UnrealizedConversionCastOp::Adaptor transformed(operands);
     SmallVector<Type> convertedTypes;
     if (succeeded(typeConverter->convertTypes(op.outputs().getTypes(),
                                               convertedTypes)) &&
-        convertedTypes == transformed.inputs().getTypes()) {
-      rewriter.replaceOp(op, transformed.inputs());
+        convertedTypes == adaptor.inputs().getTypes()) {
+      rewriter.replaceOp(op, adaptor.inputs());
       return success();
     }
 
     convertedTypes.clear();
-    if (succeeded(typeConverter->convertTypes(transformed.inputs().getTypes(),
+    if (succeeded(typeConverter->convertTypes(adaptor.inputs().getTypes(),
                                               convertedTypes)) &&
         convertedTypes == op.outputs().getType()) {
-      rewriter.replaceOp(op, transformed.inputs());
+      rewriter.replaceOp(op, adaptor.inputs());
       return success();
     }
     return failure();
@@ -617,12 +613,12 @@ struct RankOpLowering : public ConvertOpToLLVMPattern<RankOp> {
   using ConvertOpToLLVMPattern<RankOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(RankOp op, ArrayRef<Value> operands,
+  matchAndRewrite(RankOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     Type operandType = op.memrefOrTensor().getType();
     if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
-      UnrankedMemRefDescriptor desc(RankOp::Adaptor(operands).memrefOrTensor());
+      UnrankedMemRefDescriptor desc(adaptor.memrefOrTensor());
       rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
       return success();
     }
@@ -658,10 +654,8 @@ struct IndexCastOpLowering : public ConvertOpToLLVMPattern<IndexCastOp> {
   using ConvertOpToLLVMPattern<IndexCastOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(IndexCastOp indexCastOp, ArrayRef<Value> operands,
+  matchAndRewrite(IndexCastOp indexCastOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    IndexCastOpAdaptor transformed(operands);
-
     auto targetType =
         typeConverter->convertType(indexCastOp.getResult().getType());
     auto targetElementType =
@@ -669,18 +663,18 @@ struct IndexCastOpLowering : public ConvertOpToLLVMPattern<IndexCastOp> {
             ->convertType(getElementTypeOrSelf(indexCastOp.getResult()))
             .cast<IntegerType>();
     auto sourceElementType =
-        getElementTypeOrSelf(transformed.in()).cast<IntegerType>();
+        getElementTypeOrSelf(adaptor.in()).cast<IntegerType>();
     unsigned targetBits = targetElementType.getWidth();
     unsigned sourceBits = sourceElementType.getWidth();
 
     if (targetBits == sourceBits)
-      rewriter.replaceOp(indexCastOp, transformed.in());
+      rewriter.replaceOp(indexCastOp, adaptor.in());
     else if (targetBits < sourceBits)
       rewriter.replaceOpWithNewOp<LLVM::TruncOp>(indexCastOp, targetType,
-                                                 transformed.in());
+                                                 adaptor.in());
     else
       rewriter.replaceOpWithNewOp<LLVM::SExtOp>(indexCastOp, targetType,
-                                                transformed.in());
+                                                adaptor.in());
     return success();
   }
 };
@@ -696,10 +690,9 @@ struct CmpIOpLowering : public ConvertOpToLLVMPattern<CmpIOp> {
   using ConvertOpToLLVMPattern<CmpIOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(CmpIOp cmpiOp, ArrayRef<Value> operands,
+  matchAndRewrite(CmpIOp cmpiOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    CmpIOpAdaptor transformed(operands);
-    auto operandType = transformed.lhs().getType();
+    auto operandType = adaptor.lhs().getType();
     auto resultType = cmpiOp.getResult().getType();
 
     // Handle the scalar and 1D vector cases.
@@ -707,7 +700,7 @@ struct CmpIOpLowering : public ConvertOpToLLVMPattern<CmpIOp> {
       rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
           cmpiOp, typeConverter->convertType(resultType),
           convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()),
-          transformed.lhs(), transformed.rhs());
+          adaptor.lhs(), adaptor.rhs());
       return success();
     }
 
@@ -716,13 +709,13 @@ struct CmpIOpLowering : public ConvertOpToLLVMPattern<CmpIOp> {
       return rewriter.notifyMatchFailure(cmpiOp, "expected vector result type");
 
     return LLVM::detail::handleMultidimensionalVectors(
-        cmpiOp.getOperation(), operands, *getTypeConverter(),
+        cmpiOp.getOperation(), adaptor.getOperands(), *getTypeConverter(),
         [&](Type llvm1DVectorTy, ValueRange operands) {
-          CmpIOpAdaptor transformed(operands);
+          CmpIOpAdaptor adaptor(operands);
           return rewriter.create<LLVM::ICmpOp>(
               cmpiOp.getLoc(), llvm1DVectorTy,
               convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()),
-              transformed.lhs(), transformed.rhs());
+              adaptor.lhs(), adaptor.rhs());
         },
         rewriter);
 
@@ -734,10 +727,9 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
   using ConvertOpToLLVMPattern<CmpFOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(CmpFOp cmpfOp, ArrayRef<Value> operands,
+  matchAndRewrite(CmpFOp cmpfOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    CmpFOpAdaptor transformed(operands);
-    auto operandType = transformed.lhs().getType();
+    auto operandType = adaptor.lhs().getType();
     auto resultType = cmpfOp.getResult().getType();
 
     // Handle the scalar and 1D vector cases.
@@ -745,7 +737,7 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
       rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
           cmpfOp, typeConverter->convertType(resultType),
           convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()),
-          transformed.lhs(), transformed.rhs());
+          adaptor.lhs(), adaptor.rhs());
       return success();
     }
 
@@ -754,13 +746,13 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
       return rewriter.notifyMatchFailure(cmpfOp, "expected vector result type");
 
     return LLVM::detail::handleMultidimensionalVectors(
-        cmpfOp.getOperation(), operands, *getTypeConverter(),
+        cmpfOp.getOperation(), adaptor.getOperands(), *getTypeConverter(),
         [&](Type llvm1DVectorTy, ValueRange operands) {
-          CmpFOpAdaptor transformed(operands);
+          CmpFOpAdaptor adaptor(operands);
           return rewriter.create<LLVM::FCmpOp>(
               cmpfOp.getLoc(), llvm1DVectorTy,
               convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()),
-              transformed.lhs(), transformed.rhs());
+              adaptor.lhs(), adaptor.rhs());
         },
         rewriter);
   }
@@ -774,10 +766,10 @@ struct OneToOneLLVMTerminatorLowering
   using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
 
   LogicalResult
-  matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
+  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<TargetOp>(op, operands, op->getSuccessors(),
-                                          op->getAttrs());
+    rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getOperands(),
+                                          op->getSuccessors(), op->getAttrs());
     return success();
   }
 };
@@ -792,7 +784,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
   using ConvertOpToLLVMPattern<ReturnOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     unsigned numArguments = op.getNumOperands();
@@ -801,7 +793,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
     if (getTypeConverter()->getOptions().useBarePtrCallConv) {
       // For the bare-ptr calling convention, extract the aligned pointer to
       // be returned from the memref descriptor.
-      for (auto it : llvm::zip(op->getOperands(), operands)) {
+      for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
         Type oldTy = std::get<0>(it).getType();
         Value newOperand = std::get<1>(it);
         if (oldTy.isa<MemRefType>()) {
@@ -815,7 +807,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
         updatedOperands.push_back(newOperand);
       }
     } else {
-      updatedOperands = llvm::to_vector<4>(operands);
+      updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
       (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
                                     updatedOperands,
                                     /*toDynamic=*/true);
@@ -870,14 +862,12 @@ struct SplatOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
   using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(SplatOp splatOp, ArrayRef<Value> operands,
+  matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
     if (!resultType || resultType.getRank() != 1)
       return failure();
 
-    SplatOp::Adaptor adaptor(operands);
-
     // First insert it into an undef vector so we can shuffle it.
     auto vectorType = typeConverter->convertType(splatOp.getType());
     Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
@@ -907,9 +897,8 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
   using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(SplatOp splatOp, ArrayRef<Value> operands,
+  matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    SplatOp::Adaptor adaptor(operands);
     VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
     if (!resultType || resultType.getRank() == 1)
       return failure();
@@ -984,14 +973,13 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
   using Base::Base;
 
   LogicalResult
-  matchAndRewrite(AtomicRMWOp atomicOp, ArrayRef<Value> operands,
+  matchAndRewrite(AtomicRMWOp atomicOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     if (failed(match(atomicOp)))
       return failure();
     auto maybeKind = matchSimpleAtomicOp(atomicOp);
     if (!maybeKind)
       return failure();
-    AtomicRMWOp::Adaptor adaptor(operands);
     auto resultType = adaptor.value().getType();
     auto memRefType = atomicOp.getMemRefType();
     auto dataPtr =
@@ -1036,11 +1024,10 @@ struct GenericAtomicRMWOpLowering
   using Base::Base;
 
   LogicalResult
-  matchAndRewrite(GenericAtomicRMWOp atomicOp, ArrayRef<Value> operands,
+  matchAndRewrite(GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
     auto loc = atomicOp.getLoc();
-    GenericAtomicRMWOp::Adaptor adaptor(operands);
     Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
 
     // Split the block into initial, loop, and ending parts.
index 0da2209..fe8b925 100644 (file)
@@ -144,9 +144,9 @@ public:
   using OpConversionPattern<StdOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
+  matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    assert(operands.size() <= 2);
+    assert(adaptor.getOperands().size() <= 2);
     auto dstType = this->getTypeConverter()->convertType(operation.getType());
     if (!dstType)
       return failure();
@@ -155,7 +155,8 @@ public:
       return operation.emitError(
           "bitwidth emulation is not implemented yet on unsigned op");
     }
-    rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType, operands);
+    rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType,
+                                                  adaptor.getOperands());
     return success();
   }
 };
@@ -169,7 +170,7 @@ public:
   using OpConversionPattern<SignedRemIOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(SignedRemIOp remOp, ArrayRef<Value> operands,
+  matchAndRewrite(SignedRemIOp remOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -183,19 +184,19 @@ public:
   using OpConversionPattern<StdOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
+  matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    assert(operands.size() == 2);
+    assert(adaptor.getOperands().size() == 2);
     auto dstType =
         this->getTypeConverter()->convertType(operation.getResult().getType());
     if (!dstType)
       return failure();
-    if (isBoolScalarOrVector(operands.front().getType())) {
-      rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(operation, dstType,
-                                                           operands);
+    if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
+      rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
+          operation, dstType, adaptor.getOperands());
     } else {
-      rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(operation, dstType,
-                                                           operands);
+      rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
+          operation, dstType, adaptor.getOperands());
     }
     return success();
   }
@@ -208,7 +209,7 @@ public:
   using OpConversionPattern<ConstantOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(ConstantOp constOp, ArrayRef<Value> operands,
+  matchAndRewrite(ConstantOp constOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -218,7 +219,7 @@ public:
   using OpConversionPattern<ConstantOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(ConstantOp constOp, ArrayRef<Value> operands,
+  matchAndRewrite(ConstantOp constOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -228,7 +229,7 @@ public:
   using OpConversionPattern<CmpFOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
+  matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -239,7 +240,7 @@ public:
   using OpConversionPattern<CmpFOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
+  matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -250,7 +251,7 @@ public:
   using OpConversionPattern<CmpFOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
+  matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -260,7 +261,7 @@ public:
   using OpConversionPattern<CmpIOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
+  matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -270,7 +271,7 @@ public:
   using OpConversionPattern<CmpIOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
+  matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -280,7 +281,7 @@ public:
   using OpConversionPattern<ReturnOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
+  matchAndRewrite(ReturnOp returnOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -289,7 +290,7 @@ class SelectOpPattern final : public OpConversionPattern<SelectOp> {
 public:
   using OpConversionPattern<SelectOp>::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
+  matchAndRewrite(SelectOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -299,7 +300,7 @@ public:
   using OpConversionPattern<SplatOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(SplatOp op, ArrayRef<Value> operands,
+  matchAndRewrite(SplatOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -310,9 +311,9 @@ public:
   using OpConversionPattern<ZeroExtendIOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(ZeroExtendIOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ZeroExtendIOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto srcType = operands.front().getType();
+    auto srcType = adaptor.getOperands().front().getType();
     if (!isBoolScalarOrVector(srcType))
       return failure();
 
@@ -322,7 +323,7 @@ public:
     Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
     Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
     rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
-        op, dstType, operands.front(), one, zero);
+        op, dstType, adaptor.getOperands().front(), one, zero);
     return success();
   }
 };
@@ -338,7 +339,7 @@ public:
         byteCountThreshold(threshold) {}
 
   LogicalResult
-  matchAndRewrite(tensor::ExtractOp extractOp, ArrayRef<Value> operands,
+  matchAndRewrite(tensor::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     TensorType tensorType = extractOp.tensor().getType().cast<TensorType>();
 
@@ -351,7 +352,6 @@ public:
                                          "exceeding byte count threshold");
 
     Location loc = extractOp.getLoc();
-    tensor::ExtractOp::Adaptor adaptor(operands);
 
     int64_t rank = tensorType.getRank();
     SmallVector<int64_t, 4> strides(rank, 1);
@@ -396,7 +396,7 @@ public:
   using OpConversionPattern<TruncateIOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(TruncateIOp op, ArrayRef<Value> operands,
+  matchAndRewrite(TruncateIOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto dstType =
         this->getTypeConverter()->convertType(op.getResult().getType());
@@ -404,11 +404,11 @@ public:
       return failure();
 
     Location loc = op.getLoc();
-    auto srcType = operands.front().getType();
+    auto srcType = adaptor.getOperands().front().getType();
     // Check if (x & 1) == 1.
     Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
-    Value maskedSrc =
-        rewriter.create<spirv::BitwiseAndOp>(loc, srcType, operands[0], mask);
+    Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
+        loc, srcType, adaptor.getOperands()[0], mask);
     Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
 
     Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
@@ -425,9 +425,9 @@ public:
   using OpConversionPattern<UIToFPOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(UIToFPOp op, ArrayRef<Value> operands,
+  matchAndRewrite(UIToFPOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto srcType = operands.front().getType();
+    auto srcType = adaptor.getOperands().front().getType();
     if (!isBoolScalarOrVector(srcType))
       return failure();
 
@@ -437,7 +437,7 @@ public:
     Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
     Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
     rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
-        op, dstType, operands.front(), one, zero);
+        op, dstType, adaptor.getOperands().front(), one, zero);
     return success();
   }
 };
@@ -449,10 +449,10 @@ public:
   using OpConversionPattern<StdOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
+  matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    assert(operands.size() == 1);
-    auto srcType = operands.front().getType();
+    assert(adaptor.getOperands().size() == 1);
+    auto srcType = adaptor.getOperands().front().getType();
     auto dstType =
         this->getTypeConverter()->convertType(operation.getResult().getType());
     if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
@@ -460,10 +460,10 @@ public:
     if (dstType == srcType) {
       // Due to type conversion, we are seeing the same source and target type.
       // Then we can just erase this operation by forwarding its operand.
-      rewriter.replaceOp(operation, operands.front());
+      rewriter.replaceOp(operation, adaptor.getOperands().front());
     } else {
       rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType,
-                                                    operands);
+                                                    adaptor.getOperands());
     }
     return success();
   }
@@ -475,7 +475,7 @@ public:
   using OpConversionPattern<XOrOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
+  matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -486,7 +486,7 @@ public:
   using OpConversionPattern<XOrOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
+  matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -497,10 +497,11 @@ public:
 //===----------------------------------------------------------------------===//
 
 LogicalResult SignedRemIOpPattern::matchAndRewrite(
-    SignedRemIOp remOp, ArrayRef<Value> operands,
+    SignedRemIOp remOp, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
-  Value result = emulateSignedRemainder(remOp.getLoc(), operands[0],
-                                        operands[1], operands[0], rewriter);
+  Value result = emulateSignedRemainder(
+      remOp.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
+      adaptor.getOperands()[0], rewriter);
   rewriter.replaceOp(remOp, result);
 
   return success();
@@ -514,7 +515,7 @@ LogicalResult SignedRemIOpPattern::matchAndRewrite(
 // so that the tensor case can be moved to TensorToSPIRV conversion. But,
 // std.constant is for the standard dialect though.
 LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
-    ConstantOp constOp, ArrayRef<Value> operands,
+    ConstantOp constOp, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   auto srcType = constOp.getType().dyn_cast<ShapedType>();
   if (!srcType)
@@ -599,7 +600,7 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
 //===----------------------------------------------------------------------===//
 
 LogicalResult ConstantScalarOpPattern::matchAndRewrite(
-    ConstantOp constOp, ArrayRef<Value> operands,
+    ConstantOp constOp, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   Type srcType = constOp.getType();
   if (!srcType.isIntOrIndexOrFloat())
@@ -653,16 +654,13 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
+CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor,
                                ConversionPatternRewriter &rewriter) const {
-  CmpFOpAdaptor cmpFOpOperands(operands);
-
   switch (cmpFOp.getPredicate()) {
 #define DISPATCH(cmpPredicate, spirvOp)                                        \
   case cmpPredicate:                                                           \
     rewriter.replaceOpWithNewOp<spirvOp>(cmpFOp, cmpFOp.getResult().getType(), \
-                                         cmpFOpOperands.lhs(),                 \
-                                         cmpFOpOperands.rhs());                \
+                                         adaptor.lhs(), adaptor.rhs());        \
     return success();
 
     // Ordered.
@@ -689,19 +687,17 @@ CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
 }
 
 LogicalResult CmpFOpNanKernelPattern::matchAndRewrite(
-    CmpFOp cmpFOp, ArrayRef<Value> operands,
+    CmpFOp cmpFOp, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
-  CmpFOpAdaptor cmpFOpOperands(operands);
-
   if (cmpFOp.getPredicate() == CmpFPredicate::ORD) {
-    rewriter.replaceOpWithNewOp<spirv::OrderedOp>(cmpFOp, cmpFOpOperands.lhs(),
-                                                  cmpFOpOperands.rhs());
+    rewriter.replaceOpWithNewOp<spirv::OrderedOp>(cmpFOp, adaptor.lhs(),
+                                                  adaptor.rhs());
     return success();
   }
 
   if (cmpFOp.getPredicate() == CmpFPredicate::UNO) {
-    rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(
-        cmpFOp, cmpFOpOperands.lhs(), cmpFOpOperands.rhs());
+    rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(cmpFOp, adaptor.lhs(),
+                                                    adaptor.rhs());
     return success();
   }
 
@@ -709,17 +705,16 @@ LogicalResult CmpFOpNanKernelPattern::matchAndRewrite(
 }
 
 LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
-    CmpFOp cmpFOp, ArrayRef<Value> operands,
+    CmpFOp cmpFOp, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   if (cmpFOp.getPredicate() != CmpFPredicate::ORD &&
       cmpFOp.getPredicate() != CmpFPredicate::UNO)
     return failure();
 
-  CmpFOpAdaptor cmpFOpOperands(operands);
   Location loc = cmpFOp.getLoc();
 
-  Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, cmpFOpOperands.lhs());
-  Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, cmpFOpOperands.rhs());
+  Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.lhs());
+  Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.rhs());
 
   Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
   if (cmpFOp.getPredicate() == CmpFPredicate::ORD)
@@ -734,10 +729,8 @@ LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
+BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor,
                                    ConversionPatternRewriter &rewriter) const {
-  CmpIOpAdaptor cmpIOpOperands(operands);
-
   Type operandType = cmpIOp.lhs().getType();
   if (!isBoolScalarOrVector(operandType))
     return failure();
@@ -746,8 +739,7 @@ BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
 #define DISPATCH(cmpPredicate, spirvOp)                                        \
   case cmpPredicate:                                                           \
     rewriter.replaceOpWithNewOp<spirvOp>(cmpIOp, cmpIOp.getResult().getType(), \
-                                         cmpIOpOperands.lhs(),                 \
-                                         cmpIOpOperands.rhs());                \
+                                         adaptor.lhs(), adaptor.rhs());        \
     return success();
 
     DISPATCH(CmpIPredicate::eq, spirv::LogicalEqualOp);
@@ -760,10 +752,8 @@ BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
 }
 
 LogicalResult
-CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
+CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor,
                                ConversionPatternRewriter &rewriter) const {
-  CmpIOpAdaptor cmpIOpOperands(operands);
-
   Type operandType = cmpIOp.lhs().getType();
   if (isBoolScalarOrVector(operandType))
     return failure();
@@ -777,8 +767,7 @@ CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
           "bitwidth emulation is not implemented yet on unsigned op");         \
     }                                                                          \
     rewriter.replaceOpWithNewOp<spirvOp>(cmpIOp, cmpIOp.getResult().getType(), \
-                                         cmpIOpOperands.lhs(),                 \
-                                         cmpIOpOperands.rhs());                \
+                                         adaptor.lhs(), adaptor.rhs());        \
     return success();
 
     DISPATCH(CmpIPredicate::eq, spirv::IEqualOp);
@@ -802,13 +791,14 @@ CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
+ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, OpAdaptor adaptor,
                                  ConversionPatternRewriter &rewriter) const {
   if (returnOp.getNumOperands() > 1)
     return failure();
 
   if (returnOp.getNumOperands() == 1) {
-    rewriter.replaceOpWithNewOp<spirv::ReturnValueOp>(returnOp, operands[0]);
+    rewriter.replaceOpWithNewOp<spirv::ReturnValueOp>(returnOp,
+                                                      adaptor.getOperands()[0]);
   } else {
     rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
   }
@@ -820,12 +810,10 @@ ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-SelectOpPattern::matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
+SelectOpPattern::matchAndRewrite(SelectOp op, OpAdaptor adaptor,
                                  ConversionPatternRewriter &rewriter) const {
-  SelectOpAdaptor selectOperands(operands);
-  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(),
-                                               selectOperands.true_value(),
-                                               selectOperands.false_value());
+  rewriter.replaceOpWithNewOp<spirv::SelectOp>(
+      op, adaptor.condition(), adaptor.true_value(), adaptor.false_value());
   return success();
 }
 
@@ -834,12 +822,11 @@ SelectOpPattern::matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-SplatPattern::matchAndRewrite(SplatOp op, ArrayRef<Value> operands,
+SplatPattern::matchAndRewrite(SplatOp op, OpAdaptor adaptor,
                               ConversionPatternRewriter &rewriter) const {
   auto dstVecType = op.getType().dyn_cast<VectorType>();
   if (!dstVecType || !spirv::CompositeType::isValid(dstVecType))
     return failure();
-  SplatOp::Adaptor adaptor(operands);
   SmallVector<Value, 4> source(dstVecType.getNumElements(), adaptor.input());
   rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstVecType,
                                                            source);
@@ -851,34 +838,35 @@ SplatPattern::matchAndRewrite(SplatOp op, ArrayRef<Value> operands,
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
+XOrOpPattern::matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor,
                               ConversionPatternRewriter &rewriter) const {
-  assert(operands.size() == 2);
+  assert(adaptor.getOperands().size() == 2);
 
-  if (isBoolScalarOrVector(operands.front().getType()))
+  if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
     return failure();
 
   auto dstType = getTypeConverter()->convertType(xorOp.getType());
   if (!dstType)
     return failure();
-  rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(xorOp, dstType, operands);
+  rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(xorOp, dstType,
+                                                   adaptor.getOperands());
 
   return success();
 }
 
 LogicalResult
-BoolXOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
+BoolXOrOpPattern::matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor,
                                   ConversionPatternRewriter &rewriter) const {
-  assert(operands.size() == 2);
+  assert(adaptor.getOperands().size() == 2);
 
-  if (!isBoolScalarOrVector(operands.front().getType()))
+  if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
     return failure();
 
   auto dstType = getTypeConverter()->convertType(xorOp.getType());
   if (!dstType)
     return failure();
   rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(xorOp, dstType,
-                                                        operands);
+                                                        adaptor.getOperands());
   return success();
 }
 
index ed1b03c..8ee0c43 100644 (file)
@@ -947,7 +947,7 @@ class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
 public:
   using OpConversionPattern<tosa::Conv2DOp>::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(tosa::Conv2DOp op, ArrayRef<Value> args,
+  matchAndRewrite(tosa::Conv2DOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
     Location loc = op->getLoc();
     Value input = op->getOperand(0);
@@ -1111,7 +1111,7 @@ class DepthwiseConvConverter
 public:
   using OpConversionPattern<tosa::DepthwiseConv2DOp>::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(tosa::DepthwiseConv2DOp op, ArrayRef<Value> args,
+  matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
     Location loc = op->getLoc();
     Value input = op->getOperand(0);
@@ -1266,7 +1266,7 @@ class TransposeConvConverter
 public:
   using OpConversionPattern<tosa::TransposeConv2DOp>::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(tosa::TransposeConv2DOp op, ArrayRef<Value> args,
+  matchAndRewrite(tosa::TransposeConv2DOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
     Location loc = op->getLoc();
     Value input = op->getOperand(0);
@@ -1336,10 +1336,8 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
 public:
   using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(tosa::MatMulOp op, ArrayRef<Value> args,
+  matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
-    tosa::MatMulOp::Adaptor adaptor(args);
-
     Location loc = op.getLoc();
 
     auto outputTy = op.getType().cast<ShapedType>();
@@ -1377,7 +1375,7 @@ class FullyConnectedConverter
 public:
   using OpConversionPattern<tosa::FullyConnectedOp>::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(tosa::FullyConnectedOp op, ArrayRef<Value> args,
+  matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
     Location loc = op.getLoc();
     auto outputTy = op.getType().cast<ShapedType>();
@@ -1486,15 +1484,13 @@ public:
   using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(tosa::ReshapeOp reshape, ArrayRef<Value> args,
+  matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
-    typename tosa::ReshapeOp::Adaptor operands(args);
-
-    ShapedType operandTy = operands.input1().getType().cast<ShapedType>();
+    ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
     ShapedType resultTy = reshape.getType().template cast<ShapedType>();
 
     if (operandTy == resultTy) {
-      rewriter.replaceOp(reshape, args[0]);
+      rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
       return success();
     }
 
@@ -1575,19 +1571,20 @@ public:
 
       auto collapsedTy = RankedTensorType::get({totalElems}, elemTy);
       Value collapsedOp = rewriter.create<linalg::TensorCollapseShapeOp>(
-          loc, collapsedTy, args[0], collapsingMap);
+          loc, collapsedTy, adaptor.getOperands()[0], collapsingMap);
       rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
           reshape, resultTy, collapsedOp, expandingMap);
 
       return success();
     }
 
-    if (resultTy.getRank() < args[0].getType().cast<ShapedType>().getRank())
+    if (resultTy.getRank() <
+        adaptor.getOperands()[0].getType().cast<ShapedType>().getRank())
       rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
-          reshape, resultTy, args[0], reassociationMap);
+          reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
     else
       rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
-          reshape, resultTy, args[0], reassociationMap);
+          reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
 
     return success();
   }
@@ -2117,7 +2114,7 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
   using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(tosa::ConcatOp op, ArrayRef<Value> args,
+  matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto resultType = op.getType().dyn_cast<RankedTensorType>();
     if (!resultType || !resultType.hasStaticShape()) {
@@ -2136,11 +2133,12 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
     offsets.resize(rank, rewriter.create<ConstantIndexOp>(loc, 0));
 
     for (int i = 0; i < rank; ++i) {
-      sizes.push_back(rewriter.create<tensor::DimOp>(loc, args[0], i));
+      sizes.push_back(
+          rewriter.create<tensor::DimOp>(loc, adaptor.getOperands()[0], i));
     }
 
     Value resultDimSize = sizes[axis];
-    for (auto arg : args.drop_front()) {
+    for (auto arg : adaptor.getOperands().drop_front()) {
       auto size = rewriter.create<tensor::DimOp>(loc, arg, axisValue);
       resultDimSize = rewriter.create<AddIOp>(loc, resultDimSize, size);
     }
@@ -2154,7 +2152,7 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
     Value result =
         rewriter.create<linalg::FillOp>(loc, zeroVal, init).getResult(0);
 
-    for (auto arg : args) {
+    for (auto arg : adaptor.getOperands()) {
       sizes[axis] = rewriter.create<tensor::DimOp>(loc, arg, axisValue);
       result = rewriter.create<tensor::InsertSliceOp>(loc, arg, result, offsets,
                                                       sizes, strides);
@@ -2230,7 +2228,7 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
   using OpConversionPattern<tosa::TileOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(tosa::TileOp op, ArrayRef<Value> args,
+  matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
     auto input = op.input1();
@@ -2488,10 +2486,10 @@ class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
 public:
   using OpConversionPattern<tosa::GatherOp>::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(tosa::GatherOp op, ArrayRef<Value> args,
+  matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
-    auto input = args[0];
-    auto indices = args[1];
+    auto input = adaptor.getOperands()[0];
+    auto indices = adaptor.getOperands()[1];
 
     auto inputTy = input.getType().cast<ShapedType>();
     auto indicesTy = indices.getType().cast<ShapedType>();
index de9dfa1..27037cb 100644 (file)
@@ -36,13 +36,12 @@ struct VectorBitcastConvert final
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(vector::BitCastOp bitcastOp, ArrayRef<Value> operands,
+  matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
     if (!dstType)
       return failure();
 
-    vector::BitCastOp::Adaptor adaptor(operands);
     if (dstType == adaptor.source().getType())
       rewriter.replaceOp(bitcastOp, adaptor.source());
     else
@@ -58,12 +57,11 @@ struct VectorBroadcastConvert final
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef<Value> operands,
+  matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     if (broadcastOp.source().getType().isa<VectorType>() ||
         !spirv::CompositeType::isValid(broadcastOp.getVectorType()))
       return failure();
-    vector::BroadcastOp::Adaptor adaptor(operands);
     SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
                                  adaptor.source());
     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
@@ -77,7 +75,7 @@ struct VectorExtractOpConvert final
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
+  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Only support extracting a scalar value now.
     VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>();
@@ -88,7 +86,6 @@ struct VectorExtractOpConvert final
     if (!dstType)
       return failure();
 
-    vector::ExtractOp::Adaptor adaptor(operands);
     if (adaptor.vector().getType().isa<spirv::ScalarType>()) {
       rewriter.replaceOp(extractOp, adaptor.vector());
       return success();
@@ -106,8 +103,7 @@ struct VectorExtractStridedSliceOpConvert final
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
-                  ArrayRef<Value> operands,
+  matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto dstType = getTypeConverter()->convertType(extractOp.getType());
     if (!dstType)
@@ -120,7 +116,7 @@ struct VectorExtractStridedSliceOpConvert final
     if (stride != 1)
       return failure();
 
-    Value srcVector = operands.front();
+    Value srcVector = adaptor.getOperands().front();
 
     // Extract vector<1xT> case.
     if (dstType.isa<spirv::ScalarType>()) {
@@ -144,11 +140,10 @@ struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands,
+  matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     if (!spirv::CompositeType::isValid(fmaOp.getVectorType()))
       return failure();
-    vector::FMAOp::Adaptor adaptor(operands);
     rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>(
         fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc());
     return success();
@@ -160,12 +155,11 @@ struct VectorInsertOpConvert final
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
+  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     if (insertOp.getSourceType().isa<VectorType>() ||
         !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
       return failure();
-    vector::InsertOp::Adaptor adaptor(operands);
     int32_t id = getFirstIntValue(insertOp.position());
     rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
         insertOp, adaptor.source(), adaptor.dest(), id);
@@ -178,12 +172,10 @@ struct VectorExtractElementOpConvert final
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(vector::ExtractElementOp extractElementOp,
-                  ArrayRef<Value> operands,
+  matchAndRewrite(vector::ExtractElementOp extractElementOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
       return failure();
-    vector::ExtractElementOp::Adaptor adaptor(operands);
     rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
         extractElementOp, extractElementOp.getType(), adaptor.vector(),
         extractElementOp.position());
@@ -196,12 +188,10 @@ struct VectorInsertElementOpConvert final
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(vector::InsertElementOp insertElementOp,
-                  ArrayRef<Value> operands,
+  matchAndRewrite(vector::InsertElementOp insertElementOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
       return failure();
-    vector::InsertElementOp::Adaptor adaptor(operands);
     rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
         insertElementOp, insertElementOp.getType(), insertElementOp.dest(),
         adaptor.source(), insertElementOp.position());
@@ -214,11 +204,10 @@ struct VectorInsertStridedSliceOpConvert final
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(vector::InsertStridedSliceOp insertOp,
-                  ArrayRef<Value> operands,
+  matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Value srcVector = operands.front();
-    Value dstVector = operands.back();
+    Value srcVector = adaptor.getOperands().front();
+    Value dstVector = adaptor.getOperands().back();
 
     // Insert scalar values not supported yet.
     if (srcVector.getType().isa<spirv::ScalarType>() ||
index 6a94299..5e60bed 100644 (file)
@@ -84,7 +84,7 @@ Value castPtr(ConversionPatternRewriter &rewriter, Location loc, Value ptr) {
 struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
   using ConvertOpToLLVMPattern<TileZeroOp>::ConvertOpToLLVMPattern;
   LogicalResult
-  matchAndRewrite(TileZeroOp op, ArrayRef<Value> operands,
+  matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     VectorType vType = op.getVectorType();
     // Determine m x n tile sizes.
@@ -102,9 +102,8 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
   using ConvertOpToLLVMPattern<TileLoadOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(TileLoadOp op, ArrayRef<Value> operands,
+  matchAndRewrite(TileLoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    TileLoadOp::Adaptor adaptor(operands);
     MemRefType mType = op.getMemRefType();
     VectorType vType = op.getVectorType();
     // Determine m x n tile sizes.
@@ -130,9 +129,8 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
   using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(TileStoreOp op, ArrayRef<Value> operands,
+  matchAndRewrite(TileStoreOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    TileStoreOp::Adaptor adaptor(operands);
     MemRefType mType = op.getMemRefType();
     VectorType vType = op.getVectorType();
     // Determine m x n tile sizes.
@@ -156,9 +154,8 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
 struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
   using ConvertOpToLLVMPattern<TileMulFOp>::ConvertOpToLLVMPattern;
   LogicalResult
-  matchAndRewrite(TileMulFOp op, ArrayRef<Value> operands,
+  matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    TileMulFOp::Adaptor adaptor(operands);
     VectorType aType = op.getLhsVectorType();
     VectorType bType = op.getRhsVectorType();
     VectorType cType = op.getVectorType();
@@ -179,9 +176,8 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
 struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
   using ConvertOpToLLVMPattern<TileMulIOp>::ConvertOpToLLVMPattern;
   LogicalResult
-  matchAndRewrite(TileMulIOp op, ArrayRef<Value> operands,
+  matchAndRewrite(TileMulIOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    TileMulIOp::Adaptor adaptor(operands);
     VectorType aType = op.getLhsVectorType();
     VectorType bType = op.getRhsVectorType();
     VectorType cType = op.getVectorType();
index ed50f45..c477019 100644 (file)
@@ -46,12 +46,13 @@ class ForwardOperands : public OpConversionPattern<OpTy> {
   using OpConversionPattern<OpTy>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(OpTy op, ArrayRef<Value> operands,
+  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
-    if (ValueRange(operands).getTypes() == op->getOperands().getTypes())
+    if (adaptor.getOperands().getTypes() == op->getOperands().getTypes())
       return rewriter.notifyMatchFailure(op, "operand types already match");
 
-    rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); });
+    rewriter.updateRootInPlace(
+        op, [&]() { op->setOperands(adaptor.getOperands()); });
     return success();
   }
 };
@@ -61,9 +62,10 @@ public:
   using OpConversionPattern<ReturnOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
-    rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); });
+    rewriter.updateRootInPlace(
+        op, [&]() { op->setOperands(adaptor.getOperands()); });
     return success();
   }
 };
@@ -118,13 +120,12 @@ struct ScalableLoadOpLowering : public ConvertOpToLLVMPattern<ScalableLoadOp> {
   using ConvertOpToLLVMPattern<ScalableLoadOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(ScalableLoadOp loadOp, ArrayRef<Value> operands,
+  matchAndRewrite(ScalableLoadOp loadOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto type = loadOp.getMemRefType();
     if (!isConvertibleAndHasIdentityMaps(type))
       return failure();
 
-    ScalableLoadOp::Adaptor transformed(operands);
     LLVMTypeConverter converter(loadOp.getContext());
 
     auto resultType = loadOp.result().getType();
@@ -138,9 +139,8 @@ struct ScalableLoadOpLowering : public ConvertOpToLLVMPattern<ScalableLoadOp> {
                                           converter)
               .getValue());
     }
-    Value dataPtr =
-        getStridedElementPtr(loadOp.getLoc(), type, transformed.base(),
-                             transformed.index(), rewriter);
+    Value dataPtr = getStridedElementPtr(loadOp.getLoc(), type, adaptor.base(),
+                                         adaptor.index(), rewriter);
     Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>(
         loadOp.getLoc(), llvmDataTypePtr, dataPtr);
     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, bitCastedPtr);
@@ -155,13 +155,12 @@ struct ScalableStoreOpLowering
   using ConvertOpToLLVMPattern<ScalableStoreOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(ScalableStoreOp storeOp, ArrayRef<Value> operands,
+  matchAndRewrite(ScalableStoreOp storeOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto type = storeOp.getMemRefType();
     if (!isConvertibleAndHasIdentityMaps(type))
       return failure();
 
-    ScalableStoreOp::Adaptor transformed(operands);
     LLVMTypeConverter converter(storeOp.getContext());
 
     auto resultType = storeOp.value().getType();
@@ -175,12 +174,11 @@ struct ScalableStoreOpLowering
                                           converter)
               .getValue());
     }
-    Value dataPtr =
-        getStridedElementPtr(storeOp.getLoc(), type, transformed.base(),
-                             transformed.index(), rewriter);
+    Value dataPtr = getStridedElementPtr(storeOp.getLoc(), type, adaptor.base(),
+                                         adaptor.index(), rewriter);
     Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>(
         storeOp.getLoc(), llvmDataTypePtr, dataPtr);
-    rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, transformed.value(),
+    rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.value(),
                                                bitCastedPtr);
     return success();
   }
index 2127d7d..2d0886c 100644 (file)
@@ -337,10 +337,10 @@ public:
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(CreateGroupOp op, ArrayRef<Value> operands,
+  matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
-        op, GroupType::get(op->getContext()), operands);
+        op, GroupType::get(op->getContext()), adaptor.getOperands());
     return success();
   }
 };
@@ -356,10 +356,10 @@ public:
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(AddToGroupOp op, ArrayRef<Value> operands,
+  matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
-        op, rewriter.getIndexType(), operands);
+        op, rewriter.getIndexType(), adaptor.getOperands());
     return success();
   }
 };
@@ -382,7 +382,7 @@ public:
         outlinedFunctions(outlinedFunctions) {}
 
   LogicalResult
-  matchAndRewrite(AwaitType op, ArrayRef<Value> operands,
+  matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // We can only await on one the `AwaitableType` (for `await` it can be
     // a `token` or a `value`, for `await_all` it must be a `group`).
@@ -395,7 +395,7 @@ public:
     const bool isInCoroutine = outlined != outlinedFunctions.end();
 
     Location loc = op->getLoc();
-    Value operand = AwaitAdaptor(operands).operand();
+    Value operand = adaptor.operand();
 
     Type i1 = rewriter.getI1Type();
 
@@ -520,7 +520,7 @@ public:
         outlinedFunctions(outlinedFunctions) {}
 
   LogicalResult
-  matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands,
+  matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Check if yield operation is inside the async coroutine function.
     auto func = op->template getParentOfType<FuncOp>();
@@ -534,7 +534,7 @@ public:
 
     // Store yielded values into the async values storage and switch async
     // values state to available.
-    for (auto tuple : llvm::zip(operands, coro.returnValues)) {
+    for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
       Value yieldValue = std::get<0>(tuple);
       Value asyncValue = std::get<1>(tuple);
       rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
@@ -563,7 +563,7 @@ public:
         outlinedFunctions(outlinedFunctions) {}
 
   LogicalResult
-  matchAndRewrite(AssertOp op, ArrayRef<Value> operands,
+  matchAndRewrite(AssertOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Check if assert operation is inside the async coroutine function.
     auto func = op->template getParentOfType<FuncOp>();
@@ -577,7 +577,7 @@ public:
 
     Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
     rewriter.setInsertionPointToEnd(cont->getPrevNode());
-    rewriter.create<CondBranchOp>(loc, AssertOpAdaptor(operands).arg(),
+    rewriter.create<CondBranchOp>(loc, adaptor.arg(),
                                   /*trueDest=*/cont,
                                   /*trueArgs=*/ArrayRef<Value>(),
                                   /*falseDest=*/setupSetErrorBlock(coro),
index 79a111f..a8cf8c1 100644 (file)
@@ -104,9 +104,8 @@ public:
   using OpConversionPattern<InitTensorOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(InitTensorOp op, ArrayRef<Value> operands,
+  matchAndRewrite(InitTensorOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
-    linalg::InitTensorOpAdaptor adaptor(operands, op->getAttrDictionary());
     rewriter.replaceOpWithNewOp<memref::AllocOp>(
         op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
         adaptor.sizes());
@@ -126,9 +125,8 @@ public:
       memref::ExpandShapeOp, memref::CollapseShapeOp>;
 
   LogicalResult
-  matchAndRewrite(TensorReshapeOp op, ArrayRef<Value> operands,
+  matchAndRewrite(TensorReshapeOp op, Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
-    Adaptor adaptor(operands, op->getAttrDictionary());
     rewriter.replaceOpWithNewOp<ReshapeOp>(op,
                                            this->getTypeConverter()
                                                ->convertType(op.getType())
@@ -145,9 +143,8 @@ public:
   using OpConversionPattern<FillOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(FillOp op, ArrayRef<Value> operands,
+  matchAndRewrite(FillOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
-    linalg::FillOpAdaptor adaptor(operands, op->getAttrDictionary());
     if (!op.output().getType().isa<TensorType>())
       return rewriter.notifyMatchFailure(op,
                                          "operand must be of a tensor type");
@@ -208,9 +205,8 @@ public:
   using OpConversionPattern<tensor::ExtractSliceOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(tensor::ExtractSliceOp op, ArrayRef<Value> operands,
+  matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
-    tensor::ExtractSliceOpAdaptor adaptor(operands, op->getAttrDictionary());
     Value sourceMemref = adaptor.source();
     assert(sourceMemref.getType().isa<MemRefType>());
 
index 08278bc..15b7f9e 100644 (file)
@@ -60,7 +60,7 @@ class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(GenericOp op, ArrayRef<Value> operands,
+  matchAndRewrite(GenericOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Block *originalBlock = op->getBlock();
 
@@ -78,7 +78,7 @@ public:
     rewriter.replaceOp(op, yieldOp->getOperands());
 
     // No need for these intermediate blocks, merge them into 1.
-    rewriter.mergeBlocks(opEntryBlock, originalBlock, operands);
+    rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands());
     rewriter.mergeBlocks(newBlock, originalBlock, {});
 
     rewriter.eraseOp(&*Block::iterator(yieldOp));
index c34660b..b84a6cb 100644 (file)
@@ -21,7 +21,7 @@ class ConvertForOpTypes : public OpConversionPattern<ForOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ForOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ForOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     SmallVector<Type, 6> newResultTypes;
     for (auto type : op.getResultTypes()) {
@@ -63,7 +63,7 @@ public:
     }
     // Change the clone to use the updated operands. We could have cloned with
     // a BlockAndValueMapping, but this seems a bit more direct.
-    newOp->setOperands(operands);
+    newOp->setOperands(adaptor.getOperands());
     // Update the result types to the new converted types.
     for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
       std::get<0>(t).setType(std::get<1>(t));
@@ -79,7 +79,7 @@ class ConvertIfOpTypes : public OpConversionPattern<IfOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(IfOp op, ArrayRef<Value> operands,
+  matchAndRewrite(IfOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // TODO: Generalize this to any type conversion, not just 1:1.
     //
@@ -108,7 +108,7 @@ public:
                                 newOp.elseRegion().end());
 
     // Update the operands and types.
-    newOp->setOperands(operands);
+    newOp->setOperands(adaptor.getOperands());
     for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
       std::get<0>(t).setType(std::get<1>(t));
     rewriter.replaceOp(op, newOp.getResults());
@@ -125,9 +125,9 @@ class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(scf::YieldOp op, ArrayRef<Value> operands,
+  matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<scf::YieldOp>(op, operands);
+    rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getOperands());
     return success();
   }
 };
@@ -139,7 +139,7 @@ public:
   using OpConversionPattern<WhileOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(WhileOp op, ArrayRef<Value> operands,
+  matchAndRewrite(WhileOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto *converter = getTypeConverter();
     assert(converter);
@@ -147,7 +147,6 @@ public:
     if (failed(converter->convertTypes(op.getResultTypes(), newResultTypes)))
       return failure();
 
-    WhileOp::Adaptor adaptor(operands);
     auto newOp = rewriter.create<WhileOp>(op.getLoc(), newResultTypes,
                                           adaptor.getOperands());
     for (auto i : {0u, 1u}) {
@@ -167,9 +166,10 @@ class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
 public:
   using OpConversionPattern<ConditionOp>::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ConditionOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ConditionOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); });
+    rewriter.updateRootInPlace(
+        op, [&]() { op->setOperands(adaptor.getOperands()); });
     return success();
   }
 };
index 20a793d..10a3ba6 100644 (file)
@@ -156,7 +156,7 @@ public:
   using OpConversionPattern<spirv::FuncOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(spirv::FuncOp funcOp, ArrayRef<Value> operands,
+  matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -168,7 +168,7 @@ class LowerABIAttributesPass final
 } // namespace
 
 LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
-    spirv::FuncOp funcOp, ArrayRef<Value> operands,
+    spirv::FuncOp funcOp, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   if (!funcOp->getAttrOfType<spirv::EntryPointABIAttr>(
           spirv::getEntryPointABIAttrName())) {
index 76abf15..e2981bf 100644 (file)
@@ -541,13 +541,13 @@ public:
   using OpConversionPattern<FuncOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
+  matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 } // namespace
 
 LogicalResult
-FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
+FuncOpConversion::matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
                                   ConversionPatternRewriter &rewriter) const {
   auto fnType = funcOp.getType();
   if (fnType.getNumResults() > 1)
index b58fa4d..008164a 100644 (file)
@@ -20,7 +20,7 @@ public:
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(AssumingOp op, ArrayRef<Value> operands,
+  matchAndRewrite(AssumingOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
     SmallVector<Type, 2> newResultTypes;
     newResultTypes.reserve(op.getNumResults());
@@ -48,9 +48,9 @@ public:
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(AssumingYieldOp op, ArrayRef<Value> operands,
+  matchAndRewrite(AssumingYieldOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
-    rewriter.replaceOpWithNewOp<AssumingYieldOp>(op, operands);
+    rewriter.replaceOpWithNewOp<AssumingYieldOp>(op, adaptor.getOperands());
     return success();
   }
 };
index 5df5477..328bf8e 100644 (file)
@@ -249,9 +249,9 @@ class SparseReturnConverter : public OpConversionPattern<ReturnOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
+    rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands());
     return success();
   }
 };
@@ -262,7 +262,7 @@ class SparseTensorToDimSizeConverter
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(tensor::DimOp op, ArrayRef<Value> operands,
+  matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type resType = op.getType();
     auto enc = getSparseTensorEncoding(op.source().getType());
@@ -278,7 +278,7 @@ public:
     // Generate the call.
     StringRef name = "sparseDimSize";
     SmallVector<Value, 2> params;
-    params.push_back(operands[0]);
+    params.push_back(adaptor.getOperands()[0]);
     params.push_back(
         rewriter.create<ConstantOp>(op.getLoc(), rewriter.getIndexAttr(idx)));
     rewriter.replaceOpWithNewOp<CallOp>(
@@ -291,14 +291,15 @@ public:
 class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(NewOp op, ArrayRef<Value> operands,
+  matchAndRewrite(NewOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type resType = op.getType();
     auto enc = getSparseTensorEncoding(resType);
     if (!enc)
       return failure();
     Value perm;
-    rewriter.replaceOp(op, genNewCall(rewriter, op, enc, 0, perm, operands[0]));
+    rewriter.replaceOp(
+        op, genNewCall(rewriter, op, enc, 0, perm, adaptor.getOperands()[0]));
     return success();
   }
 };
@@ -307,7 +308,7 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
 class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ConvertOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type resType = op.getType();
     auto encDst = getSparseTensorEncoding(resType);
@@ -320,7 +321,8 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
       // yield the fastest conversion but avoids the need for a full
       // O(N^2) conversion matrix.
       Value perm;
-      Value coo = genNewCall(rewriter, op, encDst, 3, perm, operands[0]);
+      Value coo =
+          genNewCall(rewriter, op, encDst, 3, perm, adaptor.getOperands()[0]);
       rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, coo));
       return success();
     }
@@ -349,7 +351,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
         MemRefType::get({ShapedType::kDynamicSize}, rewriter.getIndexType());
     Value perm;
     Value ptr = genNewCall(rewriter, op, encDst, 2, perm);
-    Value tensor = operands[0];
+    Value tensor = adaptor.getOperands()[0];
     Value arg = rewriter.create<ConstantOp>(
         loc, rewriter.getIndexAttr(shape.getRank()));
     Value ind = rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg});
@@ -381,7 +383,7 @@ class SparseTensorToPointersConverter
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ToPointersOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ToPointersOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type resType = op.getType();
     Type eltType = resType.cast<ShapedType>().getElementType();
@@ -398,10 +400,11 @@ public:
       name = "sparsePointers8";
     else
       return failure();
-    rewriter.replaceOpWithNewOp<CallOp>(
-        op, resType,
-        getFunc(op, name, resType, operands, /*emitCInterface=*/true),
-        operands);
+    rewriter.replaceOpWithNewOp<CallOp>(op, resType,
+                                        getFunc(op, name, resType,
+                                                adaptor.getOperands(),
+                                                /*emitCInterface=*/true),
+                                        adaptor.getOperands());
     return success();
   }
 };
@@ -411,7 +414,7 @@ class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ToIndicesOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type resType = op.getType();
     Type eltType = resType.cast<ShapedType>().getElementType();
@@ -428,10 +431,11 @@ public:
       name = "sparseIndices8";
     else
       return failure();
-    rewriter.replaceOpWithNewOp<CallOp>(
-        op, resType,
-        getFunc(op, name, resType, operands, /*emitCInterface=*/true),
-        operands);
+    rewriter.replaceOpWithNewOp<CallOp>(op, resType,
+                                        getFunc(op, name, resType,
+                                                adaptor.getOperands(),
+                                                /*emitCInterface=*/true),
+                                        adaptor.getOperands());
     return success();
   }
 };
@@ -441,7 +445,7 @@ class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ToValuesOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type resType = op.getType();
     Type eltType = resType.cast<ShapedType>().getElementType();
@@ -460,10 +464,11 @@ public:
       name = "sparseValuesI8";
     else
       return failure();
-    rewriter.replaceOpWithNewOp<CallOp>(
-        op, resType,
-        getFunc(op, name, resType, operands, /*emitCInterface=*/true),
-        operands);
+    rewriter.replaceOpWithNewOp<CallOp>(op, resType,
+                                        getFunc(op, name, resType,
+                                                adaptor.getOperands(),
+                                                /*emitCInterface=*/true),
+                                        adaptor.getOperands());
     return success();
   }
 };
@@ -474,12 +479,12 @@ public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
   // Simply fold the operator into the pointer to the sparse storage scheme.
-  matchAndRewrite(ToTensorOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ToTensorOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Check that all arguments of the tensor reconstruction operators are calls
     // into the support library that query exactly the same opaque pointer.
     Value ptr;
-    for (Value op : operands) {
+    for (Value op : adaptor.getOperands()) {
       if (auto call = op.getDefiningOp<CallOp>()) {
         Value arg = call.getOperand(0);
         if (!arg.getType().isa<LLVM::LLVMPointerType>())
index 06f6c12..23b7019 100644 (file)
@@ -27,9 +27,8 @@ class BufferizeIndexCastOp : public OpConversionPattern<IndexCastOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(IndexCastOp op, ArrayRef<Value> operands,
+  matchAndRewrite(IndexCastOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    IndexCastOp::Adaptor adaptor(operands);
     auto tensorType = op.getType().cast<RankedTensorType>();
     rewriter.replaceOpWithNewOp<IndexCastOp>(
         op, adaptor.in(),
@@ -42,12 +41,11 @@ class BufferizeSelectOp : public OpConversionPattern<SelectOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
+  matchAndRewrite(SelectOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     if (!op.condition().getType().isa<IntegerType>())
       return rewriter.notifyMatchFailure(op, "requires scalar condition");
 
-    SelectOp::Adaptor adaptor(operands);
     rewriter.replaceOpWithNewOp<SelectOp>(
         op, adaptor.condition(), adaptor.true_value(), adaptor.false_value());
     return success();
index 7636bc7..3686568 100644 (file)
@@ -61,7 +61,7 @@ struct DecomposeCallGraphTypesForFuncArgs
       DecomposeCallGraphTypesOpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(FuncOp op, ArrayRef<Value> operands,
+  matchAndRewrite(FuncOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
     auto functionType = op.getType();
 
@@ -106,10 +106,10 @@ struct DecomposeCallGraphTypesForReturnOp
   using DecomposeCallGraphTypesOpConversionPattern::
       DecomposeCallGraphTypesOpConversionPattern;
   LogicalResult
-  matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
     SmallVector<Value, 2> newOperands;
-    for (Value operand : operands)
+    for (Value operand : adaptor.getOperands())
       decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
                                 operand, newOperands);
     rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
@@ -131,12 +131,12 @@ struct DecomposeCallGraphTypesForCallOp
       DecomposeCallGraphTypesOpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(CallOp op, ArrayRef<Value> operands,
+  matchAndRewrite(CallOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
 
     // Create the operands list of the new `CallOp`.
     SmallVector<Value, 2> newOperands;
-    for (Value operand : operands)
+    for (Value operand : adaptor.getOperands())
       decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
                                 operand, newOperands);
 
index 49aaade..8756fcf 100644 (file)
@@ -20,7 +20,7 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
 
   /// Hook for derived classes to implement combined matching and rewriting.
   LogicalResult
-  matchAndRewrite(CallOp callOp, ArrayRef<Value> operands,
+  matchAndRewrite(CallOp callOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Convert the original function results.
     SmallVector<Type, 1> convertedResults;
@@ -30,8 +30,8 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
 
     // Substitute with the new result types from the corresponding FuncType
     // conversion.
-    rewriter.replaceOpWithNewOp<CallOp>(callOp, callOp.callee(),
-                                        convertedResults, operands);
+    rewriter.replaceOpWithNewOp<CallOp>(
+        callOp, callOp.callee(), convertedResults, adaptor.getOperands());
     return success();
   }
 };
@@ -96,13 +96,12 @@ public:
   using OpConversionPattern<ReturnOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
     // For a return, all operands go to the results of the parent, so
     // rewrite them all.
-    Operation *operation = op.getOperation();
-    rewriter.updateRootInPlace(
-        op, [operands, operation]() { operation->setOperands(operands); });
+    rewriter.updateRootInPlace(op,
+                               [&] { op->setOperands(adaptor.getOperands()); });
     return success();
   }
 };
index c916a73..035251f 100644 (file)
@@ -66,7 +66,7 @@ public:
         globals(globals) {}
 
   LogicalResult
-  matchAndRewrite(ConstantOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ConstantOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto type = op.getType().dyn_cast<RankedTensorType>();
     if (!type)
index e8c3865..f35d701 100644 (file)
@@ -26,10 +26,11 @@ class BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(tensor::CastOp op, ArrayRef<Value> operands,
+  matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto resultType = getTypeConverter()->convertType(op.getType());
-    rewriter.replaceOpWithNewOp<memref::CastOp>(op, resultType, operands[0]);
+    rewriter.replaceOpWithNewOp<memref::CastOp>(op, resultType,
+                                                adaptor.getOperands()[0]);
     return success();
   }
 };
@@ -40,9 +41,8 @@ class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(tensor::DimOp op, ArrayRef<Value> operands,
+  matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    tensor::DimOp::Adaptor adaptor(operands);
     rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(),
                                                adaptor.index());
     return success();
@@ -55,9 +55,8 @@ class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(tensor::ExtractOp op, ArrayRef<Value> operands,
+  matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    tensor::ExtractOp::Adaptor adaptor(operands);
     rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.tensor(),
                                                 adaptor.indices());
     return success();
@@ -71,7 +70,7 @@ class BufferizeFromElementsOp
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(tensor::FromElementsOp op, ArrayRef<Value> operands,
+  matchAndRewrite(tensor::FromElementsOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     int numberOfElements = op.elements().size();
     auto resultType = MemRefType::get(
@@ -95,16 +94,15 @@ public:
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(tensor::GenerateOp op, ArrayRef<Value> operands,
+  matchAndRewrite(tensor::GenerateOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
     // Allocate memory.
     Location loc = op.getLoc();
-    tensor::GenerateOp::Adaptor transformed(operands);
     RankedTensorType tensorType = op.getType().cast<RankedTensorType>();
     MemRefType memrefType =
         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
-    Value result = rewriter.create<memref::AllocOp>(
-        loc, memrefType, transformed.dynamicExtents());
+    Value result = rewriter.create<memref::AllocOp>(loc, memrefType,
+                                                    adaptor.dynamicExtents());
 
     // Collect loop bounds.
     int64_t rank = tensorType.getRank();
@@ -117,7 +115,7 @@ public:
     for (int i = 0; i < rank; i++) {
       Value upperBound =
           tensorType.isDynamicDim(i)
-              ? transformed.dynamicExtents()[nextDynamicIndex++]
+              ? adaptor.dynamicExtents()[nextDynamicIndex++]
               : rewriter.create<ConstantIndexOp>(loc, memrefType.getDimSize(i));
       upperBounds.push_back(upperBound);
     }
index c2a7a19..7b174e1 100644 (file)
@@ -46,18 +46,18 @@ struct LowerToIntrinsic : public OpConversionPattern<OpTy> {
   }
 
   LogicalResult
-  matchAndRewrite(OpTy op, ArrayRef<Value> operands,
+  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type elementType = getSrcVectorElementType<OpTy>(op);
     unsigned bitwidth = elementType.getIntOrFloatBitWidth();
     if (bitwidth == 32)
       return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(),
-                                           operands, getTypeConverter(),
-                                           rewriter);
+                                           adaptor.getOperands(),
+                                           getTypeConverter(), rewriter);
     if (bitwidth == 64)
       return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(),
-                                           operands, getTypeConverter(),
-                                           rewriter);
+                                           adaptor.getOperands(),
+                                           getTypeConverter(), rewriter);
     return rewriter.notifyMatchFailure(
         op, "expected 'src' to be either f32 or f64");
   }
@@ -68,9 +68,8 @@ struct MaskCompressOpConversion
   using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(MaskCompressOp op, ArrayRef<Value> operands,
+  matchAndRewrite(MaskCompressOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    MaskCompressOp::Adaptor adaptor(operands);
     auto opType = adaptor.a().getType();
 
     Value src;
@@ -95,10 +94,8 @@ struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
   using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(RsqrtOp op, ArrayRef<Value> operands,
+  matchAndRewrite(RsqrtOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    RsqrtOp::Adaptor adaptor(operands);
-
     auto opType = adaptor.a().getType();
     rewriter.replaceOpWithNewOp<RsqrtIntrOp>(op, opType, adaptor.a());
     return success();
@@ -109,9 +106,8 @@ struct DotOpConversion : public ConvertOpToLLVMPattern<DotOp> {
   using ConvertOpToLLVMPattern<DotOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(DotOp op, ArrayRef<Value> operands,
+  matchAndRewrite(DotOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    DotOp::Adaptor adaptor(operands);
     auto opType = adaptor.a().getType();
     Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8);
     // Dot product of all elements, broadcasted to all elements.
index 7ed7526..27d27c0 100644 (file)
@@ -58,9 +58,8 @@ class BufferizeTensorLoadOp : public OpConversionPattern<memref::TensorLoadOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(memref::TensorLoadOp op, ArrayRef<Value> operands,
+  matchAndRewrite(memref::TensorLoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    memref::TensorLoadOp::Adaptor adaptor(operands);
     rewriter.replaceOp(op, adaptor.memref());
     return success();
   }
@@ -74,9 +73,8 @@ class BufferizeCastOp : public OpConversionPattern<memref::BufferCastOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(memref::BufferCastOp op, ArrayRef<Value> operands,
+  matchAndRewrite(memref::BufferCastOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    memref::BufferCastOp::Adaptor adaptor(operands);
     rewriter.replaceOp(op, adaptor.tensor());
     return success();
   }