#pragma once #include #include #include namespace simd { template class Mask; template <> class Mask { private: __m256i _reg; public: using storage_t = std::uint64_t; static constexpr unsigned storage_size = 1; static constexpr storage_t true_v = 1l << 63; static constexpr storage_t false_v = 0l; static storage_t encode(bool value) { return value ? true_v : false_v; } static storage_t encode(bool* value) { return encode(*value); } Mask(bool a, bool b, bool c, bool d): _reg(_mm256_set_epi64x(encode(d),encode(c),encode(b),encode(a))) { } Mask(std::uint64_t a, std::uint64_t b, std::uint64_t c, std::uint64_t d): _reg(_mm256_set_epi64x(d,c,b,a)) { } Mask(std::uint64_t* ptr): _reg(_mm256_loadu_si256(reinterpret_cast<__m256i*>(ptr))) { } Mask(storage_t* ptr, std::size_t iCell): Mask(ptr + iCell) { } Mask(__m256i reg): _reg(reg) { } operator __m256i() { return _reg; } __m256i neg() const { return _mm256_sub_epi64(_mm256_set1_epi64x(true_v), _reg); } operator bool() const { const std::uint64_t* values = reinterpret_cast(&_reg); return values[0] == true_v || values[1] == true_v || values[2] == true_v || values[3] == true_v; } }; template <> class Mask { private: __m256i _reg; public: using storage_t = std::uint32_t; static constexpr unsigned storage_size = 1; static constexpr storage_t true_v = 1 << 31; static constexpr storage_t false_v = 0; static storage_t encode(bool value) { return value ? true_v : false_v; } static storage_t encode(bool* value) { return encode(*value); } Mask(storage_t* ptr): _reg(_mm256_loadu_si256(reinterpret_cast<__m256i*>(ptr))) { } Mask(storage_t* ptr, std::size_t iCell): Mask(ptr + iCell) { } Mask(__m256i reg): _reg(reg) { } operator __m256i() { return _reg; } __m256i neg() const { return _mm256_sub_epi32(_mm256_set1_epi32(true_v), _reg); } operator bool() const { const std::uint32_t* values = reinterpret_cast(&_reg); return values[0] == true_v || values[1] == true_v || values[2] == true_v || values[3] == true_v || values[4] == true_v || values[5] == true_v || values[6] == true_v || values[7] == true_v; } }; template class Pack; template <> class Pack { private: __m256d _reg; public: using mask_t = Mask; using index_t = std::uint32_t; static constexpr std::size_t size = 4; Pack() = default; Pack(__m256d reg): _reg(reg) { } Pack(double val): Pack(_mm256_set1_pd(val)) { } Pack(double a, double b, double c, double d): Pack(_mm256_set_pd(d,c,b,a)) { } Pack(double* ptr): Pack(_mm256_loadu_pd(ptr)) { _mm_prefetch(ptr + size, _MM_HINT_T2); } Pack(double* ptr, index_t* idx): Pack(_mm256_i32gather_pd(ptr, _mm_loadu_si128(reinterpret_cast<__m128i*>(idx)), sizeof(double))) { } operator __m256d() { return _reg; } Pack operator+(Pack rhs) { return Pack(_mm256_add_pd(_reg, rhs)); } Pack& operator+=(Pack rhs) { _reg = _mm256_add_pd(_reg, rhs); return *this; } Pack operator-(Pack rhs) { return Pack(_mm256_sub_pd(_reg, rhs)); } Pack& operator-=(Pack rhs) { _reg = _mm256_sub_pd(_reg, rhs); return *this; } Pack operator*(Pack rhs) { return Pack(_mm256_mul_pd(_reg, rhs)); } Pack& operator*=(Pack rhs) { _reg = _mm256_mul_pd(_reg, rhs); return *this; } Pack operator/(Pack rhs) { return Pack(_mm256_div_pd(_reg, rhs)); } Pack& operator/=(Pack rhs) { _reg = _mm256_div_pd(_reg, rhs); return *this; } Pack operator-() { return *this * Pack(-1.); } __m256d sqrt() { return _mm256_sqrt_pd(_reg); } }; template <> class Pack { private: __m256 _reg; public: using mask_t = Mask; using index_t = std::uint32_t; static constexpr std::size_t size = 8; Pack() = default; Pack(__m256 reg): _reg(reg) { } Pack(float val): Pack(_mm256_set1_ps(val)) { } Pack(float a, float b, float c, float d, float e, float f, float g, float h): Pack(_mm256_set_ps(h,g,f,e,d,c,b,a)) { } Pack(float* ptr): Pack(_mm256_loadu_ps(ptr)) { _mm_prefetch(ptr + size, _MM_HINT_T2); } Pack(float* ptr, index_t* idx): Pack(_mm256_i32gather_ps(ptr, _mm256_loadu_si256(reinterpret_cast<__m256i*>(idx)), sizeof(float))) { } operator __m256() { return _reg; } Pack operator+(Pack rhs) { return Pack(_mm256_add_ps(_reg, rhs)); } Pack& operator+=(Pack rhs) { _reg = _mm256_add_ps(_reg, rhs); return *this; } Pack operator-(Pack rhs) { return Pack(_mm256_sub_ps(_reg, rhs)); } Pack& operator-=(Pack rhs) { _reg = _mm256_sub_ps(_reg, rhs); return *this; } Pack operator*(Pack rhs) { return Pack(_mm256_mul_ps(_reg, rhs)); } Pack& operator*=(Pack rhs) { _reg = _mm256_mul_ps(_reg, rhs); return *this; } Pack operator/(Pack rhs) { return Pack(_mm256_div_ps(_reg, rhs)); } Pack& operator/=(Pack rhs) { _reg = _mm256_div_ps(_reg, rhs); return *this; } Pack operator-() { return *this * Pack(-1.); } __m256 sqrt() { return _mm256_sqrt_ps(_reg); } }; template Pack operator+(T lhs, Pack rhs) { return Pack(lhs) + rhs; } template Pack operator+(Pack lhs, T rhs) { return lhs + Pack(rhs); } template Pack operator-(T lhs, Pack rhs) { return Pack(lhs) - rhs; } template Pack operator-(Pack lhs, T rhs) { return lhs - Pack(rhs); } template Pack operator*(Pack lhs, T rhs) { return lhs * Pack(rhs); } template Pack operator*(T lhs, Pack rhs) { return Pack(lhs) * rhs; } template Pack operator/(Pack lhs, T rhs) { return lhs / Pack(rhs); } template Pack operator/(T lhs, Pack rhs) { return Pack(lhs) / rhs; } template Pack sqrt(Pack x) { return x.sqrt(); } template void maskstore(T* target, Mask mask, Pack value); template <> void maskstore(double* target, Mask mask, Pack value) { _mm256_maskstore_pd(target, mask, value); } template <> void maskstore(float* target, Mask mask, Pack value) { _mm256_maskstore_ps(target, mask, value); } template void store(T* target, Pack value); template <> void store(double* target, Pack value) { _mm256_storeu_pd(target, value); } template <> void store(float* target, Pack value) { _mm256_storeu_ps(target, value); } template void store(T* target, Pack value, typename Pack::index_t* indices); template <> void store(double* target, Pack value, Pack::index_t* indices) { #ifdef __AVX512F__ _mm256_i32scatter_pd(target, _mm_loadu_si128(reinterpret_cast<__m128i*>(indices)), value, sizeof(double)); #else __m256d reg = value; #pragma GCC unroll 4 for (unsigned i=0; i < simd::Pack::size; ++i) { target[indices[i]] = reg[i]; } #endif } template <> void store(float* target, Pack value, Pack::index_t* indices) { #ifdef __AVX512F__ _mm256_i32scatter_ps(target, _mm256_loadu_si256(reinterpret_cast<__m256i*>(indices)), value, sizeof(float)); #else __m256 reg = value; #pragma GCC unroll 8 for (unsigned i=0; i < simd::Pack::size; ++i) { target[indices[i]] = reg[i]; } #endif } }