JwkSigningKeyResolver.java
package jasper.security.jwt;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.JwsHeader;
import io.jsonwebtoken.SigningKeyResolver;
import io.jsonwebtoken.io.Decoders;
import org.springframework.web.client.RestTemplate;
import java.math.BigInteger;
import java.net.URI;
import java.security.Key;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.RSAPublicKeySpec;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;
class JwkSigningKeyResolver implements SigningKeyResolver {
private URI jwkUri;
private RestTemplate restTemplate;
private Object lock = new Object();
private volatile Map<String, Key> keyMap = new HashMap<>();
JwkSigningKeyResolver(URI jwkUri, RestTemplate restTemplate) {
this.jwkUri = jwkUri;
this.restTemplate = restTemplate;
}
@Override
public Key resolveSigningKey(JwsHeader header, Claims claims) {
return getKey(header.getKeyId());
}
@Override
public Key resolveSigningKey(JwsHeader header, byte[] content) {
return getKey(header.getKeyId());
}
private Key getKey(String keyId) {
// check non synchronized to avoid a lock
Key result = keyMap.get(keyId);
if (result != null) {
return result;
}
synchronized (lock) {
// once synchronized, check the map once again the a previously
// synchronized thread could have already updated they keys
result = keyMap.get(keyId);
if (result != null) {
return result;
}
// finally, fallback to updating the keys, an return a value (or null)
updateKeys();
return keyMap.get(keyId);
}
}
private void updateKeys() {
Map<String, Key> newKeys = restTemplate
.getForObject(jwkUri, JwkKeys.class)
.getKeys().stream()
.filter(jwkKey -> "sig".equals(jwkKey.getPublicKeyUse()))
.filter(jwkKey -> "RSA".equals(jwkKey.getKeyType()))
.collect(Collectors.toMap(JwkKey::getKeyId, jwkKey -> {
BigInteger modulus = base64ToBigInteger(jwkKey.getPublicKeyModulus());
BigInteger exponent = base64ToBigInteger(jwkKey.getPublicKeyExponent());
RSAPublicKeySpec rsaPublicKeySpec = new RSAPublicKeySpec(modulus, exponent);
try {
KeyFactory keyFactory = KeyFactory.getInstance("RSA");
return keyFactory.generatePublic(rsaPublicKeySpec);
} catch (NoSuchAlgorithmException | InvalidKeySpecException e) {
throw new IllegalStateException("Failed to parse public key");
}
}));
keyMap = Collections.unmodifiableMap(newKeys);
}
private BigInteger base64ToBigInteger(String value) {
return new BigInteger(1, Decoders.BASE64URL.decode(value));
}
}