scatteringTransform
Syntax
Description
[
specifies options using one or more name-value arguments. These arguments can be added to
the previous input syntax. For example, to average along the time dimension for all JTFS
coefficients, set outCFS
,outMETA
] = scatteringTransform(___,Name=Value
)TimeAverage
to "global"
.
Examples
Joint Time-Frequency Scattering Transform of Signal
Create a single-precision random signal with three channels and 1024 samples representing a batch of 5. Save the signal as a dlarray
in "CTB"
format.
nchan = 3;
nsam = 1024;
nbatch = 5;
sig = single(randn([nchan nsam nbatch]));
x = dlarray(sig,"CTB");
Create a JTFS network appropriate for the signal. Set the filter data type of the network to "single"
.
jtfn = timeFrequencyScattering(SignalLength=nsam, ... FilterDataType="single");
Use the scatteringTransform
function to obtain the JTFS transform of the signal. Also obtain the transform metadata.
[outCFS,outMETA] = scatteringTransform(jtfn,x);
Inspect the JTFS coefficient arrays. The format of each coefficient array is path-by-frequency-by-time-by-channel-by-batch.
outCFS
outCFS = dictionary (string ⟼ cell) with 5 entries: "S1FreqLowpass" ⟼ {5-D dlarray} "S1SpinUpFreqLowpass" ⟼ {5-D dlarray} "SpinUp" ⟼ {5-D dlarray} "SpinDown" ⟼ {5-D dlarray} "U2JointLowpass" ⟼ {5-D dlarray}
If the input signal is a formatted or unformatted dlarray
, every dictionary value is an unformatted dlarray
. Choose any dictionary value. Confirm that value is an unformatted dlarray
and the underlying data type is single precision.
key = "S1SpinUpFreqLowpass";
val = outCFS{key};
dims(val)
ans = 0×0 empty char array
underlyingType(val)
ans = 'single'
Inspect the SpinUp
coefficients array and its metadata. The metadata in the ith table row describes the coefficients outCFS{"SpinUp"}(i,:,:,:,:)
.
cfs = outCFS{"SpinUp"}; [numPath,numFrequency,numTime,numChannel,numBatch] = size(cfs) %#ok<*ASGLU>
numPath = 35
numFrequency = 6
numTime = 8
numChannel = 3
numBatch = 5
outMETA{3}
ans=35×5 table
type log2dsfactor path spin log2stride
________ ____________ ______ ____ __________
"SpinUp" 0 1 1 3 1 3 7
"SpinUp" 0 1 2 3 1 3 7
"SpinUp" 1 1 3 3 1 3 7
"SpinUp" 2 1 4 3 1 3 7
"SpinUp" 2 1 5 3 1 3 7
"SpinUp" 0 2 1 4 1 3 7
"SpinUp" 0 2 2 4 1 3 7
"SpinUp" 1 2 3 4 1 3 7
"SpinUp" 2 2 4 4 1 3 7
"SpinUp" 2 2 5 4 1 3 7
"SpinUp" 0 3 1 5 1 3 7
"SpinUp" 0 3 2 5 1 3 7
"SpinUp" 1 3 3 5 1 3 7
"SpinUp" 2 3 4 5 1 3 7
"SpinUp" 2 3 5 5 1 3 7
"SpinUp" 0 4 1 6 1 3 7
⋮
Inspect the U2JointLowpass
coefficients array and its metadata. Because the scatteringTransform
function did not use spin-up or spin-down wavelets to compute these coefficients, the spin
value for all coefficient paths is 0
.
cfs = outCFS{"U2JointLowpass"};
[numPath,numFrequency,numTime,numChannel,numBatch] = size(cfs)
numPath = 7
numFrequency = 6
numTime = 8
numChannel = 3
numBatch = 5
outMETA{5}
ans=7×5 table
type log2dsfactor path spin log2stride
________________ ____________ ________ ____ __________
"U2JointLowpass" 1 -1 3 0 3 7
"U2JointLowpass" 2 -1 4 0 3 7
"U2JointLowpass" 3 -1 5 0 3 7
"U2JointLowpass" 4 -1 6 0 3 7
"U2JointLowpass" 5 -1 7 0 3 7
"U2JointLowpass" 6 -1 8 0 3 7
"U2JointLowpass" 6 -1 9 0 3 7
Change Oversampling Factors in Scattering Transform
Load the ECG signal data. The data has 2048 samples. Create a JTFS network appropriate for the signal.
load wecg
len = length(wecg);
jtfn = timeFrequencyScattering(SignalLength=len);
Obtain the JTFS transform of the signal using default function parameters. Also obtain the transform metadata. By default, scatteringTransform
critically downsamples values in time and frequency. Because the data contains one batch of a single-channel signal, the format of the coefficient dictionary values is path-by-frequency-by-time.
[outCFS,outMETA] = scatteringTransform(jtfn,wecg); outCFS
outCFS = dictionary (string ⟼ cell) with 5 entries: "S1FreqLowpass" ⟼ {1×7×8 double} "S1SpinUpFreqLowpass" ⟼ {5×7×8 double} "SpinUp" ⟼ {40×7×8 double} "SpinDown" ⟼ {40×7×8 double} "U2JointLowpass" ⟼ {8×7×8 double}
Obtain the JTFS transform with TimeOversamplingFactor
set to 1. Because you specify a time oversampling factor of 1, the size of the time dimension in the coefficient arrays increases by a factor of 2. The sizes of the path and frequency dimensions remain the same.
[outCFS_T1,outMETA_T1] = scatteringTransform(jtfn,wecg, ...
TimeOversamplingFactor=1);
outCFS_T1
outCFS_T1 = dictionary (string ⟼ cell) with 5 entries: "S1FreqLowpass" ⟼ {1×7×16 double} "S1SpinUpFreqLowpass" ⟼ {5×7×16 double} "SpinUp" ⟼ {40×7×16 double} "SpinDown" ⟼ {40×7×16 double} "U2JointLowpass" ⟼ {8×7×16 double}
Compare the first five rows in the "SpinDown"
metadata tables. The second column in the log2dsfactor
and log2stride
table variables indicates the downsampling factor in time. By oversampling in time by 1, those values in the metadata from the second transform have decreased by 1.
outMETA{4}(1:5,:)
ans=5×5 table
type log2dsfactor path spin log2stride
__________ ____________ ________ ____ __________
"SpinDown" 0 1 6 3 -1 3 8
"SpinDown" 0 1 7 3 -1 3 8
"SpinDown" 1 1 8 3 -1 3 8
"SpinDown" 2 1 9 3 -1 3 8
"SpinDown" 2 1 10 3 -1 3 8
outMETA_T1{4}(1:5,:)
ans=5×5 table
type log2dsfactor path spin log2stride
__________ ____________ ________ ____ __________
"SpinDown" 0 0 6 3 -1 3 7
"SpinDown" 0 0 7 3 -1 3 7
"SpinDown" 1 0 8 3 -1 3 7
"SpinDown" 2 0 9 3 -1 3 7
"SpinDown" 2 0 10 3 -1 3 7
Now obtain the JTFS transform of the signal with FrequencyOversamplingFactor
set to 1. Compared with the first transform, the size of the frequency dimension in the coefficient arrays is twice as large. The sizes of the path and time dimensions are the same.
[outCFS_F1,outMETA_F1] = scatteringTransform(jtfn,wecg, ...
FrequencyOversamplingFactor=1);
outCFS_F1
outCFS_F1 = dictionary (string ⟼ cell) with 5 entries: "S1FreqLowpass" ⟼ {1×14×8 double} "S1SpinUpFreqLowpass" ⟼ {5×14×8 double} "SpinUp" ⟼ {40×14×8 double} "SpinDown" ⟼ {40×14×8 double} "U2JointLowpass" ⟼ {8×14×8 double}
Time- and Frequency-Averaging in Scattering Transform
Create a single-precision random signal with three channels and 1000 samples representing a batch of 5. For 3-D numeric input, scatteringTransform
assumes the dimensions are time-by-channel-by-batch. Save the signal as a gpuArray
.
nsam = 1000; nchan = 3; nbatch = 5; sig = single(randn([nsam nchan nbatch])); x = gpuArray(sig);
Create a JTFS network appropriate for the signal.
jtfn = timeFrequencyScattering(SignalLength=nsam, ... FilterDataType="single");
Obtain the JTFS transform of the signal using default settings. The scatteringTransform
function uses lowpass filtering to obtain the coefficients.
outCFS = scatteringTransform(jtfn,x)
outCFS = dictionary (string ⟼ cell) with 5 entries: "S1FreqLowpass" ⟼ {5-D gpuArray} "S1SpinUpFreqLowpass" ⟼ {5-D gpuArray} "SpinUp" ⟼ {5-D gpuArray} "SpinDown" ⟼ {5-D gpuArray} "U2JointLowpass" ⟼ {5-D gpuArray}
Obtain the dimensions of the coefficient arrays. The arrays are in path-by-frequency-by-time-by-channel-by-batch format.
dictionaryValues = values(outCFS); cellfun(@size,dictionaryValues,UniformOutput=false)
ans=5×1 cell array
{[ 1 6 7 3 5]}
{[ 5 6 7 3 5]}
{[35 6 7 3 5]}
{[35 6 7 3 5]}
{[ 7 6 7 3 5]}
Obtain the JTFS transform of the signal with TimeAverage
set to "global"
. Instead of using lowpass filtering, the function takes the mean along the time dimension for all the coefficients. The size of the time dimension in the coefficient arrays is 1.
outCFS_T = scatteringTransform(jtfn,x, ... TimeAverage="global"); dictionaryValues_T = values(outCFS_T); cellfun(@size,dictionaryValues_T,UniformOutput=false)
ans=5×1 cell array
{[ 1 6 1 3 5]}
{[ 5 6 1 3 5]}
{[35 6 1 3 5]}
{[35 6 1 3 5]}
{[ 7 6 1 3 5]}
Obtain the JTFS transform of the signal with FrequencyAverage
set to "global"
. Instead of using lowpass filtering, the function takes the mean along the frequency dimension for all the coefficients. The size of the frequency dimension in the coefficient arrays is 1.
outCFS_F = scatteringTransform(jtfn,x, ... FrequencyAverage="global"); dictionaryValues_F = values(outCFS_F); cellfun(@size,dictionaryValues_F,UniformOutput=false)
ans=5×1 cell array
{[ 1 1 7 3 5]}
{[ 5 1 7 3 5]}
{[35 1 7 3 5]}
{[35 1 7 3 5]}
{[ 7 1 7 3 5]}
Obtain the JTFS transform of the signal with TimeAverage
and FrequencyAverage
both set to "global"
.
outCFS_TF = scatteringTransform(jtfn,x, ... TimeAverage="global", ... FrequencyAverage="global"); dictionaryValues_TF = values(outCFS_TF); cellfun(@size,dictionaryValues_TF,UniformOutput=false)
ans=5×1 cell array
{[ 1 1 1 3 5]}
{[ 5 1 1 3 5]}
{[35 1 1 3 5]}
{[35 1 1 3 5]}
{[ 7 1 1 3 5]}
Confirm the underlying data type of the coefficients is single precision.
dictionaryValues = values(outCFS_TF); cellfun(@underlyingType,dictionaryValues,UniformOutput=false)
ans = 5×1 cell
{'single'}
{'single'}
{'single'}
{'single'}
{'single'}
Gather the "SpinUp"
coefficients from the GPU. Compare with the same coefficients in the JTFS transform of the original random signal. Confirm the coefficients are equal.
cfs = "SpinUp"; cfsG = gather(outCFS_TF{cfs}); outCFS_TF_ORIG = scatteringTransform(jtfn,sig, ... TimeAverage="global", ... FrequencyAverage="global"); cfsO = outCFS_TF_ORIG{cfs}; max(abs(cfsG(:)-cfsO(:)))
ans = single
7.4506e-08
Input Arguments
jtfn
— Joint time-frequency scattering network
timeFrequencyScattering
object
Joint time-frequency scattering network, specified as a timeFrequencyScattering
object.
x
— Input data
numeric array | dlarray
object
Input data, specified as a formatted or unformatted dlarray
(Deep Learning Toolbox) object
or a numeric array. If x
is a formatted dlarray
,
it must be in "CBT"
format. If x
is an
unformatted dlarray
, it must be compatible with
"CBT"
format and you must set
DataFormat
.
If x
is 2-D, the scatteringTransform
function assumes the first dimension is time and the columns of x
are separate channels. If x
is 3-D, the dimensions of
x
are time-by-channel-by-batch.
If
x
is a vector or unformatteddlarray
, the number of samples inx
must match theSignalLength
property ofjtfn
.If
x
is a numeric or unformatted matrix or a 3-D array, the number of rows inx
must matchSignalLength
.If
x
is a formatteddlarray
, the length of the time dimension must matchSignalLength
.
Data Types: single
| double
Name-Value Arguments
Specify optional pairs of arguments as
Name1=Value1,...,NameN=ValueN
, where Name
is
the argument name and Value
is the corresponding value.
Name-value arguments must appear after other arguments, but the order of the
pairs does not matter.
Example: outCFS =
scatteringTransform(jtfn,x,DataFormat="CBT",FrequencyAverage="global")
specifies
the format of the unformatted dlarray
x
as "CBT"
and takes the mean along the frequency
dimension for all JTFS coefficients.
ExcludeCoefficients
— Coefficients to exclude
''
(default) | string vector | cell array of character vectors
Coefficients to exclude from the JTFS transform, specified as a string vector or cell array of character vectors. You can specify these coefficients:
"S1FreqLowpass"
— First-order time scattering coefficients filtered with the frequency lowpass filter"S1SpinUpFreqLowpass"
— First-order time scattering coefficients with the spin-up frequency wavelets"SpinUp"
— Second-order time scattering coefficients with spin-up wavelets"SpinDown"
— Second-order time scattering coefficients with spin-down wavelets"U2JointLowpass"
— Second-order time scattering coefficients filtered with joint lowpass filters
Example: outCFS =
scatteringTransform(jtfn,x,ExcludeCoefficients=["S1FreqLowpass"
"U2JointLowpass"])
TimeAverage
— Time-averaging option
"local"
(default) | "global"
Time-averaging option, specified as one of these:
"local"
—scatteringTransform
uses the lowpass filter when obtaining the JTFS coefficients."global"
—scatteringTransform
takes the mean along the time dimension for all JTFS coefficients.
FrequencyAverage
— Frequency-averaging option
"local"
(default) | "global"
Frequency-averaging option, specified as one of these:
"local"
—scatteringTransform
uses the lowpass frequency filter when obtaining the JTFS coefficients."global"
—scatteringTransform
takes the mean along the frequency dimension for all JTFS coefficients.
TimeOversamplingFactor
— Time oversampling factor
0
(default) | nonnegative integer
Time oversampling factor, specified as a nonnegative integer. The factor specifies how much the coefficients are oversampled in time with respect to the critically downsampled values. The factor is on a base-2 logarithmic scale.
If you increase the oversampling factor, the computational costs and memory requirements of the scattering transform also increase.
Note
The number of paths in the JTFS network does not depend on the time oversampling
factor. This is different from waveletScattering
. The value of the
OversamplingFactor
property in a wavelet scattering network
affects the number of paths in the network.
Data Types: single
| double
FrequencyOversamplingFactor
— Frequency oversampling factor
0
(default) | nonnegative integer
Frequency oversampling factor, specified as a nonnegative integer. The factor specifies how much the coefficients are oversampled in frequency with respect to the critically downsampled values. The factor is on a base-2 logarithmic scale.
If you increase the oversampling factor, the computational costs and memory requirements of the scattering transform also increase.
Note
The number of paths in the JTFS network does not depend on the frequency
oversampling factor. This is different from waveletScattering
. The value of the
OversamplingFactor
property in a wavelet scattering network
affects the number of paths in the network.
Data Types: single
| double
DataFormat
— Data format
character vector | string scalar
Data format of x
, specified as a character vector or string
scalar. This name-value argument is valid only if x
is an
unformatted dlarray
. If x
is not a
dlarray
, the function ignores the DataFormat
argument.
Each character in this argument must be one of these labels:
"C"
— Channel"B"
— Batch observations"T"
— Time
DataFormat
can be any permutation of
"CBT"
.
Data Types: char
| string
Output Arguments
outCFS
— Joint time-frequency scattering transform
dictionary
object
Joint time-frequency scattering transform, returned as a dictionary
object with these keys:
"S1FreqLowpass"
— First-order time scattering coefficients filtered with the frequency lowpass filter"S1SpinUpFreqLowpass"
— First-order time scattering coefficients with the spin-up frequency wavelets"SpinUp"
— Second-order time scattering coefficients with spin-up wavelets"SpinDown"
— Second-order time scattering coefficients with spin-down wavelets"U2JointLowpass"
— Second-order time scattering coefficients filtered with joint lowpass filters
For more information, see Joint Time-Frequency Scattering Coefficients.
All dictionary values are in path-by-frequency-by-time-by-channel-by-batch format.
If x
is a formatted or unformatted dlarray
,
every dictionary value is an unformatted dlarray
.
outMETA
— Metadata
cell array
Metadata for each coefficient key in outCFS
, returned as a cell
array of tables. All tables have these variables:
type
— Coefficient key.path
— Two-column variable indicating the coefficient path. The first column is the index of the frequency wavelet, and the second column is the index of the second-order time wavelet. A value of –1 indicates the lowpass filter.spin
— Wavelet spin. A value of 1 indicates a spin-up wavelet, and –1 indicates a spin-down wavelet. A value of 0 indicates thatscatteringTransform
did not use a spin-up or spin-down wavelet to compute those coefficients.log2stride
— Two-column variable indicating how muchscatteringTransform
downsamples in frequency (first column) and time (second column) after applying the lowpass filters. Values are on a base-2 logarithmic scale.log2dsfactor
— Downsampling factors in frequency and second-order time. Values are on a base-2 logarithmic scale.If
type
is"SpinUp"
or"SpinDown"
, thenlog2dsfactor
is a two-column variable indicating the downsampling factors in frequency (first column) and second-order time (second column).If
type
is"S1SpinUpFreqLowpass"
, thenlog2dsfactor
is a single-column variable indicating the downsampling factor in frequency.If
type
is"U2JointLowpass"
, thenlog2dsfactor
is a single-column variable indicating the downsampling factor in second-order time.This table variable is not applicable when
type
is"S1FreqLowpass"
.
More About
Joint Time-Frequency Scattering Coefficients
The joint time-frequency scattering (JTFS) transform is used to extract time-frequency features from a signal that are invariant to shifts and deformations in time and frequency. Compute the JTFS transform by first convolving the signal in time with wavelets followed by pointwise modulus nonlinearities. Then filter that result along frequency with frequential wavelets [1][2].
Let:
x denote the signal.
and denote the time wavelets in the first- and second-order filter banks, respectively.
denote the frequential wavelets of spin s. If s = 1, these are the spin-up wavelets. If s = –1, these are the spin-down wavelets.
and denote the time and frequential lowpass filters, respectively.
Then the JTFS coefficients are defined as:
"S1FreqLowpass"
—"S1SpinUpFreqLowpass"
— for s = 1"SpinUp"
— for s = 1"SpinDown"
— for s = –1"U2JointLowpass"
—
For more information, see Joint Time-Frequency Scattering.
References
[1] Andén, Joakim, Vincent Lostanlen, and Stéphane Mallat. “Joint Time–Frequency Scattering.” IEEE Transactions on Signal Processing 67, no. 14 (July 15, 2019): 3704–18.https://doi.org/10.1109/TSP.2019.2918992
[2] Lostanlen, Vincent, Christian El-Hajj, Mathias Rossignol, Grégoire Lafay, Joakim Andén, and Mathieu Lagrange. “Time–Frequency Scattering Accurately Models Auditory Similarities between Instrumental Playing Techniques.” EURASIP Journal on Audio, Speech, and Music Processing 2021, no. 1 (December 2021): 3. https://doi.org/10.1186/s13636-020-00187-z
[3] Mallat, Stéphane. “Group Invariant Scattering.” Communications on Pure and Applied Mathematics 65, no. 10 (October 2012): 1331–98. https://doi.org/10.1002/cpa.21413
Extended Capabilities
GPU Arrays
Accelerate code by running on a graphics processing unit (GPU) using Parallel Computing Toolbox™.
The scatteringTransform
function
fully supports GPU arrays. To run the function on a GPU, specify the input data as a gpuArray
(Parallel Computing Toolbox). For more information, see Run MATLAB Functions on a GPU (Parallel Computing Toolbox).
Version History
Introduced in R2024b
MATLAB Command
You clicked a link that corresponds to this MATLAB command:
Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)
Asia Pacific
- Australia (English)
- India (English)
- New Zealand (English)
- 中国
- 日本Japanese (日本語)
- 한국Korean (한국어)