From ccf24225e3f2356ebf0e73bb114a831bf1721222 Mon Sep 17 00:00:00 2001 From: Florian Hahn Date: Thu, 9 Jan 2020 10:23:34 +0000 Subject: [PATCH] [Matrix] Update shape propagation to iterate until done. This patch updates the shape propagation to iterate until no new shape information is discovered. As initial seed for the forward propagation, we use the matrix intrinsic instructions. Both propagateShapeForward and propagateShapeBackward return new work lists, with the instructions to be used for the next iteration. When propagating forward, we record all instructions we added new shape information for. When propagating backward, we record all users of instructions we added new shape information for. Reviewers: anemet, Gerolf, reames, hfinkel, andrew.w.kaylor Reviewed By: anemet Differential Revision: https://reviews.llvm.org/D70901 --- .../Transforms/Scalar/LowerMatrixIntrinsics.cpp | 105 ++++++++++++--------- .../propagate-multiple-iterations.ll | 84 +++++++++++++++++ 2 files changed, 146 insertions(+), 43 deletions(-) create mode 100644 llvm/test/Transforms/LowerMatrixIntrinsics/propagate-multiple-iterations.ll diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index afe1b4e..0ff6ee8 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -10,9 +10,6 @@ // // TODO: // * Implement multiply & add fusion -// * Implement shape propagation -// * Implement optimizations to reduce or eliminateshufflevector uses by using -// shape information. // * Add remark, summarizing the available matrix optimization opportunities. // //===----------------------------------------------------------------------===// @@ -321,32 +318,12 @@ public: } /// Propagate the shape information of instructions to their users. - void propagateShapeForward() { - // The work list contains instructions for which we can compute the shape, - // either based on the information provided by matrix intrinsics or known - // shapes of operands. - SmallVector WorkList; - - // Initialize the work list with ops carrying shape information. Initially - // only the shape of matrix intrinsics is known. - for (BasicBlock &BB : Func) - for (Instruction &Inst : BB) { - IntrinsicInst *II = dyn_cast(&Inst); - if (!II) - continue; - - switch (II->getIntrinsicID()) { - case Intrinsic::matrix_multiply: - case Intrinsic::matrix_transpose: - case Intrinsic::matrix_columnwise_load: - case Intrinsic::matrix_columnwise_store: - WorkList.push_back(&Inst); - break; - default: - break; - } - } - + /// The work list contains instructions for which we can compute the shape, + /// either based on the information provided by matrix intrinsics or known + /// shapes of operands. + SmallVector + propagateShapeForward(SmallVectorImpl &WorkList) { + SmallVector NewWorkList; // Pop an element for which we guaranteed to have at least one of the // operand shapes. Add the shape for this and then add users to the work // list. @@ -395,20 +372,29 @@ public: } } - if (Propagate) + if (Propagate) { + NewWorkList.push_back(Inst); for (auto *User : Inst->users()) if (ShapeMap.count(User) == 0) WorkList.push_back(cast(User)); + } } + + return NewWorkList; } /// Propagate the shape to operands of instructions with shape information. - void propagateShapeBackward() { - SmallVector WorkList; - // Worklist contains instruction for which we already know the shape. - for (auto &V : ShapeMap) - WorkList.push_back(V.first); - + /// \p Worklist contains the instruction for which we already know the shape. + SmallVector + propagateShapeBackward(SmallVectorImpl &WorkList) { + SmallVector NewWorkList; + + auto pushInstruction = [](Value *V, + SmallVectorImpl &WorkList) { + Instruction *I = dyn_cast(V); + if (I) + WorkList.push_back(I); + }; // Pop an element with known shape. Traverse the operands, if their shape // derives from the result shape and is unknown, add it and add them to the // worklist. @@ -417,6 +403,7 @@ public: Value *V = WorkList.back(); WorkList.pop_back(); + size_t BeforeProcessingV = WorkList.size(); if (!isa(V)) continue; @@ -429,21 +416,21 @@ public: m_Value(MatrixA), m_Value(MatrixB), m_Value(M), m_Value(N), m_Value(K)))) { if (setShapeInfo(MatrixA, {M, N})) - WorkList.push_back(MatrixA); + pushInstruction(MatrixA, WorkList); if (setShapeInfo(MatrixB, {N, K})) - WorkList.push_back(MatrixB); + pushInstruction(MatrixB, WorkList); } else if (match(V, m_Intrinsic( m_Value(MatrixA), m_Value(M), m_Value(N)))) { // Flip dimensions. if (setShapeInfo(MatrixA, {M, N})) - WorkList.push_back(MatrixA); + pushInstruction(MatrixA, WorkList); } else if (match(V, m_Intrinsic( m_Value(MatrixA), m_Value(), m_Value(), m_Value(M), m_Value(N)))) { if (setShapeInfo(MatrixA, {M, N})) { - WorkList.push_back(MatrixA); + pushInstruction(MatrixA, WorkList); } } else if (isa(V) || match(V, m_Intrinsic())) { @@ -456,16 +443,48 @@ public: ShapeInfo Shape = ShapeMap[V]; for (Use &U : cast(V)->operands()) { if (setShapeInfo(U.get(), Shape)) - WorkList.push_back(U.get()); + pushInstruction(U.get(), WorkList); } } + // After we discovered new shape info for new instructions in the + // worklist, we use their users as seeds for the next round of forward + // propagation. + for (size_t I = BeforeProcessingV; I != WorkList.size(); I++) + for (User *U : WorkList[I]->users()) + if (isa(U) && V != U) + NewWorkList.push_back(cast(U)); } + return NewWorkList; } bool Visit() { if (EnableShapePropagation) { - propagateShapeForward(); - propagateShapeBackward(); + SmallVector WorkList; + + // Initially only the shape of matrix intrinsics is known. + // Initialize the work list with ops carrying shape information. + for (BasicBlock &BB : Func) + for (Instruction &Inst : BB) { + IntrinsicInst *II = dyn_cast(&Inst); + if (!II) + continue; + + switch (II->getIntrinsicID()) { + case Intrinsic::matrix_multiply: + case Intrinsic::matrix_transpose: + case Intrinsic::matrix_columnwise_load: + case Intrinsic::matrix_columnwise_store: + WorkList.push_back(&Inst); + break; + default: + break; + } + } + // Propagate shapes until nothing changes any longer. + while (!WorkList.empty()) { + WorkList = propagateShapeForward(WorkList); + WorkList = propagateShapeBackward(WorkList); + } } ReversePostOrderTraversal RPOT(&Func); diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-multiple-iterations.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-multiple-iterations.ll new file mode 100644 index 0000000..38200b3 --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-multiple-iterations.ll @@ -0,0 +1,84 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -lower-matrix-intrinsics -S < %s | FileCheck %s +; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s + + +; Make sure we propagate in multiple iterations. First, we back-propagate the +; shape information from the transpose to %A, in the next iteration we +; forward-propagate it to %Mul, and then back to %B. +define <16 x double> @backpropagation_iterations(<16 x double>* %A.Ptr, <16 x double>* %B.Ptr) { +; CHECK-LABEL: @backpropagation_iterations( +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <16 x double>* [[A_PTR:%.*]] to double* +; CHECK-NEXT: [[TMP2:%.*]] = bitcast double* [[TMP1]] to <4 x double>* +; CHECK-NEXT: [[TMP3:%.*]] = load <4 x double>, <4 x double>* [[TMP2]], align 8 +; CHECK-NEXT: [[TMP5:%.*]] = getelementptr double, double* [[TMP1]], i32 4 +; CHECK-NEXT: [[TMP6:%.*]] = bitcast double* [[TMP5]] to <4 x double>* +; CHECK-NEXT: [[TMP7:%.*]] = load <4 x double>, <4 x double>* [[TMP6]], align 8 +; CHECK-NEXT: [[TMP9:%.*]] = getelementptr double, double* [[TMP1]], i32 8 +; CHECK-NEXT: [[TMP10:%.*]] = bitcast double* [[TMP9]] to <4 x double>* +; CHECK-NEXT: [[TMP11:%.*]] = load <4 x double>, <4 x double>* [[TMP10]], align 8 +; CHECK-NEXT: [[TMP13:%.*]] = getelementptr double, double* [[TMP1]], i32 12 +; CHECK-NEXT: [[TMP14:%.*]] = bitcast double* [[TMP13]] to <4 x double>* +; CHECK-NEXT: [[TMP15:%.*]] = load <4 x double>, <4 x double>* [[TMP14]], align 8 +; CHECK-NEXT: [[TMP16:%.*]] = extractelement <4 x double> [[TMP3]], i64 0 +; CHECK-NEXT: [[TMP17:%.*]] = insertelement <4 x double> undef, double [[TMP16]], i64 0 +; CHECK-NEXT: [[TMP18:%.*]] = extractelement <4 x double> [[TMP7]], i64 0 +; CHECK-NEXT: [[TMP19:%.*]] = insertelement <4 x double> [[TMP17]], double [[TMP18]], i64 1 +; CHECK-NEXT: [[TMP20:%.*]] = extractelement <4 x double> [[TMP11]], i64 0 +; CHECK-NEXT: [[TMP21:%.*]] = insertelement <4 x double> [[TMP19]], double [[TMP20]], i64 2 +; CHECK-NEXT: [[TMP22:%.*]] = extractelement <4 x double> [[TMP15]], i64 0 +; CHECK-NEXT: [[TMP23:%.*]] = insertelement <4 x double> [[TMP21]], double [[TMP22]], i64 3 +; CHECK-NEXT: [[TMP24:%.*]] = extractelement <4 x double> [[TMP3]], i64 1 +; CHECK-NEXT: [[TMP25:%.*]] = insertelement <4 x double> undef, double [[TMP24]], i64 0 +; CHECK-NEXT: [[TMP26:%.*]] = extractelement <4 x double> [[TMP7]], i64 1 +; CHECK-NEXT: [[TMP27:%.*]] = insertelement <4 x double> [[TMP25]], double [[TMP26]], i64 1 +; CHECK-NEXT: [[TMP28:%.*]] = extractelement <4 x double> [[TMP11]], i64 1 +; CHECK-NEXT: [[TMP29:%.*]] = insertelement <4 x double> [[TMP27]], double [[TMP28]], i64 2 +; CHECK-NEXT: [[TMP30:%.*]] = extractelement <4 x double> [[TMP15]], i64 1 +; CHECK-NEXT: [[TMP31:%.*]] = insertelement <4 x double> [[TMP29]], double [[TMP30]], i64 3 +; CHECK-NEXT: [[TMP32:%.*]] = extractelement <4 x double> [[TMP3]], i64 2 +; CHECK-NEXT: [[TMP33:%.*]] = insertelement <4 x double> undef, double [[TMP32]], i64 0 +; CHECK-NEXT: [[TMP34:%.*]] = extractelement <4 x double> [[TMP7]], i64 2 +; CHECK-NEXT: [[TMP35:%.*]] = insertelement <4 x double> [[TMP33]], double [[TMP34]], i64 1 +; CHECK-NEXT: [[TMP36:%.*]] = extractelement <4 x double> [[TMP11]], i64 2 +; CHECK-NEXT: [[TMP37:%.*]] = insertelement <4 x double> [[TMP35]], double [[TMP36]], i64 2 +; CHECK-NEXT: [[TMP38:%.*]] = extractelement <4 x double> [[TMP15]], i64 2 +; CHECK-NEXT: [[TMP39:%.*]] = insertelement <4 x double> [[TMP37]], double [[TMP38]], i64 3 +; CHECK-NEXT: [[TMP40:%.*]] = extractelement <4 x double> [[TMP3]], i64 3 +; CHECK-NEXT: [[TMP41:%.*]] = insertelement <4 x double> undef, double [[TMP40]], i64 0 +; CHECK-NEXT: [[TMP42:%.*]] = extractelement <4 x double> [[TMP7]], i64 3 +; CHECK-NEXT: [[TMP43:%.*]] = insertelement <4 x double> [[TMP41]], double [[TMP42]], i64 1 +; CHECK-NEXT: [[TMP44:%.*]] = extractelement <4 x double> [[TMP11]], i64 3 +; CHECK-NEXT: [[TMP45:%.*]] = insertelement <4 x double> [[TMP43]], double [[TMP44]], i64 2 +; CHECK-NEXT: [[TMP46:%.*]] = extractelement <4 x double> [[TMP15]], i64 3 +; CHECK-NEXT: [[TMP47:%.*]] = insertelement <4 x double> [[TMP45]], double [[TMP46]], i64 3 +; CHECK-NEXT: [[TMP48:%.*]] = bitcast <16 x double>* [[B_PTR:%.*]] to double* +; CHECK-NEXT: [[TMP49:%.*]] = bitcast double* [[TMP48]] to <4 x double>* +; CHECK-NEXT: [[TMP50:%.*]] = load <4 x double>, <4 x double>* [[TMP49]], align 8 +; CHECK-NEXT: [[TMP52:%.*]] = getelementptr double, double* [[TMP48]], i32 4 +; CHECK-NEXT: [[TMP53:%.*]] = bitcast double* [[TMP52]] to <4 x double>* +; CHECK-NEXT: [[TMP54:%.*]] = load <4 x double>, <4 x double>* [[TMP53]], align 8 +; CHECK-NEXT: [[TMP56:%.*]] = getelementptr double, double* [[TMP48]], i32 8 +; CHECK-NEXT: [[TMP57:%.*]] = bitcast double* [[TMP56]] to <4 x double>* +; CHECK-NEXT: [[TMP58:%.*]] = load <4 x double>, <4 x double>* [[TMP57]], align 8 +; CHECK-NEXT: [[TMP60:%.*]] = getelementptr double, double* [[TMP48]], i32 12 +; CHECK-NEXT: [[TMP61:%.*]] = bitcast double* [[TMP60]] to <4 x double>* +; CHECK-NEXT: [[TMP62:%.*]] = load <4 x double>, <4 x double>* [[TMP61]], align 8 +; CHECK-NEXT: [[TMP63:%.*]] = fmul <4 x double> [[TMP3]], [[TMP50]] +; CHECK-NEXT: [[TMP64:%.*]] = fmul <4 x double> [[TMP7]], [[TMP54]] +; CHECK-NEXT: [[TMP65:%.*]] = fmul <4 x double> [[TMP11]], [[TMP58]] +; CHECK-NEXT: [[TMP66:%.*]] = fmul <4 x double> [[TMP15]], [[TMP62]] +; CHECK-NEXT: [[TMP67:%.*]] = shufflevector <4 x double> [[TMP63]], <4 x double> [[TMP64]], <8 x i32> +; CHECK-NEXT: [[TMP68:%.*]] = shufflevector <4 x double> [[TMP65]], <4 x double> [[TMP66]], <8 x i32> +; CHECK-NEXT: [[TMP69:%.*]] = shufflevector <8 x double> [[TMP67]], <8 x double> [[TMP68]], <16 x i32> +; CHECK-NEXT: ret <16 x double> [[TMP69]] +; + %A = load <16 x double>, <16 x double>* %A.Ptr + %A.trans = tail call <16 x double> @llvm.matrix.transpose.v16f64(<16 x double> %A, i32 4, i32 4) + %B = load <16 x double>, <16 x double>* %B.Ptr + %Mul = fmul <16 x double> %A, %B + ret <16 x double> %Mul +} + +declare <16 x double> @llvm.matrix.multiply.v16f64.v16f64.v16f64(<16 x double>, <16 x double>, i32 immarg, i32 immarg, i32 immarg) +declare <16 x double> @llvm.matrix.transpose.v16f64(<16 x double>, i32 immarg, i32 immarg) -- 2.7.4