Speeding up numerical gradient of tensor with FFTs
2 views (last 30 days)
Show older comments
I have a gradient problem that involves the following function: adding a phase to each column in a matrix and computing the FFT of each column, and aggregating all FFTs into a new matrix.
My brute-force numerical gradient is adding a small phase to each column, iteratively, and then computing the loss by comparing it with some known matrix. currX is the current guess of phases, Ts2pHH is the matrix to whose columns these phases are added, and the function is computing by taking a sum along 2 dimensions of the intensity and then adding it up.
I want to know if this can be done more efficiently, because for my matrix size (~1000x1000) this takes around 2 minutes, which is very slow.
My code is shown below:
for k = 1:length(currX)
currX_perturbed = currX;
currX_perturbed(k) = currX_perturbed(k) + epsilon;
phases_perturbed = exp(1i * [0, currX_perturbed]);
Tcorr_perturbed = Ts2pHH .* phases_perturbed;
TcorrFFT_perturbed = fftshift(fft(fft(fft(fft(reshape(Tcorr_perturbed, [Npx, Npx, Nin, Nin]), [], 3), [], 4), [], 1), [], 2));
inputFreq_perturbed = squeeze(sum(sum(abs(TcorrFFT_perturbed).^2, 1), 2));
gradient(k) = gradient(k) + (-sum(inputFreq_perturbed .* support, 'all') - loss) / epsilon;
end
0 Comments
Answers (1)
Sahas
on 19 Jul 2024
As per my understanding, you would like to optimize the code provided so that it executes efficiently for bigger sized data inputs.
Since I am not sure of the entire algorithm, providing an algorithmic optimization is challenging.
However, MATLAB provides a number of code optimization techniques and code writing strategies such as “vectorization”, “parallelization”, “pre-allocation”. The execution time can be reduced by incorporating combination of these techniques. More information about such methods can be found in the below links:
https://www.mathworks.com/help/coder/ug/optimize-generated-code.html -- Optimization strategies for various scenarios when writing code
https://www.mathworks.com/help/matlab/matlab_prog/techniques-for-improving-performance.html -- Programming practices for better code performance
https://www.mathworks.com/help/matlab/matlab_prog/vectorization.html --- Basics of “vectorization” method of coding in MATLAB
https://www.mathworks.com/help/matlab/matlab_prog/preallocating-arrays.html -- Basics of pre-allocating memory in MATLAB
Please find the attached code for reference on how to use the “pre-allocation” technique.
% Pre-allocation Technique
currX_perturbed = currX;
phases_perturbed = zeros(numElements + 1, numElements);
Tcorr_perturbed = zeros(Npx, Npx, Nin, Nin, numElements);
TcorrFFT_perturbed = zeros(Npx, Npx, Nin, Nin, numElements);
inputFreq_perturbed = zeros(Nin, Nin, numElements);
Below is a sample implementation on how to use the “vectorization” method. Please note that “vectorization” is a memory-intensive method and comes with a tradeoff, it might go out-of-memory for larger inputs.
% Vectorization Method
function [gradient, elapsedTime] = compute_gradient(currX, Ts2pHH, support, epsilon)
% Start timing
tic;
% Compute the original phases and Tcorr
phases = exp(1i * currX);
phases = reshape(phases, [1, 1, 1, length(currX)]);
Tcorr = Ts2pHH .* phases;
TcorrFFT = fftshift(fft(fft(fft(fft(Tcorr, [], 3), [], 4), [], 1), [], 2));
inputFreq = squeeze(sum(sum(abs(TcorrFFT).^2, 1), 2));
loss = sum(inputFreq .* support, 'all');
% Create a matrix of perturbed phases
perturbed_phases = exp(1i * (currX + epsilon * eye(length(currX))));
perturbed_phases = reshape(perturbed_phases, [1, 1, length(currX), length(currX)]);
% Apply perturbed phases to Ts2pHH
Tcorr_perturbed = Ts2pHH .* permute(perturbed_phases, [1, 2, 4, 3]);
% Compute FFT for all perturbed matrices
TcorrFFT_perturbed = fftshift(fft(fft(fft(fft(Tcorr_perturbed, [], 3), [], 4), [], 1), [], 2));
% Compute the input frequencies for all perturbed matrices
inputFreq_perturbed = squeeze(sum(sum(abs(TcorrFFT_perturbed).^2, 1), 2));
% Compute the gradient
gradient = (-sum(inputFreq_perturbed .* support, 1) - loss) / epsilon;
% Stop timing
elapsedTime = toc;
% Display the elapsed time
fprintf('Elapsed time: %.2f seconds\n', elapsedTime);
end
% TESTBENCH
n = 170;
currX = rand(1, n); % Example current phases
Ts2pHH = rand(n, n, n, n); % Example matrix
support = rand(n, n); % Example support matrix
epsilon = 1e-6;
[gradient, elapsedTime] = compute_gradient(currX, Ts2pHH, support, epsilon);
disp(gradient);
Hope this is beneficial!
0 Comments
See Also
Categories
Find more on Performance and Memory in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!