DEV Community

Shannon
Shannon

Posted on

Testing Vault in Go

Recently, I wrote some code for a CLI that reads all the secrets in a Vault path, and I went down the road of figuring out how to actually unit test some of my functions that are reaching out to Vault via API calls.

Importantly, my functions should not be reaching out to a real instance of Vault to do this testing. As a core tenet of unit testing, we shouldn't be reaching out to real APIs, live databases, etc. in unit tests. So, this meant I really had two options:

  1. Mock out the methods that are communicating directly with the Vault API using interfaces and dependency injection.
  2. Spin up a test Vault server, write secrets into it, and have the API calls run against this test server

Luckily, both were viable options! I am going to run through both options as a way to work through mocking out these methods, but the second method is clearly superior, in my opinion. Additionally, this is how Vault is actually being tested.

Mocking the Vault Client to test Read() and List() methods

First, let's set up some scaffolding below. All of this can go in main.go:

var (
    ErrSecretNotFound    = errors.New("no secret not found at given path")
    ErrVaultVarsNotFound = errors.New("VAULT_TOKEN and VAULT_ADDR environment variables must be set")
)

// 1
type Logicaler interface {
    List(path string) (*api.Secret, error)
    Read(path string) (*api.Secret, error)
}

// 2
type VaultClient struct {
    client Logicaler
}

// 3
func NewVaultClient() (*VaultClient, error) {
    token := os.Getenv("VAULT_TOKEN")
    vault_addr := os.Getenv("VAULT_ADDR")

    if token == "" || vault_addr == "" {
        return &VaultClient{}, ErrVaultVarsNotFound
    }

    config := &api.Config{
        Address: vault_addr,
    }

    client, err := api.NewClient(config)
    if err != nil {
        return &VaultClient{}, err
    }

    client.SetToken(token)

    return &VaultClient{
        client: client.Logical(),
    }, nil
}

// 4
func (v *VaultClient) ReadSecret(endpoint string) ([]string, error) {
    // 5
    secret, err := v.client.Read(endpoint)
    if err != nil {
        return []string{}, err
    }
    if secret == nil {
        return []string{}, ErrSecretNotFound
    }

    list := []string{}
    for _, v := range secret.Data["data"].(map[string]interface{}) {
        list = append(list, v.(string))
    }

    return list, nil
}

// 6
func (v *VaultClient) ListSecret(path string) ([]string, error) {
    // 7
    secret, err := v.client.List(path)
    if err != nil {
        return []string{}, err
    }
    if secret == nil {
        return []string{}, ErrSecretNotFound
    }

    list := []string{}
    for _, v := range secret.Data["keys"].([]interface{}) {
        list = append(list, v.(string))
    }

    return list, nil
}
Enter fullscreen mode Exit fullscreen mode

Some notes on the code above to make it comprehensible:

  1. We are setting up an interface that requires two methods to implicitly be satisfied: Read and List. It's idiomatic in Go to append er onto interfaces, so I opted for this method.
  2. Create a VaultClient struct that hosts a single field: the interface Logicaler.
  3. Create a new Vault client that establishes a real connection to a Vault server. You can think of this as a constructor.
  4. Method of the VaultClient to read a secret in Vault
  5. This part is important! It is the actual communication with Vault. Notice this is one of the two methods in the interface to be satisfied. When mocking, you want to search out the functions or methods that are communicating directly with the API to be mocked.
  6. Same as 4. We are listing secrets in a Vault path.
  7. Same as 5! This is the other method that must be mocked and satisfied in the interface.

Let's create some additional testing scaffolding in our main_test.go file before explaining how we're going to mock this out:

package main

import (
    "fmt"
    "testing"

    kv "github.com/hashicorp/vault-plugin-secrets-kv"
    "github.com/hashicorp/vault/api"
    vaulthttp "github.com/hashicorp/vault/http"
    "github.com/hashicorp/vault/sdk/logical"
    hashivault "github.com/hashicorp/vault/vault"
)

// 1
type MockVault struct{}

// 2
func (mv *MockVault) List(path string) (*api.Secret, error) {
    return &api.Secret{}, nil
}

func (mv *MockVault) Read(path string) (*api.Secret, error) {
    return &api.Secret{}, nil
}

// 3
func NewMockVaultClient() *VaultClient {
    return &VaultClient{
        client: &MockVault{},
    }
}
Enter fullscreen mode Exit fullscreen mode
  1. We're creating a Mock Vault client to emulate the actual client's methods (Read and List).
  2. We're creating those methods (although they return nothing at the moment) to satisfy the Logicaler interface.
  3. Create a mock Vault client were the MockVault struct is passed into the client field, where it is looking for a variable of type Logicaler.

With this code, check everything is compiling: go build.


Now that we have a basic structure, we can create some mock data to be returned in the *api.Secret object. Let's update the Read(path string) method to return mock data:

func (mv *MockVault) Read(path string) (*api.Secret, error) {
    return &api.Secret{
        Data: map[string]interface{}{
            "data": map[string]interface{}{
                "value": "fakedata",
            },
        },
    }, nil
}
Enter fullscreen mode Exit fullscreen mode

The Data field is the actual contents of the secret data. Unfortunately, this is a bit tough to read because of the multiple nested map[string]interface{} objects. It is worth peeking at the definition of the *api.Secret object in the Vault package to see how this is defined.


We have some mock data being returned in our Read() method now, so let's write a test for this!

func TestReadSecret(t *testing.T) {
    vc := NewMockVaultClient()
    want := []string{"fakedata"}
    got, _ := vc.ReadSecret("")

    if !reflect.DeepEqual(want, got) {
        t.Errorf("got %v but want %v", got, want)
    }
}
Enter fullscreen mode Exit fullscreen mode

As you can see, we are now creating a NewMockVaultClient() instead of the actual client. Then, we are calling the method ReadSecret(path string), which calls out to the method Read(path string). This method is implemented in the Logicaler interface and thus calls out to the mocked (mv *MockVault) Read(path string) method instead of the actual VaultClient! Finally, we're testing that the arrays are the exact same.

If this is confusing, try implementing the same thing with TestListSecret afterwards. Now, you can test this is working with go test --run TestReadSecret.

Spinning up a test Vault server in Go to test Read() and List()

Let's take a step back and think about the differences between mocking and using a real Vault test client. Because we are not mocking the client, we no longer need an interface to mock out the methods. We'll still be using the real methods but calling out to a test server spun up during the testing process. Thus, we can get rid of the interface and adjust some of our code.

Here is the scaffolding:

type VaultClient struct {
    // 1
    client *api.Client
}

func NewVaultClient() (*VaultClient, error) {
    token := os.Getenv("VAULT_TOKEN")
    vault_addr := os.Getenv("VAULT_ADDR")

    if token == "" || vault_addr == "" {
        return &VaultClient{}, ErrVaultVarsNotFound
    }

    config := &api.Config{
        Address: vault_addr,
    }

    client, err := api.NewClient(config)
    if err != nil {
        return &VaultClient{}, err
    }

    client.SetToken(token)

    return &VaultClient{
        // 2
        client: client,
    }, nil
}

func (v *VaultClient) ReadSecret(endpoint string) ([]string, error) {
    // 3
    secret, err := v.client.Logical().Read(endpoint)
    if err != nil {
        return []string{}, err
    }
    if secret == nil {
        return []string{}, ErrSecretNotFound
    }

    list := []string{}
    for _, v := range secret.Data["data"].(map[string]interface{}) {
        list = append(list, v.(string))
    }

    return list, nil
}

func (v *VaultClient) ListSecret(path string) ([]string, error) {
    // 4
    secret, err := v.client.Logical().List(path)
    if err != nil {
        return []string{}, err
    }
    if secret == nil {
        return []string{}, ErrSecretNotFound
    }

    list := []string{}
    for _, v := range secret.Data["keys"].([]interface{}) {
        list = append(list, v.(string))
    }

    return list, nil
}
Enter fullscreen mode Exit fullscreen mode

I've marked a couple spots to note again, so let's walk through these.

  1. We are now passing in an *api.Client from the Vault package instead of the Logicaler interface. Because we are no longer mocking anything, we don't need an interface.
  2. See the client being passed in. We are no longer calling the Logical() method.
  3. As you can see, the method calls look a bit different. Each time we use the client, we are running the Logical() method, which is used to return the client for logical-backend API calls.
  4. Same as 3. You now need to call client.Logical().* to run the Vault commands.

So, we have our code that is calling out to Vault. Now we need to write some testing that spins up a testing Vault server to run this against. Let's create the scaffolding for the test cluster:

// CreateTestVault spins up a Vault server and tests against
// an actual Vault instance. Currently, this is only set up for kv v2
func createTestVault(t testing.TB) *hashivault.TestCluster {
    t.Helper()

    // CoreConfig parameterizes the Vault core config
    coreConfig := &hashivault.CoreConfig{
        LogicalBackends: map[string]logical.Factory{
            "kv": kv.Factory,
        },
    }

    cluster := hashivault.NewTestCluster(t, coreConfig, &hashivault.TestClusterOptions{
        // Handler returns an http.Handler for the API. This can be used on
        // its own to mount the Vault API within another web server.
        HandlerFunc: vaulthttp.Handler,
    })
    cluster.Start()

    // Create KV V2 mount on the path /test
    // It starts in cluster mode, so you just pick one of the three clients
    // In this case, Cores[0] is just always picking the first one
    if err := cluster.Cores[0].Client.Sys().Mount("test", &api.MountInput{
        Type: "kv",
        Options: map[string]string{
            "version": "2",
        },
    }); err != nil {
        t.Fatal(err)
    }

    return cluster
}
Enter fullscreen mode Exit fullscreen mode

Note: most of this is grabbed from this GitHub issue. It's worth reading through in totality and shows multiple ways to set up test clusters.

The comments lay most of the details out, but the important bit is that we're setting up a test Vault cluster with the /test path mounted.


We'll take the testing of ReadSecret piece by piece, so here is the initial setup that creates the Vault cluster, makes a client, and waits some time post-mount with the new cluster.

func TestReadSecrets(t *testing.T) {
    cluster := createTestVault(t)
    defer cluster.Cleanup()
    vaultClient := cluster.Cores[0].Client // only need a client from 1 of 3 clusters

    _ = &VaultClient{
        client: vaultClient,
    }

    // time buffer required after new mount
    // https://github.com/hashicorp/terraform-provider-vault/issues/677#issuecomment-609116328
    // Code 400: Errors: Upgrading from non-versioned to versioned data. This backend will be unavailable for a brief period and will resume service shortly.
    time.Sleep(2 * time.Second)
}
Enter fullscreen mode Exit fullscreen mode

Of note, we are passing in only a single Core of the client. In the cluster spinup, there are 3 cores. We are just choosing the first.


Next, we need to write in some fake data into the path /test in order to read the secrets later on. Let's write some data:

// set up sample data to write into vault
testData := []struct {
        path  string
        key   string
        value string
    }{
        {"test/data/test0", "key0", "data0"},
        {"test/data/test1", "key1", "data1"},
        {"test/data/test2", "key2", "data2"},
    }

    // write k/v data pairs into vault
    for _, v := range testData {
        _, err := vc.client.Logical().Write(v.path, map[string]interface{}{
            "data": map[string]interface{}{
                v.key: v.value,
            },
        })
        if err != nil {
            t.Fatal(err)
        }
    }
Enter fullscreen mode Exit fullscreen mode

In the above code, we are creating a struct that will be looped over three times to write sample data into Vault.


Finally, we create a test table that reads secrets and confirms their validity.

testTable := []struct {
        name       string
        endpoint   string
        key        string
        want       []string
        vaultError error
    }{
        // 1
        {
            name:       "find a k/v match",
            endpoint:   "test/test0",
            key:        "key0",
            want:       []string{"data0"},
            vaultError: nil,
        },
        // 2
        {
            name:     "do not find a secret",
            endpoint: "test/test123", key: "test_0_key",
            want:       []string{"test_0_data"},
            vaultError: ErrSecretNotFound,
        },
    }

    for _, tc := range testTable {
        t.Run(tc.name, func(t *testing.T) {
            secrets, err := vc.ReadSecret(tc.endpoint)
            if err != tc.vaultError {
                t.Fatal(err)
            }

            // 3
            for i := 0; i < len(secrets); i++ {
                if secrets[i] != tc.want[i] {
                    assert.Equal(t, tc.want[i], secrets[i])
                }
            }
        })
    }
Enter fullscreen mode Exit fullscreen mode
  1. We're finding a valid match here. Thus, we shouldn't have any kind of error returned, so that's why vaultError is nil.
  2. We're not finding a valid match here. However, Vault by default does not return an error when there isn't a match. It returns an empty secret object. However, we have logic built in to accommodate for this in vc.ReadSecret. Thus, we're checking for ErrSecretNotFound.
  3. Instead of using reflect, we are looping over the secrets object. When an error is returned and there is no secret, this will simply not have a loop, as no comparison is needed.

Phew! That is a lot of code. But we now have a working iteration of a test Vault cluster being created, having secrets, written in, and then running our methods against it.

And for posterity, I'm going to post the entire main_test.go file below because there were so many additions. À très bientôt!

package main

import (
    "testing"
    "time"

    kv "github.com/hashicorp/vault-plugin-secrets-kv"
    "github.com/hashicorp/vault/api"
    vaulthttp "github.com/hashicorp/vault/http"
    "github.com/hashicorp/vault/sdk/logical"
    hashivault "github.com/hashicorp/vault/vault"
    "gotest.tools/assert"
)

// CreateTestVault spins up a Vault server and tests against
// an actual Vault instance. Currently, this is only set up for
// kv v2. Mostly copied from this github issue:
// https://github.com/hashicorp/vault/issues/8440
func createTestVault(t testing.TB) *hashivault.TestCluster {
    t.Helper()

    // CoreConfig parameterizes the Vault core config
    coreConfig := &hashivault.CoreConfig{
        LogicalBackends: map[string]logical.Factory{
            "kv": kv.Factory,
        },
    }

    cluster := hashivault.NewTestCluster(t, coreConfig, &hashivault.TestClusterOptions{
        // Handler returns an http.Handler for the API. This can be used on
        // its own to mount the Vault API within another web server.
        HandlerFunc: vaulthttp.Handler,
    })
    cluster.Start()

    // Create KV V2 mount on the path /test
    // It starts in cluster mode, so you just pick one of the three clients
    // In this case, Cores[0] is just always picking the first one
    if err := cluster.Cores[0].Client.Sys().Mount("test", &api.MountInput{
        Type: "kv",
        Options: map[string]string{
            "version": "2",
        },
    }); err != nil {
        t.Fatal(err)
    }

    return cluster
}

func TestReadSecrets(t *testing.T) {
    cluster := createTestVault(t)
    defer cluster.Cleanup()
    vaultClient := cluster.Cores[0].Client // only need a client from 1 of 3 clusters

    vc := &VaultClient{
        client: vaultClient,
    }

    // time buffer required after new mount
    // https://github.com/hashicorp/terraform-provider-vault/issues/677#issuecomment-609116328
    // Code 400: Errors: Upgrading from non-versioned to versioned data. This backend will be unavailable for a brief period and will resume service shortly.
    time.Sleep(2 * time.Second)

    testData := []struct {
        path  string
        key   string
        value string
    }{
        {"test/data/test0", "key0", "data0"},
        {"test/data/test1", "key1", "data1"},
        {"test/data/test2", "key2", "data2"},
    }

    // write k/v data pairs into vault
    for _, v := range testData {
        _, err := vc.client.Logical().Write(v.path, map[string]interface{}{
            "data": map[string]interface{}{
                v.key: v.value,
            },
        })
        if err != nil {
            t.Fatal(err)
        }
    }

    testTable := []struct {
        name       string
        endpoint   string
        key        string
        want       []string
        vaultError error
    }{
        // 1
        {
            name:       "find a k/v match",
            endpoint:   "test/data/test0",
            key:        "key0",
            want:       []string{"data0"},
            vaultError: nil,
        },
        // 2
        {
            name:     "do not find a secret",
            endpoint: "test/data/test123", key: "test_0_key",
            want:       []string{},
            vaultError: ErrSecretNotFound,
        },
    }

    for _, tc := range testTable {
        t.Run(tc.name, func(t *testing.T) {
            secrets, err := vc.ReadSecret(tc.endpoint)
            if err != tc.vaultError {
                t.Fatal(err)
            }

            // 3
            for i := 0; i < len(secrets); i++ {
                if secrets[i] != tc.want[i] {
                    assert.Equal(t, tc.want[i], secrets[i])
                }
            }
        })
    }
}

Enter fullscreen mode Exit fullscreen mode

Discussion (0)