From: Florian Hahn Date: Sun, 7 Jun 2020 10:11:27 +0000 (+0100) Subject: [Matrix] Implement * binary operator for MatrixType. X-Git-Tag: llvmorg-12-init~3851 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=4affc444b499ba2a12fee7f9ef0bb6ef33789c12;p=platform%2Fupstream%2Fllvm.git [Matrix] Implement * binary operator for MatrixType. This patch implements the * binary operator for values of MatrixType. It adds support for matrix * matrix, scalar * matrix and matrix * scalar. For the matrix, matrix case, the number of columns of the first operand must match the number of rows of the second. For the scalar,matrix variants, the element type of the matrix must match the scalar type. Reviewers: rjmccall, anemet, Bigcheese, rsmith, martong Reviewed By: rjmccall Differential Revision: https://reviews.llvm.org/D76794 --- diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h index cef25fc..8f914ec 100644 --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -11218,6 +11218,8 @@ public: QualType CheckMatrixElementwiseOperands(ExprResult &LHS, ExprResult &RHS, SourceLocation Loc, bool IsCompAssign); + QualType CheckMatrixMultiplyOperands(ExprResult &LHS, ExprResult &RHS, + SourceLocation Loc, bool IsCompAssign); bool areLaxCompatibleVectorTypes(QualType srcType, QualType destType); bool isLaxVectorConversion(QualType srcType, QualType destType); diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp index b169462..cabfe0c 100644 --- a/clang/lib/CodeGen/CGExprScalar.cpp +++ b/clang/lib/CodeGen/CGExprScalar.cpp @@ -741,6 +741,22 @@ public: } } + if (Ops.Ty->isConstantMatrixType()) { + llvm::MatrixBuilder MB(Builder); + // We need to check the types of the operands of the operator to get the + // correct matrix dimensions. + auto *BO = cast(Ops.E); + auto *LHSMatTy = dyn_cast( + BO->getLHS()->getType().getCanonicalType()); + auto *RHSMatTy = dyn_cast( + BO->getRHS()->getType().getCanonicalType()); + if (LHSMatTy && RHSMatTy) + return MB.CreateMatrixMultiply(Ops.LHS, Ops.RHS, LHSMatTy->getNumRows(), + LHSMatTy->getNumColumns(), + RHSMatTy->getNumColumns()); + return MB.CreateScalarMultiply(Ops.LHS, Ops.RHS); + } + if (Ops.Ty->isUnsignedIntegerType() && CGF.SanOpts.has(SanitizerKind::UnsignedIntegerOverflow) && !CanElideOverflowCheck(CGF.getContext(), Ops)) diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp index d7a39cf..45c1ace 100644 --- a/clang/lib/Sema/SemaExpr.cpp +++ b/clang/lib/Sema/SemaExpr.cpp @@ -10058,6 +10058,9 @@ QualType Sema::CheckMultiplyDivideOperands(ExprResult &LHS, ExprResult &RHS, return CheckVectorOperands(LHS, RHS, Loc, IsCompAssign, /*AllowBothBool*/getLangOpts().AltiVec, /*AllowBoolConversions*/false); + if (!IsDiv && (LHS.get()->getType()->isConstantMatrixType() || + RHS.get()->getType()->isConstantMatrixType())) + return CheckMatrixMultiplyOperands(LHS, RHS, Loc, IsCompAssign); QualType compType = UsualArithmeticConversions( LHS, RHS, Loc, IsCompAssign ? ACK_CompAssign : ACK_Arithmetic); @@ -12120,6 +12123,37 @@ QualType Sema::CheckMatrixElementwiseOperands(ExprResult &LHS, ExprResult &RHS, return InvalidOperands(Loc, LHS, RHS); } +QualType Sema::CheckMatrixMultiplyOperands(ExprResult &LHS, ExprResult &RHS, + SourceLocation Loc, + bool IsCompAssign) { + if (!IsCompAssign) { + LHS = DefaultFunctionArrayLvalueConversion(LHS.get()); + if (LHS.isInvalid()) + return QualType(); + } + RHS = DefaultFunctionArrayLvalueConversion(RHS.get()); + if (RHS.isInvalid()) + return QualType(); + + auto *LHSMatType = LHS.get()->getType()->getAs(); + auto *RHSMatType = RHS.get()->getType()->getAs(); + assert((LHSMatType || RHSMatType) && "At least one operand must be a matrix"); + + if (LHSMatType && RHSMatType) { + if (LHSMatType->getNumColumns() != RHSMatType->getNumRows()) + return InvalidOperands(Loc, LHS, RHS); + + if (!Context.hasSameType(LHSMatType->getElementType(), + RHSMatType->getElementType())) + return InvalidOperands(Loc, LHS, RHS); + + return Context.getConstantMatrixType(LHSMatType->getElementType(), + LHSMatType->getNumRows(), + RHSMatType->getNumColumns()); + } + return CheckMatrixElementwiseOperands(LHS, RHS, Loc, IsCompAssign); +} + inline QualType Sema::CheckBitwiseOperands(ExprResult &LHS, ExprResult &RHS, SourceLocation Loc, BinaryOperatorKind Opc) { diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp index d7fc966..5e5f539 100644 --- a/clang/lib/Sema/SemaOverload.cpp +++ b/clang/lib/Sema/SemaOverload.cpp @@ -9167,8 +9167,10 @@ void Sema::AddBuiltinOperatorCandidates(OverloadedOperatorKind Op, case OO_Star: // '*' is either unary or binary if (Args.size() == 1) OpBuilder.addUnaryStarPointerOverloads(); - else + else { OpBuilder.addGenericBinaryArithmeticOverloads(); + OpBuilder.addMatrixBinaryArithmeticOverloads(); + } break; case OO_Slash: diff --git a/clang/test/CodeGen/matrix-type-operators.c b/clang/test/CodeGen/matrix-type-operators.c index c649922..f44ecd7 100644 --- a/clang/test/CodeGen/matrix-type-operators.c +++ b/clang/test/CodeGen/matrix-type-operators.c @@ -173,6 +173,134 @@ void add_matrix_scalar_long_long_int_unsigned_long_long(ullx4x2_t b, unsigned lo b = vulli + b; } +// Tests for matrix multiplication. + +void multiply_matrix_matrix_double(dx5x5_t b, dx5x5_t c) { + // CHECK-LABEL: @multiply_matrix_matrix_double( + // CHECK: [[B:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8 + // CHECK-NEXT: [[C:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8 + // CHECK-NEXT: [[RES:%.*]] = call <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double> [[B]], <25 x double> [[C]], i32 5, i32 5, i32 5) + // CHECK-NEXT: [[A_ADDR:%.*]] = bitcast [25 x double]* %a to <25 x double>* + // CHECK-NEXT: store <25 x double> [[RES]], <25 x double>* [[A_ADDR]], align 8 + // CHECK-NEXT: ret void + // + + dx5x5_t a; + a = b * c; +} + +typedef int ix3x9_t __attribute__((matrix_type(3, 9))); +typedef int ix9x9_t __attribute__((matrix_type(9, 9))); +// CHECK-LABEL: @multiply_matrix_matrix_int( +// CHECK: [[B:%.*]] = load <27 x i32>, <27 x i32>* {{.*}}, align 4 +// CHECK-NEXT: [[C:%.*]] = load <27 x i32>, <27 x i32>* {{.*}}, align 4 +// CHECK-NEXT: [[RES:%.*]] = call <81 x i32> @llvm.matrix.multiply.v81i32.v27i32.v27i32(<27 x i32> [[B]], <27 x i32> [[C]], i32 9, i32 3, i32 9) +// CHECK-NEXT: [[A_ADDR:%.*]] = bitcast [81 x i32]* %a to <81 x i32>* +// CHECK-NEXT: store <81 x i32> [[RES]], <81 x i32>* [[A_ADDR]], align 4 +// CHECK-NEXT: ret void +// +void multiply_matrix_matrix_int(ix9x3_t b, ix3x9_t c) { + ix9x9_t a; + a = b * c; +} + +// CHECK-LABEL: @multiply_double_matrix_scalar_float( +// CHECK: [[A:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8 +// CHECK-NEXT: [[S:%.*]] = load float, float* %s.addr, align 4 +// CHECK-NEXT: [[S_EXT:%.*]] = fpext float [[S]] to double +// CHECK-NEXT: [[VECINSERT:%.*]] = insertelement <25 x double> undef, double [[S_EXT]], i32 0 +// CHECK-NEXT: [[VECSPLAT:%.*]] = shufflevector <25 x double> [[VECINSERT]], <25 x double> undef, <25 x i32> zeroinitializer +// CHECK-NEXT: [[RES:%.*]] = fmul <25 x double> [[A]], [[VECSPLAT]] +// CHECK-NEXT: store <25 x double> [[RES]], <25 x double>* {{.*}}, align 8 +// CHECK-NEXT: ret void +// +void multiply_double_matrix_scalar_float(dx5x5_t a, float s) { + a = a * s; +} + +// CHECK-LABEL: @multiply_double_matrix_scalar_double( +// CHECK: [[A:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8 +// CHECK-NEXT: [[S:%.*]] = load double, double* %s.addr, align 8 +// CHECK-NEXT: [[VECINSERT:%.*]] = insertelement <25 x double> undef, double [[S]], i32 0 +// CHECK-NEXT: [[VECSPLAT:%.*]] = shufflevector <25 x double> [[VECINSERT]], <25 x double> undef, <25 x i32> zeroinitializer +// CHECK-NEXT: [[RES:%.*]] = fmul <25 x double> [[A]], [[VECSPLAT]] +// CHECK-NEXT: store <25 x double> [[RES]], <25 x double>* {{.*}}, align 8 +// CHECK-NEXT: ret void +// +void multiply_double_matrix_scalar_double(dx5x5_t a, double s) { + a = a * s; +} + +// CHECK-LABEL: @multiply_float_matrix_scalar_double( +// CHECK: [[S:%.*]] = load double, double* %s.addr, align 8 +// CHECK-NEXT: [[S_TRUNC:%.*]] = fptrunc double [[S]] to float +// CHECK-NEXT: [[MAT:%.*]] = load <6 x float>, <6 x float>* [[MAT_ADDR:%.*]], align 4 +// CHECK-NEXT: [[VECINSERT:%.*]] = insertelement <6 x float> undef, float [[S_TRUNC]], i32 0 +// CHECK-NEXT: [[VECSPLAT:%.*]] = shufflevector <6 x float> [[VECINSERT]], <6 x float> undef, <6 x i32> zeroinitializer +// CHECK-NEXT: [[RES:%.*]] = fmul <6 x float> [[VECSPLAT]], [[MAT]] +// CHECK-NEXT: store <6 x float> [[RES]], <6 x float>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: ret void +// +void multiply_float_matrix_scalar_double(fx2x3_t b, double s) { + b = s * b; +} + +// CHECK-LABEL: @multiply_int_matrix_scalar_short( +// CHECK: [[S:%.*]] = load i16, i16* %s.addr, align 2 +// CHECK-NEXT: [[S_EXT:%.*]] = sext i16 [[S]] to i32 +// CHECK-NEXT: [[MAT:%.*]] = load <27 x i32>, <27 x i32>* [[MAT_ADDR:%.*]], align 4 +// CHECK-NEXT: [[VECINSERT:%.*]] = insertelement <27 x i32> undef, i32 [[S_EXT]], i32 0 +// CHECK-NEXT: [[VECSPLAT:%.*]] = shufflevector <27 x i32> [[VECINSERT]], <27 x i32> undef, <27 x i32> zeroinitializer +// CHECK-NEXT: [[RES:%.*]] = mul <27 x i32> [[VECSPLAT]], [[MAT]] +// CHECK-NEXT: store <27 x i32> [[RES]], <27 x i32>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: ret void +// +void multiply_int_matrix_scalar_short(ix9x3_t b, short s) { + b = s * b; +} + +// CHECK-LABEL: @multiply_int_matrix_scalar_ull( +// CHECK: [[MAT:%.*]] = load <27 x i32>, <27 x i32>* [[MAT_ADDR:%.*]], align 4 +// CHECK-NEXT: [[S:%.*]] = load i64, i64* %s.addr, align 8 +// CHECK-NEXT: [[S_TRUNC:%.*]] = trunc i64 [[S]] to i32 +// CHECK-NEXT: [[VECINSERT:%.*]] = insertelement <27 x i32> undef, i32 [[S_TRUNC]], i32 0 +// CHECK-NEXT: [[VECSPLAT:%.*]] = shufflevector <27 x i32> [[VECINSERT]], <27 x i32> undef, <27 x i32> zeroinitializer +// CHECK-NEXT: [[RES:%.*]] = mul <27 x i32> [[MAT]], [[VECSPLAT]] +// CHECK-NEXT: store <27 x i32> [[RES]], <27 x i32>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: ret void +// +void multiply_int_matrix_scalar_ull(ix9x3_t b, unsigned long long s) { + b = b * s; +} + +// CHECK-LABEL: @multiply_float_matrix_constant( +// CHECK-NEXT: entry: +// CHECK-NEXT: [[A_ADDR:%.*]] = alloca [6 x float], align 4 +// CHECK-NEXT: [[MAT_ADDR:%.*]] = bitcast [6 x float]* [[A_ADDR]] to <6 x float>* +// CHECK-NEXT: store <6 x float> [[A:%.*]], <6 x float>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: [[MAT:%.*]] = load <6 x float>, <6 x float>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: [[RES:%.*]] = fmul <6 x float> [[MAT]], +// CHECK-NEXT: store <6 x float> [[RES]], <6 x float>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: ret void +// +void multiply_float_matrix_constant(fx2x3_t a) { + a = a * 2.5; +} + +// CHECK-LABEL: @multiply_int_matrix_constant( +// CHECK-NEXT: entry: +// CHECK-NEXT: [[A_ADDR:%.*]] = alloca [27 x i32], align 4 +// CHECK-NEXT: [[MAT_ADDR:%.*]] = bitcast [27 x i32]* [[A_ADDR]] to <27 x i32>* +// CHECK-NEXT: store <27 x i32> [[A:%.*]], <27 x i32>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: [[MAT:%.*]] = load <27 x i32>, <27 x i32>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: [[RES:%.*]] = mul <27 x i32> , [[MAT]] +// CHECK-NEXT: store <27 x i32> [[RES]], <27 x i32>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: ret void +// +void multiply_int_matrix_constant(ix9x3_t a) { + a = 5 * a; +} + // Tests for the matrix type operators. typedef double dx5x5_t __attribute__((matrix_type(5, 5))); diff --git a/clang/test/CodeGenCXX/matrix-type-operators.cpp b/clang/test/CodeGenCXX/matrix-type-operators.cpp index a9eec96..5ca55d2 100644 --- a/clang/test/CodeGenCXX/matrix-type-operators.cpp +++ b/clang/test/CodeGenCXX/matrix-type-operators.cpp @@ -157,6 +157,45 @@ void test_IntWrapper_Sub(MyMatrix &m) { m.value = w3 - m.value; } +template +typename MyMatrix::matrix_t multiply(MyMatrix &A, MyMatrix &B) { + return A.value * B.value; +} + +MyMatrix test_multiply_template(MyMatrix Mat1, + MyMatrix Mat2) { + // CHECK-LABEL: define void @_Z22test_multiply_template8MyMatrixIfLj2ELj5EES_IfLj5ELj2EE( + // CHECK-NEXT: entry: + // CHECK-NEXT: [[RES:%.*]] = call <4 x float> @_Z8multiplyIfLj2ELj5ELj2EEN8MyMatrixIT_XT0_EXT2_EE8matrix_tERS0_IS1_XT0_EXT1_EERS0_IS1_XT1_EXT2_EE(%struct.MyMatrix* nonnull align 4 dereferenceable(40) %Mat1, %struct.MyMatrix.2* nonnull align 4 dereferenceable(40) %Mat2) + // CHECK-NEXT: %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %agg.result, i32 0, i32 0 + // CHECK-NEXT: [[VALUE_ADDR:%.*]] = bitcast [4 x float]* %value to <4 x float>* + // CHECK-NEXT: store <4 x float> [[RES]], <4 x float>* [[VALUE_ADDR]], align 4 + // CHECK-NEXT: ret void + // + // CHECK-LABEL: define linkonce_odr <4 x float> @_Z8multiplyIfLj2ELj5ELj2EEN8MyMatrixIT_XT0_EXT2_EE8matrix_tERS0_IS1_XT0_EXT1_EERS0_IS1_XT1_EXT2_EE( + // CHECK: [[MAT1:%.*]] = load <10 x float>, <10 x float>* {{.*}}, align 4 + // CHECK: [[MAT2:%.*]] = load <10 x float>, <10 x float>* {{.*}}, align 4 + // CHECK-NEXT: [[RES:%.*]] = call <4 x float> @llvm.matrix.multiply.v4f32.v10f32.v10f32(<10 x float> [[MAT1]], <10 x float> [[MAT2]], i32 2, i32 5, i32 2) + // CHECK-NEXT: ret <4 x float> [[RES]] + + MyMatrix Res; + Res.value = multiply(Mat1, Mat2); + return Res; +} + +void test_IntWrapper_Multiply(MyMatrix &m, IntWrapper &w3) { + // CHECK-LABEL: define void @_Z24test_IntWrapper_MultiplyR8MyMatrixIdLj10ELj9EER10IntWrapper( + // CHECK: [[SCALAR:%.*]] = call i32 @_ZN10IntWrappercviEv(%struct.IntWrapper* {{.*}}) + // CHECK-NEXT: [[SCALAR_FP:%.*]] = sitofp i32 %call to double + // CHECK: [[MATRIX:%.*]] = load <90 x double>, <90 x double>* {{.*}}, align 8 + // CHECK-NEXT: [[SCALAR_EMBED:%.*]] = insertelement <90 x double> undef, double [[SCALAR_FP]], i32 0 + // CHECK-NEXT: [[SCALAR_EMBED1:%.*]] = shufflevector <90 x double> [[SCALAR_EMBED]], <90 x double> undef, <90 x i32> zeroinitializer + // CHECK-NEXT: [[RES:%.*]] = fmul <90 x double> [[SCALAR_EMBED1]], [[MATRIX]] + // CHECK: store <90 x double> [[RES]], <90 x double>* {{.*}}, align 8 + // CHECK: ret void + m.value = w3 * m.value; +} + template void insert(MyMatrix &Mat, EltTy e, unsigned i, unsigned j) { Mat.value[i][j] = e; @@ -164,11 +203,11 @@ void insert(MyMatrix &Mat, EltTy e, unsigned i, unsigned j void test_insert_template1(MyMatrix &Mat, unsigned e, unsigned i, unsigned j) { // CHECK-LABEL: @_Z21test_insert_template1R8MyMatrixIjLj2ELj2EEjjj( - // CHECK: [[MAT_ADDR:%.*]] = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %Mat.addr, align 8 + // CHECK: [[MAT_ADDR:%.*]] = load %struct.MyMatrix.3*, %struct.MyMatrix.3** %Mat.addr, align 8 // CHECK-NEXT: [[E:%.*]] = load i32, i32* %e.addr, align 4 // CHECK-NEXT: [[I:%.*]] = load i32, i32* %i.addr, align 4 // CHECK-NEXT: [[J:%.*]] = load i32, i32* %j.addr, align 4 - // CHECK-NEXT: call void @_Z6insertIjLj2ELj2EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(%struct.MyMatrix.1* nonnull align 4 dereferenceable(16) [[MAT_ADDR]], i32 [[E]], i32 [[I]], i32 [[J]]) + // CHECK-NEXT: call void @_Z6insertIjLj2ELj2EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(%struct.MyMatrix.3* nonnull align 4 dereferenceable(16) [[MAT_ADDR]], i32 [[E]], i32 [[I]], i32 [[J]]) // CHECK-NEXT: ret void // // CHECK-LABEL: define linkonce_odr void @_Z6insertIjLj2ELj2EEvR8MyMatrixIT_XT0_EXT1_EES1_jj( @@ -190,9 +229,9 @@ void test_insert_template1(MyMatrix &Mat, unsigned e, unsigned i void test_insert_template2(MyMatrix &Mat, float e) { // CHECK-LABEL: @_Z21test_insert_template2R8MyMatrixIfLj3ELj8EEf( - // CHECK: [[MAT_ADDR:%.*]] = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %Mat.addr, align 8 + // CHECK: [[MAT_ADDR:%.*]] = load %struct.MyMatrix.4*, %struct.MyMatrix.4** %Mat.addr, align 8 // CHECK-NEXT: [[E:%.*]] = load float, float* %e.addr, align 4 - // CHECK-NEXT: call void @_Z6insertIfLj3ELj8EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(%struct.MyMatrix.2* nonnull align 4 dereferenceable(96) [[MAT_ADDR]], float [[E]], i32 2, i32 5) + // CHECK-NEXT: call void @_Z6insertIfLj3ELj8EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(%struct.MyMatrix.4* nonnull align 4 dereferenceable(96) [[MAT_ADDR]], float [[E]], i32 2, i32 5) // CHECK-NEXT: ret void // // CHECK-LABEL: define linkonce_odr void @_Z6insertIfLj3ELj8EEvR8MyMatrixIT_XT0_EXT1_EES1_jj( @@ -220,7 +259,7 @@ EltTy extract(MyMatrix &Mat) { int test_extract_template(MyMatrix Mat1) { // CHECK-LABEL: @_Z21test_extract_template8MyMatrixIiLj2ELj2EE( // CHECK-NEXT: entry: - // CHECK-NEXT: [[CALL:%.*]] = call i32 @_Z7extractIiLj2ELj2EET_R8MyMatrixIS0_XT0_EXT1_EE(%struct.MyMatrix.3* nonnull align 4 dereferenceable(16) [[MAT1:%.*]]) + // CHECK-NEXT: [[CALL:%.*]] = call i32 @_Z7extractIiLj2ELj2EET_R8MyMatrixIS0_XT0_EXT1_EE(%struct.MyMatrix.5* nonnull align 4 dereferenceable(16) [[MAT1:%.*]]) // CHECK-NEXT: ret i32 [[CALL]] // // CHECK-LABEL: define linkonce_odr i32 @_Z7extractIiLj2ELj2EET_R8MyMatrixIS0_XT0_EXT1_EE( @@ -301,7 +340,7 @@ struct identmatrix_t { constexpr identmatrix_t identmatrix; void test_constexpr1(matrix_type &m) { - // CHECK-LABEL: define void @_Z15test_constexpr1RU11matrix_typeLm4ELm4Ef([16 x float]* nonnull align 4 dereferenceable(64) %m) #3 { + // CHECK-LABEL: define void @_Z15test_constexpr1RU11matrix_typeLm4ELm4Ef( // CHECK: [[MAT:%.*]] = load <16 x float>, <16 x float>* {{.*}}, align 4 // CHECK-NEXT: [[IM:%.*]] = call <16 x float> @_ZNK13identmatrix_tcvU11matrix_typeXT0_EXT0_ET_IfLj4EEEv(%struct.identmatrix_t* @_ZL11identmatrix) // CHECK-NEXT: [[ADD:%.*]] = fadd <16 x float> [[MAT]], [[IM]] @@ -327,7 +366,7 @@ void test_constexpr1(matrix_type &m) { } void test_constexpr2(matrix_type &m) { - // CHECK-LABEL: define void @_Z15test_constexpr2RU11matrix_typeLm5ELm5Ei([25 x i32]* nonnull align 4 dereferenceable(100) %m) #4 { + // CHECK-LABEL: define void @_Z15test_constexpr2RU11matrix_typeLm5ELm5Ei( // CHECK: [[IM:%.*]] = call <25 x i32> @_ZNK13identmatrix_tcvU11matrix_typeXT0_EXT0_ET_IiLj5EEEv(%struct.identmatrix_t* @_ZL11identmatrix) // CHECK: [[MAT:%.*]] = load <25 x i32>, <25 x i32>* {{.*}}, align 4 // CHECK-NEXT: [[SUB:%.*]] = sub <25 x i32> [[IM]], [[MAT]] diff --git a/clang/test/Sema/matrix-type-operators.c b/clang/test/Sema/matrix-type-operators.c index 0de0dda..397fc35 100644 --- a/clang/test/Sema/matrix-type-operators.c +++ b/clang/test/Sema/matrix-type-operators.c @@ -32,6 +32,44 @@ void sub(sx10x10_t a, sx5x10_t b, sx10x5_t c) { // expected-error@-2 {{casting 'sx10x5_t *' (aka 'float __attribute__((matrix_type(10, 5)))*') to incompatible type 'float'}} } +typedef int ix10x5_t __attribute__((matrix_type(10, 5))); +typedef int ix10x10_t __attribute__((matrix_type(10, 10))); + +void matrix_matrix_multiply(sx10x10_t a, sx5x10_t b, ix10x5_t c, ix10x10_t d, float sf, char *p) { + // Check dimension mismatches. + a = a * b; + // expected-error@-1 {{invalid operands to binary expression ('sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') and 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))'))}} + b = a * a; + // expected-error@-1 {{assigning to 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') from incompatible type 'float __attribute__((matrix_type(10, 10)))'}} + + // Check element type mismatches. + a = b * c; + // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'ix10x5_t' (aka 'int __attribute__((matrix_type(10, 5)))'))}} + d = a * a; + // expected-error@-1 {{assigning to 'ix10x10_t' (aka 'int __attribute__((matrix_type(10, 10)))') from incompatible type 'float __attribute__((matrix_type(10, 10)))'}} + + p = a * a; + // expected-error@-1 {{assigning to 'char *' from incompatible type 'float __attribute__((matrix_type(10, 10)))'}} +} + +void mat_scalar_multiply(sx10x10_t a, sx5x10_t b, float sf, char *p) { + // Shape of multiplication result does not match the type of b. + b = a * sf; + // expected-error@-1 {{assigning to 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') from incompatible type 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))')}} + b = sf * a; + // expected-error@-1 {{assigning to 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') from incompatible type 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))')}} + + a = a * p; + // expected-error@-1 {{casting 'char *' to incompatible type 'float'}} + // expected-error@-2 {{invalid operands to binary expression ('sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') and 'char *')}} + a = p * a; + // expected-error@-1 {{casting 'char *' to incompatible type 'float'}} + // expected-error@-2 {{invalid operands to binary expression ('char *' and 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))'))}} + + sf = a * sf; + // expected-error@-1 {{assigning to 'float' from incompatible type 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))')}} +} + sx5x10_t get_matrix(); void insert(sx5x10_t a, float f) { diff --git a/clang/test/SemaCXX/matrix-type-operators.cpp b/clang/test/SemaCXX/matrix-type-operators.cpp index 337d44f..808e474 100644 --- a/clang/test/SemaCXX/matrix-type-operators.cpp +++ b/clang/test/SemaCXX/matrix-type-operators.cpp @@ -65,6 +65,45 @@ void test_subtract_template(unsigned *Ptr1, float *Ptr2) { // expected-note@-1 {{in instantiation of function template specialization 'subtract' requested here}} } +template +typename MyMatrix::matrix_t multiply(MyMatrix &A, MyMatrix &B) { + char *v1 = A.value * B.value; + // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'unsigned int __attribute__((matrix_type(2, 2)))'}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 2)))') and 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}} + // expected-error@-3 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))') and 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))'))}} + + MyMatrix m; + B.value = m.value * A.value; + // expected-error@-1 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'int __attribute__((matrix_type(5, 6)))') and 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))'))}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'int __attribute__((matrix_type(5, 6)))') and 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 2)))'))}} + // expected-error@-3 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'int __attribute__((matrix_type(5, 6)))') and 'MyMatrix::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))'))}} + + return A.value * B.value; + // expected-error@-1 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 2)))') and 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))') and 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))'))}} +} + +void test_multiply_template(unsigned *Ptr1, float *Ptr2) { + MyMatrix Mat1; + MyMatrix Mat2; + MyMatrix Mat3; + Mat1.value = *((decltype(Mat1)::matrix_t *)Ptr1); + unsigned v1 = multiply(Mat1, Mat1); + // expected-note@-1 {{in instantiation of function template specialization 'multiply' requested here}} + // expected-error@-2 {{cannot initialize a variable of type 'unsigned int' with an rvalue of type 'typename MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}} + + MyMatrix Mat4; + Mat1.value = multiply(Mat4, Mat2); + // expected-note@-1 {{in instantiation of function template specialization 'multiply' requested here}} + + Mat1.value = multiply(Mat3, Mat1); + // expected-note@-1 {{in instantiation of function template specialization 'multiply' requested here}} + + Mat4.value = Mat4.value * Mat1; + // expected-error@-1 {{no viable conversion from 'MyMatrix' to 'unsigned int'}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 2)))') and 'MyMatrix')}} +} + struct UserT {}; struct StructWithC { diff --git a/llvm/include/llvm/IR/MatrixBuilder.h b/llvm/include/llvm/IR/MatrixBuilder.h index 20aa47e..4833ed2 100644 --- a/llvm/include/llvm/IR/MatrixBuilder.h +++ b/llvm/include/llvm/IR/MatrixBuilder.h @@ -33,6 +33,21 @@ template class MatrixBuilder { IRBuilderTy &B; Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); } + std::pair splatScalarOperandIfNeeded(Value *LHS, + Value *RHS) { + assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) && + "One of the operands must be a matrix (embedded in a vector)"); + if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) + RHS = B.CreateVectorSplat( + cast(LHS->getType())->getNumElements(), RHS, + "scalar.splat"); + else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) + LHS = B.CreateVectorSplat( + cast(RHS->getType())->getNumElements(), LHS, + "scalar.splat"); + return {LHS, RHS}; + } + public: MatrixBuilder(IRBuilderTy &Builder) : B(Builder) {} @@ -164,15 +179,13 @@ public: : B.CreateSub(LHS, RHS); } - /// Multiply matrix \p LHS with scalar \p RHS. + /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p + /// RHS. Value *CreateScalarMultiply(Value *LHS, Value *RHS) { - Value *ScalarVector = - B.CreateVectorSplat(cast(LHS->getType())->getNumElements(), - RHS, "scalar.splat"); - if (RHS->getType()->isFloatingPointTy()) - return B.CreateFMul(LHS, ScalarVector); - - return B.CreateMul(LHS, ScalarVector); + std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS); + if (LHS->getType()->getScalarType()->isFloatingPointTy()) + return B.CreateFMul(LHS, RHS); + return B.CreateMul(LHS, RHS); } /// Extracts the element at (\p RowIdx, \p ColumnIdx) from \p Matrix.