From a92271176a19e06611099c0eccc4e6a6887f4915 Mon Sep 17 00:00:00 2001 From: Adrian Kummerlaender Date: Mon, 17 May 2021 00:30:13 +0200 Subject: Extract public version of SweepLB --- src/simd/256.h | 379 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 379 insertions(+) create mode 100644 src/simd/256.h (limited to 'src/simd/256.h') diff --git a/src/simd/256.h b/src/simd/256.h new file mode 100644 index 0000000..a3f419d --- /dev/null +++ b/src/simd/256.h @@ -0,0 +1,379 @@ +#pragma once + +#include + +#include +#include + +namespace simd { + +template +class Mask; + +template <> +class Mask { +private: + __m256i _reg; + +public: + using storage_t = std::uint64_t; + static constexpr unsigned storage_size = 1; + + static constexpr storage_t true_v = 1l << 63; + static constexpr storage_t false_v = 0l; + + static storage_t encode(bool value) { + return value ? true_v : false_v; + } + + static storage_t encode(bool* value) { + return encode(*value); + } + + Mask(bool a, bool b, bool c, bool d): + _reg(_mm256_set_epi64x(encode(d),encode(c),encode(b),encode(a))) { } + + Mask(std::uint64_t a, std::uint64_t b, std::uint64_t c, std::uint64_t d): + _reg(_mm256_set_epi64x(d,c,b,a)) { } + + Mask(std::uint64_t* ptr): + _reg(_mm256_loadu_si256(reinterpret_cast<__m256i*>(ptr))) { } + + Mask(storage_t* ptr, std::size_t iCell): + Mask(ptr + iCell) { } + + Mask(__m256i reg): + _reg(reg) { } + + operator __m256i() { + return _reg; + } + + __m256i neg() const { + return _mm256_sub_epi64(_mm256_set1_epi64x(true_v), _reg); + } + + operator bool() const { + const std::uint64_t* values = reinterpret_cast(&_reg); + return values[0] == true_v + || values[1] == true_v + || values[2] == true_v + || values[3] == true_v; + } +}; + +template <> +class Mask { +private: + __m256i _reg; + +public: + using storage_t = std::uint32_t; + static constexpr unsigned storage_size = 1; + + static constexpr storage_t true_v = 1 << 31; + static constexpr storage_t false_v = 0; + + static storage_t encode(bool value) { + return value ? true_v : false_v; + } + + static storage_t encode(bool* value) { + return encode(*value); + } + + Mask(storage_t* ptr): + _reg(_mm256_loadu_si256(reinterpret_cast<__m256i*>(ptr))) { } + + Mask(storage_t* ptr, std::size_t iCell): + Mask(ptr + iCell) { } + + Mask(__m256i reg): + _reg(reg) { } + + operator __m256i() { + return _reg; + } + + __m256i neg() const { + return _mm256_sub_epi32(_mm256_set1_epi32(true_v), _reg); + } + + operator bool() const { + const std::uint32_t* values = reinterpret_cast(&_reg); + return values[0] == true_v + || values[1] == true_v + || values[2] == true_v + || values[3] == true_v + || values[4] == true_v + || values[5] == true_v + || values[6] == true_v + || values[7] == true_v; + } +}; + + +template +class Pack; + +template <> +class Pack { +private: + __m256d _reg; + +public: + using mask_t = Mask; + using index_t = std::uint32_t; + + static constexpr std::size_t size = 4; + + Pack() = default; + + Pack(__m256d reg): + _reg(reg) { } + + Pack(double val): + Pack(_mm256_set1_pd(val)) { } + + Pack(double a, double b, double c, double d): + Pack(_mm256_set_pd(d,c,b,a)) { } + + Pack(double* ptr): + Pack(_mm256_loadu_pd(ptr)) { + _mm_prefetch(ptr + size, _MM_HINT_T2); + } + + Pack(double* ptr, index_t* idx): + Pack(_mm256_i32gather_pd(ptr, _mm_loadu_si128(reinterpret_cast<__m128i*>(idx)), sizeof(double))) { } + + operator __m256d() { + return _reg; + } + + Pack operator+(Pack rhs) { + return Pack(_mm256_add_pd(_reg, rhs)); + } + + Pack& operator+=(Pack rhs) { + _reg = _mm256_add_pd(_reg, rhs); + return *this; + } + + Pack operator-(Pack rhs) { + return Pack(_mm256_sub_pd(_reg, rhs)); + } + + Pack& operator-=(Pack rhs) { + _reg = _mm256_sub_pd(_reg, rhs); + return *this; + } + + Pack operator*(Pack rhs) { + return Pack(_mm256_mul_pd(_reg, rhs)); + } + + Pack& operator*=(Pack rhs) { + _reg = _mm256_mul_pd(_reg, rhs); + return *this; + } + + Pack operator/(Pack rhs) { + return Pack(_mm256_div_pd(_reg, rhs)); + } + + Pack& operator/=(Pack rhs) { + _reg = _mm256_div_pd(_reg, rhs); + return *this; + } + + Pack operator-() { + return *this * Pack(-1.); + } + + __m256d sqrt() { + return _mm256_sqrt_pd(_reg); + } +}; + +template <> +class Pack { +private: + __m256 _reg; + +public: + using mask_t = Mask; + using index_t = std::uint32_t; + + static constexpr std::size_t size = 8; + + Pack() = default; + + Pack(__m256 reg): + _reg(reg) { } + + Pack(float val): + Pack(_mm256_set1_ps(val)) { } + + Pack(float a, float b, float c, float d, float e, float f, float g, float h): + Pack(_mm256_set_ps(h,g,f,e,d,c,b,a)) { } + + Pack(float* ptr): + Pack(_mm256_loadu_ps(ptr)) { + _mm_prefetch(ptr + size, _MM_HINT_T2); + } + + Pack(float* ptr, index_t* idx): + Pack(_mm256_i32gather_ps(ptr, _mm256_loadu_si256(reinterpret_cast<__m256i*>(idx)), sizeof(float))) { } + + operator __m256() { + return _reg; + } + + Pack operator+(Pack rhs) { + return Pack(_mm256_add_ps(_reg, rhs)); + } + + Pack& operator+=(Pack rhs) { + _reg = _mm256_add_ps(_reg, rhs); + return *this; + } + + Pack operator-(Pack rhs) { + return Pack(_mm256_sub_ps(_reg, rhs)); + } + + Pack& operator-=(Pack rhs) { + _reg = _mm256_sub_ps(_reg, rhs); + return *this; + } + + Pack operator*(Pack rhs) { + return Pack(_mm256_mul_ps(_reg, rhs)); + } + + Pack& operator*=(Pack rhs) { + _reg = _mm256_mul_ps(_reg, rhs); + return *this; + } + + Pack operator/(Pack rhs) { + return Pack(_mm256_div_ps(_reg, rhs)); + } + + Pack& operator/=(Pack rhs) { + _reg = _mm256_div_ps(_reg, rhs); + return *this; + } + + Pack operator-() { + return *this * Pack(-1.); + } + + __m256 sqrt() { + return _mm256_sqrt_ps(_reg); + } +}; + + +template +Pack operator+(T lhs, Pack rhs) { + return Pack(lhs) + rhs; +} + +template +Pack operator+(Pack lhs, T rhs) { + return lhs + Pack(rhs); +} + +template +Pack operator-(T lhs, Pack rhs) { + return Pack(lhs) - rhs; +} + +template +Pack operator-(Pack lhs, T rhs) { + return lhs - Pack(rhs); +} + +template +Pack operator*(Pack lhs, T rhs) { + return lhs * Pack(rhs); +} + +template +Pack operator*(T lhs, Pack rhs) { + return Pack(lhs) * rhs; +} + +template +Pack operator/(Pack lhs, T rhs) { + return lhs / Pack(rhs); +} + +template +Pack operator/(T lhs, Pack rhs) { + return Pack(lhs) / rhs; +} + +template +Pack sqrt(Pack x) { + return x.sqrt(); +} + +template +void maskstore(T* target, Mask mask, Pack value); + +template <> +void maskstore(double* target, Mask mask, Pack value) { + _mm256_maskstore_pd(target, mask, value); +} + +template <> +void maskstore(float* target, Mask mask, Pack value) { + _mm256_maskstore_ps(target, mask, value); +} + + +template +void store(T* target, Pack value); + +template <> +void store(double* target, Pack value) { + _mm256_storeu_pd(target, value); +} + +template <> +void store(float* target, Pack value) { + _mm256_storeu_ps(target, value); +} + +template +void store(T* target, Pack value, typename Pack::index_t* indices); + +template <> +void store(double* target, Pack value, Pack::index_t* indices) { +#ifdef __AVX512F__ + _mm256_i32scatter_pd(target, _mm_loadu_si128(reinterpret_cast<__m128i*>(indices)), value, sizeof(double)); +#else + __m256d reg = value; + #pragma GCC unroll 4 + for (unsigned i=0; i < simd::Pack::size; ++i) { + target[indices[i]] = reg[i]; + } +#endif +} + +template <> +void store(float* target, Pack value, Pack::index_t* indices) { +#ifdef __AVX512F__ + _mm256_i32scatter_ps(target, _mm256_loadu_si256(reinterpret_cast<__m256i*>(indices)), value, sizeof(float)); +#else + __m256 reg = value; + #pragma GCC unroll 8 + for (unsigned i=0; i < simd::Pack::size; ++i) { + target[indices[i]] = reg[i]; + } +#endif +} + +} -- cgit v1.2.3