Skip to content

Commit

Permalink
Safely close multiple resources in RapidsBufferCatalog (#11307)
Browse files Browse the repository at this point in the history
* Safely close multiple resources in RapidsBufferCatalog

Signed-off-by: Jihoon Son <[email protected]>

* remove duplicate null filtering

* add nullafying back

---------

Signed-off-by: Jihoon Son <[email protected]>
  • Loading branch information
jihoonson authored Aug 9, 2024
1 parent de39a94 commit 05152f7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,7 +16,6 @@

package com.nvidia.spark.rapids

import scala.collection
import scala.collection.generic.CanBuildFrom
import scala.collection.mutable
import scala.reflect.ClassTag
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package com.nvidia.spark.rapids
import java.util.concurrent.ConcurrentHashMap
import java.util.function.BiFunction

import scala.collection.JavaConverters.collectionAsScalaIterableConverter

import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer, Rmm, Table}
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.RapidsBufferCatalog.getExistingRapidsBufferAndAcquire
Expand Down Expand Up @@ -727,9 +729,8 @@ class RapidsBufferCatalog(
def numBuffers: Int = bufferMap.size()

override def close(): Unit = {
bufferIdToHandles.values.forEach { handles =>
handles.foreach(_.close())
}
bufferIdToHandles.values.asScala.toSeq.flatMap(_.seq).safeClose()

bufferIdToHandles.clear()
}
}
Expand Down Expand Up @@ -864,30 +865,16 @@ object RapidsBufferCatalog extends Logging {
}

private def closeImpl(): Unit = synchronized {
if (_singleton != null) {
_singleton.close()
_singleton = null
}

if (memoryEventHandler != null) {
// Workaround for shutdown ordering problems where device buffers allocated with this handler
// are being freed after the handler is destroyed
//Rmm.clearEventHandler()
memoryEventHandler = null
}

if (deviceStorage != null) {
deviceStorage.close()
deviceStorage = null
}
if (hostStorage != null) {
hostStorage.close()
hostStorage = null
}
if (diskStorage != null) {
diskStorage.close()
diskStorage = null
}
Seq(_singleton, deviceStorage, hostStorage, diskStorage).safeClose()

_singleton = null
// Workaround for shutdown ordering problems where device buffers allocated
// with this handler are being freed after the handler is destroyed
//Rmm.clearEventHandler()
memoryEventHandler = null
deviceStorage = null
hostStorage = null
diskStorage = null
}

def getDeviceStorage: RapidsDeviceMemoryStore = deviceStorage
Expand Down

0 comments on commit 05152f7

Please sign in to comment.