From: River Riddle Date: Mon, 13 May 2019 18:56:21 +0000 (-0700) Subject: Move MemRefCastOp and TensorCastOp to the Op Definition Generation framework. X-Git-Tag: llvmorg-11-init~1466^2~1730 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5d7546470ddb01c336e2f74d3f0a4fb8c1354545;p=platform%2Fupstream%2Fllvm.git Move MemRefCastOp and TensorCastOp to the Op Definition Generation framework. -- PiperOrigin-RevId: 247981385 --- diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index a50d4a3..71e187b 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -886,34 +886,6 @@ ParseResult parseCastOp(OpAsmParser *parser, OperationState *result); void printCastOp(Operation *op, OpAsmPrinter *p); Value *foldCastOp(Operation *op); } // namespace impl - -/// This template is used for operations that are cast operations, that have a -/// single operand and single results, whose source and destination types are -/// different. -/// -/// From this structure, subclasses get a standard builder, parser and printer. -/// -template class... Traits> -class CastOp : public Op { -public: - using Op::Op; - - static void build(Builder *builder, OperationState *result, Value *source, - Type destType) { - impl::buildCastOp(builder, result, source, destType); - } - static ParseResult parse(OpAsmParser *parser, OperationState *result) { - return impl::parseCastOp(parser, result); - } - void print(OpAsmPrinter *p) { - return impl::printCastOp(this->getOperation(), p); - } - - Value *fold() { return impl::foldCastOp(this->getOperation()); } -}; - } // end namespace mlir #endif diff --git a/mlir/include/mlir/StandardOps/Ops.h b/mlir/include/mlir/StandardOps/Ops.h index 80e73d3..866012b 100644 --- a/mlir/include/mlir/StandardOps/Ops.h +++ b/mlir/include/mlir/StandardOps/Ops.h @@ -32,12 +32,6 @@ namespace mlir { class AffineMap; class Builder; -namespace detail { -/// A custom binary operation printer that omits the "std." prefix from the -/// operation names. -void printStandardBinaryOp(Operation *op, OpAsmPrinter *p); -} // namespace detail - class StandardOpsDialect : public Dialect { public: StandardOpsDialect(MLIRContext *context); @@ -579,38 +573,6 @@ public: MLIRContext *context); }; -/// The "memref_cast" operation converts a memref from one type to an equivalent -/// type with a compatible shape. The source and destination types are -/// when both are memref types with the same element type, affine mappings, -/// address space, and rank but where the individual dimensions may add or -/// remove constant dimensions from the memref type. -/// -/// If the cast converts any dimensions from an unknown to a known size, then it -/// acts as an assertion that fails at runtime of the dynamic dimensions -/// disagree with resultant destination size. -/// -/// Assert that the input dynamic shape matches the destination static shape. -/// %2 = memref_cast %1 : memref to memref<4x4xf32> -/// Erase static shape information, replacing it with dynamic information. -/// %3 = memref_cast %1 : memref<4xf32> to memref -/// -class MemRefCastOp : public CastOp { -public: - using CastOp::CastOp; - static StringRef getOperationName() { return "std.memref_cast"; } - - /// Return true if `a` and `b` are valid operand and result pairs for - /// the operation. - static bool areCastCompatible(Type a, Type b); - - /// The result of a memref_cast is always a memref. - MemRefType getType() { return getResult()->getType().cast(); } - - void print(OpAsmPrinter *p); - - LogicalResult verify(); -}; - /// The "select" operation chooses one value based on a binary condition /// supplied as its first operand. If the value of the first operand is 1, the /// second operand is chosen, otherwise the third operand is chosen. The second @@ -683,33 +645,6 @@ public: MLIRContext *context); }; -/// The "tensor_cast" operation converts a tensor from one type to an equivalent -/// type without changing any data elements. The source and destination types -/// must both be tensor types with the same element type. If both are ranked -/// then the rank should be the same and static dimensions should match. The -/// operation is invalid if converting to a mismatching constant dimension. -/// -/// Convert from unknown rank to rank 2 with unknown dimension sizes. -/// %2 = tensor_cast %1 : tensor to tensor -/// -class TensorCastOp : public CastOp { -public: - using CastOp::CastOp; - - static StringRef getOperationName() { return "std.tensor_cast"; } - - /// Return true if `a` and `b` are valid operand and result pairs for - /// the operation. - static bool areCastCompatible(Type a, Type b); - - /// The result of a tensor_cast is always a tensor. - TensorType getType() { return getResult()->getType().cast(); } - - void print(OpAsmPrinter *p); - - LogicalResult verify(); -}; - /// Prints dimension and symbol list. void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, diff --git a/mlir/include/mlir/StandardOps/Ops.td b/mlir/include/mlir/StandardOps/Ops.td index a82b1c4..13991d6 100644 --- a/mlir/include/mlir/StandardOps/Ops.td +++ b/mlir/include/mlir/StandardOps/Ops.td @@ -32,6 +32,29 @@ def Standard_Dialect : Dialect { let name = "std"; } +// Base class for standard cast operations. Requires single operand and result, +// but does not constrain them to specific types. +class CastOp traits = []> : + Op { + + let results = (outs AnyType); + + let builders = [OpBuilder< + "Builder *builder, OperationState *result, Value *source, Type destType", [{ + impl::buildCastOp(builder, result, source, destType); + }]>]; + + let parser = [{ + return impl::parseCastOp(parser, result); + }]; + let printer = [{ + return printStandardCastOp(this->getOperation(), p); + }]; + let verifier = [{ return ::verifyCastOp(*this); }]; + + let hasFolder = 1; +} + // Base class for standard arithmetic operations. Requires operands and // results to be of the same type, but does not constrain them to specific // types. Individual classes will have `lhs` and `rhs` accessor to operands. @@ -46,7 +69,7 @@ class ArithmeticOp traits = []> : }]; let printer = [{ - return detail::printStandardBinaryOp(this->getOperation(), p); + return printStandardBinaryOp(this->getOperation(), p); }]; } @@ -383,6 +406,38 @@ def ExtractElementOp : Op { let hasConstantFolder = 0b1; } +def MemRefCastOp : CastOp<"memref_cast"> { + let summary = "memref cast operation"; + let description = [{ + The "memref_cast" operation converts a memref from one type to an equivalent + type with a compatible shape. The source and destination types are + when both are memref types with the same element type, affine mappings, + address space, and rank but where the individual dimensions may add or + remove constant dimensions from the memref type. + + If the cast converts any dimensions from an unknown to a known size, then it + acts as an assertion that fails at runtime of the dynamic dimensions + disagree with resultant destination size. + + Assert that the input dynamic shape matches the destination static shape. + %2 = memref_cast %1 : memref to memref<4x4xf32> + Erase static shape information, replacing it with dynamic information. + %3 = memref_cast %1 : memref<4xf32> to memref + }]; + + let arguments = (ins MemRef); + let results = (outs MemRef); + + let extraClassDeclaration = [{ + /// Return true if `a` and `b` are valid operand and result pairs for + /// the operation. + static bool areCastCompatible(Type a, Type b); + + /// The result of a memref_cast is always a memref. + MemRefType getType() { return getResult()->getType().cast(); } + }]; +} + def MulFOp : FloatArithmeticOp<"mulf"> { let summary = "foating point multiplication operation"; let hasConstantFolder = 0b1; @@ -453,6 +508,32 @@ def SubIOp : IntArithmeticOp<"subi"> { let hasCanonicalizer = 0b1; } +def TensorCastOp : CastOp<"tensor_cast"> { + let summary = "tensor cast operation"; + let description = [{ + The "tensor_cast" operation converts a tensor from one type to an equivalent + type without changing any data elements. The source and destination types + must both be tensor types with the same element type. If both are ranked + then the rank should be the same and static dimensions should match. The + operation is invalid if converting to a mismatching constant dimension. + + Convert from unknown rank to rank 2 with unknown dimension sizes. + %2 = tensor_cast %1 : tensor to tensor + }]; + + let arguments = (ins Tensor); + let results = (outs Tensor); + + let extraClassDeclaration = [{ + /// Return true if `a` and `b` are valid operand and result pairs for + /// the operation. + static bool areCastCompatible(Type a, Type b); + + /// The result of a tensor_cast is always a tensor. + TensorType getType() { return getResult()->getType().cast(); } + }]; +} + def XOrOp : IntArithmeticOp<"xor", [Commutative]> { let summary = "integer binary xor"; let hasConstantFolder = 0b1; diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 9fe8e21..c8ba1d0 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -38,7 +38,7 @@ using namespace mlir; /// A custom binary operation printer that omits the "std." prefix from the /// operation names. -void detail::printStandardBinaryOp(Operation *op, OpAsmPrinter *p) { +static void printStandardBinaryOp(Operation *op, OpAsmPrinter *p) { assert(op->getNumOperands() == 2 && "binary op should have two operands"); assert(op->getNumResults() == 1 && "binary op should have one result"); @@ -59,10 +59,29 @@ void detail::printStandardBinaryOp(Operation *op, OpAsmPrinter *p) { *p << " : " << op->getResult(0)->getType(); } +/// A custom cast operation printer that omits the "std." prefix from the +/// operation names. +static void printStandardCastOp(Operation *op, OpAsmPrinter *p) { + *p << op->getName().getStringRef().drop_front(strlen("std.")) << ' ' + << *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to " + << op->getResult(0)->getType(); +} + +/// A custom cast operation verifier. +template static LogicalResult verifyCastOp(T op) { + auto opType = op.getOperand()->getType(); + auto resType = op.getType(); + if (!T::areCastCompatible(opType, resType)) + return op.emitError("operand type ") << opType << " and result type " + << resType << " are cast incompatible"; + + return success(); +} + StandardOpsDialect::StandardOpsDialect(MLIRContext *context) : Dialect(/*name=*/"std", context) { addOperations(); @@ -1783,21 +1802,7 @@ bool MemRefCastOp::areCastCompatible(Type a, Type b) { return true; } -void MemRefCastOp::print(OpAsmPrinter *p) { - *p << "memref_cast " << *getOperand() << " : " << getOperand()->getType() - << " to " << getType(); -} - -LogicalResult MemRefCastOp::verify() { - auto opType = getOperand()->getType(); - auto resType = getType(); - if (!areCastCompatible(opType, resType)) - return emitError(llvm::formatv( - "operand type {0} and result type {1} are cast incompatible", opType, - resType)); - - return success(); -} +Value *MemRefCastOp::fold() { return impl::foldCastOp(*this); } //===----------------------------------------------------------------------===// // MulFOp @@ -2235,21 +2240,7 @@ bool TensorCastOp::areCastCompatible(Type a, Type b) { return true; } -void TensorCastOp::print(OpAsmPrinter *p) { - *p << "tensor_cast " << *getOperand() << " : " << getOperand()->getType() - << " to " << getType(); -} - -LogicalResult TensorCastOp::verify() { - auto opType = getOperand()->getType(); - auto resType = getType(); - if (!areCastCompatible(opType, resType)) - return emitError(llvm::formatv( - "operand type {0} and result type {1} are cast incompatible", opType, - resType)); - - return success(); -} +Value *TensorCastOp::fold() { return impl::foldCastOp(*this); } //===----------------------------------------------------------------------===// // TableGen'd op method definitions