这里是板子,虽然看懂了原理,但是代码还是好难理解哦
void fft(int n,complex<double>*buffer,int offset,int step,complex<double>* epsilon)
{
if(n==1) return;
int m=n>>1;
fft(m,buffer,offset,step<<1,epsilon);
fft(m,buffer,offset+step,step<<1,epsilon);
for(int k=0;k!=m;++k)
{
int pos=2*step*k;
temp[k]=buffer[pos+offset]+epsilon[k*step]*buffer[pos+offset+step];
temp[k+m]=buffer[pos+offset]-epsilon[k*step]*buffer[pos+offset+step];
}
for(int i=0;i!=n;++i)
buffer[i*step+offset]=temp[i];
}
void init_epsilon(int n)
{
double pi=acos(-1);
for(int i=0;i!=n;++i)
{
epsilon[i]=complex<double>(cos(2.0*pi*i/n),sin(2.0*pi*i/n));
arti_epsilon[i]=conj(epsilon[i]);
}
}
int reverse_add(int x)
{
for(int l=1<<bit_length;(x^=l)<l;l>>=1);
return x;
}
/* 这时候 n 已经补齐到 2 的幂次 */
void bit_reverse(int n, complex_t *x)
{
for(int i=0,j=0;i!=n;++i)
{
if(i>j) swap(x[i],x[j]);
for(int l=n>>1;(j^=l)<l;l>>=1);
}
}
void transform(int n,complex_t *x,complex_t *w)
{
bit_reverse(n, x);
for(int i=2;i<=n;i<<=1)
{
int m=i>>1;
for(int j=0;j<n;j+=i)
{
for(int k=0;k!=m;++k)
{
complex_t z=x[j+m+k]*w[n/i*k];
x[j+m+k]=x[j+k]-z;
x[j+k]+=z;
}
}
}
}
这个是带注释的版本
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<iostream>
#include<algorithm>
using namespace std;
#define N 301000
const double pi=acos(-1);
struct node
{
double x,y;
node(){x=y=0;}
node(double x,double y):x(x),y(y){}
}a[N],b[N];
node operator + (node x,node y) {return node(x.x+y.x,x.y+y.y);}
node operator - (node x,node y) {return node(x.x-y.x,x.y-y.y);}
node operator * (node x,node y) {return node(x.x*y.x-x.y*y.y,x.x*y.y+x.y*y.x);}
void fft(node *s,int n,int t)
{
if (n==1) return;
node a0[n>>1],a1[n>>1];
for (int i=0;i<=n;i+=2)
a0[i>>1]=s[i],a1[i>>1]=s[i+1];
fft(a0,n>>1,t);fft(a1,n>>1,t);
node wn(cos(2*pi/n),t*sin(2*pi/n)),w(1,0);
for (int i=0;i<(n>>1);i++,w=w*wn)
s[i]=a0[i]+w*a1[i],s[i+(n>>1)]=a0[i]-w*a1[i];
}
int main()
{
int n,m,fn,i;
scanf("%d%d",&n,&m);
for (i=0;i<=n;i++) scanf("%lf",&a[i].x);
for (i=0;i<=m;i++) scanf("%lf",&b[i].x);
fn=1;while (fn<=n+m) fn<<=1;
fft(a,fn,1);fft(b,fn,1);
for (i=0;i<=fn;i++) a[i]=a[i]*b[i];
fft(a,fn,-1);
for (i=0;i<=n+m;i++) printf("%d ",(int)(a[i].x/fn+0.5));
printf("\n");
return 0;
}
搜索
复制