Disallow index types in memrefs.
authorAlex Zinenko <zinenko@google.com>
Thu, 3 Oct 2019 07:57:55 +0000 (00:57 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 3 Oct 2019 07:58:29 +0000 (00:58 -0700)
As specified in the MLIR language reference and rationale documents, `memref`
types should not be allowed to have `index` as element types. As observed in
https://groups.google.com/a/tensorflow.org/forum/#!msg/mlir/P49hVWqTMNc/nW89a4i_AgAJ
this restriction was lifted when canonicalization unit tests for affine
operations were introduced, without sufficient motivation to lift the
restriction itself.  The test in question can be trivially rewritten (return
the value from a function instead of storing it to prevent DCE from removing
the producer operation) and the restriction put back in place.

If `memref<...x index>` is relevant for some use cases, the relaxation of the
type system can be implemented separately with appropriate modifications to the
documentation.

PiperOrigin-RevId: 272607043

mlir/lib/IR/StandardTypes.cpp
mlir/test/AffineOps/canonicalize.mlir
mlir/test/IR/invalid.mlir

index 0a77bfa..ab80d83 100644 (file)
@@ -341,6 +341,13 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
                                Optional<Location> location) {
   auto *context = elementType.getContext();
 
+  // Check that memref is formed from allowed types.
+  if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>()) {
+    if (location)
+      emitError(*location, "invalid memref element type");
+    return nullptr;
+  }
+
   for (int64_t s : shape) {
     // Negative sizes are not allowed except for `-1` that means dynamic size.
     if (s < -1) {
index 6d84913..b8c00d9 100644 (file)
@@ -258,15 +258,12 @@ func @trivial_maps() {
 }
 
 // CHECK-LABEL: func @partial_fold_map
-func @partial_fold_map(%arg0: memref<index>, %arg1: index, %arg2: index) {
+func @partial_fold_map(%arg1: index, %arg2: index) -> index {
   // TODO: Constant fold one index into affine.apply
   %c42 = constant 42 : index
   %2 = affine.apply (d0, d1) -> (d0 - d1) (%arg1, %c42)
-  store %2, %arg0[] : memref<index>
   // CHECK: [[X:%[0-9]+]] = affine.apply [[MAP15]]()[%{{.*}}]
-  // CHECK-NEXT: store [[X]], %{{.*}}
-
-  return
+  return %2 : index
 }
 
 // CHECK-LABEL: func @symbolic_composition_a(%{{.*}}: index, %{{.*}}: index) -> index {
index 06e0b96..f62f433 100644 (file)
@@ -21,8 +21,7 @@ func @indexvector(vector<4 x index>) -> () // expected-error {{vector elements m
 
 // -----
 
-// Everything is valid in a memref.
-func @indexmemref(memref<? x index>) -> ()
+func @indexmemref(memref<? x index>) -> () // expected-error {{invalid memref element type}}
 
 // -----