From 858be166708081a3d5927ea5ea187c6b8c3efdaf Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 21 Jun 2022 10:07:52 +0200 Subject: [PATCH] [mlir][memref] Fix layout map computation in inferRankReducedResultType Differential Revision: https://reviews.llvm.org/D128160 --- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 5 +- mlir/unittests/Dialect/CMakeLists.txt | 2 + mlir/unittests/Dialect/MemRef/CMakeLists.txt | 7 +++ mlir/unittests/Dialect/MemRef/InferShapeTest.cpp | 60 ++++++++++++++++++++++ .../mlir/unittests/BUILD.bazel | 14 +++++ 5 files changed, 85 insertions(+), 3 deletions(-) create mode 100644 mlir/unittests/Dialect/MemRef/CMakeLists.txt create mode 100644 mlir/unittests/Dialect/MemRef/InferShapeTest.cpp diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index ee1218a..014c5be 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2164,9 +2164,8 @@ Type SubViewOp::inferRankReducedResultType(unsigned resultRank, if (!dimsToProject.test(pos)) projectedShape.push_back(shape[pos]); - AffineMap map = inferredType.getLayout().getAffineMap(); - if (!map.isIdentity()) - map = getProjectedMap(map, dimsToProject); + AffineMap map = + getProjectedMap(inferredType.getLayout().getAffineMap(), dimsToProject); inferredType = MemRefType::get(projectedShape, inferredType.getElementType(), map, inferredType.getMemorySpace()); diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt index 8a0f418..ec89b14 100644 --- a/mlir/unittests/Dialect/CMakeLists.txt +++ b/mlir/unittests/Dialect/CMakeLists.txt @@ -7,6 +7,8 @@ target_link_libraries(MLIRDialectTests MLIRDialect) add_subdirectory(Affine) +add_subdirectory(MemRef) + add_subdirectory(Quant) add_subdirectory(SparseTensor) add_subdirectory(SPIRV) diff --git a/mlir/unittests/Dialect/MemRef/CMakeLists.txt b/mlir/unittests/Dialect/MemRef/CMakeLists.txt new file mode 100644 index 0000000..c3f349a --- /dev/null +++ b/mlir/unittests/Dialect/MemRef/CMakeLists.txt @@ -0,0 +1,7 @@ +add_mlir_unittest(MLIRMemRefTests + InferShapeTest.cpp +) +target_link_libraries(MLIRMemRefTests + PRIVATE + MLIRMemRefDialect + ) diff --git a/mlir/unittests/Dialect/MemRef/InferShapeTest.cpp b/mlir/unittests/Dialect/MemRef/InferShapeTest.cpp new file mode 100644 index 0000000..1899755 --- /dev/null +++ b/mlir/unittests/Dialect/MemRef/InferShapeTest.cpp @@ -0,0 +1,60 @@ +//===- InferShapeTest.cpp - unit tests for shape inference ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace mlir::memref; + +// Source memref has identity layout. +TEST(InferShapeTest, inferRankReducedShapeIdentity) { + MLIRContext ctx; + OpBuilder b(&ctx); + auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType()); + auto reducedType = SubViewOp::inferRankReducedResultType( + /*resultRank=*/1, sourceMemref, {2, 3}, {1, 2}, {1, 1}); + AffineExpr dim0; + bindDims(&ctx, dim0); + auto expectedType = + MemRefType::get({2}, b.getIndexType(), AffineMap::get(1, 0, dim0 + 13)); + EXPECT_EQ(reducedType, expectedType); +} + +// Source memref has non-identity layout. +TEST(InferShapeTest, inferRankReducedShapeNonIdentity) { + MLIRContext ctx; + OpBuilder b(&ctx); + AffineExpr dim0, dim1; + bindDims(&ctx, dim0, dim1); + auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType(), + AffineMap::get(2, 0, 1000 * dim0 + dim1)); + auto reducedType = SubViewOp::inferRankReducedResultType( + /*resultRank=*/1, sourceMemref, {2, 3}, {1, 2}, {1, 1}); + auto expectedType = + MemRefType::get({2}, b.getIndexType(), AffineMap::get(1, 0, dim0 + 2003)); + EXPECT_EQ(reducedType, expectedType); +} + +TEST(InferShapeTest, inferRankReducedShapeToScalar) { + MLIRContext ctx; + OpBuilder b(&ctx); + AffineExpr dim0, dim1; + bindDims(&ctx, dim0, dim1); + auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType(), + AffineMap::get(2, 0, 1000 * dim0 + dim1)); + auto reducedType = SubViewOp::inferRankReducedResultType( + /*resultRank=*/0, sourceMemref, {2, 3}, {1, 1}, {1, 1}); + auto expectedType = + MemRefType::get({}, b.getIndexType(), + AffineMap::get(0, 0, b.getAffineConstantExpr(2003))); + EXPECT_EQ(reducedType, expectedType); +} diff --git a/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel index af784fa..d335e55 100644 --- a/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel @@ -106,6 +106,20 @@ cc_test( ) cc_test( + name = "memref_tests", + size = "small", + srcs = glob([ + "Dialect/MemRef/*.cpp", + "Dialect/MemRef/*.h", + ]), + deps = [ + "//llvm:TestingSupport", + "//llvm:gtest_main", + "//mlir:MemRefDialect", + ], +) + +cc_test( name = "quantops_tests", size = "small", srcs = glob([ -- 2.7.4