--- /dev/null
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+try:
+ from ..ir import *
+ from .builtin import FuncOp
+ from ._ods_common import get_default_loc_context as _get_default_loc_context
+
+ from typing import Any, List, Optional, Union
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+
+def _isa(obj: Any, cls: type):
+ try:
+ cls(obj)
+ except ValueError:
+ return False
+ return True
+
+
+def _is_any_of(obj: Any, classes: List[type]):
+ return any(_isa(obj, cls) for cls in classes)
+
+
+def _is_integer_like_type(type: Type):
+ return _is_any_of(type, [IntegerType, IndexType])
+
+
+def _is_float_type(type: Type):
+ return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
+
+
+class ConstantOp:
+ """Specialization for the constant op class."""
+
+ def __init__(self,
+ result: Type,
+ value: Union[int, float, Attribute],
+ *,
+ loc=None,
+ ip=None):
+ if isinstance(value, int):
+ super().__init__(result, IntegerAttr.get(result, value), loc=loc, ip=ip)
+ elif isinstance(value, float):
+ super().__init__(result, FloatAttr.get(result, value), loc=loc, ip=ip)
+ else:
+ super().__init__(result, value, loc=loc, ip=ip)
+
+ @classmethod
+ def create_index(cls, value: int, *, loc=None, ip=None):
+ """Create an index-typed constant."""
+ return cls(
+ IndexType.get(context=_get_default_loc_context(loc)),
+ value,
+ loc=loc,
+ ip=ip)
+
+ @property
+ def type(self):
+ return self.results[0].type
+
+ @property
+ def literal_value(self) -> Union[int, float]:
+ if _is_integer_like_type(self.type):
+ return IntegerAttr(self.value).value
+ elif _is_float_type(self.type):
+ return FloatAttr(self.value).value
+ else:
+ raise ValueError("only integer and float constants have literal values")
--- /dev/null
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import std
+
+
+def constructAndPrintInModule(f):
+ print("\nTEST:", f.__name__)
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ f()
+ print(module)
+ return f
+
+# CHECK-LABEL: TEST: testConstantOp
+
+@constructAndPrintInModule
+def testConstantOp():
+ c1 = std.ConstantOp(IntegerType.get_signless(32), 42)
+ c2 = std.ConstantOp(IntegerType.get_signless(64), 100)
+ c3 = std.ConstantOp(F32Type.get(), 3.14)
+ c4 = std.ConstantOp(F64Type.get(), 1.23)
+ # CHECK: 42
+ print(c1.literal_value)
+
+ # CHECK: 100
+ print(c2.literal_value)
+
+ # CHECK: 3.140000104904175
+ print(c3.literal_value)
+
+ # CHECK: 1.23
+ print(c4.literal_value)
+
+# CHECK: = constant 42 : i32
+# CHECK: = constant 100 : i64
+# CHECK: = constant 3.140000e+00 : f32
+# CHECK: = constant 1.230000e+00 : f64
+
+# CHECK-LABEL: TEST: testVectorConstantOp
+@constructAndPrintInModule
+def testVectorConstantOp():
+ int_type = IntegerType.get_signless(32)
+ vec_type = VectorType.get([2, 2], int_type)
+ c1 = std.ConstantOp(vec_type,
+ DenseElementsAttr.get_splat(vec_type, IntegerAttr.get(int_type, 42)))
+ try:
+ print(c1.literal_value)
+ except ValueError as e:
+ assert "only integer and float constants have literal values" in str(e)
+ else:
+ assert False
+
+# CHECK: = constant dense<42> : vector<2x2xi32>
+
+# CHECK-LABEL: TEST: testConstantIndexOp
+@constructAndPrintInModule
+def testConstantIndexOp():
+ c1 = std.ConstantOp.create_index(10)
+ # CHECK: 10
+ print(c1.literal_value)
+
+# CHECK: = constant 10 : index