isl_union_set_compute_schedule: separate components by default
[platform/upstream/isl.git] / isl_schedule.c
index 0807fce..8a804e8 100644 (file)
@@ -124,6 +124,12 @@ struct isl_sched_edge {
        int end;
 };
 
+enum isl_edge_type {
+       isl_edge_validity = 0,
+       isl_edge_proximity,
+       isl_edge_last = isl_edge_proximity
+};
+
 /* Internal information about the dependence graph used during
  * the construction of the schedule.
  *
@@ -148,10 +154,13 @@ struct isl_sched_edge {
  *
  * n_edge is the number of edges
  * edge is the list of edges
+ * max_edge contains the maximal number of edges of each type;
+ *     in particular, it contains the number of edges in the inital graph.
  * edge_table contains pointers into the edge array, hashed on the source
- *     and sink spaces; the table only contains edges that represent
- *     validity constraints (and that may or may not also represent proximity
- *     constraints)
+ *     and sink spaces; there is one such table for each type;
+ *     a given edge may be referenced from more than one table
+ *     if the corresponding relation appears in more than of the
+ *     sets of dependences
  *
  * node_table contains pointers into the node array, hashed on the space
  *
@@ -187,7 +196,8 @@ struct isl_sched_graph {
 
        struct isl_sched_edge *edge;
        int n_edge;
-       struct isl_hash_table *edge_table;
+       int max_edge[isl_edge_last + 1];
+       struct isl_hash_table *edge_table[isl_edge_last + 1];
 
        struct isl_hash_table *node_table;
        struct isl_region *region;
@@ -254,60 +264,95 @@ static int edge_has_src_and_dst(const void *entry, const void *val)
        return edge->src == temp->src && edge->dst == temp->dst;
 }
 
-/* Initialize edge_table based on the list of edges.
- * Only edges with validity set are added to the table.
+/* Add the given edge to graph->edge_table[type].
  */
-static int graph_init_edge_table(isl_ctx *ctx, struct isl_sched_graph *graph)
+static int graph_edge_table_add(isl_ctx *ctx, struct isl_sched_graph *graph,
+       enum isl_edge_type type, struct isl_sched_edge *edge)
 {
-       int i;
+       struct isl_hash_table_entry *entry;
+       uint32_t hash;
 
-       graph->edge_table = isl_hash_table_alloc(ctx, graph->n_edge);
-       if (!graph->edge_table)
+       hash = isl_hash_init();
+       hash = isl_hash_builtin(hash, edge->src);
+       hash = isl_hash_builtin(hash, edge->dst);
+       entry = isl_hash_table_find(ctx, graph->edge_table[type], hash,
+                                   &edge_has_src_and_dst, edge, 1);
+       if (!entry)
                return -1;
+       entry->data = edge;
 
-       for (i = 0; i < graph->n_edge; ++i) {
-               struct isl_hash_table_entry *entry;
-               uint32_t hash;
+       return 0;
+}
 
-               if (!graph->edge[i].validity)
-                       continue;
+/* Allocate the edge_tables based on the maximal number of edges of
+ * each type.
+ */
+static int graph_init_edge_tables(isl_ctx *ctx, struct isl_sched_graph *graph)
+{
+       int i;
 
-               hash = isl_hash_init();
-               hash = isl_hash_builtin(hash, graph->edge[i].src);
-               hash = isl_hash_builtin(hash, graph->edge[i].dst);
-               entry = isl_hash_table_find(ctx, graph->edge_table, hash,
-                                           &edge_has_src_and_dst,
-                                           &graph->edge[i], 1);
-               if (!entry)
+       for (i = 0; i <= isl_edge_last; ++i) {
+               graph->edge_table[i] = isl_hash_table_alloc(ctx,
+                                                           graph->max_edge[i]);
+               if (!graph->edge_table[i])
                        return -1;
-               entry->data = &graph->edge[i];
        }
 
        return 0;
 }
 
-/* Check whether the dependence graph has a (validity) edge
- * between the given two nodes.
+/* If graph->edge_table[type] contains an edge from the given source
+ * to the given destination, then return the hash table entry of this edge.
+ * Otherwise, return NULL.
  */
-static int graph_has_edge(struct isl_sched_graph *graph,
+static struct isl_hash_table_entry *graph_find_edge_entry(
+       struct isl_sched_graph *graph,
+       enum isl_edge_type type,
        struct isl_sched_node *src, struct isl_sched_node *dst)
 {
        isl_ctx *ctx = isl_space_get_ctx(src->dim);
-       struct isl_hash_table_entry *entry;
        uint32_t hash;
        struct isl_sched_edge temp = { .src = src, .dst = dst };
-       struct isl_sched_edge *edge;
-       int empty;
 
        hash = isl_hash_init();
        hash = isl_hash_builtin(hash, temp.src);
        hash = isl_hash_builtin(hash, temp.dst);
-       entry = isl_hash_table_find(ctx, graph->edge_table, hash,
+       return isl_hash_table_find(ctx, graph->edge_table[type], hash,
                                    &edge_has_src_and_dst, &temp, 0);
+}
+
+
+/* If graph->edge_table[type] contains an edge from the given source
+ * to the given destination, then return this edge.
+ * Otherwise, return NULL.
+ */
+static struct isl_sched_edge *graph_find_edge(struct isl_sched_graph *graph,
+       enum isl_edge_type type,
+       struct isl_sched_node *src, struct isl_sched_node *dst)
+{
+       struct isl_hash_table_entry *entry;
+
+       entry = graph_find_edge_entry(graph, type, src, dst);
        if (!entry)
+               return NULL;
+
+       return entry->data;
+}
+
+/* Check whether the dependence graph has an edge of the give type
+ * between the given two nodes.
+ */
+static int graph_has_edge(struct isl_sched_graph *graph,
+       enum isl_edge_type type,
+       struct isl_sched_node *src, struct isl_sched_node *dst)
+{
+       struct isl_sched_edge *edge;
+       int empty;
+
+       edge = graph_find_edge(graph, type, src, dst);
+       if (!edge)
                return 0;
 
-       edge = entry->data;
        empty = isl_map_plain_is_empty(edge->map);
        if (empty < 0)
                return -1;
@@ -315,6 +360,72 @@ static int graph_has_edge(struct isl_sched_graph *graph,
        return !empty;
 }
 
+/* If there is an edge from the given source to the given destination
+ * of any type then return this edge.
+ * Otherwise, return NULL.
+ */
+static struct isl_sched_edge *graph_find_any_edge(struct isl_sched_graph *graph,
+       struct isl_sched_node *src, struct isl_sched_node *dst)
+{
+       int i;
+       struct isl_sched_edge *edge;
+
+       for (i = 0; i <= isl_edge_last; ++i) {
+               edge = graph_find_edge(graph, i, src, dst);
+               if (edge)
+                       return edge;
+       }
+
+       return NULL;
+}
+
+/* Remove the given edge from all the edge_tables that refer to it.
+ */
+static void graph_remove_edge(struct isl_sched_graph *graph,
+       struct isl_sched_edge *edge)
+{
+       isl_ctx *ctx = isl_map_get_ctx(edge->map);
+       int i;
+
+       for (i = 0; i <= isl_edge_last; ++i) {
+               struct isl_hash_table_entry *entry;
+
+               entry = graph_find_edge_entry(graph, i, edge->src, edge->dst);
+               if (!entry)
+                       continue;
+               if (entry->data != edge)
+                       continue;
+               isl_hash_table_remove(ctx, graph->edge_table[i], entry);
+       }
+}
+
+/* Check whether the dependence graph has any edge
+ * between the given two nodes.
+ */
+static int graph_has_any_edge(struct isl_sched_graph *graph,
+       struct isl_sched_node *src, struct isl_sched_node *dst)
+{
+       int i;
+       int r;
+
+       for (i = 0; i <= isl_edge_last; ++i) {
+               r = graph_has_edge(graph, i, src, dst);
+               if (r < 0 || r)
+                       return r;
+       }
+
+       return r;
+}
+
+/* Check whether the dependence graph has a validity edge
+ * between the given two nodes.
+ */
+static int graph_has_validity_edge(struct isl_sched_graph *graph,
+       struct isl_sched_node *src, struct isl_sched_node *dst)
+{
+       return graph_has_edge(graph, isl_edge_validity, src, dst);
+}
+
 static int graph_alloc(isl_ctx *ctx, struct isl_sched_graph *graph,
        int n_node, int n_edge)
 {
@@ -367,7 +478,8 @@ static void graph_free(isl_ctx *ctx, struct isl_sched_graph *graph)
        free(graph->edge);
        free(graph->region);
        free(graph->stack);
-       isl_hash_table_free(ctx, graph->edge_table);
+       for (i = 0; i <= isl_edge_last; ++i)
+               isl_hash_table_free(ctx, graph->edge_table[i]);
        isl_hash_table_free(ctx, graph->node_table);
        isl_basic_set_free(graph->lp);
 }
@@ -410,21 +522,27 @@ static int extract_node(__isl_take isl_set *set, void *user)
        return 0;
 }
 
-/* Add a new edge to the graph based on the given map.
- * Edges are first extracted from the validity dependences,
- * from which the edge_table is constructed.
- * Afterwards, the proximity dependences are added.  If a proximity
- * dependence relation happens to be identical to one of the
- * validity dependence relations added before, then we don't create
- * a new edge, but instead mark the original edge as also representing
- * a proximity dependence.
+struct isl_extract_edge_data {
+       enum isl_edge_type type;
+       struct isl_sched_graph *graph;
+};
+
+/* Add a new edge to the graph based on the given map
+ * and add it to data->graph->edge_table[data->type].
+ * If a dependence relation of a given type happens to be identical
+ * to one of the dependence relations of a type that was added before,
+ * then we don't create a new edge, but instead mark the original edge
+ * as also representing a dependence of the current type.
  */
 static int extract_edge(__isl_take isl_map *map, void *user)
 {
        isl_ctx *ctx = isl_map_get_ctx(map);
-       struct isl_sched_graph *graph = user;
+       struct isl_extract_edge_data *data = user;
+       struct isl_sched_graph *graph = data->graph;
        struct isl_sched_node *src, *dst;
        isl_space *dim;
+       struct isl_sched_edge *edge;
+       int is_equal;
 
        dim = isl_space_domain(isl_map_get_space(map));
        src = graph_find_node(ctx, graph, dim);
@@ -441,54 +559,55 @@ static int extract_edge(__isl_take isl_map *map, void *user)
        graph->edge[graph->n_edge].src = src;
        graph->edge[graph->n_edge].dst = dst;
        graph->edge[graph->n_edge].map = map;
-       graph->edge[graph->n_edge].validity = !graph->edge_table;
-       graph->edge[graph->n_edge].proximity = !!graph->edge_table;
+       if (data->type == isl_edge_validity) {
+               graph->edge[graph->n_edge].validity = 1;
+               graph->edge[graph->n_edge].proximity = 0;
+       }
+       if (data->type == isl_edge_proximity) {
+               graph->edge[graph->n_edge].validity = 0;
+               graph->edge[graph->n_edge].proximity = 1;
+       }
        graph->n_edge++;
 
-       if (graph->edge_table) {
-               uint32_t hash;
-               struct isl_hash_table_entry *entry;
-               struct isl_sched_edge *edge;
-               int is_equal;
-
-               hash = isl_hash_init();
-               hash = isl_hash_builtin(hash, src);
-               hash = isl_hash_builtin(hash, dst);
-               entry = isl_hash_table_find(ctx, graph->edge_table, hash,
-                                           &edge_has_src_and_dst,
-                                           &graph->edge[graph->n_edge - 1], 0);
-               if (!entry)
-                       return 0;
-               edge = entry->data;
-               is_equal = isl_map_plain_is_equal(map, edge->map);
-               if (is_equal < 0)
-                       return -1;
-               if (!is_equal)
-                       return 0;
+       edge = graph_find_any_edge(graph, src, dst);
+       if (!edge)
+               return graph_edge_table_add(ctx, graph, data->type,
+                                   &graph->edge[graph->n_edge - 1]);
+       is_equal = isl_map_plain_is_equal(map, edge->map);
+       if (is_equal < 0)
+               return -1;
+       if (!is_equal)
+               return graph_edge_table_add(ctx, graph, data->type,
+                                   &graph->edge[graph->n_edge - 1]);
 
-               graph->n_edge--;
-               edge->proximity = 1;
-               isl_map_free(map);
-       }
+       graph->n_edge--;
+       edge->validity |= graph->edge[graph->n_edge].validity;
+       edge->proximity |= graph->edge[graph->n_edge].proximity;
+       isl_map_free(map);
 
-       return 0;
+       return graph_edge_table_add(ctx, graph, data->type, edge);
 }
 
 /* Check whether there is a validity dependence from src to dst,
- * forcing dst to follow src.
+ * forcing dst to follow src (if weak is not set).
+ * If weak is set, then check if there is any dependence from src to dst.
  */
 static int node_follows(struct isl_sched_graph *graph, 
-       struct isl_sched_node *dst, struct isl_sched_node *src)
+       struct isl_sched_node *dst, struct isl_sched_node *src, int weak)
 {
-       return graph_has_edge(graph, src, dst);
+       if (weak)
+               return graph_has_any_edge(graph, src, dst);
+       else
+               return graph_has_validity_edge(graph, src, dst);
 }
 
 /* Perform Tarjan's algorithm for computing the strongly connected components
  * in the dependence graph (only validity edges).
- * If directed is not set, we consider the graph to be undirected and
+ * If weak is set, we consider the graph to be undirected and
  * we effectively compute the (weakly) connected components.
+ * Additionally, we also consider other edges when weak is set.
  */
-static int detect_sccs_tarjan(struct isl_sched_graph *g, int i, int directed)
+static int detect_sccs_tarjan(struct isl_sched_graph *g, int i, int weak)
 {
        int j;
 
@@ -508,18 +627,18 @@ static int detect_sccs_tarjan(struct isl_sched_graph *g, int i, int directed)
                         g->node[j].index > g->node[i].min_index))
                        continue;
                
-               f = node_follows(g, &g->node[i], &g->node[j]);
+               f = node_follows(g, &g->node[i], &g->node[j], weak);
                if (f < 0)
                        return -1;
-               if (!f && !directed) {
-                       f = node_follows(g, &g->node[j], &g->node[i]);
+               if (!f && weak) {
+                       f = node_follows(g, &g->node[j], &g->node[i], weak);
                        if (f < 0)
                                return -1;
                }
                if (!f)
                        continue;
                if (g->node[j].index < 0) {
-                       detect_sccs_tarjan(g, j, directed);
+                       detect_sccs_tarjan(g, j, weak);
                        if (g->node[j].min_index < g->node[i].min_index)
                                g->node[i].min_index = g->node[j].min_index;
                } else if (g->node[j].index < g->node[i].min_index)
@@ -539,7 +658,7 @@ static int detect_sccs_tarjan(struct isl_sched_graph *g, int i, int directed)
        return 0;
 }
 
-static int detect_ccs(struct isl_sched_graph *graph, int directed)
+static int detect_ccs(struct isl_sched_graph *graph, int weak)
 {
        int i;
 
@@ -552,7 +671,7 @@ static int detect_ccs(struct isl_sched_graph *graph, int directed)
        for (i = graph->n - 1; i >= 0; --i) {
                if (graph->node[i].index >= 0)
                        continue;
-               if (detect_sccs_tarjan(graph, i, directed) < 0)
+               if (detect_sccs_tarjan(graph, i, weak) < 0)
                        return -1;
        }
 
@@ -564,7 +683,7 @@ static int detect_ccs(struct isl_sched_graph *graph, int directed)
  */
 static int detect_sccs(struct isl_sched_graph *graph)
 {
-       return detect_ccs(graph, 1);
+       return detect_ccs(graph, 0);
 }
 
 /* Apply Tarjan's algorithm to detect the (weakly) connected components
@@ -572,7 +691,7 @@ static int detect_sccs(struct isl_sched_graph *graph)
  */
 static int detect_wccs(struct isl_sched_graph *graph)
 {
-       return detect_ccs(graph, 0);
+       return detect_ccs(graph, 1);
 }
 
 static int cmp_scc(const void *a, const void *b, void *data)
@@ -1519,12 +1638,12 @@ static __isl_give isl_map *specialize(__isl_take isl_map *map,
 
 /* Update the dependence relations of all edges based on the current schedule.
  * If a dependence is carried completely by the current schedule, then
- * it is removed and edge_table is updated accordingly.
+ * it is removed from the edge_tables.  It is kept in the list of edges
+ * as otherwise all edge_tables would have to be recomputed.
  */
 static int update_edges(isl_ctx *ctx, struct isl_sched_graph *graph)
 {
        int i;
-       int reset_table = 0;
 
        for (i = graph->n_edge - 1; i >= 0; --i) {
                struct isl_sched_edge *edge = &graph->edge[i];
@@ -1532,19 +1651,8 @@ static int update_edges(isl_ctx *ctx, struct isl_sched_graph *graph)
                if (!edge->map)
                        return -1;
 
-               if (isl_map_plain_is_empty(edge->map)) {
-                       reset_table = 1;
-                       isl_map_free(edge->map);
-                       if (i != graph->n_edge - 1)
-                               graph->edge[i] = graph->edge[graph->n_edge - 1];
-                       graph->n_edge--;
-               }
-       }
-
-       if (reset_table) {
-               isl_hash_table_free(ctx, graph->edge_table);
-               graph->edge_table = NULL;
-               return graph_init_edge_table(ctx, graph);
+               if (isl_map_plain_is_empty(edge->map))
+                       graph_remove_edge(graph, edge);
        }
 
        return 0;
@@ -1701,6 +1809,7 @@ static int copy_edges(isl_ctx *ctx, struct isl_sched_graph *dst,
        int (*edge_pred)(struct isl_sched_edge *edge, int data), int data)
 {
        int i;
+       int t;
 
        dst->n_edge = 0;
        for (i = 0; i < src->n_edge; ++i) {
@@ -1731,6 +1840,15 @@ static int copy_edges(isl_ctx *ctx, struct isl_sched_graph *dst,
                dst->edge[dst->n_edge].validity = edge->validity;
                dst->edge[dst->n_edge].proximity = edge->proximity;
                dst->n_edge++;
+
+               for (t = 0; t <= isl_edge_last; ++t) {
+                       if (edge !=
+                           graph_find_edge(src, t, edge->src, edge->dst))
+                               continue;
+                       if (graph_edge_table_add(ctx, dst, t,
+                                           &dst->edge[dst->n_edge - 1]) < 0)
+                               return -1;
+               }
        }
 
        return 0;
@@ -1810,6 +1928,7 @@ static int compute_sub_schedule(isl_ctx *ctx,
        int data, int wcc)
 {
        struct isl_sched_graph split = { 0 };
+       int t;
 
        if (graph_alloc(ctx, &split, n, n_edge) < 0)
                goto error;
@@ -1817,9 +1936,11 @@ static int compute_sub_schedule(isl_ctx *ctx,
                goto error;
        if (graph_init_table(ctx, &split) < 0)
                goto error;
-       if (copy_edges(ctx, &split, graph, edge_pred, data) < 0)
+       for (t = 0; t <= isl_edge_last; ++t)
+               split.max_edge[t] = graph->max_edge[t];
+       if (graph_init_edge_tables(ctx, &split) < 0)
                goto error;
-       if (graph_init_edge_table(ctx, &split) < 0)
+       if (copy_edges(ctx, &split, graph, edge_pred, data) < 0)
                goto error;
        split.n_row = graph->n_row;
        split.n_total_row = graph->n_total_row;
@@ -2403,15 +2524,23 @@ static int carry_dependences(isl_ctx *ctx, struct isl_sched_graph *graph)
        return compute_next_band(ctx, graph);
 }
 
-/* Are there any validity edges in the graph?
+/* Are there any (non-empty) validity edges in the graph?
  */
 static int has_validity_edges(struct isl_sched_graph *graph)
 {
        int i;
 
-       for (i = 0; i < graph->n_edge; ++i)
+       for (i = 0; i < graph->n_edge; ++i) {
+               int empty;
+
+               empty = isl_map_plain_is_empty(graph->edge[i].map);
+               if (empty < 0)
+                       return -1;
+               if (empty)
+                       continue;
                if (graph->edge[i].validity)
                        return 1;
+       }
 
        return 0;
 }
@@ -2561,7 +2690,8 @@ static int compute_component_schedule(isl_ctx *ctx,
        int n_total_row, orig_total_row;
        int n_band, orig_band;
 
-       if (ctx->opt->schedule_fuse == ISL_SCHEDULE_FUSE_MIN)
+       if (ctx->opt->schedule_fuse == ISL_SCHEDULE_FUSE_MIN ||
+           ctx->opt->schedule_separate_components)
                split_on_scc(graph);
 
        n_total_row = 0;
@@ -2639,6 +2769,7 @@ __isl_give isl_schedule *isl_union_set_compute_schedule(
        isl_space *dim;
        struct isl_sched_graph graph = { 0 };
        isl_schedule *sched;
+       struct isl_extract_edge_data data;
 
        domain = isl_union_set_align_params(domain,
                                            isl_union_map_get_space(validity));
@@ -2663,12 +2794,17 @@ __isl_give isl_schedule *isl_union_set_compute_schedule(
                goto error;
        if (graph_init_table(ctx, &graph) < 0)
                goto error;
-       graph.n_edge = 0;
-       if (isl_union_map_foreach_map(validity, &extract_edge, &graph) < 0)
+       graph.max_edge[isl_edge_validity] = isl_union_map_n_map(validity);
+       graph.max_edge[isl_edge_proximity] = isl_union_map_n_map(proximity);
+       if (graph_init_edge_tables(ctx, &graph) < 0)
                goto error;
-       if (graph_init_edge_table(ctx, &graph) < 0)
+       graph.n_edge = 0;
+       data.graph = &graph;
+       data.type = isl_edge_validity;
+       if (isl_union_map_foreach_map(validity, &extract_edge, &data) < 0)
                goto error;
-       if (isl_union_map_foreach_map(proximity, &extract_edge, &graph) < 0)
+       data.type = isl_edge_proximity;
+       if (isl_union_map_foreach_map(proximity, &extract_edge, &data) < 0)
                goto error;
 
        if (compute_schedule(ctx, &graph) < 0)