Home > other >  Implement the recursive insert method of binary search tree in RUST
Implement the recursive insert method of binary search tree in RUST

Time:11-10

I'm learning Rust and trying to implement a simple binary search tree (actually it's rewriting the Java implementation below). Here is what I have done:

use std::cmp::Ordering;

// Node of this BST, the two generic types are key and value
struct Node<K:Ord, V> {
    key: K,
    value: V,
    left: Option<Box<Node<K, V>>>,
    right: Option<Box<Node<K, V>>>,
    number_of_nodes: i32,
}

impl<K: Ord, V> Node<K, V> {
    // Create a new node
    fn new(key: K, value: V, number_of_nodes: i32) -> Node<K, V>{
        Node {
            key,
            value,
            left: None,
            right: None,
            number_of_nodes,
        }
    }
}

struct BST<K: Ord ,V> {
    root: Option<Box<Node<K, V>>>,
}

impl<K: Ord, V> BST<K, V> {
    // Get the size of this BST
    fn size(&self) -> i32 {
        size(&self.root)
    }

    // Search for key. Update value if found, otherwise insert the new node
    fn put(&self, key: K, value: V) {
        self.root = put(&self.root, key, value)
    }
}

// Function for recursively get the size of a sub BST 
fn size<K: Ord, V>(node: &Option<Box<Node<K, V>>>) -> i32 {
    match node {
        Some(real_node) => real_node.number_of_nodes,
        None => 0,
    }
}

// Function for recursively put a new node to this BST
fn put<K: Ord, V>(node: &Option<Box<Node<K, V>>>, key: K, value: V) -> &Option<Box<Node<K, V>>>{
    match node {
        None => {
            let new_node = Some(Box::new(Node::new(key, value, 1)));
            return &new_node;
        },
        Some(real_node) => {
            match key.cmp(&real_node.key) {
                Ordering::Less => real_node.left = *put(&real_node.left, key, value),
                Ordering::Greater => real_node.right = *put(&real_node.right, key, value), 
                Ordering::Equal => real_node.value = value,
            }
            real_node.number_of_nodes = size(&real_node.right)   size(&real_node.left)   1;
            node
        },
    }
}

But this code won't compile, at the line self.root = put(&self.root, key, value), I get an error:

mismatched types
expected enum 'Option<Box<Node<K, V>>>' found reference '&Option<Box<Node<K, V>>>'

I don't know how to fix that, i tried to change the &self parameter to self, or self.root to *self.root, but i got more errors. I'm so confused about the reference in Rust, all I wanna do is rewrite the following Java code in Rust.

public class BST<Key extends Comparable<Key>, Value>
{
    private Node root;              //root of BST

    private class Node
    {
        private Key key;            // key
        private Value val;          // associated value
        private Node right, left;   // left and right subtrees
        private int N;              // number of nodes in subtree

        public Node(Key key, Value val, int N)
        {
            this.key = key;
            this.val = val;
            this.N = N;
        }
    }

    // Returns the number of key-value pairs in this symbol table.
    public int size()
    {
        return size(root);
    }

    // Return number of key-value pairs in BST rooted at x
    private int size(Node x)
    {
        if (x == null) return 0;
        else return x.N;
    }

    public void put(Key key, Value val)
    {
        root = put(root, key, val);
    }

    private Node put(Node x, Key key, Value val)
    {
        if (x == null) return new Node(key, val, 1);
        int cmp = key.compareTo(x.key);
        if (cmp < 0) x.left = put(x.left, key, val);
        else if (cmp > 0) x.right = put(x.right, key, val);
        else x.val = val;
        x.N = size(x.left)   size(x.right)   1;
        return x;
    }
} 

It's dead simple in Java because I don't need to handle the reference. So here is my problems:

  1. How could I fix that mismatched error?
  2. What is the proper return type of that recursive function put, the &Option<Box<Node<K, V>>> or Option<Box<Node<K, V>>>? What's the difference?
  3. Am I on the right way to rewrite this Java code? The rust-analyzer only reports this mismatched error but I don't know if it will work as I expect. And honestly I don't fully understand what am i doing when I handle the reference in rust especially when it's a reference of struct or enum

It's hard to learn Rust because I don't have much experience in system programming language, I appreciated your guys help :)

CodePudding user response:

The simplest option is to take a mutable reference to the node:

impl<K: Ord, V> BST<K, V> {
    // ...

    fn put(&mut self, key: K, value: V) {
        put(&mut self.root, key, value)
    }
}

fn put<K: Ord, V>(node: &mut Option<Box<Node<K, V>>>, key: K, value: V) {
    match node {
        None => {
            *node = Some(Box::new(Node::new(key, value, 1)));
        }
        Some(real_node) => {
            if key < real_node.key {
                put(&mut real_node.left, key, value);
            } else {
                put(&mut real_node.right, key, value);
            }

            real_node.number_of_nodes = size(&real_node.right)   size(&real_node.left)   1;
        }
    }
}

Note I changed the insertion behavior when the keys are equal, because in the original Rust and Java code the variable that stores the number of nodes would be incremented even if no new node was inserted.

Alternatively you could take the node by value and return the modified node:

impl<K: Ord, V> BST<K, V> {
    // ...

    pub fn put(&mut self, key: K, value: V) {
        self.root = put(self.root.take(), key, value);
    }
}

fn put<K: Ord, V>(node: Option<Box<Node<K, V>>>, key: K, value: V) -> Option<Box<Node<K, V>>> {
    match node {
        None => Some(Box::new(Node::new(key, value, 1))),
        Some(mut real_node) => {
            if key < real_node.key {
                real_node.left = put(real_node.left, key, value);
            } else {
                real_node.right = put(real_node.right, key, value);
            }

            real_node.number_of_nodes = size(&real_node.right)   size(&real_node.left)   1;
            Some(real_node)
        }
    }
}

I suggest reading the ownership and references chapter of the Rust book if you haven't already.

  • Related