// Index cast is applicable from index to integer and backwards.
bool IndexCastOp::areCastCompatible(Type a, Type b) {
+ if (a.isa<ShapedType>() && b.isa<ShapedType>()) {
+ auto aShaped = a.cast<ShapedType>();
+ auto bShaped = b.cast<ShapedType>();
+
+ return (aShaped.getShape() == bShaped.getShape()) &&
+ areCastCompatible(aShaped.getElementType(),
+ bShaped.getElementType());
+ }
+
return (a.isIndex() && b.isSignlessInteger()) ||
(a.isSignlessInteger() && b.isIndex());
}
--- /dev/null
+// RUN: mlir-opt -split-input-file %s -verify-diagnostics
+
+// CHECK-LABEL: test_index_cast_shape_error
+func @test_index_cast_shape_error(%arg0 : tensor<index>) -> tensor<2xi64> {
+ // expected-error @+1 {{operand type 'tensor<index>' and result type 'tensor<2xi64>' are cast incompatible}}
+ %0 = index_cast %arg0 : tensor<index> to tensor<2xi64>
+ return %0 : tensor<2xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_index_cast_tensor_error
+func @test_index_cast_tensor_error(%arg0 : tensor<index>) -> i64 {
+ // expected-error @+1 {{operand type 'tensor<index>' and result type 'i64' are cast incompatible}}
+ %0 = index_cast %arg0 : tensor<index> to i64
+ return %0 : i64
+}
--- /dev/null
+// RUN: mlir-opt -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: test_index_cast
+func @test_index_cast(%arg0 : index) -> i64 {
+ %0 = index_cast %arg0 : index to i64
+ return %0 : i64
+}
+
+// CHECK-LABEL: test_index_cast_tensor
+func @test_index_cast_tensor(%arg0 : tensor<index>) -> tensor<i64> {
+ %0 = index_cast %arg0 : tensor<index> to tensor<i64>
+ return %0 : tensor<i64>
+}
+
+// CHECK-LABEL: test_index_cast_tensor_reverse
+func @test_index_cast_tensor_reverse(%arg0 : tensor<i64>) -> tensor<index> {
+ %0 = index_cast %arg0 : tensor<i64> to tensor<index>
+ return %0 : tensor<index>
+}
+