36 template <
typename T,
typename Scalar>
40 auto x =
static_cast<size_t>(S.dimension(0));
41 auto y =
static_cast<size_t>(S.dimension(1));
42 auto z =
static_cast<size_t>(S.dimension(2));
44 BNMF_ASSERT(model_params.
alpha.
size() == x,
45 "Number of alpha parameters must be equal to S.dimension(0)");
46 BNMF_ASSERT(model_params.
beta.
size() == z,
47 "Number of beta parameters must be equal to z");
57 for (
size_t i = 0; i < x; ++i) {
58 for (
size_t k = 0; k < z; ++k) {
59 W(i, k) = model_params.
alpha[i] + S_ipk(i, k) - 1;
62 for (
size_t k = 0; k < z; ++k) {
63 for (
size_t j = 0; j < y; ++j) {
64 H(k, j) = model_params.
beta[k] + S_pjk(j, k) - 1;
67 for (
size_t j = 0; j < y; ++j) {
68 L(j) = (model_params.
a + S_pjp(j) - 1) / (model_params.
b + 1 + eps);
76 W = W.array().rowwise() / W_colsum.array();
77 H = H.array().rowwise() / H_colsum.array();
Structure to hold the parameters for the Allocation Model .
Definition: alloc_model_params.hpp:25
Eigen::Tensor< Scalar, N, Eigen::RowMajor > tensor_t
Tensor type used in the computations.
Definition: defs.hpp:52
Eigen::array< size_t, N > shape
Shape of vectors, matrices, tensors, etc.
Definition: defs.hpp:66
std::vector< Scalar > alpha
Parameter vector of Dirichlet prior for matrix of size .
Definition: alloc_model_params.hpp:37
Scalar a
Shape parameter of Gamma distribution.
Definition: alloc_model_params.hpp:28
Scalar b
Rate parameter of Gamma distribution.
Definition: alloc_model_params.hpp:32
std::vector< Scalar > beta
Parameter vector of Dirichlet prior for matrix of size .
Definition: alloc_model_params.hpp:42
Eigen::Matrix< Scalar, 1, Eigen::Dynamic, Eigen::RowMajor > vector_t
Vector type used in the computations.
Definition: defs.hpp:27
Main namespace for bnmf-algs library.
Definition: alloc_model_funcs.hpp:12
std::tuple< matrix_t< T >, matrix_t< T >, vector_t< T > > bld_fact(const tensor_t< T, 3 > &S, const alloc_model::Params< Scalar > &model_params, double eps=1e-50)
Compute matrices and vector from tensor according to the allocation model.
Definition: bld_fact.hpp:38