AesGcmCipher.kt

package de.pflugradts.passbird.application.security

import de.pflugradts.passbird.application.util.withScrambledBytes
import de.pflugradts.passbird.domain.model.shell.EncryptedShell
import de.pflugradts.passbird.domain.model.shell.MAX_ASCII_VALUE
import de.pflugradts.passbird.domain.model.shell.MIN_ASCII_VALUE
import de.pflugradts.passbird.domain.model.shell.PlainShell.Companion.SECURE_RANDOM
import de.pflugradts.passbird.domain.model.shell.Shell
import de.pflugradts.passbird.domain.model.shell.Shell.Companion.shellOf
import de.pflugradts.passbird.domain.service.password.encryption.CryptoProvider
import java.security.SecureRandom
import javax.crypto.Cipher
import javax.crypto.Cipher.DECRYPT_MODE
import javax.crypto.Cipher.ENCRYPT_MODE
import javax.crypto.SecretKeyFactory
import javax.crypto.spec.GCMParameterSpec
import javax.crypto.spec.PBEKeySpec
import javax.crypto.spec.SecretKeySpec

private const val ENCRYPTION_ALGORITHM = "AES"
private const val TRANSFORMATION = "$ENCRYPTION_ALGORITHM/GCM/NoPadding"
private const val SECRET_KEY_ALGORITHM = "PBKDF2WithHmacSHA256"
private const val TAG_LENGTH_BIT = 128
private const val AES_KEY_LENGTH_BIT = 128
private val SALT = "PassbirdSalt2024".toByteArray()

class AesGcmCipher internal constructor(private val secretKeySpec: SecretKeySpec) : CryptoProvider {
    constructor(keyShell: Shell) : this(createCurrentSecretKeySpec(keyShell))

    private val secureRandom = SecureRandom()

    override fun encrypt(shell: Shell) = requestSecureIv().let {
        EncryptedShell(payload = cipherize(ENCRYPT_MODE, shell, it.toByteArray()), iv = it)
    }

    override fun decrypt(encryptedShell: EncryptedShell) = cipherize(DECRYPT_MODE, encryptedShell.payload, encryptedShell.iv.toByteArray())

    private fun requestSecureIv(): Shell = ByteArray(IV_SIZE).apply { secureRandom.nextBytes(this) }.toShell()

    private fun cipherize(mode: Int, shell: Shell, iv: ByteArray) = withCipherInputBytes(shell) { input ->
        withCipherOutputBytes(
            Cipher.getInstance(TRANSFORMATION)
                .apply { init(mode, secretKeySpec, GCMParameterSpec(TAG_LENGTH_BIT, iv)) }
                .doFinal(input),
        ) { it.toShell() }
    }

    private fun ByteArray.toShell() = shellOf(this)

    companion object {
        const val IV_SIZE = 12
    }
}

internal fun createLegacyAesGcmCipher(keyShell: Shell) = AesGcmCipher(createLegacySecretKeySpec(keyShell))

internal inline fun <T> withCipherInputBytes(shell: Shell, block: (ByteArray) -> T) = withScrambledBytes(shell.toByteArray(), block)

internal inline fun <T> withCipherOutputBytes(bytes: ByteArray, block: (ByteArray) -> T) = withScrambledBytes(bytes, block)

private fun createCurrentSecretKeySpec(keyShell: Shell) = withCurrentKeyBytes(keyShell) {
    SecretKeySpec(it, ENCRYPTION_ALGORITHM)
}

private fun createLegacySecretKeySpec(keyShell: Shell) = withLegacyKeyBytes(keyShell) {
    SecretKeySpec(it, ENCRYPTION_ALGORITHM)
}

internal inline fun <T> withCurrentKeyBytes(keyShell: Shell, block: (ByteArray) -> T) = withScrambledBytes(keyShell.toByteArray(), block)

internal inline fun <T> withLegacyKeyBytes(keyShell: Shell, block: (ByteArray) -> T): T {
    val passwordChars = CharArray(keyShell.size) { keyShell.getChar(it) }
    try {
        val keySpec = PBEKeySpec(passwordChars, SALT, 100, AES_KEY_LENGTH_BIT)
        try {
            return withScrambledBytes(
                SecretKeyFactory.getInstance(SECRET_KEY_ALGORITHM).generateSecret(keySpec).encoded,
                block,
            )
        } finally {
            keySpec.clearPassword()
        }
    } finally {
        passwordChars.scramble()
    }
}

private fun CharArray.scramble() = indices.forEach {
    this[it] = (SECURE_RANDOM.nextInt(1 + MAX_ASCII_VALUE - MIN_ASCII_VALUE) + MIN_ASCII_VALUE).toChar()
}