Source code for metacells.pipeline.select
"""
Selection
---------
"""
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import numpy as np
from anndata import AnnData # type: ignore
import metacells.parameters as pr
import metacells.tools as tl
import metacells.utilities as ut
__all__ = [
"extract_selected_data",
]
# pylint: disable=dangerous-default-value
[docs]
@ut.logged()
@ut.timed_call()
@ut.expand_doc()
def extract_selected_data( # pylint: disable=too-many-branches, too-many-statements
adata: AnnData,
what: Union[str, ut.Matrix] = "__x__",
*,
name: Optional[str] = ".select",
downsample_min_samples: float = pr.select_downsample_min_samples,
downsample_min_cell_quantile: float = pr.select_downsample_min_cell_quantile,
downsample_max_cell_quantile: float = pr.select_downsample_max_cell_quantile,
min_gene_relative_variance: Optional[float], # = pr.select_min_gene_relative_variance,
min_genes: int = pr.select_min_genes,
min_gene_total: Optional[int] = pr.select_min_gene_total,
min_gene_top3: Optional[int] = pr.select_min_gene_top3,
additional_gene_masks: List[str] = ["&~lateral_gene"],
random_seed: int,
top_level: bool = True,
) -> AnnData:
"""
Select a subset of ``what`` (default: {what} data, to compute metacells by.
When computing metacells (or clustering cells in general), it makes sense to use a subset of the genes for computing
cell-cell similarity, for both technical (e.g., too low an expression level) and biological (e.g., ignoring
bookkeeping and cell cycle genes) reasons. The steps provided here are expected to be generically useful, but as
always specific data sets may require custom gene selection steps on a case-by-case basis.
**Input**
A presumably "clean" Annotated ``adata``, where the observations are cells and the variables are
genes, where ``what`` is a per-variable-per-observation matrix or the name of a
per-variable-per-observation annotation containing such a matrix.
Will obey the following annotations in the full ``adata``, if they exist:
Variable (Gene) Annotations
``select_gene``
If exists, force a mask of genes to use as "select" genes, ignoring everything else.
``lateral_gene``
A boolean mask of genes which are lateral from being chosen as "select" genes based
on their name.
**Returns**
Returns annotated sliced data containing the "select" subset of the original data. By default,
the ``name`` of this data is {name}. If no selects were selected, return ``None``.
Also sets the following annotations in the full ``adata``:
Unstructured Annotations
``downsample_samples``
The target total number of samples in each downsampled cell.
Observation-Variable (Cell-Gene) Annotations:
``downsampled``
The downsampled data where the total number of samples in each cell is at most ``downsample_samples``.
Variable (Gene) Annotations
``high_total_gene``
A boolean mask of genes with "high" expression level (unless a ``select_gene`` mask exists).
``high_relative_variance_gene``
A boolean mask of genes with "high" normalized variance, relative to other genes with a similar expression
level (unless a ``select_gene`` mask exists).
``selected_gene``
A boolean mask of the actually selected genes.
**Computation Parameters**
0. If a ``select_gene`` mask exists, just use these genes and go directly to the last step 6.
1. Invoke :py:func:`metacells.tools.downsample.downsample_cells` to downsample the cells to the same total number of
UMIs, using the ``downsample_min_samples`` (default: {downsample_min_samples}), ``downsample_min_cell_quantile``
(default: {downsample_min_cell_quantile}), ``downsample_max_cell_quantile`` (default:
{downsample_max_cell_quantile}) and the ``random_seed`` (non-zero for reproducible results).
2. Invoke :py:func:`metacells.tools.high.find_high_total_genes` to select high-expression genes (based on the
downsampled data), using ``min_gene_total``.
3. Invoke :py:func:`metacells.tools.high.find_high_relative_variance_genes` to select high-variance genes (based on
the downsampled data), using ``min_gene_relative_variance``.
4. Compute the set of genes that pass the above test, as well as match the ``additional_gene_masks`` (default:
{additional_gene_masks}).
5. If we found less than ``min_genes`` genes (default: {min_genes}, and ``min_gene_relative_variance`` was
specified, try to achieve the required minimal number of genes by reducing the ``min_gene_relative_variance``. In
extreme cases, give up on the relative variance requirement altogether.
6. Invoke :py:func:`metacells.tools.filter.filter_data` to slice just the selected genes.
"""
assert min_genes > 0
tl.downsample_cells(
adata,
what,
downsample_min_samples=downsample_min_samples,
downsample_min_cell_quantile=downsample_min_cell_quantile,
downsample_max_cell_quantile=downsample_max_cell_quantile,
random_seed=random_seed,
)
results: Optional[Tuple[AnnData, ut.PandasSeries, ut.PandasSeries]] = None
if ut.has_data(adata, "select_gene"):
results = tl.filter_data(
adata,
name=name,
top_level=top_level,
track_var="full_gene_index",
var_masks=["&select_gene"],
mask_var="selected_gene",
)
assert results is not None
else:
var_masks = []
if min_gene_top3 is not None:
var_masks.append("&high_top3_gene")
tl.find_high_topN_genes(adata, "downsampled", topN=3, min_gene_topN=min_gene_top3)
if min_gene_total is not None:
var_masks.append("&high_total_gene")
tl.find_high_total_genes(adata, "downsampled", min_gene_total=min_gene_total)
if min_gene_relative_variance is not None:
var_masks.append("&high_relative_variance_gene")
tl.find_high_relative_variance_genes(
adata, "downsampled", min_gene_relative_variance=min_gene_relative_variance
)
candidate_genes_mask = tl.combine_masks(adata, var_masks + additional_gene_masks)
assert candidate_genes_mask is not None
if np.sum(candidate_genes_mask.values) >= min_genes:
results = tl.filter_data(
adata,
name=name,
top_level=top_level,
track_var="full_gene_index",
var_masks=var_masks + additional_gene_masks,
mask_var="selected_gene",
)
if results is None and min_gene_relative_variance is not None:
valid_candidate_genes_mask = tl.combine_masks(adata, var_masks[:-1] + additional_gene_masks)
assert valid_candidate_genes_mask is not None
is_valid_set = False
if np.sum(valid_candidate_genes_mask.values) <= min_genes:
var_masks.pop()
else:
while True:
high_gene_relative_variance = min_gene_relative_variance
min_gene_relative_variance -= 1 / 8
tl.find_high_relative_variance_genes(
adata, "downsampled", min_gene_relative_variance=min_gene_relative_variance
)
valid_candidate_genes_mask = tl.combine_masks(adata, var_masks + additional_gene_masks)
assert valid_candidate_genes_mask is not None
if np.sum(valid_candidate_genes_mask.values) >= min_genes:
is_valid_set = True
break
for _ in range(4):
mid_gene_relative_variance = (high_gene_relative_variance + min_gene_relative_variance) / 2.0
tl.find_high_relative_variance_genes(
adata, "downsampled", min_gene_relative_variance=mid_gene_relative_variance
)
mid_candidate_genes_mask = tl.combine_masks(adata, var_masks + additional_gene_masks)
assert mid_candidate_genes_mask is not None
if np.sum(mid_candidate_genes_mask.values) >= min_genes:
min_gene_relative_variance = mid_gene_relative_variance
valid_candidate_genes_mask = mid_candidate_genes_mask
is_valid_set = True
else:
high_gene_relative_variance = mid_gene_relative_variance
is_valid_set = False
if not is_valid_set:
ut.set_v_data(adata, "high_relative_variance_gene", valid_candidate_genes_mask.values)
results = tl.filter_data(
adata,
name=name,
top_level=top_level,
track_var="full_gene_index",
var_masks=var_masks + additional_gene_masks,
mask_var="selected_gene",
)
if results is None:
raise ValueError(f"Failed to select {min_genes} genes")
return results[0]