an nwise generator

2013.Apr.28

The following is a very useful iterator/generator helper. I’ve found it essential to solving a number of different problems.

I once was given the following interview question.

Provided a list of sorted numbers, return the contiguous ranges. e.g., for [1,2,3,7,8,10,12] return ['1-3', '7-8', '10', '12']

A simple way to solve this is to keep track of the current number and the next number and group them into a range if the next number increases by more than one.

Here’s a naïve encoding of the above approach:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def contiguous(xs):
	'given a sequence of numbers, xs, return a list of contiguous sub-sequences'
	xs = iter(xs)
	rv = []
	temp = ()
	prev = None
	for x in xs:
		if prev and x > prev + 1:
			rv.append(temp)
			temp = ()
		temp += (x,)
		prev = x
	if temp:
		rv.append(temp)
	return rv
				
# test it out
assert contiguous([1,2,3,7,8,10,12]) == [(1,2,3), (7,8), (10,), (12,)]

As we can see, there are a lot of lines in the above dedicated to tracking the state of our iteration. This usually suggests that we can refactor some of the code into a generator.

In general, code managing the state of an iteration can be cleanly refactored into a generator. This allows code using the generator to focus only on the meat of the problem it’s trying to solve.

It should be clear that we want a generator that can give us a pairwise view of an iterable. Instead of looking at each individual value in iterations of our loop, we want to look at each value and the value immediately preceeding it.

Here is my standard, one-line formulation:

1
2
3
4
from itertools import tee, islice, izip
nwise = lambda xs,n=2: izip(*(islice(xs,idx,None) for idx,xs in enumerate(tee(xs,n))))

assert list(nwise([1,2,3,4,5,6])) == [(1,2), (2,3), (3,4), (4,5), (5,6)]

This works simply enough:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def nwise(iterable, n=2):
	iterables = tee(iterable, n) # make n copies of the original iterable
	temp = []
	for idx, it in enumerate(iterables): # loop over each one
		it = islice(it, idx, None) # advance the iterable by idx places
		temp.append(it)	
	# we now have n copies of the iterable, where each iterable is skips the first 0..n-1 values
	# for itertools.count(10), n=3, this will look like:
	#   [0,1,2,3,4, ...]
	#   [1,2,3,4,5, ...]
	#   [2,3,4,5,6, ...]
	# therefore, if we zip them together, we'll get an n-wise window of values
	#   -> [(0,1,2), (1,2,3), (2,3,4), ...]
	# since we're using izip here, the zipping occurs lazily
	#   which means this approach will work on infinite-length iterables
	# also, since we've only stepped ahead up to n-1 values in any of the
	#   tee-d iterators, we can guarantee that we won't use more than O(n)
	#   memory in buffering cost
	return izip(*temp) 

Applying to our original problem:

1
2
3
4
5
6
7
8
9
10
11
12
13
from itertools import tee, islice, izip
nwise = lambda xs,n=2: izip(*(islice(xs,idx,None) for idx,xs in enumerate(tee(xs,n))))

def contiguous(xs):
	buf = ()
	for prev, curr in nwise(xs):
		buf += (prev,)
		if curr > prev + 1: # if we have a jump, yield the buffer
			yield buf
			buf = ()
	yield buf + (curr,) # yield whatever is left

assert list(contiguous([1,2,3,7,8,10,12])) == [(1,2,3), (7,8), (10,), (12,)]

This should be a bit shorter and a bit easier to read.

Looking at values pairwise, we built up a subsequence as long as these numbers are contiguous. If there is a jump, then yield the current subsequence and start building a new subsequence.

Let’s clean this up to answer our question more precisely:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from itertools import tee, islice, izip
nwise = lambda xs,n=2: izip(*(islice(xs,idx,None) for idx,xs in enumerate(tee(xs,n))))

def contiguous(xs):
	def subsets(xs):
		buf = ()
		for prev, curr in nwise(xs):
			buf += (prev,)
			if curr > prev + 1:
				yield buf
				buf = ()
		yield buf + (curr,)
	for subset in subsets(xs):
		yield xrange(subset[0], subset[0]+1) \
		       if len(subset) == 1 else \
		       xrange(subset[0], subset[-1]+1)

# should have put an assert here, but xrange(x,y) != xrange(x,y)
#   prints [xrange(1, 4), xrange(7, 9), xrange(10, 11), xrange(12, 13)]
print list(contiguous([1,2,3,7,8,10,12]))

It turns out that there are a number of problems which this generator can assist us in modelling.

Another such problem is:

Find the first substring of the following string in which there are three lowercase letters, followed by an uppercase letter, followed by three more lowercase letters: 'jfjheNeKdlwoqjasJjasjDfk'

A solution (please don’t use this code):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from itertools import tee, islice, izip
nwise = lambda xs,n=2: izip(*(islice(xs,idx,None) for idx,xs in enumerate(tee(xs,n))))

from itertools import imap
find = lambda s: next((''.join([x,y,z])
                  for x,y,z in imap(lambda x: map(''.join,(x[:3],x[3:4],x[4:])), nwise(s,7))
                            if x.islower() and y.isupper() and z.islower()), None)

s = 'jfjheNeKdlwoqjasJjasjDfk'
#                  ---=---
assert find(s) == 'jasJjas'

s = 'jjjjjjjjjjjjjjjjjjjjjjjj'
assert find(s) is None

Additionally, let’s say that we have a sequence of objects that represent historical states of some data.

If we want to build diffs, then it should seem obvious that we want to compare at pairwise values.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# our standard formulation
from itertools import tee, islice, izip
nwise = lambda xs,n=2: izip(*(islice(xs,idx,None) for idx,xs in enumerate(tee(xs,n))))

# for this example, I'll use simple dictionaries

# let's say we have the history of market close prices used in
#   some financial application for equity markets
# ref: http://finance.yahoo.com/q/hp?s=AAPL
#      http://finance.yahoo.com/q/hp?s=MSFT
#      http://finance.yahoo.com/q/hp?s=FB
from datetime import date
history = {
    date(2013,4,26): { 'AAPL': 409.81, 'MSFT': 31.90, 'FB': 26.60 },
    date(2013,4,25): { 'AAPL': 411.23, 'MSFT': 31.71, 'FB': 26.07 },
    date(2013,4,24): { 'AAPL': 393.54, 'MSFT': 30.62, 'FB': 25.93 },
    date(2013,4,23): { 'AAPL': 403.99, 'MSFT': 30.70, 'FB': 26.22 },
    date(2013,4,22): { 'AAPL': 392.64, 'MSFT': 30.30, 'FB': 25.81 }, }

def moves(history):
	for prev, curr in nwise(history[k] for k in sorted(history)):
		# give the price moves for tickers in both (&) snapshots 
		yield {k:curr[k] - prev[k] for k in set(prev) & set(curr)}

class nearest(dict):
	def __missing__(self, key):
		return self[min(self, key=lambda k: abs(k - key))]

# colour the little arrows by direction of price movement
colourscheme = nearest({-0.1: '#FF0000', 0: '#000000', 0.1: '#00FF00'})

def arrows(historical_moves):
	for moves in historical_moves:
		yield {ticker: colourscheme[chg] for ticker,chg in moves.iteritems()}

assert list(arrows(moves(history))) == \
         [{'AAPL': '#00FF00', 'FB': '#00FF00', 'MSFT': '#00FF00'},
          {'AAPL': '#FF0000', 'FB': '#FF0000', 'MSFT': '#FF0000'},
          {'AAPL': '#00FF00', 'FB': '#00FF00', 'MSFT': '#00FF00'},
          {'AAPL': '#FF0000', 'FB': '#00FF00', 'MSFT': '#00FF00'}, ]