def SPV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>;
def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>;
def SPV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
+def SPV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
def SPV_OC_OpLogicalEqual : I32EnumAttrCase<"OpLogicalEqual", 164>;
def SPV_OC_OpLogicalNotEqual : I32EnumAttrCase<"OpLogicalNotEqual", 165>;
def SPV_OC_OpLogicalOr : I32EnumAttrCase<"OpLogicalOr", 166>;
SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub,
SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv,
SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod,
- SPV_OC_OpMatrixTimesScalar, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
- SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
- SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
- SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
- SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
- SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
- SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
- SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
- SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
- SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
- SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,
- SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor,
- SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitFieldInsert,
- SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, SPV_OC_OpBitReverse,
- SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier,
- SPV_OC_OpAtomicCompareExchangeWeak, SPV_OC_OpAtomicIIncrement,
- SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, SPV_OC_OpAtomicISub,
- SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, SPV_OC_OpAtomicSMax,
- SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor,
- SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel,
- SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
- SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpNoLine,
- SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect,
- SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpGroupNonUniformIAdd,
- SPV_OC_OpGroupNonUniformFAdd, SPV_OC_OpGroupNonUniformIMul,
- SPV_OC_OpGroupNonUniformFMul, SPV_OC_OpGroupNonUniformSMin,
- SPV_OC_OpGroupNonUniformUMin, SPV_OC_OpGroupNonUniformFMin,
- SPV_OC_OpGroupNonUniformSMax, SPV_OC_OpGroupNonUniformUMax,
- SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR,
- SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV,
- SPV_OC_OpCooperativeMatrixStoreNV, SPV_OC_OpCooperativeMatrixMulAddNV,
- SPV_OC_OpCooperativeMatrixLengthNV
+ SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpLogicalEqual,
+ SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd,
+ SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpINotEqual,
+ SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual,
+ SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan,
+ SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual,
+ SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual,
+ SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan,
+ SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual,
+ SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual,
+ SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpShiftRightLogical,
+ SPV_OC_OpShiftRightArithmetic, SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr,
+ SPV_OC_OpBitwiseXor, SPV_OC_OpBitwiseAnd, SPV_OC_OpNot,
+ SPV_OC_OpBitFieldInsert, SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract,
+ SPV_OC_OpBitReverse, SPV_OC_OpBitCount, SPV_OC_OpControlBarrier,
+ SPV_OC_OpMemoryBarrier, SPV_OC_OpAtomicCompareExchangeWeak,
+ SPV_OC_OpAtomicIIncrement, SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd,
+ SPV_OC_OpAtomicISub, SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin,
+ SPV_OC_OpAtomicSMax, SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd,
+ SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, SPV_OC_OpPhi, SPV_OC_OpLoopMerge,
+ SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch,
+ SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue,
+ SPV_OC_OpUnreachable, SPV_OC_OpNoLine, SPV_OC_OpModuleProcessed,
+ SPV_OC_OpGroupNonUniformElect, SPV_OC_OpGroupNonUniformBallot,
+ SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpGroupNonUniformFAdd,
+ SPV_OC_OpGroupNonUniformIMul, SPV_OC_OpGroupNonUniformFMul,
+ SPV_OC_OpGroupNonUniformSMin, SPV_OC_OpGroupNonUniformUMin,
+ SPV_OC_OpGroupNonUniformFMin, SPV_OC_OpGroupNonUniformSMax,
+ SPV_OC_OpGroupNonUniformUMax, SPV_OC_OpGroupNonUniformFMax,
+ SPV_OC_OpSubgroupBallotKHR, SPV_OC_OpTypeCooperativeMatrixNV,
+ SPV_OC_OpCooperativeMatrixLoadNV, SPV_OC_OpCooperativeMatrixStoreNV,
+ SPV_OC_OpCooperativeMatrixMulAddNV, SPV_OC_OpCooperativeMatrixLengthNV
]>;
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
#ifndef SPIRV_MATRIX_OPS
#define SPIRV_MATRIX_OPS
+include "mlir/Interfaces/SideEffectInterfaces.td"
// -----
-def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", []> {
+def SPV_MatrixTimesMatrixOp : SPV_Op<"MatrixTimesMatrix", [NoSideEffect]> {
+ let summary = "Linear-algebraic multiply of LeftMatrix X RightMatrix.";
+
+ let description = [{
+ Result Type must be an OpTypeMatrix whose Column Type is a vector of
+ floating-point type.
+
+ LeftMatrix must be a matrix whose Column Type is the same as the Column
+ Type in Result Type.
+
+ RightMatrix must be a matrix with the same Component Type as the
+ Component Type in Result Type. Its number of columns must equal the
+ number of columns in Result Type. Its columns must have the same number
+ of components as the number of columns in LeftMatrix.
+
+ <!-- End of AutoGen section -->
+
+ ```
+ matrix-times-matrix-op ::= ssa-id `=` `spv.MatrixTimesMatrix` ssa-use,
+ ssa-use `:` matrix-type `,` matrix-type `->` matrix-type
+ ```mlir
+
+ #### Example:
+
+ ```
+ %0 = spv.MatrixTimesMatrix %matrix_1, %matrix_2 :
+ !spv.matrix<4 x vector<3xf32>>, !spv.matrix<3 x vector<4xf32>> ->
+ !spv.matrix<4 x vector<4xf32>>
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPV_V_1_0>,
+ MaxVersion<SPV_V_1_5>,
+ Extension<[]>,
+ Capability<[SPV_C_Matrix]>
+ ];
+
+ let arguments = (ins
+ SPV_AnyMatrix:$leftmatrix,
+ SPV_AnyMatrix:$rightmatrix
+ );
+
+ let results = (outs
+ SPV_AnyMatrix:$result
+ );
+ let assemblyFormat = [{
+ operands attr-dict `:` type($leftmatrix) `,` type($rightmatrix) `->` type($result)
+ }];
+ let verifier = [{ return verifyMatrixTimesMatrix(*this); }];
+}
+
+// -----
+
+def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", [NoSideEffect]> {
let summary = "Scale a floating-point matrix.";
let description = [{
// -----
-def SPV_TransposeOp : SPV_Op<"Transpose", []> {
+def SPV_TransposeOp : SPV_Op<"Transpose", [NoSideEffect]> {
let summary = "Transpose a matrix.";
let description = [{
Type columnType,
uint32_t columnCount);
- /// Returns true if the matrix elements are vectors of float elements
+ /// Returns true if the matrix elements are vectors of float elements.
static bool isValidColumnType(Type columnType);
- Type getElementType() const;
+ Type getColumnType() const;
+
+ /// Returns the number of rows.
+ unsigned getNumRows() const;
+
+ /// Returns the number of columns.
+ unsigned getNumColumns() const;
+ /// Returns total number of elements (rows*columns).
unsigned getNumElements() const;
+ /// Returns the elements' type (i.e, single element type).
+ Type getElementType() const;
+
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<spirv::StorageClass> storage = llvm::None);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
}
static void print(MatrixType type, DialectAsmPrinter &os) {
- os << "matrix<" << type.getNumElements() << " x " << type.getElementType();
+ os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
os << ">";
}
// auto-generated verify method.
auto inputMatrix = op.matrix().getType().cast<spirv::MatrixType>();
- // Check that the scalar type is the same as the matrix components type.
- if (auto inputMatrixColumns =
- inputMatrix.getElementType().dyn_cast<VectorType>()) {
- if (op.scalar().getType() != inputMatrixColumns.getElementType())
- return op.emitError("input matrix components' type and scaling "
- "value must have the same type");
-
- // Note that the next three checks could be done using the AllTypesMatch
- // trait in the Op definition file but it generates a vague error message.
-
- // Check that the input and result matrices have the same size
- auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
- if (inputMatrix.getNumElements() != resultMatrix.getNumElements())
- return op.emitError("input and result matrices must have "
- "the same number of columns");
-
- if (auto resultMatrixColumns =
- resultMatrix.getElementType().dyn_cast<VectorType>()) {
- // Check that the input and result matrices' columns have the same type
- if (inputMatrixColumns.getElementType() !=
- resultMatrixColumns.getElementType())
- return op.emitError("input and result matrices' columns must "
- "have the same component type");
-
- // Check that the input and result matrices' columns have the same size
- if (inputMatrixColumns.getNumElements() !=
- resultMatrixColumns.getNumElements())
- return op.emitError("input and result matrices' columns must "
- "have the same size");
- }
- }
+ auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
+
+ // Check that the scalar type is the same as the matrix element type.
+ if (op.scalar().getType() != inputMatrix.getElementType())
+ return op.emitError("input matrix components' type and scaling value must "
+ "have the same type");
+
+ // Note that the next three checks could be done using the AllTypesMatch
+ // trait in the Op definition file but it generates a vague error message.
+
+ // Check that the input and result matrices have the same columns' count
+ if (inputMatrix.getNumColumns() != resultMatrix.getNumColumns())
+ return op.emitError("input and result matrices must have the same "
+ "number of columns");
+
+ // Check that the input and result matrices' have the same rows count
+ if (inputMatrix.getNumRows() != resultMatrix.getNumRows())
+ return op.emitError("input and result matrices' columns must have "
+ "the same size");
+
+ // Check that the input and result matrices' have the same component type
+ if (inputMatrix.getElementType() != resultMatrix.getElementType())
+ return op.emitError("input and result matrices' columns must have "
+ "the same component type");
return success();
}
auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
// Verify that the input and output matrices have correct shapes.
- if (auto inputMatrixColumns =
- inputMatrix.getElementType().dyn_cast<VectorType>()) {
- if (inputMatrixColumns.getNumElements() != resultMatrix.getNumElements())
- return op.emitError("input matrix rows count must be equal to "
- "output matrix columns count");
- if (auto resultMatrixColumns =
- resultMatrix.getElementType().dyn_cast<VectorType>()) {
- if (resultMatrixColumns.getNumElements() != inputMatrix.getNumElements())
- return op.emitError("input matrix columns count must be equal "
- "to output matrix rows count");
-
- // Verify that the input and output matrices have the same component type
- if (inputMatrixColumns.getElementType() !=
- resultMatrixColumns.getElementType())
- return op.emitError("input and output matrices must have the "
- "same component type");
- }
- }
+ if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
+ return op.emitError("input matrix rows count must be equal to "
+ "output matrix columns count");
+
+ if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
+ return op.emitError("input matrix columns count must be equal to "
+ "output matrix rows count");
+
+ // Verify that the input and output matrices have the same component type
+ if (inputMatrix.getElementType() != resultMatrix.getElementType())
+ return op.emitError("input and output matrices must have the same "
+ "component type");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spv.MatrixTimesMatrix
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyMatrixTimesMatrix(spirv::MatrixTimesMatrixOp op) {
+ auto leftMatrix = op.leftmatrix().getType().cast<spirv::MatrixType>();
+ auto rightMatrix = op.rightmatrix().getType().cast<spirv::MatrixType>();
+ auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
+
+ // left matrix columns' count and right matrix rows' count must be equal
+ if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
+ return op.emitError("left matrix columns' count must be equal to "
+ "the right matrix rows' count");
+
+ // right and result matrices columns' count must be the same
+ if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
+ return op.emitError(
+ "right and result matrices must have equal columns' count");
+
+ // right and result matrices component type must be the same
+ if (rightMatrix.getElementType() != resultMatrix.getElementType())
+ return op.emitError("right and result matrices' component type must"
+ " be the same");
+
+ // left and result matrices component type must be the same
+ if (leftMatrix.getElementType() != resultMatrix.getElementType())
+ return op.emitError("left and result matrices' component type"
+ " must be the same");
+
+ // left and result matrices rows count must be the same
+ if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
+ return op.emitError("left and result matrices must have equal rows'"
+ " count");
+
return success();
}
case spirv::TypeKind::CooperativeMatrix:
return cast<CooperativeMatrixNVType>().getElementType();
case spirv::TypeKind::Matrix:
- return cast<MatrixType>().getElementType();
+ return cast<MatrixType>().getColumnType();
case spirv::TypeKind::RuntimeArray:
return cast<RuntimeArrayType>().getElementType();
case spirv::TypeKind::Struct:
llvm_unreachable(
"invalid to query number of elements of spirv::CooperativeMatrix type");
case spirv::TypeKind::Matrix:
- return cast<MatrixType>().getNumElements();
+ return cast<MatrixType>().getNumColumns();
case spirv::TypeKind::RuntimeArray:
llvm_unreachable(
"invalid to query number of elements of spirv::RuntimeArray type");
return false;
}
-Type MatrixType::getElementType() const { return getImpl()->columnType; }
+Type MatrixType::getColumnType() const { return getImpl()->columnType; }
-unsigned MatrixType::getNumElements() const { return getImpl()->columnCount; }
+Type MatrixType::getElementType() const {
+ return getImpl()->columnType.cast<VectorType>().getElementType();
+}
+
+unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
+
+unsigned MatrixType::getNumRows() const {
+ return getImpl()->columnType.cast<VectorType>().getShape()[0];
+}
+
+unsigned MatrixType::getNumElements() const {
+ return (getImpl()->columnCount) * getNumRows();
+}
void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage) {
- getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
+ getColumnType().cast<SPIRVType>().getExtensions(extensions, storage);
}
void MatrixType::getCapabilities(
capabilities.push_back(ref);
}
// Add any capabilities associated with the underlying vectors (i.e., columns)
- getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
+ getColumnType().cast<SPIRVType>().getCapabilities(capabilities, storage);
}
if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) {
uint32_t elementTypeID = 0;
- if (failed(processType(loc, matrixType.getElementType(), elementTypeID))) {
+ if (failed(processType(loc, matrixType.getColumnType(), elementTypeID))) {
return failure();
}
typeEnum = spirv::Opcode::OpTypeMatrix;
operands.push_back(elementTypeID);
- operands.push_back(matrixType.getNumElements());
+ operands.push_back(matrixType.getNumColumns());
return success();
}
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>>
spv.ReturnValue %result : !spv.matrix<2 x vector<3xf32>>
}
+
+ // CHECK-LABEL: @matrix_times_matrix_1
+ spv.func @matrix_times_matrix_1(%arg0: !spv.matrix<3 x vector<3xf32>>, %arg1: !spv.matrix<3 x vector<3xf32>>) -> !spv.matrix<3 x vector<3xf32>> "None"{
+ // CHECK: {{%.*}} = spv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
+ %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
+ spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
+ }
+
+ // CHECK-LABEL: @matrix_times_matrix_2
+ spv.func @matrix_times_matrix_2(%arg0: !spv.matrix<3 x vector<2xf32>>, %arg1: !spv.matrix<2 x vector<3xf32>>) -> !spv.matrix<2 x vector<2xf32>> "None"{
+ // CHECK: {{%.*}} = spv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<2xf32>>
+ %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<2xf32>>
+ spv.ReturnValue %result : !spv.matrix<2 x vector<2xf32>>
+ }
}
// -----
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
}
+
+ // CHECK-LABEL: @matrix_times_matrix_1
+ spv.func @matrix_times_matrix_1(%arg0: !spv.matrix<3 x vector<3xf32>>, %arg1: !spv.matrix<3 x vector<3xf32>>) -> !spv.matrix<3 x vector<3xf32>> "None"{
+ // CHECK: {{%.*}} = spv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
+ %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
+ spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
+ }
+
+ // CHECK-LABEL: @matrix_times_matrix_2
+ spv.func @matrix_times_matrix_2(%arg0: !spv.matrix<3 x vector<2xf32>>, %arg1: !spv.matrix<2 x vector<3xf32>>) -> !spv.matrix<2 x vector<2xf32>> "None"{
+ // CHECK: {{%.*}} = spv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<2xf32>>
+ %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<2xf32>>
+ spv.ReturnValue %result : !spv.matrix<2 x vector<2xf32>>
+ }
}
// -----
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<4xf32>> -> !spv.matrix<4 x vector<3xf16>>
spv.Return
}
+
+// -----
+
+func @matrix_times_matrix_invalid_input_shape_1(%arg0 : !spv.matrix<3 x vector<2xf32>>, %arg1 : !spv.matrix<2 x vector<3xf32>>){
+ // expected-error @+1 {{right and result matrices must have equal columns' count}}
+ %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<3 x vector<2xf32>>
+}
+
+// -----
+
+func @matrix_times_matrix_invalid_input_shape_2(%arg0 : !spv.matrix<3 x vector<2xf32>>, %arg1 : !spv.matrix<2 x vector<3xf32>>){
+ // expected-error @+1 {{left and result matrices must have equal rows' count}}
+ %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<3xf32>>
+}
+
+// -----
+
+func @matrix_times_matrix_inputs_shape_mismatch(%arg0 : !spv.matrix<3 x vector<2xf32>>, %arg1 : !spv.matrix<2 x vector<2xf32>>){
+ // expected-error @+1 {{left matrix columns' count must be equal to the right matrix rows' count}}
+ %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<2xf32>> -> !spv.matrix<2 x vector<2xf32>>
+}
+
+// -----
+
+func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : !spv.matrix<3x vector<3xf32>>){
+ // expected-error @+1 {{right and result matrices' component type must be the same}}
+ %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf64>>
+}
+
+
+// -----
+
+func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spv.matrix<3 x vector<3xf64>>, %arg1 : !spv.matrix<3x vector<3xf32>>){
+ // expected-error @+1 {{left and result matrices' component type must be the same}}
+ %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<3xf64>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
+}