ppforest2 v0.1.0
Projection Pursuit Decision Trees and Random Forests
Loading...
Searching...
No Matches
GroupPartition.hpp
Go to the documentation of this file.
1#pragma once
2
3
4#include "utils/Types.hpp"
5#include "utils/Invariant.hpp"
6
7#include <map>
8#include <optional>
9#include <set>
10#include <vector>
11#include <Eigen/Dense>
12
13namespace ppforest2::stats {
41 using Group = types::GroupId;
42 using GroupSet = std::set<types::GroupId>;
43 using GroupMap = std::map<types::GroupId, types::GroupId>;
44 using GroupInvMap = std::map<types::GroupId, GroupSet>;
45 using GroupVector = types::GroupIdVector;
46
47 public:
49 static bool is_contiguous(GroupVector const& y);
50
57
66
72 GroupPartition(int start, int end);
73
85 GroupPartition bisect(int mid) const;
86
88 int group_start(Group const& group) const;
90 int group_end(Group const& group) const;
92 int group_size(Group const& group) const;
93
101 Group first_group() const {
102 invariant(!groups.empty(), "GroupPartition::first_group: partition is empty");
103 return *groups.begin();
104 }
105
114 int total_size() const;
115
127 template<typename Derived> auto group(Eigen::MatrixBase<Derived> const& x, Group const& group) const {
128 std::vector<int> indices;
129
130 auto const& subs = this->subgroups.at(group);
131
132 for (auto const& g : subs) {
133 for (int i = group_start(g); i <= group_end(g); ++i) {
134 invariant(i >= 0 && i < x.rows(), "GroupPartition::group: index out of bounds");
135 indices.push_back(i);
136 }
137 }
138
139 return x(indices, Eigen::all);
140 }
141
148 template<typename Derived> auto data(Eigen::MatrixBase<Derived> const& x) const {
149 std::vector<int> indices;
150
151 for (auto const& kv : blocks) {
152 auto const& g = kv.first;
153 for (int i = group_start(g); i <= group_end(g); ++i) {
154 indices.push_back(i);
155 }
156 }
157
158 return x(indices, Eigen::all);
159 }
160
167
174 GroupPartition subset(GroupSet const& groups) const;
175
176 using SplitSizes = std::map<types::GroupId, int>;
177
192 std::pair<GroupPartition, GroupPartition> split(SplitSizes const& left_sizes) const;
193
200 GroupPartition remap(GroupMap const& mapping) const;
201
208
210 GroupSet const groups;
212 GroupMap const supergroups;
214 GroupInvMap const subgroups;
215
216 private:
217 struct Block {
218 int start;
219 int end;
220 int size;
221 std::optional<types::GroupId> next;
222 std::optional<types::GroupId> prev;
223 };
224
225 using BlockMap = std::map<types::GroupId, Block>;
226 BlockMap const blocks;
227
228 static BlockMap init_blocks(GroupVector const& y);
229 static GroupMap init_supergroups(GroupSet const& groups);
230
231 explicit GroupPartition(BlockMap const& blocks);
232
233 GroupPartition(BlockMap const& blocks, GroupSet const& groups);
234
235 GroupPartition(BlockMap const& blocks, GroupMap const& supergroups);
236
237 GroupPartition(BlockMap const& blocks, GroupSet const& groups, GroupMap const& supergroups);
238 };
239}
void invariant(bool condition, char const *message)
Runtime assertion that throws on failure.
GroupPartition remap(GroupMap const &mapping) const
Merge groups according to a mapping.
GroupPartition(types::GroupIdVector const &y)
Construct from a sorted response vector.
GroupMap const supergroups
Maps each group to its supergroup (identity if no merge).
Definition GroupPartition.hpp:212
GroupPartition bisect(int mid) const
Bisect a single-group partition at row index mid into two groups.
Group first_group() const
Smallest group label in the partition.
Definition GroupPartition.hpp:101
types::FeatureMatrix wgss(types::FeatureMatrix const &x) const
Within-group sum of squares matrix (p × p).
types::FeatureVector mean(types::FeatureMatrix const &x) const
Overall mean of all grouped rows (p).
GroupSet const groups
Set of all group labels in this partition.
Definition GroupPartition.hpp:210
types::FeatureMatrix bgss(types::FeatureMatrix const &x) const
Between-group sum of squares matrix (p × p).
auto data(Eigen::MatrixBase< Derived > const &x) const
Extract all rows across all groups.
Definition GroupPartition.hpp:148
GroupPartition(int start, int end)
Construct a single-group partition covering rows [start, end].
int group_end(Group const &group) const
Last row index (inclusive) of the block for group.
int total_size() const
Total number of observations across all groups in the partition.
GroupPartition subset(GroupSet const &groups) const
Create a partition containing only the given groups.
GroupPartition(types::OutcomeVector const &y)
Construct from a float-typed response vector.
static bool is_contiguous(GroupVector const &y)
Check whether all equal values in y form a single contiguous block.
std::map< types::GroupId, int > SplitSizes
Definition GroupPartition.hpp:176
int group_start(Group const &group) const
First row index of the block for group.
int group_size(Group const &group) const
Number of observations in group.
GroupPartition collapse() const
Collapse all groups into a single supergroup.
std::pair< GroupPartition, GroupPartition > split(SplitSizes const &left_sizes) const
Split each group's block into left and right children.
auto group(Eigen::MatrixBase< Derived > const &x, Group const &group) const
Extract rows belonging to a group (or supergroup).
Definition GroupPartition.hpp:127
GroupInvMap const subgroups
Maps each group to its set of subgroups.
Definition GroupPartition.hpp:214
Statistical infrastructure for training and evaluation.
Definition ConfusionMatrix.hpp:11
Eigen::Matrix< Feature, Eigen::Dynamic, Eigen::Dynamic > FeatureMatrix
Dynamic-size matrix of feature values.
Definition Types.hpp:33
Eigen::Matrix< Outcome, Eigen::Dynamic, 1 > OutcomeVector
Dynamic-size column vector of predictions.
Definition Types.hpp:42
Eigen::Matrix< GroupId, Eigen::Dynamic, 1 > GroupIdVector
Dynamic-size column vector of internal group labels.
Definition Types.hpp:39
Eigen::Matrix< Feature, Eigen::Dynamic, 1 > FeatureVector
Dynamic-size column vector of feature values.
Definition Types.hpp:36
int GroupId
Scalar type for internal group labels (integer). Used as map keys, set elements, and partition indice...
Definition Types.hpp:27