Verify that affine.load/store/dma_start/dma_wait operands are valid dimension or...
authorAndy Davis <andydavis@google.com>
Fri, 26 Jul 2019 20:00:01 +0000 (13:00 -0700)
committerjpienaar <jpienaar@google.com>
Sat, 27 Jul 2019 15:20:38 +0000 (08:20 -0700)
PiperOrigin-RevId: 260197567

mlir/lib/AffineOps/AffineOps.cpp
mlir/test/AffineOps/load-store-invalid.mlir

index e730ba5..9a02623 100644 (file)
@@ -108,6 +108,13 @@ bool mlir::isValidSymbol(Value *value) {
   return isTopLevelSymbol(value);
 }
 
+// Returns true if 'value' is a valid index to an affine operation (e.g.
+// affine.load, affine.store, affine.dma_start, affine.dma_wait).
+// Returns false otherwise.
+static bool isValidAffineIndexOperand(Value *value) {
+  return isValidDim(value) || isValidSymbol(value);
+}
+
 /// Utility function to verify that a set of operands are valid dimension and
 /// symbol identifiers. The operands should be layed out such that the dimension
 /// operands are before the symbol operands. This function returns failure if
@@ -880,6 +887,25 @@ LogicalResult AffineDmaStartOp::verify() {
       getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
     return emitOpError("incorrect number of operands");
   }
+
+  for (auto *idx : getSrcIndices()) {
+    if (!idx->getType().isIndex())
+      return emitOpError("src index to dma_start must have 'index' type");
+    if (!isValidAffineIndexOperand(idx))
+      return emitOpError("src index must be a dimension or symbol identifier");
+  }
+  for (auto *idx : getDstIndices()) {
+    if (!idx->getType().isIndex())
+      return emitOpError("dst index to dma_start must have 'index' type");
+    if (!isValidAffineIndexOperand(idx))
+      return emitOpError("dst index must be a dimension or symbol identifier");
+  }
+  for (auto *idx : getTagIndices()) {
+    if (!idx->getType().isIndex())
+      return emitOpError("tag index to dma_start must have 'index' type");
+    if (!isValidAffineIndexOperand(idx))
+      return emitOpError("tag index must be a dimension or symbol identifier");
+  }
   return success();
 }
 
@@ -951,6 +977,12 @@ ParseResult AffineDmaWaitOp::parse(OpAsmParser *parser,
 LogicalResult AffineDmaWaitOp::verify() {
   if (!getOperand(0)->getType().isa<MemRefType>())
     return emitOpError("expected DMA tag to be of memref type");
+  for (auto *idx : getTagIndices()) {
+    if (!idx->getType().isIndex())
+      return emitOpError("index to dma_wait must have 'index' type");
+    if (!isValidAffineIndexOperand(idx))
+      return emitOpError("index must be a dimension or symbol identifier");
+  }
   return success();
 }
 
@@ -1549,7 +1581,6 @@ void AffineIfOp::setIntegerSet(IntegerSet newSet) {
 
 void AffineLoadOp::build(Builder *builder, OperationState *result,
                          AffineMap map, ArrayRef<Value *> operands) {
-  // TODO(b/133776335) Check that map operands are loop IVs or symbols.
   result->addOperands(operands);
   if (map)
     result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
@@ -1616,10 +1647,12 @@ LogicalResult AffineLoadOp::verify() {
           "expects the number of subscripts to be equal to memref rank");
   }
 
-  for (auto *idx : getIndices())
+  for (auto *idx : getIndices()) {
     if (!idx->getType().isIndex())
       return emitOpError("index to load must have 'index' type");
-  // TODO(b/133776335) Verify that map operands are loop IVs or symbols.
+    if (!isValidAffineIndexOperand(idx))
+      return emitOpError("index must be a dimension or symbol identifier");
+  }
   return success();
 }
 
@@ -1637,7 +1670,6 @@ void AffineLoadOp::getCanonicalizationPatterns(
 void AffineStoreOp::build(Builder *builder, OperationState *result,
                           Value *valueToStore, AffineMap map,
                           ArrayRef<Value *> operands) {
-  // TODO(b/133776335) Check that map operands are loop IVs or symbols.
   result->addOperands(valueToStore);
   result->addOperands(operands);
   if (map)
@@ -1708,10 +1740,12 @@ LogicalResult AffineStoreOp::verify() {
           "expects the number of subscripts to be equal to memref rank");
   }
 
-  for (auto *idx : getIndices())
+  for (auto *idx : getIndices()) {
     if (!idx->getType().isIndex())
-      return emitOpError("index to load must have 'index' type");
-  // TODO(b/133776335) Verify that map operands are loop IVs or symbols.
+      return emitOpError("index to store must have 'index' type");
+    if (!isValidAffineIndexOperand(idx))
+      return emitOpError("index must be a dimension or symbol identifier");
+  }
   return success();
 }
 
index 6211023..b1aebc3 100644 (file)
@@ -59,3 +59,95 @@ func @store_too_few_subscripts_map(%arg0: memref<?x?xf32>, %arg1: index, %val: f
   "affine.store"(%val, %arg0, %arg1)
     {map = (i, j) -> (i, j) } : (f32, memref<?x?xf32>, index) -> ()
 }
+
+// -----
+
+func @load_non_affine_index(%arg0 : index) {
+  %0 = alloc() : memref<10xf32>
+  affine.for %i0 = 0 to 10 {
+    %1 = muli %i0, %arg0 : index
+    // expected-error@+1 {{op index must be a dimension or symbol identifier}}
+    %v = affine.load %0[%1] : memref<10xf32>
+  }
+  return
+}
+
+// -----
+
+func @store_non_affine_index(%arg0 : index) {
+  %0 = alloc() : memref<10xf32>
+  %1 = constant 11.0 : f32
+  affine.for %i0 = 0 to 10 {
+    %2 = muli %i0, %arg0 : index
+    // expected-error@+1 {{op index must be a dimension or symbol identifier}}
+    affine.store %1, %0[%2] : memref<10xf32>
+  }
+  return
+}
+
+// -----
+
+func @dma_start_non_affine_src_index(%arg0 : index) {
+  %0 = alloc() : memref<100xf32>
+  %1 = alloc() : memref<100xf32, 2>
+  %2 = alloc() : memref<1xi32, 4>
+  %c0 = constant 0 : index
+  %c64 = constant 64 : index
+  affine.for %i0 = 0 to 10 {
+    %3 = muli %i0, %arg0 : index
+    // expected-error@+1 {{op src index must be a dimension or symbol identifier}}
+    affine.dma_start %0[%3], %1[%i0], %2[%c0], %c64
+        : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
+  }
+  return
+}
+
+// -----
+
+func @dma_start_non_affine_dst_index(%arg0 : index) {
+  %0 = alloc() : memref<100xf32>
+  %1 = alloc() : memref<100xf32, 2>
+  %2 = alloc() : memref<1xi32, 4>
+  %c0 = constant 0 : index
+  %c64 = constant 64 : index
+  affine.for %i0 = 0 to 10 {
+    %3 = muli %i0, %arg0 : index
+    // expected-error@+1 {{op dst index must be a dimension or symbol identifier}}
+    affine.dma_start %0[%i0], %1[%3], %2[%c0], %c64
+        : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
+  }
+  return
+}
+
+// -----
+
+func @dma_start_non_affine_tag_index(%arg0 : index) {
+  %0 = alloc() : memref<100xf32>
+  %1 = alloc() : memref<100xf32, 2>
+  %2 = alloc() : memref<1xi32, 4>
+  %c0 = constant 0 : index
+  %c64 = constant 64 : index
+  affine.for %i0 = 0 to 10 {
+    %3 = muli %i0, %arg0 : index
+    // expected-error@+1 {{op tag index must be a dimension or symbol identifier}}
+    affine.dma_start %0[%i0], %1[%arg0], %2[%3], %c64
+        : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
+  }
+  return
+}
+
+// -----
+
+func @dma_wait_non_affine_tag_index(%arg0 : index) {
+  %0 = alloc() : memref<100xf32>
+  %1 = alloc() : memref<100xf32, 2>
+  %2 = alloc() : memref<1xi32, 4>
+  %c0 = constant 0 : index
+  %c64 = constant 64 : index
+  affine.for %i0 = 0 to 10 {
+    %3 = muli %i0, %arg0 : index
+    // expected-error@+1 {{op index must be a dimension or symbol identifier}}
+    affine.dma_wait %2[%3], %c64 : memref<1xi32, 4>
+  }
+  return
+}