//===----------------------------------------------------------------------===//
-// Reduce op.
+// Map op.
//===----------------------------------------------------------------------===//
def TensorOrMemref :
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
+def MapOp : LinalgStructuredBase_Op<"map", [
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>,
+ SingleBlockImplicitTerminator<"YieldOp">]> {
+ let summary = "Elementwise operations";
+ let description = [{
+ Models elementwise operations on tensors in terms of arithmetic operations
+ on the corresponding elements.
+
+ Example:
+ ```
+ %add = linalg.map
+ ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
+ outs(%init: tensor<64xf32>)
+ (%lhs_elem: f32, %rhs_elem: f32) {
+ %0 = arith.addf %lhs_elem, %rhs_elem: f32
+ linalg.yield %0: f32
+ }
+ ```
+ }];
+
+ let arguments = (ins
+ // Input args
+ Variadic<TensorOrMemref>:$inputs,
+
+ // Output arg
+ TensorOrMemref:$init
+ );
+ let results = (outs Variadic<AnyTensor>:$result);
+ let regions = (region SizedRegion<1>:$mapper);
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+ // Implement functions necessary for LinalgStructuredInterface.
+ ArrayAttr getIteratorTypes();
+ ArrayAttr getIndexingMaps();
+ std::string getLibraryCallName() {
+ return "op_has_no_registered_library_name";
+ }
+
+ // Implement functions necessary for DestinationStyleOpInterface.
+ unsigned getNumInputs() {
+ return this->getOperation()->getNumOperands() - getNumOutputs();
+ };
+ unsigned getNumOutputs() { return 1; };
+ mlir::ValueRange getOutputs() { return getOperands().take_back(1); }
+ linalg::OpOperandVector getOpOperandsMatchingBBargs() {
+ return getInputOperands();
+ }
+
+ static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+ mlir::ArrayRef<mlir::NamedAttribute>)>
+ getRegionBuilder() {
+ return nullptr;
+ }
+ }];
+
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+}
+
+
+//===----------------------------------------------------------------------===//
+// Reduce op.
+//===----------------------------------------------------------------------===//
+
def ReduceOp : LinalgStructuredBase_Op<"reduce", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>,
ConfinedAttr<DenseI64ArrayAttr,
[DenseArrayStrictlySorted<DenseI64ArrayAttr>]>:$dimensions
);
- let results = (outs Variadic<TensorOrMemref>);
+ let results = (outs Variadic<AnyTensor>);
let regions = (region SizedRegion<1>:$combiner);
let extraClassDeclaration = structuredOpsBaseDecls # [{
}
//===----------------------------------------------------------------------===//
+// MapOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseDstStyleOp(
+ OpAsmParser &parser, OperationState &result,
+ function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
+ nullptr) {
+ // Parse `ins` and `outs`.
+ SmallVector<Type, 4> inputTypes, outputTypes;
+ if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
+ /*addOperandSegmentSizes=*/false))
+ return failure();
+
+ // Add result types.
+ for (Type outputType : outputTypes) {
+ if (outputType.isa<RankedTensorType>())
+ result.addTypes(outputType);
+ }
+
+ // Parse required attributes.
+ if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
+ return failure();
+
+ // Parse optional attributes.
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+ return success();
+}
+
+void MapOp::getAsmBlockArgumentNames(Region ®ion,
+ OpAsmSetValueNameFn setNameFn) {
+ for (Value v : getRegionInputArgs())
+ setNameFn(v, "in");
+}
+
+void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
+ if (!getResults().empty())
+ setNameFn(getResults().front(), "mapped");
+}
+
+ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
+ if (parseDstStyleOp(parser, result))
+ return failure();
+
+ SmallVector<OpAsmParser::Argument> regionArgs;
+ if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
+ /*allowType=*/true, /*allowAttrs=*/true)) {
+ return failure();
+ }
+
+ Region *body = result.addRegion();
+ if (parser.parseRegion(*body, regionArgs))
+ return failure();
+
+ return success();
+}
+
+void MapOp::print(OpAsmPrinter &p) {
+ printCommonStructuredOpParts(p, getInputs(), getOutputs());
+ p.printOptionalAttrDict((*this)->getAttrs());
+
+ p << "(";
+ llvm::interleaveComma(getMapper().getArguments(), p,
+ [&](auto arg) { p.printRegionArgument(arg); });
+ p << ") ";
+
+ p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
+}
+
+LogicalResult MapOp::verify() {
+ auto *bodyBlock = getBody();
+ auto blockArgs = bodyBlock->getArguments();
+
+ // Checks if the number of `inputs` match the arity of the `mapper` region.
+ if (getInputs().size() != blockArgs.size())
+ return emitOpError() << "expects number of operands to match the arity of "
+ "mapper, but got: "
+ << getInputs().size() << " and " << blockArgs.size();
+
+ // The parameters of mapper should all match the element type // of inputs.
+ for (const auto &[bbArgType, inputArg] :
+ llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
+ auto inputElemType = inputArg.getType().cast<ShapedType>().getElementType();
+ if (bbArgType != inputElemType) {
+ return emitOpError() << "expected element type of input " << inputElemType
+ << " to match bbArg type " << bbArgType;
+ }
+ }
+
+ // The shape of each input must match the shape of the output.
+ auto outputShape =
+ getOutputs().front().getType().cast<ShapedType>().getShape();
+ for (Type inputArgType : TypeRange{getInputs()}) {
+ auto inputElemShape = inputArgType.cast<ShapedType>().getShape();
+ if (inputElemShape != outputShape) {
+ return emitOpError() << "expected shape of input (" << inputElemShape
+ << ") to match shape of output (" << outputShape
+ << ")";
+ }
+ }
+
+ return success();
+}
+
+ArrayAttr MapOp::getIteratorTypes() {
+ int64_t rank = getInit().getType().getRank();
+ return Builder(getContext())
+ .getStrArrayAttr(
+ SmallVector<StringRef>(rank, getParallelIteratorTypeName()));
+}
+
+ArrayAttr MapOp::getIndexingMaps() {
+ Builder builder(getContext());
+ int64_t rank = getInit().getType().getRank();
+ int64_t numIndexingMaps = getOperands().size();
+ return builder.getAffineMapArrayAttr(SmallVector<AffineMap>(
+ numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
+}
+
+void MapOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ SmallVector<Value> inputBuffers = getInputBufferOperands();
+ SmallVector<Value> outputBuffers = getOutputBufferOperands();
+ getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
+ outputBuffers);
+}
+
+//===----------------------------------------------------------------------===//
// ReduceOp
//===----------------------------------------------------------------------===//
void ReduceOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
- setNameFn(getResults().front(), "reduced");
+ if (!getResults().empty())
+ setNameFn(getResults().front(), "reduced");
}
ArrayAttr ReduceOp::getIteratorTypes() {
outputBuffers);
}
-static ParseResult parseDstStyleOp(
- OpAsmParser &parser, OperationState &result,
- function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
- nullptr) {
- // Parse `ins` and `outs`.
- SmallVector<Type, 4> inputTypes, outputTypes;
- if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
- /*addOperandSegmentSizes=*/false))
- return failure();
-
- // Add result types.
- for (Type outputType : outputTypes) {
- if (!outputType.isa<RankedTensorType>())
- return failure();
- result.addTypes(outputType);
- }
-
- // Parse required attributes.
- if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
- return failure();
-
- // Parse optional attributes.
- if (parser.parseOptionalAttrDict(result.attributes))
- return failure();
- return success();
-}
-
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
NamedAttrList &attributes,
StringRef attributeName) {
// -----
+func.func @map_binary_wrong_yield_operands(
+ %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
+ -> tensor<64xf32> {
+ %add = linalg.map
+ ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
+ outs(%init:tensor<64xf32>)
+ (%lhs_elem: f32, %rhs_elem: f32) {
+ %0 = arith.addf %lhs_elem, %rhs_elem: f32
+ // expected-error @+1{{'linalg.yield' op expected number of yield values (1) to match the number of operands of the enclosing LinalgOp (2)}}
+ linalg.yield %0, %0: f32, f32
+ }
+ func.return %add : tensor<64xf32>
+}
+
+// -----
+
+func.func @map_input_mapper_arity_mismatch(
+ %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
+ -> tensor<64xf32> {
+ // expected-error@+1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 2 and 3}}
+ %add = linalg.map
+ ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
+ outs(%init:tensor<64xf32>)
+ (%lhs_elem: f32, %rhs_elem: f32, %extra_elem: f32) {
+ %0 = arith.addf %lhs_elem, %rhs_elem: f32
+ linalg.yield %0: f32
+ }
+ func.return %add : tensor<64xf32>
+}
+
+// -----
+
+func.func @map_input_mapper_type_mismatch(
+ %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
+ -> tensor<64xf32> {
+ // expected-error@+1{{'linalg.map' op expected element type of input 'f32' to match bbArg type 'f64'}}
+ %add = linalg.map
+ ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
+ outs(%init:tensor<64xf32>)
+ (%lhs_elem: f64, %rhs_elem: f64) {
+ %0 = arith.addf %lhs_elem, %rhs_elem: f64
+ linalg.yield %0: f64
+ }
+ func.return %add : tensor<64xf32>
+}
+
+// -----
+
+func.func @map_input_output_shape_mismatch(
+ %lhs: tensor<64x64xf32>, %rhs: tensor<64x64xf32>, %init: tensor<32xf32>)
+ -> tensor<32xf32> {
+ // expected-error@+1{{'linalg.map' op expected shape of input (64, 64) to match shape of output (32)}}
+ %add = linalg.map
+ ins(%lhs, %rhs : tensor<64x64xf32>, tensor<64x64xf32>)
+ outs(%init:tensor<32xf32>)
+ (%lhs_elem: f32, %rhs_elem: f32) {
+ %0 = arith.addf %lhs_elem, %rhs_elem: f32
+ linalg.yield %0: f32
+ }
+ func.return %add : tensor<32xf32>
+}
+
+// -----
+
func.func @reduce_input_vs_init_dimension_mismatch(
%input: tensor<16x32x64xf32>,
%init: tensor<16x64xf32>) -> tensor<16x64xf32> {
// -----
+func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
+ %init: tensor<64xf32>) -> tensor<64xf32> {
+ %add = linalg.map
+ ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
+ outs(%init:tensor<64xf32>)
+ (%lhs_elem: f32, %rhs_elem: f32) {
+ %0 = arith.addf %lhs_elem, %rhs_elem: f32
+ linalg.yield %0: f32
+ }
+ func.return %add : tensor<64xf32>
+}
+// CHECK-LABEL: func @map_binary
+// CHECK: linalg.map
+
+// -----
+
+func.func @map_binary_memref(%lhs: memref<64xf32>, %rhs: memref<64xf32>,
+ %init: memref<64xf32>) {
+ linalg.map
+ ins(%lhs, %rhs: memref<64xf32>, memref<64xf32>)
+ outs(%init:memref<64xf32>)
+ (%lhs_elem: f32, %rhs_elem: f32) {
+ %0 = arith.addf %lhs_elem, %rhs_elem: f32
+ linalg.yield %0: f32
+ }
+ func.return
+}
+// CHECK-LABEL: func @map_binary_memref
+// CHECK: linalg.map
+
+// -----
+
+func.func @map_unary(%input: tensor<64xf32>, %init: tensor<64xf32>) -> tensor<64xf32> {
+ %abs = linalg.map
+ ins(%input:tensor<64xf32>)
+ outs(%init:tensor<64xf32>)
+ (%input_elem: f32) {
+ %0 = math.absf %input_elem: f32
+ linalg.yield %0: f32
+ }
+ func.return %abs : tensor<64xf32>
+}
+// CHECK-LABEL: func @map_unary
+// CHECK: linalg.map
+
+// -----
+
+func.func @map_unary_memref(%input: memref<64xf32>, %init: memref<64xf32>) {
+ linalg.map
+ ins(%input:memref<64xf32>)
+ outs(%init:memref<64xf32>)
+ (%input_elem: f32) {
+ %0 = math.absf %input_elem: f32
+ linalg.yield %0: f32
+ }
+ func.return
+}
+// CHECK-LABEL: func @map_unary_memref
+// CHECK: linalg.map
+
+// -----
+
func.func @reduce(%input: tensor<16x32x64xf32>,
- %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
+ %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
%reduce = linalg.reduce
ins(%input:tensor<16x32x64xf32>)
outs(%init:tensor<16x64xf32>)
// -----
+func.func @reduce_memref(%input: memref<16x32x64xf32>,
+ %init: memref<16x64xf32>) {
+ linalg.reduce
+ ins(%input:memref<16x32x64xf32>)
+ outs(%init:memref<16x64xf32>)
+ dimensions = [1]
+ (%in: f32, %out: f32) {
+ %0 = arith.addf %in, %out: f32
+ linalg.yield %0: f32
+ }
+ func.return
+}
+// CHECK-LABEL: func @reduce_memref
+// CHECK: linalg.reduce
+
+// -----
+
func.func @variadic_reduce(%input1: tensor<16x32x64xf32>,
%init1: tensor<16x64xf32>, %input2: tensor<16x32x64xi64>,
%init2: tensor<16x64xi64>) -> (tensor<16x64xf32>, tensor<16x64xi64>) {
}
// CHECK-LABEL: func @variadic_reduce
// CHECK: linalg.reduce
+
+// -----
+
+func.func @variadic_reduce_memref(%input1: memref<16x32x64xf32>,
+ %init1: memref<16x64xf32>, %input2: memref<16x32x64xi64>,
+ %init2: memref<16x64xi64>) {
+ linalg.reduce
+ ins(%input1, %input2 : memref<16x32x64xf32>, memref<16x32x64xi64>)
+ outs(%init1, %init2 : memref<16x64xf32>, memref<16x64xi64>)
+ dimensions = [1]
+ (%in1: f32, %in2: i64, %out1: f32, %out2: i64) {
+ %0 = arith.addf %in1, %out1: f32
+ %1 = arith.addi %in2, %out2: i64
+ linalg.yield %0, %1: f32, i64
+ }
+ func.return
+}
+// CHECK-LABEL: func @variadic_reduce_memref
+// CHECK: linalg.reduce