In [1]:
!pip install matplotlib --quiet
import glob
import math
from matplotlib import pyplot as plt
import nvdiffrast
import os
import torch

from tutorial_common import COMMON_DATA_DIR
import kaolin as kal

glctx = nvdiffrast.torch.RasterizeGLContext(False, device='cuda')

[0m

## Load Mesh information

In [2]:
# Set KAOLIN_TEST_SHAPENETV2_PATH env variable, or replace by your shapenet path
SHAPENETV2_PATH = os.getenv('KAOLIN_TEST_SHAPENETV2_PATH')

if SHAPENETV2_PATH is not None:
    ds = kal.io.shapenet.ShapeNetV2(root=SHAPENETV2_PATH,
                                    categories=['car'],
                                    train=True, split=1.,
                                    with_materials=True,
                                    output_dict=True)
    mesh = ds[0]['mesh']
else:
    # Load a specific obj instead
    OBJ_PATH = os.path.join(COMMON_DATA_DIR, 'meshes', 'fox.obj')
    mesh = kal.io.obj.import_mesh(OBJ_PATH, with_materials=True, triangulate=True)

# Batch, move to GPU and center and normalize vertices in the range [-0.5, 0.5]
mesh = mesh.to_batched().cuda()
mesh.vertices = kal.ops.pointcloud.center_points(mesh.vertices, normalize=True)
print(mesh.to_string(print_stats=True))

# Preprocess OBJ materials into a format we will use for rendering 
materials = [m['map_Kd'].unsqueeze(0).cuda().float() / 255. if 'map_Kd' in m else
             m['Kd'].reshape(1, 1, 1, 3).cuda()
             for m in mesh.materials[0]]

# Use a single diffuse color as backup when map doesn't exist (and face_uvs_idx == -1)
mesh.uvs = torch.nn.functional.pad(mesh.uvs, (0, 0, 0, 1))
mesh.face_uvs_idx[mesh.face_uvs_idx == -1] = mesh.uvs.shape[1] - 1

SurfaceMesh object with batching strategy FIXED
            vertices: [1, 5002, 3] (torch.float32)[cuda:0]  - [min -0.5000, max 0.5000, mean 0.0200] 
               faces: [10000, 3] (torch.int64)[cuda:0]  - [min 0.0000, max 5001.0000, mean 2471.1404] 
                 uvs: [1, 5505, 2] (torch.float32)[cuda:0]  - [min 0.0140, max 0.9710, mean 0.5101] 
        face_uvs_idx: [1, 10000, 3] (torch.int64)[cuda:0]  - [min 0.0000, max 5504.0000, mean 2662.6140] 
material_assignments: [1, 10000] (torch.int16)[cuda:0]  - [min 0.0000, max 0.0000, mean 0.0000] 
           materials: [
                      0: list of length 1
                      ]
       face_vertices: if possible, computed on access from: (faces, vertices)
        face_normals: if possible, computed on access from: (normals, face_normals_idx) or (vertices, faces)
            face_uvs: if possible, computed on access from: (uvs, face_uvs_idx)
      vertex_normals: if possible, computed on access from: (faces, face_normals)
    

## Instantiate a camera

With the general constructor `Camera.from_args()` the underlying constructors are `CameraExtrinsics.from_lookat()` and `PinholeIntrinsics.from_fov`

In [3]:
cam = kal.render.camera.Camera.from_args(eye=torch.tensor([2., 0., 0.]),
                                         at=torch.tensor([0., 0., 0.]),
                                         up=torch.tensor([0., 1., 0.]),
                                         fov=math.pi * 45 / 180,
                                         width=512, height=512, device='cuda')

## Rendering a mesh

Here we are rendering the loaded mesh with [nvdiffrast](https://github.com/NVlabs/nvdiffrast) using the camera object created above

In [4]:
def render(mesh, cam):
    vertices_camera = cam.extrinsics.transform(mesh.vertices)
    face_vertices_camera = kal.ops.mesh.index_vertices_by_faces(
        vertices_camera, mesh.faces)
    face_normals_z = kal.ops.mesh.face_normals(
        face_vertices_camera,
        unit=True
    )[..., -1:].contiguous()
    
    # Create a fake W (See nvdiffrast documentation)
    proj = cam.projection_matrix()[None]
    homogeneous_vecs = kal.render.camera.up_to_homogeneous(
        vertices_camera
    )[..., None]

    vertices_clip = (proj @ homogeneous_vecs).squeeze(-1)

    rast = nvdiffrast.torch.rasterize(
        glctx, vertices_clip, mesh.faces.int(),
        (cam.height, cam.width), grad_db=False
    )
    rast0 = torch.flip(rast[0], dims=(1,))
    hard_mask = rast0[:, :, :, -1:] != 0
    face_idx = (rast0[..., -1].long() - 1).contiguous()

    uv_map = nvdiffrast.torch.interpolate(
        mesh.uvs, rast0, mesh.face_uvs_idx[0, ...].int()
    )[0] % 1.

    img = torch.zeros((1, 512, 512, 3), dtype=torch.float, device='cuda')
    # Obj meshes can be composed of multiple materials
    # so at rendering we need to interpolate from corresponding materials
    im_material_idx = mesh.material_assignments[0, ...][face_idx]  # Assume single assignments map
    im_material_idx[face_idx == -1] = -1
    for i, material in enumerate(materials):
        mask = im_material_idx == i
        mask_idx = torch.nonzero(mask, as_tuple=False)
        _texcoords = uv_map[mask]
        _texcoords[:, 1] = -_texcoords[:, 1]
        if _texcoords.shape[0] > 0:
            pixel_val = nvdiffrast.torch.texture(
                materials[i].contiguous(),
                _texcoords.reshape(1, 1, -1, 2).contiguous(),
                filter_mode='linear'
            )
            img[mask] = pixel_val[0, 0]

    # Need to flip the image because opengl
    return torch.clamp(img * hard_mask, 0., 1.)[0]

# Moving the camera

Once the camera is created you can move it using `cam.move_up()`, `cam.move_right` and `cam.move_forward()`.

To be noted that in OpenGL `forward` in the camera space is actually toward the viewer (so it actually move away from the object looked at)

<img src="./assets/ndc_camera_space.png"
     alt="Markdown Monster icon"
     style="float:left;margin-right:10px;width:400px" />
<img src="./assets/ndc_image_space.png"
     alt="Markdown Monster icon"
     style="float:left;margin-right:10px;width:300px" />

Below is a simple interactive rendering, where buttons are linked to camera methods for moving it.

In [5]:
try:
    ipy_str = str(type(get_ipython()))
    if 'zmqshell' in ipy_str:
        %matplotlib notebook
finally:
    pass

from matplotlib.widgets import Button

fig, ax = plt.subplots()
plt.subplots_adjust(bottom=0.3)
im_buffer = plt.imshow(render(mesh, cam).cpu())

def update():
    """Update the image buffer"""
    im_buffer.set_data(render(mesh, cam).cpu())
    plt.draw()

def on_button_up_clicked(b):
    """Callback on Up"""
    cam.move_up(0.1)
    update()

def on_button_down_clicked(b):
    """Callback on Down"""
    cam.move_up(-0.1)
    update()

def on_button_left_clicked(b):
    """Callback on Left"""
    cam.move_right(-0.1)
    update()

def on_button_right_clicked(b):
    """Callback on Right"""
    cam.move_right(0.1)
    update()

def on_button_forward_clicked(b):
    """Callback on Forward
    
    Note: Forward is actually on the back of the camera
    """
    cam.move_forward(0.1)
    update()
    
def on_button_backward_clicked(b):
    """Callback on Backward
    
    Note: Forward is actually on the back of the camera
    """
    cam.move_forward(-0.1)
    update()

def on_button_increase_pitch_clicked(b):
    """Callback on increasing Pitch"""
    cam.rotate(pitch=0.1)
    update()
    
def on_button_decrease_pitch_clicked(b):
    """Callback on decreasing Pitch"""
    cam.rotate(pitch=-0.1)
    update()
    
def on_button_increase_yaw_clicked(b):
    """Callback on increasing Yaw"""
    cam.rotate(yaw=0.1)
    update()
    
def on_button_decrease_yaw_clicked(b):
    """Callback on decreasing Yaw"""
    cam.rotate(yaw=-0.1)
    update()
    
def on_button_increase_roll_clicked(b):
    """Callback on increasing Roll"""
    cam.rotate(roll=0.1)
    update()
    
def on_button_decrease_roll_clicked(b):
    """Callback on decreasing Roll"""
    cam.rotate(roll=-0.1)
    update()
    
up_ax = plt.axes([0.0, 0.14, 0.14, 0.08])
left_ax = plt.axes([0.15, 0.14, 0.14, 0.08])
down_ax = plt.axes([0.3, 0.14, 0.14, 0.08])
right_ax = plt.axes([0.45, 0.14, 0.14, 0.08])
forward_ax = plt.axes([0.6, 0.14, 0.14, 0.08])
backward_ax = plt.axes([0.75, 0.14, 0.14, 0.08])
increase_pitch_ax = plt.axes([0.0, 0.05, 0.14, 0.08])
decrease_pitch_ax = plt.axes([0.15, 0.05, 0.14, 0.08])
increase_yaw_ax = plt.axes([0.3, 0.05, 0.14, 0.08])
decrease_yaw_ax = plt.axes([0.45, 0.05, 0.14, 0.08])
increase_roll_ax = plt.axes([0.60, 0.05, 0.14, 0.08])
decrease_roll_ax = plt.axes([0.75, 0.05, 0.14, 0.08])

button_up = Button(up_ax, "Up")
button_down = Button(down_ax, "Bottom")
button_left = Button(left_ax, "Left")
button_right = Button(right_ax, "Right")
button_forward = Button(forward_ax, "Forward")
button_backward = Button(backward_ax, "Backward")
button_increase_pitch = Button(increase_pitch_ax, "Increase\nPitch")
button_decrease_pitch = Button(decrease_pitch_ax, "Decrease\nPitch")
button_increase_yaw = Button(increase_yaw_ax, "Increase\nYaw")
button_decrease_yaw = Button(decrease_yaw_ax, "Decrease\nYaw")
button_increase_roll = Button(increase_roll_ax, "Increase\nRoll")
button_decrease_roll = Button(decrease_roll_ax, "Decrease\nRoll")

button_up.on_clicked(on_button_up_clicked)
button_down.on_clicked(on_button_down_clicked)
button_left.on_clicked(on_button_left_clicked)
button_right.on_clicked(on_button_right_clicked)
button_forward.on_clicked(on_button_forward_clicked)
button_backward.on_clicked(on_button_backward_clicked)
button_increase_pitch.on_clicked(on_button_increase_pitch_clicked)
button_decrease_pitch.on_clicked(on_button_decrease_pitch_clicked)
button_increase_yaw.on_clicked(on_button_increase_yaw_clicked)
button_decrease_yaw.on_clicked(on_button_decrease_yaw_clicked)
button_increase_roll.on_clicked(on_button_increase_roll_clicked)
button_decrease_roll.on_clicked(on_button_decrease_roll_clicked)


<IPython.core.display.Javascript object>

0