Quick Note on Prefix Sums in Python
This mostly serves as a quick note-to-self.
One of my favorite things about Python is how you can easily step in and out of more declarative/functional styles of programming as you like. Python’s list/set/dict comprehensions and generator expressions were inherited from Haskell (couldn’t find my original source). They’re arguably way more flexible and useful than their biggest influence.
(GHC, a.k.a. the Glasgow Haskell Compiler, offers a compiler extension to generalize Haskell’s list comprehensions to any type that’s an instance of Monad
, but I haven’t used it and don’t know much about its use in production Haskell libraries and applications.)
Python comprehensions are often pretty fast since they’re essentially “dropping down” into C. This did bite me, though, this morning, when I was working on a LeetCode problem that required prefix sums.
For those who don’t know, a prefix sum is essentially a running total of a list. Here’s an example:
def prefix_sums(xs: list[int]) -> list[int]:
sums = [sum(xs[:i]) for i in range(len(xs))]
return sums
My initial, and incorrect, line of thinking here was: “list comprehensions are quick. I need to reuse these sums in the problem, so I’ll use a list comprehension rather than a generator expression. That should be sufficient.” This code, unsurprisingly, timed out.
There’s a very obvious problem with that snippet above: we’re recomputing sums constantly! For sums[1_000_000]
we’re recomputing the first million prefix sums! Meanwhile, there is a very obvious definition which avoids recomputing intermediate prefix sums we already know, which is:
def prefix_sums(xs: list[int]) -> list[int]:
sums = [0] * len(xs) # or [0 for _ in range(len(xs))]
sums[0] = xs[0]
for i, x in enumerate(xs[1:], start=1):
sums[i] = sums[i-1] + x
return sums
As an aside, I assumed that the comprehension version of creating that list would be quicker, but a quick %timeit
in IPython showed me I was wrong, at least in this context:
In [61]: %timeit [0 for _ in range(10_000_000)]
257 ms ± 6.36 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [62]: %timeit [0] * 10_000_000
38.2 ms ± 139 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Remember: if you’re running into unexpected behavior from your code, question the assumptions you used when writing that code! What you believe “should” be faster, etc. Consider your hardware, and start iterating through those assumptions and testing each of them. I would try and avoid this sort of quick and dirty heuristic for code that’s meant to hit prod.
So, what are the lessons we can learn here?
- You’re never really in a “when all you have is a hammer” situation in software engineering. There’s always some assumption you can question that might make it so that your code runs faster/with less memory, you can more gracefully handle your edge cases, whatever your biggest constraint is. In this situation, my assumption was that outside of a restricted computation domain like NumPy that expects you to favor vectorizable operations, list comprehensions are More Efficient (r) (tm).*
- Always test things out!
- Think out loud!
*and in NumPy, you could just do arr.cumsum()
.