From 3a3a09f65412dc38aba6b7370b93f9d2c7fd1c30 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 4 Oct 2021 11:38:53 +0200 Subject: [PATCH] [mlir][python] Provide more convenient wrappers for std.ConstantOp Constructing a ConstantOp using the default-generated API is verbose and requires to specify the constant type twice: for the result type of the operation and for the type of the attribute. It also requires to explicitly construct the attribute. Provide custom constructors that take the type once and accept a raw value instead of the attribute. This requires dynamic dispatch based on type in the constructor. Also provide the corresponding accessors to raw values. In addition, provide a "refinement" class ConstantIndexOp similar to what exists in C++. Unlike other "op view" Python classes, operations cannot be automatically downcasted to this class since it does not correspond to a specific operation name. It only exists to simplify construction of the operation. Depends On D110946 Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D110947 --- mlir/python/CMakeLists.txt | 4 +- mlir/python/mlir/dialects/_std_ops_ext.py | 71 +++++++++++++++++++++++++++++++ mlir/test/python/dialects/std.py | 64 ++++++++++++++++++++++++++++ 3 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 mlir/python/mlir/dialects/_std_ops_ext.py create mode 100644 mlir/test/python/dialects/std.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index eb7e1e4..4f0d154 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -136,7 +136,9 @@ declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/StandardOps.td - SOURCES dialects/std.py + SOURCES + dialects/std.py + dialects/_std_ops_ext.py DIALECT_NAME std) declare_mlir_dialect_python_bindings( diff --git a/mlir/python/mlir/dialects/_std_ops_ext.py b/mlir/python/mlir/dialects/_std_ops_ext.py new file mode 100644 index 0000000..bb67fe4 --- /dev/null +++ b/mlir/python/mlir/dialects/_std_ops_ext.py @@ -0,0 +1,71 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * + from .builtin import FuncOp + from ._ods_common import get_default_loc_context as _get_default_loc_context + + from typing import Any, List, Optional, Union +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + + +def _isa(obj: Any, cls: type): + try: + cls(obj) + except ValueError: + return False + return True + + +def _is_any_of(obj: Any, classes: List[type]): + return any(_isa(obj, cls) for cls in classes) + + +def _is_integer_like_type(type: Type): + return _is_any_of(type, [IntegerType, IndexType]) + + +def _is_float_type(type: Type): + return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type]) + + +class ConstantOp: + """Specialization for the constant op class.""" + + def __init__(self, + result: Type, + value: Union[int, float, Attribute], + *, + loc=None, + ip=None): + if isinstance(value, int): + super().__init__(result, IntegerAttr.get(result, value), loc=loc, ip=ip) + elif isinstance(value, float): + super().__init__(result, FloatAttr.get(result, value), loc=loc, ip=ip) + else: + super().__init__(result, value, loc=loc, ip=ip) + + @classmethod + def create_index(cls, value: int, *, loc=None, ip=None): + """Create an index-typed constant.""" + return cls( + IndexType.get(context=_get_default_loc_context(loc)), + value, + loc=loc, + ip=ip) + + @property + def type(self): + return self.results[0].type + + @property + def literal_value(self) -> Union[int, float]: + if _is_integer_like_type(self.type): + return IntegerAttr(self.value).value + elif _is_float_type(self.type): + return FloatAttr(self.value).value + else: + raise ValueError("only integer and float constants have literal values") diff --git a/mlir/test/python/dialects/std.py b/mlir/test/python/dialects/std.py new file mode 100644 index 0000000..ed507d6 --- /dev/null +++ b/mlir/test/python/dialects/std.py @@ -0,0 +1,64 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +from mlir.dialects import std + + +def constructAndPrintInModule(f): + print("\nTEST:", f.__name__) + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + f() + print(module) + return f + +# CHECK-LABEL: TEST: testConstantOp + +@constructAndPrintInModule +def testConstantOp(): + c1 = std.ConstantOp(IntegerType.get_signless(32), 42) + c2 = std.ConstantOp(IntegerType.get_signless(64), 100) + c3 = std.ConstantOp(F32Type.get(), 3.14) + c4 = std.ConstantOp(F64Type.get(), 1.23) + # CHECK: 42 + print(c1.literal_value) + + # CHECK: 100 + print(c2.literal_value) + + # CHECK: 3.140000104904175 + print(c3.literal_value) + + # CHECK: 1.23 + print(c4.literal_value) + +# CHECK: = constant 42 : i32 +# CHECK: = constant 100 : i64 +# CHECK: = constant 3.140000e+00 : f32 +# CHECK: = constant 1.230000e+00 : f64 + +# CHECK-LABEL: TEST: testVectorConstantOp +@constructAndPrintInModule +def testVectorConstantOp(): + int_type = IntegerType.get_signless(32) + vec_type = VectorType.get([2, 2], int_type) + c1 = std.ConstantOp(vec_type, + DenseElementsAttr.get_splat(vec_type, IntegerAttr.get(int_type, 42))) + try: + print(c1.literal_value) + except ValueError as e: + assert "only integer and float constants have literal values" in str(e) + else: + assert False + +# CHECK: = constant dense<42> : vector<2x2xi32> + +# CHECK-LABEL: TEST: testConstantIndexOp +@constructAndPrintInModule +def testConstantIndexOp(): + c1 = std.ConstantOp.create_index(10) + # CHECK: 10 + print(c1.literal_value) + +# CHECK: = constant 10 : index -- 2.7.4