如何保留指定的维度之后的所有维度,而不显式列出它们?

Eri*_*ric 16 matlab numpy

或者等效地,"在Matlab中NumPy的省略号索引的等价物"

说我有一些高维数组:

x = zeros(3, 4, 5, 6);
Run Code Online (Sandbox Code Playgroud)

我想编写一个采用大小数组的函数,(3, ...)并进行一些计算.在NumPy中,我可以这样写:

def fun(x):
    return x[0]*x[1] + x[2]
Run Code Online (Sandbox Code Playgroud)

但是,MATLAB中的等效函数不起作用,因为使用一个整数进行索引会将数组展平为1d

function y = fun_bad(x)
    y = x(1)*x(2) + x(3)
Run Code Online (Sandbox Code Playgroud)

我可以用最多三维数组来完成这项工作

function y = fun_ok3d(x)
    y = x(1,:,:)*x(2,:,:) + x(3,:,:)
Run Code Online (Sandbox Code Playgroud)

如果我希望这适用于多达10维阵列,我可以写

function y = fun_ok10d(x)
    y = x(1,:,:,:,:,:,:,:,:,:)*x(2,:,:,:,:,:,:,:,:,:) + x(3,:,:,:,:,:,:,:,:,:)
Run Code Online (Sandbox Code Playgroud)

我怎样才能避免在这里编写愚蠢数量的冒号,只是让它适用于任何维度?是否有一些x(1,...)语法暗示这一点?

NumPy可以在索引表达式中使用...(Ellipsis)文字来表示" :根据需要多次",这将解决此问题.

Lui*_*ndo 18

方法1:使用逗号分隔列表 ':'

我不知道指定的方法

: 根据需要多次

同时保持形状.但你可以指定

:一个任意次数

在运行时定义该次数的位置.使用此方法,您可以保留形状,前提是索引数与维数一致.

这是通过使用完成逗号分隔的列表从一个单元阵列生成的,并且利用这样的事实 ':'可被用作一个索引,而不是::

function y = fun(x)
colons = repmat({':'}, 1, ndims(x)-1); % row cell array containing the string ':'
                                       % repeated the required number of times
y = x(1,colons{:}).*x(2,colons{:}) + x(3,colons{:});
Run Code Online (Sandbox Code Playgroud)

这种方法可以很容易地推广到任何维度的索引,而不仅仅是第一个:

function y = fun(x, dim)
% Input argument dim is the dimension along which to index
colons_pre = repmat({':'}, 1, dim-1);
colons_post = repmat({':'}, 1, ndims(x)-dim);
y = x(colons_pre{:}, 1, colons_post{:}) ...
  .*x(colons_pre{:}, 2, colons_post{:}) ...
  + x(colons_pre{:}, 3, colons_post{:});
Run Code Online (Sandbox Code Playgroud)

方法2:拆分数组

您可以使用num2cell,沿第一个维度拆分数组,然后将操作应用于生成的子数组.当然这会占用更多内存; 正如@Adriaan所说,它更慢.

function y = fun(x)
xs = num2cell(x, [2:ndims(x)]); % x split along the first dimension
y = xs{1}.*xs{2} + xs{3};
Run Code Online (Sandbox Code Playgroud)

或者,为了沿任何维度编制索引:

function y = fun(x, dim)
xs = num2cell(x, [1:dim-1 dim+1:ndims(x)]); % x split along dimension dim
y = xs{1}.*xs{2} + xs{3};
Run Code Online (Sandbox Code Playgroud)


Adr*_*aan 15

当使用单个冒号时,MATLAB会平展所有尾随尺寸,因此您可以使用它来从N- D阵列到达2D阵列,您可以在计算后将其reshape返回到原始N维.

沿着第一个维度

如果你想使用第一个维度,你可以使用一个相对简单和短的代码:

function y = MyMultiDimensional(x)
    x_size = size(x); % Get input size
    yflat = x(1,:) .* x(2,:) + x(3,:); % Calculate "flattened" 2D function
    y = reshape(yflat, [1 x_size(2:end)]); % Reshape output back to original size
end
Run Code Online (Sandbox Code Playgroud)

沿着任意维度,现在以ND置换为特色.

当你希望你的函数沿着总共N个中的第n个维度行动时,你可以将该维度放在前面:permute

function y = MyMultiDimensional(x,n)
    x_size = size(x); % Get input size

    Order = 1:numel(x_size);
    Order(n)=[]; % Remove n-th dimension
    Order2 = [n, Order]; % Prepend n-th dimension

    xPermuted = permute(x,Order2); % permute the n-th dimension to the front
    yTmp = xPermuted (1,:) .* xPermuted (2,:) + xPermuted (3,:); % Calculate "flattened" 2D function
    y = reshape(yTmp, x_size(Order)); % Reshape output back to original size
end
Run Code Online (Sandbox Code Playgroud)

我计算了Luis'和我的方法的两种方法的结果:

function timeMultiDim()

x = rand(1e1,1e1,1e1,1e1,1e1,1e1,1e1,1e1);

    function y = Luis1(x)
        colons = repmat({':'}, 1, ndims(x)-1); % row cell array containing the string ':'
        % repeated the required number of times
        y = x(1,colons{:}).*x(2,colons{:}) + x(3,colons{:});

    end

    function y = Luis2(x)
        xs = num2cell(x, [2:ndims(x)]); % x split along the first dimension
        y = xs{1}.*xs{2} + xs{3};
    end

    function y = Adriaan(x)
        x_size = size(x); % Get input size
        yflat = x(1,:) .* x(2,:) + x(3,:); % Calculate "flattened" 2D function
        y = reshape(yflat, [1 x_size(2:end)]); % Reshape output back to original size
    end

n=1;
    function y = Adriaan2(x,n)
        x_size = size(x); % Get input size

        Order = 1:numel(x_size);
        Order(n)=[]; % Remove n-th dimension
        Order2 = [n, Order]; % Prepend n-th dimension

        xPermuted = permute(x,Order2); % permute the n-th dimension to the front
        yTmp = xPermuted (1,:) .* xPermuted (2,:) + xPermuted (3,:); % Calculate "flattened" 2D function
        y = reshape(yTmp, x_size(Order)); % Reshape output back to original size

    end

t1 = timeit(@() Luis1(x));
t2 = timeit(@() Luis2(x));
t3 = timeit(@() Adriaan(x));
t4 = timeit(@() Adriaan2(x,n));

format long g;
fprintf('Luis 1: %f seconds\n', t1);
fprintf('Luis 2: %f seconds\n', t2);
fprintf('Adriaan 1: %f seconds\n', t3);
fprintf('Adriaan 2: %f seconds\n', t4);

end

Luis 1: 0.698139 seconds
Luis 2: 4.082378 seconds
Adriaan 1: 0.696034 seconds
Adriaan 2: 0.691597 seconds
Run Code Online (Sandbox Code Playgroud)

所以,去一个单元格是坏的,它需要的时间超过5倍,reshape并且':'几乎没有分开,所以这要归结为偏好.