diff --git a/jetty-core/jetty-server/src/main/config/etc/jetty-dos.xml b/jetty-core/jetty-server/src/main/config/etc/jetty-dos.xml new file mode 100644 index 00000000000..6ac61077623 --- /dev/null +++ b/jetty-core/jetty-server/src/main/config/etc/jetty-dos.xml @@ -0,0 +1,67 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/jetty-core/jetty-server/src/main/config/modules/dos.mod b/jetty-core/jetty-server/src/main/config/modules/dos.mod new file mode 100644 index 00000000000..108e2769070 --- /dev/null +++ b/jetty-core/jetty-server/src/main/config/modules/dos.mod @@ -0,0 +1,59 @@ +# DO NOT EDIT THIS FILE - See: https://eclipse.dev/jetty/documentation/ + +[description] +Enables the DosHandler for the server. + +[tags] +connector + +[depend] +server + +[xml] +etc/jetty-dos.xml + +[ini-template] + +## The algorithm to use for obtaining an Id from an Request: ID_FROM_REMOTE_ADDRESS, ID_FROM_REMOTE_PORT, ID_FROM_REMOTE_ADDRESS_PORT, ID_CONNECTION +#jetty.dos.id.type=ID_FROM_REMOTE_ADDRESS +#jetty.dos.id.class=org.eclipse.jetty.server.handler.DosHandler + +## The class to use to create RateControl instances to track the rate of requests +#jetty.dos.rateControlFactory=org.eclipse.jetty.server.handler.DosHandler$ExponentialMovingAverageRateControlFactory + +## The sample period(ms) to determine the request rate, or -1 for a default value +#jetty.dos.rateControlFactory.samplePeriodMs=100 + +## The Exponential factor for the moving average rate +#jetty.dos.rateControlFactory.expMovingAvg.alpha=0.2 + +## The maximum requests per second per client +#jetty.dos.maxRequestsPerSecond=100 + +## The Handler class to use to reject DOS requests +#jetty.dos.rejectHandler=org.eclipse.jetty.server.handler.DosHandler$TooManyRequestsRejectHandler + +## The period to delay dos requests before rejecting them. +#jetty.dos.rejectHandler.delayed.delayMs=1000 + +## The maximum number of requests to be held in the delay queue +#jetty.dos.rejectHandler.delayed.maxDelayQueue=1000 + +## The maximum number of clients to track; or -1 for a default value +#jetty.dos.maxTrackers=10000 + +## The status code used to reject requests; or 0 to abort the request; or -1 for a default +#jetty.dos.rejectStatus=429 + +## List of InetAddress patterns to include +#jetty.dos.include.inet=10.10.10-14.0-128 + +## List of InetAddressPatterns to exclude +#jetty.dos.exclude.inet=10.10.10-14.0-128 + +## List of path patterns to include +#jetty.dos.include.path=/context/* + +## List of path to exclude +#jetty.dos.exclude.path=/context/* + diff --git a/jetty-core/jetty-server/src/main/java/org/eclipse/jetty/server/handler/DoSHandler.java b/jetty-core/jetty-server/src/main/java/org/eclipse/jetty/server/handler/DoSHandler.java new file mode 100644 index 00000000000..db3be689d80 --- /dev/null +++ b/jetty-core/jetty-server/src/main/java/org/eclipse/jetty/server/handler/DoSHandler.java @@ -0,0 +1,653 @@ +// +// ======================================================================== +// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +// ======================================================================== +// + +package org.eclipse.jetty.server.handler; + +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.time.Duration; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; + +import org.eclipse.jetty.http.HttpStatus; +import org.eclipse.jetty.io.CyclicTimeouts; +import org.eclipse.jetty.server.ConnectionMetaData; +import org.eclipse.jetty.server.Handler; +import org.eclipse.jetty.server.Request; +import org.eclipse.jetty.server.Response; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.util.Callback; +import org.eclipse.jetty.util.NanoTime; +import org.eclipse.jetty.util.annotation.ManagedObject; +import org.eclipse.jetty.util.annotation.Name; +import org.eclipse.jetty.util.thread.AutoLock; +import org.eclipse.jetty.util.thread.Scheduler; + +/** + * A Denial of Service Handler. + *

Protect from denial of service attacks by limiting the request rate from remote hosts

+ */ +@ManagedObject("DoS Prevention Handler") +public class DoSHandler extends ConditionalHandler.ElseNext +{ + /** + * An id {@link Function} to create an ID from the remote address and port of a {@link Request} + */ + public static final Function ID_FROM_REMOTE_ADDRESS_PORT = request -> + { + SocketAddress remoteSocketAddress = request.getConnectionMetaData().getRemoteSocketAddress(); + if (remoteSocketAddress instanceof InetSocketAddress inetSocketAddress) + return inetSocketAddress.toString(); + return remoteSocketAddress.toString(); + }; + + /** + * An id {@link Function} to create an ID from the remote address of a {@link Request} + */ + public static final Function ID_FROM_REMOTE_ADDRESS = request -> + { + SocketAddress remoteSocketAddress = request.getConnectionMetaData().getRemoteSocketAddress(); + if (remoteSocketAddress instanceof InetSocketAddress inetSocketAddress) + return inetSocketAddress.getAddress().toString(); + return remoteSocketAddress.toString(); + }; + + /** + * An id {@link Function} to create an ID from the remote port of a {@link Request}. + * This can be useful if there is an untrusted intermediary, where the remote port can be a surrogate for the connection. + */ + public static final Function ID_FROM_REMOTE_PORT = request -> + { + SocketAddress remoteSocketAddress = request.getConnectionMetaData().getRemoteSocketAddress(); + if (remoteSocketAddress instanceof InetSocketAddress inetSocketAddress) + return Integer.toString(inetSocketAddress.getPort()); + return remoteSocketAddress.toString(); + }; + + /** + * An id {@link Function} to create an ID from {@link ConnectionMetaData#getId()} of the {@link Request} + */ + public static final Function ID_FROM_CONNECTION = request -> request.getConnectionMetaData().getId(); + + /** + * An interface implemented to track and control the rate of requests for a specific ID. + */ + public interface RateControl + { + /** + * Record a request and calculate if the rate is exceeded at the given time. + * @param now The {@link NanoTime#now()} at which to calculate the rate + * @return {@code true} if the rate is currently exceeded + */ + boolean onRequest(long now); + + /** + * Check if the tracker is now idle + * @param now The {@link NanoTime#now()} at which to calculate the rate + * @return {@code true} if the rate is currently near zero + */ + boolean isIdle(long now); + + /** + * A factory to create new {@link RateControl} instances + */ + interface Factory + { + RateControl newRateControl(); + + Duration idleCheckPeriod(); + } + } + + private final Map _trackers = new ConcurrentHashMap<>(); + private final Function _getId; + private final RateControl.Factory _rateControlFactory; + private final Request.Handler _rejectHandler; + private final int _maxTrackers; + private CyclicTimeouts _cyclicTimeouts; + + public DoSHandler() + { + this(null, 100, -1); + } + + /** + * @param maxRequestsPerSecond The maximum requests per second allows per ID. + */ + public DoSHandler(@Name("maxRequestsPerSecond") int maxRequestsPerSecond) + { + this(null, maxRequestsPerSecond, -1); + } + + /** + * @param getId Function to extract an remote ID from a request. + * @param maxRequestsPerSecond The maximum requests per second allows per ID. + * @param maxTrackers The maximum number of remote clients to track or -1 for a default value. If this limit is exceeded, then requests from additional remote clients are rejected. + */ + public DoSHandler( + @Name("getId") Function getId, + @Name("maxRequestsPerSecond") int maxRequestsPerSecond, + @Name("maxTrackers") int maxTrackers) + { + this(null, getId, new ExponentialMovingAverageRateControlFactory(maxRequestsPerSecond), null, maxTrackers); + } + + /** + * @param getId Function to extract an remote ID from a request. + * @param rateControlFactory Factory to create a Rate per Tracker + * @param rejectHandler A {@link Handler} used to reject excess requests, or {@code null} for a default. + * @param maxTrackers The maximum number of remote clients to track or -1 for a default value. If this limit is exceeded, then requests from additional remote clients are rejected. + */ + public DoSHandler( + @Name("getId") Function getId, + @Name("rateFactory") RateControl.Factory rateControlFactory, + @Name("rejectHandler") Request.Handler rejectHandler, + @Name("maxTrackers") int maxTrackers) + { + this(null, getId, rateControlFactory, rejectHandler, maxTrackers); + } + + /** + * @param handler Then next {@link Handler} or {@code null} + * @param getId Function to extract an remote ID from a request. + * @param rateControlFactory Factory to create a Rate per Tracker + * @param rejectHandler A {@link Handler} used to reject excess requests, or {@code null} for a default. + * @param maxTrackers The maximum number of remote clients to track or -1 for a default value. If this limit is exceeded, then requests from additional remote clients are rejected. + */ + public DoSHandler( + @Name("handler") Handler handler, + @Name("getId") Function getId, + @Name("rateFactory") RateControl.Factory rateControlFactory, + @Name("rejectHandler") Request.Handler rejectHandler, + @Name("maxTrackers") int maxTrackers) + { + super(handler); + installBean(_trackers); + _getId = Objects.requireNonNullElse(getId, ID_FROM_REMOTE_ADDRESS); + installBean(_getId); + _rateControlFactory = Objects.requireNonNull(rateControlFactory); + installBean(_rateControlFactory); + _maxTrackers = maxTrackers < 0 ? 10_000 : maxTrackers; + _rejectHandler = Objects.requireNonNullElseGet(rejectHandler, StatusRejectHandler::new); + installBean(_rejectHandler); + } + + @Override + public void setServer(Server server) + { + super.setServer(server); + if (_rejectHandler instanceof Handler handler) + handler.setServer(server); + } + + @Override + protected boolean onConditionsMet(Request request, Response response, Callback callback) throws Exception + { + // Reject if we have too many Trackers + if (_maxTrackers > 0 && _trackers.size() >= _maxTrackers) + { + // Try shrinking the tracker pool + long now = NanoTime.now(); + _trackers.values().removeIf(tracker -> tracker.isIdle(now)); + if (_trackers.size() >= _maxTrackers) + { + // Try shrinking the tracker pool as if we are at the next idle check already + long nextIdleCheck = NanoTime.now() + _rateControlFactory.idleCheckPeriod().getNano(); + _trackers.values().removeIf(tracker -> tracker.isIdle(nextIdleCheck)); + if (_trackers.size() >= _maxTrackers) + return _rejectHandler.handle(request, response, callback); + } + } + + // Calculate an id for the request (which may be global empty string) + String id = _getId.apply(request); + + if (id == null) + return _rejectHandler.handle(request, response, callback); + + // Obtain a tracker, creating a new one if necessary. Trackers are removed if CyclicTimeouts#onExpired returns true + Tracker tracker = _trackers.computeIfAbsent(id, this::newTracker); + + // If we are not over-limit then handle normally + if (!tracker.onRequest(request.getBeginNanoTime())) + return nextHandler(request, response, callback); + + // Otherwise reject the request + return _rejectHandler.handle(request, response, callback); + } + + Tracker newTracker(String id) + { + return new Tracker(id, _rateControlFactory.newRateControl()); + } + + @Override + protected void doStart() throws Exception + { + _cyclicTimeouts = new CyclicTimeouts<>(getServer().getScheduler()) + { + @Override + protected Iterator iterator() + { + return _trackers.values().iterator(); + } + + @Override + protected boolean onExpired(Tracker tracker) + { + return tracker.isIdle(NanoTime.now()); + } + }; + addBean(_cyclicTimeouts); + super.doStart(); + } + + @Override + protected void doStop() throws Exception + { + super.doStop(); + removeBean(_cyclicTimeouts); + _cyclicTimeouts.destroy(); + _cyclicTimeouts = null; + } + + /** + * A RateTracker is associated with a connection, and stores request rate data. + */ + class Tracker implements CyclicTimeouts.Expirable + { + private final AutoLock _lock = new AutoLock(); + private final String _id; + private final RateControl _rateControl; + private final Duration _idleCheck; + private long _nextIdleCheckAt; + + Tracker(String id, RateControl rateControl) + { + this(id, rateControl, null); + } + + Tracker(String id, RateControl rateControl, Duration idleCheck) + { + _id = id; + _rateControl = rateControl; + _idleCheck = idleCheck == null ? Duration.ofSeconds(2) : idleCheck; + _nextIdleCheckAt = NanoTime.now() + _idleCheck.toNanos(); + } + + public String getId() + { + return _id; + } + + RateControl getRateControl() + { + return _rateControl; + } + + public boolean onRequest(long now) + { + try (AutoLock l = _lock.lock()) + { + CyclicTimeouts cyclicTimeouts = _cyclicTimeouts; + if (cyclicTimeouts != null) + { + // schedule a check to remove this tracker if idle + _nextIdleCheckAt = now + _idleCheck.toNanos(); + cyclicTimeouts.schedule(this); + } + return _rateControl.onRequest(now); + } + } + + public boolean isIdle(long now) + { + try (AutoLock l = _lock.lock()) + { + CyclicTimeouts cyclicTimeouts = _cyclicTimeouts; + if (cyclicTimeouts != null) + { + _nextIdleCheckAt = now + _idleCheck.toNanos(); + cyclicTimeouts.schedule(this); + } + return _rateControl.isIdle(now); + } + } + + @Override + public long getExpireNanoTime() + { + return _nextIdleCheckAt; + } + + @Override + public String toString() + { + try (AutoLock l = _lock.lock()) + { + return "Tracker@%s{rc=%s}".formatted(_id, _rateControl); + } + } + } + + /** + * A {@link RateControl.Factory} that uses an + * Exponential Moving Average + * to limit the request rate to a maximum number of requests per second. + */ + public static class ExponentialMovingAverageRateControlFactory implements RateControl.Factory + { + private final Duration _samplePeriod; + private final Duration _idleCheckPeriod; + private final double _alpha; + private final int _maxRequestsPerSecond; + + public ExponentialMovingAverageRateControlFactory() + { + this(null, -1.0, 1000); + } + + public ExponentialMovingAverageRateControlFactory(@Name("maxRequestsPerSecond") int maxRateRequestsPerSecond) + { + this(null, -1.0, maxRateRequestsPerSecond); + } + + public ExponentialMovingAverageRateControlFactory( + @Name("samplePeriodMs") long samplePeriodMs, + @Name("alpha") double alpha, + @Name("maxRequestsPerSecond") int maxRequestsPerSecond) + { + this(samplePeriodMs <= 0 ? null : Duration.ofMillis(samplePeriodMs), alpha, maxRequestsPerSecond); + } + + public ExponentialMovingAverageRateControlFactory( + @Name("samplePeriod") Duration samplePeriod, + @Name("alpha") double alpha, + @Name("maxRequestsPerSecond") int maxRequestsPerSecond) + { + this(samplePeriod, null, alpha, maxRequestsPerSecond); + } + + public ExponentialMovingAverageRateControlFactory( + @Name("samplePeriod") Duration samplePeriod, + @Name("idleCheckPeriod") Duration idleCheckPeriod, + @Name("alpha") double alpha, + @Name("maxRequestsPerSecond") int maxRequestsPerSecond) + { + _samplePeriod = samplePeriod == null ? Duration.ofMillis(100) : samplePeriod; + _idleCheckPeriod = idleCheckPeriod == null ? _samplePeriod.multipliedBy(20) : idleCheckPeriod; + _alpha = alpha <= 0.0 ? 0.2 : alpha; + if (_samplePeriod.compareTo(Duration.ofSeconds(1)) > 0) + throw new IllegalArgumentException("Sample period must be less than or equal to 1 second"); + if (_alpha > 1.0) + throw new IllegalArgumentException("Alpha " + _alpha + " is too large"); + _maxRequestsPerSecond = maxRequestsPerSecond; + } + + @Override + public Duration idleCheckPeriod() + { + return _idleCheckPeriod; + } + + @Override + public RateControl newRateControl() + { + return new ExponentialMovingAverageRateControl(); + } + + class ExponentialMovingAverageRateControl implements RateControl + { + private double _exponentialMovingAverage; + private int _sampleCount; + private long _sampleStart; + + private ExponentialMovingAverageRateControl() + { + _sampleStart = NanoTime.now(); + } + + @Override + public boolean onRequest(long now) + { + // Count the request + _sampleCount++; + + long elapsedTime = NanoTime.elapsed(_sampleStart, now); + + // We calculate the moving average if: + // + the sample exceeds the rate + // + the sample period has been exceeded + if (_sampleCount > _maxRequestsPerSecond || (_sampleStart != 0 && elapsedTime > _samplePeriod.toNanos())) + { + calculateMovingAverage(now); + } + + // if the rate has been exceeded? + return _exponentialMovingAverage > _maxRequestsPerSecond; + } + + @Override + public boolean isIdle(long now) + { + calculateMovingAverage(now); + return _exponentialMovingAverage <= 0.0001; + } + + private void calculateMovingAverage(long now) + { + double elapsedTime1 = (double)(now - _sampleStart); + double count = _sampleCount; + if (elapsedTime1 > 0.0) + { + double currentRate = (count * TimeUnit.SECONDS.toNanos(1L)) / elapsedTime1; + // Adjust alpha based on the ratio of elapsed time to the interval to allow for long and short intervals + double adjustedAlpha = _alpha * (elapsedTime1 / _samplePeriod.toNanos()); + if (adjustedAlpha > 1.0) + adjustedAlpha = 1.0; // Ensure adjustedAlpha does not exceed 1.0 + + _exponentialMovingAverage = (adjustedAlpha * currentRate + (1.0 - adjustedAlpha) * _exponentialMovingAverage); + } + else + { + // assume count as the rate for the sample. + double guessedRate = count * TimeUnit.SECONDS.toNanos(1) / _samplePeriod.toNanos(); + _exponentialMovingAverage = (_alpha * guessedRate + (1.0 - _alpha) * _exponentialMovingAverage); + } + + // restart the sample + _sampleStart = now; + _sampleCount = 0; + } + + double getCurrentRatePerSecond() + { + return _exponentialMovingAverage; + } + } + } + + /** + * A Handler to reject DoS requests with a status code or failure. + */ + public static class StatusRejectHandler implements Request.Handler + { + private final int _status; + + public StatusRejectHandler() + { + this(-1); + } + + /** + * @param status The status used to reject a request, or 0 to fail the request or -1 for a default ({@link HttpStatus#TOO_MANY_REQUESTS_429}. + */ + public StatusRejectHandler(int status) + { + _status = status >= 0 ? status : HttpStatus.TOO_MANY_REQUESTS_429; + if (_status != 0 && _status != HttpStatus.OK_200 && !HttpStatus.isClientError(_status) && !HttpStatus.isServerError(_status)) + throw new IllegalArgumentException("status must be a client or server error"); + } + + @Override + public boolean handle(Request request, Response response, Callback callback) throws Exception + { + if (_status == 0) + callback.failed(new RejectedExecutionException()); + else + Response.writeError(request, response, callback, _status); + return true; + } + } + + /** + * A Handler to reject DoS requests after first delaying them. + */ + public static class DelayedRejectHandler extends Handler.Abstract + { + record Exchange(Request request, Response response, Callback callback) + {} + + private final AutoLock _lock = new AutoLock(); + private final Deque _delayQueue = new ArrayDeque<>(); + private final int _maxDelayQueue; + private final long _delayMs; + private final Request.Handler _reject; + private Scheduler _scheduler; + + public DelayedRejectHandler() + { + this(-1, -1, null); + } + + /** + * @param delayMs The delay in milliseconds to hold rejected requests before sending a response or -1 for a default (1000ms) + * @param maxDelayQueue The maximum number of delayed requests to hold or -1 for a default (1000ms). + * @param reject The {@link Request.Handler} used to reject {@link Request}s or null for a default ({@link HttpStatus#TOO_MANY_REQUESTS_429}). + */ + public DelayedRejectHandler( + @Name("delayMs") long delayMs, + @Name("maxDelayQueue") int maxDelayQueue, + @Name("reject") Request.Handler reject) + { + _delayMs = delayMs >= 0 ? delayMs : 1000; + _maxDelayQueue = maxDelayQueue >= 0 ? maxDelayQueue : 1000; + _reject = Objects.requireNonNullElseGet(reject, () -> new StatusRejectHandler(HttpStatus.TOO_MANY_REQUESTS_429)); + } + + @Override + protected void doStart() throws Exception + { + super.doStart(); + _scheduler = getServer().getScheduler(); + addBean(_scheduler); + } + + @Override + protected void doStop() throws Exception + { + super.doStop(); + removeBean(_scheduler); + _scheduler = null; + } + + @Override + public boolean handle(Request request, Response response, Callback callback) throws Exception + { + List rejects = null; + try (AutoLock ignored = _lock.lock()) + { + while (_delayQueue.size() >= _maxDelayQueue) + { + Exchange exchange = _delayQueue.removeFirst(); + if (rejects == null) + rejects = new ArrayList<>(); + rejects.add(exchange); + } + + if (_delayQueue.isEmpty()) + _scheduler.schedule(this::onTick, _delayMs / 2, TimeUnit.MILLISECONDS); + _delayQueue.addLast(new Exchange(request, response, callback)); + } + + if (rejects != null) + { + for (Exchange exchange : rejects) + { + try + { + if (!_reject.handle(exchange.request, exchange.response, exchange.callback)) + exchange.callback.failed(new RejectedExecutionException()); + } + catch (Throwable t) + { + exchange.callback.failed(t); + } + } + } + + return true; + } + + private void onTick() + { + long expired = NanoTime.now() - TimeUnit.MILLISECONDS.toNanos(_delayMs); + + List rejects = null; + try (AutoLock ignored = _lock.lock()) + { + Iterator iterator = _delayQueue.iterator(); + while (iterator.hasNext()) + { + Exchange exchange = iterator.next(); + if (NanoTime.isBeforeOrSame(exchange.request.getBeginNanoTime(), expired)) + { + iterator.remove(); + + if (rejects == null) + rejects = new ArrayList<>(); + rejects.add(exchange); + } + } + + if (!_delayQueue.isEmpty()) + _scheduler.schedule(this::onTick, _delayMs / 2, TimeUnit.MILLISECONDS); + } + + if (rejects != null) + { + for (Exchange exchange : rejects) + { + try + { + if (!_reject.handle(exchange.request, exchange.response, exchange.callback)) + exchange.callback.failed(new RejectedExecutionException()); + } + catch (Throwable t) + { + exchange.callback.failed(t); + } + } + } + } + } +} diff --git a/jetty-core/jetty-server/src/test/java/org/eclipse/jetty/server/handler/DoSHandlerTest.java b/jetty-core/jetty-server/src/test/java/org/eclipse/jetty/server/handler/DoSHandlerTest.java new file mode 100644 index 00000000000..2cca14bf680 --- /dev/null +++ b/jetty-core/jetty-server/src/test/java/org/eclipse/jetty/server/handler/DoSHandlerTest.java @@ -0,0 +1,291 @@ +// +// ======================================================================== +// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +// ======================================================================== +// + +package org.eclipse.jetty.server.handler; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import org.awaitility.Awaitility; +import org.eclipse.jetty.server.LocalConnector; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.util.NanoTime; +import org.junit.jupiter.api.Test; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.both; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.lessThan; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class DoSHandlerTest +{ + @Test + public void testTrackerSteadyBelowRate() throws Exception + { + DoSHandler handler = new DoSHandler(100); + DoSHandler.Tracker tracker = handler.newTracker("id"); + long now = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); + + for (int sample = 0; sample < 400; sample++) + { + boolean exceeded = tracker.onRequest(now); + assertFalse(exceeded); + now += TimeUnit.MILLISECONDS.toNanos(11); + } + double rate = tracker.getRateControl() instanceof DoSHandler.ExponentialMovingAverageRateControlFactory.ExponentialMovingAverageRateControl rc ? rc.getCurrentRatePerSecond() : 0.0; + assertThat(rate, both(greaterThan((1000.0D / 11) - 5)).and(lessThan(100.0D))); + } + + @Test + public void testTrackerSteadyAboveRate() throws Exception + { + DoSHandler handler = new DoSHandler(100); + DoSHandler.Tracker tracker = handler.newTracker("id"); + long now = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); + + boolean exceeded = false; + for (int sample = 0; sample < 200; sample++) + { + if (tracker.onRequest(now)) + { + exceeded = true; + break; + } + now += TimeUnit.MILLISECONDS.toNanos(9); + } + + assertTrue(exceeded); + } + + @Test + public void testTrackerUnevenBelowRate() throws Exception + { + DoSHandler handler = new DoSHandler(100); + DoSHandler.Tracker tracker = handler.newTracker("id"); + long now = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); + + for (int sample = 0; sample < 20; sample++) + { + for (int burst = 0; burst < 9; burst++) + { + boolean exceeded = tracker.onRequest(now); + assertFalse(exceeded); + } + + now += TimeUnit.MILLISECONDS.toNanos(100); + } + } + + @Test + public void testTrackerUnevenAboveRate() throws Exception + { + DoSHandler handler = new DoSHandler(100); + DoSHandler.Tracker tracker = handler.newTracker("id"); + long now = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); + + boolean exceeded = false; + for (int sample = 0; sample < 20; sample++) + { + for (int burst = 0; burst < 11; burst++) + { + if (tracker.onRequest(now)) + { + exceeded = true; + break; + } + } + + now += TimeUnit.MILLISECONDS.toNanos(100); + } + + assertTrue(exceeded); + } + + @Test + public void testTrackerBurstBelowRate() throws Exception + { + DoSHandler handler = new DoSHandler(100); + DoSHandler.Tracker tracker = handler.newTracker("id"); + long now = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); + + for (int seconds = 0; seconds < 2; seconds++) + { + for (int burst = 0; burst < 99; burst++) + { + boolean exceeded = tracker.onRequest(now); + assertFalse(exceeded); + } + + now += TimeUnit.MILLISECONDS.toNanos(1000); + } + } + + @Test + public void testTrackerBurstAboveRate() throws Exception + { + DoSHandler handler = new DoSHandler(100); + DoSHandler.Tracker tracker = handler.newTracker("id"); + long now = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); + + boolean exceeded = false; + for (int seconds = 0; seconds < 2; seconds++) + { + for (int burst = 0; burst < 101; burst++) + { + if (tracker.onRequest(now)) + { + exceeded = true; + break; + } + } + + now += TimeUnit.MILLISECONDS.toNanos(1000); + } + + assertTrue(exceeded); + } + + @Test + public void testRecoveryAfterBursts() throws Exception + { + DoSHandler handler = new DoSHandler(100); + DoSHandler.Tracker tracker = handler.newTracker("id"); + long now = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); + + for (int seconds = 0; seconds < 2; seconds++) + { + for (int burst = 0; burst < 99; burst++) + assertFalse(tracker.onRequest(now++)); + + now += TimeUnit.MILLISECONDS.toNanos(1000) - 100; + } + + double rate = tracker.getRateControl() instanceof DoSHandler.ExponentialMovingAverageRateControlFactory.ExponentialMovingAverageRateControl rc ? rc.getCurrentRatePerSecond() : 0.0; + assertThat(rate, both(greaterThan(90.0D)).and(lessThan(100.0D))); + + for (int seconds = 0; seconds < 2; seconds++) + { + for (int burst = 0; burst < 49; burst++) + assertFalse(tracker.onRequest(now++)); + + now += TimeUnit.MILLISECONDS.toNanos(1000) - 100; + } + + rate = tracker.getRateControl() instanceof DoSHandler.ExponentialMovingAverageRateControlFactory.ExponentialMovingAverageRateControl rc ? rc.getCurrentRatePerSecond() : 0.0; + assertThat(rate, both(greaterThan(40.0D)).and(lessThan(50.0D))); + } + + @Test + public void testOKRequestRate() throws Exception + { + Server server = new Server(); + LocalConnector connector = new LocalConnector(server); + server.addConnector(connector); + + DoSHandler dosHandler = new DoSHandler(1000); + DumpHandler dumpHandler = new DumpHandler(); + server.setHandler(dosHandler); + dosHandler.setHandler(dumpHandler); + + server.start(); + + long now = System.nanoTime(); + long end = now + TimeUnit.SECONDS.toNanos(5); + CountDownLatch latch = new CountDownLatch(90); + for (int thread = 0; thread < 90; thread++) + { + server.getThreadPool().execute(() -> + { + try + { + while (NanoTime.isBefore(NanoTime.now(), end)) + { + String response = connector.getResponse(""" + GET / HTTP/1.1\r + Host: local\r + + """); + assertThat(response, containsString("200 OK")); + Thread.sleep(100); + } + latch.countDown(); + } + catch (Throwable x) + { + throw new RuntimeException(x); + } + }); + } + + assertTrue(latch.await(10, TimeUnit.SECONDS)); + } + + @Test + public void testHighRequestRate() throws Exception + { + Server server = new Server(); + LocalConnector connector = new LocalConnector(server); + server.addConnector(connector); + + DoSHandler dosHandler = new DoSHandler(1000); + DumpHandler dumpHandler = new DumpHandler(); + server.setHandler(dosHandler); + dosHandler.setHandler(dumpHandler); + + server.start(); + + long now = System.nanoTime(); + long end = now + TimeUnit.SECONDS.toNanos(5); + AtomicInteger outstanding = new AtomicInteger(0); + AtomicInteger calm = new AtomicInteger(); + for (int thread = 0; thread < 90; thread++) + { + server.getThreadPool().execute(() -> + { + try + { + while (NanoTime.isBefore(NanoTime.now(), end)) + { + try + { + outstanding.incrementAndGet(); + String response = connector.getResponse(""" + GET / HTTP/1.1\r + Host: local\r + + """); + if (response.contains(" 429 ")) + calm.incrementAndGet(); + Thread.sleep(70); + } + finally + { + outstanding.decrementAndGet(); + } + } + } + catch (Throwable x) + { + throw new RuntimeException(x); + } + }); + } + + Awaitility.waitAtMost(10, TimeUnit.SECONDS).until(() -> outstanding.get() == 0); + assertThat(calm.get(), greaterThan(0)); + } +}