Home > Enterprise >  How to implement a fast type inference procedure for SKI combinators in Python?
How to implement a fast type inference procedure for SKI combinators in Python?

Time:01-26

How to implement a fast simple type inference procedure for SKI combinators in Python?

I am interested in 2 functions:

  1. typable: returns true if a given SKI term has a type (I suppose it should work faster than searching for a concrete type).

  2. principle_type: returns principle type if it exists and False otherwise.


typable(SKK) = True
typable(SII) = False # (I=SKK). This term does not have a type. Similar to \x.xx

principle_type(S) = (t1 -> t2 -> t3) -> (t1 -> t2) -> t1 -> t3
principle_type(K) = t1 -> t2 -> t1
principle_type(SK) = (t3 -> t2) -> t3 -> t3
principle_type(SKK) = principle_type(I) = t1 -> t1  

Theoretical questions:

  1. I read about Hindley–Milner type system. There are 2 algs: Algorithm J and Algorithm W. Do I understand correctly that they are used for more complex type system: System F? System with parametric polymorphism? Are there combinators typable in System F but not typable in the simple type system?

  2. As I understand, to find a principle type we need to solve a system of equations between symbolic expressions. Is it possible to simplify the algorithm and speed up the process by using SMT solvers like Z3?

My implementation of basic combinators, reduction and parsing:

from __future__ import annotations
import typing
from dataclasses import dataclass


@dataclass(eq=True, frozen=True)
class S:
    def __str__(self):
        return "S"

    def __len__(self):
        return 1


@dataclass(eq=True, frozen=True)
class K:
    def __str__(self):
        return "K"

    def __len__(self):
        return 1


@dataclass(eq=True, frozen=True)
class App:
    left: Term
    right: Term

    def __str__(self):
        return f"({self.left}{self.right})"

    def __len__(self):
        return len(str(self))


Term = typing.Union[S, K, App]


def parse_ski_string(s):
    # remove spaces
    s = ''.join(s.split())

    stack = []
    for c in s:
        # print(stack, len(stack))
        if c == '(':
            pass

        elif c == 'S':
            stack.append(S())
        elif c == 'K':
            stack.append(K())
        # elif c == 'I':
        #     stack.append(I())

        elif c == ')':
            x = stack.pop()
            if len(stack) > 0:
                # S(SK)
                f = stack.pop()
                node = App(f, x)
                stack.append(node)
            else:
                # S(S)
                stack.append(x)
        else:
            raise Exception('wrong c = ', c)

    if len(stack) != 1:
        raise Exception('wrong stack = ', str(stack))

    return stack[0]


def simplify(expr: Term):
    if isinstance(expr, S) or isinstance(expr, K):
        return expr

    elif isinstance(expr, App) and isinstance(expr.left, App) and isinstance(expr.left.left, K):
        return simplify(expr.left.right)

    elif isinstance(expr, App) and isinstance(expr.left, App) and isinstance(expr.left.left, App) and isinstance(
            expr.left.left.left, S):
        return simplify(App(App(expr.left.left.right, expr.right), (App(expr.left.right, expr.right))))

    elif isinstance(expr, App):
        l2 = simplify(expr.left)
        r2 = simplify(expr.right)
        if expr.left == l2 and expr.right == r2:
            return App(expr.left, expr.right)
        else:
            return simplify(App(l2, r2))

    else:
        raise Exception('Wrong type of combinator', expr)

# simplify(App(App(K(),S()),K())) = S
# simplify(parse_ski_string('(((SK)K)S)')) = S

CodePudding user response:

Simple, maybe not the fastest (but reasonably fast if the types are small).

from dataclasses import dataclass


class OccursError(Exception):
    pass


parent = {}

Var = int


def new_var() -> Var:
    t1 = Var(len(parent))
    parent[t1] = t1
    return t1


@dataclass
class Fun:
    dom: "Var | Fun"
    cod: "Var | Fun"


def S() -> Fun:
    t1 = new_var()
    t2 = new_var()
    t3 = new_var()
    return Fun(Fun(t1, Fun(t2, t3)), Fun(Fun(t1, t2), Fun(t1, t3)))


def K() -> Fun:
    t1 = new_var()
    t2 = new_var()
    return Fun(t1, Fun(t2, t1))


def I() -> Fun:
    t1 = new_var()
    return Fun(t1, t1)


def find(t1: Var | Fun) -> Var | Fun:
    if isinstance(t1, Var):
        if parent[t1] == t1:
            return t1
        t2 = find(parent[t1])
        parent[t1] = t2
        return t2
    if isinstance(t1, Fun):
        return Fun(find(t1.dom), find(t1.cod))
    raise TypeError


def occurs(t1: Var, t2: Var | Fun) -> bool:
    if isinstance(t2, Var):
        return t1 == t2
    if isinstance(t2, Fun):
        return occurs(t1, t2.dom) or occurs(t1, t2.cod)
    raise TypeError


def unify(t1: Var | Fun, t2: Var | Fun):
    t1 = find(t1)
    t2 = find(t2)
    if isinstance(t1, Var) and isinstance(t2, Var):
        parent[t1] = t2
    elif isinstance(t1, Var) and isinstance(t2, Fun):
        if occurs(t1, t2):
            raise OccursError
        parent[t1] = t2
    elif isinstance(t1, Fun) and isinstance(t2, Var):
        if occurs(t2, t1):
            raise OccursError
        parent[t2] = t1
    elif isinstance(t1, Fun) and isinstance(t2, Fun):
        unify(t1.dom, t2.dom)
        unify(t1.cod, t2.cod)
    else:
        raise TypeError


def apply(t1: Var | Fun, t2: Var | Fun) -> Var | Fun:
    t3 = new_var()
    unify(t1, Fun(t2, t3))
    return t3


try:
    a = S()
    b = K()
    ab = apply(a, b)
    c = K()
    abc = apply(ab, c)
    print("#", find(abc))
except OccursError:
    print("# no type")

try:
    a = S()
    b = I()
    ab = apply(a, b)
    c = I()
    abc = apply(ab, c)
    print("#", find(abc))
except OccursError:
    print("# no type")

# Fun(dom=6, cod=6)
# no type
  • Related