}];
}
+def TensorLoadOp : Std_Op<"tensor_load",
+ [SameOperandsAndResultShape, SameOperandsAndResultElementType]> {
+ let summary = "tensor load operation";
+ let description = [{
+ The "tensor_load" operation creates a tensor from a memref, making an
+ independent copy of the element data. The result value is a tensor whose
+ shape and element type match the memref operand.
+
+ Produce a value of tensor<4x?xf32> type.
+ %12 = tensor_load %10 : memref<4x?xf32, #layout, memspace0>
+ }];
+
+ let arguments = (ins AnyMemRef);
+ let results = (outs AnyTensor);
+ // TensorLoadOp is fully verified by traits.
+ let verifier = ?;
+
+ let builders = [OpBuilder<
+ "Builder *builder, OperationState *result, Value *memref", [{
+ auto memrefType = memref->getType().cast<MemRefType>();
+ auto resultType = builder->getTensorType(memrefType.getShape(),
+ memrefType.getElementType());
+ result->addOperands(memref);
+ result->addTypes(resultType);
+ }]>];
+
+
+ let extraClassDeclaration = [{
+ /// The result of a tensor_load is always a tensor.
+ TensorType getType() { return getResult()->getType().cast<TensorType>(); }
+ }];
+}
+
+def TensorStoreOp : Std_Op<"tensor_store",
+ [SameOperandsShape, SameOperandsElementType]> {
+ let summary = "tensor store operation";
+ let description = [{
+ The "tensor_store" operation stores the contents of a tensor into a memref.
+ The first operand is a value of tensor type, the second operand is a value
+ of memref type. The shapes and element types of these must match, and are
+ specified by the memref type.
+
+ Example:
+ %9 = dim %8, 1 : tensor<4x?xf32>
+ %10 = alloc(%9) : memref<4x?xf32, #layout, memspace0>
+ tensor_store %8, %10 : memref<4x?xf32, #layout, memspace0>
+ }];
+
+ let arguments = (ins AnyTensor:$tensor, AnyMemRef:$memref);
+ // TensorStoreOp is fully verified by traits.
+ let verifier = ?;
+}
+
+
def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
let summary = "integer binary xor";
let hasFolder = 1;
}
// Op supports operand broadcast behavior.
-def Broadcastable : NativeOpTrait<"BroadcastableTwoOperandsOneResult">;
+def Broadcastable : NativeOpTrait<"BroadcastableTwoOperandsOneResult">;
// X op Y == Y op X
-def Commutative : NativeOpTrait<"IsCommutative">;
+def Commutative : NativeOpTrait<"IsCommutative">;
// Op is isolated from above.
def IsolatedFromAbove : NativeOpTrait<"IsIsolatedFromAbove">;
// Op results are float or vectors/tensors thereof.
def ResultsAreFloatLike : NativeOpTrait<"ResultsAreFloatLike">;
// Op has no side effect.
-def NoSideEffect : NativeOpTrait<"HasNoSideEffect">;
+def NoSideEffect : NativeOpTrait<"HasNoSideEffect">;
// Op has the same operand type.
-def SameTypeOperands : NativeOpTrait<"SameTypeOperands">;
+def SameTypeOperands : NativeOpTrait<"SameTypeOperands">;
+// Op has same shape for all operands.
+def SameOperandsShape : NativeOpTrait<"SameOperandsShape">;
// Op has same operand and result shape.
-def SameOperandsAndResultShape : NativeOpTrait<"SameOperandsAndResultShape">;
+def SameOperandsAndResultShape : NativeOpTrait<"SameOperandsAndResultShape">;
// Op has the same operand and result type.
-def SameOperandsAndResultType : NativeOpTrait<"SameOperandsAndResultType">;
+def SameOperandsAndResultType : NativeOpTrait<"SameOperandsAndResultType">;
+// Op has the same element type for all operands.
+def SameOperandsElementType : NativeOpTrait<"SameOperandsElementType">;
// Op has the same operand and result element type.
def SameOperandsAndResultElementType :
NativeOpTrait<"SameOperandsAndResultElementType">;
LogicalResult verifyOneResult(Operation *op);
LogicalResult verifyNResults(Operation *op, unsigned numOperands);
LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands);
+LogicalResult verifySameOperandsShape(Operation *op);
LogicalResult verifySameOperandsAndResultShape(Operation *op);
+LogicalResult verifySameOperandsElementType(Operation *op);
LogicalResult verifySameOperandsAndResultElementType(Operation *op);
LogicalResult verifySameOperandsAndResultType(Operation *op);
LogicalResult verifyResultsAreBoolLike(Operation *op);
: public detail::MultiResultTraitBase<ConcreteType, VariadicResults> {};
/// This class provides verification for ops that are known to have the same
+/// operand shape: all operands are scalars, vectors/tensors of the same
+/// shape.
+template <typename ConcreteType>
+class SameOperandsShape : public TraitBase<ConcreteType, SameOperandsShape> {
+public:
+ static LogicalResult verifyTrait(Operation *op) {
+ return impl::verifySameOperandsShape(op);
+ }
+};
+
+/// This class provides verification for ops that are known to have the same
/// operand and result shape: both are scalars, vectors/tensors of the same
/// shape.
template <typename ConcreteType>
};
/// This class provides verification for ops that are known to have the same
+/// operand element type.
+///
+template <typename ConcreteType>
+class SameOperandsElementType
+ : public TraitBase<ConcreteType, SameOperandsElementType> {
+public:
+ static LogicalResult verifyTrait(Operation *op) {
+ return impl::verifySameOperandsElementType(op);
+ }
+};
+
+/// This class provides verification for ops that are known to have the same
/// operand and result element type.
///
template <typename ConcreteType>
}
//===----------------------------------------------------------------------===//
+// Helpers for Tensor[Load|Store]Op
+//===----------------------------------------------------------------------===//
+
+static Type getTensorTypeFromMemRefType(Builder &b, Type type) {
+ if (auto memref = type.dyn_cast<MemRefType>())
+ return b.getTensorType(memref.getShape(), memref.getElementType());
+ return b.getNoneType();
+}
+
+//===----------------------------------------------------------------------===//
+// TensorLoadOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter *p, TensorLoadOp op) {
+ *p << "tensor_load " << *op.getOperand();
+ p->printOptionalAttrDict(op.getAttrs());
+ *p << " : " << op.getOperand()->getType();
+}
+
+static ParseResult parseTensorLoadOp(OpAsmParser *parser,
+ OperationState *result) {
+ OpAsmParser::OperandType op;
+ Type type;
+ return failure(parser->parseOperand(op) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(type) ||
+ parser->resolveOperand(op, type, result->operands) ||
+ parser->addTypeToList(
+ getTensorTypeFromMemRefType(parser->getBuilder(), type),
+ result->types));
+}
+
+//===----------------------------------------------------------------------===//
+// TensorStoreOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter *p, TensorStoreOp op) {
+ *p << "tensor_store " << *op.tensor() << ", " << *op.memref();
+ p->printOptionalAttrDict(op.getAttrs());
+ *p << " : " << op.memref()->getType();
+}
+
+static ParseResult parseTensorStoreOp(OpAsmParser *parser,
+ OperationState *result) {
+ SmallVector<OpAsmParser::OperandType, 2> ops;
+ Type type;
+ llvm::SMLoc loc = parser->getCurrentLocation();
+ return failure(
+ parser->parseOperandList(ops, /*requiredOperandCount=*/2) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(type) ||
+ parser->resolveOperands(
+ ops, {getTensorTypeFromMemRefType(parser->getBuilder(), type), type},
+ loc, result->operands));
+}
+
+//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
return success(sType1.getShape() == sType2.getShape());
}
+LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) {
+ if (op->getNumOperands() == 0)
+ return failure();
+
+ auto type = op->getOperand(0)->getType();
+ for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) {
+ if (failed(verifyShapeMatch(opType, type)))
+ return op->emitOpError() << "requires the same shape for all operands";
+ }
+ return success();
+}
+
LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) {
if (op->getNumOperands() == 0 || op->getNumResults() == 0)
return failure();
return success();
}
+LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) {
+ if (op->getNumOperands() == 0)
+ return failure();
+
+ auto type = op->getOperand(0)->getType().dyn_cast<ShapedType>();
+ if (!type)
+ return op->emitOpError("requires shaped type results");
+ auto elementType = type.getElementType();
+
+ for (auto operandType : llvm::drop_begin(op->getOperandTypes(), 1)) {
+ auto shapedType = operandType.dyn_cast<ShapedType>();
+ if (!shapedType)
+ return op->emitOpError("requires shaped type operands");
+ if (shapedType.getElementType() != elementType)
+ return op->emitOpError("requires the same element type for all operands");
+ }
+
+ return success();
+}
+
LogicalResult
OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) {
if (op->getNumOperands() == 0 || op->getNumResults() == 0)
return
}
+// CHECK-LABEL: func @tensor_load_store
+func @tensor_load_store(%0 : memref<4x4xi32>) {
+ // CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF:.*]] : memref<4x4xi32>
+ %1 = tensor_load %0 : memref<4x4xi32>
+ // CHECK: tensor_store %[[TENSOR]], %[[MEMREF]] : memref<4x4xi32>
+ tensor_store %1, %0 : memref<4x4xi32>
+ return
+}
// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
+// CHECK: succeededSameOperandsElementType
+func @succeededSameOperandsElementType(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>) {
+ %0 = "test.same_operand_type"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi32>
+ %1 = "test.same_operand_type"(%t1, %t10x10) : (tensor<1xf32>, tensor<10x10xf32>) -> tensor<1xi32>
+ %2 = "test.same_operand_type"(%t10x10, %v1) : (tensor<10x10xf32>, vector<1xf32>) -> tensor<1xi32>
+ %3 = "test.same_operand_type"(%v1, %t1) : (vector<1xf32>, tensor<1xf32>) -> tensor<1xi32>
+ %4 = "test.same_operand_type"(%v1, %t1) : (vector<1xf32>, tensor<1xf32>) -> tensor<121xi32>
+ return
+}
+
+// -----
+
+func @failedSameOperandElementType(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>) {
+ // expected-error@+1 {{requires the same element type for all operands}}
+ %0 = "test.same_operand_type"(%t1, %t1i) : (tensor<1xf32>, tensor<1xi32>) -> tensor<1xf32>
+}
+
+// -----
+
// CHECK: succeededSameOperandAndResultElementType
func @succeededSameOperandAndResultElementType(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>) {
%0 = "test.same_operand_and_result_type"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
// -----
+// CHECK: succeededSameOperandShape
+func @succeededSameOperandShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %tr: tensor<*xf32>) {
+ %0 = "test.same_operand_shape"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> (tensor<10x10xf32>)
+ %1 = "test.same_operand_shape"(%t10x10, %t10x10) : (tensor<10x10xf32>, tensor<10x10xf32>) -> (tensor<1xf32>)
+ %2 = "test.same_operand_shape"(%t1, %tr) : (tensor<1xf32>, tensor<*xf32>) -> (tensor<10x10xf32>)
+ return
+}
+
+// -----
+
+func @failedSameOperandShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>) {
+ // expected-error@+1 {{requires the same shape for all operands}}
+ %0 = "test.same_operand_shape"(%t1, %t10x10) : (tensor<1xf32>, tensor<10x10xf32>) -> (tensor<1xf32>)
+}
+
+// -----
+
// CHECK: succeededSameOperandAndResultShape
func @succeededSameOperandAndResultShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %tr: tensor<*xf32>) {
%0 = "test.same_operand_and_result_shape"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
// -----
-func @succeededSameOperandAndResultShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %v1: vector<1xf32>) {
+func @failedSameOperandAndResultShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>) {
// expected-error@+1 {{requires the same shape for all operands and results}}
%0 = "test.same_operand_and_result_shape"(%t1, %t10x10) : (tensor<1xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
}
// Test Traits
//===----------------------------------------------------------------------===//
+def SameOperandElementTypeOp : TEST_Op<"same_operand_type",
+ [SameOperandsElementType]> {
+ let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y);
+ let results = (outs AnyVectorOrTensor:$res);
+}
+
def SameOperandAndResultElementTypeOp : TEST_Op<"same_operand_and_result_type",
[SameOperandsAndResultElementType]> {
let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y);
let results = (outs AnyVectorOrTensor:$res);
}
+def SameOperandShapeOp : TEST_Op<"same_operand_shape", [SameOperandsShape]> {
+ let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y);
+ let results = (outs AnyVectorOrTensor:$res);
+}
+
def SameOperandAndResultShapeOp : TEST_Op<"same_operand_and_result_shape",
[SameOperandsAndResultShape]> {
let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y);