#pragma once #include #include #include namespace simd { template class Mask; template <> class Mask { private: __mmask8 _reg; public: using storage_t = std::uint8_t; static constexpr unsigned storage_size = 8; static storage_t encode(bool* value) { storage_t mask = value[0]; for (unsigned j=1; j < storage_size; ++j) { mask |= value[j] << j; } return mask; } Mask(bool b0, bool b1, bool b2, bool b3, bool b4, bool b5, bool b6, bool b7): _reg(std::uint16_t(b0 | b1<<1 | b2<<2 | b3<<3 | b4<<4 | b5<<5 | b6<<6 | b7<<7)) { } Mask(std::uint8_t* ptr): _reg(_load_mask16(reinterpret_cast(ptr))) { } Mask(storage_t* ptr, std::size_t iCell): Mask(ptr + iCell / storage_size) { } Mask(__mmask8 reg): _reg(reg) { } operator __mmask8() { return _reg; } __mmask8 neg() const { return _knot_mask8(_reg); } operator bool() const { const std::uint8_t* value = reinterpret_cast(&_reg); return value[0] != 0; } }; template <> class Mask { private: __mmask16 _reg; public: using storage_t = std::uint16_t; static constexpr unsigned storage_size = 16; static storage_t encode(bool* value) { storage_t mask = value[0]; for (unsigned j=1; j < storage_size; ++j) { mask |= value[j] << j; } return mask; } Mask(std::uint16_t* ptr): _reg(_load_mask16(ptr)) { } Mask(storage_t* ptr, std::size_t iCell): Mask(ptr + iCell / storage_size) { } Mask(__mmask16 reg): _reg(reg) { } operator __mmask16() { return _reg; } __mmask16 neg() const { return _knot_mask16(_reg); } operator bool() const { const std::uint16_t* value = reinterpret_cast(&_reg); return value[0] != 0; } }; template class Pack; template <> class Pack { private: __m512d _reg; public: using mask_t = Mask; using index_t = std::uint32_t; static constexpr std::size_t size = 8; Pack() = default; Pack(__m512d reg): _reg(reg) { } Pack(double val): Pack(_mm512_set1_pd(val)) { } Pack(double a, double b, double c, double d, double e, double f, double g, double h): Pack(_mm512_set_pd(h,g,f,e,d,c,b,a)) { } Pack(double* ptr): Pack(_mm512_loadu_pd(ptr)) { } Pack(double* ptr, index_t* idx): Pack(_mm512_i32gather_pd(_mm256_loadu_si256(reinterpret_cast<__m256i*>(idx)), ptr, sizeof(double))) { } operator __m512d() { return _reg; } Pack operator+(Pack rhs) { return Pack(_mm512_add_pd(_reg, rhs)); } Pack& operator+=(Pack rhs) { _reg = _mm512_add_pd(_reg, rhs); return *this; } Pack operator-(Pack rhs) { return Pack(_mm512_sub_pd(_reg, rhs)); } Pack& operator-=(Pack rhs) { _reg = _mm512_sub_pd(_reg, rhs); return *this; } Pack operator*(Pack rhs) { return Pack(_mm512_mul_pd(_reg, rhs)); } Pack& operator*=(Pack rhs) { _reg = _mm512_mul_pd(_reg, rhs); return *this; } Pack operator/(Pack rhs) { return Pack(_mm512_div_pd(_reg, rhs)); } Pack& operator/=(Pack rhs) { _reg = _mm512_div_pd(_reg, rhs); return *this; } Pack operator-() { return *this * Pack(-1.); } __m512d sqrt() { return _mm512_sqrt_pd(_reg); } }; template <> class Pack { private: __m512 _reg; public: using mask_t = Mask; using index_t = std::uint32_t; static constexpr std::size_t size = 16; Pack() = default; Pack(__m512 reg): _reg(reg) { } Pack(float val): Pack(_mm512_set1_ps(val)) { } Pack(float a, float b, float c, float d, float e, float f, float g, float h, float i, float j, float k, float l, float m, float n, float o, float p): Pack(_mm512_set_ps(p,o,n,m,l,k,j,i,h,g,f,e,d,c,b,a)) { } Pack(float* ptr): Pack(_mm512_loadu_ps(ptr)) { } Pack(float* ptr, index_t* idx): Pack(_mm512_i32gather_ps(_mm512_loadu_si512(reinterpret_cast<__m512i*>(idx)), ptr, sizeof(float))) { } operator __m512() { return _reg; } Pack operator+(Pack rhs) { return Pack(_mm512_add_ps(_reg, rhs)); } Pack& operator+=(Pack rhs) { _reg = _mm512_add_ps(_reg, rhs); return *this; } Pack operator-(Pack rhs) { return Pack(_mm512_sub_ps(_reg, rhs)); } Pack& operator-=(Pack rhs) { _reg = _mm512_sub_ps(_reg, rhs); return *this; } Pack operator*(Pack rhs) { return Pack(_mm512_mul_ps(_reg, rhs)); } Pack& operator*=(Pack rhs) { _reg = _mm512_mul_ps(_reg, rhs); return *this; } Pack operator/(Pack rhs) { return Pack(_mm512_div_ps(_reg, rhs)); } Pack& operator/=(Pack rhs) { _reg = _mm512_div_ps(_reg, rhs); return *this; } Pack operator-() { return *this * Pack(-1.); } __m512 sqrt() { return _mm512_sqrt_ps(_reg); } }; template Pack operator+(T lhs, Pack rhs) { return Pack(lhs) + rhs; } template Pack operator+(Pack lhs, T rhs) { return lhs + Pack(rhs); } template Pack operator-(T lhs, Pack rhs) { return Pack(lhs) - rhs; } template Pack operator-(Pack lhs, T rhs) { return lhs - Pack(rhs); } template Pack operator*(Pack lhs, T rhs) { return lhs * Pack(rhs); } template Pack operator*(T lhs, Pack rhs) { return Pack(lhs) * rhs; } template Pack operator/(Pack lhs, T rhs) { return lhs / Pack(rhs); } template Pack operator/(T lhs, Pack rhs) { return Pack(lhs) / rhs; } template void maskstore(T* target, Mask mask, Pack value); template <> void maskstore(double* target, Mask mask, Pack value) { _mm512_mask_storeu_pd(target, mask, value); } template <> void maskstore(float* target, Mask mask, Pack value) { _mm512_mask_storeu_ps(target, mask, value); } template void store(T* target, Pack value); template <> void store(double* target, Pack value) { _mm512_storeu_pd(target, value); } template <> void store(float* target, Pack value) { _mm512_storeu_ps(target, value); } template void store(T* target, Pack value, typename Pack::index_t* indices); template <> void store(double* target, Pack value, Pack::index_t* indices) { _mm512_i32scatter_pd(target, _mm256_loadu_si256(reinterpret_cast<__m256i*>(indices)), value, sizeof(double)); } template <> void store(float* target, Pack value, Pack::index_t* indices) { _mm512_i32scatter_ps(target, _mm512_loadu_si512(reinterpret_cast<__m512i*>(indices)), value, sizeof(float)); } }