Home > Back-end >  Why won't my Decision Tree classifier work. The functions say not enough input arguments
Why won't my Decision Tree classifier work. The functions say not enough input arguments

Time:12-23

I have coded a Decision Tree classifier in Matlab. To the best of my knowledge everything should work, the logic checks out. When I try to call the fit method it breaks on one of my functions telling me I haven't got the right input arguments but I'm sure I do! Been trying to solve this and similar errors to do with functions and input arguments for a day or two now. I wondered if it had something to do from calling them from within the constructor but calling them from the main script still doesn't work. Pls help!

classdef my_ClassificationTree < handle
    
    properties
        X % training examples
        Y % training labels
        MinParentSize % minimum parent node size
        MaxNumSplits % maximum number of splits        
        Verbose % are we printing out debug as we go?
       % MinLeafSize
        CutPoint
        CutPredictorIndex
        Children
        numSplits
        root
    end
    
    methods
        
        % constructor: implementing the fitting phase
        
        function obj = my_ClassificationTree(X, Y, MinParentSize, MaxNumSplits, Verbose)
            obj.X = X;
            obj.Y = Y;
            obj.MinParentSize = MinParentSize;
            obj.MaxNumSplits = MaxNumSplits;
            obj.Verbose = Verbose;
%             obj.Children = zeros(1, 2);
%             obj.CutPoint = 0;
%             obj.CutPredictorIndex = 0;
           % obj.MinLeafSize = MinLeafSize;
            obj.numSplits = 0;
            obj.root = Node(1, size(obj.X,1));
            root = Node(1, size(obj.X,1));
            fit(obj,root);
        end
        
        function node = Node(sIndex,eIndex)
            node.startIndex = sIndex;
            node.endIndex = eIndex;
            node.leaf = false;
            node.Children = 0;
            node.size = eIndex - sIndex   1;
            node.CutPoint = 0;
            node.CutPredictorIndex = 0;
            node.NodeClass = 0;
        end

        function fit(obj,node)            
            if node.size < obj.MinParentSize || obj.numSplits >= obj.MaxNumSplits
                 % Mark the node as a leaf node
                 node.Leaf = true;
                 % Calculate the majority class label for the examples at this node
                 labels = obj.Y(node.startIndex:node.endIndex); %gather all the labels for the data in the nodes range
                 node.NodeClass = mode(labels); %find the most frequent label and classify the node as such
                 return;
            end
            bestCutPoint = findBestCutPoint(node, obj.X, obj.Y);
            leftChild = Node(node.startIndex, bestCutPoint.CutIndex - 1);
            rightChild = Node(bestSplit.splitIndex, node.endIndex);
            obj.numSplits = obj.numSplits   1;
            node.CutPoint = bestSplit.CutPoint;
            node.CutPredictorIndex = bestSplit.CutPredictorIndex;
            %Attach the child nodes to the parent node
            node.Children = [leftChild, rightChild];
            % Recursively build the tree for the left and right child nodes
            fit(obj, leftChild);
            fit(obj, rightChild);
        end        

        function bestCutPoint = findBestCutPoint(node, X, labels)
            bestCutPoint.CutPoint = 0;
            bestCutPoint.CutPredictorIndex = 0;
            bestCutPoint.CutIndex = 0;
            bestGDI = Inf; % Initialize the best GDI to a large value
            
            % Loop through all the features
            for i = 1:size(X, 2)
                % Loop through all the unique values of the feature
                values = unique(X(node.startIndex:node.endIndex, i));
                for j = 1:length(values)
                    % Calculate the weighted impurity of the two resulting
                    % cut
                    leftLabels = labels(node.startIndex:node.endIndex, 1);
                    rightLabels = labels(node.startIndex:node.endIndex, 1);
                    leftLabels = leftLabels(X(node.startIndex:node.endIndex, i) < values(j));
                    rightLabels = rightLabels(X(node.startIndex:node.endIndex, i) >= values(j));
                    leftGDI = weightedGDI(leftLabels, labels);
                    rightGDI = weightedGDI(rightLabels, labels);
                    % Calculate the weighted impurity of the split
                    cutGDI = leftGDI   rightGDI;
                    % Update the best split if the current split has a lower GDI
                    if cutGDI < bestGDI
                        bestGDI = cutGDI;
                        bestCutPoint.CutPoint = values(j);
                        bestCutPoint.CutPredictorIndex = i;
                        bestCutPoint.CutIndex = find(X(:, i) == values(j), 1, 'first');
                    end
                end
            end
        end

% the prediction phase:
        function predictions = predict(obj, test_examples)
            
            % get ready to store our predicted class labels:
            predictions = categorical;
            
             % Iterate over each example in X
            for i = 1:size(test_examples, 1)
                % Set the current node to be the root node
                currentNode = obj.root;
                % While the current node is not a leaf node
                while ~currentNode.leaf 
                    % Check the value of the predictor feature specified by the CutPredictorIndex property of the current node
                    value = test_examples(i, currentNode.CutPredictorIndex);
                    % If the value is less than the CutPoint of the current node, set the current node to be the left child of the current node
                    if value < currentNode.CutPoint
                        currentNode = currentNode.Children(1);
                    % If the value is greater than or equal to the CutPoint of the current node, set the current node to be the right child of the current node
                    else
                        currentNode = currentNode.Children(2);
                    end
                end
                % Once the current node is a leaf node, add the NodeClass of the current node to the predictions vector
                predictions(i) = currentNode.NodeClass;
            end
        end
        
        % add any other methods you want on the lines below...

    end
    
end

This is the function that calls myClassificationTree

function m = my_fitctree(train_examples, train_labels, varargin)

    % take an extra name-value pair allowing us to turn debug on:
    p = inputParser;
    addParameter(p, 'Verbose', false);
               
  • Related