From 12d79b1514b8036a6079495a45091c9bfb8db1a6 Mon Sep 17 00:00:00 2001 From: Matt Arsenault Date: Sat, 9 Apr 2022 10:45:31 -0400 Subject: [PATCH] GlobalISel: Add LLT helper to multiply vector sizes --- llvm/include/llvm/Support/LowLevelTypeImpl.h | 12 ++++++++ llvm/include/llvm/Support/TypeSize.h | 5 ++++ llvm/unittests/CodeGen/LowLevelTypeTest.cpp | 45 ++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+) diff --git a/llvm/include/llvm/Support/LowLevelTypeImpl.h b/llvm/include/llvm/Support/LowLevelTypeImpl.h index dd286f5..186a7e5 100644 --- a/llvm/include/llvm/Support/LowLevelTypeImpl.h +++ b/llvm/include/llvm/Support/LowLevelTypeImpl.h @@ -207,6 +207,18 @@ public: return scalar(getScalarSizeInBits() / Factor); } + /// Produce a vector type that is \p Factor times bigger, preserving the + /// element type. For a scalar or pointer, this will produce a new vector with + /// \p Factor elements. + LLT multiplyElements(int Factor) const { + if (isVector()) { + return scalarOrVector(getElementCount().multiplyCoefficientBy(Factor), + getElementType()); + } + + return fixed_vector(Factor, *this); + } + bool isByteSized() const { return getSizeInBits().isKnownMultipleOf(8); } unsigned getScalarSizeInBits() const { diff --git a/llvm/include/llvm/Support/TypeSize.h b/llvm/include/llvm/Support/TypeSize.h index 6bddb60..96f33c1 100644 --- a/llvm/include/llvm/Support/TypeSize.h +++ b/llvm/include/llvm/Support/TypeSize.h @@ -362,6 +362,11 @@ public: LinearPolySize::get(getKnownMinValue() / RHS, isScalable())); } + LeafTy multiplyCoefficientBy(ScalarTy RHS) const { + return static_cast( + LinearPolySize::get(getKnownMinValue() * RHS, isScalable())); + } + LeafTy coefficientNextPowerOf2() const { return static_cast(LinearPolySize::get( static_cast(llvm::NextPowerOf2(getKnownMinValue())), diff --git a/llvm/unittests/CodeGen/LowLevelTypeTest.cpp b/llvm/unittests/CodeGen/LowLevelTypeTest.cpp index bf629c5..ef0a439 100644 --- a/llvm/unittests/CodeGen/LowLevelTypeTest.cpp +++ b/llvm/unittests/CodeGen/LowLevelTypeTest.cpp @@ -319,4 +319,49 @@ TEST(LowLevelTypeTest, Divide) { LLT::fixed_vector(4, LLT::pointer(1, 64)).divide(2)); } +TEST(LowLevelTypeTest, MultiplyElements) { + // Basic scalar->vector cases + EXPECT_EQ(LLT::fixed_vector(2, 16), LLT::scalar(16).multiplyElements(2)); + EXPECT_EQ(LLT::fixed_vector(3, 16), LLT::scalar(16).multiplyElements(3)); + EXPECT_EQ(LLT::fixed_vector(4, 32), LLT::scalar(32).multiplyElements(4)); + EXPECT_EQ(LLT::fixed_vector(4, 7), LLT::scalar(7).multiplyElements(4)); + + // Basic vector to vector cases + EXPECT_EQ(LLT::fixed_vector(4, 32), + LLT::fixed_vector(2, 32).multiplyElements(2)); + EXPECT_EQ(LLT::fixed_vector(9, 32), + LLT::fixed_vector(3, 32).multiplyElements(3)); + + // Pointer to vector of pointers + EXPECT_EQ(LLT::fixed_vector(2, LLT::pointer(0, 32)), + LLT::pointer(0, 32).multiplyElements(2)); + EXPECT_EQ(LLT::fixed_vector(3, LLT::pointer(1, 32)), + LLT::pointer(1, 32).multiplyElements(3)); + EXPECT_EQ(LLT::fixed_vector(4, LLT::pointer(1, 64)), + LLT::pointer(1, 64).multiplyElements(4)); + + // Vector of pointers to vector of pointers + EXPECT_EQ(LLT::fixed_vector(8, LLT::pointer(1, 64)), + LLT::fixed_vector(2, LLT::pointer(1, 64)).multiplyElements(4)); + EXPECT_EQ(LLT::fixed_vector(9, LLT::pointer(1, 32)), + LLT::fixed_vector(3, LLT::pointer(1, 32)).multiplyElements(3)); + + // Scalable vectors + EXPECT_EQ(LLT::scalable_vector(4, 16), + LLT::scalable_vector(2, 16).multiplyElements(2)); + EXPECT_EQ(LLT::scalable_vector(6, 16), + LLT::scalable_vector(2, 16).multiplyElements(3)); + EXPECT_EQ(LLT::scalable_vector(9, 16), + LLT::scalable_vector(3, 16).multiplyElements(3)); + EXPECT_EQ(LLT::scalable_vector(4, 32), + LLT::scalable_vector(2, 32).multiplyElements(2)); + EXPECT_EQ(LLT::scalable_vector(256, 32), + LLT::scalable_vector(8, 32).multiplyElements(32)); + + // Scalable vectors of pointers + EXPECT_EQ(LLT::scalable_vector(4, LLT::pointer(0, 32)), + LLT::scalable_vector(2, LLT::pointer(0, 32)).multiplyElements(2)); + EXPECT_EQ(LLT::scalable_vector(32, LLT::pointer(1, 64)), + LLT::scalable_vector(8, LLT::pointer(1, 64)).multiplyElements(4)); +} } -- 2.7.4