Tuesday, October 26, 2010

Tour of a real toy Haskell program, part 2

This is part 2 of my commentary on detrospector. You can find part 1 here.

Performance claims

I make a lot of claims about performance below, not all of which are supported by evidence. I did perform a fair amount of profiling, but it's impossible to test every combination of data structures. How can I structure my code to make more of my performance hypotheses easily testable?

Representing the source document

When we process a source document, we'll need a data structure to represent large amounts of text.

  • Haskell's standard String type is simple, maximally lazy, and Unicode-correct. However, as a singly-linked list of individually boxed heap-allocated characters, it's woefully inefficient in space and time.

  • ByteString is a common alternative, but solves the wrong problem. We need a sequence of Unicode codepoints, not a sequence of bytes.

  • ByteString.Char8 is a hack and not remotely Unicode-correct.

  • The text library stores Unicode characters in a ByteString-like packed format. That's exactly what we need.

For more information on picking a string type, see ezyang's writeup.

We choose the lazy flavor of text so that we can stream the input file in chunks without storing it all in memory. Lazy IO is semantically dubious, but I've decided that this project is enough of a toy that I don't care. Note that robustness to IO errors is not in the requirements list. ;)

The enumerator package provides iteratees as a composable, well-founded alternative to the lazy IO hack. The enumerator API was complex enough to put me off using it for the first version of detrospector. After reading Michael Snoyman's enumerators tutorial, I have a slight idea how the library works, and I might use it in a future version.

Representing substrings

We also need to represent the k-character substrings, both when we analyze the source and when we generate text. The requirements here are different.

We expect that k will be small, certainly under 10 — otherwise, you'll just generate the source document! With many small strings, the penalty of boxing each individual Char is less severe.

And we need to support queue-like operations; we build a new substring by pushing a character onto the end, and dropping one from the beginning. For both String and Text, appending onto the end requires a full copy. So we'll use Data.Sequence, which provides sequences with bidirectional append, implemented as finger trees:

-- module Detrospector.Types
type Queue a = S.Seq a

...Actually, I just ran a quick test with Text as the queue type, and it seems not to matter much. Since k is small, the O(k) copy is insignificant. Profiling makes fools of us all, and asymptotic analysis is mostly misused. Anyway, I'm keeping the code the way it is, pending any insight from my clever readers. Perhaps one of the queues from Purely Functional Data Structures would be most appropriate.

Representing character counts

We need a data structure for tabulating character counts. In the old days of C and Eurocentrism, we could use a flat, mutable int[256] indexed by ASCII values. But the range of Unicode characters is far too large for a flat array, and we need efficient functional updates without a full copy.

We could import Data.Map and use Map Char Int. This will build a balanced binary search tree using pairwise Ord comparisons.

But we can do better. We can use the bits of an integer key as a path in a tree, following a left or right child for a 0 or 1 bit, respectively. This sort of search tree (a trie) will typically outperform repeated pairwise comparisons. Data.IntMap implements this idea, with an API very close to Map Int. Our keys are Chars, but we can easily convert using fromIntegral.

-- module Detrospector.Types
import qualified Data.IntMap as IM
...
type FreqTable = IM.IntMap Int

Representing distributions

So we have some frequency table like

IM.fromList [('e', 267), ('t', 253), ('a', 219), ...]

How can we efficiently pick a character from this distribution? We're mapping characters to individual counts, but we really want a map from cumulative counts to characters:

-- module Detrospector.Types
data PickTable = PickTable Int (IM.IntMap Char)

To sample a character from PickTable t im, we first pick a random k such that 0 ≤ k < t, using a uniform distribution. We then find the first key in im which is greater than k, and take its associated Char value. In code:

-- module Detrospector.Types
import qualified System.Random.MWC as RNG
...
sample :: PickTable -> RNG.GenIO -> IO Char
sample (PickTable t im) g = do
k <- (`mod` t) <$> RNG.uniform g
case IM.split k im of
(_, IM.toList -> ((_,x):_)) -> return x
_ -> error "impossible"

The largest cumulative sum is the total count t, so the largest key in im is t. We know k < t, so IM.split k im will never return an empty map on the right side.

Note the view pattern for pattern-matching an IntMap as if it were an ascending-order list.

The standard System.Random library in GHC Haskell is quite slow, a problem shared by most language implementations. We use the much faster mwc-random package. The only operation we need is picking a uniform Int as an IO action.

We still need a way to build our PickTable:

-- module Detrospector.Types
import Data.List ( mapAccumR )
...
cumulate :: FreqTable -> PickTable
cumulate t = PickTable r $ IM.fromList ps where
(r,ps) = mapAccumR f 0 $ IM.assocs t
f ra (x,n) = let rb = ra+n in (rb, (rb, toEnum x))

This code is short, but kind of tricky. For reference:

mapAccumR :: (acc -> x -> (acc, y)) -> acc -> [x] -> (acc, [y])
f :: Int -> (Int, Int) -> (Int, (Int, Char))

f takes an assoc pair from the FreqTable, adds its count to the running sum, and produces an assoc pair for the PickTable. We start the traversal with a sum of 0, and get the final sum r along with our assoc pairs ps.

Representing the Markov chain

So we can represent probability distributions for characters. Now we need a map from k-character substrings to distributions.

Data.Map is again an option, but pairwise, character-wise comparison of our Queue Char values will be slow. What we really want is another trie, with character-based fanout at each node. Hackage has bytestring-trie, which unfortunately works on bytes, not characters. And maybe I should have used TrieMap or list-tries. Instead I used the hashmap package:

-- module Detrospector.Types
import qualified Data.HashMap as H
...
data Chain = Chain Int (H.HashMap (Queue Char) PickTable)

A value Chain k hm maps from subsequences of up to k Chars to PickTables. A lookup of some Queue Char key will require one traversal to calculate an Int hash value, then uses an IntMap to find a (hopefully small) Map of keys with that same hash value.

There is a wrinkle: we need to specify how to hash a Queue, which is just a synonym for S.Seq. This is an orphan instance, which we could avoid by newtype-wrapping S.Seq.

-- module Detrospector.Types
import qualified Data.Hashable as H
import qualified Data.Foldable as F
...
instance (H.Hashable a) => H.Hashable (S.Seq a) where
{-# SPECIALIZE instance H.Hashable (S.Seq Char) #-}
hash = F.foldl' (\acc h -> acc `H.combine` H.hash h) 0

This code is lifted almost verbatim from the [a] instance in Data.Hashable.

Serialization

After training, we need to write the Chain to disk, for use in subsequent generation. I started out with derived Show and Read, which was simple but incredibly slow. We'll use binary with ByteString.Lazy — the dreaded lazy IO again!

We start by specifying how to serialize a few types. Here the tuple instances for Binary come in handy:

-- module Detrospector.Types
import qualified Data.Binary as Bin
...
-- another orphan instance
instance (Bin.Binary k, Bin.Binary v, H.Hashable k, Ord k)
=> Bin.Binary (H.HashMap k v) where
put = Bin.put . H.assocs
get = H.fromList <$> Bin.get

instance Bin.Binary PickTable where
put (PickTable n t) = Bin.put (n,t)
get = uncurry PickTable <$> Bin.get

instance Bin.Binary Chain where
put (Chain n h) = Bin.put (n,h)
get = uncurry Chain <$> Bin.get

The actual IO is easy. We use gzip compression, which fits right into the IO pipeline:

-- module Detrospector.Types
import qualified Data.ByteString.Lazy as BSL
import qualified Codec.Compression.GZip as Z
...
withChain :: FilePath -> (Chain -> RNG -> IO a) -> IO a
withChain p f = do
ch <- (Bin.decode . Z.decompress) <$> BSL.readFile p
RNG.withSystemRandom $ f ch

writeChain :: FilePath -> Chain -> IO ()
writeChain out = BSL.writeFile out . Z.compress . Bin.encode

Training the chain

I won't present the whole implementation of the train subcommand, but here's a simplification:

-- module Detrospector.Modes.Train
import qualified Data.Text as Txt
import qualified Data.Text.IO as Txt
import qualified Data.HashMap as H
import qualified Data.IntMap as IM
import qualified Data.Sequence as S
import qualified Data.Foldable as F
...
train Train{num,out} = do
(_,h) <- Txt.foldl' roll (emptyQ, H.empty) ys <$> Txt.getContents
writeChain out . Chain num $ H.map cumulate h where

roll (!s,!h) x
= (shift num x s, F.foldr (H.alter $ ins x) h $ S.tails s)

ins x Nothing = Just $! sing x
ins x (Just v) = Just $! incr x v

sing x = IM.singleton (fromEnum x) 1

incr x = IM.alter f $ fromEnum x where
f Nothing = Just 1
f (Just v) = Just $! (v+1)

Before generating PickTables, we build a HashMap of FreqTables. We fold over the input text, accumulating a pair of (last characters seen, map so far). Since foldl' is only strict to weak head-normal form (WHNF), we use bang patterns on the fold function roll to force further evaluation. RWH discusses the same issue.

shift (from Detrospector.Types) pushes a new character into the queue, and drops the oldest character if the size exceeds num. We add one count for the new character x, both for the whole history s and each of its suffixes.

We're using alter where perhaps a combination of lookup and insert would be more natural. This is a workaround for a subtle laziness-related space leak, which I found after much profiling and random mucking about. When you insert into a map like so:

let mm = insert k (v+1) m

there is nothing to force v+1 to WHNF, even if you force mm to WHNF. The leaves of our tree end up holding large thunks of the form ((((1+1)+1)+1)+...).

The workaround is to call alter with Just $! (v+1). We know that the implementation of alter will pattern-match on the Maybe constructor, which then triggers WHNF evaluation of v+1 because of ($!). This was tricky to figure out. Is there a better solution, or some different way I should approach this problem? It seems to me that Data.Map and friends are generally lacking in strictness building blocks.

The end!

Thanks for reading! Here's an example of how not to write the same program:

module Main where{import Random;(0:y)%(R p _)=y%p;(1:f)%(R _ n)=f%n;[]%(J x)=x;b
[p,v,k]=(l k%).(l v%).(l p%);main=getContents>>=("eek"#).flip(.:"eek")(y.y.y.y$0
);h k=toEnum k;;(r:j).:[k,v,b]=(j.:[v,b,r]).((k?).(v?).(b?)$(r?)m);[].:_=id;(!)=
iterate;data R a=J a|R(R a)(R a);(&)i=fmap i;k b y v j=let{t=l b%v+y;d f|t>j=b;d
f=k(m b)t v j}in d q;y l=(!!8)(q R!J l);q(+)b=b+b;p(0:v)z(R f x)=R(p v z f)x;p[]
z(J x)=J(z x);p(1:v)z(R n k)=R n$p v z k;m = succ;l p=tail.(snd&).take 9$((<<2).
fst)!(fromEnum p,0);(?)=p.l;d@[q,p,r]#v=let{u=b d v;j=i u;(s,o)|j<1=((97,122),id
)|h 1=((0,j-1),(k 0 0 u))}in do{q<-(h.o)&randomRIO s;putChar q;[p,r,q]#v};i(J q)
=q;i(R n k)=i n+i k;(<<)=divMod} -- finite text on stdin © keegan oct 2010 BSD3

No comments:

Post a Comment