Есть ли какая-либо функция pytorch, которая может объединить конкретные непрерывные размеры тензора в один?

1

Позвольте мне вызвать функцию, которую я ищу " magic_combine ", которая может объединить непрерывные измерения тензора, которые я им даю. Для более конкретного, я хочу, чтобы он делал следующее:

a = torch.zeros(1, 2, 3, 4, 5, 6)  
b = a.magic_combine(2, 5) # combine dimension 2, 3, 4 
print(b.size()) # should be (1, 2, 60, 6)

Я знаю, что torch.view() может делать аналогичную вещь. Но мне просто интересно, есть ли более элегантный способ достижения цели?

Теги:
deep-learning
tensor
pytorch

1 ответ

1
Лучший ответ

Я не уверен, что вы имеете в виду "более элегантный способ", но Tensor.view() имеет преимущество не перераспределять данные для представления (исходный тензор и представление имеют одни и те же данные), что делает эту операцию довольно легкий вес.

Как уже упоминалось в @UmangGupta, однако довольно просто перенести эту функцию для достижения того, что вы хотите, например:

import torch

def magic_combine(x, dim_begin, dim_end):
    combined_shape = list(x.shape[:dim_begin]) + [-1] + list(x.shape[dim_end:])
    return x.view(combined_shape)

a = torch.zeros(1, 2, 3, 4, 5, 6)
b = magic_combine(a, 2, 5) # combine dimension 2, 3, 4
print(b.size())
# torch.Size([1, 2, 60, 6])

Ещё вопросы

Сообщество Overcoder
Наверх
Меню