ppforest2 v0.1.0
Projection Pursuit Decision Trees and Random Forests
Loading...
Searching...
No Matches
ClassificationTree.hpp
Go to the documentation of this file.
1#pragma once
2
3#include "models/Tree.hpp"
4
5#include <map>
6#include <set>
7
8namespace ppforest2 {
16 class ClassificationTree : public Tree {
17 public:
18 using Ptr = std::unique_ptr<ClassificationTree>;
19 using Tree::predict;
20
25 using RNG = stats::RNG;
26
27 using Groups = std::set<types::GroupId>;
28 using GroupIndices = std::map<types::GroupId, int>;
30
42
44 : Tree(std::move(root), std::move(spec))
45 , groups(std::move(groups)) {
46 invariant(this->training_spec != nullptr, "ClassificationTree requires a non-null TrainingSpec");
47 invariant(is_classification(this), "ClassificationTree requires a classification TrainingSpec");
48 invariant(!this->groups.empty(), "ClassificationTree requires a non-empty groups set");
49 }
50
51
64 static Ptr train(TrainingSpec const& s, FeatureMatrix& x, OutcomeVector& y, GroupPartition const& y_part, RNG& rng);
65
71
82
88
97
98 void accept(Model::Visitor& visitor) const override;
99 };
100}
void invariant(bool condition, char const *message)
Runtime assertion that throws on failure.
types::OutcomeVector OutcomeVector
Definition ClassificationTree.hpp:23
void accept(Model::Visitor &visitor) const override
Accept a model visitor (mode-specific dispatch).
stats::GroupPartition GroupPartition
Definition ClassificationTree.hpp:29
types::Outcome Outcome
Definition ClassificationTree.hpp:24
std::unique_ptr< ClassificationTree > Ptr
Definition ClassificationTree.hpp:18
std::set< types::GroupId > Groups
Definition ClassificationTree.hpp:27
Groups groups
Set of group labels this tree predicts over.
Definition ClassificationTree.hpp:41
stats::RNG RNG
Definition ClassificationTree.hpp:25
std::map< types::GroupId, int > GroupIndices
Definition ClassificationTree.hpp:28
FeatureMatrix predict(FeatureMatrix const &x, Proportions, GroupIndices const &indices) const
One-hot encoding per row, with an explicit column layout.
FeatureVector predict(FeatureVector const &x, Proportions, GroupIndices const &indices) const
One-hot encoding for one observation, with an explicit column layout passed as a precomputed {group →...
static Ptr train(TrainingSpec const &s, FeatureMatrix &x, OutcomeVector &y, GroupPartition const &y_part, RNG &rng)
Train a classification tree with an external RNG.
FeatureMatrix predict(FeatureMatrix const &x, Proportions) const
One-hot encoding of the predicted group per row, columns laid out by groups().
ClassificationTree(TreeNode::Ptr root, TrainingSpec::Ptr spec, Groups groups)
Definition ClassificationTree.hpp:43
types::FeatureVector predict(types::FeatureVector const &x, Proportions) const
One-hot encoding of the predicted group for one observation, with columns laid out by groups().
types::FeatureVector FeatureVector
Definition ClassificationTree.hpp:22
types::FeatureMatrix FeatureMatrix
Definition ClassificationTree.hpp:21
Visitor interface for model dispatch.
Definition Model.hpp:51
TrainingSpec::Ptr training_spec
Training specification used to build this model.
Definition Model.hpp:70
Training configuration for projection pursuit trees and forests.
Definition TrainingSpec.hpp:43
std::shared_ptr< TrainingSpec > Ptr
Definition TrainingSpec.hpp:45
std::unique_ptr< TreeNode > Ptr
Definition TreeNode.hpp:21
Root root
Root node of the tree.
Definition Tree.hpp:45
Tree(TreeNode::Ptr root, TrainingSpec::Ptr spec)
Definition Tree.hpp:93
types::Outcome predict(types::FeatureVector const &x) const override
Predict a single observation.
Contiguous-block representation of grouped observations.
Definition GroupPartition.hpp:40
pcg32 RNG
Definition Stats.hpp:24
Eigen::Matrix< Feature, Eigen::Dynamic, Eigen::Dynamic > FeatureMatrix
Dynamic-size matrix of feature values.
Definition Types.hpp:33
Eigen::Matrix< Outcome, Eigen::Dynamic, 1 > OutcomeVector
Dynamic-size column vector of predictions.
Definition Types.hpp:42
Eigen::Matrix< Feature, Eigen::Dynamic, 1 > FeatureVector
Dynamic-size column vector of feature values.
Definition Types.hpp:36
Feature Outcome
Scalar type for predictions (float for both classification and regression).
Definition Types.hpp:30
Binarization strategies for multiclass-to-binary reduction.
Definition Benchmark.hpp:25
bool is_classification(Model const &model)
Whether model was trained for classification.
Definition Model.hpp:145
Tag type for requesting vote-proportion predictions.
Definition Model.hpp:23