Introduce OpOperandAdaptors and emit them from ODS
authorAlex Zinenko <zinenko@google.com>
Mon, 3 Jun 2019 15:03:20 +0000 (08:03 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Tue, 4 Jun 2019 02:26:12 +0000 (19:26 -0700)
When manipulating generic operations, such as in dialect conversion /
rewriting, it is often necessary to view a list of Values as operands to an
operation without creating the operation itself.  The absence of such view
makes dialect conversion patterns, among others, to use magic numbers to obtain
specific operands from a list of rewritten values when converting an operation.
Introduce XOpOperandAdaptor classes that wrap an ArrayRef<Value *> and provide
accessor functions identical to those available in XOp.  This makes it possible
for conversions to use these adaptors to address the operands with names rather
than rely on their position in the list.  The adaptors are generated from ODS
together with the actual operation definitions.

This is another step towards making dialect conversion patterns specific for a
given operation.

Illustrate the approach on conversion patterns in the standard to LLVM dialect
conversion.

PiperOrigin-RevId: 251232899

mlir/g3doc/OpDefinitions.md
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/StandardOps/Ops.td
mlir/include/mlir/TableGen/Operator.h
mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
mlir/lib/TableGen/Operator.cpp
mlir/test/mlir-tblgen/op-decl.td
mlir/test/mlir-tblgen/op-operand.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

index 4befbde..810dd34 100644 (file)
@@ -400,11 +400,13 @@ the other for definitions. The former is generated via the `-gen-op-decls`
 command-line option, while the latter is via the `-gen-op-defs` option.
 
 The definition file contains all the op method definitions, which can be
-included and enabled by defining `GET_OP_CLASSES`. Besides, it also
-contains a comma-separated list of all defined ops, which can be included
-and enabled by defining `GET_OP_LIST`.
+included and enabled by defining `GET_OP_CLASSES`. For each operation,
+OpDefinitionsGen generates an operation class and an
+[operand adaptor](#operand-adaptors) class. Besides, it also contains a
+comma-separated list of all defined ops, which can be included and enabled by
+defining `GET_OP_LIST`.
 
-### Class name and namespaces
+#### Class name and namespaces
 
 For each operation, its generated C++ class name is the symbol `def`ed with
 TableGen with dialect prefix removed. The first `_` serves as the delimiter.
@@ -423,6 +425,36 @@ match exactly with the operation name as explained in
 [Operation name](#operation-name). This is to allow flexible naming to satisfy
 coding style requirements.
 
+#### Operand adaptors
+
+For each operation, we automatically generate an _operand adaptor_. This class
+solves the problem of accessing operands provided as a list of `Value`s without
+using "magic" constants. The operand adaptor takes a reference to an array of
+`Value *` and provides methods with the same names as those in the operation
+class to access them. For example, for a binary arithmethic operation, it may
+provide `.lhs()` to access the first operand and `.rhs()` to access the second
+operand.
+
+The operand adaptor class lives in the same namespace as the operation class,
+and has the name of the operation followed by `OperandAdaptor`. A template
+declaration `OperandAdaptor<>` is provided to look up the operand adaptor for
+the given operation.
+
+Operand adaptors can be used in function templates that also process operations:
+
+```c++
+template <typename BinaryOpTy>
+std::pair<Value *, Value *> zip(BinaryOpTy &&op) {
+  return std::make_pair(op.lhs(), op.rhs());;
+}
+
+void process(AddOp op, ArrayRef<Value *> newOperands) {
+  zip(op);
+  zip(OperandAdaptor<AddOp>(newOperands));
+  /*...*/
+}
+```
+
 ## Constraints
 
 Constraint is a core concept in table-driven operation definition: operation
index 312cf14..71a5339 100644 (file)
@@ -49,6 +49,14 @@ class RewritePattern;
 class Type;
 class Value;
 
+/// This is an adaptor from a list of values to named operands of OpTy.  In a
+/// generic operation context, e.g., in dialect conversions, an ordered array of
+/// `Value`s is treated as operands of `OpTy`.  This adaptor takes a reference
+/// to the array and provides accessors with the same names as `OpTy` for
+/// operands.  This makes possible to create function templates that operate on
+/// either OpTy or OperandAdaptor<OpTy> seamlessly.
+template <typename OpTy> using OperandAdaptor = typename OpTy::OperandAdaptor;
+
 /// This is a vector that owns the patterns inside of it.
 using OwningPatternList = std::vector<std::unique_ptr<Pattern>>;
 using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
index e7b6087..82493a0 100644 (file)
@@ -663,7 +663,7 @@ def MemRefCastOp : CastOp<"memref_cast"> {
        %3 = memref_cast %1 : memref<4xf32> to memref<?xf32>
   }];
 
-  let arguments = (ins AnyMemRef);
+  let arguments = (ins AnyMemRef:$source);
   let results = (outs AnyMemRef);
 
   let extraClassDeclaration = [{
index de2818e..56033f9 100644 (file)
@@ -157,6 +157,12 @@ public:
   // Returns this op's extra class declaration code.
   StringRef getExtraClassDeclaration() const;
 
+  // Returns the Tablegen definition this operator was constructed from.
+  // TODO(antiagainst,zinenko): do not expose the TableGen record, this is a
+  // temporary solution to OpEmitter requiring a Record because Operator does
+  // not provide enough methods.
+  const llvm::Record &getDef() const;
+
 private:
   // Populates the vectors containing operands, attributes, results and traits.
   void populateOpStructure();
index 3a756a9..a9717e2 100644 (file)
@@ -503,6 +503,7 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
   void rewrite(Operation *op, ArrayRef<Value *> operands,
                PatternRewriter &rewriter) const override {
     assert(operands.size() == 1 && "dealloc takes one operand");
+    OperandAdaptor<DeallocOp> transformed(operands);
 
     // Insert the `free` declaration if it is not already present.
     Function *freeFunc =
@@ -513,11 +514,12 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
       op->getFunction()->getModule()->getFunctions().push_back(freeFunc);
     }
 
-    auto type = operands[0]->getType().cast<LLVM::LLVMType>();
+    auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
     auto hasStaticShape = type.getUnderlyingType()->isPointerTy();
     Type elementPtrType = hasStaticShape ? type : type.getStructElementType(0);
-    Value *bufferPtr = extractMemRefElementPtr(
-        rewriter, op->getLoc(), operands[0], elementPtrType, hasStaticShape);
+    Value *bufferPtr =
+        extractMemRefElementPtr(rewriter, op->getLoc(), transformed.memref(),
+                                elementPtrType, hasStaticShape);
     Value *casted = rewriter.create<LLVM::BitcastOp>(
         op->getLoc(), getVoidPtrType(), bufferPtr);
     rewriter.replaceOpWithNewOp<LLVM::CallOp>(
@@ -542,13 +544,14 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
   void rewrite(Operation *op, ArrayRef<Value *> operands,
                PatternRewriter &rewriter) const override {
     auto memRefCastOp = cast<MemRefCastOp>(op);
+    OperandAdaptor<MemRefCastOp> transformed(operands);
     auto targetType = memRefCastOp.getType();
     auto sourceType = memRefCastOp.getOperand()->getType().cast<MemRefType>();
 
     // Copy the data buffer pointer.
     auto elementTypePtr = getMemRefElementPtrType(targetType, lowering);
     Value *buffer =
-        extractMemRefElementPtr(rewriter, op->getLoc(), operands[0],
+        extractMemRefElementPtr(rewriter, op->getLoc(), transformed.source(),
                                 elementTypePtr, sourceType.hasStaticShape());
     // Account for static memrefs as target types
     if (targetType.hasStaticShape())
@@ -583,7 +586,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
           sourceSize == -1
               ? rewriter.create<LLVM::ExtractValueOp>(
                     op->getLoc(), getIndexType(),
-                    operands[0], // NB: dynamic memref
+                    transformed.source(), // NB: dynamic memref
                     getIntegerArrayAttr(rewriter, sourceDynamicDimIdx++))
               : createIndexConstant(rewriter, op->getLoc(), sourceSize);
       newDescriptor = rewriter.create<LLVM::InsertValueOp>(
@@ -612,8 +615,8 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
 
   void rewrite(Operation *op, ArrayRef<Value *> operands,
                PatternRewriter &rewriter) const override {
-    assert(operands.size() == 1 && "expected exactly one operand");
     auto dimOp = cast<DimOp>(op);
+    OperandAdaptor<DimOp> transformed(operands);
     MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>();
 
     auto shape = type.getShape();
@@ -630,7 +633,7 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
           ++position;
       }
       rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
-          op, getIndexType(), operands[0],
+          op, getIndexType(), transformed.memrefOrTensor(),
           getIntegerArrayAttr(rewriter, position));
     } else {
       rewriter.replaceOp(
@@ -759,10 +762,11 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
   void rewrite(Operation *op, ArrayRef<Value *> operands,
                PatternRewriter &rewriter) const override {
     auto loadOp = cast<LoadOp>(op);
+    OperandAdaptor<LoadOp> transformed(operands);
     auto type = loadOp.getMemRefType();
 
-    Value *dataPtr = getDataPtr(op->getLoc(), type, operands.front(),
-                                operands.drop_front(), rewriter, getModule());
+    Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
+                                transformed.indices(), rewriter, getModule());
     auto elementType = lowering.convertType(type.getElementType());
 
     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, elementType,
@@ -778,10 +782,12 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
   void rewrite(Operation *op, ArrayRef<Value *> operands,
                PatternRewriter &rewriter) const override {
     auto type = cast<StoreOp>(op).getMemRefType();
+    OperandAdaptor<StoreOp> transformed(operands);
 
-    Value *dataPtr = getDataPtr(op->getLoc(), type, operands[1],
-                                operands.drop_front(2), rewriter, getModule());
-    rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, operands[0], dataPtr);
+    Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
+                                transformed.indices(), rewriter, getModule());
+    rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
+                                               dataPtr);
   }
 };
 
index cd3537d..87494c2 100644 (file)
@@ -87,6 +87,8 @@ StringRef tblgen::Operator::getExtraClassDeclaration() const {
   return def.getValueAsString(attr);
 }
 
+const llvm::Record &tblgen::Operator::getDef() const { return def; }
+
 tblgen::TypeConstraint
 tblgen::Operator::getResultTypeConstraint(int index) const {
   DagInit *results = def.getValueAsDag("results");
index 4fb450d..336efc1 100644 (file)
@@ -39,9 +39,19 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect]> {
 
 // CHECK-LABEL: NS::AOp declarations
 
+// CHECK:  class AOpOperandAdaptor {
+// CHECK:  public:
+// CHECK:    AOpOperandAdaptor(ArrayRef<Value *> values);
+// CHECK:    Value *a();
+// CHECK:    ArrayRef<Value *> b();
+// CHECK:  private:
+// CHECK:    ArrayRef<Value *> tblgen_operands;
+// CHECK:  };
+
 // CHECK: class AOp : public Op<AOp, OpTrait::AtLeastNResults<1>::Impl, OpTrait::HasNoSideEffect, OpTrait::AtLeastNOperands<1>::Impl> {
 // CHECK: public:
 // CHECK:   using Op::Op;
+// CHECK:   using OperandAdaptor = AOpOperandAdaptor;
 // CHECK:   static StringRef getOperationName();
 // CHECK:   Value *a();
 // CHECK:   Operation::operand_range b();
index 633d4e7..506c018 100644 (file)
@@ -14,6 +14,12 @@ def OpA : NS_Op<"one_normal_operand_op", []> {
 
 // CHECK-LABEL: OpA definitions
 
+// CHECK:      OpAOperandAdaptor::OpAOperandAdaptor
+// CHECK-NEXT: tblgen_operands = values
+
+// CHECK:      OpAOperandAdaptor::input
+// CHECK-NEXT: return tblgen_operands[0]
+
 // CHECK:      void OpA::build
 // CHECK-SAME:   Value *input
 // CHECK:        tblgen_state->operands.push_back(input);
@@ -40,6 +46,16 @@ def OpC : NS_Op<"all_variadic_inputs_op", [SameVariadicOperandSize]> {
   let arguments = (ins Variadic<AnyTensor>:$input1, Variadic<AnyTensor>:$input2);
 }
 
+// CHECK-LABEL: OpCOperandAdaptor::input1
+// CHECK-NEXT:    variadicOperandSize = (tblgen_operands.size() - 0) / 2;
+// CHECK-NEXT:    offset = 0 + variadicOperandSize * 0;
+// CHECK-NEXT:    return {std::next(tblgen_operands.begin(), offset), std::next(tblgen_operands.begin(), offset + variadicOperandSize)};
+
+// CHECK-LABEL: OpCOperandAdaptor::input2
+// CHECK-NEXT:    variadicOperandSize = (tblgen_operands.size() - 0) / 2;
+// CHECK-NEXT:    offset = 0 + variadicOperandSize * 1;
+// CHECK-NEXT:    return {std::next(tblgen_operands.begin(), offset), std::next(tblgen_operands.begin(), offset + variadicOperandSize)};
+
 // CHECK-LABEL: Operation::operand_range OpC::input1()
 // CHECK-NEXT:    variadicOperandSize = (this->getNumOperands() - 0) / 2;
 // CHECK-NEXT:    offset = 0 + variadicOperandSize * 0;
@@ -58,6 +74,21 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]>
   let arguments = (ins Variadic<AnyTensor>:$input1, AnyTensor:$input2, Variadic<AnyTensor>:$input3);
 }
 
+// CHECK-LABEL: OpDOperandAdaptor::input1() {
+// CHECK-NEXT:    variadicOperandSize = (tblgen_operands.size() - 1) / 2;
+// CHECK-NEXT:    offset = 0 + variadicOperandSize * 0;
+// CHECK-NEXT:    return {std::next(tblgen_operands.begin(), offset), std::next(tblgen_operands.begin(), offset + variadicOperandSize)};
+
+// CHECK-LABEL: Value *OpDOperandAdaptor::input2() {
+// CHECK-NEXT:    variadicOperandSize = (tblgen_operands.size() - 1) / 2;
+// CHECK-NEXT:    offset = 0 + variadicOperandSize * 1;
+// CHECK-NEXT:    return tblgen_operands[offset];
+
+// CHECK-LABEL: OpDOperandAdaptor::input3() {
+// CHECK-NEXT:    variadicOperandSize = (tblgen_operands.size() - 1) / 2;
+// CHECK-NEXT:    offset = 1 + variadicOperandSize * 1;
+// CHECK-NEXT:    return {std::next(tblgen_operands.begin(), offset), std::next(tblgen_operands.begin(), offset + variadicOperandSize)};
+
 // CHECK-LABEL: Operation::operand_range OpD::input1()
 // CHECK-NEXT:    variadicOperandSize = (this->getNumOperands() - 1) / 2;
 // CHECK-NEXT:    offset = 0 + variadicOperandSize * 0;
@@ -66,7 +97,7 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]>
 // CHECK-LABEL: Value *OpD::input2()
 // CHECK-NEXT:    variadicOperandSize = (this->getNumOperands() - 1) / 2;
 // CHECK-NEXT:    offset = 0 + variadicOperandSize * 1;
-// CHECK-NEXT:    return this->getOperand(offset);
+// CHECK-NEXT:    return this->getOperation()->getOperand(offset);
 
 // CHECK-LABEL: Operation::operand_range OpD::input3()
 // CHECK-NEXT:    variadicOperandSize = (this->getNumOperands() - 1) / 2;
@@ -82,6 +113,21 @@ def OpE : NS_Op<"one_variadic_among_multi_normal_inputs_op", []> {
   let arguments = (ins AnyTensor:$input1, AnyTensor:$input2, Variadic<AnyTensor>:$input3, AnyTensor:$input4, AnyTensor:$input5);
 }
 
+// CHECK-LABEL:  Value *OpEOperandAdaptor::input1() {
+// CHECK-NEXT:    return tblgen_operands[0];
+
+// CHECK-LABEL:  Value *OpEOperandAdaptor::input2() {
+// CHECK-NEXT:    return tblgen_operands[1];
+
+// CHECK-LABEL:  OpEOperandAdaptor::input3() {
+// CHECK-NEXT:    return {std::next(tblgen_operands.begin(), 2), std::next(tblgen_operands.begin(), 2 + tblgen_operands.size() - 4)};
+
+// CHECK-LABEL:  Value *OpEOperandAdaptor::input4() {
+// CHECK-NEXT:    return tblgen_operands[tblgen_operands.size() - 2];
+
+// CHECK-LABEL:  Value *OpEOperandAdaptor::input5() {
+// CHECK-NEXT:    return tblgen_operands[tblgen_operands.size() - 1];
+
 // CHECK-LABEL: Value *OpE::input1()
 // CHECK-NEXT:    return this->getOperation()->getOperand(0);
 
index 61373ae..4261faf 100644 (file)
@@ -110,8 +110,8 @@ public:
   void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
 
 private:
-  // Returns true if the given C++ `type` ends with '&' or '*'.
-  static bool endsWithRefOrPtr(StringRef type);
+  // Returns true if the given C++ `type` ends with '&' or '*', or is empty.
+  static bool elideSpaceAfterType(StringRef type);
 
   std::string returnType;
   std::string methodName;
@@ -142,7 +142,8 @@ public:
   // querying properties.
   enum Property {
     MP_None = 0x0,
-    MP_Static = 0x1, // Static method
+    MP_Static = 0x1,      // Static method
+    MP_Constructor = 0x2, // Constructor
   };
 
   OpMethod(StringRef retType, StringRef name, StringRef params,
@@ -168,19 +169,22 @@ private:
   OpMethodBody methodBody;
 };
 
-// Class for holding an op for C++ code emission
-class OpClass {
+// A class used to emit C++ classes from Tablegen.  Contains a list of public
+// methods and a list of private fields to be emitted.
+class Class {
 public:
-  explicit OpClass(StringRef name, StringRef extraClassDeclaration = "");
-
-  // Adds an op trait.
-  void addTrait(Twine trait);
+  explicit Class(StringRef name);
 
-  // Creates a new method in this op's class.
+  // Creates a new method in this class.
   OpMethod &newMethod(StringRef retType, StringRef name, StringRef params = "",
                       OpMethod::Property = OpMethod::MP_None,
                       bool declOnly = false);
 
+  OpMethod &newConstructor(StringRef params = "", bool declOnly = false);
+
+  // Creates a new field in this class.
+  void newField(StringRef type, StringRef name, StringRef defaultValue = "");
+
   // Writes this op's class as a declaration to the given `os`.
   void writeDeclTo(raw_ostream &os) const;
   // Writes the method definitions in this op's class to the given `os`.
@@ -189,11 +193,27 @@ public:
   // Returns the C++ class name of the op.
   StringRef getClassName() const { return className; }
 
+protected:
+  std::string className;
+  SmallVector<OpMethod, 8> methods;
+  SmallVector<std::string, 4> fields;
+};
+
+// Class for holding an op for C++ code emission
+class OpClass : public Class {
+public:
+  explicit OpClass(StringRef name, StringRef extraClassDeclaration = "");
+
+  // Adds an op trait.
+  void addTrait(Twine trait);
+
+  // Writes this op's class as a declaration to the given `os`.  Redefines
+  // Class::writeDeclTo to also emit traits and extra class declarations.
+  void writeDeclTo(raw_ostream &os) const;
+
 private:
-  StringRef className;
   StringRef extraClassDeclaration;
   SmallVector<std::string, 4> traits;
-  SmallVector<OpMethod, 8> methods;
 };
 } // end anonymous namespace
 
@@ -202,7 +222,7 @@ OpMethodSignature::OpMethodSignature(StringRef retType, StringRef name,
     : returnType(retType), methodName(name), parameters(params) {}
 
 void OpMethodSignature::writeDeclTo(raw_ostream &os) const {
-  os << returnType << (endsWithRefOrPtr(returnType) ? "" : " ") << methodName
+  os << returnType << (elideSpaceAfterType(returnType) ? "" : " ") << methodName
      << "(" << parameters << ")";
 }
 
@@ -224,13 +244,13 @@ void OpMethodSignature::writeDefTo(raw_ostream &os,
     return result;
   };
 
-  os << returnType << (endsWithRefOrPtr(returnType) ? "" : " ") << namePrefix
+  os << returnType << (elideSpaceAfterType(returnType) ? "" : " ") << namePrefix
      << (namePrefix.empty() ? "" : "::") << methodName << "("
      << removeParamDefaultValue(parameters) << ")";
 }
 
-bool OpMethodSignature::endsWithRefOrPtr(StringRef type) {
-  return type.endswith("&") || type.endswith("*");
+bool OpMethodSignature::elideSpaceAfterType(StringRef type) {
+  return type.empty() || type.endswith("&") || type.endswith("*");
 }
 
 OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {}
@@ -287,41 +307,71 @@ void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
   os << "}";
 }
 
+Class::Class(StringRef name) : className(name) {}
+
+OpMethod &Class::newMethod(StringRef retType, StringRef name, StringRef params,
+                           OpMethod::Property property, bool declOnly) {
+  methods.emplace_back(retType, name, params, property, declOnly);
+  return methods.back();
+}
+
+OpMethod &Class::newConstructor(StringRef params, bool declOnly) {
+  return newMethod("", getClassName(), params, OpMethod::MP_Constructor,
+                   declOnly);
+}
+
+void Class::newField(StringRef type, StringRef name, StringRef defaultValue) {
+  std::string varName = formatv("{0} {1}", type, name).str();
+  std::string field = defaultValue.empty()
+                          ? varName
+                          : formatv("{0} = {1}", varName, defaultValue).str();
+  fields.push_back(std::move(field));
+}
+
+void Class::writeDeclTo(raw_ostream &os) const {
+  os << "class " << className << " {\n";
+  os << "public:\n";
+  for (const auto &method : methods) {
+    method.writeDeclTo(os);
+    os << '\n';
+  }
+  os << '\n';
+  os << "private:\n";
+  for (const auto &field : fields)
+    os.indent(2) << field << ";\n";
+  os << "};\n";
+}
+
+void Class::writeDefTo(raw_ostream &os) const {
+  for (const auto &method : methods) {
+    method.writeDefTo(os, className);
+    os << "\n\n";
+  }
+}
+
 OpClass::OpClass(StringRef name, StringRef extraClassDeclaration)
-    : className(name), extraClassDeclaration(extraClassDeclaration) {}
+    : Class(name), extraClassDeclaration(extraClassDeclaration) {}
 
 // Adds the given trait to this op. Prefixes "OpTrait::" to `trait` implicitly.
 void OpClass::addTrait(Twine trait) {
   traits.push_back(("OpTrait::" + trait).str());
 }
 
-OpMethod &OpClass::newMethod(StringRef retType, StringRef name,
-                             StringRef params, OpMethod::Property property,
-                             bool declOnly) {
-  methods.emplace_back(retType, name, params, property, declOnly);
-  return methods.back();
-}
-
 void OpClass::writeDeclTo(raw_ostream &os) const {
   os << "class " << className << " : public Op<" << className;
   for (const auto &trait : traits)
     os << ", " << trait;
   os << "> {\npublic:\n";
   os << "  using Op::Op;\n";
+  os << "  using OperandAdaptor = " << className << "OperandAdaptor;\n";
   for (const auto &method : methods) {
     method.writeDeclTo(os);
     os << "\n";
   }
   // TODO: Add line control markers to make errors easier to debug.
-  os << extraClassDeclaration << "\n";
-  os << "};";
-}
-
-void OpClass::writeDefTo(raw_ostream &os) const {
-  for (const auto &method : methods) {
-    method.writeDefTo(os, className);
-    os << "\n\n";
-  }
+  if (!extraClassDeclaration.empty())
+    os << extraClassDeclaration << "\n";
+  os << "};\n";
 }
 
 //===----------------------------------------------------------------------===//
@@ -332,11 +382,11 @@ namespace {
 // Helper class to emit a record into the given output stream.
 class OpEmitter {
 public:
-  static void emitDecl(const Record &def, raw_ostream &os);
-  static void emitDef(const Record &def, raw_ostream &os);
+  static void emitDecl(const Operator &op, raw_ostream &os);
+  static void emitDef(const Operator &op, raw_ostream &os);
 
 private:
-  OpEmitter(const Record &def);
+  OpEmitter(const Operator &op);
 
   void emitDecl(raw_ostream &os);
   void emitDef(raw_ostream &os);
@@ -385,6 +435,8 @@ private:
   void genOpNameGetter();
 
   // The TableGen record for this op.
+  // TODO(antiagainst,zinenko): OpEmitter should not have a Record directly,
+  // it should rather go through the Operator for better abstraction.
   const Record &def;
 
   // The wrapper operator class for querying information from this op.
@@ -398,8 +450,8 @@ private:
 };
 } // end anonymous namespace
 
-OpEmitter::OpEmitter(const Record &def)
-    : def(def), op(def),
+OpEmitter::OpEmitter(const Operator &op)
+    : def(op.getDef()), op(op),
       opClass(op.getCppClassName(), op.getExtraClassDeclaration()) {
   verifyCtx.withOp("(*this->getOperation())");
 
@@ -418,22 +470,19 @@ OpEmitter::OpEmitter(const Record &def)
   genFolderDecls();
 }
 
-void OpEmitter::emitDecl(const Record &def, raw_ostream &os) {
-  OpEmitter(def).emitDecl(os);
+void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) {
+  OpEmitter(op).emitDecl(os);
 }
 
-void OpEmitter::emitDef(const Record &def, raw_ostream &os) {
-  OpEmitter(def).emitDef(os);
+void OpEmitter::emitDef(const Operator &op, raw_ostream &os) {
+  OpEmitter(op).emitDef(os);
 }
 
 void OpEmitter::emitDecl(raw_ostream &os) {
-  os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
   opClass.writeDeclTo(os);
 }
 
 void OpEmitter::emitDef(raw_ostream &os) {
-  os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
-
   opClass.writeDefTo(os);
 }
 
@@ -480,7 +529,21 @@ void OpEmitter::genAttrGetters() {
   }
 }
 
-void OpEmitter::genNamedOperandGetters() {
+// Generates the named operand getter methods for the given Operator `op` and
+// puts them in `opClass`.  Uses `rangeType` as the return type of getters that
+// return a range of operands (individual operands are `Value *` and each
+// element in the range must also be `Value *`); use `rangeBeginCall` to get an
+// iterator to the beginning of the operand range; use `rangeSizeCall` to obtain
+// the number of operands. `getOperandCallPattern` contains the code necessary
+// to obtain a single operand whose position will be substituted instead of
+// "{0}" marker in the pattern.  Note that the pattern should work for any kind
+// of ops, in particular for one-operand ops that may not have the
+// `getOperand(unsigned)` method.
+static void generateNamedOperandGetters(const Operator &op, Class &opClass,
+                                        StringRef rangeType,
+                                        StringRef rangeBeginCall,
+                                        StringRef rangeSizeCall,
+                                        StringRef getOperandCallPattern) {
   const int numOperands = op.getNumOperands();
   const int numVariadicOperands = op.getNumVariadicOperands();
   const int numNormalOperands = numOperands - numVariadicOperands;
@@ -499,20 +562,22 @@ void OpEmitter::genNamedOperandGetters() {
         continue;
 
       if (operand.isVariadic()) {
-        auto &m = opClass.newMethod("Operation::operand_range", operand.name);
+        auto &m = opClass.newMethod(rangeType, operand.name);
         m.body() << formatv(
-            "  return {{std::next(operand_begin(), {0}), "
-            "std::next(operand_begin(), {0} + this->getNumOperands() - {1})};",
-            i, numNormalOperands);
+            "  return {{std::next({2}, {0}), std::next({2}, {0} + {3} - {1})};",
+            i, numNormalOperands, rangeBeginCall, rangeSizeCall);
         emittedVariadicOperand = true;
       } else {
         auto &m = opClass.newMethod("Value *", operand.name);
-        m.body() << "  return this->getOperation()->getOperand(";
-        if (emittedVariadicOperand)
-          m.body() << "this->getNumOperands() - " << numOperands - i;
-        else
-          m.body() << i;
-        m.body() << ");\n";
+
+        auto operandIndex =
+            emittedVariadicOperand
+                ? formatv("{0} - {1}", rangeSizeCall, numOperands - i).str()
+                : std::to_string(i);
+
+        m.body() << "  return "
+                 << formatv(getOperandCallPattern.data(), operandIndex)
+                 << ";\n";
       }
     }
     return;
@@ -535,27 +600,36 @@ void OpEmitter::genNamedOperandGetters() {
       continue;
 
     const char *code = R"(
-  int variadicOperandSize = (this->getNumOperands() - {0}) / {1};
+  int variadicOperandSize = ({4} - {0}) / {1};
   int offset = {2} + variadicOperandSize * {3};
   return )";
     auto sizeAndOffset =
         formatv(code, numNormalOperands, numVariadicOperands,
-                emittedNormalOperands, emittedVariadicOperands);
+                emittedNormalOperands, emittedVariadicOperands, rangeSizeCall);
 
     if (operand.isVariadic()) {
-      auto &m = opClass.newMethod("Operation::operand_range", operand.name);
-      m.body() << sizeAndOffset
-               << "{std::next(operand_begin(), offset), "
-                  "std::next(operand_begin(), offset + variadicOperandSize)};";
+      auto &m = opClass.newMethod(rangeType, operand.name);
+      m.body() << sizeAndOffset << "{std::next(" << rangeBeginCall
+               << ", offset), std::next(" << rangeBeginCall
+               << ", offset + variadicOperandSize)};";
       ++emittedVariadicOperands;
     } else {
       auto &m = opClass.newMethod("Value *", operand.name);
-      m.body() << sizeAndOffset << "this->getOperand(offset);";
+      m.body() << sizeAndOffset
+               << formatv(getOperandCallPattern.data(), "offset") << ";";
       ++emittedNormalOperands;
     }
   }
 }
 
+void OpEmitter::genNamedOperandGetters() {
+  generateNamedOperandGetters(
+      op, opClass, /*rangeType=*/"Operation::operand_range",
+      /*rangeBeginCall=*/"operand_begin()",
+      /*rangeSizeCall=*/"this->getNumOperands()",
+      /*getOperandCallPattern=*/"this->getOperation()->getOperand({0})");
+}
+
 void OpEmitter::genNamedResultGetters() {
   const int numResults = op.getNumResults();
   const int numVariadicResults = op.getNumVariadicResults();
@@ -1103,15 +1177,61 @@ void OpEmitter::genOpNameGetter() {
   method.body() << "  return \"" << op.getOperationName() << "\";\n";
 }
 
+//===----------------------------------------------------------------------===//
+// OpOperandAdaptor emitter
+//===----------------------------------------------------------------------===//
+
+namespace {
+// Helper class to emit Op operand adaptors to an output stream.  Operand
+// adaptors are wrappers around ArrayRef<Value *> that provide named operand
+// getters identical to those defined in the Op.
+class OpOperandAdaptorEmitter {
+public:
+  static void emitDecl(const Operator &op, raw_ostream &os);
+  static void emitDef(const Operator &op, raw_ostream &os);
+
+private:
+  explicit OpOperandAdaptorEmitter(const Operator &op);
+
+  Class adapterClass;
+};
+} // end namespace
+
+OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
+    : adapterClass(op.getCppClassName().str() + "OperandAdaptor") {
+  adapterClass.newField("ArrayRef<Value *>", "tblgen_operands");
+  auto &constructor = adapterClass.newConstructor("ArrayRef<Value *> values");
+  constructor.body() << "  tblgen_operands = values;\n";
+
+  generateNamedOperandGetters(op, adapterClass,
+                              /*rangeType=*/"ArrayRef<Value *>",
+                              /*rangeBeginCall=*/"tblgen_operands.begin()",
+                              /*rangeSizeCall=*/"tblgen_operands.size()",
+                              /*getOperandCallPattern=*/"tblgen_operands[{0}]");
+}
+
+void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {
+  OpOperandAdaptorEmitter(op).adapterClass.writeDeclTo(os);
+}
+
+void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) {
+  OpOperandAdaptorEmitter(op).adapterClass.writeDefTo(os);
+}
+
 // Emits the opcode enum and op classes.
 static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os,
                           bool emitDecl) {
   IfDefScope scope("GET_OP_CLASSES", os);
   for (auto *def : defs) {
+    Operator op(*def);
     if (emitDecl) {
-      OpEmitter::emitDecl(*def, os);
+      os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
+      OpOperandAdaptorEmitter::emitDecl(op, os);
+      OpEmitter::emitDecl(op, os);
     } else {
-      OpEmitter::emitDef(*def, os);
+      os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
+      OpOperandAdaptorEmitter::emitDef(op, os);
+      OpEmitter::emitDef(op, os);
     }
   }
 }