拉格朗日插值 & 线性递推

拉格朗日插值 & 线性递推

拉格朗日插值

动态拉格朗日插值

插值公式($n$ 个点 $x_i,y_i$ 确定一个 $n-1$ 次多项式,求出点 $x$ 处的取值):

$$
f(x) = \sum_{i=1}^n y_i \prod_{j \ne i} \frac{x-x_j}{x_i-x_j}
$$

这个插值公式显然是 $O(n^2)$ 的,但是它可以动态向集合中加入点,动态询问当前多项式的点值。

也就是用前后缀积优化一下就可以了,此处给出代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include<bits/stdc++.h>
#pragma GCC optimize("Ofast")
#define ll long long
#define mod 998244353
#define N 3005
using namespace std;
ll n,opt,x,y,tx[N],ty[N],opts[N],q1[N],q2[N],tot;
inline ll qmi(ll a,ll b,ll p){
ll res = 1%p,t = a;
while(b){
if(b&1) res=res*t%p;
t=t*t%p;
b>>=1;
}
return res;
}
inline void insert(ll x,ll y){
tot++,tx[tot]=x,ty[tot]=y,opts[tot]=1;
for(ll i=1;i<tot;i++) opts[tot]=opts[tot]*(x-tx[i]+mod)%mod;
for(ll i=1;i<tot;i++) opts[i]=opts[i]*(tx[i]-x+mod)%mod;
}
inline ll query(ll x){
ll ans = 0;
q1[0] = 1;
for(ll i=1;i<=tot;i++) q1[i]=q1[i-1]*(x-tx[i]+mod)%mod;
q2[tot+1] = 1;
for(ll i=tot;i>=1;i--) q2[i]=q2[i+1]*(x-tx[i]+mod)%mod;
for(ll i=1;i<=tot;i++){
ll res = q1[i-1]*q2[i+1]%mod;
res = res*qmi(opts[i],mod-2,mod)%mod;
ans = (ans+res*ty[i])%mod;
}
return ans;
}
int main(){
ios::sync_with_stdio(false);
cin>>n;
while(n--){
cin>>opt;
if(opt==1){
cin>>x>>y;
insert(x,y);
}
else{
cin>>x;
cout<<query(x)<<endl;
}
}
return 0;
}

快速拉格朗日插值

这个方法只适用于静态的 $n$ 个点,不支持加点,然后询问多个点的点值,并且或者让你求出系数。

我们这里只讲求出系数,因为询问多个点的点值也很简单,把 $x$ 展开之后用多项式多点求值即可,或者求出系数之后多项式多点求值。

如何求出系数?我们回到公式:

$$
f(x) = \sum_{i=1}^n y_i \prod_{j \ne i} \frac{x-x_j}{x_i-x_j}
$$

首先求出常数项:

$$
f(x) = \sum_{i=1}^n y_i \prod_{j \ne i} \frac{1}{x_i-x_j}
$$

这个怎么做?相当于对于多项式 $A=\prod_{i=1}^n \frac 1{x-x_i}$ 求出 $x$ 在等于 $x_{1 \sim n}$ 处的点值然后乘上 $x_i-x_i=0$。

因为这两个值都是 $0$,考虑运用洛必达法则,也就是 $x$ 在等于 $x_{1 \sim n}$ 处的点值然后乘上 $x_i-x_i=0$ 就等于 $x$ 在多项式 $A=\prod_{i=1}^n \frac 1{x-x_i}$ 求导之后的点值。

这一部分用多项式多点求值即可。

设只用 $x_{l \sim r}$ 算出来 $f_{l,r}$(是一个多项式),就是 $\sum_{i=l}^r y_iA’(x_i) \prod_{l \le j \le r,j \ne i} (x-x_j)$,然后考虑分治合并。

首先 $f_{l,l}$ 可以很简单地得到是 $y_iA’(x_i)$,然后我们考虑合并 $f_{l,mid}$ 和 $f_{mid+1,r}$:

$$
\begin{aligned}
f_{l,r}&= \sum_{i=l}^r y_iA’(x_i) \prod_{l \le j \le r,j \ne i} (x-x_j) \\
&= \sum_{i=l}^{mid} y_iA’(x_i) \prod_{l \le j \le r,j \ne i} (x-x_j) +\sum_{i=mid+1}^{r} y_iA’(x_i) \prod_{l \le j \le r,j \ne i} (x-x_j) \\
&= \sum_{i=l}^{mid} y_iA’(x_i) \prod_{l \le j \le mid,j \ne i} (x-x_j) \prod_{mid+1 \le j \le r}(x-x_j) +\sum_{i=mid+1}^{r} y_iA’(x_i) \prod_{mid+1 \le j \le r,j \ne i} (x-x_j) \prod_{l \le j \le mid}(x-x_j) \\
&= f_{l,mid}g_{mid+1,r} + f_{mid+1,r}g_{l,mid} \\
\end{aligned}
$$

其中 $g_{l,r} = \prod_{i=l}^r(x-x_i)$,然后用分治下去的 NTT 计算就可以了。

最后的答案就是 $f_{1,n}$,代码实现如下,时间复杂度 $O(n \log^2 n)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
#include<bits/stdc++.h>
#define ll long long
#define N 500005
#define mod 998244353
using namespace std;
namespace IO{
inline char nc(){
static char buf[1000000],*p=buf,*q=buf;
return p==q&&(q=(p=buf)+fread(buf,1,1000000,stdin),p==q)?EOF:*p++;
}
inline ll read(){
ll res = 0,w = 1;
char c = nc();
while(c<'0'||c>'9')w=(c=='-'?-1:w),c=nc();
while(c<='9'&&c>='0')res=res*10+c-'0',c=nc();
return res*w;
}
char obuf[1<<21],*p3=obuf;
inline void pc(char c){
p3-obuf<=(1<<20)?(*p3++=c):(fwrite(obuf,p3-obuf,1,stdout),p3=obuf,*p3++=c);
}
inline void write(ll x){
if(x<0) pc('-'),x=-x;
if(x>9) write(x/10);
pc(x%10+'0');
}
}

using namespace IO;

inline ll qmi(ll a,ll b,ll p){
ll res = 1,t = a;
while(b){
if(b&1) res=res*t%p;
t=t*t%p;
b>>=1;
}
return res;
}
namespace poly{
ll n,m,len,res[N],G,invG,invn,nn,qm[N],qm2[N],temp[N],tempa[N],tempb[N],temp2[N],c[N];
inline void init(ll n){
for(ll i=0;i<n;i++){
res[i]=((res[i>>1])>>1);
if(i&1) res[i]+=(n>>1);
}
for(ll i=0;i<18;i++){
ll w1=qmi(G,(mod-1)/(1<<(i+1)),mod);
ll w2=qmi(invG,(mod-1)/(1<<(i+1)),mod);
for(ll j=0,now=1,now2=1;j<(1<<i);j++,now=now*w1%mod,now2=now2*w2%mod) qm[(1<<i)+j]=now,qm2[(1<<i)+j]=now2;
}
}
inline void ntt(ll *f,ll n,ll type){
static unsigned long long tmp[N];
ll u = __builtin_ctz((1ll<<18)/n);
for(ll i=0;i<n;i++) tmp[i]=f[res[i]>>u];
for(ll i=1;i<n;i<<=1){
for(ll r=(i<<1),j=0;j<n;j+=r){
for(ll k=0;k<i;k++){
ll y = (type==1?qm[i|k]:qm2[i|k])*tmp[j|k|i]%mod;
tmp[j|k|i] = (tmp[j|k]+mod-y),tmp[j|k]+=y;
}
}
}
for(ll i=0;i<n;i++) f[i]=tmp[i]%mod;
}
inline void times(ll *a,ll *b,ll *cp,ll n,ll m,ll tim){
ll nn = n,mm = m,op = 1;
while(op<=n+m) op*=2;
n=op;
for(ll i=0;i<n;i++) tempa[i]=(i>=nn?0:a[i]);
for(ll i=0;i<n;i++) tempb[i]=(i>=mm?0:b[i]);
ntt(tempa,n,1),ntt(tempb,n,1);
for(ll i=0;i<n;i++){
if(tim==1) c[i]=tempa[i]*tempb[i]%mod;
else c[i]=tempa[i]*tempb[i]%mod*tempb[i]%mod;
}
ntt(c,n,-1);
ll invn = qmi(n,mod-2,mod);
for(ll i=0;i<nn+mm;i++) cp[i]=c[i]*invn%mod;
}
inline void solve(ll x,ll *val,ll *val2){
if(x==1){
val2[0] = qmi(val[0],mod-2,mod);
return ;
}
ll len = (x+1)/2;
solve(len,val,val2);
for(ll i=len;i<x;i++) val2[i]=0;
times(val,val2,temp2,x,x,2);
for(ll i=0;i<x;i++) val2[i]=((2*val2[i]-temp2[i])%mod+mod)%mod;
}
ll *f[N],*g[N],bin[N<<6],*pos(bin),t1[N],t2[N],t3[N];
inline void get_val_init(ll *a,ll l,ll r,ll p){
g[p]=pos,pos+=(r-l+2),f[p]=pos,pos+=(r-l+2);
if(l==r){
g[p][0]=1,g[p][1]=(mod-a[l])%mod;
return ;
}
ll mid = (l+r)/2;
get_val_init(a,l,mid,2*p),get_val_init(a,mid+1,r,2*p+1);
times(g[2*p],g[2*p+1],g[p],mid-l+2,r-mid+1,1);
}
inline void get_val_solve(ll *a,ll *b,ll *c,ll l,ll r,ll p){
if(l==r){
c[l]=f[p][0];
return ;
}
ll mid = (l+r)/2,len1 = mid-l+2,len2 = r-mid+1;
reverse(g[2*p+1],g[2*p+1]+len2);
times(f[p],g[2*p+1],t3,r-l+1,len2,1);
for(ll i=0;i<len1;i++) f[2*p][i]=t3[i+len2-1];
get_val_solve(a,b,c,l,mid,2*p);

reverse(g[2*p],g[2*p]+len1);
times(f[p],g[2*p],t3,r-l+1,len1,1);
for(ll i=0;i<len2;i++) f[2*p+1][i]=t3[i+len1-1];
get_val_solve(a,b,c,mid+1,r,2*p+1);
}
inline void get_val(ll *a,ll *b,ll *c,ll n,ll m){
get_val_init(b,1,m,1);
solve(m+1,g[1],t1);
reverse(t1,t1+m+1),times(a,t1,t2,n+1,m+1,1);
for(ll i=0;i<=m;i++) f[1][i]=t2[n+i];
get_val_solve(a,b,c,1,m,1);
for(ll i=1;i<=m;i++) c[i]=(c[i]*b[i]+a[0])%mod;
}
}
ll n,m,a[N],b[N],i,c[N],d[N],k,*g[N],*f[N],bin[N<<5],*pos(bin),temp[N];
inline void binary_ntt(ll *a,ll l,ll r,ll p){
g[p]=pos,pos+=(r-l+2),f[p]=pos,pos+=(r-l+2);
if(l==r){
g[p][0]=(mod-a[l])%mod,g[p][1]=1;
return ;
}
ll mid = (l+r)/2;
binary_ntt(a,l,mid,2*p),binary_ntt(a,mid+1,r,2*p+1);
poly::times(g[2*p],g[2*p+1],g[p],mid-l+2,r-mid+1,1);
}
inline void binary_solve(ll l,ll r,ll p){
if(l==r){
f[p][0]=d[l];
return ;
}
ll mid = (l+r)/2;
binary_solve(l,mid,2*p),binary_solve(mid+1,r,2*p+1);
poly::times(g[2*p+1],f[2*p],temp,r-mid+1,mid-l+2,1);
for(ll i=0;i<=r-l+1;i++) f[p][i]=(f[p][i]+temp[i])%mod;
poly::times(g[2*p],f[2*p+1],temp,mid-l+2,r-mid+1,1);
for(ll i=0;i<=r-l+1;i++) f[p][i]=(f[p][i]+temp[i])%mod;
}
int main(){
poly::G=3,poly::invG=qmi(poly::G,mod-2,mod),poly::init(1<<18);
n=read();
for(i=1;i<=n;i++) a[i]=read(),b[i]=read();
binary_ntt(a,1,n,1);
for(i=0;i<=n;i++) c[i]=g[1][i];
for(i=1;i<=n;i++) c[i-1]=c[i]*i%mod;
c[n]=0;
poly::get_val(c,a,d,n+1,n);
for(i=1;i<=n;i++) d[i]=qmi(d[i],mod-2,mod)*b[i]%mod;
binary_solve(1,n,1);
for(i=0;i<n;i++) write(f[1][i]),pc(' ');
fwrite(obuf,p3-obuf,1,stdout);
return 0;
}

常系数齐次线性递推

我们要解决的是 $a_i=\sum_{j=1}^k a_{i-j}b_j$,在 $O(k \log k \log n)$ 的时间复杂度内求出 $a_n$ 的值。

我们以 Fib 数列为例,$f_1=1,f_2=1,f_3=2$,因为 $f_i=f_{i-1}+f_{i-2}$,我们把 $f_1$ 和 $f_2$ 看作未知数,那么有:

$$
\begin{aligned}
f_3 &= f_1+f_2 \\
f_4 &= f_3+f_2=f_1+2f_2 \\
f_5 &= f_3+f_4=2f_1+3f_2 \\
\end{aligned}
$$

如果令 $f_1=x^1,f_2=x^2$,那么就有:

$$
\begin{aligned}
f_3 &= x^1+x^2 \\
f_4 &= f_3+f_2=x+2x^2 \\
f_5 &= f_3+f_4=2x^1+3x^2 \\
\end{aligned}
$$

如果我们让项数从 $0$ 开始,那么 $x^2=x^0+x^1$,所以 $f_5=3x^0+5x^1$,容易发现这就是 $x^5$ 对 $x^2-x^0-x^1$ 取模之后的结果。

并且 $x^2-x^0-x^1=0$ 是这个斐波那契数列的特征公式,于是问题转化为了求 $x^n$ 对这个递推式的特征公式取模后的结果,设其为 $A=a_0x^0+a_1x^1+\dots+a_{k-1}x^{k-1}$,因为我们知道 $x^{0 \sim k-1}$ 的值,所以直接带入计算即可。

考虑如何求出取模后的值,我们用倍增快速幂的方法即可,也就是先求出 $x^i \bmod T$ 的值,然后平方一下再取模就是 $x^{2i} \bmod T$ 的值,最后按照倍增快速幂合并出 $x^n \bmod T$ 就可以了。($T$ 是特征多项式)

时间复杂度分析:每次乘法取模次数都是 $O(k)$ 级别的,一共乘+模 $O(\log n)$ 次,所以时间复杂度就是 $O(k \log k \log n)$。

代码实现(P4723 【模板】常系数齐次线性递推):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#include<bits/stdc++.h>
#define ll long long
#define N 2000005
#define mod 998244353
using namespace std;
namespace IO{
inline char nc(){
static char buf[1000000],*p=buf,*q=buf;
return p==q&&(q=(p=buf)+fread(buf,1,1000000,stdin),p==q)?EOF:*p++;
}
inline ll read(){
ll res = 0,w = 1;
char c = nc();
while(c<'0'||c>'9')w=(c=='-'?-1:w),c=nc();
while(c<='9'&&c>='0')res=res*10+c-'0',c=nc();
return res*w;
}
char obuf[1<<21],*p3=obuf;
inline void pc(char c){
p3-obuf<=(1<<20)?(*p3++=c):(fwrite(obuf,p3-obuf,1,stdout),p3=obuf,*p3++=c);
}
inline void write(ll x){
if(x<0) pc('-'),x=-x;
if(x>9) write(x/10);
pc(x%10+'0');
}
}

using namespace IO;

inline ll qmi(ll a,ll b,ll p){
ll res = 1,t = a;
while(b){
if(b&1) res=res*t%p;
t=t*t%p;
b>>=1;
}
return res;
}
namespace poly{
ll n,m,len,res[N],G,invG,invn,nn,qm[N],qm2[N],c[N],aval[N],bval[N],cval[N],dval[N],val[N],val2[N],val3[N],temp[N],tempp[N],tempa[N],tempb[N],temp2[N];
inline void init(ll n){
for(ll i=0;i<n;i++){
res[i]=((res[i>>1])>>1);
if(i&1) res[i]+=(n>>1);
}
for(ll i=0;i<18;i++){
ll w1=qmi(G,(mod-1)/(1<<(i+1)),mod);
ll w2=qmi(invG,(mod-1)/(1<<(i+1)),mod);
for(ll j=0,now=1,now2=1;j<(1<<i);j++,now=now*w1%mod,now2=now2*w2%mod) qm[(1<<i)+j]=now,qm2[(1<<i)+j]=now2;
}
}
inline void ntt(ll *f,ll n,ll type){
static unsigned ll tmp[N];
ll u = __builtin_ctz((1ll<<18)/n);
for(ll i=0;i<n;i++) tmp[i]=f[res[i]>>u];
for(ll i=1;i<n;i<<=1){
for(ll r=(i<<1),j=0;j<n;j+=r){
for(ll k=0;k<i;k++){
ll y = (type==1?qm[i|k]:qm2[i|k])*tmp[j|k|i]%mod;
tmp[j|k|i] = (tmp[j|k]+mod-y),tmp[j|k]+=y;
}
}
}
for(ll i=0;i<n;i++) f[i]=tmp[i]%mod;
}
inline void times(ll *a,ll *b,ll *cp,ll n,ll m,ll tim){
ll nn = n,mm = m,op = 1;
while(op<=n+m) op*=2;
n=op;
for(ll i=0;i<n;i++) tempa[i]=(i>=nn?0:a[i]);
for(ll i=0;i<n;i++) tempb[i]=(i>=mm?0:b[i]);
ntt(tempa,n,1),ntt(tempb,n,1);
for(ll i=0;i<n;i++){
if(tim==1) c[i]=tempa[i]*tempb[i]%mod;
else c[i]=tempa[i]*tempb[i]%mod*tempb[i]%mod;
}
ntt(c,n,-1);
ll invn = qmi(n,mod-2,mod);
for(ll i=0;i<nn+mm;i++) cp[i]=c[i]*invn%mod;
}
inline void solve(ll x){
if(x==1){
val2[0] = qmi(val[0],mod-2,mod);
return ;
}
ll len = (x+1)/2;
solve(len);
for(ll i=len;i<x;i++) val2[i]=0;
times(val,val2,temp2,x,x,2);
for(ll i=0;i<x;i++) val2[i]=((2*val2[i]-temp2[i])%mod+mod)%mod;
}
inline void modd(ll *aa,ll *bb,ll n,ll m,ll *f,ll &lenn,ll type){
lenn=0;
for(ll i=0;i<n;i++) aval[i]=aa[i],cval[i]=aval[i];
len=0;
reverse(aval,aval+n);
if(type==0){
for(ll i=0;i<m;i++) val[i]=bb[i],bval[i]=val[i];
reverse(val,val+m);
solve(m);
}
nn=1;
while(nn<n*2) nn*=2;
times(aval,val2,aval,n,n,1);
for(ll i=n-m;i>=0;i--) val3[len++]=aval[i];
times(bval,val3,dval,m,len,1);
for(ll i=0;i<m-1;i++) f[lenn++]=(cval[i]-dval[i]+mod)%mod;
}
}

ll n,m,a[N],b[N],i,c[N],k,p1,p2,g[N],g2[N],ans,temp,opt;

int main(){
poly::G=3,poly::invG=qmi(poly::G,mod-2,mod),poly::init(1<<18);
n=read(),k=read();
for(i=0;i<k;i++) b[i]=read(),b[i]=(b[i]%mod+mod)%mod;
reverse(b,b+k);
c[k]=1;
for(i=0;i<k;i++) c[i]=(mod-b[i])%mod;
for(i=0;i<k;i++) a[i]=read(),a[i]=(a[i]%mod+mod)%mod;
g[0]=1,g2[1]=1;
while(n){
if(n&1){
poly::times(g,g2,g,k,k,1);
poly::modd(g,c,2*k-1,k+1,g,temp,opt),opt=1;
}
poly::times(g2,g2,g2,k,k,1),poly::modd(g2,c,2*k-1,k+1,g2,temp,opt),opt=1;
n>>=1;
}
for(i=0;i<k;i++) ans=(ans+g[i]*a[i])%mod;
write(ans);
fwrite(obuf,p3-obuf,1,stdout);
return 0;
}