7 #include <gsl/gsl_randist.h> 8 #include <gsl/gsl_sf_gamma.h> 13 namespace alloc_model {
15 template <
typename T,
typename Scalar>
33 template <
typename T,
typename Scalar>
36 const long x = S.dimension(0);
37 const long y = S.dimension(1);
38 const long z = S.dimension(2);
41 const double log_gamma_sum =
47 alpha.
begin(), alpha.
end(), log_gamma_alpha.begin(),
48 [](
const Scalar alpha_i) {
return gsl_sf_lngamma(alpha_i); });
50 const double sum_log_gamma =
54 #pragma omp parallel for reduction(+:first) 55 for (
int k = 0; k < z; ++k) {
57 for (
int i = 0; i < x; ++i) {
59 for (
int j = 0; j < y; ++j) {
63 first -= gsl_sf_lngamma(sum);
65 for (
int i = 0; i < x; ++i) {
67 for (
int j = 0; j < y; ++j) {
70 first += gsl_sf_lngamma(sum);
74 double base = log_gamma_sum * z - sum_log_gamma * z;
90 template <
typename T,
typename Scalar>
93 const long x = S.dimension(0);
94 const long y = S.dimension(1);
95 const long z = S.dimension(2);
98 const double log_gamma_sum =
104 [](
const Scalar beta) {
return gsl_sf_lngamma(beta); });
106 const double sum_log_gamma =
110 #pragma omp parallel for reduction(+:second) 111 for (
int j = 0; j < y; ++j) {
113 for (
int k = 0; k < z; ++k) {
115 for (
int i = 0; i < x; ++i) {
119 second -= gsl_sf_lngamma(sum);
121 for (
int k = 0; k < z; ++k) {
123 for (
int i = 0; i < x; ++i) {
126 second += gsl_sf_lngamma(sum);
130 double base = log_gamma_sum * y - sum_log_gamma * y;
131 return base + second;
145 template <
typename T,
typename Scalar>
147 const long x = S.dimension(0);
148 const long y = S.dimension(1);
149 const long z = S.dimension(2);
151 const double log_gamma = -gsl_sf_lngamma(a);
152 const double a_log_b = a *
std::log(b);
155 #pragma omp parallel for reduction(+:third) 156 for (
int j = 0; j < y; ++j) {
158 for (
int i = 0; i < x; ++i) {
159 for (
int k = 0; k < z; ++k) {
163 third += gsl_sf_lngamma(sum);
167 double base = log_gamma * y + a_log_b * y;
180 const long x = S.dimension(0);
181 const long y = S.dimension(1);
182 const long z = S.dimension(2);
185 #pragma omp parallel for reduction(+:fourth) 186 for (
int i = 0; i < x; ++i) {
187 for (
int j = 0; j < y; ++j) {
188 for (
int k = 0; k < z; ++k) {
189 fourth += gsl_sf_lngamma(S(i, j, k) + 1);
215 : X(X), model_params(model_params),
216 S(X.rows(), X.cols(), model_params.beta.size()) {
220 const Integer z = model_params.
beta.
size();
221 for (
long i = 0; i < X.rows(); ++i) {
222 for (
long j = 0; j < X.cols(); ++j) {
234 for (
long i = 0; i < X.rows(); ++i) {
235 for (
long j = 0; j < X.cols(); ++j) {
236 S(i, j, 0) = X(i, j);
244 model_params.
alpha.
end(), Scalar());
255 const double init_log_marginal =
258 return marginal_recursive(0, init_log_marginal);
272 double log_marginal_change_on_increment(
size_t i,
size_t j,
size_t k) {
273 return std::log(model_params.alpha[i] + S_ipk(i, k)) -
275 std::log(model_params.beta[k] + S_pjk(j, k));
288 double log_marginal_change_on_decrement(
size_t i,
size_t j,
size_t k) {
289 double result = -(
std::log(model_params.alpha[i] + S_ipk(i, k) - 1) -
290 std::log(sum_alpha + S_ppk(k) - 1) -
292 std::log(model_params.beta[k] + S_pjk(j, k) - 1));
318 double marginal_recursive(
const size_t fiber_index,
319 double prev_log_marginal) {
322 const auto& part_changes = alloc_vec[fiber_index];
323 const size_t i = ii[fiber_index];
324 const size_t j = jj[fiber_index];
326 const Integer old_value = S(i, j, 0);
329 const bool last_fiber = fiber_index == (ii.size() - 1);
332 last_fiber ?
std::exp(prev_log_marginal)
333 : marginal_recursive(fiber_index + 1, prev_log_marginal);
337 size_t incr_idx, decr_idx;
338 double increment_change, decrement_change, new_log_marginal;
339 for (
const auto& change_idx : part_changes) {
341 std::tie(decr_idx, incr_idx) = change_idx;
343 decrement_change = log_marginal_change_on_decrement(i, j, decr_idx);
347 --S_pjk(j, decr_idx);
348 --S_ipk(i, decr_idx);
351 increment_change = log_marginal_change_on_increment(i, j, incr_idx);
355 ++S_pjk(j, incr_idx);
356 ++S_ipk(i, incr_idx);
360 prev_log_marginal + increment_change + decrement_change;
364 result +=
std::exp(new_log_marginal);
366 result += marginal_recursive(fiber_index + 1, new_log_marginal);
369 prev_log_marginal = new_log_marginal;
373 const auto last_index = S.dimension(2) - 1;
374 S(i, j, last_index) = 0;
375 S(i, j, 0) = old_value;
376 S_pjk(j, last_index) -= old_value;
377 S_pjk(j, 0) += old_value;
378 S_ipk(i, last_index) -= old_value;
379 S_ipk(i, 0) += old_value;
380 S_ppk(last_index) -= old_value;
381 S_ppk(0) += old_value;
439 namespace alloc_model {
461 template <
typename T,
typename Scalar>
464 size_t x = tensor_shape[0], y = tensor_shape[1], z = tensor_shape[2];
466 BNMF_ASSERT(model_params.
alpha.
size() == x,
467 "Number of dirichlet parameters alpha must be equal to x");
468 BNMF_ASSERT(model_params.
beta.
size() == z,
469 "Number of dirichlet parameters beta must be equal to z");
475 for (
size_t i = 0; i < y; ++i) {
477 gsl_ran_gamma(rand_gen.get(), model_params.
a, model_params.
b);
481 matrix_t<T> prior_W(x, z);
483 for (
size_t i = 0; i < z; ++i) {
484 gsl_ran_dirichlet(rand_gen.get(), x, model_params.
alpha.
data(),
485 dirichlet_variates.data());
487 for (
size_t j = 0; j < x; ++j) {
488 prior_W(j, i) = dirichlet_variates(j);
493 matrix_t<T> prior_H(z, y);
495 for (
size_t i = 0; i < y; ++i) {
496 gsl_ran_dirichlet(rand_gen.get(), z, model_params.
beta.
data(),
497 dirichlet_variates.data());
499 for (
size_t j = 0; j < z; ++j) {
500 prior_H(j, i) = dirichlet_variates(j);
527 template <
typename T>
530 auto x =
static_cast<size_t>(prior_W.rows());
531 auto y =
static_cast<size_t>(prior_L.cols());
532 auto z =
static_cast<size_t>(prior_H.rows());
534 BNMF_ASSERT(prior_W.cols() == prior_H.rows(),
535 "Number of columns of W is different than number of rows of H");
536 BNMF_ASSERT(prior_H.cols() == prior_L.cols(),
537 "Number of columns of H is different than size of L");
543 for (
size_t i = 0; i < x; ++i) {
544 for (
size_t j = 0; j < y; ++j) {
545 for (
size_t k = 0; k < z; ++k) {
546 mu = prior_W(i, k) * prior_H(k, j) * prior_L(j);
547 sample(i, j, k) = gsl_ran_poisson(rand_gen.get(), mu);
573 template <
typename T,
typename Scalar>
577 static_cast<size_t>(S.dimension(0)),
578 "Number of alpha parameters must be equal to S.dimension(0)");
579 BNMF_ASSERT(model_params.
beta.
size() ==
static_cast<size_t>(S.dimension(2)),
580 "Number of alpha parameters must be equal to S.dimension(2)");
599 template <
typename Integer,
typename Scalar>
602 BNMF_ASSERT((X.array() >= 0).all(),
603 "X must be nonnegative in alloc_model::total_log_marginal");
604 BNMF_ASSERT(static_cast<size_t>(X.rows()) == model_params.
alpha.
size(),
605 "Model parameters are incompatible with given matrix X in " 606 "alloc_model::total_log_marginal");
609 double marginal = calc.calc_marginal();
Structure to hold the parameters for the Allocation Model .
Definition: alloc_model_params.hpp:25
TotalMarginalCalculator(const matrix_t< Integer > &X, const alloc_model::Params< Scalar > &model_params)
Construct the calculator class and initialize computation variables.
Definition: alloc_model_funcs.hpp:213
Eigen::Tensor< Scalar, N, Eigen::RowMajor > tensor_t
Tensor type used in the computations.
Definition: defs.hpp:52
Eigen::Matrix< Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor > matrix_t
Matrix type used in the computations.
Definition: defs.hpp:41
double total_log_marginal(const matrix_t< Integer > &X, const Params< Scalar > &model_params)
Calculate total marginal value, where is model parameters.
Definition: alloc_model_funcs.hpp:600
Eigen::array< size_t, N > shape
Shape of vectors, matrices, tensors, etc.
Definition: defs.hpp:66
double compute_first_term(const tensor_t< T, 3 > &S, const std::vector< Scalar > &alpha)
Compute the first term of the sum when calculating log marginal of S.
Definition: alloc_model_funcs.hpp:34
Class to compute total for all possible allocation tensors while storing shared state between functi...
Definition: alloc_model_funcs.hpp:204
std::vector< Scalar > alpha
Parameter vector of Dirichlet prior for matrix of size .
Definition: alloc_model_params.hpp:37
double compute_second_term(const tensor_t< T, 3 > &S, const std::vector< Scalar > &beta)
Compute the second term of the sum when calculating log marginal of S.
Definition: alloc_model_funcs.hpp:91
Scalar a
Shape parameter of Gamma distribution.
Definition: alloc_model_params.hpp:28
Scalar b
Rate parameter of Gamma distribution.
Definition: alloc_model_params.hpp:32
double compute_third_term(const tensor_t< T, 3 > &S, Scalar a, Scalar b)
Compute the third term of the sum when calculating log marginal of S.
Definition: alloc_model_funcs.hpp:146
double calc_marginal()
Calculate total marginal by calculating for every possible allocation tensor .
Definition: alloc_model_funcs.hpp:254
std::tuple< matrix_t< T >, matrix_t< T >, vector_t< T > > bnmf_priors(const shape< 3 > &tensor_shape, const Params< Scalar > &model_params)
Return prior matrices W, H and vector L according to the Bayesian NMF allocation model using distribu...
Definition: alloc_model_funcs.hpp:463
double log_marginal_S(const tensor_t< T, 3 > &S, const Params< Scalar > &model_params)
Compute the log marginal of tensor S with respect to the given distribution parameters.
Definition: alloc_model_funcs.hpp:574
std::vector< std::pair< size_t, size_t > > partition_change_indices(Integer n, Integer k)
Compute the sequence of indices to be incremented and decremented to generate all partitions of numbe...
Definition: util.hpp:560
std::vector< Scalar > beta
Parameter vector of Dirichlet prior for matrix of size .
Definition: alloc_model_params.hpp:42
Eigen::Matrix< Scalar, 1, Eigen::Dynamic, Eigen::RowMajor > vector_t
Vector type used in the computations.
Definition: defs.hpp:27
Main namespace for bnmf-algs library.
Definition: alloc_model_funcs.hpp:12
tensor_t< T, 3 > sample_S(const matrix_t< T > &prior_W, const matrix_t< T > &prior_H, const vector_t< T > &prior_L)
Sample a tensor S from generative Bayesian NMF model using the given priors.
Definition: alloc_model_funcs.hpp:528
double compute_fourth_term(const tensor_t< T, 3 > &S)
Compute the fourth term of the sum when calculating log marginal of S.
Definition: alloc_model_funcs.hpp:179