Add implementation for tensor_load and tensor_store operations.
authorStephan Herhut <herhut@google.com>
Wed, 28 Aug 2019 18:25:19 +0000 (11:25 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 28 Aug 2019 18:25:52 +0000 (11:25 -0700)
This change adds definitions, parsing and verification for both ops.

PiperOrigin-RevId: 265954051

mlir/include/mlir/Dialect/StandardOps/Ops.td
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OpDefinition.h
mlir/lib/Dialect/StandardOps/Ops.cpp
mlir/lib/IR/Operation.cpp
mlir/test/IR/core-ops.mlir
mlir/test/IR/traits.mlir
mlir/test/lib/TestDialect/TestOps.td

index b6bf2cf..37f2ac7 100644 (file)
@@ -897,6 +897,60 @@ def TensorCastOp : CastOp<"tensor_cast"> {
   }];
 }
 
+def TensorLoadOp : Std_Op<"tensor_load",
+    [SameOperandsAndResultShape, SameOperandsAndResultElementType]> {
+  let summary = "tensor load operation";
+  let description = [{
+    The "tensor_load" operation creates a tensor from a memref, making an
+    independent copy of the element data. The result value is a tensor whose
+    shape and element type match the memref operand.
+
+    Produce a value of tensor<4x?xf32> type.
+       %12 = tensor_load %10 : memref<4x?xf32, #layout, memspace0>
+  }];
+
+  let arguments = (ins AnyMemRef);
+  let results = (outs AnyTensor);
+  // TensorLoadOp is fully verified by traits.
+  let verifier = ?;
+
+  let builders = [OpBuilder<
+    "Builder *builder, OperationState *result, Value *memref", [{
+      auto memrefType = memref->getType().cast<MemRefType>();
+      auto resultType = builder->getTensorType(memrefType.getShape(),
+          memrefType.getElementType());
+      result->addOperands(memref);
+      result->addTypes(resultType);
+  }]>];
+
+
+  let extraClassDeclaration = [{
+    /// The result of a tensor_load is always a tensor.
+    TensorType getType() { return getResult()->getType().cast<TensorType>(); }
+  }];
+}
+
+def TensorStoreOp : Std_Op<"tensor_store",
+    [SameOperandsShape, SameOperandsElementType]> {
+  let summary = "tensor store operation";
+  let description = [{
+    The "tensor_store" operation stores the contents of a tensor into a memref.
+    The first operand is a value of tensor type, the second operand is a value
+    of memref type. The shapes and element types of these must match, and are
+    specified by the memref type.
+
+    Example:
+       %9 = dim %8, 1 : tensor<4x?xf32>
+       %10 = alloc(%9) : memref<4x?xf32, #layout, memspace0>
+       tensor_store %8, %10 : memref<4x?xf32, #layout, memspace0>
+  }];
+
+  let arguments = (ins AnyTensor:$tensor, AnyMemRef:$memref);
+  // TensorStoreOp is fully verified by traits.
+  let verifier = ?;
+}
+
+
 def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
   let summary = "integer binary xor";
   let hasFolder = 1;
index b88613f..dd9d4e2 100644 (file)
@@ -1073,21 +1073,25 @@ class PredOpTrait<string descr, Pred pred> : OpTrait {
 }
 
 // Op supports operand broadcast behavior.
-def Broadcastable    : NativeOpTrait<"BroadcastableTwoOperandsOneResult">;
+def Broadcastable  : NativeOpTrait<"BroadcastableTwoOperandsOneResult">;
 // X op Y == Y op X
-def Commutative      : NativeOpTrait<"IsCommutative">;
+def Commutative  : NativeOpTrait<"IsCommutative">;
 // Op is isolated from above.
 def IsolatedFromAbove : NativeOpTrait<"IsIsolatedFromAbove">;
 // Op results are float or vectors/tensors thereof.
 def ResultsAreFloatLike : NativeOpTrait<"ResultsAreFloatLike">;
 // Op has no side effect.
-def NoSideEffect     : NativeOpTrait<"HasNoSideEffect">;
+def NoSideEffect : NativeOpTrait<"HasNoSideEffect">;
 // Op has the same operand type.
-def SameTypeOperands  : NativeOpTrait<"SameTypeOperands">;
+def SameTypeOperands : NativeOpTrait<"SameTypeOperands">;
+// Op has same shape for all operands.
+def SameOperandsShape : NativeOpTrait<"SameOperandsShape">;
 // Op has same operand and result shape.
-def SameOperandsAndResultShape   : NativeOpTrait<"SameOperandsAndResultShape">;
+def SameOperandsAndResultShape : NativeOpTrait<"SameOperandsAndResultShape">;
 // Op has the same operand and result type.
-def SameOperandsAndResultType    : NativeOpTrait<"SameOperandsAndResultType">;
+def SameOperandsAndResultType : NativeOpTrait<"SameOperandsAndResultType">;
+// Op has the same element type for all operands.
+def SameOperandsElementType : NativeOpTrait<"SameOperandsElementType">;
 // Op has the same operand and result element type.
 def SameOperandsAndResultElementType :
   NativeOpTrait<"SameOperandsAndResultElementType">;
index fd35262..570990a 100644 (file)
@@ -357,7 +357,9 @@ LogicalResult verifyZeroResult(Operation *op);
 LogicalResult verifyOneResult(Operation *op);
 LogicalResult verifyNResults(Operation *op, unsigned numOperands);
 LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands);
+LogicalResult verifySameOperandsShape(Operation *op);
 LogicalResult verifySameOperandsAndResultShape(Operation *op);
+LogicalResult verifySameOperandsElementType(Operation *op);
 LogicalResult verifySameOperandsAndResultElementType(Operation *op);
 LogicalResult verifySameOperandsAndResultType(Operation *op);
 LogicalResult verifyResultsAreBoolLike(Operation *op);
@@ -626,6 +628,17 @@ class VariadicResults
     : public detail::MultiResultTraitBase<ConcreteType, VariadicResults> {};
 
 /// This class provides verification for ops that are known to have the same
+/// operand shape: all operands are scalars, vectors/tensors of the same
+/// shape.
+template <typename ConcreteType>
+class SameOperandsShape : public TraitBase<ConcreteType, SameOperandsShape> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifySameOperandsShape(op);
+  }
+};
+
+/// This class provides verification for ops that are known to have the same
 /// operand and result shape: both are scalars, vectors/tensors of the same
 /// shape.
 template <typename ConcreteType>
@@ -638,6 +651,18 @@ public:
 };
 
 /// This class provides verification for ops that are known to have the same
+/// operand element type.
+///
+template <typename ConcreteType>
+class SameOperandsElementType
+    : public TraitBase<ConcreteType, SameOperandsElementType> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifySameOperandsElementType(op);
+  }
+};
+
+/// This class provides verification for ops that are known to have the same
 /// operand and result element type.
 ///
 template <typename ConcreteType>
index 120e45a..9f49053 100644 (file)
@@ -2132,6 +2132,63 @@ OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
 }
 
 //===----------------------------------------------------------------------===//
+// Helpers for Tensor[Load|Store]Op
+//===----------------------------------------------------------------------===//
+
+static Type getTensorTypeFromMemRefType(Builder &b, Type type) {
+  if (auto memref = type.dyn_cast<MemRefType>())
+    return b.getTensorType(memref.getShape(), memref.getElementType());
+  return b.getNoneType();
+}
+
+//===----------------------------------------------------------------------===//
+// TensorLoadOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter *p, TensorLoadOp op) {
+  *p << "tensor_load " << *op.getOperand();
+  p->printOptionalAttrDict(op.getAttrs());
+  *p << " : " << op.getOperand()->getType();
+}
+
+static ParseResult parseTensorLoadOp(OpAsmParser *parser,
+                                     OperationState *result) {
+  OpAsmParser::OperandType op;
+  Type type;
+  return failure(parser->parseOperand(op) ||
+                 parser->parseOptionalAttributeDict(result->attributes) ||
+                 parser->parseColonType(type) ||
+                 parser->resolveOperand(op, type, result->operands) ||
+                 parser->addTypeToList(
+                     getTensorTypeFromMemRefType(parser->getBuilder(), type),
+                     result->types));
+}
+
+//===----------------------------------------------------------------------===//
+// TensorStoreOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter *p, TensorStoreOp op) {
+  *p << "tensor_store " << *op.tensor() << ", " << *op.memref();
+  p->printOptionalAttrDict(op.getAttrs());
+  *p << " : " << op.memref()->getType();
+}
+
+static ParseResult parseTensorStoreOp(OpAsmParser *parser,
+                                      OperationState *result) {
+  SmallVector<OpAsmParser::OperandType, 2> ops;
+  Type type;
+  llvm::SMLoc loc = parser->getCurrentLocation();
+  return failure(
+      parser->parseOperandList(ops, /*requiredOperandCount=*/2) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonType(type) ||
+      parser->resolveOperands(
+          ops, {getTensorTypeFromMemRefType(parser->getBuilder(), type), type},
+          loc, result->operands));
+}
+
+//===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
 
index 205d561..a623e39 100644 (file)
@@ -770,6 +770,18 @@ static LogicalResult verifyShapeMatch(Type type1, Type type2) {
   return success(sType1.getShape() == sType2.getShape());
 }
 
+LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) {
+  if (op->getNumOperands() == 0)
+    return failure();
+
+  auto type = op->getOperand(0)->getType();
+  for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) {
+    if (failed(verifyShapeMatch(opType, type)))
+      return op->emitOpError() << "requires the same shape for all operands";
+  }
+  return success();
+}
+
 LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) {
   if (op->getNumOperands() == 0 || op->getNumResults() == 0)
     return failure();
@@ -788,6 +800,26 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) {
   return success();
 }
 
+LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) {
+  if (op->getNumOperands() == 0)
+    return failure();
+
+  auto type = op->getOperand(0)->getType().dyn_cast<ShapedType>();
+  if (!type)
+    return op->emitOpError("requires shaped type results");
+  auto elementType = type.getElementType();
+
+  for (auto operandType : llvm::drop_begin(op->getOperandTypes(), 1)) {
+    auto shapedType = operandType.dyn_cast<ShapedType>();
+    if (!shapedType)
+      return op->emitOpError("requires shaped type operands");
+    if (shapedType.getElementType() != elementType)
+      return op->emitOpError("requires the same element type for all operands");
+  }
+
+  return success();
+}
+
 LogicalResult
 OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) {
   if (op->getNumOperands() == 0 || op->getNumResults() == 0)
index a2ea19e..2b91d68 100644 (file)
@@ -452,3 +452,11 @@ func @test_vector.transfer_ops(%arg0: memref<?x?xf32>) {
   return
 }
 
+// CHECK-LABEL: func @tensor_load_store
+func @tensor_load_store(%0 : memref<4x4xi32>) {
+  // CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF:.*]] : memref<4x4xi32>
+  %1 = tensor_load %0 : memref<4x4xi32>
+  // CHECK: tensor_store %[[TENSOR]], %[[MEMREF]] : memref<4x4xi32>
+  tensor_store %1, %0 : memref<4x4xi32>
+  return
+}
index 4eed6bb..40a4e96 100644 (file)
@@ -1,5 +1,24 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
 
+// CHECK: succeededSameOperandsElementType
+func @succeededSameOperandsElementType(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>) {
+  %0 = "test.same_operand_type"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi32>
+  %1 = "test.same_operand_type"(%t1, %t10x10) : (tensor<1xf32>, tensor<10x10xf32>) -> tensor<1xi32>
+  %2 = "test.same_operand_type"(%t10x10, %v1) : (tensor<10x10xf32>, vector<1xf32>) -> tensor<1xi32>
+  %3 = "test.same_operand_type"(%v1, %t1) : (vector<1xf32>, tensor<1xf32>) -> tensor<1xi32>
+  %4 = "test.same_operand_type"(%v1, %t1) : (vector<1xf32>, tensor<1xf32>) -> tensor<121xi32>
+  return
+}
+
+// -----
+
+func @failedSameOperandElementType(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>) {
+  // expected-error@+1 {{requires the same element type for all operands}}
+  %0 = "test.same_operand_type"(%t1, %t1i) : (tensor<1xf32>, tensor<1xi32>) -> tensor<1xf32>
+}
+
+// -----
+
 // CHECK: succeededSameOperandAndResultElementType
 func @succeededSameOperandAndResultElementType(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>) {
   %0 = "test.same_operand_and_result_type"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
@@ -26,6 +45,23 @@ func @failedSameOperandAndResultElementType(%t10x10 : tensor<10x10xf32>, %t1: te
 
 // -----
 
+// CHECK: succeededSameOperandShape
+func @succeededSameOperandShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %tr: tensor<*xf32>) {
+  %0 = "test.same_operand_shape"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> (tensor<10x10xf32>)
+  %1 = "test.same_operand_shape"(%t10x10, %t10x10) : (tensor<10x10xf32>, tensor<10x10xf32>) -> (tensor<1xf32>)
+  %2 = "test.same_operand_shape"(%t1, %tr) : (tensor<1xf32>, tensor<*xf32>) -> (tensor<10x10xf32>)
+  return
+}
+
+// -----
+
+func @failedSameOperandShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>) {
+  // expected-error@+1 {{requires the same shape for all operands}}
+  %0 = "test.same_operand_shape"(%t1, %t10x10) : (tensor<1xf32>, tensor<10x10xf32>) -> (tensor<1xf32>)
+}
+
+// -----
+
 // CHECK: succeededSameOperandAndResultShape
 func @succeededSameOperandAndResultShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %tr: tensor<*xf32>) {
   %0 = "test.same_operand_and_result_shape"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
@@ -36,7 +72,7 @@ func @succeededSameOperandAndResultShape(%t10x10 : tensor<10x10xf32>, %t1: tenso
 
 // -----
 
-func @succeededSameOperandAndResultShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %v1: vector<1xf32>) {
+func @failedSameOperandAndResultShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>) {
   // expected-error@+1 {{requires the same shape for all operands and results}}
   %0 = "test.same_operand_and_result_shape"(%t1, %t10x10) : (tensor<1xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
 }
index e2fdf37..f2d7aef 100644 (file)
@@ -165,12 +165,23 @@ def SizedRegionOp : TEST_Op<"sized_region_op", []> {
 // Test Traits
 //===----------------------------------------------------------------------===//
 
+def SameOperandElementTypeOp : TEST_Op<"same_operand_type",
+    [SameOperandsElementType]> {
+  let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y);
+  let results = (outs AnyVectorOrTensor:$res);
+}
+
 def SameOperandAndResultElementTypeOp : TEST_Op<"same_operand_and_result_type",
     [SameOperandsAndResultElementType]> {
   let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y);
   let results = (outs AnyVectorOrTensor:$res);
 }
 
+def SameOperandShapeOp : TEST_Op<"same_operand_shape", [SameOperandsShape]> {
+  let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y);
+  let results = (outs AnyVectorOrTensor:$res);
+}
+
 def SameOperandAndResultShapeOp : TEST_Op<"same_operand_and_result_shape",
     [SameOperandsAndResultShape]> {
   let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y);