Source code for deepaudiox.utils.training_utils

import itertools
import logging
import math
import warnings
from pathlib import Path

import numpy as np
import torch
from torch import Generator, default_generator, randperm
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Subset

from deepaudiox.datasets.audio_classification_dataset import AudioClassificationDataset
from deepaudiox.schemas.types import DeviceName


def get_logger() -> logging.Logger:
    """Initialize and return a console logger."""
    logging.basicConfig(level=logging.INFO, format="%(message)s")
    logger = logging.getLogger("ConsoleLogger")
    return logger


[docs] def get_class_mapping_from_list(labels: list[str], sort_alphabetically: bool = True) -> dict[str, int]: """Get a class mapping dictionary given a list of class names. Args: labels (list[str]): List of class names sort_alphabetically (boolean): Determines if alphabetical sorting should be applied to class names. Returns: dict[str, int]: The class mapping dictionary Example: >>> from deepaudiox import get_class_mapping_from_list >>> labels = ["speech", "music", "noise"] >>> class_mapping = get_class_mapping_from_list(labels, sort_alphabetically=True) >>> # Example output: >>> # {'music': 0, 'noise': 1, 'speech': 2} """ if sort_alphabetically: class_mapping = {name: idx for idx, name in enumerate(sorted(labels))} else: class_mapping = {name: idx for idx, name in enumerate(labels)} return class_mapping
[docs] def get_class_mapping_from_dir(root_dir: str) -> dict[str, int]: """Load the class mapping given a folder of class sub-folders. Expected directory structure:: root_dir/ ├── class_a/ │ ├── audio1.wav │ └── audio2.wav └── class_b/ ├── audio3.wav └── audio4.wav Args: root_dir (str): The path to root folder Returns: dict[str, int]: The class mapping dictionary, ordered alphabetically by folder name. Example: >>> from deepaudiox import get_class_mapping_from_dir >>> class_mapping = get_class_mapping_from_dir(root_dir="path/to/data") >>> # Example output: >>> # {'class_a': 0, 'class_b': 1} """ root_path = Path(root_dir) if not root_path.is_dir(): raise ValueError(f"The path '{root_dir}' is not a directory") subdirs = [d for d in root_path.iterdir() if d.is_dir()] class_mapping = {d.name: idx for idx, d in enumerate(sorted(subdirs))} return class_mapping
def get_device(device: DeviceName = "cpu", device_index: int | None = None) -> torch.device: """Returns a PyTorch device based on the user's choice. Args: device (DeviceName): The device to use. One of ``"cuda"``, ``"mps"``, or ``"cpu"``. Defaults to ``"cpu"``. device_index (int | None): The GPU device index. Only applicable when ``device="cuda"``. If ``None``, uses the default CUDA device. Returns: torch.device Raises: ValueError: If ``device="cuda"`` but CUDA is not available. ValueError: If ``device="cuda"`` and ``device_index`` is out of range. ValueError: If ``device="mps"`` but MPS is not available. ValueError: If ``device_index`` is provided for a non-CUDA device. """ if device == "cuda": if not torch.cuda.is_available(): raise ValueError("CUDA is not available on this machine. Use device='cpu' or device='mps' instead.") if device_index is not None and (device_index < 0 or device_index >= torch.cuda.device_count()): raise ValueError(f"Invalid device_index {device_index}. Available GPU count: {torch.cuda.device_count()}") if device_index is not None: torch_device = torch.device(f"cuda:{device_index}") print(f"Using GPU: {torch.cuda.get_device_name(device_index)}") else: torch_device = torch.device("cuda") print(f"Using GPU: {torch.cuda.get_device_name(0)}") elif device == "mps": if not torch.backends.mps.is_available(): raise ValueError("MPS is not available on this machine. Use device='cpu' instead.") if device_index is not None: raise ValueError("device_index is not supported for MPS. Apple Silicon has a single GPU.") torch_device = torch.device("mps") print("Using MPS (Apple Silicon GPU)") else: if device_index is not None: print("Warning: device_index is ignored when device='cpu'.") torch_device = torch.device("cpu") print("Using CPU.") return torch_device def pad_collate_fn(batch) -> dict: """Collate function that pads variable-length audio tensors in a batch. Uses ``torch.stack`` when all waveforms in the batch have the same length (e.g. when ``segment_duration`` is set), and falls back to ``pad_sequence`` for variable-length batches. Args: batch (list of dict): Each dict contains 'feature', 'y_true', and 'class_name'. Returns: dict[str, torch.Tensor | list[str]]: Batched and padded tensors. """ features = [torch.from_numpy(item["feature"]) for item in batch] labels = torch.tensor([item["y_true"] for item in batch], dtype=torch.long) class_names = [item["class_name"] for item in batch] if len({f.shape[0] for f in features}) == 1: batch_features = torch.stack(features) else: batch_features = pad_sequence(features, batch_first=True) return {"feature": batch_features, "y_true": labels, "class_name": class_names} def random_split_audio_dataset( dataset: AudioClassificationDataset, train_ratio: float, generator: Generator = default_generator ) -> list[Subset[AudioClassificationDataset]]: """ Split AudioClassificationDataset into train / val subsets specified by train ratio. Method accounts for segmentized waveforms. Args: dataset (AudioClassificationDataset): An AudioClassificationDataset train_ratio (float): Percentage of training set. generator (Generator): Random Generator. Returns: list[Subset[AudioClassificationDataset]]: List containing train and validation Subsets. """ # Validate ratio if not (0 <= train_ratio <= 1): raise ValueError("train_ratio must be between 0 and 1.") # Extract recording paths file_paths = np.array([item.path for item in dataset.items]) num_files = len(np.unique(file_paths)) if dataset.segment_duration else len(file_paths) # Compute split sizes n_train = int(math.floor(num_files * train_ratio)) n_valid = num_files - n_train subset_lengths = [n_train, n_valid] # Validate split sizes if n_train == 0: warnings.warn("Training split has length 0.", stacklevel=2) if n_valid == 0: warnings.warn("Validation split has length 0.", stacklevel=2) if sum(subset_lengths) != num_files: raise ValueError("Split sizes do not sum to the total number of items.") if dataset.segment_duration: # Segmentized dataset: shuffle files, then map to segments unique_files = np.unique(file_paths) num_unique_files = len(unique_files) shuffled_file_indices = randperm(num_unique_files, generator=generator).tolist() shuffled_files = unique_files[shuffled_file_indices] split_files_list = [ shuffled_files[offset - length : offset] for offset, length in zip(itertools.accumulate(subset_lengths), subset_lengths, strict=False) ] subsets = [ Subset(dataset, np.where(np.isin(file_paths, split_files))[0].tolist()) for split_files in split_files_list ] return subsets else: # Non-segmentized dataset: shuffle indices directly indices = randperm(num_files, generator=generator).tolist() subsets = [ Subset(dataset, indices[offset - length : offset]) for offset, length in zip(itertools.accumulate(subset_lengths), subset_lengths, strict=False) ] return subsets