From b9ff67099ad6da931976e66f1510c5af2558a86e Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Thu, 18 Feb 2021 12:56:42 -0800 Subject: [PATCH] [MLIR] Make structured op tests permutation invariant Extracts the relevant dimensions from the map under test to build up the maps to test against in a permutation-invariant way. This also includes a fix to the indexing maps used by isColumnMajorMatmul. The maps as currently written do not describe a column-major matmul. The linalg named op column_major_matmul has the correct maps (and notably fails the current test). If `C = matmul(A, B)` we want an operation that given A in column major format and B in column major format produces C in column major format. Given that for a matrix, faux column major is just transpose. `column_major_matmul(transpose(A), transpose(B)) = transpose(C)`. If `A` is `NxK` and `B` is `KxM`, then `C` is `NxM`, so `transpose(A)` is `KxN`, `transpose(B)` is `MxK` and `transpose(C)` is `MxN`, not `NxM` as these maps currently have. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D96984 --- .../mlir/Dialect/Utils/StructuredOpsUtils.h | 38 ++- mlir/lib/Dialect/CMakeLists.txt | 1 + mlir/lib/Dialect/Utils/CMakeLists.txt | 6 + mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp | 92 ++++++++ mlir/lib/Dialect/Vector/CMakeLists.txt | 1 + mlir/unittests/Dialect/Utils/CMakeLists.txt | 6 + .../Dialect/Utils/StructuredOpsUtilsTest.cpp | 256 +++++++++++++++++++++ 7 files changed, 379 insertions(+), 21 deletions(-) create mode 100644 mlir/lib/Dialect/Utils/CMakeLists.txt create mode 100644 mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp create mode 100644 mlir/unittests/Dialect/Utils/CMakeLists.txt create mode 100644 mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h index b903c09..c7d2476 100644 --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -24,27 +24,23 @@ namespace mlir { -inline bool isRowMajorMatmul(ArrayAttr indexingMaps) { - auto context = indexingMaps.getContext(); - AffineExpr m, n, k; - bindDims(context, m, n, k); - auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context)); - auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context)); - auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, context)); - auto maps = ArrayAttr::get(context, {mapA, mapB, mapC}); - return indexingMaps == maps; -} - -inline bool isColumnMajorMatmul(ArrayAttr indexingMaps) { - auto context = indexingMaps.getContext(); - AffineExpr m, n, k; - bindDims(context, m, n, k); - auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context)); - auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context)); - auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}, context)); - auto maps = ArrayAttr::get(context, {mapA, mapB, mapC}); - return indexingMaps == maps; -} +/// Tests whether the given maps describe a row major matmul. The test is +/// permutation-invariant. Note that this only checks the affine maps from an +/// operation, so does not perform any checks on the math being performed within +/// the reduction. +bool isRowMajorMatmul(ArrayAttr indexingMaps); + +/// Tests whether the given maps describe a column major matmul. The test is +/// permutation-invariant. Note that this only checks the affine maps from an +/// operation, so does not perform any checks on the math being performed within +/// the reduction. +bool isColumnMajorMatmul(ArrayAttr indexingMaps); + +/// Tests whether the given maps describe a row major batch matmul. The test is +/// permutation-invariant. Note that this only checks the affine maps from an +/// operation, so does not perform any checks on the math being performed within +/// the reduction. +bool isRowMajorBatchMatmul(ArrayAttr indexingMaps); /// Attribute name for the AffineArrayAttr which encodes the relationship /// between a structured op iterators' and its operands. diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt index ddeaa2e..251ea05 100644 --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -20,6 +20,7 @@ add_subdirectory(SPIRV) add_subdirectory(StandardOps) add_subdirectory(Tensor) add_subdirectory(Tosa) +add_subdirectory(Utils) add_subdirectory(Vector) set(LLVM_OPTIONAL_SOURCES diff --git a/mlir/lib/Dialect/Utils/CMakeLists.txt b/mlir/lib/Dialect/Utils/CMakeLists.txt new file mode 100644 index 0000000..9213e6f --- /dev/null +++ b/mlir/lib/Dialect/Utils/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_library(MLIRDialectUtils + StructuredOpsUtils.cpp + + DEPENDS + MLIRIR +) diff --git a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp new file mode 100644 index 0000000..27b5701 --- /dev/null +++ b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp @@ -0,0 +1,92 @@ +//===- StructuredOpsUtils.cpp - Utilities used by structured ops ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinAttributes.h" + +using namespace mlir; + +bool mlir::isRowMajorMatmul(ArrayAttr indexingMaps) { + if (indexingMaps.size() != 3) + return false; + + auto map0 = indexingMaps[0].cast().getValue(); + auto map1 = indexingMaps[1].cast().getValue(); + auto map2 = indexingMaps[2].cast().getValue(); + + if (map0.getNumResults() != 2 || map1.getNumResults() != 2 || + map2.getNumResults() != 2 || map0.getNumInputs() != 3 || + map1.getNumInputs() != 3 || map2.getNumInputs() != 3) { + return false; + } + + // Extract dimensions for MxK * KxN -> MxN + AffineExpr m = map2.getResult(0); + AffineExpr n = map2.getResult(1); + AffineExpr k = map0.getResult(1); + auto *context = indexingMaps.getContext(); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, context)); + auto maps = ArrayAttr::get(context, {mapA, mapB, mapC}); + return indexingMaps == maps; +} + +bool mlir::isColumnMajorMatmul(ArrayAttr indexingMaps) { + if (indexingMaps.size() != 3) + return false; + + auto map0 = indexingMaps[0].cast().getValue(); + auto map1 = indexingMaps[1].cast().getValue(); + auto map2 = indexingMaps[2].cast().getValue(); + + if (map0.getNumResults() != 2 || map1.getNumResults() != 2 || + map2.getNumResults() != 2 || map0.getNumInputs() != 3 || + map1.getNumInputs() != 3 || map2.getNumInputs() != 3) { + return false; + } + + // Extract dimensions for KxM * NxK -> NxM + AffineExpr n = map2.getResult(0); + AffineExpr m = map2.getResult(1); + AffineExpr k = map0.getResult(0); + auto *context = indexingMaps.getContext(); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {n, k}, context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}, context)); + auto maps = ArrayAttr::get(context, {mapA, mapB, mapC}); + return indexingMaps == maps; +} + +bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) { + if (indexingMaps.size() != 3) + return false; + + auto map0 = indexingMaps[0].cast().getValue(); + auto map1 = indexingMaps[1].cast().getValue(); + auto map2 = indexingMaps[2].cast().getValue(); + + if (map0.getNumResults() != 3 || map1.getNumResults() != 3 || + map2.getNumResults() != 3 || map0.getNumInputs() != 4 || + map1.getNumInputs() != 4 || map2.getNumInputs() != 4) { + return false; + } + + // Extract dimensions for BxMxK * BxKxN -> BxMxN + AffineExpr b = map2.getResult(0); + AffineExpr m = map2.getResult(1); + AffineExpr n = map2.getResult(2); + AffineExpr k = map0.getResult(2); + auto *context = indexingMaps.getContext(); + auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {b, m, k}, context)); + auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {b, k, n}, context)); + auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {b, m, n}, context)); + auto maps = ArrayAttr::get(context, {mapA, mapB, mapC}); + return indexingMaps == maps; +} diff --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt index 957ea66..f3e9af2 100644 --- a/mlir/lib/Dialect/Vector/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/CMakeLists.txt @@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRVector LINK_LIBS PUBLIC MLIRAffineEDSC MLIREDSC + MLIRDialectUtils MLIRIR MLIRStandard MLIRAffine diff --git a/mlir/unittests/Dialect/Utils/CMakeLists.txt b/mlir/unittests/Dialect/Utils/CMakeLists.txt new file mode 100644 index 0000000..d75b693 --- /dev/null +++ b/mlir/unittests/Dialect/Utils/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_unittest(MLIRDialectUtilsTests + StructuredOpsUtilsTest.cpp +) +target_link_libraries(MLIRDialectUtilsTests + PRIVATE + MLIRDialectUtils) diff --git a/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp new file mode 100644 index 0000000..bb95402 --- /dev/null +++ b/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp @@ -0,0 +1,256 @@ +//===- StructuredOpsUtilsTest.cpp - StructuredOpsUtils unit tests ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using namespace mlir; +using testing::Not; +using testing::Truly; + +namespace { + +TEST(isRowMajorMatmul, Simple) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Truly(isRowMajorMatmul)); +} + +TEST(isRowMajorMatmul, BindingShifted) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, k, m, n); // bind in different order + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Truly(isRowMajorMatmul)); +} + +TEST(isRowMajorMatmul, BindingSwapped) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, k, n, m); // bind in different order + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Truly(isRowMajorMatmul)); +} + +TEST(isRowMajorMatmul, ColumnMajor) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul))); +} + +TEST(isRowMajorMatmul, FirstInputSwapped) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul))); +} + +TEST(isRowMajorMatmul, TooFewMaps) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB}); + + EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul))); +} + +TEST(isRowMajorMatmul, TooManyMaps) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto mapD = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, &context)); + + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC, mapD}); + + EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul))); +} + +TEST(isRowMajorMatmul, TooFewDims) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul))); +} + +TEST(isRowMajorMatmul, TooFewOutputs) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul))); +} + +TEST(isColumnMajorMatmul, Simple) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Truly(isColumnMajorMatmul)); +} + +TEST(isColumnMajorMatmul, BindingShifted) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, k, m, n); // bind in different order + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Truly(isColumnMajorMatmul)); +} + +TEST(isColumnMajorMatmul, BindingSwapped) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, k, n, m); // bind in different order + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Truly(isColumnMajorMatmul)); +} + +TEST(isColumnMajorMatmul, RowMajor) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Not(Truly(isColumnMajorMatmul))); +} + +TEST(isColumnMajorMatmul, FirstInputSwapped) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {n, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Not(Truly(isColumnMajorMatmul))); +} + +TEST(isRowMajorBatchMatmul, Simple) { + MLIRContext context; + + AffineExpr batch, m, n, k; + bindDims(&context, batch, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Truly(isRowMajorBatchMatmul)); +} + +TEST(isRowMajorBatchMatmul, BindingShifted) { + MLIRContext context; + + AffineExpr batch, m, n, k; + bindDims(&context, k, batch, m, n); // bind in different order + auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Truly(isRowMajorBatchMatmul)); +} + +TEST(isRowMajorBatchMatmul, BindingSwapped) { + MLIRContext context; + + AffineExpr batch, m, n, k; + bindDims(&context, batch, k, n, m); // bind in different order + auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Truly(isRowMajorBatchMatmul)); +} + +TEST(isRowMajorBatchMatmul, FirstInputSwapped) { + MLIRContext context; + + AffineExpr batch, m, n, k; + bindDims(&context, batch, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, m}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Not(Truly(isRowMajorBatchMatmul))); +} + +} // namespace -- 2.7.4