Update concepts branch to revision 131834
[official-gcc.git] / libstdc++-v3 / include / parallel / multiway_mergesort.h
blob9d9733ad05f5f7c3c8bc41492ab5575f5cdadd3b
1 // -*- C++ -*-
3 // Copyright (C) 2007, 2008 Free Software Foundation, Inc.
4 //
5 // This file is part of the GNU ISO C++ Library. This library is free
6 // software; you can redistribute it and/or modify it under the terms
7 // of the GNU General Public License as published by the Free Software
8 // Foundation; either version 2, or (at your option) any later
9 // version.
11 // This library is distributed in the hope that it will be useful, but
12 // WITHOUT ANY WARRANTY; without even the implied warranty of
13 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 // General Public License for more details.
16 // You should have received a copy of the GNU General Public License
17 // along with this library; see the file COPYING. If not, write to
18 // the Free Software Foundation, 59 Temple Place - Suite 330, Boston,
19 // MA 02111-1307, USA.
21 // As a special exception, you may use this file as part of a free
22 // software library without restriction. Specifically, if other files
23 // instantiate templates or use macros or inline functions from this
24 // file, or you compile this file and link it with other files to
25 // produce an executable, this file does not by itself cause the
26 // resulting executable to be covered by the GNU General Public
27 // License. This exception does not however invalidate any other
28 // reasons why the executable file might be covered by the GNU General
29 // Public License.
31 /** @file parallel/multiway_mergesort.h
32 * @brief Parallel multiway merge sort.
33 * This file is a GNU parallel extension to the Standard C++ Library.
36 // Written by Johannes Singler.
38 #ifndef _GLIBCXX_PARALLEL_MERGESORT_H
39 #define _GLIBCXX_PARALLEL_MERGESORT_H 1
41 #include <vector>
43 #include <parallel/basic_iterator.h>
44 #include <bits/stl_algo.h>
45 #include <parallel/parallel.h>
46 #include <parallel/multiway_merge.h>
48 namespace __gnu_parallel
51 /** @brief Subsequence description. */
52 template<typename _DifferenceTp>
53 struct Piece
55 typedef _DifferenceTp difference_type;
57 /** @brief Begin of subsequence. */
58 difference_type begin;
60 /** @brief End of subsequence. */
61 difference_type end;
64 /** @brief Data accessed by all threads.
66 * PMWMS = parallel multiway mergesort */
67 template<typename RandomAccessIterator>
68 struct PMWMSSortingData
70 typedef std::iterator_traits<RandomAccessIterator> traits_type;
71 typedef typename traits_type::value_type value_type;
72 typedef typename traits_type::difference_type difference_type;
74 /** @brief Number of threads involved. */
75 thread_index_t num_threads;
77 /** @brief Input begin. */
78 RandomAccessIterator source;
80 /** @brief Start indices, per thread. */
81 difference_type* starts;
83 /** @brief Storage in which to sort. */
84 value_type** temporary;
86 /** @brief Samples. */
87 value_type* samples;
89 /** @brief Offsets to add to the found positions. */
90 difference_type* offsets;
92 /** @brief Pieces of data to merge @c [thread][sequence] */
93 std::vector<Piece<difference_type> >* pieces;
96 /**
97 * @brief Select samples from a sequence.
98 * @param sd Pointer to algorithm data. Result will be placed in
99 * @c sd->samples.
100 * @param num_samples Number of samples to select.
102 template<typename RandomAccessIterator, typename _DifferenceTp>
103 void
104 determine_samples(PMWMSSortingData<RandomAccessIterator>* sd,
105 _DifferenceTp num_samples)
107 typedef std::iterator_traits<RandomAccessIterator> traits_type;
108 typedef typename traits_type::value_type value_type;
109 typedef _DifferenceTp difference_type;
111 thread_index_t iam = omp_get_thread_num();
113 difference_type* es = new difference_type[num_samples + 2];
115 equally_split(sd->starts[iam + 1] - sd->starts[iam],
116 num_samples + 1, es);
118 for (difference_type i = 0; i < num_samples; ++i)
119 ::new(&(sd->samples[iam * num_samples + i]))
120 value_type(sd->source[sd->starts[iam] + es[i + 1]]);
122 delete[] es;
125 /** @brief Split consistently. */
126 template<bool exact, typename RandomAccessIterator,
127 typename Comparator, typename SortingPlacesIterator>
128 struct split_consistently
132 /** @brief Split by exact splitting. */
133 template<typename RandomAccessIterator, typename Comparator,
134 typename SortingPlacesIterator>
135 struct split_consistently
136 <true, RandomAccessIterator, Comparator, SortingPlacesIterator>
138 void operator()(
139 const thread_index_t iam,
140 PMWMSSortingData<RandomAccessIterator>* sd,
141 Comparator& comp,
142 const typename
143 std::iterator_traits<RandomAccessIterator>::difference_type
144 num_samples)
145 const
147 # pragma omp barrier
149 std::vector<std::pair<SortingPlacesIterator, SortingPlacesIterator> >
150 seqs(sd->num_threads);
151 for (thread_index_t s = 0; s < sd->num_threads; s++)
152 seqs[s] = std::make_pair(sd->temporary[s],
153 sd->temporary[s]
154 + (sd->starts[s + 1] - sd->starts[s]));
156 std::vector<SortingPlacesIterator> offsets(sd->num_threads);
158 // if not last thread
159 if (iam < sd->num_threads - 1)
160 multiseq_partition(seqs.begin(), seqs.end(),
161 sd->starts[iam + 1], offsets.begin(), comp);
163 for (int seq = 0; seq < sd->num_threads; seq++)
165 // for each sequence
166 if (iam < (sd->num_threads - 1))
167 sd->pieces[iam][seq].end = offsets[seq] - seqs[seq].first;
168 else
169 // very end of this sequence
170 sd->pieces[iam][seq].end =
171 sd->starts[seq + 1] - sd->starts[seq];
174 # pragma omp barrier
176 for (thread_index_t seq = 0; seq < sd->num_threads; seq++)
178 // For each sequence.
179 if (iam > 0)
180 sd->pieces[iam][seq].begin = sd->pieces[iam - 1][seq].end;
181 else
182 // Absolute beginning.
183 sd->pieces[iam][seq].begin = 0;
188 /** @brief Split by sampling. */
189 template<typename RandomAccessIterator, typename Comparator,
190 typename SortingPlacesIterator>
191 struct split_consistently<false, RandomAccessIterator, Comparator,
192 SortingPlacesIterator>
194 void operator()(
195 const thread_index_t iam,
196 PMWMSSortingData<RandomAccessIterator>* sd,
197 Comparator& comp,
198 const typename
199 std::iterator_traits<RandomAccessIterator>::difference_type
200 num_samples)
201 const
203 typedef std::iterator_traits<RandomAccessIterator> traits_type;
204 typedef typename traits_type::value_type value_type;
205 typedef typename traits_type::difference_type difference_type;
207 determine_samples(sd, num_samples);
209 # pragma omp barrier
211 # pragma omp single
212 __gnu_sequential::sort(sd->samples,
213 sd->samples + (num_samples * sd->num_threads),
214 comp);
216 # pragma omp barrier
218 for (thread_index_t s = 0; s < sd->num_threads; ++s)
220 // For each sequence.
221 if (num_samples * iam > 0)
222 sd->pieces[iam][s].begin =
223 std::lower_bound(sd->temporary[s],
224 sd->temporary[s]
225 + (sd->starts[s + 1] - sd->starts[s]),
226 sd->samples[num_samples * iam],
227 comp)
228 - sd->temporary[s];
229 else
230 // Absolute beginning.
231 sd->pieces[iam][s].begin = 0;
233 if ((num_samples * (iam + 1)) < (num_samples * sd->num_threads))
234 sd->pieces[iam][s].end =
235 std::lower_bound(sd->temporary[s],
236 sd->temporary[s]
237 + (sd->starts[s + 1] - sd->starts[s]),
238 sd->samples[num_samples * (iam + 1)],
239 comp)
240 - sd->temporary[s];
241 else
242 // Absolute end.
243 sd->pieces[iam][s].end = sd->starts[s + 1] - sd->starts[s];
248 template<bool stable, typename RandomAccessIterator, typename Comparator>
249 struct possibly_stable_sort
253 template<typename RandomAccessIterator, typename Comparator>
254 struct possibly_stable_sort<true, RandomAccessIterator, Comparator>
256 void operator()(const RandomAccessIterator& begin,
257 const RandomAccessIterator& end, Comparator& comp) const
259 __gnu_sequential::stable_sort(begin, end, comp);
263 template<typename RandomAccessIterator, typename Comparator>
264 struct possibly_stable_sort<false, RandomAccessIterator, Comparator>
266 void operator()(const RandomAccessIterator begin,
267 const RandomAccessIterator end, Comparator& comp) const
269 __gnu_sequential::sort(begin, end, comp);
273 template<bool stable, typename SeqRandomAccessIterator,
274 typename RandomAccessIterator, typename Comparator,
275 typename DiffType>
276 struct possibly_stable_multiway_merge
280 template<typename SeqRandomAccessIterator, typename RandomAccessIterator,
281 typename Comparator, typename DiffType>
282 struct possibly_stable_multiway_merge
283 <true, SeqRandomAccessIterator, RandomAccessIterator, Comparator,
284 DiffType>
286 void operator()(const SeqRandomAccessIterator& seqs_begin,
287 const SeqRandomAccessIterator& seqs_end,
288 const RandomAccessIterator& target,
289 Comparator& comp,
290 DiffType length_am) const
292 stable_multiway_merge(seqs_begin, seqs_end, target, length_am, comp,
293 sequential_tag());
297 template<typename SeqRandomAccessIterator, typename RandomAccessIterator,
298 typename Comparator, typename DiffType>
299 struct possibly_stable_multiway_merge
300 <false, SeqRandomAccessIterator, RandomAccessIterator, Comparator,
301 DiffType>
303 void operator()(const SeqRandomAccessIterator& seqs_begin,
304 const SeqRandomAccessIterator& seqs_end,
305 const RandomAccessIterator& target,
306 Comparator& comp,
307 DiffType length_am) const
309 multiway_merge(seqs_begin, seqs_end, target, length_am, comp,
310 sequential_tag());
314 /** @brief PMWMS code executed by each thread.
315 * @param sd Pointer to algorithm data.
316 * @param comp Comparator.
318 template<bool stable, bool exact, typename RandomAccessIterator,
319 typename Comparator>
320 void
321 parallel_sort_mwms_pu(PMWMSSortingData<RandomAccessIterator>* sd,
322 Comparator& comp)
324 typedef std::iterator_traits<RandomAccessIterator> traits_type;
325 typedef typename traits_type::value_type value_type;
326 typedef typename traits_type::difference_type difference_type;
328 thread_index_t iam = omp_get_thread_num();
330 // Length of this thread's chunk, before merging.
331 difference_type length_local = sd->starts[iam + 1] - sd->starts[iam];
333 // Sort in temporary storage, leave space for sentinel.
335 typedef value_type* SortingPlacesIterator;
337 sd->temporary[iam] =
338 static_cast<value_type*>(
339 ::operator new(sizeof(value_type) * (length_local + 1)));
341 // Copy there.
342 std::uninitialized_copy(sd->source + sd->starts[iam],
343 sd->source + sd->starts[iam] + length_local,
344 sd->temporary[iam]);
346 possibly_stable_sort<stable, SortingPlacesIterator, Comparator>()
347 (sd->temporary[iam], sd->temporary[iam] + length_local, comp);
349 // Invariant: locally sorted subsequence in sd->temporary[iam],
350 // sd->temporary[iam] + length_local.
352 // No barrier here: Synchronization is done by the splitting routine.
354 difference_type num_samples =
355 _Settings::get().sort_mwms_oversampling * sd->num_threads - 1;
356 split_consistently
357 <exact, RandomAccessIterator, Comparator, SortingPlacesIterator>()
358 (iam, sd, comp, num_samples);
360 // Offset from target begin, length after merging.
361 difference_type offset = 0, length_am = 0;
362 for (thread_index_t s = 0; s < sd->num_threads; s++)
364 length_am += sd->pieces[iam][s].end - sd->pieces[iam][s].begin;
365 offset += sd->pieces[iam][s].begin;
368 typedef std::vector<
369 std::pair<SortingPlacesIterator, SortingPlacesIterator> >
370 seq_vector_type;
371 seq_vector_type seqs(sd->num_threads);
373 for (int s = 0; s < sd->num_threads; ++s)
375 seqs[s] =
376 std::make_pair(sd->temporary[s] + sd->pieces[iam][s].begin,
377 sd->temporary[s] + sd->pieces[iam][s].end);
380 possibly_stable_multiway_merge<
381 stable,
382 typename seq_vector_type::iterator,
383 RandomAccessIterator,
384 Comparator, difference_type>()
385 (seqs.begin(), seqs.end(),
386 sd->source + offset, comp,
387 length_am);
389 # pragma omp barrier
391 ::operator delete(sd->temporary[iam]);
394 /** @brief PMWMS main call.
395 * @param begin Begin iterator of sequence.
396 * @param end End iterator of sequence.
397 * @param comp Comparator.
398 * @param n Length of sequence.
399 * @param num_threads Number of threads to use.
401 template<bool stable, bool exact, typename RandomAccessIterator,
402 typename Comparator>
403 void
404 parallel_sort_mwms(RandomAccessIterator begin, RandomAccessIterator end,
405 Comparator comp,
406 thread_index_t num_threads)
408 _GLIBCXX_CALL(end - begin)
410 typedef std::iterator_traits<RandomAccessIterator> traits_type;
411 typedef typename traits_type::value_type value_type;
412 typedef typename traits_type::difference_type difference_type;
414 difference_type n = end - begin;
416 if (n <= 1)
417 return;
419 // at least one element per thread
420 if (num_threads > n)
421 num_threads = static_cast<thread_index_t>(n);
423 // shared variables
424 PMWMSSortingData<RandomAccessIterator> sd;
425 difference_type* starts;
427 # pragma omp parallel num_threads(num_threads)
429 num_threads = omp_get_num_threads(); //no more threads than requested
431 # pragma omp single
433 sd.num_threads = num_threads;
434 sd.source = begin;
436 sd.temporary = new value_type*[num_threads];
438 if (!exact)
440 difference_type size =
441 (_Settings::get().sort_mwms_oversampling * num_threads - 1)
442 * num_threads;
443 sd.samples = static_cast<value_type*>(
444 ::operator new(size * sizeof(value_type)));
446 else
447 sd.samples = NULL;
449 sd.offsets = new difference_type[num_threads - 1];
450 sd.pieces = new std::vector<Piece<difference_type> >[num_threads];
451 for (int s = 0; s < num_threads; ++s)
452 sd.pieces[s].resize(num_threads);
453 starts = sd.starts = new difference_type[num_threads + 1];
455 difference_type chunk_length = n / num_threads;
456 difference_type split = n % num_threads;
457 difference_type pos = 0;
458 for (int i = 0; i < num_threads; ++i)
460 starts[i] = pos;
461 pos += (i < split) ? (chunk_length + 1) : chunk_length;
463 starts[num_threads] = pos;
464 } //single
466 // Now sort in parallel.
467 parallel_sort_mwms_pu<stable, exact>(&sd, comp);
468 } //parallel
470 delete[] starts;
471 delete[] sd.temporary;
473 if (!exact)
474 ::operator delete(sd.samples);
476 delete[] sd.offsets;
477 delete[] sd.pieces;
479 } //namespace __gnu_parallel
481 #endif