Module pycascadia.remove_restore

Implementation of the remove-restore algorithm from the GEBCO cookbook

Note that the remove-restore algorithm consists of only "step D" described in the cookbook. Preprocessing steps A-C are not included here.

Expand source code
#!/usr/bin/env python3

"""
Implementation of the remove-restore algorithm from the GEBCO cookbook

Note that the remove-restore algorithm consists of only "step D" described in
the cookbook. Preprocessing steps A-C are not included here.
"""

from pygmt import blockmedian, grdtrack, grdfilter
from pygmt.clib import Session
from pygmt.helpers import (
    GMTTempFile,
    build_arg_string,
    kwargs_to_strings,
    use_alias,
)

import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt
import argparse

from pycascadia.grid import Grid
from pycascadia.utility import (
    min_regions,
    is_region_valid,
    read_fnames,
    all_values_are_nodata,
)


@use_alias(
    I="spacing",
    R="region",
    V="verbose",
)
@kwargs_to_strings(R="sequence")
def nearneighbour(data_xyz, **kwargs):
    """Uses pyGMT's clib to call GMT's nearneighbour command

    Adapted from pyGMT's [blockmedian implementation](https://github.com/GenericMappingTools/pygmt/blob/c0ff7f1add9884305688c2fa15c5f13516b8b960/pygmt/src/blockm.py)
    """
    with GMTTempFile(suffix=".csv") as tmpfile:
        with Session() as lib:
            file_context = lib.virtualfile_from_data(check_kind="vector", data=data_xyz)
            with file_context as infile:
                if "G" not in kwargs.keys():  # if outgrid is unset, output to tempfile
                    kwargs.update({"G": tmpfile.name})
                outgrid = kwargs["G"]
                arg_str = " ".join([infile, build_arg_string(kwargs)])
                lib.call_module("nearneighbor", arg_str)

        if outgrid == tmpfile.name:  # if user did not set outgrid, return DataArray
            with xr.open_dataarray(outgrid) as dataarray:
                result = dataarray.load()
                _ = result.gmt  # load GMTDataArray accessor information
        else:
            result = None  # if user sets an outgrid, return None

    return result


def create_interpolation_grid(
    diff_grid: xr.DataArray, nodata_val: int, window_width: int
) -> xr.DataArray:
    """Create filter to smooth hard edge of difference grid.

    This works by creating a grid containing 1 where there is data in the input
    grid and a 0 where there is none, then applying a filter to this grid to
    smooth the boundary. This can be directly multiplied by the difference grid
    to smooth its edges.

    Args:
        diff_grid: Difference grid to calculate smoothing filter from.
        nodata_val: Value in grid representing a lack of data.
        window_width: Width of smoothing region at data boundary.

    Returns:
        Grid which should be multiplied by input grid in order to smooth.
    """
    # nodata_grid = xr.where(diff_grid == nodata_val, 1.0, 0.0) # This doesn't work, for reference

    # Form grid of 0.0 where there's no data and 1.0 where there's data
    nodata_grid = diff_grid.where(diff_grid == nodata_val, 1.0)
    nodata_grid = nodata_grid.where(diff_grid != nodata_val, 0.0)
    # Use boxcar filter to smooth hard boundary between data & no data
    interp_grid = grdfilter(nodata_grid, filter=f"b{2*window_width}", distance=0)
    # Rescale and keep only one side of window, that on the data side
    interp_grid = (interp_grid - 0.5) * 2.0
    interp_grid = interp_grid.where(interp_grid > 0.0, 0.0)

    return interp_grid


def calc_diff_grid(
    base_grid: Grid,
    update_grid: Grid,
    diff_threshold: float = 0.0,
    window_width: int = None,
) -> xr.DataArray:
    """Calculates difference grid for use in remove-restore.

    Args:
        base_grid: Base grid to be later updated using the calculated difference grid.
        update_grid: Differences will be calculated between this and the base grid.
        diff_threshold: Optional threshold above which a difference will be applied.
        window_width: Width of optional smoothing window around update grid.

    Returns:
        Difference grid for updating base grid.
    """
    print("Blockmedian update grid")
    max_spacing = max(update_grid.spacing, base_grid.spacing)
    minimal_region = min_regions(update_grid.region, base_grid.region)
    if not is_region_valid(minimal_region):
        print("Update grid is entirely outside region of interest. Skipping.")
        return None

    if all_values_are_nodata(update_grid.grid):
        print("Update grid consists entirely of no_data_values. Skipping.")
        return None

    bmd = blockmedian(update_grid.xyz, spacing=max_spacing, region=minimal_region)

    print("Find z in base grid")
    base_pts = grdtrack(bmd, base_grid.grid, "base_z", interpolation="l")

    print("Create difference grid")
    diff = pd.DataFrame()
    diff["x"] = base_pts["x"]
    diff["y"] = base_pts["y"]
    diff["z"] = base_pts["z"] - base_pts["base_z"]

    diff[diff.z.abs() < diff_threshold]["z"] = 0.0  # Filter out small differences

    NODATA_VAL = 9999

    diff_grid = nearneighbour(
        diff,
        region=base_grid.region,
        spacing=base_grid.spacing,
        S=2 * max_spacing,
        N=4,
        E=NODATA_VAL,
        verbose=True,
    )

    # Interpolate between nodata and data regions in update grid
    if window_width:
        interp_grid = create_interpolation_grid(diff_grid, NODATA_VAL, window_width)

        # Filter out nodata
        diff_grid = diff_grid.where(diff_grid != NODATA_VAL, 0.0)
        # Filter the original difference grid using the interpolation grid
        diff_grid = diff_grid * interp_grid
    else:
        # Filter out nodata
        diff_grid = diff_grid.where(diff_grid != NODATA_VAL, 0.0)

    return diff_grid


def load_base_grid(fname: str, region: list = None, spacing: bool = None) -> Grid:
    """Load base grid from file optionally cropping and resampling.

    Args:
        fname: Filename of input grid.
        region: Optional region to crop to.
        spacing: Optional grid spacing to which the base grid will be resampled.

    Returns:
        Grid containing base grid.
    """
    base_grid = Grid(fname, convert_to_xyz=False)
    if region:
        base_grid.crop(region)
    if spacing:
        base_grid.resample(spacing)

    return base_grid


def main():
    """Main entry point for remove-restore command line tool.

    This handles arguments, applies the remove-restore algorithm and, optionally, plots the results.
    """
    # Handle arguments
    parser = argparse.ArgumentParser(
        description="Combine multiple bathymetry sources into a single grid"
    )
    parser.add_argument(
        "filenames", nargs="*", help="sources to combine with the base grid"
    )
    parser.add_argument("--base", required=True, help="base grid")
    parser.add_argument("--input_txt", help="text file containing list of input grids")
    parser.add_argument("--spacing", type=float, help="output grid spacing")
    parser.add_argument(
        "--diff_threshold",
        default=0.0,
        help="value above which differences will be added to the base grid",
    )
    parser.add_argument(
        "--plot", action="store_true", help="plot final output before saving"
    )
    parser.add_argument("--output", required=True, help="filename of final output")
    parser.add_argument(
        "--window_width",
        required=False,
        type=float,
        help="Enable windowing of update grid and specify width of window in degrees",
    )
    parser.add_argument(
        "--region_of_interest",
        metavar=("xmin", "xmax", "ymin", "ymax"),
        required=False,
        nargs=4,
        type=float,
        help="output region. Defaults to the extent of the base grid.",
    )
    args = parser.parse_args()

    filenames = []
    if args.input_txt:
        # Read filenames from file
        filenames += read_fnames(args.input_txt)
    # Add filenames from command line
    filenames += args.filenames

    assert filenames != [], "No filenames given"

    base_fname = args.base
    diff_threshold = args.diff_threshold
    output_fname = args.output
    region_of_interest = args.region_of_interest
    window_width = args.window_width

    # Create base grid
    base_grid = load_base_grid(
        base_fname, region=region_of_interest, spacing=args.spacing
    )

    # Update base grid
    for fname in filenames:
        print("Loading update grid")
        update_grid = Grid(fname, convert_to_xyz=True)

        diff_grid = calc_diff_grid(
            base_grid,
            update_grid,
            diff_threshold=diff_threshold,
            window_width=window_width,
        )

        if diff_grid is not None:
            print("Update base grid")
            base_grid.grid.values += diff_grid.values

    base_grid.save_grid(output_fname)

    if args.plot:
        fig, axes = plt.subplots(2, 2)
        initial_base_grid = load_base_grid(base_fname, region=region_of_interest)
        initial_base_grid.plot(ax=axes[0, 0])
        axes[0, 0].set_title("Initial Grid")
        base_grid.plot(ax=axes[0, 1])
        axes[0, 1].set_title("Final Grid")
        # diff_grid.plot(ax=axes[0,1])
        # axes[0,1].set_title("Difference Grid")
        base_grid.grid.differentiate("x").plot(ax=axes[1, 0])
        axes[1, 0].set_title("x Derivative of Final Grid")
        base_grid.grid.differentiate("y").plot(ax=axes[1, 1])
        axes[1, 1].set_title("y Derivative of Final Grid")
        plt.show()


if __name__ == "__main__":
    main()

Functions

def calc_diff_grid(base_grid: Grid, update_grid: Grid, diff_threshold: float = 0.0, window_width: int = None) -> xarray.core.dataarray.DataArray

Calculates difference grid for use in remove-restore.

Args

base_grid
Base grid to be later updated using the calculated difference grid.
update_grid
Differences will be calculated between this and the base grid.
diff_threshold
Optional threshold above which a difference will be applied.
window_width
Width of optional smoothing window around update grid.

Returns

Difference grid for updating base grid.

Expand source code
def calc_diff_grid(
    base_grid: Grid,
    update_grid: Grid,
    diff_threshold: float = 0.0,
    window_width: int = None,
) -> xr.DataArray:
    """Calculates difference grid for use in remove-restore.

    Args:
        base_grid: Base grid to be later updated using the calculated difference grid.
        update_grid: Differences will be calculated between this and the base grid.
        diff_threshold: Optional threshold above which a difference will be applied.
        window_width: Width of optional smoothing window around update grid.

    Returns:
        Difference grid for updating base grid.
    """
    print("Blockmedian update grid")
    max_spacing = max(update_grid.spacing, base_grid.spacing)
    minimal_region = min_regions(update_grid.region, base_grid.region)
    if not is_region_valid(minimal_region):
        print("Update grid is entirely outside region of interest. Skipping.")
        return None

    if all_values_are_nodata(update_grid.grid):
        print("Update grid consists entirely of no_data_values. Skipping.")
        return None

    bmd = blockmedian(update_grid.xyz, spacing=max_spacing, region=minimal_region)

    print("Find z in base grid")
    base_pts = grdtrack(bmd, base_grid.grid, "base_z", interpolation="l")

    print("Create difference grid")
    diff = pd.DataFrame()
    diff["x"] = base_pts["x"]
    diff["y"] = base_pts["y"]
    diff["z"] = base_pts["z"] - base_pts["base_z"]

    diff[diff.z.abs() < diff_threshold]["z"] = 0.0  # Filter out small differences

    NODATA_VAL = 9999

    diff_grid = nearneighbour(
        diff,
        region=base_grid.region,
        spacing=base_grid.spacing,
        S=2 * max_spacing,
        N=4,
        E=NODATA_VAL,
        verbose=True,
    )

    # Interpolate between nodata and data regions in update grid
    if window_width:
        interp_grid = create_interpolation_grid(diff_grid, NODATA_VAL, window_width)

        # Filter out nodata
        diff_grid = diff_grid.where(diff_grid != NODATA_VAL, 0.0)
        # Filter the original difference grid using the interpolation grid
        diff_grid = diff_grid * interp_grid
    else:
        # Filter out nodata
        diff_grid = diff_grid.where(diff_grid != NODATA_VAL, 0.0)

    return diff_grid
def create_interpolation_grid(diff_grid: xarray.core.dataarray.DataArray, nodata_val: int, window_width: int) -> xarray.core.dataarray.DataArray

Create filter to smooth hard edge of difference grid.

This works by creating a grid containing 1 where there is data in the input grid and a 0 where there is none, then applying a filter to this grid to smooth the boundary. This can be directly multiplied by the difference grid to smooth its edges.

Args

diff_grid
Difference grid to calculate smoothing filter from.
nodata_val
Value in grid representing a lack of data.
window_width
Width of smoothing region at data boundary.

Returns

Grid which should be multiplied by input grid in order to smooth.

Expand source code
def create_interpolation_grid(
    diff_grid: xr.DataArray, nodata_val: int, window_width: int
) -> xr.DataArray:
    """Create filter to smooth hard edge of difference grid.

    This works by creating a grid containing 1 where there is data in the input
    grid and a 0 where there is none, then applying a filter to this grid to
    smooth the boundary. This can be directly multiplied by the difference grid
    to smooth its edges.

    Args:
        diff_grid: Difference grid to calculate smoothing filter from.
        nodata_val: Value in grid representing a lack of data.
        window_width: Width of smoothing region at data boundary.

    Returns:
        Grid which should be multiplied by input grid in order to smooth.
    """
    # nodata_grid = xr.where(diff_grid == nodata_val, 1.0, 0.0) # This doesn't work, for reference

    # Form grid of 0.0 where there's no data and 1.0 where there's data
    nodata_grid = diff_grid.where(diff_grid == nodata_val, 1.0)
    nodata_grid = nodata_grid.where(diff_grid != nodata_val, 0.0)
    # Use boxcar filter to smooth hard boundary between data & no data
    interp_grid = grdfilter(nodata_grid, filter=f"b{2*window_width}", distance=0)
    # Rescale and keep only one side of window, that on the data side
    interp_grid = (interp_grid - 0.5) * 2.0
    interp_grid = interp_grid.where(interp_grid > 0.0, 0.0)

    return interp_grid
def load_base_grid(fname: str, region: list = None, spacing: bool = None) -> Grid

Load base grid from file optionally cropping and resampling.

Args

fname
Filename of input grid.
region
Optional region to crop to.
spacing
Optional grid spacing to which the base grid will be resampled.

Returns

Grid containing base grid.

Expand source code
def load_base_grid(fname: str, region: list = None, spacing: bool = None) -> Grid:
    """Load base grid from file optionally cropping and resampling.

    Args:
        fname: Filename of input grid.
        region: Optional region to crop to.
        spacing: Optional grid spacing to which the base grid will be resampled.

    Returns:
        Grid containing base grid.
    """
    base_grid = Grid(fname, convert_to_xyz=False)
    if region:
        base_grid.crop(region)
    if spacing:
        base_grid.resample(spacing)

    return base_grid
def main()

Main entry point for remove-restore command line tool.

This handles arguments, applies the remove-restore algorithm and, optionally, plots the results.

Expand source code
def main():
    """Main entry point for remove-restore command line tool.

    This handles arguments, applies the remove-restore algorithm and, optionally, plots the results.
    """
    # Handle arguments
    parser = argparse.ArgumentParser(
        description="Combine multiple bathymetry sources into a single grid"
    )
    parser.add_argument(
        "filenames", nargs="*", help="sources to combine with the base grid"
    )
    parser.add_argument("--base", required=True, help="base grid")
    parser.add_argument("--input_txt", help="text file containing list of input grids")
    parser.add_argument("--spacing", type=float, help="output grid spacing")
    parser.add_argument(
        "--diff_threshold",
        default=0.0,
        help="value above which differences will be added to the base grid",
    )
    parser.add_argument(
        "--plot", action="store_true", help="plot final output before saving"
    )
    parser.add_argument("--output", required=True, help="filename of final output")
    parser.add_argument(
        "--window_width",
        required=False,
        type=float,
        help="Enable windowing of update grid and specify width of window in degrees",
    )
    parser.add_argument(
        "--region_of_interest",
        metavar=("xmin", "xmax", "ymin", "ymax"),
        required=False,
        nargs=4,
        type=float,
        help="output region. Defaults to the extent of the base grid.",
    )
    args = parser.parse_args()

    filenames = []
    if args.input_txt:
        # Read filenames from file
        filenames += read_fnames(args.input_txt)
    # Add filenames from command line
    filenames += args.filenames

    assert filenames != [], "No filenames given"

    base_fname = args.base
    diff_threshold = args.diff_threshold
    output_fname = args.output
    region_of_interest = args.region_of_interest
    window_width = args.window_width

    # Create base grid
    base_grid = load_base_grid(
        base_fname, region=region_of_interest, spacing=args.spacing
    )

    # Update base grid
    for fname in filenames:
        print("Loading update grid")
        update_grid = Grid(fname, convert_to_xyz=True)

        diff_grid = calc_diff_grid(
            base_grid,
            update_grid,
            diff_threshold=diff_threshold,
            window_width=window_width,
        )

        if diff_grid is not None:
            print("Update base grid")
            base_grid.grid.values += diff_grid.values

    base_grid.save_grid(output_fname)

    if args.plot:
        fig, axes = plt.subplots(2, 2)
        initial_base_grid = load_base_grid(base_fname, region=region_of_interest)
        initial_base_grid.plot(ax=axes[0, 0])
        axes[0, 0].set_title("Initial Grid")
        base_grid.plot(ax=axes[0, 1])
        axes[0, 1].set_title("Final Grid")
        # diff_grid.plot(ax=axes[0,1])
        # axes[0,1].set_title("Difference Grid")
        base_grid.grid.differentiate("x").plot(ax=axes[1, 0])
        axes[1, 0].set_title("x Derivative of Final Grid")
        base_grid.grid.differentiate("y").plot(ax=axes[1, 1])
        axes[1, 1].set_title("y Derivative of Final Grid")
        plt.show()
def nearneighbour(data_xyz, **kwargs)

Uses pyGMT's clib to call GMT's nearneighbour command

Adapted from pyGMT's blockmedian implementation

Expand source code
@use_alias(
    I="spacing",
    R="region",
    V="verbose",
)
@kwargs_to_strings(R="sequence")
def nearneighbour(data_xyz, **kwargs):
    """Uses pyGMT's clib to call GMT's nearneighbour command

    Adapted from pyGMT's [blockmedian implementation](https://github.com/GenericMappingTools/pygmt/blob/c0ff7f1add9884305688c2fa15c5f13516b8b960/pygmt/src/blockm.py)
    """
    with GMTTempFile(suffix=".csv") as tmpfile:
        with Session() as lib:
            file_context = lib.virtualfile_from_data(check_kind="vector", data=data_xyz)
            with file_context as infile:
                if "G" not in kwargs.keys():  # if outgrid is unset, output to tempfile
                    kwargs.update({"G": tmpfile.name})
                outgrid = kwargs["G"]
                arg_str = " ".join([infile, build_arg_string(kwargs)])
                lib.call_module("nearneighbor", arg_str)

        if outgrid == tmpfile.name:  # if user did not set outgrid, return DataArray
            with xr.open_dataarray(outgrid) as dataarray:
                result = dataarray.load()
                _ = result.gmt  # load GMTDataArray accessor information
        else:
            result = None  # if user sets an outgrid, return None

    return result