From 26eb2c6b42f7cd10b54399e2b7c69a25560e23a0 Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Tue, 18 Oct 2022 16:41:03 +0000 Subject: [PATCH] [mlir][sparse] remove vector support in sparsification Sparse compiler used to generate vectorized code for sparse tensors computation, but it should really be delegated to other vectorization passes for better progressive lowering. https://discourse.llvm.org/t/rfc-structured-codegen-beyond-rectangular-arrays/64707 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D136183 --- .../mlir/Dialect/SparseTensor/Pipelines/Passes.h | 28 +- .../mlir/Dialect/SparseTensor/Transforms/Passes.h | 36 +- .../mlir/Dialect/SparseTensor/Transforms/Passes.td | 23 +- .../Pipelines/SparseTensorPipelines.cpp | 2 +- .../SparseTensor/Transforms/CodegenUtils.cpp | 1 - .../SparseTensor/Transforms/SparseTensorPasses.cpp | 22 +- .../SparseTensor/Transforms/Sparsification.cpp | 320 +----------- .../test/Dialect/SparseTensor/sparse_parallel.mlir | 17 +- mlir/test/Dialect/SparseTensor/sparse_vector.mlir | 540 ++++----------------- .../Dialect/SparseTensor/sparse_vector_chain.mlir | 128 ----- .../Dialect/SparseTensor/sparse_vector_index.mlir | 126 ----- .../Dialect/SparseTensor/sparse_vector_peeled.mlir | 63 --- .../Dialect/SparseTensor/CPU/sparse_cast.mlir | 8 - .../SparseTensor/CPU/sparse_filter_conv2d.mlir | 7 - .../Dialect/SparseTensor/CPU/sparse_flatten.mlir | 9 - .../SparseTensor/CPU/sparse_index_dense.mlir | 7 - .../Dialect/SparseTensor/CPU/sparse_matvec.mlir | 10 - .../Dialect/SparseTensor/CPU/sparse_mttkrp.mlir | 9 - .../SparseTensor/CPU/sparse_out_simple.mlir | 9 - .../SparseTensor/CPU/sparse_quantized_matmul.mlir | 7 - .../SparseTensor/CPU/sparse_reductions.mlir | 7 - .../SparseTensor/CPU/sparse_sampled_matmul.mlir | 11 - .../SparseTensor/CPU/sparse_sampled_mm_fusion.mlir | 7 - .../Dialect/SparseTensor/CPU/sparse_scale.mlir | 8 - .../Dialect/SparseTensor/CPU/sparse_spmm.mlir | 9 - .../Dialect/SparseTensor/CPU/sparse_sum.mlir | 9 - .../Dialect/SparseTensor/python/test_SDDMM.py | 23 +- .../Dialect/SparseTensor/python/test_SpMM.py | 4 +- .../Dialect/SparseTensor/python/test_stress.py | 5 - utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 1 - 30 files changed, 136 insertions(+), 1320 deletions(-) delete mode 100644 mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir delete mode 100644 mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir delete mode 100644 mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir diff --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h index 466b0d2..97030f5 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h @@ -52,29 +52,7 @@ struct SparseCompilerOptions mlir::SparseParallelizationStrategy::kAnyStorageAnyLoop, "any-storage-any-loop", "Enable sparse parallelization for any storage and loop."))}; - PassOptions::Option vectorization{ - *this, "vectorization-strategy", - ::llvm::cl::desc("Set the vectorization strategy"), - ::llvm::cl::init(mlir::SparseVectorizationStrategy::kNone), - llvm::cl::values( - clEnumValN(mlir::SparseVectorizationStrategy::kNone, "none", - "Turn off sparse vectorization."), - clEnumValN(mlir::SparseVectorizationStrategy::kDenseInnerLoop, - "dense-inner-loop", - "Enable vectorization for dense inner loops."), - clEnumValN(mlir::SparseVectorizationStrategy::kAnyStorageInnerLoop, - "any-storage-inner-loop", - "Enable sparse vectorization for inner loops with any " - "storage."))}; - - PassOptions::Option vectorLength{ - *this, "vl", desc("Set the vector length"), init(1)}; - PassOptions::Option enableSIMDIndex32{ - *this, "enable-simd-index32", - desc("Enable i32 indexing into vectors (for efficiency)"), init(false)}; - PassOptions::Option enableVLAVectorization{ - *this, "enable-vla-vectorization", - desc("Enable vector length agnostic vectorization"), init(false)}; + PassOptions::Option enableRuntimeLibrary{ *this, "enable-runtime-library", desc("Enable runtime library for manipulating sparse tensors"), @@ -87,9 +65,7 @@ struct SparseCompilerOptions /// Projects out the options for `createSparsificationPass`. SparsificationOptions sparsificationOptions() const { - return SparsificationOptions(parallelization, vectorization, vectorLength, - enableSIMDIndex32, enableVLAVectorization, - enableRuntimeLibrary); + return SparsificationOptions(parallelization); } // These options must be kept in sync with `SparseTensorConversionBase`. diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h index 2230f43..5e301c4 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -44,39 +44,16 @@ enum class SparseParallelizationStrategy { // TODO: support reduction parallelization too? }; -/// Defines a vectorization strategy. Any inner loop is a candidate (full SIMD -/// for parallel loops and horizontal SIMD for reduction loops). A loop is -/// actually vectorized if (1) allowed by the strategy, and (2) the emitted -/// code is an actual for-loop (and not a co-iterating while-loop). -enum class SparseVectorizationStrategy { - kNone, - kDenseInnerLoop, - kAnyStorageInnerLoop -}; - #define GEN_PASS_DECL #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" /// Options for the Sparsification pass. struct SparsificationOptions { - SparsificationOptions(SparseParallelizationStrategy p, - SparseVectorizationStrategy v, unsigned vl, bool e, - bool vla, bool rt) - : parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl), - enableSIMDIndex32(e), enableVLAVectorization(vla), - enableRuntimeLibrary(rt) {} + SparsificationOptions(SparseParallelizationStrategy p) + : parallelizationStrategy(p) {} SparsificationOptions() - : SparsificationOptions(SparseParallelizationStrategy::kNone, - SparseVectorizationStrategy::kNone, 1u, - /*enable SIMD Index32=*/false, - /*enable VLA Vectorization=*/false, - /*enable runtime library=*/true) {} + : SparsificationOptions(SparseParallelizationStrategy::kNone) {} SparseParallelizationStrategy parallelizationStrategy; - SparseVectorizationStrategy vectorizationStrategy; - unsigned vectorLength; - bool enableSIMDIndex32; - bool enableVLAVectorization; - bool enableRuntimeLibrary; }; /// Sets up sparsification rewriting rules with the given options. @@ -165,10 +142,9 @@ void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT, bool enableForeach, bool enableConvert); std::unique_ptr createSparseTensorRewritePass(); -std::unique_ptr -createSparseTensorRewritePass(const SparsificationOptions &options, - bool enableForeach = true, - bool enableConvert = true); +std::unique_ptr createSparseTensorRewritePass(bool enableRT, + bool enableForeach = true, + bool enableConvert = true); //===----------------------------------------------------------------------===// // Other rewriting rules and passes. diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td index eee33b0..09593c2 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -86,7 +86,6 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> { "memref::MemRefDialect", "scf::SCFDialect", "sparse_tensor::SparseTensorDialect", - "vector::VectorDialect", ]; // TODO(57514): These enum options are duplicated in Passes.h. let options = [ @@ -106,26 +105,7 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> { "Enable dense parallelization for any loop."), clEnumValN(mlir::SparseParallelizationStrategy::kAnyStorageAnyLoop, "any-storage-any-loop", - "Enable sparse parallelization for any storage and loop."))}]>, - Option<"vectorization", "vectorization-strategy", "mlir::SparseVectorizationStrategy", - "mlir::SparseVectorizationStrategy::kNone", - "Set the vectorization strategy", [{llvm::cl::values( - clEnumValN(mlir::SparseVectorizationStrategy::kNone, "none", - "Turn off sparse vectorization."), - clEnumValN(mlir::SparseVectorizationStrategy::kDenseInnerLoop, - "dense-inner-loop", - "Enable vectorization for dense inner loops."), - clEnumValN(mlir::SparseVectorizationStrategy::kAnyStorageInnerLoop, - "any-storage-inner-loop", - "Enable sparse vectorization for inner loops with any storage."))}]>, - Option<"vectorLength", "vl", "int32_t", "1", - "Set the vector length">, - Option<"enableSIMDIndex32", "enable-simd-index32", "bool", "false", - "Enable i32 indexing into vectors (for efficiency)">, - Option<"enableVLAVectorization", "enable-vla-vectorization", "bool", - "false", "Enable vector length agnostic vectorization">, - Option<"enableRuntimeLibrary", "enable-runtime-library", "bool", - "true", "Enable runtime library for manipulating sparse tensors"> + "Enable sparse parallelization for any storage and loop."))}]> ]; } @@ -167,7 +147,6 @@ def SparseTensorConversionPass : Pass<"sparse-tensor-conversion", "ModuleOp"> { "memref::MemRefDialect", "scf::SCFDialect", "sparse_tensor::SparseTensorDialect", - "vector::VectorDialect", ]; let options = [ Option<"sparseToSparse", "s2s-strategy", "int32_t", "0", diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp index 0cd1799..b3b7057 100644 --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -58,7 +58,7 @@ void mlir::sparse_tensor::buildSparseCompiler( /*analysisOnly=*/options.testBufferizationAnalysisOnly))); if (options.testBufferizationAnalysisOnly) return; - pm.addPass(createSparseTensorRewritePass(options.sparsificationOptions())); + pm.addPass(createSparseTensorRewritePass(options.enableRuntimeLibrary)); pm.addPass(createSparsificationPass(options.sparsificationOptions())); if (options.enableRuntimeLibrary) pm.addPass(createSparseTensorConversionPass( diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp index 4e75999..e8c561b 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -13,7 +13,6 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp index b524ac1..4a35a7f 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -43,9 +43,8 @@ struct SparseTensorRewritePass SparseTensorRewritePass() = default; SparseTensorRewritePass(const SparseTensorRewritePass &pass) = default; - SparseTensorRewritePass(const SparsificationOptions &options, bool foreach, - bool convert) { - enableRuntimeLibrary = options.enableRuntimeLibrary; + SparseTensorRewritePass(bool enableRT, bool foreach, bool convert) { + enableRuntimeLibrary = enableRT; enableForeach = foreach; enableConvert = convert; } @@ -66,19 +65,12 @@ struct SparsificationPass SparsificationPass(const SparsificationPass &pass) = default; SparsificationPass(const SparsificationOptions &options) { parallelization = options.parallelizationStrategy; - vectorization = options.vectorizationStrategy; - vectorLength = options.vectorLength; - enableSIMDIndex32 = options.enableSIMDIndex32; - enableVLAVectorization = options.enableVLAVectorization; - enableRuntimeLibrary = options.enableRuntimeLibrary; } void runOnOperation() override { auto *ctx = &getContext(); // Translate strategy flags to strategy options. - SparsificationOptions options(parallelization, vectorization, vectorLength, - enableSIMDIndex32, enableVLAVectorization, - enableRuntimeLibrary); + SparsificationOptions options(parallelization); // Apply sparsification and vector cleanup rewriting. RewritePatternSet patterns(ctx); populateSparsificationPatterns(patterns, options); @@ -258,10 +250,10 @@ std::unique_ptr mlir::createSparseTensorRewritePass() { return std::make_unique(); } -std::unique_ptr -mlir::createSparseTensorRewritePass(const SparsificationOptions &options, - bool enableForeach, bool enableConvert) { - return std::make_unique(options, enableForeach, +std::unique_ptr mlir::createSparseTensorRewritePass(bool enableRT, + bool enableForeach, + bool enableConvert) { + return std::make_unique(enableRT, enableForeach, enableConvert); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index 45013c0..0f3251c 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -27,7 +27,6 @@ #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/SparseTensor/Utils/Merger.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/TensorEncoding.h" #include "llvm/ADT/SmallBitVector.h" @@ -97,9 +96,6 @@ struct CodeGen { Value expFilled; Value expAdded; Value expCount; - // Current vector length and mask. - unsigned curVecLength = 1; - Value curVecMask; // Topsort (reference should remain in scope). std::vector &topSort; }; @@ -373,26 +369,6 @@ static bool isAdmissibleTensorExp(Merger &merger, linalg::GenericOp op, // Sparse compiler synthesis methods (reductions). //===----------------------------------------------------------------------===// -/// Maps reduction kind to vector::CombiningKind. -static vector::CombiningKind getCombiningKind(Reduction kind) { - switch (kind) { - case kNoReduc: - case kCustom: - break; - case kSum: - return vector::CombiningKind::ADD; - case kProduct: - return vector::CombiningKind::MUL; - case kAnd: - return vector::CombiningKind::AND; - case kOr: - return vector::CombiningKind::OR; - case kXor: - return vector::CombiningKind::XOR; - } - llvm_unreachable("unknown reduction kind"); -} - /// Maps operation to reduction. static Reduction getReduction(Kind kind) { switch (kind) { @@ -420,42 +396,6 @@ static Reduction getReduction(Kind kind) { } } -/// Generates an initial value for a vector reduction, following the scheme -/// given in Chapter 5 of "The Software Vectorization Handbook", where the -/// initial scalar value is correctly embedded in the vector reduction value, -/// and a straightforward horizontal reduction will complete the operation. -static Value genVectorReducInit(CodeGen &codegen, OpBuilder &builder, - Location loc, VectorType vtp) { - Value r = codegen.redVal; - switch (codegen.redKind) { - case kNoReduc: - case kCustom: - break; - case kSum: - case kXor: - // Initialize reduction vector to: | 0 | .. | 0 | r | - return builder.create( - loc, r, constantZero(builder, loc, vtp), - constantIndex(builder, loc, 0)); - case kProduct: - // Initialize reduction vector to: | 1 | .. | 1 | r | - return builder.create( - loc, r, constantOne(builder, loc, vtp), constantIndex(builder, loc, 0)); - case kAnd: - case kOr: - // Initialize reduction vector to: | r | .. | r | r | - return builder.create(loc, vtp, r); - } - llvm_unreachable("unknown reduction kind"); -} - -/// Generates final value for a vector reduction. -static Value genVectorReducEnd(CodeGen &codegen, OpBuilder &builder, - Location loc, VectorType vtp) { - vector::CombiningKind kind = getCombiningKind(codegen.redKind); - return builder.create(loc, kind, codegen.redVal); -} - /// Updates scalarized reduction value. static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) { assert(codegen.redKind != kNoReduc); @@ -573,89 +513,6 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder, } } -/// Constructs vector type. -static VectorType vectorType(CodeGen &codegen, Type etp) { - unsigned numScalableDims = codegen.options.enableVLAVectorization; - return VectorType::get(codegen.curVecLength, etp, numScalableDims); -} - -/// Constructs vector type from pointer. -static VectorType vectorType(CodeGen &codegen, Value ptr) { - return vectorType(codegen, ptr.getType().cast().getElementType()); -} - -/// Constructs vector iteration mask. -static Value genVectorMask(CodeGen &codegen, OpBuilder &builder, Value iv, - Value lo, Value hi, Value step) { - Location loc = iv.getLoc(); - VectorType mtp = vectorType(codegen, builder.getI1Type()); - // Special case if the vector length evenly divides the trip count (for - // example, "for i = 0, 128, 16"). A constant all-true mask is generated - // so that all subsequent masked memory operations are immediately folded - // into unconditional memory operations. - IntegerAttr loInt, hiInt, stepInt; - if (matchPattern(lo, m_Constant(&loInt)) && - matchPattern(hi, m_Constant(&hiInt)) && - matchPattern(step, m_Constant(&stepInt))) { - if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) - return builder.create( - loc, mtp, constantI1(builder, loc, true)); - } - // Otherwise, generate a vector mask that avoids overrunning the upperbound - // during vector execution. Here we rely on subsequent loop optimizations to - // avoid executing the mask in all iterations, for example, by splitting the - // loop into an unconditional vector loop and a scalar cleanup loop. - auto minMap = AffineMap::get( - /*dimCount=*/2, /*symbolCount=*/1, - {builder.getAffineSymbolExpr(0), - builder.getAffineDimExpr(0) - builder.getAffineDimExpr(1)}, - builder.getContext()); - Value end = - builder.createOrFold(loc, minMap, ValueRange{hi, iv, step}); - return builder.create(loc, mtp, end); -} - -/// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. -static Value genVectorLoad(CodeGen &codegen, OpBuilder &builder, Value ptr, - ArrayRef args) { - Location loc = ptr.getLoc(); - VectorType vtp = vectorType(codegen, ptr); - Value pass = constantZero(builder, loc, vtp); - if (args.back().getType().isa()) { - SmallVector scalarArgs(args.begin(), args.end()); - Value indexVec = args.back(); - scalarArgs.back() = constantIndex(builder, loc, 0); - return builder.create(loc, vtp, ptr, scalarArgs, indexVec, - codegen.curVecMask, pass); - } - return builder.create(loc, vtp, ptr, args, - codegen.curVecMask, pass); -} - -/// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. -static void genVectorStore(CodeGen &codegen, OpBuilder &builder, Value rhs, - Value ptr, ArrayRef args) { - Location loc = ptr.getLoc(); - if (args.back().getType().isa()) { - SmallVector scalarArgs(args.begin(), args.end()); - Value indexVec = args.back(); - scalarArgs.back() = constantIndex(builder, loc, 0); - builder.create(loc, ptr, scalarArgs, indexVec, - codegen.curVecMask, rhs); - return; - } - builder.create(loc, ptr, args, codegen.curVecMask, - rhs); -} - -/// Generates a vectorized invariant. Here we rely on subsequent loop -/// optimizations to hoist the invariant broadcast out of the vector loop. -static Value genVectorInvariantValue(CodeGen &codegen, OpBuilder &builder, - Value val) { - VectorType vtp = vectorType(codegen, val.getType()); - return builder.create(val.getLoc(), vtp, val); -} - /// Generates an affine expression. // // TODO: generalize for sparse tensor subscripts @@ -808,11 +665,9 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, unsigned exp) { // Test if the load was hoisted to a higher loop nest. Value val = merger.exp(exp).val; - if (val) { - if (codegen.curVecLength > 1 && !val.getType().isa()) - return genVectorInvariantValue(codegen, builder, val); + if (val) return val; - } + // Load during insertion. OpOperand &t = op->getOpOperand(merger.exp(exp).tensor); if (&t == codegen.sparseOut) { @@ -823,8 +678,6 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen, OpBuilder &builder, // Actual load. SmallVector args; Value ptr = genSubscript(codegen, builder, op, &t, args); - if (codegen.curVecLength > 1) - return genVectorLoad(codegen, builder, ptr, args); return builder.create(op.getLoc(), ptr, args); } @@ -834,9 +687,6 @@ static void genTensorStore(Merger &merger, CodeGen &codegen, OpBuilder &builder, Location loc = op.getLoc(); // Test if this is a scalarized reduction. if (codegen.redVal) { - if (codegen.curVecLength > 1) - rhs = builder.create(loc, codegen.curVecMask, rhs, - codegen.redVal); updateReduc(merger, codegen, rhs); return; } @@ -864,10 +714,7 @@ static void genTensorStore(Merger &merger, CodeGen &codegen, OpBuilder &builder, // Actual store. SmallVector args; Value ptr = genSubscript(codegen, builder, op, t, args); - if (codegen.curVecLength > 1) - genVectorStore(codegen, builder, rhs, ptr, args); - else - builder.create(loc, rhs, ptr, args); + builder.create(loc, rhs, ptr, args); } /// Generates a pointer/index load from the sparse storage scheme. Narrower @@ -875,37 +722,8 @@ static void genTensorStore(Merger &merger, CodeGen &codegen, OpBuilder &builder, /// index type used for looping and indexing. static Value genLoad(CodeGen &codegen, OpBuilder &builder, Location loc, Value ptr, Value s) { - // See https://llvm.org/docs/GetElementPtr.html for some background on - // the complications described below. - if (codegen.curVecLength > 1) { - // Since the index vector is used in a subsequent gather/scatter operations, - // which effectively defines an unsigned pointer + signed index, we must - // zero extend the vector to an index width. For 8-bit and 16-bit values, - // an 32-bit index width suffices. For 32-bit values, zero extending the - // elements into 64-bit loses some performance since the 32-bit indexed - // gather/scatter is more efficient than the 64-bit index variant (if the - // negative 32-bit index space is unused, the enableSIMDIndex32 flag can - // preserve this performance). For 64-bit values, there is no good way - // to state that the indices are unsigned, with creates the potential of - // incorrect address calculations in the unlikely case we need such - // extremely large offsets. - Type etp = ptr.getType().cast().getElementType(); - Value vload = genVectorLoad(codegen, builder, ptr, {s}); - if (!etp.isa()) { - if (etp.getIntOrFloatBitWidth() < 32) - vload = builder.create( - loc, vectorType(codegen, builder.getI32Type()), vload); - else if (etp.getIntOrFloatBitWidth() < 64 && - !codegen.options.enableSIMDIndex32) - vload = builder.create( - loc, vectorType(codegen, builder.getI64Type()), vload); - } - return vload; - } - // For the scalar case, we simply zero extend narrower indices into 64-bit - // values before casting to index without a performance penalty. Here too, - // however, indices that already are 64-bit, in theory, cannot express the - // full range as explained above. + // Simply zero extends narrower indices into 64-bit values before casting to + // index without a performance penalty. Value load = builder.create(loc, ptr, s); if (!load.getType().isa()) { if (load.getType().getIntOrFloatBitWidth() < 64) @@ -920,8 +738,6 @@ static Value genLoad(CodeGen &codegen, OpBuilder &builder, Location loc, static Value genInvariantValue(Merger &merger, CodeGen &codegen, OpBuilder &builder, unsigned exp) { Value val = merger.exp(exp).val; - if (codegen.curVecLength > 1) - return genVectorInvariantValue(codegen, builder, val); return val; } @@ -929,11 +745,6 @@ static Value genInvariantValue(Merger &merger, CodeGen &codegen, static Value genAddress(CodeGen &codegen, OpBuilder &builder, Location loc, Value size, Value p, Value i) { Value mul = builder.create(loc, size, p); - if (auto vtp = i.getType().dyn_cast()) { - Value inv = - builder.create(loc, vtp.getElementType(), mul); - mul = genVectorInvariantValue(codegen, builder, inv); - } return builder.create(loc, mul, i); } @@ -941,31 +752,6 @@ static Value genAddress(CodeGen &codegen, OpBuilder &builder, Location loc, static Value genIndexValue(CodeGen &codegen, OpBuilder &builder, unsigned idx, unsigned ldx) { Value ival = codegen.loops[idx]; - Type itype = ival.getType(); - // During vectorization, we either encounter: - // (1) indices already in vector form, as in ... = ind[lo:hi], good to go, or - // (2) single index, as in ... = i, must convert to [i, i+1, ...] for inner i. - unsigned vl = codegen.curVecLength; - if (vl > 1 && !itype.isa()) { - Location loc = ival.getLoc(); - VectorType vtp = vectorType(codegen, itype); - ival = builder.create(loc, vtp, ival); - if (idx == ldx) { - Value incr; - if (vtp.isScalable()) { - Type stepvty = vectorType(codegen, builder.getI64Type()); - Value stepv = builder.create(loc, stepvty); - incr = builder.create(loc, vtp, stepv); - } else { - SmallVector integers; - for (unsigned i = 0; i < vl; i++) - integers.push_back(APInt(/*width=*/64, i)); - auto values = DenseElementsAttr::get(vtp, integers); - incr = builder.create(loc, vtp, values); - } - ival = builder.create(loc, ival, incr); - } - } return ival; } @@ -1207,31 +993,11 @@ static bool genInit(Merger &merger, CodeGen &codegen, OpBuilder &builder, return needsUniv; } -/// Returns vectorization strategy. Any implicit inner loop in the Linalg -/// operation is a candidate. Whether it is actually converted to SIMD code -/// depends on the requested strategy. -static bool isVectorFor(CodeGen &codegen, bool isInner, bool isReduction, - bool isSparse) { - // Reject vectorization of sparse output, unless innermost is reduction. - if (codegen.sparseOut && !isReduction) - return false; - // Inspect strategy. - switch (codegen.options.vectorizationStrategy) { - case SparseVectorizationStrategy::kNone: - return false; - case SparseVectorizationStrategy::kDenseInnerLoop: - return isInner && !isSparse; - case SparseVectorizationStrategy::kAnyStorageInnerLoop: - return isInner; - } - llvm_unreachable("unexpected vectorization strategy"); -} - /// Returns parallelization strategy. Any implicit loop in the Linalg operation /// that is marked "parallel" is a candidate. Whether it is actually converted /// to a parallel operation depends on the requested strategy. static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, - bool isSparse, bool isVector) { + bool isSparse) { // Reject parallelization of sparse output. if (codegen.sparseOut) return false; @@ -1240,42 +1006,17 @@ static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, case SparseParallelizationStrategy::kNone: return false; case SparseParallelizationStrategy::kDenseOuterLoop: - return isOuter && !isSparse && !isReduction && !isVector; + return isOuter && !isSparse && !isReduction; case SparseParallelizationStrategy::kAnyStorageOuterLoop: - return isOuter && !isReduction && !isVector; + return isOuter && !isReduction; case SparseParallelizationStrategy::kDenseAnyLoop: - return !isSparse && !isReduction && !isVector; + return !isSparse && !isReduction; case SparseParallelizationStrategy::kAnyStorageAnyLoop: - return !isReduction && !isVector; + return !isReduction; } llvm_unreachable("unexpected parallelization strategy"); } -/// Checks unit stride for dense tensors. The iteration graph may have ignored -/// dense access patterns in order to avoid cycles (sparse access patterns are -/// always placed innermost), but that means dense access has become strided. -/// This prevents effective vectorization. -static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, - unsigned idx) { - for (OpOperand &t : op->getOpOperands()) { - if (!getSparseTensorEncoding(t.get().getType())) { - auto map = op.getMatchingIndexingMap(&t); - for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { - AffineExpr a = map.getResult(d); - // Report non-unit stride if innermost index appears at an outer - // dimension (true non-unit stride) or if the innermost index appears - // in a compound subscript in the innermost dimension. Even if the - // latter is unit stride, it does not play well with scatter/gather. - // TODO: accept unit stride affine innermost like a[i,j+k+1]? - if (a.isFunctionOfDim(idx) && - ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId))) - return false; - } - } - } - return true; -} - /// Generates a for-loop on a single index. static Operation *genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, bool isOuter, bool isInner, @@ -1287,29 +1028,16 @@ static Operation *genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder, bool isReduction = linalg::isReductionIterator(iteratorTypes[idx]); bool isSparse = isCompressedDLT(merger.getDimLevelType(fb)) || isSingletonDLT(merger.getDimLevelType(fb)); - bool isVector = isVectorFor(codegen, isInner, isReduction, isSparse) && - denseUnitStrides(merger, op, idx); - bool isParallel = - isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); - - // Prepare vector length. - if (isVector) - codegen.curVecLength = codegen.options.vectorLength; + bool isParallel = isParallelFor(codegen, isOuter, isReduction, isSparse); // Loop bounds and increment. Location loc = op.getLoc(); Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; - Value step = constantIndex(builder, loc, codegen.curVecLength); - if (isVector && codegen.options.enableVLAVectorization) { - Value vscale = builder.create( - loc, IndexType::get(builder.getContext())); - step = builder.create(loc, vscale, step); - } + Value step = constantIndex(builder, loc, 1); // Emit a parallel loop. if (isParallel) { - assert(!isVector); scf::ParallelOp parOp = builder.create(loc, lo, hi, step); if (isSparse) codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; @@ -1321,18 +1049,13 @@ static Operation *genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder, // Emit a sequential or vector loop. SmallVector operands; - if (codegen.redVal) { - // In a vector loop, bring reduction into SIMD form, if not already. - if (isVector && !codegen.redVal.getType().isa()) { - VectorType vtp = vectorType(codegen, codegen.redVal.getType()); - Value vred = genVectorReducInit(codegen, builder, loc, vtp); - updateReduc(merger, codegen, vred); - } + if (codegen.redVal) operands.push_back(codegen.redVal); - } if (codegen.expValues) operands.push_back(codegen.expCount); + scf::ForOp forOp = builder.create(loc, lo, hi, step, operands); + if (codegen.redVal) updateReduc(merger, codegen, forOp.getRegionIterArgs().front()); if (codegen.expValues) @@ -1343,10 +1066,8 @@ static Operation *genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder, codegen.pidxs[tensor][idx] = iv; else codegen.loops[idx] = iv; + builder.setInsertionPointToStart(forOp.getBody()); - // Share vector iteration mask between all subsequent loads/stores. - if (isVector) - codegen.curVecMask = genVectorMask(codegen, builder, iv, lo, hi, step); return forOp; } @@ -1659,7 +1380,6 @@ static void endIf(Merger &merger, CodeGen &codegen, OpBuilder &builder, static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, unsigned exp, unsigned at, unsigned idx, unsigned ldx, unsigned lts) { - assert(codegen.curVecLength == 1); assert(!codegen.loops[idx]); // Emit invariants at this loop sequence level. genInvariants(merger, codegen, builder, op, exp, ldx, /*atStart=*/true); @@ -1686,7 +1406,6 @@ static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder, static Operation *startLoop(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, unsigned at, unsigned li, bool needsUniv) { - assert(codegen.curVecLength == 1); // Emit the for/while-loop control. Operation *loop = genLoop(merger, codegen, builder, op, at, needsUniv, merger.lat(li).simple); @@ -1699,7 +1418,6 @@ static Operation *startLoop(Merger &merger, CodeGen &codegen, static bool endLoop(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, Operation *loop, unsigned idx, unsigned li, bool needsUniv) { - codegen.curVecLength = 1; // End a while-loop. if (auto whileOp = dyn_cast(loop)) { genWhileInduction(merger, codegen, builder, op, idx, needsUniv, @@ -1715,14 +1433,8 @@ static bool endLoop(Merger &merger, CodeGen &codegen, OpBuilder &builder, static void endLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, unsigned exp, unsigned at, unsigned idx, unsigned ldx) { - assert(codegen.curVecLength == 1); assert(codegen.loops[idx]); codegen.loops[idx] = Value(); - // Bring a pending reduction back from SIMD form when sequence ends. - if (codegen.redVal) - if (auto vtp = codegen.redVal.getType().dyn_cast()) - updateReduc(merger, codegen, - genVectorReducEnd(codegen, builder, op.getLoc(), vtp)); // Unmark bookkeeping of invariants and loop index. genInvariants(merger, codegen, builder, op, exp, ldx, /*atStart=*/false); // Finalize access pattern expansion for sparse tensor output. diff --git a/mlir/test/Dialect/SparseTensor/sparse_parallel.mlir b/mlir/test/Dialect/SparseTensor/sparse_parallel.mlir index 5e02681..38766b0 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_parallel.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_parallel.mlir @@ -1,13 +1,14 @@ // RUN: mlir-opt %s -sparsification="parallelization-strategy=none" | \ // RUN: FileCheck %s --check-prefix=CHECK-PAR0 -// RUN: mlir-opt %s -sparsification="parallelization-strategy=dense-outer-loop" | \ -// RUN: FileCheck %s --check-prefix=CHECK-PAR1 -// RUN: mlir-opt %s -sparsification="parallelization-strategy=any-storage-outer-loop" | \ -// RUN: FileCheck %s --check-prefix=CHECK-PAR2 -// RUN: mlir-opt %s -sparsification="parallelization-strategy=dense-any-loop" | \ -// RUN: FileCheck %s --check-prefix=CHECK-PAR3 -// RUN: mlir-opt %s -sparsification="parallelization-strategy=any-storage-any-loop" | \ -// RUN: FileCheck %s --check-prefix=CHECK-PAR4 +// FIXME: we do not support vectorization/parallel loops in loop emitter right now +// R_U_N: mlir-opt %s -sparsification="parallelization-strategy=dense-outer-loop" | \ +// R_U_N: FileCheck %s --check-prefix=CHECK-PAR1 +// R_U_N: mlir-opt %s -sparsification="parallelization-strategy=any-storage-outer-loop" | \ +// R_U_N: FileCheck %s --check-prefix=CHECK-PAR2 +// R_U_N: mlir-opt %s -sparsification="parallelization-strategy=dense-any-loop" | \ +// R_U_N: FileCheck %s --check-prefix=CHECK-PAR3 +// R_U_N: mlir-opt %s -sparsification="parallelization-strategy=any-storage-any-loop" | \ +// R_U_N: FileCheck %s --check-prefix=CHECK-PAR4 #DenseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense" ] diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir index a730e77..fca5a33 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir @@ -1,13 +1,5 @@ -// RUN: mlir-opt %s -sparsification="vectorization-strategy=none vl=16" -cse -split-input-file | \ -// RUN: FileCheck %s --check-prefix=CHECK-VEC0 -// RUN: mlir-opt %s -sparsification="vectorization-strategy=dense-inner-loop vl=16" -cse -split-input-file | \ -// RUN: FileCheck %s --check-prefix=CHECK-VEC1 -// RUN: mlir-opt %s -sparsification="vectorization-strategy=any-storage-inner-loop vl=16" -cse -split-input-file | \ -// RUN: FileCheck %s --check-prefix=CHECK-VEC2 -// RUN: mlir-opt %s -sparsification="vectorization-strategy=any-storage-inner-loop vl=16 enable-simd-index32=true" -cse -split-input-file | \ -// RUN: FileCheck %s --check-prefix=CHECK-VEC3 -// RUN: mlir-opt %s -sparsification="vectorization-strategy=any-storage-inner-loop vl=4 enable-vla-vectorization=true" -cse -split-input-file | \ -// RUN: FileCheck %s --check-prefix=CHECK-VEC4 +// RUN: mlir-opt %s -sparsification -cse -split-input-file | \ +// RUN: FileCheck %s #DenseVector = #sparse_tensor.encoding<{ dimLevelType = [ "dense" ] }> @@ -21,59 +13,18 @@ } // -// CHECK-VEC0-LABEL: func @scale_d -// CHECK-VEC0-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-VEC0-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK-VEC0-DAG: %[[c1024:.*]] = arith.constant 1024 : index -// CHECK-VEC0: scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] { -// CHECK-VEC0: %[[l:.*]] = memref.load %{{.*}}[%[[i]]] : memref -// CHECK-VEC0: %[[m:.*]] = arith.mulf %[[l]], %{{.*}} : f32 -// CHECK-VEC0: store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32> -// CHECK-VEC0: } -// CHECK-VEC0: return -// -// CHECK-VEC1-LABEL: func @scale_d -// CHECK-VEC1-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-VEC1-DAG: %[[c16:.*]] = arith.constant 16 : index -// CHECK-VEC1-DAG: %[[c1024:.*]] = arith.constant 1024 : index -// CHECK-VEC1: scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] { -// CHECK-VEC1: %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref, vector<16xf32> -// CHECK-VEC1: %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32> -// CHECK-VEC1: %[[m:.*]] = arith.mulf %[[r]], %[[b]] : vector<16xf32> -// CHECK-VEC1: vector.store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32> -// CHECK-VEC1: } -// CHECK-VEC1: return -// -// CHECK-VEC2-LABEL: func @scale_d -// CHECK-VEC2-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-VEC2-DAG: %[[c16:.*]] = arith.constant 16 : index -// CHECK-VEC2-DAG: %[[c1024:.*]] = arith.constant 1024 : index -// CHECK-VEC2: scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] { -// CHECK-VEC2: %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref, vector<16xf32> -// CHECK-VEC2: %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32> -// CHECK-VEC2: %[[m:.*]] = arith.mulf %[[r]], %[[b]] : vector<16xf32> -// CHECK-VEC2: vector.store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32> -// CHECK-VEC2: } -// CHECK-VEC2: return -// -// CHECK-VEC4: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1) -// CHECK-VEC4-LABEL: func @scale_d -// CHECK-VEC4-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-VEC4-DAG: %[[c4:.*]] = arith.constant 4 : index -// CHECK-VEC4-DAG: %[[c1024:.*]] = arith.constant 1024 : index -// CHECK-VEC4-DAG: %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32> -// CHECK-VEC4-DAG: %[[vscale:.*]] = vector.vscale -// CHECK-VEC4: %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index -// CHECK-VEC4: scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[step]] { -// CHECK-VEC4: %[[sub:.*]] = affine.min #[[$map]](%[[c1024]], %[[i]])[%[[step]]] -// CHECK-VEC4: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1> -// CHECK-VEC4: %[[val:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32> -// CHECK-VEC4: %[[scalev:.*]] = vector.broadcast %{{.*}} : f32 to vector<[4]xf32> -// CHECK-VEC4: %[[scaled:.*]] = arith.mulf %[[val]], %[[scalev]] : vector<[4]xf32> -// CHECK-VEC4: vector.maskedstore %{{.*}}[%[[i]]], %[[mask]], %[[scaled]] : memref<1024xf32>, vector<[4]xi1>, vector<[4]xf32> -// CHECK-VEC4: } -// CHECK-VEC4: return +// CHECK-LABEL: func @scale_d +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[c1024:.*]] = arith.constant 1024 : index +// CHECK: scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] { +// CHECK: %[[l:.*]] = memref.load %{{.*}}[%[[i]]] : memref +// CHECK: %[[m:.*]] = arith.mulf %[[l]], %{{.*}} : f32 +// CHECK: store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32> +// CHECK: } +// CHECK: return // + func.func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor<1024xf32>) -> tensor<1024xf32> { %0 = linalg.generic #trait_scale_d ins(%arga: tensor<1024xf32, #DenseVector>) @@ -104,117 +55,25 @@ func.func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor } // -// CHECK-VEC0-LABEL: func @mul_s -// CHECK-VEC0-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-VEC0-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK-VEC0: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref -// CHECK-VEC0: %[[a:.*]] = arith.extui %[[p]] : i32 to i64 -// CHECK-VEC0: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index -// CHECK-VEC0: %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref -// CHECK-VEC0: %[[b:.*]] = arith.extui %[[r]] : i32 to i64 -// CHECK-VEC0: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index -// CHECK-VEC0: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c1]] { -// CHECK-VEC0: %[[li:.*]] = memref.load %{{.*}}[%[[i]]] : memref -// CHECK-VEC0: %[[zi:.*]] = arith.extui %[[li]] : i32 to i64 -// CHECK-VEC0: %[[ci:.*]] = arith.index_cast %[[zi]] : i64 to index -// CHECK-VEC0: %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref -// CHECK-VEC0: %[[lb:.*]] = memref.load %{{.*}}[%[[ci]]] : memref<1024xf32> -// CHECK-VEC0: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32 -// CHECK-VEC0: store %[[m]], %{{.*}}[%[[ci]]] : memref<1024xf32> -// CHECK-VEC0: } -// CHECK-VEC0: return -// -// CHECK-VEC1-LABEL: func @mul_s -// CHECK-VEC1-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-VEC1-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK-VEC1: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref -// CHECK-VEC1: %[[a:.*]] = arith.extui %[[p]] : i32 to i64 -// CHECK-VEC1: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index -// CHECK-VEC1: %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref -// CHECK-VEC1: %[[b:.*]] = arith.extui %[[r]] : i32 to i64 -// CHECK-VEC1: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index -// CHECK-VEC1: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c1]] { -// CHECK-VEC1: %[[li:.*]] = memref.load %{{.*}}[%[[i]]] : memref -// CHECK-VEC1: %[[zi:.*]] = arith.extui %[[li]] : i32 to i64 -// CHECK-VEC1: %[[ci:.*]] = arith.index_cast %[[zi]] : i64 to index -// CHECK-VEC1: %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref -// CHECK-VEC1: %[[lb:.*]] = memref.load %{{.*}}[%[[ci]]] : memref<1024xf32> -// CHECK-VEC1: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32 -// CHECK-VEC1: store %[[m]], %{{.*}}[%[[ci]]] : memref<1024xf32> -// CHECK-VEC1: } -// CHECK-VEC1: return -// -// CHECK-VEC2: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1) -// CHECK-VEC2-LABEL: func @mul_s -// CHECK-VEC2-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-VEC2-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK-VEC2-DAG: %[[c16:.*]] = arith.constant 16 : index -// CHECK-VEC2: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref -// CHECK-VEC2: %[[a:.*]] = arith.extui %[[p]] : i32 to i64 -// CHECK-VEC2: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index -// CHECK-VEC2: %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref -// CHECK-VEC2: %[[b:.*]] = arith.extui %[[r]] : i32 to i64 -// CHECK-VEC2: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index -// CHECK-VEC2: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] { -// CHECK-VEC2: %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[c16]]] -// CHECK-VEC2: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> -// CHECK-VEC2: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> -// CHECK-VEC2: %[[zi:.*]] = arith.extui %[[li]] : vector<16xi32> to vector<16xi64> -// CHECK-VEC2: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> -// CHECK-VEC2: %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32> -// CHECK-VEC2: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32> -// CHECK-VEC2: vector.scatter %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> -// CHECK-VEC2: } -// CHECK-VEC2: return -// -// CHECK-VEC3: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1) -// CHECK-VEC3-LABEL: func @mul_s -// CHECK-VEC3-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-VEC3-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK-VEC3-DAG: %[[c16:.*]] = arith.constant 16 : index -// CHECK-VEC3: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref -// CHECK-VEC3: %[[a:.*]] = arith.extui %[[p]] : i32 to i64 -// CHECK-VEC3: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index -// CHECK-VEC3: %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref -// CHECK-VEC3: %[[b:.*]] = arith.extui %[[r]] : i32 to i64 -// CHECK-VEC3: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index -// CHECK-VEC3: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] { -// CHECK-VEC3: %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[c16]]] -// CHECK-VEC3: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> -// CHECK-VEC3: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> -// CHECK-VEC3: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> -// CHECK-VEC3: %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> -// CHECK-VEC3: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32> -// CHECK-VEC3: vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -// CHECK-VEC3: } -// CHECK-VEC3: return -// -// CHECK-VEC4: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1) -// CHECK-VEC4-LABEL: func @mul_s -// CHECK-VEC4-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-VEC4-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK-VEC4-DAG: %[[c4:.*]] = arith.constant 4 : index -// CHECK-VEC4-DAG: %[[v0i:.*]] = arith.constant dense<0> : vector<[4]xi32> -// CHECK-VEC4-DAG: %[[v0f:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32> -// CHECK-VEC4: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref -// CHECK-VEC4: %[[a:.*]] = arith.extui %[[p]] : i32 to i64 -// CHECK-VEC4: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index -// CHECK-VEC4: %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref -// CHECK-VEC4: %[[b:.*]] = arith.extui %[[r]] : i32 to i64 -// CHECK-VEC4: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index -// CHECK-VEC4: %[[vscale:.*]] = vector.vscale -// CHECK-VEC4: %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index -// CHECK-VEC4: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[step]] { -// CHECK-VEC4: %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[step]]] -// CHECK-VEC4: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1> -// CHECK-VEC4: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0i]] : memref, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32> -// CHECK-VEC4: %[[lii64:.*]] = arith.extui %[[li]] : vector<[4]xi32> to vector<[4]xi64> -// CHECK-VEC4: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0f]] : memref, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32> -// CHECK-VEC4: %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[lii64]]], %[[mask]], %[[v0f]] : memref<1024xf32>, vector<[4]xi64>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32> -// CHECK-VEC4: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<[4]xf32> -// CHECK-VEC4: vector.scatter %{{.*}}[%[[c0]]] [%[[lii64]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<[4]xi64>, vector<[4]xi1>, vector<[4]xf32> -// CHECK-VEC4: } -// CHECK-VEC4: return +// CHECK-LABEL: func @mul_s +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref +// CHECK: %[[a:.*]] = arith.extui %[[p]] : i32 to i64 +// CHECK: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index +// CHECK: %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref +// CHECK: %[[b:.*]] = arith.extui %[[r]] : i32 to i64 +// CHECK: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index +// CHECK: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c1]] { +// CHECK: %[[li:.*]] = memref.load %{{.*}}[%[[i]]] : memref +// CHECK: %[[zi:.*]] = arith.extui %[[li]] : i32 to i64 +// CHECK: %[[ci:.*]] = arith.index_cast %[[zi]] : i64 to index +// CHECK: %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref +// CHECK: %[[lb:.*]] = memref.load %{{.*}}[%[[ci]]] : memref<1024xf32> +// CHECK: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32 +// CHECK: store %[[m]], %{{.*}}[%[[ci]]] : memref<1024xf32> +// CHECK: } +// CHECK: return // func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>, %argb: tensor<1024xf32>, %argx: tensor<1024xf32>) -> tensor<1024xf32> { %0 = linalg.generic #trait_mul_s @@ -242,75 +101,18 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>, %argb: tensor<1024xf32> } // -// CHECK-VEC0-LABEL: func @reduction_d -// CHECK-VEC0-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-VEC0-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK-VEC0-DAG: %[[c1024:.*]] = arith.constant 1024 : index -// CHECK-VEC0: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] iter_args(%[[red_in:.*]] = %{{.*}}) -> (f32) { -// CHECK-VEC0: %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref -// CHECK-VEC0: %[[lb:.*]] = memref.load %{{.*}}[%[[i]]] : memref<1024xf32> -// CHECK-VEC0: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32 -// CHECK-VEC0: %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : f32 -// CHECK-VEC0: scf.yield %[[a]] : f32 -// CHECK-VEC0: } -// CHECK-VEC0: return -// -// CHECK-VEC1-LABEL: func @reduction_d -// CHECK-VEC1-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-VEC1-DAG: %[[c16:.*]] = arith.constant 16 : index -// CHECK-VEC1-DAG: %[[c1024:.*]] = arith.constant 1024 : index -// CHECK-VEC1-DAG: %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32> -// CHECK-VEC1: %[[l:.*]] = memref.load %{{.*}}[] : memref -// CHECK-VEC1: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32> -// CHECK-VEC1: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) { -// CHECK-VEC1: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref, vector<16xf32> -// CHECK-VEC1: %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32> -// CHECK-VEC1: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32> -// CHECK-VEC1: %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : vector<16xf32> -// CHECK-VEC1: scf.yield %[[a]] : vector<16xf32> -// CHECK-VEC1: } -// CHECK-VEC1: %{{.*}} = vector.reduction , %[[red]] : vector<16xf32> into f32 -// CHECK-VEC1: return -// -// CHECK-VEC2-LABEL: func @reduction_d -// CHECK-VEC2-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-VEC2-DAG: %[[c16:.*]] = arith.constant 16 : index -// CHECK-VEC2-DAG: %[[c1024:.*]] = arith.constant 1024 : index -// CHECK-VEC2-DAG: %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32> -// CHECK-VEC2: %[[l:.*]] = memref.load %{{.*}}[] : memref -// CHECK-VEC2: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32> -// CHECK-VEC2: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) { -// CHECK-VEC2: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref, vector<16xf32> -// CHECK-VEC2: %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32> -// CHECK-VEC2: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32> -// CHECK-VEC2: %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : vector<16xf32> -// CHECK-VEC2: scf.yield %[[a]] : vector<16xf32> -// CHECK-VEC2: } -// CHECK-VEC2: %{{.*}} = vector.reduction , %[[red]] : vector<16xf32> into f32 -// CHECK-VEC2: return -// -// CHECK-VEC4: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1) -// CHECK-VEC4-LABEL: func @reduction_d -// CHECK-VEC4-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-VEC4-DAG: %[[c4:.*]] = arith.constant 4 : index -// CHECK-VEC4-DAG: %[[c1024:.*]] = arith.constant 1024 : index -// CHECK-VEC4-DAG: %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32> -// CHECK-VEC4: %[[l:.*]] = memref.load %{{.*}}[] : memref -// CHECK-VEC4: %[[vscale:.*]] = vector.vscale -// CHECK-VEC4: %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index -// CHECK-VEC4: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<[4]xf32> -// CHECK-VEC4: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[step]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<[4]xf32>) { -// CHECK-VEC4: %[[sub:.*]] = affine.min #[[$map]](%[[c1024]], %[[i]])[%[[step]]] -// CHECK-VEC4: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1> -// CHECK-VEC4: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32> -// CHECK-VEC4: %[[lb:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref<1024xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32> -// CHECK-VEC4: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<[4]xf32> -// CHECK-VEC4: %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : vector<[4]xf32> -// CHECK-VEC4: %[[sa:.*]] = arith.select %[[mask]], %[[a]], %[[red_in]] : vector<[4]xi1>, vector<[4]xf32> -// CHECK-VEC4: scf.yield %[[sa]] : vector<[4]xf32> -// CHECK-VEC4: } -// CHECK-VEC4: %{{.*}} = vector.reduction , %[[red]] : vector<[4]xf32> into f32 -// CHECK-VEC4: return +// CHECK-LABEL: func @reduction_d +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[c1024:.*]] = arith.constant 1024 : index +// CHECK: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] iter_args(%[[red_in:.*]] = %{{.*}}) -> (f32) { +// CHECK: %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref +// CHECK: %[[lb:.*]] = memref.load %{{.*}}[%[[i]]] : memref<1024xf32> +// CHECK: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32 +// CHECK: %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : f32 +// CHECK: scf.yield %[[a]] : f32 +// CHECK: } +// CHECK: return // func.func @reduction_d(%arga: tensor<1024xf32, #DenseVector>, %argb: tensor<1024xf32>, %argx: tensor) -> tensor { %0 = linalg.generic #trait_reduction_d @@ -343,137 +145,29 @@ func.func @reduction_d(%arga: tensor<1024xf32, #DenseVector>, %argb: tensor<1024 } // -// CHECK-VEC0-LABEL: func @mul_ds -// CHECK-VEC0-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-VEC0-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK-VEC0-DAG: %[[c512:.*]] = arith.constant 512 : index -// CHECK-VEC0: scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] { -// CHECK-VEC0: %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref -// CHECK-VEC0: %[[a:.*]] = arith.extui %[[p]] : i32 to i64 -// CHECK-VEC0: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index -// CHECK-VEC0: %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index -// CHECK-VEC0: %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref -// CHECK-VEC0: %[[b:.*]] = arith.extui %[[r]] : i32 to i64 -// CHECK-VEC0: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index -// CHECK-VEC0: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c1]] { -// CHECK-VEC0: %[[lj:.*]] = memref.load %{{.*}}[%[[j]]] : memref -// CHECK-VEC0: %[[zj:.*]] = arith.extui %[[lj]] : i32 to i64 -// CHECK-VEC0: %[[cj:.*]] = arith.index_cast %[[zj]] : i64 to index -// CHECK-VEC0: %[[la:.*]] = memref.load %{{.*}}[%[[j]]] : memref -// CHECK-VEC0: %[[lb:.*]] = memref.load %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32> -// CHECK-VEC0: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32 -// CHECK-VEC0: store %[[m]], %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32> -// CHECK-VEC0: } -// CHECK-VEC0: } -// CHECK-VEC0: return -// -// CHECK-VEC1-LABEL: func @mul_ds -// CHECK-VEC1-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-VEC1-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK-VEC1-DAG: %[[c512:.*]] = arith.constant 512 : index -// CHECK-VEC1: scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] { -// CHECK-VEC1: %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref -// CHECK-VEC1: %[[a:.*]] = arith.extui %[[p]] : i32 to i64 -// CHECK-VEC1: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index -// CHECK-VEC1: %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index -// CHECK-VEC1: %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref -// CHECK-VEC1: %[[b:.*]] = arith.extui %[[r]] : i32 to i64 -// CHECK-VEC1: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index -// CHECK-VEC1: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c1]] { -// CHECK-VEC1: %[[lj:.*]] = memref.load %{{.*}}[%[[j]]] : memref -// CHECK-VEC1: %[[zj:.*]] = arith.extui %[[lj]] : i32 to i64 -// CHECK-VEC1: %[[cj:.*]] = arith.index_cast %[[zj]] : i64 to index -// CHECK-VEC1: %[[la:.*]] = memref.load %{{.*}}[%[[j]]] : memref -// CHECK-VEC1: %[[lb:.*]] = memref.load %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32> -// CHECK-VEC1: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32 -// CHECK-VEC1: store %[[m]], %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32> -// CHECK-VEC1: } -// CHECK-VEC1: } -// CHECK-VEC1: return -// -// CHECK-VEC2: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1) -// CHECK-VEC2-LABEL: func @mul_ds -// CHECK-VEC2-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-VEC2-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK-VEC2-DAG: %[[c16:.*]] = arith.constant 16 : index -// CHECK-VEC2-DAG: %[[c512:.*]] = arith.constant 512 : index -// CHECK-VEC2: scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] { -// CHECK-VEC2: %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref -// CHECK-VEC2: %[[a:.*]] = arith.extui %[[p]] : i32 to i64 -// CHECK-VEC2: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index -// CHECK-VEC2: %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index -// CHECK-VEC2: %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref -// CHECK-VEC2: %[[b:.*]] = arith.extui %[[r]] : i32 to i64 -// CHECK-VEC2: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index -// CHECK-VEC2: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] { -// CHECK-VEC2: %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[j]])[%[[c16]]] -// CHECK-VEC2: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> -// CHECK-VEC2: %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> -// CHECK-VEC2: %[[zj:.*]] = arith.extui %[[lj]] : vector<16xi32> to vector<16xi64> -// CHECK-VEC2: %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> -// CHECK-VEC2: %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[zj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32> -// CHECK-VEC2: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32> -// CHECK-VEC2: vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[zj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> -// CHECK-VEC2: } -// CHECK-VEC2: } -// CHECK-VEC2: return -// -// CHECK-VEC3: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1) -// CHECK-VEC3-LABEL: func @mul_ds -// CHECK-VEC3-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-VEC3-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK-VEC3-DAG: %[[c16:.*]] = arith.constant 16 : index -// CHECK-VEC3-DAG: %[[c512:.*]] = arith.constant 512 : index -// CHECK-VEC3: scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] { -// CHECK-VEC3: %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref -// CHECK-VEC3: %[[a:.*]] = arith.extui %[[p]] : i32 to i64 -// CHECK-VEC3: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index -// CHECK-VEC3: %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index -// CHECK-VEC3: %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref -// CHECK-VEC3: %[[b:.*]] = arith.extui %[[r]] : i32 to i64 -// CHECK-VEC3: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index -// CHECK-VEC3: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] { -// CHECK-VEC3: %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[j]])[%[[c16]]] -// CHECK-VEC3: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> -// CHECK-VEC3: %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> -// CHECK-VEC3: %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> -// CHECK-VEC3: %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> -// CHECK-VEC3: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32> -// CHECK-VEC3: vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -// CHECK-VEC3: } -// CHECK-VEC3: } -// CHECK-VEC3: return -// -// CHECK-VEC4: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1) -// CHECK-VEC4-LABEL: func @mul_ds -// CHECK-VEC4-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-VEC4-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK-VEC4-DAG: %[[c4:.*]] = arith.constant 4 : index -// CHECK-VEC4-DAG: %[[c512:.*]] = arith.constant 512 : index -// CHECK-VEC4-DAG: %[[v0i:.*]] = arith.constant dense<0> : vector<[4]xi32> -// CHECK-VEC4-DAG: %[[v0f:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32> -// CHECK-VEC4: scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] { -// CHECK-VEC4: %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref -// CHECK-VEC4: %[[a:.*]] = arith.extui %[[p]] : i32 to i64 -// CHECK-VEC4: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index -// CHECK-VEC4: %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index -// CHECK-VEC4: %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref -// CHECK-VEC4: %[[b:.*]] = arith.extui %[[r]] : i32 to i64 -// CHECK-VEC4: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index -// CHECK-VEC4: %[[vscale:.*]] = vector.vscale -// CHECK-VEC4: %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index -// CHECK-VEC4: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[step]] { -// CHECK-VEC4: %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[j]])[%[[step]]] -// CHECK-VEC4: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1> -// CHECK-VEC4: %[[lji32:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %[[v0i]] : memref, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32> -// CHECK-VEC4: %[[lj:.*]] = arith.extui %[[lji32]] : vector<[4]xi32> to vector<[4]xi64> -// CHECK-VEC4: %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %[[v0f]] : memref, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32> -// CHECK-VEC4: %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[v0f]] : memref<512x1024xf32>, vector<[4]xi64>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32> -// CHECK-VEC4: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<[4]xf32> -// CHECK-VEC4: vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<[4]xi64>, vector<[4]xi1>, vector<[4]xf32> -// CHECK-VEC4: } -// CHECK-VEC4: } -// CHECK-VEC4: return +// CHECK-LABEL: func @mul_ds +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[c512:.*]] = arith.constant 512 : index +// CHECK: scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] { +// CHECK: %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref +// CHECK: %[[a:.*]] = arith.extui %[[p]] : i32 to i64 +// CHECK: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index +// CHECK: %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index +// CHECK: %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref +// CHECK: %[[b:.*]] = arith.extui %[[r]] : i32 to i64 +// CHECK: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index +// CHECK: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c1]] { +// CHECK: %[[lj:.*]] = memref.load %{{.*}}[%[[j]]] : memref +// CHECK: %[[zj:.*]] = arith.extui %[[lj]] : i32 to i64 +// CHECK: %[[cj:.*]] = arith.index_cast %[[zj]] : i64 to index +// CHECK: %[[la:.*]] = memref.load %{{.*}}[%[[j]]] : memref +// CHECK: %[[lb:.*]] = memref.load %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32> +// CHECK: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32 +// CHECK: store %[[m]], %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32> +// CHECK: } +// CHECK: } +// CHECK: return // func.func @mul_ds(%arga: tensor<512x1024xf32, #SparseMatrix>, %argb: tensor<512x1024xf32>, %argx: tensor<512x1024xf32>) -> tensor<512x1024xf32> { %0 = linalg.generic #trait_mul_ds @@ -500,89 +194,23 @@ func.func @mul_ds(%arga: tensor<512x1024xf32, #SparseMatrix>, %argb: tensor<512x } // -// CHECK-VEC0-LABEL: func @add_dense -// CHECK-VEC0-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-VEC0-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK-VEC0-DAG: %[[c32:.*]] = arith.constant 32 : index -// CHECK-VEC0: scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] { -// CHECK-VEC0: %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref -// CHECK-VEC0: %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index -// CHECK-VEC0: %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref -// CHECK-VEC0: scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[c1]] { -// CHECK-VEC0: %[[j:.*]] = memref.load %{{.*}}[%[[jj]]] : memref -// CHECK-VEC0: %[[x:.*]] = memref.load %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64> -// CHECK-VEC0: %[[a:.*]] = memref.load %{{.*}}[%[[jj]]] : memref -// CHECK-VEC0: %[[s:.*]] = arith.addf %[[x]], %[[a]] : f64 -// CHECK-VEC0: memref.store %[[s]], %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64> -// CHECK-VEC0: } -// CHECK-VEC0: } -// CHECK-VEC0: return -// -// CHECK-VEC1-LABEL: func @add_dense -// CHECK-VEC1-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-VEC1-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK-VEC1-DAG: %[[c32:.*]] = arith.constant 32 : index -// CHECK-VEC1: scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] { -// CHECK-VEC1: %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref -// CHECK-VEC1: %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index -// CHECK-VEC1: %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref -// CHECK-VEC1: scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[c1]] { -// CHECK-VEC1: %[[j:.*]] = memref.load %{{.*}}[%[[jj]]] : memref -// CHECK-VEC1: %[[x:.*]] = memref.load %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64> -// CHECK-VEC1: %[[a:.*]] = memref.load %{{.*}}[%[[jj]]] : memref -// CHECK-VEC1: %[[s:.*]] = arith.addf %[[x]], %[[a]] : f64 -// CHECK-VEC1: memref.store %[[s]], %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64> -// CHECK-VEC1: } -// CHECK-VEC1: } -// CHECK-VEC1: return -// -// CHECK-VEC2: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1) -// CHECK-VEC2-LABEL: func @add_dense -// CHECK-VEC2-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-VEC2-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK-VEC2-DAG: %[[c16:.*]] = arith.constant 16 : index -// CHECK-VEC2-DAG: %[[c32:.*]] = arith.constant 32 : index -// CHECK-VEC2: scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] { -// CHECK-VEC2: %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref -// CHECK-VEC2: %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index -// CHECK-VEC2: %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref -// CHECK-VEC2: scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[c16]] { -// CHECK-VEC2: %[[sub:.*]] = affine.min #[[$map]](%[[hi]], %[[jj]])[%[[c16]]] -// CHECK-VEC2: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> -// CHECK-VEC2: %[[j:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %{{.*}} : memref -// CHECK-VEC2: %[[x:.*]] = vector.gather %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %{{.*}} : memref<33x64xf64> -// CHECK-VEC2: %[[a:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %{{.*}} : memref -// CHECK-VEC2: %[[s:.*]] = arith.addf %[[x]], %[[a]] : vector<16xf64> -// CHECK-VEC2: vector.scatter %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %[[s]] : memref<33x64xf64> -// CHECK-VEC2: } -// CHECK-VEC2: } -// CHECK-VEC2: return -// -// CHECK-VEC4: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (s0, d0 - d1) -// CHECK-VEC4-LABEL: func @add_dense -// CHECK-VEC4-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-VEC4-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK-VEC4-DAG: %[[c4:.*]] = arith.constant 4 : index -// CHECK-VEC4-DAG: %[[c32:.*]] = arith.constant 32 : index -// CHECK-VEC4-DAG: %[[v0idx:.*]] = arith.constant dense<0> : vector<[4]xindex> -// CHECK-VEC4-DAG: %[[v0f64:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf64> -// CHECK-VEC4: scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] { -// CHECK-VEC4: %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref -// CHECK-VEC4: %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index -// CHECK-VEC4: %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref -// CHECK-VEC4: %[[vscale:.*]] = vector.vscale -// CHECK-VEC4: %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index -// CHECK-VEC4: scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[step]] { -// CHECK-VEC4: %[[sub:.*]] = affine.min #[[$map]](%[[hi]], %[[jj]])[%[[step]]] -// CHECK-VEC4: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1> -// CHECK-VEC4: %[[j:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %[[v0idx]] : memref -// CHECK-VEC4: %[[x:.*]] = vector.gather %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %[[v0f64]] : memref<33x64xf64> -// CHECK-VEC4: %[[a:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %[[v0f64]] : memref -// CHECK-VEC4: %[[s:.*]] = arith.addf %[[x]], %[[a]] : vector<[4]xf64> -// CHECK-VEC4: vector.scatter %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %[[s]] : memref<33x64xf64> -// CHECK-VEC4: } -// CHECK-VEC4: } -// CHECK-VEC4: return +// CHECK-LABEL: func @add_dense +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index +// CHECK: scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] { +// CHECK: %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref +// CHECK: %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index +// CHECK: %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref +// CHECK: scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[c1]] { +// CHECK: %[[j:.*]] = memref.load %{{.*}}[%[[jj]]] : memref +// CHECK: %[[x:.*]] = memref.load %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64> +// CHECK: %[[a:.*]] = memref.load %{{.*}}[%[[jj]]] : memref +// CHECK: %[[s:.*]] = arith.addf %[[x]], %[[a]] : f64 +// CHECK: memref.store %[[s]], %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64> +// CHECK: } +// CHECK: } +// CHECK: return // func.func @add_dense(%arga: tensor<32x64xf64, #SparseMatrix>, %argx: tensor<33x64xf64>) -> tensor<33x64xf64> { diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir deleted file mode 100644 index 6278421..0000000 --- a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir +++ /dev/null @@ -1,128 +0,0 @@ -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py - -// RUN: mlir-opt %s -sparsification="vectorization-strategy=any-storage-inner-loop vl=8" -canonicalize | \ -// RUN: FileCheck %s - -#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["dense","compressed"]}> - -#trait = { - indexing_maps = [ - affine_map<(i,j) -> (i,j)>, // a (in) - affine_map<(i,j) -> (i,j)>, // b (in) - affine_map<(i,j) -> ()> // x (out) - ], - iterator_types = ["reduction", "reduction"] -} - -// Verifies that the SIMD reductions in the two for-loops after the -// while-loop are chained before horizontally reducing these back to scalar. -// -// CHECK-LABEL: func @sparse_matrix_sum( -// CHECK-SAME: %[[VAL_0:.*]]: tensor, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<64x32xf64, #sparse_tensor.encoding<{{{.*}}}>>, -// CHECK-SAME: %[[VAL_2:.*]]: tensor<64x32xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor { -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<0.000000e+00> : vector<8xf64> -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 64 : index -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 1 : index} : tensor<64x32xf64, #sparse_tensor.encoding<{{{.*}}}>> -// CHECK: %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 1 : index} : tensor<64x32xf64, #sparse_tensor.encoding<{{{.*}}}>> -// CHECK: %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<64x32xf64, #sparse_tensor.encoding<{{{.*}}}>> -// CHECK: %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_2]] {dimension = 1 : index} : tensor<64x32xf64, #sparse_tensor.encoding<{{{.*}}}>> -// CHECK: %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_2]] {dimension = 1 : index} : tensor<64x32xf64, #sparse_tensor.encoding<{{{.*}}}>> -// CHECK: %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<64x32xf64, #sparse_tensor.encoding<{{{.*}}}>> -// CHECK: %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_0]] : memref -// CHECK: %[[VAL_16:.*]] = tensor.extract %[[VAL_0]][] : tensor -// CHECK: %[[VAL_17:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]]) -> (f64) { -// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref -// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_18]], %[[VAL_8]] : index -// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref -// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_18]]] : memref -// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_18]], %[[VAL_8]] : index -// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_24]]] : memref -// CHECK: %[[VAL_26:.*]]:3 = scf.while (%[[VAL_27:.*]] = %[[VAL_20]], %[[VAL_28:.*]] = %[[VAL_23]], %[[VAL_29:.*]] = %[[VAL_19]]) : (index, index, f64) -> (index, index, f64) { -// CHECK: %[[VAL_30:.*]] = arith.cmpi ult, %[[VAL_27]], %[[VAL_22]] : index -// CHECK: %[[VAL_31:.*]] = arith.cmpi ult, %[[VAL_28]], %[[VAL_25]] : index -// CHECK: %[[VAL_32:.*]] = arith.andi %[[VAL_30]], %[[VAL_31]] : i1 -// CHECK: scf.condition(%[[VAL_32]]) %[[VAL_27]], %[[VAL_28]], %[[VAL_29]] : index, index, f64 -// CHECK: } do { -// CHECK: ^bb0(%[[VAL_33:.*]]: index, %[[VAL_34:.*]]: index, %[[VAL_35:.*]]: f64): -// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_33]]] : memref -// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_34]]] : memref -// CHECK: %[[VAL_38:.*]] = arith.cmpi ult, %[[VAL_37]], %[[VAL_36]] : index -// CHECK: %[[VAL_39:.*]] = arith.select %[[VAL_38]], %[[VAL_37]], %[[VAL_36]] : index -// CHECK: %[[VAL_40:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_39]] : index -// CHECK: %[[VAL_41:.*]] = arith.cmpi eq, %[[VAL_37]], %[[VAL_39]] : index -// CHECK: %[[VAL_42:.*]] = arith.andi %[[VAL_40]], %[[VAL_41]] : i1 -// CHECK: %[[VAL_43:.*]] = scf.if %[[VAL_42]] -> (f64) { -// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_33]]] : memref -// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_34]]] : memref -// CHECK: %[[VAL_46:.*]] = arith.addf %[[VAL_44]], %[[VAL_45]] : f64 -// CHECK: %[[VAL_47:.*]] = arith.addf %[[VAL_35]], %[[VAL_46]] : f64 -// CHECK: scf.yield %[[VAL_47]] : f64 -// CHECK: } else { -// CHECK: %[[VAL_48:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_39]] : index -// CHECK: %[[VAL_49:.*]] = scf.if %[[VAL_48]] -> (f64) { -// CHECK: %[[VAL_50:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_33]]] : memref -// CHECK: %[[VAL_51:.*]] = arith.addf %[[VAL_35]], %[[VAL_50]] : f64 -// CHECK: scf.yield %[[VAL_51]] : f64 -// CHECK: } else { -// CHECK: %[[VAL_52:.*]] = arith.cmpi eq, %[[VAL_37]], %[[VAL_39]] : index -// CHECK: %[[VAL_53:.*]] = scf.if %[[VAL_52]] -> (f64) { -// CHECK: %[[VAL_54:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_34]]] : memref -// CHECK: %[[VAL_55:.*]] = arith.addf %[[VAL_35]], %[[VAL_54]] : f64 -// CHECK: scf.yield %[[VAL_55]] : f64 -// CHECK: } else { -// CHECK: scf.yield %[[VAL_35]] : f64 -// CHECK: } -// CHECK: scf.yield %[[VAL_56:.*]] : f64 -// CHECK: } -// CHECK: scf.yield %[[VAL_57:.*]] : f64 -// CHECK: } -// CHECK: %[[VAL_58:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_39]] : index -// CHECK: %[[VAL_59:.*]] = arith.addi %[[VAL_33]], %[[VAL_8]] : index -// CHECK: %[[VAL_60:.*]] = arith.select %[[VAL_58]], %[[VAL_59]], %[[VAL_33]] : index -// CHECK: %[[VAL_61:.*]] = arith.cmpi eq, %[[VAL_37]], %[[VAL_39]] : index -// CHECK: %[[VAL_62:.*]] = arith.addi %[[VAL_34]], %[[VAL_8]] : index -// CHECK: %[[VAL_63:.*]] = arith.select %[[VAL_61]], %[[VAL_62]], %[[VAL_34]] : index -// CHECK: scf.yield %[[VAL_60]], %[[VAL_63]], %[[VAL_64:.*]] : index, index, f64 -// CHECK: } -// CHECK: %[[VAL_65:.*]] = vector.insertelement %[[VAL_66:.*]]#2, %[[VAL_3]]{{\[}}%[[VAL_6]] : index] : vector<8xf64> -// CHECK: %[[VAL_67:.*]] = scf.for %[[VAL_68:.*]] = %[[VAL_66]]#0 to %[[VAL_22]] step %[[VAL_4]] iter_args(%[[VAL_69:.*]] = %[[VAL_65]]) -> (vector<8xf64>) { -// CHECK: %[[VAL_70:.*]] = affine.min #map(%[[VAL_22]], %[[VAL_68]]) -// CHECK: %[[VAL_71:.*]] = vector.create_mask %[[VAL_70]] : vector<8xi1> -// CHECK: %[[VAL_72:.*]] = vector.maskedload %[[VAL_11]]{{\[}}%[[VAL_68]]], %[[VAL_71]], %[[VAL_3]] : memref, vector<8xi1>, vector<8xf64> into vector<8xf64> -// CHECK: %[[VAL_73:.*]] = arith.addf %[[VAL_69]], %[[VAL_72]] : vector<8xf64> -// CHECK: %[[VAL_74:.*]] = arith.select %[[VAL_71]], %[[VAL_73]], %[[VAL_69]] : vector<8xi1>, vector<8xf64> -// CHECK: scf.yield %[[VAL_74]] : vector<8xf64> -// CHECK: } -// CHECK: %[[VAL_75:.*]] = scf.for %[[VAL_76:.*]] = %[[VAL_66]]#1 to %[[VAL_25]] step %[[VAL_4]] iter_args(%[[VAL_77:.*]] = %[[VAL_78:.*]]) -> (vector<8xf64>) { -// CHECK: %[[VAL_79:.*]] = affine.min #map(%[[VAL_25]], %[[VAL_76]]) -// CHECK: %[[VAL_80:.*]] = vector.create_mask %[[VAL_79]] : vector<8xi1> -// CHECK: %[[VAL_81:.*]] = vector.maskedload %[[VAL_14]]{{\[}}%[[VAL_76]]], %[[VAL_80]], %[[VAL_3]] : memref, vector<8xi1>, vector<8xf64> into vector<8xf64> -// CHECK: %[[VAL_82:.*]] = arith.addf %[[VAL_77]], %[[VAL_81]] : vector<8xf64> -// CHECK: %[[VAL_83:.*]] = arith.select %[[VAL_80]], %[[VAL_82]], %[[VAL_77]] : vector<8xi1>, vector<8xf64> -// CHECK: scf.yield %[[VAL_83]] : vector<8xf64> -// CHECK: } -// CHECK: %[[VAL_84:.*]] = vector.reduction , %[[VAL_85:.*]] : vector<8xf64> into f64 -// CHECK: scf.yield %[[VAL_84]] : f64 -// CHECK: } -// CHECK: memref.store %[[VAL_86:.*]], %[[VAL_15]][] : memref -// CHECK: %[[VAL_87:.*]] = bufferization.to_tensor %[[VAL_15]] : memref -// CHECK: return %[[VAL_87]] : tensor -// CHECK: } -func.func @sparse_matrix_sum(%argx: tensor, - %arga: tensor<64x32xf64, #SparseMatrix>, - %argb: tensor<64x32xf64, #SparseMatrix>) -> tensor { - %0 = linalg.generic #trait - ins(%arga, %argb: tensor<64x32xf64, #SparseMatrix>, - tensor<64x32xf64, #SparseMatrix>) - outs(%argx: tensor) { - ^bb(%a: f64, %b: f64, %x: f64): - %m = arith.addf %a, %b : f64 - %t = arith.addf %x, %m : f64 - linalg.yield %t : f64 - } -> tensor - return %0 : tensor -} diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir deleted file mode 100644 index 70f54ee..0000000 --- a/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir +++ /dev/null @@ -1,126 +0,0 @@ -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py - -// The script is designed to make adding checks to -// a test case fast, it is *not* designed to be authoritative -// about what constitutes a good test! The CHECK should be -// minimized and named to reflect the test intent. - -// RUN: mlir-opt %s -sparsification="vectorization-strategy=any-storage-inner-loop vl=8" -canonicalize | \ -// RUN: FileCheck %s - -#SparseVector = #sparse_tensor.encoding<{ - dimLevelType = ["compressed"] -}> - -#trait_1d = { - indexing_maps = [ - affine_map<(i) -> (i)>, // a - affine_map<(i) -> (i)> // x (out) - ], - iterator_types = ["parallel"], - doc = "X(i) = a(i) op i" -} - -// CHECK-LABEL: func @sparse_index_1d_conj( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<8xi64> { -// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<0> : vector<8xi64> -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<0> : vector<8xindex> -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : i64 -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref -// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref -// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref -// CHECK-DAG: %[[VAL_10a:.*]] = tensor.empty() : tensor<8xi64> -// CHECK-DAG: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_10a]] : memref<8xi64> -// CHECK-DAG: linalg.fill ins(%[[VAL_5]] : i64) outs(%[[VAL_10]] : memref<8xi64>) -// CHECK-DAG: %[[VAL_11:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref -// CHECK-DAG: %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref -// CHECK: scf.for %[[VAL_13:.*]] = %[[VAL_11]] to %[[VAL_12]] step %[[VAL_3]] { -// CHECK: %[[VAL_14:.*]] = affine.min #map0(%[[VAL_13]]){{\[}}%[[VAL_12]]] -// CHECK: %[[VAL_15:.*]] = vector.create_mask %[[VAL_14]] : vector<8xi1> -// CHECK: %[[VAL_16:.*]] = vector.maskedload %[[VAL_8]]{{\[}}%[[VAL_13]]], %[[VAL_15]], %[[VAL_2]] : memref, vector<8xi1>, vector<8xindex> into vector<8xindex> -// CHECK: %[[VAL_17:.*]] = vector.maskedload %[[VAL_9]]{{\[}}%[[VAL_13]]], %[[VAL_15]], %[[VAL_1]] : memref, vector<8xi1>, vector<8xi64> into vector<8xi64> -// CHECK: %[[VAL_18:.*]] = arith.index_cast %[[VAL_16]] : vector<8xindex> to vector<8xi64> -// CHECK: %[[VAL_19:.*]] = arith.muli %[[VAL_17]], %[[VAL_18]] : vector<8xi64> -// CHECK: vector.scatter %[[VAL_10]]{{\[}}%[[VAL_6]]] {{\[}}%[[VAL_16]]], %[[VAL_15]], %[[VAL_19]] : memref<8xi64>, vector<8xindex>, vector<8xi1>, vector<8xi64> -// CHECK: } -// CHECK: %[[VAL_20:.*]] = bufferization.to_tensor %[[VAL_10]] : memref<8xi64> -// CHECK: return %[[VAL_20]] : tensor<8xi64> -// CHECK: } -func.func @sparse_index_1d_conj(%arga: tensor<8xi64, #SparseVector>) -> tensor<8xi64> { - %init = tensor.empty() : tensor<8xi64> - %r = linalg.generic #trait_1d - ins(%arga: tensor<8xi64, #SparseVector>) - outs(%init: tensor<8xi64>) { - ^bb(%a: i64, %x: i64): - %i = linalg.index 0 : index - %ii = arith.index_cast %i : index to i64 - %m1 = arith.muli %a, %ii : i64 - linalg.yield %m1 : i64 - } -> tensor<8xi64> - return %r : tensor<8xi64> -} - -// CHECK-LABEL: func @sparse_index_1d_disj( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<8xi64> { -// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex> -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : i64 -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref -// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref -// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref -// CHECK-DAG: %[[VAL_9a:.*]] = tensor.empty() : tensor<8xi64> -// CHECK-DAG: %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_9a]] : memref<8xi64> -// CHECK-DAG: linalg.fill ins(%[[VAL_3]] : i64) outs(%[[VAL_9]] : memref<8xi64>) -// CHECK-DAG: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref -// CHECK-DAG: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_2]]] : memref -// CHECK: %[[VAL_12:.*]]:2 = scf.while (%[[VAL_13:.*]] = %[[VAL_10]], %[[VAL_14:.*]] = %[[VAL_5]]) : (index, index) -> (index, index) { -// CHECK: %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_13]], %[[VAL_11]] : index -// CHECK: scf.condition(%[[VAL_15]]) %[[VAL_13]], %[[VAL_14]] : index, index -// CHECK: } do { -// CHECK: ^bb0(%[[VAL_16:.*]]: index, %[[VAL_17:.*]]: index): -// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref -// CHECK: %[[VAL_19:.*]] = arith.cmpi eq, %[[VAL_18]], %[[VAL_17]] : index -// CHECK: scf.if %[[VAL_19]] { -// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref -// CHECK: %[[VAL_21:.*]] = arith.index_cast %[[VAL_17]] : index to i64 -// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_20]], %[[VAL_21]] : i64 -// CHECK: memref.store %[[VAL_22]], %[[VAL_9]]{{\[}}%[[VAL_17]]] : memref<8xi64> -// CHECK: } else { -// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_17]] : index to i64 -// CHECK: memref.store %[[VAL_23]], %[[VAL_9]]{{\[}}%[[VAL_17]]] : memref<8xi64> -// CHECK: } -// CHECK: %[[VAL_24:.*]] = arith.cmpi eq, %[[VAL_18]], %[[VAL_17]] : index -// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_16]], %[[VAL_2]] : index -// CHECK: %[[VAL_26:.*]] = arith.select %[[VAL_24]], %[[VAL_25]], %[[VAL_16]] : index -// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_17]], %[[VAL_2]] : index -// CHECK: scf.yield %[[VAL_26]], %[[VAL_27]] : index, index -// CHECK: } -// CHECK: scf.for %[[VAL_28:.*]] = %[[VAL_29:.*]]#1 to %[[VAL_4]] step %[[VAL_4]] { -// CHECK: %[[VAL_30:.*]] = affine.min #map1(%[[VAL_28]]) -// CHECK: %[[VAL_31:.*]] = vector.create_mask %[[VAL_30]] : vector<8xi1> -// CHECK: %[[VAL_32:.*]] = vector.broadcast %[[VAL_28]] : index to vector<8xindex> -// CHECK: %[[VAL_33:.*]] = arith.addi %[[VAL_32]], %[[VAL_1]] : vector<8xindex> -// CHECK: %[[VAL_34:.*]] = arith.index_cast %[[VAL_33]] : vector<8xindex> to vector<8xi64> -// CHECK: vector.maskedstore %[[VAL_9]]{{\[}}%[[VAL_28]]], %[[VAL_31]], %[[VAL_34]] : memref<8xi64>, vector<8xi1>, vector<8xi64> -// CHECK: } -// CHECK: %[[VAL_35:.*]] = bufferization.to_tensor %[[VAL_9]] : memref<8xi64> -// CHECK: return %[[VAL_35]] : tensor<8xi64> -// CHECK: } -func.func @sparse_index_1d_disj(%arga: tensor<8xi64, #SparseVector>) -> tensor<8xi64> { - %init = tensor.empty() : tensor<8xi64> - %r = linalg.generic #trait_1d - ins(%arga: tensor<8xi64, #SparseVector>) - outs(%init: tensor<8xi64>) { - ^bb(%a: i64, %x: i64): - %i = linalg.index 0 : index - %ii = arith.index_cast %i : index to i64 - %m1 = arith.addi %a, %ii : i64 - linalg.yield %m1 : i64 - } -> tensor<8xi64> - return %r : tensor<8xi64> -} diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir deleted file mode 100644 index 276b8a9..0000000 --- a/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir +++ /dev/null @@ -1,63 +0,0 @@ -// RUN: mlir-opt %s -sparsification="vectorization-strategy=any-storage-inner-loop vl=16" -scf-for-loop-peeling -canonicalize | \ -// RUN: FileCheck %s - -#SparseVector = #sparse_tensor.encoding<{ - dimLevelType = [ "compressed" ], - pointerBitWidth = 32, - indexBitWidth = 32 -}> - -#trait_mul_s = { - indexing_maps = [ - affine_map<(i) -> (i)>, // a - affine_map<(i) -> (i)>, // b - affine_map<(i) -> (i)> // x (out) - ], - iterator_types = ["parallel"], - doc = "x(i) = a(i) * b(i)" -} - -// CHECK-DAG: #[[$map0:.*]] = affine_map<()[s0, s1] -> (s0 + ((-s0 + s1) floordiv 16) * 16)> -// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0)> -// CHECK-LABEL: func @mul_s -// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index -// CHECK: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref -// CHECK: %[[a:.*]] = arith.extui %[[p]] : i32 to i64 -// CHECK: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index -// CHECK: %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref -// CHECK: %[[b:.*]] = arith.extui %[[r]] : i32 to i64 -// CHECK: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index -// CHECK: %[[boundary:.*]] = affine.apply #[[$map0]]()[%[[q]], %[[s]]] -// CHECK: scf.for %[[i:.*]] = %[[q]] to %[[boundary]] step %[[c16]] { -// CHECK: %[[mask:.*]] = vector.constant_mask [16] : vector<16xi1> -// CHECK: %[[li:.*]] = vector.load %{{.*}}[%[[i]]] : memref, vector<16xi32> -// CHECK: %[[zi:.*]] = arith.extui %[[li]] : vector<16xi32> to vector<16xi64> -// CHECK: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref, vector<16xf32> -// CHECK: %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32> -// CHECK: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32> -// CHECK: vector.scatter %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> -// CHECK: } -// CHECK: scf.for %[[i2:.*]] = %[[boundary]] to %[[s]] step %[[c16]] { -// CHECK: %[[sub:.*]] = affine.apply #[[$map1]](%[[i2]])[%[[s]]] -// CHECK: %[[mask2:.*]] = vector.create_mask %[[sub]] : vector<16xi1> -// CHECK: %[[li2:.*]] = vector.maskedload %{{.*}}[%[[i2]]], %[[mask2]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> -// CHECK: %[[zi2:.*]] = arith.extui %[[li2]] : vector<16xi32> to vector<16xi64> -// CHECK: %[[la2:.*]] = vector.maskedload %{{.*}}[%[[i2]]], %[[mask2]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> -// CHECK: %[[lb2:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[zi2]]], %[[mask2]], %{{.*}} : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32> -// CHECK: %[[m2:.*]] = arith.mulf %[[la2]], %[[lb2]] : vector<16xf32> -// CHECK: vector.scatter %{{.*}}[%[[c0]]] [%[[zi2]]], %[[mask2]], %[[m2]] : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> -// CHECK: } -// CHECK: return -// -func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>, %argb: tensor<1024xf32>, %argx: tensor<1024xf32>) -> tensor<1024xf32> { - %0 = linalg.generic #trait_mul_s - ins(%arga, %argb: tensor<1024xf32, #SparseVector>, tensor<1024xf32>) - outs(%argx: tensor<1024xf32>) { - ^bb(%a: f32, %b: f32, %x: f32): - %0 = arith.mulf %a, %b : f32 - linalg.yield %0 : f32 - } -> tensor<1024xf32> - return %0 : tensor<1024xf32> -} diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_cast.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_cast.mlir index 3e701bc7..beb5fab 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_cast.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_cast.mlir @@ -3,14 +3,6 @@ // RUN: -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s -// -// Do the same run, but now with SIMDization as well. This should not change the outcome. -// -// RUN: mlir-opt %s --sparse-compiler="vectorization-strategy=any-storage-inner-loop vl=2" | \ -// RUN: mlir-cpu-runner \ -// RUN: -e entry -entry-point-result=void \ -// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s #SV = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir index 38d809c..213b321 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir @@ -2,13 +2,6 @@ // RUN: mlir-cpu-runner -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s -// -// Do the same run, but now with SIMDization as well. This should not change the outcome. -// -// RUN: mlir-opt %s --sparse-compiler="vectorization-strategy=any-storage-inner-loop vl=2" | \ -// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ -// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s #DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_flatten.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_flatten.mlir index 9ac9785..80e4680 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_flatten.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_flatten.mlir @@ -4,15 +4,6 @@ // RUN: -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s -// -// Do the same run, but now with SIMDization as well. This should not change the outcome. -// -// RUN: mlir-opt %s --sparse-compiler="vectorization-strategy=any-storage-inner-loop vl=4" | \ -// RUN: TENSOR0="%mlir_src_dir/test/Integration/data/test.tns" \ -// RUN: mlir-cpu-runner \ -// RUN: -e entry -entry-point-result=void \ -// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s !Filename = !llvm.ptr diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index_dense.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index_dense.mlir index 9a79247..1155423 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index_dense.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index_dense.mlir @@ -2,13 +2,6 @@ // RUN: mlir-cpu-runner -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s -// -// Do the same run, but now with SIMDization as well. This should not change the outcome. -// -// RUN: mlir-opt %s --sparse-compiler="vectorization-strategy=any-storage-inner-loop vl=4" | \ -// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ -// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s #SparseVector = #sparse_tensor.encoding<{ dimLevelType = ["compressed"] diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matvec.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matvec.mlir index abe4ab9..43a8eed 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matvec.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matvec.mlir @@ -4,16 +4,6 @@ // RUN: -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s -// -// Do the same run, but now with SIMDization as well. This should not change the outcome. -// -// RUN: mlir-opt %s \ -// RUN: --sparse-compiler="vectorization-strategy=any-storage-inner-loop vl=16 enable-simd-index32" | \ -// RUN: TENSOR0="%mlir_src_dir/test/Integration/data/wide.mtx" \ -// RUN: mlir-cpu-runner \ -// RUN: -e entry -entry-point-result=void \ -// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s !Filename = !llvm.ptr diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_mttkrp.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_mttkrp.mlir index b8c2bb1..4ee5e1e 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_mttkrp.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_mttkrp.mlir @@ -4,15 +4,6 @@ // RUN: -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s -// -// Do the same run, but now with SIMDization as well. This should not change the outcome. -// -// RUN: mlir-opt %s --sparse-compiler="vectorization-strategy=any-storage-inner-loop vl=4" | \ -// RUN: TENSOR0="%mlir_src_dir/test/Integration/data/mttkrp_b.tns" \ -// RUN: mlir-cpu-runner \ -// RUN: -e entry -entry-point-result=void \ -// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s !Filename = !llvm.ptr diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_simple.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_simple.mlir index 96f9b85..de25484 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_simple.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_simple.mlir @@ -4,15 +4,6 @@ // RUN: -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s -// -// Do the same run, but now with SIMDization as well. This should not change the outcome. -// -// RUN: mlir-opt %s --sparse-compiler="vectorization-strategy=any-storage-inner-loop vl=4" | \ -// RUN: TENSOR0="%mlir_src_dir/test/Integration/data/test.mtx" \ -// RUN: mlir-cpu-runner \ -// RUN: -e entry -entry-point-result=void \ -// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s !Filename = !llvm.ptr diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_quantized_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_quantized_matmul.mlir index 6bd9cc6..dd5ee6f2 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_quantized_matmul.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_quantized_matmul.mlir @@ -2,13 +2,6 @@ // RUN: mlir-cpu-runner -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s -// -// Do the same run, but now with SIMDization as well. This should not change the outcome. -// -// RUN: mlir-opt %s --sparse-compiler="vectorization-strategy=any-storage-inner-loop vl=2" | \ -// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ -// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s #DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions.mlir index 662d96e..dfed9c8 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions.mlir @@ -2,13 +2,6 @@ // RUN: mlir-cpu-runner -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s -// -// Do the same run, but now with SIMDization as well. This should not change the outcome. -// -// RUN: mlir-opt %s -sparse-compiler="vectorization-strategy=any-storage-inner-loop vl=8" | \ -// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ -// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s #SV = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> #DV = #sparse_tensor.encoding<{ dimLevelType = [ "dense" ] }> diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_matmul.mlir index ecb551b..b368670 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_matmul.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_matmul.mlir @@ -4,17 +4,6 @@ // RUN: -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s -// -// Do the same run, but now with SIMDization as well. This should not change the outcome. -// -// RUN: mlir-opt %s \ -// RUN: --sparse-compiler="vectorization-strategy=any-storage-inner-loop vl=4 enable-simd-index32" | \ -// RUN: TENSOR0="%mlir_src_dir/test/Integration/data/test.mtx" \ -// RUN: mlir-cpu-runner \ -// RUN: -e entry -entry-point-result=void \ -// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s -// !Filename = !llvm.ptr diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir index a786780..164ea65 100755 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir @@ -2,13 +2,6 @@ // RUN: mlir-cpu-runner -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s -// -// Do the same run, but now with SIMDization as well. This should not change the outcome. -// -// RUN: mlir-opt %s -sparse-compiler="vl=8" | \ -// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ -// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s #SM = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_scale.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_scale.mlir index b8fdabc..f9687e45 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_scale.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_scale.mlir @@ -3,14 +3,6 @@ // RUN: -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s -// -// Do the same run, but now with SIMDization as well. This should not change the outcome. -// -// RUN: mlir-opt %s --sparse-compiler="vectorization-strategy=any-storage-inner-loop vl=4" | \ -// RUN: mlir-cpu-runner \ -// RUN: -e entry -entry-point-result=void \ -// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s #CSR = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }> diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_spmm.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_spmm.mlir index 9ae063e..7e3aefd 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_spmm.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_spmm.mlir @@ -4,15 +4,6 @@ // RUN: -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s -// -// Do the same run, but now with SIMDization as well. This should not change the outcome. -// -// RUN: mlir-opt %s --sparse-compiler="vectorization-strategy=any-storage-inner-loop vl=2" | \ -// RUN: TENSOR0="%mlir_src_dir/test/Integration/data/wide.mtx" \ -// RUN: mlir-cpu-runner \ -// RUN: -e entry -entry-point-result=void \ -// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s !Filename = !llvm.ptr diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir index 66961c1..c6b4e035 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir @@ -4,15 +4,6 @@ // RUN: -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s -// -// Do the same run, but now with SIMDization as well. This should not change the outcome. -// -// RUN: mlir-opt %s --sparse-compiler="vectorization-strategy=any-storage-inner-loop vl=2" | \ -// RUN: TENSOR0="%mlir_src_dir/test/Integration/data/test_symmetric.mtx" \ -// RUN: mlir-cpu-runner \ -// RUN: -e entry -entry-point-result=void \ -// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s !Filename = !llvm.ptr diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py index 36a3da6..3fd5964 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py +++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py @@ -140,24 +140,19 @@ def main(): ir.AffineMap.get_permutation([0, 1]), ir.AffineMap.get_permutation([1, 0]) ] - vec_strategy = ['none', 'dense-inner-loop'] for level in levels: for ordering in orderings: for pwidth in [32]: for iwidth in [32]: - for vec in vec_strategy: - for e in [True]: - vl = 1 if vec == 0 else 16 - attr = st.EncodingAttr.get(level, ordering, None, pwidth, - iwidth) - opt = (f'parallelization-strategy=none ' - f'vectorization-strategy={vec} ' - f'vl={vl} enable-simd-index32={e}') - compiler = sparse_compiler.SparseCompiler( - options=opt, opt_level=0, shared_libs=[support_lib]) - build_compile_and_run_SDDMMM(attr, compiler) - count = count + 1 - # CHECK: Passed 16 tests + for e in [True]: + attr = st.EncodingAttr.get(level, ordering, None, pwidth, + iwidth) + opt = (f'parallelization-strategy=none') + compiler = sparse_compiler.SparseCompiler( + options=opt, opt_level=0, shared_libs=[support_lib]) + build_compile_and_run_SDDMMM(attr, compiler) + count = count + 1 + # CHECK: Passed 8 tests print('Passed ', count, 'tests') diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py index 7b51091..119de07 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py +++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py @@ -123,9 +123,7 @@ def main(): vl = 1 e = False - opt = (f'parallelization-strategy=none ' - f'vectorization-strategy=none ' - f'vl={vl} enable-simd-index32={e}') + opt = (f'parallelization-strategy=none') levels = [[st.DimLevelType.dense, st.DimLevelType.dense], [st.DimLevelType.dense, st.DimLevelType.compressed], [st.DimLevelType.compressed, st.DimLevelType.dense], diff --git a/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py b/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py index d05cb40..13bba75 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py +++ b/mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py @@ -183,8 +183,6 @@ def main(): # CHECK-LABEL: TEST: test_stress print("\nTEST: test_stress") with ir.Context() as ctx, ir.Location.unknown(): - vl = 1 - e = False # Disable direct sparse2sparse conversion, because it doubles the time! # TODO: While direct s2s is far too slow for per-commit testing, # we should have some framework ensure that we run this test with @@ -193,9 +191,6 @@ def main(): s2s = 1 sparsification_options = ( f'parallelization-strategy=none ' - f'vectorization-strategy=none ' - f'vl={vl} ' - f'enable-simd-index32={e} ' f's2s-strategy={s2s}') compiler = sparse_compiler.SparseCompiler( options=sparsification_options, opt_level=0, shared_libs=[support_lib]) diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index af3805f..c2914ab 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2153,7 +2153,6 @@ cc_library( ":Support", ":TensorDialect", ":Transforms", - ":VectorDialect", "//llvm:Support", ], ) -- 2.7.4