牛客挑战赛41 D买糖果 折半 多点求值

(11 mins to read)

给定n个数,求所有非空子集的和的乘积mod 998244353$n \leq 32$

考虑折半L,R,很容易状压求出两部分内部各自的贡献,两部分共同的贡献可以用$\prod\limits_{i=1}^{(1<<L)-1} \prod \limits_{j=1}^{(1<<R)-1} (c_i + d_j)$令$f(x) = \prod \limits_{j=1}^{(1<<R)-1}(x+d_j)$,显然可以用分治NTT求得该多项式,然后只要对每个$c_i$​求值后乘积即可,多点求值板子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#include <bits/stdc++.h>
using namespace std;
const int mod = 998244353, G = 3;
const int N = (1<<17) + 5;
int up, w[N], rev[N], inv[N];
int fpw(int a, int b)
{
int ans = 1;
while(b)
{
if(b&1) ans = 1ll*ans*a%mod;
a = 1ll*a*a%mod;
b >>= 1;
}
return ans;
}
namespace poly
{
void init(int n)
{
inv[0] = inv[1] = 1;
for(int i=2; i<=n; i++) inv[i] = 1ll*(mod-mod/i)*inv[mod%i]%mod;
up = 1; int l = 0;
while(up<=n) up <<= 1, l++;
for(int i=0; i<up; i++) rev[i] = (rev[i>>1]>>1)|((i&1)<<(l-1));
int wn = fpw(G, mod>>l); w[up>>1] = 1;
for(int i=(up>>1)+1; i<up; i++) w[i] = 1ll*w[i-1]*wn%mod;
for(int i=(up>>1)-1; i>=1; i--) w[i] = w[i<<1];
}
void clear(int *a, int n) { memset(a, 0, n<<2); }
int getlen(int n) { return 1<<(32-__builtin_clz(n)); }
inline void mul(int *a, int n, int x, int *b) { while(n--) *b++ = 1ll**a++*x%mod; }
inline void dot(int *a, int *b, int n, int *c) { while(n--) *c++ = 1ll**a++**b++%mod; }
void DFT(int *a, int l)
{
static unsigned long long tmp[N];
int u = __builtin_ctz(up/l), t;
for(int i=0; i<l; i++) tmp[i] = a[rev[i]>>u];
for(int i=1; i^l; i<<=1)
for(int j=0, d=i<<1; j^l; j+=d)
for(int k=0; k<i; k++)
t = tmp[i|j|k]*w[i|k]%mod, tmp[i|j|k] = tmp[j|k]+mod-t, tmp[j|k] += t;
for(int i=0; i<l; i++) a[i] = tmp[i]%mod;
}
void IDFT(int *a, int l)
{
reverse(a+1, a+l); DFT(a, l);
mul(a, l, mod-mod/l, a);
}
inline void conv(int *a, int *b, int l) { DFT(a, l); DFT(b, l); dot(a, b, l, a); IDFT(a, l); }
void Inv(const int *a, int *b, int n)
{
static int c[N], l;
if(n==0) { b[0] = fpw(a[0], mod-2); return; }
Inv(a, b, n>>1); l = getlen(n<<1);
for(int i=0; i<=n; i++) c[i] = a[i];
for(int i=n+1; i<l; i++) c[i] = 0;
DFT(c, l); DFT(b, l);
for(int i=0; i<l; i++) b[i] = (2ll-1ll*c[i]*b[i]%mod+mod)%mod*b[i]%mod;
IDFT(b, l);
for(int i=n+1; i<l; i++) b[i] = 0;
}
int *f[N], *g[N], buf[N<<5], *np(buf);
void mul(int *a, int n, int *b, int m, int *c, int deg, int st)
{
static int A[N], B[N], l;
l = getlen(deg), copy(a, a+n+1, A), copy(b, b+m+1, B);
conv(A, B, l); copy(A+st, A+deg+1, c);
clear(A, l), clear(B, l);
}
void eval_init(int p, int l, int r, int *a)
{
g[p] = np, np += r-l+2, f[p] = np, np += r-l+2;
if(l==r) { g[p][0] = (mod-a[l])%mod, g[p][1] = 1; return; }
int lc = p<<1, rc = lc|1, mid = (l+r)>>1, up1 = mid-l+1, up2 = r-mid;
eval_init(lc, l, mid, a); eval_init(rc, mid+1, r, a);
mul(g[lc], up1, g[rc], up2, g[p], up1+up2, 0);
}
void eval_work(int p, int l, int r, int *a)
{
if(l==r) { a[l] = f[p][0]; return; }
int lc = p<<1, rc = lc|1, mid = (l+r)>>1, up1 = mid-l+1, up2 = r-mid;
mul(f[p], r-l, g[rc], up2, f[lc], r-l, up2);
eval_work(lc, l, mid, a);
mul(f[p], r-l, g[lc], up1, f[rc], r-l, up1);
eval_work(rc, mid+1, r, a);
}
void eval(int *a, int n, int *b, int m, int *c)
{
static int invg[N], q[N];
eval_init(1, 1, m, b);
reverse(g[1], g[1]+m+1);
Inv(g[1], invg, m);
reverse(invg, invg+m+1);
mul(a, n, invg, m, q, n+m, 0);
copy(q+n+1, q+n+m+1, f[1]);
eval_work(1, 1, m, c);
for(int i=1; i<=m; i++) c[i] = (1ll*c[i]*b[i]%mod+a[0])%mod;
}
}
int n, a[35];
int c[N], d[N], v[N];
int *f[N], pool[N<<5], *ptr(pool);
void solve(int p, int l, int r)
{
f[p] = ptr, ptr += r-l+2;
if(l==r) { f[p][0] = d[l], f[p][1] = 1; return; }
int lc = p<<1, rc = lc|1, mid = (l+r)>>1;
solve(lc, l, mid); solve(rc, mid+1, r);
poly::mul(f[lc], mid-l+1, f[rc], r-mid, f[p], r-l+1, 0);
}
int main()
{
ios_base::sync_with_stdio(false); cin.tie(nullptr);
cin >> n;
for(int i=0; i<n; i++) cin >> a[i];
int L = n/2, R = n - L, ans = 1;
poly::init((1<<(R+1))-1);
for(int i=1; i<(1<<L); i++)
{
int cur = 0;
for(int j=0; j<L; j++)
if((i>>j)&1) cur += a[j];
ans = 1ll*ans*cur%mod;
c[i] = cur;
}
for(int i=1; i<(1<<R); i++)
{
int cur = 0;
for(int j=0; j<R; j++)
if((i>>j)&1) cur += a[L+j];
ans = 1ll*ans*cur%mod;
d[i] = cur;
}
solve(1, 1, (1<<R)-1);
poly::eval(f[1], (1<<R)-1, c, (1<<R)-1, v);
for(int i=1; i<(1<<L); i++) ans = 1ll*ans*v[i]%mod;
cout << ans << '\n';
return 0;
}