security hardening #12
@@ -0,0 +1,61 @@
|
|||||||
|
package de.zendric.app.xpensely_server.security;
|
||||||
|
|
||||||
|
import io.github.bucket4j.Bandwidth;
|
||||||
|
import io.github.bucket4j.Bucket;
|
||||||
|
import jakarta.servlet.FilterChain;
|
||||||
|
import jakarta.servlet.ServletException;
|
||||||
|
import jakarta.servlet.http.HttpServletRequest;
|
||||||
|
import jakarta.servlet.http.HttpServletResponse;
|
||||||
|
import org.springframework.http.HttpStatus;
|
||||||
|
import org.springframework.security.core.Authentication;
|
||||||
|
import org.springframework.security.core.context.SecurityContextHolder;
|
||||||
|
import org.springframework.security.oauth2.jwt.Jwt;
|
||||||
|
import org.springframework.web.filter.OncePerRequestFilter;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
|
||||||
|
public class RateLimitFilter extends OncePerRequestFilter {
|
||||||
|
|
||||||
|
private static final int REQUESTS_PER_MINUTE = 60;
|
||||||
|
|
||||||
|
private final Map<String, Bucket> buckets = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void doFilterInternal(HttpServletRequest request,
|
||||||
|
HttpServletResponse response,
|
||||||
|
FilterChain filterChain) throws ServletException, IOException {
|
||||||
|
String key = resolveKey(request);
|
||||||
|
Bucket bucket = buckets.computeIfAbsent(key, k -> newBucket());
|
||||||
|
|
||||||
|
if (bucket.tryConsume(1)) {
|
||||||
|
filterChain.doFilter(request, response);
|
||||||
|
} else {
|
||||||
|
response.setStatus(HttpStatus.TOO_MANY_REQUESTS.value());
|
||||||
|
response.getWriter().write("Rate limit exceeded");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private String resolveKey(HttpServletRequest request) {
|
||||||
|
Authentication auth = SecurityContextHolder.getContext().getAuthentication();
|
||||||
|
if (auth != null && auth.getPrincipal() instanceof Jwt jwt) {
|
||||||
|
return "user:" + jwt.getSubject();
|
||||||
|
}
|
||||||
|
String ip = request.getHeader("X-Forwarded-For");
|
||||||
|
if (ip != null && !ip.isBlank()) {
|
||||||
|
return "ip:" + ip.split(",")[0].trim();
|
||||||
|
}
|
||||||
|
return "ip:" + request.getRemoteAddr();
|
||||||
|
}
|
||||||
|
|
||||||
|
private Bucket newBucket() {
|
||||||
|
return Bucket.builder()
|
||||||
|
.addLimit(Bandwidth.builder()
|
||||||
|
.capacity(REQUESTS_PER_MINUTE)
|
||||||
|
.refillGreedy(REQUESTS_PER_MINUTE, Duration.ofMinutes(1))
|
||||||
|
.build())
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,6 +6,7 @@ import org.springframework.context.annotation.Profile;
|
|||||||
import org.springframework.security.config.Customizer;
|
import org.springframework.security.config.Customizer;
|
||||||
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
|
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
|
||||||
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
|
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
|
||||||
|
import org.springframework.security.oauth2.server.resource.web.BearerTokenAuthenticationFilter;
|
||||||
import org.springframework.security.web.SecurityFilterChain;
|
import org.springframework.security.web.SecurityFilterChain;
|
||||||
|
|
||||||
@Configuration
|
@Configuration
|
||||||
@@ -31,6 +32,7 @@ public class SecurityConfig {
|
|||||||
.oauth2ResourceServer(oauth2 -> oauth2
|
.oauth2ResourceServer(oauth2 -> oauth2
|
||||||
.jwt(Customizer.withDefaults()))
|
.jwt(Customizer.withDefaults()))
|
||||||
.oauth2Login(Customizer.withDefaults())
|
.oauth2Login(Customizer.withDefaults())
|
||||||
|
.addFilterAfter(new RateLimitFilter(), BearerTokenAuthenticationFilter.class)
|
||||||
.csrf().disable();
|
.csrf().disable();
|
||||||
|
|
||||||
return http.build();
|
return http.build();
|
||||||
|
|||||||
@@ -0,0 +1,89 @@
|
|||||||
|
package de.zendric.app.xpensely_Server.security;
|
||||||
|
|
||||||
|
import de.zendric.app.xpensely_server.security.RateLimitFilter;
|
||||||
|
import jakarta.servlet.FilterChain;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.springframework.mock.web.MockHttpServletRequest;
|
||||||
|
import org.springframework.mock.web.MockHttpServletResponse;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.mockito.Mockito.*;
|
||||||
|
|
||||||
|
class RateLimitFilterTest {
|
||||||
|
|
||||||
|
RateLimitFilter filter;
|
||||||
|
FilterChain chain;
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() {
|
||||||
|
filter = new RateLimitFilter();
|
||||||
|
chain = mock(FilterChain.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void allowsRequestUnderLimit() throws Exception {
|
||||||
|
MockHttpServletRequest request = new MockHttpServletRequest();
|
||||||
|
request.setRemoteAddr("1.2.3.4");
|
||||||
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||||
|
|
||||||
|
filter.doFilter(request, response, chain);
|
||||||
|
|
||||||
|
verify(chain, times(1)).doFilter(request, response);
|
||||||
|
assertEquals(200, response.getStatus());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void blocksRequestOverLimit() throws Exception {
|
||||||
|
MockHttpServletRequest request = new MockHttpServletRequest();
|
||||||
|
request.setRemoteAddr("5.6.7.8");
|
||||||
|
|
||||||
|
for (int i = 0; i < 60; i++) {
|
||||||
|
filter.doFilter(request, new MockHttpServletResponse(), chain);
|
||||||
|
}
|
||||||
|
|
||||||
|
MockHttpServletResponse blockedResponse = new MockHttpServletResponse();
|
||||||
|
filter.doFilter(request, blockedResponse, chain);
|
||||||
|
|
||||||
|
assertEquals(429, blockedResponse.getStatus());
|
||||||
|
verify(chain, times(60)).doFilter(eq(request), any());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void differentIpsBucketedSeparately() throws Exception {
|
||||||
|
MockHttpServletRequest req1 = new MockHttpServletRequest();
|
||||||
|
req1.setRemoteAddr("10.0.0.1");
|
||||||
|
MockHttpServletRequest req2 = new MockHttpServletRequest();
|
||||||
|
req2.setRemoteAddr("10.0.0.2");
|
||||||
|
|
||||||
|
for (int i = 0; i < 60; i++) {
|
||||||
|
filter.doFilter(req1, new MockHttpServletResponse(), chain);
|
||||||
|
}
|
||||||
|
|
||||||
|
MockHttpServletResponse response2 = new MockHttpServletResponse();
|
||||||
|
filter.doFilter(req2, response2, chain);
|
||||||
|
|
||||||
|
assertEquals(200, response2.getStatus());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void prefersXForwardedForHeader() throws Exception {
|
||||||
|
MockHttpServletRequest request = new MockHttpServletRequest();
|
||||||
|
request.setRemoteAddr("192.168.1.1");
|
||||||
|
request.addHeader("X-Forwarded-For", "203.0.113.5, 10.0.0.1");
|
||||||
|
|
||||||
|
for (int i = 0; i < 60; i++) {
|
||||||
|
filter.doFilter(request, new MockHttpServletResponse(), chain);
|
||||||
|
}
|
||||||
|
|
||||||
|
MockHttpServletResponse blocked = new MockHttpServletResponse();
|
||||||
|
filter.doFilter(request, blocked, chain);
|
||||||
|
assertEquals(429, blocked.getStatus());
|
||||||
|
|
||||||
|
MockHttpServletRequest directRequest = new MockHttpServletRequest();
|
||||||
|
directRequest.setRemoteAddr("192.168.1.1");
|
||||||
|
MockHttpServletResponse directResponse = new MockHttpServletResponse();
|
||||||
|
filter.doFilter(directRequest, directResponse, chain);
|
||||||
|
assertEquals(200, directResponse.getStatus());
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user