From 05152f7bae63df80aa3ac3936b342dc4f9ecfd69 Mon Sep 17 00:00:00 2001 From: Jihoon Son Date: Fri, 9 Aug 2024 12:44:26 -0700 Subject: [PATCH] Safely close multiple resources in RapidsBufferCatalog (#11307) * Safely close multiple resources in RapidsBufferCatalog Signed-off-by: Jihoon Son * remove duplicate null filtering * add nullafying back --------- Signed-off-by: Jihoon Son --- .../com/nvidia/spark/rapids/implicits.scala | 3 +- .../spark/rapids/RapidsBufferCatalog.scala | 41 +++++++------------ 2 files changed, 15 insertions(+), 29 deletions(-) diff --git a/sql-plugin/src/main/scala-2.12/com/nvidia/spark/rapids/implicits.scala b/sql-plugin/src/main/scala-2.12/com/nvidia/spark/rapids/implicits.scala index e29058789b4..eddad69ba97 100644 --- a/sql-plugin/src/main/scala-2.12/com/nvidia/spark/rapids/implicits.scala +++ b/sql-plugin/src/main/scala-2.12/com/nvidia/spark/rapids/implicits.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ package com.nvidia.spark.rapids -import scala.collection import scala.collection.generic.CanBuildFrom import scala.collection.mutable import scala.reflect.ClassTag diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala index 5a4086865cf..f61291a31ce 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala @@ -19,6 +19,8 @@ package com.nvidia.spark.rapids import java.util.concurrent.ConcurrentHashMap import java.util.function.BiFunction +import scala.collection.JavaConverters.collectionAsScalaIterableConverter + import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer, Rmm, Table} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.RapidsBufferCatalog.getExistingRapidsBufferAndAcquire @@ -727,9 +729,8 @@ class RapidsBufferCatalog( def numBuffers: Int = bufferMap.size() override def close(): Unit = { - bufferIdToHandles.values.forEach { handles => - handles.foreach(_.close()) - } + bufferIdToHandles.values.asScala.toSeq.flatMap(_.seq).safeClose() + bufferIdToHandles.clear() } } @@ -864,30 +865,16 @@ object RapidsBufferCatalog extends Logging { } private def closeImpl(): Unit = synchronized { - if (_singleton != null) { - _singleton.close() - _singleton = null - } - - if (memoryEventHandler != null) { - // Workaround for shutdown ordering problems where device buffers allocated with this handler - // are being freed after the handler is destroyed - //Rmm.clearEventHandler() - memoryEventHandler = null - } - - if (deviceStorage != null) { - deviceStorage.close() - deviceStorage = null - } - if (hostStorage != null) { - hostStorage.close() - hostStorage = null - } - if (diskStorage != null) { - diskStorage.close() - diskStorage = null - } + Seq(_singleton, deviceStorage, hostStorage, diskStorage).safeClose() + + _singleton = null + // Workaround for shutdown ordering problems where device buffers allocated + // with this handler are being freed after the handler is destroyed + //Rmm.clearEventHandler() + memoryEventHandler = null + deviceStorage = null + hostStorage = null + diskStorage = null } def getDeviceStorage: RapidsDeviceMemoryStore = deviceStorage