プログラムdeタマゴ

nodamushiの著作物は、文章、画像、プログラムにかかわらず全てUnlicenseです

高速フーリエ変換(FFT)の解説。実装編

前回のソース2を再帰呼び出しからループに変換してみる。例えば前回のように2097152個のデータを突っ込むと21階層もの再帰呼び出しになる。したがって、これをループに変換するってのは重要なことだ。





まず、「何回のループに展開されるのか?」を考えよう。「これは要するに何階層の再帰呼び出しになるのか?」と同意で、2nのnである。2097152は2の21乗なので21回だ。下に示すソースコードでは変数nbitがそれに対応する。


 次に、「それぞれの呼び出されたメソッド(_fft)をどう展開すればいいのか?」を考える。この事で大切なのは配列のどこの場所にアクセス(読み込み書き込み)するかである。前回作ったプログラムを見ていただければわかると思うが、それぞれの_fftメソッドは他の処理に干渉しない。今、自分がどこを処理しているのかが重要なのだ。t回目の呼び出し(ループ)での各_fftが処理するデータの個数は2^{n-t}個になる。そしてデータは配列上に連続に並んでいるので、処理する配列の先頭番地がわかれば良い。(前回のプログラムでこれらを示しているのは、for(int m=0;m<nh;m++)の部分である。) 先頭番地の名前をindexとすると、呼び出し(ループ)がt回目のk番目のメソッドのindexはk\times2^{n-t}である。
 ところで、kの範囲はいくらだろうか?すなわち、t回目の再帰呼び出し時には何個のメソッドが呼び出されているだろうか?各kについて処理をするにはこの範囲もわかっていないといけない。前回作ったプログラムでは、一回の_fftメソッドは展開のために二回_fftメソッドを呼び出した。従って、1回呼び出す毎に2倍に増えていくので、2^t個になる。これをfsizeと名付けておく。



 準備は以上である。再帰呼び出しをループに展開する準備ができたので、実際にソースに変換する。



ソース3

public static void fft(double[] real,double[] imaginary){
    int N = real.length;
    double theta = -2*Math.PI/N;//(12)
    //n=2^nbitとなるnbitを求めています。
    //これはすなわちループ回数になります。
    int nbit =31- Integer.numberOfLeadingZeros(N);
    for(int t = 0,n=N,nh = N>>1;t <nbit;t++){
        int fsize = 1 << t;//2^tを計算しています
        for(int k = 0;k < fsize;k++){
            int index = k * n;//※nの意味が上の文章と違っています
            for(int m = 0;m < nh;m++){
                double
                    r,i,
                    Rm = real[m+index]-(r=real[m+nh+index]),
                    Im = imaginary[m+index]-(i=imaginary[m+nh+index]);
                real[m+index] +=r;
                imaginary[m+index] += i;
                double cos = Math.cos(m*theta);
                double sin = Math.sin(m*theta);
                real[m+nh+index] = Rm*cos-Im*sin;
                imaginary[m+nh+index] = Rm*sin+Im*cos;
            }
        }
        n =nh;
        nh =(nh >>1);
        theta *=2;
    }

    //並べ替え
    for (int j = 0,i=0,nh = N>>1,nh1 =nh+1; j < nh; j += 2) {
        double tmpr,tmpi;
        if (j < i) {
            tmpr = real[j];
            real[j] = real[i];
            real[i] = tmpr;
            tmpr = real[j + nh1];
            real[j + nh1] = real[i + nh1];
            real[i + nh1] = tmpr;
            tmpi = imaginary[j];
            imaginary[j] = imaginary[i];
            imaginary[i] = tmpi;
            tmpi = imaginary[j + nh1];
            imaginary[j + nh1] = imaginary[i + nh1];
            imaginary[i + nh1] = tmpi;
        }
        tmpr = real[j + nh];
        real[j + nh] = real[i + 1];
        real[i + 1] = tmpr;
        tmpi = imaginary[j + nh];
        imaginary[j + nh] = imaginary[i + 1];
        imaginary[i + 1] = tmpi;
        for (int k = nh >> 1; k > (i ^= k); k >>= 1);
    }
}

なお、これ以降並び替えの処理は常に同じなので記載を省略する




これでも良い感じなのだが、よく見て貰うと、

for(int k=0;k<fsize;k++){
    ………
    for(int m=0;m<nh;m++){
        ………
        //cosとsinはkの値に無関係な数
        double cos = Math.cos(m*theta);
        double sin = Math.sin(m*theta);

 cosおよびsinの同じ引数での呼び出しがkに無関係でありながら何度も計算していることが分かると思う。
 今時のPCが速いと言っても、やっぱり三角関数計算はテイラー展開しているので重い処理だ。できることならこれの呼び出し回数を減らしてやるべきだ。



 そのためにループの順序を入れ替える。ついでに、ループの条件判定を変えて必要な変数の削除も行った。ループの条件が変わっているのでがらっと変わったように見えるかもしれないが、やってることは大して変わっていない。



ソース4

public static void fft(double[] real,double[] imaginary){
    int N = real.length;
    double theta = -2*Math.PI/N;
    for(int n = N,nh;  (nh = n>>1) >=1;  n = nh){
        for(int m=0;m < nh;m++){
            //kのループ外に出たので無駄な呼び出しがなくなった
            double cos = Math.cos(m*theta);
            double sin = Math.sin(m*theta);

            for(int k = m;k < N;k+=n){
                double
                r,i,
                Rm = real[k]-(r=real[k+nh]),
                Im = imaginary[k]-(i=imaginary[k+nh]);
                real[k] +=r;
                imaginary[k] += i;
                real[k+nh] = Rm*cos-Im*sin;
                imaginary[k+nh] = Rm*sin+Im*cos;
            }
        }
        theta *=2;
    }
    //並び替え(略)
}

 このソースで前回と同じように2097152個のデータを突っ込み測定すると1.3秒となった。



 なお、さらに関数の呼び出しを減らす手段として、三角関数の加法定理を利用する事が考えられる。



ソース5

public static void fft(double[] real,double[] imaginary){
    final int N = real.length;
    double theta = -2*Math.PI/N;
    for(int n = N,nh;  (nh = n>>1) >=1;  n = nh){
        double cosTheta = Math.cos(theta);
        double sinTheta = Math.sin(theta);

        double cos = 1;//cos(mθ)はm=0から開始するので初期値は1
        double sin = 0;//同様に初期値は0
        for(int m=0;m < nh;m++){
            for(int k = m;k < N;k+=n){
                double
                r,i,
                Rm = real[k]-(r=real[k+nh]),
                Im = imaginary[k]-(i=imaginary[k+nh]);
                real[k] +=r;
                imaginary[k] += i;
                real[k+nh] = Rm*cos-Im*sin;
                imaginary[k+nh] = Rm*sin+Im*cos;
            }
            //cos,sinを加法定理を用いて更新します。
            double c = cos*cosTheta-sin*sinTheta;//cos(A+B)=cosAcosB-sinAsinB
            double s = sin*cosTheta+cos*sinTheta;//sin(A+B)=sinAcosB+cosAsinB
            cos = c;
            sin = s;
        }
        theta *=2;
    }

    //並べ替え
}

 これで測定したところ1.1秒となり確かに速度改善が見られた。
 だが、そもそもcosやsin関数の結果には計算打ち止め誤差が含まれており、それを用いてcos,sinの値を更新する度誤差が拡大していく事には注意したい。誤差が気になるようなら利用はお勧めしない。

 ここには載せないが、他の関数呼び出しを減らす方法としては、nがNでない場合-90度(マイナスであることに注意)次の角度は
cos(θ-90) = sinθ
sin(θ-90) = -cosθ
であることを利用する事も考えられる。これは誤差が拡大することもない。




 後で並び替えても良いと言うことは、先に並び替えても良いと言うことだ。どうやらそれをバタフライ演算というらしい。
 以下の様な感じになると思われるのだが、測定結果も1.3秒と全く変化が無く、わざわざ順序を反転する理由がなんなのかさっぱり分からない。
 誰か、このバタフライ演算の意味、およびメリットを教えてくれませんか。



ソース6

public static void fft(double[] real,double[] imaginary){
    final int N = real.length;
    double theta = -Math.PI;

    //並べ替え(略)


    for(int nh = 1,n;  (n = nh<<1) <=N;  nh = n){
        for(int m=0;m < nh;m++){
            double cos = Math.cos(m*theta);
            double sin = Math.sin(m*theta);
            for(int k = m;k < N;k+=n){
                double
                r=real[k],
                i=imaginary[k],
                r2 = real[k+nh],
                i2 = imaginary[k+nh],
                R = r2*cos-i2*sin,
                I=r2*sin+i2*cos;

                real[k] += R;
                imaginary[k] += I;
                real[k+nh] = r-R;
                imaginary[k+nh] = i-I;
            }
        }
        theta /=2;
    }
}





最後にFFTを並列処理にすることを考えてみる。
ソース4〜6の様な形は二重ループになっているので効率よく並列処理に持ち込むのは微妙に面倒くさい。


そこでソース2の形に戻ってみる。この場合再帰呼び出し+1重のループと並列処理に持ち込むには最高の形をしている。
CPUで計算する場合、並列度はCPUのコア数と同じにした方が良いから、どこかで並列化を止めた方が良い。そこで、ソース2+ソース4のハイブリッド型を考えてみる。なお、面倒くさかったので下のソースはコア数が2^nであることを前提とした。そうでない場合はそれを超えない最大の、もしくはそれより大きい最小の2^nとかにするといいんじゃね?


ここでは並列化するのにExecutorServiceを利用して並列処理をしたが、Java7以降ならFork/Joinフレームワークを利用した方が良いだろう。




実行結果は0.7秒とかなり速くなった。ちなみに、当方CPUはcore i5です。並列数は4です。まぁ、物理コアは2つなわけだし、2倍近い速度でてるからいいんじゃね?
最初の2秒から3倍近く速くなりました。もう、十分でしょう。


というわけで、かなり長くなったがこれでFFTの解説を終了します。



ソース7

    private static ExecutorService pool;

    public static void fft(double[] real,double[] imaginary){
        int n = real.length;
        int core = Runtime.getRuntime().availableProcessors();
        pool = Executors.newFixedThreadPool(core);
        double theta = -2*Math.PI/n;//(12)
        _fft(n,theta,real,imaginary,0,core);

        //並べ替え (略)

        pool.shutdownNow();
    }

    private static void _fft(
            final int n,final double theta,
            final double[] real,final double[] imaginary,final int arrayIndex,final int thread){
        if(n<=1)return;

        if(thread >1)
            __fft(n, theta, real, imaginary, arrayIndex, thread);
        else{
            _fft(n, theta, real, imaginary, arrayIndex);
        }
    }

    private static void _fft(
            final int N, double theta,
            double[] real,double[] imaginary,int arrayIndex){
        for(int n = N,nh;  (nh = n>>1) >=1;  n = nh){


            for(int m=0;m < nh;m++){
                double cos = Math.cos(m*theta);
                double sin = Math.sin(m*theta);
                for(int k = m;k < N;k+=n){
                    double
                    r,i,
                    Rm = real[k+arrayIndex]-(r=real[k+nh+arrayIndex]),
                    Im = imaginary[k+arrayIndex]-(i=imaginary[k+nh+arrayIndex]);
                    real[k+arrayIndex] +=r;
                    imaginary[k+arrayIndex] += i;
                    real[k+nh+arrayIndex] = Rm*cos-Im*sin;
                    imaginary[k+nh+arrayIndex] = Rm*sin+Im*cos;
                }

            }
            theta *=2;
        }
    }

    private static void __fft(
            final int n,final double theta,
            final double[] real,final double[] imaginary,final int arrayIndex,final int thread){
        final int nh = n/2;
        Future[] fu = new Future[thread-1];
        for(int t =0;t<thread-1;t++){
            final int sm = nh*t/thread,em = (t+1)*nh/thread;
            fu[t] = pool.submit(new Runnable(){

                public void run(){
                    for(int m=sm;m<em;m++){
                        double r1 = real[m+arrayIndex],r2 = real[m+nh+arrayIndex];
                        double i1 = imaginary[m+arrayIndex],i2 = imaginary[m+nh+arrayIndex];
                        real[m+arrayIndex] = r1+r2;
                        imaginary[m+arrayIndex] = i1+i2;

                        double Rm = r1 - r2;
                        double Im =  i1 - i2;
                        double cos = Math.cos(m*theta);
                        double sin = Math.sin(m*theta);
                        real[m+nh+arrayIndex] = Rm*cos-Im*sin;
                        imaginary[m+nh+arrayIndex] = Rm*sin+Im*cos;
                    }
                }
            });
        }
        for(int m=(thread-1)*nh/thread;m<nh;m++){
            double r1 = real[m+arrayIndex],r2 = real[m+nh+arrayIndex];
            double i1 = imaginary[m+arrayIndex],i2 = imaginary[m+nh+arrayIndex];
            real[m+arrayIndex] = r1+r2;
            imaginary[m+arrayIndex] = i1+i2;

            double Rm = r1 - r2;
            double Im =  i1 - i2;
            double cos = Math.cos(m*theta);
            double sin = Math.sin(m*theta);
            real[m+nh+arrayIndex] = Rm*cos-Im*sin;
            imaginary[m+nh+arrayIndex] = Rm*sin+Im*cos;
        }

        for(Future f:fu){
            try {
                f.get();
            } catch (InterruptedException e) {
                e.printStackTrace();
            } catch (ExecutionException e) {
                e.printStackTrace();
            }
        }

        Future<?> f=pool.submit(new Runnable(){
            public void run(){
                _fft(nh,2*theta,real,imaginary,arrayIndex,thread>>1);
            }
        });
        _fft(nh,2*theta,real,imaginary,arrayIndex+nh,thread>>1);
        try {
            f.get();
        } catch (InterruptedException e) {
            e.printStackTrace();
        } catch (ExecutionException e) {
            e.printStackTrace();
        }
    }