From b5654d1311ffb2dc1f7f9803d36e4e503bfcc9dd Mon Sep 17 00:00:00 2001 From: Andy Davis Date: Wed, 6 Nov 2019 11:25:16 -0800 Subject: [PATCH] Add ViewOp verification for dynamic strides, and address some comments from previous change. PiperOrigin-RevId: 278903187 --- mlir/include/mlir/Dialect/StandardOps/Ops.td | 3 ++- mlir/lib/Dialect/StandardOps/Ops.cpp | 32 +++++++++++++++++----------- mlir/test/IR/core-ops.mlir | 20 +++++++++-------- mlir/test/IR/invalid-ops.mlir | 25 +++++++++++++++++++++- 4 files changed, 57 insertions(+), 23 deletions(-) diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td index 4dbdcc4..10f9438 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -1169,7 +1169,8 @@ def ViewOp : Std_Op<"view"> { (d0, d1, d2)[s0, s1] -> (d0 * s0 + d1 * 4 + d2 + s1) }]; - let arguments = (ins AnyMemRef:$source, Variadic:$operands); + let arguments = (ins MemRefRankOf<[I8], [1]>:$source, + Variadic:$operands); let results = (outs AnyMemRef); let extraClassDeclaration = [{ diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index e6b9903..5a452c5 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -2376,17 +2376,8 @@ static void print(OpAsmPrinter &p, ViewOp op) { } static LogicalResult verify(ViewOp op) { - auto baseType = op.getOperand(0)->getType().dyn_cast(); - auto viewType = op.getResult()->getType().dyn_cast(); - - // Operand 0 type and ViewOp result type must be memref. - if (!baseType || !viewType) - return op.emitError("operand type ") << baseType << " and result type " - << viewType << " are must be memref"; - - // The base memref should be rank 1 with i8 element type. - if (baseType.getRank() != 1 || !baseType.getElementType().isInteger(8)) - return op.emitError("unsupported shape for base memref type ") << baseType; + auto baseType = op.getOperand(0)->getType().cast(); + auto viewType = op.getResult()->getType().cast(); // The base memref should have identity layout map (or none). if (baseType.getAffineMaps().size() > 1 || @@ -2403,7 +2394,7 @@ static LogicalResult verify(ViewOp op) { // Verify that the result memref type has a strided layout map. is strided int64_t offset; llvm::SmallVector strides; - if (failed(mlir::getStridesAndOffset(viewType, strides, offset))) + if (failed(getStridesAndOffset(viewType, strides, offset))) return op.emitError("result type ") << viewType << " is not strided"; // Verify that we have the correct number of operands for the result type. @@ -2414,6 +2405,23 @@ static LogicalResult verify(ViewOp op) { if (op.getNumOperands() != memrefOperandCount + numDynamicDims + dynamicOffsetCount) return op.emitError("incorrect number of operands for type ") << viewType; + + // Verify dynamic strides symbols were added to correct dimensions based + // on dynamic sizes. + ArrayRef viewShape = viewType.getShape(); + unsigned viewRank = viewType.getRank(); + assert(viewRank == strides.size()); + bool dynamicStrides = false; + for (int i = viewRank - 2; i >= 0; --i) { + // If size at dim 'i + 1' is dynamic, set the 'dynamicStrides' flag. + if (ShapedType::isDynamic(viewShape[i + 1])) + dynamicStrides = true; + // If stride at dim 'i' is not dynamic, return error. + if (dynamicStrides && strides[i] != MemRefType::getDynamicStrideOrOffset()) + return op.emitError("incorrect dynamic strides in view memref type ") + << viewType; + } + return success(); } diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir index 977ec66..bbabe60 100644 --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -10,8 +10,10 @@ // CHECK-DAG: #[[map_proj_d0d1_d0:map[0-9]+]] = (d0, d1) -> (d0) // CHECK-DAG: #[[map_proj_d0d1_d1:map[0-9]+]] = (d0, d1) -> (d1) // CHECK-DAG: #[[map_proj_d0d1_d1d0:map[0-9]+]] = (d0, d1) -> (d1, d0) -// CHECK-DAG: #[[VIEW_MAP0:map[0-9]+]] = (d0, d1)[s0] -> (d0 * 4 + d1 + s0) + // CHECK-DAG: #[[VIEW_MAP1:map[0-9]+]] = (d0, d1) -> (d0 * 4 + d1) +// CHECK-DAG: #[[VIEW_MAP2:map[0-9]+]] = (d0, d1)[s0, s1] -> (d0 * s0 + d1 + s1) +// CHECK-DAG: #[[VIEW_MAP3:map[0-9]+]] = (d0, d1)[s0] -> (d0 * s0 + d1) // CHECK-LABEL: func @func_with_ops(%arg0: f32) { func @func_with_ops(f32) { @@ -478,24 +480,24 @@ func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref) { func @memref_view(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<2048xi8> // Test two dynamic sizes and dynamic offset. - // CHECK: %{{.*}} = std.view %0[%arg0, %arg1][%arg2] : memref<2048xi8> to memref + // CHECK: %{{.*}} = std.view %0[%arg0, %arg1][%arg2] : memref<2048xi8> to memref %1 = view %0[%arg0, %arg1][%arg2] - : memref<2048xi8> to memref (d0 * 4 + d1 + s0)> + : memref<2048xi8> to memref (d0 * s0 + d1 + s1)> // Test two dynamic sizes and static offset. - // CHECK: %{{.*}} = std.view %0[%arg0, %arg1][] : memref<2048xi8> to memref + // CHECK: %{{.*}} = std.view %0[%arg0, %arg1][] : memref<2048xi8> to memref %2 = view %0[%arg0, %arg1][] - : memref<2048xi8> to memref (d0 * 4 + d1)> + : memref<2048xi8> to memref (d0 * s0 + d1)> // Test one dynamic size and dynamic offset. - // CHECK: %{{.*}} = std.view %0[%arg1][%arg2] : memref<2048xi8> to memref<4x?xf32, #[[VIEW_MAP0]]> + // CHECK: %{{.*}} = std.view %0[%arg1][%arg2] : memref<2048xi8> to memref<4x?xf32, #[[VIEW_MAP2]]> %3 = view %0[%arg1][%arg2] - : memref<2048xi8> to memref<4x?xf32, (d0, d1)[s0] -> (d0 * 4 + d1 + s0)> + : memref<2048xi8> to memref<4x?xf32, (d0, d1)[s0, s1] -> (d0 * s0 + d1 + s1)> // Test one dynamic size and static offset. - // CHECK: %{{.*}} = std.view %0[%arg0][] : memref<2048xi8> to memref + // CHECK: %{{.*}} = std.view %0[%arg0][] : memref<2048xi8> to memref %4 = view %0[%arg0][] - : memref<2048xi8> to memref (d0 * 4 + d1)> + : memref<2048xi8> to memref (d0 * 4 + d1)> // Test static sizes and static offset. // CHECK: %{{.*}} = std.view %0[][] : memref<2048xi8> to memref<64x4xf32, #[[VIEW_MAP1]]> diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index 4d45d95..4d1d853 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -926,7 +926,7 @@ func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<2048xf32> - // expected-error@+1 {{unsupported shape for base memref}} + // expected-error@+1 {{must be 1D memref of 8-bit integer values}} %1 = view %0[%arg0, %arg1][] : memref<2048xf32> to memref (d0 * 4 + d1 + s0)> return @@ -953,3 +953,26 @@ func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { memref (d0 * 4 + d1 + s0), 1> return } + +// ----- + +func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<2048xi8> + // expected-error@+1 {{incorrect dynamic strides}} + %1 = view %0[%arg0, %arg1][] + : memref<2048xi8> to + memref (d0 * 777 + d1 * 4 + d2)> + return +} + +// ----- + +func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<2048xi8> + // expected-error@+1 {{incorrect dynamic strides}} + %1 = view %0[%arg0][] + : memref<2048xi8> to + memref<16x4x?xf32, (d0, d1, d2) -> (d0 * 777 + d1 * 4 + d2)> + return +} + -- 2.7.4