P6800 【模板】Chirp-Z Transform

(6 mins to read)

给出一个n项多项式P,要求出在$c^0,c^1,…,c^{m-1}$处多项式的值$n\leq 10^6,m\leq 10^6$

做法

直接多点求值会t$\sum \limits_{i=0}^{n-1}a_ic^{mi}P(cm)$考虑把$mi$换一下$mi = \binom{m+i}{2} - \binom{m}{2} -\binom{i}{2}$$P(c^m) = c^{-\binom{m}{2}}\sum\limits_{i=0}^{n-1}(a_ic^{-\binom{i}{2}})(c^{\binom{m+i}{2}})$减法卷积,把前一个翻转后NTT即可注意指数上的取模要用扩展欧拉定理如果把c换成$\omega_n^1$​即$complex(cos(2\pi/n),sin(2\pi/n))$,就是bluestein算法了,可以解决任意长度的卷积,就是常数很大,因为一次DFT就相当于一次多项式乘法,即三次DFT小优化:由于我们只需要得到n到n+m项的值,所以n次多项式和n+m次多项式我们只需要求到n+m即可,而不用求到2*n+m

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include <bits/stdc++.h>
using namespace std;
const int N = 6e6 + 5;
const int mod = 998244353, G = 3;
int n, c, m, a[N], b[N];
int up, w[N], rev[N];
int fpw(int a, int b)
{
int ans = 1;
while(b)
{
if(b&1) ans = 1ll*ans*a%mod;
a = 1ll*a*a%mod;
b >>= 1;
}
return ans;
}
void init(int n)
{
up = 1; int l = 0;
while(up<=n) up <<= 1, l++;
for(int i=0; i<up; i++) rev[i] = (rev[i>>1]>>1)|((i&1)<<(l-1));
int wn = fpw(G, mod>>l); w[up>>1] = 1;
for(int i=(up>>1)+1; i<up; i++) w[i] = 1ll*w[i-1]*wn%mod;
for(int i=(up>>1)-1; i>=1; i--) w[i] = w[i<<1];
}
int getlen(int n) { return 1<<(32-__builtin_clz(n)); }
inline void mul(int *a, int n, int x, int *b) { while(n--) *b++ = 1ll**a++*x%mod; }
inline void dot(int *a, int *b, int n, int *c) { while(n--) *c++ = 1ll**a++**b++%mod; }
void DFT(int *a, int l)
{
static unsigned long long tmp[N];
int u = __builtin_ctz(up/l), t;
for(int i=0; i<l; i++) tmp[i] = a[rev[i]>>u];
for(int i=1; i^l; i<<=1)
for(int j=0, d=i<<1; j^l; j+=d)
for(int k=0; k<i; k++)
t = tmp[i|j|k]*w[i|k]%mod, tmp[i|j|k] = tmp[j|k]+mod-t, tmp[j|k] += t;
for(int i=0; i<l; i++) a[i] = tmp[i]%mod;
}
void IDFT(int *a, int l)
{
reverse(a+1, a+l); DFT(a, l);
mul(a, l, mod-mod/l, a);
}
int main()
{
scanf("%d%d%d", &n, &c, &m);
--n, --m;
for(int i=0; i<=n; i++)
{
scanf("%d", a+i);
a[i] = 1ll*a[i]*fpw(c, (mod-1)-1ll*i*(i-1)/2%(mod-1))%mod;
}
for(int i=0; i<=n+m; i++) b[i] = fpw(c, 1ll*i*(i-1)/2%(mod-1));
init(2*n+m);
reverse(a, a+n+1);
int l = getlen(2*n+m);
DFT(a, l); DFT(b, l); dot(a, b, l, a); IDFT(a, l);
for(int i=n; i<=n+m; i++) printf("%d%c", 1ll*a[i]*fpw(c, (mod-1)-1ll*(i-n)*(i-n-1)/2%(mod-1))%mod, " \n"[i==n+m]);
return 0;
}