This revision starts making concrete use of 0-d vectors to extend the semantics of
ExtractElementOp.
In the process a new VectorOfAnyRank Tablegen OpBase.td is added to allow progressive transition to supporting 0-d vectors by gradually opting in.
Differential Revision: https://reviews.llvm.org/D114387
TypesMatchWith<"result type matches element type of vector operand",
"vector", "result",
"$_self.cast<ShapedType>().getElementType()">]>,
- Arguments<(ins AnyVector:$vector, AnySignlessIntegerOrIndex:$position)>,
+ Arguments<(ins AnyVectorOfAnyRank:$vector,
+ Optional<AnySignlessIntegerOrIndex>:$position)>,
Results<(outs AnyType:$result)> {
let summary = "extractelement operation";
let description = [{
- Takes an 1-D vector and a dynamic index position and extracts the
- scalar at that position. Note that this instruction resembles
- vector.extract, but is restricted to 1-D vectors and relaxed
- to dynamic indices. It is meant to be closer to LLVM's version:
+ Takes a 0-D or 1-D vector and a optional dynamic index position and
+ extracts the scalar at that position.
+
+ Note that this instruction resembles vector.extract, but is restricted to
+ 0-D and 1-D vectors and relaxed to dynamic indices.
+ If the vector is 0-D, the position must be llvm::None.
+
+
+ It is meant to be closer to LLVM's version:
https://llvm.org/docs/LangRef.html#extractelement-instruction
Example:
```mlir
%c = arith.constant 15 : i32
%1 = vector.extractelement %0[%c : i32]: vector<16xf32>
+ %2 = vector.extractelement %z[]: vector<f32>
```
}];
let assemblyFormat = [{
- $vector `[` $position `:` type($position) `]` attr-dict `:` type($vector)
+ $vector `[` ($position^ `:` type($position))? `]` attr-dict `:` type($vector)
}];
let builders = [
- OpBuilder<(ins "Value":$source, "Value":$position)>
+ // 0-D builder.
+ OpBuilder<(ins "Value":$source)>,
+ // 1-D + position builder.
+ OpBuilder<(ins "Value":$source, "Value":$position)>,
];
let extraClassDeclaration = [{
VectorType getVectorType() {
//===----------------------------------------------------------------------===//
// Whether a type is a VectorType.
-def IsVectorTypePred : CPred<"$_self.isa<::mlir::VectorType>()">;
+// Explicitly disallow 0-D vectors for now until we have good enough coverage.
+def IsVectorTypePred : And<[CPred<"$_self.isa<::mlir::VectorType>()">,
+ CPred<"$_self.cast<::mlir::VectorType>().getRank() > 0">]>;
+
+// Temporary vector type clone that allows gradual transition to 0-D vectors.
+def IsVectorOfAnyRankTypePred : CPred<"$_self.isa<::mlir::VectorType>()">;
// Whether a type is a TensorType.
def IsTensorTypePred : CPred<"$_self.isa<::mlir::TensorType>()">;
class VectorOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsVectorTypePred, "vector",
"::mlir::VectorType">;
+// Temporary vector type clone that allows gradual transition to 0-D vectors.
+class VectorOfAnyRankOf<list<Type> allowedTypes> :
+ ShapedContainerType<allowedTypes, IsVectorOfAnyRankTypePred, "vector",
+ "::mlir::VectorType">;
// Whether the number of elements of a vector is from the given
// `allowedRanks` list
"::mlir::VectorType">;
def AnyVector : VectorOf<[AnyType]>;
+// Temporary vector type clone that allows gradual transition to 0-D vectors.
+def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
// Shaped types.
return LLVM::LLVMPointerType::get(elementType, type.getMemorySpaceAsInt());
}
-/// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type
-/// when n > 1. For example, `vector<4 x f32>` remains as is while,
-/// `vector<4x8x16xf32>` converts to `!llvm.array<4xarray<8 x vector<16xf32>>>`.
+/// Convert an n-D vector type to an LLVM vector type:
+/// * 0-D `vector<T>` are converted to vector<1xT>
+/// * 1-D `vector<axT>` remains as is while,
+/// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
+/// `!llvm.array<ax...array<jxvector<kxT>>>`.
Type LLVMTypeConverter::convertVectorType(VectorType type) {
auto elementType = convertType(type.getElementType());
if (!elementType)
return {};
+ if (type.getShape().empty())
+ return VectorType::get({1}, elementType);
Type vectorType = VectorType::get(type.getShape().back(), elementType);
assert(LLVM::isCompatibleVectorType(vectorType) &&
"expected vector type compatible with the LLVM dialect");
LLVMTypeConverter &typeConverter, Location loc,
Value val1, Value val2, Type llvmType, int64_t rank,
int64_t pos) {
+ assert(rank > 0 && "0-D vector corner case should have been handled already");
if (rank == 1) {
auto idxType = rewriter.getIndexType();
auto constant = rewriter.create<LLVM::ConstantOp>(
static Value extractOne(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter, Location loc,
Value val, Type llvmType, int64_t rank, int64_t pos) {
+ assert(rank > 0 && "0-D vector corner case should have been handled already");
if (rank == 1) {
auto idxType = rewriter.getIndexType();
auto constant = rewriter.create<LLVM::ConstantOp>(
if (!llvmType)
return failure();
+ if (vectorType.getRank() == 0) {
+ Location loc = extractEltOp.getLoc();
+ auto idxType = rewriter.getIndexType();
+ auto zero = rewriter.create<LLVM::ConstantOp>(
+ loc, typeConverter->convertType(idxType),
+ rewriter.getIntegerAttr(idxType, 0));
+ rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
+ extractEltOp, llvmType, adaptor.vector(), zero);
+ return success();
+ }
+
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
extractEltOp, llvmType, adaptor.vector(), adaptor.position());
return success();
//===----------------------------------------------------------------------===//
void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
+ Value source) {
+ result.addOperands({source});
+ result.addTypes(source.getType().cast<VectorType>().getElementType());
+}
+
+void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
Value source, Value position) {
result.addOperands({source, position});
result.addTypes(source.getType().cast<VectorType>().getElementType());
static LogicalResult verify(vector::ExtractElementOp op) {
VectorType vectorType = op.getVectorType();
+ if (vectorType.getRank() == 0) {
+ if (op.position())
+ return op.emitOpError("expected position to be empty with 0-D vector");
+ return success();
+ }
if (vectorType.getRank() != 1)
- return op.emitOpError("expected 1-D vector");
+ return op.emitOpError("unexpected >1 vector rank");
+ if (!op.position())
+ return op.emitOpError("expected position for 1-D vector");
return success();
}
// -----
+// CHECK-LABEL: @extract_element_0d
+func @extract_element_0d(%a: vector<f32>) -> f32 {
+ // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+ // CHECK: llvm.extractelement %{{.*}}[%[[C0]] : {{.*}}] : vector<1xf32>
+ %1 = vector.extractelement %a[] : vector<f32>
+ return %1 : f32
+}
+
+// -----
+
func @extract_element(%arg0: vector<16xf32>) -> f32 {
%0 = arith.constant 15 : i32
%1 = vector.extractelement %arg0[%0 : i32]: vector<16xf32>
// -----
+func @extract_element(%arg0: vector<f32>) {
+ %c = arith.constant 3 : i32
+ // expected-error@+1 {{expected position to be empty with 0-D vector}}
+ %1 = vector.extractelement %arg0[%c : i32] : vector<f32>
+}
+
+// -----
+
+func @extract_element(%arg0: vector<4xf32>) {
+ %c = arith.constant 3 : i32
+ // expected-error@+1 {{expected position for 1-D vector}}
+ %1 = vector.extractelement %arg0[] : vector<4xf32>
+}
+
+// -----
+
func @extract_element(%arg0: vector<4x4xf32>) {
%c = arith.constant 3 : i32
- // expected-error@+1 {{'vector.extractelement' op expected 1-D vector}}
+ // expected-error@+1 {{unexpected >1 vector rank}}
%1 = vector.extractelement %arg0[%c : i32] : vector<4x4xf32>
}
return %1 : vector<3x4xf32>
}
+// CHECK-LABEL: @extract_element_0d
+func @extract_element_0d(%a: vector<f32>) -> f32 {
+ // CHECK-NEXT: vector.extractelement %{{.*}}[] : vector<f32>
+ %1 = vector.extractelement %a[] : vector<f32>
+ return %1 : f32
+}
+
// CHECK-LABEL: @extract_element
func @extract_element(%a: vector<16xf32>) -> f32 {
// CHECK: %[[C15:.*]] = arith.constant 15 : i32
--- /dev/null
+// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+func @extract_element_0d(%a: vector<f32>) {
+ %1 = vector.extractelement %a[] : vector<f32>
+ // CHECK: 42
+ vector.print %1: f32
+ return
+}
+
+func @entry() {
+ %1 = arith.constant dense<42.0> : vector<f32>
+ call @extract_element_0d(%1) : (vector<f32>) -> ()
+ return
+}