Added --subplot* to plotmpl.py

Driven primarily by a want to compare measurements of different runtime
complexities (it's difficult to fit O(n) and O(log n) on the same plot),
this adds the ability to nest subplots in the same .svg which try to align
as much as possible. This turned out to be surprisingly complicated.

As a part of this, adopted matplotlib's relatively recent
constrained_layout, which behaves much more consistently.

Also dropped --legend-left, no one should really be using that.
This commit is contained in:
Christopher Haster
2022-12-13 15:26:41 -06:00
parent 2d2dd8b2eb
commit cfd4e6029a

View File

@ -14,9 +14,11 @@ import collections as co
import csv
import io
import itertools as it
import logging
import math as m
import numpy as np
import os
import shlex
import shutil
import time
@ -54,7 +56,7 @@ FORMATS = ['-']
FORMATS_POINTS = ['.']
FORMATS_POINTS_AND_LINES = ['.-']
WIDTH = 735
WIDTH = 750
HEIGHT = 350
FONT_SIZE = 11
@ -123,6 +125,10 @@ def si2(x):
s = s.rstrip('.')
return '%s%s%s' % ('-' if x < 0 else '', s, SI2_PREFIXES[p])
# parse escape strings
def escape(s):
return codecs.escape_decode(s.encode('utf8'))[0].decode('utf8')
# we want to use MaxNLocator, but since MaxNLocator forces multiples of 10
# to be an option, we can't really...
class AutoMultipleLocator(mpl.ticker.MultipleLocator):
@ -292,6 +298,266 @@ def datasets(results, by=None, x=None, y=None, define=[]):
return datasets
# some classes for organizing subplots into a grid
class Subplot:
def __init__(self, **args):
self.x = 0
self.y = 0
self.xspan = 1
self.yspan = 1
self.args = args
class Grid:
def __init__(self, subplot, width=1.0, height=1.0):
self.xweights = [width]
self.yweights = [height]
self.map = {(0,0): subplot}
self.subplots = [subplot]
def __repr__(self):
return 'Grid(%r, %r)' % (self.xweights, self.yweights)
@property
def width(self):
return len(self.xweights)
@property
def height(self):
return len(self.yweights)
def __iter__(self):
return iter(self.subplots)
def __getitem__(self, i):
x, y = i
if x < 0:
x += len(self.xweights)
if y < 0:
y += len(self.yweights)
return self.map[(x,y)]
def merge(self, other, dir):
if dir in ['above', 'below']:
# first scale the two grids so they line up
self_xweights = self.xweights
other_xweights = other.xweights
self_w = sum(self_xweights)
other_w = sum(other_xweights)
ratio = self_w / other_w
other_xweights = [s*ratio for s in other_xweights]
# now interleave xweights as needed
new_xweights = []
self_map = {}
other_map = {}
self_i = 0
other_i = 0
self_xweight = (self_xweights[self_i]
if self_i < len(self_xweights) else m.inf)
other_xweight = (other_xweights[other_i]
if other_i < len(other_xweights) else m.inf)
while self_i < len(self_xweights) and other_i < len(other_xweights):
if other_xweight - self_xweight > 0.0000001:
new_xweights.append(self_xweight)
other_xweight -= self_xweight
new_i = len(new_xweights)-1
for j in range(len(self.yweights)):
self_map[(new_i, j)] = self.map[(self_i, j)]
for j in range(len(other.yweights)):
other_map[(new_i, j)] = other.map[(other_i, j)]
for s in other.subplots:
if s.x+s.xspan-1 == new_i:
s.xspan += 1
elif s.x > new_i:
s.x += 1
self_i += 1
self_xweight = (self_xweights[self_i]
if self_i < len(self_xweights) else m.inf)
elif self_xweight - other_xweight > 0.0000001:
new_xweights.append(other_xweight)
self_xweight -= other_xweight
new_i = len(new_xweights)-1
for j in range(len(other.yweights)):
other_map[(new_i, j)] = other.map[(other_i, j)]
for j in range(len(self.yweights)):
self_map[(new_i, j)] = self.map[(self_i, j)]
for s in self.subplots:
if s.x+s.xspan-1 == new_i:
s.xspan += 1
elif s.x > new_i:
s.x += 1
other_i += 1
other_xweight = (other_xweights[other_i]
if other_i < len(other_xweights) else m.inf)
else:
new_xweights.append(self_xweight)
new_i = len(new_xweights)-1
for j in range(len(self.yweights)):
self_map[(new_i, j)] = self.map[(self_i, j)]
for j in range(len(other.yweights)):
other_map[(new_i, j)] = other.map[(other_i, j)]
self_i += 1
self_xweight = (self_xweights[self_i]
if self_i < len(self_xweights) else m.inf)
other_i += 1
other_xweight = (other_xweights[other_i]
if other_i < len(other_xweights) else m.inf)
# squish so ratios are preserved
self_h = sum(self.yweights)
other_h = sum(other.yweights)
ratio = (self_h-other_h) / self_h
self_yweights = [s*ratio for s in self.yweights]
# finally concatenate the two grids
if dir == 'above':
for s in other.subplots:
s.y += len(self_yweights)
self.subplots.extend(other.subplots)
self.xweights = new_xweights
self.yweights = self_yweights + other.yweights
self.map = self_map | {(x, y+len(self_yweights)): s
for (x, y), s in other_map.items()}
else:
for s in self.subplots:
s.y += len(other.yweights)
self.subplots.extend(other.subplots)
self.xweights = new_xweights
self.yweights = other.yweights + self_yweights
self.map = other_map | {(x, y+len(other.yweights)): s
for (x, y), s in self_map.items()}
if dir in ['right', 'left']:
# first scale the two grids so they line up
self_yweights = self.yweights
other_yweights = other.yweights
self_h = sum(self_yweights)
other_h = sum(other_yweights)
ratio = self_h / other_h
other_yweights = [s*ratio for s in other_yweights]
# now interleave yweights as needed
new_yweights = []
self_map = {}
other_map = {}
self_i = 0
other_i = 0
self_yweight = (self_yweights[self_i]
if self_i < len(self_yweights) else m.inf)
other_yweight = (other_yweights[other_i]
if other_i < len(other_yweights) else m.inf)
while self_i < len(self_yweights) and other_i < len(other_yweights):
if other_yweight - self_yweight > 0.0000001:
new_yweights.append(self_yweight)
other_yweight -= self_yweight
new_i = len(new_yweights)-1
for j in range(len(self.xweights)):
self_map[(j, new_i)] = self.map[(j, self_i)]
for j in range(len(other.xweights)):
other_map[(j, new_i)] = other.map[(j, other_i)]
for s in other.subplots:
if s.y+s.yspan-1 == new_i:
s.yspan += 1
elif s.y > new_i:
s.y += 1
self_i += 1
self_yweight = (self_yweights[self_i]
if self_i < len(self_yweights) else m.inf)
elif self_yweight - other_yweight > 0.0000001:
new_yweights.append(other_yweight)
self_yweight -= other_yweight
new_i = len(new_yweights)-1
for j in range(len(other.xweights)):
other_map[(j, new_i)] = other.map[(j, other_i)]
for j in range(len(self.xweights)):
self_map[(j, new_i)] = self.map[(j, self_i)]
for s in self.subplots:
if s.y+s.yspan-1 == new_i:
s.yspan += 1
elif s.y > new_i:
s.y += 1
other_i += 1
other_yweight = (other_yweights[other_i]
if other_i < len(other_yweights) else m.inf)
else:
new_yweights.append(self_yweight)
new_i = len(new_yweights)-1
for j in range(len(self.xweights)):
self_map[(j, new_i)] = self.map[(j, self_i)]
for j in range(len(other.xweights)):
other_map[(j, new_i)] = other.map[(j, other_i)]
self_i += 1
self_yweight = (self_yweights[self_i]
if self_i < len(self_yweights) else m.inf)
other_i += 1
other_yweight = (other_yweights[other_i]
if other_i < len(other_yweights) else m.inf)
# squish so ratios are preserved
self_w = sum(self.xweights)
other_w = sum(other.xweights)
ratio = (self_w-other_w) / self_w
self_xweights = [s*ratio for s in self.xweights]
# finally concatenate the two grids
if dir == 'right':
for s in other.subplots:
s.x += len(self_xweights)
self.subplots.extend(other.subplots)
self.xweights = self_xweights + other.xweights
self.yweights = new_yweights
self.map = self_map | {(x+len(self_xweights), y): s
for (x, y), s in other_map.items()}
else:
for s in self.subplots:
s.x += len(other.xweights)
self.subplots.extend(other.subplots)
self.xweights = other.xweights + self_xweights
self.yweights = new_yweights
self.map = other_map | {(x+len(other.xweights), y): s
for (x, y), s in self_map.items()}
def scale(self, width, height):
self.xweights = [s*width for s in self.xweights]
self.yweights = [s*height for s in self.yweights]
@classmethod
def fromargs(cls, width=1.0, height=1.0, *,
subplots=[],
**args):
grid = cls(Subplot(**args))
for dir, subargs in subplots:
subgrid = cls.fromargs(
width=subargs.pop('width',
0.5 if dir in ['right', 'left'] else width),
height=subargs.pop('height',
0.5 if dir in ['above', 'below'] else height),
**subargs)
grid.merge(subgrid, dir)
grid.scale(width, height)
return grid
def main(csv_paths, output, *,
svg=False,
png=False,
@ -321,7 +587,9 @@ def main(csv_paths, output, *,
xticklabels=None,
yticklabels=None,
title=None,
legend=None,
legend_right=False,
legend_above=False,
legend_below=False,
dark=False,
ggplot=False,
xkcd=False,
@ -330,7 +598,10 @@ def main(csv_paths, output, *,
font_size=FONT_SIZE,
font_color=None,
foreground=None,
background=None):
background=None,
subplot={},
subplots=[],
**args):
# guess the output format
if not png and not svg:
if output.endswith('.png'):
@ -338,23 +609,6 @@ def main(csv_paths, output, *,
else:
svg = True
# allow shortened ranges
if len(xlim) == 1:
xlim = (0, xlim[0])
if len(ylim) == 1:
ylim = (0, ylim[0])
# separate out renames
renames = list(it.chain.from_iterable(
((k, v) for v in vs)
for k, vs in it.chain(by or [], x or [], y or [])))
if by is not None:
by = [k for k, _ in by]
if x is not None:
x = [k for k, _ in x]
if y is not None:
y = [k for k, _ in y]
# some shortcuts for color schemes
if github:
ggplot = True
@ -412,22 +666,10 @@ def main(csv_paths, output, *,
else:
background_ = '#ffffff'
# allow escape codes in labels/titles
if title is not None:
title = codecs.escape_decode(title.encode('utf8'))[0].decode('utf8')
if xlabel is not None:
xlabel = codecs.escape_decode(xlabel.encode('utf8'))[0].decode('utf8')
if ylabel is not None:
ylabel = codecs.escape_decode(ylabel.encode('utf8'))[0].decode('utf8')
# first collect results from CSV files
results = collect(csv_paths, renames)
# then extract the requested datasets
datasets_ = datasets(results, by, x, y, define)
# configure some matplotlib settings
if xkcd:
# the font search here prints a bunch of unhelpful warnings
logging.getLogger('matplotlib.font_manager').setLevel(logging.ERROR)
plt.xkcd()
# turn off the white outline, this breaks some things
plt.rc('path', effects=[])
@ -454,9 +696,11 @@ def main(csv_paths, output, *,
plt.rc('font', family=font)
plt.rc('font', size=font_size)
plt.rc('text', color=font_color_)
plt.rc('figure', titlesize='medium')
plt.rc('axes',
plt.rc('figure',
titlesize='medium',
labelsize='small')
plt.rc('axes',
titlesize='small',
labelsize='small',
labelcolor=font_color_)
if not ggplot:
@ -475,245 +719,294 @@ def main(csv_paths, output, *,
if not ggplot:
plt.rc('axes', facecolor='#00000000')
# I think the svg backend just ignores DPI, but seems to use something
# equivalent to 96, maybe this is the default for SVG rendering?
plt.rc('figure', dpi=96)
# separate out renames
renames = list(it.chain.from_iterable(
((k, v) for v in vs)
for k, vs in it.chain(by or [], x or [], y or [])))
if by is not None:
by = [k for k, _ in by]
if x is not None:
x = [k for k, _ in x]
if y is not None:
y = [k for k, _ in y]
# first collect results from CSV files
results = collect(csv_paths, renames)
# then extract the requested datasets
datasets_ = datasets(results, by, x, y, define)
# figure out formats/colors here so that subplot defines
# don't change them later, that'd be bad
dataformats_ = {
name: formats_[i % len(formats_)]
for i, name in enumerate(datasets_.keys())}
datacolors_ = {
name: colors_[i % len(colors_)]
for i, name in enumerate(datasets_.keys())}
# create a grid of subplots
grid = Grid.fromargs(
subplots=subplots + subplot.pop('subplots', []),
**subplot)
# create a matplotlib plot
fig = plt.figure(figsize=(
width/plt.rcParams['figure.dpi'],
height/plt.rcParams['figure.dpi']),
layout='constrained',
# we need a linewidth to keep xkcd mode happy
linewidth=8 if xkcd else 0)
ax = fig.subplots()
for i, (name, dataset) in enumerate(datasets_.items()):
dats = sorted((x,y) for x,y in dataset.items())
ax.plot([x for x,_ in dats], [y for _,y in dats],
formats_[i % len(formats_)],
color=colors_[i % len(colors_)],
label=','.join(k for k in name if k))
gs = fig.add_gridspec(
grid.height
+ (1 if legend_above else 0)
+ (1 if legend_below else 0),
grid.width
+ (1 if legend_right else 0),
height_ratios=([0.001] if legend_above else [])
+ [max(s, 0.01) for s in reversed(grid.yweights)]
+ ([0.001] if legend_below else []),
width_ratios=[max(s, 0.01) for s in grid.xweights]
+ ([0.001] if legend_right else []))
# axes scaling
if xlog:
ax.set_xscale('symlog')
ax.xaxis.set_minor_locator(mpl.ticker.NullLocator())
if ylog:
ax.set_yscale('symlog')
ax.yaxis.set_minor_locator(mpl.ticker.NullLocator())
# axes limits
ax.set_xlim(
xlim[0] if xlim[0] is not None
else min(it.chain([0], (k
for r in datasets_.values()
for k, v in r.items()
if v is not None))),
xlim[1] if xlim[1] is not None
else max(it.chain([0], (k
for r in datasets_.values()
for k, v in r.items()
if v is not None))))
ax.set_ylim(
ylim[0] if ylim[0] is not None
else min(it.chain([0], (v
for r in datasets_.values()
for _, v in r.items()
if v is not None))),
ylim[1] if ylim[1] is not None
else max(it.chain([0], (v
for r in datasets_.values()
for _, v in r.items()
if v is not None))))
# axes ticks
if x2:
ax.xaxis.set_major_formatter(lambda x, pos:
si2(x)+(xunits if xunits else ''))
if xticklabels is not None:
ax.xaxis.set_ticklabels(xticklabels)
if xticks is None:
ax.xaxis.set_major_locator(AutoMultipleLocator(2))
elif isinstance(xticks, list):
ax.xaxis.set_major_locator(mpl.ticker.FixedLocator(xticks))
elif xticks != 0:
ax.xaxis.set_major_locator(AutoMultipleLocator(2, xticks-1))
# first create axes so that plots can interact with each other
for s in grid:
s.ax = fig.add_subplot(gs[
grid.height-(s.y+s.yspan) + (1 if legend_above else 0)
: grid.height-s.y + (1 if legend_above else 0),
s.x
: s.x+s.xspan])
# now plot each subplot
for s in grid:
# allow subplot params to override global params
define_ = define + s.args.get('define', [])
xlim_ = s.args.get('xlim', xlim)
ylim_ = s.args.get('ylim', ylim)
xlog_ = s.args.get('xlog', False) or xlog
ylog_ = s.args.get('ylog', False) or ylog
x2_ = s.args.get('x2', False) or x2
y2_ = s.args.get('y2', False) or y2
xticks_ = s.args.get('xticks', xticks)
yticks_ = s.args.get('yticks', yticks)
xunits_ = s.args.get('xunits', xunits)
yunits_ = s.args.get('yunits', yunits)
xticklabels_ = s.args.get('xticklabels', xticklabels)
yticklabels_ = s.args.get('yticklabels', yticklabels)
# label/titles are handled a bit differently in subplots
subtitle = s.args.get('title')
xsublabel = s.args.get('xlabel')
ysublabel = s.args.get('ylabel')
# allow shortened ranges
if len(xlim_) == 1:
xlim_ = (0, xlim_[0])
if len(ylim_) == 1:
ylim_ = (0, ylim_[0])
# data can be constrained by subplot-specific defines,
# so re-extract for each plot
subdatasets = datasets(results, by, x, y, define_)
# plot!
ax = s.ax
for name, dataset in subdatasets.items():
dats = sorted((x,y) for x,y in dataset.items())
ax.plot([x for x,_ in dats], [y for _,y in dats],
dataformats_[name],
color=datacolors_[name],
label=','.join(k for k in name if k))
# axes scaling
if xlog_:
ax.set_xscale('symlog')
ax.xaxis.set_minor_locator(mpl.ticker.NullLocator())
if ylog_:
ax.set_yscale('symlog')
ax.yaxis.set_minor_locator(mpl.ticker.NullLocator())
# axes limits
ax.set_xlim(
xlim_[0] if xlim_[0] is not None
else min(it.chain([0], (k
for r in subdatasets.values()
for k, v in r.items()
if v is not None))),
xlim_[1] if xlim_[1] is not None
else max(it.chain([0], (k
for r in subdatasets.values()
for k, v in r.items()
if v is not None))))
ax.set_ylim(
ylim_[0] if ylim_[0] is not None
else min(it.chain([0], (v
for r in subdatasets.values()
for _, v in r.items()
if v is not None))),
ylim_[1] if ylim_[1] is not None
else max(it.chain([0], (v
for r in subdatasets.values()
for _, v in r.items()
if v is not None))))
# axes ticks
if x2_:
ax.xaxis.set_major_formatter(lambda x, pos:
si2(x)+(xunits_ if xunits_ else ''))
if xticklabels_ is not None:
ax.xaxis.set_ticklabels(xticklabels_)
if xticks_ is None:
ax.xaxis.set_major_locator(AutoMultipleLocator(2))
elif isinstance(xticks_, list):
ax.xaxis.set_major_locator(mpl.ticker.FixedLocator(xticks_))
elif xticks_ != 0:
ax.xaxis.set_major_locator(AutoMultipleLocator(2, xticks_-1))
else:
ax.xaxis.set_major_locator(mpl.ticker.NullLocator())
else:
ax.xaxis.set_major_locator(mpl.ticker.NullLocator())
else:
ax.xaxis.set_major_formatter(lambda x, pos:
si(x)+(xunits if xunits else ''))
if xticklabels is not None:
ax.xaxis.set_ticklabels(xticklabels)
if xticks is None:
ax.xaxis.set_major_locator(mpl.ticker.AutoLocator())
elif isinstance(xticks, list):
ax.xaxis.set_major_locator(mpl.ticker.FixedLocator(xticks))
elif xticks != 0:
ax.xaxis.set_major_locator(mpl.ticker.MaxNLocator(xticks-1))
ax.xaxis.set_major_formatter(lambda x, pos:
si(x)+(xunits_ if xunits_ else ''))
if xticklabels_ is not None:
ax.xaxis.set_ticklabels(xticklabels_)
if xticks_ is None:
ax.xaxis.set_major_locator(mpl.ticker.AutoLocator())
elif isinstance(xticks_, list):
ax.xaxis.set_major_locator(mpl.ticker.FixedLocator(xticks_))
elif xticks_ != 0:
ax.xaxis.set_major_locator(mpl.ticker.MaxNLocator(xticks_-1))
else:
ax.xaxis.set_major_locator(mpl.ticker.NullLocator())
if y2_:
ax.yaxis.set_major_formatter(lambda x, pos:
si2(x)+(yunits_ if yunits_ else ''))
if yticklabels_ is not None:
ax.yaxis.set_ticklabels(yticklabels_)
if yticks_ is None:
ax.yaxis.set_major_locator(AutoMultipleLocator(2))
elif isinstance(yticks_, list):
ax.yaxis.set_major_locator(mpl.ticker.FixedLocator(yticks_))
elif yticks_ != 0:
ax.yaxis.set_major_locator(AutoMultipleLocator(2, yticks_-1))
else:
ax.yaxis.set_major_locator(mpl.ticker.NullLocator())
else:
ax.xaxis.set_major_locator(mpl.ticker.NullLocator())
if y2:
ax.yaxis.set_major_formatter(lambda x, pos:
si2(x)+(yunits if yunits else ''))
if yticklabels is not None:
ax.yaxis.set_ticklabels(yticklabels)
if yticks is None:
ax.yaxis.set_major_locator(AutoMultipleLocator(2))
elif isinstance(yticks, list):
ax.yaxis.set_major_locator(mpl.ticker.FixedLocator(yticks))
elif yticks != 0:
ax.yaxis.set_major_locator(AutoMultipleLocator(2, yticks-1))
else:
ax.yaxis.set_major_locator(mpl.ticker.NullLocator())
else:
ax.yaxis.set_major_formatter(lambda x, pos:
si(x)+(yunits if yunits else ''))
if yticklabels is not None:
ax.yaxis.set_ticklabels(yticklabels)
if yticks is None:
ax.yaxis.set_major_locator(mpl.ticker.AutoLocator())
elif isinstance(yticks, list):
ax.yaxis.set_major_locator(mpl.ticker.FixedLocator(yticks))
elif yticks != 0:
ax.yaxis.set_major_locator(mpl.ticker.MaxNLocator(yticks-1))
else:
ax.yaxis.set_major_locator(mpl.ticker.NullLocator())
# axes labels
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
if ggplot:
ax.grid(sketch_params=None)
ax.yaxis.set_major_formatter(lambda x, pos:
si(x)+(yunits_ if yunits_ else ''))
if yticklabels_ is not None:
ax.yaxis.set_ticklabels(yticklabels_)
if yticks_ is None:
ax.yaxis.set_major_locator(mpl.ticker.AutoLocator())
elif isinstance(yticks_, list):
ax.yaxis.set_major_locator(mpl.ticker.FixedLocator(yticks_))
elif yticks_ != 0:
ax.yaxis.set_major_locator(mpl.ticker.MaxNLocator(yticks_-1))
else:
ax.yaxis.set_major_locator(mpl.ticker.NullLocator())
if ggplot:
ax.grid(sketch_params=None)
if title is not None:
ax.set_title(title)
# axes subplot labels
if xsublabel is not None:
ax.set_xlabel(escape(xsublabel))
if ysublabel is not None:
ax.set_ylabel(escape(ysublabel))
if subtitle is not None:
ax.set_title(escape(subtitle))
# pre-render so we can derive some bboxes
fig.tight_layout()
# it's not clear how you're actually supposed to get the renderer if
# get_renderer isn't supported
try:
renderer = fig.canvas.get_renderer()
except AttributeError:
renderer = fig._cachedRenderer
# add a legend? a bit tricky with matplotlib
#
# the best solution I've found is a dedicated, invisible axes for the
# legend, hacky, but it works.
#
# note this was written before constrained_layout supported legend
# collisions, hopefully this is added in the future
labels = co.OrderedDict()
for s in grid:
for h, l in zip(*s.ax.get_legend_handles_labels()):
labels[l] = h
# add a legend? this actually ends up being _really_ complicated
if legend == 'right':
l_pad = fig.transFigure.inverted().transform((
mpl.font_manager.FontProperties('small')
.get_size_in_points()/2,
0))[0]
legend_ = ax.legend(
bbox_to_anchor=(1+l_pad, 1),
if legend_right:
ax = fig.add_subplot(gs[(1 if legend_above else 0):,-1])
ax.set_axis_off()
ax.legend(
labels.values(),
labels.keys(),
loc='upper left',
fancybox=False,
borderaxespad=0)
if ggplot:
legend_.get_frame().set_linewidth(0)
fig.tight_layout()
elif legend == 'left':
l_pad = fig.transFigure.inverted().transform((
mpl.font_manager.FontProperties('small')
.get_size_in_points()/2,
0))[0]
# place legend somewhere to get its bbox
legend_ = ax.legend(
bbox_to_anchor=(0, 1),
loc='upper right',
fancybox=False,
borderaxespad=0)
# first make space for legend without the legend in the figure
l_bbox = (legend_.get_tightbbox(renderer)
.transformed(fig.transFigure.inverted()))
legend_.remove()
fig.tight_layout(rect=(0, 0, 1-l_bbox.width-l_pad, 1))
# place legend after tight_layout computation
bbox = (ax.get_tightbbox(renderer)
.transformed(ax.transAxes.inverted()))
legend_ = ax.legend(
bbox_to_anchor=(bbox.x0-l_pad, 1),
loc='upper right',
fancybox=False,
borderaxespad=0)
if ggplot:
legend_.get_frame().set_linewidth(0)
elif legend == 'above':
l_pad = fig.transFigure.inverted().transform((
0,
mpl.font_manager.FontProperties('small')
.get_size_in_points()/2))[1]
if legend_above:
ax = fig.add_subplot(gs[0, :grid.width])
ax.set_axis_off()
# try different column counts until we fit in the axes
for ncol in reversed(range(1, len(datasets_)+1)):
for ncol in reversed(range(1, len(labels)+1)):
legend_ = ax.legend(
bbox_to_anchor=(0.5, 1+l_pad),
loc='lower center',
ncol=ncol,
fancybox=False,
borderaxespad=0)
if ggplot:
legend_.get_frame().set_linewidth(0)
l_bbox = (legend_.get_tightbbox(renderer)
.transformed(ax.transAxes.inverted()))
if l_bbox.x0 >= 0:
break
# fix the title
if title is not None:
t_bbox = (ax.title.get_tightbbox(renderer)
.transformed(ax.transAxes.inverted()))
ax.set_title(None)
fig.tight_layout(rect=(0, 0, 1, 1-t_bbox.height))
l_bbox = (legend_.get_tightbbox(renderer)
.transformed(ax.transAxes.inverted()))
ax.set_title(title, y=1+l_bbox.height+l_pad)
elif legend == 'below':
l_pad = fig.transFigure.inverted().transform((
0,
mpl.font_manager.FontProperties('small')
.get_size_in_points()/2))[1]
# try different column counts until we fit in the axes
for ncol in reversed(range(1, len(datasets_)+1)):
legend_ = ax.legend(
bbox_to_anchor=(0.5, 0),
labels.values(),
labels.keys(),
loc='upper center',
ncol=ncol,
fancybox=False,
borderaxespad=0)
l_bbox = (legend_.get_tightbbox(renderer)
.transformed(ax.transAxes.inverted()))
if l_bbox.x0 >= 0:
if (legend_.get_window_extent().width
<= ax.get_window_extent().width):
break
# first make space for legend without the legend in the figure
l_bbox = (legend_.get_tightbbox(renderer)
.transformed(fig.transFigure.inverted()))
legend_.remove()
fig.tight_layout(rect=(0, 0, 1, 1-l_bbox.height-l_pad))
if legend_below:
ax = fig.add_subplot(gs[-1, :grid.width])
ax.set_axis_off()
bbox = (ax.get_tightbbox(renderer)
.transformed(ax.transAxes.inverted()))
legend_ = ax.legend(
bbox_to_anchor=(0.5, bbox.y0-l_pad),
loc='upper center',
ncol=ncol,
fancybox=False,
borderaxespad=0)
if ggplot:
legend_.get_frame().set_linewidth(0)
# big hack to get xlabel above the legend! but hey this
# works really well actually
if xlabel:
ax.set_title(escape(xlabel),
size=plt.rcParams['axes.labelsize'],
weight=plt.rcParams['axes.labelweight'])
# compute another tight_layout for good measure, because this _does_
# fix some things... I don't really know why though
fig.tight_layout()
# try different column counts until we fit in the axes
for ncol in reversed(range(1, len(labels)+1)):
legend_ = ax.legend(
labels.values(),
labels.keys(),
loc='upper center',
ncol=ncol,
fancybox=False,
borderaxespad=0)
plt.savefig(output, format='png' if png else 'svg', bbox_inches='tight')
if (legend_.get_window_extent().width
<= ax.get_window_extent().width):
break
# axes labels, NOTE we reposition these below
if xlabel is not None and not legend_below:
fig.supxlabel(escape(xlabel))
if ylabel is not None:
fig.supylabel(escape(ylabel))
if title is not None:
fig.suptitle(escape(title))
# precompute constrained layout and find midpoints to adjust things
# that should be centered so they are actually centered
fig.canvas.draw()
xmid = (grid[0,0].ax.get_position().x0 + grid[-1,0].ax.get_position().x1)/2
ymid = (grid[0,0].ax.get_position().y0 + grid[0,-1].ax.get_position().y1)/2
if xlabel is not None and not legend_below:
fig.supxlabel(escape(xlabel), x=xmid)
if ylabel is not None:
fig.supylabel(escape(ylabel), y=ymid)
if title is not None:
fig.suptitle(escape(title), x=xmid)
# write the figure!
plt.savefig(output, format='png' if png else 'svg')
# some stats
if not quiet:
@ -733,7 +1026,7 @@ if __name__ == "__main__":
'csv_paths',
nargs='*',
help="Input *.csv files.")
parser.add_argument(
output_rule = parser.add_argument(
'-o', '--output',
required=True,
help="Output *.svg/*.png file.")
@ -867,11 +1160,17 @@ if __name__ == "__main__":
'-t', '--title',
help="Add a title.")
parser.add_argument(
'-l', '--legend',
nargs='?',
choices=['above', 'below', 'left', 'right'],
const='right',
help="Place a legend here.")
'-l', '--legend-right',
action='store_true',
help="Place a legend to the right.")
parser.add_argument(
'--legend-above',
action='store_true',
help="Place a legend above.")
parser.add_argument(
'--legend-below',
action='store_true',
help="Place a legend below.")
parser.add_argument(
'--dark',
action='store_true',
@ -904,6 +1203,56 @@ if __name__ == "__main__":
parser.add_argument(
'--background',
help="Background color to use.")
sys.exit(main(**{k: v
for k, v in vars(parser.parse_intermixed_args()).items()
if v is not None}))
class AppendSubplot(argparse.Action):
@staticmethod
def parse(value):
import copy
subparser = copy.deepcopy(parser)
next(a for a in subparser._actions
if '--output' in a.option_strings).required = False
next(a for a in subparser._actions
if '--width' in a.option_strings).type = float
next(a for a in subparser._actions
if '--height' in a.option_strings).type = float
return subparser.parse_intermixed_args(shlex.split(value or ""))
def __call__(self, parser, namespace, value, option):
if not hasattr(namespace, 'subplots'):
namespace.subplots = []
namespace.subplots.append((
option.split('-')[-1],
self.__class__.parse(value)))
parser.add_argument(
'--subplot-above',
action=AppendSubplot,
help="Add subplot above with the same dataset. Takes an arg string to "
"control the subplot which supports most (but not all) of the "
"parameters listed here. The relative dimensions of the subplot "
"can be controlled with -W/-H which now take a percentage.")
parser.add_argument(
'--subplot-below',
action=AppendSubplot,
help="Add subplot below with the same dataset.")
parser.add_argument(
'--subplot-left',
action=AppendSubplot,
help="Add subplot left with the same dataset.")
parser.add_argument(
'--subplot-right',
action=AppendSubplot,
help="Add subplot right with the same dataset.")
parser.add_argument(
'--subplot',
type=AppendSubplot.parse,
help="Add subplot-specific arguments to the main plot.")
def dictify(ns):
if hasattr(ns, 'subplots'):
ns.subplots = [(dir, dictify(subplot_ns))
for dir, subplot_ns in ns.subplots]
if ns.subplot is not None:
ns.subplot = dictify(ns.subplot)
return {k: v
for k, v in vars(ns).items()
if v is not None}
sys.exit(main(**dictify(parser.parse_intermixed_args())))