【UNR #2】黎明前的巧克力

题目链接

难度:简单

考虑如果一个集合大小为 $|S|$ 的集合 $S$ 中的所有数字异或和为 $0$,那对答案的贡献应该是 $2^{|S|}$,前提是 $|S|\not =0$,如果 $|S|=0$,则贡献应该是 $0$,我们最后再去掉这个限制的影响。

由此我们可以写出巧克力的结合幂级数:

$$
1+2x^{a_i}
$$

一个暴力的想法是直接用 $FWT$,但是这样一共要做 $n$ 次 $FWT$,每一次都是 $O(v\log v)$ 的($v$ 表示值域),对于时间来说非常浪费,我们需要考虑利用这个集合幂级数的特殊性质。

根据我们之前博客讲过的 $FWT(A)_i$ 的式子,容易发现,每一项的数都是 $-1$ 或者是 $3$,所以我们可以考虑算出来最后的每个位置上有多少个集合幂级数是 $-1$,有多少个是 $3$。具体来说,把上面的集合幂级数累加,然后只做一遍 $FWT$,通过解方程,我们可以得到把上面的式子异或卷积之后的到的 $FWT$ 式子。然后 $IFWT$ 即可。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#include<bits/stdc++.h>
#define mset(a,b) memset((a),(b),sizeof((a)))
#define rep(i,l,r) for(int i=(l);i<=(r);i++)
#define dec(i,l,r) for(int i=(r);i>=(l);i--)
#define inc(a,b) (((a)+(b))>=mod?(a)+(b)-mod:(a)+(b))
#define sub(a,b) (((a)-(b))<0?(a)-(b)+mod:(a)-(b))
#define mul(a,b) 1ll*(a)*(b)%mod
#define sgn(a) (((a)&1)?(mod-1):1)
#define cmax(a,b) (((a)<(b))?(a=b):(a))
#define cmin(a,b) (((a)>(b))?(a=b):(a))
#define Next(k) for(int x=head[k];x;x=li[x].next)
#define vc vector
#define ar array
#define pi pair
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define N 2000100
#define M number
using namespace std;

typedef double dd;
typedef long double ld;
typedef long long ll;
typedef unsigned int uint;
typedef unsigned long long ull;
//#define int long long
typedef pair<int,int> P;
typedef vector<int> vi;

const int INF=0x3f3f3f3f;
const dd eps=1e-9;
const int mod=998244353;

template<typename T> inline void read(T &x) {
x=0; int f=1;
char c=getchar();
for(;!isdigit(c);c=getchar()) if(c == '-') f=-f;
for(;isdigit(c);c=getchar()) x=x*10+c-'0';
x*=f;
}

int n,a[N],b[N],inv2,c[N],inv4;

inline int ksm(int a,int b,int mod){int res=1;while(b){if(b&1)res=1ll*res*a%mod;a=1ll*a*a%mod;b>>=1;}return res;}
inline int inv(int a){return ksm(a,mod-2,mod);}
inline P Calc(int a,int b){
int x=1ll*((a+b)%mod)*inv4%mod;
int y=(a-x)%mod;
// x=(x+mod)%mod;y=(y+mod)%mod;
return mp(x,y);
}
inline void FWT(int *f,int n,int op){
for(int i=2;i<=n;i<<=1)
for(int l=0;l<n;l+=i)
rep(k,l,l+i/2-1){
// printf("k=%d\n",k);
int a=f[k],b=f[k+i/2];
if(op){
f[k]=(a+b)%mod;f[k+i/2]=(a-b)%mod;
}
else{
f[k]=1ll*((a+b)%mod)*inv2%mod;
f[k+i/2]=1ll*((a-b)%mod)*inv2%mod;
}
}
}

int main(){
// assert(freopen("my.in","r",stdin));
// assert(freopen("my.out","w",stdout));
inv2=ksm(2,mod-2,mod);inv4=inv(4);
read(n);rep(i,1,n) read(a[i]);
int nn=n;
rep(i,1,n) b[a[i]]+=2;b[0]+=n;
int m=1000000;n=1;while(n<m) n<<=1;
FWT(b,n,1);
// rep(i,0,n-1) printf("%d ",b[i]);puts("");
rep(i,0,n-1){
P now=Calc(nn,b[i]);
// printf("now.fi=%d now.se=%d\n",now.fi,now.se);
c[i]=1ll*ksm(3,now.fi,mod)*sgn(now.se)%mod;
}
FWT(c,n,0);
c[0]--;
printf("%d\n",(c[0]+mod)%mod);
return 0;
}

CF1119H

题目链接

难度:较难

可以转化成上面那道题。

第一个转化:容易发现这个题的集合幂级数是这个形式:

$$
ux^{a_i}+vx^{b_i}+wx^{c_i}
$$

如果直接考虑的话,情况个数是 $2^3=8$,我们可以用下面的集合幂级数代替上面的,以把情况数缩小 $\frac{1}{2}$:

$$
ux^{a_i\oplus c_i}+vx^{b_i\oplus c_i}+w
$$

这样做对答案有什么影响,不难发现,我们只需要让答案下标异或上 $\oplus_{i=1}^{n}c_i=xsum$,就可以修正这样做带来的影响。这种技巧在没有常数项的时候适用,如果有常数项的存在将无法缩小情况。

接下来我们考虑转化成上面那道题解方程式的模型。不考虑 $w$,一共有 $4$ 种情况:

$$
\begin{cases}
u+v\
u-v\
-u+v\
-u-v\
\end{cases}
$$

设 $4$ 种情况在第 $i$ 位的出现次数分别为 $A_i,B_i,C_i,D_i$。
那么这一位的 $FWT$ 值应该是 $(u+v)^{A_i}(u-v)^{B_i}(-u+v)^{C_i}(-u-v)^{D_i}$,考虑通过解方程得到 $A_i,B_i,C_i,D_i$ 是多少。

显然 $A_i+B_i+C_i+D_i=n$,通过计算 $FWT(\sum x^{a_i\oplus c_i})$ 可以得到 $A_i+B_i$ 是多少(通过解一个二元方程)。同理通过计算 $FWT(\sum x^{b_i\oplus c_i})$ 可以得到 $A_i+C_i$ 是多少。

非常人类智慧的一点是我们可以通过计算 $FWT(\sum x^{(a_i\oplus c_i)\oplus (b_i \oplus c_i)})$ 来得到 $A_i+D_i$,这是因为以下式子成立:

$$
FWT(x^{(a_i\oplus c_i)\oplus (b_i \oplus c_i)})=FWT(x^{a_i\oplus b_i})
$$

可知

$$
[x^k]=(-1)^{|k&(a_i\oplus b_i)|}=(-1)^{(k&a_i)\oplus (k&b_i)}=(-1)^{k&a_i}(-1)^{k&b_i}
$$

所以只有当两边同号的时候 $x^k$ 的系数才能是 $1$,由此通过解方程我们可以得到 $A_i+D_i$。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include<bits/stdc++.h>
#define mset(a,b) memset((a),(b),sizeof((a)))
#define rep(i,l,r) for(int i=(l);i<=(r);i++)
#define dec(i,l,r) for(int i=(r);i>=(l);i--)
#define inc(a,b) (((a)+(b))>=mod?(a)+(b)-mod:(a)+(b))
#define sub(a,b) (((a)-(b))<0?(a)-(b)+mod:(a)-(b))
#define mul(a,b) 1ll*(a)*(b)%mod
#define sgn(a) (((a)&1)?(mod-1):1)
#define cmax(a,b) (((a)<(b))?(a=b):(a))
#define cmin(a,b) (((a)>(b))?(a=b):(a))
#define mr(a) ((a)=((a)+mod)%mod)
#define Next(k) for(int x=head[k];x;x=li[x].next)
#define vc vector
#define ar array
#define pi pair
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define N 400010
#define M number
using namespace std;

typedef double dd;
typedef long double ld;
typedef long long ll;
typedef unsigned int uint;
typedef unsigned long long ull;
#define int long long
typedef pair<int,int> P;
typedef vector<int> vi;

const int INF=0x3f3f3f3f;
const dd eps=1e-9;
const int mod=998244353;

template<typename T> inline void read(T &x) {
x=0; int f=1;
char c=getchar();
for(;!isdigit(c);c=getchar()) if(c == '-') f=-f;
for(;isdigit(c);c=getchar()) x=x*10+c-'0';
x*=f;
}

int n,k,u,v,w,a[N],b[N],c[N],xsum,inv2;
int f[N],g[N],h[N],ans[N],Ans[N];

inline int ksm(int a,int b,int mod){mr(a);int res=1;while(b){if(b&1)res=1ll*res*a%mod;a=1ll*a*a%mod;b>>=1;}return res;}
inline int inv(int a){return ksm(a,mod-2,mod);}
inline void FWT(int *f,int n,int op){
for(int i=2;i<=n;i<<=1)
for(int l=0;l<n;l+=i)
for(int k=l;k<l+(i/2);k++){
ll A=f[k],B=f[k+(i/2)];
if(op==1) f[k]=(A+B),f[k+(i/2)]=(A-B);
else f[k]=1ll*(A+B)*inv2%mod,f[k+(i/2)]=1ll*(A-B)*inv2%mod;
}
}
inline void Calc(ll &a,ll &b,ll &c,ll &d,ll a1,ll a2,ll a3,ll a4){
ll sum=(a2+a3+a4)-a1;sum%=mod;a=1ll*sum*inv2%mod;
b=a2-a;c=a3-a;d=a4-a;mr(a);mr(b);mr(c);mr(d);
}
inline void Calc(ll &a,ll &b,ll a1,ll a2){
ll sum=(a1+a2);sum%=mod;a=1ll*sum*inv2%mod;
b=a1-a;mr(a);mr(b);
// printf("a1=%lld a2=%lld a=%lld b=%lld\n",a1,a2,a,b);
}

signed main(){
// assert(freopen("my.in","r",stdin));
// assert(freopen("my.out","w",stdout));
read(n);read(k);read(u);read(v);read(w);rep(i,1,n) read(a[i]),read(b[i]),read(c[i]);
u%=mod;v%=mod;w%=mod;
rep(i,1,n) a[i]^=c[i],b[i]^=c[i];rep(i,1,n) xsum^=c[i];
rep(i,1,n) f[a[i]]++;rep(i,1,n) g[b[i]]++;rep(i,1,n) h[a[i]^b[i]]++;
int ln=n;
int nn=(1<<k)-1;n=1;while(n<=nn) n<<=1;inv2=inv(2);FWT(f,n,1);FWT(h,n,1);FWT(g,n,1);
// printf("f: ");rep(i,0,n-1) printf("%d ",f[i]);puts("");
// printf("g: ");rep(i,0,n-1) printf("%d ",g[i]);puts("");
// printf("h: ");rep(i,0,n-1) printf("%d ",h[i]);puts("");
rep(i,0,n-1){
// printf("i=%d\n",i);
ll A,B,C,D,a1,a2,a3,a4;a1=ln;
ll aa,bb,aa1,aa2;
aa1=ln;aa2=f[i];Calc(aa,bb,aa1,aa2);
a2=aa;aa2=g[i];Calc(aa,bb,aa1,aa2);
a3=aa;aa2=h[i];Calc(aa,bb,aa1,aa2);
a4=aa;
Calc(A,B,C,D,a1,a2,a3,a4);
ans[i]=1ll*ksm(u+v+w,A,mod)*ksm(u-v+w,B,mod)%mod*ksm(-u+v+w,C,mod)%mod*ksm(-u-v+w,D,mod)%mod;
}
FWT(ans,n,0);
rep(i,0,n-1) Ans[i]=ans[i^xsum];
rep(i,0,n-1) printf("%lld ",mr(Ans[i]));
return 0;
}

出现的问题:取模混乱,如果取模太过于复杂一定要记得 ll;记得处理负数。

石家庄的工人阶级队伍比较坚强

题目链接

因为不会卡常,所以只得了 $85pts$,是一道 k 进制 FWT 模板题,只需要一定的转化。不难发现 $u$ 是 $x-y$ 三进制下 $1$ 的个数,而 $v$ 三进制下 $2$ 的个数,当然这里的 $-$ 是在三进制下的。所以整个 $b$ 就和 $x-y$ 相关了,卷积即可。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#include<bits/stdc++.h>
#define mset(a,b) memset((a),(b),sizeof((a)))
#define rep(i,l,r) for(int i=(l);i<=(r);i++)
#define dec(i,l,r) for(int i=(r);i>=(l);i--)
#define mai(a) ((a)<0?((a)+mod):(((a)>mod)?(a)-mod:(a)))
#define sgn(a) (((a)&1)?(mod-1):1)
#define cmax(a,b) (((a)<(b))?(a=b):(a))
#define cmin(a,b) (((a)>(b))?(a=b):(a))
#define Next(k) for(int x=head[k];x;x=li[x].next)
#define vc vector
#define ar array
#define pi pair
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define N 631441
#define M 15
using namespace std;

typedef double dd;
typedef long double ld;
typedef long long ll;
typedef unsigned int uint;
typedef unsigned long long ull;
// #define int long long
typedef pair<int,int> P;
typedef vector<int> vi;

const int INF=0x3f3f3f3f;
const dd eps=1e-9;

template<typename T> inline void read(T &x) {
x=0; int f=1;
char c=getchar();
for(;!isdigit(c);c=getchar()) if(c == '-') f=-f;
for(;isdigit(c);c=getchar()) x=x*10+c-'0';
x*=f;
}

int mod,inv3;
inline ll exgcd(ll a,ll b,ll &x,ll &y){
if(b==0){x=1;y=0;return a;}ll g=exgcd(b,a%b,x,y);ll tmp=x;x=y;y=tmp-a/b*y;return g;
}
inline int inv(int a){
ll x,y;int g=exgcd(a,mod,x,y);assert(g==1);
// printf("a=%lld mod=%lld x=%lld y=%lld\n",a,mod,x,y);
assert(mai(1ll*x*a%mod)==1);
return (x%mod+mod)%mod;
}
inline int ksm(int a,int b,int mod){
int res=1;while(b){if(b&1)res=1ll*res*a%mod;a=1ll*a*a%mod;b>>=1;}return res;
}

struct cp{
int x,y;
inline cp() {}
inline cp(int x,int y) : x(x),y(y) {}
inline cp operator + (const cp &b)const{return cp((x+b.x)%mod,(y+b.y)%mod);}
inline cp operator - (const cp &b)const{return cp(mai(x-b.x),mai(y-b.y));}
inline cp operator * (const cp &b)const{
return cp(mai((1ll*x*b.y%mod+1ll*b.x*y%mod)%mod-1ll*x*b.x%mod),mai(1ll*y*b.y%mod-1ll*x*b.x%mod)%mod);
}
inline cp operator * (const int b)const{return cp(1ll*x*b%mod,1ll*y*b%mod);}
inline void Print(){
printf("x=%d y=%d\n",x,y);
}
};
inline cp ksm(cp a,int b){
cp res;res.y=1;res.x=0;
while(b){if(b&1)res=res*a;a=a*a;b>>=1;}return res;
}
int m,t,n,B[M][M];
cp f[N],b[N];
cp w2,w1;

inline int bit1(int x){
int res=0;while(x){if((x%3)==1) res++;x/=3;}return res;
}
inline int bit2(int x){
int res=0;while(x){if((x%3)==2) res++;x/=3;}return res;
}
// inline int add(int x,int y){
// int ans=0;
// int p=1;while(x/p!=0||y/p!=0){
// int nowx=(x/p)%3,nowy=(y/p)%3;
// int now=(nowx+nowy)%3;ans+=p*now;p*=3;
// }
// return ans;
// }
// inline int del(int x,int y){
// int ans=0;
// int p=1;while(x/p!=0||y/p!=0){
// int nowx=(x/p)%3,nowy=(y/p)%3;
// int now=(nowx-nowy+3)%3;ans+=p*now;p*=3;
// }
// return ans;
// }
inline void FWT(cp *f,int n){
for(int i=1;i<n;i*=3)
for(int j=0;j<n;j+=i*3)
rep(k,0,i-1){
int k1=j+k,k2=j+k+i,k3=j+k+i+i;
cp a=f[k1],b=f[k2],c=f[k3];
f[k1]=a+b+c;
f[k2]=a+b*w1+c*w2;
f[k3]=a+b*w2+c*w1;
}
}

inline void IFWT(cp *f,int n){
for(int i=1;i<n;i*=3)
for(int j=0;j<n;j+=i*3)
rep(k,0,i-1){
int k1=j+k,k2=j+k+i,k3=j+k+i*2;
cp a=f[k1],b=f[k2],c=f[k3];
f[k1]=a+b+c;
f[k2]=a+b*w2+c*w1;
f[k3]=a+b*w1+c*w2;
// f[k1]=f[k1]*inv3;
// f[k2]=f[k2]*inv3;
// f[k3]=f[k3]*inv3;
}
int ninv=ksm(inv3,m,mod);
rep(i,0,n-1) f[i]=f[i]*ninv;
}

signed main(){
// assert(freopen("my.in","r",stdin));
// assert(freopen("my.out","w",stdout));
read(m);read(t);read(mod);inv3=inv(3);
n=1;rep(i,1,m) n=n*3;rep(i,0,n-1) read(f[i].y);
rep(i,1,m+1) rep(j,1,m+2-i) read(B[i-1][j-1]);
rep(i,0,n-1) b[i].y=B[bit1(i)][bit2(i)];w1.x=1;w2=(cp){mod-1,mod-1};
FWT(b,n);FWT(f,n);
rep(i,0,n-1) b[i]=ksm(b[i],t);
rep(i,0,n-1) f[i]=f[i]*b[i];
// printf("inv3=%d\n",inv3);
IFWT(f,n);
rep(i,0,n-1) printf("%lld\n",f[i].y);
return 0;
}