/// of the type and we assert everything is f32.
/// TODO(ntv): relax the assumptions on admissible element type once a
/// contract exists.
- MaterializationState() : hwVectorSize(clVectorSize.size(), 0) {
- std::copy(clVectorSize.begin(), clVectorSize.end(), hwVectorSize.begin());
- }
+ MaterializationState(SmallVector<int64_t, 8> sizes) : hwVectorSize(sizes) {}
+
SmallVector<int64_t, 8> hwVectorSize;
VectorType superVectorType;
VectorType hwVectorType;
DenseMap<Value *, Value *> *substitutionsMap;
};
+/// Base state for the vector materialization pass.
+/// Command line arguments are preempted by non-empty pass arguments.
struct MaterializeVectorsPass : public FunctionPass<MaterializeVectorsPass> {
+ MaterializeVectorsPass()
+ : hwVectorSize(clVectorSize.begin(), clVectorSize.end()) {}
+ MaterializeVectorsPass(ArrayRef<int64_t> hwVectorSize)
+ : MaterializeVectorsPass() {
+ if (!hwVectorSize.empty())
+ this->hwVectorSize.assign(hwVectorSize.begin(), hwVectorSize.end());
+ }
+
+ SmallVector<int64_t, 8> hwVectorSize;
void runOnFunction() override;
};
LLVM_DEBUG(dbgs() << "\nMaterializeVectors on Function\n");
LLVM_DEBUG(f->print(dbgs()));
- MaterializationState state;
+ MaterializationState state(hwVectorSize);
// Get the hardware vector type.
// TODO(ntv): get elemental type from super-vector type rather than force f32.
auto subVectorType =
- VectorType::get(state.hwVectorSize, FloatType::getF32(&getContext()));
+ VectorType::get(hwVectorSize, FloatType::getF32(&getContext()));
// Capture terminators; i.e. vector_transfer_write ops involving a strict
// super-vector of subVectorType.
signalPassFailure();
}
-FunctionPassBase *mlir::createMaterializeVectorsPass() {
- return new MaterializeVectorsPass();
+FunctionPassBase *
+mlir::createMaterializeVectorsPass(llvm::ArrayRef<int64_t> vectorSize) {
+ return new MaterializeVectorsPass(vectorSize);
}
static PassRegistration<MaterializeVectorsPass>
} // end anonymous namespace
-Vectorize::Vectorize() {
- this->vectorSizes.assign(clVirtualVectorSize.begin(),
- clVirtualVectorSize.end());
- this->fastestVaryingPattern.assign(clFastestVaryingPattern.begin(),
- clFastestVaryingPattern.end());
-}
+Vectorize::Vectorize()
+ : vectorSizes(clVirtualVectorSize.begin(), clVirtualVectorSize.end()),
+ fastestVaryingPattern(clFastestVaryingPattern.begin(),
+ clFastestVaryingPattern.end()) {}
Vectorize::Vectorize(ArrayRef<int64_t> virtualVectorSize) : Vectorize() {
if (!virtualVectorSize.empty()) {