Wednesday, October 29, 2008

Markov chain code

A follow up to my markov chain posts, here's the code for a generic markov chain in C#:

A markov chain:

/// <summary>
/// Chain implements a markov chain for a type T
/// allows the generation of sequences based on
/// a sample set of T items
/// </summary>
/// <typeparam name="T">the type of elements</typeparam>
public class Chain<T>
{
Link<T> root = new Link<T>(default(T));
int length;

/// <summary>
/// creates a new chain
/// </summary>
/// <param name="input">Sample set</param>
/// <param name="length">window size for sequences</param>
public Chain(IEnumerable<T> input, int length)
{
this.length = length;
root.Process(input, length);
}

/// <summary>
/// generate a new sequence based on the samples first entry
/// </summary>
/// <param name="max">maximum size of result</param>
/// <returns></returns>
public IEnumerable<T> Generate(int max)
{
foreach (Link<T> next in root.Generate(root.SelectRandomLink().Data, length, max))
yield return next.Data;
}

/// <summary>
/// generate a new sequence based on the sample
/// </summary>
/// <param name="start">the item to start with</param>
/// <param name="max">maximum size of result</param>
/// <returns></returns>
public IEnumerable<T> Generate(T start, int max)
{
foreach (Link<T> next in root.Generate(start, length, max))
yield return next.Data;
}
}


consists of links:

/// <summary>
/// parts of a chain (markcov)
/// </summary>
/// <typeparam name="T">link type</typeparam>
internal class Link<T>
{
T data;
int count;
// following links
Dictionary<T, Link<T>> links;

private Link()
{
}

/// <summary>
/// create a new link
/// </summary>
/// <param name="data">value of the item in sequence</param>
internal Link(T data)
{
this.data = data;
this.count = 0;

links = new Dictionary<T, Link<T>>();
}

/// <summary>
/// process the input in window sized chunks
/// </summary>
/// <param name="input">the sample set</param>
/// <param name="length">size of sequence window</param>
public void Process(IEnumerable<T> input, int length)
{
// holds the current window
Queue<T> window = new Queue<T>(length);

// process the input, a window at a time (overlapping)
foreach (T part in input)
{
if (window.Count == length)
window.Dequeue();
window.Enqueue(part);

ProcessWindow(window);
}
}

/// <summary>
/// process the window to construct the chain
/// </summary>
/// <param name="window"></param>
private void ProcessWindow(Queue<T> window)
{
Link<T> link = this;

foreach (T part in window)
link = link.Process(part);
}

/// <summary>
/// process an item following us
/// keep track of how many times
/// we are followed by each item
/// </summary>
/// <param name="part"></param>
/// <returns></returns>
internal Link<T> Process(T part)
{
Link<T> link = Find(part);

// not been followed by this
// item before
if (link == null)
{
link = new Link<T>(part);
links.Add(part, link);
}

link.Seen();

return link;
}

private void Seen()
{
count++;
}

public T Data
{
get
{
return data;
}
}

public int Occurances
{
get
{
return count;
}
}

/// <summary>
/// Total number of incidences after this link
/// </summary>
public int ChildOccurances
{
get
{
// sum all followers occurances
int result = links.Sum(link => link.Value.Occurances);

return result;
}
}

public override string ToString()
{
return String.Format("{0} ({1})", data, count);
}

/// <summary>
/// find a follower of this link
/// </summary>
/// <param name="start">item to be found</param>
/// <returns></returns>
internal Link<T> Find(T follower)
{
Link<T> link = null;

if (links.ContainsKey(follower))
link = links[follower];

return link;
}

static Random rand = new Random();
/// <summary>
/// select a random follower weighted
/// towards followers that followed us
/// more often in the sample set
/// </summary>
/// <returns></returns>
public Link<T> SelectRandomLink()
{
Link<T> link = null;

int universe = this.ChildOccurances;

// select a random probability
int rnd = rand.Next(1, universe+1);

// match the probability by treating
// the followers as bands of probability
int total = 0;
foreach (Link<T> child in links.Values)
{
total += child.Occurances;

if (total >= rnd)
{
link = child;
break;
}
}

return link;
}

/// <summary>
/// find a window of followers that
/// are after this link, returns where
/// the last link if found, or null if
/// this window never occured after this link
/// </summary>
/// <param name="window">the sequence to look for</param>
/// <returns></returns>
private Link<T> Find(Queue<T> window)
{
Link<T> link = this;

foreach (T part in window)
{
link = link.Find(part);

if (link == null)
break;
}

return link;
}

/// <summary>
/// a generated set of followers based
/// on the likelyhood of sequence steps
/// seen in the sample data
/// </summary>
/// <param name="start">a seed value to start the sequence with</param>
/// <param name="length">how bug a window to use for sequence steps</param>
/// <param name="max">maximum size of the set produced</param>
/// <returns></returns>
internal IEnumerable<Link<T>> Generate(T start, int length, int max)
{
var window = new Queue<T>(length);

window.Enqueue(start);

for (Link<T> link = Find(window); link != null && max != 0; link = Find(window), max--)
{
var next = link.SelectRandomLink();

yield return link;

if (window.Count == length-1)
window.Dequeue();
if (next != null)
window.Enqueue(next.Data);
}
}
}


which can be called:

static void Main(string[] args)
{
// sample data set
string seed = Tidy(@"Twinkle, twinkle, little star,
How I wonder what you are!
Up above the world so high,
Like a diamond in the sky!

When the blazing sun is gone,
When he nothing shines upon,
Then you show your little light,
Twinkle, twinkle, all the night.

Then the traveller in the dark,
Thanks you for your tiny spark,
He could not see which way to go,
If you did not twinkle so.

In the dark blue sky you keep,
And often through my curtains peep,
For you never shut your eye,
Till the sun is in the sky.

As your bright and tiny spark,
Lights the traveller in the dark,—
Though I know not what you are,
Twinkle, twinkle, little star.
");

// tokenise the input string
var seedList = new List<string>(Split(seed.ToLower()));
// create a chain with a window size of 4
var chain = new Chain<string>(seedList, 4);

// generate a new sequence using a starting word, and maximum return size
var generated = new List<string>(chain.Generate("twinkle", 2000));
// output the results to the console
generated.ForEach(item => Console.Write("{0}", item));
}

// tokenise a string into words (regex definition of word)
private static IEnumerable<string> Split(string subject)
{
List<string> tokens = new List<string>();
Regex regex = new Regex(@"(\W+)");
tokens.AddRange(regex.Split(subject));

return tokens;
}


Giving output such as:
twinkle, little star,
how i know not twinkle so.

in the sky.

as your tiny spark,
lights the sky!

when the night.

then the traveller in the dark blue sky you never shut your eye,
till the sky.

as your eye,
till the sun is gone,
when he nothing shines upon,
then you are!
up above the dark,
thanks you did not what you are,
twinkle, twinkle, all the sky!


One of the interesting side-effects of tokenizing using a simple regex is that if the input stream is HTML the Markov chain will treat HTML as words as well and therefore not only generate text that looks like the input, but is also formatted like the input - this also works for punctuation.

6 comments:

Unknown said...

I was wondering where the "Sum" method came from in line 114 of the Link internal class. Thanks.

fe said...

@Dervin Sum is an extension method, see http://msdn.microsoft.com/en-us/library/system.linq.enumerable.sum.aspx

Colin Meier said...

Thanks dude! I've been noodling around with the idea for a Markov chain to use some art projects i'm doing. This is *SO* much simpler than my attempts.

You win +3 internetz :)

Anonymous said...

Where´s the Tidy method? I cant test it :(

fe said...

@Kalaziel:
private static string Tidy(string p)
{
string result = p.Replace('\t', ' ');
string compress = result;

do
{
result = compress;
compress = result.Replace(" ", " ");
}
while (result != compress);

return result;
}

Anonymous said...

Thank you very much, I'm gonna test it now :)