--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file parallel_for.h
+ * \brief An implementation to run loop in parallel.
+ */
+#ifndef TVM_SUPPORT_PARALLEL_FOR_H_
+#define TVM_SUPPORT_PARALLEL_FOR_H_
+
+#include <tvm/runtime/c_runtime_api.h>
+
+#include <functional>
+#include <vector>
+
+namespace tvm {
+namespace support {
+
+using PartitionerFuncType = std::function<std::vector<std::vector<int>>(int, int, int, int)>;
+
+/*!
+ * \brief A partitioner to split the task to each thread in Round-robin manner.
+ * \param begin The start index of this parallel loop(inclusive).
+ * \param end The end index of this parallel loop(exclusive).
+ * \param step The traversal step to the index.
+ * \param num_threads The number of threads(the number of tasks to be partitioned to).
+ * \return A list with `num_threads` elements, and each is a list of integers indicating the loop
+ * indexes for the corresponding thread to process.
+ */
+TVM_DLL std::vector<std::vector<int>> rr_partitioner(int begin, int end, int step, int num_threads);
+
+/*!
+ * \brief A runtime api provided to run the task function in parallel.
+ * e.g. A for loop:
+ * for (int i = 0; i < 10; i++) {
+ * a[i] = i;
+ * }
+ * should work the same as:
+ * parallel_for(0, 10, [&a](int index) {
+ * a[i] = i;
+ * });
+ * \param begin The start index of this parallel loop(inclusive).
+ * \param end The end index of this parallel loop(exclusive).
+ * \param f The task function to be excuted. Assert to take an int index as input with no output.
+ * \param step The traversal step to the index.
+ * \param partitioner A partition function to split tasks to different threads. Use Round-robin
+ * partitioner by default.
+ * \note 1. Currently do not support nested parallel_for; 2. The order of execution in each thread
+ * is not guaranteed, the for loop task should be thread independent and thread safe.
+ */
+TVM_DLL void parallel_for(int begin, int end, const std::function<void(int)>& f, int step = 1,
+ const PartitionerFuncType partitioner = rr_partitioner);
+
+} // namespace support
+} // namespace tvm
+
+#endif // TVM_SUPPORT_PARALLEL_FOR_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file parallel_for.cc
+ * \brief An implementation to run loop in parallel.
+ */
+#include <dmlc/logging.h>
+#include <tvm/support/parallel_for.h>
+
+#include <future>
+#include <thread>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace support {
+
+std::vector<std::vector<int>> rr_partitioner(int begin, int end, int step, int num_threads) {
+ int total_task_count = (end - begin) / step;
+ CHECK_GT(total_task_count, 0) << "Infinite loop condition, check the input value of "
+ << "`begin`, `end`, `step`.";
+ std::vector<std::vector<int>> ret;
+ ret.reserve(num_threads);
+ for (size_t thread = 0; begin < end; begin += step, thread = (thread + 1) % num_threads) {
+ if (thread >= ret.size()) {
+ ret.push_back(std::vector<int>());
+ }
+ ret[thread].push_back(begin);
+ }
+ return ret;
+}
+
+void parallel_for(int begin, int end, const std::function<void(int)>& f, int step,
+ const PartitionerFuncType partitioner) {
+ int default_num_threads = std::thread::hardware_concurrency();
+ const auto& run_partitions = partitioner(begin, end, step, default_num_threads);
+
+ std::vector<std::thread> threads;
+ threads.reserve(run_partitions.size());
+ std::vector<std::future<void>> res_vec;
+ res_vec.reserve(run_partitions.size());
+ for (const auto& run_partition : run_partitions) {
+ std::packaged_task<void(const std::vector<int>&, const std::function<void(int)>&)> task(
+ [](const std::vector<int>& run_pattition, const std::function<void(int)>& f) {
+ for (const auto& i : run_pattition) {
+ f(i);
+ }
+ });
+ res_vec.emplace_back(task.get_future());
+ threads.emplace_back(std::move(task), run_partition, f);
+ }
+
+ for (auto&& thread : threads) {
+ thread.join();
+ }
+ try {
+ for (auto&& i : res_vec) {
+ i.get();
+ }
+ } catch (const std::exception& e) {
+ LOG(FATAL) << "Parallel_for error with " << e.what();
+ }
+}
+
+} // namespace support
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <dmlc/logging.h>
+#include <gtest/gtest.h>
+#include <tvm/support/parallel_for.h>
+
+#include <vector>
+
+TEST(ParallelFor, Basic) {
+ using tvm::support::parallel_for;
+
+ int a[1000], b[1000];
+
+ // Check for a small size of parallel
+ for (int i = 0; i < 10; i++) {
+ a[i] = i;
+ }
+ parallel_for(0, 10, [&b](int i) { b[i] = i; });
+ for (int i = 0; i < 10; i++) {
+ CHECK_EQ(a[i], b[i]);
+ }
+
+ // Check for a large size of parallel
+ for (int i = 0; i < 1000; i++) {
+ a[i] = i;
+ }
+ parallel_for(0, 1000, [&b](int i) { b[i] = i; });
+ for (int i = 0; i < 1000; i++) {
+ CHECK_EQ(a[i], b[i]);
+ }
+
+ // Check for step != 1
+ for (int i = 0; i < 1000; i += 2) {
+ a[i] *= 2;
+ }
+ parallel_for(
+ 0, 1000, [&b](int i) { b[i] *= 2; }, 2);
+ for (int i = 0; i < 1000; i++) {
+ CHECK_EQ(a[i], b[i]);
+ }
+}
+
+TEST(ParallelFor, NestedWithNormalForLoop) {
+ using tvm::support::parallel_for;
+
+ int a[500][500], b[500][500], c[500][500];
+
+ for (int i = 0; i < 500; i++) {
+ for (int j = 0; j < 500; j++) {
+ a[i][j] = i * j;
+ }
+ }
+
+ parallel_for(0, 500, [&b](int i) {
+ for (int j = 0; j < 500; j++) {
+ b[i][j] = i * j;
+ }
+ });
+ for (int i = 0; i < 500; i++) {
+ for (int j = 0; j < 500; j++) {
+ CHECK_EQ(a[i][j], b[i][j]);
+ }
+ }
+
+ for (int i = 0; i < 500; i++) {
+ parallel_for(0, 500, [&c, &i](int j) { c[i][j] = i * j; });
+ }
+ for (int i = 0; i < 500; i++) {
+ for (int j = 0; j < 500; j++) {
+ CHECK_EQ(a[i][j], c[i][j]);
+ }
+ }
+}
+
+TEST(ParallelFor, Exception) {
+ using tvm::support::parallel_for;
+
+ bool exception = false;
+ try {
+ parallel_for(0, 100, [](int i) { LOG(FATAL) << "error"; });
+ } catch (const std::exception& e) {
+ exception = true;
+ }
+ CHECK(exception);
+}
+
+int main(int argc, char** argv) {
+ testing::InitGoogleTest(&argc, argv);
+ testing::FLAGS_gtest_death_test_style = "threadsafe";
+ return RUN_ALL_TESTS();
+}