bnmf-algs
sum_cond_EM.hpp
Go to the documentation of this file.
1 #pragma once
2 
4 #include "bld/util_details.hpp"
5 #include "defs.hpp"
6 #include "online_EM_defs.hpp"
7 #include "online_EM_funcs.hpp"
8 #include "util/util.hpp"
9 #include <cmath>
10 #include <gsl/gsl_sf_gamma.h>
11 #include <gsl/gsl_sf_psi.h>
12 #include <tuple>
13 
14 namespace bnmf_algs {
15 namespace bld {
16 
49 template <typename T, typename Scalar>
50 EMResult<T> online_EM(const matrix_t<T>& X,
51  const std::vector<alloc_model::Params<Scalar>>& param_vec,
52  const size_t max_iter = 1000,
53  const bool use_psi_appr = false) {
54  details::check_EM_params(X, param_vec);
55 
56  const auto x = static_cast<size_t>(X.rows());
57  const auto y = static_cast<size_t>(X.cols());
58  const auto z = static_cast<size_t>(param_vec.size());
59 
60  EMResult<T> res;
61  res.X_full = X;
62 
63  if (details::online_EM::init_nan_values(res.X_full) == 0) {
64  // all the values are NaN
65  return res;
66  }
67 
68  // nonzero elems and indices (including NaN)
69  // these are const, but no good way of making them so in C++14
70  /* const */ std::vector<size_t> ii, jj;
71  /* const */ std::vector<T> xx;
73 
74  // init alpha and beta
75  // these are const, but no good way of making them so in C++14
76  /* const */ matrix_t<Scalar> alpha, beta;
77  std::tie(alpha, beta) = details::online_EM::init_alpha_beta(param_vec, y);
78  const vector_t<Scalar> alpha_pk = alpha.colwise().sum();
79 
80  // init S sums
81  vector_t<T> S_ppk;
82  std::tie(res.S_pjk, res.S_ipk, S_ppk) =
83  details::online_EM::init_S_xx(res.X_full, z, ii, jj);
84 
85  // psi func to use
86  std::function<T(T)> psi_fn =
87  use_psi_appr ? util::psi_appr<T> : details::gsl_psi_wrapper<T>;
88 
89  // init logW and logH
90  res.logW = matrix_t<double>(x, z);
91  res.logH = matrix_t<double>(z, y);
92  details::online_EM::update_logW(alpha, res.S_ipk, alpha_pk, S_ppk, psi_fn,
93  res.logW);
94  details::online_EM::update_logH(beta, res.S_pjk, param_vec[0].b, psi_fn,
95  res.logH);
96 
97  // iteration variables
98  res.log_PS = vector_t<double>::Constant(max_iter, 0);
99 
100  // EM
101  for (size_t iter = 0; iter < max_iter; ++iter) {
102  res.S_pjk.setZero();
103  res.S_ipk.setZero();
104  S_ppk.setZero();
105 
106  double delta_log_PS =
107  details::online_EM::update_allocation(ii, jj, xx, res, S_ppk);
108  res.log_PS(iter) += delta_log_PS;
109 
110  details::online_EM::update_logW(alpha, res.S_ipk, alpha_pk, S_ppk,
111  psi_fn, res.logW);
112  details::online_EM::update_logH(beta, res.S_pjk, param_vec[0].b, psi_fn,
113  res.logH);
114 
115  delta_log_PS = details::online_EM::delta_log_PS(
116  alpha, beta, res.S_ipk, res.S_pjk, alpha_pk, S_ppk, param_vec[0].b);
117  res.log_PS(iter) += delta_log_PS;
118  }
119 
120  return res;
121 }
122 
123 } // namespace bld
124 } // namespace bnmf_algs
void update_logH(const matrix_t< Scalar > &beta, const matrix_t< T > &S_pjk, const Scalar b, const PsiFunction &psi_fn, matrix_t< double > &logH)
Perform an update on logW matrix.
Definition: online_EM_funcs.hpp:221
EMResult< T > online_EM(const matrix_t< T > &X, const std::vector< alloc_model::Params< Scalar >> &param_vec, const size_t max_iter=1000, const bool use_psi_appr=false)
Complete a matrix containing unobserved values given as NaN using an EM procedure according to the al...
Definition: online_EM.hpp:50
T tie(T...args)
void check_EM_params(const matrix_t< T > &X, const std::vector< alloc_model::Params< Scalar >> &param_vec)
Do parameter checks on EM algorithms for solving the BLD problem.
Definition: util_details.hpp:39
double delta_log_PS(const matrix_t< Scalar > &alpha, const matrix_t< Scalar > &beta, const matrix_t< T > &S_ipk, const matrix_t< T > &S_pjk, const vector_t< Scalar > &alpha_pk, const vector_t< T > &S_ppk, Scalar b)
Compute the difference in log_PS value computed using the given model variables.
Definition: online_EM_funcs.hpp:364
std::tuple< std::vector< size_t >, std::vector< size_t >, std::vector< T > > find_nonzero(const matrix_t< T > &X)
Find nonzero entries and their indices in the given matrix and return indices and values as vectors...
Definition: online_EM_funcs.hpp:251
size_t init_nan_values(matrix_t< T > &X)
Initialize all NaN values in the given matrix with the mean of the remaining values that are differen...
Definition: online_EM_funcs.hpp:28
std::tuple< matrix_t< T >, matrix_t< T >, vector_t< T > > init_S_xx(const matrix_t< T > &X_full, size_t z, const std::vector< size_t > &ii, const std::vector< size_t > &jj)
Initialize S_pjk, S_ipk matrices and S_ppk vector from a Dirichlet distribution with all parameters e...
Definition: online_EM_funcs.hpp:120
std::pair< matrix_t< Scalar >, matrix_t< Scalar > > init_alpha_beta(const std::vector< alloc_model::Params< Scalar >> &param_vec, size_t y)
Initialize each entry of alpha and beta matrices with the given model parameters. ...
Definition: online_EM_funcs.hpp:77
Main namespace for bnmf-algs library.
Definition: alloc_model_funcs.hpp:12
double update_allocation(const std::vector< size_t > &ii, const std::vector< size_t > &jj, const std::vector< T > &xx, bld::EMResult< T > &res, vector_t< T > &S_ppk)
Update the current allocation by performing a maximization step for each nonzero entry of the origina...
Definition: online_EM_funcs.hpp:285
void update_logW(const matrix_t< Scalar > &alpha, const matrix_t< T > &S_ipk, const vector_t< Scalar > &alpha_pk, const vector_t< T > &S_ppk, const PsiFunction &psi_fn, matrix_t< double > &logW)
Perform an update on logW matrix.
Definition: online_EM_funcs.hpp:186