AGNBoost Basic Usage Tutorial¶
This notebook demonstrates the basic workflow for using AGNBoost to predict AGN fractions from photometric data. We'll walk through:
- Loading astronomical data with the Catalog class
- Exploring the dataset structure and properties
- Splitting data into training, validation, and test sets
- Cleaning the data by removing rows with missing values
- Loading a pre-trained AGN fraction model
- Making predictions with uncertainty quantification
- Evaluating model performance
Let's start by importing the necessary libraries and loading our data.
%load_ext autoreload
%autoreload 2
# Set agnboost folder as root
import os
# Navigate to the repository root (parent directory of notebooks/)
os.chdir('..')
# Import necessary libraries
import numpy as np
import pandas as pd
from agnboost import dataset, model
# Set random seed for reproducibility
np.random.seed(123)
print("AGNBoost Basic Usage Tutorial")
print("=" * 40)
2025-08-08 00:37:29.402 | INFO | agnboost.config:<module>:11 - PROJ_ROOT path is: /home/kurt/Documents/agnboost
AGNBoost Basic Usage Tutorial ========================================
Loading the Data¶
We'll use the Catalog class to load our astronomical dataset. The cigale_mock_small.csv
file contains is a small set of mock NIRCam+MIRI CIGALE galaxies for demonstration purposes.
# Load the astronomical data using the Catalog class
catalog = dataset.Catalog(path="data/cigale_mock_small.csv",summarize = False)
Current working directory: /home/kurt/Documents/agnboost Looking for bands file at: /home/kurt/Documents/agnboost/agnboost/allowed_bands.json [INFO] Loaded bands file metadata: This file contains the allowed photometric bands for JWST [INFO] Loaded 11 allowed bands from agnboost/allowed_bands.json [INFO] Attempting to load file with delimiter: ',' [INFO] Successfully loaded data with 1000 rows. [INFO] Found 11 valid band columns: [INFO] - jwst.nircam.F115W (F115W): 1.154 μm [INFO] - jwst.nircam.F150W (F150W): 1.501 μm [INFO] - jwst.nircam.F200W (F200W): 1.988 μm [INFO] - jwst.nircam.F277W (F277W): 2.776 μm [INFO] - jwst.nircam.F356W (F356W): 3.565 μm [INFO] - jwst.nircam.F410M (F410M): 4.083 μm [INFO] - jwst.nircam.F444W (F444W): 4.402 μm [INFO] - jwst.miri.F770W (F770W): 7.7 μm [INFO] - jwst.miri.F1000W (F1000W): 10.0 μm [INFO] - jwst.miri.F1500W (F1500W): 15.0 μm [INFO] - jwst.miri.F2100W (F2100W): 21.0 μm
catalog.allowed_bands
{'jwst.nircam.F115W': {'shorthand': 'F115W', 'wavelength': 1.154}, 'jwst.nircam.F150W': {'shorthand': 'F150W', 'wavelength': 1.501}, 'jwst.nircam.F200W': {'shorthand': 'F200W', 'wavelength': 1.988}, 'jwst.nircam.F277W': {'shorthand': 'F277W', 'wavelength': 2.776}, 'jwst.nircam.F356W': {'shorthand': 'F356W', 'wavelength': 3.565}, 'jwst.nircam.F410M': {'shorthand': 'F410M', 'wavelength': 4.083}, 'jwst.nircam.F444W': {'shorthand': 'F444W', 'wavelength': 4.402}, 'jwst.miri.F770W': {'shorthand': 'F770W', 'wavelength': 7.7}, 'jwst.miri.F1000W': {'shorthand': 'F1000W', 'wavelength': 10.0}, 'jwst.miri.F1500W': {'shorthand': 'F1500W', 'wavelength': 15.0}, 'jwst.miri.F2100W': {'shorthand': 'F2100W', 'wavelength': 21.0}}
Exploring the Dataset¶
Let's examine the structure of our data to understand what photometric bands are available and get basic statistics about our dataset. The print_data_summary()
method provides comprehensive information about:
- Dataset dimensions and memory usage
- Photometric band validation and metadata
- Column-by-column statistics including missing values
- Summary statistics for numerical columns
This information helps us understand data quality and identify any potential issues before modeling.
# Display comprehensive data summary
catalog.print_data_summary()
# Check which photometric bands were validated
valid_bands = catalog.get_valid_bands()
print(f"\nValid photometric bands found: {len(valid_bands)}")
for band_name, info in valid_bands.items():
print(f" {band_name}: {info['shorthand']} ({info['wavelength']} μm)")
# Check if our target variable exists
target_column = 'agn.fracAGN'
if target_column in catalog.get_data().columns:
print(f"\nTarget variable '{target_column}' found in dataset")
target_stats = catalog.get_data()[target_column].describe()
print("Target variable statistics:")
print(target_stats)
else:
print(f"Warning: Target variable '{target_column}' not found in dataset")
print("Available columns:", list(catalog.get_data().columns))
================================================================================ DATA SUMMARY: cigale_mock_small.csv ================================================================================ Dimensions: 1000 rows × 26 columns Memory usage: 0.20 MB -------------------------------------------------------------------------------- Valid Band Columns: -------------------------------------------------------------------------------- Column Name Shorthand Wavelength (μm) -------------------------------------------------------------------------------- jwst.nircam.F115W F115W 1.154 jwst.nircam.F150W F150W 1.501 jwst.nircam.F200W F200W 1.988 jwst.nircam.F277W F277W 2.776 jwst.nircam.F356W F356W 3.565 jwst.nircam.F410M F410M 4.083 jwst.nircam.F444W F444W 4.402 jwst.miri.F770W F770W 7.700 jwst.miri.F1000W F1000W 10.000 jwst.miri.F1500W F1500W 15.000 jwst.miri.F2100W F2100W 21.000 -------------------------------------------------------------------------------- Column Information: -------------------------------------------------------------------------------- Column Name Type Non-Null Null % -------------------------------------------------------------------------------- IRAC1 float64 1000/1000 0.00% IRAC2 float64 1000/1000 0.00% IRAC3 float64 1000/1000 0.00% IRAC4 float64 1000/1000 0.00% hst.acs.wfc.F606W float64 1000/1000 0.00% hst.acs.wfc.F814W float64 1000/1000 0.00% hst.wfc3.ir.F125W float64 1000/1000 0.00% hst.wfc3.ir.F140W float64 1000/1000 0.00% hst.wfc3.ir.F160W float64 1000/1000 0.00% jwst.miri.F1000W float64 1000/1000 0.00% jwst.miri.F1280W float64 1000/1000 0.00% jwst.miri.F1500W float64 1000/1000 0.00% jwst.miri.F1800W float64 1000/1000 0.00% jwst.miri.F2100W float64 1000/1000 0.00% jwst.miri.F770W float64 1000/1000 0.00% jwst.nircam.F115W float64 1000/1000 0.00% jwst.nircam.F150W float64 1000/1000 0.00% jwst.nircam.F200W float64 1000/1000 0.00% jwst.nircam.F277W float64 1000/1000 0.00% jwst.nircam.F356W float64 1000/1000 0.00% jwst.nircam.F410M float64 1000/1000 0.00% jwst.nircam.F444W float64 1000/1000 0.00% sfh.sfr100Myrs float64 1000/1000 0.00% stellar.m_star float64 1000/1000 0.00% agn.fracAGN float64 1000/1000 0.00% universe.redshift float64 1000/1000 0.00% -------------------------------------------------------------------------------- Numeric Column Statistics: -------------------------------------------------------------------------------- Column Mean Std Min Max -------------------------------------------------------------------------------- IRAC1 57.9 1308 2.413e-06 4.098e+04 IRAC2 22.97 509.9 8.821e-07 1.596e+04 IRAC3 39.96 918 1.646e-06 2.879e+04 IRAC4 57.92 1309 2.413e-06 4.099e+04 hst.acs.wfc.F606W 0.311 5.52 0 169 hst.acs.wfc.F814W 0.3093 5.099 5.455e-13 155.8 hst.wfc3.ir.F125W 0.5148 6.576 1.614e-09 192.2 hst.wfc3.ir.F140W 0.6132 7.125 2.611e-09 196.7 hst.wfc3.ir.F160W 0.7412 7.991 4.119e-09 200.2 jwst.miri.F1000W 57.54 1356 3.049e-06 4.257e+04 jwst.miri.F1280W 71.4 1587 4.006e-06 4.97e+04 jwst.miri.F1500W 74.16 1638 4.475e-06 5.129e+04 jwst.miri.F1800W 82.2 1710 4.232e-06 5.339e+04 jwst.miri.F2100W 87.79 1773 4.001e-06 5.527e+04 jwst.miri.F770W 58.58 1315 2.288e-06 4.117e+04 jwst.nircam.F115W 0.461 6.317 1.192e-09 188.6 jwst.nircam.F150W 0.706 7.721 3.693e-09 198.7 jwst.nircam.F200W 1.482 15.31 1.959e-08 280.6 jwst.nircam.F277W 4.441 68.46 1.332e-07 2009 jwst.nircam.F356W 11.62 225.3 4.48e-07 6973 jwst.nircam.F410M 17.51 373 6.242e-07 1.164e+04 jwst.nircam.F444W 21.65 477 8.29e-07 1.492e+04 sfh.sfr100Myrs 4.765 4.403 4.765e-27 15.79 stellar.m_star 3.51e+09 2.551e+09 3.367e+07 7.388e+09 agn.fracAGN 0.4993 0.3164 0 0.99 universe.redshift 1.765 1.811 0.01 7.999 ================================================================================ Valid photometric bands found: 11 jwst.nircam.F115W: F115W (1.154 μm) jwst.nircam.F150W: F150W (1.501 μm) jwst.nircam.F200W: F200W (1.988 μm) jwst.nircam.F277W: F277W (2.776 μm) jwst.nircam.F356W: F356W (3.565 μm) jwst.nircam.F410M: F410M (4.083 μm) jwst.nircam.F444W: F444W (4.402 μm) jwst.miri.F770W: F770W (7.7 μm) jwst.miri.F1000W: F1000W (10.0 μm) jwst.miri.F1500W: F1500W (15.0 μm) jwst.miri.F2100W: F2100W (21.0 μm) Target variable 'agn.fracAGN' found in dataset Target variable statistics: count 1000.000000 mean 0.499330 std 0.316352 min 0.000000 25% 0.200000 50% 0.500000 75% 0.800000 max 0.990000 Name: agn.fracAGN, dtype: float64
Creating Train/Test/Validation Splits¶
Before any modeling, we need to split our data into separate sets for training, validation, and testing. AGNBoost provides intelligent data splitting with optional stratification to ensure representative samples across all splits.
We'll use the default split ratios:
- 60% for training
- 20% for validation
- 20% for testing
The random state ensures reproducible results. This step is not strictly necessary, as AGNBoost will internally perform the split if it has not explicitly been done.
# Create train/validation/test splitsget_train_val_test_sizes
catalog.split_data(test_size=0.2, val_size=0.2, random_state=42)
# Get split information
print("Data split summary:")
print(f" Training: {catalog.get_train_len()}")
print(f" Valdiation: {catalog.get_val_len()}")
print(f" Testing: {catalog.get_test_len()}")
Data split summary: Training: 600 Valdiation: 200 Testing: 200
Cleaning the Data¶
Real astronomical datasets often contain missing values due to various observational limitations. Before training or making predictions, we will remove rows that have NaN values in critical columns.
The drop_nan()
method removes rows with missing values in the validated photometric band columns, ensuring our model receives complete data for all features.
# Drop rows with NaN values in the validated columns
catalog.drop_nan(inplace=True)
[INFO] No rows with NaN values found in the specified columns.
There are no-nan rows to remove since the CIGALE mock data we loaded has none, but your real data might.
Creating Features¶
AGNBoost automatically engineers features from photometric data, including colors and transformations. Let's create the feature dataframe that will be used for modeling.
By default, AGNBoost will create a features consisting of the photometric bands + derived colors + the squares of those derived colors
# Create features for modeling
catalog.create_feature_dataframe()
# Get information about created features
features = catalog.get_features()
print(f"Feature engineering complete:")
print(f" Feature dataframe shape: {features.shape}")
[INFO] Created feature dataframe with 121 columns and 1000 rows. [INFO] Created features are: ['jwst.nircam.F115W', 'jwst.nircam.F150W', 'jwst.nircam.F200W', 'jwst.nircam.F277W', 'jwst.nircam.F356W', 'jwst.nircam.F410M', 'jwst.nircam.F444W', 'jwst.miri.F770W', 'jwst.miri.F1000W', 'jwst.miri.F1500W', 'jwst.miri.F2100W', 'F2100W/F1500W', 'F2100W/F1000W', 'F2100W/F770W', 'F2100W/F444W', 'F2100W/F410M', 'F2100W/F356W', 'F2100W/F277W', 'F2100W/F200W', 'F2100W/F150W', 'F2100W/F115W', 'F1500W/F1000W', 'F1500W/F770W', 'F1500W/F444W', 'F1500W/F410M', 'F1500W/F356W', 'F1500W/F277W', 'F1500W/F200W', 'F1500W/F150W', 'F1500W/F115W', 'F1000W/F770W', 'F1000W/F444W', 'F1000W/F410M', 'F1000W/F356W', 'F1000W/F277W', 'F1000W/F200W', 'F1000W/F150W', 'F1000W/F115W', 'F770W/F444W', 'F770W/F410M', 'F770W/F356W', 'F770W/F277W', 'F770W/F200W', 'F770W/F150W', 'F770W/F115W', 'F444W/F410M', 'F444W/F356W', 'F444W/F277W', 'F444W/F200W', 'F444W/F150W', 'F444W/F115W', 'F410M/F356W', 'F410M/F277W', 'F410M/F200W', 'F410M/F150W', 'F410M/F115W', 'F356W/F277W', 'F356W/F200W', 'F356W/F150W', 'F356W/F115W', 'F277W/F200W', 'F277W/F150W', 'F277W/F115W', 'F200W/F150W', 'F200W/F115W', 'F150W/F115W', 'F2100W/F1500W^2', 'F2100W/F1000W^2', 'F2100W/F770W^2', 'F2100W/F444W^2', 'F2100W/F410M^2', 'F2100W/F356W^2', 'F2100W/F277W^2', 'F2100W/F200W^2', 'F2100W/F150W^2', 'F2100W/F115W^2', 'F1500W/F1000W^2', 'F1500W/F770W^2', 'F1500W/F444W^2', 'F1500W/F410M^2', 'F1500W/F356W^2', 'F1500W/F277W^2', 'F1500W/F200W^2', 'F1500W/F150W^2', 'F1500W/F115W^2', 'F1000W/F770W^2', 'F1000W/F444W^2', 'F1000W/F410M^2', 'F1000W/F356W^2', 'F1000W/F277W^2', 'F1000W/F200W^2', 'F1000W/F150W^2', 'F1000W/F115W^2', 'F770W/F444W^2', 'F770W/F410M^2', 'F770W/F356W^2', 'F770W/F277W^2', 'F770W/F200W^2', 'F770W/F150W^2', 'F770W/F115W^2', 'F444W/F410M^2', 'F444W/F356W^2', 'F444W/F277W^2', 'F444W/F200W^2', 'F444W/F150W^2', 'F444W/F115W^2', 'F410M/F356W^2', 'F410M/F277W^2', 'F410M/F200W^2', 'F410M/F150W^2', 'F410M/F115W^2', 'F356W/F277W^2', 'F356W/F200W^2', 'F356W/F150W^2', 'F356W/F115W^2', 'F277W/F200W^2', 'F277W/F150W^2', 'F277W/F115W^2', 'F200W/F150W^2', 'F200W/F115W^2', 'F150W/F115W^2'] Feature engineering complete: Feature dataframe shape: (1000, 121)
Loading the Pre-trained Model¶
AGNBoost comes with pre-trained models for common astronomical tasks. We'll load the model specifically trained for AGN fraction estimation (agn.fracAGN
).
The load_models()
method automatically:
- Checks for compatible pre-trained models
- Validates feature compatibility between the model and our data
- Loads model metadata including training parameters and performance metrics
# Initialize an AGNBoost model. The target variable is the name of the target variable column, and its value in the passed dictionary is the distribution used to model it.
agnboost_m = model.AGNBoost( feature_names = catalog.get_feature_names(),
target_variables = {'agn.fracAGN' : 'ZABeta'},
)
# Load pre-trained models. We will not pass a filename to load, and will simply the the most recent fracAGN model.
agnboost_m.load_model(model_name = 'fracAGN', overwrite = True)
if agnboost_m.models['agn.fracAGN'] is not None:
print("✅ Pre-trained model loaded successfully!")
# Display model information
model_info = agnboost_m.model_info.get('agn.fracAGN', {})
if model_info:
print("\nModel information:")
if 'training_timestamp' in model_info:
print(f" Trained: {model_info['training_timestamp']}")
if 'best_score' in model_info:
print(f" Best validation score: {model_info['best_score']:.6f}")
if 'features' in model_info:
print(f" Number of features: {len(model_info['features'])}")
else:
print("❌ No pre-trained models found!")
print("You may need to train a new model or check the models directory.")
2025-08-08 00:53:47,070 - AGNBoost.AGNBoost - WARNING - No file_name passed. Using the most recently modified one instead: 2025_05_22-PM06_59_58_agn.fracAGN_model.pkl.gz.
✅ Pre-trained model loaded successfully! Model information: Best validation score: -649218.125000 Number of features: 121
Making Predictions¶
Now we'll use our loaded model to predict AGN fractions for the test set. AGNBoost seamlessly handles the conversion of our catalog data into the format required by the underlying XGBoost model.
The prediction process uses the engineered features (colors, log magnitudes, etc.) that were automatically created from our photometric band data.
# Make predictions on the test set
preds = agnboost_m.predict( data = catalog, split_use = 'test', model_name = 'agn.fracAGN')
print(f" Mean: {np.mean(preds):.6f}")
print(f" Std: {np.std(preds):.6f}")
print(f" Min: {np.min(preds):.6f}")
print(f" Max: {np.max(preds):.6f}")
2025-08-08 00:57:49,965 - AGNBoost.AGNBoost - WARNING - Catalog object passsed. Taking the features and labels of the test set stored in the passed Catalog.
Mean: 0.504962 Std: 0.325251 Min: 0.000308 Max: 0.989859
Quantifying Prediction Uncertainty¶
One of AGNBoost's key advantages is its ability to provide robust uncertainty estimates through XGBoostLSS distributional modeling. Rather than just point estimates, we get full uncertainty quantification for each prediction.
The prediction_uncertainty()
method returns uncertainty estimates that account for both model uncertainty and the inherent variability in the data. This is crucial for astronomical applications where understanding prediction confidence is essential for scientific interpretation.
Since the loaded data is a CIGALE mock catalog with no photometric uncertainty, we will only estimate the model (aleatoric + epistemic) uncertainty for each source.
model_uncertainty = agnboost_m.prediction_uncertainty( uncertainty_type = 'model', model_name = 'agn.fracAGN', catalog = catalog)
print(f"✅ Uncertainty estimates generated")
print(f"Uncertainty statistics:")
print(f" Mean uncertainty: {np.mean(model_uncertainty):.6f}")
print(f" Std uncertainty: {np.std(model_uncertainty):.6f}")
print(f" Min uncertainty: {np.min(model_uncertainty):.6f}")
print(f" Max uncertainty: {np.max(model_uncertainty):.6f}")
2025-08-08 00:57:58,750 - AGNBoost.AGNBoost - WARNING - Catalog object passsed. Taking the features and labels of the None set stored in the passed Catalog. Processing truncated model uncertainty: 100%|█| 1000/1000 [06:58<00:00, 2.39it/
✅ Uncertainty estimates generated Uncertainty statistics: Mean uncertainty: 0.033900 Std uncertainty: 0.013166 Min uncertainty: 0.000940 Max uncertainty: 0.071419