scripts: Adopted Attr rework in plot.py/plotmpl.py

Unifying these complicated attr-assigning flags across all the scripts
is the main benefit of the new internal Attr system.

The only tricky bit is we need to somehow keep track of all input fields
in case % modifiers reference fields, when we could previously discard
non-data fields.

Tricky but doable.

Updated flags:

- -L/--label -> -L/--add-label
- --colors -> -C/--add-color
- --formats -> -F/--add-format
- --chars -> -*/--add-char/--chars
- --line-chars -> -_/--add-line-char/--line-chars

I've also tweaked Attr to accept glob matches when figuring out group
assignments. This is useful for matching slightly different, but
similarly named results in our benchmark scripts.

There's probably a clever way to do this by injecting new by fields with
csv.py, but just adding globbing is simpler and makes attr assignment
even more flexible.
This commit is contained in:
Christopher Haster
2025-02-20 05:08:55 -06:00
parent 8b04e35ea5
commit 86f3bad2a4
6 changed files with 802 additions and 397 deletions

View File

@ -15,6 +15,7 @@ if __name__ == "__main__":
import collections as co
import csv
import fnmatch
import io
import itertools as it
import logging
@ -131,31 +132,6 @@ def si2(x):
s = s.rstrip('.')
return '%s%s%s' % ('-' if x < 0 else '', s, SI2_PREFIXES[p])
# parse %-escaped strings
def unescape(s):
pattern = re.compile(
'%[%=,abfnrtv0]'
'|' '%x..'
'|' '%u....'
'|' '%U........')
def unescape(m):
if m.group()[1] == '%': return '%'
elif m.group()[1] == '=': return '='
elif m.group()[1] == ',': return ','
elif m.group()[1] == 'a': return '\a'
elif m.group()[1] == 'b': return '\b'
elif m.group()[1] == 'f': return '\f'
elif m.group()[1] == 'n': return '\n'
elif m.group()[1] == 'r': return '\r'
elif m.group()[1] == 't': return '\t'
elif m.group()[1] == 'v': return '\v'
elif m.group()[1] == '0': return '\0'
elif m.group()[1] == 'x': return chr(int(m.group()[2:], 16))
elif m.group()[1] == 'u': return chr(int(m.group()[2:], 16))
elif m.group()[1] == 'U': return chr(int(m.group()[2:], 16))
else: assert False
return re.sub(pattern, unescape, s)
# we want to use MaxNLocator, but since MaxNLocator forces multiples of 10
# to be an option, we can't really...
class AutoMultipleLocator(mpl.ticker.MultipleLocator):
@ -217,6 +193,12 @@ def dat(x):
# else give up
raise ValueError("invalid dat %r" % x)
def try_dat(x):
try:
return dat(x)
except ValueError:
return None
def collect(csv_paths, defines=[]):
# collect results from CSV files
fields = []
@ -239,7 +221,7 @@ def collect(csv_paths, defines=[]):
return fields, results
def fold(results, by=None, x=None, y=None, defines=[], labels=None):
def fold(results, by=None, x=None, y=None, defines=[]):
# filter by matching defines
if defines:
results_ = []
@ -257,11 +239,13 @@ def fold(results, by=None, x=None, y=None, defines=[], labels=None):
# collect all datasets
datasets = co.OrderedDict()
dataattrs = co.OrderedDict()
for key in (keys if by else [()]):
for x_ in (x if x else [None]):
for y_ in y:
# organize by 'by', x, and y
dataset = []
dataattr = {}
i = 0
for r in results:
# filter by 'by'
@ -298,6 +282,10 @@ def fold(results, by=None, x=None, y=None, defines=[], labels=None):
# incorrect and misleading results
dataset.append((x__, y__))
# include all fields in dataattrs in case we use
# them for % modifiers
dataattr.update(r)
# hide x/y if there is only one field
key_ = key
if len(x or []) > 1:
@ -305,20 +293,144 @@ def fold(results, by=None, x=None, y=None, defines=[], labels=None):
if len(y or []) > 1 or not key_:
key_ += (y_,)
datasets[key_] = dataset
dataattrs[key_] = dataattr
# order by labels
if labels:
datasets_ = co.OrderedDict()
for _, key in labels:
if key in datasets:
datasets_[key] = datasets[key]
# include unlabeled data to help with debugging
for key, dataset in datasets.items():
if key not in datasets_:
datasets_[key] = datasets[key]
datasets = datasets_
return datasets, dataattrs
return datasets
# a representation of optionally key-mapped attrs
class Attr:
def __init__(self, attrs, *,
defaults=None):
# include defaults?
if (defaults is not None
and not any(
not isinstance(attr, tuple)
or attr[0] in {None, (), ('*',)}
for attr in (attrs or []))):
attrs = defaults + (attrs or [])
# normalize
self.attrs = []
self.keyed = co.OrderedDict()
for attr in (attrs or []):
if not isinstance(attr, tuple):
attr = ((), attr)
elif attr[0] in {None, (), ('*',)}:
attr = ((), attr[1])
self.attrs.append(attr)
if attr[0] not in self.keyed:
self.keyed[attr[0]] = []
self.keyed[attr[0]].append(attr[1])
def __repr__(self):
return 'Attr(%r)' % [
(','.join(attr[0]), attr[1])
for attr in self.attrs]
def __iter__(self):
return it.cycle(self.keyed[()])
def __bool__(self):
return bool(self.attrs)
def __getitem__(self, key):
if isinstance(key, tuple):
if len(key) > 0 and not isinstance(key[0], str):
i, key = key
else:
i, key = 0, key
else:
i, key = key, ()
# try to lookup by key
best = None
for ks, vs in self.keyed.items():
prefix = []
for j, k in enumerate(ks):
if j < len(key) and fnmatch.fnmatchcase(key[j], k):
prefix.append(k)
else:
prefix = None
break
if prefix is not None and (
best is None or len(prefix) >= len(best[0])):
best = (prefix, vs)
if best is not None:
# cycle based on index
return best[1][i % len(best[1])]
return None
def __contains__(self, key):
return self.__getitem__(key) is not None
# a key function for sorting by key order
def key(self, key):
# allow key to be a tuple to make sorting dicts easier
if (isinstance(key, tuple)
and len(key) >= 1
and isinstance(key[0], tuple)):
key = key[0]
best = None
for i, ks in enumerate(self.keyed.keys()):
prefix = []
for j, k in enumerate(ks):
if j < len(key) and (not k or key[j] == k):
prefix.append(k)
else:
prefix = None
break
if prefix is not None and (
best is None or len(prefix) >= len(best[0])):
best = (prefix, i)
if best is not None:
return best[1]
return len(self.keyed)
# parse %-escaped strings
def punescape(s, attrs=None):
if attrs is None:
attrs = {}
if isinstance(attrs, dict):
attrs_ = attrs
attrs = lambda k: attrs_[k]
pattern = re.compile(
'%[%n]'
'|' '%x..'
'|' '%u....'
'|' '%U........'
'|' '%\((?P<field>[^)]*)\)'
'(?P<format>[+\- #0-9\.]*[scdboxXfFeEgG])')
def unescape(m):
if m.group()[1] == '%': return '%'
elif m.group()[1] == 'n': return '\n'
elif m.group()[1] == 'x': return chr(int(m.group()[2:], 16))
elif m.group()[1] == 'u': return chr(int(m.group()[2:], 16))
elif m.group()[1] == 'U': return chr(int(m.group()[2:], 16))
elif m.group()[1] == '(':
try:
v = attrs(m.group('field'))
except KeyError:
return m.group()
if m.group('format')[-1] in 'dboxXfFeEgG':
if isinstance(v, str):
v = try_dat(v) or 0
else:
if not isinstance(v, str):
v = str(v)
# note we need Python's new format syntax for binary
f = '{:%s}' % m.group('format')
return f.format(v)
else: assert False
return re.sub(pattern, unescape, s)
# some classes for organizing subplots into a grid
@ -593,11 +705,11 @@ def main(csv_paths, output, *,
x=None,
y=None,
define=[],
label=None,
labels=[],
colors=[],
formats=[],
points=False,
points_and_lines=False,
colors=None,
formats=None,
width=WIDTH,
height=HEIGHT,
xlim=(None,None),
@ -637,21 +749,14 @@ def main(csv_paths, output, *,
svg = True
# what colors/alphas/formats to use?
if colors is not None:
colors_ = colors
elif dark:
colors_ = COLORS_DARK
else:
colors_ = COLORS
colors_ = Attr(colors, defaults=COLORS_DARK if dark else COLORS)
if formats is not None:
formats_ = [unescape(f) for f in formats]
elif points_and_lines:
formats_ = FORMATS_POINTS_AND_LINES
elif points:
formats_ = FORMATS_POINTS
else:
formats_ = FORMATS
formats_ = Attr(formats, defaults=(
FORMATS_POINTS_AND_LINES if points_and_lines
else FORMATS_POINTS if points
else FORMATS))
labels_ = Attr(labels)
if font_color is not None:
font_color_ = font_color
@ -750,9 +855,6 @@ def main(csv_paths, output, *,
subplots_get('define', **subplot, subplots=subplots)):
all_defines[k] |= vs
all_defines = sorted(all_defines.items())
all_labels = [(unescape(k), vs) for k, vs in (
(label or [])
+ subplots_get('label', **subplot, subplots=subplots))]
if not all_by and not all_y:
print("error: needs --by or -y to figure out fields",
@ -760,7 +862,7 @@ def main(csv_paths, output, *,
sys.exit(-1)
# first collect results from CSV files
fields_, results = collect(csv_paths, all_defines)
fields_, results = collect(csv_paths)
# if y not specified, guess it's anything not in by/defines/x
if not all_y:
@ -771,16 +873,27 @@ def main(csv_paths, output, *,
# then extract the requested datasets
#
# note we don't need to filter by defines again
datasets_ = fold(results, all_by, all_x, all_y, None, all_labels)
datasets_, dataattrs_ = fold(results, all_by, all_x, all_y)
# order by labels
datasets_ = co.OrderedDict(sorted(
datasets_.items(),
key=labels_.key))
# and merge dataattrs
mergedattrs_ = {k: v
for dataattr in dataattrs_.values()
for k, v in dataattr.items()}
# figure out formats/colors here so that subplot defines don't change
# them later, that'd be bad
dataformats_ = {
name: formats_[i % len(formats_)]
for i, name in enumerate(datasets_.keys())}
datacolors_ = {
name: colors_[i % len(colors_)]
for i, name in enumerate(datasets_.keys())}
dataformats_ = {name: formats_[i, name]
for i, name in enumerate(datasets_.keys())}
datacolors_ = {name: colors_[i, name]
for i, name in enumerate(datasets_.keys())}
datalabels_ = {name: punescape(labels_[i, name], dataattrs_[name])
for i, name in enumerate(datasets_.keys())
if (i, name) in labels_}
# create a grid of subplots
grid = Grid.fromargs(**subplot, subplots=subplots)
@ -846,13 +959,27 @@ def main(csv_paths, output, *,
# data can be constrained by subplot-specific defines,
# so re-extract for each plot
subdatasets = fold(results, all_by, all_x, all_y, define_, all_labels)
subdatasets, subdataattrs = fold(
results, all_by, all_x, all_y, define_)
# order by labels
subdatasets = co.OrderedDict(sorted(
subdatasets.items(),
key=labels_.key))
# filter by subplot x/y
subdatasets = co.OrderedDict([(name, dataset)
for name, dataset in subdatasets.items()
if len(all_x) <= 1 or name[-(1 if len(all_y) <= 1 else 2)] in x_
if len(all_y) <= 1 or name[-1] in y_])
subdataattrs = co.OrderedDict([(name, dataattr)
for name, dataattr in subdataattrs.items()
if len(all_x) <= 1 or name[-(1 if len(all_y) <= 1 else 2)] in x_
if len(all_y) <= 1 or name[-1] in y_])
# and merge dataattrs
submergedattrs = {k: v
for dataattr in subdataattrs.values()
for k, v in dataattr.items()}
# plot!
ax = s.ax
@ -898,7 +1025,8 @@ def main(csv_paths, output, *,
ax.xaxis.set_major_formatter(lambda x, pos:
si2(x)+(xunits_ if xunits_ else ''))
if xticklabels_ is not None:
ax.xaxis.set_ticklabels([unescape(l) for l in xticklabels_])
ax.xaxis.set_ticklabels([punescape(l, submergedattrs)
for l in xticklabels_])
if xticks_ is None:
ax.xaxis.set_major_locator(AutoMultipleLocator(2))
elif isinstance(xticks_, list):
@ -911,7 +1039,8 @@ def main(csv_paths, output, *,
ax.xaxis.set_major_formatter(lambda x, pos:
si(x)+(xunits_ if xunits_ else ''))
if xticklabels_ is not None:
ax.xaxis.set_ticklabels([unescape(l) for l in xticklabels_])
ax.xaxis.set_ticklabels([punescape(l, submergedattrs)
for l in xticklabels_])
if xticks_ is None:
ax.xaxis.set_major_locator(mpl.ticker.AutoLocator())
elif isinstance(xticks_, list):
@ -924,7 +1053,8 @@ def main(csv_paths, output, *,
ax.yaxis.set_major_formatter(lambda x, pos:
si2(x)+(yunits_ if yunits_ else ''))
if yticklabels_ is not None:
ax.yaxis.set_ticklabels([unescape(l) for l in yticklabels_])
ax.xaxis.set_ticklabels([punescape(l, submergedattrs)
for l in yticklabels_])
if yticks_ is None:
ax.yaxis.set_major_locator(AutoMultipleLocator(2))
elif isinstance(yticks_, list):
@ -937,7 +1067,8 @@ def main(csv_paths, output, *,
ax.yaxis.set_major_formatter(lambda x, pos:
si(x)+(yunits_ if yunits_ else ''))
if yticklabels_ is not None:
ax.yaxis.set_ticklabels([unescape(l) for l in yticklabels_])
ax.xaxis.set_ticklabels([punescape(l, submergedattrs)
for l in yticklabels_])
if yticks_ is None:
ax.yaxis.set_major_locator(mpl.ticker.AutoLocator())
elif isinstance(yticks_, list):
@ -951,11 +1082,11 @@ def main(csv_paths, output, *,
# axes subplot labels
if xsublabel is not None:
ax.set_xlabel(unescape(xsublabel))
ax.set_xlabel(punescape(xsublabel, submergedattrs))
if ysublabel is not None:
ax.set_ylabel(unescape(ysublabel))
ax.set_ylabel(punescape(ysublabel, submergedattrs))
if subtitle is not None:
ax.set_title(unescape(subtitle))
ax.set_title(punescape(subtitle, submergedattrs))
# add a legend? a bit tricky with matplotlib
#
@ -968,16 +1099,14 @@ def main(csv_paths, output, *,
for s in grid:
for h, l in zip(*s.ax.get_legend_handles_labels()):
legend[l] = h
if all_labels:
all_labels_ = {key: l for l, key in all_labels}
# sort in dataset order
legend_ = []
for name in datasets_.keys():
for i, name in enumerate(datasets_.keys()):
name_ = ','.join(name)
if name_ in legend:
if all_labels and name in all_labels_:
if all_labels_[name]:
legend_.append((all_labels_[name], legend[name_]))
if name in datalabels_:
if datalabels_[name]:
legend_.append((datalabels_[name], legend[name_]))
else:
legend_.append((name_, legend[name_]))
legend = legend_
@ -1026,7 +1155,7 @@ def main(csv_paths, output, *,
# big hack to get xlabel above the legend! but hey this
# works really well actually
if xlabel:
ax.set_title(unescape(xlabel),
ax.set_title(punescape(xlabel, mergedattrs_),
size=plt.rcParams['axes.labelsize'],
weight=plt.rcParams['axes.labelweight'])
@ -1056,11 +1185,11 @@ def main(csv_paths, output, *,
# axes labels, NOTE we reposition these below
if xlabel is not None and not legend_below:
fig.supxlabel(unescape(xlabel))
fig.supxlabel(punescape(xlabel, mergedattrs_))
if ylabel is not None:
fig.supylabel(unescape(ylabel))
fig.supylabel(punescape(ylabel, mergedattrs_))
if title is not None:
fig.suptitle(unescape(title))
fig.suptitle(punescape(title, mergedattrs_))
# precompute constrained layout and find midpoints to adjust things
# that should be centered so they are actually centered
@ -1069,11 +1198,11 @@ def main(csv_paths, output, *,
ymid = (grid[0,0].ax.get_position().y0 + grid[0,-1].ax.get_position().y1)/2
if xlabel is not None and not legend_below:
fig.supxlabel(unescape(xlabel), x=xmid)
fig.supxlabel(punescape(xlabel, mergedattrs_), x=xmid)
if ylabel is not None:
fig.supylabel(unescape(ylabel), y=ymid)
fig.supylabel(punescape(ylabel, mergedattrs_), y=ymid)
if title is not None:
fig.suptitle(unescape(title), x=xmid)
fig.suptitle(punescape(title, mergedattrs_), x=xmid)
# write the figure!
@ -1137,17 +1266,43 @@ if __name__ == "__main__":
help="Only include results where this field is this value. May "
"include comma-separated options.")
parser.add_argument(
'-L', '--label',
'-L', '--add-label',
dest='labels',
action='append',
type=lambda x: (
lambda k, vs: (
k.strip(),
tuple(v.strip() for v in vs.split(',')))
)(*re.split(r'(?<!%)=', x, 1)),
help="Use this label for a given group, where a group is roughly "
"the comma-separated values in the -b/--by, -x, and -y "
"fields. Also provides an ordering. Accepts %= and other "
"%-escaped codes.")
lambda ks, v: (
tuple(k.strip() for k in ks.split(',')),
v.strip())
)(*x.split('=', 1))
if '=' in x else x.strip(),
help="Add a label to use. Can be assigned to a specific group "
"where a group is the comma-separated 'by' fields. Accepts %% "
"modifiers. Also provides an ordering.")
parser.add_argument(
'-C', '--add-color',
dest='colors',
action='append',
type=lambda x: (
lambda ks, v: (
tuple(k.strip() for k in ks.split(',')),
v.strip())
)(*x.split('=', 1))
if '=' in x else x.strip(),
help="Add a color to use. Can be assigned to a specific group "
"where a group is the comma-separated 'by' fields.")
parser.add_argument(
'-F', '--add-format',
dest='formats',
action='append',
type=lambda x: (
lambda ks, v: (
tuple(k.strip() for k in ks.split(',')),
v.strip())
)(*x.split('=', 1))
if '=' in x else x.strip(),
help="Add a matplotlib format to use. Can be assigned to a "
"specific group where a group is the comma-separated 'by' "
"fields.")
parser.add_argument(
'-.', '--points',
action='store_true',
@ -1156,15 +1311,6 @@ if __name__ == "__main__":
'-!', '--points-and-lines',
action='store_true',
help="Draw data points and lines.")
parser.add_argument(
'--colors',
type=lambda x: [x.strip() for x in x.split(',')],
help="Comma-separated hex colors to use.")
parser.add_argument(
'--formats',
type=lambda x: [x.strip() for x in re.split(r'(?<!%),', x)],
help="Comma-separated matplotlib formats to use. Accepts %, and "
"other %-escaped codes.")
parser.add_argument(
'-W', '--width',
type=lambda x: int(x, 0),
@ -1221,25 +1367,25 @@ if __name__ == "__main__":
help="Units for the y-axis.")
parser.add_argument(
'--xlabel',
help="Add a label to the x-axis. Accepts %-escaped codes.")
help="Add a label to the x-axis. Accepts %% modifiers.")
parser.add_argument(
'--ylabel',
help="Add a label to the y-axis. Accepts %-escaped codes.")
help="Add a label to the y-axis. Accepts %% modifiers.")
parser.add_argument(
'--xticklabels',
type=lambda x: [x.strip() for x in re.split(r'(?<!%),', x)]
if x.strip() else [],
help="Comma separated xticklabels. Accepts %, and other "
"%-escaped codes.")
help="Comma separated xticklabels. Accepts %%, and other "
"%%-escaped codes.")
parser.add_argument(
'--yticklabels',
type=lambda x: [x.strip() for x in re.split(r'(?<!%),', x)]
if x.strip() else [],
help="Comma separated yticklabels. Accepts %, and other "
"%-escaped codes.")
help="Comma separated yticklabels. Accepts %%, and other "
"%%-escaped codes.")
parser.add_argument(
'--title',
help="Add a title. Accepts %-escaped codes.")
help="Add a title. Accepts %% modifiers.")
parser.add_argument(
'-l', '--legend', '--legend-right',
dest='legend_right',