Merged with mainline at revision 128810.
[official-gcc.git] / libstdc++-v3 / include / parallel / partition.h
blob790685bf0a5698badafbec16f3d28a615ddb2117
1 // -*- C++ -*-
3 // Copyright (C) 2007 Free Software Foundation, Inc.
4 //
5 // This file is part of the GNU ISO C++ Library. This library is free
6 // software; you can redistribute it and/or modify it under the terms
7 // of the GNU General Public License as published by the Free Software
8 // Foundation; either version 2, or (at your option) any later
9 // version.
11 // This library is distributed in the hope that it will be useful, but
12 // WITHOUT ANY WARRANTY; without even the implied warranty of
13 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 // General Public License for more details.
16 // You should have received a copy of the GNU General Public License
17 // along with this library; see the file COPYING. If not, write to
18 // the Free Software Foundation, 59 Temple Place - Suite 330, Boston,
19 // MA 02111-1307, USA.
21 // As a special exception, you may use this file as part of a free
22 // software library without restriction. Specifically, if other files
23 // instantiate templates or use macros or inline functions from this
24 // file, or you compile this file and link it with other files to
25 // produce an executable, this file does not by itself cause the
26 // resulting executable to be covered by the GNU General Public
27 // License. This exception does not however invalidate any other
28 // reasons why the executable file might be covered by the GNU General
29 // Public License.
31 /** @file parallel/partition.h
32 * @brief Parallel implementation of std::partition(),
33 * std::nth_element(), and std::partial_sort().
34 * This file is a GNU parallel extension to the Standard C++ Library.
37 // Written by Johannes Singler and Felix Putze.
39 #ifndef _GLIBCXX_PARALLEL_PARTITION_H
40 #define _GLIBCXX_PARALLEL_PARTITION_H 1
42 #include <parallel/basic_iterator.h>
43 #include <parallel/sort.h>
44 #include <bits/stl_algo.h>
45 #include <parallel/parallel.h>
47 /** @brief Decide whether to declare certain variable volatile in this file. */
48 #define _GLIBCXX_VOLATILE volatile
50 namespace __gnu_parallel
52 /** @brief Parallel implementation of std::partition.
53 * @param begin Begin iterator of input sequence to split.
54 * @param end End iterator of input sequence to split.
55 * @param pred Partition predicate, possibly including some kind of pivot.
56 * @param max_num_threads Maximum number of threads to use for this task.
57 * @return Number of elements not fulfilling the predicate. */
58 template<typename RandomAccessIterator, typename Predicate>
59 inline typename std::iterator_traits<RandomAccessIterator>::difference_type
60 parallel_partition(RandomAccessIterator begin, RandomAccessIterator end,
61 Predicate pred, thread_index_t max_num_threads)
63 typedef std::iterator_traits<RandomAccessIterator> traits_type;
64 typedef typename traits_type::value_type value_type;
65 typedef typename traits_type::difference_type difference_type;
67 difference_type n = end - begin;
69 _GLIBCXX_CALL(n)
71 // Shared.
72 _GLIBCXX_VOLATILE difference_type left = 0, right = n - 1;
73 _GLIBCXX_VOLATILE difference_type leftover_left, leftover_right, leftnew, rightnew;
74 bool* reserved_left, * reserved_right;
76 reserved_left = new bool[max_num_threads];
77 reserved_right = new bool[max_num_threads];
79 difference_type chunk_size;
80 if (Settings::partition_chunk_share > 0.0)
81 chunk_size = std::max((difference_type)Settings::partition_chunk_size, (difference_type)((double)n * Settings::partition_chunk_share / (double)max_num_threads));
82 else
83 chunk_size = Settings::partition_chunk_size;
85 // At least good for two processors.
86 while (right - left + 1 >= 2 * max_num_threads * chunk_size)
88 difference_type num_chunks = (right - left + 1) / chunk_size;
89 thread_index_t num_threads = (int)std::min((difference_type)max_num_threads, num_chunks / 2);
91 for (int r = 0; r < num_threads; r++)
93 reserved_left[r] = false;
94 reserved_right[r] = false;
96 leftover_left = 0;
97 leftover_right = 0;
99 #pragma omp parallel num_threads(num_threads)
101 // Private.
102 difference_type thread_left, thread_left_border, thread_right, thread_right_border;
103 thread_left = left + 1;
105 // Just to satisfy the condition below.
106 thread_left_border = thread_left - 1;
107 thread_right = n - 1;
108 thread_right_border = thread_right + 1;
110 bool iam_finished = false;
111 while (!iam_finished)
113 if (thread_left > thread_left_border)
114 #pragma omp critical
116 if (left + (chunk_size - 1) > right)
117 iam_finished = true;
118 else
120 thread_left = left;
121 thread_left_border = left + (chunk_size - 1);
122 left += chunk_size;
126 if (thread_right < thread_right_border)
127 #pragma omp critical
129 if (left > right - (chunk_size - 1))
130 iam_finished = true;
131 else
133 thread_right = right;
134 thread_right_border = right - (chunk_size - 1);
135 right -= chunk_size;
139 if (iam_finished)
140 break;
142 // Swap as usual.
143 while (thread_left < thread_right)
145 while (pred(begin[thread_left]) && thread_left <= thread_left_border)
146 thread_left++;
147 while (!pred(begin[thread_right]) && thread_right >= thread_right_border)
148 thread_right--;
150 if (thread_left > thread_left_border || thread_right < thread_right_border)
151 // Fetch new chunk(s).
152 break;
154 std::swap(begin[thread_left], begin[thread_right]);
155 thread_left++;
156 thread_right--;
160 // Now swap the leftover chunks to the right places.
161 if (thread_left <= thread_left_border)
162 #pragma omp atomic
163 leftover_left++;
164 if (thread_right >= thread_right_border)
165 #pragma omp atomic
166 leftover_right++;
168 #pragma omp barrier
170 #pragma omp single
172 leftnew = left - leftover_left * chunk_size;
173 rightnew = right + leftover_right * chunk_size;
176 #pragma omp barrier
178 // <=> thread_left_border + (chunk_size - 1) >= leftnew
179 if (thread_left <= thread_left_border
180 && thread_left_border >= leftnew)
182 // Chunk already in place, reserve spot.
183 reserved_left[(left - (thread_left_border + 1)) / chunk_size] = true;
186 // <=> thread_right_border - (chunk_size - 1) <= rightnew
187 if (thread_right >= thread_right_border
188 && thread_right_border <= rightnew)
190 // Chunk already in place, reserve spot.
191 reserved_right[((thread_right_border - 1) - right) / chunk_size] = true;
194 #pragma omp barrier
196 if (thread_left <= thread_left_border && thread_left_border < leftnew)
198 // Find spot and swap.
199 difference_type swapstart = -1;
200 #pragma omp critical
202 for (int r = 0; r < leftover_left; r++)
203 if (!reserved_left[r])
205 reserved_left[r] = true;
206 swapstart = left - (r + 1) * chunk_size;
207 break;
211 #if _GLIBCXX_ASSERTIONS
212 _GLIBCXX_PARALLEL_ASSERT(swapstart != -1);
213 #endif
215 std::swap_ranges(begin + thread_left_border - (chunk_size - 1), begin + thread_left_border + 1, begin + swapstart);
218 if (thread_right >= thread_right_border
219 && thread_right_border > rightnew)
221 // Find spot and swap
222 difference_type swapstart = -1;
223 #pragma omp critical
225 for (int r = 0; r < leftover_right; r++)
226 if (!reserved_right[r])
228 reserved_right[r] = true;
229 swapstart = right + r * chunk_size + 1;
230 break;
234 #if _GLIBCXX_ASSERTIONS
235 _GLIBCXX_PARALLEL_ASSERT(swapstart != -1);
236 #endif
238 std::swap_ranges(begin + thread_right_border, begin + thread_right_border + chunk_size, begin + swapstart);
240 #if _GLIBCXX_ASSERTIONS
241 #pragma omp barrier
243 #pragma omp single
245 for (int r = 0; r < leftover_left; r++)
246 _GLIBCXX_PARALLEL_ASSERT(reserved_left[r]);
247 for (int r = 0; r < leftover_right; r++)
248 _GLIBCXX_PARALLEL_ASSERT(reserved_right[r]);
251 #pragma omp barrier
252 #endif
254 #pragma omp barrier
255 left = leftnew;
256 right = rightnew;
258 } // end "recursion"
260 difference_type final_left = left, final_right = right;
262 while (final_left < final_right)
264 // Go right until key is geq than pivot.
265 while (pred(begin[final_left]) && final_left < final_right)
266 final_left++;
268 // Go left until key is less than pivot.
269 while (!pred(begin[final_right]) && final_left < final_right)
270 final_right--;
272 if (final_left == final_right)
273 break;
274 std::swap(begin[final_left], begin[final_right]);
275 final_left++;
276 final_right--;
279 // All elements on the left side are < piv, all elements on the
280 // right are >= piv
281 delete[] reserved_left;
282 delete[] reserved_right;
284 // Element "between" final_left and final_right might not have
285 // been regarded yet
286 if (final_left < n && !pred(begin[final_left]))
287 // Really swapped.
288 return final_left;
289 else
290 return final_left + 1;
293 /**
294 * @brief Parallel implementation of std::nth_element().
295 * @param begin Begin iterator of input sequence.
296 * @param nth Iterator of element that must be in position afterwards.
297 * @param end End iterator of input sequence.
298 * @param comp Comparator.
300 template<typename RandomAccessIterator, typename Comparator>
301 void
302 parallel_nth_element(RandomAccessIterator begin, RandomAccessIterator nth, RandomAccessIterator end, Comparator comp)
304 typedef std::iterator_traits<RandomAccessIterator> traits_type;
305 typedef typename traits_type::value_type value_type;
306 typedef typename traits_type::difference_type difference_type;
308 _GLIBCXX_CALL(end - begin)
310 RandomAccessIterator split;
311 value_type pivot;
312 random_number rng;
314 difference_type minimum_length = std::max<difference_type>(2, Settings::partition_minimal_n);
316 // Break if input range to small.
317 while (static_cast<sequence_index_t>(end - begin) >= minimum_length)
319 difference_type n = end - begin;
321 RandomAccessIterator pivot_pos = begin + rng(n);
323 // Swap pivot_pos value to end.
324 if (pivot_pos != (end - 1))
325 std::swap(*pivot_pos, *(end - 1));
326 pivot_pos = end - 1;
328 // XXX Comparator must have first_value_type, second_value_type, result_type
329 // Comparator == __gnu_parallel::lexicographic<S, int, __gnu_parallel::less<S, S> >
330 // pivot_pos == std::pair<S, int>*
331 // XXX binder2nd only for RandomAccessIterators??
332 __gnu_parallel::binder2nd<Comparator, value_type, value_type, bool> pred(comp, *pivot_pos);
334 // Divide, leave pivot unchanged in last place.
335 RandomAccessIterator split_pos1, split_pos2;
336 split_pos1 = begin + parallel_partition(begin, end - 1, pred, get_max_threads());
338 // Left side: < pivot_pos; right side: >= pivot_pos
340 // Swap pivot back to middle.
341 if (split_pos1 != pivot_pos)
342 std::swap(*split_pos1, *pivot_pos);
343 pivot_pos = split_pos1;
345 // In case all elements are equal, split_pos1 == 0
346 if ((split_pos1 + 1 - begin) < (n >> 7) || (end - split_pos1) < (n >> 7))
348 // Very unequal split, one part smaller than one 128th
349 // elements not stricly larger than the pivot.
350 __gnu_parallel::unary_negate<__gnu_parallel::binder1st<Comparator, value_type, value_type, bool>, value_type> pred(__gnu_parallel::binder1st<Comparator, value_type, value_type, bool>(comp, *pivot_pos));
352 // Find other end of pivot-equal range.
353 split_pos2 = __gnu_sequential::partition(split_pos1 + 1, end, pred);
355 else
356 // Only skip the pivot.
357 split_pos2 = split_pos1 + 1;
359 // Compare iterators.
360 if (split_pos2 <= nth)
361 begin = split_pos2;
362 else if (nth < split_pos1)
363 end = split_pos1;
364 else
365 break;
368 // Only at most Settings::partition_minimal_n elements left.
369 __gnu_sequential::sort(begin, end, comp);
372 /** @brief Parallel implementation of std::partial_sort().
373 * @param begin Begin iterator of input sequence.
374 * @param middle Sort until this position.
375 * @param end End iterator of input sequence.
376 * @param comp Comparator. */
377 template<typename RandomAccessIterator, typename Comparator>
378 void
379 parallel_partial_sort(RandomAccessIterator begin, RandomAccessIterator middle, RandomAccessIterator end, Comparator comp)
381 parallel_nth_element(begin, middle, end, comp);
382 std::sort(begin, middle, comp);
385 } //namespace __gnu_parallel
387 #undef _GLIBCXX_VOLATILE
389 #endif