If you’re writing your own programming language, then you’ve probably heard of type inference. In this series of articles, without too much theory, we will clearly go through how it works and implement our own in Rust.
What is type inference?
Type inference in the context of compilers is the stage of compilation in which the compiler infers types for expressions. Type inference is mostly inherent in functional languages, but other groups of languages can also implement it.
Algorithm
Disclaimer: There will be a minimum of theory and some terms and definitions may not be used.
Before doing type analysis, our compiler needs to parse our source. And it is in this article that we will use the concepts of AST and HIR to a minimum and will consider type inference using code examples, not trees.
Let’s imagine that we have code in a hypothetical programming language for which we want to implement type inference:
func inc(a) {
return a + 1
}
Since the types are not specified in the function signature, we will have to infer them ourselves. The first thing we need to do is create type variables:
func inc(a: ?1) -> ?2 {
return a + 1;
}
Type variable — in this case, denotes a type that we have not yet inferred.
For simplicity and clarity, let’s assume that the signature of the +
operator is: (int, int) -> int
. Then, from a + 1
it follows that the type of parameter a
must be int
and the type of expression 1
must also be int
. And since a + 1
is equal to int
and is the inc
function’s return value, the function’s return value is int
. Thus, we have received information about the relationship between types or type constraints:
?1 = int
int = int
?2 = int
Type constraint — simply put, is a way to record information about relationships between types.
Here we use equality constraint, but we could also use subtyping, for example, in numeric types (which we will cover in this article).
Now that we have a list of these very relationships, we can find out the values of variable types, this is called unification. The bottom line is to get the values of type variables themselves from the system of equations, which is a list of type constraints:
?1 = int
?2 = int
Readers might wonder what happens if the equation has no solution. The answer is simple — mismatched types error. For example:
func inc(a: ?1) -> ?2 {
return a + "hello";
}
?1 = int
String = int
?2 = int
Since String
cannot be equal to int
, the type constraint is unmet, and the equation has no solutions. Then we can see the mismatched types error:
Let’s take a more interesting example:
func generate_nums(count) {
var a = [];
for (var i = 0; i < count; i++) {
a.insert(i);
}
return a;
}
First, as in the previous example, mark up type variables for names and function signatures:
func generate_nums(count: ?1) -> ?2 {
var a: ?3 = [];
for (var i: ?4 = 0; i < count; i++) {
a.insert(i);
}
return a;
}
Now if we look at the place where we declare a
— we can see that an empty array is created. We can use a type variable to indicate the type of array elements:
?3 = Array<?5>
Moving to the next statement:
for (var i: ?4 = 0; i < count; i++) {
Suppose that the signature of the operator ++
is: (int) -> int
and the signature of the operator <
: (int, int) -> bool
. So we have 3 new constraints:
?4 = int
?4 = ?1
?4 = int
Let’s continue!
a.insert(i);
Let’s suppose the insert method of an array is defined as something like this: List<T>.insert(element: T)
. Then for the generic T
we need one more type variable:
?3 = Array<?6> // a
?6 = ?4 // i
And now analyzing the last return statement:
func generate_nums(count: ?1) -> ?2 {
var a: ?3 = [];
...
return a;
}
The return value of the function is a, which means that:
?3 = ?2
Thus, we obtain a system of equations:
?3 = Array<?5>
?4 = int
?4 = ?1
?4 = int
?3 = Array<?6>
?6 = ?4
?3 = ?2
And again, using the **unification **algorithm, we can solve this system of equations and get:
?1 = int
?2 = Array<int>
?3 = Array<int>
?4 = int
?5 = int
?6 = int
Substituting these types, we get the result of type inference:
func generate_nums(count: int) -> Array<int> {
var a: Array<int> = [];
for (var i: int = 0; i < count; i++) {
a.insert(i);
}
return a;
}
Now, perhaps the reader has a question, how does this mysterious unification algorithm work that solves a system of equations from type constraints? Now we’ll figure it out!
Implementation
First, let’s write a representation of types in our language:
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
enum Type {
Constructor(TypeConstructor),
Variable(TypeVariable),
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct TypeConstructor {
name: String,
generics: Vec<Arc<Type>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct TypeVariable(usize);
So that:
-
int
isTypeConstructor { name: “int”, generics: Vec::new() }
. -
List<int>
isTypeConstructor { name: “List”, generics: vec![TypeConstructor { name: “int”, generics: vec![] }] }
. -
?1
isTypeVariable(1)
.
Now we can start implementing the unification algorithm:
fn unify(left: Arc<Type>, right: Arc<Type>,
substitutions: &mut HashMap<TypeVariable, Arc<Type>>) {
match (left.as_ref(), right.as_ref()) {
If both types are type constructors, then we check that they are equal and unify their generic parameters:
(
Type::Constructor(TypeConstructor {
name: name1,
generics: generics1,
}),
Type::Constructor(TypeConstructor {
name: name2,
generics: generics2,
}),
) => {
assert_eq!(name1, name2);
assert_eq!(generics1.len(), generics2.len());
for (left, right) in zip(generics1, generics2) {
unify(left.clone(), right.clone(), substitutions);
}
}
For example from:
Array<int> = Array<?1>
Follows that:
int = ?1
If we get two different type constructors, then we have a type mismatch:
Array<...> != Option<...>
If both sides are equal type variables, then everything is fine and we do nothing:
(Type::Variable(TypeVariable(i)),
Type::Variable(TypeVariable(j))) if i == j => {}
If not, then we add the value of the variable to the storage and, importantly, check whether we have created an infinite type.
(_, Type::Variable(v @ TypeVariable(..))) => {
if let Some(substitution) = substitutions.get(&v) {
unify(left, substitution.clone(), substitutions);
return;
}
assert!(!v.occurs_in(left.clone(), substitutions));
substitutions.insert(*v, left);
}
(Type::Variable(v @ TypeVariable(..)), _) => {
if let Some(substitution) = substitutions.get(&v) {
unify(right, substitution.clone(), substitutions);
return;
}
assert!(!v.occurs_in(right.clone(), substitutions));
substitutions.insert(*v, right);
}
An example of when we try to create an infinite type in Rust:
In this example, generic in push — T
, its value is of type a.to_vec()
, that is, Vec<T>
. We get T = Vec<T>
. The only possible solution for this constraint is Vec<Vec<Vec<Vec<Vec<Vec<Vec<….>>>>>>>
. Of course, there are languages that allow this, but in this case, for simplicity and to avoid problems, we will not accept such types.
Let’s now implement occurs_in, which checks whether the type is present in the generic arguments of another if that constructor, or equal to it if it is a variable:
impl TypeVariable {
fn occurs_in(&self, ty: Arc<Type>,
substitutions: &HashMap<TypeVariable, Arc<Type>>) -> bool {
match ty.as_ref() {
Type::Variable(v @ TypeVariable(i)) => {
if let Some(substitution) = substitutions.get(&v) {
if substitution.as_ref() != &Type::Variable(*v) {
return self.occurs_in(substitution.clone(), substitutions);
}
}
self.0 == *i
}
Type::Constructor(TypeConstructor { generics, .. }) => {
for generic in generics {
if self.occurs_in(generic.clone(), substitutions) {
return true;
}
}
false
}
}
}
}
We will also create a function that will recursively go through our store of values of type variables in order to completely remove them, that is, for example:
?1 = ?2
?2 = ?3
?3 = int
substitute(?1) = substitute(?2) = substitute(?3) = int
impl Type {
fn substitute(&self, substitutions: &HashMap<TypeVariable, Arc<Type>>) -> Arc<Type> {
match self {
Type::Constructor(TypeConstructor { name, generics }) => {
Arc::new(Type::Constructor(TypeConstructor {
name: name.clone(),
generics: generics
.iter()
.map(|t| t.substitute(substitutions))
.collect(),
}))
}
Type::Variable(TypeVariable(i)) => {
if let Some(t) = substitutions.get(&TypeVariable(*i)) {
t.substitute(substitutions)
} else {
Arc::new(self.clone())
}
}
}
}
}
We did it, we wrote a unification algorithm! Let’s check it out in practice! Remember the previous example?
?3 = Array<?5>
?4 = int
?4 = ?1
?4 = int
?3 = Array<?6>
?6 = ?4
?3 = ?2
Before simulating it, let’s write two macros:
The first macro, will shortly generate a type variable:
macro_rules! tvar {
($i:expr) => {
Arc::new(Type::Variable(TypeVariable($i)))
};
}
The second macro will shortly generate a type constructor:
macro_rules! tconst {
($name:expr,$($generic:expr)*) => {
Arc::new(Type::Constructor(TypeConstructor {
name: $name.to_string(),
generics: vec![$($generic),*],
}))
};
($name:expr) => { tconst!($name,) };
}
Now let’s simulate our previous example on our implementation of the unification algorithm:
fn main() {
let mut substitutions = HashMap::new();
unify(tvar!(3), tconst!("Array", tvar!(5)), &mut substitutions);
unify(tvar!(4), tconst!("int"), &mut substitutions);
unify(tvar!(4), tvar!(1), &mut substitutions);
unify(tvar!(4), tconst!("int"), &mut substitutions);
unify(tvar!(3), tconst!("Array", tvar!(6)), &mut substitutions);
unify(tvar!(6), tvar!(4), &mut substitutions);
unify(tvar!(3), tvar!(2), &mut substitutions);
for i in 1..=6 {
println!(
"{}: {:?}",
i,
Type::Variable(TypeVariable(i)).substitute(&substitutions)
);
}
}
We get:
1: Constructor(TypeConstructor { name: "int", generics: [] })
2: Constructor(TypeConstructor { name: "Array", generics: [Constructor(TypeConstructor { name: "int", generics: [] })] })
3: Constructor(TypeConstructor { name: "Array", generics: [Constructor(TypeConstructor { name: "int", generics: [] })] })
4: Constructor(TypeConstructor { name: "int", generics: [] })
5: Constructor(TypeConstructor { name: "int", generics: [] })
6: Constructor(TypeConstructor { name: "int", generics: [] })
In the next article, we’ll start writing our little programming language and start defining the types of simple expressions like literals or arrays!
Top comments (0)