DeclareOpInterfaceMethods<CastOpInterface>]>,
Arguments<(ins From:$in)>,
Results<(outs To:$out)> {
- let builders = [
- OpBuilder<(ins "Value":$source, "Type":$destType), [{
- impl::buildCastOp($_builder, $_state, source, destType);
- }]>
- ];
-
let assemblyFormat = "$in attr-dict `:` type($in) `to` type($out)";
}
let arguments = (ins AnyRankedOrUnrankedMemRef:$source);
let results = (outs AnyRankedOrUnrankedMemRef:$dest);
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
- let builders = [
- OpBuilder<(ins "Value":$source, "Type":$destType), [{
- impl::buildCastOp($_builder, $_state, source, destType);
- }]>
- ];
let extraClassDeclaration = [{
/// Fold the given CastOp into consumer op.
switch (conversion) {
case PrintConversion::ZeroExt64:
value = rewriter.create<arith::ExtUIOp>(
- loc, value, IntegerType::get(rewriter.getContext(), 64));
+ loc, IntegerType::get(rewriter.getContext(), 64), value);
break;
case PrintConversion::SignExt64:
value = rewriter.create<arith::ExtSIOp>(
- loc, value, IntegerType::get(rewriter.getContext(), 64));
+ loc, IntegerType::get(rewriter.getContext(), 64), value);
break;
case PrintConversion::None:
break;
getMemRefType(castOp.getType().cast<TensorType>(), state.getOptions(),
layout, sourceType.getMemorySpace());
- replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, source,
- resultType);
+ replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType,
+ source);
return success();
}
};
scalingFactor);
}
Value numWorkersIndex =
- b.create<arith::IndexCastOp>(numWorkerThreadsVal, b.getI32Type());
+ b.create<arith::IndexCastOp>(b.getI32Type(), numWorkerThreadsVal);
Value numWorkersFloat =
- b.create<arith::SIToFPOp>(numWorkersIndex, b.getF32Type());
+ b.create<arith::SIToFPOp>(b.getF32Type(), numWorkersIndex);
Value scaledNumWorkers =
b.create<arith::MulFOp>(scalingFactor, numWorkersFloat);
Value scaledNumInt =
- b.create<arith::FPToSIOp>(scaledNumWorkers, b.getI32Type());
+ b.create<arith::FPToSIOp>(b.getI32Type(), scaledNumWorkers);
Value scaledWorkers =
- b.create<arith::IndexCastOp>(scaledNumInt, b.getIndexType());
+ b.create<arith::IndexCastOp>(b.getIndexType(), scaledNumInt);
Value maxComputeBlocks = b.create<arith::MaxSIOp>(
b.create<arith::ConstantIndexOp>(1), scaledWorkers);
auto i32Vec = broadcast(builder.getI32Type(), shape);
// exp2(k)
- Value k = builder.create<arith::FPToSIOp>(kF32, i32Vec);
+ Value k = builder.create<arith::FPToSIOp>(i32Vec, kF32);
Value exp2KValue = exp2I32(builder, k);
// exp(x) = exp(y) * exp2(k)
auto i32Vec = broadcast(builder.getI32Type(), shape);
auto fPToSingedInteger = [&](Value a) -> Value {
- return builder.create<arith::FPToSIOp>(a, i32Vec);
+ return builder.create<arith::FPToSIOp>(i32Vec, a);
};
auto modulo4 = [&](Value a) -> Value {
alloc.alignmentAttr());
// Insert a cast so we have the same type as the old alloc.
auto resultCast =
- rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType());
+ rewriter.create<CastOp>(alloc.getLoc(), alloc.getType(), newAlloc);
rewriter.replaceOp(alloc, {resultCast});
return success();
rewriter.replaceOp(subViewOp, subViewOp.source());
return success();
}
- rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.source(),
- subViewOp.getType());
+ rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
+ subViewOp.source());
return success();
}
};
/// A canonicalizer wrapper to replace SubViewOps.
struct SubViewCanonicalizer {
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
- rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType());
+ rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
}
};
viewOp.getOperand(0),
viewOp.byte_shift(), newOperands);
// Insert a cast so we have the same type as the old memref type.
- rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType());
+ rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
return success();
}
};
Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
size = rewriter.create<memref::LoadOp>(loc, op.shape(), index);
if (!size.getType().isa<IndexType>())
- size = rewriter.create<arith::IndexCastOp>(loc, size,
- rewriter.getIndexType());
+ size = rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getIndexType(), size);
sizes[i] = size;
} else {
sizes[i] = rewriter.getIndexAttr(op.getType().getDimSize(i));
Value val = rewriter.create<tensor::ExtractOp>(loc, indices,
ValueRange{ivs[0], idx});
val =
- rewriter.create<arith::IndexCastOp>(loc, val, rewriter.getIndexType());
+ rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), val);
rewriter.create<memref::StoreOp>(loc, val, ind, idx);
}
return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]);
if (!etp.isa<IndexType>()) {
if (etp.getIntOrFloatBitWidth() < 32)
vload = rewriter.create<arith::ExtUIOp>(
- loc, vload, vectorType(codegen, rewriter.getI32Type()));
+ loc, vectorType(codegen, rewriter.getI32Type()), vload);
else if (etp.getIntOrFloatBitWidth() < 64 &&
!codegen.options.enableSIMDIndex32)
vload = rewriter.create<arith::ExtUIOp>(
- loc, vload, vectorType(codegen, rewriter.getI64Type()));
+ loc, vectorType(codegen, rewriter.getI64Type()), vload);
}
return vload;
}
Value load = rewriter.create<memref::LoadOp>(loc, ptr, s);
if (!load.getType().isa<IndexType>()) {
if (load.getType().getIntOrFloatBitWidth() < 64)
- load = rewriter.create<arith::ExtUIOp>(loc, load, rewriter.getI64Type());
+ load = rewriter.create<arith::ExtUIOp>(loc, rewriter.getI64Type(), load);
load =
- rewriter.create<arith::IndexCastOp>(loc, load, rewriter.getIndexType());
+ rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), load);
}
return load;
}
Value mul = rewriter.create<arith::MulIOp>(loc, size, p);
if (auto vtp = i.getType().dyn_cast<VectorType>()) {
Value inv =
- rewriter.create<arith::IndexCastOp>(loc, mul, vtp.getElementType());
+ rewriter.create<arith::IndexCastOp>(loc, vtp.getElementType(), mul);
mul = genVectorInvariantValue(codegen, rewriter, inv);
}
return rewriter.create<arith::AddIOp>(loc, mul, i);
rewriter.getZeroAttr(v0.getType())),
v0);
case kTruncF:
- return rewriter.create<arith::TruncFOp>(loc, v0, inferType(e, v0));
+ return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
case kExtF:
- return rewriter.create<arith::ExtFOp>(loc, v0, inferType(e, v0));
+ return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
case kCastFS:
- return rewriter.create<arith::FPToSIOp>(loc, v0, inferType(e, v0));
+ return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
case kCastFU:
- return rewriter.create<arith::FPToUIOp>(loc, v0, inferType(e, v0));
+ return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
case kCastSF:
- return rewriter.create<arith::SIToFPOp>(loc, v0, inferType(e, v0));
+ return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
case kCastUF:
- return rewriter.create<arith::UIToFPOp>(loc, v0, inferType(e, v0));
+ return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
case kCastS:
- return rewriter.create<arith::ExtSIOp>(loc, v0, inferType(e, v0));
+ return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
case kCastU:
- return rewriter.create<arith::ExtUIOp>(loc, v0, inferType(e, v0));
+ return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
case kTruncI:
- return rewriter.create<arith::TruncIOp>(loc, v0, inferType(e, v0));
+ return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
case kBitCast:
- return rewriter.create<arith::BitcastOp>(loc, v0, inferType(e, v0));
+ return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
// Binary ops.
case kMulF:
return rewriter.create<arith::MulFOp>(loc, v0, v1);
[&](OpBuilder &b, Location loc) {
Value res = memref;
if (compatibleMemRefType != xferOp.getShapedType())
- res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
+ res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
scf::ValueVector viewAndIndices{res};
viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
xferOp.indices().end());
alloc);
b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
Value casted =
- b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
+ b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
scf::ValueVector viewAndIndices{casted};
viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
zero);
[&](OpBuilder &b, Location loc) {
Value res = memref;
if (compatibleMemRefType != xferOp.getShapedType())
- res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
+ res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
scf::ValueVector viewAndIndices{res};
viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
xferOp.indices().end());
loc, MemRefType::get({}, vector.getType()), alloc));
Value casted =
- b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
+ b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
scf::ValueVector viewAndIndices{casted};
viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
zero);
[&](OpBuilder &b, Location loc) {
Value res = memref;
if (compatibleMemRefType != xferOp.getShapedType())
- res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
+ res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
scf::ValueVector viewAndIndices{res};
viewAndIndices.insert(viewAndIndices.end(),
xferOp.indices().begin(),
},
[&](OpBuilder &b, Location loc) {
Value casted =
- b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
+ b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
scf::ValueVector viewAndIndices{casted};
viewAndIndices.insert(viewAndIndices.end(),
xferOp.getTransferRank(), zero);