#! /usr/bin/env python3

"""\
Merge together YODA files produced by Rivet

Examples:

  %(prog)s [options] <yodafile> [<yodafile2> ...]

ENVIRONMENT:
 * RIVET_ANALYSIS_PATH: list of paths to be searched for analysis plugin libraries
 * RIVET_DATA_PATH: list of paths to be searched for data files
"""

import os, sys

## Load the rivet module
try:
    import rivet
except Exception as e:
    ## If rivet loading failed, try to bootstrap the Python path!
    # try:
    #     # TODO: Is this a good idea? Maybe just notify the user that their PYTHONPATH is wrong?
    #     import commands
    #     modname = sys.modules[__name__].__file__
    #     binpath = os.path.dirname(modname)
    #     rivetconfigpath = os.path.join(binpath, "rivet-config")
    #     rivetpypath = commands.getoutput(rivetconfigpath + " --pythonpath")
    #     sys.path.append(rivetpypath)
    #     import rivet
    # except:
    sys.stderr.write("The rivet Python module could not be loaded: is your PYTHONPATH set correctly?\n")
    sys.stderr.write("Try running 'python -c \"import rivet\"' at the command line (or \"import rivet\" in a Python REPL)\n")
    sys.stderr.write("Full error message:\n{}\n".format(e))
    sys.exit(5)

rivet.util.set_process_name("rivet-merge")

import time, datetime, logging, signal
import multiprocessing

try:
    numcores = multiprocessing.cpu_count()
except:
    numcores = 1

## Parse command line options
import argparse
parser = argparse.ArgumentParser(description=__doc__)

extragroup = parser.add_argument_group("Run settings")
extragroup.add_argument("YODAFILES", nargs="+", help="data files to merge. Weight suffixes can be added after a colon, e.g. foo.yoda:=1.23 or bar.yoda:x4.5 to set the cross-section to 1.23 pb, or to scale by 4.5 respectively")
extragroup.add_argument("-o", "--output-file", dest="OUTPUTFILE",
                        default="Rivet.yoda", help="specify the output histo file path (default = %(default)s)")
extragroup.add_argument("-e", "--equiv", dest="EQUIV", action="store_true", default=False,
                        help="assume that the yoda files are equivalent but statistically independent (default= assume that different files contains different processes)")
# TODO: could the merge- and add-options settings be specified in the info file in a MergeSettings heading?
extragroup.add_argument("-O", "--merge-option", dest="MERGEOPTIONS", action="append",
                        default=[], help="specify an analysis option name where different options should be merged into the default analysis.")
extragroup.add_argument("-a", "--add-option", dest="ADDOPTIONS", action="append", default=[],
                        help="specify an analysis option in the format ANALYSIS:NAME=VALUE that should be added in the merging. Use :NAME=VALUE if the option is to be applied to all analyses.")
extragroup.add_argument("--pwd", dest="ANALYSIS_PATH_PWD", action="store_true", default=False,
                      help="append the current directory (pwd) to the analysis/data search paths (cf. $RIVET_ANALYSIS_PATH).")
extragroup.add_argument("-p", "--preload-file", dest="PRELOADFILE",
                        default=None, help="specify an old yoda file to initialize or load calibrations, etc. (default = %(default)s)")
extragroup.add_argument("-m", "--match", dest="MATCH", action="append", default=[],
                         help="only write out histograms whose path matches this regex")
extragroup.add_argument("-M", "--unmatch", dest="UNMATCH", action="append", default=[],
                        help="exclude histograms whose path matches this regex")
extragroup.add_argument("-x", "--cross-section", dest="XSEC", default=None,
                        help="specify the signal process cross-section (in pb) to overwrite the values found in the files")
extragroup.add_argument("--assume-reentrant", dest="ASSUME_REENTRANT", action="store_true", default=False,
                        help="assume all routines are reentrant-safe, even if they are not explicitly marked as such")
extragroup.add_argument("-j", "--jobs", dest="JOBS", type=int,
                        default=1, help="max number of threads to be used [%s]" % numcores)

verbgroup = parser.add_argument_group("Verbosity control")
parser.add_argument("-l", dest="NATIVE_LOG_STRS", action="append",
                    default=[], help="set a log level in the Rivet library")
verbgroup.add_argument("-v", "--verbose", action="store_const", const=logging.DEBUG, dest="LOGLEVEL",
                       default=logging.INFO, help="print debug (very verbose) messages")
verbgroup.add_argument("-q", "--quiet", action="store_const", const=logging.WARNING, dest="LOGLEVEL",
                       default=logging.INFO, help="be very quiet")


def subMerge(yodafiles, mergeoptions, addoptions, match, unmatch, equiv, reentrantOnly, preload = None, procnum = 1):
    """Worker function that merges a subset of the input YODA files
       and returns AO paths and serialized AO data in memory."""
    ah = rivet.AnalysisHandler("AH%d" % procnum)
    if preload is not None:
        ah.readData(preload)
    ah.mergeYODAs(yodafiles, mergeoptions, addoptions, match, unmatch, equiv, reentrantOnly)
    return ah.getRawAOpaths(), ah.serializeContent()


if __name__ == '__main__':

    args = parser.parse_args()


    ## Configure logging
    logging.basicConfig(level=args.LOGLEVEL, format="%(message)s")
    if not args.NATIVE_LOG_STRS:
        # find the key in rivet.LEVELS corresponding to the value of args.LOGLEVEL
        l = list(rivet.LEVELS.keys())[list(rivet.LEVELS.values()).index(args.LOGLEVEL)]
        args.NATIVE_LOG_STRS.append("Rivet=%s" % l)
    for l in args.NATIVE_LOG_STRS:
        name, level = None, None
        try:
            name, level = l.split("=")
        except:
            name = "Rivet"
            level = l
        ## Fix name
        if name != "Rivet" and not name.startswith("Rivet."):
            name = "Rivet." + name
        try:
            ## Get right error type
            level = rivet.LEVELS.get(level.upper(), None)
            logging.debug("Setting log level: %s %d" % (name, level))
            rivet.setLogLevel(name, level)
        except:
            logging.warning("Couldn't process logging string '%s'" % l)


    ## Set up signal handling
    RECVD_KILL_SIGNAL = None
    def handleKillSignal(signum, frame):
        "Declare us as having been signalled, and return to default handling behaviour"
        global RECVD_KILL_SIGNAL
        logging.critical("Signal handler called with signal " + str(signum))
        RECVD_KILL_SIGNAL = signum
        signal.signal(signum, signal.SIG_DFL)
    ## Signals to handle
    signal.signal(signal.SIGTERM, handleKillSignal);
    signal.signal(signal.SIGHUP,  handleKillSignal);
    signal.signal(signal.SIGINT,  handleKillSignal);
    signal.signal(signal.SIGUSR1, handleKillSignal);
    signal.signal(signal.SIGUSR2, handleKillSignal);
    try:
        signal.signal(signal.SIGXCPU, handleKillSignal);
    except:
        pass


    if args.ANALYSIS_PATH_PWD:
        rivet.addAnalysisLibPath(os.path.abspath("."))
        rivet.addAnalysisDataPath(os.path.abspath("."))

    ## Set up analysis handler
    ah = rivet.AnalysisHandler("Merging")

    ## Check if it's worth splitting the input files
    splitFiles = args.JOBS > 1 and len(args.YODAFILES) > 2*args.JOBS

    # If the user specified a global cross-section, attach the value
    # like ":=1.23" to each file name in order to force overwriting
    xs = ''
    if args.XSEC is not None:
        xs = ':=' + args.XSEC
    inputFiles = [ y if ':' in y else y + xs for y in args.YODAFILES ]

    if splitFiles:

        nprocs = args.JOBS if args.JOBS < numcores else numcores
        # Call multiprocessing pool to merge subsets of files
        manager = multiprocessing.Manager()
        p = multiprocessing.Pool(processes=nprocs)
        jobs = [p.apply_async(func=subMerge,
                              args=(inputFiles[idx:idx+nprocs],
                                    args.MERGEOPTIONS, args.ADDOPTIONS,
                                    args.MATCH, args.UNMATCH,
                                    args.EQUIV, not args.ASSUME_REENTRANT,
                                    args.PRELOADFILE, procID,))
                                    for procID, idx in enumerate(range(0, len(inputFiles), nprocs))]
        results = [ job.get() for job in jobs ]
        p.close()
        # Merge the individual substeps
        ah.loadAOs(*results[0])
        for aopaths, aodata in results[1:]:
            ahtemp = rivet.AnalysisHandler("TempAH")
            ahtemp.loadAOs(aopaths, aodata)
            ah.merge(ahtemp)

    else:

        if args.PRELOADFILE is not None:
            ah.readData(args.PRELOADFILE)

        ah.mergeYODAs(inputFiles, args.MERGEOPTIONS, args.ADDOPTIONS,
                      args.MATCH, args.UNMATCH, args.EQUIV, not args.ASSUME_REENTRANT)

    logging.info("Rerunning finalize ...")
    ah.finalize()
    ah.writeData(args.OUTPUTFILE)
