Simplify compiling GPU code for tests
[gromacs.git] / src / gromacs / mdspan / mdspan.h
blob10499d655dcf55f840a81c92f2111818d17d0eba
1 /*
2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2018,2019, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
36 * This file is a modified version of original work of Sandia Corporation.
37 * In the spirit of the original code, this particular file can be distributed
38 * on the terms of Sandia Corporation.
41 * Kokkos v. 2.0
42 * Copyright (2014) Sandia Corporation
44 * Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
45 * the U.S. Government retains certain rights in this software.
47 * Kokkos is licensed under 3-clause BSD terms of use:
49 * Redistribution and use in source and binary forms, with or without
50 * modification, are permitted provided that the following conditions are
51 * met:
53 * 1. Redistributions of source code must retain the above copyright
54 * notice, this list of conditions and the following disclaimer.
56 * 2. Redistributions in binary form must reproduce the above copyright
57 * notice, this list of conditions and the following disclaimer in the
58 * documentation and/or other materials provided with the distribution.
60 * 3. Neither the name of the Corporation nor the names of the
61 * contributors may be used to endorse or promote products derived from
62 * this software without specific prior written permission.
64 * THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
65 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
66 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
67 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
68 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
69 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
70 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
71 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
72 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
73 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
74 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
76 * Questions? Contact Christian R. Trott (crtrott@sandia.gov)
78 /*! \libinternal \file
79 * \brief Declares gmx::mdspan
81 * \author Christian Trott <crtrott@sandia.gov>
82 * \author Ronan Keryell <ronan.keryell@xilinx.com>
83 * \author Carter Edwards <hedwards@nvidia.com>
84 * \author David Hollman <dshollm@sandia.gov>
85 * \author Christian Blau <cblau@gwdg.de>
86 * \inlibraryapi
87 * \ingroup mdspan
89 #ifndef MDSPAN_MDSPAN_H
90 #define MDSPAN_MDSPAN_H
92 #include <array>
93 #include <type_traits>
95 #include "accessor_policy.h"
96 #include "extents.h"
97 #include "layouts.h"
99 namespace gmx
102 /*! \libinternal \brief Multidimensional array indexing and memory access with flexible mapping and access model.
104 * \tparam ElementType Type of elemnt to be viewed
105 * \tparam Extents The dimensions of the multidimenisonal array to view.
106 * \tparam LayoutPolicy Describes is the memory layout of the multidimensional array; right by default.
107 * \tparam AccessorPolicy Describes memory access model.
109 template<class ElementType, class Extents, class LayoutPolicy = layout_right, class AccessorPolicy = accessor_basic<ElementType>>
110 class basic_mdspan
112 public:
113 //! Expose type used to define the extents of the data.
114 using extents_type = Extents;
115 //! Expose type used to define the layout of the data.
116 using layout_type = LayoutPolicy;
117 //! Expose type used to define the memory access model of the data.
118 using accessor_type = AccessorPolicy;
119 //! Expose type used to map multidimensional indices to one-dimensioal indices.
120 using mapping_type = typename layout_type::template mapping<extents_type>;
121 //! Exposes the type of stored element.
122 using element_type = typename accessor_type::element_type;
123 //! Expose the underlying type of the stored elements.
124 using value_type = std::remove_cv_t<element_type>;
125 //! Expose the type used for indexing.
126 using index_type = ptrdiff_t;
127 //! Expose type for index differences.
128 using difference_type = ptrdiff_t;
129 //! Expose underlying pointer to data type.
130 using pointer = typename accessor_type::pointer;
131 //! Expose reference to data type.
132 using reference = typename accessor_type::reference;
134 //! Trivial constructor
135 constexpr basic_mdspan() noexcept : acc_(), map_(), ptr_() {}
136 //! Move constructor
137 constexpr basic_mdspan(basic_mdspan&& other) noexcept = default;
138 //! copy constructor
139 constexpr basic_mdspan(const basic_mdspan& other) noexcept = default;
140 //! Copy assignment
141 basic_mdspan& operator=(const basic_mdspan& other) noexcept = default;
142 //! Move assignment
143 basic_mdspan& operator=(basic_mdspan&& other) noexcept = default;
145 //! Copy constructor
146 template<class OtherElementType, class OtherExtents, class OtherLayoutPolicy, class OtherAccessor>
147 constexpr basic_mdspan(
148 const basic_mdspan<OtherElementType, OtherExtents, OtherLayoutPolicy, OtherAccessor>& rhs) noexcept :
149 acc_(rhs.acc_),
150 map_(rhs.map_),
151 ptr_(rhs.ptr_)
154 //! Copy assignment constructor
155 template<class OtherElementType, class OtherExtents, class OtherLayoutPolicy, class OtherAccessor>
156 basic_mdspan&
157 operator=(const basic_mdspan<OtherElementType, OtherExtents, OtherLayoutPolicy, OtherAccessor>& rhs) noexcept
159 acc_ = rhs.acc_;
160 map_ = rhs.map_;
161 ptr_ = rhs.ptr_;
162 return *this;
165 /*!\brief Construct mdspan by setting the dynamic extents and pointer to data.
166 * \param[in] ptr Pointer to data to be accessed by this span
167 * \param[in] DynamicExtents
168 * \tparam IndexType index type to describe dynamic extents
170 template<class... IndexType>
171 explicit constexpr basic_mdspan(pointer ptr, IndexType... DynamicExtents) noexcept :
172 acc_(accessor_type()),
173 map_(extents_type(DynamicExtents...)),
174 ptr_(ptr)
177 /*! \brief Construct from array describing dynamic extents.
178 * \param[in] ptr Pointer to data to be accessed by this span
179 * \param[in] dynamic_extents Array the size of dynamic extents.
181 constexpr basic_mdspan(pointer ptr,
182 const std::array<ptrdiff_t, extents_type::rank_dynamic()>& dynamic_extents) :
183 acc_(accessor_type()),
184 map_(extents_type(dynamic_extents)),
185 ptr_(ptr)
188 /*! \brief Construct from pointer and mapping.
189 * \param[in] ptr Pointer to data to be accessed by this span
190 * \param[in] m Mapping from multidimenisonal indices to one-dimensional offset.
192 constexpr basic_mdspan(pointer ptr, const mapping_type& m) noexcept :
193 acc_(accessor_type()),
194 map_(m),
195 ptr_(ptr)
198 /*! \brief Construct with pointer, mapping and accessor.
199 * \param[in] ptr Pointer to data to be accessed by this span
200 * \param[in] m Mapping from multidimenisonal indices to one-dimensional offset.
201 * \param[in] a Accessor implementing memory access model.
203 constexpr basic_mdspan(pointer ptr, const mapping_type& m, const accessor_type& a) noexcept :
204 acc_(a),
205 map_(m),
206 ptr_(ptr)
209 /*! \brief Construct mdspan from multidimensional arrays implemented with mdspan
211 * Requires the container to have a view_type describing the mdspan, which is
212 * accessible through an asView() call
214 * This allows functions to declare mdspans as arguments, but take e.g. multidimensional
215 * arrays implicitly during the function call
216 * \tparam U container type
217 * \param[in] other mdspan-implementing container
219 template<typename U, typename = std::enable_if_t<std::is_same<typename std::remove_reference_t<U>::view_type::element_type, ElementType>::value>>
220 constexpr basic_mdspan(U&& other) : basic_mdspan(other.asView())
223 /*! \brief Construct mdspan of const Elements from multidimensional arrays implemented with mdspan
225 * Requires the container to have a const_view_type describing the mdspan, which is
226 * accessible through an asConstView() call
228 * This allows functions to declare mdspans as arguments, but take e.g. multidimensional
229 * arrays implicitly during the function call
230 * \tparam U container type
231 * \param[in] other mdspan-implementing container
233 template<typename U, typename = std::enable_if_t<std::is_same<typename std::remove_reference_t<U>::const_view_type::element_type, ElementType>::value>>
234 constexpr basic_mdspan(const U& other) : basic_mdspan(other.asConstView())
237 /*! \brief Brace operator to access multidimensional array element.
238 * \param[in] indices The multidimensional indices of the object.
239 * Requires rank() == sizeof...(IndexType). Slicing is implemented via sub_span.
240 * \returns reference to element at indices.
242 template<class... IndexType>
243 constexpr std::enable_if_t<sizeof...(IndexType) == extents_type::rank(), reference>
244 operator()(IndexType... indices) const noexcept
246 return acc_.access(ptr_, map_(indices...));
248 /*! \brief Canonical bracket operator for one-dimensional arrays.
249 * Allows mdspan to act like array in one-dimension.
250 * Enabled only when rank==1.
251 * \param[in] i one-dimensional index
252 * \returns reference to element stored at position i
254 template<class IndexType>
255 constexpr std::enable_if_t<std::is_integral<IndexType>::value && extents_type::rank() == 1, reference>
256 operator[](const IndexType& i) const noexcept
258 return acc_.access(ptr_, map_(i));
260 /*! \brief Bracket operator for multi-dimensional arrays.
262 * \note Prefer operator() for better compile-time and run-time performance
264 * Slices two- and higher-dimensional arrays along a given slice by
265 * returning a new basic_mdspan that drops the first extent and indexes
266 * the remaining extents
268 * \note Currently only implemented for layout_right
269 * \note For layout_right this implementation has significant
270 * performance benefits over implementing a more general slicing
271 * operator with a strided layout
272 * \note Enabled only when rank() > 1
274 * \tparam IndexType integral tyoe for the index that enables indexing
275 * with, e.g., int or size_t
276 * \param[in] index one-dimensional index of the slice to be indexed
278 * \returns basic_mdspan that is sliced at the given index
280 template<class IndexType,
281 typename sliced_mdspan_type = basic_mdspan<element_type, decltype(extents_type().sliced_extents()), LayoutPolicy, AccessorPolicy>>
282 constexpr std::enable_if_t<std::is_integral<IndexType>::value && (extents_type::rank() > 1)
283 && std::is_same<LayoutPolicy, layout_right>::value,
284 sliced_mdspan_type>
285 operator[](const IndexType index) const noexcept
287 return sliced_mdspan_type(ptr_ + index * stride(0), extents().sliced_extents());
289 //! Report the rank.
290 static constexpr int rank() noexcept { return extents_type::rank(); }
291 //! Report the dynamic rank.
292 static constexpr int rank_dynamic() noexcept { return extents_type::rank_dynamic(); }
293 /*! \brief Return the static extent.
294 * \param[in] k dimension to query for static extent
295 * \returns static extent along specified dimension
297 constexpr index_type static_extent(size_t k) const noexcept
299 return map_.extents().static_extent(k);
302 /*! \brief Return the extent.
303 * \param[in] k dimension to query for extent
304 * \returns extent along specified dimension
306 constexpr index_type extent(int k) const noexcept { return map_.extents().extent(k); }
308 //! Return all extents
309 constexpr const extents_type& extents() const noexcept { return map_.extents(); }
310 //! Report if mappings for this basic_span is always unique.
311 static constexpr bool is_always_unique() noexcept { return mapping_type::is_always_unique(); }
312 //! Report if mapping for this basic_span is always strided
313 static constexpr bool is_always_strided() noexcept { return mapping_type::is_always_strided(); }
314 //! Report if mapping for this basic_span is always is_contiguous
315 static constexpr bool is_always_contiguous() noexcept
317 return mapping_type::is_always_contiguous();
319 //! Report if the currently applied map is unique
320 constexpr bool is_unique() const noexcept { return map_.is_unique(); }
321 //! Report if the currently applied map is strided
322 constexpr bool is_strided() const noexcept { return map_.is_strided(); }
323 //! Report if the currently applied map is contiguous
324 constexpr bool is_contiguous() const noexcept { return map_.is_contiguous(); }
325 //! Report stride along a specific rank.
326 constexpr index_type stride(size_t r) const noexcept { return map_.stride(r); }
327 //! Return the currently applied mapping.
328 constexpr mapping_type mapping() const noexcept { return map_; }
329 //! Return the memory access model.
330 constexpr accessor_type accessor() const noexcept { return acc_; }
331 //! Return pointer to underlying data
332 constexpr pointer data() const noexcept { return ptr_; }
334 private:
335 //! The memory access model
336 accessor_type acc_;
337 //! The transformation from multidimenisonal index to memory offset.
338 mapping_type map_;
339 //! Memory location handle
340 pointer ptr_;
343 //! basic_mdspan with wrapped indices, basic_accessor policiy and right-aligned memory layout.
344 template<class T, ptrdiff_t... Indices>
345 using mdspan = basic_mdspan<T, extents<Indices...>, layout_right, accessor_basic<T>>;
347 } // namespace gmx
349 #endif /* end of include guard: MDSPAN_MDSPAN_H */