DEV Community

David Berry
David Berry

Posted on

A Prime Iterator in rust

Inspired by Alessio Saltarin's Primality Test in Scala, I thought that it would be interesting to implement the same primality test in rust and then use this test as a means to build an Iterator that lazily generates prime numbers.

The primality test is based on the fact that all primes greater than 3 are of the form 6k ± 1, where k is any integer greater than 0. The rust version of this algorithm is:

pub fn is_prime(n: u64) -> bool {
  if n < 4 {
    n > 1
  } else if n % 2 == 0 || n % 3 == 0 {
    false
  } else {
    let max_p = (n as f64).sqrt().ceil() as u64;
    match (5..=max_p).step_by(6).find(|p| n % p == 0 || n % (p+2) == 0) {
      Some(_) => false,
      None => true
    }
  }
}
Enter fullscreen mode Exit fullscreen mode

is_prime first tests for n<4 and returns n>1. n=2 or n=3 will return true all other n<4 return false. The next test checks if n is an even multiple of 2 or 3. If it is, return false as n is not prime.

The final test checks to see if n is evenly divisible by any 6k ± 1 up to the square root of n. This is done by starting at 5 and checking if n is an even multiple of 5 or 5+2, 6-1 or 6+1 respectively. If it is an even multiple we will return false, if not we add 6 and test again. If no even multiple is found, n is prime and return true.

Many will point out that using sqrt() and ceil() are expensive, but I only use it once up front to get the upper bound for the comparison. This removes the need to square the 6k ± 1 in the test loop to determine when to break. It also allowed me to use the range (5..=max_p).step_by(6).find with the predicate |p| n % p == 0 || n % (p+2) == 0 to identify non-primes. I add that to a match statement to return false if a non-prime is identified and true if n is prime.

Now that we have a primality test we can construct the lazy prime Iterator in rust. The code for the Iterator is:

pub struct Prime {
  curr: u64,
  next: u64,
}

impl Prime {
  pub fn new() -> Prime {
    Prime {
      curr: 2,
      next: 3,
    }
  }
}

impl Iterator for Prime {
  type Item = u64;

  fn next(&mut self) -> Option<Self::Item> {
    let prime = self.curr;
    self.curr = self.next;
    loop {
      self.next += match self.next%6 {
        1 => 4,
        _ => 2,
      };
      if is_prime(self.next) {
        break;
      }
    }
    Some(prime)
  }
}
Enter fullscreen mode Exit fullscreen mode

A new Prime starts off with curr=2 and next=3. This allowed me to ignore the special case of two primes being only one digit apart. The next function saves the prime in curr which it will return. It then moves next into curr and calculates a new next. It makes use of the fact that all primes greater than 3 are of the form 6k ± 1. So if the latest test prime is 6k-1 we add 2 to get the next possible prime. If the latest test prime is 6K+1 we add 4 to get the next possible prime. We accomplish this with p modulo 6, where p is the latest test prime. If p modulo 6 is 1 then we have a 6k+1 number and need to add 4 to get to the next test prime of 6k-1. If p modulo 6 is 5 or 3, then we add 2 to get to the next 6k+1 number. _ => 2 is the catch all that let's us catch the initial state of n=3 and any state where n=6k-1.

This iterator operates with constant memory and only generates one prime at a time, but it performs a modulo and match operation inside the loop. After some consideration, both the modulo and match can be removed from inside the loop. If we initialize the struct Prime with the first two primes and the first 6k ± 1 we only need to add 6 to get the next trial prime. The final Iterator is:

pub struct Prime {
  curr: u64,
  next: u64,
  trial1: u64,
  trial2: u64
}

impl Prime {
  pub fn new() -> Prime {
    Prime {
      curr: 2,
      next: 3,
      trial1: 5,
      trial2: 7
    }
  }
}

impl Iterator for Prime {
  type Item = u64;

  fn next(&mut self) -> Option<Self::Item> {
    let prime = self.curr;
    self.curr = self.next;
    loop {
      self.next = self.trial1;
      self.trial1 = self.trial2;
      self.trial2 = self.next+6;
      if is_prime(self.next) {
        break;
      }
    }
    Some(prime)
  }
}
Enter fullscreen mode Exit fullscreen mode

Full code can be found on GitHub

Discussion (1)

Collapse
guildenstern70 profile image
Alessio Saltarin

Good work!