align parameters of arguments to binary isl_map and isl_pw_aff functions
[platform/upstream/isl.git] / isl_map.c
index 12369e9..8ed3e44 100644 (file)
--- a/isl_map.c
+++ b/isl_map.c
@@ -811,6 +811,52 @@ static int room_for_con(struct isl_basic_map *bmap, unsigned n)
        return bmap->n_eq + bmap->n_ineq + n <= bmap->c_size;
 }
 
+__isl_give isl_map *isl_map_align_params_map_map_and(
+       __isl_take isl_map *map1, __isl_take isl_map *map2,
+       __isl_give isl_map *(*fn)(__isl_take isl_map *map1,
+                                   __isl_take isl_map *map2))
+{
+       if (!map1 || !map2)
+               goto error;
+       if (isl_dim_match(map1->dim, isl_dim_param, map2->dim, isl_dim_param))
+               return fn(map1, map2);
+       if (!isl_dim_has_named_params(map1->dim) ||
+           !isl_dim_has_named_params(map2->dim))
+               isl_die(map1->ctx, isl_error_invalid,
+                       "unaligned unnamed parameters", goto error);
+       map1 = isl_map_align_params(map1, isl_map_get_dim(map2));
+       map2 = isl_map_align_params(map2, isl_map_get_dim(map1));
+       return fn(map1, map2);
+error:
+       isl_map_free(map1);
+       isl_map_free(map2);
+       return NULL;
+}
+
+static int align_params_map_map_and_test(__isl_keep isl_map *map1,
+       __isl_keep isl_map *map2,
+       int (*fn)(__isl_keep isl_map *map1, __isl_keep isl_map *map2))
+{
+       int r;
+
+       if (!map1 || !map2)
+               return -1;
+       if (isl_dim_match(map1->dim, isl_dim_param, map2->dim, isl_dim_param))
+               return fn(map1, map2);
+       if (!isl_dim_has_named_params(map1->dim) ||
+           !isl_dim_has_named_params(map2->dim))
+               isl_die(map1->ctx, isl_error_invalid,
+                       "unaligned unnamed parameters", return -1);
+       map1 = isl_map_copy(map1);
+       map2 = isl_map_copy(map2);
+       map1 = isl_map_align_params(map1, isl_map_get_dim(map2));
+       map2 = isl_map_align_params(map2, isl_map_get_dim(map1));
+       r = fn(map1, map2);
+       isl_map_free(map1);
+       isl_map_free(map2);
+       return r;
+}
+
 int isl_basic_map_alloc_equality(struct isl_basic_map *bmap)
 {
        struct isl_ctx *ctx;
@@ -2356,7 +2402,8 @@ error:
        return NULL;
 }
 
-struct isl_map *isl_map_intersect(struct isl_map *map1, struct isl_map *map2)
+static __isl_give isl_map *map_intersect(__isl_take isl_map *map1,
+       __isl_take isl_map *map2)
 {
        unsigned flags = 0;
        struct isl_map *result;
@@ -2380,8 +2427,6 @@ struct isl_map *isl_map_intersect(struct isl_map *map1, struct isl_map *map2)
            (map1->p[0]->n_eq + map1->p[0]->n_ineq == 1 ||
             map2->p[0]->n_eq + map2->p[0]->n_ineq == 1))
                return map_intersect_add_constraint(map1, map2);
-       isl_assert(map1->ctx, isl_dim_match(map1->dim, isl_dim_param,
-                                        map2->dim, isl_dim_param), goto error);
        if (isl_dim_total(map1->dim) ==
                                isl_dim_size(map1->dim, isl_dim_param) &&
            isl_dim_total(map2->dim) != isl_dim_size(map2->dim, isl_dim_param))
@@ -2421,6 +2466,12 @@ error:
        return NULL;
 }
 
+__isl_give isl_map *isl_map_intersect(__isl_take isl_map *map1,
+       __isl_take isl_map *map2)
+{
+       return isl_map_align_params_map_map_and(map1, map2, &map_intersect);
+}
+
 struct isl_set *isl_set_intersect(struct isl_set *set1, struct isl_set *set2)
 {
        return (struct isl_set *)
@@ -2431,12 +2482,18 @@ struct isl_set *isl_set_intersect(struct isl_set *set1, struct isl_set *set2)
 /* The current implementation of isl_map_intersect accepts intersections
  * with parameter domains, so we can just call that for now.
  */
-__isl_give isl_map *isl_map_intersect_params(__isl_take isl_map *map,
+static __isl_give isl_map *map_intersect_params(__isl_take isl_map *map,
                __isl_take isl_set *params)
 {
        return isl_map_intersect(map, params);
 }
 
+__isl_give isl_map *isl_map_intersect_params(__isl_take isl_map *map1,
+       __isl_take isl_map *map2)
+{
+       return isl_map_align_params_map_map_and(map1, map2, &map_intersect_params);
+}
+
 __isl_give isl_set *isl_set_intersect_params(__isl_take isl_set *set,
                __isl_take isl_set *params)
 {
@@ -5633,8 +5690,8 @@ error:
        return NULL;
 }
 
-struct isl_map *isl_map_union_disjoint(
-                       struct isl_map *map1, struct isl_map *map2)
+static __isl_give isl_map *map_union_disjoint(__isl_take isl_map *map1,
+       __isl_take isl_map *map2)
 {
        int i;
        unsigned flags = 0;
@@ -5684,6 +5741,12 @@ error:
        return NULL;
 }
 
+__isl_give isl_map *isl_map_union_disjoint(__isl_take isl_map *map1,
+       __isl_take isl_map *map2)
+{
+       return isl_map_align_params_map_map_and(map1, map2, &map_union_disjoint);
+}
+
 struct isl_map *isl_map_union(struct isl_map *map1, struct isl_map *map2)
 {
        map1 = isl_map_union_disjoint(map1, map2);
@@ -5708,8 +5771,8 @@ struct isl_set *isl_set_union(struct isl_set *set1, struct isl_set *set2)
                isl_map_union((struct isl_map *)set1, (struct isl_map *)set2);
 }
 
-struct isl_map *isl_map_intersect_range(
-               struct isl_map *map, struct isl_set *set)
+static __isl_give isl_map *map_intersect_range(__isl_take isl_map *map,
+       __isl_take isl_set *set)
 {
        unsigned flags = 0;
        struct isl_map *result;
@@ -5758,6 +5821,12 @@ error:
        return NULL;
 }
 
+__isl_give isl_map *isl_map_intersect_range(__isl_take isl_map *map,
+       __isl_take isl_set *set)
+{
+       return isl_map_align_params_map_map_and(map, set, &map_intersect_range);
+}
+
 struct isl_map *isl_map_intersect_domain(
                struct isl_map *map, struct isl_set *set)
 {
@@ -5765,8 +5834,8 @@ struct isl_map *isl_map_intersect_domain(
                isl_map_intersect_range(isl_map_reverse(map), set));
 }
 
-struct isl_map *isl_map_apply_domain(
-               struct isl_map *map1, struct isl_map *map2)
+static __isl_give isl_map *map_apply_domain(__isl_take isl_map *map1,
+       __isl_take isl_map *map2)
 {
        if (!map1 || !map2)
                goto error;
@@ -5779,8 +5848,14 @@ error:
        return NULL;
 }
 
-struct isl_map *isl_map_apply_range(
-               struct isl_map *map1, struct isl_map *map2)
+__isl_give isl_map *isl_map_apply_domain(__isl_take isl_map *map1,
+       __isl_take isl_map *map2)
+{
+       return isl_map_align_params_map_map_and(map1, map2, &map_apply_domain);
+}
+
+static __isl_give isl_map *map_apply_range(__isl_take isl_map *map1,
+       __isl_take isl_map *map2)
 {
        struct isl_dim *dim_result;
        struct isl_map *result;
@@ -5815,6 +5890,12 @@ error:
        return NULL;
 }
 
+__isl_give isl_map *isl_map_apply_range(__isl_take isl_map *map1,
+       __isl_take isl_map *map2)
+{
+       return isl_map_align_params_map_map_and(map1, map2, &map_apply_range);
+}
+
 /*
  * returns range - domain
  */
@@ -6317,7 +6398,7 @@ int isl_set_has_equal_dim(__isl_keep isl_set *set1, __isl_keep isl_set *set2)
        return isl_dim_equal(set1->dim, set2->dim);
 }
 
-int isl_map_is_equal(struct isl_map *map1, struct isl_map *map2)
+static int map_is_equal(__isl_keep isl_map *map1, __isl_keep isl_map *map2)
 {
        int is_subset;
 
@@ -6330,6 +6411,11 @@ int isl_map_is_equal(struct isl_map *map1, struct isl_map *map2)
        return is_subset;
 }
 
+int isl_map_is_equal(__isl_keep isl_map *map1, __isl_keep isl_map *map2)
+{
+       return align_params_map_map_and_test(map1, map2, &map_is_equal);
+}
+
 int isl_basic_map_is_strict_subset(
                struct isl_basic_map *bmap1, struct isl_basic_map *bmap2)
 {
@@ -6705,7 +6791,8 @@ struct isl_set *isl_set_align_divs(struct isl_set *set)
        return (struct isl_set *)isl_map_align_divs((struct isl_map *)set);
 }
 
-struct isl_set *isl_set_apply(struct isl_set *set, struct isl_map *map)
+static __isl_give isl_set *set_apply( __isl_take isl_set *set,
+       __isl_take isl_map *map)
 {
        if (!set || !map)
                goto error;
@@ -6719,6 +6806,12 @@ error:
        return NULL;
 }
 
+__isl_give isl_set *isl_set_apply( __isl_take isl_set *set,
+       __isl_take isl_map *map)
+{
+       return isl_map_align_params_map_map_and(set, map, &set_apply);
+}
+
 /* There is no need to cow as removing empty parts doesn't change
  * the meaning of the set.
  */
@@ -7656,11 +7749,18 @@ error:
 
 /* Given two maps A -> B and C -> D, construct a map [A -> C] -> [B -> D]
  */
-struct isl_map *isl_map_product(struct isl_map *map1, struct isl_map *map2)
+static __isl_give isl_map *map_product_aligned(__isl_take isl_map *map1,
+       __isl_take isl_map *map2)
 {
        return map_product(map1, map2, &isl_dim_product, &isl_basic_map_product);
 }
 
+__isl_give isl_map *isl_map_product(__isl_take isl_map *map1,
+       __isl_take isl_map *map2)
+{
+       return isl_map_align_params_map_map_and(map1, map2, &map_product_aligned);
+}
+
 /* Given two maps A -> B and C -> D, construct a map (A, C) -> (B, D)
  */
 __isl_give isl_map *isl_map_flat_product(__isl_take isl_map *map1,
@@ -7689,13 +7789,20 @@ __isl_give isl_set *isl_set_flat_product(__isl_take isl_set *set1,
 
 /* Given two maps A -> B and C -> D, construct a map (A * C) -> [B -> D]
  */
-__isl_give isl_map *isl_map_range_product(__isl_take isl_map *map1,
+static __isl_give isl_map *map_range_product_aligned(__isl_take isl_map *map1,
        __isl_take isl_map *map2)
 {
        return map_product(map1, map2, &isl_dim_range_product,
                                &isl_basic_map_range_product);
 }
 
+__isl_give isl_map *isl_map_range_product(__isl_take isl_map *map1,
+       __isl_take isl_map *map2)
+{
+       return isl_map_align_params_map_map_and(map1, map2,
+                                               &map_range_product_aligned);
+}
+
 /* Given two maps A -> B and C -> D, construct a map (A * C) -> (B, D)
  */
 __isl_give isl_map *isl_map_flat_range_product(__isl_take isl_map *map1,