from dataclasses import dataclass, field
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from deepaudiox.callbacks.reporter import Reporter
from deepaudiox.datasets.audio_classification_dataset import AudioClassificationDataset
from deepaudiox.modules.baseclasses import BaseAudioClassifier
from deepaudiox.schemas.types import DeviceName
from deepaudiox.utils.training_utils import get_device, get_logger, pad_collate_fn
@dataclass
class State:
"""Dataclass that stores variables accessed throughout the testing lifecycle.
Attributes:
y_true (np.ndarray): A NumPy array of true labels.
y_pred (np.ndarray): A NumPy array of predicted labels.
posteriors (np.ndarray): A NumPy array of posterior probabilities.
"""
y_true: np.ndarray = field(default_factory=lambda: np.array([], dtype=int))
y_pred: np.ndarray = field(default_factory=lambda: np.array([], dtype=int))
posteriors: np.ndarray = field(default_factory=lambda: np.array([], dtype=float))
[docs]
class Evaluator:
"""The core SDK module for testing a model.
The Evaluator assembles all modules required for testing
and performs the testing process.
Attributes:
state (State): Stores testing variables.
verbose (bool): Whether to log the evaluation report after testing.
device (str): The device used for testing.
class_mapping (dict): A mapping between class names and IDs.
logger (logging.Logger): A module used for logging messages.
test_dloader (torch.DataLoader): The DataLoader of the testing set.
model (BaseAudioClassifier): An AudioClassifier module inheriting from BaseAudioClassifier.
callbacks (list): A list of callbacks used throughout the testing lifecycle.
"""
[docs]
def __init__(
self,
test_dset: AudioClassificationDataset,
model: BaseAudioClassifier,
class_mapping: dict,
batch_size: int = 16,
num_workers: int = 4,
device: DeviceName = "cpu",
device_index: int | None = None,
verbose: bool = True,
):
"""Initialize the Evaluator.
Args:
test_dset (AudioClassificationDataset): The testing dataset.
model (BaseAudioClassifier): An AudioClassifier module inheriting from BaseAudioClassifier.
class_mapping (dict): A mapping between class names and IDs.
batch_size (int, optional): The batch size for Python Data Loaders. Defaults to 16.
num_workers (int, optional): The number of workers for Python Data Loaders. Defaults to 4.
device (DeviceName): The device to use for evaluation. One of ``"cuda"``, ``"mps"``, or ``"cpu"``.
Defaults to ``"cpu"``.
device_index (int | None): The GPU device index. Only applicable when ``device="cuda"``.
If ``None``, uses the default CUDA device.
verbose (bool): If True, prints the classification report, confusion matrix, and average
posteriors after evaluation. "Evaluation has finished." is always printed. Defaults to True.
Example:
>>> import torch
>>> from deepaudiox import AudioClassifier, Evaluator
>>> from deepaudiox import audio_classification_dataset_from_dir, get_class_mapping_from_dir
>>> class_mapping = get_class_mapping_from_dir(root_dir="path/to/data")
>>> test_dataset = audio_classification_dataset_from_dir(
... root_dir="path/to/data",
... sample_rate=16_000,
... class_mapping=class_mapping,
... )
>>> model = AudioClassifier(num_classes=len(class_mapping), backbone="beats", sample_rate=16_000)
>>> model.load_state_dict(torch.load("checkpoint.pt"))
>>> evaluator = Evaluator(test_dset=test_dataset, model=model, class_mapping=class_mapping)
>>> evaluator.evaluate()
"""
self.state = State()
self.verbose = verbose
self.device = get_device(device=device, device_index=device_index)
self.class_mapping = class_mapping
# Configure logger
self.logger = get_logger()
# Load dataset
self.test_dloader = DataLoader(
test_dset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=self.device.type == "cuda",
persistent_workers=num_workers > 0,
collate_fn=pad_collate_fn,
)
# Load model
self.model = model
self.model.to(self.device)
self.model.eval()
# Configure callbacks
self.callbacks = [Reporter(logger=self.logger)]
[docs]
@torch.inference_mode()
def evaluate(self) -> None:
"""Run the full evaluation loop over the test set.
Iterates over all batches in ``test_dloader``, accumulates true labels,
predicted labels, and posterior probabilities into ``self.state``, then
triggers the registered callbacks via ``on_testing_end``.
Always prints "Evaluation has finished." regardless of ``verbose``.
The ``Reporter`` callback (classification report, confusion matrix, average
posteriors) is only executed when ``verbose=True``.
After this method returns, ``self.state`` holds:
- ``y_true`` (np.ndarray): Ground-truth class indices, shape (N,).
- ``y_pred`` (np.ndarray): Predicted class indices, shape (N,).
- ``posteriors`` (np.ndarray): Max posterior probability per sample, shape (N,).
Note:
The model is expected to already be in eval mode (set in ``__init__``).
Runs under ``torch.inference_mode()`` — gradients are fully disabled.
"""
# Lists to accumulate evaluation results, i.e., true_labels, prediction_labels, and posteriors
y_true_batches, y_pred_batches, posterior_batches = [], [], []
with tqdm(self.test_dloader, unit="batch", leave=False, desc="Evaluation phase") as tbar:
for batch in tbar:
# Move inputs
x = batch["feature"].to(self.device)
y_true = batch["y_true"].cpu().numpy()
# Run model prediction
inference = self.model.predict(x)
y_pred = np.array(inference["y_preds"], dtype=int)
post = np.array(inference["posteriors"], dtype=float)
# Update lists with new batch results
y_true_batches.append(y_true)
y_pred_batches.append(y_pred)
posterior_batches.append(post)
# Concatenate all results outside the loop
self.state.y_true = np.concatenate(y_true_batches)
self.state.y_pred = np.concatenate(y_pred_batches)
self.state.posteriors = np.concatenate(posterior_batches)
self.logger.info("Evaluation has finished.")
if self.verbose:
for cb in self.callbacks:
cb.on_testing_end(self)