import pandas as pd
import datetime as dt
from osgeo import ogr
import numpy as np
from styx_river_coal_mine.code import param_private


def get_utm_zone(sf):
    driver = ogr.GetDriverByName("GPKG")
    data = driver.Open(sf, 0)
    layer = data.GetLayer()
    lon_min, lon_max = None, None
    lat_min, lat_max = None, None
    for f in layer:
        xmin, xmax, ymin, ymax = f.geometry().GetEnvelope()
        lon_min = xmin if lon_min is None else min(lon_min, xmin)
        lon_max = xmax if lon_max is None else max(lon_max, xmax)
        lat_min = ymin if lat_min is None else min(lat_min, ymin)
        lat_max = ymax if lat_max is None else max(lat_max, ymax)
    mean_lon = 0.5 * (lon_min + lon_max)
    mean_lat = 0.5 * (lat_min + lat_max)
    utm_zone = int(np.ceil((mean_lon + 180.0) / 6))
    hemisphere = " +south" if mean_lat < 0.0 else " +north"
    mesh_proj = "+proj=utm +ellps=WGS84 +zone=" + str(utm_zone) + hemisphere
    return lon_min, lon_max, lat_min, lat_max, mesh_proj


class parameters:
    def __init__(
        self,
        mesh_setup="gbr_styx",
        initial_time=param_private.initial_time,
        final_time=param_private.final_time,
    ):

        # --- simu parameters --- #
        self.mesh_setup = mesh_setup
        self.initial_time = initial_time
        self.final_time = final_time
        self.dt = 900
        self.dt_export = 60 * 60
        self.checkpoint_ratio = 24
        self.ratio_full_export = int(86400 / self.dt_export)  # every day
        self.min_depth = 5

        # --- Paths --- #
        self.local_base_dir = "styx_river_coal_mine/data/"
        self.scratch_base_dir = "styx_river_coal_mine/data/"

        self.output_directory = f"{self.scratch_base_dir}output/{self.mesh_setup}"
        self.full_output_directory = (
            f"{self.scratch_base_dir}full_export/{self.mesh_setup}"
        )
        self.nc_data_dir = f"{self.scratch_base_dir}nc"
        self.data_dir = f"{self.scratch_base_dir}prepro"
        self.prepro_dir_base = (
            self.data_dir[:-1] if self.data_dir[-1] == "/" else self.data_dir
        )
        self.prepro_dir = f"{self.prepro_dir_base}/{self.mesh_setup}"
        self.nc_dir = (
            self.nc_data_dir[:-1] if self.nc_data_dir[-1] == "/" else self.nc_data_dir
        )
        # --- Mesh --- #
        self.shape_file = f"{self.local_base_dir}spatial/domain.gpkg"  # shape file containing coast lines
        self.reef_file = f"{self.local_base_dir}spatial/reefs.gpkg"
        self.land_shapefile = f"{self.local_base_dir}spatial/QLD_mainland_islands.gpkg"
        self.reef_file_simplified = (
            f"{self.local_base_dir}spatial/reefs_simplified.gpkg"
        )
        self.mesh_file = f"{self.local_base_dir}mesh/{self.mesh_setup}.msh"

        # --- region --- #
        (
            self.lon_min,
            self.lon_max,
            self.lat_min,
            self.lat_max,
            self.mesh_proj,
        ) = get_utm_zone(self.shape_file)

        self.open_tags = ["open_shelf", "open_north", "open_south"]
        self.closed_tags = ["coast", "coast_zoi"]

        # --- misc --- #
        self.forcing_names = param_private.forcing_names
        self.vec_prepro2D = param_private.vec_prepro2D

    def print_info(self):
        print("========================= PARAM INFO =========================")
        print(f"Date: {dt.datetime.now()}")
        print(f"Run on: {self.machine}")
        print(f"Mesh: {self.mesh_file}")
        print(f"Extent: {self.lon_min, self.lon_max, self.lat_min, self.lat_max}")
        print(f"Simu start: {self.initial_time}")
        print(f"Simu end: {self.final_time}")
        print(f"Proj: {self.mesh_proj}")
        print(f"Local dir: {self.local_base_dir}")
        print(f"Prepro dir: {self.prepro_dir}")
        print(f"NetCDF dir: {self.nc_dir}")
        print(f"Output dir: {self.output_directory}")
        print("==============================================================")