提交测试
This commit is contained in:
88
examples/tutorial/dmtet_network.py
Normal file
88
examples/tutorial/dmtet_network.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
# MLP + Positional Encoding
|
||||
class Decoder(torch.nn.Module):
|
||||
def __init__(self, input_dims = 3, internal_dims = 128, output_dims = 4, hidden = 5, multires = 2):
|
||||
super().__init__()
|
||||
self.embed_fn = None
|
||||
if multires > 0:
|
||||
embed_fn, input_ch = get_embedder(multires)
|
||||
self.embed_fn = embed_fn
|
||||
input_dims = input_ch
|
||||
|
||||
net = (torch.nn.Linear(input_dims, internal_dims, bias=False), torch.nn.ReLU())
|
||||
for i in range(hidden-1):
|
||||
net = net + (torch.nn.Linear(internal_dims, internal_dims, bias=False), torch.nn.ReLU())
|
||||
net = net + (torch.nn.Linear(internal_dims, output_dims, bias=False),)
|
||||
self.net = torch.nn.Sequential(*net)
|
||||
|
||||
def forward(self, p):
|
||||
if self.embed_fn is not None:
|
||||
p = self.embed_fn(p)
|
||||
out = self.net(p)
|
||||
return out
|
||||
|
||||
def pre_train_sphere(self, iter):
|
||||
print ("Initialize SDF to sphere")
|
||||
loss_fn = torch.nn.MSELoss()
|
||||
optimizer = torch.optim.Adam(list(self.parameters()), lr=1e-4)
|
||||
|
||||
for i in tqdm(range(iter)):
|
||||
p = torch.rand((1024,3), device='cuda') - 0.5
|
||||
ref_value = torch.sqrt((p**2).sum(-1)) - 0.3
|
||||
output = self(p)
|
||||
loss = loss_fn(output[...,0], ref_value)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
print("Pre-trained MLP", loss.item())
|
||||
|
||||
|
||||
# Positional Encoding from https://github.com/yenchenlin/nerf-pytorch/blob/1f064835d2cca26e4df2d7d130daa39a8cee1795/run_nerf_helpers.py
|
||||
class Embedder:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
self.create_embedding_fn()
|
||||
|
||||
def create_embedding_fn(self):
|
||||
embed_fns = []
|
||||
d = self.kwargs['input_dims']
|
||||
out_dim = 0
|
||||
if self.kwargs['include_input']:
|
||||
embed_fns.append(lambda x : x)
|
||||
out_dim += d
|
||||
|
||||
max_freq = self.kwargs['max_freq_log2']
|
||||
N_freqs = self.kwargs['num_freqs']
|
||||
|
||||
if self.kwargs['log_sampling']:
|
||||
freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
|
||||
else:
|
||||
freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
|
||||
|
||||
for freq in freq_bands:
|
||||
for p_fn in self.kwargs['periodic_fns']:
|
||||
embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
|
||||
out_dim += d
|
||||
|
||||
self.embed_fns = embed_fns
|
||||
self.out_dim = out_dim
|
||||
|
||||
def embed(self, inputs):
|
||||
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
|
||||
|
||||
def get_embedder(multires):
|
||||
embed_kwargs = {
|
||||
'include_input' : True,
|
||||
'input_dims' : 3,
|
||||
'max_freq_log2' : multires-1,
|
||||
'num_freqs' : multires,
|
||||
'log_sampling' : True,
|
||||
'periodic_fns' : [torch.sin, torch.cos],
|
||||
}
|
||||
|
||||
embedder_obj = Embedder(**embed_kwargs)
|
||||
embed = lambda x, eo=embedder_obj : eo.embed(x)
|
||||
return embed, embedder_obj.out_dim
|
||||
Reference in New Issue
Block a user