Add operand type iterators to Operation and cleanup usages of operand->getType...
authorRiver Riddle <riverriddle@google.com>
Fri, 24 May 2019 20:28:55 +0000 (13:28 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:00:43 +0000 (20:00 -0700)
--

PiperOrigin-RevId: 249889174

14 files changed:
mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
mlir/include/mlir/GPU/GPUDialect.h
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/Operation.h
mlir/lib/GPU/IR/GPUDialect.cpp
mlir/lib/GPU/Transforms/KernelOutlining.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/lib/StandardOps/Ops.cpp

index adfa1f7..c9d52de 100644 (file)
@@ -97,7 +97,7 @@ class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
   void rewrite(Operation *op, ArrayRef<Value *> operands,
                PatternRewriter &rewriter) const override {
     edsc::ScopedContext edscContext(rewriter, op->getLoc());
-    auto elementType = linalg::convertLinalgType(*op->getResultTypes().begin());
+    auto elementType = linalg::convertLinalgType(*op->result_type_begin());
     Value *viewDescriptor = operands[0];
     ArrayRef<Value *> indices = operands.drop_front();
     Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
index 907e3f1..cc0d7f9 100644 (file)
@@ -226,8 +226,8 @@ public:
       // Find the next operation ready for inference, that is an operation
       // with all operands already resolved (non-generic).
       auto nextop = llvm::find_if(opWorklist, [](mlir::Operation *op) {
-        return llvm::all_of(op->getOperands(), [](mlir::Value *v) {
-          return !v->getType().cast<ToyArrayType>().isGeneric();
+        return llvm::all_of(op->getOperandTypes(), [](mlir::Type ty) {
+          return !ty.cast<ToyArrayType>().isGeneric();
         });
       });
       if (nextop == opWorklist.end())
@@ -308,9 +308,8 @@ public:
         if (!mangledCallee) {
           // Can't find the target, this is where we queue the request for the
           // callee and stop the inference for the current function now.
-          std::vector<mlir::Type> funcArgs;
-          for (auto operand : op->getOperands())
-            funcArgs.push_back(operand->getType());
+          std::vector<mlir::Type> funcArgs(op->operand_type_begin(),
+                                           op->operand_type_end());
           funcWorklist.push_back(
               {callee, std::move(mangledName), std::move(funcArgs)});
           return mlir::success();
index 5267586..3bd5d26 100644 (file)
@@ -226,8 +226,8 @@ public:
       // Find the next operation ready for inference, that is an operation
       // with all operands already resolved (non-generic).
       auto nextop = llvm::find_if(opWorklist, [](mlir::Operation *op) {
-        return llvm::all_of(op->getOperands(), [](mlir::Value *v) {
-          return !v->getType().cast<ToyArrayType>().isGeneric();
+        return llvm::all_of(op->getOperandTypes(), [](mlir::Type ty) {
+          return !ty.cast<ToyArrayType>().isGeneric();
         });
       });
       if (nextop == opWorklist.end())
@@ -312,18 +312,15 @@ public:
         if (!mangledCallee) {
           // Can't find the target, this is where we queue the request for the
           // callee and stop the inference for the current function now.
-          std::vector<mlir::Type> funcArgs;
-          for (auto operand : op->getOperands())
-            funcArgs.push_back(operand->getType());
+          std::vector<mlir::Type> funcArgs(op->operand_type_begin(),
+                                           op->operand_type_end());
           funcWorklist.push_back(
               {callee, std::move(mangledName), std::move(funcArgs)});
           return mlir::success();
         }
         // Found a specialized callee! Let's turn this into a normal call
         // operation.
-        SmallVector<mlir::Value *, 8> operands;
-        for (mlir::Value *v : op->getOperands())
-          operands.push_back(v);
+        SmallVector<mlir::Value *, 8> operands(op->getOperands());
         mlir::FuncBuilder builder(f);
         builder.setInsertionPoint(op);
         auto newCall =
index 8ace6ff..003336b 100644 (file)
@@ -78,9 +78,9 @@ public:
   /// Get the SSA values corresponding to kernel block size.
   KernelDim3 getBlockSize();
   /// Get the operand values passed as kernel arguments.
-  Operation::operand_range getKernelOperandValues();
-  /// Append the operand types passed as kernel arguments to `out`.
-  void getKernelOperandTypes(SmallVectorImpl<Type> &out);
+  operand_range getKernelOperandValues();
+  /// Get the operand types passed as kernel arguments.
+  operand_type_range getKernelOperandTypes();
 
   /// Get the SSA values passed as operands to specify the grid size.
   KernelDim3 getGridSizeOperandValues();
index 472441e..58d19d7 100644 (file)
@@ -328,6 +328,8 @@ template <typename ConcreteType, template <typename> class TraitType>
 struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
   using operand_iterator = Operation::operand_iterator;
   using operand_range = Operation::operand_range;
+  using operand_type_iterator = Operation::operand_type_iterator;
+  using operand_type_range = Operation::operand_type_range;
 
   /// Return the number of operands.
   unsigned getNumOperands() { return this->getOperation()->getNumOperands(); }
@@ -346,6 +348,17 @@ struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
   }
   operand_iterator operand_end() { return this->getOperation()->operand_end(); }
   operand_range getOperands() { return this->getOperation()->getOperands(); }
+
+  /// Operand type access.
+  operand_type_iterator operand_type_begin() {
+    return this->getOperation()->operand_type_begin();
+  }
+  operand_type_iterator operand_type_end() {
+    return this->getOperation()->operand_type_end();
+  }
+  operand_type_range getOperandTypes() {
+    return this->getOperation()->getOperandTypes();
+  }
 };
 } // end namespace detail
 
@@ -447,6 +460,8 @@ template <typename ConcreteType, template <typename> class TraitType>
 struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
   using result_iterator = Operation::result_iterator;
   using result_range = Operation::result_range;
+  using result_type_iterator = Operation::result_type_iterator;
+  using result_type_range = Operation::result_type_range;
 
   /// Return the number of results.
   unsigned getNumResults() { return this->getOperation()->getNumResults(); }
@@ -468,6 +483,17 @@ struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
   }
   result_iterator result_end() { return this->getOperation()->result_end(); }
   result_range getResults() { return this->getOperation()->getResults(); }
+
+  /// Result type access.
+  result_type_iterator result_type_begin() {
+    return this->getOperation()->result_type_begin();
+  }
+  result_type_iterator result_type_end() {
+    return this->getOperation()->result_type_end();
+  }
+  result_type_range getResultTypes() {
+    return this->getOperation()->getResultTypes();
+  }
 };
 } // end namespace detail
 
@@ -477,7 +503,6 @@ template <typename ConcreteType>
 class OneResult : public TraitBase<ConcreteType, OneResult> {
 public:
   Value *getResult() { return this->getOperation()->getResult(0); }
-
   Type getType() { return getResult()->getType(); }
 
   /// Replace all uses of 'this' value with the new value, updating anything in
index 107bfb8..0a7a2aa 100644 (file)
@@ -34,6 +34,7 @@ class BlockAndValueMapping;
 class Location;
 class MLIRContext;
 class OperandIterator;
+class OperandTypeIterator;
 struct OperationState;
 class ResultIterator;
 class ResultTypeIterator;
@@ -198,6 +199,13 @@ public:
 
   OpOperand &getOpOperand(unsigned idx) { return getOpOperands()[idx]; }
 
+  // Support operand type iteration.
+  using operand_type_iterator = OperandTypeIterator;
+  using operand_type_range = llvm::iterator_range<operand_type_iterator>;
+  operand_type_iterator operand_type_begin();
+  operand_type_iterator operand_type_end();
+  operand_type_range getOperandTypes();
+
   //===--------------------------------------------------------------------===//
   // Results
   //===--------------------------------------------------------------------===//
@@ -226,9 +234,10 @@ public:
 
   // Support result type iteration.
   using result_type_iterator = ResultTypeIterator;
+  using result_type_range = llvm::iterator_range<result_type_iterator>;
   result_type_iterator result_type_begin();
   result_type_iterator result_type_end();
-  llvm::iterator_range<result_type_iterator> getResultTypes();
+  result_type_range getResultTypes();
 
   //===--------------------------------------------------------------------===//
   // Attributes
@@ -500,6 +509,19 @@ public:
   Value *operator*() const { return this->object->getOperand(this->index); }
 };
 
+/// This class implements the operand type iterators for the Operation
+/// class in terms of operand_iterator->getType().
+class OperandTypeIterator final
+    : public llvm::mapped_iterator<OperandIterator, Type (*)(Value *)> {
+  static Type unwrap(Value *value) { return value->getType(); }
+
+public:
+  /// Initializes the operand type iterator to the specified operand iterator.
+  OperandTypeIterator(OperandIterator it)
+      : llvm::mapped_iterator<OperandIterator, Type (*)(Value *)>(it, &unwrap) {
+  }
+};
+
 // Implement the inline operand iterator methods.
 inline auto Operation::operand_begin() -> operand_iterator {
   return operand_iterator(this, 0);
@@ -513,6 +535,18 @@ inline auto Operation::getOperands() -> operand_range {
   return {operand_begin(), operand_end()};
 }
 
+inline auto Operation::operand_type_begin() -> operand_type_iterator {
+  return operand_type_iterator(operand_begin());
+}
+
+inline auto Operation::operand_type_end() -> operand_type_iterator {
+  return operand_type_iterator(operand_end());
+}
+
+inline auto Operation::getOperandTypes() -> operand_type_range {
+  return {operand_type_begin(), operand_type_end()};
+}
+
 /// This class implements the result iterators for the Operation class
 /// in terms of getResult(idx).
 class ResultIterator final
@@ -559,8 +593,7 @@ inline auto Operation::result_type_end() -> result_type_iterator {
   return result_type_iterator(result_end());
 }
 
-inline auto Operation::getResultTypes()
-    -> llvm::iterator_range<result_type_iterator> {
+inline auto Operation::getResultTypes() -> result_type_range {
   return {result_type_begin(), result_type_end()};
 }
 
index 755a2c2..5c0539a 100644 (file)
@@ -99,16 +99,12 @@ KernelDim3 LaunchOp::getBlockSize() {
   return KernelDim3{args[9], args[10], args[11]};
 }
 
-Operation::operand_range LaunchOp::getKernelOperandValues() {
-  return {getOperation()->operand_begin() + kNumConfigOperands,
-          getOperation()->operand_end()};
+LaunchOp::operand_range LaunchOp::getKernelOperandValues() {
+  return llvm::drop_begin(getOperands(), kNumConfigOperands);
 }
 
-void LaunchOp::getKernelOperandTypes(SmallVectorImpl<Type> &out) {
-  out.reserve(getNumOperands() - kNumConfigOperands + out.size());
-  for (unsigned i = kNumConfigOperands; i < getNumOperands(); ++i) {
-    out.push_back(getOperand(i)->getType());
-  }
+LaunchOp::operand_type_range LaunchOp::getKernelOperandTypes() {
+  return llvm::drop_begin(getOperandTypes(), kNumConfigOperands);
 }
 
 KernelDim3 LaunchOp::getGridSizeOperandValues() {
index 006ba4f..163a7cf 100644 (file)
@@ -60,8 +60,7 @@ void injectGpuIndexOperations(Location loc, Function &kernelFunc) {
 // Outline the `gpu.launch` operation body into a kernel function.
 Function *outlineKernelFunc(Module &module, gpu::LaunchOp &launchOp) {
   Location loc = launchOp.getLoc();
-  SmallVector<Type, 4> kernelOperandTypes;
-  launchOp.getKernelOperandTypes(kernelOperandTypes);
+  SmallVector<Type, 4> kernelOperandTypes(launchOp.getKernelOperandTypes());
   FunctionType type =
       FunctionType::get(kernelOperandTypes, {}, module.getContext());
   std::string kernelFuncName =
index bc46b45..71744cf 100644 (file)
@@ -220,10 +220,10 @@ void ModuleState::visitAttribute(Attribute attr) {
 
 void ModuleState::visitOperation(Operation *op) {
   // Visit all the types used in the operation.
-  for (auto *operand : op->getOperands())
-    visitType(operand->getType());
-  for (auto *result : op->getResults())
-    visitType(result->getType());
+  for (auto type : op->getOperandTypes())
+    visitType(type);
+  for (auto type : op->getResultTypes())
+    visitType(type);
 
   // Visit each of the attributes.
   for (auto elt : op->getAttrs())
index 22463f1..582fb39 100644 (file)
@@ -593,11 +593,7 @@ Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper,
     }
   }
 
-  SmallVector<Type, 8> resultTypes;
-  resultTypes.reserve(getNumResults());
-  for (auto *result : getResults())
-    resultTypes.push_back(result->getType());
-
+  SmallVector<Type, 8> resultTypes(getResultTypes());
   unsigned numRegions = getNumRegions();
   auto *newOp = Operation::create(getLoc(), getName(), operands, resultTypes,
                                   attrs, successors, numRegions,
@@ -718,8 +714,8 @@ static Type getTensorOrVectorElementType(Type type) {
 }
 
 LogicalResult OpTrait::impl::verifyOperandsAreIntegerLike(Operation *op) {
-  for (auto *operand : op->getOperands()) {
-    auto type = getTensorOrVectorElementType(operand->getType());
+  for (auto opType : op->getOperandTypes()) {
+    auto type = getTensorOrVectorElementType(opType);
     if (!type.isIntOrIndex())
       return op->emitOpError() << "requires an integer or index type";
   }
@@ -727,8 +723,8 @@ LogicalResult OpTrait::impl::verifyOperandsAreIntegerLike(Operation *op) {
 }
 
 LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) {
-  for (auto *operand : op->getOperands()) {
-    auto type = getTensorOrVectorElementType(operand->getType());
+  for (auto opType : op->getOperandTypes()) {
+    auto type = getTensorOrVectorElementType(opType);
     if (!type.isa<FloatType>())
       return op->emitOpError("requires a float type");
   }
@@ -742,8 +738,8 @@ LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) {
     return success();
 
   auto type = op->getOperand(0)->getType();
-  for (unsigned i = 1; i < nOperands; ++i)
-    if (op->getOperand(i)->getType() != type)
+  for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1))
+    if (opType != type)
       return op->emitOpError() << "requires all operands to have the same type";
   return success();
 }
@@ -798,13 +794,13 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) {
     return failure();
 
   auto type = op->getOperand(0)->getType();
-  for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) {
-    if (failed(verifyShapeMatch(op->getResult(i)->getType(), type)))
+  for (auto resultType : op->getResultTypes()) {
+    if (failed(verifyShapeMatch(resultType, type)))
       return op->emitOpError()
              << "requires the same shape for all operands and results";
   }
-  for (unsigned i = 1, e = op->getNumOperands(); i < e; ++i) {
-    if (failed(verifyShapeMatch(op->getOperand(i)->getType(), type)))
+  for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) {
+    if (failed(verifyShapeMatch(opType, type)))
       return op->emitOpError()
              << "requires the same shape for all operands and results";
   }
@@ -849,13 +845,13 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
     return failure();
 
   auto type = op->getResult(0)->getType();
-  for (unsigned i = 1, e = op->getNumResults(); i < e; ++i) {
-    if (op->getResult(i)->getType() != type)
+  for (auto resultType : llvm::drop_begin(op->getResultTypes(), 1)) {
+    if (resultType != type)
       return op->emitOpError()
              << "requires the same type for all operands and results";
   }
-  for (unsigned i = 0, e = op->getNumOperands(); i < e; ++i) {
-    if (op->getOperand(i)->getType() != type)
+  for (auto opType : op->getOperandTypes()) {
+    if (opType != type)
       return op->emitOpError()
              << "requires the same type for all operands and results";
   }
@@ -905,8 +901,8 @@ LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) {
 }
 
 LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) {
-  for (auto *result : op->getResults()) {
-    auto elementType = getTensorOrVectorElementType(result->getType());
+  for (auto resultType : op->getResultTypes()) {
+    auto elementType = getTensorOrVectorElementType(resultType);
     bool isBoolType = elementType.isInteger(1);
     if (!isBoolType)
       return op->emitOpError() << "requires a bool result type";
@@ -916,19 +912,17 @@ LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) {
 }
 
 LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) {
-  for (auto *result : op->getResults())
-    if (!getTensorOrVectorElementType(result->getType()).isa<FloatType>())
+  for (auto resultType : op->getResultTypes())
+    if (!getTensorOrVectorElementType(resultType).isa<FloatType>())
       return op->emitOpError() << "requires a floating point type";
 
   return success();
 }
 
 LogicalResult OpTrait::impl::verifyResultsAreIntegerLike(Operation *op) {
-  for (auto *result : op->getResults()) {
-    auto type = getTensorOrVectorElementType(result->getType());
-    if (!type.isIntOrIndex())
+  for (auto resultType : op->getResultTypes())
+    if (!getTensorOrVectorElementType(resultType).isIntOrIndex())
       return op->emitOpError() << "requires an integer or index type";
-  }
   return success();
 }
 
index 8b673f3..8490fda 100644 (file)
@@ -177,11 +177,8 @@ static ParseResult parseAllocaOp(OpAsmParser *parser, OperationState *result) {
 //===----------------------------------------------------------------------===//
 
 static void printGEPOp(OpAsmPrinter *p, GEPOp &op) {
-  SmallVector<Type, 8> types;
-  for (auto *operand : op.getOperands())
-    types.push_back(operand->getType());
-  auto funcTy =
-      FunctionType::get(types, op.getResult()->getType(), op.getContext());
+  SmallVector<Type, 8> types(op.getOperandTypes());
+  auto funcTy = FunctionType::get(types, op.getType(), op.getContext());
 
   *p << op.getOperationName() << ' ' << *op.base() << '[';
   p->printOperands(std::next(op.operand_begin()), op.operand_end());
@@ -326,11 +323,9 @@ static void printCallOp(OpAsmPrinter *p, CallOp &op) {
   p->printOptionalAttrDict(op.getAttrs(), {"callee"});
 
   // Reconstruct the function MLIR function type from operand and result types.
-  SmallVector<Type, 1> resultTypes(op.getOperation()->getResultTypes());
-  SmallVector<Type, 8> argTypes;
-  argTypes.reserve(op.getNumOperands());
-  for (auto *operand : llvm::drop_begin(op.getOperands(), isDirect ? 0 : 1))
-    argTypes.push_back(operand->getType());
+  SmallVector<Type, 1> resultTypes(op.getResultTypes());
+  SmallVector<Type, 8> argTypes(
+      llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1));
 
   *p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext());
 }
index acf0be3..4354f82 100644 (file)
@@ -262,17 +262,6 @@ protected:
   LLVM::LLVMDialect &dialect;
 };
 
-// Given a range of MLIR typed objects, return a list of their types.
-template <typename T>
-SmallVector<Type, 4> getTypes(llvm::iterator_range<T> range) {
-  SmallVector<Type, 4> types;
-  types.reserve(llvm::size(range));
-  for (auto operand : range) {
-    types.push_back(operand->getType());
-  }
-  return types;
-}
-
 // Basic lowering implementation for one-to-one rewriting from Standard Ops to
 // LLVM Dialect Ops.
 template <typename SourceOp, typename TargetOp>
@@ -288,8 +277,8 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
 
     Type packedType;
     if (numResults != 0) {
-      packedType =
-          this->lowering.packFunctionResults(getTypes(op->getResults()));
+      packedType = this->lowering.packFunctionResults(
+          llvm::to_vector<4>(op->getResultTypes()));
       assert(packedType && "type conversion failed, such operation should not "
                            "have been matched");
     }
@@ -832,7 +821,8 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
 
     // Otherwise, we need to pack the arguments into an LLVM struct type before
     // returning.
-    auto packedType = lowering.packFunctionResults(getTypes(op->getOperands()));
+    auto packedType =
+        lowering.packFunctionResults(llvm::to_vector<4>(op->getOperandTypes()));
 
     Value *packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType);
     for (unsigned i = 0; i < numArguments; ++i) {
index cf8a2cc..f94868d 100644 (file)
@@ -316,7 +316,7 @@ class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
   void rewrite(Operation *op, ArrayRef<Value *> operands,
                PatternRewriter &rewriter) const override {
     edsc::ScopedContext edscContext(rewriter, op->getLoc());
-    auto elementTy = lowering.convertType(*op->getResultTypes().begin());
+    auto elementTy = lowering.convertType(*op->result_type_begin());
     Value *viewDescriptor = operands[0];
     ArrayRef<Value *> indices = operands.drop_front();
     auto ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
index f05b0cf..fd9d57f 100644 (file)
@@ -284,8 +284,8 @@ static LogicalResult verify(AllocOp op) {
         "operand count does not equal dimension plus symbol operand count");
 
   // Verify that all operands are of type Index.
-  for (auto *operand : op.getOperands())
-    if (!operand->getType().isIndex())
+  for (auto operandType : op.getOperandTypes())
+    if (!operandType.isIndex())
       return op.emitOpError("requires operands to be of type Index");
   return success();
 }
@@ -475,11 +475,8 @@ static LogicalResult verify(CallOp op) {
 }
 
 FunctionType CallOp::getCalleeType() {
-  SmallVector<Type, 4> resultTypes(getOperation()->getResultTypes());
-  SmallVector<Type, 8> argTypes;
-  argTypes.reserve(getNumOperands());
-  for (auto *operand : getArgOperands())
-    argTypes.push_back(operand->getType());
+  SmallVector<Type, 4> resultTypes(getResultTypes());
+  SmallVector<Type, 8> argTypes(getOperandTypes());
   return FunctionType::get(argTypes, resultTypes, getContext());
 }