WebSocketConfig.java

package jasper.config;

import jakarta.servlet.http.HttpServletRequest;
import jasper.component.ConfigCache;
import jasper.domain.proj.HasOrigin;
import jasper.security.Auth;
import jasper.security.jwt.TokenProvider;
import jasper.security.jwt.TokenProviderImplDefault;
import jasper.service.dto.UserDto;
import org.apache.tomcat.websocket.server.WsSci;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.tomcat.TomcatContextCustomizer;
import org.springframework.boot.tomcat.servlet.TomcatServletWebServerFactory;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Profile;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.integration.annotation.ServiceActivator;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.config.ChannelRegistration;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.security.core.Authentication;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketTransportRegistration;
import org.springframework.web.socket.handler.WebSocketHandlerDecorator;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;

import java.io.IOException;
import java.security.Principal;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import static jasper.domain.proj.HasOrigin.isSubOrigin;
import static jasper.security.Auth.LOCAL_ORIGIN_HEADER;
import static jasper.security.Auth.READ_ACCESS_HEADER;
import static jasper.security.Auth.TAG_READ_ACCESS_HEADER;
import static jasper.security.Auth.TAG_WRITE_ACCESS_HEADER;
import static jasper.security.Auth.USER_ROLE_HEADER;
import static jasper.security.Auth.USER_TAG_HEADER;
import static jasper.security.Auth.WRITE_ACCESS_HEADER;
import static org.apache.commons.lang3.StringUtils.isNotBlank;

@Profile("!no-websocket")
@Configuration
@EnableWebSocketMessageBroker
public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
	private static final Logger logger = LoggerFactory.getLogger(WebSocketConfig.class);

	public static final String AUTHORIZATION_HEADER = "Authorization";

	@Autowired
	Props props;

	@Autowired
	ConfigCache configs;

	@Autowired
	TokenProvider tokenProvider;

	@Autowired
	TokenProviderImplDefault defaultTokenProvider;

	@Autowired
	@Qualifier("authSingleton")
	Auth auth;

	private Set<WebSocketSession> sessions = ConcurrentHashMap.newKeySet();

	@Bean
	public TomcatServletWebServerFactory tomcatContainerFactory() {
		var factory = new TomcatServletWebServerFactory();
		factory.addContextCustomizers(tomcatContextCustomizer());
		return factory;
	}

	@Bean
	public TomcatContextCustomizer tomcatContextCustomizer() {
		return context -> context.addServletContainerInitializer(new WsSci(), null);
	}

	@Override
	public void configureMessageBroker(MessageBrokerRegistry config) {
		config.enableSimpleBroker("/topic");
		config.setApplicationDestinationPrefixes("/app");
	}

	@Override
	public void registerStompEndpoints(StompEndpointRegistry registry) {
		registry
			.addEndpoint("/api/stomp/")
				.setHandshakeHandler(new StompDefaultHandshakeHandler())
				.addInterceptors(new StompHandshakeInterceptor())
				.setAllowedOriginPatterns("*")
			.withSockJS()
				.setSuppressCors(true);
	}

	@Override
	public void configureClientInboundChannel(ChannelRegistration registration) {
		registration.interceptors(new JwtChannelInterceptor());
	}

	@Override
	public void configureWebSocketTransport(WebSocketTransportRegistration registration) {
		registration.addDecoratorFactory(handler -> new WebSocketHandlerDecorator(handler) {
			@Override
			public void afterConnectionEstablished(WebSocketSession session) throws Exception {
				sessions.add(session);
				super.afterConnectionEstablished(session);
			}
			@Override
			public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
				sessions.remove(session);
				super.afterConnectionClosed(session, closeStatus);
			}

		});
	}

	@ServiceActivator(inputChannel = "userRxChannel")
	public void handleUserUpdate(Message<UserDto> message) {
		// Just drop all sessions for now
		sessions.forEach(s -> {
            try {
                s.close(CloseStatus.SERVICE_RESTARTED);
            } catch (IOException e) {
                logger.warn("Could not close websocket session.", e);
            }
        });
		sessions.clear();
	}

	class StompHandshakeInterceptor implements HandshakeInterceptor {

		private String resolveOrigin(HttpServletRequest request) {
			var originHeader = request.getHeader(LOCAL_ORIGIN_HEADER);
			logger.debug("STOMP Local Origin Header: {}", originHeader);
			if (isNotBlank(originHeader)) {
				originHeader = originHeader.toLowerCase();
				if ("default".equals(originHeader)) return props.getLocalOrigin();
				if (originHeader.matches(HasOrigin.REGEX) && isSubOrigin(props.getLocalOrigin(), originHeader)) {
					return originHeader;
				}
			}
			return props.getOrigin();
		}

		@Override
		public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) {
			logger.debug("STOMP Handshake");
			if (request instanceof ServletServerHttpRequest servletRequest) {
				var httpServletRequest = servletRequest.getServletRequest();
				var token = httpServletRequest.getHeader(AUTHORIZATION_HEADER);
				if (isNotBlank(token) && token.startsWith("Bearer ")) {
					attributes.put("jwt", token.substring("Bearer ".length()));
				}
				attributes.put("origin", resolveOrigin(httpServletRequest));
				// Create WebSocket request attributes for proxy-controlled headers
				var wsAttributes = new WebSocketRequestAttributes();
				wsAttributes.setHeader(LOCAL_ORIGIN_HEADER, httpServletRequest.getHeader(LOCAL_ORIGIN_HEADER));
				wsAttributes.setHeader(USER_ROLE_HEADER, httpServletRequest.getHeader(USER_ROLE_HEADER));
				wsAttributes.setHeader(USER_TAG_HEADER, httpServletRequest.getHeader(USER_TAG_HEADER));
				wsAttributes.setHeader(READ_ACCESS_HEADER, httpServletRequest.getHeader(READ_ACCESS_HEADER));
				wsAttributes.setHeader(WRITE_ACCESS_HEADER, httpServletRequest.getHeader(WRITE_ACCESS_HEADER));
				wsAttributes.setHeader(TAG_READ_ACCESS_HEADER, httpServletRequest.getHeader(TAG_READ_ACCESS_HEADER));
				wsAttributes.setHeader(TAG_WRITE_ACCESS_HEADER, httpServletRequest.getHeader(TAG_WRITE_ACCESS_HEADER));
				attributes.put("wsAttributes", wsAttributes);
			}
			return true;
		}

		@Override
		public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) { }
	}

	class StompDefaultHandshakeHandler extends DefaultHandshakeHandler {
		@Override
		public Principal determineUser(ServerHttpRequest request, WebSocketHandler handler, Map<String, Object> attributes) {
			var origin = (String) attributes.get("origin");
			WebSocketConfig.logger.debug("{} STOMP Determine User", origin);
			if (!configs.root().web(origin)) {
				WebSocketConfig.logger.error("{} No web access for origin", origin);
				return null;
			}
			try {
				var wsAttributes = (WebSocketRequestAttributes) attributes.get("wsAttributes");
				RequestContextHolder.setRequestAttributes(wsAttributes);
				if (!attributes.containsKey("jwt")) return defaultTokenProvider.getAuthentication(null, origin);
				var token = (String) attributes.get("jwt");
				return tokenProvider.validateToken(token, origin) ? tokenProvider.getAuthentication(token, origin) : defaultTokenProvider.getAuthentication(null, origin);
			} finally {
				RequestContextHolder.resetRequestAttributes();
			}
		}
	}

	class JwtChannelInterceptor implements ChannelInterceptor {
		@Override
		public Message<?> preSend(Message<?> message, MessageChannel channel) {
			try {
				var accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
				if (accessor.getCommand() == StompCommand.BEGIN) return null; // No Transactions
				if (accessor.getCommand() == StompCommand.SEND) return null; // No Client Messages
				if (accessor.getCommand() != StompCommand.SUBSCRIBE) return message;
				var wsAttributes = (WebSocketRequestAttributes) accessor.getSessionAttributes().get("wsAttributes");
				RequestContextHolder.setRequestAttributes(wsAttributes);
				if (accessor.getUser() instanceof Authentication authentication) {
					auth.clear(authentication);
					logger.debug("{} STOMP User Set {}", auth.getOrigin(), auth.getUserTag());
					@SuppressWarnings("unchecked")
					var headers = (Map<String, ArrayList<String>>) message.getHeaders().get("nativeHeaders", Map.class);
					if (headers != null && headers.get("jwt") != null && headers.get("jwt").size() > 0) {
						var token = headers.get("jwt").get(0);
						var origin = auth.getOrigin();
						if  (tokenProvider.validateToken(token, origin)) {
							auth.clear(tokenProvider.getAuthentication(token, origin));
							logger.debug("{} STOMP SUBSCRIBE Credentials Header {}", auth.getOrigin(), auth.getUserTag());
						}
					}
				} else {
					auth.clear(defaultTokenProvider.getAuthentication(null, props.getOrigin()));
					logger.debug("{} STOMP Default auth {}", auth.getOrigin(), auth.getUserTag());
				}
				if (!configs.root().web(auth.getOrigin())) {
					logger.error("{} No web access for origin", auth.getOrigin());
					return null;
				}
				if (auth.canSubscribeTo(accessor.getDestination())) return message;
				logger.warn("{} {} can't subscribe to {}", auth.getOrigin(), auth.getUserTag(), accessor.getDestination());
			} catch (Exception e) {
				logger.error("{} Cannot authorize websocket subscription.", auth.getOrigin(), e);
			} finally {
				RequestContextHolder.resetRequestAttributes();
			}
			return null;
		}
	}

	public static class WebSocketRequestAttributes implements RequestAttributes {
		private final Map<String, String> headers = new HashMap<>();

		public void setHeader(String name, String value) {
			headers.put(name, value);
		}

		public String getHeader(String name) {
			return headers.get(name);
		}

		@Override
		public Object getAttribute(String name, int scope) {
			return headers.get(name);
		}

		@Override
		public void setAttribute(String name, Object value, int scope) {
			headers.put(name, (String)value);
		}

		@Override
		public void removeAttribute(String name, int scope) {

		}

		@Override
		public String[] getAttributeNames(int scope) {
			return new String[0];
		}

		@Override
		public void registerDestructionCallback(String name, Runnable callback, int scope) {}

		@Override
		public Object resolveReference(String key) { return null; }

		@Override
		public String getSessionId() { return null; }

		@Override
		public Object getSessionMutex() { return null; }
	}
}