ppforest2 v0.1.0
Projection Pursuit Decision Trees and Random Forests
Loading...
Searching...
No Matches
ppforest2::Tree Struct Reference

A single projection pursuit decision tree. More...

#include <Tree.hpp>

Public Member Functions

 Tree (TreeNode::Ptr root, TrainingSpec::Ptr training_spec)
 
void accept (Model::Visitor &visitor) const override
 Accept a model visitor (double dispatch).
 
bool operator!= (Tree const &other) const
 
bool operator== (Tree const &other) const
 
types::ResponseVector predict (types::FeatureMatrix const &data) const override
 Predict a matrix of observations.
 
types::FeatureMatrix predict (types::FeatureMatrix const &data, Proportions) const override
 Predict proportions for a matrix of observations.
 
types::Response predict (types::FeatureVector const &data) const override
 Predict a single observation.
 
- Public Member Functions inherited from ppforest2::Model
virtual ~Model ()=default
 

Static Public Member Functions

static Tree train (TrainingSpec const &training_spec, types::FeatureMatrix const &x, stats::GroupPartition const &group_spec, stats::RNG &rng)
 Train a tree from a group partition.
 
static Tree train (TrainingSpec const &training_spec, types::FeatureMatrix const &x, types::ResponseVector const &y, stats::RNG &rng)
 Train a tree from a response vector.
 
- Static Public Member Functions inherited from ppforest2::Model
static Ptr train (TrainingSpec const &spec, types::FeatureMatrix const &x, types::ResponseVector const &y)
 Train a model from a training specification.
 

Public Attributes

TreeNode::Ptr root
 Root node of the tree.
 
- Public Attributes inherited from ppforest2::Model
bool degenerate = false
 Whether the model contains degenerate nodes/splits.
 
TrainingSpec::Ptr training_spec
 Training specification used to build this model.
 

Additional Inherited Members

- Public Types inherited from ppforest2::Model
using Ptr = std::shared_ptr<Model>
 

Detailed Description

A single projection pursuit decision tree.

Each internal node projects data onto a linear combination of features and splits on the projected value. Leaf nodes hold group labels.

stats::RNG rng(0);
Tree tree = Tree::train(spec, x, y, rng);
types::Response label = tree.predict(x.row(0));
DRStrategy::Ptr noop()
Factory function for a no-op DR strategy.
PPStrategy::Ptr pda(float lambda)
Factory function for a PDA projection pursuit strategy.
SRStrategy::Ptr mean_of_means()
Factory function for a mean-of-means split strategy.
pcg32 RNG
Definition Stats.hpp:19
int Response
Scalar type for group labels (integer).
Definition Types.hpp:21
Eigen::Matrix< Response, Eigen::Dynamic, 1 > ResponseVector
Dynamic-size column vector of group labels.
Definition Types.hpp:29
Training configuration for projection pursuit trees and forests.
Definition TrainingSpec.hpp:40
types::Response predict(types::FeatureVector const &data) const override
Predict a single observation.
Tree(TreeNode::Ptr root, TrainingSpec::Ptr training_spec)
static Tree train(TrainingSpec const &training_spec, types::FeatureMatrix const &x, types::ResponseVector const &y, stats::RNG &rng)
Train a tree from a response vector.

Constructor & Destructor Documentation

◆ Tree()

ppforest2::Tree::Tree ( TreeNode::Ptr root,
TrainingSpec::Ptr training_spec )

Member Function Documentation

◆ accept()

void ppforest2::Tree::accept ( Model::Visitor & visitor) const
overridevirtual

Accept a model visitor (double dispatch).

Implements ppforest2::Model.

◆ operator!=()

bool ppforest2::Tree::operator!= ( Tree const & other) const

◆ operator==()

bool ppforest2::Tree::operator== ( Tree const & other) const

◆ predict() [1/3]

types::ResponseVector ppforest2::Tree::predict ( types::FeatureMatrix const & data) const
overridevirtual

Predict a matrix of observations.

Parameters
dataFeature matrix (n × p).
Returns
Predicted group labels (n).

Implements ppforest2::Model.

◆ predict() [2/3]

types::FeatureMatrix ppforest2::Tree::predict ( types::FeatureMatrix const & data,
Proportions  ) const
overridevirtual

Predict proportions for a matrix of observations.

Returns an (n × G) matrix. For forests, each row contains the fraction of trees that voted for each group. For single trees, each row is a one-hot encoding of the predicted group.

Parameters
dataFeature matrix (n × p).
Returns
Proportion matrix (n × G), rows sum to 1.0.

Implements ppforest2::Model.

◆ predict() [3/3]

types::Response ppforest2::Tree::predict ( types::FeatureVector const & data) const
overridevirtual

Predict a single observation.

Parameters
dataFeature vector (p).
Returns
Predicted group label.

Implements ppforest2::Model.

◆ train() [1/2]

static Tree ppforest2::Tree::train ( TrainingSpec const & training_spec,
types::FeatureMatrix const & x,
stats::GroupPartition const & group_spec,
stats::RNG & rng )
static

Train a tree from a group partition.

Recursively splits the data by finding optimal projections at each node until pure leaves are reached.

Parameters
training_specTraining specification (strategy + DR).
xFeature matrix (n × p).
group_specGroup partition.
rngRandom number generator.
Returns
Trained tree.

◆ train() [2/2]

static Tree ppforest2::Tree::train ( TrainingSpec const & training_spec,
types::FeatureMatrix const & x,
types::ResponseVector const & y,
stats::RNG & rng )
static

Train a tree from a response vector.

Constructs a GroupPartition from y and delegates to the GroupPartition overload.

Parameters
training_specTraining specification (strategy + DR).
xFeature matrix (n × p).
yResponse vector (n).
rngRandom number generator.
Returns
Trained tree.

Member Data Documentation

◆ root

TreeNode::Ptr ppforest2::Tree::root

Root node of the tree.


The documentation for this struct was generated from the following file: