Skip to content

WIP: Using DGL instead of Torch Geometric #48

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 33 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ed49f41
Update local.yaml
gennarinoos Feb 3, 2025
9c7505e
WIP on using DGL
gennarinoos Feb 4, 2025
dcf9318
WIP: missing file
gennarinoos Feb 4, 2025
b2b3d28
fix: batching while conserving mesh face information
gennarinoos Feb 4, 2025
43ee8d2
fix: start fixing unit tests
gennarinoos Feb 5, 2025
0b8ffe4
fix: temporarily disable surface and triangle coll loss
gennarinoos Feb 5, 2025
64dc850
fix: fix point sampler tests
gennarinoos Feb 5, 2025
bc42714
fix: surface distance loss, reverting some optimizations
gennarinoos Feb 6, 2025
23e5df1
refactor: do not branch out from original loss functions
gennarinoos Feb 6, 2025
b766de0
fix: renaming
gennarinoos Feb 6, 2025
618921c
fix: renaming
gennarinoos Feb 6, 2025
2a4e19a
refactor: reverting some minor changes
gennarinoos Feb 6, 2025
61b542b
perf: optimize surface distance loss calculation
gennarinoos Feb 6, 2025
67a3993
chore: install right torch and dgl in CI
gennarinoos Feb 6, 2025
c399dad
fix: errors in pytest
gennarinoos Feb 6, 2025
a53a54f
fix: cpu and gpu compatible deps
gennarinoos Feb 6, 2025
98a5809
debug CI
gennarinoos Feb 6, 2025
a51b70d
debug CI 2
gennarinoos Feb 6, 2025
3ca6b66
fix: pin dgl version
gennarinoos Feb 7, 2025
6be72e9
fix: move graph to GPU
gennarinoos Feb 7, 2025
5220770
fix: update notebook for dgl for CUDA
gennarinoos Feb 7, 2025
1da4624
fix: subgraph created on original graph's device
gennarinoos Feb 7, 2025
23a6c7c
chore: more logging for remote debugging
gennarinoos Feb 7, 2025
fcd057b
debug: force simplified graph to be on GPU
gennarinoos Feb 7, 2025
401c272
debug: force triangle graph to be on GPU
gennarinoos Feb 7, 2025
49e80eb
test hyperparameters 1
gennarinoos Feb 7, 2025
86dc8bf
chore: more logging
gennarinoos Feb 7, 2025
27e1a48
perf: reduce number of neighbors
gennarinoos Feb 7, 2025
944582c
fix: use eval dataset for validation
gennarinoos Feb 7, 2025
fc75a0a
fix: evaluation script
gennarinoos Feb 7, 2025
d6a778e
fix: serial loading of data
gennarinoos Feb 7, 2025
22e64eb
chore: parameterize k value in edge crossings loss test
martinnormark Feb 7, 2025
792c195
avoid the triangle itself ranking as the closest
gennarinoos Feb 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cpu
pip install torch_cluster==1.6.3 torch_geometric==2.5.3 torch_scatter==2.1.2 torch_sparse==0.6.18 -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cpu
pip install dgl==2.2.0 -f https://data.dgl.ai/wheels/torch-2.3/repo.html
pip install torchdata==0.9.0
pip install -r requirements.txt
pip freeze
ls -la /opt/hostedtoolcache/Python/3.12.8/x64/lib/python3.12/site-packages/dgl/graphbolt/
- name: Install this package
run: |
pip install -e .
Expand Down
Binary file not shown.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,19 @@ conda install pip
```

Depending on whether you are using PyTorch on a CPU or a GPU,
you'll have to use the correct binaries for PyTorch and the PyTorch Geometric libraries. You can install them via:
you'll have to use the correct binaries for PyTorch and DGL. You can install them via:

```bash
pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cpu
pip install torch_cluster==1.6.3 torch_geometric==2.5.3 torch_scatter==2.1.2 torch_sparse==0.6.18 -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cpu
pip install dgl==2.2.0 -f https://data.dgl.ai/wheels/torch-2.3/repo.html
```

Replace “cpu” with “cu121” or the appropriate CUDA version for your system. If you don't know what is your cuda version,
run `nvidia-smi`

NOTE: When updating version of PyTorch or DGL, please check https://www.dgl.ai/pages/start.html to determine
the versions compatible with CUDA.

After that you can install the remaining requirements

```bash
Expand Down
7 changes: 4 additions & 3 deletions configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ model:
hidden_dim: 128 # feature dimension for point sampler and face classifier (as per paper)
edge_hidden_dim: 64 # feature dimension for edge predictor (as per paper)
num_layers: 3 # number of convolutional layers (as per paper)
k: 15 # number of neighbors for graph construction (as per paper)
edge_k: 15 # number of neighbors for edge features (as per paper)
target_ratio: 0.5 # mesh simplification ratio
k: 5 # number of neighbors for graph construction (as per paper)
edge_k: 5 # number of neighbors for edge features (as per paper)
target_ratio: 0.4 # mesh simplification ratio

# Training parameters
training:
Expand All @@ -16,6 +16,7 @@ training:
num_epochs: 20 # total training epochs
early_stopping_patience: 15 # epochs before early stopping
checkpoint_dir: data/checkpoints # model save directory
num_workers: 0

# Data parameters
data:
Expand Down
10 changes: 5 additions & 5 deletions configs/local.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Model parameters
model:
input_dim: 3
hidden_dim: 64 # feature dimension for point sampler and face classifier (as per paper)
edge_hidden_dim: 128 # feature dimension for edge predictor (as per paper)
hidden_dim: 128 # feature dimension for point sampler and face classifier (as per paper)
edge_hidden_dim: 64 # feature dimension for edge predictor (as per paper)
num_layers: 3 # number of convolutional layers (as per paper)
k: 15 # number of neighbors for graph construction (as per paper)
k: 5 # number of neighbors for graph construction (as per paper)
edge_k: 15 # number of neighbors for edge features (as per paper)
target_ratio: 0.5 # mesh simplification ratio

Expand All @@ -13,10 +13,10 @@ training:
learning_rate: 1.0e-5
weight_decay: 0.99 # weight decay per epoch (as per paper)
batch_size: 2
accumulation_steps: 4
num_epochs: 20 # total training epochs
early_stopping_patience: 10 # epochs before early stopping
early_stopping_patience: 5 # epochs before early stopping
checkpoint_dir: data/checkpoints # model save directory
num_workers: 0

# Data parameters
data:
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pluggy==1.5.0
psutil==6.0.0
pyarrow==17.0.0
pyarrow-hotfix==0.6
pydantic==2.10.6
pyparsing==3.1.2
pytest==8.3.2
python-dateutil==2.9.0.post0
Expand All @@ -39,6 +40,7 @@ setuptools==72.1.0
six==1.16.0
sympy==1.13.1
threadpoolctl==3.5.0
torchdata==0.9.0
tqdm==4.66.5
trimesh==4.4.4
typing_extensions==4.12.2
Expand Down
2 changes: 1 addition & 1 deletion scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def load_config(config_path):
def main():
args = parse_args()
config = load_config(args.config)
config["data"]["eval_data_path"] = args.eval_data_path
config["data"]["data_dir"] = args.eval_data_path

trainer = Trainer(config)
trainer.load_checkpoint(args.checkpoint)
Expand Down
29 changes: 12 additions & 17 deletions src/neural_mesh_simplification/api/neural_mesh_simplifier.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
import torch
import trimesh
from torch_geometric.data import Data

from ..data.dataset import preprocess_mesh, mesh_to_tensor
from ..data.dataset import preprocess_mesh, mesh_to_dgl, dgl_to_trimesh
from ..models import NeuralMeshSimplification


class NeuralMeshSimplifier:
def __init__(
self,
input_dim,
hidden_dim,
edge_hidden_dim, # Separate hidden dim for edge predictor
num_layers,
k,
edge_k,
target_ratio
self,
input_dim,
hidden_dim,
edge_hidden_dim, # Separate hidden dim for edge predictor
num_layers,
k,
edge_k,
target_ratio
):
self.input_dim = input_dim
self.hidden_dim = hidden_dim
Expand Down Expand Up @@ -59,11 +58,7 @@ def simplify(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh:
preprocessed_mesh: trimesh.Trimesh = preprocess_mesh(mesh)

# Convert to a tensor
tensor: Data = mesh_to_tensor(preprocessed_mesh)
model_output = self.model(tensor)
graph, _ = mesh_to_dgl(preprocessed_mesh)
s_graph, s_faces, _ = self.model(graph)

vertices = model_output["sampled_vertices"].detach().numpy()
faces = model_output["simplified_faces"].numpy()
edges = model_output["edge_index"].t().numpy() # Transpose to get (n, 2) shape

return trimesh.Trimesh(vertices=vertices, faces=faces, edges=edges)
return dgl_to_trimesh(s_graph, s_faces)
86 changes: 58 additions & 28 deletions src/neural_mesh_simplification/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import gc
import logging
import os
from typing import Optional

import dgl
import numpy as np
import torch
import trimesh
from torch.utils.data import Dataset
from torch_geometric.data import Data
from dgl.data import DGLDataset
from trimesh import Geometry, Trimesh

from ..utils import build_graph_from_mesh
logger = logging.getLogger(__name__)


class MeshSimplificationDataset(Dataset):
def __init__(self, data_dir: str, preprocess: bool = False, transform: Optional[callable] = None):
class MeshSimplificationDataset(DGLDataset):
def __init__(self, data_dir, preprocess: bool = False, transform: Optional[callable] = None):
super().__init__(name='mesh_simplification')
self.data_dir = data_dir
self.preprocess = preprocess
self.transform = transform
Expand All @@ -29,7 +30,7 @@ def _get_file_list(self):
def __len__(self):
return len(self.file_list)

def __getitem__(self, idx):
def __getitem__(self, idx) -> tuple[dgl.DGLGraph, torch.Tensor]:
file_path = os.path.join(self.data_dir, self.file_list[idx])
mesh = load_mesh(file_path)

Expand All @@ -39,9 +40,7 @@ def __getitem__(self, idx):
if self.transform:
mesh = self.transform(mesh)

data = mesh_to_tensor(mesh)
gc.collect()
return data
return mesh_to_dgl(mesh)


def load_mesh(file_path: str) -> Geometry | list[Geometry] | None:
Expand Down Expand Up @@ -81,28 +80,59 @@ def augment_mesh(mesh: trimesh.Trimesh) -> Trimesh | None:
return mesh


def mesh_to_tensor(mesh: trimesh.Trimesh) -> Data:
"""Convert a mesh to tensor representation including graph structure."""
def mesh_to_dgl(mesh) -> tuple[dgl.DGLGraph, torch.Tensor]:
if mesh is None:
return None
raise ValueError("Mesh is undefined")

# Convert vertices to tensor
vertices = torch.tensor(mesh.vertices, dtype=torch.float32)
num_nodes = vertices.shape[0]

# Convert unique edges
edges_np = np.array(list(mesh.edges_unique))
edges = torch.tensor(edges_np, dtype=torch.long).t()

# Create DGL graph
g = dgl.graph((edges[0], edges[1]), num_nodes=num_nodes)
g = dgl.add_self_loop(g)

# Convert vertices and faces to tensors
vertices_tensor = torch.tensor(mesh.vertices, dtype=torch.float32)
faces_tensor = torch.tensor(mesh.faces, dtype=torch.long).t()
# Verify node count matches
assert g.number_of_nodes() == vertices.shape[0], "Mismatch between nodes and features"

# Build graph structure
G = build_graph_from_mesh(mesh)
# Add node features
g.ndata['x'] = vertices
g.ndata['pos'] = vertices

# Create edge index tensor
edge_index = torch.tensor(list(G.edges), dtype=torch.long).t().contiguous()
# Store face information as node data
if hasattr(mesh, 'faces'):
faces_tensor = torch.tensor(mesh.faces, dtype=torch.long)
else:
faces_tensor = torch.empty((0, 3), dtype=torch.long)

# Create Data object
data = Data(
x=vertices_tensor,
pos=vertices_tensor,
edge_index=edge_index,
face=faces_tensor,
num_nodes=len(mesh.vertices),
return g, faces_tensor


def dgl_to_trimesh(g: dgl.DGLGraph, faces: torch.Tensor | None) -> Trimesh:
# Convert to a tensor
vertices = g.ndata['pos'].numpy()
vertex_normals = g.ndata.get('normal', None)
if vertex_normals is not None:
vertex_normals = vertex_normals.numpy()

return trimesh.Trimesh(
vertices=vertices,
faces=faces,
vertex_normals=vertex_normals,
process=True,
validate=True
)

return data

def collate(batch: list[tuple]) -> tuple[dgl.DGLGraph, torch.Tensor]:
graphs, faces = zip(*batch)
max_faces = max(f.shape[0] for f in faces)
padded_faces = torch.stack([
torch.nn.functional.pad(f, (0, 0, 0, max_faces - f.shape[0]), value=-1)
for f in faces
])
return graphs, padded_faces
7 changes: 1 addition & 6 deletions src/neural_mesh_simplification/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1 @@
from .chamfer_distance_loss import ProbabilisticChamferDistanceLoss
from .surface_distance_loss import ProbabilisticSurfaceDistanceLoss
from .triangle_collision_loss import TriangleCollisionLoss
from .edge_crossing_loss import EdgeCrossingLoss
from .overlapping_triangles_loss import OverlappingTrianglesLoss
from .combined_loss import CombinedMeshSimplificationLoss
from neural_mesh_simplification.losses.combined_loss import CombinedMeshSimplificationLoss
21 changes: 19 additions & 2 deletions src/neural_mesh_simplification/losses/chamfer_distance_loss.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
import logging

import torch
import torch.nn as nn

logger = logging.getLogger(__name__)


class ProbabilisticChamferDistanceLoss(nn.Module):
def __init__(self):
super(ProbabilisticChamferDistanceLoss, self).__init__()

def forward(self, P, Ps, probabilities):
def forward(
self,
P: torch.Tensor,
Ps: torch.Tensor,
probabilities: torch.Tensor
) -> torch.Tensor:
"""
Compute the Probabilistic Chamfer Distance loss.

Expand All @@ -18,6 +27,9 @@ def forward(self, P, Ps, probabilities):
Returns:
torch.Tensor: Scalar loss value
"""

logger.debug(f"Calculating CHAMFER loss on device {P.device} {Ps.device} {probabilities.device}")

if P.size(0) == 0 or Ps.size(0) == 0:
return torch.tensor(0.0, device=P.device, requires_grad=True)

Expand All @@ -42,7 +54,12 @@ def forward(self, P, Ps, probabilities):

return loss

def compute_minimum_distances(self, source, target, return_indices=False):
def compute_minimum_distances(
self,
source: torch.Tensor,
target: torch.Tensor,
return_indices: bool = False
):
"""
Compute the minimum distances from each point in source to target.

Expand Down
Loading
Loading