import os
import sqlite3
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import colorcet
import cycler
import skyproj
import healpy as hp
import scipy.stats 

import rubin_sim.maf as maf
from rubin_sim.data import get_baseline

from rubin_scheduler.scheduler.utils import get_current_footprint

from astropy.time import Time
from astropy.coordinates import get_sun
from astropy.coordinates import EarthLocation

from rubin_scheduler.utils import SURVEY_START_MJD

RUBIN_LOC = EarthLocation.of_site("Cerro Pachon")

MJD_2024 = Time(
    "2024-01-01T00:00:00.00",
    format="isot",
    scale="utc",
    location=RUBIN_LOC,
).utc.mjd

YEAR = 365.25
BASELINE_SURVEY_START_MJD = SURVEY_START_MJD # 60796 = May 10, 2025


def get_sun_ra_at_mjd(mjd):
    t = Time(mjd, format='mjd', location=RUBIN_LOC)
    return get_sun(t).ra.deg


def get_phase_for_ra_in_mjd(ra, start_time):
    sun_ra_start = get_sun_ra_at_mjd(start_time)
    return (
        ((sun_ra_start - ra) / 180 * np.pi + np.pi / 2) % (2.0 * np.pi)
        * (YEAR / 2.0 / np.pi)
    )


def get_season(mjd, ra, start_time):
    phase = get_phase_for_ra_in_mjd(ra, start_time)
    return (mjd - start_time + phase) / YEAR
from rubin_sim.maf import BaseMetric
from rubin_sim.maf import find_season_edges
from rubin_scheduler.utils import survey_start_mjd, calc_season

if '__main__.NVisPerSeasonMetric' in BaseMetric.registry:
    del BaseMetric.registry['__main__.NVisPerSeasonMetric']

if '__main__.SeasonValMetric' in BaseMetric.registry:
    del BaseMetric.registry['__main__.SeasonValMetric']


class SeasonValMetric(BaseMetric):
    """Calculate a (single) value for the seasons.

    Parameters
    ----------
    min_exp_time : `float`, optional
        Minimum visit exposure time to count for a 'visit', in seconds.
        Default 16, to skip near-sun twilight visits.
    reduce_func : func-like, optional
        Any function, to turn the array of season values into a single number.
        Default np.median.
    mjd_start : `float` or None, optional
        The survey start date, in mjd.
        Default None uses `survey_start_mjd()`.
    mjd_col : `str`, optional
        The name of the MJD column in the data.
    exp_time_col : `str`, optional
        The name of the visit exposure time column.
    """
        
    def __init__(self, min_exp_time=16, reduce_func=np.median, mjd_start=None, mjd_col='observationStartMJD', 
                 exp_time_col='visitExposureTime', **kwargs):
        units = ""
        if mjd_start is None:
            mjd_start = SURVEY_START_MJD
        self.mjd_start = mjd_start
        self.mjd_col = mjd_col
        self.exp_time_col = exp_time_col
        self.min_exp_time = min_exp_time
        self.reduce_func = reduce_func
        if 'metric_name' not in kwargs:
            try:
                metric_name = f'Season_{reduce_func.__name__}'
            except AttributeError:
                metric_name = 'SeasonVal'
        else:
            metric_name = kwargs['metric_name']
            del kwargs['metric_name']
        super().__init__(
            col=[self.mjd_col, self.exp_time_col], units=units, 
            metric_name=metric_name, **kwargs
        )
        
    def run(self, data_slice, slice_point):
        # Order data Slice/times and exclude visits which are too short.
        long = np.where(data_slice[self.exp_time_col] > self.min_exp_time)
        if len(long[0]) == 0:
            return self.badval
        data = np.sort(data_slice[long], order=self.mjd_col)
        # SlicePoints ra/dec are always in radians -
        # convert to degrees to calculate season
        seasons = calc_season(np.degrees(slice_point["ra"]), data[self.mjd_col], self.mjd_start)
        result = self.reduce_func(seasons)
        return result

class NVisPerSeasonMetric(BaseMetric):
    """Calculate some per-season numbers: number of visits
    per season, the season length per season, and season id number.

    Parameters
    ----------
    min_exp_time : `float`, optional
        Minimum visit exposure time to count for a 'visit', in seconds.
        Default 20.
    reduce_func : function, optional
       Function that can operate on array-like structures.
       Typically numpy function.
       This reduces the season length in each season from 10 separate
       values to a single value.
       Default np.median.
    mjd_start : `float` or None
        The MJD of the start of the survey (mjd days). 
        Default None uses `survey_start_mjd()`. 
        
    """

    def __init__(
        self,
        min_exp_time=16,
        mjd_start=None,
        mjd_col="observationStartMJD",
        exp_time_col="visitExposureTime",
        night_col="night",
        metric_name="NVisPerSeason",
        **kwargs,
    ):
        units = "#"
        if mjd_start is None:
            mjd_start = SURVEY_START_MJD
        self.mjd_start = mjd_start
        self.mjd_col = mjd_col
        self.exp_time_col = exp_time_col
        self.night_col = night_col
        self.min_exp_time = min_exp_time
        super().__init__(
            col=[self.mjd_col, self.exp_time_col, self.night_col], units=units, 
            metric_name=metric_name, metric_dtype='object', **kwargs
        )

    def run(self, data_slice, slice_point):
        # Order data Slice/times and exclude visits which are too short.
        long = np.where(data_slice[self.exp_time_col] > self.min_exp_time)
        if len(long[0]) == 0:
            return self.badval
        # Sort remaining visits in order of MJD
        data = np.sort(data_slice[long], order=self.mjd_col)
        # Calculate the season value for each of these visits
        seasons = (calc_season(np.degrees(slice_point["ra"]), data[self.mjd_col], self.mjd_start))
        int_seasons = np.floor(seasons)
        # Count up X per season
        n_per_season = []
        max_nightgap_per_season = []
        median_nightgap_per_season = []
        for si in np.unique(int_seasons):
            match = np.where(int_seasons == si)[0]
            # number of visits per season
            n_per_season.append(len(match))
            night_gaps = np.diff(np.unique(data[self.night_col][match]))
            if len(night_gaps) == 0:
                maxgap = self.badval
                mediangap = self.badval
            else:
                maxgap = night_gaps.max()
                mediangap = np.median(night_gaps)
            max_nightgap_per_season.append(maxgap)
            median_nightgap_per_season.append(mediangap)
        n_per_season = np.array(n_per_season)
        max_nightgap_per_season = np.array(max_nightgap_per_season)
        median_nightgap_per_season = np.array(median_nightgap_per_season)
        # Count up length of seasons using season utils
        first_of_season, last_of_season = find_season_edges(seasons)
        seasonlengths = data[self.mjd_col][last_of_season] - data[self.mjd_col][first_of_season]
        # Cunt seasons in the alternate way used in the FBS
        alt_season = get_season(data[self.mjd_col], np.degrees(slice_point["ra"]), self.mjd_start)
        int_alt = np.floor(alt_season)
        alt_n_per_season = []
        for si in np.unique(int_alt):
            alt_n_per_season.append(len(np.where(int_alt == si)[0]))
        alt_n_per_season = np.array(alt_n_per_season)
        return {'n_per_season' : n_per_season, 
                'max_gap_per_season': max_nightgap_per_season,
                'median_gap_per_season': median_nightgap_per_season,
                'alt_n_per_season' : alt_n_per_season,
                'season_length' : seasonlengths, 
                'season_id' : np.unique(int_seasons)}

    def reduce_min(self, metric_value):
        # Don't count first or last seasons, as they could be short
        return np.min(metric_value['n_per_season'][1:-1])

    def reduce_max(self, metric_value):
        return np.max(metric_value['n_per_season'])

    def reduce_season_m1(self, metric_value):
        idx = np.where(metric_value['season_id'] == -1)[0]
        if len(idx) == 0:
            return self.badval
        return metric_value['n_per_season'][idx]

    def reduce_season_0(self, metric_value):
        idx = np.where(metric_value['season_id'] == 0)[0]
        if len(idx) == 0:
            return self.badval
        return metric_value['n_per_season'][idx]

    def reduce_season_1(self, metric_value):
        idx = np.where(metric_value['season_id'] == 1)[0]
        if len(idx) == 0:
            return self.badval
        return metric_value['n_per_season'][idx]

    def reduce_season_2(self, metric_value):
        idx = np.where(metric_value['season_id'] == 2)[0]
        if len(idx) == 0:
            return self.badval
        return metric_value['n_per_season'][idx]

    def reduce_season_3(self, metric_value):
        idx = np.where(metric_value['season_id'] == 3)[0]
        if len(idx) == 0:
            return self.badval
        return metric_value['n_per_season'][idx]

    def reduce_season_4(self, metric_value):
        idx = np.where(metric_value['season_id'] == 4)[0]
        if len(idx) == 0:
            return self.badval
        return metric_value['n_per_season'][idx]

    def reduce_season_5(self, metric_value):
        idx = np.where(metric_value['season_id'] == 5)[0]
        if len(idx) == 0:
            return self.badval
        return metric_value['n_per_season'][idx]

    def reduce_season_6(self, metric_value):
        idx = np.where(metric_value['season_id'] == 6)[0]
        if len(idx) == 0:
            return self.badval
        return metric_value['n_per_season'][idx]

    def reduce_season_7(self, metric_value):
        idx = np.where(metric_value['season_id'] == 7)[0]
        if len(idx) == 0:
            return self.badval
        return metric_value['n_per_season'][idx]

    def reduce_season_8(self, metric_value):
        idx = np.where(metric_value['season_id'] == 8)[0]
        if len(idx) == 0:
            return self.badval
        return metric_value['n_per_season'][idx]

    def reduce_season_9(self, metric_value):
        idx = np.where(metric_value['season_id'] == 9)[0]
        if len(idx) == 0:
            return self.badval
        return metric_value['n_per_season'][idx]
from rubin_scheduler.scheduler.utils import get_current_footprint

opsim_fname = 'baseline_v4.3.2_10yrs.db'
#opsim_fname = get_baseline()
print(opsim_fname)

outdir = "season_figs"

runName = os.path.split(opsim_fname)[-1].replace('.db', '')
print(runName)


# tried out a test point .. but then went to exgal-wfd footprint
test_ra = 30.0
test_dec = -20.0
test_slicer = maf.UserPointsSlicer(test_ra, test_dec)


footprints, labels = get_current_footprint(nside=64)
wfdhpid = np.where((labels == 'lowdust') | (labels == "euclid_overlap") | (labels == "LMC_SMC") | (labels == "virgo"))[0]
print(np.unique(labels))
wfd_slicer = maf.HealpixSubsetSlicer(nside=64, hpid=wfdhpid)

#constraint = "filter == 'g' or filter == 'r' or filter == 'i'"
#label = "gri bands"
constraint = ""
label = ""

mymetric = NVisPerSeasonMetric()

summary_metrics = [maf.MedianMetric()]
bundle = maf.MetricBundle(mymetric, wfd_slicer, constraint, run_name=runName, summary_metrics=summary_metrics)
g = maf.MetricBundleGroup({'nvals': bundle}, opsim_fname, out_dir=outdir, results_db=None, verbose=True)
# And calculate the metric
g.run_all()
# just check the reduce functions showed up (and what they're named)
for b in g.bundle_dict:
    print(b, g.bundle_dict[b].summary_values)
# hackety hack to make a quick histogram of the non-first/last season nvisits values
vals = []
for i in bundle.metric_values.compressed():
    vals.append(i['n_per_season'][1:-1])
vals = np.hstack(vals)
_ = plt.hist(vals, bins=np.arange(0, 200, 2))
#_ = plt.hist(vals, bins=np.arange(0, 120, 1))
plt.grid(True, alpha=0.3)
#plt.axvline(14, color='k')
#plt.axvline(65, color='k')
plt.title(f"{runName} {label}")
plt.xlabel("Number of visits per season")
#plt.savefig("rolling_cadence_nvisits_per_season.png")
# Can we try to identify on/off seasons automatically? 
# <40 nvisits = off and >110 = on? and anything between is "average"?
low = 45
high = 105
season_maps = {}
for sid in np.arange(-1, 10, 1):
    vals = []
    for i in bundle.metric_values.filled(0):
        if i == 0:
            vals.append(hp.UNSEEN)
        else:
            idx = np.where(i['season_id'] == sid)[0]
            if len(idx) == 0:
                vals.append(hp.UNSEEN)
            else:
                n_in_season = i['n_per_season'][idx][0]
                if n_in_season < low:
                    vals.append(-1)
                elif n_in_season > high:
                    vals.append(1)
                else:
                    vals.append(0)
    vals = np.array(vals)
    season_maps[sid] = vals
hp.mollview(season_maps[5])
n_on = np.zeros(len(bundle.slicer))
n_off = np.zeros(len(bundle.slicer))
for idx, i in enumerate(bundle.metric_values.filled(0)):
    if i == 0:
        n_on[idx] = 0
        n_off[idx] = 0
    else:
        n_on[idx] = len(np.where(i['n_per_season'] > high)[0])
        n_off[idx] = len(np.where(i['n_per_season'] < low)[0])
hp.mollview(n_on, min=0, max=4)
# hackety hack to make a quick histogram of the non-first/last season nvisits values
vals_on = []
vals_mid = []
vals_off = []
k = 'season_length'
bins = np.arange(0, 200, 10)
for i in bundle.metric_values.compressed():
    on = np.where(i['n_per_season'] > high)
    vals_on.append(i[k][on])
    off = np.where(i['n_per_season'] < low)
    vals_off.append(i[k][off])
    mid = np.where((i['n_per_season'] >= low) & (i['n_per_season'] <= high))
    vals_mid.append(i[k][mid])
vals_on = np.hstack(vals_on)
vals_mid = np.hstack(vals_mid)
vals_off = np.hstack(vals_off)
_ = plt.hist(vals_on, bins=bins, alpha=0.3, label=f"High - median={np.median(vals_on)}")
_ = plt.hist(vals_off, bins=bins, alpha=0.3, label=f"Low - median={np.median(vals_off)}")
_ = plt.hist(vals_mid, bins=bins, alpha=0.3, label=f"Mean - median={np.median(vals_mid)}")
plt.grid(True, alpha=0.3)
plt.legend()
#plt.axvline(14, color='k')
#plt.axvline(65, color='k')
plt.title(f"{runName} {label}")
plt.xlabel("Median Gap Between Nights Per Season")
plt.savefig("rolling_cadence_median_nightgap.png")
print(np.median(vals_on), np.median(vals_mid), np.median(vals_off))
# This is just a standard metric bundle - minimum number of visits per season
k = 'NVisPerSeason_min'
g.bundle_dict[k].set_plot_dict({'color_min': 0, 'color_max': 50, 'x_min': 0, 'x_max':200})
g.bundle_dict[k].set_plot_funcs([maf.HealpixSkyMap()])
g.bundle_dict[k].plot()
#print(g.bundle_dict[k].summary_values)
# I'm curious about the season length over the sky, say in the middle of the survey
# is good that the rolling (active/inactive) doesn't strongly impact season length
# I'm not so sure about why the boundaries between regions have such long seasons ..
vals = []
sid = 3
for i in bundle.metric_values.filled(0):
    if i == 0:
        vals.append(hp.UNSEEN)
    else:
        idx = np.where(i['season_id'] == sid)[0]
        if len(idx) == 0:
            vals.append(hp.UNSEEN)
        else:
            vals.append(i['n_per_season'][idx][0])
vals = np.array(vals)
hp.mollview(vals, min=40, max=110)
# Plot the visits per season 
for i in ['m1', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
    k = f'NVisPerSeason_season_{i}'
    g.bundle_dict[k].set_plot_dict({'color_min': 0, 'color_max': 200, 'x_min': 0, 'x_max':200})
    g.bundle_dict[k].set_plot_funcs([maf.HealpixSkyMap()])
    g.bundle_dict[k].plot(savefig=True)