vectorを使ったData.List.sortより4倍速いsortアルゴリズムの実装
Data.List.sortがあまりに遅くてつらいので、vectorを使って書いてチューニングしたら約4倍速くなりましたという話をします。その過程でvectorのmonadic indexingとは何かという話をします。
仕様としてData.Listと互換性を持たせるため、次のようなインターフェースにしました。
{-# INLINE sort #-} sort :: Ord a => [a] -> [a] sort = sortBy compare sortBy :: (a -> a -> Ordering) -> [a] -> [a] sortBy = ...
次に実装を見てきます。
import qualified Data.Vector as V ... sortBy cmp = V.toList . mergeSortAux . V.fromList where mergeSortAux l = ...
入力はリストで与えられるので、まず最初にfromList :: [a] -> Vector aを使ってData.Vector.Vectorに変換し、mergeSortAux :: Vector a -> Vector aを使ってソートした後に、toList :: Vector a -> [a]を用いてリストに戻します。Unboxed Vectorを使うとさらに速くなりますが、ソートできるデータが制限されてしまうので今回はBoxed Vectorを使います。
mergeSortAuxの実装はただのマージソートです。今後ところどころunsafeなんとかを使っていますが、これは今回のアルゴリズムでは境界値外アクセスすることはないので配列の境界チェックを省くために行っているだけです。怖くないよ。
import qualified Data.Vector as V ... sortBy cmp = ... where mergeSortAux l | n <= 1 = l | otherwise = merge (n1, l1') (n2, l2') where n = V.length l n1 = div n 2 n2 = n - n1 l1 = V.unsafeSlice 0 n1 l l2 = V.unsafeSlice n1 n2 l l1' = mergeSortAux l1 l2' = mergeSortAux l2 merge = ...
さて、最後にmergeの実装を見てきましょう。
{-# LANGUAGE BangPatterns #-} import qualified Data.Vector as V import qualified Data.Vector.Mutable as MV sortBy cmp = ... where .... merge (!n1,!as) (!n2,!bs) = V.create $ do res <- MV.unsafeNew (n1+n2) let go i1 i2 = ... go 0 0
V.create :: (forall s. ST s (MVector s a)) -> Vector aを使ってMVector s aをVectorに変換します。
do構文の中ではマージ後の配列を格納するres :: MVector s aを確保し、go関数で先頭から順番に値を埋めていく感じです。
let go i1 i2 | i1 >= n1 && i2 >= n2 = return res | i1 >= n1 = tick2 i1 i2 | i2 >= n2 = tick1 i1 i2 | otherwise = do v1 <- V.unsafeIndexM as i1 v2 <- V.unsafeIndexM bs i2 case cmp v1 v2 of GT -> do MV.unsafeWrite res (i1 + i2) v2 go i1 (i2 + 1) _ -> do MV.unsafeWrite res (i1 + i2) v1 go (i1 + 1) i2 tick1 i1 i2 = do V.unsafeIndexM as i1 >>= MV.unsafeWrite res (i1 + i2) go (i1 + 1) i2 tick2 i1 i2 = do V.unsafeIndexM bs i2 >>= MV.unsafeWrite res (i1 + i2) go i1 (i2 + 1)
goの中でV.unsafeIndexMを使っていますが、これの意味について説明します。
indexMについて
例として以下のようなIndexTest.hsを考えます。
gist4eb6a528eed4ef665e9bcf1d91441094
$ ghc -O2 IndexTest.hs
$ ./IndexTest
test1: safe!
test2: safe!
IndexTest: Prelude.undefined
実行するとtest3でundefinedエラーが発生します。
コンパイルの中間コードをdumpして何が起こっているのかを読みます。
$ touch IndexTest.hs $ ghc -ddump-simpl -O2 IndexTest.hs > IndexTest.log
ghc -ddump-simpl -O2 IndexTest.hs (GHC 7.10.2)
test1のコンパイル結果を見ましょう。
a1_r5Mg :: V.Vector Int -> GHC.Prim.State# GHC.Prim.RealWorld -> (# GHC.Prim.State# GHC.Prim.RealWorld, () #) [GblId, Arity=2, Str=DmdType <S,1*H><L,U>] a1_r5Mg = \ (v_a2pg :: V.Vector Int) (eta_B1 [OS=OneShot] :: GHC.Prim.State# GHC.Prim.RealWorld) -> case v_a2pg of _ [Occ=Dead] { Data.Vector.Vector ipv_s2CZ ipv1_s2Db ipv2_s2Dc -> ((check_r2i0 lvl2_r5Mf (case GHC.Prim.indexArray# @ Int ipv2_s2Dc ipv_s2CZ of _ [Occ=Dead] { (# ipv3_a5vA #) -> ipv3_a5vA })) `cast` ...) eta_B1 } test1 [InlPrag=NOINLINE] :: V.Vector Int -> IO () [GblId, Arity=2, Str=DmdType <S,1*H><L,U>] test1 = a1_r5Mg `cast` ...
ごちゃごちゃしてますが重要なことはcheck関数呼び出し時に
(check_r2i0 lvl2_r5Mf (case GHC.Prim.indexArray# @ Int ipv2_s2Dc ipv_s2CZ of _ [Occ=Dead] { (# ipv3_a5vA #) -> ipv3_a5vA }))
と(case...)の部分が遅延評価されているということです。
この部分は元のソースコードでは(V.unsafeIndex v 0)の部分に対応します。
このため、余計なオーバーヘッドが生じてしまいます。
次にtest2のコンパイル結果を見ます。
a3_r5Mk :: V.Vector Int -> GHC.Prim.State# GHC.Prim.RealWorld -> (# GHC.Prim.State# GHC.Prim.RealWorld, () #) [GblId, Arity=2, Str=DmdType <S,1*U(U,A,U)><L,U>] a3_r5Mk = \ (v_a2ph :: V.Vector Int) (eta_B1 [OS=OneShot] :: GHC.Prim.State# GHC.Prim.RealWorld) -> case v_a2ph of _ [Occ=Dead] { Data.Vector.Vector ipv_s2Dk ipv1_s2Dl ipv2_s2Dm -> case GHC.Prim.indexArray# @ Int ipv2_s2Dm ipv_s2Dk of _ [Occ=Dead] { (# ipv3_a5vA #) -> ((check_r2i0 lvl4_r5Mj ipv3_a5vA) `cast` ...) eta_B1 } } test2 [InlPrag=NOINLINE] :: V.Vector Int -> IO () [GblId, Arity=2, Str=DmdType <S,1*U(U,A,U)><L,U>] test2 = a3_r5Mk `cast` ...
今度はcheckの呼び出しの前にindexArray#が呼び出されています。
しかし、配列の要素自体は評価しないのでundefinedエラーは発生しません。
最後にtest3のコンパイル結果を見ます。
a2_r5Mi :: V.Vector Int -> GHC.Prim.State# GHC.Prim.RealWorld -> (# GHC.Prim.State# GHC.Prim.RealWorld, () #) [GblId, Arity=2, Str=DmdType <S,1*U(U,A,U)><L,U>] a2_r5Mi = \ (v_a2pj :: V.Vector Int) (eta_B1 [OS=OneShot] :: GHC.Prim.State# GHC.Prim.RealWorld) -> case v_a2pj of _ [Occ=Dead] { Data.Vector.Vector ipv_s2Df ipv1_s2Dg ipv2_s2Dh -> case GHC.Prim.indexArray# @ Int ipv2_s2Dh ipv_s2Df of _ [Occ=Dead] { (# ipv3_a5vA #) -> case ipv3_a5vA of vx_a2zI { GHC.Types.I# ipv4_s3c7 -> ((check_r2i0 lvl3_r5Mh vx_a2zI) `cast` ...) eta_B1 } } } test3 [InlPrag=NOINLINE] :: V.Vector Int -> IO () [GblId, Arity=2, Str=DmdType <S,1*U(U,A,U)><L,U>] test3 = a2_r5Mi `cast` ...
この場合もcheckの呼び出し前にindexArray#を呼び出していて余計なオーバーヘッドは生じないのですが、さらに配列の要素も評価してしまうのでこの例ではundefinedエラーとなってしまいます。
まとめると
- check "test1" (V.unsafeIndex v 0)だと配列アクセスのためのサンクが作って関数呼び出し。
- check "test3" $! (V.unsafeIndex v 0)だと配列アクセスと要素の評価の後、関数呼び出し。
- V.unsafeIndexM v 0 >>= check "test2"だと配列アクセスの後、関数呼び出し。
実験
今回書いたソートのコードはgistにまとめています。
gist1b847adefb53e04eedd1b7adabfcc330
次のようなテストプログラムを書いて実験を行いました。
gistecc5317d3f07cc2890f68da58bdec379
入力として30万要素のランダムな整数列を与えてみます。
$ ./Main < test300000.in > test300000.out [2016-04-04 04:38:12.623993 UTC] Parsing: begin [2016-04-04 04:38:12.709319 UTC] Parsing: end [2016-04-04 04:38:12.709578 UTC] Parsing: 0.085326s [2016-04-04 04:38:12.709794 UTC] Vector: begin [2016-04-04 04:38:12.976383 UTC] Vector: end [2016-04-04 04:38:12.976596 UTC] Vector: Sorting: 0.266589s [2016-04-04 04:38:12.976818 UTC] List: begin [2016-04-04 04:38:14.018594 UTC] List: end [2016-04-04 04:38:14.018809 UTC] List: Sorting: 1.041776s [2016-04-04 04:38:14.019109 UTC] result validation begin [2016-04-04 04:38:14.022344 UTC] Correct! [2016-04-04 04:38:14.126167 UTC] output done [2016-04-04 04:38:14.126446 UTC] Elapsed Time: 1.502174s [2016-04-04 04:38:14.126715 UTC] Speedup (Data.List.sort / VecSort.sort) = 391%
Data.List.sortは1.04秒かかっていたのに対してVecSort.sortは0.267秒でおよそ4倍高速になりました。もちろん、ソートされたリストの最初の幾つかの要素が欲しい場合などではData.List.sortの方が速いこともありますが、全体をソートしたい場合にはたくさんのメモリを消費するリストを用いたソートよりも、省メモリなVectorを用いたソートを用いたほうがよいと思います。