Commit dd82d05d authored by davidkep's avatar davidkep

fix bugs in optimizers

parent 0cef7b8b
......@@ -62,9 +62,12 @@ struct DataCache {
//!
//! @param x matrix of predictor values.
//! @param residuals vector of residuals.
//! @param lambda the lambda value to use (overrides the lambda in the EN penalty)
//! @param penalty elastic net penalty object.
inline bool AnyViolateKKT(const arma::mat& x, const arma::vec& residuals, const EnPenalty& penalty) {
const double lambda_1 = residuals.n_elem * penalty.alpha() * penalty.lambda();
inline bool AnyViolateKKT(const arma::mat& x, const arma::vec& residuals, const double lambda,
const EnPenalty& penalty) {
// const double lambda_1 = residuals.n_elem * penalty.alpha() * penalty.lambda();
const double lambda_1 = penalty.alpha() * lambda;
for (arma::uword j = 0; j < x.n_cols; ++j) {
// const double cutoff = lambda1 * penalty.loadings()[j];
const double inner = arma::dot(x.col(j), residuals);
......@@ -79,9 +82,12 @@ inline bool AnyViolateKKT(const arma::mat& x, const arma::vec& residuals, const
//!
//! @param x matrix of predictor values.
//! @param residuals vector of residuals.
//! @param lambda the lambda value to use (overrides the lambda in the EN penalty)
//! @param penalty adaptive elastic net penalty object.
inline bool AnyViolateKKT(const arma::mat& x, const arma::vec& residuals, const AdaptiveEnPenalty& penalty) {
const double lambda_1 = residuals.n_elem * penalty.alpha() * penalty.lambda();
inline bool AnyViolateKKT(const arma::mat& x, const arma::vec& residuals, const double lambda,
const AdaptiveEnPenalty& penalty) {
// const double lambda_1 = residuals.n_elem * penalty.alpha() * penalty.lambda();
const double lambda_1 = penalty.alpha() * lambda;
for (arma::uword j = 0; j < x.n_cols; ++j) {
const double cutoff = lambda_1 * penalty.loadings()[j];
const double inner = arma::dot(x.col(j), residuals);
......@@ -323,19 +329,18 @@ class AdmmLinearOptimizer : public Optimizer<LossFunction, PenaltyFunction, Coef
// Check if the data needs to be updated
const bool reset = !data_;
bool check_empty = reset || admm_optimizer::AllZero(coefs_.beta);
if (reset) {
UpdateData(IsWeightedTag{});
}
const bool include_intercept = loss_->IncludeIntercept();
const double scaled_lambda = ScaledLambda(IsWeightedTag{});
const bool check_empty = admm_optimizer::AllZero(coefs_.beta) || (coefs_.beta.n_elem != data_->n_pred());
// Check if the coefficients are correct.
if (coefs_.beta.n_elem != data_->n_pred()) {
coefs_.beta.zeros(data_->n_pred());
coefs_.intercept = 0;
check_empty = true;
}
std::unique_ptr<Metrics> metrics(new Metrics("admm"));
......@@ -357,11 +362,10 @@ class AdmmLinearOptimizer : public Optimizer<LossFunction, PenaltyFunction, Coef
fitted += InterceptUpdate(IsWeightedTag{});
}
if (check_empty) {
if (!admm_optimizer::AnyViolateKKT(data_->cx(), data_->cy() - fitted, *penalty_)) {
// None of the predictors will be activated for the current penalty. Return the current coefficient value.
return FinalizeResult(0, 0, OptimumStatus::kOk, std::move(metrics));
}
if (check_empty && !admm_optimizer::AnyViolateKKT(data_->cx(), data_->cy() - fitted, scaled_lambda / step_size_,
*penalty_)) {
// None of the predictors will be activated for the current penalty. Return the current coefficient value.
return FinalizeResult(0, 0, OptimumStatus::kOk, std::move(metrics));
}
// Check if the state needs to be re-initialized
......@@ -724,7 +728,7 @@ class AdmmVarStepOptimizer : public Optimizer<LossFunction, PenaltyFunction, Coe
const bool include_intercept = loss_->IncludeIntercept();
const double scaled_lambda = ScaledLambda(IsWeightedTag{});
bool check_empty = admm_optimizer::AllZero(coefs_.beta);
const bool check_empty = admm_optimizer::AllZero(coefs_.beta) || (coefs_.beta.n_elem != data_->n_pred());
std::unique_ptr<Metrics> metrics(new Metrics("admm"));
metrics->AddDetail("type", "var-stepsize");
......@@ -733,15 +737,13 @@ class AdmmVarStepOptimizer : public Optimizer<LossFunction, PenaltyFunction, Coe
if (coefs_.beta.n_elem != data_->n_pred()) {
coefs_.beta.zeros(data_->n_pred());
coefs_.intercept = 0;
check_empty = true;
}
// Check if any of the predictors might be active.
if (check_empty) {
if (!admm_optimizer::AnyViolateKKT(data_->cx(), data_->cy() - arma::mean(data_->cy()), *penalty_)) {
// None of the predictors will be activated for the current penalty. Return the current coefficient value.
return FinalizeResult(0, OptimumStatus::kOk, std::move(metrics));
}
if (check_empty && !admm_optimizer::AnyViolateKKT(data_->cx(), data_->cy() - arma::mean(data_->cy()), scaled_lambda,
*penalty_)) {
// None of the predictors will be activated for the current penalty. Return the current coefficient value.
return FinalizeResult(0, OptimumStatus::kOk, std::move(metrics));
}
// Compute the upper limit for tau (if we would do a restart)
......
......@@ -137,8 +137,8 @@ class DalEnOptimizer : public Optimizer<LossFunction, PenaltyFunction, Regressio
coefs_.Reset();
}
auto changes = data_.Update(loss);
loss_.reset(new LossFunction(loss));
auto changes = data_.Update(*loss_);
if (changes.data_changed || changes.weights_changed > 1) {
// If the data changed, the proximity parameters must be reset.
......
......@@ -138,6 +138,10 @@ template<typename LossFunction>
class DataProxy<LossFunction, std::true_type> {
public:
DataProxy() = default;
//! Important: the DataProxy only retains a reference to the loss' data and weights. Therefore, it is the client's
//! responsibility to ensure that the passed-in loss is available until *after* the `Update` method is called with
//! a new loss!
explicit DataProxy(LossFunction const * const loss)
: data_(loss ? &(loss->data()) : nullptr), sqrt_weights_(loss ? &(loss->sqrt_weights()) : nullptr),
mean_weight_(loss ? loss->mean_weight() : 1.),
......@@ -153,6 +157,9 @@ class DataProxy<LossFunction, std::true_type> {
DataProxy& operator=(DataProxy&& other) = default;
//! Update the data proxy with a new loss function.
//! Important: the DataProxy only retains a reference to the loss' data and weights. Therefore, it is the client's
//! responsibility to ensure that the passed-in loss is available until *after* the `Update` method is called with
//! a new loss!
//!
//! @param the new loss function.
//! @return information on what data changed.
......
......@@ -163,6 +163,7 @@ class MMOptimizer : public Optimizer<LossFunction, PenaltyFunction, Coefficients
//! @return information about the optimum.
Optimum Optimize(const Coefficients& start, const int max_it) {
coefs_ = start;
optimizer_.Reset();
return Optimize(max_it);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment