# Is there a function to make scatterplot matrices in matplotlib?

Posted on

### Question :

Is there a function to make scatterplot matrices in matplotlib?

Example of scatterplot matrix Is there such a function in matplotlib.pyplot?

Generally speaking, matplotlib doesn’t usually contain plotting functions that operate on more than one axes object (subplot, in this case). The expectation is that you’d write a simple function to string things together however you’d like.

I’m not quite sure what your data looks like, but it’s quite simple to just build a function to do this from scratch. If you’re always going to be working with structured or rec arrays, then you can simplify this a touch. (i.e. There’s always a name associated with each data series, so you can omit having to specify names.)

As an example:

``````import itertools
import numpy as np
import matplotlib.pyplot as plt

def main():
np.random.seed(1977)
numvars, numdata = 4, 10
data = 10 * np.random.random((numvars, numdata))
fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
linestyle='none', marker='o', color='black', mfc='none')
fig.suptitle('Simple Scatterplot Matrix')
plt.show()

def scatterplot_matrix(data, names, **kwargs):
"""Plots a scatterplot matrix of subplots.  Each row of "data" is plotted
against other rows, resulting in a nrows by nrows grid of subplots with the
diagonal subplots labeled with "names".  Additional keyword arguments are
passed on to matplotlib's "plot" command. Returns the matplotlib figure
object containg the subplot grid."""
numvars, numdata = data.shape
fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))

for ax in axes.flat:
# Hide all ticks and labels
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)

# Set up ticks only on one side for the "edge" subplots...
if ax.is_first_col():
ax.yaxis.set_ticks_position('left')
if ax.is_last_col():
ax.yaxis.set_ticks_position('right')
if ax.is_first_row():
ax.xaxis.set_ticks_position('top')
if ax.is_last_row():
ax.xaxis.set_ticks_position('bottom')

# Plot the data.
for i, j in zip(*np.triu_indices_from(axes, k=1)):
for x, y in [(i,j), (j,i)]:
axes[x,y].plot(data[x], data[y], **kwargs)

# Label the diagonal subplots...
for i, label in enumerate(names):
axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
ha='center', va='center')

# Turn on the proper x or y axes ticks.
for i, j in zip(range(numvars), itertools.cycle((-1, 0))):
axes[j,i].xaxis.set_visible(True)
axes[i,j].yaxis.set_visible(True)

return fig

main()
`````` For those who do not want to define their own functions, there is a great data analysis libarary in Python, called Pandas, where one can find the scatter_matrix() method:

``````from pandas.plotting import scatter_matrix
df = pd.DataFrame(np.random.randn(1000, 4), columns = ['a', 'b', 'c', 'd'])
scatter_matrix(df, alpha = 0.2, figsize = (6, 6), diagonal = 'kde')
`````` You can also use Seaborn’s `pairplot` function:

``````import seaborn as sns
sns.set()
sns.pairplot(df, hue="species")
``````

Thanks for sharing your code! You figured out all the hard stuff for us. As I was working with it, I noticed a few little things that didn’t look quite right.

1. [FIX #1] The axis tics weren’t lining up like I would expect (i.e., in your example above, you should be able to draw a vertical and horizontal line through any point across all plots and the lines should cross through the corresponding point in the other plots, but as it sits now this doesn’t occur.

2. [FIX #2] If you have an odd number of variables you are plotting with, the bottom right corner axes doesn’t pull the correct xtics or ytics. It just leaves it as the default 0..1 ticks.

3. Not a fix, but I made it optional to explicitly input `names`, so that it puts a default `xi` for variable i in the diagonal positions.

Below you’ll find an updated version of your code that addresses these two points, otherwise preserving the beauty of your code.

``````import itertools
import numpy as np
import matplotlib.pyplot as plt

def scatterplot_matrix(data, names=[], **kwargs):
"""
Plots a scatterplot matrix of subplots.  Each row of "data" is plotted
against other rows, resulting in a nrows by nrows grid of subplots with the
diagonal subplots labeled with "names".  Additional keyword arguments are
passed on to matplotlib's "plot" command. Returns the matplotlib figure
object containg the subplot grid.
"""
numvars, numdata = data.shape
fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))

for ax in axes.flat:
# Hide all ticks and labels
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)

# Set up ticks only on one side for the "edge" subplots...
if ax.is_first_col():
ax.yaxis.set_ticks_position('left')
if ax.is_last_col():
ax.yaxis.set_ticks_position('right')
if ax.is_first_row():
ax.xaxis.set_ticks_position('top')
if ax.is_last_row():
ax.xaxis.set_ticks_position('bottom')

# Plot the data.
for i, j in zip(*np.triu_indices_from(axes, k=1)):
for x, y in [(i,j), (j,i)]:
# FIX #1: this needed to be changed from ...(data[x], data[y],...)
axes[x,y].plot(data[y], data[x], **kwargs)

# Label the diagonal subplots...
if not names:
names = ['x'+str(i) for i in range(numvars)]

for i, label in enumerate(names):
axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
ha='center', va='center')

# Turn on the proper x or y axes ticks.
for i, j in zip(range(numvars), itertools.cycle((-1, 0))):
axes[j,i].xaxis.set_visible(True)
axes[i,j].yaxis.set_visible(True)

# FIX #2: if numvars is odd, the bottom right corner plot doesn't have the
# correct axes limits, so we pull them from other axes
if numvars%2:
xlimits = axes[0,-1].get_xlim()
ylimits = axes[-1,0].get_ylim()
axes[-1,-1].set_xlim(xlimits)
axes[-1,-1].set_ylim(ylimits)

return fig

if __name__=='__main__':
np.random.seed(1977)
numvars, numdata = 4, 10
data = 10 * np.random.random((numvars, numdata))
fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
linestyle='none', marker='o', color='black', mfc='none')
fig.suptitle('Simple Scatterplot Matrix')
plt.show()
``````

Thanks again for sharing this with us. I have used it many times! Oh, and I re-arranged the `main()` part of the code so that it can be a formal example code or not get called if it is being imported into another piece of code.

While reading the question I expected to see an answer including rpy. I think this is a nice option taking advantage of two beautiful languages. So here it is:

``````import rpy
import numpy as np

def main():
np.random.seed(1977)
numvars, numdata = 4, 10
data = 10 * np.random.random((numvars, numdata))
mpg = data[0,:]
disp = data[1,:]
drat = data[2,:]
wt = data[3,:]
rpy.set_default_mode(rpy.NO_CONVERSION)

R_data = rpy.r.data_frame(mpg=mpg,disp=disp,drat=drat,wt=wt)

# Figure saved as eps
rpy.r.postscript('pairsPlot.eps')
rpy.r.pairs(R_data,
main="Simple Scatterplot Matrix Via RPy")
rpy.r.dev_off()

# Figure saved as png
rpy.r.png('pairsPlot.png')
rpy.r.pairs(R_data,
main="Simple Scatterplot Matrix Via RPy")
rpy.r.dev_off()

rpy.set_default_mode(rpy.BASIC_CONVERSION)

if __name__ == '__main__': main()
``````

I can’t post an image to show the result 🙁 sorry!