10 #include <gsl/gsl_sf_gamma.h> 11 #include <gsl/gsl_sf_psi.h> 49 template <
typename T,
typename Scalar>
52 const size_t max_iter = 1000,
53 const bool use_psi_appr =
false) {
56 const auto x =
static_cast<size_t>(X.rows());
57 const auto y =
static_cast<size_t>(X.cols());
58 const auto z =
static_cast<size_t>(param_vec.size());
87 use_psi_appr ? util::psi_appr<T> : details::gsl_psi_wrapper<T>;
101 for (
size_t iter = 0; iter < max_iter; ++iter) {
116 alpha, beta, res.
S_ipk, res.
S_pjk, alpha_pk, S_ppk, param_vec[0].b);
Structure to hold the parameters for the Allocation Model .
Definition: alloc_model_params.hpp:25
void update_logH(const matrix_t< Scalar > &beta, const matrix_t< T > &S_pjk, const Scalar b, const PsiFunction &psi_fn, matrix_t< double > &logH)
Perform an update on logW matrix.
Definition: online_EM_funcs.hpp:221
EMResult< T > online_EM(const matrix_t< T > &X, const std::vector< alloc_model::Params< Scalar >> ¶m_vec, const size_t max_iter=1000, const bool use_psi_appr=false)
Complete a matrix containing unobserved values given as NaN using an EM procedure according to the al...
Definition: online_EM.hpp:50
void check_EM_params(const matrix_t< T > &X, const std::vector< alloc_model::Params< Scalar >> ¶m_vec)
Do parameter checks on EM algorithms for solving the BLD problem.
Definition: util_details.hpp:39
double delta_log_PS(const matrix_t< Scalar > &alpha, const matrix_t< Scalar > &beta, const matrix_t< T > &S_ipk, const matrix_t< T > &S_pjk, const vector_t< Scalar > &alpha_pk, const vector_t< T > &S_ppk, Scalar b)
Compute the difference in log_PS value computed using the given model variables.
Definition: online_EM_funcs.hpp:364
Eigen::Matrix< Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor > matrix_t
Matrix type used in the computations.
Definition: defs.hpp:41
Structure holding the results of EM procedures.
Definition: online_EM_defs.hpp:13
matrix_t< double > logH
Matrix whose entry contains .
Definition: online_EM_defs.hpp:41
std::tuple< std::vector< size_t >, std::vector< size_t >, std::vector< T > > find_nonzero(const matrix_t< T > &X)
Find nonzero entries and their indices in the given matrix and return indices and values as vectors...
Definition: online_EM_funcs.hpp:251
matrix_t< T > X_full
Completed version of the incomplete matrix given as input to an EM algorithm.
Definition: online_EM_defs.hpp:33
matrix_t< T > S_ipk
Sum of the hidden tensor along its second dimension, i.e. .
Definition: online_EM_defs.hpp:24
matrix_t< double > logW
Matrix whose entry contains .
Definition: online_EM_defs.hpp:37
matrix_t< T > S_pjk
Sum of the hidden tensor along its first dimension, i.e. .
Definition: online_EM_defs.hpp:19
size_t init_nan_values(matrix_t< T > &X)
Initialize all NaN values in the given matrix with the mean of the remaining values that are differen...
Definition: online_EM_funcs.hpp:28
std::tuple< matrix_t< T >, matrix_t< T >, vector_t< T > > init_S_xx(const matrix_t< T > &X_full, size_t z, const std::vector< size_t > &ii, const std::vector< size_t > &jj)
Initialize S_pjk, S_ipk matrices and S_ppk vector from a Dirichlet distribution with all parameters e...
Definition: online_EM_funcs.hpp:120
std::pair< matrix_t< Scalar >, matrix_t< Scalar > > init_alpha_beta(const std::vector< alloc_model::Params< Scalar >> ¶m_vec, size_t y)
Initialize each entry of alpha and beta matrices with the given model parameters. ...
Definition: online_EM_funcs.hpp:77
vector_t< double > log_PS
Vector containing EM bound computed after every iteration.
Definition: online_EM_defs.hpp:45
Eigen::Matrix< Scalar, 1, Eigen::Dynamic, Eigen::RowMajor > vector_t
Vector type used in the computations.
Definition: defs.hpp:27
Main namespace for bnmf-algs library.
Definition: alloc_model_funcs.hpp:12
double update_allocation(const std::vector< size_t > &ii, const std::vector< size_t > &jj, const std::vector< T > &xx, bld::EMResult< T > &res, vector_t< T > &S_ppk)
Update the current allocation by performing a maximization step for each nonzero entry of the origina...
Definition: online_EM_funcs.hpp:285
void update_logW(const matrix_t< Scalar > &alpha, const matrix_t< T > &S_ipk, const vector_t< Scalar > &alpha_pk, const vector_t< T > &S_ppk, const PsiFunction &psi_fn, matrix_t< double > &logW)
Perform an update on logW matrix.
Definition: online_EM_funcs.hpp:186