This enables this kind of construct in the DSL to generate a named op that is polymorphic over numeric type variables `T` and `U`, generating the correct arithmetic casts at construction time:
```
@tc_def_op
def polymorphic_matmul(A=TensorDef(T1, S.M, S.K),
B=TensorDef(T2, S.K, S.N),
C=TensorDef(U, S.M, S.N, output=True)):
implements(ContractionOpInterface)
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
```
Presently, this only supports type variables that are bound to the element type of one of the arguments, although a further extension that allows binding a type variable to an attribute would allow some more expressiveness and may be useful for some formulations. This is left to a future patch. In addition, this patch does not yet materialize the verifier support which ensures that types are bound correctly (for such simple examples, failing to do so will yield IR that fails verification, it just won't yet fail with a precise error).
Note that the full grid of extensions/truncation/int<->float conversions are supported, but many of them are lossy and higher level code needs to be mindful of numerics (it is not the job of this level).
As-is, this should be sufficient for most integer matmul scenarios we work with in typical quantization schemes.
Differential Revision: https://reviews.llvm.org/D97603
name: A
usage: input
shape: affine_map<()[s0, s1, s2] -> (s0, s2)>
+ element_type_var: T1
- !<LinalgTensorDef>
name: B
usage: input
shape: affine_map<()[s0, s1, s2] -> (s2, s1)>
+ element_type_var: T2
- !<LinalgTensorDef>
name: C
usage: output
shape: affine_map<()[s0, s1, s2] -> (s0, s1)>
+ element_type_var: U
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
fn_name: mul
operands:
- !ScalarExpression
- scalar_arg: A
+ symbolic_cast:
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: A
- !ScalarExpression
- scalar_arg: B
+ symbolic_cast:
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: B
public:
RegionBuilderHelper(Block &block) : block(block) {}
+ // Generates operations to cast the given operand to a specified type.
+ // If the cast cannot be performed, a warning will be issued and the
+ // operand returned as-is (which will presumably yield a verification
+ // issue downstream).
+ Value cast(Type toType, Value operand) {
+ OpBuilder builder = getBuilder(operand);
+ auto loc = operand.getLoc();
+
+ if (operand.getType() == toType)
+ return operand;
+ if (auto toIntType = toType.dyn_cast<IntegerType>()) {
+ // If operand is floating point, cast directly to the int type.
+ if (operand.getType().isa<FloatType>())
+ return builder.create<FPToSIOp>(loc, toType, operand);
+ if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
+ // Either sign extend or truncate.
+ if (toIntType.getWidth() > fromIntType.getWidth())
+ return builder.create<SignExtendIOp>(loc, toType, operand);
+ else if (toIntType.getWidth() < fromIntType.getWidth())
+ return builder.create<TruncateIOp>(loc, toType, operand);
+ }
+ } else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
+ // If operand is integer, cast directly to the float type.
+ // Note that it is unclear how to cast from BF16<->FP16.
+ if (operand.getType().isa<IntegerType>())
+ return builder.create<SIToFPOp>(loc, toFloatType, operand);
+ if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
+ if (toFloatType.getWidth() > fromFloatType.getWidth())
+ return builder.create<FPExtOp>(loc, toFloatType, operand);
+ else if (toFloatType.getWidth() < fromFloatType.getWidth())
+ return builder.create<FPTruncOp>(loc, toFloatType, operand);
+ }
+ }
+
+ emitWarning(operand.getLoc()) << "could not cast operand of type "
+ << operand.getType() << " to " << toType;
+ return operand;
+ }
+
Value applyfn__add(Value lhs, Value rhs) {
OpBuilder builder = getBuilder(lhs);
if (isFloatingPoint(lhs))
// CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
// CHECK-NEXT: linalg.yield %[[ADD]] : i32
// CHECK-NEXT: -> tensor<16x32xi32>
+
+// -----
+// Verifies floating point to integer cast.
+func @generalize_matmul_tensor_f32_f32_i16(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> {
+ %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
+ outs(%C: tensor<16x32xi16>) -> tensor<16x32xi16>
+ return %0: tensor<16x32xi16>
+}
+
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: i16)
+// CHECK-NEXT: %[[A_CAST:.+]] = fptosi %[[A_ARG]] : f32 to i16
+// CHECK-NEXT: %[[B_CAST:.+]] = fptosi %[[B_ARG]] : f32 to i16
+// CHECK-NEXT: %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i16
+// CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i16
+// CHECK-NEXT: linalg.yield %[[ADD]] : i16
+// CHECK-NEXT: -> tensor<16x32xi16>
+
+// -----
+// Verifies sign extension cast.
+func @generalize_matmul_tensor_i8_i8_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi8>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
+ %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>)
+ outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
+ return %0: tensor<16x32xi32>
+}
+
+// -----
+// Verifies that different argument types is legal.
+func @generalize_matmul_tensor_i8_i16_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi16>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
+ %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi16>)
+ outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
+ return %0: tensor<16x32xi32>
+}
+
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: i32)
+// CHECK-NEXT: %[[A_CAST:.+]] = sexti %[[A_ARG]] : i8 to i32
+// CHECK-NEXT: %[[B_CAST:.+]] = sexti %[[B_ARG]] : i8 to i32
+// CHECK-NEXT: %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i32
+// CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
+// CHECK-NEXT: linalg.yield %[[ADD]] : i32
+// CHECK-NEXT: -> tensor<16x32xi32>
+
+// -----
+// Somewhat non-sensical but checks integer truncation cast.
+func @generalize_matmul_tensor_i32_i32_i16(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> {
+ %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
+ outs(%C: tensor<16x32xi16>) -> tensor<16x32xi16>
+ return %0: tensor<16x32xi16>
+}
+
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i16)
+// CHECK-NEXT: %[[A_CAST:.+]] = trunci %[[A_ARG]] : i32 to i16
+// CHECK-NEXT: %[[B_CAST:.+]] = trunci %[[B_ARG]] : i32 to i16
+// CHECK-NEXT: %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i16
+// CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i16
+// CHECK-NEXT: linalg.yield %[[ADD]] : i16
+// CHECK-NEXT: -> tensor<16x32xi16>
+
+// -----
+// Verifies integer to floating point cast.
+func @generalize_matmul_tensor_i8_i8_f32(%A : tensor<16x8xi8>, %B: tensor<8x32xi8>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
+ %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>)
+ outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+ return %0: tensor<16x32xf32>
+}
+
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: f32)
+// CHECK-NEXT: %[[A_CAST:.+]] = sitofp %[[A_ARG]] : i8 to f32
+// CHECK-NEXT: %[[B_CAST:.+]] = sitofp %[[B_ARG]] : i8 to f32
+// CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32
+// CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
+// CHECK-NEXT: linalg.yield %[[ADD]] : f32
+// CHECK-NEXT: -> tensor<16x32xf32>
+
+// -----
+// Verifies floating point extension cast.
+func @generalize_matmul_tensor_f16_f16_f32(%A : tensor<16x8xf16>, %B: tensor<8x32xf16>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
+ %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf16>)
+ outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+ return %0: tensor<16x32xf32>
+}
+
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
+// CHECK-NEXT: %[[A_CAST:.+]] = fpext %[[A_ARG]] : f16 to f32
+// CHECK-NEXT: %[[B_CAST:.+]] = fpext %[[B_ARG]] : f16 to f32
+// CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32
+// CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
+// CHECK-NEXT: linalg.yield %[[ADD]] : f32
+// CHECK-NEXT: -> tensor<16x32xf32>
+
+// -----
+// Verifies floating point truncation.
+func @generalize_matmul_tensor_f64_f64_f32(%A : tensor<16x8xf64>, %B: tensor<8x32xf64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
+ %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf64>, tensor<8x32xf64>)
+ outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+ return %0: tensor<16x32xf32>
+}
+
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f64, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32)
+// CHECK-NEXT: %[[A_CAST:.+]] = fptrunc %[[A_ARG]] : f64 to f32
+// CHECK-NEXT: %[[B_CAST:.+]] = fptrunc %[[B_ARG]] : f64 to f32
+// CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32
+// CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
+// CHECK-NEXT: linalg.yield %[[ADD]] : f32
+// CHECK-NEXT: -> tensor<16x32xf32>
std::string name;
LinalgTensorUsageDef usage;
SerializedAffineMap shape;
+ std::string elementTypeVar;
};
enum class LinalgIteratorTypeDef {
std::vector<ScalarExpression> operands;
};
+struct ScalarSymbolicCast {
+ std::string typeVar;
+ // NOTE: This must be of arity 1, but to break the self-referential cycle,
+ // we use a heap allocated vector.
+ std::vector<ScalarExpression> operands;
+};
+
struct ScalarExpression {
- Optional<std::string> scalarArg;
- Optional<ScalarApply> scalarApply;
+ Optional<std::string> arg;
+ Optional<ScalarApply> apply;
+ Optional<ScalarSymbolicCast> symbolicCast;
};
struct ScalarAssign {
/// - `shape`: An AffineMap from all op symbols to the specific shape
/// of this argument. Each shape must be normalized over the same list of
/// symbols and have no dimension inputs.
+/// - `element_type_var`: The symbolic type variable that binds to the scalar
+/// element type of this TensorDef.
template <>
struct MappingTraits<LinalgTensorDef> {
static void mapping(IO &io, LinalgTensorDef &info) {
io.mapRequired("name", info.name);
io.mapRequired("usage", info.usage);
io.mapRequired("shape", info.shape);
+ io.mapRequired("element_type_var", info.elementTypeVar);
}
};
/// - `scalar_arg`: Name of an argument to the op.
/// - `scalar_apply`: Result of evaluating a named function (see
/// `ScalarApply`).
+/// - `symbolic_cast`: Cast to a symbolic TypeVar bound elsewhere.
template <>
struct MappingTraits<ScalarExpression> {
static void mapping(IO &io, ScalarExpression &info) {
- io.mapOptional("scalar_arg", info.scalarArg);
- io.mapOptional("scalar_apply", info.scalarApply);
+ io.mapOptional("scalar_arg", info.arg);
+ io.mapOptional("scalar_apply", info.apply);
+ io.mapOptional("symbolic_cast", info.symbolicCast);
}
};
}
};
+template <>
+struct MappingTraits<ScalarSymbolicCast> {
+ static void mapping(IO &io, ScalarSymbolicCast &info) {
+ io.mapRequired("type_var", info.typeVar);
+ io.mapRequired("operands", info.operands);
+ }
+};
+
/// Helper mapping which accesses an AffineMapAttr as a serialized string of
/// the same.
template <>
return None;
}
+static Optional<int>
+findTypeVarArgIndex(StringRef typeVar, SmallVectorImpl<LinalgTensorDef> &args) {
+ for (auto it : llvm::enumerate(args)) {
+ if (it.value().elementTypeVar == typeVar)
+ return it.index();
+ }
+ return None;
+}
+
static ScalarAssign *
findAssignment(StringRef name, SmallVectorImpl<ScalarAssign> &assignments) {
for (auto &assign : assignments) {
std::function<Optional<std::string>(ScalarExpression &)>
generateExpression =
[&](ScalarExpression &expression) -> Optional<std::string> {
- if (expression.scalarArg) {
- Optional<int> argIndex =
- findTensorDefArgIndex(*expression.scalarArg, args);
+ if (expression.arg) {
+ // Argument reference.
+ Optional<int> argIndex = findTensorDefArgIndex(*expression.arg, args);
if (!argIndex) {
emitError(genContext.getLoc())
<< "scalar argument not defined on the op: " << arg.name;
}
return std::string(
llvm::formatv("block.getArgument({0})", *argIndex));
- } else if (expression.scalarApply) {
+ } else if (expression.apply) {
+ // Apply function.
// Recursively generate operands.
SmallVector<std::string> operandCppValues;
- for (ScalarExpression &operand : expression.scalarApply->operands) {
+ for (ScalarExpression &operand : expression.apply->operands) {
auto operandCppValue = generateExpression(operand);
if (!operandCppValue)
return None;
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
stmts.push_back(
llvm::formatv("Value {0} = helper.applyfn__{1}({2});", cppIdent,
- expression.scalarApply->fnName,
+ expression.apply->fnName,
interleaveToString(operandCppValues, ", ")));
return cppIdent;
+ } else if (expression.symbolicCast) {
+ // Symbolic cast.
+ // Operands must be arity 1.
+ if (expression.symbolicCast->operands.size() != 1) {
+ emitError(genContext.getLoc())
+ << "symbolic_cast operand arity must be 1";
+ return None;
+ }
+ Optional<std::string> operandCppValue =
+ generateExpression(expression.symbolicCast->operands[0]);
+ if (!operandCppValue)
+ return None;
+
+ // Try to map the TypeVar to an arg index (which map to block arg
+ // indices), since we can just get that type directly.
+ // TODO: Handle free type variables which do not map to an argument.
+ Optional<int> typeArgIndex =
+ findTypeVarArgIndex(expression.symbolicCast->typeVar, args);
+ if (!typeArgIndex) {
+ emitError(genContext.getLoc())
+ << "type variable " << expression.symbolicCast->typeVar
+ << ", used in a symbolic cast must map to an argument but it "
+ << "does not";
+ return None;
+ }
+ std::string typeCppValue =
+ llvm::formatv("block.getArgument({0}).getType()", *typeArgIndex);
+ std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
+ stmts.push_back(llvm::formatv("Value {0} = helper.cast({1}, {2});",
+ cppIdent, typeCppValue,
+ *operandCppValue));
+ return cppIdent;
} else {
emitError(genContext.getLoc()) << "unknown ScalarExpression type";
return None;