extended_radix2_domain.hpp
Go to the documentation of this file.
1 //---------------------------------------------------------------------------//
2 // Copyright (c) 2020-2021 Mikhail Komarov <nemo@nil.foundation>
3 // Copyright (c) 2020-2021 Nikita Kaskov <nbering@nil.foundation>
4 //
5 // MIT License
6 //
7 // Permission is hereby granted, free of charge, to any person obtaining a copy
8 // of this software and associated documentation files (the "Software"), to deal
9 // in the Software without restriction, including without limitation the rights
10 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 // copies of the Software, and to permit persons to whom the Software is
12 // furnished to do so, subject to the following conditions:
13 //
14 // The above copyright notice and this permission notice shall be included in all
15 // copies or substantial portions of the Software.
16 //
17 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 // SOFTWARE.
24 //---------------------------------------------------------------------------//
25 
26 #ifndef CRYPTO3_MATH_EXTENDED_RADIX2_DOMAIN_HPP
27 #define CRYPTO3_MATH_EXTENDED_RADIX2_DOMAIN_HPP
28 
29 #include <vector>
30 
34 
35 namespace nil {
36  namespace crypto3 {
37  namespace math {
38 
39  using namespace nil::crypto3::algebra;
40 
41  template<typename FieldType>
42  class evaluation_domain;
43 
44  template<typename FieldType>
45  class extended_radix2_domain : public evaluation_domain<FieldType> {
46  typedef typename FieldType::value_type value_type;
47 
48  public:
49  typedef FieldType field_type;
50 
51  std::size_t small_m;
52  value_type omega;
53  value_type shift;
54 
55  extended_radix2_domain(const std::size_t m) : evaluation_domain<FieldType>(m) {
56  if (m <= 1)
57  throw std::invalid_argument("extended_radix2(): expected m > 1");
58 
59  if (!std::is_same<value_type, std::complex<double>>::value) {
60  const std::size_t logm = static_cast<std::size_t>(std::ceil(std::log2(m)));
62  throw std::invalid_argument(
63  "extended_radix2(): expected logm == fields::arithmetic_params<FieldType>::s + 1");
64  }
65 
66  small_m = m / 2;
67 
68  omega = unity_root<FieldType>(small_m);
69 
70  shift = detail::coset_shift<FieldType>();
71  }
72 
73  void fft(std::vector<value_type> &a) {
74  if (a.size() != this->m) {
75  if (a.size() < this->m) {
76  a.resize(this->m, value_type(0));
77  } else {
78  throw std::invalid_argument("extended_radix2: expected a.size() == this->m");
79  }
80  }
81 
82  std::vector<value_type> a0(small_m, value_type::zero());
83  std::vector<value_type> a1(small_m, value_type::zero());
84 
85  const value_type shift_to_small_m = shift.pow(small_m);
86 
87  value_type shift_i = value_type::one();
88  for (std::size_t i = 0; i < small_m; ++i) {
89  a0[i] = a[i] + a[small_m + i];
90  a1[i] = shift_i * (a[i] + shift_to_small_m * a[small_m + i]);
91 
92  shift_i *= shift;
93  }
94 
95  _basic_radix2_fft<FieldType>(a0, omega);
96  _basic_radix2_fft<FieldType>(a1, omega);
97 
98  for (std::size_t i = 0; i < small_m; ++i) {
99  a[i] = a0[i];
100  a[i + small_m] = a1[i];
101  }
102  }
103 
104  void inverse_fft(std::vector<value_type> &a) {
105  if (a.size() != this->m) {
106  if (a.size() < this->m) {
107  a.resize(this->m, value_type(0));
108  } else {
109  throw std::invalid_argument("extended_radix2: expected a.size() == this->m");
110  }
111  }
112 
113  // note: this is not in-place
114  std::vector<value_type> a0(a.begin(), a.begin() + small_m);
115  std::vector<value_type> a1(a.begin() + small_m, a.end());
116 
117  const value_type omega_inverse = omega.inversed();
118  _basic_radix2_fft<FieldType>(a0, omega_inverse);
119  _basic_radix2_fft<FieldType>(a1, omega_inverse);
120 
121  const value_type shift_to_small_m = shift.pow(small_m);
122  const value_type sconst = (value_type(small_m) * (value_type::one() - shift_to_small_m)).inversed();
123 
124  const value_type shift_inverse = shift.inversed();
125  value_type shift_inverse_i = value_type::one();
126 
127  for (std::size_t i = 0; i < small_m; ++i) {
128  a[i] = sconst * (-shift_to_small_m * a0[i] + shift_inverse_i * a1[i]);
129  a[i + small_m] = sconst * (a0[i] - shift_inverse_i * a1[i]);
130 
131  shift_inverse_i *= shift_inverse;
132  }
133  }
134 
135  std::vector<value_type> evaluate_all_lagrange_polynomials(const value_type &t) {
136  const std::vector<value_type> T0 =
137  detail::basic_radix2_evaluate_all_lagrange_polynomials<FieldType>(small_m, t);
138  const std::vector<value_type> T1 =
139  detail::basic_radix2_evaluate_all_lagrange_polynomials<FieldType>(small_m,
140  t * shift.inversed());
141 
142  std::vector<value_type> result(this->m, value_type::zero());
143 
144  const value_type t_to_small_m = t.pow(small_m);
145  const value_type shift_to_small_m = shift.pow(small_m);
146  const value_type one_over_denom = (shift_to_small_m - value_type::one()).inversed();
147  const value_type T0_coeff = (t_to_small_m - shift_to_small_m) * (-one_over_denom);
148  const value_type T1_coeff = (t_to_small_m - value_type::one()) * one_over_denom;
149  for (std::size_t i = 0; i < small_m; ++i) {
150  result[i] = T0[i] * T0_coeff;
151  result[i + small_m] = T1[i] * T1_coeff;
152  }
153 
154  return result;
155  }
156 
157  value_type get_domain_element(const std::size_t idx) {
158  if (idx < small_m) {
159  return omega.pow(idx);
160  } else {
161  return shift * (omega.pow(idx - small_m));
162  }
163  }
164 
165  value_type compute_vanishing_polynomial(const value_type &t) {
166  return (t.pow(small_m) - value_type::one()) * (t.pow(small_m) - shift.pow(small_m));
167  }
168 
169  void add_poly_z(const value_type &coeff, std::vector<value_type> &H) {
170  // if (H.size() != this->m + 1)
171  // throw std::invalid_argument("extended_radix2: expected H.size() == this->m+1");
172 
173  const value_type shift_to_small_m = shift.pow(small_m);
174 
175  H[this->m] += coeff;
176  H[small_m] -= coeff * (shift_to_small_m + value_type::one());
177  H[0] += coeff * shift_to_small_m;
178  }
179 
180  void divide_by_z_on_coset(std::vector<value_type> &P) {
182 
183  const value_type coset_to_small_m = coset.pow(small_m);
184  const value_type shift_to_small_m = shift.pow(small_m);
185 
186  const value_type Z0 =
187  (coset_to_small_m - value_type::one()) * (coset_to_small_m - shift_to_small_m);
188  const value_type Z1 = (coset_to_small_m * shift_to_small_m - value_type::one()) *
189  (coset_to_small_m * shift_to_small_m - shift_to_small_m);
190 
191  const value_type Z0_inverse = Z0.inversed();
192  const value_type Z1_inverse = Z1.inversed();
193 
194  for (std::size_t i = 0; i < small_m; ++i) {
195  P[i] *= Z0_inverse;
196  P[i + small_m] *= Z1_inverse;
197  }
198  }
199  };
200  } // namespace math
201  } // namespace crypto3
202 } // namespace nil
203 
204 #endif // ALGEBRA_FFT_EXTENDED_RADIX2_DOMAIN_HPP
Definition: evaluation_domain.hpp:41
Definition: extended_radix2_domain.hpp:45
value_type omega
Definition: extended_radix2_domain.hpp:52
value_type shift
Definition: extended_radix2_domain.hpp:53
value_type get_domain_element(const std::size_t idx)
Definition: extended_radix2_domain.hpp:157
extended_radix2_domain(const std::size_t m)
Definition: extended_radix2_domain.hpp:55
void divide_by_z_on_coset(std::vector< value_type > &P)
Definition: extended_radix2_domain.hpp:180
void add_poly_z(const value_type &coeff, std::vector< value_type > &H)
Definition: extended_radix2_domain.hpp:169
value_type compute_vanishing_polynomial(const value_type &t)
Definition: extended_radix2_domain.hpp:165
std::size_t small_m
Definition: extended_radix2_domain.hpp:51
std::vector< value_type > evaluate_all_lagrange_polynomials(const value_type &t)
Definition: extended_radix2_domain.hpp:135
void fft(std::vector< value_type > &a)
Definition: extended_radix2_domain.hpp:73
FieldType field_type
Definition: extended_radix2_domain.hpp:49
void inverse_fft(std::vector< value_type > &a)
Definition: extended_radix2_domain.hpp:104
Definition: pair.hpp:33
Definition: pair.hpp:31
Definition: fields/params.hpp:58