cheat sheet

matplotlib

Create publication-quality 2-D plots with matplotlib. Covers pyplot basics, subplots, savefig, common chart types, and the show-vs-save pitfall.

matplotlib — Plotting

What it is

matplotlib is the foundational Python plotting library. It produces static, animated, or interactive figures in PNG, PDF, SVG, and more. Most higher-level visualization libraries (seaborn, pandas .plot(), plotly) wrap or are inspired by matplotlib.

Install

bash
pip install matplotlib

Output: (none — exits 0 on success)

Quick example

python
import matplotlib.pyplot as plt

x = [1, 2, 3, 4, 5]
y = [2, 4, 3, 5, 4]

plt.plot(x, y, marker="o", color="steelblue")
plt.title("Simple Line Plot")
plt.xlabel("x")
plt.ylabel("y")
plt.tight_layout()
plt.savefig("line.png", dpi=150)
print("Saved line.png")

Output:

text
Saved line.png

The plot itself is a PNG file. The image shows a connected line with circular markers at each data point.

When / why to use it

  • Generating charts for reports, notebooks, or CI artifacts.
  • Fine-grained control over every visual element (fonts, tick marks, annotations).
  • When you need reproducible figures saved to disk rather than interactive browser charts.

Common pitfalls

plt.show() vs plt.savefig() — in non-interactive scripts, calling plt.show() blocks execution until you close the window. In headless environments (CI, Docker) it may crash. Use plt.savefig() and avoid plt.show() in scripts.

Figure accumulation — calling plt.plot() without plt.figure() or plt.clf() accumulates lines on the same axes across loop iterations. Call plt.figure() at the start of each chart or use the OO interface.

The object-oriented interface (fig, ax = plt.subplots()) is clearer for complex figures and avoids the global state issues of the plt.* convenience functions.

Richer example — subplots with OO interface

python
import matplotlib.pyplot as plt
import numpy as np

rng = np.random.default_rng(42)

fig, axes = plt.subplots(1, 2, figsize=(10, 4))

# Left: trig functions
x = np.linspace(0, 2 * np.pi, 200)
axes[0].plot(x, np.sin(x), label="sin", linewidth=2)
axes[0].plot(x, np.cos(x), label="cos", linestyle="--", linewidth=2)
axes[0].set_title("Trig functions")
axes[0].set_xlabel("radians")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Right: scatter
axes[1].scatter(rng.standard_normal(50), rng.standard_normal(50),
                alpha=0.6, c="coral", edgecolors="k", linewidths=0.5)
axes[1].set_title("Random scatter (n=50)")
axes[1].set_xlabel("x")
axes[1].set_ylabel("y")

plt.tight_layout()
plt.savefig("subplots.png", dpi=150, bbox_inches="tight")
print("Saved subplots.png")

Output:

text
Saved subplots.png

The output PNG shows two side-by-side panels: a smooth sine/cosine wave plot on the left and a scatter cloud on the right.

Common chart types

python
import matplotlib.pyplot as plt
import numpy as np

fig, axs = plt.subplots(2, 2, figsize=(10, 8))

# Bar chart
categories = ["A", "B", "C", "D"]
values = [23, 45, 12, 67]
axs[0, 0].bar(categories, values, color="steelblue")
axs[0, 0].set_title("Bar chart")

# Histogram
data = np.random.default_rng(0).standard_normal(500)
axs[0, 1].hist(data, bins=30, edgecolor="white")
axs[0, 1].set_title("Histogram")

# Pie chart
axs[1, 0].pie([40, 30, 20, 10], labels=categories, autopct="%1.0f%%")
axs[1, 0].set_title("Pie chart")

# Heatmap
matrix = np.random.default_rng(1).random((5, 5))
im = axs[1, 1].imshow(matrix, cmap="viridis")
fig.colorbar(im, ax=axs[1, 1])
axs[1, 1].set_title("Heatmap")

plt.tight_layout()
plt.savefig("chart_types.png", dpi=120)
print("Saved chart_types.png")

Output:

text
Saved chart_types.png

pyplot vs object-oriented API

matplotlib exposes two parallel interfaces. The pyplot (stateful, plt.*) API mirrors MATLAB: implicit "current figure" and "current axes" that every call mutates. The object-oriented (OO) API explicitly returns Figure and Axes objects you operate on: fig, ax = plt.subplots(); ax.plot(...). Use pyplot for quick one-off charts in a notebook; use the OO API for anything reusable, multi-panel, or scripted — it scales much better and is the recommended style.

python
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)

# Pyplot style
plt.figure()
plt.plot(x, np.sin(x))
plt.title("pyplot style")
plt.savefig("pyplot_style.png")
plt.close()

# OO style
fig, ax = plt.subplots()
ax.plot(x, np.sin(x))
ax.set_title("OO style")
ax.set_xlabel("x")
ax.set_ylabel("sin(x)")
fig.savefig("oo_style.png")
plt.close(fig)

Figures, axes, and artists

The matplotlib object hierarchy: a Figure is the whole canvas; a Figure contains one or more Axes (each Axes is a single plot area with its own coordinate system); every visible element on the canvas (lines, points, text, ticks, legends) is an Artist. Holding references to these objects lets you modify any property later, animate them, or remove them on demand.

python
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(6, 4))
line, = ax.plot([1, 2, 3], [4, 1, 5])
text = ax.text(2, 4, "marker")

# Modify properties after creation
line.set_color("crimson")
line.set_linewidth(3)
text.set_fontsize(14)

# Iterate every artist on the axes
for artist in ax.get_children():
    print(type(artist).__name__)

Output (truncated):

text
Line2D
Text
Spine
Spine
Spine
Spine
XAxis
YAxis
...

subplots and GridSpec

plt.subplots(rows, cols) returns a regular grid of axes. gridspec.GridSpec is the flexible alternative — it lets you span multiple cells, give different ratios to rows/columns, and mix subplot sizes within one figure. Use GridSpec when a dashboard needs e.g. a big chart on top, two small ones below.

python
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np

fig = plt.figure(figsize=(10, 6), constrained_layout=True)
gs = gridspec.GridSpec(2, 3, figure=fig, height_ratios=[2, 1])

ax_main = fig.add_subplot(gs[0, :])        # top row, full width
ax_main.plot(np.arange(20), np.sin(np.arange(20) / 2))
ax_main.set_title("main")

for i in range(3):
    ax = fig.add_subplot(gs[1, i])
    ax.bar(["a", "b", "c"], np.random.default_rng(i).random(3))
    ax.set_title(f"panel {i}")

fig.savefig("gridspec.png", dpi=120)
plt.close(fig)

Pass constrained_layout=True to plt.subplots / Figure for automatic spacing — it is the modern replacement for tight_layout() and handles colorbars and suptitles correctly.

Styling — rcParams, style sheets, and custom themes

plt.rcParams is a global dict of every default visual setting (font, line width, colors, DPI). Set it once at the top of a script for site-wide style. plt.style.use("seaborn-v0_8-whitegrid") applies a packaged stylesheet. For a project-wide style, write a mystyle.mplstyle file with key: value lines and call plt.style.use("mystyle.mplstyle").

python
import matplotlib.pyplot as plt
import numpy as np

# Pick a built-in style
print(plt.style.available[:8])

plt.style.use("ggplot")

# Or set rcParams individually
plt.rcParams.update({
    "figure.dpi": 120,
    "font.size": 11,
    "axes.titleweight": "bold",
    "axes.spines.top": False,
    "axes.spines.right": False,
})

fig, ax = plt.subplots(figsize=(6, 3))
ax.plot(np.linspace(0, 10, 100), np.sin(np.linspace(0, 10, 100)))
ax.set_title("Custom style")
fig.savefig("styled.png")
plt.close(fig)

Output:

text
['Solarize_Light2', '_classic_test_patch', '_mpl-gallery', '_mpl-gallery-nogrid', 'bmh', 'classic', 'dark_background', 'fast']

Colormaps

A cmap maps numbers (or category indices) to colors. Pick from one of three families based on your data:

  • Sequential (viridis, magma, plasma, cividis) — for data with a meaningful order from low to high. viridis is the modern default and is perceptually uniform.
  • Diverging (RdBu, coolwarm, PuOr) — for data with a meaningful midpoint (e.g. positive vs negative deviations).
  • Qualitative (tab10, Set2, Pastel1) — for unordered categories.

Avoid jet and rainbow — they introduce false visual gradients that mislead the eye. viridis is colorblind-safe and prints well in grayscale.

python
import matplotlib.pyplot as plt
import numpy as np

data = np.random.default_rng(0).standard_normal((10, 10))

fig, axs = plt.subplots(1, 3, figsize=(12, 3))
for ax, cmap in zip(axs, ["viridis", "RdBu_r", "tab10"]):
    im = ax.imshow(data, cmap=cmap)
    ax.set_title(cmap)
    fig.colorbar(im, ax=ax, fraction=0.046)
fig.savefig("colormaps.png", dpi=120, bbox_inches="tight")
plt.close(fig)

Annotations, legends, and text

Annotations call attention to specific data points; legends explain which line/marker is which series. Use ax.annotate(text, xy=(x, y), xytext=(x2, y2), arrowprops=...) to draw an arrow from a callout to a point, and ax.text(x, y, s) for free-floating labels.

python
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)
y = np.sin(x)

fig, ax = plt.subplots(figsize=(7, 3))
ax.plot(x, y, label="sin(x)")
ax.scatter([np.pi / 2], [1], color="red", zorder=5, label="peak")

ax.annotate(
    "peak at π/2",
    xy=(np.pi / 2, 1),
    xytext=(4, 0.8),
    arrowprops={"arrowstyle": "->", "color": "gray"},
)
ax.legend(loc="lower left", frameon=False)
ax.set_xlabel("x")
ax.set_ylabel("sin(x)")
fig.savefig("annotation.png", dpi=120)
plt.close(fig)

Savefig — file formats and DPI

fig.savefig(path) infers the format from the extension (.png, .pdf, .svg, .jpg, .webp). Vector formats (.pdf, .svg) scale to any size without pixelation — preferred for print and slides. Raster (.png, .jpg) needs a dpi= argument to control resolution; 150 is screen quality, 300 is print, 600+ is publication.

python
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(4, 3))
ax.plot([1, 2, 3], [1, 4, 2])
ax.set_title("Save examples")

# Different output formats
fig.savefig("chart.png", dpi=150, bbox_inches="tight")   # screen
fig.savefig("chart.pdf", bbox_inches="tight")             # vector for print
fig.savefig("chart.svg")                                  # vector for web
fig.savefig("chart_hi.png", dpi=600)                      # publication

plt.close(fig)

Always pass bbox_inches="tight" to crop whitespace around the figure. For PNG output, transparent=True makes the background see-through — useful when embedding in slides or web pages with non-white backgrounds.

Animations

matplotlib.animation.FuncAnimation repeatedly calls an update function and grabs each frame. The result is saved as MP4 (requires ffmpeg), GIF (requires pillow), or HTML5 video for notebooks.

python
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np

fig, ax = plt.subplots(figsize=(6, 3))
x = np.linspace(0, 2 * np.pi, 200)
line, = ax.plot(x, np.sin(x))
ax.set_ylim(-1.2, 1.2)

def update(frame):
    line.set_ydata(np.sin(x + frame / 10))
    return (line,)

anim = animation.FuncAnimation(fig, update, frames=60, interval=50, blit=True)
anim.save("sine.gif", writer="pillow", fps=20)
plt.close(fig)

blit=True is much faster — it only redraws elements that changed between frames. Always return them as a tuple from your update function.

3-D plots

from mpl_toolkits.mplot3d import Axes3D (or subplot_kw={"projection": "3d"}) enables 3-D axes. plot_surface, scatter, and quiver are the common methods. For interactive 3-D in a browser, prefer plotly — matplotlib 3-D rendering is static and slow on large data.

python
import matplotlib.pyplot as plt
import numpy as np

fig = plt.figure(figsize=(6, 4))
ax = fig.add_subplot(projection="3d")

X = np.linspace(-3, 3, 60)
Y = np.linspace(-3, 3, 60)
X, Y = np.meshgrid(X, Y)
Z = np.sin(np.sqrt(X**2 + Y**2))

ax.plot_surface(X, Y, Z, cmap="viridis", linewidth=0)
ax.set_title("3-D surface")
fig.savefig("surface.png", dpi=120)
plt.close(fig)

Backends — interactive vs headless

A backend is the rendering engine that turns matplotlib primitives into pixels. The default depends on your environment; check with plt.get_backend(). For headless CI / Docker, set matplotlib.use("Agg") before importing pyplot — this disables any GUI dependency.

python
import matplotlib
matplotlib.use("Agg")     # must be set before pyplot import
import matplotlib.pyplot as plt

print(plt.get_backend())  # "agg"

fig, ax = plt.subplots()
ax.plot([1, 2, 3], [1, 4, 2])
fig.savefig("ci_safe.png")
plt.close(fig)

Never call plt.show() in a CI script or Docker container — it blocks on a GUI window that may not exist and can crash with RuntimeError: main thread is not in main loop on macOS.

Comparison with seaborn, plotly, and altair

matplotlib is the substrate for the entire Python plotting ecosystem. Pick the right wrapper for the task:

LibraryBuilt onBest for
matplotlibitselfFine-grained control, publication figures, animations, anything custom
seabornmatplotlibStatistical plots (heatmaps, pairplots, violin plots) with shorter code
pandas.plotmatplotlibQuick exploratory charts directly from a DataFrame
plotlyown JS libInteractive browser charts, 3-D, dashboards
altairVega-LiteDeclarative grammar-of-graphics for browser charts
bokehown JS libInteractive plots with linked brushing, large datasets
python
# seaborn — built on matplotlib, great defaults
import seaborn as sns
import matplotlib.pyplot as plt

tips = sns.load_dataset("tips")
sns.boxplot(data=tips, x="day", y="total_bill", hue="sex")
plt.savefig("seaborn_box.png", dpi=120, bbox_inches="tight")
plt.close()
python
# plotly — interactive
import plotly.express as px

fig = px.scatter(tips, x="total_bill", y="tip", color="sex", trendline="ols")
fig.write_html("scatter.html")

Common pitfalls (reference)

Memory leaks from un-closed figures — long-running scripts that create many figures should call plt.close(fig) (or plt.close("all")) after each one. matplotlib retains references to every figure created until you close it.

plt.show() clears the figure — if you call show() and then savefig(), the saved file may be empty. Always savefig first, then show (or skip show entirely in scripts).

Text rendering inconsistencies — by default, matplotlib uses its own font cache. On servers without the expected fonts (e.g. minimal Docker images), labels may render as boxes. Install fonts-dejavu or include a TTF file and register it explicitly.

Latex math rendering is slowusetex=True shells out to a real LaTeX installation and can take seconds per render. For simple math, the default mathtext (r"$\alpha + \beta$") is fast and ships with matplotlib.

Real-world recipes

Dual-axis chart (revenue and orders)

Two metrics on different scales share an x-axis via ax.twinx().

python
import matplotlib.pyplot as plt
import numpy as np

months = np.arange(1, 13)
revenue = 10 + np.cumsum(np.random.default_rng(0).normal(2, 1, 12))
orders = 100 + np.cumsum(np.random.default_rng(1).normal(15, 8, 12))

fig, ax1 = plt.subplots(figsize=(8, 4))
ax1.plot(months, revenue, "b-o", label="revenue")
ax1.set_xlabel("month")
ax1.set_ylabel("revenue (k$)", color="b")
ax1.tick_params(axis="y", labelcolor="b")

ax2 = ax1.twinx()
ax2.plot(months, orders, "r--s", label="orders")
ax2.set_ylabel("orders", color="r")
ax2.tick_params(axis="y", labelcolor="r")

fig.tight_layout()
fig.savefig("dual_axis.png", dpi=120)
plt.close(fig)

Correlation heatmap from a DataFrame

python
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

df = pd.DataFrame(np.random.default_rng(0).standard_normal((100, 5)),
                  columns=list("ABCDE"))
corr = df.corr()

fig, ax = plt.subplots(figsize=(5, 4))
im = ax.imshow(corr, cmap="RdBu_r", vmin=-1, vmax=1)
ax.set_xticks(range(len(corr)), corr.columns)
ax.set_yticks(range(len(corr)), corr.columns)
for i in range(len(corr)):
    for j in range(len(corr)):
        ax.text(j, i, f"{corr.iloc[i, j]:.2f}", ha="center", va="center", fontsize=9)
fig.colorbar(im, ax=ax)
fig.savefig("corr_heatmap.png", dpi=120, bbox_inches="tight")
plt.close(fig)

Grouped bar chart

python
import matplotlib.pyplot as plt
import numpy as np

regions = ["East", "West", "North", "South"]
q1 = [120, 95, 80, 110]
q2 = [135, 100, 92, 115]

x = np.arange(len(regions))
w = 0.35

fig, ax = plt.subplots(figsize=(6, 4))
ax.bar(x - w/2, q1, w, label="Q1")
ax.bar(x + w/2, q2, w, label="Q2")
ax.set_xticks(x, regions)
ax.legend()
ax.set_ylabel("sales (k$)")
ax.set_title("Sales by quarter and region")
fig.savefig("grouped_bar.png", dpi=120, bbox_inches="tight")
plt.close(fig)

Save a figure for a Streamlit dashboard

Streamlit accepts a Figure object directly — no need to save to disk.

python
import streamlit as st
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
ax.plot([1, 2, 3], [4, 1, 5])
st.pyplot(fig)

Plot a regression with confidence band

A pattern that pairs naturally with scipy curve_fit and a ±2σ band drawn with fill_between.

python
import matplotlib.pyplot as plt
import numpy as np

rng = np.random.default_rng(0)
x = np.linspace(0, 10, 50)
y = 2 * x + 1 + rng.normal(scale=2, size=x.size)

# Fit a line: y = ax + b
a, b = np.polyfit(x, y, 1)
y_hat = a * x + b
resid = y - y_hat
sigma = resid.std(ddof=2)

fig, ax = plt.subplots(figsize=(7, 4))
ax.scatter(x, y, alpha=0.6, label="data")
ax.plot(x, y_hat, color="crimson", label=f"y = {a:.2f}x + {b:.2f}")
ax.fill_between(x, y_hat - 2*sigma, y_hat + 2*sigma, color="crimson", alpha=0.15, label="±2σ")
ax.legend()
ax.set_xlabel("x")
ax.set_ylabel("y")
fig.tight_layout()
fig.savefig("regression_band.png", dpi=120)
plt.close(fig)

Stacked area chart

Stacked areas highlight totals and composition simultaneously — useful for "how is each segment contributing over time".

python
import matplotlib.pyplot as plt
import numpy as np

months = np.arange(1, 13)
a = 20 + np.cumsum(np.random.default_rng(0).normal(2, 1, 12))
b = 15 + np.cumsum(np.random.default_rng(1).normal(1, 1, 12))
c = 10 + np.cumsum(np.random.default_rng(2).normal(1.5, 1, 12))

fig, ax = plt.subplots(figsize=(8, 4))
ax.stackplot(months, a, b, c, labels=["Product A", "Product B", "Product C"], alpha=0.85)
ax.legend(loc="upper left")
ax.set_xlabel("month")
ax.set_ylabel("revenue")
ax.set_title("Stacked revenue by product")
fig.savefig("stacked_area.png", dpi=120, bbox_inches="tight")
plt.close(fig)

Programmatic small multiples

Faceted charts ("small multiples") are easy to build with subplots plus a loop — Edward Tufte's preferred alternative to a busy single chart.

python
import matplotlib.pyplot as plt
import numpy as np

rng = np.random.default_rng(0)
regions = ["East", "West", "North", "South"]
fig, axs = plt.subplots(2, 2, figsize=(8, 5), sharex=True, sharey=True)

for ax, region in zip(axs.flat, regions):
    months = np.arange(12)
    sales = 100 + np.cumsum(rng.normal(5, 8, 12))
    ax.plot(months, sales, marker="o")
    ax.set_title(region)
    ax.grid(True, alpha=0.3)

fig.suptitle("Sales by region — small multiples")
fig.supxlabel("month")
fig.supylabel("sales")
fig.tight_layout()
fig.savefig("small_multiples.png", dpi=120)
plt.close(fig)

See also

  • numpy — most matplotlib inputs are NumPy arrays.
  • pandasdf.plot() is a thin matplotlib wrapper.
  • jupyter — inline figures via %matplotlib inline.
  • streamlit — render matplotlib figures in interactive web apps.
  • scipy — pair with matplotlib to visualise curve fits and signal spectra.