# import gdal
import numpy as np

# import pyproj
import seamsh
import seamsh.geometry
from osgeo import osr
from seamsh.geometry import CurveType

import param

# --------- OPTIONS  -----------------------------------------------------------
p = param.parameters()
interior_point = (150, -21.25)
use_bath = True

# dist to coast
l0_coast = 500  # Res min along coastlines
l0_reefs = 1000  # Res min on reefs borders
l0_zoi = 100  # Res min on ROI
l1 = 25000  # Res max
d0 = l0_coast  # Length of plateau (with res min)
d1 = 100000  # Distance where res = res max
bath_max = 200  # max bathymetry

print("Start creating mesh with:")
print(f"-> highest resolution: {min(l0_coast, l0_zoi, l0_reefs)}")
print(f"-> with bathy: {use_bath}")
# ------------------------------------------------------------------------------

# 1. Define projections and domain
print("Defining projections and domain")
domain_srs = osr.SpatialReference()
domain_srs.ImportFromProj4(p.mesh_proj)
lonlat_proj = osr.SpatialReference()
lonlat_proj.ImportFromProj4("+proj=latlong +ellps=WGS84 +unit=degrees")

domain = seamsh.geometry.Domain(domain_srs)
domain_reefs = seamsh.geometry.Domain(domain_srs)

# 2. Add coastlines/features to domain
print("Reading coastlines and reefs")
domain.add_boundary_curves_shp(p.shape_file, "type", CurveType.POLYLINE)
domain_reefs.add_interior_curves_shp(
    p.reef_file_simplified, "FEAT_NAME", CurveType.POLYLINE
)

# 3. Define distances to features for mesh resolutions
dist_coast = seamsh.field.Distance(domain, 100, ["coast"])
dist_coast_zoi = seamsh.field.Distance(domain, 100, ["coast_zoi"])
dist_reefs = seamsh.field.Distance(domain_reefs, 100, ["Reef"])

# 4. Read bathymetry file and compute gradient if needed
print("Reading bathymetry")
path_to_bathy = p.slimGBR_data_dir + "source/bathymetry/gbr100_corrected/"

if use_bath:
    bath_field = seamsh.field.Raster(
        path_to_bathy + "GBR100m_smoothed_HolesFixedG3.tif"
    )

# 5. Define mesh element size function
print("Defining mesh element size function")


def threshold(d, l0, l1, d0, d1):
    """
    if distance d < d0, return mesh size l0
    elif distance d > d1, return mesh size l1
    elif d0 <= distance d <= d1, return mesh size linearly interpolated between l0 and l1
    """
    res = np.ones_like(d) * l1
    dd = d[d < d1]
    res[d < d1] = np.where(
        dd < d0, l0, (dd - d0) / (d1 - d0) * l1 + (1 - (dd - d0) / (d1 - d0)) * l0
    )
    return res


def blending(d, d0, d1):
    res = np.ones_like(d)
    dd = d[d < d1]
    res[d < d1] = np.where(
        dd < d0, 0, 3 * ((dd - d0) / (d1 - d0)) ** 2 - 2 * ((dd - d0) / (d1 - d0)) ** 3
    )
    return res


def mesh_size(x, projection):
    val = 1e9 * np.ones(x.shape[0])

    val = np.minimum(
        val, threshold(dist_coast(x, projection), l0_coast, l1, l0_coast, d1)
    )
    val = np.minimum(
        val, threshold(dist_coast_zoi(x, projection), l0_zoi, l1, l0_zoi, d1)
    )
    val = np.minimum(
        val, threshold(dist_reefs(x, projection), l0_reefs, l1, l0_reefs, d1)
    )

    if use_bath:
        bath = -bath_field(x, projection)
        bath[bath < 0] = 0

        size_bath = np.sqrt(np.clip(bath, 100, 4000)) * 750
        val = np.minimum(val, size_bath)

    return val


# 6. Treat mesh boundaries
coarse = seamsh.geometry.coarsen_boundaries(
    domain, interior_point, lonlat_proj, mesh_size
)
# 7. Start meshing
seamsh.gmsh.mesh(coarse, f"{p.local_base_dir}mesh/gbr_styx.msh", mesh_size, version=2)