From b76f523be6ea606d9cf494e247546cec1cd7f209 Mon Sep 17 00:00:00 2001 From: zhanghb97 Date: Mon, 14 Sep 2020 22:52:22 +0800 Subject: [PATCH] [mlir] expose affine map to C API This patch provides C API for MLIR affine map. - Implement C API for AffineMap class. - Add Utils.h to include/mlir/CAPI/, and move the definition of the CallbackOstream to Utils.h to make sure mlirAffineMapPrint work correct. - Add TODO for exposing the C API related to AffineExpr and mutable affine map. Differential Revision: https://reviews.llvm.org/D87617 --- mlir/include/mlir-c/AffineMap.h | 110 +++++++++++++++++++++++++++++++++ mlir/include/mlir/CAPI/Utils.h | 48 +++++++++++++++ mlir/lib/CAPI/IR/AffineMap.cpp | 116 ++++++++++++++++++++++++++++++++++- mlir/lib/CAPI/IR/IR.cpp | 41 +++---------- mlir/test/CAPI/ir.c | 132 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 411 insertions(+), 36 deletions(-) create mode 100644 mlir/include/mlir/CAPI/Utils.h diff --git a/mlir/include/mlir-c/AffineMap.h b/mlir/include/mlir-c/AffineMap.h index bef13fd..a5d9918 100644 --- a/mlir/include/mlir-c/AffineMap.h +++ b/mlir/include/mlir-c/AffineMap.h @@ -18,6 +18,116 @@ extern "C" { DEFINE_C_API_STRUCT(MlirAffineMap, const void); +/** Gets the context that the given affine map was created with*/ +MlirContext mlirAffineMapGetContext(MlirAffineMap affineMap); + +/** Checks whether an affine map is null. */ +inline int mlirAffineMapIsNull(MlirAffineMap affineMap) { + return !affineMap.ptr; +} + +/** Checks if two affine maps are equal. */ +int mlirAffineMapEqual(MlirAffineMap a1, MlirAffineMap a2); + +/** Prints an affine map by sending chunks of the string representation and + * forwarding `userData to `callback`. Note that the callback may be called + * several times with consecutive chunks of the string. */ +void mlirAffineMapPrint(MlirAffineMap affineMap, MlirStringCallback callback, + void *userData); + +/** Prints the affine map to the standard error stream. */ +void mlirAffineMapDump(MlirAffineMap affineMap); + +/** Creates a zero result affine map with no dimensions or symbols in the + * context. The affine map is owned by the context. */ +MlirAffineMap mlirAffineMapEmptyGet(MlirContext ctx); + +/** Creates a zero result affine map of the given dimensions and symbols in the + * context. The affine map is owned by the context. */ +MlirAffineMap mlirAffineMapGet(MlirContext ctx, intptr_t dimCount, + intptr_t symbolCount); + +/** Creates a single constant result affine map in the context. The affine map + * is owned by the context. */ +MlirAffineMap mlirAffineMapConstantGet(MlirContext ctx, int64_t val); + +/** Creates an affine map with 'numDims' identity in the context. The affine map + * is owned by the context. */ +MlirAffineMap mlirAffineMapMultiDimIdentityGet(MlirContext ctx, + intptr_t numDims); + +/** Creates an identity affine map on the most minor dimensions in the context. + * The affine map is owned by the context. The function asserts that the number + * of dimensions is greater or equal to the number of results. */ +MlirAffineMap mlirAffineMapMinorIdentityGet(MlirContext ctx, intptr_t dims, + intptr_t results); + +/** Creates an affine map with a permutation expression and its size in the + * context. The permutation expression is a non-empty vector of integers. + * The elements of the permutation vector must be continuous from 0 and cannot + * be repeated (i.e. `[1,2,0]` is a valid permutation. `[2,0]` or `[1,1,2]` is + * an invalid invalid permutation.) The affine map is owned by the context. */ +MlirAffineMap mlirAffineMapPermutationGet(MlirContext ctx, intptr_t size, + unsigned *permutation); + +/** Checks whether the given affine map is an identity affine map. The function + * asserts that the number of dimensions is greater or equal to the number of + * results. */ +int mlirAffineMapIsIdentity(MlirAffineMap affineMap); + +/** Checks whether the given affine map is a minor identity affine map. */ +int mlirAffineMapIsMinorIdentity(MlirAffineMap affineMap); + +/** Checks whether the given affine map is an empty affine map. */ +int mlirAffineMapIsEmpty(MlirAffineMap affineMap); + +/** Checks whether the given affine map is a single result constant affine + * map. */ +int mlirAffineMapIsSingleConstant(MlirAffineMap affineMap); + +/** Returns the constant result of the given affine map. The function asserts + * that the map has a single constant result. */ +int64_t mlirAffineMapGetSingleConstantResult(MlirAffineMap affineMap); + +/** Returns the number of dimensions of the given affine map. */ +intptr_t mlirAffineMapGetNumDims(MlirAffineMap affineMap); + +/** Returns the number of symbols of the given affine map. */ +intptr_t mlirAffineMapGetNumSymbols(MlirAffineMap affineMap); + +/** Returns the number of results of the given affine map. */ +intptr_t mlirAffineMapGetNumResults(MlirAffineMap affineMap); + +/** Returns the number of inputs (dimensions + symbols) of the given affine + * map. */ +intptr_t mlirAffineMapGetNumInputs(MlirAffineMap affineMap); + +/** Checks whether the given affine map represents a subset of a symbol-less + * permutation map. */ +int mlirAffineMapIsProjectedPermutation(MlirAffineMap affineMap); + +/** Checks whether the given affine map represents a symbol-less permutation + * map. */ +int mlirAffineMapIsPermutation(MlirAffineMap affineMap); + +/** Returns the affine map consisting of the `resultPos` subset. */ +MlirAffineMap mlirAffineMapGetSubMap(MlirAffineMap affineMap, intptr_t size, + intptr_t *resultPos); + +/** Returns the affine map consisting of the most major `numResults` results. + * Returns the null AffineMap if the `numResults` is equal to zero. + * Returns the `affineMap` if `numResults` is greater or equals to number of + * results of the given affine map. */ +MlirAffineMap mlirAffineMapGetMajorSubMap(MlirAffineMap affineMap, + intptr_t numResults); + +/** Returns the affine map consisting of the most minor `numResults` results. + * Returns the null AffineMap if the `numResults` is equal to zero. + * Returns the `affineMap` if `numResults` is greater or equals to number of + * results of the given affine map. */ +MlirAffineMap mlirAffineMapGetMinorSubMap(MlirAffineMap affineMap, + intptr_t numResults); + #ifdef __cplusplus } #endif diff --git a/mlir/include/mlir/CAPI/Utils.h b/mlir/include/mlir/CAPI/Utils.h new file mode 100644 index 0000000..022f09d --- /dev/null +++ b/mlir/include/mlir/CAPI/Utils.h @@ -0,0 +1,48 @@ +//===- Utils.h - C API General Utilities ------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines general utilities for C API. This file should not be +// included from C++ code other than C API implementation nor from C code. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CAPI_UTILS_H +#define MLIR_CAPI_UTILS_H + +#include "llvm/Support/raw_ostream.h" + +/* ========================================================================== */ +/* Printing helper. */ +/* ========================================================================== */ + +namespace mlir { +namespace detail { +/// A simple raw ostream subclass that forwards write_impl calls to the +/// user-supplied callback together with opaque user-supplied data. +class CallbackOstream : public llvm::raw_ostream { +public: + CallbackOstream(std::function callback, + void *opaqueData) + : callback(callback), opaqueData(opaqueData), pos(0u) {} + + void write_impl(const char *ptr, size_t size) override { + callback(ptr, size, opaqueData); + pos += size; + } + + uint64_t current_pos() const override { return pos; } + +private: + std::function callback; + void *opaqueData; + uint64_t pos; +}; +} // end namespace detail +} // end namespace mlir + +#endif // MLIR_CAPI_UTILS_H diff --git a/mlir/lib/CAPI/IR/AffineMap.cpp b/mlir/lib/CAPI/IR/AffineMap.cpp index d80d9e2..6a87c26 100644 --- a/mlir/lib/CAPI/IR/AffineMap.cpp +++ b/mlir/lib/CAPI/IR/AffineMap.cpp @@ -9,7 +9,119 @@ #include "mlir-c/AffineMap.h" #include "mlir-c/IR.h" #include "mlir/CAPI/AffineMap.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Utils.h" #include "mlir/IR/AffineMap.h" -// This is a placeholder for affine map bindings. The file is here to serve as a -// compilation unit that includes the headers. +// TODO: expose the C API related to `AffineExpr` and mutable affine map. + +using namespace mlir; + +MlirContext mlirAffineMapGetContext(MlirAffineMap affineMap) { + return wrap(unwrap(affineMap).getContext()); +} + +int mlirAffineMapEqual(MlirAffineMap a1, MlirAffineMap a2) { + return unwrap(a1) == unwrap(a2); +} + +void mlirAffineMapPrint(MlirAffineMap affineMap, MlirStringCallback callback, + void *userData) { + mlir::detail::CallbackOstream stream(callback, userData); + unwrap(affineMap).print(stream); + stream.flush(); +} + +void mlirAffineMapDump(MlirAffineMap affineMap) { unwrap(affineMap).dump(); } + +MlirAffineMap mlirAffineMapEmptyGet(MlirContext ctx) { + return wrap(AffineMap::get(unwrap(ctx))); +} + +MlirAffineMap mlirAffineMapGet(MlirContext ctx, intptr_t dimCount, + intptr_t symbolCount) { + return wrap(AffineMap::get(dimCount, symbolCount, unwrap(ctx))); +} + +MlirAffineMap mlirAffineMapConstantGet(MlirContext ctx, int64_t val) { + return wrap(AffineMap::getConstantMap(val, unwrap(ctx))); +} + +MlirAffineMap mlirAffineMapMultiDimIdentityGet(MlirContext ctx, + intptr_t numDims) { + return wrap(AffineMap::getMultiDimIdentityMap(numDims, unwrap(ctx))); +} + +MlirAffineMap mlirAffineMapMinorIdentityGet(MlirContext ctx, intptr_t dims, + intptr_t results) { + return wrap(AffineMap::getMinorIdentityMap(dims, results, unwrap(ctx))); +} + +MlirAffineMap mlirAffineMapPermutationGet(MlirContext ctx, intptr_t size, + unsigned *permutation) { + return wrap(AffineMap::getPermutationMap( + llvm::makeArrayRef(permutation, static_cast(size)), unwrap(ctx))); +} + +int mlirAffineMapIsIdentity(MlirAffineMap affineMap) { + return unwrap(affineMap).isIdentity(); +} + +int mlirAffineMapIsMinorIdentity(MlirAffineMap affineMap) { + return unwrap(affineMap).isMinorIdentity(); +} + +int mlirAffineMapIsEmpty(MlirAffineMap affineMap) { + return unwrap(affineMap).isEmpty(); +} + +int mlirAffineMapIsSingleConstant(MlirAffineMap affineMap) { + return unwrap(affineMap).isSingleConstant(); +} + +int64_t mlirAffineMapGetSingleConstantResult(MlirAffineMap affineMap) { + return unwrap(affineMap).getSingleConstantResult(); +} + +intptr_t mlirAffineMapGetNumDims(MlirAffineMap affineMap) { + return unwrap(affineMap).getNumDims(); +} + +intptr_t mlirAffineMapGetNumSymbols(MlirAffineMap affineMap) { + return unwrap(affineMap).getNumSymbols(); +} + +intptr_t mlirAffineMapGetNumResults(MlirAffineMap affineMap) { + return unwrap(affineMap).getNumResults(); +} + +intptr_t mlirAffineMapGetNumInputs(MlirAffineMap affineMap) { + return unwrap(affineMap).getNumInputs(); +} + +int mlirAffineMapIsProjectedPermutation(MlirAffineMap affineMap) { + return unwrap(affineMap).isProjectedPermutation(); +} + +int mlirAffineMapIsPermutation(MlirAffineMap affineMap) { + return unwrap(affineMap).isPermutation(); +} + +MlirAffineMap mlirAffineMapGetSubMap(MlirAffineMap affineMap, intptr_t size, + intptr_t *resultPos) { + SmallVector pos; + pos.reserve(size); + for (intptr_t i = 0; i < size; ++i) + pos.push_back(static_cast(resultPos[i])); + return wrap(unwrap(affineMap).getSubMap(pos)); +} + +MlirAffineMap mlirAffineMapGetMajorSubMap(MlirAffineMap affineMap, + intptr_t numResults) { + return wrap(unwrap(affineMap).getMajorSubMap(numResults)); +} + +MlirAffineMap mlirAffineMapGetMinorSubMap(MlirAffineMap affineMap, + intptr_t numResults) { + return wrap(unwrap(affineMap).getMinorSubMap(numResults)); +} diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 2a008a2..8611d65 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -9,44 +9,17 @@ #include "mlir-c/IR.h" #include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Utils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Module.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Types.h" #include "mlir/Parser.h" -#include "llvm/Support/raw_ostream.h" using namespace mlir; /* ========================================================================== */ -/* Printing helper. */ -/* ========================================================================== */ - -namespace { -/// A simple raw ostream subclass that forwards write_impl calls to the -/// user-supplied callback together with opaque user-supplied data. -class CallbackOstream : public llvm::raw_ostream { -public: - CallbackOstream(std::function callback, - void *opaqueData) - : callback(callback), opaqueData(opaqueData), pos(0u) {} - - void write_impl(const char *ptr, size_t size) override { - callback(ptr, size, opaqueData); - pos += size; - } - - uint64_t current_pos() const override { return pos; } - -private: - std::function callback; - void *opaqueData; - uint64_t pos; -}; -} // end namespace - -/* ========================================================================== */ /* Context API. */ /* ========================================================================== */ @@ -77,7 +50,7 @@ MlirLocation mlirLocationUnknownGet(MlirContext context) { void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, void *userData) { - CallbackOstream stream(callback, userData); + detail::CallbackOstream stream(callback, userData); unwrap(location).print(stream); stream.flush(); } @@ -244,7 +217,7 @@ MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, void mlirOperationPrint(MlirOperation op, MlirStringCallback callback, void *userData) { - CallbackOstream stream(callback, userData); + detail::CallbackOstream stream(callback, userData); unwrap(op)->print(stream); stream.flush(); } @@ -326,7 +299,7 @@ MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData) { - CallbackOstream stream(callback, userData); + detail::CallbackOstream stream(callback, userData); unwrap(block)->print(stream); stream.flush(); } @@ -341,7 +314,7 @@ MlirType mlirValueGetType(MlirValue value) { void mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData) { - CallbackOstream stream(callback, userData); + detail::CallbackOstream stream(callback, userData); unwrap(value).print(stream); stream.flush(); } @@ -361,7 +334,7 @@ MlirContext mlirTypeGetContext(MlirType type) { int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); } void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) { - CallbackOstream stream(callback, userData); + detail::CallbackOstream stream(callback, userData); unwrap(type).print(stream); stream.flush(); } @@ -382,7 +355,7 @@ int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, void *userData) { - CallbackOstream stream(callback, userData); + detail::CallbackOstream stream(callback, userData); unwrap(attr).print(stream); stream.flush(); } diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c index ceb19ef..fa63c72 100644 --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -10,6 +10,7 @@ /* RUN: mlir-capi-ir-test 2>&1 | FileCheck %s */ +#include "mlir-c/AffineMap.h" #include "mlir-c/IR.h" #include "mlir-c/Registration.h" #include "mlir-c/StandardAttributes.h" @@ -593,6 +594,121 @@ int printStandardAttributes(MlirContext ctx) { return 0; } +int printAffineMap(MlirContext ctx) { + MlirAffineMap emptyAffineMap = mlirAffineMapEmptyGet(ctx); + MlirAffineMap affineMap = mlirAffineMapGet(ctx, 3, 2); + MlirAffineMap constAffineMap = mlirAffineMapConstantGet(ctx, 2); + MlirAffineMap multiDimIdentityAffineMap = + mlirAffineMapMultiDimIdentityGet(ctx, 3); + MlirAffineMap minorIdentityAffineMap = + mlirAffineMapMinorIdentityGet(ctx, 3, 2); + unsigned permutation[] = {1, 2, 0}; + MlirAffineMap permutationAffineMap = mlirAffineMapPermutationGet( + ctx, sizeof(permutation) / sizeof(unsigned), permutation); + + mlirAffineMapDump(emptyAffineMap); + mlirAffineMapDump(affineMap); + mlirAffineMapDump(constAffineMap); + mlirAffineMapDump(multiDimIdentityAffineMap); + mlirAffineMapDump(minorIdentityAffineMap); + mlirAffineMapDump(permutationAffineMap); + + if (!mlirAffineMapIsIdentity(emptyAffineMap) || + mlirAffineMapIsIdentity(affineMap) || + mlirAffineMapIsIdentity(constAffineMap) || + !mlirAffineMapIsIdentity(multiDimIdentityAffineMap) || + mlirAffineMapIsIdentity(minorIdentityAffineMap) || + mlirAffineMapIsIdentity(permutationAffineMap)) + return 1; + + if (!mlirAffineMapIsMinorIdentity(emptyAffineMap) || + mlirAffineMapIsMinorIdentity(affineMap) || + !mlirAffineMapIsMinorIdentity(multiDimIdentityAffineMap) || + !mlirAffineMapIsMinorIdentity(minorIdentityAffineMap) || + mlirAffineMapIsMinorIdentity(permutationAffineMap)) + return 2; + + if (!mlirAffineMapIsEmpty(emptyAffineMap) || + mlirAffineMapIsEmpty(affineMap) || + mlirAffineMapIsEmpty(constAffineMap) || + mlirAffineMapIsEmpty(multiDimIdentityAffineMap) || + mlirAffineMapIsEmpty(minorIdentityAffineMap) || + mlirAffineMapIsEmpty(permutationAffineMap)) + return 3; + + if (mlirAffineMapIsSingleConstant(emptyAffineMap) || + mlirAffineMapIsSingleConstant(affineMap) || + !mlirAffineMapIsSingleConstant(constAffineMap) || + mlirAffineMapIsSingleConstant(multiDimIdentityAffineMap) || + mlirAffineMapIsSingleConstant(minorIdentityAffineMap) || + mlirAffineMapIsSingleConstant(permutationAffineMap)) + return 4; + + if (mlirAffineMapGetSingleConstantResult(constAffineMap) != 2) + return 5; + + if (mlirAffineMapGetNumDims(emptyAffineMap) != 0 || + mlirAffineMapGetNumDims(affineMap) != 3 || + mlirAffineMapGetNumDims(constAffineMap) != 0 || + mlirAffineMapGetNumDims(multiDimIdentityAffineMap) != 3 || + mlirAffineMapGetNumDims(minorIdentityAffineMap) != 3 || + mlirAffineMapGetNumDims(permutationAffineMap) != 3) + return 6; + + if (mlirAffineMapGetNumSymbols(emptyAffineMap) != 0 || + mlirAffineMapGetNumSymbols(affineMap) != 2 || + mlirAffineMapGetNumSymbols(constAffineMap) != 0 || + mlirAffineMapGetNumSymbols(multiDimIdentityAffineMap) != 0 || + mlirAffineMapGetNumSymbols(minorIdentityAffineMap) != 0 || + mlirAffineMapGetNumSymbols(permutationAffineMap) != 0) + return 7; + + if (mlirAffineMapGetNumResults(emptyAffineMap) != 0 || + mlirAffineMapGetNumResults(affineMap) != 0 || + mlirAffineMapGetNumResults(constAffineMap) != 1 || + mlirAffineMapGetNumResults(multiDimIdentityAffineMap) != 3 || + mlirAffineMapGetNumResults(minorIdentityAffineMap) != 2 || + mlirAffineMapGetNumResults(permutationAffineMap) != 3) + return 8; + + if (mlirAffineMapGetNumInputs(emptyAffineMap) != 0 || + mlirAffineMapGetNumInputs(affineMap) != 5 || + mlirAffineMapGetNumInputs(constAffineMap) != 0 || + mlirAffineMapGetNumInputs(multiDimIdentityAffineMap) != 3 || + mlirAffineMapGetNumInputs(minorIdentityAffineMap) != 3 || + mlirAffineMapGetNumInputs(permutationAffineMap) != 3) + return 9; + + if (!mlirAffineMapIsProjectedPermutation(emptyAffineMap) || + !mlirAffineMapIsPermutation(emptyAffineMap) || + mlirAffineMapIsProjectedPermutation(affineMap) || + mlirAffineMapIsPermutation(affineMap) || + mlirAffineMapIsProjectedPermutation(constAffineMap) || + mlirAffineMapIsPermutation(constAffineMap) || + !mlirAffineMapIsProjectedPermutation(multiDimIdentityAffineMap) || + !mlirAffineMapIsPermutation(multiDimIdentityAffineMap) || + !mlirAffineMapIsProjectedPermutation(minorIdentityAffineMap) || + mlirAffineMapIsPermutation(minorIdentityAffineMap) || + !mlirAffineMapIsProjectedPermutation(permutationAffineMap) || + !mlirAffineMapIsPermutation(permutationAffineMap)) + return 10; + + intptr_t sub[] = {1}; + + MlirAffineMap subMap = mlirAffineMapGetSubMap( + multiDimIdentityAffineMap, sizeof(sub) / sizeof(intptr_t), sub); + MlirAffineMap majorSubMap = + mlirAffineMapGetMajorSubMap(multiDimIdentityAffineMap, 1); + MlirAffineMap minorSubMap = + mlirAffineMapGetMinorSubMap(multiDimIdentityAffineMap, 1); + + mlirAffineMapDump(subMap); + mlirAffineMapDump(majorSubMap); + mlirAffineMapDump(minorSubMap); + + return 0; +} + int main() { MlirContext ctx = mlirContextCreate(); mlirRegisterAllDialects(ctx); @@ -704,6 +820,22 @@ int main() { errcode = printStandardAttributes(ctx); fprintf(stderr, "%d\n", errcode); + // clang-format off + // CHECK-LABEL: @affineMap + // CHECK: () -> () + // CHECK: (d0, d1, d2)[s0, s1] -> () + // CHECK: () -> (2) + // CHECK: (d0, d1, d2) -> (d0, d1, d2) + // CHECK: (d0, d1, d2) -> (d1, d2) + // CHECK: (d0, d1, d2) -> (d1, d2, d0) + // CHECK: (d0, d1, d2) -> (d1) + // CHECK: (d0, d1, d2) -> (d0) + // CHECK: (d0, d1, d2) -> (d2) + // CHECK: 0 + fprintf(stderr, "@affineMap\n"); + errcode = printAffineMap(ctx); + fprintf(stderr, "%d\n", errcode); + mlirContextDestroy(ctx); return 0; -- 2.7.4