Skip to content

Commit

Permalink
Fixing deadlock by moving the removal listener out of lru lock
Browse files Browse the repository at this point in the history
Signed-off-by: Sagar Upadhyaya <[email protected]>
  • Loading branch information
sgup432 committed Jun 5, 2024
1 parent 3b75273 commit 386287d
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,18 @@ class LockWrapper {
ReleasableLock writeLock = new ReleasableLock(readWriteLock.writeLock());

AtomicInteger refCount = new AtomicInteger();

String tierName = "";

LockWrapper(String tierName) {
this.tierName = tierName;
}

LockWrapper() {}

void setTierName(String tierName) {
this.tierName = tierName;
}
}

// Package private for testing
Expand All @@ -168,7 +180,8 @@ ICache<K, V> getDiskCache() {
return diskCache;
}

private void writeLock(ICacheKey<K> key) {

private void writeLock(ICacheKey<K> key, String tierName) {
if (threadLocal.get() != null) {
LockWrapper lockWrapper = threadLocal.get();
if (lockWrapper != null && lockWrapper.writeLock.isHeldByCurrentThread()) {
Expand All @@ -178,26 +191,39 @@ private void writeLock(ICacheKey<K> key) {
}
}
}

LockWrapper lockWrapper = locks.computeIfAbsent(key, key1 -> {
return new LockWrapper();
});
// if (lockWrapper.tierName.equals("")) {
// lockWrapper.setTierName(tierName);
// }
// if (!tierName.equals(lockWrapper.tierName)) {
// return;
// }
lockWrapper.refCount.incrementAndGet();
threadLocal.set(lockWrapper);
lockWrapper.writeLock.acquire();
}

private void readLock(ICacheKey<K> key) {
private void readLock(ICacheKey<K> key, String tierName) {
LockWrapper lockWrapper = locks.computeIfAbsent(key, key1 -> {
return new LockWrapper();
});
// if (!tierName.equals(lockWrapper.tierName)) {
// return;
// }
lockWrapper.refCount.incrementAndGet();
threadLocal.set(lockWrapper);
lockWrapper.readLock.acquire();
}

private void unlockWriteLock(ICacheKey<K> key) {
private void unlockWriteLock(ICacheKey<K> key, String tierName) {
LockWrapper lockWrapper = locks.get(key);
if (lockWrapper != null && lockWrapper.writeLock.isHeldByCurrentThread()) {
// if (!tierName.equals(lockWrapper.tierName)) {
// return;
// }
lockWrapper.writeLock.close();
if (lockWrapper.refCount.decrementAndGet() == 0) {
locks.remove(key);
Expand All @@ -206,9 +232,12 @@ private void unlockWriteLock(ICacheKey<K> key) {
threadLocal.remove();
}

private void unlockReadLock(ICacheKey<K> key) {
private void unlockReadLock(ICacheKey<K> key, String tierName) {
LockWrapper lockWrapper = locks.get(key);
if (lockWrapper != null && lockWrapper.readLock.isHeldByCurrentThread()) {
// if (!tierName.equals(lockWrapper.tierName)) {
// return;
// }
lockWrapper.readLock.close();
if (lockWrapper.refCount.decrementAndGet() == 0) {
locks.remove(key);
Expand Down Expand Up @@ -237,11 +266,11 @@ public V get(ICacheKey<K> key) {
@Override
public void put(ICacheKey<K> key, V value) {
try {
writeLock(key);
writeLock(key, TIER_DIMENSION_VALUE_ON_HEAP);
onHeapCache.put(key, value);
updateStatsOnPut(TIER_DIMENSION_VALUE_ON_HEAP, key, value);
} finally {
unlockWriteLock(key);
unlockWriteLock(key, TIER_DIMENSION_VALUE_ON_HEAP);
}
}

Expand All @@ -261,10 +290,10 @@ public V computeIfAbsent(ICacheKey<K> key, LoadAwareCacheLoader<ICacheKey<K>, V>
// the value once.
V value = null;
try {
writeLock(key);
writeLock(key, TIER_DIMENSION_VALUE_ON_HEAP);
value = onHeapCache.computeIfAbsent(key, loader);
} finally {
unlockWriteLock(key);
unlockWriteLock(key, TIER_DIMENSION_VALUE_ON_HEAP);
}
// Handle stats
if (loader.isLoaded()) {
Expand Down Expand Up @@ -307,10 +336,10 @@ public void invalidate(ICacheKey<K> key) {
}
if (key.key != null) {
try {
writeLock(key);
writeLock(key, TIER_DIMENSION_VALUE_ON_HEAP);
cacheEntry.getKey().invalidate(key);
} finally {
unlockWriteLock(key);
unlockWriteLock(key, TIER_DIMENSION_VALUE_ON_HEAP);
}
}
}
Expand Down Expand Up @@ -378,7 +407,7 @@ public ImmutableCacheStatsHolder stats(String[] levels) {
private Function<ICacheKey<K>, Tuple<V, String>> getValueFromTieredCache(boolean captureStats) {
return key -> {
try {
readLock(key);
readLock(key, TIER_DIMENSION_VALUE_ON_HEAP);
for (Map.Entry<ICache<K, V>, TierInfo> cacheEntry : caches.entrySet()) {
if (cacheEntry.getValue().isEnabled()) {
V value = cacheEntry.getKey().get(key);
Expand All @@ -397,7 +426,7 @@ private Function<ICacheKey<K>, Tuple<V, String>> getValueFromTieredCache(boolean
}
return null;
} finally {
unlockReadLock(key);
unlockReadLock(key, TIER_DIMENSION_VALUE_ON_HEAP);
}
};
}
Expand All @@ -407,10 +436,10 @@ void handleRemovalFromHeapTier(RemovalNotification<ICacheKey<K>, V> notification
boolean wasEvicted = SPILLOVER_REMOVAL_REASONS.contains(notification.getRemovalReason());
if (caches.get(diskCache).isEnabled() && wasEvicted && evaluatePolicies(notification.getValue())) {
try {
writeLock(key);
writeLock(key, TIER_DIMENSION_VALUE_DISK);
diskCache.put(key, notification.getValue()); // spill over to the disk tier and increment its stats
} finally {
unlockWriteLock(key);
unlockWriteLock(key, TIER_DIMENSION_VALUE_DISK);
}
updateStatsOnPut(TIER_DIMENSION_VALUE_DISK, key, notification.getValue());
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1350,21 +1350,21 @@ public void testNumLocksTiming() throws Exception {
Phaser phaser = new Phaser(numThreads + 1);
CountDownLatch countDownLatch = new CountDownLatch(numThreads);

new Thread(() -> {
try {
Thread.sleep(10000);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
System.out.println("Taking thread dump");
for (Map.Entry<Thread, StackTraceElement[]> entry : Thread.getAllStackTraces().entrySet()) {
System.out.println(entry.getKey() + " " + entry.getKey().getState());
for (StackTraceElement ste : entry.getValue()) {
System.out.println("\tat " + ste);
}
System.out.println();
}
}).start();
// new Thread(() -> {
// try {
// Thread.sleep(10000);
// } catch (InterruptedException e) {
// throw new RuntimeException(e);
// }
// System.out.println("Taking thread dump");
// for (Map.Entry<Thread, StackTraceElement[]> entry : Thread.getAllStackTraces().entrySet()) {
// System.out.println(entry.getKey() + " " + entry.getKey().getState());
// for (StackTraceElement ste : entry.getValue()) {
// System.out.println("\tat " + ste);
// }
// System.out.println();
// }
// }).start();
// Precompute the keys each thread will request so we don't include that in the time estimate
List<List<ICacheKey<String>>> keysPerThread = new ArrayList<>();

Expand All @@ -1381,8 +1381,7 @@ public void testNumLocksTiming() throws Exception {
for (int j = 0; j < numRequests; j++) {
tieredSpilloverCache.computeIfAbsent(keysPerThread.get(finalI).get(j), getLoadAwareCacheLoader());
if (j % 100 == 0) {
//System.out.println("Finished iter " + j + " elapsed time = " + (System
// .currentTimeMillis() - startTime));
System.out.println("Finished iter " + j + " elapsed time = " + (System.currentTimeMillis() - startTime));
}
}
} catch (Exception e) {
Expand Down Expand Up @@ -1481,7 +1480,7 @@ public String load(ICacheKey<String> key) throws Exception {
}
};

tieredSpilloverCache.setMockListener(mockRemovalListener);
//tieredSpilloverCache.setMockListener(mockRemovalListener);

tieredSpilloverCache.computeIfAbsent(cacheKey1, new LoadAwareCacheLoader<>() {
boolean isLoaded = false;
Expand Down
16 changes: 15 additions & 1 deletion server/src/main/java/org/opensearch/common/cache/Cache.java
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,8 @@ public long getEvictions() {

private boolean promote(Entry<K, V> entry, long now) {
boolean promoted = true;
boolean removed = false;
RemovalNotification<K, V> removalNotification = null;
try (ReleasableLock ignored = lruLock.acquire()) {
switch (entry.state) {
case DELETED:
Expand All @@ -782,9 +784,21 @@ private boolean promote(Entry<K, V> entry, long now) {
break;
}
if (promoted) {
evict(now);
while (tail != null && shouldPrune(tail, now)) {
CacheSegment<K, V> segment = getCacheSegment(entry.key);
if (segment != null) {
segment.remove(entry.key, entry.value, f -> {});
}
if (unlink(entry)) {
removed = true;
removalNotification = new RemovalNotification<>(entry.key, entry.value, RemovalReason.EVICTED);
}
}
}
}
if (removed) {
removalListener.onRemoval(new RemovalNotification<>(entry.key, entry.value, RemovalReason.EVICTED));
}
return promoted;
}

Expand Down

0 comments on commit 386287d

Please sign in to comment.