diff options
author | Adrian Kummerlaender | 2021-05-17 00:30:13 +0200 |
---|---|---|
committer | Adrian Kummerlaender | 2021-05-17 00:30:13 +0200 |
commit | a92271176a19e06611099c0eccc4e6a6887f4915 (patch) | |
tree | 54067b334bfae7d99c79cfb00da5891334f9514c /src/simd/256.h | |
download | SweepLB-a92271176a19e06611099c0eccc4e6a6887f4915.tar SweepLB-a92271176a19e06611099c0eccc4e6a6887f4915.tar.gz SweepLB-a92271176a19e06611099c0eccc4e6a6887f4915.tar.bz2 SweepLB-a92271176a19e06611099c0eccc4e6a6887f4915.tar.lz SweepLB-a92271176a19e06611099c0eccc4e6a6887f4915.tar.xz SweepLB-a92271176a19e06611099c0eccc4e6a6887f4915.tar.zst SweepLB-a92271176a19e06611099c0eccc4e6a6887f4915.zip |
Extract public version of SweepLB
Diffstat (limited to 'src/simd/256.h')
-rw-r--r-- | src/simd/256.h | 379 |
1 files changed, 379 insertions, 0 deletions
diff --git a/src/simd/256.h b/src/simd/256.h new file mode 100644 index 0000000..a3f419d --- /dev/null +++ b/src/simd/256.h @@ -0,0 +1,379 @@ +#pragma once + +#include <immintrin.h> + +#include <cstdint> +#include <type_traits> + +namespace simd { + +template <std::floating_point T> +class Mask; + +template <> +class Mask<double> { +private: + __m256i _reg; + +public: + using storage_t = std::uint64_t; + static constexpr unsigned storage_size = 1; + + static constexpr storage_t true_v = 1l << 63; + static constexpr storage_t false_v = 0l; + + static storage_t encode(bool value) { + return value ? true_v : false_v; + } + + static storage_t encode(bool* value) { + return encode(*value); + } + + Mask(bool a, bool b, bool c, bool d): + _reg(_mm256_set_epi64x(encode(d),encode(c),encode(b),encode(a))) { } + + Mask(std::uint64_t a, std::uint64_t b, std::uint64_t c, std::uint64_t d): + _reg(_mm256_set_epi64x(d,c,b,a)) { } + + Mask(std::uint64_t* ptr): + _reg(_mm256_loadu_si256(reinterpret_cast<__m256i*>(ptr))) { } + + Mask(storage_t* ptr, std::size_t iCell): + Mask<double>(ptr + iCell) { } + + Mask(__m256i reg): + _reg(reg) { } + + operator __m256i() { + return _reg; + } + + __m256i neg() const { + return _mm256_sub_epi64(_mm256_set1_epi64x(true_v), _reg); + } + + operator bool() const { + const std::uint64_t* values = reinterpret_cast<const std::uint64_t*>(&_reg); + return values[0] == true_v + || values[1] == true_v + || values[2] == true_v + || values[3] == true_v; + } +}; + +template <> +class Mask<float> { +private: + __m256i _reg; + +public: + using storage_t = std::uint32_t; + static constexpr unsigned storage_size = 1; + + static constexpr storage_t true_v = 1 << 31; + static constexpr storage_t false_v = 0; + + static storage_t encode(bool value) { + return value ? true_v : false_v; + } + + static storage_t encode(bool* value) { + return encode(*value); + } + + Mask(storage_t* ptr): + _reg(_mm256_loadu_si256(reinterpret_cast<__m256i*>(ptr))) { } + + Mask(storage_t* ptr, std::size_t iCell): + Mask<float>(ptr + iCell) { } + + Mask(__m256i reg): + _reg(reg) { } + + operator __m256i() { + return _reg; + } + + __m256i neg() const { + return _mm256_sub_epi32(_mm256_set1_epi32(true_v), _reg); + } + + operator bool() const { + const std::uint32_t* values = reinterpret_cast<const std::uint32_t*>(&_reg); + return values[0] == true_v + || values[1] == true_v + || values[2] == true_v + || values[3] == true_v + || values[4] == true_v + || values[5] == true_v + || values[6] == true_v + || values[7] == true_v; + } +}; + + +template <std::floating_point T> +class Pack; + +template <> +class Pack<double> { +private: + __m256d _reg; + +public: + using mask_t = Mask<double>; + using index_t = std::uint32_t; + + static constexpr std::size_t size = 4; + + Pack() = default; + + Pack(__m256d reg): + _reg(reg) { } + + Pack(double val): + Pack(_mm256_set1_pd(val)) { } + + Pack(double a, double b, double c, double d): + Pack(_mm256_set_pd(d,c,b,a)) { } + + Pack(double* ptr): + Pack(_mm256_loadu_pd(ptr)) { + _mm_prefetch(ptr + size, _MM_HINT_T2); + } + + Pack(double* ptr, index_t* idx): + Pack(_mm256_i32gather_pd(ptr, _mm_loadu_si128(reinterpret_cast<__m128i*>(idx)), sizeof(double))) { } + + operator __m256d() { + return _reg; + } + + Pack operator+(Pack rhs) { + return Pack(_mm256_add_pd(_reg, rhs)); + } + + Pack& operator+=(Pack rhs) { + _reg = _mm256_add_pd(_reg, rhs); + return *this; + } + + Pack operator-(Pack rhs) { + return Pack(_mm256_sub_pd(_reg, rhs)); + } + + Pack& operator-=(Pack rhs) { + _reg = _mm256_sub_pd(_reg, rhs); + return *this; + } + + Pack operator*(Pack rhs) { + return Pack(_mm256_mul_pd(_reg, rhs)); + } + + Pack& operator*=(Pack rhs) { + _reg = _mm256_mul_pd(_reg, rhs); + return *this; + } + + Pack operator/(Pack rhs) { + return Pack(_mm256_div_pd(_reg, rhs)); + } + + Pack& operator/=(Pack rhs) { + _reg = _mm256_div_pd(_reg, rhs); + return *this; + } + + Pack operator-() { + return *this * Pack(-1.); + } + + __m256d sqrt() { + return _mm256_sqrt_pd(_reg); + } +}; + +template <> +class Pack<float> { +private: + __m256 _reg; + +public: + using mask_t = Mask<float>; + using index_t = std::uint32_t; + + static constexpr std::size_t size = 8; + + Pack() = default; + + Pack(__m256 reg): + _reg(reg) { } + + Pack(float val): + Pack(_mm256_set1_ps(val)) { } + + Pack(float a, float b, float c, float d, float e, float f, float g, float h): + Pack(_mm256_set_ps(h,g,f,e,d,c,b,a)) { } + + Pack(float* ptr): + Pack(_mm256_loadu_ps(ptr)) { + _mm_prefetch(ptr + size, _MM_HINT_T2); + } + + Pack(float* ptr, index_t* idx): + Pack(_mm256_i32gather_ps(ptr, _mm256_loadu_si256(reinterpret_cast<__m256i*>(idx)), sizeof(float))) { } + + operator __m256() { + return _reg; + } + + Pack operator+(Pack rhs) { + return Pack(_mm256_add_ps(_reg, rhs)); + } + + Pack& operator+=(Pack rhs) { + _reg = _mm256_add_ps(_reg, rhs); + return *this; + } + + Pack operator-(Pack rhs) { + return Pack(_mm256_sub_ps(_reg, rhs)); + } + + Pack& operator-=(Pack rhs) { + _reg = _mm256_sub_ps(_reg, rhs); + return *this; + } + + Pack operator*(Pack rhs) { + return Pack(_mm256_mul_ps(_reg, rhs)); + } + + Pack& operator*=(Pack rhs) { + _reg = _mm256_mul_ps(_reg, rhs); + return *this; + } + + Pack operator/(Pack rhs) { + return Pack(_mm256_div_ps(_reg, rhs)); + } + + Pack& operator/=(Pack rhs) { + _reg = _mm256_div_ps(_reg, rhs); + return *this; + } + + Pack operator-() { + return *this * Pack(-1.); + } + + __m256 sqrt() { + return _mm256_sqrt_ps(_reg); + } +}; + + +template <typename T> +Pack<T> operator+(T lhs, Pack<T> rhs) { + return Pack<T>(lhs) + rhs; +} + +template <typename T> +Pack<T> operator+(Pack<T> lhs, T rhs) { + return lhs + Pack<T>(rhs); +} + +template <typename T> +Pack<T> operator-(T lhs, Pack<T> rhs) { + return Pack<T>(lhs) - rhs; +} + +template <typename T> +Pack<T> operator-(Pack<T> lhs, T rhs) { + return lhs - Pack<T>(rhs); +} + +template <typename T> +Pack<T> operator*(Pack<T> lhs, T rhs) { + return lhs * Pack<T>(rhs); +} + +template <typename T> +Pack<T> operator*(T lhs, Pack<T> rhs) { + return Pack<T>(lhs) * rhs; +} + +template <typename T> +Pack<T> operator/(Pack<T> lhs, T rhs) { + return lhs / Pack<T>(rhs); +} + +template <typename T> +Pack<T> operator/(T lhs, Pack<T> rhs) { + return Pack<T>(lhs) / rhs; +} + +template <typename T> +Pack<T> sqrt(Pack<T> x) { + return x.sqrt(); +} + +template <std::floating_point T> +void maskstore(T* target, Mask<T> mask, Pack<T> value); + +template <> +void maskstore<double>(double* target, Mask<double> mask, Pack<double> value) { + _mm256_maskstore_pd(target, mask, value); +} + +template <> +void maskstore<float>(float* target, Mask<float> mask, Pack<float> value) { + _mm256_maskstore_ps(target, mask, value); +} + + +template <std::floating_point T> +void store(T* target, Pack<T> value); + +template <> +void store<double>(double* target, Pack<double> value) { + _mm256_storeu_pd(target, value); +} + +template <> +void store<float>(float* target, Pack<float> value) { + _mm256_storeu_ps(target, value); +} + +template <std::floating_point T> +void store(T* target, Pack<T> value, typename Pack<T>::index_t* indices); + +template <> +void store<double>(double* target, Pack<double> value, Pack<double>::index_t* indices) { +#ifdef __AVX512F__ + _mm256_i32scatter_pd(target, _mm_loadu_si128(reinterpret_cast<__m128i*>(indices)), value, sizeof(double)); +#else + __m256d reg = value; + #pragma GCC unroll 4 + for (unsigned i=0; i < simd::Pack<double>::size; ++i) { + target[indices[i]] = reg[i]; + } +#endif +} + +template <> +void store<float>(float* target, Pack<float> value, Pack<float>::index_t* indices) { +#ifdef __AVX512F__ + _mm256_i32scatter_ps(target, _mm256_loadu_si256(reinterpret_cast<__m256i*>(indices)), value, sizeof(float)); +#else + __m256 reg = value; + #pragma GCC unroll 8 + for (unsigned i=0; i < simd::Pack<float>::size; ++i) { + target[indices[i]] = reg[i]; + } +#endif +} + +} |