Tutorial 01 — Quick Start
In this tutorial we will train an audio classifier from scratch on the GTZAN dataset consisting of audio samples from 10 different music genres.
By the end of this tutorial we will have:
Loaded an audio dataset from a directory
Built and trained an
AudioClassifierwith a pretrained backboneEvaluated model accuracy on a held-out test set
Run single-file inference on unseen data
Assumed dataset layout
DeepAudioX expects the dataset to be in the following directory structure.
gtzan/
├── train/
│ ├── blues/
│ │ ├── blues.00000.wav
│ │ ├── blues.00001.wav
│ │ └── ...
│ ├── classical/
│ ├── country/
│ ├── disco/
│ ├── hiphop/
│ ├── jazz/
│ ├── metal/
│ ├── pop/
│ ├── reggae/
│ └── rock/
└── test/
├── blues/
└── ...
Each sub-folder name becomes the class label. See Tutorial 02 for an alternative loading strategy using Python dictionaries.
1. Configuration
Update the two directory paths below and choose your target device before running any other cell.
[1]:
TRAIN_DIR = "/data/gtzan/train" # directory containing class sub-folders
TEST_DIR = "/data/gtzan/test"
SAMPLE_RATE = 32_000 # Sampling Rate to use when loading audio files
CHECKPOINT = "checkpoint.pt" # path where the best model will be saved
DEVICE = "cuda" # "cuda" | "mps" | "cpu"
2. Loading the Dataset
get_class_mapping_from_dir scans the top-level sub-folders of the training directory and builds a {class_name: int} mapping automatically. This mapping must be passed to every dataset, model, and inference call — it is the single source of truth for label ordering throughout the project.
[2]:
from deepaudiox import get_class_mapping_from_dir
class_mapping = get_class_mapping_from_dir(TRAIN_DIR)
print(f"{len(class_mapping)} classes detected: {list(class_mapping.keys())}")
10 classes detected: ['blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal', 'pop', 'reggae', 'rock']
DeepAudioX allows you to easily create a PyTorch Dataset for Audio Classification. This done via the audio_classification_dataset_from_dir method.
[3]:
from deepaudiox import audio_classification_dataset_from_dir
train_dataset = audio_classification_dataset_from_dir(
root_dir=TRAIN_DIR,
sample_rate=SAMPLE_RATE,
class_mapping=class_mapping,
)
test_dataset = audio_classification_dataset_from_dir(
root_dir=TEST_DIR,
sample_rate=SAMPLE_RATE,
class_mapping=class_mapping,
)
print(f"Train samples : {len(train_dataset)}")
print(f"Test samples : {len(test_dataset)}")
Train samples : 799
Test samples : 195
Each item returned by the dataset is a dictionary containing:
{
"path": str, # File path of the audio
"y_true": int, # Integer class ID
"class_name": str, # String class label
"segment_idx": int, # Segment index (for segmented audio)
"feature": np.ndarray # Audio waveform as numpy array
}
[4]:
item = train_dataset[0]
for key, value in item.items():
print(key, ":", value)
path : /data/gtzan/train/blues/blues.00056.wav
y_true : 0
class_name : blues
segment_idx : 0
feature : [-0.15826386 -0.26958048 -0.29068074 ... -0.37689695 -0.36377403
-0.12435105]
3. Working with segmented-audio files
The audio_classification_dataset_from_dir provides an argument segment_duration where you can control the total duration of each sample returned by __getitem__.
It is customary in audio to work with small segments (e.g., < 10 sec) to capture temporal changes within the track. If you provide a specific segment_duration then each segment is treated as an independent sample in the dataset, with the same class label as the original audio file. The segment_idx field in the dataset output indicates which segment a sample corresponds to. We re-initialize the datasets using an audio segment duration equal to 3 seconds.
[5]:
train_dataset = audio_classification_dataset_from_dir(
root_dir=TRAIN_DIR,
sample_rate=SAMPLE_RATE,
class_mapping=class_mapping,
segment_duration=3.0
)
test_dataset = audio_classification_dataset_from_dir(
root_dir=TEST_DIR,
sample_rate=SAMPLE_RATE,
class_mapping=class_mapping,
segment_duration=3.0
)
print(f"Train samples : {len(train_dataset)}")
print(f"Test samples : {len(test_dataset)}")
Train samples : 7981
Test samples : 1950
4. Building the Classifier
Use the AudioClassifier class to easily build an audio classifier with a pretrained backbone. The only parameters you’ll need to specify is the num_classes for the classifier head and the sample_rate to be used internally by the backbones to extract the spectro-temporal features from the raw audio waveforms.
We use a PASST as backbone pre-trained on AudioSet and freeze its weights — only the lightweight classifier head will be updated. This keeps training fast enough to run on a laptop GPU or even a CPU.
[6]:
from deepaudiox import AudioClassifier
model = AudioClassifier(
num_classes=len(class_mapping),
backbone="passt",
pretrained=True, # load pretrained weights for the backbone
freeze_backbone=True, # only the classifier head is trained
sample_rate=SAMPLE_RATE,
)
[7]:
# Below you can check the architecture of the model.
model
[7]:
AudioClassifierConstructor(
(backbone_constructor): BackboneConstructor(
(backbone): PaSST(
(feature_extractor): AugmentMelSTFT(
winsize=800, hopsize=320
(freqm): FrequencyMasking()
(timem): TimeMasking()
)
(patch_embed): PatchEmbed(
(proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10))
(norm): Identity()
)
(pos_drop): Dropout(p=0.0, inplace=False)
(blocks): Sequential(
(0): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(1): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(2): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(3): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(4): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(5): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(6): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(7): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(8): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(9): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(10): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
(11): Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
)
)
)
(norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(pre_logits): Identity()
)
(pooling): GAP()
)
(classifier): MLPHead(
(model): Sequential(
(0): Linear(in_features=768, out_features=10, bias=True)
)
)
)
5. Training
Trainer class handles the complete training loop out of the box:
Feature |
Default |
|---|---|
Train / validation split |
80 / 20 (automatic) |
Optimizer |
Adam, lr = 1e-3 |
LR scheduler |
ReduceLROnPlateau |
Loss |
Cross-Entropy |
Early stopping |
|
Checkpointing |
Best model saved to |
All defaults can be overridden — see Tutorial 04 for advanced configuration.
[8]:
from deepaudiox import Trainer
trainer = Trainer(
train_dset=train_dataset,
model=model,
epochs=30,
patience=5,
batch_size=64,
path_to_checkpoint=CHECKPOINT,
device=DEVICE
)
trainer.train() # Simply call the train method to start training!
Using GPU: NVIDIA GeForce RTX 4090
[Epoch 1/30]
Epoch 1 | Train Loss: 1.7399 | Val. Loss: 1.1800 | Time: 19.64s
[CHECKPOINTER] Validation loss decreased: (inf --> 1.180008), (-nan%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 2/30]
Epoch 2 | Train Loss: 1.1476 | Val. Loss: 0.8874 | Time: 18.99s
[CHECKPOINTER] Validation loss decreased: (1.180008 --> 0.887417), (-24.80%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 3/30]
Epoch 3 | Train Loss: 0.9254 | Val. Loss: 0.7591 | Time: 19.41s
[CHECKPOINTER] Validation loss decreased: (0.887417 --> 0.759138), (-14.46%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 4/30]
Epoch 4 | Train Loss: 0.8158 | Val. Loss: 0.7005 | Time: 19.03s
[CHECKPOINTER] Validation loss decreased: (0.759138 --> 0.700534), (-7.72%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 5/30]
Epoch 5 | Train Loss: 0.7439 | Val. Loss: 0.6588 | Time: 19.73s
[CHECKPOINTER] Validation loss decreased: (0.700534 --> 0.658843), (-5.95%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 6/30]
Epoch 6 | Train Loss: 0.6974 | Val. Loss: 0.6386 | Time: 19.19s
[CHECKPOINTER] Validation loss decreased: (0.658843 --> 0.638621), (-3.07%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 7/30]
Epoch 7 | Train Loss: 0.6623 | Val. Loss: 0.6237 | Time: 19.12s
[CHECKPOINTER] Validation loss decreased: (0.638621 --> 0.623666), (-2.34%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 8/30]
Epoch 8 | Train Loss: 0.6299 | Val. Loss: 0.6037 | Time: 19.14s
[CHECKPOINTER] Validation loss decreased: (0.623666 --> 0.603713), (-3.20%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 9/30]
Epoch 9 | Train Loss: 0.6011 | Val. Loss: 0.5913 | Time: 19.14s
[CHECKPOINTER] Validation loss decreased: (0.603713 --> 0.591263), (-2.06%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 10/30]
Epoch 10 | Train Loss: 0.5858 | Val. Loss: 0.5845 | Time: 19.12s
[CHECKPOINTER] Validation loss decreased: (0.591263 --> 0.584525), (-1.14%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 11/30]
Epoch 11 | Train Loss: 0.5679 | Val. Loss: 0.5834 | Time: 19.08s
[CHECKPOINTER] Validation loss decreased: (0.584525 --> 0.583369), (-0.20%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 12/30]
Epoch 12 | Train Loss: 0.5546 | Val. Loss: 0.5725 | Time: 19.24s
[CHECKPOINTER] Validation loss decreased: (0.583369 --> 0.572482), (-1.87%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 13/30]
Epoch 13 | Train Loss: 0.5458 | Val. Loss: 0.5639 | Time: 19.12s
[CHECKPOINTER] Validation loss decreased: (0.572482 --> 0.563908), (-1.50%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 14/30]
Epoch 14 | Train Loss: 0.5314 | Val. Loss: 0.5577 | Time: 19.14s
[CHECKPOINTER] Validation loss decreased: (0.563908 --> 0.557748), (-1.09%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 15/30]
Epoch 15 | Train Loss: 0.5154 | Val. Loss: 0.5573 | Time: 19.16s
[CHECKPOINTER] Validation loss decreased: (0.557748 --> 0.557256), (-0.09%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 16/30]
Epoch 16 | Train Loss: 0.5136 | Val. Loss: 0.5517 | Time: 19.08s
[CHECKPOINTER] Validation loss decreased: (0.557256 --> 0.551656), (-1.01%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 17/30]
Epoch 17 | Train Loss: 0.4960 | Val. Loss: 0.5537 | Time: 19.13s
[Epoch 18/30]
Epoch 18 | Train Loss: 0.4868 | Val. Loss: 0.5557 | Time: 18.84s
[Epoch 19/30]
Epoch 19 | Train Loss: 0.4890 | Val. Loss: 0.5368 | Time: 18.85s
[CHECKPOINTER] Validation loss decreased: (0.551656 --> 0.536754), (-2.70%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 20/30]
Epoch 20 | Train Loss: 0.4828 | Val. Loss: 0.5426 | Time: 19.40s
[Epoch 21/30]
Epoch 21 | Train Loss: 0.4664 | Val. Loss: 0.5397 | Time: 18.88s
[Epoch 22/30]
Epoch 22 | Train Loss: 0.4576 | Val. Loss: 0.5293 | Time: 18.82s
[CHECKPOINTER] Validation loss decreased: (0.536754 --> 0.529326), (-1.38%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 23/30]
Epoch 23 | Train Loss: 0.4554 | Val. Loss: 0.5226 | Time: 19.08s
[CHECKPOINTER] Validation loss decreased: (0.529326 --> 0.522598), (-1.27%).
[CHECKPOINTER] Checkpoint saved successfully at: checkpoint.pt
[Epoch 24/30]
Epoch 24 | Train Loss: 0.4541 | Val. Loss: 0.5298 | Time: 19.10s
[Epoch 25/30]
Epoch 25 | Train Loss: 0.4490 | Val. Loss: 0.5253 | Time: 18.86s
[Epoch 26/30]
Epoch 26 | Train Loss: 0.4373 | Val. Loss: 0.5240 | Time: 18.86s
[Epoch 27/30]
Epoch 27 | Train Loss: 0.4341 | Val. Loss: 0.5269 | Time: 18.94s
[EARLY STOPPING] Elapsed epochs: 4 out of 5
[Epoch 28/30]
Epoch 28 | Train Loss: 0.4306 | Val. Loss: 0.5257 | Time: 18.84s
[EARLY STOPPING] Elapsed epochs: 5 out of 5
[EARLY STOPPING] Patience exceeded, early stoping ...
Training has finished.
5. Evaluating on the Test Set
Once training has finished we reload the best checkpoint using AudioClassifier.from_checkpoint, and initialize the Evaluator to check the performance of the classifier over the held-out test split.
evaluate() prints per-class accuracy, macro metrics, and a full classification report.
[9]:
from deepaudiox import Evaluator
model = AudioClassifier.from_checkpoint(CHECKPOINT)
evaluator = Evaluator(
test_dset=test_dataset,
model=model,
class_mapping=class_mapping,
device=DEVICE,
)
evaluator.evaluate()
Using GPU: NVIDIA GeForce RTX 4090
Testing has finished.
[REPORTER] Class mapping: {'blues': 0, 'classical': 1, 'country': 2, 'disco': 3, 'hiphop': 4, 'jazz': 5, 'metal': 6, 'pop': 7, 'reggae': 8, 'rock': 9}
[REPORTER] Classification Report:
precision recall f1-score support
blues 0.74 0.69 0.72 200
classical 0.98 0.99 0.99 200
country 0.66 0.95 0.78 200
disco 0.86 0.89 0.88 200
hiphop 0.91 0.98 0.94 200
jazz 0.86 0.94 0.90 190
metal 0.87 0.88 0.88 190
pop 0.87 0.93 0.90 190
reggae 0.90 0.51 0.65 190
rock 0.73 0.55 0.63 190
accuracy 0.83 1950
macro avg 0.84 0.83 0.83 1950
weighted avg 0.84 0.83 0.83 1950
[REPORTER] Confusion Matrix:
[[139 1 27 13 0 14 0 0 0 6]
[ 0 198 0 0 0 1 0 1 0 0]
[ 9 0 190 0 0 0 0 0 0 1]
[ 2 0 3 178 6 0 0 3 5 3]
[ 0 0 0 3 196 0 0 0 0 1]
[ 0 2 2 0 1 179 0 1 2 3]
[ 4 0 3 1 1 0 167 0 0 14]
[ 1 0 4 4 0 0 0 176 0 5]
[ 30 0 12 5 10 11 0 21 96 5]
[ 2 1 48 2 1 2 24 1 4 105]]
[REPORTER] Average Posteriors:
blues : 0.722
classical : 0.958
country : 0.920
disco : 0.841
hiphop : 0.893
jazz : 0.932
metal : 0.903
pop : 0.855
reggae : 0.767
rock : 0.639
6. Inference on a New Audio File
DeepAudioX provides a flexible method inference_on_file to make predictions on unseen data. The method accepts any WAV or MP3 path and returns the predicted class label along with its posterior probability.
The argument segment_duration allows us to specify the desired segment duration upon processing the entire file. For the best performance this is advisable to coincide with the same duration used in training.
Here we grab the first .wav file from the test directory as a quick sanity check.
[10]:
from pathlib import Path
sample_file = next(Path(TEST_DIR).rglob("*.wav"))
true_label = sample_file.parent.name
result = model.inference_on_file(
path=str(sample_file),
sample_rate=SAMPLE_RATE,
class_mapping=class_mapping,
segment_duration=3.0, # same segment duration used during training
)
The method returns a dictionary containing the final_label or the prediction. This is the result of the majority vote across all 3-second segments in the file. In addition, the dictionary contains information about each individual segment-level prediction.
[11]:
result
[11]:
{'final_label': 'country',
'final_posterior': 0.959261554479599,
'segment_labels': ['country',
'country',
'country',
'country',
'country',
'country',
'country',
'country',
'country',
'country',
'classical'],
'segment_posteriors': [0.9425491094589233,
0.9905596375465393,
0.9864566326141357,
0.9714261889457703,
0.9702315330505371,
0.973042905330658,
0.9922925233840942,
0.9042606353759766,
0.9638944864273071,
0.8979018926620483,
0.6707818508148193]}