ppforest2 v0.1.0
Projection Pursuit Decision Trees and Random Forests
Loading...
Searching...
No Matches
Model.hpp
Go to the documentation of this file.
1#pragma once
2
4#include "utils/Types.hpp"
5
6#include <memory>
7
8namespace ppforest2 {
9 class Tree;
10 class Forest;
13 class RegressionTree;
14 class RegressionForest;
15
23 struct Proportions {};
24
25
29 class Model {
30 public:
31 using Ptr = std::shared_ptr<Model>;
32
34 * @brief Visitor interface for model dispatch.
35 *
36 * Two layers of dispatch:
37 * - `visit(Tree)` / `visit(Forest)` — bimodal (classification or
38 * regression). Default to no-op so visitors can override only
39 * the cases they care about.
40 * - `visit(ClassificationTree)` / `visit(RegressionTree)` /
41 * `visit(ClassificationForest)` / `visit(RegressionForest)` —
42 * mode-specific. Default implementations delegate to the
43 * bimodal overload, so a visitor that overrides only `visit(Tree)`
44 * still receives both classification and regression trees.
45 *
46 * Visitors that don't care about mode override only the bimodal
47 * pair. Visitors that need to call mode-specific API (e.g., the
48 * classification-only `predict(data, Proportions{})`) override the
49 * relevant mode-specific overload(s).
50 */
51 class Visitor {
52 public:
53 virtual ~Visitor() = default;
54
55 virtual void visit(Tree const&) {}
56 virtual void visit(Forest const&) {}
57
58 virtual void visit(ClassificationTree const& tree);
59 virtual void visit(ClassificationForest const& forest);
60 virtual void visit(RegressionTree const& tree);
61 virtual void visit(RegressionForest const& forest);
62 };
63
64 virtual ~Model() = default;
65
67 bool degenerate = false;
68
71
73 virtual void accept(Visitor& visitor) const = 0;
74
81 virtual types::Outcome predict(types::FeatureVector const& x) const = 0;
82
93 types::OutcomeVector result(x.rows());
94
95 for (Eigen::Index i = 0; i < x.rows(); ++i) {
96 result(i) = predict(static_cast<types::FeatureVector>(x.row(i)));
97 }
98
99 return result;
100 }
101
114
123 };
124
135
138
145 inline bool is_classification(Model const& model) {
146 return is_classification(model.training_spec);
147 }
148
155 inline bool is_regression(Model const& model) {
156 return is_regression(model.training_spec);
157 }
158
160 inline bool is_classification(Model const* model) {
161 return model != nullptr && is_classification(*model);
162 }
163
165 inline bool is_regression(Model const* model) {
166 return model != nullptr && is_regression(*model);
167 }
168}
Random forest of classification trees.
Definition ClassificationForest.hpp:18
A projection pursuit decision tree for classification.
Definition ClassificationTree.hpp:16
Abstract base class for projection pursuit random forests.
Definition Forest.hpp:31
Visitor interface for model dispatch.
Definition Model.hpp:51
virtual void visit(Tree const &)
Definition Model.hpp:55
virtual ~Visitor()=default
virtual void visit(RegressionForest const &forest)
virtual void visit(RegressionTree const &tree)
virtual void visit(Forest const &)
Definition Model.hpp:56
virtual void visit(ClassificationTree const &tree)
virtual void visit(ClassificationForest const &forest)
Abstract base class for predictive models (trees and forests).
Definition Model.hpp:29
virtual types::OutcomeVector predict(types::FeatureMatrix const &x) const
Predict a matrix of observations.
Definition Model.hpp:92
std::shared_ptr< Model > Ptr
Definition Model.hpp:31
bool degenerate
Whether the model contains degenerate nodes/splits.
Definition Model.hpp:67
static Ptr train(TrainingSpec const &spec, types::FeatureMatrix &x, types::OutcomeVector &y)
Train a model from a training specification.
virtual types::Outcome predict(types::FeatureVector const &x) const =0
Predict a single observation.
static void check_train_inputs(types::FeatureMatrix const &x, types::OutcomeVector const &y)
Validate common training inputs (y non-empty, matching x rows).
TrainingSpec::Ptr training_spec
Training specification used to build this model.
Definition Model.hpp:70
virtual void accept(Visitor &visitor) const =0
Accept a model visitor (double dispatch).
virtual ~Model()=default
Random forest of regression trees.
Definition RegressionForest.hpp:13
A projection pursuit decision tree for regression.
Definition RegressionTree.hpp:15
Training configuration for projection pursuit trees and forests.
Definition TrainingSpec.hpp:43
std::shared_ptr< TrainingSpec > Ptr
Definition TrainingSpec.hpp:45
Abstract base class for projection pursuit decision trees.
Definition Tree.hpp:29
Eigen::Matrix< Feature, Eigen::Dynamic, Eigen::Dynamic > FeatureMatrix
Dynamic-size matrix of feature values.
Definition Types.hpp:33
Eigen::Matrix< Outcome, Eigen::Dynamic, 1 > OutcomeVector
Dynamic-size column vector of predictions.
Definition Types.hpp:42
Eigen::Matrix< Feature, Eigen::Dynamic, 1 > FeatureVector
Dynamic-size column vector of feature values.
Definition Types.hpp:36
Feature Outcome
Scalar type for predictions (float for both classification and regression).
Definition Types.hpp:30
Binarization strategies for multiclass-to-binary reduction.
Definition Benchmark.hpp:25
bool is_classification(Model const &model)
Whether model was trained for classification.
Definition Model.hpp:145
types::FeatureMatrix predict_proportions(Model const &model, types::FeatureMatrix const &x)
Compute vote proportions for a classification model.
bool is_regression(Model const &model)
Whether model was trained for regression.
Definition Model.hpp:155
Tag type for requesting vote-proportion predictions.
Definition Model.hpp:23