summaryrefslogtreecommitdiff
path: root/src/simd/512.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/simd/512.h')
-rw-r--r--src/simd/512.h339
1 files changed, 339 insertions, 0 deletions
diff --git a/src/simd/512.h b/src/simd/512.h
new file mode 100644
index 0000000..2cc0a44
--- /dev/null
+++ b/src/simd/512.h
@@ -0,0 +1,339 @@
+#pragma once
+
+#include <immintrin.h>
+
+#include <cstdint>
+#include <type_traits>
+
+namespace simd {
+
+
+template <std::floating_point T>
+class Mask;
+
+template <>
+class Mask<double> {
+private:
+ __mmask8 _reg;
+
+public:
+ using storage_t = std::uint8_t;
+ static constexpr unsigned storage_size = 8;
+
+ static storage_t encode(bool* value) {
+ storage_t mask = value[0];
+ for (unsigned j=1; j < storage_size; ++j) {
+ mask |= value[j] << j;
+ }
+ return mask;
+ }
+
+ Mask(bool b0, bool b1, bool b2, bool b3, bool b4, bool b5, bool b6, bool b7):
+ _reg(std::uint16_t(b0 | b1<<1 | b2<<2 | b3<<3 | b4<<4 | b5<<5 | b6<<6 | b7<<7)) { }
+
+ Mask(std::uint8_t* ptr):
+ _reg(_load_mask16(reinterpret_cast<std::uint16_t*>(ptr))) { }
+
+ Mask(storage_t* ptr, std::size_t iCell):
+ Mask(ptr + iCell / storage_size) { }
+
+ Mask(__mmask8 reg):
+ _reg(reg) { }
+
+ operator __mmask8() {
+ return _reg;
+ }
+
+ __mmask8 neg() const {
+ return _knot_mask8(_reg);
+ }
+
+ operator bool() const {
+ const std::uint8_t* value = reinterpret_cast<const std::uint8_t*>(&_reg);
+ return value[0] != 0;
+ }
+};
+
+template <>
+class Mask<float> {
+private:
+ __mmask16 _reg;
+
+public:
+ using storage_t = std::uint16_t;
+ static constexpr unsigned storage_size = 16;
+
+ static storage_t encode(bool* value) {
+ storage_t mask = value[0];
+ for (unsigned j=1; j < storage_size; ++j) {
+ mask |= value[j] << j;
+ }
+ return mask;
+ }
+
+ Mask(std::uint16_t* ptr):
+ _reg(_load_mask16(ptr)) { }
+
+ Mask(storage_t* ptr, std::size_t iCell):
+ Mask(ptr + iCell / storage_size) { }
+
+ Mask(__mmask16 reg):
+ _reg(reg) { }
+
+ operator __mmask16() {
+ return _reg;
+ }
+
+ __mmask16 neg() const {
+ return _knot_mask16(_reg);
+ }
+
+ operator bool() const {
+ const std::uint16_t* value = reinterpret_cast<const std::uint16_t*>(&_reg);
+ return value[0] != 0;
+ }
+};
+
+
+template <std::floating_point T>
+class Pack;
+
+template <>
+class Pack<double> {
+private:
+ __m512d _reg;
+
+public:
+ using mask_t = Mask<double>;
+ using index_t = std::uint32_t;
+
+ static constexpr std::size_t size = 8;
+
+ Pack() = default;
+
+ Pack(__m512d reg):
+ _reg(reg) { }
+
+ Pack(double val):
+ Pack(_mm512_set1_pd(val)) { }
+
+ Pack(double a, double b, double c, double d, double e, double f, double g, double h):
+ Pack(_mm512_set_pd(h,g,f,e,d,c,b,a)) { }
+
+ Pack(double* ptr):
+ Pack(_mm512_loadu_pd(ptr)) { }
+
+ Pack(double* ptr, index_t* idx):
+ Pack(_mm512_i32gather_pd(_mm256_loadu_si256(reinterpret_cast<__m256i*>(idx)), ptr, sizeof(double))) { }
+
+ operator __m512d() {
+ return _reg;
+ }
+
+ Pack operator+(Pack rhs) {
+ return Pack(_mm512_add_pd(_reg, rhs));
+ }
+
+ Pack& operator+=(Pack rhs) {
+ _reg = _mm512_add_pd(_reg, rhs);
+ return *this;
+ }
+
+ Pack operator-(Pack rhs) {
+ return Pack(_mm512_sub_pd(_reg, rhs));
+ }
+
+ Pack& operator-=(Pack rhs) {
+ _reg = _mm512_sub_pd(_reg, rhs);
+ return *this;
+ }
+
+ Pack operator*(Pack rhs) {
+ return Pack(_mm512_mul_pd(_reg, rhs));
+ }
+
+ Pack& operator*=(Pack rhs) {
+ _reg = _mm512_mul_pd(_reg, rhs);
+ return *this;
+ }
+
+ Pack operator/(Pack rhs) {
+ return Pack(_mm512_div_pd(_reg, rhs));
+ }
+
+ Pack& operator/=(Pack rhs) {
+ _reg = _mm512_div_pd(_reg, rhs);
+ return *this;
+ }
+
+ Pack operator-() {
+ return *this * Pack(-1.);
+ }
+
+ __m512d sqrt() {
+ return _mm512_sqrt_pd(_reg);
+ }
+};
+
+template <>
+class Pack<float> {
+private:
+ __m512 _reg;
+
+public:
+ using mask_t = Mask<float>;
+ using index_t = std::uint32_t;
+
+ static constexpr std::size_t size = 16;
+
+ Pack() = default;
+
+ Pack(__m512 reg):
+ _reg(reg) { }
+
+ Pack(float val):
+ Pack(_mm512_set1_ps(val)) { }
+
+ Pack(float a, float b, float c, float d, float e, float f, float g, float h, float i, float j, float k, float l, float m, float n, float o, float p):
+ Pack(_mm512_set_ps(p,o,n,m,l,k,j,i,h,g,f,e,d,c,b,a)) { }
+
+ Pack(float* ptr):
+ Pack(_mm512_loadu_ps(ptr)) { }
+
+ Pack(float* ptr, index_t* idx):
+ Pack(_mm512_i32gather_ps(_mm512_loadu_si512(reinterpret_cast<__m512i*>(idx)), ptr, sizeof(float))) { }
+
+ operator __m512() {
+ return _reg;
+ }
+
+ Pack operator+(Pack rhs) {
+ return Pack(_mm512_add_ps(_reg, rhs));
+ }
+
+ Pack& operator+=(Pack rhs) {
+ _reg = _mm512_add_ps(_reg, rhs);
+ return *this;
+ }
+
+ Pack operator-(Pack rhs) {
+ return Pack(_mm512_sub_ps(_reg, rhs));
+ }
+
+ Pack& operator-=(Pack rhs) {
+ _reg = _mm512_sub_ps(_reg, rhs);
+ return *this;
+ }
+
+ Pack operator*(Pack rhs) {
+ return Pack(_mm512_mul_ps(_reg, rhs));
+ }
+
+ Pack& operator*=(Pack rhs) {
+ _reg = _mm512_mul_ps(_reg, rhs);
+ return *this;
+ }
+
+ Pack operator/(Pack rhs) {
+ return Pack(_mm512_div_ps(_reg, rhs));
+ }
+
+ Pack& operator/=(Pack rhs) {
+ _reg = _mm512_div_ps(_reg, rhs);
+ return *this;
+ }
+
+ Pack operator-() {
+ return *this * Pack(-1.);
+ }
+
+ __m512 sqrt() {
+ return _mm512_sqrt_ps(_reg);
+ }
+};
+
+
+template <typename T>
+Pack<T> operator+(T lhs, Pack<T> rhs) {
+ return Pack<T>(lhs) + rhs;
+}
+
+template <typename T>
+Pack<T> operator+(Pack<T> lhs, T rhs) {
+ return lhs + Pack<T>(rhs);
+}
+
+template <typename T>
+Pack<T> operator-(T lhs, Pack<T> rhs) {
+ return Pack<T>(lhs) - rhs;
+}
+
+template <typename T>
+Pack<T> operator-(Pack<T> lhs, T rhs) {
+ return lhs - Pack<T>(rhs);
+}
+
+template <typename T>
+Pack<T> operator*(Pack<T> lhs, T rhs) {
+ return lhs * Pack<T>(rhs);
+}
+
+template <typename T>
+Pack<T> operator*(T lhs, Pack<T> rhs) {
+ return Pack<T>(lhs) * rhs;
+}
+
+template <typename T>
+Pack<T> operator/(Pack<T> lhs, T rhs) {
+ return lhs / Pack<T>(rhs);
+}
+
+template <typename T>
+Pack<T> operator/(T lhs, Pack<T> rhs) {
+ return Pack<T>(lhs) / rhs;
+}
+
+
+template <typename T>
+void maskstore(T* target, Mask<T> mask, Pack<T> value);
+
+template <>
+void maskstore<double>(double* target, Mask<double> mask, Pack<double> value) {
+ _mm512_mask_storeu_pd(target, mask, value);
+}
+
+template <>
+void maskstore<float>(float* target, Mask<float> mask, Pack<float> value) {
+ _mm512_mask_storeu_ps(target, mask, value);
+}
+
+
+template <typename T>
+void store(T* target, Pack<T> value);
+
+template <>
+void store<double>(double* target, Pack<double> value) {
+ _mm512_storeu_pd(target, value);
+}
+
+template <>
+void store<float>(float* target, Pack<float> value) {
+ _mm512_storeu_ps(target, value);
+}
+
+
+template <typename T>
+void store(T* target, Pack<T> value, typename Pack<T>::index_t* indices);
+
+template <>
+void store<double>(double* target, Pack<double> value, Pack<double>::index_t* indices) {
+ _mm512_i32scatter_pd(target, _mm256_loadu_si256(reinterpret_cast<__m256i*>(indices)), value, sizeof(double));
+}
+
+
+template <>
+void store<float>(float* target, Pack<float> value, Pack<float>::index_t* indices) {
+ _mm512_i32scatter_ps(target, _mm512_loadu_si512(reinterpret_cast<__m512i*>(indices)), value, sizeof(float));
+}
+
+}