# %%
import os
import numpy as np
import datetime as dt
import geopandas as gpd
import pyproj
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.tri as tri
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes, mark_inset
from mpl_toolkits.axes_grid1 import AxesGrid
from pyproj import transformer
import abin
import param
import param_private as pp


# %%
def dt2ts(d):
    return int(d.replace(tzinfo=dt.timezone.utc).timestamp())


# %% ----- Parameters ----------------------------------------------------------
print("Loading parameters")
# mesh_setup = "gbr_styx_coarse_with_rivers"
mesh_setup = "gbr_styx_with_rivers"
p = param.parameters(mesh_setup=mesh_setup)
base_path = p.local_base_dir
simu_name = ""

quiver_flag = False
stream_flag = True
eta_field_flag = False  # field = eta if True, else uv
inset_flag = True
mean_flag = False

region = {
    "full": {
        "minlon": 148.5,
        "maxlon": 153.1,
        "minlat": -23,
        "maxlat": -19.25,
        "vmax": 0.7,
    },
    "bs": {
        "minlon": 149.45,
        "maxlon": 150.55,
        "minlat": -22.7,
        # "maxlat": -21.75,
        "maxlat": -21.3,
        "vmax": 2,
    },
    "styx": {
        "minlon": 149.6235,
        "maxlon": 149.8075,
        "minlat": -22.54,
        "maxlat": -22.3475,
        "vmax": None,
        "zoom_inset": 3,
    },
    "styx_mouth": {
        "minlon": 149.6235,
        "maxlon": 149.7,
        "minlat": -22.54,
        "maxlat": -22.47,
        "vmax": None,
        "zoom_inset": 7,
    },
}

cmap = "YlGnBu"
region_base = "bs"
region_inset = "styx"
loc_inset = "upper right"
inset_ec = pp.colors["rose"]
inset_lw = 2

vmin = 0
vmax = region[region_base]["vmax"]

figsize = (20, 6)

fmt = "%Y-%m-%d %H:%M:%S"
initial_time_as_datetime = dt.datetime.strptime(p.initial_time, fmt)
dates = [
    dt.datetime(2021, 1, 10, 22, 0, 0),
    dt.datetime(2021, 1, 11, 1, 0, 0),
    dt.datetime(2021, 1, 11, 4, 0, 0),
]

grid_res = 500
vec_res = 50

coast_shp_file = p.land_shapefile
df_coast = gpd.read_file(coast_shp_file)

fig_name = ""
fig_name = fig_name + "_quiver" if quiver_flag else fig_name
fig_name = fig_name + "_stream" if stream_flag else fig_name
fig_name = fig_name + f"_inset_{region_inset}" if inset_flag else fig_name

# %% ----- Domain definition ---------------------------------------------------
print("Initializing")
projutm = pyproj.CRS.from_string(p.mesh_proj)
projll = pyproj.CRS.from_string("+ellps=WGS84 +proj=latlong")

minlon_base = region[region_base]["minlon"]
maxlon_base = region[region_base]["maxlon"]
minlat_base = region[region_base]["minlat"]
maxlat_base = region[region_base]["maxlat"]

long = np.linspace(minlon_base, maxlon_base, grid_res)
latg = np.linspace(minlat_base, maxlat_base, grid_res)
(long, latg) = np.meshgrid(long, latg)

lonq = np.linspace(minlon_base, maxlon_base, vec_res)
latq = np.linspace(minlat_base, maxlat_base, vec_res)
(lonq, latq) = np.meshgrid(lonq, latq)

if inset_flag:
    minlon_inset = region[region_inset]["minlon"]
    maxlon_inset = region[region_inset]["maxlon"]
    minlat_inset = region[region_inset]["minlat"]
    maxlat_inset = region[region_inset]["maxlat"]

    long_inset = np.linspace(minlon_inset, maxlon_inset, grid_res)
    latg_inset = np.linspace(minlat_inset, maxlat_inset, grid_res)
    (long_inset, latg_inset) = np.meshgrid(long_inset, latg_inset)

    lonq_inset = np.linspace(minlon_inset, maxlon_inset, vec_res)
    latq_inset = np.linspace(minlat_inset, maxlat_inset, vec_res)
    (lonq_inset, latq_inset) = np.meshgrid(lonq_inset, latq_inset)

# %%
simdir = p.output_directory + "/slim/"
topo = abin.openread(simdir + "/mesh/topology")
geo = abin.openread(simdir + "/mesh/geometry")[0]
eta = abin.openread(simdir + "/data/eta")[:, :, 0]
uv = abin.openread(simdir + "/data/uv")
u = uv[:, :, 0]
v = uv[:, :, 1]
t = abin.openread(simdir + "/time")
topo = topo[:].reshape(-1, 4)[:, 1:]

transformer = pyproj.Transformer.from_crs(projutm, projll)
lon, lat = transformer.transform(geo[:, 0], geo[:, 1])

# %%
print("Selecting useful elements for interpolation")
lon_bool = np.logical_and(lon > minlon_base - 0.2, lon < maxlon_base + 0.2)
lat_bool = np.logical_and(lat > minlat_base - 0.2, lat < maxlat_base + 0.2)
coord_bool = np.logical_and(lon_bool, lat_bool)

# Find triangles to keep
topo_bool = coord_bool[topo]
tri_to_keep = np.any(topo_bool, axis=1)

# Keep included topo only
topo_to_keep = topo[tri_to_keep]
# Count number of times each node appears in topo
count_nodes = np.bincount(topo_to_keep.ravel(), minlength=np.max(topo) + 1)
# Included nodes = when appaers at least once
included_nodes = np.flatnonzero(count_nodes)

lon = lon[included_nodes]
lat = lat[included_nodes]

# node renumbering
corresp = np.ma.zeros(np.max(topo) + 1, dtype=int)
corresp[included_nodes] = np.arange(len(included_nodes))
corresp = np.ma.masked_where(count_nodes == 0, corresp)
topo_renumbered = corresp[topo_to_keep]

# %%
print("Initializing plot...")
out_dir = f"{p.output_directory}/fig/"
os.makedirs(out_dir, exist_ok=True)
time_for_name = [f"{d}_" for d in dates]
out_fig_file = f"{out_dir}hydro_{time_for_name}{region_base}{fig_name}"

fig = plt.figure(figsize=figsize)

ax = AxesGrid(
    fig,
    111,
    nrows_ncols=(1, len(dates)),
    axes_pad=0.0,
    share_all=True,
    label_mode="L",
    cbar_mode="single",
)

for i, date_unique in enumerate(dates):
    print(f"Select timestep data for {date_unique}")
    ts_unique = int((date_unique - initial_time_as_datetime).total_seconds() / 3600)
    posix_unique = dt2ts(date_unique)

    eta_ts = eta[ts_unique, :]
    u_ts = u[ts_unique, :]
    v_ts = v[ts_unique, :]

    eta_sel = eta_ts[included_nodes]
    u_sel = u_ts[included_nodes]
    v_sel = v_ts[included_nodes]

    print("Interpolating...")
    triangles = tri.Triangulation(lon, lat, topo_renumbered)

    interp_eta = tri.LinearTriInterpolator(triangles, eta_sel.ravel())
    interp_u = tri.LinearTriInterpolator(triangles, u_sel.ravel())
    interp_v = tri.LinearTriInterpolator(triangles, v_sel.ravel())

    etaq = interp_eta(lonq, latq)
    uq = interp_u(lonq, latq)
    vq = interp_v(lonq, latq)
    uvq = np.sqrt(uq * uq + vq * vq)

    etag = interp_eta(long, latg)
    ug = interp_u(long, latg)
    vg = interp_v(long, latg)
    uvg = np.sqrt(ug * ug + vg * vg)

    myfield = etag if eta_field_flag else uvg

    if inset_flag:
        etaq_inset = interp_eta(lonq_inset, latq_inset)
        uq_inset = interp_u(lonq_inset, latq_inset)
        vq_inset = interp_v(lonq_inset, latq_inset)
        uvq_inset = np.sqrt(uq_inset * uq_inset + vq_inset * vq_inset)

        etag_inset = interp_eta(long_inset, latg_inset)
        ug_inset = interp_u(long_inset, latg_inset)
        vg_inset = interp_v(long_inset, latg_inset)
        uvg_inset = np.sqrt(ug_inset * ug_inset + vg_inset * vg_inset)

        myfield_inset = etag_inset if eta_field_flag else uvg_inset

    print("Plotting...")

    ax[i].set_aspect("equal")
    im = ax[i].imshow(
        myfield,
        origin="lower",
        extent=[minlon_base, maxlon_base, minlat_base, maxlat_base],
        zorder=1,
    )
    im.set_cmap(cmap)
    im.set_clim(vmin=vmin, vmax=vmax)

    if quiver_flag:
        ax[i].quiver(lonq, latq, uq, vq, uvq, cmap="gray", zorder=3)
    if stream_flag:
        ax[i].streamplot(
            long,
            latg,
            np.reshape(ug, (grid_res, grid_res)),
            np.reshape(vg, (grid_res, grid_res)),
            linewidth=1,
            color="k",
            density=2,
            zorder=4,
        )
    if "polygon" in str(type(df_coast.iloc[0, -1])):
        df_coast.plot(
            ax=ax[i], linewidth=0.5, edgecolors="k", color=pp.colors["gris"], zorder=2
        )
    else:
        df_coast.plot(ax=ax, linewidth=0.5, color="k", zorder=4)
    ax[i].set_xlim(minlon_base, maxlon_base)
    ax[i].set_ylim(minlat_base, maxlat_base)

    if inset_flag:
        axins = zoomed_inset_axes(
            ax[i], region[region_inset]["zoom_inset"], loc=loc_inset
        )

        axins.set_aspect("equal")
        im = axins.imshow(
            myfield_inset,
            origin="lower",
            extent=[minlon_inset, maxlon_inset, minlat_inset, maxlat_inset],
            zorder=1,
        )
        im.set_cmap(cmap)
        im.set_clim(vmin=vmin, vmax=vmax)
        if quiver_flag:
            axins.quiver(
                lonq_inset,
                latq_inset,
                uq_inset,
                vq_inset,
                uvq_inset,
                cmap="gray",
                zorder=3,
            )
        if stream_flag:
            axins.streamplot(
                long_inset,
                latg_inset,
                np.reshape(ug_inset, (grid_res, grid_res)),
                np.reshape(vg_inset, (grid_res, grid_res)),
                color="k",
                linewidth=1,
                arrowsize=0.8,
                density=1.5,
                zorder=4,
            )
        if "polygon" in str(type(df_coast.iloc[0, -1])):
            df_coast.plot(
                ax=axins,
                linewidth=0.5,
                edgecolors="k",
                color=pp.colors["gris"],
                zorder=2,
            )
        else:
            df_coast.plot(ax=axins, linewidth=0.5, color="k", zorder=4)

        axins.set_xlim(minlon_inset, maxlon_inset)
        axins.set_ylim(minlat_inset, maxlat_inset)

        # Remove ticks and labels
        axins.axes.xaxis.set_ticklabels([])
        axins.axes.yaxis.set_ticklabels([])

        [axins.spines[i].set_color(inset_ec) for i in axins.spines]
        [axins.spines[i].set_linewidth(inset_lw) for i in axins.spines]
        [axins.spines[i].set_zorder(5) for i in axins.spines]
        mark_inset(
            ax[i], axins, loc1=2, loc2=4, fc="none", ec=inset_ec, lw=inset_lw, zorder=5
        )


# Colorbar
ax.cbar_axes[0].colorbar(im)

plt.savefig(f"{out_fig_file}.png", bbox_inches="tight")
plt.savefig(f"{out_fig_file}.jpeg", bbox_inches="tight", dpi=1000)
plt.savefig(f"{out_fig_file}.pdf", bbox_inches="tight")

plt.show()
plt.close("all")

# %%