@@ -38,6 +38,13 @@ struct OrtArenaCfg {
3838 int max_dead_bytes_per_chunk; // use -1 to allow ORT to choose the default
3939 int initial_growth_chunk_size_bytes; // use -1 to allow ORT to choose the default
4040 int64_t max_power_of_two_extend_bytes; // use -1 to allow ORT to choose the default
41+ // Use CudaMemPool based arena if available (starting with cuda 11.2)
42+ int use_cuda_mempool = -1 ;
43+ // Amount of reserved memory in bytes to hold onto before trying
44+ // to release memory back to the OS.
45+ uint64_t cuda_mempool_release_threshold = 0 ;
46+ // Bytes to keep on shrink for CudaMemPool, 0 is to attempt to release all, allocated space not affected.
47+ size_t cuda_mempool_bytes_to_keep_on_shrink = 0 ;
4148
4249 bool IsValid () {
4350 return arena_extend_strategy >= -1 && arena_extend_strategy <= 1 &&
@@ -55,6 +62,9 @@ struct OrtArenaCfg {
5562 static constexpr const char * InitialGrowthChunkSizeBytes = " arena.initial_growth_chunk_size_bytes" ;
5663 static constexpr const char * MaxPowerOfTwoExtendBytes = " arena.max_power_of_two_extend_bytes" ;
5764 static constexpr const char * MaxMem = " arena.max_mem" ;
65+ static constexpr const char * UseCudaMemPool = " arena.use_cuda_mempool" ;
66+ static constexpr const char * CudaMempoolReleaseThreshold = " arena.cuda_mempool_release_threshold" ;
67+ static constexpr const char * CudaMempoolBytesToKeepOnShrink = " arena.cuda_mempool_bytes_to_keep_on_shrink" ;
5868 };
5969
6070 static onnxruntime::common::Status FromKeyValuePairs (const OrtKeyValuePairs& kvps, OrtArenaCfg& cfg);
@@ -348,4 +358,13 @@ void AllocatorDefaultFree(void* p);
348358void * AllocatorDefaultAllocAligned (size_t size, size_t alignment);
349359void AllocatorDefaultFreeAligned (void * p, size_t alignment);
350360
361+ class IArena : public IAllocator {
362+ public:
363+ using IAllocator::IAllocator;
364+ virtual Status Shrink () = 0;
365+ // Only implemented when IsStreamAware() returns true
366+ virtual void ReleaseStreamBuffers (Stream* /* stream*/ ) {}
367+ static IArena* SafeArenaCast (IAllocator* allocator);
368+ };
369+
351370} // namespace onnxruntime
0 commit comments