多项式
备注:多项式基本操作此处不再说明,都比较简略,此处只介绍重要结论以及算法。
Warning:这篇文章有的公式是无法显示的,建议在洛谷 “多项式科技” 中查看。
多项式 ln/exp
暴力解法
首先引入概念设 $A(x)=\ln(B(x)),B(x)=\exp(A(x))$,那么我们可以看做 $A_x$ 是 $x$ 个点的带标号无向连通图,$B_x$ 是 $x$ 个点的带标号无向普通图,于是我们根据组合意义得到下面的公式:
$$
n!B_n = \sum_{i=1}^n C_{n-1}^{i-1}i!A_i(n-i)!B_{n-i}
$$
可以化简成:
$$
B_n = \sum_{i=1}^n A_iB_{n-i}\frac in
$$
反过来,$A$ 也可以通过 $B$ 得出来:
$$
A_n = B_n-\sum_{i=1}^{n-1} A_iB_{n-i}\frac in
$$
当然,还有一种对于 $\ln(1+ax)$ 的化简公式:
$$
\ln(1+ax) = \sum_{i=1}^{\text{inf}}\frac 1i(-1)^{i-1}a^ix^i
$$
如果我们需要求出第 $n$ 项在 exp 之后的值,直接保留前 $n$ 项就可以了。
注意:根据 $\ln,\exp$ 的定义,可得 $A_0=0,B_0=1$。
复数
给原有的实数集 $R$ 添加一个元素,然后任意元素表出来的所有集合称作复数集 $C$。
添加的元素是 $i=\sqrt{-1}$,因此 $C$ 中的每个元素可以唯一使用 $a+bi(a \in R,b \in R)$ 表示,并且这个域中的加法乘法运算具有封闭性。
其中 $a$ 称作这个复数的实部,$b$ 称作这个复数的虚部。
运算
加/减法:$(a+bi)+(c+di)=(a+c)+(b+d)i$。
乘法:$(a+bi)(c+di)=(ac-bd)+(ad+bc)i$。(乘法分配律展开即可)
除法:$\frac {a+bi}{c+di}=\frac {(a+bi)(c-di)}{(c+di)(c-di)}=\frac{ac+bd+bci-adi}{c^2+d^2}=\frac{ac+bd}{c^2+d^2}+\frac{bc-ad}{c^2+d^2}i$。
平面几何
因为每个复数可以看作平面上的一个点,于是设平面的 $x$ 轴为实轴;$y$ 轴为虚轴,$a+bi$ 在平面上的坐标就是 $(a,b)$,其模长就是这个向量的长度。
所有模长等于 $1$ 的复数在平面直角坐标系上构成了单位圆,即一个半径为 $1$ 的圆,这些复数也称作单位复数。
单位复数与单位复数的乘积还是单位复数,也就是这些单位复数也构成一个具有运算封闭性的群。
这个平面直角坐标系有一个特殊的名字:复平面。
在极坐标的视角下,复数的乘除法变得很简单。复数乘法,模相乘,辐角相加。复数除法,模相除,辐角相减。
欧拉公式
对于任意实数 $x$,有:
$$
e^{ix}=\cos x+i \sin x
$$
这个公式我们此处不做证明,感兴趣的可以查一下资料了解。
单位根
设 $x^n=1$ 在复数意义下的解 $x$ 是 $n$ 次复根。根据单位复数的知识,这样的解一共有 $n$ 个,并且平分单位圆,这些解就称作单位根。
设 $\omega_{n}^{k} = \exp \frac {2\pi k}{n}$,则 $x^n=1$ 的解集就是 $\omega_{n}^{0 \sim n-1}$。(关于 $\frac {2\pi k}{n}$ 的角度表示可以参考 计算几何-极坐标)
并且因为 $\exp (i \theta)=\cos \theta + i \sin \theta$,所以我们可以快速得到所有 $\omega_n$。(这个公式的详细推导需要用到三角函数等知识,此处不再赘述)
根据单位圆的性质,容易发现:
$$
\omega_{n}^n =1,\omega_{2n}^{2k} = \omega_n^k,\omega_{2n}^{n+k} = -\omega_{2n}^k
$$
上面的三个公式会在 FFT(快速傅里叶变换)中得到应用。
并且如果 $k$ 等于 $1$,则所有 $\omega$ 都可以通过 $\omega_{n}^1$ 的次幂表示。($\omega_n^1$ 是从 $1$ 开始逆时针方向遇到的第一个单位根)
本源单位根
为什么说,上述 $n$ 个解都是 $n$ 次单位根,而平时说的 $n$ 次单位根一般特指第一个?
特指第一个,是为了在应用时方便。
在解方程的视角看来,满足 $\omega_n$ 性质的不止 $\omega_n$ 一个,对于 $\omega_n$ 的若干次幂也会满足性质。
称集合:${\omega_n^k\mid 0\le k<n, \gcd(n,k)=1}$
中的元素为本原单位根。任意一个本原单位根 $\omega$,与上述 $\omega_n$ 具有相同的性质:对于任意的 $0<k<n$,$\omega$ 的 $k$ 次幂不为 $1$。因此,借助任意一个本原单位根,都可以生成全体单位根。
全体 $n$ 次本原单位根共有 $\varphi(n)$ 个,很好证明。
FFT(快速傅里叶变换)
如果给了两个多项式 $A(x)$ 和 $B(x)$ 求出 $A \times B$ 的值,时间复杂度是 $O(n^2)$ 的。
如果我们给了两个多项式 $A(x),B(x)$ 在 $x = 1 \sim n+1$ 地方的取值,求出 $A \times B$ 在 $x = 1 \sim n+1$ 地方的取值,那么可以做到直接相乘 $O(n)$。
最后我们需要根据点值求出系数,又要用到拉格朗日插值,时间复杂度是 $O(n^2)$ 的。
这启发我们使用恰当的 $x$,算出点值之后求出系数,这就是 FFT 的精髓所在。
设 $A$ 的次数为 $n$,$B$ 的次数为 $m$。
FFT 使用了单位根来计算点值并且还原系数,我们设 $N$ 是第一个大于 $n+m$ 的二的次幂,则 FFT 选择的点就是 $\omega_{N}^{1 \sim N}$,并且在 $\log$ 时间范围内计算出点值。
于是我们考虑 $f(x)=a_0+a_1x+a_2x^2+a_3x^3+a_4x^4+a_5x^5+a_6x^6+a_7x^7$,我们可以按照奇偶分组:
$$
G(x)=a_0+a_2x+a_4x^2+a_6x^3 \\
H(x)=a_1+a_3x^1+a_5x^2+a_7x^3
$$
所以 $f(x)=G(x^2)+xH(x^2)$,因为 $\omega_{n}^k= - \omega_{n}^{k+\frac n2}$,所以 $G({\omega_{n}^k}^2)=G({(-\omega_{n}^{k+\frac n2}})^2)$,对于 $H(x^2)$ 同理。
所以我们可以直接进行计算:$f(\omega_{n}^k)=G({\omega_{n}^k}^2)+\omega_{n}^kH({\omega_{n}^k}^2),f(\omega_{n}^{k+\frac n2})=G({\omega_{n}^k}^2)-\omega_{n}^kH({\omega_{n}^k}^2)$。
于是根据 ${\omega_{n}^k}^2=\omega_{n}^{2k}=\omega_{\frac n2}^k$ 递归计算即可。
因为每次都必须是 $2$ 的倍数($n$),所以最开始我们令 $N=2^k$,然后就能得到一个常数很大的代码(递归版本)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
| #include <cmath> #include <complex>
typedef std::complex<double> Comp;
const Comp I(0, 1); const int MAX_N = 1 << 20;
Comp tmp[MAX_N];
void DFT(Comp* f, int n, int rev) { if (n == 1) return; for (int i = 0; i < n; ++i) tmp[i] = f[i]; for (int i = 0; i < n; ++i) { if (i & 1) f[n / 2 + i / 2] = tmp[i]; else f[i / 2] = tmp[i]; } Comp *g = f, *h = f + n / 2; DFT(g, n / 2, rev), DFT(h, n / 2, rev); Comp cur(1, 0), step(cos(2 * M_PI / n), sin(2 * M_PI * rev / n)); for (int k = 0; k < n / 2; ++k) { tmp[k] = g[k] + cur * h[k]; tmp[k + n / 2] = g[k] - cur * h[k]; cur *= step; } for (int i = 0; i < n; ++i) f[i] = tmp[i]; }
|
常数优化
首先考虑最后我们每次按照奇偶分类最后的 $f$ 是什么样子,事实上就是它在 $0 \sim k-1$ 位翻转后的结果,每次递推计算这个翻转后的位置就好:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
|
void change(Complex y[], int len) { for (int i = 0; i < len; ++i) { rev[i] = rev[i >> 1] >> 1; if (i & 1) { rev[i] |= len >> 1; } } for (int i = 0; i < len; ++i) { if (i < rev[i]) { swap(y[i], y[rev[i]]); } } return; }
|
然后我们每次执行运算 $f(\omega_n) \to f(\omega_{2n})$ 只需要直接在原位置上进行计算就可以了。
点值得到系数
得到了最终多项式的点值之后我们要知道各个项数的系数是多少,于是我们还得学习快速傅里叶逆变换。
IDFT(傅里叶反变换)的作用,是把目标多项式的点值形式转换成系数形式。而 DFT 本身是个线性变换,可以理解为将目标多项式当作向量,左乘一个矩阵得到变换后的向量。
现在我们已经得到最左边的结果了,中间的 $x$ 值在目标多项式的点值表示中也是一一对应的,所以,根据矩阵的基础知识,我们只要在式子两边左乘中间那个大矩阵的逆矩阵就行了。
由于这个矩阵的元素非常特殊,它的逆矩阵也有特殊的性质,就是每一项取倒数,再除以变换的长度 $n$,就能得到它的逆矩阵。
为了使计算的结果为原来的倒数,根据欧拉公式,可以得到
$$\frac{1}{\omega_k}=\omega_k^{-1}=e^{-\frac{2\pi i}{k}}=\cos\left(\frac{2\pi}{k}\right)+i \sin\left(-\frac{2\pi}{k}\right)$$
因此我们可以尝试着把单位根 $\omega_k$ 取成 $e^{-\frac{2\pi \mathrm{i}}{k}}$,这样我们的计算结果就会变成原来的倒数,之后唯一多的操作就只有再除以它的长度 $n$,而其它的操作过程与 DFT 是完全相同的。我们可以定义一个函数,在里面加一个参数 $1$ 或者是 $-1$,然后把它乘到 $\pi$ 上。传入 $1$ 就是 DFT,传入 $-1$ 就是 IDFT。
最后我们以 P3803 【模板】多项式乘法(FFT) 为例子看一下代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
| #include<bits/stdc++.h> #define N 4000005 #define db double using namespace std; inline char nc(){ static char buf[1000000],*p=buf,*q=buf; return p==q&&(q=(p=buf)+fread(buf,1,1000000,stdin),p==q)?EOF:*p++; } inline int read(){ int res = 0; char c = nc(); while(c<'0'||c>'9')c=nc(); while(c<='9'&&c>='0')res=res*10+c-'0',c=nc(); return res; } char obuf[1<<21],*p3=obuf; inline void pc(char c){ p3-obuf<=(1<<20)?(*p3++=c):(fwrite(obuf,p3-obuf,1,stdout),p3=obuf,*p3++=c); } inline void write(int x){ if(x<0) pc('-'),x=-x; if(x>9) write(x/10); pc(x%10+'0'); } const double eps = 1e-8,pi = acos(-1.0); struct poly{db re,im;}a[N],b[N],c[N]; inline poly operator+(poly &a,poly &b){return (poly){a.re+b.re,a.im+b.im};} inline poly operator-(poly &a,poly &b){return (poly){a.re-b.re,a.im-b.im};} inline poly operator*(poly &a,poly &b){return (poly){a.re*b.re-a.im*b.im,a.re*b.im+a.im*b.re};} int n,m,i,res[N],las; inline void fft(poly *f,int n,int type){ for(int i=0;i<n;i++) if(i<res[i]) swap(f[i],f[res[i]]); for(int i=1;i<n;i*=2){ poly wn = (poly){cos(pi/i),sin(pi/i)*type}; for(int r=i*2,j=0;j<n;j+=r){ poly w = (poly){1,0}; for(int k=0;k<i;k++,w=w*wn){ poly x = f[j+k],y = w*f[j+k+i]; f[j+k] = x+y,f[j+k+i] = x-y; } } } } int main(){
n=read()+1,m=read()+1; las=n+m-1; for(i=0;i<n;i++) a[i].re=read(); for(i=0;i<m;i++) b[i].re=read(); while((n&(-n))!=n) n++; while((m&(-m))!=m) m++; while(n<m) n++; while(m<n) m++; n*=2; for(int i=0;i<n;i++){ res[i]=(res[i>>1]>>1); if(i&1) res[i]|=(n>>1); } fft(a,n,1),fft(b,n,1); for(i=0;i<n;i++) c[i]=a[i]*b[i]; fft(c,n,-1); for(i=0;i<las;i++) write(round(c[i].re/n)),pc(' '); pc('\n'); fwrite(obuf,p3-obuf,1,stdout); return 0; }
|
最后记得因为精度有误差,必须要四舍五入一下,并且 $\frac 1n$ 需要写在四舍五入函数的里面。
应用
FFT 用于快速求解:
$$
C_i = \sum_{j=0}^i A_jB_{i-j}
$$
的所有 $C$ 的值,与此相同,如果 $A,B$ 的系数是凸的,那么闵可夫斯基和求的就是:
$$
C_i = \max_{j=0}^i A_j+B_{i-j}
$$
同时我们还可以想到迪利克雷卷积(上面那个一般称作循环卷积):
$$
C_i = \sum_{j \mid i} A_jB_{\frac ij}
$$
上面这几个可以在 $O(n \log n),O(n),O(n \log n)$ 范围内解决,都很优秀,做题的时候偏重在于转化模型得到上面的答案,再应用工具输出最终结果。
其他
FFT 同时可以利用一些循环的性质来解决(证明用到了单位根的循环性质,自证不难)
如果我们前两遍 FFT 的时候长度为 $N(N>n,N>m)$ 但是 $N<n+m-1$,即第三遍合并的时候项数不够,那么最后的结果就是一个循环移位的结果,设标准结果为 $C$,此方法得到的结果为 $C’$,则有:
$$
D_i = \sum_{k \ge 0} C_{i+kN}
$$
于是我们可以通过这个方法来去掉不需要的项来减少常数消耗。
NTT (快速数论变换)
相比于 FFT,每个数值都需要对 $mod$ 取模,这个的处理方式也很简单。
$mod$ 如果是质数就可以找到一个原根,满足原根的 $\varphi$ 次方对 $mod$ 取模后是 $1$ 和一些其它的性质。
因此我们就可以仿照单位根那样进行计算,原根基本上具有单位根的所有性质,所以这样做是对的。
接下来我们需要考虑一下哪些模数可以作为 NTT 模数,因为原根的循环节是 $\varphi$ 长度的,所以 $N$ 一定要整数 $\varphi$ 即 $mod-1$,并且 $N=2^k$,所以一个模数可以作为 NTT 模数当且仅当它分解为 $2^t \times p+q$ 之后 $2^t$ 大于等于 $N$。
常见的 NTT 模数就是 $998244353$,下面提供了一份 NTT 模板,把复数类换成整数类,把单位根换成原根就可以了。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| inline void init(ll n){ for(ll i=0;i<n;i++){ res[i]=((res[i>>1])>>1); if(i&1) res[i]+=(n>>1); } } inline void clear(ll n){for(ll i=0;i<n;i++) a[i]=b[i]=c[i]=0;} inline void ntt(ll *f,ll n,ll type){ for(ll i=0;i<n;i++) if(i<res[i]) swap(f[i],f[res[i]]); for(ll i=1;i<n;i*=2){ ll wn = qmi((type==1?G:invG),(mod-1)/(i*2),mod); for(ll j=1;j<i;j++) qm[j]=qm[j-1]*wn%mod; for(ll r=i*2,j=0;j<n;j+=r){ for(ll k=0;k<i;k++){ ll x = f[j+k],y = qm[k]*f[j+k+i]%mod; f[j+k] = (x+y)%mod,f[j+k+i] = (x-y+mod)%mod; } } } }
|
补充:据说还有任意模数多项式乘法,但是并不常见,下面列举一种常见的解法:
- 找 $3$ 个质数,满足其 $\operatorname{lcm}$ 大于等于最大的数,然后用 CRT 合并起来对要求的数取模就可以了。
这样的写法需要用 long double
合并并且常数是 $9$ 倍 NTT,如果代码写得比较好看的话,对于 $n \le 10^5$ 完全足够了。
多项式相关操作
多项式乘法
取模:NTT/CRT+NTT。
不取模:FFT/CRT+NTT。
多项式逆元
设 $A(x) \cdot B(x) \equiv 1 \pmod {x^n}$(后面的 $\bmod \ x^n$ 表示保留 $0 \sim n-1$ 次项的系数)
我们考虑化一下公式:
$$
\begin{aligned}
A(x) \cdot B(x) \equiv 1 \pmod {x^n}
\end{aligned}
$$
设 $A(x) \cdot B’(x) \equiv 1 \pmod {x^{\lceil \frac n2 \rceil}}$,则有:
$$
A(x) \cdot (B(x)-B_0(x)) \equiv 0 \pmod {x^{\lceil \frac n2 \rceil}} \\
(B(x)-B_0(x)) \equiv 0 \pmod {x^{\lceil \frac n2 \rceil}} \\
(B(x)-B_0(x))^2 \equiv 0 \pmod {x^n} \\
B(x)^2+B_0(x)^2-2B(x)B_0(x) \equiv 0 \pmod {x^n} \\
B(x)+A(x)B_0(x)^2-2B_0(x) \equiv 0 \pmod {x^n} \\
B(x) \equiv 2B_0(x)-A(x)B_0(x)^2 \pmod {x^n} \\
$$
于是我们就得到了 $B(x)$ 的递归式,每一层递归式使用 NTT 求解即可,时间复杂度为 $T(n)=T(\frac n2)+O(n \log n)$,所以 $T(n)=O(n \log n)$。
边界条件就是常数项 $B(0)=A(0)^{-1}$。
以 P4238 【模板】多项式乘法逆 参考一下代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
| #include<bits/stdc++.h> #define ll long long #define N 500005 #define mod 998244353 using namespace std; ll n,G,invG,i,res[N],qm[N],val[N],val2[N],temp[N],tempa[N],tempb[N],sqval[N],valq[N]; inline ll qmi(ll a,ll b,ll p){ ll res = 1,t = a; while(b){ if(b&1) res=res*t%p; t=t*t%p; b>>=1; } return res; } inline void init(ll n){ for(ll i=0;i<n;i++){ res[i]=((res[i>>1])>>1); if(i&1) res[i]+=(n>>1); } } inline void ntt(ll *f,ll n,ll type){ for(ll i=0;i<n;i++) if(i<res[i]) swap(f[i],f[res[i]]); for(ll i=1;i<n;i*=2){ ll wn = qmi((type==1?G:invG),(mod-1)/(i*2),mod); for(ll j=1;j<i;j++) qm[j]=qm[j-1]*wn%mod; for(ll r=i*2,j=0;j<n;j+=r){ for(ll k=0;k<i;k++){ ll x = f[j+k],y = qm[k]*f[j+k+i]%mod; f[j+k] = (x+y)%mod,f[j+k+i] = (x-y+mod)%mod; } } } } inline char nc(){ static char buf[1000000],*p=buf,*q=buf; return p==q&&(q=(p=buf)+fread(buf,1,1000000,stdin),p==q)?EOF:*p++; } inline ll read(){ ll res = 0; char c = nc(); while(c<'0'||c>'9')c=nc(); while(c<='9'&&c>='0')res=res*10+c-'0',c=nc(); return res; } char obuf[1<<21],*p3=obuf; inline void pc(char c){ p3-obuf<=(1<<20)?(*p3++=c):(fwrite(obuf,p3-obuf,1,stdout),p3=obuf,*p3++=c); } inline void write(ll x){ if(x<0) pc('-'),x=-x; if(x>9) write(x/10); pc(x%10+'0'); } inline void times(ll *a,ll *b,ll *c,ll n,ll tim){ ll nn = n; while((n&(-n))!=n) n++; n*=2; for(ll i=0;i<n;i++) tempa[i]=(i>=nn?0:a[i]),tempb[i]=(i>=nn?0:b[i]); init(n),ntt(tempa,n,1),ntt(tempb,n,1); for(ll i=0;i<n;i++){ if(tim==1) c[i]=tempa[i]*tempb[i]%mod; else c[i]=tempa[i]*tempb[i]%mod*tempb[i]%mod; } ntt(c,n,-1); ll invn = qmi(n,mod-2,mod); for(ll i=0;i<n;i++) c[i]=c[i]*invn%mod; } inline void solve(ll x){ if(x==1){ val2[0] = qmi(val[0],mod-2,mod); return ; } ll len = (x+1)/2; solve(len); times(val,val2,temp,x,2); for(ll i=0;i<x;i++) val2[i]=((2*val2[i]-temp[i])%mod+mod)%mod; } int main(){ G=3,invG=qmi(G,mod-2,mod),qm[0]=1; n=read(); for(i=0;i<n;i++) val[i]=read(); solve(n); for(i=0;i<n;i++) write(val2[i]),pc(' '); fwrite(obuf,p3-obuf,1,stdout); return 0; }
|
多项式除法
我们要求出 $F(x)$ 除以 $G(x)$ 的商 $R(x)$ 和余数 $Q(x)$。
其中 $F(x)$ 的次数是 $n$,$G(x)$ 的次数是 $m$,$R(x)$ 的次数是 $n-m$,$Q(x)$ 的次数不超过 $m-1$,这里可以看作 $m-1$(高次项补零)。
设 $F_R(x)$ 表示将 $F(x)$ 的系数翻转(次数)后形成的函数。
那么就有下面的公式:
$$
\begin{aligned}
F(x)&=G(x)R(x)+Q(x) \\
F(\frac 1x) &= G(\frac 1x)R(\frac 1x)+Q(\frac 1x) \\
F(\frac 1x)x^n &= G(\frac 1x)x^mR(\frac 1x)x^{n-m}+Q(\frac 1x)x^{m-1}x^{n-m+1} \\
F(\frac 1x)x^n &= G(\frac 1x)x^mR(\frac 1x)x^{n-m}+Q(\frac 1x)x^{m-1}x^{n-m+1} \\
F_R(x) &= G_R(x)R_R(x)+Q_R(x)x^{n-m+1} \\
F_R(x) &\equiv G_R(x)R_R(x) \pmod {x^{n-m+1}}
\end{aligned}
$$
于是对 $G_R(x)$ 求逆之后 NTT 就可以求出 $R_R(x)$ 和 $R(x)$,顺便还可以求出 $Q(x)=F(x)-G(x)R(x)$。
因为 $R$ 的次数是 $n-m$ 所以我们这么求出来的 $R$ 是对的,又因为 $Q(x)x^{n-m+1}$ 不超过 $n$ 次项,所以 $Q(x)$ 最后的结果也是不超过 $m-1$ 次项的。
以 P4512 【模板】多项式除法 为例,参考如下代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
| #include<bits/stdc++.h> #define ll long long #define N 2000005 #define mod 998244353 using namespace std; ll n,m,nn,G,invG,invn,i,res[N],qm[N],aval[N],bval[N],cval[N],val[N],val2[N],temp[N],tempp[N],len; inline ll qmi(ll a,ll b,ll p){ ll res = 1,t = a; while(b){ if(b&1) res=res*t%p; t=t*t%p; b>>=1; } return res; } inline void init(ll n){ for(ll i=0;i<n;i++){ res[i]=((res[i>>1])>>1); if(i&1) res[i]+=(n>>1); } } inline void ntt(ll *f,ll n,ll type){ for(ll i=0;i<n;i++) if(i<res[i]) swap(f[i],f[res[i]]); for(ll i=1;i<n;i*=2){ ll wn = qmi((type==1?G:invG),(mod-1)/(i*2),mod); for(ll j=1;j<i;j++) qm[j]=qm[j-1]*wn%mod; for(ll r=i*2,j=0;j<n;j+=r){ for(ll k=0;k<i;k++){ ll x = f[j+k],y = qm[k]*f[j+k+i]%mod; f[j+k] = (x+y)%mod,f[j+k+i] = (x-y+mod)%mod; } } } } inline char nc(){ static char buf[1000000],*p=buf,*q=buf; return p==q&&(q=(p=buf)+fread(buf,1,1000000,stdin),p==q)?EOF:*p++; } inline ll read(){ ll res = 0,w = 1; char c = nc(); while(c<'0'||c>'9')w=(c=='-'?-1:w),c=nc(); while(c<='9'&&c>='0')res=res*10+c-'0',c=nc(); return res*w; } char obuf[1<<21],*p3=obuf; inline void pc(char c){ p3-obuf<=(1<<20)?(*p3++=c):(fwrite(obuf,p3-obuf,1,stdout),p3=obuf,*p3++=c); } inline void write(ll x){ if(x<0) pc('-'),x=-x; if(x>9) write(x/10); pc(x%10+'0'); } inline void solve(ll x){ if(x==1){ val2[0] = qmi(val[0],mod-2,mod); return ; } solve((x+1)/2); for(ll i=0;i<x;i++) temp[i]=val[i],tempp[i]=val2[i]; ll len = x; while((len&(-len))!=len) temp[len]=tempp[len]=0,len++; temp[len]=tempp[len]=0,len++; while((len&(-len))!=len) temp[len]=tempp[len]=0,len++; ll invn = qmi(len,mod-2,mod); init(len); ntt(temp,len,1),ntt(tempp,len,1); for(ll i=0;i<len;i++) temp[i]=temp[i]*tempp[i]%mod*tempp[i]%mod; ntt(temp,len,-1); for(ll i=0;i<len;i++) temp[i]=temp[i]*invn%mod; for(ll i=0;i<x;i++) val2[i]=((2*val2[i]-temp[i])%mod+mod)%mod; } int main(){ G=3,invG=qmi(G,mod-2,mod),qm[0]=1; n=read(),m=read(); n++,m++; for(i=0;i<n;i++) aval[i]=read(),cval[i]=aval[i]; for(i=0;i<m;i++) val[i]=read(),bval[i]=val[i]; reverse(aval,aval+n),reverse(val,val+m); solve(n); nn=1; while(nn<n*2) nn*=2; for(i=n;i<nn;i++) aval[i]=0,val2[i]=0; init(nn),ntt(aval,nn,1),ntt(val2,nn,1); for(i=0;i<nn;i++) aval[i]=aval[i]*val2[i]%mod; ntt(aval,nn,-1),invn=qmi(nn,mod-2,mod); for(i=0;i<nn;i++) aval[i]=aval[i]*invn%mod; for(i=n-m;i>=0;i--) write(aval[i]),pc(' '),val2[len++]=aval[i]; pc('\n'); for(i=len;i<nn;i++) val2[i]=0; ntt(bval,nn,1),ntt(val2,nn,1); for(i=0;i<nn;i++) bval[i]=bval[i]*val2[i]%mod; ntt(bval,nn,-1),invn=qmi(nn,mod-2,mod); for(i=0;i<nn;i++) bval[i]=bval[i]*invn%mod; for(i=0;i<m-1;i++) write((cval[i]-bval[i]+mod)%mod),pc(' '); pc('\n'); fwrite(obuf,p3-obuf,1,stdout); return 0; }
|
多项式开方
考虑像求逆元那样倍增,那么设 $H(x)^2 \equiv F(x) \pmod {x^{\lceil \frac n2 \rceil}}$,有:
$$
\begin{aligned}
G(x) &\equiv H(x) \pmod {x^{\lceil \frac n2 \rceil}} \\
G(x)-H(x) &\equiv 0 \pmod {x^{\lceil \frac n2 \rceil}} \\
(G(x)-H(x))^2 &\equiv 0\pmod {x^n} \\
G(x)^2+H(x)^2-2G(x)H(x) &\equiv 0\pmod {x^n} \\
H(x)^2+H(x)^2-2G(x)H(x) &\equiv 0\pmod {x^n} \\
F(x)^2+H(x)^2-2G(x)H(x) &\equiv 0\pmod {x^n} \\
G(x) &\equiv \frac {F(x)+H(x)^2}{2H(x)} \pmod {x^n}
\end{aligned}
$$
边界情况需要注意 $G(0)$ 是 $F(0)$ 的二次剩余!
因此我们直接 $O(n \log n)$ 求即可,下面以 P5205 【模板】多项式开根 欣赏一下代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
| #include<bits/stdc++.h> #define ll long long #define N 500005 #define mod 998244353 using namespace std; ll n,G,invG,inv2,i,res[N],qm[N],val[N],val2[N],temp[N],temp2[N],tempa[N],tempb[N],sqval[N],valq[N]; inline ll qmi(ll a,ll b,ll p){ ll res = 1,t = a; while(b){ if(b&1) res=res*t%p; t=t*t%p; b>>=1; } return res; } inline void init(ll n){ for(ll i=0;i<n;i++){ res[i]=((res[i>>1])>>1); if(i&1) res[i]+=(n>>1); } } inline void ntt(ll *f,ll n,ll type){ for(ll i=0;i<n;i++) if(i<res[i]) swap(f[i],f[res[i]]); for(ll i=1;i<n;i*=2){ ll wn = qmi((type==1?G:invG),(mod-1)/(i*2),mod); for(ll j=1;j<i;j++) qm[j]=qm[j-1]*wn%mod; for(ll r=i*2,j=0;j<n;j+=r){ for(ll k=0;k<i;k++){ ll x = f[j+k],y = qm[k]*f[j+k+i]%mod; f[j+k] = (x+y)%mod,f[j+k+i] = (x-y+mod)%mod; } } } } inline char nc(){ static char buf[1000000],*p=buf,*q=buf; return p==q&&(q=(p=buf)+fread(buf,1,1000000,stdin),p==q)?EOF:*p++; } inline ll read(){ ll res = 0; char c = nc(); while(c<'0'||c>'9')c=nc(); while(c<='9'&&c>='0')res=res*10+c-'0',c=nc(); return res; } char obuf[1<<21],*p3=obuf; inline void pc(char c){ p3-obuf<=(1<<20)?(*p3++=c):(fwrite(obuf,p3-obuf,1,stdout),p3=obuf,*p3++=c); } inline void write(ll x){ if(x<0) pc('-'),x=-x; if(x>9) write(x/10); pc(x%10+'0'); } inline void times(ll *a,ll *b,ll *c,ll n,ll tim){ ll nn = n; while((n&(-n))!=n) n++; n*=2; for(ll i=0;i<n;i++) tempa[i]=(i>=nn?0:a[i]),tempb[i]=(i>=nn?0:b[i]); init(n),ntt(tempa,n,1),ntt(tempb,n,1); for(ll i=0;i<n;i++){ if(tim==1) c[i]=tempa[i]*tempb[i]%mod; else c[i]=tempa[i]*tempb[i]%mod*tempb[i]%mod; } ntt(c,n,-1); ll invn = qmi(n,mod-2,mod); for(ll i=0;i<n;i++) c[i]=c[i]*invn%mod; } inline void solve(ll x){ if(x==1){ val2[0] = qmi(val[0],mod-2,mod); return ; } ll len = (x+1)/2; solve(len); times(val,val2,temp2,x,2); for(ll i=0;i<x;i++) val2[i]=((2*val2[i]-temp2[i])%mod+mod)%mod; } inline void sqr(ll x){ if(x==1){ sqval[0]=1; return ; } sqr((x+1)/2); times(sqval,sqval,temp,x,1); for(ll i=0;i<x;i++) temp[i]=(temp[i]+valq[i])%mod; for(ll i=0;i<x;i++) val[i]=sqval[i],val2[i]=0; solve(x); times(temp,val2,temp,x,1); for(ll i=0;i<x;i++) sqval[i]=temp[i]*inv2%mod; } int main(){ G=3,invG=qmi(G,mod-2,mod),qm[0]=1,inv2=qmi(2,mod-2,mod); n=read(); for(i=0;i<n;i++) valq[i]=read(); sqr(n); for(i=0;i<n;i++) write(sqval[i]),pc(' '); fwrite(obuf,p3-obuf,1,stdout); return 0; }
|
推荐多项式求逆写一个比较好看的函数,这样的话方便调试。
多项式对数函数 ln
首先设 $f’(x) = \lim_{x \to 0} \frac {\Delta f(x)}{\Delta x}$,说人话就是斜率。
那么多项式 ln 需要用到的公式就是:$(\ln x)’= \frac 1x$,$(x^a)’=ax^{a-1}$,$(f(g(x)))’=f’(g(x))g’(x)$。
那么就有 $G(x)=F(A(x)),F(x)=\ln(x)$,两边求导得 $G’(x)=F’(A(x))A’(x)=\frac {A’(x)}{A(x)}$,于是直接求导乘上逆元就可以了。
最后我们要通过 $G’(x)$ 还原出 $G(x)$,于是进行求导的逆运算——积分就可以了。
这里的积分就是 $\int x^adx=\frac{1}{a+1}x^{a+1}$ 这一个公式。
代码以 P4725 【模板】多项式对数函数(多项式 ln) 为例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
| #include<bits/stdc++.h> #define ll long long #define N 500005 #define mod 998244353 using namespace std; ll n,m,x,G,invG,inv2,i,res[N],qm[N],val[N],val2[N],temp[N],temp2[N],tempas[N],tempa[N],tempb[N],sqval[N],valqs[N],valq[N],ans[N]; inline ll qmi(ll a,ll b,ll p){ ll res = 1,t = a; while(b){ if(b&1) res=res*t%p; t=t*t%p; b>>=1; } return res; } inline void init(ll n){ for(ll i=0;i<n;i++){ res[i]=((res[i>>1])>>1); if(i&1) res[i]+=(n>>1); } } inline void ntt(ll *f,ll n,ll type){ for(ll i=0;i<n;i++) if(i<res[i]) swap(f[i],f[res[i]]); for(ll i=1;i<n;i*=2){ ll wn = qmi((type==1?G:invG),(mod-1)/(i*2),mod); for(ll j=1;j<i;j++) qm[j]=qm[j-1]*wn%mod; for(ll r=i*2,j=0;j<n;j+=r){ for(ll k=0;k<i;k++){ ll x = f[j+k],y = qm[k]*f[j+k+i]%mod; f[j+k] = (x+y)%mod,f[j+k+i] = (x-y+mod)%mod; } } } } inline char nc(){ static char buf[1000000],*p=buf,*q=buf; return p==q&&(q=(p=buf)+fread(buf,1,1000000,stdin),p==q)?EOF:*p++; } inline ll read(){ ll res = 0; char c = nc(); while(c<'0'||c>'9')c=nc(); while(c<='9'&&c>='0')res=res*10+c-'0',c=nc(); return res; } char obuf[1<<21],*p3=obuf; inline void pc(char c){ p3-obuf<=(1<<20)?(*p3++=c):(fwrite(obuf,p3-obuf,1,stdout),p3=obuf,*p3++=c); } inline void write(ll x){ if(x<0) pc('-'),x=-x; if(x>9) write(x/10); pc(x%10+'0'); } inline void times(ll *a,ll *b,ll *c,ll n,ll tim){ ll nn = n; while((n&(-n))!=n) n++; n*=2; for(ll i=0;i<n;i++) tempa[i]=(i>=nn?0:a[i]),tempb[i]=(i>=nn?0:b[i]); init(n),ntt(tempa,n,1),ntt(tempb,n,1); for(ll i=0;i<n;i++){ if(tim==1) c[i]=tempa[i]*tempb[i]%mod; else c[i]=tempa[i]*tempb[i]%mod*tempb[i]%mod; } ntt(c,n,-1); ll invn = qmi(n,mod-2,mod); for(ll i=0;i<n;i++) c[i]=c[i]*invn%mod; } inline void solve(ll x){ if(x==1){ val2[0] = qmi(val[0],mod-2,mod); return ; } ll len = (x+1)/2; solve(len); times(val,val2,temp2,x,2); for(ll i=0;i<x;i++) val2[i]=((2*val2[i]-temp2[i])%mod+mod)%mod; } inline void sqr(ll x){ if(x==1){ sqval[0]=1; return ; } sqr((x+1)/2); times(sqval,sqval,temp,x,1); for(ll i=0;i<x;i++) temp[i]=(temp[i]+valq[i])%mod; for(ll i=0;i<x;i++) val[i]=sqval[i],val2[i]=0; solve(x); times(temp,val2,temp,x,1); for(ll i=0;i<x;i++) sqval[i]=temp[i]*inv2%mod; } inline void qd(ll *f,ll x){ for(ll i=1;i<x;i++) f[i-1]=f[i]*i; f[x-1]=0; } inline void jf(ll *f,ll x){ for(ll i=x-1;i>0;i--) f[i]=f[i-1]*qmi(i,mod-2,mod)%mod; f[0]=0; } inline void ln(ll *f,ll *ans,ll n){ for(ll i=0;i<n;i++) val[i]=f[i]; solve(n),qd(f,n); times(f,val2,ans,n,1); jf(ans,n); } int main(){ G=3,invG=qmi(G,mod-2,mod),qm[0]=1,inv2=qmi(2,mod-2,mod); n=read(); for(i=0;i<n;i++) tempas[i]=read(); ln(tempas,ans,n); for(i=0;i<n;i++) write(ans[i]),pc(' '); fwrite(obuf,p3-obuf,1,stdout); return 0; }
|
多项式指数函数 exp
牛顿迭代
牛顿迭代主要用来求解一个函数的零点的整数值(不用精确到小数),首先我们随机取一个点 $x$ 算出其函数值 $y=f(x)$,然后我们要得到过点 $(x,y)$ 并且切函数 $f(x)$ 的直线。
这个很好办,首先求出导数,即点上的斜率 $f’(x)$,那么直线表达式就是 $y=f’(x)(x_0-x)+f(x)$($x_0$ 是带进去的值)。
然后找到其与 $x$ 轴的交点,令交点为下一次随机选择的点处理即可,即 $y=0$,然后得到 $x_0=x-\frac{f(x)}{f’(x)}$,放到多项式上也是同理。
那么我们要找到一个函数 $G(x)$ 满足 $F(G(x))=0$,那么每次令 $G_0(x)=G(x)-\frac {F(G(x))}{(F’(G(x)))}$ 就可以了,事实上每次执行精度都会翻倍,即若 $F(G(x)) \equiv 0 \pmod {x^n}$ 那么 $F(G_0(x)) \equiv 0 \pmod {x^{2n}}$。
牛顿迭代 $\to$ 多项式 exp
题目给定 $A(x)$ 求 $B(x) \equiv e^{A(x)} \pmod {x^n}$,那么就有 $\ln B(x)-A(x) \equiv 0 \pmod {x^n}$,我们需要构造函数 $F(G(x))=\ln G(x)-A(x) \equiv 0 \pmod {x^n}$。
则 $F’(G(x))=(\ln(G(x))’-A(x)’)=\frac{1}{G(x)}$,所以 $G_0(x)=G(x)-\frac{\ln G(x)-A(x)}{\frac{1}{G(x)}}=G(x)-\frac{G(x)(\ln G(x)-A(x))}{1}=G(x)(1-\ln G(x)+A(x))$。
所以每次牛顿迭代求解即可,时间复杂度依然是大常数 $O(n \log n)$,边界情况就是 $B(0)=G(0)=1$。
特别地:只有当 $A(0)=0$ 的时候才能用 exp,当 $A(0)=1$ 的时候才能用 ln,两者的关系是相互的!
代码以 P4726 【模板】多项式指数函数(多项式 exp) 为例参考:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
| #include<bits/stdc++.h> #define ll long long #define N 500005 #define mod 998244353 using namespace std; ll n,m,x,G,invG,inv2,i,res[N],qm[N],val[N],val2[N],temp[N],temp2[N],nexp[N],expx[N],tempa[N],tempb[N],sqval[N],valqs[N],valq[N],ans[N]; ll tempas[N],tempbs[N],tempcs[N]; inline ll qmi(ll a,ll b,ll p){ ll res = 1,t = a; while(b){ if(b&1) res=res*t%p; t=t*t%p; b>>=1; } return res; } inline void init(ll n){ for(ll i=0;i<n;i++){ res[i]=((res[i>>1])>>1); if(i&1) res[i]+=(n>>1); } } inline void ntt(ll *f,ll n,ll type){ for(ll i=0;i<n;i++) if(i<res[i]) swap(f[i],f[res[i]]); for(ll i=1;i<n;i*=2){ ll wn = qmi((type==1?G:invG),(mod-1)/(i*2),mod); for(ll j=1;j<i;j++) qm[j]=qm[j-1]*wn%mod; for(ll r=i*2,j=0;j<n;j+=r){ for(ll k=0;k<i;k++){ ll x = f[j+k],y = qm[k]*f[j+k+i]%mod; f[j+k] = (x+y)%mod,f[j+k+i] = (x-y+mod)%mod; } } } } inline char nc(){ static char buf[1000000],*p=buf,*q=buf; return p==q&&(q=(p=buf)+fread(buf,1,1000000,stdin),p==q)?EOF:*p++; } inline ll read(){ ll res = 0; char c = nc(); while(c<'0'||c>'9')c=nc(); while(c<='9'&&c>='0')res=res*10+c-'0',c=nc(); return res; } char obuf[1<<21],*p3=obuf; inline void pc(char c){ p3-obuf<=(1<<20)?(*p3++=c):(fwrite(obuf,p3-obuf,1,stdout),p3=obuf,*p3++=c); } inline void write(ll x){ if(x<0) pc('-'),x=-x; if(x>9) write(x/10); pc(x%10+'0'); } inline void times(ll *a,ll *b,ll *c,ll n,ll tim){ ll nn = n; while((n&(-n))!=n) n++; n*=2; for(ll i=0;i<n;i++) tempa[i]=(i>=nn?0:a[i]),tempb[i]=(i>=nn?0:b[i]); init(n),ntt(tempa,n,1),ntt(tempb,n,1); for(ll i=0;i<n;i++){ if(tim==1) c[i]=tempa[i]*tempb[i]%mod; else c[i]=tempa[i]*tempb[i]%mod*tempb[i]%mod; } ntt(c,n,-1); ll invn = qmi(n,mod-2,mod); for(ll i=0;i<n;i++) c[i]=c[i]*invn%mod; } inline void solve(ll x){ if(x==1){ val2[0] = qmi(val[0],mod-2,mod); return ; } ll len = (x+1)/2; solve(len); for(ll i=len;i<x;i++) val2[i]=0; times(val,val2,temp2,x,2); for(ll i=0;i<x;i++) val2[i]=((2*val2[i]-temp2[i])%mod+mod)%mod; } struct poly{ll a,b;}ress,t; ll nk,nks; poly operator*(poly a,poly b){return (poly){((a.a*b.a+a.b*b.b%mod*((nks*nks-nk)%mod))+mod)%mod,(a.b*b.a+a.a*b.b)%mod};} mt19937 rnd(time(0)); inline ll cipolla(ll n){ if(n==0) return 0; nk=n; ll x,y; while(1){ x=rnd()%mod; if(qmi((((x*x-n)%mod)+mod)%mod,(mod-1)/2,mod)==mod-1) break; } ress.a = 1,ress.b = 0,t.a = x,t.b = 1,nks = x; y = (mod+1)/2; while(y){ if(y&1) ress=ress*t; t=t*t; y>>=1; } return min(ress.a,mod-ress.a); } inline void sqr(ll x){ if(x==1){ sqval[0]=cipolla(valq[0]); return ; } sqr((x+1)/2); times(sqval,sqval,temp,x,1); for(ll i=0;i<x;i++) temp[i]=(temp[i]+valq[i])%mod; for(ll i=0;i<x;i++) val[i]=sqval[i]; solve(x); times(temp,val2,temp,x,1); for(ll i=0;i<x;i++) sqval[i]=temp[i]*inv2%mod; } inline void qd(ll *f,ll x){ for(ll i=1;i<x;i++) f[i-1]=f[i]*i; f[x-1]=0; } inline void jf(ll *f,ll x){ for(ll i=x-1;i>0;i--) f[i]=f[i-1]*qmi(i,mod-2,mod)%mod; f[0]=0; } inline void ln(ll *f,ll *ans,ll n){ for(ll i=0;i<n;i++) val[i]=f[i],val2[i]=0; solve(n),qd(f,n); times(f,val2,ans,n,1); jf(ans,n); } inline void init(){G=3,invG=qmi(G,mod-2,mod),qm[0]=1,inv2=qmi(2,mod-2,mod);} inline void newton(ll x){ if(x==1){ expx[0]=1; return ; } newton((x+1)/2); for(ll i=0;i<x;i++) tempcs[i]=expx[i]; ln(tempcs,tempas,x); for(ll i=0;i<x;i++) tempas[i]=(mod-tempas[i]+nexp[i])%mod; tempas[0]=(tempas[0]+1)%mod; times(tempas,expx,tempbs,x,1); for(ll i=0;i<x;i++) expx[i]=tempbs[i]; } int main(){ init(); n=read(); for(i=0;i<n;i++) nexp[i]=read(); newton(n); for(i=0;i<n;i++) write(expx[i]),pc(' '); fwrite(obuf,p3-obuf,1,stdout); return 0; }
|
多项式快速幂
求 $A(x)^k$ 的值,首先取 $\ln$,然后在 $\ln$ 下乘上 $k$,最后再 $\exp$ 回去就可以了。
代码以 P5273 【模板】多项式幂函数(加强版) 为例参考:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
| #include<bits/stdc++.h> #define ll long long #define N 500005 #define mod 998244353 using namespace std; ll n,m,k,x,G,invG,inv2,i,j,valll[N],res[N],qm[N],val[N],val2[N],temp[N],temp2[N],nexp[N],expx[N],tempa[N],tempb[N],sqval[N],valqs[N],valq[N],ans[N]; ll tempas[N],tempbs[N],tempcs[N],vall[N],pos,qms,modphi,ks,ifdy; inline ll qmi(ll a,ll b,ll p){ ll res = 1,t = a; while(b){ if(b&1) res=res*t%p; t=t*t%p; b>>=1; } return res; } inline void init(ll n){ for(ll i=0;i<n;i++){ res[i]=((res[i>>1])>>1); if(i&1) res[i]+=(n>>1); } } inline void ntt(ll *f,ll n,ll type){ for(ll i=0;i<n;i++) if(i<res[i]) swap(f[i],f[res[i]]); for(ll i=1;i<n;i*=2){ ll wn = qmi((type==1?G:invG),(mod-1)/(i*2),mod); for(ll j=1;j<i;j++) qm[j]=qm[j-1]*wn%mod; for(ll r=i*2,j=0;j<n;j+=r){ for(ll k=0;k<i;k++){ ll x = f[j+k],y = qm[k]*f[j+k+i]%mod; f[j+k] = (x+y)%mod,f[j+k+i] = (x-y+mod)%mod; } } } } inline char nc(){ static char buf[1000000],*p=buf,*q=buf; return p==q&&(q=(p=buf)+fread(buf,1,1000000,stdin),p==q)?EOF:*p++; } inline ll read(){ ll res = 0; modphi = 0; ifdy = 0; char c = nc(); while(c<'0'||c>'9')c=nc(); while(c<='9'&&c>='0')res=res*10+c-'0',ifdy|=(res>=mod),modphi=modphi*10+c-'0',c=nc(),res%=mod,modphi%=(mod-1); return res; } char obuf[1<<21],*p3=obuf; inline void pc(char c){ p3-obuf<=(1<<20)?(*p3++=c):(fwrite(obuf,p3-obuf,1,stdout),p3=obuf,*p3++=c); } inline void write(ll x){ if(x<0) pc('-'),x=-x; if(x>9) write(x/10); pc(x%10+'0'); } inline void times(ll *a,ll *b,ll *c,ll n,ll tim){ ll nn = n; while((n&(-n))!=n) n++; n*=2; for(ll i=0;i<n;i++) tempa[i]=(i>=nn?0:a[i]),tempb[i]=(i>=nn?0:b[i]); init(n),ntt(tempa,n,1),ntt(tempb,n,1); for(ll i=0;i<n;i++){ if(tim==1) c[i]=tempa[i]*tempb[i]%mod; else c[i]=tempa[i]*tempb[i]%mod*tempb[i]%mod; } ntt(c,n,-1); ll invn = qmi(n,mod-2,mod); for(ll i=0;i<n;i++) c[i]=c[i]*invn%mod; } inline void solve(ll x){ if(x==1){ val2[0] = qmi(val[0],mod-2,mod); return ; } ll len = (x+1)/2; solve(len); for(ll i=len;i<x;i++) val2[i]=0; times(val,val2,temp2,x,2); for(ll i=0;i<x;i++) val2[i]=((2*val2[i]-temp2[i])%mod+mod)%mod; } struct poly{ll a,b;}ress,t; ll nk,nks,kifdy; poly operator*(poly a,poly b){return (poly){((a.a*b.a+a.b*b.b%mod*((nks*nks-nk)%mod))+mod)%mod,(a.b*b.a+a.a*b.b)%mod};} mt19937 rnd(time(0)); inline ll cipolla(ll n){ if(n==0) return 0; nk=n; ll x,y; while(1){ x=rnd()%mod; if(qmi((((x*x-n)%mod)+mod)%mod,(mod-1)/2,mod)==mod-1) break; } ress.a = 1,ress.b = 0,t.a = x,t.b = 1,nks = x; y = (mod+1)/2; while(y){ if(y&1) ress=ress*t; t=t*t; y>>=1; } return min(ress.a,mod-ress.a); } inline void sqr(ll x){ if(x==1){ sqval[0]=cipolla(valq[0]); return ; } sqr((x+1)/2); times(sqval,sqval,temp,x,1); for(ll i=0;i<x;i++) temp[i]=(temp[i]+valq[i])%mod; for(ll i=0;i<x;i++) val[i]=sqval[i]; solve(x); times(temp,val2,temp,x,1); for(ll i=0;i<x;i++) sqval[i]=temp[i]*inv2%mod; } inline void qd(ll *f,ll x){ for(ll i=1;i<x;i++) f[i-1]=f[i]*i; f[x-1]=0; } inline void jf(ll *f,ll x){ for(ll i=x-1;i>0;i--) f[i]=f[i-1]*qmi(i,mod-2,mod)%mod; f[0]=0; } inline void ln(ll *f,ll *ans,ll n){ for(ll i=0;i<n;i++) val[i]=f[i],val2[i]=0; solve(n),qd(f,n); times(f,val2,ans,n,1); jf(ans,n); } inline void init(){G=3,invG=qmi(G,mod-2,mod),qm[0]=1,inv2=qmi(2,mod-2,mod);} inline void newton(ll x){ if(x==1){ expx[0]=1; return ; } newton((x+1)/2); for(ll i=0;i<x;i++) tempcs[i]=expx[i]; ln(tempcs,tempas,x); for(ll i=0;i<x;i++) tempas[i]=(mod-tempas[i]+nexp[i])%mod; tempas[0]=(tempas[0]+1)%mod; times(tempas,expx,tempbs,x,1); for(ll i=0;i<x;i++) expx[i]=tempbs[i]; } int main(){ init(); n=read(); n++; for(i=0;i<n;i++) vall[i]=read(),valll[i]=vall[i]; k=read(),ks=modphi,kifdy=ifdy; m=read(); n=max(n,m); for(i=0;i<n;i++) if(vall[i]!=0) break; pos=i; for(j=i;j<n;j++) vall[j-i]=valll[j]; n-=pos,qms=valll[pos]; for(i=0;i<n;i++) vall[i]=vall[i]*qmi(qms,mod-2,mod)%mod; ln(vall,nexp,n); for(i=0;i<n;i++) nexp[i]=nexp[i]*k%mod; newton(n); if(pos>0&&kifdy){ for(i=0;i<n+pos;i++){ if(i>=m) break; write(0),pc(' '); } fwrite(obuf,p3-obuf,1,stdout); return 0; } for(i=0;i<min(n+pos,pos*k);i++){ if(i>=m) break; write(0),pc(' '); } for(i=pos*k;i<n+pos;i++){ if(i>=m) break; write(expx[i-pos*k]*qmi(qms,ks,mod)%mod),pc(' '); } fwrite(obuf,p3-obuf,1,stdout); return 0; }
|
代码中添加了加强版的数据判断,因为有可能不保证 $A(0)=1$,所以要找到一个最小的 $i$ 使得 $\forall 0 \le j<i,A(j)=0,A(i) \ne 0$,然后全部除以 $A(i)x^i$,这样的话 $A(0)$ 就是 $1$ 了。
最后答案需要乘上 $A(i)^k x^{ik}$,特别注意,输入的 $k$ 次方满足 $A(x)^k \equiv A(x)^{k \ \bmod \ p} \pmod{p}$,$p$ 足够大,这道题 $p=998244353$ 已经足够大了。
但是 $A(i)^k$ 却不满足上述公式,它满足的是欧拉公式,即 $A(i)^k \equiv A(i)^{k \ \bmod \ \varphi(p)} \pmod p$,所以还要新建一个变量表示 $k \bmod \varphi(p)$。(前提是 $\gcd(p,k)$ 互质,因为 $p=998244353$,所以也可以使用)
最后记得特判一下 $x^{ik} \ge x^n$ 的情况。