import os
from pathlib import Path
import librosa
import soundfile as sf
from torch.utils.data import Dataset
from deepaudiox.schemas.items import AudioClassificationItem
[docs]
class AudioClassificationDataset(Dataset):
"""PyTorch Dataset for audio classification tasks.
This dataset loads audio files and returns a dictionary containing the raw waveform
(under the key ``"feature"``), the corresponding class name, and the integer class ID
defined in ``class_mapping``. The ``file_to_class_mapping`` argument must be a dictionary
of the form::
{"abs/path/to/audio.wav": "class_name"}
Optionally, the dataset can segment each audio file into fixed-duration chunks using
``segment_duration``. When enabled, each segment becomes an individual dataset sample.
Attributes:
file_to_class_mapping (dict): Mapping from file paths to class names.
sample_rate (int): Target sampling rate for audio loading.
class_mapping (dict): Mapping from string class labels to integer IDs.
"""
[docs]
def __init__(
self,
file_to_class_mapping: dict[str | os.PathLike, str],
sample_rate: int,
class_mapping: dict[str, int],
segment_duration: float | None = None,
):
"""Initialize the dataset.
Args:
file_to_class_mapping (dict): Mapping from file paths to class names.
sample_rate (int): Target sampling rate for audio loading.
class_mapping (dict): Mapping from string labels to integer IDs.
segment_duration (float | None): Duration of audio segments in seconds. If None, load full audio.
When set, the last partial segment is dropped.
Example:
>>> from deepaudiox import AudioClassificationDataset
>>> file_to_class_mapping = {
... "path/to/audio1.wav": "speech",
... "path/to/audio2.wav": "music",
... }
>>> class_mapping = {"speech": 0, "music": 1}
>>> dataset = AudioClassificationDataset(
... file_to_class_mapping=file_to_class_mapping,
... sample_rate=16_000,
... class_mapping=class_mapping,
... segment_duration=2.0,
... )
"""
self.sample_rate = sample_rate
self.class_mapping = class_mapping
self.file_to_class_mapping = file_to_class_mapping
self.items = [
AudioClassificationItem(path=Path(path), class_name=class_name, y_true=self.class_mapping[class_name])
for path, class_name in file_to_class_mapping.items()
]
self.segment_duration = segment_duration
if self.segment_duration:
self._apply_segmentation(segment_duration)
def __len__(self) -> int:
"""Return the number of items in the dataset.
Returns:
int: Total number of samples.
"""
return len(self.items)
def __getitem__(self, idx: int) -> dict:
"""Get a single dataset item by index.
Args:
idx (int): Index of the item to retrieve.
Returns:
dict: An AudioClassificationItem in the form of dictionary.
"""
item = self.items[idx]
item.feature = librosa.load(
path=item.path,
sr=self.sample_rate,
mono=True,
offset=item.segment_idx * self.segment_duration if self.segment_duration else 0,
duration=self.segment_duration,
)[0]
return item.to_dict()
def _apply_segmentation(self, segment_duration: float | None):
"""Segmentize all audio files into fixed-duration segments.
Drops the last partial segment. Files shorter than segment_duration are excluded.
"""
valid_items = []
for item in self.items:
with sf.SoundFile(item.path) as f:
total_duration = len(f) / f.samplerate # seconds
if total_duration < segment_duration:
continue
num_segments = int(total_duration // segment_duration)
for seg_idx in range(num_segments):
valid_items.append(
AudioClassificationItem(
path=item.path,
y_true=item.y_true,
segment_idx=seg_idx,
class_name=item.class_name,
)
)
self.items = valid_items
[docs]
def audio_classification_dataset_from_dir(
root_dir: str, sample_rate: int, class_mapping: dict[str, int], segment_duration: float | None = None
) -> AudioClassificationDataset:
"""Create an AudioClassificationDataset from a directory structure.
Args:
root_dir (str | Path): Root directory containing class sub-folders.
Only ``.wav`` and ``.mp3`` files are used.
sample_rate (int): Target sampling rate for audio loading.
class_mapping (dict): Mapping from string labels to integer IDs.
segment_duration (float | None): Duration of audio segments in seconds. If None, load full audio.
When set, the last partial segment is dropped.
Returns:
AudioClassificationDataset: The constructed dataset.
Example:
>>> from deepaudiox import audio_classification_dataset_from_dir, get_class_mapping_from_dir
>>> class_mapping = get_class_mapping_from_dir(root_dir="path/to/data")
>>> dataset = audio_classification_dataset_from_dir(
... root_dir="path/to/data",
... sample_rate=16_000,
... class_mapping=class_mapping,
... segment_duration=2.0,
... )
"""
root_path = Path(root_dir)
file_to_class_mapping = {}
subdirs = [d for d in root_path.iterdir() if d.is_dir()]
for subdir in sorted(subdirs):
audio_files = list(subdir.glob("**/*.wav")) + list(subdir.glob("**/*.mp3"))
for audio_file in audio_files:
file_to_class_mapping[audio_file] = subdir.name
return AudioClassificationDataset(
file_to_class_mapping=file_to_class_mapping,
sample_rate=sample_rate,
class_mapping=class_mapping,
segment_duration=segment_duration,
)
[docs]
def audio_classification_dataset_from_dictionary(
file_to_class_mapping: dict[str | os.PathLike, str],
sample_rate: int,
class_mapping: dict[str, int],
segment_duration: float | None = None,
) -> AudioClassificationDataset:
"""Create an AudioClassificationDataset from a file-to-class mapping dictionary.
Args:
file_to_class_mapping (dict): Mapping from file paths to class names.
sample_rate (int): Target sampling rate for audio loading.
class_mapping (dict): Mapping from string labels to integer IDs.
segment_duration (float | None): Duration of audio segments in seconds. If None, load full audio.
When set, the last partial segment is dropped.
Returns:
AudioClassificationDataset: The constructed dataset.
Example:
>>> from deepaudiox import audio_classification_dataset_from_dictionary
>>> file_to_class_mapping = {
... "path/to/audio1.wav": "speech",
... "path/to/audio2.wav": "music",
... }
>>> class_mapping = {"speech": 0, "music": 1}
>>> dataset = audio_classification_dataset_from_dictionary(
... file_to_class_mapping=file_to_class_mapping,
... sample_rate=16_000,
... class_mapping=class_mapping,
... segment_duration=None,
... )
"""
return AudioClassificationDataset(
file_to_class_mapping=file_to_class_mapping,
sample_rate=sample_rate,
class_mapping=class_mapping,
segment_duration=segment_duration,
)