DEV Community

Cover image for Verifying Cognito JWT tokens in go
C Turner
C Turner

Posted on

Verifying Cognito JWT tokens in go

You can use JSON Web Tokens (JWTs) as a part of OpenID Connect (OIDC) and OAuth 2.0 frameworks to limit client access to your APIs.

Background

Authorizing API requests should consist of a few steps.

Validate claims

  • kid - The token must have a header claim that matches the key in the jwks_uri that signed the token.
  • iss – Must match the issuer that is configured for the authorizer.
  • exp – Must be before the current time in UTC.
  • nbf – Must be before the current time in UTC.
  • eat – Must be before the current time in UTC.
  • aud or client_id

The audience ("client_id") in the payload matches the app client ID created in the Cognito user pool.

Verify the signature

The Signature is created using the Header and Payload segments, a signing algorithm, and a secret or public key (depending on the chosen signing algorithm). You do this to verify that the token was signed by the sender and not altered in any way.

You can get the information you need from your public keys, which are available at address:
https://cognito-idp.{region}.amazonaws.com/{userPoolId}/.well-known/jwks.json

I used Auth0's jwt middleware library because it has the ability to check the authorization header for a JWT and it decodes the JWT and sets the content to the request context.

You will also need the library "github.com/form3tech-oss/jwt-go".

I used jwtmiddleware so that all my requests check the JWT token. Following along this example, I made two functions that

  1. validate the claims and
  2. create the RSA public key, which will be used by jwt-go CheckJWT() function to check the token's signature.

Validate claims with a function

Use this function as an argument to ValidationKeyGetter argument.

Mine looked like this:


type Jwks struct {
    Keys []JSONWebKeys `json:"keys"`
}

type JSONWebKeys struct {
    Kty string   `json:"kty"`
    Kid string   `json:"kid"`
    Use string   `json:"use"`
    N   string   `json:"n"`
    E   string   `json:"e"`
}

func validationGetter(token *jwt.Token) (interface{}, error) {


    clientId := "your client id"
    // AWS Cognito public keys are available at address:
    //https://cognito-idp.{region}.amazonaws.com/{userPoolId}/.well-known/jwks.json
    publicKeysURL := "url to your public key"
    iss := "your iss"


    resp, err := http.Get(publicKeysURL)

    if err != nil {
        return token, err
    }
    defer resp.Body.Close()

    var jwks = Jwks{}
    err = json.NewDecoder(resp.Body).Decode(&jwks)

    if err != nil {
        return token, err
    }

    // Verify 'iss' claim
    checkIss := token.Claims.(jwt.MapClaims).VerifyIssuer(iss, false)
    if !checkIss {
        return token, gqlerror.Errorf("invalid iss")
    }

    // Verify audience and make sure it matches client id

    aud, _ := token.Claims.(jwt.MapClaims)["client_id"].(string)
    if aud != clientId {
        return token, gqlerror.Errorf("invalid audience")
    }

    // Validates time based claims "exp, iat, nbf"
    err = token.Claims.(jwt.MapClaims).Valid()
    if err != nil {
        return token, errors.New("token expired")
    }

    checkKid := false
    for k, _ := range jwks.Keys {
        if token.Header["kid"] == jwks.Keys[k].Kid {
            checkKid = true
        }
    }

    if !checkKid {
        return token, errors.New("Invalid kid")
    }
    pk, err := getPublicKey(token, jwks)
    if err != nil {
        return nil, errors.New("Something went wrong.")
    }
    return pk, nil
}

Enter fullscreen mode Exit fullscreen mode

Be sure to make your error messages less helpful so that you can not drop hints for devs who are trying to forge tokens. The constants clientId, iss, and publicKeysURL should be fetched from your configuration file.

Create a public key that can be used to verify the signature

// getPublicKey ... function to return the public key
func getPublicKey(token *jwt.Token, jwks Jwks) (*rsa.PublicKey, error) {
    var pk *rsa.PublicKey

    for k, _ := range jwks.Keys {
        if token.Header["kid"] == jwks.Keys[k].Kid {
            // decode the base64 bytes for n
            nb, err := base64.RawURLEncoding.DecodeString(jwks.Keys[k].N)
            if err != nil {
                log.Fatal(err)
            }
            e := 0
            // The default exponent is usually 65537, so just compare the
            // base64 for [1,0,1] or [0,1,0,1]
            if jwks.Keys[k].E == "AQAB" || jwks.Keys[k].E == "AAEAAQ" {
                e = 65537
            } else {
                // need to decode "e" as a big-endian int
                log.Fatal("need to deocde e:", jwks.Keys[k].E)
            }
            pk = &rsa.PublicKey{
                N: new(big.Int).SetBytes(nb),
                E: e,
            }
            return pk, nil
        }
    }
    return pk, errors.New("Could not find match")
}

Enter fullscreen mode Exit fullscreen mode

Note that at the end of the function validationGetter we actually return the public key (pk) after we have completed our checks. The public key will then be used by jwtmiddleware to verify the signature of the token.

    jwtMiddleware := jwtmiddleware.New(jwtmiddleware.Options{
        ValidationKeyGetter: validationGetter,
        SigningMethod:       jwt.SigningMethodRS256,
        Debug:               true,
    })
Enter fullscreen mode Exit fullscreen mode

And that's it! What do you think? Can it be simplified? Drop a comment below.

Top comments (0)