cheat sheet
matplotlib
Create publication-quality 2-D plots with matplotlib. Covers pyplot basics, subplots, savefig, common chart types, and the show-vs-save pitfall.
matplotlib — Plotting
What it is
matplotlib is the foundational Python plotting library. It produces static, animated, or interactive figures in PNG, PDF, SVG, and more. Most higher-level visualization libraries (seaborn, pandas .plot(), plotly) wrap or are inspired by matplotlib.
Install
pip install matplotlib
Output: (none — exits 0 on success)
Quick example
import matplotlib.pyplot as plt
x = [1, 2, 3, 4, 5]
y = [2, 4, 3, 5, 4]
plt.plot(x, y, marker="o", color="steelblue")
plt.title("Simple Line Plot")
plt.xlabel("x")
plt.ylabel("y")
plt.tight_layout()
plt.savefig("line.png", dpi=150)
print("Saved line.png")
Output:
Saved line.png
The plot itself is a PNG file. The image shows a connected line with circular markers at each data point.
When / why to use it
- Generating charts for reports, notebooks, or CI artifacts.
- Fine-grained control over every visual element (fonts, tick marks, annotations).
- When you need reproducible figures saved to disk rather than interactive browser charts.
Common pitfalls
plt.show()vsplt.savefig()— in non-interactive scripts, callingplt.show()blocks execution until you close the window. In headless environments (CI, Docker) it may crash. Useplt.savefig()and avoidplt.show()in scripts.
Figure accumulation — calling
plt.plot()withoutplt.figure()orplt.clf()accumulates lines on the same axes across loop iterations. Callplt.figure()at the start of each chart or use the OO interface.
The object-oriented interface (
fig, ax = plt.subplots()) is clearer for complex figures and avoids the global state issues of theplt.*convenience functions.
Richer example — subplots with OO interface
import matplotlib.pyplot as plt
import numpy as np
rng = np.random.default_rng(42)
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
# Left: trig functions
x = np.linspace(0, 2 * np.pi, 200)
axes[0].plot(x, np.sin(x), label="sin", linewidth=2)
axes[0].plot(x, np.cos(x), label="cos", linestyle="--", linewidth=2)
axes[0].set_title("Trig functions")
axes[0].set_xlabel("radians")
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Right: scatter
axes[1].scatter(rng.standard_normal(50), rng.standard_normal(50),
alpha=0.6, c="coral", edgecolors="k", linewidths=0.5)
axes[1].set_title("Random scatter (n=50)")
axes[1].set_xlabel("x")
axes[1].set_ylabel("y")
plt.tight_layout()
plt.savefig("subplots.png", dpi=150, bbox_inches="tight")
print("Saved subplots.png")
Output:
Saved subplots.png
The output PNG shows two side-by-side panels: a smooth sine/cosine wave plot on the left and a scatter cloud on the right.
Common chart types
import matplotlib.pyplot as plt
import numpy as np
fig, axs = plt.subplots(2, 2, figsize=(10, 8))
# Bar chart
categories = ["A", "B", "C", "D"]
values = [23, 45, 12, 67]
axs[0, 0].bar(categories, values, color="steelblue")
axs[0, 0].set_title("Bar chart")
# Histogram
data = np.random.default_rng(0).standard_normal(500)
axs[0, 1].hist(data, bins=30, edgecolor="white")
axs[0, 1].set_title("Histogram")
# Pie chart
axs[1, 0].pie([40, 30, 20, 10], labels=categories, autopct="%1.0f%%")
axs[1, 0].set_title("Pie chart")
# Heatmap
matrix = np.random.default_rng(1).random((5, 5))
im = axs[1, 1].imshow(matrix, cmap="viridis")
fig.colorbar(im, ax=axs[1, 1])
axs[1, 1].set_title("Heatmap")
plt.tight_layout()
plt.savefig("chart_types.png", dpi=120)
print("Saved chart_types.png")
Output:
Saved chart_types.png
pyplot vs object-oriented API
matplotlib exposes two parallel interfaces. The pyplot (stateful, plt.*) API mirrors MATLAB: implicit "current figure" and "current axes" that every call mutates. The object-oriented (OO) API explicitly returns Figure and Axes objects you operate on: fig, ax = plt.subplots(); ax.plot(...). Use pyplot for quick one-off charts in a notebook; use the OO API for anything reusable, multi-panel, or scripted — it scales much better and is the recommended style.
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
# Pyplot style
plt.figure()
plt.plot(x, np.sin(x))
plt.title("pyplot style")
plt.savefig("pyplot_style.png")
plt.close()
# OO style
fig, ax = plt.subplots()
ax.plot(x, np.sin(x))
ax.set_title("OO style")
ax.set_xlabel("x")
ax.set_ylabel("sin(x)")
fig.savefig("oo_style.png")
plt.close(fig)
Figures, axes, and artists
The matplotlib object hierarchy: a Figure is the whole canvas; a Figure contains one or more Axes (each Axes is a single plot area with its own coordinate system); every visible element on the canvas (lines, points, text, ticks, legends) is an Artist. Holding references to these objects lets you modify any property later, animate them, or remove them on demand.
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(6, 4))
line, = ax.plot([1, 2, 3], [4, 1, 5])
text = ax.text(2, 4, "marker")
# Modify properties after creation
line.set_color("crimson")
line.set_linewidth(3)
text.set_fontsize(14)
# Iterate every artist on the axes
for artist in ax.get_children():
print(type(artist).__name__)
Output (truncated):
Line2D
Text
Spine
Spine
Spine
Spine
XAxis
YAxis
...
subplots and GridSpec
plt.subplots(rows, cols) returns a regular grid of axes. gridspec.GridSpec is the flexible alternative — it lets you span multiple cells, give different ratios to rows/columns, and mix subplot sizes within one figure. Use GridSpec when a dashboard needs e.g. a big chart on top, two small ones below.
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
fig = plt.figure(figsize=(10, 6), constrained_layout=True)
gs = gridspec.GridSpec(2, 3, figure=fig, height_ratios=[2, 1])
ax_main = fig.add_subplot(gs[0, :]) # top row, full width
ax_main.plot(np.arange(20), np.sin(np.arange(20) / 2))
ax_main.set_title("main")
for i in range(3):
ax = fig.add_subplot(gs[1, i])
ax.bar(["a", "b", "c"], np.random.default_rng(i).random(3))
ax.set_title(f"panel {i}")
fig.savefig("gridspec.png", dpi=120)
plt.close(fig)
Pass
constrained_layout=Truetoplt.subplots/Figurefor automatic spacing — it is the modern replacement fortight_layout()and handles colorbars and suptitles correctly.
Styling — rcParams, style sheets, and custom themes
plt.rcParams is a global dict of every default visual setting (font, line width, colors, DPI). Set it once at the top of a script for site-wide style. plt.style.use("seaborn-v0_8-whitegrid") applies a packaged stylesheet. For a project-wide style, write a mystyle.mplstyle file with key: value lines and call plt.style.use("mystyle.mplstyle").
import matplotlib.pyplot as plt
import numpy as np
# Pick a built-in style
print(plt.style.available[:8])
plt.style.use("ggplot")
# Or set rcParams individually
plt.rcParams.update({
"figure.dpi": 120,
"font.size": 11,
"axes.titleweight": "bold",
"axes.spines.top": False,
"axes.spines.right": False,
})
fig, ax = plt.subplots(figsize=(6, 3))
ax.plot(np.linspace(0, 10, 100), np.sin(np.linspace(0, 10, 100)))
ax.set_title("Custom style")
fig.savefig("styled.png")
plt.close(fig)
Output:
['Solarize_Light2', '_classic_test_patch', '_mpl-gallery', '_mpl-gallery-nogrid', 'bmh', 'classic', 'dark_background', 'fast']
Colormaps
A cmap maps numbers (or category indices) to colors. Pick from one of three families based on your data:
- Sequential (
viridis,magma,plasma,cividis) — for data with a meaningful order from low to high.viridisis the modern default and is perceptually uniform. - Diverging (
RdBu,coolwarm,PuOr) — for data with a meaningful midpoint (e.g. positive vs negative deviations). - Qualitative (
tab10,Set2,Pastel1) — for unordered categories.
Avoid
jetandrainbow— they introduce false visual gradients that mislead the eye.viridisis colorblind-safe and prints well in grayscale.
import matplotlib.pyplot as plt
import numpy as np
data = np.random.default_rng(0).standard_normal((10, 10))
fig, axs = plt.subplots(1, 3, figsize=(12, 3))
for ax, cmap in zip(axs, ["viridis", "RdBu_r", "tab10"]):
im = ax.imshow(data, cmap=cmap)
ax.set_title(cmap)
fig.colorbar(im, ax=ax, fraction=0.046)
fig.savefig("colormaps.png", dpi=120, bbox_inches="tight")
plt.close(fig)
Annotations, legends, and text
Annotations call attention to specific data points; legends explain which line/marker is which series. Use ax.annotate(text, xy=(x, y), xytext=(x2, y2), arrowprops=...) to draw an arrow from a callout to a point, and ax.text(x, y, s) for free-floating labels.
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
y = np.sin(x)
fig, ax = plt.subplots(figsize=(7, 3))
ax.plot(x, y, label="sin(x)")
ax.scatter([np.pi / 2], [1], color="red", zorder=5, label="peak")
ax.annotate(
"peak at π/2",
xy=(np.pi / 2, 1),
xytext=(4, 0.8),
arrowprops={"arrowstyle": "->", "color": "gray"},
)
ax.legend(loc="lower left", frameon=False)
ax.set_xlabel("x")
ax.set_ylabel("sin(x)")
fig.savefig("annotation.png", dpi=120)
plt.close(fig)
Savefig — file formats and DPI
fig.savefig(path) infers the format from the extension (.png, .pdf, .svg, .jpg, .webp). Vector formats (.pdf, .svg) scale to any size without pixelation — preferred for print and slides. Raster (.png, .jpg) needs a dpi= argument to control resolution; 150 is screen quality, 300 is print, 600+ is publication.
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(4, 3))
ax.plot([1, 2, 3], [1, 4, 2])
ax.set_title("Save examples")
# Different output formats
fig.savefig("chart.png", dpi=150, bbox_inches="tight") # screen
fig.savefig("chart.pdf", bbox_inches="tight") # vector for print
fig.savefig("chart.svg") # vector for web
fig.savefig("chart_hi.png", dpi=600) # publication
plt.close(fig)
Always pass
bbox_inches="tight"to crop whitespace around the figure. For PNG output,transparent=Truemakes the background see-through — useful when embedding in slides or web pages with non-white backgrounds.
Animations
matplotlib.animation.FuncAnimation repeatedly calls an update function and grabs each frame. The result is saved as MP4 (requires ffmpeg), GIF (requires pillow), or HTML5 video for notebooks.
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
fig, ax = plt.subplots(figsize=(6, 3))
x = np.linspace(0, 2 * np.pi, 200)
line, = ax.plot(x, np.sin(x))
ax.set_ylim(-1.2, 1.2)
def update(frame):
line.set_ydata(np.sin(x + frame / 10))
return (line,)
anim = animation.FuncAnimation(fig, update, frames=60, interval=50, blit=True)
anim.save("sine.gif", writer="pillow", fps=20)
plt.close(fig)
blit=Trueis much faster — it only redraws elements that changed between frames. Always return them as a tuple from your update function.
3-D plots
from mpl_toolkits.mplot3d import Axes3D (or subplot_kw={"projection": "3d"}) enables 3-D axes. plot_surface, scatter, and quiver are the common methods. For interactive 3-D in a browser, prefer plotly — matplotlib 3-D rendering is static and slow on large data.
import matplotlib.pyplot as plt
import numpy as np
fig = plt.figure(figsize=(6, 4))
ax = fig.add_subplot(projection="3d")
X = np.linspace(-3, 3, 60)
Y = np.linspace(-3, 3, 60)
X, Y = np.meshgrid(X, Y)
Z = np.sin(np.sqrt(X**2 + Y**2))
ax.plot_surface(X, Y, Z, cmap="viridis", linewidth=0)
ax.set_title("3-D surface")
fig.savefig("surface.png", dpi=120)
plt.close(fig)
Backends — interactive vs headless
A backend is the rendering engine that turns matplotlib primitives into pixels. The default depends on your environment; check with plt.get_backend(). For headless CI / Docker, set matplotlib.use("Agg") before importing pyplot — this disables any GUI dependency.
import matplotlib
matplotlib.use("Agg") # must be set before pyplot import
import matplotlib.pyplot as plt
print(plt.get_backend()) # "agg"
fig, ax = plt.subplots()
ax.plot([1, 2, 3], [1, 4, 2])
fig.savefig("ci_safe.png")
plt.close(fig)
Never call
plt.show()in a CI script or Docker container — it blocks on a GUI window that may not exist and can crash withRuntimeError: main thread is not in main loopon macOS.
Comparison with seaborn, plotly, and altair
matplotlib is the substrate for the entire Python plotting ecosystem. Pick the right wrapper for the task:
| Library | Built on | Best for |
|---|---|---|
| matplotlib | itself | Fine-grained control, publication figures, animations, anything custom |
| seaborn | matplotlib | Statistical plots (heatmaps, pairplots, violin plots) with shorter code |
| pandas.plot | matplotlib | Quick exploratory charts directly from a DataFrame |
| plotly | own JS lib | Interactive browser charts, 3-D, dashboards |
| altair | Vega-Lite | Declarative grammar-of-graphics for browser charts |
| bokeh | own JS lib | Interactive plots with linked brushing, large datasets |
# seaborn — built on matplotlib, great defaults
import seaborn as sns
import matplotlib.pyplot as plt
tips = sns.load_dataset("tips")
sns.boxplot(data=tips, x="day", y="total_bill", hue="sex")
plt.savefig("seaborn_box.png", dpi=120, bbox_inches="tight")
plt.close()
# plotly — interactive
import plotly.express as px
fig = px.scatter(tips, x="total_bill", y="tip", color="sex", trendline="ols")
fig.write_html("scatter.html")
Common pitfalls (reference)
Memory leaks from un-closed figures — long-running scripts that create many figures should call
plt.close(fig)(orplt.close("all")) after each one. matplotlib retains references to every figure created until you close it.
plt.show()clears the figure — if you callshow()and thensavefig(), the saved file may be empty. Alwayssavefigfirst, thenshow(or skipshowentirely in scripts).
Text rendering inconsistencies — by default, matplotlib uses its own font cache. On servers without the expected fonts (e.g. minimal Docker images), labels may render as boxes. Install
fonts-dejavuor include a TTF file and register it explicitly.
Latex math rendering is slow —
usetex=Trueshells out to a real LaTeX installation and can take seconds per render. For simple math, the default mathtext (r"$\alpha + \beta$") is fast and ships with matplotlib.
Real-world recipes
Dual-axis chart (revenue and orders)
Two metrics on different scales share an x-axis via ax.twinx().
import matplotlib.pyplot as plt
import numpy as np
months = np.arange(1, 13)
revenue = 10 + np.cumsum(np.random.default_rng(0).normal(2, 1, 12))
orders = 100 + np.cumsum(np.random.default_rng(1).normal(15, 8, 12))
fig, ax1 = plt.subplots(figsize=(8, 4))
ax1.plot(months, revenue, "b-o", label="revenue")
ax1.set_xlabel("month")
ax1.set_ylabel("revenue (k$)", color="b")
ax1.tick_params(axis="y", labelcolor="b")
ax2 = ax1.twinx()
ax2.plot(months, orders, "r--s", label="orders")
ax2.set_ylabel("orders", color="r")
ax2.tick_params(axis="y", labelcolor="r")
fig.tight_layout()
fig.savefig("dual_axis.png", dpi=120)
plt.close(fig)
Correlation heatmap from a DataFrame
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
df = pd.DataFrame(np.random.default_rng(0).standard_normal((100, 5)),
columns=list("ABCDE"))
corr = df.corr()
fig, ax = plt.subplots(figsize=(5, 4))
im = ax.imshow(corr, cmap="RdBu_r", vmin=-1, vmax=1)
ax.set_xticks(range(len(corr)), corr.columns)
ax.set_yticks(range(len(corr)), corr.columns)
for i in range(len(corr)):
for j in range(len(corr)):
ax.text(j, i, f"{corr.iloc[i, j]:.2f}", ha="center", va="center", fontsize=9)
fig.colorbar(im, ax=ax)
fig.savefig("corr_heatmap.png", dpi=120, bbox_inches="tight")
plt.close(fig)
Grouped bar chart
import matplotlib.pyplot as plt
import numpy as np
regions = ["East", "West", "North", "South"]
q1 = [120, 95, 80, 110]
q2 = [135, 100, 92, 115]
x = np.arange(len(regions))
w = 0.35
fig, ax = plt.subplots(figsize=(6, 4))
ax.bar(x - w/2, q1, w, label="Q1")
ax.bar(x + w/2, q2, w, label="Q2")
ax.set_xticks(x, regions)
ax.legend()
ax.set_ylabel("sales (k$)")
ax.set_title("Sales by quarter and region")
fig.savefig("grouped_bar.png", dpi=120, bbox_inches="tight")
plt.close(fig)
Save a figure for a Streamlit dashboard
Streamlit accepts a Figure object directly — no need to save to disk.
import streamlit as st
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.plot([1, 2, 3], [4, 1, 5])
st.pyplot(fig)
Plot a regression with confidence band
A pattern that pairs naturally with scipy curve_fit and a ±2σ band drawn with fill_between.
import matplotlib.pyplot as plt
import numpy as np
rng = np.random.default_rng(0)
x = np.linspace(0, 10, 50)
y = 2 * x + 1 + rng.normal(scale=2, size=x.size)
# Fit a line: y = ax + b
a, b = np.polyfit(x, y, 1)
y_hat = a * x + b
resid = y - y_hat
sigma = resid.std(ddof=2)
fig, ax = plt.subplots(figsize=(7, 4))
ax.scatter(x, y, alpha=0.6, label="data")
ax.plot(x, y_hat, color="crimson", label=f"y = {a:.2f}x + {b:.2f}")
ax.fill_between(x, y_hat - 2*sigma, y_hat + 2*sigma, color="crimson", alpha=0.15, label="±2σ")
ax.legend()
ax.set_xlabel("x")
ax.set_ylabel("y")
fig.tight_layout()
fig.savefig("regression_band.png", dpi=120)
plt.close(fig)
Stacked area chart
Stacked areas highlight totals and composition simultaneously — useful for "how is each segment contributing over time".
import matplotlib.pyplot as plt
import numpy as np
months = np.arange(1, 13)
a = 20 + np.cumsum(np.random.default_rng(0).normal(2, 1, 12))
b = 15 + np.cumsum(np.random.default_rng(1).normal(1, 1, 12))
c = 10 + np.cumsum(np.random.default_rng(2).normal(1.5, 1, 12))
fig, ax = plt.subplots(figsize=(8, 4))
ax.stackplot(months, a, b, c, labels=["Product A", "Product B", "Product C"], alpha=0.85)
ax.legend(loc="upper left")
ax.set_xlabel("month")
ax.set_ylabel("revenue")
ax.set_title("Stacked revenue by product")
fig.savefig("stacked_area.png", dpi=120, bbox_inches="tight")
plt.close(fig)
Programmatic small multiples
Faceted charts ("small multiples") are easy to build with subplots plus a loop — Edward Tufte's preferred alternative to a busy single chart.
import matplotlib.pyplot as plt
import numpy as np
rng = np.random.default_rng(0)
regions = ["East", "West", "North", "South"]
fig, axs = plt.subplots(2, 2, figsize=(8, 5), sharex=True, sharey=True)
for ax, region in zip(axs.flat, regions):
months = np.arange(12)
sales = 100 + np.cumsum(rng.normal(5, 8, 12))
ax.plot(months, sales, marker="o")
ax.set_title(region)
ax.grid(True, alpha=0.3)
fig.suptitle("Sales by region — small multiples")
fig.supxlabel("month")
fig.supylabel("sales")
fig.tight_layout()
fig.savefig("small_multiples.png", dpi=120)
plt.close(fig)