I'm learning Rust and trying to implement a simple binary search tree (actually it's rewriting the Java implementation below). Here is what I have done:
use std::cmp::Ordering;
// Node of this BST, the two generic types are key and value
struct Node<K:Ord, V> {
key: K,
value: V,
left: Option<Box<Node<K, V>>>,
right: Option<Box<Node<K, V>>>,
number_of_nodes: i32,
}
impl<K: Ord, V> Node<K, V> {
// Create a new node
fn new(key: K, value: V, number_of_nodes: i32) -> Node<K, V>{
Node {
key,
value,
left: None,
right: None,
number_of_nodes,
}
}
}
struct BST<K: Ord ,V> {
root: Option<Box<Node<K, V>>>,
}
impl<K: Ord, V> BST<K, V> {
// Get the size of this BST
fn size(&self) -> i32 {
size(&self.root)
}
// Search for key. Update value if found, otherwise insert the new node
fn put(&self, key: K, value: V) {
self.root = put(&self.root, key, value)
}
}
// Function for recursively get the size of a sub BST
fn size<K: Ord, V>(node: &Option<Box<Node<K, V>>>) -> i32 {
match node {
Some(real_node) => real_node.number_of_nodes,
None => 0,
}
}
// Function for recursively put a new node to this BST
fn put<K: Ord, V>(node: &Option<Box<Node<K, V>>>, key: K, value: V) -> &Option<Box<Node<K, V>>>{
match node {
None => {
let new_node = Some(Box::new(Node::new(key, value, 1)));
return &new_node;
},
Some(real_node) => {
match key.cmp(&real_node.key) {
Ordering::Less => real_node.left = *put(&real_node.left, key, value),
Ordering::Greater => real_node.right = *put(&real_node.right, key, value),
Ordering::Equal => real_node.value = value,
}
real_node.number_of_nodes = size(&real_node.right) size(&real_node.left) 1;
node
},
}
}
But this code won't compile, at the line self.root = put(&self.root, key, value)
, I get an error:
mismatched types
expected enum 'Option<Box<Node<K, V>>>' found reference '&Option<Box<Node<K, V>>>'
I don't know how to fix that, i tried to change the &self
parameter to self
, or self.root
to *self.root
, but i got more errors. I'm so confused about the reference in Rust, all I wanna do is rewrite the following Java code in Rust.
public class BST<Key extends Comparable<Key>, Value>
{
private Node root; //root of BST
private class Node
{
private Key key; // key
private Value val; // associated value
private Node right, left; // left and right subtrees
private int N; // number of nodes in subtree
public Node(Key key, Value val, int N)
{
this.key = key;
this.val = val;
this.N = N;
}
}
// Returns the number of key-value pairs in this symbol table.
public int size()
{
return size(root);
}
// Return number of key-value pairs in BST rooted at x
private int size(Node x)
{
if (x == null) return 0;
else return x.N;
}
public void put(Key key, Value val)
{
root = put(root, key, val);
}
private Node put(Node x, Key key, Value val)
{
if (x == null) return new Node(key, val, 1);
int cmp = key.compareTo(x.key);
if (cmp < 0) x.left = put(x.left, key, val);
else if (cmp > 0) x.right = put(x.right, key, val);
else x.val = val;
x.N = size(x.left) size(x.right) 1;
return x;
}
}
It's dead simple in Java
because I don't need to handle the reference. So here is my problems:
- How could I fix that mismatched error?
- What is the proper return type of that recursive function
put
, the&Option<Box<Node<K, V>>>
orOption<Box<Node<K, V>>>
? What's the difference? - Am I on the right way to rewrite this
Java
code? The rust-analyzer only reports this mismatched error but I don't know if it will work as I expect. And honestly I don't fully understand what am i doing when I handle the reference in rust especially when it's a reference of struct or enum
It's hard to learn Rust because I don't have much experience in system programming language, I appreciated your guys help :)
CodePudding user response:
The simplest option is to take a mutable reference to the node:
impl<K: Ord, V> BST<K, V> {
// ...
fn put(&mut self, key: K, value: V) {
put(&mut self.root, key, value)
}
}
fn put<K: Ord, V>(node: &mut Option<Box<Node<K, V>>>, key: K, value: V) {
match node {
None => {
*node = Some(Box::new(Node::new(key, value, 1)));
}
Some(real_node) => {
if key < real_node.key {
put(&mut real_node.left, key, value);
} else {
put(&mut real_node.right, key, value);
}
real_node.number_of_nodes = size(&real_node.right) size(&real_node.left) 1;
}
}
}
Note I changed the insertion behavior when the keys are equal, because in the original Rust and Java code the variable that stores the number of nodes would be incremented even if no new node was inserted.
Alternatively you could take the node by value and return the modified node:
impl<K: Ord, V> BST<K, V> {
// ...
pub fn put(&mut self, key: K, value: V) {
self.root = put(self.root.take(), key, value);
}
}
fn put<K: Ord, V>(node: Option<Box<Node<K, V>>>, key: K, value: V) -> Option<Box<Node<K, V>>> {
match node {
None => Some(Box::new(Node::new(key, value, 1))),
Some(mut real_node) => {
if key < real_node.key {
real_node.left = put(real_node.left, key, value);
} else {
real_node.right = put(real_node.right, key, value);
}
real_node.number_of_nodes = size(&real_node.right) size(&real_node.left) 1;
Some(real_node)
}
}
}
I suggest reading the ownership and references chapter of the Rust book if you haven't already.