isl_stream_read_map: accept min and max expressions in constraints
authorSven Verdoolaege <skimo@kotnet.org>
Fri, 18 Mar 2011 10:19:48 +0000 (11:19 +0100)
committerSven Verdoolaege <skimo@kotnet.org>
Fri, 18 Mar 2011 11:21:03 +0000 (12:21 +0100)
Signed-off-by: Sven Verdoolaege <skimo@kotnet.org>
include/isl/stream.h
isl_input.c
isl_stream.c
isl_test.c

index 4a9fae8..6695157 100644 (file)
@@ -27,7 +27,7 @@ enum isl_token_type { ISL_TOKEN_ERROR = -1,
                        ISL_TOKEN_TO, ISL_TOKEN_AND,
                        ISL_TOKEN_OR, ISL_TOKEN_EXISTS, ISL_TOKEN_NOT,
                        ISL_TOKEN_DEF, ISL_TOKEN_INFTY, ISL_TOKEN_NAN,
-                       ISL_TOKEN_MAX, ISL_TOKEN_RAT,
+                       ISL_TOKEN_MIN, ISL_TOKEN_MAX, ISL_TOKEN_RAT,
                        ISL_TOKEN_TRUE, ISL_TOKEN_FALSE,
                        ISL_TOKEN_STRING,
                        ISL_TOKEN_LAST };
index 607732e..319937c 100644 (file)
@@ -29,6 +29,9 @@ struct variable {
        char                    *name;
        int                      pos;
        isl_vec                 *def;
+       /* non-zero if variable represents a min (-1) or a max (1) */
+       int                      sign;
+       isl_mat                 *list;
        struct variable         *next;
 };
 
@@ -54,6 +57,7 @@ static void variable_free(struct variable *var)
 {
        while (var) {
                struct variable *next = var->next;
+               isl_mat_free(var->list);
                isl_vec_free(var->def);
                free(var->name);
                free(var);
@@ -81,6 +85,7 @@ static void vars_drop(struct vars *v, int n)
        var = v->v;
        while (--n >= 0) {
                struct variable *next = var->next;
+               isl_mat_free(var->list);
                isl_vec_free(var->def);
                free(var->name);
                free(var);
@@ -93,7 +98,7 @@ static struct variable *variable_new(struct vars *v, const char *name, int len,
                                int pos)
 {
        struct variable *var;
-       var = isl_alloc_type(v->ctx, struct variable);
+       var = isl_calloc_type(v->ctx, struct variable);
        if (!var)
                goto error;
        var->name = strdup(name);
@@ -263,6 +268,7 @@ error:
 
 static struct isl_vec *accept_affine(struct isl_stream *s, struct vars *v);
 static int read_div_definition(struct isl_stream *s, struct vars *v);
+static int read_minmax_definition(struct isl_stream *s, struct vars *v);
 
 static __isl_give isl_vec *accept_affine_factor(struct isl_stream *s,
        struct vars *v)
@@ -324,6 +330,19 @@ static __isl_give isl_vec *accept_affine_factor(struct isl_stream *s,
                if (read_div_definition(s, v) < 0)
                        goto error;
                aff = isl_vec_zero_extend(aff, 1 + v->n);
+       } else if (tok->type == ISL_TOKEN_MIN || tok->type == ISL_TOKEN_MAX) {
+               if (vars_add_anon(v) < 0)
+                       goto error;
+               aff = isl_vec_alloc(v->ctx, 1 + v->n);
+               if (!aff)
+                       goto error;
+               isl_seq_clr(aff->el, aff->size);
+               isl_int_set_si(aff->el[1 + v->n - 1], 1);
+               isl_stream_push_token(s, tok);
+               tok = NULL;
+               if (read_minmax_definition(s, v) < 0)
+                       goto error;
+               aff = isl_vec_zero_extend(aff, 1 + v->n);
        } else {
                isl_stream_error(s, tok, "expecting factor");
                goto error;
@@ -373,6 +392,7 @@ static struct isl_vec *accept_affine(struct isl_stream *s, struct vars *v)
                        continue;
                }
                if (tok->type == '(' || tok->type == '[' ||
+                   tok->type == ISL_TOKEN_MIN || tok->type == ISL_TOKEN_MAX ||
                    tok->type == ISL_TOKEN_IDENT) {
                        isl_vec *aff2;
                        isl_stream_push_token(s, tok);
@@ -605,6 +625,32 @@ error:
        return NULL;
 }
 
+static int read_minmax_definition(struct isl_stream *s, struct vars *v)
+{
+       struct isl_token *tok;
+       struct variable *var;
+
+       var = v->v;
+
+       tok = isl_stream_next_token(s);
+       if (!tok)
+               return -1;
+       var->sign = tok->type == ISL_TOKEN_MIN ? -1 : 1;
+       isl_token_free(tok);
+
+       if (isl_stream_eat(s, '('))
+               return -1;
+
+       var->list = accept_affine_list(s, v);
+       if (!var->list)
+               return -1;
+
+       if (isl_stream_eat(s, ')'))
+               return -1;
+
+       return 0;
+}
+
 static int read_div_definition(struct isl_stream *s, struct vars *v)
 {
        struct isl_token *tok;
@@ -918,6 +964,8 @@ static __isl_give isl_basic_map *add_lifted_divs(__isl_take isl_basic_map *bmap,
                                        0, 0, 2 * extra);
 
        for (i = 0, var = v->v; i < extra; ++i, var = var->next) {
+               if (!var->def)
+                       continue;
                var->def = isl_vec_zero_extend(var->def, 2 + v->n);
                if (!var->def)
                        goto error;
@@ -990,10 +1038,201 @@ error:
        return NULL;
 }
 
+/* Return first variable, starting at n, representing a min or max,
+ * or NULL if there is no such variable.
+ */
+static struct variable *first_minmax(struct vars *v, int n)
+{
+       struct variable *first = NULL;
+       struct variable *var;
+
+       for (var = v->v; var && var->pos >= n; var = var->next)
+               if (var->list)
+                       first = var;
+
+       return first;
+}
+
+/* Check whether the variable at the given position only occurs in
+ * inequalities and only with the given sign.
+ */
+static int all_coefficients_of_sign(__isl_keep isl_map *map, int pos, int sign)
+{
+       int i, j;
+
+       if (!map)
+               return -1;
+
+       for (i = 0; i < map->n; ++i) {
+               isl_basic_map *bmap = map->p[i];
+
+               for (j = 0; j < bmap->n_eq; ++j)
+                       if (!isl_int_is_zero(bmap->eq[j][1 + pos]))
+                               return 0;
+               for (j = 0; j < bmap->n_ineq; ++j) {
+                       int s = isl_int_sgn(bmap->ineq[j][1 + pos]);
+                       if (s == 0)
+                               continue;
+                       if (s != sign)
+                               return 0;
+               }
+       }
+
+       return 1;
+}
+
+/* Given a variable m which represents a min or a max of n expressions
+ * b_i, add the constraints
+ *
+ *     m <= b_i
+ *
+ * in case of a min (var->sign < 0) and m >= b_i in case of a max.
+ */
+static __isl_give isl_map *bound_minmax(__isl_take isl_map *map,
+       struct variable *var)
+{
+       int i, k;
+       isl_basic_map *bound;
+       int total;
+
+       total = isl_map_dim(map, isl_dim_all);
+       bound = isl_basic_map_alloc_dim(isl_map_get_dim(map),
+                                       0, 0, var->list->n_row);
+
+       for (i = 0; i < var->list->n_row; ++i) {
+               k = isl_basic_map_alloc_inequality(bound);
+               if (k < 0)
+                       goto error;
+               if (var->sign < 0)
+                       isl_seq_cpy(bound->ineq[k], var->list->row[i],
+                                   var->list->n_col);
+               else
+                       isl_seq_neg(bound->ineq[k], var->list->row[i],
+                                   var->list->n_col);
+               isl_int_set_si(bound->ineq[k][1 + var->pos], var->sign);
+               isl_seq_clr(bound->ineq[k] + var->list->n_col,
+                           1 + total - var->list->n_col);
+       }
+
+       map = isl_map_intersect(map, isl_map_from_basic_map(bound));
+
+       return map;
+error:
+       isl_basic_map_free(bound);
+       isl_map_free(map);
+       return NULL;
+}
+
+/* Given a variable m which represents a min (or max) of n expressions
+ * b_i, add constraints that assigns the minimal upper bound to m, i.e.,
+ * divide the space into cells where one
+ * of the upper bounds is smaller than all the others and assign
+ * this upper bound to m.
+ *
+ * In particular, if there are n bounds b_i, then the input map
+ * is split into n pieces, each with the extra constraints
+ *
+ *     m = b_i
+ *     b_i <= b_j      for j > i
+ *     b_i <  b_j      for j < i
+ *
+ * in case of a min (var->sign < 0) and similarly in case of a max.
+ *
+ * Note: this function is very similar to set_minimum in isl_tab_pip.c
+ * Perhaps we should try to merge the two.
+ */
+static __isl_give isl_map *set_minmax(__isl_take isl_map *map,
+       struct variable *var)
+{
+       int i, j, k;
+       isl_basic_map *bmap = NULL;
+       isl_ctx *ctx;
+       isl_map *split = NULL;
+       int total;
+
+       ctx = isl_map_get_ctx(map);
+       total = isl_map_dim(map, isl_dim_all);
+       split = isl_map_alloc_dim(isl_map_get_dim(map),
+                               var->list->n_row, ISL_SET_DISJOINT);
+
+       for (i = 0; i < var->list->n_row; ++i) {
+               bmap = isl_basic_map_alloc_dim(isl_map_get_dim(map), 0,
+                                              1, var->list->n_row - 1);
+               k = isl_basic_map_alloc_equality(bmap);
+               if (k < 0)
+                       goto error;
+               isl_seq_cpy(bmap->eq[k], var->list->row[i], var->list->n_col);
+               isl_int_set_si(bmap->eq[k][1 + var->pos], -1);
+               for (j = 0; j < var->list->n_row; ++j) {
+                       if (j == i)
+                               continue;
+                       k = isl_basic_map_alloc_inequality(bmap);
+                       if (k < 0)
+                               goto error;
+                       if (var->sign < 0)
+                               isl_seq_combine(bmap->ineq[k],
+                                               ctx->one, var->list->row[j],
+                                               ctx->negone, var->list->row[i],
+                                               var->list->n_col);
+                       else
+                               isl_seq_combine(bmap->ineq[k],
+                                               ctx->negone, var->list->row[j],
+                                               ctx->one, var->list->row[i],
+                                               var->list->n_col);
+                       isl_seq_clr(bmap->ineq[k] + var->list->n_col,
+                                   1 + total - var->list->n_col);
+                       if (j < i)
+                               isl_int_sub_ui(bmap->ineq[k][0],
+                                              bmap->ineq[k][0], 1);
+               }
+               bmap = isl_basic_map_finalize(bmap);
+               split = isl_map_add_basic_map(split, bmap);
+       }
+
+       map = isl_map_intersect(map, split);
+
+       return map;
+error:
+       isl_basic_map_free(bmap);
+       isl_map_free(split);
+       isl_map_free(map);
+       return NULL;
+}
+
+/* Plug in the definitions of all min and max expressions.
+ * If a min expression only appears in inequalities and only
+ * with a positive coefficient, then we can simply bound
+ * the variable representing the min by its defining terms
+ * and similarly for a max expression.
+ * Otherwise, we have to assign the different terms to the
+ * variable under the condition that the assigned term is smaller
+ * than the other terms.
+ */
+static __isl_give isl_map *add_minmax(__isl_take isl_map *map,
+       struct vars *v, int n)
+{
+       int i;
+       struct variable *var;
+
+       while (n < v->n) {
+               var = first_minmax(v, n);
+               if (!var)
+                       break;
+               if (all_coefficients_of_sign(map, var->pos, -var->sign))
+                       map = bound_minmax(map, var);
+               else
+                       map = set_minmax(map, var);
+               n = var->pos + 1;
+       }
+
+       return map;
+}
+
 static isl_map *read_constraint(struct isl_stream *s,
        struct vars *v, __isl_take isl_basic_map *bmap)
 {
        int n = v->n;
+       isl_map *map;
 
        if (!bmap)
                return NULL;
@@ -1004,11 +1243,15 @@ static isl_map *read_constraint(struct isl_stream *s,
        bmap = isl_basic_map_simplify(bmap);
        bmap = isl_basic_map_finalize(bmap);
 
-       bmap = isl_basic_set_unwrap(isl_basic_map_domain(bmap));
+       map = isl_map_from_basic_map(bmap);
+
+       map = add_minmax(map, v, n);
+
+       map = isl_set_unwrap(isl_map_domain(map));
 
        vars_drop(v, v->n - n);
 
-       return isl_map_from_basic_map(bmap);
+       return map;
 }
 
 static struct isl_map *read_disjuncts(struct isl_stream *s,
index 6b5e7f7..2cb6e0b 100644 (file)
@@ -250,6 +250,8 @@ static enum isl_token_type check_keywords(struct isl_stream *s)
                return ISL_TOKEN_INFTY;
        if (!strcasecmp(s->buffer, "NaN"))
                return ISL_TOKEN_NAN;
+       if (!strcasecmp(s->buffer, "min"))
+               return ISL_TOKEN_MIN;
        if (!strcasecmp(s->buffer, "max"))
                return ISL_TOKEN_MAX;
        if (!strcasecmp(s->buffer, "rat"))
index 46177d6..30879d9 100644 (file)
@@ -73,6 +73,11 @@ void test_parse(struct isl_ctx *ctx)
        str2 = "{ [x, y] : 2y >= 6 - x }";
        test_parse_map_equal(ctx, str, str2);
 
+       test_parse_map_equal(ctx, "{ [x,y] : x <= min(y, 2*y+3) }",
+                                 "{ [x,y] : x <= y, 2*y + 3 }");
+       str = "{ [x, y] : (y <= x and y >= -3) or (2y <= -3 + x and y <= -4) }";
+       test_parse_map_equal(ctx, "{ [x,y] : x >= min(y, 2*y+3) }", str);
+
        str = "{[new,old] -> [new+1-2*[(new+1)/2],old+1-2*[(old+1)/2]]}";
        map = isl_map_read_from_str(ctx, str, -1);
        str = "{ [new, old] -> [o0, o1] : "