26 #ifndef CRYPTO3_MATH_STEP_RADIX2_DOMAIN_HPP
27 #define CRYPTO3_MATH_STEP_RADIX2_DOMAIN_HPP
41 template<
typename FieldType>
42 class evaluation_domain;
44 template<
typename FieldType>
46 typedef typename FieldType::value_type value_type;
59 throw std::invalid_argument(
"step_radix2(): expected m > 1");
61 big_m = 1ul << (
static_cast<std::size_t
>(std::ceil(std::log2(m))) - 1);
64 if (small_m != 1ul <<
static_cast<std::size_t
>(std::ceil(std::log2(small_m))))
65 throw std::invalid_argument(
"step_radix2(): expected small_m == 1ul<<log2(small_m)");
67 omega = unity_root<FieldType>(1ul <<
static_cast<std::size_t
>(std::ceil(std::log2(m))));
69 big_omega = omega.squared();
70 small_omega = unity_root<FieldType>(small_m);
73 void fft(std::vector<value_type> &a) {
74 if (a.size() != this->m) {
75 if (a.size() < this->m) {
76 a.resize(this->m, value_type(0));
78 throw std::invalid_argument(
"step_radix2: expected a.size() == this->m");
82 std::vector<value_type> c(big_m, value_type::zero());
83 std::vector<value_type> d(big_m, value_type::zero());
85 value_type omega_i = value_type::one();
86 for (std::size_t i = 0; i < big_m; ++i) {
87 c[i] = (i < small_m ? a[i] + a[i + big_m] : a[i]);
88 d[i] = omega_i * (i < small_m ? a[i] - a[i + big_m] : a[i]);
92 std::vector<value_type> e(small_m, value_type::zero());
93 const std::size_t compr = 1ul << (
static_cast<std::size_t
>(std::ceil(std::log2(big_m))) -
94 static_cast<std::size_t
>(std::ceil(std::log2(small_m))));
95 for (std::size_t i = 0; i < small_m; ++i) {
96 for (std::size_t j = 0; j < compr; ++j) {
97 e[i] += d[i + j * small_m];
101 _basic_radix2_fft<FieldType>(c, omega.squared());
102 _basic_radix2_fft<FieldType>(e, unity_root<FieldType>(small_m));
104 for (std::size_t i = 0; i < big_m; ++i) {
108 for (std::size_t i = 0; i < small_m; ++i) {
113 if (a.size() != this->m)
114 throw std::invalid_argument(
"step_radix2: expected a.size() == this->m");
116 std::vector<value_type> U0(a.begin(), a.begin() + big_m);
117 std::vector<value_type> U1(a.begin() + big_m, a.end());
119 _basic_radix2_fft<FieldType>(U0, omega.squared().inversed());
120 _basic_radix2_fft<FieldType>(U1, unity_root<FieldType>(small_m).inversed());
122 const value_type U0_size_inv = value_type(big_m).inversed();
123 for (std::size_t i = 0; i < big_m; ++i) {
124 U0[i] *= U0_size_inv;
127 const value_type U1_size_inv = value_type(small_m).inversed();
128 for (std::size_t i = 0; i < small_m; ++i) {
129 U1[i] *= U1_size_inv;
132 std::vector<value_type> tmp = U0;
133 value_type omega_i = value_type::one();
134 for (std::size_t i = 0; i < big_m; ++i) {
140 for (std::size_t i = small_m; i < big_m; ++i) {
144 const std::size_t compr = 1ul << (
static_cast<std::size_t
>(std::ceil(std::log2(big_m))) -
145 static_cast<std::size_t
>(std::ceil(std::log2(small_m))));
146 for (std::size_t i = 0; i < small_m; ++i) {
147 for (std::size_t j = 1; j < compr; ++j) {
148 U1[i] -= tmp[i + j * small_m];
152 const value_type omega_inv = omega.inversed();
153 value_type omega_inv_i = value_type::one();
154 for (std::size_t i = 0; i < small_m; ++i) {
155 U1[i] *= omega_inv_i;
156 omega_inv_i *= omega_inv;
160 const value_type over_two = value_type(2).inversed();
161 for (std::size_t i = 0; i < small_m; ++i) {
162 a[i] = (U0[i] + U1[i]) * over_two;
166 for (std::size_t i = 0; i < small_m; ++i) {
167 a[big_m + i] = (U0[i] - U1[i]) * over_two;
172 std::vector<value_type> inner_big =
173 detail::basic_radix2_evaluate_all_lagrange_polynomials<FieldType>(big_m, t);
174 std::vector<value_type> inner_small =
175 detail::basic_radix2_evaluate_all_lagrange_polynomials<FieldType>(small_m,
176 t * omega.inversed());
178 std::vector<value_type> result(this->m, value_type::zero());
180 const value_type L0 = t.pow(small_m) - omega.pow(small_m);
181 const value_type omega_to_small_m = omega.pow(small_m);
182 const value_type big_omega_to_small_m = big_omega.pow(small_m);
183 value_type elt = value_type::one();
184 for (std::size_t i = 0; i < big_m; ++i) {
185 result[i] = inner_big[i] * L0 * (elt - omega_to_small_m).inversed();
186 elt *= big_omega_to_small_m;
189 const value_type L1 =
190 (t.pow(big_m) - value_type::one()) * (omega.pow(big_m) - value_type::one()).inversed();
192 for (std::size_t i = 0; i < small_m; ++i) {
193 result[big_m + i] = L1 * inner_small[i];
201 return big_omega.pow(idx);
203 return omega * (small_omega.pow(idx - big_m));
208 return (t.pow(big_m) - value_type::one()) * (t.pow(small_m) - omega.pow(small_m));
211 void add_poly_z(
const value_type &coeff, std::vector<value_type> &H) {
215 const value_type omega_to_small_m = omega.pow(small_m);
218 H[big_m] -= coeff * omega_to_small_m;
220 H[0] += coeff * omega_to_small_m;
226 const value_type Z0 = coset.pow(big_m) - value_type::one();
227 const value_type coset_to_small_m_times_Z0 = coset.pow(small_m) * Z0;
228 const value_type omega_to_small_m_times_Z0 = omega.pow(small_m) * Z0;
229 const value_type omega_to_2small_m = omega.pow(2 * small_m);
230 value_type elt = value_type::one();
232 for (std::size_t i = 0; i < big_m; ++i) {
233 P[i] *= (coset_to_small_m_times_Z0 * elt - omega_to_small_m_times_Z0).inversed();
234 elt *= omega_to_2small_m;
239 const value_type Z1 = (((coset * omega).
pow(big_m) - value_type::one()) *
240 ((coset * omega).pow(small_m) - omega.pow(small_m)));
241 const value_type Z1_inverse = Z1.inversed();
243 for (std::size_t i = 0; i < small_m; ++i) {
244 P[big_m + i] *= Z1_inverse;
Definition: evaluation_domain.hpp:41
Definition: step_radix2_domain.hpp:45
value_type compute_vanishing_polynomial(const value_type &t)
Definition: step_radix2_domain.hpp:207
value_type omega
Definition: step_radix2_domain.hpp:53
FieldType field_type
Definition: step_radix2_domain.hpp:49
void fft(std::vector< value_type > &a)
Definition: step_radix2_domain.hpp:73
std::vector< value_type > evaluate_all_lagrange_polynomials(const value_type &t)
Definition: step_radix2_domain.hpp:171
void inverse_fft(std::vector< value_type > &a)
Definition: step_radix2_domain.hpp:112
value_type big_omega
Definition: step_radix2_domain.hpp:54
step_radix2_domain(const std::size_t m)
Definition: step_radix2_domain.hpp:57
std::size_t big_m
Definition: step_radix2_domain.hpp:51
void add_poly_z(const value_type &coeff, std::vector< value_type > &H)
Definition: step_radix2_domain.hpp:211
std::size_t small_m
Definition: step_radix2_domain.hpp:52
value_type get_domain_element(const std::size_t idx)
Definition: step_radix2_domain.hpp:199
value_type small_omega
Definition: step_radix2_domain.hpp:55
void divide_by_z_on_coset(std::vector< value_type > &P)
Definition: step_radix2_domain.hpp:222
constexpr T pow(T x, U n)
Definition: static_pow.hpp:32
Definition: fields/params.hpp:58