반응형
1. Tuple 여러개를 트리 구조로 보고 싶어졌다
AttributeError: 'tuple' object has no attribute 'shape'
tuple 구조일 때 shape을 보기 힘들어서 여러 dimension인 자료구조에 대해 대응하기 힘들었다
한번 len으로 알더라도 중첩되면 계속 에러 나오는게 신경쓰였다.
2. 코드
import torch
def inspect_shapes(structure, prefix=''):
"""
Recursively inspect and print the shapes and lengths of tensors,
tuples, or lists within a given structure.
Args:
structure: The data structure (tensor, tuple, or list) to inspect.
prefix: The prefix for print statements to indicate nesting level.
"""
if isinstance(structure, torch.Tensor):
# If it's a tensor, print its shape
print(f"{prefix}Tensor with shape: {structure.shape}")
elif isinstance(structure, (list, tuple)):
# If it's a list or tuple, print its length
print(f"{prefix}{type(structure).__name__} with length: {len(structure)}")
# Recursively inspect each element
for i, item in enumerate(structure):
inspect_shapes(item, prefix=prefix + f' [item {i}] ')
else:
# Unexpected type
print(f"{prefix}Unexpected type: {type(structure)}")
# Example usage within a hook or standalone
example_output = (
torch.randn(2, 128, 64, 64), # A tensor
[torch.randn(2, 64, 32, 32), torch.randn(2, 32, 16, 16)], # A list of tensors
(torch.randn(2, 16, 8, 8),) # A tuple containing a tensor
)
# Inspect the structure to check shapes and lengths
inspect_shapes(example_output)
여러 구조로 tuple, list가 섞인 경우 어떻게 생겼는지 트리 구조로 쓰도록 코드 구해서 잘 사용했다.
반응형