提交测试
This commit is contained in:
54
examples/recipes/spc/spc_basics.py
Normal file
54
examples/recipes/spc/spc_basics.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# ==============================================================================================================
|
||||
# The following snippet demonstrates the basic usage of kaolin's compressed octree,
|
||||
# termed "Structured Point Cloud (SPC)".
|
||||
# Note this is a low level structure: practitioners are encouraged to visit the references below.
|
||||
# ==============================================================================================================
|
||||
# See also:
|
||||
#
|
||||
# - Code: kaolin.ops.spc.SPC
|
||||
# https://kaolin.readthedocs.io/en/latest/modules/kaolin.rep.html?highlight=SPC#kaolin.rep.Spc
|
||||
#
|
||||
# - Tutorial: Understanding Structured Point Clouds (SPCs)
|
||||
# https://github.com/NVIDIAGameWorks/kaolin/blob/master/examples/tutorial/understanding_spcs_tutorial.ipynb
|
||||
#
|
||||
# - Documentation: Structured Point Clouds
|
||||
# https://kaolin.readthedocs.io/en/latest/modules/kaolin.ops.spc.html?highlight=spc#kaolin-ops-spc
|
||||
# ==============================================================================================================
|
||||
|
||||
import torch
|
||||
import kaolin
|
||||
|
||||
# Construct SPC from some points data. Point coordinates are expected to be normalized to the range [-1, 1].
|
||||
points = torch.tensor([[-1.0, -1.0, -1.0], [-0.9, -0.95, -1.0], [1.0, 1.0, 1.0]], device='cuda')
|
||||
|
||||
# In kaolin, operations are batched by default
|
||||
# Here, in contrast, we use a single point cloud and therefore invoke an unbatched conversion function.
|
||||
# The Structured Point Cloud will be using 3 levels of detail
|
||||
spc = kaolin.ops.conversions.pointcloud.unbatched_pointcloud_to_spc(pointcloud=points, level=3)
|
||||
|
||||
# SPC is a batched object, and most of its fields are packed.
|
||||
# (see: https://kaolin.readthedocs.io/en/latest/modules/kaolin.ops.batch.html#kaolin-ops-batch )
|
||||
# spc.length defines the boundaries between different batched SPC instances the same object holds.
|
||||
# Here we keep track of a single entry batch, which has 8 octree non-leaf cells.
|
||||
print(f'spc.batch_size: {spc.batch_size}')
|
||||
print(f'spc.lengths (cells per batch entry): {spc.lengths}')
|
||||
|
||||
# SPC is hierarchical and keeps information for every level of detail from 0 to 3.
|
||||
# spc.point_hierarchies keeps the sparse, zero indexed coordinates of each occupied cell, per level.
|
||||
print(f'SPC keeps track of total of {spc.point_hierarchies.shape[0]} parent + leaf cells:')
|
||||
|
||||
# To separate the boundaries, the spc.pyramids field is used.
|
||||
# This field is not-packed, unlike the other SPC fields.
|
||||
pyramid_of_first_entry_in_batch = spc.pyramids[0]
|
||||
cells_per_level = pyramid_of_first_entry_in_batch[0]
|
||||
cumulative_cells_per_level = pyramid_of_first_entry_in_batch[1]
|
||||
for i, lvl_cells in enumerate(cells_per_level[:-1]):
|
||||
print(f'LOD #{i} has {lvl_cells} cells.')
|
||||
|
||||
# The spc.octrees field keeps track of the fundamental occupancy information of each cell in the octree.
|
||||
print('The occupancy of each octant parent cell, in Morton / Z-curve order is:')
|
||||
print(['{0:08b}'.format(octree_byte) for octree_byte in spc.octrees])
|
||||
|
||||
# Since SPCs are low level objects, they require bookkeeping of multiple fields.
|
||||
# For ease of use, these fields are collected and tracked within a single class: kaolin.ops.spc.SPC
|
||||
# See references at the header for elaborate information on how to use this object.
|
||||
113
examples/recipes/spc/spc_conv3d_example.py
Normal file
113
examples/recipes/spc/spc_conv3d_example.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# ==============================================================================================================
|
||||
# The following code demonstrates the usage of kaolin's "Structured Point Cloud (SPC)" 3d convolution
|
||||
# functionality. Note that this sample does NOT demonstrate how to use Kaolin's Pytorch 3d convolution layers.
|
||||
# Rather, 3d convolutions are used to 'filter' color data useful for level-of-detail management during
|
||||
# rendering. This can be thought of as the 3d analog of generating a 2d mipmap.
|
||||
#
|
||||
# Note this is a low level interface: practitioners are encouraged to visit the references below.
|
||||
# ==============================================================================================================
|
||||
# See also:
|
||||
#
|
||||
# - Code: kaolin.ops.spc.SPC
|
||||
# https://kaolin.readthedocs.io/en/latest/modules/kaolin.rep.html?highlight=SPC#kaolin.rep.Spc
|
||||
#
|
||||
# - Tutorial: Understanding Structured Point Clouds (SPCs)
|
||||
# https://github.com/NVIDIAGameWorks/kaolin/blob/master/examples/tutorial/understanding_spcs_tutorial.ipynb
|
||||
#
|
||||
# - Documentation: Structured Point Clouds
|
||||
# https://kaolin.readthedocs.io/en/latest/modules/kaolin.ops.spc.html?highlight=spc#kaolin-ops-spc
|
||||
# ==============================================================================================================
|
||||
|
||||
import torch
|
||||
import kaolin
|
||||
|
||||
# The following function applies a series of SPC convolutions to encode the entire hierarchy into a single tensor.
|
||||
# Each step applies a convolution on the "highest" level of the SPC with some averaging kernel.
|
||||
# Therefore, each step locally averages the "colored point hierarchy", where each "colored point"
|
||||
# corresponds to a point in the SPC point hierarchy.
|
||||
# For a description of inputs 'octree', 'point_hierachy', 'level', 'pyramids', and 'exsum', as well a
|
||||
# detailed description of the mathematics of SPC convolutions, see:
|
||||
# https://kaolin.readthedocs.io/en/latest/modules/kaolin.ops.spc.html?highlight=SPC#kaolin.ops.spc.Conv3d
|
||||
# The input 'color' is Pytorch tensor containing color features corresponding to some 'level' of the hierarchy.
|
||||
def encode(colors, octree, point_hierachy, pyramids, exsum, level):
|
||||
|
||||
# SPC convolutions are characterized by a set of 'kernel vectors' and corresponding 'weights'.
|
||||
|
||||
# kernel_vectors is the "kernel support" -
|
||||
# a listing of 3D coordinates where the weights of the convolution are non-null,
|
||||
# in this case a it's a simple dense 2x2x2 grid.
|
||||
kernel_vectors = torch.tensor([[0,0,0],[0,0,1],[0,1,0],[0,1,1],
|
||||
[1,0,0],[1,0,1],[1,1,0],[1,1,1]],
|
||||
dtype=torch.short, device='cuda')
|
||||
|
||||
# The weights specify how the input colors 'under' the kernel are mapped to an output color,
|
||||
# in this case a simple average.
|
||||
weights = torch.diag(torch.tensor([0.125, 0.125, 0.125, 0.125],
|
||||
dtype=torch.float32, device='cuda')) # Tensor of (4, 4)
|
||||
weights = weights.repeat(8,1,1).contiguous() # Tensor of (8, 4, 4)
|
||||
|
||||
# Storage for the output color hierarchy is allocated. This includes points at the bottom of the hierarchy,
|
||||
# as well as intermediate SPC levels (which may store different features)
|
||||
color_hierarchy = torch.empty((pyramids[0,1,level+1],4), dtype=torch.float32, device='cuda')
|
||||
# Copy the input colors into the highest level of color_hierarchy. pyramids is used here to select all leaf
|
||||
# points at the bottom of the hierarchy and set them to some pre-sampled random color. Points at intermediate
|
||||
# levels are left empty.
|
||||
color_hierarchy[pyramids[0,1,level]:pyramids[0,1,level+1]] = colors[:]
|
||||
|
||||
# Performs the 3d convolutions in a bottom up fashion to 'filter' colors from the previous level
|
||||
for l in range(level,0,-1):
|
||||
|
||||
# Apply the 3d convolution. Note that jump=1 means the inputs and outputs differ by 1 level
|
||||
# This is analogous to to a stride=2 in grid based convolutions
|
||||
colors, ll = kaolin.ops.spc.conv3d(octree,
|
||||
point_hierachy,
|
||||
l,
|
||||
pyramids,
|
||||
exsum,
|
||||
colors,
|
||||
weights,
|
||||
kernel_vectors,
|
||||
jump=1)
|
||||
# Copy the output colors into the color hierarchy
|
||||
color_hierarchy[pyramids[0,1,ll]:pyramids[0,1,l]] = colors[:]
|
||||
print(f"At level {l}, output feature shape is:\n{colors.shape}")
|
||||
|
||||
# Normalize the colors.
|
||||
color_hierarchy /= color_hierarchy[:,3:]
|
||||
# Normalization is needed here due to the sparse nature of SPCs. When a point under a kernel is not
|
||||
# present in the point hierarchy, the corresponding data is treated as zeros. Normalization is equivalent
|
||||
# to having the filter weights sum to one. This may not always be desirable, e.g. alpha blending.
|
||||
|
||||
return color_hierarchy
|
||||
|
||||
|
||||
# Highest level of SPC
|
||||
level = 3
|
||||
|
||||
# Construct a fully occupied Structured Point Cloud with N levels of detail
|
||||
# See https://kaolin.readthedocs.io/en/latest/modules/kaolin.rep.html?highlight=SPC#kaolin.rep.Spc
|
||||
spc = kaolin.rep.Spc.make_dense(level, device='cuda')
|
||||
|
||||
# In kaolin, operations are batched by default, the spc object above contains a single item batch, hence [0]
|
||||
num_points_last_lod = spc.num_points(level)[0]
|
||||
|
||||
# Create tensor of random colors for all points in the highest level of detail
|
||||
colors = torch.rand((num_points_last_lod, 4), dtype=torch.float32, device='cuda')
|
||||
# Set 4th color channel to one for subsequent color normalization
|
||||
colors[:,3] = 1
|
||||
|
||||
print(f'Input SPC features: {colors.shape}')
|
||||
|
||||
# Encode color hierarchy by invoking a series of convolutions, until we end up with a single tensor.
|
||||
color_hierarchy = encode(colors=colors,
|
||||
octree=spc.octrees,
|
||||
point_hierachy=spc.point_hierarchies,
|
||||
pyramids=spc.pyramids,
|
||||
exsum=spc.exsum,
|
||||
level=level)
|
||||
|
||||
# Print root node color
|
||||
print(f'Final encoded value (average of averages):')
|
||||
print(color_hierarchy[0])
|
||||
# This will be the average of averages, over the entire spc hierarchy. Since the initial random colors
|
||||
# came from a uniform distribution, this should approach [0.5, 0.5, 0.5, 1.0] as 'level' increases
|
||||
73
examples/recipes/spc/spc_dual_octree.py
Normal file
73
examples/recipes/spc/spc_dual_octree.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# ==============================================================================================================
|
||||
# The following snippet demonstrates the basic usage of kaolin's dual octree, an octree which keeps features
|
||||
# at the 8 corners of each cell (the primary octree keeps a single feature at each cell center).
|
||||
# The implementation is realized through kaolin's "Structured Point Cloud (SPC)".
|
||||
# Note this is a low level structure: practitioners are encouraged to visit the references below.
|
||||
# ==============================================================================================================
|
||||
# See also:
|
||||
#
|
||||
# - Code: kaolin.ops.spc.SPC
|
||||
# https://kaolin.readthedocs.io/en/latest/modules/kaolin.rep.html?highlight=SPC#kaolin.rep.Spc
|
||||
#
|
||||
# - Tutorial: Understanding Structured Point Clouds (SPCs)
|
||||
# https://github.com/NVIDIAGameWorks/kaolin/blob/master/examples/tutorial/understanding_spcs_tutorial.ipynb
|
||||
#
|
||||
# - Documentation: Structured Point Clouds
|
||||
# https://kaolin.readthedocs.io/en/latest/modules/kaolin.ops.spc.html?highlight=spc#kaolin-ops-spc
|
||||
# ==============================================================================================================
|
||||
|
||||
import torch
|
||||
import kaolin
|
||||
|
||||
# Construct SPC from some points data. Point coordinates are expected to be normalized to the range [-1, 1].
|
||||
# To keep the example readable, by default we set the SPC level to 1: root + 8 cells
|
||||
# (note that with a single LOD, only 2 cells should be occupied due to quantization)
|
||||
level = 1
|
||||
points = torch.tensor([[-1.0, -1.0, -1.0], [-0.9, -0.95, -1.0], [1.0, 1.0, 1.0]], device='cuda')
|
||||
spc = kaolin.ops.conversions.pointcloud.unbatched_pointcloud_to_spc(pointcloud=points, level=level)
|
||||
|
||||
# Construct the dual octree with an unbatched operation, each cell is now converted to 8 corners
|
||||
# More info about batched / packed tensors at:
|
||||
# https://kaolin.readthedocs.io/en/latest/modules/kaolin.ops.batch.html#kaolin-ops-batch
|
||||
pyramid = spc.pyramids[0] # The pyramids field is batched, we select the singleton entry, #0
|
||||
point_hierarchy = spc.point_hierarchies # point_hierarchies is a packed tensor, so no need to unbatch
|
||||
point_hierarchy_dual, pyramid_dual = kaolin.ops.spc.unbatched_make_dual(point_hierarchy=point_hierarchy,
|
||||
pyramid=pyramid)
|
||||
|
||||
# Let's compare the primary and dual octrees.
|
||||
# The function 'unbatched_get_level_points' yields a tensor which lists all points / sparse cell coordinates occupied
|
||||
# at a certain level.
|
||||
# [Primary octree] [Dual octree]
|
||||
# . . . . . . . . X . . .X. . . X
|
||||
# | . X . X | . | . . | .
|
||||
# | . . . . . . . . ===> | X . . X . . . X
|
||||
# | | . X . | X . X | . . | .
|
||||
# | | . . . . . . . . | | X . . .X. . . X
|
||||
# | | | | | | X | | |
|
||||
# . .|. . | . . . | ===> X .|. . X . . X |
|
||||
# .| X |. X . | .| |. . X
|
||||
# . . | . . . . . | X . | . X . . X |
|
||||
# . | X . X . | . | . . |
|
||||
# . . . . . . . . X . . X . . . X
|
||||
#
|
||||
primary_lod0 = kaolin.ops.spc.unbatched_get_level_points(point_hierarchy, pyramid, level=0)
|
||||
primary_lod1 = kaolin.ops.spc.unbatched_get_level_points(point_hierarchy, pyramid, level=1)
|
||||
dual_lod0 = kaolin.ops.spc.unbatched_get_level_points(point_hierarchy_dual, pyramid_dual, level=0)
|
||||
dual_lod1 = kaolin.ops.spc.unbatched_get_level_points(point_hierarchy_dual, pyramid_dual, level=1)
|
||||
print(f'Primary octree: Level 0 (root cells): \n{primary_lod0}')
|
||||
print(f'Dual octree: Level 0 (root corners): \n{dual_lod0}')
|
||||
print(f'Primary octree: Level 1 (cells): \n{primary_lod1}')
|
||||
print(f'Dual octree: Level 1 (corners): \n{dual_lod1}')
|
||||
|
||||
# kaolin allows for interchangeable usage of the primary and dual octrees.
|
||||
# First we have to create a mapping between them:
|
||||
trinkets, _ = kaolin.ops.spc.unbatched_make_trinkets(point_hierarchy, pyramid, point_hierarchy_dual, pyramid_dual)
|
||||
|
||||
# trinkets are indirection pointers (in practice, indices) from the nodes of the primary octree
|
||||
# to the nodes of the dual octree. The nodes of the dual octree represent the corners of the voxels
|
||||
# defined by the primary octree.
|
||||
print(f'point_hierarchy is of shape {point_hierarchy.shape}')
|
||||
print(f'point_hierarchy_dual is of shape {point_hierarchy_dual.shape}')
|
||||
print(f'trinkets is of shape {trinkets.shape}')
|
||||
print(f'Trinket indices are multilevel: {trinkets}')
|
||||
# See also spc_trilinear_interp.py for a practical application which uses the dual octree & trinkets
|
||||
69
examples/recipes/spc/spc_trilinear_interp.py
Normal file
69
examples/recipes/spc/spc_trilinear_interp.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# ==============================================================================================================
|
||||
# The following snippet demonstrates the basic usage of kaolin's dual octree, an octree which keeps features
|
||||
# at the 8 corners of each cell (the primary octree keeps a single feature at each cell center).
|
||||
# In this example we sample an interpolated value according to the 8 corners of a cell.
|
||||
# The implementation is realized through kaolin's "Structured Point Cloud (SPC)".
|
||||
# Note this is a low level structure: practitioners are encouraged to visit the references below.
|
||||
# ==============================================================================================================
|
||||
# See also:
|
||||
#
|
||||
# - Code: kaolin.ops.spc.SPC
|
||||
# https://kaolin.readthedocs.io/en/latest/modules/kaolin.rep.html?highlight=SPC#kaolin.rep.Spc
|
||||
#
|
||||
# - Tutorial: Understanding Structured Point Clouds (SPCs)
|
||||
# https://github.com/NVIDIAGameWorks/kaolin/blob/master/examples/tutorial/understanding_spcs_tutorial.ipynb
|
||||
#
|
||||
# - Documentation: Structured Point Clouds
|
||||
# https://kaolin.readthedocs.io/en/latest/modules/kaolin.ops.spc.html?highlight=spc#kaolin-ops-spc
|
||||
# ==============================================================================================================
|
||||
|
||||
import torch
|
||||
import kaolin
|
||||
|
||||
# Construct SPC from some points data. Point coordinates are expected to be normalized to the range [-1, 1].
|
||||
# To keep the example readable, by default we set the SPC level to 1: root + 8 cells
|
||||
# (note that with a single LOD, only 2 cells should be occupied due to quantization)
|
||||
level = 1
|
||||
points = torch.tensor([[-1.0, -1.0, -1.0], [-0.9, -0.95, -1.0], [1.0, 1.0, 1.0]], device='cuda')
|
||||
spc = kaolin.ops.conversions.pointcloud.unbatched_pointcloud_to_spc(pointcloud=points, level=level)
|
||||
|
||||
# Construct the dual octree with an unbatched operation, each cell is now converted to 8 corners
|
||||
# More info about batched / packed tensors at:
|
||||
# https://kaolin.readthedocs.io/en/latest/modules/kaolin.ops.batch.html#kaolin-ops-batch
|
||||
pyramid = spc.pyramids[0] # The pyramids field is batched, we select the singleton entry, #0
|
||||
point_hierarchy = spc.point_hierarchies # point_hierarchies is a packed tensor, so no need to unbatch
|
||||
point_hierarchy_dual, pyramid_dual = kaolin.ops.spc.unbatched_make_dual(point_hierarchy=point_hierarchy,
|
||||
pyramid=pyramid)
|
||||
# kaolin allows for interchangeable usage of the primary and dual octrees via the "trinkets" mapping
|
||||
# trinkets are indirection pointers (in practice, indices) from the nodes of the primary octree
|
||||
# to the nodes of the dual octree. The nodes of the dual octree represent the corners of the voxels
|
||||
# defined by the primary octree.
|
||||
trinkets, _ = kaolin.ops.spc.unbatched_make_trinkets(point_hierarchy, pyramid, point_hierarchy_dual, pyramid_dual)
|
||||
|
||||
# We'll now apply the dual octree and trinkets to perform trilinaer interpolation.
|
||||
# First we'll generate some features for the corners.
|
||||
# The first dimension of pyramid / pyramid_dual specifies how many unique points exist per level.
|
||||
# For the pyramid_dual, this means how many "unique corners" are in place (as neighboring cells may share corners!)
|
||||
num_of_corners_at_last_lod = pyramid_dual[0, level]
|
||||
feature_dims = 32
|
||||
feats = torch.rand([num_of_corners_at_last_lod, feature_dims], device='cuda')
|
||||
|
||||
# Create some query coordinate with normalized values in the range [-1, 1], here we pick (0.5, 0.5, 0.5).
|
||||
# We'll also modify the dimensions of the query tensor to match the interpolation function api:
|
||||
# batch dimension refers to the unique number of spc cells we're querying.
|
||||
# samples_count refers to the number of interpolations we perform per cell.
|
||||
query_coord = points.new_tensor((0.5, 0.5, 0.5)).unsqueeze(0) # Tensor of (batch, 3), in this case batch=1
|
||||
sampled_query_coords = query_coord.unsqueeze(1) # Tensor of (batch, samples_count, 3), in this case samples_count=1
|
||||
|
||||
# unbatched_query converts from normalized coordinates to the index of the cell containing this point.
|
||||
# The query_index can be used to pick the point from point_hierarchy
|
||||
query_index = kaolin.ops.spc.unbatched_query(spc.octrees, spc.exsum, query_coord, level, with_parents=False)
|
||||
|
||||
# The unbatched_interpolate_trilinear function uses the query coordinates to perform trilinear interpolation.
|
||||
# Here, unbatched specifies this function supports only a single SPC at a time.
|
||||
# Per single SPC, we may interpolate a batch of coordinates and samples
|
||||
interpolated = kaolin.ops.spc.unbatched_interpolate_trilinear(coords=sampled_query_coords,
|
||||
pidx=query_index.int(),
|
||||
point_hierarchy=point_hierarchy,
|
||||
trinkets=trinkets, feats=feats, level=level)
|
||||
print(f'Interpolated a tensor of shape {interpolated.shape} with values: {interpolated}')
|
||||
Reference in New Issue
Block a user