Move MemRefCastOp and TensorCastOp to the Op Definition Generation framework.
authorRiver Riddle <riverriddle@google.com>
Mon, 13 May 2019 18:56:21 +0000 (11:56 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 20 May 2019 20:39:53 +0000 (13:39 -0700)
--

PiperOrigin-RevId: 247981385

mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/StandardOps/Ops.h
mlir/include/mlir/StandardOps/Ops.td
mlir/lib/StandardOps/Ops.cpp

index a50d4a3..71e187b 100644 (file)
@@ -886,34 +886,6 @@ ParseResult parseCastOp(OpAsmParser *parser, OperationState *result);
 void printCastOp(Operation *op, OpAsmPrinter *p);
 Value *foldCastOp(Operation *op);
 } // namespace impl
-
-/// This template is used for operations that are cast operations, that have a
-/// single operand and single results, whose source and destination types are
-/// different.
-///
-/// From this structure, subclasses get a standard builder, parser and printer.
-///
-template <typename ConcreteType, template <typename T> class... Traits>
-class CastOp : public Op<ConcreteType, OpTrait::OneOperand, OpTrait::OneResult,
-                         OpTrait::HasNoSideEffect, Traits...> {
-public:
-  using Op<ConcreteType, OpTrait::OneOperand, OpTrait::OneResult,
-           OpTrait::HasNoSideEffect, Traits...>::Op;
-
-  static void build(Builder *builder, OperationState *result, Value *source,
-                    Type destType) {
-    impl::buildCastOp(builder, result, source, destType);
-  }
-  static ParseResult parse(OpAsmParser *parser, OperationState *result) {
-    return impl::parseCastOp(parser, result);
-  }
-  void print(OpAsmPrinter *p) {
-    return impl::printCastOp(this->getOperation(), p);
-  }
-
-  Value *fold() { return impl::foldCastOp(this->getOperation()); }
-};
-
 } // end namespace mlir
 
 #endif
index 80e73d3..866012b 100644 (file)
@@ -32,12 +32,6 @@ namespace mlir {
 class AffineMap;
 class Builder;
 
-namespace detail {
-/// A custom binary operation printer that omits the "std." prefix from the
-/// operation names.
-void printStandardBinaryOp(Operation *op, OpAsmPrinter *p);
-} // namespace detail
-
 class StandardOpsDialect : public Dialect {
 public:
   StandardOpsDialect(MLIRContext *context);
@@ -579,38 +573,6 @@ public:
                                           MLIRContext *context);
 };
 
-/// The "memref_cast" operation converts a memref from one type to an equivalent
-/// type with a compatible shape.  The source and destination types are
-/// when both are memref types with the same element type, affine mappings,
-/// address space, and rank but where the individual dimensions may add or
-/// remove constant dimensions from the memref type.
-///
-/// If the cast converts any dimensions from an unknown to a known size, then it
-/// acts as an assertion that fails at runtime of the dynamic dimensions
-/// disagree with resultant destination size.
-///
-/// Assert that the input dynamic shape matches the destination static shape.
-///    %2 = memref_cast %1 : memref<?x?xf32> to memref<4x4xf32>
-/// Erase static shape information, replacing it with dynamic information.
-///    %3 = memref_cast %1 : memref<4xf32> to memref<?xf32>
-///
-class MemRefCastOp : public CastOp<MemRefCastOp> {
-public:
-  using CastOp::CastOp;
-  static StringRef getOperationName() { return "std.memref_cast"; }
-
-  /// Return true if `a` and `b` are valid operand and result pairs for
-  /// the operation.
-  static bool areCastCompatible(Type a, Type b);
-
-  /// The result of a memref_cast is always a memref.
-  MemRefType getType() { return getResult()->getType().cast<MemRefType>(); }
-
-  void print(OpAsmPrinter *p);
-
-  LogicalResult verify();
-};
-
 /// The "select" operation chooses one value based on a binary condition
 /// supplied as its first operand. If the value of the first operand is 1, the
 /// second operand is chosen, otherwise the third operand is chosen. The second
@@ -683,33 +645,6 @@ public:
                                           MLIRContext *context);
 };
 
-/// The "tensor_cast" operation converts a tensor from one type to an equivalent
-/// type without changing any data elements.  The source and destination types
-/// must both be tensor types with the same element type.  If both are ranked
-/// then the rank should be the same and static dimensions should match.  The
-/// operation is invalid if converting to a mismatching constant dimension.
-///
-/// Convert from unknown rank to rank 2 with unknown dimension sizes.
-///    %2 = tensor_cast %1 : tensor<??f32> to tensor<?x?xf32>
-///
-class TensorCastOp : public CastOp<TensorCastOp> {
-public:
-  using CastOp::CastOp;
-
-  static StringRef getOperationName() { return "std.tensor_cast"; }
-
-  /// Return true if `a` and `b` are valid operand and result pairs for
-  /// the operation.
-  static bool areCastCompatible(Type a, Type b);
-
-  /// The result of a tensor_cast is always a tensor.
-  TensorType getType() { return getResult()->getType().cast<TensorType>(); }
-
-  void print(OpAsmPrinter *p);
-
-  LogicalResult verify();
-};
-
 /// Prints dimension and symbol list.
 void printDimAndSymbolList(Operation::operand_iterator begin,
                            Operation::operand_iterator end, unsigned numDims,
index a82b1c4..13991d6 100644 (file)
@@ -32,6 +32,29 @@ def Standard_Dialect : Dialect {
   let name = "std";
 }
 
+// Base class for standard cast operations. Requires single operand and result,
+// but does not constrain them to specific types.
+class CastOp<string mnemonic, list<OpTrait> traits = []> :
+    Op<Standard_Dialect, mnemonic, !listconcat(traits, [NoSideEffect])> {
+
+  let results = (outs AnyType);
+
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, Value *source, Type destType", [{
+       impl::buildCastOp(builder, result, source, destType);
+  }]>];
+
+  let parser = [{
+    return impl::parseCastOp(parser, result);
+  }];
+  let printer = [{
+    return printStandardCastOp(this->getOperation(), p);
+  }];
+  let verifier = [{ return ::verifyCastOp(*this); }];
+
+  let hasFolder = 1;
+}
+
 // Base class for standard arithmetic operations.  Requires operands and
 // results to be of the same type, but does not constrain them to specific
 // types.  Individual classes will have `lhs` and `rhs` accessor to operands.
@@ -46,7 +69,7 @@ class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
   }];
 
   let printer = [{
-    return detail::printStandardBinaryOp(this->getOperation(), p);
+    return printStandardBinaryOp(this->getOperation(), p);
   }];
 }
 
@@ -383,6 +406,38 @@ def ExtractElementOp : Op<Standard_Dialect, "extract_element", [NoSideEffect]> {
   let hasConstantFolder = 0b1;
 }
 
+def MemRefCastOp : CastOp<"memref_cast"> {
+  let summary = "memref cast operation";
+  let description = [{
+    The "memref_cast" operation converts a memref from one type to an equivalent
+    type with a compatible shape. The source and destination types are
+    when both are memref types with the same element type, affine mappings,
+    address space, and rank but where the individual dimensions may add or
+    remove constant dimensions from the memref type.
+
+    If the cast converts any dimensions from an unknown to a known size, then it
+    acts as an assertion that fails at runtime of the dynamic dimensions
+    disagree with resultant destination size.
+
+    Assert that the input dynamic shape matches the destination static shape.
+       %2 = memref_cast %1 : memref<?x?xf32> to memref<4x4xf32>
+    Erase static shape information, replacing it with dynamic information.
+       %3 = memref_cast %1 : memref<4xf32> to memref<?xf32>
+  }];
+
+  let arguments = (ins MemRef<AnyType>);
+  let results = (outs MemRef<AnyType>);
+
+  let extraClassDeclaration = [{
+    /// Return true if `a` and `b` are valid operand and result pairs for
+    /// the operation.
+    static bool areCastCompatible(Type a, Type b);
+
+    /// The result of a memref_cast is always a memref.
+    MemRefType getType() { return getResult()->getType().cast<MemRefType>(); }
+  }];
+}
+
 def MulFOp : FloatArithmeticOp<"mulf"> {
   let summary = "foating point multiplication operation";
   let hasConstantFolder = 0b1;
@@ -453,6 +508,32 @@ def SubIOp : IntArithmeticOp<"subi"> {
   let hasCanonicalizer = 0b1;
 }
 
+def TensorCastOp : CastOp<"tensor_cast"> {
+  let summary = "tensor cast operation";
+  let description = [{
+    The "tensor_cast" operation converts a tensor from one type to an equivalent
+    type without changing any data elements.  The source and destination types
+    must both be tensor types with the same element type.  If both are ranked
+    then the rank should be the same and static dimensions should match.  The
+    operation is invalid if converting to a mismatching constant dimension.
+
+    Convert from unknown rank to rank 2 with unknown dimension sizes.
+       %2 = tensor_cast %1 : tensor<??f32> to tensor<?x?xf32>
+  }];
+
+  let arguments = (ins Tensor);
+  let results = (outs Tensor);
+
+  let extraClassDeclaration = [{
+    /// Return true if `a` and `b` are valid operand and result pairs for
+    /// the operation.
+    static bool areCastCompatible(Type a, Type b);
+
+    /// The result of a tensor_cast is always a tensor.
+    TensorType getType() { return getResult()->getType().cast<TensorType>(); }
+  }];
+}
+
 def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
   let summary = "integer binary xor";
   let hasConstantFolder = 0b1;
index 9fe8e21..c8ba1d0 100644 (file)
@@ -38,7 +38,7 @@ using namespace mlir;
 
 /// A custom binary operation printer that omits the "std." prefix from the
 /// operation names.
-void detail::printStandardBinaryOp(Operation *op, OpAsmPrinter *p) {
+static void printStandardBinaryOp(Operation *op, OpAsmPrinter *p) {
   assert(op->getNumOperands() == 2 && "binary op should have two operands");
   assert(op->getNumResults() == 1 && "binary op should have one result");
 
@@ -59,10 +59,29 @@ void detail::printStandardBinaryOp(Operation *op, OpAsmPrinter *p) {
   *p << " : " << op->getResult(0)->getType();
 }
 
+/// A custom cast operation printer that omits the "std." prefix from the
+/// operation names.
+static void printStandardCastOp(Operation *op, OpAsmPrinter *p) {
+  *p << op->getName().getStringRef().drop_front(strlen("std.")) << ' '
+     << *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to "
+     << op->getResult(0)->getType();
+}
+
+/// A custom cast operation verifier.
+template <typename T> static LogicalResult verifyCastOp(T op) {
+  auto opType = op.getOperand()->getType();
+  auto resType = op.getType();
+  if (!T::areCastCompatible(opType, resType))
+    return op.emitError("operand type ") << opType << " and result type "
+                                         << resType << " are cast incompatible";
+
+  return success();
+}
+
 StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
     : Dialect(/*name=*/"std", context) {
   addOperations<CmpFOp, CmpIOp, CondBranchOp, DmaStartOp, DmaWaitOp, LoadOp,
-                MemRefCastOp, SelectOp, StoreOp, TensorCastOp,
+                SelectOp, StoreOp,
 #define GET_OP_LIST
 #include "mlir/StandardOps/Ops.cpp.inc"
                 >();
@@ -1783,21 +1802,7 @@ bool MemRefCastOp::areCastCompatible(Type a, Type b) {
   return true;
 }
 
-void MemRefCastOp::print(OpAsmPrinter *p) {
-  *p << "memref_cast " << *getOperand() << " : " << getOperand()->getType()
-     << " to " << getType();
-}
-
-LogicalResult MemRefCastOp::verify() {
-  auto opType = getOperand()->getType();
-  auto resType = getType();
-  if (!areCastCompatible(opType, resType))
-    return emitError(llvm::formatv(
-        "operand type {0} and result type {1} are cast incompatible", opType,
-        resType));
-
-  return success();
-}
+Value *MemRefCastOp::fold() { return impl::foldCastOp(*this); }
 
 //===----------------------------------------------------------------------===//
 // MulFOp
@@ -2235,21 +2240,7 @@ bool TensorCastOp::areCastCompatible(Type a, Type b) {
   return true;
 }
 
-void TensorCastOp::print(OpAsmPrinter *p) {
-  *p << "tensor_cast " << *getOperand() << " : " << getOperand()->getType()
-     << " to " << getType();
-}
-
-LogicalResult TensorCastOp::verify() {
-  auto opType = getOperand()->getType();
-  auto resType = getType();
-  if (!areCastCompatible(opType, resType))
-    return emitError(llvm::formatv(
-        "operand type {0} and result type {1} are cast incompatible", opType,
-        resType));
-
-  return success();
-}
+Value *TensorCastOp::fold() { return impl::foldCastOp(*this); }
 
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions