TunnelClient.java

package jasper.component;

import jasper.domain.proj.HasTags;
import jasper.errors.InvalidTunnelException;
import jasper.errors.RetryableTunnelException;
import jasper.repository.UserRepository;
import org.apache.sshd.client.SshClient;
import org.apache.sshd.client.auth.keyboard.UserInteraction;
import org.apache.sshd.client.keyverifier.ServerKeyVerifier;
import org.apache.sshd.client.session.ClientSession;
import org.apache.sshd.common.SshException;
import org.apache.sshd.common.config.keys.KeyUtils;
import org.apache.sshd.common.session.Session;
import org.apache.sshd.common.session.SessionListener;
import org.apache.sshd.common.util.net.SshdSocketAddress;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;

import java.io.ByteArrayInputStream;
import java.net.URI;
import java.net.URISyntaxException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;

import static jasper.domain.proj.HasTags.authors;
import static jasper.domain.proj.HasTags.hasMatchingTag;
import static jasper.domain.proj.Tag.defaultOrigin;
import static jasper.domain.proj.Tag.reverseOrigin;
import static jasper.plugin.Origin.getOrigin;
import static jasper.plugin.Tunnel.getTunnel;
import static jasper.util.Logging.getMessage;
import static org.apache.commons.lang3.StringUtils.isNotBlank;
import static org.apache.sshd.common.NamedResource.ofName;
import static org.apache.sshd.common.util.security.SecurityUtils.loadKeyPairIdentities;

@Component
public class TunnelClient {
	private static final Logger logger = LoggerFactory.getLogger(TunnelClient.class);

	@Autowired
	TaskScheduler taskScheduler;

	@Autowired
	UserRepository userRepository;

	@Autowired
	Tagger tagger;

	record TunnelInfo(int tunnelPort, int connections, SshClient client) {}
	Map<String, TunnelInfo> tunnels = new ConcurrentHashMap<>();

	@Scheduled(fixedDelay = 30, initialDelay = 10, timeUnit = TimeUnit.MINUTES)
	public void log() {
		for (var e : tunnels.entrySet()) {
			logger.info("SSH Tunnel Pool: {} with {} connections", e.getKey(), e.getValue().connections);
		}
	}

	@Scheduled(fixedDelay = 1, timeUnit = TimeUnit.MINUTES)
	public void healthCheck() {
		for (var remote : tunnels.keySet()) {
			tunnels.compute(remote, (k, v) -> {
				if (v == null) {
					logger.debug("Health check found null entry for {}", remote);
					return null;
				}

				if (!v.client.isOpen()) {
					logger.warn("Found closed client for {} with {} connections", remote, v.connections);
					v.client.stop();
					return null;
				}

				// Test SSH connection is responding
				try {
					v.client.getVersion();
					logger.debug("Healthy connection for {} with {} connections", remote, v.connections);
					return v;
				} catch (Exception e) {
					logger.warn("Failed connection test for {} with {} connections: {}", remote, v.connections, getMessage(e));
					v.client.stop();
					return null;
				}
			});
		}
	}

	public void proxy(HasTags remote, ProxyRequest request) {
		try {
			var config = getOrigin(remote);
			URI url;
			try {
				url = new URI(isNotBlank(config.getProxy()) ? config.getProxy() : remote.getUrl());
			} catch (URISyntaxException e) {
				throw new InvalidTunnelException("Error parsing tunnel URI", e);
			}
			if (!hasMatchingTag(remote, "+plugin/origin/tunnel")) {
				request.go(url);
			} else {
				var users = authors(remote);
				if (users.isEmpty()) {
					throw new InvalidTunnelException("Tunnel requested, but no user signature to lookup private key.");
				}
				var user = userRepository.findOneByQualifiedTag(users.get(0) + remote.getOrigin());
				if (user.isEmpty() || user.get().getKey() == null) {
					throw new InvalidTunnelException("Tunnel requested, but user " + users.get(0) + " does not have a private key set.");
				}
				var tunnel = getTunnel(remote);
				var host = isNotBlank(tunnel.getSshHost()) ? tunnel.getSshHost() : url.getHost();
				var username = linuxUsername(defaultOrigin(isNotBlank(tunnel.getRemoteUser()) ? tunnel.getRemoteUser() : user.get().getTag(), config.getRemote()));
				var port = tunnel.getSshPort();
				var tunnelPort = pooledConnection(remote.getOrigin(), host, username, port, serverKeyVerifier(remote), user.get().getKey());
				try {
					request.go(new URI("http://localhost:" + tunnelPort));
				} catch (Exception e) {
					killTunnel(host, username, port);
					throw new InvalidTunnelException("Error creating tunnel tracker", e);
				}
				releaseTunnel(tunnelPort, host, username, port);
			}
		} catch (RetryableTunnelException e) {
			logger.info("{} Error creating SSH tunnel for {}: {}",
				remote.getOrigin(), remote.getTitle(), remote.getUrl());
			tagger.attachLogs(remote.getUrl(), remote.getOrigin(),
				"Error creating SSH tunnel for %s: %s".formatted(
					remote.getTitle(), remote.getUrl()), getMessage(e));
		} catch (InvalidTunnelException e) {
			logger.error("{} Fatal error creating SSH tunnel for {}: {}",
				remote.getOrigin(), remote.getTitle(), remote.getUrl());
			tagger.attachError(remote.getUrl(), remote.getOrigin(),
				"Fatal error creating SSH tunnel for %s: %s".formatted(
					remote.getTitle(), remote.getUrl()), getMessage(e));
		}
	}

	public URI reserveProxy(HasTags remote) throws RetryableTunnelException {
		try {
			var config = getOrigin(remote);
			URI url;
			try {
				url = new URI(isNotBlank(config.getProxy()) ? config.getProxy() : remote.getUrl());
			} catch (URISyntaxException e) {
				throw new InvalidTunnelException("Error parsing tunnel URI", e);
			}
			if (!hasMatchingTag(remote, "+plugin/origin/tunnel")) return url;
			var users = authors(remote);
			if (users.isEmpty()) {
				throw new InvalidTunnelException("Tunnel requested, but no user signature to lookup private key.");
			}
			var user = userRepository.findOneByQualifiedTag(users.get(0) + remote.getOrigin());
			if (user.isEmpty() || user.get().getKey() == null) {
				throw new InvalidTunnelException("Tunnel requested, but user " + users.get(0) + " does not have a private key set.");
			}
			var tunnel = getTunnel(remote);
			var host = isNotBlank(tunnel.getSshHost()) ? tunnel.getSshHost() : url.getHost();
			var username = linuxUsername(defaultOrigin(isNotBlank(tunnel.getRemoteUser()) ? tunnel.getRemoteUser() : user.get().getTag(), config.getRemote()));
			var port = tunnel.getSshPort();
			var tunnelPort = pooledConnection(remote.getOrigin(), host, username, port, serverKeyVerifier(remote), user.get().getKey());
			try {
				return new URI("http://localhost:" + tunnelPort);
			} catch (URISyntaxException e) {
				killTunnel(host, username, port);
				throw new InvalidTunnelException("Error creating tunnel tracker", e);
			}
		} catch (RetryableTunnelException e) {
			logger.info("{} Error creating SSH tunnel for {}: {}",
				remote.getOrigin(), remote.getTitle(), remote.getUrl());
			tagger.attachLogs(remote.getUrl(), remote.getOrigin(),
				"Error creating SSH tunnel for %s: %s".formatted(
					remote.getTitle(), remote.getUrl()), getMessage(e));
			throw e;
		} catch (InvalidTunnelException e) {
			logger.error("{} Fatal error creating SSH tunnel for {}: {}",
				remote.getOrigin(), remote.getTitle(), remote.getUrl());
			tagger.attachError(remote.getUrl(), remote.getOrigin(),
				"Fatal error creating SSH tunnel for %s: %s".formatted(
					remote.getTitle(), remote.getUrl()), getMessage(e));
			throw e;
		}
	}

	private ServerKeyVerifier serverKeyVerifier(HasTags remote) {
		return (sshdClientSession, remoteAddress, serverKey) -> {
			var fingerprint = KeyUtils.getFingerPrint(serverKey);
			var tunnel = getTunnel(remote);
			if (tunnel.getHostFingerprint() == null) {
				tunnel.setHostFingerprint(fingerprint);
				tagger.plugin(remote.getUrl(), remote.getOrigin(), "+plugin/origin/tunnel", tunnel);
				return true;
			}
			return fingerprint.equals(tunnel.getHostFingerprint());
		};
	}

	public void releaseProxy(HasTags remote) {
		var config = getOrigin(remote);
		URI url;
		try {
			url = new URI(isNotBlank(config.getProxy()) ? config.getProxy() : remote.getUrl());
		} catch (URISyntaxException e) {
			throw new InvalidTunnelException("Error parsing tunnel URI", e);
		}
		if (hasMatchingTag(remote, "+plugin/origin/tunnel")) {
			var users = authors(remote);
			var tunnel = getTunnel(remote);
			var host = isNotBlank(tunnel.getSshHost()) ? tunnel.getSshHost() : url.getHost();
			var username = linuxUsername(defaultOrigin(isNotBlank(tunnel.getRemoteUser()) ? tunnel.getRemoteUser() : users.get(0), config.getRemote()));
			var port = tunnel.getSshPort();
			releaseTunnel(null, host, username, port);
		}
	}

	public void killProxy(HasTags remote) {
		var config = getOrigin(remote);
		URI url;
		try {
			url = new URI(isNotBlank(config.getProxy()) ? config.getProxy() : remote.getUrl());
		} catch (URISyntaxException e) {
			throw new InvalidTunnelException("Error parsing tunnel URI", e);
		}
		if (!hasMatchingTag(remote, "+plugin/origin/tunnel")) {
			var users = authors(remote);
			var tunnel = getTunnel(remote);
			var host = isNotBlank(tunnel.getSshHost()) ? tunnel.getSshHost() : url.getHost();
			var username = linuxUsername(defaultOrigin(isNotBlank(tunnel.getRemoteUser()) ? tunnel.getRemoteUser() : users.get(0), config.getRemote()));
			var port = tunnel.getSshPort();
			this.killTunnel(host, username, port);
		}
	}

	private int pooledConnection(String origin, String host, String username, int port, ServerKeyVerifier serverKeyVerifier, byte[] key) throws RetryableTunnelException {
		var remote = username + "@" + host + ":" + port;
		try {
			return tunnels.compute(remote, (k, v) -> {
				if (v != null) {
					if  (v.client.isOpen()) return new TunnelInfo(v.tunnelPort, v.connections + 1, v.client);
				}
				var client = SshClient.setUpDefaultClient();
				try {
					final int[] httpPort = {38022};
					client.setUserInteraction(new GetBanner() {
						@Override
						public void banner(String banner) {
							logger.debug("Received SSH banner: {}", banner);
							try {
								httpPort[0] = Integer.parseInt(banner);
							} catch (Exception e) {
								logger.warn("{} Could not parse tunnel port from banner. Using default {}", origin, httpPort[0]);
							}
						}
					});
					client.setServerKeyVerifier(serverKeyVerifier);
					client.start();
					var session = client.connect(username, host, port).verify(30, TimeUnit.SECONDS).getSession();
					loadKeyPairIdentities(null, ofName(username), new ByteArrayInputStream(key), null)
						.forEach(session::addPublicKeyIdentity);
					session.auth().verify(30, TimeUnit.SECONDS);
					var tracker = session.createLocalPortForwardingTracker(0, new SshdSocketAddress("localhost", httpPort[0]));
					var tunnelPort = tracker.getBoundAddress().getPort();
					client.addSessionListener(new SessionListener() {
						@Override
						public void sessionClosed(Session session) {
							logger.debug("{} SSH session closed for {}", origin, remote);
							killTunnel(host, username, port);
						}
					});
					return new TunnelInfo(tunnelPort, 1, client);
				} catch (Exception e) {
					client.stop();
					throw new RuntimeException(e);
				}
			}).tunnelPort;
		} catch (RuntimeException e) {
			logger.debug("{} Error creating tunnel SSH client", origin, e);
			if (e.getCause() instanceof SshException) throw e;
			throw new RetryableTunnelException("Error creating tunnel SSH client", e);
		}
	}

	private void releaseTunnel(Integer tunnelPort, String host, String username, int port) {
		var remote = username + "@" + host + ":" + port;
		tunnels.compute(remote, (k, v) -> {
			if (v == null) return null;
			if (tunnelPort != null && v.tunnelPort != tunnelPort) return v;
			return new TunnelInfo(v.tunnelPort, v.connections - 1, v.client);
		});
		taskScheduler.schedule(() -> cleanupTunnel(tunnelPort, host, username, port), Instant.now().plus(1, ChronoUnit.MINUTES));
	}

	private void cleanupTunnel(Integer tunnelPort, String host, String username, int port) {
		var remote = username + "@" + host + ":" + port;
		tunnels.compute(remote, (k, v) -> {
			if (v == null) return null;
			if (tunnelPort != null && v.tunnelPort != tunnelPort) return v;
			if (v.connections <= 0) {
				v.client.stop();
				return null;
			}
			return v;
		});
	}

	private void killTunnel(String host, String username, int port) {
		var remote = username + "@" + host + ":" + port;
		tunnels.compute(remote, (k, v) -> {
			if (v == null) return null;
			v.client.stop();
			return null;
		});
	}

	public interface ProxyRequest {
		void go(URI url);
	}

	private String linuxUsername(String qualifiedTag) {
		return reverseOrigin(qualifiedTag)
			.replace("_", "")
			.replace("+", "")
			.replace(".", "-")
			.replace("/", "_");
	}

	private static abstract class GetBanner implements UserInteraction {

		public abstract void banner(String banner);

		@Override
		public void welcome(ClientSession session, String banner, String lang) {
			banner(banner.trim());
		}

		@Override
		public String[] interactive(ClientSession session, String name, String instruction, String lang, String[] prompt, boolean[] echo) {
			return new String[0];
		}

		@Override
		public String getUpdatedPassword(ClientSession session, String prompt, String lang) {
			return null;
		}
	}
}