+/* Internal data structure for before_for and after_for callbacks.
+ *
+ * depth is the current depth
+ * before is the number of times before_for has been called
+ * after is the number of times after_for has been called
+ */
+struct isl_test_codegen_data {
+ int depth;
+ int before;
+ int after;
+};
+
+/* This function is called before each for loop in the AST generated
+ * from test_ast_gen1.
+ *
+ * Increment the number of calls and the depth.
+ * Check that the space returned by isl_ast_build_get_schedule_space
+ * matches the target space of the schedule returned by
+ * isl_ast_build_get_schedule.
+ * Return an isl_id that is checked by the corresponding call
+ * to after_for.
+ */
+static __isl_give isl_id *before_for(__isl_keep isl_ast_build *build,
+ void *user)
+{
+ struct isl_test_codegen_data *data = user;
+ isl_ctx *ctx;
+ isl_space *space;
+ isl_union_map *schedule;
+ isl_union_set *uset;
+ isl_set *set;
+ int empty;
+ char name[] = "d0";
+
+ ctx = isl_ast_build_get_ctx(build);
+
+ if (data->before >= 3)
+ isl_die(ctx, isl_error_unknown,
+ "unexpected number of for nodes", return NULL);
+ if (data->depth >= 2)
+ isl_die(ctx, isl_error_unknown,
+ "unexpected depth", return NULL);
+
+ snprintf(name, sizeof(name), "d%d", data->depth);
+ data->before++;
+ data->depth++;
+
+ schedule = isl_ast_build_get_schedule(build);
+ uset = isl_union_map_range(schedule);
+ if (!uset)
+ return NULL;
+ if (isl_union_set_n_set(uset) != 1) {
+ isl_union_set_free(uset);
+ isl_die(ctx, isl_error_unknown,
+ "expecting single range space", return NULL);
+ }
+
+ space = isl_ast_build_get_schedule_space(build);
+ set = isl_union_set_extract_set(uset, space);
+ isl_union_set_free(uset);
+ empty = isl_set_is_empty(set);
+ isl_set_free(set);
+
+ if (empty < 0)
+ return NULL;
+ if (empty)
+ isl_die(ctx, isl_error_unknown,
+ "spaces don't match", return NULL);
+
+ return isl_id_alloc(ctx, name, NULL);
+}
+
+/* This function is called after each for loop in the AST generated
+ * from test_ast_gen1.
+ *
+ * Increment the number of calls and decrement the depth.
+ * Check that the annotation attached to the node matches
+ * the isl_id returned by the corresponding call to before_for.
+ */
+static __isl_give isl_ast_node *after_for(__isl_take isl_ast_node *node,
+ __isl_keep isl_ast_build *build, void *user)
+{
+ struct isl_test_codegen_data *data = user;
+ isl_id *id;
+ const char *name;
+ int valid;
+
+ data->after++;
+ data->depth--;
+
+ if (data->after > data->before)
+ isl_die(isl_ast_node_get_ctx(node), isl_error_unknown,
+ "mismatch in number of for nodes",
+ return isl_ast_node_free(node));
+
+ id = isl_ast_node_get_annotation(node);
+ if (!id)
+ isl_die(isl_ast_node_get_ctx(node), isl_error_unknown,
+ "missing annotation", return isl_ast_node_free(node));
+
+ name = isl_id_get_name(id);
+ valid = name && atoi(name + 1) == data->depth;
+ isl_id_free(id);
+
+ if (!valid)
+ isl_die(isl_ast_node_get_ctx(node), isl_error_unknown,
+ "wrong annotation", return isl_ast_node_free(node));
+
+ return node;
+}
+
+/* Check that the before_each_for and after_each_for callbacks
+ * are called for each for loop in the generated code,
+ * that they are called in the right order and that the isl_id
+ * returned from the before_each_for callback is attached to
+ * the isl_ast_node passed to the corresponding after_each_for call.
+ */
+static int test_ast_gen1(isl_ctx *ctx)
+{
+ const char *str;
+ isl_set *set;
+ isl_union_map *schedule;
+ isl_ast_build *build;
+ isl_ast_node *tree;
+ struct isl_test_codegen_data data;
+
+ str = "[N] -> { : N >= 10 }";
+ set = isl_set_read_from_str(ctx, str);
+ str = "[N] -> { A[i,j] -> S[8,i,3,j] : 0 <= i,j <= N; "
+ "B[i,j] -> S[8,j,9,i] : 0 <= i,j <= N }";
+ schedule = isl_union_map_read_from_str(ctx, str);
+
+ data.before = 0;
+ data.after = 0;
+ data.depth = 0;
+ build = isl_ast_build_from_context(set);
+ build = isl_ast_build_set_before_each_for(build,
+ &before_for, &data);
+ build = isl_ast_build_set_after_each_for(build,
+ &after_for, &data);
+ tree = isl_ast_build_ast_from_schedule(build, schedule);
+ isl_ast_build_free(build);
+ if (!tree)
+ return -1;
+
+ isl_ast_node_free(tree);
+
+ if (data.before != 3 || data.after != 3)
+ isl_die(ctx, isl_error_unknown,
+ "unexpected number of for nodes", return -1);
+
+ return 0;
+}
+