Skip to content

Commit

Permalink
Draft lock per key implementation
Browse files Browse the repository at this point in the history
Signed-off-by: Sagar Upadhyaya <[email protected]>
  • Loading branch information
sgup432 committed May 31, 2024
1 parent 0c0a2b3 commit 727aa91
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.Function;
Expand Down Expand Up @@ -77,15 +80,19 @@ public class TieredSpilloverCache<K, V> implements ICache<K, V> {
private final TieredSpilloverCacheStatsHolder statsHolder;
private ToLongBiFunction<ICacheKey<K>, V> weigher;
private final List<String> dimensionNames;
ReadWriteLock readWriteLock = new ReentrantReadWriteLock();
ReleasableLock readLock = new ReleasableLock(readWriteLock.readLock());
ReleasableLock writeLock = new ReleasableLock(readWriteLock.writeLock());
// ReadWriteLock readWriteLock = new ReentrantReadWriteLock();
// ReleasableLock readLock = new ReleasableLock(readWriteLock.readLock());
// ReleasableLock writeLock = new ReleasableLock(readWriteLock.writeLock());
/**
* Maintains caching tiers in ascending order of cache latency.
*/
private final Map<ICache<K, V>, TierInfo> caches;
private final List<Predicate<V>> policies;

Map<ICacheKey<K>, LockWrapper> locks = new ConcurrentHashMap<>();

Map<ICacheKey<K>, CompletableFuture<Tuple<ICacheKey<K>, V>>> completableFutureMap = new ConcurrentHashMap<>();

TieredSpilloverCache(Builder<K, V> builder) {
Objects.requireNonNull(builder.onHeapCacheFactory, "onHeap cache builder can't be null");
Objects.requireNonNull(builder.diskCacheFactory, "disk cache builder can't be null");
Expand Down Expand Up @@ -141,6 +148,14 @@ public class TieredSpilloverCache<K, V> implements ICache<K, V> {
.addSettingsUpdateConsumer(DISK_CACHE_ENABLED_SETTING_MAP.get(builder.cacheType), this::enableDisableDiskCache);
}

class LockWrapper {
ReadWriteLock readWriteLock = new ReentrantReadWriteLock();
ReleasableLock readLock = new ReleasableLock(readWriteLock.readLock());
ReleasableLock writeLock = new ReleasableLock(readWriteLock.writeLock());

AtomicInteger refCount = new AtomicInteger();
}

// Package private for testing
ICache<K, V> getOnHeapCache() {
return onHeapCache;
Expand All @@ -151,6 +166,38 @@ ICache<K, V> getDiskCache() {
return diskCache;
}

private void writeLock(ICacheKey<K> key) {
LockWrapper lockWrapper = locks.computeIfAbsent(key, key1 -> {
return new LockWrapper();
});
lockWrapper.refCount.incrementAndGet();
lockWrapper.writeLock.acquire();
}

private void readLock(ICacheKey<K> key) {
LockWrapper lockWrapper = locks.computeIfAbsent(key, key1 -> {
return new LockWrapper();
});
lockWrapper.refCount.incrementAndGet();
lockWrapper.readLock.acquire();
}

private void unlockWriteLock(ICacheKey<K> key) {
LockWrapper lockWrapper = locks.get(key);
lockWrapper.writeLock.close();
if (lockWrapper.refCount.decrementAndGet() == 0) {
locks.remove(key);
}
}

private void unlockReadLock(ICacheKey<K> key) {
LockWrapper lockWrapper = locks.get(key);
lockWrapper.readLock.close();
if (lockWrapper.refCount.decrementAndGet() == 0) {
locks.remove(key);
}
}

// Package private for testing.
void enableDisableDiskCache(Boolean isDiskCacheEnabled) {
// When disk cache is disabled, we are not clearing up the disk cache entries yet as that should be part of
Expand All @@ -170,9 +217,12 @@ public V get(ICacheKey<K> key) {

@Override
public void put(ICacheKey<K> key, V value) {
try (ReleasableLock ignore = writeLock.acquire()) {
try {
writeLock(key);
onHeapCache.put(key, value);
updateStatsOnPut(TIER_DIMENSION_VALUE_ON_HEAP, key, value);
} finally {
unlockWriteLock(key);
}
}

Expand All @@ -191,8 +241,11 @@ public V computeIfAbsent(ICacheKey<K> key, LoadAwareCacheLoader<ICacheKey<K>, V>
// This is needed as there can be many requests for the same key at the same time and we only want to load
// the value once.
V value = null;
try (ReleasableLock ignore = writeLock.acquire()) {
try {
writeLock(key);
value = onHeapCache.computeIfAbsent(key, loader);
} finally {
unlockWriteLock(key);
}
// Handle stats
if (loader.isLoaded()) {
Expand Down Expand Up @@ -234,20 +287,23 @@ public void invalidate(ICacheKey<K> key) {
statsHolder.removeDimensions(dimensionValues);
}
if (key.key != null) {
try (ReleasableLock ignore = writeLock.acquire()) {
try {
writeLock(key);
cacheEntry.getKey().invalidate(key);
} finally {
unlockWriteLock(key);
}
}
}
}

@Override
public void invalidateAll() {
try (ReleasableLock ignore = writeLock.acquire()) {
//try (ReleasableLock ignore = writeLock.acquire()) {
for (Map.Entry<ICache<K, V>, TierInfo> cacheEntry : caches.entrySet()) {
cacheEntry.getKey().invalidateAll();
}
}
//}
statsHolder.reset();
}

Expand Down Expand Up @@ -275,11 +331,11 @@ public long count() {

@Override
public void refresh() {
try (ReleasableLock ignore = writeLock.acquire()) {
//try (ReleasableLock ignore = writeLock.acquire()) {
for (Map.Entry<ICache<K, V>, TierInfo> cacheEntry : caches.entrySet()) {
cacheEntry.getKey().refresh();
}
}
//}
}

@Override
Expand All @@ -302,7 +358,8 @@ public ImmutableCacheStatsHolder stats(String[] levels) {
*/
private Function<ICacheKey<K>, Tuple<V, String>> getValueFromTieredCache(boolean captureStats) {
return key -> {
try (ReleasableLock ignore = readLock.acquire()) {
try {
readLock(key);
for (Map.Entry<ICache<K, V>, TierInfo> cacheEntry : caches.entrySet()) {
if (cacheEntry.getValue().isEnabled()) {
V value = cacheEntry.getKey().get(key);
Expand All @@ -320,6 +377,8 @@ private Function<ICacheKey<K>, Tuple<V, String>> getValueFromTieredCache(boolean
}
}
return null;
} finally {
unlockReadLock(key);
}
};
}
Expand All @@ -328,8 +387,11 @@ void handleRemovalFromHeapTier(RemovalNotification<ICacheKey<K>, V> notification
ICacheKey<K> key = notification.getKey();
boolean wasEvicted = SPILLOVER_REMOVAL_REASONS.contains(notification.getRemovalReason());
if (caches.get(diskCache).isEnabled() && wasEvicted && evaluatePolicies(notification.getValue())) {
try (ReleasableLock ignore = writeLock.acquire()) {
try {
writeLock(key);
diskCache.put(key, notification.getValue()); // spill over to the disk tier and increment its stats
} finally {
unlockWriteLock(key);
}
updateStatsOnPut(TIER_DIMENSION_VALUE_DISK, key, notification.getValue());
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1320,7 +1320,65 @@ public void testTierStatsAddCorrectly() throws Exception {
clusterSettings.applySettings(
Settings.builder().put(DISK_CACHE_ENABLED_SETTING_MAP.get(CacheType.INDICES_REQUEST_CACHE).getKey(), true).build()
);
}

public void testNumLocksTiming() throws Exception {
int onHeapCacheSize = randomIntBetween(2400, 2401);
int diskCacheSize = randomIntBetween(5000, 10000);
int keyValueSize = 50;
MockCacheRemovalListener<String, String> removalListener = new MockCacheRemovalListener<>();
TieredSpilloverCache<String, String> tieredSpilloverCache = initializeTieredSpilloverCache(
keyValueSize,
diskCacheSize,
removalListener,
Settings.builder()
.put(
OpenSearchOnHeapCacheSettings.getSettingListForCacheType(CacheType.INDICES_REQUEST_CACHE)
.get(MAXIMUM_SIZE_IN_BYTES_KEY)
.getKey(),
onHeapCacheSize * keyValueSize + "b"
)
.build(),
0
);

int numRequests = 100_000;
// Each thread will do this many requests for key with string value of i, and then that many again (for possible hits)
int numThreads = 8;
Thread[] threads = new Thread[numThreads];
Phaser phaser = new Phaser(numThreads + 1);
CountDownLatch countDownLatch = new CountDownLatch(numThreads);

// Precompute the keys each thread will request so we don't include that in the time estimate
List<List<ICacheKey<String>>> keysPerThread = new ArrayList<>();

for (int i = 0; i < numThreads; i++) {
keysPerThread.add(new ArrayList<>());
int finalI = i;
for (int j = 0; j < numRequests; j++) {
keysPerThread.get(i).add(getICacheKey(String.valueOf(randomInt(numRequests))));
}

threads[i] = new Thread(() -> {
phaser.arriveAndAwaitAdvance();
try {
for (int j = 0; j < numRequests; j++) {
tieredSpilloverCache.computeIfAbsent(keysPerThread.get(finalI).get(j), getLoadAwareCacheLoader());
if (j % 100 == 0) {
System.out.println("Finished iter " + j);
}
}
} catch (Exception ignored) {}
countDownLatch.countDown();
});
threads[i].start();
}
long now = System.nanoTime();
phaser.arriveAndAwaitAdvance();
countDownLatch.await();
long elapsed = System.nanoTime() - now;
//System.out.println("TIME TAKEN FOR NUM_LOCKS = " + TieredSpilloverCache.NUM_LOCKS + " is " + elapsed + " ns
// or " + (float) elapsed / 1000000000 + " sec");
}

private List<String> getMockDimensions() {
Expand Down

0 comments on commit 727aa91

Please sign in to comment.