//===----------------------------------------------------------------------===//
OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 1 && "spv.CompositeExtract expects one operand");
+ if (auto insertOp = composite().getDefiningOp<spirv::CompositeInsertOp>()) {
+ if (indices() == insertOp.indices())
+ return insertOp.object();
+ }
+
+ if (auto constructOp =
+ composite().getDefiningOp<spirv::CompositeConstructOp>()) {
+ auto type = constructOp.getType().cast<spirv::CompositeType>();
+ if (indices().size() == 1 &&
+ constructOp.constituents().size() == type.getNumElements()) {
+ auto i = indices().begin()->cast<IntegerAttr>();
+ return constructOp.constituents()[i.getValue().getSExtValue()];
+ }
+ }
+
auto indexVector =
llvm::to_vector<8>(llvm::map_range(indices(), [](Attribute attr) {
return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
// -----
+// CHECK-LABEL: extract_insert
+// CHECK-SAME: (%[[COMP:.+]]: !spv.array<1 x vector<2xf32>>, %[[VAL:.+]]: f32)
+func.func @extract_insert(%composite: !spv.array<1xvector<2xf32>>, %val: f32) -> (f32, f32) {
+ // CHECK: %[[INSERT:.+]] = spv.CompositeInsert %[[VAL]], %[[COMP]]
+ %insert = spv.CompositeInsert %val, %composite[0 : i32, 1 : i32] : f32 into !spv.array<1xvector<2xf32>>
+ %1 = spv.CompositeExtract %insert[0 : i32, 0 : i32] : !spv.array<1xvector<2xf32>>
+ // CHECK: %[[S:.+]] = spv.CompositeExtract %[[INSERT]][0 : i32, 0 : i32]
+ %2 = spv.CompositeExtract %insert[0 : i32, 1 : i32] : !spv.array<1xvector<2xf32>>
+ // CHECK: return %[[S]], %[[VAL]]
+ return %1, %2 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: extract_construct
+// CHECK-SAME: (%[[VAL1:.+]]: vector<2xf32>, %[[VAL2:.+]]: vector<2xf32>)
+func.func @extract_construct(%val1: vector<2xf32>, %val2: vector<2xf32>) -> (vector<2xf32>, vector<2xf32>) {
+ %construct = spv.CompositeConstruct %val1, %val2 : (vector<2xf32>, vector<2xf32>) -> !spv.array<2xvector<2xf32>>
+ %1 = spv.CompositeExtract %construct[0 : i32] : !spv.array<2xvector<2xf32>>
+ %2 = spv.CompositeExtract %construct[1 : i32] : !spv.array<2xvector<2xf32>>
+ // CHECK: return %[[VAL1]], %[[VAL2]]
+ return %1, %2 : vector<2xf32>, vector<2xf32>
+}
+
+// -----
+
+// Not yet implemented case
+
+// CHECK-LABEL: extract_construct
+func.func @extract_construct(%val1: vector<3xf32>, %val2: f32) -> (f32, f32) {
+ // CHECK: spv.CompositeConstruct
+ %construct = spv.CompositeConstruct %val1, %val2 : (vector<3xf32>, f32) -> vector<4xf32>
+ // CHECK: spv.CompositeExtract
+ %1 = spv.CompositeExtract %construct[0 : i32] : vector<4xf32>
+ // CHECK: spv.CompositeExtract
+ %2 = spv.CompositeExtract %construct[1 : i32] : vector<4xf32>
+ return %1, %2 : f32, f32
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spv.Constant
//===----------------------------------------------------------------------===//