PointNet : L'Architecture qui a tout changé

Comment traiter des nuages de points bruts sans voxélisation ni projection 2D ? La réponse de Stanford qui a révolutionné la 3D.
Expert
90 min
PyTorch, Mathématiques
"Avant 2017, le Deep Learning 3D était un enfer de trames voxels gourmandes en mémoire. PointNet a prouvé qu'on pouvait apprendre directement sur des coordonnées XYZ désordonnées."

1. Le Problème de l'Ordre

Un nuage de points est un ensemble de N points. Si vous permutez l'ordre des points, l'objet reste le même (une chaise est une chaise, que le point 1 soit le pied ou le dossier). Mais un réseau de neurones classique (MLP, CNN) est sensible à l'ordre d'entrée.

PointNet résout cela en utilisant une fonction symétrique (Max Pooling) qui rend le réseau invariant aux permutations.

Input (Nx3) MLP (Point-wise) Max Pool Global Feat

2. T-Net : L'Alignement Spatial

Comme pour l'ordre, le réseau doit être invariant à la rotation. PointNet apprend une matrice de transformation 3x3 (via un mini-réseau appelé T-Net) pour aligner canoniquement l'objet avant de le traiter.

tnet_module.py
class TNet(nn.Module):
    def __init__(self, k=3):
        super().__init__()
        self.conv1 = nn.Conv1d(k, 64, 1)
        self.fc1 = nn.Linear(1024, 512)
        # ... couches cachées ...
        self.fc3 = nn.Linear(256, k*k)
    
    def forward(self, x):
        # Apprend une matrice de transformation k x k
        matrix = self.fc3(x).view(-1, k, k)
        # Ajout de l'identité pour la stabilité
        identity = torch.eye(k).cuda()
        return matrix + identity

3. Implémentation du Backbone

Le cœur de PointNet est une série de MLPs partagés (Conv1d avec kernel size 1 en PyTorch) qui montent la dimension de chaque point de 3 (xyz) à 64, puis 1024.

pointnet_cls.py
class PointNetCls(nn.Module):
    def __init__(self, k=10, feature_transform=False):
        super().__init__()
        self.feat = PointNetFeat(global_feat=True)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k) # k classes
        self.dropout = nn.Dropout(p=0.3)

    def forward(self, x):
        x, trans, trans_feat = self.feat(x)
        x = F.relu(self.fc1(x))
        x = F.log_softmax(self.fc3(x), dim=1)
        return x, trans_feat

4. Classification vs Segmentation

Tâche Output Architecture
Classification 1 classe par nuage Max Pooling global -> FC Layers
Segmentation 1 classe par point Concaténer Global Feat (1024) à chaque Point Feat (64)

Pour la segmentation, l'idée de génie est de réinjecter l'information globale (le contexte de la forme entière) au niveau local de chaque point. Ainsi, le point sait "je suis un point situé en bas" ET "je fais partie d'une chaise", donc "je suis probablement un pied".

🧠 Entraînez votre premier réseau 3D

Notre TP "Deep Learning 3D" vous guide pas à pas pour entraîner PointNet sur le dataset ModelNet40 et segmenter des parties d'avions.

Rejoindre l'Elite (120h)