Project revived from Feb2017
[EroSomnia.git] / deps / boost_1_63_0 / boost / compute / algorithm / detail / scan_on_cpu.hpp
blobd81117c65f539b1ecda353f5bf9cd49ff5997d5f
1 //---------------------------------------------------------------------------//
2 // Copyright (c) 2016 Jakub Szuppe <j.szuppe@gmail.com>
3 //
4 // Distributed under the Boost Software License, Version 1.0
5 // See accompanying file LICENSE_1_0.txt or copy at
6 // http://www.boost.org/LICENSE_1_0.txt
7 //
8 // See http://boostorg.github.com/compute for more information.
9 //---------------------------------------------------------------------------//
11 #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_SCAN_ON_CPU_HPP
12 #define BOOST_COMPUTE_ALGORITHM_DETAIL_SCAN_ON_CPU_HPP
14 #include <iterator>
16 #include <boost/compute/device.hpp>
17 #include <boost/compute/kernel.hpp>
18 #include <boost/compute/command_queue.hpp>
19 #include <boost/compute/algorithm/detail/serial_scan.hpp>
20 #include <boost/compute/detail/meta_kernel.hpp>
21 #include <boost/compute/detail/iterator_range_size.hpp>
22 #include <boost/compute/detail/parameter_cache.hpp>
24 namespace boost {
25 namespace compute {
26 namespace detail {
28 template<class InputIterator, class OutputIterator, class T, class BinaryOperator>
29 inline OutputIterator scan_on_cpu(InputIterator first,
30 InputIterator last,
31 OutputIterator result,
32 bool exclusive,
33 T init,
34 BinaryOperator op,
35 command_queue &queue)
37 typedef typename
38 std::iterator_traits<InputIterator>::value_type input_type;
39 typedef typename
40 std::iterator_traits<OutputIterator>::value_type output_type;
42 const context &context = queue.get_context();
43 const device &device = queue.get_device();
44 const size_t compute_units = queue.get_device().compute_units();
46 boost::shared_ptr<parameter_cache> parameters =
47 detail::parameter_cache::get_global_cache(device);
49 std::string cache_key =
50 "__boost_scan_cpu_" + boost::lexical_cast<std::string>(sizeof(T));
52 // for inputs smaller than serial_scan_threshold
53 // serial_scan algorithm is used
54 uint_ serial_scan_threshold =
55 parameters->get(cache_key, "serial_scan_threshold", 16384 * sizeof(T));
56 serial_scan_threshold =
57 (std::max)(serial_scan_threshold, uint_(compute_units));
59 size_t count = detail::iterator_range_size(first, last);
60 if(count == 0){
61 return result;
63 else if(count < serial_scan_threshold) {
64 return serial_scan(first, last, result, exclusive, init, op, queue);
67 buffer block_partial_sums(context, sizeof(output_type) * compute_units );
69 // create scan kernel
70 meta_kernel k("scan_on_cpu_block_scan");
72 // Arguments
73 size_t count_arg = k.add_arg<uint_>("count");
74 size_t init_arg = k.add_arg<output_type>("initial_value");
75 size_t block_partial_sums_arg =
76 k.add_arg<output_type *>(memory_object::global_memory, "block_partial_sums");
78 k <<
79 "uint block = " <<
80 "(uint)ceil(((float)count)/(get_global_size(0) + 1));\n" <<
81 "uint index = get_global_id(0) * block;\n" <<
82 "uint end = min(count, index + block);\n";
84 if(!exclusive){
85 k <<
86 k.decl<output_type>("sum") << " = " <<
87 first[k.var<uint_>("index")] << ";\n" <<
88 result[k.var<uint_>("index")] << " = sum;\n" <<
89 "index++;\n";
91 else {
92 k <<
93 k.decl<output_type>("sum") << ";\n" <<
94 "if(index == 0){\n" <<
95 "sum = initial_value;\n" <<
96 "}\n" <<
97 "else {\n" <<
98 "sum = " << first[k.var<uint_>("index")] << ";\n" <<
99 "index++;\n" <<
100 "}\n";
103 k <<
104 "while(index < end){\n" <<
105 // load next value
106 k.decl<const input_type>("value") << " = "
107 << first[k.var<uint_>("index")] << ";\n";
109 if(exclusive){
110 k <<
111 "if(get_global_id(0) == 0){\n" <<
112 result[k.var<uint_>("index")] << " = sum;\n" <<
113 "}\n";
115 k <<
116 "sum = " << op(k.var<output_type>("sum"),
117 k.var<output_type>("value")) << ";\n";
119 if(!exclusive){
120 k <<
121 "if(get_global_id(0) == 0){\n" <<
122 result[k.var<uint_>("index")] << " = sum;\n" <<
123 "}\n";
126 k <<
127 "index++;\n" <<
128 "}\n" << // end while
129 "block_partial_sums[get_global_id(0)] = sum;\n";
131 // compile scan kernel
132 kernel block_scan_kernel = k.compile(context);
134 // setup kernel arguments
135 block_scan_kernel.set_arg(count_arg, static_cast<uint_>(count));
136 block_scan_kernel.set_arg(init_arg, static_cast<output_type>(init));
137 block_scan_kernel.set_arg(block_partial_sums_arg, block_partial_sums);
139 // execute the kernel
140 size_t global_work_size = compute_units;
141 queue.enqueue_1d_range_kernel(block_scan_kernel, 0, global_work_size, 0);
143 // scan is done
144 if(compute_units < 2) {
145 return result + count;
148 // final scan kernel
149 meta_kernel l("scan_on_cpu_final_scan");
151 // Arguments
152 count_arg = l.add_arg<uint_>("count");
153 block_partial_sums_arg =
154 l.add_arg<output_type *>(memory_object::global_memory, "block_partial_sums");
156 l <<
157 "uint block = " <<
158 "(uint)ceil(((float)count)/(get_global_size(0) + 1));\n" <<
159 "uint index = block + get_global_id(0) * block;\n" <<
160 "uint end = min(count, index + block);\n" <<
162 k.decl<output_type>("sum") << " = block_partial_sums[0];\n" <<
163 "for(uint i = 0; i < get_global_id(0); i++) {\n" <<
164 "sum = " << op(k.var<output_type>("sum"),
165 k.var<output_type>("block_partial_sums[i + 1]")) << ";\n" <<
166 "}\n" <<
168 "while(index < end){\n";
169 if(exclusive){
170 l <<
171 l.decl<output_type>("value") << " = "
172 << first[k.var<uint_>("index")] << ";\n" <<
173 result[k.var<uint_>("index")] << " = sum;\n" <<
174 "sum = " << op(k.var<output_type>("sum"),
175 k.var<output_type>("value")) << ";\n";
177 else {
178 l <<
179 "sum = " << op(k.var<output_type>("sum"),
180 first[k.var<uint_>("index")]) << ";\n" <<
181 result[k.var<uint_>("index")] << " = sum;\n";
183 l <<
184 "index++;\n" <<
185 "}\n";
188 // compile scan kernel
189 kernel final_scan_kernel = l.compile(context);
191 // setup kernel arguments
192 final_scan_kernel.set_arg(count_arg, static_cast<uint_>(count));
193 final_scan_kernel.set_arg(block_partial_sums_arg, block_partial_sums);
195 // execute the kernel
196 global_work_size = compute_units;
197 queue.enqueue_1d_range_kernel(final_scan_kernel, 0, global_work_size, 0);
199 // return iterator pointing to the end of the result range
200 return result + count;
203 } // end detail namespace
204 } // end compute namespace
205 } // end boost namespace
207 #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_SCAN_ON_CPU_HPP