1 /*-------------------------------------------------------------------------
2 * drawElements C++ Base Library
3 * -----------------------------
5 * Copyright 2015 The Android Open Source Project
7 * Licensed under the Apache License, Version 2.0 (the "License");
8 * you may not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
11 * http://www.apache.org/licenses/LICENSE-2.0
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
21 * \brief Cross-thread barrier.
22 *//*--------------------------------------------------------------------*/
24 #include "deSpinBarrier.hpp"
25 #include "deThread.hpp"
26 #include "deRandom.hpp"
34 SpinBarrier::SpinBarrier (deInt32 numThreads)
35 : m_numCores (deGetNumAvailableLogicalCores())
36 , m_numThreads (numThreads)
41 DE_ASSERT(numThreads > 0);
44 SpinBarrier::~SpinBarrier (void)
46 DE_ASSERT(m_numEntered == 0 && m_numLeaving == 0);
49 void SpinBarrier::reset (deUint32 numThreads)
51 // If last threads were removed, m_numEntered > 0 && m_numRemoved > 0
52 DE_ASSERT(m_numLeaving == 0);
53 DE_ASSERT(numThreads > 0);
54 m_numThreads = numThreads;
60 inline SpinBarrier::WaitMode getWaitMode (SpinBarrier::WaitMode requested, deUint32 numCores, deInt32 numThreads)
62 if (requested == SpinBarrier::WAIT_MODE_AUTO)
63 return ((deUint32)numThreads <= numCores) ? SpinBarrier::WAIT_MODE_BUSY : SpinBarrier::WAIT_MODE_YIELD;
68 inline void wait (SpinBarrier::WaitMode mode)
70 DE_ASSERT(mode == SpinBarrier::WAIT_MODE_YIELD || mode == SpinBarrier::WAIT_MODE_BUSY);
72 if (mode == SpinBarrier::WAIT_MODE_YIELD)
76 void SpinBarrier::sync (WaitMode requestedMode)
78 const deInt32 cachedNumThreads = m_numThreads;
79 const WaitMode waitMode = getWaitMode(requestedMode, m_numCores, cachedNumThreads);
81 deMemoryReadWriteFence();
83 // m_numEntered must not be touched until all threads have had
84 // a chance to observe it being 0.
89 if (m_numLeaving == 0)
96 // If m_numRemoved > 0, m_numThreads will decrease. If m_numThreads is decreased
97 // just after atomicOp and before comparison, the branch could be taken by multiple
98 // threads. Since m_numThreads only changes if all threads are inside the spinbarrier,
99 // cached value at snapshotted at the beginning of the function will be equal for
101 if (deAtomicIncrement32(&m_numEntered) == cachedNumThreads)
103 // Release all waiting threads. Since this thread has not been removed, m_numLeaving will
104 // be >= 1 until m_numLeaving is decremented at the end of this function.
105 m_numThreads -= m_numRemoved;
106 m_numLeaving = m_numThreads;
109 deMemoryReadWriteFence();
116 if (m_numEntered == 0)
123 deAtomicDecrement32(&m_numLeaving);
124 deMemoryReadWriteFence();
127 void SpinBarrier::removeThread (WaitMode requestedMode)
129 const deInt32 cachedNumThreads = m_numThreads;
130 const WaitMode waitMode = getWaitMode(requestedMode, m_numCores, cachedNumThreads);
132 // Wait for other threads exiting previous barrier
133 if (m_numLeaving > 0)
137 if (m_numLeaving == 0)
144 // Ask for last thread entering barrier to adjust thread count
145 deAtomicIncrement32(&m_numRemoved);
147 // See sync() - use cached value
148 if (deAtomicIncrement32(&m_numEntered) == cachedNumThreads)
150 // Release all waiting threads.
151 m_numThreads -= m_numRemoved;
152 m_numLeaving = m_numThreads;
155 deMemoryReadWriteFence();
163 void singleThreadTest (SpinBarrier::WaitMode mode)
165 SpinBarrier barrier(1);
172 class TestThread : public de::Thread
175 TestThread (SpinBarrier& barrier, volatile deInt32* sharedVar, int numThreads, int threadNdx)
176 : m_barrier (barrier)
177 , m_sharedVar (sharedVar)
178 , m_numThreads (numThreads)
179 , m_threadNdx (threadNdx)
180 , m_busyOk ((deUint32)m_numThreads <= deGetNumAvailableLogicalCores())
186 const int numIters = 10000;
187 de::Random rnd (deInt32Hash(m_numThreads) ^ deInt32Hash(m_threadNdx));
189 for (int iterNdx = 0; iterNdx < numIters; iterNdx++)
192 deAtomicIncrement32(m_sharedVar);
195 m_barrier.sync(getWaitMode(rnd));
197 DE_TEST_ASSERT(*m_sharedVar == m_numThreads);
199 m_barrier.sync(getWaitMode(rnd));
201 // Phase 2: count down
202 deAtomicDecrement32(m_sharedVar);
205 m_barrier.sync(getWaitMode(rnd));
207 DE_TEST_ASSERT(*m_sharedVar == 0);
209 m_barrier.sync(getWaitMode(rnd));
214 SpinBarrier& m_barrier;
215 volatile deInt32* const m_sharedVar;
216 const int m_numThreads;
217 const int m_threadNdx;
220 SpinBarrier::WaitMode getWaitMode (de::Random& rnd)
222 static const SpinBarrier::WaitMode s_allModes[] =
224 SpinBarrier::WAIT_MODE_YIELD,
225 SpinBarrier::WAIT_MODE_AUTO,
226 SpinBarrier::WAIT_MODE_BUSY,
228 const int numModes = DE_LENGTH_OF_ARRAY(s_allModes) - (m_busyOk ? 0 : 1);
230 return rnd.choose<SpinBarrier::WaitMode>(DE_ARRAY_BEGIN(s_allModes), DE_ARRAY_BEGIN(s_allModes) + numModes);
234 void multiThreadTest (int numThreads)
236 SpinBarrier barrier (numThreads);
237 volatile deInt32 sharedVar = 0;
238 std::vector<TestThread*> threads (numThreads, static_cast<TestThread*>(DE_NULL));
240 for (int ndx = 0; ndx < numThreads; ndx++)
242 threads[ndx] = new TestThread(barrier, &sharedVar, numThreads, ndx);
243 DE_TEST_ASSERT(threads[ndx]);
244 threads[ndx]->start();
247 for (int ndx = 0; ndx < numThreads; ndx++)
249 threads[ndx]->join();
253 DE_TEST_ASSERT(sharedVar == 0);
256 void singleThreadRemoveTest (SpinBarrier::WaitMode mode)
258 SpinBarrier barrier(3);
260 barrier.removeThread(mode);
261 barrier.removeThread(mode);
263 barrier.removeThread(mode);
269 barrier.removeThread(mode);
273 class TestExitThread : public de::Thread
276 TestExitThread (SpinBarrier& barrier, int numThreads, int threadNdx, SpinBarrier::WaitMode waitMode)
277 : m_barrier (barrier)
278 , m_numThreads (numThreads)
279 , m_threadNdx (threadNdx)
280 , m_waitMode (waitMode)
286 const int numIters = 10000;
287 de::Random rnd (deInt32Hash(m_numThreads) ^ deInt32Hash(m_threadNdx) ^ deInt32Hash((deInt32)m_waitMode));
288 const int invExitProb = 1000;
290 for (int iterNdx = 0; iterNdx < numIters; iterNdx++)
292 if (rnd.getInt(0, invExitProb) == 0)
294 m_barrier.removeThread(m_waitMode);
298 m_barrier.sync(m_waitMode);
303 SpinBarrier& m_barrier;
304 const int m_numThreads;
305 const int m_threadNdx;
306 const SpinBarrier::WaitMode m_waitMode;
309 void multiThreadRemoveTest (int numThreads, SpinBarrier::WaitMode waitMode)
311 SpinBarrier barrier (numThreads);
312 std::vector<TestExitThread*> threads (numThreads, static_cast<TestExitThread*>(DE_NULL));
314 for (int ndx = 0; ndx < numThreads; ndx++)
316 threads[ndx] = new TestExitThread(barrier, numThreads, ndx, waitMode);
317 DE_TEST_ASSERT(threads[ndx]);
318 threads[ndx]->start();
321 for (int ndx = 0; ndx < numThreads; ndx++)
323 threads[ndx]->join();
330 void SpinBarrier_selfTest (void)
332 singleThreadTest(SpinBarrier::WAIT_MODE_YIELD);
333 singleThreadTest(SpinBarrier::WAIT_MODE_BUSY);
334 singleThreadTest(SpinBarrier::WAIT_MODE_AUTO);
341 singleThreadRemoveTest(SpinBarrier::WAIT_MODE_YIELD);
342 singleThreadRemoveTest(SpinBarrier::WAIT_MODE_BUSY);
343 singleThreadRemoveTest(SpinBarrier::WAIT_MODE_AUTO);
344 multiThreadRemoveTest(1, SpinBarrier::WAIT_MODE_BUSY);
345 multiThreadRemoveTest(2, SpinBarrier::WAIT_MODE_AUTO);
346 multiThreadRemoveTest(4, SpinBarrier::WAIT_MODE_AUTO);
347 multiThreadRemoveTest(8, SpinBarrier::WAIT_MODE_AUTO);
348 multiThreadRemoveTest(16, SpinBarrier::WAIT_MODE_AUTO);
349 multiThreadRemoveTest(1, SpinBarrier::WAIT_MODE_YIELD);
350 multiThreadRemoveTest(2, SpinBarrier::WAIT_MODE_YIELD);
351 multiThreadRemoveTest(4, SpinBarrier::WAIT_MODE_YIELD);
352 multiThreadRemoveTest(8, SpinBarrier::WAIT_MODE_YIELD);
353 multiThreadRemoveTest(16, SpinBarrier::WAIT_MODE_YIELD);