26 #ifndef CRYPTO3_MATH_EXTENDED_RADIX2_DOMAIN_HPP
27 #define CRYPTO3_MATH_EXTENDED_RADIX2_DOMAIN_HPP
41 template<
typename FieldType>
42 class evaluation_domain;
44 template<
typename FieldType>
46 typedef typename FieldType::value_type value_type;
57 throw std::invalid_argument(
"extended_radix2(): expected m > 1");
59 if (!std::is_same<value_type, std::complex<double>>::value) {
60 const std::size_t logm =
static_cast<std::size_t
>(std::ceil(std::log2(m)));
62 throw std::invalid_argument(
63 "extended_radix2(): expected logm == fields::arithmetic_params<FieldType>::s + 1");
68 omega = unity_root<FieldType>(small_m);
70 shift = detail::coset_shift<FieldType>();
73 void fft(std::vector<value_type> &a) {
74 if (a.size() != this->m) {
75 if (a.size() < this->m) {
76 a.resize(this->m, value_type(0));
78 throw std::invalid_argument(
"extended_radix2: expected a.size() == this->m");
82 std::vector<value_type> a0(small_m, value_type::zero());
83 std::vector<value_type> a1(small_m, value_type::zero());
85 const value_type shift_to_small_m = shift.pow(small_m);
87 value_type shift_i = value_type::one();
88 for (std::size_t i = 0; i < small_m; ++i) {
89 a0[i] = a[i] + a[small_m + i];
90 a1[i] = shift_i * (a[i] + shift_to_small_m * a[small_m + i]);
95 _basic_radix2_fft<FieldType>(a0, omega);
96 _basic_radix2_fft<FieldType>(a1, omega);
98 for (std::size_t i = 0; i < small_m; ++i) {
100 a[i + small_m] = a1[i];
105 if (a.size() != this->m) {
106 if (a.size() < this->m) {
107 a.resize(this->m, value_type(0));
109 throw std::invalid_argument(
"extended_radix2: expected a.size() == this->m");
114 std::vector<value_type> a0(a.begin(), a.begin() + small_m);
115 std::vector<value_type> a1(a.begin() + small_m, a.end());
117 const value_type omega_inverse = omega.inversed();
118 _basic_radix2_fft<FieldType>(a0, omega_inverse);
119 _basic_radix2_fft<FieldType>(a1, omega_inverse);
121 const value_type shift_to_small_m = shift.pow(small_m);
122 const value_type sconst = (value_type(small_m) * (value_type::one() - shift_to_small_m)).inversed();
124 const value_type shift_inverse = shift.inversed();
125 value_type shift_inverse_i = value_type::one();
127 for (std::size_t i = 0; i < small_m; ++i) {
128 a[i] = sconst * (-shift_to_small_m * a0[i] + shift_inverse_i * a1[i]);
129 a[i + small_m] = sconst * (a0[i] - shift_inverse_i * a1[i]);
131 shift_inverse_i *= shift_inverse;
136 const std::vector<value_type> T0 =
137 detail::basic_radix2_evaluate_all_lagrange_polynomials<FieldType>(small_m, t);
138 const std::vector<value_type> T1 =
139 detail::basic_radix2_evaluate_all_lagrange_polynomials<FieldType>(small_m,
140 t * shift.inversed());
142 std::vector<value_type> result(this->m, value_type::zero());
144 const value_type t_to_small_m = t.pow(small_m);
145 const value_type shift_to_small_m = shift.pow(small_m);
146 const value_type one_over_denom = (shift_to_small_m - value_type::one()).inversed();
147 const value_type T0_coeff = (t_to_small_m - shift_to_small_m) * (-one_over_denom);
148 const value_type T1_coeff = (t_to_small_m - value_type::one()) * one_over_denom;
149 for (std::size_t i = 0; i < small_m; ++i) {
150 result[i] = T0[i] * T0_coeff;
151 result[i + small_m] = T1[i] * T1_coeff;
159 return omega.pow(idx);
161 return shift * (omega.pow(idx - small_m));
166 return (t.pow(small_m) - value_type::one()) * (t.pow(small_m) - shift.pow(small_m));
169 void add_poly_z(
const value_type &coeff, std::vector<value_type> &H) {
173 const value_type shift_to_small_m = shift.pow(small_m);
176 H[small_m] -= coeff * (shift_to_small_m + value_type::one());
177 H[0] += coeff * shift_to_small_m;
183 const value_type coset_to_small_m = coset.pow(small_m);
184 const value_type shift_to_small_m = shift.pow(small_m);
186 const value_type Z0 =
187 (coset_to_small_m - value_type::one()) * (coset_to_small_m - shift_to_small_m);
188 const value_type Z1 = (coset_to_small_m * shift_to_small_m - value_type::one()) *
189 (coset_to_small_m * shift_to_small_m - shift_to_small_m);
191 const value_type Z0_inverse = Z0.inversed();
192 const value_type Z1_inverse = Z1.inversed();
194 for (std::size_t i = 0; i < small_m; ++i) {
196 P[i + small_m] *= Z1_inverse;
Definition: evaluation_domain.hpp:41
Definition: extended_radix2_domain.hpp:45
value_type omega
Definition: extended_radix2_domain.hpp:52
value_type shift
Definition: extended_radix2_domain.hpp:53
value_type get_domain_element(const std::size_t idx)
Definition: extended_radix2_domain.hpp:157
extended_radix2_domain(const std::size_t m)
Definition: extended_radix2_domain.hpp:55
void divide_by_z_on_coset(std::vector< value_type > &P)
Definition: extended_radix2_domain.hpp:180
void add_poly_z(const value_type &coeff, std::vector< value_type > &H)
Definition: extended_radix2_domain.hpp:169
value_type compute_vanishing_polynomial(const value_type &t)
Definition: extended_radix2_domain.hpp:165
std::size_t small_m
Definition: extended_radix2_domain.hpp:51
std::vector< value_type > evaluate_all_lagrange_polynomials(const value_type &t)
Definition: extended_radix2_domain.hpp:135
void fft(std::vector< value_type > &a)
Definition: extended_radix2_domain.hpp:73
FieldType field_type
Definition: extended_radix2_domain.hpp:49
void inverse_fft(std::vector< value_type > &a)
Definition: extended_radix2_domain.hpp:104
Definition: fields/params.hpp:58