Skip to content
Extraits de code Groupes Projets
plot_subplots_abin_hydro.py 9,9 ko
Newer Older
  • Learn to ignore specific revisions
  • # %%
    import os
    import numpy as np
    import datetime as dt
    import geopandas as gpd
    import pyproj
    import matplotlib as mpl
    import matplotlib.pyplot as plt
    import matplotlib.tri as tri
    from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes, mark_inset
    from mpl_toolkits.axes_grid1 import AxesGrid
    from pyproj import transformer
    import abin
    import param
    import param_private as pp
    
    
    # %%
    def dt2ts(d):
        return int(d.replace(tzinfo=dt.timezone.utc).timestamp())
    
    
    # %% ----- Parameters ----------------------------------------------------------
    print("Loading parameters")
    # mesh_setup = "gbr_styx_coarse_with_rivers"
    mesh_setup = "gbr_styx_with_rivers"
    p = param.parameters(mesh_setup=mesh_setup)
    base_path = p.local_base_dir
    simu_name = ""
    
    quiver_flag = False
    stream_flag = True
    eta_field_flag = False  # field = eta if True, else uv
    inset_flag = True
    mean_flag = False
    
    region = {
        "full": {
            "minlon": 148.5,
            "maxlon": 153.1,
            "minlat": -23,
            "maxlat": -19.25,
            "vmax": 0.7,
        },
        "bs": {
            "minlon": 149.45,
            "maxlon": 150.55,
            "minlat": -22.7,
            # "maxlat": -21.75,
            "maxlat": -21.3,
            "vmax": 2,
        },
        "styx": {
            "minlon": 149.6235,
            "maxlon": 149.8075,
            "minlat": -22.54,
            "maxlat": -22.3475,
            "vmax": None,
            "zoom_inset": 3,
        },
        "styx_mouth": {
            "minlon": 149.6235,
            "maxlon": 149.7,
            "minlat": -22.54,
            "maxlat": -22.47,
            "vmax": None,
            "zoom_inset": 7,
        },
    }
    
    cmap = "YlGnBu"
    region_base = "bs"
    region_inset = "styx"
    loc_inset = "upper right"
    inset_ec = pp.colors["rose"]
    inset_lw = 2
    
    vmin = 0
    vmax = region[region_base]["vmax"]
    
    figsize = (20, 6)
    
    fmt = "%Y-%m-%d %H:%M:%S"
    initial_time_as_datetime = dt.datetime.strptime(p.initial_time, fmt)
    dates = [
        dt.datetime(2021, 1, 10, 22, 0, 0),
        dt.datetime(2021, 1, 11, 1, 0, 0),
        dt.datetime(2021, 1, 11, 4, 0, 0),
    ]
    
    grid_res = 500
    vec_res = 50
    
    coast_shp_file = p.land_shapefile
    df_coast = gpd.read_file(coast_shp_file)
    
    fig_name = ""
    fig_name = fig_name + "_quiver" if quiver_flag else fig_name
    fig_name = fig_name + "_stream" if stream_flag else fig_name
    fig_name = fig_name + f"_inset_{region_inset}" if inset_flag else fig_name
    
    # %% ----- Domain definition ---------------------------------------------------
    print("Initializing")
    projutm = pyproj.CRS.from_string(p.mesh_proj)
    projll = pyproj.CRS.from_string("+ellps=WGS84 +proj=latlong")
    
    minlon_base = region[region_base]["minlon"]
    maxlon_base = region[region_base]["maxlon"]
    minlat_base = region[region_base]["minlat"]
    maxlat_base = region[region_base]["maxlat"]
    
    long = np.linspace(minlon_base, maxlon_base, grid_res)
    latg = np.linspace(minlat_base, maxlat_base, grid_res)
    (long, latg) = np.meshgrid(long, latg)
    
    lonq = np.linspace(minlon_base, maxlon_base, vec_res)
    latq = np.linspace(minlat_base, maxlat_base, vec_res)
    (lonq, latq) = np.meshgrid(lonq, latq)
    
    if inset_flag:
        minlon_inset = region[region_inset]["minlon"]
        maxlon_inset = region[region_inset]["maxlon"]
        minlat_inset = region[region_inset]["minlat"]
        maxlat_inset = region[region_inset]["maxlat"]
    
        long_inset = np.linspace(minlon_inset, maxlon_inset, grid_res)
        latg_inset = np.linspace(minlat_inset, maxlat_inset, grid_res)
        (long_inset, latg_inset) = np.meshgrid(long_inset, latg_inset)
    
        lonq_inset = np.linspace(minlon_inset, maxlon_inset, vec_res)
        latq_inset = np.linspace(minlat_inset, maxlat_inset, vec_res)
        (lonq_inset, latq_inset) = np.meshgrid(lonq_inset, latq_inset)
    
    # %%
    simdir = p.output_directory + "/slim/"
    topo = abin.openread(simdir + "/mesh/topology")
    geo = abin.openread(simdir + "/mesh/geometry")[0]
    eta = abin.openread(simdir + "/data/eta")[:, :, 0]
    uv = abin.openread(simdir + "/data/uv")
    u = uv[:, :, 0]
    v = uv[:, :, 1]
    t = abin.openread(simdir + "/time")
    topo = topo[:].reshape(-1, 4)[:, 1:]
    
    transformer = pyproj.Transformer.from_crs(projutm, projll)
    lon, lat = transformer.transform(geo[:, 0], geo[:, 1])
    
    # %%
    print("Selecting useful elements for interpolation")
    lon_bool = np.logical_and(lon > minlon_base - 0.2, lon < maxlon_base + 0.2)
    lat_bool = np.logical_and(lat > minlat_base - 0.2, lat < maxlat_base + 0.2)
    coord_bool = np.logical_and(lon_bool, lat_bool)
    
    # Find triangles to keep
    topo_bool = coord_bool[topo]
    tri_to_keep = np.any(topo_bool, axis=1)
    
    # Keep included topo only
    topo_to_keep = topo[tri_to_keep]
    # Count number of times each node appears in topo
    count_nodes = np.bincount(topo_to_keep.ravel(), minlength=np.max(topo) + 1)
    # Included nodes = when appaers at least once
    included_nodes = np.flatnonzero(count_nodes)
    
    lon = lon[included_nodes]
    lat = lat[included_nodes]
    
    # node renumbering
    corresp = np.ma.zeros(np.max(topo) + 1, dtype=int)
    corresp[included_nodes] = np.arange(len(included_nodes))
    corresp = np.ma.masked_where(count_nodes == 0, corresp)
    topo_renumbered = corresp[topo_to_keep]
    
    # %%
    print("Initializing plot...")
    out_dir = f"{p.output_directory}/fig/"
    os.makedirs(out_dir, exist_ok=True)
    time_for_name = [f"{d}_" for d in dates]
    out_fig_file = f"{out_dir}hydro_{time_for_name}{region_base}{fig_name}"
    
    fig = plt.figure(figsize=figsize)
    
    ax = AxesGrid(
        fig,
        111,
        nrows_ncols=(1, len(dates)),
        axes_pad=0.0,
        share_all=True,
        label_mode="L",
        cbar_mode="single",
    )
    
    for i, date_unique in enumerate(dates):
        print(f"Select timestep data for {date_unique}")
        ts_unique = int((date_unique - initial_time_as_datetime).total_seconds() / 3600)
        posix_unique = dt2ts(date_unique)
    
        eta_ts = eta[ts_unique, :]
        u_ts = u[ts_unique, :]
        v_ts = v[ts_unique, :]
    
        eta_sel = eta_ts[included_nodes]
        u_sel = u_ts[included_nodes]
        v_sel = v_ts[included_nodes]
    
        print("Interpolating...")
        triangles = tri.Triangulation(lon, lat, topo_renumbered)
    
        interp_eta = tri.LinearTriInterpolator(triangles, eta_sel.ravel())
        interp_u = tri.LinearTriInterpolator(triangles, u_sel.ravel())
        interp_v = tri.LinearTriInterpolator(triangles, v_sel.ravel())
    
        etaq = interp_eta(lonq, latq)
        uq = interp_u(lonq, latq)
        vq = interp_v(lonq, latq)
        uvq = np.sqrt(uq * uq + vq * vq)
    
        etag = interp_eta(long, latg)
        ug = interp_u(long, latg)
        vg = interp_v(long, latg)
        uvg = np.sqrt(ug * ug + vg * vg)
    
        myfield = etag if eta_field_flag else uvg
    
        if inset_flag:
            etaq_inset = interp_eta(lonq_inset, latq_inset)
            uq_inset = interp_u(lonq_inset, latq_inset)
            vq_inset = interp_v(lonq_inset, latq_inset)
            uvq_inset = np.sqrt(uq_inset * uq_inset + vq_inset * vq_inset)
    
            etag_inset = interp_eta(long_inset, latg_inset)
            ug_inset = interp_u(long_inset, latg_inset)
            vg_inset = interp_v(long_inset, latg_inset)
            uvg_inset = np.sqrt(ug_inset * ug_inset + vg_inset * vg_inset)
    
            myfield_inset = etag_inset if eta_field_flag else uvg_inset
    
        print("Plotting...")
    
        ax[i].set_aspect("equal")
        im = ax[i].imshow(
            myfield,
            origin="lower",
            extent=[minlon_base, maxlon_base, minlat_base, maxlat_base],
            zorder=1,
        )
        im.set_cmap(cmap)
        im.set_clim(vmin=vmin, vmax=vmax)
    
        if quiver_flag:
            ax[i].quiver(lonq, latq, uq, vq, uvq, cmap="gray", zorder=3)
        if stream_flag:
            ax[i].streamplot(
                long,
                latg,
                np.reshape(ug, (grid_res, grid_res)),
                np.reshape(vg, (grid_res, grid_res)),
                linewidth=1,
                color="k",
                density=2,
                zorder=4,
            )
        if "polygon" in str(type(df_coast.iloc[0, -1])):
            df_coast.plot(
                ax=ax[i], linewidth=0.5, edgecolors="k", color=pp.colors["gris"], zorder=2
            )
        else:
            df_coast.plot(ax=ax, linewidth=0.5, color="k", zorder=4)
        ax[i].set_xlim(minlon_base, maxlon_base)
        ax[i].set_ylim(minlat_base, maxlat_base)
    
        if inset_flag:
            axins = zoomed_inset_axes(
                ax[i], region[region_inset]["zoom_inset"], loc=loc_inset
            )
    
            axins.set_aspect("equal")
            im = axins.imshow(
                myfield_inset,
                origin="lower",
                extent=[minlon_inset, maxlon_inset, minlat_inset, maxlat_inset],
                zorder=1,
            )
            im.set_cmap(cmap)
            im.set_clim(vmin=vmin, vmax=vmax)
            if quiver_flag:
                axins.quiver(
                    lonq_inset,
                    latq_inset,
                    uq_inset,
                    vq_inset,
                    uvq_inset,
                    cmap="gray",
                    zorder=3,
                )
            if stream_flag:
                axins.streamplot(
                    long_inset,
                    latg_inset,
                    np.reshape(ug_inset, (grid_res, grid_res)),
                    np.reshape(vg_inset, (grid_res, grid_res)),
                    color="k",
                    linewidth=1,
                    arrowsize=0.8,
                    density=1.5,
                    zorder=4,
                )
            if "polygon" in str(type(df_coast.iloc[0, -1])):
                df_coast.plot(
                    ax=axins,
                    linewidth=0.5,
                    edgecolors="k",
                    color=pp.colors["gris"],
                    zorder=2,
                )
            else:
                df_coast.plot(ax=axins, linewidth=0.5, color="k", zorder=4)
    
            axins.set_xlim(minlon_inset, maxlon_inset)
            axins.set_ylim(minlat_inset, maxlat_inset)
    
            # Remove ticks and labels
            axins.axes.xaxis.set_ticklabels([])
            axins.axes.yaxis.set_ticklabels([])
    
            [axins.spines[i].set_color(inset_ec) for i in axins.spines]
            [axins.spines[i].set_linewidth(inset_lw) for i in axins.spines]
            [axins.spines[i].set_zorder(5) for i in axins.spines]
            mark_inset(
                ax[i], axins, loc1=2, loc2=4, fc="none", ec=inset_ec, lw=inset_lw, zorder=5
            )
    
    
    # Colorbar
    ax.cbar_axes[0].colorbar(im)
    
    plt.savefig(f"{out_fig_file}.png", bbox_inches="tight")
    plt.savefig(f"{out_fig_file}.jpeg", bbox_inches="tight", dpi=1000)
    plt.savefig(f"{out_fig_file}.pdf", bbox_inches="tight")
    
    plt.show()
    plt.close("all")
    
    # %%