isl_basic_set_opt: avoid invalid access on error path
[platform/upstream/isl.git] / isl_test.c
index fa018e3..fb6a5f4 100644 (file)
@@ -1185,6 +1185,19 @@ int test_coalesce(struct isl_ctx *ctx)
        if (test_coalesce_set(ctx, "{ [i,j] : exists a,b : i = 2a and j = 3b; "
                                     "[i,j] : exists a : j = 3a }", 1) < 0)
                return -1;
+       if (test_coalesce_set(ctx,
+               "{ [a, b, c] : (c <= 7 - b and b <= 1 and b >= 0 and "
+                       "c >= 3 + b and b <= 3 + 8a and b >= -26 + 8a and "
+                       "a >= 3) or "
+                   "(b <= 1 and c <= 7 and b >= 0 and c >= 4 + b and "
+                       "b <= 3 + 8a and b >= -26 + 8a and a >= 3) }", 1) < 0)
+               return -1;
+       if (test_coalesce_set(ctx,
+               "{ [a, 0, c] : c >= 1 and c <= 29 and c >= -1 + 8a and "
+                               "c <= 6 + 8a and a >= 3; "
+                   "[a, -1, c] : c >= 1 and c <= 30 and c >= 8a and "
+                               "c <= 7 + 8a and a >= 3 and a <= 4 }", 1) < 0)
+               return -1;
        return 0;
 }
 
@@ -1521,6 +1534,20 @@ void test_lexmin(struct isl_ctx *ctx)
        assert(isl_map_is_equal(map, map2));
        isl_map_free(map);
        isl_map_free(map2);
+
+       /* Check that empty pieces are properly combined. */
+       str = "[K, N] -> { [x, y] -> [a, b] : K+2<=N<=K+4 and x>=4 and "
+               "2N-6<=x<K+N and N-1<=a<=K+N-1 and N+b-6<=a<=2N-4 and "
+               "b<=2N-3K+a and 3b<=4N-K+1 and b>=N and a>=x+1 }";
+       map = isl_map_read_from_str(ctx, str);
+       map = isl_map_lexmin(map);
+       str = "[K, N] -> { [x, y] -> [1 + x, N] : x >= -6 + 2N and "
+               "x <= -5 + 2N and x >= -1 + 3K - N and x <= -2 + K + N and "
+               "x >= 4 }";
+       map2 = isl_map_read_from_str(ctx, str);
+       assert(isl_map_is_equal(map, map2));
+       isl_map_free(map);
+       isl_map_free(map2);
 }
 
 struct must_may {
@@ -2278,23 +2305,48 @@ int test_one_schedule(isl_ctx *ctx, const char *d, const char *w,
        return 0;
 }
 
-int test_special_schedule(isl_ctx *ctx, const char *domain,
-       const char *validity, const char *proximity, const char *expected_sched)
+static __isl_give isl_union_map *compute_schedule(isl_ctx *ctx,
+       const char *domain, const char *validity, const char *proximity)
 {
        isl_union_set *dom;
        isl_union_map *dep;
        isl_union_map *prox;
-       isl_union_map *sched1, *sched2;
        isl_schedule *schedule;
-       int equal;
+       isl_union_map *sched;
 
        dom = isl_union_set_read_from_str(ctx, domain);
        dep = isl_union_map_read_from_str(ctx, validity);
        prox = isl_union_map_read_from_str(ctx, proximity);
        schedule = isl_union_set_compute_schedule(dom, dep, prox);
-       sched1 = isl_schedule_get_map(schedule);
+       sched = isl_schedule_get_map(schedule);
        isl_schedule_free(schedule);
 
+       return sched;
+}
+
+/* Check that a schedule can be constructed on the given domain
+ * with the given validity and proximity constraints.
+ */
+static int test_has_schedule(isl_ctx *ctx, const char *domain,
+       const char *validity, const char *proximity)
+{
+       isl_union_map *sched;
+
+       sched = compute_schedule(ctx, domain, validity, proximity);
+       if (!sched)
+               return -1;
+
+       isl_union_map_free(sched);
+       return 0;
+}
+
+int test_special_schedule(isl_ctx *ctx, const char *domain,
+       const char *validity, const char *proximity, const char *expected_sched)
+{
+       isl_union_map *sched1, *sched2;
+       int equal;
+
+       sched1 = compute_schedule(ctx, domain, validity, proximity);
        sched2 = isl_union_map_read_from_str(ctx, expected_sched);
 
        equal = isl_union_map_is_equal(sched1, sched2);
@@ -2310,6 +2362,43 @@ int test_special_schedule(isl_ctx *ctx, const char *domain,
        return 0;
 }
 
+/* Check that the schedule map is properly padded, even after being
+ * reconstructed from the band forest.
+ */
+static int test_padded_schedule(isl_ctx *ctx)
+{
+       const char *str;
+       isl_union_set *D;
+       isl_union_map *validity, *proximity;
+       isl_schedule *sched;
+       isl_union_map *map1, *map2;
+       isl_band_list *list;
+       int equal;
+
+       str = "[N] -> { S0[i] : 0 <= i <= N; S1[i, j] : 0 <= i, j <= N }";
+       D = isl_union_set_read_from_str(ctx, str);
+       validity = isl_union_map_empty(isl_union_set_get_space(D));
+       proximity = isl_union_map_copy(validity);
+       sched = isl_union_set_compute_schedule(D, validity, proximity);
+       map1 = isl_schedule_get_map(sched);
+       list = isl_schedule_get_band_forest(sched);
+       isl_band_list_free(list);
+       map2 = isl_schedule_get_map(sched);
+       isl_schedule_free(sched);
+       equal = isl_union_map_is_equal(map1, map2);
+       isl_union_map_free(map1);
+       isl_union_map_free(map2);
+
+       if (equal < 0)
+               return -1;
+       if (!equal)
+               isl_die(ctx, isl_error_unknown,
+                       "reconstructed schedule map not the same as original",
+                       return -1);
+
+       return 0;
+}
+
 int test_schedule(isl_ctx *ctx)
 {
        const char *D, *W, *R, *V, *P, *S;
@@ -2561,7 +2650,33 @@ int test_schedule(isl_ctx *ctx)
        if (test_special_schedule(ctx, D, V, P, S) < 0)
                return -1;
        ctx->opt->schedule_algorithm = ISL_SCHEDULE_ALGORITHM_ISL;
-       return test_special_schedule(ctx, D, V, P, S);
+       if (test_special_schedule(ctx, D, V, P, S) < 0)
+               return -1;
+       
+       D = "{ A[a]; B[] }";
+       V = "{}";
+       P = "{ A[a] -> B[] }";
+       if (test_has_schedule(ctx, D, V, P) < 0)
+               return -1;
+
+       if (test_padded_schedule(ctx) < 0)
+               return -1;
+
+       /* Check that check for progress is not confused by rational
+        * solution.
+        */
+       D = "[N] -> { S0[i, j] : i >= 0 and i <= N and j >= 0 and j <= N }";
+       V = "[N] -> { S0[i0, -1 + N] -> S0[2 + i0, 0] : i0 >= 0 and "
+                                                       "i0 <= -2 + N; "
+                       "S0[i0, i1] -> S0[i0, 1 + i1] : i0 >= 0 and "
+                               "i0 <= N and i1 >= 0 and i1 <= -1 + N }";
+       P = "{}";
+       ctx->opt->schedule_algorithm = ISL_SCHEDULE_ALGORITHM_FEAUTRIER;
+       if (test_has_schedule(ctx, D, V, P) < 0)
+               return -1;
+       ctx->opt->schedule_algorithm = ISL_SCHEDULE_ALGORITHM_ISL;
+
+       return 0;
 }
 
 int test_plain_injective(isl_ctx *ctx, const char *str, int injective)
@@ -2768,6 +2883,24 @@ int test_dim_max(isl_ctx *ctx)
        if (!equal)
                isl_die(ctx, isl_error_unknown, "unexpected result", return -1);
 
+       /* Check that solutions are properly merged. */
+       str = "[n] -> { [a, b, c] : c >= -4a - 2b and "
+                               "c <= -1 + n - 4a - 2b and c >= -2b and "
+                               "4a >= -4 + n and c >= 0 }";
+       set = isl_set_read_from_str(ctx, str);
+       pwaff = isl_set_dim_min(set, 2);
+       set1 = isl_set_from_pw_aff(pwaff);
+       str = "[n] -> { [(0)] : n >= 1 }";
+       set2 = isl_set_read_from_str(ctx, str);
+       equal = isl_set_is_equal(set1, set2);
+       isl_set_free(set1);
+       isl_set_free(set2);
+
+       if (equal < 0)
+               return -1;
+       if (!equal)
+               isl_die(ctx, isl_error_unknown, "unexpected result", return -1);
+
        return 0;
 }
 
@@ -2935,7 +3068,10 @@ int test_output(isl_ctx *ctx)
        s = isl_printer_get_str(p);
        isl_printer_free(p);
        isl_pw_aff_free(pa);
-       equal = !strcmp(s, "(2 - x + 4*floord(x, 4) >= 0) ? (1) : 2");
+       if (!s)
+               equal = -1;
+       else
+               equal = !strcmp(s, "(2 - x + 4*floord(x, 4) >= 0) ? (1) : 2");
        free(s);
        if (equal < 0)
                return -1;
@@ -2983,6 +3119,8 @@ int test_sample(isl_ctx *ctx)
        subset = isl_basic_set_is_subset(bset2, bset1);
        isl_basic_set_free(bset1);
        isl_basic_set_free(bset2);
+       if (empty < 0 || subset < 0)
+               return -1;
        if (empty)
                isl_die(ctx, isl_error_unknown, "point not found", return -1);
        if (!subset)
@@ -3264,6 +3402,31 @@ static int test_conversion(isl_ctx *ctx)
        return 0;
 }
 
+/* Check that isl_basic_map_curry does not modify input.
+ */
+static int test_curry(isl_ctx *ctx)
+{
+       const char *str;
+       isl_basic_map *bmap1, *bmap2;
+       int equal;
+
+       str = "{ [A[] -> B[]] -> C[] }";
+       bmap1 = isl_basic_map_read_from_str(ctx, str);
+       bmap2 = isl_basic_map_curry(isl_basic_map_copy(bmap1));
+       equal = isl_basic_map_is_equal(bmap1, bmap2);
+       isl_basic_map_free(bmap1);
+       isl_basic_map_free(bmap2);
+
+       if (equal < 0)
+               return -1;
+       if (equal)
+               isl_die(ctx, isl_error_unknown,
+                       "curried map should not be equal to original",
+                       return -1);
+
+       return 0;
+}
+
 struct {
        const char *set;
        const char *ma;
@@ -3379,10 +3542,12 @@ static int test_ast(isl_ctx *ctx)
        expr = isl_ast_expr_add(expr1, expr2);
        expr = isl_ast_expr_neg(expr);
        str = isl_ast_expr_to_str(expr);
-       ok = !strcmp(str, "-(A + B)");
+       ok = str ? !strcmp(str, "-(A + B)") : -1;
        free(str);
        isl_ast_expr_free(expr);
 
+       if (ok < 0)
+               return -1;
        if (!ok)
                isl_die(ctx, isl_error_unknown,
                        "isl_ast_expr printed incorrectly", return -1);
@@ -3393,10 +3558,12 @@ static int test_ast(isl_ctx *ctx)
        expr3 = isl_ast_expr_from_id(isl_id_alloc(ctx, "C", NULL));
        expr = isl_ast_expr_sub(expr3, expr);
        str = isl_ast_expr_to_str(expr);
-       ok = !strcmp(str, "C - (A + B)");
+       ok = str ? !strcmp(str, "C - (A + B)") : -1;
        free(str);
        isl_ast_expr_free(expr);
 
+       if (ok < 0)
+               return -1;
        if (!ok)
                isl_die(ctx, isl_error_unknown,
                        "isl_ast_expr printed incorrectly", return -1);
@@ -3404,6 +3571,160 @@ static int test_ast(isl_ctx *ctx)
        return 0;
 }
 
+/* Internal data structure for before_for and after_for callbacks.
+ *
+ * depth is the current depth
+ * before is the number of times before_for has been called
+ * after is the number of times after_for has been called
+ */
+struct isl_test_codegen_data {
+       int depth;
+       int before;
+       int after;
+};
+
+/* This function is called before each for loop in the AST generated
+ * from test_ast_gen1.
+ *
+ * Increment the number of calls and the depth.
+ * Check that the space returned by isl_ast_build_get_schedule_space
+ * matches the target space of the schedule returned by
+ * isl_ast_build_get_schedule.
+ * Return an isl_id that is checked by the corresponding call
+ * to after_for.
+ */
+static __isl_give isl_id *before_for(__isl_keep isl_ast_build *build,
+       void *user)
+{
+       struct isl_test_codegen_data *data = user;
+       isl_ctx *ctx;
+       isl_space *space;
+       isl_union_map *schedule;
+       isl_union_set *uset;
+       isl_set *set;
+       int empty;
+       char name[] = "d0";
+
+       ctx = isl_ast_build_get_ctx(build);
+
+       if (data->before >= 3)
+               isl_die(ctx, isl_error_unknown,
+                       "unexpected number of for nodes", return NULL);
+       if (data->depth >= 2)
+               isl_die(ctx, isl_error_unknown,
+                       "unexpected depth", return NULL);
+
+       snprintf(name, sizeof(name), "d%d", data->depth);
+       data->before++;
+       data->depth++;
+
+       schedule = isl_ast_build_get_schedule(build);
+       uset = isl_union_map_range(schedule);
+       if (!uset)
+               return NULL;
+       if (isl_union_set_n_set(uset) != 1) {
+               isl_union_set_free(uset);
+               isl_die(ctx, isl_error_unknown,
+                       "expecting single range space", return NULL);
+       }
+
+       space = isl_ast_build_get_schedule_space(build);
+       set = isl_union_set_extract_set(uset, space);
+       isl_union_set_free(uset);
+       empty = isl_set_is_empty(set);
+       isl_set_free(set);
+
+       if (empty < 0)
+               return NULL;
+       if (empty)
+               isl_die(ctx, isl_error_unknown,
+                       "spaces don't match", return NULL);
+
+       return isl_id_alloc(ctx, name, NULL);
+}
+
+/* This function is called after each for loop in the AST generated
+ * from test_ast_gen1.
+ *
+ * Increment the number of calls and decrement the depth.
+ * Check that the annotation attached to the node matches
+ * the isl_id returned by the corresponding call to before_for.
+ */
+static __isl_give isl_ast_node *after_for(__isl_take isl_ast_node *node,
+       __isl_keep isl_ast_build *build, void *user)
+{
+       struct isl_test_codegen_data *data = user;
+       isl_id *id;
+       const char *name;
+       int valid;
+
+       data->after++;
+       data->depth--;
+
+       if (data->after > data->before)
+               isl_die(isl_ast_node_get_ctx(node), isl_error_unknown,
+                       "mismatch in number of for nodes",
+                       return isl_ast_node_free(node));
+
+       id = isl_ast_node_get_annotation(node);
+       if (!id)
+               isl_die(isl_ast_node_get_ctx(node), isl_error_unknown,
+                       "missing annotation", return isl_ast_node_free(node));
+
+       name = isl_id_get_name(id);
+       valid = name && atoi(name + 1) == data->depth;
+       isl_id_free(id);
+
+       if (!valid)
+               isl_die(isl_ast_node_get_ctx(node), isl_error_unknown,
+                       "wrong annotation", return isl_ast_node_free(node));
+
+       return node;
+}
+
+/* Check that the before_each_for and after_each_for callbacks
+ * are called for each for loop in the generated code,
+ * that they are called in the right order and that the isl_id
+ * returned from the before_each_for callback is attached to
+ * the isl_ast_node passed to the corresponding after_each_for call.
+ */
+static int test_ast_gen1(isl_ctx *ctx)
+{
+       const char *str;
+       isl_set *set;
+       isl_union_map *schedule;
+       isl_ast_build *build;
+       isl_ast_node *tree;
+       struct isl_test_codegen_data data;
+
+       str = "[N] -> { : N >= 10 }";
+       set = isl_set_read_from_str(ctx, str);
+       str = "[N] -> { A[i,j] -> S[8,i,3,j] : 0 <= i,j <= N; "
+                   "B[i,j] -> S[8,j,9,i] : 0 <= i,j <= N }";
+       schedule = isl_union_map_read_from_str(ctx, str);
+
+       data.before = 0;
+       data.after = 0;
+       data.depth = 0;
+       build = isl_ast_build_from_context(set);
+       build = isl_ast_build_set_before_each_for(build,
+                       &before_for, &data);
+       build = isl_ast_build_set_after_each_for(build,
+                       &after_for, &data);
+       tree = isl_ast_build_ast_from_schedule(build, schedule);
+       isl_ast_build_free(build);
+       if (!tree)
+               return -1;
+
+       isl_ast_node_free(tree);
+
+       if (data.before != 3 || data.after != 3)
+               isl_die(ctx, isl_error_unknown,
+                       "unexpected number of for nodes", return -1);
+
+       return 0;
+}
+
 /* Check that the AST generator handles domains that are integrally disjoint
  * but not ratinoally disjoint.
  */
@@ -3542,6 +3863,8 @@ static int test_ast_gen4(isl_ctx *ctx)
 
 static int test_ast_gen(isl_ctx *ctx)
 {
+       if (test_ast_gen1(ctx) < 0)
+               return -1;
        if (test_ast_gen2(ctx) < 0)
                return -1;
        if (test_ast_gen3(ctx) < 0)
@@ -3551,10 +3874,100 @@ static int test_ast_gen(isl_ctx *ctx)
        return 0;
 }
 
+/* Check if dropping output dimensions from an isl_pw_multi_aff
+ * works properly.
+ */
+static int test_pw_multi_aff(isl_ctx *ctx)
+{
+       const char *str;
+       isl_pw_multi_aff *pma1, *pma2;
+       int equal;
+
+       str = "{ [i,j] -> [i+j, 4i-j] }";
+       pma1 = isl_pw_multi_aff_read_from_str(ctx, str);
+       str = "{ [i,j] -> [4i-j] }";
+       pma2 = isl_pw_multi_aff_read_from_str(ctx, str);
+
+       pma1 = isl_pw_multi_aff_drop_dims(pma1, isl_dim_out, 0, 1);
+
+       equal = isl_pw_multi_aff_plain_is_equal(pma1, pma2);
+
+       isl_pw_multi_aff_free(pma1);
+       isl_pw_multi_aff_free(pma2);
+       if (equal < 0)
+               return -1;
+       if (!equal)
+               isl_die(ctx, isl_error_unknown,
+                       "expressions not equal", return -1);
+
+       return 0;
+}
+
+/* This is a regression test for a bug where isl_basic_map_simplify
+ * would end up in an infinite loop.  In particular, we construct
+ * an empty basic set that is not obviously empty.
+ * isl_basic_set_is_empty marks the basic set as empty.
+ * After projecting out i3, the variable can be dropped completely,
+ * but isl_basic_map_simplify refrains from doing so if the basic set
+ * is empty and would end up in an infinite loop if it didn't test
+ * explicitly for empty basic maps in the outer loop.
+ */
+static int test_simplify(isl_ctx *ctx)
+{
+       const char *str;
+       isl_basic_set *bset;
+       int empty;
+
+       str = "{ [i0, i1, i2, i3] : i0 >= -2 and 6i2 <= 4 + i0 + 5i1 and "
+               "i2 <= 22 and 75i2 <= 111 + 13i0 + 60i1 and "
+               "25i2 >= 38 + 6i0 + 20i1 and i0 <= -1 and i2 >= 20 and "
+               "i3 >= i2 }";
+       bset = isl_basic_set_read_from_str(ctx, str);
+       empty = isl_basic_set_is_empty(bset);
+       bset = isl_basic_set_project_out(bset, isl_dim_set, 3, 1);
+       isl_basic_set_free(bset);
+       if (!bset)
+               return -1;
+       if (!empty)
+               isl_die(ctx, isl_error_unknown,
+                       "basic set should be empty", return -1);
+
+       return 0;
+}
+
+/* This is a regression test for a bug where isl_tab_basic_map_partial_lexopt
+ * with gbr context would fail to disable the use of the shifted tableau
+ * when transferring equalities for the input to the context, resulting
+ * in invalid sample values.
+ */
+static int test_partial_lexmin(isl_ctx *ctx)
+{
+       const char *str;
+       isl_basic_set *bset;
+       isl_basic_map *bmap;
+       isl_map *map;
+
+       str = "{ [1, b, c, 1 - c] -> [e] : 2e <= -c and 2e >= -3 + c }";
+       bmap = isl_basic_map_read_from_str(ctx, str);
+       str = "{ [a, b, c, d] : c <= 1 and 2d >= 6 - 4b - c }";
+       bset = isl_basic_set_read_from_str(ctx, str);
+       map = isl_basic_map_partial_lexmin(bmap, bset, NULL);
+       isl_map_free(map);
+
+       if (!map)
+               return -1;
+
+       return 0;
+}
+
 struct {
        const char *name;
        int (*fn)(isl_ctx *ctx);
 } tests [] = {
+       { "partial lexmin", &test_partial_lexmin },
+       { "simplify", &test_simplify },
+       { "curry", &test_curry },
+       { "piecewise multi affine expressions", &test_pw_multi_aff },
        { "conversion", &test_conversion },
        { "list", &test_list },
        { "align parameters", &test_align_parameters },
@@ -3563,7 +3976,7 @@ struct {
        { "AST", &test_ast },
        { "AST generation", &test_ast_gen },
        { "eliminate", &test_eliminate },
-       { "reisdue class", &test_residue_class },
+       { "residue class", &test_residue_class },
        { "div", &test_div },
        { "slice", &test_slice },
        { "fixed power", &test_fixed_power },