diff --git a/src/main/java/nl/martijndwars/webpush/HttpEce.java b/src/main/java/nl/martijndwars/webpush/HttpEce.java index b9e9436..d382899 100644 --- a/src/main/java/nl/martijndwars/webpush/HttpEce.java +++ b/src/main/java/nl/martijndwars/webpush/HttpEce.java @@ -13,8 +13,9 @@ import java.security.*; import java.util.Arrays; import java.util.Base64; -import java.util.HashMap; +import java.util.Collections; import java.util.Map; +import java.util.function.Supplier; import static java.nio.charset.StandardCharsets.UTF_8; import static javax.crypto.Cipher.DECRYPT_MODE; @@ -39,17 +40,65 @@ public class HttpEce { public static final int TAG_SIZE = 16; public static final int TWO_BYTE_MAX = 65_536; public static final String WEB_PUSH_INFO = "WebPush: info\0"; + public static final String ECE_KEYLOG_PROP = "ECE_KEYLOG"; + private static final boolean LOG_ENABLED = "1".equals(System.getenv(ECE_KEYLOG_PROP)) || "1".equals(System.getProperty(ECE_KEYLOG_PROP)); + public static final Supplier DEFAULT_CIPHER_SUPPLIER = HttpEce::createCipher; + public static final ThreadLocal THREAD_LOCAL_CIPHER = ThreadLocal.withInitial(DEFAULT_CIPHER_SUPPLIER); + public static final Supplier THREAD_LOCAL_CIPHER_SUPPLIER = HttpEce::getThreadLocalCipher; - private Map keys; - private Map labels; + private final Map keys; + private final Map labels; + private final Supplier cipherSupplier; public HttpEce() { - this(new HashMap(), new HashMap()); + this(Collections.emptyMap(), Collections.emptyMap()); } public HttpEce(Map keys, Map labels) { + this(keys, labels, DEFAULT_CIPHER_SUPPLIER); + } + + public HttpEce(Map keys, Map labels, Supplier cipherSupplier) { this.keys = keys; this.labels = labels; + this.cipherSupplier = cipherSupplier; + + } + + public static HttpEce createWithDefaultCipher() { + return createWithDefaultCipher(Collections.emptyMap(), Collections.emptyMap()); + } + + public static HttpEce createWithDefaultCipher(Map keys, Map labels) { + return createWithCipher(keys, labels, DEFAULT_CIPHER_SUPPLIER); + } + + public static HttpEce createWithThreadLocalCipher() { + return createWithThreadLocalCipher(Collections.emptyMap(), Collections.emptyMap()); + } + + public static HttpEce createWithThreadLocalCipher(Map keys, Map labels) { + return createWithCipher(keys, labels, THREAD_LOCAL_CIPHER_SUPPLIER); + } + + public static HttpEce createWithCipher(Supplier cipherSupplier) { + return createWithCipher(Collections.emptyMap(), Collections.emptyMap(), cipherSupplier); + } + + public static HttpEce createWithCipher(Map keys, Map labels, Supplier cipherSupplier) { + return new HttpEce(keys, labels, cipherSupplier); + } + + private static Cipher getThreadLocalCipher() { + return THREAD_LOCAL_CIPHER.get(); + } + + private static Cipher createCipher() { + try { + return Cipher.getInstance("AES/GCM/NoPadding", "BC"); + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } } /** @@ -73,7 +122,7 @@ public byte[] encrypt(byte[] plaintext, byte[] salt, byte[] privateKey, String k byte[] nonce = keyAndNonce[1]; // Note: Cipher adds the tag to the end of the ciphertext - Cipher cipher = Cipher.getInstance("AES/GCM/NoPadding", "BC"); + Cipher cipher = cipherSupplier.get(); GCMParameterSpec params = new GCMParameterSpec(TAG_SIZE * 8, nonce); cipher.init(ENCRYPT_MODE, new SecretKeySpec(key, "AES"), params); @@ -86,9 +135,11 @@ public byte[] encrypt(byte[] plaintext, byte[] salt, byte[] privateKey, String k log("padding", padding); byte[][] encrypted = {cipher.update(plaintext), cipher.update(padding), cipher.doFinal()}; - log("encrypted", concat(encrypted)); + log("encrypted", () -> concat(encrypted)); - return log("ciphertext", concat(header, concat(encrypted))); + byte[] result = concat(header, concat(encrypted)); + log("ciphertext", result); + return result; } else { return concat(cipher.update(new byte[2]), cipher.doFinal(plaintext)); } @@ -138,7 +189,7 @@ public byte[][] parseHeader(byte[] payload) { } public byte[] decryptRecord(byte[] ciphertext, byte[] key, byte[] nonce, Encoding version) throws NoSuchPaddingException, NoSuchAlgorithmException, NoSuchProviderException, InvalidAlgorithmParameterException, InvalidKeyException, BadPaddingException, IllegalBlockSizeException { - Cipher cipher = Cipher.getInstance("AES/GCM/NoPadding", "BC"); + Cipher cipher = cipherSupplier.get(); GCMParameterSpec params = new GCMParameterSpec(TAG_SIZE * 8, nonce); cipher.init(DECRYPT_MODE, new SecretKeySpec(key, "AES"), params); @@ -437,10 +488,15 @@ private static byte[] intToBytes(int number) { * @return */ private static byte[] log(String info, byte[] array) { - if ("1".equals(System.getenv("ECE_KEYLOG"))) { + log(info, () -> array); + return array; + } + + private static void log(String info, Supplier arraySupplier) { + if (LOG_ENABLED) { + byte [] array = arraySupplier.get(); System.out.println(info + " [" + array.length + "]: " + Base64.getUrlEncoder().withoutPadding().encodeToString(array)); } - - return array; } + } diff --git a/src/test/java/nl/martijndwars/webpush/HttpEceTest.java b/src/test/java/nl/martijndwars/webpush/HttpEceTest.java index c3f6884..2c91b44 100644 --- a/src/test/java/nl/martijndwars/webpush/HttpEceTest.java +++ b/src/test/java/nl/martijndwars/webpush/HttpEceTest.java @@ -25,7 +25,7 @@ private byte[] decode(String s) { @Test public void testZeroSaltAndKey() throws GeneralSecurityException { - HttpEce httpEce = new HttpEce(); + HttpEce httpEce = HttpEce.createWithDefaultCipher(); String plaintext = "Hello"; byte[] salt = new byte[16]; byte[] key = new byte[16]; @@ -45,7 +45,7 @@ public void testZeroSaltAndKey() throws GeneralSecurityException { */ @Test public void testSampleEncryption() throws GeneralSecurityException { - HttpEce httpEce = new HttpEce(); + HttpEce httpEce = HttpEce.createWithThreadLocalCipher(); byte[] plaintext = "I am the walrus".getBytes(); byte[] salt = decode("I1BsxtFttlv3u_Oo94xnmw");