Commit abc14099 authored by davidkep's avatar davidkep

Merge branch 'feature/generic-admm' into develop

parents af240ac6 3f071d9d
......@@ -26,15 +26,15 @@
namespace nsoptim {
//! Configuration options for the ADMM algorithm.
//! Configuration options for the variable-stepsize ADMM algorithm.
//!
//! The members have the following meaning:
//! max_it ... maximum number of iterations allowed.
//! tau ... the step size. If negative (the default), use the square L_2 norm of `x` (for linearized ADMM).
//! tau ... the step size. If negative (the default), use the square L_2 norm of `x`.
//! tau_lower_mult ... lower bound for the step size, defined as a multiple of `tau`.
//! tau_adjustment_lower ... lower bound of the step-size adjustment factor.
//! tau_adjustment_upper ... upper bound of the step-size adjustment factor.
struct AdmmConfiguration {
struct AdmmVarStepConfiguration {
int max_it;
double tau;
double tau_lower_mult;
......@@ -42,9 +42,20 @@ struct AdmmConfiguration {
double tau_adjustment_upper;
};
//! Configuration options for the linearized ADMM algorithm.
//!
//! The members have the following meaning:
//! max_it ... maximum number of iterations allowed.
struct AdmmLinearConfiguration {
int max_it;
};
namespace admm_optimizer {
//! Default configuration for ADMM-type algorithms
constexpr AdmmConfiguration kDefaultAdmmConfiguration = {1000, -1, 0.01, 0.98, 0.999};
//! Default configuration for the variable-stepsize ADMM algorithm
constexpr AdmmVarStepConfiguration kDefaultVarStepConfig { 1000, -1, 0.01, 0.98, 0.999 };
//! Default configuration for the variable-stepsize ADMM algorithm
constexpr AdmmLinearConfiguration kDefaultLinConfig { 1000 };
//! How often does the secondary convergence criterion (relative to the maximum number of iterations) need to be
//! fulfulled to stop the linearized ADMM early.
constexpr int kSecondCriterionMultiplier = 10;
......@@ -58,6 +69,16 @@ struct DataCache {
arma::mat chol_xtx;
};
//! Check whether any of the predictors in `x` violates the KKT conditions for a EN-type problem.
//!
//! @param x matrix of predictor values.
//! @param residuals vector of residuals.
//! @param lambda the lambda value to use (overrides the lambda in the EN penalty)
//! @param penalty elastic net penalty object.
inline bool AnyViolateKKT(const arma::mat&, const arma::vec&, const double, const RidgePenalty&) {
return true;
}
//! Check whether any of the predictors in `x` violates the KKT conditions for a EN-type problem.
//!
//! @param x matrix of predictor values.
......@@ -67,7 +88,7 @@ struct DataCache {
inline bool AnyViolateKKT(const arma::mat& x, const arma::vec& residuals, const double lambda,
const EnPenalty& penalty) {
// const double lambda_1 = residuals.n_elem * penalty.alpha() * penalty.lambda();
const double lambda_1 = penalty.alpha() * lambda;
const double lambda_1 = penalty.alpha() * lambda * residuals.n_elem;
for (arma::uword j = 0; j < x.n_cols; ++j) {
// const double cutoff = lambda1 * penalty.loadings()[j];
const double inner = arma::dot(x.col(j), residuals);
......@@ -87,7 +108,7 @@ inline bool AnyViolateKKT(const arma::mat& x, const arma::vec& residuals, const
inline bool AnyViolateKKT(const arma::mat& x, const arma::vec& residuals, const double lambda,
const AdaptiveEnPenalty& penalty) {
// const double lambda_1 = residuals.n_elem * penalty.alpha() * penalty.lambda();
const double lambda_1 = penalty.alpha() * lambda;
const double lambda_1 = penalty.alpha() * lambda * residuals.n_elem;
for (arma::uword j = 0; j < x.n_cols; ++j) {
const double cutoff = lambda_1 * penalty.loadings()[j];
const double inner = arma::dot(x.col(j), residuals);
......@@ -143,46 +164,395 @@ inline bool SolveChol(const arma::mat& chol, arma::vec * const b) {
return true;
}
} // namespace admm_optimizer
//! Proximal operator for the unweighted LS loss.
class LsProximalOperator {
public:
using LossFunction = LsLoss;
//! Initialize the proximal operator with fixed step size `1 / tau`.
explicit LsProximalOperator(const double tau = -1) noexcept : config_tau_(tau) {}
//! Set the loss function for the proximal operator.
//!
//! @param loss the LS-loss for optimization. The object retains only a reference to the loss, so it is
//! the user's responsibility to not use the object after the loss is removed!
inline void loss(LsLoss* loss) noexcept {
loss_ = loss;
}
//! Compute the proximal operator `v` for the given input parameters.
//!
//! @param u
//! @param v_prev ignored.
//! @param intercept
//! @param lambda
//! @param metrics optional metrics object to collect metrics of the proximal operator
inline arma::vec operator()(const arma::vec& u, const arma::vec&, const double intercept, const double lambda,
Metrics * const = nullptr) {
return this->operator()(u, intercept, lambda);
}
//! Compute the proximal operator `v` for the given input parameters.
//!
//! @param u
//! @param intercept
//! @param lambda
//! @param metrics optional metrics object to collect metrics of the proximal operator
inline arma::vec operator()(const arma::vec& u, const double intercept, const double lambda,
Metrics * const = nullptr) {
const int n = loss_->data().n_obs();
const double mult_fact = 1 / (n + lambda);
if (loss_->IncludeIntercept()) {
return n * mult_fact * u + lambda * mult_fact * (loss_->data().cy() - intercept);
} else {
return n * mult_fact * u + lambda * mult_fact * loss_->data().cy();
}
}
//! Compute the intercept.
inline double ComputeIntercept(const arma::vec& fitted) const noexcept {
return arma::mean(loss_->data().cy() - fitted);
}
//! Compute the step size for the currently set loss.
//!
//! @param penalty the current penalty value.
//! @return the loss-specific step size.
double StepSize(const EnPenalty& penalty, const double norm_x) const {
if (config_tau_ < 0) {
const PredictorResponseData& data = loss_->data();
const double n_obs_dbl = static_cast<double>(data.n_obs());
if (data.n_obs() < data.n_pred()) {
const double expo = std::max(0.5, n_obs_dbl / data.n_pred());
return std::min((data.n_pred() * std::pow(norm_x, expo)),
norm_x * n_obs_dbl / std::sqrt(data.n_pred() * penalty.lambda()));
} else {
const double expo = std::max(0.5, data.n_pred() / n_obs_dbl);
return std::min(std::max(1., (data.n_pred() * std::pow(norm_x, expo))),
norm_x * n_obs_dbl / std::sqrt(n_obs_dbl * penalty.lambda()));
}
}
return 1 / config_tau_;
}
//! Compute the step size for the currently set loss.
//!
//! @param penalty the current penalty value.
//! @return the loss-specific step size.
double StepSize(const AdaptiveEnPenalty& penalty, const double norm_x) const {
if (config_tau_ < 0) {
const PredictorResponseData& data = loss_->data();
const double n_obs_dbl = static_cast<double>(data.n_obs());
if (data.n_obs() < data.n_pred()) {
const double expo = std::max(0.5, n_obs_dbl / data.n_pred());
return std::min((data.n_pred() * std::pow(norm_x, expo)) / n_obs_dbl,
norm_x / std::sqrt(data.n_pred() * penalty.lambda()));
} else {
const double expo = std::max(0.5, data.n_pred() / n_obs_dbl);
return std::min(std::max(1., (data.n_pred() * std::pow(norm_x, expo)) / n_obs_dbl),
norm_x / std::sqrt(n_obs_dbl * penalty.lambda()));
}
}
return 1 / config_tau_;
}
//! Compute the step size for the currently set loss.
//!
//! @param penalty the current penalty value.
//! @return the loss-specific step size.
inline double StepSize(const RidgePenalty& penalty, const double norm_x) const {
return StepSize(EnPenalty(penalty), norm_x);
}
//! Compute the step size for the currently set loss.
//!
//! @param penalty the current penalty value.
//! @return the loss-specific step size.
inline double StepSize(const LassoPenalty& penalty, const double norm_x) const {
return StepSize(EnPenalty(penalty), norm_x);
}
private:
double config_tau_;
LsLoss * loss_;
};
//! Proximal operator for the weighted LS loss.
class WeightedLsProximalOperator {
public:
using LossFunction = WeightedLsLoss;
//! Initialize the proximal operator with fixed step size `1 / tau`.
explicit WeightedLsProximalOperator(const double tau = -1) noexcept : config_tau_(tau) {}
//! Set the loss function for the proximal operator.
//!
//! @param loss the LS-loss for optimization. The object retains only a reference to the loss, so it is
//! the user's responsibility to not use the object after the loss is removed!
inline void loss(WeightedLsLoss* loss) noexcept {
loss_ = loss;
weights_ = loss_->weights();
}
//! Compute the proximal operator `v` for the given input parameters.
//!
//! @param u
//! @param v_prev ignored.
//! @param intercept
//! @param lambda
//! @param metrics optional metrics object to collect metrics of the proximal operator
inline arma::vec operator()(const arma::vec& u, const arma::vec&, const double intercept, const double lambda,
Metrics * const = nullptr) const {
return this->operator()(u, intercept, lambda);
}
//! Compute the proximal operator `v` for the given input parameters.
//!
//! @param u
//! @param intercept
//! @param lambda
//! @param metrics optional metrics object to collect metrics of the proximal operator
inline arma::vec operator()(const arma::vec& u, const double intercept, const double lambda,
Metrics * const = nullptr) const {
const auto n = loss_->data().n_obs();
if (loss_->IncludeIntercept()) {
return (n * u + lambda * weights_ % (loss_->data().cy() - intercept)) /
(n + lambda * weights_);
} else {
return (n * u + lambda * weights_ % loss_->data().cy()) / (n + lambda * weights_);
}
}
//! Compute the intercept.
inline double ComputeIntercept(const arma::vec& fitted) const noexcept {
return arma::mean((loss_->data().cy() - fitted) % weights_) / loss_->mean_weight();
}
//! Compute the step size for the currently set loss.
//!
//! @param penalty the current penalty value.
//! @return the loss-specific step size.
double StepSize(const EnPenalty& penalty, const double norm_x) const {
if (config_tau_ < 0) {
const PredictorResponseData& data = loss_->data();
const double n_obs_weighted = data.n_obs() * loss_->mean_weight();
if (n_obs_weighted < data.n_pred()) {
const double expo = std::max(0.5, n_obs_weighted / data.n_pred());
return std::min(data.n_pred() * std::pow(norm_x, expo) / loss_->mean_weight(),
norm_x * data.n_obs() / std::sqrt(data.n_pred() * penalty.lambda()));
} else {
const double expo = std::max(0.5, data.n_pred() / n_obs_weighted);
return std::min(std::max(1., data.n_pred() * std::pow(norm_x, expo)) / loss_->mean_weight(),
norm_x * data.n_obs() / std::sqrt(n_obs_weighted * penalty.lambda()));
}
}
return 1 / config_tau_;
}
//! Compute the step size for the currently set loss.
//!
//! @param penalty the current penalty value.
//! @return the loss-specific step size.
double StepSize(const AdaptiveEnPenalty& penalty, const double norm_x) const {
if (config_tau_ < 0) {
const PredictorResponseData& data = loss_->data();
const double n_obs_weighted = data.n_obs() * loss_->mean_weight();
if (n_obs_weighted < data.n_pred()) {
const double expo = std::max(0.5, n_obs_weighted / data.n_pred());
return std::min((data.n_pred() * std::pow(norm_x, expo)),
norm_x * n_obs_weighted / std::sqrt(data.n_pred() * penalty.lambda()));
} else {
const double expo = std::max(0.5, data.n_pred() / n_obs_weighted);
return std::min(std::max(1., (data.n_pred() * std::pow(norm_x, expo))),
norm_x * n_obs_weighted / std::sqrt(n_obs_weighted * penalty.lambda()));
}
}
return 1 / config_tau_;
}
//! Compute the step size for the currently set loss.
//!
//! @param penalty the current penalty value.
//! @return the loss-specific step size.
inline double StepSize(const RidgePenalty& penalty, const double norm_x) const {
return StepSize(EnPenalty(penalty), norm_x);
}
//! Compute the step size for the currently set loss.
//!
//! @param penalty the current penalty value.
//! @return the loss-specific step size.
inline double StepSize(const LassoPenalty& penalty, const double norm_x) const {
return StepSize(EnPenalty(penalty), norm_x);
}
private:
double config_tau_;
WeightedLsLoss * loss_;
arma::vec weights_;
};
namespace admm_optimizer {
//! Type trait mapping the LsLoss to the LsProximalOperator, the WeightedLsLoss to the WeightedLsProximalOperator,
//! and any other type to itself.
template <typename T>
using ProximalOperator = typename std::conditional<
std::is_same<T, LsLoss>::value, LsProximalOperator,
typename std::conditional<std::is_same<T, WeightedLsLoss>::value,
WeightedLsProximalOperator, T>::type >::type;
} // namespace admm_optimizer
//! Compute the EN regression estimate using the alternating direction method of multiplier (ADMM)
//! with linearization
template <typename LossFunction, typename PenaltyFunction, typename Coefficients>
class AdmmLinearOptimizer : public Optimizer<LossFunction, PenaltyFunction, Coefficients> {
//! with linearization. This optimizer uses the given proximal operator class `ProxOp`.
//!
//! A proximal operator needs to implement the following methods:
//! void loss(const LossFunction& loss) ... change the loss function to `loss`.
//! arma::vec operator()(const vec& u, const vec& prev, const double intercept, const double lambda, Metrics * metrics)
//! ... get the value of the proximal operator of the function scaled by `lambda`, evaluated at `u`. The argument
//! `prev` is the previous value returned by the proximal operator and `intercept` is the current value
//! of the intercept or 0, if the loss does not use an intercept term.
//! double ComputeIntercept(const arma::vec& fitted) ... compute the intercept term, given the fitted values.
//! double StepSize(const PenaltyFunction& penalty, const double norm_x) ... get the step size required for the
//! loss function if the penalty is as given.
//!
//! See `LsProximalOperator` and `WeightedLsProximalOperator` for example implementations of the proximal operator.
template <typename ProxOp, typename PenaltyFunction, typename Coefficients>
class GenericLinearizedAdmmOptimizer : public Optimizer<typename ProxOp::LossFunction, PenaltyFunction, Coefficients> {
public:
using ProximalOperator = ProxOp;
using LossFunction = typename ProxOp::LossFunction;
private:
using Base = Optimizer<LossFunction, PenaltyFunction, Coefficients>;
using LossFunctionPtr = std::unique_ptr<LossFunction>;
using PenaltyPtr = std::unique_ptr<PenaltyFunction>;
using IsWeightedTag = typename traits::is_weighted<LossFunction>::type;
using IsAdaptiveTag = typename traits::is_adaptive<PenaltyFunction>::type;
using WeightsType = typename std::conditional<IsWeightedTag::value, arma::vec, char>::type;
static_assert(traits::is_en_penalty<PenaltyFunction>::value, "PenaltyFunction must be an EN-type penalty.");
static_assert(traits::is_ls_loss<LossFunction>::value, "LossFunction must be an least-squares-type loss.");
// ADMM state variables
struct State {
arma::vec v;
arma::vec l;
};
// Helper-traits to identify constructor arguments.
template<typename T>
using IsConfiguration = std::is_same<T, AdmmLinearConfiguration>;
template<typename T>
using IsLossFunction = std::is_same<T, LossFunction>;
template<typename T>
using IsProximalOperator = std::is_same<T, ProximalOperator>;
public:
using Optimum = typename Base::Optimum;
//! Ininitialize the optimizer using the given (weighted) LS loss function and the Ridge penalty.
//! Initialize the ADMM optimizer without setting a loss or penalty function.
GenericLinearizedAdmmOptimizer() noexcept
: config_(admm_optimizer::kDefaultLinConfig), prox_(), loss_(nullptr), penalty_(nullptr) {}
//! Initialize the ADMM optimizer with the given loss and penalty functions.
//!
//! @param loss a weighted LS loss function.
//! @param penalty the Ridge penalty.
explicit AdmmLinearOptimizer(const AdmmConfiguration& config = admm_optimizer::kDefaultAdmmConfiguration) noexcept
: config_(config), loss_(nullptr), penalty_(nullptr) {}
//! @param loss the loss function object.
//! @param penalty the penalty function object.
GenericLinearizedAdmmOptimizer(const LossFunction& loss, const PenaltyFunction& penalty) noexcept
: config_(admm_optimizer::kDefaultLinConfig), prox_(),
loss_(new LossFunction(loss)), penalty_(new PenaltyFunction(penalty)) {}
//! Ininitialize the optimizer using the given (weighted) LS loss function and the Ridge penalty.
//! Initialize the ADMM optimizer with the given loss and penalty functions.
//!
//! @param loss a weighted LS loss function.
//! @param penalty the Ridge penalty.
AdmmLinearOptimizer(const LossFunction& loss, const PenaltyFunction& penalty,
const AdmmConfiguration& config = admm_optimizer::kDefaultAdmmConfiguration) noexcept
: config_(config), loss_(new LossFunction(loss)), penalty_(new PenaltyFunction(penalty)) {}
//! @param loss the loss function object.
//! @param penalty the penalty function object.
//! @param prox proximal operator object.
//! @param config optional ADMM configuration object.
GenericLinearizedAdmmOptimizer(const LossFunction& loss, const PenaltyFunction& penalty, const ProximalOperator& prox,
const AdmmLinearConfiguration& config = admm_optimizer::kDefaultLinConfig) noexcept
: config_(config), prox_(prox), loss_(new LossFunction(loss)), penalty_(new PenaltyFunction(penalty)) {}
//! Initialize the ADMM optimizer without setting a loss or penalty function.
//!
//! @param prox_arg_1 first argument to constructor of the proximal operator.
//! @param prox_args... further arguments to the constructor of the proximal operator.
template<typename T, typename... Args, typename = typename
std::enable_if<!IsConfiguration<T>::value && !IsLossFunction<T>::value, void>::type >
explicit GenericLinearizedAdmmOptimizer(T&& prox_arg_1, Args&&... prox_args) noexcept
: config_(admm_optimizer::kDefaultLinConfig),
prox_(std::forward<T>(prox_arg_1), std::forward<Args>(prox_args)...),
loss_(nullptr), penalty_(nullptr) {}
//! Initialize the ADMM optimizer without setting a loss or penalty function.
//!
//! @param config ADMM configuration object.
//! @param prox_args... arguments to the constructor of the proximal operator.
template<typename C, typename... Args, typename = typename
std::enable_if<IsConfiguration<C>::value, void>::type >
explicit GenericLinearizedAdmmOptimizer(const C& config, Args&&... prox_args) noexcept
: config_(config), prox_(std::forward<Args>(prox_args)...), loss_(nullptr), penalty_(nullptr) {}
//! Initialize the ADMM optimizer without setting a loss or penalty function.
//!
//! @param prox_arg_1 first argument to constructor of the proximal operator.
//! @param prox_arg_2 second argument to constructor of the proximal operator.
//! @param prox_args... further arguments to the constructor of the proximal operator.
template<typename T1, typename T2, typename... Args, typename = typename
std::enable_if<!IsConfiguration<T1>::value && !IsLossFunction<T1>::value, void>::type >
GenericLinearizedAdmmOptimizer(T1&& prox_arg_1, T2&& prox_arg_2, Args&&... prox_args) noexcept
: config_(admm_optimizer::kDefaultLinConfig),
prox_(std::forward<T1>(prox_arg_1), std::forward<T2>(prox_arg_2), std::forward<Args>(prox_args)...),
loss_(nullptr), penalty_(nullptr) {}
//! Initialize the ADMM optimizer without setting a loss or penalty function.
//!
//! @param config ADMM configuration object.
//! @param prox_arg_1 first argument to constructor of the proximal operator.
//! @param prox_args... further arguments to the constructor of the proximal operator.
template<typename C, typename T1, typename... Args, typename = typename
std::enable_if<IsConfiguration<C>::value, void>::type >
GenericLinearizedAdmmOptimizer(const C& config, T1&& prox_arg_1, Args&&... prox_args) noexcept
: config_(config), prox_(std::forward<T1>(prox_arg_1), std::forward<Args>(prox_args)...),
loss_(nullptr), penalty_(nullptr) {}
//! Initialize the ADMM optimizer without setting a loss or penalty function.
//!
//! @param config ADMM configuration object.
//! @param prox_arg_1 first argument to constructor of the proximal operator.
//! @param prox_arg_2 second argument to constructor of the proximal operator.
//! @param prox_args... further arguments to the constructor of the proximal operator.
template<typename C, typename T1, typename T2, typename... Args, typename = typename
std::enable_if<IsConfiguration<C>::value, void>::type >
GenericLinearizedAdmmOptimizer(const C& config, T1&& prox_arg_1, T2&& prox_arg_2, Args&&... prox_args) noexcept
: config_(config),
prox_(std::forward<T1>(prox_arg_1), std::forward<T2>(prox_arg_2), std::forward<Args>(prox_args)...),
loss_(nullptr), penalty_(nullptr) {}
//! Initialize the ADMM optimizer without setting a loss or penalty function.
//!
//! @param loss the loss function object.
//! @param penalty the penalty function object.
//! @param prox_arg_1 first argument to constructor of the proximal operator.
//! @param prox_args... further arguments to the constructor of the proximal operator.
template<typename L, typename P, typename T1, typename... Args, typename = typename
std::enable_if<IsLossFunction<L>::value && !IsConfiguration<T1>::value, void>::type >
GenericLinearizedAdmmOptimizer(const L& loss, const P& penalty, T1&& prox_arg_1, Args&&... prox_args) noexcept
: config_(admm_optimizer::kDefaultLinConfig),
prox_(std::forward<T1>(prox_arg_1), std::forward<Args>(prox_args)...),
loss_(new LossFunction(loss)), penalty_(new PenaltyFunction(penalty)) {}
//! Initialize the ADMM optimizer without setting a loss or penalty function.
//!
//! @param loss the loss function object.
//! @param penalty the penalty function object.
//! @param config ADMM configuration object.
//! @param prox_args... further arguments to the constructor of the proximal operator.
template<typename L, typename P, typename C, typename... Args, typename = typename
std::enable_if<IsLossFunction<L>::value && IsConfiguration<C>::value, void>::type >
GenericLinearizedAdmmOptimizer(const L& loss, const P& penalty, const C& config, Args&&... prox_args) noexcept
: config_(config), prox_(std::forward<Args>(prox_args)...),
loss_(new LossFunction(loss)), penalty_(new PenaltyFunction(penalty)) {}
//! Default copy constructor.
//!
......@@ -190,15 +560,15 @@ class AdmmLinearOptimizer : public Optimizer<LossFunction, PenaltyFunction, Coef
//! In case the loss or penalty function are mutated in any way, the change will affect both optimizers.
//! If the loss/penalty function is changed on one of the optimizers (using the `loss()` or `penalty()` methods),
//! the two optimizers will *not* share the new loss/penalty function.
AdmmLinearOptimizer(const AdmmLinearOptimizer& other) noexcept
GenericLinearizedAdmmOptimizer(const GenericLinearizedAdmmOptimizer& other) noexcept
: config_(other.config_),
prox_(other.prox_),
loss_(other.loss_? new LossFunction(*other.loss_) : nullptr),
penalty_(other.penalty_ ? new PenaltyFunction(*other.penalty_) : nullptr),
coefs_(other.coefs_),
state_(other.state_),
step_size_(other.step_size_),
norm_x_(other.norm_x_),
norm_x_sq_inv_(other.norm_x_sq_inv_),
convergence_tolerance_(other.convergence_tolerance_) {}
//! Default copy assignment.
......@@ -207,27 +577,27 @@ class AdmmLinearOptimizer : public Optimizer<LossFunction, PenaltyFunction, Coef
//! In case the loss or penalty function are mutated in any way, the change will affect both optimizers.
//! If the loss/penalty function is changed on one of the optimizers (using the `loss()` or `penalty()` methods),
//! the two optimizers will *not* share the new loss/penalty function.
AdmmLinearOptimizer& operator=(const AdmmLinearOptimizer& other) = default;
GenericLinearizedAdmmOptimizer& operator=(const GenericLinearizedAdmmOptimizer& other) = default;
//! Default move constructor.
AdmmLinearOptimizer(AdmmLinearOptimizer&& other) = default;
GenericLinearizedAdmmOptimizer(GenericLinearizedAdmmOptimizer&& other) = default;
//! Default move assignment operator.
AdmmLinearOptimizer& operator=(AdmmLinearOptimizer&& other) = default;
GenericLinearizedAdmmOptimizer& operator=(GenericLinearizedAdmmOptimizer&& other) = default;
~AdmmLinearOptimizer() = default;
~GenericLinearizedAdmmOptimizer() = default;
AdmmLinearOptimizer Clone() const {
GenericLinearizedAdmmOptimizer Clone() const {
if (!loss_ || !penalty_) {
return AdmmLinearOptimizer();
return GenericLinearizedAdmmOptimizer();
}
AdmmLinearOptimizer clone(loss_->Clone(), penalty_->Clone());
GenericLinearizedAdmmOptimizer clone(loss_->Clone(), penalty_->Clone(), prox_, config_);
return clone;
}
AdmmLinearOptimizer Copy() {
return AdmmLinearOptimizer(*this);
GenericLinearizedAdmmOptimizer Copy() {
return GenericLinearizedAdmmOptimizer(*this);
}
void Reset() {
......@@ -243,6 +613,7 @@ class AdmmLinearOptimizer : public Optimizer<LossFunction, PenaltyFunction, Coef
void loss(const LossFunction& loss) noexcept {
loss_.reset(new LossFunction(loss));
prox_.loss(loss_.get());
// If the data changes, reset the ADMM state, but not the initial coefficient values.
state_.v.reset();
......@@ -252,7 +623,6 @@ class AdmmLinearOptimizer : public Optimizer<LossFunction, PenaltyFunction, Coef
arma::norm(loss_->data().cx(), 2);
norm_x_sq_inv_ = 1 / (norm_x_ * norm_x_);
UpdateWeights(IsWeightedTag{});
}
PenaltyFunction& penalty() const {
......@@ -308,10 +678,8 @@ class AdmmLinearOptimizer : public Optimizer<LossFunction, PenaltyFunction, Coef
}
coefs_ = start;
state_ = State {
ProximalLs(loss_->data().cx() * coefs_.beta, IsWeightedTag{}),
arma::zeros(loss_->data().n_obs())
};
// Reset the state (at least `v` so that the Optimize function knows that the state needs to be re-initialized)
state_.v.reset();
return Optimize(max_it);
}
......@@ -331,8 +699,7 @@ class AdmmLinearOptimizer : public Optimizer<LossFunction, PenaltyFunction, Coef
const PredictorResponseData& data = loss_->data();
const bool include_intercept = loss_->IncludeIntercept();
const double scaled_lambda = data.n_obs() * penalty_->lambda();
const bool check_empty = admm_optimizer::AllZero(coefs_.beta) || (coefs_.beta.n_elem != data.n_pred());
// const bool check_empty = admm_optimizer::AllZero(coefs_.beta) || (coefs_.beta.n_elem != data.n_pred());
// Check if the coefficients are correct.
if (coefs_.beta.n_elem != data.n_pred()) {
......@@ -345,30 +712,23 @@ class AdmmLinearOptimizer : public Optimizer<LossFunction, PenaltyFunction, Coef
// Compute the step size
const double old_step_size = step_size_;
UpdateStepSize(IsWeightedTag{});
step_size_ = prox_.StepSize(*penalty_, norm_x_);
// This is the convergence tolerance for the "un-standardized" residual.
const double conv_tol = convergence_tolerance_ * penalty_->alpha() * step_size_ * step_size_;
const auto en_cutoff = DetermineCutoff(IsAdaptiveTag{});
const double en_multiplier = 1 / (1 + norm_x_sq_inv_ * step_size_ * scaled_lambda * (1 - penalty_->alpha()));
const double en_multiplier = 1 / (1 + norm_x_sq_inv_ * step_size_ * penalty_->lambda() * (1 - penalty_->alpha()));
double gap = 0;
arma::vec fitted = data.cx() * coefs_.beta;
if (include_intercept) {
coefs_.intercept = ComputeIntercept(fitted, IsWeightedTag{});
}
if (check_empty && !admm_optimizer::AnyViolateKKT(data.cx(), EmptyModelResiduals(IsWeightedTag{}), scaled_lambda,
*penalty_)) {
// None of the predictors will be activated for the current penalty. Return the current coefficient value.
return FinalizeResult(0, 0, OptimumStatus::kOk, std::move(metrics));
coefs_.intercept = prox_.ComputeIntercept(fitted);
}
// Check if the state needs to be re-initialized
if (state_.v.n_elem != fitted.n_elem) {
// This is the ProximalLS function, but inlined...
state_.v = ProximalLs(fitted, IsWeightedTag{});
state_.v = prox_(fitted, coefs_.intercept, step_size_);
state_.l.zeros(data.n_obs());
} else if (old_step_size != step_size_) {
// Adjust the slack variable for the udpated step size.
......@@ -382,13 +742,11 @@ class AdmmLinearOptimizer : public Optimizer<LossFunction, PenaltyFunction, Coef
metrics->AddDetail("step_size", step_size_);
int iter = 0;
int second_criterion = 0;
State prev_state;
while (iter++ < max_it) {
// State prev_state = state_;
prev_state.v = state_.v;
prev_state.l = state_.l;
Metrics& iter_metrics = metrics->CreateSubMetrics("admm-iteration");
prev_state = state_;
// remember: fitted is already fitted - state_.v
coefs_.beta = en_multiplier * SoftThreshold(coefs_.beta, -norm_x_sq_inv_,
data.cx().t() * (fitted + state_.l), en_cutoff);
......@@ -396,27 +754,29 @@ class AdmmLinearOptimizer : public Optimizer<LossFunction, PenaltyFunction, Coef
fitted = data.cx() * coefs_.beta;
if (include_intercept) {
coefs_.intercept = ComputeIntercept(fitted, IsWeightedTag{});
coefs_.intercept = prox_.ComputeIntercept(fitted);
}
// This is the ProximalLS function, but inlined...
state_.v = ProximalLs(fitted + state_.l, IsWeightedTag{});
state_.v = prox_(fitted + state_.l, state_.v, coefs_.intercept, step_size_,
&(iter_metrics.CreateSubMetrics("prox")));
fitted -= state_.v;
state_.l += fitted;
const double v_diff = arma::accu(arma::square(state_.v - prev_state.v));
const double l_diff = arma::accu(arma::square(state_.l - prev_state.l));
const double diff_scaling = arma::norm(state_.v, 2) + arma::norm(state_.l, 2);
// This is the "un-standardized" residual, without the division by the squared step-size.
gap = v_diff + l_diff;
if ((gap < conv_tol) || second_criterion > max_it) {
gap = (v_diff + l_diff) * diff_scaling * diff_scaling;
iter_metrics.AddDetail("v_diff", v_diff);
iter_metrics.AddDetail("l_diff", l_diff);
iter_metrics.AddDetail("gap", gap);
if (gap < conv_tol) {
return FinalizeResult(iter, gap, OptimumStatus::kOk, std::move(metrics));
} else if (gap * gap < conv_tol && (l_diff < gap || v_diff < gap)) {
second_criterion += admm_optimizer::kSecondCriterionMultiplier;
} else {
second_criterion = 0;
}
}
return FinalizeResult(iter, gap, OptimumStatus::kWarning, "ADMM-algorithm did not converge.", std::move(metrics));
return FinalizeResult(--iter, gap, OptimumStatus::kWarning, "ADMM-algorithm did not converge.", std::move(metrics));
}
private:
......@@ -435,25 +795,6 @@ class AdmmLinearOptimizer : public Optimizer<LossFunction, PenaltyFunction, Coef
return MakeOptimum(*loss_, *penalty_, coefs_, std::move(metrics), status, message);
}
//! Compute the **doubly** weighted residuals.
//! This does not compute the true "weighted" residuals W.(y - mu - X.beta), but rather W'W.(y - mu - X.beta)!
arma::vec EmptyModelResiduals(std::true_type) const {
if (loss_->IncludeIntercept()) {
return weights_ % (loss_->data().cy() - coefs_.intercept);
} else {
return weights_ % loss_->data().cy();
}
}
//! Compute the unweighted residuals.
arma::vec EmptyModelResiduals(std::false_type) const noexcept {
if (loss_->IncludeIntercept()) {
return loss_->data().cy() - coefs_.intercept;
} else {
return loss_->data().cy();
}
}
//! Determine the cutoff for the soft-threshold function for adaptive penalties
arma::vec DetermineCutoff(std::true_type) const noexcept {
return penalty_->loadings() * DetermineCutoff(std::false_type{});
......@@ -461,110 +802,26 @@ class AdmmLinearOptimizer : public Optimizer<LossFunction, PenaltyFunction, Coef