import os
import sqlite3
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import colorcet
import cycler
import skyproj
import healpy as hp
import scipy.stats
import rubin_sim.maf as maf
from rubin_sim.data import get_baseline
from rubin_scheduler.scheduler.utils import get_current_footprint
from astropy.time import Time
from astropy.coordinates import get_sun
from astropy.coordinates import EarthLocation
from rubin_scheduler.utils import SURVEY_START_MJD
RUBIN_LOC = EarthLocation.of_site("Cerro Pachon")
MJD_2024 = Time(
"2024-01-01T00:00:00.00",
format="isot",
scale="utc",
location=RUBIN_LOC,
).utc.mjd
YEAR = 365.25
BASELINE_SURVEY_START_MJD = SURVEY_START_MJD # 60796 = May 10, 2025
def get_sun_ra_at_mjd(mjd):
t = Time(mjd, format='mjd', location=RUBIN_LOC)
return get_sun(t).ra.deg
def get_phase_for_ra_in_mjd(ra, start_time):
sun_ra_start = get_sun_ra_at_mjd(start_time)
return (
((sun_ra_start - ra) / 180 * np.pi + np.pi / 2) % (2.0 * np.pi)
* (YEAR / 2.0 / np.pi)
)
def get_season(mjd, ra, start_time):
phase = get_phase_for_ra_in_mjd(ra, start_time)
return (mjd - start_time + phase) / YEAR
from rubin_sim.maf import BaseMetric
from rubin_sim.maf import find_season_edges
from rubin_scheduler.utils import survey_start_mjd, calc_season
if '__main__.NVisPerSeasonMetric' in BaseMetric.registry:
del BaseMetric.registry['__main__.NVisPerSeasonMetric']
if '__main__.SeasonValMetric' in BaseMetric.registry:
del BaseMetric.registry['__main__.SeasonValMetric']
class SeasonValMetric(BaseMetric):
"""Calculate a (single) value for the seasons.
Parameters
----------
min_exp_time : `float`, optional
Minimum visit exposure time to count for a 'visit', in seconds.
Default 16, to skip near-sun twilight visits.
reduce_func : func-like, optional
Any function, to turn the array of season values into a single number.
Default np.median.
mjd_start : `float` or None, optional
The survey start date, in mjd.
Default None uses `survey_start_mjd()`.
mjd_col : `str`, optional
The name of the MJD column in the data.
exp_time_col : `str`, optional
The name of the visit exposure time column.
"""
def __init__(self, min_exp_time=16, reduce_func=np.median, mjd_start=None, mjd_col='observationStartMJD',
exp_time_col='visitExposureTime', **kwargs):
units = ""
if mjd_start is None:
mjd_start = SURVEY_START_MJD
self.mjd_start = mjd_start
self.mjd_col = mjd_col
self.exp_time_col = exp_time_col
self.min_exp_time = min_exp_time
self.reduce_func = reduce_func
if 'metric_name' not in kwargs:
try:
metric_name = f'Season_{reduce_func.__name__}'
except AttributeError:
metric_name = 'SeasonVal'
else:
metric_name = kwargs['metric_name']
del kwargs['metric_name']
super().__init__(
col=[self.mjd_col, self.exp_time_col], units=units,
metric_name=metric_name, **kwargs
)
def run(self, data_slice, slice_point):
# Order data Slice/times and exclude visits which are too short.
long = np.where(data_slice[self.exp_time_col] > self.min_exp_time)
if len(long[0]) == 0:
return self.badval
data = np.sort(data_slice[long], order=self.mjd_col)
# SlicePoints ra/dec are always in radians -
# convert to degrees to calculate season
seasons = calc_season(np.degrees(slice_point["ra"]), data[self.mjd_col], self.mjd_start)
result = self.reduce_func(seasons)
return result
class NVisPerSeasonMetric(BaseMetric):
"""Calculate some per-season numbers: number of visits
per season, the season length per season, and season id number.
Parameters
----------
min_exp_time : `float`, optional
Minimum visit exposure time to count for a 'visit', in seconds.
Default 20.
reduce_func : function, optional
Function that can operate on array-like structures.
Typically numpy function.
This reduces the season length in each season from 10 separate
values to a single value.
Default np.median.
mjd_start : `float` or None
The MJD of the start of the survey (mjd days).
Default None uses `survey_start_mjd()`.
"""
def __init__(
self,
min_exp_time=16,
mjd_start=None,
mjd_col="observationStartMJD",
exp_time_col="visitExposureTime",
night_col="night",
metric_name="NVisPerSeason",
**kwargs,
):
units = "#"
if mjd_start is None:
mjd_start = SURVEY_START_MJD
self.mjd_start = mjd_start
self.mjd_col = mjd_col
self.exp_time_col = exp_time_col
self.night_col = night_col
self.min_exp_time = min_exp_time
super().__init__(
col=[self.mjd_col, self.exp_time_col, self.night_col], units=units,
metric_name=metric_name, metric_dtype='object', **kwargs
)
def run(self, data_slice, slice_point):
# Order data Slice/times and exclude visits which are too short.
long = np.where(data_slice[self.exp_time_col] > self.min_exp_time)
if len(long[0]) == 0:
return self.badval
# Sort remaining visits in order of MJD
data = np.sort(data_slice[long], order=self.mjd_col)
# Calculate the season value for each of these visits
seasons = (calc_season(np.degrees(slice_point["ra"]), data[self.mjd_col], self.mjd_start))
int_seasons = np.floor(seasons)
# Count up X per season
n_per_season = []
max_nightgap_per_season = []
median_nightgap_per_season = []
for si in np.unique(int_seasons):
match = np.where(int_seasons == si)[0]
# number of visits per season
n_per_season.append(len(match))
night_gaps = np.diff(np.unique(data[self.night_col][match]))
if len(night_gaps) == 0:
maxgap = self.badval
mediangap = self.badval
else:
maxgap = night_gaps.max()
mediangap = np.median(night_gaps)
max_nightgap_per_season.append(maxgap)
median_nightgap_per_season.append(mediangap)
n_per_season = np.array(n_per_season)
max_nightgap_per_season = np.array(max_nightgap_per_season)
median_nightgap_per_season = np.array(median_nightgap_per_season)
# Count up length of seasons using season utils
first_of_season, last_of_season = find_season_edges(seasons)
seasonlengths = data[self.mjd_col][last_of_season] - data[self.mjd_col][first_of_season]
# Cunt seasons in the alternate way used in the FBS
alt_season = get_season(data[self.mjd_col], np.degrees(slice_point["ra"]), self.mjd_start)
int_alt = np.floor(alt_season)
alt_n_per_season = []
for si in np.unique(int_alt):
alt_n_per_season.append(len(np.where(int_alt == si)[0]))
alt_n_per_season = np.array(alt_n_per_season)
return {'n_per_season' : n_per_season,
'max_gap_per_season': max_nightgap_per_season,
'median_gap_per_season': median_nightgap_per_season,
'alt_n_per_season' : alt_n_per_season,
'season_length' : seasonlengths,
'season_id' : np.unique(int_seasons)}
def reduce_min(self, metric_value):
# Don't count first or last seasons, as they could be short
return np.min(metric_value['n_per_season'][1:-1])
def reduce_max(self, metric_value):
return np.max(metric_value['n_per_season'])
def reduce_season_m1(self, metric_value):
idx = np.where(metric_value['season_id'] == -1)[0]
if len(idx) == 0:
return self.badval
return metric_value['n_per_season'][idx]
def reduce_season_0(self, metric_value):
idx = np.where(metric_value['season_id'] == 0)[0]
if len(idx) == 0:
return self.badval
return metric_value['n_per_season'][idx]
def reduce_season_1(self, metric_value):
idx = np.where(metric_value['season_id'] == 1)[0]
if len(idx) == 0:
return self.badval
return metric_value['n_per_season'][idx]
def reduce_season_2(self, metric_value):
idx = np.where(metric_value['season_id'] == 2)[0]
if len(idx) == 0:
return self.badval
return metric_value['n_per_season'][idx]
def reduce_season_3(self, metric_value):
idx = np.where(metric_value['season_id'] == 3)[0]
if len(idx) == 0:
return self.badval
return metric_value['n_per_season'][idx]
def reduce_season_4(self, metric_value):
idx = np.where(metric_value['season_id'] == 4)[0]
if len(idx) == 0:
return self.badval
return metric_value['n_per_season'][idx]
def reduce_season_5(self, metric_value):
idx = np.where(metric_value['season_id'] == 5)[0]
if len(idx) == 0:
return self.badval
return metric_value['n_per_season'][idx]
def reduce_season_6(self, metric_value):
idx = np.where(metric_value['season_id'] == 6)[0]
if len(idx) == 0:
return self.badval
return metric_value['n_per_season'][idx]
def reduce_season_7(self, metric_value):
idx = np.where(metric_value['season_id'] == 7)[0]
if len(idx) == 0:
return self.badval
return metric_value['n_per_season'][idx]
def reduce_season_8(self, metric_value):
idx = np.where(metric_value['season_id'] == 8)[0]
if len(idx) == 0:
return self.badval
return metric_value['n_per_season'][idx]
def reduce_season_9(self, metric_value):
idx = np.where(metric_value['season_id'] == 9)[0]
if len(idx) == 0:
return self.badval
return metric_value['n_per_season'][idx]
from rubin_scheduler.scheduler.utils import get_current_footprint
opsim_fname = 'baseline_v4.3.2_10yrs.db'
#opsim_fname = get_baseline()
print(opsim_fname)
outdir = "season_figs"
runName = os.path.split(opsim_fname)[-1].replace('.db', '')
print(runName)
# tried out a test point .. but then went to exgal-wfd footprint
test_ra = 30.0
test_dec = -20.0
test_slicer = maf.UserPointsSlicer(test_ra, test_dec)
footprints, labels = get_current_footprint(nside=64)
wfdhpid = np.where((labels == 'lowdust') | (labels == "euclid_overlap") | (labels == "LMC_SMC") | (labels == "virgo"))[0]
print(np.unique(labels))
wfd_slicer = maf.HealpixSubsetSlicer(nside=64, hpid=wfdhpid)
#constraint = "filter == 'g' or filter == 'r' or filter == 'i'"
#label = "gri bands"
constraint = ""
label = ""
mymetric = NVisPerSeasonMetric()
summary_metrics = [maf.MedianMetric()]
bundle = maf.MetricBundle(mymetric, wfd_slicer, constraint, run_name=runName, summary_metrics=summary_metrics)
g = maf.MetricBundleGroup({'nvals': bundle}, opsim_fname, out_dir=outdir, results_db=None, verbose=True)
# And calculate the metric
g.run_all()
# just check the reduce functions showed up (and what they're named)
for b in g.bundle_dict:
print(b, g.bundle_dict[b].summary_values)
# hackety hack to make a quick histogram of the non-first/last season nvisits values
vals = []
for i in bundle.metric_values.compressed():
vals.append(i['n_per_season'][1:-1])
vals = np.hstack(vals)
_ = plt.hist(vals, bins=np.arange(0, 200, 2))
#_ = plt.hist(vals, bins=np.arange(0, 120, 1))
plt.grid(True, alpha=0.3)
#plt.axvline(14, color='k')
#plt.axvline(65, color='k')
plt.title(f"{runName} {label}")
plt.xlabel("Number of visits per season")
#plt.savefig("rolling_cadence_nvisits_per_season.png")
# Can we try to identify on/off seasons automatically?
# <40 nvisits = off and >110 = on? and anything between is "average"?
low = 45
high = 105
season_maps = {}
for sid in np.arange(-1, 10, 1):
vals = []
for i in bundle.metric_values.filled(0):
if i == 0:
vals.append(hp.UNSEEN)
else:
idx = np.where(i['season_id'] == sid)[0]
if len(idx) == 0:
vals.append(hp.UNSEEN)
else:
n_in_season = i['n_per_season'][idx][0]
if n_in_season < low:
vals.append(-1)
elif n_in_season > high:
vals.append(1)
else:
vals.append(0)
vals = np.array(vals)
season_maps[sid] = vals
hp.mollview(season_maps[5])
n_on = np.zeros(len(bundle.slicer))
n_off = np.zeros(len(bundle.slicer))
for idx, i in enumerate(bundle.metric_values.filled(0)):
if i == 0:
n_on[idx] = 0
n_off[idx] = 0
else:
n_on[idx] = len(np.where(i['n_per_season'] > high)[0])
n_off[idx] = len(np.where(i['n_per_season'] < low)[0])
hp.mollview(n_on, min=0, max=4)
# hackety hack to make a quick histogram of the non-first/last season nvisits values
vals_on = []
vals_mid = []
vals_off = []
k = 'season_length'
bins = np.arange(0, 200, 10)
for i in bundle.metric_values.compressed():
on = np.where(i['n_per_season'] > high)
vals_on.append(i[k][on])
off = np.where(i['n_per_season'] < low)
vals_off.append(i[k][off])
mid = np.where((i['n_per_season'] >= low) & (i['n_per_season'] <= high))
vals_mid.append(i[k][mid])
vals_on = np.hstack(vals_on)
vals_mid = np.hstack(vals_mid)
vals_off = np.hstack(vals_off)
_ = plt.hist(vals_on, bins=bins, alpha=0.3, label=f"High - median={np.median(vals_on)}")
_ = plt.hist(vals_off, bins=bins, alpha=0.3, label=f"Low - median={np.median(vals_off)}")
_ = plt.hist(vals_mid, bins=bins, alpha=0.3, label=f"Mean - median={np.median(vals_mid)}")
plt.grid(True, alpha=0.3)
plt.legend()
#plt.axvline(14, color='k')
#plt.axvline(65, color='k')
plt.title(f"{runName} {label}")
plt.xlabel("Median Gap Between Nights Per Season")
plt.savefig("rolling_cadence_median_nightgap.png")
print(np.median(vals_on), np.median(vals_mid), np.median(vals_off))
# This is just a standard metric bundle - minimum number of visits per season
k = 'NVisPerSeason_min'
g.bundle_dict[k].set_plot_dict({'color_min': 0, 'color_max': 50, 'x_min': 0, 'x_max':200})
g.bundle_dict[k].set_plot_funcs([maf.HealpixSkyMap()])
g.bundle_dict[k].plot()
#print(g.bundle_dict[k].summary_values)
# I'm curious about the season length over the sky, say in the middle of the survey
# is good that the rolling (active/inactive) doesn't strongly impact season length
# I'm not so sure about why the boundaries between regions have such long seasons ..
vals = []
sid = 3
for i in bundle.metric_values.filled(0):
if i == 0:
vals.append(hp.UNSEEN)
else:
idx = np.where(i['season_id'] == sid)[0]
if len(idx) == 0:
vals.append(hp.UNSEEN)
else:
vals.append(i['n_per_season'][idx][0])
vals = np.array(vals)
hp.mollview(vals, min=40, max=110)
# Plot the visits per season
for i in ['m1', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
k = f'NVisPerSeason_season_{i}'
g.bundle_dict[k].set_plot_dict({'color_min': 0, 'color_max': 200, 'x_min': 0, 'x_max':200})
g.bundle_dict[k].set_plot_funcs([maf.HealpixSkyMap()])
g.bundle_dict[k].plot(savefig=True)