From: wren romano <2998727+wrengr@users.noreply.github.com> Date: Wed, 9 Nov 2022 00:43:44 +0000 (-0800) Subject: [mlir][sparse] Factoring out NewCallParams class in SparseTensorConversion.cpp X-Git-Tag: upstream/17.0.6~28165 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f5ce99afa72fd74d57a3b9fba658f48626b3aef5;p=platform%2Fupstream%2Fllvm.git [mlir][sparse] Factoring out NewCallParams class in SparseTensorConversion.cpp The new class helps encapsulate the arguments to `_mlir_ciface_newSparseTensor` so that client code doesn't depend on the details of the API. (This makes way for the next differential which significantly alters the API.) Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D137680 --- diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp index f41a579..6ca6cfc 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -70,16 +70,6 @@ static Value genDimSizeCall(OpBuilder &builder, Location loc, .getResult(0); } -/// Generates a call into the "swiss army knife" method of the sparse runtime -/// support library for materializing sparse tensors into the computation. -static Value genNewCall(OpBuilder &builder, Location loc, - ArrayRef params) { - StringRef name = "newSparseTensor"; - Type pTp = getOpaquePointerType(builder); - return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On) - .getResult(0); -} - /// Compute the size from type (for static sizes) or from an already-converted /// opaque pointer source (for dynamic sizes) at the given dimension. static Value sizeFromPtrAtDim(OpBuilder &builder, Location loc, @@ -168,41 +158,132 @@ static Value genBuffer(OpBuilder &builder, Location loc, ValueRange values) { return buffer; } -/// Populates parameters required to call the "swiss army knife" method of the -/// sparse runtime support library for materializing sparse tensors into the -/// computation. -static void newParams(OpBuilder &builder, SmallVector ¶ms, - Location loc, ShapedType stp, - SparseTensorEncodingAttr &enc, Action action, - ValueRange szs, Value ptr = Value()) { - ArrayRef dlt = enc.getDimLevelType(); - unsigned sz = dlt.size(); +/// This class abstracts over the API of `_mlir_ciface_newSparseTensor`: +/// the "swiss army knife" method of the sparse runtime support library +/// for materializing sparse tensors into the computation. This abstraction +/// reduces the need to make modifications to client code whenever that +/// API changes. +class NewCallParams final { +public: + /// Allocates the `ValueRange` for the `func::CallOp` parameters, + /// but does not initialize them. + NewCallParams(OpBuilder &builder, Location loc) + : builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {} + + /// Initializes all static parameters (i.e., those which indicate + /// type-level information such as the encoding and sizes), generating + /// MLIR buffers as needed, and returning `this` for method chaining. + /// This method does not set the action and pointer arguments, since + /// those are handled by `genNewCall` instead. + NewCallParams &genBuffers(SparseTensorEncodingAttr enc, ValueRange sizes, + ShapedType stp); + + /// (Re)sets the C++ template type parameters, and returns `this` + /// for method chaining. This is already done as part of `genBuffers`, + /// but is factored out so that it can also be called independently + /// whenever subsequent `genNewCall` calls want to reuse the same + /// buffers but different type parameters. + // + // TODO: This is only ever used by sparse2sparse-viaCOO `ConvertOp`; + // is there a better way to handle that than this one-off setter method? + NewCallParams &setTemplateTypes(SparseTensorEncodingAttr enc, + ShapedType stp) { + params[kParamPtrTp] = constantPointerTypeEncoding(builder, loc, enc); + params[kParamIndTp] = constantIndexTypeEncoding(builder, loc, enc); + params[kParamValTp] = + constantPrimaryTypeEncoding(builder, loc, stp.getElementType()); + return *this; + } + + /// Checks whether all the static parameters have been initialized. + bool isInitialized() const { + for (unsigned i = 0; i < kNumStaticParams; ++i) + if (!params[i]) + return false; + return true; + } + + /// Gets the dimension-to-level mapping. + // + // TODO: This is only ever used for passing into `genAddEltCall`; + // is there a better way to encapsulate that pattern (both to avoid + // this one-off getter, and to avoid potential mixups)? + Value getDim2LvlMap() const { + assert(isInitialized() && "Must initialize before getDim2LvlMap"); + return params[kParamDim2Lvl]; + } + + /// Generates a function call, with the current static parameters + /// and the given dynamic arguments. + Value genNewCall(Action action, Value ptr = Value()) { + assert(isInitialized() && "Must initialize before genNewCall"); + StringRef name = "newSparseTensor"; + params[kParamAction] = constantAction(builder, loc, action); + params[kParamPtr] = ptr ? ptr : builder.create(loc, pTp); + return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On) + .getResult(0); + } + +private: + static constexpr unsigned kNumStaticParams = 6; + static constexpr unsigned kNumDynamicParams = 2; + static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams; + static constexpr unsigned kParamLvlTypes = 0; + static constexpr unsigned kParamDimSizes = 1; + static constexpr unsigned kParamDim2Lvl = 2; + static constexpr unsigned kParamPtrTp = 3; + static constexpr unsigned kParamIndTp = 4; + static constexpr unsigned kParamValTp = 5; + static constexpr unsigned kParamAction = 6; + static constexpr unsigned kParamPtr = 7; + + OpBuilder &builder; + Location loc; + Type pTp; + Value params[kNumParams]; +}; + +// TODO: see the note at `_mlir_ciface_newSparseTensor` about how +// the meaning of the various arguments (e.g., "sizes" vs "shapes") +// is inconsistent between the different actions. +NewCallParams &NewCallParams::genBuffers(SparseTensorEncodingAttr enc, + ValueRange dimSizes, ShapedType stp) { + const unsigned lvlRank = enc.getDimLevelType().size(); + const unsigned dimRank = stp.getRank(); // Sparsity annotations. - SmallVector attrs; - for (unsigned i = 0; i < sz; i++) - attrs.push_back(constantDimLevelTypeEncoding(builder, loc, dlt[i])); - params.push_back(genBuffer(builder, loc, attrs)); - // Dimension sizes array of the enveloping tensor. Useful for either + SmallVector lvlTypes; + for (auto dlt : enc.getDimLevelType()) + lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt)); + assert(lvlTypes.size() == lvlRank && "Level-rank mismatch"); + params[kParamLvlTypes] = genBuffer(builder, loc, lvlTypes); + // Dimension-sizes array of the enveloping tensor. Useful for either // verification of external data, or for construction of internal data. - params.push_back(genBuffer(builder, loc, szs)); - // Dimension order permutation array. This is the "identity" permutation by - // default, or otherwise the "reverse" permutation of a given ordering, so - // that indices can be mapped quickly to the right position. - SmallVector rev(sz); - for (unsigned i = 0; i < sz; i++) - rev[toOrigDim(enc, i)] = constantIndex(builder, loc, i); - params.push_back(genBuffer(builder, loc, rev)); + assert(dimSizes.size() == dimRank && "Dimension-rank mismatch"); + params[kParamDimSizes] = genBuffer(builder, loc, dimSizes); + // The dimension-to-level mapping. We must preinitialize `dim2lvl` + // so that the true branch below can perform random-access `operator[]` + // assignment. + SmallVector dim2lvl(dimRank); + auto dimOrder = enc.getDimOrdering(); + if (dimOrder) { + assert(dimOrder.isPermutation()); + for (unsigned l = 0; l < lvlRank; l++) { + // The `d`th source variable occurs in the `l`th result position. + uint64_t d = dimOrder.getDimPosition(l); + dim2lvl[d] = constantIndex(builder, loc, l); + } + } else { + assert(dimRank == lvlRank && "Rank mismatch"); + for (unsigned i = 0; i < lvlRank; i++) + dim2lvl[i] = constantIndex(builder, loc, i); + } + params[kParamDim2Lvl] = genBuffer(builder, loc, dim2lvl); // Secondary and primary types encoding. - Type elemTp = stp.getElementType(); - params.push_back(constantPointerTypeEncoding(builder, loc, enc)); - params.push_back(constantIndexTypeEncoding(builder, loc, enc)); - params.push_back(constantPrimaryTypeEncoding(builder, loc, elemTp)); - // User action. - params.push_back(constantAction(builder, loc, action)); - // Payload pointer. - if (!ptr) - ptr = builder.create(loc, getOpaquePointerType(builder)); - params.push_back(ptr); + setTemplateTypes(enc, stp); + // Finally, make note that initialization is complete. + assert(isInitialized() && "Initialization failed"); + // And return `this` for method chaining. + return *this; } /// Generates a call to obtain the values array. @@ -387,14 +468,12 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor, op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); SmallVector srcSizes; - SmallVector params; sizesFromPtr(rewriter, srcSizes, loc, encSrc, srcTp, adaptor.getSrc()); - newParams(rewriter, params, loc, srcTp, noPerm, Action::kToIterator, srcSizes, - adaptor.getSrc()); - Value iter = genNewCall(rewriter, loc, params); + NewCallParams params(rewriter, loc); + Value iter = params.genBuffers(noPerm, srcSizes, srcTp) + .genNewCall(Action::kToIterator, adaptor.getSrc()); // Start a new COO for the destination tensor. SmallVector dstSizes; - params.clear(); if (dstTp.hasStaticShape()) { sizesFromType(rewriter, dstSizes, loc, dstTp); } else { @@ -402,9 +481,9 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor, genReshapeDstShape(loc, rewriter, dstSizes, srcSizes, dstShape, op.getReassociationIndices()); } - newParams(rewriter, params, loc, dstTp, encDst, Action::kEmptyCOO, dstSizes); - Value coo = genNewCall(rewriter, loc, params); - Value dstPerm = params[2]; + Value coo = + params.genBuffers(encDst, dstSizes, dstTp).genNewCall(Action::kEmptyCOO); + Value dstPerm = params.getDim2LvlMap(); // Construct a while loop over the iterator. Value srcIdx = genAlloca(rewriter, loc, srcRank, rewriter.getIndexType()); Value dstIdx = genAlloca(rewriter, loc, dstRank, rewriter.getIndexType()); @@ -426,9 +505,7 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor, rewriter.create(loc); // Final call to construct sparse tensor storage and free temporary resources. rewriter.setInsertionPointAfter(whileOp); - params[6] = constantAction(rewriter, loc, Action::kFromCOO); - params[7] = coo; - Value dst = genNewCall(rewriter, loc, params); + Value dst = params.genNewCall(Action::kFromCOO, coo); genDelCOOCall(rewriter, loc, elemTp, coo); genDelIteratorCall(rewriter, loc, elemTp, iter); rewriter.replaceOp(op, dst); @@ -458,11 +535,10 @@ static void genSparseCOOIterationLoop( rewriter.getContext(), enc.getDimLevelType(), AffineMap(), AffineMap(), enc.getPointerBitWidth(), enc.getIndexBitWidth()); SmallVector sizes; - SmallVector params; sizesFromPtr(rewriter, sizes, loc, noPerm, tensorTp, t); - newParams(rewriter, params, loc, tensorTp, noPerm, Action::kToIterator, sizes, - t); - Value iter = genNewCall(rewriter, loc, params); + Value iter = NewCallParams(rewriter, loc) + .genBuffers(noPerm, sizes, tensorTp) + .genNewCall(Action::kToIterator, t); // Construct a while loop over the iterator. Value srcIdx = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); @@ -611,12 +687,12 @@ public: // Generate the call to construct tensor from ptr. The sizes are // inferred from the result type of the new operator. SmallVector sizes; - SmallVector params; ShapedType stp = resType.cast(); sizesFromType(rewriter, sizes, loc, stp); Value ptr = adaptor.getOperands()[0]; - newParams(rewriter, params, loc, stp, enc, Action::kFromFile, sizes, ptr); - rewriter.replaceOp(op, genNewCall(rewriter, loc, params)); + rewriter.replaceOp(op, NewCallParams(rewriter, loc) + .genBuffers(enc, sizes, stp) + .genNewCall(Action::kFromFile, ptr)); return success(); } }; @@ -650,10 +726,10 @@ public: } // Generate the call to construct empty tensor. The sizes are // explicitly defined by the arguments to the alloc operator. - SmallVector params; - ShapedType stp = resType.cast(); - newParams(rewriter, params, loc, stp, enc, Action::kEmpty, sizes); - rewriter.replaceOp(op, genNewCall(rewriter, loc, params)); + rewriter.replaceOp(op, + NewCallParams(rewriter, loc) + .genBuffers(enc, sizes, resType.cast()) + .genNewCall(Action::kEmpty)); return success(); } }; @@ -690,7 +766,7 @@ public: return success(); } SmallVector sizes; - SmallVector params; + NewCallParams params(rewriter, loc); ShapedType stp = srcType.cast(); sizesFromPtr(rewriter, sizes, loc, encSrc, stp, src); bool useDirectConversion; @@ -708,9 +784,8 @@ public: break; } if (useDirectConversion) { - newParams(rewriter, params, loc, stp, encDst, Action::kSparseToSparse, - sizes, src); - rewriter.replaceOp(op, genNewCall(rewriter, loc, params)); + rewriter.replaceOp(op, params.genBuffers(encDst, sizes, stp) + .genNewCall(Action::kSparseToSparse, src)); } else { // use via-COO conversion. // Set up encoding with right mix of src and dst so that the two // method calls can share most parameters, while still providing @@ -719,13 +794,13 @@ public: op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(), encDst.getHigherOrdering(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); - newParams(rewriter, params, loc, stp, enc, Action::kToCOO, sizes, src); - Value coo = genNewCall(rewriter, loc, params); - params[3] = constantPointerTypeEncoding(rewriter, loc, encDst); - params[4] = constantIndexTypeEncoding(rewriter, loc, encDst); - params[6] = constantAction(rewriter, loc, Action::kFromCOO); - params[7] = coo; - Value dst = genNewCall(rewriter, loc, params); + // TODO: This is the only place where `kToCOO` (or `kToIterator`) + // is called with a non-identity permutation. Is there any clean + // way to push the permutation over to the `kFromCOO` side instead? + Value coo = + params.genBuffers(enc, sizes, stp).genNewCall(Action::kToCOO, src); + Value dst = params.setTemplateTypes(encDst, stp) + .genNewCall(Action::kFromCOO, coo); genDelCOOCall(rewriter, loc, stp.getElementType(), coo); rewriter.replaceOp(op, dst); } @@ -743,7 +818,7 @@ public: RankedTensorType srcTensorTp = srcType.cast(); unsigned rank = dstTensorTp.getRank(); Type elemTp = dstTensorTp.getElementType(); - // Fabricate a no-permutation encoding for newParams(). + // Fabricate a no-permutation encoding for NewCallParams // The pointer/index types must be those of `src`. // The dimLevelTypes aren't actually used by Action::kToIterator. encDst = SparseTensorEncodingAttr::get( @@ -751,11 +826,10 @@ public: SmallVector(rank, DimLevelType::Dense), AffineMap(), AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); SmallVector sizes; - SmallVector params; sizesFromPtr(rewriter, sizes, loc, encSrc, srcTensorTp, src); - newParams(rewriter, params, loc, dstTensorTp, encDst, Action::kToIterator, - sizes, src); - Value iter = genNewCall(rewriter, loc, params); + Value iter = NewCallParams(rewriter, loc) + .genBuffers(encDst, sizes, dstTensorTp) + .genNewCall(Action::kToIterator, src); Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); Block *insertionBlock = rewriter.getInsertionBlock(); @@ -817,12 +891,12 @@ public: ShapedType stp = resType.cast(); unsigned rank = stp.getRank(); SmallVector sizes; - SmallVector params; sizesFromSrc(rewriter, sizes, loc, src); - newParams(rewriter, params, loc, stp, encDst, Action::kEmptyCOO, sizes); - Value coo = genNewCall(rewriter, loc, params); + NewCallParams params(rewriter, loc); + Value coo = + params.genBuffers(encDst, sizes, stp).genNewCall(Action::kEmptyCOO); Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); - Value perm = params[2]; + Value perm = params.getDim2LvlMap(); Type eltType = stp.getElementType(); Value elemPtr = genAllocaScalar(rewriter, loc, eltType); genDenseTensorOrSparseConstantIterLoop( @@ -836,9 +910,7 @@ public: genAddEltCall(builder, loc, eltType, coo, elemPtr, ind, perm); }); // Final call to construct sparse tensor storage. - params[6] = constantAction(rewriter, loc, Action::kFromCOO); - params[7] = coo; - Value dst = genNewCall(rewriter, loc, params); + Value dst = params.genNewCall(Action::kFromCOO, coo); genDelCOOCall(rewriter, loc, eltType, coo); rewriter.replaceOp(op, dst); return success(); @@ -1117,15 +1189,15 @@ public: Value offset = constantIndex(rewriter, loc, 0); SmallVector sizes; - SmallVector params; + NewCallParams params(rewriter, loc); concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), concatDim); if (encDst) { // Start a new COO for the destination tensor. - newParams(rewriter, params, loc, dstTp, encDst, Action::kEmptyCOO, sizes); - dst = genNewCall(rewriter, loc, params); - dstPerm = params[2]; + dst = + params.genBuffers(encDst, sizes, dstTp).genNewCall(Action::kEmptyCOO); + dstPerm = params.getDim2LvlMap(); elemPtr = genAllocaScalar(rewriter, loc, elemTp); dstIdx = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); } else { @@ -1188,11 +1260,9 @@ public: offset = rewriter.create(loc, offset, curDim); } if (encDst) { - params[6] = constantAction(rewriter, loc, Action::kFromCOO); // In sparse output case, the destination holds the COO. Value coo = dst; - params[7] = coo; - dst = genNewCall(rewriter, loc, params); + dst = params.genNewCall(Action::kFromCOO, coo); // Release resources. genDelCOOCall(rewriter, loc, elemTp, coo); rewriter.replaceOp(op, dst); @@ -1216,27 +1286,25 @@ public: Value src = adaptor.getOperands()[0]; auto encSrc = getSparseTensorEncoding(srcType); SmallVector sizes; - SmallVector params; sizesFromPtr(rewriter, sizes, loc, encSrc, srcType, src); auto enc = SparseTensorEncodingAttr::get( op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); - newParams(rewriter, params, loc, srcType, enc, Action::kToCOO, sizes, src); - Value coo = genNewCall(rewriter, loc, params); + Value coo = NewCallParams(rewriter, loc) + .genBuffers(enc, sizes, srcType) + .genNewCall(Action::kToCOO, src); // Then output the tensor to external file with indices in the externally // visible lexicographic index order. A sort is required if the source was // not in that order yet (note that the sort can be dropped altogether if // external format does not care about the order at all, but here we assume // it does). - bool sort = - encSrc.getDimOrdering() && !encSrc.getDimOrdering().isIdentity(); - params.clear(); - params.push_back(coo); - params.push_back(adaptor.getOperands()[1]); - params.push_back(constantI1(rewriter, loc, sort)); + Value sort = constantI1(rewriter, loc, + encSrc.getDimOrdering() && + !encSrc.getDimOrdering().isIdentity()); + SmallVector outParams{coo, adaptor.getOperands()[1], sort}; Type eltType = srcType.getElementType(); SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(eltType)}; - createFuncCall(rewriter, loc, name, {}, params, EmitCInterface::Off); + createFuncCall(rewriter, loc, name, {}, outParams, EmitCInterface::Off); genDelCOOCall(rewriter, loc, eltType, coo); rewriter.eraseOp(op); return success();