Problem 14 Part 1

2023/03/19

I recently saw a post on Gwern’s blog on an interesting problem in probability and expected value:

A shuffled deck of cards has an equal number of red and black cards. When you draw a red card from the deck, you get \$1, and when you draw a black card you lose \$1. You can stop whenever you want. If the deck has $N$ cards, what is the expected value of playing this game? When should you stop?

The first part of this post has been written without reading past the question, though I did slip and see that Gwern used DP to solve this. Fortunately, it’s pretty obvious that DP is a good option for this (the $N=2$ case is a trivially easy starting point). I’ll start by giving a few thoughts on the problem, maybe try to do some probability math to solve it outright, and then give whatever solutions I think of. After that, I’ll analyze Gwern’s ideas and see what I think!

Part 1: Thoughts

def n_pos_hlpr(r, b, value):
	# if we've already added too many black cards, return
	if value <= -r:
		return 0
	# base cases: if we've run out of one type, all we can do
	# is add all cards of the remaining type
	if r == 0:
		# can't necessarily add all black cards
		return min(value-1, b)
	if b == 0:
		# can always add more red cards
		return r
	# recursive case: we can either add a red or black card to the prefix
	n = 1 + n_pos_hlpr(r-1, b, value+1) + n_pos_hlpr(r, b-1, value-1)
	return n

def n_pos(r, b):
	return n_pos_hlpr(r, b, 0) - 1

note: we subtract 1 in n_pos because n_pos_hlpr counts the initial/zero-length prefix

Part 2: Debugging, Math

Woah, I didn’t mean to give a solution in this part! It sort of fell out of thinking about the combinatorics. Oh well. Let’s run this naive/non-memoized version of the code and see what happens:

>>> n_pos(2,2)
5

Time to make sure that’s right. The valid prefixes I can think of are: {R, RR, BRR, RBR, RRB}. Seems legit! How high can we go with this?

Not very far! I tried to do n_pos(100, 100) and it got stuck for more than a minute, so I killed it. Let’s do some memoization with a little hash table (yes, fine, this is cursed, but it avoids the problem of having a sparse matrix in memory, we’ll fix it later):

def n_pos_hlpr(r, b, value, tbl):
	# if we've already added too many black cards, return
	if value <= -r:
		return 0
	# base cases: if we've run out of one type, all we can do
	# is add all cards of the remaining type
	if r == 0:
		# can't necessarily add all black cards
		return min(value-1, b)
	if b == 0:
		# can always add more red cards
		return r
	if (r, b) in tbl:
		return tbl[(r, b)]
	# recursive case: we can either add a red or black card to the prefix
	n = 1 + n_pos_hlpr(r-1, b, value+1, tbl) + n_pos_hlpr(r, b-1, value-1, tbl)
	tbl[(r, b)] = n
	return n

def n_pos(r, b):
	tbl = {}
	result = n_pos_hlpr(r, b, 0, tbl)
	return result

Woah, n_pos(100, 100) now returns almost instantly. But, um, yikes:

>>> n_pos(100, 100)
134926252037064790251419095546152145179336046991581259352659

I didn’t expect to be overflowing 64-bit ints here, but this is like a 200-bit unsigned integer or something (cursed tip for estimating the number of bits in an integer: take the number of digits and multiply by 3.3). In fact, hang on:

>>> import math
>>> math.comb(200, 100)
90548514656103281165404177077484163874504589675413336841320

Apparently, we have more prefixes than possible arrangements of cards.

At this point, I wanted to find a manageable number where my method failed, so I went to $r=3, b=3$:

3 red, 3 black: 1 red: 1 (0 black) 2 red: 1 (0 black) + 3 (1 black) = 4 (RR, RRB, RBR, BRR) 3 red: 1 (0 black) + 4 (1 black) + 10 (2 black) = 15 Total = 20

OK, most of these number are hopefully intuitive, but how did I get to 10 as the number of arrangements of 3 red and 2 black cards? Well, you can imagine that there are 4 spaces around the 3 cards:

s R s R s R s

We want to draw from these spaces in an unordered manner with replacement, which we can do with this function (drawing $k$ times from $n$ options)

$$uwr(n, k) = {n + k - 1 \choose k}$$

So I used $uwr(4, 2) = {5 \choose 2} = 10$. In general, we can write a function for number of prefixes with up to $r$ red and $b$ black cards:

import math
def n_pos(r, b):
	total = 0
	for r_p in range(1, r+1):
		total += 1 # zero black
		for b_p in range(1, min(r_p, b+1)):
			total += math.comb(r_p + b_p, b_p)
	return total

note: since there are always n+1 spaces around n objects, the +1 and -1 cancel and we get r_p + b_p

We can confirm that n_pos(3) = 20. Let’s see n_pos(100, 100):

>>> n_pos(100, 100):
119733200176664861055409214554645754349825484924704582796270

This is still 10x higher than math.comb(200, 100)! Clearly, something is still wrong.

I think the problem is that we’re double-counting items that are prefixes of each other. For instance, if we call n_pos(2,1) it returns 5, correctly counting prefixes {‘R’, ‘RR’, ‘RRB’, ‘RBR’, ‘BRR’}. But there are only 3 ways 2 red cards and 1 black card can be arranged! If we instead removed items which were prefixes of other items, we’d find that only the longest strings remained: {‘RRB’, ‘RBR’, ‘BRR’}. This lets us remove an outer for loop!

import math
def n_pos(r, b):
	if r == 0:
		return 0
	total = 1 # bp = 0 case
	for b_p in range(1, min(r, b+1)):
		total += math.comb(r + b_p, b_p)
	return total

Now we have n_pos(3,3) = 15. Since we can shuffle 3 red and 3 black cards ${6 \choose 3} = 20$ ways, we expect that we expected positive value 15/20 = 75% of the time! Let’s see if this matches our intuition:

import itertools

def get_shuffles_hlpr(accum, curr, r, b):
	if r == 0:
		accum.append(curr + 'B'*b)
		return
	if b == 0:
		accum.append(curr + 'R'*r)
		return
	get_shuffles_hlpr(accum, curr+'R', r-1, b)
	get_shuffles_hlpr(accum, curr+'B', r, b-1)

def get_shuffles(r, b):
	# here, r and b are the number of red and
	# black cards in the hand
	accum = []
	get_shuffles_hlpr(accum, '', r, b)
	return accum

def p_win(r, b):
	total_pos = 0
	shuffles = get_shuffles(r, b)
	for s in shuffles:
		for i in range(1, len(s)):
			n_red = len(s[:i].replace('B', ''))
			if n_red > (i/2):
				total_pos += 1
				# string has a pos prefix, break
				break
	return total_pos/len(shuffles)

Now we can do:

>>> p_win(3,3)
0.75

Woo! I also verified this on a few more toy examples, eg:

>>> p_win(10, 10)
0.9090909090909091
>>> n_pos(10,10)/math.comb(20, 10)
0.9090909090909091

Part 3: Optimization

Awesome! But do we beat Gwern? If we assume that math.comb has a constant-time implementation, yes, in a sense. In that case, we would have an $O(n)$ solution, while his solution is $O(n^2)$. But it probably doesn’t.

Gwern says he got an answer for 200,000 cards in like 5 seconds. Whether we can beat him really depends on how efficient math.comb is and how efficient computations on large numbers are in Python generally. It takes about 2 seconds to compute n_pos(1000,1000) on my old laptop, so I’m about 2 orders of magnitude too slow with my mathy solution. Let’s see if I can do better!

n_pos(1000, 1000) is:

2046105521468021692642519982997827217179245642339057975844538099572176010191891863964968026156453752449015750569428595097318163634370154637380666882886375203359653243390929717431080443509007504772912973142253209352126946839844796747697638537600100637918819326569730982083021538057087711176285777909275869648636874856805956580057673173655666887003493944650164153396910927037406301799052584663611016897272893305532116292143271037140718751625839812072682464343153792956281748582435751481498598087586998603921577523657477775758899987954012641033870640665444651660246024318184109046864244732001962029120000

This is not an efficient number to work with, and we only care about the first significant figure to decide when to stop playing (since we stop playing when $p_{win} < 0.5$). One apporach is to try to find an approximation of the ${n \choose k}$ function in a faster language that returns a 64-bit float rather than an arbitrarily large (but exact) integer. I expect that this will do a lot for my performance - 100x probably isn’t an unreasonable ask, and that’s before I try any fancy multiprocessing (very easy to do for this method, since there’s no DP or shared memory). However, I’m going to explore another idea. As a reminder, the unoptimized code I’m working with is:

def n_pos(r, b):
	if r == 0:
		return 0
	total = 1 # bp = 0 case
	for b_p in range(1, min(r, b+1)):
		total += math.comb(r + b_p, b_p)
	return total

Which essentially evaluates:

$$1 + \sum\limits_{b_{p}= 1}^{\min(r, b+1)} {r + b_{p} \choose b_{p}}$$.

This looks like:

$$1 + {r + 1 \choose 1} + {r + 2 \choose 2} + \dots$$

We can make the observation that:

$${n + 1 \choose k + 1} = \dfrac{(n+1)!}{(k+1)!\ (n-k)!} = \dfrac{n + 1}{k + 1} \dfrac{n!}{k! (n-k)!} = \dfrac{n + 1}{k + 1} {n \choose k}$$.

This means that we can find a much more efficient way to evaluate

$$\sum\limits_{b_{p}= 1}^{\min(r, b+1)} {r + b_{p} \choose b_{p}}$$

By simply finding ${r + 1 \choose 1}$, and then noting that ${r + 2 \choose 2} = \frac{r+2}{2} {r + 1 \choose 1}$. Also, ${r + 3 \choose 3} = \frac{r+3}{3} {r + 2 \choose 2} = \frac{r+3}{3} \frac{r+2}{2} {r + 1 \choose 1}$. So we end up with:

$${r + 1 \choose 1}\left(1 + \frac{r+2}{2}\left(1 + \frac{r+3}{3}\left(1 + \dots \right)\right)\right)$$

This is super cool. Now instead of evaluating math.comb, which seems to be pretty slow, a ton of times, we can just evaluate it once and leave the rest of it as (much faster) floating-point addition and division!!

def n_pos_opt(r, b):
	if r == 0:
		return 0
	total = 1 # bp = 0 case
	nested_frac = 1
	# bounds on range change for python inclusive/exclusive rules
	for b_p in range(min(r-1, b), 1, -1):
		nested_frac *= (r + b_p)/b_p
		nested_frac += 1
	b_1 = math.comb(r + 1, 1)
	total += b_1 * nested_frac
	return total

When we do n_pos_opt(3,3), we get our familiar 15, but when we do:

>>> n_pos_opt(5,5)
209.99999999999997
>>> n_pos(5,5)
210

This is to be expected! We’re doing a little bit of approximation here, but it’s much better than what I was expecting - the error is just floating-point inaccuracy. But when we do:

>>> n_pos_opt(1000,1000)
inf

We overflowed Python floats. OK.

Our problem is now that we have a mathematically correct approach that Python doesn’t like. Still fixable, with Python’s fractions library! That should store the numerator and denominatory of our nested fraction as integers, which we’ve seen Python will let us make infinitely long.

from fractions import Fraction

def n_pos_opt(r, b):
	if r == 0:
		return 0
	total = 1 # bp = 0 case
	nested_frac = Fraction(1,1)
	for b_p in range(min(r-1, b), 1, -1):
		nested_frac *= Fraction(r + b_p, b_p)
		nested_frac += 1
	b_1 = math.comb(r + 1, 1)
	total += b_1 * nested_frac
	return total

Excellent, this runs even faster than the float version and doesn’t overflow! But it’s still too slow. We can do better by leveraging Fraction’s limit_denominator() function, which generates approximations of fractions.

def n_pos_opt(r, b):
	if r == 0:
		return 0
	total = 1 # bp = 0 case
	nested_frac = Fraction(1,1)
	for b_p in range(min(r-1, b), 1, -1):
		nested_frac *= Fraction(r + b_p, b_p).limit_denominator(10_000_000)
		nested_frac += 1
	b_1 = math.comb(r + 1, 1)
	total += b_1 * nested_frac
	return total

After several seconds, I get a very interesting error:

>>> n_pos_opt(100000,100000)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/lib/python3.10/fractions.py", line 266, in __repr__
    return '%s(%s, %s)' % (self.__class__.__name__,
ValueError: Exceeds the limit (4300) for integer string conversion; use sys.set_int_max_str_digits() to increase the limit

Let’s try one other method with mpmath before we give up on computing with high precision:

import mpmath
mpmath.mp.dps = 15
mpmath.mp.pretty = True

def n_pos_opt(r, b):
	if r == 0:
		return 0
	total = mpmath.mpf(1) # bp = 0 case
	nested_frac = mpmath.mpf(1)
	for b_p in range(min(r-1, b), 1, -1):
		nested_frac *= mpmath.fdiv(r + b_p, b_p)
		nested_frac += mpmath.mpf(1)
	b_1 = mpmath.binomial(r + 1, 1)
	total += b_1 * nested_frac
	return total

It works! mpmath evaluates to 15 digits as an approximation, and we find:

>>> n_pos_opt(1000, 1000)
2.04610552146803e+600
>>> n_pos_opt(100000, 100000)
1.78054508182311e+60203
>>> n_pos_opt(100000, 100000)/mpmath.binomial(200000, 100000)
0.999990000100024

It takes about 3 seconds to run n_pos_opt for (100000, 100000), meaning that we’ve (barely) beaten Gwern’s time, if the problem is to decide whether or not to quit at a certain point. I actually still haven’t read his post, except for searching in it for “seconds”, pulling out the quickest time I saw, and reading enough of the surrounding text to make sure his (100000, 100000) was the same as mine. In other words, I didn’t see what problem he was actually solving.

Adjusting the mpmath precision doesn’t seem to improve the time, which suggests to me that ~all the time is spent in the overhead of actually calling into mpmath. That suggests to me that the best way to improve the speed at this point would be to use something that isn’t Python, but with an mpmath-esque arbitrary-precision float library. Possibly I could do more or less linearly better if I parallelized this; though the for loop looks like it’s very hard to parallelize here, you actually can do it without any shared memory at the expense of multiple calls to mpmath.binomial by essentially having multiple starting points instead of just ${r+1 \choose 1}$. This, obviously, is complicated by GIL, but I’m not sure how. However, I think I’ve pretty much exhausted my interest in this problem for now. Stay tuned!