DEV Community

Cover image for Understanding Interactive Proof Systems and Sum Check Protocol: Part 2
Rachit Anand
Rachit Anand

Posted on

Understanding Interactive Proof Systems and Sum Check Protocol: Part 2

In Part 1, we covered the math, and the logic behind Sum-Check Protocol. In this part, we will be hacking up a quick Sum-Check protocol source code for demonstration.

Prerequisites:

  • Familiarity with GoLang.
  • Understanding of Interactive Proofs.
  • Understanding of the maths behind Sum-Check Protocol (covered here)

First, let's write some utils functions inside a file named utils.go

utils.go:
First, we write a helper function that returns the number of arguments taken by a function.

type FuncType func(...int) int

func Arity(f interface{}) int {
 // Get the type of f, which should be a function.
 fType := reflect.TypeOf(f)
 if fType.Kind() != reflect.Func {
  // Optionally, handle the case where f is not a function.
  return -1
 }
 // Return the number of input arguments.
 return fType.NumIn()
}

Enter fullscreen mode Exit fullscreen mode

Then we write a helper function, which takes an input n, and returns its binary vector, front-padded to pad length ( taken as input again ).

func ToBits(n int, padToLen int) []int {
 binStr := strconv.FormatInt(int64(n), 2)

 if len(binStr) > padToLen {
  padToLen = len(binStr)
 }

 v := make([]int, len(binStr))
 for i, ch := range binStr {
  if ch == '1' {
   v[i] = 1
  } else {
   v[i] = 0
  }
 }
 diff := padToLen - len(v)

 paddedV := make([]int, diff)
 return append(paddedV, v...)
}
Enter fullscreen mode Exit fullscreen mode
  1. Binary Representation:
    The line binStr := strconv.FormatInt(int64(n), 2) converts the integer n into a binary string representation. For instance, if n is 6, binStr will be "110".

  2. Creating a Slice of Integers:
    The subsequent loop creates a slice v of integers where each element corresponds to a digit in the binary string. '1's in the binary string are represented as 1 in the slice, and '0's as 0.

  3. Padding:

  4. The function then calculates diff, the difference between the desired length padToLen and the actual length of the binary representation. If the binary representation is shorter than padToLen, it needs to be padded with zeroes.

  5. paddedV := make([]int, diff) creates a slice of zeroes of length diff.

  6. Finally, the function returns a new slice that concatenates the padding and the binary representation slice (append(paddedV, v...)).
    Lastly: We write a util function that returns the degree of j’th variable in g.

func DegJ(g FuncType, j int) int {
 exp := 1
 for {
  args := make([]int, 1)
  for i := range args {
   if i == j {
    args[i] = 100
   } else {
    args[i] = 1
   }
  }

  out1 := g(args[0])

  args[0] = 1000

  out2 := g(args[0])

  // Consider a function f(x) = x²
  // To find the degree of x (assuming it's the second variable, so x = 1), the function would compare f(100) with f(1000).
  // If x is cubed (x³), the output should scale by 1000^ 3 / 100 ^ 3
  // when x changes from 100 to 1000. The function checks if this scaling holds to estimate the degree.
  if math.Abs(float64(out1)/math.Pow(100, float64(exp))-float64(out2)/math.Pow(1000, float64(exp))) < 1 {
   return exp
  } else if exp > 10 {
   panic("exp grew larger than 10")
  } else {
   exp++
  }

 }
}
Enter fullscreen mode Exit fullscreen mode

verifier.go:
This is where all the logic related to the verifier takes place.

Firstly we define the verifier as follows:

type Verifier struct {
 g                FuncType
 gArity           int // represents the number of inputs to polynomial
 H                int // witness
 randomChallenges []int
 round            int
 polynomials      []FuncType
}
Enter fullscreen mode Exit fullscreen mode

Now, we need a function that accepts polynomial sent by the prover:

func (v *Verifier) RecievePolynomials(polynomial FuncType) {
 v.polynomials = append(v.polynomials, polynomial)
}
Enter fullscreen mode Exit fullscreen mode

Now we need a function, that can verify that the latest polynomial is a univariant polynomial of almost deg_j(g) and that the relationship mentioned:

gj1(rj1)=gj(0)+gj(1)g_{j-1}(r_{j-1}) = g_j(0) + g_j(1)

For this, we write another function on the verifier struct:

func (v *Verifier) CheckLatestPolynomial() error {
 latestPoly := v.polynomials[len(v.polynomials)-1]
 degLatest := DegJ(latestPoly, 0)
 degMax := DegJ(v.g, v.round-1)

 if degLatest > degMax {
  return fmt.Errorf("Prover sent polynomial of degree %d greater than expected : %d", degLatest, degMax)
 }

 newSum := latestPoly(0) + latestPoly(1)

 var check int

 if v.round == 1 {
  check = v.H
 } else {
  check = v.polynomials[len(v.polynomials)-2](v.randomChallenges[len(v.randomChallenges)-1])
 }

 if check != newSum {
  return fmt.Errorf("Prover sent incorrect polynomials: %d, expected %d", newSum, check)
 }

 return nil
}
Enter fullscreen mode Exit fullscreen mode

If you read this, you will know that we also need to have a mechanism to generate a random number r and send it as a challenge to the prover. To achieve this:

func (v *Verifier) GetNewRandomValueAndSend(p *Prover) {
 rand.Seed(time.Now().UnixNano())
 v.randomChallenges = append(v.randomChallenges, rand.Intn(2))
 p.ReceiveChallenge(v.randomChallenges[len(v.randomChallenges)-1])
 v.round++
}
Enter fullscreen mode Exit fullscreen mode

Now, lastly, there is the part where in the final step, the verifier has all the random challenges, and now it uses them to calculate the final value of g. This should match with the final function sₙ(Xn) sent by the prover. If the values match, the verifier accepts.

func (v *Verifier) EvaluateAndCheckGV() (bool, error) {
 if len(v.randomChallenges) != v.gArity-1 {
  return false, fmt.Errorf("Incorrect number of random challenges")
 }

 v.randomChallenges = append(v.randomChallenges, rand.Intn(2))
 gFinal := v.g(v.randomChallenges...)
 check := v.polynomials[len(v.polynomials)-1](v.randomChallenges[len(v.randomChallenges)-1])

 if gFinal != check {
  return false, fmt.Errorf("Prover sent incorrect final polynomials")
 }

 fmt.Println("VERIFIER ACCEPTS")
 return true, nil
}
Enter fullscreen mode Exit fullscreen mode

And that’s it. We have a basic code ready for a verifier in a Sum-Check Protocol!

prover.go:

This is where all the logic related to the prover goes.

Firstly, we define the prover struct as follows:

type Prover struct {
 gArity            int
 randomChallenges  []int
 cachedPolynomials []FuncType
 round             int
 H                 int
}
Enter fullscreen mode Exit fullscreen mode

Now, remember in the Sum-Check protocol, a new polynomial is calculated at each interaction, where several variables = original arguments — round number.

The code iterates from 0 to 2^pad - 1. In each iteration, ToBits is used to create a binary representation of the loop counter i, padded to pad bits. This binary slice is prepended with the first argument (args[0]), and then passed to the polynomial function poly .

This function dynamically generates inputs for the polynomial poly, combining a constant first argument with varying binary patterns, and aggregating the outputs, playing a key role in the sum-check protocol.

func (p *Prover) ComputeAndSendNextPolynomial(v *Verifier) {
 round := p.round
 poly := p.cachedPolynomials[len(p.cachedPolynomials)-1]

 gJ := func(args ...int) int {
  if len(args) == 0 {
   // Handle the case where no arguments are passed
   panic("gJ requires at least one argument")
  }
  pad := p.gArity - round
  var sum int
  for i := 0; i < (1 << pad); i++ {
   args := append([]int{args[0]}, ToBits(i, pad)...)
   sum += poly(args...)
  }
  return sum
 }

 v.RecievePolynomials(gJ)
 p.round++
}
Enter fullscreen mode Exit fullscreen mode

The iteration process in the gJ the function is fundamental to the sum-check protocol, and it's designed to systematically evaluate a polynomial over all possible binary combinations of a certain length. Here's why we iterate this way:

  • Exhaustive Evaluation: The goal is to evaluate the polynomial poly for every possible combination of binary inputs. This ensures a thorough and complete assessment of the polynomial's behavior across its entire domain.
  • Binary Combinations: By iterating from 0 to 2^(pad−1) and converting these numbers to binary, we generate all possible combinations of binary digits of length pad. This is a standard way to cover all cases in binary representation.
  • Reducing Complexity Per Round: In each round of the sum-check protocol, one less variable is considered, reducing the computational complexity step by step. pad decreases with each round, reflecting this reduction. Lastly, we need a function to receive challenges from the verifier:

func (p *Prover) ReceiveChallenge(challenge int) {
 p.randomChallenges = append(p.randomChallenges, challenge)
 p.CacheNext(challenge)
 fmt.Printf("Received challenge %d, initiating round %d\n", challenge, p.round)
}
Enter fullscreen mode Exit fullscreen mode

where CacheNext is simply:

func (p *Prover) CacheNext(challenge int) {
 poly := p.cachedPolynomials[len(p.cachedPolynomials)-1]

 nextPoly := func(args ...int) int {
  return poly(append([]int{challenge}, args...)...)
 }

 p.cachedPolynomials = append(p.cachedPolynomials, nextPoly)
}
Enter fullscreen mode Exit fullscreen mode

Now to run this protocol, we can define a simple main.go file with the following code to run the prover and the verifier:

type SumcheckProtocol struct {
 gArity int
 p      *Prover
 v      *Verifier
 round  int
 done   bool
}

// do the initialization
....


// Advance protocol by 1 round
func (s *SumcheckProtocol) AdvanceRound() {
 if s.done {
  panic("Sumcheck protocol is finished")
 }

 s.p.ComputeAndSendNextPolynomial(s.v)
 s.v.CheckLatestPolynomial()

 if s.round == s.gArity {
  // final round
  s.done, _ = s.v.EvaluateAndCheckGV()
 } else {
  s.v.GetNewRandomValueAndSend(s.p)
  s.round++
 }
}

// Advance protocol to the end
func (s *SumcheckProtocol) AdvanceToEnd(verbose bool) {
 for !s.done {
  if verbose {
   fmt.Println("Advance Output:", s)
  }

  s.AdvanceRound()
 }
}
Enter fullscreen mode Exit fullscreen mode

Now time to test the protocol !!!
To do this, we create a new file main_test.go and add the following lines:

func TestSumcheck(t *testing.T) {

 g := func(args ...int) int {
  a := args[0]
  return a + a + a*a
 }

 protocol := NewSumcheckProtocol(g)
 fmt.Print(protocol)
 protocol.AdvanceToEnd(true)

 f := func(args ...int) int {
  a := args[0]
  return a*a*a + a + a
 }

 protocol = NewSumcheckProtocol(f)
 protocol.AdvanceToEnd(true)

 ff := func(args ...int) int {
  a := args[0]
  return a*a*a + a + a + a*a
 }
 protocol = NewSumcheckProtocol(ff)
 protocol.AdvanceToEnd(true)
}
Enter fullscreen mode Exit fullscreen mode

And Hurray 🎉 🎉 🎉 🎉 !! You just wrote the Sum-Check protocol from scratch.

Here is the complete source code in the tutorial: Github

Lastly, connect with me on my socials: Twitter, and LinkedIn. Happy to talk about cryptography, blockchain, and any computer science topic in general.

Top comments (0)