If you have ever tried to train a Graph Neural Network (GNN) on a massive dataset, you already know the pain of the "Memory Wall."
Loading a datas...
For further actions, you may consider blocking this person and/or reporting abuse
Brilliant approach! Leveraging mmap to let the OS handle the paging is a huge brain move. I do have a quick technical question: Since you're relying on page faults from the NVMe SSD, how does this affect the training speed (I/O bottleneck) compared to a system that actually has enough RAM to fit the whole dataset? Is the OpenMP multi-threading enough to completely hide that I/O latency?
Thanks! You hit the fundamental trade-off. No, OpenMP doesn't completely hide I/O latency—DDR5 RAM will always beat an NVMe SSD.
However, two things mitigate this: the OS Page Cache stores "hot" nodes in free RAM after epoch 1, and OpenMP saturates the NVMe's IOPS queue with parallel requests.
Ultimately, it’s a choice between training at 70% speed versus an instant PyTorch OOM crash!
That makes perfect sense! Relying on the OS Page Cache for the 'hot' nodes after epoch 1 is a very elegant fallback. And you're absolutely right—70% training speed is infinitely faster than a 0% OOM crashed run 😂. Thanks for the detailed breakdown, brilliant engineering!
Really cool approach. The mmap strategy resonates with me — I run Llama 3 locally to generate content for a 100k+ page multilingual site, and memory management is a constant battle at that scale. Had to get creative with batch sizing and streaming to avoid OOM on a machine with 32GB RAM.
The 70% speed vs. OOM crash trade-off is the right framing. In my case I found a similar pattern: processing pages in streaming batches at ~60% throughput beats trying to load everything into memory and crashing halfway through a 10-hour generation run.
Curious — have you considered adding a prefetching strategy that pre-warms the page cache based on the graph's access patterns? If your neighbor sampler knows which nodes are likely to be accessed next, you could issue madvise(MADV_WILLNEED) hints ahead of time. Might close some of that 30% gap.
well thanks for prefetching idea, i may look into this.
This is clever
Moving beyond the usual batch size reduction and gradient accumulation bandaids
The zero copy approach makes a lot of sense
Data movement between CPU and GPU is such an overlooked bottleneck
Everyone stares at compute but ignores the overhead of shuffling tensors around
Questions that came to mind
How does this compare to pytorches own memory optimization features
Things like checkpointing or max split size mb
Is this generalizable across different architectures
Or does it work best for specific patterns like transformers
Would be interesting to see benchmarks across different gpu hardware
A100 vs H100 vs consumer cards
Also curious about the autograd tradeoffs
Maintaining gradients with zero copy is tricky territory
Have you thought about open sourcing it
Would love to test this on some real workloads
The approach reminds me of tensorrt but keeping it inside the pytorch ecosystem is a nice sweet spot
Nice engineering
well that are some hard questions to answer, so you may look about them yourself, it is open sourced (github link) , also avaliable on pypi,
This is a brilliant solution to a problem every GNN practitioner has faced. The mmap + zero-copy approach is elegant letting the OS handle page faults instead of fighting against RAM limits. The fact that you're handing raw C++ pointers directly to PyTorch via nanobind is exactly the kind of systems-level thinking that makes deep learning actually practical at scale. Impressive work!
This is a fascinating deep dive into low-level optimization! The zero-copy C graph engine approach is clever - bypassing PyTorch's overhead while maintaining the computational graph is no small feat.
Key takeaways that impressed me:
Moving beyond just batch size reduction or gradient accumulation (the usual OOM bandaids)
The zero-copy philosophy - minimizing data movement between CPU/GPU is often overlooked but critical for performance
Building a custom C extension that maintains graph connectivity while reducing memory footprint
-How does this compare to PyTorch's own memory optimization features like checkpointing or max_split_size_mb?
-Is this generalizable across different model architectures, or does it work best for specific patterns (like transformers vs CNNs)?
-Did you have to make any trade-offs with autograd functionality? Maintaining gradients with zero-copy can be tricky
Would be interesting to see benchmarks across different GPU architectures (A100 vs H100 vs consumer cards)
Integration with PyTorch's memory profiling tools would make this more accessible
A fallback mechanism for operations where zero-copy isn't optimal
This kind of systems-level optimization work is exactly what the ML engineering community needs. Have you considered open-sourcing it? Would love to test it on some production workloads!
The approach reminds me of how TensorRT optimizes graphs, but keeping it within PyTorch's ecosystem is a nice middle ground. Great engineering!
The goat