introduce isl_dim structure for representing shared dimension information
[platform/upstream/isl.git] / isl_mat.c
index 07c695e..e663838 100644 (file)
--- a/isl_mat.c
+++ b/isl_mat.c
@@ -1,3 +1,4 @@
+#include "isl_dim.h"
 #include "isl_seq.h"
 #include "isl_mat.h"
 #include "isl_map_private.h"
@@ -687,15 +688,18 @@ struct isl_basic_set *isl_basic_set_preimage(struct isl_ctx *ctx,
        if (!bset)
                goto error;
 
-       isl_assert(ctx, bset->nparam == 0, goto error);
+       isl_assert(ctx, bset->dim->nparam == 0, goto error);
        isl_assert(ctx, bset->n_div == 0, goto error);
-       isl_assert(ctx, 1+bset->dim == mat->n_row, goto error);
+       isl_assert(ctx, 1+bset->dim->n_out == mat->n_row, goto error);
 
        if (mat->n_col > mat->n_row)
                bset = isl_basic_set_extend(bset, 0, mat->n_col-1, 0,
                                                0, 0);
-       else {
-               bset->dim -= mat->n_row - mat->n_col;
+       else if (mat->n_col < mat->n_row) {
+               bset->dim = isl_dim_cow(bset->dim);
+               if (!bset->dim)
+                       goto error;
+               bset->dim->n_out -= mat->n_row - mat->n_col;
                bset->extra += mat->n_row - mat->n_col;
        }
 
@@ -746,8 +750,13 @@ struct isl_set *isl_set_preimage(struct isl_ctx *ctx,
                if (!set->p[i])
                        goto error;
        }
-       set->dim += mat->n_col;
-       set->dim -= mat->n_row;
+       if (mat->n_col != mat->n_row) {
+               set->dim = isl_dim_cow(set->dim);
+               if (!set->dim)
+                       goto error;
+               set->dim->n_out += mat->n_col;
+               set->dim->n_out -= mat->n_row;
+       }
        isl_mat_free(ctx, mat);
        return set;
 error:
@@ -777,7 +786,7 @@ void isl_mat_dump(struct isl_ctx *ctx, struct isl_mat *mat,
                for (j = 0; j < mat->n_col; ++j) {
                        if (j)
                            fprintf(out, ",");
-                       isl_int_print(out, mat->row[i][j]);
+                       isl_int_print(out, mat->row[i][j], 0);
                }
                if (i == mat->n_row-1)
                        fprintf(out, "]]\n");