#include "IRModule.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/Interfaces.h"
+#include "llvm/ADT/STLExtras.h"
namespace py = pybind11;
}
/// Given the arguments required to build an operation, attempts to infer its
- /// return types. Throws value_error on faliure.
+ /// return types. Throws value_error on failure.
std::vector<PyType>
- inferReturnTypes(std::optional<std::vector<PyValue>> operands,
+ inferReturnTypes(std::optional<py::list> operandList,
std::optional<PyAttribute> attributes,
std::optional<std::vector<PyRegion>> regions,
DefaultingPyMlirContext context,
llvm::SmallVector<MlirValue> mlirOperands;
llvm::SmallVector<MlirRegion> mlirRegions;
- if (operands) {
- mlirOperands.reserve(operands->size());
- for (PyValue &value : *operands) {
- mlirOperands.push_back(value);
+ if (operandList && !operandList->empty()) {
+ // Note: as the list may contain other lists this may not be final size.
+ mlirOperands.reserve(operandList->size());
+ for (const auto& it : llvm::enumerate(*operandList)) {
+ PyValue* val;
+ try {
+ val = py::cast<PyValue *>(it.value());
+ if (!val)
+ throw py::cast_error();
+ mlirOperands.push_back(val->get());
+ continue;
+ } catch (py::cast_error &err) {
+ }
+
+ try {
+ auto vals = py::cast<py::sequence>(it.value());
+ for (py::object v : vals) {
+ try {
+ val = py::cast<PyValue *>(v);
+ if (!val)
+ throw py::cast_error();
+ mlirOperands.push_back(val->get());
+ } catch (py::cast_error &err) {
+ throw py::value_error(
+ (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
+ " must be a Value or Sequence of Values (" + err.what() +
+ ")")
+ .str());
+ }
+ }
+ continue;
+ } catch (py::cast_error &err) {
+ throw py::value_error(
+ (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
+ " must be a Value or Sequence of Values (" + err.what() + ")")
+ .str());
+ }
+
+ throw py::cast_error();
}
}
ValueRange values) {
// Check static and dynamic offsets/sizes/strides does not overflow type.
if (staticVals.size() != numElements)
- return op->emitError("expected ")
- << numElements << " " << name << " values";
+ return op->emitError("expected ") << numElements << " " << name
+ << " values, got " << staticVals.size();
unsigned expectedNumDynamicEntries =
llvm::count_if(staticVals, [&](int64_t staticVal) {
return ShapedType::isDynamic(staticVal);
class InferTypeOpInterface:
def __init__(self, object: object, context: Optional[Context] = None) -> None: ...
- def inferReturnTypes(self, operands: Optional[List[Value]] = None, attributes: Optional[Attribute] = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[Type]: ...
+ def inferReturnTypes(self, operands: Optional[List] = None, attributes: Optional[Attribute] = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[Type]: ...
@property
def operation(self) -> Operation: ...
@property
return tensor.EmptyOp([], f32)
print(module)
+
+
+# CHECK-LABEL: TEST: testInferTypesInsertSlice
+@run
+def testInferTypesInsertSlice():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32Type = F32Type.get()
+ indexType = IndexType.get()
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((1, 1), f32Type),
+ RankedTensorType.get((1, 1), f32Type))
+ # CHECK: func @f
+ # CHECK: tensor.insert_slice %arg0 into %arg1[0, 0] [1, 1] [0, 0] :
+ # CHECK-SAME: tensor<1x1xf32> into tensor<1x1xf32>
+ def f(source, dest):
+ c0 = arith.ConstantOp(indexType, 0)
+ c1 = arith.ConstantOp(indexType, 1)
+ d0 = tensor.InsertSliceOp(source, dest, [], [], [],
+ DenseI64ArrayAttr.get([0, 0]),
+ DenseI64ArrayAttr.get([1, 1]),
+ DenseI64ArrayAttr.get([0, 0]))
+ return [d0.result]
+
+ print(module)