When to use aggreagate/filter/transform with pandas

I've been teaching quite a lot of Pandas recently, and a lot of the recurring questions are about grouping. That's no surprise, as it's one of the most flexible features of Pandas. However, that flexibility also makes it sometimes confusing.

I think that most of the confusion arises because the same grouping logic is used for (at least) three distinct operations in Pandas. In the order that we normally learn them, these are:

  • calculating some aggregate measurement for each group (size, mean, etc.)
  • filtering the rows on a property of the group they belong to
  • calculating a new value for each row based on a property of the group.

This leads commonly to situations where we know that we need to use groupby() - and may even be able to easily figure out what the arguments to groupby() should be - but are unsure about what to do next.

Here's a trick that I've found useful when teaching these ideas: think about the result you want, and work back from there. If you want to get a single value for each group, use aggregate() (or one of its shortcuts). If you want to get a subset of the original rows, use filter(). And if you want to get a new value for each original row, use transpose().

Here's a minimal example of the three different situations, all of which require exactly the same call to groupby() but which do different things with the result. We'll use the well known tips dataset which we can load directly from the web:

import pandas as pd
df = pd.read_csv(
   "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/tips.csv"
   )
pd.options.display.max_rows = 10
df
total_bill tip sex smoker day time size
0 16.99 1.01 Female No Sun Dinner 2
1 10.34 1.66 Male No Sun Dinner 3
2 21.01 3.50 Male No Sun Dinner 3
3 23.68 3.31 Male No Sun Dinner 2
4 24.59 3.61 Female No Sun Dinner 4
... ... ... ... ... ... ... ...
239 29.03 5.92 Male No Sat Dinner 3
240 27.18 2.00 Female Yes Sat Dinner 2
241 22.67 2.00 Male Yes Sat Dinner 2
242 17.82 1.75 Male No Sat Dinner 2
243 18.78 3.00 Female No Thur Dinner 2

244 rows × 7 columns

If you're not familiar with this dataset, all you need to know is that each row represents a meal at a restaurant, and the columns store the value of the total bill and the tip, plus some metadata about the customer - their sex, whether or not they were a smoker, what day and time they ate at, and the size of their party. Also, notice that we have 244 rows - this will be important later on.

What was the average total bill on each day?

To answer this, let's imagine that we have already figured out that we need to group by day:

df.groupby('day')

now what's the next step? Use the trick that I just described and start by imagining what we want the output to look like. We want a single value for each group, so we need to use aggregate():

df.groupby('day').aggregate('mean')
total_bill tip size
day
Fri 17.151579 2.734737 2.105263
Sat 20.441379 2.993103 2.517241
Sun 21.410000 3.255132 2.842105
Thur 17.682742 2.771452 2.451613

We're only interested in the total_bill column, so we can select it (either before or after we do the aggregation):

df.groupby('day')['total_bill'].aggregate('mean')

day Fri 17.151579 Sat 20.441379 Sun 21.410000 Thur 17.682742 Name: total_bill, dtype: float64

Pandas has lots of shortcuts for the various ways to aggregate group values - we could use mean() here instead:

df.groupby('day')['total_bill'].mean()

day Fri 17.151579 Sat 20.441379 Sun 21.410000 Thur 17.682742 Name: total_bill, dtype: float64

Which meals were eaten on days where the average bill was greater than 20?

For this question, think again about the output we want - our goal here is to get a subset of the original rows, so this is a job for filter(). The argument to filter() must be a function or lambda that will take a group and return True or False to determine whether rows belonging to that group should be included in the output. Here's how we might do it with a lambda:

df.groupby('day').filter(
    lambda x : x['total_bill'].mean() > 20
    )
total_bill tip sex smoker day time size
0 16.99 1.01 Female No Sun Dinner 2
1 10.34 1.66 Male No Sun Dinner 3
2 21.01 3.50 Male No Sun Dinner 3
3 23.68 3.31 Male No Sun Dinner 2
4 24.59 3.61 Female No Sun Dinner 4
... ... ... ... ... ... ... ...
238 35.83 4.67 Female No Sat Dinner 3
239 29.03 5.92 Male No Sat Dinner 3
240 27.18 2.00 Female Yes Sat Dinner 2
241 22.67 2.00 Male Yes Sat Dinner 2
242 17.82 1.75 Male No Sat Dinner 2

163 rows × 7 columns

Notice that our output dataframe has only 163 rows (compared to the 244 that we started with), and that the columns are exactly the same as the input.

Compared to our first example, it's a bit harder to see why this is useful - typically we'll do a filter like this and then follow it up with another operation. For example, we might want to compare the average party size on days where the average bill is high:

df.groupby('day').filter(
    lambda x : x['total_bill'].mean() > 20
)['size'].mean()

2.6687116564417179

with the average party size on days where the average bill is low:

df.groupby('day').filter(
    lambda x : x['total_bill'].mean() <= 20
)['size'].mean()

2.3703703703703702

Incidentally, a question that I'm often asked is what the type of the argument to the lambda is - what actually is the variable x in our examples above? We can find out by passing a lambda that just prints the type of its input:

df.groupby('day').filter(lambda x: print(type(x)))


<class 'pandas.core.frame.DataFrame'>
<class 'pandas.core.frame.DataFrame'>
<class 'pandas.core.frame.DataFrame'>
<class 'pandas.core.frame.DataFrame'>
total_bill tip sex smoker day time size

And we see that each group is passed to our lambda function as a Pandas DataFrame, so we already know how to use it.

How did the cost of each meal compare to the average for the day?

This last example is the trickiest to understand, but remember our trick - start by thinking about the desired output. In this case we are trying to generate a new value for each input row - the total bill divided by the average total bill for each day. (If you have a scientific or maths background then you might think of this as a normalized or scaled total bill). To make a new value for each row, we use transform().

To start with, let's see what happens when we pass in a lambda to transform() that just gives us the mean of its input:

df.groupby('day').transform(lambda x : x.mean())
total_bill tip size
0 21.410000 3.255132 2.842105
1 21.410000 3.255132 2.842105
2 21.410000 3.255132 2.842105
3 21.410000 3.255132 2.842105
4 21.410000 3.255132 2.842105
... ... ... ...
239 20.441379 2.993103 2.517241
240 20.441379 2.993103 2.517241
241 20.441379 2.993103 2.517241
242 20.441379 2.993103 2.517241
243 17.682742 2.771452 2.451613

244 rows × 3 columns

Notice that we get the same number of output rows as input rows - Pandas has calculated the mean for each group, then used the results as the new values for each row. We're only interested in the total bill, so let's get rid of the other columns:

df.groupby('day')['total_bill'].transform(lambda x : x.mean())

0 21.410000 1 21.410000 2 21.410000 3 21.410000 4 21.410000 ...
239 20.441379 240 20.441379 241 20.441379 242 20.441379 243 17.682742 Name: total_bill, Length: 244, dtype: float64

This gives us a series with the same number of rows as our input data. We could assign this to a new column in our dataframe:

df['day_average'] = df.groupby('day')['total_bill'].transform(
    lambda x : x.mean()
)
df
total_bill tip sex smoker day time size day_average
0 16.99 1.01 Female No Sun Dinner 2 21.410000
1 10.34 1.66 Male No Sun Dinner 3 21.410000
2 21.01 3.50 Male No Sun Dinner 3 21.410000
3 23.68 3.31 Male No Sun Dinner 2 21.410000
4 24.59 3.61 Female No Sun Dinner 4 21.410000
... ... ... ... ... ... ... ... ...
239 29.03 5.92 Male No Sat Dinner 3 20.441379
240 27.18 2.00 Female Yes Sat Dinner 2 20.441379
241 22.67 2.00 Male Yes Sat Dinner 2 20.441379
242 17.82 1.75 Male No Sat Dinner 2 20.441379
243 18.78 3.00 Female No Thur Dinner 2 17.682742

244 rows × 8 columns

Which would allow us to calculate the scaled total bills:

df['total_bill'] / df['day_average']

0 0.793554 1 0.482952 2 0.981317 3 1.106025 4 1.148529 ...
239 1.420159 240 1.329656 241 1.109025 242 0.871761 243 1.062052 Length: 244, dtype: float64

But we could also calculate the scaled bill as part of the transform:

df['scaled bill'] = df.groupby('day')['total_bill'].transform(
    lambda x : x/x.mean()
)
df.head()
total_bill tip sex smoker day time size day_average scaled bill
0 16.99 1.01 Female No Sun Dinner 2 21.41 0.793554
1 10.34 1.66 Male No Sun Dinner 3 21.41 0.482952
2 21.01 3.50 Male No Sun Dinner 3 21.41 0.981317
3 23.68 3.31 Male No Sun Dinner 2 21.41 1.106025
4 24.59 3.61 Female No Sun Dinner 4 21.41 1.148529

In conclusion

All of our three examples used exactly the same groupby() call to begin with:

df.groupby('day')['total_bill'].mean()
df.groupby('day').filter(lambda x : x['total_bill'].mean() > 20)
df.groupby('day')['total_bill'].transform(lambda x : x/x.mean())

but by doing different things with the resulting groups we get very different outputs. To reiterate:

  • if we want to get a single value for each group -> use aggregate()
  • if we want to get a subset of the input rows -> use filter()
  • if we want to get a new value for each input row -> use transform()