auto type = this->lowering.convertType(op->getResult(i)->getType());
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
op->getLoc(), type, newOp.getOperation()->getResult(0),
- this->getIntegerArrayAttr(rewriter, i)));
+ rewriter.getIndexArrayAttr(i)));
}
rewriter.replaceOp(op, results);
return this->matchSuccess();
}
};
+// Express `linearIndex` in terms of coordinates of `basis`.
+// Returns the empty vector when linearIndex is out of the range [0, P] where
+// P is the product of all the basis coordinates.
+//
+// Prerequisites:
+// Basis is an array of nonnegative integers (signed type inherited from
+// vector shape type).
+static SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
+ unsigned linearIndex) {
+ SmallVector<int64_t, 4> res;
+ res.reserve(basis.size());
+ for (unsigned basisElement : llvm::reverse(basis)) {
+ res.push_back(linearIndex % basisElement);
+ linearIndex = linearIndex / basisElement;
+ }
+ if (linearIndex > 0)
+ return {};
+ std::reverse(res.begin(), res.end());
+ return res;
+}
+
+// Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect
+// Ops for binary ops with one result. This supports higher-dimensional vector
+// types.
+template <typename SourceOp, typename TargetOp>
+struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
+ using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
+ using Super = BinaryOpLLVMOpLowering<SourceOp, TargetOp>;
+
+ // Convert the type of the result to an LLVM type, pass operands as is,
+ // preserve attributes.
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ static_assert(
+ std::is_base_of<OpTrait::NOperands<2>::Impl<SourceOp>, SourceOp>::value,
+ "expected binary op");
+ static_assert(
+ std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
+ "expected single result op");
+ static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
+ SourceOp>::value,
+ "expected single result op");
+
+ auto loc = op->getLoc();
+ auto llvmArrayTy = operands[0]->getType().cast<LLVM::LLVMType>();
+
+ if (!llvmArrayTy.isArrayTy()) {
+ auto newOp = rewriter.create<TargetOp>(
+ op->getLoc(), operands[0]->getType(), operands, op->getAttrs());
+ rewriter.replaceOp(op, newOp.getResult());
+ return this->matchSuccess();
+ }
+
+ // Unroll iterated array type until we hit a non-array type.
+ auto llvmTy = llvmArrayTy;
+ SmallVector<int64_t, 4> arraySizes;
+ while (llvmTy.isArrayTy()) {
+ arraySizes.push_back(llvmTy.getArrayNumElements());
+ llvmTy = llvmTy.getArrayElementType();
+ }
+ assert(llvmTy.isVectorTy() && "unexpected binary op over non-vector type");
+ auto llvmVectorTy = llvmTy;
+
+ // Iteratively extract a position coordinates with basis `arraySize` from a
+ // `linearIndex` that is incremented at each step. This terminates when
+ // `linearIndex` exceeds the range specified by `arraySize`.
+ // This has the effect of fully unrolling the dimensions of the n-D array
+ // type, getting to the underlying vector element.
+ Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
+ unsigned ub = 1;
+ for (auto s : arraySizes)
+ ub *= s;
+ for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
+ auto coords = getCoordinates(arraySizes, linearIndex);
+ // Linear index is out of bounds, we are done.
+ if (coords.empty())
+ break;
+
+ auto position = rewriter.getIndexArrayAttr(coords);
+
+ // For this unrolled `position` corresponding to the `linearIndex`^th
+ // element, extract operand vectors
+ Value *extractedLHS = rewriter.create<LLVM::ExtractValueOp>(
+ loc, llvmVectorTy, operands[0], position);
+ Value *extractedRHS = rewriter.create<LLVM::ExtractValueOp>(
+ loc, llvmVectorTy, operands[1], position);
+ Value *newVal = rewriter.create<TargetOp>(
+ loc, llvmVectorTy, ArrayRef<Value *>{extractedLHS, extractedRHS},
+ op->getAttrs());
+ desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc,
+ newVal, position);
+ }
+ rewriter.replaceOp(op, desc);
+ return this->matchSuccess();
+ }
+};
+
// Specific lowerings.
// FIXME: this should be tablegen'ed.
-struct AddIOpLowering : public OneToOneLLVMOpLowering<AddIOp, LLVM::AddOp> {
+struct AddIOpLowering : public BinaryOpLLVMOpLowering<AddIOp, LLVM::AddOp> {
using Super::Super;
};
-struct SubIOpLowering : public OneToOneLLVMOpLowering<SubIOp, LLVM::SubOp> {
+struct SubIOpLowering : public BinaryOpLLVMOpLowering<SubIOp, LLVM::SubOp> {
using Super::Super;
};
-struct MulIOpLowering : public OneToOneLLVMOpLowering<MulIOp, LLVM::MulOp> {
+struct MulIOpLowering : public BinaryOpLLVMOpLowering<MulIOp, LLVM::MulOp> {
using Super::Super;
};
-struct DivISOpLowering : public OneToOneLLVMOpLowering<DivISOp, LLVM::SDivOp> {
+struct DivISOpLowering : public BinaryOpLLVMOpLowering<DivISOp, LLVM::SDivOp> {
using Super::Super;
};
-struct DivIUOpLowering : public OneToOneLLVMOpLowering<DivIUOp, LLVM::UDivOp> {
+struct DivIUOpLowering : public BinaryOpLLVMOpLowering<DivIUOp, LLVM::UDivOp> {
using Super::Super;
};
-struct RemISOpLowering : public OneToOneLLVMOpLowering<RemISOp, LLVM::SRemOp> {
+struct RemISOpLowering : public BinaryOpLLVMOpLowering<RemISOp, LLVM::SRemOp> {
using Super::Super;
};
-struct RemIUOpLowering : public OneToOneLLVMOpLowering<RemIUOp, LLVM::URemOp> {
+struct RemIUOpLowering : public BinaryOpLLVMOpLowering<RemIUOp, LLVM::URemOp> {
using Super::Super;
};
-struct AndOpLowering : public OneToOneLLVMOpLowering<AndOp, LLVM::AndOp> {
+struct AndOpLowering : public BinaryOpLLVMOpLowering<AndOp, LLVM::AndOp> {
using Super::Super;
};
-struct OrOpLowering : public OneToOneLLVMOpLowering<OrOp, LLVM::OrOp> {
+struct OrOpLowering : public BinaryOpLLVMOpLowering<OrOp, LLVM::OrOp> {
using Super::Super;
};
-struct XOrOpLowering : public OneToOneLLVMOpLowering<XOrOp, LLVM::XOrOp> {
+struct XOrOpLowering : public BinaryOpLLVMOpLowering<XOrOp, LLVM::XOrOp> {
using Super::Super;
};
-struct AddFOpLowering : public OneToOneLLVMOpLowering<AddFOp, LLVM::FAddOp> {
+struct AddFOpLowering : public BinaryOpLLVMOpLowering<AddFOp, LLVM::FAddOp> {
using Super::Super;
};
-struct SubFOpLowering : public OneToOneLLVMOpLowering<SubFOp, LLVM::FSubOp> {
+struct SubFOpLowering : public BinaryOpLLVMOpLowering<SubFOp, LLVM::FSubOp> {
using Super::Super;
};
-struct MulFOpLowering : public OneToOneLLVMOpLowering<MulFOp, LLVM::FMulOp> {
+struct MulFOpLowering : public BinaryOpLLVMOpLowering<MulFOp, LLVM::FMulOp> {
using Super::Super;
};
-struct DivFOpLowering : public OneToOneLLVMOpLowering<DivFOp, LLVM::FDivOp> {
+struct DivFOpLowering : public BinaryOpLLVMOpLowering<DivFOp, LLVM::FDivOp> {
using Super::Super;
};
-struct RemFOpLowering : public OneToOneLLVMOpLowering<RemFOp, LLVM::FRemOp> {
+struct RemFOpLowering : public BinaryOpLLVMOpLowering<RemFOp, LLVM::FRemOp> {
using Super::Super;
};
struct SelectOpLowering
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, memRefDescriptor, allocated,
- getIntegerArrayAttr(rewriter, 0));
+ rewriter.getIndexArrayAttr(0));
// Store dynamically allocated sizes in the descriptor. Dynamic sizes are
// passed in as operands.
for (auto indexedSize : llvm::enumerate(operands)) {
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, memRefDescriptor, indexedSize.value(),
- getIntegerArrayAttr(rewriter, 1 + indexedSize.index()));
+ rewriter.getIndexArrayAttr(1 + indexedSize.index()));
}
// Return the final value of the descriptor.
}
auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
- auto hasStaticShape = type.getUnderlyingType()->isPointerTy();
+ auto hasStaticShape = type.isPointerTy();
Type elementPtrType = hasStaticShape ? type : type.getStructElementType(0);
Value *bufferPtr =
extractMemRefElementPtr(rewriter, op->getLoc(), transformed.memref(),
// Otherwise target type is dynamic memref, so create a proper descriptor.
newDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, newDescriptor, buffer,
- getIntegerArrayAttr(rewriter, 0));
+ rewriter.getIndexArrayAttr(0));
// Fill in the dynamic sizes of the new descriptor. If the size was
// dynamic, copy it from the old descriptor. If the size was static, insert
? rewriter.create<LLVM::ExtractValueOp>(
op->getLoc(), getIndexType(),
transformed.source(), // NB: dynamic memref
- getIntegerArrayAttr(rewriter, sourceDynamicDimIdx++))
+ rewriter.getIndexArrayAttr(sourceDynamicDimIdx++))
: createIndexConstant(rewriter, op->getLoc(), sourceSize);
newDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, newDescriptor, size,
- getIntegerArrayAttr(rewriter, targetDynamicDimIdx++));
+ rewriter.getIndexArrayAttr(targetDynamicDimIdx++));
}
assert(sourceDynamicDimIdx - 1 == sourceType.getNumDynamicDims() &&
"source dynamic dimensions were not processed");
}
rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
op, getIndexType(), transformed.memrefOrTensor(),
- getIntegerArrayAttr(rewriter, position));
+ rewriter.getIndexArrayAttr(position));
} else {
rewriter.replaceOp(
op, createIndexConstant(rewriter, op->getLoc(), shape[index]));
if (s == -1) {
Value *size = rewriter.create<LLVM::ExtractValueOp>(
loc, this->getIndexType(), memRefDescriptor,
- this->getIntegerArrayAttr(rewriter, dynamicSizeIdx++));
+ rewriter.getIndexArrayAttr(dynamicSizeIdx++));
sizes.push_back(size);
} else {
sizes.push_back(this->createIndexConstant(rewriter, loc, s));
Value *subscript = linearizeSubscripts(rewriter, loc, indices, sizes);
Value *dataPtr = rewriter.create<LLVM::ExtractValueOp>(
- loc, elementTypePtr, memRefDescriptor,
- this->getIntegerArrayAttr(rewriter, 0));
+ loc, elementTypePtr, memRefDescriptor, rewriter.getIndexArrayAttr(0));
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr,
ArrayRef<Value *>{dataPtr, subscript},
ArrayRef<NamedAttribute>{});
for (unsigned i = 0; i < numArguments; ++i) {
packed = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), packedType, packed, operands[i],
- getIntegerArrayAttr(rewriter, i));
+ rewriter.getIndexArrayAttr(i));
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
op, llvm::makeArrayRef(packed), llvm::ArrayRef<Block *>(),