Files
kaolin/kaolin/utils/testing.py
2024-01-16 17:22:21 +08:00

433 lines
18 KiB
Python

# Copyright (c) 2019,20-21-22 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import collections
import functools
import logging
import numpy as np
import torch
import kaolin.ops.random as random
from kaolin.ops.spc.uint8 import uint8_bits_sum
BOOL_DTYPES = [torch.bool]
INT_DTYPES = [torch.uint8, torch.short, torch.int, torch.long]
FLOAT_DTYPES = [torch.half, torch.float, torch.double]
NUM_DTYPES = INT_DTYPES + FLOAT_DTYPES
ALL_DTYPES = NUM_DTYPES + BOOL_DTYPES
ALL_DEVICES = ['cpu', 'cuda']
BOOL_TYPES = [('cuda', torch.bool), ('cpu', torch.bool)]
INT_TYPES = [(device, dtype) for device in ALL_DEVICES for dtype in INT_DTYPES]
FLOAT_TYPES = [('cuda', dtype) for dtype in FLOAT_DTYPES] + \
[('cpu', dtype) for dtype in FLOAT_DTYPES if dtype != torch.half]
CUDA_FLOAT_TYPES = [('cuda', dtype) for dtype in FLOAT_DTYPES]
NUM_TYPES = INT_TYPES + FLOAT_TYPES
ALL_TYPES = NUM_TYPES + BOOL_TYPES
def with_seed(torch_seed=0, numpy_seed=None, random_seed=None):
"""Decorator to fix the seed of a function.
Args:
torch_seed (int): The desired seed for torch module.
random_seed (int): The desired seed for random module. Default: torch_seed value.
numpy_seed (int): The desired seed for numpy module. Default: torch_seed value.
"""
def decorator(orig_test):
@functools.wraps(orig_test)
def orig_test_wrapper(*args, **kwargs):
torch_state, random_state, np_state = random.get_state()
random.manual_seed(torch_seed, numpy_seed, random_seed)
output = orig_test(*args, **kwargs)
random.set_state(torch_state, random_state, np_state)
return output
return orig_test_wrapper
return decorator
def check_tensor(tensor, shape=None, dtype=None, device=None, throw=True):
"""Check if :class:`torch.Tensor` is valid given set of criteria.
Args:
tensor (torch.Tensor): the tensor to be tested.
shape (list or tuple of int, optional): the expected shape,
if a dimension is set at ``None`` then it's not verified.
dtype (torch.dtype, optional): the expected dtype.
device (torch.device, optional): the expected device.
throw (bool): if true (default), will throw if checks fail
Return:
(bool) True if checks pass
"""
if shape is not None:
if len(shape) != tensor.ndim:
if throw:
raise ValueError(f"tensor have {tensor.ndim} ndim, should have {len(shape)}")
return False
for i, dim in enumerate(shape):
if dim is not None and tensor.shape[i] != dim:
if throw:
raise ValueError(f"tensor shape is {tensor.shape}, should be {shape}")
return False
if dtype is not None and dtype != tensor.dtype:
if throw:
raise TypeError(f"tensor dtype is {tensor.dtype}, should be {dtype}")
return False
if device is not None and device != tensor.device.type:
if throw:
raise TypeError(f"tensor device is {tensor.device.type}, should be {device}")
return False
return True
def check_packed_tensor(tensor, total_numel=None, last_dim=None, dtype=None, device=None,
throw=True):
"""Check if :ref:`packed tensor<packed>` is valid given set of criteria.
Args:
tensor (torch.Tensor): the packed tensor to be tested.
total_numel (int, optional): the expected number of elements.
last_dim (int, optional): the expected last dimension size.
dtype (torch.dtype, optional): the expected dtype.
device (torch.device, optional): the expected device.
throw (bool): if True the check will raise an error if failing.
Return:
(bool): status of the check.
"""
if not check_tensor(tensor, shape=(None, None), dtype=dtype, device=device, throw=throw):
return False
if total_numel is not None and tensor.shape[0] != total_numel:
if throw:
raise ValueError(f"tensor total number of elements is {tensor.shape[0]}, "
f"should be {total_numel}")
return False
if last_dim is not None and last_dim != tensor.shape[-1]:
if throw:
raise ValueError(f"tensor last_dim is {tensor.shape[-1]}, should be {last_dim}")
return False
return True
def check_padded_tensor(tensor, padding_value=None, shape_per_tensor=None,
batch_size=None, max_shape=None, last_dim=None,
dtype=None, device=None, throw=True):
"""Check if :ref:`padded tensor<padded>` is valid given set of criteria.
Args:
tensor (torch.Tensor): the padded tensor to be tested.
padding_value (int, optional): the expected number of elements,
:attr:`shape_per_tensor` must be provided with padding_value.
shape_per_tensor (torch.LongTensor, optional): the expected :attr:`shape_per_tensor`.
batch_size (int, optional): the expected batch size.
last_dim (int, optional): the expected last dimension size.
dtype (torch.dtype, optional): the expected dtype.
device (torch.device, optional): the expected device.
throw (bool): if True the check will raise an error if failing.
Return:
(bool): status of the check.
"""
if not check_tensor(tensor, dtype=dtype, device=device, throw=throw):
return False
if shape_per_tensor is not None:
if batch_size is not None and batch_size != shape_per_tensor.shape[0]:
if throw:
raise ValueError(f"batch_size is {batch_size}, "
f"but there are {shape_per_tensor.shape[0]} "
f"shapes in shape_per_tensor")
return False
batch_size = shape_per_tensor.shape[0]
if batch_size is not None and batch_size != tensor.shape[0]:
if throw:
raise ValueError(f"tensor batch size is {tensor.shape[0]}, should be {batch_size}")
return False
if max_shape is not None:
for i, dim in enumerate(max_shape, 1):
if dim is not None and dim != tensor.shape[i]:
if throw:
raise ValueError(f"tensor max_shape is {tensor.shape[1:-1]}, should be {max_shape}")
return False
if last_dim is not None and last_dim != tensor.shape[-1]:
if throw:
raise ValueError(f"tensor last_dim is {tensor.shape[-1]}, should be {last_dim}")
return False
if padding_value is not None:
if shape_per_tensor is None:
raise ValueError("shape_per_tensor should not be None if padding_value is set")
mask = torch.ones(tensor.shape, dtype=torch.bool, device=tensor.device)
for i, shape in enumerate(shape_per_tensor):
mask[[i] + [slice(dim) for dim in shape]] = False
if any(tensor[mask] != padding_value):
if throw:
first_false_coord = tuple(
int(l[0]) for l in torch.where((tensor != padding_value) & mask))
raise ValueError(f"tensor padding at {first_false_coord} is {tensor[first_false_coord]}, "
f"should be {padding_value}")
return False
return True
def check_spc_octrees(octrees, lengths, batch_size=None, level=None,
device=None, throw=True):
if batch_size is not None and (batch_size,) != lengths.shape:
if throw:
raise ValueError(f"lengths is of shape {lengths.shape}, "
f"but batch_size should be {batch_size}")
return False
if device is not None and device != octrees.device.type:
if throw:
raise ValueError(f"octrees is on {octrees.device}, "
f"should be on {device}.")
return False
octree_start_idx = 0
for i, length in enumerate(lengths):
cur_node_idx = 0
cur_num_nodes = 1
cur_level = 0
octree = octrees[octree_start_idx:octree_start_idx + length]
while cur_node_idx < length:
cur_level += 1
cur_level_nodes = octree[cur_node_idx:cur_node_idx + cur_num_nodes]
cur_node_idx += cur_num_nodes
cur_num_nodes = int(torch.sum(uint8_bits_sum(cur_level_nodes).long()))
if cur_node_idx > length:
if throw:
raise ValueError(f"lengths at {i} is {length}, "
f"but level {cur_level} ends at length {cur_node_idx}")
return False
if level is not None and level != cur_level:
if throw:
raise ValueError(f"octree {i} ends at level {cur_level}, "
f"should end at {level}")
return False
octree_start_idx += length
return True
def tensor_info(t, name='', print_stats=False, detailed=False):
"""
Convenience method to format diagnostic tensor information, including
shape, type, and optional attributes if specified as string.
This information can then be logged as:
logger.debug(tensor_info(my_tensor, 'my tensor'))
Log output:
my_tensor: [10, 2, 100, 100] (torch.float32)
Args:
t: input pytorch tensor or numpy array or None
name: human readable name of the tensor (optional)
print_stats: if True, includes mean/max/min statistics (takes compute time)
detailed: if True, includes details about tensor properties
Returns:
(String) formatted string
Examples:
>>> t = torch.Tensor([0., 2., 3.])
>>> tensor_info(t, 'mytensor', True, True)
'mytensor: torch.Size([3]) (torch.float32) - [min 0.0000, max 3.0000, mean 1.6667] - req_grad=False, is_leaf=True, device=cpu, layout=torch.strided'
"""
def _get_stats_str():
if torch.is_tensor(t):
return ' - [min %0.4f, max %0.4f, mean %0.4f]' % \
(torch.min(t).item(),
torch.max(t).item(),
torch.mean(t.to(torch.float32)).item())
elif type(t) == np.ndarray:
return ' - [min %0.4f, max %0.4f, mean %0.4f]' % (np.min(t), np.max(t), np.mean(t))
else:
raise RuntimeError('Not implemented for {}'.format(type(t)))
def _get_details_str():
if torch.is_tensor(t):
return ' - req_grad={}, is_leaf={}, layout={}'.format(
t.requires_grad, t.is_leaf, t.layout)
if t is None:
return '%s: None' % name
shape_str = ''
if hasattr(t, 'shape'):
shape_str = '%s ' % str(list(t.shape))
if hasattr(t, 'dtype'):
type_str = '%s' % str(t.dtype)
else:
type_str = '{}'.format(type(t))
device_str = ''
if hasattr(t, 'device'):
device_str = '[{}]'.format(t.device)
name_str = ''
if name is not None and len(name) > 0:
name_str = '%s: ' % name
return ('%s%s(%s)%s %s %s' %
(name_str, shape_str, type_str, device_str,
(_get_stats_str() if print_stats else ''),
(_get_details_str() if detailed else '')))
def contained_torch_equal(elem, other, approximate=False, print_error_context=None, **allclose_args):
"""Check for equality (or allclose if approximate) of two objects potentially containing tensors.
:func:`torch.equal` do not support data structure like dictionary / arrays
and `==` is ambiguous on :class:`torch.Tensor`.
This class will try to apply recursion through :class:`collections.abc.Mapping`,
:class:`collections.abc.Sequence`, :func:`torch.equal` if the objects are `torch.Tensor`,
of else `==` operator.
Args:
elem (object, dict, list, tuple): The first object
other (object, dict, list, tuple): The other object to compare to ``elem``
approximate (bool): if requested will use allclose for comparison instead (default=False)
print_error_context (str): set to any string value to print the context for the first nested failed match
allclose_args: arguments to `torch.allclose` if approximate comparison requested
Return (bool): the comparison result
"""
def _maybe_print(val, extra_context='', prefix_string='Failed match for '):
if not val and print_error_context is not None: # match failed
print(f'{prefix_string}{print_error_context}{extra_context}')
return val
elem_type = type(elem)
if elem_type != type(other):
return _maybe_print(False)
def _tensor_compare(a, b):
if not approximate:
return torch.equal(a, b)
else:
return torch.allclose(a, b, **allclose_args)
def _number_compare(a, b):
return _tensor_compare(torch.tensor([a]), torch.tensor([b]))
def _attrs_to_dict(a, attrs):
return {k : getattr(a, k) for k in attrs if hasattr(a, k)}
def _recursive_error_context(append_context):
if print_error_context is None:
return None
return f'{print_error_context}{append_context}'
recursive_args = copy.copy(allclose_args)
recursive_args['approximate'] = approximate
if isinstance(elem, torch.Tensor):
return _maybe_print(_tensor_compare(elem, other))
elif isinstance(elem, str):
return _maybe_print(elem == other, extra_context=f': {elem} vs {other}')
elif isinstance(elem, float):
return _maybe_print(_number_compare(elem, other), extra_context=f': {elem} vs {other}')
elif isinstance(elem, collections.abc.Mapping):
if elem.keys() != other.keys():
return _maybe_print(False, f': {elem.keys()} vs {other.keys()}', 'Different keys for ')
return all(contained_torch_equal(
elem[key], other[key],
print_error_context=_recursive_error_context(f'[{key}]'), **recursive_args) for key in elem)
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
if set(elem._fields) != set(other._fields):
return _maybe_print(False, f': {elem._fields} vs {other._fields}', 'Different fields for ')
return all(contained_torch_equal(
getattr(elem, f), getattr(other, f),
print_error_context=_recursive_error_context(f'.{f}'), **recursive_args) for f in elem._fields)
elif isinstance(elem, collections.abc.Sequence):
if len(elem) != len(other):
return _maybe_print(False, ': {len(elem)} vs {len(other)}', 'Different length for ')
return all(contained_torch_equal(
a, b, print_error_context=_recursive_error_context(f'[{i}]'), **recursive_args)
for i, (a, b) in enumerate(zip(elem, other)))
elif hasattr(elem, '__slots__'):
return contained_torch_equal(_attrs_to_dict(elem, elem.__slots__), _attrs_to_dict(other, other.__slots__),
print_error_context=print_error_context, **recursive_args)
else:
return _maybe_print(elem == other)
def check_allclose(tensor, other, rtol=1e-5, atol=1e-8, equal_nan=False):
if not torch.allclose(tensor, other, atol=atol, rtol=rtol, equal_nan=equal_nan):
diff_idx = torch.where(~torch.isclose(tensor, other, atol=atol, rtol=rtol, equal_nan=equal_nan))
raise ValueError(f"Tensors are not close on indices {diff_idx}:",
f"Example values: {tensor[diff_idx][:10]} vs {other[diff_idx][:10]}.")
def check_tensor_attribute_shapes(container, throw=True, **attribute_info):
"""Checks shape on all specified attributes of the container.
Args:
container (dict, tuple, object): container with named attributes to be tested
throw (bool): if true (default), will throw error on first check that fails
attribute_info: named attribute=shape values, where shape can be list or tuple (see `check_tensor`)
Return:
(bool) True if checks pass
"""
def _get_item(container, attr):
if isinstance(container, collections.abc.Mapping):
return container[attr]
else:
return getattr(container, attr)
success = True
for k, shape in attribute_info.items():
val = _get_item(container, k)
if not check_tensor(val, shape=shape, throw=False):
success = False
message = f'Attribute {k} has shape {val.shape} (expected {shape})'
if throw:
raise ValueError(message)
else:
logging.error(message)
return success
def print_namedtuple_attributes(ntuple, name='', prefix='', **tensor_info_kwargs):
r"""Same as :func:`print_dict_attributes`, but with named tuple input. """
print_dict_attributes(ntuple._asdict(), name=name, prefix=prefix, **tensor_info_kwargs)
def print_dict_attributes(in_dict, name='', prefix='', **tensor_info_kwargs):
r"""Convenience function to print all attributes of a dict with names. Extra info
will be provided for tensors.
Args:
in_dict (dict): input dictionary containing any values
name (str): name of the input dictionary (especially useful if printing nested dictionaries)
prefix (str): prefix to provide to every printed attributes (especially useful if printing nested dictionaries)
**tensor_info_kwargs: extra arguments to pass to :func:`tensor_info`
"""
if len(name) > 0:
print(f'\nAttributes of {name}:')
for k, v in in_dict.items():
recurse = False
if torch.is_tensor(v):
tinfo = tensor_info(v, **tensor_info_kwargs)
elif isinstance(v, (str, int, float)):
tinfo = v
elif isinstance(v, collections.abc.Sequence):
tinfo = f'{type(v)} of length {len(v)}'
elif isinstance(v, collections.abc.Mapping):
tinfo = f'{type(v)} of length {len(v)}'
recurse = True
else:
tinfo = type(v)
print(f' {prefix}{k}: {tinfo}')
if recurse:
print_dict_attributes(v, prefix=f' ')