[spirv] Add CompositeConstruct operation.
authorDenis Khalikov <khalikov.denis@huawei.com>
Mon, 9 Dec 2019 20:43:23 +0000 (12:43 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 9 Dec 2019 20:43:53 +0000 (12:43 -0800)
Closes tensorflow/mlir#308

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/308 from denis0x0D:sandbox/composite_construct 9ef7180f77f9374bcd05afc4f9e6c1d2d72d02b7
PiperOrigin-RevId: 284613617

mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir
mlir/test/Dialect/SPIRV/composite-ops.mlir

index 62095a5..8368a62 100644 (file)
@@ -1075,6 +1075,7 @@ def SPV_OC_OpStore                     : I32EnumAttrCase<"OpStore", 62>;
 def SPV_OC_OpAccessChain               : I32EnumAttrCase<"OpAccessChain", 65>;
 def SPV_OC_OpDecorate                  : I32EnumAttrCase<"OpDecorate", 71>;
 def SPV_OC_OpMemberDecorate            : I32EnumAttrCase<"OpMemberDecorate", 72>;
+def SPV_OC_OpCompositeConstruct        : I32EnumAttrCase<"OpCompositeConstruct", 80>;
 def SPV_OC_OpCompositeExtract          : I32EnumAttrCase<"OpCompositeExtract", 81>;
 def SPV_OC_OpCompositeInsert           : I32EnumAttrCase<"OpCompositeInsert", 82>;
 def SPV_OC_OpConvertFToU               : I32EnumAttrCase<"OpConvertFToU", 109>;
@@ -1171,20 +1172,21 @@ def SPV_OpcodeAttr :
       SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
       SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad,
       SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
-      SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert,
-      SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF,
-      SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert,
-      SPV_OC_OpBitcast, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd,
-      SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv,
-      SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod,
-      SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
-      SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
-      SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
-      SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
-      SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
-      SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
-      SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
-      SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
+      SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct,
+      SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpConvertFToU,
+      SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,
+      SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast,
+      SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub,
+      SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv,
+      SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod,
+      SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr,
+      SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual,
+      SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
+      SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
+      SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
+      SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
+      SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan,
+      SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
       SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
       SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
       SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,
index 7165050..6392a1b 100644 (file)
 
 include "mlir/Dialect/SPIRV/SPIRVBase.td"
 
+// -----
+
+def SPV_CompositeConstructOp : SPV_Op<"CompositeConstruct", [NoSideEffect]> {
+  let summary = [{
+    Construct a new composite object from a set of constituent objects that
+    will fully form it.
+  }];
+
+  let description = [{
+    Result Type must be a composite type, whose top-level
+    members/elements/components/columns have the same type as the types of
+    the operands, with one exception. The exception is that for constructing
+    a vector, the operands may also be vectors with the same component type
+    as the Result Type component type. When constructing a vector, the total
+    number of components in all the operands must equal the number of
+    components in Result Type.
+
+    Constituents will become members of a structure, or elements of an
+    array, or components of a vector, or columns of a matrix. There must be
+    exactly one Constituent for each top-level
+    member/element/component/column of the result, with one exception. The
+    exception is that for constructing a vector, a contiguous subset of the
+    scalars consumed can be represented by a vector operand instead. The
+    Constituents must appear in the order needed by the definition of the
+    type of the result. When constructing a vector, there must be at least
+    two Constituent operands.
+
+    ### Custom assembly form
+
+    ``` {.ebnf}
+    composite-construct-op ::= ssa-id `=` `spv.CompositeConstruct`
+                               (ssa-use (`,` ssa-use)* )? `:` composite-type
+    ```
+
+    For example:
+
+    ```
+    %0 = spv.CompositeConstruct %1, %2, %3 : vector<3xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    Variadic<SPV_Type>:$constituents
+  );
+
+  let results = (outs
+    SPV_Composite:$result
+  );
+}
+
+// -----
+
 def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> {
   let summary = "Extract a part of a composite object.";
 
index 4009691..f1fc80b 100644 (file)
@@ -1070,6 +1070,73 @@ static LogicalResult verify(spirv::BranchConditionalOp branchOp) {
 }
 
 //===----------------------------------------------------------------------===//
+// spv.CompositeConstruct
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseCompositeConstructOp(OpAsmParser &parser,
+                                             OperationState &state) {
+  SmallVector<OpAsmParser::OperandType, 4> operands;
+  Type type;
+  auto loc = parser.getCurrentLocation();
+
+  if (parser.parseOperandList(operands) || parser.parseColonType(type)) {
+    return failure();
+  }
+  auto cType = type.dyn_cast<spirv::CompositeType>();
+  if (!cType) {
+    return parser.emitError(
+               loc, "result type must be a composite type, but provided ")
+           << type;
+  }
+
+  if (operands.size() != cType.getNumElements()) {
+    return parser.emitError(loc, "has incorrect number of operands: expected ")
+           << cType.getNumElements() << ", but provided " << operands.size();
+  }
+  // TODO: Add support for constructing a vector type from the vector operands.
+  // According to the spec: "for constructing a vector, the operands may
+  // also be vectors with the same component type as the Result Type component
+  // type".
+  SmallVector<Type, 4> elementTypes;
+  elementTypes.reserve(cType.getNumElements());
+  for (auto index : llvm::seq<uint32_t>(0, cType.getNumElements())) {
+    elementTypes.push_back(cType.getElementType(index));
+  }
+  state.addTypes(type);
+  return parser.resolveOperands(operands, elementTypes, loc, state.operands);
+}
+
+static void print(spirv::CompositeConstructOp compositeConstructOp,
+                  OpAsmPrinter &printer) {
+  printer << spirv::CompositeConstructOp::getOperationName() << " ";
+  printer.printOperands(compositeConstructOp.constituents());
+  printer << " : " << compositeConstructOp.getResult()->getType();
+}
+
+static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
+  auto cType = compositeConstructOp.getType().cast<spirv::CompositeType>();
+
+  SmallVector<Value *, 4> constituents(compositeConstructOp.constituents());
+  if (constituents.size() != cType.getNumElements()) {
+    return compositeConstructOp.emitError(
+               "has incorrect number of operands: expected ")
+           << cType.getNumElements() << ", but provided "
+           << constituents.size();
+  }
+
+  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
+    if (constituents[index]->getType() != cType.getElementType(index)) {
+      return compositeConstructOp.emitError(
+                 "operand type mismatch: expected operand type ")
+             << cType.getElementType(index) << ", but provided "
+             << constituents[index]->getType();
+    }
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // spv.CompositeExtractOp
 //===----------------------------------------------------------------------===//
 
index a3f74ca..ba01cc8 100644 (file)
@@ -2,8 +2,13 @@
 
 spv.module "Logical" "GLSL450" {
   func @composite_insert(%arg0 : !spv.struct<f32, !spv.struct<!spv.array<4xf32>, f32>>, %arg1: !spv.array<4xf32>) -> !spv.struct<f32, !spv.struct<!spv.array<4xf32>, f32>> {
-    // CHECK: {{%.*}} = spv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32, 0 : i32] : !spv.array<4 x f32> into !spv.struct<f32, !spv.struct<!spv.array<4 x f32>, f32>>
+    // CHECK: spv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32, 0 : i32] : !spv.array<4 x f32> into !spv.struct<f32, !spv.struct<!spv.array<4 x f32>, f32>>
     %0 = spv.CompositeInsert %arg1, %arg0[1 : i32, 0 : i32] : !spv.array<4xf32> into !spv.struct<f32, !spv.struct<!spv.array<4xf32>, f32>>
     spv.ReturnValue %0: !spv.struct<f32, !spv.struct<!spv.array<4xf32>, f32>>
   }
+  func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> {
+    // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : vector<3xf32>
+    %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : vector<3xf32>
+    spv.ReturnValue %0: vector<3xf32>
+  }
 }
index 353080c..4ce8974 100644 (file)
@@ -1,6 +1,58 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
 
 //===----------------------------------------------------------------------===//
+// spv.CompositeConstruct
+//===----------------------------------------------------------------------===//
+
+func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> {
+  // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : vector<3xf32>
+  %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : vector<3xf32>
+  return %0: vector<3xf32>
+}
+
+// -----
+
+func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spv.array<4xf32>, %arg2 : !spv.struct<f32>) -> !spv.struct<vector<3xf32>, !spv.array<4xf32>, !spv.struct<f32>> {
+  // CHECK: spv.CompositeConstruct %arg0, %arg1, %arg2 : !spv.struct<vector<3xf32>, !spv.array<4 x f32>, !spv.struct<f32>>
+  %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : !spv.struct<vector<3xf32>, !spv.array<4xf32>, !spv.struct<f32>>
+  return %0: !spv.struct<vector<3xf32>, !spv.array<4xf32>, !spv.struct<f32>>
+}
+
+// -----
+
+func @composite_construct_empty_struct() -> !spv.struct<> {
+  // CHECK: spv.CompositeConstruct : !spv.struct<>
+  %0 = spv.CompositeConstruct : !spv.struct<>
+  return %0: !spv.struct<>
+}
+
+// -----
+
+func @composite_construct_invalid_num_of_elements(%arg0: f32) -> f32 {
+  // expected-error @+1 {{result type must be a composite type, but provided 'f32'}}
+  %0 = spv.CompositeConstruct %arg0 : f32
+  return %0: f32
+}
+
+// -----
+
+func @composite_construct_invalid_result_type(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> {
+  // expected-error @+1 {{has incorrect number of operands: expected 3, but provided 2}}
+  %0 = spv.CompositeConstruct %arg0, %arg2 : vector<3xf32>
+  return %0: vector<3xf32>
+}
+
+// -----
+
+func @composite_construct_invalid_operand_type(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xi32> {
+  // expected-error @+1 {{operand type mismatch: expected operand type 'i32', but provided 'f32'}}
+  %0 = "spv.CompositeConstruct" (%arg0, %arg1, %arg2) : (f32, f32, f32) -> vector<3xi32>
+  return %0: vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
 // spv.CompositeExtractOp
 //===----------------------------------------------------------------------===//