void printCastOp(Operation *op, OpAsmPrinter *p);
Value *foldCastOp(Operation *op);
} // namespace impl
-
-/// This template is used for operations that are cast operations, that have a
-/// single operand and single results, whose source and destination types are
-/// different.
-///
-/// From this structure, subclasses get a standard builder, parser and printer.
-///
-template <typename ConcreteType, template <typename T> class... Traits>
-class CastOp : public Op<ConcreteType, OpTrait::OneOperand, OpTrait::OneResult,
- OpTrait::HasNoSideEffect, Traits...> {
-public:
- using Op<ConcreteType, OpTrait::OneOperand, OpTrait::OneResult,
- OpTrait::HasNoSideEffect, Traits...>::Op;
-
- static void build(Builder *builder, OperationState *result, Value *source,
- Type destType) {
- impl::buildCastOp(builder, result, source, destType);
- }
- static ParseResult parse(OpAsmParser *parser, OperationState *result) {
- return impl::parseCastOp(parser, result);
- }
- void print(OpAsmPrinter *p) {
- return impl::printCastOp(this->getOperation(), p);
- }
-
- Value *fold() { return impl::foldCastOp(this->getOperation()); }
-};
-
} // end namespace mlir
#endif
class AffineMap;
class Builder;
-namespace detail {
-/// A custom binary operation printer that omits the "std." prefix from the
-/// operation names.
-void printStandardBinaryOp(Operation *op, OpAsmPrinter *p);
-} // namespace detail
-
class StandardOpsDialect : public Dialect {
public:
StandardOpsDialect(MLIRContext *context);
MLIRContext *context);
};
-/// The "memref_cast" operation converts a memref from one type to an equivalent
-/// type with a compatible shape. The source and destination types are
-/// when both are memref types with the same element type, affine mappings,
-/// address space, and rank but where the individual dimensions may add or
-/// remove constant dimensions from the memref type.
-///
-/// If the cast converts any dimensions from an unknown to a known size, then it
-/// acts as an assertion that fails at runtime of the dynamic dimensions
-/// disagree with resultant destination size.
-///
-/// Assert that the input dynamic shape matches the destination static shape.
-/// %2 = memref_cast %1 : memref<?x?xf32> to memref<4x4xf32>
-/// Erase static shape information, replacing it with dynamic information.
-/// %3 = memref_cast %1 : memref<4xf32> to memref<?xf32>
-///
-class MemRefCastOp : public CastOp<MemRefCastOp> {
-public:
- using CastOp::CastOp;
- static StringRef getOperationName() { return "std.memref_cast"; }
-
- /// Return true if `a` and `b` are valid operand and result pairs for
- /// the operation.
- static bool areCastCompatible(Type a, Type b);
-
- /// The result of a memref_cast is always a memref.
- MemRefType getType() { return getResult()->getType().cast<MemRefType>(); }
-
- void print(OpAsmPrinter *p);
-
- LogicalResult verify();
-};
-
/// The "select" operation chooses one value based on a binary condition
/// supplied as its first operand. If the value of the first operand is 1, the
/// second operand is chosen, otherwise the third operand is chosen. The second
MLIRContext *context);
};
-/// The "tensor_cast" operation converts a tensor from one type to an equivalent
-/// type without changing any data elements. The source and destination types
-/// must both be tensor types with the same element type. If both are ranked
-/// then the rank should be the same and static dimensions should match. The
-/// operation is invalid if converting to a mismatching constant dimension.
-///
-/// Convert from unknown rank to rank 2 with unknown dimension sizes.
-/// %2 = tensor_cast %1 : tensor<??f32> to tensor<?x?xf32>
-///
-class TensorCastOp : public CastOp<TensorCastOp> {
-public:
- using CastOp::CastOp;
-
- static StringRef getOperationName() { return "std.tensor_cast"; }
-
- /// Return true if `a` and `b` are valid operand and result pairs for
- /// the operation.
- static bool areCastCompatible(Type a, Type b);
-
- /// The result of a tensor_cast is always a tensor.
- TensorType getType() { return getResult()->getType().cast<TensorType>(); }
-
- void print(OpAsmPrinter *p);
-
- LogicalResult verify();
-};
-
/// Prints dimension and symbol list.
void printDimAndSymbolList(Operation::operand_iterator begin,
Operation::operand_iterator end, unsigned numDims,
let name = "std";
}
+// Base class for standard cast operations. Requires single operand and result,
+// but does not constrain them to specific types.
+class CastOp<string mnemonic, list<OpTrait> traits = []> :
+ Op<Standard_Dialect, mnemonic, !listconcat(traits, [NoSideEffect])> {
+
+ let results = (outs AnyType);
+
+ let builders = [OpBuilder<
+ "Builder *builder, OperationState *result, Value *source, Type destType", [{
+ impl::buildCastOp(builder, result, source, destType);
+ }]>];
+
+ let parser = [{
+ return impl::parseCastOp(parser, result);
+ }];
+ let printer = [{
+ return printStandardCastOp(this->getOperation(), p);
+ }];
+ let verifier = [{ return ::verifyCastOp(*this); }];
+
+ let hasFolder = 1;
+}
+
// Base class for standard arithmetic operations. Requires operands and
// results to be of the same type, but does not constrain them to specific
// types. Individual classes will have `lhs` and `rhs` accessor to operands.
}];
let printer = [{
- return detail::printStandardBinaryOp(this->getOperation(), p);
+ return printStandardBinaryOp(this->getOperation(), p);
}];
}
let hasConstantFolder = 0b1;
}
+def MemRefCastOp : CastOp<"memref_cast"> {
+ let summary = "memref cast operation";
+ let description = [{
+ The "memref_cast" operation converts a memref from one type to an equivalent
+ type with a compatible shape. The source and destination types are
+ when both are memref types with the same element type, affine mappings,
+ address space, and rank but where the individual dimensions may add or
+ remove constant dimensions from the memref type.
+
+ If the cast converts any dimensions from an unknown to a known size, then it
+ acts as an assertion that fails at runtime of the dynamic dimensions
+ disagree with resultant destination size.
+
+ Assert that the input dynamic shape matches the destination static shape.
+ %2 = memref_cast %1 : memref<?x?xf32> to memref<4x4xf32>
+ Erase static shape information, replacing it with dynamic information.
+ %3 = memref_cast %1 : memref<4xf32> to memref<?xf32>
+ }];
+
+ let arguments = (ins MemRef<AnyType>);
+ let results = (outs MemRef<AnyType>);
+
+ let extraClassDeclaration = [{
+ /// Return true if `a` and `b` are valid operand and result pairs for
+ /// the operation.
+ static bool areCastCompatible(Type a, Type b);
+
+ /// The result of a memref_cast is always a memref.
+ MemRefType getType() { return getResult()->getType().cast<MemRefType>(); }
+ }];
+}
+
def MulFOp : FloatArithmeticOp<"mulf"> {
let summary = "foating point multiplication operation";
let hasConstantFolder = 0b1;
let hasCanonicalizer = 0b1;
}
+def TensorCastOp : CastOp<"tensor_cast"> {
+ let summary = "tensor cast operation";
+ let description = [{
+ The "tensor_cast" operation converts a tensor from one type to an equivalent
+ type without changing any data elements. The source and destination types
+ must both be tensor types with the same element type. If both are ranked
+ then the rank should be the same and static dimensions should match. The
+ operation is invalid if converting to a mismatching constant dimension.
+
+ Convert from unknown rank to rank 2 with unknown dimension sizes.
+ %2 = tensor_cast %1 : tensor<??f32> to tensor<?x?xf32>
+ }];
+
+ let arguments = (ins Tensor);
+ let results = (outs Tensor);
+
+ let extraClassDeclaration = [{
+ /// Return true if `a` and `b` are valid operand and result pairs for
+ /// the operation.
+ static bool areCastCompatible(Type a, Type b);
+
+ /// The result of a tensor_cast is always a tensor.
+ TensorType getType() { return getResult()->getType().cast<TensorType>(); }
+ }];
+}
+
def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
let summary = "integer binary xor";
let hasConstantFolder = 0b1;
/// A custom binary operation printer that omits the "std." prefix from the
/// operation names.
-void detail::printStandardBinaryOp(Operation *op, OpAsmPrinter *p) {
+static void printStandardBinaryOp(Operation *op, OpAsmPrinter *p) {
assert(op->getNumOperands() == 2 && "binary op should have two operands");
assert(op->getNumResults() == 1 && "binary op should have one result");
*p << " : " << op->getResult(0)->getType();
}
+/// A custom cast operation printer that omits the "std." prefix from the
+/// operation names.
+static void printStandardCastOp(Operation *op, OpAsmPrinter *p) {
+ *p << op->getName().getStringRef().drop_front(strlen("std.")) << ' '
+ << *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to "
+ << op->getResult(0)->getType();
+}
+
+/// A custom cast operation verifier.
+template <typename T> static LogicalResult verifyCastOp(T op) {
+ auto opType = op.getOperand()->getType();
+ auto resType = op.getType();
+ if (!T::areCastCompatible(opType, resType))
+ return op.emitError("operand type ") << opType << " and result type "
+ << resType << " are cast incompatible";
+
+ return success();
+}
+
StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
: Dialect(/*name=*/"std", context) {
addOperations<CmpFOp, CmpIOp, CondBranchOp, DmaStartOp, DmaWaitOp, LoadOp,
- MemRefCastOp, SelectOp, StoreOp, TensorCastOp,
+ SelectOp, StoreOp,
#define GET_OP_LIST
#include "mlir/StandardOps/Ops.cpp.inc"
>();
return true;
}
-void MemRefCastOp::print(OpAsmPrinter *p) {
- *p << "memref_cast " << *getOperand() << " : " << getOperand()->getType()
- << " to " << getType();
-}
-
-LogicalResult MemRefCastOp::verify() {
- auto opType = getOperand()->getType();
- auto resType = getType();
- if (!areCastCompatible(opType, resType))
- return emitError(llvm::formatv(
- "operand type {0} and result type {1} are cast incompatible", opType,
- resType));
-
- return success();
-}
+Value *MemRefCastOp::fold() { return impl::foldCastOp(*this); }
//===----------------------------------------------------------------------===//
// MulFOp
return true;
}
-void TensorCastOp::print(OpAsmPrinter *p) {
- *p << "tensor_cast " << *getOperand() << " : " << getOperand()->getType()
- << " to " << getType();
-}
-
-LogicalResult TensorCastOp::verify() {
- auto opType = getOperand()->getType();
- auto resType = getType();
- if (!areCastCompatible(opType, resType))
- return emitError(llvm::formatv(
- "operand type {0} and result type {1} are cast incompatible", opType,
- resType));
-
- return success();
-}
+Value *TensorCastOp::fold() { return impl::foldCastOp(*this); }
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions