I recently saw a post on Gwern’s blog on an interesting problem in probability and expected value:
A shuffled deck of cards has an equal number of red and black cards. When you draw a red card from the deck, you get \$1, and when you draw a black card you lose \$1. You can stop whenever you want. If the deck has $N$ cards, what is the expected value of playing this game? When should you stop?
The first part of this post has been written without reading past the question, though I did slip and see that Gwern used DP to solve this. Fortunately, it’s pretty obvious that DP is a good option for this (the $N=2$ case is a trivially easy starting point). I’ll start by giving a few thoughts on the problem, maybe try to do some probability math to solve it outright, and then give whatever solutions I think of. After that, I’ll analyze Gwern’s ideas and see what I think!
Part 1: Thoughts
- First, this game has positive expected value. You’ll never stop drawing cards if you’re net negative, so the worst you can do is draw every card break even. Since you can and probably should stop making draws at some point in most shuffles, expected value is positive.
- OK, but when should we stop? Let’s suppose we’re already ahead in this game. How do we determine when to cash out? Let’s model the deck as a random sequence $c_{1}, c_{2}, c_{3}, \dots$. We want to keep playing if we think the probability that some prefix of this sequence $c_{1..n}$ has more red cards than black.
- One obvious point is that if the deck as a whole has more red cards than black, we can skip any complicated calculation and just note that the expected value for $n=1$ is positive.
- Because we can see every card we draw and know how many cards of each type started in the deck, we know at every point in time how many cards of each type remain in the deck. Let’s denote those numbers $r$ and $b$. Trivially, no prefix of length $n >\geq 2r$ can have positive value, so let’s focus on the case where $n < 2r$.
- The number of ways to arrange $r$ red cards and $b$ black cards is ${r+b \choose r}$
- The number of ways to have positive value in a prefix of length $n < 2r$ can be calculated naively by: $$\texttt{NPosFixed}(n, r) = \sum\limits_{r_{p} = \lceil{n/2}\rceil}^{\min(n, r)} {n \choose r_p}$$
- The number of ways to have positive value for any prefix length, given $r$ red and $b$ black cards, is: $$\texttt{NPos} (r, b) = \dfrac{\sum\limits_{n = 1}^{\min(2r, r+b)} \texttt{NPosFixed}(n, r)}{r + b \choose r}$$
- That was more math than we probably need to do! This nested sum could get computationally sort of nasty, too.
- NPos can be calculated constructively by noting that for all prefixes of length $n+1$ are simply a prefix of length $n$ followed by either a red card or a black card:
def n_pos_hlpr(r, b, value):
# if we've already added too many black cards, return
if value <= -r:
return 0
# base cases: if we've run out of one type, all we can do
# is add all cards of the remaining type
if r == 0:
# can't necessarily add all black cards
return min(value-1, b)
if b == 0:
# can always add more red cards
return r
# recursive case: we can either add a red or black card to the prefix
n = 1 + n_pos_hlpr(r-1, b, value+1) + n_pos_hlpr(r, b-1, value-1)
return n
def n_pos(r, b):
return n_pos_hlpr(r, b, 0) - 1
note: we subtract 1 in n_pos because n_pos_hlpr counts the initial/zero-length prefix
Part 2: Debugging, Math
Woah, I didn’t mean to give a solution in this part! It sort of fell out of thinking about the combinatorics. Oh well. Let’s run this naive/non-memoized version of the code and see what happens:
>>> n_pos(2,2)
5
Time to make sure that’s right. The valid prefixes I can think of are: {R, RR, BRR, RBR, RRB}. Seems legit! How high can we go with this?
Not very far! I tried to do n_pos(100, 100)
and it got stuck for more than a minute, so I killed it. Let’s do some memoization with a little hash table (yes, fine, this is cursed, but it avoids the problem of having a sparse matrix in memory, we’ll fix it later):
def n_pos_hlpr(r, b, value, tbl):
# if we've already added too many black cards, return
if value <= -r:
return 0
# base cases: if we've run out of one type, all we can do
# is add all cards of the remaining type
if r == 0:
# can't necessarily add all black cards
return min(value-1, b)
if b == 0:
# can always add more red cards
return r
if (r, b) in tbl:
return tbl[(r, b)]
# recursive case: we can either add a red or black card to the prefix
n = 1 + n_pos_hlpr(r-1, b, value+1, tbl) + n_pos_hlpr(r, b-1, value-1, tbl)
tbl[(r, b)] = n
return n
def n_pos(r, b):
tbl = {}
result = n_pos_hlpr(r, b, 0, tbl)
return result
Woah, n_pos(100, 100)
now returns almost instantly. But, um, yikes:
>>> n_pos(100, 100)
134926252037064790251419095546152145179336046991581259352659
I didn’t expect to be overflowing 64-bit ints here, but this is like a 200-bit unsigned integer or something (cursed tip for estimating the number of bits in an integer: take the number of digits and multiply by 3.3). In fact, hang on:
>>> import math
>>> math.comb(200, 100)
90548514656103281165404177077484163874504589675413336841320
Apparently, we have more prefixes than possible arrangements of cards.
At this point, I wanted to find a manageable number where my method failed, so I went to $r=3, b=3$:
3 red, 3 black: 1 red: 1 (0 black) 2 red: 1 (0 black) + 3 (1 black) = 4 (RR, RRB, RBR, BRR) 3 red: 1 (0 black) + 4 (1 black) + 10 (2 black) = 15 Total = 20
OK, most of these number are hopefully intuitive, but how did I get to 10 as the number of arrangements of 3 red and 2 black cards? Well, you can imagine that there are 4 spaces around the 3 cards:
s R s R s R s
We want to draw from these spaces in an unordered manner with replacement, which we can do with this function (drawing $k$ times from $n$ options)
$$uwr(n, k) = {n + k - 1 \choose k}$$
So I used $uwr(4, 2) = {5 \choose 2} = 10$. In general, we can write a function for number of prefixes with up to $r$ red and $b$ black cards:
import math
def n_pos(r, b):
total = 0
for r_p in range(1, r+1):
total += 1 # zero black
for b_p in range(1, min(r_p, b+1)):
total += math.comb(r_p + b_p, b_p)
return total
note: since there are always n+1 spaces around n objects, the +1 and -1 cancel and we get r_p + b_p
We can confirm that n_pos(3) = 20
. Let’s see n_pos(100, 100)
:
>>> n_pos(100, 100):
119733200176664861055409214554645754349825484924704582796270
This is still 10x higher than math.comb(200, 100)
! Clearly, something is still wrong.
I think the problem is that we’re double-counting items that are prefixes of each other. For instance, if we call n_pos(2,1)
it returns 5, correctly counting prefixes {‘R’, ‘RR’, ‘RRB’, ‘RBR’, ‘BRR’}. But there are only 3 ways 2 red cards and 1 black card can be arranged! If we instead removed items which were prefixes of other items, we’d find that only the longest strings remained: {‘RRB’, ‘RBR’, ‘BRR’}. This lets us remove an outer for loop!
import math
def n_pos(r, b):
if r == 0:
return 0
total = 1 # bp = 0 case
for b_p in range(1, min(r, b+1)):
total += math.comb(r + b_p, b_p)
return total
Now we have n_pos(3,3) = 15
. Since we can shuffle 3 red and 3 black cards ${6 \choose 3} = 20$ ways, we expect that we expected positive value 15/20 = 75% of the time! Let’s see if this matches our intuition:
import itertools
def get_shuffles_hlpr(accum, curr, r, b):
if r == 0:
accum.append(curr + 'B'*b)
return
if b == 0:
accum.append(curr + 'R'*r)
return
get_shuffles_hlpr(accum, curr+'R', r-1, b)
get_shuffles_hlpr(accum, curr+'B', r, b-1)
def get_shuffles(r, b):
# here, r and b are the number of red and
# black cards in the hand
accum = []
get_shuffles_hlpr(accum, '', r, b)
return accum
def p_win(r, b):
total_pos = 0
shuffles = get_shuffles(r, b)
for s in shuffles:
for i in range(1, len(s)):
n_red = len(s[:i].replace('B', ''))
if n_red > (i/2):
total_pos += 1
# string has a pos prefix, break
break
return total_pos/len(shuffles)
Now we can do:
>>> p_win(3,3)
0.75
Woo! I also verified this on a few more toy examples, eg:
>>> p_win(10, 10)
0.9090909090909091
>>> n_pos(10,10)/math.comb(20, 10)
0.9090909090909091
Part 3: Optimization
Awesome! But do we beat Gwern? If we assume that math.comb
has a constant-time implementation, yes, in a sense. In that case, we would have an $O(n)$ solution, while his solution is $O(n^2)$. But it probably doesn’t.
Gwern says he got an answer for 200,000 cards in like 5 seconds. Whether we can beat him really depends on how efficient math.comb
is and how efficient computations on large numbers are in Python generally. It takes about 2 seconds to compute n_pos(1000,1000)
on my old laptop, so I’m about 2 orders of magnitude too slow with my mathy solution. Let’s see if I can do better!
n_pos(1000, 1000) is:
2046105521468021692642519982997827217179245642339057975844538099572176010191891863964968026156453752449015750569428595097318163634370154637380666882886375203359653243390929717431080443509007504772912973142253209352126946839844796747697638537600100637918819326569730982083021538057087711176285777909275869648636874856805956580057673173655666887003493944650164153396910927037406301799052584663611016897272893305532116292143271037140718751625839812072682464343153792956281748582435751481498598087586998603921577523657477775758899987954012641033870640665444651660246024318184109046864244732001962029120000
This is not an efficient number to work with, and we only care about the first significant figure to decide when to stop playing (since we stop playing when $p_{win} < 0.5$). One apporach is to try to find an approximation of the ${n \choose k}$ function in a faster language that returns a 64-bit float rather than an arbitrarily large (but exact) integer. I expect that this will do a lot for my performance - 100x probably isn’t an unreasonable ask, and that’s before I try any fancy multiprocessing (very easy to do for this method, since there’s no DP or shared memory). However, I’m going to explore another idea. As a reminder, the unoptimized code I’m working with is:
def n_pos(r, b):
if r == 0:
return 0
total = 1 # bp = 0 case
for b_p in range(1, min(r, b+1)):
total += math.comb(r + b_p, b_p)
return total
Which essentially evaluates:
$$1 + \sum\limits_{b_{p}= 1}^{\min(r, b+1)} {r + b_{p} \choose b_{p}}$$.
This looks like:
$$1 + {r + 1 \choose 1} + {r + 2 \choose 2} + \dots$$
We can make the observation that:
$${n + 1 \choose k + 1} = \dfrac{(n+1)!}{(k+1)!\ (n-k)!} = \dfrac{n + 1}{k + 1} \dfrac{n!}{k! (n-k)!} = \dfrac{n + 1}{k + 1} {n \choose k}$$.
This means that we can find a much more efficient way to evaluate
$$\sum\limits_{b_{p}= 1}^{\min(r, b+1)} {r + b_{p} \choose b_{p}}$$
By simply finding ${r + 1 \choose 1}$, and then noting that ${r + 2 \choose 2} = \frac{r+2}{2} {r + 1 \choose 1}$. Also, ${r + 3 \choose 3} = \frac{r+3}{3} {r + 2 \choose 2} = \frac{r+3}{3} \frac{r+2}{2} {r + 1 \choose 1}$. So we end up with:
$${r + 1 \choose 1}\left(1 + \frac{r+2}{2}\left(1 + \frac{r+3}{3}\left(1 + \dots \right)\right)\right)$$
This is super cool. Now instead of evaluating math.comb
, which seems to be pretty slow, a ton of times, we can just evaluate it once and leave the rest of it as (much faster) floating-point addition and division!!
def n_pos_opt(r, b):
if r == 0:
return 0
total = 1 # bp = 0 case
nested_frac = 1
# bounds on range change for python inclusive/exclusive rules
for b_p in range(min(r-1, b), 1, -1):
nested_frac *= (r + b_p)/b_p
nested_frac += 1
b_1 = math.comb(r + 1, 1)
total += b_1 * nested_frac
return total
When we do n_pos_opt(3,3)
, we get our familiar 15, but when we do:
>>> n_pos_opt(5,5)
209.99999999999997
>>> n_pos(5,5)
210
This is to be expected! We’re doing a little bit of approximation here, but it’s much better than what I was expecting - the error is just floating-point inaccuracy. But when we do:
>>> n_pos_opt(1000,1000)
inf
We overflowed Python floats. OK.
Our problem is now that we have a mathematically correct approach that Python doesn’t like. Still fixable, with Python’s fractions library! That should store the numerator and denominatory of our nested fraction as integers, which we’ve seen Python will let us make infinitely long.
from fractions import Fraction
def n_pos_opt(r, b):
if r == 0:
return 0
total = 1 # bp = 0 case
nested_frac = Fraction(1,1)
for b_p in range(min(r-1, b), 1, -1):
nested_frac *= Fraction(r + b_p, b_p)
nested_frac += 1
b_1 = math.comb(r + 1, 1)
total += b_1 * nested_frac
return total
Excellent, this runs even faster than the float version and doesn’t overflow! But it’s still too slow. We can do better by leveraging Fraction’s limit_denominator()
function, which generates approximations of fractions.
def n_pos_opt(r, b):
if r == 0:
return 0
total = 1 # bp = 0 case
nested_frac = Fraction(1,1)
for b_p in range(min(r-1, b), 1, -1):
nested_frac *= Fraction(r + b_p, b_p).limit_denominator(10_000_000)
nested_frac += 1
b_1 = math.comb(r + 1, 1)
total += b_1 * nested_frac
return total
After several seconds, I get a very interesting error:
>>> n_pos_opt(100000,100000)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/lib/python3.10/fractions.py", line 266, in __repr__
return '%s(%s, %s)' % (self.__class__.__name__,
ValueError: Exceeds the limit (4300) for integer string conversion; use sys.set_int_max_str_digits() to increase the limit
Let’s try one other method with mpmath
before we give up on computing with high precision:
import mpmath
mpmath.mp.dps = 15
mpmath.mp.pretty = True
def n_pos_opt(r, b):
if r == 0:
return 0
total = mpmath.mpf(1) # bp = 0 case
nested_frac = mpmath.mpf(1)
for b_p in range(min(r-1, b), 1, -1):
nested_frac *= mpmath.fdiv(r + b_p, b_p)
nested_frac += mpmath.mpf(1)
b_1 = mpmath.binomial(r + 1, 1)
total += b_1 * nested_frac
return total
It works! mpmath
evaluates to 15 digits as an approximation, and we find:
>>> n_pos_opt(1000, 1000)
2.04610552146803e+600
>>> n_pos_opt(100000, 100000)
1.78054508182311e+60203
>>> n_pos_opt(100000, 100000)/mpmath.binomial(200000, 100000)
0.999990000100024
It takes about 3 seconds to run n_pos_opt
for (100000, 100000)
, meaning that we’ve (barely) beaten Gwern’s time, if the problem is to decide whether or not to quit at a certain point. I actually still haven’t read his post, except for searching in it for “seconds”, pulling out the quickest time I saw, and reading enough of the surrounding text to make sure his (100000, 100000) was the same as mine. In other words, I didn’t see what problem he was actually solving.
Adjusting the mpmath precision doesn’t seem to improve the time, which suggests to me that ~all the time is spent in the overhead of actually calling into mpmath. That suggests to me that the best way to improve the speed at this point would be to use something that isn’t Python, but with an mpmath-esque arbitrary-precision float library. Possibly I could do more or less linearly better if I parallelized this; though the for
loop looks like it’s very hard to parallelize here, you actually can do it without any shared memory at the expense of multiple calls to mpmath.binomial
by essentially having multiple starting points instead of just ${r+1 \choose 1}$. This, obviously, is complicated by GIL, but I’m not sure how. However, I think I’ve pretty much exhausted my interest in this problem for now. Stay tuned!