From 77c333ca62650874dc3b81996cb75bc916a18e46 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 10 May 2019 15:26:23 -0700 Subject: [PATCH] Move the definitions of BranchOp, DimOp, and ExtractElementOp to Op Definition Generator. -- PiperOrigin-RevId: 247686212 --- mlir/include/mlir/IR/OpBase.td | 2 + mlir/include/mlir/StandardOps/Ops.h | 99 ------------------------------- mlir/include/mlir/StandardOps/Ops.td | 111 +++++++++++++++++++++++++++++++++++ mlir/lib/StandardOps/Ops.cpp | 99 ++++++++++++------------------- mlir/test/IR/invalid-ops.mlir | 6 +- 5 files changed, 155 insertions(+), 162 deletions(-) diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 7e90f4a..b4b159e 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -340,6 +340,8 @@ class Vector dims> : ContainerType dimensions = dims; } +def VectorOrTensor : Type; + // Tensor type. // This represents a generic tensor without constraints on elemental type, diff --git a/mlir/include/mlir/StandardOps/Ops.h b/mlir/include/mlir/StandardOps/Ops.h index 81d3614..7f3e8ab 100644 --- a/mlir/include/mlir/StandardOps/Ops.h +++ b/mlir/include/mlir/StandardOps/Ops.h @@ -46,39 +46,6 @@ public: #define GET_OP_CLASSES #include "mlir/StandardOps/Ops.h.inc" -/// The "br" operation represents a branch operation in a function. -/// The operation takes variable number of operands and produces no results. -/// The operand number and types for each successor must match the -/// arguments of the block successor. For example: -/// -/// ^bb2: -/// %2 = call @someFn() -/// br ^bb3(%2 : tensor<*xf32>) -/// ^bb3(%3: tensor<*xf32>): -/// -class BranchOp : public Op { -public: - friend Operation; - using Op::Op; - - static StringRef getOperationName() { return "std.br"; } - - static void build(Builder *builder, OperationState *result, Block *dest, - ArrayRef operands = {}); - - // Hooks to customize behavior of this op. - static ParseResult parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p); - - /// Return the block this branch jumps to. - Block *getDest(); - void setDest(Block *block); - - /// Erase the operand at 'index' from the operand list. - void eraseOperand(unsigned index); -}; - /// The "call" operation represents a direct call to a function. The operands /// and result types of the call must match the specified function type. The /// callee is encoded as a function attribute named "callee". @@ -457,36 +424,6 @@ public: static bool isClassFor(Operation *op); }; -/// The "dim" operation takes a memref or tensor operand and returns an -/// "index". It requires a single integer attribute named "index". It -/// returns the size of the specified dimension. For example: -/// -/// %1 = dim %0, 2 : tensor -/// -class DimOp : public Op { -public: - friend Operation; - using Op::Op; - - static void build(Builder *builder, OperationState *result, - Value *memrefOrTensor, unsigned index); - - Attribute constantFold(ArrayRef operands, MLIRContext *context); - - /// This returns the dimension number that the 'dim' is inspecting. - unsigned getIndex() { - return getAttrOfType("index").getValue().getZExtValue(); - } - - static StringRef getOperationName() { return "std.dim"; } - - // Hooks to customize behavior of this op. - LogicalResult verify(); - static ParseResult parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p); -}; - // DmaStartOp starts a non-blocking DMA operation that transfers data from a // source memref to a destination memref. The source and destination memref need // not be of the same dimensionality, but need to have the same elemental type. @@ -684,42 +621,6 @@ public: MLIRContext *context); }; -/// The "extract_element" op reads a tensor or vector and returns one element -/// from it specified by an index list. The output of extract is a new value -/// with the same type as the elements of the tensor or vector. The arity of -/// indices matches the rank of the accessed value (i.e., if a tensor is of rank -/// 3, then 3 indices are required for the extract). The indices should all be -/// of affine_int type. -/// -/// For example: -/// -/// %3 = extract_element %0[%1, %2] : vector<4x4xi32> -/// -class ExtractElementOp - : public Op { -public: - friend Operation; - using Op::Op; - - static void build(Builder *builder, OperationState *result, Value *aggregate, - ArrayRef indices = {}); - - Value *getAggregate() { return getOperand(0); } - - operand_range getIndices() { - return {getOperation()->operand_begin() + 1, getOperation()->operand_end()}; - } - - static StringRef getOperationName() { return "std.extract_element"; } - - // Hooks to customize behavior of this op. - LogicalResult verify(); - static ParseResult parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p); - Attribute constantFold(ArrayRef operands, MLIRContext *context); -}; - /// The "load" op reads an element from a memref specified by an index list. The /// output of load is a new value with the same type as the elements of the /// memref. The arity of indices is the rank of the memref (i.e., if the memref diff --git a/mlir/include/mlir/StandardOps/Ops.td b/mlir/include/mlir/StandardOps/Ops.td index 237730c..cfdbf1d 100644 --- a/mlir/include/mlir/StandardOps/Ops.td +++ b/mlir/include/mlir/StandardOps/Ops.td @@ -134,6 +134,40 @@ def AndOp : IntArithmeticOp<"and", [Commutative]> { let hasFolder = 1; } +def BranchOp : Op { + let summary = "branch operation"; + let description = [{ + The "br" operation represents a branch operation in a function. + The operation takes variable number of operands and produces no results. + The operand number and types for each successor must match the arguments of + the block successor. For example: + + ^bb2: + %2 = call @someFn() + br ^bb3(%2 : tensor<*xf32>) + ^bb3(%3: tensor<*xf32>): + }]; + + let arguments = (ins Variadic:$operands); + + let parser = [{ return parseBranchOp(parser, result); }]; + let printer = [{ return printBranchOp(p, *this); }]; + + let builders = [OpBuilder< + "Builder *, OperationState *result, Block *dest," + "ArrayRef operands = {}", [{ + result->addSuccessor(dest, operands); + }]>]; + + let extraClassDeclaration = [{ + Block *getDest(); + void setDest(Block *block); + + /// Erase the operand at 'index' from the operand list. + void eraseOperand(unsigned index); + }]; +} + def ConstantOp : Op { let summary = "constant"; @@ -177,6 +211,42 @@ def DeallocOp : Op { let hasCanonicalizer = 0b1; } +def DimOp : Op { + let summary = "dimension index operation"; + let description = [{ + The "dim" operation takes a memref or tensor operand and returns an "index". + It requires a single integer attribute named "index". It returns the size + of the specified dimension. For example: + + %1 = dim %0, 2 : tensor + }]; + + let arguments = (ins AnyTypeOf<[MemRef, Tensor], + "any tensor or memref type">:$memrefOrTensor, + APIntAttr:$index); + let results = (outs Index); + + let parser = [{ return parseDimOp(parser, result); }]; + let printer = [{ return printDimOp(p, *this); }]; + let verifier = [{ return ::verify(*this); }]; + + let builders = [OpBuilder< + "Builder *builder, OperationState *result, Value *memrefOrTensor," + "unsigned index", [{ + auto indexType = builder->getIndexType(); + auto indexAttr = builder->getIntegerAttr(indexType, index); + build(builder, result, indexType, memrefOrTensor, indexAttr); + }]>]; + + let extraClassDeclaration = [{ + unsigned getIndex() { + return getAttrOfType("index").getValue().getZExtValue(); + } + }]; + + let hasConstantFolder = 0b1; +} + def DivFOp : FloatArithmeticOp<"divf"> { let summary = "floating point division operation"; } @@ -191,6 +261,47 @@ def DivIUOp : IntArithmeticOp<"diviu"> { let hasConstantFolder = 0b1; } +def ExtractElementOp : Op { + let summary = "element extract operation"; + let description = [{ + The "extract_element" op reads a tensor or vector and returns one element + from it specified by an index list. The output of extract is a new value + with the same type as the elements of the tensor or vector. The arity of + indices matches the rank of the accessed value (i.e., if a tensor is of rank + 3, then 3 indices are required for the extract). The indices should all be + of affine_int type. For example: + + %0 = extract_element %0[%1, %2] : vector<4x4xi32> + }]; + + let arguments = (ins VectorOrTensor:$aggregate, + Variadic:$indices); + let results = (outs AnyType); + + let parser = [{ return parseExtractElementOp(parser, result); }]; + let printer = [{ return printExtractElementOp(p, *this); }]; + let verifier = [{ return ::verify(*this); }]; + + let builders = [OpBuilder< + "Builder *builder, OperationState *result, Value *aggregate," + "ArrayRef indices = {}", [{ + auto resType = aggregate->getType().cast() + .getElementType(); + build(builder, result, resType, aggregate, indices); + }]>]; + + let extraClassDeclaration = [{ + Value *getAggregate() { return getOperand(0); } + + operand_range getIndices() { + return {getOperation()->operand_begin() + 1, + getOperation()->operand_end()}; + } + }]; + + let hasConstantFolder = 0b1; +} + def MulFOp : FloatArithmeticOp<"mulf"> { let summary = "foating point multiplication operation"; let hasConstantFolder = 0b1; diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index f9b13ce..9ad37fd 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -61,9 +61,9 @@ void detail::printStandardBinaryOp(Operation *op, OpAsmPrinter *p) { StandardOpsDialect::StandardOpsDialect(MLIRContext *context) : Dialect(/*name=*/"std", context) { - addOperations(); @@ -374,12 +374,7 @@ void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, // BranchOp //===----------------------------------------------------------------------===// -void BranchOp::build(Builder *builder, OperationState *result, Block *dest, - ArrayRef operands) { - result->addSuccessor(dest, operands); -} - -ParseResult BranchOp::parse(OpAsmParser *parser, OperationState *result) { +static ParseResult parseBranchOp(OpAsmParser *parser, OperationState *result) { Block *dest; SmallVector destOperands; if (parser->parseSuccessorAndUseList(dest, destOperands)) @@ -388,9 +383,9 @@ ParseResult BranchOp::parse(OpAsmParser *parser, OperationState *result) { return success(); } -void BranchOp::print(OpAsmPrinter *p) { +static void printBranchOp(OpAsmPrinter *p, BranchOp op) { *p << "br "; - p->printSuccessorAndUseList(getOperation(), 0); + p->printSuccessorAndUseList(op.getOperation(), 0); } Block *BranchOp::getDest() { return getOperation()->getSuccessor(0); } @@ -1297,21 +1292,13 @@ void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, // DimOp //===----------------------------------------------------------------------===// -void DimOp::build(Builder *builder, OperationState *result, - Value *memrefOrTensor, unsigned index) { - result->addOperands(memrefOrTensor); - auto type = builder->getIndexType(); - result->addAttribute("index", builder->getIntegerAttr(type, index)); - result->types.push_back(type); +static void printDimOp(OpAsmPrinter *p, DimOp op) { + *p << "dim " << *op.getOperand() << ", " << op.getIndex(); + p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"}); + *p << " : " << op.getOperand()->getType(); } -void DimOp::print(OpAsmPrinter *p) { - *p << "dim " << *getOperand() << ", " << getIndex(); - p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"index"}); - *p << " : " << getOperand()->getType(); -} - -ParseResult DimOp::parse(OpAsmParser *parser, OperationState *result) { +static ParseResult parseDimOp(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType operandInfo; IntegerAttr indexAttr; Type type; @@ -1326,25 +1313,25 @@ ParseResult DimOp::parse(OpAsmParser *parser, OperationState *result) { parser->addTypeToList(indexType, result->types)); } -LogicalResult DimOp::verify() { +static LogicalResult verify(DimOp op) { // Check that we have an integer index operand. - auto indexAttr = getAttrOfType("index"); + auto indexAttr = op.getAttrOfType("index"); if (!indexAttr) - return emitOpError("requires an integer attribute named 'index'"); + return op.emitOpError("requires an integer attribute named 'index'"); uint64_t index = indexAttr.getValue().getZExtValue(); - auto type = getOperand()->getType(); + auto type = op.getOperand()->getType(); if (auto tensorType = type.dyn_cast()) { if (index >= static_cast(tensorType.getRank())) - return emitOpError("index is out of range"); + return op.emitOpError("index is out of range"); } else if (auto memrefType = type.dyn_cast()) { if (index >= memrefType.getRank()) - return emitOpError("index is out of range"); + return op.emitOpError("index is out of range"); } else if (type.isa()) { // ok, assumed to be in-range. } else { - return emitOpError("requires an operand with tensor or memref type"); + return op.emitOpError("requires an operand with tensor or memref type"); } return success(); @@ -1355,11 +1342,10 @@ Attribute DimOp::constantFold(ArrayRef operands, // Constant fold dim when the size along the index referred to is a constant. auto opType = getOperand()->getType(); int64_t indexSize = -1; - if (auto tensorType = opType.dyn_cast()) { + if (auto tensorType = opType.dyn_cast()) indexSize = tensorType.getShape()[getIndex()]; - } else if (auto memrefType = opType.dyn_cast()) { + else if (auto memrefType = opType.dyn_cast()) indexSize = memrefType.getShape()[getIndex()]; - } if (indexSize >= 0) return IntegerAttr::get(IndexType::get(context), indexSize); @@ -1641,24 +1627,16 @@ void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results, // ExtractElementOp //===----------------------------------------------------------------------===// -void ExtractElementOp::build(Builder *builder, OperationState *result, - Value *aggregate, ArrayRef indices) { - auto aggregateType = aggregate->getType().cast(); - result->addOperands(aggregate); - result->addOperands(indices); - result->types.push_back(aggregateType.getElementType()); -} - -void ExtractElementOp::print(OpAsmPrinter *p) { - *p << "extract_element " << *getAggregate() << '['; - p->printOperands(getIndices()); +static void printExtractElementOp(OpAsmPrinter *p, ExtractElementOp op) { + *p << "extract_element " << *op.getAggregate() << '['; + p->printOperands(op.getIndices()); *p << ']'; - p->printOptionalAttrDict(getAttrs()); - *p << " : " << getAggregate()->getType(); + p->printOptionalAttrDict(op.getAttrs()); + *p << " : " << op.getAggregate()->getType(); } -ParseResult ExtractElementOp::parse(OpAsmParser *parser, - OperationState *result) { +static ParseResult parseExtractElementOp(OpAsmParser *parser, + OperationState *result) { OpAsmParser::OperandType aggregateInfo; SmallVector indexInfo; VectorOrTensorType type; @@ -1674,25 +1652,26 @@ ParseResult ExtractElementOp::parse(OpAsmParser *parser, parser->addTypeToList(type.getElementType(), result->types)); } -LogicalResult ExtractElementOp::verify() { - if (getNumOperands() == 0) - return emitOpError("expected an aggregate to index into"); +static LogicalResult verify(ExtractElementOp op) { + if (op.getNumOperands() == 0) + return op.emitOpError("expected an aggregate to index into"); - auto aggregateType = getAggregate()->getType().dyn_cast(); + auto aggregateType = + op.getAggregate()->getType().dyn_cast(); if (!aggregateType) - return emitOpError("first operand must be a vector or tensor"); + return op.emitOpError("first operand must be a vector or tensor"); - if (getType() != aggregateType.getElementType()) - return emitOpError("result type must match element type of aggregate"); + if (op.getType() != aggregateType.getElementType()) + return op.emitOpError("result type must match element type of aggregate"); - for (auto *idx : getIndices()) + for (auto *idx : op.getIndices()) if (!idx->getType().isIndex()) - return emitOpError("index to extract_element must have 'index' type"); + return op.emitOpError("index to extract_element must have 'index' type"); // Verify the # indices match if we have a ranked type. auto aggregateRank = aggregateType.getRank(); - if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1) - return emitOpError("incorrect number of indices for extract_element"); + if (aggregateRank != -1 && aggregateRank != op.getNumOperands() - 1) + return op.emitOpError("incorrect number of indices for extract_element"); return success(); } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index 645df7f..ea25b5b 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -2,7 +2,7 @@ func @dim(tensor<1xf32>) { ^bb(%0: tensor<1xf32>): - "std.dim"(%0){index: "xyz"} : (tensor<1xf32>)->i32 // expected-error {{'std.dim' op requires an integer attribute named 'index'}} + "std.dim"(%0){index: "xyz"} : (tensor<1xf32>)->index // expected-error {{attribute 'index' failed to satisfy constraint: arbitrary integer attribute}} return } @@ -10,7 +10,7 @@ func @dim(tensor<1xf32>) { func @dim2(tensor<1xf32>) { ^bb(%0: tensor<1xf32>): - "std.dim"(){index: "xyz"} : ()->i32 // expected-error {{'std.dim' op requires a single operand}} + "std.dim"(){index: "xyz"} : ()->index // expected-error {{'std.dim' op requires a single operand}} return } @@ -18,7 +18,7 @@ func @dim2(tensor<1xf32>) { func @dim3(tensor<1xf32>) { ^bb(%0: tensor<1xf32>): - "std.dim"(%0){index: 1} : (tensor<1xf32>)->i32 // expected-error {{'std.dim' op index is out of range}} + "std.dim"(%0){index: 1} : (tensor<1xf32>)->index // expected-error {{'std.dim' op index is out of range}} return } -- 2.7.4