BSGS 和 exBSGS 学习笔记

BSGS 是一个用来求高次同余方程的暴力算法。

先看一道例题:

给定 a,b,pa,b,p,保证 a,b,p2a,b,p\ge2gcd(a,p)=1\operatorname{gcd}(a,p)=1,求 axb(modp)a^x\equiv b\pmod p 的最小自然数解 xx

显然,我们可以暴力从小到大枚举 xx,判断每个 xx 是否合法。根据欧拉的神奇定理,aφ(p)1(modp)a^{\varphi(p)}\equiv 1\pmod p,我们只需要枚举到 φ(p)\varphi(p) 即可。时间复杂度 O(p)O(p)

这样显然不够好,很多时候 pp 都十分巨大。考虑优化,我们可以xx 拆开,拆成 x=tABx=tA-B,其中 BtB\le t。那么有:

atABb(modp)a^{tA-B}\equiv b\pmod p

atAbaB(modp)a^{tA}\equiv ba^B\pmod p

这样我们就可以用一个哈希表存每个 baBmodpba^B\operatorname{mod} p 对应的最大的 BB,然后枚举 AA,快速找到 atAa^{tA} 对应的 BB,答案即为 x=tABx=tA-B

显然,此时 ttp\left\lceil\sqrt p\right\rceil 是最优的。此时时间复杂度为 O(p)O(\left\lceil\sqrt p\right\rceil)。如果用 map 来哈希的话时间复杂度会多只 log\log

代码如下:

#include <iostream>
#include <cstdio>
#include <cmath>
#include <map>

using namespace std;

inline int BSGS(int a,int b,int p)
{
	map<int,int> mp;
	int val=1,t=sqrt(p)+1;
	for(int B=1;B<=t;B++)
	{
		val=1ll*val*a%p;
		mp[1ll*b*val%p]=B;
	}
	int cur=val;
	for(int A=1;A<=t;A++)
	{
		if(mp.find(val)!=mp.end())
		{
			return A*t-mp[val];
		}
		val=1ll*val*cur%p;
	}
	return -1;
}

int main()
{
	int p,a,b;
	scanf("%d%d%d",&p,&a,&b);
	int x=BSGS(a,b,p);
	if(x==-1)
	{
		puts("no solution");
		return 0;
	}
	printf("%d\n",x);
	return 0;
}

但是有些时候,gcd(a,p)=1\gcd(a,p)\not=1,用不了 BSGS。这时候我们就需要 exBSGS 了。

例如这道题:P4195 【模板】扩展 BSGS/exBSGS

gcd(a,p)=1\gcd(a,p)\not=1 时,我们令 d=gcd(a,p)d=\gcd(a,p),那么有:

axb(modp)a^x\equiv b\pmod p

adax1bd(modpd)\dfrac{a}{d}a^{x-1}\equiv \dfrac{b}{d}\pmod{\dfrac{p}{d}}

如果 gcd(ad,pd)\gcd(\dfrac{a}{d},\dfrac{p}{d}) 仍然不为 11,那么继续做下去,直到 gcd(aD,pD)=1\gcd(\dfrac{a}{D},\dfrac{p}{D})=1 为止DD 为所有 dd 的乘积)。

这时我们的方程变成了这样(cntcnt 是操作的次数):

acntDaxcntbD(modpD)\dfrac{a^{cnt}}{D}a^{x-cnt}\equiv \dfrac{b}{D}\pmod{\dfrac{p}{D}}

就可以用 BSGS 计算答案了。

注意特判 x=0x=0x=cntx=cnt 的特殊情况。

代码如下:

#include <iostream>
#include <cstdio>
#include <map>
#include <cmath>

using namespace std;

inline int gcd(int a,int b)
{
	int t=a%b;
	while(t>0)
	{
		a=b;
		b=t;
		t=a%b;
	}
	return b;
}

inline int exBSGS(int a,int b,int p)
{
	a%=p;
	b%=p;
	if(b==1||p==1)
	{
		return 0;
	}
	int cnt=0,val=1;
	while(1)
	{
		int d=gcd(a,p);
		if(d==1)
		{
			break;
		}
		if(b%d!=0)
		{
			return -1;
		}
		cnt++;
		p/=d;
		b/=d;
		val=1ll*val*(a/d)%p;
		if(val==b)
		{
			return cnt;
		}
	}
	map<int,int> mp;
	int val2=1,t=sqrt(p)+1;
	for(int B=1;B<=t;B++)
	{
		val2=1ll*val2*a%p;
		mp[1ll*b*val2%p]=B;
	}
	int cur=1ll*val2*val%p;
	for(int A=1;A<=t;A++)
	{
		if(mp.find(cur)!=mp.end())
		{
			return A*t-mp[cur]+cnt;
		}
		cur=1ll*cur*val2%p;
	}
	return -1;
}

int main()
{
	int a,b,p;
	while(1)
	{
		scanf("%d%d%d",&a,&p,&b);
		if(a==0&&b==0&&p==0)
		{
			break;
		}
		int res=exBSGS(a,b,p);
		if(res<0)
		{
			puts("No Solution");
		}
		else
		{
			printf("%d\n",res);
		}
	}
	return 0;
}

练习题目