From 7f1cb43d60a517660c579ef22351bc3ca413d52d Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Tue, 28 Sep 2021 22:48:32 -0700 Subject: [PATCH] [mlir][sparse] simplify negi code generation with subi The lack of negi details leaked from merger class into codegen part. Also, special case for vector code was not needed, the type can be used directly! Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D110677 --- mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp | 8 -------- mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp | 11 +++++++---- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index 60272e2..373f3d1 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -790,14 +790,6 @@ static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, return genInvariantValue(merger, codegen, rewriter, exp); Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); - if (merger.exp(exp).kind == Kind::kNegI) { - // TODO: no negi in std, need to make zero explicit. - Type tp = op.getOutputTensorTypes()[0].getElementType(); - v1 = v0; - v0 = rewriter.create(loc, tp, rewriter.getZeroAttr(tp)); - if (codegen.curVecLength > 1) - v0 = genVectorInvariantValue(codegen, rewriter, v0); - } return merger.buildExp(rewriter, loc, exp, v0, v1); } diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp index 4a18a0a..fd96151 100644 --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -568,7 +568,7 @@ Optional Merger::buildTensorExp(linalg::GenericOp op, Value v) { if (isa(def)) return addExp(kFloorF, e); if (isa(def)) - return addExp(kNegF, e); // TODO: no negi in std? + return addExp(kNegF, e); // no negi in std if (isa(def)) return addExp(kTruncF, e, v); if (isa(def)) @@ -651,9 +651,12 @@ Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e, return rewriter.create(loc, v0); case kNegF: return rewriter.create(loc, v0); - case kNegI: - assert(v1); // no negi in std - return rewriter.create(loc, v0, v1); + case kNegI: // no negi in std + return rewriter.create( + loc, + rewriter.create(loc, v0.getType(), + rewriter.getZeroAttr(v0.getType())), + v0); case kTruncF: return rewriter.create(loc, v0, inferType(e, v0)); case kExtF: -- 2.7.4