Program synthesis is the act of automatically constructing a program that fulfills a given specification. Perhaps you are interested in sketching a program, leaving parts of it incomplete, and then having a tool fill in those missing bits for you? Or perhaps you are a compiler, and you have some instruction sequence, but you want to find an equivalent-but-better instruction sequence? Program synthesizers promise to help you out in these situations!

I recently stumbled across Adrian Sampson’s Program Synthesis is Possible blog post. Adrian describes and implements minisynth, a toy program synthesizer that generates constants for holes in a template program when given a specification. What fun! As a way to learn more about program synthesis myself, I ported minisynth to Rust.

The Language

The input language is quite simple. The only type is the signed integer and our operations are addition, subtraction, multiplication, division, negation, left- and right-shift, and if-then-else conditionals.

Here is an example:

x * 10 + y

And here is conditional expression that evaluates to 27 if x is non-zero, and 42 otherwise:

x ? 27 : 42

Abstract Syntax Tree

My representation of the AST uses an id-based arena and interns identifier strings, which is a bit overkill for such a small program, but is a pattern that has worked well for me in Rust. This pattern makes implementing the petgraph crate’s traits easy, which gets you all the graph traversals and dominator algorithms, etc that a non-toy implementation will eventually want.

The ast::Context structure contains the arena of AST nodes and the interned strings.

// src/ast.rs

use id_arena::{Arena, Id};

pub type StringId = Id<String>;

#[derive(Default)]
pub struct Context {
    idents: Arena<String>,
    already_interned: HashMap<String, StringId>,
    nodes: Arena<Node>,
}

The ast::Node definition is an enum with a variant for each type of expression in the language.

// src/ast.rs

pub type NodeId = Id<Node>;

pub enum Node {
    Identifier(StringId),
    Addition(NodeId, NodeId),
    Subtraction(NodeId, NodeId),
    Multiplication(NodeId, NodeId),
    Division(NodeId, NodeId),
    RightShift(NodeId, NodeId),
    LeftShift(NodeId, NodeId),
    Const(i64),
    Negation(NodeId),
    Conditional(NodeId, NodeId, NodeId),
}

The ast::Context also has methods to allocate new ast::Nodes, get interned strings, and access allocated nodes. These definitions are straightforward, so I have elided them here. If you’re interested, you can look at the source on GitHub.

Parsing

I use the wonderful lalrpop parser generator to generate a parser for the input language. The grammar and actions are given in full below:

// src/parser/grammar.lalrpop

use crate::ast;
use std::str::FromStr;

grammar(ctx: &mut ast::Context);

Integer: i64 = <s:r"[0-9]+"> => i64::from_str(s).unwrap();

Identifier: ast::NodeId = <s:r"[a-zA-Z][a-zA-Z0-9_]*"> => ctx.new_identifier(s);

Sum: ast::NodeId = {
    <t:Term> => t,
    <l:Sum> "+" <r:Term> => ctx.new_node(ast::Node::Addition(l, r)),
    <l:Sum> "-" <r:Term> => ctx.new_node(ast::Node::Subtraction(l, r)),
};

Term: ast::NodeId = {
    <i:Item> => i,
    <l:Term> "*" <r:Item> => ctx.new_node(ast::Node::Multiplication(l, r)),
    <l:Term> "/" <r:Item> => ctx.new_node(ast::Node::Division(l, r)),
    <l:Term> ">>" <r:Item> => ctx.new_node(ast::Node::RightShift(l, r)),
    <l:Term> "<<" <r:Item> => ctx.new_node(ast::Node::LeftShift(l, r)),
};

Item: ast::NodeId = {
    <n:Integer> => ctx.new_node(ast::Node::Const(n)),
    "-" <i:Item> => ctx.new_node(ast::Node::Negation(i)),
    <i:Identifier> => i,
    "(" <s:Start> ")" => s,
};

pub Start: ast::NodeId = {
    <s:Sum> => s,
    <condition:Sum> "?" <consequent:Sum> ":" <alternative:Sum> =>
        ctx.new_node(ast::Node::Conditional(condition, consequent, alternative)),
};

Interpretation

To interpret expressions, we need a lookup function that maps identifiers to values, the ast::Context that owns the AST nodes, and the id of the node we are evaluating. We match on this node, and handle the following cases:

  • If the node represents a constant, we return that node’s associated constant value.

  • If the node represents an identifier, we get a reference to its interned identifier string from the context, and then query the lookup function for its value.

  • If the node represents an operator, we recursively evaluate its operands and then apply the operator to the operands’ values. For division, we also check for divide-by-zero return an error.

Here is our initial interpreter function:

// src/eval.rs

pub fn eval<L>(
    ctx: &mut ast::Context,
    node: ast::NodeId,
    lookup: &mut L
) -> Result<i64>
where
    L: for<'a> FnMut(&'a str) -> Result<i64>,
{
    match *ctx.node_ref(node) {
        Node::Const(i) => Ok(i),
        Node::Identifier(s) => {
            let s = ctx.interned(s);
            lookup(s)
        }
        Node::Addition(lhs, rhs) => {
            let lhs = eval(ctx, lhs, lookup)?;
            let rhs = eval(ctx, rhs, lookup)?;
            Ok(lhs + rhs)
        }
        Node::Subtraction(lhs, rhs) => {
            let lhs = eval(ctx, lhs, lookup)?;
            let rhs = eval(ctx, rhs, lookup)?;
            Ok(lhs - rhs)
        }
        Node::Multiplication(lhs, rhs) => {
            let lhs = eval(ctx, lhs, lookup)?;
            let rhs = eval(ctx, rhs, lookup)?;
            Ok(lhs * rhs)
        }
        Node::Division(lhs, rhs) => {
            let lhs = eval(ctx, lhs, lookup)?;
            let rhs = eval(ctx, rhs, lookup)?;
            if rhs == 0 {
                bail!("divide by zero");
            }
            Ok(lhs / rhs)
        }
        Node::RightShift(lhs, rhs) => {
            let lhs = eval(ctx, lhs, lookup)?;
            let rhs = eval(ctx, rhs, lookup)?;
            Ok(lhs >> rhs)
        }
        Node::LeftShift(lhs, rhs) => {
            let lhs = eval(ctx, lhs, lookup)?;
            let rhs = eval(ctx, rhs, lookup)?;
            Ok(lhs << rhs)
        }
        Node::Negation(n) => {
            let n = eval(ctx, n, lookup)?;
            Ok(-n)
        }
        Node::Conditional(condition, consequent, alternative) => {
            let condition = eval(ctx, condition, lookup)?;
            let consequent = eval(ctx, consequent, lookup)?;
            let alternative = eval(ctx, alternative, lookup)?;
            Ok(if condition != 0 {
                consequent
            } else {
                alternative
            })
        }
    }
}

From Interpreter to Synthesizer

Our synthesizer will take a specification program and a template program. The template program may contain holes — in our system, these are variables that start with the letter “h”. Our goal is to synthesize constant values for these holes such that the template program implements the specification for all values of the non-hole variables.

Let’s consider the example from the original blog post:

// Specification:
x * 10

// Template:
(x << h1) + (x << h2)

Can we transform multiplication by ten into the sum of two constant left shifts? If we can find constant values for h1 and h2, then the answer is yes. Our synthesizer should answer that either h1 = 1 and h2 = 3, or that h1 = 3 and h2 = 1.

To implement synthesis, we will walk the AST and generate constraints for the Z3 SMT solver that reflect the program’s semantics. We will do this for both the specification and the template, and then constrain the results of each of them to be equal to each other for every non-hole constant variable. Finally, we ask Z3 if it can find a solution to all of the constraints. Any solution that exists will provide definitions for the holes.

That is, we are asking the solver to find a solution for

∃h1h2…hm: ∀c1c2…cn: t = s

where t is the template’s constraints, s is the specification’s constraints, hi are holes in the template, and cj are constants in the template and specification programs.

Adrian’s original Python implementation of minisynth takes advantage of Python’s dynamic nature and the Z3 Python library’s operator overloading to reuse the interpreter for constraint generation without any changes to the interpreter function. All you have to do is supply a lookup function that returns Z3 bitvector variables instead of signed integers. A neat trick!

For our Rust implementation, we want to reuse the interpreter as well, but Rust is statically typed and the z3 crate for Rust doesn’t implement operator overloading. So we will factor out an interpret function from our eval function that is generic over some abstract interpreter.

An abstract interpreter must have an associated output type. For normal evaluation, this will be an i64, and for constraints generation it will be a Z3 constraint. The abstract interpreter must have methods for evaluating each operation of the input language, taking its operands as its output type, applying the operation to them, and returning the results as its output type. It must also provide a way to translate constants and identifiers into its output type.

// src/abstract_interpret.rs

pub trait AbstractInterpret {
    /// The output type of this interpreter.
    type Output;

    /// Create a constant output value.
    fn constant(&mut self, c: i64) -> Self::Output;

    /// `lhs + rhs`
    fn add(&mut self, lhs: &Self::Output, rhs: &Self::Output) -> Self::Output;

    /// `lhs - rhs`
    fn sub(&mut self, lhs: &Self::Output, rhs: &Self::Output) -> Self::Output;

    /// `lhs * rhs`
    fn mul(&mut self, lhs: &Self::Output, rhs: &Self::Output) -> Self::Output;

    /// `lhs / rhs`. Fails on divide by zero.
    fn div(&mut self, lhs: &Self::Output, rhs: &Self::Output) -> Result<Self::Output>;

    /// `lhs >> rhs`
    fn shr(&mut self, lhs: &Self::Output, rhs: &Self::Output) -> Self::Output;

    /// `lhs << rhs`
    fn shl(&mut self, lhs: &Self::Output, rhs: &Self::Output) -> Self::Output;

    /// `-e`
    fn neg(&mut self, e: &Self::Output) -> Self::Output;

    /// Returns `1` if `lhs == rhs`, returns `0` otherwise.
    fn eq(&mut self, lhs: &Self::Output, rhs: &Self::Output) -> Self::Output;

    /// Returns `1` if `lhs != rhs`, returns `0` otherwise.
    fn neq(&mut self, lhs: &Self::Output, rhs: &Self::Output) -> Self::Output;

    /// Perform variable lookup for the identifier `var`.
    fn lookup(&mut self, var: &str) -> Result<Self::Output>;
}

Next we make an interpretation function that takes an abstract interpreter and uses it to interpret an expression of our input language. This looks almost the same as our original eval function, but there is one tricky bit: encoding conditional’s semantics into interpreter methods without using Rust’s control flow, which would be invisible to the solver. To do this, we multiply the activated conditional arm by one and the deactivated conditional arm by zero and then sum the products. An alternative approach would be to add a method for interpreting conditionals directly to the AbstractIntepret trait.

// src/abstract_interpret.rs

pub fn interpret<A>(
    interpreter: &mut A,
    ctx: &mut ast::Context,
    node: ast::NodeId,
) -> Result<A::Output>
where
    A: AbstractInterpret,
{
    match *ctx.node_ref(node) {
        Node::Const(i) => Ok(interpreter.constant(i)),
        Node::Identifier(s) => {
            let s = ctx.interned(s);
            interpreter.lookup(s)
        }
        Node::Addition(lhs, rhs) => {
            let lhs = interpret(interpreter, ctx, lhs)?;
            let rhs = interpret(interpreter, ctx, rhs)?;
            Ok(interpreter.add(&lhs, &rhs))
        }
        Node::Subtraction(lhs, rhs) => {
            let lhs = interpret(interpreter, ctx, lhs)?;
            let rhs = interpret(interpreter, ctx, rhs)?;
            Ok(interpreter.sub(&lhs, &rhs))
        }
        Node::Multiplication(lhs, rhs) => {
            let lhs = interpret(interpreter, ctx, lhs)?;
            let rhs = interpret(interpreter, ctx, rhs)?;
            Ok(interpreter.mul(&lhs, &rhs))
        }
        Node::Division(lhs, rhs) => {
            let lhs = interpret(interpreter, ctx, lhs)?;
            let rhs = interpret(interpreter, ctx, rhs)?;
            interpreter.div(&lhs, &rhs)
        }
        Node::RightShift(lhs, rhs) => {
            let lhs = interpret(interpreter, ctx, lhs)?;
            let rhs = interpret(interpreter, ctx, rhs)?;
            Ok(interpreter.shr(&lhs, &rhs))
        }
        Node::LeftShift(lhs, rhs) => {
            let lhs = interpret(interpreter, ctx, lhs)?;
            let rhs = interpret(interpreter, ctx, rhs)?;
            Ok(interpreter.shl(&lhs, &rhs))
        }
        Node::Negation(e) => {
            let e = interpret(interpreter, ctx, e)?;
            Ok(interpreter.neg(&e))
        }
        Node::Conditional(condition, consequent, alternative) => {
            let condition = interpret(interpreter, ctx, condition)?;
            let consequent = interpret(interpreter, ctx, consequent)?;
            let alternative = interpret(interpreter, ctx, alternative)?;

            let zero = interpreter.constant(0);
            let neq_zero = interpreter.neq(&condition, &zero);
            let eq_zero = interpreter.eq(&condition, &zero);

            let consequent = interpreter.mul(&neq_zero, &consequent);
            let alternative = interpreter.mul(&eq_zero, &alternative);

            Ok(interpreter.add(&consequent, &alternative))
        }
    }
}

We refactor eval to apply an implementation of AbstractInterpret that has an i64 associated output type and directly evaluates expressions:

// src/eval.rs

struct Eval<'a> {
    env: &'a HashMap<String, i64>,
}

impl<'a> AbstractInterpret for Eval<'a> {
    type Output = i64;

    fn constant(&mut self, c: i64) -> i64 { c }

    fn lookup(&mut self, var: &str) -> Result<i64> {
        self.env
            .get(var)
            .cloned()
            .ok_or_else(|| format_err!("undefined variable: {}", var))
    }

    fn neg(&mut self, e: &i64) -> i64 { -e }
    fn add(&mut self, lhs: &i64, rhs: &i64) -> i64 { lhs + rhs }
    fn sub(&mut self, lhs: &i64, rhs: &i64) -> i64 { lhs - rhs }
    fn mul(&mut self, lhs: &i64, rhs: &i64) -> i64 { lhs * rhs }
    fn shr(&mut self, lhs: &i64, rhs: &i64) -> i64 { lhs >> rhs }
    fn shl(&mut self, lhs: &i64, rhs: &i64) -> i64 { lhs << rhs }
    fn div(&mut self, lhs: &i64, rhs: &i64) -> Result<i64> {
        if *rhs == 0 {
            bail!("divide by zero");
        }
        Ok(lhs / rhs)
    }

    fn eq(&mut self, lhs: &i64, rhs: &i64) -> i64 {
        (lhs == rhs) as i64
    }
    fn neq(&mut self, lhs: &i64, rhs: &i64) -> i64 {
        (lhs != rhs) as i64
    }
}

pub fn eval(
    ctx: &mut ast::Context,
    node: NodeId,
    env: &HashMap<String, i64>
) -> Result<i64> {
    let eval = &mut Eval { env };
    interpret(eval, ctx, node)
}

Finally, we are ready to start implementing synthesis!

First we create an implementation of AbstractInterpret that builds up Z3 constraints. Its lookup method keeps track of which variables have been used, categorizes them by whether they are a hole or an unknown constant, and makes sure that subsequent lookups of the same identifier return the same Z3 variable. All other methods map straightforwardly onto Z3 method calls.

// src/synthesize.rs

struct Synthesize<'a, 'ctx>
where
    'ctx: 'a,
{
    ctx: &'ctx z3::Context,
    vars: &'a mut HashMap<String, z3::Ast<'ctx>>,
    holes: &'a mut HashMap<z3::Ast<'ctx>, String>,
    const_vars: &'a mut HashSet<z3::Ast<'ctx>>,
}

impl<'a, 'ctx> AbstractInterpret for Synthesize<'a, 'ctx> {
    type Output = z3::Ast<'ctx>;
    fn lookup(&mut self, var: &str) -> Result<z3::Ast<'ctx>> {
        if !self.vars.contains_key(var) {
            let c = self.ctx.fresh_bitvector_const(var, 64);
            self.vars.insert(var.to_string(), c.clone());
            if var.starts_with("h") {
                self.holes.insert(c, var.to_string());
            } else {
                self.const_vars.insert(c);
            }
        }

        Ok(self.vars[var].clone())
    }
    fn constant(&mut self, c: i64) -> z3::Ast<'ctx> {
        z3::Ast::bitvector_from_i64(self.ctx, c as i64, 64)
    }
    fn add(&mut self, lhs: &z3::Ast<'ctx>, rhs: &z3::Ast<'ctx>) -> z3::Ast<'ctx> {
        lhs.bvadd(rhs)
    }
    fn sub(&mut self, lhs: &z3::Ast<'ctx>, rhs: &z3::Ast<'ctx>) -> z3::Ast<'ctx> {
        lhs.bvsub(rhs)
    }
    fn mul(&mut self, lhs: &z3::Ast<'ctx>, rhs: &z3::Ast<'ctx>) -> z3::Ast<'ctx> {
        lhs.bvmul(rhs)
    }
    fn div(&mut self, lhs: &z3::Ast<'ctx>, rhs: &z3::Ast<'ctx>) -> Result<z3::Ast<'ctx>> {
        Ok(lhs.bvsdiv(rhs))
    }
    fn shr(&mut self, lhs: &z3::Ast<'ctx>, rhs: &z3::Ast<'ctx>) -> z3::Ast<'ctx> {
        lhs.bvlshr(&rhs)
    }
    fn shl(&mut self, lhs: &z3::Ast<'ctx>, rhs: &z3::Ast<'ctx>) -> z3::Ast<'ctx> {
        lhs.bvshl(&rhs)
    }
    fn neg(&mut self, e: &z3::Ast<'ctx>) -> z3::Ast<'ctx> {
        e.bvneg()
    }
    fn eq(&mut self, lhs: &z3::Ast<'ctx>, rhs: &z3::Ast<'ctx>) -> z3::Ast<'ctx> {
        lhs._eq(rhs).ite(&self.constant(1), &self.constant(0))
    }
    fn neq(&mut self, lhs: &z3::Ast<'ctx>, rhs: &z3::Ast<'ctx>) -> z3::Ast<'ctx> {
        lhs._eq(rhs).not().ite(&self.constant(1), &self.constant(0))
    }
}

Our synthesis function will take the specification program and the template program, and then use the Synthesize abstract interpreter to generate constraints for each of them.

// src/synthesize.rs

pub fn synthesize<'a>(
    z3_ctx: &'a z3::Context,
    ast_ctx: &mut ast::Context,
    specification: NodeId,
    template: NodeId,
) -> Result<HashMap<String, i64>> {
    let mut vars = HashMap::new();
    let mut holes = HashMap::new();
    let mut const_vars = HashSet::new();

    let synth = &mut Synthesize {
        ctx: z3_ctx,
        vars: &mut vars,
        holes: &mut holes,
        const_vars: &mut const_vars,
    };

    let specification = interpret(synth, ast_ctx, specification)?;
    if !synth.holes.is_empty() {
        bail!("the specification cannot have any holes!");
    }
    let template = interpret(synth, ast_ctx, template)?;

    // ...
}

Next, we extract the constant variables and create our goal, which is a for-all constraint. The template must be equal to the specification for all possible values these constants could take.

let const_vars: Vec<_> = const_vars.iter().collect();
let templ_eq_spec = specification._eq(&template);
let goal = if const_vars.is_empty() {
    templ_eq_spec
} else {
    z3::Ast::forall_const(&const_vars, &templ_eq_spec)
};

Now that we have constructed our goal, we ask Z3 to solve it. If it can find an answer, we extract the values its assigned to each of the holes and return the results as a hash map.

let solver = z3::Solver::new(z3_ctx);
solver.assert(&goal);
if solver.check() {
    let model = solver.get_model();
    let mut results = HashMap::new();
    for (hole, name) in holes {
        results.insert(name, model.eval(&hole).unwrap().as_i64().unwrap());
    }
    Ok(results)
} else {
    bail!("no solution")
}

And now we have a synthesizer!

When given

// Specification:
x * 10
// Template:
(x << h1) + (x << h2)

our synthesis gives the answer

{
    "h1": 1,
    "h2": 3,
}

And when given

// Specification:
x * 9
// Template:
x << (hb1 ? x : hn1) + (hb2 ? x : hn2)

it gives us the answer

{
    "hb1": 0,
    "hb2": 1,
    "hn1": 3,
    "hn2": 0,
}

Conclusion

This was quite fun!

Thanks to Adrian Sampson for writing the original blog post and minisynth Python implementation.

If you would like to learn more, here are a few resources: