Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:24:07

0001 """
0002 Utilities for plotting ROOT histograms in matplotlib.
0003 """
0004 
0005 from builtins import range
0006 __license__ = '''\
0007 Copyright (c) 2009-2010 Jeff Klukas <klukas@wisc.edu>
0008 
0009 Permission is hereby granted, free of charge, to any person obtaining a copy
0010 of this software and associated documentation files (the "Software"), to deal
0011 in the Software without restriction, including without limitation the rights
0012 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
0013 copies of the Software, and to permit persons to whom the Software is
0014 furnished to do so, subject to the following conditions:
0015 
0016 The above copyright notice and this permission notice shall be included in
0017 all copies or substantial portions of the Software.
0018 
0019 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
0020 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
0021 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
0022 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
0023 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
0024 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
0025 THE SOFTWARE.
0026 '''
0027 
0028 ################ Import python libraries
0029 
0030 import math
0031 import ROOT
0032 import re
0033 import copy
0034 import array
0035 from rootplot import utilities
0036 import matplotlib as mpl
0037 import matplotlib.pyplot as plt
0038 import matplotlib.transforms as transforms
0039 import numpy as np
0040 
0041 ################ Define constants
0042 
0043 _all_whitespace_string = re.compile(r'\s*$')
0044 
0045 
0046 ################ Define classes
0047 
0048 class Hist2D(utilities.Hist2D):
0049     """A container to hold the parameters from a 2D ROOT histogram."""
0050     def __init__(self, *args, **kwargs):
0051         self.replacements = None
0052         if 'replacements' in kwargs:
0053             self.replacements = kwargs.pop('replacements')
0054         utilities.Hist2D.__init__(self, *args, **kwargs)
0055     def contour(self, **kwargs):
0056         """Draw a contour plot."""
0057         cs = plt.contour(self.x, self.y, self.content, **kwargs)
0058         plt.clabel(cs, inline=1, fontsize=10)
0059         if self.binlabelsx is not None:
0060             plt.xticks(np.arange(self.nbinsx), self.binlabelsx)
0061         if self.binlabelsy is not None:
0062             plt.yticks(np.arange(self.nbinsy), self.binlabelsy)
0063         return cs
0064     def col(self, **kwargs):
0065         """Draw a colored box plot using :func:`matplotlib.pyplot.imshow`."""
0066         if 'cmap' in kwargs:
0067             kwargs['cmap'] = plt.get_cmap(kwargs['cmap'])
0068         plot = plt.imshow(self.content, interpolation='nearest',
0069                           extent=[self.xedges[0], self.xedges[-1],
0070                                   self.yedges[0], self.yedges[-1]],
0071                           aspect='auto', origin='lower', **kwargs)
0072         return plot
0073     def colz(self, **kwargs):
0074         """
0075         Draw a colored box plot with a colorbar using
0076         :func:`matplotlib.pyplot.imshow`.
0077         """
0078         plot = self.col(**kwargs)
0079         plt.colorbar(plot)
0080         return plot
0081     def box(self, maxsize=40, **kwargs):
0082         """
0083         Draw a box plot with size indicating content using
0084         :func:`matplotlib.pyplot.scatter`.
0085         
0086         The data will be normalized, with the largest box using a marker of
0087         size maxsize (in points).
0088         """
0089         x = np.hstack([self.x for i in range(self.nbinsy)])
0090         y = np.hstack([[yval for i in range(self.nbinsx)] for yval in self.y])
0091         maxvalue = np.max(self.content)
0092         if maxvalue == 0:
0093             maxvalue = 1
0094         sizes = np.array(self.content).flatten() / maxvalue * maxsize
0095         plot = plt.scatter(x, y, sizes, marker='s', **kwargs)
0096         return plot
0097     def TH2F(self, name=""):
0098         """Return a ROOT.TH2F object with contents of this Hist2D."""
0099         th2f = ROOT.TH2F(name, "",
0100                          self.nbinsx, array.array('f', self.xedges),
0101                          self.nbinsy, array.array('f', self.yedges))
0102         th2f.SetTitle("%s;%s;%s" % (self.title, self.xlabel, self.ylabel))
0103         for ix in range(self.nbinsx):
0104             for iy in range(self.nbinsy):
0105                 th2f.SetBinContent(ix + 1, iy + 1, self.content[iy][ix])
0106         return th2f
0107 
0108 class Hist(utilities.Hist):
0109     """A container to hold the parameters from a ROOT histogram."""
0110     def __init__(self, *args, **kwargs):
0111         self.replacements = None
0112         if 'replacements' in kwargs:
0113             self.replacements = kwargs.pop('replacements')
0114         utilities.Hist.__init__(self, *args, **kwargs)
0115     def _prepare_xaxis(self, rotation=0, alignment='center'):
0116         """Apply bounds and text labels on x axis."""
0117         if self.binlabels is not None:
0118             binwidth = (self.xedges[-1] - self.xedges[0]) / self.nbins
0119             plt.xticks(self.x, self.binlabels,
0120                        rotation=rotation, ha=alignment)
0121         plt.xlim(self.xedges[0], self.xedges[-1])
0122 
0123     def _prepare_yaxis(self, rotation=0, alignment='center'):
0124         """Apply bounds and text labels on y axis."""
0125         if self.binlabels is not None:
0126             binwidth = (self.xedges[-1] - self.xedges[0]) / self.nbins
0127             plt.yticks(self.x, self.binlabels,
0128                        rotation=rotation, va=alignment)
0129         plt.ylim(self.xedges[0], self.xedges[-1])
0130 
0131     def show_titles(self, **kwargs):
0132         """Print the title and axis labels to the current figure."""
0133         replacements = kwargs.get('replacements', None) or self.replacements
0134         plt.title(replace(self.title, replacements))
0135         plt.xlabel(replace(self.xlabel, replacements))
0136         plt.ylabel(replace(self.ylabel, replacements))
0137     def hist(self, label_rotation=0, label_alignment='center', **kwargs):
0138         """
0139         Generate a matplotlib hist figure.
0140 
0141         All additional keyword arguments will be passed to
0142         :func:`matplotlib.pyplot.hist`.
0143         """
0144         kwargs.pop('fmt', None)
0145         replacements = kwargs.get('replacements', None) or self.replacements
0146         weights = self.y
0147         # Kludge to avoid mpl bug when plotting all zeros
0148         if self.y == [0] * self.nbins:
0149             weights = [1.e-10] * self.nbins
0150         plot = plt.hist(self.x, weights=weights, bins=self.xedges,
0151                         label=replace(self.label, replacements), **kwargs)
0152         self._prepare_xaxis(label_rotation, label_alignment)
0153         return plot
0154     def errorbar(self, xerr=False, yerr=False, label_rotation=0,
0155                  label_alignment='center', **kwargs):
0156         """
0157         Generate a matplotlib errorbar figure.
0158 
0159         All additional keyword arguments will be passed to
0160         :func:`matplotlib.pyplot.errorbar`.
0161         """
0162         if xerr:
0163             kwargs['xerr'] = self.xerr
0164         if yerr:
0165             kwargs['yerr'] = self.yerr
0166         replacements = kwargs.get('replacements', None) or self.replacements
0167         errorbar = plt.errorbar(self.x, self.y,
0168                                 label=replace(self.label, replacements),
0169                                 **kwargs)
0170         self._prepare_xaxis(label_rotation, label_alignment)
0171         return errorbar
0172     def errorbarh(self, xerr=False, yerr=False, label_rotation=0,
0173                   label_alignment='center', **kwargs):
0174         """
0175         Generate a horizontal matplotlib errorbar figure.
0176 
0177         All additional keyword arguments will be passed to
0178         :func:`matplotlib.pyplot.errorbar`.
0179         """
0180         if xerr: kwargs['xerr'] = self.yerr
0181         if yerr: kwargs['yerr'] = self.xerr
0182         replacements = kwargs.get('replacements', None) or self.replacements
0183         errorbar = plt.errorbar(self.y, self.x,
0184                                 label=replace(self.label, replacements),
0185                                 **kwargs)
0186         self._prepare_yaxis(label_rotation, label_alignment)
0187         return errorbar
0188     def bar(self, xerr=False, yerr=False, xoffset=0., width=0.8, 
0189             label_rotation=0, label_alignment='center', **kwargs):
0190         """
0191         Generate a matplotlib bar figure.
0192 
0193         All additional keyword arguments will be passed to
0194         :func:`matplotlib.pyplot.bar`.
0195         """
0196         kwargs.pop('fmt', None)
0197         if xerr: kwargs['xerr'] = self.av_xerr()
0198         if yerr: kwargs['yerr'] = self.av_yerr()
0199         replacements = kwargs.get('replacements', None) or self.replacements
0200         ycontent = [self.xedges[i] + self.width[i] * xoffset
0201                     for i in range(len(self.xedges) - 1)]
0202         width = [x * width for x in self.width]
0203         bar = plt.bar(ycontent, self.y, width,
0204                       label=replace(self.label, replacements), **kwargs)
0205         self._prepare_xaxis(label_rotation, label_alignment)
0206         return bar
0207     def barh(self, xerr=False, yerr=False, yoffset=0., width=0.8,
0208              label_rotation=0, label_alignment='center', **kwargs):
0209         """
0210         Generate a horizontal matplotlib bar figure.
0211 
0212         All additional keyword arguments will be passed to
0213         :func:`matplotlib.pyplot.bar`.
0214         """
0215         kwargs.pop('fmt', None)
0216         if xerr: kwargs['xerr'] = self.av_yerr()
0217         if yerr: kwargs['yerr'] = self.av_xerr()
0218         replacements = kwargs.get('replacements', None) or self.replacements
0219         xcontent = [self.xedges[i] + self.width[i] * yoffset
0220                     for i in range(len(self.xedges) - 1)]
0221         width = [x * width for x in self.width]
0222         barh = plt.barh(xcontent, self.y, width,
0223                         label=replace(self.label, replacements),
0224                        **kwargs)
0225         self._prepare_yaxis(label_rotation, label_alignment)
0226         return barh
0227 
0228 class HistStack(utilities.HistStack):
0229     """
0230     A container to hold Hist objects for plotting together.
0231 
0232     When plotting, the title and the x and y labels of the last Hist added
0233     will be used unless specified otherwise in the constructor.
0234     """
0235     def __init__(self, *args, **kwargs):
0236         if 'replacements' in kwargs:
0237             self.replacements = kwargs.pop('replacements')
0238         utilities.HistStack.__init__(self, *args, **kwargs)
0239     def show_titles(self, **kwargs):
0240         self.hists[-1].show_titles()
0241     def hist(self, label_rotation=0, **kwargs):
0242         """
0243         Make a matplotlib hist plot.
0244 
0245         Any additional keyword arguments will be passed to
0246         :func:`matplotlib.pyplot.hist`, which allows a vast array of
0247         possibilities.  Particlularly, the *histtype* values such as
0248         ``'barstacked'`` and ``'stepfilled'`` give substantially different
0249         results.  You will probably want to include a transparency value
0250         (i.e. *alpha* = 0.5).
0251         """
0252         contents = np.dstack([hist.y for hist in self.hists])
0253         xedges = self.hists[0].xedges
0254         x = np.dstack([hist.x for hist in self.hists])[0]
0255         labels = [hist.label for hist in self.hists]
0256         try:
0257             clist = [item['color'] for item in self.kwargs]
0258             plt.gca().set_color_cycle(clist)
0259             ## kwargs['color'] = clist # For newer version of matplotlib
0260         except:
0261             pass
0262         plot = plt.hist(x, weights=contents, bins=xedges,
0263                         label=labels, **kwargs)
0264     def bar3d(self, **kwargs):
0265         #### Not yet ready for primetime
0266         from mpl_toolkits.mplot3d import Axes3D
0267         fig = plt.figure()
0268         ax = Axes3D(fig)
0269         plots = []
0270         labels = []
0271         for i, hist in enumerate(self.hists):
0272             if self.title  is not None: hist.title  = self.title
0273             if self.xlabel is not None: hist.xlabel = self.xlabel
0274             if self.ylabel is not None: hist.ylabel = self.ylabel
0275             labels.append(hist.label)
0276             all_kwargs = copy.copy(kwargs)
0277             all_kwargs.update(self.kwargs[i])
0278             bar = ax.bar(hist.x, hist.y, zs=i, zdir='y', width=hist.width,
0279                          **all_kwargs)
0280             plots.append(bar)
0281         from matplotlib.ticker import FixedLocator
0282         locator = FixedLocator(list(range(len(labels))))
0283         ax.w_yaxis.set_major_locator(locator)
0284         ax.w_yaxis.set_ticklabels(labels)
0285         ax.set_ylim3d(-1, len(labels))
0286         return plots
0287     def barstack(self, **kwargs):
0288         """
0289         Make a matplotlib bar plot, with each Hist stacked upon the last.
0290 
0291         Any additional keyword arguments will be passed to
0292         :func:`matplotlib.pyplot.bar`.
0293         """
0294         bottom = None # if this is set to zeroes, it fails for log y
0295         plots = []
0296         for i, hist in enumerate(self.hists):
0297             if self.title  is not None: hist.title  = self.title
0298             if self.xlabel is not None: hist.xlabel = self.xlabel
0299             if self.ylabel is not None: hist.ylabel = self.ylabel
0300             all_kwargs = copy.copy(kwargs)
0301             all_kwargs.update(self.kwargs[i])
0302             bar = hist.bar(bottom=bottom, **all_kwargs)
0303             plots.append(bar)
0304             if not bottom: bottom = [0. for i in range(self.hists[0].nbins)]
0305             bottom = [sum(pair) for pair in zip(bottom, hist.y)]
0306         return plots
0307     def histstack(self, **kwargs):
0308         """
0309         Make a matplotlib hist plot, with each Hist stacked upon the last.
0310 
0311         Any additional keyword arguments will be passed to
0312         :func:`matplotlib.pyplot.hist`.
0313         """
0314         bottom = None # if this is set to zeroes, it fails for log y
0315         plots = []
0316         cumhist = None
0317         for i, hist in enumerate(self.hists):
0318             if cumhist:
0319                 cumhist = hist + cumhist
0320             else:
0321                 cumhist = copy.copy(hist)
0322             if self.title  is not None: cumhist.title  = self.title
0323             if self.xlabel is not None: cumhist.xlabel = self.xlabel
0324             if self.ylabel is not None: cumhist.ylabel = self.ylabel
0325             all_kwargs = copy.copy(kwargs)
0326             all_kwargs.update(self.kwargs[i])
0327             zorder = 0 + float(len(self) - i)/len(self) # plot in reverse order
0328             plot = cumhist.hist(zorder=zorder, **all_kwargs)
0329             plots.append(plot)
0330         return plots
0331     def barcluster(self, width=0.8, **kwargs):
0332         """
0333         Make a clustered bar plot.
0334 
0335         Any additional keyword arguments will be passed to
0336         :func:`matplotlib.pyplot.bar`.
0337         """
0338         plots = []
0339         spacer = (1. - width) / 2
0340         width = width / len(self.hists)
0341         for i, hist in enumerate(self.hists):
0342             if self.title  is not None: hist.title  = self.title
0343             if self.xlabel is not None: hist.xlabel = self.xlabel
0344             if self.ylabel is not None: hist.ylabel = self.ylabel
0345             all_kwargs = copy.copy(kwargs)
0346             all_kwargs.update(self.kwargs[i])
0347             bar = hist.bar(xoffset=width*i + spacer, width=width, **all_kwargs)
0348             plots.append(bar)
0349         return plots
0350     def barh(self, width=0.8, **kwargs):
0351         """
0352         Make a horizontal clustered matplotlib bar plot.
0353 
0354         Any additional keyword arguments will be passed to
0355         :func:`matplotlib.pyplot.bar`.
0356         """
0357         plots = []
0358         spacer = (1. - width) / 2
0359         width = width / len(self.hists)
0360         for i, hist in enumerate(self.hists):
0361             if self.title  is not None: hist.title  = self.title
0362             if self.xlabel is not None: hist.ylabel = self.xlabel
0363             if self.ylabel is not None: hist.xlabel = self.ylabel
0364             all_kwargs = copy.copy(kwargs)
0365             all_kwargs.update(self.kwargs[i])
0366             bar = hist.barh(yoffset=width*i + spacer, width=width, **all_kwargs)
0367             plots.append(bar)
0368         return plots
0369     def bar(self, **kwargs):
0370         """
0371         Make a bar plot, with all Hists in the stack overlaid.
0372 
0373         Any additional keyword arguments will be passed to
0374         :func:`matplotlib.pyplot.bar`.  You will probably want to set a 
0375         transparency value (i.e. *alpha* = 0.5).
0376         """
0377         plots = []
0378         for i, hist in enumerate(self.hists):
0379             if self.title  is not None: hist.title  = self.title
0380             if self.xlabel is not None: hist.xlabel = self.xlabel
0381             if self.ylabel is not None: hist.ylabel = self.ylabel
0382             all_kwargs = copy.copy(kwargs)
0383             all_kwargs.update(self.kwargs[i])
0384             bar = hist.bar(**all_kwargs)
0385             plots.append(bar)
0386         return plots
0387     def errorbar(self, offset=False, **kwargs):
0388         """
0389         Make a matplotlib errorbar plot, with all Hists in the stack overlaid.
0390 
0391         Passing 'offset=True' will slightly offset each dataset so overlapping
0392         errorbars are still visible.  Any additional keyword arguments will
0393         be passed to :func:`matplotlib.pyplot.errorbar`.
0394         """
0395         plots = []
0396         for i, hist in enumerate(self.hists):
0397             if self.title  is not None: hist.title  = self.title
0398             if self.xlabel is not None: hist.xlabel = self.xlabel
0399             if self.ylabel is not None: hist.ylabel = self.ylabel
0400             all_kwargs = copy.copy(kwargs)
0401             all_kwargs.update(self.kwargs[i])
0402             transform = plt.gca().transData
0403             if offset:
0404                 index_offset = (len(self.hists) - 1)/2.
0405                 pixel_offset = 1./72 * (i - index_offset)
0406                 transform = transforms.ScaledTranslation(
0407                     pixel_offset, 0, plt.gcf().dpi_scale_trans)
0408                 transform = plt.gca().transData + transform
0409             errorbar = hist.errorbar(transform=transform, **all_kwargs)
0410             plots.append(errorbar)
0411         return plots
0412     def errorbarh(self, **kwargs):
0413         """
0414         Make a horizontal matplotlib errorbar plot, with all Hists in the
0415         stack overlaid.
0416 
0417         Any additional keyword arguments will be passed to
0418         :func:`matplotlib.pyplot.errorbar`.
0419         """
0420         plots = []
0421         for i, hist in enumerate(self.hists):
0422             if self.title  is not None: hist.title  = self.title
0423             if self.xlabel is not None: hist.ylabel = self.xlabel
0424             if self.ylabel is not None: hist.xlabel = self.ylabel
0425             all_kwargs = copy.copy(kwargs)
0426             all_kwargs.update(self.kwargs[i])
0427             errorbar = hist.errorbarh(**all_kwargs)
0428             plots.append(errorbar)
0429         return plots
0430 
0431 ################ Define functions and classes for navigating within ROOT
0432 
0433 class RootFile(utilities.RootFile):
0434     """A wrapper for TFiles, allowing easier access to methods."""
0435     def get(self, object_name, path=None):
0436         try:
0437             return utilities.RootFile.get(self, object_name, path,
0438                                           Hist, Hist2D)
0439         except ReferenceError as e:
0440             raise ReferenceError(e)
0441 
0442 ################ Define additional helping functions
0443 
0444 def replace(string, replacements):
0445     """
0446     Modify a string based on a list of patterns and substitutions.
0447 
0448     replacements should be a list of two-entry tuples, the first entry giving
0449     a string to search for and the second entry giving the string with which
0450     to replace it.  If replacements includes a pattern entry containing
0451     'use_regexp', then all patterns will be treated as regular expressions
0452     using re.sub.
0453     """
0454     if not replacements:
0455         return string
0456     if 'use_regexp' in [x for x,y in replacements]:
0457         for pattern, repl in [x for x in replacements
0458                               if x[0] != 'use_regexp']:
0459             string = re.sub(pattern, repl, string)
0460     else:
0461         for pattern, repl in replacements:
0462             string = string.replace(pattern, repl)
0463     if re.match(_all_whitespace_string, string):
0464         return ""
0465     return string
0466