def SPV_OC_OpCompositeConstruct : I32EnumAttrCase<"OpCompositeConstruct", 80>;
def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>;
def SPV_OC_OpCompositeInsert : I32EnumAttrCase<"OpCompositeInsert", 82>;
+def SPV_OC_OpTranspose : I32EnumAttrCase<"OpTranspose", 84>;
def SPV_OC_OpConvertFToU : I32EnumAttrCase<"OpConvertFToU", 109>;
def SPV_OC_OpConvertFToS : I32EnumAttrCase<"OpConvertFToS", 110>;
def SPV_OC_OpConvertSToF : I32EnumAttrCase<"OpConvertSToF", 111>;
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall,
SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain,
SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct,
- SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpConvertFToU,
- SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,
- SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast,
- SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub,
- SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv,
- SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod,
- SPV_OC_OpMatrixTimesScalar, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
- SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
- SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
- SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
- SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
- SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
- SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
- SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
+ SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose,
+ SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF,
+ SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert,
+ SPV_OC_OpBitcast, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd,
+ SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv,
+ SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod,
+ SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar,
+ SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr,
+ SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual,
+ SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
+ SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
+ SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
+ SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
+ SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan,
+ SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,
return success();
}
+//===----------------------------------------------------------------------===//
+// spv.Transpose
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyTranspose(spirv::TransposeOp op) {
+ auto inputMatrix = op.matrix().getType().cast<spirv::MatrixType>();
+ auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
+
+ // Verify that the input and output matrices have correct shapes.
+ if (auto inputMatrixColumns =
+ inputMatrix.getElementType().dyn_cast<VectorType>()) {
+ if (inputMatrixColumns.getNumElements() != resultMatrix.getNumElements())
+ return op.emitError("input matrix rows count must be equal to "
+ "output matrix columns count");
+ if (auto resultMatrixColumns =
+ resultMatrix.getElementType().dyn_cast<VectorType>()) {
+ if (resultMatrixColumns.getNumElements() != inputMatrix.getNumElements())
+ return op.emitError("input matrix columns count must be equal "
+ "to output matrix rows count");
+
+ // Verify that the input and output matrices have the same component type
+ if (inputMatrixColumns.getElementType() !=
+ resultMatrixColumns.getElementType())
+ return op.emitError("input and output matrices must have the "
+ "same component type");
+ }
+ }
+ return success();
+}
+
namespace mlir {
namespace spirv {
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
// CHECK-LABEL: @matrix_times_scalar
- spv.func @matrix_times_scalar_1(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spv.matrix<3 x vector<3xf32>> "None" {
+ spv.func @matrix_times_scalar(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spv.matrix<3 x vector<3xf32>> "None" {
// CHECK: {{%.*}} = spv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>>
%result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>>
spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
}
+
+ // CHECK-LABEL: @matrix_transpose_1
+ spv.func @matrix_transpose_1(%arg0 : !spv.matrix<3 x vector<2xf32>>) -> !spv.matrix<2 x vector<3xf32>> "None" {
+ // CHECK: {{%.*}} = spv.Transpose {{%.*}} : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>>
+ %result = spv.Transpose %arg0 : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>>
+ spv.ReturnValue %result : !spv.matrix<2 x vector<3xf32>>
+ }
+
+ // CHECK-LABEL: @matrix_transpose_2
+ spv.func @matrix_transpose_2(%arg0 : !spv.matrix<3 x vector<3xf32>>) -> !spv.matrix<3 x vector<3xf32>> "None" {
+ // CHECK: {{%.*}} = spv.Transpose {{%.*}} : !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
+ %result = spv.Transpose %arg0 : !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
+ spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
+ }
}
// -----
%result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<4 x vector<3xf32>>
}
+// -----
+
+func @transpose_op_shape_mismatch_1(%arg0 : !spv.matrix<3 x vector<4xf32>>) -> () {
+ // expected-error @+1 {{input matrix rows count must be equal to output matrix columns count}}
+ %result = spv.Transpose %arg0 : !spv.matrix<3 x vector<4xf32>> -> !spv.matrix<3 x vector<3xf32>>
+ spv.Return
+}
+
+// -----
+func @transpose_op_shape_mismatch_2(%arg0 : !spv.matrix<3 x vector<4xf32>>) -> () {
+ // expected-error @+1 {{input matrix rows count must be equal to output matrix columns count}}
+ %result = spv.Transpose %arg0 : !spv.matrix<3 x vector<4xf32>> -> !spv.matrix<2 x vector<4xf32>>
+ spv.Return
+}
+// -----
+
+func @transpose_op_type_mismatch(%arg0 : !spv.matrix<3 x vector<4xf32>>) -> () {
+ // expected-error @+1 {{input and output matrices must have the same component type}}
+ %result = spv.Transpose %arg0 : !spv.matrix<3 x vector<4xf32>> -> !spv.matrix<4 x vector<3xf16>>
+ spv.Return
+}