@radekmie

On Automata in Rust

By Radosław Miernik · Published on · Comment on Reddit

Table of contents

Intro

While working on my PhD, I had to implement an interpreter of the language we’re working on. At the very core of it are finite automata and finite state machines, as well as tons of operations on them. In practice, it’s a directed graph with operations on the edges (assignments, comparisons, etc.).

As I’m most familiar with JavaScript, I implemented the first of the interpreter version in it. It allowed me to have a working version rather fast; as a bonus, it works in the browser, so it’s much more accessible to other people. However, at some point, the performance was not good enough to work interactively.

Because of that, I decided to implement the same interpreter in a low-level language. As the automata definition was a simple JSON object, I decided to go with Rust I was already familiar with. Thanks to Serde, I had a 1-to-1 matching between the TypeScript and Rust structures within minutes.

In this text, I’ll focus on the performance of one operation that had problems in JavaScript – checking reachability. In the future, I plan to dive more into the whole project and what the rest of the code looks like.

Shall we?

Naive implementation

We’ll need a couple of definitions first. An Automaton (graph) is a collection of Edges. Every Edge connects two Nodes with a label. It has only one, rather intuitive operation: add_edge. The exact representation of it is not relevant to the algorithm, so we can use whatever we find suitable.

Example automaton.

beginswitchx = 0;operatex != 3;endx == 3;x = fn[x];

There’s also State, which defines the state of the automaton we’re in. Most importantly, there’s the position of it, which points to one of the Nodes. It has one crucial operation – next_states – which is used to iterate over the following states, i.e., the states that can be reached from the current one. That means it has to check all of the edges that connect the position with other nodes and check whether it’s allowed in the current state.

An example usage of next_states is the dreaded is_reachable function, which checks, whether it’s possible to traverse the graph to a specified Node. It’s a recursive operation, potentially traversing the entire graph. In practice, it’s only a small subset of it, but it’s called very, very often.

Let’s get to coding and basic type definitions. They’re the same for all versions so I won’t duplicate them in each. It’s a simplified version, but enough for the rest of the examples to run.

// `pub` means this type can be imported in other modules. The module
// system in Rust is slightly different, but think of it as of `export`
// in JavaScript.

// There's no information stored in the node, so a number is enough.
pub type Node = usize;

// Directed pair of nodes.
pub struct Edge {
  lhs: Node,
  rhs: Node,
}

// A state of the automaton. In reality, there's much more data here.
// The `derive` attribute here generates a `default` function that
// creates an empty state, i.e., a state that is in `0`.
#[derive(Default)]
pub struct State {
  position: Node,
}

Now, let’s start with the most straightforward implementation I can imagine. A one I believe anyone who has never worked with Rust would implement.

pub struct Automaton {
  // A list of edges.
  edges: Vec<Edge>,
}

impl Automaton {
  pub fn add_edge(&mut self, edge: Edge) {
    // Add it to the list.
    self.edges.push(edge)
  }
}

impl State {
  pub fn is_reachable(&self, automaton: &Automaton, position: Node) -> bool {
    // We're either already there, or...
    self.position == position ||
    // ...one of the next states is.
    self
      .next_states(automaton)
      // Convert `Vec` into `Iterator`.
      .iter()
      .any(|state| state.is_reachable(automaton, position))
  }

  pub fn next_states(&self, automaton: &Automaton) -> Vec<Self> {
    automaton
      .edges
      // Convert `Vec` into `Iterator`.
      .iter()
      // Filter only the ones that start at the current position.
      .filter(|edge| edge.lhs == self.position)
      // Create a state that is there. In reality, we'd have to copy
      // the rest of the state properties too.
      .map(|edge| Self { position: edge.rhs })
      // Convert `Iterator` into `Vec`.
      .collect()
  }
}

As you can see, it’s only a couple of lines of code – nothing crazy, nothing “smart” your future self could ponder about. But as you see, there’s a very obvious performance problem there – we’re constantly switching back and forth between Vec and Iterator.

In Rust, Iterator is not a type of object, but rather a trait; sort of an interface in the TypeScript world. That means there’s some object there, but all we know about it is that we can iterate over it. We could skip the transformation, but then next_states would have to return something we’re not so sure of. Of course, it’s possible, but instead, let’s create our custom object we’ll implement the Iterator trait for.

Custom iterator

To implement a trait, we need something to implement it for. Let’s create one that could be one’s first guess – a struct that stores the list of Edges we have to go through. Having done that, we’ll implement the Iterator trait and update the rest of the code accordingly.

pub struct StateIterator<'a> {
  // List of the `Edge`s we still have to check. Instead of copying
  // them, we store the references, hence the `'a` lifetime here.
  queue: Vec<&'a Edge>,
}

impl Iterator for StateIterator<'_> {
  // It's an iterator of `State`s, and...
  type Item = State;

  // ...to get the next one we...
  fn next(&mut self) -> Option<Self::Item> {
    // ...pop the `Edge` out of the queue.
    if let Some(edge) = self.queue.pop() {
      // If there's one, return a `State` that moved there.
      return Some(Self::Item { position: edge.rhs });
    }

    // If not, there's no next `State`.
    None
  }
}

impl State {
    pub fn is_reachable(&self, automaton: &Automaton, position: Node) -> bool {
      self.position == position ||
      self
        .next_states(automaton)
        // No `.iter()` here!
        .any(|state| state.is_reachable(automaton, position))
    }

    pub fn next_states<'a>(&'a self, automaton: &'a Automaton) -> StateIterator {
      StateIterator {
        queue: automaton
          .edges
          .iter()
          .filter(|edge| edge.lhs == self.position)
          // No `.map()` here!
          .collect(),
      }
    }
}

It’s a very straightforward (and naive) wrap of our previous implementation into an iterator object. Definitely not a performant one (we’ll get to that later), but it allows us to improve on the iterator independently from the rest.

No copying allowed

I said before, that the problem with the first implementation was constant transforming between Vectors and Iterators. It’s still there, as we have to .collect() them, i.e., create new Vectors when passing them down to the StateIterator object. Can we do better? Instead, we could reference the entire edges list in there and only store the index we’re still checking.

pub struct StateIterator<'a> {
  // Position in the `queue`.
  index: usize,
  // Position in the automaton.
  position: Node,
  // List of _all_ edges of the automaton.
  queue: &'a Vec<Edge>,
}

impl Iterator for StateIterator<'_> {
  type Item = State;

  fn next(&mut self) -> Option<Self::Item> {
    // `.get()` returns the element or `None` if the index is out of
    // the bounds (i.e., larger than the number of `Edge`s).
    while let Some(edge) = self.queue.get(self.index) {
      // Proceed in the `queue`.
      self.index += 1;
      // Return the next state, but only if the `edge` is anchored in
      // the current `position`.
      if edge.lhs == self.position {
        return Some(Self::Item { position: edge.rhs });
      }
    }

    // Once the `queue` is exhausted, we're done.
    None
  }
}

impl State {
  pub fn next_states<'a>(&'a self, automaton: &'a Automaton) -> StateIterator {
    StateIterator {
      // Start at the first `Edge`.
      index: 0,
      // Copy the `position` to know which `Edge`s to use.
      position: self.position,
      // Reference the entire list of `Edge`s.
      queue: &automaton.edges,
    }
  }
}

Is it better? Intuitively yes, because there are fewer Vectors created here. And as we know, fewer objects mean fewer operations, and that leads to a better performance. What’s next?

Slice the queue

Even without checking the implementation of .get we can guess, the more the index grows, the more operation it needs. Can we somehow eliminate it?

Enter slices! A slice is a view into a contiguous sequence (of some memory). The API of it allows us to peek at the elements there and advance it (i.e., point to a later element).

pub struct StateIterator<'a> {
  position: Node,
  // A reference to a slice that lives as long as the original `Vec`.
  queue: &'a [Edge],
}

impl Iterator for StateIterator<'_> {
  type Item = State;

  fn next(&mut self) -> Option<Self::Item> {
    // Unpack the slice into the first element and the rest of them.
    while let [edge, tail @ ..] = self.queue {
      // Point the `queue` to the rest of the elements, effectively
      // "popping" the first element.
      self.queue = tail;
      if edge.lhs == self.position {
        return Some(Self::Item { position: edge.rhs });
      }
    }

    None
  }
}

impl State {
  pub fn next_states<'a>(&'a self, automaton: &'a Automaton) -> StateIterator {
    StateIterator {
      position: self.position,
      // Create a `slice` out of the `Vec`.
      queue: automaton.edges.as_slice(),
    }
  }
}

Is it better? We’ll check the performance later, but intuitively it should be better, as there’s one field less in the StateIterator struct, and we’re no longer .geting the far elements but rather always only the first one.

What else could we do? We still have a problem: if there are hundreds of thousands of Edges there, every iterator has to go through all of them. That could be eliminated by storing them differently. Like, in a…

Tree!

As I said at the beginning, the Automaton is built once, and then we iterate over it a lot. If we make the next_states faster at the cost of the add_edge being slower, it’s most likely going to be a huge win.

Let’s do that then – Automaton should store the Edges in a way that’ll allow us to get the outgoing ones fast. Let’s use a BTreeMap, i.e., a tree-based key-value store1. A key will be the lhs of an Edge, and the value will be a Vector of the outgoing Edges.

use std::collections::BTreeMap;

#[derive(Default)]
pub struct Automaton {
  // `BTreeMap` receives both key and value types.
  edges: BTreeMap<Node, Vec<Edge>>,
}

impl Automaton {
  pub fn add_edge(&mut self, edge: Edge) {
    self.edges
      // Get the corresponding `Vec`...
      .entry(edge.lhs)
      // ...if it's not there, create it...
      .or_insert_with(Vec::default)
      // ...and add the `Edge` to it.
      .push(edge)
  }
}

pub struct StateIterator<'a> {
  // No `position` here! The `queue` is already filtered.
  queue: &'a [Edge],
}

impl Iterator for StateIterator<'_> {
  type Item = State;

  fn next(&mut self) -> Option<Self::Item> {
    if let [edge, tail @ ..] = self.queue {
      self.queue = tail;
      // No `if` here!
      return Some(Self::Item { position: edge.rhs });
    }

    None
  }
}

impl State {
  pub fn next_states<'a>(&'a self, automaton: &'a Automaton) -> StateIterator {
    StateIterator {
      queue: automaton
        .edges
        // Get the corresponding `Vec`...
        .get(&self.position)
        // ...and slice it. If it's not there, use an empty one.
        .map_or(&[], Vec::as_slice),
    }
  }
}

Intuitively, it should be much faster, as the creation of the StateIterator scales with the number of unique Nodes and not Edges. Additionally, the iteration itself is now direct, i.e., there’s no need to check if the lhs equals our position. That should be fast… I think.

Benchmark

Finally, we’ve arrived at what seems a sensible and performant approach. Of course, we’re responsible people, so instead of relying solely on a gut feeling, we’ll create a replicable benchmarking suite to compare the implementations. To do that, we’ll use the Criterion package2.

fn criterion_benchmark(criterion: &mut Criterion) {
  // Create a macro to create the same benchmark for all versions.
  // If you know a better way to do it, please let me know!
  macro_rules! bench {
    ($version: expr, $mod: ident) => {
      // All automatons have the same "depth" (i.e., the number of
      // states from the root to the furthest one), but they differ
      // in the branching factor, i.e., the average number of next
      // states from each state.
      for branching in 1..=5 {
        criterion.bench_with_input(
          BenchmarkId::new($version, branching),
          &branching,
          |bencher, branching| {
            let mut automaton = $mod::Automaton::default();
            for index in 0..=50 {
              // Artificially increase the branching factor by
              // creating edges not only to the next state, but
              // also to the following ones. As I said, that's
              // more or less how real-life data looks like.
              for offset in 1..=*branching {
                automaton.add_edge(Edge {
                  lhs: index,
                  rhs: index + offset,
                });
              }
            }

            // One root state for all iterations.
            let state = $mod::State::default();

            bencher.iter(|| {
              // Try different depths but aggregate their timings.
              black_box(state.is_reachable(&automaton, black_box(0)));
              black_box(state.is_reachable(&automaton, black_box(25)));
              black_box(state.is_reachable(&automaton, black_box(50)));
            })
          },
        );
      }
    };
  }

  // Register the benchmarks. In reality, I wrapped all of the versions
  // into separate modules just to be able to fit them into one file.
  bench!("v1", v1);
  bench!("v2", v2);
  bench!("v3", v3);
  bench!("v4", v4);
  bench!("v5", v5);
}

// Configure Criterion.
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);

Having done that, we have to add the dependency and basic configuration in our Cargo.toml, and we’re good to go. Let’s run cargo benchmark, chart our results3 and see how each version performs on my M2 MacBook Air.

12345Branching012345Time [s]v1v2v3v4v5Version

Well… That’s not something you expected, huh? If you think about it, the sheer number of copies makes the creation of iterators extremely slow. Actually, it was so slow I lost patience while waiting for it to finish and decided to skip it entirely. Let’s remove all but the first point of version 2 and see the chart again.

12345Branching0246810Time [µs]v1v2v3v4v5Version

Much better! First of all, we see that versions 1 and 2 perform similarly when the branching factor is 1. That makes perfect sense, as next_states always creates a vector of length 1, and is_reachable immediately consumes it.

Secondly, versions 3 and 4 are basically identical in terms of performance. Advancing a slice is as simple as moving the pointer, so version 3 tends to be marginally slower as it has to index it from the start at every step instead (and it’s visible on the chart).

Lastly, version 5 outperforms all of them. That’s because the cost of navigation in such a small tree is negligible (the biggest automaton I’ve worked with so far had a few thousand states). As a result, the performance of creating an iterator scales with the branching factor and not the number of edges4.

Table with the complete results.
VersionBranchingMinAvgMax
112.627µs2.629µs2.631µs
212.656µs2.659µs2.660µs
31881.4ns882.7ns884.0ns
41830.8ns831.9ns833.1ns
51618.3ns618.7ns619.1ns
124.839µs4.851µs4.863µs
2220.63ms20.66ms20.70ms
321.488µs1.490µs1.492µs
421.395µs1.395µs1.396µs
52615.6ns616.6ns617.5ns
135.895µs5.912µs5.930µs
23949.4ms949.9ms950.5ms
332.041µs2.048µs2.056µs
432.033µs2.035µs2.038µs
53612.9ns614.1ns615.4ns
147.216µs7.244µs7.274µs
244.8150s4.8300s4.8580s
342.553µs2.555µs2.556µs
442.491µs2.493µs2.494µs
54611.6ns612.7ns613.9ns
1510.68µs10.71µs10.75µs
25?????????
353.084µs3.087µs3.090µs
452.990µs2.991µs2.992µs
55610.2ns611.0ns611.9ns
Entire source code, including the benchmark.

Please note that it’s formatted slightly differently, according to the standard cargo fmt rules. It also passes the cargo clippy linter. Rust v1.65.0.

use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};

pub type Node = usize;

pub struct Edge {
    lhs: Node,
    rhs: Node,
}

mod v1 {
    use crate::{Edge, Node};

    #[derive(Default)]
    pub struct Automaton {
        edges: Vec<Edge>,
    }

    impl Automaton {
        pub fn add_edge(&mut self, edge: Edge) {
            self.edges.push(edge)
        }
    }

    #[derive(Default)]
    pub struct State {
        position: Node,
    }

    impl State {
        pub fn is_reachable(&self, automaton: &Automaton, position: Node) -> bool {
            self.position == position
                || self
                    .next_states(automaton)
                    .iter()
                    .any(|state| state.is_reachable(automaton, position))
        }

        pub fn next_states(&self, automaton: &Automaton) -> Vec<Self> {
            automaton
                .edges
                .iter()
                .filter(|edge| edge.lhs == self.position)
                .map(|edge| Self { position: edge.rhs })
                .collect()
        }
    }
}

mod v2 {
    use crate::{Edge, Node};

    #[derive(Default)]
    pub struct Automaton {
        edges: Vec<Edge>,
    }

    impl Automaton {
        pub fn add_edge(&mut self, edge: Edge) {
            self.edges.push(edge)
        }
    }

    #[derive(Default)]
    pub struct State {
        position: Node,
    }

    impl State {
        pub fn is_reachable(&self, automaton: &Automaton, position: Node) -> bool {
            self.position == position
                || self
                    .next_states(automaton)
                    .any(|state| state.is_reachable(automaton, position))
        }

        pub fn next_states<'a>(&'a self, automaton: &'a Automaton) -> StateIterator {
            StateIterator {
                queue: automaton
                    .edges
                    .iter()
                    .filter(|edge| edge.lhs == self.position)
                    .collect(),
            }
        }
    }

    pub struct StateIterator<'a> {
        queue: Vec<&'a Edge>,
    }

    impl Iterator for StateIterator<'_> {
        type Item = State;

        fn next(&mut self) -> Option<Self::Item> {
            if let Some(edge) = self.queue.pop() {
                return Some(Self::Item { position: edge.rhs });
            }

            None
        }
    }
}

mod v3 {
    use crate::{Edge, Node};

    #[derive(Default)]
    pub struct Automaton {
        edges: Vec<Edge>,
    }

    impl Automaton {
        pub fn add_edge(&mut self, edge: Edge) {
            self.edges.push(edge)
        }
    }

    #[derive(Default)]
    pub struct State {
        position: Node,
    }

    impl State {
        pub fn is_reachable(&self, automaton: &Automaton, position: Node) -> bool {
            self.position == position
                || self
                    .next_states(automaton)
                    .any(|state| state.is_reachable(automaton, position))
        }

        pub fn next_states<'a>(&'a self, automaton: &'a Automaton) -> StateIterator {
            StateIterator {
                index: 0,
                position: self.position,
                queue: &automaton.edges,
            }
        }
    }

    pub struct StateIterator<'a> {
        index: usize,
        position: Node,
        queue: &'a Vec<Edge>,
    }

    impl Iterator for StateIterator<'_> {
        type Item = State;

        fn next(&mut self) -> Option<Self::Item> {
            while let Some(edge) = self.queue.get(self.index) {
                self.index += 1;
                if edge.lhs == self.position {
                    return Some(Self::Item { position: edge.rhs });
                }
            }

            None
        }
    }
}

mod v4 {
    use crate::{Edge, Node};

    #[derive(Default)]
    pub struct Automaton {
        edges: Vec<Edge>,
    }

    impl Automaton {
        pub fn add_edge(&mut self, edge: Edge) {
            self.edges.push(edge)
        }
    }

    #[derive(Default)]
    pub struct State {
        position: Node,
    }

    impl State {
        pub fn is_reachable(&self, automaton: &Automaton, position: Node) -> bool {
            self.position == position
                || self
                    .next_states(automaton)
                    .any(|state| state.is_reachable(automaton, position))
        }

        pub fn next_states<'a>(&'a self, automaton: &'a Automaton) -> StateIterator {
            StateIterator {
                position: self.position,
                queue: automaton.edges.as_slice(),
            }
        }
    }

    pub struct StateIterator<'a> {
        position: Node,
        queue: &'a [Edge],
    }

    impl Iterator for StateIterator<'_> {
        type Item = State;

        fn next(&mut self) -> Option<Self::Item> {
            while let [edge, tail @ ..] = self.queue {
                self.queue = tail;
                if edge.lhs == self.position {
                    return Some(Self::Item { position: edge.rhs });
                }
            }

            None
        }
    }
}

mod v5 {
    use crate::{Edge, Node};
    use std::collections::BTreeMap;

    #[derive(Default)]
    pub struct Automaton {
        edges: BTreeMap<Node, Vec<Edge>>,
    }

    impl Automaton {
        pub fn add_edge(&mut self, edge: Edge) {
            self.edges
                .entry(edge.lhs)
                .or_insert_with(Vec::default)
                .push(edge)
        }
    }

    #[derive(Default)]
    pub struct State {
        position: Node,
    }

    impl State {
        pub fn is_reachable(&self, automaton: &Automaton, position: Node) -> bool {
            self.position == position
                || self
                    .next_states(automaton)
                    .any(|state| state.is_reachable(automaton, position))
        }

        pub fn next_states<'a>(&'a self, automaton: &'a Automaton) -> StateIterator {
            StateIterator {
                queue: automaton
                    .edges
                    .get(&self.position)
                    .map_or(&[], Vec::as_slice),
            }
        }
    }

    pub struct StateIterator<'a> {
        queue: &'a [Edge],
    }

    impl Iterator for StateIterator<'_> {
        type Item = State;

        fn next(&mut self) -> Option<Self::Item> {
            while let [edge, tail @ ..] = self.queue {
                self.queue = tail;
                return Some(Self::Item { position: edge.rhs });
            }

            None
        }
    }
}

fn criterion_benchmark(criterion: &mut Criterion) {
    macro_rules! bench {
        ($version: expr, $mod: ident) => {
            for branching in 1..=5 {
                criterion.bench_with_input(
                    BenchmarkId::new($version, branching),
                    &branching,
                    |bencher, branching| {
                        let mut automaton = $mod::Automaton::default();
                        for index in 0..=50 {
                            for offset in 1..=*branching {
                                automaton.add_edge(Edge {
                                    lhs: index,
                                    rhs: index + offset,
                                });
                            }
                        }

                        let state = $mod::State::default();

                        bencher.iter(|| {
                            black_box(state.is_reachable(&automaton, black_box(0)));
                            black_box(state.is_reachable(&automaton, black_box(25)));
                            black_box(state.is_reachable(&automaton, black_box(50)));
                        })
                    },
                );
            }
        };
    }

    bench!("v1", v1);
    bench!("v2", v2);
    bench!("v3", v3);
    bench!("v4", v4);
    bench!("v5", v5);
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);

Closing thoughts

That’s a wrap. It was extremely interesting and educative to me. I didn’t use Rust’s slices that much, and here they really shined. The final performance is more than enough, and I’m happy with it for now.

If you’re interested, the complete interpreter in Rust is between 5 and 20 times faster than the JavaScript one. There are definitely more improvements to come, but for now, the performance is fine.

Now, back to work!

1

We could also use HashMap. It doesn’t matter for us API-wise, but it turned out to be slower, at least in my small-scale tests. I guess it’d be beneficial if the number of Nodes were much higher or the Node itself was bigger.

2

I’d rather use the built-in #[bench] attribute, but it remains unstable. On the other hand, Criterion provides far more detailed feedback and provides tools to measure how code changes impact the benchmarks.

3

I wanted to style the charts differently as well as inline them into the post, so I used Vega instead of Criterion’s charts. It’s a powerful tool that allows you to “code” your charts, and I strongly recommend it!

4

As you can see, the time actually goes down with the branching factor, which is at least weird, but I was able to consistently reproduce it. At first, I thought that’s because the shortest path is shorter, but that shouldn’t matter, as the iteration is depth-first and always finds the longest path.