DEV Community

loading...
Cover image for How to write stronger unit tests with a custom go-mock matcher

How to write stronger unit tests with a custom go-mock matcher

TECH SCHOOL
We believe that everyone deserves a good and free education. The purpose of Tech School is to give everyone a chance to learn IT by giving free, high-quality tutorials and coding courses.
Updated on ・13 min read

Hello everyone!

In this lecture, we will learn how to write a custom gomock matcher to make our Golang unit tests stronger.

Here's:

The weak unit test

In the last lecture, we have learned how to securely store users’ password using bcrypt. We have also implemented the API to create a new user for our simple bank.

func (server *Server) createUser(ctx *gin.Context) {
    var req createUserRequest
    if err := ctx.ShouldBindJSON(&req); err != nil {
        ctx.JSON(http.StatusBadRequest, errorResponse(err))
        return
    }

    hashedPassword, err := util.HashPassword(req.Password)
    if err != nil {
        ctx.JSON(http.StatusInternalServerError, errorResponse(err))
        return
    }

    arg := db.CreateUserParams{
        Username:       req.Username,
        HashedPassword: hashedPassword,
        FullName:       req.FullName,
        Email:          req.Email,
    }

    user, err := server.store.CreateUser(ctx, arg)
    if err != nil {
        if pqErr, ok := err.(*pq.Error); ok {
            switch pqErr.Code.Name() {
            case "unique_violation":
                ctx.JSON(http.StatusForbidden, errorResponse(err))
                return
            }
        }
        ctx.JSON(http.StatusInternalServerError, errorResponse(err))
        return
    }

    rsp := createUserResponse{
        Username:          user.Username,
        FullName:          user.FullName,
        Email:             user.Email,
        PasswordChangedAt: user.PasswordChangedAt,
        CreatedAt:         user.CreatedAt,
    }
    ctx.JSON(http.StatusOK, rsp)
}
Enter fullscreen mode Exit fullscreen mode

Although I didn’t show you how to write unit tests for this API, since it would be very similar to what we have done in lecture 13 of the course. With the help of go-mock, any API unit tests can be written with ease.

However, if you have tried to write unit tests for the create user API yourself, you might find it a little bit tricky, due to the fact that the input password param is hashed before storing in the database.

To understand why, let’s look at the simple version of the tests that I’ve already written here.

func TestCreateUserAPI(t *testing.T) {
    user, password := randomUser(t)

    testCases := []struct {
        name          string
        body          gin.H
        buildStubs    func(store *mockdb.MockStore)
        checkResponse func(recoder *httptest.ResponseRecorder)
    }{
        {
            name: "OK",
            body: gin.H{
                "username":  user.Username,
                "password":  password,
                "full_name": user.FullName,
                "email":     user.Email,
            },
            buildStubs: func(store *mockdb.MockStore) {
                store.EXPECT().
                    CreateUser(gomock.Any(), gomock.Any()).
                    Times(1).
                    Return(user, nil)
            },
            checkResponse: func(recorder *httptest.ResponseRecorder) {
                require.Equal(t, http.StatusOK, recorder.Code)
                requireBodyMatchUser(t, recorder.Body, user)
            },
        },
        {
            name: "InternalError",
            body: gin.H{
                "username":  user.Username,
                "password":  password,
                "full_name": user.FullName,
                "email":     user.Email,
            },
            buildStubs: func(store *mockdb.MockStore) {
                store.EXPECT().
                    CreateUser(gomock.Any(), gomock.Any()).
                    Times(1).
                    Return(db.User{}, sql.ErrConnDone)
            },
            checkResponse: func(recorder *httptest.ResponseRecorder) {
                require.Equal(t, http.StatusInternalServerError, recorder.Code)
            },
        },
        {
            name: "DuplicateUsername",
            body: gin.H{
                "username":  user.Username,
                "password":  password,
                "full_name": user.FullName,
                "email":     user.Email,
            },
            buildStubs: func(store *mockdb.MockStore) {
                store.EXPECT().
                    CreateUser(gomock.Any(), gomock.Any()).
                    Times(1).
                    Return(db.User{}, &pq.Error{Code: "23505"})
            },
            checkResponse: func(recorder *httptest.ResponseRecorder) {
                require.Equal(t, http.StatusForbidden, recorder.Code)
            },
        },
        {
            name: "InvalidUsername",
            body: gin.H{
                "username":  "invalid-user#1",
                "password":  password,
                "full_name": user.FullName,
                "email":     user.Email,
            },
            buildStubs: func(store *mockdb.MockStore) {
                store.EXPECT().
                    CreateUser(gomock.Any(), gomock.Any()).
                    Times(0)
            },
            checkResponse: func(recorder *httptest.ResponseRecorder) {
                require.Equal(t, http.StatusBadRequest, recorder.Code)
            },
        },
        {
            name: "InvalidEmail",
            body: gin.H{
                "username":  user.Username,
                "password":  password,
                "full_name": user.FullName,
                "email":     "invalid-email",
            },
            buildStubs: func(store *mockdb.MockStore) {
                store.EXPECT().
                    CreateUser(gomock.Any(), gomock.Any()).
                    Times(0)
            },
            checkResponse: func(recorder *httptest.ResponseRecorder) {
                require.Equal(t, http.StatusBadRequest, recorder.Code)
            },
        },
        {
            name: "TooShortPassword",
            body: gin.H{
                "username":  user.Username,
                "password":  "123",
                "full_name": user.FullName,
                "email":     user.Email,
            },
            buildStubs: func(store *mockdb.MockStore) {
                store.EXPECT().
                    CreateUser(gomock.Any(), gomock.Any()).
                    Times(0)
            },
            checkResponse: func(recorder *httptest.ResponseRecorder) {
                require.Equal(t, http.StatusBadRequest, recorder.Code)
            },
        },
    }

    ...
}
Enter fullscreen mode Exit fullscreen mode

Basically, we first create a random user to be created.

Then we declare a table of test cases, where we can define the input request body, and 2 functions to build store stubs and check the response of the API.

There are several different cases we can test, such as:

  • The successful case
  • Internal server error case
  • Duplicate username case
  • Invalid username, email, or password case

We iterate through all of these cases, and run a separate sub-test for each of them.

func TestCreateUserAPI(t *testing.T) {
    ...

    for i := range testCases {
        tc := testCases[i]

        t.Run(tc.name, func(t *testing.T) {
            ctrl := gomock.NewController(t)
            defer ctrl.Finish()

            store := mockdb.NewMockStore(ctrl)
            tc.buildStubs(store)

            server := NewServer(store)
            recorder := httptest.NewRecorder()

            // Marshal body data to JSON
            data, err := json.Marshal(tc.body)
            require.NoError(t, err)

            url := "/users"
            request, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(data))
            require.NoError(t, err)

            server.router.ServeHTTP(recorder, request)
            tc.checkResponse(recorder)
        })
    }
}
Enter fullscreen mode Exit fullscreen mode

In each sub-test, we create a new gomock controller, and use it to build a new mock DB store.

Then we call the buildStubs() function of the test case to set up the stubs for that store.

After that, we create a new server using the mock store, and create a new HTTP response recorder to record the result of the API call.

Next we marshal the input request body to JSON, and make a new POST request to the create-user API endpoint with that JSON data.

We call server.router.ServeHTTP() function with the recorder and request object. And finally just call tc.checkResponse() function to check the result.

It’s pretty simple, just like what we’ve learned in Lecture 13.

I highly recommend you to read it first to make sure you fully understand the code before continue.

Now, for today’s lecture, we only need to focus on 1 case: the successful one.

func TestCreateUserAPI(t *testing.T) {
    user, password := randomUser(t)

    testCases := []struct {
        name          string
        body          gin.H
        buildStubs    func(store *mockdb.MockStore)
        checkResponse func(recoder *httptest.ResponseRecorder)
    }{
        {
            name: "OK",
            body: gin.H{
                "username":  user.Username,
                "password":  password,
                "full_name": user.FullName,
                "email":     user.Email,
            },
            buildStubs: func(store *mockdb.MockStore) {
                store.EXPECT().
                    CreateUser(gomock.Any(), gomock.Any()).
                    Times(1).
                    Return(user, nil)
            },
            checkResponse: func(recorder *httptest.ResponseRecorder) {
                require.Equal(t, http.StatusOK, recorder.Code)
                requireBodyMatchUser(t, recorder.Body, user)
            },
        },
        ...
    }

    ...
}
Enter fullscreen mode Exit fullscreen mode

As you can see here, the request body’s parameters are all valid.

In the build stubs function, we expect the CreateUser() function of the store to be called with 2 parameters. In this simple version, we’re using the gomock.Any() matcher for both of them.

Note that the first argument of the store.CreateUser() function is a context, which we don’t care about its value, so it makes sense to use any matcher.

However, using that same matcher for the second argument will weaken the test, because it won’t be able to detect if the createUserParams object passed in the CreateUser() function is correct or not. I’m gonna show you how in a moment. For now, let’s just keep this gomock.Any() matcher.

And this CreateUser function is expected to be called exactly once, and it will return the user object with no errors, since this is the happy case.

For the check response function, we just check the HTTP status code to be 200 OK, and make sure that the response body matches the created user object.

That’s it! Let’s run the test.

Alt Text

All passed. So we’re good, right?

Well, not really!

As I said before, this gomock.Any() matcher will make the test weaker.

How?

Let’s see what will happen if in this createUser() handler, I set the argument variable to an empty CreateUserParam{} object.

func (server *Server) createUser(ctx *gin.Context) {
    ...

    arg := db.CreateUserParams{}

    user, err := server.store.CreateUser(ctx, arg)
    if err != nil {
        if pqErr, ok := err.(*pq.Error); ok {
            switch pqErr.Code.Name() {
            case "unique_violation":
                ctx.JSON(http.StatusForbidden, errorResponse(err))
                return
            }
        }
        ctx.JSON(http.StatusInternalServerError, errorResponse(err))
        return
    }

    ...
}
Enter fullscreen mode Exit fullscreen mode

This will discard all the input parameters of the request, and will create a completely empty user in the database. So it should not be allowed, and we expect the test to fail, right?

However, if we run the test again:

Alt Text

It still passed!

This is very bad, because the implementation of the handler is completely wrong, but the test could not detect it!

Another case that this test would not be able to detect is like this:

Let’s remove this set empty argument statement, but I will ignore the user’s input password, and just hash a constant value, such as "xyz" here.

func (server *Server) createUser(ctx *gin.Context) {
    ...

    hashedPassword, err := util.HashPassword("xyz")
    if err != nil {
        ctx.JSON(http.StatusInternalServerError, errorResponse(err))
        return
    }

    arg := db.CreateUserParams{
        Username:       req.Username,
        HashedPassword: hashedPassword,
        FullName:       req.FullName,
        Email:          req.Email,
    }

    user, err := server.store.CreateUser(ctx, arg)
    if err != nil {
        if pqErr, ok := err.(*pq.Error); ok {
            switch pqErr.Code.Name() {
            case "unique_violation":
                ctx.JSON(http.StatusForbidden, errorResponse(err))
                return
            }
        }
        ctx.JSON(http.StatusInternalServerError, errorResponse(err))
        return
    }

    ...
}
Enter fullscreen mode Exit fullscreen mode

Then go back to the test, and run it again.

Alt Text

As you can see, it still passed!

This is unacceptable! The test we wrote is too weak! We need to fix it!

Try using gomock.Eq

One thing that might come to your mind is: what if instead of using gomock.Any() matcher, we use something else, such as the gomock.Eq() matcher? Let’s try it!

First I will declare a new arg variable of type db.CreateUserParams, where username is user.Username.

For the HashedPassword field, we need to hash the input naked password, so let’s go up to the top of the test. Here, after generating the random user object, we call util.HashPassword, and pass in the generated password value.

This function will return a hashedPassword value and an error, so we have to make sure error is nil using require.NoError().

func TestCreateUserAPI(t *testing.T) {
    user, password := randomUser(t)

    hashedPassword, err := util.HashPassword(password)
    require.NoError(t, err)

    testCases := []struct {
        name          string
        body          gin.H
        buildStubs    func(store *mockdb.MockStore)
        checkResponse func(recoder *httptest.ResponseRecorder)
    }{
        {
            name: "OK",
            body: gin.H{
                "username":  user.Username,
                "password":  password,
                "full_name": user.FullName,
                "email":     user.Email,
            },
            buildStubs: func(store *mockdb.MockStore) {
                arg := db.CreateUserParams{
                    Username: user.Username,
                    HashedPassword: hashedPassword,
                    FullName: user.FullName,
                    Email: user.Email,
                }
                store.EXPECT().
                    CreateUser(gomock.Any(), gomock.Eq(arg)).
                    Times(1).
                    Return(user, nil)
            },
            checkResponse: func(recorder *httptest.ResponseRecorder) {
                require.Equal(t, http.StatusOK, recorder.Code)
                requireBodyMatchUser(t, recorder.Body, user)
            },
        },
        ...
    }

    ...
}
Enter fullscreen mode Exit fullscreen mode

Next, the FullName should be user.FullName, and finally Email should be user.Email.

With this object, we can now replace gomock.Any() matcher with gomock.Eq(arg)

OK now, let’s try to test the case where all input parameters are discard.

func (server *Server) createUser(ctx *gin.Context) {
    ...

    arg := db.CreateUserParams{}

    user, err := server.store.CreateUser(ctx, arg)
    if err != nil {
        if pqErr, ok := err.(*pq.Error); ok {
            switch pqErr.Code.Name() {
            case "unique_violation":
                ctx.JSON(http.StatusForbidden, errorResponse(err))
                return
            }
        }
        ctx.JSON(http.StatusInternalServerError, errorResponse(err))
        return
    }

    ...
}
Enter fullscreen mode Exit fullscreen mode

Since we have used a stronger Eq() matcher here, the test should fail, right? Let’s run it to confirm!

Alt Text

Yes, that’s right! The test failed as expected.

The logs tell us it failed because of the missing calls, which is true, because although the CreateUser function was called, it was not called with the correct input argument as we wanted.

So is that it? Did we fixed the issue?

Unfortunately no! Let’s see what will happen if I remove this line arg := db.CreateUserParams{}

func (server *Server) createUser(ctx *gin.Context) {
    ...

    arg := db.CreateUserParams{
        Username:       req.Username,
        HashedPassword: hashedPassword,
        FullName:       req.FullName,
        Email:          req.Email,
    }

    user, err := server.store.CreateUser(ctx, arg)
    if err != nil {
        if pqErr, ok := err.(*pq.Error); ok {
            switch pqErr.Code.Name() {
            case "unique_violation":
                ctx.JSON(http.StatusForbidden, errorResponse(err))
                return
            }
        }
        ctx.JSON(http.StatusInternalServerError, errorResponse(err))
        return
    }

    ...
}
Enter fullscreen mode Exit fullscreen mode

Now the create user handler function is correctly taken into account all input parameters, so we expect the test to pass, right? Let’s run it to confirm!

Alt Text

Sadly, the test didn’t pass. It still failed due to missing call.

And the real reason is that the CreateUser() function is called with an input argument that doesn’t match with the one we expect.

In the log, we can see clearly what value the mock store got, compared to what it wants to receive.

It looks like the Username, FullName, and Email are all matched. Only the HashedPassword values are different. Do you know why?

Well, if you still remember what we learned in the last lecture about bcrypt, it uses a random salt when hashing the password to prevent rainbow table attack.

So even if we pass the same password value into the hash function, it will always produce a new hashed value every time.

Because of this, the hashed value that we created in the test and the one in the create user handler will always be different. So we cannot simply use the built-in gomock.Eq() matcher to compare the argument.

The only way to fix this properly is to implement a new custom matcher on our own in this case. Although it sounds a bit annoying, it’s actually very easy to implement. And I think it would be useful for you if you ever encounter some special cases like this in your real project.

OK, let’s learn how to do it!

Implement a custom gomock matcher

First I will remove the hashedPassword because it is not needed in the custom matcher that we’re going to implement.

We will have to replace the gomock.Eq matcher with our own matcher. So let’s open its implementation.

func Eq(x interface{}) Matcher { return eqMatcher{x} }

type Matcher interface {
    // Matches returns whether x is a match.
    Matches(x interface{}) bool

    // String describes what the matcher matches.
    String() string
}
Enter fullscreen mode Exit fullscreen mode

It’s simply a function that takes the expected argument x as input and returns a Matcher interface. In this specific case, it returns an implementation of the matcher that matches on equality: eqMatcher.

For our custom matcher, we will have to write a similar implementation of the Matcher interface, which has only 2 methods:

  • The first one is Matches(), which should return whether the input x is a match or not.
  • And the second one is String(), which just describes what the matcher matches for logging purpose.

There are several built-in implementations of the Matcher interface. For example, this one is anyMatcher, which will always return true regardless of the input argument.

type anyMatcher struct{}

func (anyMatcher) Matches(interface{}) bool {
    return true
}

func (anyMatcher) String() string {
    return "is anything"
}
Enter fullscreen mode Exit fullscreen mode

Then this is the equal matcher that we’re using:

type eqMatcher struct {
    x interface{}
}

func (e eqMatcher) Matches(x interface{}) bool {
    return reflect.DeepEqual(e.x, x)
}

func (e eqMatcher) String() string {
    return fmt.Sprintf("is equal to %v", e.x)
}
Enter fullscreen mode Exit fullscreen mode

It uses reflect.DeepEqual to compare the actual input argument with the expected one.

The custom matcher that we’re going to implement will be very similar to this one, so I’m gonna copy all of these functions, and paste them to the top of the user_test.go file.

Then let’s change the name of this struct to eqCreateUserParamsMatcher. In order to compare the input arguments correctly, we will need to store 2 fields this struct:

  • First the arg field of type db.CreateUserParams
  • And second, the password field to store the naked password value.
type eqCreateUserParamsMatcher struct {
    arg      db.CreateUserParams
    password string
}
Enter fullscreen mode Exit fullscreen mode

OK, now let’s implement the Matches() function. Since the input parameter x is an interface, we should convert it to db.CreateUserParams object.

If the conversion is not OK, then we just return false. Otherwise, we will check if the hashed password matches with the expected password or not by calling util.CheckPassword() function with e.password and arg.HashedPassword.

func (e eqCreateUserParamsMatcher) Matches(x interface{}) bool {
    arg, ok := x.(db.CreateUserParams)
    if !ok {
        return false
    }

    err := util.CheckPassword(e.password, arg.HashedPassword)
    if err != nil {
        return false
    }

    e.arg.HashedPassword = arg.HashedPassword
    return reflect.DeepEqual(e.arg, arg)
}
Enter fullscreen mode Exit fullscreen mode

If this function returns an error, then we return false. Else, we will set the hashedPassword field of the expected argument e.arg to the same value with the input arg.HashPassword.

And we use reflect.DeepEqual to compare the expected argument e.arg with the input argument arg.

That’s all! Pretty simple, right?

OK, now let’s update this message of the String function to include the expected argument and naked password values.

func (e eqCreateUserParamsMatcher) String() string {
    return fmt.Sprintf("matches arg %v and password %v", e.arg, e.password)
}
Enter fullscreen mode Exit fullscreen mode

And our custom matcher is done.

Next we will add a function to return an instance of this matcher, just like the Eq() function that returns an EqMatcher instance.

I’m gonna change this function’s name to EqCreateUserParams(), and it will have 2 input arguments: a db.CreateUserParams object, and a naked password string.

This function will return a Matcher interface, which in our case is the eqCreateUserParamsMatcher object with the input argument and password.

func EqCreateUserParams(arg db.CreateUserParams, password string) gomock.Matcher {
    return eqCreateUserParamsMatcher{arg, password}
}
Enter fullscreen mode Exit fullscreen mode

Alright, so now we have everything we need for the new custom matcher. Let’s use it in the unit test to see how it goes.

I’m gonna change the gomock.Eq() to EqCreateUserParams(), then pass in the CreateUserParams argument arg and naked password.

func TestCreateUserAPI(t *testing.T) {
    user, password := randomUser(t)

    testCases := []struct {
        name          string
        body          gin.H
        buildStubs    func(store *mockdb.MockStore)
        checkResponse func(recoder *httptest.ResponseRecorder)
    }{
        {
            name: "OK",
            body: gin.H{
                "username":  user.Username,
                "password":  password,
                "full_name": user.FullName,
                "email":     user.Email,
            },
            buildStubs: func(store *mockdb.MockStore) {
                arg := db.CreateUserParams{
                    Username: user.Username,
                    FullName: user.FullName,
                    Email:    user.Email,
                }
                store.EXPECT().
                    CreateUser(gomock.Any(), EqCreateUserParams(arg, password)).
                    Times(1).
                    Return(user, nil)
            },
            checkResponse: func(recorder *httptest.ResponseRecorder) {
                require.Equal(t, http.StatusOK, recorder.Code)
                requireBodyMatchUser(t, recorder.Body, user)
            },
        },
        ...
    }

    ...
}
Enter fullscreen mode Exit fullscreen mode

Just like that, and we’re done. Let’s rerun the test!

Alt Text

It passed. Excellent!

Now let’s try the case where we set this argument to an empty db.CreateUserParams{} object.

func (server *Server) createUser(ctx *gin.Context) {
    ...

    arg := db.CreateUserParams{}

    user, err := server.store.CreateUser(ctx, arg)
    if err != nil {
        if pqErr, ok := err.(*pq.Error); ok {
            switch pqErr.Code.Name() {
            case "unique_violation":
                ctx.JSON(http.StatusForbidden, errorResponse(err))
                return
            }
        }
        ctx.JSON(http.StatusInternalServerError, errorResponse(err))
        return
    }

    ...
}
Enter fullscreen mode Exit fullscreen mode

We expect the test to fail.

Alt Text

And it does fail, since the input argument doesn’t match the expected one.

OK, how about the case where we ignore the input password parameter, and just hash this constant password value: xyz

func (server *Server) createUser(ctx *gin.Context) {
    ...

    hashedPassword, err := util.HashPassword("xyz")
    if err != nil {
        ctx.JSON(http.StatusInternalServerError, errorResponse(err))
        return
    }

    ...
}
Enter fullscreen mode Exit fullscreen mode

Let’s run the test!

Alt Text

It failed, just like what we expected.

So now our unit test has become much stronger than it was before. Thanks to the new custom matcher that we’ve just implemented.

And that’s it for today’s lecture. I hope you find it interesting and useful.

Thanks for reading, and see you in the next one!


If you like the article, please subscribe to our Youtube channel and follow us on Twitter or Facebook for more tutorials in the future.


If you want to join me on my current amazing team at Voodoo, check out our job openings here. Remote or onsite in Paris/Amsterdam/London/Berlin/Barcelona with visa sponsorship.

Discussion (0)