auto inputTy = getInput().getType().dyn_cast<RankedTensorType>();
auto outputTy = getType().dyn_cast<RankedTensorType>();
- if (!inputTy || !outputTy || inputTy != outputTy)
+ if (!inputTy || !outputTy)
return {};
- if (inputTy.hasStaticShape())
+
+ if (inputTy == outputTy && inputTy.hasStaticShape())
return getInput();
+ if (!operands[0])
+ return {};
+
+ auto operand = operands[0].cast<ElementsAttr>();
+ if (operand.isSplat() && outputTy.hasStaticShape()) {
+ return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
+ }
+
+ if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
+ outputTy.getNumElements() == 1) {
+ llvm::SmallVector<uint64_t> indices;
+ for (auto val : getStart()) {
+ indices.push_back(val.cast<IntegerAttr>().getInt());
+ }
+ auto value = operand.getValues<Attribute>()[indices];
+ return SplatElementsAttr::get(outputTy, value);
+ }
+
return {};
}
// CHECK: return %[[THREE]]
return %add : tensor<10xf32>
}
+
+// -----
+
+// CHECK-LABEL: @slice_splat
+func.func @slice_splat() -> tensor<1x1x1xi32> {
+ // CHECK: %[[SLICE:.+]] = "tosa.const"() {value = dense<42> : tensor<1x1x1xi32>}
+ %splat = "tosa.const"() {value = dense<42> : tensor<4x5x6xi32>} : () -> tensor<4x5x6xi32>
+ %slice = "tosa.slice"(%splat) { size = [1, 1, 1], start = [1, 2, 3] } : (tensor<4x5x6xi32>) -> tensor<1x1x1xi32>
+ // CHECK: return %[[SLICE]]
+ return %slice : tensor<1x1x1xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @slice_singleton
+func.func @slice_singleton() -> tensor<1x1xi32> {
+ %splat = "tosa.const"() {value = dense<[[0, 1, 2], [3, 4, 5], [6, 7 ,8]]> : tensor<3x3xi32>} : () -> tensor<3x3xi32>
+ // CHECK: %[[SLICE:.+]] = "tosa.const"() {value = dense<4> : tensor<1x1xi32>}
+ %slice = "tosa.slice"(%splat) { size = [1, 1], start = [1, 1] } : (tensor<3x3xi32>) -> tensor<1x1xi32>
+ // CHECK: return %[[SLICE]]
+ return %slice : tensor<1x1xi32>
+}