import os
import warnings
import copy
import numpy as np
import pandas as pd
import sqlite3
import healpy as hp
import matplotlib.pyplot as plt
import colorcet as cc
import skyproj
from IPython.display import display, HTML
from tabulate import tabulate

from astropy.time import Time, TimeDelta
from astropy.coordinates import SkyCoord
import astropy.units as u

# I have imported more below than I actually needed .. 
from rubin_scheduler.scheduler import sim_runner
from rubin_scheduler.scheduler.model_observatory import ModelObservatory
from rubin_scheduler.scheduler.schedulers import SimpleBandSched, DateSwapBandScheduler, CoreScheduler
from rubin_scheduler.scheduler.features import Conditions
from rubin_scheduler.scheduler.utils import SchemaConverter, run_info_table
from rubin_scheduler.site_models import Almanac
from rubin_scheduler.utils import ddf_locations, angular_separation, approx_ra_dec2_alt_az, Site
from rubin_scheduler.scheduler.utils import ObservationArray, generate_all_sky

import rubin_sim.maf as maf
from rubin_sim.data import get_baseline
import schedview.compute as schedview_compute

from rubin_nights import connections
import rubin_nights.dayobs_utils as rn_dayobs
import rubin_nights.plot_utils as rn_plots
import rubin_nights.augment_visits as augment_visits
import rubin_nights.observatory_status as observatory_status
import rubin_nights.rubin_scheduler_addons as rn_sch
import rubin_nights.rubin_sim_addons as rn_sim
from rubin_nights.targets_and_visits import targets_and_visits

import importlib

# Replace with lsst_utils, now that lsst_utils is in conda.
band_colors = rn_plots.PlotStyles.band_colors
out_dir = "dp2"
try:
    os.mkdir(out_dir)
except FileExistsError:
    pass
# # Butler visit list for DP2 created by :
# import lsst.daf.butler as dafButler

# collections = ["LSSTCam/runs/DRP/DP2/v30_0_6/DM-53881/stage4"]
# butler = dafButler.Butler("dp2_prep", collections=collections)

# res = butler.query_datasets("visit_table")
# visit_table = butler.get(res[0], storageClass='DataFrame')
# visit_table.to_hdf("dp2_visit_table.hdf", key="visits")
dp2_visit_table = pd.read_hdf("dp2_visit_table.hdf")
print(dp2_visit_table.visitId.min(), dp2_visit_table.visitId.max())
dp2_visit_table.head()
endpoints = connections.get_clients()

programs = ["BLOCK-365", "BLOCK-407", "BLOCK-408", "BLOCK-416", "BLOCK-417", "BLOCK-419", "BLOCK-421"]

refresh_visits = True
quicklook = 'visit'

day_obs_min = 20250424
day_obs_max = 20260106

one_day = TimeDelta(1, format='jd')
time_day_obs_start = rn_dayobs.day_obs_to_time(day_obs_min)
time_day_obs_end = rn_dayobs.day_obs_to_time(day_obs_max) + one_day
days = time_day_obs_start + one_day * np.arange(0, (time_day_obs_end - time_day_obs_start).jd + 0.5)
day_obs_list = [rn_dayobs.day_obs_str_to_int(d.iso[0:10]) for d in days]

# Just a flag to make it clear if we're skipping retrieval from the consdb.
if refresh_visits:
    skip_imgtypes = ["bias", "flat", "dark"]
    if quicklook == 'visit':
        query = ( 
            "select v.*, q.* from cdb_lsstcam.visit1 as v left join cdb_lsstcam.visit1_quicklook as q on v.visit_id = q.visit_id "
            f"where v.day_obs >= {day_obs_min} and v.day_obs <= {day_obs_max} and v.img_type != 'bias' and v.img_type != 'flat' and v.img_type != 'dark'"
              )
        visits = endpoints['consdb_tap'].query(query)
        visits = augment_visits.augment_visits(visits, "lsstcam", cols_from="visit")  
    elif quicklook == 'detector':
        query = (
                f"select v.*, "
                f"c.detector, cq.eff_time_zero_point_scale, cq.psf_sigma, cq.n_psf_star, cq.psf_area, cq.psf_ixx, cq.psf_iyy, cq.psf_ixy, "
                f"cq.zero_point, cq.stats_mag_lim, cq.pixel_scale, cq.sky_bg, cq.sky_noise "
                f"from cdb_lsstcam.visit1 as v join cdb_lsstcam.ccdvisit1 as c "
                f"on v.visit_id = c.visit_id  "
                f"left join cdb_lsstcam.ccdvisit1_quicklook as cq "
                f"on c.ccdvisit_id = cq.ccdvisit_id "
                f"where c.detector = 94 " 
                f"AND v.day_obs >= {day_obs_min} and v.day_obs <= {day_obs_max} and v.img_type != 'bias' and v.img_type != 'flat' and v.img_type != 'dark'"
            )
        visits = endpoints['consdb'].query(query)
        visits = augment_visits.augment_visits(visits, "lsstcam", cols_from="ccd")  
    visits.reset_index(inplace=True)
    visits.drop("index", axis=1, inplace=True)
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always") 
        visits.to_hdf('v_now.h5', key='visits')
else:
    visits = pd.read_hdf('v_now.h5')


# Now add extra information, including from the EFD

#dome_open = observatory_status.get_dome_open_close(rn_dayobs.day_obs_to_time(day_obs_min), rn_dayobs.day_obs_to_time(day_obs_max) + TimeDelta(1, format='jd'), endpoints['efd'])


wait_before_slew = 1.6
settle = 1.5
max_scatter = 10



if len(visits) > 0:
    
    cols = ["overhead", "fault_idle", "program_change", "filter_change", "bad_flag"]
    new_df = pd.DataFrame(np.zeros((len(visits), len(cols))), columns = cols, index=visits.index)
    visits = visits.merge(new_df, right_index=True, left_index=True)

    # Flag science_program changes 
    program_change = np.where((visits.science_program[:-1].values != visits.science_program[1:].values))[0]
    program_change = program_change + 1
    pmask = np.zeros(len(visits))
    pmask[0] = 1
    pmask[program_change] = 1
    visits["program_change"] = pmask
    
    # Flag filter changes 
    filter_change = np.where((visits.band[:-1].values != visits.band[1:].values) 
                             & (visits.day_obs[:-1].values == visits.day_obs[1:].values))[0]
    filter_change = filter_change + 1
    fmask = np.zeros(len(visits))
    fmask[filter_change] = 1
    visits["filter_change"] = fmask
    
    # calculate slew times and identify expected overheads
    visits, slewing = rn_sch.add_model_slew_times(visits, endpoints['efd'], model_settle=wait_before_slew + settle, dome_crawl=True, slew_while_changing_filter=False)
    valid_overhead = np.min([np.where(np.isnan(visits.slew_model.values), 0, visits.slew_model.values) + max_scatter, visits.visit_gap.values], axis=0)
    visits["overhead"] = valid_overhead
    
    # Need to remove faults for the first visit of the night or where there was a different program we didn't fetch
    # (could skip *some* of this by fetching all visits, but might still have some missing due to flats?)
    skipped_visits = np.concatenate([np.array([0]), np.where(visits.visit_id[:-1].values + 1 != visits.visit_id[1:].values)[0] + 1])
    
    fault = visits.visit_gap - valid_overhead
    fault[skipped_visits] = np.nan
    visits["fault_idle"] = fault
    
    visits.loc[skipped_visits, 'model_gap'] = np.nan
    
    # Pull lsst-dm excluded visit list to flag bad visits
    bad_visit_ids = augment_visits.fetch_excluded_visits("lsstcam")
    visits['bad_flag'] = np.zeros(len(visits), int)
    idx = visits.query("visit_id in @bad_visit_ids").index
    visits.loc[idx, 'bad_flag'] = 1
    # Also pull bad visit lists from exposure log
    ee = endpoints['exposure_log'].query_log(rn_dayobs.day_obs_to_time(day_obs_min), rn_dayobs.day_obs_to_time(day_obs_max))
    if len(ee) > 0:
        def make_visit_id(x):
            return f"{x.day_obs:d}{x.seq_num:05d}"
        exp_log_bad_visit_ids = ee.query("exposure_flag == 'junk'").apply(make_visit_id, axis=1).values
        if len(exp_log_bad_visit_ids) > 0:
            idx = visits.query("visit_id in @exp_log_bad_visit_ids").index
            visits.loc[idx, 'bad_flag'] = 1

    # And flag close_loop visits as 'bad' for purposes of efficiency
    idx = visits.query("observation_reason.str.contains('close_loop')").index
    visits.loc[idx, 'bad_flag'] = 1
    
    print("all visits:", len(visits), "science visits:", len(visits.query("science_program in @programs")))

else:
    print("Found no visits")
    print("The remainder of this notebook requires visits.")

## thoughts -- need to mark time associated with bad visits as fault somehow .. 

add_more_dimm = True
if add_more_dimm:
    def add_dimm(x, dimm):
        window = np.where((dimm['mjd'] >= (x.obs_start_mjd - 3/60/24))  & (dimm['mjd'] <= x.obs_end_mjd + 3/60/24))
        if len(window) > 0:
            dimm_fwhm = dimm.iloc[window]['fwhm'].mean(skipna=True)
        else:
            dimm_fwhm = np.nan
        return pd.Series({"dimm_fwhm": dimm_fwhm})

    def add_ringss(x, ringss):
        window = np.where((ringss['mjd'] >= (x.obs_start_mjd - 1/60/24)) & (ringss['mjd'] <= x.obs_end_mjd + 1/60/24))
        ringss_fwhmFree = ringss.iloc[window].fwhmFree.mean()    
        ringss_fwhmSector = ringss.iloc[window].fwhmSector.mean()
        ringss_fwhmScintillation = ringss.iloc[window].fwhmScintillation.mean()
        return pd.Series({"ringss_fwhmFree": ringss_fwhmFree, "ringss_fwhmSector": ringss_fwhmSector, "ringss_fwhmScintillation": ringss_fwhmScintillation})

    
    ringss = endpoints['efd'].select_time_series("lsst.sal.ESS.logevent_ringssMeasurement", ['fwhmFree', 'fwhmSector', 'fwhmScintillation'], time_day_obs_start, time_day_obs_end)
    if len(ringss) > 0:
        ringss_times_mjd = Time(ringss.index.values, scale='utc').tai.mjd 
        ringss['mjd'] = ringss_times_mjd
        ringss_df = visits.apply(add_ringss, args=[ringss], axis=1)
    else:
        ringss_df = pd.DataFrame([np.nan] * len(visits), columns=['ringss_fwhmFree', 'ringss_fwhmSector', 'ringss_fwhmScintillation'], index=visits.index)
    
    dimm2 = endpoints['efd'].select_time_series("lsst.sal.DIMM.logevent_dimmMeasurement", ['fwhm'], time_day_obs_start, time_day_obs_end, index=2)
    if len(dimm2) > 0:
        dimm2_times_mjd = Time(dimm2.index.values, scale='utc').tai.mjd 
        dimm2['mjd'] = dimm2_times_mjd
        dimm2_df = visits.apply(add_dimm, args=[dimm2], axis=1)
        dimm2_df.rename({"dimm_fwhm": "dimm2_fwhm"}, axis=1, inplace=True)
    else:
        dimm2_df = pd.DataFrame([np.nan] * len(visits), columns=['dimm2_fwhm'], index=visits.index)

    dimm1 = endpoints['efd'].select_time_series("lsst.sal.DIMM.logevent_dimmMeasurement", ['fwhm'], time_day_obs_start, time_day_obs_end, index=1)
    if len(dimm1) > 0:
        dimm1_times_mjd = Time(dimm1.index.values, scale='utc').tai.mjd 
        dimm1['mjd'] = dimm1_times_mjd        
        dimm1_df = visits.apply(add_dimm, args=[dimm1], axis=1)
        dimm1_df.rename({"dimm_fwhm": "dimm1_fwhm"}, axis=1, inplace=True)
    else:
        dimm1_df = pd.DataFrame([np.nan] * len(visits), columns=['dimm1_fwhm'], index=visits.index)

    dd = pd.merge(ringss_df, dimm2_df, left_index=True, right_index=True)
    dd = pd.merge(dd, dimm1_df, left_index=True, right_index=True)
    visits = pd.merge(visits, dd, left_index=True, right_index=True)

#sci = visits.query("science_program in @programs and img_type == 'science' and observation_reason != 'block-t548'")
sci = visits.query("science_program in @programs and img_type == 'science' and observation_reason != 'block-t548'")
print(len(sci.query("bad_flag == 0")))
print(len(dp2_visit_table))
print(len(set(sci.visit_id.values).intersection(set(dp2_visit_table.visitId.values))))
# ll = (list(set(sci.visit_id.values).difference(set(dp2_visit_table.visitId.values))))
# ll.sort()
# ll
# q = sci.query("visit_id in @ll")[['science_program', 'observation_reason', 'visit_id', 'day_obs', 'band', 'fwhm_geom', 'clouds', 'can_see_sky']]
dp2_visits = visits.query("visit_id in @dp2_visit_table.visitId")
opsim = rn_sim.consdb_to_opsim(dp2_visits)
filename = os.path.join(out_dir, 'dp2_visits.db')
con = sqlite3.connect(filename)
opsim.to_sql("observations", con, index=False, if_exists="replace")
con.close()
print(len(opsim))
svisits = dp2_visits.copy()
print(len(svisits))
t = Time("2025-04-24T12:00:00", scale='tai')
mjd_to_jd = t.mjd - t.jd
svisits.loc[:, 'jd'] = np.floor(svisits.obs_start_mjd - mjd_to_jd)
svisits.loc[:, 'jd'] = svisits.jd.astype(int)
jds = np.arange(svisits.jd.min(), svisits.jd.max()+1, 1)
jdsbins =  np.arange(svisits.jd.min(), svisits.jd.max()+2, 1)
days = [t.split('T')[0] for t in Time(jds, format='jd', scale='tai').isot]
bar_bottom = np.zeros(len(jds))
plt.figure(figsize=(10, 6))
for b in 'ugrizy':
    heights, _ = np.histogram(svisits.query("band == @b ").jd, bins=jdsbins)
    plt.bar(jds, heights, bottom=bar_bottom, width=1, color=band_colors[b], alpha=0.8, label=b)
    bar_bottom += heights
plt.legend()
_ = plt.xticks(jds[::7], labels=days[::7], rotation=90)
plt.grid(alpha=0.2)
plt.ylabel("Number of visits", fontsize='large')
plt.title("DP2 Science Visits")
plt.savefig(os.path.join(out_dir, "dp2_visits_timing.png"), bbox_inches='tight')
print("total visits", len(svisits))
print("visits for small field survey science", len(svisits.query("observation_reason == 'field_survey_science'")))
total_sv = len(svisits.query("observation_reason != 'field_survey_science'"))
print("visits for SV", len(svisits.query("observation_reason != 'field_survey_science'")))
print("visits for wide SV", len(svisits.query("observation_reason != 'field_survey_science' "\
                                              "and observation_reason != 'template_area_singles_i' "\
                                             "and observation_reason != 'too' and not observation_reason.str.contains('ddf')")))
total_ddf = len(svisits.query("observation_reason.str.contains('ddf')"))
print("visits for ddf", total_ddf, total_ddf/total_sv)
total_too = len(svisits.query("observation_reason.str.contains('too')"))
print("visits for too", total_too, total_too/total_sv)
print("visits prior to 20251024", len(svisits.query("day_obs < 20251024")))
print("visits post 202501024", len(svisits.query("day_obs >= 20251024")))
print("visits post to 20260106", len(svisits.query("day_obs > 20260106")))
ss = svisits.query("observation_reason == 'field_survey_science'").groupby(["target_name", "band"]).agg({'seq_num': 'count'})
ss.rename({"seq_num": "count"}, axis=1, inplace=True)
ss = ss.reset_index('band').pivot(columns=["band"]).droplevel(0, axis=1)
ss = ss[['u', 'g', 'r', 'i', 'z', 'y']]
ss['all'] = ss.sum(axis=1)
ss = ss.query("all > 50").sort_values('all')
ss.index = [i.replace("_", "\\_") for i in ss.index]
rst_table = tabulate(pd.DataFrame(ss.round(0)), headers='keys', tablefmt='rst')
rst_table = rst_table.replace("nan", "---")
print(rst_table)
q = svisits.query("observation_reason == 'field_survey_science'")
plt.figure(figsize=(10, 6))
bins = np.arange(0.3, 2.5, 0.005)
for b in 'ugrizy':
    #plt.hist(q.query("band == @b").fwhm_eff, bins=bins, cumulative=1, density=True, alpha=0.3, color=band_colors[b], label=b)
    plt.hist(q.query("band == @b").fwhm_eff, bins=bins, cumulative=1, density=True, histtype='step', color=band_colors[b], label=b)
plt.legend()
plt.xlim(0.5, 2.4)
plt.xlabel("FWHM (arcseconds)", fontsize='large')
plt.ylabel("Cumulative fraction of visits", fontsize='large')
plt.title("LSSTCam Small Field Survey Science")
plt.grid(alpha=0.3)
plt.savefig(os.path.join(out_dir, "smallfield_fwhm.png"))
q = svisits.query("observation_reason != 'field_survey_science' and not observation_reason.str.contains('ddf')")
#q = gvisits.query("observation_reason.str.contains('ddf')")
plt.figure(figsize=(10, 6))
bins = np.arange(0.3, 2.5, 0.005)
for b in 'ugrizy':
    #plt.hist(q.query("band == @b").fwhm_eff, bins=bins, cumulative=1, density=True, alpha=0.3, color=band_colors[b], label=b)
    plt.hist(q.query("band == @b").fwhm_eff, bins=bins, cumulative=1, density=True, histtype='step', color=band_colors[b], label=b)
plt.legend()
plt.xlim(0.5, 2.4)
plt.xlabel("FWHM (arcseconds)", fontsize='large')
plt.ylabel("Cumulative fraction of visits", fontsize='large')
plt.title("LSSTCam DP2 (SV+PreLSST) Wide Survey")
plt.grid(alpha=0.3)
plt.savefig(os.path.join(out_dir, "dp2_wide_fwhm.png"))
baseline = '/Users/lynnej/opsim/fbs_5.0/baseline_v5.0.1_10yrs.db'
conn = sqlite3.connect(baseline)
baseline_visits = pd.read_sql("select * from observations", conn)
qq = baseline_visits.query("not scheduler_note.str.contains('DD') and not scheduler_note.str.contains('ToO')")
ss = qq.groupby("band").agg({'fiveSigmaDepth': 'median', 'seeingFwhmEff': 'median', 'airmass': 'median', 'cloud': 'median'})
ss = ss.rename({"fiveSigmaDepth": "median m5", "seeingFwhmEff": "median fwhm (arcsec)", 'airmass': "median airmass", 'cloud': "mean cloud (mag)"}, axis=1)
ss = ss.loc[['u', 'g', 'r', 'i', 'z', 'y']]
remap = dict([(b, "baseline " + b) for b in 'ugrizy'])
ss = ss.rename(remap, axis=0)

qq = svisits.query("observation_reason != 'field_survey_science' and not observation_reason.str.contains('ddf')")
ss2 = qq.groupby("band").agg({'cat_m5': 'median', 'fwhm_eff': 'median', 'airmass': 'median', 'clouds': 'mean'})
ss2 = ss2.rename({"cat_m5": "median m5", "fwhm_eff": "median fwhm (arcsec)", "airmass": "median airmass", "clouds": "mean cloud (mag)"}, axis=1)
ss2 = ss2.loc[['u', 'g', 'r', 'i', 'z', 'y']]
remap = dict([(b, "DP2 " + b) for b in 'ugrizy'])
ss2 = ss2.rename(remap, axis=0)
s = pd.concat([ss, ss2]).loc[["baseline u", "DP2 u", "baseline g", "DP2 g", 
                          "baseline r", "DP2 r", "baseline i", "DP2 i", 
                          "baseline z", "DP2 z", "baseline y", "DP2 y"], 
    ["median airmass", "median fwhm (arcsec)", "mean cloud (mag)", "median m5"]]
rst_table = tabulate(s.round(2), headers='keys', tablefmt='rst')
print(rst_table)
q = svisits.query("observation_reason != 'field_survey_science' and not observation_reason.str.contains('ddf')")
#q = gvisits.query("observation_reason.str.contains('ddf')")
plt.figure(figsize=(10, 6))
bins = np.arange(18, 25.8, 0.1)
for b in 'ugrizy':
    plt.hist(q.query("band == @b").cat_m5, bins=bins, cumulative=0, density=False, alpha=0.3, color=band_colors[b], label=b)
    plt.hist(q.query("band == @b").cat_m5, bins=bins, cumulative=0, density=False, histtype='step', color=band_colors[b])#, label=b)
plt.legend()
plt.xlabel("Estimated m5 depth)", fontsize='large')
plt.ylabel("Number of visits", fontsize='large')
plt.title("LSSTCam DP2 (SV+PreLSST) Wide Survey")
plt.grid(alpha=0.3)
plt.savefig(os.path.join(out_dir, "dp2_wide_m5.png"), bbox_inches='tight')
q = svisits.query("observation_reason != 'field_survey_science' and not observation_reason.str.contains('ddf')")
#q = gvisits.query("observation_reason.str.contains('ddf')")
plt.figure(figsize=(10, 6))
bins = np.arange(18, 25.8, 0.1)
for b in 'ugrizy':
    plt.hist(q.query("band == @b").stats_mag_lim_median, bins=bins, cumulative=0, density=False, alpha=0.3, color=band_colors[b], label=b)
    plt.hist(q.query("band == @b").stats_mag_lim_median, bins=bins, cumulative=0, density=False, histtype='step', color=band_colors[b])#, label=b)
plt.legend()
plt.xlabel("Estimated m5 depth)", fontsize='large')
plt.ylabel("Number of visits", fontsize='large')
plt.title("LSSTCam DP2 (SV+PreLSST) Wide Survey")
plt.grid(alpha=0.3)
## DON'T USE THIS ONE -- the stats mag lim was not available for much of SV, and has not been backfilled. Use cat_m5 as it's pretty close anyway.
#plt.savefig(os.path.join(out_dir, "sv_wide_m5.png"))
vv = svisits.query("observation_reason == 'field_survey_science'").groupby(['target_name']).agg({'seq_num': 'count',
                                                                                                 'fwhm_eff': 'median',
                                                                                                 'clouds': 'mean',
                                                                                                 'obs_start_mjd': np.ptp})
vv.rename({"seq_num": "nvisits", 
           "obs_start_mjd": "timespan (days)", 
           "fwhm_eff": "median fwhm (arcsec)", 
           "clouds": "mean cloud (mag)"}, axis=1, inplace=True)
vv = vv.query("nvisits > 50").sort_values("nvisits")
vv['timespan (days)'] = vv['timespan (days)'].astype(int) + 1
vv.round(2)
vv.index = [i.replace("_", "\\_") for i in vv.index]
rst_table = tabulate(pd.DataFrame(vv.round(2)), headers='keys', tablefmt='rst')
rst_table = rst_table.replace("nan", "---")
print(rst_table)
# vv = gvisits.query("observation_reason == 'field_survey_science'").groupby(['target_name', 'band']).agg({'seq_num': 'count', 
#                                                                                                         'fwhm_eff': 'median', 
#                                                                                                         'clouds': 'mean', 
#                                                                                                        'obs_start_mjd': np.ptp})
# vv = vv.reset_index('band')
# vv.rename({"seq_num": "nvisits", 
#            "obs_start_mjd": "timespan (jd)", 
#            "fwhm_eff": "median fwhm (arcsec)", 
#            "clouds": "mean cloud (mag)"}, axis=1, inplace=True)
# nvisits_all = gvisits.query("observation_reason == 'field_survey_science'").groupby("target_name").agg({'obs_start': 'count', 
#                                                                                                        'fwhm_eff': 'median', 
#                                                                                                        'clouds': 'mean',
#                                                                                                       'obs_start_mjd': np.ptp})
# nvisits_all = nvisits_all.rename({'obs_start': 'nvisits', 
#                                   "obs_start_mjd": "timespan (jd)", 
#                                   "fwhm_eff": "median fwhm (arcsec)", 
#                                   "clouds": "mean cloud (mag)"}, axis=1)
# nvisits_all['band'] = 'all'
# vv = pd.concat([nvisits_all, vv])
# vv = vv.pivot(columns=['band'])
# vv['nvisits'] = vv['nvisits'].fillna(0)
# vv['nvisits'] = vv['nvisits'].astype(int)
# #columns = ['_'.join(tt) for tt in vv.columns.to_flat_index()]
# #vv.columns = columns
# vv.sort_values(('nvisits', 'all'), inplace=True)
# vv = vv[vv[('nvisits', 'all')] > 50]
# display(HTML(vv.T.round(2).to_html()))
## make dither pattern plot for COSMOS
ops_visits = opsim

ddf_nvisits = {}
ddf_coadd = {}

m_nvis = maf.CountMetric(col='observationStartMJD', metric_name = "Nvisits")
m_coadd = maf.Coaddm5Metric(m5_col='fiveSigmaDepth')

ddfs = ddf_locations(skycoords=True)
s, plot_dict = maf.make_circle_subset_slicer(ddfs['COSMOS'].ra.deg, ddfs['COSMOS'].dec.deg, radius=2.8, nside=512*2, use_cache=False)

for b in ['u', 'g', 'r', 'i', 'z', 'y', 'all']:
    constraint = f"ddf cosmos {b}"
    if b == 'all':
        opsvis = ops_visits.to_records()
    else:
        opsvis = ops_visits.query("band == @b").to_records()
    ddf_nvisits[b] = maf.MetricBundle(m_nvis, s, constraint)
    ddf_coadd[b] = maf.MetricBundle(m_coadd, s, constraint)

    g = maf.MetricBundleGroup({f'nvisits {b}': ddf_nvisits[b], f'coadd {b}': ddf_coadd[b]}, None)
    g.run_current(constraint, opsvis)
plot_dict['xsize']=250
ph = maf.PlotHandler(thumbnail=False, fig_format='png')
ph.set_metric_bundles([ddf_nvisits['all']])
fig = ph.plot(maf.HealpixSkyMap(), plot_dict)
from rubin_scheduler.scheduler.utils import get_current_footprint
from lsst_survey_sim import plot as ss_plot
# All lsstcam
run_calc = True
nside = 64 * 4
if run_calc:
    nvisits = {}
    coadd = {}
    m_nvis = maf.CountMetric(col='obs_start_mjd', metric_name = "Nvisits")
    m_coadd = maf.Coaddm5Metric(m5_col='cat_m5')
    s = maf.HealpixSlicer(nside=nside, lat_col='s_dec', lon_col='s_ra', rot_sky_pos_col_name = 'sky_rotation')
    for b in ['u', 'g', 'r', 'i', 'z', 'y', 'all']:
        constraint = f"{b}"
        if b == 'all':
            opsvis = svisits.to_records()
            opsvis['cat_m5'] = np.where(np.isnan(opsvis['cat_m5']), -666, opsvis['cat_m5'])
        else:
            opsvis = svisits.query("band == @b").to_records()
            opsvis['cat_m5'] = np.where(np.isnan(opsvis['cat_m5']), -666, opsvis['cat_m5'])
        nvisits[b] = maf.MetricBundle(m_nvis, s, constraint)
        coadd[b] = maf.MetricBundle(m_coadd, s, constraint)
        g = maf.MetricBundleGroup({f'nvisits {b}': nvisits[b], f'coadd {b}': coadd[b]}, None)
        g.run_current(constraint, opsvis)

bg = ss_plot.get_background(nside)
fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(16, 10),)
axdict = {"u": ax[0][0], "g": ax[0][1], "r": ax[0][2],
          "i": ax[1][0], "z": ax[1][1], "y": ax[1][2], "all": None}
for b in ["u", "g", "r", "i", "z", "y"]:
    if len(nvisits[b].metric_values.compressed()) > 1:
        vmax = np.percentile(nvisits[b].metric_values.compressed(), 95)
    else:
        vmax = None
    label_dec = False
    if b == 'u' or b == 'i':
        label_dec = True 
    fig = ss_plot.make_plot(nvisits[b], proj='mcbryde', vmax=vmax, ax=axdict[b], title=f"DP2 band {b}", background=bg, label_dec=label_dec)
fig.tight_layout()
fig.savefig(os.path.join(out_dir, f"dp2_nvisits_band.png"), bbox_inches='tight')

vmax = np.percentile(nvisits['all'].metric_values.compressed(), 95)
fig = ss_plot.make_plot(nvisits['all'], proj='mcbryde', vmin=None, vmax=vmax, ax=None, background=bg, title=f"DP2 visits")
fig.savefig(os.path.join(out_dir, f"dp2_nvisits.png"), bbox_inches='tight')
fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(16, 10),)
axdict = {"u": ax[0][0], "g": ax[0][1], "r": ax[0][2],
          "i": ax[1][0], "z": ax[1][1], "y": ax[1][2], "all": None}
for b in ["u", "g", "r", "i", "z", "y"]:
    if len(coadd[b].metric_values.compressed()) > 1:
        vmin = np.nanpercentile(coadd[b].metric_values.compressed(), 5)
        vmax = np.nanpercentile(coadd[b].metric_values.compressed(), 95)
    else:
        vmax = None
    label_dec = False
    if b == 'u' or b == 'i':
        label_dec = True 
    fig = ss_plot.make_plot(coadd[b], proj='mcbryde', vmin=vmin, vmax=vmax, ax=axdict[b], title=f"DP2 band {b}", background=bg, label_dec=label_dec)
fig.tight_layout()
fig.savefig(os.path.join(out_dir, f"dp2_coadd_band.png"), bbox_inches='tight')
run_calc = True
if run_calc:
    zoom_nvisits = {}
    zoom_coadd = {}
    m_nvis = maf.CountMetric(col='obs_start_mjd', metric_name = "Nvisits")
    m_coadd = maf.Coaddm5Metric(m5_col='cat_m5')
    s, plot_dict = maf.make_circle_subset_slicer(ra_cen=310, dec_cen=-15, radius=5.0, nside=2*4*512, use_cache=True)
    s.lon_col='s_ra'
    s.lat_col='s_dec'
    s.rotSkyPosColName='sky_rotation'
    s.camera_radius = 1.94
    for b in ['g']: #['u', 'g', 'r', 'i', 'z', 'y', 'all']:
        constraint = f"{b}"
        if b == 'all':
            opsvis = svisits.to_records()
            opsvis['cat_m5'] = np.where(np.isnan(opsvis['cat_m5']), -666, opsvis['cat_m5'])
        else:
            opsvis = svisits.query("band == @b").to_records()
            opsvis['cat_m5'] = np.where(np.isnan(opsvis['cat_m5']), -666, opsvis['cat_m5'])
        zoom_nvisits[b] = maf.MetricBundle(m_nvis, s, constraint, plot_dict=plot_dict)
        zoom_coadd[b] = maf.MetricBundle(m_coadd, s, constraint, plot_dict=plot_dict)
        g = maf.MetricBundleGroup({f'nvisits {b}': zoom_nvisits[b], f'coadd {b}': zoom_coadd[b]}, None)
        g.run_current(constraint, opsvis)
# q = svisits.query("band == 'r' and s_ra > 305 and s_ra < 315 and s_dec > -15  and s_dec < -5")
# q = q.query("day_obs == 20250701")

# theta = np.arange(0, np.pi*2, 0.001)
# rad = 1.94
# for i, pt in q.iterrows():
#     x = rad * np.cos(theta) + pt.s_ra
#     y = rad * np.sin(theta) + pt.s_dec
#     plt.plot(x, y, 'k-', alpha=0.1)
# rad = 1.75
# for i, pt in q.iterrows():
#     x = rad * np.cos(theta) + pt.s_ra
#     y = rad * np.sin(theta) + pt.s_dec
#     plt.plot(x, y, 'k-', alpha=0.1)
# plt.scatter(q.s_ra, q.s_dec, c=q.day_obs,  marker='o')
# plt.colorbar(label="dayobs")
# plt.xlabel("RA")
# plt.ylabel("Dec")
# plt.xlim(305, 315)
# plt.ylim(-15, -5)
# print('number of distinct dayobs', len(q.day_obs.unique()), 'number of visits', len(q))
# q.query("day_obs == 20250701")[['day_obs', 's_ra', 's_dec', 'seq_num', 'scheduler_note']]
plot_dict['figsize']= (8, 8)
plot_dict['xsize']=350
#plot_dict['color_min'] = 10
#plot_dict['color_max'] = 13
plot_dict['color_min'] = 0
plot_dict['color_max'] = 10
plot_dict['n_ticks'] = 7
plot_dict["title"] = "dp2 g band fov_edge = 1.94deg"
ph = maf.PlotHandler(thumbnail=False, fig_format='png')
ph.set_metric_bundles([zoom_nvisits['g']])
fig = ph.plot(maf.HealpixSkyMap(), plot_dict)
np.nanmedian(nvisits['g'].metric_values.compressed())
plot_dict['xsize']=700
plot_dict['color_min'] = 24.0
plot_dict['color_max'] = 26.0
plot_dict['n_ticks'] = 7
ph = maf.PlotHandler(thumbnail=False, fig_format='png')
ph.set_metric_bundles([coadd['r']])
fig = ph.plot(maf.HealpixSkyMap(), plot_dict)
nside = nvisits['r'].slicer.nside
allsky = generate_all_sky(nside=nside)

mask_v1 = np.where(
        (
            (np.abs(allsky["eclip_lat"]) < 10)
            & ((allsky["eclip_lon"] > 240) | (allsky["eclip_lon"] < 40))
        ),
        1,
        np.nan,
    )

mask_v2 = np.where(
        ((np.abs(allsky["eclip_lat"]) < 5) & (allsky["eclip_lon"] > 285)),
        1,
        np.nan,
    )

mask_v3 = np.where(
    (allsky["ra"] > 300)
    & (allsky["ra"] < 324)
    & (allsky["dec"] > -26)
    & (allsky["dec"] < -10) 
    & (allsky['eclip_lat'] < 5),
    1,
    np.nan,
)

bg = ss_plot.get_background(nside)
mask_v4 = np.where(np.isnan(bg), np.nan, 1)
print("v1 area", len(np.where(~np.isnan(mask_v1))[0]) * hp.nside2pixarea(nside, degrees=True))
print("v2 area", len(np.where(~np.isnan(mask_v2))[0]) * hp.nside2pixarea(nside, degrees=True))
print("v3 area", len(np.where(~np.isnan(mask_v3))[0]) * hp.nside2pixarea(nside, degrees=True))
print("v4 area", len(np.where(~np.isnan(mask_v4))[0]) * hp.nside2pixarea(nside, degrees=True))
summary = {}
summary['LSST'] = pd.DataFrame([[np.nanmedian(nvisits[b].metric_values.filled(0) * mask_v4) for b in nvisits],
                              [np.nanmedian(coadd[b].metric_values.filled(0) * mask_v4) for b in coadd]],
                      columns=list(nvisits.keys()), index=[f'Nvisits', f'CoaddM5'])
summary['3k'] = pd.DataFrame([[np.nanmedian(nvisits[b].metric_values.filled(0) * mask_v1) for b in nvisits],
                              [np.nanmedian(coadd[b].metric_values.filled(0) * mask_v1) for b in coadd]],
                      columns=list(nvisits.keys()), index=[f'Nvisits', f'CoaddM5'])
summary['750'] = pd.DataFrame([[np.nanmedian(nvisits[b].metric_values.filled(0) * mask_v2) for b in nvisits],
                              [np.nanmedian(coadd[b].metric_values.filled(0) * mask_v2) for b in coadd]],
                      columns=list(nvisits.keys()), index=[f'Nvisits', f'CoaddM5'])
summary['300'] = pd.DataFrame([[np.nanmedian(nvisits[b].metric_values.filled(0) * mask_v3) for b in nvisits],
                              [np.nanmedian(coadd[b].metric_values.filled(0) * mask_v3) for b in coadd]],
                      columns=list(nvisits.keys()), index=[f'Nvisits', f'CoaddM5'])
summary = pd.concat(summary)
display(summary.round(2))
rst_table = tabulate(pd.DataFrame(summary.round(1)), headers='keys', tablefmt='rst')
print(rst_table)
ddfs = ddf_locations()
ddf_nvisits = {}
ddf_coadd = {}
for ddf in ddfs:
    # In ops, we'd need observation_reason probably
    dd_vis = svisits.query("science_program == 'BLOCK-365' and target_name.str.contains(@ddf) or target_name.str.contains(@ddf.lower())").copy()
    dd_vis = dd_vis.dropna(subset=['cat_m5', 's_ra'], axis=0)
    print(ddf, len(dd_vis))
    ddf_nvisits[ddf] = {}
    ddf_coadd[ddf] = {}
    m_nvis = maf.CountMetric(col='obs_start_mjd', metric_name = "Nvisits")
    m_coadd = maf.Coaddm5Metric(m5_col='cat_m5')
    s, plot_dict = maf.make_circle_subset_slicer(ddfs[ddf][0], ddfs[ddf][1], radius=2.0, nside=512, use_cache=True)
    s.lon_col = 's_ra'
    s.lat_col = 's_dec'
    s.rotSkyPosColName = 'sky_rotation'
    for b in ['u', 'g', 'r', 'i', 'z', 'y', 'all']:
        constraint = f"{b}"
        if b == 'all':
            opsvis = dd_vis.to_records()
        else:
            opsvis = dd_vis.query("band == @b").to_records()
        ddf_nvisits[ddf][b] = maf.MetricBundle(m_nvis, s, constraint, plot_dict=plot_dict, info_label=f"{ddf} {b} band")
        ddf_coadd[ddf][b] = maf.MetricBundle(m_coadd, s, constraint, plot_dict=plot_dict, info_label=f"{ddf} {b} band")
        g = maf.MetricBundleGroup({f'nvisits {b}': ddf_nvisits[ddf][b], f'coadd {b}': ddf_coadd[ddf][b]}, None)
        if len(opsvis) > 0:
            g.run_current(constraint, opsvis)
for ddf in ddfs:
    ddf_coadd[ddf]['r'].plot()
# summary value per point
summary = []
for ddf in ddfs:
    nval = []
    cval = []
    for b in ['u', 'g', 'r', 'i', 'z', 'y', 'all']:
        if ddf_nvisits[ddf][b].metric_values is not None:
            nval.append(round(np.nanmedian(ddf_nvisits[ddf][b].metric_values.compressed()), 0))
        else:
            nval.append('-')
        if ddf_coadd[ddf][b].metric_values is not None:
            cval.append(round(np.nanmedian(ddf_coadd[ddf][b].metric_values.compressed()), 2))
        else:
            cval.append('-')
    summary.append(pd.DataFrame([nval, cval], columns=['u', 'g', 'r', 'i', 'z', 'y', 'all'], index=[f"{ddf} nvisits", f"{ddf} coadd"]))
summary = pd.concat(summary)
coaddrows = [r for r in summary.index.values if 'coadd' in r]
summary.loc[coaddrows, 'all'] = np.nan
display(summary.round(1))
print(summary.to_latex())
rst_table = tabulate(summary, headers="keys", tablefmt='rst')
print(rst_table)
# per visit 
q = svisits.query("target_name.str.contains('ddf')").copy()
def build_ddf_name(x: pd.Series) -> str:
    vals = x.target_name.split(',')
    # ddf_edfs_a can overlap with edfs_b, but rarely
    # just assign to first.
    vals = [v.replace(" ", "") for v in vals if 'ddf' in v]
    return vals[0]
q['ddf_label'] = q.apply(build_ddf_name, axis=1)
vv = q.groupby('ddf_label').agg({'seq_num': 'count', 'fwhm_eff': 'median', 'clouds': 'mean', 'obs_start_mjd': np.ptp})
vv.rename({"seq_num": "nvisits", 
           "obs_start_mjd": "timespan (days)", 
           "fwhm_eff": "median fwhm (arcsec)", 
           "clouds": "mean cloud (mag)"}, axis=1, inplace=True)
vv['timespan (days)'] = vv['timespan (days)'].astype(int) + 1
vv = vv.sort_values('nvisits')
display(vv.round(2))
vv.index = [i.replace("_", "\\_") for i in vv.index]
rst_table = tabulate(pd.DataFrame(vv.round(2)), headers='keys', tablefmt='rst')
rst_table = rst_table.replace("nan", "---")
print(rst_table)
y1nvis = {'u': 6, 'g':8, 'r':  18, 'i': 18, 'z': 16, 'y': 16} 
threshold = np.arange(1, 40, 1)
fig, ax = plt.subplots(figsize=(8, 6))
for b in 'ugrizy':
    area = np.zeros(len(threshold))
    for i, t in enumerate(threshold):
        joint = np.where(nvisits[b].metric_values.filled(0) >= t, 1, 0)
        area[i] = len(np.where(joint >= 1)[0]) *  scale
    ax.plot(threshold, area, label=b, color=band_colors[b])
ax.axhline(3000, color='gray', linestyle=':')
ax.axhline(750, color='gray', linestyle=':')
ax.legend()
ax.set_xlabel("Number of visits", fontsize='x-large')
ax.set_ylabel("Area (sq degrees)", fontsize='x-large')
ax.set_yscale('log')
#ax.set_ylim(0, None)
ax.set_xlim(0,None)
ax.grid(alpha=0.3)
plt.savefig(os.path.join(out_dir, "dp2_area_per_band.png"))
y1nvis = {'u': 6, 'g':8, 'r':  18, 'i': 18, 'z': 16, 'y': 16} 
threshold = np.arange(1, 30, 1)
fig, ax = plt.subplots(figsize=(8, 6))
for b in 'ugrizy':
    area = np.zeros(len(threshold))
    for i, t in enumerate(threshold):
        joint = np.where(nvisits[b].metric_values.filled(0) >= t, 1, 0)
        area[i] = len(np.where(joint >= 1)[0]) *  scale
    ax.plot(threshold/y1nvis[b], area, label=b, color=band_colors[b])
ax.axhline(3000, color='gray', linestyle=':')
ax.axhline(750, color='gray', linestyle=':')
ax.legend()
ax.set_xlabel("Number of visits / Y1 (per band)", fontsize='x-large')
ax.set_ylabel("Area (sq degrees)", fontsize='x-large')
ax.set_yscale('log')
#ax.set_ylim(0, None)
ax.set_xlim(0, 1.5)
ax.grid(alpha=0.3)
plt.savefig(os.path.join(out_dir, "dp2_area_per_band_scaled.png"), bbox_inches='tight')
# plt.figure(figsize=(8, 6))

# gq = sv_visits.copy()
# gq.loc[:, 'night'] = np.floor(gq.obs_start_mjd - 0.5) - Time("2025-06-20T12:00:00", scale='tai').mjd + 1

# night_cut = gq.night.max() 
# night_cut = 100
# nights = np.arange(0, night_cut+1) 

# _ = plt.hist(sv_orig_visits.night, bins=nights, histtype='step', cumulative=True, label='sv_sim_1.0')
# _ = plt.hist(sv_osim_visits.night, bins=nights, histtype='step', cumulative=True, label='sv_20250620')
# _ = plt.hist(gq.night, bins=nights, histtype='step', cumulative=True, label='consdb', linewidth=2)

# plt.legend()

# plt.xlim(0, gq.night.max())
# plt.grid(alpha=0.5)
# plt.xlabel("Night of SV survey", fontsize='x-large')
# plt.ylabel("Cumulative number of visits", fontsize='x-large')
# plt.savefig(os.path.join(out_dir, "sv_cumulative.png"), bbox_inches='tight')
q = svisits.query("filter_change == 1")
_ = plt.hist(q.visit_gap, bins=np.arange(50, 400, 5))
plt.xlabel("Visit gap in filter change (seconds)", fontsize='large')
plt.figtext(0.6, 0.8, f"Minimum visit_gap {q.visit_gap.min():.1f}s")
plt.figtext(0.6, 0.75, f"Median visit_gap {q.visit_gap.median():.1f}s")
plt.axvline(120, color='black')
plt.title("DP2 survey visits with filter change")
q = svisits.query("filter_change == 1")
tenpercentile = np.percentile(q.fault_idle.values, 90)/60
print("tenpercentile value (min)", tenpercentile)
print(np.percentile(q.fault_idle.values, 80), '90%ile (min)', np.percentile(q.fault_idle.values, 90)/60, 'max minutes', q.fault_idle.values.max()/60)
print(q.query("fault_idle > 40").fault_idle.sum()/60/60)
print(len(q.query("fault_idle > 240")) / len(q))
print('15 minute', len(q.query("fault_idle > 15*60")) / len(q))
print('1 hour', len(q.query("fault_idle > 60*60")) / len(q))
print(q.query("fault_idle > 20*60").day_obs.unique())
q = svisits
_ = plt.hist(q.overhead + q.fault_idle, bins=100)#np.arange(0, 60*60, 40))
plt.yscale('log')
plt.ylabel("Number of visits", fontsize='x-large')
plt.xlabel("Visit gap all visits (seconds)", fontsize='large')
plt.figtext(0.5, 0.80, f"Median visit gap {(q.overhead + q.fault_idle).median():.1f}s")
plt.figtext(0.5, 0.75, f"Mean visit gap {(q.overhead + q.fault_idle).mean():.1f}s")
plt.title("All DP2 survey visits")
q = svisits
_ = plt.hist(q.overhead, bins=100)#np.arange(0, 60, 2))
plt.yscale('log')
plt.ylabel("Number of visits", fontsize='x-large')
plt.xlabel("Estimated slew overhead all visits (seconds)", fontsize='large')
plt.figtext(0.5, 0.8, f"Median slew overhead {q.overhead.median():.1f}s")
plt.figtext(0.5, 0.75, f"Mean slew overhead {q.overhead.mean():.1f}s")
plt.title("All DP2 survey visits")
q = svisits.query("filter_change == 0")
perc = np.nanpercentile(q.fault_idle.values, 99)/60
print("?percentile value (min)", perc)
print(np.nanpercentile(q.fault_idle.values, 80), '90%ile (min)', np.nanpercentile(q.fault_idle.values, 90)/60, 'max minutes', np.nanmax(q.fault_idle.values)/60)
print(q.query("fault_idle > 0").fault_idle.sum()/60/60)
print(len(q.query("fault_idle > 240")) / len(q))
print('15 minute', len(q.query("fault_idle > 15*60")) / len(q))
print('1 hour', len(q.query("fault_idle > 60*60")) / len(q))
print(q.query("fault_idle > 20*60").day_obs.unique())
q = svisits.query("filter_change == 0")
_ = plt.hist(q.overhead, bins=100)#np.arange(0, 60, 2))
plt.yscale('log')
plt.ylabel("Number of visits", fontsize='x-large')
plt.xlabel("Estimated slew overhead all visits (seconds)", fontsize='large')
plt.figtext(0.5, 0.8, f"Median slew overhead {q.overhead.median():.1f}s")
plt.figtext(0.5, 0.75, f"Mean slew overhead {q.overhead.mean():.1f}s")
plt.title("All non-filter-change DP2 survey visits")