Introduction to Data Visualization in Python

Author: Sean Trott

The goal of this tutorial is to familiarize you with the basics of data visualization in Python.

Note that the slides version can be found here, and the repository itself can be found here.

What is data visualization?

Data visualization refers to the process (and result) of representing data graphically.

For our purposes today, we'll be talking mostly about common methods of plotting data, including:

  • Histograms
  • Scatterplots
  • Line plots
  • Bar plots

Why is data visualization important?

Exploratory data analysis

  • Checking assumptions: does my data look like how I expect it to look?
  • Generating hypotheses: what relationships can I discover in the data?

Communicating insights

  • Communicating information clearly, concisely, and accurately.

Impacting the world

  • Data visualizations can (and have) played a big role in shaping policy, business decisions, and more.

Exploratory Data Analysis: Checking your assumptions

Anscombe's Quartet

title

Communicating Insights

Reference: Full Stack Economics

title

Impacting the world

Florence Nightingale (1820-1910) was a social reformer, statistician, and founder of modern nursing.

title

So how do we visualize data?

In today's tutorial, we'll discuss how to visualize data using Python.

Key learning outcomes

  • Describe the benefits and advantages of data visualization.
  • Propose a suitable data visualization for a particular dataset and question.
  • Implement a first-draft of this visualization in Python, using either matplotlib or seaborn.
  • Evaluate this data visualization and make suggestions for improvements in future iterations.
  • Reconstruct a visualization found "in the wild" using the appropriate dataset.

Loading packages

Here, we load the core packages we'll be using:

  • pandas: a library for loading, reshaping, and joining DataFrame objects.
  • matplotlib: Python's core plotting/graphics library.
  • seaborn: a library built on top of matplotlib, which makes certain plotting functions easier.

We also add some lines of code that make sure our visualizations will plot "inline" with our code, and that they'll have nice, crisp quality.

In [1]:
import pandas as pd # conventionalized abbreviation
import matplotlib.pyplot as plt # conventionalized abbreviation
import seaborn as sns # conventionalized abbreviation
/Users/seantrott/anaconda3/lib/python3.7/site-packages/pandas/compat/_optional.py:138: UserWarning: Pandas requires version '2.7.0' or newer of 'numexpr' (version '2.6.8' currently installed).
  warnings.warn(msg, UserWarning)
In [3]:
%matplotlib inline 
%config InlineBackend.figure_format = 'retina'

Part 1: Basic plot types with matplotlib.pyplot

In this section, we'll explore several basic plot types, first using matplotlib.pyplot. These will include:

  • Histograms
  • Scatterplots
  • Barplots

Afterwards (in Part 2), we will discuss how to replicate these plot types with seaborn, an API built on top of pyplot.

Load dataset

To get started, we'll load a dataset from seaborn.

In [4]:
## Load taxis dataset
df_taxis = sns.load_dataset("taxis")
len(df_taxis)
Out[4]:
6433
In [5]:
## What are the columns/variables in this dataset?
df_taxis.head(5)
Out[5]:
pickup dropoff passengers distance fare tip tolls total color payment pickup_zone dropoff_zone pickup_borough dropoff_borough
0 2019-03-23 20:21:09 2019-03-23 20:27:24 1 1.60 7.0 2.15 0.0 12.95 yellow credit card Lenox Hill West UN/Turtle Bay South Manhattan Manhattan
1 2019-03-04 16:11:55 2019-03-04 16:19:00 1 0.79 5.0 0.00 0.0 9.30 yellow cash Upper West Side South Upper West Side South Manhattan Manhattan
2 2019-03-27 17:53:01 2019-03-27 18:00:25 1 1.37 7.5 2.36 0.0 14.16 yellow credit card Alphabet City West Village Manhattan Manhattan
3 2019-03-10 01:23:59 2019-03-10 01:49:51 1 7.70 27.0 6.15 0.0 36.95 yellow credit card Hudson Sq Yorkville West Manhattan Manhattan
4 2019-03-30 13:27:42 2019-03-30 13:37:14 3 2.16 9.0 1.10 0.0 13.40 yellow credit card Midtown East Yorkville West Manhattan Manhattan

Histograms: plotting frequency distributions

Histograms are critical for visualizing how your data distribute.

They are typically used for continuous, quantitative variables. Examples could include: height, income, and temperature.

The pyplot package has a dedicated function for creating histograms, called hist (documentation here).

An example call would look like this:

plt.hist(
    x = [ ... ],
    bins = 10, # set number of bins
    density = False, # indicate whether to draw density plot
    alpha = .7 # indicate shading/transparency of plot
)

Now let's plot some of our variables from the taxis dataset!

In [6]:
## Q: How would you describe the shape of this distribution?

plt.hist(
    x=df_taxis['fare'], # this is the variable we're plotting
    bins=20, # this is the number of bins we want to know
    alpha = .5 # this is our alpha level
)
plt.show()

We can also add axis labels, so it's clear what we're plotting.

In [7]:
plt.hist(
    x=df_taxis['fare'], # this is the variable we're plotting
    bins=20, # this is the number of bins we want to know
    alpha = .5 # this is our alpha level
)
plt.xlabel("Fare ($)")
plt.ylabel("Count")
plt.title("Distribution of Taxi Fares ($)")
plt.show()

Q: How would you plot the distribution of trip distances?

In [8]:
plt.hist(
    x=df_taxis['distance'], # this is the variable we're plotting
    bins=20, # this is the number of bins we want to know
    alpha = .5 # this is our alpha level
)
plt.xlabel("Distance (miles)")
plt.ylabel("Count")
plt.title("Distribution of Trip Distances")
plt.show()

Q: What about the distribution of tip amounts?

In [9]:
plt.hist(
    x=df_taxis['tip'], # this is the variable we're plotting
    bins=20, # this is the number of bins we want to know
    alpha = .5 # this is our alpha level
)
plt.xlabel("Tip ($)")
plt.ylabel("Count")
plt.title("Distribution of Tip Amounts ($)")
plt.show()

Q: What if we wanted to plot the distribution of tip percentages?

Creating a new variable

Presumably, these tip amounts correlate with the overall fare. We might also want to know about the amount in percentage-terms, e.g., how much a customer tipped as a percentage of how much the fare was.

In [10]:
## tip_pct = tip / fare
df_taxis['tip_pct'] = (df_taxis['tip'] / df_taxis['fare']) * 100

Now we can plot this new variable, to get a sense of the distribution of tip percentages.

Q: What conclusions might we draw about the distribution of tip percentages here?

In [11]:
plt.hist(
    x=df_taxis['tip_pct'], # this is the variable we're plotting
    bins=20, # this is the number of bins we want to know
    alpha = .5 # this is our alpha level
)
plt.xlabel("Tip (%)")
plt.ylabel("Count")
plt.title("Distribution of Tip Percentages (%)")
plt.show()

Q: Adding other markers to the plot?

What if we wanted to add more information to this plot, such as the mean tip percentage?

(Hint: use either plt.axvline or plt.text).

In [12]:
# First, calculate mean tip for ease of access
MEAN_TIP = df_taxis['tip_pct'].mean()
In [13]:
# Drawing a vertical line

plt.hist(
    x=df_taxis['tip_pct'], # this is the variable we're plotting
    bins=20, # this is the number of bins we want to know
    alpha = .5 # this is our alpha level
)

plt.axvline(MEAN_TIP, linestyle = "dotted")

plt.xlabel("Tip (%)")
plt.ylabel("Count")
plt.title("Distribution of Tip Percentages (%)")
plt.show()
In [14]:
# Annotating with text

plt.hist(
    x=df_taxis['tip_pct'], # this is the variable we're plotting
    bins=20, # this is the number of bins we want to know
    alpha = .5 # this is our alpha level
)

plt.text(MEAN_TIP, # where on x-axis to display
         y = 1500, # where on y-axis to display
         s = "Mean Tip: {mean}%".format(mean=str(round(MEAN_TIP, 2))))

plt.xlabel("Tip (%)")
plt.ylabel("Count")
plt.title("Distribution of Tip Percentages (%)")
plt.show()

Scatterplots: plotting multiple continuous variables

Scatterplots can be used to look at how two different continuous variables relate to each other.

Among other things, scatterplots are critical for exploratory data analysis––for example, if you want to run a linear regression on your dataset, it is important to determine that the relationship you're investigating is linear (remember Anscombe's Quartet).

The pyplot package has a dedicated function for creating scatterplots as well, called scatter (documentation here).

An example call would look like this:

plt.scatter(
    x = [ ... ], # array of values for x-axis
    y = [ ... ], # array of values for y-axis
    c = [ ... ], # array of values to indicate color of points
    alpha = .7 # indicate shading/transparency of plot
)

What to ask our data?

Intuitively, we might want to know a few things:

  1. Do longer distances cost more?
  2. Do people tip more (either in amount or percentage terms) for longer distances?
  3. Does it cost more to travel more between boroughs, even holding the distance constant?
In [15]:
# Q: Do longer distances cost more?

plt.scatter(
    x = df_taxis['distance'],
    y = df_taxis['fare'],
    alpha = .6
)
plt.xlabel("Distance (miles)")
plt.ylabel("Fare ($)")
plt.title("Raw Fare by Distance Traveled")
plt.show()

Q: How could we ask about the relationship between distance and tip amount?

In [16]:
plt.scatter(
    x = df_taxis['distance'],
    y = df_taxis['tip'],
    alpha = .6
)
plt.xlabel("Distance (miles)")
plt.ylabel("Tip ($)")
plt.title("Tip Amount by Distance Traveled")
plt.show()

Q: What about distance and tip percentage?

In [17]:
plt.scatter(
    x = df_taxis['distance'],
    y = df_taxis['tip_pct'],
    alpha = .4
)
plt.xlabel("Distance (miles)")
plt.ylabel("Tip (%)")
plt.title("Tip Percentage by Distance Traveled")
plt.show()

Q: What about crossing between boroughs?

In [18]:
## Create a new variable
df_taxis['crossed_borough'] = df_taxis['pickup_borough'] != df_taxis['dropoff_borough']
df_taxis.head(1)
Out[18]:
pickup dropoff passengers distance fare tip tolls total color payment pickup_zone dropoff_zone pickup_borough dropoff_borough tip_pct crossed_borough
0 2019-03-23 20:21:09 2019-03-23 20:27:24 1 1.6 7.0 2.15 0.0 12.95 yellow credit card Lenox Hill West UN/Turtle Bay South Manhattan Manhattan 30.714286 False

Plotting categorical and continuous data

Now, we want to add more information to our scatterplot––we have a categorical variable (crossed_borough), and we want the color (or perhaps shape) of each point to indicate the level of that variable, i.e., whether the trip crossed between boroughs.

In [20]:
cdict = {True: 'orange', False: 'blue'}
fig, ax = plt.subplots()
for group in set(df_taxis['crossed_borough']):
    df_tmp = df_taxis[df_taxis['crossed_borough']==group] # filter to appropriate group
    ax.scatter(
        x = df_tmp['distance'],
        y = df_tmp['fare'],
        label = group,
        c = cdict[group],
        alpha = .3
    )
plt.xlabel("Distance")
plt.ylabel("Fare ($)")
ax.legend(title = "Crossed Borough")
plt.show()
In [21]:
cdict = {True: 'orange', False: 'blue'}
fig, ax = plt.subplots()
for group in set(df_taxis['crossed_borough']):
    df_tmp = df_taxis[df_taxis['crossed_borough']==group] # filter to appropriate group
    ax.scatter(
        x = df_tmp['distance'],
        y = df_tmp['tip_pct'],
        label = group,
        c = cdict[group],
        alpha = .3
    )
plt.xlabel("Distance")
plt.ylabel("Tip (%)")
ax.legend(title = "Crossed Borough")
plt.show()

Barplots

Another commonly found type of plot is a barplot.

Typically, barplots feature:

  • On the x-axis, a categorical (i.e., nominal) variable.
  • On the y-axis, a continuous variable, usually summarized in some way (e.g., the means, with standard errors).

Barplots are effective for drawing attention to differences in the magnitude of some quantity across categories––though personally, I don't think they're always the most effective in terms of the amount of information conveyed per unit of space.

Barplots in pyplot

The corresponding pyplot function is bar. We saw this earlier, when we were trying to plot the counts of a categorical variable.

The basic inputs are as follows:

plt.bar(
    x = [ ... ], # array of categories
    height = [ ... ] # height of each category
)

Q: What is the mean tip % for green vs. yellow taxis?

Creating a new dataframe using groupby

Once again, groupby is very helpful here. We'll want to:

  1. Calculate the mean tip_pct for each level of the color variable.
  2. Plot this using plt.bar.
In [22]:
df_color_tips_mean = df_taxis[['tip_pct', 'color']].groupby('color').mean().reset_index()
df_color_tips_mean
Out[22]:
color tip_pct
0 green 7.441110
1 yellow 18.627966
In [23]:
# Q: What's missing from this graph?

plt.bar(x = df_color_tips_mean['color'],
       height = df_color_tips_mean['tip_pct'])
plt.xlabel("Taxi Color")
plt.ylabel("Tip (%)")
plt.title("Tip (%) by Taxi Color")
Out[23]:
Text(0.5, 1.0, 'Tip (%) by Taxi Color')

Adding error bars

To add error bars, we can use plt.errorbar.

In [24]:
## First, calculate standard error of the mean
df_color_tips_sem = df_taxis[['tip_pct', 'color']].groupby('color').sem().reset_index()
df_color_tips_sem
Out[24]:
color tip_pct
0 green 0.384112
1 yellow 0.194550
In [25]:
plt.bar(x = df_color_tips_mean['color'],
        height = df_color_tips_mean['tip_pct'])

plt.errorbar(x = df_color_tips_mean['color'], ## as with bar, the x-axis is the *color*
             y = df_color_tips_mean['tip_pct'], ## as with bar, the y-axis is the *mean*
             yerr = df_color_tips_sem['tip_pct'] * 2, ## two standard errors, as typical
             ls = 'none', ## toggle this to connect or not connect the lines
             color = "black")

plt.xlabel("Taxi Color")
plt.ylabel("Tip (%)")
plt.title("Tip (%) by Taxi Color")
Out[25]:
Text(0.5, 1.0, 'Tip (%) by Taxi Color')

Other pyplot affordances

This short tutorial really only scratched the surface of what you can do with pyplot.

Just to give a sense:

Part 2: Plotting with seaborn

In this section, we'll replicate some of those same plots from above, but now using seaborn.

As you'll see, seaborn offers an API that's (in most cases) cleaner and easier to use than pyplot, though there's a trade-off in terms of flexibility.

Histograms

Depending on your version of seaborn, you can plot a histogram using either:

  • seaborn.distplot: typically used for creating density plots, but can be adapted for histograms.
  • seaborn.histplot: targeted more specifically at histograms.
In [29]:
sns.distplot(df_taxis['distance'], ## vector to plot
             kde = False, ## whether to use kernel density estimation to fit density curve
             norm_hist = False ## whether to norm histogram to display probability vs. counts
            )
plt.xlabel("Distance (miles)")
plt.ylabel("Count")
plt.title("Distribution of miles traveled")
Out[29]:
Text(0.5, 1.0, 'Distribution of miles traveled')

Q: Exercise for home––recreate the rest of the histograms from Part 1 using seaborn.

Scatterplots

Fortunately, seaborn makes it much easier to create scatterplots, especially if you'd like to change the hue or style of the points.

sns.scatterplot(
    data = ..., # name of dataframe
    x = ..., # name of x-axis variable
    y = ..., # name of y-axis variable
    hue = ..., # name of hue variable
    style = ..., # name of style variable
    size = ..., # name of variable to modulate point size
)
In [30]:
# Replicating our initial plot: Do longer distances cost more?

sns.scatterplot(data = df_taxis,
                x = "distance",
                y = "fare",
                alpha = .7)
plt.xlabel("Distance (miles)")
plt.ylabel("Fare ($)")
plt.title("Raw Fare by Distance Traveled")
plt.show()

Q: How would you create the same plot, but with the points colored to reflect whether a borough has been crossed?

In [31]:
# Replicating our initial plot: Do longer distances cost more?

sns.scatterplot(data = df_taxis,
                x = "distance",
                y = "fare",
                hue = "crossed_borough",
                alpha = .7)
plt.xlabel("Distance (miles)")
plt.ylabel("Fare ($)")
plt.title("Raw Fare by Distance Traveled")
plt.show()

Q: Exercise for home––recreate the other scatterplots, or new ones, using seaborn.

Bar plots

Finally, we can reconstruct some of the bar plots we created above, using the barplot function.

sns.barplot(
    data = ..., # name of dataframe
    x = ..., # name of x-axis variable (categorical)
    y = ..., # name of y-axis variable (continuous)
    hue = ..., # name of hue variable
    ci = float # size of confidence intervals to draw (these are estimated using *bootstrapping*, not SEM)
)

Q: How would we recreate the figure showing average tip % by taxi color?

In [32]:
sns.barplot(data = df_taxis,
            x = "color",
            y = "tip_pct",
            ci = 95
           )

plt.xlabel("Color")
plt.ylabel("Tip (%)")
plt.title("Tip (%) by Taxi Color")
Out[32]:
Text(0.5, 1.0, 'Tip (%) by Taxi Color')

Q: What other variables would you be interested in looking at? Whether or not they're crossed a borough?

In [33]:
## Same as above, but with hues corresponding to pickup borough
sns.barplot(data = df_taxis,
            x = "color",
            y = "tip_pct",
            hue = "crossed_borough",
            ci = 95
           )

plt.xlabel("Color")
plt.ylabel("Tip (%)")
plt.legend(title = "Crossed Borough")
Out[33]:
<matplotlib.legend.Legend at 0x7f94d0514908>

Additional exercises with taxis

This dataset is quite rich, and you could ask a number of other questions.

Just to get you started:

  1. Does distance traveled vary as a function of the hour of pickup?
  2. Is the tip_pct larger as a function of the hour of pickup?
  3. How does distance relate to the amount of time between pickup and dropoff?
  4. Do customers tip more when the trip takes longer (controlling for distance)?

Part 3: Reconstructing plots from "the wild"

Data visualizations are increasingly common in blog posts, news articles, or even social media.

Reconstructing a visualization you find "in the wild" can reveal some of the hidden complexities of this process.

To that end, I reached out to Timothy Lee, author of a very interesting blog post found on Full Stack Economics, entitled "18 charts that explain the American economy". He was kind enough to give me the data he used to construct two of the charts included in that post: chart #7 and chart #14.

In this reconstruction, let's focus on Chart 14.

Chart 14: Reasons for missing work

This chart depicts the relative magnitudes of different causes for missing work, over time. For me, there are a couple takeaways:

  1. Vacation time is highly seasonal, peaking each summer and creating an almost oscillatory pattern.
  2. Starting with 2020, vacation time drops a bit (though preserves the seasonal pattern), and we see a rise in missed work due to illness.

Let's try to reconstruct the chart from scratch!

Just for reference, the chart looks like this:

title

Load the data

First, we load and inspect the data. Here, we have to actually read in a .csv file using pandas.read_csv.

Question: What do you you all notice about the way this dataframe is structured?

In [34]:
df_missing = pd.read_csv("data/missing_work.csv")
len(df_missing)
Out[34]:
110
In [35]:
# Inspecting the dataset
df_missing.head(5)
Out[35]:
Year Child care problems Maternity or paternity leave Other family or personal obligations Illness or injury Vacation Month
0 2012 18 313 246 899 1701 10
1 2012 35 278 230 880 1299 11
2 2012 13 245 246 944 1005 12
3 2013 14 257 250 1202 1552 1
4 2013 27 258 276 1079 1305 2

Converting to long format

Q: Any ideas on how to do this?

In [36]:
## Note: this is just one of the many approaches!

df_missing_long = pd.melt(df_missing, # dataframe
                          value_vars=['Child care problems', # which columns to stack into long format
                                      'Maternity or paternity leave',# which columns to stack into long format
                                      'Other family or personal obligations', # which columns to stack into long format
                                      'Illness or injury', # which columns to stack into long format
                                      'Vacation'], # which columns to stack into long format
                          id_vars=["Year", "Month"], # which columns to group/id them by
                          var_name = "Cause", # what to call the new grouping column
                          value_name = "Days" # what to call the new value column
                         )
In [37]:
df_missing_long.head(5)
Out[37]:
Year Month Cause Days
0 2012 10 Child care problems 18
1 2012 11 Child care problems 35
2 2012 12 Child care problems 13
3 2013 1 Child care problems 14
4 2013 2 Child care problems 27
In [38]:
# Note: One other thing––the original graph aggregated data by the *millions*.
# This is already in the *thousands*, so we'll divide by another 1000 to match the origina. 
df_missing_long['Days_millions'] = df_missing_long['Days']/1000

Distribution over the year

Before we try to reconstruct the plot, let's first look at the distribution of missed days across years, to see if we can detect any patterns.

Q: What pattern should we conclude?

In [40]:
sns.lineplot(data = df_missing_long,
             x = 'Month',
             y = 'Days_millions',
             alpha = .7)

plt.xlabel("Month")
plt.ylabel("Days (Millions)")
Out[40]:
Text(0, 0.5, 'Days (Millions)')

Q (2): Why is there so much error in the summer months?

In [41]:
# This shows us that it's *vacation* specifically that has a seasonal pattern
sns.lineplot(data = df_missing_long,
             x = 'Month',
             y = 'Days',
             hue = 'Cause',
             alpha = .7)

# This makes sure the legend doesn't cover up the plot itself
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
plt.xlabel("Month")
plt.ylabel("Days (Millions)")
Out[41]:
Text(0, 0.5, 'Days (Millions)')

Reconstructing the original plot

Now, rather than averaging across years to get a picture of the "average year", let's just look at this over time.

One issue is that we'll first need to convert each observation (a Month and a Year) into a datetime object, so we can plug this into pandas and seaborn native datetime operations.

In [42]:
## First, let's concatenate each month and year into a single string
df_missing_long['date'] = df_missing_long.apply(lambda row: str(row['Month']) + '-' + str(row['Year']), axis = 1)
df_missing_long.head(5)
Out[42]:
Year Month Cause Days Days_millions date
0 2012 10 Child care problems 18 0.018 10-2012
1 2012 11 Child care problems 35 0.035 11-2012
2 2012 12 Child care problems 13 0.013 12-2012
3 2013 1 Child care problems 14 0.014 1-2013
4 2013 2 Child care problems 27 0.027 2-2013
In [43]:
## Now, let's create a new "datetime" column using the `pd.to_datetime` function
df_missing_long['datetime'] = pd.to_datetime(df_missing_long['date'])
df_missing_long.head(5)
Out[43]:
Year Month Cause Days Days_millions date datetime
0 2012 10 Child care problems 18 0.018 10-2012 2012-10-01
1 2012 11 Child care problems 35 0.035 11-2012 2012-11-01
2 2012 12 Child care problems 13 0.013 12-2012 2012-12-01
3 2013 1 Child care problems 14 0.014 1-2013 2013-01-01
4 2013 2 Child care problems 27 0.027 2-2013 2013-02-01

Q: Now, how can we actually create the final plot?

In [44]:
sns.lineplot(data = df_missing_long,
             x = 'datetime',
             y = 'Days_millions',
             hue = 'Cause')

plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
plt.xlabel("Year")
plt.ylabel("Days (Millions)")
plt.title("Reasons for missing work")
Out[44]:
Text(0.5, 1.0, 'Reasons for missing work')

Bonus: Smoothing

If you look at the original plot carefully, you'll see that the observations have been smoothed using a rolling window average.

Fortunately, we can do the same thing using the pandas.rolling function.

In [45]:
## Construct a 3-month rolling average to "smooth" the data
df_missing_long['Days_rolling_avg'] = df_missing_long['Days_millions'].rolling(3).mean()
In [46]:
sns.lineplot(data = df_missing_long,
             x = 'datetime',
             y = 'Days_rolling_avg',
             hue = 'Cause')

plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
plt.xlabel("Year")
plt.ylabel("Days (Millions)")
plt.title("Reasons for missing work (smoothed)")
Out[46]:
Text(0.5, 1.0, 'Reasons for missing work (smoothed)')
In [47]:
## Discussion: Should we limit y-axis/causes to investigate trends in other causes, like child care problems?
## Q: How could we do this?
In [48]:
sns.lineplot(data = df_missing_long,
            x = 'datetime',
            y = 'Days_millions',
            hue = 'Cause',
            alpha = .5)

# An easy way to do this is just limit the extent of the y-axis
plt.ylim(0, .5)

plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
plt.xlabel("Year")
plt.ylabel("Days (Millions)")
plt.title("Reasons for missing work")
Out[48]:
Text(0.5, 1.0, 'Reasons for missing work')

Conclusion

This concludes the tutorial! Hopefully, this has equipped you with some of the tools necessary to:

  • Describe the benefits and advantages of data visualization.
  • Propose a suitable data visualization for a particular dataset and question.
  • Implement a first-draft of this visualization in Python, using either matplotlib or seaborn.
  • Evaluate this data visualization and make suggestions for improvements in future iterations.
  • Reconstruct a visualization found "in the wild" using the appropriate dataset.

If you're interested in R, a tutorial on data visualization in R can be found here.

Part 4 (Bonus): Anscombe's Quartet

As discussed in the slides, one key motivation for making data visualizations is exploratory data analysis (EDA). Among other things, EDA can be used to ensure that the assumptions of a statistical test are met (e.g., that a relationship in our data is linear, if one is running linear regression).

Load data

In [49]:
df_anscombe = sns.load_dataset("anscombe")
df_anscombe.head(5)
Out[49]:
dataset x y
0 I 10.0 8.04
1 I 8.0 6.95
2 I 13.0 7.58
3 I 9.0 8.81
4 I 11.0 8.33

Calculate correlation coefficient for each dataset

Just to make sure the central claim is correct, let's calculate Pearson's r for each dataset. Is it roughly the same in each?

In [50]:
from scipy.stats import pearsonr

corrs = []
for dataset in set(df_anscombe['dataset']):
    df_tmp = df_anscombe[df_anscombe['dataset']==dataset]
    r = round(pearsonr(df_tmp['x'], df_tmp['y'])[0], 4)
    corrs.append({
        'r': r,
        'dataset': dataset
    })

df_corrs = pd.DataFrame(corrs)
df_corrs
Out[50]:
r dataset
0 0.8163 III
1 0.8164 I
2 0.8162 II
3 0.8165 IV

Plot the data: a first (and flawed) approach

First, I'm going to plot the data just using seaborn.scatterplot.

In [51]:
# Q: What's wrong with this approach?

sns.scatterplot(data = df_anscombe,
                x = 'x',
                y = 'y')

plt.xlabel("X")
plt.ylabel("Y")
Out[51]:
Text(0, 0.5, 'Y')

Plot the data using FacetGrid

The above plot lumps all the datasets together into the same plot. But we want to visualize them all separately.

For this, we can rely on FacetGrid.

Introducing FacetGrid

The FacetGrid object allows you to map a particular plot aesthetic (i.e., a particular x/y axis choice) onto multiple subplots, broken up by the levels of a categorical variable.

For example, you could create different subplots for payment type (cash vs. credit card), each showing the same relationship between distance and fare.

In [52]:
g = sns.FacetGrid(df_anscombe, col="dataset")
g.map(sns.scatterplot, "x", "y")
Out[52]:
<seaborn.axisgrid.FacetGrid at 0x7f94d2f4aa58>

Plot the data using lmplot

Another approach that allows us to facet the data is to use lmplot.

The benefit of lmplot is that it will also plot a regression line over our data, which, in this case, clearly demonstrates that across each dataset, the estimated slope is roughly the same.

In [53]:
sns.lmplot(data = df_anscombe,
           x = 'x',
           y = 'y',
           col = 'dataset')
Out[53]:
<seaborn.axisgrid.FacetGrid at 0x7f94d2fc61d0>

Conclusion

This concludes the tutorial! Hopefully, this has equipped you with some of the tools necessary to:

  • Describe the benefits and advantages of data visualization.
  • Propose a suitable data visualization for a particular dataset and question.
  • Implement a first-draft of this visualization in Python, using either matplotlib or seaborn.
  • Evaluate this data visualization and make suggestions for improvements in future iterations.
  • Reconstruct a visualization found "in the wild" using the appropriate dataset.

If you're interested in R, a tutorial on data visualization in R can be found here.

In [ ]: