FFT什么的

  这里只有公式&作法,没有复杂的证实(实际上是由于弱鸡yww不会)数组

  参考自国家集训队论文&各个博客ide

多项式

​  一个以\(x\)为变量的多项式定义在一个代数域\(F\)上,将函数\(A(x)\)表示为形式和:
\[ A(x)=\sum_{j=0}^{n-1}a_jx^j \]
咱们称\(a_0,a_1,\ldots,a_{n-1}\)为多项式的系数,全部系数都属于数域\(F\),典型的情形是负数集合\(C\)函数

  若是一个多项式的最高次的非零系数是\(a_k\),则称\(A(x)\)的次数是\(k\)。任何严格大于一个多项式次数的整数都是该多项式的次数界。所以,对于次数界为\(n\)的多项式\(C(x)\),其次数能够是\(0\)~\(n-1\)之间的任何整数,包括\(0\)\(n-1\)优化

​  咱们在多项式上能够定义不少不一样的运算。ui

多项式加法

​  若是\(A(x)\)\(B(x)\)是次数界为\(n\)的多项式,那么他们的和也是一个次数界为\(n\)的多项式\(C(x)\)。对于全部属于定义域的\(x\),都有\(C(x)=A(x)+B(x)\)。也就是说,若
\[ A(x)=\sum_{j=0}^{n-1}a_jx^j\\ B(x)=\sum_{j=0}^{n-1}b_jx^j \]

\[ C(x)=\sum_{j=0}^{n-1}c_jx^j\\ \]
其中
\[ c_j=a_j+b_j \]
​  例如,若是
\[ A(x)=6x^3+7x^2-10x+9,B(x)=-2x^3+4x-5 \]

\[ C(x)=4x^3+7x^2-6x+4 \]spa

多项式乘法

​  若是\(A(x)\)是次数界为\(n\)的多项式,\(B(x)\)是次数界为\(m\)的多项式,那么他们的乘积是一个次数界为\(n+m\)的多项式\(C(x)\)。其中
\[ c_j=\sum_{k=0}^ja_kb_{j-k} \]
​  例如,若是
\[ A(x)=6x^3+7x^2-10x+9,B(x)=-2x^3+4x-5 \]
​  则
\[ C(x)=-12x^6-14x^5+44x^4-20x^3-75x^2+86x-45 \].net

多项式的表示

系数表达

​  对一个次数界为\(n\)的多项式\(A(x)=\sum_{j=0}^{n-1}a_jx^j\)而言,其系数表达式一个由系数组成获得向量\(a=(a_0,a_1,\cdots,a_{n-1})\)code

​  咱们能够用秦久韶算法在\(O(n)\)的时间内求出多项式在给定点\(x_0\)的值,即求值运算:
\[ A(x_0)=a_0+x_0(a_1+a_0(a_2+\cdots+x_0(a_{n-1}+x_0(a_{n-1})\cdots)) \]
​  相似的,对于两个分别用系数向量\(a=(a_0,a_1,\cdots,a_{n-1}),b=(b_0,b_1,\cdots,b_{n-1})\)表示的多项式进行相加时,所需的时间是\(O(n)\)。咱们只用输出系数向量\(c=(c_0,c_1,\cdots,c_{n-1})\),其中\(c_i=a_i+b_i\)blog

​  如今来考虑两个用系数形式表达的次数界为\(n\)的多项式\(A(x),B(x)\)的乘法运算,所须要的时间是\(O(n^2)\)。系数向量\(c\)也称为输入向量\(a,b\)的卷积。\(c=a\otimes b\)

点值表达

​  一个次数界为\(n\)的多项式的点值表达就是一个有\(n\)个点值对所组成的集合。
\[ \{(x_0,y_0),(x_1,y_1),\cdots,(x_{n-1},y_{n-1})\} \]
使得对\(k=0,1,\cdots,n-1\),全部\(x_k\)各不相同且\(y_k=A(x_k)\)

​  一个多项式能够有不少不一样的点值表达,由于能够采用\(n\)个不一样的点构成的集合做为这种表示方法的基。

​  朴素的求值是\(O(n^2)\)的。

​  求值的逆称为插值。当插值多项式的次数界等于已知的点值对的数目时,插值才是明确的。

​  咱们能够在用高斯消元在\(O(n^3)\)内插值,也能够用拉格朗日插值\(O(n^2)\)内插值。

​  以上求值和插值能够将多项式的系数表达和点值表达进行相互转化,上面给出的算法的时间复杂度是\(O(n^2)\),但咱们能够巧妙地选取\(x_k\)来加速这一过程,使其运行时间变为\(O(nlogn)\)

​  对于许多多项式相关的操做,点值表达式很便利的。

​  对于加法,若是\(C(x)=A(x)+B(x)\)。给定\(A\)的点值表达
\[ \{(x_0,y_0),(x_1,y_1),\cdots,(x_{n-1},y_{n-1})\} \]
\(B\)的点值表达
\[ \{(x_0,y'_0),(x_1,y'_1),\cdots,(x_{n-1},y'_{n-1})\} \]
(注意,\(A\)\(B\)在相同的\(n\)个位置求值),则\(C\)的点值表达是
\[ \{(x_0,y_0+y'_0),(x_1,y_1+y'_1),\cdots,(x_{n-1},y_{n-1}+y'_{n-1})\} \]
所以,对两个点值形式表示的次数界为\(n\)的多项式相加,时间复杂度是\(O(n)\)

​  相似的,若是\(C(x)=A(x)B(x)\),咱们须要\(2n\)个点值对才能插出\(C\)。给定\(A\)的点值表达
\[ \{(x_0,y_0),(x_1,y_1),\cdots,(x_{2n-1},y_{2n-1})\} \]
\(B\)的点值表达
\[ \{(x_0,y'_0),(x_1,y'_1),\cdots,(x_{2n-1},y'_{2n-1})\} \]
(注意,\(A\)\(B\)在相同的\(2n\)个位置求值),则\(C\)的点值表达是
\[ \{(x_0,y_0y'_0),(x_1,y_1y'_1),\cdots,(x_{2n-1},y_{2n-1}y'_{2n-1})\} \]
所以,对两个点值形式表示的次数界为\(n\)的多项式相乘,时间复杂度是\(O(n)\)

​  最后,咱们考虑一个采用点值表达的多项式,如何求其在某个新点上的值。最简单的方法是把该多项式转成系数形式表达,而后在新点处求值。

系数形式表示的多项式的快速乘法

​  若是咱们选\(n\)次单位复数根做为求值点,咱们能够在\(O(nlogn)\)内求值和插值。咱们先在对这两个多项式\(A,B\)求值以前添加\(n\)\(0\),使其次数界加倍为\(2n\)。如今咱们采用“\(2n\)次单位复数根”做为求值点。

DFT&FFT&IDFT

单位复数根

​  \(n\)次单位复数根是知足\(w^n=1\)的复数\(w\)\(n\)次单位复数根刚好有\(n\)个,对于\(k=0,1,\cdots,n-1\),这些根是\(e^{\frac{2\pi ik}{n}}\)\(w_n=e^\frac{2\pi i}{n}\)称为主\(n\)次单位根,全部其余\(n\)次单位复数根都是\(w_n\)的幂次。这\(n\)\(n\)次单位复数根在乘法意义下造成了一个群,即\(w_n^jw_n^k=w_n^{(j+k)mod~n}\),并且这\(n\)\(n\)次单位复数根均匀分布在以复平面的原点为圆心的单位半径的圆周上。(图片from zjt)

  

​  消去引理:对任何整数\(n\geq 0,k\geq 0,d>0\)
\[ w_{dn}^{dk}=w_n^k \]

DFT

​  回顾一下,咱们但愿计算次数界为\(n\)的多项式\(A(x)\)\(w_n^0,w_n^1,\cdots,w_n^{n-1}\)处的值(即在\(n\)\(n\)次单位复数根处)。对于\(k=0,1,\cdots,n-1\),定义结果\(y_k\)
\[ y_k=A(w_n^k)=\sum_{j=0}^{n-1}a_jw_n^{kj} \]
向量\(y=(y_0,y_1,\cdots,y_{n-1})\)就是系数向量\(a\)的离散傅里叶变换(DFT),咱们也记为\(y=DFT_n(a)\)

FFT

​  利用单位复数根的特殊性质,咱们能够在\(O(nlogn)\)内计算出\(DFT_n(a)\)。这里假设\(n\)\(2\)的幂。

  FFT利用了分治策略。

  咱们令\(a=(a_0,a_1,\cdots,a_{n-1}),a_1=(a_0,a_2,\cdots,a_{n-2}),a_2=(a_1,a_3,\cdots,a_{n-1})\)

  对于\(k<\frac n2\)有:
\[ \begin{align} y_k&=A(w_n^k)\\ &=\sum_{j=0}^{n-1}a_jw_n^{kj}\\ &=\sum_{j=0}^{\frac n2-1}a_{2j}w_n^{2kj}+\sum_{j=0}^{\frac n2-1}a_{2j+1}w_n^{2kj+k}\\ &=\sum_{j=0}^{\frac n2-1}a_{2j}w_n^{2kj}+w_n^k\sum_{j=0}^{\frac n2-1}a_{2j+1}w_n^{2kj}\\ &=\sum_{j=0}^{\frac n2-1}{a_1}_{j}w_{\frac n2}^{kj}+w_n^k\sum_{j=0}^{\frac n2-1}{a_2}_{j}w_{\frac n2}^{kj}\\ &={y_1}_k+w_n^k{y_2}_k \end{align} \]
  对于\(k\geq \frac n2\)有:
\[ \begin{align} y_k&=A(w_n^k)\\ &=\sum_{j=0}^{n-1}a_jw_n^{kj}\\ &=\sum_{j=0}^{\frac n2-1}a_{2j}w_n^{2kj}+\sum_{j=0}^{\frac n2-1}a_{2j+1}w_n^{2kj+k}\\ &=\sum_{j=0}^{\frac n2-1}a_{2j}w_n^{2kj}+w_n^k\sum_{j=0}^{\frac n2-1}a_{2j+1}w_n^{2kj}\\ &=\sum_{j=0}^{\frac n2-1}{a_1}_{j}w_{\frac n2}^{kj}+w_n^k\sum_{j=0}^{\frac n2-1}{a_2}_{j}w_{\frac n2}^{kj}\\ &=\sum_{j=0}^{\frac n2-1}{a_1}_{j}w_{\frac n2}^{(k-\frac n2)j}+w_n^k\sum_{j=0}^{\frac n2-1}{a_2}_{j}w_{\frac n2}^{(k-\frac n2)j}\\ &={y_1}_{k-\frac n2}+w_n^k{y_2}_{k-\frac n2}\\ &={y_1}_{k-\frac n2}-w_n^{k-\frac n2}{y_2}_{k-\frac n2} \end{align} \]
  这样咱们把\(y_1,y_2\)合并为\(y\)的时间复杂度是\(O(n)\)。因此总的时间复杂度是
\[ T(n)=2T(\frac n2)+O(n)=O(n\log n) \]

IDFT

​  经过推导公式,咱们获得:
\[ a_k=\frac1n\sum_{j=0}^{n-1}y_jw_n^{-kj} \]
​  因此咱们能够用相似FFT的方法在\(O(n\log n)\)内求出\(IDFT_n(y)\)

多项式乘法

​  咱们能够在\(O(n)\)内补\(0\)\(O(n\log n)\)内求值,\(O(n)\)内点值乘法,\(O(n\log n)\)内插值。因此咱们能够在\(O(n\log n)\)内求出\(a\otimes b\)
\[ a\otimes b=IDFT_{2n}(DFT_{2n}(a)\cdot DFT_{2n}(b)) \]

蝶形运算

  咱们把由\({y_1}_k,{y_2}_k,w_n^k\)获得\(y_k,y_{k+\frac n2}\)的过程称为蝴蝶操做。

​  咱们发现,递归时\(a\)是长这样的:
\[ 0~~~1~~~2~~~3~~~4~~~5~~~6~~~7\\ 0~~~2~~~4~~~6~|~1~~~3~~~5~~~7\\ 0~~~4~|~2~~~6~|~1~~~5~|~3~~~7\\ 0~|~4~|~2~|~6~|~1~|~5~|~3~|~7 \]
  总的蝶形运算是长这样的:
  
  

​  能够发现,最后\(a_i\)是原来的\(a_{rev(i)}\)。因此咱们能够交换\(a_i,a_{rev(i)}\),而后一层层来作。这样能够减少常数。

NTT

​  在某些时候,咱们须要求模\(p\)意义下的卷积。

​  先求出\(p\)的原根\(g\),能够发现,\(g^{\frac{p-1}{n}}\)\(w_n\)的性质相似。因此咱们能够用\(g^{\frac{p-1}{n}}\)来代替\(w_n\)

时间上的优化

  当咱们要算两个多项式 \(A(x), B(x)\) 的乘积的时候,普通的作法是先把 \(a,b\) 两个序列 DFT,再点乘,再 IDFT 回去。

  可是咱们还有一种方法:

​  令\(t_j=(a_j+b_j)+(a_j-b_j)i,S=T\times T\)

​  \(s_j\)的实部为
\[ \begin{align} \sum_{k=0}^j(a_k+b_k)(a_{j-k}+b_{j-k})-(a_k-b_k)(a_{j-k}-b_{j-k})&=\sum_{k=0}^j4a_kb_{j-k}=4\sum_{k=0}^ja_kb_{j-k} \end{align} \]
  这样咱们就能够求出\(S=T\times T\),而后把\(s_j\)除以\(4\)

  这个方法能够把\(3\)次DFT改为\(2\)次DFT。

多项式求导

  给定\(A(x)=\sum_{i\geq 0}a_ix^i\),定义\(A(x)\)的形式导数为
\[ A'(x)=\sum_{i\geq 1}ia_ix^{i-1} \]

多项式积分

  给定\(A(x)=\sum_{i\geq 0}a_ix^i\),则
\[ \int A(x)=\sum_{i\geq 1}\frac{a_{i-1}}{i}x^i \]

多项式求逆

​  多项式\(A(x)\)存在乘法逆元的充要条件是\(A(x)\)的常数项存在乘法逆元。

​  下面介绍一个\(O(n~log~n)\)计算乘法逆元的算法,它的本质是牛顿迭代法

​  首先求出\(A(x)\)常数项的逆元\(b\),令\(B(x)\)的初始值为\(b\)

​  假设已求出知足
\[ A(x)B(x)\equiv1~(mod~x^n) \]
\(B(x)\),则
\[ \begin{align} A(x)B(x)-1&\equiv0~(mod~x^n)\\ {(A(x)B(x)-1)}^2&\equiv 0~(mod~x^{2n})\\ A(x)(2B(x)-B(x)^2A(x))&\equiv 1~(mod~x^{2n}) \end{align} \]
​  咱们能够用\(O(n~log~n)\)的时间计算出\(2B(x)-B(x)^2A(x)\),并将它赋值给\(B(x)\)进行下一次迭代。每迭代一次,\(B(x)\)的有效项数\(n\)都会增长一倍。因而该算法的时间复杂度为
\[ T(n)=T(n/2)+O(n\log n)=O(n\log n) \]

多项式开根

  已知\(A(x)\),求\(B(x)\)使得
\[ B(x)^2\equiv A(x)~(mod~x^n) \]

  先求出\(A(x)\)常数项的平方根\(b\)(能够用二次剩余的东西来算,但我只会暴力算),令\(B(x)\)的初始值为\(b\)

  假设已求出知足
\[ B(x)^2\equiv A(x)~(mod~x^n) \]
\(B(x)\),则
\[ \begin{align} B(x)^2-A(x)&\equiv 0~(mod~x^n)\\ {(B(x)^2-A(x))}^2&\equiv 0~(mod~x^{2n})\\ B(x)^4-2B(x)^2A(x)+A(x)^2&\equiv 0~(mod~x^{2n})\\ B(x)^4+2B(x)^2A(x)+A(x)^2&\equiv 4B(x)^2A(x)~(mod~x^{2n})\\ {(B(x)^2+A(x))}^2&\equiv {(2B(x))}^2A(x)~(mod~x^{2n})\\ {(\frac{B(x)^2+A(x)}{2B(x)})}^2&\equiv A(x)~(mod~x^{2n}) \end{align} \]
  咱们能够在\(O(n\log n)\)内算出\(\frac{B(x)^2+A(x)}{2B(x)}=\frac{B(x)}{2}+\frac{A(x)}{2B(x)}\),并把它赋值给\(B(x)\)

  时间复杂度:\(O(n\log n)\)

多项式ln

  给定形式幂级数\(A(x)=\sum_{i\geq 1}a_ix^i\),定义
\[ \ln(1-A(x))=-\sum_{i\geq 1}\frac{{A(x)}^i}{i} \]
  给定多项式\(A(x)=1+\sum_{i\geq 1}a_ix^i\),令
\[ B(x)=\ln(A(x)) \]

\[ B'(x)=\frac{A'(x)}{A(x)} \]
  只须要求出\(A(x)\)的乘法逆元,就能够求出\(\ln(A(x))\)

多项式exp

  给定形式幂级数\(A(x)=\sum_{i\geq 1}a_ix^i\),定义
\[ \exp(A(x))=\sum_{i\geq 0}\frac{{A(x)}^i}{i!} \]
  令\(f(x)=e^{A(x)}\),可获得一个关于\(f(x)\)的方程
\[ g(f(x))=\ln(f(x))-A(x)=0 \]
  考虑用牛顿迭代解这一方程。首先\(f(x)\)的常数项是容易肯定的(就是\(1\))。

  设以求得\(f(x)\)的前\(n\)\(f_0(x)\),即
\[ f(x)\equiv f_0(x)~~~(mod~~~x^n) \]
  做泰勒展开得
\[ \begin{align} 0&=g(f(x))\\ &=g(f_0(x))+g'(f_0(x))(f(x)-f_0(x))~~~~~(mod~~~x^{2n}) \end{align} \]

\[ f(x)\equiv f_0(x)-\frac{g(f_0(x))}{g'(f_0(x))}~~~~(mod~~~x^{2n}) \]
  把上面那个式子带入得
\[ \begin{align} f(x)&=f_0(x)-\frac{\ln(f_0(x))-A(x)}{\frac{1}{f_0(x)}}\\ &=f_0(x)(1-\ln(f_0(x))+A(x)) \end{align} \]
  时间复杂度:\(O(n\log n)\)
  

多项式求幂

  给你\(A(x),k\),求\(A^k(x)\)

  设\(A(x)\)中最低次数项是\(cx^d\),那么先把整个多项式除以\(cx^d\),再求\(\ln\),把整个多项式乘以\(k\),再求\(\exp\),再乘上\(c^kx^{kd}\)
\[ A^k(x)=\exp(k\ln\frac{A(x)}{cx^d}))c^kx^{kd} \]
  时间复杂度:\(O(n\log n)\)

多项式除法

​  给你\(A(x),B(x)\),求两个多项式\(D(x),R(x)\)知足
\[ A(x)=D(x)B(x)+R(x) \]
​  若\(A(x)\)是一个\(n\)阶多项式,则
\[ A^R(x)=x^nA(\frac1x) \]
  举个例子:好比说
\[ A(x)=x^3+2x^2+3x+4\\ A^R(x)=1+2x+3x^2+4x^3 \]
​  至关于把\(A(x)\)的系数反转。

  咱们设\(A(x)\)\(n\)阶多项式,\(B(x)\)\(m\)阶多项式,\(D(x)\)\(n-m\)阶多项式,\(R(x)\)\(m-1\)阶多项式。咱们把上个式子的\(x\)\(\frac1x\),而后所有乘上\(x^n\)
\[ x^nA(\frac1x)=x^{n-m}D(\frac1x)x^mB(\frac1x)+x^{n-m+1}x^{m-1}R(\frac1x)\\ A^R(x)=D^R(x)B^R(x)+x^{n-m+1}R^R(x) \]
  而后咱们把这个式子放在模\(x^{n-m+1}\)意义下,获得
\[ A^R(x)=D^R(x)B^R(x)~(mod~x^{n-m+1})\\ D^R(x)=A^R(x){(B^R(x))}^{-1}~(mod~x^{n-m+1}) \]
  由于\(D(x)\)的次数是\(n-m\),因此不会受模意义的影响。

  而后把\(D(x)\)带入到原来的式子中,就能够算出\(R(x)\)了。

  时间复杂度:\(O(n\log n)\)

多点求值

  给你一个多项式\(A(x)\)\(n\)个点\(x_0,x_1,\cdots,x_{n-1}\),求这个多项式在这\(n\)个点处的值,即求\(A(x_0),A(x_1),\cdots,A(x_{n-1})\)

  考虑一个简单的作法:构造\(B_i(x)=x-x_i,C_i(x)=A(x)~mod~B_i(x)\),那么\(B_i(x_i)=0\)。因此\(A(x_i)=C_i(x_i)\)。可是计算\(B_i(x)\)\(C_i(x)\)\(O(n)\)的,必须加速这个过程。

  设当前求值的点为\(X=\{x_0,x_1,\cdots,x_{n-1}\}\),咱们能够把这\(n\)个点分为两半:
\[ X_0=\{x_0,x_1,\cdots,x_{\frac n2-1}\}\\ X_1=\{x_{\frac n2},x_{\frac n2+1},\cdots,x_{n-1}\} \]
  构造多项式
\[ B_0=\prod_{i=0}^{\frac n2-1}(x-x_i)\\ B_1=\prod_{i=\frac n2}^{n-1}(x-x_i)\\ A_0=A~mod~B_0\\ A_1=A~mod~B_1 \]
  那么当\(x\in X_0\)\(A(x)=A_0(x)\),能够递归计算。当\(x\in X_1\)时同理。

  每一层计算\(B_0,B_1,A_0,A_1\)的时间复杂度都是\(O(n\log n)\)

  总的时间复杂度就是
\[ T(n)=2T(\frac n2)+O(n\log n)=O(n\log^2n) \]

快速插值

  考虑怎么求\(g_i=\prod_{j=0,j\neq i}^n (x_i-x_j)\),也就是分母。

\[ \begin{align} g_i&=\prod_{j=0,j\neq i}^n (x_i-x_j)\\ &=\lim_{x \to x_i}\frac{\prod_{j=0}^n (x-x_j)}{x-x_i}\\ &=(\prod_{j=0}^n (x-x_j))'|_{x=x_i} \end{align} \]

  能够分治求出\(\prod_{j=0}^n (x-x_j)\)再求导后在全部\(x_i\)处多点求值。

  分子直接分治求出。

  时间复杂度:\(O(n\log^2n)\)

小技巧1

  好比咱们要计算两个实数序列的卷积\(A\times B=C\),记\(D_i=(a_i+b_i)+(a_i-b_i)i\),那么\(C_i=\frac{1}{4}real({D^2}_i)\)
  
  这样就能够把三次DFT减小到两次DFT。
  
  固然,若是\(A=B\)那么这个优化是没有效果的。

任意模数FFT

模板

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<cmath>
#include<functional>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
void sort(int &a,int &b)
{
    if(a>b)
        swap(a,b);
}
void open(const char *s)
{
#ifndef ONLINE_JUDGE
    char str[100];
    sprintf(str,"%s.in",s);
    freopen(str,"r",stdin);
    sprintf(str,"%s.out",s);
    freopen(str,"w",stdout);
#endif
}
int rd()
{
    int s=0,c;
    while((c=getchar())<'0'||c>'9');
    do
    {
        s=s*10+c-'0';
    }
    while((c=getchar())>='0'&&c<='9');
    return s;
}
int upmin(int &a,int b)
{
    if(b<a)
    {
        a=b;
        return 1;
    }
    return 0;
}
int upmax(int &a,int b)
{
    if(b>a)
    {
        a=b;
        return 1;
    }
    return 0;
}
const ll p=998244353;
const ll g=3;
ll fp(ll a,ll b)
{
    ll s=1;
    while(b)
    {
        if(b&1)
            s=s*a%p;
        a=a*a%p;
        b>>=1;
    }
    return s;
}
const int maxn=600000;
ll inv[maxn];
namespace ntt
{
    ll w1[maxn];
    ll w2[maxn];
    int rev[maxn];
    int n;
    void init(int m)
    {
        n=1;
        while(n<m)
            n<<=1;
        int i;
        for(i=2;i<=n;i<<=1)
        {
            w1[i]=fp(g,(p-1)/i);
            w2[i]=fp(w1[i],p-2);
        }
        rev[0]=0;
        for(i=1;i<n;i++)
            rev[i]=(rev[i>>1]>>1)|((i&1)*(n>>1));
    }
    void ntt(ll *a,int t)
    {
        int i,j,k;
        ll u,v,w,wn;
        for(i=0;i<n;i++)
            if(rev[i]<i)
                swap(a[i],a[rev[i]]);
        for(i=2;i<=n;i<<=1)
        {
            wn=(t==1?w1[i]:w2[i]);
            for(j=0;j<n;j+=i)
            {
                w=1;
                for(k=j;k<j+i/2;k++)
                {
                    u=a[k];
                    v=a[k+i/2]*w%p;
                    a[k]=(u+v)%p;
                    a[k+i/2]=(u-v)%p;
                    w=w*wn%p;
                }
            }
        }
        if(t==-1)
        {
            u=fp(n,p-2);    
            for(i=0;i<n;i++)
                a[i]=a[i]*u%p;
        }
    }
    ll x[maxn];
    ll y[maxn];
    ll z[maxn];
    void copy_clear(ll *a,ll *b,int m)
    {
        int i;
        for(i=0;i<m;i++)
            a[i]=b[i];
        for(i=m;i<n;i++)
            a[i]=0;
    }
    void copy(ll *a,ll *b,int m)
    {
        int i;
        for(i=0;i<m;i++)
            a[i]=b[i];
    }
    void mul(ll *a,ll *b,ll *c,int m)
    {
        init(m<<1);
        copy_clear(x,a,m);
        copy_clear(y,b,m);
        ntt(x,1);
        ntt(y,1);
        int i;
        for(i=0;i<n;i++)
            x[i]=x[i]*y[i]%p;
        ntt(x,-1);
        copy(c,x,m);
    }
    void inverse(ll *a,ll *b,int m)
    {
        if(m==1)
        {
            b[0]=fp(a[0],p-2);
            return;
        }
        inverse(a,b,m>>1);
        init(m<<1);
        copy_clear(x,a,m);
        copy_clear(y,b,m>>1);
        ntt(x,1);
        ntt(y,1);
        int i;
        for(i=0;i<n;i++)
            x[i]=y[i]*(2-x[i]*y[i]%p)%p;
        ntt(x,-1);
        copy(b,x,m);
    }
    ll c[maxn],d[maxn],e[maxn],f[maxn];
    void sqrt(ll *a,ll *b,int m)
    {
        if(m==1)
        {
            if(a[0]==1)
                b[0]=1;
            else if(a[0]==0)
                b[0]=0;
            else
                //我也不会
                ;
            return;
        }
        sqrt(a,b,m>>1);
//      copy_clear(c,b,m>>1);
        int i;
        for(i=m;i<m<<1;i++)
            b[i]=0;
        inverse(b,d,m);
        init(m<<1);
        for(i=m;i<m<<1;i++)
            b[i]=d[i]=0;
        ll inv2=fp(2,p-2);
        copy_clear(x,a,m);
        ntt(x,1);
        ntt(d,1);
        for(i=0;i<n;i++)
            x[i]=x[i]*d[i]%p;
        ntt(x,-1);
        for(i=0;i<m;i++)
            b[i]=((b[i]+x[i])%p*inv2)%p;
    }
    void derivative(ll *a,ll *b,int m)
    {
        int i;
        for(i=0;i<m-1;i++)
            b[i]=(i+1)*a[i+1]%p;
        b[m-1]=0;
    }
    void differential(ll *a,ll *b,int m)
    {
        int i;
        for(i=m-1;i>=1;i--)
            b[i]=a[i-1]*inv[i]%p;
        b[0]=0;
    }
    void ln(ll *a,ll *b,int m)
    {
        static ll c[maxn],d[maxn];
        derivative(a,c,m);
        inverse(a,d,m);
        init(m<<1);
        int i;
        for(i=m;i<n;i++)
            c[i]=d[i]=0;
        ntt(c,1);
        ntt(d,1);
        for(i=0;i<n;i++)
            c[i]=c[i]*d[i]%p;
        ntt(c,-1);
        differential(c,b,m);
    }
    void exp(ll *a,ll *b,int m)
    {
        if(m==1)
        {
            b[0]=1;
            return;
        }
        exp(a,b,m>>1);
        int i;
        for(i=m>>1;i<m;i++)
            b[i]=0;
        ln(b,y,m);
        init(m<<1);
        copy_clear(x,a,m);
        x[0]++;
        for(i=0;i<m;i++)
            x[i]=(x[i]-y[i])%p;
        copy_clear(y,b,m);
        ntt(x,1);
        ntt(y,1);
        for(i=0;i<n;i++)
            x[i]=x[i]*y[i]%p;
        ntt(x,-1);
        copy(b,x,m);
    }
    void module(ll *a,ll *b,ll *c,int n1,int n2)
    {
        int k=1;
        while(k<=n1-n2+1)
            k<<=1;
        int i;
        for(i=0;i<=n1;i++)
            d[i]=a[i];
        for(i=0;i<=n2;i++)
            e[i]=b[i];
        reverse(d,d+n1+1);
        reverse(e,e+n2+1);
        for(i=n1-n2+1;i<k<<1;i++)
            d[i]=e[i]=0;
        inverse(e,f,k);
        for(i=n1-n2+1;i<k<<1;i++)
            f[i]=0;
        init(k<<1);
        ntt::ntt(d,1);
        ntt::ntt(f,1);
        for(i=0;i<n;i++)
            e[i]=d[i]*f[i]%p;
        ntt::ntt(e,-1);
        for(i=0;i<=n1-n2;i++)
            c[i]=e[i];
        reverse(c,c+n1-n2+1);
    }
};
ll b[maxn];
ll a[maxn];
ll c[maxn];
void get(ll *a,int n)
{
    int i;
    for(i=0;i<n;i++)
        a[i]=rand();
}
int main()
{
//  freopen("fft.txt","w",stdout);
//  srand(time(0));
//  int n=262144;
//  int bg,ed;
//  int i;
//  int times=100,j;
//  double s,s1;
//  inv[0]=inv[1]=1;
//  for(i=2;i<=n;i++)
//      inv[i]=-(p/i)*inv[p%i]%p;
//  s=0;
//  for(j=1;j<=times;j++)
//  {
//      get(a,n);
//      bg=clock();
//      ntt::init(n);
//      ntt::ntt(a,1);
//      ed=clock();
//      s+=double(ed-bg)/CLOCKS_PER_SEC;
//  }
//  printf("ntt :%.10lf\n",s/times);
//  s1=s;
//  s=0;
//  for(j=1;j<=times;j++)
//  {
//      get(a,n);
//      get(b,n);
//      bg=clock();
//      ntt::mul(a,b,c,n);
//      ed=clock();
//      s+=double(ed-bg)/CLOCKS_PER_SEC;
//  }
//  printf("mul :%.10lf %.10lf\n",s/times,s/s1);
//  s=0;
//  for(j=1;j<=times;j++)
//  {
//      get(a,n);
//      bg=clock();
//      ntt::inverse(a,b,n);
//      ed=clock();
//      s+=double(ed-bg)/CLOCKS_PER_SEC;
//  }
//  printf("inv :%.10lf %.10lf\n",s/times,s/s1);
//  s=0;
//  for(j=1;j<=times;j++)
//  {
//      get(a,n);
//      a[0]=1;
//      bg=clock();
//      ntt::sqrt(a,b,n);
//      ed=clock();
//      s+=double(ed-bg)/CLOCKS_PER_SEC;
//  }
//  printf("sqrt:%.10lf %.10lf\n",s/times,s/s1);
//  s=0;
//  for(j=1;j<=times;j++)
//  {
//      get(a,n);
//      a[0]=1;
//      bg=clock();
//      ntt::ln(a,b,n);
//      ed=clock();
//      s+=double(ed-bg)/CLOCKS_PER_SEC;
//  }
//  printf("ln  :%.10lf %.10lf\n",s/times,s/s1);
//  s=0;
//  for(j=1;j<=times;j++)
//  {
//      get(a,n);
//      bg=clock();
//      ntt::exp(a,b,n);
//      ed=clock();
//      s+=double(ed-bg)/CLOCKS_PER_SEC;
//  }
//  printf("exp :%.10lf %.10lf\n",s/times,s/s1);
//  return 0;
}

多点求值+快速插值

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const ll p=998244353;
const ll g=3;
const int maxw=131072;
const int maxn=150000;
ll fp(ll a,ll b)
{
    ll s=1;
    for(;b;b>>=1,a=a*a%p)
        if(b&1)
            s=s*a%p;
    return s;
}
int rt,cnt,ls[1000010],rs[1000010];
ll vx[100010],vy[100010],va[100010];
ll inv[maxn],w1[maxn],w2[maxn];
int rev[maxn];
void init()
{
    inv[0]=inv[1]=1;
    for(int i=2;i<=maxw;i++)
        inv[i]=-p/i*inv[p%i]%p;
    for(int i=2;i<=maxw;i<<=1)
    {
        w1[i]=fp(g,(p-1)/i);
        w2[i]=fp(w1[i],p-2);
    }
}
ll *f[1000010];
int len[maxn];
void clear(ll *a,int n)
{
    memset(a,0,(sizeof a[0])*n);
}
void ntt(ll *a,int n,int t)
{
    for(int i=1;i<n;i++)
    {
        rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
        if(i>rev[i])
            swap(a[i],a[rev[i]]);
    }
    for(int i=2;i<=n;i<<=1)
    {
        ll wn=(t==1?w1[i]:w2[i]);
        for(int j=0;j<n;j+=i)
        {
            ll w=1;
            for(int k=j;k<j+i/2;k++)
            {
                ll u=a[k];
                ll v=a[k+i/2]*w%p;
                a[k]=(u+v)%p;
                a[k+i/2]=(u-v)%p;
                w=w*wn%p;
            }
        }
    }
    if(t==-1)
    {
        ll inv=fp(n,p-2);
        for(int i=0;i<n;i++)
            a[i]=a[i]*inv%p;
    }
}
void mul(ll *a,ll *b,ll *c,int n,int m)
{
    int k=1;
    while(k<=n+m)
        k<<=1;
    static ll a1[maxn],a2[maxn];
    clear(a1,k);
    clear(a2,k);
    for(int i=0;i<=n;i++)
        a1[i]=a[i];
    for(int i=0;i<=m;i++)
        a2[i]=b[i];
    ntt(a1,k,1);
    ntt(a2,k,1);
    for(int i=0;i<k;i++)
        a1[i]=a1[i]*a2[i]%p;
    ntt(a1,k,-1);
    for(int i=0;i<=n+m;i++)
        c[i]=a1[i];
}
void getinv(ll *a,ll *b,int n)
{
    if(n==1)
    {
        b[0]=fp(a[0],p-2);
        return;
    }
    getinv(a,b,n>>1);
    static ll a1[maxn],a2[maxn];
    clear(a1,n<<1);
    clear(a2,n<<1);
    for(int i=0;i<n;i++)
        a1[i]=a[i];
    for(int i=0;i<n>>1;i++)
        a2[i]=b[i];
    ntt(a1,n<<1,1);
    ntt(a2,n<<1,1);
    for(int i=0;i<n<<1;i++)
        a1[i]=a2[i]*(2-a2[i]*a1[i]%p)%p;
    ntt(a1,n<<1,-1);
    for(int i=0;i<n;i++)
        b[i]=a1[i];
}
void div(ll *a,ll *b,ll *c,int n,int m)
{
    static ll a1[maxn],a2[maxn],a3[maxn];
    int k=1;
    while(k<=2*(n-m))
        k<<=1;
    for(int i=0;i<=n;i++)
        a1[i]=a[i];
    for(int i=0;i<=m;i++)
        a2[i]=b[i];
    reverse(a1,a1+n+1);
    reverse(a2,a2+m+1);
    clear(a1+n-m+1,k-(n-m+1));
    clear(a2+n-m+1,k-(n-m+1));
    getinv(a2,a3,k);
    clear(a3+n-m+1,k-(n-m+1));
    ntt(a1,k,1);
    ntt(a3,k,1);
    for(int i=0;i<k;i++)
        a1[i]=a1[i]*a3[i]%p;
    ntt(a1,k,-1);
    for(int i=0;i<=n-m;i++)
        c[i]=a1[i];
    reverse(c,c+n-m+1);
}
void getmod(ll *a,ll *b,ll *c,int n,int m)
{
    static ll a1[maxn],a2[maxn];
    int k=1;
    while(k<=n)
        k<<=1;
    clear(a1,k);
    clear(a2,k);
    for(int i=0;i<=m;i++)
        a1[i]=b[i];
    div(a,b,a2,n,m);
    ntt(a1,k,1);
    ntt(a2,k,1);
    for(int i=0;i<k;i++)
        a1[i]=a1[i]*a2[i]%p;
    ntt(a1,k,-1);
    for(int i=0;i<m;i++)
        c[i]=(a[i]-a1[i])%p;
}
void divide(int l,int r,int &now)
{
    now=++cnt;
    len[now]=r-l+1;
    f[now]=new ll[len[now]+1];
    if(l==r)
    {
        f[now][1]=1;
        f[now][0]=-vx[l];
        return;
    }
    int mid=(l+r)>>1;
    divide(l,mid,ls[now]);
    divide(mid+1,r,rs[now]);
    mul(f[ls[now]],f[rs[now]],f[now],len[ls[now]],len[rs[now]]);
}
void getv(ll *a,int n,int l,int r,int now)
{
    ll *a1=new ll[len[now]];
    getmod(a,f[now],a1,n,len[now]);
    if(l==r)
    {
        va[l]=a1[0];
        return;
    }
    int mid=(l+r)>>1;
    getv(a1,len[now]-1,l,mid,ls[now]);
    getv(a1,len[now]-1,mid+1,r,rs[now]);
}
ll *s[1000010];
void getpoly(int l,int r,int now)
{
    s[now]=new ll[len[now]];
    if(l==r)
    {
        s[now][0]=va[l];
        return;
    }
    int mid=(l+r)>>1;
    getpoly(l,mid,ls[now]);
    getpoly(mid+1,r,rs[now]);
    int k=1;
    while(k<=len[now])
        k<<=1;
    static ll a1[maxn],a2[maxn],a3[maxn],a4[maxn];
    clear(a1,k);
    clear(a2,k);
    clear(a3,k);
    clear(a4,k);
    for(int i=0;i<len[ls[now]];i++)
        a1[i]=s[ls[now]][i];
    for(int i=0;i<=len[rs[now]];i++)
        a2[i]=f[rs[now]][i];
    for(int i=0;i<len[rs[now]];i++)
        a3[i]=s[rs[now]][i];
    for(int i=0;i<=len[ls[now]];i++)
        a4[i]=f[ls[now]][i];
    ntt(a1,k,1);
    ntt(a2,k,1);
    ntt(a3,k,1);
    ntt(a4,k,1);
    for(int i=0;i<k;i++)
        a1[i]=(a1[i]*a2[i]+a3[i]*a4[i])%p;
    ntt(a1,k,-1);
    for(int i=0;i<len[now];i++)
        s[now][i]=a1[i];
}
int n;
ll a[maxn],b[maxn],c[maxn];
int main()
{
    init();
    scanf("%d",&n);
    for(int i=0;i<=n;i++)
        scanf("%lld%lld",&vx[i],&vy[i]);
    divide(0,n,rt);
    for(int i=0;i<=n;i++)
        a[i]=f[rt][i+1]*(i+1)%p;
    getv(a,n,0,n,rt);
//  for(int i=0;i<=n;i++)
//      printf("%lld ",(va[i]+p)%p);
//  printf("\n");
    for(int i=0;i<=n;i++)
        va[i]=fp(va[i],p-2)*vy[i]%p;
    getpoly(0,n,rt);
    for(int i=0;i<=n;i++)
        printf("%lld ",(s[rt][i]+p)%p);
    printf("\n");
    return 0;
}
相关文章
相关标签/搜索