summaryrefslogtreecommitdiff
path: root/src/simd/256.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/simd/256.h')
-rw-r--r--src/simd/256.h379
1 files changed, 379 insertions, 0 deletions
diff --git a/src/simd/256.h b/src/simd/256.h
new file mode 100644
index 0000000..a3f419d
--- /dev/null
+++ b/src/simd/256.h
@@ -0,0 +1,379 @@
+#pragma once
+
+#include <immintrin.h>
+
+#include <cstdint>
+#include <type_traits>
+
+namespace simd {
+
+template <std::floating_point T>
+class Mask;
+
+template <>
+class Mask<double> {
+private:
+ __m256i _reg;
+
+public:
+ using storage_t = std::uint64_t;
+ static constexpr unsigned storage_size = 1;
+
+ static constexpr storage_t true_v = 1l << 63;
+ static constexpr storage_t false_v = 0l;
+
+ static storage_t encode(bool value) {
+ return value ? true_v : false_v;
+ }
+
+ static storage_t encode(bool* value) {
+ return encode(*value);
+ }
+
+ Mask(bool a, bool b, bool c, bool d):
+ _reg(_mm256_set_epi64x(encode(d),encode(c),encode(b),encode(a))) { }
+
+ Mask(std::uint64_t a, std::uint64_t b, std::uint64_t c, std::uint64_t d):
+ _reg(_mm256_set_epi64x(d,c,b,a)) { }
+
+ Mask(std::uint64_t* ptr):
+ _reg(_mm256_loadu_si256(reinterpret_cast<__m256i*>(ptr))) { }
+
+ Mask(storage_t* ptr, std::size_t iCell):
+ Mask<double>(ptr + iCell) { }
+
+ Mask(__m256i reg):
+ _reg(reg) { }
+
+ operator __m256i() {
+ return _reg;
+ }
+
+ __m256i neg() const {
+ return _mm256_sub_epi64(_mm256_set1_epi64x(true_v), _reg);
+ }
+
+ operator bool() const {
+ const std::uint64_t* values = reinterpret_cast<const std::uint64_t*>(&_reg);
+ return values[0] == true_v
+ || values[1] == true_v
+ || values[2] == true_v
+ || values[3] == true_v;
+ }
+};
+
+template <>
+class Mask<float> {
+private:
+ __m256i _reg;
+
+public:
+ using storage_t = std::uint32_t;
+ static constexpr unsigned storage_size = 1;
+
+ static constexpr storage_t true_v = 1 << 31;
+ static constexpr storage_t false_v = 0;
+
+ static storage_t encode(bool value) {
+ return value ? true_v : false_v;
+ }
+
+ static storage_t encode(bool* value) {
+ return encode(*value);
+ }
+
+ Mask(storage_t* ptr):
+ _reg(_mm256_loadu_si256(reinterpret_cast<__m256i*>(ptr))) { }
+
+ Mask(storage_t* ptr, std::size_t iCell):
+ Mask<float>(ptr + iCell) { }
+
+ Mask(__m256i reg):
+ _reg(reg) { }
+
+ operator __m256i() {
+ return _reg;
+ }
+
+ __m256i neg() const {
+ return _mm256_sub_epi32(_mm256_set1_epi32(true_v), _reg);
+ }
+
+ operator bool() const {
+ const std::uint32_t* values = reinterpret_cast<const std::uint32_t*>(&_reg);
+ return values[0] == true_v
+ || values[1] == true_v
+ || values[2] == true_v
+ || values[3] == true_v
+ || values[4] == true_v
+ || values[5] == true_v
+ || values[6] == true_v
+ || values[7] == true_v;
+ }
+};
+
+
+template <std::floating_point T>
+class Pack;
+
+template <>
+class Pack<double> {
+private:
+ __m256d _reg;
+
+public:
+ using mask_t = Mask<double>;
+ using index_t = std::uint32_t;
+
+ static constexpr std::size_t size = 4;
+
+ Pack() = default;
+
+ Pack(__m256d reg):
+ _reg(reg) { }
+
+ Pack(double val):
+ Pack(_mm256_set1_pd(val)) { }
+
+ Pack(double a, double b, double c, double d):
+ Pack(_mm256_set_pd(d,c,b,a)) { }
+
+ Pack(double* ptr):
+ Pack(_mm256_loadu_pd(ptr)) {
+ _mm_prefetch(ptr + size, _MM_HINT_T2);
+ }
+
+ Pack(double* ptr, index_t* idx):
+ Pack(_mm256_i32gather_pd(ptr, _mm_loadu_si128(reinterpret_cast<__m128i*>(idx)), sizeof(double))) { }
+
+ operator __m256d() {
+ return _reg;
+ }
+
+ Pack operator+(Pack rhs) {
+ return Pack(_mm256_add_pd(_reg, rhs));
+ }
+
+ Pack& operator+=(Pack rhs) {
+ _reg = _mm256_add_pd(_reg, rhs);
+ return *this;
+ }
+
+ Pack operator-(Pack rhs) {
+ return Pack(_mm256_sub_pd(_reg, rhs));
+ }
+
+ Pack& operator-=(Pack rhs) {
+ _reg = _mm256_sub_pd(_reg, rhs);
+ return *this;
+ }
+
+ Pack operator*(Pack rhs) {
+ return Pack(_mm256_mul_pd(_reg, rhs));
+ }
+
+ Pack& operator*=(Pack rhs) {
+ _reg = _mm256_mul_pd(_reg, rhs);
+ return *this;
+ }
+
+ Pack operator/(Pack rhs) {
+ return Pack(_mm256_div_pd(_reg, rhs));
+ }
+
+ Pack& operator/=(Pack rhs) {
+ _reg = _mm256_div_pd(_reg, rhs);
+ return *this;
+ }
+
+ Pack operator-() {
+ return *this * Pack(-1.);
+ }
+
+ __m256d sqrt() {
+ return _mm256_sqrt_pd(_reg);
+ }
+};
+
+template <>
+class Pack<float> {
+private:
+ __m256 _reg;
+
+public:
+ using mask_t = Mask<float>;
+ using index_t = std::uint32_t;
+
+ static constexpr std::size_t size = 8;
+
+ Pack() = default;
+
+ Pack(__m256 reg):
+ _reg(reg) { }
+
+ Pack(float val):
+ Pack(_mm256_set1_ps(val)) { }
+
+ Pack(float a, float b, float c, float d, float e, float f, float g, float h):
+ Pack(_mm256_set_ps(h,g,f,e,d,c,b,a)) { }
+
+ Pack(float* ptr):
+ Pack(_mm256_loadu_ps(ptr)) {
+ _mm_prefetch(ptr + size, _MM_HINT_T2);
+ }
+
+ Pack(float* ptr, index_t* idx):
+ Pack(_mm256_i32gather_ps(ptr, _mm256_loadu_si256(reinterpret_cast<__m256i*>(idx)), sizeof(float))) { }
+
+ operator __m256() {
+ return _reg;
+ }
+
+ Pack operator+(Pack rhs) {
+ return Pack(_mm256_add_ps(_reg, rhs));
+ }
+
+ Pack& operator+=(Pack rhs) {
+ _reg = _mm256_add_ps(_reg, rhs);
+ return *this;
+ }
+
+ Pack operator-(Pack rhs) {
+ return Pack(_mm256_sub_ps(_reg, rhs));
+ }
+
+ Pack& operator-=(Pack rhs) {
+ _reg = _mm256_sub_ps(_reg, rhs);
+ return *this;
+ }
+
+ Pack operator*(Pack rhs) {
+ return Pack(_mm256_mul_ps(_reg, rhs));
+ }
+
+ Pack& operator*=(Pack rhs) {
+ _reg = _mm256_mul_ps(_reg, rhs);
+ return *this;
+ }
+
+ Pack operator/(Pack rhs) {
+ return Pack(_mm256_div_ps(_reg, rhs));
+ }
+
+ Pack& operator/=(Pack rhs) {
+ _reg = _mm256_div_ps(_reg, rhs);
+ return *this;
+ }
+
+ Pack operator-() {
+ return *this * Pack(-1.);
+ }
+
+ __m256 sqrt() {
+ return _mm256_sqrt_ps(_reg);
+ }
+};
+
+
+template <typename T>
+Pack<T> operator+(T lhs, Pack<T> rhs) {
+ return Pack<T>(lhs) + rhs;
+}
+
+template <typename T>
+Pack<T> operator+(Pack<T> lhs, T rhs) {
+ return lhs + Pack<T>(rhs);
+}
+
+template <typename T>
+Pack<T> operator-(T lhs, Pack<T> rhs) {
+ return Pack<T>(lhs) - rhs;
+}
+
+template <typename T>
+Pack<T> operator-(Pack<T> lhs, T rhs) {
+ return lhs - Pack<T>(rhs);
+}
+
+template <typename T>
+Pack<T> operator*(Pack<T> lhs, T rhs) {
+ return lhs * Pack<T>(rhs);
+}
+
+template <typename T>
+Pack<T> operator*(T lhs, Pack<T> rhs) {
+ return Pack<T>(lhs) * rhs;
+}
+
+template <typename T>
+Pack<T> operator/(Pack<T> lhs, T rhs) {
+ return lhs / Pack<T>(rhs);
+}
+
+template <typename T>
+Pack<T> operator/(T lhs, Pack<T> rhs) {
+ return Pack<T>(lhs) / rhs;
+}
+
+template <typename T>
+Pack<T> sqrt(Pack<T> x) {
+ return x.sqrt();
+}
+
+template <std::floating_point T>
+void maskstore(T* target, Mask<T> mask, Pack<T> value);
+
+template <>
+void maskstore<double>(double* target, Mask<double> mask, Pack<double> value) {
+ _mm256_maskstore_pd(target, mask, value);
+}
+
+template <>
+void maskstore<float>(float* target, Mask<float> mask, Pack<float> value) {
+ _mm256_maskstore_ps(target, mask, value);
+}
+
+
+template <std::floating_point T>
+void store(T* target, Pack<T> value);
+
+template <>
+void store<double>(double* target, Pack<double> value) {
+ _mm256_storeu_pd(target, value);
+}
+
+template <>
+void store<float>(float* target, Pack<float> value) {
+ _mm256_storeu_ps(target, value);
+}
+
+template <std::floating_point T>
+void store(T* target, Pack<T> value, typename Pack<T>::index_t* indices);
+
+template <>
+void store<double>(double* target, Pack<double> value, Pack<double>::index_t* indices) {
+#ifdef __AVX512F__
+ _mm256_i32scatter_pd(target, _mm_loadu_si128(reinterpret_cast<__m128i*>(indices)), value, sizeof(double));
+#else
+ __m256d reg = value;
+ #pragma GCC unroll 4
+ for (unsigned i=0; i < simd::Pack<double>::size; ++i) {
+ target[indices[i]] = reg[i];
+ }
+#endif
+}
+
+template <>
+void store<float>(float* target, Pack<float> value, Pack<float>::index_t* indices) {
+#ifdef __AVX512F__
+ _mm256_i32scatter_ps(target, _mm256_loadu_si256(reinterpret_cast<__m256i*>(indices)), value, sizeof(float));
+#else
+ __m256 reg = value;
+ #pragma GCC unroll 8
+ for (unsigned i=0; i < simd::Pack<float>::size; ++i) {
+ target[indices[i]] = reg[i];
+ }
+#endif
+}
+
+}