basic_radix2_domain_aux.hpp
Go to the documentation of this file.
1 //---------------------------------------------------------------------------//
2 // Copyright (c) 2020-2021 Mikhail Komarov <nemo@nil.foundation>
3 // Copyright (c) 2020-2021 Nikita Kaskov <nbering@nil.foundation>
4 //
5 // MIT License
6 //
7 // Permission is hereby granted, free of charge, to any person obtaining a copy
8 // of this software and associated documentation files (the "Software"), to deal
9 // in the Software without restriction, including without limitation the rights
10 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 // copies of the Software, and to permit persons to whom the Software is
12 // furnished to do so, subject to the following conditions:
13 //
14 // The above copyright notice and this permission notice shall be included in all
15 // copies or substantial portions of the Software.
16 //
17 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 // SOFTWARE.
24 //---------------------------------------------------------------------------//
25 
26 #ifndef CRYPTO3_MATH_BASIC_RADIX2_DOMAIN_AUX_HPP
27 #define CRYPTO3_MATH_BASIC_RADIX2_DOMAIN_AUX_HPP
28 
29 #include <algorithm>
30 #include <vector>
31 
32 #ifdef MULTICORE
33 #include <omp.h>
34 #endif
35 
37 
40 
41 #ifdef MULTICORE
42 #define _basic_radix2_fft detail::basic_parallel_radix2_fft
43 #else
44 #define _basic_radix2_fft detail::basic_serial_radix2_fft
45 #endif
46 
47 namespace nil {
48  namespace crypto3 {
49  namespace math {
50  namespace detail {
51 
52  /*
53  * Below we make use of pseudocode from [CLRS 2n Ed, pp. 864].
54  * Also, note that it's the caller's responsibility to multiply by 1/N.
55  */
56  template<typename FieldType, typename Range>
57  void basic_serial_radix2_fft(Range &a, const typename FieldType::value_type &omega) {
58  typedef typename std::iterator_traits<decltype(std::begin(std::declval<Range>()))>::value_type
59  value_type;
60 
61  BOOST_STATIC_ASSERT(algebra::is_field<FieldType>::value);
62  BOOST_STATIC_ASSERT(std::is_same<typename FieldType::value_type, value_type>::value);
63 
64  const std::size_t n = a.size(), logn = log2(n);
65  if (n != (1u << logn))
66  throw std::invalid_argument("expected n == (1u << logn)");
67 
68  /* swapping in place (from Storer's book) */
69  for (std::size_t k = 0; k < n; ++k) {
70  const std::size_t rk = bitreverse(k, logn);
71  if (k < rk)
72  std::swap(a[k], a[rk]);
73  }
74 
75  std::size_t m = 1; // invariant: m = 2^{s-1}
76  for (std::size_t s = 1; s <= logn; ++s) {
77  // w_m is 2^s-th root of unity now
78  const value_type w_m = omega.pow(n / (2 * m));
79 
80  asm volatile("/* pre-inner */");
81  for (std::size_t k = 0; k < n; k += 2 * m) {
82  value_type w = value_type::one();
83  for (std::size_t j = 0; j < m; ++j) {
84  const value_type t = w * a[k + j + m];
85  a[k + j + m] = a[k + j] - t;
86  a[k + j] += t;
87  w *= w_m;
88  }
89  }
90  asm volatile("/* post-inner */");
91  m *= 2;
92  }
93  }
94 
95  template<typename FieldType, typename Range>
97  const typename FieldType::value_type &omega,
98  const std::size_t log_cpus) {
99  typedef typename std::iterator_traits<decltype(std::begin(std::declval<Range>()))>::value_type
100  value_type;
101 
102  BOOST_STATIC_ASSERT(algebra::is_field<FieldType>::value);
103  BOOST_STATIC_ASSERT(std::is_same<typename FieldType::value_type, value_type>::value);
104 
105  const std::size_t num_cpus = 1ul << log_cpus;
106 
107  const std::size_t m = a.size();
108  const std::size_t log_m = log2(m);
109  if (m != 1ul << log_m)
110  throw std::invalid_argument("expected m == 1ul<<log_m");
111 
112  if (log_m < log_cpus) {
113  basic_serial_radix2_fft<FieldType>(a, omega);
114  return;
115  }
116 
117  std::vector<std::vector<value_type>> tmp(num_cpus);
118  for (std::size_t j = 0; j < num_cpus; ++j) {
119  tmp[j].resize(1ul << (log_m - log_cpus), value_type::zero());
120  }
121 
122 #ifdef MULTICORE
123 #pragma omp parallel for
124 #endif
125  for (std::size_t j = 0; j < num_cpus; ++j) {
126  const value_type omega_j = omega.pow(j);
127  const value_type omega_step = omega.pow(j << (log_m - log_cpus));
128 
129  value_type elt = value_type::one();
130  for (std::size_t i = 0; i < 1ul << (log_m - log_cpus); ++i) {
131  for (std::size_t s = 0; s < num_cpus; ++s) {
132  // invariant: elt is omega^(j*idx)
133  const std::size_t idx = (i + (s << (log_m - log_cpus))) % (1u << log_m);
134  tmp[j][i] += a[idx] * elt;
135  elt *= omega_step;
136  }
137  elt *= omega_j;
138  }
139  }
140 
141  const value_type omega_num_cpus = omega.pow(num_cpus);
142 
143 #ifdef MULTICORE
144 #pragma omp parallel for
145 #endif
146  for (std::size_t j = 0; j < num_cpus; ++j) {
147  basic_serial_radix2_fft<FieldType>(tmp[j], omega_num_cpus);
148  }
149 
150 #ifdef MULTICORE
151 #pragma omp parallel for
152 #endif
153  for (std::size_t i = 0; i < num_cpus; ++i) {
154  for (std::size_t j = 0; j < 1ul << (log_m - log_cpus); ++j) {
155  // now: i = idx >> (log_m - log_cpus) and j = idx % (1u << (log_m - log_cpus)), for idx
156  // =
157  // ((i<<(log_m-log_cpus))+j) % (1u << log_m)
158  a[(j << log_cpus) + i] = tmp[i][j];
159  }
160  }
161  }
162 
163  template<typename FieldType, typename Range>
164  void basic_parallel_radix2_fft(Range &a, const typename FieldType::value_type &omega) {
165 #ifdef MULTICORE
166  const std::size_t num_cpus = omp_get_max_threads();
167 #else
168  const std::size_t num_cpus = 1;
169 #endif
170  const std::size_t log_cpus =
171  ((num_cpus & (num_cpus - 1)) == 0 ? log2(num_cpus) : log2(num_cpus) - 1);
172 
173  if (log_cpus == 0) {
174  basic_serial_radix2_fft<FieldType>(a, omega);
175  } else {
176  basic_parallel_radix2_fft_inner(a, omega, log_cpus);
177  }
178  }
179 
184  template<typename FieldType>
185  std::vector<typename FieldType::value_type>
187  const typename FieldType::value_type &t) {
188  typedef typename FieldType::value_type value_type;
189 
190  if (m == 1) {
191  return std::vector<value_type>(1, value_type::one());
192  }
193 
194  if (m != (1u << static_cast<std::size_t>(std::ceil(std::log2(m)))))
195  throw std::invalid_argument("expected m == (1u << log2(m))");
196 
197  const value_type omega = unity_root<FieldType>(m);
198 
199  std::vector<value_type> u(m, value_type::zero());
200 
201  /*
202  If t equals one of the roots of unity in S={omega^{0},...,omega^{m-1}}
203  then output 1 at the right place, and 0 elsewhere
204  */
205 
206  if (t.pow(m) == value_type::one()) {
207  value_type omega_i = value_type::one();
208  for (std::size_t i = 0; i < m; ++i) {
209  if (omega_i == t) // i.e., t equals omega^i
210  {
211  u[i] = value_type::one();
212  return u;
213  }
214 
215  omega_i *= omega;
216  }
217  }
218 
219  /*
220  Otherwise, if t does not equal any of the roots of unity in S,
221  then compute each L_{i,S}(t) as Z_{S}(t) * v_i / (t-\omega^i)
222  where:
223  - Z_{S}(t) = \prod_{j} (t-\omega^j) = (t^m-1), and
224  - v_{i} = 1 / \prod_{j \neq i} (\omega^i-\omega^j).
225  Below we use the fact that v_{0} = 1/m and v_{i+1} = \omega * v_{i}.
226  */
227 
228  const value_type Z = (t.pow(m)) - value_type::one();
229  value_type l = Z * value_type(m).inversed();
230  value_type r = value_type::one();
231  for (std::size_t i = 0; i < m; ++i) {
232  u[i] = l * (t - r).inversed();
233  l *= omega;
234  r *= omega;
235  }
236 
237  return u;
238  }
239  } // namespace detail
240  } // namespace fft
241  } // namespace crypto3
242 } // namespace nil
243 
244 #endif // ALGEBRA_FFT_BASIC_RADIX2_DOMAIN_AUX_HPP
void basic_parallel_radix2_fft(Range &a, const typename FieldType::value_type &omega)
Definition: basic_radix2_domain_aux.hpp:164
void basic_serial_radix2_fft(Range &a, const typename FieldType::value_type &omega)
Definition: basic_radix2_domain_aux.hpp:57
void basic_parallel_radix2_fft_inner(Range &a, const typename FieldType::value_type &omega, const std::size_t log_cpus)
Definition: basic_radix2_domain_aux.hpp:96
std::vector< typename FieldType::value_type > basic_radix2_evaluate_all_lagrange_polynomials(const std::size_t m, const typename FieldType::value_type &t)
Definition: basic_radix2_domain_aux.hpp:186
std::size_t bitreverse(std::size_t n, const std::size_t l)
Definition: field_utils.hpp:42
Definition: pair.hpp:31
Definition: algebra/include/nil/crypto3/algebra/type_traits.hpp:95