#! /usr/bin/env python3

"""\
%(prog)s <datafile1> [<datafile2> ...]

List the contents of YODA-readable data files (sorted by path name).
"""

from __future__ import print_function

import yoda, sys, argparse, signal

# Ignore SIGPIPE (BrokenPipeError)
signal.signal(signal.SIGPIPE, signal.SIG_DFL)

parser = argparse.ArgumentParser(usage=__doc__)
parser.add_argument("ARGS", nargs="+", help="infile [outfile]")
parser.add_argument('-v', '--verbose', action="count", default=1, dest='VERBOSITY',
                    help="print extra histogram details")
parser.add_argument('-q', '--quiet', action="store_const", const=0, default=1, dest='VERBOSITY',
                    help="just print histogram details, no cosmetic filenames or blank lines")
parser.add_argument("-m", "--match", dest="MATCH", metavar="PATT", default=None,
                    help="only write out histograms whose path matches this regex")
parser.add_argument("-M", "--unmatch", dest="UNMATCH", metavar="PATT", default=None,
                    help="exclude histograms whose path matches this regex")
parser.add_argument("--max-typelen", dest="MAX_TYPELEN", type=int, default=10,
                    help="max length of type column")
parser.add_argument("-l", dest="LONGROW", default=False, action="store_true",
                    help="print details in a unix-friendly parseable way")
args = parser.parse_args()

filenames = args.ARGS
if not filenames:
    print("ERROR! Please supply at least one data file for listing")
    sys.exit(1)

try:
    import natsort
    ysorted = natsort.natsorted
except:
    ysorted = sorted

for i, f in enumerate(filenames):
    if args.VERBOSITY >= 1:
        if i > 0: print()
        print("Data objects in %s:" % f)
    aodict = yoda.read(f, patterns=args.MATCH, unpatterns=args.UNMATCH)
    for p, ao in ysorted(aodict.items()):
        extrainfo = ""
        if args.VERBOSITY >= 2:
            if hasattr(ao, "numEntries"):
                extrainfo += " N={sumw:.3g}".format(sumw=ao.numEntries())
            if hasattr(ao, "sumW"):
                extrainfo += " sumW={sumw:.3g}".format(sumw=ao.sumW())
            if hasattr(ao, "numBins"):
                n = ao.numBins()
                if n == 1 and ao.axisConfig == 'd':
                    if hasattr(ao.bin(0), "sumW"):
                        extrainfo += " uflow={sumw:.3g}".format(sumw=ao.bin(0).sumW())
                        extrainfo += " oflow={sumw:.3g}".format(sumw=ao.bin(n+1).sumW())
                    elif hasattr(ao.bin(0), "val"):
                        extrainfo += " uflow={sumw:.3g}".format(sumw=ao.bin(0).val())
                        extrainfo += " oflow={sumw:.3g}".format(sumw=ao.bin(n+1).val())

        if hasattr(ao, "numBins"):
            nobjstr = "{n:4d}".format(n=ao.numBins(False))
        elif hasattr(ao, "numPoints"):
            nobjstr = "{n:4d}".format(n=ao.numPoints())
        else:
            nobjstr = "   -"
        aoinfo = f"{p:<50} " if args.LONGROW else f"{p}\n"
        aoinfo += f"{ao.type().ljust(args.MAX_TYPELEN)}"
        aoinfo += f" {nobjstr} bins/pts" + extrainfo
        print (aoinfo)
        sys.stdout.flush()
        if not args.LONGROW and args.VERBOSITY >= 3:
            # print columns headers
            tri_temp = 'Val{dim:d}  Err{dim:d}-  Err{dim:d}+'
            dbl_temp = 'Low{dim:d}  High{dim:d}'
            if hasattr(ao, "binDim"):
                print('# ' + '  '.join([ dbl_temp.format(dim=d+1) for d in range(ao.dim()-1) ]) + \
                                       '  Val  Err-  Err+')
            else:
                print('# ' + '  '.join([ tri_temp.format(dim=d+1) for d in range(ao.dim()) ]))
            # print columns of (low,high) or (val, err-, err+) for each dimension and each point
            tri_temp = '{0:.3g}  {1:.3g}  {2:.3g}'
            dbl_temp = '{0:.3g}  {1:.3g}'
            content = [ ]
            if hasattr(ao, "binDim"):
                # print column in the form (low,high) plus one in (val, err-, err+)
                content = [ '  '.join([ dbl_temp.format(p.min(d), p.max(d)) \
                                        for d in range(ao.dim()-1) ] + \
                                        [ tri_temp.format(p.val(ao.dim()-1), *p.errs(ao.dim()-1)) ]) \
                                        for p in ao.mkScatter().points() ]
            else:
                # print all columns in the form (val, err-, err+)
                content = [ '  '.join([ tri_temp.format(p.val(d), *p.errs(d)) \
                                        for d in range(ao.dim()) ]) \
                                        for p in ao.mkScatter().points() ]
            print('\n'.join(content))
            sys.stdout.flush()
        if not args.LONGROW:
            print()
