ppforest2 v0.1.0
Projection Pursuit Decision Trees and Random Forests
Loading...
Searching...
No Matches
TrainingSpec.hpp
Go to the documentation of this file.
1#pragma once
2
10
11#include <memory>
12#include <thread>
13#include <nlohmann/json.hpp>
14
15namespace ppforest2 {
44 public:
45 using Ptr = std::shared_ptr<TrainingSpec>;
46
61
64
66 int const size;
68 int const seed;
70 int const threads;
72 int const max_retries;
73
102 class Builder {
103 public:
125
128
130 : mode(mode) {}
131
133 config.pp = std::move(v);
134 return *this;
135 }
137 config.vars = std::move(v);
138 return *this;
139 }
141 config.cutpoint = std::move(v);
142 return *this;
143 }
145 config.stop = std::move(v);
146 return *this;
147 }
149 config.binarization = std::move(v);
150 return *this;
151 }
153 config.grouping = std::move(v);
154 return *this;
155 }
157 config.leaf = std::move(v);
158 return *this;
159 }
160
161 Builder& size(int v) {
162 config.size = v;
163 return *this;
164 }
165 Builder& seed(int v) {
166 config.seed = v;
167 return *this;
168 }
169 Builder& threads(int v) {
170 config.threads = v;
171 return *this;
172 }
174 config.max_retries = v;
175 return *this;
176 }
177
207
215
218 };
219
228
253 int size,
254 int seed,
255 int threads,
256 int max_retries
257 );
258
259 // -- Forwarding methods (delegate to the underlying strategy) -----------
260
262 void find_projection(NodeContext& ctx, stats::RNG& rng) const;
263
265 void select_vars(NodeContext& ctx, stats::RNG& rng) const;
266
268 void find_cutpoint(NodeContext& ctx, stats::RNG& rng) const;
269
277 bool should_stop(NodeContext const& ctx, stats::RNG& rng) const;
278
280 void regroup(NodeContext& ctx, stats::RNG& rng) const;
281
289 void group(NodeContext& ctx, stats::RNG& rng) const;
290
293
295 TreeNode::Ptr create_leaf(NodeContext const& ctx, stats::RNG& rng) const { return leaf->create_leaf(ctx, rng); }
296
298 bool is_forest() const { return size > 0; }
299
301 nlohmann::json to_json() const;
302
304 static Ptr from_json(nlohmann::json const& j);
305
307 template<typename... Args> static Ptr make(Args&&... args) {
308 return std::make_shared<TrainingSpec>(std::forward<Args>(args)...);
309 }
310
319 int resolve_threads() const {
320 return threads > 0 ? threads : static_cast<int>(std::thread::hardware_concurrency());
321 }
322 };
323
325 inline bool is_classification(TrainingSpec const& spec) {
326 return types::is_classification(spec.mode);
327 }
328
330 inline bool is_regression(TrainingSpec const& spec) {
331 return types::is_regression(spec.mode);
332 }
333
335 inline bool is_classification(TrainingSpec::Ptr const& spec) {
336 return spec != nullptr && is_classification(*spec);
337 }
338
340 inline bool is_regression(TrainingSpec::Ptr const& spec) {
341 return spec != nullptr && is_regression(*spec);
342 }
343}
std::shared_ptr< ProjectionPursuit > Ptr
Definition Strategy.hpp:95
Fluent builder for TrainingSpec.
Definition TrainingSpec.hpp:102
Builder & leaf(leaf::LeafStrategy::Ptr v)
Definition TrainingSpec.hpp:156
Builder & seed(int v)
Definition TrainingSpec.hpp:165
Builder & cutpoint(cutpoint::Cutpoint::Ptr v)
Definition TrainingSpec.hpp:140
Builder & threads(int v)
Definition TrainingSpec.hpp:169
Builder & binarization(binarize::Binarization::Ptr v)
Definition TrainingSpec.hpp:148
Builder & max_retries(int v)
Definition TrainingSpec.hpp:173
Builder & size(int v)
Definition TrainingSpec.hpp:161
Ptr make()
Shorthand for std::make_shared<TrainingSpec>(build()).
Builder(types::Mode mode)
Definition TrainingSpec.hpp:129
types::Mode const mode
Definition TrainingSpec.hpp:127
Builder & pp(pp::ProjectionPursuit::Ptr v)
Definition TrainingSpec.hpp:132
TrainingSpec build()
Finalize the builder into a TrainingSpec.
Builder & vars(vars::VariableSelection::Ptr v)
Definition TrainingSpec.hpp:136
Builder & apply_defaults()
Fill in any null strategy fields with mode-aware defaults.
Config config
Definition TrainingSpec.hpp:126
Builder & stop(stop::StopRule::Ptr v)
Definition TrainingSpec.hpp:144
Builder & grouping(grouping::Grouping::Ptr v)
Definition TrainingSpec.hpp:152
Training configuration for projection pursuit trees and forests.
Definition TrainingSpec.hpp:43
TreeNode::Ptr create_leaf(NodeContext const &ctx, stats::RNG &rng) const
Create a leaf node from the current node context.
Definition TrainingSpec.hpp:295
grouping::Grouping::Ptr const grouping
Grouping strategy.
Definition TrainingSpec.hpp:58
static Builder builder(types::Mode mode)
Create a builder for the given mode.
Definition TrainingSpec.hpp:227
static Ptr make(Args &&... args)
Create a shared pointer to a TrainingSpec.
Definition TrainingSpec.hpp:307
TrainingSpec(pp::ProjectionPursuit::Ptr pp, vars::VariableSelection::Ptr vars, cutpoint::Cutpoint::Ptr cutpoint, stop::StopRule::Ptr stop, binarize::Binarization::Ptr binarization, grouping::Grouping::Ptr grouping, leaf::LeafStrategy::Ptr leaf, types::Mode mode, int size, int seed, int threads, int max_retries)
Construct a training specification.
void group(NodeContext &ctx, stats::RNG &rng) const
Split observations into two child partitions.
binarize::Binarization::Ptr const binarization
Binarization strategy.
Definition TrainingSpec.hpp:56
leaf::LeafStrategy::Ptr const leaf
Leaf creation strategy.
Definition TrainingSpec.hpp:60
pp::ProjectionPursuit::Ptr const pp
Projection pursuit optimization strategy.
Definition TrainingSpec.hpp:48
int resolve_threads() const
Get the number of threads to use for training.
Definition TrainingSpec.hpp:319
void find_projection(NodeContext &ctx, stats::RNG &rng) const
Run projection pursuit optimization. Asserts postcondition: ctx.projector and ctx....
nlohmann::json to_json() const
Serialize the training spec to JSON.
void find_cutpoint(NodeContext &ctx, stats::RNG &rng) const
Compute the split cutpoint. Asserts postcondition: ctx.cutpoint is set.
bool is_forest() const
Whether this specification describes a forest (size > 0).
Definition TrainingSpec.hpp:298
cutpoint::Cutpoint::Ptr const cutpoint
Split cutpoint strategy.
Definition TrainingSpec.hpp:52
stop::StopRule::Ptr const stop
Stop rule strategy.
Definition TrainingSpec.hpp:54
bool should_stop(NodeContext const &ctx, stats::RNG &rng) const
Check whether the node should stop growing.
vars::VariableSelection::Ptr const vars
Variable selection strategy.
Definition TrainingSpec.hpp:50
int const max_retries
Maximum retry attempts for degenerate trees.
Definition TrainingSpec.hpp:72
int const size
Number of trees (0 = single tree).
Definition TrainingSpec.hpp:66
void regroup(NodeContext &ctx, stats::RNG &rng) const
Reduce multiclass partition to binary. Asserts postcondition: ctx.y_bin is set.
int const seed
RNG seed.
Definition TrainingSpec.hpp:68
void select_vars(NodeContext &ctx, stats::RNG &rng) const
Run variable selection. Asserts postcondition: ctx.var_selection is set.
static Ptr from_json(nlohmann::json const &j)
Deserialize a training spec from JSON.
types::Mode const mode
Training mode (classification or regression).
Definition TrainingSpec.hpp:63
int const threads
Number of threads for parallel forest training.
Definition TrainingSpec.hpp:70
stats::GroupPartition init_groups(types::OutcomeVector const &y) const
Create the initial group partition from the training response.
Definition TrainingSpec.hpp:292
std::shared_ptr< TrainingSpec > Ptr
Definition TrainingSpec.hpp:45
std::unique_ptr< TreeNode > Ptr
Definition TreeNode.hpp:21
Contiguous-block representation of grouped observations.
Definition GroupPartition.hpp:40
pcg32 RNG
Definition Stats.hpp:24
bool is_classification(Mode mode)
Whether mode is Classification.
Definition Types.hpp:61
Eigen::Matrix< Outcome, Eigen::Dynamic, 1 > OutcomeVector
Dynamic-size column vector of predictions.
Definition Types.hpp:42
bool is_regression(Mode mode)
Whether mode is Regression.
Definition Types.hpp:66
Mode
Training mode.
Definition Types.hpp:58
Binarization strategies for multiclass-to-binary reduction.
Definition Benchmark.hpp:25
bool is_classification(Model const &model)
Whether model was trained for classification.
Definition Model.hpp:145
bool is_regression(Model const &model)
Whether model was trained for regression.
Definition Model.hpp:155
Mutable context accumulating intermediate results during node training.
Definition NodeContext.hpp:20
Builder state — the configuration being assembled.
Definition TrainingSpec.hpp:111
leaf::LeafStrategy::Ptr leaf
Definition TrainingSpec.hpp:118
int max_retries
Definition TrainingSpec.hpp:123
int size
Definition TrainingSpec.hpp:120
cutpoint::Cutpoint::Ptr cutpoint
Definition TrainingSpec.hpp:114
int seed
Definition TrainingSpec.hpp:121
grouping::Grouping::Ptr grouping
Definition TrainingSpec.hpp:117
int threads
Definition TrainingSpec.hpp:122
vars::VariableSelection::Ptr vars
Definition TrainingSpec.hpp:113
pp::ProjectionPursuit::Ptr pp
Definition TrainingSpec.hpp:112
stop::StopRule::Ptr stop
Definition TrainingSpec.hpp:115
binarize::Binarization::Ptr binarization
Definition TrainingSpec.hpp:116