Source code for deepaudiox.loops.trainer

import time
from dataclasses import dataclass, field

import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import LRScheduler, ReduceLROnPlateau
from torch.utils.data import DataLoader
from tqdm import tqdm

from deepaudiox.callbacks.base_callback import BaseCallback
from deepaudiox.callbacks.checkpointer import Checkpointer
from deepaudiox.callbacks.early_stopper import EarlyStopper
from deepaudiox.datasets.audio_classification_dataset import AudioClassificationDataset
from deepaudiox.modules.baseclasses import BaseAudioClassifier
from deepaudiox.schemas.types import DeviceName
from deepaudiox.utils.training_utils import get_device, get_logger, pad_collate_fn, random_split_audio_dataset


@dataclass
class State:
    """Dataclass that stores variables accessed throught the training lifecycle.

    Attributes:
        current_epoch (int): The current epoch of the training process. Dafaults to 1.
        lowest_loss (float): The lowest loss achieved during training. Defaults to np.Inf.
        train_loss (list): An ordered list of train losses, by epoch.
        validation_loss (list): An orderd list of validation lossed, by epoch.
        early_stop (bool): Determines if training should be early stopped. Defaults to False.

    """

    current_epoch: int = 1
    lowest_loss: float = np.inf
    train_loss: list[float] = field(default_factory=list)
    validation_loss: list[float] = field(default_factory=list)
    early_stop: bool = False
    current_patience: int = 0


[docs] class Trainer: """The core SDK module for training a model. The Trainer assembles all modules required for training and performs the training process. Attributes: state (State): Stores training variables. epochs (int): The maximum number of training epochs. verbose (bool): Whether to log epoch-level artifacts. device (str): The device used for training. logger (logging.Logger): A module used for logging messages. train_dloader (torch.DataLoader): The DataLoader of the training set. validation_dloader (torch.DataLoader): The DataLoader of the validation set. model (BaseAudioClassifier): The BaseAudioClassifier to be trained. optimizer (torch.optim.Optimizer): The optimizer of the training process. scheduler (LRScheduler): The learning rate scheduler of the training process. loss_function (nn.Module): The loss function used for optimization. callbacks (list): A list of callbacks used throughout the training lifecycle. """
[docs] def __init__( self, train_dset: AudioClassificationDataset, model: BaseAudioClassifier, validation_dset: AudioClassificationDataset | None = None, optimizer: torch.optim.Optimizer | None = None, learning_rate: float = 1e-3, lr_scheduler: LRScheduler | None = None, loss_function: nn.Module | None = None, train_ratio: float = 0.8, epochs: int = 100, patience: int | None = None, num_workers: int = 4, batch_size: int = 16, path_to_checkpoint: str = "checkpoint.pt", device: DeviceName = "cpu", device_index: int | None = None, verbose: bool = True, ): """Initialize the Trainer. Args: train_dset (AudioClassificationDataset): The training dataset. model (BaseAudioClassifier): The model to be trained. validation_dset (AudioClassificationDataset | None): The validation dataset. If None, a split is created from train_dset using train_ratio. optimizer (torch.optim.Optimizer): The optimizer used for training. Adam if None. learning_rate (float): Learning rate used when optimizer is None. Defaults to 1e-3. lr_scheduler (LRScheduler | None): The scheduler used for training. ReduceLROnPlateau if None. loss_function (nn.Module | None): The loss function used for training. Uses CrossEntropy if None. train_ratio (float, optional): The ratio of the train split when validation_dset is None. Defaults to 0.8. epochs (int, optional): The maximum number of training epochs. Defaults to 100. patience (int | None): Epochs to wait without loss improvement before stopping. Disabled if None. num_workers (int, optional): The number of workers for Python Data Loaders. Defaults to 4. batch_size (int, optional): The batch size for Python Data Loaders. Defaults to 16. path_to_checkpoint (str, optional): The path to the saved model checpoint. Defaults to "checkpoint.pt". device (DeviceName): The device to use for training. One of ``"cuda"``, ``"mps"``, or ``"cpu"``. Defaults to ``"cpu"``. device_index (int | None): The GPU device index. Only applicable when ``device="cuda"``. If ``None``, uses the default CUDA device. verbose (bool): If True, logs epoch-level artifacts (loss, time). If False, only start/end messages and the final training summary are printed. Defaults to True. Example: >>> from deepaudiox import AudioClassifier, Trainer >>> from deepaudiox import audio_classification_dataset_from_dir, get_class_mapping_from_dir >>> class_mapping = get_class_mapping_from_dir(root_dir="path/to/data") >>> train_dataset = audio_classification_dataset_from_dir( ... root_dir="path/to/data", ... sample_rate=16_000, ... class_mapping=class_mapping, ... ) >>> model = AudioClassifier(num_classes=len(class_mapping), backbone="beats", sample_rate=16_000) >>> trainer = Trainer(train_dset=train_dataset, model=model, epochs=10) >>> trainer.train() """ # Configure training state self.state = State() self.epochs = epochs self.verbose = verbose self.path_to_checkpoint = path_to_checkpoint self.device = get_device(device=device, device_index=device_index) # Configure logger self.logger = get_logger() # Load datasets self.train_dloader, self.validation_dloader = self._setup_dataloaders( train_dset=train_dset, validation_dset=validation_dset, train_ratio=train_ratio, batch_size=batch_size, num_workers=num_workers, ) # Load model and training modules self.model = model self.model.to(self.device) self.optimizer = optimizer or Adam(params=self.model.parameters(), lr=learning_rate) self.scheduler = lr_scheduler or ReduceLROnPlateau(self.optimizer, "min") self.loss_function = loss_function or nn.CrossEntropyLoss() # Configure callbacks — Checkpointer must precede EarlyStopper (updates lowest_loss first) self.callbacks: list[BaseCallback] = [Checkpointer(path_to_checkpoint=path_to_checkpoint, logger=self.logger)] if patience: self.callbacks.append(EarlyStopper(patience=patience, logger=self.logger)) self._epoch_start_time: float = 0.0
[docs] def train_step(self) -> float: """Run one pass over the training set. Sets the model to train mode, iterates over ``train_dloader``, performs forward + backward + optimizer step per batch. Returns: float: Average training loss over the epoch. """ train_loss = 0.0 self.model.train() with tqdm(self.train_dloader, unit="batch", leave=False, desc="Training phase") as tbar: for batch in tbar: self.optimizer.zero_grad() x = batch["feature"].to(self.device) y_true = batch["y_true"].to(self.device) y_pred = self.model(x) batch_loss = self.loss_function(y_pred, y_true) batch_loss.backward() self.optimizer.step() train_loss += batch_loss.item() return train_loss / max(1, len(self.train_dloader))
[docs] def val_step(self) -> float: """Run one pass over the validation set. Sets the model to eval mode, iterates over ``validation_dloader`` under ``torch.no_grad()``. Returns: float: Average validation loss over the epoch. """ val_loss = 0.0 self.model.eval() with torch.no_grad(), tqdm(self.validation_dloader, unit="batch", leave=False, desc="Validation phase") as vbar: for batch in vbar: x = batch["feature"].to(self.device) y_true = batch["y_true"].to(self.device) y_pred = self.model(x) batch_loss = self.loss_function(y_pred, y_true) val_loss += batch_loss.item() return val_loss / len(self.validation_dloader)
[docs] def epoch_step(self) -> tuple[float, float]: """Run one complete training epoch. Logs the epoch header and metrics when ``verbose=True``, calls ``train_step()`` and ``val_step()``, updates the LR scheduler and ``self.state``, then executes ``on_epoch_end`` callbacks (which may trigger early stopping or checkpointing). Note: ``self.state.current_epoch`` must be set by the caller before invoking this method — ``train()`` does this automatically. When calling ``epoch_step()`` directly, set it yourself: ``trainer.state.current_epoch = epoch``. Returns: tuple[float, float]: ``(train_loss, val_loss)`` for the epoch. Example: >>> from deepaudiox import AudioClassifier, Trainer >>> from deepaudiox import audio_classification_dataset_from_dir, get_class_mapping_from_dir >>> class_mapping = get_class_mapping_from_dir(root_dir="path/to/data") >>> train_dataset = audio_classification_dataset_from_dir( ... root_dir="path/to/data", ... sample_rate=16_000, ... class_mapping=class_mapping, ... ) >>> model = AudioClassifier(num_classes=len(class_mapping), backbone="beats", sample_rate=16_000) >>> trainer = Trainer(train_dset=train_dataset, model=model, epochs=10) >>> for epoch in range(1, trainer.epochs + 1): ... trainer.state.current_epoch = epoch ... train_loss, val_loss = trainer.epoch_step() ... print(f"Epoch {epoch} — train: {train_loss:.4f}, val: {val_loss:.4f}") ... if trainer.state.early_stop: ... break """ self._epoch_start_time = time.time() if self.verbose: self.logger.info(f"[Epoch {self.state.current_epoch}/{self.epochs}]") for cb in self.callbacks: cb.on_epoch_start(self) train_loss = self.train_step() val_loss = self.val_step() if isinstance(self.scheduler, ReduceLROnPlateau): self.scheduler.step(val_loss) else: self.scheduler.step() self.state.train_loss.append(train_loss) self.state.validation_loss.append(val_loss) if self.verbose: elapsed = time.time() - self._epoch_start_time self.logger.info( f"Epoch {self.state.current_epoch} | " f"Train Loss: {train_loss:.4f} | " f"Val. Loss: {val_loss:.4f} | " f"Time: {elapsed:.2f}s" ) for cb in self.callbacks: cb.on_epoch_end(self) return train_loss, val_loss
[docs] def train(self) -> None: """Perform the full training process. Epoch-level output is controlled by ``verbose``. The training summary (best epoch, losses, checkpoint path) is always printed on completion. """ self.logger.info("Training has started.") for cb in self.callbacks: cb.on_train_start(self) for epoch in range(1, self.epochs + 1): if self.state.early_stop: break self.state.current_epoch = epoch self.epoch_step() for cb in self.callbacks: cb.on_train_end(self) best_idx = int(np.argmin(self.state.validation_loss)) best_val_loss = self.state.validation_loss[best_idx] best_train_loss = self.state.train_loss[best_idx] best_epoch = best_idx + 1 sep = "─" * 52 self.logger.info( f"\n{sep}\n" f" Training Complete\n" f" Best Epoch : {best_epoch}\n" f" Train Loss : {best_train_loss:.6f}\n" f" Val. Loss : {best_val_loss:.6f}\n" f" Checkpoint : {self.path_to_checkpoint}\n" f"{sep}" )
def _setup_dataloaders( self, train_dset: AudioClassificationDataset, validation_dset: AudioClassificationDataset | None, train_ratio: float, batch_size: int, num_workers: int, ): """Generate PyTorch DataLoaders for training and validation splits. Args: train_dset (AudioClassificationDataset): Training dataset. validation_dset (AudioClassificationDataset): Validation dataset. batch_size (int, optional): The batch size for Python Data Loaders. Defaults to 16. num_workers (int, optional): The number of workers for Python Data Loaders. Defaults to 4. train_ratio (float, optional): The ratio of the train split. Defaults to 0.8. """ # Split to train and validation if validation_dset is None: train_dataset, validation_dataset = random_split_audio_dataset(train_dset, train_ratio) else: train_dataset = train_dset validation_dataset = validation_dset # Produce DataLoaders pin_memory = self.device.type == "cuda" persistent_workers = num_workers > 0 train_dloader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, collate_fn=pad_collate_fn, ) validation_dloader = DataLoader( validation_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, collate_fn=pad_collate_fn, ) return train_dloader, validation_dloader