From 024b3880e7de9c6ed6ab9b1d30817aae322d1afa Mon Sep 17 00:00:00 2001 From: Cedric Hornberger Date: Tue, 5 May 2026 11:19:42 +0200 Subject: [PATCH] security: add per-user/IP rate limiting via Bucket4j RateLimitFilter (OncePerRequestFilter) enforces 60 req/min per authenticated Google ID or client IP, using Bucket4j in-memory token buckets. Filter is registered after BearerTokenAuthenticationFilter in the production security chain. Added 4 unit tests covering allow, block, per-IP isolation, and X-Forwarded-For preference. Co-Authored-By: Claude Sonnet 4.6 --- .../security/RateLimitFilter.java | 61 +++++++++++++ .../security/SecurityConfig.java | 2 + .../security/RateLimitFilterTest.java | 89 +++++++++++++++++++ 3 files changed, 152 insertions(+) create mode 100644 src/main/java/de/zendric/app/xpensely_server/security/RateLimitFilter.java create mode 100644 src/test/java/de/zendric/app/xpensely_Server/security/RateLimitFilterTest.java diff --git a/src/main/java/de/zendric/app/xpensely_server/security/RateLimitFilter.java b/src/main/java/de/zendric/app/xpensely_server/security/RateLimitFilter.java new file mode 100644 index 0000000..f04341e --- /dev/null +++ b/src/main/java/de/zendric/app/xpensely_server/security/RateLimitFilter.java @@ -0,0 +1,61 @@ +package de.zendric.app.xpensely_server.security; + +import io.github.bucket4j.Bandwidth; +import io.github.bucket4j.Bucket; +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.springframework.http.HttpStatus; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.web.filter.OncePerRequestFilter; + +import java.io.IOException; +import java.time.Duration; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +public class RateLimitFilter extends OncePerRequestFilter { + + private static final int REQUESTS_PER_MINUTE = 60; + + private final Map buckets = new ConcurrentHashMap<>(); + + @Override + protected void doFilterInternal(HttpServletRequest request, + HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { + String key = resolveKey(request); + Bucket bucket = buckets.computeIfAbsent(key, k -> newBucket()); + + if (bucket.tryConsume(1)) { + filterChain.doFilter(request, response); + } else { + response.setStatus(HttpStatus.TOO_MANY_REQUESTS.value()); + response.getWriter().write("Rate limit exceeded"); + } + } + + private String resolveKey(HttpServletRequest request) { + Authentication auth = SecurityContextHolder.getContext().getAuthentication(); + if (auth != null && auth.getPrincipal() instanceof Jwt jwt) { + return "user:" + jwt.getSubject(); + } + String ip = request.getHeader("X-Forwarded-For"); + if (ip != null && !ip.isBlank()) { + return "ip:" + ip.split(",")[0].trim(); + } + return "ip:" + request.getRemoteAddr(); + } + + private Bucket newBucket() { + return Bucket.builder() + .addLimit(Bandwidth.builder() + .capacity(REQUESTS_PER_MINUTE) + .refillGreedy(REQUESTS_PER_MINUTE, Duration.ofMinutes(1)) + .build()) + .build(); + } +} diff --git a/src/main/java/de/zendric/app/xpensely_server/security/SecurityConfig.java b/src/main/java/de/zendric/app/xpensely_server/security/SecurityConfig.java index d0d2465..1ae17fc 100644 --- a/src/main/java/de/zendric/app/xpensely_server/security/SecurityConfig.java +++ b/src/main/java/de/zendric/app/xpensely_server/security/SecurityConfig.java @@ -6,6 +6,7 @@ import org.springframework.context.annotation.Profile; import org.springframework.security.config.Customizer; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.oauth2.server.resource.web.BearerTokenAuthenticationFilter; import org.springframework.security.web.SecurityFilterChain; @Configuration @@ -31,6 +32,7 @@ public class SecurityConfig { .oauth2ResourceServer(oauth2 -> oauth2 .jwt(Customizer.withDefaults())) .oauth2Login(Customizer.withDefaults()) + .addFilterAfter(new RateLimitFilter(), BearerTokenAuthenticationFilter.class) .csrf().disable(); return http.build(); diff --git a/src/test/java/de/zendric/app/xpensely_Server/security/RateLimitFilterTest.java b/src/test/java/de/zendric/app/xpensely_Server/security/RateLimitFilterTest.java new file mode 100644 index 0000000..825832b --- /dev/null +++ b/src/test/java/de/zendric/app/xpensely_Server/security/RateLimitFilterTest.java @@ -0,0 +1,89 @@ +package de.zendric.app.xpensely_Server.security; + +import de.zendric.app.xpensely_server.security.RateLimitFilter; +import jakarta.servlet.FilterChain; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.*; + +class RateLimitFilterTest { + + RateLimitFilter filter; + FilterChain chain; + + @BeforeEach + void setUp() { + filter = new RateLimitFilter(); + chain = mock(FilterChain.class); + } + + @Test + void allowsRequestUnderLimit() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setRemoteAddr("1.2.3.4"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + filter.doFilter(request, response, chain); + + verify(chain, times(1)).doFilter(request, response); + assertEquals(200, response.getStatus()); + } + + @Test + void blocksRequestOverLimit() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setRemoteAddr("5.6.7.8"); + + for (int i = 0; i < 60; i++) { + filter.doFilter(request, new MockHttpServletResponse(), chain); + } + + MockHttpServletResponse blockedResponse = new MockHttpServletResponse(); + filter.doFilter(request, blockedResponse, chain); + + assertEquals(429, blockedResponse.getStatus()); + verify(chain, times(60)).doFilter(eq(request), any()); + } + + @Test + void differentIpsBucketedSeparately() throws Exception { + MockHttpServletRequest req1 = new MockHttpServletRequest(); + req1.setRemoteAddr("10.0.0.1"); + MockHttpServletRequest req2 = new MockHttpServletRequest(); + req2.setRemoteAddr("10.0.0.2"); + + for (int i = 0; i < 60; i++) { + filter.doFilter(req1, new MockHttpServletResponse(), chain); + } + + MockHttpServletResponse response2 = new MockHttpServletResponse(); + filter.doFilter(req2, response2, chain); + + assertEquals(200, response2.getStatus()); + } + + @Test + void prefersXForwardedForHeader() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setRemoteAddr("192.168.1.1"); + request.addHeader("X-Forwarded-For", "203.0.113.5, 10.0.0.1"); + + for (int i = 0; i < 60; i++) { + filter.doFilter(request, new MockHttpServletResponse(), chain); + } + + MockHttpServletResponse blocked = new MockHttpServletResponse(); + filter.doFilter(request, blocked, chain); + assertEquals(429, blocked.getStatus()); + + MockHttpServletRequest directRequest = new MockHttpServletRequest(); + directRequest.setRemoteAddr("192.168.1.1"); + MockHttpServletResponse directResponse = new MockHttpServletResponse(); + filter.doFilter(directRequest, directResponse, chain); + assertEquals(200, directResponse.getStatus()); + } +}