vectorを使ったData.List.sortより4倍速いsortアルゴリズムの実装

Data.List.sortがあまりに遅くてつらいので、vectorを使って書いてチューニングしたら約4倍速くなりましたという話をします。その過程でvectorのmonadic indexingとは何かという話をします。

仕様としてData.Listと互換性を持たせるため、次のようなインターフェースにしました。

{-# INLINE sort #-}
sort :: Ord a => [a] -> [a]
sort = sortBy compare
sortBy :: (a -> a -> Ordering) -> [a] -> [a]
sortBy = ...

次に実装を見てきます。

import qualified Data.Vector as V
...
sortBy cmp = V.toList . mergeSortAux . V.fromList
    where 
    mergeSortAux l = ...

入力はリストで与えられるので、まず最初にfromList :: [a] -> Vector aを使ってData.Vector.Vectorに変換し、mergeSortAux :: Vector a -> Vector aを使ってソートした後に、toList :: Vector a -> [a]を用いてリストに戻します。Unboxed Vectorを使うとさらに速くなりますが、ソートできるデータが制限されてしまうので今回はBoxed Vectorを使います。

mergeSortAuxの実装はただのマージソートです。今後ところどころunsafeなんとかを使っていますが、これは今回のアルゴリズムでは境界値外アクセスすることはないので配列の境界チェックを省くために行っているだけです。怖くないよ。

import qualified Data.Vector as V
...
sortBy cmp = ...
    where 
    mergeSortAux l
        | n <= 1 = l
        | otherwise = merge (n1, l1') (n2, l2')
             where
                n  = V.length l
                n1 = div n 2
                n2 = n - n1
                l1 = V.unsafeSlice 0  n1 l
                l2 = V.unsafeSlice n1 n2 l
                l1' = mergeSortAux l1
                l2' = mergeSortAux l2
    merge = ... 

さて、最後にmergeの実装を見てきましょう。

{-# LANGUAGE BangPatterns #-}
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV

sortBy cmp = ...
    where
    ....
    merge (!n1,!as) (!n2,!bs) = V.create $ do
        res <- MV.unsafeNew (n1+n2)
        let go i1 i2 = ...
        go 0 0

V.create :: (forall s. ST s (MVector s a)) -> Vector aを使ってMVector s aをVectorに変換します。
do構文の中ではマージ後の配列を格納するres :: MVector s aを確保し、go関数で先頭から順番に値を埋めていく感じです。

        let go i1 i2 | i1 >= n1 && i2 >= n2 = return res
                     | i1 >= n1 = tick2 i1 i2
                     | i2 >= n2 = tick1 i1 i2
                     | otherwise = do
                         v1 <- V.unsafeIndexM as i1
                         v2 <- V.unsafeIndexM bs i2
                         case cmp v1 v2 of
                            GT  -> do
                                 MV.unsafeWrite res (i1 + i2) v2
                                 go i1 (i2 + 1)
                            _   -> do
                                 MV.unsafeWrite res (i1 + i2) v1
                                 go (i1 + 1) i2
            tick1 i1 i2 = do
                V.unsafeIndexM as i1 >>= MV.unsafeWrite res (i1 + i2)
                go (i1 + 1) i2
            tick2 i1 i2 = do
                V.unsafeIndexM bs i2 >>= MV.unsafeWrite res (i1 + i2)
                go i1 (i2 + 1)

goの中でV.unsafeIndexMを使っていますが、これの意味について説明します。

indexMについて

例として以下のようなIndexTest.hsを考えます。

gist4eb6a528eed4ef665e9bcf1d91441094

$ ghc -O2 IndexTest.hs
$ ./IndexTest
test1: safe!
test2: safe!
IndexTest: Prelude.undefined

実行するとtest3でundefinedエラーが発生します。

コンパイルの中間コードをdumpして何が起こっているのかを読みます。

$ touch IndexTest.hs
$ ghc -ddump-simpl -O2 IndexTest.hs > IndexTest.log

ghc -ddump-simpl -O2 IndexTest.hs (GHC 7.10.2)

test1のコンパイル結果を見ましょう。

a1_r5Mg
  :: V.Vector Int
     -> GHC.Prim.State# GHC.Prim.RealWorld
     -> (# GHC.Prim.State# GHC.Prim.RealWorld, () #)
[GblId, Arity=2, Str=DmdType <S,1*H><L,U>]
a1_r5Mg =
  \ (v_a2pg :: V.Vector Int)
    (eta_B1 [OS=OneShot] :: GHC.Prim.State# GHC.Prim.RealWorld) ->
    case v_a2pg
    of _ [Occ=Dead]
    { Data.Vector.Vector ipv_s2CZ ipv1_s2Db ipv2_s2Dc ->
    ((check_r2i0
        lvl2_r5Mf
        (case GHC.Prim.indexArray# @ Int ipv2_s2Dc ipv_s2CZ
         of _ [Occ=Dead] { (# ipv3_a5vA #) ->
         ipv3_a5vA
         }))
     `cast` ...)
      eta_B1
    }

test1 [InlPrag=NOINLINE] :: V.Vector Int -> IO ()
[GblId, Arity=2, Str=DmdType <S,1*H><L,U>]
test1 =
  a1_r5Mg
  `cast` ...

ごちゃごちゃしてますが重要なことはcheck関数呼び出し時に

(check_r2i0
        lvl2_r5Mf
        (case GHC.Prim.indexArray# @ Int ipv2_s2Dc ipv_s2CZ
         of _ [Occ=Dead] { (# ipv3_a5vA #) ->
         ipv3_a5vA
         }))

と(case...)の部分が遅延評価されているということです。
この部分は元のソースコードでは(V.unsafeIndex v 0)の部分に対応します。
このため、余計なオーバーヘッドが生じてしまいます。

次にtest2のコンパイル結果を見ます。

a3_r5Mk
  :: V.Vector Int
     -> GHC.Prim.State# GHC.Prim.RealWorld
     -> (# GHC.Prim.State# GHC.Prim.RealWorld, () #)
[GblId, Arity=2, Str=DmdType <S,1*U(U,A,U)><L,U>]
a3_r5Mk =
  \ (v_a2ph :: V.Vector Int)
    (eta_B1 [OS=OneShot] :: GHC.Prim.State# GHC.Prim.RealWorld) ->
    case v_a2ph
    of _ [Occ=Dead]
    { Data.Vector.Vector ipv_s2Dk ipv1_s2Dl ipv2_s2Dm ->
    case GHC.Prim.indexArray# @ Int ipv2_s2Dm ipv_s2Dk
    of _ [Occ=Dead] { (# ipv3_a5vA #) ->
    ((check_r2i0 lvl4_r5Mj ipv3_a5vA)
     `cast` ...)
      eta_B1
    }
    }

test2 [InlPrag=NOINLINE] :: V.Vector Int -> IO ()
[GblId, Arity=2, Str=DmdType <S,1*U(U,A,U)><L,U>]
test2 =
  a3_r5Mk
  `cast` ...

今度はcheckの呼び出しの前にindexArray#が呼び出されています。
しかし、配列の要素自体は評価しないのでundefinedエラーは発生しません。

最後にtest3のコンパイル結果を見ます。

a2_r5Mi
  :: V.Vector Int
     -> GHC.Prim.State# GHC.Prim.RealWorld
     -> (# GHC.Prim.State# GHC.Prim.RealWorld, () #)
[GblId, Arity=2, Str=DmdType <S,1*U(U,A,U)><L,U>]
a2_r5Mi =
  \ (v_a2pj :: V.Vector Int)
    (eta_B1 [OS=OneShot] :: GHC.Prim.State# GHC.Prim.RealWorld) ->
    case v_a2pj
    of _ [Occ=Dead]
    { Data.Vector.Vector ipv_s2Df ipv1_s2Dg ipv2_s2Dh ->
    case GHC.Prim.indexArray# @ Int ipv2_s2Dh ipv_s2Df
    of _ [Occ=Dead] { (# ipv3_a5vA #) ->
    case ipv3_a5vA of vx_a2zI { GHC.Types.I# ipv4_s3c7 ->
    ((check_r2i0 lvl3_r5Mh vx_a2zI)
     `cast` ...)
      eta_B1
    }
    }
    }

test3 [InlPrag=NOINLINE] :: V.Vector Int -> IO ()
[GblId, Arity=2, Str=DmdType <S,1*U(U,A,U)><L,U>]
test3 =
  a2_r5Mi
  `cast` ...

この場合もcheckの呼び出し前にindexArray#を呼び出していて余計なオーバーヘッドは生じないのですが、さらに配列の要素も評価してしまうのでこの例ではundefinedエラーとなってしまいます。

まとめると

  • check "test1" (V.unsafeIndex v 0)だと配列アクセスのためのサンクが作って関数呼び出し。
  • check "test3" $! (V.unsafeIndex v 0)だと配列アクセスと要素の評価の後、関数呼び出し。
  • V.unsafeIndexM v 0 >>= check "test2"だと配列アクセスの後、関数呼び出し。

ちなみにindexM系がモナド内でしか使えないのはある種のHackなのでモナドの種類は関係ないはずです。

実験

今回書いたソートのコードはgistにまとめています。
gist1b847adefb53e04eedd1b7adabfcc330

次のようなテストプログラムを書いて実験を行いました。

gistecc5317d3f07cc2890f68da58bdec379

入力として30万要素のランダムな整数列を与えてみます。

$ ./Main < test300000.in > test300000.out
[2016-04-04 04:38:12.623993 UTC] Parsing: begin
[2016-04-04 04:38:12.709319 UTC] Parsing: end
[2016-04-04 04:38:12.709578 UTC] Parsing: 0.085326s
[2016-04-04 04:38:12.709794 UTC] Vector: begin
[2016-04-04 04:38:12.976383 UTC] Vector: end
[2016-04-04 04:38:12.976596 UTC] Vector: Sorting: 0.266589s
[2016-04-04 04:38:12.976818 UTC] List: begin
[2016-04-04 04:38:14.018594 UTC] List: end
[2016-04-04 04:38:14.018809 UTC] List: Sorting: 1.041776s
[2016-04-04 04:38:14.019109 UTC] result validation begin
[2016-04-04 04:38:14.022344 UTC] Correct!
[2016-04-04 04:38:14.126167 UTC] output done
[2016-04-04 04:38:14.126446 UTC] Elapsed Time: 1.502174s
[2016-04-04 04:38:14.126715 UTC] Speedup (Data.List.sort / VecSort.sort) = 391%

Data.List.sortは1.04秒かかっていたのに対してVecSort.sortは0.267秒でおよそ4倍高速になりました。もちろん、ソートされたリストの最初の幾つかの要素が欲しい場合などではData.List.sortの方が速いこともありますが、全体をソートしたい場合にはたくさんのメモリを消費するリストを用いたソートよりも、省メモリなVectorを用いたソートを用いたほうがよいと思います。