"""
Projection
----------
"""
from math import ceil
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
import numpy as np
import scipy.sparse as sp # type: ignore
from anndata import AnnData # type: ignore
import metacells.parameters as pr
import metacells.tools as tl
import metacells.utilities as ut
from metacells import __version__ # pylint: disable=cyclic-import
__all__ = [
"projection_pipeline",
"outliers_projection_pipeline",
"write_projection_weights",
]
[docs]
@ut.logged()
@ut.timed_call()
@ut.expand_doc()
def projection_pipeline(
what: str = "__x__",
*,
adata: AnnData,
qdata: AnnData,
only_atlas_marker_genes: bool = pr.only_atlas_marker_genes,
only_query_marker_genes: bool = pr.only_query_marker_genes,
ignore_atlas_lateral_genes: bool = pr.ignore_atlas_lateral_genes,
ignore_query_lateral_genes: bool = pr.ignore_query_lateral_genes,
consider_atlas_noisy_genes: bool = pr.consider_atlas_noisy_genes,
consider_query_noisy_genes: bool = pr.consider_query_noisy_genes,
misfit_min_metacells_fraction: float = pr.misfit_min_metacells_fraction,
project_log_data: bool = pr.project_log_data,
project_fold_regularization: float = pr.project_fold_regularization,
project_candidates_count: int = pr.project_candidates_count,
project_min_candidates_fraction: float = pr.project_min_candidates_fraction,
project_min_significant_gene_umis: int = pr.project_min_significant_gene_umis,
project_min_usage_weight: float = pr.project_min_usage_weight,
project_filter_ranges: bool = pr.project_filter_ranges,
project_ignore_range_quantile: float = pr.project_ignore_range_quantile,
project_ignore_range_min_overlap_fraction: float = pr.project_ignore_range_min_overlap_fraction,
project_min_query_markers_fraction: float = pr.project_min_query_markers_fraction,
project_max_consistency_fold_factor: float = pr.project_max_consistency_fold_factor,
project_max_projection_fold_factor: float = pr.project_max_projection_fold_factor,
project_max_projection_noisy_fold_factor: float = pr.project_max_projection_noisy_fold_factor,
project_max_misfit_genes: int = pr.project_max_misfit_genes,
project_min_essential_genes_fraction: Optional[float] = pr.project_min_essential_genes_fraction,
atlas_type_property_name: str = "type",
project_corrections: bool = pr.project_corrections,
project_min_corrected_gene_correlation: float = pr.project_min_corrected_gene_correlation,
project_min_corrected_gene_factor: float = pr.project_min_corrected_gene_factor,
reproducible: bool,
top_level_parallel: bool = True,
) -> ut.CompressedMatrix:
"""
Complete pipeline for projecting query metacells onto an atlas of metacells for the ``what`` (default: {what}) data.
**Input**
Annotated query ``qdata`` and atlas ``adata``, where the observations are cells and the variables are genes, where
``what`` is a per-variable-per-observation matrix or the name of a per-variable-per-observation annotation
containing such a matrix, containing the fraction of each gene in each cell. The atlas should also contain a
``type`` per-observation (metacell) annotation.
The ``qdata`` may include per-gene masks, ``ignored_gene`` and ``ignored_gene_of_<type>``, which force the code
below to ignore the marked genes from either the preliminary projection or the following refined type-specific
projections.
**Returns**
A matrix whose rows are query metacells and columns are atlas metacells, where each entry is the weight of the atlas
metacell in the projection of the query metacells. The sum of weights in each row (that is, for a single query
metacell) is 1. The weighted sum of the atlas metacells using these weights is the "projected" image of the query
metacell onto the atlas.
Variable (Gene) Annotations
``atlas_gene``
A mask of the query genes that also exist in the atlas. We match genes by their name; if projecting query
data from a different technology, we expect the caller to modify the query gene names to match the atlas
before projecting it.
``atlas_lateral_gene``, ``atlas_noisy_gene``, ``atlas_marker_gene``, ``essential_gene_of_<type>``
Copied from the atlas to the query (``False`` for non-``atlas_gene``).
``projected_noisy_gene``
The mask of the genes that were considered "noisy" when computing the projection. By default this is the
union of the noisy atlas and query genes.
``correction_factor`` (if ``project_corrections``)
If projecting a query on an atlas with different technologies (e.g., 10X v3 to 10X v2), an automatically
computed factor we multiplied the query gene fractions by to compensate for the systematic difference
between the technologies (1.0 for uncorrected genes and 0.0 for non-``atlas_gene``).
``fitted_gene_of_<type>``
For each type, the genes that were projected well from the query to the atlas for most cells of that
type; any ``atlas_gene`` outside this mask failed to project well from the query to the atlas for most
metacells of this type. For non-``atlas_gene`` this is set to ``False``.
``misfit_gene_of_<type>``
For each query metacell type, a boolean mask indicating whether the gene has a strong bias in the query
metacells of this type compared to the atlas metacells of this type.
``ignored_gene_of_<type>``
For each query metacell type, a boolean mask indicating whether the gene was ignored by the projection (for
any reason) when computing the projection for metacells of this type.
Observation (Cell) Annotations
``total_atlas_umis``
The total UMIs of the ``atlas_gene`` in each query metacell. This is used in the analysis as described for
``total_umis`` above, that is, to ensure comparing expression levels will ignore cases where the total
number of UMIs of both compared gene profiles is too low to make a reliable determination. In such cases we
take the fold factor to be 0.
``projected_type``
For each query metacell, the best atlas ``type`` we can assign to it based on its projection. Note this does
not indicate that the query metacell is "truly" of this type; to make this determination one needs to look
at the quality control data below.
``projected_secondary_type``
In some cases, a query metacell may fail to project well to a single region of the atlas, but does project
well to a combination of two distinct atlas regions. This may be due to the query metacell containing
doublets, of a mixture of cells which match different atlas regions (e.g. due to sparsity of data in the
query data set). Either way, if this happens, we place here the type that best describes the secondary
region the query metacell was projected to; otherwise this would be the empty string. Note that the
``weights`` matrix above does not distinguish between the regions.
``projected_correlation`` per query metacell
The correlation between between the ``corrected_fraction`` and the ``projected_fraction`` for only the
``fitted_gene`` expression levels of each query metacell. This serves as a very rough estimator for the
quality of the projection for this query metacell (e.g. can be used to compute R^2 values).
In general we expect high correlation (more than 0.9 in most metacells) since we restricted the
``fitted_gene`` mask only to genes we projected well.
``similar`` mask per query metacell
A conservative determination of whether the query metacell is "similar" to its projection on the atlas. This
is based on whether the number of ``misfit_gene`` for the query metacell is low enough (by default, up to 3
genes), and also that at least 75% of the ``essential_gene`` of the query metacell were not ``misfit_gene``.
Note that this explicitly allows for a ``projected_secondary_type``, that is, a metacell of doublets will be
"similar" to the atlas, but a metacell of a novel state missing from the atlas will be "dissimilar".
The final determination of whether to accept the projection is, as always, up to the analyst, based on prior
biological knowledge, the context of the collection of the query (and atlas) data sets, etc. The analyst
need not (indeed, *should not*) blindly accept the ``similar`` determination without examining the rest of
the quality control data listed above.
Observation-Variable (Cell-Gene) Annotations
``corrected_fraction`` per gene per query metacell
For each ``atlas_gene``, its fraction in each query metacell, out of only the atlas genes. This may be
further corrected (see below) if projecting between different scRNA-seq technologies (e.g. 10X v2 and 10X
v3). For non-``atlas_gene`` this is 0.
``projected_fraction`` per gene per query metacell
For each ``atlas_gene``, its fraction in its projection on the atlas. This projection is computed as a
weighted average of some atlas metacells (see below), which are all sufficiently close to each other (in
terms of gene expression), so averaging them is reasonable to capture the fact the query metacell may be
along some position on some gradient that isn't an exact match for any specific atlas metacell. For
non-``atlas_gene`` this is 0.
``fitted`` mask per gene per query metacell
For each ``atlas_gene`` for each query metacell, whether the gene was expected to be projected well, based
on the query metacell ``projected_type`` (and the ``projected_secondary_type``, if any). For
non-``atlas_gene`` this is set to ``False``. This does not guarantee the gene was actually projected well.
``misfit``
For each ``atlas_gene`` for each query metacell, whether the ``corrected_fraction`` of the gene was
significantly different from the ``projected_fractions`` (that is, whether the gene was not projected well
for this metacell). For non-``atlas_gene`` this is set to ``False``, to make it easier to identify
problematic genes.
This is expected to be rare for ``fitted_gene`` and common for the rest of the ``atlas_gene``. If too many
``fitted_gene`` are also ``misfit_gene``, then one should be suspicious whether the query metacell is
"truly" of the ``projected_type``.
``essential``
Which of the ``atlas_gene`` were also listed in the ``essential_gene_of_<type>`` for the ``projected_type``
(and also the ``projected_secondary_type``, if any) of each query metacell.
If an ``essential_gene`` is also a ``misfit_gene``, then one should be very suspicious whether the query
metacell is "truly" of the ``projected_type``.
``projected_fold`` per gene per query metacell
The fold factor between the ``corrected_fraction`` and the ``projected_fraction`` (0 for
non-``atlas_gene``). If the absolute value of this is high (3 for 8x ratio) then the gene was not projected
well for this metacell. This will be 0 for non-``atlas_gene``.
It is expected this would have low values for most ``fitted_gene`` and high values for the rest of the
``atlas_gene``, but specific values will vary from one query metacell to another. This allows the analyst to
make fine-grained determination about the quality of the projection, and/or identify quantitative
differences between the query and the atlas (e.g., when studying perturbed systems such as knockouts or
disease models).
**Computation Parameters**
0. Find the subset of genes that exist in both the query and the atlas. All computations will be done on this common
subset. Normalize the fractions of the fitted gene fractions to sum to 1 in each metacell.
Compute preliminary projection:
1. Compute a mask of fitted genes, ignoring any genes included in ``ignored_gene``. If ``only_atlas_marker_genes``
(default: ``only_atlas_marker_genes``), ignore any non-``marker_gene`` of the atlas. If
``only_query_marker_genes`` (default: ``only_query_marker_genes``), ignore any non-``marker_gene`` of the query.
If ``ignore_atlas_lateral_genes`` (default: {ignore_atlas_lateral_genes}), ignore the ``lateral_gene`` of the
atlas. If ``ignore_query_lateral_genes`` (default: {ignore_query_lateral_genes}), ignore the ``lateral_gene`` of
the atlas. Normalize the fractions of the fitted gene fractions to sum to 1 in each metacell.
2. Invoke :py:func:`metacells.tools.project.compute_projection_weights`
to project each query metacell onto the atlas, using the ``project_log_data`` (default: {project_log_data}),
``project_fold_regularization`` (default: {project_fold_regularization}), ``project_min_significant_gene_umis``
(default: {project_min_significant_gene_umis}), ``project_max_consistency_fold_factor`` (default:
{project_max_consistency_fold_factor}), ``project_candidates_count`` (default: {project_candidates_count}),
``project_min_usage_weight`` (default: {project_min_usage_weight}), and ``reproducible``.
3. If ``project_corrections`` (default: {project_corrections}: Correlate the expression levels of each gene between
the query and projection. If this is at least ``project_min_corrected_gene_correlation`` (default:
{project_min_corrected_gene_correlation}), compute the ratio between the mean expression of the gene in the
projection and the query. If this is at most 1/(1+``project_min_corrected_gene_factor``) or at least
(1+``project_min_corrected_gene_factor``) (default: {project_min_corrected_gene_factor}), then multiply the
gene's value by this factor so its level would match the atlas. As usual, ignore genes which do not have at least
``project_min_significant_gene_umis``. If any genes were corrected, then repeat steps 2-3 (but do these steps no
more than 3 times).
4. If ``project_filter_ranges`` (default: {project_filter_ranges}): Compute for each gene its expression range
(lowest and highest ``project_ignore_range_quantile`` (default: {project_ignore_range_quantile}) in both the
projected and the corrected values. Compute the overlap between these ranges (shared range divided by the query
range). If this is less than ``project_ignore_range_min_overlap_fraction`` (default:
{project_ignore_range_min_overlap_fraction}), then ignore the gene. If any genes were ignored, repeat steps 2-4
(but do this step no more than 3 times).
5. Invoke :py:func:`metacells.tools.project.convey_atlas_to_query` to assign a projected type to each of the
query metacells based on the ``atlas_type_property_name`` (default: {atlas_type_property_name}).
Then, for each type of query metacells:
If ``top_level_parallel`` (default: ``top_level_parallel``), do this in parallel. This seems to work better than
doing this serially and using parallelism within each type. However this is still inefficient compared to using both
types of parallelism at once, which the code currently can't do without non-trivial coding (this would have been
trivial in Julia...).
6. Further reduce the mask of type-specific fitted genes by ignoring any genes in ``ignored_gene_of_<type>``, if
this annotation exists in the query. Normalize the sum of the fitted gene fractions to 1 in each metacell.
7. Invoke :py:func:`metacells.tools.project.compute_projection_weights` to project each query metacell of the type
onto the atlas. Note that even though we only look at query metacells (tentatively) assigned the type, their
projection on the atlas may use metacells of any type.
8. Invoke :py:func:`metacells.tools.quality.compute_projected_folds` to compute the significant fold factors
between the query and its projection.
9. Identify type-specific misfit genes whose fold factor is above ``project_max_projection_fold_factor``. If
``consider_atlas_noisy_genes`` and/or ``consider_query_noisy_genes``, then any gene listed in either is allowed
an additional ``project_max_projection_noisy_fold_factor``. If any gene has such a high fold factor in at least
``misfit_min_metacells_fraction``, remove it from the fitted genes mask and repeat steps 6-9.
10. Invoke :py:func:`metacells.tools.quality.compute_similar_query_metacells` to verify which query metacells
ended up being sufficiently similar to their projection, using ``project_max_consistency_fold_factor`` (default:
{project_max_consistency_fold_factor}), ``project_max_projection_noisy_fold_factor`` (default:
{project_max_projection_noisy_fold_factor}), ``project_max_misfit_genes`` (default: {project_max_misfit_genes}),
and if needed, the the ``essential_gene_of_<type>``. Also compute the correlation between the corrected
and projected gene fractions for each metacell.
And then:
11. Invoke :py:func:`metacells.tools.project.convey_atlas_to_query` to assign an updated projected type to each of
the query metacells based on the ``atlas_type_property_name`` (default: {atlas_type_property_name}). If this
changed the type assigned to any query metacell, repeat steps 6-11 (but do this step no more than 3 times).
For each query metacell that ended up being dissimilar, try to project them as a combination of two atlas regions:
12. Reduce the list of fitted genes for the metacell based on the ``ignored_gene_of_<type>`` for both the primary
(initially from the above) query metacell type and the secondary query metacell type (initially empty).
Normalize the sum of the gene fractions in the metacell to 1.
13. Invoke :py:func:`metacells.tools.project.compute_projection_weights` just for this metacell, allowing the
projection to use a secondary location in the atlas based on the residuals of the atlas metacells relative to
the primary query projection.
14. Invoke :py:func:`metacells.tools.project.convey_atlas_to_query` twice, once for the weights of the primary
location and once for the weights of the secondary location, to obtain a primary and secondary type for
the query metacell. If these have changed, repeat steps 13-14 (but do these steps no more than 3 times; note
will will always do them twice as the 1st run will generate some non-empty secondary type).
15. Invoke :py:func:`metacells.tools.quality.compute_projected_folds` and
:py:func:`metacells.tools.quality.compute_similar_query_metacells` to update the projection and evaluation of
the query metacell. If it is now similar, then use the results for the metacell; otherwise, keep the original
results as they were at the end of step 10.
"""
assert project_min_corrected_gene_factor >= 0
use_essential_genes = project_min_essential_genes_fraction is not None and _has_any_essential_genes(adata)
ut.set_m_data(qdata, "project_max_projection_fold_factor", project_max_projection_fold_factor)
ut.set_m_data(qdata, "project_max_projection_noisy_fold_factor", project_max_projection_noisy_fold_factor)
ut.set_m_data(qdata, "project_max_misfit_genes", project_max_misfit_genes)
atlas_common_gene_indices, query_common_gene_indices = _common_gene_indices(adata, qdata)
type_names = _initialize_atlas_data_in_query(
adata,
qdata,
atlas_type_property_name=atlas_type_property_name,
atlas_common_gene_indices=atlas_common_gene_indices,
query_common_gene_indices=query_common_gene_indices,
consider_atlas_noisy_genes=consider_atlas_noisy_genes,
consider_query_noisy_genes=consider_query_noisy_genes,
use_essential_genes=use_essential_genes,
)
common_adata = _initialize_common_adata(adata, what, atlas_common_gene_indices)
common_qdata = _initialize_common_qdata(qdata, what, query_common_gene_indices)
query_marker_genes_mask = ut.get_v_numpy(common_qdata, "marker_gene")
min_fitted_query_marker_genes = np.sum(query_marker_genes_mask) * project_min_query_markers_fraction
ut.log_calc("min_fitted_query_marker_genes", min_fitted_query_marker_genes)
min_essential_genes_of_type = _min_essential_genes_of_type(
adata=adata,
common_adata=common_adata,
type_names=type_names,
project_min_essential_genes_fraction=project_min_essential_genes_fraction,
)
preliminary_fitted_genes_mask = _preliminary_fitted_genes_mask(
common_adata=common_adata,
common_qdata=common_qdata,
only_atlas_marker_genes=only_atlas_marker_genes,
only_query_marker_genes=only_query_marker_genes,
ignore_atlas_lateral_genes=ignore_atlas_lateral_genes,
ignore_query_lateral_genes=ignore_query_lateral_genes,
min_significant_gene_umis=project_min_significant_gene_umis,
)
_compute_preliminary_projection(
common_adata=common_adata,
common_qdata=common_qdata,
preliminary_fitted_genes_mask=preliminary_fitted_genes_mask,
project_log_data=project_log_data,
project_fold_regularization=project_fold_regularization,
project_min_significant_gene_umis=project_min_significant_gene_umis,
project_filter_ranges=project_filter_ranges,
project_ignore_range_quantile=project_ignore_range_quantile,
project_ignore_range_min_overlap_fraction=project_ignore_range_min_overlap_fraction,
project_max_consistency_fold_factor=project_max_consistency_fold_factor,
project_candidates_count=project_candidates_count,
project_min_usage_weight=project_min_usage_weight,
project_corrections=project_corrections,
project_min_corrected_gene_correlation=project_min_corrected_gene_correlation,
project_min_corrected_gene_factor=project_min_corrected_gene_factor,
atlas_type_property_name=atlas_type_property_name,
reproducible=reproducible,
)
weights = _compute_per_type_projection(
common_adata=common_adata,
common_qdata=common_qdata,
type_names=type_names,
preliminary_fitted_genes_mask=preliminary_fitted_genes_mask,
misfit_min_metacells_fraction=misfit_min_metacells_fraction,
project_log_data=project_log_data,
project_fold_regularization=project_fold_regularization,
project_candidates_count=project_candidates_count,
project_min_significant_gene_umis=project_min_significant_gene_umis,
project_min_usage_weight=project_min_usage_weight,
project_max_consistency_fold_factor=project_max_consistency_fold_factor,
project_max_projection_fold_factor=project_max_projection_fold_factor,
project_max_projection_noisy_fold_factor=project_max_projection_noisy_fold_factor,
project_max_misfit_genes=project_max_misfit_genes,
min_fitted_query_marker_genes=min_fitted_query_marker_genes,
min_essential_genes_of_type=min_essential_genes_of_type,
atlas_type_property_name=atlas_type_property_name,
top_level_parallel=top_level_parallel,
reproducible=reproducible,
)
_compute_dissimilar_residuals_projection(
weights_per_atlas_per_query_metacell=weights,
common_adata=common_adata,
common_qdata=common_qdata,
project_log_data=project_log_data,
project_fold_regularization=project_fold_regularization,
project_candidates_count=project_candidates_count,
project_min_candidates_fraction=project_min_candidates_fraction,
project_min_significant_gene_umis=project_min_significant_gene_umis,
project_min_usage_weight=project_min_usage_weight,
project_max_consistency_fold_factor=project_max_consistency_fold_factor,
project_max_projection_fold_factor=project_max_projection_fold_factor,
project_max_projection_noisy_fold_factor=project_max_projection_noisy_fold_factor,
project_max_misfit_genes=project_max_misfit_genes,
min_fitted_query_marker_genes=min_fitted_query_marker_genes,
min_essential_genes_of_type=min_essential_genes_of_type,
atlas_type_property_name=atlas_type_property_name,
top_level_parallel=top_level_parallel,
reproducible=reproducible,
)
tl.compute_projected_fractions(
adata=common_adata,
qdata=common_qdata,
log_data=project_log_data,
fold_regularization=project_fold_regularization,
weights=weights,
)
_common_data_to_full(
qdata=qdata,
common_qdata=common_qdata,
project_corrections=project_corrections,
type_names=type_names,
use_essential_genes=use_essential_genes,
)
ut.set_m_data(qdata, "projection_algorithm", f"metacells.{__version__}")
return sp.csr_matrix(weights)
def _common_gene_indices(adata: AnnData, qdata: AnnData) -> Tuple[ut.NumpyVector, ut.NumpyVector]:
if list(qdata.var_names) == list(adata.var_names):
atlas_common_gene_indices = query_common_gene_indices = np.array(range(qdata.n_vars), dtype="int32")
else:
atlas_genes_list = list(adata.var_names)
query_genes_list = list(qdata.var_names)
common_genes_list = list(sorted(set(atlas_genes_list) & set(query_genes_list)))
assert len(common_genes_list) > 0
atlas_common_gene_indices = np.array([atlas_genes_list.index(gene) for gene in common_genes_list])
query_common_gene_indices = np.array([query_genes_list.index(gene) for gene in common_genes_list])
return atlas_common_gene_indices, query_common_gene_indices
def _has_any_essential_genes(adata: AnnData) -> bool:
for property_name in adata.var.keys():
if property_name.startswith("essential_gene_of_"):
return True
return False
def _initialize_atlas_data_in_query(
adata: AnnData,
qdata: AnnData,
*,
atlas_type_property_name: str,
atlas_common_gene_indices: ut.NumpyVector,
query_common_gene_indices: ut.NumpyVector,
consider_atlas_noisy_genes: bool,
consider_query_noisy_genes: bool,
use_essential_genes: bool,
) -> List[str]:
atlas_genes_mask = np.zeros(qdata.n_vars, dtype="bool")
atlas_genes_mask[query_common_gene_indices] = True
ut.set_v_data(qdata, "atlas_gene", atlas_genes_mask)
type_names = list(np.unique(ut.get_o_numpy(adata, atlas_type_property_name)))
genes_mask_names = ["lateral_gene", "noisy_gene", "marker_gene"]
if consider_atlas_noisy_genes:
genes_mask_names.append("noisy_gene")
if use_essential_genes:
genes_mask_names += [f"essential_gene_of_{type_name}" for type_name in type_names]
for genes_mask_name in genes_mask_names:
if not ut.has_data(adata, genes_mask_name):
continue
atlas_mask = ut.get_v_numpy(adata, genes_mask_name)
query_mask = np.zeros(qdata.n_vars, dtype="bool")
query_mask[query_common_gene_indices] = atlas_mask[atlas_common_gene_indices]
ut.set_v_data(
qdata,
genes_mask_name if genes_mask_name.startswith("essential_gene_of_") else "atlas_" + genes_mask_name,
query_mask,
)
noisy_masks: List[str] = []
if consider_atlas_noisy_genes and ut.has_data(qdata, "atlas_noisy_gene"):
noisy_masks.append("|atlas_noisy_gene")
if consider_query_noisy_genes and ut.has_data(qdata, "noisy_gene"):
noisy_masks.append("|noisy_gene")
if len(noisy_masks) > 0:
tl.combine_masks(qdata, noisy_masks, to="projected_noisy_gene")
return type_names
def _initialize_common_adata(adata: AnnData, what: str, atlas_common_gene_indices: ut.NumpyVector) -> AnnData:
common_adata = ut.slice(adata, name=".common", vars=atlas_common_gene_indices, top_level=False)
atlas_total_common_umis = ut.get_o_numpy(common_adata, "total_umis", sum=True)
ut.set_o_data(common_adata, "total_umis", atlas_total_common_umis)
_normalize_corrected_fractions(common_adata, what)
return common_adata
def _initialize_common_qdata(
qdata: AnnData,
what: str,
query_common_gene_indices: ut.NumpyVector,
) -> AnnData:
common_qdata = ut.slice(
qdata,
name=".common",
vars=query_common_gene_indices,
track_var="full_gene_index_of_qdata",
top_level=False,
)
query_total_common_umis = ut.get_o_numpy(common_qdata, "total_umis", sum=True)
ut.set_o_data(common_qdata, "total_umis", query_total_common_umis)
ut.set_o_data(qdata, "total_atlas_umis", query_total_common_umis)
_normalize_corrected_fractions(common_qdata, what)
return common_qdata
def _min_essential_genes_of_type(
*,
adata: AnnData,
common_adata: AnnData,
type_names: List[str],
project_min_essential_genes_fraction: Optional[float],
) -> Dict[Tuple[str, str], Optional[int]]:
min_essential_genes_of_type: Dict[Tuple[str, str], Optional[int]] = {
(type_name, other_type_name): None for type_name in type_names for other_type_name in type_names
}
has_any_essential_genes = False
if project_min_essential_genes_fraction is not None:
for type_name in type_names:
if type_name != "Outliers" and ut.has_data(adata, f"essential_gene_of_{type_name}"):
has_any_essential_genes = True
break
if not has_any_essential_genes:
return min_essential_genes_of_type
assert project_min_essential_genes_fraction is not None
for type_name in type_names:
if type_name == "Outliers":
min_essential_genes_count: Optional[int] = None
else:
atlas_essential_genes_mask = ut.get_v_numpy(adata, f"essential_gene_of_{type_name}")
atlas_essential_genes_count = np.sum(atlas_essential_genes_mask)
common_essential_genes_mask = ut.get_v_numpy(common_adata, f"essential_gene_of_{type_name}")
common_essential_genes_count = np.sum(common_essential_genes_mask)
min_essential_genes_count = ceil(project_min_essential_genes_fraction * atlas_essential_genes_count)
assert min_essential_genes_count is not None
assert min_essential_genes_count >= 0
if min_essential_genes_count > common_essential_genes_count:
atlas_essential_genes_names = sorted(adata.var_names[atlas_essential_genes_mask])
common_essential_genes_names = sorted(common_adata.var_names[common_essential_genes_mask])
missing_essential_genes_names = sorted(
set(atlas_essential_genes_names) - set(common_essential_genes_names)
)
ut.logger().warning( # pylint: disable=logging-fstring-interpolation
f"the {common_essential_genes_count} "
f"common essential gene(s) {', '.join(common_essential_genes_names)} "
f"for the type {type_name} "
f"are not enough for the required {min_essential_genes_count}; "
"reducing the minimal requirement "
f" (the non-common essential gene(s) are: {', '.join(missing_essential_genes_names)})"
)
min_essential_genes_count = common_essential_genes_count
min_essential_genes_of_type[(type_name, type_name)] = min_essential_genes_count
ut.log_calc(f"min_essential_genes_of_type[{type_name}]", min_essential_genes_count)
if min_essential_genes_count is None:
continue
for other_type_name in type_names:
if other_type_name <= type_name or other_type_name == "Outliers":
continue
other_atlas_essential_genes_mask = ut.get_v_numpy(adata, f"essential_gene_of_{type_name}")
pair_atlas_essential_genes_mask = atlas_essential_genes_mask | other_atlas_essential_genes_mask
pair_atlas_essential_genes_count = np.sum(pair_atlas_essential_genes_mask)
other_common_essential_genes_mask = ut.get_v_numpy(common_adata, f"essential_gene_of_{type_name}")
pair_common_essential_genes_mask = common_essential_genes_mask | other_common_essential_genes_mask
pair_common_essential_genes_count = np.sum(pair_common_essential_genes_mask)
missing_essential_genes_count = pair_atlas_essential_genes_count - pair_common_essential_genes_count
assert missing_essential_genes_count >= 0
min_essential_genes_count = ceil(project_min_essential_genes_fraction * pair_atlas_essential_genes_count)
assert min_essential_genes_count is not None
ut.log_calc(f"min_essential_genes_of_type[{type_name}, {other_type_name}]", min_essential_genes_count)
min_essential_genes_of_type[(type_name, other_type_name)] = min_essential_genes_count
min_essential_genes_of_type[(other_type_name, type_name)] = min_essential_genes_count
return min_essential_genes_of_type
def _preliminary_fitted_genes_mask(
*,
common_adata: AnnData,
common_qdata: AnnData,
only_atlas_marker_genes: bool,
only_query_marker_genes: bool,
ignore_atlas_lateral_genes: bool,
ignore_query_lateral_genes: bool,
min_significant_gene_umis: int,
) -> ut.NumpyVector:
atlas_total_umis_per_gene = ut.get_v_numpy(common_adata, "total_umis", sum=True)
query_total_umis_per_gene = ut.get_v_numpy(common_qdata, "total_umis", sum=True)
total_umis_per_gene = atlas_total_umis_per_gene + query_total_umis_per_gene
preliminary_fitted_genes_mask = total_umis_per_gene >= min_significant_gene_umis
ut.log_calc("total_umis_mask", preliminary_fitted_genes_mask)
if only_atlas_marker_genes and ut.has_data(common_adata, "marker_gene"):
preliminary_fitted_genes_mask &= ut.get_v_numpy(common_adata, "marker_gene")
if only_query_marker_genes and ut.has_data(common_qdata, "marker_gene"):
preliminary_fitted_genes_mask &= ut.get_v_numpy(common_qdata, "marker_gene")
if ignore_atlas_lateral_genes and ut.has_data(common_adata, "lateral_gene"):
preliminary_fitted_genes_mask &= ~ut.get_v_numpy(common_adata, "lateral_gene")
if ignore_query_lateral_genes and ut.has_data(common_qdata, "lateral_gene"):
preliminary_fitted_genes_mask &= ~ut.get_v_numpy(common_qdata, "lateral_gene")
if ut.has_data(common_qdata, "ignored_gene"):
preliminary_fitted_genes_mask &= ~ut.get_v_numpy(common_qdata, "ignored_gene")
return preliminary_fitted_genes_mask
def _common_data_to_full(
*,
qdata: AnnData,
common_qdata: AnnData,
project_corrections: bool,
use_essential_genes: bool,
type_names: List[str],
) -> None:
if use_essential_genes:
primary_type_per_metacell = ut.get_o_numpy(common_qdata, "projected_type")
secondary_type_per_metacell = ut.get_o_numpy(common_qdata, "projected_secondary_type")
essential_per_gene_per_metacell = np.zeros(qdata.shape, dtype="bool")
essential_gene_per_type = {
type_name: ut.get_v_numpy(qdata, f"essential_gene_of_{type_name}")
for type_name in type_names
if type_name != "Outliers"
}
for metacell_index in range(qdata.n_obs):
primary_type_of_metacell = primary_type_per_metacell[metacell_index]
if primary_type_of_metacell != "Outliers":
essential_per_gene_per_metacell[metacell_index, :] = essential_gene_per_type[primary_type_of_metacell]
secondary_type_of_metacell = secondary_type_per_metacell[metacell_index]
if secondary_type_of_metacell not in ("", "Outliers", primary_type_of_metacell):
essential_per_gene_per_metacell[metacell_index, :] |= essential_gene_per_type[
secondary_type_of_metacell
]
ut.set_vo_data(qdata, "essential", sp.csr_matrix(essential_per_gene_per_metacell))
full_gene_index_of_common_qdata = ut.get_v_numpy(common_qdata, "full_gene_index_of_qdata")
for property_name in (
"corrected_fraction",
"projected_fraction",
"fitted",
"misfit",
"projected_fold",
):
data_per_common_gene_per_metacell = ut.get_vo_proper(common_qdata, property_name)
data_per_gene_per_metacell = np.zeros(qdata.shape, dtype=ut.shaped_dtype(data_per_common_gene_per_metacell))
data_per_gene_per_metacell[:, full_gene_index_of_common_qdata] = ut.to_numpy_matrix(
data_per_common_gene_per_metacell
)
ut.set_vo_data(qdata, property_name, sp.csr_matrix(data_per_gene_per_metacell))
property_names = [f"fitted_gene_of_{type_name}" for type_name in type_names]
if project_corrections:
property_names.append("correction_factor")
for property_name in property_names:
data_per_common_gene = ut.get_v_numpy(common_qdata, property_name)
data_per_full_gene = np.zeros(qdata.n_vars, dtype=data_per_common_gene.dtype)
data_per_full_gene[full_gene_index_of_common_qdata] = data_per_common_gene
ut.set_v_data(qdata, property_name, data_per_full_gene)
for property_name, formatter in (
("projected_type", None),
("projected_secondary_type", None),
("projected_correlation", ut.sizes_description),
("similar", None),
):
data_per_metacell = ut.get_o_numpy(common_qdata, property_name, formatter=formatter)
ut.set_o_data(qdata, property_name, data_per_metacell, formatter=formatter)
@ut.logged()
@ut.timed_call()
def _compute_preliminary_projection(
*,
common_adata: AnnData,
common_qdata: AnnData,
preliminary_fitted_genes_mask: ut.NumpyVector,
project_log_data: bool,
project_fold_regularization: float,
project_min_significant_gene_umis: float,
project_filter_ranges: bool,
project_ignore_range_quantile: float,
project_ignore_range_min_overlap_fraction: float,
project_max_consistency_fold_factor: float,
project_candidates_count: int,
project_min_usage_weight: float,
project_corrections: bool,
project_min_corrected_gene_correlation: float,
project_min_corrected_gene_factor: float,
atlas_type_property_name: str,
reproducible: bool,
) -> None:
correction_factor_per_gene = np.full(common_qdata.n_vars, 1.0, dtype="float32")
repeat = 0
while True:
repeat += 1
ut.log_calc("preliminary repeat", repeat)
fitted_adata, weights = _compute_correction_factors(
common_adata=common_adata,
common_qdata=common_qdata,
correction_factor_per_gene=correction_factor_per_gene,
preliminary_fitted_genes_mask=preliminary_fitted_genes_mask,
project_log_data=project_log_data,
project_fold_regularization=project_fold_regularization,
project_min_significant_gene_umis=project_min_significant_gene_umis,
project_max_consistency_fold_factor=project_max_consistency_fold_factor,
project_candidates_count=project_candidates_count,
project_min_usage_weight=project_min_usage_weight,
project_corrections=project_corrections,
project_min_corrected_gene_correlation=project_min_corrected_gene_correlation,
project_min_corrected_gene_factor=project_min_corrected_gene_factor,
reproducible=reproducible,
)
if (
repeat > 2
or not project_filter_ranges
or not _filter_range_genes(
common_qdata=common_qdata,
project_fold_regularization=project_fold_regularization,
project_ignore_range_quantile=project_ignore_range_quantile,
project_ignore_range_min_overlap_fraction=project_ignore_range_min_overlap_fraction,
preliminary_fitted_genes_mask=preliminary_fitted_genes_mask,
)
):
ut.log_calc("preliminary last repeat", repeat)
break
if project_corrections:
ut.set_v_data(common_qdata, "correction_factor", correction_factor_per_gene)
tl.convey_atlas_to_query(
adata=fitted_adata,
qdata=common_qdata,
weights=weights,
property_name=atlas_type_property_name,
to_property_name="projected_type",
)
@ut.logged()
@ut.timed_call()
def _compute_correction_factors(
*,
common_adata: AnnData,
common_qdata: AnnData,
correction_factor_per_gene: ut.NumpyVector,
preliminary_fitted_genes_mask: ut.NumpyVector,
project_log_data: bool,
project_fold_regularization: float,
project_min_significant_gene_umis: float,
project_max_consistency_fold_factor: float,
project_candidates_count: int,
project_min_usage_weight: float,
project_corrections: bool,
project_min_corrected_gene_correlation: float,
project_min_corrected_gene_factor: float,
reproducible: bool,
) -> Tuple[AnnData, ut.ProperMatrix]:
repeat = 0
while True:
repeat += 1
ut.log_calc("corrections repeat", repeat)
fitted_adata = ut.slice(common_adata, name=".fitted", vars=preliminary_fitted_genes_mask, top_level=False)
fitted_qdata = ut.slice(
common_qdata,
name=".fitted",
vars=preliminary_fitted_genes_mask,
track_var="common_gene_index_of_qdata",
top_level=False,
)
weights = tl.compute_projection_weights(
adata=fitted_adata,
qdata=fitted_qdata,
log_data=project_log_data,
fold_regularization=project_fold_regularization,
min_significant_gene_umis=project_min_significant_gene_umis,
max_consistency_fold_factor=project_max_consistency_fold_factor,
candidates_count=project_candidates_count,
min_usage_weight=project_min_usage_weight,
reproducible=reproducible,
)
tl.compute_projected_fractions(
adata=common_adata,
qdata=common_qdata,
log_data=project_log_data,
fold_regularization=project_fold_regularization,
weights=weights,
)
if (
repeat > 2
or not project_corrections
or not _correct_correlated_genes(
common_adata=common_adata,
common_qdata=common_qdata,
preliminary_fitted_genes_mask=preliminary_fitted_genes_mask,
project_min_corrected_gene_correlation=project_min_corrected_gene_correlation,
project_min_corrected_gene_factor=project_min_corrected_gene_factor,
correction_factor_per_gene=correction_factor_per_gene,
reproducible=reproducible,
)
):
ut.log_calc("corrections last repeat", repeat)
return fitted_adata, weights
def _correct_correlated_genes(
*,
common_adata: AnnData,
common_qdata: AnnData,
preliminary_fitted_genes_mask: ut.NumpyVector,
project_min_corrected_gene_correlation: float,
project_min_corrected_gene_factor: float,
correction_factor_per_gene: ut.NumpyVector,
reproducible: bool,
) -> bool:
corrected_fractions_per_gene_per_metacell = ut.to_numpy_matrix(
ut.get_vo_proper(common_qdata, "corrected_fraction", layout="column_major")
)
projected_fractions_per_gene_per_metacell = ut.to_numpy_matrix(
ut.get_vo_proper(common_qdata, "projected_fraction", layout="column_major")
)
preliminary_fitted_genes_indices = np.where(preliminary_fitted_genes_mask)[0]
corrected_fractions_per_fitted_gene_per_metacell = corrected_fractions_per_gene_per_metacell[
:, preliminary_fitted_genes_indices
]
projected_fractions_per_fitted_gene_per_metacell = projected_fractions_per_gene_per_metacell[
:, preliminary_fitted_genes_indices
]
correlation_per_fitted_gene = np.full(len(preliminary_fitted_genes_indices), -2, dtype="float32")
correlation_per_fitted_gene = ut.pairs_corrcoef_rows(
corrected_fractions_per_fitted_gene_per_metacell.transpose(),
projected_fractions_per_fitted_gene_per_metacell.transpose(),
reproducible=reproducible,
)
correlated_fitted_genes_mask = correlation_per_fitted_gene >= project_min_corrected_gene_correlation
ut.log_calc("correlated_fitted_genes_mask", correlated_fitted_genes_mask)
if not np.any(correlated_fitted_genes_mask):
return False
total_corrected_fractions_per_fitted_gene = ut.sum_per(
corrected_fractions_per_fitted_gene_per_metacell, per="column"
)
total_projected_fractions_per_fitted_gene = ut.sum_per(
projected_fractions_per_fitted_gene_per_metacell, per="column"
)
zero_fitted_genes_mask = (total_projected_fractions_per_fitted_gene == 0) | (
total_corrected_fractions_per_fitted_gene == 0
)
ut.log_calc("zero_fitted_genes_mask", zero_fitted_genes_mask)
total_corrected_fractions_per_fitted_gene[zero_fitted_genes_mask] = 1.0
total_projected_fractions_per_fitted_gene[zero_fitted_genes_mask] = 1.0
current_correction_factor_per_fitted_gene = (
total_projected_fractions_per_fitted_gene / total_corrected_fractions_per_fitted_gene
)
high_factor = 1 + project_min_corrected_gene_factor
low_factor = 1.0 / high_factor
factor_fitted_genes_mask = (current_correction_factor_per_fitted_gene <= low_factor) | (
current_correction_factor_per_fitted_gene >= high_factor
)
factor_genes_mask = np.zeros(common_adata.n_vars, dtype="bool")
factor_genes_mask[preliminary_fitted_genes_indices] = factor_fitted_genes_mask
ut.log_calc("factor_fitted_genes_mask", factor_fitted_genes_mask)
if not np.any(factor_fitted_genes_mask):
return False
corrected_fitted_genes_mask = correlated_fitted_genes_mask & factor_fitted_genes_mask
ut.log_calc("corrected_fitted_genes_mask", corrected_fitted_genes_mask)
if not np.any(corrected_fitted_genes_mask):
return False
corrected_genes_mask = np.zeros(common_adata.n_vars, dtype="bool")
corrected_genes_mask[preliminary_fitted_genes_indices] = corrected_fitted_genes_mask
ut.log_calc("corrected_genes_mask", corrected_genes_mask)
correction_factor_per_corrected_gene = current_correction_factor_per_fitted_gene[corrected_fitted_genes_mask]
correction_factor_per_gene[corrected_genes_mask] *= correction_factor_per_corrected_gene
corrected_fractions_per_gene_per_metacell[:, corrected_genes_mask] *= correction_factor_per_corrected_gene[
np.newaxis, :
]
corrected_fractions_per_gene_per_metacell = ut.fraction_by( # type: ignore
ut.to_layout(corrected_fractions_per_gene_per_metacell, layout="row_major"), by="row"
)
ut.set_vo_data(common_qdata, "corrected_fraction", sp.csr_matrix(corrected_fractions_per_gene_per_metacell))
return True
@ut.logged()
@ut.timed_call()
def _filter_range_genes(
*,
common_qdata: AnnData,
project_fold_regularization: float,
project_ignore_range_quantile: float,
project_ignore_range_min_overlap_fraction: float,
preliminary_fitted_genes_mask: ut.NumpyVector,
) -> bool:
corrected_fractions_per_gene_per_metacell = ut.to_numpy_matrix(
ut.get_vo_proper(common_qdata, "corrected_fraction", layout="column_major")
)
projected_fractions_per_gene_per_metacell = ut.to_numpy_matrix(
ut.get_vo_proper(common_qdata, "projected_fraction", layout="column_major")
)
preliminary_fitted_genes_indices = np.where(preliminary_fitted_genes_mask)[0]
corrected_fractions_per_fitted_gene_per_metacell = corrected_fractions_per_gene_per_metacell[
:, preliminary_fitted_genes_indices
]
projected_fractions_per_fitted_gene_per_metacell = projected_fractions_per_gene_per_metacell[
:, preliminary_fitted_genes_indices
]
corrected_log_fractions_per_fitted_gene_per_metacell = (
corrected_fractions_per_fitted_gene_per_metacell + project_fold_regularization
)
projected_log_fractions_per_fitted_gene_per_metacell = (
projected_fractions_per_fitted_gene_per_metacell + project_fold_regularization
)
np.log2(
corrected_log_fractions_per_fitted_gene_per_metacell, out=corrected_log_fractions_per_fitted_gene_per_metacell
)
np.log2(
projected_log_fractions_per_fitted_gene_per_metacell, out=projected_log_fractions_per_fitted_gene_per_metacell
)
low_corrected_log_fractions_per_fitted_gene = ut.quantile_per(
corrected_log_fractions_per_fitted_gene_per_metacell, project_ignore_range_quantile, per="column"
)
low_projected_log_fractions_per_fitted_gene = ut.quantile_per(
projected_log_fractions_per_fitted_gene_per_metacell, project_ignore_range_quantile, per="column"
)
high_corrected_log_fractions_per_fitted_gene = ut.quantile_per(
corrected_log_fractions_per_fitted_gene_per_metacell, 1.0 - project_ignore_range_quantile, per="column"
)
high_projected_log_fractions_per_fitted_gene = ut.quantile_per(
projected_log_fractions_per_fitted_gene_per_metacell, 1.0 - project_ignore_range_quantile, per="column"
)
low_common_log_fractions_per_fitted_gene = np.maximum(
low_corrected_log_fractions_per_fitted_gene, low_projected_log_fractions_per_fitted_gene
)
high_common_log_fractions_per_fitted_gene = np.minimum(
high_corrected_log_fractions_per_fitted_gene, high_projected_log_fractions_per_fitted_gene
)
corrected_range_log_fractions_per_fitted_gene = (
high_corrected_log_fractions_per_fitted_gene - low_corrected_log_fractions_per_fitted_gene
)
common_range_log_fractions_per_fitted_gene = (
high_common_log_fractions_per_fitted_gene - low_common_log_fractions_per_fitted_gene
)
corrected_range_log_fractions_per_fitted_gene[corrected_range_log_fractions_per_fitted_gene == 0] = 1
overlap_per_fitted_gene = common_range_log_fractions_per_fitted_gene / corrected_range_log_fractions_per_fitted_gene
ignore_fitted_genes_mask = overlap_per_fitted_gene < project_ignore_range_min_overlap_fraction
ut.log_calc("ignore_fitted_genes_mask", ignore_fitted_genes_mask)
if not np.any(ignore_fitted_genes_mask):
return False
ignore_gene_indices = preliminary_fitted_genes_indices[ignore_fitted_genes_mask]
preliminary_fitted_genes_mask[ignore_gene_indices] = False
ut.log_calc("preliminary_fitted_genes_mask", preliminary_fitted_genes_mask)
return True
@ut.logged()
@ut.timed_call()
def _compute_per_type_projection(
*,
common_adata: AnnData,
common_qdata: AnnData,
type_names: List[str],
preliminary_fitted_genes_mask: ut.NumpyVector,
misfit_min_metacells_fraction: float,
project_log_data: bool,
project_fold_regularization: float,
project_candidates_count: int,
project_min_significant_gene_umis: float,
project_min_usage_weight: float,
project_max_consistency_fold_factor: float,
project_max_projection_fold_factor: float,
project_max_projection_noisy_fold_factor: float,
project_max_misfit_genes: int,
min_fitted_query_marker_genes: float,
min_essential_genes_of_type: Dict[Tuple[str, str], Optional[int]],
atlas_type_property_name: str,
top_level_parallel: bool,
reproducible: bool,
) -> ut.NumpyMatrix:
old_types_per_metacell: List[Set[str]] = []
for _metacell_index in range(common_qdata.n_obs):
old_types_per_metacell.append(set())
fitted_genes_mask_per_type = _initial_fitted_genes_mask_per_type(
common_qdata=common_qdata, type_names=type_names, preliminary_fitted_genes_mask=preliminary_fitted_genes_mask
)
misfit_per_gene_per_metacell = np.empty(common_qdata.shape, dtype="bool")
projected_correlation_per_metacell = np.empty(common_qdata.n_obs, dtype="float32")
projected_fold_per_gene_per_metacell = np.empty(common_qdata.shape, dtype="float32")
weights_per_atlas_per_query_metacell = np.empty((common_qdata.n_obs, common_adata.n_obs), dtype="float32")
similar_per_metacell = np.empty(common_qdata.n_obs, dtype="bool")
repeat = 0
while True:
repeat += 1
ut.log_calc("types repeat", repeat)
type_of_query_metacells = ut.get_o_numpy(common_qdata, "projected_type")
query_type_names = list(np.unique(type_of_query_metacells))
misfit_per_gene_per_metacell[:, :] = False
projected_fold_per_gene_per_metacell[:, :] = 0.0
@ut.timed_call("single_type_projection")
def _single_type_projection(
type_index: int,
) -> Dict[str, Any]:
return _compute_single_type_projection(
type_name=query_type_names[type_index],
common_adata=common_adata,
common_qdata=common_qdata,
fitted_genes_mask_per_type=fitted_genes_mask_per_type,
misfit_min_metacells_fraction=misfit_min_metacells_fraction,
project_log_data=project_log_data,
project_fold_regularization=project_fold_regularization,
project_candidates_count=project_candidates_count,
project_min_significant_gene_umis=project_min_significant_gene_umis,
project_min_usage_weight=project_min_usage_weight,
project_max_consistency_fold_factor=project_max_consistency_fold_factor,
project_max_projection_fold_factor=project_max_projection_fold_factor,
project_max_projection_noisy_fold_factor=project_max_projection_noisy_fold_factor,
project_max_misfit_genes=project_max_misfit_genes,
min_fitted_query_marker_genes=min_fitted_query_marker_genes,
min_essential_genes_of_type=min_essential_genes_of_type,
top_level_parallel=top_level_parallel,
reproducible=reproducible,
)
@ut.logged()
def _collect_single_type_result(
type_index: int,
*,
query_metacell_indices_of_type: ut.NumpyVector,
fitted_genes_indices_of_type: ut.NumpyVector,
similar_per_metacell_of_type: ut.NumpyVector,
misfit_per_gene_per_metacell_of_type: ut.ProperMatrix,
projected_correlation_per_metacell_of_type: ut.NumpyVector,
projected_fold_per_gene_per_metacell_of_type: ut.ProperMatrix,
weights_per_atlas_per_query_metacell_of_type: ut.ProperMatrix,
) -> None:
fitted_genes_mask_per_type[query_type_names[type_index]][:] = False
fitted_genes_mask_per_type[query_type_names[type_index]][fitted_genes_indices_of_type] = True
similar_per_metacell[query_metacell_indices_of_type] = similar_per_metacell_of_type
projected_correlation_per_metacell[
query_metacell_indices_of_type
] = projected_correlation_per_metacell_of_type
misfit_per_gene_per_metacell[query_metacell_indices_of_type, :] = ut.to_numpy_matrix(
misfit_per_gene_per_metacell_of_type
)
projected_fold_per_gene_per_metacell[query_metacell_indices_of_type, :] = ut.to_numpy_matrix(
projected_fold_per_gene_per_metacell_of_type
)
weights_per_atlas_per_query_metacell[query_metacell_indices_of_type, :] = ut.to_numpy_matrix(
weights_per_atlas_per_query_metacell_of_type
)
if top_level_parallel:
for type_index, result in enumerate(ut.parallel_map(_single_type_projection, len(query_type_names))):
_collect_single_type_result(type_index, **result)
else:
for type_index in range(len(query_type_names)):
result = _single_type_projection(type_index)
_collect_single_type_result(type_index, **result)
if repeat > 2 or not _changed_projected_types(
common_adata=common_adata,
common_qdata=common_qdata,
old_type_of_query_metacells=type_of_query_metacells,
weights_per_atlas_per_query_metacell=weights_per_atlas_per_query_metacell,
atlas_type_property_name=atlas_type_property_name,
old_types_per_metacell=old_types_per_metacell,
):
ut.log_calc("types last repeat", repeat)
break
for type_name, fitted_genes_mask_of_type in fitted_genes_mask_per_type.items():
ut.set_v_data(common_qdata, f"fitted_gene_of_{type_name}", fitted_genes_mask_of_type)
projected_type_per_metacell = ut.get_o_numpy(common_qdata, "projected_type")
fitted_mask_per_gene_per_metacell = np.vstack(
[fitted_genes_mask_per_type[type_name] for type_name in projected_type_per_metacell]
)
ut.set_vo_data(common_qdata, "fitted", fitted_mask_per_gene_per_metacell)
ut.set_o_data(common_qdata, "similar", similar_per_metacell)
ut.set_o_data(common_qdata, "projected_correlation", projected_correlation_per_metacell)
ut.set_vo_data(common_qdata, "misfit", misfit_per_gene_per_metacell)
ut.set_vo_data(common_qdata, "projected_fold", projected_fold_per_gene_per_metacell)
tl.compute_projected_fractions(
adata=common_adata,
qdata=common_qdata,
log_data=project_log_data,
fold_regularization=project_fold_regularization,
weights=weights_per_atlas_per_query_metacell,
)
return weights_per_atlas_per_query_metacell
def _initial_fitted_genes_mask_per_type(
common_qdata: AnnData,
type_names: List[str],
preliminary_fitted_genes_mask: ut.NumpyVector,
) -> Dict[str, ut.NumpyVector]:
fitted_genes_mask_per_type: Dict[str, ut.NumpyVector] = {}
for type_name in type_names:
fitted_genes_mask_of_type = preliminary_fitted_genes_mask.copy()
property_name = f"ignored_gene_of_{type_name}"
if ut.has_data(common_qdata, property_name):
ignored_gene_mask_of_type = ut.get_v_numpy(common_qdata, property_name)
fitted_genes_mask_of_type &= ~ignored_gene_mask_of_type
fitted_genes_mask_per_type[type_name] = fitted_genes_mask_of_type
return fitted_genes_mask_per_type
@ut.logged()
@ut.timed_call()
def _compute_single_type_projection(
*,
type_name: str,
common_adata: AnnData,
common_qdata: AnnData,
fitted_genes_mask_per_type: Dict[str, ut.NumpyVector],
misfit_min_metacells_fraction: float,
project_log_data: bool,
project_fold_regularization: float,
project_candidates_count: int,
project_min_significant_gene_umis: float,
project_min_usage_weight: float,
project_max_consistency_fold_factor: float,
project_max_projection_fold_factor: float,
project_max_projection_noisy_fold_factor: float,
project_max_misfit_genes: int,
min_fitted_query_marker_genes: float,
min_essential_genes_of_type: Dict[Tuple[str, str], Optional[int]],
top_level_parallel: bool,
reproducible: bool,
) -> Dict[str, Any]:
projected_type_per_metacell = ut.get_o_numpy(common_qdata, "projected_type")
query_metacell_mask_of_type = projected_type_per_metacell == type_name
ut.log_calc("query_metacell_mask_of_type", query_metacell_mask_of_type)
assert np.any(query_metacell_mask_of_type)
type_common_qdata = ut.slice(common_qdata, name=f".{type_name}", obs=query_metacell_mask_of_type, top_level=False)
corrected_fractions = ut.get_vo_proper(type_common_qdata, "corrected_fraction")
fitted_genes_mask_of_type = fitted_genes_mask_per_type[type_name]
assert np.any(fitted_genes_mask_of_type)
repeat = 0
while True:
repeat += 1
ut.log_calc(f"{type_name} misfit repeat", repeat)
type_fitted_adata = ut.slice(
common_adata, name=f".{type_name}.fitted", vars=fitted_genes_mask_of_type, top_level=False
)
type_fitted_qdata = ut.slice(
type_common_qdata,
name=".fitted",
vars=fitted_genes_mask_of_type,
track_var="common_gene_index_of_qdata",
top_level=False,
)
type_weights = tl.compute_projection_weights(
adata=type_fitted_adata,
qdata=type_fitted_qdata,
log_data=project_log_data,
fold_regularization=project_fold_regularization,
min_significant_gene_umis=project_min_significant_gene_umis,
max_consistency_fold_factor=project_max_consistency_fold_factor,
candidates_count=project_candidates_count,
min_usage_weight=project_min_usage_weight,
reproducible=reproducible,
)
tl.compute_projected_fractions(
adata=common_adata,
qdata=type_common_qdata,
log_data=project_log_data,
fold_regularization=project_fold_regularization,
weights=type_weights,
)
projected_fractions = ut.get_vo_proper(type_common_qdata, "projected_fraction")
tl.compute_projected_folds(
type_common_qdata,
fold_regularization=project_fold_regularization,
min_significant_gene_umis=project_min_significant_gene_umis,
)
projected_fold_per_gene_per_metacell_of_type = ut.get_vo_proper(type_common_qdata, "projected_fold")
if not _detect_type_misfit_genes(
type_common_qdata=type_common_qdata,
projected_fold_per_gene_per_metacell_of_type=projected_fold_per_gene_per_metacell_of_type,
max_projection_fold_factor=project_max_projection_fold_factor,
max_projection_noisy_fold_factor=project_max_projection_noisy_fold_factor,
misfit_min_metacells_fraction=misfit_min_metacells_fraction,
fitted_genes_mask_of_type=fitted_genes_mask_of_type,
):
ut.log_calc(f"{type_name} misfit last repeat", repeat)
break
tl.compute_similar_query_metacells(
type_common_qdata,
max_projection_fold_factor=project_max_projection_fold_factor,
max_projection_noisy_fold_factor=project_max_projection_noisy_fold_factor,
max_misfit_genes=project_max_misfit_genes,
min_fitted_query_marker_genes=min_fitted_query_marker_genes,
essential_genes_property=f"essential_gene_of_{type_name}",
min_essential_genes=min_essential_genes_of_type[(type_name, type_name)],
fitted_genes_mask=fitted_genes_mask_of_type,
)
similar_per_metacell_of_type = ut.get_o_numpy(type_common_qdata, "similar")
misfit_per_gene_per_metacell_of_type = ut.get_vo_proper(type_common_qdata, "misfit")
fitted_corrected_fractions = corrected_fractions[:, fitted_genes_mask_of_type]
fitted_projected_fractions = projected_fractions[:, fitted_genes_mask_of_type]
projected_correlation_per_metacell_of_type = ut.pairs_corrcoef_rows(
ut.to_layout(ut.to_numpy_matrix(fitted_corrected_fractions), layout="row_major"),
ut.to_layout(ut.to_numpy_matrix(fitted_projected_fractions), layout="row_major"),
reproducible=reproducible,
)
if top_level_parallel:
if not isinstance(misfit_per_gene_per_metacell_of_type, sp.csr_matrix):
misfit_per_gene_per_metacell_of_type = sp.csr_matrix(misfit_per_gene_per_metacell_of_type)
if not isinstance(projected_fold_per_gene_per_metacell_of_type, sp.csr_matrix):
projected_fold_per_gene_per_metacell_of_type = sp.csr_matrix(projected_fold_per_gene_per_metacell_of_type)
if not isinstance(type_weights, sp.csr_matrix):
type_weights = sp.csr_matrix(type_weights)
ut.log_return("query_metacell_mask_of_type", query_metacell_mask_of_type)
ut.log_return("fitted_genes_mask_of_type", fitted_genes_mask_of_type)
ut.log_return("similar_per_metacell_of_type", similar_per_metacell_of_type)
ut.log_return("misfit_per_gene_per_metacell_of_type", misfit_per_gene_per_metacell_of_type)
ut.log_return("projected_correlation_per_metacell_of_type", projected_correlation_per_metacell_of_type)
ut.log_return("projected_fold_per_gene_per_metacell_of_type", projected_fold_per_gene_per_metacell_of_type)
ut.log_return("weights", type_weights)
return {
"query_metacell_indices_of_type": np.where(query_metacell_mask_of_type)[0],
"fitted_genes_indices_of_type": np.where(fitted_genes_mask_of_type)[0],
"similar_per_metacell_of_type": similar_per_metacell_of_type,
"misfit_per_gene_per_metacell_of_type": misfit_per_gene_per_metacell_of_type,
"projected_correlation_per_metacell_of_type": projected_correlation_per_metacell_of_type,
"projected_fold_per_gene_per_metacell_of_type": projected_fold_per_gene_per_metacell_of_type,
"weights_per_atlas_per_query_metacell_of_type": type_weights,
}
def _detect_type_misfit_genes(
*,
type_common_qdata: AnnData,
projected_fold_per_gene_per_metacell_of_type: ut.ProperMatrix,
max_projection_fold_factor: float,
max_projection_noisy_fold_factor: float,
misfit_min_metacells_fraction: float,
fitted_genes_mask_of_type: ut.NumpyVector,
) -> bool:
assert max_projection_fold_factor >= 0
assert max_projection_noisy_fold_factor >= 0
assert 0 <= misfit_min_metacells_fraction <= 1
if ut.has_data(type_common_qdata, "projected_noisy_gene"):
max_projection_fold_factor_per_gene = np.full(
type_common_qdata.n_vars, max_projection_fold_factor, dtype="float32"
)
noisy_per_gene = ut.get_v_numpy(type_common_qdata, "projected_noisy_gene")
max_projection_fold_factor_per_gene[noisy_per_gene] += max_projection_noisy_fold_factor
high_projection_fold_per_gene_per_metacell_of_type = (
projected_fold_per_gene_per_metacell_of_type > max_projection_fold_factor_per_gene[np.newaxis, :]
)
else:
high_projection_fold_per_gene_per_metacell_of_type = (
ut.to_numpy_matrix(projected_fold_per_gene_per_metacell_of_type) > max_projection_fold_factor
)
ut.log_calc(
"high_projection_fold_per_gene_per_metacell_of_type", high_projection_fold_per_gene_per_metacell_of_type
)
high_projection_metacells_per_gene = ut.sum_per(
ut.to_layout(high_projection_fold_per_gene_per_metacell_of_type, layout="column_major"), per="column"
)
ut.log_calc(
"high_projection_metacells_per_gene", high_projection_metacells_per_gene, formatter=ut.sizes_description
)
min_high_projection_metacells = type_common_qdata.n_obs * misfit_min_metacells_fraction
ut.log_calc("min_high_projection_metacells", min_high_projection_metacells)
high_projection_genes_mask = high_projection_metacells_per_gene >= min_high_projection_metacells
type_misfit_genes_mask = fitted_genes_mask_of_type & high_projection_genes_mask
ut.log_calc("type_misfit_genes_mask", type_misfit_genes_mask)
if not np.any(type_misfit_genes_mask):
return False
fitted_genes_mask_of_type[type_misfit_genes_mask] = False
ut.log_calc("fitted_genes_mask_of_type", fitted_genes_mask_of_type)
return True
@ut.logged()
@ut.timed_call()
def _changed_projected_types(
*,
common_adata: AnnData,
common_qdata: AnnData,
old_type_of_query_metacells: ut.NumpyVector,
weights_per_atlas_per_query_metacell: ut.NumpyMatrix,
atlas_type_property_name: str,
old_types_per_metacell: List[Set[str]],
) -> bool:
tl.convey_atlas_to_query(
adata=common_adata,
qdata=common_qdata,
weights=weights_per_atlas_per_query_metacell,
property_name=atlas_type_property_name,
to_property_name="projected_type",
)
new_type_of_query_metacells = ut.get_o_numpy(common_qdata, "projected_type")
has_changed = False
for metacell_index, old_types_of_metacell in enumerate(old_types_per_metacell):
old_type = old_type_of_query_metacells[metacell_index]
old_types_of_metacell.add(old_type)
new_type = new_type_of_query_metacells[metacell_index]
if new_type not in old_types_of_metacell:
ut.log_calc(f"metacell: {metacell_index} changed from old type: {old_type} to new type", new_type)
has_changed = True
elif new_type != old_type:
ut.log_calc(f"metacell: {metacell_index} changed from old type: {old_type} to older type", new_type)
ut.log_return("has_changed", has_changed)
return has_changed
@ut.timed_call()
@ut.logged()
def _compute_dissimilar_residuals_projection(
*,
common_adata: AnnData,
common_qdata: AnnData,
weights_per_atlas_per_query_metacell: ut.NumpyMatrix,
project_log_data: bool,
project_fold_regularization: float,
project_candidates_count: int,
project_min_candidates_fraction: float,
project_min_significant_gene_umis: float,
project_min_usage_weight: float,
project_max_consistency_fold_factor: float,
project_max_projection_fold_factor: float,
project_max_projection_noisy_fold_factor: float,
project_max_misfit_genes: int,
min_fitted_query_marker_genes: float,
min_essential_genes_of_type: Dict[Tuple[str, str], Optional[int]],
atlas_type_property_name: str,
top_level_parallel: bool,
reproducible: bool,
) -> None:
secondary_type = [""] * common_qdata.n_obs
dissimilar_metacells_mask = ~ut.get_o_numpy(common_qdata, "similar")
if not np.any(dissimilar_metacells_mask):
ut.set_o_data(common_qdata, "projected_secondary_type", np.array(secondary_type))
return
dissimilar_metacell_indices = np.where(dissimilar_metacells_mask)[0]
@ut.timed_call("single_metacell_residuals")
def _single_metacell_residuals(
dissimilar_metacell_position: int,
) -> Optional[Dict[str, Any]]:
return _compute_single_metacell_residuals(
dissimilar_metacell_index=dissimilar_metacell_indices[dissimilar_metacell_position],
common_adata=common_adata,
common_qdata=common_qdata,
project_log_data=project_log_data,
project_fold_regularization=project_fold_regularization,
project_candidates_count=project_candidates_count,
project_min_candidates_fraction=project_min_candidates_fraction,
project_min_significant_gene_umis=project_min_significant_gene_umis,
project_min_usage_weight=project_min_usage_weight,
project_max_consistency_fold_factor=project_max_consistency_fold_factor,
project_max_projection_fold_factor=project_max_projection_fold_factor,
project_max_projection_noisy_fold_factor=project_max_projection_noisy_fold_factor,
project_max_misfit_genes=project_max_misfit_genes,
min_fitted_query_marker_genes=min_fitted_query_marker_genes,
min_essential_genes_of_type=min_essential_genes_of_type,
atlas_type_property_name=atlas_type_property_name,
reproducible=reproducible,
)
if top_level_parallel:
results = ut.parallel_map(_single_metacell_residuals, len(dissimilar_metacell_indices))
else:
results = [
_single_metacell_residuals(dissimilar_metacell_position)
for dissimilar_metacell_position in range(len(dissimilar_metacell_indices))
]
similar_per_metacell = ut.get_o_numpy(common_qdata, "similar").copy()
primary_type_per_metacell = ut.get_o_numpy(common_qdata, "projected_type").copy()
secondary_type_per_metacell = [""] * common_qdata.n_obs
fitted_per_gene_per_metacell = ut.to_numpy_matrix(ut.get_vo_proper(common_qdata, "fitted"), copy=True)
misfit_per_gene_per_metacell = ut.to_numpy_matrix(ut.get_vo_proper(common_qdata, "misfit"), copy=True)
projected_correlation_per_metacell = ut.get_o_numpy(common_qdata, "projected_correlation").copy()
projected_fold_per_gene_per_metacell = ut.to_numpy_matrix(
ut.get_vo_proper(common_qdata, "projected_fold"), copy=True
)
def _collect_metacell_residuals(
dissimilar_metacell_index: int,
primary_type: str,
secondary_type: str,
fitted_genes_mask: ut.NumpyVector,
misfit_genes_mask: ut.NumpyVector,
projected_correlation: float,
projected_fold_per_gene: ut.NumpyVector,
weights: ut.ProperMatrix,
) -> None:
similar_per_metacell[dissimilar_metacell_index] = True
primary_type_per_metacell[dissimilar_metacell_index] = primary_type
secondary_type_per_metacell[dissimilar_metacell_index] = secondary_type
projected_correlation_per_metacell[dissimilar_metacell_index] = projected_correlation
fitted_per_gene_per_metacell[dissimilar_metacell_index, :] = fitted_genes_mask
misfit_per_gene_per_metacell[dissimilar_metacell_index, :] = misfit_genes_mask
projected_fold_per_gene_per_metacell[dissimilar_metacell_index, :] = projected_fold_per_gene
weights_per_atlas_per_query_metacell[dissimilar_metacell_index, :] = weights
for dissimilar_metacell_index, result in zip(dissimilar_metacell_indices, results):
if result is not None:
_collect_metacell_residuals(dissimilar_metacell_index, **result)
ut.set_o_data(common_qdata, "similar", similar_per_metacell)
ut.set_o_data(common_qdata, "projected_type", primary_type_per_metacell)
ut.set_o_data(common_qdata, "projected_secondary_type", np.array(secondary_type_per_metacell))
ut.set_o_data(common_qdata, "projected_correlation", projected_correlation_per_metacell)
ut.set_vo_data(common_qdata, "fitted", fitted_per_gene_per_metacell)
ut.set_vo_data(common_qdata, "misfit", misfit_per_gene_per_metacell)
ut.set_vo_data(common_qdata, "projected_fold", projected_fold_per_gene_per_metacell)
@ut.timed_call()
@ut.logged()
def _compute_single_metacell_residuals( # pylint: disable=too-many-statements
*,
dissimilar_metacell_index: int,
common_adata: AnnData,
common_qdata: AnnData,
project_log_data: bool,
project_fold_regularization: float,
project_candidates_count: int,
project_min_candidates_fraction: float,
project_min_significant_gene_umis: float,
project_min_usage_weight: float,
project_max_consistency_fold_factor: float,
project_max_projection_fold_factor: float,
project_max_projection_noisy_fold_factor: float,
project_max_misfit_genes: int,
min_fitted_query_marker_genes: float,
min_essential_genes_of_type: Dict[Tuple[str, str], Optional[int]],
atlas_type_property_name: str,
reproducible: bool,
) -> Optional[Dict[str, Any]]:
dissimilar_qdata = ut.slice(
common_qdata, name=f".dissimilar.{dissimilar_metacell_index}", obs=[dissimilar_metacell_index], top_level=False
)
corrected_fractions = ut.get_vo_proper(dissimilar_qdata, "corrected_fraction")
primary_type = ut.get_o_numpy(dissimilar_qdata, "projected_type")[0]
secondary_type = ""
repeat = 0
while True:
repeat += 1
ut.log_calc("residuals repeat", repeat)
ut.log_calc("primary_type", primary_type)
ut.log_calc("secondary_type", secondary_type)
fitted_genes_mask = ut.get_v_numpy(dissimilar_qdata, f"fitted_gene_of_{primary_type}")
if secondary_type != "":
fitted_genes_mask = fitted_genes_mask | ut.get_v_numpy(dissimilar_qdata, f"fitted_gene_of_{secondary_type}")
fitted_adata = ut.slice(
common_adata,
name=f".dissimilar.{dissimilar_metacell_index}.fitted",
vars=fitted_genes_mask,
top_level=False,
)
metacell_fitted_qdata = ut.slice(dissimilar_qdata, name=".fitted", vars=fitted_genes_mask, top_level=False)
second_anchor_indices: List[int] = []
weights = tl.compute_projection_weights(
adata=fitted_adata,
qdata=metacell_fitted_qdata,
log_data=project_log_data,
fold_regularization=project_fold_regularization,
min_significant_gene_umis=project_min_significant_gene_umis,
max_consistency_fold_factor=project_max_consistency_fold_factor,
candidates_count=project_candidates_count,
min_candidates_fraction=project_min_candidates_fraction,
min_usage_weight=project_min_usage_weight,
reproducible=reproducible,
second_anchor_indices=second_anchor_indices,
)
tl.compute_projected_fractions(
adata=common_adata,
qdata=dissimilar_qdata,
log_data=project_log_data,
fold_regularization=project_fold_regularization,
weights=weights,
)
projected_fractions = ut.get_vo_proper(dissimilar_qdata, "projected_fraction")
first_anchor_weights = weights.copy()
first_anchor_weights[:, second_anchor_indices] = 0.0
if len(second_anchor_indices) == 0:
new_secondary_type = ""
else:
second_anchor_weights = weights - first_anchor_weights # type: ignore
tl.convey_atlas_to_query(
adata=common_adata,
qdata=dissimilar_qdata,
weights=second_anchor_weights,
property_name=atlas_type_property_name,
to_property_name="projected_secondary_type",
)
new_secondary_type = ut.get_o_numpy(dissimilar_qdata, "projected_secondary_type")[0]
if np.sum(first_anchor_weights.data) == 0:
new_primary_type = new_secondary_type
new_secondary_type = ""
else:
tl.convey_atlas_to_query(
adata=common_adata,
qdata=dissimilar_qdata,
weights=first_anchor_weights,
property_name=atlas_type_property_name,
to_property_name="projected_type",
)
new_primary_type = ut.get_o_numpy(dissimilar_qdata, "projected_type")[0]
if repeat > 2 or (new_primary_type == primary_type and new_secondary_type == secondary_type):
ut.log_calc("residuals last repeat", repeat)
break
primary_type = new_primary_type
secondary_type = new_secondary_type
projected_correlation = ut.pairs_corrcoef_rows(
ut.to_numpy_matrix(corrected_fractions), ut.to_numpy_matrix(projected_fractions), reproducible=reproducible
)[0]
tl.compute_projected_folds(
dissimilar_qdata,
fold_regularization=project_fold_regularization,
min_significant_gene_umis=project_min_significant_gene_umis,
)
projected_fold_per_gene = ut.to_numpy_vector(ut.get_vo_proper(dissimilar_qdata, "projected_fold"))
essential_genes_properties = [f"essential_gene_of_{primary_type}"]
if secondary_type == "":
min_essential_genes = min_essential_genes_of_type[(primary_type, primary_type)]
else:
min_essential_genes = min_essential_genes_of_type[(primary_type, secondary_type)]
essential_genes_properties.append(f"essential_gene_of_{secondary_type}")
tl.compute_similar_query_metacells(
dissimilar_qdata,
max_projection_fold_factor=project_max_projection_fold_factor,
max_projection_noisy_fold_factor=project_max_projection_noisy_fold_factor,
max_misfit_genes=project_max_misfit_genes,
essential_genes_property=essential_genes_properties,
min_essential_genes=min_essential_genes,
min_fitted_query_marker_genes=min_fitted_query_marker_genes,
fitted_genes_mask=fitted_genes_mask,
)
similar = ut.get_o_numpy(dissimilar_qdata, "similar")[0]
ut.log_return("similar", False)
if not similar:
return None
misfit_genes_mask = ut.to_numpy_vector(ut.get_vo_proper(dissimilar_qdata, "misfit")[0, :])
ut.log_return("primary_type", primary_type)
ut.log_return("secondary_type", secondary_type)
ut.log_return("fitted_genes_mask", fitted_genes_mask)
ut.log_return("misfit_genes_mask", misfit_genes_mask)
ut.log_return("projected_correlation", projected_correlation)
ut.log_return("projected_fold_per_gene", projected_fold_per_gene)
return {
"primary_type": primary_type,
"secondary_type": secondary_type,
"fitted_genes_mask": fitted_genes_mask,
"misfit_genes_mask": misfit_genes_mask,
"projected_correlation": projected_correlation,
"projected_fold_per_gene": projected_fold_per_gene,
"weights": ut.to_numpy_vector(weights),
}
def _normalize_corrected_fractions(adata: AnnData, what: str) -> None:
fractions_data = ut.get_vo_proper(adata, what, layout="row_major")
corrected_data = ut.fraction_by(fractions_data, by="row")
ut.set_vo_data(adata, "corrected_fraction", corrected_data)
[docs]
@ut.logged()
@ut.timed_call()
@ut.expand_doc()
def outliers_projection_pipeline(
what: str = "__x__",
*,
adata: AnnData,
odata: AnnData,
fold_regularization: float = pr.outliers_fold_regularization,
project_min_significant_gene_umis: int = pr.project_min_significant_gene_umis,
reproducible: bool,
) -> None:
"""
Project outliers on an atlas.
**Returns**
Sets the following in ``odata``:
Per-Observation (Cell) Annotations
``atlas_most_similar``
For each observation (outlier), the index of the "most similar" atlas metacell.
Per-Variable Per-Observation (Gene-Cell) Annotations
``atlas_most_similar_fold``
The fold factor between the outlier gene expression and their expression in the most similar atlas metacell
(unless the value is too low to be of interest, in which case it will be zero).
**Computation Parameters**
1. Invoke :py:func:`metacells.tools.quality.compute_outliers_matches` using the ``fold_regularization``
(default: {fold_regularization}) and ``reproducible``.
2. Invoke :py:func:`metacells.tools.quality.compute_outliers_fold_factors` using the ``fold_regularization``
(default: {fold_regularization}).
"""
if list(odata.var_names) != list(adata.var_names):
atlas_genes_list = list(adata.var_names)
query_genes_list = list(odata.var_names)
common_genes_list = list(sorted(set(atlas_genes_list) & set(query_genes_list)))
atlas_common_gene_indices = np.array([atlas_genes_list.index(gene) for gene in common_genes_list])
query_common_gene_indices = np.array([query_genes_list.index(gene) for gene in common_genes_list])
common_adata = ut.slice(adata, name=".common", vars=atlas_common_gene_indices, top_level=False)
common_odata = ut.slice(
odata,
name=".common",
vars=query_common_gene_indices,
top_level=False,
)
else:
common_adata = adata
common_odata = odata
tl.compute_outliers_matches(
what,
adata=common_odata,
gdata=common_adata,
most_similar="atlas_most_similar",
value_regularization=fold_regularization,
reproducible=reproducible,
)
tl.compute_outliers_fold_factors(
what,
adata=common_odata,
gdata=common_adata,
most_similar="atlas_most_similar",
min_gene_total=project_min_significant_gene_umis,
)
if list(odata.var_names) != list(adata.var_names):
atlas_most_similar = ut.get_o_numpy(common_odata, "atlas_most_similar")
ut.set_o_data(odata, "atlas_most_similar", atlas_most_similar)
common_folds = ut.get_vo_proper(common_odata, "atlas_most_similar_fold")
atlas_most_similar_fold = np.zeros(odata.shape, dtype="float32")
atlas_most_similar_fold[:, query_common_gene_indices] = common_folds
ut.set_vo_data(odata, "atlas_most_similar_fold", sp.csr_matrix(atlas_most_similar_fold))
[docs]
def write_projection_weights(path: str, adata: AnnData, qdata: AnnData, weights: ut.CompressedMatrix) -> None:
"""
Write into the ``path`` the ``weights`` computed for the projection of the query ``qdata`` on the atlas ``adata``.
Since the weights are (very) sparse, we just write them as a CSV file. This is also what ``MCView`` expect.
"""
with open(path, "w", encoding="utf8") as file:
file.write("query,atlas,weight\n")
for query_index, atlas_index in zip(*weights.nonzero()):
weight = weights[query_index, atlas_index]
file.write(f"{qdata.obs_names[query_index]},{adata.obs_names[atlas_index]},{weight}\n")