diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index a8dd3286b07d15e1e98f5945417316e7c77ee70f..79e1610fab97c682ba3743922995111cc66fea11 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -26,7 +26,7 @@ publish_package:
   script:
     - pip install twine
     - python setup.py bdist_wheel
-    - TWINE_PASSWORD=${TWINE_PASSWORD} TWINE_USERNAME=${TWINE_USERNAME} python -m twine upload dist/*
+    - TWINE_PASSWORD=${pypln_token} TWINE_USERNAME=${account_name} python -m twine upload dist/*
   tags:
     - docker
   only:
diff --git a/pyPLNmodels/__init__.py b/pyPLNmodels/__init__.py
index 15591263ad54a35729ad9b724a7467f8be9b30b5..d895c7cd30d9f39e8069cb6793b48a4ddb2dac79 100644
--- a/pyPLNmodels/__init__.py
+++ b/pyPLNmodels/__init__.py
@@ -1,6 +1,12 @@
 from .models import PLNPCA, PLN  # pylint:disable=[C0114]
 from .elbos import profiled_elbo_pln, elbo_plnpca, elbo_pln
-from ._utils import get_simulated_count_data, get_real_count_data
+from ._utils import (
+    get_simulated_count_data,
+    get_real_count_data,
+    load_model,
+    load_plnpca,
+    load_pln,
+)
 
 __all__ = (
     "PLNPCA",
@@ -10,4 +16,7 @@ __all__ = (
     "elbo_pln",
     "get_simulated_count_data",
     "get_real_count_data",
+    "load_model",
+    "load_plnpca",
+    "load_pln",
 )
diff --git a/pyPLNmodels/_closed_forms.py b/pyPLNmodels/_closed_forms.py
index 889cf221a5711e7d30c8c5ec4b7dfee49eedd517..0dc9e7fab2b975949f515733508f915b103bf284 100644
--- a/pyPLNmodels/_closed_forms.py
+++ b/pyPLNmodels/_closed_forms.py
@@ -3,7 +3,11 @@ import torch  # pylint:disable=[C0114]
 
 def closed_formula_covariance(covariates, latent_mean, latent_var, coef, n_samples):
     """Closed form for covariance for the M step for the noPCA model."""
-    m_moins_xb = latent_mean - covariates @ coef
+    if covariates is None:
+        XB = 0
+    else:
+        XB = covariates @ coef
+    m_moins_xb = latent_mean - XB
     closed = m_moins_xb.T @ m_moins_xb + torch.diag(
         torch.sum(torch.square(latent_var), dim=0)
     )
@@ -12,6 +16,8 @@ def closed_formula_covariance(covariates, latent_mean, latent_var, coef, n_sampl
 
 def closed_formula_coef(covariates, latent_mean):
     """Closed form for coef for the M step for the noPCA model."""
+    if covariates is None:
+        return None
     return torch.inverse(covariates.T @ covariates) @ covariates.T @ latent_mean
 
 
diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index 1024746c6dd76601317650532dd981c24d4c8ba7..0e079ce52f357a6be07403d83c44734add074dd7 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -1,5 +1,6 @@
 import math  # pylint:disable=[C0114]
 import warnings
+import os
 
 import matplotlib.pyplot as plt
 import numpy as np
@@ -8,6 +9,7 @@ import torch.linalg as TLA
 import pandas as pd
 from matplotlib.patches import Ellipse
 from matplotlib import transforms
+from patsy import dmatrices
 
 torch.set_default_dtype(torch.float64)
 
@@ -81,7 +83,7 @@ class PLNPlotArgs:
         ax.legend()
 
 
-def init_sigma(counts, covariates, coef):
+def init_covariance(counts, covariates, coef):
     """Initialization for covariance for the PLN model. Take the log of counts
     (careful when counts=0), remove the covariates effects X@coef and
     then do as a MLE for Gaussians samples.
@@ -93,9 +95,7 @@ def init_sigma(counts, covariates, coef):
     Returns : torch.tensor of size (p,p).
     """
     log_y = torch.log(counts + (counts == 0) * math.exp(-2))
-    log_y_centered = (
-        log_y - torch.matmul(covariates.unsqueeze(1), coef.unsqueeze(0)).squeeze()
-    )
+    log_y_centered = log_y - torch.mean(log_y, axis=0)
     # MLE in a Gaussian setting
     n_samples = counts.shape[0]
     sigma_hat = 1 / (n_samples - 1) * (log_y_centered.T) @ log_y_centered
@@ -115,7 +115,7 @@ def init_components(counts, covariates, coef, rank):
     Returns :
         torch.tensor of size (p,rank). The initialization of components.
     """
-    sigma_hat = init_sigma(counts, covariates, coef).detach()
+    sigma_hat = init_covariance(counts, covariates, coef).detach()
     components = components_from_covariance(sigma_hat, rank)
     return components
 
@@ -188,15 +188,11 @@ def sample_pln(components, coef, covariates, offsets, _coef_inflation=None, seed
         torch.random.manual_seed(seed)
     n_samples = offsets.shape[0]
     rank = components.shape[1]
-    full_of_ones = torch.ones((n_samples, 1))
     if covariates is None:
-        covariates = full_of_ones
+        XB = 0
     else:
-        covariates = torch.stack((full_of_ones, covariates), axis=1).squeeze()
-    gaussian = (
-        torch.mm(torch.randn(n_samples, rank, device=DEVICE), components.T)
-        + covariates @ coef
-    )
+        XB = covariates @ coef
+    gaussian = torch.mm(torch.randn(n_samples, rank, device=DEVICE), components.T) + XB
     parameter = torch.exp(offsets + gaussian)
     if _coef_inflation is not None:
         print("ZIPLN is sampled")
@@ -235,13 +231,12 @@ def components_from_covariance(covariance, rank):
     return requested_components
 
 
-def init_coef(counts, covariates):
-    log_y = torch.log(counts + (counts == 0) * math.exp(-2))
-    log_y = log_y.to(DEVICE)
-    return torch.matmul(
-        torch.inverse(torch.matmul(covariates.T, covariates)),
-        torch.matmul(covariates.T, log_y),
-    )
+def init_coef(counts, covariates, offsets):
+    if covariates is None:
+        return None
+    poiss_reg = PoissonReg()
+    poiss_reg.fit(counts, covariates, offsets)
+    return poiss_reg.beta
 
 
 def log_stirling(integer):
@@ -275,8 +270,11 @@ def log_posterior(counts, covariates, offsets, posterior_mean, components, coef)
     components_posterior_mean = torch.matmul(
         components.unsqueeze(0), posterior_mean.unsqueeze(2)
     ).squeeze()
-
-    log_lambda = offsets + components_posterior_mean + covariates @ coef
+    if covariates is None:
+        XB = 0
+    else:
+        XB = covariates @ coef
+    log_lambda = offsets + components_posterior_mean + XB
     first_term = (
         -rank / 2 * math.log(2 * math.pi)
         - 1 / 2 * torch.norm(posterior_mean, dim=-1) ** 2
@@ -321,7 +319,21 @@ def check_two_dimensions_are_equal(
         )
 
 
+def init_S(counts, covariates, offsets, beta, C, M):
+    n, rank = M.shape
+    batch_matrix = torch.matmul(C.unsqueeze(2), C.unsqueeze(1)).unsqueeze(0)
+    CW = torch.matmul(C.unsqueeze(0), M.unsqueeze(2)).squeeze()
+    common = torch.exp(offsets + covariates @ beta + CW).unsqueeze(2).unsqueeze(3)
+    prod = batch_matrix * common
+    hess_posterior = torch.sum(prod, axis=1) + torch.eye(rank).to(DEVICE)
+    inv_hess_posterior = -torch.inverse(hess_posterior)
+    hess_posterior = torch.diagonal(inv_hess_posterior, dim1=-2, dim2=-1)
+    return hess_posterior
+
+
 def format_data(data):
+    if data is None:
+        return None
     if isinstance(data, pd.DataFrame):
         return torch.from_numpy(data.values).double().to(DEVICE)
     if isinstance(data, np.ndarray):
@@ -335,7 +347,8 @@ def format_data(data):
 
 def format_model_param(counts, covariates, offsets, offsets_formula):
     counts = format_data(counts)
-    covariates = prepare_covariates(covariates, counts.shape[0])
+    if covariates is not None:
+        covariates = format_data(covariates)
     if offsets is None:
         if offsets_formula == "logsum":
             print("Setting the offsets as the log of the sum of counts")
@@ -349,20 +362,26 @@ def format_model_param(counts, covariates, offsets, offsets_formula):
     return counts, covariates, offsets
 
 
-def prepare_covariates(covariates, n_samples):
-    full_of_ones = torch.full((n_samples, 1), 1, device=DEVICE).double()
-    if covariates is None:
-        return full_of_ones
+def remove_useless_intercepts(covariates):
     covariates = format_data(covariates)
-    return torch.concat((full_of_ones, covariates), axis=1)
+    if covariates.shape[1] < 2:
+        return covariates
+    first_column = covariates[:, 0]
+    second_column = covariates[:, 1]
+    diff = first_column - second_column
+    if torch.sum(torch.abs(diff - diff[0])) == 0:
+        print("removing one")
+        return covariates[:, 1:]
+    return covariates
 
 
 def check_data_shape(counts, covariates, offsets):
     n_counts, p_counts = counts.shape
     n_offsets, p_offsets = offsets.shape
-    n_cov, _ = covariates.shape
     check_two_dimensions_are_equal("counts", "offsets", n_counts, n_offsets, 0)
-    check_two_dimensions_are_equal("counts", "covariates", n_counts, n_cov, 0)
+    if covariates is not None:
+        n_cov, _ = covariates.shape
+        check_two_dimensions_are_equal("counts", "covariates", n_counts, n_cov, 0)
     check_two_dimensions_are_equal("counts", "offsets", p_counts, p_offsets, 1)
 
 
@@ -417,13 +436,13 @@ def get_components_simulation(dim, rank):
 def get_simulation_offsets_cov_coef(n_samples, nb_cov, dim):
     prev_state = torch.random.get_rng_state()
     torch.random.manual_seed(0)
-    if nb_cov < 2:
+    if nb_cov == 0:
         covariates = None
     else:
         covariates = torch.randint(
             low=-1,
             high=2,
-            size=(n_samples, nb_cov - 1),
+            size=(n_samples, nb_cov),
             dtype=torch.float64,
             device=DEVICE,
         )
@@ -471,9 +490,66 @@ def closest(lst, element):
     return lst[idx]
 
 
-def check_dimensions_are_equal(tens1, tens2):
-    if tens1.shape[0] != tens2.shape[0] or tens1.shape[1] != tens2.shape[1]:
-        raise ValueError("Tensors should have the same size.")
+class PoissonReg:
+    """Poisson regressor class."""
+
+    def __init__(self):
+        """No particular initialization is needed."""
+        pass
+
+    def fit(self, Y, covariates, O, Niter_max=300, tol=0.001, lr=0.005, verbose=False):
+        """Run a gradient ascent to maximize the log likelihood, using
+        pytorch autodifferentiation. The log likelihood considered is
+        the one from a poisson regression model. It is roughly the
+        same as PLN without the latent layer Z.
+
+        Args:
+                        Y: torch.tensor. Counts with size (n,p)
+            0: torch.tensor. Offset, size (n,p)
+            covariates: torch.tensor. Covariates, size (n,d)
+            Niter_max: int, optional. The maximum number of iteration.
+                Default is 300.
+            tol: non negative float, optional. The tolerance criteria.
+                Will stop if the norm of the gradient is less than
+                or equal to this threshold. Default is 0.001.
+            lr: positive float, optional. Learning rate for the gradient ascent.
+                Default is 0.005.
+            verbose: bool, optional. If True, will print some stats.
+
+        Returns : None. Update the parameter beta. You can access it
+                by calling self.beta.
+        """
+        # Initialization of beta of size (d,p)
+        beta = torch.rand(
+            (covariates.shape[1], Y.shape[1]), device=DEVICE, requires_grad=True
+        )
+        optimizer = torch.optim.Rprop([beta], lr=lr)
+        i = 0
+        grad_norm = 2 * tol  # Criterion
+        while i < Niter_max and grad_norm > tol:
+            loss = -compute_poissreg_log_like(Y, O, covariates, beta)
+            loss.backward()
+            optimizer.step()
+            grad_norm = torch.norm(beta.grad)
+            beta.grad.zero_()
+            i += 1
+            if verbose:
+                if i % 10 == 0:
+                    print("log like : ", -loss)
+                    print("grad_norm : ", grad_norm)
+                if i < Niter_max:
+                    print("Tolerance reached in {} iterations".format(i))
+                else:
+                    print("Maxium number of iterations reached")
+        self.beta = beta
+
+
+def compute_poissreg_log_like(Y, O, covariates, beta):
+    """Compute the log likelihood of a Poisson regression."""
+    # Matrix multiplication of X and beta.
+    XB = torch.matmul(covariates.unsqueeze(1), beta.unsqueeze(0)).squeeze()
+    # Returns the formula of the log likelihood of a poisson regression model.
+    return torch.sum(-torch.exp(O + XB) + torch.multiply(Y, O + XB))
 
 
 def to_tensor(obj):
@@ -484,3 +560,84 @@ def to_tensor(obj):
     if isinstance(obj, pd.DataFrame):
         return torch.from_numpy(obj.values)
     raise TypeError("Please give either a nd.array or torch.Tensor or pd.DataFrame")
+
+
+def check_dimensions_are_equal(tens1, tens2):
+    if tens1.shape[0] != tens2.shape[0] or tens1.shape[1] != tens2.shape[1]:
+        raise ValueError("Tensors should have the same size.")
+
+
+def load_model(path_of_directory):
+    working_dict = os.getcwd()
+    os.chdir(path_of_directory)
+    all_files = os.listdir()
+    data = {}
+    for filename in all_files:
+        if len(filename) > 4:
+            if filename[-4:] == ".csv":
+                parameter = filename[:-4]
+                try:
+                    data[parameter] = pd.read_csv(filename, header=None).values
+                except pd.errors.EmptyDataError as err:
+                    print(
+                        f"Can't load {parameter} since empty. Standard initialization will be performed"
+                    )
+    os.chdir(working_dict)
+    return data
+
+
+def load_pln(path_of_directory):
+    return load_model(path_of_directory)
+
+
+def load_plnpca(path_of_directory, ranks=None):
+    working_dict = os.getcwd()
+    os.chdir(path_of_directory)
+    if ranks is None:
+        dirnames = os.listdir()
+        ranks = []
+        for dirname in dirnames:
+            try:
+                rank = int(dirname[-1])
+            except ValueError:
+                raise ValueError(
+                    f"Can't load the model {dirname}. End of {dirname} should be an int"
+                )
+            ranks.append(rank)
+    datas = {}
+    for rank in ranks:
+        datas[rank] = load_model(f"_PLNPCA_rank_{rank}")
+    os.chdir(working_dict)
+    return datas
+
+
+def check_right_rank(data, rank):
+    data_rank = data["latent_mean"].shape[1]
+    if data_rank != rank:
+        raise RuntimeError(
+            f"Wrong rank during initialization."
+            f" Got rank {rank} and data with rank {data_rank}."
+        )
+
+
+def extract_data_from_formula(formula, data):
+    dmatrix = dmatrices(formula, data=data)
+    counts = dmatrix[0]
+    covariates = dmatrix[1]
+    print("covariates size:", covariates.size)
+    if covariates.size == 0:
+        covariates = None
+    offsets = data.get("offsets", None)
+    return counts, covariates, offsets
+
+
+def is_dict_of_dict(dictionnary):
+    if isinstance(dictionnary[list(dictionnary.keys())[0]], dict):
+        return True
+    return False
+
+
+def get_dict_initialization(rank, dict_of_dict):
+    if dict_of_dict is None:
+        return None
+    return dict_of_dict[rank]
diff --git a/pyPLNmodels/elbos.py b/pyPLNmodels/elbos.py
index 160468cfc407dd5e271dde305a9db1aa02063fa6..082bcdfaa943478ed1cf9999ce8016256130f930 100644
--- a/pyPLNmodels/elbos.py
+++ b/pyPLNmodels/elbos.py
@@ -22,7 +22,11 @@ def elbo_pln(counts, covariates, offsets, latent_mean, latent_var, covariance, c
     n_samples, dim = counts.shape
     s_rond_s = torch.square(latent_var)
     offsets_plus_m = offsets + latent_mean
-    m_minus_xb = latent_mean - covariates @ coef
+    if covariates is None:
+        XB = 0
+    else:
+        XB = covariates @ coef
+    m_minus_xb = latent_mean - XB
     d_plus_minus_xb2 = (
         torch.diag(torch.sum(s_rond_s, dim=0)) + m_minus_xb.T @ m_minus_xb
     )
@@ -90,7 +94,11 @@ def elbo_plnpca(counts, covariates, offsets, latent_mean, latent_var, components
     """
     n_samples = counts.shape[0]
     rank = components.shape[1]
-    log_intensity = offsets + covariates @ coef + latent_mean @ components.T
+    if covariates is None:
+        XB = 0
+    else:
+        XB = covariates @ coef
+    log_intensity = offsets + XB + latent_mean @ components.T
     s_rond_s = torch.square(latent_var)
     counts_log_intensity = torch.sum(counts * log_intensity)
     minus_intensity_plus_s_rond_s_cct = torch.sum(
diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index c4171d6aab7437a1cb56e7159090803ae32e990d..1be04cc1508ea93065ade6298c4392338e224085 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -3,6 +3,8 @@ from abc import ABC, abstractmethod
 import pickle
 import warnings
 import os
+from functools import singledispatchmethod
+from collections.abc import Iterable
 
 import pandas as pd
 import torch
@@ -10,6 +12,7 @@ import numpy as np
 import seaborn as sns
 import matplotlib.pyplot as plt
 from sklearn.decomposition import PCA
+from patsy import dmatrices
 
 
 from ._closed_forms import (
@@ -20,7 +23,7 @@ from ._closed_forms import (
 from .elbos import elbo_plnpca, elbo_zi_pln, profiled_elbo_pln
 from ._utils import (
     PLNPlotArgs,
-    init_sigma,
+    init_covariance,
     init_components,
     init_coef,
     check_two_dimensions_are_equal,
@@ -31,9 +34,12 @@ from ._utils import (
     nice_string_of_dict,
     plot_ellipse,
     closest,
-    prepare_covariates,
     to_tensor,
     check_dimensions_are_equal,
+    check_right_rank,
+    remove_useless_intercepts,
+    extract_data_from_formula,
+    get_dict_initialization,
 )
 
 if torch.cuda.is_available():
@@ -68,17 +74,50 @@ class _PLN(ABC):
     _latent_var: torch.Tensor
     _latent_mean: torch.Tensor
 
-    def __init__(self):
+    @singledispatchmethod
+    def __init__(
+        self,
+        counts,
+        covariates=None,
+        offsets=None,
+        offsets_formula="logsum",
+        dict_initialization=None,
+    ):
         """
         Simple initialization method.
         """
-        self._fitted = False
-        self.plotargs = PLNPlotArgs(self.WINDOW)
 
-    def format_model_param(self, counts, covariates, offsets, offsets_formula):
         self._counts, self._covariates, self._offsets = format_model_param(
             counts, covariates, offsets, offsets_formula
         )
+        check_data_shape(self._counts, self._covariates, self._offsets)
+        self._fitted = False
+        self.plotargs = PLNPlotArgs(self.WINDOW)
+        if dict_initialization is not None:
+            self.set_init_parameters(dict_initialization)
+
+    @__init__.register(str)
+    def _(
+        self,
+        formula: str,
+        data: dict,
+        offsets_formula="logsum",
+        dict_initialization=None,
+    ):
+        counts, covariates, offsets = extract_data_from_formula(formula, data)
+        self.__init__(counts, covariates, offsets, offsets_formula, dict_initialization)
+
+    def set_init_parameters(self, dict_initialization):
+        if "coef" not in dict_initialization.keys():
+            print("No coef is initialized.")
+            self.coef = None
+        for key, array in dict_initialization.items():
+            array = format_data(array)
+            setattr(self, key, array)
+
+    @property
+    def fitted(self):
+        return
 
     @property
     def nb_iteration_done(self):
@@ -94,12 +133,16 @@ class _PLN(ABC):
 
     @property
     def nb_cov(self):
+        if self.covariates is None:
+            return 0
         return self.covariates.shape[1]
 
     def smart_init_coef(self):
-        self._coef = init_coef(self._counts, self._covariates)
+        self._coef = init_coef(self._counts, self._covariates, self._offsets)
 
     def random_init_coef(self):
+        if self.nb_cov == 0:
+            self._coef = None
         self._coef = torch.randn((self.nb_cov, self.dim), device=DEVICE)
 
     @abstractmethod
@@ -140,17 +183,12 @@ class _PLN(ABC):
 
     def fit(
         self,
-        counts,
-        covariates=None,
-        offsets=None,
         nb_max_iteration=50000,
         lr=0.01,
         class_optimizer=torch.optim.Rprop,
-        tol=1e-6,
+        tol=1e-3,
         do_smart_init=True,
         verbose=False,
-        offsets_formula="logsum",
-        keep_going=False,
     ):
         """
         Main function of the class. Fit a PLN to the data.
@@ -168,11 +206,9 @@ class _PLN(ABC):
         self.print_beginning_message()
         self.beginnning_time = time.time()
 
-        if keep_going is False:
-            self.format_model_param(counts, covariates, offsets, offsets_formula)
-            check_data_shape(self._counts, self._covariates, self._offsets)
+        if self._fitted is False:
             self.init_parameters(do_smart_init)
-        if self._fitted is True and keep_going is True:
+        else:
             self.beginnning_time -= self.plotargs.running_times[-1]
         self.optim = class_optimizer(self.list_of_parameters_needing_gradient, lr=lr)
         stop_condition = False
@@ -215,8 +251,8 @@ class _PLN(ABC):
     def print_end_of_fitting_message(self, stop_condition, tol):
         if stop_condition is True:
             print(
-                f"Tolerance {tol} reached"
-                f"n {self.plotargs.iteration_number} iterations"
+                f"Tolerance {tol} reached "
+                f"in {self.plotargs.iteration_number} iterations"
             )
         else:
             print(
@@ -360,7 +396,11 @@ class _PLN(ABC):
 
     @property
     def model_in_a_dict(self):
-        return self.dict_data | self.model_parameters | self.latent_parameters
+        return self.dict_data | self.dict_parameters
+
+    @property
+    def dict_parameters(self):
+        return self.model_parameters | self.latent_parameters
 
     @property
     def coef(self):
@@ -391,28 +431,18 @@ class _PLN(ABC):
         return None
 
     def save(self, path_of_directory="./"):
-        path = f"{path_of_directory}/{self.model_path}/"
+        path = f"{path_of_directory}/{self.path_to_directory}{self.directory_name}"
         os.makedirs(path, exist_ok=True)
-        for key, value in self.model_in_a_dict.items():
+        for key, value in self.dict_parameters.items():
             filename = f"{path}/{key}.csv"
             if isinstance(value, torch.Tensor):
                 pd.DataFrame(np.array(value.cpu().detach())).to_csv(
                     filename, header=None, index=None
                 )
-            else:
+            elif value is not None:
                 pd.DataFrame(np.array([value])).to_csv(
                     filename, header=None, index=None
                 )
-        self._fitted = True
-
-    def load(self, path_of_directory="./"):
-        path = f"{path_of_directory}/{self.model_path}/"
-        for key, value in self.model_in_a_dict.items():
-            value = torch.from_numpy(
-                pd.read_csv(path + key + ".csv", header=None).values
-            )
-            setattr(self, key, value)
-        self.put_parameters_to_device()
 
     @property
     def counts(self):
@@ -473,13 +503,26 @@ class _PLN(ABC):
         return self.covariance
 
     def predict(self, covariates=None):
-        if isinstance(covariates, torch.Tensor):
-            if covariates.shape[-1] != self.nb_cov - 1:
-                error_string = f"X has wrong shape ({covariates.shape}).Should"
-                error_string += f" be ({self.n_samples, self.nb_cov-1})."
-                raise RuntimeError(error_string)
-        covariates_with_ones = prepare_covariates(covariates, self.n_samples)
-        return covariates_with_ones @ self.coef
+        if covariates is not None and self.nb_cov == 0:
+            raise AttributeError("No covariates in the model, can't predict")
+        if covariates is None:
+            if self.covariates is None:
+                print("No covariates in the model.")
+                return None
+            return self.covariates @ self.coef
+        if covariates.shape[-1] != self.nb_cov:
+            error_string = f"X has wrong shape ({covariates.shape}).Should"
+            error_string += f" be ({self.n_samples, self.nb_cov})."
+            raise RuntimeError(error_string)
+        return covariates @ self.coef
+
+    @property
+    def directory_name(self):
+        return f"{self.NAME}_nbcov_{self.nb_cov}_dim_{self.dim}"
+
+    @property
+    def path_to_directory(self):
+        return ""
 
 
 # need to do a good init for M and S
@@ -505,12 +548,10 @@ class PLN(_PLN):
         self.random_init_latent_parameters()
 
     def random_init_latent_parameters(self):
-        self._latent_var = 1 / 2 * torch.ones((self.n_samples, self.dim)).to(DEVICE)
-        self._latent_mean = torch.ones((self.n_samples, self.dim)).to(DEVICE)
-
-    @property
-    def model_path(self):
-        return self.NAME
+        if not hasattr(self, "_latent_var"):
+            self._latent_var = 1 / 2 * torch.ones((self.n_samples, self.dim)).to(DEVICE)
+        if not hasattr(self, "_latent_mean"):
+            self._latent_mean = torch.ones((self.n_samples, self.dim)).to(DEVICE)
 
     @property
     def list_of_parameters_needing_gradient(self):
@@ -589,31 +630,121 @@ class PLN(_PLN):
         pass
 
 
+## en train d'essayer de faire une seule init pour_PLNPCA
 class PLNPCA:
-    def __init__(self, ranks):
-        if isinstance(ranks, (list, np.ndarray)):
-            self.ranks = ranks
-            self.dict_models = {}
+    NAME = "PLNPCA"
+
+    @singledispatchmethod
+    def __init__(
+        self,
+        counts,
+        covariates=None,
+        offsets=None,
+        offsets_formula="logsum",
+        ranks=range(3, 5),
+        dict_of_dict_initialization=None,
+    ):
+        self.init_data(counts, covariates, offsets, offsets_formula)
+        self.init_models(ranks, dict_of_dict_initialization)
+
+    def init_data(self, counts, covariates, offsets, offsets_formula):
+        self._counts, self._covariates, self._offsets = format_model_param(
+            counts, covariates, offsets, offsets_formula
+        )
+        check_data_shape(self._counts, self._covariates, self._offsets)
+        self._fitted = False
+
+    @__init__.register(str)
+    def _(
+        self,
+        formula: str,
+        data: dict,
+        offsets_formula="logsum",
+        ranks=range(3, 5),
+        dict_of_dict_initialization=None,
+    ):
+        counts, covariates, offsets = extract_data_from_formula(formula, data)
+        self.__init__(
+            counts,
+            covariates,
+            offsets,
+            offsets_formula,
+            ranks,
+            dict_of_dict_initialization,
+        )
+
+    @property
+    def covariates(self):
+        return self.list_models[0].covariates
+
+    @property
+    def counts(self):
+        return self.list_models[0].counts
+
+    @counts.setter
+    def counts(self, counts):
+        counts = format_data(counts)
+        if hasattr(self, "_counts"):
+            check_dimensions_are_equal(self._counts, counts)
+        self._counts = counts
+
+    @covariates.setter
+    def covariates(self, covariates):
+        covariates = format_data(covariates)
+        # if hasattr(self,)
+        self._covariates = covariates
+
+    @property
+    def offsets(self):
+        return self.list_models[0].offsets
+
+    def init_models(self, ranks, dict_of_dict_initialization):
+        if isinstance(ranks, (Iterable, np.ndarray)):
+            self.list_models = []
             for rank in ranks:
-                if isinstance(rank, (int, np.int64)):
-                    self.dict_models[rank] = _PLNPCA(rank)
+                if isinstance(rank, (int, np.integer)):
+                    dict_initialization = get_dict_initialization(
+                        rank, dict_of_dict_initialization
+                    )
+                    self.list_models.append(
+                        _PLNPCA(
+                            self._counts,
+                            self._covariates,
+                            self._offsets,
+                            rank,
+                            dict_initialization,
+                        )
+                    )
                 else:
                     raise TypeError(
-                        "Please instantiate with either a list\
-                              of integers or an integer."
+                        f"Please instantiate with either a list "
+                        f"of integers or an integer."
                     )
-        elif isinstance(ranks, int):
-            self.ranks = [ranks]
-            self.dict_models = {ranks: _PLNPCA(ranks)}
+        elif isinstance(ranks, (int, np.integer)):
+            dict_initialization = get_dict_initialization(
+                ranks, dict_of_dict_initialization
+            )
+            self.list_models = [
+                _PLNPCA(
+                    self._counts,
+                    self._covariates,
+                    self._offsets,
+                    rank,
+                    dict_initialization,
+                )
+            ]
         else:
             raise TypeError(
-                "Please instantiate with either a list of \
-                        integers or an integer."
+                f"Please instantiate with either a list " f"of integers or an integer."
             )
 
     @property
-    def models(self):
-        return list(self.dict_models.values())
+    def ranks(self):
+        return [model.rank for model in self.list_models]
+
+    @property
+    def dict_models(self):
+        return {model.rank: model for model in self.list_models}
 
     def print_beginning_message(self):
         return f"Adjusting {len(self.ranks)} PLN models for PCA analysis \n"
@@ -622,39 +753,30 @@ class PLNPCA:
     def dim(self):
         return self[self.ranks[0]].dim
 
+    @property
+    def nb_cov(self):
+        return self[self.ranks[0]].nb_cov
+
     ## should do something for this weird init. pb: if doing the init of self._counts etc
     ## only in PLNPCA, then we don't do it for each _PLNPCA but then PLN is not doing it.
     def fit(
         self,
-        counts,
-        covariates=None,
-        offsets=None,
         nb_max_iteration=100000,
         lr=0.01,
         class_optimizer=torch.optim.Rprop,
-        tol=1e-6,
+        tol=1e-3,
         do_smart_init=True,
         verbose=False,
-        offsets_formula="logsum",
-        keep_going=False,
     ):
         self.print_beginning_message()
-        counts, _, offsets = format_model_param(
-            counts, covariates, offsets, offsets_formula
-        )
         for pca in self.dict_models.values():
             pca.fit(
-                counts,
-                covariates,
-                offsets,
                 nb_max_iteration,
                 lr,
                 class_optimizer,
                 tol,
                 do_smart_init,
                 verbose,
-                None,
-                keep_going,
             )
         self.print_ending_message()
 
@@ -681,15 +803,15 @@ class PLNPCA:
 
     @property
     def BIC(self):
-        return {model.rank: int(model.BIC) for model in self.models}
+        return {model.rank: int(model.BIC) for model in self.list_models}
 
     @property
     def AIC(self):
-        return {model.rank: int(model.AIC) for model in self.models}
+        return {model.rank: int(model.AIC) for model in self.list_models}
 
     @property
     def loglikes(self):
-        return {model.rank: model.loglike for model in self.models}
+        return {model.rank: model.loglike for model in self.list_models}
 
     def show(self):
         bic = self.BIC
@@ -730,24 +852,31 @@ class PLNPCA:
             return self[self.best_AIC_model_rank]
         raise ValueError(f"Unknown criterion {criterion}")
 
-    def save(self, path_of_directory="./"):
-        for model in self.models:
-            model.save(path_of_directory)
+    def save(self, path_of_directory="./", ranks=None):
+        if ranks is None:
+            ranks = self.ranks
+        for model in self.list_models:
+            if model.rank in ranks:
+                model.save(path_of_directory)
 
-    def load(self, path_of_directory="./"):
-        for model in self.models:
-            model.load(path_of_directory)
+    @property
+    def directory_name(self):
+        return f"{self.NAME}_nbcov_{self.nb_cov}_dim_{self.dim}"
 
     @property
     def n_samples(self):
-        return self.models[0].n_samples
+        return self.list_models[0].n_samples
 
     @property
     def _p(self):
         return self[self.ranks[0]].p
 
+    @property
+    def models(self):
+        return self.dict_models.values()
+
     def __str__(self):
-        nb_models = len(self.models)
+        nb_models = len(self.list_models)
         delimiter = "\n" + "-" * NB_CHARACTERS_FOR_NICE_PLOT + "\n"
         to_print = delimiter
         to_print += f"Collection of {nb_models} PLNPCA models with \
@@ -780,26 +909,47 @@ class PLNPCA:
         return ".BIC, .AIC, .loglikes"
 
 
+# Here, setting the value for each key in dict_parameters
 class _PLNPCA(_PLN):
-    NAME = "PLNPCA"
+    NAME = "_PLNPCA"
     _components: torch.Tensor
 
-    def __init__(self, rank):
-        super().__init__()
+    @singledispatchmethod
+    def __init__(self, counts, covariates, offsets, rank, dict_initialization=None):
         self._rank = rank
+        self._counts, self._covariates, self._offsets = format_model_param(
+            counts, covariates, offsets, None
+        )
+        check_data_shape(self._counts, self._covariates, self._offsets)
+        self.check_if_rank_is_too_high()
+        if dict_initialization is not None:
+            self.set_init_parameters(dict_initialization)
+        self._fitted = False
+        self.plotargs = PLNPlotArgs(self.WINDOW)
 
-    def init_parameters(self, do_smart_init):
-        if self.dim < self._rank:
-            warning_string = f"\nThe requested rank of approximation {self._rank} \
-                is greater than the number of variables {self.dim}. \
-                Setting rank to {self.dim}"
+    @__init__.register(str)
+    def _(self, formula, data, rank, dict_initialization):
+        counts, covariates, offsets = extract_data_from_formula(formula, data)
+        self.__init__(counts, covariates, offsets, rank, dict_initialization)
+
+    def check_if_rank_is_too_high(self):
+        if self.dim < self.rank:
+            warning_string = (
+                f"\nThe requested rank of approximation {self.rank} "
+                f"is greater than the number of variables {self.dim}. "
+                f"Setting rank to {self.dim}"
+            )
             warnings.warn(warning_string)
             self._rank = self.dim
-        super().init_parameters(do_smart_init)
 
     @property
-    def model_path(self):
-        return f"{self.NAME}_{self._rank}_rank"
+    def directory_name(self):
+        return f"{self.NAME}_rank_{self._rank}"
+        # return f"PLNPCA_nbcov_{self.nb_cov}_dim_{self.dim}/{self.NAME}_rank_{self._rank}"
+
+    @property
+    def path_to_directory(self):
+        return f"PLNPCA_nbcov_{self.nb_cov}_dim_{self.dim}/"
 
     @property
     def rank(self):
@@ -817,10 +967,12 @@ class _PLNPCA(_PLN):
         return {"coef": self.coef, "components": self.components}
 
     def smart_init_model_parameters(self):
-        super().smart_init_coef()
-        self._components = init_components(
-            self._counts, self._covariates, self._coef, self._rank
-        )
+        if not hasattr(self, "_coef"):
+            super().smart_init_coef()
+        if not hasattr(self, "_components"):
+            self._components = init_components(
+                self._counts, self._covariates, self._coef, self._rank
+            )
 
     def random_init_model_parameters(self):
         super().random_init_coef()
@@ -831,23 +983,27 @@ class _PLNPCA(_PLN):
         self._latent_mean = torch.ones((self.n_samples, self._rank)).to(DEVICE)
 
     def smart_init_latent_parameters(self):
-        self._latent_mean = (
-            init_latent_mean(
-                self._counts,
-                self._covariates,
-                self._offsets,
-                self._coef,
-                self._components,
+        if not hasattr(self, "_latent_mean"):
+            self._latent_mean = (
+                init_latent_mean(
+                    self._counts,
+                    self._covariates,
+                    self._offsets,
+                    self._coef,
+                    self._components,
+                )
+                .to(DEVICE)
+                .detach()
+            )
+        if not hasattr(self, "_latent_var"):
+            self._latent_var = (
+                1 / 2 * torch.ones((self.n_samples, self._rank)).to(DEVICE)
             )
-            .to(DEVICE)
-            .detach()
-        )
-        self._latent_var = 1 / 2 * torch.ones((self.n_samples, self._rank)).to(DEVICE)
-        self._latent_mean.requires_grad_(True)
-        self._latent_var.requires_grad_(True)
 
     @property
     def list_of_parameters_needing_gradient(self):
+        if self._coef is None:
+            return [self._components, self._latent_mean, self._latent_var]
         return [self._components, self._coef, self._latent_mean, self._latent_var]
 
     def compute_elbo(self):
@@ -896,6 +1052,16 @@ class _PLNPCA(_PLN):
         ortho_components = torch.linalg.qr(self._components, "reduced")[0]
         return torch.mm(self.latent_variables, ortho_components).detach().cpu()
 
+    def pca_projected_latent_variables(self, n_components=None):
+        if n_components is None:
+            n_components = self.get_max_components()
+        if n_components > self.dim:
+            raise RuntimeError(
+                f"You ask more components ({n_components}) than variables ({self.dim})"
+            )
+        pca = PCA(n_components=n_components)
+        return pca.fit_transform(self.projected_latent_variables.detach().cpu())
+
     @property
     def components(self):
         return self.attribute_or_none("_components")
@@ -905,13 +1071,16 @@ class _PLNPCA(_PLN):
         self._components = components
 
     def viz(self, ax=None, colors=None):
-        if self._rank != 2:
-            raise RuntimeError("Can't perform visualization for rank != 2.")
         if ax is None:
             ax = plt.gca()
-        proj_variables = self.projected_latent_variables
-        x = proj_variables[:, 0].cpu().numpy()
-        y = proj_variables[:, 1].cpu().numpy()
+        if self._rank < 2:
+            raise RuntimeError("Can't perform visualization for rank < 2.")
+        if self._rank > 2:
+            proj_variables = self.pca_projected_latent_variables(n_components=2)
+        if self._rank == 2:
+            proj_variables = self.projected_latent_variables.cpu().numpy()
+        x = proj_variables[:, 0]
+        y = proj_variables[:, 1]
         sns.scatterplot(x=x, y=y, hue=colors, ax=ax)
         covariances = torch.diag_embed(self._latent_var**2).detach().cpu()
         for i in range(covariances.shape[0]):
@@ -943,10 +1112,12 @@ class ZIPLN(PLN):
     # should change the good initialization, especially for _coef_inflation
     def smart_init_model_parameters(self):
         super().smart_init_model_parameters()
-        self._covariance = init_sigma(
-            self._counts, self._covariates, self._offsets, self._coef
-        )
-        self._coef_inflation = torch.randn(self.nb_cov, self.dim)
+        if not hasattr(self, "_covariance"):
+            self._covariance = init_covariance(
+                self._counts, self._covariates, self._coef
+            )
+        if not hasattr(self, "_coef_inflation"):
+            self._coef_inflation = torch.randn(self.nb_cov, self.dim)
 
     def random_init_latent_parameters(self):
         self._dirac = self._counts == 0
diff --git a/setup.py b/setup.py
index 74cc89090df8d4d49d7a9b99b19d1696b04d3de1..9c8b27450e6254d2ddc1dd060b31c3724054d812 100644
--- a/setup.py
+++ b/setup.py
@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 from setuptools import setup, find_packages
 
-VERSION = "0.0.37"
+VERSION = "0.0.38"
 
 with open("README.md", "r") as fh:
     long_description = fh.read()
diff --git a/test.py b/test.py
index 900e42df2c9dd4070e663d898b93b219d4fa3734..b9d5e686d2a75cc14ad36c59459700ddf463bbe0 100644
--- a/test.py
+++ b/test.py
@@ -2,6 +2,8 @@ from pyPLNmodels.models import PLNPCA, _PLNPCA, PLN
 from pyPLNmodels import get_real_count_data, get_simulated_count_data
 
 import os
+import pandas as pd
+import numpy as np
 
 os.chdir("./pyPLNmodels/")
 
@@ -11,12 +13,15 @@ covariates = None
 offsets = None
 # counts, covariates, offsets = get_simulated_count_data(seed = 0)
 
-pca = PLNPCA([3, 4])
+pca = PLNPCA(counts, covariates, offsets, ranks=[3, 4])
+pca.fit(tol=0.1)
 
-pca.fit(counts, covariates, offsets, tol=0.1)
-print(pca)
-
-# pln = PLN()
+# pca.fit()
+# print(pca)
+# data = pd.DataFrame(counts)
+# pln = PLN("counts~1", data)
+# pln.fit()
+# print(pln)
 # pcamodel = pca.best_model()
 # pcamodel.save()
 # model = PLNPCA([4])[4]
diff --git a/tests/conftest.py b/tests/conftest.py
index 1df7d25bb7e0e401e6b471e678330ef839d98959..904240c9611cc87e63c8495600d21bfdfe94b6a8 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,3 +1,414 @@
 import sys
+import glob
+from functools import singledispatch
+
+import pytest
+import torch
+from pytest_lazyfixture import lazy_fixture as lf
+from pyPLNmodels import load_model, load_plnpca
+from pyPLNmodels.models import PLN, _PLNPCA, PLNPCA
+
 
 sys.path.append("../")
+
+pytest_plugins = [
+    fixture_file.replace("/", ".").replace(".py", "")
+    for fixture_file in glob.glob("src/**/tests/fixtures/[!__]*.py", recursive=True)
+]
+
+
+from tests.import_data import (
+    data_sim_0cov,
+    data_sim_2cov,
+    data_real,
+)
+
+
+counts_sim_0cov = data_sim_0cov["counts"]
+covariates_sim_0cov = data_sim_0cov["covariates"]
+offsets_sim_0cov = data_sim_0cov["offsets"]
+
+counts_sim_2cov = data_sim_2cov["counts"]
+covariates_sim_2cov = data_sim_2cov["covariates"]
+offsets_sim_2cov = data_sim_2cov["offsets"]
+
+counts_real = data_real["counts"]
+
+
+def add_fixture_to_dict(my_dict, string_fixture):
+    my_dict[string_fixture] = [lf(string_fixture)]
+    return my_dict
+
+
+def add_list_of_fixture_to_dict(
+    my_dict, name_of_list_of_fixtures, list_of_string_fixtures
+):
+    my_dict[name_of_list_of_fixtures] = []
+    for string_fixture in list_of_string_fixtures:
+        my_dict[name_of_list_of_fixtures].append(lf(string_fixture))
+    return my_dict
+
+
+RANK = 8
+RANKS = [2, 6]
+instances = []
+# dict_fixtures_models = []
+
+
+@singledispatch
+def convenient_plnpca(
+    counts,
+    covariates=None,
+    offsets=None,
+    offsets_formula=None,
+    dict_initialization=None,
+):
+    return _PLNPCA(
+        counts, covariates, offsets, rank=RANK, dict_initialization=dict_initialization
+    )
+
+
+@convenient_plnpca.register(str)
+def _(formula, data, offsets_formula=None, dict_initialization=None):
+    return _PLNPCA(formula, data, rank=RANK, dict_initialization=dict_initialization)
+
+
+@singledispatch
+def convenientplnpca(
+    counts,
+    covariates=None,
+    offsets=None,
+    offsets_formula=None,
+    dict_initialization=None,
+):
+    return PLNPCA(
+        counts,
+        covariates,
+        offsets,
+        offsets_formula,
+        dict_of_dict_initialization=dict_initialization,
+        ranks=RANKS,
+    )
+
+
+@convenientplnpca.register(str)
+def _(formula, data, offsets_formula=None, dict_initialization=None):
+    return PLNPCA(
+        formula,
+        data,
+        offsets_formula,
+        ranks=RANKS,
+        dict_of_dict_initialization=dict_initialization,
+    )
+
+
+def generate_new_model(model, *args, **kwargs):
+    name_dir = model.directory_name
+    name = model.NAME
+    if name in ("PLN", "_PLNPCA"):
+        path = model.path_to_directory + name_dir
+        init = load_model(path)
+        if name == "PLN":
+            new = PLN(*args, **kwargs, dict_initialization=init)
+        if name == "_PLNPCA":
+            new = convenient_plnpca(*args, **kwargs, dict_initialization=init)
+    if name == "PLNPCA":
+        init = load_plnpca(name_dir)
+        new = convenientplnpca(*args, **kwargs, dict_initialization=init)
+    return new
+
+
+def cache(func):
+    dict_cache = {}
+
+    def new_func(request):
+        if request.param.__name__ not in dict_cache:
+            dict_cache[request.param.__name__] = func(request)
+        return dict_cache[request.param.__name__]
+
+    return new_func
+
+
+params = [PLN, convenient_plnpca, convenientplnpca]
+dict_fixtures = {}
+
+
+@pytest.fixture(params=params)
+def simulated_pln_0cov_array(request):
+    cls = request.param
+    pln = cls(counts_sim_0cov, covariates_sim_0cov, offsets_sim_0cov)
+    return pln
+
+
+@pytest.fixture(params=params)
+@cache
+def simulated_fitted_pln_0cov_array(request):
+    cls = request.param
+    pln = cls(counts_sim_0cov, covariates_sim_0cov, offsets_sim_0cov)
+    pln.fit()
+    return pln
+
+
+@pytest.fixture(params=params)
+def simulated_pln_0cov_formula(request):
+    cls = request.param
+    pln = cls("counts ~ 0", data_sim_0cov)
+    return pln
+
+
+@pytest.fixture(params=params)
+@cache
+def simulated_fitted_pln_0cov_formula(request):
+    cls = request.param
+    pln = cls("counts ~ 0", data_sim_0cov)
+    pln.fit()
+    return pln
+
+
+@pytest.fixture
+def simulated_loaded_pln_0cov_formula(simulated_fitted_pln_0cov_formula):
+    simulated_fitted_pln_0cov_formula.save()
+    return generate_new_model(
+        simulated_fitted_pln_0cov_formula,
+        "counts ~ 0",
+        data_sim_0cov,
+    )
+
+
+@pytest.fixture
+def simulated_loaded_pln_0cov_array(simulated_fitted_pln_0cov_array):
+    simulated_fitted_pln_0cov_array.save()
+    return generate_new_model(
+        simulated_fitted_pln_0cov_array,
+        counts_sim_0cov,
+        covariates_sim_0cov,
+        offsets_sim_0cov,
+    )
+
+
+sim_pln_0cov_instance = [
+    "simulated_pln_0cov_array",
+    "simulated_pln_0cov_formula",
+]
+
+instances = sim_pln_0cov_instance + instances
+
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "sim_pln_0cov_instance", sim_pln_0cov_instance
+)
+
+sim_pln_0cov_fitted = [
+    "simulated_fitted_pln_0cov_array",
+    "simulated_fitted_pln_0cov_formula",
+]
+
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "sim_pln_0cov_fitted", sim_pln_0cov_fitted
+)
+
+sim_pln_0cov_loaded = [
+    "simulated_loaded_pln_0cov_array",
+    "simulated_loaded_pln_0cov_formula",
+]
+
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "sim_pln_0cov_loaded", sim_pln_0cov_loaded
+)
+
+sim_pln_0cov = sim_pln_0cov_instance + sim_pln_0cov_fitted + sim_pln_0cov_loaded
+dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "sim_pln_0cov", sim_pln_0cov)
+
+
+@pytest.fixture(params=params)
+@cache
+def simulated_pln_2cov_array(request):
+    cls = request.param
+    pln_full = cls(counts_sim_2cov, covariates_sim_2cov, offsets_sim_2cov)
+    return pln_full
+
+
+@pytest.fixture
+def simulated_fitted_pln_2cov_array(simulated_pln_2cov_array):
+    simulated_pln_2cov_array.fit()
+    return simulated_pln_2cov_array
+
+
+@pytest.fixture(params=params)
+@cache
+def simulated_pln_2cov_formula(request):
+    cls = request.param
+    pln_full = cls("counts ~ 0 + covariates", data_sim_2cov)
+    return pln_full
+
+
+@pytest.fixture
+def simulated_fitted_pln_2cov_formula(simulated_pln_2cov_formula):
+    simulated_pln_2cov_formula.fit()
+    return simulated_pln_2cov_formula
+
+
+@pytest.fixture
+def simulated_loaded_pln_2cov_formula(simulated_fitted_pln_2cov_formula):
+    simulated_fitted_pln_2cov_formula.save()
+    return generate_new_model(
+        simulated_fitted_pln_2cov_formula,
+        "counts ~0 + covariates",
+        data_sim_2cov,
+    )
+
+
+@pytest.fixture
+def simulated_loaded_pln_2cov_array(simulated_fitted_pln_2cov_array):
+    simulated_fitted_pln_2cov_array.save()
+    return generate_new_model(
+        simulated_fitted_pln_2cov_array,
+        counts_sim_2cov,
+        covariates_sim_2cov,
+        offsets_sim_2cov,
+    )
+
+
+sim_pln_2cov_instance = [
+    "simulated_pln_2cov_array",
+    "simulated_pln_2cov_formula",
+]
+instances = sim_pln_2cov_instance + instances
+
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "sim_pln_2cov_instance", sim_pln_2cov_instance
+)
+
+sim_pln_2cov_fitted = [
+    "simulated_fitted_pln_2cov_array",
+    "simulated_fitted_pln_2cov_formula",
+]
+
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "sim_pln_2cov_fitted", sim_pln_2cov_fitted
+)
+
+sim_pln_2cov_loaded = [
+    "simulated_loaded_pln_2cov_array",
+    "simulated_loaded_pln_2cov_formula",
+]
+
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "sim_pln_2cov_loaded", sim_pln_2cov_loaded
+)
+
+sim_pln_2cov = sim_pln_2cov_instance + sim_pln_2cov_fitted + sim_pln_2cov_loaded
+dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "sim_pln_2cov", sim_pln_2cov)
+
+
+@pytest.fixture(params=params)
+@cache
+def real_pln_intercept_array(request):
+    cls = request.param
+    pln_full = cls(counts_real, covariates=torch.ones((counts_real.shape[0], 1)))
+    return pln_full
+
+
+@pytest.fixture
+def real_fitted_pln_intercept_array(real_pln_intercept_array):
+    real_pln_intercept_array.fit()
+    return real_pln_intercept_array
+
+
+@pytest.fixture(params=params)
+@cache
+def real_pln_intercept_formula(request):
+    cls = request.param
+    pln_full = cls("counts ~ 1", data_real)
+    return pln_full
+
+
+@pytest.fixture
+def real_fitted_pln_intercept_formula(real_pln_intercept_formula):
+    real_pln_intercept_formula.fit()
+    return real_pln_intercept_formula
+
+
+@pytest.fixture
+def real_loaded_pln_intercept_formula(real_fitted_pln_intercept_formula):
+    real_fitted_pln_intercept_formula.save()
+    return generate_new_model(
+        real_fitted_pln_intercept_formula, "counts ~ 1", data_real
+    )
+
+
+@pytest.fixture
+def real_loaded_pln_intercept_array(real_fitted_pln_intercept_array):
+    real_fitted_pln_intercept_array.save()
+    return generate_new_model(
+        real_fitted_pln_intercept_array,
+        counts_real,
+        covariates=torch.ones((counts_real.shape[0], 1)),
+    )
+
+
+real_pln_instance = [
+    "real_pln_intercept_array",
+    "real_pln_intercept_formula",
+]
+instances = real_pln_instance + instances
+
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "real_pln_instance", real_pln_instance
+)
+
+real_pln_fitted = [
+    "real_fitted_pln_intercept_array",
+    "real_fitted_pln_intercept_formula",
+]
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "real_pln_fitted", real_pln_fitted
+)
+
+real_pln_loaded = [
+    "real_loaded_pln_intercept_array",
+    "real_loaded_pln_intercept_formula",
+]
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "real_pln_loaded", real_pln_loaded
+)
+
+sim_loaded_pln = sim_pln_0cov_loaded + sim_pln_2cov_loaded
+
+loaded_pln = real_pln_loaded + sim_loaded_pln
+dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "loaded_pln", loaded_pln)
+
+simulated_pln_fitted = sim_pln_0cov_fitted + sim_pln_2cov_fitted
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "simulated_pln_fitted", simulated_pln_fitted
+)
+fitted_pln = real_pln_fitted + simulated_pln_fitted
+dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "fitted_pln", fitted_pln)
+
+
+loaded_and_fitted_sim_pln = simulated_pln_fitted + sim_loaded_pln
+loaded_and_fitted_real_pln = real_pln_fitted + real_pln_loaded
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "loaded_and_fitted_real_pln", loaded_and_fitted_real_pln
+)
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "loaded_and_fitted_sim_pln", loaded_and_fitted_sim_pln
+)
+loaded_and_fitted_pln = fitted_pln + loaded_pln
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "loaded_and_fitted_pln", loaded_and_fitted_pln
+)
+
+real_pln = real_pln_instance + real_pln_fitted + real_pln_loaded
+dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "real_pln", real_pln)
+
+sim_pln = sim_pln_2cov + sim_pln_0cov
+dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "sim_pln", sim_pln)
+
+all_pln = real_pln + sim_pln + instances
+dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "instances", instances)
+dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "all_pln", all_pln)
+
+
+for string_fixture in all_pln:
+    print("string_fixture", string_fixture)
+    dict_fixtures = add_fixture_to_dict(dict_fixtures, string_fixture)
diff --git a/tests/import_data.py b/tests/import_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..80e613ba12fc5a92d6a8187c65f9bcfd32ee23b9
--- /dev/null
+++ b/tests/import_data.py
@@ -0,0 +1,39 @@
+import os
+
+from pyPLNmodels import (
+    get_simulated_count_data,
+    get_real_count_data,
+)
+
+
+(
+    counts_sim_0cov,
+    covariates_sim_0cov,
+    offsets_sim_0cov,
+    true_covariance_0cov,
+    true_coef_0cov,
+) = get_simulated_count_data(return_true_param=True, nb_cov=0)
+(
+    counts_sim_2cov,
+    covariates_sim_2cov,
+    offsets_sim_2cov,
+    true_covariance_2cov,
+    true_coef_2cov,
+) = get_simulated_count_data(return_true_param=True, nb_cov=2)
+
+data_sim_0cov = {
+    "counts": counts_sim_0cov,
+    "covariates": covariates_sim_0cov,
+    "offsets": offsets_sim_0cov,
+}
+true_sim_0cov = {"Sigma": true_covariance_0cov, "beta": true_coef_0cov}
+true_sim_2cov = {"Sigma": true_covariance_2cov, "beta": true_coef_2cov}
+
+
+data_sim_2cov = {
+    "counts": counts_sim_2cov,
+    "covariates": covariates_sim_2cov,
+    "offsets": offsets_sim_2cov,
+}
+counts_real = get_real_count_data()
+data_real = {"counts": counts_real}
diff --git a/tests/import_fixtures.py b/tests/import_fixtures.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tests/test_args.py b/tests/test_args.py
deleted file mode 100644
index 16c8a73d7f0c354e7691ed0b66c734518fe083ca..0000000000000000000000000000000000000000
--- a/tests/test_args.py
+++ /dev/null
@@ -1,51 +0,0 @@
-import os
-
-from pyPLNmodels.models import PLN, PLNPCA, _PLNPCA
-from pyPLNmodels import get_simulated_count_data, get_real_count_data
-import pytest
-from pytest_lazyfixture import lazy_fixture as lf
-import pandas as pd
-import numpy as np
-
-(
-    counts_sim,
-    covariates_sim,
-    offsets_sim,
-) = get_simulated_count_data(nb_cov=2)
-
-couts_real = get_real_count_data(n_samples=298, dim=101)
-RANKS = [2, 8]
-
-
-@pytest.fixture
-def instance_plnpca():
-    plnpca = PLNPCA(ranks=RANKS)
-    return plnpca
-
-
-@pytest.fixture
-def instance__plnpca():
-    model = _PLNPCA(rank=RANKS[0])
-    return model
-
-
-@pytest.fixture
-def instance_pln_full():
-    return PLN()
-
-
-all_instances = [lf("instance_plnpca"), lf("instance__plnpca"), lf("instance_pln_full")]
-
-
-@pytest.mark.parametrize("instance", all_instances)
-def test_pandas_init(instance):
-    instance.fit(
-        pd.DataFrame(counts_sim.numpy()),
-        pd.DataFrame(covariates_sim.numpy()),
-        pd.DataFrame(offsets_sim.numpy()),
-    )
-
-
-@pytest.mark.parametrize("instance", all_instances)
-def test_numpy_init(instance):
-    instance.fit(counts_sim.numpy(), covariates_sim.numpy(), offsets_sim.numpy())
diff --git a/tests/test_common.py b/tests/test_common.py
index 7e0e6c955a074e8169b643f014bca59583e32d01..b74cf9fbae73676d7c5bcf34643324c63e07ad4a 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -1,192 +1,30 @@
-import torch
-import numpy as np
-import pandas as pd
-
-from pyPLNmodels.models import PLN, _PLNPCA
-from pyPLNmodels import get_simulated_count_data, get_real_count_data
-from tests.utils import MSE
-
-import pytest
-from pytest_lazyfixture import lazy_fixture as lf
 import os
 
-(
-    counts_sim,
-    covariates_sim,
-    offsets_sim,
-    true_covariance,
-    true_coef,
-) = get_simulated_count_data(return_true_param=True, nb_cov=2)
-
-
-counts_real = get_real_count_data()
-rank = 8
-
-
-@pytest.fixture
-def instance_pln_full():
-    pln_full = PLN()
-    return pln_full
-
-
-@pytest.fixture
-def instance__plnpca():
-    plnpca = _PLNPCA(rank=rank)
-    return plnpca
-
-
-@pytest.fixture
-def simulated_fitted_pln_full():
-    pln_full = PLN()
-    pln_full.fit(counts=counts_sim, covariates=covariates_sim, offsets=offsets_sim)
-    return pln_full
-
-
-@pytest.fixture
-def simulated_fitted__plnpca():
-    plnpca = _PLNPCA(rank=rank)
-    plnpca.fit(counts=counts_sim, covariates=covariates_sim, offsets=offsets_sim)
-    return plnpca
-
-
-@pytest.fixture
-def loaded_simulated_pln_full(simulated_fitted_pln_full):
-    simulated_fitted_pln_full.save()
-    loaded_pln_full = PLN()
-    loaded_pln_full.load()
-    return loaded_pln_full
-
-
-@pytest.fixture
-def loaded_refit_simulated_pln_full(loaded_simulated_pln_full):
-    loaded_simulated_pln_full.fit(
-        counts=counts_sim,
-        covariates=covariates_sim,
-        offsets=offsets_sim,
-        keep_going=True,
-    )
-    return loaded_simulated_pln_full
-
-
-@pytest.fixture
-def loaded_simulated__plnpca(simulated_fitted__plnpca):
-    simulated_fitted__plnpca.save()
-    loaded_pln_full = _PLNPCA(rank=rank)
-    loaded_pln_full.load()
-    return loaded_pln_full
-
-
-@pytest.fixture
-def loaded_refit_simulated__plnpca(loaded_simulated__plnpca):
-    loaded_simulated__plnpca.fit(
-        counts=counts_sim,
-        covariates=covariates_sim,
-        offsets=offsets_sim,
-        keep_going=True,
-    )
-    return loaded_simulated__plnpca
-
-
-@pytest.fixture
-def real_fitted_pln_full():
-    pln_full = PLN()
-    pln_full.fit(counts=counts_real)
-    return pln_full
-
-
-@pytest.fixture
-def loaded_real_pln_full(real_fitted_pln_full):
-    real_fitted_pln_full.save()
-    loaded_pln_full = PLN()
-    loaded_pln_full.load()
-    return loaded_pln_full
-
-
-@pytest.fixture
-def loaded_refit_real_pln_full(loaded_real_pln_full):
-    loaded_real_pln_full.fit(counts=counts_real, keep_going=True)
-    return loaded_real_pln_full
-
-
-@pytest.fixture
-def real_fitted__plnpca():
-    plnpca = _PLNPCA(rank=rank)
-    plnpca.fit(counts=counts_real)
-    return plnpca
-
-
-@pytest.fixture
-def loaded_real__plnpca(real_fitted__plnpca):
-    real_fitted__plnpca.save()
-    loaded_plnpca = _PLNPCA(rank=rank)
-    loaded_plnpca.load()
-    return loaded_plnpca
-
-
-@pytest.fixture
-def loaded_refit_real__plnpca(loaded_real__plnpca):
-    loaded_real__plnpca.fit(counts=counts_real, keep_going=True)
-    return loaded_real__plnpca
-
-
-real_pln_full = [
-    lf("real_fitted_pln_full"),
-    lf("loaded_real_pln_full"),
-    lf("loaded_refit_real_pln_full"),
-]
-real__plnpca = [
-    lf("real_fitted__plnpca"),
-    lf("loaded_real__plnpca"),
-    lf("loaded_refit_real__plnpca"),
-]
-simulated_pln_full = [
-    lf("simulated_fitted_pln_full"),
-    lf("loaded_simulated_pln_full"),
-    lf("loaded_refit_simulated_pln_full"),
-]
-simulated__plnpca = [
-    lf("simulated_fitted__plnpca"),
-    lf("loaded_simulated__plnpca"),
-    lf("loaded_refit_simulated__plnpca"),
-]
-
-loaded_sim_pln = [
-    lf("loaded_simulated__plnpca"),
-    lf("loaded_simulated_pln_full"),
-    lf("loaded_refit_simulated_pln_full"),
-    lf("loaded_refit_simulated_pln_full"),
-]
-
-
-@pytest.mark.parametrize("loaded", loaded_sim_pln)
-def test_refit_not_keep_going(loaded):
-    loaded.fit(
-        counts=counts_sim,
-        covariates=covariates_sim,
-        offsets=offsets_sim,
-        keep_going=False,
-    )
-
-
-all_instances = [lf("instance__plnpca"), lf("instance_pln_full")]
+import torch
+import pytest
 
-all_fitted__plnpca = simulated__plnpca + real__plnpca
-all_fitted_pln_full = simulated_pln_full + real_pln_full
+from tests.conftest import dict_fixtures
+from tests.utils import MSE, filter_models
 
-simulated_any_pln = simulated__plnpca + simulated_pln_full
-real_any_pln = real_pln_full + real__plnpca
-all_fitted_models = simulated_any_pln + real_any_pln
+from tests.import_data import true_sim_0cov, true_sim_2cov
 
 
-@pytest.mark.parametrize("any_pln", all_fitted_models)
+@pytest.mark.parametrize("any_pln", dict_fixtures["fitted_pln"])
+@filter_models(["PLN", "_PLNPCA"])
 def test_properties(any_pln):
-    assert hasattr(any_pln, "latent_variables")
-    assert hasattr(any_pln, "model_parameters")
     assert hasattr(any_pln, "latent_parameters")
+    assert hasattr(any_pln, "latent_variables")
     assert hasattr(any_pln, "optim_parameters")
+    assert hasattr(any_pln, "model_parameters")
+
+
+@pytest.mark.parametrize("any_pln", dict_fixtures["loaded_and_fitted_pln"])
+def test_print(any_pln):
+    print(any_pln)
 
 
-@pytest.mark.parametrize("any_pln", all_fitted_models)
+@pytest.mark.parametrize("any_pln", dict_fixtures["fitted_pln"])
+@filter_models(["PLN", "_PLNPCA"])
 def test_show_coef_transform_covariance_pcaprojected(any_pln):
     any_pln.show()
     any_pln.plotargs.show_loss()
@@ -200,137 +38,80 @@ def test_show_coef_transform_covariance_pcaprojected(any_pln):
         any_pln.pca_projected_latent_variables(n_components=any_pln.dim + 1)
 
 
-@pytest.mark.parametrize("sim_pln", simulated_any_pln)
+@pytest.mark.parametrize("sim_pln", dict_fixtures["loaded_and_fitted_pln"])
+@filter_models(["PLN", "_PLNPCA"])
 def test_predict_simulated(sim_pln):
-    X = torch.randn((sim_pln.n_samples, sim_pln.nb_cov - 1))
-    prediction = sim_pln.predict(X)
-    expected = (
-        torch.stack((torch.ones(sim_pln.n_samples, 1), X), axis=1).squeeze()
-        @ sim_pln.coef
-    )
-    assert torch.all(torch.eq(expected, prediction))
-
-
-@pytest.mark.parametrize("real_pln", real_any_pln)
-def test_predict_real(real_pln):
-    prediction = real_pln.predict()
-    expected = torch.ones(real_pln.n_samples, 1) @ real_pln.coef
-    assert torch.all(torch.eq(expected, prediction))
-
-
-@pytest.mark.parametrize("any_pln", all_fitted_models)
-def test_print(any_pln):
-    print(any_pln)
-
-
-@pytest.mark.parametrize("any_instance_pln", all_instances)
+    if sim_pln.nb_cov == 0:
+        assert sim_pln.predict() is None
+        with pytest.raises(AttributeError):
+            sim_pln.predict(1)
+    else:
+        X = torch.randn((sim_pln.n_samples, sim_pln.nb_cov))
+        prediction = sim_pln.predict(X)
+        expected = X @ sim_pln.coef
+        assert torch.all(torch.eq(expected, prediction))
+
+
+@pytest.mark.parametrize("any_instance_pln", dict_fixtures["instances"])
 def test_verbose(any_instance_pln):
-    any_instance_pln.fit(
-        counts=counts_sim, covariates=covariates_sim, offsets=offsets_sim, verbose=True
-    )
-
+    any_instance_pln.fit(verbose=True, tol=0.1)
 
-@pytest.mark.parametrize("sim_pln", simulated_any_pln)
-def test_only_counts(sim_pln):
-    sim_pln.fit(counts=counts_sim)
 
-
-@pytest.mark.parametrize("sim_pln", simulated_any_pln)
-def test_only_counts_and_offsets(sim_pln):
-    sim_pln.fit(counts=counts_sim, offsets=offsets_sim)
-
-
-@pytest.mark.parametrize("sim_pln", simulated_any_pln)
-def test_only_Y_and_cov(sim_pln):
-    sim_pln.fit(counts=counts_sim, covariates=covariates_sim)
-
-
-@pytest.mark.parametrize("simulated_fitted_any_pln", simulated_any_pln)
+@pytest.mark.parametrize(
+    "simulated_fitted_any_pln", dict_fixtures["loaded_and_fitted_sim_pln"]
+)
+@filter_models(["PLN", "_PLNPCA"])
 def test_find_right_covariance(simulated_fitted_any_pln):
+    if simulated_fitted_any_pln.nb_cov == 0:
+        true_covariance = true_sim_0cov["Sigma"]
+    elif simulated_fitted_any_pln.nb_cov == 2:
+        true_covariance = true_sim_2cov["Sigma"]
     mse_covariance = MSE(simulated_fitted_any_pln.covariance - true_covariance)
     assert mse_covariance < 0.05
 
 
-@pytest.mark.parametrize("sim_pln", simulated_any_pln)
-def test_find_right_coef(sim_pln):
-    mse_coef = MSE(sim_pln.coef - true_coef)
-    assert mse_coef < 0.1
+@pytest.mark.parametrize(
+    "real_fitted_and_loaded_pln", dict_fixtures["loaded_and_fitted_real_pln"]
+)
+@filter_models(["PLN", "_PLNPCA"])
+def test_right_covariance_shape(real_fitted_and_loaded_pln):
+    assert real_fitted_and_loaded_pln.covariance.shape == (100, 100)
 
 
-def test_number_of_iterations_pln_full(simulated_fitted_pln_full):
-    nb_iterations = len(simulated_fitted_pln_full.elbos_list)
-    assert 50 < nb_iterations < 300
+@pytest.mark.parametrize(
+    "simulated_fitted_any_pln", dict_fixtures["loaded_and_fitted_pln"]
+)
+@filter_models(["PLN", "_PLNPCA"])
+def test_find_right_coef(simulated_fitted_any_pln):
+    if simulated_fitted_any_pln.nb_cov == 2:
+        true_coef = true_sim_2cov["beta"]
+        mse_coef = MSE(simulated_fitted_any_pln.coef - true_coef)
+        assert mse_coef < 0.1
+    elif simulated_fitted_any_pln.nb_cov == 0:
+        assert simulated_fitted_any_pln.coef is None
 
 
-def test_computable_elbopca(instance__plnpca, simulated_fitted__plnpca):
-    instance__plnpca.counts = simulated_fitted__plnpca.counts
-    instance__plnpca.covariates = simulated_fitted__plnpca.covariates
-    instance__plnpca.offsets = simulated_fitted__plnpca.offsets
-    instance__plnpca.latent_mean = simulated_fitted__plnpca.latent_mean
-    instance__plnpca.latent_var = simulated_fitted__plnpca.latent_var
-    instance__plnpca.components = simulated_fitted__plnpca.components
-    instance__plnpca.coef = simulated_fitted__plnpca.coef
-    instance__plnpca.compute_elbo()
-
-
-def test_computable_elbo_full(instance_pln_full, simulated_fitted_pln_full):
-    instance_pln_full.counts = simulated_fitted_pln_full.counts
-    instance_pln_full.covariates = simulated_fitted_pln_full.covariates
-    instance_pln_full.offsets = simulated_fitted_pln_full.offsets
-    instance_pln_full.latent_mean = simulated_fitted_pln_full.latent_mean
-    instance_pln_full.latent_var = simulated_fitted_pln_full.latent_var
-    instance_pln_full.covariance = simulated_fitted_pln_full.covariance
-    instance_pln_full.coef = simulated_fitted_pln_full.coef
-    instance_pln_full.compute_elbo()
-
-
-def test_fail_count_setter(simulated_fitted_pln_full):
+@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"])
+@filter_models(["PLN", "_PLNPCA"])
+def test_fail_count_setter(pln):
     wrong_counts = torch.randint(size=(10, 5), low=0, high=10)
     with pytest.raises(Exception):
-        simulated_fitted_pln_full.counts = wrong_counts
-
+        pln.counts = wrong_counts
 
-@pytest.mark.parametrize("any_pln", all_fitted_models)
-def test_setter_with_numpy(any_pln):
-    np_counts = any_pln.counts.numpy()
-    any_pln.counts = np_counts
 
-
-@pytest.mark.parametrize("any_pln", all_fitted_models)
-def test_setter_with_pandas(any_pln):
-    pd_counts = pd.DataFrame(any_pln.counts.numpy())
-    any_pln.counts = pd_counts
-
-
-@pytest.mark.parametrize("instance", all_instances)
+@pytest.mark.parametrize("instance", dict_fixtures["instances"])
 def test_random_init(instance):
-    instance.fit(counts_sim, covariates_sim, offsets_sim, do_smart_init=False)
+    instance.fit(do_smart_init=False)
 
 
-@pytest.mark.parametrize("instance", all_instances)
+@pytest.mark.parametrize("instance", dict_fixtures["instances"])
 def test_print_end_of_fitting_message(instance):
-    instance.fit(counts_sim, covariates_sim, offsets_sim, nb_max_iteration=4)
+    instance.fit(nb_max_iteration=4)
 
 
-@pytest.mark.parametrize("any_pln", all_fitted_models)
-def test_fail_wrong_covariates_prediction(any_pln):
-    X = torch.randn(any_pln.n_samples, any_pln.nb_cov)
+@pytest.mark.parametrize("pln", dict_fixtures["fitted_pln"])
+@filter_models(["PLN", "_PLNPCA"])
+def test_fail_wrong_covariates_prediction(pln):
+    X = torch.randn(pln.n_samples, pln.nb_cov + 1)
     with pytest.raises(Exception):
-        any_pln.predict(X)
-
-
-@pytest.mark.parametrize("any__plnpca", all_fitted__plnpca)
-def test_latent_var_pca(any__plnpca):
-    assert any__plnpca.transform(project=False).shape == any__plnpca.counts.shape
-    assert any__plnpca.transform().shape == (any__plnpca.n_samples, any__plnpca.rank)
-
-
-@pytest.mark.parametrize("any_pln_full", all_fitted_pln_full)
-def test_latent_var_pln_full(any_pln_full):
-    assert any_pln_full.transform().shape == any_pln_full.counts.shape
-
-
-def test_wrong_rank():
-    instance = _PLNPCA(counts_sim.shape[1] + 1)
-    with pytest.warns(UserWarning):
-        instance.fit(counts=counts_sim, covariates=covariates_sim, offsets=offsets_sim)
+        pln.predict(X)
diff --git a/tests/test_pln_full.py b/tests/test_pln_full.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f9d6a5d06dd4e207db265ed5981629d9da19e88
--- /dev/null
+++ b/tests/test_pln_full.py
@@ -0,0 +1,17 @@
+import pytest
+
+from tests.conftest import dict_fixtures
+from tests.utils import filter_models
+
+
+@pytest.mark.parametrize("fitted_pln", dict_fixtures["fitted_pln"])
+@filter_models(["PLN"])
+def test_number_of_iterations_pln_full(fitted_pln):
+    nb_iterations = len(fitted_pln.elbos_list)
+    assert 50 < nb_iterations < 300
+
+
+@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"])
+@filter_models(["PLN"])
+def test_latent_var_full(pln):
+    assert pln.transform().shape == pln.counts.shape
diff --git a/tests/test_plnpca.py b/tests/test_plnpca.py
index db0e324aac2d70ff3b98a517aca8fec9f01b0f4f..9eb1b2f4289efce04cef20a824570e407aede96e 100644
--- a/tests/test_plnpca.py
+++ b/tests/test_plnpca.py
@@ -1,158 +1,76 @@
 import os
 
 import pytest
-from pytest_lazyfixture import lazy_fixture as lf
-from pyPLNmodels.models import PLNPCA, _PLNPCA
-from pyPLNmodels import get_simulated_count_data, get_real_count_data
-from tests.utils import MSE
-
 import matplotlib.pyplot as plt
 import numpy as np
 
-(
-    counts_sim,
-    covariates_sim,
-    offsets_sim,
-    true_covariance,
-    true_coef,
-) = get_simulated_count_data(return_true_param=True)
-
-counts_real = get_real_count_data()
-RANKS = [2, 8]
-
-
-@pytest.fixture
-def my_instance_plnpca():
-    plnpca = PLNPCA(ranks=RANKS)
-    return plnpca
-
-
-@pytest.fixture
-def real_fitted_plnpca(my_instance_plnpca):
-    my_instance_plnpca.fit(counts_real)
-    return my_instance_plnpca
-
-
-@pytest.fixture
-def simulated_fitted_plnpca(my_instance_plnpca):
-    my_instance_plnpca.fit(
-        counts=counts_sim, covariates=covariates_sim, offsets=offsets_sim
-    )
-    return my_instance_plnpca
-
-
-@pytest.fixture
-def one_simulated_fitted_plnpca():
-    model = PLNPCA(ranks=2)
-    model.fit(counts=counts_sim, covariates=covariates_sim, offsets=offsets_sim)
-    return model
-
-
-@pytest.fixture
-def real_best_aic(real_fitted_plnpca):
-    return real_fitted_plnpca.best_model("AIC")
-
-
-@pytest.fixture
-def real_best_bic(real_fitted_plnpca):
-    return real_fitted_plnpca.best_model("BIC")
-
-
-@pytest.fixture
-def simulated_best_aic(simulated_fitted_plnpca):
-    return simulated_fitted_plnpca.best_model("AIC")
-
-
-@pytest.fixture
-def simulated_best_bic(simulated_fitted_plnpca):
-    return simulated_fitted_plnpca.best_model("BIC")
+from tests.conftest import dict_fixtures
+from tests.utils import MSE, filter_models
 
 
-simulated_best_models = [lf("simulated_best_aic"), lf("simulated_best_bic")]
-real_best_models = [lf("real_best_aic"), lf("real_best_bic")]
-best_models = simulated_best_models + real_best_models
-
-
-all_fitted_simulated_plnpca = [
-    lf("simulated_fitted_plnpca"),
-    lf("one_simulated_fitted_plnpca"),
-]
-all_fitted_plnpca = [lf("real_fitted_plnpca")] + all_fitted_simulated_plnpca
-
-
-def test_print_plnpca(simulated_fitted_plnpca):
-    print(simulated_fitted_plnpca)
-
-
-@pytest.mark.parametrize("best_model", best_models)
-def test_best_model(best_model):
+@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"])
+@filter_models(["PLNPCA"])
+def test_best_model(plnpca):
+    best_model = plnpca.best_model()
     print(best_model)
 
 
-@pytest.mark.parametrize("best_model", best_models)
-def test_projected_variables(best_model):
+@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"])
+@filter_models(["PLNPCA"])
+def test_projected_variables(plnpca):
+    best_model = plnpca.best_model()
     plv = best_model.projected_latent_variables
     assert plv.shape[0] == best_model.n_samples and plv.shape[1] == best_model.rank
 
 
-def test_save_load_back_and_refit(simulated_fitted_plnpca):
-    simulated_fitted_plnpca.save()
-    new = PLNPCA(ranks=RANKS)
-    new.load()
-    new.fit(counts=counts_sim, covariates=covariates_sim, offsets=offsets_sim)
-
-
-@pytest.mark.parametrize("plnpca", all_fitted_simulated_plnpca)
-def test_find_right_covariance(plnpca):
-    passed = True
-    for model in plnpca.models:
-        mse_covariance = MSE(model.covariance - true_covariance)
-        assert mse_covariance < 0.3
-
-
-@pytest.mark.parametrize("plnpca", all_fitted_simulated_plnpca)
-def test_find_right_coef(plnpca):
-    for model in plnpca.models:
-        mse_coef = MSE(model.coef - true_coef)
-        assert mse_coef < 0.3
-
-
-@pytest.mark.parametrize("all_pca", all_fitted_plnpca)
-def test_additional_methods_pca(all_pca):
-    all_pca.show()
-    all_pca.BIC
-    all_pca.AIC
-    all_pca.loglikes
-
-
-@pytest.mark.parametrize("all_pca", all_fitted_plnpca)
-def test_viz_pca(all_pca):
-    _, ax = plt.subplots()
-    all_pca[2].viz(ax=ax)
-    plt.show()
-    all_pca[2].viz()
-    plt.show()
-    n_samples = all_pca.n_samples
-    colors = np.random.randint(low=0, high=2, size=n_samples)
-    all_pca[2].viz(colors=colors)
-    plt.show()
-
-
-@pytest.mark.parametrize(
-    "pca", [lf("real_fitted_plnpca"), lf("simulated_fitted_plnpca")]
-)
-def test_fails_viz_pca(pca):
-    with pytest.raises(RuntimeError):
-        pca[8].viz()
-
-
-@pytest.mark.parametrize("all_pca", all_fitted_plnpca)
-def test_closest(all_pca):
+@pytest.mark.parametrize("fitted_pln", dict_fixtures["fitted_pln"])
+@filter_models(["_PLNPCA"])
+def test_number_of_iterations_plnpca(fitted_pln):
+    nb_iterations = len(fitted_pln.elbos_list)
+    assert 100 < nb_iterations < 5000
+
+
+@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"])
+@filter_models(["_PLNPCA"])
+def test_latent_var_pca(plnpca):
+    assert plnpca.transform(project=False).shape == plnpca.counts.shape
+    assert plnpca.transform().shape == (plnpca.n_samples, plnpca.rank)
+
+
+@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"])
+@filter_models(["PLNPCA"])
+def test_additional_methods_pca(plnpca):
+    plnpca.show()
+    plnpca.BIC
+    plnpca.AIC
+    plnpca.loglikes
+
+
+@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"])
+@filter_models(["PLNPCA"])
+def test_viz_pca(plnpca):
+    models = plnpca.models
+    for model in models:
+        _, ax = plt.subplots()
+        model.viz(ax=ax)
+        plt.show()
+        model.viz()
+        plt.show()
+        n_samples = plnpca.n_samples
+        colors = np.random.randint(low=0, high=2, size=n_samples)
+        model.viz(colors=colors)
+        plt.show()
+
+
+@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"])
+@filter_models(["PLNPCA"])
+def test_closest(plnpca):
     with pytest.warns(UserWarning):
-        all_pca[9]
+        plnpca[9]
 
 
-@pytest.mark.parametrize("plnpca", all_fitted_plnpca)
+@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"])
+@filter_models(["PLNPCA"])
 def test_wrong_criterion(plnpca):
     with pytest.raises(ValueError):
         plnpca.best_model("AIK")
diff --git a/tests/test_setters.py b/tests/test_setters.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1a9ba29c09d21faa2a10bc5dae71dbe368cbd7e
--- /dev/null
+++ b/tests/test_setters.py
@@ -0,0 +1,21 @@
+import pytest
+import pandas as pd
+
+from tests.conftest import dict_fixtures
+from tests.utils import MSE, filter_models
+
+
+@pytest.mark.parametrize("pln", dict_fixtures["all_pln"])
+@filter_models(["PLN", "PLNPCA"])
+def test_setter_with_numpy(pln):
+    np_counts = pln.counts.numpy()
+    pln.counts = np_counts
+    pln.fit()
+
+
+@pytest.mark.parametrize("pln", dict_fixtures["all_pln"])
+@filter_models(["PLN", "PLNPCA"])
+def test_setter_with_pandas(pln):
+    pd_counts = pd.DataFrame(pln.counts.numpy())
+    pln.counts = pd_counts
+    pln.fit()
diff --git a/tests/utils.py b/tests/utils.py
index 0cc7f2d7b9600b724015896ed97813699e153239..2e33a3d0111519d0e036040a43a38d74ecc3dafa 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -1,5 +1,20 @@
 import torch
+import functools
 
 
 def MSE(t):
     return torch.mean(t**2)
+
+
+def filter_models(models_name):
+    def decorator(my_test):
+        @functools.wraps(my_test)
+        def new_test(**kwargs):
+            fixture = next(iter(kwargs.values()))
+            if type(fixture).__name__ not in models_name:
+                return None
+            return my_test(**kwargs)
+
+        return new_test
+
+    return decorator