From f349f90243165482f0d4cd660319ae3cdb5df240 Mon Sep 17 00:00:00 2001 From: Magnus Lundborg Date: Thu, 7 Nov 2019 16:47:06 +0100 Subject: [PATCH] More SIMD preparations in the FE calculations Prepare the launch of the FE kernel using SIMD. So far the kernel is not modified to actually use SIMD. Change-Id: Iaad24fc37549b5deaa892655be1d2a7317f65955 --- src/gromacs/gmxlib/nonbonded/nb_free_energy.cpp | 237 ++++++++++++++++-------- 1 file changed, 157 insertions(+), 80 deletions(-) diff --git a/src/gromacs/gmxlib/nonbonded/nb_free_energy.cpp b/src/gromacs/gmxlib/nonbonded/nb_free_energy.cpp index ecd6aa7917..7785ab657a 100644 --- a/src/gromacs/gmxlib/nonbonded/nb_free_energy.cpp +++ b/src/gromacs/gmxlib/nonbonded/nb_free_energy.cpp @@ -51,66 +51,104 @@ #include "gromacs/mdtypes/forceoutput.h" #include "gromacs/mdtypes/forcerec.h" #include "gromacs/mdtypes/md_enums.h" +#include "gromacs/simd/simd.h" #include "gromacs/utility/fatalerror.h" +//! Scalar (non-SIMD) data types. +struct ScalarDataTypes +{ + using RealType = real; //!< The data type to use as real. + using IntType = int; //!< The data type to use as int. + static constexpr int simdRealWidth = 1; //!< The width of the RealType. + static constexpr int simdIntWidth = 1; //!< The width of the IntType. +}; + +#if GMX_SIMD_HAVE_REAL && GMX_SIMD_HAVE_INT32_ARITHMETICS +//! SIMD data types. +struct SimdDataTypes +{ + using RealType = gmx::SimdReal; //!< The data type to use as real. + using IntType = gmx::SimdInt32; //!< The data type to use as int. + static constexpr int simdRealWidth = GMX_SIMD_REAL_WIDTH; //!< The width of the RealType. + static constexpr int simdIntWidth = GMX_SIMD_FINT32_WIDTH; //!< The width of the IntType. +}; +#endif + //! Computes r^(1/p) and 1/r^(1/p) for the standard p=6 -static inline void pthRoot(const real r, real* pthRoot, real* invPthRoot) +template +static inline void pthRoot(const RealType r, RealType* pthRoot, RealType* invPthRoot) { *invPthRoot = gmx::invsqrt(std::cbrt(r)); *pthRoot = 1 / (*invPthRoot); } -static inline real calculateRinv6(const real rinvV) +template +static inline RealType calculateRinv6(const RealType rinvV) { - real rinv6 = rinvV * rinvV; + RealType rinv6 = rinvV * rinvV; return (rinv6 * rinv6 * rinv6); } -static inline real calculateVdw6(const real c6, const real rinv6) +template +static inline RealType calculateVdw6(const RealType c6, const RealType rinv6) { return (c6 * rinv6); } -static inline real calculateVdw12(const real c12, const real rinv6) +template +static inline RealType calculateVdw12(const RealType c12, const RealType rinv6) { return (c12 * rinv6 * rinv6); } /* reaction-field electrostatics */ -static inline real -reactionFieldScalarForce(const real qq, const real rinv, const real r, const real krf, const real two) +template +static inline RealType reactionFieldScalarForce(const RealType qq, + const RealType rinv, + const RealType r, + const real krf, + const real two) { return (qq * (rinv - two * krf * r * r)); } -static inline real reactionFieldPotential(const real qq, const real rinv, const real r, const real krf, const real potentialShift) +template +static inline RealType reactionFieldPotential(const RealType qq, + const RealType rinv, + const RealType r, + const real krf, + const real potentialShift) { return (qq * (rinv + krf * r * r - potentialShift)); } /* Ewald electrostatics */ -static inline real ewaldScalarForce(const real coulomb, const real rinv) +template +static inline RealType ewaldScalarForce(const RealType coulomb, const RealType rinv) { return (coulomb * rinv); } -static inline real ewaldPotential(const real coulomb, const real rinv, const real potentialShift) +template +static inline RealType ewaldPotential(const RealType coulomb, const RealType rinv, const real potentialShift) { return (coulomb * (rinv - potentialShift)); } /* cutoff LJ */ -static inline real lennardJonesScalarForce(const real v6, const real v12) +template +static inline RealType lennardJonesScalarForce(const RealType v6, const RealType v12) { return (v12 - v6); } -static inline real lennardJonesPotential(const real v6, - const real v12, - const real c6, - const real c12, - const real repulsionShift, - const real dispersionShift, - const real onesixth, - const real onetwelfth) +template +static inline RealType lennardJonesPotential(const RealType v6, + const RealType v12, + const RealType c6, + const RealType c12, + const real repulsionShift, + const real dispersionShift, + const real onesixth, + const real onetwelfth) { return ((v12 + c12 * repulsionShift) * onetwelfth - (v6 + c6 * dispersionShift) * onesixth); } @@ -122,13 +160,14 @@ static inline real ewaldLennardJonesGridSubtract(const real c6grid, const real p } /* LJ Potential switch */ -static inline real potSwitchScalarForceMod(const real fScalarInp, - const real potential, - const real sw, - const real r, - const real rVdw, - const real dsw, - const real zero) +template +static inline RealType potSwitchScalarForceMod(const RealType fScalarInp, + const RealType potential, + const RealType sw, + const RealType r, + const RealType rVdw, + const RealType dsw, + const real zero) { if (r < rVdw) { @@ -137,8 +176,12 @@ static inline real potSwitchScalarForceMod(const real fScalarInp, } return (zero); } -static inline real -potSwitchPotentialMod(const real potentialInp, const real sw, const real r, const real rVdw, const real zero) +template +static inline RealType potSwitchPotentialMod(const RealType potentialInp, + const RealType sw, + const RealType r, + const RealType rVdw, + const real zero) { if (r < rVdw) { @@ -150,7 +193,7 @@ potSwitchPotentialMod(const real potentialInp, const real sw, const real r, cons //! Templated free-energy non-bonded kernel -template +template static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist, rvec* gmx_restrict xx, gmx::ForceWithShiftForces* forceWithShiftForces, @@ -163,6 +206,10 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist, #define STATE_B 1 #define NSTATES 2 + using RealType = typename DataTypes::RealType; + using IntType = typename DataTypes::IntType; + + /* FIXME: How should these be handled with SIMD? */ constexpr real onetwelfth = 1.0 / 12.0; constexpr real onesixth = 1.0 / 6.0; constexpr real zero = 0.0; @@ -339,17 +386,17 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist, for (int k = nj0; k < nj1; k++) { - int tj[NSTATES]; - const int jnr = jjnr[k]; - const int j3 = 3 * jnr; - real c6[NSTATES], c12[NSTATES], qq[NSTATES], Vcoul[NSTATES], Vvdw[NSTATES]; - real r, rinv, rp, rpm2; - real alpha_vdw_eff, alpha_coul_eff, sigma6[NSTATES]; - const real dx = ix - x[j3]; - const real dy = iy - x[j3 + 1]; - const real dz = iz - x[j3 + 2]; - const real rsq = dx * dx + dy * dy + dz * dz; - real FscalC[NSTATES], FscalV[NSTATES]; + int tj[NSTATES]; + const int jnr = jjnr[k]; + const int j3 = 3 * jnr; + RealType c6[NSTATES], c12[NSTATES], qq[NSTATES], Vcoul[NSTATES], Vvdw[NSTATES]; + RealType r, rinv, rp, rpm2; + RealType alpha_vdw_eff, alpha_coul_eff, sigma6[NSTATES]; + const RealType dx = ix - x[j3]; + const RealType dy = iy - x[j3 + 1]; + const RealType dz = iz - x[j3 + 2]; + const RealType rsq = dx * dx + dy * dy + dz * dz; + RealType FscalC[NSTATES], FscalV[NSTATES]; if (rsq >= rcutoff_max2) { @@ -398,7 +445,7 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist, rp = 1; } - real Fscal = 0; + RealType Fscal = 0; qq[STATE_A] = iqA * chargeA[jnr]; qq[STATE_B] = iqB * chargeB[jnr]; @@ -454,7 +501,7 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist, Vcoul[i] = 0; Vvdw[i] = 0; - real rinvC, rinvV, rC, rV, rpinvC, rpinvV; + RealType rinvC, rinvV, rC, rV, rpinvC, rpinvV; /* Only spend time on A or B state if it is non-zero */ if ((qq[i] != 0) || (c6[i] != 0) || (c12[i] != 0)) @@ -518,7 +565,7 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist, || (!vdwInteractionTypeIsEwald && rV < rvdw); if ((c6[i] != 0 || c12[i] != 0) && computeVdwInteraction) { - real rinv6; + RealType rinv6; if (useSoftCore) { rinv6 = rpinvV; @@ -527,8 +574,8 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist, { rinv6 = calculateRinv6(rinvV); } - real Vvdw6 = calculateVdw6(c6[i], rinv6); - real Vvdw12 = calculateVdw12(c12[i], rinv6); + RealType Vvdw6 = calculateVdw6(c6[i], rinv6); + RealType Vvdw12 = calculateVdw12(c12[i], rinv6); Vvdw[i] = lennardJonesPotential(Vvdw6, Vvdw12, c6[i], c12[i], repulsionShift, dispersionShift, onesixth, onetwelfth); @@ -543,11 +590,12 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist, if (vdwModifierIsPotSwitch) { - real d = rV - ic->rvdw_switch; - d = (d > zero) ? d : zero; - const real d2 = d * d; - const real sw = one + d2 * d * (vdw_swV3 + d * (vdw_swV4 + d * vdw_swV5)); - const real dsw = d2 * (vdw_swF2 + d * (vdw_swF3 + d * vdw_swF4)); + RealType d = rV - ic->rvdw_switch; + d = (d > zero) ? d : zero; + const RealType d2 = d * d; + const RealType sw = + one + d2 * d * (vdw_swV3 + d * (vdw_swV4 + d * vdw_swV5)); + const RealType dsw = d2 * (vdw_swF2 + d * (vdw_swF3 + d * vdw_swF4)); FscalV[i] = potSwitchScalarForceMod(FscalV[i], Vvdw[i], sw, rV, rvdw, dsw, zero); @@ -595,7 +643,7 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist, * As there is no singularity, there is no need for soft-core. */ const real FF = -two * krf; - real VV = krf * rsq - crf; + RealType VV = krf * rsq - crf; if (ii == jnr) { @@ -622,11 +670,11 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist, */ real v_lr, f_lr; - const real ewrt = r * ewtabscale; - int ewitab = static_cast(ewrt); - const real eweps = ewrt - ewitab; - ewitab = 4 * ewitab; - f_lr = ewtab[ewitab] + eweps * ewtab[ewitab + 1]; + const RealType ewrt = r * ewtabscale; + IntType ewitab = static_cast(ewrt); + const RealType eweps = ewrt - ewitab; + ewitab = 4 * ewitab; + f_lr = ewtab[ewitab] + eweps * ewtab[ewitab + 1]; v_lr = (ewtab[ewitab + 2] - ewtabhalfspace * eweps * (ewtab[ewitab] + f_lr)); f_lr *= rinv; @@ -667,15 +715,16 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist, * r close to 0 for non-interacting pairs. */ - const real rs = rsq * rinv * ewtabscale; - const int ri = static_cast(rs); - const real frac = rs - ri; - const real f_lr = (1 - frac) * tab_ewald_F_lj[ri] + frac * tab_ewald_F_lj[ri + 1]; + const RealType rs = rsq * rinv * ewtabscale; + const IntType ri = static_cast(rs); + const RealType frac = rs - ri; + const RealType f_lr = (1 - frac) * tab_ewald_F_lj[ri] + frac * tab_ewald_F_lj[ri + 1]; /* TODO: Currently the Ewald LJ table does not contain * the factor 1/6, we should add this. */ - const real FF = f_lr * rinv / six; - real VV = (tab_ewald_V_lj[ri] - ewtabhalfspace * frac * (tab_ewald_F_lj[ri] + f_lr)) / six; + const RealType FF = f_lr * rinv / six; + RealType VV = + (tab_ewald_V_lj[ri] - ewtabhalfspace * frac * (tab_ewald_F_lj[ri] + f_lr)) / six; if (ii == jnr) { @@ -777,51 +826,74 @@ typedef void (*KernelFunction)(const t_nblist* gmx_restrict nlist, nb_kernel_data_t* gmx_restrict kernel_data, t_nrnb* gmx_restrict nrnb); +template +static KernelFunction dispatchKernelOnUseSimd(const bool useSimd) +{ + if (useSimd) + { +#if GMX_SIMD_HAVE_REAL && GMX_SIMD_HAVE_INT32_ARITHMETICS + /* FIXME: Here SimdDataTypes should be used to enable SIMD. So far, the code in nb_free_energy_kernel is not adapted to SIMD */ + return (nb_free_energy_kernel); +#else + return (nb_free_energy_kernel); +#endif + } + else + { + return (nb_free_energy_kernel); + } +} + template -static KernelFunction dispatchKernelOnVdwModifier(const bool vdwModifierIsPotSwitch) +static KernelFunction dispatchKernelOnVdwModifier(const bool vdwModifierIsPotSwitch, const bool useSimd) { if (vdwModifierIsPotSwitch) { - return (nb_free_energy_kernel); + return (dispatchKernelOnUseSimd(useSimd)); } else { - return (nb_free_energy_kernel); + return (dispatchKernelOnUseSimd(useSimd)); } } template static KernelFunction dispatchKernelOnElecInteractionType(const bool elecInteractionTypeIsEwald, - const bool vdwModifierIsPotSwitch) + const bool vdwModifierIsPotSwitch, + const bool useSimd) { if (elecInteractionTypeIsEwald) { return (dispatchKernelOnVdwModifier( - vdwModifierIsPotSwitch)); + vdwModifierIsPotSwitch, useSimd)); } else { return (dispatchKernelOnVdwModifier( - vdwModifierIsPotSwitch)); + vdwModifierIsPotSwitch, useSimd)); } } template static KernelFunction dispatchKernelOnVdwInteractionType(const bool vdwInteractionTypeIsEwald, const bool elecInteractionTypeIsEwald, - const bool vdwModifierIsPotSwitch) + const bool vdwModifierIsPotSwitch, + const bool useSimd) { if (vdwInteractionTypeIsEwald) { return (dispatchKernelOnElecInteractionType( - elecInteractionTypeIsEwald, vdwModifierIsPotSwitch)); + elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd)); } else { return (dispatchKernelOnElecInteractionType( - elecInteractionTypeIsEwald, vdwModifierIsPotSwitch)); + elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd)); } } @@ -829,17 +901,18 @@ template static KernelFunction dispatchKernelOnScLambdasOrAlphasDifference(const bool scLambdasOrAlphasDiffer, const bool vdwInteractionTypeIsEwald, const bool elecInteractionTypeIsEwald, - const bool vdwModifierIsPotSwitch) + const bool vdwModifierIsPotSwitch, + const bool useSimd) { if (scLambdasOrAlphasDiffer) { return (dispatchKernelOnVdwInteractionType( - vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch)); + vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd)); } else { return (dispatchKernelOnVdwInteractionType( - vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch)); + vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd)); } } @@ -847,19 +920,20 @@ static KernelFunction dispatchKernel(const bool scLambdasOrAlphasDiffer, const bool vdwInteractionTypeIsEwald, const bool elecInteractionTypeIsEwald, const bool vdwModifierIsPotSwitch, + const bool useSimd, const t_forcerec* fr) { if (fr->sc_alphacoul == 0 && fr->sc_alphavdw == 0) { return (dispatchKernelOnScLambdasOrAlphasDifference( scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, - vdwModifierIsPotSwitch)); + vdwModifierIsPotSwitch, useSimd)); } else { return (dispatchKernelOnScLambdasOrAlphasDifference( scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, - vdwModifierIsPotSwitch)); + vdwModifierIsPotSwitch, useSimd)); } } @@ -879,6 +953,7 @@ void gmx_nb_free_energy_kernel(const t_nblist* nlist, const bool elecInteractionTypeIsEwald = (EEL_PME_EWALD(fr->ic->eeltype)); const bool vdwModifierIsPotSwitch = (fr->ic->vdw_modifier == eintmodPOTSWITCH); bool scLambdasOrAlphasDiffer = true; + const bool useSimd = fr->use_simd_kernels; if (fr->sc_alphacoul == 0 && fr->sc_alphavdw == 0) { @@ -895,7 +970,9 @@ void gmx_nb_free_energy_kernel(const t_nblist* nlist, { GMX_RELEASE_ASSERT(false, "Unsupported soft-core r-power"); } - KernelFunction kernelFunc = dispatchKernel(scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, - elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, fr); + + KernelFunction kernelFunc; + kernelFunc = dispatchKernel(scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, + elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd, fr); kernelFunc(nlist, xx, ff, fr, mdatoms, kernel_data, nrnb); } -- 2.11.4.GIT