step_radix2_domain.hpp
Go to the documentation of this file.
1 //---------------------------------------------------------------------------//
2 // Copyright (c) 2020-2021 Mikhail Komarov <nemo@nil.foundation>
3 // Copyright (c) 2020-2021 Nikita Kaskov <nbering@nil.foundation>
4 //
5 // MIT License
6 //
7 // Permission is hereby granted, free of charge, to any person obtaining a copy
8 // of this software and associated documentation files (the "Software"), to deal
9 // in the Software without restriction, including without limitation the rights
10 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 // copies of the Software, and to permit persons to whom the Software is
12 // furnished to do so, subject to the following conditions:
13 //
14 // The above copyright notice and this permission notice shall be included in all
15 // copies or substantial portions of the Software.
16 //
17 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 // SOFTWARE.
24 //---------------------------------------------------------------------------//
25 
26 #ifndef CRYPTO3_MATH_STEP_RADIX2_DOMAIN_HPP
27 #define CRYPTO3_MATH_STEP_RADIX2_DOMAIN_HPP
28 
29 #include <vector>
30 
34 
35 namespace nil {
36  namespace crypto3 {
37  namespace math {
38 
39  using namespace nil::crypto3::algebra;
40 
41  template<typename FieldType>
42  class evaluation_domain;
43 
44  template<typename FieldType>
45  class step_radix2_domain : public evaluation_domain<FieldType> {
46  typedef typename FieldType::value_type value_type;
47 
48  public:
49  typedef FieldType field_type;
50 
51  std::size_t big_m;
52  std::size_t small_m;
53  value_type omega;
54  value_type big_omega;
55  value_type small_omega;
56 
57  step_radix2_domain(const std::size_t m) : evaluation_domain<FieldType>(m) {
58  if (m <= 1)
59  throw std::invalid_argument("step_radix2(): expected m > 1");
60 
61  big_m = 1ul << (static_cast<std::size_t>(std::ceil(std::log2(m))) - 1);
62  small_m = m - big_m;
63 
64  if (small_m != 1ul << static_cast<std::size_t>(std::ceil(std::log2(small_m))))
65  throw std::invalid_argument("step_radix2(): expected small_m == 1ul<<log2(small_m)");
66 
67  omega = unity_root<FieldType>(1ul << static_cast<std::size_t>(std::ceil(std::log2(m))));
68 
69  big_omega = omega.squared();
70  small_omega = unity_root<FieldType>(small_m);
71  }
72 
73  void fft(std::vector<value_type> &a) {
74  if (a.size() != this->m) {
75  if (a.size() < this->m) {
76  a.resize(this->m, value_type(0));
77  } else {
78  throw std::invalid_argument("step_radix2: expected a.size() == this->m");
79  }
80  }
81 
82  std::vector<value_type> c(big_m, value_type::zero());
83  std::vector<value_type> d(big_m, value_type::zero());
84 
85  value_type omega_i = value_type::one();
86  for (std::size_t i = 0; i < big_m; ++i) {
87  c[i] = (i < small_m ? a[i] + a[i + big_m] : a[i]);
88  d[i] = omega_i * (i < small_m ? a[i] - a[i + big_m] : a[i]);
89  omega_i *= omega;
90  }
91 
92  std::vector<value_type> e(small_m, value_type::zero());
93  const std::size_t compr = 1ul << (static_cast<std::size_t>(std::ceil(std::log2(big_m))) -
94  static_cast<std::size_t>(std::ceil(std::log2(small_m))));
95  for (std::size_t i = 0; i < small_m; ++i) {
96  for (std::size_t j = 0; j < compr; ++j) {
97  e[i] += d[i + j * small_m];
98  }
99  }
100 
101  _basic_radix2_fft<FieldType>(c, omega.squared());
102  _basic_radix2_fft<FieldType>(e, unity_root<FieldType>(small_m));
103 
104  for (std::size_t i = 0; i < big_m; ++i) {
105  a[i] = c[i];
106  }
107 
108  for (std::size_t i = 0; i < small_m; ++i) {
109  a[i + big_m] = e[i];
110  }
111  }
112  void inverse_fft(std::vector<value_type> &a) {
113  if (a.size() != this->m)
114  throw std::invalid_argument("step_radix2: expected a.size() == this->m");
115 
116  std::vector<value_type> U0(a.begin(), a.begin() + big_m);
117  std::vector<value_type> U1(a.begin() + big_m, a.end());
118 
119  _basic_radix2_fft<FieldType>(U0, omega.squared().inversed());
120  _basic_radix2_fft<FieldType>(U1, unity_root<FieldType>(small_m).inversed());
121 
122  const value_type U0_size_inv = value_type(big_m).inversed();
123  for (std::size_t i = 0; i < big_m; ++i) {
124  U0[i] *= U0_size_inv;
125  }
126 
127  const value_type U1_size_inv = value_type(small_m).inversed();
128  for (std::size_t i = 0; i < small_m; ++i) {
129  U1[i] *= U1_size_inv;
130  }
131 
132  std::vector<value_type> tmp = U0;
133  value_type omega_i = value_type::one();
134  for (std::size_t i = 0; i < big_m; ++i) {
135  tmp[i] *= omega_i;
136  omega_i *= omega;
137  }
138 
139  // save A_suffix
140  for (std::size_t i = small_m; i < big_m; ++i) {
141  a[i] = U0[i];
142  }
143 
144  const std::size_t compr = 1ul << (static_cast<std::size_t>(std::ceil(std::log2(big_m))) -
145  static_cast<std::size_t>(std::ceil(std::log2(small_m))));
146  for (std::size_t i = 0; i < small_m; ++i) {
147  for (std::size_t j = 1; j < compr; ++j) {
148  U1[i] -= tmp[i + j * small_m];
149  }
150  }
151 
152  const value_type omega_inv = omega.inversed();
153  value_type omega_inv_i = value_type::one();
154  for (std::size_t i = 0; i < small_m; ++i) {
155  U1[i] *= omega_inv_i;
156  omega_inv_i *= omega_inv;
157  }
158 
159  // compute A_prefix
160  const value_type over_two = value_type(2).inversed();
161  for (std::size_t i = 0; i < small_m; ++i) {
162  a[i] = (U0[i] + U1[i]) * over_two;
163  }
164 
165  // compute B2
166  for (std::size_t i = 0; i < small_m; ++i) {
167  a[big_m + i] = (U0[i] - U1[i]) * over_two;
168  }
169  }
170 
171  std::vector<value_type> evaluate_all_lagrange_polynomials(const value_type &t) {
172  std::vector<value_type> inner_big =
173  detail::basic_radix2_evaluate_all_lagrange_polynomials<FieldType>(big_m, t);
174  std::vector<value_type> inner_small =
175  detail::basic_radix2_evaluate_all_lagrange_polynomials<FieldType>(small_m,
176  t * omega.inversed());
177 
178  std::vector<value_type> result(this->m, value_type::zero());
179 
180  const value_type L0 = t.pow(small_m) - omega.pow(small_m);
181  const value_type omega_to_small_m = omega.pow(small_m);
182  const value_type big_omega_to_small_m = big_omega.pow(small_m);
183  value_type elt = value_type::one();
184  for (std::size_t i = 0; i < big_m; ++i) {
185  result[i] = inner_big[i] * L0 * (elt - omega_to_small_m).inversed();
186  elt *= big_omega_to_small_m;
187  }
188 
189  const value_type L1 =
190  (t.pow(big_m) - value_type::one()) * (omega.pow(big_m) - value_type::one()).inversed();
191 
192  for (std::size_t i = 0; i < small_m; ++i) {
193  result[big_m + i] = L1 * inner_small[i];
194  }
195 
196  return result;
197  }
198 
199  value_type get_domain_element(const std::size_t idx) {
200  if (idx < big_m) {
201  return big_omega.pow(idx);
202  } else {
203  return omega * (small_omega.pow(idx - big_m));
204  }
205  }
206 
207  value_type compute_vanishing_polynomial(const value_type &t) {
208  return (t.pow(big_m) - value_type::one()) * (t.pow(small_m) - omega.pow(small_m));
209  }
210 
211  void add_poly_z(const value_type &coeff, std::vector<value_type> &H) {
212  // if (H.size() != this->m + 1)
213  // throw std::invalid_argument("step_radix2: expected H.size() == this->m+1");
214 
215  const value_type omega_to_small_m = omega.pow(small_m);
216 
217  H[this->m] += coeff;
218  H[big_m] -= coeff * omega_to_small_m;
219  H[small_m] -= coeff;
220  H[0] += coeff * omega_to_small_m;
221  }
222  void divide_by_z_on_coset(std::vector<value_type> &P) {
223  // (c^{2^k}-1) * (c^{2^r} * w^{2^{r+1}*i) - w^{2^r})
225 
226  const value_type Z0 = coset.pow(big_m) - value_type::one();
227  const value_type coset_to_small_m_times_Z0 = coset.pow(small_m) * Z0;
228  const value_type omega_to_small_m_times_Z0 = omega.pow(small_m) * Z0;
229  const value_type omega_to_2small_m = omega.pow(2 * small_m);
230  value_type elt = value_type::one();
231 
232  for (std::size_t i = 0; i < big_m; ++i) {
233  P[i] *= (coset_to_small_m_times_Z0 * elt - omega_to_small_m_times_Z0).inversed();
234  elt *= omega_to_2small_m;
235  }
236 
237  // (c^{2^k}*w^{2^k}-1) * (c^{2^k} * w^{2^r} - w^{2^r})
238 
239  const value_type Z1 = (((coset * omega).pow(big_m) - value_type::one()) *
240  ((coset * omega).pow(small_m) - omega.pow(small_m)));
241  const value_type Z1_inverse = Z1.inversed();
242 
243  for (std::size_t i = 0; i < small_m; ++i) {
244  P[big_m + i] *= Z1_inverse;
245  }
246  }
247  };
248  } // namespace math
249  } // namespace crypto3
250 } // namespace nil
251 
252 #endif // ALGEBRA_FFT_STEP_RADIX2_DOMAIN_HPP
Definition: evaluation_domain.hpp:41
Definition: step_radix2_domain.hpp:45
value_type compute_vanishing_polynomial(const value_type &t)
Definition: step_radix2_domain.hpp:207
value_type omega
Definition: step_radix2_domain.hpp:53
FieldType field_type
Definition: step_radix2_domain.hpp:49
void fft(std::vector< value_type > &a)
Definition: step_radix2_domain.hpp:73
std::vector< value_type > evaluate_all_lagrange_polynomials(const value_type &t)
Definition: step_radix2_domain.hpp:171
void inverse_fft(std::vector< value_type > &a)
Definition: step_radix2_domain.hpp:112
value_type big_omega
Definition: step_radix2_domain.hpp:54
step_radix2_domain(const std::size_t m)
Definition: step_radix2_domain.hpp:57
std::size_t big_m
Definition: step_radix2_domain.hpp:51
void add_poly_z(const value_type &coeff, std::vector< value_type > &H)
Definition: step_radix2_domain.hpp:211
std::size_t small_m
Definition: step_radix2_domain.hpp:52
value_type get_domain_element(const std::size_t idx)
Definition: step_radix2_domain.hpp:199
value_type small_omega
Definition: step_radix2_domain.hpp:55
void divide_by_z_on_coset(std::vector< value_type > &P)
Definition: step_radix2_domain.hpp:222
Definition: pair.hpp:33
constexpr T pow(T x, U n)
Definition: static_pow.hpp:32
Definition: pair.hpp:31
Definition: fields/params.hpp:58