// a. Get the first max ranked shape.
VectorType firstMaxRankedType;
for (Value operand : op->getOperands()) {
- auto vecType = dyn_cast<VectorType>(bvm.lookup(operand).getType());
+ auto vecOperand = bvm.lookup(operand);
+ assert(vecOperand && "Vector operand couldn't be found");
+
+ auto vecType = dyn_cast<VectorType>(vecOperand.getType());
if (vecType && (!firstMaxRankedType ||
firstMaxRankedType.getRank() < vecType.getRank()))
firstMaxRankedType = vecType;
}
// b. Broadcast each op if needed.
- SmallVector<Value> vectorizedOperands;
+ SmallVector<Value> vecOperands;
for (Value scalarOperand : op->getOperands()) {
- Value vectorizedOperand = bvm.lookup(scalarOperand);
- auto vecType =
- VectorType::get(firstMaxRankedType.getShape(),
- getElementTypeOrSelf(vectorizedOperand.getType()),
- firstMaxRankedType.getNumScalableDims());
- vectorizedOperands.push_back(
- !firstMaxRankedType
- ? vectorizedOperand
- : broadcastIfNeeded(rewriter, vectorizedOperand, vecType));
+ Value vecOperand = bvm.lookup(scalarOperand);
+ assert(vecOperand && "Vector operand couldn't be found");
+
+ if (firstMaxRankedType) {
+ auto vecType = VectorType::get(firstMaxRankedType.getShape(),
+ getElementTypeOrSelf(vecOperand.getType()),
+ firstMaxRankedType.getNumScalableDims());
+ vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
+ } else {
+ vecOperands.push_back(vecOperand);
+ }
}
// c. for elementwise, the result is the vector with the firstMaxRankedShape
SmallVector<Type> resultTypes;
for (Type resultType : op->getResultTypes()) {
resultTypes.push_back(
- !firstMaxRankedType
- ? resultType
- : VectorType::get(firstMaxRankedType.getShape(), resultType,
- firstMaxRankedType.getNumScalableDims()));
+ firstMaxRankedType
+ ? VectorType::get(firstMaxRankedType.getShape(), resultType,
+ firstMaxRankedType.getNumScalableDims())
+ : resultType);
}
// d. Build and return the new op.
return VectorizationResult{
VectorizationStatus::NewOp,
- rewriter.create(op->getLoc(), op->getName().getIdentifier(),
- vectorizedOperands, resultTypes, op->getAttrs())};
+ rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
+ resultTypes, op->getAttrs())};
}
/// Generic vectorization function that rewrites the body of a `linalgOp` into
%1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize %1 { vectorize_padding } : (!transform.any_op) -> !transform.any_op
}
+
+// -----
+
+func.func @zero_dim_tensor(%input: tensor<f32>, %output: tensor<f32>) -> tensor<f32>
+{
+ %0 = linalg.generic { indexing_maps = [ affine_map<() -> ()>, affine_map<() -> ()> ],
+ iterator_types = [] }
+ ins(%input : tensor<f32>)
+ outs(%output : tensor<f32>) {
+ ^bb0(%arg0: f32, %arg1: f32):
+ %2 = arith.addf %arg0, %arg1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %4 = get_closest_isolated_parent %3 : (!transform.any_op) -> !transform.any_op
+ %5 = transform.structured.vectorize %4 : (!transform.any_op) -> !transform.any_op
+}
+
+// CHECK-LABEL: func @zero_dim_tensor
+// CHECK: vector.transfer_read {{.*}} : tensor<f32>, vector<f32>
+// CHECK: vector.extractelement
+// CHECK: vector.transfer_read {{.*}} : tensor<f32>, vector<f32>
+// CHECK: vector.extractelement
+// CHECK: arith.addf {{.*}} : f32
+// CHECK: vector.broadcast %{{.*}} : f32 to vector<f32>
+// CHECK: vector.transfer_write {{.*}} : vector<f32>, tensor<f32>
+