Source code for pinefarm.cli.merge

"""Merge multiple PineAPPL grids into a single one."""

import itertools
import pathlib
import re
import shutil

import click
import pineappl
import rich

from .. import tools
from ._base import command


@command.command("merge")
@click.argument("grids", nargs=-1)
def subcommand(grids):
    """Merge multiple PineAPPL grids into a single one."""
    main(grids)


[docs] def main(grids): """Merge multiple PineAPPL grids into a single one.""" if len(grids) < 2: raise ValueError("At least 2 grids needed for a merge.") grid_paths = [pathlib.Path(grid) for grid in grids] grids = [pineappl.grid.Grid.read(str(grid)) for grid in grids] common = tools.common_substring(*(grid.name for grid in grid_paths)).strip("_") mgrid_path = pathlib.Path(common).with_suffix(".pineappl") rich.print(f"Merging into -> '{mgrid_path}'") # merge all grids in a single one # read the first one from file rich.print(f"Init to '{grid_paths[0]}'") mgrid = grids[0] # subsequently merge all the others (disk -> memory) for path in grid_paths[1:]: rich.print(f"Merging '{path}'...") mgrid.merge_from_file(str(path)) mgrid.write(str(mgrid_path)) # get all keys, possibly ones that are exclusive to a single grid keys = set(itertools.chain(*(grid.key_values().keys() for grid in grids))) # concatenate results tmpresults = [] data_row = re.compile(r"\d.*") empty_row = re.compile( r"\d+ +0.000000e+00 0.000000e+00 0.000 0.000 0.0000 0.0000 0.0000" ) # extract the header from the first grid for line in mgrid.key_values().get("results").splitlines(): results_header_rows = [] if not data_row.fullmatch(line): results_header_rows.append(line) tmpresults.append("\n".join(results_header_rows)) # extract the results from each grid for grid in grids: results_rows = [] for line in grid.key_values().get("results").splitlines(): if data_row.fullmatch(line) and not empty_row.fullmatch(line): results_rows.append(line) tmpresults.append("\n".join(results_rows)) # set the results metadata in the new grid mgridtmp = mgrid_path.parent / (mgrid_path.name + ".tmp") tools.update_grid_metadata( mgrid_path, mgridtmp, entries={"results": "\n".join(tmpresults)} ) shutil.move(str(mgridtmp), str(mgrid_path)) mkeys = mgrid.key_values() for key in keys: if key == "results": continue mvalue = mkeys.get(key) for grid in grids: if mvalue != grid.key_values().get(key): # TODO: what do we do in this case? rich.print(f"'{key}' differs [gray]for '{grid}'[/]") cpath = tools.compress(mgrid_path) mgrid_path.unlink() rich.print(f"Grid merged and compressed, stored in '{cpath}'.")