Consecutive Prime Sum
Consecutive Prime Sum
Context
I cut my teeth on Java by solving problems on Project Euler. At one point, the solution to Problem 50 eluded me for a long time: 7 years to be precise. On a random afternoon last week, I was browsing (and cringing at) some of my old code when I came across this problem once again.
I decided to give it one more try…
Problem
From the problem page:
The prime 41, can be written as the sum of six consecutive primes:
41 = 2 + 3 + 5 + 7 + 11 + 13
This is the longest sum of consecutive primes that adds to a prime below one-hundred.
The longest sum of consecutive primes below one-thousand that adds to a prime, contains 21 terms, and is equal to
953
.Which prime, below one-million, can be written as the sum of the most consecutive primes?
Solution
The first thought to this problem is “for each prime under a million, find a representation in terms of a sum of consecutive primes”.
A simpler solution, most likely, is to invert this. Let’s find the sums of consecutive primes, and keep track of the longest such expression that sums to a prime.
import bisect
import dataclasses
import itertools
import timeit
import math
from typing import Generator, Callable, Iterator
Solution Type
Let’s first define what an “answer” will look like. What we want is:
- The prime itself
- The consecutive primes that add up to it
- The length of the chain (number of primes from 2)
@dataclasses.dataclass
class ConsecutivePrimeSum:
prime: int
consecutive_primes: list
@property
def chain_length(self):
return len(self.consecutive_primes)
Get All Primes Below 1 Million
My favorite mechanism of generating primes is Eratosthenes’ Sieve
def eratosthenes_sieve(n: int) -> Callable[[int], bool]:
"""Returns a function that is an efficient implementation of is_prime
Generates a Eratosthenes' sieve for primes, and uses it as a cache for a wrapped function
"""
is_prime = [True] * (n + 1)
is_prime[0] = is_prime[1] = False
for i in range(2, int(math.sqrt(n)) + 1):
# if i is marked as "not prime", its multiple must have been too
# so nothing to do here
if not is_prime[i]:
continue
# If i is not marked as "not prime", it MUST be prime
# Mark all multiples of i as not prime
multiples_of_i = range(i**2, n + 1, i)
for j in multiples_of_i:
is_prime[j] = False
return lambda x: is_prime[x]
def primes_under(n: int) -> Iterator[int]:
is_prime = eratosthenes_sieve(n)
return [i for i in range(n) if is_prime(i)]
Let’s do a quick check:
primes_under_100 = primes_under(100)
print(primes_under_100)
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97]
Prefix Sums
What we now want is sums over “windows” of varying sizes over this array:
2 | 3 | 5 | 7 | 11 |
2 | 3 | 5 | 7 | 11 |
2 | 3 | 5 | 7 | 11 |
2 | 3 | 5 | 7 | 11 |
2 | 3 | 5 | 7 | 11 |
2 | 3 | 5 | 7 | 11 |
2 | 3 | 5 | 7 | 11 |
2 | 3 | 5 | 7 | 11 |
2 | 3 | 5 | 7 | 11 |
2 | 3 | 5 | 7 | 11 |
2 | 3 | 5 | 7 | 11 |
2 | 3 | 5 | 7 | 11 |
2 | 3 | 5 | 7 | 11 |
The naive way would be to loop over every i, every j, and find sums, an $O(n^3)$ operation. A slightly better approach is to find prefix sums, which will bring this down to $O(n^2)$
prefix_sums = [0] + list(itertools.accumulate(primes_under_100))
print(list(prefix_sums))
[0, 2, 5, 10, 17, 28, 41, 58, 77, 100, 129, 160, 197, 238, 281, 328, 381, 440, 501, 568, 639, 712, 791, 874, 963, 1060]
Getting the sum of all primes between and including (say) 3 and 7 is simply
end = primes_under_100.index(7)
start = primes_under_100.index(3)
prefix_sums[end + 1] - prefix_sums[start]
15
The next step is, for each pair of indices, find the sum of all primes between those two indices, and if it is prime, update our answer
def consecutive_prime_sums(n, debug=False) -> Generator[ConsecutivePrimeSum, None, None]:
is_prime = eratosthenes_sieve(n)
primes_under_n = primes_under(n)
# prepend a 0 for nicer indexing
prefix_sums = [0] + list(itertools.accumulate(primes_under_n))
num_primes = len(primes_under_n)
max_length_until_now = 0
# Consider all chains starting from the ith prime
for i in range(num_primes):
# We iterate backwards for the end, since we are looking for the longest chain
# Once we find a chain from [i,j], there is no point in looking for smaller j
j_start = num_primes - 1
# chains with a length smaller than the longest one we have found so far
# will not contribute to the answer, so we can and should skip them
j_end = max_length_until_now + i - 1
for j in range(j_start, j_end, -1):
chain_length = j - i + 1
# prime_sum includes both the ith and jth prime
consecutive_prime_sum = prefix_sums[j + 1] - prefix_sums[i]
# if this sum is greater than the highest acceptable prime for this problem
# skip it
if consecutive_prime_sum > n:
continue
if is_prime(consecutive_prime_sum):
max_length_until_now = max(max_length_until_now, chain_length)
yield ConsecutivePrimeSum(consecutive_prime_sum, primes_under_n[i:j + 1])
# Smaller j's with the same i will never give a better answer
break
def solve(n, solver):
return max(solver(n), key=lambda x: x.chain_length)
Let’s test out the basic cases:
solve(100, solver=consecutive_prime_sums)
ConsecutivePrimeSum(prime=41, consecutive_primes=[2, 3, 5, 7, 11, 13])
solve(1000, solver=consecutive_prime_sums)
ConsecutivePrimeSum(prime=953, consecutive_primes=[7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89])
We have a working solution! :)
Performance
Before we continue with finding the solution for n = 1000000
, we need to confirm if this is an acceptable solution, i.e. will it end within a minute?
from tabulate import tabulate
def test_performance(solver):
# Simple geometric progression to see how the algorithm scales
ns = [100 * 2**(n - 1) for n in range(1, 14)]
data = []
for n in ns:
start = timeit.default_timer()
ans = solve(n, solver)
end = timeit.default_timer()
data.append([n, round(end-start, 6), ans.prime])
return data
data = test_performance(consecutive_prime_sums)
print(tabulate(data, headers=["n", "time (s)", "answer"], tablefmt="fancy_grid"))
╒════════╤════════════╤══════════╕
│ n │ time (s) │ answer │
╞════════╪════════════╪══════════╡
│ 100 │ 8.5e-05 │ 41 │
├────────┼────────────┼──────────┤
│ 200 │ 0.000209 │ 197 │
├────────┼────────────┼──────────┤
│ 400 │ 0.000445 │ 379 │
├────────┼────────────┼──────────┤
│ 800 │ 0.001347 │ 499 │
├────────┼────────────┼──────────┤
│ 1600 │ 0.005278 │ 1583 │
├────────┼────────────┼──────────┤
│ 3200 │ 0.017518 │ 2909 │
├────────┼────────────┼──────────┤
│ 6400 │ 0.068539 │ 6079 │
├────────┼────────────┼──────────┤
│ 12800 │ 0.198685 │ 12713 │
├────────┼────────────┼──────────┤
│ 25600 │ 0.682807 │ 25237 │
├────────┼────────────┼──────────┤
│ 51200 │ 2.41372 │ 49279 │
├────────┼────────────┼──────────┤
│ 102400 │ 8.6131 │ 102001 │
├────────┼────────────┼──────────┤
│ 204800 │ 30.0977 │ 203279 │
├────────┼────────────┼──────────┤
│ 409600 │ 106.413 │ 408479 │
╘════════╧════════════╧══════════╛
Hmm, it looks like this is definitely not fast enough. Since ours is an $O(n^2)$ algorithm, solving for 1 million will take roughly 480s, close to 8 minutes.
Optimization
Let’s take a quick look at the primes under a million:
primes_under_million = primes_under(1000_000)
len(primes_under_million)
78498
primes_under_million[-10:]
[999863,
999883,
999907,
999917,
999931,
999953,
999959,
999961,
999979,
999983]
Can we get an upper bound for the answer? We can! Let us find a k
such that the sum of the first k
primes (i.e. the smallest k
primes) exceeds a million:
k = 0
running_sum = 0
while running_sum < 1000_000:
running_sum += primes_under_million[k]
k += 1
k
547
Now we can limit our search to chains of maximum length k = 547
.
In the general case, if we can find a k
such that the sum of the smallest k
primes exceeds n
, then we have an upper bound that can significantly speed up the solution.
We can incorporate this into our solver (eliding all comments for succintness):
def faster_consecutive_prime_sums(n, debug=False) -> Generator[ConsecutivePrimeSum, None, None]:
is_prime = eratosthenes_sieve(n)
primes_under_n = primes_under(n)
prefix_sums = [0] + list(itertools.accumulate(primes_under_n))
# Since prefix_sums is sorted by nature, use binary search to find n
max_possible_chain_length = bisect.bisect_left(prefix_sums, n) + 1
if debug:
print(f"Max possible chain length for {n} is {max_possible_chain_length}")
num_primes = len(primes_under_n)
max_length_until_now = 0
# Consider all chains starting from the ith prime
for i in range(num_primes):
j_start = min(i + max_possible_chain_length, num_primes - 1)
j_end = max(max_length_until_now + i - 1, 0)
for j in range(j_start, j_end, -1):
chain_length = j - i + 1
consecutive_prime_sum = prefix_sums[j + 1] - prefix_sums[i]
if consecutive_prime_sum > n:
continue
if is_prime(consecutive_prime_sum):
max_length_until_now = max(max_length_until_now, chain_length)
yield ConsecutivePrimeSum(consecutive_prime_sum, primes_under_n[i:j + 1])
break
data = test_performance(faster_consecutive_prime_sums)
print(tabulate(data, headers=["n", "time (s)", "answer"], tablefmt="fancy_grid"))
╒════════╤════════════╤══════════╕
│ n │ time (s) │ answer │
╞════════╪════════════╪══════════╡
│ 100 │ 0.000144 │ 41 │
├────────┼────────────┼──────────┤
│ 200 │ 0.000188 │ 197 │
├────────┼────────────┼──────────┤
│ 400 │ 0.000614 │ 379 │
├────────┼────────────┼──────────┤
│ 800 │ 0.00085 │ 499 │
├────────┼────────────┼──────────┤
│ 1600 │ 0.000797 │ 1583 │
├────────┼────────────┼──────────┤
│ 3200 │ 0.001656 │ 2909 │
├────────┼────────────┼──────────┤
│ 6400 │ 0.004762 │ 6079 │
├────────┼────────────┼──────────┤
│ 12800 │ 0.006628 │ 12713 │
├────────┼────────────┼──────────┤
│ 25600 │ 0.011502 │ 25237 │
├────────┼────────────┼──────────┤
│ 51200 │ 0.025432 │ 49279 │
├────────┼────────────┼──────────┤
│ 102400 │ 0.043048 │ 102001 │
├────────┼────────────┼──────────┤
│ 204800 │ 0.104071 │ 203279 │
├────────┼────────────┼──────────┤
│ 409600 │ 0.209027 │ 408479 │
╘════════╧════════════╧══════════╛
This appears to satisfy our performance problems: computing for a million should take roughly half a second:
%%timeit -n3 -r3
solve(1_000_000, solver=faster_consecutive_prime_sums)
491 ms ± 14.4 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)