[mlir] simplify type constraints in AVX512 dialect
authorAlex Zinenko <zinenko@google.com>
Wed, 10 Mar 2021 10:45:53 +0000 (11:45 +0100)
committerAlex Zinenko <zinenko@google.com>
Wed, 10 Mar 2021 12:07:25 +0000 (13:07 +0100)
VectorOfLengthAndType accepts a cartesian product of given lengths and types
rather than types produced by co-indexed values in the corresponding lists.
Update the definitions accordingly. The type validity is already enforced by
op traits.

Reviewed By: nicolasvasilache, springerm

Differential Revision: https://reviews.llvm.org/D98327

mlir/include/mlir/Dialect/AVX512/AVX512.td

index c2487a0..391ce74 100644 (file)
@@ -54,14 +54,14 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [NoSideEffect,
   remaining elements from `src`.
   }];
   let verifier = [{ return ::verify(*this); }];
-  let arguments = (ins VectorOfLengthAndType<[16, 16, 8, 8],
-                                             [I1, I1, I1, I1]>:$k,
-                   VectorOfLengthAndType<[16, 16, 8, 8],
+  let arguments = (ins VectorOfLengthAndType<[16, 8],
+                                             [I1]>:$k,
+                   VectorOfLengthAndType<[16, 8],
                                          [F32, I32, F64, I64]>:$a,
-                   Optional<VectorOfLengthAndType<[16, 16, 8, 8],
+                   Optional<VectorOfLengthAndType<[16, 8],
                                                   [F32, I32, F64, I64]>>:$src,
                    OptionalAttr<ElementsAttr>:$constant_src);
-  let results = (outs VectorOfLengthAndType<[16, 16, 8, 8],
+  let results = (outs VectorOfLengthAndType<[16, 8],
                                             [F32, I32, F64, I64]>:$dst);
   let assemblyFormat = "$k `,` $a (`,` $src^)? attr-dict"
                        " `:` type($dst) (`,` type($src)^)?";
@@ -162,8 +162,8 @@ def Vp2IntersectOp : AVX512_Op<"vp2intersect", [NoSideEffect,
   let arguments = (ins VectorOfLengthAndType<[16, 8], [I32, I64]>:$a,
                    VectorOfLengthAndType<[16, 8], [I32, I64]>:$b
                    );
-  let results = (outs VectorOfLengthAndType<[16, 8], [I1, I1]>:$k1,
-                 VectorOfLengthAndType<[16, 8], [I1, I1]>:$k2
+  let results = (outs VectorOfLengthAndType<[16, 8], [I1]>:$k1,
+                 VectorOfLengthAndType<[16, 8], [I1]>:$k2
                  );
   let assemblyFormat =
     "$a `,` $b attr-dict `:` type($a)";