Add ViewOp verification for dynamic strides, and address some comments from previous...
authorAndy Davis <andydavis@google.com>
Wed, 6 Nov 2019 19:25:16 +0000 (11:25 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 6 Nov 2019 19:25:54 +0000 (11:25 -0800)
PiperOrigin-RevId: 278903187

mlir/include/mlir/Dialect/StandardOps/Ops.td
mlir/lib/Dialect/StandardOps/Ops.cpp
mlir/test/IR/core-ops.mlir
mlir/test/IR/invalid-ops.mlir

index 4dbdcc4..10f9438 100644 (file)
@@ -1169,7 +1169,8 @@ def ViewOp : Std_Op<"view"> {
         (d0, d1, d2)[s0, s1] -> (d0 * s0 + d1 * 4 + d2 + s1)
   }];
 
-  let arguments = (ins AnyMemRef:$source, Variadic<Index>:$operands);
+  let arguments = (ins MemRefRankOf<[I8], [1]>:$source,
+                       Variadic<Index>:$operands);
   let results = (outs AnyMemRef);
 
   let extraClassDeclaration = [{
index e6b9903..5a452c5 100644 (file)
@@ -2376,17 +2376,8 @@ static void print(OpAsmPrinter &p, ViewOp op) {
 }
 
 static LogicalResult verify(ViewOp op) {
-  auto baseType = op.getOperand(0)->getType().dyn_cast<MemRefType>();
-  auto viewType = op.getResult()->getType().dyn_cast<MemRefType>();
-
-  // Operand 0 type and ViewOp result type must be memref.
-  if (!baseType || !viewType)
-    return op.emitError("operand type ") << baseType << " and result type "
-                                         << viewType << " are must be memref";
-
-  // The base memref should be rank 1 with i8 element type.
-  if (baseType.getRank() != 1 || !baseType.getElementType().isInteger(8))
-    return op.emitError("unsupported shape for base memref type ") << baseType;
+  auto baseType = op.getOperand(0)->getType().cast<MemRefType>();
+  auto viewType = op.getResult()->getType().cast<MemRefType>();
 
   // The base memref should have identity layout map (or none).
   if (baseType.getAffineMaps().size() > 1 ||
@@ -2403,7 +2394,7 @@ static LogicalResult verify(ViewOp op) {
   // Verify that the result memref type has a strided layout map. is strided
   int64_t offset;
   llvm::SmallVector<int64_t, 4> strides;
-  if (failed(mlir::getStridesAndOffset(viewType, strides, offset)))
+  if (failed(getStridesAndOffset(viewType, strides, offset)))
     return op.emitError("result type ") << viewType << " is not strided";
 
   // Verify that we have the correct number of operands for the result type.
@@ -2414,6 +2405,23 @@ static LogicalResult verify(ViewOp op) {
   if (op.getNumOperands() !=
       memrefOperandCount + numDynamicDims + dynamicOffsetCount)
     return op.emitError("incorrect number of operands for type ") << viewType;
+
+  // Verify dynamic strides symbols were added to correct dimensions based
+  // on dynamic sizes.
+  ArrayRef<int64_t> viewShape = viewType.getShape();
+  unsigned viewRank = viewType.getRank();
+  assert(viewRank == strides.size());
+  bool dynamicStrides = false;
+  for (int i = viewRank - 2; i >= 0; --i) {
+    // If size at dim 'i + 1' is dynamic, set the 'dynamicStrides' flag.
+    if (ShapedType::isDynamic(viewShape[i + 1]))
+      dynamicStrides = true;
+    // If stride at dim 'i' is not dynamic, return error.
+    if (dynamicStrides && strides[i] != MemRefType::getDynamicStrideOrOffset())
+      return op.emitError("incorrect dynamic strides in view memref type ")
+             << viewType;
+  }
+
   return success();
 }
 
index 977ec66..bbabe60 100644 (file)
 // CHECK-DAG: #[[map_proj_d0d1_d0:map[0-9]+]] = (d0, d1) -> (d0)
 // CHECK-DAG: #[[map_proj_d0d1_d1:map[0-9]+]] = (d0, d1) -> (d1)
 // CHECK-DAG: #[[map_proj_d0d1_d1d0:map[0-9]+]] = (d0, d1) -> (d1, d0)
-// CHECK-DAG: #[[VIEW_MAP0:map[0-9]+]] = (d0, d1)[s0] -> (d0 * 4 + d1 + s0)
+
 // CHECK-DAG: #[[VIEW_MAP1:map[0-9]+]] = (d0, d1) -> (d0 * 4 + d1)
+// CHECK-DAG: #[[VIEW_MAP2:map[0-9]+]] = (d0, d1)[s0, s1] -> (d0 * s0 + d1 + s1)
+// CHECK-DAG: #[[VIEW_MAP3:map[0-9]+]] = (d0, d1)[s0] -> (d0 * s0 + d1)
 
 // CHECK-LABEL: func @func_with_ops(%arg0: f32) {
 func @func_with_ops(f32) {
@@ -478,24 +480,24 @@ func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref<?xf32>) {
 func @memref_view(%arg0 : index, %arg1 : index, %arg2 : index) {
   %0 = alloc() : memref<2048xi8>
   // Test two dynamic sizes and dynamic offset.
-  // CHECK: %{{.*}} = std.view %0[%arg0, %arg1][%arg2] : memref<2048xi8> to memref<?x?xf32, #[[VIEW_MAP0]]>
+  // CHECK: %{{.*}} = std.view %0[%arg0, %arg1][%arg2] : memref<2048xi8> to memref<?x?xf32, #[[VIEW_MAP2]]>
   %1 = view %0[%arg0, %arg1][%arg2]
-    : memref<2048xi8> to memref<?x?xf32, (d0, d1)[s0] -> (d0 * 4 + d1 + s0)>
+    : memref<2048xi8> to memref<?x?xf32, (d0, d1)[s0, s1] -> (d0 * s0 + d1 + s1)>
 
   // Test two dynamic sizes and static offset.
-  // CHECK: %{{.*}} = std.view %0[%arg0, %arg1][] : memref<2048xi8> to memref<?x?xf32, #[[VIEW_MAP1]]>
+  // CHECK: %{{.*}} = std.view %0[%arg0, %arg1][] : memref<2048xi8> to memref<?x?xf32, #[[VIEW_MAP3]]>
   %2 = view %0[%arg0, %arg1][]
-    : memref<2048xi8> to memref<?x?xf32, (d0, d1) -> (d0 * 4 + d1)>
+    : memref<2048xi8> to memref<?x?xf32, (d0, d1)[s0] -> (d0 * s0 + d1)>
 
   // Test one dynamic size and dynamic offset.
-  // CHECK: %{{.*}} = std.view %0[%arg1][%arg2] : memref<2048xi8> to memref<4x?xf32, #[[VIEW_MAP0]]>
+  // CHECK: %{{.*}} = std.view %0[%arg1][%arg2] : memref<2048xi8> to memref<4x?xf32, #[[VIEW_MAP2]]>
   %3 = view %0[%arg1][%arg2]
-    : memref<2048xi8> to memref<4x?xf32, (d0, d1)[s0] -> (d0 * 4 + d1 + s0)>
+    : memref<2048xi8> to memref<4x?xf32, (d0, d1)[s0, s1] -> (d0 * s0 + d1 + s1)>
 
   // Test one dynamic size and static offset.
-  // CHECK: %{{.*}} = std.view %0[%arg0][] : memref<2048xi8> to memref<?x16xf32, #[[VIEW_MAP1]]>
+  // CHECK: %{{.*}} = std.view %0[%arg0][] : memref<2048xi8> to memref<?x4xf32, #[[VIEW_MAP1]]>
   %4 = view %0[%arg0][]
-    : memref<2048xi8> to memref<?x16xf32, (d0, d1) -> (d0 * 4 + d1)>
+    : memref<2048xi8> to memref<?x4xf32, (d0, d1) -> (d0 * 4 + d1)>
 
   // Test static sizes and static offset.
   // CHECK: %{{.*}} = std.view %0[][] : memref<2048xi8> to memref<64x4xf32, #[[VIEW_MAP1]]>
index 4d45d95..4d1d853 100644 (file)
@@ -926,7 +926,7 @@ func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
 
 func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
   %0 = alloc() : memref<2048xf32>
-  // expected-error@+1 {{unsupported shape for base memref}}
+  // expected-error@+1 {{must be 1D memref of 8-bit integer values}}
   %1 = view %0[%arg0, %arg1][]
     : memref<2048xf32> to memref<?x?xf32, (d0, d1)[s0] -> (d0 * 4 + d1 + s0)>
   return
@@ -953,3 +953,26 @@ func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
       memref<?x?xf32, (d0, d1)[s0] -> (d0 * 4 + d1 + s0), 1>
   return
 }
+
+// -----
+
+func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
+  %0 = alloc() : memref<2048xi8>
+  // expected-error@+1 {{incorrect dynamic strides}}
+  %1 = view %0[%arg0, %arg1][]
+    : memref<2048xi8> to
+      memref<?x?x4xf32, (d0, d1, d2) -> (d0 * 777 + d1 * 4 + d2)>
+  return
+}
+
+// -----
+
+func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
+  %0 = alloc() : memref<2048xi8>
+  // expected-error@+1 {{incorrect dynamic strides}}
+  %1 = view %0[%arg0][]
+    : memref<2048xi8> to
+      memref<16x4x?xf32, (d0, d1, d2) -> (d0 * 777 + d1 * 4 + d2)>
+  return
+}
+