#!/usr/bin/env python3
# -*- coding: utf-8 -*-
""" S
Segmentation methods for 1D signals.
This module gathers a collection of functions to detect regions of interest (ROIs)
in the temporal domain.
"""
#
# Authors: Juan Sebastian ULLOA <lisofomia@gmail.com>
# Sylvain HAUPERT <sylvain.haupert@mnhn.fr>
#
# License: New BSD License
# =============================================================================
# Load the modules
# =============================================================================
# Import external modules
import numpy as np
from scipy import signal
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import pandas as pd
# import internal modules
from maad.sound import sinc
from maad import sound, util
#%%
# =============================================================================
# Private functions
# =============================================================================
def _corresp_onset_offset(onset, offset, tmin, tmax):
"""
Check that each onsets have a corresponding offset.
Parameters
----------
onset: ndarray
array with onset from find_rois_1d
offset: ndarray
array with offset from find_rois_1d
tmin: float
Start time of wav file (in s)
tmax:
End time of wav file (in s)
Return
------
onset : ndarray
onset with corresponding offset
offset : ndarray
offset with corresponding onset
"""
if onset[0] > offset[0]: # check start
onset = np.insert(onset,0,tmin)
else:
pass
if onset[-1] > offset[-1]: # check end
offset = np.append(offset,tmax)
else:
pass
return onset, offset
#%%
def _energy_windowed(s, fs: float, wl: int=512):
"""
Computse windowed energy on an audio signal.
Computes the energy of the signals by windows of length wl. Used to amplify sectors where the density of energy is higher
Parameters
----------
s : ndarray
input signal
fs : float
frequency sampling of the signal, used to keep track of temporal information of the signal
wl : int, default is 512
length of the window to summarize the rms value
Returns
-------
time : ndarray
temporal index vector
s_rms : ndarray
windowed rms signal
"""
s_aux = np.lib.pad(s, (0, wl-len(s)%wl), 'reflect') # padding
s_aux = s_aux**2
# s_aux = np.abs(s_aux) # absolute value. alternative option
s_aux = np.reshape(s_aux,(int(len(s_aux)/wl),wl))
s_rms = np.mean(s_aux,1)
time = np.arange(0,len(s_rms)) * wl / fs + wl*0.5/fs
return time, s_rms
#%%
# =============================================================================
# Public functions
# =============================================================================
[docs]
def find_rois_cwt(s, fs, flims, tlen, th: float=0, display=False, save_df=False,
savefilename='rois.csv', **kwargs):
"""
Find region of interest using known estimates of signal length and frequency limits.
The general approach is based on continous wavelet transform following a three step process
1. Filter the signal with a bandpass sinc filter
2. Smoothing the signal by convolving it with a Mexican hat wavelet (Ricker wavelet) [1]
3. Binarize the signal applying a linear threshold
Parameters
----------
s : ndarray
input signal
fs : float
frequency sampling of the signal, used to keep track of temporal information of the signal
flims : int
upper and lower frequencies (in Hz)
tlen : int
temporal length of signal searched (in s)
th : float, optional
threshold to binarize the output
display: boolean, optional, default is False
plot results if set to True, default is False
save_df : boolean, optional
save results to csv file
savefilename : str, optional
Name of the file to save the table as comma separatd values (csv)
Returns
-------
rois : pandas DataFrame
an table with temporal and frequencial limits of regions of interest
References
----------
.. [1] Pan Du, Warren A. Kibbe, Simon M. Lin, Improved peak detection in mass spectrum by incorporating continuous wavelet transform-based pattern matching, Bioinformatics, Volume 22, Issue 17, 1 September 2006, Pages 2059–2065, `DOI: 10.1093/bioinformatics/btl355 <https://doi.org/10.1093/bioinformatics/btl355>`_
Examples
--------
>>> from maad import sound, rois
>>> s, fs = sound.load('../data/spinetail.wav')
>>> rois.find_rois_cwt(s, fs, flims=(4500,8000), tlen=2, th=0, display=True)
min_f min_t max_f max_t
0 4500.0 0.74304 8000.0 2.50776
1 4500.0 5.10839 8000.0 7.33751
2 4500.0 11.23846 8000.0 13.37469
3 4500.0 16.16109 8000.0 18.29732
"""
# filter signal
s_filt = sinc(s, flims, fs, atten=80, transition_bw=0.8)
# rms: calculate window of maximum 5% of tlen. improves speed of cwt
wl = 2**np.floor(np.log2(tlen*fs*0.05))
t, s_rms = _energy_windowed(s_filt, fs, int(wl))
# find peaks
cwt_width = [round(tlen*fs/wl/2)]
npad = 5 ## seems to work with 3, but not sure
s_rms = np.pad(s_rms, np.int64(cwt_width[0]*npad), 'reflect') ## add pad
s_cwt = signal.cwt(s_rms, signal.ricker, cwt_width)
s_cwt = s_cwt[0][np.int64(cwt_width[0]*npad):len(s_cwt[0])-np.int64(cwt_width[0]*npad)] ## rm pad
# find onset and offset of sound
segments_bin = np.array(s_cwt > th)
onset = t[np.where(np.diff(segments_bin.astype(int)) > 0)]+t[0] # there is delay because of the diff that needs to be accounted
offset = t[np.where(np.diff(segments_bin.astype(int)) < 0)]+t[0]
# format for output
if onset.size==0 or offset.size==0:
# No detection found
print('Warning: No detection found')
df = pd.DataFrame(data=None)
if save_df==True:
df.to_csv(savefilename, sep=',',header=False, index=False)
else:
# A detection was found, save results to csv
onset, offset = _corresp_onset_offset(onset, offset, tmin=0, tmax=len(s)/fs)
rois_tf = np.transpose([np.repeat(flims[0],repeats=len(onset)),
np.round(onset,5),
np.repeat(flims[1],repeats=len(onset)),
np.round(offset,5)])
cols=['min_f', 'min_t','max_f', 'max_t']
df = pd.DataFrame(data=rois_tf,columns=cols)
if save_df==True:
df.to_csv(savefilename, sep=',', header=True, index=False)
# Display
if display==True:
figsize = kwargs.pop('figsize',(12,6))
cmap = kwargs.pop('cmap','gray')
nfft = kwargs.pop('nperseg',512)
noverlap = kwargs.pop('noverlap',512)
nperseg = kwargs.pop('nperseg',1024)
db_range = kwargs.pop('db_range',80)
# plot wavelet
fig,(ax1,ax2) = plt.subplots(2,1,figsize=figsize)
ax1.margins(x=0)
ax1.plot(s_cwt)
ax1.set_xticks([])
ax1.set_ylabel('Amplitude')
ax1.grid(True)
ax1.hlines(th, 0, len(s_cwt), linestyles='dashed', colors='r')
#plot spectrogram
Sxx, tn, fn, ext = sound.spectrogram(s, fs, nperseg = nperseg, noverlap=noverlap, mode='psd')
util.plot_spectrogram(Sxx, ext, db_range, ax=ax2, colorbar=False)
ax2.set_ylabel('Frequency (Hz)')
ax2.set_xlabel('Time (s)')
if not(df.empty):
for idx, row in df.iterrows():
xy = (row.min_t, row.min_f)
width = row.max_t- row.min_t
height = row.max_f- row.min_f
rect = patches.Rectangle(xy, width, height, lw=1,
edgecolor='yellow', facecolor='none')
ax2.add_patch(rect)
plt.show()
return df
if __name__ == "__main__":
import doctest
doctest.testmod()