Advent of Code 2020 - Day 6

in   code   ,

Day 6 of AoC 2020 (Custom Customs) is a very easy problem to solve using Python sets. Today’s solution is an exercise in code golf. As always, spoilers ahead.

The problem describes a group with 1 or more people in the group, each person of whom has answered a series of 26 yes or no questions. The input is given as all the questions to which each person answered yes, with each person in the group on a separate line. Groups are separated by a blank line.

We need to find all the questions to which any person in the group answered yes. This is a straightforward set union operation.

First, to parse the input, like we did on Day 4, we can split the input using split('\n\n') to break up the individual groups, then convert each line into a set as shown in the snippet below.

with open('input') as datafile:
    # Split on '\n\n' to separate the individual records
    data = datafile.read().split('\n\n')
    groups = [list(map(set, group.split())) for group in data]

We only need the counts, but we need them on a per group basis. The following one-liner will give us the solution for part 1.

from functools import reduce
print(sum(map(len, (reduce(set.union, group) for group in data))))

Let’s break that one-liner down, shall we? First, let’s reformat the above one-liner into multiple lines, so that it is easier to read and understand.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
from functools import reduce
print(
    sum(
        map(
            len,
            (reduce(set.union, group)
                for group in data)
        )
    )
)

Lines 6 and 7 above is a generator. Since data is the list of all groups in the input, we just need to apply the set union operator to each group and get the size of the resulting set. The generator calls the reduce method, with the set.union method, and the group is the list of all sets in that group. The reduce method returns a new set, which is the union of all the sets in that group.

Line 4 uses the map built-in to apply the len method to the above generator, giving us the length of each resulting set for each group in the input. Finally, sum on line 3 totals up all the lengths and returns the result.

In essence, that one-liner can be rewritten in long form as shown below:

total = 0
for group in data:
    s = set(group[0]) # To create a copy
    for person in group:
        s |= person # Same as s = set.union(s, person)

    total += len(s)

print(total)

Part 2 needs us to compute the set intersection instead, so we can change the one-liner to the following:

print(sum(map(len, (reduce(set.intersection, group) for group in data))))