13#include <nlohmann/json.hpp>
45 using Ptr = std::shared_ptr<TrainingSpec>;
137 config.vars = std::move(v);
141 config.cutpoint = std::move(v);
145 config.stop = std::move(v);
149 config.binarization = std::move(v);
153 config.grouping = std::move(v);
157 config.leaf = std::move(v);
307 template<
typename... Args>
static Ptr make(Args&&... args) {
308 return std::make_shared<TrainingSpec>(std::forward<Args>(args)...);
320 return threads > 0 ?
threads :
static_cast<int>(std::thread::hardware_concurrency());
std::shared_ptr< ProjectionPursuit > Ptr
Definition Strategy.hpp:95
Fluent builder for TrainingSpec.
Definition TrainingSpec.hpp:102
Builder & leaf(leaf::LeafStrategy::Ptr v)
Definition TrainingSpec.hpp:156
Builder & seed(int v)
Definition TrainingSpec.hpp:165
Builder & cutpoint(cutpoint::Cutpoint::Ptr v)
Definition TrainingSpec.hpp:140
Builder & threads(int v)
Definition TrainingSpec.hpp:169
Builder & binarization(binarize::Binarization::Ptr v)
Definition TrainingSpec.hpp:148
Builder & max_retries(int v)
Definition TrainingSpec.hpp:173
Builder & size(int v)
Definition TrainingSpec.hpp:161
Ptr make()
Shorthand for std::make_shared<TrainingSpec>(build()).
Builder(types::Mode mode)
Definition TrainingSpec.hpp:129
types::Mode const mode
Definition TrainingSpec.hpp:127
Builder & pp(pp::ProjectionPursuit::Ptr v)
Definition TrainingSpec.hpp:132
TrainingSpec build()
Finalize the builder into a TrainingSpec.
Builder & vars(vars::VariableSelection::Ptr v)
Definition TrainingSpec.hpp:136
Builder & apply_defaults()
Fill in any null strategy fields with mode-aware defaults.
Config config
Definition TrainingSpec.hpp:126
Builder & stop(stop::StopRule::Ptr v)
Definition TrainingSpec.hpp:144
Builder & grouping(grouping::Grouping::Ptr v)
Definition TrainingSpec.hpp:152
Training configuration for projection pursuit trees and forests.
Definition TrainingSpec.hpp:43
TreeNode::Ptr create_leaf(NodeContext const &ctx, stats::RNG &rng) const
Create a leaf node from the current node context.
Definition TrainingSpec.hpp:295
grouping::Grouping::Ptr const grouping
Grouping strategy.
Definition TrainingSpec.hpp:58
static Builder builder(types::Mode mode)
Create a builder for the given mode.
Definition TrainingSpec.hpp:227
static Ptr make(Args &&... args)
Create a shared pointer to a TrainingSpec.
Definition TrainingSpec.hpp:307
TrainingSpec(pp::ProjectionPursuit::Ptr pp, vars::VariableSelection::Ptr vars, cutpoint::Cutpoint::Ptr cutpoint, stop::StopRule::Ptr stop, binarize::Binarization::Ptr binarization, grouping::Grouping::Ptr grouping, leaf::LeafStrategy::Ptr leaf, types::Mode mode, int size, int seed, int threads, int max_retries)
Construct a training specification.
void group(NodeContext &ctx, stats::RNG &rng) const
Split observations into two child partitions.
binarize::Binarization::Ptr const binarization
Binarization strategy.
Definition TrainingSpec.hpp:56
leaf::LeafStrategy::Ptr const leaf
Leaf creation strategy.
Definition TrainingSpec.hpp:60
pp::ProjectionPursuit::Ptr const pp
Projection pursuit optimization strategy.
Definition TrainingSpec.hpp:48
int resolve_threads() const
Get the number of threads to use for training.
Definition TrainingSpec.hpp:319
void find_projection(NodeContext &ctx, stats::RNG &rng) const
Run projection pursuit optimization. Asserts postcondition: ctx.projector and ctx....
nlohmann::json to_json() const
Serialize the training spec to JSON.
void find_cutpoint(NodeContext &ctx, stats::RNG &rng) const
Compute the split cutpoint. Asserts postcondition: ctx.cutpoint is set.
bool is_forest() const
Whether this specification describes a forest (size > 0).
Definition TrainingSpec.hpp:298
cutpoint::Cutpoint::Ptr const cutpoint
Split cutpoint strategy.
Definition TrainingSpec.hpp:52
stop::StopRule::Ptr const stop
Stop rule strategy.
Definition TrainingSpec.hpp:54
bool should_stop(NodeContext const &ctx, stats::RNG &rng) const
Check whether the node should stop growing.
vars::VariableSelection::Ptr const vars
Variable selection strategy.
Definition TrainingSpec.hpp:50
int const max_retries
Maximum retry attempts for degenerate trees.
Definition TrainingSpec.hpp:72
int const size
Number of trees (0 = single tree).
Definition TrainingSpec.hpp:66
void regroup(NodeContext &ctx, stats::RNG &rng) const
Reduce multiclass partition to binary. Asserts postcondition: ctx.y_bin is set.
int const seed
RNG seed.
Definition TrainingSpec.hpp:68
void select_vars(NodeContext &ctx, stats::RNG &rng) const
Run variable selection. Asserts postcondition: ctx.var_selection is set.
static Ptr from_json(nlohmann::json const &j)
Deserialize a training spec from JSON.
types::Mode const mode
Training mode (classification or regression).
Definition TrainingSpec.hpp:63
int const threads
Number of threads for parallel forest training.
Definition TrainingSpec.hpp:70
stats::GroupPartition init_groups(types::OutcomeVector const &y) const
Create the initial group partition from the training response.
Definition TrainingSpec.hpp:292
std::shared_ptr< TrainingSpec > Ptr
Definition TrainingSpec.hpp:45
std::unique_ptr< TreeNode > Ptr
Definition TreeNode.hpp:21
Contiguous-block representation of grouped observations.
Definition GroupPartition.hpp:40
pcg32 RNG
Definition Stats.hpp:24
bool is_classification(Mode mode)
Whether mode is Classification.
Definition Types.hpp:61
Eigen::Matrix< Outcome, Eigen::Dynamic, 1 > OutcomeVector
Dynamic-size column vector of predictions.
Definition Types.hpp:42
bool is_regression(Mode mode)
Whether mode is Regression.
Definition Types.hpp:66
Mode
Training mode.
Definition Types.hpp:58
Binarization strategies for multiclass-to-binary reduction.
Definition Benchmark.hpp:25
bool is_classification(Model const &model)
Whether model was trained for classification.
Definition Model.hpp:145
bool is_regression(Model const &model)
Whether model was trained for regression.
Definition Model.hpp:155
Mutable context accumulating intermediate results during node training.
Definition NodeContext.hpp:20
Builder state — the configuration being assembled.
Definition TrainingSpec.hpp:111
leaf::LeafStrategy::Ptr leaf
Definition TrainingSpec.hpp:118
int max_retries
Definition TrainingSpec.hpp:123
int size
Definition TrainingSpec.hpp:120
cutpoint::Cutpoint::Ptr cutpoint
Definition TrainingSpec.hpp:114
int seed
Definition TrainingSpec.hpp:121
grouping::Grouping::Ptr grouping
Definition TrainingSpec.hpp:117
int threads
Definition TrainingSpec.hpp:122
vars::VariableSelection::Ptr vars
Definition TrainingSpec.hpp:113
pp::ProjectionPursuit::Ptr pp
Definition TrainingSpec.hpp:112
stop::StopRule::Ptr stop
Definition TrainingSpec.hpp:115
binarize::Binarization::Ptr binarization
Definition TrainingSpec.hpp:116