二进制[BJOI2018]

题目链接

loj

题意

一个长为 $n$ 的 $01$ 串,每次单调修改,或询问某个区间内,有多少个子区间,满足这个子区间存在一种方案,使得经过重排后为 $3$ 的倍数。

$n,m\leq 10^5$。

题解

线段树神奇操作qwq。

首先容斥,变成求有多少个子区间不满足,发现当且仅当区间中( $1$ 的个数只有 $1$ )或者( $1$ 的个数出现奇数次且 $0$ 的次数小于 $2$)时,这个区间不满足条件。

那么我们只需要上面这算两种情况就好了,注意两种情况的重合部分(即 $1,01,10$ 三种情况)。

对于这两种情况,在线段树中存强制选左/右端点,主体部分是 $0/1$,到这个端点里有 $0/1$ 个其他数的区间有多少种。

合并的时候大力分类讨论即可。

时间复杂度 $O(n\log n)$。

程序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <vector>
#include <cstring>
#include <queue>
using namespace std;
#define FO(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
#define fo(i,j,k) for(int i=(j),end_i=(k);i<=end_i;i++)
#define ff(i,j,k) for(int i=(j),end_i=(k);i< end_i;i++)
#define fd(i,j,k) for(int i=(j),end_i=(k);i>=end_i;i--)
#define DEBUG(x) cerr<<#x<<"="<<x<<endl
#define all(x) (x).begin(),(x).end()
#define cle(x) memset(x,0,sizeof(x))
#define lowbit(x) ((x)&-(x))
#define ll long long
#define ull unsigned ll
#define db double
#define lb long db
#define pb push_back
#define mp make_pair
#define fi first
#define se second
inline int read()
{
int x=0; char ch=getchar(); bool f=0;
for(;ch<'0'||ch>'9';ch=getchar()) if(ch=='-') f=1;
for(;ch>='0'&&ch<='9';ch=getchar()) x=(x<<3)+(x<<1)+(ch^48);
return f?-x:x;
}
#define CASET fo(___,1,read())

struct node{
int l0,l1,r0,r1,len;
int zl0,zl1,zr0,zr1;
ll sum;
inline void init(int x)
{
len=1,sum=x;
l0=r0=zl1=zr1=x;
l1=r1=zl0=zr0=!x;
}
inline void change()
{
sum=1-sum;
l0=r0=zl1=zr1=sum;
l1=r1=zl0=zr0=!sum;
}
friend inline node operator+(const node &A,const node &B)
{
node S;
S.len=A.len+B.len;
S.sum=A.sum+B.sum;
S.sum+=1ll*A.zr0*B.zl1+1ll*A.zr1*B.zl0;
ll le[2],ri[2];
le[0]=A.r0>>1; le[1]=A.r0-le[0];
ri[0]=B.l0>>1; ri[1]=B.l0-ri[0];
S.sum+=le[0]*ri[1]+le[1]*ri[0];
if(A.r1)
{
le[0]=(A.r1+1)>>1; le[1]=A.r1-le[0];
if(A.r0&1) swap(le[0],le[1]);
ri[0]=B.l0>>1; ri[1]=B.l0-ri[0];
S.sum+=le[0]*ri[1]+le[1]*ri[0];
if(!A.r0&&B.l0) S.sum--;
}
if(B.l1)
{
le[0]=A.r0>>1; le[1]=A.r0-le[0];
ri[0]=(B.l1+1)>>1; ri[1]=B.l1-ri[0];
if(B.l0&1) swap(ri[0],ri[1]);
S.sum+=le[0]*ri[1]+le[1]*ri[0];
if(!B.l0&&A.r0) S.sum--;
}

if(!A.l1) S.l0=A.l0+B.l0,S.l1=B.l1;
else S.l0=A.l0,S.l1=A.l1+((A.l1+A.l0==A.len)?B.l0:0);
if(!B.r1) S.r0=B.r0+A.r0,S.r1=A.r1;
else S.r0=B.r0,S.r1=B.r1+((B.r1+B.r0==B.len)?A.r0:0);

if(!A.zl1) S.zl0=A.zl0+B.zl0,S.zl1=B.zl1;
else S.zl0=A.zl0,S.zl1=A.zl1+((A.zl1+A.zl0==A.len)?B.zl0:0);
if(!B.zr1) S.zr0=B.zr0+A.zr0,S.zr1=A.zr1;
else S.zr0=B.zr0,S.zr1=B.zr1+((B.zr1+B.zr0==B.len)?A.zr0:0);

return S;
}
}tr[400010];
#define lc (u<<1)
#define rc (u<<1|1)
#define ls lc,l,mid
#define rs rc,mid+1,r
void build(int u,int l,int r)
{
if(l==r) return tr[u].init(read());
int mid=l+r>>1;
build(ls); build(rs);
tr[u]=tr[lc]+tr[rc];
}
void update(int u,int l,int r,int p)
{
if(l==r) return tr[u].change();
int mid=l+r>>1;
(p<=mid)?update(ls,p):update(rs,p);
tr[u]=tr[lc]+tr[rc];
}
node query(int u,int l,int r,int L,int R)
{
if(L<=l&&r<=R) return tr[u];
int mid=l+r>>1;
if(L>mid) return query(rs,L,R);
else if(R<=mid) return query(ls,L,R);
else return query(ls,L,R)+query(rs,L,R);
}
int n,l,r,len;
int main()
{
n=read();
build(1,1,n);
CASET
{
if(read()==1) update(1,1,n,read());
else
{
l=read(),r=read(); len=r-l+1;
printf("%lld\n",1ll*len*(len+1)/2-query(1,1,n,l,r).sum);
}
}
return 0;
}