diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index a8dd3286b07d15e1e98f5945417316e7c77ee70f..79e1610fab97c682ba3743922995111cc66fea11 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -26,7 +26,7 @@ publish_package: script: - pip install twine - python setup.py bdist_wheel - - TWINE_PASSWORD=${TWINE_PASSWORD} TWINE_USERNAME=${TWINE_USERNAME} python -m twine upload dist/* + - TWINE_PASSWORD=${pypln_token} TWINE_USERNAME=${account_name} python -m twine upload dist/* tags: - docker only: diff --git a/pyPLNmodels/__init__.py b/pyPLNmodels/__init__.py index 15591263ad54a35729ad9b724a7467f8be9b30b5..d895c7cd30d9f39e8069cb6793b48a4ddb2dac79 100644 --- a/pyPLNmodels/__init__.py +++ b/pyPLNmodels/__init__.py @@ -1,6 +1,12 @@ from .models import PLNPCA, PLN # pylint:disable=[C0114] from .elbos import profiled_elbo_pln, elbo_plnpca, elbo_pln -from ._utils import get_simulated_count_data, get_real_count_data +from ._utils import ( + get_simulated_count_data, + get_real_count_data, + load_model, + load_plnpca, + load_pln, +) __all__ = ( "PLNPCA", @@ -10,4 +16,7 @@ __all__ = ( "elbo_pln", "get_simulated_count_data", "get_real_count_data", + "load_model", + "load_plnpca", + "load_pln", ) diff --git a/pyPLNmodels/_closed_forms.py b/pyPLNmodels/_closed_forms.py index 889cf221a5711e7d30c8c5ec4b7dfee49eedd517..0dc9e7fab2b975949f515733508f915b103bf284 100644 --- a/pyPLNmodels/_closed_forms.py +++ b/pyPLNmodels/_closed_forms.py @@ -3,7 +3,11 @@ import torch # pylint:disable=[C0114] def closed_formula_covariance(covariates, latent_mean, latent_var, coef, n_samples): """Closed form for covariance for the M step for the noPCA model.""" - m_moins_xb = latent_mean - covariates @ coef + if covariates is None: + XB = 0 + else: + XB = covariates @ coef + m_moins_xb = latent_mean - XB closed = m_moins_xb.T @ m_moins_xb + torch.diag( torch.sum(torch.square(latent_var), dim=0) ) @@ -12,6 +16,8 @@ def closed_formula_covariance(covariates, latent_mean, latent_var, coef, n_sampl def closed_formula_coef(covariates, latent_mean): """Closed form for coef for the M step for the noPCA model.""" + if covariates is None: + return None return torch.inverse(covariates.T @ covariates) @ covariates.T @ latent_mean diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py index 1024746c6dd76601317650532dd981c24d4c8ba7..0e079ce52f357a6be07403d83c44734add074dd7 100644 --- a/pyPLNmodels/_utils.py +++ b/pyPLNmodels/_utils.py @@ -1,5 +1,6 @@ import math # pylint:disable=[C0114] import warnings +import os import matplotlib.pyplot as plt import numpy as np @@ -8,6 +9,7 @@ import torch.linalg as TLA import pandas as pd from matplotlib.patches import Ellipse from matplotlib import transforms +from patsy import dmatrices torch.set_default_dtype(torch.float64) @@ -81,7 +83,7 @@ class PLNPlotArgs: ax.legend() -def init_sigma(counts, covariates, coef): +def init_covariance(counts, covariates, coef): """Initialization for covariance for the PLN model. Take the log of counts (careful when counts=0), remove the covariates effects X@coef and then do as a MLE for Gaussians samples. @@ -93,9 +95,7 @@ def init_sigma(counts, covariates, coef): Returns : torch.tensor of size (p,p). """ log_y = torch.log(counts + (counts == 0) * math.exp(-2)) - log_y_centered = ( - log_y - torch.matmul(covariates.unsqueeze(1), coef.unsqueeze(0)).squeeze() - ) + log_y_centered = log_y - torch.mean(log_y, axis=0) # MLE in a Gaussian setting n_samples = counts.shape[0] sigma_hat = 1 / (n_samples - 1) * (log_y_centered.T) @ log_y_centered @@ -115,7 +115,7 @@ def init_components(counts, covariates, coef, rank): Returns : torch.tensor of size (p,rank). The initialization of components. """ - sigma_hat = init_sigma(counts, covariates, coef).detach() + sigma_hat = init_covariance(counts, covariates, coef).detach() components = components_from_covariance(sigma_hat, rank) return components @@ -188,15 +188,11 @@ def sample_pln(components, coef, covariates, offsets, _coef_inflation=None, seed torch.random.manual_seed(seed) n_samples = offsets.shape[0] rank = components.shape[1] - full_of_ones = torch.ones((n_samples, 1)) if covariates is None: - covariates = full_of_ones + XB = 0 else: - covariates = torch.stack((full_of_ones, covariates), axis=1).squeeze() - gaussian = ( - torch.mm(torch.randn(n_samples, rank, device=DEVICE), components.T) - + covariates @ coef - ) + XB = covariates @ coef + gaussian = torch.mm(torch.randn(n_samples, rank, device=DEVICE), components.T) + XB parameter = torch.exp(offsets + gaussian) if _coef_inflation is not None: print("ZIPLN is sampled") @@ -235,13 +231,12 @@ def components_from_covariance(covariance, rank): return requested_components -def init_coef(counts, covariates): - log_y = torch.log(counts + (counts == 0) * math.exp(-2)) - log_y = log_y.to(DEVICE) - return torch.matmul( - torch.inverse(torch.matmul(covariates.T, covariates)), - torch.matmul(covariates.T, log_y), - ) +def init_coef(counts, covariates, offsets): + if covariates is None: + return None + poiss_reg = PoissonReg() + poiss_reg.fit(counts, covariates, offsets) + return poiss_reg.beta def log_stirling(integer): @@ -275,8 +270,11 @@ def log_posterior(counts, covariates, offsets, posterior_mean, components, coef) components_posterior_mean = torch.matmul( components.unsqueeze(0), posterior_mean.unsqueeze(2) ).squeeze() - - log_lambda = offsets + components_posterior_mean + covariates @ coef + if covariates is None: + XB = 0 + else: + XB = covariates @ coef + log_lambda = offsets + components_posterior_mean + XB first_term = ( -rank / 2 * math.log(2 * math.pi) - 1 / 2 * torch.norm(posterior_mean, dim=-1) ** 2 @@ -321,7 +319,21 @@ def check_two_dimensions_are_equal( ) +def init_S(counts, covariates, offsets, beta, C, M): + n, rank = M.shape + batch_matrix = torch.matmul(C.unsqueeze(2), C.unsqueeze(1)).unsqueeze(0) + CW = torch.matmul(C.unsqueeze(0), M.unsqueeze(2)).squeeze() + common = torch.exp(offsets + covariates @ beta + CW).unsqueeze(2).unsqueeze(3) + prod = batch_matrix * common + hess_posterior = torch.sum(prod, axis=1) + torch.eye(rank).to(DEVICE) + inv_hess_posterior = -torch.inverse(hess_posterior) + hess_posterior = torch.diagonal(inv_hess_posterior, dim1=-2, dim2=-1) + return hess_posterior + + def format_data(data): + if data is None: + return None if isinstance(data, pd.DataFrame): return torch.from_numpy(data.values).double().to(DEVICE) if isinstance(data, np.ndarray): @@ -335,7 +347,8 @@ def format_data(data): def format_model_param(counts, covariates, offsets, offsets_formula): counts = format_data(counts) - covariates = prepare_covariates(covariates, counts.shape[0]) + if covariates is not None: + covariates = format_data(covariates) if offsets is None: if offsets_formula == "logsum": print("Setting the offsets as the log of the sum of counts") @@ -349,20 +362,26 @@ def format_model_param(counts, covariates, offsets, offsets_formula): return counts, covariates, offsets -def prepare_covariates(covariates, n_samples): - full_of_ones = torch.full((n_samples, 1), 1, device=DEVICE).double() - if covariates is None: - return full_of_ones +def remove_useless_intercepts(covariates): covariates = format_data(covariates) - return torch.concat((full_of_ones, covariates), axis=1) + if covariates.shape[1] < 2: + return covariates + first_column = covariates[:, 0] + second_column = covariates[:, 1] + diff = first_column - second_column + if torch.sum(torch.abs(diff - diff[0])) == 0: + print("removing one") + return covariates[:, 1:] + return covariates def check_data_shape(counts, covariates, offsets): n_counts, p_counts = counts.shape n_offsets, p_offsets = offsets.shape - n_cov, _ = covariates.shape check_two_dimensions_are_equal("counts", "offsets", n_counts, n_offsets, 0) - check_two_dimensions_are_equal("counts", "covariates", n_counts, n_cov, 0) + if covariates is not None: + n_cov, _ = covariates.shape + check_two_dimensions_are_equal("counts", "covariates", n_counts, n_cov, 0) check_two_dimensions_are_equal("counts", "offsets", p_counts, p_offsets, 1) @@ -417,13 +436,13 @@ def get_components_simulation(dim, rank): def get_simulation_offsets_cov_coef(n_samples, nb_cov, dim): prev_state = torch.random.get_rng_state() torch.random.manual_seed(0) - if nb_cov < 2: + if nb_cov == 0: covariates = None else: covariates = torch.randint( low=-1, high=2, - size=(n_samples, nb_cov - 1), + size=(n_samples, nb_cov), dtype=torch.float64, device=DEVICE, ) @@ -471,9 +490,66 @@ def closest(lst, element): return lst[idx] -def check_dimensions_are_equal(tens1, tens2): - if tens1.shape[0] != tens2.shape[0] or tens1.shape[1] != tens2.shape[1]: - raise ValueError("Tensors should have the same size.") +class PoissonReg: + """Poisson regressor class.""" + + def __init__(self): + """No particular initialization is needed.""" + pass + + def fit(self, Y, covariates, O, Niter_max=300, tol=0.001, lr=0.005, verbose=False): + """Run a gradient ascent to maximize the log likelihood, using + pytorch autodifferentiation. The log likelihood considered is + the one from a poisson regression model. It is roughly the + same as PLN without the latent layer Z. + + Args: + Y: torch.tensor. Counts with size (n,p) + 0: torch.tensor. Offset, size (n,p) + covariates: torch.tensor. Covariates, size (n,d) + Niter_max: int, optional. The maximum number of iteration. + Default is 300. + tol: non negative float, optional. The tolerance criteria. + Will stop if the norm of the gradient is less than + or equal to this threshold. Default is 0.001. + lr: positive float, optional. Learning rate for the gradient ascent. + Default is 0.005. + verbose: bool, optional. If True, will print some stats. + + Returns : None. Update the parameter beta. You can access it + by calling self.beta. + """ + # Initialization of beta of size (d,p) + beta = torch.rand( + (covariates.shape[1], Y.shape[1]), device=DEVICE, requires_grad=True + ) + optimizer = torch.optim.Rprop([beta], lr=lr) + i = 0 + grad_norm = 2 * tol # Criterion + while i < Niter_max and grad_norm > tol: + loss = -compute_poissreg_log_like(Y, O, covariates, beta) + loss.backward() + optimizer.step() + grad_norm = torch.norm(beta.grad) + beta.grad.zero_() + i += 1 + if verbose: + if i % 10 == 0: + print("log like : ", -loss) + print("grad_norm : ", grad_norm) + if i < Niter_max: + print("Tolerance reached in {} iterations".format(i)) + else: + print("Maxium number of iterations reached") + self.beta = beta + + +def compute_poissreg_log_like(Y, O, covariates, beta): + """Compute the log likelihood of a Poisson regression.""" + # Matrix multiplication of X and beta. + XB = torch.matmul(covariates.unsqueeze(1), beta.unsqueeze(0)).squeeze() + # Returns the formula of the log likelihood of a poisson regression model. + return torch.sum(-torch.exp(O + XB) + torch.multiply(Y, O + XB)) def to_tensor(obj): @@ -484,3 +560,84 @@ def to_tensor(obj): if isinstance(obj, pd.DataFrame): return torch.from_numpy(obj.values) raise TypeError("Please give either a nd.array or torch.Tensor or pd.DataFrame") + + +def check_dimensions_are_equal(tens1, tens2): + if tens1.shape[0] != tens2.shape[0] or tens1.shape[1] != tens2.shape[1]: + raise ValueError("Tensors should have the same size.") + + +def load_model(path_of_directory): + working_dict = os.getcwd() + os.chdir(path_of_directory) + all_files = os.listdir() + data = {} + for filename in all_files: + if len(filename) > 4: + if filename[-4:] == ".csv": + parameter = filename[:-4] + try: + data[parameter] = pd.read_csv(filename, header=None).values + except pd.errors.EmptyDataError as err: + print( + f"Can't load {parameter} since empty. Standard initialization will be performed" + ) + os.chdir(working_dict) + return data + + +def load_pln(path_of_directory): + return load_model(path_of_directory) + + +def load_plnpca(path_of_directory, ranks=None): + working_dict = os.getcwd() + os.chdir(path_of_directory) + if ranks is None: + dirnames = os.listdir() + ranks = [] + for dirname in dirnames: + try: + rank = int(dirname[-1]) + except ValueError: + raise ValueError( + f"Can't load the model {dirname}. End of {dirname} should be an int" + ) + ranks.append(rank) + datas = {} + for rank in ranks: + datas[rank] = load_model(f"_PLNPCA_rank_{rank}") + os.chdir(working_dict) + return datas + + +def check_right_rank(data, rank): + data_rank = data["latent_mean"].shape[1] + if data_rank != rank: + raise RuntimeError( + f"Wrong rank during initialization." + f" Got rank {rank} and data with rank {data_rank}." + ) + + +def extract_data_from_formula(formula, data): + dmatrix = dmatrices(formula, data=data) + counts = dmatrix[0] + covariates = dmatrix[1] + print("covariates size:", covariates.size) + if covariates.size == 0: + covariates = None + offsets = data.get("offsets", None) + return counts, covariates, offsets + + +def is_dict_of_dict(dictionnary): + if isinstance(dictionnary[list(dictionnary.keys())[0]], dict): + return True + return False + + +def get_dict_initialization(rank, dict_of_dict): + if dict_of_dict is None: + return None + return dict_of_dict[rank] diff --git a/pyPLNmodels/elbos.py b/pyPLNmodels/elbos.py index 160468cfc407dd5e271dde305a9db1aa02063fa6..082bcdfaa943478ed1cf9999ce8016256130f930 100644 --- a/pyPLNmodels/elbos.py +++ b/pyPLNmodels/elbos.py @@ -22,7 +22,11 @@ def elbo_pln(counts, covariates, offsets, latent_mean, latent_var, covariance, c n_samples, dim = counts.shape s_rond_s = torch.square(latent_var) offsets_plus_m = offsets + latent_mean - m_minus_xb = latent_mean - covariates @ coef + if covariates is None: + XB = 0 + else: + XB = covariates @ coef + m_minus_xb = latent_mean - XB d_plus_minus_xb2 = ( torch.diag(torch.sum(s_rond_s, dim=0)) + m_minus_xb.T @ m_minus_xb ) @@ -90,7 +94,11 @@ def elbo_plnpca(counts, covariates, offsets, latent_mean, latent_var, components """ n_samples = counts.shape[0] rank = components.shape[1] - log_intensity = offsets + covariates @ coef + latent_mean @ components.T + if covariates is None: + XB = 0 + else: + XB = covariates @ coef + log_intensity = offsets + XB + latent_mean @ components.T s_rond_s = torch.square(latent_var) counts_log_intensity = torch.sum(counts * log_intensity) minus_intensity_plus_s_rond_s_cct = torch.sum( diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py index c4171d6aab7437a1cb56e7159090803ae32e990d..1be04cc1508ea93065ade6298c4392338e224085 100644 --- a/pyPLNmodels/models.py +++ b/pyPLNmodels/models.py @@ -3,6 +3,8 @@ from abc import ABC, abstractmethod import pickle import warnings import os +from functools import singledispatchmethod +from collections.abc import Iterable import pandas as pd import torch @@ -10,6 +12,7 @@ import numpy as np import seaborn as sns import matplotlib.pyplot as plt from sklearn.decomposition import PCA +from patsy import dmatrices from ._closed_forms import ( @@ -20,7 +23,7 @@ from ._closed_forms import ( from .elbos import elbo_plnpca, elbo_zi_pln, profiled_elbo_pln from ._utils import ( PLNPlotArgs, - init_sigma, + init_covariance, init_components, init_coef, check_two_dimensions_are_equal, @@ -31,9 +34,12 @@ from ._utils import ( nice_string_of_dict, plot_ellipse, closest, - prepare_covariates, to_tensor, check_dimensions_are_equal, + check_right_rank, + remove_useless_intercepts, + extract_data_from_formula, + get_dict_initialization, ) if torch.cuda.is_available(): @@ -68,17 +74,50 @@ class _PLN(ABC): _latent_var: torch.Tensor _latent_mean: torch.Tensor - def __init__(self): + @singledispatchmethod + def __init__( + self, + counts, + covariates=None, + offsets=None, + offsets_formula="logsum", + dict_initialization=None, + ): """ Simple initialization method. """ - self._fitted = False - self.plotargs = PLNPlotArgs(self.WINDOW) - def format_model_param(self, counts, covariates, offsets, offsets_formula): self._counts, self._covariates, self._offsets = format_model_param( counts, covariates, offsets, offsets_formula ) + check_data_shape(self._counts, self._covariates, self._offsets) + self._fitted = False + self.plotargs = PLNPlotArgs(self.WINDOW) + if dict_initialization is not None: + self.set_init_parameters(dict_initialization) + + @__init__.register(str) + def _( + self, + formula: str, + data: dict, + offsets_formula="logsum", + dict_initialization=None, + ): + counts, covariates, offsets = extract_data_from_formula(formula, data) + self.__init__(counts, covariates, offsets, offsets_formula, dict_initialization) + + def set_init_parameters(self, dict_initialization): + if "coef" not in dict_initialization.keys(): + print("No coef is initialized.") + self.coef = None + for key, array in dict_initialization.items(): + array = format_data(array) + setattr(self, key, array) + + @property + def fitted(self): + return @property def nb_iteration_done(self): @@ -94,12 +133,16 @@ class _PLN(ABC): @property def nb_cov(self): + if self.covariates is None: + return 0 return self.covariates.shape[1] def smart_init_coef(self): - self._coef = init_coef(self._counts, self._covariates) + self._coef = init_coef(self._counts, self._covariates, self._offsets) def random_init_coef(self): + if self.nb_cov == 0: + self._coef = None self._coef = torch.randn((self.nb_cov, self.dim), device=DEVICE) @abstractmethod @@ -140,17 +183,12 @@ class _PLN(ABC): def fit( self, - counts, - covariates=None, - offsets=None, nb_max_iteration=50000, lr=0.01, class_optimizer=torch.optim.Rprop, - tol=1e-6, + tol=1e-3, do_smart_init=True, verbose=False, - offsets_formula="logsum", - keep_going=False, ): """ Main function of the class. Fit a PLN to the data. @@ -168,11 +206,9 @@ class _PLN(ABC): self.print_beginning_message() self.beginnning_time = time.time() - if keep_going is False: - self.format_model_param(counts, covariates, offsets, offsets_formula) - check_data_shape(self._counts, self._covariates, self._offsets) + if self._fitted is False: self.init_parameters(do_smart_init) - if self._fitted is True and keep_going is True: + else: self.beginnning_time -= self.plotargs.running_times[-1] self.optim = class_optimizer(self.list_of_parameters_needing_gradient, lr=lr) stop_condition = False @@ -215,8 +251,8 @@ class _PLN(ABC): def print_end_of_fitting_message(self, stop_condition, tol): if stop_condition is True: print( - f"Tolerance {tol} reached" - f"n {self.plotargs.iteration_number} iterations" + f"Tolerance {tol} reached " + f"in {self.plotargs.iteration_number} iterations" ) else: print( @@ -360,7 +396,11 @@ class _PLN(ABC): @property def model_in_a_dict(self): - return self.dict_data | self.model_parameters | self.latent_parameters + return self.dict_data | self.dict_parameters + + @property + def dict_parameters(self): + return self.model_parameters | self.latent_parameters @property def coef(self): @@ -391,28 +431,18 @@ class _PLN(ABC): return None def save(self, path_of_directory="./"): - path = f"{path_of_directory}/{self.model_path}/" + path = f"{path_of_directory}/{self.path_to_directory}{self.directory_name}" os.makedirs(path, exist_ok=True) - for key, value in self.model_in_a_dict.items(): + for key, value in self.dict_parameters.items(): filename = f"{path}/{key}.csv" if isinstance(value, torch.Tensor): pd.DataFrame(np.array(value.cpu().detach())).to_csv( filename, header=None, index=None ) - else: + elif value is not None: pd.DataFrame(np.array([value])).to_csv( filename, header=None, index=None ) - self._fitted = True - - def load(self, path_of_directory="./"): - path = f"{path_of_directory}/{self.model_path}/" - for key, value in self.model_in_a_dict.items(): - value = torch.from_numpy( - pd.read_csv(path + key + ".csv", header=None).values - ) - setattr(self, key, value) - self.put_parameters_to_device() @property def counts(self): @@ -473,13 +503,26 @@ class _PLN(ABC): return self.covariance def predict(self, covariates=None): - if isinstance(covariates, torch.Tensor): - if covariates.shape[-1] != self.nb_cov - 1: - error_string = f"X has wrong shape ({covariates.shape}).Should" - error_string += f" be ({self.n_samples, self.nb_cov-1})." - raise RuntimeError(error_string) - covariates_with_ones = prepare_covariates(covariates, self.n_samples) - return covariates_with_ones @ self.coef + if covariates is not None and self.nb_cov == 0: + raise AttributeError("No covariates in the model, can't predict") + if covariates is None: + if self.covariates is None: + print("No covariates in the model.") + return None + return self.covariates @ self.coef + if covariates.shape[-1] != self.nb_cov: + error_string = f"X has wrong shape ({covariates.shape}).Should" + error_string += f" be ({self.n_samples, self.nb_cov})." + raise RuntimeError(error_string) + return covariates @ self.coef + + @property + def directory_name(self): + return f"{self.NAME}_nbcov_{self.nb_cov}_dim_{self.dim}" + + @property + def path_to_directory(self): + return "" # need to do a good init for M and S @@ -505,12 +548,10 @@ class PLN(_PLN): self.random_init_latent_parameters() def random_init_latent_parameters(self): - self._latent_var = 1 / 2 * torch.ones((self.n_samples, self.dim)).to(DEVICE) - self._latent_mean = torch.ones((self.n_samples, self.dim)).to(DEVICE) - - @property - def model_path(self): - return self.NAME + if not hasattr(self, "_latent_var"): + self._latent_var = 1 / 2 * torch.ones((self.n_samples, self.dim)).to(DEVICE) + if not hasattr(self, "_latent_mean"): + self._latent_mean = torch.ones((self.n_samples, self.dim)).to(DEVICE) @property def list_of_parameters_needing_gradient(self): @@ -589,31 +630,121 @@ class PLN(_PLN): pass +## en train d'essayer de faire une seule init pour_PLNPCA class PLNPCA: - def __init__(self, ranks): - if isinstance(ranks, (list, np.ndarray)): - self.ranks = ranks - self.dict_models = {} + NAME = "PLNPCA" + + @singledispatchmethod + def __init__( + self, + counts, + covariates=None, + offsets=None, + offsets_formula="logsum", + ranks=range(3, 5), + dict_of_dict_initialization=None, + ): + self.init_data(counts, covariates, offsets, offsets_formula) + self.init_models(ranks, dict_of_dict_initialization) + + def init_data(self, counts, covariates, offsets, offsets_formula): + self._counts, self._covariates, self._offsets = format_model_param( + counts, covariates, offsets, offsets_formula + ) + check_data_shape(self._counts, self._covariates, self._offsets) + self._fitted = False + + @__init__.register(str) + def _( + self, + formula: str, + data: dict, + offsets_formula="logsum", + ranks=range(3, 5), + dict_of_dict_initialization=None, + ): + counts, covariates, offsets = extract_data_from_formula(formula, data) + self.__init__( + counts, + covariates, + offsets, + offsets_formula, + ranks, + dict_of_dict_initialization, + ) + + @property + def covariates(self): + return self.list_models[0].covariates + + @property + def counts(self): + return self.list_models[0].counts + + @counts.setter + def counts(self, counts): + counts = format_data(counts) + if hasattr(self, "_counts"): + check_dimensions_are_equal(self._counts, counts) + self._counts = counts + + @covariates.setter + def covariates(self, covariates): + covariates = format_data(covariates) + # if hasattr(self,) + self._covariates = covariates + + @property + def offsets(self): + return self.list_models[0].offsets + + def init_models(self, ranks, dict_of_dict_initialization): + if isinstance(ranks, (Iterable, np.ndarray)): + self.list_models = [] for rank in ranks: - if isinstance(rank, (int, np.int64)): - self.dict_models[rank] = _PLNPCA(rank) + if isinstance(rank, (int, np.integer)): + dict_initialization = get_dict_initialization( + rank, dict_of_dict_initialization + ) + self.list_models.append( + _PLNPCA( + self._counts, + self._covariates, + self._offsets, + rank, + dict_initialization, + ) + ) else: raise TypeError( - "Please instantiate with either a list\ - of integers or an integer." + f"Please instantiate with either a list " + f"of integers or an integer." ) - elif isinstance(ranks, int): - self.ranks = [ranks] - self.dict_models = {ranks: _PLNPCA(ranks)} + elif isinstance(ranks, (int, np.integer)): + dict_initialization = get_dict_initialization( + ranks, dict_of_dict_initialization + ) + self.list_models = [ + _PLNPCA( + self._counts, + self._covariates, + self._offsets, + rank, + dict_initialization, + ) + ] else: raise TypeError( - "Please instantiate with either a list of \ - integers or an integer." + f"Please instantiate with either a list " f"of integers or an integer." ) @property - def models(self): - return list(self.dict_models.values()) + def ranks(self): + return [model.rank for model in self.list_models] + + @property + def dict_models(self): + return {model.rank: model for model in self.list_models} def print_beginning_message(self): return f"Adjusting {len(self.ranks)} PLN models for PCA analysis \n" @@ -622,39 +753,30 @@ class PLNPCA: def dim(self): return self[self.ranks[0]].dim + @property + def nb_cov(self): + return self[self.ranks[0]].nb_cov + ## should do something for this weird init. pb: if doing the init of self._counts etc ## only in PLNPCA, then we don't do it for each _PLNPCA but then PLN is not doing it. def fit( self, - counts, - covariates=None, - offsets=None, nb_max_iteration=100000, lr=0.01, class_optimizer=torch.optim.Rprop, - tol=1e-6, + tol=1e-3, do_smart_init=True, verbose=False, - offsets_formula="logsum", - keep_going=False, ): self.print_beginning_message() - counts, _, offsets = format_model_param( - counts, covariates, offsets, offsets_formula - ) for pca in self.dict_models.values(): pca.fit( - counts, - covariates, - offsets, nb_max_iteration, lr, class_optimizer, tol, do_smart_init, verbose, - None, - keep_going, ) self.print_ending_message() @@ -681,15 +803,15 @@ class PLNPCA: @property def BIC(self): - return {model.rank: int(model.BIC) for model in self.models} + return {model.rank: int(model.BIC) for model in self.list_models} @property def AIC(self): - return {model.rank: int(model.AIC) for model in self.models} + return {model.rank: int(model.AIC) for model in self.list_models} @property def loglikes(self): - return {model.rank: model.loglike for model in self.models} + return {model.rank: model.loglike for model in self.list_models} def show(self): bic = self.BIC @@ -730,24 +852,31 @@ class PLNPCA: return self[self.best_AIC_model_rank] raise ValueError(f"Unknown criterion {criterion}") - def save(self, path_of_directory="./"): - for model in self.models: - model.save(path_of_directory) + def save(self, path_of_directory="./", ranks=None): + if ranks is None: + ranks = self.ranks + for model in self.list_models: + if model.rank in ranks: + model.save(path_of_directory) - def load(self, path_of_directory="./"): - for model in self.models: - model.load(path_of_directory) + @property + def directory_name(self): + return f"{self.NAME}_nbcov_{self.nb_cov}_dim_{self.dim}" @property def n_samples(self): - return self.models[0].n_samples + return self.list_models[0].n_samples @property def _p(self): return self[self.ranks[0]].p + @property + def models(self): + return self.dict_models.values() + def __str__(self): - nb_models = len(self.models) + nb_models = len(self.list_models) delimiter = "\n" + "-" * NB_CHARACTERS_FOR_NICE_PLOT + "\n" to_print = delimiter to_print += f"Collection of {nb_models} PLNPCA models with \ @@ -780,26 +909,47 @@ class PLNPCA: return ".BIC, .AIC, .loglikes" +# Here, setting the value for each key in dict_parameters class _PLNPCA(_PLN): - NAME = "PLNPCA" + NAME = "_PLNPCA" _components: torch.Tensor - def __init__(self, rank): - super().__init__() + @singledispatchmethod + def __init__(self, counts, covariates, offsets, rank, dict_initialization=None): self._rank = rank + self._counts, self._covariates, self._offsets = format_model_param( + counts, covariates, offsets, None + ) + check_data_shape(self._counts, self._covariates, self._offsets) + self.check_if_rank_is_too_high() + if dict_initialization is not None: + self.set_init_parameters(dict_initialization) + self._fitted = False + self.plotargs = PLNPlotArgs(self.WINDOW) - def init_parameters(self, do_smart_init): - if self.dim < self._rank: - warning_string = f"\nThe requested rank of approximation {self._rank} \ - is greater than the number of variables {self.dim}. \ - Setting rank to {self.dim}" + @__init__.register(str) + def _(self, formula, data, rank, dict_initialization): + counts, covariates, offsets = extract_data_from_formula(formula, data) + self.__init__(counts, covariates, offsets, rank, dict_initialization) + + def check_if_rank_is_too_high(self): + if self.dim < self.rank: + warning_string = ( + f"\nThe requested rank of approximation {self.rank} " + f"is greater than the number of variables {self.dim}. " + f"Setting rank to {self.dim}" + ) warnings.warn(warning_string) self._rank = self.dim - super().init_parameters(do_smart_init) @property - def model_path(self): - return f"{self.NAME}_{self._rank}_rank" + def directory_name(self): + return f"{self.NAME}_rank_{self._rank}" + # return f"PLNPCA_nbcov_{self.nb_cov}_dim_{self.dim}/{self.NAME}_rank_{self._rank}" + + @property + def path_to_directory(self): + return f"PLNPCA_nbcov_{self.nb_cov}_dim_{self.dim}/" @property def rank(self): @@ -817,10 +967,12 @@ class _PLNPCA(_PLN): return {"coef": self.coef, "components": self.components} def smart_init_model_parameters(self): - super().smart_init_coef() - self._components = init_components( - self._counts, self._covariates, self._coef, self._rank - ) + if not hasattr(self, "_coef"): + super().smart_init_coef() + if not hasattr(self, "_components"): + self._components = init_components( + self._counts, self._covariates, self._coef, self._rank + ) def random_init_model_parameters(self): super().random_init_coef() @@ -831,23 +983,27 @@ class _PLNPCA(_PLN): self._latent_mean = torch.ones((self.n_samples, self._rank)).to(DEVICE) def smart_init_latent_parameters(self): - self._latent_mean = ( - init_latent_mean( - self._counts, - self._covariates, - self._offsets, - self._coef, - self._components, + if not hasattr(self, "_latent_mean"): + self._latent_mean = ( + init_latent_mean( + self._counts, + self._covariates, + self._offsets, + self._coef, + self._components, + ) + .to(DEVICE) + .detach() + ) + if not hasattr(self, "_latent_var"): + self._latent_var = ( + 1 / 2 * torch.ones((self.n_samples, self._rank)).to(DEVICE) ) - .to(DEVICE) - .detach() - ) - self._latent_var = 1 / 2 * torch.ones((self.n_samples, self._rank)).to(DEVICE) - self._latent_mean.requires_grad_(True) - self._latent_var.requires_grad_(True) @property def list_of_parameters_needing_gradient(self): + if self._coef is None: + return [self._components, self._latent_mean, self._latent_var] return [self._components, self._coef, self._latent_mean, self._latent_var] def compute_elbo(self): @@ -896,6 +1052,16 @@ class _PLNPCA(_PLN): ortho_components = torch.linalg.qr(self._components, "reduced")[0] return torch.mm(self.latent_variables, ortho_components).detach().cpu() + def pca_projected_latent_variables(self, n_components=None): + if n_components is None: + n_components = self.get_max_components() + if n_components > self.dim: + raise RuntimeError( + f"You ask more components ({n_components}) than variables ({self.dim})" + ) + pca = PCA(n_components=n_components) + return pca.fit_transform(self.projected_latent_variables.detach().cpu()) + @property def components(self): return self.attribute_or_none("_components") @@ -905,13 +1071,16 @@ class _PLNPCA(_PLN): self._components = components def viz(self, ax=None, colors=None): - if self._rank != 2: - raise RuntimeError("Can't perform visualization for rank != 2.") if ax is None: ax = plt.gca() - proj_variables = self.projected_latent_variables - x = proj_variables[:, 0].cpu().numpy() - y = proj_variables[:, 1].cpu().numpy() + if self._rank < 2: + raise RuntimeError("Can't perform visualization for rank < 2.") + if self._rank > 2: + proj_variables = self.pca_projected_latent_variables(n_components=2) + if self._rank == 2: + proj_variables = self.projected_latent_variables.cpu().numpy() + x = proj_variables[:, 0] + y = proj_variables[:, 1] sns.scatterplot(x=x, y=y, hue=colors, ax=ax) covariances = torch.diag_embed(self._latent_var**2).detach().cpu() for i in range(covariances.shape[0]): @@ -943,10 +1112,12 @@ class ZIPLN(PLN): # should change the good initialization, especially for _coef_inflation def smart_init_model_parameters(self): super().smart_init_model_parameters() - self._covariance = init_sigma( - self._counts, self._covariates, self._offsets, self._coef - ) - self._coef_inflation = torch.randn(self.nb_cov, self.dim) + if not hasattr(self, "_covariance"): + self._covariance = init_covariance( + self._counts, self._covariates, self._coef + ) + if not hasattr(self, "_coef_inflation"): + self._coef_inflation = torch.randn(self.nb_cov, self.dim) def random_init_latent_parameters(self): self._dirac = self._counts == 0 diff --git a/setup.py b/setup.py index 74cc89090df8d4d49d7a9b99b19d1696b04d3de1..9c8b27450e6254d2ddc1dd060b31c3724054d812 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from setuptools import setup, find_packages -VERSION = "0.0.37" +VERSION = "0.0.38" with open("README.md", "r") as fh: long_description = fh.read() diff --git a/test.py b/test.py index 900e42df2c9dd4070e663d898b93b219d4fa3734..b9d5e686d2a75cc14ad36c59459700ddf463bbe0 100644 --- a/test.py +++ b/test.py @@ -2,6 +2,8 @@ from pyPLNmodels.models import PLNPCA, _PLNPCA, PLN from pyPLNmodels import get_real_count_data, get_simulated_count_data import os +import pandas as pd +import numpy as np os.chdir("./pyPLNmodels/") @@ -11,12 +13,15 @@ covariates = None offsets = None # counts, covariates, offsets = get_simulated_count_data(seed = 0) -pca = PLNPCA([3, 4]) +pca = PLNPCA(counts, covariates, offsets, ranks=[3, 4]) +pca.fit(tol=0.1) -pca.fit(counts, covariates, offsets, tol=0.1) -print(pca) - -# pln = PLN() +# pca.fit() +# print(pca) +# data = pd.DataFrame(counts) +# pln = PLN("counts~1", data) +# pln.fit() +# print(pln) # pcamodel = pca.best_model() # pcamodel.save() # model = PLNPCA([4])[4] diff --git a/tests/conftest.py b/tests/conftest.py index 1df7d25bb7e0e401e6b471e678330ef839d98959..904240c9611cc87e63c8495600d21bfdfe94b6a8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,414 @@ import sys +import glob +from functools import singledispatch + +import pytest +import torch +from pytest_lazyfixture import lazy_fixture as lf +from pyPLNmodels import load_model, load_plnpca +from pyPLNmodels.models import PLN, _PLNPCA, PLNPCA + sys.path.append("../") + +pytest_plugins = [ + fixture_file.replace("/", ".").replace(".py", "") + for fixture_file in glob.glob("src/**/tests/fixtures/[!__]*.py", recursive=True) +] + + +from tests.import_data import ( + data_sim_0cov, + data_sim_2cov, + data_real, +) + + +counts_sim_0cov = data_sim_0cov["counts"] +covariates_sim_0cov = data_sim_0cov["covariates"] +offsets_sim_0cov = data_sim_0cov["offsets"] + +counts_sim_2cov = data_sim_2cov["counts"] +covariates_sim_2cov = data_sim_2cov["covariates"] +offsets_sim_2cov = data_sim_2cov["offsets"] + +counts_real = data_real["counts"] + + +def add_fixture_to_dict(my_dict, string_fixture): + my_dict[string_fixture] = [lf(string_fixture)] + return my_dict + + +def add_list_of_fixture_to_dict( + my_dict, name_of_list_of_fixtures, list_of_string_fixtures +): + my_dict[name_of_list_of_fixtures] = [] + for string_fixture in list_of_string_fixtures: + my_dict[name_of_list_of_fixtures].append(lf(string_fixture)) + return my_dict + + +RANK = 8 +RANKS = [2, 6] +instances = [] +# dict_fixtures_models = [] + + +@singledispatch +def convenient_plnpca( + counts, + covariates=None, + offsets=None, + offsets_formula=None, + dict_initialization=None, +): + return _PLNPCA( + counts, covariates, offsets, rank=RANK, dict_initialization=dict_initialization + ) + + +@convenient_plnpca.register(str) +def _(formula, data, offsets_formula=None, dict_initialization=None): + return _PLNPCA(formula, data, rank=RANK, dict_initialization=dict_initialization) + + +@singledispatch +def convenientplnpca( + counts, + covariates=None, + offsets=None, + offsets_formula=None, + dict_initialization=None, +): + return PLNPCA( + counts, + covariates, + offsets, + offsets_formula, + dict_of_dict_initialization=dict_initialization, + ranks=RANKS, + ) + + +@convenientplnpca.register(str) +def _(formula, data, offsets_formula=None, dict_initialization=None): + return PLNPCA( + formula, + data, + offsets_formula, + ranks=RANKS, + dict_of_dict_initialization=dict_initialization, + ) + + +def generate_new_model(model, *args, **kwargs): + name_dir = model.directory_name + name = model.NAME + if name in ("PLN", "_PLNPCA"): + path = model.path_to_directory + name_dir + init = load_model(path) + if name == "PLN": + new = PLN(*args, **kwargs, dict_initialization=init) + if name == "_PLNPCA": + new = convenient_plnpca(*args, **kwargs, dict_initialization=init) + if name == "PLNPCA": + init = load_plnpca(name_dir) + new = convenientplnpca(*args, **kwargs, dict_initialization=init) + return new + + +def cache(func): + dict_cache = {} + + def new_func(request): + if request.param.__name__ not in dict_cache: + dict_cache[request.param.__name__] = func(request) + return dict_cache[request.param.__name__] + + return new_func + + +params = [PLN, convenient_plnpca, convenientplnpca] +dict_fixtures = {} + + +@pytest.fixture(params=params) +def simulated_pln_0cov_array(request): + cls = request.param + pln = cls(counts_sim_0cov, covariates_sim_0cov, offsets_sim_0cov) + return pln + + +@pytest.fixture(params=params) +@cache +def simulated_fitted_pln_0cov_array(request): + cls = request.param + pln = cls(counts_sim_0cov, covariates_sim_0cov, offsets_sim_0cov) + pln.fit() + return pln + + +@pytest.fixture(params=params) +def simulated_pln_0cov_formula(request): + cls = request.param + pln = cls("counts ~ 0", data_sim_0cov) + return pln + + +@pytest.fixture(params=params) +@cache +def simulated_fitted_pln_0cov_formula(request): + cls = request.param + pln = cls("counts ~ 0", data_sim_0cov) + pln.fit() + return pln + + +@pytest.fixture +def simulated_loaded_pln_0cov_formula(simulated_fitted_pln_0cov_formula): + simulated_fitted_pln_0cov_formula.save() + return generate_new_model( + simulated_fitted_pln_0cov_formula, + "counts ~ 0", + data_sim_0cov, + ) + + +@pytest.fixture +def simulated_loaded_pln_0cov_array(simulated_fitted_pln_0cov_array): + simulated_fitted_pln_0cov_array.save() + return generate_new_model( + simulated_fitted_pln_0cov_array, + counts_sim_0cov, + covariates_sim_0cov, + offsets_sim_0cov, + ) + + +sim_pln_0cov_instance = [ + "simulated_pln_0cov_array", + "simulated_pln_0cov_formula", +] + +instances = sim_pln_0cov_instance + instances + +dict_fixtures = add_list_of_fixture_to_dict( + dict_fixtures, "sim_pln_0cov_instance", sim_pln_0cov_instance +) + +sim_pln_0cov_fitted = [ + "simulated_fitted_pln_0cov_array", + "simulated_fitted_pln_0cov_formula", +] + +dict_fixtures = add_list_of_fixture_to_dict( + dict_fixtures, "sim_pln_0cov_fitted", sim_pln_0cov_fitted +) + +sim_pln_0cov_loaded = [ + "simulated_loaded_pln_0cov_array", + "simulated_loaded_pln_0cov_formula", +] + +dict_fixtures = add_list_of_fixture_to_dict( + dict_fixtures, "sim_pln_0cov_loaded", sim_pln_0cov_loaded +) + +sim_pln_0cov = sim_pln_0cov_instance + sim_pln_0cov_fitted + sim_pln_0cov_loaded +dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "sim_pln_0cov", sim_pln_0cov) + + +@pytest.fixture(params=params) +@cache +def simulated_pln_2cov_array(request): + cls = request.param + pln_full = cls(counts_sim_2cov, covariates_sim_2cov, offsets_sim_2cov) + return pln_full + + +@pytest.fixture +def simulated_fitted_pln_2cov_array(simulated_pln_2cov_array): + simulated_pln_2cov_array.fit() + return simulated_pln_2cov_array + + +@pytest.fixture(params=params) +@cache +def simulated_pln_2cov_formula(request): + cls = request.param + pln_full = cls("counts ~ 0 + covariates", data_sim_2cov) + return pln_full + + +@pytest.fixture +def simulated_fitted_pln_2cov_formula(simulated_pln_2cov_formula): + simulated_pln_2cov_formula.fit() + return simulated_pln_2cov_formula + + +@pytest.fixture +def simulated_loaded_pln_2cov_formula(simulated_fitted_pln_2cov_formula): + simulated_fitted_pln_2cov_formula.save() + return generate_new_model( + simulated_fitted_pln_2cov_formula, + "counts ~0 + covariates", + data_sim_2cov, + ) + + +@pytest.fixture +def simulated_loaded_pln_2cov_array(simulated_fitted_pln_2cov_array): + simulated_fitted_pln_2cov_array.save() + return generate_new_model( + simulated_fitted_pln_2cov_array, + counts_sim_2cov, + covariates_sim_2cov, + offsets_sim_2cov, + ) + + +sim_pln_2cov_instance = [ + "simulated_pln_2cov_array", + "simulated_pln_2cov_formula", +] +instances = sim_pln_2cov_instance + instances + +dict_fixtures = add_list_of_fixture_to_dict( + dict_fixtures, "sim_pln_2cov_instance", sim_pln_2cov_instance +) + +sim_pln_2cov_fitted = [ + "simulated_fitted_pln_2cov_array", + "simulated_fitted_pln_2cov_formula", +] + +dict_fixtures = add_list_of_fixture_to_dict( + dict_fixtures, "sim_pln_2cov_fitted", sim_pln_2cov_fitted +) + +sim_pln_2cov_loaded = [ + "simulated_loaded_pln_2cov_array", + "simulated_loaded_pln_2cov_formula", +] + +dict_fixtures = add_list_of_fixture_to_dict( + dict_fixtures, "sim_pln_2cov_loaded", sim_pln_2cov_loaded +) + +sim_pln_2cov = sim_pln_2cov_instance + sim_pln_2cov_fitted + sim_pln_2cov_loaded +dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "sim_pln_2cov", sim_pln_2cov) + + +@pytest.fixture(params=params) +@cache +def real_pln_intercept_array(request): + cls = request.param + pln_full = cls(counts_real, covariates=torch.ones((counts_real.shape[0], 1))) + return pln_full + + +@pytest.fixture +def real_fitted_pln_intercept_array(real_pln_intercept_array): + real_pln_intercept_array.fit() + return real_pln_intercept_array + + +@pytest.fixture(params=params) +@cache +def real_pln_intercept_formula(request): + cls = request.param + pln_full = cls("counts ~ 1", data_real) + return pln_full + + +@pytest.fixture +def real_fitted_pln_intercept_formula(real_pln_intercept_formula): + real_pln_intercept_formula.fit() + return real_pln_intercept_formula + + +@pytest.fixture +def real_loaded_pln_intercept_formula(real_fitted_pln_intercept_formula): + real_fitted_pln_intercept_formula.save() + return generate_new_model( + real_fitted_pln_intercept_formula, "counts ~ 1", data_real + ) + + +@pytest.fixture +def real_loaded_pln_intercept_array(real_fitted_pln_intercept_array): + real_fitted_pln_intercept_array.save() + return generate_new_model( + real_fitted_pln_intercept_array, + counts_real, + covariates=torch.ones((counts_real.shape[0], 1)), + ) + + +real_pln_instance = [ + "real_pln_intercept_array", + "real_pln_intercept_formula", +] +instances = real_pln_instance + instances + +dict_fixtures = add_list_of_fixture_to_dict( + dict_fixtures, "real_pln_instance", real_pln_instance +) + +real_pln_fitted = [ + "real_fitted_pln_intercept_array", + "real_fitted_pln_intercept_formula", +] +dict_fixtures = add_list_of_fixture_to_dict( + dict_fixtures, "real_pln_fitted", real_pln_fitted +) + +real_pln_loaded = [ + "real_loaded_pln_intercept_array", + "real_loaded_pln_intercept_formula", +] +dict_fixtures = add_list_of_fixture_to_dict( + dict_fixtures, "real_pln_loaded", real_pln_loaded +) + +sim_loaded_pln = sim_pln_0cov_loaded + sim_pln_2cov_loaded + +loaded_pln = real_pln_loaded + sim_loaded_pln +dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "loaded_pln", loaded_pln) + +simulated_pln_fitted = sim_pln_0cov_fitted + sim_pln_2cov_fitted +dict_fixtures = add_list_of_fixture_to_dict( + dict_fixtures, "simulated_pln_fitted", simulated_pln_fitted +) +fitted_pln = real_pln_fitted + simulated_pln_fitted +dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "fitted_pln", fitted_pln) + + +loaded_and_fitted_sim_pln = simulated_pln_fitted + sim_loaded_pln +loaded_and_fitted_real_pln = real_pln_fitted + real_pln_loaded +dict_fixtures = add_list_of_fixture_to_dict( + dict_fixtures, "loaded_and_fitted_real_pln", loaded_and_fitted_real_pln +) +dict_fixtures = add_list_of_fixture_to_dict( + dict_fixtures, "loaded_and_fitted_sim_pln", loaded_and_fitted_sim_pln +) +loaded_and_fitted_pln = fitted_pln + loaded_pln +dict_fixtures = add_list_of_fixture_to_dict( + dict_fixtures, "loaded_and_fitted_pln", loaded_and_fitted_pln +) + +real_pln = real_pln_instance + real_pln_fitted + real_pln_loaded +dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "real_pln", real_pln) + +sim_pln = sim_pln_2cov + sim_pln_0cov +dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "sim_pln", sim_pln) + +all_pln = real_pln + sim_pln + instances +dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "instances", instances) +dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "all_pln", all_pln) + + +for string_fixture in all_pln: + print("string_fixture", string_fixture) + dict_fixtures = add_fixture_to_dict(dict_fixtures, string_fixture) diff --git a/tests/import_data.py b/tests/import_data.py new file mode 100644 index 0000000000000000000000000000000000000000..80e613ba12fc5a92d6a8187c65f9bcfd32ee23b9 --- /dev/null +++ b/tests/import_data.py @@ -0,0 +1,39 @@ +import os + +from pyPLNmodels import ( + get_simulated_count_data, + get_real_count_data, +) + + +( + counts_sim_0cov, + covariates_sim_0cov, + offsets_sim_0cov, + true_covariance_0cov, + true_coef_0cov, +) = get_simulated_count_data(return_true_param=True, nb_cov=0) +( + counts_sim_2cov, + covariates_sim_2cov, + offsets_sim_2cov, + true_covariance_2cov, + true_coef_2cov, +) = get_simulated_count_data(return_true_param=True, nb_cov=2) + +data_sim_0cov = { + "counts": counts_sim_0cov, + "covariates": covariates_sim_0cov, + "offsets": offsets_sim_0cov, +} +true_sim_0cov = {"Sigma": true_covariance_0cov, "beta": true_coef_0cov} +true_sim_2cov = {"Sigma": true_covariance_2cov, "beta": true_coef_2cov} + + +data_sim_2cov = { + "counts": counts_sim_2cov, + "covariates": covariates_sim_2cov, + "offsets": offsets_sim_2cov, +} +counts_real = get_real_count_data() +data_real = {"counts": counts_real} diff --git a/tests/import_fixtures.py b/tests/import_fixtures.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_args.py b/tests/test_args.py deleted file mode 100644 index 16c8a73d7f0c354e7691ed0b66c734518fe083ca..0000000000000000000000000000000000000000 --- a/tests/test_args.py +++ /dev/null @@ -1,51 +0,0 @@ -import os - -from pyPLNmodels.models import PLN, PLNPCA, _PLNPCA -from pyPLNmodels import get_simulated_count_data, get_real_count_data -import pytest -from pytest_lazyfixture import lazy_fixture as lf -import pandas as pd -import numpy as np - -( - counts_sim, - covariates_sim, - offsets_sim, -) = get_simulated_count_data(nb_cov=2) - -couts_real = get_real_count_data(n_samples=298, dim=101) -RANKS = [2, 8] - - -@pytest.fixture -def instance_plnpca(): - plnpca = PLNPCA(ranks=RANKS) - return plnpca - - -@pytest.fixture -def instance__plnpca(): - model = _PLNPCA(rank=RANKS[0]) - return model - - -@pytest.fixture -def instance_pln_full(): - return PLN() - - -all_instances = [lf("instance_plnpca"), lf("instance__plnpca"), lf("instance_pln_full")] - - -@pytest.mark.parametrize("instance", all_instances) -def test_pandas_init(instance): - instance.fit( - pd.DataFrame(counts_sim.numpy()), - pd.DataFrame(covariates_sim.numpy()), - pd.DataFrame(offsets_sim.numpy()), - ) - - -@pytest.mark.parametrize("instance", all_instances) -def test_numpy_init(instance): - instance.fit(counts_sim.numpy(), covariates_sim.numpy(), offsets_sim.numpy()) diff --git a/tests/test_common.py b/tests/test_common.py index 7e0e6c955a074e8169b643f014bca59583e32d01..b74cf9fbae73676d7c5bcf34643324c63e07ad4a 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,192 +1,30 @@ -import torch -import numpy as np -import pandas as pd - -from pyPLNmodels.models import PLN, _PLNPCA -from pyPLNmodels import get_simulated_count_data, get_real_count_data -from tests.utils import MSE - -import pytest -from pytest_lazyfixture import lazy_fixture as lf import os -( - counts_sim, - covariates_sim, - offsets_sim, - true_covariance, - true_coef, -) = get_simulated_count_data(return_true_param=True, nb_cov=2) - - -counts_real = get_real_count_data() -rank = 8 - - -@pytest.fixture -def instance_pln_full(): - pln_full = PLN() - return pln_full - - -@pytest.fixture -def instance__plnpca(): - plnpca = _PLNPCA(rank=rank) - return plnpca - - -@pytest.fixture -def simulated_fitted_pln_full(): - pln_full = PLN() - pln_full.fit(counts=counts_sim, covariates=covariates_sim, offsets=offsets_sim) - return pln_full - - -@pytest.fixture -def simulated_fitted__plnpca(): - plnpca = _PLNPCA(rank=rank) - plnpca.fit(counts=counts_sim, covariates=covariates_sim, offsets=offsets_sim) - return plnpca - - -@pytest.fixture -def loaded_simulated_pln_full(simulated_fitted_pln_full): - simulated_fitted_pln_full.save() - loaded_pln_full = PLN() - loaded_pln_full.load() - return loaded_pln_full - - -@pytest.fixture -def loaded_refit_simulated_pln_full(loaded_simulated_pln_full): - loaded_simulated_pln_full.fit( - counts=counts_sim, - covariates=covariates_sim, - offsets=offsets_sim, - keep_going=True, - ) - return loaded_simulated_pln_full - - -@pytest.fixture -def loaded_simulated__plnpca(simulated_fitted__plnpca): - simulated_fitted__plnpca.save() - loaded_pln_full = _PLNPCA(rank=rank) - loaded_pln_full.load() - return loaded_pln_full - - -@pytest.fixture -def loaded_refit_simulated__plnpca(loaded_simulated__plnpca): - loaded_simulated__plnpca.fit( - counts=counts_sim, - covariates=covariates_sim, - offsets=offsets_sim, - keep_going=True, - ) - return loaded_simulated__plnpca - - -@pytest.fixture -def real_fitted_pln_full(): - pln_full = PLN() - pln_full.fit(counts=counts_real) - return pln_full - - -@pytest.fixture -def loaded_real_pln_full(real_fitted_pln_full): - real_fitted_pln_full.save() - loaded_pln_full = PLN() - loaded_pln_full.load() - return loaded_pln_full - - -@pytest.fixture -def loaded_refit_real_pln_full(loaded_real_pln_full): - loaded_real_pln_full.fit(counts=counts_real, keep_going=True) - return loaded_real_pln_full - - -@pytest.fixture -def real_fitted__plnpca(): - plnpca = _PLNPCA(rank=rank) - plnpca.fit(counts=counts_real) - return plnpca - - -@pytest.fixture -def loaded_real__plnpca(real_fitted__plnpca): - real_fitted__plnpca.save() - loaded_plnpca = _PLNPCA(rank=rank) - loaded_plnpca.load() - return loaded_plnpca - - -@pytest.fixture -def loaded_refit_real__plnpca(loaded_real__plnpca): - loaded_real__plnpca.fit(counts=counts_real, keep_going=True) - return loaded_real__plnpca - - -real_pln_full = [ - lf("real_fitted_pln_full"), - lf("loaded_real_pln_full"), - lf("loaded_refit_real_pln_full"), -] -real__plnpca = [ - lf("real_fitted__plnpca"), - lf("loaded_real__plnpca"), - lf("loaded_refit_real__plnpca"), -] -simulated_pln_full = [ - lf("simulated_fitted_pln_full"), - lf("loaded_simulated_pln_full"), - lf("loaded_refit_simulated_pln_full"), -] -simulated__plnpca = [ - lf("simulated_fitted__plnpca"), - lf("loaded_simulated__plnpca"), - lf("loaded_refit_simulated__plnpca"), -] - -loaded_sim_pln = [ - lf("loaded_simulated__plnpca"), - lf("loaded_simulated_pln_full"), - lf("loaded_refit_simulated_pln_full"), - lf("loaded_refit_simulated_pln_full"), -] - - -@pytest.mark.parametrize("loaded", loaded_sim_pln) -def test_refit_not_keep_going(loaded): - loaded.fit( - counts=counts_sim, - covariates=covariates_sim, - offsets=offsets_sim, - keep_going=False, - ) - - -all_instances = [lf("instance__plnpca"), lf("instance_pln_full")] +import torch +import pytest -all_fitted__plnpca = simulated__plnpca + real__plnpca -all_fitted_pln_full = simulated_pln_full + real_pln_full +from tests.conftest import dict_fixtures +from tests.utils import MSE, filter_models -simulated_any_pln = simulated__plnpca + simulated_pln_full -real_any_pln = real_pln_full + real__plnpca -all_fitted_models = simulated_any_pln + real_any_pln +from tests.import_data import true_sim_0cov, true_sim_2cov -@pytest.mark.parametrize("any_pln", all_fitted_models) +@pytest.mark.parametrize("any_pln", dict_fixtures["fitted_pln"]) +@filter_models(["PLN", "_PLNPCA"]) def test_properties(any_pln): - assert hasattr(any_pln, "latent_variables") - assert hasattr(any_pln, "model_parameters") assert hasattr(any_pln, "latent_parameters") + assert hasattr(any_pln, "latent_variables") assert hasattr(any_pln, "optim_parameters") + assert hasattr(any_pln, "model_parameters") + + +@pytest.mark.parametrize("any_pln", dict_fixtures["loaded_and_fitted_pln"]) +def test_print(any_pln): + print(any_pln) -@pytest.mark.parametrize("any_pln", all_fitted_models) +@pytest.mark.parametrize("any_pln", dict_fixtures["fitted_pln"]) +@filter_models(["PLN", "_PLNPCA"]) def test_show_coef_transform_covariance_pcaprojected(any_pln): any_pln.show() any_pln.plotargs.show_loss() @@ -200,137 +38,80 @@ def test_show_coef_transform_covariance_pcaprojected(any_pln): any_pln.pca_projected_latent_variables(n_components=any_pln.dim + 1) -@pytest.mark.parametrize("sim_pln", simulated_any_pln) +@pytest.mark.parametrize("sim_pln", dict_fixtures["loaded_and_fitted_pln"]) +@filter_models(["PLN", "_PLNPCA"]) def test_predict_simulated(sim_pln): - X = torch.randn((sim_pln.n_samples, sim_pln.nb_cov - 1)) - prediction = sim_pln.predict(X) - expected = ( - torch.stack((torch.ones(sim_pln.n_samples, 1), X), axis=1).squeeze() - @ sim_pln.coef - ) - assert torch.all(torch.eq(expected, prediction)) - - -@pytest.mark.parametrize("real_pln", real_any_pln) -def test_predict_real(real_pln): - prediction = real_pln.predict() - expected = torch.ones(real_pln.n_samples, 1) @ real_pln.coef - assert torch.all(torch.eq(expected, prediction)) - - -@pytest.mark.parametrize("any_pln", all_fitted_models) -def test_print(any_pln): - print(any_pln) - - -@pytest.mark.parametrize("any_instance_pln", all_instances) + if sim_pln.nb_cov == 0: + assert sim_pln.predict() is None + with pytest.raises(AttributeError): + sim_pln.predict(1) + else: + X = torch.randn((sim_pln.n_samples, sim_pln.nb_cov)) + prediction = sim_pln.predict(X) + expected = X @ sim_pln.coef + assert torch.all(torch.eq(expected, prediction)) + + +@pytest.mark.parametrize("any_instance_pln", dict_fixtures["instances"]) def test_verbose(any_instance_pln): - any_instance_pln.fit( - counts=counts_sim, covariates=covariates_sim, offsets=offsets_sim, verbose=True - ) - + any_instance_pln.fit(verbose=True, tol=0.1) -@pytest.mark.parametrize("sim_pln", simulated_any_pln) -def test_only_counts(sim_pln): - sim_pln.fit(counts=counts_sim) - -@pytest.mark.parametrize("sim_pln", simulated_any_pln) -def test_only_counts_and_offsets(sim_pln): - sim_pln.fit(counts=counts_sim, offsets=offsets_sim) - - -@pytest.mark.parametrize("sim_pln", simulated_any_pln) -def test_only_Y_and_cov(sim_pln): - sim_pln.fit(counts=counts_sim, covariates=covariates_sim) - - -@pytest.mark.parametrize("simulated_fitted_any_pln", simulated_any_pln) +@pytest.mark.parametrize( + "simulated_fitted_any_pln", dict_fixtures["loaded_and_fitted_sim_pln"] +) +@filter_models(["PLN", "_PLNPCA"]) def test_find_right_covariance(simulated_fitted_any_pln): + if simulated_fitted_any_pln.nb_cov == 0: + true_covariance = true_sim_0cov["Sigma"] + elif simulated_fitted_any_pln.nb_cov == 2: + true_covariance = true_sim_2cov["Sigma"] mse_covariance = MSE(simulated_fitted_any_pln.covariance - true_covariance) assert mse_covariance < 0.05 -@pytest.mark.parametrize("sim_pln", simulated_any_pln) -def test_find_right_coef(sim_pln): - mse_coef = MSE(sim_pln.coef - true_coef) - assert mse_coef < 0.1 +@pytest.mark.parametrize( + "real_fitted_and_loaded_pln", dict_fixtures["loaded_and_fitted_real_pln"] +) +@filter_models(["PLN", "_PLNPCA"]) +def test_right_covariance_shape(real_fitted_and_loaded_pln): + assert real_fitted_and_loaded_pln.covariance.shape == (100, 100) -def test_number_of_iterations_pln_full(simulated_fitted_pln_full): - nb_iterations = len(simulated_fitted_pln_full.elbos_list) - assert 50 < nb_iterations < 300 +@pytest.mark.parametrize( + "simulated_fitted_any_pln", dict_fixtures["loaded_and_fitted_pln"] +) +@filter_models(["PLN", "_PLNPCA"]) +def test_find_right_coef(simulated_fitted_any_pln): + if simulated_fitted_any_pln.nb_cov == 2: + true_coef = true_sim_2cov["beta"] + mse_coef = MSE(simulated_fitted_any_pln.coef - true_coef) + assert mse_coef < 0.1 + elif simulated_fitted_any_pln.nb_cov == 0: + assert simulated_fitted_any_pln.coef is None -def test_computable_elbopca(instance__plnpca, simulated_fitted__plnpca): - instance__plnpca.counts = simulated_fitted__plnpca.counts - instance__plnpca.covariates = simulated_fitted__plnpca.covariates - instance__plnpca.offsets = simulated_fitted__plnpca.offsets - instance__plnpca.latent_mean = simulated_fitted__plnpca.latent_mean - instance__plnpca.latent_var = simulated_fitted__plnpca.latent_var - instance__plnpca.components = simulated_fitted__plnpca.components - instance__plnpca.coef = simulated_fitted__plnpca.coef - instance__plnpca.compute_elbo() - - -def test_computable_elbo_full(instance_pln_full, simulated_fitted_pln_full): - instance_pln_full.counts = simulated_fitted_pln_full.counts - instance_pln_full.covariates = simulated_fitted_pln_full.covariates - instance_pln_full.offsets = simulated_fitted_pln_full.offsets - instance_pln_full.latent_mean = simulated_fitted_pln_full.latent_mean - instance_pln_full.latent_var = simulated_fitted_pln_full.latent_var - instance_pln_full.covariance = simulated_fitted_pln_full.covariance - instance_pln_full.coef = simulated_fitted_pln_full.coef - instance_pln_full.compute_elbo() - - -def test_fail_count_setter(simulated_fitted_pln_full): +@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"]) +@filter_models(["PLN", "_PLNPCA"]) +def test_fail_count_setter(pln): wrong_counts = torch.randint(size=(10, 5), low=0, high=10) with pytest.raises(Exception): - simulated_fitted_pln_full.counts = wrong_counts - + pln.counts = wrong_counts -@pytest.mark.parametrize("any_pln", all_fitted_models) -def test_setter_with_numpy(any_pln): - np_counts = any_pln.counts.numpy() - any_pln.counts = np_counts - -@pytest.mark.parametrize("any_pln", all_fitted_models) -def test_setter_with_pandas(any_pln): - pd_counts = pd.DataFrame(any_pln.counts.numpy()) - any_pln.counts = pd_counts - - -@pytest.mark.parametrize("instance", all_instances) +@pytest.mark.parametrize("instance", dict_fixtures["instances"]) def test_random_init(instance): - instance.fit(counts_sim, covariates_sim, offsets_sim, do_smart_init=False) + instance.fit(do_smart_init=False) -@pytest.mark.parametrize("instance", all_instances) +@pytest.mark.parametrize("instance", dict_fixtures["instances"]) def test_print_end_of_fitting_message(instance): - instance.fit(counts_sim, covariates_sim, offsets_sim, nb_max_iteration=4) + instance.fit(nb_max_iteration=4) -@pytest.mark.parametrize("any_pln", all_fitted_models) -def test_fail_wrong_covariates_prediction(any_pln): - X = torch.randn(any_pln.n_samples, any_pln.nb_cov) +@pytest.mark.parametrize("pln", dict_fixtures["fitted_pln"]) +@filter_models(["PLN", "_PLNPCA"]) +def test_fail_wrong_covariates_prediction(pln): + X = torch.randn(pln.n_samples, pln.nb_cov + 1) with pytest.raises(Exception): - any_pln.predict(X) - - -@pytest.mark.parametrize("any__plnpca", all_fitted__plnpca) -def test_latent_var_pca(any__plnpca): - assert any__plnpca.transform(project=False).shape == any__plnpca.counts.shape - assert any__plnpca.transform().shape == (any__plnpca.n_samples, any__plnpca.rank) - - -@pytest.mark.parametrize("any_pln_full", all_fitted_pln_full) -def test_latent_var_pln_full(any_pln_full): - assert any_pln_full.transform().shape == any_pln_full.counts.shape - - -def test_wrong_rank(): - instance = _PLNPCA(counts_sim.shape[1] + 1) - with pytest.warns(UserWarning): - instance.fit(counts=counts_sim, covariates=covariates_sim, offsets=offsets_sim) + pln.predict(X) diff --git a/tests/test_pln_full.py b/tests/test_pln_full.py new file mode 100644 index 0000000000000000000000000000000000000000..1f9d6a5d06dd4e207db265ed5981629d9da19e88 --- /dev/null +++ b/tests/test_pln_full.py @@ -0,0 +1,17 @@ +import pytest + +from tests.conftest import dict_fixtures +from tests.utils import filter_models + + +@pytest.mark.parametrize("fitted_pln", dict_fixtures["fitted_pln"]) +@filter_models(["PLN"]) +def test_number_of_iterations_pln_full(fitted_pln): + nb_iterations = len(fitted_pln.elbos_list) + assert 50 < nb_iterations < 300 + + +@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"]) +@filter_models(["PLN"]) +def test_latent_var_full(pln): + assert pln.transform().shape == pln.counts.shape diff --git a/tests/test_plnpca.py b/tests/test_plnpca.py index db0e324aac2d70ff3b98a517aca8fec9f01b0f4f..9eb1b2f4289efce04cef20a824570e407aede96e 100644 --- a/tests/test_plnpca.py +++ b/tests/test_plnpca.py @@ -1,158 +1,76 @@ import os import pytest -from pytest_lazyfixture import lazy_fixture as lf -from pyPLNmodels.models import PLNPCA, _PLNPCA -from pyPLNmodels import get_simulated_count_data, get_real_count_data -from tests.utils import MSE - import matplotlib.pyplot as plt import numpy as np -( - counts_sim, - covariates_sim, - offsets_sim, - true_covariance, - true_coef, -) = get_simulated_count_data(return_true_param=True) - -counts_real = get_real_count_data() -RANKS = [2, 8] - - -@pytest.fixture -def my_instance_plnpca(): - plnpca = PLNPCA(ranks=RANKS) - return plnpca - - -@pytest.fixture -def real_fitted_plnpca(my_instance_plnpca): - my_instance_plnpca.fit(counts_real) - return my_instance_plnpca - - -@pytest.fixture -def simulated_fitted_plnpca(my_instance_plnpca): - my_instance_plnpca.fit( - counts=counts_sim, covariates=covariates_sim, offsets=offsets_sim - ) - return my_instance_plnpca - - -@pytest.fixture -def one_simulated_fitted_plnpca(): - model = PLNPCA(ranks=2) - model.fit(counts=counts_sim, covariates=covariates_sim, offsets=offsets_sim) - return model - - -@pytest.fixture -def real_best_aic(real_fitted_plnpca): - return real_fitted_plnpca.best_model("AIC") - - -@pytest.fixture -def real_best_bic(real_fitted_plnpca): - return real_fitted_plnpca.best_model("BIC") - - -@pytest.fixture -def simulated_best_aic(simulated_fitted_plnpca): - return simulated_fitted_plnpca.best_model("AIC") - - -@pytest.fixture -def simulated_best_bic(simulated_fitted_plnpca): - return simulated_fitted_plnpca.best_model("BIC") +from tests.conftest import dict_fixtures +from tests.utils import MSE, filter_models -simulated_best_models = [lf("simulated_best_aic"), lf("simulated_best_bic")] -real_best_models = [lf("real_best_aic"), lf("real_best_bic")] -best_models = simulated_best_models + real_best_models - - -all_fitted_simulated_plnpca = [ - lf("simulated_fitted_plnpca"), - lf("one_simulated_fitted_plnpca"), -] -all_fitted_plnpca = [lf("real_fitted_plnpca")] + all_fitted_simulated_plnpca - - -def test_print_plnpca(simulated_fitted_plnpca): - print(simulated_fitted_plnpca) - - -@pytest.mark.parametrize("best_model", best_models) -def test_best_model(best_model): +@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"]) +@filter_models(["PLNPCA"]) +def test_best_model(plnpca): + best_model = plnpca.best_model() print(best_model) -@pytest.mark.parametrize("best_model", best_models) -def test_projected_variables(best_model): +@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"]) +@filter_models(["PLNPCA"]) +def test_projected_variables(plnpca): + best_model = plnpca.best_model() plv = best_model.projected_latent_variables assert plv.shape[0] == best_model.n_samples and plv.shape[1] == best_model.rank -def test_save_load_back_and_refit(simulated_fitted_plnpca): - simulated_fitted_plnpca.save() - new = PLNPCA(ranks=RANKS) - new.load() - new.fit(counts=counts_sim, covariates=covariates_sim, offsets=offsets_sim) - - -@pytest.mark.parametrize("plnpca", all_fitted_simulated_plnpca) -def test_find_right_covariance(plnpca): - passed = True - for model in plnpca.models: - mse_covariance = MSE(model.covariance - true_covariance) - assert mse_covariance < 0.3 - - -@pytest.mark.parametrize("plnpca", all_fitted_simulated_plnpca) -def test_find_right_coef(plnpca): - for model in plnpca.models: - mse_coef = MSE(model.coef - true_coef) - assert mse_coef < 0.3 - - -@pytest.mark.parametrize("all_pca", all_fitted_plnpca) -def test_additional_methods_pca(all_pca): - all_pca.show() - all_pca.BIC - all_pca.AIC - all_pca.loglikes - - -@pytest.mark.parametrize("all_pca", all_fitted_plnpca) -def test_viz_pca(all_pca): - _, ax = plt.subplots() - all_pca[2].viz(ax=ax) - plt.show() - all_pca[2].viz() - plt.show() - n_samples = all_pca.n_samples - colors = np.random.randint(low=0, high=2, size=n_samples) - all_pca[2].viz(colors=colors) - plt.show() - - -@pytest.mark.parametrize( - "pca", [lf("real_fitted_plnpca"), lf("simulated_fitted_plnpca")] -) -def test_fails_viz_pca(pca): - with pytest.raises(RuntimeError): - pca[8].viz() - - -@pytest.mark.parametrize("all_pca", all_fitted_plnpca) -def test_closest(all_pca): +@pytest.mark.parametrize("fitted_pln", dict_fixtures["fitted_pln"]) +@filter_models(["_PLNPCA"]) +def test_number_of_iterations_plnpca(fitted_pln): + nb_iterations = len(fitted_pln.elbos_list) + assert 100 < nb_iterations < 5000 + + +@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"]) +@filter_models(["_PLNPCA"]) +def test_latent_var_pca(plnpca): + assert plnpca.transform(project=False).shape == plnpca.counts.shape + assert plnpca.transform().shape == (plnpca.n_samples, plnpca.rank) + + +@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"]) +@filter_models(["PLNPCA"]) +def test_additional_methods_pca(plnpca): + plnpca.show() + plnpca.BIC + plnpca.AIC + plnpca.loglikes + + +@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"]) +@filter_models(["PLNPCA"]) +def test_viz_pca(plnpca): + models = plnpca.models + for model in models: + _, ax = plt.subplots() + model.viz(ax=ax) + plt.show() + model.viz() + plt.show() + n_samples = plnpca.n_samples + colors = np.random.randint(low=0, high=2, size=n_samples) + model.viz(colors=colors) + plt.show() + + +@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"]) +@filter_models(["PLNPCA"]) +def test_closest(plnpca): with pytest.warns(UserWarning): - all_pca[9] + plnpca[9] -@pytest.mark.parametrize("plnpca", all_fitted_plnpca) +@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"]) +@filter_models(["PLNPCA"]) def test_wrong_criterion(plnpca): with pytest.raises(ValueError): plnpca.best_model("AIK") diff --git a/tests/test_setters.py b/tests/test_setters.py new file mode 100644 index 0000000000000000000000000000000000000000..b1a9ba29c09d21faa2a10bc5dae71dbe368cbd7e --- /dev/null +++ b/tests/test_setters.py @@ -0,0 +1,21 @@ +import pytest +import pandas as pd + +from tests.conftest import dict_fixtures +from tests.utils import MSE, filter_models + + +@pytest.mark.parametrize("pln", dict_fixtures["all_pln"]) +@filter_models(["PLN", "PLNPCA"]) +def test_setter_with_numpy(pln): + np_counts = pln.counts.numpy() + pln.counts = np_counts + pln.fit() + + +@pytest.mark.parametrize("pln", dict_fixtures["all_pln"]) +@filter_models(["PLN", "PLNPCA"]) +def test_setter_with_pandas(pln): + pd_counts = pd.DataFrame(pln.counts.numpy()) + pln.counts = pd_counts + pln.fit() diff --git a/tests/utils.py b/tests/utils.py index 0cc7f2d7b9600b724015896ed97813699e153239..2e33a3d0111519d0e036040a43a38d74ecc3dafa 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,20 @@ import torch +import functools def MSE(t): return torch.mean(t**2) + + +def filter_models(models_name): + def decorator(my_test): + @functools.wraps(my_test) + def new_test(**kwargs): + fixture = next(iter(kwargs.values())) + if type(fixture).__name__ not in models_name: + return None + return my_test(**kwargs) + + return new_test + + return decorator