26 #ifndef CRYPTO3_MATH_BASIC_RADIX2_DOMAIN_AUX_HPP
27 #define CRYPTO3_MATH_BASIC_RADIX2_DOMAIN_AUX_HPP
42 #define _basic_radix2_fft detail::basic_parallel_radix2_fft
44 #define _basic_radix2_fft detail::basic_serial_radix2_fft
56 template<
typename FieldType,
typename Range>
58 typedef typename std::iterator_traits<decltype(std::begin(std::declval<Range>()))>::value_type
62 BOOST_STATIC_ASSERT(std::is_same<typename FieldType::value_type, value_type>::value);
64 const std::size_t n = a.size(), logn = log2(n);
65 if (n != (1u << logn))
66 throw std::invalid_argument(
"expected n == (1u << logn)");
69 for (std::size_t k = 0; k < n; ++k) {
72 std::swap(a[k], a[rk]);
76 for (std::size_t s = 1; s <= logn; ++s) {
78 const value_type w_m = omega.pow(n / (2 * m));
80 asm volatile(
"/* pre-inner */");
81 for (std::size_t k = 0; k < n; k += 2 * m) {
82 value_type w = value_type::one();
83 for (std::size_t j = 0; j < m; ++j) {
84 const value_type t = w * a[k + j + m];
85 a[k + j + m] = a[k + j] - t;
90 asm volatile(
"/* post-inner */");
95 template<
typename FieldType,
typename Range>
97 const typename FieldType::value_type &omega,
98 const std::size_t log_cpus) {
99 typedef typename std::iterator_traits<decltype(std::begin(std::declval<Range>()))>::value_type
103 BOOST_STATIC_ASSERT(std::is_same<typename FieldType::value_type, value_type>::value);
105 const std::size_t num_cpus = 1ul << log_cpus;
107 const std::size_t m = a.size();
108 const std::size_t log_m = log2(m);
109 if (m != 1ul << log_m)
110 throw std::invalid_argument(
"expected m == 1ul<<log_m");
112 if (log_m < log_cpus) {
113 basic_serial_radix2_fft<FieldType>(a, omega);
117 std::vector<std::vector<value_type>> tmp(num_cpus);
118 for (std::size_t j = 0; j < num_cpus; ++j) {
119 tmp[j].resize(1ul << (log_m - log_cpus), value_type::zero());
123 #pragma omp parallel for
125 for (std::size_t j = 0; j < num_cpus; ++j) {
126 const value_type omega_j = omega.pow(j);
127 const value_type omega_step = omega.pow(j << (log_m - log_cpus));
129 value_type elt = value_type::one();
130 for (std::size_t i = 0; i < 1ul << (log_m - log_cpus); ++i) {
131 for (std::size_t s = 0; s < num_cpus; ++s) {
133 const std::size_t idx = (i + (s << (log_m - log_cpus))) % (1u << log_m);
134 tmp[j][i] += a[idx] * elt;
141 const value_type omega_num_cpus = omega.pow(num_cpus);
144 #pragma omp parallel for
146 for (std::size_t j = 0; j < num_cpus; ++j) {
147 basic_serial_radix2_fft<FieldType>(tmp[j], omega_num_cpus);
151 #pragma omp parallel for
153 for (std::size_t i = 0; i < num_cpus; ++i) {
154 for (std::size_t j = 0; j < 1ul << (log_m - log_cpus); ++j) {
158 a[(j << log_cpus) + i] = tmp[i][j];
163 template<
typename FieldType,
typename Range>
166 const std::size_t num_cpus = omp_get_max_threads();
168 const std::size_t num_cpus = 1;
170 const std::size_t log_cpus =
171 ((num_cpus & (num_cpus - 1)) == 0 ? log2(num_cpus) : log2(num_cpus) - 1);
174 basic_serial_radix2_fft<FieldType>(a, omega);
184 template<
typename FieldType>
185 std::vector<typename FieldType::value_type>
187 const typename FieldType::value_type &t) {
188 typedef typename FieldType::value_type value_type;
191 return std::vector<value_type>(1, value_type::one());
194 if (m != (1u <<
static_cast<std::size_t
>(std::ceil(std::log2(m)))))
195 throw std::invalid_argument(
"expected m == (1u << log2(m))");
197 const value_type omega = unity_root<FieldType>(m);
199 std::vector<value_type> u(m, value_type::zero());
206 if (t.pow(m) == value_type::one()) {
207 value_type omega_i = value_type::one();
208 for (std::size_t i = 0; i < m; ++i) {
211 u[i] = value_type::one();
228 const value_type Z = (t.pow(m)) - value_type::one();
229 value_type l = Z * value_type(m).inversed();
230 value_type r = value_type::one();
231 for (std::size_t i = 0; i < m; ++i) {
232 u[i] = l * (t - r).inversed();
void basic_parallel_radix2_fft(Range &a, const typename FieldType::value_type &omega)
Definition: basic_radix2_domain_aux.hpp:164
void basic_serial_radix2_fft(Range &a, const typename FieldType::value_type &omega)
Definition: basic_radix2_domain_aux.hpp:57
void basic_parallel_radix2_fft_inner(Range &a, const typename FieldType::value_type &omega, const std::size_t log_cpus)
Definition: basic_radix2_domain_aux.hpp:96
std::vector< typename FieldType::value_type > basic_radix2_evaluate_all_lagrange_polynomials(const std::size_t m, const typename FieldType::value_type &t)
Definition: basic_radix2_domain_aux.hpp:186
std::size_t bitreverse(std::size_t n, const std::size_t l)
Definition: field_utils.hpp:42
Definition: algebra/include/nil/crypto3/algebra/type_traits.hpp:95