ppforest2 v0.1.0
Projection Pursuit Decision Trees and Random Forests
Loading...
Searching...
No Matches
TreeNode.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <memory>
4#include <set>
5
6#include "utils/Types.hpp"
7
8namespace ppforest2 {
9 class TreeBranch;
10 class TreeLeaf;
11
19 class TreeNode {
20 public:
21 using Ptr = std::unique_ptr<TreeNode>;
22
31 class Visitor {
32 public:
33 virtual ~Visitor() = default;
34
35 virtual void visit(TreeBranch const&) {}
36 virtual void visit(TreeLeaf const&) {}
37 };
38
40 bool degenerate = false;
41
42 virtual ~TreeNode() = default;
43
45 virtual void accept(Visitor& visitor) const = 0;
46
53 virtual types::Outcome predict(types::FeatureVector const& x) const = 0;
54
56 virtual types::Outcome response() const = 0;
57
61 virtual int group_count() const = 0;
62
66 virtual std::set<types::GroupId> node_groups() const = 0;
67
69 virtual bool equals(TreeNode const& other) const = 0;
70
72 virtual Ptr clone() const = 0;
73
74 bool operator==(TreeNode const& other) const;
75 bool operator!=(TreeNode const& other) const;
76 };
77
86 bool is_leaf(TreeNode const& node);
87 bool is_leaf(TreeNode::Ptr const& node);
88}
Internal split node in a projection pursuit tree.
Definition TreeBranch.hpp:15
Leaf node in a projection pursuit tree.
Definition TreeLeaf.hpp:12
Visitor interface for tree node dispatch.
Definition TreeNode.hpp:31
virtual void visit(TreeLeaf const &)
Definition TreeNode.hpp:36
virtual ~Visitor()=default
virtual void visit(TreeBranch const &)
Definition TreeNode.hpp:35
Abstract base class for nodes in a projection pursuit tree.
Definition TreeNode.hpp:19
bool operator!=(TreeNode const &other) const
virtual std::set< types::GroupId > node_groups() const =0
Sorted set of group labels reachable from this node.
bool degenerate
Whether this node (or any descendant) had a degenerate split.
Definition TreeNode.hpp:40
virtual void accept(Visitor &visitor) const =0
Accept a tree node visitor (double dispatch).
virtual types::Outcome predict(types::FeatureVector const &x) const =0
Predict the group label for a single observation.
virtual bool equals(TreeNode const &other) const =0
Structural equality comparison (value-based).
virtual types::Outcome response() const =0
The group label at this node (leaf value or majority group).
std::unique_ptr< TreeNode > Ptr
Definition TreeNode.hpp:21
virtual int group_count() const =0
Number of distinct groups reachable from this node.
virtual ~TreeNode()=default
bool operator==(TreeNode const &other) const
virtual Ptr clone() const =0
Deep copy of this node and its subtree.
Eigen::Matrix< Feature, Eigen::Dynamic, 1 > FeatureVector
Dynamic-size column vector of feature values.
Definition Types.hpp:36
Feature Outcome
Scalar type for predictions (float for both classification and regression).
Definition Types.hpp:30
Binarization strategies for multiclass-to-binary reduction.
Definition Benchmark.hpp:25
bool is_leaf(TreeNode const &node)
Whether node is a TreeLeaf.