Skip to content
Extraits de code Groupes Projets
plot_sediments_footprints.py 10,6 ko
Newer Older
  • Learn to ignore specific revisions
  • # %%
    import os
    
    import geopandas as gpd
    import matplotlib
    import matplotlib.pyplot as plt
    import matplotlib.ticker as plticker
    from mpl_toolkits.axes_grid1 import AxesGrid
    import numpy as np
    from osgeo import gdal
    
    import param
    import param_private as pp
    
    font = {
        "family": "Arial",
        "size": 8,
    }
    
    matplotlib.rc("font", **font)
    
    # %%
    p = param.parameters()
    coast_shp_file = p.land_shapefile
    gdf_coast = gpd.read_file(coast_shp_file)
    gdf_coast_utm = gdf_coast.to_crs("epsg:32756")
    
    
    # %%
    nu = "10-6"
    im_type = ["sedimented", "floating"]
    loc = ["inside", "outside"]
    number_particles = 400000
    simu_name = "one_month_release_constant_seed_by_site"
    mesh_name = "gbr_styx_with_rivers"
    sed_sizes = [1, 2, 4, 8, 16, 32, 64, 125, 250, 500, 1000]
    t_start = 1609804800
    max_week = 12
    thresholds = [100, 1000, 4000, 5000, 10000]
    
    base_path = f"/Users/asaintamand/Documents/thesis/scripts/styx_river_coal_mine/data/output/{mesh_name}/LPT_sediments/3D/{simu_name}_nu{nu}_{number_particles}/"
    
    
    # %%
    class Tiff_array:
        def __init__(self, i, l, t, min_size, max_size):
            self.i = i
            self.l = l
            self.t = t
            self.min_size = min_size
            self.max_size = max_size
            self.get_tiff_path()
            self.get_tiff_as_array()
    
        def get_tiff_path(self):
            name = f"{self.i}_{self.l}_cumul_{self.min_size}_{self.max_size}_{self.t}"
            self.path = f"{base_path}{self.min_size}_{self.max_size}/raster/{name}.tiff"
    
        def get_tiff_as_array(self):
            dataset = gdal.Open(self.path)
            self.reprojected = gdal.Warp(
                "/vsimem/in_memory_output.tif", dataset, dstSRS="EPSG:4326"
            )
    
            # Note GetRasterBand() takes band no. starting from 1 not 0
            band = self.reprojected.GetRasterBand(1)
            self.arr = band.ReadAsArray()
    
        def get_tiff_extent(self):
            minx, dx, _, maxy, _, dy = self.reprojected.GetGeoTransform()
            nrows, ncols = self.arr.shape
    
            maxx = minx + ncols * dx
            miny = maxy + nrows * dy
            return minx, maxx, miny, maxy
    
    
    # %%
    def get_plot(arr_weeks, minx, maxx, miny, maxy, loc):
        cmap = plt.get_cmap("afmhot_r", np.max(max_week))
    
        fig = plt.figure(dpi=600)
    
        ax = AxesGrid(
            fig,
            111,
            nrows_ncols=(1, 1),
            axes_pad=0.0,
            share_all=True,
            label_mode="L",
            cbar_mode="single",
            cbar_pad=0.05,
            cbar_size="5%",
        )
    
        xloc = plticker.MultipleLocator(base=0.5)
        yloc = plticker.MultipleLocator(base=0.5)
    
        ax[0].set_aspect("equal")
        ax[0].set_facecolor("skyblue")
        im = ax[0].imshow(
            arr_weeks,
            extent=(minx, maxx, miny, maxy),
            vmin=0.5,
            vmax=max_week + 0.5,
            cmap=cmap,
        )
        ax[0].contour(
            arr_weeks,
            levels=[2],
            colors="crimson",
            origin="upper",
            extent=(minx, maxx, miny, maxy),
            linewidths=(0.8,),
        )
    
        gdf_coast.plot(ax=ax[0], linewidth=0.5, edgecolors="k", color=pp.colors["gris"])
    
        ax[0].set_xlim(minx, maxx)
        ax[0].set_ylim(miny, maxy)
        ax[0].tick_params(
            top=False,
            bottom=False,
            left=False,
            right=False,
            labelleft=False,
            labelbottom=False,
        )
        ax[0].xaxis.set_major_locator(xloc)
        ax[0].yaxis.set_major_locator(yloc)
    
        ax[0].text(
            maxx - 0.075,
            maxy - 0.115,
            f"{min_size} µm - {max_size} µm",
            fontfamily="arial",
            weight="bold",
            fontsize="large",
            ha="right",
            c="white",
            backgroundcolor=(0, 0.25, 0.33),
        )
    
        cbar = ax.cbar_axes[0].colorbar(im, ticks=np.arange(1, max_week + 1))
        cbar.ax.set_ylabel("Arrival time (weeks)", fontsize="large")
    
        fig_dir = f"{base_path}fig/{loc}/{im_t}/threshold_{threshold}/"
        os.makedirs(fig_dir, exist_ok=True)
        plt.savefig(
            f"{fig_dir}{loc}_{im_t}_{threshold}_{min_size}_{max_size}.png",
            bbox_inches="tight",
        )
        plt.close()
    
    
    def launch_all_plots(threshold, min_size, max_size, im_t):
        for i in range(max_week, 0, -1):
            t = t_start + i * 86400 * 7
            arr_inside = Tiff_array(im_t, "inside", str(t), min_size, max_size)
            arr_outside = Tiff_array(im_t, "outside", str(t), min_size, max_size)
            arr_sum = arr_inside.arr + arr_outside.arr
    
            if i == max_week:
                arr_weeks_inside = np.zeros(np.shape(arr_sum))
                arr_weeks_outside = np.zeros(np.shape(arr_sum))
                arr_weeks_sum = np.zeros(np.shape(arr_sum))
    
            arr_weeks_inside = np.where(arr_inside.arr >= threshold, i, arr_weeks_inside)
            arr_weeks_outside = np.where(arr_outside.arr >= threshold, i, arr_weeks_outside)
            arr_weeks_sum = np.where(arr_sum >= threshold, i, arr_weeks_sum)
    
        arr_weeks_inside = np.ma.masked_where(arr_weeks_inside == 0, arr_weeks_inside)
        arr_weeks_outside = np.ma.masked_where(arr_weeks_outside == 0, arr_weeks_outside)
        arr_weeks_sum = np.ma.masked_where(arr_weeks_sum == 0, arr_weeks_sum)
    
        minx, maxx, miny, maxy = arr_inside.get_tiff_extent()
    
        for loc in ["inside", "outside", "sum"]:
            if loc == "inside":
                arr_to_plot = arr_weeks_inside
            elif loc == "outisde":
                arr_to_plot = arr_weeks_outside
            else:
                arr_to_plot = arr_weeks_sum
    
            get_plot(arr_to_plot, minx, maxx, miny, maxy, loc)
    
    
    # %% Compute one plot for each threshold/loc/min-max sizes
    for threshold in thresholds:
        for min_size, max_size in zip(sed_sizes[:-1], sed_sizes[1:]):
            for im_t in im_type:
                launch_all_plots(threshold, min_size, max_size, im_t)
    
    
    # %% Compute a subplots figure 2X2
    im_t = "floating"
    threshold = 4000
    sizes = [(4, 8), (8, 16), (16, 32), (32, 64)]
    sizes_for_name = "_".join([f"{s[0]}-{s[1]}" for s in sizes])
    max_week = 12
    minx, miny, maxx, maxy = 149.3, -22.7, 150.25, -21.51
    fig_dir = f"{base_path}fig/sum/{im_t}/threshold_{threshold}/"
    os.makedirs(fig_dir, exist_ok=True)
    
    fig = plt.figure(dpi=600, figsize=(8, 8))
    
    ax = AxesGrid(
        fig,
        111,
        nrows_ncols=(2, 2),
        axes_pad=0.0,
        share_all=True,
        label_mode="L",
        cbar_mode="single",
    )
    
    xloc = plticker.MultipleLocator(base=0.5)
    yloc = plticker.MultipleLocator(base=0.5)
    
    letters = "ABCDEF"
    
    for a, s in enumerate(sizes):
    
        min_size, max_size = s
    
        for i in range(max_week, 0, -1):
            t = t_start + i * 86400 * 7
    
            arr_inside = Tiff_array(im_t, "inside", str(t), min_size, max_size)
            arr_outside = Tiff_array(im_t, "outside", str(t), min_size, max_size)
            arr_sum = arr_inside.arr + arr_outside.arr
    
            if i == max_week:
                arr_weeks_sum = np.zeros(np.shape(arr_sum))
    
            arr_weeks_sum = np.where(arr_sum >= threshold, i, arr_weeks_sum)
    
        arr_weeks_sum = np.ma.masked_where(arr_weeks_sum == 0, arr_weeks_sum)
        arr_extent = arr_inside.get_tiff_extent()
    
        ax[a].set_aspect("equal")
        ax[a].set_facecolor("skyblue")
        im = ax[a].imshow(
            arr_weeks_sum,
            extent=arr_extent,
            vmin=0,
            vmax=max_week,
            cmap="afmhot_r",
        )
        ax[a].contour(
            arr_weeks_sum,
            levels=[2],
            colors="crimson",
            origin="upper",
            extent=arr_extent,
            linewidths=(0.8,),
        )
    
        gdf_coast.plot(ax=ax[a], linewidth=0.5, edgecolors="k", color=pp.colors["gris"])
    
        ax[a].set_xlim(minx, maxx)
        ax[a].set_ylim(miny, maxy)
        ax[a].tick_params(axis="both", which="both", direction="in", right=True, top=True)
        ax[a].xaxis.set_major_locator(xloc)
        ax[a].yaxis.set_major_locator(yloc)
    
        ax[a].text(
            minx + 0.01,
            maxy - 0.01,
            letters[a],
            c="white",
            ha="left",
            va="top",
            fontfamily="arial",
            bbox=dict(facecolor="k", edgecolor="none", pad=1.2),
            clip_on=True,
        )
    
        ax[a].text(
            minx + 0.075,
            miny + 0.075,
            f"{min_size} µm - {max_size} µm",
            fontfamily="arial",
            weight="bold",
        )
    
    ax.cbar_axes[0].colorbar(im)
    
    plt.savefig(
        f"{fig_dir}quatro_sum_{im_t}_{threshold}_{sizes_for_name}.jpg",
        bbox_inches="tight",
        dpi=1000,
    )
    plt.show()
    
    # %% Compute a subplots figure 3X2
    cmap = plt.get_cmap("afmhot_r", np.max(max_week))
    
    im_t = "floating"
    threshold = 4000
    sizes = [(2, 4), (4, 8), (8, 16), (16, 32), (32, 64), (64, 125)]
    sizes_for_name = "_".join([f"{s[0]}-{s[1]}" for s in sizes])
    max_week = 12
    minx, miny, maxx, maxy = 149.3, -22.7, 150.25, -21.51
    fig_dir = f"{base_path}fig/sum/{im_t}/threshold_{threshold}/"
    os.makedirs(fig_dir, exist_ok=True)
    
    fig = plt.figure(dpi=600, figsize=(12, 8))
    
    ax = AxesGrid(
        fig,
        111,
        nrows_ncols=(2, 3),
        axes_pad=0.0,
        share_all=True,
        label_mode="L",
        cbar_mode="single",
        cbar_pad=0.1,
        cbar_size="3%",
    )
    
    xloc = plticker.MultipleLocator(base=0.5)
    yloc = plticker.MultipleLocator(base=0.5)
    
    for a, s in enumerate(sizes):
    
        min_size, max_size = s
    
        for i in range(max_week, 0, -1):
            t = t_start + i * 86400 * 7
    
            arr_inside = Tiff_array(im_t, "inside", str(t), min_size, max_size)
            arr_outside = Tiff_array(im_t, "outside", str(t), min_size, max_size)
            arr_sum = arr_inside.arr + arr_outside.arr
    
            if i == max_week:
                arr_weeks_sum = np.zeros(np.shape(arr_sum))
    
            arr_weeks_sum = np.where(arr_sum >= threshold, i, arr_weeks_sum)
    
        arr_weeks_sum = np.ma.masked_where(arr_weeks_sum == 0, arr_weeks_sum)
        arr_extent = arr_inside.get_tiff_extent()
    
        ax[a].set_aspect("equal")
        ax[a].set_facecolor("skyblue")
        im = ax[a].imshow(
            arr_weeks_sum,
            extent=arr_extent,
            vmin=0.5,
            vmax=max_week + 0.5,
            cmap=cmap,
        )
        ax[a].contour(
            arr_weeks_sum,
            levels=[2],
            colors="crimson",
            origin="upper",
            extent=arr_extent,
            linewidths=(0.8,),
        )
    
        gdf_coast.plot(ax=ax[a], linewidth=0.5, edgecolors="k", color=pp.colors["gris"])
    
        ax[a].set_xlim(minx, maxx)
        ax[a].set_ylim(miny, maxy)
        ax[a].tick_params(axis="both", which="both", direction="in", right=True, top=True)
        ax[a].xaxis.set_major_locator(xloc)
        ax[a].yaxis.set_major_locator(yloc)
    
        ax[a].text(
            minx + 0.01,
            maxy - 0.01,
            letters[a],
            c="white",
            ha="left",
            va="top",
            fontfamily="arial",
            bbox=dict(facecolor="k", edgecolor="none", pad=1.2),
            clip_on=True,
        )
    
        ax[a].text(
            minx + 0.075,
            miny + 0.075,
            f"{min_size} µm - {max_size} µm",
            fontfamily="arial",
            weight="bold",
        )
    
    cbar = ax.cbar_axes[0].colorbar(im, ticks=np.arange(1, max_week + 1))
    cbar.ax.set_ylabel("Arrival time (weeks)", fontsize="large")
    
    plt.savefig(
        f"{fig_dir}six_sum_{im_t}_{threshold}_{sizes_for_name}.jpg",
        bbox_inches="tight",
        dpi=1000,
    )
    plt.show()
    
    # %%