10 #include <gsl/gsl_sf_gamma.h> 11 #include <gsl/gsl_sf_psi.h> 49 template <
typename T,
typename Scalar>
50 EMResult<T>
online_EM(
const matrix_t<T>& X,
51 const std::vector<alloc_model::Params<Scalar>>& param_vec,
52 const size_t max_iter = 1000,
53 const bool use_psi_appr =
false) {
56 const auto x =
static_cast<size_t>(X.rows());
57 const auto y =
static_cast<size_t>(X.cols());
58 const auto z =
static_cast<size_t>(param_vec.size());
76 matrix_t<Scalar> alpha, beta;
78 const vector_t<Scalar> alpha_pk = alpha.colwise().sum();
82 std::tie(res.S_pjk, res.S_ipk, S_ppk) =
87 use_psi_appr ? util::psi_appr<T> : details::gsl_psi_wrapper<T>;
90 res.logW = matrix_t<double>(x, z);
91 res.logH = matrix_t<double>(z, y);
98 res.log_PS = vector_t<double>::Constant(max_iter, 0);
101 for (
size_t iter = 0; iter < max_iter; ++iter) {
116 alpha, beta, res.S_ipk, res.S_pjk, alpha_pk, S_ppk, param_vec[0].b);
void update_logH(const matrix_t< Scalar > &beta, const matrix_t< T > &S_pjk, const Scalar b, const PsiFunction &psi_fn, matrix_t< double > &logH)
Perform an update on logW matrix.
Definition: online_EM_funcs.hpp:221
EMResult< T > online_EM(const matrix_t< T > &X, const std::vector< alloc_model::Params< Scalar >> ¶m_vec, const size_t max_iter=1000, const bool use_psi_appr=false)
Complete a matrix containing unobserved values given as NaN using an EM procedure according to the al...
Definition: online_EM.hpp:50
void check_EM_params(const matrix_t< T > &X, const std::vector< alloc_model::Params< Scalar >> ¶m_vec)
Do parameter checks on EM algorithms for solving the BLD problem.
Definition: util_details.hpp:39
double delta_log_PS(const matrix_t< Scalar > &alpha, const matrix_t< Scalar > &beta, const matrix_t< T > &S_ipk, const matrix_t< T > &S_pjk, const vector_t< Scalar > &alpha_pk, const vector_t< T > &S_ppk, Scalar b)
Compute the difference in log_PS value computed using the given model variables.
Definition: online_EM_funcs.hpp:364
std::tuple< std::vector< size_t >, std::vector< size_t >, std::vector< T > > find_nonzero(const matrix_t< T > &X)
Find nonzero entries and their indices in the given matrix and return indices and values as vectors...
Definition: online_EM_funcs.hpp:251
size_t init_nan_values(matrix_t< T > &X)
Initialize all NaN values in the given matrix with the mean of the remaining values that are differen...
Definition: online_EM_funcs.hpp:28
std::tuple< matrix_t< T >, matrix_t< T >, vector_t< T > > init_S_xx(const matrix_t< T > &X_full, size_t z, const std::vector< size_t > &ii, const std::vector< size_t > &jj)
Initialize S_pjk, S_ipk matrices and S_ppk vector from a Dirichlet distribution with all parameters e...
Definition: online_EM_funcs.hpp:120
std::pair< matrix_t< Scalar >, matrix_t< Scalar > > init_alpha_beta(const std::vector< alloc_model::Params< Scalar >> ¶m_vec, size_t y)
Initialize each entry of alpha and beta matrices with the given model parameters. ...
Definition: online_EM_funcs.hpp:77
Main namespace for bnmf-algs library.
Definition: alloc_model_funcs.hpp:12
double update_allocation(const std::vector< size_t > &ii, const std::vector< size_t > &jj, const std::vector< T > &xx, bld::EMResult< T > &res, vector_t< T > &S_ppk)
Update the current allocation by performing a maximization step for each nonzero entry of the origina...
Definition: online_EM_funcs.hpp:285
void update_logW(const matrix_t< Scalar > &alpha, const matrix_t< T > &S_ipk, const vector_t< Scalar > &alpha_pk, const vector_t< T > &S_ppk, const PsiFunction &psi_fn, matrix_t< double > &logW)
Perform an update on logW matrix.
Definition: online_EM_funcs.hpp:186