CF794G Replace All 做题记录

给两个包含 A,B,? 的串,? 可以填 AB,求所有情况下下面这个东西的和,对 109+710^9+7 取模:

  • 统计有多少对长度 n\le n0101(S,T)(S,T) 使得把所有 A 换成 SSB 换成 TT 后两个串相等;

两个串的长度 3×105\le 3\times 10^5
为了方便,设 a=n,b=m|a|=n,|b|=m,原来的 nn 设为 kk

首先有:

那么 sstt 一定是由一个小串 cc 不断重复得到的,其中 c=gcd(s,t)|c|=\gcd(|s|,|t|)

先不管问号,设 aa 中的 Abb 中的 Axx 个,bb 中的 Baa 中的 Byy 个,那么显然有限制 x×s=y×tx\times |s|=y\times |t|。显然若 xxyy 异号则无解,否则可以让 xxyy 都取绝对值。接下来分类讨论:

  • x=y=0x=y=0

    此时 s|s|t|t| 可以取任意值,那么答案为:

    i=1kj=1k2gcd(i,j)=d=1k2di=1kdj=1kd[gcd(i,j)=1]=d=1k2di=1kdμ(i)(kdi)2\begin{aligned} &\sum\limits_{i=1}^k\sum\limits_{j=1}^k2^{\gcd(i,j)}\\ &=\sum\limits_{d=1}^{k}2^d\sum\limits_{i=1}^{\lfloor\frac{k}{d}\rfloor}\sum\limits_{j=1}^{\lfloor\frac{k}{d}\rfloor}[\gcd(i,j)=1]\\ &=\sum\limits_{d=1}^{k}2^d\sum\limits_{i=1}^{\lfloor\frac{k}{d}\rfloor}\mu(i)\left(\left\lfloor\frac{k}{di}\right\rfloor\right)^2 \end{aligned}

    直接整除分块,时间复杂度 O(k)O(k)

  • x=0x=0y=0y\not=0x=0x\not=0y=0y=0:方程无解。

  • x>0,y>0x>0,y>0

    xxyy 不互质则可以同时除掉 gcd(x,y)\gcd(x,y),方程显然仍然成立,所以下面默认 xxyy 互质。
    因为 xxyy 互质,所以有 s=y×w,t=x×w|s|=y\times w,|t|=x\times www 的取值范围是 [1,ky][1,kx]=[1,kmax(x,y)][1,\lfloor\frac{k}{y}\rfloor]\cap[1,\lfloor\frac{k}{x}\rfloor]=[1,\lfloor\frac{k}{\max(x,y)}\rfloor],答案即为 i=1kmax(x,y)2i\sum\limits_{i=1}^{\lfloor\frac{k}{\max(x,y)}\rfloor}2^i

现在考虑有问号的情况,先设 x,yx,y 的答案为 f(x,y)f(x,y),设 aa 原来有 aaaaAababBaqaq?bb 原来有 babaAbbbbBbqbq? 那么有:

ans=i=0aqj=0bq(aqi)(bqj)f(aa+ibaj,bb+bqjabaq+i)=i=bqaqf(aaba+i,mban+aa+i)j=0bq(bqj)(aqi+j)=i=bqaqf(aaba+i,mban+aa+i)j=0bq(bqj)(aqaqij)=i=bqaqf(aaba+i,mban+aa+i)(bq+aqaqi)\begin{aligned} ans&=\sum\limits_{i=0}^{aq}\sum\limits_{j=0}^{bq}\binom{aq}{i}\binom{bq}{j}f(aa+i-ba-j,bb+bq-j-ab-aq+i)\\ &=\sum\limits_{i=-bq}^{aq}f(aa-ba+i,m-ba-n+aa+i)\sum\limits_{j=0}^{bq}\binom{bq}{j}\binom{aq}{i+j}\\ &=\sum\limits_{i=-bq}^{aq}f(aa-ba+i,m-ba-n+aa+i)\sum\limits_{j=0}^{bq}\binom{bq}{j}\binom{aq}{aq-i-j}\\ &=\sum\limits_{i=-bq}^{aq}f(aa-ba+i,m-ba-n+aa+i)\binom{bq+aq}{aq-i}\\ \end{aligned}

f(0,0)f(0,0) 单独算一下,其他的 ff 均可 O(1)O(1) 算出,那么时间复杂度 O(n)O(n)

代码如下:

#include <iostream>
#include <cstdio>
#include <cstring>

using namespace std;

const int S=2000005,p=1000000007;

int n,m,k;
char a[S],b[S];
int aa,ab,aq,ba,bb,bq;
bool nop[S];
int tot,prime[S];
int mu[S],sum[S];
int fra[S],inv[S];
int _2sum[S];

inline int qpow(int x,int y)
{
	int res=1;
	for(;y>0;y>>=1,x=1ll*x*x%p) res=y&1?1ll*res*x%p:res;
	return res;
}

inline int C(int n,int m)
{
	if(n<0||m<0||n<m) return 0;
	return 1ll*fra[n]*inv[n-m]%p*inv[m]%p;
}

inline int gcd(int x,int y)
{
	int t=x%y;
	while(t!=0) x=y,y=t,t=x%y;
	return y;
}

inline int g(int n)
{
	int res=0;
	for(int l=1;l<=n;l++)
	{
		int r=min(n,n/(n/l));
		int val=((long long)sum[r]-sum[l-1]+p)%p;
		res=(res+1ll*val*(n/l)%p*(n/l)%p)%p;
		l=r;
	}
	return res;
}

inline int f(int x,int y)
{
	if(x==0&&y==0)
	{
		int res=0;
		for(int l=1;l<=k;l++)
		{
			int r=min(k,k/(k/l));
			int val=((long long)_2sum[r]-_2sum[l-1]+p)%p;
			res=(res+1ll*val*g(k/l)%p)%p;
			l=r;
		}
		return res;
	}
	else
	{
		int g=gcd(x,y);
		x/=g,y/=g;
		return ((long long)_2sum[k/max(x,y)]-_2sum[0]+p)%p;
	}
}

int main()
{
	scanf("%s%s%d",a+1,b+1,&k);
	n=strlen(a+1),m=strlen(b+1);
	for(int i=1;i<=n;i++)
	{
		aa+=a[i]=='A';
		ab+=a[i]=='B';
		aq+=a[i]=='?';
	}
	for(int i=1;i<=m;i++)
	{
		ba+=b[i]=='A';
		bb+=b[i]=='B';
		bq+=b[i]=='?';
	}
	nop[0]=nop[1]=true;
	mu[1]=1;
	for(int i=2;i<=S-3;i++)
	{
		if(!nop[i])
		{
			prime[++tot]=i;
			mu[i]=-1;
		}
		for(int j=1;j<=tot;j++)
		{
			if(i*prime[j]>S-3) break;
			nop[i*prime[j]]=true;
			mu[i*prime[j]]=mu[i]*mu[prime[j]];
			if(i%prime[j]==0)
			{
				mu[i*prime[j]]=0;
				break;
			}
		}
	}
	for(int i=1;i<=S-3;i++) sum[i]=((long long)sum[i-1]+mu[i]+p)%p;
	fra[0]=1;
	for(int i=1;i<=S-3;i++) fra[i]=1ll*fra[i-1]*i%p;
	inv[S-3]=qpow(fra[S-3],p-2);
	for(int i=S-3;i>=1;i--) inv[i-1]=1ll*inv[i]*i%p;
	_2sum[0]=1;
	for(int i=1,tmp=1;i<=S-3;i++)
	{
		tmp=2ll*tmp%p;
		_2sum[i]=(_2sum[i-1]+tmp)%p;
	}
	int ans=0;
	for(int i=-bq;i<=aq;i++)
	{
		int x=aa-ba+i,y=m-ba-n+aa+i;
		if((x<0&&y>0)||(x>0&&y<0)) continue;
		x=abs(x),y=abs(y);
		if((x==0||y==0)&&x+y>0) continue;
		// printf("%d %d %d\n",x,y,f(x,y));
		ans=(ans+1ll*f(x,y)*C(bq+aq,aq-i)%p)%p;
	}
	bool fl=n==m;
	for(int i=1;i<=n;i++) fl&=a[i]==b[i]||a[i]=='?'||b[i]=='?';
	if(fl)
	{
		int cnt=0;
		for(int i=1;i<=n;i++) cnt+=a[i]=='?'&&b[i]=='?';
		ans=((long long)ans-1ll*f(0,0)*qpow(2,cnt)%p+p)%p;
		int val=((long long)_2sum[k]-_2sum[0]+p)%p;
		ans=(ans+1ll*val*val%p*qpow(2,cnt)%p)%p;
	}
	printf("%d\n",ans);
	return 0;
}