[Matrix] Implement * binary operator for MatrixType.
authorFlorian Hahn <flo@fhahn.com>
Sun, 7 Jun 2020 10:11:27 +0000 (11:11 +0100)
committerFlorian Hahn <flo@fhahn.com>
Sun, 7 Jun 2020 10:11:27 +0000 (11:11 +0100)
This patch implements the * binary operator for values of
MatrixType. It adds support for matrix * matrix, scalar * matrix and
matrix * scalar.

For the matrix, matrix case, the number of columns of the first operand
must match the number of rows of the second. For the scalar,matrix variants,
the element type of the matrix must match the scalar type.

Reviewers: rjmccall, anemet, Bigcheese, rsmith, martong

Reviewed By: rjmccall

Differential Revision: https://reviews.llvm.org/D76794

clang/include/clang/Sema/Sema.h
clang/lib/CodeGen/CGExprScalar.cpp
clang/lib/Sema/SemaExpr.cpp
clang/lib/Sema/SemaOverload.cpp
clang/test/CodeGen/matrix-type-operators.c
clang/test/CodeGenCXX/matrix-type-operators.cpp
clang/test/Sema/matrix-type-operators.c
clang/test/SemaCXX/matrix-type-operators.cpp
llvm/include/llvm/IR/MatrixBuilder.h

index cef25fc..8f914ec 100644 (file)
@@ -11218,6 +11218,8 @@ public:
   QualType CheckMatrixElementwiseOperands(ExprResult &LHS, ExprResult &RHS,
                                           SourceLocation Loc,
                                           bool IsCompAssign);
+  QualType CheckMatrixMultiplyOperands(ExprResult &LHS, ExprResult &RHS,
+                                       SourceLocation Loc, bool IsCompAssign);
 
   bool areLaxCompatibleVectorTypes(QualType srcType, QualType destType);
   bool isLaxVectorConversion(QualType srcType, QualType destType);
index b169462..cabfe0c 100644 (file)
@@ -741,6 +741,22 @@ public:
       }
     }
 
+    if (Ops.Ty->isConstantMatrixType()) {
+      llvm::MatrixBuilder<CGBuilderTy> MB(Builder);
+      // We need to check the types of the operands of the operator to get the
+      // correct matrix dimensions.
+      auto *BO = cast<BinaryOperator>(Ops.E);
+      auto *LHSMatTy = dyn_cast<ConstantMatrixType>(
+          BO->getLHS()->getType().getCanonicalType());
+      auto *RHSMatTy = dyn_cast<ConstantMatrixType>(
+          BO->getRHS()->getType().getCanonicalType());
+      if (LHSMatTy && RHSMatTy)
+        return MB.CreateMatrixMultiply(Ops.LHS, Ops.RHS, LHSMatTy->getNumRows(),
+                                       LHSMatTy->getNumColumns(),
+                                       RHSMatTy->getNumColumns());
+      return MB.CreateScalarMultiply(Ops.LHS, Ops.RHS);
+    }
+
     if (Ops.Ty->isUnsignedIntegerType() &&
         CGF.SanOpts.has(SanitizerKind::UnsignedIntegerOverflow) &&
         !CanElideOverflowCheck(CGF.getContext(), Ops))
index d7a39cf..45c1ace 100644 (file)
@@ -10058,6 +10058,9 @@ QualType Sema::CheckMultiplyDivideOperands(ExprResult &LHS, ExprResult &RHS,
     return CheckVectorOperands(LHS, RHS, Loc, IsCompAssign,
                                /*AllowBothBool*/getLangOpts().AltiVec,
                                /*AllowBoolConversions*/false);
+  if (!IsDiv && (LHS.get()->getType()->isConstantMatrixType() ||
+                 RHS.get()->getType()->isConstantMatrixType()))
+    return CheckMatrixMultiplyOperands(LHS, RHS, Loc, IsCompAssign);
 
   QualType compType = UsualArithmeticConversions(
       LHS, RHS, Loc, IsCompAssign ? ACK_CompAssign : ACK_Arithmetic);
@@ -12120,6 +12123,37 @@ QualType Sema::CheckMatrixElementwiseOperands(ExprResult &LHS, ExprResult &RHS,
   return InvalidOperands(Loc, LHS, RHS);
 }
 
+QualType Sema::CheckMatrixMultiplyOperands(ExprResult &LHS, ExprResult &RHS,
+                                           SourceLocation Loc,
+                                           bool IsCompAssign) {
+  if (!IsCompAssign) {
+    LHS = DefaultFunctionArrayLvalueConversion(LHS.get());
+    if (LHS.isInvalid())
+      return QualType();
+  }
+  RHS = DefaultFunctionArrayLvalueConversion(RHS.get());
+  if (RHS.isInvalid())
+    return QualType();
+
+  auto *LHSMatType = LHS.get()->getType()->getAs<ConstantMatrixType>();
+  auto *RHSMatType = RHS.get()->getType()->getAs<ConstantMatrixType>();
+  assert((LHSMatType || RHSMatType) && "At least one operand must be a matrix");
+
+  if (LHSMatType && RHSMatType) {
+    if (LHSMatType->getNumColumns() != RHSMatType->getNumRows())
+      return InvalidOperands(Loc, LHS, RHS);
+
+    if (!Context.hasSameType(LHSMatType->getElementType(),
+                             RHSMatType->getElementType()))
+      return InvalidOperands(Loc, LHS, RHS);
+
+    return Context.getConstantMatrixType(LHSMatType->getElementType(),
+                                         LHSMatType->getNumRows(),
+                                         RHSMatType->getNumColumns());
+  }
+  return CheckMatrixElementwiseOperands(LHS, RHS, Loc, IsCompAssign);
+}
+
 inline QualType Sema::CheckBitwiseOperands(ExprResult &LHS, ExprResult &RHS,
                                            SourceLocation Loc,
                                            BinaryOperatorKind Opc) {
index d7fc966..5e5f539 100644 (file)
@@ -9167,8 +9167,10 @@ void Sema::AddBuiltinOperatorCandidates(OverloadedOperatorKind Op,
   case OO_Star: // '*' is either unary or binary
     if (Args.size() == 1)
       OpBuilder.addUnaryStarPointerOverloads();
-    else
+    else {
       OpBuilder.addGenericBinaryArithmeticOverloads();
+      OpBuilder.addMatrixBinaryArithmeticOverloads();
+    }
     break;
 
   case OO_Slash:
index c649922..f44ecd7 100644 (file)
@@ -173,6 +173,134 @@ void add_matrix_scalar_long_long_int_unsigned_long_long(ullx4x2_t b, unsigned lo
   b = vulli + b;
 }
 
+// Tests for matrix multiplication.
+
+void multiply_matrix_matrix_double(dx5x5_t b, dx5x5_t c) {
+  // CHECK-LABEL: @multiply_matrix_matrix_double(
+  // CHECK:         [[B:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8
+  // CHECK-NEXT:    [[C:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8
+  // CHECK-NEXT:    [[RES:%.*]] = call <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double> [[B]], <25 x double> [[C]], i32 5, i32 5, i32 5)
+  // CHECK-NEXT:    [[A_ADDR:%.*]] = bitcast [25 x double]* %a to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> [[RES]], <25 x double>* [[A_ADDR]], align 8
+  // CHECK-NEXT:    ret void
+  //
+
+  dx5x5_t a;
+  a = b * c;
+}
+
+typedef int ix3x9_t __attribute__((matrix_type(3, 9)));
+typedef int ix9x9_t __attribute__((matrix_type(9, 9)));
+// CHECK-LABEL: @multiply_matrix_matrix_int(
+// CHECK:         [[B:%.*]] = load <27 x i32>, <27 x i32>* {{.*}}, align 4
+// CHECK-NEXT:    [[C:%.*]] = load <27 x i32>, <27 x i32>* {{.*}}, align 4
+// CHECK-NEXT:    [[RES:%.*]] = call <81 x i32> @llvm.matrix.multiply.v81i32.v27i32.v27i32(<27 x i32> [[B]], <27 x i32> [[C]], i32 9, i32 3, i32 9)
+// CHECK-NEXT:    [[A_ADDR:%.*]] = bitcast [81 x i32]* %a to <81 x i32>*
+// CHECK-NEXT:    store <81 x i32> [[RES]], <81 x i32>* [[A_ADDR]], align 4
+// CHECK-NEXT:    ret void
+//
+void multiply_matrix_matrix_int(ix9x3_t b, ix3x9_t c) {
+  ix9x9_t a;
+  a = b * c;
+}
+
+// CHECK-LABEL: @multiply_double_matrix_scalar_float(
+// CHECK:         [[A:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8
+// CHECK-NEXT:    [[S:%.*]] = load float, float* %s.addr, align 4
+// CHECK-NEXT:    [[S_EXT:%.*]] = fpext float [[S]] to double
+// CHECK-NEXT:    [[VECINSERT:%.*]] = insertelement <25 x double> undef, double [[S_EXT]], i32 0
+// CHECK-NEXT:    [[VECSPLAT:%.*]] = shufflevector <25 x double> [[VECINSERT]], <25 x double> undef, <25 x i32> zeroinitializer
+// CHECK-NEXT:    [[RES:%.*]] = fmul <25 x double> [[A]], [[VECSPLAT]]
+// CHECK-NEXT:    store <25 x double> [[RES]], <25 x double>* {{.*}}, align 8
+// CHECK-NEXT:    ret void
+//
+void multiply_double_matrix_scalar_float(dx5x5_t a, float s) {
+  a = a * s;
+}
+
+// CHECK-LABEL: @multiply_double_matrix_scalar_double(
+// CHECK:         [[A:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8
+// CHECK-NEXT:    [[S:%.*]] = load double, double* %s.addr, align 8
+// CHECK-NEXT:    [[VECINSERT:%.*]] = insertelement <25 x double> undef, double [[S]], i32 0
+// CHECK-NEXT:    [[VECSPLAT:%.*]] = shufflevector <25 x double> [[VECINSERT]], <25 x double> undef, <25 x i32> zeroinitializer
+// CHECK-NEXT:    [[RES:%.*]] = fmul <25 x double> [[A]], [[VECSPLAT]]
+// CHECK-NEXT:    store <25 x double> [[RES]], <25 x double>* {{.*}}, align 8
+// CHECK-NEXT:    ret void
+//
+void multiply_double_matrix_scalar_double(dx5x5_t a, double s) {
+  a = a * s;
+}
+
+// CHECK-LABEL: @multiply_float_matrix_scalar_double(
+// CHECK:         [[S:%.*]] = load double, double* %s.addr, align 8
+// CHECK-NEXT:    [[S_TRUNC:%.*]] = fptrunc double [[S]] to float
+// CHECK-NEXT:    [[MAT:%.*]] = load <6 x float>, <6 x float>* [[MAT_ADDR:%.*]], align 4
+// CHECK-NEXT:    [[VECINSERT:%.*]] = insertelement <6 x float> undef, float [[S_TRUNC]], i32 0
+// CHECK-NEXT:    [[VECSPLAT:%.*]] = shufflevector <6 x float> [[VECINSERT]], <6 x float> undef, <6 x i32> zeroinitializer
+// CHECK-NEXT:    [[RES:%.*]] = fmul <6 x float> [[VECSPLAT]], [[MAT]]
+// CHECK-NEXT:    store <6 x float> [[RES]], <6 x float>* [[MAT_ADDR]], align 4
+// CHECK-NEXT:    ret void
+//
+void multiply_float_matrix_scalar_double(fx2x3_t b, double s) {
+  b = s * b;
+}
+
+// CHECK-LABEL: @multiply_int_matrix_scalar_short(
+// CHECK:         [[S:%.*]] = load i16, i16* %s.addr, align 2
+// CHECK-NEXT:    [[S_EXT:%.*]] = sext i16 [[S]] to i32
+// CHECK-NEXT:    [[MAT:%.*]] = load <27 x i32>, <27 x i32>* [[MAT_ADDR:%.*]], align 4
+// CHECK-NEXT:    [[VECINSERT:%.*]] = insertelement <27 x i32> undef, i32 [[S_EXT]], i32 0
+// CHECK-NEXT:    [[VECSPLAT:%.*]] = shufflevector <27 x i32> [[VECINSERT]], <27 x i32> undef, <27 x i32> zeroinitializer
+// CHECK-NEXT:    [[RES:%.*]] = mul <27 x i32> [[VECSPLAT]], [[MAT]]
+// CHECK-NEXT:    store <27 x i32> [[RES]], <27 x i32>* [[MAT_ADDR]], align 4
+// CHECK-NEXT:    ret void
+//
+void multiply_int_matrix_scalar_short(ix9x3_t b, short s) {
+  b = s * b;
+}
+
+// CHECK-LABEL: @multiply_int_matrix_scalar_ull(
+// CHECK:         [[MAT:%.*]] = load <27 x i32>, <27 x i32>* [[MAT_ADDR:%.*]], align 4
+// CHECK-NEXT:    [[S:%.*]] = load i64, i64* %s.addr, align 8
+// CHECK-NEXT:    [[S_TRUNC:%.*]] = trunc i64 [[S]] to i32
+// CHECK-NEXT:    [[VECINSERT:%.*]] = insertelement <27 x i32> undef, i32 [[S_TRUNC]], i32 0
+// CHECK-NEXT:    [[VECSPLAT:%.*]] = shufflevector <27 x i32> [[VECINSERT]], <27 x i32> undef, <27 x i32> zeroinitializer
+// CHECK-NEXT:    [[RES:%.*]] = mul <27 x i32> [[MAT]], [[VECSPLAT]]
+// CHECK-NEXT:    store <27 x i32> [[RES]], <27 x i32>* [[MAT_ADDR]], align 4
+// CHECK-NEXT:    ret void
+//
+void multiply_int_matrix_scalar_ull(ix9x3_t b, unsigned long long s) {
+  b = b * s;
+}
+
+// CHECK-LABEL: @multiply_float_matrix_constant(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[A_ADDR:%.*]] = alloca [6 x float], align 4
+// CHECK-NEXT:    [[MAT_ADDR:%.*]] = bitcast [6 x float]* [[A_ADDR]] to <6 x float>*
+// CHECK-NEXT:    store <6 x float> [[A:%.*]], <6 x float>* [[MAT_ADDR]], align 4
+// CHECK-NEXT:    [[MAT:%.*]] = load <6 x float>, <6 x float>* [[MAT_ADDR]], align 4
+// CHECK-NEXT:    [[RES:%.*]] = fmul <6 x float> [[MAT]], <float 2.500000e+00, float 2.500000e+00, float 2.500000e+00, float 2.500000e+00, float 2.500000e+00, float 2.500000e+00>
+// CHECK-NEXT:    store <6 x float> [[RES]], <6 x float>* [[MAT_ADDR]], align 4
+// CHECK-NEXT:    ret void
+//
+void multiply_float_matrix_constant(fx2x3_t a) {
+  a = a * 2.5;
+}
+
+// CHECK-LABEL: @multiply_int_matrix_constant(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[A_ADDR:%.*]] = alloca [27 x i32], align 4
+// CHECK-NEXT:    [[MAT_ADDR:%.*]] = bitcast [27 x i32]* [[A_ADDR]] to <27 x i32>*
+// CHECK-NEXT:    store <27 x i32> [[A:%.*]], <27 x i32>* [[MAT_ADDR]], align 4
+// CHECK-NEXT:    [[MAT:%.*]] = load <27 x i32>, <27 x i32>* [[MAT_ADDR]], align 4
+// CHECK-NEXT:    [[RES:%.*]] = mul <27 x i32> <i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5, i32 5>, [[MAT]]
+// CHECK-NEXT:    store <27 x i32> [[RES]], <27 x i32>* [[MAT_ADDR]], align 4
+// CHECK-NEXT:    ret void
+//
+void multiply_int_matrix_constant(ix9x3_t a) {
+  a = 5 * a;
+}
+
 // Tests for the matrix type operators.
 
 typedef double dx5x5_t __attribute__((matrix_type(5, 5)));
index a9eec96..5ca55d2 100644 (file)
@@ -157,6 +157,45 @@ void test_IntWrapper_Sub(MyMatrix<double, 10, 9> &m) {
   m.value = w3 - m.value;
 }
 
+template <typename EltTy0, unsigned R0, unsigned C0, unsigned C1>
+typename MyMatrix<EltTy0, R0, C1>::matrix_t multiply(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy0, C0, C1> &B) {
+  return A.value * B.value;
+}
+
+MyMatrix<float, 2, 2> test_multiply_template(MyMatrix<float, 2, 5> Mat1,
+                                             MyMatrix<float, 5, 2> Mat2) {
+  // CHECK-LABEL: define void @_Z22test_multiply_template8MyMatrixIfLj2ELj5EES_IfLj5ELj2EE(
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    [[RES:%.*]] = call <4 x float> @_Z8multiplyIfLj2ELj5ELj2EEN8MyMatrixIT_XT0_EXT2_EE8matrix_tERS0_IS1_XT0_EXT1_EERS0_IS1_XT1_EXT2_EE(%struct.MyMatrix* nonnull align 4 dereferenceable(40) %Mat1, %struct.MyMatrix.2* nonnull align 4 dereferenceable(40) %Mat2)
+  // CHECK-NEXT:    %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %agg.result, i32 0, i32 0
+  // CHECK-NEXT:    [[VALUE_ADDR:%.*]] = bitcast [4 x float]* %value to <4 x float>*
+  // CHECK-NEXT:    store <4 x float> [[RES]], <4 x float>* [[VALUE_ADDR]], align 4
+  // CHECK-NEXT:    ret void
+  //
+  // CHECK-LABEL:  define linkonce_odr <4 x float> @_Z8multiplyIfLj2ELj5ELj2EEN8MyMatrixIT_XT0_EXT2_EE8matrix_tERS0_IS1_XT0_EXT1_EERS0_IS1_XT1_EXT2_EE(
+  // CHECK:         [[MAT1:%.*]] = load <10 x float>, <10 x float>* {{.*}}, align 4
+  // CHECK:         [[MAT2:%.*]] = load <10 x float>, <10 x float>* {{.*}}, align 4
+  // CHECK-NEXT:    [[RES:%.*]] = call <4 x float> @llvm.matrix.multiply.v4f32.v10f32.v10f32(<10 x float> [[MAT1]], <10 x float> [[MAT2]], i32 2, i32 5, i32 2)
+  // CHECK-NEXT:    ret <4 x float> [[RES]]
+
+  MyMatrix<float, 2, 2> Res;
+  Res.value = multiply(Mat1, Mat2);
+  return Res;
+}
+
+void test_IntWrapper_Multiply(MyMatrix<double, 10, 9> &m, IntWrapper &w3) {
+  // CHECK-LABEL: define void @_Z24test_IntWrapper_MultiplyR8MyMatrixIdLj10ELj9EER10IntWrapper(
+  // CHECK:       [[SCALAR:%.*]] = call i32 @_ZN10IntWrappercviEv(%struct.IntWrapper* {{.*}})
+  // CHECK-NEXT:  [[SCALAR_FP:%.*]] = sitofp i32 %call to double
+  // CHECK:       [[MATRIX:%.*]] = load <90 x double>, <90 x double>* {{.*}}, align 8
+  // CHECK-NEXT:  [[SCALAR_EMBED:%.*]] = insertelement <90 x double> undef, double [[SCALAR_FP]], i32 0
+  // CHECK-NEXT:  [[SCALAR_EMBED1:%.*]] = shufflevector <90 x double> [[SCALAR_EMBED]], <90 x double> undef, <90 x i32> zeroinitializer
+  // CHECK-NEXT:  [[RES:%.*]] = fmul <90 x double> [[SCALAR_EMBED1]], [[MATRIX]]
+  // CHECK:       store <90 x double> [[RES]], <90 x double>* {{.*}}, align 8
+  // CHECK:       ret void
+  m.value = w3 * m.value;
+}
+
 template <typename EltTy, unsigned Rows, unsigned Columns>
 void insert(MyMatrix<EltTy, Rows, Columns> &Mat, EltTy e, unsigned i, unsigned j) {
   Mat.value[i][j] = e;
@@ -164,11 +203,11 @@ void insert(MyMatrix<EltTy, Rows, Columns> &Mat, EltTy e, unsigned i, unsigned j
 
 void test_insert_template1(MyMatrix<unsigned, 2, 2> &Mat, unsigned e, unsigned i, unsigned j) {
   // CHECK-LABEL: @_Z21test_insert_template1R8MyMatrixIjLj2ELj2EEjjj(
-  // CHECK:         [[MAT_ADDR:%.*]] = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %Mat.addr, align 8
+  // CHECK:         [[MAT_ADDR:%.*]] = load %struct.MyMatrix.3*, %struct.MyMatrix.3** %Mat.addr, align 8
   // CHECK-NEXT:    [[E:%.*]] = load i32, i32* %e.addr, align 4
   // CHECK-NEXT:    [[I:%.*]] = load i32, i32* %i.addr, align 4
   // CHECK-NEXT:    [[J:%.*]] = load i32, i32* %j.addr, align 4
-  // CHECK-NEXT:    call void @_Z6insertIjLj2ELj2EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(%struct.MyMatrix.1* nonnull align 4 dereferenceable(16) [[MAT_ADDR]], i32 [[E]], i32 [[I]], i32 [[J]])
+  // CHECK-NEXT:    call void @_Z6insertIjLj2ELj2EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(%struct.MyMatrix.3* nonnull align 4 dereferenceable(16) [[MAT_ADDR]], i32 [[E]], i32 [[I]], i32 [[J]])
   // CHECK-NEXT:    ret void
   //
   // CHECK-LABEL: define linkonce_odr void @_Z6insertIjLj2ELj2EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(
@@ -190,9 +229,9 @@ void test_insert_template1(MyMatrix<unsigned, 2, 2> &Mat, unsigned e, unsigned i
 
 void test_insert_template2(MyMatrix<float, 3, 8> &Mat, float e) {
   // CHECK-LABEL: @_Z21test_insert_template2R8MyMatrixIfLj3ELj8EEf(
-  // CHECK:         [[MAT_ADDR:%.*]] = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %Mat.addr, align 8
+  // CHECK:         [[MAT_ADDR:%.*]] = load %struct.MyMatrix.4*, %struct.MyMatrix.4** %Mat.addr, align 8
   // CHECK-NEXT:    [[E:%.*]] = load float, float* %e.addr, align 4
-  // CHECK-NEXT:    call void @_Z6insertIfLj3ELj8EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(%struct.MyMatrix.2* nonnull align 4 dereferenceable(96) [[MAT_ADDR]], float [[E]], i32 2, i32 5)
+  // CHECK-NEXT:    call void @_Z6insertIfLj3ELj8EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(%struct.MyMatrix.4* nonnull align 4 dereferenceable(96) [[MAT_ADDR]], float [[E]], i32 2, i32 5)
   // CHECK-NEXT:    ret void
   //
   // CHECK-LABEL: define linkonce_odr void @_Z6insertIfLj3ELj8EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(
@@ -220,7 +259,7 @@ EltTy extract(MyMatrix<EltTy, Rows, Columns> &Mat) {
 int test_extract_template(MyMatrix<int, 2, 2> Mat1) {
   // CHECK-LABEL: @_Z21test_extract_template8MyMatrixIiLj2ELj2EE(
   // CHECK-NEXT:  entry:
-  // CHECK-NEXT:    [[CALL:%.*]] = call i32 @_Z7extractIiLj2ELj2EET_R8MyMatrixIS0_XT0_EXT1_EE(%struct.MyMatrix.3* nonnull align 4 dereferenceable(16) [[MAT1:%.*]])
+  // CHECK-NEXT:    [[CALL:%.*]] = call i32 @_Z7extractIiLj2ELj2EET_R8MyMatrixIS0_XT0_EXT1_EE(%struct.MyMatrix.5* nonnull align 4 dereferenceable(16) [[MAT1:%.*]])
   // CHECK-NEXT:    ret i32 [[CALL]]
   //
   // CHECK-LABEL: define linkonce_odr i32 @_Z7extractIiLj2ELj2EET_R8MyMatrixIS0_XT0_EXT1_EE(
@@ -301,7 +340,7 @@ struct identmatrix_t {
 constexpr identmatrix_t identmatrix;
 
 void test_constexpr1(matrix_type<float, 4, 4> &m) {
-  // CHECK-LABEL: define void @_Z15test_constexpr1RU11matrix_typeLm4ELm4Ef([16 x float]* nonnull align 4 dereferenceable(64) %m) #3 {
+  // CHECK-LABEL: define void @_Z15test_constexpr1RU11matrix_typeLm4ELm4Ef(
   // CHECK:         [[MAT:%.*]] = load <16 x float>, <16 x float>* {{.*}}, align 4
   // CHECK-NEXT:    [[IM:%.*]] = call <16 x float> @_ZNK13identmatrix_tcvU11matrix_typeXT0_EXT0_ET_IfLj4EEEv(%struct.identmatrix_t* @_ZL11identmatrix)
   // CHECK-NEXT:    [[ADD:%.*]] = fadd <16 x float> [[MAT]], [[IM]]
@@ -327,7 +366,7 @@ void test_constexpr1(matrix_type<float, 4, 4> &m) {
 }
 
 void test_constexpr2(matrix_type<int, 5, 5> &m) {
-  // CHECK-LABEL: define void @_Z15test_constexpr2RU11matrix_typeLm5ELm5Ei([25 x i32]* nonnull align 4 dereferenceable(100) %m) #4 {
+  // CHECK-LABEL: define void @_Z15test_constexpr2RU11matrix_typeLm5ELm5Ei(
   // CHECK:         [[IM:%.*]] = call <25 x i32> @_ZNK13identmatrix_tcvU11matrix_typeXT0_EXT0_ET_IiLj5EEEv(%struct.identmatrix_t* @_ZL11identmatrix)
   // CHECK:         [[MAT:%.*]] = load <25 x i32>, <25 x i32>* {{.*}}, align 4
   // CHECK-NEXT:    [[SUB:%.*]] = sub <25 x i32> [[IM]], [[MAT]]
index 0de0dda..397fc35 100644 (file)
@@ -32,6 +32,44 @@ void sub(sx10x10_t a, sx5x10_t b, sx10x5_t c) {
   // expected-error@-2 {{casting 'sx10x5_t *' (aka 'float  __attribute__((matrix_type(10, 5)))*') to incompatible type 'float'}}
 }
 
+typedef int ix10x5_t __attribute__((matrix_type(10, 5)));
+typedef int ix10x10_t __attribute__((matrix_type(10, 10)));
+
+void matrix_matrix_multiply(sx10x10_t a, sx5x10_t b, ix10x5_t c, ix10x10_t d, float sf, char *p) {
+  // Check dimension mismatches.
+  a = a * b;
+  // expected-error@-1 {{invalid operands to binary expression ('sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') and 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))'))}}
+  b = a * a;
+  // expected-error@-1 {{assigning to 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') from incompatible type 'float __attribute__((matrix_type(10, 10)))'}}
+
+  // Check element type mismatches.
+  a = b * c;
+  // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'ix10x5_t' (aka 'int __attribute__((matrix_type(10, 5)))'))}}
+  d = a * a;
+  // expected-error@-1 {{assigning to 'ix10x10_t' (aka 'int __attribute__((matrix_type(10, 10)))') from incompatible type 'float __attribute__((matrix_type(10, 10)))'}}
+
+  p = a * a;
+  // expected-error@-1 {{assigning to 'char *' from incompatible type 'float __attribute__((matrix_type(10, 10)))'}}
+}
+
+void mat_scalar_multiply(sx10x10_t a, sx5x10_t b, float sf, char *p) {
+  // Shape of multiplication result does not match the type of b.
+  b = a * sf;
+  // expected-error@-1 {{assigning to 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') from incompatible type 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))')}}
+  b = sf * a;
+  // expected-error@-1 {{assigning to 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') from incompatible type 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))')}}
+
+  a = a * p;
+  // expected-error@-1 {{casting 'char *' to incompatible type 'float'}}
+  // expected-error@-2 {{invalid operands to binary expression ('sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') and 'char *')}}
+  a = p * a;
+  // expected-error@-1 {{casting 'char *' to incompatible type 'float'}}
+  // expected-error@-2 {{invalid operands to binary expression ('char *' and 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))'))}}
+
+  sf = a * sf;
+  // expected-error@-1 {{assigning to 'float' from incompatible type 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))')}}
+}
+
 sx5x10_t get_matrix();
 
 void insert(sx5x10_t a, float f) {
index 337d44f..808e474 100644 (file)
@@ -65,6 +65,45 @@ void test_subtract_template(unsigned *Ptr1, float *Ptr2) {
   // expected-note@-1 {{in instantiation of function template specialization 'subtract<unsigned int, 3, 3, float, 2, 2, unsigned int, 2, 2>' requested here}}
 }
 
+template <typename EltTy0, unsigned R0, unsigned C0, typename EltTy1, unsigned R1, unsigned C1, typename EltTy2, unsigned R2, unsigned C2>
+typename MyMatrix<EltTy2, R2, C2>::matrix_t multiply(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy1, R1, C1> &B) {
+  char *v1 = A.value * B.value;
+  // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'unsigned int __attribute__((matrix_type(2, 2)))'}}
+  // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<unsigned int, 3, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}}
+  // expected-error@-3 {{invalid operands to binary expression ('MyMatrix<float, 2, 2>::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))'))}}
+
+  MyMatrix<int, 5, 6> m;
+  B.value = m.value * A.value;
+  // expected-error@-1 {{invalid operands to binary expression ('MyMatrix<int, 5, 6>::matrix_t' (aka 'int __attribute__((matrix_type(5, 6)))') and 'MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))'))}}
+  // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<int, 5, 6>::matrix_t' (aka 'int __attribute__((matrix_type(5, 6)))') and 'MyMatrix<unsigned int, 3, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 2)))'))}}
+  // expected-error@-3 {{invalid operands to binary expression ('MyMatrix<int, 5, 6>::matrix_t' (aka 'int __attribute__((matrix_type(5, 6)))') and 'MyMatrix<float, 2, 2>::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))'))}}
+
+  return A.value * B.value;
+  // expected-error@-1 {{invalid operands to binary expression ('MyMatrix<unsigned int, 3, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}}
+  // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<float, 2, 2>::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))'))}}
+}
+
+void test_multiply_template(unsigned *Ptr1, float *Ptr2) {
+  MyMatrix<unsigned, 2, 2> Mat1;
+  MyMatrix<unsigned, 3, 3> Mat2;
+  MyMatrix<float, 2, 2> Mat3;
+  Mat1.value = *((decltype(Mat1)::matrix_t *)Ptr1);
+  unsigned v1 = multiply<unsigned, 2, 2, unsigned, 2, 2, unsigned, 2, 2>(Mat1, Mat1);
+  // expected-note@-1 {{in instantiation of function template specialization 'multiply<unsigned int, 2, 2, unsigned int, 2, 2, unsigned int, 2, 2>' requested here}}
+  // expected-error@-2 {{cannot initialize a variable of type 'unsigned int' with an rvalue of type 'typename MyMatrix<unsigned int, 2U, 2U>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}}
+
+  MyMatrix<unsigned, 3, 2> Mat4;
+  Mat1.value = multiply<unsigned, 3, 2, unsigned, 3, 3, unsigned, 2, 2>(Mat4, Mat2);
+  // expected-note@-1 {{in instantiation of function template specialization 'multiply<unsigned int, 3, 2, unsigned int, 3, 3, unsigned int, 2, 2>' requested here}}
+
+  Mat1.value = multiply<float, 2, 2, unsigned, 2, 2, unsigned, 2, 2>(Mat3, Mat1);
+  // expected-note@-1 {{in instantiation of function template specialization 'multiply<float, 2, 2, unsigned int, 2, 2, unsigned int, 2, 2>' requested here}}
+
+  Mat4.value = Mat4.value * Mat1;
+  // expected-error@-1 {{no viable conversion from 'MyMatrix<unsigned int, 2, 2>' to 'unsigned int'}}
+  // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<unsigned int, 3, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 2)))') and 'MyMatrix<unsigned int, 2, 2>')}}
+}
+
 struct UserT {};
 
 struct StructWithC {
index 20aa47e..4833ed2 100644 (file)
@@ -33,6 +33,21 @@ template <class IRBuilderTy> class MatrixBuilder {
   IRBuilderTy &B;
   Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); }
 
+  std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS,
+                                                         Value *RHS) {
+    assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) &&
+           "One of the operands must be a matrix (embedded in a vector)");
+    if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy())
+      RHS = B.CreateVectorSplat(
+          cast<VectorType>(LHS->getType())->getNumElements(), RHS,
+          "scalar.splat");
+    else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy())
+      LHS = B.CreateVectorSplat(
+          cast<VectorType>(RHS->getType())->getNumElements(), LHS,
+          "scalar.splat");
+    return {LHS, RHS};
+  }
+
 public:
   MatrixBuilder(IRBuilderTy &Builder) : B(Builder) {}
 
@@ -164,15 +179,13 @@ public:
                : B.CreateSub(LHS, RHS);
   }
 
-  /// Multiply matrix \p LHS with scalar \p RHS.
+  /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p
+  /// RHS.
   Value *CreateScalarMultiply(Value *LHS, Value *RHS) {
-    Value *ScalarVector =
-        B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getNumElements(),
-                            RHS, "scalar.splat");
-    if (RHS->getType()->isFloatingPointTy())
-      return B.CreateFMul(LHS, ScalarVector);
-
-    return B.CreateMul(LHS, ScalarVector);
+    std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS);
+    if (LHS->getType()->getScalarType()->isFloatingPointTy())
+      return B.CreateFMul(LHS, RHS);
+    return B.CreateMul(LHS, RHS);
   }
 
   /// Extracts the element at (\p RowIdx, \p ColumnIdx) from \p Matrix.