import torch from tqdm import tqdm # MLP + Positional Encoding class Decoder(torch.nn.Module): def __init__(self, input_dims = 3, internal_dims = 128, output_dims = 4, hidden = 5, multires = 2): super().__init__() self.embed_fn = None if multires > 0: embed_fn, input_ch = get_embedder(multires) self.embed_fn = embed_fn input_dims = input_ch net = (torch.nn.Linear(input_dims, internal_dims, bias=False), torch.nn.ReLU()) for i in range(hidden-1): net = net + (torch.nn.Linear(internal_dims, internal_dims, bias=False), torch.nn.ReLU()) net = net + (torch.nn.Linear(internal_dims, output_dims, bias=False),) self.net = torch.nn.Sequential(*net) def forward(self, p): if self.embed_fn is not None: p = self.embed_fn(p) out = self.net(p) return out def pre_train_sphere(self, iter): print ("Initialize SDF to sphere") loss_fn = torch.nn.MSELoss() optimizer = torch.optim.Adam(list(self.parameters()), lr=1e-4) for i in tqdm(range(iter)): p = torch.rand((1024,3), device='cuda') - 0.5 ref_value = torch.sqrt((p**2).sum(-1)) - 0.3 output = self(p) loss = loss_fn(output[...,0], ref_value) optimizer.zero_grad() loss.backward() optimizer.step() print("Pre-trained MLP", loss.item()) # Positional Encoding from https://github.com/yenchenlin/nerf-pytorch/blob/1f064835d2cca26e4df2d7d130daa39a8cee1795/run_nerf_helpers.py class Embedder: def __init__(self, **kwargs): self.kwargs = kwargs self.create_embedding_fn() def create_embedding_fn(self): embed_fns = [] d = self.kwargs['input_dims'] out_dim = 0 if self.kwargs['include_input']: embed_fns.append(lambda x : x) out_dim += d max_freq = self.kwargs['max_freq_log2'] N_freqs = self.kwargs['num_freqs'] if self.kwargs['log_sampling']: freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) else: freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) for freq in freq_bands: for p_fn in self.kwargs['periodic_fns']: embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq)) out_dim += d self.embed_fns = embed_fns self.out_dim = out_dim def embed(self, inputs): return torch.cat([fn(inputs) for fn in self.embed_fns], -1) def get_embedder(multires): embed_kwargs = { 'include_input' : True, 'input_dims' : 3, 'max_freq_log2' : multires-1, 'num_freqs' : multires, 'log_sampling' : True, 'periodic_fns' : [torch.sin, torch.cos], } embedder_obj = Embedder(**embed_kwargs) embed = lambda x, eo=embedder_obj : eo.embed(x) return embed, embedder_obj.out_dim