Tutorial 01 — Quick Start

In this tutorial we will train an audio classifier from scratch on the GTZAN dataset consisting of audio samples from 10 different music genres.

By the end of this tutorial we will have:

  • Loaded an audio dataset from a directory

  • Built and trained an AudioClassifier with a pretrained backbone

  • Evaluated model accuracy on a held-out test set

  • Run single-file inference on unseen data


Assumed dataset layout

DeepAudioX expects the dataset to be in the following directory structure.

gtzan/
├── train/
│   ├── blues/
│   │   ├── blues.00000.wav
│   │   ├── blues.00001.wav
│   │   └── ...
│   ├── classical/
│   ├── country/
│   ├── disco/
│   ├── hiphop/
│   ├── jazz/
│   ├── metal/
│   ├── pop/
│   ├── reggae/
│   └── rock/
└── test/
    ├── blues/
    └── ...

Each sub-folder name becomes the class label. See Tutorial 02 for an alternative loading strategy using Python dictionaries.

1. Configuration

Update the two directory paths below and choose your target device before running any other cell.

[1]:
TRAIN_DIR   = "/data/gtzan/train"  # directory containing class sub-folders
TEST_DIR    = "/data/gtzan/test"
SAMPLE_RATE = 32_000                            # Sampling Rate to use when loading audio files
CHECKPOINT  = "checkpoint.pt"                   # path where the best model will be saved
DEVICE      = "cuda"                            # "cuda" | "mps" | "cpu"

2. Loading the Dataset

get_class_mapping_from_dir scans the top-level sub-folders of the training directory and builds a {class_name: int} mapping automatically. This mapping must be passed to every dataset, model, and inference call — it is the single source of truth for label ordering throughout the project.

[2]:
from deepaudiox import get_class_mapping_from_dir

class_mapping = get_class_mapping_from_dir(TRAIN_DIR)
print(f"{len(class_mapping)} classes detected: {list(class_mapping.keys())}")
10 classes detected: ['blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal', 'pop', 'reggae', 'rock']

DeepAudioX allows you to easily create a PyTorch Dataset for Audio Classification. This done via the audio_classification_dataset_from_dir method.

[3]:
from deepaudiox import audio_classification_dataset_from_dir

train_dataset = audio_classification_dataset_from_dir(
    root_dir=TRAIN_DIR,
    sample_rate=SAMPLE_RATE,
    class_mapping=class_mapping,
)

test_dataset = audio_classification_dataset_from_dir(
    root_dir=TEST_DIR,
    sample_rate=SAMPLE_RATE,
    class_mapping=class_mapping,
)

print(f"Train samples : {len(train_dataset)}")
print(f"Test  samples : {len(test_dataset)}")
Train samples : 799
Test  samples : 195

Each item returned by the dataset is a dictionary containing:

{
    "path": str,                # File path of the audio
    "y_true": int,              # Integer class ID
    "class_name": str,          # String class label
    "segment_idx": int,         # Segment index (for segmented audio)
    "feature": np.ndarray       # Audio waveform as numpy array
}
[4]:
item = train_dataset[0]
for key, value in item.items():
    print(key, ":", value)
path : /data/gtzan/train/blues/blues.00056.wav
y_true : 0
class_name : blues
segment_idx : 0
feature : [-0.15826386 -0.26958048 -0.29068074 ... -0.37689695 -0.36377403
 -0.12435105]

3. Working with segmented-audio files

The audio_classification_dataset_from_dir provides an argument segment_duration where you can control the total duration of each sample returned by __getitem__.

It is customary in audio to work with small segments (e.g., < 10 sec) to capture temporal changes within the track. If you provide a specific segment_duration then each segment is treated as an independent sample in the dataset, with the same class label as the original audio file. The segment_idx field in the dataset output indicates which segment a sample corresponds to. We re-initialize the datasets using an audio segment duration equal to 3 seconds.

[5]:
train_dataset = audio_classification_dataset_from_dir(
    root_dir=TRAIN_DIR,
    sample_rate=SAMPLE_RATE,
    class_mapping=class_mapping,
    segment_duration=3.0
)

test_dataset = audio_classification_dataset_from_dir(
    root_dir=TEST_DIR,
    sample_rate=SAMPLE_RATE,
    class_mapping=class_mapping,
    segment_duration=3.0
)

print(f"Train samples : {len(train_dataset)}")
print(f"Test  samples : {len(test_dataset)}")
Train samples : 7981
Test  samples : 1950

4. Building the Classifier

Use the AudioClassifier class to easily build an audio classifier with a pretrained backbone. The only parameters you’ll need to specify is the num_classes for the classifier head and the sample_rate to be used internally by the backbones to extract the spectro-temporal features from the raw audio waveforms.

We use a PASST as backbone pre-trained on AudioSet and freeze its weights — only the lightweight classifier head will be updated. This keeps training fast enough to run on a laptop GPU or even a CPU.

[6]:
from deepaudiox import AudioClassifier

model = AudioClassifier(
    num_classes=len(class_mapping),
    backbone="passt",
    pretrained=True,       # load pretrained weights for the backbone
    freeze_backbone=True,  # only the classifier head is trained
    sample_rate=SAMPLE_RATE,
)
[7]:
# Below you can check the architecture of the model.

model
[7]:
AudioClassifierConstructor(
  (backbone_constructor): BackboneConstructor(
    (backbone): PaSST(
      (feature_extractor): AugmentMelSTFT(
        winsize=800, hopsize=320
        (freqm): FrequencyMasking()
        (timem): TimeMasking()
      )
      (patch_embed): PatchEmbed(
        (proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10))
        (norm): Identity()
      )
      (pos_drop): Dropout(p=0.0, inplace=False)
      (blocks): Sequential(
        (0): Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (act): GELU(approximate='none')
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (1): Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (act): GELU(approximate='none')
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (2): Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (act): GELU(approximate='none')
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (3): Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (act): GELU(approximate='none')
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (4): Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (act): GELU(approximate='none')
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (5): Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (act): GELU(approximate='none')
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (6): Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (act): GELU(approximate='none')
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (7): Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (act): GELU(approximate='none')
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (8): Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (act): GELU(approximate='none')
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (9): Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (act): GELU(approximate='none')
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (10): Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (act): GELU(approximate='none')
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
        (11): Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (act): GELU(approximate='none')
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
      )
      (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (pre_logits): Identity()
    )
    (pooling): GAP()
  )
  (classifier): MLPHead(
    (model): Sequential(
      (0): Linear(in_features=768, out_features=10, bias=True)
    )
  )
)

5. Training

Trainer class handles the complete training loop out of the box:

Feature

Default

Train / validation split

80 / 20 (automatic)

Optimizer

Adam, lr = 1e-3

LR scheduler

ReduceLROnPlateau

Loss

Cross-Entropy

Early stopping

patience epochs of non-improving val loss

Checkpointing

Best model saved to path_to_checkpoint

All defaults can be overridden — see Tutorial 04 for advanced configuration.

[8]:
from deepaudiox import Trainer

trainer = Trainer(
    train_dset=train_dataset,
    model=model,
    epochs=30,
    patience=5,
    batch_size=64,
    path_to_checkpoint=CHECKPOINT,
    device=DEVICE
)

trainer.train() # Simply call the train method to start training!
Using GPU: NVIDIA GeForce RTX 4090
[Epoch 1/30]
Epoch 1 | Train Loss: 1.7399 | Val. Loss: 1.1800 | Time: 19.64s
[CHECKPOINTER] Validation loss decreased: (inf --> 1.180008), (-nan%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 2/30]
Epoch 2 | Train Loss: 1.1476 | Val. Loss: 0.8874 | Time: 18.99s
[CHECKPOINTER] Validation loss decreased: (1.180008 --> 0.887417), (-24.80%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 3/30]
Epoch 3 | Train Loss: 0.9254 | Val. Loss: 0.7591 | Time: 19.41s
[CHECKPOINTER] Validation loss decreased: (0.887417 --> 0.759138), (-14.46%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 4/30]
Epoch 4 | Train Loss: 0.8158 | Val. Loss: 0.7005 | Time: 19.03s
[CHECKPOINTER] Validation loss decreased: (0.759138 --> 0.700534), (-7.72%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 5/30]
Epoch 5 | Train Loss: 0.7439 | Val. Loss: 0.6588 | Time: 19.73s
[CHECKPOINTER] Validation loss decreased: (0.700534 --> 0.658843), (-5.95%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 6/30]
Epoch 6 | Train Loss: 0.6974 | Val. Loss: 0.6386 | Time: 19.19s
[CHECKPOINTER] Validation loss decreased: (0.658843 --> 0.638621), (-3.07%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 7/30]
Epoch 7 | Train Loss: 0.6623 | Val. Loss: 0.6237 | Time: 19.12s
[CHECKPOINTER] Validation loss decreased: (0.638621 --> 0.623666), (-2.34%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 8/30]
Epoch 8 | Train Loss: 0.6299 | Val. Loss: 0.6037 | Time: 19.14s
[CHECKPOINTER] Validation loss decreased: (0.623666 --> 0.603713), (-3.20%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 9/30]
Epoch 9 | Train Loss: 0.6011 | Val. Loss: 0.5913 | Time: 19.14s
[CHECKPOINTER] Validation loss decreased: (0.603713 --> 0.591263), (-2.06%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 10/30]
Epoch 10 | Train Loss: 0.5858 | Val. Loss: 0.5845 | Time: 19.12s
[CHECKPOINTER] Validation loss decreased: (0.591263 --> 0.584525), (-1.14%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 11/30]
Epoch 11 | Train Loss: 0.5679 | Val. Loss: 0.5834 | Time: 19.08s
[CHECKPOINTER] Validation loss decreased: (0.584525 --> 0.583369), (-0.20%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 12/30]
Epoch 12 | Train Loss: 0.5546 | Val. Loss: 0.5725 | Time: 19.24s
[CHECKPOINTER] Validation loss decreased: (0.583369 --> 0.572482), (-1.87%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 13/30]
Epoch 13 | Train Loss: 0.5458 | Val. Loss: 0.5639 | Time: 19.12s
[CHECKPOINTER] Validation loss decreased: (0.572482 --> 0.563908), (-1.50%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 14/30]
Epoch 14 | Train Loss: 0.5314 | Val. Loss: 0.5577 | Time: 19.14s
[CHECKPOINTER] Validation loss decreased: (0.563908 --> 0.557748), (-1.09%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 15/30]
Epoch 15 | Train Loss: 0.5154 | Val. Loss: 0.5573 | Time: 19.16s
[CHECKPOINTER] Validation loss decreased: (0.557748 --> 0.557256), (-0.09%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 16/30]
Epoch 16 | Train Loss: 0.5136 | Val. Loss: 0.5517 | Time: 19.08s
[CHECKPOINTER] Validation loss decreased: (0.557256 --> 0.551656), (-1.01%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 17/30]
Epoch 17 | Train Loss: 0.4960 | Val. Loss: 0.5537 | Time: 19.13s
[Epoch 18/30]
Epoch 18 | Train Loss: 0.4868 | Val. Loss: 0.5557 | Time: 18.84s
[Epoch 19/30]
Epoch 19 | Train Loss: 0.4890 | Val. Loss: 0.5368 | Time: 18.85s
[CHECKPOINTER] Validation loss decreased: (0.551656 --> 0.536754), (-2.70%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 20/30]
Epoch 20 | Train Loss: 0.4828 | Val. Loss: 0.5426 | Time: 19.40s
[Epoch 21/30]
Epoch 21 | Train Loss: 0.4664 | Val. Loss: 0.5397 | Time: 18.88s
[Epoch 22/30]
Epoch 22 | Train Loss: 0.4576 | Val. Loss: 0.5293 | Time: 18.82s
[CHECKPOINTER] Validation loss decreased: (0.536754 --> 0.529326), (-1.38%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 23/30]
Epoch 23 | Train Loss: 0.4554 | Val. Loss: 0.5226 | Time: 19.08s
[CHECKPOINTER] Validation loss decreased: (0.529326 --> 0.522598), (-1.27%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 24/30]
Epoch 24 | Train Loss: 0.4541 | Val. Loss: 0.5298 | Time: 19.10s
[Epoch 25/30]
Epoch 25 | Train Loss: 0.4490 | Val. Loss: 0.5253 | Time: 18.86s
[Epoch 26/30]
Epoch 26 | Train Loss: 0.4373 | Val. Loss: 0.5240 | Time: 18.86s
[Epoch 27/30]
Epoch 27 | Train Loss: 0.4341 | Val. Loss: 0.5269 | Time: 18.94s
[EARLY STOPPING] Elapsed epochs: 4 out of 5
[Epoch 28/30]
Epoch 28 | Train Loss: 0.4306 | Val. Loss: 0.5257 | Time: 18.84s
[EARLY STOPPING] Elapsed epochs: 5 out of 5
[EARLY STOPPING] Patience exceeded, early stoping ...
Training has finished.

5. Evaluating on the Test Set

Once training has finished we reload the best checkpoint using AudioClassifier.from_checkpoint, and initialize the Evaluator to check the performance of the classifier over the held-out test split.

evaluate() prints per-class accuracy, macro metrics, and a full classification report.

[9]:
from deepaudiox import Evaluator

model = AudioClassifier.from_checkpoint(CHECKPOINT)

evaluator = Evaluator(
    test_dset=test_dataset,
    model=model,
    class_mapping=class_mapping,
    device=DEVICE,
)

evaluator.evaluate()
Using GPU: NVIDIA GeForce RTX 4090
Testing has finished.
[REPORTER] Class mapping: {'blues': 0, 'classical': 1, 'country': 2, 'disco': 3, 'hiphop': 4, 'jazz': 5, 'metal': 6, 'pop': 7, 'reggae': 8, 'rock': 9}

[REPORTER] Classification Report:

              precision    recall  f1-score   support

       blues       0.74      0.69      0.72       200
   classical       0.98      0.99      0.99       200
     country       0.66      0.95      0.78       200
       disco       0.86      0.89      0.88       200
      hiphop       0.91      0.98      0.94       200
        jazz       0.86      0.94      0.90       190
       metal       0.87      0.88      0.88       190
         pop       0.87      0.93      0.90       190
      reggae       0.90      0.51      0.65       190
        rock       0.73      0.55      0.63       190

    accuracy                           0.83      1950
   macro avg       0.84      0.83      0.83      1950
weighted avg       0.84      0.83      0.83      1950

[REPORTER] Confusion Matrix:

[[139   1  27  13   0  14   0   0   0   6]
 [  0 198   0   0   0   1   0   1   0   0]
 [  9   0 190   0   0   0   0   0   0   1]
 [  2   0   3 178   6   0   0   3   5   3]
 [  0   0   0   3 196   0   0   0   0   1]
 [  0   2   2   0   1 179   0   1   2   3]
 [  4   0   3   1   1   0 167   0   0  14]
 [  1   0   4   4   0   0   0 176   0   5]
 [ 30   0  12   5  10  11   0  21  96   5]
 [  2   1  48   2   1   2  24   1   4 105]]
[REPORTER] Average Posteriors:

blues               : 0.722
classical           : 0.958
country             : 0.920
disco               : 0.841
hiphop              : 0.893
jazz                : 0.932
metal               : 0.903
pop                 : 0.855
reggae              : 0.767
rock                : 0.639

6. Inference on a New Audio File

DeepAudioX provides a flexible method inference_on_file to make predictions on unseen data. The method accepts any WAV or MP3 path and returns the predicted class label along with its posterior probability.

The argument segment_duration allows us to specify the desired segment duration upon processing the entire file. For the best performance this is advisable to coincide with the same duration used in training.

Here we grab the first .wav file from the test directory as a quick sanity check.

[10]:
from pathlib import Path

sample_file = next(Path(TEST_DIR).rglob("*.wav"))
true_label  = sample_file.parent.name

result = model.inference_on_file(
    path=str(sample_file),
    sample_rate=SAMPLE_RATE,
    class_mapping=class_mapping,
    segment_duration=3.0,  # same segment duration used during training
)

The method returns a dictionary containing the final_label or the prediction. This is the result of the majority vote across all 3-second segments in the file. In addition, the dictionary contains information about each individual segment-level prediction.

[11]:
result
[11]:
{'final_label': 'country',
 'final_posterior': 0.959261554479599,
 'segment_labels': ['country',
  'country',
  'country',
  'country',
  'country',
  'country',
  'country',
  'country',
  'country',
  'country',
  'classical'],
 'segment_posteriors': [0.9425491094589233,
  0.9905596375465393,
  0.9864566326141357,
  0.9714261889457703,
  0.9702315330505371,
  0.973042905330658,
  0.9922925233840942,
  0.9042606353759766,
  0.9638944864273071,
  0.8979018926620483,
  0.6707818508148193]}