给两个包含
A
,B
,?
的串,?
可以填A
或B
,求所有情况下下面这个东西的和,对 取模:
- 统计有多少对长度 的 串 使得把所有
A
换成 ,B
换成 后两个串相等;两个串的长度 。
为了方便,设 ,原来的 设为 。
首先有:

那么 和 一定是由一个小串 不断重复得到的,其中 。
先不管问号,设 中的 A
比 中的 A
多 个, 中的 B
比 中的 B
多 个,那么显然有限制 。显然若 和 异号则无解,否则可以让 和 都取绝对值。接下来分类讨论:
-
:
此时 和 可以取任意值,那么答案为:
直接整除分块,时间复杂度 。
-
且 或 且 :方程无解。
-
:
若 和 不互质则可以同时除掉 ,方程显然仍然成立,所以下面默认 和 互质。
因为 和 互质,所以有 , 的取值范围是 ,答案即为 。
现在考虑有问号的情况,先设 的答案为 ,设 原来有 个 A
, 个 B
, 个 ?
, 原来有 个 A
, 个 B
, 个 ?
那么有:
单独算一下,其他的 均可 算出,那么时间复杂度 。
代码如下:
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
const int S=2000005,p=1000000007;
int n,m,k;
char a[S],b[S];
int aa,ab,aq,ba,bb,bq;
bool nop[S];
int tot,prime[S];
int mu[S],sum[S];
int fra[S],inv[S];
int _2sum[S];
inline int qpow(int x,int y)
{
int res=1;
for(;y>0;y>>=1,x=1ll*x*x%p) res=y&1?1ll*res*x%p:res;
return res;
}
inline int C(int n,int m)
{
if(n<0||m<0||n<m) return 0;
return 1ll*fra[n]*inv[n-m]%p*inv[m]%p;
}
inline int gcd(int x,int y)
{
int t=x%y;
while(t!=0) x=y,y=t,t=x%y;
return y;
}
inline int g(int n)
{
int res=0;
for(int l=1;l<=n;l++)
{
int r=min(n,n/(n/l));
int val=((long long)sum[r]-sum[l-1]+p)%p;
res=(res+1ll*val*(n/l)%p*(n/l)%p)%p;
l=r;
}
return res;
}
inline int f(int x,int y)
{
if(x==0&&y==0)
{
int res=0;
for(int l=1;l<=k;l++)
{
int r=min(k,k/(k/l));
int val=((long long)_2sum[r]-_2sum[l-1]+p)%p;
res=(res+1ll*val*g(k/l)%p)%p;
l=r;
}
return res;
}
else
{
int g=gcd(x,y);
x/=g,y/=g;
return ((long long)_2sum[k/max(x,y)]-_2sum[0]+p)%p;
}
}
int main()
{
scanf("%s%s%d",a+1,b+1,&k);
n=strlen(a+1),m=strlen(b+1);
for(int i=1;i<=n;i++)
{
aa+=a[i]=='A';
ab+=a[i]=='B';
aq+=a[i]=='?';
}
for(int i=1;i<=m;i++)
{
ba+=b[i]=='A';
bb+=b[i]=='B';
bq+=b[i]=='?';
}
nop[0]=nop[1]=true;
mu[1]=1;
for(int i=2;i<=S-3;i++)
{
if(!nop[i])
{
prime[++tot]=i;
mu[i]=-1;
}
for(int j=1;j<=tot;j++)
{
if(i*prime[j]>S-3) break;
nop[i*prime[j]]=true;
mu[i*prime[j]]=mu[i]*mu[prime[j]];
if(i%prime[j]==0)
{
mu[i*prime[j]]=0;
break;
}
}
}
for(int i=1;i<=S-3;i++) sum[i]=((long long)sum[i-1]+mu[i]+p)%p;
fra[0]=1;
for(int i=1;i<=S-3;i++) fra[i]=1ll*fra[i-1]*i%p;
inv[S-3]=qpow(fra[S-3],p-2);
for(int i=S-3;i>=1;i--) inv[i-1]=1ll*inv[i]*i%p;
_2sum[0]=1;
for(int i=1,tmp=1;i<=S-3;i++)
{
tmp=2ll*tmp%p;
_2sum[i]=(_2sum[i-1]+tmp)%p;
}
int ans=0;
for(int i=-bq;i<=aq;i++)
{
int x=aa-ba+i,y=m-ba-n+aa+i;
if((x<0&&y>0)||(x>0&&y<0)) continue;
x=abs(x),y=abs(y);
if((x==0||y==0)&&x+y>0) continue;
// printf("%d %d %d\n",x,y,f(x,y));
ans=(ans+1ll*f(x,y)*C(bq+aq,aq-i)%p)%p;
}
bool fl=n==m;
for(int i=1;i<=n;i++) fl&=a[i]==b[i]||a[i]=='?'||b[i]=='?';
if(fl)
{
int cnt=0;
for(int i=1;i<=n;i++) cnt+=a[i]=='?'&&b[i]=='?';
ans=((long long)ans-1ll*f(0,0)*qpow(2,cnt)%p+p)%p;
int val=((long long)_2sum[k]-_2sum[0]+p)%p;
ans=(ans+1ll*val*val%p*qpow(2,cnt)%p)%p;
}
printf("%d\n",ans);
return 0;
}