```mlir
%a = fir.undefined !fir.array<10x10xf32>
%c = arith.constant 3.0 : f32
- %1 = fir.insert_on_range %a, %c, [0 : index, 7 : index, 0 : index, 2 : index] : (!fir.array<10x10xf32>, f32) -> !fir.array<10x10xf32>
+ %1 = fir.insert_on_range %a, %c from (0, 0) to (7, 2) : (!fir.array<10x10xf32>, f32) -> !fir.array<10x10xf32>
```
The first 28 elements of %1, with coordinates from (0,0) to (7,2), have
the value 3.0.
}];
- let arguments = (ins fir_SequenceType:$seq, AnyType:$val, ArrayAttr:$coor);
+ let arguments = (ins fir_SequenceType:$seq, AnyType:$val, IndexElementsAttr:$coor);
let results = (outs fir_SequenceType);
let assemblyFormat = [{
- $seq `,` $val `,` $coor attr-dict `:` functional-type(operands, results)
+ $seq `,` $val custom<CustomRangeSubscript>($coor) attr-dict `:` functional-type(operands, results)
}];
let verifier = "return ::verify(*this);";
return success();
}
- bool isFullRange(mlir::ArrayAttr indexes, fir::SequenceType seqTy) const {
+ bool isFullRange(mlir::DenseIntElementsAttr indexes,
+ fir::SequenceType seqTy) const {
auto extents = seqTy.getShape();
- if (indexes.size() / 2 != extents.size())
+ if (indexes.size() / 2 != static_cast<int64_t>(extents.size()))
return false;
+ auto cur_index = indexes.value_begin<int64_t>();
for (unsigned i = 0; i < indexes.size(); i += 2) {
- if (indexes[i].cast<IntegerAttr>().getInt() != 0)
+ if (*(cur_index++) != 0)
return false;
- if (indexes[i + 1].cast<IntegerAttr>().getInt() != extents[i / 2] - 1)
+ if (*(cur_index++) != extents[i / 2] - 1)
return false;
}
return true;
SmallVector<uint64_t> lBounds;
SmallVector<uint64_t> uBounds;
- // Extract integer value from the attribute
- SmallVector<int64_t> coordinates = llvm::to_vector<4>(
- llvm::map_range(range.coor(), [](Attribute a) -> int64_t {
- return a.cast<IntegerAttr>().getInt();
- }));
-
// Unzip the upper and lower bound and convert to a row major format.
- for (auto i = coordinates.rbegin(), e = coordinates.rend(); i != e; ++i) {
+ mlir::DenseIntElementsAttr coor = range.coor();
+ auto reversedCoor = llvm::reverse(coor.getValues<int64_t>());
+ for (auto i = reversedCoor.begin(), e = reversedCoor.end(); i != e; ++i) {
uBounds.push_back(*i++);
lBounds.push_back(*i);
}
#include "flang/Optimizer/Support/Utils.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/TypeSwitch.h"
// InsertOnRangeOp
//===----------------------------------------------------------------------===//
+static ParseResult
+parseCustomRangeSubscript(mlir::OpAsmParser &parser,
+ mlir::DenseIntElementsAttr &coord) {
+ llvm::SmallVector<int64_t> lbounds;
+ llvm::SmallVector<int64_t> ubounds;
+ if (parser.parseKeyword("from") ||
+ parser.parseCommaSeparatedList(
+ AsmParser::Delimiter::Paren,
+ [&] { return parser.parseInteger(lbounds.emplace_back(0)); }) ||
+ parser.parseKeyword("to") ||
+ parser.parseCommaSeparatedList(AsmParser::Delimiter::Paren, [&] {
+ return parser.parseInteger(ubounds.emplace_back(0));
+ }))
+ return failure();
+ llvm::SmallVector<int64_t> zippedBounds;
+ for (auto zip : llvm::zip(lbounds, ubounds)) {
+ zippedBounds.push_back(std::get<0>(zip));
+ zippedBounds.push_back(std::get<1>(zip));
+ }
+ coord = mlir::Builder(parser.getContext()).getIndexTensorAttr(zippedBounds);
+ return success();
+}
+
+void printCustomRangeSubscript(mlir::OpAsmPrinter &printer, InsertOnRangeOp op,
+ mlir::DenseIntElementsAttr coord) {
+ printer << "from (";
+ auto enumerate = llvm::enumerate(coord.getValues<int64_t>());
+ // Even entries are the lower bounds.
+ llvm::interleaveComma(
+ make_filter_range(
+ enumerate,
+ [](auto indexed_value) { return indexed_value.index() % 2 == 0; }),
+ printer, [&](auto indexed_value) { printer << indexed_value.value(); });
+ printer << ") to (";
+ // Odd entries are the upper bounds.
+ llvm::interleaveComma(
+ make_filter_range(
+ enumerate,
+ [](auto indexed_value) { return indexed_value.index() % 2 != 0; }),
+ printer, [&](auto indexed_value) { printer << indexed_value.value(); });
+ printer << ")";
+}
+
/// Range bounds must be nonnegative, and the range must not be empty.
static mlir::LogicalResult verify(fir::InsertOnRangeOp op) {
if (fir::hasDynamicSize(op.seq().getType()))
return op.emitOpError("must have constant shape and size");
- if (op.coor().size() < 2 || op.coor().size() % 2 != 0)
+ mlir::DenseIntElementsAttr coor = op.coor();
+ if (coor.size() < 2 || coor.size() % 2 != 0)
return op.emitOpError("has uneven number of values in ranges");
bool rangeIsKnownToBeNonempty = false;
- for (auto i = op.coor().end(), b = op.coor().begin(); i != b;) {
- int64_t ub = (*--i).cast<IntegerAttr>().getInt();
- int64_t lb = (*--i).cast<IntegerAttr>().getInt();
+ for (auto i = coor.getValues<int64_t>().end(),
+ b = coor.getValues<int64_t>().begin();
+ i != b;) {
+ int64_t ub = (*--i);
+ int64_t lb = (*--i);
if (lb < 0 || ub < 0)
return op.emitOpError("negative range bound");
if (rangeIsKnownToBeNonempty)
fir.global internal @_QEmultiarray : !fir.array<32x32xi32> {
%c0_i32 = arith.constant 1 : i32
%0 = fir.undefined !fir.array<32x32xi32>
- %2 = fir.insert_on_range %0, %c0_i32, [0 : index, 31 : index, 0 : index, 31 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
+ %2 = fir.insert_on_range %0, %c0_i32 from (0, 0) to (31, 31) : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
fir.has_value %2 : !fir.array<32x32xi32>
}
fir.global internal @_QEmultiarray : !fir.array<32xi32> {
%c0_i32 = arith.constant 1 : i32
%0 = fir.undefined !fir.array<32xi32>
- %2 = fir.insert_on_range %0, %c0_i32, [5 : index, 31 : index] : (!fir.array<32xi32>, i32) -> !fir.array<32xi32>
+ %2 = fir.insert_on_range %0, %c0_i32 from (5) to (31) : (!fir.array<32xi32>, i32) -> !fir.array<32xi32>
fir.has_value %2 : !fir.array<32xi32>
}
%c1_i32 = arith.constant 9 : i32
// CHECK: [[ARR2:%.*]] = fir.zero_bits !fir.array<10xi32>
- // CHECK: [[ARR3:%.*]] = fir.insert_on_range [[ARR2]], [[C1_I32]], [2 : index, 9 : index] : (!fir.array<10xi32>, i32) -> !fir.array<10xi32>
+ // CHECK: [[ARR3:%.*]] = fir.insert_on_range [[ARR2]], [[C1_I32]] from (2) to (9) : (!fir.array<10xi32>, i32) -> !fir.array<10xi32>
// CHECK: fir.call @noret1([[ARR3]]) : (!fir.array<10xi32>) -> ()
%arr2 = fir.zero_bits !fir.array<10xi32>
- %arr3 = fir.insert_on_range %arr2, %c1_i32, [2 : index, 9 : index] : (!fir.array<10xi32>, i32) -> !fir.array<10xi32>
+ %arr3 = fir.insert_on_range %arr2, %c1_i32 from (2) to (9) : (!fir.array<10xi32>, i32) -> !fir.array<10xi32>
fir.call @noret1(%arr3) : (!fir.array<10xi32>) -> ()
// CHECK: [[SHAPE:%.*]] = fir.shape_shift [[INDXM:%.*]], [[INDXN:%.*]], [[INDXO:%.*]], [[INDXP:%.*]] : (index, index, index, index) -> !fir.shapeshift<2>
return
}
+// CHECK-LABEL: @insert_on_range_multi_dim
+// CHECK-SAME: %[[ARR:.*]]: !fir.array<10x20xi32>, %[[CST:.*]]: i32
+func @insert_on_range_multi_dim(%arr : !fir.array<10x20xi32>, %cst : i32) {
+ // CHECK: fir.insert_on_range %[[ARR]], %[[CST]] from (2, 3) to (5, 6) : (!fir.array<10x20xi32>, i32) -> !fir.array<10x20xi32>
+ %arr3 = fir.insert_on_range %arr, %cst from (2, 3) to (5, 6) : (!fir.array<10x20xi32>, i32) -> !fir.array<10x20xi32>
+ return
+}
+
// CHECK-LABEL: @test_shift
func @test_shift(%arg0: !fir.box<!fir.array<?xf32>>) -> !fir.ref<f32> {
%c4 = arith.constant 4 : index
%c0_i32 = arith.constant 1 : i32
%0 = fir.undefined !fir.array<32x32xi32>
// expected-error@+1 {{'fir.insert_on_range' op has uneven number of values in ranges}}
- %2 = fir.insert_on_range %0, %c0_i32, [0 : index, 31 : index, 0 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
+ %2 = "fir.insert_on_range"(%0, %c0_i32) { coor = dense<[0, 31, 0]> : tensor<3xindex> } : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
fir.has_value %2 : !fir.array<32x32xi32>
}
%c0_i32 = arith.constant 1 : i32
%0 = fir.undefined !fir.array<32x32xi32>
// expected-error@+1 {{'fir.insert_on_range' op has uneven number of values in ranges}}
- %2 = fir.insert_on_range %0, %c0_i32, [0 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
+ %2 = "fir.insert_on_range"(%0, %c0_i32) { coor = dense<[0]> : tensor<1xindex> } : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
fir.has_value %2 : !fir.array<32x32xi32>
}
%c0_i32 = arith.constant 1 : i32
%0 = fir.undefined !fir.array<32x32xi32>
// expected-error@+1 {{'fir.insert_on_range' op negative range bound}}
- %2 = fir.insert_on_range %0, %c0_i32, [-1 : index, 0 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
+ %2 = fir.insert_on_range %0, %c0_i32 from (-1) to (0) : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
fir.has_value %2 : !fir.array<32x32xi32>
}
%c0_i32 = arith.constant 1 : i32
%0 = fir.undefined !fir.array<32x32xi32>
// expected-error@+1 {{'fir.insert_on_range' op empty range}}
- %2 = fir.insert_on_range %0, %c0_i32, [10 : index, 9 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
+ %2 = fir.insert_on_range %0, %c0_i32 from (10) to (9) : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
fir.has_value %2 : !fir.array<32x32xi32>
}
%c0_i32 = arith.constant 1 : i32
%0 = fir.undefined !fir.array<?xi32>
// expected-error@+1 {{'fir.insert_on_range' op must have constant shape and size}}
- %2 = fir.insert_on_range %0, %c0_i32, [0 : index, 10 : index] : (!fir.array<?xi32>, i32) -> !fir.array<?xi32>
+ %2 = fir.insert_on_range %0, %c0_i32 from (0) to (10) : (!fir.array<?xi32>, i32) -> !fir.array<?xi32>
fir.has_value %2 : !fir.array<?xi32>
}
%c0_i32 = arith.constant 1 : i32
%0 = fir.undefined !fir.array<*:i32>
// expected-error@+1 {{'fir.insert_on_range' op must have constant shape and size}}
- %2 = fir.insert_on_range %0, %c0_i32, [0 : index, 10 : index] : (!fir.array<*:i32>, i32) -> !fir.array<*:i32>
+ %2 = fir.insert_on_range %0, %c0_i32 from (0) to (10) : (!fir.array<*:i32>, i32) -> !fir.array<*:i32>
fir.has_value %2 : !fir.array<*:i32>
}