Move the definitions of BranchOp, DimOp, and ExtractElementOp to Op Definition...
authorRiver Riddle <riverriddle@google.com>
Fri, 10 May 2019 22:26:23 +0000 (15:26 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sat, 11 May 2019 02:30:31 +0000 (19:30 -0700)
--

PiperOrigin-RevId: 247686212

mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/StandardOps/Ops.h
mlir/include/mlir/StandardOps/Ops.td
mlir/lib/StandardOps/Ops.cpp
mlir/test/IR/invalid-ops.mlir

index 7e90f4a..b4b159e 100644 (file)
@@ -340,6 +340,8 @@ class Vector<Type t, list<int> dims> : ContainerType<t, AllOf<[
   list<int> dimensions = dims;
 }
 
+def VectorOrTensor : Type<IsVectorOrTensorTypePred, "vector or tensor">;
+
 // Tensor type.
 
 // This represents a generic tensor without constraints on elemental type,
index 81d3614..7f3e8ab 100644 (file)
@@ -46,39 +46,6 @@ public:
 #define GET_OP_CLASSES
 #include "mlir/StandardOps/Ops.h.inc"
 
-/// The "br" operation represents a branch operation in a function.
-/// The operation takes variable number of operands and produces no results.
-/// The operand number and types for each successor must match the
-/// arguments of the block successor. For example:
-///
-///   ^bb2:
-///      %2 = call @someFn()
-///      br ^bb3(%2 : tensor<*xf32>)
-///   ^bb3(%3: tensor<*xf32>):
-///
-class BranchOp : public Op<BranchOp, OpTrait::VariadicOperands,
-                           OpTrait::ZeroResult, OpTrait::IsTerminator> {
-public:
-  friend Operation;
-  using Op::Op;
-
-  static StringRef getOperationName() { return "std.br"; }
-
-  static void build(Builder *builder, OperationState *result, Block *dest,
-                    ArrayRef<Value *> operands = {});
-
-  // Hooks to customize behavior of this op.
-  static ParseResult parse(OpAsmParser *parser, OperationState *result);
-  void print(OpAsmPrinter *p);
-
-  /// Return the block this branch jumps to.
-  Block *getDest();
-  void setDest(Block *block);
-
-  /// Erase the operand at 'index' from the operand list.
-  void eraseOperand(unsigned index);
-};
-
 /// The "call" operation represents a direct call to a function.  The operands
 /// and result types of the call must match the specified function type.  The
 /// callee is encoded as a function attribute named "callee".
@@ -457,36 +424,6 @@ public:
   static bool isClassFor(Operation *op);
 };
 
-/// The "dim" operation takes a memref or tensor operand and returns an
-/// "index".  It requires a single integer attribute named "index".  It
-/// returns the size of the specified dimension.  For example:
-///
-///   %1 = dim %0, 2 : tensor<?x?x?xf32>
-///
-class DimOp : public Op<DimOp, OpTrait::OneOperand, OpTrait::OneResult,
-                        OpTrait::HasNoSideEffect> {
-public:
-  friend Operation;
-  using Op::Op;
-
-  static void build(Builder *builder, OperationState *result,
-                    Value *memrefOrTensor, unsigned index);
-
-  Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
-
-  /// This returns the dimension number that the 'dim' is inspecting.
-  unsigned getIndex() {
-    return getAttrOfType<IntegerAttr>("index").getValue().getZExtValue();
-  }
-
-  static StringRef getOperationName() { return "std.dim"; }
-
-  // Hooks to customize behavior of this op.
-  LogicalResult verify();
-  static ParseResult parse(OpAsmParser *parser, OperationState *result);
-  void print(OpAsmPrinter *p);
-};
-
 // DmaStartOp starts a non-blocking DMA operation that transfers data from a
 // source memref to a destination memref. The source and destination memref need
 // not be of the same dimensionality, but need to have the same elemental type.
@@ -684,42 +621,6 @@ public:
                                           MLIRContext *context);
 };
 
-/// The "extract_element" op reads a tensor or vector and returns one element
-/// from it specified by an index list. The output of extract is a new value
-/// with the same type as the elements of the tensor or vector. The arity of
-/// indices matches the rank of the accessed value (i.e., if a tensor is of rank
-/// 3, then 3 indices are required for the extract).  The indices should all be
-/// of affine_int type.
-///
-/// For example:
-///
-///   %3 = extract_element %0[%1, %2] : vector<4x4xi32>
-///
-class ExtractElementOp
-    : public Op<ExtractElementOp, OpTrait::VariadicOperands, OpTrait::OneResult,
-                OpTrait::HasNoSideEffect> {
-public:
-  friend Operation;
-  using Op::Op;
-
-  static void build(Builder *builder, OperationState *result, Value *aggregate,
-                    ArrayRef<Value *> indices = {});
-
-  Value *getAggregate() { return getOperand(0); }
-
-  operand_range getIndices() {
-    return {getOperation()->operand_begin() + 1, getOperation()->operand_end()};
-  }
-
-  static StringRef getOperationName() { return "std.extract_element"; }
-
-  // Hooks to customize behavior of this op.
-  LogicalResult verify();
-  static ParseResult parse(OpAsmParser *parser, OperationState *result);
-  void print(OpAsmPrinter *p);
-  Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
-};
-
 /// The "load" op reads an element from a memref specified by an index list. The
 /// output of load is a new value with the same type as the elements of the
 /// memref. The arity of indices is the rank of the memref (i.e., if the memref
index 237730c..cfdbf1d 100644 (file)
@@ -134,6 +134,40 @@ def AndOp : IntArithmeticOp<"and", [Commutative]> {
   let hasFolder = 1;
 }
 
+def BranchOp : Op<Standard_Dialect, "br", [Terminator]> {
+  let summary = "branch operation";
+  let description = [{
+    The "br" operation represents a branch operation in a function.
+    The operation takes variable number of operands and produces no results.
+    The operand number and types for each successor must match the arguments of
+    the block successor. For example:
+
+      ^bb2:
+        %2 = call @someFn()
+        br ^bb3(%2 : tensor<*xf32>)
+      ^bb3(%3: tensor<*xf32>):
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$operands);
+
+  let parser = [{ return parseBranchOp(parser, result); }];
+  let printer = [{ return printBranchOp(p, *this); }];
+
+  let builders = [OpBuilder<
+    "Builder *, OperationState *result, Block *dest,"
+    "ArrayRef<Value *> operands = {}", [{
+      result->addSuccessor(dest, operands);
+  }]>];
+
+  let extraClassDeclaration = [{
+    Block *getDest();
+    void setDest(Block *block);
+
+    /// Erase the operand at 'index' from the operand list.
+    void eraseOperand(unsigned index);
+  }];
+}
+
 def ConstantOp : Op<Standard_Dialect, "constant", [NoSideEffect]> {
   let summary = "constant";
 
@@ -177,6 +211,42 @@ def DeallocOp : Op<Standard_Dialect, "dealloc"> {
   let hasCanonicalizer = 0b1;
 }
 
+def DimOp : Op<Standard_Dialect, "dim", [NoSideEffect]> {
+  let summary = "dimension index operation";
+  let description = [{
+    The "dim" operation takes a memref or tensor operand and returns an "index".
+    It requires a single integer attribute named "index". It returns the size
+    of the specified dimension. For example:
+
+      %1 = dim %0, 2 : tensor<?x?x?xf32>
+  }];
+
+  let arguments = (ins AnyTypeOf<[MemRef<AnyType>, Tensor],
+                                 "any tensor or memref type">:$memrefOrTensor,
+                       APIntAttr:$index);
+  let results = (outs Index);
+
+  let parser = [{ return parseDimOp(parser, result); }];
+  let printer = [{ return printDimOp(p, *this); }];
+  let verifier = [{ return ::verify(*this); }];
+
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, Value *memrefOrTensor,"
+    "unsigned index", [{
+      auto indexType = builder->getIndexType();
+      auto indexAttr = builder->getIntegerAttr(indexType, index);
+      build(builder, result, indexType, memrefOrTensor, indexAttr);
+    }]>];
+
+  let extraClassDeclaration = [{
+    unsigned getIndex() {
+      return getAttrOfType<IntegerAttr>("index").getValue().getZExtValue();
+    }
+  }];
+
+  let hasConstantFolder = 0b1;
+}
+
 def DivFOp : FloatArithmeticOp<"divf"> {
   let summary = "floating point division operation";
 }
@@ -191,6 +261,47 @@ def DivIUOp : IntArithmeticOp<"diviu"> {
   let hasConstantFolder = 0b1;
 }
 
+def ExtractElementOp : Op<Standard_Dialect, "extract_element", [NoSideEffect]> {
+  let summary = "element extract operation";
+  let description = [{
+    The "extract_element" op reads a tensor or vector and returns one element
+    from it specified by an index list. The output of extract is a new value
+    with the same type as the elements of the tensor or vector. The arity of
+    indices matches the rank of the accessed value (i.e., if a tensor is of rank
+    3, then 3 indices are required for the extract).  The indices should all be
+    of affine_int type. For example:
+
+      %0 = extract_element %0[%1, %2] : vector<4x4xi32>
+  }];
+
+  let arguments = (ins VectorOrTensor:$aggregate,
+                       Variadic<Index>:$indices);
+  let results = (outs AnyType);
+
+  let parser = [{ return parseExtractElementOp(parser, result); }];
+  let printer = [{ return printExtractElementOp(p, *this); }];
+  let verifier = [{ return ::verify(*this); }];
+
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, Value *aggregate,"
+    "ArrayRef<Value *> indices = {}", [{
+      auto resType = aggregate->getType().cast<VectorOrTensorType>()
+                                         .getElementType();
+      build(builder, result, resType, aggregate, indices);
+    }]>];
+
+  let extraClassDeclaration = [{
+    Value *getAggregate() { return getOperand(0); }
+
+    operand_range getIndices() {
+      return {getOperation()->operand_begin() + 1,
+              getOperation()->operand_end()};
+    }
+  }];
+
+  let hasConstantFolder = 0b1;
+}
+
 def MulFOp : FloatArithmeticOp<"mulf"> {
   let summary = "foating point multiplication operation";
   let hasConstantFolder = 0b1;
index f9b13ce..9ad37fd 100644 (file)
@@ -61,9 +61,9 @@ void detail::printStandardBinaryOp(Operation *op, OpAsmPrinter *p) {
 
 StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
     : Dialect(/*name=*/"std", context) {
-  addOperations<BranchOp, CallOp, CallIndirectOp, CmpFOp, CmpIOp, CondBranchOp,
-                DimOp, DmaStartOp, DmaWaitOp, ExtractElementOp, LoadOp,
-                MemRefCastOp, ReturnOp, SelectOp, StoreOp, TensorCastOp,
+  addOperations<CallOp, CallIndirectOp, CmpFOp, CmpIOp, CondBranchOp,
+                DmaStartOp, DmaWaitOp, LoadOp, MemRefCastOp, ReturnOp, SelectOp,
+                StoreOp, TensorCastOp,
 #define GET_OP_LIST
 #include "mlir/StandardOps/Ops.cpp.inc"
                 >();
@@ -374,12 +374,7 @@ void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
 // BranchOp
 //===----------------------------------------------------------------------===//
 
-void BranchOp::build(Builder *builder, OperationState *result, Block *dest,
-                     ArrayRef<Value *> operands) {
-  result->addSuccessor(dest, operands);
-}
-
-ParseResult BranchOp::parse(OpAsmParser *parser, OperationState *result) {
+static ParseResult parseBranchOp(OpAsmParser *parser, OperationState *result) {
   Block *dest;
   SmallVector<Value *, 4> destOperands;
   if (parser->parseSuccessorAndUseList(dest, destOperands))
@@ -388,9 +383,9 @@ ParseResult BranchOp::parse(OpAsmParser *parser, OperationState *result) {
   return success();
 }
 
-void BranchOp::print(OpAsmPrinter *p) {
+static void printBranchOp(OpAsmPrinter *p, BranchOp op) {
   *p << "br ";
-  p->printSuccessorAndUseList(getOperation(), 0);
+  p->printSuccessorAndUseList(op.getOperation(), 0);
 }
 
 Block *BranchOp::getDest() { return getOperation()->getSuccessor(0); }
@@ -1297,21 +1292,13 @@ void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
 // DimOp
 //===----------------------------------------------------------------------===//
 
-void DimOp::build(Builder *builder, OperationState *result,
-                  Value *memrefOrTensor, unsigned index) {
-  result->addOperands(memrefOrTensor);
-  auto type = builder->getIndexType();
-  result->addAttribute("index", builder->getIntegerAttr(type, index));
-  result->types.push_back(type);
+static void printDimOp(OpAsmPrinter *p, DimOp op) {
+  *p << "dim " << *op.getOperand() << ", " << op.getIndex();
+  p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"});
+  *p << " : " << op.getOperand()->getType();
 }
 
-void DimOp::print(OpAsmPrinter *p) {
-  *p << "dim " << *getOperand() << ", " << getIndex();
-  p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"index"});
-  *p << " : " << getOperand()->getType();
-}
-
-ParseResult DimOp::parse(OpAsmParser *parser, OperationState *result) {
+static ParseResult parseDimOp(OpAsmParser *parser, OperationState *result) {
   OpAsmParser::OperandType operandInfo;
   IntegerAttr indexAttr;
   Type type;
@@ -1326,25 +1313,25 @@ ParseResult DimOp::parse(OpAsmParser *parser, OperationState *result) {
                  parser->addTypeToList(indexType, result->types));
 }
 
-LogicalResult DimOp::verify() {
+static LogicalResult verify(DimOp op) {
   // Check that we have an integer index operand.
-  auto indexAttr = getAttrOfType<IntegerAttr>("index");
+  auto indexAttr = op.getAttrOfType<IntegerAttr>("index");
   if (!indexAttr)
-    return emitOpError("requires an integer attribute named 'index'");
+    return op.emitOpError("requires an integer attribute named 'index'");
   uint64_t index = indexAttr.getValue().getZExtValue();
 
-  auto type = getOperand()->getType();
+  auto type = op.getOperand()->getType();
   if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
     if (index >= static_cast<uint64_t>(tensorType.getRank()))
-      return emitOpError("index is out of range");
+      return op.emitOpError("index is out of range");
   } else if (auto memrefType = type.dyn_cast<MemRefType>()) {
     if (index >= memrefType.getRank())
-      return emitOpError("index is out of range");
+      return op.emitOpError("index is out of range");
 
   } else if (type.isa<UnrankedTensorType>()) {
     // ok, assumed to be in-range.
   } else {
-    return emitOpError("requires an operand with tensor or memref type");
+    return op.emitOpError("requires an operand with tensor or memref type");
   }
 
   return success();
@@ -1355,11 +1342,10 @@ Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
   // Constant fold dim when the size along the index referred to is a constant.
   auto opType = getOperand()->getType();
   int64_t indexSize = -1;
-  if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
+  if (auto tensorType = opType.dyn_cast<RankedTensorType>())
     indexSize = tensorType.getShape()[getIndex()];
-  } else if (auto memrefType = opType.dyn_cast<MemRefType>()) {
+  else if (auto memrefType = opType.dyn_cast<MemRefType>())
     indexSize = memrefType.getShape()[getIndex()];
-  }
 
   if (indexSize >= 0)
     return IntegerAttr::get(IndexType::get(context), indexSize);
@@ -1641,24 +1627,16 @@ void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
 // ExtractElementOp
 //===----------------------------------------------------------------------===//
 
-void ExtractElementOp::build(Builder *builder, OperationState *result,
-                             Value *aggregate, ArrayRef<Value *> indices) {
-  auto aggregateType = aggregate->getType().cast<VectorOrTensorType>();
-  result->addOperands(aggregate);
-  result->addOperands(indices);
-  result->types.push_back(aggregateType.getElementType());
-}
-
-void ExtractElementOp::print(OpAsmPrinter *p) {
-  *p << "extract_element " << *getAggregate() << '[';
-  p->printOperands(getIndices());
+static void printExtractElementOp(OpAsmPrinter *p, ExtractElementOp op) {
+  *p << "extract_element " << *op.getAggregate() << '[';
+  p->printOperands(op.getIndices());
   *p << ']';
-  p->printOptionalAttrDict(getAttrs());
-  *p << " : " << getAggregate()->getType();
+  p->printOptionalAttrDict(op.getAttrs());
+  *p << " : " << op.getAggregate()->getType();
 }
 
-ParseResult ExtractElementOp::parse(OpAsmParser *parser,
-                                    OperationState *result) {
+static ParseResult parseExtractElementOp(OpAsmParser *parser,
+                                         OperationState *result) {
   OpAsmParser::OperandType aggregateInfo;
   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
   VectorOrTensorType type;
@@ -1674,25 +1652,26 @@ ParseResult ExtractElementOp::parse(OpAsmParser *parser,
       parser->addTypeToList(type.getElementType(), result->types));
 }
 
-LogicalResult ExtractElementOp::verify() {
-  if (getNumOperands() == 0)
-    return emitOpError("expected an aggregate to index into");
+static LogicalResult verify(ExtractElementOp op) {
+  if (op.getNumOperands() == 0)
+    return op.emitOpError("expected an aggregate to index into");
 
-  auto aggregateType = getAggregate()->getType().dyn_cast<VectorOrTensorType>();
+  auto aggregateType =
+      op.getAggregate()->getType().dyn_cast<VectorOrTensorType>();
   if (!aggregateType)
-    return emitOpError("first operand must be a vector or tensor");
+    return op.emitOpError("first operand must be a vector or tensor");
 
-  if (getType() != aggregateType.getElementType())
-    return emitOpError("result type must match element type of aggregate");
+  if (op.getType() != aggregateType.getElementType())
+    return op.emitOpError("result type must match element type of aggregate");
 
-  for (auto *idx : getIndices())
+  for (auto *idx : op.getIndices())
     if (!idx->getType().isIndex())
-      return emitOpError("index to extract_element must have 'index' type");
+      return op.emitOpError("index to extract_element must have 'index' type");
 
   // Verify the # indices match if we have a ranked type.
   auto aggregateRank = aggregateType.getRank();
-  if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
-    return emitOpError("incorrect number of indices for extract_element");
+  if (aggregateRank != -1 && aggregateRank != op.getNumOperands() - 1)
+    return op.emitOpError("incorrect number of indices for extract_element");
 
   return success();
 }
index 645df7f..ea25b5b 100644 (file)
@@ -2,7 +2,7 @@
 
 func @dim(tensor<1xf32>) {
 ^bb(%0: tensor<1xf32>):
-  "std.dim"(%0){index: "xyz"} : (tensor<1xf32>)->i32 // expected-error {{'std.dim' op requires an integer attribute named 'index'}}
+  "std.dim"(%0){index: "xyz"} : (tensor<1xf32>)->index // expected-error {{attribute 'index' failed to satisfy constraint: arbitrary integer attribute}}
   return
 }
 
@@ -10,7 +10,7 @@ func @dim(tensor<1xf32>) {
 
 func @dim2(tensor<1xf32>) {
 ^bb(%0: tensor<1xf32>):
-  "std.dim"(){index: "xyz"} : ()->i32 // expected-error {{'std.dim' op requires a single operand}}
+  "std.dim"(){index: "xyz"} : ()->index // expected-error {{'std.dim' op requires a single operand}}
   return
 }
 
@@ -18,7 +18,7 @@ func @dim2(tensor<1xf32>) {
 
 func @dim3(tensor<1xf32>) {
 ^bb(%0: tensor<1xf32>):
-  "std.dim"(%0){index: 1} : (tensor<1xf32>)->i32 // expected-error {{'std.dim' op index is out of range}}
+  "std.dim"(%0){index: 1} : (tensor<1xf32>)->index // expected-error {{'std.dim' op index is out of range}}
   return
 }