Source code for metacells.tools.project

"""
Project
-------
"""

from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple

import numpy as np
import scipy.sparse as sp  # type: ignore
from anndata import AnnData  # type: ignore

import metacells.parameters as pr
import metacells.utilities as ut

__all__ = [
    "compute_projection_weights",
    "compute_projected_fractions",
    "convey_atlas_to_query",
]


[docs] @ut.logged() @ut.timed_call() @ut.expand_doc() def compute_projection_weights( *, adata: AnnData, qdata: AnnData, from_atlas_layer: str = "corrected_fraction", from_query_layer: str = "corrected_fraction", to_query_layer: str = "projected_fraction", log_data: bool = pr.project_log_data, fold_regularization: float = pr.project_fold_regularization, min_significant_gene_umis: float = pr.project_min_significant_gene_umis, max_consistency_fold_factor: float = pr.project_max_consistency_fold_factor, candidates_count: int = pr.project_candidates_count, min_candidates_fraction: float = pr.project_min_candidates_fraction, min_usage_weight: float = pr.project_min_usage_weight, second_anchor_indices: Optional[List[int]] = None, reproducible: bool, ) -> ut.CompressedMatrix: """ Compute the weights and results of projecting a query onto an atlas. **Input** Annotated query ``qdata`` and atlas ``adata``, where the observations are cells and the variables are genes. The atlas should contain ``from_atlas_layer`` (default: {from_atlas_layer}) containing gene fractions, and the query should similarly contain ``from_query_layer`` (default: {from_query_layer}) containing gene fractions. **Returns** A CSR matrix whose rows are query metacells and columns are atlas metacells, where each entry is the weight of the atlas metacell in the projection of the query metacells. The sum of weights in each row (that is, for a single query metacell) is 1. The weighted sum of the atlas metacells using these weights is the "projected" image of the query metacell onto the atlas. In addition, sets the following annotations in ``qdata``: Observation (Cell) Annotations ``similar`` A boolean mask indicating whether the query metacell is similar to its projection onto the atlas. If ``False`` the metacells is said to be "dissimilar", which may indicate the query contains cell states that do not appear in the atlas. Observation-Variable (Cell-Gene) Annotations ``to_query_layer`` (default: {to_query_layer}) A matrix of gene fractions describing the "projected" image of the query metacell onto the atlas. This projection is a weighted average of some atlas metacells (using the computed weights returned by this function). **Computation Parameters** 0. All fold computations (log2 of the ratio between gene fractions) use the ``fold_regularization`` (default: {fold_regularization}). For each query metacell: 1. Correlate the metacell with all the atlas metacells, and pick the highest-correlated one as the "anchor". If ``second_anchor_indices`` is not ``None``, then the ``qdata`` must contain only a single query metacell, and is expected to contain a ``projected`` per-observation-per-variable matrix containing the projected image of this query metacell on the atlas using a single anchor. The code will compute the residual of the query and the atlas relative to this projection and pick a second atlas anchor whose residuals are the most correlated to the query metacell's residuals. If ``reproducible``, a slower (still parallel) but reproducible algorithm will be used. 2. Consider (for each anchor) the ``candidates_count`` (default: {candidates_count}) candidate metacells with the highest correlation with the query metacell. 3. Keep as candidates only atlas metacells whose maximal gene fold factor compared to the anchor(s) is at most ``max_consistency_fold_factor`` (default: {max_consistency_fold_factor}). Keep at least ``min_candidates_fraction`` (default: {min_candidates_fraction}) of the original candidates even if they are less consistent. For this computation, Ignore the fold factors of genes whose sum of UMIs in the anchor(s) and the candidate metacells is less than ``min_significant_gene_umis`` (default: {min_significant_gene_umis}). 4. Compute the non-negative weights (with a sum of 1) of the selected candidates that give the best projection of the query metacells onto the atlas. If ``log_data`` (default: {log_data}), try to fit the log (base 2) of the fractions, otherwise, try to fit the fractions themselves. Since the algorithm for computing these weights rarely produces an exact 0 weight, reduce all weights less than the ``min_usage_weight`` (default: {min_usage_weight}) to zero. If ``second_anchor_indices`` is not ``None``, it is set to the list of indices of the used atlas metacells candidates correlated with the second anchor. """ prepared_arguments = _project_query_atlas_data_arguments( adata=adata, qdata=qdata, from_atlas_layer=from_atlas_layer, from_query_layer=from_query_layer, to_query_layer=to_query_layer, log_data=log_data, fold_regularization=fold_regularization, min_significant_gene_umis=min_significant_gene_umis, max_consistency_fold_factor=max_consistency_fold_factor, candidates_count=candidates_count, min_candidates_fraction=min_candidates_fraction, min_usage_weight=min_usage_weight, second_anchor_indices=second_anchor_indices, reproducible=reproducible, ) @ut.timed_call() def _project_single(query_metacell_index: int) -> Tuple[ut.NumpyVector, ut.NumpyVector]: return _project_single_metacell( query_metacell_index=query_metacell_index, **prepared_arguments, ) if ut.is_main_process(): results = ut.parallel_map(_project_single, qdata.n_obs) else: results = [_project_single(query_metacell_index) for query_metacell_index in range(qdata.n_obs)] indices = np.concatenate([result[0] for result in results], dtype="int32") data = np.concatenate([result[1] for result in results], dtype="float32") atlas_used_sizes = [len(result[0]) for result in results] atlas_used_sizes.insert(0, 0) indptr = np.cumsum(np.array(atlas_used_sizes)) return sp.csr_matrix((data, indices, indptr), shape=(qdata.n_obs, adata.n_obs))
[docs] @ut.logged() @ut.timed_call() @ut.expand_doc() def compute_projected_fractions( *, adata: AnnData, qdata: AnnData, from_atlas_layer: str = "corrected_fraction", to_query_layer: str = "projected_fraction", log_data: bool = pr.project_log_data, fold_regularization: float = pr.project_fold_regularization, weights: ut.ProperMatrix, ) -> None: """ Compute the projected image of a query on an atlas. **Input** Annotated query ``qdata`` and atlas ``adata``, where the observations are cells and the variables are genes. The atlas should contain ``from_atlas_layer`` (default: {from_atlas_layer}) containing gene fractions. **Returns** Sets ``to_query_layer`` (default: {to_query_layer}) in the query containing the gene fractions of the projection of the atlas fractions using the ``weights`` matrix. .. note:: It is important to use the same ``log_data`` value as that given to ``compute_projection_weights`` to compute the weights (default: {log_data}). """ assert fold_regularization > 0 atlas_fractions = ut.get_vo_proper(adata, from_atlas_layer, layout="row_major") if log_data: atlas_log_fractions = ut.to_numpy_matrix(atlas_fractions, copy=True) atlas_log_fractions += fold_regularization atlas_log_fractions = np.log2(atlas_log_fractions, out=atlas_log_fractions) projected_fractions = weights @ atlas_log_fractions projected_fractions = np.power(2.0, projected_fractions, out=projected_fractions) projected_fractions -= fold_regularization else: projected_fractions = weights @ atlas_fractions # type: ignore assert projected_fractions.shape == qdata.shape projected_fractions = ut.fraction_by(projected_fractions, by="row").astype("float32") # type: ignore ut.set_vo_data(qdata, to_query_layer, sp.csr_matrix(projected_fractions))
def _project_query_atlas_data_arguments( adata: AnnData, qdata: AnnData, from_atlas_layer: str, from_query_layer: str, to_query_layer: str, log_data: bool, fold_regularization: float, min_significant_gene_umis: float, max_consistency_fold_factor: float, candidates_count: int, min_candidates_fraction: float, min_usage_weight: float, second_anchor_indices: Optional[List[int]], reproducible: bool, ) -> Dict[str, Any]: assert fold_regularization > 0 assert candidates_count > 0 assert 0 <= min_candidates_fraction <= 1.0 assert min_usage_weight >= 0 assert max_consistency_fold_factor >= 0 assert np.all(adata.var_names == qdata.var_names) atlas_umis = ut.get_vo_proper(adata, "total_umis", layout="row_major") atlas_fractions = ut.get_vo_proper(adata, from_atlas_layer, layout="row_major") query_fractions = ut.get_vo_proper(qdata, from_query_layer, layout="row_major") atlas_fractions = ut.to_numpy_matrix(atlas_fractions, copy=True) query_fractions = ut.to_numpy_matrix(query_fractions, copy=True) if second_anchor_indices is not None: assert qdata.n_obs == 1 query_single_fractions = ut.to_numpy_vector(ut.get_vo_proper(qdata, to_query_layer)) query_residual_fractions = query_fractions - query_single_fractions[np.newaxis, :] query_residual_fractions[query_residual_fractions < 0] = 0 atlas_residual_fractions = atlas_fractions - ut.to_numpy_vector(query_residual_fractions)[np.newaxis, :] atlas_residual_fractions[atlas_residual_fractions < 0] = 0 if log_data: atlas_residual_fractions += fold_regularization query_residual_fractions += fold_regularization atlas_residual_data = np.log2(atlas_residual_fractions, out=atlas_residual_fractions) query_residual_data = np.log2(query_residual_fractions, out=query_residual_fractions) else: atlas_residual_data = atlas_residual_fractions query_residual_data = query_residual_fractions query_atlas_corr_residual: Optional[ut.NumpyMatrix] = ut.cross_corrcoef_rows( query_residual_data, atlas_residual_data, reproducible=reproducible ) else: query_atlas_corr_residual = None atlas_log_fractions = atlas_fractions + fold_regularization atlas_log_fractions = np.log2(atlas_log_fractions, out=atlas_log_fractions) query_log_fractions = query_fractions + fold_regularization query_log_fractions = np.log2(query_log_fractions, out=query_log_fractions) if log_data: atlas_data = atlas_log_fractions query_data = query_log_fractions else: atlas_data = atlas_fractions query_data = query_fractions query_atlas_corr = ut.cross_corrcoef_rows(query_log_fractions, atlas_log_fractions, reproducible=reproducible) return { "atlas_umis": atlas_umis, "query_atlas_corr": query_atlas_corr, "atlas_data": atlas_data, "atlas_log_fractions": atlas_log_fractions, "query_data": query_data, "candidates_count": candidates_count, "min_candidates_fraction": min_candidates_fraction, "min_significant_gene_umis": min_significant_gene_umis, "min_usage_weight": min_usage_weight, "max_consistency_fold_factor": max_consistency_fold_factor, "second_anchor_indices": second_anchor_indices, "query_atlas_corr_residual": query_atlas_corr_residual, } @ut.logged() def _project_single_metacell( # pylint: disable=too-many-statements,too-many-branches *, query_metacell_index: int, atlas_umis: ut.Matrix, query_atlas_corr: ut.NumpyMatrix, atlas_data: ut.NumpyMatrix, atlas_log_fractions: ut.NumpyMatrix, query_data: ut.NumpyMatrix, candidates_count: int, min_candidates_fraction: float, min_significant_gene_umis: float, min_usage_weight: float, max_consistency_fold_factor: float, second_anchor_indices: Optional[List[int]], query_atlas_corr_residual: Optional[ut.NumpyMatrix], ) -> Tuple[ut.NumpyVector, ut.NumpyVector]: query_metacell_data = query_data[query_metacell_index, :] query_metacell_atlas_correlations = query_atlas_corr[query_metacell_index, :] query_metacell_atlas_order = np.argsort(-query_metacell_atlas_correlations) atlas_anchor_index = query_metacell_atlas_order[0] ut.log_calc("atlas_anchor_index", atlas_anchor_index) atlas_anchor_log_fractions = atlas_log_fractions[atlas_anchor_index, :] atlas_anchor_umis = ut.to_numpy_vector(atlas_umis[atlas_anchor_index, :]) atlas_candidates_consistency = [0.0] atlas_candidates_indices = [atlas_anchor_index] position = 1 while len(atlas_candidates_indices) < candidates_count and position < len(query_metacell_atlas_order): atlas_metacell_index = query_metacell_atlas_order[position] position += 1 atlas_metacell_log_fractions = atlas_log_fractions[atlas_metacell_index, :] atlas_metacell_consistency_fold_factors = np.abs(atlas_metacell_log_fractions - atlas_anchor_log_fractions) atlas_metacell_umis = ut.to_numpy_vector(atlas_umis[atlas_metacell_index, :]) atlas_metacell_significant_genes_mask = atlas_metacell_umis + atlas_anchor_umis >= min_significant_gene_umis if np.any(atlas_metacell_significant_genes_mask): atlas_metacell_consistency = np.max( atlas_metacell_consistency_fold_factors[atlas_metacell_significant_genes_mask] ) atlas_candidates_consistency.append(atlas_metacell_consistency) atlas_candidates_indices.append(atlas_metacell_index) sorted_locations = list(np.argsort(np.array(atlas_candidates_consistency))) min_candidates_count = candidates_count * min_candidates_fraction while ( len(sorted_locations) > min_candidates_count and atlas_candidates_consistency[sorted_locations[-1]] > max_consistency_fold_factor ): sorted_locations.pop() atlas_candidate_indices_set = set([atlas_anchor_index]) for location in sorted_locations: atlas_metacell_index = atlas_candidates_indices[location] atlas_candidate_indices_set.add(atlas_metacell_index) ut.log_calc("atlas_candidates", len(atlas_candidate_indices_set)) if query_atlas_corr_residual is None: atlas_candidate_indices = np.array(sorted(atlas_candidate_indices_set)) else: query_metacell_atlas_residual_correlations = query_atlas_corr_residual[query_metacell_index, :] query_metacell_atlas_residual_order = np.argsort(-query_metacell_atlas_residual_correlations) atlas_secondary_anchor_index = query_metacell_atlas_residual_order[0] ut.log_calc("atlas_secondary_anchor_index", atlas_secondary_anchor_index) atlas_secondary_anchor_log_fractions = atlas_log_fractions[atlas_anchor_index, :] atlas_secondary_anchor_umis = ut.to_numpy_vector(atlas_umis[atlas_secondary_anchor_index, :]) atlas_secondary_candidates_consistency = [0.0] atlas_secondary_candidates_indices = [atlas_secondary_anchor_index] position = 1 while len(atlas_secondary_candidates_indices) < candidates_count and position < len( query_metacell_atlas_residual_order ): atlas_metacell_index = query_metacell_atlas_residual_order[position] position += 1 atlas_metacell_log_fractions = atlas_log_fractions[atlas_metacell_index, :] atlas_metacell_consistency_fold_factors = np.abs( atlas_metacell_log_fractions - atlas_secondary_anchor_log_fractions ) atlas_metacell_umis = ut.to_numpy_vector(atlas_umis[atlas_metacell_index, :]) atlas_metacell_significant_genes_mask = ( atlas_metacell_umis + atlas_secondary_anchor_umis >= min_significant_gene_umis ) if np.any(atlas_metacell_significant_genes_mask): atlas_metacell_consistency = np.max( atlas_metacell_consistency_fold_factors[atlas_metacell_significant_genes_mask] ) atlas_secondary_candidates_consistency.append(atlas_metacell_consistency) atlas_secondary_candidates_indices.append(atlas_metacell_index) sorted_secondary_locations = list(np.argsort(np.array(atlas_secondary_candidates_consistency))) while ( len(sorted_secondary_locations) > min_candidates_count and atlas_secondary_candidates_consistency[sorted_secondary_locations[-1]] > max_consistency_fold_factor ): sorted_secondary_locations.pop() atlas_secondary_candidate_indices_set = set([atlas_secondary_anchor_index]) for location in sorted_secondary_locations: atlas_metacell_index = atlas_secondary_candidates_indices[location] atlas_secondary_candidate_indices_set.add(atlas_metacell_index) ut.log_calc("atlas_secondary_candidates", len(atlas_candidate_indices_set)) atlas_candidate_indices = np.array(sorted(atlas_candidate_indices_set | atlas_secondary_candidate_indices_set)) atlas_candidates_data = atlas_data[atlas_candidate_indices, :] represent_result = ut.represent(query_metacell_data, atlas_candidates_data) assert represent_result is not None atlas_candidate_weights = represent_result[1] atlas_candidate_weights[atlas_candidate_weights < min_usage_weight] = 0 atlas_candidate_weights[:] /= np.sum(atlas_candidate_weights) atlas_used_mask = atlas_candidate_weights > 0 atlas_used_indices = atlas_candidate_indices[atlas_used_mask].astype("int32") ut.log_return("atlas_used_indices", atlas_used_indices) if second_anchor_indices is not None: for atlas_metacell_index in atlas_used_indices: if atlas_metacell_index not in atlas_candidate_indices_set: second_anchor_indices.append(atlas_metacell_index) atlas_used_weights = atlas_candidate_weights[atlas_used_mask] atlas_used_weights = atlas_used_weights.astype("float32") ut.log_return("atlas_used_weights", atlas_used_weights) return (atlas_used_indices, atlas_used_weights)
[docs] @ut.logged() @ut.timed_call() def convey_atlas_to_query( *, adata: AnnData, qdata: AnnData, weights: ut.ProperMatrix, property_name: str, formatter: Optional[Callable[[Any], Any]] = None, to_property_name: Optional[str] = None, method: Callable[[ut.Vector, ut.Vector], Any] = ut.highest_weight, ) -> None: """ Convey the value of a property from per-observation atlas data to per-observation query data. The input annotated ``adata`` is expected to contain a per-observation (cell) annotation named ``property_name``. Given the ``weights`` matrix, where each row specifies the weights of the atlas metacells used to project a single query metacell, this will generate a new per-observation (group) annotation in ``qdata``, named ``to_property_name`` (by default, the same as ``property_name``), containing the aggregated value of the property of all the observations (cells) that belong to the group. The aggregation method (by default, :py:func:`metacells.utilities.computation.highest_weight`) is any function taking two array, weights and values, and returning a single value. """ if to_property_name is None: to_property_name = property_name property_of_atlas_metacells = ut.get_o_numpy(adata, property_name, formatter=formatter) property_of_query_metacells = [] for query_metacell_index in range(qdata.n_obs): atlas_metacell_weights = ut.to_numpy_vector(weights[query_metacell_index, :]) used_atlas_metacells_mask = atlas_metacell_weights > 0 assert np.any(used_atlas_metacells_mask) used_atlas_metacell_weights = ut.to_numpy_vector(atlas_metacell_weights[used_atlas_metacells_mask]) used_atlas_metacell_values = property_of_atlas_metacells[used_atlas_metacells_mask] property_of_query_metacells.append(method(used_atlas_metacell_weights, used_atlas_metacell_values)) query_property_data = np.array(property_of_query_metacells, dtype=property_of_atlas_metacells.dtype) ut.set_o_data(qdata, to_property_name, query_property_data)