RateLimitConfig.java

package jasper.config;

import io.github.resilience4j.bulkhead.Bulkhead;
import io.github.resilience4j.bulkhead.BulkheadFullException;
import io.github.resilience4j.ratelimiter.RateLimiter;
import io.github.resilience4j.ratelimiter.RateLimiterConfig;
import jakarta.servlet.Filter;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jasper.component.ConfigCache;
import jasper.security.Auth;
import jasper.service.dto.TemplateDto;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.integration.annotation.ServiceActivator;
import org.springframework.messaging.Message;
import org.springframework.web.filter.GenericFilterBean;

import java.io.IOException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ThreadLocalRandom;

import static java.lang.String.format;
import static java.time.Duration.ofNanos;
import static org.apache.commons.lang3.StringUtils.isBlank;

@Configuration
public class RateLimitConfig {
	private static final Logger logger = LoggerFactory.getLogger(RateLimitConfig.class);

	@Autowired
	ConfigCache configs;

	@Autowired
	Auth auth;

	@Autowired
	Bulkhead httpBulkhead;

	private final ConcurrentHashMap<String, RateLimiter> originRateLimiters = new ConcurrentHashMap<>();
	private RateLimiter getOriginRateLimiter(String origin) {
		return originRateLimiters.computeIfAbsent(origin, k -> RateLimiter.of("http-" + origin, RateLimiterConfig.custom()
				.limitForPeriod(configs.security(origin).getMaxRequests())
				.limitRefreshPeriod(ofNanos(500))
				.build()));
	}

	@ServiceActivator(inputChannel = "templateRxChannel")
	public void handleTemplateUpdate(Message<TemplateDto> message) {
		var template = message.getPayload();
		if (isBlank(template.getTag())) return;
		if (template.getTag().startsWith("_config/server")) {
			httpRateLimiter().changeLimitForPeriod(configs.root().getMaxConcurrentRequests());
		}
		if (template.getTag().startsWith("_config/security")) {
			originRateLimiters.forEach((origin, r) -> r.changeLimitForPeriod(configs.security(origin).getMaxRequests()));
		}
	}

	@Bean
	RateLimiter httpRateLimiter() {
		return RateLimiter.of("http", RateLimiterConfig.custom()
				.limitForPeriod(configs.root().getMaxConcurrentRequests())
				.limitRefreshPeriod(ofNanos(500))
				.build());
	}

	@Bean
	public Filter rateLimitInterceptor(RateLimiter httpRateLimiter) {
		return new GenericFilterBean() {
			@Override
			public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
				HttpServletRequest httpRequest = (HttpServletRequest) request;
				HttpServletResponse httpResponse = (HttpServletResponse) response;
				var origin = auth.getOrigin();
				if (!getOriginRateLimiter(origin).acquirePermission()) {
					RateLimitConfig.logger.debug("{} Rate limit exceeded for origin: {}", origin, httpRequest.getRequestURI());
					httpResponse.setStatus(429);
					httpResponse.setHeader("X-RateLimit-Limit", ""+configs.security(origin).getMaxRequests());
					// Add random jitter from 3.5 to 4.5 seconds to prevent thundering herd
					httpResponse.setHeader("X-RateLimit-Retry-After", format("%.1f", ThreadLocalRandom.current().nextDouble(3.5, 4.5)));
					return;
				}
				if (!httpRateLimiter.acquirePermission()) {
					RateLimitConfig.logger.debug("HTTP rate limit exceeded for request: {}", httpRequest.getRequestURI());
					httpResponse.setStatus(429);
					// Add random jitter from 3.5 to 4.5 seconds to prevent thundering herd
					httpResponse.setHeader("X-RateLimit-Retry-After", format("%.1f", ThreadLocalRandom.current().nextDouble(3.5, 4.5)));
					return;
				}
				try {
					httpBulkhead.executeCheckedSupplier(() -> {
						chain.doFilter(request, response);
						return null;
					});
				} catch (BulkheadFullException e) {
					RateLimitConfig.logger.debug("HTTP concurrent limit exceeded for request: {}", httpRequest.getRequestURI());
					httpResponse.setStatus(HttpServletResponse.SC_SERVICE_UNAVAILABLE); // 503
					// Add random jitter from 3.5 to 4.5 seconds to prevent thundering herd
					httpResponse.setHeader("X-RateLimit-Retry-After", format("%.1f", ThreadLocalRandom.current().nextDouble(3.5, 4.5)));
				} catch (Throwable e) {
					throw new RuntimeException(e);
				}
			}
		};
	}
}