Data visualization refers to the process (and result) of representing data graphically.
For our purposes today, we'll be talking mostly about common methods of plotting data, including:
Florence Nightingale (1820-1910) was a social reformer, statistician, and founder of modern nursing.
In today's tutorial, we'll discuss how to visualize data using Python.
matplotlib
or seaborn
. Here, we load the core packages we'll be using:
pandas
: a library for loading, reshaping, and joining DataFrame
objects. matplotlib
: Python's core plotting/graphics library. seaborn
: a library built on top of matplotlib
, which makes certain plotting functions easier. We also add some lines of code that make sure our visualizations will plot "inline" with our code, and that they'll have nice, crisp quality.
import pandas as pd # conventionalized abbreviation
import matplotlib.pyplot as plt # conventionalized abbreviation
import seaborn as sns # conventionalized abbreviation
/Users/seantrott/anaconda3/lib/python3.7/site-packages/pandas/compat/_optional.py:138: UserWarning: Pandas requires version '2.7.0' or newer of 'numexpr' (version '2.6.8' currently installed). warnings.warn(msg, UserWarning)
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
matplotlib.pyplot
¶In this section, we'll explore several basic plot types, first using matplotlib.pyplot
. These will include:
Afterwards (in Part 2), we will discuss how to replicate these plot types with seaborn
, an API built on top of pyplot
.
To get started, we'll load a dataset from seaborn
.
## Load taxis dataset
df_taxis = sns.load_dataset("taxis")
len(df_taxis)
6433
## What are the columns/variables in this dataset?
df_taxis.head(5)
pickup | dropoff | passengers | distance | fare | tip | tolls | total | color | payment | pickup_zone | dropoff_zone | pickup_borough | dropoff_borough | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 2019-03-23 20:21:09 | 2019-03-23 20:27:24 | 1 | 1.60 | 7.0 | 2.15 | 0.0 | 12.95 | yellow | credit card | Lenox Hill West | UN/Turtle Bay South | Manhattan | Manhattan |
1 | 2019-03-04 16:11:55 | 2019-03-04 16:19:00 | 1 | 0.79 | 5.0 | 0.00 | 0.0 | 9.30 | yellow | cash | Upper West Side South | Upper West Side South | Manhattan | Manhattan |
2 | 2019-03-27 17:53:01 | 2019-03-27 18:00:25 | 1 | 1.37 | 7.5 | 2.36 | 0.0 | 14.16 | yellow | credit card | Alphabet City | West Village | Manhattan | Manhattan |
3 | 2019-03-10 01:23:59 | 2019-03-10 01:49:51 | 1 | 7.70 | 27.0 | 6.15 | 0.0 | 36.95 | yellow | credit card | Hudson Sq | Yorkville West | Manhattan | Manhattan |
4 | 2019-03-30 13:27:42 | 2019-03-30 13:37:14 | 3 | 2.16 | 9.0 | 1.10 | 0.0 | 13.40 | yellow | credit card | Midtown East | Yorkville West | Manhattan | Manhattan |
Histograms are critical for visualizing how your data distribute.
They are typically used for continuous, quantitative variables. Examples could include: height, income, and temperature.
The pyplot
package has a dedicated function for creating histograms, called hist
(documentation here).
An example call would look like this:
plt.hist(
x = [ ... ],
bins = 10, # set number of bins
density = False, # indicate whether to draw density plot
alpha = .7 # indicate shading/transparency of plot
)
Now let's plot some of our variables from the taxis
dataset!
## Q: How would you describe the shape of this distribution?
plt.hist(
x=df_taxis['fare'], # this is the variable we're plotting
bins=20, # this is the number of bins we want to know
alpha = .5 # this is our alpha level
)
plt.show()
We can also add axis labels, so it's clear what we're plotting.
plt.hist(
x=df_taxis['fare'], # this is the variable we're plotting
bins=20, # this is the number of bins we want to know
alpha = .5 # this is our alpha level
)
plt.xlabel("Fare ($)")
plt.ylabel("Count")
plt.title("Distribution of Taxi Fares ($)")
plt.show()
plt.hist(
x=df_taxis['distance'], # this is the variable we're plotting
bins=20, # this is the number of bins we want to know
alpha = .5 # this is our alpha level
)
plt.xlabel("Distance (miles)")
plt.ylabel("Count")
plt.title("Distribution of Trip Distances")
plt.show()
plt.hist(
x=df_taxis['tip'], # this is the variable we're plotting
bins=20, # this is the number of bins we want to know
alpha = .5 # this is our alpha level
)
plt.xlabel("Tip ($)")
plt.ylabel("Count")
plt.title("Distribution of Tip Amounts ($)")
plt.show()
Presumably, these tip amounts correlate with the overall fare. We might also want to know about the amount in percentage-terms, e.g., how much a customer tipped as a percentage of how much the fare was.
## tip_pct = tip / fare
df_taxis['tip_pct'] = (df_taxis['tip'] / df_taxis['fare']) * 100
Now we can plot this new variable, to get a sense of the distribution of tip percentages.
Q: What conclusions might we draw about the distribution of tip percentages here?
plt.hist(
x=df_taxis['tip_pct'], # this is the variable we're plotting
bins=20, # this is the number of bins we want to know
alpha = .5 # this is our alpha level
)
plt.xlabel("Tip (%)")
plt.ylabel("Count")
plt.title("Distribution of Tip Percentages (%)")
plt.show()
What if we wanted to add more information to this plot, such as the mean tip percentage?
(Hint: use either plt.axvline
or plt.text
).
# First, calculate mean tip for ease of access
MEAN_TIP = df_taxis['tip_pct'].mean()
# Drawing a vertical line
plt.hist(
x=df_taxis['tip_pct'], # this is the variable we're plotting
bins=20, # this is the number of bins we want to know
alpha = .5 # this is our alpha level
)
plt.axvline(MEAN_TIP, linestyle = "dotted")
plt.xlabel("Tip (%)")
plt.ylabel("Count")
plt.title("Distribution of Tip Percentages (%)")
plt.show()
# Annotating with text
plt.hist(
x=df_taxis['tip_pct'], # this is the variable we're plotting
bins=20, # this is the number of bins we want to know
alpha = .5 # this is our alpha level
)
plt.text(MEAN_TIP, # where on x-axis to display
y = 1500, # where on y-axis to display
s = "Mean Tip: {mean}%".format(mean=str(round(MEAN_TIP, 2))))
plt.xlabel("Tip (%)")
plt.ylabel("Count")
plt.title("Distribution of Tip Percentages (%)")
plt.show()
Scatterplots can be used to look at how two different continuous variables relate to each other.
Among other things, scatterplots are critical for exploratory data analysis––for example, if you want to run a linear regression on your dataset, it is important to determine that the relationship you're investigating is linear (remember Anscombe's Quartet).
The pyplot
package has a dedicated function for creating scatterplots as well, called scatter
(documentation here).
An example call would look like this:
plt.scatter(
x = [ ... ], # array of values for x-axis
y = [ ... ], # array of values for y-axis
c = [ ... ], # array of values to indicate color of points
alpha = .7 # indicate shading/transparency of plot
)
Intuitively, we might want to know a few things:
# Q: Do longer distances cost more?
plt.scatter(
x = df_taxis['distance'],
y = df_taxis['fare'],
alpha = .6
)
plt.xlabel("Distance (miles)")
plt.ylabel("Fare ($)")
plt.title("Raw Fare by Distance Traveled")
plt.show()
plt.scatter(
x = df_taxis['distance'],
y = df_taxis['tip'],
alpha = .6
)
plt.xlabel("Distance (miles)")
plt.ylabel("Tip ($)")
plt.title("Tip Amount by Distance Traveled")
plt.show()
plt.scatter(
x = df_taxis['distance'],
y = df_taxis['tip_pct'],
alpha = .4
)
plt.xlabel("Distance (miles)")
plt.ylabel("Tip (%)")
plt.title("Tip Percentage by Distance Traveled")
plt.show()
## Create a new variable
df_taxis['crossed_borough'] = df_taxis['pickup_borough'] != df_taxis['dropoff_borough']
df_taxis.head(1)
pickup | dropoff | passengers | distance | fare | tip | tolls | total | color | payment | pickup_zone | dropoff_zone | pickup_borough | dropoff_borough | tip_pct | crossed_borough | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 2019-03-23 20:21:09 | 2019-03-23 20:27:24 | 1 | 1.6 | 7.0 | 2.15 | 0.0 | 12.95 | yellow | credit card | Lenox Hill West | UN/Turtle Bay South | Manhattan | Manhattan | 30.714286 | False |
Now, we want to add more information to our scatterplot––we have a categorical variable (crossed_borough
), and we want the color (or perhaps shape) of each point to indicate the level of that variable, i.e., whether the trip crossed between boroughs.
cdict = {True: 'orange', False: 'blue'}
fig, ax = plt.subplots()
for group in set(df_taxis['crossed_borough']):
df_tmp = df_taxis[df_taxis['crossed_borough']==group] # filter to appropriate group
ax.scatter(
x = df_tmp['distance'],
y = df_tmp['fare'],
label = group,
c = cdict[group],
alpha = .3
)
plt.xlabel("Distance")
plt.ylabel("Fare ($)")
ax.legend(title = "Crossed Borough")
plt.show()
cdict = {True: 'orange', False: 'blue'}
fig, ax = plt.subplots()
for group in set(df_taxis['crossed_borough']):
df_tmp = df_taxis[df_taxis['crossed_borough']==group] # filter to appropriate group
ax.scatter(
x = df_tmp['distance'],
y = df_tmp['tip_pct'],
label = group,
c = cdict[group],
alpha = .3
)
plt.xlabel("Distance")
plt.ylabel("Tip (%)")
ax.legend(title = "Crossed Borough")
plt.show()
Another commonly found type of plot is a barplot.
Typically, barplots feature:
Barplots are effective for drawing attention to differences in the magnitude of some quantity across categories––though personally, I don't think they're always the most effective in terms of the amount of information conveyed per unit of space.
pyplot
¶The corresponding pyplot
function is bar
. We saw this earlier, when we were trying to plot the counts of a categorical variable.
The basic inputs are as follows:
plt.bar(
x = [ ... ], # array of categories
height = [ ... ] # height of each category
)
groupby
¶Once again, groupby
is very helpful here. We'll want to:
tip_pct
for each level of the color
variable. plt.bar
. df_color_tips_mean = df_taxis[['tip_pct', 'color']].groupby('color').mean().reset_index()
df_color_tips_mean
color | tip_pct | |
---|---|---|
0 | green | 7.441110 |
1 | yellow | 18.627966 |
# Q: What's missing from this graph?
plt.bar(x = df_color_tips_mean['color'],
height = df_color_tips_mean['tip_pct'])
plt.xlabel("Taxi Color")
plt.ylabel("Tip (%)")
plt.title("Tip (%) by Taxi Color")
Text(0.5, 1.0, 'Tip (%) by Taxi Color')
To add error bars, we can use plt.errorbar
.
## First, calculate standard error of the mean
df_color_tips_sem = df_taxis[['tip_pct', 'color']].groupby('color').sem().reset_index()
df_color_tips_sem
color | tip_pct | |
---|---|---|
0 | green | 0.384112 |
1 | yellow | 0.194550 |
plt.bar(x = df_color_tips_mean['color'],
height = df_color_tips_mean['tip_pct'])
plt.errorbar(x = df_color_tips_mean['color'], ## as with bar, the x-axis is the *color*
y = df_color_tips_mean['tip_pct'], ## as with bar, the y-axis is the *mean*
yerr = df_color_tips_sem['tip_pct'] * 2, ## two standard errors, as typical
ls = 'none', ## toggle this to connect or not connect the lines
color = "black")
plt.xlabel("Taxi Color")
plt.ylabel("Tip (%)")
plt.title("Tip (%) by Taxi Color")
Text(0.5, 1.0, 'Tip (%) by Taxi Color')
pyplot
affordances¶This short tutorial really only scratched the surface of what you can do with pyplot
.
Just to give a sense:
pyplot.text
seaborn
¶In this section, we'll replicate some of those same plots from above, but now using seaborn
.
As you'll see, seaborn
offers an API that's (in most cases) cleaner and easier to use than pyplot
, though there's a trade-off in terms of flexibility.
Depending on your version of seaborn
, you can plot a histogram using either:
seaborn.distplot
: typically used for creating density plots, but can be adapted for histograms.seaborn.histplot
: targeted more specifically at histograms.sns.distplot(df_taxis['distance'], ## vector to plot
kde = False, ## whether to use kernel density estimation to fit density curve
norm_hist = False ## whether to norm histogram to display probability vs. counts
)
plt.xlabel("Distance (miles)")
plt.ylabel("Count")
plt.title("Distribution of miles traveled")
Text(0.5, 1.0, 'Distribution of miles traveled')
Fortunately, seaborn
makes it much easier to create scatterplots, especially if you'd like to change the hue or style of the points.
sns.scatterplot(
data = ..., # name of dataframe
x = ..., # name of x-axis variable
y = ..., # name of y-axis variable
hue = ..., # name of hue variable
style = ..., # name of style variable
size = ..., # name of variable to modulate point size
)
# Replicating our initial plot: Do longer distances cost more?
sns.scatterplot(data = df_taxis,
x = "distance",
y = "fare",
alpha = .7)
plt.xlabel("Distance (miles)")
plt.ylabel("Fare ($)")
plt.title("Raw Fare by Distance Traveled")
plt.show()
# Replicating our initial plot: Do longer distances cost more?
sns.scatterplot(data = df_taxis,
x = "distance",
y = "fare",
hue = "crossed_borough",
alpha = .7)
plt.xlabel("Distance (miles)")
plt.ylabel("Fare ($)")
plt.title("Raw Fare by Distance Traveled")
plt.show()
Finally, we can reconstruct some of the bar plots we created above, using the barplot
function.
sns.barplot(
data = ..., # name of dataframe
x = ..., # name of x-axis variable (categorical)
y = ..., # name of y-axis variable (continuous)
hue = ..., # name of hue variable
ci = float # size of confidence intervals to draw (these are estimated using *bootstrapping*, not SEM)
)
sns.barplot(data = df_taxis,
x = "color",
y = "tip_pct",
ci = 95
)
plt.xlabel("Color")
plt.ylabel("Tip (%)")
plt.title("Tip (%) by Taxi Color")
Text(0.5, 1.0, 'Tip (%) by Taxi Color')
## Same as above, but with hues corresponding to pickup borough
sns.barplot(data = df_taxis,
x = "color",
y = "tip_pct",
hue = "crossed_borough",
ci = 95
)
plt.xlabel("Color")
plt.ylabel("Tip (%)")
plt.legend(title = "Crossed Borough")
<matplotlib.legend.Legend at 0x7f94d0514908>
taxis
¶This dataset is quite rich, and you could ask a number of other questions.
Just to get you started:
distance
traveled vary as a function of the hour of pickup? tip_pct
larger as a function of the hour of pickup? distance
relate to the amount of time between pickup
and dropoff
? distance
)? Data visualizations are increasingly common in blog posts, news articles, or even social media.
Reconstructing a visualization you find "in the wild" can reveal some of the hidden complexities of this process.
To that end, I reached out to Timothy Lee, author of a very interesting blog post found on Full Stack Economics, entitled "18 charts that explain the American economy". He was kind enough to give me the data he used to construct two of the charts included in that post: chart #7 and chart #14.
In this reconstruction, let's focus on Chart 14.
This chart depicts the relative magnitudes of different causes for missing work, over time. For me, there are a couple takeaways:
Let's try to reconstruct the chart from scratch!
Just for reference, the chart looks like this:
First, we load and inspect the data. Here, we have to actually read in a .csv
file using pandas.read_csv
.
Question: What do you you all notice about the way this dataframe is structured?
df_missing = pd.read_csv("data/missing_work.csv")
len(df_missing)
110
# Inspecting the dataset
df_missing.head(5)
Year | Child care problems | Maternity or paternity leave | Other family or personal obligations | Illness or injury | Vacation | Month | |
---|---|---|---|---|---|---|---|
0 | 2012 | 18 | 313 | 246 | 899 | 1701 | 10 |
1 | 2012 | 35 | 278 | 230 | 880 | 1299 | 11 |
2 | 2012 | 13 | 245 | 246 | 944 | 1005 | 12 |
3 | 2013 | 14 | 257 | 250 | 1202 | 1552 | 1 |
4 | 2013 | 27 | 258 | 276 | 1079 | 1305 | 2 |
## Note: this is just one of the many approaches!
df_missing_long = pd.melt(df_missing, # dataframe
value_vars=['Child care problems', # which columns to stack into long format
'Maternity or paternity leave',# which columns to stack into long format
'Other family or personal obligations', # which columns to stack into long format
'Illness or injury', # which columns to stack into long format
'Vacation'], # which columns to stack into long format
id_vars=["Year", "Month"], # which columns to group/id them by
var_name = "Cause", # what to call the new grouping column
value_name = "Days" # what to call the new value column
)
df_missing_long.head(5)
Year | Month | Cause | Days | |
---|---|---|---|---|
0 | 2012 | 10 | Child care problems | 18 |
1 | 2012 | 11 | Child care problems | 35 |
2 | 2012 | 12 | Child care problems | 13 |
3 | 2013 | 1 | Child care problems | 14 |
4 | 2013 | 2 | Child care problems | 27 |
# Note: One other thing––the original graph aggregated data by the *millions*.
# This is already in the *thousands*, so we'll divide by another 1000 to match the origina.
df_missing_long['Days_millions'] = df_missing_long['Days']/1000
Before we try to reconstruct the plot, let's first look at the distribution of missed days across years, to see if we can detect any patterns.
sns.lineplot(data = df_missing_long,
x = 'Month',
y = 'Days_millions',
alpha = .7)
plt.xlabel("Month")
plt.ylabel("Days (Millions)")
Text(0, 0.5, 'Days (Millions)')
# This shows us that it's *vacation* specifically that has a seasonal pattern
sns.lineplot(data = df_missing_long,
x = 'Month',
y = 'Days',
hue = 'Cause',
alpha = .7)
# This makes sure the legend doesn't cover up the plot itself
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
plt.xlabel("Month")
plt.ylabel("Days (Millions)")
Text(0, 0.5, 'Days (Millions)')
Now, rather than averaging across years to get a picture of the "average year", let's just look at this over time.
One issue is that we'll first need to convert each observation (a Month
and a Year
) into a datetime
object, so we can plug this into pandas
and seaborn
native datetime
operations.
## First, let's concatenate each month and year into a single string
df_missing_long['date'] = df_missing_long.apply(lambda row: str(row['Month']) + '-' + str(row['Year']), axis = 1)
df_missing_long.head(5)
Year | Month | Cause | Days | Days_millions | date | |
---|---|---|---|---|---|---|
0 | 2012 | 10 | Child care problems | 18 | 0.018 | 10-2012 |
1 | 2012 | 11 | Child care problems | 35 | 0.035 | 11-2012 |
2 | 2012 | 12 | Child care problems | 13 | 0.013 | 12-2012 |
3 | 2013 | 1 | Child care problems | 14 | 0.014 | 1-2013 |
4 | 2013 | 2 | Child care problems | 27 | 0.027 | 2-2013 |
## Now, let's create a new "datetime" column using the `pd.to_datetime` function
df_missing_long['datetime'] = pd.to_datetime(df_missing_long['date'])
df_missing_long.head(5)
Year | Month | Cause | Days | Days_millions | date | datetime | |
---|---|---|---|---|---|---|---|
0 | 2012 | 10 | Child care problems | 18 | 0.018 | 10-2012 | 2012-10-01 |
1 | 2012 | 11 | Child care problems | 35 | 0.035 | 11-2012 | 2012-11-01 |
2 | 2012 | 12 | Child care problems | 13 | 0.013 | 12-2012 | 2012-12-01 |
3 | 2013 | 1 | Child care problems | 14 | 0.014 | 1-2013 | 2013-01-01 |
4 | 2013 | 2 | Child care problems | 27 | 0.027 | 2-2013 | 2013-02-01 |
sns.lineplot(data = df_missing_long,
x = 'datetime',
y = 'Days_millions',
hue = 'Cause')
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
plt.xlabel("Year")
plt.ylabel("Days (Millions)")
plt.title("Reasons for missing work")
Text(0.5, 1.0, 'Reasons for missing work')
If you look at the original plot carefully, you'll see that the observations have been smoothed using a rolling window average.
Fortunately, we can do the same thing using the pandas.rolling
function.
## Construct a 3-month rolling average to "smooth" the data
df_missing_long['Days_rolling_avg'] = df_missing_long['Days_millions'].rolling(3).mean()
sns.lineplot(data = df_missing_long,
x = 'datetime',
y = 'Days_rolling_avg',
hue = 'Cause')
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
plt.xlabel("Year")
plt.ylabel("Days (Millions)")
plt.title("Reasons for missing work (smoothed)")
Text(0.5, 1.0, 'Reasons for missing work (smoothed)')
## Discussion: Should we limit y-axis/causes to investigate trends in other causes, like child care problems?
## Q: How could we do this?
sns.lineplot(data = df_missing_long,
x = 'datetime',
y = 'Days_millions',
hue = 'Cause',
alpha = .5)
# An easy way to do this is just limit the extent of the y-axis
plt.ylim(0, .5)
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
plt.xlabel("Year")
plt.ylabel("Days (Millions)")
plt.title("Reasons for missing work")
Text(0.5, 1.0, 'Reasons for missing work')
This concludes the tutorial! Hopefully, this has equipped you with some of the tools necessary to:
matplotlib
or seaborn
. If you're interested in R, a tutorial on data visualization in R can be found here.
As discussed in the slides, one key motivation for making data visualizations is exploratory data analysis (EDA). Among other things, EDA can be used to ensure that the assumptions of a statistical test are met (e.g., that a relationship in our data is linear, if one is running linear regression).
df_anscombe = sns.load_dataset("anscombe")
df_anscombe.head(5)
dataset | x | y | |
---|---|---|---|
0 | I | 10.0 | 8.04 |
1 | I | 8.0 | 6.95 |
2 | I | 13.0 | 7.58 |
3 | I | 9.0 | 8.81 |
4 | I | 11.0 | 8.33 |
Just to make sure the central claim is correct, let's calculate Pearson's r for each dataset. Is it roughly the same in each?
from scipy.stats import pearsonr
corrs = []
for dataset in set(df_anscombe['dataset']):
df_tmp = df_anscombe[df_anscombe['dataset']==dataset]
r = round(pearsonr(df_tmp['x'], df_tmp['y'])[0], 4)
corrs.append({
'r': r,
'dataset': dataset
})
df_corrs = pd.DataFrame(corrs)
df_corrs
r | dataset | |
---|---|---|
0 | 0.8163 | III |
1 | 0.8164 | I |
2 | 0.8162 | II |
3 | 0.8165 | IV |
First, I'm going to plot the data just using seaborn.scatterplot
.
# Q: What's wrong with this approach?
sns.scatterplot(data = df_anscombe,
x = 'x',
y = 'y')
plt.xlabel("X")
plt.ylabel("Y")
Text(0, 0.5, 'Y')
FacetGrid
¶The above plot lumps all the datasets together into the same plot. But we want to visualize them all separately.
For this, we can rely on FacetGrid
.
FacetGrid
¶The FacetGrid
object allows you to map
a particular plot aesthetic (i.e., a particular x/y
axis choice) onto multiple subplots, broken up by the levels of a categorical variable.
For example, you could create different subplots for payment
type (cash
vs. credit card
), each showing the same relationship between distance
and fare
.
g = sns.FacetGrid(df_anscombe, col="dataset")
g.map(sns.scatterplot, "x", "y")
<seaborn.axisgrid.FacetGrid at 0x7f94d2f4aa58>
lmplot
¶Another approach that allows us to facet
the data is to use lmplot
.
The benefit of lmplot
is that it will also plot a regression line over our data, which, in this case, clearly demonstrates that across each dataset, the estimated slope is roughly the same.
sns.lmplot(data = df_anscombe,
x = 'x',
y = 'y',
col = 'dataset')
<seaborn.axisgrid.FacetGrid at 0x7f94d2fc61d0>
This concludes the tutorial! Hopefully, this has equipped you with some of the tools necessary to:
matplotlib
or seaborn
. If you're interested in R, a tutorial on data visualization in R can be found here.