Skip to content

Add support for authorization code grant with PKCE #91

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package nl.myndocs.oauth2.client.inmemory

import nl.myndocs.oauth2.client.CodeChallengeMethod

data class ClientConfiguration(
var clientId: String? = null,
var clientSecret: String? = null,
var scopes: Set<String> = setOf(),
var redirectUris: Set<String> = setOf(),
var authorizedGrantTypes: Set<String> = setOf()
var authorizedGrantTypes: Set<String> = setOf(),
var allowedCodeChallengeMethods: Set<CodeChallengeMethod> = emptySet(),
var forcePKCE: Boolean = false,
var public: Boolean = false
)
Original file line number Diff line number Diff line change
@@ -16,7 +16,8 @@ class InMemoryClient : ClientService {

override fun clientOf(clientId: String): Client? {
return clients.filter { it.clientId == clientId }
.map { client -> nl.myndocs.oauth2.client.Client(client.clientId!!, client.scopes, client.redirectUris, client.authorizedGrantTypes) }
.map { client -> Client(client.clientId!!, client.scopes, client.redirectUris, client.authorizedGrantTypes,
client.allowedCodeChallengeMethods, client.forcePKCE, client.public) }
.firstOrNull()
}

23 changes: 14 additions & 9 deletions oauth2-server-core/src/main/java/nl/myndocs/oauth2/CallRouter.kt
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
package nl.myndocs.oauth2

import nl.myndocs.oauth2.authenticator.Credentials
import nl.myndocs.oauth2.client.CodeChallengeMethod
import nl.myndocs.oauth2.exception.*
import nl.myndocs.oauth2.grant.Granter
import nl.myndocs.oauth2.grant.GrantingCall
import nl.myndocs.oauth2.grant.redirect
import nl.myndocs.oauth2.grant.tokenInfo
import nl.myndocs.oauth2.identity.TokenInfo
import nl.myndocs.oauth2.request.CallContext
import nl.myndocs.oauth2.request.RedirectAuthorizationCodeRequest
import nl.myndocs.oauth2.request.RedirectTokenRequest
import nl.myndocs.oauth2.request.headerCaseInsensitive
import nl.myndocs.oauth2.request.*
import nl.myndocs.oauth2.router.RedirectRouter
import nl.myndocs.oauth2.router.RedirectRouterResponse

@@ -81,13 +79,20 @@ class CallRouter(
): RedirectRouterResponse {
val queryParameters = callContext.queryParameters
try {
val codeChallenge = queryParameters["code_challenge"]
val codeChallengeMethod = queryParameters["code_challenge_method"]
?.let { CodeChallengeMethod.parse(it) }
?: codeChallenge?.let { CodeChallengeMethod.Plain }

val redirect = grantingCallFactory(callContext).redirect(
RedirectAuthorizationCodeRequest(
queryParameters["client_id"],
queryParameters["redirect_uri"],
credentials?.username,
credentials?.password,
queryParameters["scope"]
clientId = queryParameters["client_id"],
codeChallenge = codeChallenge,
codeChallengeMethod = codeChallengeMethod,
redirectUri = queryParameters["redirect_uri"],
username = credentials?.username,
password = credentials?.password,
scope = queryParameters["scope"]
)
)

Original file line number Diff line number Diff line change
@@ -4,5 +4,8 @@ data class Client(
val clientId: String,
val clientScopes: Set<String>,
val redirectUris: Set<String>,
val authorizedGrantTypes: Set<String>
val authorizedGrantTypes: Set<String>,
val allowedCodeChallengeMethods: Set<CodeChallengeMethod> = emptySet(),
val forcePKCE: Boolean = false,
val public: Boolean = false
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package nl.myndocs.oauth2.client

import nl.myndocs.oauth2.exception.InvalidRequestException
import nl.myndocs.oauth2.extension.sha256

enum class CodeChallengeMethod(
val value: String,
private val validator: (codeChallenge: String, codeVerifier: String) -> Boolean
) {
Plain("plain", { cc, cv -> cc == cv }),
S256("S256", { cc, cv -> cc.trimEnd('=') == cv.sha256() });

companion object {
fun parse(value: String): CodeChallengeMethod {
return values().find { it.value == value }
?: throw InvalidRequestException("Selected code_challenge_method not supported")
}
}

fun validate(codeChallenge: String, codeVerifier: String): Boolean {
return validator(codeChallenge, codeVerifier)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package nl.myndocs.oauth2.extension

import java.security.MessageDigest
import java.util.*

fun String.sha256(): String {
val md = MessageDigest.getInstance("SHA-256")
val hashBytes = md.digest(this.toByteArray())
return Base64.getUrlEncoder().encodeToString(hashBytes).trimEnd('=')
}
Original file line number Diff line number Diff line change
@@ -80,6 +80,11 @@ fun GrantingCall.authorize(authorizationCodeRequest: AuthorizationCodeRequest):
throw InvalidRequestException(INVALID_REQUEST_FIELD_MESSAGE.format("redirect_uri"))
}

val client = clientService.clientOf(authorizationCodeRequest.clientId!!)
if (authorizationCodeRequest.codeVerifier.isNullOrBlank() && client?.forcePKCE == true) {
throw InvalidRequestException(INVALID_REQUEST_FIELD_MESSAGE.format("code_verifier"))
}

val consumeCodeToken = tokenStore.consumeCodeToken(authorizationCodeRequest.code)
?: throw InvalidGrantException()

@@ -88,6 +93,8 @@ fun GrantingCall.authorize(authorizationCodeRequest: AuthorizationCodeRequest):
throw InvalidGrantException()
}

validateCodeChallenge(consumeCodeToken, authorizationCodeRequest)

val accessToken = converters.accessTokenConverter.convertToToken(
consumeCodeToken.identity,
consumeCodeToken.clientId,
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
package nl.myndocs.oauth2.grant

import nl.myndocs.oauth2.client.Client
import nl.myndocs.oauth2.exception.InvalidClientException
import nl.myndocs.oauth2.exception.InvalidGrantException
import nl.myndocs.oauth2.exception.InvalidRequestException
import nl.myndocs.oauth2.exception.InvalidScopeException
import nl.myndocs.oauth2.client.CodeChallengeMethod
import nl.myndocs.oauth2.exception.*
import nl.myndocs.oauth2.identity.Identity
import nl.myndocs.oauth2.identity.TokenInfo
import nl.myndocs.oauth2.request.*
import nl.myndocs.oauth2.token.CodeToken

fun GrantingCall.grantPassword() = granter("password") {
val accessToken = authorize(
@@ -59,7 +58,8 @@ fun GrantingCall.grantAuthorizationCode() = granter("authorization_code") {
callContext.formParameters["client_id"],
callContext.formParameters["client_secret"],
callContext.formParameters["code"],
callContext.formParameters["redirect_uri"]
callContext.formParameters["redirect_uri"],
callContext.formParameters["code_verifier"]
)
)

@@ -104,11 +104,14 @@ fun GrantingCall.throwExceptionIfUnverifiedClient(clientRequest: ClientRequest)
val clientId = clientRequest.clientId
?: throw InvalidRequestException(INVALID_REQUEST_FIELD_MESSAGE.format("client_id"))

val client = clientService.clientOf(clientId) ?: throw InvalidClientException()
if (client.public) {
return
}

val clientSecret = clientRequest.clientSecret
?: throw InvalidRequestException(INVALID_REQUEST_FIELD_MESSAGE.format("client_secret"))

val client = clientService.clientOf(clientId) ?: throw InvalidClientException()

if (!clientService.validClient(client, clientSecret)) {
throw InvalidClientException()
}
@@ -117,3 +120,25 @@ fun GrantingCall.throwExceptionIfUnverifiedClient(clientRequest: ClientRequest)
fun GrantingCall.scopesAllowed(clientScopes: Set<String>, requestedScopes: Set<String>): Boolean {
return clientScopes.containsAll(requestedScopes)
}

fun GrantingCall.validateCodeChallenge(codeToken: CodeToken, request: AuthorizationCodeRequest) {
val codeChallenge = codeToken.codeChallenge
val codeVerifier = request.codeVerifier
if (codeChallenge.isNullOrBlank() && request.codeVerifier.isNullOrBlank()) {
return
}

if (codeChallenge.isNullOrBlank()) {
throw InvalidGrantException()
}

if (codeVerifier.isNullOrBlank()) {
throw InvalidGrantException()
}

val codeChallengeMethod = codeToken.codeChallengeMethod ?: CodeChallengeMethod.Plain
val validChallengeCode = codeChallengeMethod.validate(codeChallenge = codeChallenge, codeVerifier = codeVerifier)
if (!validChallengeCode) {
throw InvalidGrantException()
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package nl.myndocs.oauth2.grant

import nl.myndocs.oauth2.client.AuthorizedGrantType
import nl.myndocs.oauth2.client.Client
import nl.myndocs.oauth2.client.CodeChallengeMethod
import nl.myndocs.oauth2.exception.InvalidClientException
import nl.myndocs.oauth2.exception.InvalidGrantException
import nl.myndocs.oauth2.exception.InvalidIdentityException
@@ -25,6 +27,8 @@ fun GrantingCall.redirect(redirect: RedirectAuthorizationCodeRequest): CodeToken
}
}

validatePKCE(clientOf, redirect.codeChallenge, redirect.codeChallengeMethod)

val identityOf = identityService.identityOf(clientOf, redirect.username!!) ?: throw InvalidIdentityException()

val validIdentity = identityService.validCredentials(clientOf, identityOf, redirect.password!!)
@@ -42,6 +46,8 @@ fun GrantingCall.redirect(redirect: RedirectAuthorizationCodeRequest): CodeToken
val codeToken = converters.codeTokenConverter.convertToToken(
identityOf,
clientOf.clientId,
redirect.codeChallenge,
redirect.codeChallengeMethod,
redirect.redirectUri!!,
requestedScopes
)
@@ -113,4 +119,27 @@ private fun checkMissingFields(redirect: RedirectAuthorizationCodeRequest) = wit
redirectUri == null -> throwMissingField("redirect_uri")
else -> this
}
}
}

private fun validatePKCE(
client: Client,
codeChallenge: String?,
codeChallengeMethod: CodeChallengeMethod?
) {
if (codeChallenge.isNullOrBlank()) {
if (!client.forcePKCE) {
return
}

throw InvalidRequestException("PKCE is required. code_challenge is missing")
}

if (codeChallenge.length < 43 || codeChallenge.length > 128) {
throw InvalidRequestException("Code challenge length must be between 43 and 128 characters long")
}

val ccm = codeChallengeMethod ?: CodeChallengeMethod.Plain
if (!client.allowedCodeChallengeMethods.contains(ccm)) {
throw InvalidRequestException("Selected code_challenge_method not supported")
}
}
Original file line number Diff line number Diff line change
@@ -4,7 +4,8 @@ data class AuthorizationCodeRequest(
override val clientId: String?,
override val clientSecret: String?,
val code: String?,
val redirectUri: String?
val redirectUri: String?,
val codeVerifier: String? = null
) : ClientRequest {
val grant_type = "authorization_code"
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package nl.myndocs.oauth2.request

import nl.myndocs.oauth2.client.CodeChallengeMethod

class RedirectAuthorizationCodeRequest(
val clientId: String?,
val redirectUri: String?,
val username: String?,
val password: String?,
val scope: String?
val clientId: String?,
val codeChallenge: String?,
val codeChallengeMethod: CodeChallengeMethod?,
val redirectUri: String?,
val username: String?,
val password: String?,
val scope: String?
)
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package nl.myndocs.oauth2.token

import nl.myndocs.oauth2.identity.Identity
import nl.myndocs.oauth2.client.CodeChallengeMethod
import java.time.Instant

data class CodeToken(
@@ -9,5 +10,7 @@ data class CodeToken(
val identity: Identity,
val clientId: String,
val redirectUri: String,
val scopes: Set<String>
val scopes: Set<String>,
val codeChallenge: String? = null,
val codeChallengeMethod: CodeChallengeMethod? = null
) : ExpirableToken
Original file line number Diff line number Diff line change
@@ -1,13 +1,38 @@
package nl.myndocs.oauth2.token.converter

import nl.myndocs.oauth2.client.CodeChallengeMethod
import nl.myndocs.oauth2.identity.Identity
import nl.myndocs.oauth2.token.CodeToken

interface CodeTokenConverter {
fun convertToToken(
identity: Identity,
clientId: String,
redirectUri: String,
requestedScopes: Set<String>
): CodeToken
identity: Identity,
clientId: String,
redirectUri: String,
requestedScopes: Set<String>
): CodeToken {
throw NotImplementedError("CodeTokenConverter must implement " +
"convertToToken(Identity, String, String?, CodeChallengeMethod?, String, Set<String>): CodeToken")
}

fun convertToToken(
identity: Identity,
clientId: String,
codeChallenge: String?,
codeChallengeMethod: CodeChallengeMethod?,
redirectUri: String,
requestedScopes: Set<String>
): CodeToken {
if (codeChallenge != null || codeChallengeMethod != null) {
throw IllegalStateException("CodeTokenConverter must implement " +
"convertToToken(Identity, String, String?, CodeChallengeMethod?, String, Set<String>): CodeToken")
}

return convertToToken(
identity = identity,
clientId = clientId,
redirectUri = redirectUri,
requestedScopes = requestedScopes
)
}
}
Original file line number Diff line number Diff line change
@@ -1,26 +1,47 @@
package nl.myndocs.oauth2.token.converter

import nl.myndocs.oauth2.client.CodeChallengeMethod
import nl.myndocs.oauth2.identity.Identity
import nl.myndocs.oauth2.token.CodeToken
import java.time.Instant
import java.util.*

class UUIDCodeTokenConverter(
private val codeTokenExpireInSeconds: Int = 300
private val codeTokenExpireInSeconds: Int = 300
) : CodeTokenConverter {
override fun convertToToken(
identity: Identity,
clientId: String,
redirectUri: String,
requestedScopes: Set<String>
identity: Identity,
clientId: String,
redirectUri: String,
requestedScopes: Set<String>
): CodeToken {
return convertToToken(
identity = identity,
clientId = clientId,
codeChallenge = null,
codeChallengeMethod = null,
redirectUri = redirectUri,
requestedScopes = requestedScopes
)
}

override fun convertToToken(
identity: Identity,
clientId: String,
codeChallenge: String?,
codeChallengeMethod: CodeChallengeMethod?,
redirectUri: String,
requestedScopes: Set<String>
): CodeToken {
return CodeToken(
UUID.randomUUID().toString(),
Instant.now().plusSeconds(codeTokenExpireInSeconds.toLong()),
identity,
clientId,
redirectUri,
requestedScopes
UUID.randomUUID().toString(),
Instant.now().plusSeconds(codeTokenExpireInSeconds.toLong()),
identity,
clientId,
redirectUri,
requestedScopes,
codeChallenge,
codeChallengeMethod
)
}
}
Original file line number Diff line number Diff line change
@@ -16,6 +16,8 @@ import nl.myndocs.oauth2.identity.Identity
import nl.myndocs.oauth2.identity.IdentityService
import nl.myndocs.oauth2.request.AuthorizationCodeRequest
import nl.myndocs.oauth2.request.CallContext
import nl.myndocs.oauth2.client.CodeChallengeMethod
import nl.myndocs.oauth2.extension.sha256
import nl.myndocs.oauth2.response.AccessTokenResponder
import nl.myndocs.oauth2.token.AccessToken
import nl.myndocs.oauth2.token.CodeToken
@@ -103,6 +105,35 @@ internal class AuthorizationCodeGrantTokenServiceTest {
grantingCall.authorize(authorizationCodeRequest)
}

@Test
fun validAuthorizationCodePKCEGrant() {
val requestScopes = setOf("scope1")

val codeVerifier = "my-secret-code-challenge"
val challengeCode = codeVerifier.sha256()
val challengeCodeMethod = CodeChallengeMethod.S256

val client = Client(clientId, setOf("scope1", "scope2"), setOf(), setOf(AuthorizedGrantType.AUTHORIZATION_CODE))
val identity = Identity(username)
val codeToken = CodeToken(code, Instant.now(), identity, clientId, redirectUri, requestScopes, challengeCode, challengeCodeMethod)

val refreshToken = RefreshToken("test", Instant.now(), identity, clientId, requestScopes)
val accessToken = AccessToken("test", "bearer", Instant.now(), identity, clientId, requestScopes, refreshToken)

every { clientService.clientOf(clientId) } returns client
every { clientService.validClient(client, "") } returns true
every { identityService.identityOf(client, username) } returns identity
every { tokenStore.consumeCodeToken(code) } returns codeToken
every { refreshTokenConverter.convertToToken(identity, clientId, requestScopes) } returns refreshToken
every { accessTokenConverter.convertToToken(identity, clientId, requestScopes, refreshToken) } returns accessToken

val request = authorizationCodeRequest.copy(
clientSecret = "",
codeVerifier = codeVerifier
)
grantingCall.authorize(request)
}

@Test
fun nonExistingClientException() {
every { clientService.clientOf(clientId) } returns null
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package nl.myndocs.oauth2.request

import nl.myndocs.oauth2.client.CodeChallengeMethod
import org.hamcrest.MatcherAssert.assertThat
import org.hamcrest.Matchers.`is`
import org.junit.jupiter.api.Test

class CodeChallengeMethodTest {
@Test
fun validatePlain() {
val codeChallenge = "plain_test"
val resultValidVerifier = CodeChallengeMethod.Plain.validate(codeChallenge, "plain_test")
assertThat(resultValidVerifier, `is`(true))

val resultInvalidVerifier = CodeChallengeMethod.Plain.validate(codeChallenge, "plain_tes")
assertThat(resultInvalidVerifier, `is`(false))
}

@Test
fun validateS256() {
val codeChallenge = "W6YWc_4yHwYN-cGDgGmOMHF3l7KDy7VcRjf7q2FVF-o="
val resultValidVerifier = CodeChallengeMethod.S256.validate(codeChallenge, "s256test")
assertThat(resultValidVerifier, `is`(true))

val resultInvalidVerifier = CodeChallengeMethod.S256.validate(codeChallenge, "s256tes")
assertThat(resultInvalidVerifier, `is`(false))
}

@Test
fun validateS256NoPadding() {
val codeChallenge = "W6YWc_4yHwYN-cGDgGmOMHF3l7KDy7VcRjf7q2FVF-o"
val resultValidVerifier = CodeChallengeMethod.S256.validate(codeChallenge, "s256test")
assertThat(resultValidVerifier, `is`(true))

val resultInvalidVerifier = CodeChallengeMethod.S256.validate(codeChallenge, "s256tes")
assertThat(resultInvalidVerifier, `is`(false))
}
}
Original file line number Diff line number Diff line change
@@ -3,13 +3,16 @@ package nl.myndocs.oauth2.integration
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.kotlin.registerKotlinModule
import nl.myndocs.oauth2.client.AuthorizedGrantType
import nl.myndocs.oauth2.client.CodeChallengeMethod
import nl.myndocs.oauth2.client.inmemory.InMemoryClient
import nl.myndocs.oauth2.config.ConfigurationBuilder
import nl.myndocs.oauth2.extension.sha256
import nl.myndocs.oauth2.identity.inmemory.InMemoryIdentity
import nl.myndocs.oauth2.tokenstore.inmemory.InMemoryTokenStore
import okhttp3.*
import org.hamcrest.CoreMatchers.*
import org.hamcrest.MatcherAssert.assertThat
import org.junit.jupiter.api.Disabled
import org.junit.jupiter.api.Test
import java.util.*

@@ -34,13 +37,26 @@ abstract class BaseIntegrationTest {
AuthorizedGrantType.REFRESH_TOKEN
)
}
.client {
clientId = "testapp_pkce"
scopes = setOf("trusted")
redirectUris = setOf("http://localhost:8080/callback")
authorizedGrantTypes = setOf(
AuthorizedGrantType.AUTHORIZATION_CODE
)
allowedCodeChallengeMethods = setOf(
CodeChallengeMethod.S256
)
public = true
}
tokenStore = InMemoryTokenStore()

}

private val objectMapper = ObjectMapper().registerKotlinModule()

@Test
@Disabled
fun `test password grant flow`() {
val client = OkHttpClient()
val body = FormBody.Builder()
@@ -70,6 +86,7 @@ abstract class BaseIntegrationTest {
}

@Test
@Disabled
fun `test authorization grant flow`() {

val client = OkHttpClient.Builder()
@@ -100,7 +117,7 @@ abstract class BaseIntegrationTest {

val body = FormBody.Builder()
.add("grant_type", "authorization_code")
.add("code", response.header("location")!!.asQueryParameters()["code"])
.add("code", response.header("location")!!.asQueryParameters().getValue("code"))
.add("redirect_uri", "http://localhost:8080/callback")
.add("client_id", "testapp")
.add("client_secret", "testpass")
@@ -124,6 +141,66 @@ abstract class BaseIntegrationTest {
}

@Test
fun `test authorization grant flow with PKCE`() {
val client = OkHttpClient.Builder()
.followRedirects(false)
.build()

val codeVerifier = "simple_challenge"
val codeChallenge = codeVerifier.sha256()
val codeChallengeMethod = CodeChallengeMethod.S256.value

val url = HttpUrl.Builder()
.scheme("http")
.host("localhost")
.port(localPort!!)
.addPathSegment("oauth")
.addPathSegment("authorize")
.setQueryParameter("response_type", "code")
.setQueryParameter("client_id", "testapp_pkce")
.setQueryParameter("redirect_uri", "http://localhost:8080/callback")
.setQueryParameter("code_challenge", codeChallenge)
.setQueryParameter("code_challenge_method", codeChallengeMethod)
.build()

val request = Request.Builder()
.addHeader("Authorization", Credentials.basic("foo", "bar"))
.url(url)
.get()
.build()

val response = client.newCall(request)
.execute()

response.close()

val body = FormBody.Builder()
.add("grant_type", "authorization_code")
.add("code", response.header("location")!!.asQueryParameters().getValue("code"))
.add("redirect_uri", "http://localhost:8080/callback")
.add("client_id", "testapp_pkce")
.add("code_verifier", codeVerifier)
.build()

val tokenUrl = buildOauthTokenUri()

val tokenRequest = Request.Builder()
.url(tokenUrl)
.post(body)
.build()

val tokenResponse = client.newCall(tokenRequest)
.execute()

val values = objectMapper.readMap(tokenResponse.body()!!.string())
assertThat(values["access_token"], `is`(notNullValue()))
assertThat(UUID.fromString(values["access_token"] as String), `is`(instanceOf(UUID::class.java)))

tokenResponse.close()
}

@Test
@Disabled
fun `test client credentials flow`() {
val client = OkHttpClient()
val body = FormBody.Builder()